diff --git a/.bazelrc b/.bazelrc index 5e9ffb1fd0acea..bfc465c66cfe67 100644 --- a/.bazelrc +++ b/.bazelrc @@ -155,6 +155,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos @@ -217,6 +219,7 @@ build:mkl_aarch64 -c opt # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). # with Eigen threadpool support build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +build:mkl_aarch64_threadpool --@compute_library//:openmp=false build:mkl_aarch64_threadpool -c opt # CUDA: This config refers to building CUDA op kernels with nvcc. @@ -245,6 +248,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" @@ -259,10 +264,11 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc - +build:cuda_nvcc --config=cuda +build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc +# Old config for backward compatibility +build:nvcc_clang --config=cuda_nvcc # Debug config build:dbg -c dbg @@ -354,12 +360,16 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" +<<<<<<< HEAD # Required for building with clang build:linux --copt="-Wno-error=unused-but-set-variable" # We have some invalid linker scripts in the build, # so we need to disable this check build:linux --linkopt=-Wl,--undefined-version build:linux --host_linkopt=-Wl,--undefined-version +======= + +>>>>>>> master # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" @@ -593,7 +603,7 @@ build:rbe_linux_rocm_base --action_env=TF_ROCM_CONFIG_REPO="@ubuntu20.04-gcc9_ma build:rbe_linux_rocm_py3.9 --config=rbe_linux_rocm_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-gcc9_manylinux2014-rocm_config_python3.9" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda -build:rbe_linux_cuda_nvcc --config=nvcc_clang +build:rbe_linux_cuda_nvcc --config=cuda_nvcc build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_win_base --config=rbe_base @@ -787,27 +797,47 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +<<<<<<< HEAD test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +======= +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +>>>>>>> master # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +<<<<<<< HEAD test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +======= +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test +>>>>>>> master # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +<<<<<<< HEAD test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +======= +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test +>>>>>>> master # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +<<<<<<< HEAD test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +======= +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +>>>>>>> master # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium +<<<<<<< HEAD test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +======= +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test +>>>>>>> master # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/.github/workflows/arm-cd.yml b/.github/workflows/arm-cd.yml index 28ab5412163a4f..dabb576e1e7b5b 100644 --- a/.github/workflows/arm-cd.yml +++ b/.github/workflows/arm-cd.yml @@ -52,12 +52,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: ref: 'nightly' - name: Checkout repository for releases (skipped for nightly) if: ${{ github.event_name == 'push' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Build and test pip wheel shell: bash run: | diff --git a/.github/workflows/arm-ci-extended-cpp.yml b/.github/workflows/arm-ci-extended-cpp.yml index 1062ce908db518..b081cc402bcb86 100644 --- a/.github/workflows/arm-ci-extended-cpp.yml +++ b/.github/workflows/arm-ci-extended-cpp.yml @@ -50,12 +50,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Build binary and run C++ tests shell: bash run: | diff --git a/.github/workflows/arm-ci-extended.yml b/.github/workflows/arm-ci-extended.yml index 0fb591138a1ef5..4ce8cc369b5317 100644 --- a/.github/workflows/arm-ci-extended.yml +++ b/.github/workflows/arm-ci-extended.yml @@ -51,12 +51,12 @@ jobs: run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository for nightly (skipped for releases) if: ${{ github.event_name == 'schedule' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: ref: 'nightly' - name: Checkout repository if: ${{ github.event_name == 'push' }} - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Build binary and run python tests on nightly for all python versions shell: bash run: | diff --git a/.github/workflows/arm-ci.yml b/.github/workflows/arm-ci.yml index cc8462b44b2e02..55461fd75ec916 100644 --- a/.github/workflows/arm-ci.yml +++ b/.github/workflows/arm-ci.yml @@ -47,7 +47,7 @@ jobs: shell: bash run: find /home/ubuntu/actions-runner/_work/tensorflow/tensorflow/. -name . -o -prune -exec sudo rm -rf -- {} + || true - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Build binary and run python tests shell: bash run: | diff --git a/.github/workflows/cffconvert.yml b/.github/workflows/cffconvert.yml index 3a33d71402cdec..cca28c434f1777 100644 --- a/.github/workflows/cffconvert.yml +++ b/.github/workflows/cffconvert.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out a copy of the repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Check whether the citation metadata from CITATION.cff is valid uses: citation-file-format/cffconvert-github-action@4cf11baa70a673bfdf9dad0acc7ee33b3f4b6084 # v2.0.0 diff --git a/.github/workflows/issue-on-pr-rollback.yml b/.github/workflows/issue-on-pr-rollback.yml index 60395af98a8a90..dc322d97030e82 100644 --- a/.github/workflows/issue-on-pr-rollback.yml +++ b/.github/workflows/issue-on-pr-rollback.yml @@ -33,7 +33,7 @@ jobs: startsWith(github.event.head_commit.message, 'Rollback of PR #') steps: - name: Checkout repo - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Create a new Github Issue uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 with: diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index 5d8cddb032a3f2..29be0ab8bf43bc 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.4" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.5" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index 56ac9aa85a5805..e7597f945e6b20 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -28,7 +28,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Get file changes id: get_file_changes uses: trilom/file-changes-action@a6ca26c14274c33b15e6499323aac178af06ad4b # v1.2.4 diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index bfc10da2956a17..74f4d609c28340 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -45,7 +45,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: ref: ${{ github.event.inputs.release_branch }} - name: Get some helpful info for formatting @@ -58,7 +58,7 @@ jobs: echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@c5a7806660adbe173f04e3e038b0ccdcd758773c # v6.1.0 + uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f # v7.0.5 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index cff50a90238e09..023a14357ded87 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -41,7 +41,7 @@ jobs: steps: - name: "Checkout code" - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 with: persist-credentials: false @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@4dd16135b69a43b6c8efb853346f8437d92d3c93 # v3.26.6 + uses: github/codeql-action/upload-sarif@e2b3eafc8d227b0241d48be5f425d47c2d750a13 # v3.26.10 with: sarif_file: results.sarif diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml index cd1faa29c4ec86..756b89f3b2045a 100644 --- a/.github/workflows/sigbuild-docker-branch.yml +++ b/.github/workflows/sigbuild-docker-branch.yml @@ -40,7 +40,7 @@ jobs: run: rm -rf /opt/hostedtoolcache - name: Checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 @@ -67,7 +67,7 @@ jobs: - name: Build and push id: docker_build - uses: docker/build-push-action@5cd11c3a4ced054e52742c5fd54dca954e0edd85 # v6.7.0 + uses: docker/build-push-action@4f58ea79222b3b9dc2c8bbdd6debcef730109a75 # v6.9.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml index 6a5acb59c7082b..6f3adfbc81d47d 100644 --- a/.github/workflows/sigbuild-docker-presubmit.yml +++ b/.github/workflows/sigbuild-docker-presubmit.yml @@ -44,7 +44,7 @@ jobs: df -h - name: Checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 @@ -73,7 +73,7 @@ jobs: - name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied id: docker_build - uses: docker/build-push-action@5cd11c3a4ced054e52742c5fd54dca954e0edd85 # v6.7.0 + uses: docker/build-push-action@4f58ea79222b3b9dc2c8bbdd6debcef730109a75 # v6.9.0 with: push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }} context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml index 673041fc59d224..8f28b298ea74a4 100644 --- a/.github/workflows/sigbuild-docker.yml +++ b/.github/workflows/sigbuild-docker.yml @@ -43,7 +43,7 @@ jobs: run: rm -rf /opt/hostedtoolcache - name: Checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Set up Docker Buildx uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 @@ -82,7 +82,7 @@ jobs: - name: Build and push id: docker_build - uses: docker/build-push-action@5cd11c3a4ced054e52742c5fd54dca954e0edd85 # v6.7.0 + uses: docker/build-push-action@4f58ea79222b3b9dc2c8bbdd6debcef730109a75 # v6.9.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 42562e6d608d9c..2da000a6e13c1a 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -30,7 +30,7 @@ jobs: if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks steps: - name: Checkout code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 - name: Update the RBE Configs run: | function map() { @@ -130,7 +130,7 @@ jobs: map sigbuild-r2.17-clang-python3.11 2.17-python3.11 map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@c5a7806660adbe173f04e3e038b0ccdcd758773c # v6.1.0 + uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f # v7.0.5 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation diff --git a/.gitignore b/.gitignore index 614cde3446a16f..643ffca1c45c99 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,7 @@ tensorflow/contrib/cmake/_build/ /api_init_files_list.txt /estimator_api_init_files_list.txt *.whl +dist # Android .gradle diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 17b77f808d9c80..58123a3cddd9b4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -234,7 +234,7 @@ There are two ways to run TensorFlow unit tests. for the required packages. Alternatively, use the said [tensorflow/build Docker images](https://hub.docker.com/r/tensorflow/build) (`tensorflow/tensorflow:devel` and `tensorflow/tensorflow:devel-gpu` are no - longer supported for) development. Use TF SIG Build Dockerfiles in + longer supported for development). Use TF SIG Build Dockerfiles in development to avoid installing the packages directly on your system (in which case remember to change the directory from `/root` to `/tensorflow` once you get into the running container so `bazel` can find the `tensorflow` @@ -254,15 +254,16 @@ There are two ways to run TensorFlow unit tests. ``` If the tests are to be run on the GPU: - * For TensorFlow versions starting from v.2.18.0: - Add the `cuda` option flag. + + * For TensorFlow versions starting from v.2.18.0: Add the `cuda` option + flag. ```bash export flags="--config=opt --config=cuda -k" ``` - * For TensorFlow versions prior v.2.18.0: - Add CUDA paths to LD_LIBRARY_PATH and add the `cuda` option flag. + * For TensorFlow versions prior v.2.18.0: Add CUDA paths to + LD_LIBRARY_PATH and add the `cuda` option flag. ```bash export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" diff --git a/RELEASE.md b/RELEASE.md index bac912119a3744..6713ac4368c2f8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,4 +1,4 @@ -# Release 2.18.0 +# Release 2.19.0 ## TensorFlow @@ -9,26 +9,6 @@ * * -* `tf.lite` - * C API: - * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step - forward towards a cleaner API for `TfLiteOperator`. Function - `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, - released on 7/11/2024, and we do not expect there will be much code using this - function yet. Any code breakages can be easily resolved by passing nullptr as - the new, 4th parameter. - * SignatureRunner is now supported for models with no signatures. - -* TensorRT support is disabled in CUDA builds for code health improvement. - -* Hermetic CUDA support is added. - - Hermetic CUDA uses a specific downloadable version of CUDA instead of the - user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL - distributions, and then use CUDA libraries and tools as dependencies in - various Bazel targets. This enables more reproducible builds for Google ML - projects and supported CUDA versions. - ### Known Caveats * @@ -40,44 +20,12 @@ * * -* `tf.lite`: - * The LiteRT [repo](https://github.com/google-ai-edge/LiteRT) is - live (see [announcement](https://developers.googleblog.com/en/tensorflow-lite-is-now-litert/)), which means that in the coming months there will be changes to the development experience - for TFLite. The TF Lite Runtime source will be moved later this year, - and sometime after that we will start accepting contributions through that repo. - ### Bug Fixes and Other Changes * * * -* `tf.data` - * Add optional `synchronous` argument to `map`, to specify that the `map` - should run synchronously, as opposed to be parallelizable when - `options.experimental_optimization.map_parallelization=True`. This saves - memory compared to setting `num_parallel_calls=1`. - * Add optional `use_unbounded_threadpool` argument to `map`, to specify that - the `map` should use an unbounded threadpool instead of the default pool - that is based on the number of cores on the machine. This can improve - throughput for map functions which perform IO or otherwise release the - CPU. - * Add [`tf.data.experimental.get_model_proto`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/get_model_proto) - to allow users to peek into the analytical model inside of a dataset - iterator. - -* `tf.lite` - * `Dequantize` op supports `TensorType_INT4`. - * This change includes per-channel dequantization. - * Add support for `stablehlo.composite`. - * `EmbeddingLookup` op supports per-channel - quantization and `TensorType_INT4` values. - * `FullyConnected` op supports `TensorType_INT16` activation and - `TensorType_Int4` weight per-channel quantization. - -* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types. - * Support `bad_indices_policy`. - ## Keras @@ -110,6 +58,71 @@ This release contains contributions from many people at Google, as well as: , , , , , +# Release 2.18.0 + +## TensorFlow + +### Breaking Changes + +* `tf.lite` + * C API: + * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter. + +* TensorRT support is disabled in CUDA builds for code health improvement. + +* TensorFlow now supports and is compiled with NumPy 2.0 by default. Please see the [NumPy 2 release notes](https://numpy.org/doc/stable/release/2.0.0-notes.html) and the [NumPy 2 migration guide](https://numpy.org/devdocs/numpy_2_0_migration_guide.html#numpy-2-migration-guide). + * Note that NumPy's type promotion rules have been changed(See [NEP 50](https://numpy.org/neps/nep-0050-scalar-promotion.html#nep50)for details). This may change the precision at which computations happen, leading either to type errors or to numerical changes to results. + * Tensorflow will continue to support NumPy 1.26 until 2025, aligning with community standard deprecation timeline [here](https://scientific-python.org/specs/spec-0000/). + +* Hermetic CUDA support is added. + + Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, and then use CUDA libraries and tools as dependencies in various Bazel targets. This enables more reproducible builds for Google ML projects and supported CUDA versions. + +* Remove the `EnumNamesXNNPackFlags` function in `tensorflow/lite/acceleration/configuration/configuration_generated.h`. + + This change is a bug fix in the automatically generated code. This change is automatically generated by the new flatbuffer generator. The flatbuffers library is updated to 24.3.25 in https://github.com/tensorflow/tensorflow/commit/c17d64df85a83c1bd0fd7dcc0b1230812b0d3d48. The new flatbuffers library includes the following change https://github.com/google/flatbuffers/pull/7813 which fixed a underlying flatbuffer code generator bug. + + +### Known Caveats + +### Major Features and Improvements + +* `tf.lite`: + * The LiteRT [repo](https://github.com/google-ai-edge/LiteRT) is live (see [announcement](https://developers.googleblog.com/en/tensorflow-lite-is-now-litert/)), which means that in the coming months there will be changes to the development experience for TFLite. The TF Lite Runtime source will be moved later this year, and sometime after that we will start accepting contributions through that repo. + * SignatureRunner is now supported for models with no signatures. + +### Bug Fixes and Other Changes + +* `tf.data` + * Add optional `synchronous` argument to `map`, to specify that the `map` should run synchronously, as opposed to be parallelizable when `options.experimental_optimization.map_parallelization=True`. This saves memory compared to setting `num_parallel_calls=1`. + * Add optional `use_unbounded_threadpool` argument to `map`, to specify that the `map` should use an unbounded threadpool instead of the default pool that is based on the number of cores on the machine. This can improve throughput for map functions which perform IO or otherwise release the CPU. + * Add [`tf.data.experimental.get_model_proto`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/get_model_proto) to allow users to peek into the analytical model inside of a dataset iterator. + +* `tf.lite` + * `Dequantize` op supports `TensorType_INT4`. + * This change includes per-channel dequantization. + * Add support for `stablehlo.composite`. + * `EmbeddingLookup` op supports per-channel quantization and `TensorType_INT4` values. + * `FullyConnected` op supports `TensorType_INT16` activation and `TensorType_Int4` weight per-channel quantization. + * Enable per-tensor quantization support in dynamic range quantization of `TRANSPOSE_CONV` layer. Fixes TFLite converter [bug](https://github.com/tensorflow/tensorflow/issues/76624). + +* `tf.tensor_scatter_update`, `tf.tensor_scatter_add` and of other reduce types. + * Support `bad_indices_policy`. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Akhil Goel, akhilgoe, Alexander Pivovarov, Amir Samani, Andrew Goodbody, Andrey Portnoy, Anthony Platanios, bernardoArcari, Brett Taylor, buptzyb, Chao, Christian Clauss, Cocoa, Daniil Kutz, Darya Parygina, dependabot[bot], Dimitris Vardoulakis, Dragan Mladjenovic, Elfie Guo, eukub, Faijul Amin, flyingcat, Frédéric Bastien, ganyu.08, Georg Stefan Schmid, Grigory Reznikov, Harsha H S, Harshit Monish, Heiner, Ilia Sergachev, Jan, Jane Liu, Jaroslav Sevcik, Kaixi Hou, Kanvi Khanna, Kristof Maar, Kristóf Maár, LakshmiKalaKadali, Lbertho-Gpsw, lingzhi98, MarcoFalke, Masahiro Hiramori, Mmakevic-Amd, mraunak, Nobuo Tsukamoto, Notheisz57, Olli Lupton, Pearu Peterson, pemeliya, Peyara Nando, Philipp Hack, Phuong Nguyen, Pol Dellaiera, Rahul Batra, Ruturaj Vaidya, sachinmuradi, Sergey Kozub, Shanbin Ke, Sheng Yang, shengyu, Shraiysh, Shu Wang, Surya, sushreebarsa, Swatheesh-Mcw, syzygial, Tai Ly, terryysun, tilakrayal, Tj Xu, Trevor Morris, Tzung-Han Juang, wenchenvincent, wondertx, Xuefei Jiang, Ye Huang, Yimei Sun, Yunlong Liu, Zahid Iqbal, Zhan Lu, Zoranjovanovic-Ns, Zuri Obozuwa + +# Release 2.17.1 + +### Bug Fixes and Other Changes + +* Add necessary header files in the aar library. These are needed if developers build apps with header files unpacked from tflite aar files from maven. +* Implement Name() for GCSWritableFile to fix the profiler trace viewer cache file generation. +* Fix `cstring.h` missing file issue with the Libtensorflow archive. + # Release 2.17.0 ## TensorFlow diff --git a/SECURITY.md b/SECURITY.md index 87d45a8671754d..9753a70d195513 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -166,7 +166,7 @@ vulnerabilities. We recognize issues as vulnerabilities only when they occur in scenarios that we outline as safe; issues that have a security impact only when TensorFlow is used in a discouraged way (e.g. running untrusted models or checkpoints, data parsing -outside of the safe formats, etc.) are not treated as vulnerabilities.. +outside of the safe formats, etc.) are not treated as vulnerabilities. ### Reporting process diff --git a/WORKSPACE b/WORKSPACE index 32ffd0433108c7..0171c60db38dde 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -24,14 +24,20 @@ load("@//tensorflow:workspace3.bzl", "tf_workspace3") tf_workspace3() # Initialize hermetic Python -load("@local_xla//third_party/py:python_init_rules.bzl", "python_init_rules") +load("@local_tsl//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() -load("@local_xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") +load("@local_tsl//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( default_python_version = "system", + local_wheel_dist_folder = "dist", + local_wheel_inclusion_list = [ + "tensorflow*", + "tf_nightly*", + ], + local_wheel_workspaces = ["//:WORKSPACE"], requirements = { "3.9": "//:requirements_lock_3_9.txt", "3.10": "//:requirements_lock_3_10.txt", @@ -40,11 +46,11 @@ python_init_repositories( }, ) -load("@local_xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") +load("@local_tsl//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") python_init_toolchains() -load("@local_xla//third_party/py:python_init_pip.bzl", "python_init_pip") +load("@local_tsl//third_party/py:python_init_pip.bzl", "python_init_pip") python_init_pip() diff --git a/ci/official/containers/linux_arm64/Dockerfile b/ci/official/containers/linux_arm64/Dockerfile index d47764374d696d..43b596d5bd4b84 100644 --- a/ci/official/containers/linux_arm64/Dockerfile +++ b/ci/official/containers/linux_arm64/Dockerfile @@ -84,5 +84,6 @@ RUN /setup.python.sh python3.9 devel.requirements.txt RUN /setup.python.sh python3.10 devel.requirements.txt RUN /setup.python.sh python3.11 devel.requirements.txt RUN /setup.python.sh python3.12 devel.requirements.txt +RUN /setup.python.sh python3.13 devel.requirements.txt # "python3" commands by default run under 3.10 RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 diff --git a/ci/official/containers/linux_arm64/devel.packages.txt b/ci/official/containers/linux_arm64/devel.packages.txt index 61c7a97f1daf0f..fb69b4c89a0e15 100644 --- a/ci/official/containers/linux_arm64/devel.packages.txt +++ b/ci/official/containers/linux_arm64/devel.packages.txt @@ -29,6 +29,7 @@ moreutils openjdk-21-jdk openjdk-21-jre-headless openssl +parallel patchelf pkg-config python3-dev diff --git a/ci/official/containers/linux_arm64/setup.python.sh b/ci/official/containers/linux_arm64/setup.python.sh index d8ea04961e3d9a..3f3da75e9cb01c 100755 --- a/ci/official/containers/linux_arm64/setup.python.sh +++ b/ci/official/containers/linux_arm64/setup.python.sh @@ -36,7 +36,6 @@ else $VERSION $VERSION-dev $VERSION-venv -$VERSION-distutils EOF fi /setup.packages.sh pythons.txt diff --git a/ci/official/containers/ml_build/Dockerfile b/ci/official/containers/ml_build/Dockerfile new file mode 100644 index 00000000000000..9df1b75c7f1a85 --- /dev/null +++ b/ci/official/containers/ml_build/Dockerfile @@ -0,0 +1,61 @@ +################################################################################ +FROM ubuntu:22.04@sha256:58b87898e82351c6cf9cf5b9f3c20257bb9e2dcf33af051e12ce532d7f94e3fe AS devel +################################################################################ + +# Install devtoolset build dependencies +COPY setup.sources.sh /setup.sources.sh +COPY setup.packages.sh /setup.packages.sh +COPY builder.packages.txt /builder.packages.txt + +RUN /setup.sources.sh && /setup.packages.sh /builder.packages.txt + +# Install devtoolset-9 in /dt9 with glibc 2.17 and libstdc++ 4.8, for building +# manylinux2014-compatible packages. +COPY builder.devtoolset/fixlinks.sh /fixlinks.sh +COPY builder.devtoolset/rpm-patch.sh /rpm-patch.sh +COPY builder.devtoolset/build_devtoolset.sh /build_devtoolset.sh +COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch +RUN /build_devtoolset.sh devtoolset-9 /dt9 + +# Setup Python +COPY setup.python.sh /setup.python.sh +COPY builder.requirements.txt /builder.requirements.txt +RUN /setup.python.sh python3.9 builder.requirements.txt +RUN /setup.python.sh python3.10 builder.requirements.txt +RUN /setup.python.sh python3.11 builder.requirements.txt +RUN /setup.python.sh python3.12 builder.requirements.txt +RUN /setup.python.sh python3.13 builder.requirements.txt + +# Setup links for TensorFlow to compile. +# Referenced in devel.usertools/*.bazelrc. +# Set python3.12 as the default python version. +# TF does not support python3.13. +RUN ln -sf /usr/bin/python3.12 /usr/bin/python3 +RUN ln -sf /usr/bin/python3.12 /usr/bin/python +RUN ln -sf /usr/lib/python3.12 /usr/lib/tf_python + +# Make sure clang is on the path +RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang + +# Install various tools. +# - bats: bash unit testing framework +# - bazelisk: always use the correct bazel version +# - buildifier: clean bazel build deps +# - buildozer: clean bazel build deps +# - gcloud SDK: communicate with Google Cloud Platform (GCP) for RBE, CI +# - patchelf: Utility tool to modify existing ELF executables and libraries +RUN git clone --branch v1.11.0 https://github.com/bats-core/bats-core.git && bats-core/install.sh /usr/local && rm -rf bats-core +RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.21.0/bazelisk-linux-amd64 -O /usr/local/bin/bazel && chmod +x /usr/local/bin/bazel +RUN wget https://github.com/bazelbuild/buildtools/releases/download/v7.3.1/buildifier-linux-amd64 -O /usr/local/bin/buildifier && chmod +x /usr/local/bin/buildifier +RUN wget https://github.com/bazelbuild/buildtools/releases/download/v7.3.1/buildozer-linux-amd64 -O /usr/local/bin/buildozer && chmod +x /usr/local/bin/buildozer +RUN curl -sSL https://sdk.cloud.google.com > /tmp/gcloud && bash /tmp/gcloud --install-dir=~/usr/local/bin --disable-prompts +# Download and install patchelf v0.18.0 from GitHub. The default Ubuntu focal +# packages only provide the "0.10-2build1" version. We use patchelf to manipulate +# certain shared libraries during the wheel building process (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/build_pip_package.sh#L255-L262). +# When we use Patchelf versions <0.12, those shared libraries end up with a +# corrupted PT_NOTE program header. This was fixed in v0.12, see https://github.com/NixOS/patchelf/commit/43a33482b501b0f5ee9da312aabfca3806570cc9. +RUN wget https://github.com/NixOS/patchelf/releases/download/0.18.0/patchelf-0.18.0-x86_64.tar.gz && tar -zxvf patchelf-0.18.0-x86_64.tar.gz -C /usr && rm -rf patchelf-0.18.0-x86_64.tar.gz + +# Don't use the bazel cache when a new docker image is created. +RUN echo build --action_env=DOCKER_CACHEBUSTER=$(date +%s%N)$RANDOM >> /etc/bazel.bazelrc +RUN echo build --host_action_env=DOCKER_HOST_CACHEBUSTER=$(date +%s%N)$RANDOM >> /etc/bazel.bazelrc diff --git a/ci/official/containers/ml_build/README.md b/ci/official/containers/ml_build/README.md new file mode 100644 index 00000000000000..53c01f529b300b --- /dev/null +++ b/ci/official/containers/ml_build/README.md @@ -0,0 +1,8 @@ +WIP ML Build Docker container for ML repositories (Tensorflow, JAX and XLA). + +This container branches off from +/tensorflow/tools/tf_sig_build_dockerfiles/. However, since +hermetic CUDA and hermetic Python is now available for Tensorflow, a lot of the +requirements installed on the original container can be removed to reduce the +footprint of the container and make it more reusable across different ML +repositories. diff --git a/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh b/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh new file mode 100755 index 00000000000000..b4c63677d7ae76 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/build_devtoolset.sh @@ -0,0 +1,198 @@ +#!/bin/bash -eu +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Builds a devtoolset cross-compiler targeting manylinux 2010 (glibc 2.12 / +# libstdc++ 4.4) or manylinux2014 (glibc 2.17 / libstdc++ 4.8). + +VERSION="$1" +TARGET="$2" + +case "${VERSION}" in +devtoolset-7) + LIBSTDCXX_VERSION="6.0.24" + LIBSTDCXX_ABI="gcc4-compatible" + ;; +devtoolset-9) + LIBSTDCXX_VERSION="6.0.28" + LIBSTDCXX_ABI="new" + ;; +*) + echo "Usage: $0 {devtoolset-7|devtoolset-9} " + echo "Use 'devtoolset-7' to build a manylinux2010 compatible toolchain or 'devtoolset-9' to build a manylinux2014 compatible toolchain" + exit 1 + ;; +esac + +mkdir -p "${TARGET}" + +# Download glibc's shared and development libraries based on the value of the +# `VERSION` parameter. +# Note: 'Templatizing' this and the other conditional branches would require +# defining several variables (version, os, path) making it difficult to maintain +# and extend for future modifications. +case "${VERSION}" in +devtoolset-7) + # Download binary glibc 2.12 shared library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6_2.12.1-0ubuntu6_amd64.deb" && \ + unar "libc6_2.12.1-0ubuntu6_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6_2.12.1-0ubuntu6_amd64/data.tar.gz" && \ + rm -rf "libc6_2.12.1-0ubuntu6_amd64.deb" "libc6_2.12.1-0ubuntu6_amd64" + # Download binary glibc 2.12 development library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6-dev_2.12.1-0ubuntu6_amd64.deb" && \ + unar "libc6-dev_2.12.1-0ubuntu6_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6-dev_2.12.1-0ubuntu6_amd64/data.tar.gz" && \ + rm -rf "libc6-dev_2.12.1-0ubuntu6_amd64.deb" "libc6-dev_2.12.1-0ubuntu6_amd64" + ;; +devtoolset-9) + # Download binary glibc 2.17 shared library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6_2.17-0ubuntu5.1_amd64.deb" && \ + unar "libc6_2.17-0ubuntu5.1_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ + rm -rf "libc6_2.17-0ubuntu5.1_amd64.deb" "libc6_2.17-0ubuntu5.1_amd64" + # Download binary glibc 2.17 development library release. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/e/eglibc/libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ + unar "libc6-dev_2.17-0ubuntu5.1_amd64.deb" && \ + tar -C "${TARGET}" -xvzf "libc6-dev_2.17-0ubuntu5.1_amd64/data.tar.gz" && \ + rm -rf "libc6-dev_2.17-0ubuntu5.1_amd64.deb" "libc6-dev_2.17-0ubuntu5.1_amd64" + ;; +esac + +# Put the current kernel headers from ubuntu in place. +ln -s "/usr/include/linux" "/${TARGET}/usr/include/linux" +ln -s "/usr/include/asm-generic" "/${TARGET}/usr/include/asm-generic" +ln -s "/usr/include/x86_64-linux-gnu/asm" "/${TARGET}/usr/include/asm" + +# Symlinks in the binary distribution are set up for installation in /usr, we +# need to fix up all the links to stay within /${TARGET}. +/fixlinks.sh "/${TARGET}" + +# Patch to allow non-glibc 2.12 compatible builds to work. +sed -i '54i#define TCP_USER_TIMEOUT 18' "/${TARGET}/usr/include/netinet/tcp.h" + +# Download specific version of libstdc++ shared library based on the value of +# the `VERSION` parameter +case "${VERSION}" in +devtoolset-7) + # Download binary libstdc++ 4.4 release we are going to link against. + # We only need the shared library, as we're going to develop against the + # libstdc++ provided by devtoolset. + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/g/gcc-4.4/libstdc++6_4.4.3-4ubuntu5_amd64.deb" && \ + unar "libstdc++6_4.4.3-4ubuntu5_amd64.deb" && \ + tar -C "/${TARGET}" -xvzf "libstdc++6_4.4.3-4ubuntu5_amd64/data.tar.gz" "./usr/lib/libstdc++.so.6.0.13" && \ + rm -rf "libstdc++6_4.4.3-4ubuntu5_amd64.deb" "libstdc++6_4.4.3-4ubuntu5_amd64" + ;; +devtoolset-9) + # Download binary libstdc++ 4.8 shared library release + wget "http://old-releases.ubuntu.com/ubuntu/pool/main/g/gcc-4.8/libstdc++6_4.8.1-10ubuntu8_amd64.deb" && \ + unar "libstdc++6_4.8.1-10ubuntu8_amd64.deb" && \ + tar -C "/${TARGET}" -xvzf "libstdc++6_4.8.1-10ubuntu8_amd64/data.tar.gz" "./usr/lib/x86_64-linux-gnu/libstdc++.so.6.0.18" && \ + rm -rf "libstdc++6_4.8.1-10ubuntu8_amd64.deb" "libstdc++6_4.8.1-10ubuntu8_amd64" + ;; +esac + +mkdir -p "${TARGET}-src" +cd "${TARGET}-src" + +# Build a devtoolset cross-compiler based on our glibc 2.12/glibc 2.17 sysroot setup. +case "${VERSION}" in +devtoolset-7) + wget "http://vault.centos.org/centos/6/sclo/Source/rh/devtoolset-7/devtoolset-7-gcc-7.3.1-5.15.el6.src.rpm" + rpm2cpio "devtoolset-7-gcc-7.3.1-5.15.el6.src.rpm" |cpio -idmv + tar -xvjf "gcc-7.3.1-20180303.tar.bz2" --strip 1 + ;; +devtoolset-9) + wget "https://vault.centos.org/centos/7/sclo/Source/rh/devtoolset-9-gcc-9.3.1-2.2.el7.src.rpm" + rpm2cpio "devtoolset-9-gcc-9.3.1-2.2.el7.src.rpm" |cpio -idmv + tar -xvf "gcc-9.3.1-20200408.tar.xz" --strip 1 + ;; +esac + +# Apply the devtoolset patches to gcc. +/rpm-patch.sh "gcc.spec" + +./contrib/download_prerequisites + +mkdir -p "${TARGET}-build" +cd "${TARGET}-build" + +"${TARGET}-src/configure" \ + --prefix=/"${TARGET}/usr" \ + --with-sysroot="/${TARGET}" \ + --disable-bootstrap \ + --disable-libmpx \ + --disable-libsanitizer \ + --disable-libunwind-exceptions \ + --disable-libunwind-exceptions \ + --disable-lto \ + --disable-multilib \ + --enable-__cxa_atexit \ + --enable-gnu-indirect-function \ + --enable-gnu-unique-object \ + --enable-initfini-array \ + --enable-languages="c,c++" \ + --enable-linker-build-id \ + --enable-plugin \ + --enable-shared \ + --enable-threads=posix \ + --with-default-libstdcxx-abi=${LIBSTDCXX_ABI} \ + --with-gcc-major-version-only \ + --with-linker-hash-style="gnu" \ + --with-tune="generic" \ + && \ + make -j 42 && \ + make install + + +# Create the devtoolset libstdc++ linkerscript that links dynamically against +# the system libstdc++ 4.4 and provides all other symbols statically. +case "${VERSION}" in +devtoolset-7) +mv "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}" \ + "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}.backup" +echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.13 -lstdc++_nonshared44 )" \ + > "/${TARGET}/usr/lib/libstdc++.so.${LIBSTDCXX_VERSION}" +cp "./x86_64-pc-linux-gnu/libstdc++-v3/src/.libs/libstdc++_nonshared44.a" \ + "/${TARGET}/usr/lib" + ;; +devtoolset-9) +# Note that the installation path for libstdc++ here is /${TARGET}/usr/lib64/ +mv "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" \ + "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}.backup" +echo -e "OUTPUT_FORMAT(elf64-x86-64)\nINPUT ( libstdc++.so.6.0.18 -lstdc++_nonshared44 )" \ + > "/${TARGET}/usr/lib64/libstdc++.so.${LIBSTDCXX_VERSION}" +cp "./x86_64-pc-linux-gnu/libstdc++-v3/src/.libs/libstdc++_nonshared44.a" \ + "/${TARGET}/usr/lib64" +;; +esac + +# Link in architecture specific includes from the system; note that we cannot +# link in the whole x86_64-linux-gnu folder, as otherwise we're overlaying +# system gcc paths that we do not want to find. +# TODO(klimek): Automate linking in all non-gcc / non-kernel include +# directories. +mkdir -p "/${TARGET}/usr/include/x86_64-linux-gnu" +PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11" "python3.12") +for v in "${PYTHON_VERSIONS[@]}"; do + ln -s "/usr/local/include/${v}" "/${TARGET}/usr/include/x86_64-linux-gnu/${v}" +done + +# Patch glibc to be compatable with modern clang +case "${VERSION}" in +devtoolset-9) + cd / + patch -p0 < /glibc2.17-inline.patch +;; +esac diff --git a/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh b/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh new file mode 100755 index 00000000000000..86856d80d9ceb1 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/fixlinks.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Re-direct all links in $1 that point to /lib... to point to $1/lib... instead. + +BASE="$1" +find "${BASE}" -type l | \ + while read l ; do + if [[ "$(readlink "$l")" == /lib* ]]; then + ORIG="$(readlink "$l")"; + rm "$l"; + ln -s "${BASE}${ORIG}" "$l" + fi + done + diff --git a/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch b/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch new file mode 100644 index 00000000000000..db8c3423a38298 --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/glibc2.17-inline.patch @@ -0,0 +1,11 @@ +--- /dt9/usr/include/x86_64-linux-gnu/sys/cdefs.h 2013-09-30 13:58:17.000000000 +0000 ++++ /dt9/usr/include/x86_64-linux-gnu/sys/cdefs.new.h 2022-11-04 17:17:31.727061220 +0000 +@@ -320,7 +320,7 @@ + + /* GCC 4.3 and above with -std=c99 or -std=gnu99 implements ISO C99 + inline semantics, unless -fgnu89-inline is used. */ +-#if (!defined __cplusplus || __GNUC_PREREQ (4,3)) && defined __GNUC__ ++#if (!defined __cplusplus || __GNUC_PREREQ (4,3) || defined __clang__) && defined __GNUC__ + # if defined __GNUC_STDC_INLINE__ || defined __cplusplus + # define __extern_inline extern __inline __attribute__ ((__gnu_inline__)) + # define __extern_always_inline \ \ No newline at end of file diff --git a/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh b/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh new file mode 100755 index 00000000000000..892ae2af86a3fa --- /dev/null +++ b/ci/official/containers/ml_build/builder.devtoolset/rpm-patch.sh @@ -0,0 +1,28 @@ +#!/bin/bash -eu +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Given an RPM spec file $1, apply its patches. + +SPEC="$1" +grep '%patch' "${SPEC}" |while read cmd ; do + N=$(echo "${cmd}" |sed 's,%patch\([0-9]\+\).*,\1,') + file=$(grep "Patch$N:" "${SPEC}" |sed 's,.*: ,,') + parg=$(echo "${cmd}" |sed 's,.*\(-p[0-9]\).*,\1,') + if [[ ! "${file}" =~ doxygen && "${cmd}" != \#* ]]; then + echo "patch ${parg} -s < ${file}" + patch ${parg} -s < "${file}" + fi +done diff --git a/ci/official/containers/ml_build/builder.packages.txt b/ci/official/containers/ml_build/builder.packages.txt new file mode 100644 index 00000000000000..e1a8bf3cc0e85d --- /dev/null +++ b/ci/official/containers/ml_build/builder.packages.txt @@ -0,0 +1,34 @@ +# Packages to be installed for the new Docker image. + +# Packages needed to build devtoolset +file +flex +g++ +make +patch +rpm2cpio +unar +wget +xz-utils +cpio + +# Other build-related tools +apt-transport-https +autoconf +automake +build-essential +ca-certificates +llvm-18 +clang-18 +clang-tidy-18 +lld-18 +clang-format-12 +curl +git +parallel +sudo +swig +unzip +zip +openjdk-21-jdk +vim diff --git a/ci/official/containers/ml_build/builder.requirements.txt b/ci/official/containers/ml_build/builder.requirements.txt new file mode 100644 index 00000000000000..2aede7454f3080 --- /dev/null +++ b/ci/official/containers/ml_build/builder.requirements.txt @@ -0,0 +1,6 @@ +# For wheel verification, and uploading +auditwheel ~= 6.1.0 +twine ~= 5.1.1 + +# For JAX +build ~= 1.2.2 diff --git a/ci/official/containers/ml_build/setup.packages.sh b/ci/official/containers/ml_build/setup.packages.sh new file mode 100755 index 00000000000000..f808cf7d22a7ce --- /dev/null +++ b/ci/official/containers/ml_build/setup.packages.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# setup.packages.sh: Given a list of Ubuntu packages, install them and clean up. +# Usage: setup.packages.sh +set -e + +# Prevent apt install tzinfo from asking our location (assumes UTC) +export DEBIAN_FRONTEND=noninteractive + +apt-get update +# Remove commented lines and blank lines +apt-get install -y --no-install-recommends $(sed -e '/^\s*#.*$/d' -e '/^\s*$/d' "$1" | sort -u) +rm -rf /var/lib/apt/lists/* diff --git a/ci/official/containers/ml_build/setup.python.sh b/ci/official/containers/ml_build/setup.python.sh new file mode 100755 index 00000000000000..831bd612c41bb0 --- /dev/null +++ b/ci/official/containers/ml_build/setup.python.sh @@ -0,0 +1,68 @@ +#!/usr/bin/env bash +# +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# setup.python.sh: Install a specific Python version and packages for it. +# Usage: setup.python.sh +set -xe + +source ~/.bashrc +VERSION=$1 +REQUIREMENTS=$2 + +# Install Python packages for this container's version +if [[ ${VERSION} == "python3.13" ]]; then + cat >pythons.txt <pythons.txt </etc/apt/sources.list.d/custom.list <= 2.31.0 packaging==23.2 -setuptools==68.2.2 +setuptools==70.0.0 jax==0.4.7 +# The dependencies below are needed for TF wheel testing. +tensorflow-io-gcs-filesystem==0.37.1 +libclang >= 13.0.0 +google_pasta ~= 0.2 +flatbuffers ~= 24.3.25 diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt index 9567cdea7ee792..b729fb3d08e8d7 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_10.txt @@ -1,3 +1,9 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff @@ -7,11 +13,11 @@ absl-py==2.1.0 \ astor==0.7.1 \ --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in certifi==2024.7.4 \ --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 @@ -111,7 +117,7 @@ charset-normalizer==3.3.2 \ dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in dm-tree==0.1.8 \ --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \ --hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \ @@ -160,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -212,7 +227,7 @@ grpcio==1.64.1 \ --hash=sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22 \ --hash=sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly h5py==3.11.0 \ --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \ @@ -237,7 +252,7 @@ h5py==3.11.0 \ --hash=sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc \ --hash=sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # keras-nightly idna==3.7 \ --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ @@ -245,14 +260,26 @@ idna==3.7 \ # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in markdown==3.6 \ --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 @@ -346,44 +373,52 @@ ml-dtypes==0.4.0 \ --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax # keras-nightly namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==1.23.5 ; python_version <= "3.11" \ - --hash=sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d \ - --hash=sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07 \ - --hash=sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df \ - --hash=sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9 \ - --hash=sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d \ - --hash=sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a \ - --hash=sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719 \ - --hash=sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2 \ - --hash=sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280 \ - --hash=sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa \ - --hash=sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387 \ - --hash=sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1 \ - --hash=sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43 \ - --hash=sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f \ - --hash=sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398 \ - --hash=sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63 \ - --hash=sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de \ - --hash=sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8 \ - --hash=sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481 \ - --hash=sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0 \ - --hash=sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d \ - --hash=sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e \ - --hash=sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96 \ - --hash=sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb \ - --hash=sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6 \ - --hash=sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d \ - --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ - --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 +numpy==1.26.4 \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # h5py # jax # keras-nightly @@ -395,16 +430,18 @@ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 - # via -r requirements.in + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in protobuf==4.25.3 \ --hash=sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4 \ --hash=sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8 \ @@ -443,7 +480,7 @@ pygments==2.18.0 \ requests==2.32.3 \ --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 @@ -475,34 +512,53 @@ scipy==1.11.3 \ --hash=sha256:e04aa19acc324a1a076abb4035dabe9b64badb19f76ad9c798bde39d41025cdc \ --hash=sha256:e1f97cd89c0fe1a0685f8f89d85fa305deb3067d0668151571ba50913e445820 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d - # via -r requirements.in +tb-nightly==2.18.0a20240925 \ + --hash=sha256:ad4b476bfb4e3861b769f93c9ee4efc52f025e3b5db91da5d30c9e806e99a0eb + # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in tensorboard-data-server==0.7.2 \ --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in urllib3==2.2.2 \ --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 @@ -515,7 +571,7 @@ wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # astunparse wrapt==1.16.0 \ --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ @@ -588,11 +644,12 @@ wrapt==1.16.0 \ --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +# The following packages are considered to be unsafe in a requirements file: setuptools==70.0.0 \ --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt index 9567cdea7ee792..bc1ec067fecf12 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_11.txt @@ -1,3 +1,9 @@ +# +# This file is autogenerated by pip-compile with Python 3.11 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff @@ -7,11 +13,11 @@ absl-py==2.1.0 \ astor==0.7.1 \ --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in certifi==2024.7.4 \ --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 @@ -111,7 +117,7 @@ charset-normalizer==3.3.2 \ dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in dm-tree==0.1.8 \ --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \ --hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \ @@ -160,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -212,7 +227,7 @@ grpcio==1.64.1 \ --hash=sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22 \ --hash=sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly h5py==3.11.0 \ --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \ @@ -237,7 +252,7 @@ h5py==3.11.0 \ --hash=sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc \ --hash=sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # keras-nightly idna==3.7 \ --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ @@ -245,14 +260,26 @@ idna==3.7 \ # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in markdown==3.6 \ --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 @@ -346,44 +373,52 @@ ml-dtypes==0.4.0 \ --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax # keras-nightly namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==1.23.5 ; python_version <= "3.11" \ - --hash=sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d \ - --hash=sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07 \ - --hash=sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df \ - --hash=sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9 \ - --hash=sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d \ - --hash=sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a \ - --hash=sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719 \ - --hash=sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2 \ - --hash=sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280 \ - --hash=sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa \ - --hash=sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387 \ - --hash=sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1 \ - --hash=sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43 \ - --hash=sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f \ - --hash=sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398 \ - --hash=sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63 \ - --hash=sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de \ - --hash=sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8 \ - --hash=sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481 \ - --hash=sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0 \ - --hash=sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d \ - --hash=sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e \ - --hash=sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96 \ - --hash=sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb \ - --hash=sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6 \ - --hash=sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d \ - --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ - --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 +numpy==1.26.4 \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # h5py # jax # keras-nightly @@ -395,16 +430,18 @@ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 - # via -r requirements.in + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in protobuf==4.25.3 \ --hash=sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4 \ --hash=sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8 \ @@ -443,7 +480,7 @@ pygments==2.18.0 \ requests==2.32.3 \ --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 @@ -475,34 +512,53 @@ scipy==1.11.3 \ --hash=sha256:e04aa19acc324a1a076abb4035dabe9b64badb19f76ad9c798bde39d41025cdc \ --hash=sha256:e1f97cd89c0fe1a0685f8f89d85fa305deb3067d0668151571ba50913e445820 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d - # via -r requirements.in +tb-nightly==2.18.0a20240925 \ + --hash=sha256:ad4b476bfb4e3861b769f93c9ee4efc52f025e3b5db91da5d30c9e806e99a0eb + # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in tensorboard-data-server==0.7.2 \ --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in urllib3==2.2.2 \ --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 @@ -515,7 +571,7 @@ wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # astunparse wrapt==1.16.0 \ --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ @@ -588,11 +644,12 @@ wrapt==1.16.0 \ --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +# The following packages are considered to be unsafe in a requirements file: setuptools==70.0.0 \ --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt index 4e6a881f64ede6..ba6f7304b535f4 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_12.txt @@ -1,3 +1,9 @@ +# +# This file is autogenerated by pip-compile with Python 3.12 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff @@ -7,11 +13,11 @@ absl-py==2.1.0 \ astor==0.7.1 \ --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in certifi==2024.7.4 \ --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 @@ -111,7 +117,7 @@ charset-normalizer==3.3.2 \ dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in dm-tree==0.1.8 \ --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \ --hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \ @@ -160,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -212,7 +227,7 @@ grpcio==1.64.1 \ --hash=sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22 \ --hash=sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly h5py==3.11.0 \ --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \ @@ -237,7 +252,7 @@ h5py==3.11.0 \ --hash=sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc \ --hash=sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # keras-nightly idna==3.7 \ --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ @@ -245,14 +260,26 @@ idna==3.7 \ # via requests jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in markdown==3.6 \ --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 @@ -346,14 +373,14 @@ ml-dtypes==0.4.0 \ --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax # keras-nightly namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==1.26.4 ; python_version >= "3.12" \ +numpy==1.26.4 \ --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ @@ -391,7 +418,7 @@ numpy==1.26.4 ; python_version >= "3.12" \ --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # h5py # jax # keras-nightly @@ -403,16 +430,18 @@ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 - # via -r requirements.in + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in protobuf==4.25.3 \ --hash=sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4 \ --hash=sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8 \ @@ -451,7 +480,7 @@ pygments==2.18.0 \ requests==2.32.3 \ --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 @@ -483,34 +512,53 @@ scipy==1.11.3 \ --hash=sha256:e04aa19acc324a1a076abb4035dabe9b64badb19f76ad9c798bde39d41025cdc \ --hash=sha256:e1f97cd89c0fe1a0685f8f89d85fa305deb3067d0668151571ba50913e445820 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d - # via -r requirements.in +tb-nightly==2.18.0a20240925 \ + --hash=sha256:ad4b476bfb4e3861b769f93c9ee4efc52f025e3b5db91da5d30c9e806e99a0eb + # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in tensorboard-data-server==0.7.2 \ --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in urllib3==2.2.2 \ --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 @@ -523,7 +571,7 @@ wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # astunparse wrapt==1.16.0 \ --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ @@ -596,11 +644,12 @@ wrapt==1.16.0 \ --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +# The following packages are considered to be unsafe in a requirements file: setuptools==70.0.0 \ --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly diff --git a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt index cc6bd2ea02ea9f..1b3a3b3efd3287 100644 --- a/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt +++ b/ci/official/requirements_updater/numpy1_requirements/requirements_lock_3_9.txt @@ -1,3 +1,9 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# absl-py==2.1.0 \ --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff @@ -7,11 +13,11 @@ absl-py==2.1.0 \ astor==0.7.1 \ --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in astunparse==1.6.3 \ --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in certifi==2024.7.4 \ --hash=sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b \ --hash=sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90 @@ -111,7 +117,7 @@ charset-normalizer==3.3.2 \ dill==0.3.7 \ --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in dm-tree==0.1.8 \ --hash=sha256:054b461f8176f4bce7a21f7b1870f873a1ced3bdbe1282c816c550bb43c71fa6 \ --hash=sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760 \ @@ -160,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -212,7 +227,7 @@ grpcio==1.64.1 \ --hash=sha256:ee73a2f5ca4ba44fa33b4d7d2c71e2c8a9e9f78d53f6507ad68e7d2ad5f64a22 \ --hash=sha256:f10193c69fc9d3d726e83bbf0f3d316f1847c3071c8c93d8090cf5f326b14309 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly h5py==3.11.0 \ --hash=sha256:083e0329ae534a264940d6513f47f5ada617da536d8dccbafc3026aefc33c90e \ @@ -237,7 +252,7 @@ h5py==3.11.0 \ --hash=sha256:f3736fe21da2b7d8a13fe8fe415f1272d2a1ccdeff4849c1421d2fb30fd533bc \ --hash=sha256:f4e025e852754ca833401777c25888acb96889ee2c27e7e629a19aee288833f0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # keras-nightly idna==3.7 \ --hash=sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc \ @@ -249,14 +264,26 @@ importlib-metadata==7.1.0 \ # via markdown jax==0.4.7 \ --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in markdown==3.6 \ --hash=sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f \ --hash=sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224 @@ -350,44 +377,52 @@ ml-dtypes==0.4.0 \ --hash=sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06 \ --hash=sha256:f1724ddcdf5edbaf615a62110af47407f1719b8d02e68ccee60683acb5f74da1 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax # keras-nightly namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==1.23.5 ; python_version <= "3.11" \ - --hash=sha256:01dd17cbb340bf0fc23981e52e1d18a9d4050792e8fb8363cecbf066a84b827d \ - --hash=sha256:06005a2ef6014e9956c09ba07654f9837d9e26696a0470e42beedadb78c11b07 \ - --hash=sha256:09b7847f7e83ca37c6e627682f145856de331049013853f344f37b0c9690e3df \ - --hash=sha256:0aaee12d8883552fadfc41e96b4c82ee7d794949e2a7c3b3a7201e968c7ecab9 \ - --hash=sha256:0cbe9848fad08baf71de1a39e12d1b6310f1d5b2d0ea4de051058e6e1076852d \ - --hash=sha256:1b1766d6f397c18153d40015ddfc79ddb715cabadc04d2d228d4e5a8bc4ded1a \ - --hash=sha256:33161613d2269025873025b33e879825ec7b1d831317e68f4f2f0f84ed14c719 \ - --hash=sha256:5039f55555e1eab31124a5768898c9e22c25a65c1e0037f4d7c495a45778c9f2 \ - --hash=sha256:522e26bbf6377e4d76403826ed689c295b0b238f46c28a7251ab94716da0b280 \ - --hash=sha256:56e454c7833e94ec9769fa0f86e6ff8e42ee38ce0ce1fa4cbb747ea7e06d56aa \ - --hash=sha256:58f545efd1108e647604a1b5aa809591ccd2540f468a880bedb97247e72db387 \ - --hash=sha256:5e05b1c973a9f858c74367553e236f287e749465f773328c8ef31abe18f691e1 \ - --hash=sha256:7903ba8ab592b82014713c491f6c5d3a1cde5b4a3bf116404e08f5b52f6daf43 \ - --hash=sha256:8969bfd28e85c81f3f94eb4a66bc2cf1dbdc5c18efc320af34bffc54d6b1e38f \ - --hash=sha256:92c8c1e89a1f5028a4c6d9e3ccbe311b6ba53694811269b992c0b224269e2398 \ - --hash=sha256:9c88793f78fca17da0145455f0d7826bcb9f37da4764af27ac945488116efe63 \ - --hash=sha256:a7ac231a08bb37f852849bbb387a20a57574a97cfc7b6cabb488a4fc8be176de \ - --hash=sha256:abdde9f795cf292fb9651ed48185503a2ff29be87770c3b8e2a14b0cd7aa16f8 \ - --hash=sha256:af1da88f6bc3d2338ebbf0e22fe487821ea4d8e89053e25fa59d1d79786e7481 \ - --hash=sha256:b2a9ab7c279c91974f756c84c365a669a887efa287365a8e2c418f8b3ba73fb0 \ - --hash=sha256:bf837dc63ba5c06dc8797c398db1e223a466c7ece27a1f7b5232ba3466aafe3d \ - --hash=sha256:ca51fcfcc5f9354c45f400059e88bc09215fb71a48d3768fb80e357f3b457e1e \ - --hash=sha256:ce571367b6dfe60af04e04a1834ca2dc5f46004ac1cc756fb95319f64c095a96 \ - --hash=sha256:d208a0f8729f3fb790ed18a003f3a57895b989b40ea4dce4717e9cf4af62c6bb \ - --hash=sha256:dbee87b469018961d1ad79b1a5d50c0ae850000b639bcb1b694e9981083243b6 \ - --hash=sha256:e9f4c4e51567b616be64e05d517c79a8a22f3606499941d97bb76f2ca59f982d \ - --hash=sha256:f063b69b090c9d918f9df0a12116029e274daf0181df392839661c4c7ec9018a \ - --hash=sha256:f9a909a8bae284d46bbfdefbdd4a262ba19d3bc9921b1e76126b1d21c3c34135 +numpy==1.26.4 \ + --hash=sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b \ + --hash=sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818 \ + --hash=sha256:1af303d6b2210eb850fcf03064d364652b7120803a0b872f5211f5234b399f20 \ + --hash=sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0 \ + --hash=sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010 \ + --hash=sha256:2e4ee3380d6de9c9ec04745830fd9e2eccb3e6cf790d39d7b98ffd19b0dd754a \ + --hash=sha256:3373d5d70a5fe74a2c1bb6d2cfd9609ecf686d47a2d7b1d37a8f3b6bf6003aea \ + --hash=sha256:47711010ad8555514b434df65f7d7b076bb8261df1ca9bb78f53d3b2db02e95c \ + --hash=sha256:4c66707fabe114439db9068ee468c26bbdf909cac0fb58686a42a24de1760c71 \ + --hash=sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110 \ + --hash=sha256:52b8b60467cd7dd1e9ed082188b4e6bb35aa5cdd01777621a1658910745b90be \ + --hash=sha256:60dedbb91afcbfdc9bc0b1f3f402804070deed7392c23eb7a7f07fa857868e8a \ + --hash=sha256:62b8e4b1e28009ef2846b4c7852046736bab361f7aeadeb6a5b89ebec3c7055a \ + --hash=sha256:666dbfb6ec68962c033a450943ded891bed2d54e6755e35e5835d63f4f6931d5 \ + --hash=sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed \ + --hash=sha256:679b0076f67ecc0138fd2ede3a8fd196dddc2ad3254069bcb9faf9a79b1cebcd \ + --hash=sha256:7349ab0fa0c429c82442a27a9673fc802ffdb7c7775fad780226cb234965e53c \ + --hash=sha256:7ab55401287bfec946ced39700c053796e7cc0e3acbef09993a9ad2adba6ca6e \ + --hash=sha256:7e50d0a0cc3189f9cb0aeb3a6a6af18c16f59f004b866cd2be1c14b36134a4a0 \ + --hash=sha256:95a7476c59002f2f6c590b9b7b998306fba6a5aa646b1e22ddfeaf8f78c3a29c \ + --hash=sha256:96ff0b2ad353d8f990b63294c8986f1ec3cb19d749234014f4e7eb0112ceba5a \ + --hash=sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b \ + --hash=sha256:9ff0f4f29c51e2803569d7a51c2304de5554655a60c5d776e35b4a41413830d0 \ + --hash=sha256:a354325ee03388678242a4d7ebcd08b5c727033fcff3b2f536aea978e15ee9e6 \ + --hash=sha256:a4abb4f9001ad2858e7ac189089c42178fcce737e4169dc61321660f1a96c7d2 \ + --hash=sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a \ + --hash=sha256:afedb719a9dcfc7eaf2287b839d8198e06dcd4cb5d276a3df279231138e83d30 \ + --hash=sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218 \ + --hash=sha256:b97fe8060236edf3662adfc2c633f56a08ae30560c56310562cb4f95500022d5 \ + --hash=sha256:bfe25acf8b437eb2a8b2d49d443800a5f18508cd811fea3181723922a8a82b07 \ + --hash=sha256:cd25bcecc4974d09257ffcd1f098ee778f7834c3ad767fe5db785be9a4aa9cb2 \ + --hash=sha256:d209d8969599b27ad20994c8e41936ee0964e6da07478d6c35016bc386b66ad4 \ + --hash=sha256:d5241e0a80d808d70546c697135da2c613f30e28251ff8307eb72ba696945764 \ + --hash=sha256:edd8b5fe47dab091176d21bb6de568acdd906d1887a4584a15a9a96a1dca06ef \ + --hash=sha256:f870204a840a60da0b12273ef34f7051e98c3b5961b61b0c2c1be6dfd64fbcd3 \ + --hash=sha256:ffa75af20b44f8dba823498024771d5ac50620e6915abac414251bd971b4529f # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # h5py # jax # keras-nightly @@ -399,16 +434,18 @@ opt-einsum==3.3.0 \ --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 - # via -r requirements.in + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in protobuf==4.25.3 \ --hash=sha256:19b270aeaa0099f16d3ca02628546b8baefe2955bbe23224aaf856134eccf1e4 \ --hash=sha256:209ba4cc916bab46f64e56b85b090607a676f66b473e6b762e6f1d9d591eb2e8 \ @@ -447,7 +484,7 @@ pygments==2.18.0 \ requests==2.32.3 \ --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in rich==13.7.1 \ --hash=sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222 \ --hash=sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432 @@ -479,34 +516,53 @@ scipy==1.11.3 \ --hash=sha256:e04aa19acc324a1a076abb4035dabe9b64badb19f76ad9c798bde39d41025cdc \ --hash=sha256:e1f97cd89c0fe1a0685f8f89d85fa305deb3067d0668151571ba50913e445820 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # jax six==1.16.0 \ --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d - # via -r requirements.in +tb-nightly==2.18.0a20240925 \ + --hash=sha256:ad4b476bfb4e3861b769f93c9ee4efc52f025e3b5db91da5d30c9e806e99a0eb + # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in tensorboard-data-server==0.7.2 \ --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in typing-extensions==4.8.0 \ --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in urllib3==2.2.2 \ --hash=sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472 \ --hash=sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168 @@ -519,7 +575,7 @@ wheel==0.41.3 \ --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # astunparse wrapt==1.16.0 \ --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ @@ -592,15 +648,16 @@ wrapt==1.16.0 \ --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 - # via -r requirements.in + # via -r ci/official/requirements_updater/requirements.in zipp==3.19.2 \ --hash=sha256:bf1dcf6450f873a13e952a29504887c89e6de7506209e5b1bcc3460135d4de19 \ --hash=sha256:f091755f667055f2d02b32c53771a7a6c8b47e1fdbc4b72a8b9072b3eef8015c # via importlib-metadata +# The following packages are considered to be unsafe in a requirements file: setuptools==70.0.0 \ --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 # via - # -r requirements.in + # -r ci/official/requirements_updater/requirements.in # tb-nightly diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 749d097c6997b1..e50806d6b3650d 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -1,4 +1,5 @@ -numpy ~= 2.0.0 +# Note that numpy 2.1.0 does not support python 3.9 +numpy >= 2.0.0, < 2.2.0 wheel ~= 0.41.2 h5py >= 3.11.0 lit ~= 17.0.2 @@ -17,12 +18,17 @@ ml_dtypes >= 0.4.0, < 0.5.0 # Note that we must use nightly here as these are used in nightly jobs # For release jobs, we will pin these on the release branch keras-nightly ~= 3.0.0.dev -tb-nightly ~= 2.18.0.a +tb-nightly ~= 2.19.0.a # Test dependencies grpcio >= 1.24.3, < 2.0 portpicker == 1.6.0 -scipy ~= 1.13.0 +scipy >= 1.13.0 requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 +# The dependencies below are needed for TF wheel testing. +tensorflow-io-gcs-filesystem==0.37.1 +libclang >= 13.0.0 +google_pasta ~= 0.2 +flatbuffers ~= 24.3.25 diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index d414a88ecfad36..f681a78b2461e3 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -316,7 +316,7 @@ EOF # See b/279852433 (internal). # TODO(b/279852433) Replace deps(//tensorflow/...) with deps(//...) @test "Verify that it's possible to query every TensorFlow target without BUILD errors" { - bazel query "deps(//tensorflow/...)" > /dev/null + bazel query "deps(//tensorflow/... -//tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test)" > /dev/null } teardown_file() { diff --git a/ci/official/utilities/rename_and_verify_wheels.sh b/ci/official/utilities/rename_and_verify_wheels.sh index 36e55652c3cfd4..2111be61b802cc 100755 --- a/ci/official/utilities/rename_and_verify_wheels.sh +++ b/ci/official/utilities/rename_and_verify_wheels.sh @@ -57,15 +57,28 @@ fi venv=$(mktemp -d) "python${TFCI_PYTHON_VERSION}" -m venv "$venv" python="$venv/bin/python3" -# TODO(b/366266944) Remove the check after tf docker image upgrade for NumPy 2. -if [[ "$TFCI_WHL_NUMPY_VERSION" == 2 ]]; then - "$python" -m pip install numpy==2.0.0 +# TODO(b/366266944) Remove the check after tf docker image upgrade for NumPy 2 +# and numpy 1 support is dropped b/361369076. +if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then + "$python" -m pip install numpy==1.26.0 fi "$python" -m pip install *.whl $TFCI_PYTHON_VERIFY_PIP_INSTALL_ARGS if [[ "$TFCI_WHL_IMPORT_TEST_ENABLE" == "1" ]]; then "$python" -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' "$python" -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' fi +# Import tf nightly wheel built with numpy2 from PyPI in numpy1 env for testing. +# This aims to maintain TF compatibility with NumPy 1.x until 2025 b/361369076. +if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then + # Uninstall tf nightly wheel built with numpy1. + "$python" -m pip uninstall -y tf_nightly_numpy1 + # Install tf nightly cpu wheel built with numpy2.x from PyPI in numpy1.x env. + "$python" -m pip install tf-nightly-cpu + if [[ "$TFCI_WHL_IMPORT_TEST_ENABLE" == "1" ]]; then + "$python" -c 'import tensorflow as tf; t1=tf.constant([1,2,3,4]); t2=tf.constant([5,6,7,8]); print(tf.add(t1,t2).shape)' + "$python" -c 'import sys; import tensorflow as tf; sys.exit(0 if "keras" in tf.keras.__name__ else 1)' + fi +fi # VERY basic check to ensure the [and-cuda] package variant is installable. # Checks TFCI_BAZEL_COMMON_ARGS for "gpu" or "cuda", implying that the test is # relevant. All of the GPU test machines have CUDA installed via other means, diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index ec05db2716ce63..ebe7cf31bff5c5 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -27,8 +27,16 @@ if [[ "$TFCI_NIGHTLY_UPDATE_VERSION_ENABLE" == 1 ]]; then export TFCI_BUILD_PIP_PACKAGE_ARGS="$(echo $TFCI_BUILD_PIP_PACKAGE_ARGS | sed 's/tensorflow/tf_nightly/')" fi +# TODO(b/361369076) Remove the following block after TF NumPy 1 is dropped +# Move hermetic requirement lock files for NumPy 1 to the root +if [[ "$TFCI_WHL_NUMPY_VERSION" == 1 ]]; then + cp ./ci/official/requirements_updater/numpy1_requirements/*.txt . +fi + tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --config=cuda_wheel //tensorflow/tools/pip_package:wheel $TFCI_BUILD_PIP_PACKAGE_ARGS tfrun find ./bazel-bin/tensorflow/tools/pip_package -iname "*.whl" -exec cp {} $TFCI_OUTPUT_DIR \; +tfrun mkdir ./dist +tfrun cp $TFCI_OUTPUT_DIR/*.whl ./dist tfrun ./ci/official/utilities/rename_and_verify_wheels.sh if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then @@ -37,5 +45,5 @@ if [[ "$TFCI_ARTIFACT_STAGING_GCS_ENABLE" == 1 ]]; then fi if [[ "$TFCI_WHL_BAZEL_TEST_ENABLE" == 1 ]]; then - tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" + tfrun bazel test $TFCI_BAZEL_COMMON_ARGS $TFCI_BUILD_PIP_PACKAGE_ARGS --repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_wheel_test" fi diff --git a/ci/official/wheel_test/BUILD b/ci/official/wheel_test/BUILD deleted file mode 100644 index 3cca20e70a545b..00000000000000 --- a/ci/official/wheel_test/BUILD +++ /dev/null @@ -1,5 +0,0 @@ -py_test( - name = "test_import_api_packages", - srcs = ["test_import_api_packages.py"], - deps = ["@pypi_tensorflow//:pkg"], -) diff --git a/ci/official/wheel_test/README.md b/ci/official/wheel_test/README.md deleted file mode 100644 index cef4131e63f579..00000000000000 --- a/ci/official/wheel_test/README.md +++ /dev/null @@ -1,94 +0,0 @@ -## Wheel Test - -This directory is dedicated to tests that require a built TensorFlow wheel -file for testing, such as: - -* Ensuring the entire API is importable -* Testing downstream projects against the wheel - -Ensure you have Bazel installed and accessible from your command line. - -These tests use hermetic Python. They also require a built TensorFlow wheel file -and a requirements_lock file. The requirements_lock file is generated by the -[requirements_updater](https://github.com/tensorflow/tensorflow/tree/master/ci/official/requirements_updater) -tool using the path to this wheel file. - -### Hermetic Python - -For details about hermetic Python and setting its toolchain version, see -[requirements updater readme](https://github.com/tensorflow/tensorflow/blob/master/ci/official/requirements_updater/README.md) - -### Prerequisites for Local Testing - -To run tests locally, follow these steps: - -1. Navigate to the relevant directory: - ``` - cd ci/official/wheel_test - ``` -2. Run a script for creating requirements file: - ``` - bash update_requirements.sh - e.g.: - bash update_requirements.sh /tmp/tensorflow-2.14.0-cp311-cp311-linux_x86_64.whl 3_11 - ``` - -#### Requirements Updater Script -This script automates the process of updating TensorFlow requirements for a -specific Python version. - -##### Parameters -`path_to_tensorflow_wheel`: The local path to the TensorFlow wheel file. -Example: `/tmp/tensorflow-2.14.0-cp311-cp311-linux_x86_64.whl` - -`python_version`: The target Python version, replacing `.` with `_`. -Example: For Python 3.11, use `3_11` - -The script performs the following steps: - -1. Navigates to the `../requirements_updater` directory. -2. Creates a `requirements_wheel_test.in` file and specifies the path -to the actual TensorFlow wheel. -3. Creates a `requirements_lock_.txt` file. -4. Updates the `requirements_lock_.txt` file using -a Bazel command. -5. Moves the updated `requirements_lock_.txt` file -to the `../wheel_test/` directory. - - -### How it Works in the Presubmit Job - -`_requirements_lock` files will be generated by the presubmit job. A detailed -description will be provided once it's integrated into presubmit. - -### test_import_api_packages - -This Python test verifies whether the API v2 packages can be imported from the -current build. It utilizes the `_api/v2/api_packages.txt` list of packages from -the local wheel file specified in the `requirements_lock_.txt`. - -Packages are imported one by one in alphabetical order during runtime. - -The test doesn't identify package's order-dependent issues; for instance, -importing "tf.foo" followed by "tf.bar" won't reveal that "tf.bar" depends on -"tf.foo" being imported first. - -The `_api/v2/api_packages.txt` file is generated during the TensorFlow API v2 -init files creation process and is subsequently stored in the wheel file after -the build. It also contains a few paths that cannot be directly imported. These -paths point to attributes or sub-modules within a module's namespace, but they -don't correspond to an actual file or directory on the filesystem. The list of -such paths is stored in the packages_for_skip variable and will be skipped -during the test. - -##### How to Build - -``` -bazel build //:test_import_api_packages -``` - -##### How to Run - -``` -bazel test //:test_import_api_packages --test_output=all -``` diff --git a/ci/official/wheel_test/WORKSPACE b/ci/official/wheel_test/WORKSPACE deleted file mode 100644 index db46144dadbbb1..00000000000000 --- a/ci/official/wheel_test/WORKSPACE +++ /dev/null @@ -1,65 +0,0 @@ -# buildifier: disable=load-on-top - -workspace(name = "wheel_test") - -# buildifier: disable=load-on-top - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") - -http_archive( - name = "bazel_skylib", - sha256 = "74d544d96f4a5bb630d465ca8bbcfe231e3594e5aae57e1edbf17a6eb3ca2506", - urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.3.0/bazel-skylib-1.3.0.tar.gz", - ], -) - -http_archive( - name = "rules_python", - sha256 = "9d04041ac92a0985e344235f5d946f71ac543f1b1565f2cdbc9a2aaee8adf55b", - strip_prefix = "rules_python-0.26.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.26.0/rules_python-0.26.0.tar.gz", -) - -# buildifier: disable=same-origin-load -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -## Load HERMETIC_PYTHON_VERSION variable -local_repository( - name = "local_tensorflow", - path = "../../..", -) - -load( - "@local_tensorflow//tensorflow/tools/toolchains/python:python_repo.bzl", - "python_repository", -) - -python_repository(name = "python_version_repo") - -load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION") - -# Register multi toolchains -load("@rules_python//python:repositories.bzl", "python_register_toolchains") # buildifier: disable=same-origin-load - -python_register_toolchains( - name = "python", - ignore_root_user_error = True, - python_version = TF_PYTHON_VERSION, -) - -load("@python//:defs.bzl", "interpreter") -load("@rules_python//python:pip.bzl", "pip_parse") - -pip_parse( - name = "pypi", - python_interpreter_target = interpreter, - requirements = "//:requirements_lock_" + TF_PYTHON_VERSION.replace(".", "_") + ".txt", -) - -load("@pypi//:requirements.bzl", "install_deps") - -install_deps() diff --git a/ci/official/wheel_test/update_requirements.sh b/ci/official/wheel_test/update_requirements.sh deleted file mode 100644 index bed56273b48952..00000000000000 --- a/ci/official/wheel_test/update_requirements.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -# Copyright 2023 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# script to run pip-compile for keras, tensorboard, estimator deps. -# if there is a change in requirements.in then all lock files will be updated -# accordingly. - -# -e: abort script if one command fails -# -u: error if undefined variable used -# -o pipefail: entire command fails if pipe fails. watch out for yes | ... -# -o history: record shell history -set -euo pipefail -o history - -# Check for required arguments -if [ -z "$1" ]; then - echo "Usage: $0 " - exit 1 -fi - -TENSORFLOW_WHEEL_PATH="$1" -PYTHON_VERSION="$2" - -# All commands run relative to this directory -cd "$(dirname "${BASH_SOURCE[0]}")" -cd ../requirements_updater || exit 1 - -# Create the requirements_wheel_test.in file -echo "tensorflow @ file://localhost/$TENSORFLOW_WHEEL_PATH" > requirements_wheel_test.in - -# Create the requirements_lock file -REQUIREMENTS_LOCK_FILE="requirements_lock_${PYTHON_VERSION}.txt" -touch "$REQUIREMENTS_LOCK_FILE" - -### Update the requirements_lock file -bazel run --experimental_convenience_symlinks=ignore --repo_env=REQUIREMENTS_FILE_NAME=requirements_wheel_test.in //:requirements_${PYTHON_VERSION}.update - -# Move the updated file to the appropriate directory -mv "$REQUIREMENTS_LOCK_FILE" ../wheel_test/ - -echo "All tasks completed successfully." diff --git a/configure.py b/configure.py index 50ed76e9f23d14..c756a2f52980a7 100644 --- a/configure.py +++ b/configure.py @@ -20,15 +20,10 @@ import os import platform import re +import shutil import subprocess import sys -# pylint: disable=g-import-not-at-top -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' @@ -163,9 +158,9 @@ def get_python_path(environ_cp, python_bin_path): except subprocess.CalledProcessError: library_paths = [ run_shell([ - python_bin_path, '-c', - 'from distutils.sysconfig import get_python_lib;' - 'print(get_python_lib())' + python_bin_path, + '-c', + 'import sysconfig;print(sysconfig.get_path("purelib")', ]) ] @@ -425,9 +420,9 @@ def retrieve_bazel_version(): Returns: The bazel version detected. """ - bazel_executable = which('bazel') + bazel_executable = shutil.which('bazel') if bazel_executable is None: - bazel_executable = which('bazelisk') + bazel_executable = shutil.which('bazelisk') if bazel_executable is None: print('Cannot find bazel. Please install bazel/bazelisk.') sys.exit(1) @@ -617,7 +612,7 @@ def set_clang_cuda_compiler_path(environ_cp): if not os.path.exists(default_clang_path): default_clang_path = '/usr/lib/llvm-16/bin/clang' if not os.path.exists(default_clang_path): - default_clang_path = which('clang') or '' + default_clang_path = shutil.which('clang') or '' clang_cuda_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -779,7 +774,7 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" - default_gcc_host_compiler_path = which('gcc') or '' + default_gcc_host_compiler_path = shutil.which('gcc') or '' gcc_host_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -840,7 +835,7 @@ def set_clang_compiler_path(environ_cp): if not os.path.exists(default_clang_path): default_clang_path = '/usr/lib/llvm-16/bin/clang' if not os.path.exists(default_clang_path): - default_clang_path = which('clang') or '' + default_clang_path = shutil.which('clang') or '' clang_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -879,7 +874,7 @@ def set_clang_compiler_path_win(environ_cp): # Default path if clang-16 is installed by using apt-get install default_clang_path = 'C:/Program Files/LLVM/bin/clang.exe' if not os.path.exists(default_clang_path): - default_clang_path = which('clang') or '' + default_clang_path = shutil.which('clang') or '' clang_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -1181,7 +1176,7 @@ def configure_ios(environ_cp): def get_gcc_compiler(environ_cp): - gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or which('gcc') + gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or shutil.which('gcc') if gcc_env is not None: gcc_version = run_shell([gcc_env, '--version']).split() if gcc_version[0] in ('gcc', 'g++'): diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/BUILD b/dep_graph.dot similarity index 100% rename from third_party/xla/third_party/py/non_hermetic/ml_dtypes/BUILD rename to dep_graph.dot diff --git a/requirements_lock_3_10.txt b/requirements_lock_3_10.txt index e058fe135f93f6..30bc5923688f5d 100644 --- a/requirements_lock_3_10.txt +++ b/requirements_lock_3_10.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -359,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -424,6 +453,7 @@ packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 # via -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa @@ -505,9 +535,10 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d +tb-nightly==2.19.0a20240926 \ + --hash=sha256:4f2f4dd02eda684fbb2edd9cb46b7bd0ee7ba6ca35d38e4de2e293df8567c1b4 # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -518,6 +549,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_11.txt b/requirements_lock_3_11.txt index d57e2df7c8abf1..d77fe7a825811b 100644 --- a/requirements_lock_3_11.txt +++ b/requirements_lock_3_11.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -359,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -424,6 +453,7 @@ packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 # via -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa @@ -505,9 +535,10 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d +tb-nightly==2.19.0a20240926 \ + --hash=sha256:4f2f4dd02eda684fbb2edd9cb46b7bd0ee7ba6ca35d38e4de2e293df8567c1b4 # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -518,6 +549,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_12.txt b/requirements_lock_3_12.txt index 46778af8ee1b4b..e0e8c4eafa7761 100644 --- a/requirements_lock_3_12.txt +++ b/requirements_lock_3_12.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -256,6 +265,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -359,52 +380,60 @@ namex==0.0.8 \ --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 # via keras-nightly -numpy==2.0.2 \ - --hash=sha256:0123ffdaa88fa4ab64835dcbde75dcdf89c453c922f18dced6e27c90d1d0ec5a \ - --hash=sha256:11a76c372d1d37437857280aa142086476136a8c0f373b2e648ab2c8f18fb195 \ - --hash=sha256:13e689d772146140a252c3a28501da66dfecd77490b498b168b501835041f951 \ - --hash=sha256:1e795a8be3ddbac43274f18588329c72939870a16cae810c2b73461c40718ab1 \ - --hash=sha256:26df23238872200f63518dd2aa984cfca675d82469535dc7162dc2ee52d9dd5c \ - --hash=sha256:286cd40ce2b7d652a6f22efdfc6d1edf879440e53e76a75955bc0c826c7e64dc \ - --hash=sha256:2b2955fa6f11907cf7a70dab0d0755159bca87755e831e47932367fc8f2f2d0b \ - --hash=sha256:2da5960c3cf0df7eafefd806d4e612c5e19358de82cb3c343631188991566ccd \ - --hash=sha256:312950fdd060354350ed123c0e25a71327d3711584beaef30cdaa93320c392d4 \ - --hash=sha256:423e89b23490805d2a5a96fe40ec507407b8ee786d66f7328be214f9679df6dd \ - --hash=sha256:496f71341824ed9f3d2fd36cf3ac57ae2e0165c143b55c3a035ee219413f3318 \ - --hash=sha256:49ca4decb342d66018b01932139c0961a8f9ddc7589611158cb3c27cbcf76448 \ - --hash=sha256:51129a29dbe56f9ca83438b706e2e69a39892b5eda6cedcb6b0c9fdc9b0d3ece \ - --hash=sha256:5fec9451a7789926bcf7c2b8d187292c9f93ea30284802a0ab3f5be8ab36865d \ - --hash=sha256:671bec6496f83202ed2d3c8fdc486a8fc86942f2e69ff0e986140339a63bcbe5 \ - --hash=sha256:7f0a0c6f12e07fa94133c8a67404322845220c06a9e80e85999afe727f7438b8 \ - --hash=sha256:807ec44583fd708a21d4a11d94aedf2f4f3c3719035c76a2bbe1fe8e217bdc57 \ - --hash=sha256:883c987dee1880e2a864ab0dc9892292582510604156762362d9326444636e78 \ - --hash=sha256:8c5713284ce4e282544c68d1c3b2c7161d38c256d2eefc93c1d683cf47683e66 \ - --hash=sha256:8cafab480740e22f8d833acefed5cc87ce276f4ece12fdaa2e8903db2f82897a \ - --hash=sha256:8df823f570d9adf0978347d1f926b2a867d5608f434a7cff7f7908c6570dcf5e \ - --hash=sha256:9059e10581ce4093f735ed23f3b9d283b9d517ff46009ddd485f1747eb22653c \ - --hash=sha256:905d16e0c60200656500c95b6b8dca5d109e23cb24abc701d41c02d74c6b3afa \ - --hash=sha256:9189427407d88ff25ecf8f12469d4d39d35bee1db5d39fc5c168c6f088a6956d \ - --hash=sha256:96a55f64139912d61de9137f11bf39a55ec8faec288c75a54f93dfd39f7eb40c \ - --hash=sha256:97032a27bd9d8988b9a97a8c4d2c9f2c15a81f61e2f21404d7e8ef00cb5be729 \ - --hash=sha256:984d96121c9f9616cd33fbd0618b7f08e0cfc9600a7ee1d6fd9b239186d19d97 \ - --hash=sha256:9a92ae5c14811e390f3767053ff54eaee3bf84576d99a2456391401323f4ec2c \ - --hash=sha256:9ea91dfb7c3d1c56a0e55657c0afb38cf1eeae4544c208dc465c3c9f3a7c09f9 \ - --hash=sha256:a15f476a45e6e5a3a79d8a14e62161d27ad897381fecfa4a09ed5322f2085669 \ - --hash=sha256:a392a68bd329eafac5817e5aefeb39038c48b671afd242710b451e76090e81f4 \ - --hash=sha256:a3f4ab0caa7f053f6797fcd4e1e25caee367db3112ef2b6ef82d749530768c73 \ - --hash=sha256:a46288ec55ebbd58947d31d72be2c63cbf839f0a63b49cb755022310792a3385 \ - --hash=sha256:a61ec659f68ae254e4d237816e33171497e978140353c0c2038d46e63282d0c8 \ - --hash=sha256:a842d573724391493a97a62ebbb8e731f8a5dcc5d285dfc99141ca15a3302d0c \ - --hash=sha256:becfae3ddd30736fe1889a37f1f580e245ba79a5855bff5f2a29cb3ccc22dd7b \ - --hash=sha256:c05e238064fc0610c840d1cf6a13bf63d7e391717d247f1bf0318172e759e692 \ - --hash=sha256:c1c9307701fec8f3f7a1e6711f9089c06e6284b3afbbcd259f7791282d660a15 \ - --hash=sha256:c7b0be4ef08607dd04da4092faee0b86607f111d5ae68036f16cc787e250a131 \ - --hash=sha256:cfd41e13fdc257aa5778496b8caa5e856dc4896d4ccf01841daee1d96465467a \ - --hash=sha256:d731a1c6116ba289c1e9ee714b08a8ff882944d4ad631fd411106a30f083c326 \ - --hash=sha256:df55d490dea7934f330006d0f81e8551ba6010a5bf035a249ef61a94f21c500b \ - --hash=sha256:ec9852fb39354b5a45a80bdab5ac02dd02b15f44b3804e9f00c556bf24b4bded \ - --hash=sha256:f15975dfec0cf2239224d80e32c3170b1d168335eaedee69da84fbe9f1f9cd04 \ - --hash=sha256:f26b258c385842546006213344c50655ff1555a9338e2e5e02a0756dc3e803dd +numpy==2.1.1 \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r ci/official/requirements_updater/requirements.in # h5py @@ -424,6 +453,7 @@ packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 # via -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa @@ -505,9 +535,10 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d +tb-nightly==2.19.0a20240926 \ + --hash=sha256:4f2f4dd02eda684fbb2edd9cb46b7bd0ee7ba6ca35d38e4de2e293df8567c1b4 # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -518,6 +549,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/requirements_lock_3_9.txt b/requirements_lock_3_9.txt index 87287e12978f31..580cdebe637239 100644 --- a/requirements_lock_3_9.txt +++ b/requirements_lock_3_9.txt @@ -166,10 +166,19 @@ dm-tree==0.1.8 \ --hash=sha256:f7ac31b9aecccb2c6e1ab29706f6ded3eba0c2c69c770322c9c685929c3d6afb \ --hash=sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in grpcio==1.64.1 \ --hash=sha256:03b43d0ccf99c557ec671c7dede64f023c7da9bb632ac65dbc57f166e4970040 \ --hash=sha256:0a12ddb1678ebc6a84ec6b0487feac020ee2b1659cbe69b80f06dbffdb249122 \ @@ -260,6 +269,18 @@ keras-nightly==3.0.4.dev2024021403 \ --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in lit==17.0.6 \ --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b # via -r ci/official/requirements_updater/requirements.in @@ -428,6 +449,7 @@ packaging==23.2 \ --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 # via -r ci/official/requirements_updater/requirements.in + # tb-nightly portpicker==1.6.0 \ --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa @@ -509,9 +531,10 @@ six==1.16.0 \ --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via # astunparse + # google-pasta # tb-nightly -tb-nightly==2.18.0a20240611 \ - --hash=sha256:c299eb7dc3de22c7164a1b0c0091b784f2214d65b9a8b967eeeba9818314016d +tb-nightly==2.19.0a20240926 \ + --hash=sha256:4f2f4dd02eda684fbb2edd9cb46b7bd0ee7ba6ca35d38e4de2e293df8567c1b4 # via -r ci/official/requirements_updater/requirements.in tblib==2.0.0 \ --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ @@ -522,6 +545,24 @@ tensorboard-data-server==0.7.2 \ --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 # via tb-nightly +tensorflow-io-gcs-filesystem==0.37.1 \ + --hash=sha256:0df00891669390078a003cedbdd3b8e645c718b111917535fa1d7725e95cdb95 \ + --hash=sha256:249c12b830165841411ba71e08215d0e94277a49c551e6dd5d72aab54fe5491b \ + --hash=sha256:257aab23470a0796978efc9c2bcf8b0bc80f22e6298612a4c0a50d3f4e88060c \ + --hash=sha256:286389a203a5aee1a4fa2e53718c661091aa5fea797ff4fa6715ab8436b02e6c \ + --hash=sha256:32c50ab4e29a23c1f91cd0f9ab8c381a0ab10f45ef5c5252e94965916041737c \ + --hash=sha256:426de1173cb81fbd62becec2012fc00322a295326d90eb6c737fab636f182aed \ + --hash=sha256:6e1f2796b57e799a8ca1b75bf47c2aaa437c968408cc1a402a9862929e104cda \ + --hash=sha256:8943036bbf84e7a2be3705cb56f9c9df7c48c9e614bb941f0936c58e3ca89d6f \ + --hash=sha256:8febbfcc67c61e542a5ac1a98c7c20a91a5e1afc2e14b1ef0cb7c28bc3b6aa70 \ + --hash=sha256:9679b36e3a80921876f31685ab6f7270f3411a4cc51bc2847e80d0e4b5291e27 \ + --hash=sha256:b02f9c5f94fd62773954a04f69b68c4d576d076fd0db4ca25d5479f0fbfcdbad \ + --hash=sha256:ee5da49019670ed364f3e5fb86b46420841a6c3cb52a300553c63841671b3e6d \ + --hash=sha256:ee7c8ee5fe2fd8cb6392669ef16e71841133041fee8a330eff519ad9b36e4556 \ + --hash=sha256:fbb33f1745f218464a59cecd9a18e32ca927b0f4d77abd8f8671b645cc1a182f \ + --hash=sha256:fe8dcc6d222258a080ac3dfcaaaa347325ce36a7a046277f6b3e19abc1efb3c5 \ + --hash=sha256:ffebb6666a7bfc28005f4fbbb111a455b5e7d6cd3b12752b7050863ecb27d5cc + # via -r ci/official/requirements_updater/requirements.in termcolor==2.3.0 \ --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d62edbbfcf48c5..b145a683737dff 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -13,7 +13,7 @@ load( "if_google", "if_oss", "if_xla_available", - "tf_cc_shared_object", + "pywrap_aware_tf_cc_shared_object", "tf_custom_op_library_additional_deps_impl", "tf_monitoring_python_deps", "tf_native_cc_binary", @@ -552,18 +552,30 @@ config_setting( config_setting( name = "linux_aarch64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "aarch64"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_armhf", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "armhf"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_x86_64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) @@ -577,6 +589,10 @@ config_setting( # This condition takes precedence over :linux_x86_64 config_setting( name = "linux_x86_64_no_sse", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = { "cpu": "k8", "copt": "-mno-sse4.2", @@ -588,6 +604,10 @@ config_setting( # TODO(b/290533709): Remove this with PJRT build rule cleanup. config_setting( name = "linux_x86_64_with_weightwatcher", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), define_values = {"tensorflow_weightwatcher": "true"}, values = {"cpu": "k8"}, visibility = ["//visibility:public"], @@ -595,24 +615,40 @@ config_setting( config_setting( name = "linux_ppc64le", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "ppc"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_s390x", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "s390x"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_mips64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "mips64"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_riscv64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "riscv64"}, visibility = ["//visibility:public"], ) @@ -683,15 +719,6 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "no_nccl_support", - define_values = dict( - if_google({"GOOGLE_CUDA_COMPILER": "clang"}), - no_nccl_support = "true", - ), - visibility = ["//visibility:public"], -) - # Experimental features config_setting( name = "stackdriver_support", @@ -760,7 +787,7 @@ alias( name = "is_cuda_enabled", actual = if_oss( "@local_config_cuda//:is_cuda_enabled", - "@local_config_cuda//cuda:using_clang", + "@local_config_cuda//cuda:using_config_cuda", ), ) @@ -774,6 +801,14 @@ alias( ), ) +selects.config_setting_group( + name = "is_cuda_clang", + match_all = [ + ":is_cuda_enabled", + ":is_cuda_compiler_clang", + ], +) + # Config setting that is satisfied when CUDA device code should be compiled # with nvcc. It does not imply that CUDA support has been enabled. alias( @@ -784,6 +819,14 @@ alias( ), ) +selects.config_setting_group( + name = "is_cuda_nvcc", + match_all = [ + ":is_cuda_enabled", + ":is_cuda_compiler_nvcc", + ], +) + # Config setting that is satisfied when building with --config=cuda in OSS. selects.config_setting_group( name = "is_cuda_enabled_and_oss", @@ -912,43 +955,6 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) -# This flag disables all google production dependencies, intended for -# applications run with non-prod environment. -# TODO(timshen): Currently this option only disables some dependencies. -# See b/122528503. -# copybara:uncomment_begin(google-only) -# bool_flag( -# name = "tf_no_prod_deps", -# build_setting_default = False, -# ) -# -# config_setting( -# name = "no_prod_deps_define", -# define_values = {"tf_no_prod_deps": "1"}, -# ) -# -# config_setting( -# name = "no_prod_deps_flag", -# flag_values = {":tf_no_prod_deps": "True"}, -# ) -# -# selects.config_setting_group( -# name = "no_prod_deps", -# match_any = [ -# ":no_prod_deps_define", -# ":no_prod_deps_flag", -# ], -# ) -# -# config_setting( -# name = "no_prod_deps_cuda", -# define_values = { -# "tf_no_prod_deps": "1", -# "GOOGLE_CUDA_COMPILER": "clang", -# }, -# ) -# copybara:uncomment_end - config_setting( name = "lite_protos_legacy", define_values = {"TENSORFLOW_PROTOS": "lite"}, @@ -1115,20 +1121,13 @@ bzl_library( "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", - "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", + "@local_tsl//third_party/py/rules_pywrap:pywrap_bzl", "@local_xla//xla/tsl:tsl_bzl", "@local_xla//xla/tsl/mkl:build_defs_bzl", "@rules_java//java:rules", ], ) -bzl_library( - name = "tensorflow_default_bzl", - srcs = ["tensorflow.default.bzl"], - visibility = ["//visibility:public"], - deps = [":tensorflow_bzl"], -) - # TODO(jakeharmon8): Remove these in favor of tsl:grpc # copybara:comment_begin(oss-only) cc_library( @@ -1262,7 +1261,7 @@ cc_import( # an "-exported_symbols_list" command. -z defs disallows undefined # symbols in object files. -tf_cc_shared_object( +pywrap_aware_tf_cc_shared_object( name = "tensorflow", linkopts = select({ "//tensorflow:macos": [ diff --git a/tensorflow/api_template.__init__.py b/tensorflow/api_template.__init__.py index 5ea9ef248d55e2..613b7dac69bcfe 100644 --- a/tensorflow/api_template.__init__.py +++ b/tensorflow/api_template.__init__.py @@ -27,13 +27,15 @@ """ # pylint: disable=g-bad-import-order,protected-access,g-import-not-at-top -import distutils as _distutils +import sysconfig as _sysconfig import importlib import inspect as _inspect import os as _os import site as _site import sys as _sys +_os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1") + # Do not remove this line; See https://github.com/tensorflow/tensorflow/issues/42596 from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow # pylint: disable=unused-import from tensorflow.python.tools import module_util as _module_util @@ -100,8 +102,9 @@ if "getsitepackages" in dir(_site): _site_packages_dirs += _site.getsitepackages() -if "sysconfig" in dir(_distutils): - _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] +for _scheme in _sysconfig.get_scheme_names(): + for _name in ["purelib", "platlib"]: + _site_packages_dirs += [_sysconfig.get_path(_name, _scheme)] _site_packages_dirs = list(set(_site_packages_dirs)) diff --git a/tensorflow/api_template_v1.__init__.py b/tensorflow/api_template_v1.__init__.py index 6a4ab4e655fd77..124795b619208c 100644 --- a/tensorflow/api_template_v1.__init__.py +++ b/tensorflow/api_template_v1.__init__.py @@ -14,12 +14,12 @@ # ============================================================================== """Bring in all of the public TensorFlow interface into this module.""" -import distutils as _distutils import importlib import inspect as _inspect import os as _os import site as _site import sys as _sys +import sysconfig # pylint: disable=g-bad-import-order,protected-access,g-import-not-at-top from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import @@ -147,8 +147,9 @@ if "getsitepackages" in dir(_site): _site_packages_dirs += _site.getsitepackages() -if "sysconfig" in dir(_distutils): - _site_packages_dirs += [_distutils.sysconfig.get_python_lib()] +for _scheme in sysconfig.get_scheme_names(): + for _name in ["purelib", "platlib"]: + _site_packages_dirs += [sysconfig.get_path(_name, _scheme)] _site_packages_dirs = list(set(_site_packages_dirs)) diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ae24298fad8f4e..a130ee7ca54ff3 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -3,6 +3,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_pywrap") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", @@ -150,10 +151,7 @@ tf_cuda_library( "tf_tensor.h", "tf_tstring.h", ], - visibility = [ - "//tensorflow:internal", - "//tensorflow/c:__subpackages__", - ], + visibility = ["//visibility:public"], deps = selects.with_or({ "//tensorflow:android": [ "//tensorflow/core:portable_tensorflow_lib_lite", @@ -347,8 +345,6 @@ tf_cuda_library( ], visibility = [ "//tensorflow/c:__subpackages__", - "//tensorflow/cc/experimental/libtf:__pkg__", - "//tensorflow/cc/experimental/libtf:__subpackages__", # copybara:uncomment_begin(google-only) # "//tensorflow/cc/experimental/tf2:__pkg__", # "//tensorflow/cc/experimental/tf2:__subpackages__", @@ -558,7 +554,7 @@ tf_cuda_library( "tf_tensor_helper.h", "tf_tensor_internal.h", ], - visibility = ["//tensorflow:internal"], + visibility = ["//visibility:public"], deps = [ ":c_api_macros", ":tensor_interface", @@ -934,6 +930,7 @@ tf_cuda_cc_test( ":test_op1.so", "//tensorflow/cc/saved_model:saved_model_half_plus_two", ], + extra_copts = if_pywrap(["-DTENSORFLOW_NO_SHARED_OBJECTS"]), linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -1133,7 +1130,10 @@ tf_cuda_library( name = "python_api", srcs = ["python_api.cc"], hdrs = ["python_api.h"], - visibility = ["//tensorflow/python:__pkg__"], + visibility = [ + "//tensorflow:internal", + "//tensorflow/python:__pkg__", + ], deps = [ ":c_api", ":c_api_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index aa4b5d6987871b..08c5de71906e31 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -285,8 +285,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, } // Helpers for loading a TensorFlow plugin (a .so file). -Status LoadDynamicLibrary(const char* library_filename, void** result, - const void** buf, size_t* len); +absl::Status LoadDynamicLibrary(const char* library_filename, void** result, + const void** buf, size_t* len); // TODO(josh11b,mrry): Change Session to be able to use a Graph* // directly, instead of requiring us to serialize to a GraphDef and diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index bedba2c51c6d39..b00cfb389ad137 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -332,7 +332,7 @@ TF_Buffer* TFE_GetServerDef(const char* text_proto, TF_Status* status) { "Invalid text proto for ServerDef: ", text_proto); return nullptr; } - status->status = tensorflow::Status(); + status->status = absl::Status(); TF_Buffer* ret = TF_NewBuffer(); TF_CHECK_OK(MessageToBuffer(server_def, ret)); return ret; @@ -595,10 +595,11 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, } namespace tensorflow { -Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); // Helpers for loadding a TensorFlow PluggableDevice plugin (a .so file). -Status LoadPluggableDeviceLibrary(const char* library_filename, void** result); +absl::Status LoadPluggableDeviceLibrary(const char* library_filename, + void** result); } // namespace tensorflow void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 25805954eff67c..df890eabb7f022 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -39,7 +39,7 @@ using tensorflow::errors::InvalidArgument; namespace tensorflow { namespace { -Status ValidateNonRefOutput(const Node* node, int idx) { +absl::Status ValidateNonRefOutput(const Node* node, int idx) { const DataType& dt = node->output_type(idx); return IsRefType(dt) ? InvalidArgument("Output ", idx, " of node '", node->name(), @@ -51,7 +51,7 @@ Status ValidateNonRefOutput(const Node* node, int idx) { // does various checks while doing so. `input_nodes` will contain the same // information as input_tensors just in a different structure to make // following processing easier. TODO(iga): Simplify this nested structure. -Status ProcessInputs( +absl::Status ProcessInputs( const TF_Graph* fn_body, const char* fn_name, int ninputs, const TF_Output* inputs, std::vector* input_tensors, std::unordered_map>* input_nodes) @@ -88,9 +88,9 @@ Status ProcessInputs( // Converts `noutputs` and `outputs` into `outputs_tensors` and does various // checks while doing so. -Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, - int noutputs, const TF_Output* outputs, - std::vector* output_tensors) +absl::Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, + int noutputs, const TF_Output* outputs, + std::vector* output_tensors) TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { output_tensors->reserve(noutputs); for (int i = 0; i < noutputs; ++i) { @@ -110,7 +110,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, // Populates `body_nodes` with the nodes that will become function's body. // Performs various checks. -Status ComputeBodyNodes( +absl::Status ComputeBodyNodes( const TF_Graph* fn_body, const char* fn_name, int num_opers, const TF_Operation* const* opers, const std::unordered_map>& input_nodes, diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 14045bbc2daef4..48cb17b190f334 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -59,8 +59,8 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status); -Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +TF_Tensor* TF_TensorFromTensor(const Tensor& src, absl::Status* status); +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { @@ -236,7 +236,7 @@ TEST(CAPI, LibraryLoadFunctions) { void TestEncodeDecode(int line, const std::vector& data) { const int64_t n = data.size(); - Status status; + absl::Status status; for (const std::vector& dims : std::vector>{{n}, {1, n}, {n, 1}, {n / 2, 2}}) { // Create C++ Tensor @@ -1450,7 +1450,7 @@ TEST(CAPI, SavedModel) { TF_Operation* input_op = TF_GraphOperationByName(graph, input_op_name.c_str()); ASSERT_TRUE(input_op != nullptr); - Status status; + absl::Status status; csession.SetInputs({{input_op, TF_TensorFromTensor(input, &status)}}); ASSERT_TRUE(status.ok()) << status.message(); @@ -2634,7 +2634,7 @@ TEST(CAPI, TestTensorIsNotAligned) { // Take an unaligned slice. Tensor y = x.Slice(1, 13); - Status status; + absl::Status status; TF_Tensor* a = TF_TensorFromTensor(y, &status); if (TF_TensorDefaultAlignment() > 0) { EXPECT_FALSE(TF_TensorIsAligned(a)); diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index 6c2d6b22d61c3b..97a5bbd4b6077a 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -90,7 +90,7 @@ const string CheckpointReader::DebugString() const { void CheckpointReader::GetTensor( const string& name, std::unique_ptr* out_tensor, TF_Status* out_status) const { - Status status; + absl::Status status; if (reader_ != nullptr) { status = reader_->GetTensor(name, out_tensor); } else { diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h index 2132daf2cfa388..4bf6ff9b781902 100644 --- a/tensorflow/c/eager/abstract_context.h +++ b/tensorflow/c/eager/abstract_context.h @@ -55,10 +55,10 @@ class AbstractContext { // Registers a function with this context, after this the function is // available to be called/referenced by its name in this context. - virtual Status RegisterFunction(AbstractFunction*) = 0; + virtual absl::Status RegisterFunction(AbstractFunction*) = 0; // Remove a function. 'func' argument is the name of a previously added // FunctionDef. The name is in fdef.signature.name. - virtual Status RemoveFunction(const string& func) = 0; + virtual absl::Status RemoveFunction(const string& func) = 0; private: const AbstractContextKind kind_; diff --git a/tensorflow/c/eager/abstract_function.h b/tensorflow/c/eager/abstract_function.h index 6989679f7c65c0..7bc8f8bd58eb1c 100644 --- a/tensorflow/c/eager/abstract_function.h +++ b/tensorflow/c/eager/abstract_function.h @@ -38,7 +38,7 @@ class AbstractFunction : public core::RefCounted { AbstractFunctionKind getKind() const { return kind_; } // Returns the AbstractFunction as a FunctionDef. - virtual Status GetFunctionDef(const FunctionDef**) = 0; + virtual absl::Status GetFunctionDef(const FunctionDef**) = 0; // Returns a shared reference to the wrapped function. virtual absl::StatusOr> diff --git a/tensorflow/c/eager/abstract_op_attrs.h b/tensorflow/c/eager/abstract_op_attrs.h index 134f3f49b4eea3..e799552a96ece5 100644 --- a/tensorflow/c/eager/abstract_op_attrs.h +++ b/tensorflow/c/eager/abstract_op_attrs.h @@ -41,7 +41,7 @@ class AbstractOpAttrs { virtual bool GetFloat(absl::string_view attr_name, float* result) const = 0; virtual bool GetBool(absl::string_view attr_name, bool* result) const = 0; virtual bool GetType(absl::string_view attr_name, DataType* result) const = 0; - virtual Status GetTypeList( + virtual absl::Status GetTypeList( absl::string_view attr_name, absl::InlinedVector* type_list) const = 0; diff --git a/tensorflow/c/eager/abstract_operation.h b/tensorflow/c/eager/abstract_operation.h index 68f74c3c31833d..95142210bfa218 100644 --- a/tensorflow/c/eager/abstract_operation.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -53,7 +53,7 @@ class AbstractOperation { // clients MUST call Release() in order to destroy an instance of this class. virtual void Release() = 0; - virtual Status Reset(const char* op, const char* raw_device_name) = 0; + virtual absl::Status Reset(const char* op, const char* raw_device_name) = 0; virtual const string& Name() const = 0; @@ -78,47 +78,54 @@ class AbstractOperation { // // The value will override the previous value - that is, no "merging" of // existing and given constraints will be performed. - virtual Status SetDeviceName(const char* name) = 0; + virtual absl::Status SetDeviceName(const char* name) = 0; - virtual Status AddInput(AbstractTensorHandle* input) = 0; - virtual Status AddInputList( + virtual absl::Status AddInput(AbstractTensorHandle* input) = 0; + virtual absl::Status AddInputList( absl::Span inputs) = 0; - virtual Status Execute(absl::Span retvals, - int* num_retvals) = 0; - - virtual Status SetAttrString(const char* attr_name, const char* data, - size_t length) = 0; - virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0; - virtual Status SetAttrFloat(const char* attr_name, float value) = 0; - virtual Status SetAttrBool(const char* attr_name, bool value) = 0; - virtual Status SetAttrType(const char* attr_name, DataType value) = 0; - virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) = 0; - virtual Status SetAttrShape(const char* attr_name, - const PartialTensorShape shape); - virtual Status SetAttrFunction(const char* attr_name, - const AbstractOperation* value) = 0; - virtual Status SetAttrFunctionName(const char* attr_name, const char* value, + virtual absl::Status Execute(absl::Span retvals, + int* num_retvals) = 0; + + virtual absl::Status SetAttrString(const char* attr_name, const char* data, size_t length) = 0; - virtual Status SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) = 0; - virtual Status SetAttrStringList(const char* attr_name, - const void* const* values, - const size_t* lengths, int num_values) = 0; - virtual Status SetAttrStringList(const char* attr_name, - absl::Span values); - virtual Status SetAttrFloatList(const char* attr_name, const float* values, - int num_values) = 0; - virtual Status SetAttrIntList(const char* attr_name, const int64_t* values, - int num_values) = 0; - virtual Status SetAttrTypeList(const char* attr_name, const DataType* values, - int num_values) = 0; - virtual Status SetAttrBoolList(const char* attr_name, - const unsigned char* values, - int num_values) = 0; - virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, - const int* num_dims, int num_values) = 0; - virtual Status SetAttrFunctionList( + virtual absl::Status SetAttrInt(const char* attr_name, int64_t value) = 0; + virtual absl::Status SetAttrFloat(const char* attr_name, float value) = 0; + virtual absl::Status SetAttrBool(const char* attr_name, bool value) = 0; + virtual absl::Status SetAttrType(const char* attr_name, DataType value) = 0; + virtual absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) = 0; + virtual absl::Status SetAttrShape(const char* attr_name, + const PartialTensorShape shape); + virtual absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) = 0; + virtual absl::Status SetAttrFunctionName(const char* attr_name, + const char* value, + size_t length) = 0; + virtual absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) = 0; + virtual absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) = 0; + virtual absl::Status SetAttrStringList(const char* attr_name, + absl::Span values); + virtual absl::Status SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) = 0; + virtual absl::Status SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) = 0; + virtual absl::Status SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) = 0; + virtual absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) = 0; + virtual absl::Status SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) = 0; + virtual absl::Status SetAttrFunctionList( const char* attr_name, absl::Span values) = 0; private: @@ -127,12 +134,12 @@ class AbstractOperation { // TODO(b/193656009): Defining these in a cc file causes linker errors with // fastbuild. -inline Status AbstractOperation::SetAttrShape(const char* attr_name, - const PartialTensorShape shape) { +inline absl::Status AbstractOperation::SetAttrShape( + const char* attr_name, const PartialTensorShape shape) { return SetAttrShape(attr_name, shape.dim_sizes().data(), shape.dims()); } -inline Status AbstractOperation::SetAttrStringList( +inline absl::Status AbstractOperation::SetAttrStringList( const char* attr_name, absl::Span values) { std::vector raw_strs; std::vector lengths; diff --git a/tensorflow/c/eager/abstract_tensor_handle.cc b/tensorflow/c/eager/abstract_tensor_handle.cc index e04a9810638f61..2bbe5042d3b76c 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.cc +++ b/tensorflow/c/eager/abstract_tensor_handle.cc @@ -19,7 +19,7 @@ namespace tensorflow { std::string AbstractTensorHandle::DebugString() const { PartialTensorShape shape; - Status s = Shape(&shape); + absl::Status s = Shape(&shape); std::string shape_string; if (!s.ok()) { shape_string = ""; @@ -31,7 +31,7 @@ std::string AbstractTensorHandle::DebugString() const { ", type=", FullType().DebugString(), ")"); } -Status AbstractTensorHandle::TensorHandleStatus() const { +absl::Status AbstractTensorHandle::TensorHandleStatus() const { // Tensor handles in current runtime don't carry error info and this method // should always return OK status. return absl::OkStatus(); diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index d950d143ce6239..4a40b1c9319bbb 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -39,11 +39,10 @@ class AbstractTensorHandle : public core::RefCounted { // Returns the status of the tensor handle. If it is a tfrt::TensorHandle, // the tensor handle can be an error and return non-OK status. - virtual tensorflow::Status TensorHandleStatus() const; + virtual absl::Status TensorHandleStatus() const; // Returns tensor shape. If tensor has unknown rank, shape remains untouched. - virtual tensorflow::Status Shape( - tensorflow::PartialTensorShape* shape) const = 0; + virtual absl::Status Shape(tensorflow::PartialTensorShape* shape) const = 0; // Returns tensor (full) type. // While there is no immediate plan to deprecate dtype and shape in favor diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index d20be8abcf02a4..9bbcdf1bcd7f99 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -439,7 +439,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { const string& name() override { return name_; } - tensorflow::Status CopyTensorToDevice( + absl::Status CopyTensorToDevice( ImmediateExecutionTensorHandle* handle, ImmediateExecutionTensorHandle** result) override { handle->Ref(); @@ -454,7 +454,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status CopyTensorFromDevice( + absl::Status CopyTensorFromDevice( ImmediateExecutionTensorHandle* handle, const tensorflow::string& target_device_name, ImmediateExecutionTensorHandle** result) override { @@ -471,9 +471,9 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status Execute(const ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, - int* num_retvals) override { + absl::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) override { std::vector outputs(*num_retvals); TF_Status status; device_.execute(tensorflow::wrap(op), num_retvals, outputs.data(), &status, @@ -488,8 +488,8 @@ class CustomDeviceAPI : public tensorflow::CustomDevice { return status.status; } - tensorflow::Status Pack(absl::Span handles, - ImmediateExecutionTensorHandle** result) override { + absl::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) override { TF_Status status; *result = tensorflow::unwrap(device_.pack(context_, tensorflow::wrap(handles.data()), @@ -530,12 +530,12 @@ class CAPICustomDeviceTensorHandle ~CAPICustomDeviceTensorHandle() override { methods_.deallocator(data_); } void* DevicePointer() const override { return data_; } - Status NumDims(int* num_dims) const override { + absl::Status NumDims(int* num_dims) const override { TF_Status s; *num_dims = methods_.num_dims(data_, &s); return s.status; } - Status Dim(int dim_index, int64_t* dim) const override { + absl::Status Dim(int dim_index, int64_t* dim) const override { TF_Status s; *dim = methods_.dim(data_, dim_index, &s); return s.status; @@ -545,7 +545,7 @@ class CAPICustomDeviceTensorHandle return methods_.summarize != nullptr; } - Status SummarizeValue(std::string& summary) const override { + absl::Status SummarizeValue(std::string& summary) const override { if (methods_.summarize == nullptr) { return tensorflow::CustomDeviceTensorHandle::SummarizeValue(summary); } diff --git a/tensorflow/c/eager/c_api_cluster_test.cc b/tensorflow/c/eager/c_api_cluster_test.cc index c4b58c3dd733e7..b9a19e883af73a 100644 --- a/tensorflow/c/eager/c_api_cluster_test.cc +++ b/tensorflow/c/eager/c_api_cluster_test.cc @@ -148,7 +148,7 @@ void TestRemoteExecuteChangeServerDef(bool async) { serialized = updated_server_def.SerializeAsString(); updated_server_def.set_task_index(1); - tensorflow::Status s = tensorflow::GrpcServer::Create( + absl::Status s = tensorflow::GrpcServer::Create( updated_server_def, tensorflow::Env::Default(), &worker_server); ASSERT_TRUE(s.ok()) << s.message(); ASSERT_TRUE(worker_server->Start().ok()); @@ -430,7 +430,7 @@ void TestConnectToCluster(bool keep_localhost_for_first_connect) { TFE_TensorHandle* var_handle0 = TestVariable(ctx, 1.0, dev0_name); EXPECT_NE(var_handle0, nullptr); - tensorflow::Status status2; + absl::Status status2; EXPECT_EQ(tensorflow::unwrap(var_handle0)->DeviceName(&status2), dev0_name); // Rename local device diff --git a/tensorflow/c/eager/c_api_debug.cc b/tensorflow/c/eager/c_api_debug.cc index aa3af18ec8fd7e..cef9bb27834fde 100644 --- a/tensorflow/c/eager/c_api_debug.cc +++ b/tensorflow/c/eager/c_api_debug.cc @@ -28,7 +28,7 @@ using tensorflow::string; namespace { std::vector TensorShapeAsVector(const tensorflow::TensorHandle& handle, - tensorflow::Status* status) { + absl::Status* status) { std::vector shape; int rank = -1; *status = handle.NumDims(&rank); diff --git a/tensorflow/c/eager/c_api_distributed_test.cc b/tensorflow/c/eager/c_api_distributed_test.cc index 3cb7a5d0fa5f1a..d899c0eb23f919 100644 --- a/tensorflow/c/eager/c_api_distributed_test.cc +++ b/tensorflow/c/eager/c_api_distributed_test.cc @@ -312,7 +312,7 @@ class GraphErrorInjectionPass : public tensorflow::GraphOptimizationPass { static bool enabled_; GraphErrorInjectionPass() {} - tensorflow::Status Run( + absl::Status Run( const tensorflow::GraphOptimizationPassOptions& options) override { if (!enabled_) { return absl::OkStatus(); @@ -431,14 +431,14 @@ class FunctionErrorInjectionPass : public tensorflow::FunctionOptimizationPass { public: FunctionErrorInjectionPass(string error_node, string error_device) : error_node_(error_node), error_device_(error_device) {} - tensorflow::Status Run(const std::string& function_name, - const tensorflow::DeviceSet& device_set, - const tensorflow::ConfigProto& config_proto, - const FunctionOptions& function_options, - std::unique_ptr* graph, - tensorflow::FunctionLibraryDefinition* flib_def, - std::vector* control_ret_node_names, - bool* control_rets_updated) override { + absl::Status Run(const std::string& function_name, + const tensorflow::DeviceSet& device_set, + const tensorflow::ConfigProto& config_proto, + const FunctionOptions& function_options, + std::unique_ptr* graph, + tensorflow::FunctionLibraryDefinition* flib_def, + std::vector* control_ret_node_names, + bool* control_rets_updated) override { // Inject failure to function instantiation if finding a node that contains // the given node name (error_node_) and requested device (error_device_). for (const auto node : graph->get()->nodes()) { diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 9ad742cd3c55f0..1756e7b42995ac 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -917,8 +917,7 @@ void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code, "Coordination service is not enabled."); return; } - tensorflow::Status s(static_cast(error_code), - error_message); + absl::Status s(static_cast(error_code), error_message); status->status = coord_agent->ReportError(s); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index ab29b1cd6ff051..b2cd9d34d55273 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -48,7 +48,7 @@ void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) { GetFactories()[name] = factory; } -Status SetDefaultTracingEngine(const char* name) { +absl::Status SetDefaultTracingEngine(const char* name) { auto entry = GetFactories().find(name); if (entry != GetFactories().end()) { default_factory = GetFactories().find(name)->second; diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index ca500236ecd8cb..23ebb99839c46b 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -63,8 +63,7 @@ class GraphTensor : public TracingTensorHandle { return static_cast(TF_OperationOutputType(output_)); } - tensorflow::Status Shape( - tensorflow::PartialTensorShape* shape) const override { + absl::Status Shape(tensorflow::PartialTensorShape* shape) const override { DCHECK(shape != nullptr); TF_Status status; int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status); @@ -111,7 +110,7 @@ class GraphOperation : public TracingOperation { public: explicit GraphOperation(TF_Graph* g) : TracingOperation(kGraph), g_(g) {} void Release() override { delete this; } - Status Reset(const char* op, const char* raw_device_name) override { + absl::Status Reset(const char* op, const char* raw_device_name) override { if (op_) { return errors::FailedPrecondition("Reset called on already built op."); } @@ -121,7 +120,7 @@ class GraphOperation : public TracingOperation { op_type_ = op; return absl::OkStatus(); } - Status SetOpName(const char* const op_name) override { + absl::Status SetOpName(const char* const op_name) override { if (op_) { return errors::FailedPrecondition( "SetOpName called on already built op."); @@ -140,13 +139,13 @@ class GraphOperation : public TracingOperation { const string& Name() const override { return op_type_; } const string& DeviceName() const override { return device_name_; } - Status SetDeviceName(const char* name) override { + absl::Status SetDeviceName(const char* name) override { // TODO(srbs): Implement this. device_name_ = name; return absl::OkStatus(); } - Status AddInput(AbstractTensorHandle* input) override { + absl::Status AddInput(AbstractTensorHandle* input) override { GraphTensor* t = dyn_cast(input); if (!t) { return tensorflow::errors::InvalidArgument( @@ -155,7 +154,8 @@ class GraphOperation : public TracingOperation { TF_AddInput(op_.get(), t->output_); return absl::OkStatus(); } - Status AddInputList(absl::Span inputs) override { + absl::Status AddInputList( + absl::Span inputs) override { std::vector tf_outputs(inputs.size()); for (int i = 0; i < inputs.size(); i++) { GraphTensor* t = dyn_cast(inputs[i]); @@ -168,8 +168,8 @@ class GraphOperation : public TracingOperation { TF_AddInputList(op_.get(), tf_outputs.data(), tf_outputs.size()); return absl::OkStatus(); } - Status Execute(absl::Span retvals, - int* num_retvals) override { + absl::Status Execute(absl::Span retvals, + int* num_retvals) override { auto* tf_opdesc = op_.release(); if (tf_opdesc == nullptr) { return errors::InvalidArgument("AbstractOp is incomplete."); @@ -185,35 +185,36 @@ class GraphOperation : public TracingOperation { return absl::OkStatus(); } - Status SetAttrString(const char* attr_name, const char* data, - size_t length) override { + absl::Status SetAttrString(const char* attr_name, const char* data, + size_t length) override { tensorflow::StringPiece s(data, length); op_->node_builder.Attr(attr_name, s); return absl::OkStatus(); } - Status SetAttrInt(const char* attr_name, int64_t value) override { + absl::Status SetAttrInt(const char* attr_name, int64_t value) override { op_->node_builder.Attr(attr_name, static_cast(value)); return absl::OkStatus(); } - Status SetAttrFloat(const char* attr_name, float value) override { + absl::Status SetAttrFloat(const char* attr_name, float value) override { op_->node_builder.Attr(attr_name, value); return absl::OkStatus(); } - Status SetAttrBool(const char* attr_name, bool value) override { + absl::Status SetAttrBool(const char* attr_name, bool value) override { op_->node_builder.Attr(attr_name, value); return absl::OkStatus(); } - Status SetAttrType(const char* const attr_name, DataType value) override { + absl::Status SetAttrType(const char* const attr_name, + DataType value) override { if (!op_) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, "op_type and op_name must be specified before specifying attrs."); } op_->node_builder.Attr(attr_name, value); return absl::OkStatus(); } - Status SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) override { + absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + const int num_dims) override { PartialTensorShape shape; if (num_dims >= 0) { shape = PartialTensorShape(ArraySlice( @@ -222,25 +223,27 @@ class GraphOperation : public TracingOperation { op_->node_builder.Attr(attr_name, shape); return absl::OkStatus(); } - Status SetAttrFunction(const char* attr_name, - const AbstractOperation* value) override { + absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override { return tensorflow::errors::Unimplemented( "SetAttrFunction has not been implemented yet."); } - Status SetAttrFunctionName(const char* attr_name, const char* value, - size_t length) override { + absl::Status SetAttrFunctionName(const char* attr_name, const char* value, + size_t length) override { tensorflow::NameAttrList func_name; func_name.set_name(string(value, value + length)); op_->node_builder.Attr(attr_name, func_name); return absl::OkStatus(); } - Status SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) override { + absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override { return tensorflow::errors::Unimplemented( "SetAttrTensor has not been implemented yet."); } - Status SetAttrStringList(const char* attr_name, const void* const* values, - const size_t* lengths, int num_values) override { + absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) override { if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) { op_->colocation_constraints.clear(); for (int i = 0; i < num_values; ++i) { @@ -257,27 +260,28 @@ class GraphOperation : public TracingOperation { } return absl::OkStatus(); } - Status SetAttrFloatList(const char* attr_name, const float* values, - int num_values) override { + absl::Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); return absl::OkStatus(); } - Status SetAttrIntList(const char* attr_name, const int64_t* values, - int num_values) override { + absl::Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override { op_->node_builder.Attr( attr_name, ArraySlice( reinterpret_cast(values), num_values)); return absl::OkStatus(); } - Status SetAttrTypeList(const char* attr_name, const DataType* values, - int num_values) override { + absl::Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override { op_->node_builder.Attr(attr_name, ArraySlice(values, num_values)); return absl::OkStatus(); } - Status SetAttrBoolList(const char* attr_name, const unsigned char* values, - int num_values) override { + absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) override { std::unique_ptr b(new bool[num_values]); for (int i = 0; i < num_values; ++i) { b[i] = values[i]; @@ -287,8 +291,8 @@ class GraphOperation : public TracingOperation { return absl::OkStatus(); } - Status SetAttrShapeList(const char* attr_name, const int64_t** dims, - const int* num_dims, int num_values) override { + absl::Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override { std::vector shapes; shapes.reserve(num_values); for (int i = 0; i < num_values; ++i) { @@ -302,7 +306,7 @@ class GraphOperation : public TracingOperation { op_->node_builder.Attr(attr_name, shapes); return absl::OkStatus(); } - Status SetAttrFunctionList( + absl::Status SetAttrFunctionList( const char* attr_name, absl::Span values) override { return tensorflow::errors::Unimplemented( @@ -341,8 +345,8 @@ class GraphContext : public TracingContext { return new GraphOperation(graph_.get()); } - Status AddParameter(DataType dtype, const PartialTensorShape& shape, - TracingTensorHandle** output) override { + absl::Status AddParameter(DataType dtype, const PartialTensorShape& shape, + TracingTensorHandle** output) override { TracingOperationPtr operation(CreateOperation()); TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr)); TF_RETURN_IF_ERROR( @@ -371,7 +375,7 @@ class GraphContext : public TracingContext { return absl::OkStatus(); } - Status Finalize(OutputList* outputs, AbstractFunction** f) override { + absl::Status Finalize(OutputList* outputs, AbstractFunction** f) override { std::vector graph_outputs; graph_outputs.reserve(outputs->outputs.size()); for (auto* abstract_output : outputs->outputs) { @@ -396,12 +400,12 @@ class GraphContext : public TracingContext { return absl::OkStatus(); } - Status RegisterFunction(AbstractFunction* func) override { + absl::Status RegisterFunction(AbstractFunction* func) override { return errors::Unimplemented( "Registering graph functions has not been implemented yet."); } - Status RemoveFunction(const string& func) override { + absl::Status RemoveFunction(const string& func) override { return errors::Unimplemented( "GraphContext::RemoveFunction has not been implemented yet."); } diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index cd0d7610c7faa8..872b9081a932ef 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -79,7 +79,7 @@ class TracingOperation : public AbstractOperation { // Sets the name of the operation: this is an optional identifier that is // not intended to carry semantics and preserved/propagated without // guarantees. - virtual Status SetOpName(const char* op_name) = 0; + virtual absl::Status SetOpName(const char* op_name) = 0; // For LLVM style RTTI. static bool classof(const AbstractOperation* ptr) { @@ -108,12 +108,13 @@ class TracingContext : public AbstractContext { public: // Add a function parameter and return the corresponding tensor. - virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape, - TracingTensorHandle**) = 0; + virtual absl::Status AddParameter(DataType dtype, + const PartialTensorShape& shape, + TracingTensorHandle**) = 0; // Finalize this context and make a function out of it. The context is in a // invalid state after this call and must be destroyed. - virtual Status Finalize(OutputList* outputs, AbstractFunction**) = 0; + virtual absl::Status Finalize(OutputList* outputs, AbstractFunction**) = 0; // For LLVM style RTTI. static bool classof(const AbstractContext* ptr) { @@ -122,7 +123,7 @@ class TracingContext : public AbstractContext { }; typedef TracingContext* (*FactoryFunction)(const char* fn_name, TF_Status*); -Status SetDefaultTracingEngine(const char* name); +absl::Status SetDefaultTracingEngine(const char* name); void RegisterTracingEngineFactory(const ::tensorflow::string& name, FactoryFunction factory); } // namespace tracing diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index e866ec0ca78151..be795a559671a6 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -46,7 +46,7 @@ class UnifiedCAPI void SetUp() override { TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = StatusFromTF_Status(status.get()); + absl::Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.message(); } }; diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index f8de31aadbaa6f..e3447215192f01 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -160,8 +160,8 @@ absl::optional DeviceNameFromDlContext(const DLDevice& ctx, } // Converts DLPack data type to TF_DATATYPE. -Status TfDataTypeFormDlDataType(const DLDataType& dtype, - TF_DataType* tf_dtype) { +absl::Status TfDataTypeFormDlDataType(const DLDataType& dtype, + TF_DataType* tf_dtype) { switch (dtype.code) { case DLDataTypeCode::kDLBool: if (dtype.bits != 8) { @@ -354,7 +354,7 @@ TFE_TensorHandle* TFE_HandleFromDLPack(void* dlm, TF_Status* status, return nullptr; } TF_DataType dtype; - Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype); + absl::Status s = TfDataTypeFormDlDataType(dl_tensor->dtype, &dtype); if (!s.ok()) { status->status = std::move(s); return nullptr; diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 2fcaee07b37f50..dfd129e11fb402 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -46,10 +46,10 @@ void GetDims(const TF_Tensor* t, int64_t* out_dims) { // Runs model as is if output is a scalar, // else sums the output tensor before returning. -Status RunAndMaybeSum(AbstractContext* ctx, Model forward, - absl::Span inputs, - absl::Span outputs, - bool use_function) { +absl::Status RunAndMaybeSum(AbstractContext* ctx, Model forward, + absl::Span inputs, + absl::Span outputs, + bool use_function) { AbstractTensorHandle* model_outputs[1]; // Run the model. @@ -89,10 +89,10 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, } // ========================= End Helper Functions============================== -Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, - int input_index, bool use_function, - AbstractTensorHandle** numerical_grad) { +absl::Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad) { vector theta_inputs(inputs.size()); for (int i{}; i < inputs.size(); ++i) { theta_inputs[i] = inputs[i]; diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h index c1671480bf9bf9..d64ad44888b0df 100644 --- a/tensorflow/c/eager/gradient_checker.h +++ b/tensorflow/c/eager/gradient_checker.h @@ -35,10 +35,10 @@ namespace gradients { * `numerical_grad` is the pointer to the AbstractTensorHandle* which will * hold the numerical gradient data at the end of the function. */ -Status CalcNumericalGrad(AbstractContext* ctx, Model forward, - absl::Span inputs, - int input_index, bool use_function, - AbstractTensorHandle** numerical_grad); +absl::Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + absl::Span inputs, + int input_index, bool use_function, + AbstractTensorHandle** numerical_grad); } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index e012b29e93fdfc..791f8f198e4785 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -35,7 +35,7 @@ void CompareNumericalAndManualGradients( absl::Span inputs, int input_index, float* expected_grad, int num_grad, bool use_function, double abs_error = 1e-2) { - Status s; + absl::Status s; AbstractTensorHandlePtr numerical_grad; { AbstractTensorHandle* numerical_grad_raw; @@ -62,17 +62,17 @@ void CompareNumericalAndManualGradients( TF_DeleteTensor(numerical_tensor); } -Status MatMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { +absl::Status MatMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { return ops::MatMul(ctx, inputs[0], inputs[1], &outputs[0], /*transpose_a=*/false, /*transpose_b=*/false, "MatMul"); } -Status MulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { +absl::Status MulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { return ops::Mul(ctx, inputs[0], inputs[1], &outputs[0], "Mul"); } @@ -89,13 +89,13 @@ class GradientCheckerTest TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); { - Status s = StatusFromTF_Status(status.get()); + absl::Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.message(); } { AbstractContext* ctx_raw = nullptr; - Status s = + absl::Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx_.reset(ctx_raw); @@ -120,8 +120,8 @@ TEST_P(GradientCheckerTest, TestMatMul) { AbstractTensorHandlePtr A; { AbstractTensorHandle* A_raw; - Status s = TestTensorHandleWithDims(ctx_.get(), A_vals, - A_dims, 2, &A_raw); + absl::Status s = TestTensorHandleWithDims( + ctx_.get(), A_vals, A_dims, 2, &A_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); A.reset(A_raw); } @@ -130,8 +130,8 @@ TEST_P(GradientCheckerTest, TestMatMul) { AbstractTensorHandlePtr B; { AbstractTensorHandle* B_raw; - Status s = TestTensorHandleWithDims(ctx_.get(), B_vals, - B_dims, 2, &B_raw); + absl::Status s = TestTensorHandleWithDims( + ctx_.get(), B_vals, B_dims, 2, &B_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); B.reset(B_raw); } @@ -146,7 +146,7 @@ TEST_P(GradientCheckerTest, TestMul) { AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr; - Status s = + absl::Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); @@ -155,7 +155,7 @@ TEST_P(GradientCheckerTest, TestMul) { AbstractTensorHandlePtr y; { AbstractTensorHandle* y_raw = nullptr; - Status s = + absl::Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); y.reset(y_raw); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 326a9e8cb829d4..2fa9f90726896a 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -33,8 +33,8 @@ int64_t ToId(const AbstractTensorHandle* t) { return static_cast(reinterpret_cast(t)); } -Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, - AbstractTensorHandle** result) { +absl::Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, + AbstractTensorHandle** result) { AbstractOperationPtr op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); if (isa(op.get())) { @@ -51,7 +51,7 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, } } // namespace -Status GradientRegistry::Register( +absl::Status GradientRegistry::Register( const string& op_name, GradientFunctionFactory gradient_function_factory) { auto iter = registry_.find(op_name); if (iter != registry_.end()) { @@ -61,7 +61,7 @@ Status GradientRegistry::Register( registry_.insert({op_name, gradient_function_factory}); return absl::OkStatus(); } -Status GradientRegistry::Lookup( +absl::Status GradientRegistry::Lookup( const ForwardOperation& op, std::unique_ptr* gradient_function) const { auto iter = registry_.find(op.op_name); @@ -107,15 +107,15 @@ class TapeVSpace // Calls the passed-in backward function. // op_type is the op's name provided in RecordOperation. - Status CallBackwardFunction( + absl::Status CallBackwardFunction( const string& op_type, GradientFunction* gradient_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, absl::Span result) const override; // Builds a tensor filled with ones with the same shape and dtype as `t`. - Status BuildOnesLike(const TapeTensor& t, - AbstractTensorHandle** result) const override; + absl::Status BuildOnesLike(const TapeTensor& t, + AbstractTensorHandle** result) const override; // Looks up the ID of a Gradient. int64_t TensorId(AbstractTensorHandle* tensor) const override; @@ -151,7 +151,7 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients( } AbstractOperationPtr op(ctx_->CreateOperation()); - Status s = op->Reset("AddN", /*raw_device_name=*/nullptr); + absl::Status s = op->Reset("AddN", /*raw_device_name=*/nullptr); if (!s.ok()) { return nullptr; } @@ -171,7 +171,7 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients( // Calls the passed-in backward function. // op_type is the op's name provided in RecordOperation. -Status TapeVSpace::CallBackwardFunction( +absl::Status TapeVSpace::CallBackwardFunction( const string& op_type, GradientFunction* gradient_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, @@ -186,8 +186,8 @@ Status TapeVSpace::CallBackwardFunction( return gradient_function->Compute(ctx_, output_gradients, result); } -Status TapeVSpace::BuildOnesLike(const TapeTensor& t, - AbstractTensorHandle** result) const { +absl::Status TapeVSpace::BuildOnesLike(const TapeTensor& t, + AbstractTensorHandle** result) const { AbstractOperationPtr op(ctx_->CreateOperation()); TF_RETURN_IF_ERROR(op->Reset("OnesLike", /*raw_device_name=*/nullptr)); if (isa(op.get())) { @@ -269,7 +269,7 @@ std::vector MakeTensorIDList( return ids; } -Status Tape::ComputeGradient( +absl::Status Tape::ComputeGradient( AbstractContext* ctx, absl::Span targets, absl::Span sources, absl::Span output_gradients, @@ -299,21 +299,21 @@ Status Tape::ComputeGradient( // the state of the ForwardOperation and call the tape as appropriate. // These APIs are mainly to facilitate testing and are subject to change. namespace internal { -Status Reset(AbstractOperation* op_, const char* op, - const char* raw_device_name, ForwardOperation* forward_op_) { +absl::Status Reset(AbstractOperation* op_, const char* op, + const char* raw_device_name, ForwardOperation* forward_op_) { forward_op_->op_name = op; forward_op_->attrs.Reset(op); return op_->Reset(op, raw_device_name); } -Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, - ForwardOperation* forward_op_) { +absl::Status AddInput(AbstractOperation* op_, AbstractTensorHandle* input, + ForwardOperation* forward_op_) { TF_RETURN_IF_ERROR(op_->AddInput(input)); forward_op_->inputs.push_back(input); return absl::OkStatus(); } -Status AddInputList(AbstractOperation* op_, - absl::Span inputs, - ForwardOperation* forward_op_) { +absl::Status AddInputList(AbstractOperation* op_, + absl::Span inputs, + ForwardOperation* forward_op_) { TF_RETURN_IF_ERROR(op_->AddInputList(inputs)); for (auto input : inputs) { forward_op_->inputs.push_back(input); @@ -321,35 +321,35 @@ Status AddInputList(AbstractOperation* op_, return absl::OkStatus(); } -Status SetAttrString(AbstractOperation* op_, const char* attr_name, - const char* data, size_t length, - ForwardOperation* forward_op_) { +absl::Status SetAttrString(AbstractOperation* op_, const char* attr_name, + const char* data, size_t length, + ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, StringPiece(data, length)); return op_->SetAttrString(attr_name, data, length); } -Status SetAttrInt(AbstractOperation* op_, const char* attr_name, int64_t value, - ForwardOperation* forward_op_) { +absl::Status SetAttrInt(AbstractOperation* op_, const char* attr_name, + int64_t value, ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, static_cast(value)); return op_->SetAttrInt(attr_name, value); } -Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, float value, - ForwardOperation* forward_op_) { +absl::Status SetAttrFloat(AbstractOperation* op_, const char* attr_name, + float value, ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, value); return op_->SetAttrFloat(attr_name, value); } -Status SetAttrBool(AbstractOperation* op_, const char* attr_name, bool value, - ForwardOperation* forward_op_) { +absl::Status SetAttrBool(AbstractOperation* op_, const char* attr_name, + bool value, ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, value); return op_->SetAttrBool(attr_name, value); } -Status SetAttrType(AbstractOperation* op_, const char* attr_name, - DataType value, ForwardOperation* forward_op_) { +absl::Status SetAttrType(AbstractOperation* op_, const char* attr_name, + DataType value, ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, value); return op_->SetAttrType(attr_name, value); } -Status SetAttrShape(AbstractOperation* op_, const char* attr_name, - const int64_t* dims, const int num_dims, - ForwardOperation* forward_op_) { +absl::Status SetAttrShape(AbstractOperation* op_, const char* attr_name, + const int64_t* dims, const int num_dims, + ForwardOperation* forward_op_) { if (num_dims > TensorShape::MaxDimensions()) { return errors::InvalidArgument("Value specified for `", attr_name, "` has ", num_dims, @@ -368,28 +368,28 @@ Status SetAttrShape(AbstractOperation* op_, const char* attr_name, forward_op_->attrs.Set(attr_name, proto); return op_->SetAttrShape(attr_name, dims, num_dims); } -Status SetAttrFunction(AbstractOperation* op_, const char* attr_name, - const AbstractOperation* value, - ForwardOperation* forward_op_) { +absl::Status SetAttrFunction(AbstractOperation* op_, const char* attr_name, + const AbstractOperation* value, + ForwardOperation* forward_op_) { return tensorflow::errors::Unimplemented( "SetAttrFunction has not been implemented yet."); } -Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name, - const char* value, size_t length, - ForwardOperation* forward_op_) { +absl::Status SetAttrFunctionName(AbstractOperation* op_, const char* attr_name, + const char* value, size_t length, + ForwardOperation* forward_op_) { return tensorflow::errors::Unimplemented( "SetAttrFunctionName has not been implemented " "yet."); } -Status SetAttrTensor(AbstractOperation* op_, const char* attr_name, - AbstractTensorInterface* tensor, - ForwardOperation* forward_op_) { +absl::Status SetAttrTensor(AbstractOperation* op_, const char* attr_name, + AbstractTensorInterface* tensor, + ForwardOperation* forward_op_) { return tensorflow::errors::Unimplemented( "SetAttrTensor has not been implemented yet."); } -Status SetAttrStringList(AbstractOperation* op_, const char* attr_name, - const void* const* values, const size_t* lengths, - int num_values, ForwardOperation* forward_op_) { +absl::Status SetAttrStringList(AbstractOperation* op_, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values, ForwardOperation* forward_op_) { std::vector v(num_values); for (int i = 0; i < num_values; ++i) { v[i] = StringPiece(static_cast(values[i]), lengths[i]); @@ -397,31 +397,31 @@ Status SetAttrStringList(AbstractOperation* op_, const char* attr_name, forward_op_->attrs.Set(attr_name, v); return op_->SetAttrStringList(attr_name, values, lengths, num_values); } -Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name, - const float* values, int num_values, - ForwardOperation* forward_op_) { +absl::Status SetAttrFloatList(AbstractOperation* op_, const char* attr_name, + const float* values, int num_values, + ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, gtl::ArraySlice(values, num_values)); return op_->SetAttrFloatList(attr_name, values, num_values); } -Status SetAttrIntList(AbstractOperation* op_, const char* attr_name, - const int64_t* values, int num_values, - ForwardOperation* forward_op_) { +absl::Status SetAttrIntList(AbstractOperation* op_, const char* attr_name, + const int64_t* values, int num_values, + ForwardOperation* forward_op_) { forward_op_->attrs.Set( attr_name, gtl::ArraySlice( reinterpret_cast(values), num_values)); return op_->SetAttrIntList(attr_name, values, num_values); } -Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name, - const DataType* values, int num_values, - ForwardOperation* forward_op_) { +absl::Status SetAttrTypeList(AbstractOperation* op_, const char* attr_name, + const DataType* values, int num_values, + ForwardOperation* forward_op_) { forward_op_->attrs.Set(attr_name, gtl::ArraySlice(values, num_values)); return op_->SetAttrTypeList(attr_name, values, num_values); } -Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name, - const unsigned char* values, int num_values, - ForwardOperation* forward_op_) { +absl::Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name, + const unsigned char* values, int num_values, + ForwardOperation* forward_op_) { std::unique_ptr b(new bool[num_values]); for (int i = 0; i < num_values; ++i) { b[i] = values[i]; @@ -430,9 +430,9 @@ Status SetAttrBoolList(AbstractOperation* op_, const char* attr_name, gtl::ArraySlice(b.get(), num_values)); return op_->SetAttrBoolList(attr_name, values, num_values); } -Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name, - const int64_t** dims, const int* num_dims, - int num_values, ForwardOperation* forward_op_) { +absl::Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, ForwardOperation* forward_op_) { std::unique_ptr proto(new TensorShapeProto[num_values]); for (int i = 0; i < num_values; ++i) { const auto num_dims_i = num_dims[i]; @@ -457,17 +457,17 @@ Status SetAttrShapeList(AbstractOperation* op_, const char* attr_name, attr_name, gtl::ArraySlice(proto.get(), num_values)); return op_->SetAttrShapeList(attr_name, dims, num_dims, num_values); } -Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name, - absl::Span values, - ForwardOperation* forward_op_) { +absl::Status SetAttrFunctionList(AbstractOperation* op_, const char* attr_name, + absl::Span values, + ForwardOperation* forward_op_) { return tensorflow::errors::Unimplemented( "SetAttrFunctionList has not been " "implemented yet."); } -Status Execute(AbstractOperation* op_, AbstractContext* ctx, - absl::Span retvals, int* num_retvals, - ForwardOperation* forward_op_, Tape* tape, - const GradientRegistry& registry) { +absl::Status Execute(AbstractOperation* op_, AbstractContext* ctx, + absl::Span retvals, + int* num_retvals, ForwardOperation* forward_op_, + Tape* tape, const GradientRegistry& registry) { TF_RETURN_IF_ERROR(op_->Execute(retvals, num_retvals)); for (int i = 0; i < *num_retvals; i++) { // TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs. diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index 2b4bd601b8ad0c..88c1df24945a5e 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -54,9 +54,10 @@ namespace gradients { // } class GradientFunction { public: - virtual Status Compute(AbstractContext* ctx, - absl::Span grad_outputs, - absl::Span grad_inputs) = 0; + virtual absl::Status Compute( + AbstractContext* ctx, + absl::Span grad_outputs, + absl::Span grad_inputs) = 0; virtual ~GradientFunction() {} }; @@ -77,10 +78,11 @@ using GradientFunctionFactory = // Map from op name to a `GradientFunctionFactory`. class GradientRegistry { public: - Status Register(const string& op, - GradientFunctionFactory gradient_function_factory); - Status Lookup(const ForwardOperation& op, - std::unique_ptr* gradient_function) const; + absl::Status Register(const string& op, + GradientFunctionFactory gradient_function_factory); + absl::Status Lookup( + const ForwardOperation& op, + std::unique_ptr* gradient_function) const; private: absl::flat_hash_map registry_; @@ -163,7 +165,7 @@ class Tape : protected eager::GradientTape targets, absl::Span sources, absl::Span output_gradients, diff --git a/tensorflow/c/eager/gradients_internal.h b/tensorflow/c/eager/gradients_internal.h index 1e14302c1721c1..93c2d36b33163f 100644 --- a/tensorflow/c/eager/gradients_internal.h +++ b/tensorflow/c/eager/gradients_internal.h @@ -27,58 +27,64 @@ namespace internal { // These APIs are mainly to facilitate testing and are subject to change. // Records the op name in the `ForwardOperation`. -Status Reset(AbstractOperation*, const char* op, const char* raw_device_name, - ForwardOperation*); +absl::Status Reset(AbstractOperation*, const char* op, + const char* raw_device_name, ForwardOperation*); // Records the inputs in the `ForwardOperation`. -Status AddInput(AbstractOperation*, AbstractTensorHandle*, ForwardOperation*); -Status AddInputList(AbstractOperation*, - absl::Span inputs, - ForwardOperation*); +absl::Status AddInput(AbstractOperation*, AbstractTensorHandle*, + ForwardOperation*); +absl::Status AddInputList(AbstractOperation*, + absl::Span inputs, + ForwardOperation*); // Sets the attrs in the `ForwardOperation`. -Status SetAttrString(AbstractOperation*, const char* attr_name, - const char* data, size_t length, ForwardOperation*); -Status SetAttrInt(AbstractOperation*, const char* attr_name, int64_t value, - ForwardOperation*); -Status SetAttrFloat(AbstractOperation*, const char* attr_name, float value, - ForwardOperation*); -Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value, - ForwardOperation*); -Status SetAttrType(AbstractOperation*, const char* attr_name, DataType value, - ForwardOperation*); -Status SetAttrShape(AbstractOperation*, const char* attr_name, - const int64_t* dims, const int num_dims, ForwardOperation*); -Status SetAttrFunction(AbstractOperation*, const char* attr_name, - const AbstractOperation* value, ForwardOperation*); -Status SetAttrFunctionName(AbstractOperation*, const char* attr_name, - const char* value, size_t length, ForwardOperation*); -Status SetAttrTensor(AbstractOperation*, const char* attr_name, - AbstractTensorInterface* tensor, ForwardOperation*); -Status SetAttrStringList(AbstractOperation*, const char* attr_name, - const void* const* values, const size_t* lengths, - int num_values, ForwardOperation*); -Status SetAttrFloatList(AbstractOperation*, const char* attr_name, - const float* values, int num_values, ForwardOperation*); -Status SetAttrIntList(AbstractOperation*, const char* attr_name, - const int64_t* values, int num_values, ForwardOperation*); -Status SetAttrTypeList(AbstractOperation*, const char* attr_name, - const DataType* values, int num_values, - ForwardOperation*); -Status SetAttrBoolList(AbstractOperation*, const char* attr_name, - const unsigned char* values, int num_values, - ForwardOperation*); -Status SetAttrShapeList(AbstractOperation*, const char* attr_name, - const int64_t** dims, const int* num_dims, - int num_values, ForwardOperation*); -Status SetAttrFunctionList(AbstractOperation*, const char* attr_name, - absl::Span values, - ForwardOperation*); +absl::Status SetAttrString(AbstractOperation*, const char* attr_name, + const char* data, size_t length, ForwardOperation*); +absl::Status SetAttrInt(AbstractOperation*, const char* attr_name, + int64_t value, ForwardOperation*); +absl::Status SetAttrFloat(AbstractOperation*, const char* attr_name, + float value, ForwardOperation*); +absl::Status SetAttrBool(AbstractOperation*, const char* attr_name, bool value, + ForwardOperation*); +absl::Status SetAttrType(AbstractOperation*, const char* attr_name, + DataType value, ForwardOperation*); +absl::Status SetAttrShape(AbstractOperation*, const char* attr_name, + const int64_t* dims, const int num_dims, + ForwardOperation*); +absl::Status SetAttrFunction(AbstractOperation*, const char* attr_name, + const AbstractOperation* value, ForwardOperation*); +absl::Status SetAttrFunctionName(AbstractOperation*, const char* attr_name, + const char* value, size_t length, + ForwardOperation*); +absl::Status SetAttrTensor(AbstractOperation*, const char* attr_name, + AbstractTensorInterface* tensor, ForwardOperation*); +absl::Status SetAttrStringList(AbstractOperation*, const char* attr_name, + const void* const* values, const size_t* lengths, + int num_values, ForwardOperation*); +absl::Status SetAttrFloatList(AbstractOperation*, const char* attr_name, + const float* values, int num_values, + ForwardOperation*); +absl::Status SetAttrIntList(AbstractOperation*, const char* attr_name, + const int64_t* values, int num_values, + ForwardOperation*); +absl::Status SetAttrTypeList(AbstractOperation*, const char* attr_name, + const DataType* values, int num_values, + ForwardOperation*); +absl::Status SetAttrBoolList(AbstractOperation*, const char* attr_name, + const unsigned char* values, int num_values, + ForwardOperation*); +absl::Status SetAttrShapeList(AbstractOperation*, const char* attr_name, + const int64_t** dims, const int* num_dims, + int num_values, ForwardOperation*); +absl::Status SetAttrFunctionList(AbstractOperation*, const char* attr_name, + absl::Span values, + ForwardOperation*); // Make the call to `Tape::RecordOperation`. -Status Execute(AbstractOperation*, AbstractContext*, - absl::Span retvals, int* num_retvals, - ForwardOperation*, Tape*, const GradientRegistry&); +absl::Status Execute(AbstractOperation*, AbstractContext*, + absl::Span retvals, + int* num_retvals, ForwardOperation*, Tape*, + const GradientRegistry&); } // namespace internal } // namespace gradients diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 9df16f10290d0b..ec1f506699c202 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -52,12 +52,12 @@ class CppGradients void SetUp() override { TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = StatusFromTF_Status(status.get()); + absl::Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.message(); } }; -Status RegisterGradients(GradientRegistry* registry) { +absl::Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics")); return absl::OkStatus(); } @@ -68,7 +68,7 @@ TEST_P(CppGradients, TestSetAttrString) { AbstractContextPtr ctx; { AbstractContext* ctx_raw = nullptr; - Status s = + absl::Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); @@ -77,15 +77,16 @@ TEST_P(CppGradients, TestSetAttrString) { AbstractTensorHandlePtr t; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + absl::Status s = + TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); t.reset(x_raw); } AbstractOperationPtr check_numerics_op(ctx->CreateOperation()); ForwardOperation forward_op; - Status s = Reset(check_numerics_op.get(), "CheckNumerics", - /*raw_device_name=*/nullptr, &forward_op); + absl::Status s = Reset(check_numerics_op.get(), "CheckNumerics", + /*raw_device_name=*/nullptr, &forward_op); ASSERT_EQ(errors::OK, s.code()) << s.message(); if (isa(check_numerics_op.get())) { s = dyn_cast(check_numerics_op.get()) @@ -114,7 +115,7 @@ TEST_P(CppGradients, TestSetAttrString) { ASSERT_EQ(read_message, message); } -Status RecordOperationWithNullGradientFunctionModel( +absl::Status RecordOperationWithNullGradientFunctionModel( AbstractContext* ctx, absl::Span inputs, absl::Span outputs) { Tape tape(/*persistent=*/false); @@ -134,7 +135,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { AbstractContextPtr ctx; { AbstractContext* ctx_raw = nullptr; - Status s = + absl::Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); @@ -143,15 +144,16 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) { AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + absl::Status s = + TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } std::vector outputs(1); - Status s = RunModel(RecordOperationWithNullGradientFunctionModel, ctx.get(), - {x.get()}, absl::MakeSpan(outputs), - /*use_function=*/!std::get<2>(GetParam())); + absl::Status s = RunModel(RecordOperationWithNullGradientFunctionModel, + ctx.get(), {x.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam())); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_EQ( "Provided null gradient_function for 'Neg'.\nIf the intent is to treat " diff --git a/tensorflow/c/eager/graph_function.cc b/tensorflow/c/eager/graph_function.cc index a0b4e249d7253b..ce11530aa15b3a 100644 --- a/tensorflow/c/eager/graph_function.cc +++ b/tensorflow/c/eager/graph_function.cc @@ -27,7 +27,7 @@ GraphFunction::GraphFunction(FunctionDef fdef) : AbstractFunction(kGraph), func_record_(new FunctionRecord(std::move(fdef), {}, true)) {} GraphFunction::~GraphFunction() {} -Status GraphFunction::GetFunctionDef(const FunctionDef **fdef) { +absl::Status GraphFunction::GetFunctionDef(const FunctionDef **fdef) { *fdef = &(func_record_->fdef()); return absl::OkStatus(); } diff --git a/tensorflow/c/eager/graph_function.h b/tensorflow/c/eager/graph_function.h index bde33e1ea7f50a..b15d1b4be2eed8 100644 --- a/tensorflow/c/eager/graph_function.h +++ b/tensorflow/c/eager/graph_function.h @@ -30,7 +30,7 @@ class GraphFunction : public AbstractFunction { // GraphFunction maybe stay alive for the duration of the returned // FunctionDef. - Status GetFunctionDef(const FunctionDef** fdef) override; + absl::Status GetFunctionDef(const FunctionDef** fdef) override; // Returns a shared reference to the wrapped function. absl::StatusOr> GetFunctionRecord() diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index badd80e0498fc8..216fcfe93bd8b0 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -96,7 +96,7 @@ class ImmediateExecutionContext : public AbstractContext { // Copy the handle to another device. virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* handle, const char* device_name, - Status* status) = 0; + absl::Status* status) = 0; // Create an operation to perform op execution ImmediateExecutionOperation* CreateOperation() override = 0; @@ -111,24 +111,25 @@ class ImmediateExecutionContext : public AbstractContext { // Add `devices` into context's device manager. Context's device manager // will take ownership and maintain devices' lifetime. - virtual Status AddDevices(std::vector> devices) = 0; + virtual absl::Status AddDevices( + std::vector> devices) = 0; // Block until all pending nodes are finished. - virtual Status AsyncWait() = 0; + virtual absl::Status AsyncWait() = 0; // Add a function (serialized FunctionDef protocol buffer) so that it can // be executed as an op. Return error if the function with the same name // already exists. - virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; + virtual absl::Status AddFunctionDef(const FunctionDef& fdef) = 0; // Notifies about the function removal. - virtual Status AddRemoveFunctionNotifier(const string& func, - std::function notifier) = 0; + virtual absl::Status AddRemoveFunctionNotifier( + const string& func, std::function notifier) = 0; // Same as `AddFunctionDef`, but additionally saves the `stack_traces` under // the key of the function definition name (to be retrieved during function // instantiation). - virtual Status AddFunctionDefWithStackTraces( + virtual absl::Status AddFunctionDefWithStackTraces( const FunctionDef& fdef, const StackTracesMap& stack_traces) = 0; // Find and return a added function by its name. @@ -184,8 +185,8 @@ class ImmediateExecutionContext : public AbstractContext { // already registered. // TODO(tfrt-devs): Remove this method. Let caller register it directly into // CustomDeviceOpHandler. - virtual Status RegisterCustomDevice(const string& name, - std::unique_ptr device) = 0; + virtual absl::Status RegisterCustomDevice( + const string& name, std::unique_ptr device) = 0; // Return FunctionLibraryDefinition. Transformations need to use it to use it // to invoke MLIR compiler passes. @@ -258,7 +259,7 @@ class ImmediateExecutionContext : public AbstractContext { // all tasks in the cluster. // This call internally coordinates with other tasks to initialize the eager // context and TF server for multi-client execution. - virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; + virtual absl::Status EnableCollectiveOps(const ServerDef& server_def) = 0; // Set a distributed manager that helps set up, update, and check liveness // of member tasks in the cluster. diff --git a/tensorflow/c/eager/immediate_execution_distributed_manager.h b/tensorflow/c/eager/immediate_execution_distributed_manager.h index 09011464a8b3d7..f4f4f0931dd78e 100644 --- a/tensorflow/c/eager/immediate_execution_distributed_manager.h +++ b/tensorflow/c/eager/immediate_execution_distributed_manager.h @@ -41,26 +41,26 @@ class ImmediateExecutionDistributedManager { // existing context state with the provided `server_def`. Contexts created // on remote tasks will be considered stale and garbage collected after // `keep_alive_secs` of inactivity. - virtual Status SetOrUpdateServerDef(const ServerDef& server_def, - bool reset_context, int keep_alive_secs, - int64_t init_timeout_in_ms, int retries, - bool clear_existing_contexts = false) = 0; + virtual absl::Status SetOrUpdateServerDef( + const ServerDef& server_def, bool reset_context, int keep_alive_secs, + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) = 0; // Initializes context for the local worker and no contexts will be created // for remote workers. Currently this only works for resetting context. // TODO(b/289445025): Consider removing this when we find a proper fix. - virtual Status InitializeLocalOnlyContext(const ServerDef& server_def, - int keep_alive_secs) = 0; + virtual absl::Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) = 0; // Set up a multi-client distributed execution environment. Must be called // on all tasks in the cluster. This call internally coordinates with other // tasks to initialize the eager context and TF server for multi-client // execution. - virtual Status EnableCollectiveOps(const ServerDef& server_def) = 0; + virtual absl::Status EnableCollectiveOps(const ServerDef& server_def) = 0; // Check if the remote task is alive. - virtual Status CheckRemoteAlive(const std::string& remote_task_name, - bool* is_alive) = 0; + virtual absl::Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) = 0; // Get pointer to the coordination service agent instance. virtual tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() = 0; diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h index 6e56b7b22ed78b..fb76af9dd60990 100644 --- a/tensorflow/c/eager/immediate_execution_operation.h +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -45,8 +45,8 @@ class ImmediateExecutionOperation : public AbstractOperation { // Returns the inputs of this op. virtual absl::Span GetInputs() const = 0; - virtual Status SetInput(size_t index, - ImmediateExecutionTensorHandle* input) = 0; + virtual absl::Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) = 0; virtual ImmediateExecutionContext* GetContext() const = 0; @@ -57,8 +57,8 @@ class ImmediateExecutionOperation : public AbstractOperation { virtual const tensorflow::OpDef* OpDef() const = 0; - virtual Status InputLength(const char* input_name, int* length) = 0; - virtual Status OutputLength(const char* output_name, int* length) = 0; + virtual absl::Status InputLength(const char* input_name, int* length) = 0; + virtual absl::Status OutputLength(const char* output_name, int* length) = 0; // Set stack trace to be used for potential async error reporting. virtual void SetStackTrace(ManagedStackTrace stack_trace) = 0; diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.cc b/tensorflow/c/eager/immediate_execution_tensor_handle.cc index c99a270f0cb804..25d72bff5277b5 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.cc +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.cc @@ -34,7 +34,7 @@ std::string ImmediateExecutionTensorHandle::DebugString() const { // messages. value_string = absl::StrCat(value_string.substr(0, 100), " [...]"); } - Status s; + absl::Status s; const char* device_name = DeviceName(&s); if (!s.ok()) { device_name = ""; @@ -44,9 +44,9 @@ std::string ImmediateExecutionTensorHandle::DebugString() const { device_name, "\")"); } -Status ImmediateExecutionTensorHandle::SummarizeValue( +absl::Status ImmediateExecutionTensorHandle::SummarizeValue( std::string& summary) const { - Status status; + absl::Status status; AbstractTensorPtr resolved( // TODO(allenl): Resolve should be const, and the caches that get updated // marked mutable. diff --git a/tensorflow/c/eager/immediate_execution_tensor_handle.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h index 133d0bbca2a909..61fc0fe8c04929 100644 --- a/tensorflow/c/eager/immediate_execution_tensor_handle.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -34,27 +34,27 @@ namespace tensorflow { class ImmediateExecutionTensorHandle : public AbstractTensorHandle { public: // Returns number of dimensions. - virtual Status NumDims(int* num_dims) const = 0; + virtual absl::Status NumDims(int* num_dims) const = 0; // Returns number of elements across all dimensions. - virtual Status NumElements(int64_t* num_elements) const = 0; + virtual absl::Status NumElements(int64_t* num_elements) const = 0; // Returns size of specified dimension // // -1 indicates an unknown axis length; this is unreachable for most standard // ImmediateExecutionTensorHandles, but comes up for example when computing // the shape of a parallel tensor with component shapes differing across // devices. - virtual Status Dim(int dim_index, int64_t* dim) const = 0; + virtual absl::Status Dim(int dim_index, int64_t* dim) const = 0; // Returns the device which created the handle. - virtual const char* DeviceName(Status* status) const = 0; + virtual const char* DeviceName(absl::Status* status) const = 0; // Returns the device where the tensor was placed. - virtual const char* BackingDeviceName(Status* status) const = 0; + virtual const char* BackingDeviceName(absl::Status* status) const = 0; // Returns the device type which created the handle. - virtual const char* DeviceType(Status* status) const = 0; + virtual const char* DeviceType(absl::Status* status) const = 0; // Returns the device ID which created the handle. - virtual int DeviceId(Status* status) const = 0; + virtual int DeviceId(absl::Status* status) const = 0; // Returns a tensor for the handle. If tensor is remote, it will be copied. - virtual AbstractTensorInterface* Resolve(Status* status) = 0; + virtual AbstractTensorInterface* Resolve(absl::Status* status) = 0; std::string DebugString() const override; @@ -73,7 +73,7 @@ class ImmediateExecutionTensorHandle : public AbstractTensorHandle { // debugging. Does not include a shape or dtype. // // Included in the default implementation of DebugString. - virtual Status SummarizeValue(std::string& summary) const; + virtual absl::Status SummarizeValue(std::string& summary) const; // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 23a52afcfd9e8a..0802cc46267f66 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -65,6 +65,7 @@ cc_library( "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", @@ -114,6 +115,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/c/eager/parallel_device/parallel_device.cc b/tensorflow/c/eager/parallel_device/parallel_device.cc index d735c07419f39c..84141067a1122a 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/optional.h" #include "absl/types/variant.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/c/tf_buffer.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/core/platform/status.h" namespace tensorflow { namespace parallel_device { @@ -210,7 +210,7 @@ void ParallelTensorDeallocator(void* data) { // number of dimensions of a parallel tensor. int ParallelTensorNumDims(void* data, TF_Status* status) { const std::vector* shape; - Status s = reinterpret_cast(data)->Shape(&shape); + absl::Status s = reinterpret_cast(data)->Shape(&shape); if (!s.ok()) { tsl::Set_TF_Status_from_Status(status, s); return -1; @@ -222,7 +222,7 @@ int ParallelTensorNumDims(void* data, TF_Status* status) { // dimension of a parallel tensor. int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) { const std::vector* shape; - Status s = reinterpret_cast(data)->Shape(&shape); + absl::Status s = reinterpret_cast(data)->Shape(&shape); if (!s.ok()) { tsl::Set_TF_Status_from_Status(status, s); return -1; @@ -233,7 +233,7 @@ int64_t ParallelTensorDim(void* data, int dim_index, TF_Status* status) { TF_Buffer* ParallelTensorSummarize(void* data, TF_Status* status) { ParallelTensor* parallel_tensor = reinterpret_cast(data); std::string summary; - Status cpp_status = parallel_tensor->SummarizeValue(summary); + absl::Status cpp_status = parallel_tensor->SummarizeValue(summary); if (!cpp_status.ok()) { tsl::Set_TF_Status_from_Status(status, cpp_status); return nullptr; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index c322e6448cc307..8b11c1d5da45f6 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/device_name_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/thread_annotations.h" @@ -582,7 +581,7 @@ std::unique_ptr ParallelTensor::FromTensorHandles( new ParallelTensor(parallel_device, std::move(components), dtype)); } -Status ParallelTensor::Shape(const std::vector** shape) const { +absl::Status ParallelTensor::Shape(const std::vector** shape) const { if (!shape_.has_value()) { TF_Status status; PartialTensorShape combined_shape; @@ -621,7 +620,7 @@ Status ParallelTensor::Shape(const std::vector** shape) const { return absl::OkStatus(); } -Status ParallelTensor::SummarizeValue(std::string& summary) { +absl::Status ParallelTensor::SummarizeValue(std::string& summary) { summary = "{"; std::vector summarized_devices = device_.SummarizeDeviceNames(); for (int component_index = 0; component_index < tensors_.size(); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index fd431d70bb78ac..03845d15e34b40 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "absl/types/variant.h" @@ -200,12 +201,12 @@ class ParallelTensor { // `shape` output argument. This blocks waiting for async tensors, may return // a delayed bad status encountered during async execution, and will return a // bad status unless all tensors have the same shape. - Status Shape(const std::vector** shape) const; + absl::Status Shape(const std::vector** shape) const; TF_DataType dtype() const { return dtype_; } // Sets its output argument to a summary of the values of this tensor on every // component device. - Status SummarizeValue(std::string& summary); + absl::Status SummarizeValue(std::string& summary); std::vector release_tensors() { return std::move(tensors_); } diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc index 1123ccbf33284f..04a8a82bf2ea07 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" #include +#include "absl/status/status.h" #include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -134,7 +134,7 @@ TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) { ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); const std::vector>& handles = *outputs; const std::vector* shape; - Status s = handles[0]->Shape(&shape); + absl::Status s = handles[0]->Shape(&shape); ASSERT_TRUE(s.ok()); EXPECT_EQ(0, shape->size()); } diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index f7fa3b2491a40b..7ed8025ba24d9f 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -98,15 +98,15 @@ class VSpace { // // `unneeded_gradients` contains sorted list of input indices for which a // gradient is not required. - virtual Status CallBackwardFunction( + virtual absl::Status CallBackwardFunction( const string& op_type, BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, absl::Span result) const = 0; // Builds a tensor filled with ones with the same shape and dtype as `t`. - virtual Status BuildOnesLike(const TapeTensor& t, - Gradient** result) const = 0; + virtual absl::Status BuildOnesLike(const TapeTensor& t, + Gradient** result) const = 0; // Looks up the ID of a Gradient. virtual int64_t TensorId(Gradient* tensor) const = 0; @@ -172,7 +172,7 @@ class GradientTape { // When running backward functions, builds zeros-like tensors for // incoming grads which are nullptrs, unless `build_default_zeros_grads` // is set to false. - Status ComputeGradient( + absl::Status ComputeGradient( const VSpace& vspace, const absl::Span target_tensor_ids, const absl::Span source_tensor_ids, @@ -203,13 +203,13 @@ class GradientTape { // that. template class ForwardFunction - : public std::function&, - std::vector*, bool)> { + : public std::function&, + std::vector*, bool)> { public: template explicit ForwardFunction(lambda_type lambda) - : std::function&, - std::vector*, bool)>(lambda) {} + : std::function&, + std::vector*, bool)>(lambda) {} }; // Computes Jacobian-vector products using forward-mode automatic @@ -280,7 +280,7 @@ class ForwardAccumulator { // // This method is not thread-safe (and in general ForwardAccumulator is not // thread-safe). - Status Accumulate( + absl::Status Accumulate( const string& op_type, const std::vector& input_tensors, const std::vector& output_tensors, absl::Span input_tensor_id, @@ -329,7 +329,7 @@ class ForwardAccumulator { // Accumulate will forward op executions to the tape while the backward // function is running; this effectively adds the backward tape to the active // set (but does not require complicated callbacks to the language bindings). - Status ForwardpropFromTape( + absl::Status ForwardpropFromTape( const string& op_type, const std::vector& output_tensors, const std::function& backward_function_getter, const std::function& backward_function_deleter, @@ -603,7 +603,7 @@ std::vector InitialStack( } template -Status InitialGradients( +absl::Status InitialGradients( const VSpace& vspace, absl::Span target_tensor_ids, const std::unordered_map& sources_that_are_targets, @@ -688,7 +688,8 @@ constexpr int kMinAggregateCount = 4; constexpr int kMinAggregateBytes = 128 * 1024 * 1024; template -Status GradientTape::ComputeGradient( +absl::Status +GradientTape::ComputeGradient( const VSpace& vspace, const absl::Span target_tensor_ids, const absl::Span source_tensor_ids, @@ -702,9 +703,9 @@ Status GradientTape::ComputeGradient( std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); std::unordered_map> gradients; - Status s = InitialGradients(vspace, target_tensor_ids, - sources_that_are_targets, output_gradients, - tensor_tape_, state.op_tape, &gradients); + absl::Status s = InitialGradients(vspace, target_tensor_ids, + sources_that_are_targets, output_gradients, + tensor_tape_, state.op_tape, &gradients); auto cleanup = gtl::MakeCleanup([this, &state]() { if (!persistent_) { // Release all backprop functions @@ -789,7 +790,7 @@ Status GradientTape::ComputeGradient( for (const auto i : zero_indices) { out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); } - Status s; + absl::Status s; s = vspace.CallBackwardFunction(trace.op_type, trace.backward_function, unneeded_gradients, out_gradients, absl::MakeSpan(in_gradients)); @@ -929,7 +930,7 @@ bool ForwardAccumulator::ShouldRecord( } template -Status +absl::Status ForwardAccumulator::ForwardpropFromTape( const string& op_type, const std::vector& output_tensors, const std::function& backward_function_getter, @@ -1028,7 +1029,8 @@ ForwardAccumulator::ForwardpropFromTape( } template -Status ForwardAccumulator::Accumulate( +absl::Status +ForwardAccumulator::Accumulate( const string& op_type, const std::vector& input_tensors, const std::vector& output_tensors, absl::Span input_tensor_id, diff --git a/tensorflow/c/eager/tracing_utils.cc b/tensorflow/c/eager/tracing_utils.cc index 17ce98500eca99..0ae0b87bcf0bb7 100644 --- a/tensorflow/c/eager/tracing_utils.cc +++ b/tensorflow/c/eager/tracing_utils.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { namespace tracing { -Status MaybeSetOpName(AbstractOperation* op, const char* op_name) { +absl::Status MaybeSetOpName(AbstractOperation* op, const char* op_name) { if (isa(op)) { TF_RETURN_IF_ERROR(dyn_cast(op)->SetOpName(op_name)); } diff --git a/tensorflow/c/eager/tracing_utils.h b/tensorflow/c/eager/tracing_utils.h index 45a7e33fd01465..1c33632237d326 100644 --- a/tensorflow/c/eager/tracing_utils.h +++ b/tensorflow/c/eager/tracing_utils.h @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace tracing { -Status MaybeSetOpName(AbstractOperation*, const char* op_name); +absl::Status MaybeSetOpName(AbstractOperation*, const char* op_name); } // namespace tracing } // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc index 8d807eca7c5ff1..fbd996d467e41f 100644 --- a/tensorflow/c/eager/unified_api_test.cc +++ b/tensorflow/c/eager/unified_api_test.cc @@ -29,7 +29,7 @@ class UnifiedAPI void SetUp() override { TF_StatusPtr status(TF_NewStatus()); TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = StatusFromTF_Status(status.get()); + absl::Status s = StatusFromTF_Status(status.get()); CHECK_EQ(errors::OK, s.code()) << s.message(); } @@ -39,9 +39,9 @@ class UnifiedAPI }; // Checks that inputs[0] is a scalar. -Status TestScalarShape(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { +absl::Status TestScalarShape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { PartialTensorShape shape; TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); if (shape.dims() != 0) { @@ -59,7 +59,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) { AbstractContextPtr ctx; { AbstractContext* ctx_raw = nullptr; - Status s = + absl::Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); @@ -68,22 +68,23 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) { AbstractTensorHandlePtr x; { AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + absl::Status s = + TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } - Status s = RunModel(TestScalarShape, ctx.get(), - /*inputs=*/{x.get()}, - /*outputs=*/{}, - /*use_function=*/UseFunction()); + absl::Status s = RunModel(TestScalarShape, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); ASSERT_EQ(errors::OK, s.code()) << s.message(); } // Checks that inputs[0] is a matrix with shape 2x4. -Status TestTensorShape2x4(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { +absl::Status TestTensorShape2x4(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { PartialTensorShape shape; TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); if (shape.dims() != 2) { @@ -109,7 +110,7 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) { AbstractContextPtr ctx; { AbstractContext* ctx_raw = nullptr; - Status s = + absl::Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); ctx.reset(ctx_raw); @@ -120,16 +121,16 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) { AbstractTensorHandle* x_raw = nullptr; float data[] = {0., 0., 0., 0., 0., 0., 0., 0}; int64_t dim_sizes[] = {2, 4}; - Status s = TestTensorHandleWithDims(ctx.get(), data, - dim_sizes, 2, &x_raw); + absl::Status s = TestTensorHandleWithDims( + ctx.get(), data, dim_sizes, 2, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } - Status s = RunModel(TestTensorShape2x4, ctx.get(), - /*inputs=*/{x.get()}, - /*outputs=*/{}, - /*use_function=*/UseFunction()); + absl::Status s = RunModel(TestTensorShape2x4, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); ASSERT_EQ(errors::OK, s.code()) << s.message(); } @@ -146,14 +147,14 @@ TEST_P(UnifiedAPI, TestUnknownShapeTracing) { { tracing::TracingTensorHandle* x_raw = nullptr; PartialTensorShape shape; - Status s = dyn_cast(ctx.get())->AddParameter( + absl::Status s = dyn_cast(ctx.get())->AddParameter( DT_FLOAT, shape, &x_raw); ASSERT_EQ(errors::OK, s.code()) << s.message(); x.reset(x_raw); } PartialTensorShape shape; - Status s = x->Shape(&shape); + absl::Status s = x->Shape(&shape); ASSERT_EQ(errors::OK, s.code()) << s.message(); ASSERT_TRUE(shape.unknown_rank()); } @@ -171,7 +172,7 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) { tracing::TracingTensorHandle* x_raw = nullptr; PartialTensorShape shape; int64_t dim_sizes[] = {2, -1}; - Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape); + absl::Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape); ASSERT_EQ(errors::OK, s.code()) << s.message(); s = dyn_cast(ctx.get())->AddParameter( DT_FLOAT, shape, &x_raw); @@ -180,7 +181,7 @@ TEST_P(UnifiedAPI, TestPartialShapeTracing) { } PartialTensorShape shape; - Status s = x->Shape(&shape); + absl::Status s = x->Shape(&shape); ASSERT_EQ(errors::OK, s.code()) << s.message(); ASSERT_FALSE(shape.unknown_rank()); diff --git a/tensorflow/c/eager/unified_api_testutil.cc b/tensorflow/c/eager/unified_api_testutil.cc index a3ddfa4f761663..513da3ce27f265 100644 --- a/tensorflow/c/eager/unified_api_testutil.cc +++ b/tensorflow/c/eager/unified_api_testutil.cc @@ -34,9 +34,9 @@ AbstractContext* BuildFunction(const char* fn_name) { return unwrap(graph_ctx); } -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - std::vector* params) { +absl::Status CreateParamsForInputs( + AbstractContext* ctx, absl::Span inputs, + std::vector* params) { tracing::TracingTensorHandle* handle = nullptr; for (auto input : inputs) { PartialTensorShape shape; @@ -49,9 +49,10 @@ Status CreateParamsForInputs(AbstractContext* ctx, } // Runs `model` maybe wrapped in a function. -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function) { +absl::Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + bool use_function) { if (use_function) { const char* fn_name = "test_fn"; core::RefCountPtr scoped_func; @@ -119,7 +120,8 @@ Status RunModel(Model model, AbstractContext* ctx, } } -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { +absl::Status BuildImmediateExecutionContext(bool use_tfrt, + AbstractContext** ctx) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); @@ -130,7 +132,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { return absl::OkStatus(); } -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { +absl::Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_TensorHandle* result_t = diff --git a/tensorflow/c/eager/unified_api_testutil.h b/tensorflow/c/eager/unified_api_testutil.h index 3d47b775a2817e..2df18c13d0c692 100644 --- a/tensorflow/c/eager/unified_api_testutil.h +++ b/tensorflow/c/eager/unified_api_testutil.h @@ -31,14 +31,14 @@ AbstractContext* BuildFunction(const char* fn_name); // Creates parameters (placeholders) in the tracing `ctx` using the shape and // dtype of `inputs`. -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - std::vector* params); +absl::Status CreateParamsForInputs( + AbstractContext* ctx, absl::Span inputs, + std::vector* params); // A callable that takes tensor inputs and returns zero or more tensor outputs. -using Model = std::function, - absl::Span)>; +using Model = std::function, + absl::Span)>; // Runs `model` maybe wrapped in a function call op. This can be thought as // being equivalent to the following python code. @@ -47,17 +47,19 @@ using Model = std::function inputs, - absl::Span outputs, bool use_function); +absl::Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + bool use_function); -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); +absl::Status BuildImmediateExecutionContext(bool use_tfrt, + AbstractContext** ctx); // Return a tensor handle with given type, values and dimensions. template -Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data, - const int64_t* dims, int num_dims, - AbstractTensorHandle** tensor) { +absl::Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data, + const int64_t* dims, int num_dims, + AbstractTensorHandle** tensor) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_Context* eager_ctx = @@ -72,8 +74,8 @@ Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data, // Return a scalar tensor handle with given value. template -Status TestScalarTensorHandle(AbstractContext* ctx, const T value, - AbstractTensorHandle** tensor) { +absl::Status TestScalarTensorHandle(AbstractContext* ctx, const T value, + AbstractTensorHandle** tensor) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_Context* eager_ctx = @@ -87,7 +89,7 @@ Status TestScalarTensorHandle(AbstractContext* ctx, const T value, } // Places data from `t` into *result_tensor. -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); +absl::Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); } // namespace tensorflow #endif // TENSORFLOW_C_EAGER_UNIFIED_API_TESTUTIL_H_ diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 4ed617ba6dc628..9dc78ada8ae4fb 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -59,8 +59,7 @@ void TF_FileStat(const char* filename, TF_FileStatistics* stats, TF_Status* status) { ::tensorflow::FileStatistics cc_stats; TF_SetStatus(status, TF_OK, ""); - ::tensorflow::Status s = - ::tensorflow::Env::Default()->Stat(filename, &cc_stats); + absl::Status s = ::tensorflow::Env::Default()->Stat(filename, &cc_stats); ::tensorflow::Set_TF_Status_from_Status(status, s); if (s.ok()) { stats->length = cc_stats.length; @@ -73,8 +72,7 @@ void TF_NewWritableFile(const char* filename, TF_WritableFileHandle** handle, TF_Status* status) { std::unique_ptr<::tensorflow::WritableFile> f; TF_SetStatus(status, TF_OK, ""); - ::tensorflow::Status s = - ::tensorflow::Env::Default()->NewWritableFile(filename, &f); + absl::Status s = ::tensorflow::Env::Default()->NewWritableFile(filename, &f); ::tensorflow::Set_TF_Status_from_Status(status, s); if (s.ok()) { diff --git a/tensorflow/c/experimental/filesystem/BUILD b/tensorflow/c/experimental/filesystem/BUILD index d25e6e9314f088..a8df18adf63470 100644 --- a/tensorflow/c/experimental/filesystem/BUILD +++ b/tensorflow/c/experimental/filesystem/BUILD @@ -36,11 +36,19 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":filesystem_interface", + "//tensorflow/c:tf_file_statistics", "//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_internal", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:env", "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:file_statistics", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/platform:stringpiece", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.cc b/tensorflow/c/experimental/filesystem/modular_filesystem.cc index d030948787acdd..7fede4ff7dc801 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.cc @@ -18,11 +18,23 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem_registration.h" +#include "tensorflow/c/tf_file_statistics.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/file_statistics.h" +#include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/file_system_helper.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" // TODO(b/139060984): After all filesystems are converted, all calls to // methods from `FileSystem` will have to be replaced to calls to private diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem.h b/tensorflow/c/experimental/filesystem/modular_filesystem.h index 091b84529668a5..a3e020d580e4ac 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem.h @@ -18,7 +18,12 @@ limitations under the License. #include #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/file_system.h" /// This file builds classes needed to hold a filesystem implementation in the /// modular world. Once all TensorFlow filesystems are converted to use the @@ -61,52 +66,58 @@ class ModularFileSystem final : public FileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - Status NewRandomAccessFile( + absl::Status NewRandomAccessFile( const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status NewWritableFile(const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override; - Status NewAppendableFile(const std::string& fname, TransactionToken* token, - std::unique_ptr* result) override; - Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewWritableFile(const std::string& fname, + TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewAppendableFile( + const std::string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewReadOnlyMemoryRegionFromFile( const std::string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status FileExists(const std::string& fname, TransactionToken* token) override; + absl::Status FileExists(const std::string& fname, + TransactionToken* token) override; bool FilesExist(const std::vector& files, TransactionToken* token, - std::vector* status) override; - Status GetChildren(const std::string& dir, TransactionToken* token, - std::vector* result) override; - Status GetMatchingPaths(const std::string& pattern, TransactionToken* token, - std::vector* results) override; - Status DeleteFile(const std::string& fname, TransactionToken* token) override; - Status DeleteRecursively(const std::string& dirname, TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) override; - Status DeleteDir(const std::string& dirname, - TransactionToken* token) override; - Status RecursivelyCreateDir(const std::string& dirname, - TransactionToken* token) override; - Status CreateDir(const std::string& dirname, - TransactionToken* token) override; - Status Stat(const std::string& fname, TransactionToken* token, - FileStatistics* stat) override; - Status IsDirectory(const std::string& fname, - TransactionToken* token) override; - Status GetFileSize(const std::string& fname, TransactionToken* token, - uint64* file_size) override; - Status RenameFile(const std::string& src, const std::string& target, - TransactionToken* token) override; - Status CopyFile(const std::string& src, const std::string& target, - TransactionToken* token) override; + std::vector* status) override; + absl::Status GetChildren(const std::string& dir, TransactionToken* token, + std::vector* result) override; + absl::Status GetMatchingPaths(const std::string& pattern, + TransactionToken* token, + std::vector* results) override; + absl::Status DeleteFile(const std::string& fname, + TransactionToken* token) override; + absl::Status DeleteRecursively(const std::string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) override; + absl::Status DeleteDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status RecursivelyCreateDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status CreateDir(const std::string& dirname, + TransactionToken* token) override; + absl::Status Stat(const std::string& fname, TransactionToken* token, + FileStatistics* stat) override; + absl::Status IsDirectory(const std::string& fname, + TransactionToken* token) override; + absl::Status GetFileSize(const std::string& fname, TransactionToken* token, + uint64* file_size) override; + absl::Status RenameFile(const std::string& src, const std::string& target, + TransactionToken* token) override; + absl::Status CopyFile(const std::string& src, const std::string& target, + TransactionToken* token) override; std::string TranslateName(const std::string& name) const override; void FlushCaches(TransactionToken* token) override; - Status SetOption(const std::string& name, - const std::vector& values) override; - Status SetOption(const std::string& name, - const std::vector& values) override; - Status SetOption(const std::string& name, - const std::vector& values) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; + absl::Status SetOption(const std::string& name, + const std::vector& values) override; private: std::unique_ptr filesystem_; @@ -130,9 +141,9 @@ class ModularRandomAccessFile final : public RandomAccessFile { ~ModularRandomAccessFile() override { ops_->cleanup(file_.get()); } - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override; - Status Name(StringPiece* result) const override; + absl::Status Read(uint64 offset, size_t n, StringPiece* result, + char* scratch) const override; + absl::Status Name(StringPiece* result) const override; private: std::string filename_; @@ -151,12 +162,12 @@ class ModularWritableFile final : public WritableFile { ~ModularWritableFile() override { ops_->cleanup(file_.get()); } - Status Append(StringPiece data) override; - Status Close() override; - Status Flush() override; - Status Sync() override; - Status Name(StringPiece* result) const override; - Status Tell(int64_t* position) override; + absl::Status Append(StringPiece data) override; + absl::Status Close() override; + absl::Status Flush() override; + absl::Status Sync() override; + absl::Status Name(StringPiece* result) const override; + absl::Status Tell(int64_t* position) override; private: std::string filename_; @@ -185,7 +196,7 @@ class ModularReadOnlyMemoryRegion final : public ReadOnlyMemoryRegion { }; // Registers a filesystem plugin so that core TensorFlow can use it. -Status RegisterFilesystemPlugin(const std::string& dso_path); +absl::Status RegisterFilesystemPlugin(const std::string& dso_path); } // namespace tensorflow diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc index ce5a4282e61091..58112a3fbe2296 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.cc @@ -16,10 +16,17 @@ limitations under the License. #include +#include "absl/log/log.h" +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/modular_filesystem.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h index d8b0a28723b55e..4ddfbf6c210e87 100644 --- a/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h +++ b/tensorflow/c/experimental/filesystem/modular_filesystem_registration.h @@ -25,7 +25,7 @@ namespace filesystem_registration { // // Don't call this directly. Instead call `RegisterFilesystemPlugin`. // Exposed only for static registration of local filesystems. -Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info); +absl::Status RegisterFilesystemPluginImpl(const TF_FilesystemPluginInfo* info); } // namespace filesystem_registration } // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index a3fa49fffa34b7..92c87960986436 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -55,7 +55,9 @@ cc_library( hdrs = [ "nn_grad.h", ], - visibility = ["//visibility:private"], + visibility = [ + "//tensorflow/python/framework/experimental:__pkg__", + ], deps = [ "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:gradients_internal", diff --git a/tensorflow/c/experimental/gradients/tape/BUILD b/tensorflow/c/experimental/gradients/tape/BUILD index c29b7929d43b27..3097c31e289fd6 100644 --- a/tensorflow/c/experimental/gradients/tape/BUILD +++ b/tensorflow/c/experimental/gradients/tape/BUILD @@ -18,6 +18,10 @@ cc_library( deps = [ ":tape_operation", "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", + "//tensorflow/c/eager:gradients_internal", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/platform:status", ], ) @@ -31,9 +35,22 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:gradients_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:errors", + "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", + "//tensorflow/core/platform:stringpiece", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", ], ) @@ -49,9 +66,16 @@ cc_library( deps = [ ":tape_context", ":tape_operation", + "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_function", "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:gradients_internal", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.cc b/tensorflow/c/experimental/gradients/tape/tape_context.cc index 1fa1a3f24f193a..5285b6a088e5b0 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_context.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_context.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/c/experimental/gradients/tape/tape_context.h" #include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/gradients/tape/tape_operation.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace gradients { diff --git a/tensorflow/c/experimental/gradients/tape/tape_context.h b/tensorflow/c/experimental/gradients/tape/tape_context.h index 291053226fb4a5..a7588362325fc1 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_context.h +++ b/tensorflow/c/experimental/gradients/tape/tape_context.h @@ -16,7 +16,11 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_CONTEXT_H_ #include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/gradients/tape/tape_operation.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace gradients { diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.cc b/tensorflow/c/experimental/gradients/tape/tape_operation.cc index a972521915c9bf..5bd3daa4037fbe 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.cc +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.cc @@ -14,8 +14,22 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/gradients/tape/tape_operation.h" -#include "tensorflow/c/eager/abstract_context.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace gradients { diff --git a/tensorflow/c/experimental/gradients/tape/tape_operation.h b/tensorflow/c/experimental/gradients/tape/tape_operation.h index b971176d9e71a3..2ab67394988cf9 100644 --- a/tensorflow/c/experimental/gradients/tape/tape_operation.h +++ b/tensorflow/c/experimental/gradients/tape/tape_operation.h @@ -15,8 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ #define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_TAPE_TAPE_OPERATION_H_ +#include "absl/types/span.h" #include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace gradients { diff --git a/tensorflow/c/experimental/grappler/BUILD b/tensorflow/c/experimental/grappler/BUILD index 5f996d322a7f7a..6c3620088eb9b2 100644 --- a/tensorflow/c/experimental/grappler/BUILD +++ b/tensorflow/c/experimental/grappler/BUILD @@ -71,6 +71,6 @@ tf_cc_test( "//tensorflow/core/platform:status", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/log:check", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/c/experimental/grappler/grappler.cc b/tensorflow/c/experimental/grappler/grappler.cc index f6f00d43347240..f154e11a836670 100644 --- a/tensorflow/c/experimental/grappler/grappler.cc +++ b/tensorflow/c/experimental/grappler/grappler.cc @@ -94,8 +94,9 @@ absl::Status ValidateTPOptimizerConfigs(const TP_OptimizerConfigs& configs) { namespace tensorflow { namespace grappler { -Status CGraphOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph_def) { +absl::Status CGraphOptimizer::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph_def) { OwnedTFStatus c_status(TF_NewStatus()); OwnedTFBuffer graph_buf(TF_NewBuffer()); OwnedTFBuffer optimized_graph_buf(TF_NewBuffer()); diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc index 357432e2c58018..5199aadc45632f 100644 --- a/tensorflow/c/experimental/grappler/grappler_test.cc +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/tf_buffer_internal.h" #include "tensorflow/c/tf_status.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace grappler { @@ -102,7 +102,7 @@ TEST(Grappler, DeviceTypeNotSet) { params->device_type = nullptr; }; - tensorflow::Status status = InitGraphPlugin(plugin_init); + absl::Status status = InitGraphPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ( status.message(), @@ -118,7 +118,7 @@ TEST(Grappler, OptimizeFuncNotSet) { params->optimizer->optimize_func = nullptr; }; - tensorflow::Status status = InitGraphPlugin(plugin_init); + absl::Status status = InitGraphPlugin(plugin_init); ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); ASSERT_EQ(status.message(), "'optimize_func' field in TP_Optimizer must be set."); @@ -223,7 +223,7 @@ TEST(TF_GraphProperties, InputProperties) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::OpInfo::TensorProperties in_props; - Status s = tensorflow::BufferToMessage(in_props_buf[0], &in_props); + absl::Status s = tensorflow::BufferToMessage(in_props_buf[0], &in_props); TF_ASSERT_OK(s); EXPECT_EQ(DT_FLOAT, in_props.dtype()); @@ -271,7 +271,8 @@ TEST(TF_GraphProperties, OutputProperties) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); tensorflow::OpInfo::TensorProperties out_props; - Status s = tensorflow::BufferToMessage(out_props_buf[0], &out_props); + absl::Status s = + tensorflow::BufferToMessage(out_props_buf[0], &out_props); TF_ASSERT_OK(s); EXPECT_EQ(DT_FLOAT, out_props.dtype()); @@ -294,7 +295,7 @@ TEST(TF_FunctionLibraryDefinition, LookUpOpDef) { TF_Buffer* op_buf = TF_NewBuffer(); TF_Status* status = TF_NewStatus(); GraphDef g_def; - Status s = MessageToBuffer(g_def, g_buf); + absl::Status s = MessageToBuffer(g_def, g_buf); TF_ASSERT_OK(s); TF_FunctionLibraryDefinition* func = TF_NewFunctionLibraryDefinition(g_buf, status); diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 45c55c315f5350..893263bcb9437d 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -89,7 +89,6 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla:shape_util", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt:pjrt_c_api_client", @@ -98,5 +97,6 @@ tf_cc_test( "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", "@local_xla//xla/pjrt/cpu:cpu_client", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index 1952364d882776..6f807a83237021 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -31,12 +31,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/status_matchers.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index 23deef1d637f2d..f60b1cf0546297 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -36,9 +36,9 @@ namespace ops { // or value. // // Description: -Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, const char* name, - const char* raw_device_name) { +absl::Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Identity", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -67,10 +67,10 @@ Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, // def ApplyG(op, dy, _): // return [None, g(dy)] # Do not backprop to f(x). // ``` -Status IdentityN(AbstractContext* ctx, - absl::Span input, - absl::Span output, const char* name, - const char* raw_device_name) { +absl::Status IdentityN(AbstractContext* ctx, + absl::Span input, + absl::Span output, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("IdentityN", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -83,9 +83,9 @@ Status IdentityN(AbstractContext* ctx, // Summary: Returns a tensor of zeros with the same shape and type as x. // // Description: -Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("ZerosLike", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -107,9 +107,9 @@ Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, // # 't' is [[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]]] // shape(t) ==> [2, 2, 3] // ``` -Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, DataType out_type, const char* name, - const char* raw_device_name) { +absl::Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, DataType out_type, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Shape", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -153,10 +153,10 @@ Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, // // This operation is related to `squeeze()`, which removes dimensions of // size 1. -Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle* const dim, - AbstractTensorHandle** output, const char* name, - const char* raw_device_name) { +absl::Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const dim, + AbstractTensorHandle** output, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("ExpandDims", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -170,9 +170,9 @@ Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, // Summary: Returns a tensor of ones with the same shape and type as x. // // Description: -Status OnesLike(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status OnesLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("OnesLike", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index 466c36f1dde8ae..af4a46cf58b1d3 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -29,37 +29,39 @@ namespace ops { // Return a tensor with the same shape and contents as the input tensor or // value. -Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Identity(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns a list of tensors with the same shapes and contents as the input -Status IdentityN(AbstractContext* ctx, - absl::Span input, - absl::Span output, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status IdentityN(AbstractContext* ctx, + absl::Span input, + absl::Span output, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns a tensor of zeros with the same shape and type as x. -Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns the shape of a tensor. -Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, DataType out_type = DT_INT32, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status Shape(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, DataType out_type = DT_INT32, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Inserts a dimension of 1 into a tensor's shape. -Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle* const dim, - AbstractTensorHandle** output, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status ExpandDims(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const dim, + AbstractTensorHandle** output, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns a tensor of ones with the same shape and type as x. -Status OnesLike(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status OnesLike(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/gen/cpp/BUILD b/tensorflow/c/experimental/ops/gen/cpp/BUILD index d2fd0294adbc2e..36919e7083306e 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/BUILD @@ -20,6 +20,7 @@ cc_library( deps = [ "//tensorflow/c/experimental/ops/gen/common", "//tensorflow/c/experimental/ops/gen/cpp/renderers", + "//tensorflow/c/experimental/ops/gen/cpp/views", "//tensorflow/c/experimental/ops/gen/model", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -39,11 +40,14 @@ tf_cc_test( data = ["//tensorflow/c/experimental/ops/gen/cpp/golden"], deps = [ ":cpp", + "//tensorflow/c/experimental/ops/gen/common", + "//tensorflow/c/experimental/ops/gen/cpp/renderers", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc index 9e8aede1a21906..82368608201ebd 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.cc @@ -14,8 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_file_renderer.h" -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/c/experimental/ops/gen/cpp/views/op_view.h" +#include "tensorflow/c/experimental/ops/gen/model/op_spec.h" +#include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h index 2b2857943ef456..0a7b08cd9b171f 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator.h @@ -16,8 +16,11 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_OPS_GEN_CPP_CPP_GENERATOR_H_ #include "tensorflow/c/experimental/ops/gen/common/controller.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/common/source_code.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" #include "tensorflow/c/experimental/ops/gen/cpp/renderers/renderer_context.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc index cf453ba75188de..6d33a389ad88aa 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/cpp_generator_test.cc @@ -16,9 +16,13 @@ limitations under the License. #include -#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/c/experimental/ops/gen/common/path_config.h" +#include "tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD index ea6d23c8b16917..fd8194d584d32b 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/views/BUILD +++ b/tensorflow/c/experimental/ops/gen/cpp/views/BUILD @@ -11,7 +11,10 @@ cc_library( exclude = ["*_test.cc"], ), hdrs = glob(["*.h"]), - visibility = ["//tensorflow/c/experimental/ops/gen/cpp/renderers:__pkg__"], + visibility = [ + "//tensorflow/c/experimental/ops/gen/cpp:__pkg__", + "//tensorflow/c/experimental/ops/gen/cpp/renderers:__pkg__", + ], deps = [ "//tensorflow/c/experimental/ops/gen/common", "//tensorflow/c/experimental/ops/gen/model", diff --git a/tensorflow/c/experimental/ops/io_ops.cc b/tensorflow/c/experimental/ops/io_ops.cc index 4896c57458a347..920d82cf1be3ec 100644 --- a/tensorflow/c/experimental/ops/io_ops.cc +++ b/tensorflow/c/experimental/ops/io_ops.cc @@ -50,12 +50,12 @@ namespace ops { // // Callers must ensure all the named tensors are indeed stored in the // checkpoint. -Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, - AbstractTensorHandle* const tensor_names, - AbstractTensorHandle* const shape_and_slices, - absl::Span tensors, - absl::Span dtypes, const char* name, - const char* raw_device_name) { +absl::Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + absl::Span dtypes, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("RestoreV2", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -75,11 +75,11 @@ Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, // By default, saves the named tensors in full. If the caller wishes to save // specific slices of full tensors, "shape_and_slices" should be non-empty // strings and correspondingly well-formed. -Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, - AbstractTensorHandle* const tensor_names, - AbstractTensorHandle* const shape_and_slices, - absl::Span tensors, const char* name, - const char* raw_device_name) { +absl::Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("SaveV2", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); diff --git a/tensorflow/c/experimental/ops/io_ops.h b/tensorflow/c/experimental/ops/io_ops.h index 4160ac36a93428..ceccddad5ea188 100644 --- a/tensorflow/c/experimental/ops/io_ops.h +++ b/tensorflow/c/experimental/ops/io_ops.h @@ -28,20 +28,20 @@ namespace tensorflow { namespace ops { // Restores tensors from a V2 checkpoint. -Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, - AbstractTensorHandle* const tensor_names, - AbstractTensorHandle* const shape_and_slices, - absl::Span tensors, - absl::Span dtypes, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status RestoreV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + absl::Span dtypes, const char* name = nullptr, + const char* raw_device_name = nullptr); // Saves tensors in V2 checkpoint format. -Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, - AbstractTensorHandle* const tensor_names, - AbstractTensorHandle* const shape_and_slices, - absl::Span tensors, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status SaveV2(AbstractContext* ctx, AbstractTensorHandle* const prefix, + AbstractTensorHandle* const tensor_names, + AbstractTensorHandle* const shape_and_slices, + absl::Span tensors, + const char* name = nullptr, + const char* raw_device_name = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index 98887da0a20959..2a2ea0f26534b9 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -36,9 +36,9 @@ namespace ops { // Description: // *NOTE*: `Multiply` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Mul", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -66,9 +66,9 @@ Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, // # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j] // tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j] // ``` -Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, const char* name, - const char* raw_device_name) { +absl::Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Conj", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -84,9 +84,9 @@ Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, // *NOTE*: `Add` supports broadcasting. `AddN` does not. More about // broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("AddV2", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -107,10 +107,11 @@ Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, // // *Note*: The default kernel implementation for MatMul on GPUs uses // cublas. -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, - AbstractTensorHandle* const b, AbstractTensorHandle** product, - bool transpose_a, bool transpose_b, const char* name, - const char* raw_device_name) { +absl::Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, + AbstractTensorHandle* const b, + AbstractTensorHandle** product, bool transpose_a, + bool transpose_b, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("MatMul", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -127,9 +128,9 @@ Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, // // Description: // I.e., \\(y = -x\\). -Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Neg", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -146,10 +147,10 @@ Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, // `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry // in `axis`. If `keep_dims` is true, the reduced dimensions are retained with // length 1. -Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle* const reduction_indices, - AbstractTensorHandle** output, bool keep_dims, const char* name, - const char* raw_device_name) { +absl::Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const reduction_indices, + AbstractTensorHandle** output, bool keep_dims, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Sum", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -166,9 +167,9 @@ Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, // Description: // *NOTE*: `Subtract` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Sub", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -184,9 +185,9 @@ Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, // Description: // *NOTE*: `Div` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Div", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -203,9 +204,9 @@ Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, // // *NOTE*: `DivNoNan` supports broadcasting. More about broadcasting // [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("DivNoNan", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -245,9 +246,9 @@ Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, // x = tf.constant(1 + 1j) // tf.math.exp(x) ==> 1.4686939399158851+2.2873552871788423j // ``` -Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Exp", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -261,9 +262,9 @@ Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, // // Description: // I.e., \\(y = \sqrt{x} = x^{1/2}\\). -Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Sqrt", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -278,9 +279,9 @@ Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, // Description: // Specifically, `grad = dy * 0.5 / y`, where `y = sqrt(x)`, and `dy` // is the corresponding input gradient. -Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, - AbstractTensorHandle* const dy, AbstractTensorHandle** z, - const char* name, const char* raw_device_name) { +absl::Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, + AbstractTensorHandle* const dy, AbstractTensorHandle** z, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("SqrtGrad", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -302,9 +303,9 @@ Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, // x = tf.constant([0, 0.5, 1, 5]) // tf.math.log1p(x) ==> [0., 0.4054651, 0.6931472, 1.7917595] // ``` -Status Log1p(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name, - const char* raw_device_name) { +absl::Status Log1p(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Log1p", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index 612640df89a8a8..c7cde54acad483 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -26,74 +26,79 @@ namespace tensorflow { namespace ops { // Returns x * y element-wise. -Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status Mul(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns the complex conjugate of a complex number. -Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle** output, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Conj(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle** output, const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns x + y element-wise. -Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status AddV2(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Multiply the matrix "a" by the matrix "b". -Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, - AbstractTensorHandle* const b, AbstractTensorHandle** product, - bool transpose_a = false, bool transpose_b = false, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status MatMul(AbstractContext* ctx, AbstractTensorHandle* const a, + AbstractTensorHandle* const b, + AbstractTensorHandle** product, bool transpose_a = false, + bool transpose_b = false, const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes numerical negative value element-wise. -Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Neg(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes the sum of elements across dimensions of a tensor. -Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, - AbstractTensorHandle* const reduction_indices, - AbstractTensorHandle** output, bool keep_dims = false, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status Sum(AbstractContext* ctx, AbstractTensorHandle* const input, + AbstractTensorHandle* const reduction_indices, + AbstractTensorHandle** output, bool keep_dims = false, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns x - y element-wise. -Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status Sub(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns x / y element-wise. -Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name = nullptr, const char* raw_device_name = nullptr); +absl::Status Div(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Returns 0 if the denominator is zero. -Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle* const y, AbstractTensorHandle** z, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status DivNoNan(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle* const y, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes exponential of x element-wise. \\(y = e^x\\). -Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Exp(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes square root of x element-wise. -Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Sqrt(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes the gradient for the sqrt of `x` wrt its input. -Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, - AbstractTensorHandle* const dy, AbstractTensorHandle** z, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status SqrtGrad(AbstractContext* ctx, AbstractTensorHandle* const y, + AbstractTensorHandle* const dy, AbstractTensorHandle** z, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes natural logarithm of (1 + x) element-wise. -Status Log1p(AbstractContext* ctx, AbstractTensorHandle* const x, - AbstractTensorHandle** y, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Log1p(AbstractContext* ctx, AbstractTensorHandle* const x, + AbstractTensorHandle** y, const char* name = nullptr, + const char* raw_device_name = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index f0eed05ce0dd8e..6be53fb7fe0bf5 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -40,13 +40,11 @@ namespace ops { // given row. // // Inputs are the logits, not probabilities. -Status SparseSoftmaxCrossEntropyWithLogits(AbstractContext* ctx, - AbstractTensorHandle* const features, - AbstractTensorHandle* const labels, - AbstractTensorHandle** loss, - AbstractTensorHandle** backprop, - const char* name, - const char* raw_device_name) { +absl::Status SparseSoftmaxCrossEntropyWithLogits( + AbstractContext* ctx, AbstractTensorHandle* const features, + AbstractTensorHandle* const labels, AbstractTensorHandle** loss, + AbstractTensorHandle** backprop, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR( op_ptr->Reset("SparseSoftmaxCrossEntropyWithLogits", raw_device_name)); @@ -55,7 +53,7 @@ Status SparseSoftmaxCrossEntropyWithLogits(AbstractContext* ctx, TF_RETURN_IF_ERROR(op_ptr->AddInput(labels)); int num_retvals = 2; AbstractTensorHandle* temp_outputs[2]; - Status status = op_ptr->Execute(temp_outputs, &num_retvals); + absl::Status status = op_ptr->Execute(temp_outputs, &num_retvals); *loss = temp_outputs[0]; *backprop = temp_outputs[1]; return status; @@ -65,10 +63,11 @@ Status SparseSoftmaxCrossEntropyWithLogits(AbstractContext* ctx, // Summary: Computes rectified linear gradients for a Relu operation. // // Description: -Status ReluGrad(AbstractContext* ctx, AbstractTensorHandle* const gradients, - AbstractTensorHandle* const features, - AbstractTensorHandle** backprops, const char* name, - const char* raw_device_name) { +absl::Status ReluGrad(AbstractContext* ctx, + AbstractTensorHandle* const gradients, + AbstractTensorHandle* const features, + AbstractTensorHandle** backprops, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("ReluGrad", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -86,9 +85,9 @@ Status ReluGrad(AbstractContext* ctx, AbstractTensorHandle* const gradients, // Example usage: // >>> tf.nn.relu([-2., 0., 3.]).numpy() // array([0., 0., 3.], dtype=float32) -Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, - AbstractTensorHandle** activations, const char* name, - const char* raw_device_name) { +absl::Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, + AbstractTensorHandle** activations, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("Relu", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -103,10 +102,10 @@ Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, // Description: // This is a special case of `tf.add` where `bias` is restricted to be 1-D. // Broadcasting is supported, so `value` may have any number of dimensions. -Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, - AbstractTensorHandle* const bias, AbstractTensorHandle** output, - const char* data_format, const char* name, - const char* raw_device_name) { +absl::Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, + AbstractTensorHandle* const bias, + AbstractTensorHandle** output, const char* data_format, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("BiasAdd", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -125,10 +124,10 @@ Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, // It accumulates all the values from out_backprop into the feature dimension. // For NHWC data format, the feature dimension is the last. For NCHW data // format, the feature dimension is the third-to-last. -Status BiasAddGrad(AbstractContext* ctx, - AbstractTensorHandle* const out_backprop, - AbstractTensorHandle** output, const char* data_format, - const char* name, const char* raw_device_name) { +absl::Status BiasAddGrad(AbstractContext* ctx, + AbstractTensorHandle* const out_backprop, + AbstractTensorHandle** output, const char* data_format, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("BiasAddGrad", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index fd145a3b0b8161..204ed13a3ba9fd 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -26,35 +26,41 @@ namespace tensorflow { namespace ops { // Computes softmax cross entropy cost and gradients to backpropagate. -Status SparseSoftmaxCrossEntropyWithLogits( +absl::Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, AbstractTensorHandle* const features, AbstractTensorHandle* const labels, AbstractTensorHandle** loss, AbstractTensorHandle** backprop, const char* name = nullptr, const char* raw_device_name = nullptr); // Computes rectified linear gradients for a Relu operation. -Status ReluGrad(AbstractContext* ctx, AbstractTensorHandle* const gradients, - AbstractTensorHandle* const features, - AbstractTensorHandle** backprops, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status ReluGrad(AbstractContext* ctx, + AbstractTensorHandle* const gradients, + AbstractTensorHandle* const features, + AbstractTensorHandle** backprops, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Computes rectified linear: `max(features, 0)`. -Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, - AbstractTensorHandle** activations, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status Relu(AbstractContext* ctx, AbstractTensorHandle* const features, + AbstractTensorHandle** activations, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Adds `bias` to `value`. -Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, - AbstractTensorHandle* const bias, AbstractTensorHandle** output, - const char* data_format = "NHWC", const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status BiasAdd(AbstractContext* ctx, AbstractTensorHandle* const value, + AbstractTensorHandle* const bias, + AbstractTensorHandle** output, + const char* data_format = "NHWC", + const char* name = nullptr, + const char* raw_device_name = nullptr); // The backward operation for "BiasAdd" on the "bias" tensor. -Status BiasAddGrad(AbstractContext* ctx, - AbstractTensorHandle* const out_backprop, - AbstractTensorHandle** output, - const char* data_format = "NHWC", const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status BiasAddGrad(AbstractContext* ctx, + AbstractTensorHandle* const out_backprop, + AbstractTensorHandle** output, + const char* data_format = "NHWC", + const char* name = nullptr, + const char* raw_device_name = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/resource_variable_ops.cc b/tensorflow/c/experimental/ops/resource_variable_ops.cc index d9d480c2fcb3f4..68304ebff5bbbe 100644 --- a/tensorflow/c/experimental/ops/resource_variable_ops.cc +++ b/tensorflow/c/experimental/ops/resource_variable_ops.cc @@ -37,11 +37,11 @@ namespace ops { // Summary: Creates a handle to a Variable resource. // // Description: -Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, - DataType dtype, const PartialTensorShape shape, - const char* container, const char* shared_name, - absl::Span allowed_devices, const char* name, - const char* raw_device_name) { +absl::Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, + DataType dtype, const PartialTensorShape shape, + const char* container, const char* shared_name, + absl::Span allowed_devices, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("VarHandleOp", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -67,10 +67,10 @@ Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, // the writes on which this operation depends directly or indirectly, and to // not be influenced by any of the writes which depend directly or indirectly // on this operation. -Status ReadVariableOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - AbstractTensorHandle** value, DataType dtype, - const char* name, const char* raw_device_name) { +absl::Status ReadVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle** value, DataType dtype, + const char* name, const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("ReadVariableOp", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -86,10 +86,11 @@ Status ReadVariableOp(AbstractContext* ctx, // Description: // Any ReadVariableOp with a control dependency on this op is guaranteed to // return this value or a subsequent newer value of the variable. -Status AssignVariableOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - AbstractTensorHandle* const value, bool validate_shape, - const char* name, const char* raw_device_name) { +absl::Status AssignVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle* const value, + bool validate_shape, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("AssignVariableOp", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); @@ -107,10 +108,10 @@ Status AssignVariableOp(AbstractContext* ctx, // Description: // All subsequent operations using the resource will result in a NotFound // error status. -Status DestroyResourceOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - bool ignore_lookup_error, const char* name, - const char* raw_device_name) { +absl::Status DestroyResourceOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + bool ignore_lookup_error, const char* name, + const char* raw_device_name) { AbstractOperationPtr op_ptr(ctx->CreateOperation()); TF_RETURN_IF_ERROR(op_ptr->Reset("DestroyResourceOp", raw_device_name)); TF_RETURN_IF_ERROR(MaybeSetOpName(op_ptr.get(), name)); diff --git a/tensorflow/c/experimental/ops/resource_variable_ops.h b/tensorflow/c/experimental/ops/resource_variable_ops.h index 4e0eb3db4d9ac2..5ba2b8fdd5656d 100644 --- a/tensorflow/c/experimental/ops/resource_variable_ops.h +++ b/tensorflow/c/experimental/ops/resource_variable_ops.h @@ -30,33 +30,35 @@ namespace tensorflow { namespace ops { // Creates a handle to a Variable resource. -Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, - DataType dtype, const PartialTensorShape shape, - const char* container = "", const char* shared_name = "", - absl::Span allowed_devices = {}, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status VarHandleOp(AbstractContext* ctx, AbstractTensorHandle** resource, + DataType dtype, const PartialTensorShape shape, + const char* container = "", + const char* shared_name = "", + absl::Span allowed_devices = {}, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Reads the value of a variable. -Status ReadVariableOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - AbstractTensorHandle** value, DataType dtype, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status ReadVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle** value, DataType dtype, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Assigns a new value to a variable. -Status AssignVariableOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - AbstractTensorHandle* const value, - bool validate_shape = false, const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status AssignVariableOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + AbstractTensorHandle* const value, + bool validate_shape = false, + const char* name = nullptr, + const char* raw_device_name = nullptr); // Deletes the resource specified by the handle. -Status DestroyResourceOp(AbstractContext* ctx, - AbstractTensorHandle* const resource, - bool ignore_lookup_error = true, - const char* name = nullptr, - const char* raw_device_name = nullptr); +absl::Status DestroyResourceOp(AbstractContext* ctx, + AbstractTensorHandle* const resource, + bool ignore_lookup_error = true, + const char* name = nullptr, + const char* raw_device_name = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h index 116ff80bbcf56a..55af07ad79f4be 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h @@ -30,7 +30,7 @@ using TFInitProfilerFn = void (*)(TF_ProfilerRegistrationParams* const, TF_Status* const); // Registers plugin's profiler to TensorFlow's profiler registry. -Status InitPluginProfiler(TFInitProfilerFn init_fn); +absl::Status InitPluginProfiler(TFInitProfilerFn init_fn); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD index 9ab1345270dce8..ccfe38e1e689ba 100644 --- a/tensorflow/c/experimental/stream_executor/BUILD +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -73,6 +73,7 @@ cc_library( "//tensorflow/c:tf_status_helper", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:event", diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index ae50a9ead0fc1b..ff2d6146c5ead1 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -21,6 +21,7 @@ limitations under the License. // device. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include #include #include #include @@ -203,7 +204,7 @@ class CStreamExecutor : public StreamExecutorCommon { absl::Status Init() override { return absl::OkStatus(); } - DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override { + DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override { SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; stream_executor_->allocate(&device_, size, memory_space, &mem); absl::Status status = ValidateSPDeviceMemoryBase(mem); @@ -212,7 +213,7 @@ class CStreamExecutor : public StreamExecutorCommon { } return DeviceMemoryBaseFromC(mem); } - DeviceMemoryBase Allocate(uint64 size) { + DeviceMemoryBase Allocate(uint64_t size) { return Allocate(size, /*memory_space=*/0); } @@ -222,7 +223,7 @@ class CStreamExecutor : public StreamExecutorCommon { } absl::StatusOr> HostMemoryAllocate( - uint64 size) override { + uint64_t size) override { auto* buffer = stream_executor_->host_memory_allocate(&device_, size); if (buffer == nullptr && size > 0) { return absl::InternalError( @@ -235,7 +236,7 @@ class CStreamExecutor : public StreamExecutorCommon { stream_executor_->host_memory_deallocate(&device_, mem); } - void* UnifiedMemoryAllocate(uint64 size) override { + void* UnifiedMemoryAllocate(uint64_t size) override { CHECK(stream_executor_->unified_memory_allocate); return stream_executor_->unified_memory_allocate(&device_, size); } @@ -283,14 +284,14 @@ class CStreamExecutor : public StreamExecutorCommon { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, - uint64 size) override { + uint64_t size) override { // TODO(annarev): figure out if we should support memzero/memset // functionality by allocating on host and then copying to device. return tsl::errors::Unimplemented( "SynchronousMemZero is not supported by pluggable device."); } absl::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, - const void* host_src, uint64 size) override { + const void* host_src, uint64_t size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst); stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src, @@ -299,7 +300,7 @@ class CStreamExecutor : public StreamExecutorCommon { } absl::Status SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64 size) override { + uint64_t size) override { OwnedTFStatus c_status(TF_NewStatus()); SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src); stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base, @@ -317,33 +318,6 @@ class CStreamExecutor : public StreamExecutorCommon { return StatusFromTF_Status(c_status.get()); } - absl::Status BlockHostUntilDone(Stream* stream) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - - // If `block_host_until_done` is set, use it. - if (stream_executor_->block_host_until_done != nullptr) { - stream_executor_->block_host_until_done(&device_, stream_handle, - c_status.get()); - return StatusFromTF_Status(c_status.get()); - } - // Create and record an event and then wait for it. - SP_Event event_handle; - stream_executor_->create_event(&device_, &event_handle, c_status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); - stream_executor_->record_event(&device_, stream_handle, event_handle, - c_status.get()); - absl::Status s = StatusFromTF_Status(c_status.get()); - if (!s.ok()) { - stream_executor_->destroy_event(&device_, event_handle); - return s; - } - stream_executor_->block_host_for_event(&device_, event_handle, - c_status.get()); - stream_executor_->destroy_event(&device_, event_handle); - return StatusFromTF_Status(c_status.get()); - } - absl::Status EnablePeerAccessTo(StreamExecutor* other) override { return tsl::errors::Unimplemented( "EnablePeerAccessTo is not supported by pluggable device."); diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index 769f640d6968d2..b8217ea3cb43f6 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -35,6 +35,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace stream_executor { @@ -180,7 +181,7 @@ class CStream : public StreamCommon { stream_executor_(stream_executor), stream_handle_(nullptr) {} ~CStream() override { - parent()->BlockHostUntilDone(this).IgnoreError(); + BlockHostUntilDone().IgnoreError(); parent()->DeallocateStream(this); Destroy(); } @@ -210,6 +211,33 @@ class CStream : public StreamCommon { return static_cast(event)->Record(stream_handle_); } + absl::Status BlockHostUntilDone() override { + tensorflow::TF_StatusPtr c_status(TF_NewStatus()); + SP_Stream stream_handle = Handle(); + + // If `block_host_until_done` is set, use it. + if (stream_executor_->block_host_until_done != nullptr) { + stream_executor_->block_host_until_done(device_, stream_handle, + c_status.get()); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + // Create and record an event and then wait for it. + SP_Event event_handle; + stream_executor_->create_event(device_, &event_handle, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + stream_executor_->record_event(device_, stream_handle, event_handle, + c_status.get()); + absl::Status s = tensorflow::StatusFromTF_Status(c_status.get()); + if (!s.ok()) { + stream_executor_->destroy_event(device_, event_handle); + return s; + } + stream_executor_->block_host_for_event(device_, event_handle, + c_status.get()); + stream_executor_->destroy_event(device_, event_handle); + return tensorflow::StatusFromTF_Status(c_status.get()); + } + absl::Status WaitFor(Stream* other) override { tensorflow::TF_StatusPtr c_status(TF_NewStatus()); SP_Stream other_handle = static_cast(other)->Handle(); diff --git a/tensorflow/c/kernels_experimental.cc b/tensorflow/c/kernels_experimental.cc index 02e41428c6ae58..d63e631e657b25 100644 --- a/tensorflow/c/kernels_experimental.cc +++ b/tensorflow/c/kernels_experimental.cc @@ -68,7 +68,7 @@ struct TF_VariableInputLockHolder { std::unique_ptr> shared_locks; }; -tensorflow::Status EnsureSparseVariableAccess( +absl::Status EnsureSparseVariableAccess( TF_OpKernelContext* ctx, bool variantType, void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, TF_Tensor* dest), @@ -108,7 +108,7 @@ tensorflow::Status EnsureSparseVariableAccess( attr.set_nic_compatible(true); TF_RETURN_IF_ERROR(context->allocate_temp( var->tensor()->dtype(), var->tensor()->shape(), &tmp, attr)); - tensorflow::Status s; + absl::Status s; TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); TF_Tensor* tf_tensor = TF_TensorFromTensor(*var->tensor(), &s); copyFunc(ctx, tf_tensor, tf_tmp); @@ -118,11 +118,12 @@ tensorflow::Status EnsureSparseVariableAccess( return absl::OkStatus(); } -tensorflow::Status PrepareToUpdateVariable( - TF_OpKernelContext* ctx, tensorflow::Tensor* tensor, bool copy_on_read_mode, - bool variantType, - void (*copyFunc)(TF_OpKernelContext* ctx, TF_Tensor* source, - TF_Tensor* dest)) { +absl::Status PrepareToUpdateVariable(TF_OpKernelContext* ctx, + tensorflow::Tensor* tensor, + bool copy_on_read_mode, bool variantType, + void (*copyFunc)(TF_OpKernelContext* ctx, + TF_Tensor* source, + TF_Tensor* dest)) { auto* context = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); if (copy_on_read_mode || !tensor->RefCountIsOne()) { // Tensor's buffer is in use by some read, so we need to copy before @@ -145,7 +146,7 @@ tensorflow::Status PrepareToUpdateVariable( attr.set_nic_compatible(true); TF_RETURN_IF_ERROR( context->allocate_temp(tensor->dtype(), tensor->shape(), &tmp, attr)); - tensorflow::Status s; + absl::Status s; TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s); copyFunc(ctx, tf_tensor, tf_tmp); @@ -209,7 +210,7 @@ void TF_AssignVariable(TF_OpKernelContext* ctx, int input_index, attr.set_nic_compatible(true); OP_REQUIRES_OK(cc_ctx, cc_ctx->allocate_temp(value.dtype(), value.shape(), &tmp, attr)); - tensorflow::Status s; + absl::Status s; TF_Tensor* tf_tmp = TF_TensorFromTensor(tmp, &s); TF_Tensor* tf_value = TF_TensorFromTensor(value, &s); copyFunc(ctx, tf_value, tf_tmp); @@ -232,7 +233,7 @@ void TF_AssignRefVariable(TF_OpKernelContext* ctx, int input_ref_index, auto copy = [copyFunc, ctx](::tensorflow::OpKernelContext* cc_ctx, ::tensorflow::Tensor* lhs, const ::tensorflow::Tensor& rhs) { - ::tensorflow::Status s; + absl::Status s; TF_Tensor* tf_lhs = TF_TensorFromTensor(*lhs, &s); OP_REQUIRES_OK(cc_ctx, s); @@ -282,7 +283,7 @@ void TF_AssignUpdateVariable(TF_OpKernelContext* ctx, int input_index, PrepareToUpdateVariable(ctx, var_tensor, variable->copy_on_read_mode.load(), isVariantType, copyFunc)); - tensorflow::Status s; + absl::Status s; TF_Tensor* tf_var_tensor = TF_TensorFromTensor(*var_tensor, &s); TF_Tensor* tf_value = TF_TensorFromTensor(value, &s); updateFunc(ctx, tf_var_tensor, tf_value, Op); @@ -461,7 +462,7 @@ void TF_GetInputTensorFromVariable(TF_OpKernelContext* ctx, int input, ::tensorflow::Set_TF_Status_from_Status(status, cc_ctx->status()); }); - tensorflow::Status s; + absl::Status s; if (cc_ctx->input_dtype(input) == tensorflow::DT_RESOURCE) { tensorflow::core::RefCountPtr var; OP_REQUIRES_OK( @@ -508,7 +509,7 @@ void TF_GetInputByName(TF_OpKernelContext* ctx, const char* inputName, TF_Tensor** tensor, TF_Status* status) { auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(ctx); const ::tensorflow::Tensor* cc_tensor = nullptr; - tensorflow::Status s = cc_ctx->input(inputName, &cc_tensor); + absl::Status s = cc_ctx->input(inputName, &cc_tensor); if (!s.ok()) { ::tensorflow::Set_TF_Status_from_Status(status, s); @@ -527,7 +528,7 @@ void TF_OpKernelConstruction_GetAttrTensorShape(TF_OpKernelConstruction* ctx, TF_Status* status) { ::tensorflow::TensorShape shape; auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx); - ::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &shape); + absl::Status s = cc_ctx->GetAttr(attr_name, &shape); ::tensorflow::Set_TF_Status_from_Status(status, s); size_t rank = static_cast(shape.dims()); diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index ea4d60724e0760..b8b8b2f29cfe13 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -111,12 +111,12 @@ static void MyDeleteFunc(void* kernel) { } namespace tensorflow { -Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); static std::unique_ptr GetFakeKernel(const char* device_name, const char* op_name, const char* node_name, - Status* status) { + absl::Status* status) { NodeDef def; def.set_op(op_name); def.set_name(node_name); @@ -135,7 +135,7 @@ static std::unique_ptr GetFakeKernel(const char* device_name, static std::unique_ptr GetFakeKernel2(const char* device_name, const char* op_name, const char* node_name, - Status* status) { + absl::Status* status) { NodeDef def; def.set_op(op_name); def.set_name(node_name); @@ -189,7 +189,7 @@ TEST(TestKernel, TestRegisterKernelBuilder) { } { - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); @@ -235,7 +235,7 @@ TEST(TestKernel, TF_RegisterKernelBuilderWithKernelDef) { } { - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); @@ -277,7 +277,7 @@ TEST(TestKernel, TestRegisterAsyncKernelBuilder) { } { - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); @@ -327,7 +327,8 @@ class TestKernelAttr : public ::testing::Test { ~TestKernelAttr() override {} std::unique_ptr GetFakeKernelWithAttr(const char* op_name, - AttrValue v, Status* status) { + AttrValue v, + absl::Status* status) { NodeDef def; def.set_op(op_name); def.set_name("FakeNode"); @@ -347,7 +348,7 @@ class TestKernelAttr : public ::testing::Test { EXPECT_EQ(TF_OK, TF_GetCode(status)); TF_DeleteStatus(status); } - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernelWithAttr(op_name, v, &status); TF_EXPECT_OK(status); @@ -873,7 +874,7 @@ TEST(TestKernel, TestInputAndOutputCount) { inputs.emplace_back(); p.inputs = inputs; - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); @@ -1362,7 +1363,7 @@ TEST_F(DeviceKernelOpTest, TestGetKernelInfo) { inputs.emplace_back(&t2_1); inputs.emplace_back(&t2_2); - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel2(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); @@ -1430,7 +1431,7 @@ TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) { inputs.emplace_back(); p.inputs = inputs; - Status status; + absl::Status status; std::unique_ptr kernel = GetFakeKernel(device_name, op_name, node_name, &status); TF_EXPECT_OK(status); diff --git a/tensorflow/c/ops.cc b/tensorflow/c/ops.cc index 6cbe75fc7e3ebd..0881b2ab6cb22c 100644 --- a/tensorflow/c/ops.cc +++ b/tensorflow/c/ops.cc @@ -90,11 +90,11 @@ void TF_OpDefinitionBuilderSetShapeInferenceFunction( TF_Status* status)) { auto* cc_builder = reinterpret_cast(builder); cc_builder->SetShapeFn( - [shape_inference_func](InferenceContext* ctx) -> tensorflow::Status { + [shape_inference_func](InferenceContext* ctx) -> absl::Status { TF_Status* c_status = TF_NewStatus(); auto c_ctx = reinterpret_cast(ctx); shape_inference_func(c_ctx, c_status); - tensorflow::Status result = ::tensorflow::StatusFromTF_Status(c_status); + absl::Status result = ::tensorflow::StatusFromTF_Status(c_status); TF_DeleteStatus(c_status); return result; }); diff --git a/tensorflow/c/tf_buffer.cc b/tensorflow/c/tf_buffer.cc index a891f89ed16d0c..287966498845a7 100644 --- a/tensorflow/c/tf_buffer.cc +++ b/tensorflow/c/tf_buffer.cc @@ -56,8 +56,8 @@ TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } namespace tensorflow { -Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, - TF_Buffer* out) { +absl::Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out) { if (out->data != nullptr) { return errors::InvalidArgument("Passing non-empty TF_Buffer is invalid."); } @@ -81,8 +81,8 @@ Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, return absl::OkStatus(); } -Status BufferToMessage(const TF_Buffer* in, - tensorflow::protobuf::MessageLite* out) { +absl::Status BufferToMessage(const TF_Buffer* in, + tensorflow::protobuf::MessageLite* out) { if (in == nullptr || !out->ParseFromArray(in->data, in->length)) { return errors::InvalidArgument("Unparseable ", out->GetTypeName(), " proto"); diff --git a/tensorflow/c/tf_buffer_internal.h b/tensorflow/c/tf_buffer_internal.h index 7382e558ef66a8..85436f42294693 100644 --- a/tensorflow/c/tf_buffer_internal.h +++ b/tensorflow/c/tf_buffer_internal.h @@ -24,11 +24,11 @@ limitations under the License. namespace tensorflow { -Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, - TF_Buffer* out); +absl::Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, + TF_Buffer* out); -Status BufferToMessage(const TF_Buffer* in, - tensorflow::protobuf::MessageLite* out); +absl::Status BufferToMessage(const TF_Buffer* in, + tensorflow::protobuf::MessageLite* out); namespace internal { diff --git a/tensorflow/c/tf_tensor.cc b/tensorflow/c/tf_tensor.cc index 96c3fd97344115..a2d46fb51810c0 100644 --- a/tensorflow/c/tf_tensor.cc +++ b/tensorflow/c/tf_tensor.cc @@ -249,8 +249,10 @@ void TensorInterface::SetShape(const int64_t* dims, int num_dims) { tensor_.set_shape(s); } -Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type, - const int64_t* new_dims, int num_new_dims) { +absl::Status TensorInterface::BitcastFrom(const TensorInterface& from, + DataType type, + const int64_t* new_dims, + int num_new_dims) { tensorflow::TensorShape s; for (int i = 0; i < num_new_dims; ++i) { TF_RETURN_IF_ERROR(s.AddDimWithStatus(new_dims[i])); @@ -258,7 +260,7 @@ Status TensorInterface::BitcastFrom(const TensorInterface& from, DataType type, return tensor_.BitcastFrom(from.tensor_, type, s); } -Status TensorInterface::FromProto(const tensorflow::TensorProto& from) { +absl::Status TensorInterface::FromProto(const tensorflow::TensorProto& from) { bool success = tensor_.FromProto(from); if (success) return absl::OkStatus(); return errors::InvalidArgument("Unparseable tensor proto"); @@ -295,7 +297,7 @@ static TF_Tensor* EmptyTensor(TF_DataType dtype, namespace tensorflow { AbstractTensorInterface* TensorInterfaceFromTensor(const Tensor& src, - Status* status) { + absl::Status* status) { *status = absl::OkStatus(); if (!src.IsInitialized()) { *status = FailedPrecondition( @@ -318,12 +320,13 @@ AbstractTensorInterface* TensorInterfaceFromTensor(const Tensor& src, } // Non-static for testing. -TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) { +TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, + absl::Status* status) { return new TF_Tensor{TensorInterfaceFromTensor(src, status)}; } TF_Tensor* TF_TensorFromTensorShallow(const tensorflow::Tensor& src, - Status* status) { + absl::Status* status) { *status = absl::OkStatus(); if (!src.IsInitialized()) { *status = FailedPrecondition( @@ -336,12 +339,12 @@ TF_Tensor* TF_TensorFromTensorShallow(const tensorflow::Tensor& src, return new TF_Tensor{new tensorflow::TensorInterface(src)}; } -Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { return tensorflow::down_cast(src->tensor) ->ToTensor(dst); } -Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { +absl::Status TensorInterface::ToTensor(tensorflow::Tensor* dst) const { *dst = tensor_; return absl::OkStatus(); } diff --git a/tensorflow/c/tf_tensor_helper.h b/tensorflow/c/tf_tensor_helper.h index 5201e39e0a1f8b..b77d5a78a6270f 100644 --- a/tensorflow/c/tf_tensor_helper.h +++ b/tensorflow/c/tf_tensor_helper.h @@ -25,11 +25,11 @@ namespace tensorflow { class Tensor; -Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +absl::Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); -TF_Tensor* TF_TensorFromTensor(const Tensor& src, Status* status); +TF_Tensor* TF_TensorFromTensor(const Tensor& src, absl::Status* status); -TF_Tensor* TF_TensorFromTensorShallow(const Tensor& src, Status* status); +TF_Tensor* TF_TensorFromTensorShallow(const Tensor& src, absl::Status* status); namespace internal { diff --git a/tensorflow/c/tf_tensor_internal.h b/tensorflow/c/tf_tensor_internal.h index e68bcebd4cd014..61bceee5d5ab4a 100644 --- a/tensorflow/c/tf_tensor_internal.h +++ b/tensorflow/c/tf_tensor_internal.h @@ -113,10 +113,10 @@ class TensorInterface : public AbstractTensorInterface { std::string SummarizeValue() const override; void SetShape(const int64_t* dims, int num_dims); - Status ToTensor(tensorflow::Tensor* dst) const; - Status BitcastFrom(const TensorInterface& from, DataType type, - const int64_t* new_dims, int num_new_dims); - Status FromProto(const tensorflow::TensorProto& from); + absl::Status ToTensor(tensorflow::Tensor* dst) const; + absl::Status BitcastFrom(const TensorInterface& from, DataType type, + const int64_t* new_dims, int num_new_dims); + absl::Status FromProto(const tensorflow::TensorProto& from); tensorflow::Tensor& Tensor() { return tensor_; } @@ -129,7 +129,7 @@ inline Tensor& TensorFromInterface(AbstractTensorInterface* tensor) { } AbstractTensorInterface* TensorInterfaceFromTensor(const Tensor& src, - Status* status); + absl::Status* status); } // namespace tensorflow diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index de384d01460b0a..45c8258eacf7d5 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -917,7 +917,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -938,7 +938,7 @@ tf_cc_test( "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -951,7 +951,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/status", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -968,7 +968,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/status", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/cc/experimental/libexport/BUILD b/tensorflow/cc/experimental/libexport/BUILD index d206c115abea65..117bc64b436864 100644 --- a/tensorflow/cc/experimental/libexport/BUILD +++ b/tensorflow/cc/experimental/libexport/BUILD @@ -8,7 +8,6 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", "//tensorflow/python/saved_model:__subpackages__", ], licenses = ["notice"], diff --git a/tensorflow/cc/experimental/libexport/load.cc b/tensorflow/cc/experimental/libexport/load.cc index be9319b066d74d..cec8af507a7fbd 100644 --- a/tensorflow/cc/experimental/libexport/load.cc +++ b/tensorflow/cc/experimental/libexport/load.cc @@ -28,7 +28,7 @@ namespace libexport { using protobuf::RepeatedPtrField; -tensorflow::StatusOr TFPackage::Load(const std::string& path) { +absl::StatusOr TFPackage::Load(const std::string& path) { // Load the proto TFPackage tf_package; const string saved_model_pb_path = io::JoinPath(path, kSavedModelFilenamePb); @@ -83,8 +83,7 @@ tensorflow::StatusOr TFPackage::Load(const std::string& path) { return tf_package; } -tensorflow::StatusOr TFPackage::GetVariableCheckpointKey( - int index) { +absl::StatusOr TFPackage::GetVariableCheckpointKey(int index) { // TODO(danielellis): make sure valid index const auto& trackable_object = trackable_object_graph_.nodes(index); const TrackableObjectGraph::TrackableObject::SerializedTensor* @@ -105,7 +104,7 @@ const SavedObjectGraph& TFPackage::GetObjectGraph() { return saved_model_proto_.mutable_meta_graphs(0)->object_graph_def(); } -tensorflow::StatusOr TFPackage::GetGraphDefNode( +absl::StatusOr TFPackage::GetGraphDefNode( std::string name) { const auto& iter = graph_def_nodes_by_name_.find(name); if (iter == graph_def_nodes_by_name_.end()) { diff --git a/tensorflow/cc/experimental/libexport/load.h b/tensorflow/cc/experimental/libexport/load.h index 8ab5019eba45fe..6775f73b5ab8fb 100644 --- a/tensorflow/cc/experimental/libexport/load.h +++ b/tensorflow/cc/experimental/libexport/load.h @@ -42,7 +42,7 @@ namespace libexport { class TFPackage { public: // Load a SavedModel, parsing the associated protobuf for later access. - static tensorflow::StatusOr Load(const std::string& path); + static absl::StatusOr Load(const std::string& path); // Reads and returns a checkpoint key associated with a variable. // @@ -53,7 +53,7 @@ class TFPackage { // checkpoint files by "checkpoint keys". These keys along with dtype and // shape / slice information allow RestoreV2 to look up a variable's value in // the SavedModel and restore it into a tensor. - tensorflow::StatusOr GetVariableCheckpointKey(int index); + absl::StatusOr GetVariableCheckpointKey(int index); // Retrieves the object graph from the SavedModel. // @@ -74,8 +74,7 @@ class TFPackage { // Since we may need to load many constants, we create a hash map of these // names to their corresponding nodes at load time in order to look them up // in constant time. - tensorflow::StatusOr GetGraphDefNode( - std::string name); + absl::StatusOr GetGraphDefNode(std::string name); // Returns a list of function defs in the SavedModel. const protobuf::RepeatedPtrField& GetFunctionDefs(); diff --git a/tensorflow/cc/experimental/libexport/load_test.cc b/tensorflow/cc/experimental/libexport/load_test.cc index 0b1565be4355fa..a8ad6e211718d7 100644 --- a/tensorflow/cc/experimental/libexport/load_test.cc +++ b/tensorflow/cc/experimental/libexport/load_test.cc @@ -24,7 +24,7 @@ namespace libexport { namespace { TEST(LoadTest, TestDiskSavedModelLoad) { - StatusOr result = TFPackage::Load("test"); + absl::StatusOr result = TFPackage::Load("test"); EXPECT_FALSE(result.status().ok()); } diff --git a/tensorflow/cc/experimental/libtf/function.cc b/tensorflow/cc/experimental/libtf/function.cc deleted file mode 100644 index 06b7fa15db5a6b..00000000000000 --- a/tensorflow/cc/experimental/libtf/function.cc +++ /dev/null @@ -1,263 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/function.h" - -#include -#include - -#include "absl/cleanup/cleanup.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -using tensorflow::AbstractContext; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractOperationPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::Status; -using tensorflow::StatusOr; - -// TODO(srbs): Move this to unified execution API. -tensorflow::Status ExecuteFunction( - AbstractFunctionPtr trace, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs) { - // TODO(srbs): Provide a function execution API on ctx so that we do not - // expose the internals of how functions are to be executed here. - std::string fname; - { - const tensorflow::FunctionDef* fdef = nullptr; - TF_RETURN_IF_ERROR(trace->GetFunctionDef(&fdef)); - fname = fdef->signature().name(); - } - // TODO(srbs): Update RegisterFunction to accept AbstractFunctionPtr. - TF_RETURN_IF_ERROR(ctx->RegisterFunction(trace.get())); - auto cleanup = absl::MakeCleanup( - [fname, ctx]() { ctx->RemoveFunction(fname).IgnoreError(); }); - auto call_op = AbstractOperationPtr(ctx->CreateOperation()); - TF_RETURN_IF_ERROR( - call_op->Reset(fname.c_str(), /*raw_device_name=*/nullptr)); - for (auto t : inputs) { - TF_RETURN_IF_ERROR(call_op->AddInput(t)); - } - int num_outputs = outputs.size(); - return call_op->Execute(outputs, &num_outputs); -} - -Status VerifySupportedSignature(TaggedValue signature) { - if (signature.type() == TaggedValue::Type::TENSOR_SPEC) { - return absl::OkStatus(); - } - if (signature.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : signature.tuple()) { - if (t.type() != TaggedValue::Type::TENSOR_SPEC) { - break; - } - } - return absl::OkStatus(); - } - return tensorflow::errors::Unimplemented( - "Only functions with inputs/outputs containing a single tensor or a tuple" - " of tensors are supported right now."); -} - -Status VerifySupportedArgs(TaggedValue args) { - if (args.type() == TaggedValue::Type::TENSOR) { - return absl::OkStatus(); - } - if (args.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : args.tuple()) { - if (t.type() != TaggedValue::Type::TENSOR) { - break; - } - } - return absl::OkStatus(); - } - return tensorflow::errors::Unimplemented( - "Only functions with inputs/outputs containing a single tensor or a tuple" - " of tensors are supported right now."); -} - -Status Function::RegisterTrace(AbstractFunctionPtr fn, - TaggedValue input_signature, - TaggedValue output_signature) { - TF_RETURN_IF_ERROR(VerifySupportedSignature(input_signature)); - TF_RETURN_IF_ERROR(VerifySupportedSignature(output_signature)); - concrete_fns_.push_back({fn, input_signature, output_signature}); - return absl::OkStatus(); -} - -bool Match(TaggedValue signature, TaggedValue value) { - // TODO(b/187216309): Extend this to handle more elaborate signatures and - // values. - switch (signature.type()) { - case TaggedValue::Type::TENSOR_SPEC: { - if (value.type() != TaggedValue::Type::TENSOR) { - return false; - } - auto spec = signature.tensor_spec(); - const auto& tensor = value.tensor(); - if (tensor->DataType() != spec.dtype) { - return false; - } - tensorflow::PartialTensorShape tensor_shape; - DCHECK(tensor->Shape(&tensor_shape).ok()); - if (!tensor_shape.IsCompatibleWith(spec.shape)) { - return false; - } - } break; - case TaggedValue::Type::TUPLE: { - if (value.type() != TaggedValue::Type::TUPLE) { - return false; - } - if (value.tuple().size() != signature.tuple().size()) { - return false; - } - for (auto i = 0; i < value.tuple().size(); i++) { - if (!Match(signature.tuple()[i], value.tuple()[i])) { - return false; - } - } - } break; - default: - return false; - } - return true; -} - -// TODO(b/190203981): Move to a separate nest-like library. -void Flatten(const TaggedValue& value, - std::vector* flat_args) { - if (value.type() == TaggedValue::Type::TENSOR) { - flat_args->emplace_back(value.tensor().get()); - } else if (value.type() == TaggedValue::Type::TUPLE) { - for (const auto& t : value.tuple()) { - Flatten(t, flat_args); - } - } else { - // TODO(b/190203981): Supported arbitrary structures. - LOG(ERROR) << "Unimplemented"; - } -} - -absl::StatusOr Unflatten( - absl::Span flat_args, TaggedValue structure) { - if (structure.type() == TaggedValue::Type::TENSOR_SPEC) { - if (flat_args.size() != 1) { - // Denotes a corrupted SavedModel in which output_signature does not match - // FunctionDef outputs. - return tensorflow::errors::Internal("Expected single tensor but found ", - flat_args.size()); - } - TaggedValue wrapped_t = - TaggedValue(impl::TaggedValueTensor(flat_args[0], /*add_ref=*/true)); - if (!Match(structure, wrapped_t)) { - // Denotes a corrupted SavedModel in which output_signature does not match - // FunctionDef outputs. - std::stringstream stream; - stream << "Shape and dtype of tensor " << wrapped_t - << " does not match that in signature " << structure; - return tensorflow::errors::Internal(stream.str()); - } - return wrapped_t; - } else if (structure.type() == TaggedValue::Type::TUPLE) { - // TODO(b/190203981): Remove this check when handling nested structures - // inside tuples. - if (flat_args.size() != structure.tuple().size()) { - return tensorflow::errors::InvalidArgument( - "Tuple length ", structure.tuple().size(), - " does not match length of flat args ", flat_args.size()); - } - auto result = impl::TaggedValue::Tuple(); - for (auto i = 0; i < structure.tuple().size(); i++) { - TF_ASSIGN_OR_RETURN(TaggedValue ele, - Unflatten({flat_args[i]}, structure.tuple()[i])); - result.tuple().emplace_back(std::move(ele)); - } - return result; - } else { - // TODO(b/190203981): Support arbitrary structures. - return tensorflow::errors::Unimplemented( - "Only tensors and tuples of tensors are supported right now."); - } -} - -size_t GetFlatSize(const TaggedValue& value) { - if (value.type() == TaggedValue::Type::TUPLE) { - size_t result = 0; - for (const auto& t : value.tuple()) { - result += GetFlatSize(t); - } - return result; - } else if (value.type() == TaggedValue::Type::LIST) { - size_t result = 0; - for (const auto& t : value.list()) { - result += GetFlatSize(t); - } - return result; - } else if (value.type() == TaggedValue::Type::DICT) { - size_t result = 0; - for (const auto& t : value.dict()) { - result += GetFlatSize(t.second); - } - return result; - } - return 1; -} - -absl::StatusOr Function::Execute(AbstractContext* ctx, - TaggedValue value) const { - TF_RETURN_IF_ERROR(VerifySupportedArgs(value)); - TF_ASSIGN_OR_RETURN(auto concrete_fn, GetConcreteFunction(value)); - std::vector args; - Flatten(value, &args); - std::vector outs( - GetFlatSize(concrete_fn.output_signature)); - TF_RETURN_IF_ERROR( - ExecuteFunction(concrete_fn.trace, ctx, args, absl::MakeSpan(outs))); - auto cleanup_tensors = absl::MakeCleanup([outs]() { - for (auto t : outs) { - t->Unref(); - } - }); - return Unflatten(outs, concrete_fn.output_signature); -} - -absl::StatusOr Function::GetConcreteFunction( - TaggedValue value) const { - if (concrete_fns_.empty()) { - return tensorflow::errors::FailedPrecondition( - "No registered ConcreteFunctions."); - } - for (auto& spec : concrete_fns_) { - if (Match(spec.input_signature, value)) { - return spec; - } - } - return tensorflow::errors::InvalidArgument("No match found."); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/function.h b/tensorflow/cc/experimental/libtf/function.h deleted file mode 100644 index 21232dd6fecc69..00000000000000 --- a/tensorflow/cc/experimental/libtf/function.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ - -#include - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -class Function { - public: - tensorflow::Status RegisterTrace(tensorflow::AbstractFunctionPtr, - TaggedValue input_signature, - TaggedValue output_signature); - - // Executes this function under the execution context. - // - // Raises an error is no matching signature is found for TaggedValue. - absl::StatusOr Execute(tensorflow::AbstractContext*, - TaggedValue) const; - - private: - struct ConcreteFunction { - tensorflow::AbstractFunctionPtr trace; - TaggedValue input_signature; - TaggedValue output_signature; - }; - absl::StatusOr GetConcreteFunction(TaggedValue) const; - std::vector concrete_fns_; -}; - -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_FUNCTION_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/BUILD b/tensorflow/cc/experimental/libtf/impl/BUILD deleted file mode 100644 index 97b06b21682daa..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/BUILD +++ /dev/null @@ -1,134 +0,0 @@ -# libtf implementation details. - -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "iostream", - srcs = [ - "iostream.cc", - ], - deps = [ - ":none", - ":string", - ":tensor_spec", - ], -) - -tf_cc_test( - name = "iostream_test", - size = "small", - srcs = ["iostream_test.cc"], - deps = [ - ":iostream", - ":none", - ":scalars", - ":string", - ":tensor_spec", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "scalars", - hdrs = [ - "scalars.h", - ], -) - -tf_cc_test( - name = "scalars_test", - size = "small", - srcs = ["scalars_test.cc"], - deps = [ - ":scalars", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "string", - srcs = [ - "string.cc", - ], - hdrs = [ - "string.h", - ], -) - -tf_cc_test( - name = "string_test", - size = "small", - srcs = ["string_test.cc"], - deps = [ - ":string", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( - name = "none", - srcs = [ - "none.cc", - ], - hdrs = [ - "none.h", - ], -) - -tf_cc_test( - name = "none_test", - size = "small", - srcs = ["none_test.cc"], - deps = [ - ":none", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/container:flat_hash_set", - ], -) - -cc_library( - name = "tensor_spec", - hdrs = [ - "tensor_spec.h", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_test( - name = "tensor_spec_test", - size = "small", - srcs = ["tensor_spec_test.cc"], - deps = [ - ":iostream", # Necessary for absl::VerifyTypeImplementsAbslHashCorrectly. - ":tensor_spec", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/hash:hash_testing", - ], -) diff --git a/tensorflow/cc/experimental/libtf/impl/iostream.cc b/tensorflow/cc/experimental/libtf/impl/iostream.cc deleted file mode 100644 index eee899b8704d82..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/iostream.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Specializations of ostream::operator<< for API values. These are defined here -// so that they don't need to be linked in executables that need to be kept -// small (and don't use the functionality). -#include - -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -namespace tf { -namespace libtf { -namespace impl { - -std::ostream& operator<<(std::ostream& o, const None& none) { - return o << "None"; -} - -std::ostream& operator<<(std::ostream& o, const String& str) { - return o << str.str(); -} - -std::ostream& operator<<(std::ostream& o, const TensorSpec& x) { - o << "TensorSpec(shape = " << x.shape.DebugString() << ", dtype = " << x.dtype - << ")"; - return o; -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc b/tensorflow/cc/experimental/libtf/impl/iostream_test.cc deleted file mode 100644 index dede1483d76187..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/iostream_test.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(OStreamTest, TestInt64) { - Int64 x(42); - std::stringstream stream; - stream << x; - ASSERT_EQ(stream.str(), "42"); -} - -TEST(OStreamTest, TestFloat32) { - Float32 x(0.375); // Exactly representable as a float. - std::stringstream stream; - stream << x; - ASSERT_EQ(stream.str(), "0.375"); -} - -TEST(OStreamTest, TestString) { - String s("foo"); - std::stringstream stream; - stream << s; - ASSERT_EQ(stream.str(), "foo"); -} - -TEST(OStreamTest, TestNone) { - std::stringstream stream; - stream << None::GetInstance(); - ASSERT_EQ(stream.str(), "None"); -} - -TEST(OStreamTest, TestTensorSpec) { - std::stringstream stream; - TensorSpec tensor_spec; - tensor_spec.shape = tensorflow::PartialTensorShape({2}); - tensor_spec.dtype = tensorflow::DT_FLOAT; - stream << tensor_spec; - ASSERT_EQ(stream.str(), "TensorSpec(shape = [2], dtype = 1)"); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/none.h b/tensorflow/cc/experimental/libtf/impl/none.h deleted file mode 100644 index 84dd654a4502b5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/none.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { -/// @brief The Singleton `None` class. -/// -/// This class is not user-constructible. To create a `None` instance, use -/// None::GetInstance(). - -class None final { - public: - /// Retrieves the `None` instance. - /// - /// @return Returns the `None` singleton. - static None& GetInstance(); - - /// Equality operator. - bool operator==(const None& other) const { return true; } - - /// Overload AbslHashValue. - template - friend H AbslHashValue(H h, const None& n) { - return H::combine(std::move(h), 34559); - } - - private: - // Private contructor. - None() {} -}; - -// Defined in iostream.cc. -std::ostream& operator<<(std::ostream& o, const None& none); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_NONE_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/none_test.cc b/tensorflow/cc/experimental/libtf/impl/none_test.cc deleted file mode 100644 index d9629e09704eb5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/none_test.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/impl/none.h" - -#include "absl/container/flat_hash_set.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(NoneTest, TestSingleton) { - None& a = None::GetInstance(); - None& b = None::GetInstance(); - EXPECT_EQ(&a, &b); -} - -TEST(NoneTest, TestSupportsAbslHash) { - absl::flat_hash_set none_set; - None& a = None::GetInstance(); - None& b = None::GetInstance(); - none_set.insert(a); - none_set.insert(b); - EXPECT_EQ(none_set.size(), 1); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/scalars.h b/tensorflow/cc/experimental/libtf/impl/scalars.h deleted file mode 100644 index 2345705637e585..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/scalars.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ - -#include - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { - -/** A thin wrapper around a C++ scalar value. - * This wrapper makes the scalar immutable. - */ -template -class Scalar final { - public: - explicit Scalar(T x) : value_(x) {} - Scalar(const Scalar& o) : value_(o.value_) {} - - bool operator==(const Scalar& o) const { return o.value_ == value_; } - - T get() const { return value_; } - - /** Absl hash function. */ - template - friend H AbslHashValue(H h, const Scalar& x) { - return H::combine(std::move(h), x.value_); - } - - private: - const T value_; -}; - -template -inline std::ostream& operator<<(std::ostream& o, const Scalar& x) { - return o << x.get(); -} - -/** The overloaded addition operator. */ -template -inline auto operator+(const Scalar& x1, const Scalar& x2) - -> Scalar { - using Ret = decltype(x1 + x2); // Return type of this function. - return Ret(x1.get() + x2.get()); -} - -using Int64 = Scalar; -using Float32 = Scalar; - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_SCALARS_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/string.cc b/tensorflow/cc/experimental/libtf/impl/string.cc deleted file mode 100644 index 70c716e552a08f..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/string.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/string.h" - -#include - -// It is important for the container below to not invalidate pointers to -// elements when elements are inserted, because the String class stores such -// pointers. This rules out, for example, absl::flat_hash_set. -using StringTable = std::unordered_set; - -namespace tf { -namespace libtf { -namespace impl { - -String::String(const char* s) { - static StringTable* table = new StringTable; - value_ = &*table->insert(s).first; -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/impl/string.h b/tensorflow/cc/experimental/libtf/impl/string.h deleted file mode 100644 index a54fb25b9775c5..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/string.h +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ - -#include -#include - -namespace tf { -namespace libtf { -namespace impl { - -/** A string value. - * This class wraps an interned, immutable string value. Currently, interned - * values are never deleted, so memory usage increases without bound as new - * strings are created. - */ -class String final { - public: - /** Interning constructor. - * Interns the given string value. - */ - explicit String(const char* s); - - String() : String("") {} - String(const String& s) : value_(s.value_) {} - - // This is the same as the default equality operator, which works because - // we're interning all strings. It is specified here so we are explicit about - // it. We're not saying "= default;" because we can't use C++20 features yet. - bool operator==(const String& other) const { return value_ == other.value_; } - - const std::string& str() const { return *value_; } - - /** Absl hash function. */ - template - friend H AbslHashValue(H h, const String& s) { - return H::combine(std::move(h), *s.value_); - } - - private: - //! The interned string value. This is never null. - const std::string* value_; -}; - -// This is defined in the `iostream.cc` file in this directory. It is not -// defined inline here because the `iosfwd` header does not provide enough -// functionality (in Windows), and we don't want to include `iostream` to avoid -// increasing the binary size. -std::ostream& operator<<(std::ostream& o, const String& str); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_STRING_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/tensor_spec.h b/tensorflow/cc/experimental/libtf/impl/tensor_spec.h deleted file mode 100644 index be7c19297d8c8c..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/tensor_spec.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ - -#include - -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" - -namespace tf { -namespace libtf { -namespace impl { -/// @brief The TensorSpec struct. -/// -/// The TensorSpec describes the shape and dtype of a Tensor. - -struct TensorSpec { - tensorflow::PartialTensorShape shape; - tensorflow::DataType dtype; - - bool operator==(const TensorSpec& o) const { - return dtype == o.dtype && shape.IsIdenticalTo(o.shape); - } - - /// Overload AbslHashValue to make TensorSpec hashable. - template - friend H AbslHashValue(H h, const TensorSpec& t) { - return H::combine(std::move(h), t.shape.DebugString(), t.dtype); - } -}; - -// Defined in `iostream.cc`. -std::ostream& operator<<(std::ostream& o, const TensorSpec& x); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_IMPL_TENSOR_SPEC_H_ diff --git a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc b/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc deleted file mode 100644 index dc07f77c7ba9b7..00000000000000 --- a/tensorflow/cc/experimental/libtf/impl/tensor_spec_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -#include "absl/hash/hash_testing.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(TensorSpecTest, TestSupportsAbslHash) { - tensorflow::PartialTensorShape unknown_shape; - TensorSpec ts1; - ts1.shape = unknown_shape; - ts1.dtype = tensorflow::DT_FLOAT; - - TensorSpec ts2; - ts2.shape = tensorflow::PartialTensorShape({2}); - ts2.dtype = tensorflow::DT_FLOAT; - - TensorSpec ts3; - ts3.shape = tensorflow::PartialTensorShape({1, 2}); - ts3.dtype = tensorflow::DT_FLOAT; - - EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly({ts1, ts2, ts3})); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/mlir/BUILD b/tensorflow/cc/experimental/libtf/mlir/BUILD deleted file mode 100644 index db86e4c34e8de9..00000000000000 --- a/tensorflow/cc/experimental/libtf/mlir/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -# Parts of new C++ API that interface with MLIR. - -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "transform", - srcs = [ - "mlir_transform.cc", - ], - hdrs = ["mlir_transform.h"], - deps = [ - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/saved_model:bundle_v2", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tensorflow:translate_lib", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", - ], -) diff --git a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc b/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc deleted file mode 100644 index f5bd971caec516..00000000000000 --- a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.cc +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/mlir/mlir_transform.h" - -#include -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/saved_model/bundle_v2.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" - -namespace tf { -namespace libtf { - -// TODO(b/190837282): All return None's become errors. -Handle LoadModule(Object self, String saved_model) { - // Parse arguments. - // Load SavedModel into memory. - tensorflow::SavedModelV2Bundle bundle; - tensorflow::Status status = - tensorflow::SavedModelV2Bundle::Load(saved_model.get(), &bundle); - if (!status.ok()) { - return None(); - } - // Fetch MLIR context - auto* context = self.Get(String("_context")) - ->cast(); - - // Load the saved model into MLIR TF dialect. - absl::Span exported_names(nullptr, 0); - auto module_or = - tensorflow::ConvertSavedModelToMlir(&bundle, context, exported_names); - if (!module_or.status().ok()) { - return None(); - } - - // Make a module to wrap MLIR module and allow getting strings and running - // transforms. - // auto obj = TaggedValue::Dict(); - Object obj; - obj.Set( - String("_module"), - Handle(impl::TaggedValue::Capsule(new mlir::OwningOpRef( - std::move(module_or).value())))); - - auto get_string = [](Object self) { - auto ref = self.Get(String("_module")) - ->cast*>(); - return String(tensorflow::MlirModuleToString(ref->get(), false).c_str()); - }; - obj.Set(String("ToString"), Callable(TFLIB_CALLABLE_ADAPTOR(get_string))); - - return obj; -} - -None SaveModule(Object self, Object module, String directory) { - // TODO(b/190835292): Implement save. - return None(); -} - -None Transform(Object self, Object module, List passes) { - // TODO(b/190835292): Implement save. - return None(); -} - -Object MLIR() { - Object obj; - obj.Set(String("LoadSavedModel"), - Callable(TFLIB_CALLABLE_ADAPTOR(LoadModule))); - obj.Set(String("SaveSavedModel"), - Callable(TFLIB_CALLABLE_ADAPTOR(SaveModule))); - obj.Set(String("_context"), - Handle(impl::TaggedValue::Capsule(new mlir::MLIRContext()))); - return obj; -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/module.cc b/tensorflow/cc/experimental/libtf/module.cc deleted file mode 100644 index b2102dc466edd6..00000000000000 --- a/tensorflow/cc/experimental/libtf/module.cc +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/module.h" - -#include - -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" -namespace tf { -namespace libtf { -namespace impl { - -using tensorflow::libexport::TFPackage; -using tf::libtf::runtime::Runtime; - -// TODO(danielellis): Fill in with implementations. - -// Builds a vector of runtime representations of `SavedObject`s from a -// SavedModel. These are returned as a flat list. The full hierarchy building -// and initialization should be done in a later pass. -absl::StatusOr> BuildObjects(TFPackage& tf_package) { - std::vector objects; - const tensorflow::SavedObjectGraph object_graph = tf_package.GetObjectGraph(); - for (auto& node : object_graph.nodes()) { - if (node.kind_case() == tensorflow::SavedObject::kUserObject) { - absl::StatusOr result = BuildSavedUserObject(node); - if (result.ok()) { - objects.push_back(*result); - } else { - return result.status(); - } - } - } - return objects; -} - -absl::StatusOr BuildSavedUserObject( - tensorflow::SavedObject saved_object_proto) { - if (saved_object_proto.kind_case() != tensorflow::SavedObject::kUserObject) { - return tensorflow::errors::InvalidArgument("Not a UserObject."); - } - - std::string identifier = saved_object_proto.user_object().identifier(); - if (identifier == "trackable_list_wrapper") { - tf::libtf::List user_list; - // TODO(b/191267013): Populate with values. - return user_list; - } - if (identifier == "trackable_dict_wrapper") { - tf::libtf::Dictionary user_dict; - // TODO(b/191267013): Populate with values. - return user_dict; - } - if (identifier == "signature_map") { - tf::libtf::Dictionary signature_map; - // TODO(b/191267013): Populate with values. - return signature_map; - } - if (identifier == "_generic_user_object") { - tf::libtf::Dictionary user_object; - // TODO(b/191267013): Populate with values. - return user_object; - } - return tensorflow::errors::Unimplemented(absl::StrCat( - "UserObject with identifier '", identifier, "' not implemented.")); -} - -// Register all available concrete functions from a SavedModel into a runtime. -tensorflow::Status RegisterConcreteFunctions(Runtime runtime, - TFPackage tf_package) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Initialize any variables found in the SavedModel and attach them to the -// appropriate object representation in the runtime. -tensorflow::Status InitializeVariables(Runtime runtime, TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Register concrete functions with their associated polymorphic functions. -tensorflow::Status SetupPolymorphicFunctions(Runtime runtime, - TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Register any captures with their associated higher-level functions. -tensorflow::Status SetupFunctionCaptures(Runtime runtime, TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -// Takes a flat list of Handles and builds them into the hierarchical -// representation defined by the SavedModel. -absl::StatusOr BuildObjectHierarchy(TFPackage tf_package, - std::vector objects) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -absl::StatusOr BuildProgram(Runtime runtime, TFPackage& tf_package) { - return tensorflow::errors::Unimplemented("Not implemented."); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/module.h b/tensorflow/cc/experimental/libtf/module.h deleted file mode 100644 index c857f702888a82..00000000000000 --- a/tensorflow/cc/experimental/libtf/module.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ - -#include "tensorflow/cc/experimental/libexport/load.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace impl { - -// The main interface for taking a serialized saved model and getting back a -// fully-built model. -// -// Implementation steps: -// -// 1) For each function def in the SavedModel, register it with the runtime. -// 2) For each object in the object graph def, build it. -// 3) For each variable stored in the checkpoint in the SavedModel, -// restore it, and attach it to the associated variable object. -// 4) For each polymorphic function, associate it with the appropriate -// concrete function(s). -// 5) For each function with captures, bind the appropriate objects as -// captured inputs. -// 6) Take the fully-prepared objects, and build them into a hierarchy. -// 7) Return the prepared model. - -// Converts a SavedUserObject into its corresponding data structure. -// TODO(b/185579152): This method returns empty data structures currently. -absl::StatusOr BuildSavedUserObject( - tensorflow::SavedObject saved_object_proto); - -// "Build" all SavedObjects, ie convert from proto to their runtime -// representation, in the tf_package. -absl::StatusOr> BuildObjects( - tensorflow::libexport::TFPackage& tf_package); - -// Convert tf_package to a program in the runtime. -absl::StatusOr BuildProgram( - runtime::Runtime runtime, tensorflow::libexport::TFPackage& tf_package); - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MODULE_H_ diff --git a/tensorflow/cc/experimental/libtf/object.h b/tensorflow/cc/experimental/libtf/object.h deleted file mode 100644 index bebf28b6d496d3..00000000000000 --- a/tensorflow/cc/experimental/libtf/object.h +++ /dev/null @@ -1,709 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/// @file object.h -/// @brief Object hierarchy for the TensorFlow C++ API. All "objects" are -/// derived from the `Handle` class. Instances of `Handle` are referred to as -/// "handles". All handles have a tagged value. -/// -/// Example Usage: -/// Object runtime = GetRuntime("tfrt"); -/// Object module = runtime.Get("Import")("cool_mobilenet") -/// runtime.Get("Tensor")(Tuple(5,5,5), 3.3); -/// Object test = CreateModule("test"); -/// test.Set("cool_function", callable); -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ - -#include -#include - -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { - -using TaggedValue = impl::TaggedValue; -class Handle; - -// Necessary forward declare. -template -Handle Convert(T value); - -/// @brief Base Handle class that wraps TaggedValue data. All data creation and -/// manipulation should done using Handle instances. Users should not be working -/// with TaggedValues directly. - -/// The `Handle` class contains a TaggedValue in the `value_` member, which -/// contains the underlying data. An object belonging to `Foo`, a derived class -/// of `Handle`, can be referred to as a `Foo` handle. -/// -/// It is important that all derived classes do not add any new data fields. -/// This ensures that it is always safe to slice down (i.e. assign an object of -/// a derived class to the base class) a handle to the base Handle class. -class Handle { - public: - /// Default constructor, which initializes a TaggedValue with type NONE. - Handle() : value_(TaggedValue::None()) {} - - public: - /// Constructs a handle from a TaggedValue. - explicit Handle(TaggedValue value) : value_(std::move(value)) {} - // explicit Handle(TaggedValue value, Handle* class_input) - // : value_(std::move(value)), class_(class_input) {} - // const Handle& type() { return *class_; } - - protected: - /// The wrapped TaggedValue. - TaggedValue value_; - // effectively a "weak reference" to intern'd class value. - // types are compared by comparing pointer values here. - // Handle* class_; // effectively a "weak reference" to intern'd class value. - - /// The Integer handle. - friend class Integer; - /// The Float handle. - friend class Float; - /// The String handle. - friend class String; - /// The Object handle. - friend class Object; - /// The List handle. - friend class List; - /// The Dictionary handle. - friend class Dictionary; - /// The Tuple handle. - friend class Tuple; - /// The Callable handle. - friend class Callable; - /// The Tensor handle. - friend class Tensor; - /// Converts a Handle instance to an instance of a derived class `T`. - template - friend tensorflow::StatusOr Cast(Handle handle); - /// Infrastructure for converting a TaggedValue tuple function signature to an - /// unpacked variable list. - template - friend class UneraseCallHelper; -}; - -// Forward declare. -template -tensorflow::StatusOr Cast(Handle handle); - -/// @brief The None class for holding TaggedValues of type NONE. -class None final : public Handle { - public: - /// Creates a handle that wraps a NONE TaggedValue. - None() : Handle(TaggedValue::None()) {} - - private: - explicit None(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The String class for holding TaggedValues of type STRING. -class String final : public Handle { - public: - /// Creates a handle that wraps a STRING TaggedValue. - explicit String(const char* s) : Handle(TaggedValue(s)) {} - /// Returns the underlying TaggedValue string. - const char* get() const { return value_.s(); } - - private: - // Private since it is in general unsafe. - explicit String(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The `Object` class modeled after Python "objects". -/// -/// An `Object` uses a TaggedValue dictionary to store its attributes. The -/// "__parent__" attribute is reserved. -class Object : public Handle { - public: - /// Constructs a handle that acts as an object. - Object() : Handle(TaggedValue::Dict()) {} - /// Retrieves the key of the object's parent. - static const String& ParentKey(); - - /// @brief Gets an object member attribute`key`. - /// - /// If the `key` is not found in the object, the object's "__parent__" - /// attribute is then searched. - /// - /// @tparam T The desired return type. - /// @param key The key to look up. - /// @return `StatusOr` wrapping the key's value. - template - tensorflow::StatusOr Get(const String& key) { - auto& dict = value_.dict(); - auto it = dict.find(key.value_); - if (it != dict.end()) { - return Cast(Handle(it->second)); - } else { - // Lookup in object stored by reference in attribute "__parent__". - auto it_class = dict.find(ParentKey().value_); - if (it_class != dict.end()) { - auto& class_dict_maybe = it_class->second; - if (class_dict_maybe.type() == TaggedValue::DICT) { - auto& dict = class_dict_maybe.dict(); - auto it = dict.find(key.value_); - if (it != dict.end()) { - return Cast(Handle(it->second)); - } - } - } - } - return absl::NotFoundError("Key not in dictionary."); - } - - /// Sets `key` attribute with the underlying value of `h`. - void Set(const String& key, Handle h) { - value_.dict()[key.value_] = std::move(h.value_); - } - - /// Removes `key` from the object's attributes. - void Unset(const String& key) { value_.dict().erase(key.value_); } - // TODO(b/): Adding dir() is in the future. - private: - // Private since it is in general unsafe. - explicit Object(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Dictionary class for holding TaggedValues of type DICT. -class Dictionary final : public Handle { - public: - /// Constructs a handle that wraps a DICT TaggedValue. - Dictionary() : Handle(TaggedValue::Dict()) {} - // TODO(aselle): make this private to preserve invariant. - - /// Retrieves `key` with type `T`. - template - tensorflow::StatusOr Get(const Handle& key) { - auto it = value_.dict().find(key.value_); - if (it != value_.dict().end()) return Cast(Handle(it->second)); - return absl::NotFoundError("Key not in dictionary."); - } - /// Sets `key` with value `value`. - void Set(const String& key, Handle value) { - value_.dict()[key.value_] = std::move(value.value_); - } - /// Sets `key` with value `value`. - void Set(const Handle& key, Handle value) { - value_.dict()[key.value_] = std::move(value.value_); - } - /// Retrieves size of dictionary. - size_t size() const { return value_.dict().size(); } - - private: - // Private since it is in general unsafe. - explicit Dictionary(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Integer class for holding TaggedValues of type INT. -class Integer final : public Handle { - public: - /// Creates a handle that wraps an INT TaggedValue. - explicit Integer(Handle h) : Handle(h.value_) {} - /// Creates a handle that wraps an INT TaggedValue. - explicit Integer(int64_t i) : Handle(TaggedValue(i)) {} - /// Retrieves the underlying integer value. - int64_t get() const { return value_.i64().get(); } - - private: - // Private since it is in general unsafe. - explicit Integer(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Float class for holding TaggedValues of type FLOAT. -class Float final : public Handle { - public: - /// Constructs a Float handle that wraps a FLOAT TaggedValue. - explicit Float(Handle h) : Handle(h.value_) {} - /// Constructs a Float handle that wraps a FLOAT TaggedValue. - explicit Float(float i) : Handle(TaggedValue(i)) {} - /// Retrieves the underlying float value. - float get() const { return value_.f32().get(); } - - private: - // Private since it is in general unsafe. - explicit Float(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The Tensor class for holding TaggedValues of type TENSOR. -class Tensor final : public Handle { - public: - /// Constructs a Tensor handle from a Handle that wraps a TENSOR TaggedValue. - explicit Tensor(Handle h) : Handle(h.value_) {} - - /// @brief Retrieves the value of the Tensor handle. - - /// @param data Buffer in which to copy contents of the handle. - /// @throws InvalidArgument Raises error if `data` is of invalid size. - template - tensorflow::Status GetValue(absl::Span data) const; - - private: - // Private since it is in general unsafe. - explicit Tensor(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -template -tensorflow::Status Tensor::GetValue(absl::Span data) const { - tensorflow::AbstractTensorPtr t; - { - const auto abstract_t = value_.tensor().get(); - if (!tensorflow::ImmediateExecutionTensorHandle::classof(abstract_t)) { - return absl::InvalidArgumentError( - "Attempting to get value of non eager tensor."); - } - auto imm_t = - static_cast(abstract_t); - tensorflow::Status status; - t.reset(imm_t->Resolve(&status)); - if (!status.ok()) { - return status; - } - } - if (data.size() != t->NumElements()) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Mismatched number of elements: \n", "Expected: ", data.size(), "\n", - "Actual: ", t->NumElements(), "\n")); - } - memcpy(data.data(), t->Data(), t->ByteSize()); - return absl::OkStatus(); -} - -/// @brief The Tuple class for holding TaggedValues of type TUPLE. -class Tuple : public Handle { - public: - /// Constructs a Tuple handle. - template - explicit Tuple(T... args) : Handle(TaggedValue::Tuple()) { - add(args...); - } - - /// Retrieves value at index `i`. - template - tensorflow::StatusOr Get(size_t i) { - if (i >= value_.tuple().size()) - return absl::InvalidArgumentError("Out of bounds index."); - return Cast(Handle(value_.tuple()[i])); - } - - /// Retrieves number of elements. - size_t size() const { return value_.tuple().size(); } - - private: - // Add an item to a tuple. Should only be done by special construction - // like Callables (which are a friend). - void add() {} - template - void add(T arg, T2... args) { - value_.tuple().emplace_back(Convert(arg).value_); - add(args...); - } - - // Private since it is in general unsafe. - explicit Tuple(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The List class for holding TaggedValues of type LIST. -class List final : public Handle { - public: - /// Constructs a List handle. - template - explicit List(T... args) : Handle(TaggedValue::List()) {} - /// Retrieves value at index `i`. - template - tensorflow::StatusOr Get(size_t i) { - if (i >= size()) { - return absl::InvalidArgumentError("Out of bounds index."); - } - return Cast(Handle(value_.list()[i])); - } - - /// Sets value `h` at index `i`. - tensorflow::Status Set(size_t i, Handle h) { - if (i >= size()) { - return absl::InvalidArgumentError("Out of bounds index."); - } - value_.list()[i] = std::move(h.value_); - return absl::OkStatus(); - } - - /// Appends `arg` to list. - template - void append(T arg) { - value_.list().emplace_back(Convert(arg).value_); - } - /// Retrieves size of list. - size_t size() const { return value_.list().size(); } - - private: - // Private since it is in general unsafe. - explicit List(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -/// @brief The `KeywordArg` class for storing keyword arguments as name value -/// pairs. -class KeywordArg { - public: - explicit KeywordArg(const char* s) : key_(String(s)), value_() {} - - template - KeywordArg& operator=(const T obj) { - value_ = Convert(obj); - return *this; - } - - friend class Callable; - - private: - String key_; - Handle value_; -}; - -/// @brief The Callable class for creating callables. -class Callable final : public Handle { - private: - // Collect arguments for call - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx) {} - template - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, T v, - Types... vars) { - const Handle& o = Convert(v); - args.value_.tuple().emplace_back(o.value_); - CollectArgs(args, kwargs, idx + 1, vars...); - } - template - void CollectArgs(Tuple& args, Dictionary& kwargs, int idx, KeywordArg v, - Types... vars) { - kwargs.Set(v.key_, v.value_); - CollectArgs(args, kwargs, idx + 1, vars...); - } - - public: - /// @brief Calls the wrapped TaggedValue function on a variable argument - /// list. - template - tensorflow::StatusOr Call(Types... vars) { - Dictionary kwargs = Dictionary(); - Tuple args; - CollectArgs(args, kwargs, 0, vars...); - auto maybe_value = - value_.func()(std::move(args.value_), std::move(kwargs.value_)); - if (!maybe_value.ok()) { - return maybe_value.status(); - } - return Cast(Handle(maybe_value.value())); - } - - public: - // TODO(aselle): need to find a way to write test w/o this being public. - // Private since it is in general unsafe. - explicit Callable(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr Cast(Handle handle); -}; - -namespace internal { -/// @brief The Capsule class for holding pointers. -class Capsule final : public Handle { - public: - /// Statically cast the TaggedValue capsule to type `T`. - template - T cast() { - return static_cast(value_.capsule()); - } - - private: - // Private since it is in general unsafe. - explicit Capsule(TaggedValue v) : Handle(std::move(v)) {} - template - friend tensorflow::StatusOr tf::libtf::Cast(Handle handle); -}; -} // namespace internal - -/// @defgroup Util Functions for type conversion -/// -/// @brief Functions for retrieving and converting Handle types. -/// @{ - -/// Retrieves tagged type of `T` handle. -template -inline TaggedValue::Type TypeToTaggedType() {} -/// Retrieves tagged type of base class handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::NONE; -} -/// Retrieves tagged type of None handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::NONE; -} -/// Retrieves tagged type of String handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::STRING; -} -/// Retrieves tagged type of Callable handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::FUNC; -} -/// Retrieves tagged type of Integer handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::INT64; -} -/// Retrieves tagged type of Float handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::FLOAT32; -} -/// Retrieves tagged type of Object handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::DICT; -} -/// Retrieves tagged type of Dictionary handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::DICT; -} -/// Retrieves tagged type of List handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::LIST; -} -/// Retrieves tagged type of Tensor handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::TENSOR; -} -/// Retrieves tagged type of Capsule handle. -template <> -inline TaggedValue::Type TypeToTaggedType() { - return TaggedValue::Type::CAPSULE; -} -// TODO(unknown): fully populate - -/// @brief Casts a handle to type `T` -/// -/// @param handle The handle to cast. -/// @tparam T The target handle type. -/// @exception InvalidArgument Raises error if the underlying TaggedValue type -/// of `handle` is not equivalent to `T`. -template -tensorflow::StatusOr Cast(Handle handle) { - if (handle.value_.type() == TypeToTaggedType() || - std::is_same::value) - return T((std::move(handle.value_))); - return absl::InvalidArgumentError("Incompatible cast."); -} - -// Converters for C++ primitives like float and int to handles. Allows callable -// calls and list appends to be more idiomatic. - -/// Converts a C++ const char* to a String handle. -template <> -inline Handle Convert(const char* value) { - return String(value); -} -/// Converts a C++ int32_t to an Integer handle. -template <> -inline Handle Convert(int32_t value) { - return Integer(value); -} -/// Converts a C++ int64_t to an Integer handle. -template <> -inline Handle Convert(int64_t value) { - return Integer(value); -} -/// Converts a C++ float to an Integer handle. -template <> -inline Handle Convert(float value) { - return Float(value); -} -/// Converts a value with primitive type T to a Handle. -template -inline Handle Convert(T value) { - return Handle(std::move(value)); -} - -/// @} - -// in the future it will be possible to make additional hard typed APIs -// by generating code by introspecting objects. - -// Here's a code gen'd example -// The dynamic structure can be turned into it. -/* -class Tf : Object { - Tensor ones(Tensor shape, String dtype); - // ... -} -*/ - -// Adapter to allow users to define Callables. Use TFLIB_CALLABLE_ADAPTOR -// instead. -template -class CallableWrapper; - -// Template extracts arguments from a lambda function. This base -// class definition inherits from a another specialization in order. We use -// this top level template to extract the function pointer associated with -// the created lambda functor class. -template -class CallableWrapperUnpackArgs - : public CallableWrapperUnpackArgs { - public: - CallableWrapperUnpackArgs(TLambda fn, const char* name) - : CallableWrapperUnpackArgs(fn, name) {} -}; - -// This specialization unpacks the arguments from a normal function pointer. -template -class CallableWrapperUnpackArgs - : public CallableWrapper { - using Fn = TReturn (*)(TFuncArgs...); - - public: - CallableWrapperUnpackArgs(Fn fn, const char* name) - : CallableWrapper(fn, name) {} -}; - -// This is the second stage of extracting the arguments from lambda function. -// NOTE: CallableWrapper's first template argument is the type of the -// function or functor (not the member pointer). -template -class CallableWrapperUnpackArgs - : public CallableWrapper { - using Fn = TClass; - - public: - CallableWrapperUnpackArgs(Fn fn, const char* name) - : CallableWrapper(fn, name) {} -}; - -template -class UneraseCallHelper; - -// UneraseCallHelper::Call allows transforming all the incoming arguments -// from a TaggedValue tuple to a variadic list of args. The class template -// starts as a list of argument types and ends empty. The static member -// template starts empty and ends with the unerased types of the signature. - -// Base case (all arguments are processed, so call the function TFunc. -template -class UneraseCallHelper { - public: - template - static absl::StatusOr Call(const char* name, Fn functor_, - int argument_index, - const TaggedValue& args_in, - ArgsOut... args) { - // Call concrete type function - TReturn ret = functor_(args...); - return ret.value_; - } -}; - -// Unpack a single argument case. Each argument is then cast. -template -class UneraseCallHelper { - public: - template - static absl::StatusOr Call(const char* name, Fn fn, - int argument_index, - TaggedValue& args_in, - TArgsOut... args) { - Handle h(std::move(args_in.tuple()[argument_index])); - tensorflow::StatusOr x = Cast(std::move(h)); - if (!x.ok()) - return absl::InvalidArgumentError( - absl::StrCat(std::string("Function ") + name + " Arg " + - std::to_string(argument_index) + - " cannot be cast to desired signature type ")); - return UneraseCallHelper::Call( - name, fn, argument_index + 1, args_in, args..., *x); - } -}; - -// Template specialization that allows extracting arguments from a C function -// pointer. -template -class CallableWrapper { - private: - Fn functor_; - const char* name_; - - public: - explicit CallableWrapper(Fn fn, const char* name) - : functor_(fn), name_(name) {} - - // Entry point of the Adaptor functor. Note args, and kwargs are attempted - // to be moved. - absl::StatusOr operator()(TaggedValue args, TaggedValue kwargs) { - constexpr size_t argument_count = sizeof...(TFuncArgs); - if (argument_count != args.tuple().size()) - return absl::InvalidArgumentError( - absl::StrCat(std::string("Function ") + name_ + " expected " + - std::to_string(argument_count) + " args.")); - return UneraseCallHelper::Call(name_, functor_, - 0, args); - } -}; - -// Wrap a function that uses object handles as arguments and return types -// with one that takes TaggedValues. For example: -// Tuple Pack(Integer, Float, String); -// TaggedValue callable = TFLIB_CALLABLE_ADAPTOR(Pack); -#define TFLIB_CALLABLE_ADAPTOR(x) ::tf::libtf::CreateCallableAdaptor(x, #x) - -template -TaggedValue CreateCallableAdaptor(TF x, const char* name) { - return TaggedValue((CallableWrapperUnpackArgs(x, name))); -} - -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_OBJECT_H_ diff --git a/tensorflow/cc/experimental/libtf/runtime/BUILD b/tensorflow/cc/experimental/libtf/runtime/BUILD deleted file mode 100644 index b20c0e6e3f903b..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -load( - "//tensorflow/core/platform:rules_cc.bzl", - "cc_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "runtime", - srcs = [ - "runtime.cc", - ], - hdrs = [ - "runtime.h", - ], - deps = [ - "//tensorflow/c:tf_datatype", - "//tensorflow/c:tf_status_helper", - "//tensorflow/c:tf_status_internal", - "//tensorflow/c/eager:abstract_context", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:graph_function", - "//tensorflow/c/eager:immediate_execution_context", - "//tensorflow/c/eager:tfe_context_internal", - "//tensorflow/cc/experimental/libexport:load", - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/experimental/libtf:function", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:statusor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) diff --git a/tensorflow/cc/experimental/libtf/runtime/core/BUILD b/tensorflow/cc/experimental/libtf/runtime/core/BUILD deleted file mode 100644 index 09106ea8cb75b4..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/core/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/experimental/libtf:__subpackages__", - ], - licenses = ["notice"], -) - -cc_library( - name = "core", - srcs = [ - "core.cc", - ], - hdrs = [ - "core.h", - ], - deps = [ - "//tensorflow/c:tf_status_internal", - "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:tfe_context_internal", - "//tensorflow/cc/experimental/libtf", - "//tensorflow/cc/experimental/libtf/runtime", - ], -) diff --git a/tensorflow/cc/experimental/libtf/runtime/core/core.cc b/tensorflow/cc/experimental/libtf/runtime/core/core.cc deleted file mode 100644 index 5d23c7aa0920da..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/core/core.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" - -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" - -namespace tf { -namespace libtf { -namespace runtime { -namespace core { - -runtime::Runtime Runtime() { - TaggedValue ctx_capsule; - TFE_Context* ctx; - TFE_ContextOptions* ctx_options = TFE_NewContextOptions(); - TFE_ContextOptionsSetDevicePlacementPolicy(ctx_options, - TFE_DEVICE_PLACEMENT_WARN); - TF_Status* status = TF_NewStatus(); - ctx = TFE_NewContext(ctx_options, status); - TF_DeleteStatus(status); - TFE_DeleteContextOptions(ctx_options); - return runtime::Runtime(tensorflow::unwrap(ctx)); -} - -} // namespace core -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/runtime/runtime.cc b/tensorflow/cc/experimental/libtf/runtime/runtime.cc deleted file mode 100644 index 460964be0f4f29..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/runtime.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" - -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/c_api_experimental.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/immediate_execution_context.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libexport/load.h" -#include "tensorflow/cc/experimental/libtf/function.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" -#include "tensorflow/core/protobuf/struct.pb.h" -#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace runtime { - -using tensorflow::AbstractContext; -using tensorflow::AbstractFunctionPtr; -using tensorflow::DataType; -using tensorflow::FunctionDef; -using tensorflow::PartialTensorShape; -using tensorflow::SavedConcreteFunction; -using tensorflow::SavedObjectGraph; -using tensorflow::Status; -using tensorflow::StructuredValue; -using tensorflow::TensorSpecProto; -using tensorflow::libexport::TFPackage; -using tensorflow::protobuf::RepeatedPtrField; -using tensorflow::tracing::graph::GraphFunction; - -TaggedValue MakeCallable(const std::string& fn_name, Function fn, - AbstractContext* ctx) { - auto CallFn = [fn_name, fn, ctx](TaggedValue args_, - TaggedValue kwargs_) -> TaggedValue { - std::cout << "Calling " << fn_name << std::endl; - tensorflow::StatusOr v = fn.Execute(ctx, args_); - return v.value(); - }; - return TaggedValue(CallFn); -} - -// Import a module from a saved model. -// -// Returns a TaggedValue::Dict. All functions found on the root of the module -// will be attached as callables to this TaggedValue. -// -// `name` specifies the full path to the saved model. -// -// `ctx` should outlive the lifetime of the module. -static tensorflow::StatusOr ImportModule(String name, - AbstractContext* ctx) { - // Load the serialized model. - tensorflow::StatusOr tf_package = TFPackage::Load(name.get()); - if (!tf_package.status().ok()) { - return tf_package.status(); - } - TaggedValue module = TaggedValue::Dict(); - - // Initialize concrete function traces. - const RepeatedPtrField function_defs = - tf_package->GetFunctionDefs(); - absl::flat_hash_map traces; - for (auto& fdef : function_defs) { - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - traces[fdef.signature().name()] = trace; - } - - // Setup polymorphic functions and wrap in Callables. - // - // For each child of the root, check what type it is. If it's a - // SavedFunction, attach that function to the module as a Callable. - const SavedObjectGraph object_graph = tf_package->GetObjectGraph(); - auto& nodes = object_graph.nodes(); - // Get a map of the concrete functions to their input / output signatures. - auto& concrete_functions = object_graph.concrete_functions(); - auto& root = nodes.at(0); - for (auto& child : root.children()) { - // The child's name describes the name of the edge that connects to the - // parent object. This name will be the name of the object stored in the - // generated module. - auto& child_node = nodes.at(child.node_id()); - auto child_name = child.local_name().c_str(); - - if (child_node.kind_case() == tensorflow::SavedObject::kFunction) { - Function tf_function; - for (const std::string& fn_name : - child_node.function().concrete_functions()) { - // Setup input signature. - // - // For now, we have to do a lot of manual digging through these and - // assume they are tensorspecs. Once TODO(b/190203981) is done, we - // should be able to pass along the `StructuredValue`s to an API in a - // much cleaner way. - // - // TODO(b/190206621): Implement API for inspecting signatures - SavedConcreteFunction saved_concrete_function = - concrete_functions.at(fn_name); - TaggedValue input_signature = TaggedValue::Tuple(); - const RepeatedPtrField& args = - saved_concrete_function.canonicalized_input_signature() - .tuple_value() - .values(0) - .tuple_value() - .values(); - for (const StructuredValue& arg : args) { - PartialTensorShape shape = arg.tensor_spec_value().shape(); - DataType dtype = arg.tensor_spec_value().dtype(); - TaggedValue tensor_spec(shape, dtype); - input_signature.tuple().emplace_back(tensor_spec); - } - - // Setup output signature. - TensorSpecProto output_tensor_spec_proto = - saved_concrete_function.output_signature().tensor_spec_value(); - PartialTensorShape output_shape = output_tensor_spec_proto.shape(); - DataType output_dtype = output_tensor_spec_proto.dtype(); - TaggedValue output_tensor_spec(output_shape, output_dtype); - - // Register the function trace. - // - // This does *not* currently register the function with the runtime. - // Instead, we're registering JIT at call time. This is likely - // something that we'll change in TODO(b/190070277). - auto& trace = traces[fn_name]; - Status status = tf_function.RegisterTrace( - std::move(trace), input_signature, output_tensor_spec); - } - TaggedValue callable = MakeCallable(child_name, tf_function, ctx); - module.dict()[TaggedValue(child_name)] = callable; - } - } - return module; -} - -// Instantiate the Runtime, creating relevant Callables for later use. -Runtime::Runtime(AbstractContext* ctx) { - TaggedValue ctx_capsule = - TaggedValue::Capsule(static_cast(ctx), [](void* p) { - auto ctx = static_cast(p); - ctx->Release(); - }); - Set(String("ctx"), Handle(ctx_capsule)); - auto Load = [](Object self, String name) -> Object { - auto ctx_capsule = self.Get(String("ctx")).value(); - auto ctx = ctx_capsule.cast(); - // TODO(b/191689645): This needs to do error handling better. - return *Cast(Handle(*ImportModule(name, ctx))); - }; - - Set(String("Load"), Callable(TFLIB_CALLABLE_ADAPTOR(Load))); -} - -tensorflow::StatusOr Runtime::Load(const String& name) { - return Get(String("Load"))->Call(*this, name); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/runtime/runtime.h b/tensorflow/cc/experimental/libtf/runtime/runtime.h deleted file mode 100644 index 5c3ac94fbe03c6..00000000000000 --- a/tensorflow/cc/experimental/libtf/runtime/runtime.h +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ - -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "tensorflow/c/eager/c_api.h" -#include "tensorflow/c/eager/tfe_context_internal.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" - -namespace tf { -namespace libtf { -namespace runtime { - -/// @brief A runtime object capable of loading modules and executing functions. -/// -/// It is the responsibility of the owner of the Runtime to keep it alive longer -/// than all imported modules. -class Runtime : public Object { - public: - // TODO(b/191264214): Remove need for AbstractContext - explicit Runtime(tensorflow::AbstractContext* ctx); - /// @brief Loads the module indicated by `name` and returns it. - /// - /// @param name The name of the module / file path to load - /// @return An `Object` representing the module, if successful. Otherwise, a - /// non-ok `absl::Status`. - tensorflow::StatusOr Load(const String& name); - // TODO(b/186787000): Loading a module with identically-named functions as - // a previously loaded module results in undefined behavior. This - // functionality will be supported in the future. - - // Create a host tensor and copy data into it. - // - // Raises an error if shape or dtype are incompatible with T. - // TODO(b/189458441): Update this when we decide on the representation of - // shape and dtype in this API. - // Disclaimer: This API is subject to change as we add support for creating - // device tensors b/187222691 and enable buffer re-use b/187223179. - // TODO(b/190715501): Make this available via a soft API as well. - template - tensorflow::StatusOr CreateHostTensor(absl::Span shape, - int dtype, - absl::Span data); -}; - -template -tensorflow::StatusOr Runtime::CreateHostTensor( - absl::Span shape, int dtype, absl::Span data) { - size_t num_elements = 1; - for (int dim = 0; dim < shape.size(); dim++) { - if (shape[dim] < 0) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Shape must be fully-defined, got: shape[", dim, "] = ", shape[dim])); - } - num_elements *= shape[dim]; - } - if (data.size() != num_elements) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Mismatched shape and data size: \n", "Shape num_elements: ", - num_elements, "\n", "Data size: ", data.size(), "\n")); - } - auto maybe_capsule = Get(String("ctx")); - if (!maybe_capsule.status().ok()) { - return maybe_capsule.status(); - } - auto capsule = maybe_capsule.value(); - auto ctx = capsule.cast(); - tensorflow::AbstractTensorPtr t( - ctx->CreateTensor(static_cast(dtype), shape)); - // TODO(srbs): This is still a weak check. Check that dtype and T are - // compatible. - if (t->ByteSize() != sizeof(T) * data.size()) { - return tensorflow::errors::InvalidArgument(absl::StrCat( - "Invalid number of bytes in data buffer\n", "Expected bytes: ", - t->ByteSize(), "\n", "Actual bytes: ", sizeof(T) * data.size())); - } - memcpy(t->Data(), data.data(), t->ByteSize()); - return Tensor(Convert(TaggedValue( - impl::TaggedValueTensor(ctx->CreateLocalHandle(t.get()), false)))); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_RUNTIME_H_ diff --git a/tensorflow/cc/experimental/libtf/tests/function_test.cc b/tensorflow/cc/experimental/libtf/tests/function_test.cc deleted file mode 100644 index fa1f21389df969..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/function_test.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/function.h" - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -using tensorflow::AbstractContext; -using tensorflow::AbstractContextPtr; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::DT_FLOAT; -using tensorflow::FunctionDef; -using tensorflow::FunctionDefHelper; -using tensorflow::PartialTensorShape; -using tensorflow::Status; -using tensorflow::StatusOr; -using tensorflow::TF_StatusPtr; -using tensorflow::tracing::graph::GraphFunction; - -class FunctionTest - : public ::testing::TestWithParam> { - public: - template - impl::TaggedValueTensor CreateScalarTensor(T val) { - AbstractTensorHandle* raw = nullptr; - Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return impl::TaggedValueTensor(raw, /*add_ref=*/false); - } - - bool UseTfrt() { return std::get<1>(GetParam()); } - - AbstractContextPtr ctx_; - - protected: - void SetUp() override { - // Set the tracing impl, GraphDef vs MLIR. - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - - // Set the runtime impl, Core RT vs TFRT. - AbstractContext* ctx_raw = nullptr; - s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx_.reset(ctx_raw); - } -}; - -// TODO(b/191361582): Use Abstract* APIs for building functions so that we can -// test with MLIR. -FunctionDef SquareFunc() { - return FunctionDefHelper::Define( - // Function Name - "SquareFunc", - // Args - {"x: float"}, - // Returns - {"y: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"y"}, - /*op=*/"Square", - /*arg=*/{"x"}, - /*attr=*/{{"T", DT_FLOAT}}, - /*dep=*/{}, - /*device=*/"", - /*name=*/"square"}}); -} - -FunctionDef AddFunc() { - return FunctionDefHelper::Define( - // Function Name - "AddFunc", - // Args - {"x: float", "y: float"}, - // Returns - {"z: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"z"}, - /*op=*/"Add", - /*arg=*/{"x", "y"}, - /*attr=*/{{"T", DT_FLOAT}}, - /*dep=*/{}, - /*device=*/"", - /*name=*/"add"}}); -} - -FunctionDef IdentityNFunc() { - return FunctionDefHelper::Define( - // Function Name - "IdentityNFunc", - // Args - {"x: float", "y: float"}, - // Returns - {"u: float", "v: float"}, - // Attr def - {}, - // Nodes - {{/*ret=*/{"u", "v"}, - /*op=*/"IdentityN", - /*arg=*/{"x", "y"}, - /*attr=*/{{"T", tensorflow::DataTypeSlice({DT_FLOAT, DT_FLOAT})}}, - /*dep=*/{}, - /*device=*/""}}); -} - -template -void ExpectEquals(AbstractTensorHandle* t, T expected) { - TF_Tensor* result_t; - Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.message(); - auto value = static_cast(TF_TensorData(result_t)); - EXPECT_EQ(*value, expected); - TF_DeleteTensor(result_t); -} - -// TODO(srbs): Add tests for captures. -// TODO(srbs): Add tests for polymorphism (different shapes and dtypes). -TEST_P(FunctionTest, Square) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = SquareFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args(std::move(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - AbstractTensorHandle* t = result.tensor().get(); - ExpectEquals(t, 4.0f); -} - -TEST_P(FunctionTest, Add) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = AddFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(tensor_spec); - input_signature.tuple().emplace_back(tensor_spec); - Status s = - tf_function.RegisterTrace(std::move(trace), input_signature, tensor_spec); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - ExpectEquals(result.tensor().get(), 4.0f); -} - -TEST_P(FunctionTest, IdentityN) { - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - impl::TaggedValueTensor y = CreateScalarTensor(4.0f); - FunctionDef fdef = IdentityNFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue signature = TaggedValue::Tuple(); - signature.tuple().emplace_back(tensor_spec); - signature.tuple().emplace_back(tensor_spec); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(y)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(v.ok()) << v.status().message(); - const TaggedValue& result = v.value(); - ExpectEquals(result.tuple()[0].tensor().get(), 2.0f); - ExpectEquals(result.tuple()[1].tensor().get(), 4.0f); -} - -TEST_P(FunctionTest, UnaryFuncCalledWithMultipleArgsFails) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = SquareFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), signature, signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())); - ASSERT_TRUE(absl::StrContains(v.status().message(), "No match")); -} - -TEST_P(FunctionTest, IncorrectArityOfOutputSignatureFails) { - if (UseTfrt()) { - GTEST_SKIP() << "TFRT crashes if expected number of output tensors does not" - " match actual."; - } - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - impl::TaggedValueTensor y = CreateScalarTensor(4.0f); - FunctionDef fdef = IdentityNFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue tensor_spec(unknown_shape, DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(tensor_spec); - input_signature.tuple().emplace_back(tensor_spec); - // This is wrong! - TaggedValue output_signature(unknown_shape, DT_FLOAT); - Status s = tf_function.RegisterTrace(std::move(trace), input_signature, - output_signature); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(y)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInvalidArgument(v.status())) << v.status(); - ASSERT_TRUE(absl::StrContains(v.status().message(), - "Expecting 2 outputs, but *num_retvals is 1")); -} - -TEST_P(FunctionTest, IncorrectDtypeInOutputSignatureFails) { - // Construct a scalar. - impl::TaggedValueTensor x = CreateScalarTensor(2.0f); - FunctionDef fdef = AddFunc(); - AbstractFunctionPtr trace(new GraphFunction(fdef), /*add_ref=*/false); - Function tf_function; - PartialTensorShape unknown_shape; - TaggedValue input_tensor_spec(unknown_shape, tensorflow::DT_FLOAT); - TaggedValue input_signature = TaggedValue::Tuple(); - input_signature.tuple().emplace_back(input_tensor_spec); - input_signature.tuple().emplace_back(input_tensor_spec); - // Incorrect type. - TaggedValue output_tensor_spec(unknown_shape, tensorflow::DT_INT64); - Status s = tf_function.RegisterTrace(std::move(trace), input_signature, - output_tensor_spec); - ASSERT_TRUE(s.ok()) << s.message(); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(x)); - args.tuple().emplace_back(TaggedValue(x)); - StatusOr v = tf_function.Execute(ctx_.get(), args); - ASSERT_TRUE(tensorflow::errors::IsInternal(v.status())) << v.status(); - ASSERT_TRUE( - absl::StrContains(v.status().message(), "Shape and dtype of tensor")); - ASSERT_TRUE(absl::StrContains(v.status().message(), - "does not match that in signature")); -} - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, FunctionTest, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/generate_testdata.py b/tensorflow/cc/experimental/libtf/tests/generate_testdata.py deleted file mode 100644 index 09b84399a00e2f..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/generate_testdata.py +++ /dev/null @@ -1,105 +0,0 @@ -# /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ==============================================================================*/ -r"""Generate models in testdata for use in tests. - -If this script is being run via ` run`, pass an absolute path. -Otherwise, this script will attempt to write to a non-writable directory. - -Example: - run //third_party/tensorflow/cc/experimental/libtf:generate_testdata - -- \ - --path`pwd`/third_party/tensorflow/cc/experimental/libtf/tests/testdata/ \ - --model_name=simple-model -""" -import os - -from absl import app -from absl import flags - -from tensorflow.python.compat import v2_compat -from tensorflow.python.eager import def_function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import tensor_spec -from tensorflow.python.module import module -from tensorflow.python.ops import variables -from tensorflow.python.saved_model import saved_model - -TESTDATA_PATH = flags.DEFINE_string( - "path", None, help="Path to testdata directory.") - -MODEL_NAME = flags.DEFINE_string( - "model_name", None, help="Name of model to generate.") - - -class DataStructureModel(module.Module): - """Model used for testing data structures in the C++ API.""" - - def __init__(self): - self.arr1 = [1.] - self.const_arr = [constant_op.constant(1.)] - self.var_arr = [variables.Variable(1.), variables.Variable(2.)] - self.dict1 = {"a": 1.} - self.var_dict = {"a": variables.Variable(1.), "b": variables.Variable(2.)} - - -class SimpleModel(module.Module): - """A simple model used for exercising the C++ API.""" - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - ]) - def test_float(self, x): - return constant_op.constant(3.0) * x - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.int32), - ]) - def test_int(self, x): - return constant_op.constant(3) * x - - @def_function.function(input_signature=[ - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32), - ]) - def test_add(self, x, y): - # Test a function with multiple arguments. - return x + y - - -TEST_MODELS = { - "simple-model": SimpleModel, - "data-structure-model": DataStructureModel -} - - -def get_model(name): - if name not in TEST_MODELS: - raise ValueError("Model name '{}' not in TEST_MODELS") - return TEST_MODELS[name]() - - -def main(unused_argv): - - model = get_model(MODEL_NAME.value) - path = os.path.join(TESTDATA_PATH.value, MODEL_NAME.value) - saved_model.save(model, path) - - -if __name__ == "__main__": - v2_compat.enable_v2_behavior() - flags.mark_flag_as_required("path") - flags.mark_flag_as_required("model_name") - app.run(main) diff --git a/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc b/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc deleted file mode 100644 index 897b1235821e49..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/mlir_transform_test.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/mlir/mlir_transform.h" - -#include -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -TEST(TransformTest, LoadSavedModel) { - Object mlir = MLIR(); - TF_ASSERT_OK_AND_ASSIGN(Callable load, - mlir.Get(String("LoadSavedModel"))); - - TF_ASSERT_OK_AND_ASSIGN( - Handle model_bad, - load.Call(mlir, String("/error/doesnotexist___31284382"))); - TF_ASSERT_OK(Cast(model_bad).status()); - - const std::string model_good_path = tensorflow::GetDataDependencyFilepath( - "tensorflow/cc/experimental/libtf/tests/testdata/simple-model"); - - TF_ASSERT_OK_AND_ASSIGN( - Object model_good, - load.Call(mlir, String(model_good_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable to_string, - model_good.Get(String("ToString"))); - - TF_ASSERT_OK_AND_ASSIGN(String s, to_string.Call(model_good)); - - ASSERT_GT(strlen(s.get()), 0); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/module_test.cc b/tensorflow/cc/experimental/libtf/tests/module_test.cc deleted file mode 100644 index 78620846c59aee..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/module_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/module.h" - -#include - -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/status_matchers.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tf { -namespace libtf { -namespace impl { - -using ::tensorflow::libexport::TFPackage; -using ::tensorflow::testing::StatusIs; -using ::tf::libtf::runtime::Runtime; - -TEST(ModuleTest, TestStubbedFunctions) { - Runtime runtime = runtime::core::Runtime(); - TFPackage tf_package; - tensorflow::StatusOr result = BuildProgram(runtime, tf_package); - ASSERT_FALSE(result.status().ok()); -} - -TEST(ModuleTest, TestBuildObjectsDataStructures) { - const std::string path = tensorflow::GetDataDependencyFilepath( - "tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model"); - TF_ASSERT_OK_AND_ASSIGN(TFPackage tf_package, TFPackage::Load(path)); - - TF_ASSERT_OK_AND_ASSIGN(std::vector objects, - BuildObjects(tf_package)); - EXPECT_EQ(objects.size(), 7); - // The first node of data-structure-model is a dictionary. - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node, - Cast(objects.front())); - - // The next three nodes of data-structure-model are lists. - for (unsigned int i = 1; i < 4; i++) { - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::List node, - Cast(objects.at(i))); - } - // The last three nodes of data-structure-model are dictionaries. - for (unsigned int i = 4; i < 7; i++) { - TF_ASSERT_OK_AND_ASSIGN(tf::libtf::Dictionary node, - Cast(objects.at(i))); - } -} - -TEST(ModuleTest, TestBuildEmptyList) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "trackable_list_wrapper" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestBuildEmptyDict) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "trackable_dict_wrapper" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestBuildSignatureMap) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "signature_map" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - TF_ASSERT_OK_AND_ASSIGN(Handle result, - BuildSavedUserObject(saved_object_proto)); - EXPECT_EQ(Cast(result)->size(), 0); -} - -TEST(ModuleTest, TestUnimplementedUserObject) { - tensorflow::SavedObject saved_object_proto; - const std::string pb_txt = R"pb( - user_object { - identifier: "foo" - version { producer: 1 min_consumer: 1 } - } - )pb"; - - ASSERT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString( - pb_txt, &saved_object_proto)); - - EXPECT_THAT( - BuildSavedUserObject(saved_object_proto), - StatusIs(tensorflow::error::UNIMPLEMENTED, ::testing::HasSubstr("foo"))); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/object_test.cc b/tensorflow/cc/experimental/libtf/tests/object_test.cc deleted file mode 100644 index dd0916facf9984..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/object_test.cc +++ /dev/null @@ -1,184 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/object.h" - -#include - -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -TEST(ObjectTest, TestDictionary) { - Dictionary foo; - foo.Set(String("a"), Integer(33)); - foo.Set(String("b"), Integer(33)); - EXPECT_EQ(foo.Get(String("b"))->get(), 33); -} - -TEST(ObjectTest, TestTuple) { - Tuple foo(String("a"), Integer(33), Float(10.f)); - EXPECT_EQ(foo.size(), 3); - EXPECT_EQ(foo.Get(1)->get(), 33); -} - -TEST(ObjectTest, TestList) { - List l; - EXPECT_EQ(l.size(), 0); - l.append(Integer(3)); - EXPECT_EQ(l.Get(0)->get(), 3); - EXPECT_EQ(l.size(), 1); -} - -TaggedValue AddIntegers(TaggedValue args_, TaggedValue kwargs_) { - auto& args = args_.tuple(); - // auto& kwargs = kwargs_.dict(); - return TaggedValue(args[0].i64() + args[1].i64()); -} - -TEST(ObjectTest, TestCast) { - Integer i(3); - auto result = Cast(i); - ASSERT_TRUE(!result.ok()); -} - -TEST(ObjectTest, TestCall) { - TaggedValue add_func(AddIntegers); - Callable add(add_func); - TF_ASSERT_OK_AND_ASSIGN(Integer i, - add.Call(Integer(1), Integer(10))); - EXPECT_EQ(i.get(), 11); - - TF_ASSERT_OK_AND_ASSIGN( - Integer i2, add.Call(1, Integer(10), KeywordArg("foo") = 3)); - EXPECT_EQ(i2.get(), 11); -} - -TEST(ObjectTest, MakeObject) { - // TaggedValue func(f); - Object parent; - parent.Set(String("test3"), Integer(3)); - Object child; - child.Set(String("test1"), Integer(1)); - child.Set(String("test2"), Integer(2)); - child.Set(Object::ParentKey(), parent); - EXPECT_EQ(child.Get(String("test1"))->get(), 1); - EXPECT_EQ(child.Get(String("test2"))->get(), 2); - EXPECT_EQ(child.Get(String("test3"))->get(), 3); - ASSERT_FALSE(child.Get(String("test4")).status().ok()); - TF_ASSERT_OK(child.Get(String("test3")).status()); -} - -TEST(ObjectTest, CallFunctionOnObject) { - Object module; - module.Set(String("add"), Callable(TaggedValue(AddIntegers))); - TF_ASSERT_OK_AND_ASSIGN(Callable method, module.Get(String("add"))); - - TF_ASSERT_OK_AND_ASSIGN(Integer val, method.Call(1, 2)); - EXPECT_EQ(val.get(), 3); -} - -TEST(ObjectTest, Capsule) { - Object obj; - int* hundred = new int(100); - Handle capsule = - Handle(TaggedValue::Capsule(static_cast(hundred), [](void* p) { - delete static_cast(p); - })); - obj.Set(String("hundred"), capsule); - EXPECT_EQ(*static_cast( - obj.Get(String("hundred"))->cast()), - 100); -} - -None AppendIntegerToList(List a, Integer b) { - a.append(b); - return None(); -} -Integer AddIntegersTyped(Integer a, Integer b) { - return Integer(a.get() + b.get()); -} -Integer ReturnFive() { return Integer(5); } - -TEST(TypeUneraseCallTest, TestCallable) { - // Add two integers. - Callable add(TFLIB_CALLABLE_ADAPTOR(AddIntegersTyped)); - auto res = add.Call(Integer(3), Integer(1)); - EXPECT_EQ(res->get(), 4); -} - -TEST(TypeUneraseCallTest, TestAppend) { - // Append some indices to a list. - Callable append(TFLIB_CALLABLE_ADAPTOR(AppendIntegerToList)); - List l; - TF_ASSERT_OK(append.Call(l, Integer(3)).status()); - TF_ASSERT_OK(append.Call(l, Integer(6)).status()); - EXPECT_EQ(l.size(), 2); - EXPECT_EQ(l.Get(0)->get(), 3); - EXPECT_EQ(l.Get(1)->get(), 6); -} - -TEST(TypeUneraseCallTest, TestCallableWrongArgs) { - // Try variants of wrong argument types. - Callable append(TFLIB_CALLABLE_ADAPTOR(AddIntegersTyped)); - ASSERT_FALSE(append.Call(Object(), Integer(3)).ok()); - ASSERT_FALSE(append.Call(Object(), Object()).ok()); - // Try variants of wrong numbers of arguments. - ASSERT_FALSE(append.Call().ok()); - ASSERT_FALSE(append.Call(Integer(3)).ok()); - ASSERT_FALSE(append.Call(Integer(3), Integer(4), Integer(5)).ok()); -} - -Handle Polymorph(Handle a) { - auto i = Cast(a); - if (i.ok()) { - return Integer(i->get() * 2); - } - auto f = Cast(a); - if (f.ok()) { - return Float(f->get() * 2.f); - } - return None(); -} - -TEST(TypeUneraseCallTest, TestCallableGeneric) { - Callable f(TFLIB_CALLABLE_ADAPTOR(Polymorph)); - EXPECT_EQ(f.Call(Float(.2))->get(), .4f); - EXPECT_EQ(Cast(*f.Call(Float(.2)))->get(), .4f); - EXPECT_EQ(f.Call(Integer(3))->get(), 6); -} - -TEST(TypeUneraseCallTest, TestLambda) { - // Test a trivial lambda that doubles an integer. - Callable c( - TFLIB_CALLABLE_ADAPTOR([](Integer a) { return Integer(a.get() * 2); })); - EXPECT_EQ(c.Call(Integer(3))->get(), 6); - // Testa lambda that has captured state (call count). - int call_count = 0; - Callable f(TFLIB_CALLABLE_ADAPTOR([&call_count](Integer a, Integer b) { - call_count++; - return Integer(a.get() + b.get()); - })); - EXPECT_EQ(f.Call(Integer(3), Integer(-1))->get(), 2); - EXPECT_EQ(f.Call(Integer(3), Integer(-3))->get(), 0); - EXPECT_EQ(call_count, 2); -} - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/perf_test.cc b/tensorflow/cc/experimental/libtf/tests/perf_test.cc deleted file mode 100644 index 3c40ac0438e77a..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/perf_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" - -namespace tf { -namespace libtf { - -namespace { - -// AddTagged using tagged values -TaggedValue AddTagged(TaggedValue args, TaggedValue kwargs) { - return TaggedValue(args.tuple()[0].i64() + args.tuple()[1].i64()); -} - -int64_t AddRaw(int64_t a, int64_t b) { return a + b; } - -} // namespace - -// Add numbers in a loop by calling a callable. -void CallFunctions(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - *callable.Call(sum, Integer(30)); - size_t i = 0; - for (auto dummy : state) { - sum = *callable.Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a callable, looking up method every -// time by tokenized string. -void CallFunctionsIndirect(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - Object o; - String name("f"); - o.Set(name, callable); - size_t i = 0; - for (auto dummy : state) { - sum = *(o.Get(name))->Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a callable, looking up method every -// time by non-tokenized string. -void CallFunctionsIndirectNaive(::testing::benchmark::State& state) { - Integer sum(0); - Callable callable((impl::TaggedValue(impl::Func(AddTagged)))); - Object o; - o.Set(String("f"), callable); - size_t i = 0; - for (auto dummy : state) { - sum = *(o.Get(String("f")))->Call(sum, Integer(i)); - i++; - } -} - -// Add numbers in a loop by calling a raw C++ function with a function -// pointer. -void CallFunctionsBase(::testing::benchmark::State& state) { - int64_t sum = 0; - typedef int64_t (*Func)(int64_t a, int64_t b); - volatile Func f_raw = AddRaw; - Func f = f_raw; - size_t i = 0; - for (auto dummy : state) { - sum = f(sum, i); - i++; - } - // volatile int64_t result = sum; -} - -BENCHMARK(CallFunctions)->Arg(1 << 10); -BENCHMARK(CallFunctionsIndirect)->Arg(1 << 10); -BENCHMARK(CallFunctionsIndirectNaive)->Arg(1 << 10); -BENCHMARK(CallFunctionsBase)->Arg(1 << 10); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test.cc deleted file mode 100644 index 3610b8a964648b..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/tests/runtime_test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -using ::tensorflow::testing::StatusIs; -using ::testing::HasSubstr; -using ::tf::libtf::impl::TaggedValueTensor; - -constexpr char kSimpleModel[] = - "tensorflow/cc/experimental/libtf/tests/testdata/simple-model"; - -TEST_P(RuntimeTest, SimpleModelCallableFloatTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - std::cout << "Module imported." << std::endl; - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_float"))); - TF_ASSERT_OK_AND_ASSIGN( - Tensor tensor, runtime.CreateHostTensor({}, TF_FLOAT, {2.0f})); - TF_ASSERT_OK_AND_ASSIGN(Tensor result, fn.Call(Tensor(tensor))); - - float out_val[1]; - TF_ASSERT_OK(result.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 6.0); -} - -TEST_P(RuntimeTest, SimpleModelCallableIntTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_int"))); - - // Call the function - TF_ASSERT_OK_AND_ASSIGN(Tensor host_tensor, - runtime.CreateHostTensor({}, TF_INT32, {2})); - - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor, fn.Call(Tensor(host_tensor))); - - int out_val[1]; - TF_ASSERT_OK(tensor.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 6); -} - -TEST_P(RuntimeTest, SimpleModelCallableMultipleArgsTest) { - Runtime runtime = RuntimeTest::GetParam()(); - - // Import the module and grab the callable - const std::string module_path = - tensorflow::GetDataDependencyFilepath(kSimpleModel); - TF_ASSERT_OK_AND_ASSIGN(Object module, - runtime.Load(String(module_path.c_str()))); - - TF_ASSERT_OK_AND_ASSIGN(Callable fn, - module.Get(String("test_add"))); - - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor1, - runtime.CreateHostTensor({}, TF_FLOAT, {2.0f})) - TF_ASSERT_OK_AND_ASSIGN(Tensor tensor2, - runtime.CreateHostTensor({}, TF_FLOAT, {3.0f})) - - TF_ASSERT_OK_AND_ASSIGN(Tensor result_tensor, - fn.Call(tensor1, tensor2)); - float out_val[1]; - TF_ASSERT_OK(result_tensor.GetValue(absl::MakeSpan(out_val))); - EXPECT_EQ(out_val[0], 5.0f); -} - -TEST_P(RuntimeTest, CreateHostTensorIncompatibleShape) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({2}, TF_FLOAT, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Mismatched shape and data size"))); -} - -TEST_P(RuntimeTest, CreateHostTensorNonFullyDefinedShapeRaises) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({-1}, TF_FLOAT, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Shape must be fully-defined"))); -} - -TEST_P(RuntimeTest, CreateHostTensorIncompatibleDataType) { - Runtime runtime = RuntimeTest::GetParam()(); - EXPECT_THAT(runtime.CreateHostTensor({1}, TF_BOOL, {2.0f}), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Invalid number of bytes in data buffer"))); -} - -TEST_P(RuntimeTest, TensorCopyInvalidSize) { - Runtime runtime = RuntimeTest::GetParam()(); - TF_ASSERT_OK_AND_ASSIGN( - Tensor tensor, runtime.CreateHostTensor({1}, TF_FLOAT, {2.0f})) - float val[2]; - - EXPECT_THAT(tensor.GetValue(absl::MakeSpan(val)), - StatusIs(absl::StatusCode::kInvalidArgument, - HasSubstr("Mismatched number of elements"))); -} - -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test.h b/tensorflow/cc/experimental/libtf/tests/runtime_test.h deleted file mode 100644 index 3ae665c663b784..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ - -#include "absl/status/status.h" -#include "absl/strings/match.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_datatype.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/resource_loader.h" -#include "tensorflow/core/platform/status_matchers.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -typedef Runtime (*RuntimeFn)(); - -class RuntimeTest : public ::testing::TestWithParam {}; - -} // namespace runtime -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_TESTS_RUNTIME_TEST_H_ diff --git a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc b/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc deleted file mode 100644 index 599520025229f1..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/runtime_test_core.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/runtime/core/core.h" -#include "tensorflow/cc/experimental/libtf/tests/runtime_test.h" - -namespace tf { -namespace libtf { -namespace runtime { - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, RuntimeTest, - ::testing::Values(core::Runtime)); -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(RuntimeTest); -} // namespace runtime -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc b/tensorflow/cc/experimental/libtf/tests/tensor_test.cc deleted file mode 100644 index 85243dd428775f..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/tensor_test.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/c_api_unified_experimental.h" -#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { - -using AbstractContextPtr = tensorflow::AbstractContextPtr; -using AbstractContext = tensorflow::AbstractContext; -using AbstractTensorHandle = tensorflow::AbstractTensorHandle; -using TF_StatusPtr = tensorflow::TF_StatusPtr; -using Status = tensorflow::Status; - -class UnifiedCAPI - : public ::testing::TestWithParam> { - protected: - void SetUp() override { - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - } -}; - -namespace { -template -TaggedValue MakeContext(T runtime) { - AbstractContext* ctx_raw = nullptr; - Status s = BuildImmediateExecutionContext(runtime, &ctx_raw); - // ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return TaggedValue::Capsule(static_cast(ctx_raw), [](void* p) { - tensorflow::internal::AbstractContextDeleter()( - static_cast(p)); - }); -} -} // namespace - -TEST_P(UnifiedCAPI, HoldTensors) { - // Use the parametrized test parameters to make a context. - AbstractContextPtr ctx; - { - AbstractContext* ctx_raw = nullptr; - Status s = - BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx.reset(ctx_raw); - } - - // Construct a scalar. - impl::TaggedValueTensor x; - { - AbstractTensorHandle* x_raw = nullptr; - Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); - ASSERT_EQ(tensorflow::errors::OK, s.code()) << s.message(); - x.reset(x_raw, false); - } - // Manually copy pointer so we can later compare the reference count. - impl::TaggedValueTensor x2(x); - - { - // Take ownership of x2 pointer. Semantics of AbstractTensorHandlePtr - // are that it has a reference. Here we steal that reference and put it - // into TaggedValue. If we used release() we would double free. - impl::TaggedValue tensor(std::move(x2)); - auto list = TaggedValue::List(); - // Test adding values by copying and moving. - list.list().emplace_back(3.f); - list.list().push_back(tensor); - list.list().emplace_back(std::move(tensor)); - ASSERT_FALSE(x->RefCountIsOne()); - } - ASSERT_TRUE(x->RefCountIsOne()); -} - -TaggedValue MakeScalarTensor(TaggedValue self, TaggedValue val) { - if (val.type() != TaggedValue::FLOAT32) return TaggedValue::None(); - if (self.type() != TaggedValue::DICT) return TaggedValue::None(); - TaggedValue ctx_capsule = (self.dict())[TaggedValue("context")]; - AbstractContext* ctx = static_cast(ctx_capsule.capsule()); - AbstractTensorHandle* x_raw = nullptr; - Status s = - TestScalarTensorHandle(ctx, val.f32().get(), &x_raw); - if (!s.ok()) return TaggedValue::None(); - return TaggedValue(impl::TaggedValueTensor(x_raw, false)); -} -TEST_P(UnifiedCAPI, SimpleCreationFunctions) { - // Use the parametrized test parameters to make a context. - TaggedValue context = MakeContext(std::get<1>(GetParam())); - Object methods; - methods.Set(String("context"), Handle(MakeContext(std::get<1>(GetParam())))); - methods.Set(String("make_scalar_tensor"), - Callable(TaggedValue(MakeScalarTensor))); - - Handle foo = *methods.Get(String("make_scalar_tensor")) - ->Call(methods, Float(3.f)); -} - -INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/README b/tensorflow/cc/experimental/libtf/tests/testdata/README deleted file mode 100644 index 84ad79dac73564..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/testdata/README +++ /dev/null @@ -1,2 +0,0 @@ -The models in this directory are generated using -//third_party/tensorflow/cc/experimental/libtf/tests:generate_testdata \ No newline at end of file diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb b/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb deleted file mode 100644 index 60e1a6028942d3..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/saved_model.pb and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.data-00000-of-00001 b/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.data-00000-of-00001 deleted file mode 100644 index f21dbb4945eb0c..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index b/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index deleted file mode 100644 index 52b3dfd35d92cd..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/data-structure-model/variables/variables.index and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/saved_model.pb b/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/saved_model.pb deleted file mode 100644 index db4d863337e783..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/saved_model.pb and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.data-00000-of-00001 b/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.data-00000-of-00001 deleted file mode 100644 index 48e4a89a932e88..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.data-00000-of-00001 and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.index b/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.index deleted file mode 100644 index 585d3e16bbd336..00000000000000 Binary files a/tensorflow/cc/experimental/libtf/tests/testdata/simple-model/variables/variables.index and /dev/null differ diff --git a/tensorflow/cc/experimental/libtf/tests/value_test.cc b/tensorflow/cc/experimental/libtf/tests/value_test.cc deleted file mode 100644 index 32301d9fa2e0ef..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/value_test.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/value.h" - -#include - -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -TEST(ValueTest, TestBasic) { - TaggedValue valuef(3.f); - TaggedValue valuei(int64_t(3)); - TaggedValue list = TaggedValue::List(); - TaggedValue tuple = TaggedValue::Tuple(); - tuple.tuple().push_back(TaggedValue(int64_t(310))); - list.list().push_back(valuei); - list.list().push_back(valuef); - list.list().push_back(tuple); - std::stringstream stream; - stream << list; - ASSERT_EQ(stream.str(), "[3, 3, (310, ), ]"); -} - -TEST(ValueTest, TestString) { - TaggedValue value1a("string1"); - std::string s = "string"; - s += "1"; - TaggedValue value1b(s.c_str()); - // Verify that interned the pointers are the same. - ASSERT_EQ(value1b.s(), value1a.s()); - TaggedValue value2("string2"); - ASSERT_NE(value1a.s(), value2.s()); - ASSERT_STREQ(value1a.s(), "string1"); - ASSERT_STREQ(value2.s(), "string2"); -} - -TEST(Test1, TestDict) { - TaggedValue s1("test1"); - TaggedValue s2("test2"); - TaggedValue d = TaggedValue::Dict(); - d.dict()[s2] = TaggedValue(6.f); - std::stringstream stream; - stream << d; - ASSERT_EQ(stream.str(), "{test2: 6, }"); -} - -namespace { -TaggedValue add(TaggedValue args, TaggedValue kwargs) { - if (args.type() == TaggedValue::TUPLE) { - return TaggedValue(args.tuple()[0].f32() + args.tuple()[1].f32()); - } - return TaggedValue::None(); -} -} // namespace -TEST(Test1, TestFunctionCall) { - TaggedValue f32 = TaggedValue(add); - TaggedValue args = TaggedValue::Tuple(); - args.tuple().emplace_back(TaggedValue(1.f)); - args.tuple().emplace_back(TaggedValue(2.f)); - TaggedValue c = f32.func()(args, TaggedValue::None()).value(); - ASSERT_EQ(c, TaggedValue(3.f)); -} - -namespace { -int alloc_count = 0; -class Cool { - public: - Cool() { alloc_count++; } - ~Cool() { alloc_count--; } -}; -} // namespace - -TEST(Test1, TestCapsule) { - TaggedValue test_moved, test_copy; - ASSERT_EQ(alloc_count, 0); - void* ptr_value = new Cool(); - { - TaggedValue capsule = - TaggedValue::Capsule(static_cast(ptr_value), - [](void* x) { delete static_cast(x); }); - ASSERT_EQ(alloc_count, 1); - ASSERT_EQ(capsule.capsule(), ptr_value); - test_moved = std::move(capsule); - ASSERT_EQ(capsule.type(), TaggedValue::NONE); // NOLINT - test_copy = test_moved; - ASSERT_EQ(test_moved.capsule(), ptr_value); - ASSERT_EQ(test_copy.capsule(), ptr_value); - } - ASSERT_EQ(alloc_count, 1); - test_moved = TaggedValue::None(); - ASSERT_EQ(alloc_count, 1); - test_copy = TaggedValue(3.f); - ASSERT_EQ(alloc_count, 0); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/variable_test.cc b/tensorflow/cc/experimental/libtf/tests/variable_test.cc deleted file mode 100644 index 1e37ed9cb2b96b..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/variable_test.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/c/eager/abstract_context.h" -#include "tensorflow/c/eager/abstract_function.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/c/eager/graph_function.h" -#include "tensorflow/c/eager/unified_api_testutil.h" -#include "tensorflow/c/experimental/ops/resource_variable_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/cc/experimental/libtf/function.h" -#include "tensorflow/cc/experimental/libtf/object.h" -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -using tensorflow::AbstractContext; -using tensorflow::AbstractContextPtr; -using tensorflow::AbstractFunctionPtr; -using tensorflow::AbstractTensorHandle; -using tensorflow::DT_FLOAT; -using tensorflow::PartialTensorShape; -using tensorflow::Status; -using tensorflow::TF_StatusPtr; - -class VariableTest - : public ::testing::TestWithParam> { - public: - template - impl::TaggedValueTensor CreateScalarTensor(T val) { - AbstractTensorHandle* raw = nullptr; - Status s = TestScalarTensorHandle(ctx_.get(), val, &raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - return impl::TaggedValueTensor(raw, /*add_ref=*/false); - } - - bool UseTfrt() { return std::get<1>(GetParam()); } - - AbstractContextPtr ctx_; - - protected: - void SetUp() override { - // Set the tracing impl, GraphDef vs MLIR. - TF_StatusPtr status(TF_NewStatus()); - TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); - Status s = tensorflow::StatusFromTF_Status(status.get()); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - - // Set the runtime impl, Core RT vs TFRT. - AbstractContext* ctx_raw = nullptr; - s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw); - CHECK_EQ(tensorflow::errors::OK, s.code()) << s.message(); - ctx_.reset(ctx_raw); - } -}; - -template -void ExpectEquals(AbstractTensorHandle* t, T expected) { - TF_Tensor* result_t; - Status s = tensorflow::GetValue(t, &result_t); - ASSERT_TRUE(s.ok()) << s.message(); - auto value = static_cast(TF_TensorData(result_t)); - EXPECT_EQ(*value, expected); - TF_DeleteTensor(result_t); -} - -TEST_P(VariableTest, CreateAssignReadDestroy) { - // Create uninitialized variable. - tensorflow::AbstractTensorHandlePtr var; - { - AbstractTensorHandle* var_ptr = nullptr; - PartialTensorShape scalar_shape; - TF_EXPECT_OK( - PartialTensorShape::MakePartialShape({}, 0, &scalar_shape)); - TF_EXPECT_OK(tensorflow::ops::VarHandleOp(ctx_.get(), &var_ptr, DT_FLOAT, - scalar_shape)); - var.reset(var_ptr); - } - // Assign a value. - auto x = CreateScalarTensor(2.0f); - TF_EXPECT_OK( - tensorflow::ops::AssignVariableOp(ctx_.get(), var.get(), x.get())); - // Read variable. - tensorflow::AbstractTensorHandlePtr value; - { - AbstractTensorHandle* value_ptr = nullptr; - TF_EXPECT_OK(tensorflow::ops::ReadVariableOp(ctx_.get(), var.get(), - &value_ptr, DT_FLOAT)); - value.reset(value_ptr); - } - ExpectEquals(value.get(), 2.0f); - // Destroy variable. - TF_EXPECT_OK(tensorflow::ops::DestroyResourceOp(ctx_.get(), var.get())); -} - -INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest, - ::testing::Combine(::testing::Values("graphdef", - "mlir"), - ::testing::Values(false))); - -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/tests/visit_test.cc b/tensorflow/cc/experimental/libtf/tests/visit_test.cc deleted file mode 100644 index fe905d9972a629..00000000000000 --- a/tensorflow/cc/experimental/libtf/tests/visit_test.cc +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "tensorflow/cc/experimental/libtf/value.h" -#include "tensorflow/cc/experimental/libtf/value_iostream.h" -#include "tensorflow/core/platform/test.h" - -namespace tf { -namespace libtf { -namespace impl { - -struct Visitor { - const char* operator()(Int64 i) { return "int64"; } - const char* operator()(Float32 f) { return "float32"; } - template - const char* operator()(const T& i) { - return "else"; - } -}; - -TEST(VisitTest, Test1) { - TaggedValue a(Int64(1)), b(Float32(1.1f)); - TaggedValue c = TaggedValue::None(); - - ASSERT_EQ(a.visit(Visitor()), "int64"); - ASSERT_EQ(b.visit(Visitor()), "float32"); - ASSERT_EQ(c.visit(Visitor()), "else"); -} - -} // namespace impl -} // namespace libtf -} // namespace tf diff --git a/tensorflow/cc/experimental/libtf/value.h b/tensorflow/cc/experimental/libtf/value.h deleted file mode 100644 index 61a2888426ee3d..00000000000000 --- a/tensorflow/cc/experimental/libtf/value.h +++ /dev/null @@ -1,596 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/// @file value.h -/// @brief The TaggedValue struct that supports Python-like behavior in C++. -/// -/// The TaggedValue struct implements a tagged union data structure -/// (https://en.wikipedia.org/wiki/Tagged_union) in the TensorFlow C++ API. It -/// contains a `Type` enum (sometimes referred to as a "tag") -/// and a `Data` union for holding values. - -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ - -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "tensorflow/c/eager/abstract_tensor_handle.h" -#include "tensorflow/core/platform/intrusive_ptr.h" -#include "tensorflow/core/platform/statusor.h" - -// TODO(b/195578409): Move all value objects into `impl`. Currently only values -// that do not reference TaggedValue are there. -#include "tensorflow/cc/experimental/libtf/impl/none.h" -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" -#include "tensorflow/cc/experimental/libtf/impl/string.h" -#include "tensorflow/cc/experimental/libtf/impl/tensor_spec.h" - -namespace tf { -namespace libtf { -namespace impl { -// Necessary forward declares. -class TaggedValue; -class Tuple; -template -// TODO(ccrusius): Use absl::Hash specializations instead. -class TaggedValueHash; -using List = std::vector; -using ListPtr = std::shared_ptr; -using Dict = - absl::flat_hash_map>; -using DictPtr = std::shared_ptr; -using TuplePtr = std::shared_ptr; -using Func = - std::function(TaggedValue, TaggedValue)>; -// A capsule holds a pointer and a destructor for the pointer (i.e. a generic -// shared_ptr to void with a custom deleter). -using Capsule = std::shared_ptr; -using TaggedValueTensor = - tensorflow::core::IntrusivePtr; - -// Declare hash types so they can be instantiated below. - -/// @brief TaggedValue hashing infrastructure, which uses absl::hash. -/// -/// Hashable TaggedValues overload `AbslHashValue`. Non-hashable structures -/// return 0. -template <> -struct TaggedValueHash { - size_t operator()(const TaggedValue& v) const; -}; - -/// @brief Hash implementation for TaggedValue Tuples. -template <> -struct TaggedValueHash { - size_t operator()(const Tuple& t) const; -}; - -/// @brief The basic `TaggedValue` tagged union type. -/// -/// A `TaggedValue` contains a `Type` (or "tag") as an enum and a `Value` union. -/// Values include tensors, primitive values, lists, tuples, and dictionaries. -/// In the future we might also want to have representation of python objects in -/// the form of PyObject*. -class TaggedValue final { - public: - /// @brief Enum that describes the possible types a `TaggedValue` can be. - /// - /// A `TaggedValue` must be one of the following types: NONE, INT64, FLOAT32, - /// STRING, FUNC, DICT, LIST, TUPLE, TENSOR, TENSOR_SPEC, CAPSULE. - enum Type { - NONE = 0, - INT64 = 1, - FLOAT32 = 2, - STRING = 3, - FUNC = 4, - DICT = 5, - LIST = 6, - TUPLE = 7, - TENSOR = 8, - TENSOR_SPEC = 9, - CAPSULE = 10, - }; - TaggedValue() : type_(NONE), data_() {} - - /// Move assignment operator. - TaggedValue& operator=(TaggedValue&& v) { - destroy(); - MoveIntoUnion(std::move(v)); - return *this; - } - /// Move constructor. - TaggedValue(TaggedValue&& v) : type_(NONE) { MoveIntoUnion(std::move(v)); } - /// Copy constructor. - TaggedValue(const TaggedValue& v) : type_(NONE) { CopyIntoUnion(v); } - /// Copy assignment operator. - TaggedValue& operator=(const TaggedValue& v) { - destroy(); - CopyIntoUnion(v); - return *this; - } - /// TaggedValue constructor for type TENSOR. - explicit TaggedValue(TaggedValueTensor tensor) - : type_(TENSOR), data_(std::move(tensor)) {} - /// TaggedValue constructor for type TENSOR_SPEC. - explicit TaggedValue(tensorflow::PartialTensorShape shape, - tensorflow::DataType dtype) - : type_(TENSOR_SPEC), data_(shape, dtype) {} - /// TaggedValue constructor for type FUNC. - explicit TaggedValue(Func f32) : type_(FUNC), data_(f32) {} - /// TaggedValue constructor for type FLOAT32. - explicit TaggedValue(float f32) : type_(FLOAT32), data_(Float32(f32)) {} - /// TaggedValue constructor for type INT64. - explicit TaggedValue(int64_t i64) : type_(INT64), data_(Int64(i64)) {} - /// TaggedValue constructor for type FLOAT32. - explicit TaggedValue(Float32 f32) : type_(FLOAT32), data_(f32) {} - /// TaggedValue constructor for type INT64. - explicit TaggedValue(Int64 i64) : type_(INT64), data_(i64) {} - /// TaggedValue constructor for type STRING. - explicit TaggedValue(const char* s) : type_(STRING), data_(s) {} - /// Constructs a TaggedValue with type NONE. - static TaggedValue None() { - TaggedValue v; - v.type_ = NONE; - return v; - } - /// Constructs a TaggedValue with type LIST. - static TaggedValue List() { - TaggedValue v; - v.type_ = LIST; - using T = decltype(v.data_.list); - new (&v.data_.list) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type TUPLE. - static TaggedValue Tuple() { - TaggedValue v; - v.type_ = TUPLE; - using T = decltype(v.data_.tuple); - new (&v.data_.tuple) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type DICT. - static TaggedValue Dict() { - TaggedValue v; - v.type_ = DICT; - using T = decltype(v.data_.dict); - new (&v.data_.dict) T(std::make_shared()); - return v; - } - /// Constructs a TaggedValue with type TENSOR. - static TaggedValue Tensor(tensorflow::AbstractTensorHandle* raw_ptr) { - TaggedValue v; - v.type_ = TENSOR; - using T = decltype(v.data_.tensor); - new (&v.data_.tensor) T(raw_ptr, /*add_ref=*/false); - return v; - } - - /// Constructs a TaggedValue with type CAPSULE with a default destructor. - template - static TaggedValue Capsule(T* data) { - return Capsule(static_cast(data), - [](void* x) { delete static_cast(x); }); - } - /// Constructs a TaggedValue with type CAPSULE with a custom destructor. - static TaggedValue Capsule(void* data, void (*deleter)(void*)) { - TaggedValue v; - v.type_ = CAPSULE; - using T = decltype(v.data_.capsule); - new (&v.data_.capsule) T(data, deleter); - return v; - } - /// Destroys TaggedValue. Shared pointers in unions must be explicitly - /// deleted. - void destroy() { - if (type_ != NONE) { - // Explicitly run the destructor on the correct type. - visit([](auto& x) { - using T = typename std::decay::type; - x.~T(); - }); - // Make the type None, whenever we destroy so we always have an - // initialized value. - type_ = NONE; - } - } - ~TaggedValue() { destroy(); } - - /// @brief Get the underlying value based on type. - /// - /// @tparam T The desired return type. - /// @return The unwrapped value. If this `TaggedValue` type does not currently - /// contain a value of type `T`, the program terminates via a call to - /// `assert`. - template - T& get() { - assert(type_ == EnumValueOf::value); - return UnionAccess::unsafe_reference(*this); - } - - /// @brief Get the underlying value based on type. - /// - /// @tparam T The desired return type. - /// @return The unwrapped value. If this `TaggedValue` type does not currently - /// contain a value of type `T`, the program terminates via a call to - /// `assert`. - template - const T& get() const { - assert(type_ == EnumValueOf::value); - return UnionAccess::unsafe_reference(*this); - } - - /// Retrieves underlying value from a TaggedValue with type INT64. - const Int64& i64() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type FLOAT32. - const Float32& f32() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type STRING. - const char* s() const { return get().str().c_str(); } - - /// Retrieves underlying value from a TaggedValue with type LIST. - impl::List& list() { return *get(); } - /// Retrieves underlying value from a TaggedValue with type LIST. - const impl::List& list() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type TUPLE. - impl::Tuple& tuple() { return *get(); } - /// Retrieves underlying value from TaggedValues with type TUPLE. - const impl::Tuple& tuple() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type DICT. - impl::Dict& dict() { return *get(); } - /// Retrieves underlying value from TaggedValues with type DICT. - const impl::Dict& dict() const { return *get(); } - - /// Retrieves underlying value from a TaggedValue with type FUNC. - impl::Func func() const { return get(); } - - // TODO(danielellis): make const-only if possible, once the API allows for it - /// Retrieves underlying value from a TaggedValue with type TENSOR. - TaggedValueTensor& tensor() { return get(); } - /// Retrieves underlying value from a TaggedValue with type TENSOR. - const TaggedValueTensor& tensor() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type TENSOR_SPEC. - const TensorSpec& tensor_spec() const { return get(); } - - /// Retrieves underlying value from a TaggedValue with type CAPSULE. - void* capsule() const { return get().get(); } - - /// Retrieves type of TaggedValue. - Type type() const { return type_; } - - /// @brief Implements equality operator for TaggedValue. - bool operator==(const TaggedValue& o) const { - if (type_ != o.type_) return false; - switch (type_) { - case LIST: - return data_.list == o.data_.list; - break; - case TUPLE: - return data_.tuple == o.data_.tuple; - break; - case DICT: - return data_.dict == o.data_.dict; - break; - case FUNC: - // TODO(b/187536093): This is definitely wrong because the exact ptr of - // the function pointer is almost always different, because we hold - // it by value. Two tagged values that hold the same std::function - // will have different std::function ptrs. operator== is not defined - // for std::function's so we need a better solution here, or these - // are not comparable which seems bad. - return &data_.func == &o.data_.func; - break; - case FLOAT32: - return data_.f32 == o.data_.f32; - break; - case INT64: - return data_.i64 == o.data_.i64; - break; - case STRING: - return data_.s == o.data_.s; - break; - case TENSOR: - return data_.tensor == o.data_.tensor; - case TENSOR_SPEC: - return data_.tensor_spec == o.data_.tensor_spec; - case CAPSULE: - return data_.capsule.get() == o.data_.capsule.get(); - case NONE: - return true; - } - } - - /// @brief Implements visitor pattern for doing type-based dispatch. - /// - /// @tparam R The desired return type. - /// @tparam Visitor The visitor class which has a callable operator. - /// @return The `visitor` called on the correct value. - template - R visit(Visitor visitor) { - switch (type_) { - case LIST: - return visitor(data_.list); - case TUPLE: - return visitor(data_.tuple); - case DICT: - return visitor(data_.dict); - case FUNC: - return visitor(data_.func); - case FLOAT32: - return visitor(data_.f32); - case INT64: - return visitor(data_.i64); - case STRING: - return visitor(data_.s); - case TENSOR: - return visitor(data_.tensor); - case TENSOR_SPEC: - return visitor(data_.tensor_spec); - case CAPSULE: - return visitor(data_.capsule); - case NONE: - return visitor(impl::None::GetInstance()); - } - } - - /// @brief Implements visitor pattern for doing type-based dispatch. - /// - /// @tparam R The desired return type. - /// @tparam Visitor The visitor class which has a callable operator. - /// @return The `visitor` called on the correct value. - template - R visit(Visitor visitor) const { - switch (type_) { - case LIST: - return visitor(data_.list); - case TUPLE: - return visitor(data_.tuple); - case DICT: - return visitor(data_.dict); - case FUNC: - return visitor(data_.func); - case FLOAT32: - return visitor(data_.f32); - case INT64: - return visitor(data_.i64); - case STRING: - return visitor(data_.s); - case TENSOR: - return visitor(data_.tensor); - case TENSOR_SPEC: - return visitor(data_.tensor_spec); - case CAPSULE: - return visitor(data_.capsule); - case NONE: - return visitor(impl::None::GetInstance()); - } - } - - private: - /// @brief A utility class for mapping C++ types to Type values. - template - struct EnumValueOf; - - /// @brief A utility class for accessing the `Data` union members. - template - struct UnionAccess; - - // Unsafe Move, because it assumes the union has already been destroyed - // or is new! - void MoveIntoUnion(TaggedValue&& v) { - assert(type_ == NONE); - type_ = v.type_; - if (type_ != NONE) { - visit([&v](auto& left) -> void { - using T = typename std::decay::type; - new (&left) T(std::move(UnionAccess::unsafe_reference(v))); - }); - } - // Destroy the source r-value reference (making it None) - v.destroy(); - } - - // Unsafe Move, because it assumes the union has already been destroyed - // or is new! - void CopyIntoUnion(const TaggedValue& v) { - assert(type_ == NONE); - type_ = v.type_; - if (type_ != NONE) { - visit([&v](auto& left) -> void { - using T = typename std::decay::type; - new (&left) T(UnionAccess::unsafe_reference(v)); - }); - } - } - - /// @brief The type of the TaggedValue, i.e. the "tag" of a tagged union. - /// - /// In principle this could be incorporated into the union - /// for pointer types and non-64bit values, but then int64 and float64 values - /// would need to be indirected. This means that we are aiming for a total - /// data type size of <=16 bytes, comprised of one pointer (8 bytes) and - /// one type (<=8bytes). - Type type_; - - // we use an explicit union here because we want to avoid C++17's - // variant structures due to c++14 compatibility requirements. - // TODO(b/183980966): Compare against absl::variant. - union Data { - explicit Data() {} - explicit Data(Float32 f32) : f32(f32) {} - explicit Data(Int64 i64) : i64(i64) {} - explicit Data(const char* s) : s(String(s)) {} - explicit Data(Func fn) : func(fn) {} - explicit Data(TaggedValueTensor tensor_in) { - new (&tensor) TaggedValueTensor(std::move(tensor_in)); - } - explicit Data(tensorflow::PartialTensorShape shape, - tensorflow::DataType dtype) - : tensor_spec({shape, dtype}) {} - ~Data() {} - Float32 f32; - Int64 i64; - String s; - Func func; - // TODO(aselle): look at tensorflow thing - std::shared_ptr dict; - std::shared_ptr list; - std::shared_ptr tuple; - impl::Capsule capsule; - TaggedValueTensor tensor; - TensorSpec tensor_spec; - } data_; - friend std::ostream& operator<<(std::ostream& o, const TaggedValue& v); - friend TaggedValueHash; -}; - -#define TF_ENUM_VALUE_OF(TYPE, ENUM) \ - template <> \ - struct TaggedValue::EnumValueOf { \ - static constexpr Type value = ENUM; \ - }; - -TF_ENUM_VALUE_OF(impl::Capsule, CAPSULE); -TF_ENUM_VALUE_OF(impl::Float32, FLOAT32); -TF_ENUM_VALUE_OF(impl::Int64, INT64); -TF_ENUM_VALUE_OF(impl::List, LIST); -TF_ENUM_VALUE_OF(impl::ListPtr, LIST); -TF_ENUM_VALUE_OF(impl::Tuple, TUPLE); -TF_ENUM_VALUE_OF(impl::TuplePtr, TUPLE); -TF_ENUM_VALUE_OF(impl::Dict, DICT); -TF_ENUM_VALUE_OF(impl::DictPtr, DICT); -TF_ENUM_VALUE_OF(impl::None, NONE); -TF_ENUM_VALUE_OF(impl::Func, FUNC); -TF_ENUM_VALUE_OF(impl::String, STRING); -TF_ENUM_VALUE_OF(impl::TaggedValueTensor, TENSOR); -TF_ENUM_VALUE_OF(impl::TensorSpec, TENSOR_SPEC); -#undef TF_ENUM_VALUE_OF - -#define TF_UNION_ACCESS_INSTANCE(TYPE, MEMBER) \ - template <> \ - struct TaggedValue::UnionAccess { \ - static TYPE& unsafe_reference(TaggedValue& t) { return t.data_.MEMBER; } \ - static const TYPE& unsafe_reference(const TaggedValue& t) { \ - return t.data_.MEMBER; \ - } \ - }; - -TF_UNION_ACCESS_INSTANCE(impl::Capsule, capsule); -TF_UNION_ACCESS_INSTANCE(impl::Float32, f32); -TF_UNION_ACCESS_INSTANCE(impl::Int64, i64); -TF_UNION_ACCESS_INSTANCE(impl::ListPtr, list); -TF_UNION_ACCESS_INSTANCE(impl::TuplePtr, tuple); -TF_UNION_ACCESS_INSTANCE(impl::DictPtr, dict); -TF_UNION_ACCESS_INSTANCE(impl::Func, func); -TF_UNION_ACCESS_INSTANCE(impl::String, s); -TF_UNION_ACCESS_INSTANCE(impl::TaggedValueTensor, tensor); -TF_UNION_ACCESS_INSTANCE(impl::TensorSpec, tensor_spec); -#undef TF_UNION_ACCESS_INSTANCE - -/// The union accessor for `NoneType`. -template <> -struct TaggedValue::UnionAccess { - static impl::None& unsafe_reference(TaggedValue& t) { - return None::GetInstance(); - } - static const impl::None& unsafe_reference(const TaggedValue& t) { - return None::GetInstance(); - } -}; - -/// @brief The Tuple class for holding tuples of TaggedValues. -/// TODO: Need to wrap vector in Tuple otherwise variant has duplicate types. -class Tuple { - using TU = std::vector; - using value_type = TU::value_type; - using iterator = TU::iterator; - using const_iterator = TU::const_iterator; - TU values_; - - public: - TU::iterator begin() { return values_.begin(); } - TU::iterator end() { return values_.end(); } - TU::const_iterator begin() const { return values_.begin(); } - TU::const_iterator end() const { return values_.end(); } - const TU::value_type& operator[](size_t i) const { return values_[i]; } - TU::value_type& operator[](size_t i) { return values_[i]; } - size_t size() const { return values_.size(); } - void emplace_back(TaggedValue v) { values_.emplace_back(std::move(v)); } - void push_back(const TaggedValue& v) { values_.push_back(v); } -}; - -/// Hashing infrastructure for Tuple. -inline size_t TaggedValueHash::operator()(const Tuple& t) const { - std::size_t hash = 0; - for (auto& i : t) { - hash ^= TaggedValueHash()(i); - } - return hash; -} - -/// @brief The TaggedValueHashVisitor class for doing type-based hashing -/// of TaggedValues. -class TaggedValueHashVisitor { - public: - size_t operator()(const TaggedValueTensor& v) { - assert(false); - return 0; - } - size_t operator()(const ListPtr& v) { - assert(false); - return 0; - } - size_t operator()(const DictPtr& v) { - assert(false); - return 0; - } - size_t operator()(const Capsule& t) { return std::hash()(t); } - size_t operator()(const Func& t) { - assert(false); - return 0; - } - size_t operator()(const TuplePtr& t) { - std::size_t hash = 0; - for (auto it = t->begin(); it != t->end(); ++it) { - hash ^= TaggedValueHash()(*it); - } - return hash; - } - template - size_t operator()(const T& t) { - return absl::Hash()(t); - } -}; - -/// Hashing infrastructure for TaggedValues. Hashable TaggedValues overload -/// `AbslHashValue`. Non-hashable structures return 0, since we have no easy -/// way to abort. -inline size_t TaggedValueHash::operator()( - const TaggedValue& v) const { - return v.visit(TaggedValueHashVisitor()); -} - -} // namespace impl -} // namespace libtf -} // namespace tf - -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_H_ diff --git a/tensorflow/cc/experimental/libtf/value_iostream.h b/tensorflow/cc/experimental/libtf/value_iostream.h deleted file mode 100644 index c26ed493890407..00000000000000 --- a/tensorflow/cc/experimental/libtf/value_iostream.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ - -#include - -#include "tensorflow/cc/experimental/libtf/value.h" - -namespace tf { -namespace libtf { -namespace impl { - -inline std::ostream& operator<<(std::ostream& o, const Dict& v) { - o << "{"; - for (auto& x : v) { - o << x.first; - o << ": "; - o << x.second; - o << ", "; - } - o << "}"; - return o; -} -template -inline std::ostream& OutList(std::ostream& o, IT v_start, IT const v_end, - char start, char end) { - o << start; - for (IT p = v_start; p != v_end; ++p) { - o << *p; - o << ", "; - } - o << end; - return o; -} - -class TaggedValueIOStreamVisitor { - std::ostream& o_; - - public: - explicit TaggedValueIOStreamVisitor(std::ostream& o) : o_(o) {} - - std::ostream& operator()(const ListPtr& x) { - OutList(o_, x->begin(), x->end(), '[', ']'); - return o_; - } - std::ostream& operator()(const TuplePtr& x) { - OutList(o_, x->begin(), x->end(), '(', ')'); - return o_; - } - std::ostream& operator()(const DictPtr& x) { - o_ << *x; - return o_; - } - std::ostream& operator()(const Capsule& x) { - o_ << "Capsule(" << x.get() << ")"; - return o_; - } - std::ostream& operator()(const Func& x) { - o_ << "Func"; - return o_; - } - std::ostream& operator()(const TaggedValueTensor& x) { - o_ << "Tensor"; - return o_; - } - - template - std::ostream& operator()(const T& x) { - o_ << x; - return o_; - } -}; - -inline std::ostream& operator<<(std::ostream& o, const TaggedValue& v) { - return v.visit(TaggedValueIOStreamVisitor(o)); -} -} // namespace impl -} // namespace libtf -} // namespace tf -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_VALUE_IOSTREAM_H_ diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index b1c7e38f44c46c..a69935219b13b9 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -63,8 +63,8 @@ string NextIterationName(const Scope& scope, int loop_var_idx) { // Creates the `loop_var_idx`-th Merge node of a loop being constructed with // `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output. -Status CreateMerge(const Scope& scope, int loop_var_idx, - const Output& enter_output, Output* merge_output) { +absl::Status CreateMerge(const Scope& scope, int loop_var_idx, + const Output& enter_output, Output* merge_output) { // The merge nodes accept the while loop's back edges as an input (i.e. the // not-yet-created next iteration nodes). Use the underlying NodeBuilder API // directly to create the back edge. @@ -88,8 +88,8 @@ Status CreateMerge(const Scope& scope, int loop_var_idx, } // Creates the condition subgraph defined by `cond`. -Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, - const std::vector& inputs, Output* output) { +absl::Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, + const std::vector& inputs, Output* output) { // The control dependency is for constants in the cond graph, and other ops // that do not depend on the loop variables. This ensures that these ops are // in the while loop frame (since they will indirectly depend on an Enter node @@ -118,9 +118,9 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, // Create the body subgraph defined by `body`. `outputs` must be non-null and // empty. -Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, - const std::vector& inputs, - std::vector* outputs) { +absl::Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, + const std::vector& inputs, + std::vector* outputs) { DCHECK(outputs != nullptr); DCHECK(outputs->empty()); @@ -169,11 +169,12 @@ Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, // If there are multiple loop variables, each of the control flow ops is // duplicated for each loop variable. // TODO(skyewm): link to public version of design doc -Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, - const CondGraphBuilderFn& cond, - const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs, bool create_while_ctx, - Output* cond_output) { +absl::Status BuildWhileLoop(const Scope& scope, + const std::vector& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, + const string& frame_name, OutputList* outputs, + bool create_while_ctx, Output* cond_output) { DCHECK(!inputs.empty()); DCHECK(outputs != nullptr); DCHECK(outputs->empty()); diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 6dbf1d23dbaf4f..5a1a45dac7207d 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -27,14 +27,15 @@ namespace ops { // Function that takes cond graph inputs and returns cond graph boolean output. // 'output' need not be set if an error is returned. -typedef std::function& inputs, - Output* output)> +typedef std::function& inputs, Output* output)> CondGraphBuilderFn; // Function that takes body graph inputs and returns body graph outputs. // 'outputs' need not be populated if an error is returned. -typedef std::function& inputs, - std::vector* outputs)> +typedef std::function& inputs, + std::vector* outputs)> BodyGraphBuilderFn; // Constructs a while loop. @@ -65,11 +66,13 @@ typedef std::function& inputs, // // TODO(skyewm): clean up partially-constructed loop in error case // TODO(skyewm): create public interface to this method -Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, - const CondGraphBuilderFn& cond, - const BodyGraphBuilderFn& body, const string& frame_name, - OutputList* outputs, bool create_while_ctx = true, - Output* cond_output = nullptr); +absl::Status BuildWhileLoop(const Scope& scope, + const std::vector& inputs, + const CondGraphBuilderFn& cond, + const BodyGraphBuilderFn& body, + const string& frame_name, OutputList* outputs, + bool create_while_ctx = true, + Output* cond_output = nullptr); } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index 1e9338eb0c2fe8..5db62989c2204f 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -39,7 +39,7 @@ class WhileLoopTest : public ::testing::Test { const ops::BodyGraphBuilderFn& body, error::Code error_code = error::OK, const string& error_msg = "") { - Status s = + absl::Status s = ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_); EXPECT_EQ(s.code(), error_code); EXPECT_EQ(s.message(), error_msg); @@ -76,14 +76,14 @@ class WhileLoopTest : public ::testing::Test { const char* const WhileLoopTest::kFrameName = "test_loop"; -Status LessThanTenCond(const Scope& s, const std::vector& inputs, - Output* output) { +absl::Status LessThanTenCond(const Scope& s, const std::vector& inputs, + Output* output) { *output = ops::Less(s, inputs[0], 10); return s.status(); } -Status AddOneBody(const Scope& s, const std::vector& inputs, - std::vector* outputs) { +absl::Status AddOneBody(const Scope& s, const std::vector& inputs, + std::vector* outputs) { outputs->push_back(ops::Add(s, inputs[0], 1)); return s.status(); } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 02fd6786698cd1..8bc7d96887aae1 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -505,6 +505,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:types", ] + if_not_mobile([ @@ -522,7 +523,7 @@ cc_library( name = "fingerprinting", hdrs = ["fingerprinting.h"], visibility = [ - "//learning/brain/contrib/hub/server/distro:__subpackages__", + "//learning/brain/contrib/hub/server/ingestion:__subpackages__", "//learning/brain/contrib/tpu_modeling:__subpackages__", "//learning/metadata/artifactoid/cc:__subpackages__", "//learning/tfx/pipeline/util:__subpackages__", @@ -549,8 +550,6 @@ cc_library( ":constants", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/graph/regularization:simple_delete", - "//tensorflow/core/graph/regularization:util", "//tensorflow/core/util/tensor_bundle:naming", "//tensorflow/tools/proto_splitter:chunk_proto_cc", "//tensorflow/tools/proto_splitter:merge", @@ -558,7 +557,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:protobuf", + "@com_google_absl//absl/strings:str_format", "@riegeli//riegeli/bytes:fd_reader", "@riegeli//riegeli/records:record_reader", ], diff --git a/tensorflow/cc/saved_model/bundle_v2.h b/tensorflow/cc/saved_model/bundle_v2.h index e199bd1cc5dc7d..ec85d14f3755f5 100644 --- a/tensorflow/cc/saved_model/bundle_v2.h +++ b/tensorflow/cc/saved_model/bundle_v2.h @@ -38,12 +38,13 @@ namespace tensorflow { /// loaded into an executable in-memory representation). class SavedModelV2Bundle { public: - using RestoreObjectsCallback = - std::function; + using RestoreObjectsCallback = std::function; /// Loads persistent representations for a SavedModelV2 from the specified /// export directory. - static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle); + static absl::Status Load(const std::string& export_dir, + SavedModelV2Bundle* bundle); /// MetaGraphDef from the loaded SavedModel. MetaGraphDef& meta_graph_def() { return meta_graph_def_; } @@ -68,10 +69,10 @@ class SavedModelV2Bundle { /// saved_object_graph() and the corresponding TrackableObject from the /// trackable_object_graph(). The callback may use the variable_reader() but /// must not modify the underlying saved_object_graph(). - Status VisitObjectsToRestore(RestoreObjectsCallback callback); + absl::Status VisitObjectsToRestore(RestoreObjectsCallback callback); private: - Status RecurseObjectsToRestore( + absl::Status RecurseObjectsToRestore( const SavedObject* saved_object, int saved_object_node_id, const TrackableObjectGraph::TrackableObject* trackable_object, std::string object_name, diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index cf2ae4721623fa..edb61db527c668 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/cc/saved_model/constants.h" @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" +#include "tsl/platform/random.h" // b/291933687, b/291001524 #if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) #include "tensorflow/cc/saved_model/fingerprinting_utils.h" @@ -181,6 +183,8 @@ absl::StatusOr CreateFingerprintDefPb( fingerprint_def.set_saved_object_graph_hash(object_graph_hash); // Set fingerprint field #5. fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); + // Assign a random UUID to the fingerprint. + fingerprint_def.set_uuid(absl::StrFormat("%016d", tsl::random::New64())); // Set version of the fingerprint. VersionDef* version = fingerprint_def.mutable_version(); version->set_producer(kFingerprintProducer); diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index 053884da265528..dbc784eb8de53d 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -20,12 +20,14 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/fingerprint.pb.h" #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tsl/platform/statusor.h" @@ -60,6 +62,12 @@ TEST(FingerprintingTest, TestCreateFingerprint) { EXPECT_EQ(fingerprint_def.graph_def_program_hash(), 10127142238652115842U); EXPECT_EQ(fingerprint_def.signature_def_hash(), 15570736222402453744U); EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924U); + + // The uuid is a random number, but it should be a number > 0. + uint64 uuid = 0; + EXPECT_TRUE(absl::SimpleAtoi(fingerprint_def.uuid(), &uuid)); + EXPECT_GT(uuid, 0); + // TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot // check its value here. EXPECT_GT(fingerprint_def.checkpoint_hash(), 0); @@ -94,6 +102,7 @@ TEST(FingerprintingTest, TestCompareFingerprintForTwoModelSavedTwice) { fingerprint_def2.signature_def_hash()); EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), fingerprint_def2.saved_object_graph_hash()); + EXPECT_NE(fingerprint_def.uuid(), fingerprint_def2.uuid()); } TEST(FingerprintingTest, TestFingerprintComputationDoesNotMutateModel) { diff --git a/tensorflow/cc/saved_model/fingerprinting_utils.cc b/tensorflow/cc/saved_model/fingerprinting_utils.cc index b5248562c3f490..a41ab4ecd02bc9 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils.cc +++ b/tensorflow/cc/saved_model/fingerprinting_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "riegeli/bytes/fd_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli @@ -46,6 +47,7 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/merge.h" #include "tsl/platform/errors.h" +#include "tsl/platform/random.h" #include "tsl/platform/statusor.h" // IWYU pragma: no_include "third_party/protobuf/repeated_ptr_field.h" // IWYU pragma: no_include "third_party/protobuf/io/coded_stream.h" @@ -473,6 +475,7 @@ absl::StatusOr CreateFingerprintDefCpb( fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); + fingerprint_def.set_uuid(absl::StrFormat("%016d", tsl::random::New64())); reader.Close(); // Set version of the fingerprint. diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 1dcd951d92b5ed..f549645eff856e 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -101,16 +101,17 @@ class SavedModelBundleLite : public SavedModelBundleInterface { // indicated metagraph. // The recommended way to load a saved model is to call LoadSavedModel, // which provides an already initialized Metagraph, Session, and DebugInfo. -Status RestoreSession(const RunOptions& run_options, - const MetaGraphDef& meta_graph, const string& export_dir, - std::unique_ptr* session); +absl::Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, + const string& export_dir, + std::unique_ptr* session); // Initialize a session which wraps this metagraph. // The recommended way to load a saved model is to call LoadSavedModel, // which provides an already initialized Metagraph, Session, and DebugInfo. -Status LoadMetagraphIntoSession(const SessionOptions& session_options, - const MetaGraphDef& meta_graph, - std::unique_ptr* session); +absl::Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session); /// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to @@ -118,10 +119,11 @@ Status LoadMetagraphIntoSession(const SessionOptions& session_options, /// *bundle with a session and the requested MetaGraphDef, if found. /// /// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code. -Status LoadSavedModel(const SessionOptions& session_options, - const RunOptions& run_options, const string& export_dir, - const std::unordered_set& tags, - SavedModelBundle* bundle); +absl::Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + SavedModelBundle* bundle); /// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to @@ -130,10 +132,11 @@ Status LoadSavedModel(const SessionOptions& session_options, /// /// This overload creates a SavedModelBundleLite, which consumes less RAM than /// an equivalent SavedModelBundle. -Status LoadSavedModel(const SessionOptions& session_options, - const RunOptions& run_options, const string& export_dir, - const std::unordered_set& tags, - SavedModelBundleLite* bundle); +absl::Status LoadSavedModel(const SessionOptions& session_options, + const RunOptions& run_options, + const string& export_dir, + const std::unordered_set& tags, + SavedModelBundleLite* bundle); /// Checks whether the provided directory could contain a SavedModel. Note that /// the method does not load any data by itself. If the method returns `false`, diff --git a/tensorflow/cc/saved_model/loader_util.cc b/tensorflow/cc/saved_model/loader_util.cc index 3a984bf31b3cd9..334d631d27cc58 100644 --- a/tensorflow/cc/saved_model/loader_util.cc +++ b/tensorflow/cc/saved_model/loader_util.cc @@ -28,8 +28,9 @@ namespace internal { // A SavedModel may store the name of the initialization op to run in the // in the SignatureDef (v2) or a collection (v1). If an init_op collection // exists, then the collection must contain exactly one op. -Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, - string* init_op_name) { +absl::Status GetInitOp(const string& export_dir, + const MetaGraphDef& meta_graph_def, + string* init_op_name) { const auto& sig_def_map = meta_graph_def.signature_def(); const auto& init_op_sig_it = meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey); @@ -65,8 +66,8 @@ Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, return absl::OkStatus(); } -Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, - std::vector* asset_file_defs) { +absl::Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, + std::vector* asset_file_defs) { // With SavedModel v2, we write asset file def into metagraph instead of // collection, so read from metagraph first. if (meta_graph_def.asset_file_def_size() > 0) { diff --git a/tensorflow/cc/saved_model/loader_util.h b/tensorflow/cc/saved_model/loader_util.h index e5f36976162dcc..9ce3500c881444 100644 --- a/tensorflow/cc/saved_model/loader_util.h +++ b/tensorflow/cc/saved_model/loader_util.h @@ -27,11 +27,12 @@ namespace internal { // A SavedModel may store the name of the initialization op to run in the // in the SignatureDef (v2) or a collection (v1). If an init_op collection // exists, then the collection must contain exactly one op. -Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def, - string* init_op_name); +absl::Status GetInitOp(const string& export_dir, + const MetaGraphDef& meta_graph_def, + string* init_op_name); -Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, - std::vector* asset_file_defs); +absl::Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def, + std::vector* asset_file_defs); } // namespace internal } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h index 17785db6bc6ee0..b5e81f9e3a2523 100644 --- a/tensorflow/cc/saved_model/reader.h +++ b/tensorflow/cc/saved_model/reader.h @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_model.pb.h" namespace tensorflow { -Status ReadSavedModel(absl::string_view export_dir, - SavedModel* saved_model_proto); +absl::Status ReadSavedModel(absl::string_view export_dir, + SavedModel* saved_model_proto); // Finds and returns the MetaGraphDef (within the provided SavedModel) that // matches the given set of tags. The lifetime of the returned MetaGraphDef is @@ -45,12 +45,12 @@ absl::StatusOr FindMetaGraphDef( // finds the MetaGraphDef that matches the given set of tags and writes it to // the `meta_graph_def` parameter. Returns a failure status when the SavedModel // file does not exist or no MetaGraphDef matches the tags. -Status ReadMetaGraphDefFromSavedModel(absl::string_view export_dir, - const std::unordered_set& tags, - MetaGraphDef* meta_graph_def); +absl::Status ReadMetaGraphDefFromSavedModel( + absl::string_view export_dir, const std::unordered_set& tags, + MetaGraphDef* meta_graph_def); // Store debug info from the SavedModel export dir. -Status ReadSavedModelDebugInfoIfPresent( +absl::Status ReadSavedModelDebugInfoIfPresent( absl::string_view export_dir, std::unique_ptr* debug_info_proto); diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index 7e00186b3ad0a3..c8e153bed364c8 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -77,8 +77,8 @@ TEST_F(ReaderTest, NoTagMatch) { MetaGraphDef meta_graph_def; const string export_dir = GetDataDependencyFilepath(TestDataSharded()); - Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, - &meta_graph_def); + absl::Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"}, + &meta_graph_def); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( st.message(), @@ -90,7 +90,7 @@ TEST_F(ReaderTest, NoTagMatchMultiple) { MetaGraphDef meta_graph_def; const string export_dir = GetDataDependencyFilepath(TestDataSharded()); - Status st = ReadMetaGraphDefFromSavedModel( + absl::Status st = ReadMetaGraphDefFromSavedModel( export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def); EXPECT_FALSE(st.ok()); EXPECT_TRUE(absl::StrContains( @@ -102,8 +102,8 @@ TEST_F(ReaderTest, InvalidExportPath) { MetaGraphDef meta_graph_def; const string export_dir = GetDataDependencyFilepath("missing-path"); - Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, - &meta_graph_def); + absl::Status st = ReadMetaGraphDefFromSavedModel( + export_dir, {kSavedModelTagServe}, &meta_graph_def); EXPECT_FALSE(st.ok()); } @@ -119,7 +119,7 @@ TEST_F(ReaderTest, MetricsNotUpdatedFailedRead) { const int read_count_v2 = metrics::SavedModelReadCount("2").value(); const string export_dir = GetDataDependencyFilepath("missing-path"); - Status st = + absl::Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def); EXPECT_FALSE(st.ok()); @@ -132,7 +132,7 @@ TEST_F(ReaderTest, MetricsUpdatedSuccessfulRead) { const int read_count_v1 = metrics::SavedModelReadCount("1").value(); const string export_dir = GetDataDependencyFilepath(TestDataSharded()); - Status st = + absl::Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"serve"}, &meta_graph_def); EXPECT_EQ(metrics::SavedModelReadCount("1").value(), read_count_v1 + 1); } diff --git a/tensorflow/cc/saved_model/util.cc b/tensorflow/cc/saved_model/util.cc index b474f1ef3ed0f3..56c576f640872e 100644 --- a/tensorflow/cc/saved_model/util.cc +++ b/tensorflow/cc/saved_model/util.cc @@ -42,7 +42,7 @@ std::set GetMapKeys( return keys; } -Status GetInputValues( +absl::Status GetInputValues( const SignatureDef& signature, const ::google::protobuf::Map& request_inputs, std::vector>& inputs) { diff --git a/tensorflow/cc/saved_model/util.h b/tensorflow/cc/saved_model/util.h index aacf6c2bcb5c9e..2489f837dc97db 100644 --- a/tensorflow/cc/saved_model/util.h +++ b/tensorflow/cc/saved_model/util.h @@ -45,7 +45,7 @@ std::set GetMapKeys( // Get the default input value from signature if it's missing in the request // inputs. If `is_alias` is set to true, the keys of the `request_inputs` are // alias names rather than the feed names in the graph. -Status GetInputValues( +absl::Status GetInputValues( const SignatureDef& signature, const ::google::protobuf::Map& request_inputs, std::vector>& inputs); diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc index 8a15850edd839f..c23f9161a448fd 100644 --- a/tensorflow/cc/tools/freeze_saved_model.cc +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -133,7 +133,7 @@ Status GetVariableNameToTensorMap( std::unordered_set variable_names_set, std::unordered_map* variable_name_to_value_map) { if (variable_names_set.empty()) { - return OkStatus(); + return absl::OkStatus(); } std::vector variable_names; variable_names.reserve(variable_names_set.size()); @@ -156,7 +156,7 @@ Status GetVariableNameToTensorMap( for (size_t i = 0; i < variable_names.size(); i++) { (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; } - return OkStatus(); + return absl::OkStatus(); } // Converts a Variable NodeDef into a Constant NodeDef. @@ -229,7 +229,7 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, *frozen_graph_def->mutable_library() = graph_def.library(); // If the graph is empty there is nothing left to do. if (graph_def.node_size() == 0) { - return OkStatus(); + return absl::OkStatus(); } // name_to_node_map is needed to get the inputs from the NodeDef corresponding // the a string node name. These inputs are used when doing our backwards @@ -277,7 +277,7 @@ Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, // If the node isn't a variable, just copy the node as-is. *frozen_graph_def->add_node() = node; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -289,7 +289,7 @@ Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); TF_RETURN_IF_ERROR( FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index eb4ef40b8927f6..a64aab9e0bb5f5 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -81,7 +81,7 @@ class FreezeTest : public ::testing::Test { return saved_model_bundle->session->Run( /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs); } - return OkStatus(); + return absl::OkStatus(); } // Adds `graph_def` to `saved_model_bundle` and initializes a session with diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 8fccfc48c91376..35b5ccf6ae4088 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -16,10 +16,10 @@ limitations under the License. #include "tensorflow/cc/training/coordinator.h" #include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index e96b3d049eed73..404f313e6c490e 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/macros.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index 211fd1e68011e4..d051ff1fbe1226 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 508be19fd50a0f..6a1b5ec5726803 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/cc/training/coordinator.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/platform/blocking_counter.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { @@ -205,8 +205,7 @@ void QueueRunner::Run(Session* sess, const string& enqueue_op) { UpdateStatus(RealRun(sess, close_op_name_, false)); } } else if (!status.ok()) { - LOG(ERROR) << "Queue runner thread got a failure status: " - << status.ToString(); + LOG(ERROR) << "Queue runner thread got a failure status: " << status; UpdateStatus(status); if (coord_) { coord_->RequestStop().IgnoreError(); diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index f4de69b25a61b2..4fc5775bd2d1c6 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/cc/ops/state_ops.h" #include "tensorflow/cc/training/coordinator.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index fc351b2cd829b5..9dde12b7221c97 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -96,7 +96,7 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:compiler", "@local_xla//xla/service/cpu:buffer_info_util", "@local_xla//xla/service/cpu:cpu_compiler", diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 0151da956351b4..9dee02eb8e2548 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "xla/client/client_library.h" #include "xla/client/compile_only_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/stream_executor/platform_manager.h" #include "xla/util.h" diff --git a/tensorflow/compiler/aot/quantize.h b/tensorflow/compiler/aot/quantize.h index d18565071c7e91..e2412749290e77 100644 --- a/tensorflow/compiler/aot/quantize.h +++ b/tensorflow/compiler/aot/quantize.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 99c8541c55488c..82fdb603138136 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -212,6 +212,7 @@ def _tf_library( ] + freeze_saver_srcs, outs = [freeze_file], cmd = ( + "PYWRAP_TARGET='//tensorflow/python:_pywrap_tensorflow' " + "CUDA_VISIBLE_DEVICES='' " + "$(location " + "//tensorflow/python/tools:freeze_graph)" + diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 968aeb2f028d64..2056388d4d4987 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -844,6 +844,7 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla/client:executable_build_options", "@local_xla//xla/client:local_client", + "@local_xla//xla/hlo/translate:portable_api", "@local_xla//xla/service:hlo_graph_dumper", "@local_xla//xla/stream_executor:platform", "@local_xla//xla/stream_executor/host:host_platform_id", @@ -1145,7 +1146,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -1315,10 +1319,13 @@ tf_cc_test( "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:device_set", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_xla//xla:test", ], ) @@ -1860,7 +1867,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core/platform:errors", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index affe924d77a402..bed899bfed2f3e 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -225,7 +225,7 @@ Output IncomingEdgeAsOutput(const Edge* e) { return Output(e->src(), e->src_output()); } -Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { +absl::Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { int num_constant_inputs, num_resource_inputs; TF_RETURN_IF_ERROR( GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs)); @@ -263,7 +263,7 @@ Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) { return absl::OkStatus(); } -Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { +absl::Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) { for (const Edge* e : from->in_edges()) { if (e->IsControlEdge()) { g->AddControlEdge(e->src(), to); @@ -283,8 +283,9 @@ void RemoveAllIncomingControlEdges(Graph* g, Node* n) { } // Returns true (into `result`) if a node placed on `device` must be compiled. -Status DeviceRequiresCompilation(const jit::DeviceInfoCache& device_info_cache, - jit::DeviceId device, bool* result) { +absl::Status DeviceRequiresCompilation( + const jit::DeviceInfoCache& device_info_cache, jit::DeviceId device, + bool* result) { const XlaOpRegistry::DeviceRegistration* registration = device_info_cache.GetCompilationDevice(device); *result = registration->autoclustering_policy == @@ -423,8 +424,8 @@ absl::StatusOr GetOutputMemoryTypes(const Scope& root, // To prevent this, we add control dependencies to make the int32 input edges // into the PartitionedCall dead. With this change the D2H copy only happens if // the PartitionedCall is actually executed. -Status PredicateInt32Inputs(const Scope& root, Node* n, - Operation predicate_as_control) { +absl::Status PredicateInt32Inputs(const Scope& root, Node* n, + Operation predicate_as_control) { std::vector int32_inputs; std::vector int32_inputs_input_idxs; for (const Edge* e : n->in_edges()) { @@ -464,7 +465,7 @@ Status PredicateInt32Inputs(const Scope& root, Node* n, return absl::OkStatus(); } -Status ReplaceNodeWithXlaCompileAndXlaRun( +absl::Status ReplaceNodeWithXlaCompileAndXlaRun( jit::DeviceInfoCache* device_info_cache, const GraphOptimizationPassOptions& options, const FunctionLibraryDefinition& flib_def, bool lazy_compilation_enabled, @@ -486,7 +487,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( string device_name_str = string(device_info_cache->GetNameFor(device)); - Status status; + absl::Status status; Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr) .NewSubScope(n->name()) .WithDevice(n->requested_device()) @@ -569,7 +570,7 @@ Status ReplaceNodeWithXlaCompileAndXlaRun( } } // namespace -Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { +absl::Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); // Copy out the nodes we want to rewrite to avoid modifying the graph while we diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.h b/tensorflow/compiler/jit/build_xla_ops_pass.h index 1bc2e0b332290f..c1219d7ccd3c34 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.h +++ b/tensorflow/compiler/jit/build_xla_ops_pass.h @@ -34,7 +34,7 @@ class BuildXlaOpsPass : public GraphOptimizationPass { std::optional enable_lazy_compilation = std::nullopt) : enable_lazy_compilation_(enable_lazy_compilation) {} - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; private: std::optional enable_lazy_compilation_; diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc index 8e2324f68ddd34..c3b5ba5521ee65 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -55,8 +55,8 @@ using ::tensorflow::testing::matchers::Op; using ::tensorflow::testing::matchers::Out; using ::testing::_; -Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, - std::unique_ptr* result) { +absl::Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, + std::unique_ptr* result) { auto graph = std::make_unique(OpRegistry::Global()); TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); FunctionLibraryDefinition flib_def(graph->op_registry(), fdef_lib); @@ -85,9 +85,10 @@ Status BuildXlaOps(const Scope& s, const FunctionDefLibrary& fdef_lib, return absl::OkStatus(); } -Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, int num_constant_args, - int num_resource_args, Node** result) { +absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, + int num_constant_args, int num_resource_args, + Node** result) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); @@ -98,8 +99,8 @@ Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, return absl::OkStatus(); } -Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, - const string& node_name, Node** result) { +absl::Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { return MakeXlaCompiledKernel(graph, callee_name, node_name, /*num_constant_args=*/0, /*num_resource_args=*/0, result); @@ -167,7 +168,7 @@ TEST_F(BuildXlaOpsTest, CleanFailureOnBogusAttr) { root.graph()->AddControlEdge(call, write_op); std::unique_ptr graph; - Status failure_status = BuildXlaOps(root, fdef_lib, &graph); + absl::Status failure_status = BuildXlaOps(root, fdef_lib, &graph); ASSERT_FALSE(failure_status.ok()); EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT); } diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc index 2ee5e20b34be65..bb8dce848cfbc9 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.cc @@ -32,11 +32,11 @@ class CloneConstantsForBetterClusteringPassImpl { public: explicit CloneConstantsForBetterClusteringPassImpl(Graph* graph) : graph_(graph), unique_name_counter_(0) {} - Status Run(); + absl::Status Run(); private: - Status CloneSmallConstantInputs(const absl::flat_hash_set& name_set, - Node* n); + absl::Status CloneSmallConstantInputs( + const absl::flat_hash_set& name_set, Node* n); string GenerateUniqueName(const absl::flat_hash_set& name_set, absl::string_view prefix); absl::StatusOr CloneNode(const absl::flat_hash_set& name_set, @@ -110,7 +110,8 @@ bool IsInPlaceOp(absl::string_view op_name) { } } // namespace -Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( +absl::Status +CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( const absl::flat_hash_set& name_set, Node* n) { std::vector in_edges; // Get the edges and sort them so we clone in a deterministic order. @@ -140,7 +141,7 @@ Status CloneConstantsForBetterClusteringPassImpl::CloneSmallConstantInputs( return absl::OkStatus(); } -Status CloneConstantsForBetterClusteringPassImpl::Run() { +absl::Status CloneConstantsForBetterClusteringPassImpl::Run() { absl::flat_hash_set name_set; absl::c_transform(graph_->nodes(), std::inserter(name_set, name_set.begin()), [](Node* n) { return n->name(); }); @@ -198,7 +199,7 @@ Status CloneConstantsForBetterClusteringPassImpl::Run() { return absl::OkStatus(); } -Status CloneConstantsForBetterClusteringPass::Run( +absl::Status CloneConstantsForBetterClusteringPass::Run( const GraphOptimizationPassOptions& options) { if (GetGlobalJitLevelForGraph(options) == OptimizerOptions::OFF) { return absl::OkStatus(); diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h index 19e6c49ec44538..ebe510083d7205 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering.h +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering.h @@ -55,7 +55,7 @@ class CloneConstantsForBetterClusteringPass : public GraphOptimizationPass { public: CloneConstantsForBetterClusteringPass() = default; - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc index db5ebaac54b3ba..f132a91bc8d10b 100644 --- a/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc +++ b/tensorflow/compiler/jit/clone_constants_for_better_clustering_test.cc @@ -28,8 +28,8 @@ namespace tensorflow { namespace { using ::tensorflow::testing::FindNodeByName; -Status CloneConstantsForBetterClustering(const Scope& s, - std::unique_ptr* result) { +absl::Status CloneConstantsForBetterClustering(const Scope& s, + std::unique_ptr* result) { auto graph = std::make_unique(OpRegistry::Global()); SessionOptions session_options; session_options.config.mutable_graph_options() diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.cc b/tensorflow/compiler/jit/cluster_scoping_pass.cc index edf72f83861e54..e4efb8922089c6 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass.cc @@ -34,10 +34,10 @@ class ClusterScopingPassImpl { global_jit_level_(global_jit_level), unique_scope_id_(0) {} - Status Run(); + absl::Status Run(); private: - Status ScopingForPipelineStages(); + absl::Status ScopingForPipelineStages(); size_t GetUniqueScopeId() { return unique_scope_id_++; } @@ -131,7 +131,7 @@ void ClusterScopingPassImpl::AddScopeToAllTransitiveSuccessors(Node* start) { // // Unstage -> Node_Y // -Status ClusterScopingPassImpl::ScopingForPipelineStages() { +absl::Status ClusterScopingPassImpl::ScopingForPipelineStages() { for (Node* n : graph_->nodes()) { DCHECK(n); if (n->type_string() == "Unstage") { @@ -145,7 +145,7 @@ Status ClusterScopingPassImpl::ScopingForPipelineStages() { return absl::OkStatus(); } -Status ClusterScopingPassImpl::Run() { +absl::Status ClusterScopingPassImpl::Run() { if (global_jit_level_ == OptimizerOptions::OFF) { return absl::OkStatus(); } @@ -154,7 +154,8 @@ Status ClusterScopingPassImpl::Run() { } } // namespace -Status ClusterScopingPass::Run(const GraphOptimizationPassOptions& options) { +absl::Status ClusterScopingPass::Run( + const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); return ClusterScopingPassImpl{graph, GetGlobalJitLevelForGraph(options)} diff --git a/tensorflow/compiler/jit/cluster_scoping_pass.h b/tensorflow/compiler/jit/cluster_scoping_pass.h index 9651c3f878cfc6..0b0c2ccf842db2 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass.h +++ b/tensorflow/compiler/jit/cluster_scoping_pass.h @@ -30,7 +30,7 @@ namespace tensorflow { // clustering decision. class ClusterScopingPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc index 436d2f867c94d3..b09cb2c12fa297 100644 --- a/tensorflow/compiler/jit/cluster_scoping_pass_test.cc +++ b/tensorflow/compiler/jit/cluster_scoping_pass_test.cc @@ -31,7 +31,7 @@ limitations under the License. namespace tensorflow { namespace { -Status ClusterScoping(std::unique_ptr* graph) { +absl::Status ClusterScoping(std::unique_ptr* graph) { FixupSourceAndSinkEdges(graph->get()); GraphOptimizationPassWrapper wrapper; diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 5421637e80e5e0..2b15a4affc76af 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -80,8 +80,9 @@ bool IsInOutsideCompilationCluster(const Node& n) { return n.attrs().Find(kXlaOutsideCompilationAttr) != nullptr; } -Status MakeCallNodeFromAttribute(const Node& node, const std::string& attr_name, - NodeDef* node_def) { +absl::Status MakeCallNodeFromAttribute(const Node& node, + const std::string& attr_name, + NodeDef* node_def) { const NameAttrList* name_attr; TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), attr_name, &name_attr)); node_def->set_op(name_attr->name()); @@ -200,7 +201,8 @@ bool RecursiveCompilabilityChecker::HasXLAKernel( return false; } - Status s = FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr); + absl::Status s = + FindKernelDef(jit_device_type_, node.def(), nullptr, nullptr); if (!s.ok()) { *uncompilable_reason = s.message(); return false; @@ -322,7 +324,7 @@ bool RecursiveCompilabilityChecker::IsCompilableCall( } FunctionLibraryRuntime::Handle handle; - Status s; + absl::Status s; NameAttrList function; s = NameAndAttrsFromFunctionCall(call_def, &function); if (s.ok()) { @@ -628,11 +630,10 @@ bool CanCreateXlaKernel(const NodeDef& node_def) { return HasBoolAttr(node_def, kXlaMustCompileAttr); } -Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NameAttrList& function, - const FunctionBody** fbody, - std::vector* constant_arg_indices, - std::vector* resource_arg_indices) { +absl::Status GetBodyAndConstantsAndResources( + FunctionLibraryRuntime* flr, const NameAttrList& function, + const FunctionBody** fbody, std::vector* constant_arg_indices, + std::vector* resource_arg_indices) { FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flr->Instantiate(function.name(), AttrSlice(&function.attr()), &handle)); diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 18f6e5197b9cae..0d86c22de11a22 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -285,11 +285,10 @@ RecursiveCompilabilityChecker::OperationFilter CreateOperationFilter( // `fbody` is owned by `flr`. // `constant_arg_indices` and `resource_arg_indices` should be empty vector. // They are sorted in ascending order on this function's return. -Status GetBodyAndConstantsAndResources(FunctionLibraryRuntime* flr, - const NameAttrList& function, - const FunctionBody** fbody, - std::vector* constant_arg_indices, - std::vector* resource_arg_indices); +absl::Status GetBodyAndConstantsAndResources( + FunctionLibraryRuntime* flr, const NameAttrList& function, + const FunctionBody** fbody, std::vector* constant_arg_indices, + std::vector* resource_arg_indices); // Given a NodeDef `node_def` returns true iff `node_def` has kXlaCompileAttr // set. diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index c359bc3bd0e56a..2b2db07642d1ab 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -414,8 +414,8 @@ class PredicateFactory { return new_pred_ptr; } - Status MakeSymbolPredicate(Node* node, int output_idx, bool must_be_true, - Predicate** predicate) { + absl::Status MakeSymbolPredicate(Node* node, int output_idx, + bool must_be_true, Predicate** predicate) { TensorId tensor_id(node->name(), output_idx); bool is_boolean_tensor = @@ -449,9 +449,9 @@ class PredicateFactory { return absl::OkStatus(); } - Status MakeSymbolPredicate(Node* node, int output_idx, - std::optional must_have_value, - Predicate** predicate) { + absl::Status MakeSymbolPredicate(Node* node, int output_idx, + std::optional must_have_value, + Predicate** predicate) { TensorId tensor_id(node->name(), output_idx); TF_RET_CHECK(BaseType(node->output_type(tensor_id.index())) == DT_INT32); @@ -824,9 +824,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { explicit DeadnessAnalysisImpl(const Graph* graph) : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} - Status Populate(bool enable_optimistic); - Status PopulateFrame(absl::Span topo, bool use_optimistic_mode, - bool* success); + absl::Status Populate(bool enable_optimistic); + absl::Status PopulateFrame(absl::Span topo, + bool use_optimistic_mode, bool* success); absl::StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; @@ -836,8 +836,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { private: enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly }; - Status GetInputPreds(Node* n, EdgeKind edge_kind, - std::vector* result); + absl::Status GetInputPreds(Node* n, EdgeKind edge_kind, + std::vector* result); // Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th // bit of `should_revisit` if `pred` is different from the current predicate @@ -867,15 +867,15 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } } - Status HandleSwitch(Node* n, std::vector* should_revisit); - Status HandleMerge(Node* n, std::vector* should_revisit, - bool use_optimistic_mode); - Status HandleRecv(Node* n, std::vector* should_revisit); - Status HandleGeneric(Node* n, std::vector* should_revisit); - Status HandleNode(Node* n, std::vector* should_revisit, - bool use_optimistic_mode = false); + absl::Status HandleSwitch(Node* n, std::vector* should_revisit); + absl::Status HandleMerge(Node* n, std::vector* should_revisit, + bool use_optimistic_mode); + absl::Status HandleRecv(Node* n, std::vector* should_revisit); + absl::Status HandleGeneric(Node* n, std::vector* should_revisit); + absl::Status HandleNode(Node* n, std::vector* should_revisit, + bool use_optimistic_mode = false); - Status GetFrameBasedTopologicalOrder(std::vector* order); + absl::Status GetFrameBasedTopologicalOrder(std::vector* order); bool IsRootEnter(const Node* n) const { return IsEnter(n) && control_flow_info_[n->id()].parent_frame->IsSource(); @@ -897,7 +897,7 @@ TensorId InputEdgeToTensorId(const Edge* e) { return TensorId(e->src()->name(), e->src_output()); } -Status DeadnessAnalysisImpl::GetInputPreds( +absl::Status DeadnessAnalysisImpl::GetInputPreds( Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind, std::vector* result) { result->clear(); @@ -927,8 +927,8 @@ Status DeadnessAnalysisImpl::GetInputPreds( return absl::OkStatus(); } -Status DeadnessAnalysisImpl::HandleSwitch(Node* n, - std::vector* should_revisit) { +absl::Status DeadnessAnalysisImpl::HandleSwitch( + Node* n, std::vector* should_revisit) { std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); const Edge* pred_edge; @@ -981,7 +981,7 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n, } namespace { -Status CreateMultipleNextIterationInputsError(Node* merge) { +absl::Status CreateMultipleNextIterationInputsError(Node* merge) { std::vector backedges; for (const Edge* backedge : merge->in_edges()) { if (backedge->src()->IsNextIteration()) { @@ -994,7 +994,7 @@ Status CreateMultipleNextIterationInputsError(Node* merge) { "\nMerge nodes can have at most one incoming NextIteration edge."); } -Status FindUniqueBackedge(Node* merge, const Edge** result) { +absl::Status FindUniqueBackedge(Node* merge, const Edge** result) { *result = nullptr; CHECK(merge->IsMerge()); for (const Edge* e : merge->in_edges()) { @@ -1056,8 +1056,9 @@ Predicate* DeduceStepPredicate(PredicateFactory* predicate_factory, return found_sym ? predicate_factory->MakeAndPredicate(and_ops) : nullptr; } -Status GetFullFrame(const Node* n, absl::Span cfi_infos, - std::vector* frame) { +absl::Status GetFullFrame(const Node* n, + absl::Span cfi_infos, + std::vector* frame) { int depth = 0; for (const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; !n->IsSource(); n = cfi_iter->parent_frame, cfi_iter = &cfi_infos[n->id()]) { @@ -1075,8 +1076,9 @@ Status GetFullFrame(const Node* n, absl::Span cfi_infos, // If the node is inside some frames, get the name of the outermost non-empty // frame. Otherwise, get an empty frame name. -Status GetRootFrame(const Node* n, absl::Span cfi_infos, - absl::string_view* frame) { +absl::Status GetRootFrame(const Node* n, + absl::Span cfi_infos, + absl::string_view* frame) { int depth = 0; const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; while (!cfi_iter->parent_frame->IsSource()) { @@ -1095,9 +1097,8 @@ Status GetRootFrame(const Node* n, absl::Span cfi_infos, } } // namespace -Status DeadnessAnalysisImpl::HandleMerge(Node* n, - std::vector* should_revisit, - bool use_optimistic_mode) { +absl::Status DeadnessAnalysisImpl::HandleMerge( + Node* n, std::vector* should_revisit, bool use_optimistic_mode) { // Merge ignores deadness of its control inputs. A merge that isn't the // target of a backedge has is alive iff any of its data inputs are. The // liveness of a merge that is the target of a backedge can sometimes be @@ -1185,8 +1186,8 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, return absl::OkStatus(); } -Status DeadnessAnalysisImpl::HandleRecv(Node* n, - std::vector* should_revisit) { +absl::Status DeadnessAnalysisImpl::HandleRecv( + Node* n, std::vector* should_revisit) { // In addition to being alive or dead based on the inputs, a _Recv can also // acquire a dead signal from a _Send. std::vector input_preds; @@ -1201,8 +1202,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n, return absl::OkStatus(); } -Status DeadnessAnalysisImpl::HandleGeneric(Node* n, - std::vector* should_revisit) { +absl::Status DeadnessAnalysisImpl::HandleGeneric( + Node* n, std::vector* should_revisit) { // Generally nodes are alive iff all their inputs are alive. std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds)); @@ -1214,9 +1215,9 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n, return absl::OkStatus(); } -Status DeadnessAnalysisImpl::HandleNode(Node* n, - std::vector* should_revisit, - bool use_optimistic_mode) { +absl::Status DeadnessAnalysisImpl::HandleNode(Node* n, + std::vector* should_revisit, + bool use_optimistic_mode) { if (n->IsSwitch()) { TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); } else if (n->IsMerge()) { @@ -1240,7 +1241,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, // many inputs of each node are ready; a node is ready to be scheduled if all // of its inputs are ready. // Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details. -Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( +absl::Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( std::vector* order) { absl::flat_hash_map num_enters_for_frame; absl::flat_hash_map num_exits_for_frame; @@ -1356,7 +1357,7 @@ Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( // (root) frame. Note that we don't separate while loops belonging to the same // nested while, as there is no clean cut for separating them in the topological // order. -Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { +absl::Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( @@ -1418,9 +1419,9 @@ Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { return absl::OkStatus(); } -Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, - bool use_optimistic_mode, - bool* success) { +absl::Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, + bool use_optimistic_mode, + bool* success) { CHECK(use_optimistic_mode && success != nullptr || !use_optimistic_mode && success == nullptr); @@ -1567,7 +1568,7 @@ void DeadnessAnalysisImpl::Print() const { DeadnessAnalysis::~DeadnessAnalysis() {} -/*static*/ Status DeadnessAnalysis::Run( +/*static*/ absl::Status DeadnessAnalysis::Run( const Graph& graph, std::unique_ptr* result) { std::unique_ptr analysis( new DeadnessAnalysisImpl(&graph)); @@ -1591,8 +1592,9 @@ DeadnessAnalysisImpl::PredicateMapAsString() const { } namespace deadness_analysis_internal { -Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, - bool enable_optimistic) { +absl::Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map, + bool enable_optimistic) { DeadnessAnalysisImpl impl(&graph); TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic)); *out_predicate_map = impl.PredicateMapAsString(); diff --git a/tensorflow/compiler/jit/deadness_analysis.h b/tensorflow/compiler/jit/deadness_analysis.h index fe00acb866f179..80fa9a20faef41 100644 --- a/tensorflow/compiler/jit/deadness_analysis.h +++ b/tensorflow/compiler/jit/deadness_analysis.h @@ -85,8 +85,8 @@ class DeadnessAnalysis { // Run the deadness analysis over `graph` and returns an error or a populated // instance of DeadnessAnalysis in `result`. - static Status Run(const Graph& graph, - std::unique_ptr* result); + static absl::Status Run(const Graph& graph, + std::unique_ptr* result); protected: static DeadnessPredicate MakeDeadnessPredicate(void* pred) { diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index b2f0e72bc14ae6..0dc18d3e129d79 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -25,8 +25,9 @@ namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. using PredicateMapTy = absl::flat_hash_map; -Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, - bool enable_optimistic = true); +absl::Status ComputePredicates(const Graph& graph, + PredicateMapTy* out_predicate_map, + bool enable_optimistic = true); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index e2db6c0acca490..894ee659121e25 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -55,8 +55,8 @@ absl::StatusOr HasInputsWithMismatchingDeadness( using deadness_analysis_internal::ComputePredicates; using deadness_analysis_internal::PredicateMapTy; -Status AnalyzeDeadness(Graph* graph, - std::unique_ptr* result) { +absl::Status AnalyzeDeadness(Graph* graph, + std::unique_ptr* result) { FixupSourceAndSinkEdges(graph); return DeadnessAnalysis::Run(*graph, result); } diff --git a/tensorflow/compiler/jit/device_compilation_cache.h b/tensorflow/compiler/jit/device_compilation_cache.h index e4d39745883038..ad87134940c00c 100644 --- a/tensorflow/compiler/jit/device_compilation_cache.h +++ b/tensorflow/compiler/jit/device_compilation_cache.h @@ -73,7 +73,7 @@ class DeviceCompilationCache { using Key = DeviceCompilationClusterSignature; struct Value { DeviceCompileState compile_state = DeviceCompileState::kUncompiled; - Status compilation_status; + absl::Status compilation_status; int64_t request_count = 0; const XlaCompiler::CompilationResult* compilation_result = nullptr; ExecutableType* executable = nullptr; @@ -93,7 +93,7 @@ class DeviceCompilationCache { // corresponding `request_count`. Only arguments that are not std::nullopt are // updated in the cache. void Store(const Key& key, std::optional compile_state, - std::optional compilation_status, + std::optional compilation_status, std::optional> compilation_result, std::optional> executable); @@ -113,7 +113,7 @@ class DeviceCompilationCache { int64_t request_count TF_GUARDED_BY(mu) = 0; // Did compilation succeed? - Status compilation_status TF_GUARDED_BY(mu); + absl::Status compilation_status TF_GUARDED_BY(mu); // Output of the XlaCompiler. std::unique_ptr compilation_result @@ -206,7 +206,7 @@ DeviceCompilationCache::LookupOrCreate(const Key& key) { template void DeviceCompilationCache::Store( const Key& key, std::optional compile_state, - std::optional compilation_status, + std::optional compilation_status, std::optional> compilation_result, std::optional> executable) { diff --git a/tensorflow/compiler/jit/device_compilation_cache_test.cc b/tensorflow/compiler/jit/device_compilation_cache_test.cc index cdafa5ce35255f..2c6b22f8830f6a 100644 --- a/tensorflow/compiler/jit/device_compilation_cache_test.cc +++ b/tensorflow/compiler/jit/device_compilation_cache_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/jit/device_compilation_profiler.cc b/tensorflow/compiler/jit/device_compilation_profiler.cc index b2da1959c98f72..5e1b3b26e8ecb5 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.cc +++ b/tensorflow/compiler/jit/device_compilation_profiler.cc @@ -94,7 +94,7 @@ void DeviceCompilationProfiler::RegisterExecution( RegisterExecutionForCluster(function, &it->second); } -Status DeviceCompilationProfiler::RegisterCompilation( +absl::Status DeviceCompilationProfiler::RegisterCompilation( const NameAttrList& function, int64_t compile_time_us, bool used_persistent_cache) { metrics::UpdateXlaCompilationTime(compile_time_us); diff --git a/tensorflow/compiler/jit/device_compilation_profiler.h b/tensorflow/compiler/jit/device_compilation_profiler.h index 2057e1adc12dee..9f1d9521f4f1be 100644 --- a/tensorflow/compiler/jit/device_compilation_profiler.h +++ b/tensorflow/compiler/jit/device_compilation_profiler.h @@ -74,9 +74,9 @@ class DeviceCompilationProfiler : public ResourceBase { // Registers a cluster compilation. Increments the compilation count and // accumulates the compile time for the given cluster. Also broadcasts an // XlaJitCompilationActivity. - virtual Status RegisterCompilation(const NameAttrList& function, - int64_t compile_time_us, - bool used_persistent_cache); + virtual absl::Status RegisterCompilation(const NameAttrList& function, + int64_t compile_time_us, + bool used_persistent_cache); void IncrementOngoingAsyncCompilations(); void DecrementOngoingAsyncCompilations(); diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index da2809cb7b62a1..1baa70850a02b4 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -99,7 +99,7 @@ class DeviceCompiler : public ResourceBase { // `ExecutableType` and sets `out_executable` to point to it. The // resulting executable pointer may be null if the computation has no // non-constant outputs. - Status CompileIfNeeded( + absl::Status CompileIfNeeded( const XlaCompiler::Options& options, const NameAttrList& function, const std::vector& args, const XlaCompiler::CompileOptions& compile_options, @@ -108,7 +108,7 @@ class DeviceCompiler : public ResourceBase { ExecutableType** out_executable); // As above, but for a single op. - Status CompileSingleOpIfNeeded( + absl::Status CompileSingleOpIfNeeded( const XlaCompiler::Options& options, const std::vector& args, const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, @@ -131,7 +131,7 @@ class DeviceCompiler : public ResourceBase { private: // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` // parameter is always null for the former. - Status CompileImpl( + absl::Status CompileImpl( const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::Options& options, const NameAttrList& function, const std::vector& args, CompileScope scope, @@ -152,13 +152,13 @@ class DeviceCompiler : public ResourceBase { DeviceCompilationProfiler* profiler, mutex* mu) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu); - Status CompileAsynchronous(const DeviceCompilationClusterSignature& sig, - const XlaCompiler::CompileOptions& compile_options, - const XlaCompiler::Options& options, - const std::vector& args, - const NameAttrList& function, CompileScope scope, - OpKernelContext* ctx, - DeviceCompilationProfiler* profiler); + absl::Status CompileAsynchronous( + const DeviceCompilationClusterSignature& sig, + const XlaCompiler::CompileOptions& compile_options, + const XlaCompiler::Options& options, + const std::vector& args, + const NameAttrList& function, CompileScope scope, OpKernelContext* ctx, + DeviceCompilationProfiler* profiler); std::unique_ptr> persistor_; @@ -191,8 +191,8 @@ inline void LogOnceXlaCompiledFirstCluster() { } template -inline Status EligibleToPersist(DeviceCompileState compile_state, - const ExecutableType* executable) { +inline absl::Status EligibleToPersist(DeviceCompileState compile_state, + const ExecutableType* executable) { if (compile_state != DeviceCompileState::kCompiled) { return errors::FailedPrecondition( "Cache entry to serialize is not compiled."); @@ -244,7 +244,7 @@ string DeviceCompiler::DebugString() const { } template -Status DeviceCompiler::CompileIfNeeded( +absl::Status DeviceCompiler::CompileIfNeeded( const XlaCompiler::Options& options, const NameAttrList& function, const std::vector& args, const XlaCompiler::CompileOptions& compile_options, @@ -257,7 +257,8 @@ Status DeviceCompiler::CompileIfNeeded( } template -Status DeviceCompiler::CompileSingleOpIfNeeded( +absl::Status +DeviceCompiler::CompileSingleOpIfNeeded( const XlaCompiler::Options& options, const std::vector& args, const XlaCompiler::CompileOptions& compile_options, OpKernelContext* ctx, @@ -350,7 +351,7 @@ DeviceCompiler::CompileStrict( } template -Status DeviceCompiler::CompileAsynchronous( +absl::Status DeviceCompiler::CompileAsynchronous( const DeviceCompilationClusterSignature& signature, const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::Options& options, @@ -395,7 +396,7 @@ Status DeviceCompiler::CompileAsynchronous( } template -Status DeviceCompiler::CompileImpl( +absl::Status DeviceCompiler::CompileImpl( const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::Options& options, const NameAttrList& function, const std::vector& args, CompileScope scope, diff --git a/tensorflow/compiler/jit/device_compiler_disable_test.cc b/tensorflow/compiler/jit/device_compiler_disable_test.cc index 1d53716c435850..481330aea6292e 100644 --- a/tensorflow/compiler/jit/device_compiler_disable_test.cc +++ b/tensorflow/compiler/jit/device_compiler_disable_test.cc @@ -64,7 +64,7 @@ TEST(DeviceCompilerTest, TestDisabledXlaCompilation) { core::ScopedUnref profiler_ref(profiler); // Check that strict compilation is disallowed. - Status status = xla_device_compiler->CompileIfNeeded( + absl::Status status = xla_device_compiler->CompileIfNeeded( XlaCompiler::Options{}, fn, args, XlaCompiler::CompileOptions{}, DeviceCompileMode::kStrict, profiler, &compilation_result, &executable); EXPECT_FALSE(status.ok()); diff --git a/tensorflow/compiler/jit/device_compiler_test.cc b/tensorflow/compiler/jit/device_compiler_test.cc index f13191682d2817..c458be2c7adf10 100644 --- a/tensorflow/compiler/jit/device_compiler_test.cc +++ b/tensorflow/compiler/jit/device_compiler_test.cc @@ -110,7 +110,7 @@ class MockXlaDeviceExecutablePersistor : DeviceExecutablePersistor( Config{testing::TmpDir(), false, "xla"}, DeviceType(DEVICE_CPU_XLA_JIT)) {} - MOCK_METHOD(Status, TryToPersistExecutable, + MOCK_METHOD(absl::Status, TryToPersistExecutable, (uint64, const std::string&, const XlaCompiler::Options&, const XlaCompiler::CompilationResult&, const xla::LocalExecutable&, @@ -124,7 +124,7 @@ class MockDeviceCompilationProfiler : public DeviceCompilationProfiler { (const NameAttrList& function, DeviceCompileMode compile_mode, int64_t current_request_count), (override)); - MOCK_METHOD(Status, RegisterCompilation, + MOCK_METHOD(absl::Status, RegisterCompilation, (const NameAttrList& function, int64_t compile_time_us, bool used_persistent_cache), (override)); diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc index d02337d36d7a35..34a0c3d5ea067b 100644 --- a/tensorflow/compiler/jit/device_context_test.cc +++ b/tensorflow/compiler/jit/device_context_test.cc @@ -46,7 +46,7 @@ class DeviceContextTest : public ::testing::Test { auto device_factory = DeviceFactory::GetFactory(device_type); SessionOptions options; std::vector> devices; - Status s = device_factory->CreateDevices( + absl::Status s = device_factory->CreateDevices( options, "/job:worker/replica:0/task:0", &devices); device_ = std::move(devices[0]); diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index a6621f76008015..458441c86b5c43 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -106,7 +106,7 @@ class DeviceExecutablePersistor { // pipeline and persists that to disk. // TODO(b/255826209): Take in Signature instead hash and string once cache // is refactored. - virtual Status TryToPersistExecutable( + virtual absl::Status TryToPersistExecutable( uint64 signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, const XlaCompiler::CompilationResult& compilation_result, @@ -138,7 +138,7 @@ class DeviceExecutablePersistor { // Saves the cache entry in the file directory supplied during the // construction of this class. Overwrites existing entries. - Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry) const; + absl::Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry) const; // Tries to read a cache entry given a `key` by searching the file directory // supplied during the construction of this class. Returns std::nullopt if no @@ -147,9 +147,9 @@ class DeviceExecutablePersistor { TryToReadSerializedEntry(const XlaSerializedCacheKey& key) const; // Checks if the loaded `entry` matches the expected `key` and `hlo_module`. - Status VerifyLoadedCacheEntry(const XlaSerializedCacheKey& key, - const xla::HloModuleProto& hlo_module, - const XlaSerializedCacheEntry& entry) const; + absl::Status VerifyLoadedCacheEntry( + const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, + const XlaSerializedCacheEntry& entry) const; std::string GetFilePath(const XlaSerializedCacheKey& key) const; @@ -233,7 +233,7 @@ DeviceExecutablePersistor::TryToReadSerializedEntry( } template -Status +absl::Status DeviceExecutablePersistor::VerifyLoadedCacheEntry( const XlaSerializedCacheKey& key, const xla::HloModuleProto& hlo_module, const XlaSerializedCacheEntry& entry) const { @@ -267,7 +267,7 @@ DeviceExecutablePersistor::VerifyLoadedCacheEntry( } template -Status +absl::Status DeviceExecutablePersistor::SaveSerializedEntry( const XlaSerializedCacheEntry& entry) const { Env* env = Env::Default(); @@ -374,7 +374,7 @@ DeviceExecutablePersistor::TryToLoadExecutable( } template -Status +absl::Status DeviceExecutablePersistor::TryToPersistExecutable( uint64 signature_hash, const std::string& signature_str, const XlaCompiler::Options& options, diff --git a/tensorflow/compiler/jit/device_executable_persistor_test.cc b/tensorflow/compiler/jit/device_executable_persistor_test.cc index 3b338192e4284a..75092750ed6948 100644 --- a/tensorflow/compiler/jit/device_executable_persistor_test.cc +++ b/tensorflow/compiler/jit/device_executable_persistor_test.cc @@ -134,7 +134,7 @@ class DeviceExecutionPersistorTest : public ::testing::Test { return options; } - Status CreatePjRtCompilerClient() { + absl::Status CreatePjRtCompilerClient() { // Create PjRtClient manually while GetOrCreatePjRtClient() is WIP. TF_RETURN_IF_ERROR(SetPjRtClientInTFGlobalResourceManager( DEVICE_CPU_XLA_JIT, diff --git a/tensorflow/compiler/jit/device_util.cc b/tensorflow/compiler/jit/device_util.cc index 182d2f1d53ae49..828da0b08c2590 100644 --- a/tensorflow/compiler/jit/device_util.cc +++ b/tensorflow/compiler/jit/device_util.cc @@ -87,7 +87,8 @@ string DeviceInfoCache::DebugString(const DeviceSet& device_set) const { } } // namespace jit -Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) { +absl::Status DeviceNameToDeviceType(const string& device, + DeviceType* device_type) { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(device, &parsed)) { return errors::Internal("Malformed assigned device '", device, "'"); diff --git a/tensorflow/compiler/jit/device_util.h b/tensorflow/compiler/jit/device_util.h index de06732f39008d..745f87309501d8 100644 --- a/tensorflow/compiler/jit/device_util.h +++ b/tensorflow/compiler/jit/device_util.h @@ -149,7 +149,8 @@ class DeviceInfoCache { } // namespace jit // Returns the DeviceType corresponding to 'device'. -Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); +absl::Status DeviceNameToDeviceType(const string& device, + DeviceType* device_type); // Picks the device for which XLA should compile a cluster that contains // operations placed in devices in `devices`. For instance a cluster that diff --git a/tensorflow/compiler/jit/device_util_test.cc b/tensorflow/compiler/jit/device_util_test.cc index c63aa6683b9544..cef39df6283f2b 100644 --- a/tensorflow/compiler/jit/device_util_test.cc +++ b/tensorflow/compiler/jit/device_util_test.cc @@ -21,9 +21,9 @@ limitations under the License. namespace tensorflow { namespace { -Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, - absl::Span device_names, - string* result) { +absl::Status PickDeviceHelper(bool allow_mixing_unknown_and_cpu, + absl::Span device_names, + string* result) { jit::DeviceInfoCache cache; jit::DeviceSet device_set; for (absl::string_view name : device_names) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 47034a6a791f77..3e8a43ce08ed58 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -133,7 +133,7 @@ class Encapsulator { // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. - Status SplitIntoSubgraphs(FunctionLibraryDefinition* library); + absl::Status SplitIntoSubgraphs(FunctionLibraryDefinition* library); // Build a FunctionDef for each subgraph, and add it 'library'. The values of // the 'group_attribute' annotations become the function names. @@ -141,13 +141,14 @@ class Encapsulator { // same name, if any. // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before // function conversion. - Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, - bool reuse_existing_functions, - FunctionLibraryDefinition* library); + absl::Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, + FunctionLibraryDefinition* library); // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(Graph* graph_out, FunctionLibraryDefinition* library); + absl::Status BuildOutputGraph(Graph* graph_out, + FunctionLibraryDefinition* library); private: // A subgraph of the input, all marked with a common 'group_attribute' @@ -181,13 +182,13 @@ class Encapsulator { // 'reuse_existing_functions' is set, use an existing function with the same // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the // subgraph before function conversion. - Status BuildFunctionDef(const string& name_in, - const RewriteSubgraphFn& rewrite_subgraph_fn, - bool reuse_existing_functions, - FunctionLibraryDefinition* library); + absl::Status BuildFunctionDef(const string& name_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, + FunctionLibraryDefinition* library); // Adds the function call node to graph_out. - Status AddFunctionCallNode( + absl::Status AddFunctionCallNode( const absl::flat_hash_map& node_images, Graph* graph_out); @@ -205,32 +206,34 @@ class Encapsulator { // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, // and adds the edge within the subgraph from the _Arg node to the image of // the dst node. - Status RecordArg(const Edge* edge, - const absl::flat_hash_map& node_images, - std::vector>* src_arg_pairs); + absl::Status RecordArg( + const Edge* edge, + const absl::flat_hash_map& node_images, + std::vector>* src_arg_pairs); // Records the src of the given edge as a control result of the graph. // Used during graph to function conversion to tie control results to // the function signature. - Status RecordControlResult( + absl::Status RecordControlResult( const Edge* edge, const absl::flat_hash_map& node_images); // Creates a _Retval node for the src node of edge, and add it to results_, // if none exists yet. If a new _Retval node is created, also adds the edge // within the subgraph from the src to the _Retval node. - Status RecordResult( + absl::Status RecordResult( const Edge* edge, const absl::flat_hash_map& node_images); // Creates the sequencer node if it doesn't exist, adding it to graph_out. - Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); + absl::Status MakeSequencingNode(const string& subgraph_name, + Graph* graph_out); // If there is a sequencer node, adds a control edge from the sequencer to // the call node. void ConnectSequencerToCallNode(Graph* graph_out); - Status ReplaceFunctionDef(FunctionLibraryDefinition* library); + absl::Status ReplaceFunctionDef(FunctionLibraryDefinition* library); private: // The subgraph extracted from the input graph, suitable for being turned @@ -280,31 +283,31 @@ class Encapsulator { // Returns the key attribute associated with a node in attr. Sets either // result to the empty string if the respective attribute is not found. - Status GetFunctionNameAttr(Node const* node, string* attr) const; + absl::Status GetFunctionNameAttr(Node const* node, string* attr) const; // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to // subgraphs for data edges that cross subgraph boundaries. - Status CopySubgraphEdges( + absl::Status CopySubgraphEdges( const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs); // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes. - Status CopySubgraphNodes( + absl::Status CopySubgraphNodes( absl::flat_hash_map* node_images); // Copies all nodes that aren't in a compiled subgraph to the output graph. - Status CopyNodesToOutputGraph( + absl::Status CopyNodesToOutputGraph( Graph* graph_out, absl::flat_hash_map* node_images); // Adds function call nodes for each compiled subgraph. - Status AddFunctionCallNodes( + absl::Status AddFunctionCallNodes( const absl::flat_hash_map& node_images, Graph* graph_out); // Finds the image of an edge source in the output graph. If the edge crosses // a subgraph boundary it is the output of a call node, otherwise it is a node // in the output graph. - Status FindOutputImageOfEdgeSrc( + absl::Status FindOutputImageOfEdgeSrc( const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image); @@ -319,7 +322,7 @@ class Encapsulator { // Finds the image of an edge destination in the output graph. If the edge // crosses a subgraph boundary it is the input of a call node, otherwise it is // a node in the output graph. - Status FindOutputImageOfEdgeDst( + absl::Status FindOutputImageOfEdgeDst( const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image); @@ -333,7 +336,7 @@ class Encapsulator { // Copies a single edge to the output graph. The edge is either entirely // within the output graph, or crosses into or out of a compiled subgraph. - Status CopyEdgeToOutputGraph( + absl::Status CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, @@ -341,7 +344,7 @@ class Encapsulator { OutputInputTensorPairHasher>* edges_added); // Adds all edges to the output graph. - Status AddEdgesToOutputGraph( + absl::Status AddEdgesToOutputGraph( const absl::flat_hash_map& node_images, Graph* graph_out); @@ -349,7 +352,7 @@ class Encapsulator { // one node in send_from_host_nodes and store it in pruned_graph. On exit // nodes_images contains a mapping from nodes in graph to nodes in // pruned_graph. All functions in the copied graph are inlined. - Status MakePrunedGraphCopyAndInline( + absl::Status MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, absl::flat_hash_map* node_images, @@ -448,7 +451,7 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } -Status Encapsulator::Subgraph::RecordArg( +absl::Status Encapsulator::Subgraph::RecordArg( const Edge* edge, const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { @@ -467,7 +470,7 @@ Status Encapsulator::Subgraph::RecordArg( DataType dtype = edge->dst()->input_type(edge->dst_input()); builder.Attr("T", dtype); builder.Attr("index", arg_index); - Status s = builder.Finalize(&arg_def); + absl::Status s = builder.Finalize(&arg_def); if (!s.ok()) return s; TF_ASSIGN_OR_RETURN(Node * arg, graph_->AddNode(arg_def)); @@ -482,7 +485,7 @@ Status Encapsulator::Subgraph::RecordArg( return absl::OkStatus(); } -Status Encapsulator::Subgraph::RecordControlResult( +absl::Status Encapsulator::Subgraph::RecordControlResult( const Edge* edge, const absl::flat_hash_map& node_images) { Node* src_node = edge->src(); @@ -491,7 +494,7 @@ Status Encapsulator::Subgraph::RecordControlResult( return absl::OkStatus(); } -Status Encapsulator::Subgraph::RecordResult( +absl::Status Encapsulator::Subgraph::RecordResult( const Edge* edge, const absl::flat_hash_map& node_images) { Node* src_node = edge->src(); @@ -511,7 +514,7 @@ Status Encapsulator::Subgraph::RecordResult( builder.Attr("T", dtype); builder.Attr("index", ret_index); builder.Input(src_image->name(), src_slot, dtype); - Status s = builder.Finalize(&ret_def); + absl::Status s = builder.Finalize(&ret_def); if (!s.ok()) return s; TF_ASSIGN_OR_RETURN(Node * ret, graph_->AddNode(ret_def)); graph_->AddEdge(src_image, src_slot, ret, 0); @@ -519,15 +522,15 @@ Status Encapsulator::Subgraph::RecordResult( return absl::OkStatus(); } -Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, - Graph* graph_out) { +absl::Status Encapsulator::Subgraph::MakeSequencingNode( + const string& subgraph_name, Graph* graph_out) { if (sequencer_ == nullptr) { NodeDef seq_def; // TODO(shikharagarwal): What source node should we use for errors? NodeDefBuilder builder(absl::StrCat(subgraph_name, "_sequencer"), "NoOp"); builder.Attr(kXlaHostTransferSequencerAttr, subgraph_name); builder.Device(device_); - Status s = builder.Finalize(&seq_def); + absl::Status s = builder.Finalize(&seq_def); if (!s.ok()) return s; TF_ASSIGN_OR_RETURN(sequencer_, graph_out->AddNode(seq_def)); @@ -543,7 +546,7 @@ void Encapsulator::Subgraph::ConnectSequencerToCallNode(Graph* graph_out) { } } -Status Encapsulator::Subgraph::BuildFunctionDef( +absl::Status Encapsulator::Subgraph::BuildFunctionDef( const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { // name_in is copied here because name may be modified below if @@ -620,7 +623,7 @@ Status Encapsulator::Subgraph::BuildFunctionDef( return absl::OkStatus(); } -Status Encapsulator::Subgraph::ReplaceFunctionDef( +absl::Status Encapsulator::Subgraph::ReplaceFunctionDef( FunctionLibraryDefinition* library) { const string& name = function_def_name_; @@ -639,7 +642,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( return absl::OkStatus(); } -Status Encapsulator::Subgraph::AddFunctionCallNode( +absl::Status Encapsulator::Subgraph::AddFunctionCallNode( const absl::flat_hash_map& node_images, Graph* graph_out) { TF_ASSIGN_OR_RETURN(call_node_, graph_out->AddNode(call_node_def_)); @@ -650,7 +653,8 @@ Status Encapsulator::Subgraph::AddFunctionCallNode( return absl::OkStatus(); } -Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { +absl::Status Encapsulator::GetFunctionNameAttr(Node const* node, + string* attr) const { AttrSlice attrs = node->attrs(); attr->clear(); for (const auto& node_attr : attrs) { @@ -665,7 +669,7 @@ Status Encapsulator::GetFunctionNameAttr(Node const* node, string* attr) const { bool IsInSubgraph(const string& func_id) { return !func_id.empty(); } -Status Encapsulator::CopySubgraphNodes( +absl::Status Encapsulator::CopySubgraphNodes( absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; @@ -680,7 +684,7 @@ Status Encapsulator::CopySubgraphNodes( return absl::OkStatus(); } -Status Encapsulator::CopySubgraphEdges( +absl::Status Encapsulator::CopySubgraphEdges( const absl::flat_hash_map& node_images, std::vector>* src_arg_pairs) { for (const Edge* edge : graph_in_->edges()) { @@ -751,8 +755,9 @@ Status Encapsulator::CopySubgraphEdges( return absl::OkStatus(); } -Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { - Status s; +absl::Status Encapsulator::SplitIntoSubgraphs( + FunctionLibraryDefinition* library) { + absl::Status s; // Map from input graph nodes to subgraph nodes. absl::flat_hash_map node_images; @@ -784,7 +789,7 @@ Status Encapsulator::SplitIntoSubgraphs(FunctionLibraryDefinition* library) { return s; } -Status Encapsulator::BuildFunctionDefs( +absl::Status Encapsulator::BuildFunctionDefs( const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, FunctionLibraryDefinition* library) { for (auto& subgraph_entry : subgraphs_) { @@ -796,7 +801,7 @@ Status Encapsulator::BuildFunctionDefs( return absl::OkStatus(); } -Status Encapsulator::CopyNodesToOutputGraph( +absl::Status Encapsulator::CopyNodesToOutputGraph( Graph* graph_out, absl::flat_hash_map* node_images) { for (Node* node : graph_in_->op_nodes()) { string func_id; @@ -813,7 +818,7 @@ Status Encapsulator::CopyNodesToOutputGraph( return absl::OkStatus(); } -Status Encapsulator::AddFunctionCallNodes( +absl::Status Encapsulator::AddFunctionCallNodes( const absl::flat_hash_map& node_images, Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { @@ -823,7 +828,7 @@ Status Encapsulator::AddFunctionCallNodes( return absl::OkStatus(); } -Status Encapsulator::FindOutputImageOfEdgeSrc( +absl::Status Encapsulator::FindOutputImageOfEdgeSrc( const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_src_node, Node** src_image) { @@ -854,7 +859,7 @@ int Encapsulator::FindOutputSlotOfEdgeSrc(const string& src_func_id, } } -Status Encapsulator::FindOutputImageOfEdgeDst( +absl::Status Encapsulator::FindOutputImageOfEdgeDst( const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, const Node* original_dst_node, Node** dst_image) { @@ -885,7 +890,7 @@ int Encapsulator::FindOutputSlotOfEdgeDst(const string& src_func_id, } } -Status Encapsulator::CopyEdgeToOutputGraph( +absl::Status Encapsulator::CopyEdgeToOutputGraph( const Edge* edge, const string& src_func_id, const string& dst_func_id, const absl::flat_hash_map& node_images, Graph* graph_out, @@ -927,7 +932,7 @@ Status Encapsulator::CopyEdgeToOutputGraph( return absl::OkStatus(); } -Status Encapsulator::AddEdgesToOutputGraph( +absl::Status Encapsulator::AddEdgesToOutputGraph( const absl::flat_hash_map& node_images, Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) @@ -1011,7 +1016,7 @@ Node* AddDummyShapedNode(const Node* src_node, int src_port, } // namespace -Status Encapsulator::MakePrunedGraphCopyAndInline( +absl::Status Encapsulator::MakePrunedGraphCopyAndInline( const Graph& graph, const std::vector& sink_nodes, std::unique_ptr* pruned_graph, absl::flat_hash_map* node_images, @@ -1071,8 +1076,8 @@ Status Encapsulator::MakePrunedGraphCopyAndInline( return absl::OkStatus(); } -Status Encapsulator::BuildOutputGraph(Graph* graph_out, - FunctionLibraryDefinition* library) { +absl::Status Encapsulator::BuildOutputGraph( + Graph* graph_out, FunctionLibraryDefinition* library) { // Map from nodes in the input graph to nodes in the output graph. absl::flat_hash_map node_images; @@ -1085,7 +1090,7 @@ Status Encapsulator::BuildOutputGraph(Graph* graph_out, } // anonymous namespace -Status EncapsulateSubgraphsInFunctions( +absl::Status EncapsulateSubgraphsInFunctions( string group_attribute, const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { @@ -1105,7 +1110,7 @@ Status EncapsulateSubgraphsInFunctions( } // Finds the types of the _Arg nodes, indexed by position. -static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { +static absl::Status GetArgTypes(const Graph& graph, DataTypeVector* types) { for (Node* n : graph.op_nodes()) { if (n->type_string() == kArgOp) { int index; @@ -1122,8 +1127,8 @@ static Status GetArgTypes(const Graph& graph, DataTypeVector* types) { // Renumber the indices of _Arg nodes in a graph, according to // 'permutation' that maps old indices to new indices. -static Status RenumberArguments(Graph* graph, - const std::vector& permutation) { +static absl::Status RenumberArguments(Graph* graph, + const std::vector& permutation) { for (Node* n : graph->op_nodes()) { if (n->type_string() == kArgOp) { int index; @@ -1138,7 +1143,7 @@ static Status RenumberArguments(Graph* graph, return absl::OkStatus(); } -Status EncapsulateSubgraphsPass::Run( +absl::Status EncapsulateSubgraphsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateSubgraphsPass::Run"; if (VLOG_IS_ON(1)) { diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index 6ff0180a3a1c2c..0c7729f67349b5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -33,7 +33,7 @@ namespace tensorflow { // that TF function annotated with kXlaCompiledKernelAttr (_XlaCompiledKernel). class EncapsulateSubgraphsPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; // A rewriting function to apply to each subgraph during encapsulation. @@ -48,7 +48,7 @@ class EncapsulateSubgraphsPass : public GraphOptimizationPass { // construction, provided to allow additional attributes to be set. // The rewrite may also change the NodeDef's operator name, and that // name will be used as the name of the generated function. -typedef std::function& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def)> @@ -72,7 +72,7 @@ typedef std::function* graph_out, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index e634982c85db8a..1e05ad067def7f 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -45,9 +45,9 @@ namespace { const char* const kXlaHostTransferSequencerAttr = "_xla_host_transfer_sequencer"; -Status AddGraphDefToFunctionLibrary(const GraphDefBuilder& graphdef_builder, - const string& name_suffix, - FunctionDefLibrary* library) { +absl::Status AddGraphDefToFunctionLibrary( + const GraphDefBuilder& graphdef_builder, const string& name_suffix, + FunctionDefLibrary* library) { GraphDef graphdef; TF_RETURN_IF_ERROR(graphdef_builder.ToGraphDef(&graphdef)); std::unique_ptr graph = @@ -477,9 +477,9 @@ Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) { return opts.FinalizeBuilder(&node_builder); } -Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, - const std::vector& encapsulated_functions) { - Status s; +absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, + const std::vector& encapsulated_functions) { + absl::Status s; // Convert the GraphDef to a Graph std::unique_ptr lib_def( new FunctionLibraryDefinition(OpRegistry::Global(), *library)); @@ -550,7 +550,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library, return s; } -Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { +absl::Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { std::vector encapsulated_functions; return Encapsulate(graphdef, library, encapsulated_functions); } diff --git a/tensorflow/compiler/jit/encapsulate_util.cc b/tensorflow/compiler/jit/encapsulate_util.cc index 71b1f35d539ddd..fa94a341bbabc6 100644 --- a/tensorflow/compiler/jit/encapsulate_util.cc +++ b/tensorflow/compiler/jit/encapsulate_util.cc @@ -47,9 +47,10 @@ std::optional GetStringAttr(const Node& n, const string& attr_name) { // Adds a value to the node's list attribute. template -Status AppendToListAttr(Node* n, const string& attr_name, const string& value) { +absl::Status AppendToListAttr(Node* n, const string& attr_name, + const string& value) { std::vector attr_value; - Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value); + absl::Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value); if (!s.ok() && s.code() != error::NOT_FOUND) { return s; } @@ -69,7 +70,7 @@ void ReplaceAttr(Node* n, const string& attr_name, const T& value) { // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. -Status PreprocessControlEdgesBetweenOutsideCompilations( +absl::Status PreprocessControlEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Gather edges to remove. We should not remove the edge while iterating. std::vector edges_to_remove; @@ -109,7 +110,7 @@ Status PreprocessControlEdgesBetweenOutsideCompilations( // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of // `PreprocessEdgesBetweenOutsideCompilations` for details. -Status PreprocessDataEdgesBetweenOutsideCompilations( +absl::Status PreprocessDataEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Gather edges between outside compilation and host computation. Notice that // we do not store `Edge*` directly because we remove some nodes while adding @@ -193,7 +194,7 @@ Status PreprocessDataEdgesBetweenOutsideCompilations( // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. -Status PostprocessDataEdgesBetweenOutsideCompilations( +absl::Status PostprocessDataEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Gather all outside compilation to outside compilation nodes. std::vector placeholder_nodes; @@ -269,14 +270,14 @@ Status PostprocessDataEdgesBetweenOutsideCompilations( // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of // `PostprocessEdgesBetweenOutsideCompilations` for details. -Status PostprocessControlEdgesBetweenOutsideCompilations( +absl::Status PostprocessControlEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { auto node_name_index = g->BuildNodeNameIndex(); // Reconnect outside compilation to outside compilation control edge. for (Node* n : g->nodes()) { std::vector control_deps; - Status s = + absl::Status s = GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName, &control_deps); if (!s.ok()) { @@ -317,7 +318,7 @@ const char kXlaLiftedArgOutsideCompilationAttrName[] = "_xla_lifted_arg_oc"; const char kXlaOutsideCompilationInputsAttrName[] = "_xla_oc_inputs"; const char kXlaIsPlaceholderForArg[] = "_xla_is_placeholder_for_arg"; -Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { +absl::Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) { // Perform shape inference. std::map arg_shapes; GraphShapeInfo shape_info; @@ -378,7 +379,7 @@ OutsideCompilationClusterDependencies( return std::move(cluster_deps_ordered); } -Status PreprocessEdgesBetweenOutsideCompilations( +absl::Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { // Remove edges from source node to outside compilation nodes, and edges // from outside compilation nodes to sink node. @@ -404,7 +405,7 @@ Status PreprocessEdgesBetweenOutsideCompilations( return absl::OkStatus(); } -Status PostprocessEdgesBetweenOutsideCompilations( +absl::Status PostprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name) { TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations( g, outside_compilation_attr_name)); diff --git a/tensorflow/compiler/jit/encapsulate_util.h b/tensorflow/compiler/jit/encapsulate_util.h index ee31751a45cafd..7c99763c770728 100644 --- a/tensorflow/compiler/jit/encapsulate_util.h +++ b/tensorflow/compiler/jit/encapsulate_util.h @@ -34,7 +34,7 @@ extern const char kXlaInferredShapesAttrName[]; // We have to perform shape inference before encapsulation because after // encapsulation, some nodes will be encapsulated into function call, and shape // inference does not handle function call at the moment. -Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); +absl::Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g); // Attribute indicating that some ops in this node's XLA computation has control // dependency on this node. Attribute value will always be "true". @@ -134,7 +134,7 @@ OutsideCompilationClusterDependencies( // outside compilation node. // 2. For data edges between different outside compilations, remove the edge // and create a Placeholder node as dst node's input. -Status PreprocessEdgesBetweenOutsideCompilations( +absl::Status PreprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name); // Postprocesses edges within the same XLA cluster. This function reverts what @@ -148,7 +148,7 @@ Status PreprocessEdgesBetweenOutsideCompilations( // Notice that control edges marked by // `PreprocessEdgesBetweenOutsideCompilations` step 1b are not handled here. // They are handled in `RewriteOutsideCompilationSubgraphFn`. -Status PostprocessEdgesBetweenOutsideCompilations( +absl::Status PostprocessEdgesBetweenOutsideCompilations( Graph* g, const string& outside_compilation_attr_name); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc index c6ae12576cbd0d..1adbac0e5e187a 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc @@ -71,7 +71,7 @@ bool is_guaranteed_constant(const Node& n) { } // Finds the `index` of an _Arg or _Retval node. -Status GetIndexAttr(const Node& n, int num_args, int* index) { +absl::Status GetIndexAttr(const Node& n, int num_args, int* index) { TF_RETURN_IF_ERROR(GetNodeAttr(n.attrs(), "index", index)); if (*index < 0 || *index >= num_args) { return errors::InvalidArgument("Invalid ", n.type_string(), " number ", @@ -111,11 +111,10 @@ void AddControlOutputs(const Node& node, absl::flat_hash_set* deps) { // of the arguments. // // TODO(b/113166435): Ordering constraints on XlaLaunch op can be relaxed. -Status RewriteSubgraph(const std::vector& arg_source_tensors, - std::unique_ptr* graph_ptr, - std::vector* input_permutation, - std::vector* output_permutation, - NodeDef* call_def) { +absl::Status RewriteSubgraph( + const std::vector& arg_source_tensors, + std::unique_ptr* graph_ptr, std::vector* input_permutation, + std::vector* output_permutation, NodeDef* call_def) { Graph* graph = graph_ptr->get(); const int num_args = input_permutation->size(); const int num_retvals = output_permutation->size(); @@ -194,7 +193,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, } // namespace -/*static*/ Status EncapsulateXlaComputationsPass::Encapsulate( +/*static*/ absl::Status EncapsulateXlaComputationsPass::Encapsulate( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { // Check for undeclared outputs before Encapsulation, so we can give a better // error message. @@ -226,7 +225,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, return absl::OkStatus(); } -/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( +/*static*/ absl::Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( Graph* graph, const std::function(const Node&)>& is_xla_launch_node, const std::function(const Node&)>& @@ -358,7 +357,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, return absl::OkStatus(); } -/*static*/ Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( +/*static*/ absl::Status EncapsulateXlaComputationsPass::BuildXlaLaunchOps( Graph* graph) { const auto is_xla_launch_node = [](const Node& node) -> absl::StatusOr { const string& name = GetNodeAttrString(node.attrs(), kXlaClusterIdAttr); @@ -377,7 +376,7 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, /*add_edges_to_output_of_downstream_nodes=*/true); } -Status EncapsulateXlaComputationsPass::Run( +absl::Status EncapsulateXlaComputationsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "EncapsulateXlaComputations(): " << DumpGraphToFile("encapsulate_xla_computations_before", diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h index b6af1277976f44..6301e963763756 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.h @@ -38,7 +38,7 @@ namespace tensorflow { // XlaLaunch operators. class EncapsulateXlaComputationsPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; // The following methods are public only for unit tests. @@ -48,13 +48,13 @@ class EncapsulateXlaComputationsPass : public GraphOptimizationPass { // functions contain the computations to be passed to XlaLaunch. During // encapsulation, we sort the arguments into the order expected by // XlaLaunch. - static Status Encapsulate(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def); + static absl::Status Encapsulate(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def); // b) we rewrite the function calls generated in phase (a) into XlaLaunch // operators. We also convert the XlaClusterOutput output nodes of the // function call into the outputs of the XlaLaunch operator. - static Status BuildXlaLaunchOps(Graph* graph); + static absl::Status BuildXlaLaunchOps(Graph* graph); struct XlaFunctionInfo { int variable_start_index = -1; @@ -71,7 +71,7 @@ class EncapsulateXlaComputationsPass : public GraphOptimizationPass { // The output graph of this function would look like the following when // add_edges_to_output_of_downstream_nodes is true: // XlaLaunch -> NodeA - static Status BuildXlaLaunchOps( + static absl::Status BuildXlaLaunchOps( Graph* graph, const std::function(const Node&)>& is_xla_launch_node, diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc index 8d06ed38b1bae7..16a17c3c2a03a6 100644 --- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc @@ -60,7 +60,7 @@ static std::unique_ptr MakeOuterGraph( .Attr("_variable_start_index", 4) .Finalize(&def)); - Status status; + absl::Status status; Node* launch = scope.graph()->AddNode(def, &status); TF_CHECK_OK(status); TF_CHECK_OK(scope.DoShapeInference(launch)); diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc index aebd56d8fb9011..140c47dbcac804 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.cc @@ -59,7 +59,7 @@ absl::StatusOr AddHostComputeKeyPlaceholder( builder.Attr("dtype", DT_STRING); builder.Attr("shape", PartialTensorShape({2})); builder.Attr("_host_compute_call_node", xla_cluster_name); - Status s = builder.Finalize(&key_def); + absl::Status s = builder.Finalize(&key_def); if (!s.ok()) return s; Node* n = g->AddNode(key_def, &s); @@ -85,8 +85,8 @@ std::vector GatherNodesWithType(const Graph& g, const string& type) { } // Gets data types from `arg_nodes` and fills them into `recv_at_host_dtypes`. -Status GetArgDataTypes(const std::vector& arg_nodes, - std::vector* recv_at_host_dtypes) { +absl::Status GetArgDataTypes(const std::vector& arg_nodes, + std::vector* recv_at_host_dtypes) { recv_at_host_dtypes->resize(arg_nodes.size(), DT_INVALID); for (auto* n : arg_nodes) { int index; @@ -185,8 +185,8 @@ absl::StatusOr ReplaceArgNodesWithRecvAtHostNode( } // Gets data types from `ret_nodes` and fills them into `send_from_host_dtypes`. -Status GetRetDataTypes(const std::vector& ret_nodes, - std::vector* send_from_host_dtypes) { +absl::Status GetRetDataTypes(const std::vector& ret_nodes, + std::vector* send_from_host_dtypes) { send_from_host_dtypes->resize(ret_nodes.size(), DT_INVALID); for (auto* n : ret_nodes) { int index; @@ -392,7 +392,7 @@ TF_ATTRIBUTE_NOINLINE absl::StatusOr ReplaceOutsideCompilationCallNode( // Resets "_device_ordinal" attr to placeholder value for related nodes // (XlaRecvAtHost nodes; XlaSendFromHost nodes; If/While/FuncCall nodes // containing XlaRecvAtHost/XlaSendFromHost). -Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { +absl::Status ResetDeviceOrdinalToPlaceholderValue(Graph* g) { AttrValue device_ordinal_value; device_ordinal_value.set_placeholder("_device_ordinal"); for (Node* n : g->nodes()) { @@ -524,9 +524,9 @@ absl::StatusOr AddOutsideCompilationInputArgToFunctionBody( // Add _Retval node that matches newly added `arg_node` and connect `arg_node` // to it. -Status AddMatchingRetvalNode(const FunctionBody& function_body, - const int arg_idx, const DataType& data_type, - Node* arg_node) { +absl::Status AddMatchingRetvalNode(const FunctionBody& function_body, + const int arg_idx, const DataType& data_type, + Node* arg_node) { NodeDefBuilder ret_builder(absl::StrCat("ret_", arg_idx), "_Retval"); ret_builder.Attr("T", data_type); ret_builder.Attr("index", arg_idx); @@ -562,11 +562,12 @@ void ReplaceLiftedArgNodePlaceholderWithArg( // Adds function def to function definition library and update the function // callsite operation `callsite_node` to invoke new function instead. -Status AddFunctionWithNewName(const std::string& new_name, - const std::string& func_attr_name, - const FunctionDef& function_def, - NameAttrList* func_attr, Node* callsite_node, - FunctionLibraryDefinition* fld) { +absl::Status AddFunctionWithNewName(const std::string& new_name, + const std::string& func_attr_name, + const FunctionDef& function_def, + NameAttrList* func_attr, + Node* callsite_node, + FunctionLibraryDefinition* fld) { TF_RETURN_IF_ERROR(fld->AddFunctionDef(function_def)); func_attr->set_name(new_name); callsite_node->ClearAttr(func_attr_name); @@ -576,7 +577,7 @@ Status AddFunctionWithNewName(const std::string& new_name, // Reconnect outside compilation lifted arguments in a functional While node to // its outside compilation tensor sources. -Status PostprocessLiftedArgsForWhile( +absl::Status PostprocessLiftedArgsForWhile( const std::unordered_map& outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsWhileNode()); @@ -685,7 +686,7 @@ Status PostprocessLiftedArgsForWhile( return absl::OkStatus(); } -Status PostprocessLiftedArgsForIf( +absl::Status PostprocessLiftedArgsForIf( const std::unordered_map& outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { TF_RET_CHECK(n->IsIfNode()); @@ -824,7 +825,7 @@ Status PostprocessLiftedArgsForIf( return absl::OkStatus(); } -Status PostprocessLiftedArgsForCall( +absl::Status PostprocessLiftedArgsForCall( const std::unordered_map& outside_compilation_attr_to_node, Graph* g, Node* n, FunctionLibraryDefinition* fld) { const FunctionDef* fdef = fld->Find(n->type_string()); @@ -941,7 +942,7 @@ absl::StatusOr> OutsideCompilationAttrToNode( return outside_compilation_attr_to_node; } -Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { +absl::Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { TF_ASSIGN_OR_RETURN(auto outside_compilation_attr_to_node, OutsideCompilationAttrToNode(*g)); @@ -986,7 +987,7 @@ Status PostprocessLiftedArgs(Graph* g, FunctionLibraryDefinition* fld) { // 2) a "key placeholder" node. Later in ExpandHostGraphIntoMainGraph(), we will // replace this node with compilation result node. // 3) all outside compilation graphs. -Status ConstructHostGraph( +absl::Status ConstructHostGraph( const string& xla_cluster_name, const string& outside_compilation_attr_name, const std::vector& outside_compilation_host_graphs, FunctionLibraryDefinition* fld, std::unique_ptr* host_graph) { @@ -1036,7 +1037,7 @@ Status ConstructHostGraph( std::map node_map; node_map[host_fbody->graph->source_node()] = (*host_graph)->source_node(); node_map[host_fbody->graph->sink_node()] = (*host_graph)->sink_node(); - Status s; + absl::Status s; ReverseDFS( *host_fbody->graph, /*enter=*/nullptr, [&](const Node* n) { @@ -1122,11 +1123,11 @@ Status ConstructHostGraph( // Expand XLA computation's outside compilation host side graph into main graph. // Add a control edge between sequencer node and the XLA computation node. -Status ExpandHostGraphIntoMainGraph(Graph* main_graph, - FunctionLibraryDefinition* fld, - const string& host_graph_func_name, - Node* xla_computation_node, - Node* pivot_node) { +absl::Status ExpandHostGraphIntoMainGraph(Graph* main_graph, + FunctionLibraryDefinition* fld, + const string& host_graph_func_name, + Node* xla_computation_node, + Node* pivot_node) { // Temporarily use "0" as "_device_ordinal". It will be rewritten with the // correct value in a later pass. We cannot just use placeholder value here // because FunctionDef instantiation does not allow placeholder value for @@ -1155,7 +1156,7 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, node_map[host_graph->source_node()] = main_graph->source_node(); } node_map[host_graph->sink_node()] = main_graph->sink_node(); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); auto copy_node_fn = [&](const Node* n) { if (!s.ok()) { return; @@ -1205,9 +1206,9 @@ Status ExpandHostGraphIntoMainGraph(Graph* main_graph, // removed those placeholder nodes. // 2) Remove control edges. // 3) Prune nodes that are not useful for shape inference. -Status RewriteShapeInferenceGraph(const string& shape_inference_graph_name, - Graph* host_graph, Node* pivot_node, - FunctionLibraryDefinition* fld) { +absl::Status RewriteShapeInferenceGraph( + const string& shape_inference_graph_name, Graph* host_graph, + Node* pivot_node, FunctionLibraryDefinition* fld) { // Use "0" as "_device_ordinal". It does not matter for shape inference. AttrValue device_ordinal_attr; device_ordinal_attr.set_i(0); @@ -1355,9 +1356,9 @@ TF_ATTRIBUTE_NOINLINE absl::StatusOr BuildSendIfPredNode( } // Replaces key placeholder node with an _Arg node. -Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, - const string& func_name, - FunctionLibraryDefinition* fld) { +absl::Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, + const string& func_name, + FunctionLibraryDefinition* fld) { // Temporarily use "0" as "_device_ordinal". It will be reset to placeholder // value after rewriting. AttrValue device_ordinal_attr; @@ -1402,7 +1403,7 @@ Status ReplaceKeyPlaceholderWithArgNode(const string& xla_cluster_name, } // Builds host side graph for If node. -TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( +TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForIfNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const string& if_node_name, const string& host_transfer_key, @@ -1482,7 +1483,7 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForIfNode( } // Rewrites loop cond to add a node which sends loop cond to host. -TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( +TF_ATTRIBUTE_NOINLINE absl::Status AddSendLoopPredToLoopCond( const string& cond_xla_func_name, const string& host_transfer_key, NameAttrList* loop_cond_func, FunctionLibraryDefinition* fld, Node* while_node) { @@ -1558,7 +1559,7 @@ TF_ATTRIBUTE_NOINLINE Status AddSendLoopPredToLoopCond( } // Rewrites while loop cond function for host. -Status RewriteHostWhileLoopCond( +absl::Status RewriteHostWhileLoopCond( const string& cond_host_func_name, const string& while_node_name, const string& host_transfer_key, const string& xla_cluster_attr_name, const string& xla_cluster_name, const string& outside_compilation_attr_name, @@ -1632,7 +1633,7 @@ Status RewriteHostWhileLoopCond( } // Rewrites while loop body function for host. -Status RewriteHostWhileLoopBody( +absl::Status RewriteHostWhileLoopBody( const string& body_host_func_name, const string& while_node_name, const string& host_transfer_key, const string& xla_cluster_attr_name, const string& xla_cluster_name, const string& outside_compilation_attr_name, @@ -1690,7 +1691,7 @@ Status RewriteHostWhileLoopBody( } // Builds host side graph for while node. -TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode( +TF_ATTRIBUTE_NOINLINE absl::Status BuildHostGraphForWhileNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const string& while_node_name, const string& host_transfer_key, @@ -1757,7 +1758,7 @@ TF_ATTRIBUTE_NOINLINE Status BuildHostGraphForWhileNode( } // Builds host graph for func call nodes. -Status BuildHostGraphForFuncCallNode( +absl::Status BuildHostGraphForFuncCallNode( const string& xla_cluster_attr_name, const string& xla_cluster_name, const string& outside_compilation_attr_name, const string& func_call_node_name, const string& func_call_host_func_name, @@ -1805,7 +1806,7 @@ Status BuildHostGraphForFuncCallNode( return absl::OkStatus(); } -TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( +TF_ATTRIBUTE_NOINLINE absl::Status ExtractOutsideCompilationForFuncCallNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const std::map& host_compute_core, Graph* g, Node* n, @@ -1891,7 +1892,7 @@ TF_ATTRIBUTE_NOINLINE Status ExtractOutsideCompilationForFuncCallNode( return absl::OkStatus(); } -Status ExtractOutsideCompilationForIfNode( +absl::Status ExtractOutsideCompilationForIfNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const std::map& host_compute_core, Graph* g, Node* n, @@ -2010,7 +2011,7 @@ Status ExtractOutsideCompilationForIfNode( return absl::OkStatus(); } -Status ExtractOutsideCompilationForWhileNode( +absl::Status ExtractOutsideCompilationForWhileNode( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const std::map& host_compute_core, Graph* g, Node* n, @@ -2111,7 +2112,7 @@ Status ExtractOutsideCompilationForWhileNode( return absl::OkStatus(); } -Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( +absl::Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( Graph* g, const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const std::map& host_compute_core, FunctionLibraryRuntime* flr, @@ -2153,7 +2154,7 @@ Status ExtractOutsideCompilationForNodesWithAssociatedFunctions( return absl::OkStatus(); } -Status CopyOutsideCompilationConstNodes( +absl::Status CopyOutsideCompilationConstNodes( Graph* g, const string& outside_compilation_attr_name) { for (Node* n : g->op_nodes()) { if (!n->IsConstant() || @@ -2200,7 +2201,7 @@ Status CopyOutsideCompilationConstNodes( } // namespace -Status RewriteOutsideCompilationSubgraphFn::operator()( +absl::Status RewriteOutsideCompilationSubgraphFn::operator()( const std::vector& arg_source_tensors, std::unique_ptr* graph, std::vector* input_permutation, std::vector* output_permutation, NodeDef* node_def) { @@ -2304,7 +2305,7 @@ Status RewriteOutsideCompilationSubgraphFn::operator()( return absl::OkStatus(); } -Status ExtractOutsideCompilationForFunction( +absl::Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, @@ -2317,7 +2318,7 @@ Status ExtractOutsideCompilationForFunction( FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR( flr->Instantiate(func_name, AttrSlice(&func_name_attrs.attr()), &handle)); - Status ret_status = absl::OkStatus(); + absl::Status ret_status = absl::OkStatus(); auto cleanup_handle = gtl::MakeCleanup([&]() { auto s = flr->ReleaseHandle(handle); if (!s.ok()) { @@ -2497,7 +2498,7 @@ Status ExtractOutsideCompilationForFunction( return ret_status; } -Status ExtractOutsideCompilation( +absl::Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass.h b/tensorflow/compiler/jit/extract_outside_compilation_pass.h index cdf2acca9f386a..7631ccd0bc6ab0 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass.h +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass.h @@ -52,10 +52,11 @@ class RewriteOutsideCompilationSubgraphFn { xla_cluster_name_(xla_cluster_name), new_function_name_(new_function_name) {} - Status operator()(const std::vector&, - std::unique_ptr* graph, - std::vector* input_permutation, - std::vector* output_permutation, NodeDef* node_def); + absl::Status operator()(const std::vector&, + std::unique_ptr* graph, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* node_def); private: string xla_cluster_attr_name_; @@ -86,7 +87,7 @@ class RewriteOutsideCompilationSubgraphFn { // function names. These functions need to be rewritten later. // has_outside_compilation: a bool indicating whether this function has any // outside compilation nodes. -Status ExtractOutsideCompilationForFunction( +absl::Status ExtractOutsideCompilationForFunction( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const string& xla_cluster_name, const NameAttrList& func_name_attrs, const string& new_func_name, @@ -99,7 +100,7 @@ Status ExtractOutsideCompilationForFunction( // with XlaHostCompute, and moves those outside compilations into `g`. If shapes // of outside compilation outputs cannot be determined now, we will store shape // inference graph into `fld`. -Status ExtractOutsideCompilation( +absl::Status ExtractOutsideCompilation( const string& xla_cluster_attr_name, const string& outside_compilation_attr_name, const std::unordered_map& clusters, Graph* g, diff --git a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc index 6bf8d977743aae..f44d45b6e6abb7 100644 --- a/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/extract_outside_compilation_pass_test.cc @@ -235,7 +235,7 @@ class ExtractOutsideCompilationForFunctionTest : public ::testing::Test { device_mgr_ = std::make_unique(std::move(devices)); } - Status ExtractOutsideCompilationTest( + absl::Status ExtractOutsideCompilationTest( const string &xla_cluster_attr_name, const string &outside_compilation_attr_name, const string &xla_cluster_name, const NameAttrList &func_name_attrs, @@ -740,7 +740,7 @@ TEST_F(ExtractOutsideCompilationForFunctionTest, OutsideCompilationInFunction) { .Attr("dtype", DT_INT32) .Attr("value", tensor_proto) .Finalize(&const_def)); - Status s; + absl::Status s; Node *const_node = g->AddNode(const_def, &s); TF_CHECK_OK(s); diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc index d407f5d15d904d..d411ecde8905ae 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { -Status ForceXlaConstantsOnHostPass::Run( +absl::Status ForceXlaConstantsOnHostPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h index d7aaf02a3d2aac..ae7cf14962788e 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass.h @@ -28,7 +28,7 @@ class ForceXlaConstantsOnHostPass : public GraphOptimizationPass { public: ForceXlaConstantsOnHostPass() = default; - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc index e4b937551838f9..75bd1d7310a295 100644 --- a/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc +++ b/tensorflow/compiler/jit/force_xla_constants_on_host_pass_test.cc @@ -37,9 +37,9 @@ limitations under the License. namespace tensorflow { namespace { -Status ForceXlaConstantsOnHost(const Scope& s, - FunctionLibraryDefinition* flib_def, - std::unique_ptr* result) { +absl::Status ForceXlaConstantsOnHost(const Scope& s, + FunctionLibraryDefinition* flib_def, + std::unique_ptr* result) { auto graph = std::make_unique(OpRegistry::Global()); GraphOptimizationPassOptions options; SessionOptions session_options; diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 2f99bc5357c2af..7f5e348a1f72a7 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" +#include "xla/hlo/translate/portable_api.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/status_macros.h" #include "xla/stream_executor/host/host_platform_id.h" @@ -97,6 +98,8 @@ static absl::StatusOr BuildHLOString( IrExportStage stage, const XlaCompiler::CompilationResult& result, xla::LocalClient* local_client, const XlaCompiler::Options& options) { switch (stage) { + case IrExportStage::STABLEHLO: + case IrExportStage::STABLEHLO_SERIALIZED: case IrExportStage::HLO: case IrExportStage::HLO_NO_METADATA: case IrExportStage::HLO_SERIALIZED: { @@ -107,16 +110,27 @@ static absl::StatusOr BuildHLOString( std::unique_ptr new_module, xla::HloModule::CreateFromProto(result.computation->proto(), config)); + if (stage == IrExportStage::STABLEHLO_SERIALIZED) { + TF_ASSIGN_OR_RETURN( + std::string stablehlo, + xla::ConvertHloToStablehlo(*new_module, /*emit_bytecode=*/true)); + return stablehlo; + } + if (stage == IrExportStage::STABLEHLO) { + TF_ASSIGN_OR_RETURN( + std::string stablehlo, + xla::ConvertHloToStablehlo(*new_module, /*emit_bytecode=*/false)); + return stablehlo; + } + xla::HloPrintOptions opts; if (stage == IrExportStage::HLO_NO_METADATA) { opts.set_print_metadata(false); } - if (stage == IrExportStage::HLO_SERIALIZED) { return new_module->ToProto().SerializeAsString(); - } else { - return new_module->ToString(opts); } + return new_module->ToString(opts); } case IrExportStage::OPTIMIZED_HLO: case IrExportStage::OPTIMIZED_HLO_SERIALIZED: { @@ -223,7 +237,9 @@ absl::Status ValidateGetCompilerIrTfrtTpu(absl::string_view device_type, auto is_tfrt_tpu_supported_stage = [](IrExportStage stage) { return stage == IrExportStage::HLO || stage == IrExportStage::HLO_NO_METADATA || - stage == IrExportStage::HLO_SERIALIZED; + stage == IrExportStage::HLO_SERIALIZED || + stage == IrExportStage::STABLEHLO || + stage == IrExportStage::STABLEHLO_SERIALIZED; }; // TODO(b/238830423): support GetCompilerIr on TFRT TPU device for stages // that requires compilation from HLO to executable. diff --git a/tensorflow/compiler/jit/get_compiler_ir.h b/tensorflow/compiler/jit/get_compiler_ir.h index 079a7af7ad3cd9..cc831ec298fe0a 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.h +++ b/tensorflow/compiler/jit/get_compiler_ir.h @@ -32,6 +32,8 @@ class TensorHandle; class EagerContext; enum class IrExportStage { + STABLEHLO, + STABLEHLO_SERIALIZED, HLO, HLO_NO_METADATA, HLO_SERIALIZED, diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 0ef7156ef9f593..fa4a9405e808cb 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -14,25 +14,40 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" + #include + #include "absl/algorithm/container.h" -#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/compiler/tf2xla/cc/ops/xla_ops.h" #include "xla/status_macros.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" -#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -169,9 +184,10 @@ class ConstantCache { }; // Returns a node computing the size of the Slice op with inputs `slice_inputs`. -Status ComputeSliceSize(const Scope& host_scope, - const SliceInputs& slice_inputs, - std::vector control_deps, Output* size) { +absl::Status ComputeSliceSize(const Scope& host_scope, + const SliceInputs& slice_inputs, + std::vector control_deps, + Output* size) { // If slice_size[i] >= 0 then slice_size[i] = slice_size[i]. // // If slice_size[i] == -1 then slice_size[i] = input_size[i] - @@ -233,14 +249,14 @@ Status ComputeSliceSize(const Scope& host_scope, // Terminology: "static sized" slice is a slice with the // _XlaCompileTimeConstantInputs attribute set to {2}. The output shape of // these slices can be solely determined by their "size" input. -Status ConvertTensorFlowSliceToStaticShapedSlice( +absl::Status ConvertTensorFlowSliceToStaticShapedSlice( Graph* g, Node* slice, const SliceInputs& slice_inputs, absl::string_view cluster_name, Node** result) { string host_name; TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( slice->assigned_device_name(), &host_name)); - Status status; + absl::Status status; Scope main_scope = NewInternalScope(g, &status, /*refiner=*/nullptr) .WithXlaCluster(string(cluster_name)) @@ -301,8 +317,9 @@ void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice, g->RemoveNode(slice); } -Status RewriteSlice(Graph* g, Node* slice, const SliceInputs& slice_inputs, - absl::string_view cluster_name) { +absl::Status RewriteSlice(Graph* g, Node* slice, + const SliceInputs& slice_inputs, + absl::string_view cluster_name) { VLOG(3) << "Rewriting slice " << slice->name() << " to a \"static shaped\" Slice"; Node* static_shaped_slice; @@ -343,7 +360,7 @@ absl::StatusOr ShouldRewriteSlice(Node* n) { return !slice_inputs->begin.node()->IsConstant(); } -Status FindAndRewriteSlices(Graph* g, bool* changed) { +absl::Status FindAndRewriteSlices(Graph* g, bool* changed) { std::vector slices_to_rewrite; for (Node* n : g->nodes()) { TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n)); @@ -371,7 +388,7 @@ Status FindAndRewriteSlices(Graph* g, bool* changed) { } } // namespace -Status IncreaseDynamismForAutoJitPass::Run( +absl::Status IncreaseDynamismForAutoJitPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); if (flags->tf_xla_clustering_debug) { diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h index 818ca948d64b03..2f4cfdafe66874 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_INCREASE_DYNAMISM_FOR_AUTO_JIT_PASS_H_ #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -49,7 +50,7 @@ namespace tensorflow { // In the future we will also translate StridedSlice and Pad a similar way. class IncreaseDynamismForAutoJitPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc index e864ef1dd12ae9..569254e91fc7bb 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc @@ -15,14 +15,31 @@ limitations under the License. #include "tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.h" +#include +#include "absl/status/status.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/device_set.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { @@ -44,7 +61,9 @@ class FakeDevice : public Device { explicit FakeDevice(const DeviceAttributes& device_attributes) : Device(nullptr, device_attributes) {} - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + absl::Status Sync() override { + return errors::Unimplemented("FakeDevice::Sync()"); + } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -59,8 +78,8 @@ class FakeDevice : public Device { const char* kHostName = "/job:worker/replica:0/task:0/device:CPU:0"; const char* kDeviceName = "/job:worker/replica:0/task:0/device:GPU:0"; -Status IncreaseDynamismForAutoJit(const Scope& s, - std::unique_ptr* result) { +absl::Status IncreaseDynamismForAutoJit(const Scope& s, + std::unique_ptr* result) { std::vector> devices; devices.push_back(FakeDevice::Make(kDeviceName, DEVICE_GPU)); devices.push_back(FakeDevice::Make(kHostName, DEVICE_CPU)); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 5a29e8ef36e9b3..86cb79d981ee85 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -205,7 +205,8 @@ XlaComputationLaunchContext GetLaunchContext( return launch_context; } -Status GetTaskName(const std::string_view device_name, std::string* task_name) { +absl::Status GetTaskName(const std::string_view device_name, + std::string* task_name) { string ignored; if (!DeviceNameUtils::SplitDeviceName(device_name, task_name, &ignored)) { return errors::InvalidArgument("Unable to parse device name: ", @@ -366,7 +367,7 @@ GetXlaCompilerArgsAndSnapshotVariables( return result; } -Status CompileToLocalExecutable( +absl::Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, const XlaPlatformInfo& platform_info, const std::vector& args, @@ -419,7 +420,7 @@ Status CompileToLocalExecutable( compilation_result, executable); } -Status GetUpdatedVariables( +absl::Status GetUpdatedVariables( const OpKernelContext* ctx, absl::Span inputs, absl::Span variable_indices, const XlaCompiler::CompilationResult& compilation_result, @@ -509,7 +510,7 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { std::set variables_updated; // Here we only need to reader-lock the variables, so we pass an empty // variables_updated set here. - Status status = GetVariableInfosFromInputs( + absl::Status status = GetVariableInfosFromInputs( ctx->resource_manager(), ctx->device(), inputs, resources_, &variables_updated, &variable_infos); OP_REQUIRES_OK_ASYNC(ctx, status, done); @@ -528,7 +529,7 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { platform_info_.device_type()); if (use_pjrt) { VLOG(2) << "Compiling using PJRT"; - Status status = CompileToPjRtLoadedExecutable( + absl::Status status = CompileToPjRtLoadedExecutable( *ctx, platform_info_, function_, xla_compiler_args, DeviceCompileMode::kStrict, has_ref_vars_, /*may_alias_resource_update=*/true, &compilation_result, &pjrt_client, @@ -569,7 +570,7 @@ void XlaLocalLaunchBase::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { return; } - Status status = CompileToLocalExecutable( + absl::Status status = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, xla_compiler_args, DeviceCompileMode::kStrict, /*may_alias_resource_update=*/true, &client, &compilation_result, @@ -776,7 +777,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { // Do not alias resource updates as locking variables in XlaCompile and // unlocking them in XlaRun may lead to deadlocks. - Status status; + absl::Status status; if (use_pjrt) { VLOG(2) << "Using PJRT for compilation. Function name: " << function_.name(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 96cf169b914285..03ae04f567443e 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -139,7 +139,7 @@ class MarkForCompilationPassImpl { cpu_global_jit_(cpu_global_jit), cluster_name_prefix_(cluster_name_prefix) {} - Status Run(); + absl::Status Run(); private: // Represents a "cluster" or a connected subgraph of a TensorFlow graph. @@ -295,7 +295,7 @@ class MarkForCompilationPassImpl { // Contracts as many edges as possible to create XLA clusters. After this // finishes the clustering decisions made are implicitly stored in // `clusters_`. - Status RunEdgeContractionLoop(); + absl::Status RunEdgeContractionLoop(); // "Fixes up" clusters by removing some modes. // @@ -304,14 +304,14 @@ class MarkForCompilationPassImpl { // of those constants, and increase overall memory usage. // // This function removes "obviously bad" cases like these. - Status DeclusterNodes(); + absl::Status DeclusterNodes(); // Manifests the clustering decisions into the TF graph by tagging nodes with // an `_XlaCluster` attribute. Also some basic filter logic, like // tf_xla_min_cluster_size, are applied here. - Status CreateClusters(); + absl::Status CreateClusters(); - Status DumpDebugInfo(); + absl::Status DumpDebugInfo(); bool IsCompilationCandidate(Node* n) const { return compilation_candidates_.find(n) != compilation_candidates_.end(); @@ -322,12 +322,12 @@ class MarkForCompilationPassImpl { absl::StatusOr TryToContractEdge(Cluster* from, Cluster* to); // Nodes that XLA can compile are put in `compilation_candidates_`. - Status FindCompilationCandidates(); + absl::Status FindCompilationCandidates(); bool CompilationDisallowedByXlaCompileAttr(Node* node); // Populates `clusters_`. - Status BuildInitialClusterSet(); + absl::Status BuildInitialClusterSet(); absl::StatusOr ShouldCompileClusterImpl(const Cluster& cluster); @@ -614,7 +614,7 @@ void MarkForCompilationPassImpl::Cluster::Merge(Cluster* other) { other->resource_var_operation_node_ids_.clear(); } -Status IgnoreResourceOpForSafetyAnalysis( +absl::Status IgnoreResourceOpForSafetyAnalysis( jit::DeviceInfoCache* device_info_cache, const Node& n, bool* ignore) { // If a resource operation is assigned to XLA_CPU or XLA_GPU explicitly then // ignore it during resource operation safety analysis. We need this hack @@ -772,7 +772,7 @@ bool MarkForCompilationPassImpl::IsScalarIntegerResourceOperation( return TensorShapeUtils::IsScalar(proto->tensor_shape()); } -Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { +absl::Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { TF_RET_CHECK(initialized_ && !edges_contracted_ && !clusters_created_); edges_contracted_ = true; @@ -898,7 +898,7 @@ Status MarkForCompilationPassImpl::RunEdgeContractionLoop() { return absl::OkStatus(); } -Status MarkForCompilationPassImpl::DeclusterNodes() { +absl::Status MarkForCompilationPassImpl::DeclusterNodes() { for (Node* n : compilation_candidates_) { Cluster* cluster = GetClusterForNode(n); if (cluster == nullptr) { @@ -959,7 +959,7 @@ int64_t GetNextClusterSequenceNumber(uint64 fingerprint) { return ClusterSequenceNumberGenerator::Global().GetNext(fingerprint); } -Status MarkForCompilationPassImpl::CreateClusters() { +absl::Status MarkForCompilationPassImpl::CreateClusters() { TF_RET_CHECK(initialized_ && edges_contracted_ && !clusters_created_); clusters_created_ = true; @@ -1016,7 +1016,7 @@ Status MarkForCompilationPassImpl::CreateClusters() { return absl::OkStatus(); } -Status MarkForCompilationPassImpl::DumpDebugInfo() { +absl::Status MarkForCompilationPassImpl::DumpDebugInfo() { TF_RET_CHECK(initialized_ && edges_contracted_ && clusters_created_); if (debug_options_.dump_graphs) { @@ -1112,7 +1112,7 @@ static bool GetNodeOrFuncAttr(Node* node, FunctionLibraryDefinition* flib_def, return out; } -Status MarkForCompilationPassImpl::BuildInitialClusterSet() { +absl::Status MarkForCompilationPassImpl::BuildInitialClusterSet() { auto ignore_resource_ops = [&](const Node& n, bool* ignore) { return IgnoreResourceOpForSafetyAnalysis(&device_info_cache_, n, ignore); }; @@ -1267,7 +1267,7 @@ absl::flat_hash_set GetOrCreateAllowlist() { return allowlist; } -Status MarkForCompilationPassImpl::FindCompilationCandidates() { +absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env_, /*config=*/nullptr, @@ -1489,7 +1489,7 @@ bool MarkForCompilationPassImpl::CompilationDisallowedByXlaCompileAttr( // If there is a _XlaCompile annotation, use its value. bool compile = false; - Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); + absl::Status status = GetNodeAttr(node->attrs(), kXlaCompileAttr, &compile); if (status.ok()) { if (!compile) { VLOG(2) << "Rejecting " << node->name() << ": kXlaCompileAttr(" @@ -1587,7 +1587,7 @@ absl::StatusOr MarkForCompilationPassImpl::TryToContractEdge( return MergeClusters(from, to); } -Status MarkForCompilationPassImpl::Run() { +absl::Status MarkForCompilationPassImpl::Run() { // Make sure that kernels have been registered on the JIT device. XlaOpRegistry::RegisterCompilationKernels(); @@ -1853,7 +1853,7 @@ absl::StatusOr MarkForCompilationPassImpl::ShouldCompileCluster( return should_compile; } -Status MarkForCompilation( +absl::Status MarkForCompilation( const GraphOptimizationPassOptions& options, const MarkForCompilationPassImpl::DebugOptions& debug_options) { Graph* graph = options.graph->get(); @@ -1908,7 +1908,7 @@ std::atomic* GetPointerToFuel(int64_t initial_value) { } // anonymous namespace -Status MarkForCompilationPass::Run( +absl::Status MarkForCompilationPass::Run( const GraphOptimizationPassOptions& options) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); @@ -1928,7 +1928,7 @@ Status MarkForCompilationPass::Run( return MarkForCompilation(options, debug_options); } -Status MarkForCompilationPass::RunForTest( +absl::Status MarkForCompilationPass::RunForTest( const GraphOptimizationPassOptions& options, bool disable_deadness_analysis, bool deterministic_cluster_names) { MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index eb198502d76a97..558912f2eee2e0 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -37,12 +37,12 @@ class MarkForCompilationPass : public GraphOptimizationPass { public: MarkForCompilationPass() = default; - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; private: - Status RunForTest(const GraphOptimizationPassOptions& options, - bool disable_deadness_analysis, - bool deterministic_cluster_names); + absl::Status RunForTest(const GraphOptimizationPassOptions& options, + bool disable_deadness_analysis, + bool deterministic_cluster_names); friend class MarkForCompilationPassTestHelper; }; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index aabedf61202d3f..1a120791206369 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -374,8 +374,8 @@ TEST(XlaCompilationTest, CallXlaDeviceFuncWithResourceOp) { EXPECT_NE(clusters["A"], ""); } -static Status GradForUnaryCwise(FunctionDef* g, - std::vector nodes) { +static absl::Status GradForUnaryCwise( + FunctionDef* g, std::vector nodes) { for (auto& n : nodes) { if (n.attr.empty()) { n.attr = {{"T", DT_FLOAT}}; @@ -394,7 +394,7 @@ static Status GradForUnaryCwise(FunctionDef* g, } // A gradient containing only supported operators -Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Tanh", {"x"}}, @@ -408,7 +408,7 @@ Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Supported", SupportedGrad); // A gradient containing an unsupported operator. -Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Tanh", {"x"}}, @@ -799,7 +799,7 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { NodeDef def; TF_CHECK_OK(builder.Finalize(&def)); - Status status; + absl::Status status; Node* node = graph->AddNode(def, &status); TF_CHECK_OK(status); return node; @@ -815,7 +815,8 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { TF_EXPECT_OK(root.ToGraph(graph.get())); - Status status = MarkForCompilationPassTestHelper::MarkForCompilation(&graph); + absl::Status status = + MarkForCompilationPassTestHelper::MarkForCompilation(&graph); EXPECT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains(status.ToString(), "Edge from c to a would create a cycle.\n" @@ -1009,7 +1010,7 @@ TEST(XlaCompilationTest, RandomShapeWithFunc) { NodeDef call_node; call_node.set_name("fn_call"); call_node.set_op("Stateful_func"); - Status status; + absl::Status status; Node* call = root.graph()->AddNode(call_node, &status); TF_ASSERT_OK(status); @@ -1903,7 +1904,7 @@ Node* MakeStageNode(GraphDefBuilder& builder, string name, } // namespace TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { - auto build_staged_graph = [](std::unique_ptr* graph) -> Status { + auto build_staged_graph = [](std::unique_ptr* graph) -> absl::Status { // Construct a graph as below with two pipeline stages and test that nodes // in different stages will not be merged if ClusterScopingPass is on. // diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc index 8a2957520d7f00..4c1c65d47209b5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { -/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( +/*static*/ absl::Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def, MarkForCompilationPassTestHelper::Options options) { // Assign all unassigned nodes to the CPU device. @@ -71,7 +71,7 @@ namespace tensorflow { /*deterministic_cluster_names=*/options.deterministic_cluster_names); } -/*static*/ Status MarkForCompilationPassTestHelper::MarkForCompilation( +/*static*/ absl::Status MarkForCompilationPassTestHelper::MarkForCompilation( std::unique_ptr* graph, MarkForCompilationPassTestHelper::Options options) { FunctionDefLibrary flib; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h index 327d2eb8450989..84d24898223165 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h @@ -72,13 +72,13 @@ class MarkForCompilationPassTestHelper { // Runs the MarkForCompilation pass on `graph` after assigning all nodes in // `graph` to the CPU device. To make testing easier, ignores device // registration and _XlaCompile attributes. - static Status MarkForCompilation(std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - Options options = Options()); + static absl::Status MarkForCompilation(std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + Options options = Options()); // Like `MarkForCompilation` but creates `flib_def` from the op registry. - static Status MarkForCompilation(std::unique_ptr* graph, - Options options = Options()); + static absl::Status MarkForCompilation(std::unique_ptr* graph, + Options options = Options()); }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc index eb66a8d905cc8c..442635c9a29696 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass.cc @@ -37,9 +37,9 @@ namespace { bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); } namespace reduce_device_to_host_copies { -Status FindNodesToDecluster(const Graph& graph, - absl::flat_hash_set* result, - absl::Span post_order) { +absl::Status FindNodesToDecluster(const Graph& graph, + absl::flat_hash_set* result, + absl::Span post_order) { // Find nodes that have at least one user outside their cluster that expects // hostmem output. These nodes should be cloned to outside the cluster to // avoid the device-host copy we'd otherwise need. @@ -116,7 +116,7 @@ Status FindNodesToDecluster(const Graph& graph, return absl::OkStatus(); } -Status PartiallyDeclusterNode(Graph* graph, Node* n) { +absl::Status PartiallyDeclusterNode(Graph* graph, Node* n) { absl::string_view cluster_name = *GetXlaClusterForNode(*n); absl::InlinedVector out_edges_to_clone; for (const Edge* out_edge : n->out_edges()) { @@ -185,7 +185,7 @@ Status PartiallyDeclusterNode(Graph* graph, Node* n) { // where the ===> arrow has a hostmem source and destination and would entail a // device to host copy if the source and destination were not in the same XLA // cluster. -Status PartiallyDeclusterGraph(Graph* graph) { +absl::Status PartiallyDeclusterGraph(Graph* graph) { // When deciding whether to decluster a particular node, we base our decision // on if we've decided that some of its consumers have to be declustered too. // Iterating the graph in post-order guarantees that consumers have been @@ -244,7 +244,7 @@ bool IsMustCompileDevice(const DeviceType& device_type) { return false; } -Status MustCompileNode(const Node* n, bool* must_compile) { +absl::Status MustCompileNode(const Node* n, bool* must_compile) { DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceNameToDeviceType(n->assigned_device_name(), &device_type)); @@ -288,9 +288,9 @@ Status MustCompileNode(const Node* n, bool* must_compile) { // regress performance in any significant manner. We will have to revisit this // algorithm with a more complex cost model if this assumption turns out to be // incorrect. -Status PartiallyDeclusterGraph(Graph* graph, - const FunctionLibraryDefinition* flib_def, - Env* env) { +absl::Status PartiallyDeclusterGraph(Graph* graph, + const FunctionLibraryDefinition* flib_def, + Env* env) { std::vector compile_time_const_nodes(graph->num_node_ids()); OptimizerOptions opts; auto pflr = std::make_unique( @@ -369,7 +369,7 @@ Status PartiallyDeclusterGraph(Graph* graph, namespace decluster_root_shape_consumers { -Status PartiallyDeclusterGraph(Graph* graph) { +absl::Status PartiallyDeclusterGraph(Graph* graph) { std::vector reverse_post_order; GetReversePostOrder(*graph, &reverse_post_order, /*stable_comparator=*/NodeComparatorName(), @@ -402,7 +402,7 @@ Status PartiallyDeclusterGraph(Graph* graph) { } // namespace decluster_root_shape_consumers } // namespace -Status PartiallyDeclusterPass::Run( +absl::Status PartiallyDeclusterPass::Run( const GraphOptimizationPassOptions& options) { // NB! In this pass we assume the only XLA-auto-clusterable operations that // may have side effects are resource variable operations so we don't cluster diff --git a/tensorflow/compiler/jit/partially_decluster_pass.h b/tensorflow/compiler/jit/partially_decluster_pass.h index cfc4ddb5630bec..18b0091c5bb7eb 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass.h +++ b/tensorflow/compiler/jit/partially_decluster_pass.h @@ -27,7 +27,7 @@ namespace tensorflow { // - Reducing the number of XLA recompilations. class PartiallyDeclusterPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/partially_decluster_pass_test.cc b/tensorflow/compiler/jit/partially_decluster_pass_test.cc index d7758e009f3b23..c8bbcee20e3829 100644 --- a/tensorflow/compiler/jit/partially_decluster_pass_test.cc +++ b/tensorflow/compiler/jit/partially_decluster_pass_test.cc @@ -82,7 +82,7 @@ REGISTER_KERNEL_BUILDER( Name("FakeResourceUpdate").Device(DEVICE_CPU).HostMemory("something_else"), FakeResourceUpdateOp); -Status PartiallyDecluster(std::unique_ptr* graph) { +absl::Status PartiallyDecluster(std::unique_ptr* graph) { FixupSourceAndSinkEdges(graph->get()); // Assign all nodes to the CPU device. static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; diff --git a/tensorflow/compiler/jit/pjrt_compile_util.cc b/tensorflow/compiler/jit/pjrt_compile_util.cc index 156cb21ff8df2c..f57fe186b57621 100644 --- a/tensorflow/compiler/jit/pjrt_compile_util.cc +++ b/tensorflow/compiler/jit/pjrt_compile_util.cc @@ -38,7 +38,7 @@ namespace tensorflow { using PjRtDeviceCompiler = DeviceCompiler; -Status CompileToPjRtLoadedExecutable( +absl::Status CompileToPjRtLoadedExecutable( const DeviceBase* device, const XlaPlatformInfo& platform_info, const NameAttrList& function, const std::vector& args, @@ -70,7 +70,7 @@ Status CompileToPjRtLoadedExecutable( compilation_result, executable); } -Status CompileToPjRtLoadedExecutable( +absl::Status CompileToPjRtLoadedExecutable( const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, const NameAttrList& function, const std::vector& args, diff --git a/tensorflow/compiler/jit/pjrt_compile_util.h b/tensorflow/compiler/jit/pjrt_compile_util.h index d3ba6b7e2d0208..11645651784b35 100644 --- a/tensorflow/compiler/jit/pjrt_compile_util.h +++ b/tensorflow/compiler/jit/pjrt_compile_util.h @@ -30,7 +30,7 @@ namespace tensorflow { // The compilation result is output in `compilation_result`. The PJRT client // used for compilation is output in `client`. The PJRT executable is output in // `executable`. -Status CompileToPjRtLoadedExecutable( +absl::Status CompileToPjRtLoadedExecutable( const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, const NameAttrList& function, const std::vector& args, @@ -46,7 +46,7 @@ Status CompileToPjRtLoadedExecutable( // - `rm`: the resource manager for DeviceCompiler to store JIT-compiled XLA // computation. // - `flr`: the FunctionLibraryRuntime for the `function`. -Status CompileToPjRtLoadedExecutable( +absl::Status CompileToPjRtLoadedExecutable( const DeviceBase* device, const XlaPlatformInfo& platform_info, const NameAttrList& function, const std::vector& args, diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index 794f32d3fea9a1..9807c605939a57 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -138,7 +138,7 @@ void PjRtDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, xla::PjRtFuture<> future = device_buffer->ToLiteral(literal.get()); future.OnReady([literal = std::move(literal), done = std::move(done)]( - const tensorflow::Status& status) { done(status); }); + const absl::Status& status) { done(status); }); } void PjRtDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, diff --git a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc index ffbcef3371ae81..a0da7f7fc1a623 100644 --- a/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc +++ b/tensorflow/compiler/jit/rearrange_function_argument_pass_test.cc @@ -221,7 +221,7 @@ TEST(RearrangeFunctionArgumentForFunctionTest, TF_CHECK_OK(s.ToGraph(g.get())); std::vector> fbodies; - Status status = RearrangeFunctionArguments( + absl::Status status = RearrangeFunctionArguments( [&](const NameAttrList &function, const FunctionBody **fbody) { std::unique_ptr new_fbody; TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fld.Find(function.name()), diff --git a/tensorflow/compiler/jit/report_clustering_info_pass.cc b/tensorflow/compiler/jit/report_clustering_info_pass.cc index b2b71b47c7996c..26871a4267bc9c 100644 --- a/tensorflow/compiler/jit/report_clustering_info_pass.cc +++ b/tensorflow/compiler/jit/report_clustering_info_pass.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" namespace tensorflow { -Status ReportClusteringInfoPass::Run( +absl::Status ReportClusteringInfoPass::Run( const GraphOptimizationPassOptions& options) { XlaAutoClusteringActivity activity; *activity.mutable_summary() = GetXlaAutoClusteringSummary(**options.graph); diff --git a/tensorflow/compiler/jit/report_clustering_info_pass.h b/tensorflow/compiler/jit/report_clustering_info_pass.h index 97471cff134aec..2ac67bf1c68f48 100644 --- a/tensorflow/compiler/jit/report_clustering_info_pass.h +++ b/tensorflow/compiler/jit/report_clustering_info_pass.h @@ -25,7 +25,7 @@ namespace tensorflow { // broadcasts it via xla_activity_listener. class ReportClusteringInfoPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 92f79dde874217..2fee2b0b898890 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -96,9 +96,10 @@ namespace { // Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is // not a resource operation recognized by XLA then sets `out_resource_op_kind` // to nullopt. -Status XlaResourceOpKindForNode( +absl::Status XlaResourceOpKindForNode( const Node& n, const FunctionLibraryDefinition* flib_def, - const std::function& resource_ops_to_ignore, + const std::function& + resource_ops_to_ignore, std::optional* out_resource_op_kind) { bool should_ignore = false; if (resource_ops_to_ignore) { @@ -246,9 +247,10 @@ string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { } } // namespace -Status ComputeIncompatibleResourceOperationPairs( +absl::Status ComputeIncompatibleResourceOperationPairs( const Graph& g, const FunctionLibraryDefinition* flib_def, - const std::function& resource_ops_to_ignore, + const std::function& + resource_ops_to_ignore, std::vector>* result) { CHECK(result->empty()); diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h index 436c4be0f35c0a..eea18fb12fa13f 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -59,9 +59,10 @@ namespace tensorflow { // // If `resource_ops_to_ignore` is set then nodes for which it returns true are // ignored (we pretend these nodes are not resource operations). -Status ComputeIncompatibleResourceOperationPairs( +absl::Status ComputeIncompatibleResourceOperationPairs( const Graph& g, const FunctionLibraryDefinition* flib_def, - const std::function& resource_ops_to_ignore, + const std::function& + resource_ops_to_ignore, std::vector>* result); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index 5529a7cbc723ed..8a80b8ae9b3497 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -69,8 +69,8 @@ Node* MakeNeutral(const Scope& scope, const string& id) { return ops::Const(scope.WithOpName("Const" + id), 42.0f).node(); } -Status ComputeIncompatiblePairs(Graph* g, - std::vector>* result) { +absl::Status ComputeIncompatiblePairs( + Graph* g, std::vector>* result) { FixupSourceAndSinkEdges(g); return ComputeIncompatibleResourceOperationPairs(*g, &g->flib_def(), {}, result); @@ -250,7 +250,7 @@ FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { } Node* MakeCall(Graph* graph, const string& callee_name, const string& node_name, - Status* status) { + absl::Status* status) { NodeDef call_node; call_node.set_name(node_name); call_node.set_op(callee_name); @@ -265,7 +265,7 @@ TEST(ResourceOperationSafetyAnalysisTest, CallRead) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* read = MakeRead(root, "R"); - Status status; + absl::Status status; Node* call = MakeCall(root.graph(), "Const_func", "C", &status); TF_ASSERT_OK(status); @@ -287,7 +287,7 @@ TEST(ResourceOperationSafetyAnalysisTest, ReadCall) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* read = MakeRead(root, "R"); - Status status; + absl::Status status; Node* call = MakeCall(root.graph(), "Const_func", "C", &status); TF_ASSERT_OK(status); @@ -307,7 +307,7 @@ TEST(ResourceOperationSafetyAnalysisTest, CallWrite) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* write = MakeWrite(root, "W"); - Status status; + absl::Status status; Node* call = MakeCall(root.graph(), "Const_func", "C", &status); TF_ASSERT_OK(status); @@ -327,7 +327,7 @@ TEST(ResourceOperationSafetyAnalysisTest, WriteCall) { TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); Node* write = MakeWrite(root, "W"); - Status status; + absl::Status status; Node* call = MakeCall(root.graph(), "Const_func", "C", &status); TF_ASSERT_OK(status); diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc index 2560551ce6b530..7a5106aa69bbbf 100644 --- a/tensorflow/compiler/jit/shape_inference.cc +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -42,9 +42,9 @@ namespace tensorflow { namespace { // Converts a shape inference handle to a PartialTensorShape. -Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, - const shape_inference::ShapeHandle& handle, - PartialTensorShape* shape) { +absl::Status ShapeHandleToTensorShape( + shape_inference::InferenceContext* context, + const shape_inference::ShapeHandle& handle, PartialTensorShape* shape) { // The default is already unknown if (!context->RankKnown(handle)) return absl::OkStatus(); @@ -55,10 +55,10 @@ Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); } -Status PropagateShapes(Graph* graph, - const std::map& arg_shapes, - const std::vector& back_edges, - ShapeRefiner* shape_refiner) { +absl::Status PropagateShapes( + Graph* graph, const std::map& arg_shapes, + const std::vector& back_edges, + ShapeRefiner* shape_refiner) { std::map merge_to_next_iteration; for (const auto& e : back_edges) { if (e.src->IsNextIteration() && e.dst->IsMerge()) { @@ -77,7 +77,7 @@ Status PropagateShapes(Graph* graph, << ", type: " << n->type_string(); // Ignore the status returned by the shape_refiner. We want the best effort // shapes, even if no shape function is registered for a node. - Status status = shape_refiner->AddNode(n); + absl::Status status = shape_refiner->AddNode(n); if (!status.ok()) { VLOG(1) << "Shape inference failed for node " << n->name() << ": " << status; @@ -227,8 +227,9 @@ Status PropagateShapes(Graph* graph, } // Store the shapes of the output tensors in a map -Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, - GraphShapeInfo* shape_info) { +absl::Status StoreOutputShapes(const Graph& graph, + const ShapeRefiner& shape_refiner, + GraphShapeInfo* shape_info) { for (const Node* node : graph.nodes()) { shape_inference::InferenceContext* context = shape_refiner.GetContext(node); if (!context) continue; @@ -264,9 +265,10 @@ Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, } // namespace -Status InferShapes(Graph* graph, const std::map& arg_shapes, - const tensorflow::FunctionLibraryDefinition* fnlib_def, - GraphShapeInfo* shape_info) { +absl::Status InferShapes(Graph* graph, + const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info) { ShapeRefiner shape_refiner(graph->versions(), graph->op_registry()); shape_refiner.set_require_shape_inference_fns(false); // TODO(dlibenzi): Verify if it is worth trying to infer shaped within diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h index 3bd814823013f0..467ecb83a74aae 100644 --- a/tensorflow/compiler/jit/shape_inference.h +++ b/tensorflow/compiler/jit/shape_inference.h @@ -43,9 +43,10 @@ typedef std::unordered_map> GraphShapeInfo; // `arg_shapes`: user given map from the `index` to shapes of this // node, where `index` is the `index` attribute of `_Arg` op or `_index` // attribute of `Placeholder` op. -Status InferShapes(Graph* graph, const std::map& arg_shapes, - const tensorflow::FunctionLibraryDefinition* fnlib_def, - GraphShapeInfo* shape_info); +absl::Status InferShapes(Graph* graph, + const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info); // Merges two InferredShapes. Return an error if the two shapes cannot be // merged. diff --git a/tensorflow/compiler/jit/shape_inference_helpers.cc b/tensorflow/compiler/jit/shape_inference_helpers.cc index 9290861d48f0bc..df90f2df81ea43 100644 --- a/tensorflow/compiler/jit/shape_inference_helpers.cc +++ b/tensorflow/compiler/jit/shape_inference_helpers.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { -Status BackEdgeHelper::Remove(Graph* graph) { +absl::Status BackEdgeHelper::Remove(Graph* graph) { if (graph_ != nullptr) { return errors::Internal("BackEdgeHelper duplicate call to Remove."); } @@ -49,7 +49,7 @@ const std::vector& BackEdgeHelper::RemovedEdges() return back_edges_; } -Status BackEdgeHelper::Replace() { +absl::Status BackEdgeHelper::Replace() { if (graph_ == nullptr) { return errors::Internal("BackEdgeHelper Replace called before Remove."); } diff --git a/tensorflow/compiler/jit/shape_inference_helpers.h b/tensorflow/compiler/jit/shape_inference_helpers.h index 2f053c9a45dd47..d4c8195471e18d 100644 --- a/tensorflow/compiler/jit/shape_inference_helpers.h +++ b/tensorflow/compiler/jit/shape_inference_helpers.h @@ -45,13 +45,13 @@ class BackEdgeHelper { BackEdgeHelper& operator=(const BackEdgeHelper& other) = delete; // Temporarily removes all the back edges in graph. - Status Remove(Graph* graph); + absl::Status Remove(Graph* graph); // Gets the list of removed edges. const std::vector& RemovedEdges() const; // Replaces the back edges removed by a prior call to Remove. - Status Replace(); + absl::Status Replace(); private: Graph* graph_ = nullptr; // not owned diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc index f073902bc03d4a..81ab1d8d05f96e 100644 --- a/tensorflow/compiler/jit/test_util.cc +++ b/tensorflow/compiler/jit/test_util.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { -Status ShapeAnnotationsMatch( +absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, std::map> expected_shapes) { for (Node* node : graph.op_nodes()) { diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h index 569898654e420f..ec694662297399 100644 --- a/tensorflow/compiler/jit/test_util.h +++ b/tensorflow/compiler/jit/test_util.h @@ -42,7 +42,7 @@ namespace tensorflow { // `expected_shapes`. Returns an error if there are nodes in `expected_shapes` // that do not have shape information. Ignores nodes in `graph` that do not have // `expected_shapes` entries. -Status ShapeAnnotationsMatch( +absl::Status ShapeAnnotationsMatch( const Graph& graph, const GraphShapeInfo& shape_info, std::map> expected_shapes); diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test.cc b/tensorflow/compiler/jit/tests/auto_clustering_test.cc index 6f901799a149ff..90e73c23d210d7 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test.cc @@ -22,7 +22,7 @@ namespace { class AutoClusteringTestImpl : public AutoClusteringTest { protected: // Test auto-clustering with a proto text file ${key}.pbtxt. - Status RunAutoClusteringTestWithPbtxt(absl::string_view key) { + absl::Status RunAutoClusteringTestWithPbtxt(absl::string_view key) { string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); @@ -32,7 +32,7 @@ class AutoClusteringTestImpl : public AutoClusteringTest { } // Test auto-clustering with a gzipped proto text file ${key}.pbtxt.gz. - Status RunAutoClusteringTestWithGzippedPbtxt(absl::string_view key) { + absl::Status RunAutoClusteringTestWithGzippedPbtxt(absl::string_view key) { string file_name_without_extension = absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key); @@ -77,7 +77,7 @@ TEST_F(AutoClusteringTestImpl, OpenSeq2SeqGNMT) { } #if defined(PLATFORM_GOOGLE) -Status BenchmarkHelper(absl::string_view key, benchmark::State& state) { +absl::Status BenchmarkHelper(absl::string_view key, benchmark::State& state) { return BenchmarkMarkForCompilation( absl::StrCat(testing::TensorFlowSrcRoot(), "/compiler/jit/tests/", key, ".pbtxt"), diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index 2cb75500dd7916..74462a1cdfd1c6 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -82,7 +82,7 @@ absl::StatusOr SummarizeClustering( return result; } -Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { +absl::Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { const char* kXlaClusterAttr = "_XlaCluster"; const char* kXlaAlreadyClusteredAttr = "_XlaAlreadyClustered"; @@ -99,8 +99,8 @@ Status AssertGraphDefIsUnclustered(const GraphDef& graphdef) { return absl::OkStatus(); } -Status ReadTextProtoFromString(Env* env, const string& data, - ::tensorflow::protobuf::Message* proto) { +absl::Status ReadTextProtoFromString(Env* env, const string& data, + ::tensorflow::protobuf::Message* proto) { if (!::tensorflow::protobuf::TextFormat::ParseFromString(data, proto)) { return errors::DataLoss("Can't parse input data as text proto"); } @@ -108,7 +108,7 @@ Status ReadTextProtoFromString(Env* env, const string& data, } } // namespace -Status AutoClusteringTest::RunAutoClusteringTestImpl( +absl::Status AutoClusteringTest::RunAutoClusteringTestImpl( GraphDef graphdef, absl::string_view golden_summary_file_path) { if (!IsGoogleCudaEnabled()) { // There is some slight change in the clustering decisions under @@ -162,7 +162,7 @@ Status AutoClusteringTest::RunAutoClusteringTestImpl( return absl::OkStatus(); } -Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( +absl::Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( absl::string_view pbtxt_file_path, absl::string_view golden_summary_file_path) { GraphDef graphdef; @@ -172,7 +172,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithPbtxt( golden_summary_file_path); } -Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( +absl::Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( absl::string_view gzipped_pbtxt_file_path, absl::string_view golden_summary_file_path) { Env* env = Env::Default(); @@ -187,7 +187,7 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( /*output_buffer_bytes=*/k_buffer_size, io::ZlibCompressionOptions::GZIP()); tstring decompressed_pbtxt_string; - Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string); + absl::Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string); if (!s.ok() && !errors::IsOutOfRange(s)) { // OutOfRange is fine since we set the number of read bytes to INT_MAX. // Only return other kinds of errors. @@ -202,8 +202,8 @@ Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( } #if defined(PLATFORM_GOOGLE) -Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, - benchmark::State& state) { +absl::Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, + benchmark::State& state) { GraphDef graph_def; TF_RETURN_IF_ERROR( ReadTextProto(Env::Default(), string(graph_def_path), &graph_def)); diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h index 7f97ee0fe8136e..4750803c14b88a 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.h @@ -47,23 +47,23 @@ namespace tensorflow { class AutoClusteringTest : public ::testing::Test { protected: - Status RunAutoClusteringTestWithPbtxt( + absl::Status RunAutoClusteringTestWithPbtxt( absl::string_view pbtxt_file_path, absl::string_view golden_summary_file_path); - Status RunAutoClusteringTestWithGzippedPbtxt( + absl::Status RunAutoClusteringTestWithGzippedPbtxt( absl::string_view gzipped_pbtxt_file_path, absl::string_view golden_summary_file_path); private: - Status RunAutoClusteringTestImpl(GraphDef graphdef, - absl::string_view golden_summary_file_path); + absl::Status RunAutoClusteringTestImpl( + GraphDef graphdef, absl::string_view golden_summary_file_path); }; #if defined(PLATFORM_GOOGLE) // Reads the GraphDef stored in graph_def_path (which must be a pbtxt file) and // benchmarks MarkForCompilationPass on this graphdef. -Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, - benchmark::State& state); +absl::Status BenchmarkMarkForCompilation(absl::string_view graph_def_path, + benchmark::State& state); #endif // PLATFORM_GOOGLE } // namespace tensorflow diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc index 4eb93e85819651..e4be1a1f641656 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.cc @@ -80,8 +80,8 @@ GraphDef DeviceCompilerSerializeTest::GetTestGraph( return graph; } -Status DeviceCompilerSerializeTest::ExecuteWithBatch(const GraphDef& graph, - int batch) { +absl::Status DeviceCompilerSerializeTest::ExecuteWithBatch( + const GraphDef& graph, int batch) { const TensorShape shape({batch, 4}); // Compute the golden output tensor @@ -134,7 +134,8 @@ Status DeviceCompilerSerializeTest::ExecuteWithBatch(const GraphDef& graph, return absl::OkStatus(); } -Status DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( +absl::Status +DeviceCompilerSerializeTest::AlterPersistentCacheEntryHloModuleNames( absl::string_view persistent_cache_dir_path, absl::string_view file_prefix) { Env* env = Env::Default(); diff --git a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h index 9cf36d0cbc6cb1..58e0a03456862b 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_test_helper.h +++ b/tensorflow/compiler/jit/tests/device_compiler_test_helper.h @@ -31,24 +31,25 @@ namespace tensorflow { // A listener to inspect the use of XLA's persistent compilation cache entries. class JitCompilationListener : public XlaActivityListener { public: - Status Listen( + absl::Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { return absl::OkStatus(); } - Status Listen( + absl::Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { activity_history_.push_back(jit_compilation_activity); return absl::OkStatus(); } - Status Listen(const XlaOptimizationRemark& optimization_remark) override { + absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) override { return absl::OkStatus(); } ~JitCompilationListener() override = default; - Status VerifyPersistentCacheUseListenerHistory( + absl::Status VerifyPersistentCacheUseListenerHistory( bool expect_persistent_cache_use) { for (const auto& activity : activity_history_) { if (activity.used_persistent_cache() != expect_persistent_cache_use) { @@ -85,12 +86,12 @@ class DeviceCompilerSerializeTest : public ::testing::Test { // Runs the graph using specified batch size both with and without XLA JIT // compilation. Returns an error if the results between the two do not match. - Status ExecuteWithBatch(const GraphDef& graph, int batch); + absl::Status ExecuteWithBatch(const GraphDef& graph, int batch); // Adds the suffix "_altered" to the HLO module names of all of the persistent // XLA compilation cache entries found at the specified directory. If none are // found, returns NOT_FOUND error. - Status AlterPersistentCacheEntryHloModuleNames( + absl::Status AlterPersistentCacheEntryHloModuleNames( absl::string_view persistent_cache_dir_path, absl::string_view file_prefix = "xla_compile_cache"); diff --git a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc index 052ed6b6f38508..3efea0181974be 100644 --- a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc +++ b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.cc @@ -19,15 +19,14 @@ limitations under the License. namespace tensorflow { -Status TfGraphToHloCompiler::Compile(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - absl::Span args, - XlaCompilationResult* result) { +absl::Status TfGraphToHloCompiler::Compile( + const XlaCompiler::CompileOptions& options, const NameAttrList& function, + absl::Span args, XlaCompilationResult* result) { return ADD_SOURCE_LOCATION( xla_compiler_.CompileFunction(options, function, args, result)); } -Status TfGraphToHloCompiler::CompileSingleOp( +absl::Status TfGraphToHloCompiler::CompileSingleOp( const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx, absl::Span args, XlaCompilationResult* result) { return ADD_SOURCE_LOCATION(xla_compiler_.CompileSingleOp( diff --git a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h index be7a7eb15e9912..adc2a74e1ac017 100644 --- a/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h +++ b/tensorflow/compiler/jit/tf_graph_to_hlo_compiler.h @@ -35,18 +35,18 @@ class TfGraphToHloCompiler : public TfToHloCompiler { // Compiles a Tensorflow `function` into an HloModuleProto stored in the // XlaCompilationResult pointed to by `result` by calling // XlaCompiler::CompileFunction. - Status Compile(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - absl::Span args, - XlaCompilationResult* result) override; + absl::Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) override; // Compiles a Tensorflow single op into an HloModuleProto stored in the // XlaCompilationResult pointed to by `result` by calling // XlaCompiler::CompileSingleOp. - Status CompileSingleOp(const XlaCompiler::CompileOptions& options, - const OpKernelContext* ctx, - absl::Span args, - XlaCompilationResult* result) override; + absl::Status CompileSingleOp(const XlaCompiler::CompileOptions& options, + const OpKernelContext* ctx, + absl::Span args, + XlaCompilationResult* result) override; private: XlaCompiler xla_compiler_; diff --git a/tensorflow/compiler/jit/tf_to_hlo_compiler.h b/tensorflow/compiler/jit/tf_to_hlo_compiler.h index 0dd11c8f552a10..f9937a65147a2f 100644 --- a/tensorflow/compiler/jit/tf_to_hlo_compiler.h +++ b/tensorflow/compiler/jit/tf_to_hlo_compiler.h @@ -31,17 +31,16 @@ class TfToHloCompiler { // Compiles a Tensorflow `function` to an HloModuleProto stored in the // XlaCompilationResult pointed to by `result`. - virtual Status Compile(const XlaCompiler::CompileOptions& options, - const NameAttrList& function, - absl::Span args, - XlaCompilationResult* result) = 0; + virtual absl::Status Compile(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + absl::Span args, + XlaCompilationResult* result) = 0; // Compiles a Tensorflow single op to an HloModuleProto stored in the // XlaCompilationResult pointed to by `result`. - virtual Status CompileSingleOp(const XlaCompiler::CompileOptions& options, - const OpKernelContext* ctx, - absl::Span args, - XlaCompilationResult* result) = 0; + virtual absl::Status CompileSingleOp( + const XlaCompiler::CompileOptions& options, const OpKernelContext* ctx, + absl::Span args, XlaCompilationResult* result) = 0; private: TfToHloCompiler(const TfToHloCompiler&) = delete; diff --git a/tensorflow/compiler/jit/variable_info_util.cc b/tensorflow/compiler/jit/variable_info_util.cc index 315d5d63c73fc7..01b105cb47e2dd 100644 --- a/tensorflow/compiler/jit/variable_info_util.cc +++ b/tensorflow/compiler/jit/variable_info_util.cc @@ -35,19 +35,19 @@ limitations under the License. namespace tensorflow { -Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, - absl::Span inputs, - absl::Span variable_indices, - std::vector* result) { +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result) { return GetVariableInfosFromInputs(rm, dev, inputs, variable_indices, nullptr, result); } -Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, - absl::Span inputs, - absl::Span variable_indices, - const std::set* variables_updated, - std::vector* result) { +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + const std::set* variables_updated, + std::vector* result) { result->clear(); result->reserve(variable_indices.size()); for (int var_idx : variable_indices) { @@ -85,7 +85,7 @@ Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, return absl::OkStatus(); } -Status LockVariables(absl::Span variables) { +absl::Status LockVariables(absl::Span variables) { std::vector lock_order(variables.size()); std::iota(lock_order.begin(), lock_order.end(), 0); @@ -137,7 +137,7 @@ Status LockVariables(absl::Span variables) { return absl::OkStatus(); } -Status LockVariables(absl::Span variables) { +absl::Status LockVariables(absl::Span variables) { std::vector variable_ptrs; variable_ptrs.reserve(variables.size()); for (auto& var : variables) { @@ -146,10 +146,10 @@ Status LockVariables(absl::Span variables) { return LockVariables(absl::MakeSpan(variable_ptrs)); } -Status SnapshotResourceVariables(OpKernelContext* ctx, - absl::Span variable_indices, - absl::Span variable_infos, - ResourceVarsSnapshot* result) { +absl::Status SnapshotResourceVariables( + OpKernelContext* ctx, absl::Span variable_indices, + absl::Span variable_infos, + ResourceVarsSnapshot* result) { for (int i = 0, end = variable_indices.size(); i < end; i++) { Var* var = variable_infos[i].var(); (*result)[variable_indices[i]] = @@ -168,7 +168,7 @@ std::vector GetResourceVariableIndicesFromContext(OpKernelContext* ctx) { return out; } -Status CreateVariableInfoLookup( +absl::Status CreateVariableInfoLookup( absl::Span variable_args, absl::flat_hash_map& variable_info_lookup) { for (const VariableInfo& info : variable_args) { diff --git a/tensorflow/compiler/jit/variable_info_util.h b/tensorflow/compiler/jit/variable_info_util.h index 386a4f3755633c..ac825d14687834 100644 --- a/tensorflow/compiler/jit/variable_info_util.h +++ b/tensorflow/compiler/jit/variable_info_util.h @@ -43,10 +43,10 @@ using ResourceVarsSnapshot = absl::flat_hash_map>; // We snapshot the entire set of resource variables as one atomic operation. // This models Read->* dependencies between resource variable operations. See // jit/resource_operation_safety_analysis for details. -Status SnapshotResourceVariables(OpKernelContext* ctx, - absl::Span variable_indices, - absl::Span variable_infos, - ResourceVarsSnapshot* result); +absl::Status SnapshotResourceVariables( + OpKernelContext* ctx, absl::Span variable_indices, + absl::Span variable_infos, + ResourceVarsSnapshot* result); // Acquires the mutexes for all the variables in `variables` using a // deadlock-safe protocol (acquire the mutexes in increasing-address order). @@ -55,9 +55,9 @@ Status SnapshotResourceVariables(OpKernelContext* ctx, // variable (i.e. variables[i].var() can be null for some i). // // If the variable is read_only(), only acquires reader locks. -Status LockVariables(absl::Span variables) +absl::Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); -Status LockVariables(absl::Span variables) +absl::Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); // Returns a vector of VariableInfo instances for the resource variable inputs, @@ -66,25 +66,25 @@ Status LockVariables(absl::Span variables) // // When using the VariableInfos generated by this version, all variables would // be writer-locked. -Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, - absl::Span inputs, - absl::Span variable_indices, - std::vector* result); +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + std::vector* result); // variables_updated is a set containing the indices of the variables that are // going to be mutated. If variables_updated is empty, then in LockVariables all // variables would only be reader-locked. If variables_updated is null, then we // consider this information unknown and will acquire writer-lock for all // variables. -Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, - absl::Span inputs, - absl::Span variable_indices, - const std::set* variables_updated, - std::vector* result); +absl::Status GetVariableInfosFromInputs(ResourceMgr* rm, DeviceBase* dev, + absl::Span inputs, + absl::Span variable_indices, + const std::set* variables_updated, + std::vector* result); std::vector GetResourceVariableIndicesFromContext(OpKernelContext* ctx); -Status CreateVariableInfoLookup( +absl::Status CreateVariableInfoLookup( absl::Span variable_args, absl::flat_hash_map& variable_info_lookup); diff --git a/tensorflow/compiler/jit/xla_activity_listener.cc b/tensorflow/compiler/jit/xla_activity_listener.cc index c3df741fbb08e2..471e6813466213 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_listener.cc @@ -39,7 +39,7 @@ XlaActivityListenerList* GetXlaActivityListenerList() { } template -Status ForEachListener(FnTy fn) { +absl::Status ForEachListener(FnTy fn) { XlaActivityListenerList* listener_list = GetXlaActivityListenerList(); absl::ReaderMutexLock reader_lock(&listener_list->mutex); @@ -52,7 +52,7 @@ Status ForEachListener(FnTy fn) { } void FlushAllListeners() { - Status s = ForEachListener([](XlaActivityListener* listener) { + absl::Status s = ForEachListener([](XlaActivityListener* listener) { listener->Flush(); return absl::OkStatus(); }); @@ -60,28 +60,29 @@ void FlushAllListeners() { } } // namespace -Status BroadcastXlaActivity( +absl::Status BroadcastXlaActivity( XlaAutoClusteringActivity auto_clustering_activity) { return ForEachListener([&](XlaActivityListener* listener) { return listener->Listen(auto_clustering_activity); }); } -Status BroadcastXlaActivity( +absl::Status BroadcastXlaActivity( XlaJitCompilationActivity jit_compilation_activity) { return ForEachListener([&](XlaActivityListener* listener) { return listener->Listen(jit_compilation_activity); }); } -Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark) { +absl::Status BroadcastOptimizationRemark( + XlaOptimizationRemark optimization_remark) { VLOG(2) << "OptimizationRemark: " << optimization_remark.DebugString(); return ForEachListener([&](XlaActivityListener* listener) { return listener->Listen(optimization_remark); }); } -Status BroadcastOptimizationRemark( +absl::Status BroadcastOptimizationRemark( XlaOptimizationRemark::Warning optimization_warning, string debug_information) { XlaOptimizationRemark remark; diff --git a/tensorflow/compiler/jit/xla_activity_listener.h b/tensorflow/compiler/jit/xla_activity_listener.h index 05328c896d34f4..d8be8309045b3a 100644 --- a/tensorflow/compiler/jit/xla_activity_listener.h +++ b/tensorflow/compiler/jit/xla_activity_listener.h @@ -22,18 +22,21 @@ limitations under the License. namespace tensorflow { // Broadcast `auto_clustering_activity` to all the registered listeners. -Status BroadcastXlaActivity(XlaAutoClusteringActivity auto_clustering_activity); +absl::Status BroadcastXlaActivity( + XlaAutoClusteringActivity auto_clustering_activity); // Broadcast `jit_compilation_activity` to all the registered listeners. -Status BroadcastXlaActivity(XlaJitCompilationActivity jit_compilation_activity); +absl::Status BroadcastXlaActivity( + XlaJitCompilationActivity jit_compilation_activity); // Broadcast `jit_compilation_activity` to all the registered listeners. -Status BroadcastOptimizationRemark(XlaOptimizationRemark optimization_remark); +absl::Status BroadcastOptimizationRemark( + XlaOptimizationRemark optimization_remark); // LINT.IfChange // Called after TensorFlow realizes possible lost performance. The parameters in // this should match all of the values in the XlaOptimizationRemark proto. -Status BroadcastOptimizationRemark( +absl::Status BroadcastOptimizationRemark( XlaOptimizationRemark::Warning optimization_warning, string debug_information); @@ -46,15 +49,16 @@ Status BroadcastOptimizationRemark( class XlaActivityListener { public: // Called after TensorFlow auto-clusters a graph. - virtual Status Listen( + virtual absl::Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) = 0; // Called after TensorFlow JIT compiles an XLA cluster. - virtual Status Listen( + virtual absl::Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) = 0; // Called after TensorFlow realizes possible lost performance. - virtual Status Listen(const XlaOptimizationRemark& optimization_remark) = 0; + virtual absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) = 0; // Called at program exit in best-effort manner to give listeners a chance to // flush their state. diff --git a/tensorflow/compiler/jit/xla_activity_listener_test.cc b/tensorflow/compiler/jit/xla_activity_listener_test.cc index ee58c280d66d80..3a678b78f4e22f 100644 --- a/tensorflow/compiler/jit/xla_activity_listener_test.cc +++ b/tensorflow/compiler/jit/xla_activity_listener_test.cc @@ -31,19 +31,20 @@ namespace { class TestListener : public XlaActivityListener { public: - Status Listen( + absl::Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { auto_clustering_activity_ = auto_clustering_activity; return absl::OkStatus(); } - Status Listen( + absl::Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { jit_compilation_activity_ = jit_compilation_activity; return absl::OkStatus(); } - Status Listen(const XlaOptimizationRemark& optimization_remark) override { + absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) override { return absl::OkStatus(); } diff --git a/tensorflow/compiler/jit/xla_activity_logging_listener.cc b/tensorflow/compiler/jit/xla_activity_logging_listener.cc index 20262548e8bc2b..141a5c4c31b302 100644 --- a/tensorflow/compiler/jit/xla_activity_logging_listener.cc +++ b/tensorflow/compiler/jit/xla_activity_logging_listener.cc @@ -23,7 +23,7 @@ namespace { // Listens to XLA activity and logs them using tensorflow::Logger. class XlaActivityLoggingListener final : public XlaActivityListener { public: - Status Listen( + absl::Status Listen( const XlaAutoClusteringActivity& auto_clustering_activity) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaAutoClusteringActivity disabled"; @@ -33,7 +33,7 @@ class XlaActivityLoggingListener final : public XlaActivityListener { return absl::OkStatus(); } - Status Listen( + absl::Status Listen( const XlaJitCompilationActivity& jit_compilation_activity) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaJitCompilationActivity disabled"; @@ -43,7 +43,8 @@ class XlaActivityLoggingListener final : public XlaActivityListener { return absl::OkStatus(); } - Status Listen(const XlaOptimizationRemark& optimization_remark) override { + absl::Status Listen( + const XlaOptimizationRemark& optimization_remark) override { if (!IsEnabled()) { VLOG(3) << "Logging XlaJitCompilationActivity disabled"; return absl::OkStatus(); diff --git a/tensorflow/compiler/jit/xla_cluster_util.cc b/tensorflow/compiler/jit/xla_cluster_util.cc index 79e973c4f1adc4..6d7e5518524c29 100644 --- a/tensorflow/compiler/jit/xla_cluster_util.cc +++ b/tensorflow/compiler/jit/xla_cluster_util.cc @@ -202,7 +202,7 @@ std::optional GetXlaClusterForNode(const Node& node) { if (attr_value == nullptr) { return std::nullopt; } - Status s = AttrValueHasType(*attr_value, "string"); + absl::Status s = AttrValueHasType(*attr_value, "string"); if (!s.ok()) { return std::nullopt; } @@ -420,7 +420,7 @@ CallTargetListTy GetCallTargetListFromNode( enum class Direction { kForward, kBackward }; -Status GetNodesRelatedToRefVariablesInDirection( +absl::Status GetNodesRelatedToRefVariablesInDirection( const Graph& graph, FunctionLibraryRuntime* lib_runtime, Direction direction, int depth, absl::flat_hash_set* result); @@ -480,7 +480,7 @@ absl::StatusOr DoesAnyCalleeHaveRefNodes( // Helper for GetNodesRelatedToRefVariables that traverses the graph in one // direction. -Status GetNodesRelatedToRefVariablesInDirection( +absl::Status GetNodesRelatedToRefVariablesInDirection( const Graph& graph, FunctionLibraryRuntime* lib_runtime, Direction direction, int depth, absl::flat_hash_set* result) { std::vector nodes_in_order; diff --git a/tensorflow/compiler/jit/xla_cluster_util_test.cc b/tensorflow/compiler/jit/xla_cluster_util_test.cc index ecb96992836fd9..808931c10714ae 100644 --- a/tensorflow/compiler/jit/xla_cluster_util_test.cc +++ b/tensorflow/compiler/jit/xla_cluster_util_test.cc @@ -255,7 +255,7 @@ TEST(NodesRelatedToRefVariables, Basic) { EXPECT_EQ(names, expected); } -Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) { +absl::Status MakeLoop(Scope s, Output init_value, absl::string_view loop_name) { s = s.NewSubScope(std::string(loop_name)); ops::internal::Enter enter(s.WithOpName("init_value"), init_value, loop_name); ops::Merge merge(s.WithOpName("merge"), {init_value, init_value}); diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index 4135f7de7850c6..d8b5d9bf24ff03 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -83,7 +83,7 @@ XlaCompiler::CompileOptions GetCompileOptions(bool for_pjrt = false) { // Gets `variables` from `ctx`, locks them and builds XlaCompiler::Arguments // using them. Stores the arguments in `args`. `variables` and `args` passed in // will be cleared before populating them. -Status GetAndLockVariablesAndBuildXlaCompilerArguments( +absl::Status GetAndLockVariablesAndBuildXlaCompilerArguments( const OpKernelContext& ctx, const std::vector& inputs, const std::vector& constant_indices, const std::vector& variable_indices, @@ -103,11 +103,11 @@ Status GetAndLockVariablesAndBuildXlaCompilerArguments( } } // namespace -Status XlaCompileOnDemandOp::Run(const ResourceVarsSnapshot& variable_args, - const XlaCompiler::CompilationResult* result, - const XlaDeviceCompiler* xla_device_compiler, - xla::LocalExecutable* executable, - OpKernelContext* ctx) { +absl::Status XlaCompileOnDemandOp::Run( + const ResourceVarsSnapshot& variable_args, + const XlaCompiler::CompilationResult* result, + const XlaDeviceCompiler* xla_device_compiler, + xla::LocalExecutable* executable, OpKernelContext* ctx) { xla::LocalClient* client = static_cast(xla_device_compiler->client()); @@ -167,7 +167,7 @@ Status XlaCompileOnDemandOp::Run(const ResourceVarsSnapshot& variable_args, return absl::OkStatus(); } -Status XlaCompileOnDemandOp::Compile( +absl::Status XlaCompileOnDemandOp::Compile( const std::vector& args, OpKernelContext* ctx, PjRtDeviceCompiler** pjrt_device_compiler, DeviceCompilationProfiler** profiler, @@ -189,7 +189,7 @@ Status XlaCompileOnDemandOp::Compile( result, executable); } -Status XlaCompileOnDemandOp::Compile( +absl::Status XlaCompileOnDemandOp::Compile( const std::vector& args, OpKernelContext* ctx, XlaDeviceCompiler** xla_device_compiler, DeviceCompilationProfiler** profiler, diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index d8dead2737d794..dfe9ddaa8ac3a7 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -47,27 +47,27 @@ class XlaCompileOnDemandOp : public OpKernel { void Compute(OpKernelContext* ctx) override; private: - Status Compile(const std::vector& args, - OpKernelContext* ctx, - DeviceCompiler** - xla_device_compiler, - DeviceCompilationProfiler** profiler, - const XlaCompiler::CompilationResult** result, - xla::LocalExecutable** executable); + absl::Status Compile(const std::vector& args, + OpKernelContext* ctx, + DeviceCompiler** + xla_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::LocalExecutable** executable); - Status Compile(const std::vector& args, - OpKernelContext* ctx, - DeviceCompiler** - pjrt_device_compiler, - DeviceCompilationProfiler** profiler, - const XlaCompiler::CompilationResult** result, - xla::PjRtLoadedExecutable** executable); + absl::Status Compile(const std::vector& args, + OpKernelContext* ctx, + DeviceCompiler** pjrt_device_compiler, + DeviceCompilationProfiler** profiler, + const XlaCompiler::CompilationResult** result, + xla::PjRtLoadedExecutable** executable); - Status Run(const ResourceVarsSnapshot& variable_args, - const XlaCompiler::CompilationResult* result, - const DeviceCompiler* - xla_device_compiler, - xla::LocalExecutable* executable, OpKernelContext* ctx); + absl::Status Run(const ResourceVarsSnapshot& variable_args, + const XlaCompiler::CompilationResult* result, + const DeviceCompiler* + xla_device_compiler, + xla::LocalExecutable* executable, OpKernelContext* ctx); const XlaPlatformInfo platform_info_; }; diff --git a/tensorflow/compiler/jit/xla_compile_util.cc b/tensorflow/compiler/jit/xla_compile_util.cc index 05feb2c8f36769..ab73bb4e188dfd 100644 --- a/tensorflow/compiler/jit/xla_compile_util.cc +++ b/tensorflow/compiler/jit/xla_compile_util.cc @@ -51,7 +51,7 @@ absl::StatusOr> CreateSingleOpGraph( for (int64_t i = 0, end = args.size(); i < end; ++i) { Node* node; string arg_name = absl::StrCat("_arg", i); - Status status = + absl::Status status = NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) .ControlInput(graph->source_node()) .Attr("T", args[i].kind == XlaArgument::kResource ? DT_RESOURCE @@ -66,11 +66,12 @@ absl::StatusOr> CreateSingleOpGraph( for (int64_t i = 0, end = result_types.size(); i < end; ++i) { Node* node; string retval_name = absl::StrCat("_retval", i); - Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) - .Input(main_node, i) - .Attr("T", result_types[i]) - .Attr("index", i) - .Finalize(graph.get(), &node); + absl::Status status = + NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) + .Input(main_node, i) + .Attr("T", result_types[i]) + .Attr("index", i) + .Finalize(graph.get(), &node); TF_RETURN_IF_ERROR(status); } FixupSourceAndSinkEdges(graph.get()); diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 4c41805d034a0e..092d4d0891678e 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -36,12 +36,14 @@ using tensorflow::IdentityShapeRepresentationFn; class XlaCpuDeviceFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; }; -Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { +absl::Status XlaCpuDeviceFactory::ListPhysicalDevices( + std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set " @@ -53,7 +55,7 @@ Status XlaCpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { return absl::OkStatus(); } -Status XlaCpuDeviceFactory::CreateDevices( +absl::Status XlaCpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); @@ -104,7 +106,7 @@ Status XlaCpuDeviceFactory::CreateDevices( // tensorflow_accelerator_device_info() == nullptr is used as an IsCPU test. // We need XlaCpuDevice to be treated not as CPU because it allocates // XlaTensors, not regular Tensors. - Status status = device->UseAcceleratorDeviceInfo(); + absl::Status status = device->UseAcceleratorDeviceInfo(); if (!status.ok()) { errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT); return status; diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 471f54571d2b53..dcc661e4f73cf5 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -65,7 +65,7 @@ namespace tensorflow { // Default PaddedShapeFn implementation that simply returns the unpadded // on-device shape. This is accurate for CPU and GPU devices that neither // transpose nor pad tensors. -Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { +absl::Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const tensorflow::XlaTensor* xla_tensor = tensorflow::XlaTensor::FromTensor(&tensor); if (xla_tensor == nullptr) { @@ -169,7 +169,7 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return device_type_; } -/*static*/ Status XlaDevice::GetMetadataFromDevice( +/*static*/ absl::Status XlaDevice::GetMetadataFromDevice( DeviceBase* device, const XlaDevice::Metadata** metadata) { *metadata = nullptr; XlaDevice* xla_device = dynamic_cast(device->UnderlyingDevice()); @@ -184,13 +184,13 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return absl::OkStatus(); } -/* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, - const Metadata** metadata) { +/* static */ absl::Status XlaDevice::GetMetadata(OpKernelContext* ctx, + const Metadata** metadata) { return GetMetadataFromDevice(ctx->device(), metadata); } -/* static */ Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, - const Metadata** metadata) { +/* static */ absl::Status XlaDevice::GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata) { return GetMetadataFromDevice(ctx->device(), metadata); } @@ -287,15 +287,14 @@ Allocator* XlaDevice::GetAllocatorLocked(AllocatorAttributes attr) { return xla_allocator_; } -Status XlaDevice::EnsureDeviceContextOk() { +absl::Status XlaDevice::EnsureDeviceContextOk() { mutex_lock lock(mu_); return GetDeviceContextLocked().status(); } -Status XlaDevice::EnsureStreamOkLocked(xla::Backend* backend, - const string& name, - std::shared_ptr* stream, - bool* stream_was_changed) { +absl::Status XlaDevice::EnsureStreamOkLocked( + xla::Backend* backend, const string& name, + std::shared_ptr* stream, bool* stream_was_changed) { if (!(*stream) || !(*stream)->ok()) { xla::StreamPool::Ptr ptr; TF_ASSIGN_OR_RETURN(ptr, backend->BorrowStream(device_ordinal_)); @@ -438,13 +437,13 @@ absl::StatusOr XlaDevice::GetDeviceContextDefault() { return GetDeviceContextWithIndex(0); } -Status XlaDevice::UseAcceleratorDeviceInfo() { +absl::Status XlaDevice::UseAcceleratorDeviceInfo() { mutex_lock lock(mu_); use_accelerator_device_info_ = true; return GetDeviceContextLocked().status(); } -Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) { +absl::Status XlaDevice::TryGetDeviceContext(DeviceContext** out_context) { TF_ASSIGN_OR_RETURN(auto device_context, GetDeviceContextDefault()); device_context->Ref(); *out_context = device_context; @@ -482,7 +481,7 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, op_kernel->ComputeAsync(context, done); } -Status XlaDevice::Sync() { +absl::Status XlaDevice::Sync() { VLOG(1) << "XlaDevice::Sync"; tsl::profiler::TraceMe activity("XlaDevice::Sync", tsl::profiler::TraceMeLevel::kInfo); @@ -493,7 +492,7 @@ Status XlaDevice::Sync() { } if (!stream) return absl::OkStatus(); - Status status = stream->BlockHostUntilDone(); + absl::Status status = stream->BlockHostUntilDone(); TF_RETURN_IF_ERROR(status); if (!stream->ok()) { return errors::Internal("XlaDevice::Sync() failed."); @@ -502,17 +501,16 @@ Status XlaDevice::Sync() { return absl::OkStatus(); } -Status XlaDevice::MakeTensorFromProto(DeviceContext* device_context, - const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) { +absl::Status XlaDevice::MakeTensorFromProto( + DeviceContext* device_context, const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, Tensor* tensor) { Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from proto: ", tensor_proto.DebugString()); } - Status status; + absl::Status status; if (alloc_attrs.on_host()) { *tensor = parsed; } else { @@ -530,9 +528,9 @@ Status XlaDevice::MakeTensorFromProto(DeviceContext* device_context, return status; } -Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) { +absl::Status XlaDevice::MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { VLOG(1) << "XlaDevice::MakeTensorFromProto"; DeviceContext* device_context; TF_ASSIGN_OR_RETURN(device_context, GetDeviceContextDefault()); @@ -549,13 +547,14 @@ bool XlaDevice::AllowsSyncOnCompletion() const { return sync_on_completion_; } -void XlaDevice::SetHandleDeviceErrorCallback(std::function callback) { +void XlaDevice::SetHandleDeviceErrorCallback( + std::function callback) { mutex_lock lock(mu_); device_error_callback_ = callback; } -Status XlaDevice::HandleDeviceError() { - std::function local_device_error_callback; +absl::Status XlaDevice::HandleDeviceError() { + std::function local_device_error_callback; { mutex_lock lock(mu_); local_device_error_callback = device_error_callback_; @@ -566,7 +565,7 @@ Status XlaDevice::HandleDeviceError() { return absl::OkStatus(); } -Status XlaDevice::RefreshStatus() { +absl::Status XlaDevice::RefreshStatus() { std::shared_ptr stream; { mutex_lock lock(mu_); @@ -575,7 +574,7 @@ Status XlaDevice::RefreshStatus() { if (!stream) { return absl::OkStatus(); } - Status status = stream->RefreshStatus(); + absl::Status status = stream->RefreshStatus(); if (!status.ok()) { // Ignore errors from HandleDeviceError, since by definition the status is // already non-ok, so there's nothing extra to report if HandleDeviceError diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 64f3abbeca7a45..cbaa97dc15e1c0 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -53,7 +53,7 @@ class XlaDevice : public LocalDevice { // Given a tensor, sets `xla::Shape*` the shape of tensor's representation // on device, fully padded. On error, the contents of `xla::Shape*` // are undefined. - typedef std::function PaddedShapeFn; + typedef std::function PaddedShapeFn; // Wrapper class to store metadata about the XlaDevice, where it can be // retrieved e.g., when lazily creating the XlaCompilationCache device. @@ -93,16 +93,17 @@ class XlaDevice : public LocalDevice { }; // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. - static Status GetMetadata(OpKernelContext* ctx, const Metadata** metadata); + static absl::Status GetMetadata(OpKernelContext* ctx, + const Metadata** metadata); // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by `ctx`. - static Status GetMetadata(OpKernelConstruction* ctx, - const Metadata** metadata); + static absl::Status GetMetadata(OpKernelConstruction* ctx, + const Metadata** metadata); // Sets `*metadata` to the XlaDevice Metadata in the XLA device used by // `device`. - static Status GetMetadataFromDevice(DeviceBase* device, - const XlaDevice::Metadata** metadata); + static absl::Status GetMetadataFromDevice( + DeviceBase* device, const XlaDevice::Metadata** metadata); struct Options { // The StreamExecutor platform. Not owned. Must be non-null. @@ -157,19 +158,20 @@ class XlaDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override; + absl::Status Sync() override; - Status TryGetDeviceContext(DeviceContext** out_context) override + absl::Status TryGetDeviceContext(DeviceContext** out_context) override TF_LOCKS_EXCLUDED(mu_); - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override TF_LOCKS_EXCLUDED(mu_); + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override + TF_LOCKS_EXCLUDED(mu_); - Status MakeTensorFromProto(DeviceContext* device_context, - const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor); + absl::Status MakeTensorFromProto(DeviceContext* device_context, + const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor); const Metadata& metadata() { return xla_metadata_; } @@ -179,7 +181,7 @@ class XlaDevice : public LocalDevice { // // TODO(b/111859745): The Eager context needs to call this method to recover // from failures. - Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); + absl::Status EnsureDeviceContextOk() TF_LOCKS_EXCLUDED(mu_); // Two convenient methods to get the underlying device context. // Get the default device context, created by the first @@ -190,7 +192,7 @@ class XlaDevice : public LocalDevice { // Instructs this XlaDevice to set a AcceleratorDeviceInfo, which holds extra // information for GPU and TPU devices. - Status UseAcceleratorDeviceInfo() TF_LOCKS_EXCLUDED(mu_); + absl::Status UseAcceleratorDeviceInfo() TF_LOCKS_EXCLUDED(mu_); // Instructs this XlaDevice to return 'sync_on_completion' for // AllowsSyncOnCompletion(). @@ -199,17 +201,17 @@ class XlaDevice : public LocalDevice { bool AllowsSyncOnCompletion() const override TF_LOCKS_EXCLUDED(mu_); // Installs an error handling callback when RefreshStatus sees !status.ok(). - void SetHandleDeviceErrorCallback(std::function callback); + void SetHandleDeviceErrorCallback(std::function callback); - Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); + absl::Status RefreshStatus() override TF_LOCKS_EXCLUDED(mu_); private: absl::StatusOr GetOrCreateClient() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, - std::shared_ptr* stream, - bool* stream_was_changed) + absl::Status EnsureStreamOkLocked(xla::Backend* backend, const string& name, + std::shared_ptr* stream, + bool* stream_was_changed) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Return a vector of device context, ordered by the sequence in the given @@ -218,7 +220,7 @@ class XlaDevice : public LocalDevice { TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Handles error when RefreshStatus sees !status.ok(). - Status HandleDeviceError(); + absl::Status HandleDeviceError(); mutable mutex mu_; // The metadata of this XlaDevice. @@ -279,7 +281,7 @@ class XlaDevice : public LocalDevice { bool sync_on_completion_ TF_GUARDED_BY(mu_) = true; // A callback that will be invoked when RefreshStatus sees a status error. - std::function device_error_callback_ TF_GUARDED_BY(mu_); + std::function device_error_callback_ TF_GUARDED_BY(mu_); // Set of devices to use. This controls which of the devices on the given // platform will have resources allocated. For GPUs this will be @@ -311,7 +313,7 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels( XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device); -Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); +absl::Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape); } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index faf3b65d407a7e..2cecba9c3da5ec 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -136,7 +136,7 @@ void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaLayoutPreference layout_preference = shape_determination_fns_.layout_preference_fn( device_tensor->shape(), device_tensor->dtype(), std::nullopt); - Status status = [&]() -> Status { + absl::Status status = [&]() -> absl::Status { TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_determination_fns_.shape_representation_fn( device_tensor->shape(), device_tensor->dtype(), @@ -263,7 +263,7 @@ void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, device_to_host_stream.get(), xla_tensor->shaped_buffer(), literal, [this, ref, xla_tensor, done, device_to_host_stream, device_allows_sync_on_completion](absl::Status status) { - Status done_status = status; + absl::Status done_status = status; VLOG(2) << "Transfer from device as literal: " << xla_tensor->shaped_buffer().ToString(); // For devices don't allow sync on completion, the device execution is @@ -300,9 +300,9 @@ se::Stream* XlaDeviceContext::GetDeviceToDeviceStream() { return device_to_device_stream(stream); } -Status XlaDeviceContext::ThenExecute(Device* device, - stream_executor::Stream* stream, - std::function func) { +absl::Status XlaDeviceContext::ThenExecute(Device* device, + stream_executor::Stream* stream, + std::function func) { VLOG(2) << "XlaDeviceContext::ThenExecute"; return stream->DoHostCallback(std::move(func)); } diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index a0c72517cd5ebe..4e8a769eae0f9b 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -89,8 +89,8 @@ class XlaDeviceContext : public DeviceContext { // Returns a device-to-device stream, in round-robin fashion. se::Stream* GetDeviceToDeviceStream(); - Status ThenExecute(Device* device, stream_executor::Stream* stream, - std::function func) override; + absl::Status ThenExecute(Device* device, stream_executor::Stream* stream, + std::function func) override; private: bool UseMultipleStreams() const { return stream_ != host_to_device_stream_; } diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index a16415ececc035..64f98698ccd951 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -39,12 +39,14 @@ namespace tensorflow { class XlaGpuDeviceFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; }; -Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { +absl::Status XlaGpuDeviceFactory::ListPhysicalDevices( + std::vector* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); if (!flags->tf_xla_enable_xla_devices && !XlaDevicesCreationRequired()) { VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set " @@ -72,7 +74,7 @@ Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector* devices) { return absl::OkStatus(); } -Status XlaGpuDeviceFactory::CreateDevices( +absl::Status XlaGpuDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { XlaDeviceFlags* flags = GetXlaDeviceFlags(); @@ -139,7 +141,7 @@ Status XlaGpuDeviceFactory::CreateDevices( options.shape_determination_fns = {shape_representation_fns}; auto device = std::make_unique(session_options, options); - Status status = device->UseAcceleratorDeviceInfo(); + absl::Status status = device->UseAcceleratorDeviceInfo(); if (!status.ok()) { LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT << " device. Device number is " << i << ", reason: " << status; diff --git a/tensorflow/compiler/jit/xla_host_recv_device_context.cc b/tensorflow/compiler/jit/xla_host_recv_device_context.cc index ae3c149d5d1387..479abe923e0fb8 100644 --- a/tensorflow/compiler/jit/xla_host_recv_device_context.cc +++ b/tensorflow/compiler/jit/xla_host_recv_device_context.cc @@ -24,7 +24,7 @@ void XlaHostRecvDeviceContext::CopyDeviceTensorToCPU( Tensor* cpu_tensor, StatusCallback done) { DataType dtype = EncodePrimitiveTypeAsDataType(shape_.element_type()).value(); TensorShape tensor_shape; - Status status = XLAShapeToTensorShape(shape_, &tensor_shape); + absl::Status status = XLAShapeToTensorShape(shape_, &tensor_shape); if (!status.ok()) { done(status); return; diff --git a/tensorflow/compiler/jit/xla_kernel_creator.cc b/tensorflow/compiler/jit/xla_kernel_creator.cc index 337a93a5bb3a62..f180403305d9d5 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator.cc @@ -50,9 +50,9 @@ bool XlaKernelCreator::CanCreateKernel( !XlaOpRegistry::IsCompilationDevice(flr.device()->device_type()); } -static Status CreateXlaKernel(FunctionLibraryRuntime* flr, - const NodeDef& node_def, - std::unique_ptr* kernel) { +static absl::Status CreateXlaKernel(FunctionLibraryRuntime* flr, + const NodeDef& node_def, + std::unique_ptr* kernel) { if (!CanCreateXlaKernel(node_def)) { return errors::Internal("Invalid node: ", node_def.ShortDebugString()); } @@ -77,7 +77,7 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, // Create the kernel. Device* dev = flr->device(); - Status s; + absl::Status s; auto props = std::make_shared( &fbody->record->fdef().signature(), node_def, fbody->arg_types, fbody->ret_types); @@ -93,7 +93,7 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr, return s; } -Status XlaKernelCreator::CreateKernel( +absl::Status XlaKernelCreator::CreateKernel( FunctionLibraryRuntime* flr, const std::shared_ptr& props, std::unique_ptr* kernel) const { diff --git a/tensorflow/compiler/jit/xla_kernel_creator.h b/tensorflow/compiler/jit/xla_kernel_creator.h index 843a21acd19176..67c843bdb5cf72 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator.h +++ b/tensorflow/compiler/jit/xla_kernel_creator.h @@ -37,9 +37,9 @@ class XlaKernelCreator : public CustomKernelCreator { const std::shared_ptr& props) const override; // Given a supported NodeDef, returns a XlaLaunchOp that computes the node. - Status CreateKernel(FunctionLibraryRuntime* flr, - const std::shared_ptr& props, - std::unique_ptr* kernel) const override; + absl::Status CreateKernel(FunctionLibraryRuntime* flr, + const std::shared_ptr& props, + std::unique_ptr* kernel) const override; }; bool RegisterLaunchOpCreator(); diff --git a/tensorflow/compiler/jit/xla_kernel_creator_test.cc b/tensorflow/compiler/jit/xla_kernel_creator_test.cc index b66e4270d3a004..12ab76a7c1ce37 100644 --- a/tensorflow/compiler/jit/xla_kernel_creator_test.cc +++ b/tensorflow/compiler/jit/xla_kernel_creator_test.cc @@ -107,7 +107,8 @@ TEST_F(XlaKernelCreatorTest, OneFloatOneResourceArgument) { (*(callsite->node_def.mutable_attr()))["_XlaMustCompile"] = BoolAttr(true); // Note: need to set attribute on the created node. - Status status = xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); + absl::Status status = + xla_kernel_creator.CreateKernel(flr_, callsite, &kernel_); ASSERT_TRUE(status.ok()) << status.ToString(); EXPECT_EQ("XTimesY", kernel_->name()); @@ -129,14 +130,12 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrNotSet) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = - xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + absl::Status status = xla_kernel_creator.CreateKernel( + flr_, + ToNodeProperties(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), + &kernel_); EXPECT_TRUE(absl::IsInternal(status)) << status; } @@ -146,14 +145,12 @@ TEST_F(XlaKernelCreatorTest, FailsIfXlaCompileAttrIsSetToFalse) { Init({fdef}); XlaKernelCreator xla_kernel_creator; - Status status = - xla_kernel_creator.CreateKernel(flr_, ToNodeProperties(R"proto( - name: 'XTimesY' - op: 'XTimesY' - input: 'a' - input: 'b' - )proto"), - &kernel_); + absl::Status status = xla_kernel_creator.CreateKernel( + flr_, + ToNodeProperties(R"pb( + name: 'XTimesY' op: 'XTimesY' input: 'a' input: 'b' + )pb"), + &kernel_); EXPECT_TRUE(absl::IsInternal(status)) << status; } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 27a8f16b5f1323..ff53f31a10c719 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -279,7 +279,7 @@ static absl::StatusOr GetOrCreateTensorForOutput( } // Sets output `output_num` for `ctx` provided it is known at a compile time. -Status SetOutputForConstant( +absl::Status SetOutputForConstant( OpKernelContext* ctx, bool requires_copy_to_device, const XlaCompiler::CompilationResult* compilation_result, int output_num) { CHECK(compilation_result->outputs[output_num].is_constant); @@ -302,7 +302,7 @@ Status SetOutputForConstant( } ctx->op_device_context()->CopyCPUTensorToDevice( &const_tensor, device, output_tensor, - [&](Status status) { TF_CHECK_OK(status); }); + [&](absl::Status status) { TF_CHECK_OK(status); }); if (device->device_type() == DEVICE_GPU) { // The GPUDeviceContext enqueues the host->device transfer in a @@ -357,7 +357,7 @@ absl::StatusOr> GatherVariableInfo( return std::move(out); } -Status XlaComputationLaunchContext::PopulateOutputs( +absl::Status XlaComputationLaunchContext::PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, ScopedShapedBuffer output, int missing_ctx_input_prefix, @@ -606,7 +606,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( } // TODO(b/289002708) Create a unit test to cover use_pjrt_tensor_buffer=true. -Status PreparePjRtExecutableArguments( +absl::Status PreparePjRtExecutableArguments( int num_missing_prefix_ctx_inputs, const std::vector& input_mapping, const std::vector& inputs, const absl::flat_hash_map& variable_snapshots, @@ -684,7 +684,7 @@ Status PreparePjRtExecutableArguments( } // TODO(b/289002708) Create a unit test to cover use_pjrt_tensor_buffer=true. -Status PopulateCtxOutputsFromPjRtExecutableOutputs( +absl::Status PopulateCtxOutputsFromPjRtExecutableOutputs( int num_missing_prefix_ctx_inputs, const std::vector& inputs, const std::vector& variables, const XlaCompiler::CompilationResult& compilation_result, @@ -825,7 +825,7 @@ DeviceType GetDeviceType(OpKernelContext* ctx) { return DeviceType(device->device_type()); } -Status RunPjRtExecutable( +absl::Status RunPjRtExecutable( const std::vector& inputs, const std::vector& variables, const XlaCompiler::CompilationResult& compilation_result, @@ -841,7 +841,7 @@ Status RunPjRtExecutable( } // TODO(b/289421064): Add unit test for this. -Status RunPjRtExecutable( +absl::Status RunPjRtExecutable( int num_missing_prefix_ctx_inputs, const std::vector& inputs, const absl::flat_hash_map& variable_snapshots, const std::vector& updated_variables, @@ -940,7 +940,7 @@ absl::StatusOr>> RunPjRtExecutable( // is ready i.e. when the execution is complete. if (!owned_executable_args.empty() && future.has_value()) { future->OnReady([owned_executable_args = - std::move(owned_executable_args)](Status s) {}); + std::move(owned_executable_args)](absl::Status s) {}); } return execute_outputs; diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 32f277ff13b0b8..5e5128d515bf97 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -49,7 +49,7 @@ std::vector InputsFromContext(OpKernelContext* ctx); absl::StatusOr> GetConstantInputIndicesFromContext( OpKernelContext* ctx); -Status SetOutputForConstant( +absl::Status SetOutputForConstant( OpKernelContext* ctx, bool requires_copy_to_device, const XlaCompiler::CompilationResult* compilation_result, int output_num); @@ -78,7 +78,7 @@ Status SetOutputForConstant( // complete. Therefore we put the newly created PjRtBuffer into `owned_args`. // Caller is responsible to ensure `owned_args` lives till the end of XLA // computation. -Status PreparePjRtExecutableArguments( +absl::Status PreparePjRtExecutableArguments( int num_missing_prefix_ctx_inputs, const std::vector& input_mapping, const std::vector& inputs, const absl::flat_hash_map& variable_snapshots, @@ -95,7 +95,7 @@ Status PreparePjRtExecutableArguments( // Assumes that the first `num_missing_prefix_ctx_inputs` inputs to the // compilation_result are missing in `inputs` and adjusts indexing into `inputs` // accordingly. -Status PopulateCtxOutputsFromPjRtExecutableOutputs( +absl::Status PopulateCtxOutputsFromPjRtExecutableOutputs( int num_missing_prefix_ctx_inputs, const std::vector& inputs, const std::vector& variables, const XlaCompiler::CompilationResult& compilation_result, @@ -118,7 +118,7 @@ DeviceType GetDeviceType(OpKernelContext* ctx); // `variables` are the input arguments to the computation, usually read from the // OpKernelContext, `ctx`. Requires the device-appropriate `pjrt_client` and the // `compilation_result` used to build the `executable`. -Status RunPjRtExecutable( +absl::Status RunPjRtExecutable( const std::vector& inputs, const std::vector& variables, const XlaCompiler::CompilationResult& compilation_result, @@ -132,7 +132,7 @@ Status RunPjRtExecutable( // Assumes that the first `num_missing_prefix_ctx_inputs` inputs to the // compilation_result are missing in `inputs` and adjusts indexing into `inputs` // accordingly. -Status RunPjRtExecutable( +absl::Status RunPjRtExecutable( int num_missing_prefix_ctx_inputs, const std::vector& inputs, const absl::flat_hash_map& variable_snapshots, const std::vector& updated_variables, @@ -202,7 +202,7 @@ class XlaComputationLaunchContext { // // Assumes that the first `missing_ctx_input_prefix` inputs to the // compilation_result are missing and adjusts input indices accordingly. - Status PopulateOutputs( + absl::Status PopulateOutputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index 84f493572d2f5e..f9af695e33c163 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -98,7 +98,7 @@ absl::StatusOr>> GetAllowedGpus( return gpu_ids; } -Status GetCompilationDeviceTypeAndPjRtClient( +absl::Status GetCompilationDeviceTypeAndPjRtClient( const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, DeviceType* compilation_device_type, xla::PjRtClient** pjrt_client) { DeviceType device_type = platform_info.device_type(); @@ -204,10 +204,11 @@ absl::StatusOr GetCompilationDeviceType( return compilation_device_type; } -Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, - const XlaPlatformInfo& platform_info, - DeviceType compilation_device_type, - XlaDeviceCompiler** xla_device_compiler) { +absl::Status BuildXlaDeviceCompiler(DeviceBase* device, + FunctionLibraryRuntime* flr, + const XlaPlatformInfo& platform_info, + DeviceType compilation_device_type, + XlaDeviceCompiler** xla_device_compiler) { if (platform_info.platform_id() == nullptr && platform_info.device_type() == DEVICE_GPU) { // We do not need to (and cannot) build a real device compiler for GPU @@ -267,7 +268,7 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, // // So bail out of _XlaCompile in this case, and let the executor handle // the situation for us. - const Status& status = compiler_for_platform.status(); + const absl::Status& status = compiler_for_platform.status(); if (status.code() == error::NOT_FOUND) { return errors::Unimplemented("Could not find compiler for platform ", platform.value()->Name(), ": ", @@ -295,7 +296,7 @@ Status BuildXlaDeviceCompiler(DeviceBase* device, FunctionLibraryRuntime* flr, return absl::OkStatus(); } -Status GetOrCreatePjRtDeviceCompilerAndProfiler( +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( const XlaPlatformInfo& platform_info, ResourceMgr* rm, FunctionLibraryRuntime* flr, PjRtDeviceCompiler** pjrt_device_compiler, DeviceCompilationProfiler** profiler) { @@ -307,7 +308,7 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler( bool deleted_old_device_compiler = false; // Lookup the DeviceCompiler, create one if not found. - Status s = rm->Lookup( + absl::Status s = rm->Lookup( rm->default_container(), compiler_name, pjrt_device_compiler); if (s.ok() && device_type == DEVICE_TPU) { auto* existing_pjrt_client = (*pjrt_device_compiler)->client(); @@ -352,7 +353,7 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler( return absl::OkStatus(); } -Status GetOrCreatePjRtDeviceCompilerAndProfiler( +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, DeviceCompiler** diff --git a/tensorflow/compiler/jit/xla_platform_info.h b/tensorflow/compiler/jit/xla_platform_info.h index 94764e4d3dd7fe..7c5099f0ff94a8 100644 --- a/tensorflow/compiler/jit/xla_platform_info.h +++ b/tensorflow/compiler/jit/xla_platform_info.h @@ -116,7 +116,7 @@ absl::StatusOr GetCompilationDeviceType( // point to it. Uses flags from `MarkForCompilationPassFlags` for configuring // the persistor used in the DeviceCompiler. The platform ID from // `platform_info` must not be null in CPU case. -Status BuildXlaDeviceCompiler( +absl::Status BuildXlaDeviceCompiler( DeviceBase* dev, FunctionLibraryRuntime* flr, const XlaPlatformInfo& platform_info, DeviceType compilation_device_type, DeviceCompiler** @@ -132,7 +132,7 @@ Status BuildXlaDeviceCompiler( // non-XLA devices aren't supported yet. This is because: // 1. PjRtClient doesn't support data transfer for non-XLA devices yet // 2. Fetching the PjRtClient for non-XLA devices is also not supported yet -Status GetOrCreatePjRtDeviceCompilerAndProfiler( +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( const OpKernelContext& ctx, const XlaPlatformInfo& platform_info, FunctionLibraryRuntime* flr, DeviceCompiler** @@ -141,7 +141,7 @@ Status GetOrCreatePjRtDeviceCompilerAndProfiler( // Same as the above function but takes the resource manager `rm` instead of an // OpKernelContext. -Status GetOrCreatePjRtDeviceCompilerAndProfiler( +absl::Status GetOrCreatePjRtDeviceCompilerAndProfiler( const XlaPlatformInfo& platform_info, ResourceMgr* rm, FunctionLibraryRuntime* flr, DeviceCompiler** diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc index 3a7ea396e61862..e9cdad219dd28d 100644 --- a/tensorflow/compiler/jit/xla_tensor.cc +++ b/tensorflow/compiler/jit/xla_tensor.cc @@ -41,10 +41,10 @@ namespace tensorflow { } } -Status XlaTensor::AllocateShapedBuffer(DataType dtype, - const xla::Shape& on_device_shape, - xla::LocalClient* client, - int device_ordinal) { +absl::Status XlaTensor::AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_device_shape, + xla::LocalClient* client, + int device_ordinal) { xla::Shape on_host_shape = xla::ShapeUtil::DeviceShapeToHostShape(on_device_shape); xla::ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, @@ -94,9 +94,9 @@ void XlaTensor::ResetDefinitionEvent(std::shared_ptr event, streams_defined_on_ = {stream}; } -Status XlaTensor::RefreshStatusOfStreams() { +absl::Status XlaTensor::RefreshStatusOfStreams() { mutex_lock lock(mu_); - Status status; + absl::Status status; for (se::Stream* stream : streams_defined_on_) { status.Update(stream->RefreshStatus()); } diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h index 6ffd6401d79f93..91e06ddf17a684 100644 --- a/tensorflow/compiler/jit/xla_tensor.h +++ b/tensorflow/compiler/jit/xla_tensor.h @@ -48,8 +48,10 @@ class XlaTensor { // Assign the internal ShapedBuffer to new memory for the given dtype and // shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it // is replaced and the managed memory deallocated. - Status AllocateShapedBuffer(DataType dtype, const xla::Shape& on_device_shape, - xla::LocalClient* client, int device_ordinal); + absl::Status AllocateShapedBuffer(DataType dtype, + const xla::Shape& on_device_shape, + xla::LocalClient* client, + int device_ordinal); // Some Tensors can have complex on-device shapes, including tuple shapes. To // manage the memory for these tensors a ShapedBuffer may be required. @@ -87,7 +89,7 @@ class XlaTensor { // Refresh the status of streams_defined_on_. Return the first not-OK stream's // status or OK. - Status RefreshStatusOfStreams(); + absl::Status RefreshStatusOfStreams(); // Convert from a raw pointer to an XlaTensor, removing the pointer tag. static XlaTensor* FromOpaquePointer(void* ptr); diff --git a/tensorflow/compiler/jit/xla_tpu_device.cc b/tensorflow/compiler/jit/xla_tpu_device.cc index f9524163a6223b..9c7ba59bf9fe03 100644 --- a/tensorflow/compiler/jit/xla_tpu_device.cc +++ b/tensorflow/compiler/jit/xla_tpu_device.cc @@ -75,7 +75,7 @@ absl::StatusOr TpuShapeRepresentation( // Given a tensor, returns the shape of its representation on device, // fully padded. Contents of `shape` are undefined on error. -Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { +absl::Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { const tensorflow::XlaTensor* xla_tensor = tensorflow::XlaTensor::FromTensor(&tensor); if (xla_tensor == nullptr) { @@ -106,7 +106,7 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { // Check if TPU has been initialized. TPU initialization is not necessary // for 1x1. -Status CheckIfTPUInitialized() { +absl::Status CheckIfTPUInitialized() { auto* tpu_platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(); if (!tpu_platform->Initialized()) { return errors::FailedPrecondition( @@ -136,9 +136,9 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, static const bool should_use_substream = tpu_use_substreams_for_cross_tpu_device_transfers_flag; - auto impl = [&]() -> Status { + auto impl = [&]() -> absl::Status { if (src->name() != dst->name()) { - Status s = CheckIfTPUInitialized(); + absl::Status s = CheckIfTPUInitialized(); if (!s.ok()) { done(s); return absl::OkStatus(); @@ -301,7 +301,7 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, done(absl::OkStatus()); }); }; - Status status = impl(); + absl::Status status = impl(); if (!status.ok()) { done(status); } @@ -309,12 +309,14 @@ void TpuDeviceToDeviceCopy(DeviceContext* src_dev_context, class TpuNodeDeviceFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; }; -Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector* devices) { +absl::Status TpuNodeDeviceFactory::ListPhysicalDevices( + std::vector* devices) { tpu::TpuPlatformInterface* platform = tpu::TpuPlatformInterface::GetRegisteredPlatform(); if (platform == nullptr) { @@ -332,7 +334,7 @@ Status TpuNodeDeviceFactory::ListPhysicalDevices(std::vector* devices) { return absl::OkStatus(); } -Status TpuNodeDeviceFactory::CreateDevices( +absl::Status TpuNodeDeviceFactory::CreateDevices( const SessionOptions& session_options, const string& name_prefix, std::vector>* devices) { tpu::TpuPlatformInterface* platform = @@ -392,7 +394,7 @@ Status TpuNodeDeviceFactory::CreateDevices( // The AcceleratorDeviceInfo actually provides information not only for GPU // devices but also for TPU. The name is a legacy from the pre-TPU // dark ages. - Status status = device->UseAcceleratorDeviceInfo(); + absl::Status status = device->UseAcceleratorDeviceInfo(); if (!status.ok()) { errors::AppendToMessage(&status, "while setting up ", DEVICE_TPU_XLA_JIT, " device number ", i); @@ -411,12 +413,13 @@ Status TpuNodeDeviceFactory::CreateDevices( class TpuSystemDeviceFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override; }; -Status TpuSystemDeviceFactory::ListPhysicalDevices( +absl::Status TpuSystemDeviceFactory::ListPhysicalDevices( std::vector* devices) { int device_count = 0; TF_RETURN_IF_ERROR(tpu::TpuPlatform::TpusPerHost(&device_count)); @@ -430,7 +433,7 @@ Status TpuSystemDeviceFactory::ListPhysicalDevices( return absl::OkStatus(); } -Status TpuSystemDeviceFactory::CreateDevices( +absl::Status TpuSystemDeviceFactory::CreateDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { int device_count = 0; diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index eabb87e1e89913..10691fa490b63a 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -128,6 +128,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -195,9 +196,11 @@ cc_library( hdrs = ["register_common_dialects.h"], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", + "//tensorflow/core/ir/types:Dialect", "@llvm-project//mlir:AllExtensions", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", @@ -233,6 +236,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:translate_registration", + "//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", @@ -240,8 +244,8 @@ tf_cc_binary( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", - "@local_xla//xla/translate/hlo_to_mhlo:translate_registration", - "@local_xla//xla/translate/mhlo_to_hlo:translate_registration", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:translate_registration", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:translate_registration", ], ) diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index c87dc83bdde956..ad44b889cc62a8 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -81,6 +81,7 @@ def glob_lit_tests( driver = _default_driver, features = [], exec_properties = {}, + use_lit_test_suite = None, # @unused hermetic_cuda_data_dir = None): """Creates all plausible Lit tests (and their inputs) under this directory. @@ -101,6 +102,7 @@ def glob_lit_tests( exec_properties: a dictionary of properties to pass on. hermetic_cuda_data_dir: string. If set, the tests will be run with a `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. + use_lit_test_suite: unused. For compatibility. """ # Ignore some patterns by default for tests and input data. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 88c47a36ff938f..214d93390a670b 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -27,7 +27,7 @@ package_group( "//third_party/iree/...", "//third_party/odml/infra/...", "//tensorflow/compiler/mlir/...", - "//tensorflow/lite/python/...", + "//tensorflow/lite/...", "//waymo/accelerator/alpine/tools/...", "//waymo/ml/compiler/mlir/...", # Allow visibility from the mlir language server. @@ -905,6 +905,7 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", @@ -1127,19 +1128,19 @@ cc_library( cc_library( name = "tensorflow_lite_d2s", srcs = [ - "transforms/dense_to_sparse.cc", + "transforms/dense_to_sparse_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/dense_to_sparse_pass.h", ], deps = [ - ":tensorflow_lite", - ":tensorflow_lite_passes_inc_gen", + ":pass", + ":pass_options", + ":tensorflow_lite_ops", "//tensorflow/compiler/mlir/lite/kernels/internal/utils:sparsity_format_converter", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -1372,15 +1373,16 @@ cc_library( ], deps = [ ":const_tensor_utils", + ":control_edges", ":convert_type", ":flatbuffer_tflite_operator_lib", ":offset_buffer", ":size_utils", ":tensorflow_lite", - "//tensorflow/compiler/mlir/lite:control_edges", "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:debug_metadata_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1397,7 +1399,9 @@ cc_library( "//tensorflow/core/platform:errors", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Support", @@ -1410,6 +1414,7 @@ cc_library( "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@stablehlo//:stablehlo_ops", @@ -1567,8 +1572,8 @@ tf_cc_binary( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:translate", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/translate/hlo_to_mhlo:translate", "@stablehlo//:stablehlo_ops", ], ) @@ -1593,6 +1598,7 @@ cc_library( ":tensorflow_lite_push_transpose_through_ewise", # buildcleaner: keep ":tensorflow_lite_quantize", # buildcleaner: keep ":tensorflow_lite_tf_unfreeze_global_tensors", + "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", @@ -1620,6 +1626,7 @@ cc_library( "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:stablehlo_extension_passes", + "@stablehlo//:stablehlo_passes", ], ) @@ -1729,7 +1736,7 @@ cc_library( name = "string_utils", srcs = ["utils/string_utils.cc"], hdrs = ["utils/string_utils.h"], - visibility = ["//tensorflow/lite:__pkg__"], + visibility = ["//visibility:public"], ) exports_files(srcs = ["allocation.h"]) @@ -1760,10 +1767,7 @@ exports_files(srcs = ["utils/control_edges.h"]) cc_library( name = "control_edges", hdrs = ["utils/control_edges.h"], - visibility = [ - "//tensorflow/compiler/mlir/lite/experimental/remat:__pkg__", - "//tensorflow/lite:__pkg__", - ], + visibility = ["//tensorflow/compiler/mlir/lite/experimental/remat:__pkg__"], ) tf_cc_test( @@ -1881,21 +1885,18 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "types_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":types_proto"], # ) # # py_proto_library( # name = "model_flags_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":model_flags_proto"], # ) # # py_proto_library( # name = "converter_flags_py_proto", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":converter_flags_proto"], # ) diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 2de28d68703b7c..1941fbd7e63105 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" @@ -101,7 +102,7 @@ static int HasOptions(const Record &def) { } static void EmitOptionBuilders(const RecordKeeper &record_keeper, - const std::vector &defs, + const std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; @@ -129,7 +130,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper, mlir::tblgen::Operator op(*def); for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) { auto arg = arg_values->getArg(i); - DefInit *arg_def = dyn_cast(arg); + const auto *arg_def = dyn_cast(arg); if (!arg_def) continue; if (arg_def->getDef()->isSubClassOf(attr_type)) { // This binds the name of the attribute in the TD file with the name @@ -187,7 +188,7 @@ static void EmitOptionBuilders(const RecordKeeper &record_keeper, // arguments that depend on op definitions should be auto-generated and then // operator should be built by the caller because it does not require // auto-generation. -static void EmitOperatorBuilders(const std::vector &defs, +static void EmitOperatorBuilders(const std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; @@ -275,7 +276,7 @@ static inline std::string GetOperatorName(const Record &def) { // // TODO(hinsu): Consider converting this to a static constant associative // container instead of a series of if conditions, if required. -static void EmitGetBuiltinOpCode(const std::vector &defs, +static void EmitGetBuiltinOpCode(const std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; @@ -305,7 +306,7 @@ static void EmitGetBuiltinOpCode(const std::vector &defs, // return {0, 0}; // } static void EmitOperandNumbers(const RecordKeeper &record_keeper, - const std::vector &defs, + const std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; const auto attr_type = record_keeper.getClass("Attr"); @@ -351,7 +352,7 @@ static void EmitOperandNumbers(const RecordKeeper &record_keeper, // const std::vector& intermediates, // flatbuffers::FlatBufferBuilder *fbb, // std::optional debug_metadata_index); -static void EmitBuildOperator(const std::vector &defs, +static void EmitBuildOperator(const std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; @@ -390,10 +391,9 @@ static void EmitBuildOperator(const std::vector &defs, // // where id is an empty string if builtin_options_id is 1, or builtin_options_id // otherwise. -static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, - const std::vector &defs, - raw_ostream *ostream, - const int builtin_options_id) { +static void EmitBuiltinOptionsToAttributes( + const RecordKeeper &record_keeper, const std::vector &defs, + raw_ostream *ostream, const int builtin_options_id) { raw_ostream &os = *ostream; const std::string builtin_options_suffix = [&] { @@ -433,7 +433,7 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, auto *arg_values = def->getValueAsDag("arguments"); for (unsigned i = 0, e = arg_values->getNumArgs(); i != e; ++i) { auto arg = arg_values->getArg(i); - DefInit *arg_def = dyn_cast(arg); + const auto *arg_def = dyn_cast(arg); if (!arg_def) continue; if (arg_def->getDef()->isSubClassOf(attr_type)) { StringRef arg_name = arg_values->getArgNameStr(i); @@ -464,11 +464,11 @@ static void EmitBuiltinOptionsToAttributes(const RecordKeeper &record_keeper, // The function below has a non-constant reference as that is required by LLVM's // TableGenMain. // NOLINTNEXTLINE -static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) { +static bool OperatorWritersMain(raw_ostream &os, const RecordKeeper &records) { emitSourceFileHeader("MLIR TFLite FlatBuffer Builders", os); // Retrieve all the definitions derived from TFL_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("TFL_Op"); + std::vector defs = records.getAllDerivedDefinitions("TFL_Op"); llvm::sort(defs, LessRecord()); for (const auto *def : defs) { @@ -503,7 +503,7 @@ static bool OperatorWritersMain(raw_ostream &os, RecordKeeper &records) { } static void GenOperandResultVerifier(raw_ostream &os, - llvm::ArrayRef values, + llvm::ArrayRef values, StringRef valueKind) { mlir::tblgen::FmtContext fctx; @@ -551,11 +551,12 @@ static void GenOperandResultVerifier(raw_ostream &os, } // NOLINTNEXTLINE -static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { +static bool RuntimeVerifierWriterMain(raw_ostream &os, + const RecordKeeper &records) { emitSourceFileHeader("MLIR TFLite Runtime Verifiers", os); // Retrieve all the definitions derived from TFL_Op and sort by record name. - std::vector defs = records.getAllDerivedDefinitions("Op"); + std::vector defs = records.getAllDerivedDefinitions("Op"); llvm::sort(defs, LessRecord()); // Iterate through all the ops defined. diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD index d76299aa723d51..daf69cf2c87e12 100644 --- a/tensorflow/compiler/mlir/lite/core/BUILD +++ b/tensorflow/compiler/mlir/lite/core/BUILD @@ -26,10 +26,7 @@ cc_library( hdrs = ["model_builder_base.h"], compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), - visibility = [ - "//tensorflow/compiler/mlir/lite:__subpackages__", - "//tensorflow/lite/core:__pkg__", - ], + visibility = ["//visibility:public"], deps = [ ":macros", "//tensorflow/compiler/mlir/lite:allocation", @@ -48,9 +45,7 @@ cc_library( hdrs = ["absl_error_model_builder.h"], compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), - visibility = [ - "//tensorflow/compiler/mlir/lite:__subpackages__", - ], + visibility = ["//visibility:public"], deps = [ ":model_builder_base", "//tensorflow/compiler/mlir/lite/core/api:error_reporter", diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD index 0aaca3928420d6..cad2d9f8567863 100644 --- a/tensorflow/compiler/mlir/lite/core/api/BUILD +++ b/tensorflow/compiler/mlir/lite/core/api/BUILD @@ -4,10 +4,7 @@ load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/compiler/mlir/lite:__subpackages__", - "//tensorflow/lite:__subpackages__", - ], + default_visibility = ["//visibility:public"], licenses = ["notice"], ) diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h index 3484f4e3d071d6..26cfe2890b3eae 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.h +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -386,9 +386,15 @@ class FlatBufferModelBase { size_t allocation_size = std::min(allocation->bytes(), static_cast(FLATBUFFERS_MAX_BUFFER_SIZE - 1)); + flatbuffers::Verifier::Options options; + // TODO(b/366118885): Remove after the root cause of the crash on Windows + // is found. +#if defined(_WIN32) + options.assert = true; +#endif flatbuffers::Verifier base_verifier( - reinterpret_cast(allocation->base()), - allocation_size); + reinterpret_cast(allocation->base()), allocation_size, + options); if (!VerifyModelBuffer(base_verifier)) { TF_LITE_REPORT_ERROR(error_reporter, "The model is not a valid Flatbuffer buffer"); @@ -495,7 +501,7 @@ class FlatBufferModelBase { std::map keys_values; if (!model || !model->metadata() || !model->buffers()) return keys_values; - for (int i = 0; i < model->metadata()->size(); ++i) { + for (size_t i = 0; i < model->metadata()->size(); ++i) { auto metadata = model->metadata()->Get(i); auto buf = metadata->buffer(); if (buf >= model->buffers()->size()) continue; diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD index d0fc7fc7d8693c..c8bfd87e378aa1 100644 --- a/tensorflow/compiler/mlir/lite/debug/BUILD +++ b/tensorflow/compiler/mlir/lite/debug/BUILD @@ -34,10 +34,10 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/io:buffered_file", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:stringpiece", + "@local_xla//xla/tsl/lib/io:buffered_file", ], ) diff --git a/tensorflow/compiler/mlir/lite/debug/debug.cc b/tensorflow/compiler/mlir/lite/debug/debug.cc index d0b85019cfe200..8b4b611a18108e 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug.cc @@ -47,8 +47,8 @@ limitations under the License. #include "re2/re2.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/lib/io/buffered_file.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/path.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD deleted file mode 100644 index 5d6dbd10c9c94f..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/BUILD +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], -) - -cc_binary( - name = "apply_plugin", - srcs = [ - "apply_plugin.cc", - # TODO: b/366821557 - Support pre-compiled plugins as data dependencies. - "//tensorflow/compiler/mlir/lite/experimental/lrt/examples:mul_op_plugin_so", - ], - deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:algo", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:api_internal", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", - "//tensorflow/lite/schema:schema_fbs", - "@llvm-project//llvm:Support", - ], -) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc b/tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc deleted file mode 100644 index c6e5543d000484..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/apply_plugin.cc +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/Support/CommandLine.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" - -// NOLINTNEXTLINE -static llvm::cl::opt model_path( - "model_path", llvm::cl::desc("Path to flatbuffer file."), - llvm::cl::init("")); - -// TODO: b/366821557 - Support path to pre-compiled plugin in flags. -// NOLINTNEXTLINE -static llvm::cl::opt soc_manufacturer( - "soc_man", - llvm::cl::desc("String identifier of SoC backend (pixel, qcc, darwinn)."), - llvm::cl::init("ExampleSocManufacturer")); - -// NOLINTNEXTLINE -static llvm::cl::opt soc_model( - "soc_model", - llvm::cl::desc("Compilation configuration identifier (chip type)."), - llvm::cl::init("DummyMulOp")); - -// TODO swap "dry_run" for optional "don't delete partitioned subgraphs". -// NOLINTNEXTLINE -static llvm::cl::opt dry_run( - "dry_run", - llvm::cl::desc( - "Only run \"partition\" phase and output the spliced out subgraphs."), - llvm::cl::init(true)); - -#define EXIT_IF_NULL(val, msg) \ - if (!val) { \ - std::cerr << msg << "\n"; \ - return 1; \ - } - -void DumpSubgraph(const LrtSubgraphT& subgraph, std::string_view label) { -#ifndef NDEBUG - std::cerr << "===== " << label << " =====\n"; - for (auto op : subgraph.ops) { - debug::DumpOp(*op); - } - for (auto tensor : subgraph.inputs) { - std::cerr << "SG_IN " << tensor << "\n"; - } - - for (auto tensor : subgraph.outputs) { - std::cerr << "SG_OUT " << tensor << "\n"; - } -#endif -} - -bool IsSocModelSupported(LrtCompilerPlugin plugin, - std::string_view requested_soc_model) { - const auto num_supported_configs = LrtPluginNumSupportedSocModels(plugin); - for (int i = 0; i < num_supported_configs; ++i) { - const char* config; - LRT_RETURN_VAL_IF_NOT_OK( - LrtPluginGetSupportedSocModelId(plugin, i, &config), false); - if (requested_soc_model == config) { - return true; - } - } - - return false; -} - -// TODO: b/366821557 - Replace loading pre-compiled plugin. -UniqueLrtCompilerPlugin LoadPlugin() { - if (soc_manufacturer != LrtPluginSocManufacturer()) { - std::cerr << "Only ExampleSocManufacturer currently supported"; - return nullptr; - } - - LrtCompilerPlugin plugin; - LRT_RETURN_VAL_IF_NOT_OK(LrtPluginInit(&plugin), nullptr); - auto result = UniqueLrtCompilerPlugin(plugin); - - if (!IsSocModelSupported(result.get(), soc_model)) { - std::cerr << "Only DummyMulOp currently supported\n"; - return nullptr; - } - - return result; -} - -UniqueLrtModel LoadModel(std::string_view filename) { - LrtModel model; - LRT_RETURN_VAL_IF_NOT_OK(LoadModelFromFile(filename.data(), &model), nullptr); - return UniqueLrtModel(model); -} - -LrtStatus ApplyPlugin(LrtModel model, LrtCompilerPlugin plugin) { - LRT_RETURN_STATUS_IF_NOT_OK( - RegisterCustomOpCode(model, LrtPluginSocManufacturer())); - - LrtOpListT selected_ops; - LRT_RETURN_STATUS_IF_NOT_OK( - LrtPluginPartitionModel(plugin, model, &selected_ops)); - - auto partitions = - algo::DisjointSets::GetPartitionsFromFlatList(selected_ops.ops); - - // TODO: b/366821557 - Support multiple subgraphs in plugin application. - auto& main_subgraph = model->subgraphs.front(); - DumpSubgraph(main_subgraph, "Main subgraph before partioning."); - - std::vector slices; - std::vector custom_ops; - slices.reserve(partitions.size()); - custom_ops.reserve(partitions.size()); - - for (auto& partition : partitions) { - LrtSubgraph new_subgraph = &model->subgraphs.emplace_back(); - - LrtOp custom_op = algo::GraphSlicer::SlicePartitionFromGraph( - main_subgraph, new_subgraph, partition); - custom_ops.push_back(custom_op); - slices.push_back(new_subgraph); - - DumpSubgraph(*new_subgraph, "New subgraph"); - } - - DumpSubgraph(main_subgraph, "Main subgraph after partioning."); - - if (dry_run) { - return StatusOk(); - } - - LrtCompiledResult compiled_result; - LRT_RETURN_STATUS_IF_NOT_OK( - LrtPluginCompile(plugin, slices.data(), slices.size(), &compiled_result)); - - lrt_param_index_t num_calls_compiled; - LRT_RETURN_STATUS_IF_NOT_OK( - LrtCompiledResultGetNumCalls(compiled_result, &num_calls_compiled)); - - if (num_calls_compiled != slices.size()) { - std::cerr - << "Plugin must provide and entry point for each compiled partition\n"; - return StatusCreate(kLrtStatusErrorNotFound); - } - - for (int i = 0; i < num_calls_compiled; ++i) { - const void* call_info; - size_t call_info_size; - - LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetCallInfo( - compiled_result, i, &call_info, &call_info_size)); - - auto* custom_op = custom_ops.at(i); - custom_op->custom_options.assign(reinterpret_cast(call_info), - call_info_size); - } - - const void* byte_code; - size_t byte_code_size; - - LRT_RETURN_STATUS_IF_NOT_OK(LrtCompiledResultGetByteCode( - compiled_result, &byte_code, &byte_code_size)); - - LRT_RETURN_STATUS_IF_NOT_OK(AppendMetadata(model, byte_code, byte_code_size, - LrtPluginSocManufacturer())); - - return StatusOk(); -} - -int main(int argc, char** argv) { - llvm::cl::ParseCommandLineOptions(argc, argv); - - auto model = LoadModel(model_path); - EXIT_IF_NULL(model, "Failed to load model"); - - auto plugin = LoadPlugin(); - EXIT_IF_NULL(plugin, "Failed to load plugin."); - - LRT_RETURN_VAL_IF_NOT_OK(ApplyPlugin(model.get(), plugin.get()), 1); - - uint8_t* buf; - size_t buf_size; - size_t buf_offset; - - LRT_RETURN_VAL_IF_NOT_OK( - SerializeModel(model.release(), &buf, &buf_size, &buf_offset), 1); - - std::string out(reinterpret_cast(buf) + buf_offset, - buf_size - buf_offset); - std::cout << out; - - delete[] buf; - - return 0; -} diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h b/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h deleted file mode 100644 index 9cd70dd4bc5168..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// Declares canonical opaque type. -#define LITE_RT_DEFINE_HANDLE(name) typedef struct name##T* name -// Declares an array of references to opaque type. `name` must be -// previously declared opaque type. -#define LITE_RT_DEFINE_HANDLE_ARRAY(name) typedef name* name##Array - -// Common status type. May have an attached opaque payload. Requires cleanup. -LITE_RT_DEFINE_HANDLE(LrtStatus); - -typedef enum { - kLrtStatusOk = 0, - - kLrtStatusErrorInvalidArgument = 1, - kLrtStatusErrorMemoryAllocationFailure = 2, - kLrtStatusErrorRuntimeFailure = 3, - kLrtStatusErrorMissingInputTensor = 4, - kLrtStatusErrorUnsupported = 5, - kLrtStatusErrorNotFound = 6, - - // File related errors. - kLrtStatusBadFileOp = 500, - kLrtStatusFlatbufferFailedVerify = 501, - - // IR related errors. - kLrtParamIndexOOB = 1000, - kLrtStatusBadTensorType = 1001, - kLrtStatusGraphInvariantError = 1002, -} LrtStatusCode; - -// Get code from status. -LrtStatusCode GetStatusCode(LrtStatus status); - -// Free any payloads attached to status. -void StatusDestroy(LrtStatus status); - -// Create a status with given code. -LrtStatus StatusCreate(LrtStatusCode code); - -// Create an ok status. -LrtStatus StatusOk(); - -// TODO: b/365295276 - Implement error message payloads for lrt status. - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMMON_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h b/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h deleted file mode 100644 index a254ae00a85ec5..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ - -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITE_RT_DEFINE_HANDLE(LrtCompilerPlugin); - -// Artifact produced from compiling a selected partition of ops. -LITE_RT_DEFINE_HANDLE(LrtCompiledResult); - -// -// Plugin -// - -LrtStatus LrtPluginInit(LrtCompilerPlugin* compiler_plugin); - -void LrtPluginDestroy(LrtCompilerPlugin compiler_plugin); - -// Name associated with the manufacturer this plugin relates to (darwinn, QCC). -const char* LrtPluginSocManufacturer(); - -// Number of soc models supported by this plugin. -lrt_param_index_t LrtPluginNumSupportedSocModels( - LrtCompilerPlugin compiler_plugin); - -// Gets a string identifying the given config index. -LrtStatus LrtPluginGetSupportedSocModelId(LrtCompilerPlugin compiler_plugin, - lrt_param_index_t config_idx, - const char** config_id); - -// Select desired ops for compilation. This will be called only once -// during the plugin application flow, all ops should be selected during this -// call. -LrtStatus LrtPluginPartitionModel(LrtCompilerPlugin compiler_plugin, - LrtModel model, LrtOpList selected_ops); - -// Prepare result to pass to the runtime for given partition. The given -// subgraphs are valid sub-DAG within the ops selected in partition step. -LrtStatus LrtPluginCompile(LrtCompilerPlugin compiler_plugin, - LrtSubgraphArray partitions, - lrt_param_index_t num_partitions, - LrtCompiledResult* compiled_result); - -// -// Compiled Partition -// - -void LrtCompiledResultDestroy(LrtCompiledResult result); - -// Get serialized result to compiled modules available to all custom ops. -// This could be one module with multiple entry points or multiple modules -// concat together. -LrtStatus LrtCompiledResultGetByteCode(LrtCompiledResult compiled_result, - const void** byte_code, - size_t* byte_code_size); - -// Get info to embed in a particular custom op. This could be any opaque data -// parsed in the custom op. -LrtStatus LrtCompiledResultGetCallInfo(LrtCompiledResult compiled_result, - lrt_param_index_t call_idx, - const void** call_info, - size_t* call_info_size); - -// Get the number of calls that will be made to the HAL for this graph. -// This should equal the number of partitions given for compilation which -// is equal to the number of custom ops in the final model. -LrtStatus LrtCompiledResultGetNumCalls(LrtCompiledResult compiled_result, - lrt_param_index_t* num_calls); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h b/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h deleted file mode 100644 index 190294c2d8a287..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ - -#include -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/lite/core/c/c_api_types.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -LITE_RT_DEFINE_HANDLE(LrtBuffer); - -LITE_RT_DEFINE_HANDLE(LrtTensor); -LITE_RT_DEFINE_HANDLE_ARRAY(LrtTensor); - -LITE_RT_DEFINE_HANDLE(LrtOp); -LITE_RT_DEFINE_HANDLE_ARRAY(LrtOp); - -LITE_RT_DEFINE_HANDLE(LrtSubgraph); -LITE_RT_DEFINE_HANDLE_ARRAY(LrtSubgraph); - -LITE_RT_DEFINE_HANDLE(LrtModel); - -// Append only list of ops. -LITE_RT_DEFINE_HANDLE(LrtOpList); - -// For indexing into lrt collections or counting lrt things. -typedef uint64_t lrt_param_index_t; - -// -// Tensors -// - -typedef enum { - kLrtElementTypeNone = kTfLiteNoType, - kLrtElementTypeBool = kTfLiteBool, - kLrtElementTypeInt4 = kTfLiteInt4, - kLrtElementTypeInt8 = kTfLiteInt8, - kLrtElementTypeInt16 = kTfLiteInt16, - kLrtElementTypeInt32 = kTfLiteInt32, - kLrtElementTypeInt64 = kTfLiteInt64, - kLrtElementTypeUInt8 = kTfLiteUInt8, - kLrtElementTypeUInt16 = kTfLiteUInt16, - kLrtElementTypeUInt32 = kTfLiteUInt32, - kLrtElementTypeUInt64 = kTfLiteUInt64, - kLrtElementTypeFloat16 = kTfLiteFloat16, - kLrtElementTypeBFloat16 = kTfLiteBFloat16, - kLrtElementTypeFloat32 = kTfLiteFloat32, - kLrtElementTypeFloat64 = kTfLiteFloat64, - kLrtElementTypeComplex64 = kTfLiteComplex64, - kLrtElementTypeComplex128 = kTfLiteComplex128, - kLrtElementTypeTfResource = kTfLiteResource, - kLrtElementTypeTfString = kTfLiteString, - kLrtElementTypeTfVariant = kTfLiteVariant, -} LrtElementType; - -typedef struct { - uint32_t rank; - // TODO: b/365299994 - Decide on canonical type(s) for indices({s}32/64). Also - // representation of dynamic dim. - const int32_t* dimensions; -} LrtLayout; - -// Tensor whose rank is dynamic. -typedef struct { - LrtElementType element_type; -} LrtUnrankedTensorType; - -// Tensor whose rank is static but dimenions may be dynamic. -typedef struct { - LrtElementType element_type; - LrtLayout layout; -} LrtRankedTensorType; - -typedef enum { - kLrtRankedTensorType = 0, - kLrtUnrankedTensorType = 1, - // TODO: b/365299994 - q types. -} LrtTensorTypeId; - -// Get type identifier from tensor. -LrtStatus GetTensorTypeId(LrtTensor tensor, LrtTensorTypeId* type_id); - -// Get unranked tensor type info, return bad status if not unranked. -LrtStatus GetUrankedTensorType(LrtTensor tensor, - LrtUnrankedTensorType* unranked_tensor_type); - -// Get ranked tensor type info, return bad status if not ranked. -LrtStatus GetRankedTensorType(LrtTensor tensor, - LrtRankedTensorType* ranked_tensor_type); - -// Get opaque array from given buffer. -LrtStatus GetBufferInfo(LrtBuffer buffer, size_t* size, const void** addr); - -// Get buffer associated with given tensor. All tensors have a buffer, -// null buffers have size = 0; -LrtStatus GetTensorBuffer(LrtTensor tensor, LrtBuffer* buffer); - -// Get all the ops that reference given tensor, and at what operand index. -LrtStatus GetTensorUses(LrtTensor tensor, lrt_param_index_t* num_uses, - LrtOpArray* users, lrt_param_index_t** user_arg_inds); - -// Get the op that defines this tensor and the corresponding output index. If -// tensor is a subgraph input, defining op will be null. -LrtStatus GetTensorDefiningOp(LrtTensor tensor, LrtOp* maybe_defining_op, - lrt_param_index_t* maybe_defining_op_output_ind); - -// -// Op -// - -// Get output tensors of given op. -LrtStatus GetOpOutputs(LrtOp op, lrt_param_index_t* num_outputs, - LrtTensorArray* output); - -// Get input tensors of given op. -LrtStatus GetOpInputs(LrtOp op, lrt_param_index_t* num_inputs, - LrtTensorArray* inputs); - -// Get code corresponding to operation type for given op. -LrtStatus GetOpCode(LrtOp op, LrtOpCode* code); - -// -// Subgraph -// - -// Get input tensors for given subgraph. -LrtStatus GetSubgraphInputs(LrtSubgraph subgraph, lrt_param_index_t* num_inputs, - LrtTensorArray* inputs); - -// Get output tensors for given subgraph. -LrtStatus GetSubgraphOutputs(LrtSubgraph subgraph, - lrt_param_index_t* num_outputs, - LrtTensorArray* outputs); - -// Get all ops in given subgraph in a topological order. -LrtStatus GetSubgraphOps(LrtSubgraph subgraph, lrt_param_index_t* num_ops, - LrtOpArray* ops); - -// -// Model -// - -// Get number of subgraphs in model. -LrtStatus GetModelNumSubgraphs(LrtModel model, - lrt_param_index_t* num_subgraphs); - -// Get subgraph at given index in model. -LrtStatus GetModelSubgraph(LrtModel model, lrt_param_index_t subgraph_index, - LrtSubgraph* subgraph); - -// Get the index of the entry subgraph. -// TODO: b/365299994 - Figure out signatures. -LrtStatus GetModelMainSubgraph(LrtModel model, - lrt_param_index_t* main_subgraph_index); - -LrtStatus PushOp(LrtOpList op_list, LrtOp op); - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_MODEL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h b/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h deleted file mode 100644 index 52e39f463a0e07..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ - -#include "tensorflow/lite/builtin_ops.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -typedef enum { - kLrtOpCodeTflAdd = kTfLiteBuiltinAdd, - kLrtOpCodeTflAveragePool2d = kTfLiteBuiltinAveragePool2d, - kLrtOpCodeTflConcatenation = kTfLiteBuiltinConcatenation, - kLrtOpCodeTflConv2d = kTfLiteBuiltinConv2d, - kLrtOpCodeTflDepthwiseConv2d = kTfLiteBuiltinDepthwiseConv2d, - kLrtOpCodeTflDepthToSpace = kTfLiteBuiltinDepthToSpace, - kLrtOpCodeTflDequantize = kTfLiteBuiltinDequantize, - kLrtOpCodeTflEmbeddingLookup = kTfLiteBuiltinEmbeddingLookup, - kLrtOpCodeTflFloor = kTfLiteBuiltinFloor, - kLrtOpCodeTflFullyConnected = kTfLiteBuiltinFullyConnected, - kLrtOpCodeTflHashtableLookup = kTfLiteBuiltinHashtableLookup, - kLrtOpCodeTflL2Normalization = kTfLiteBuiltinL2Normalization, - kLrtOpCodeTflL2Pool2d = kTfLiteBuiltinL2Pool2d, - kLrtOpCodeTflLocalResponseNormalization = - kTfLiteBuiltinLocalResponseNormalization, - kLrtOpCodeTflLogistic = kTfLiteBuiltinLogistic, - kLrtOpCodeTflLshProjection = kTfLiteBuiltinLshProjection, - kLrtOpCodeTflLstm = kTfLiteBuiltinLstm, - kLrtOpCodeTflMaxPool2d = kTfLiteBuiltinMaxPool2d, - kLrtOpCodeTflMul = kTfLiteBuiltinMul, - kLrtOpCodeTflRelu = kTfLiteBuiltinRelu, - kLrtOpCodeTflReluN1To1 = kTfLiteBuiltinReluN1To1, - kLrtOpCodeTflRelu6 = kTfLiteBuiltinRelu6, - kLrtOpCodeTflReshape = kTfLiteBuiltinReshape, - kLrtOpCodeTflResizeBilinear = kTfLiteBuiltinResizeBilinear, - kLrtOpCodeTflRnn = kTfLiteBuiltinRnn, - kLrtOpCodeTflSoftmax = kTfLiteBuiltinSoftmax, - kLrtOpCodeTflSpaceToDepth = kTfLiteBuiltinSpaceToDepth, - kLrtOpCodeTflSvdf = kTfLiteBuiltinSvdf, - kLrtOpCodeTflTanh = kTfLiteBuiltinTanh, - kLrtOpCodeTflConcatEmbeddings = kTfLiteBuiltinConcatEmbeddings, - kLrtOpCodeTflSkipGram = kTfLiteBuiltinSkipGram, - kLrtOpCodeTflCall = kTfLiteBuiltinCall, - kLrtOpCodeTflCustom = kTfLiteBuiltinCustom, - kLrtOpCodeTflEmbeddingLookupSparse = kTfLiteBuiltinEmbeddingLookupSparse, - kLrtOpCodeTflPad = kTfLiteBuiltinPad, - kLrtOpCodeTflUnidirectionalSequenceRnn = - kTfLiteBuiltinUnidirectionalSequenceRnn, - kLrtOpCodeTflGather = kTfLiteBuiltinGather, - kLrtOpCodeTflBatchToSpaceNd = kTfLiteBuiltinBatchToSpaceNd, - kLrtOpCodeTflSpaceToBatchNd = kTfLiteBuiltinSpaceToBatchNd, - kLrtOpCodeTflTranspose = kTfLiteBuiltinTranspose, - kLrtOpCodeTflMean = kTfLiteBuiltinMean, - kLrtOpCodeTflSuv = kTfLiteBuiltinSub, - kLrtOpCodeTflDiv = kTfLiteBuiltinDiv, - kLrtOpCodeTflSqueeze = kTfLiteBuiltinSqueeze, - kLrtOpCodeTflUnidirectionalSequenceLstm = - kTfLiteBuiltinUnidirectionalSequenceLstm, - kLrtOpCodeTflStridedSlice = kTfLiteBuiltinStridedSlice, - kLrtOpCodeTflBidirectionalSequenceRnn = - kTfLiteBuiltinBidirectionalSequenceRnn, - kLrtOpCodeTflExp = kTfLiteBuiltinExp, - kLrtOpCodeTflTopkV2 = kTfLiteBuiltinTopkV2, - kLrtOpCodeTflSplit = kTfLiteBuiltinSplit, - kLrtOpCodeTflLogSoftmax = kTfLiteBuiltinLogSoftmax, - kLrtOpCodeTflDelegate = kTfLiteBuiltinDelegate, - kLrtOpCodeTflBidirectionalSequenceLstm = - kTfLiteBuiltinBidirectionalSequenceLstm, - kLrtOpCodeTflCast = kTfLiteBuiltinCast, - kLrtOpCodeTflPrelu = kTfLiteBuiltinPrelu, - kLrtOpCodeTflMaximum = kTfLiteBuiltinMaximum, - kLrtOpCodeTflArgMax = kTfLiteBuiltinArgMax, - kLrtOpCodeTflMinimum = kTfLiteBuiltinMinimum, - kLrtOpCodeTflLess = kTfLiteBuiltinLess, - kLrtOpCodeTflNeg = kTfLiteBuiltinNeg, - kLrtOpCodeTflPadv2 = kTfLiteBuiltinPadv2, - kLrtOpCodeTflGreater = kTfLiteBuiltinGreater, - kLrtOpCodeTflGreaterEqual = kTfLiteBuiltinGreaterEqual, - kLrtOpCodeTflLessEqual = kTfLiteBuiltinLessEqual, - kLrtOpCodeTflSelect = kTfLiteBuiltinSelect, - kLrtOpCodeTflSlice = kTfLiteBuiltinSlice, - kLrtOpCodeTflSin = kTfLiteBuiltinSin, - kLrtOpCodeTflTransposeConv = kTfLiteBuiltinTransposeConv, - kLrtOpCodeTflSparseToDense = kTfLiteBuiltinSparseToDense, - kLrtOpCodeTflTile = kTfLiteBuiltinTile, - kLrtOpCodeTflExpandDims = kTfLiteBuiltinExpandDims, - kLrtOpCodeTflEqual = kTfLiteBuiltinEqual, - kLrtOpCodeTflNotEqual = kTfLiteBuiltinNotEqual, - kLrtOpCodeTflLog = kTfLiteBuiltinLog, - kLrtOpCodeTflSum = kTfLiteBuiltinSum, - kLrtOpCodeTflSqrt = kTfLiteBuiltinSqrt, - kLrtOpCodeTflRsqrt = kTfLiteBuiltinRsqrt, - kLrtOpCodeTflShape = kTfLiteBuiltinShape, - kLrtOpCodeTflPow = kTfLiteBuiltinPow, - kLrtOpCodeTflArgMin = kTfLiteBuiltinArgMin, - kLrtOpCodeTflFakeQuant = kTfLiteBuiltinFakeQuant, - kLrtOpCodeTflReduceProd = kTfLiteBuiltinReduceProd, - kLrtOpCodeTflReduceMax = kTfLiteBuiltinReduceMax, - kLrtOpCodeTflPack = kTfLiteBuiltinPack, - kLrtOpCodeTflLogicalOr = kTfLiteBuiltinLogicalOr, - kLrtOpCodeTflOneHot = kTfLiteBuiltinOneHot, - kLrtOpCodeTflLogicalAnd = kTfLiteBuiltinLogicalAnd, - kLrtOpCodeTflLogicalNot = kTfLiteBuiltinLogicalNot, - kLrtOpCodeTflUnpack = kTfLiteBuiltinUnpack, - kLrtOpCodeTflReduceMin = kTfLiteBuiltinReduceMin, - kLrtOpCodeTflFloorDiv = kTfLiteBuiltinFloorDiv, - kLrtOpCodeTflReduceAny = kTfLiteBuiltinReduceAny, - kLrtOpCodeTflSquare = kTfLiteBuiltinSquare, - kLrtOpCodeTflZerosLike = kTfLiteBuiltinZerosLike, - kLrtOpCodeTflFill = kTfLiteBuiltinFill, - kLrtOpCodeTflFloorMod = kTfLiteBuiltinFloorMod, - kLrtOpCodeTflRange = kTfLiteBuiltinRange, - kLrtOpCodeTflResizeNearestNeighbor = kTfLiteBuiltinResizeNearestNeighbor, - kLrtOpCodeTflLeakyRelu = kTfLiteBuiltinLeakyRelu, - kLrtOpCodeTflSquaredDifference = kTfLiteBuiltinSquaredDifference, - kLrtOpCodeTflMirrorPad = kTfLiteBuiltinMirrorPad, - kLrtOpCodeTflAbs = kTfLiteBuiltinAbs, - kLrtOpCodeTflSplitV = kTfLiteBuiltinSplitV, - kLrtOpCodeTflUnique = kTfLiteBuiltinUnique, - kLrtOpCodeTflCeil = kTfLiteBuiltinCeil, - kLrtOpCodeTflReverseV2 = kTfLiteBuiltinReverseV2, - kLrtOpCodeTflAddN = kTfLiteBuiltinAddN, - kLrtOpCodeTflGatherNd = kTfLiteBuiltinGatherNd, - kLrtOpCodeTflCos = kTfLiteBuiltinCos, - kLrtOpCodeTflWhere = kTfLiteBuiltinWhere, - kLrtOpCodeTflRank = kTfLiteBuiltinRank, - kLrtOpCodeTflElu = kTfLiteBuiltinElu, - kLrtOpCodeTflReverseSequence = kTfLiteBuiltinReverseSequence, - kLrtOpCodeTflMatrixDiag = kTfLiteBuiltinMatrixDiag, - kLrtOpCodeTflQuantize = kTfLiteBuiltinQuantize, - kLrtOpCodeTflMatrixSetDiag = kTfLiteBuiltinMatrixSetDiag, - kLrtOpCodeTflRound = kTfLiteBuiltinRound, - kLrtOpCodeTflHardSwish = kTfLiteBuiltinHardSwish, - kLrtOpCodeTflIf = kTfLiteBuiltinIf, - kLrtOpCodeTflWhile = kTfLiteBuiltinWhile, - kLrtOpCodeTflNonMaxSuppressionV4 = kTfLiteBuiltinNonMaxSuppressionV4, - kLrtOpCodeTflNonMaxSuppressionV5 = kTfLiteBuiltinNonMaxSuppressionV5, - kLrtOpCodeTflScatterNd = kTfLiteBuiltinScatterNd, - kLrtOpCodeTflSelectV2 = kTfLiteBuiltinSelectV2, - kLrtOpCodeTflDensify = kTfLiteBuiltinDensify, - kLrtOpCodeTflSegmentSum = kTfLiteBuiltinSegmentSum, - kLrtOpCodeTflBatchMatmul = kTfLiteBuiltinBatchMatmul, - kLrtOpCodeTflPlaceholderForGreaterOpCodeTfls = - kTfLiteBuiltinPlaceholderForGreaterOpCodes, - kLrtOpCodeTflCumsum = kTfLiteBuiltinCumsum, - kLrtOpCodeTflCallOnce = kTfLiteBuiltinCallOnce, - kLrtOpCodeTflBroadcastTo = kTfLiteBuiltinBroadcastTo, - kLrtOpCodeTflRfft2d = kTfLiteBuiltinRfft2d, - kLrtOpCodeTflConv3d = kTfLiteBuiltinConv3d, - kLrtOpCodeTflImag = kTfLiteBuiltinImag, - kLrtOpCodeTflReal = kTfLiteBuiltinReal, - kLrtOpCodeTflComplexAbs = kTfLiteBuiltinComplexAbs, - kLrtOpCodeTflHashtable = kTfLiteBuiltinHashtable, - kLrtOpCodeTflHashtableFind = kTfLiteBuiltinHashtableFind, - kLrtOpCodeTflHashtableImport = kTfLiteBuiltinHashtableImport, - kLrtOpCodeTflHashtableSize = kTfLiteBuiltinHashtableSize, - kLrtOpCodeTflReduceAll = kTfLiteBuiltinReduceAll, - kLrtOpCodeTflConv3dTranspose = kTfLiteBuiltinConv3dTranspose, - kLrtOpCodeTflVarHandle = kTfLiteBuiltinVarHandle, - kLrtOpCodeTflReadVariable = kTfLiteBuiltinReadVariable, - kLrtOpCodeTflAssignVariable = kTfLiteBuiltinAssignVariable, - kLrtOpCodeTflBroadcastArgs = kTfLiteBuiltinBroadcastArgs, - kLrtOpCodeTflRandomStandardNormal = kTfLiteBuiltinRandomStandardNormal, - kLrtOpCodeTflBucketize = kTfLiteBuiltinBucketize, - kLrtOpCodeTflRandomUniform = kTfLiteBuiltinRandomUniform, - kLrtOpCodeTflMultinomial = kTfLiteBuiltinMultinomial, - kLrtOpCodeTflGelu = kTfLiteBuiltinGelu, - kLrtOpCodeTflDynamicUpdateSlice = kTfLiteBuiltinDynamicUpdateSlice, - kLrtOpCodeTflRelu0To1 = kTfLiteBuiltinRelu0To1, - kLrtOpCodeTflUnsortedSegmentProd = kTfLiteBuiltinUnsortedSegmentProd, - kLrtOpCodeTflUnsortedSegmentMax = kTfLiteBuiltinUnsortedSegmentMax, - kLrtOpCodeTflUnsortedSegmentSum = kTfLiteBuiltinUnsortedSegmentSum, - kLrtOpCodeTflAtan2 = kTfLiteBuiltinAtan2, - kLrtOpCodeTflUnsortedSegmentMin = kTfLiteBuiltinUnsortedSegmentMin, - kLrtOpCodeTflSign = kTfLiteBuiltinSign, - kLrtOpCodeTflBitcast = kTfLiteBuiltinBitcast, - kLrtOpCodeTflBitwiseXor = kTfLiteBuiltinBitwiseXor, - kLrtOpCodeTflRightShift = kTfLiteBuiltinRightShift, - kLrtOpCodeShloLogistic = kTfLiteBuiltinStablehloLogistic, - kLrtOpCodeShloAdd = kTfLiteBuiltinStablehloAdd, - kLrtOpCodeShloDivide = kTfLiteBuiltinStablehloDivide, - kLrtOpCodeShloMultiply = kTfLiteBuiltinStablehloMultiply, - kLrtOpCodeShloMaximum = kTfLiteBuiltinStablehloMaximum, - kLrtOpCodeShloReshape = kTfLiteBuiltinStablehloReshape, - kLrtOpCodeShloClamp = kTfLiteBuiltinStablehloClamp, - kLrtOpCodeShloConcatenate = kTfLiteBuiltinStablehloConcatenate, - kLrtOpCodeShloBroadcastInDim = kTfLiteBuiltinStablehloBroadcastInDim, - kLrtOpCodeShloConvolution = kTfLiteBuiltinStablehloConvolution, - kLrtOpCodeShloSlice = kTfLiteBuiltinStablehloSlice, - kLrtOpCodeShloCustomCall = kTfLiteBuiltinStablehloCustomCall, - kLrtOpCodeShloReduce = kTfLiteBuiltinStablehloReduce, - kLrtOpCodeShloAbs = kTfLiteBuiltinStablehloAbs, - kLrtOpCodeShloAnd = kTfLiteBuiltinStablehloAnd, - kLrtOpCodeShloCosine = kTfLiteBuiltinStablehloCosine, - kLrtOpCodeShloExponential = kTfLiteBuiltinStablehloExponential, - kLrtOpCodeShloFloor = kTfLiteBuiltinStablehloFloor, - kLrtOpCodeShloLog = kTfLiteBuiltinStablehloLog, - kLrtOpCodeShloMinimum = kTfLiteBuiltinStablehloMinimum, - kLrtOpCodeShloNegate = kTfLiteBuiltinStablehloNegate, - kLrtOpCodeShloOr = kTfLiteBuiltinStablehloOr, - kLrtOpCodeShloPower = kTfLiteBuiltinStablehloPower, - kLrtOpCodeShloRemainder = kTfLiteBuiltinStablehloRemainder, - kLrtOpCodeShloRsqrt = kTfLiteBuiltinStablehloRsqrt, - kLrtOpCodeShloSelect = kTfLiteBuiltinStablehloSelect, - kLrtOpCodeShloSubtract = kTfLiteBuiltinStablehloSubtract, - kLrtOpCodeShloTanh = kTfLiteBuiltinStablehloTanh, - kLrtOpCodeShloScatter = kTfLiteBuiltinStablehloScatter, - kLrtOpCodeShloCompare = kTfLiteBuiltinStablehloCompare, - kLrtOpCodeShloConvert = kTfLiteBuiltinStablehloConvert, - kLrtOpCodeShloDynamicSlice = kTfLiteBuiltinStablehloDynamicSlice, - kLrtOpCodeShloDynamicUpdateSlice = kTfLiteBuiltinStablehloDynamicUpdateSlice, - kLrtOpCodeShloPad = kTfLiteBuiltinStablehloPad, - kLrtOpCodeShloIota = kTfLiteBuiltinStablehloIota, - kLrtOpCodeShloGeneral = kTfLiteBuiltinStablehloDotGeneral, - kLrtOpCodeShloWindow = kTfLiteBuiltinStablehloReduceWindow, - kLrtOpCodeShloSort = kTfLiteBuiltinStablehloSort, - kLrtOpCodeShloWhile = kTfLiteBuiltinStablehloWhile, - kLrtOpCodeShloGather = kTfLiteBuiltinStablehloGather, - kLrtOpCodeShloTranspose = kTfLiteBuiltinStablehloTranspose, - kLrtOpCodeTflDilate = kTfLiteBuiltinDilate, - kLrtOpCodeShloRngBitGenerator = kTfLiteBuiltinStablehloRngBitGenerator, - kLrtOpCodeTflReduceWindow = kTfLiteBuiltinReduceWindow, - kLrtOpCodeShloComposite = kTfLiteBuiltinStablehloComposite, -} LrtOpCode; - -#ifdef __cplusplus -} -#endif // __cplusplus - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_OP_CODE_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h b/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h deleted file mode 100644 index 20c8ce08be76ef..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ - -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// #define LRT_ABORT abort() -// TODO: b/365295276 - Find a fatal error approach that will pass kokoro. -#define LRT_ABORT - -#define LRT_FATAL(msg) \ - { \ - fprintf(stderr, "%s\n", (msg)); \ - LRT_ABORT; \ - } - -#define LRT_RETURN_STATUS_IF_NOT_OK(expr) \ - { \ - LrtStatus stat = expr; \ - if (GetStatusCode(stat) != kLrtStatusOk) return stat; \ - StatusDestroy(stat); \ - } - -// TODO: b/365295276 - Add optional debug only print messages support -// to all macros. -#define LRT_RETURN_STATUS_IF_NOT_OK_MSG(expr, d_msg) \ - LRT_RETURN_STATUS_IF_NOT_OK(expr) - -#define LRT_RETURN_VAL_IF_NOT_OK(expr, ret_val) \ - { \ - LrtStatus stat = expr; \ - LrtStatusCode code_ = GetStatusCode(stat); \ - StatusDestroy(stat); \ - if (code_ != kLrtStatusOk) return ret_val; \ - } - - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_C_LITE_RT_SUPPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h b/tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h deleted file mode 100644 index f030c37079a616..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ - -#include - -#include -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_support.h" // IWYU pragma: export -#ifndef NDEBUG -#include // IWYU pragma: keep -#endif - -#define _CONCAT_NAME_IMPL(x, y) x##y - -#define _CONCAT_NAME(x, y) _CONCAT_NAME_IMPL(x, y) - -#define _RETURN_VAL(val) return val - -struct LrtStatusDeleter { - void operator()(LrtStatus status) { - if (status != nullptr) { - StatusDestroy(status); - } - } -}; - -using UniqueLrtStatus = std::unique_ptr; - -inline UniqueLrtStatus UniqueStatusFromCode(LrtStatusCode code) { - return UniqueLrtStatus(StatusCreate(code)); -} - -inline UniqueLrtStatus UniqueStatusOk() { - return UniqueStatusFromCode(kLrtStatusOk); -} - -// TODO: b/365295276 - Put all smart pointer wrappers in support.h. -struct LrtCompilerPluginDeleter { - void operator()(LrtCompilerPlugin plugin) { - if (plugin != nullptr) { - LrtPluginDestroy(plugin); - } - } -}; - -using UniqueLrtCompilerPlugin = - std::unique_ptr; - -// `StatusOr` analog for lrt. Very basic currently. -// TODO: b/365295276 - Figure out how to better infer template param -// and not require passing typing to macros. -template -class LrtResult { - public: - // TODO: b/365295276 - Implement emplace for LrtResult. - - static LrtResult FromValue(const T& value) { - LrtResult result; - result.data_ = value; - return result; - } - - static LrtResult TakeValue(T&& value) { - LrtResult result; - result.data_ = std::move(value); - return result; - } - - static LrtResult FromCode(LrtStatusCode code) { - LrtResult result; - result.data_ = code; - return result; - } - - T& Value() { - if (!HasValue()) { - LRT_FATAL("Result does not contain a value."); - } - return std::get(data_); - } - - LrtStatusCode Code() { - if (std::holds_alternative(data_)) { - return kLrtStatusOk; - } - return std::get(data_); - } - - bool HasValue() { return std::holds_alternative(data_); } - - private: - std::variant data_; -}; - -#ifdef NDEBUG -#define _LRT_D_MSG(msg) -#else -#define _LRT_D_MSG(msg) \ - std::cerr << msg << " " << __FILE__ << ":" << __LINE__ << "\n"; -#endif - -#ifdef LRT_RETURN_STATUS_IF_NOT_OK_MSG -#undef LRT_RETURN_STATUS_IF_NOT_OK_MSG -#define LRT_RETURN_STATUS_IF_NOT_OK_MSG(expr, d_msg) \ - { \ - LrtStatus stat = expr; \ - if (GetStatusCode(stat) != kLrtStatusOk) { \ - _LRT_D_MSG(d_msg) \ - return stat; \ - } \ - StatusDestroy(stat); \ - } -#endif - -// TODO: b/365295276 Make c friendly `CHECK` macro(s) and move to c api. -#define LRT_CHECK_STATUS_HAS_CODE_MSG(expr, code, d_msg) \ - { \ - LrtStatus stat = expr; \ - CHECK_NE(stat, nullptr); \ - LrtStatusCode code_ = GetStatusCode(stat); \ - StatusDestroy(stat); \ - if (code_ != code) { \ - _LRT_D_MSG(d_msg) \ - CHECK(false); \ - } \ - } - -#define LRT_CHECK_STATUS_HAS_CODE(expr, code) \ - LRT_CHECK_STATUS_HAS_CODE_MSG(expr, code, ""); - -#define LRT_CHECK_STATUS_OK(expr) LRT_CHECK_STATUS_HAS_CODE(expr, kLrtStatusOk); - -#define LRT_CHECK_STATUS_OK_MSG(expr, d_msg) \ - LRT_CHECK_STATUS_HAS_CODE_MSG(expr, kLrtStatusOk, d_msg); - -// If expr doesn't retur ok status, wrap as result and return. -#define LRT_RETURN_RESULT_IF_NOT_OK(expr, ty) \ - { \ - LrtStatus stat = (expr); \ - LrtStatusCode code_ = GetStatusCode(stat); \ - StatusDestroy(stat); \ - if (code_ != kLrtStatusOk) return LrtResult::FromCode(code_); \ - } - -#define _ASSIGN_OR_BLOCK(decl, expr, block, result) \ - auto result = (expr); \ - if (!result.HasValue()) { \ - block; \ - } \ - decl = result.Value(); - -#define _ASSIGN_OR_RETURN_VAL(decl, expr, val, result) \ - _ASSIGN_OR_BLOCK(decl, expr, _RETURN_VAL(val), result) - -// Assign value behind result returned from expr. If not ok, return val. -#define LRT_ASSIGN_OR_RETURN_VAL(decl, expr, val) \ - _ASSIGN_OR_RETURN_VAL(decl, expr, val, _CONCAT_NAME(_result, __COUNTER__)) - -#define _STATUS_FROM_RESULT(result) StatusCreate(result.Code()); - -#define _ASSIGN_OR_RETURN_STATUS(decl, expr, result) \ - _ASSIGN_OR_RETURN_VAL(decl, expr, _STATUS_FROM_RESULT(result), result) - -// Assign value behind result returned from expr. If not ok, return status. -#define LRT_ASSIGN_OR_RETURN_STATUS(decl, expr) \ - _ASSIGN_OR_RETURN_STATUS(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) - -#define _FORWARD_RESULT(result, ty) LrtResult::FromCode(result.Code()); - -#define _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, result) \ - _ASSIGN_OR_RETURN_VAL(decl, expr, _FORWARD_RESULT(result, ty), result) - -// Assign value behind result returned from expr. If not ok, return result. -#define LRT_ASSIGN_OR_RETURN_RESULT(decl, expr, ty) \ - _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CC_LITE_RT_SUPPORT_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD deleted file mode 100644 index d584670b49730e..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/BUILD +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], -) - -cc_library( - name = "api_internal", - srcs = ["lite_rt_common.cc"], - hdrs = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_common.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_compiler_plugin.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_model.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_op_code.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_support.h", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_support.h", - ], - deps = [ - "//tensorflow/lite:builtin_ops", - "//tensorflow/lite/core/c:c_api_types", - ], -) - -cc_library( - name = "model", - srcs = ["model.cc"], - hdrs = [ - "model.h", - ], - deps = [ - ":api_internal", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/core/c:c_api_types", - "//tensorflow/lite/schema:schema_fbs", - ], -) - -cc_library( - name = "lite_rt_model_init", - srcs = ["lite_rt_model_init.cc"], - hdrs = ["lite_rt_model_init.h"], - deps = [ - ":api_internal", - ":model", - "//tensorflow/compiler/mlir/lite:allocation", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite:framework", - "//tensorflow/lite:stderr_reporter", - "//tensorflow/lite/c:c_api_types", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/log:check", - "@flatbuffers//:runtime_cc", - ], -) - -cc_test( - name = "model_test", - srcs = ["model_test.cc"], - tags = ["no_oss"], - deps = [ - ":api_internal", - ":graph_tools", - ":lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_googletest//:gtest_main", - "@flatbuffers//:runtime_cc", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "algo", - hdrs = ["algo.h"], - deps = [ - ":api_internal", - ":model", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/log:check", - "@llvm-project//llvm:Support", - ], -) - -cc_test( - name = "algo_test", - srcs = ["algo_test.cc"], - tags = ["no_oss"], - deps = [ - ":algo", - ":api_internal", - ":graph_tools", - ":model", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "graph_tools", - hdrs = [ - "graph_tools.h", - ], - deps = [ - ":api_internal", - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/lite/c:c_api_types", - "@llvm-project//llvm:Support", - ], -) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h b/tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h deleted file mode 100644 index 49c937fb74af20..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h +++ /dev/null @@ -1,355 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ - -#include -#include -#include - -#ifndef NDEBUG -#include -#endif - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" - -#define _D_MATCH_TRUE(v) \ - { \ - std::cerr << "failed match true " << __FILE__ << __LINE__ << "\n"; \ - if (!(v)) return false; \ - } - -#define _D_MATCH_EQ(lhs, rhs) \ - { \ - std::cerr << "failed match eq " << __FILE__ << __LINE__ << "\n"; \ - if (lhs != rhs) return false; \ - } - -#define _MATCH_EQ(lhs, rhs) \ - { \ - if (lhs != rhs) return false; \ - } - -#define _MATCH_TRUE(v) \ - { \ - if (!(v)) return false; \ - } - -#ifndef NDEBUG -#define MATCH_EQ(lhs, rhs) _D_MATCH_EQ(lhs, rhs) -#define MATCH_TRUE(v) _D_MATCH_TRUE(v) -#else -#define MATCH_EQ(lhs, rhs) _MATCH_EQ(lhs, rhs) -#define MATCH_TRUE(v) _MATCH_TRUE(v) -#endif - -namespace graph_tools { - -using RankedTypeInfo = std::tuple>; - -using TensorUseInfo = std::tuple; - -//===----------------------------------------------------------------------===// -// Getters // -//===----------------------------------------------------------------------===// - -// TODO: b/365299994 - Switch llvm container types for mobile friendly ones. -// Likely will need to define them. - -// Get the ops that reference given tensor. -inline LrtResult> GetTensorUses( - LrtTensor tensor) { - lrt_param_index_t num_uses; - lrt_param_index_t* use_user_arg_ind; - LrtOpArray users = nullptr; - - LRT_RETURN_RESULT_IF_NOT_OK( - GetTensorUses(tensor, &num_uses, &users, &use_user_arg_ind), - llvm::SmallVector); - - llvm::ArrayRef users_arr(users, num_uses); - llvm::ArrayRef user_arg_ind_arr(use_user_arg_ind, - num_uses); - - auto results = llvm::zip(users_arr, user_arg_ind_arr); - llvm::SmallVector results_vec(results.begin(), results.end()); - - return LrtResult>::FromValue(results_vec); -} - -// Get the only user of given tensor, bad status if tensor doesn't have -// exactly one user. -inline LrtResult GetTensorOnlyUse(LrtTensor tensor) { - LRT_ASSIGN_OR_RETURN_RESULT(auto uses, GetTensorUses(tensor), TensorUseInfo); - if (uses.size() != 1) { - return LrtResult::FromCode(kLrtStatusGraphInvariantError); - } - return LrtResult::FromValue(uses[0]); -} - -// Get tensor inputs to given op. -inline LrtResult> GetOpIns(LrtOp op) { - lrt_param_index_t num_inputs; - LrtTensorArray inputs = nullptr; - - LRT_RETURN_RESULT_IF_NOT_OK(GetOpInputs(op, &num_inputs, &inputs), - llvm::ArrayRef); - - return LrtResult>::FromValue( - llvm::ArrayRef(inputs, num_inputs)); -} - -// Get the only tensor input to given op, bad status if op doesn't have -// exacty one input. -inline LrtResult GetOnlyOpIn(LrtOp op) { - LRT_ASSIGN_OR_RETURN_RESULT(auto ins, GetOpIns(op), LrtTensor); - if (ins.size() != 1) { - return LrtResult::FromCode(kLrtStatusGraphInvariantError); - } - return LrtResult::FromValue(ins[0]); -} - -// Get tensors outputs to given op. -inline LrtResult> GetOpOuts(LrtOp op) { - lrt_param_index_t num_outputs; - LrtTensorArray outputs = nullptr; - - LRT_RETURN_RESULT_IF_NOT_OK(GetOpOutputs(op, &num_outputs, &outputs), - llvm::ArrayRef); - - return LrtResult>::FromValue( - llvm::ArrayRef(outputs, num_outputs)); -} - -// Get the only tensor output to given op, bad status if op doesn't have -// exacty one output. -inline LrtResult GetOnlyOpOut(LrtOp op) { - LRT_ASSIGN_OR_RETURN_RESULT(auto outs, GetOpOuts(op), LrtTensor); - if (outs.size() != 1) { - return LrtResult::FromCode(kLrtStatusGraphInvariantError); - } - return LrtResult::FromValue(outs[0]); -} - -// Get all ops in given subgraph in topological order. -inline LrtResult> GetSubgraphOps(LrtSubgraph subgraph) { - lrt_param_index_t num_ops; - LrtOpArray ops = nullptr; - LRT_RETURN_RESULT_IF_NOT_OK(GetSubgraphOps(subgraph, &num_ops, &ops), - llvm::ArrayRef); - - return LrtResult>::FromValue( - llvm::ArrayRef(ops, num_ops)); -} - -// Get tensor inputs to given subgraph. -inline LrtResult> GetSubgraphInputs( - LrtSubgraph subgraph) { - lrt_param_index_t num_inputs; - LrtTensorArray inputs = nullptr; - LRT_RETURN_RESULT_IF_NOT_OK(GetSubgraphInputs(subgraph, &num_inputs, &inputs), - llvm::ArrayRef); - - return LrtResult>::FromValue( - llvm::ArrayRef(inputs, num_inputs)); -} - -// Get tensor outputs to given subgraph. -inline LrtResult> GetSubgraphOutputs( - LrtSubgraph subgraph) { - lrt_param_index_t num_outputs; - LrtTensorArray outputs = nullptr; - LRT_RETURN_RESULT_IF_NOT_OK( - GetSubgraphOutputs(subgraph, &num_outputs, &outputs), - llvm::ArrayRef); - - return LrtResult>::FromValue( - llvm::ArrayRef(outputs, num_outputs)); -} - -// Get only subgraph in given model, bad status if model doens't have exactly -// one subgraph. -// TODO: b/365299994 - Add multi-subgraph getters for graph tools. -inline LrtResult GetSubgraph(LrtModel model) { - lrt_param_index_t num_subgraphs; - LRT_RETURN_RESULT_IF_NOT_OK(GetModelNumSubgraphs(model, &num_subgraphs), - LrtSubgraph); - - if (num_subgraphs != 1) { - return LrtResult::FromCode(kLrtStatusErrorUnsupported); - } - - LrtSubgraph subgraph = nullptr; - LRT_RETURN_RESULT_IF_NOT_OK(GetModelSubgraph(model, 0, &subgraph), - LrtSubgraph); - - return LrtResult::FromValue(subgraph); -} - -//===----------------------------------------------------------------------===// -// Matchers // -//===----------------------------------------------------------------------===// - -// Matches tensor type id, shape and element type for given tensor. -inline bool MatchRankedTensorType(LrtTensor tensor, LrtElementType element_type, - llvm::ArrayRef shape) { - LrtTensorTypeId type_id; - LRT_RETURN_VAL_IF_NOT_OK(GetTensorTypeId(tensor, &type_id), false); - MATCH_EQ(type_id, kLrtRankedTensorType); - - LrtRankedTensorType ranked_tensor_type; - LRT_RETURN_VAL_IF_NOT_OK(GetRankedTensorType(tensor, &ranked_tensor_type), - false); - MATCH_EQ(ranked_tensor_type.element_type, element_type); - MATCH_EQ(ranked_tensor_type.layout.rank, shape.size()); - - for (int i = 0; i < shape.size(); ++i) { - MATCH_EQ(shape[i], ranked_tensor_type.layout.dimensions[i]); - } - - return true; -} - -// Matches users of given tensor (ordering doesn't matter). If strict is true, -// `use_info` must have same number of elements as tensor has uses. If not, -// it must be a subset. -inline bool MatchTensorHasUses(LrtTensor tensor, - llvm::ArrayRef use_info, - bool strict = true) { - // uses are unique so this is sufficient to check for equality. - LRT_ASSIGN_OR_RETURN_VAL(auto uses, GetTensorUses(tensor), false); - MATCH_TRUE(!strict || (uses.size() == use_info.size())); - - llvm::SetVector unique_uses(uses.begin(), uses.end()); - - return llvm::all_of(use_info, - [&](auto use) { return unique_uses.contains(use); }); -} - -// Matches a tensor with no uses. -inline bool MatchkTensorNoUses(LrtTensor tensor) { - lrt_param_index_t num_uses; - lrt_param_index_t* use_user_arg_ind; - LrtOpArray users = nullptr; - - LRT_RETURN_VAL_IF_NOT_OK( - GetTensorUses(tensor, &num_uses, &users, &use_user_arg_ind), false); - - return num_uses == 0; -} - -// Matches a tensors defining op and output indice. -inline bool MatchTensorDefiningOp( - LrtTensor tensor, lrt_param_index_t expected_defining_op_out_ind, - LrtOp expected_defining_op) { - LrtOp defining_op = nullptr; - lrt_param_index_t defining_op_out_ind; - - LRT_RETURN_VAL_IF_NOT_OK( - GetTensorDefiningOp(tensor, &defining_op, &defining_op_out_ind), false); - MATCH_EQ(defining_op, expected_defining_op); - - return expected_defining_op == nullptr || - expected_defining_op_out_ind == defining_op_out_ind; -} - -// Matches a tensor that is not the output of an op (subgraph inputs/consts). -inline bool MatchTensorNoDefiningOp(LrtTensor tensor) { - return MatchTensorDefiningOp(tensor, 0, nullptr); -} - -// Matches the op code and types of given ops inputs and outputs. -inline bool MatchOpType(LrtOp op, - llvm::ArrayRef input_type_info, - llvm::ArrayRef output_type_info, - LrtOpCode code) { - LrtOpCode actual_code; - LRT_RETURN_VAL_IF_NOT_OK(GetOpCode(op, &actual_code), false); - MATCH_EQ(actual_code, code); - - const auto exptected_num_inputs = input_type_info.size(); - - LRT_ASSIGN_OR_RETURN_VAL(auto inputs, GetOpIns(op), false); - for (int i = 0; i < exptected_num_inputs; ++i) { - const auto& [type, shape] = input_type_info[i]; - MATCH_TRUE(MatchRankedTensorType(inputs[i], type, shape)); - } - - const auto expected_num_outputs = output_type_info.size(); - - LRT_ASSIGN_OR_RETURN_VAL(auto outputs, GetOpOuts(op), false); - for (int i = 0; i < expected_num_outputs; ++i) { - const auto& [type, shape] = output_type_info[i]; - MATCH_TRUE(MatchRankedTensorType(outputs[i], type, shape)); - } - - return true; -} - -// Checks that doubly linked structure of ops <-> tensors is valid. -inline bool ValidateTopology(llvm::ArrayRef ops) { - for (auto& op : ops) { - LRT_ASSIGN_OR_RETURN_VAL(auto inputs, GetOpIns(op), false); - for (auto [input_ind, input] : llvm::enumerate(inputs)) { - MATCH_TRUE(MatchTensorHasUses(input, {{op, input_ind}}, false)); - } - - LRT_ASSIGN_OR_RETURN_VAL(auto outputs, GetOpOuts(op), false); - for (auto [output_ind, output] : llvm::enumerate(outputs)) { - MATCH_TRUE(MatchTensorDefiningOp(output, output_ind, op)); - } - } - return true; -} - -// Match buffer behind given tensor contains data. -template -inline bool MatchBuffer(LrtTensor tensor, llvm::ArrayRef expected_data) { - LrtBuffer buffer = nullptr; - LRT_RETURN_VAL_IF_NOT_OK(GetTensorBuffer(tensor, &buffer), false); - MATCH_TRUE(buffer != nullptr); - - size_t size; - const void* data = nullptr; - LRT_RETURN_VAL_IF_NOT_OK(GetBufferInfo(buffer, &size, &data), false); - MATCH_TRUE(data != nullptr); - - MATCH_EQ(size, expected_data.size() * sizeof(T)); - return llvm::ArrayRef(static_cast(data), expected_data.size()) == - expected_data; -} - -// Match given tensor having no (empty) buffer. -inline bool MatchNoBuffer(LrtTensor tensor) { - LrtBuffer buffer = nullptr; - LRT_RETURN_VAL_IF_NOT_OK(GetTensorBuffer(tensor, &buffer), false); - MATCH_TRUE(buffer != nullptr); - - size_t size; - const void* data = nullptr; - LRT_RETURN_VAL_IF_NOT_OK(GetBufferInfo(buffer, &size, &data), false); - - return size == 0; -} -} // namespace graph_tools - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_GRAPH_TOOLS_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc b/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc deleted file mode 100644 index 02e024532c98fe..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" - -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" - -// -// Model -// - -LrtStatus GetModelNumSubgraphs(LrtModel model, - lrt_param_index_t* num_subgraphs) { - *num_subgraphs = model->subgraphs.size(); - return StatusOk(); -} - -LrtStatus GetModelSubgraph(LrtModel model, lrt_param_index_t subgraph_index, - LrtSubgraph* subgraph) { - if (subgraph_index >= model->subgraphs.size()) { - return StatusCreate(kLrtParamIndexOOB); - } - *subgraph = model->subgraphs.data() + subgraph_index; - return StatusOk(); -} - -LrtStatus GetModelMainSubgraph(LrtModel model, - lrt_param_index_t* main_subgraph_index) { - // TODO replace this with signature. - *main_subgraph_index = 0; - return StatusOk(); -} - -void ModelDestroy(LrtModel model) { delete model; } - -LrtStatus PushOp(LrtOpList op_list, LrtOp op) { - op_list->ops.push_back(op); - return StatusOk(); -} - -// -// Subgraph -// - -LrtStatus GetSubgraphInputs(LrtSubgraph subgraph, lrt_param_index_t* num_inputs, - LrtTensorArray* inputs) { - *num_inputs = subgraph->inputs.size(); - *inputs = subgraph->inputs.data(); - return StatusOk(); -} - -LrtStatus GetSubgraphOutputs(LrtSubgraph subgraph, - lrt_param_index_t* num_outputs, - LrtTensorArray* outputs) { - *num_outputs = subgraph->outputs.size(); - *outputs = subgraph->outputs.data(); - return StatusOk(); -} - -LrtStatus GetSubgraphOps(LrtSubgraph subgraph, lrt_param_index_t* num_ops, - LrtOpArray* ops) { - *num_ops = subgraph->ops.size(); - *ops = subgraph->ops.data(); - return StatusOk(); -} - -// -// Op -// - -LrtStatus GetOpOutputs(LrtOp op, lrt_param_index_t* num_outputs, - LrtTensorArray* outputs) { - *num_outputs = op->outputs.size(); - *outputs = op->outputs.data(); - return StatusOk(); -} - -LrtStatus GetOpInputs(LrtOp op, lrt_param_index_t* num_inputs, - LrtTensorArray* inputs) { - *num_inputs = op->inputs.size(); - *inputs = op->inputs.data(); - return StatusOk(); -} - -LrtStatus GetOpCode(LrtOp op, LrtOpCode* code) { - *code = op->op_code; - return StatusOk(); -} - -// -// Tensor -// - -LrtStatus GetBufferInfo(LrtBuffer buffer, size_t* size, const void** addr) { - if (buffer->fb_buffer == nullptr) { - *size = 0; - *addr = nullptr; - } else { - *size = buffer->fb_buffer->data.size(); - *addr = buffer->fb_buffer->data.data(); - } - return StatusOk(); -} - -LrtStatus GetTensorBuffer(LrtTensor tensor, LrtBuffer* buffer) { - *buffer = &tensor->buffer; - return StatusOk(); -} - -LrtStatus GetTensorUses(LrtTensor tensor, lrt_param_index_t* num_uses, - LrtOpArray* use_users, - lrt_param_index_t** use_user_arg_inds) { - *num_uses = tensor->users.size(); - *use_users = tensor->users.data(); - *use_user_arg_inds = tensor->user_arg_inds.data(); - return StatusOk(); -} - -// Null if subgraph input or constant. -LrtStatus GetTensorDefiningOp(LrtTensor tensor, LrtOp* maybe_defining_op, - lrt_param_index_t* maybe_defining_op_output_ind) { - if (tensor->defining_op != nullptr) { - *maybe_defining_op = tensor->defining_op; - *maybe_defining_op_output_ind = tensor->defining_op_out_ind; - } - return StatusOk(); -} - -LrtStatus GetTensorTypeId(LrtTensor tensor, LrtTensorTypeId* type_id) { - *type_id = tensor->type_id; - return StatusOk(); -} - -LrtStatus GetUrankedTensorType(LrtTensor tensor, - LrtUnrankedTensorType* unranked_tensor_type) { - if (tensor->type_id != kLrtUnrankedTensorType) { - return StatusCreate(kLrtStatusBadTensorType); - } - *unranked_tensor_type = tensor->type_detail.unranked_tensor_type; - return StatusOk(); -} - -LrtStatus GetRankedTensorType(LrtTensor tensor, - LrtRankedTensorType* ranked_tensor_type) { - if (tensor->type_id != kLrtRankedTensorType) { - return StatusCreate(kLrtStatusBadTensorType); - } - *ranked_tensor_type = tensor->type_detail.ranked_tensor_type; - return StatusOk(); -} diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h b/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h deleted file mode 100644 index 72d0f7d4e0e8af..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h +++ /dev/null @@ -1,242 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ - -#include -#ifndef NDEBUG -#include -#include -#endif - -#include -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/lite/core/c/c_api_types.h" -#include "tensorflow/lite/schema/schema_generated.h" - -// -// Tensor -// - -struct LrtBufferT { - std::unique_ptr fb_buffer = nullptr; -}; - -typedef union { - LrtUnrankedTensorType unranked_tensor_type; - LrtRankedTensorType ranked_tensor_type; -} LrtTypeDetail; - -struct LrtTensorT { - // Empty if subgraph output. This is a reference. - std::vector users; - - // Which arg number for user i. - std::vector user_arg_inds; - - // Null if subgraph input or constant. This is a reference. - LrtOp defining_op = nullptr; - - // Which output ind from defining op made this tensor. - lrt_param_index_t defining_op_out_ind; - - // Not a reference. - LrtBufferT buffer; - - LrtTensorTypeId type_id; - - LrtTypeDetail type_detail; -}; - -// -// Op -// - -struct LrtOpT { - // These are references. - std::vector inputs; - - // These are references. - std::vector outputs; - - LrtOpCode op_code; - - // This is a placeholder to be usd by just custom ops for now. - std::string custom_options; - - // TODO: b/365299994 - Add support for op options. -}; - -// -// Subgraph -// - -struct LrtSubgraphT { - // Storage and views of tensors. Clients are only shown views. Facilitates - // efficient topological mutation. - std::list tensors_storage; - std::vector tensors; - - // Storage and vies of ops. - std::list ops_storage; - std::vector ops; - - // Shared view of initial flatbuffer data. - std::shared_ptr flatbuffer_subgraph; - - // These are references and a subset of `tensors`. - std::vector inputs; - - // These are references and a subset of `tensors`. - std::vector outputs; -}; - -// -// Model -// - -// A (partial) unpacking of the flatbuffer model into a list of subgraphs. -// Keeps a reference to the flatbuffer model. Lifetimes of all storage -// are linked to the containing model. -struct LrtModelT { - // Subgraphs that have been unpacked into usable types. - std::vector subgraphs; - - // TODO: b/365299994 - Delete this. - // Shared views of remaining unpacked flatbuffer data. - std::vector> flatbuffer_subgraphs; - - // Initial flatbuffer loaded in. "Subgraphs" field has been invalidated. - std::unique_ptr flatbuffer_model; - - // Custom code associated with all customs ops emitted during - // re-serialization. - std::string custom_op_code; -}; - -// -// Utils -// - -// Used for communicating selections of ops. -struct LrtOpListT { - std::vector ops; -}; - -namespace debug { - -// TODO: b/365299994 - Flesh out printing api and move elsewhere. -inline void DumpOp(const LrtOpT& op) { -#ifndef NDEBUG - using DumpInfo = std::pair, std::string>; - - auto op_name = [&](const LrtOpT& op) -> std::string { - std::stringstream result; - switch (op.op_code) { - case kLrtOpCodeTflAdd: - result << "TFL_ADD"; - break; - case kLrtOpCodeTflMul: - result << "TFL_MUL"; - break; - case kLrtOpCodeTflCustom: - result << "TFL_CUSTOM_OP"; - break; - default: - result << "UKNOWN_OP_CODE: " << op.op_code; - break; - } - result << " " << &op; - return result.str(); - }; - - // TODO: b/365299994 - Pull tensor dump into separate functiona nd - // only dump relevant topology when called in DumpOp. - auto tensor_dump = [&](const LrtTensorT& tensor) -> DumpInfo { - DumpInfo result; - - for (int i = 0; i < tensor.users.size(); ++i) { - auto& user = result.first.emplace_back(); - char* s; - asprintf(&s, "%s [%lu], ", op_name(*tensor.users[i]).c_str(), - tensor.user_arg_inds[i]); - user.assign(s); - free(s); - } - - if (tensor.defining_op != nullptr) { - char* s; - asprintf(&s, "%s [%lu], ", op_name(*tensor.defining_op).c_str(), - tensor.defining_op_out_ind); - result.second.assign(s); - free(s); - } else { - result.second = "NO DEF OP"; - } - - return result; - }; - - auto validate_tensor = [](const LrtTensorT& tensor) -> void { - if (tensor.users.size() != tensor.user_arg_inds.size()) { - LRT_FATAL("Invalid tensor."); - } - }; - - auto print_users = [](const DumpInfo& info) { - for (const auto& user : info.first) { - std::cerr << " USER: " << user << "\n"; - } - }; - - auto print_def = [](const DumpInfo& info) { - std::cerr << " DEFINING OP: " << info.second << "\n"; - }; - - std::cerr << op_name(op) << " {\n"; - - for (const auto& inp : op.inputs) { - validate_tensor(*inp); - std::cerr << " INPUT: " << &inp << "\n"; - print_def(tensor_dump(*inp)); - std::cerr << "\n"; - } - - for (const auto& out : op.outputs) { - validate_tensor(*out); - std::cerr << " OUTPUT: " << &out << "\n"; - print_users(tensor_dump(*out)); - if (out != op.outputs.back()) { - std::cerr << "\n"; - } - } - - std::cerr << "}\n"; -#endif -} - -} // namespace debug - -// TODO: b/365299994 - Make dumping a generic streamable. -#define LRT_DUMP_OP(op) \ - _LRT_D_MSG(""); \ - debug::DumpOp(op); - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_MODEL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD deleted file mode 100644 index fbb21622ab2d30..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/BUILD +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:private"], -) - -cc_library( - name = "mul_op_plugin", - srcs = ["mul_op_plugin.cc"], - deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:graph_tools", - ], -) - -cc_shared_library( - name = "mul_op_plugin_so", - shared_lib_name = "mul_op_plugin.so", - visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], - deps = [":mul_op_plugin"], -) - -cc_test( - name = "mul_op_plugin_test", - srcs = ["mul_op_plugin_test.cc"], - tags = ["no_oss"], - deps = [ - ":mul_op_plugin", # buildcleaner: keep - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:graph_tools", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", - "//tensorflow/compiler/mlir/lite/experimental/lrt/test_data:test_data_util", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest_main", - ], -) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc deleted file mode 100644 index 867195abf29377..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin.cc +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include -#include -#include -#include - -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" - -// -// Configurations -// - -constexpr char kPluginMan[] = "ExampleSocManufacturer"; -constexpr char kPluginModel[] = "DummyMulOp"; - -const char* LrtPluginSocManufacturer() { return kPluginMan; } - -lrt_param_index_t LrtPluginNumSupportedSocModels( - LrtCompilerPlugin compiler_plugin) { - return 1; -} - -LrtStatus LrtPluginGetSupportedSocModelId(LrtCompilerPlugin compiler_plugin, - lrt_param_index_t config_idx, - const char** config_id) { - if (config_idx != 0) { - return StatusCreate(kLrtStatusErrorUnsupported); - } - *config_id = kPluginModel; - return StatusOk(); -} - -// -// Compiled Result Definition -// - -struct LrtCompiledResultT { - std::string byte_code; - std::vector per_op_data; -}; - -LrtStatus LrtCompiledResultGetByteCode(LrtCompiledResult compiled_result, - const void** byte_code, - size_t* byte_code_size) { - *byte_code = compiled_result->byte_code.data(); - *byte_code_size = compiled_result->byte_code.size(); - return StatusOk(); -} - -LrtStatus LrtCompiledResultGetCallInfo(LrtCompiledResult compiled_result, - lrt_param_index_t call_idx, - const void** call_info, - size_t* call_info_size) { - if (call_idx >= compiled_result->per_op_data.size()) { - return StatusCreate(kLrtParamIndexOOB); - } - - *call_info = compiled_result->per_op_data.at(call_idx).data(); - *call_info_size = compiled_result->per_op_data.at(call_idx).size(); - - return StatusOk(); -} - -LrtStatus LrtCompiledResultGetNumCalls(LrtCompiledResult compiled_result, - lrt_param_index_t* num_calls) { - *num_calls = compiled_result->per_op_data.size(); - return StatusOk(); -} - -void LrtCompiledResultDestroy(LrtCompiledResult compiled_result) { - delete compiled_result; -} - -// -// Plugin Definition -// - -// Plugins can hold state. -struct LrtCompilerPluginT { -}; - -LrtStatus LrtPluginInit(LrtCompilerPlugin* compiler_plugin) { - *compiler_plugin = new LrtCompilerPluginT; - return StatusOk(); -} - -void LrtPluginDestroy(LrtCompilerPlugin compiler_plugin) { - delete compiler_plugin; -} - -LrtStatus LrtPluginPartitionModel(LrtCompilerPlugin compiler_plugin, - LrtModel model, LrtOpList selected_ops) { - LRT_ASSIGN_OR_RETURN_STATUS(auto subgraph, graph_tools::GetSubgraph(model)); - LRT_ASSIGN_OR_RETURN_STATUS(auto ops, graph_tools::GetSubgraphOps(subgraph)); - - for (auto op : ops) { - LrtOpCode op_code; - LRT_RETURN_STATUS_IF_NOT_OK(GetOpCode(op, &op_code)); - if (op_code != kLrtOpCodeTflMul) { - continue; - } - LRT_RETURN_STATUS_IF_NOT_OK(PushOp(selected_ops, op)); - } - return StatusOk(); -} - -LrtStatus CompileSinglePartition(lrt_param_index_t partition_index, - LrtSubgraph subgraph, - LrtCompiledResultT& result) { - LRT_ASSIGN_OR_RETURN_STATUS(auto ops, graph_tools::GetSubgraphOps(subgraph)); - - int num_muls_in_partition = 0; - for (auto op : ops) { - LrtOpCode op_code; - - LRT_RETURN_STATUS_IF_NOT_OK(GetOpCode(op, &op_code)); - if (op_code != kLrtOpCodeTflMul) { - return StatusCreate(kLrtStatusErrorUnsupported); - } - - ++num_muls_in_partition; - } - - { - char* byte_code_append; - (void)asprintf(&byte_code_append, - "Partition_%lu_with_%d_muls:", partition_index, - num_muls_in_partition); - result.byte_code.append(byte_code_append); - free(byte_code_append); - } - - { - char* per_op_data; - (void)asprintf(&per_op_data, "Partition_%lu", partition_index); - result.per_op_data.push_back(per_op_data); - free(per_op_data); - } - - return StatusOk(); -} - -LrtStatus LrtPluginCompile(LrtCompilerPlugin compiler_plugin, - LrtSubgraphArray partitions, - lrt_param_index_t num_partitions, - LrtCompiledResult* compiled_result) { - LrtCompiledResult result = new LrtCompiledResultT; - - for (auto i = 0; i < num_partitions; ++i) { - LRT_RETURN_STATUS_IF_NOT_OK( - CompileSinglePartition(i, partitions[i], *result)); - } - - *compiled_result = result; - - return StatusOk(); -} diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc b/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc deleted file mode 100644 index 46eafd32136a38..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/examples/mul_op_plugin_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include -#include "absl/log/check.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_compiler_plugin.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" - -namespace { - -UniqueLrtCompilerPlugin GetDummyPlugin() { - LrtCompilerPlugin dummy_plugin; - LRT_CHECK_STATUS_OK(LrtPluginInit(&dummy_plugin)); - CHECK_NE(dummy_plugin, nullptr); - return UniqueLrtCompilerPlugin(dummy_plugin); -} - -TEST(TestDummyPlugin, GetConfigInfo) { - ASSERT_STREQ(LrtPluginSocManufacturer(), "ExampleSocManufacturer"); - - auto plugin = GetDummyPlugin(); - - ASSERT_EQ(1, LrtPluginNumSupportedSocModels(plugin.get())); - - const char* config_id; - ASSERT_STATUS_OK( - LrtPluginGetSupportedSocModelId(plugin.get(), 0, &config_id)); - ASSERT_STREQ(config_id, "DummyMulOp"); -} - -TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { - auto plugin = GetDummyPlugin(); - auto model = LoadTestFileModel("simple_multi_op.tflite"); - - LrtOpListT selected_ops; - ASSERT_STATUS_OK( - LrtPluginPartitionModel(plugin.get(), model.get(), &selected_ops)); - - ASSERT_EQ(selected_ops.ops.size(), 2); - ASSERT_EQ(selected_ops.ops[0]->op_code, kLrtOpCodeTflMul); - ASSERT_EQ(selected_ops.ops[1]->op_code, kLrtOpCodeTflMul); -} - -TEST(TestCallDummyPlugin, CompileMulSubgraph) { - auto plugin = GetDummyPlugin(); - auto model = LoadTestFileModel("mul_simple.tflite"); - - ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); - - LrtCompiledResult compiled; - ASSERT_STATUS_OK(LrtPluginCompile(plugin.get(), &subgraph, 1, &compiled)); - - const void* byte_code; - size_t byte_code_size; - - ASSERT_STATUS_OK( - LrtCompiledResultGetByteCode(compiled, &byte_code, &byte_code_size)); - - std::string byte_code_string(reinterpret_cast(byte_code), - byte_code_size); - ASSERT_EQ(byte_code_string, "Partition_0_with_2_muls:"); - - const void* op_data; - size_t op_data_size; - - ASSERT_STATUS_OK( - LrtCompiledResultGetCallInfo(compiled, 0, &op_data, &op_data_size)); - - std::string op_data_string(reinterpret_cast(op_data), - op_data_size); - ASSERT_EQ(op_data_string, "Partition_0"); - - LrtCompiledResultDestroy(compiled); -} - -} // namespace diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD b/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD deleted file mode 100644 index cee72cde127b9a..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], -) - -# TODO: b/365295276 - Make custom rule and move to `.sh`. - -OUT_DIR = "$(RULEDIR)" - -CONVERTER = "//tensorflow/compiler/mlir/lite:tf_tfl_translate" - -CMD = """ -for mlir_file in $(SRCS); do - $(location {converter}) --input-mlir $$mlir_file --o={out_dir}/$$(basename $$mlir_file .mlir).tflite -done -""".format( - converter = CONVERTER, - out_dir = OUT_DIR, -) - -genrule( - name = "tflite_test_data", - srcs = glob(["*.mlir"]), - outs = [s.removesuffix(".mlir") + ".tflite" for s in glob(["*.mlir"])], - cmd = CMD, - tools = ["//tensorflow/compiler/mlir/lite:tf_tfl_translate"], -) - -cc_library( - name = "test_data_util", - testonly = 1, - hdrs = ["test_data_util.h"], - data = [":tflite_test_data"], - deps = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/cc:lite_rt_cc_api", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:lite_rt_model_init", - "//tensorflow/compiler/mlir/lite/experimental/lrt/core:model", - "@com_google_absl//absl/log:check", - "@local_tsl//tsl/platform", - ], -) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h b/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h deleted file mode 100644 index 889a19711b190f..00000000000000 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright 2024 Google LLC. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ - -// NOLINTNEXTLINE -#include -#include -#include - -#include "absl/log/check.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tsl/platform/platform.h" - -#define _ASSERT_RESULT_OK_ASSIGN(decl, expr, result) \ - auto result = (expr); \ - ASSERT_TRUE(result.HasValue()); \ - decl = result.Value(); - -#define ASSERT_RESULT_OK_ASSIGN(decl, expr) \ - _ASSERT_RESULT_OK_ASSIGN(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) - -#define ASSERT_STATUS_HAS_CODE(expr, code) \ - { \ - auto stat = (expr); \ - auto code_ = GetStatusCode(stat); \ - StatusDestroy(stat); \ - ASSERT_EQ(code_, code); \ - } - -#define ASSERT_STATUS_OK(expr) ASSERT_STATUS_HAS_CODE(expr, kLrtStatusOk); - -inline std::string GetTestFilePath(std::string_view filename) { - static constexpr std::string_view kTestDataDir = - "tensorflow/compiler/mlir/lite/experimental/lrt/" - "test_data/"; - - std::filesystem::path result_path; - if constexpr (!tsl::kIsOpenSource) { - result_path.append("third_party"); - } - - result_path.append(kTestDataDir); - result_path.append(filename); - - return result_path.generic_string(); -} - -inline UniqueLrtModel LoadTestFileModel(std::string_view filename) { - LrtModel model = nullptr; - LRT_CHECK_STATUS_OK( - LoadModelFromFile(GetTestFilePath(filename).c_str(), &model)); - CHECK_NE(model, nullptr); - return UniqueLrtModel(model); -} - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_TEST_DATA_TEST_DATA_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD index 089d5e695ea20a..73937af298e896 100644 --- a/tensorflow/compiler/mlir/lite/experimental/remat/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/remat/BUILD @@ -26,11 +26,7 @@ cc_library( srcs = ["metadata_util.cc"], hdrs = ["metadata_util.h"], compatible_with = get_compatible_with_portable(), - visibility = [ - "//tensorflow/compiler/mlir/lite:__pkg__", - "//tensorflow/lite/core:__pkg__", - "//tensorflow/lite/delegates:__pkg__", - ], + visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/mlir/lite:control_edges", ], diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h index a89d7cacb7e6e7..741ee5d5203d39 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc index 0c37a8da20575f..d9b3711339e55c 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc index 278c54e8805f3d..49113a666a4172 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index fad3e5c1409372..896131bf877915 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -60,7 +60,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index a289126d26b6ca..512f6047f1c058 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -29,7 +29,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/APFloat.h" @@ -49,8 +51,8 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -80,10 +82,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/schema/mutable/debug_metadata_generated.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" #include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -99,6 +102,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -130,6 +134,17 @@ using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; using ::tflite::IsValidBufferOffset; +struct DebugMetadata { + // Debug metadata locations. + std::vector debug_metadata_locations; + + // Maps from operator (subgraph_debug_metadata_idx, + // operator_debug_metadata_idx) to its top-level location index in + // `debug_metadata_locations`, which is: + // <, location_idx>. + absl::flat_hash_map> operator_location_map; +}; + // Create the MLIR NamedLoc location corresponding to a given tensor Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { if (tensor.name.empty()) { @@ -138,27 +153,223 @@ Location TensorLoc(const TensorT& tensor, Builder builder, Location base) { return mlir::NameLoc::get(builder.getStringAttr(tensor.name), base); } -// Create the MLIR Location corresponding to a given op. This is an -// experimental/debugging feature and production code should not rely on names -// of intermediate tensors since importer doesn't guarantee to preserve tensor -// names except output tensors. -Location OpLoc(const OperatorT& op, - const std::vector>& tensors, - Builder builder, Location base) { +// Build and return the MLIR location. +StatusOr BuildLocation( + Builder builder, const debug_metadata::Location& location, + const std::vector& debug_metadata_locations, + const absl::flat_hash_map& + attribute_location_idx_map) { + switch (location.location_type()) { + // FileLineColLoc. + case debug_metadata::LocationType_FileLineColLoc: { + auto file_line_col_loc = + static_cast( + location.location()); + return mlir::FileLineColLoc::get( + builder.getContext(), + builder.getStringAttr(file_line_col_loc->filename()->string_view()), + file_line_col_loc->line(), file_line_col_loc->column()); + } + // CallSiteLoc. + case debug_metadata::LocationType_CallSiteLoc: { + auto callsite_loc = + static_cast(location.location()); + if (!attribute_location_idx_map.contains(callsite_loc->callee_index()) || + !attribute_location_idx_map.contains(callsite_loc->caller_index())) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (callee " + "or caller index of a CallSiteLoc is not valid)"); + } + return mlir::CallSiteLoc::get( + debug_metadata_locations[attribute_location_idx_map.at( + callsite_loc->callee_index())], + debug_metadata_locations[attribute_location_idx_map.at( + callsite_loc->caller_index())]); + } + // NameLoc. + case debug_metadata::LocationType_NameLoc: { + auto name_loc = + static_cast(location.location()); + if (!attribute_location_idx_map.contains(name_loc->child_index())) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (child " + "index of a NameLoc is not valid)"); + } + return mlir::NameLoc::get( + builder.getStringAttr(name_loc->name()->string_view()), + debug_metadata_locations[attribute_location_idx_map.at( + name_loc->child_index())]); + } + // FusedLoc. + case debug_metadata::LocationType_FusedLoc: { + auto fused_loc = + static_cast(location.location()); + auto fused_location_indexes = fused_loc->location_indexes(); + std::vector fused_locations; + fused_locations.reserve(fused_location_indexes->size()); + for (int fused_loc_idx = 0; + fused_loc_idx < fused_location_indexes->size(); ++fused_loc_idx) { + if (!attribute_location_idx_map.contains( + fused_location_indexes->Get(fused_loc_idx))) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken " + "(location index of a FusedLoc is not valid)"); + } + fused_locations.push_back( + debug_metadata_locations[attribute_location_idx_map.at( + fused_location_indexes->Get(fused_loc_idx))]); + } + return mlir::FusedLoc::get( + fused_locations, mlir::StringAttr::get(builder.getContext(), ""), + builder.getContext()); + } + default: { + return mlir::UnknownLoc::get(builder.getContext()); + } + } +} + +// Parses all locations in ConversionDebugMetadata, build the mlir::location +// counterparts, and put them inside debug_metadata_. Additionally, maintain a +// map that maps the top location index of each operator. +Status ParseAndBuildLocation( + Builder builder, + const debug_metadata::ConversionDebugMetadata* conversion_debug_metadata, + DebugMetadata& debug_metadata_var) { + auto attribute_types = conversion_debug_metadata->attributes_type(); + auto attributes = conversion_debug_metadata->attributes(); + + auto& debug_metadata_locations = debug_metadata_var.debug_metadata_locations; + debug_metadata_locations.reserve(attribute_types->size()); + + // Map index in the attribute_vector to the index in the data structure we + // are building: DebugMetadata::debug_metadata_locations. + absl::flat_hash_map attribute_location_idx_map; + + for (int i = 0; i < attribute_types->size(); ++i) { + if (attribute_types->Get(i) == debug_metadata::Attribute_Location) { + auto location = + static_cast(attributes->Get(i)); + TF_ASSIGN_OR_RETURN( + auto mlir_location, + BuildLocation(builder, *location, debug_metadata_locations, + attribute_location_idx_map)); + debug_metadata_locations.push_back(mlir_location); + + // Create index mapping. + attribute_location_idx_map[i] = debug_metadata_locations.size() - 1; + } + } + + // Collect the top location idx of each operator. + auto subgraphs_debug_metadata = + conversion_debug_metadata->subgraphs_debug_metadata(); + for (int subgraph_idx = 0; subgraph_idx < subgraphs_debug_metadata->size(); + ++subgraph_idx) { + const auto* subgraph_debug_metadata = + subgraphs_debug_metadata->Get(subgraph_idx); + auto operators_debug_metadata = + subgraph_debug_metadata->operators_debug_metadata(); + for (int operator_idx = 0; operator_idx < operators_debug_metadata->size(); + ++operator_idx) { + const auto* operator_debug_metadata = + operators_debug_metadata->Get(operator_idx); + // Find the location attribute of the operator. Note that there should + // be at most one idx pointing to location attribute for each operator. + std::vector location_attribute_idxs; + for (int i = 0; + i < operator_debug_metadata->attribute_metadata_indexes()->size(); + ++i) { + auto attribute_idx = + operator_debug_metadata->attribute_metadata_indexes()->Get(i); + if (attribute_types->Get(attribute_idx) == + debug_metadata::Attribute_Location) { + location_attribute_idxs.push_back(attribute_idx); + } + } + if (location_attribute_idxs.size() > 1) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (more " + "than one location attribute for an operator)"); + } + if (location_attribute_idxs.empty()) { + continue; + } + + if (!attribute_location_idx_map.contains(location_attribute_idxs[0])) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken " + "(location attribute index of an operator is not valid)"); + } + debug_metadata_var.operator_location_map[subgraph_idx][operator_idx] = + attribute_location_idx_map[location_attribute_idxs[0]]; + } + } + + return absl::OkStatus(); +} + +// Parse the DebugMetadata flatbuffer and store debug metadata in struct +// `debug_metadata`. +Status ParseDebugMetadata(Builder builder, const char* data, size_t size, + DebugMetadata& debug_metadata_var) { + auto debug_metadata_fb = debug_metadata::GetDebugMetadata(data); + + if (debug_metadata_fb->debug_metadata_type()->size() != + debug_metadata_fb->debug_metadata()->size()) { + return absl::InternalError( + "Invalid/corrupt DebugMetadata, expected invariant broken (size of " + "debug_metadata_type and debug_metadata not equal)"); + } + + for (int i = 0; i < debug_metadata_fb->debug_metadata_type()->size(); ++i) { + if (debug_metadata_fb->debug_metadata_type()->Get(i) == + debug_metadata::DebugMetadataType_ConversionDebugMetadata) { + auto conversion_debug_metadata = + static_cast( + debug_metadata_fb->debug_metadata()->Get(i)); + TF_RETURN_IF_ERROR(ParseAndBuildLocation( + builder, conversion_debug_metadata, debug_metadata_var)); + } else { + LOG(WARNING) << "Unsupported DebugMetadataType: " + << debug_metadata_fb->debug_metadata_type()->Get(i); + } + } + + return absl::OkStatus(); +} + +// Return MLIR location if it exists in the debug metadata. Otherwise, create a +// MLIR location by fusing its output tensor names. +Location OpLoc(const OperatorT& op, Builder builder, + DebugMetadata& debug_metadata, const tflite::SubGraphT& subgraph, + Location base) { + const int subgraph_debug_metadata_idx = subgraph.debug_metadata_index; + if (debug_metadata.operator_location_map.contains( + subgraph_debug_metadata_idx) && + debug_metadata.operator_location_map[subgraph_debug_metadata_idx] + .contains(op.debug_metadata_index)) { + int location_idx = + debug_metadata.operator_location_map[subgraph_debug_metadata_idx] + [op.debug_metadata_index]; + return debug_metadata.debug_metadata_locations[location_idx]; + } + if (op.outputs.empty()) return base; llvm::SmallVector locations; locations.reserve(op.outputs.size()); for (auto tensor_index : op.outputs) { - locations.push_back(TensorLoc(*tensors[tensor_index], builder, base)); + locations.push_back( + TensorLoc(*subgraph.tensors[tensor_index], builder, base)); } return mlir::FusedLoc::get(builder.getContext(), locations); } // Extract the min max information in the tensor and create the quant stats op. -// If the input `tensor` has scale/zero_point, `res` should have quantized -// type, thus none stats op is required and nullptr is returned. -// If the min max information is invalid, nullptr is returned. +// If the input `tensor` has scale/zero_point, `res` should have quantized type, +// thus none stats op is required and nullptr is returned. If the min max +// information is invalid, nullptr is returned. mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, Value res) { // If the `tensor` has scale/zero_point, it must have been quantized, then the @@ -678,8 +889,8 @@ StatusOr ConvertOp( } // While the last several tensors could be optional tensors for an tfl op, the - // number of input operands could vary. Gets the min/max number of - // operands from tflite op name. + // number of input operands could vary. Gets the min/max number of operands + // from tflite op name. // Also, since the above code special-handles the `tfl.reshape` op and add an // additional input, we put these function block here. llvm::MinMax input_min_max = mlir::OperandNumbersMinMax(op_name); @@ -1117,7 +1328,7 @@ StatusOr ConvertSubgraph( const tflite::SignatureDefT* signature, const tflite::ControlEdges& control_edges, const std::unique_ptr& model_ptr, - bool use_stablehlo_constant) { + bool use_stablehlo_constant, DebugMetadata& debug_metadata) { // Populate from metadata. ControlNodes control_nodes; for (const auto [from, to] : control_edges) { @@ -1301,11 +1512,12 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( mlir::TensorType type, tfl::GetTensorType(*subgraph.tensors[intermediate], builder, - /*is_constant=*/false, /*is_intermediate=*/true)); + /*is_constant=*/false, + /*is_intermediate=*/true)); intermediate_types.emplace_back(type); } - auto op_loc = OpLoc(*op, subgraph.tensors, builder, base_loc); + auto op_loc = OpLoc(*op, builder, debug_metadata, subgraph, base_loc); // If there's an optional argument, maybe_optional_arg_marker has been set // to a valid Value @@ -1504,7 +1716,7 @@ OwningOpRef tflite::FlatBufferToMlir( const bool disable_vhlo_to_stablehlo) { mlir::DialectRegistry registry; registry.insert(); @@ -1513,7 +1725,7 @@ OwningOpRef tflite::FlatBufferToMlir( context->loadDialect< mlir::arith::ArithDialect, mlir::func::FuncDialect, - mlir::quant::QuantizationDialect, + mlir::quant::QuantDialect, mlir::quantfork::QuantizationForkDialect, mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect, mlir::stablehlo::StablehloDialect, mlir::vhlo::VhloDialect>(); @@ -1535,6 +1747,7 @@ OwningOpRef tflite::FlatBufferToMlir( llvm::SmallVector metadata_attrs; mlir::StringSet<> seen_attr; + DebugMetadata debug_metadata; for (const auto& metadata : model->metadata) { if (metadata->name == tflite::kModelControlDependenciesMetadataKey) { const std::vector& data = model->buffers[metadata->buffer]->data; @@ -1559,6 +1772,17 @@ OwningOpRef tflite::FlatBufferToMlir( continue; } + if (metadata->name == "debug_metadata") { + const std::vector& data = model->buffers[metadata->buffer]->data; + auto status = ParseDebugMetadata( + builder, reinterpret_cast(data.data()), data.size(), + debug_metadata); + if (!status.ok()) { + return emitError(base_loc, std::string(status.message())), nullptr; + } + continue; + } + std::vector buffer = model->buffers[metadata->buffer]->data; metadata_attrs.emplace_back( builder.getStringAttr(metadata->name), @@ -1618,7 +1842,7 @@ OwningOpRef tflite::FlatBufferToMlir( ? subgraph_to_signature_map.at(subgraph_index) : nullptr, model_control_dependencies[subgraph_index], model_ptr, - use_stablehlo_constant); + use_stablehlo_constant, debug_metadata); if (!func_or_error.ok()) { return emitError(base_loc, "could not translate function ") << subgraph->name << ": " << func_or_error.status().message(), diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index b84a981ca9e541..bc281bb5fbad44 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -229,7 +229,7 @@ static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( "mlir-to-tflite-flatbuffer", "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction, [](DialectRegistry& registry) { - registry.insert(); mlir::RegisterAllTensorFlowDialects(registry); registry.insert(); diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 1b7dd7c77dc2df..d5d2d22e303768 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -52,7 +52,7 @@ limitations under the License. #include "llvm/Support/Threading.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -5255,13 +5255,12 @@ Attribute ConstBytesAttr::parse(AsmParser& parser, Type type) { if (parser.parseString(&data)) { return nullptr; } - if (data.size() < 2 || data.substr(0, 2) != "0x") { - parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`"); - return nullptr; + if (data.size() >= 2 && data.substr(0, 2) == "0x") { + std::string bytes_data = absl::HexStringToBytes(data.substr(2)); + return ConstBytesAttr::get(parser.getBuilder().getContext(), bytes_data); } - std::string bytes_data = absl::HexStringToBytes(data.substr(2)); - return ConstBytesAttr::get(parser.getBuilder().getContext(), bytes_data); + return ConstBytesAttr::get(parser.getBuilder().getContext(), data); } void ConstBytesAttr::print(mlir::AsmPrinter& printer) const { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index ed3a963bbb7523..5946ce0f31da73 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -25,6 +25,8 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/TypeSupport.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 77b16fe3fb3cf8..68c1c1192937c4 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -23,7 +23,7 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" @@ -1135,7 +1135,7 @@ in the batch dimensions and broadcasting. `inputs[0]`: required: input LHS `inputs[1]`: required: input RHS `adjoint_lhs`: optional: Transpose LHS (default false) - `adjoint_lhs`: optional: Transpose LHS (default false) + `adjoint_rhs`: optional: Transpose RHS (default false) }]; let arguments = (ins diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h index 1c3f6ce789dcc4..3a602ba9c51687 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h +++ b/tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h @@ -20,6 +20,7 @@ limitations under the License. // LINT.IfChange #include +#include #include #include #include diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index 2deb787ee627d2..3954cbd1fd1179 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -11,7 +11,6 @@ package( default_visibility = [ "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/lite:__subpackages__", - "//tensorflow/lite/python/converter:__subpackages__", "//tensorflow/lite/toco/python:__subpackages__", ], licenses = ["notice"], @@ -93,7 +92,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "converter_error_data_proto_py", -# api_version = 2, # visibility = [ # "//visibility:public", # ], diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index edddef0e7e992f..9aeef35691b4f8 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_pywrap") load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "tf_python_pybind_extension") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -123,9 +124,9 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_xla//xla/service:hlo_parser", + "@local_xla//xla/hlo/parser:hlo_parser", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", ], ) @@ -164,9 +165,7 @@ py_strict_library( config_setting( name = "tflite_convert_with_select_tf_ops", define_values = {"tflite_convert_with_select_tf_ops": "true"}, - visibility = [ - "//tensorflow/lite:__subpackages__", - ], + visibility = ["//visibility:private"], ) filegroup( @@ -174,9 +173,7 @@ filegroup( srcs = [ "converter_python_api.h", ], - visibility = [ - "//tensorflow/python:__subpackages__", - ], + visibility = ["//visibility:private"], ) cc_library( @@ -239,10 +236,13 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_converter_api.pyi", ], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", - ], + ] + if_pywrap([":converter_python_api"]), ) diff --git a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi index cdb1e881b7dc9f..989d4f1dbe56fb 100644 --- a/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi +++ b/tensorflow/compiler/mlir/lite/python/_pywrap_converter_api.pyi @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -def Convert(model_flags_proto_txt_raw: object, toco_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., enable_mlir_converter: bool = ..., quantization_py_function_library = ...) -> object: ... +def Convert(model_flags_proto_txt_raw: object, converter_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., quantization_py_function_library = ...) -> object: ... def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ... def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ... def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ... diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index 1b0250fd535e9a..c7059d721a062f 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -61,7 +60,7 @@ namespace tflite { PyObject* Convert(PyObject* model_flags_proto_txt_raw, PyObject* converter_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return, - PyObject* debug_info_txt_raw, bool enable_mlir_converter, + PyObject* debug_info_txt_raw, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library) { // Use Python C API to validate and convert arguments. In py3 (bytes), diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.h b/tensorflow/compiler/mlir/lite/python/converter_python_api.h index 6dbcf0603d7e8c..cfcba696d01b7a 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.h +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.h @@ -30,15 +30,12 @@ namespace tflite { // representing the contents of the converted model. When extended_return // flag is set to true returns a dictionary that contains string representation // of the converted model and some statistics like arithmetic ops count. -// `debug_info_str` contains the `GraphDebugInfo` proto. When -// `enable_mlir_converter` is True, use MLIR-based conversion instead of -// TOCO conversion. +// `debug_info_str` contains the `GraphDebugInfo` proto. PyObject* Convert(PyObject* model_flags_proto_txt_raw, - PyObject* toco_flags_proto_txt_raw, + PyObject* converter_flags_proto_txt_raw, PyObject* input_contents_txt_raw, bool extended_return = false, PyObject* debug_info_txt_raw = nullptr, - bool enable_mlir_converter = false, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library = nullptr); diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc index de46a6f9115339..83e3da9e540bcf 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api_wrapper.cc @@ -27,21 +27,21 @@ PYBIND11_MODULE(_pywrap_converter_api, m) { m.def( "Convert", [](py::object model_flags_proto_txt_raw, - py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw, - bool extended_return, py::object debug_info_txt_raw, - bool enable_mlir_converter, + py::object converter_flags_proto_txt_raw, + py::object input_contents_txt_raw, bool extended_return, + py::object debug_info_txt_raw, const tensorflow::quantization::PyFunctionLibrary* quantization_py_function_library) { return tensorflow::PyoOrThrow(tflite::Convert( - model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(), - input_contents_txt_raw.ptr(), extended_return, - debug_info_txt_raw.ptr(), enable_mlir_converter, + model_flags_proto_txt_raw.ptr(), + converter_flags_proto_txt_raw.ptr(), input_contents_txt_raw.ptr(), + extended_return, debug_info_txt_raw.ptr(), quantization_py_function_library)); }, - py::arg("model_flags_proto_txt_raw"), py::arg("toco_flags_proto_txt_raw"), + py::arg("model_flags_proto_txt_raw"), + py::arg("converter_flags_proto_txt_raw"), py::arg("input_contents_txt_raw"), py::arg("extended_return") = false, py::arg("debug_info_txt_raw") = py::none(), - py::arg("enable_mlir_converter") = false, py::arg("quantization_py_function_library") = py::none(), R"pbdoc( Convert a model represented in `input_contents`. `model_flags_proto` @@ -50,9 +50,7 @@ PYBIND11_MODULE(_pywrap_converter_api, m) { representing the contents of the converted model. When extended_return flag is set to true returns a dictionary that contains string representation of the converted model and some statistics like arithmetic ops count. - `debug_info_str` contains the `GraphDebugInfo` proto. When - `enable_mlir_converter` is True, tuse MLIR-based conversion instead of - TOCO conversion. + `debug_info_str` contains the `GraphDebugInfo` proto. )pbdoc"); m.def( "ExperimentalMlirQuantizeModel", diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index b9558bad138bd5..acedcdf29134fb 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -40,9 +40,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/lite/python/wrap_converter.py b/tensorflow/compiler/mlir/lite/python/wrap_converter.py index 1c198f062388fc..ee3c5f2435fd9d 100644 --- a/tensorflow/compiler/mlir/lite/python/wrap_converter.py +++ b/tensorflow/compiler/mlir/lite/python/wrap_converter.py @@ -22,19 +22,17 @@ def wrapped_convert( model_flags_str, - toco_flags_str, + converter_flags_str, input_data_str, debug_info_str, - enable_mlir_converter, ): - """Wraps TocoConvert with lazy loader.""" + """Wraps Convert with lazy loader.""" return _pywrap_converter_api.Convert( model_flags_str, - toco_flags_str, + converter_flags_str, input_data_str, False, # extended_return debug_info_str, - enable_mlir_converter, py_function_lib.PyFunctionLibrary(), ) diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index bc3456ed3c650a..92880736b3fcb6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.h b/tensorflow/compiler/mlir/lite/quantization/device_target.h index c974a758d23291..6c475fc23499f7 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.h +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.h @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 8ba8ef0fcad20a..8eb6087c7825c4 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -81,7 +81,7 @@ class ImportQuantStatsPass void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index c8b658e9e2a58b..79773263d38c8c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -15,7 +15,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index a51956ce08a239..d38b9e39423c8a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc index bf11fa66272cdd..f60fa016c0bc30 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.td b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.td index 4d18c866c35564..39bd4b711c448e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.td +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.td @@ -23,10 +23,41 @@ limitations under the License. #define QUANT_FORK_OPS include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOpsBase.td" -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +class quant_TypedPrimitiveOrContainer : + Type.predicate, + VectorOf<[etype]>.predicate]>, + "primitive/tensor/vector of " # etype.summary>; + +// A primitive type that can represent a real value. This is either a +// floating point value or a quantized type. +def quant_RealPrimitiveType : + Type, + "real valued primitive (float or quantized type)">; + +// A primitive type that can represent a storage value. This is either an +// integer or quantized type. +def quant_StoragePrimitiveType : + Type, + "quantized storage primitive (integer or quantized type)">; + +// A primitive or container of RealPrimitiveType. +def quant_RealValueType : + quant_TypedPrimitiveOrContainer; + +// A primitive or container of StoragePrimitiveType. +def quant_StorageValueType : + quant_TypedPrimitiveOrContainer; + +// Either a real valued or storage primitive or container type. +def quant_RealOrStorageValueType : + Type, + "real valued or storage primitive or container type">; + //===----------------------------------------------------------------------===// // Base classes //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc index 88dad98d3dd3af..2d79db85fadc35 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index 51c130c31b22ba..bcb756f7108854 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -73,6 +73,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/lite/schema:schema_utils", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_xla//xla/tsl/lib/core:status_test_util", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc index db7fd7fd052552..6601ce5dc9fa6c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc @@ -27,8 +27,10 @@ limitations under the License. #include #include +#include #include #include "absl/status/status.h" +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index ecba20595c0f91..8682cba5cdc5a9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index 8a6bf4f83a28c5..405002b1bcef5b 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_info.proto b/tensorflow/compiler/mlir/lite/quantization/quantization_info.proto index d06edca64556e0..f8e52abc4b8de6 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_info.proto +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_info.proto @@ -2,8 +2,6 @@ syntax = "proto3"; package mlir.quant; -option cc_enable_arenas = true; - // Represents the quantization parameters for a list of named tensors. message QuantizationInfo { // min/max of the per axis value range. To quantize the value, the metadata @@ -34,13 +32,16 @@ message QuantizationInfo { // The metadata defines the target properties. message Metadata { - // Bit number of fixed-point data the target kernel supports. + // Bit number of fixed-point data the target kernel supports. int32 num_bits = 1; - // The quantized axis index if it is per-axis quantization. + + // The quantized axis index if it is per-axis quantization. int32 quantize_axis = 2; + // The minimum allowed value of the fixed-point data range. // This can also be used to derive the sign of storage type. int32 range_min = 3; + // The minimum allowed value of the fixed-point data range. int32 range_max = 4; } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index b4015181886788..12856137123f63 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project @@ -55,7 +55,7 @@ struct LegalizeTFToQuant void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 47440b4c4c0beb..874118ae4f93d0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -41,7 +41,7 @@ using mlir::tblgen::Operator; // The function below has a non-constant reference as that is required by LLVM's // TableGenMain. // NOLINTNEXTLINE -static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { +static bool OpQuantSpecWriter(raw_ostream &os, const RecordKeeper &records) { llvm::Regex acc_uniform_trait_regex{"AccumulatorUniformScale<([0-9]*),"}; llvm::Regex coeff_index_trait_regex{"AffineOpCoefficient<(-?[0-9]*),"}; llvm::Regex fixed_uniform_trait_regex{ @@ -50,7 +50,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, RecordKeeper &records) { // Retrieve all the definitions derived from Op definition and sort by record // name. - std::vector defs = records.getAllDerivedDefinitions("Op"); + std::vector defs = records.getAllDerivedDefinitions("Op"); llvm::sort(defs, LessRecord()); OUT(0) << "static std::unique_ptr " diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc index 98d5e7fc56215e..c92d43da951af1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc @@ -20,12 +20,14 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_replace.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "mlir/TableGen/Operator.h" // from @llvm-project -#include "tsl/platform/logging.h" #include "tsl/platform/regexp.h" using llvm::LessRecord; @@ -54,7 +56,8 @@ const std::map &GetTypeToStringRepresentation() { return *entries; } -void EmitDynamicRangeOp(std::vector &defs, raw_ostream *ostream) { +void EmitDynamicRangeOp(std::vector &defs, + raw_ostream *ostream) { std::string dynamic_quant_kernel_support_regex = "bool GetDynamicRangeQuantKernelSupport() { return true; }"; raw_ostream &os = *ostream; @@ -104,7 +107,7 @@ void EmitDynamicRangeOp(std::vector &defs, raw_ostream *ostream) { os.indent(0) << "}\n"; } -void EmitSparseOp(std::vector &defs, raw_ostream *ostream) { +void EmitSparseOp(std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; llvm::sort(defs, LessRecord()); @@ -126,7 +129,7 @@ void EmitSparseOp(std::vector &defs, raw_ostream *ostream) { os.indent(0) << "}\n"; } -bool CheckTypeConstraints(llvm::Init *input_value, +bool CheckTypeConstraints(const llvm::Init *input_value, std::list required_types, bool per_axis) { auto *def_init = llvm::cast(input_value); @@ -144,7 +147,7 @@ bool CheckTypeConstraints(llvm::Init *input_value, return true; } -void GenerateStaticQuantOp(std::vector &defs, +void GenerateStaticQuantOp(std::vector &defs, std::vector &result, InputDataType act_type, const bool per_axis, const bool is_toco) { @@ -180,7 +183,7 @@ void GenerateStaticQuantOp(std::vector &defs, Operator op(def); if (!op.getTrait("::mlir::OpTrait::quant::QuantizableResult")) continue; - llvm::DagInit *args_in_dag = def->getValueAsDag("arguments"); + const llvm::DagInit *args_in_dag = def->getValueAsDag("arguments"); // Assumes argument name is "input" for input activations. Otherwise, assume // the first argument is the input activation. int input_idx = 0; @@ -224,7 +227,7 @@ void GenerateStaticQuantOp(std::vector &defs, } } -void EmitStaticInt8PerAxisQuantOp(std::vector &defs, +void EmitStaticInt8PerAxisQuantOp(std::vector &defs, raw_ostream &os) { os.indent(0) << "const std::set &ExportStaticInt8PerAxisSpec() {\n"; @@ -244,7 +247,7 @@ void EmitStaticInt8PerAxisQuantOp(std::vector &defs, os.indent(0) << "}\n"; } -void EmitStaticInt8PerTensorQuantOp(std::vector &defs, +void EmitStaticInt8PerTensorQuantOp(std::vector &defs, raw_ostream &os) { os.indent(0) << "const std::set &ExportStaticInt8PerTensorSpec() {\n"; @@ -264,7 +267,7 @@ void EmitStaticInt8PerTensorQuantOp(std::vector &defs, os.indent(0) << "}\n"; } -void EmitStaticUInt8PerAxisQuantOp(std::vector &defs, +void EmitStaticUInt8PerAxisQuantOp(std::vector &defs, raw_ostream &os) { os.indent(0) << "const std::set &ExportStaticUInt8PerAxisSpec() {\n"; @@ -284,7 +287,7 @@ void EmitStaticUInt8PerAxisQuantOp(std::vector &defs, os.indent(0) << "}\n"; } -void EmitStaticUInt8PerTensorQuantOp(std::vector &defs, +void EmitStaticUInt8PerTensorQuantOp(std::vector &defs, raw_ostream &os) { os.indent(0) << "const std::set &ExportStaticUInt8PerTensorSpec() {\n"; @@ -304,7 +307,8 @@ void EmitStaticUInt8PerTensorQuantOp(std::vector &defs, os.indent(0) << "}\n"; } -void EmitStaticQuantOp(std::vector &defs, raw_ostream *ostream) { +void EmitStaticQuantOp(std::vector &defs, + raw_ostream *ostream) { raw_ostream &os = *ostream; llvm::sort(defs, LessRecord()); @@ -314,7 +318,7 @@ void EmitStaticQuantOp(std::vector &defs, raw_ostream *ostream) { EmitStaticUInt8PerTensorQuantOp(defs, os); } -void EmitStaticQuantWithInt16ActOp(std::vector &defs, +void EmitStaticQuantWithInt16ActOp(std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; llvm::sort(defs, LessRecord()); @@ -337,7 +341,7 @@ void EmitStaticQuantWithInt16ActOp(std::vector &defs, os.indent(0) << "}\n"; } -void EmitStaticQuantWithInt16ActTocoOp(std::vector &defs, +void EmitStaticQuantWithInt16ActTocoOp(std::vector &defs, raw_ostream *ostream) { raw_ostream &os = *ostream; llvm::sort(defs, LessRecord()); @@ -361,8 +365,9 @@ void EmitStaticQuantWithInt16ActTocoOp(std::vector &defs, } static bool TFLiteOpCoverageSpecWritersMain(raw_ostream &os, - RecordKeeper &records) { - std::vector op_defs = records.getAllDerivedDefinitions("TFL_Op"); + const RecordKeeper &records) { + std::vector op_defs = + records.getAllDerivedDefinitions("TFL_Op"); EmitStaticQuantOp(op_defs, &os); EmitDynamicRangeOp(op_defs, &os); EmitStaticQuantWithInt16ActOp(op_defs, &os); diff --git a/tensorflow/compiler/mlir/lite/schema/schema_generated.h b/tensorflow/compiler/mlir/lite/schema/schema_generated.h index 5c81bbeb36b806..36e715a328173a 100755 --- a/tensorflow/compiler/mlir/lite/schema/schema_generated.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_generated.h @@ -14835,6 +14835,7 @@ struct OperatorT : public ::flatbuffers::NativeTable { uint64_t large_custom_options_offset = 0; uint64_t large_custom_options_size = 0; tflite::BuiltinOptions2Union builtin_options_2{}; + int32_t debug_metadata_index = -1; }; struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -14853,7 +14854,8 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_LARGE_CUSTOM_OPTIONS_OFFSET = 22, VT_LARGE_CUSTOM_OPTIONS_SIZE = 24, VT_BUILTIN_OPTIONS_2_TYPE = 26, - VT_BUILTIN_OPTIONS_2 = 28 + VT_BUILTIN_OPTIONS_2 = 28, + VT_DEBUG_METADATA_INDEX = 30 }; uint32_t opcode_index() const { return GetField(VT_OPCODE_INDEX, 0); @@ -15340,6 +15342,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const tflite::StablehloShiftLeftOptions *builtin_options_2_as_StablehloShiftLeftOptions() const { return builtin_options_2_type() == tflite::BuiltinOptions2_StablehloShiftLeftOptions ? static_cast(builtin_options_2()) : nullptr; } + int32_t debug_metadata_index() const { + return GetField(VT_DEBUG_METADATA_INDEX, -1); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX, 4) && @@ -15362,6 +15367,7 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VerifyField(verifier, VT_BUILTIN_OPTIONS_2_TYPE, 1) && VerifyOffset(verifier, VT_BUILTIN_OPTIONS_2) && VerifyBuiltinOptions2(verifier, builtin_options_2(), builtin_options_2_type()) && + VerifyField(verifier, VT_DEBUG_METADATA_INDEX, 4) && verifier.EndTable(); } OperatorT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -16004,6 +16010,9 @@ struct OperatorBuilder { void add_builtin_options_2(::flatbuffers::Offset builtin_options_2) { fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS_2, builtin_options_2); } + void add_debug_metadata_index(int32_t debug_metadata_index) { + fbb_.AddElement(Operator::VT_DEBUG_METADATA_INDEX, debug_metadata_index, -1); + } explicit OperatorBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -16029,10 +16038,12 @@ inline ::flatbuffers::Offset CreateOperator( uint64_t large_custom_options_offset = 0, uint64_t large_custom_options_size = 0, tflite::BuiltinOptions2 builtin_options_2_type = tflite::BuiltinOptions2_NONE, - ::flatbuffers::Offset builtin_options_2 = 0) { + ::flatbuffers::Offset builtin_options_2 = 0, + int32_t debug_metadata_index = -1) { OperatorBuilder builder_(_fbb); builder_.add_large_custom_options_size(large_custom_options_size); builder_.add_large_custom_options_offset(large_custom_options_offset); + builder_.add_debug_metadata_index(debug_metadata_index); builder_.add_builtin_options_2(builtin_options_2); builder_.add_intermediates(intermediates); builder_.add_mutating_variable_inputs(mutating_variable_inputs); @@ -16061,7 +16072,8 @@ inline ::flatbuffers::Offset CreateOperatorDirect( uint64_t large_custom_options_offset = 0, uint64_t large_custom_options_size = 0, tflite::BuiltinOptions2 builtin_options_2_type = tflite::BuiltinOptions2_NONE, - ::flatbuffers::Offset builtin_options_2 = 0) { + ::flatbuffers::Offset builtin_options_2 = 0, + int32_t debug_metadata_index = -1) { auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector(*outputs) : 0; auto custom_options__ = custom_options ? _fbb.CreateVector(*custom_options) : 0; @@ -16081,7 +16093,8 @@ inline ::flatbuffers::Offset CreateOperatorDirect( large_custom_options_offset, large_custom_options_size, builtin_options_2_type, - builtin_options_2); + builtin_options_2, + debug_metadata_index); } ::flatbuffers::Offset CreateOperator(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -16093,6 +16106,7 @@ struct SubGraphT : public ::flatbuffers::NativeTable { std::vector outputs{}; std::vector> operators{}; std::string name{}; + int32_t debug_metadata_index = -1; SubGraphT() = default; SubGraphT(const SubGraphT &o); SubGraphT(SubGraphT&&) FLATBUFFERS_NOEXCEPT = default; @@ -16107,7 +16121,8 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { VT_INPUTS = 6, VT_OUTPUTS = 8, VT_OPERATORS = 10, - VT_NAME = 12 + VT_NAME = 12, + VT_DEBUG_METADATA_INDEX = 14 }; const ::flatbuffers::Vector<::flatbuffers::Offset> *tensors() const { return GetPointer> *>(VT_TENSORS); @@ -16124,6 +16139,9 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { const ::flatbuffers::String *name() const { return GetPointer(VT_NAME); } + int32_t debug_metadata_index() const { + return GetField(VT_DEBUG_METADATA_INDEX, -1); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) && @@ -16138,6 +16156,7 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { verifier.VerifyVectorOfTables(operators()) && VerifyOffset(verifier, VT_NAME) && verifier.VerifyString(name()) && + VerifyField(verifier, VT_DEBUG_METADATA_INDEX, 4) && verifier.EndTable(); } SubGraphT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -16164,6 +16183,9 @@ struct SubGraphBuilder { void add_name(::flatbuffers::Offset<::flatbuffers::String> name) { fbb_.AddOffset(SubGraph::VT_NAME, name); } + void add_debug_metadata_index(int32_t debug_metadata_index) { + fbb_.AddElement(SubGraph::VT_DEBUG_METADATA_INDEX, debug_metadata_index, -1); + } explicit SubGraphBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -16181,8 +16203,10 @@ inline ::flatbuffers::Offset CreateSubGraph( ::flatbuffers::Offset<::flatbuffers::Vector> inputs = 0, ::flatbuffers::Offset<::flatbuffers::Vector> outputs = 0, ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset>> operators = 0, - ::flatbuffers::Offset<::flatbuffers::String> name = 0) { + ::flatbuffers::Offset<::flatbuffers::String> name = 0, + int32_t debug_metadata_index = -1) { SubGraphBuilder builder_(_fbb); + builder_.add_debug_metadata_index(debug_metadata_index); builder_.add_name(name); builder_.add_operators(operators); builder_.add_outputs(outputs); @@ -16197,7 +16221,8 @@ inline ::flatbuffers::Offset CreateSubGraphDirect( const std::vector *inputs = nullptr, const std::vector *outputs = nullptr, const std::vector<::flatbuffers::Offset> *operators = nullptr, - const char *name = nullptr) { + const char *name = nullptr, + int32_t debug_metadata_index = -1) { auto tensors__ = tensors ? _fbb.CreateVector<::flatbuffers::Offset>(*tensors) : 0; auto inputs__ = inputs ? _fbb.CreateVector(*inputs) : 0; auto outputs__ = outputs ? _fbb.CreateVector(*outputs) : 0; @@ -16209,7 +16234,8 @@ inline ::flatbuffers::Offset CreateSubGraphDirect( inputs__, outputs__, operators__, - name__); + name__, + debug_metadata_index); } ::flatbuffers::Offset CreateSubGraph(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -21216,6 +21242,7 @@ inline void Operator::UnPackTo(OperatorT *_o, const ::flatbuffers::resolver_func { auto _e = large_custom_options_size(); _o->large_custom_options_size = _e; } { auto _e = builtin_options_2_type(); _o->builtin_options_2.type = _e; } { auto _e = builtin_options_2(); if (_e) _o->builtin_options_2.value = tflite::BuiltinOptions2Union::UnPack(_e, builtin_options_2_type(), _resolver); } + { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } } inline ::flatbuffers::Offset Operator::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -21239,6 +21266,7 @@ inline ::flatbuffers::Offset CreateOperator(::flatbuffers::FlatBufferB auto _large_custom_options_size = _o->large_custom_options_size; auto _builtin_options_2_type = _o->builtin_options_2.type; auto _builtin_options_2 = _o->builtin_options_2.Pack(_fbb); + auto _debug_metadata_index = _o->debug_metadata_index; return tflite::CreateOperator( _fbb, _opcode_index, @@ -21253,13 +21281,15 @@ inline ::flatbuffers::Offset CreateOperator(::flatbuffers::FlatBufferB _large_custom_options_offset, _large_custom_options_size, _builtin_options_2_type, - _builtin_options_2); + _builtin_options_2, + _debug_metadata_index); } inline SubGraphT::SubGraphT(const SubGraphT &o) : inputs(o.inputs), outputs(o.outputs), - name(o.name) { + name(o.name), + debug_metadata_index(o.debug_metadata_index) { tensors.reserve(o.tensors.size()); for (const auto &tensors_ : o.tensors) { tensors.emplace_back((tensors_) ? new tflite::TensorT(*tensors_) : nullptr); } operators.reserve(o.operators.size()); @@ -21272,6 +21302,7 @@ inline SubGraphT &SubGraphT::operator=(SubGraphT o) FLATBUFFERS_NOEXCEPT { std::swap(outputs, o.outputs); std::swap(operators, o.operators); std::swap(name, o.name); + std::swap(debug_metadata_index, o.debug_metadata_index); return *this; } @@ -21289,6 +21320,7 @@ inline void SubGraph::UnPackTo(SubGraphT *_o, const ::flatbuffers::resolver_func { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } else { _o->outputs.resize(0); } } { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { if(_o->operators[_i]) { _e->Get(_i)->UnPackTo(_o->operators[_i].get(), _resolver); } else { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); }; } } else { _o->operators.resize(0); } } { auto _e = name(); if (_e) _o->name = _e->str(); } + { auto _e = debug_metadata_index(); _o->debug_metadata_index = _e; } } inline ::flatbuffers::Offset SubGraph::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -21304,13 +21336,15 @@ inline ::flatbuffers::Offset CreateSubGraph(::flatbuffers::FlatBufferB auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; auto _operators = _o->operators.size() ? _fbb.CreateVector<::flatbuffers::Offset> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); + auto _debug_metadata_index = _o->debug_metadata_index; return tflite::CreateSubGraph( _fbb, _tensors, _inputs, _outputs, _operators, - _name); + _name, + _debug_metadata_index); } inline BufferT *Buffer::UnPack(const ::flatbuffers::resolver_function_t *_resolver) const { diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 79a355dcb73e3d..3b9825f72bde45 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -29,6 +29,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/lite:pass_registry_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite_d2s", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc index e180d3d46a9d8b..76a64df1f6e7bd 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc @@ -36,7 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -67,7 +68,7 @@ absl::Status SparsifyModel(const tflite::ModelT& input_model, } PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); - pm.addPass(TFL::CreateDenseToSparsePass()); + pm.addPass(TFL::Create()); if (failed(pm.run(module.get()))) { LOG(ERROR) << "Failed to sparsify: " diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index a0c3febeead92f..6502be87c8728b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,5 +1,6 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -119,6 +120,92 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "legalize_utils", + srcs = ["transforms/utils.cc"], + hdrs = ["transforms/utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_xla//xla/mlir_hlo", + ], +) + +tf_cc_test( + name = "legalize_utils_test", + srcs = ["transforms/utils_test.cc"], + deps = [ + ":legalize_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) + +gentbl_cc_library( + name = "legalize_tf_patterns_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_legalize_tf.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/legalize_tf_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + +cc_library( + name = "legalize_tf", + srcs = [ + "transforms/generated_legalize_tf.inc", + "transforms/legalize_tf.cc", + ], + hdrs = [ + "transforms/legalize_tf_passes.h", + ], + deps = [ + ":legalize_tf_patterns_inc_gen", + ":legalize_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:padding", + "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/client/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:convert_op_folder", + "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", + "@stablehlo//:chlo_ops", + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), +) + cc_library( name = "tf_stablehlo", srcs = [ @@ -131,6 +218,7 @@ cc_library( "-Ithird_party", ], deps = [ + ":legalize_tf", ":stablehlo_util", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", @@ -270,18 +358,18 @@ gentbl_cc_library( "-gen-pass-decls", "-name=OdmlStablehlo", ], - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/passes.td", + td_file = "transforms/stablehlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) cc_library( name = "compose_uniform_quantized_type_pass", srcs = ["transforms/compose_uniform_quantized_type_pass.cc"], - hdrs = ["transforms/passes.h"], + hdrs = ["transforms/stablehlo_passes.h"], copts = ["-Ithird_party"], deps = [ ":passes_inc_gen", @@ -306,7 +394,7 @@ cc_library( "transforms/unfuse_batch_norm_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -334,7 +422,7 @@ cc_library( "transforms/fold_broadcast_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -359,7 +447,7 @@ cc_library( "transforms/fold_broadcast_to_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -385,7 +473,7 @@ cc_library( "transforms/fuse_convolution_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -414,7 +502,7 @@ cc_library( "transforms/unfold_splat_constant_pass.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -469,8 +557,8 @@ cc_library( "transforms/legalize_stablehlo_composite_to_tfl_custom.cc", ], hdrs = [ - "transforms/passes.h", - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h", + "transforms/stablehlo_passes.h.inc", ], copts = [ "-Ithird_party", @@ -495,8 +583,8 @@ cc_library( "transforms/legalize_stablehlo_to_vhlo.cc", ], hdrs = [ - "transforms/passes.h", - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h", + "transforms/stablehlo_passes.h.inc", ], copts = [ "-Ithird_party", @@ -538,8 +626,8 @@ cc_library( "transforms/legalize_stablehlo_custom_call_to_composite.cc", ], hdrs = [ - "transforms/passes.h", - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h", + "transforms/stablehlo_passes.h.inc", ], copts = [ "-Ithird_party", @@ -563,7 +651,7 @@ cc_library( "transforms/optimize.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -585,8 +673,8 @@ cc_library( name = "uniform_quantized_stablehlo_to_tfl_pass", srcs = ["transforms/uniform_quantized_stablehlo_to_tfl_pass.cc"], hdrs = [ - "transforms/passes.h", - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h", + "transforms/stablehlo_passes.h.inc", ], copts = ["-Ithird_party"], deps = [ @@ -654,7 +742,7 @@ cc_library( "transforms/tflite_legalize_hlo.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -667,6 +755,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:fft", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:gather", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:get_dimension_size", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:if", @@ -715,7 +804,7 @@ cc_library( "transforms/prepare_hlo.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -745,7 +834,7 @@ cc_library( srcs = [ "transforms/tfl_legalize_chlo.cc", ], - hdrs = ["transforms/passes.h"], + hdrs = ["transforms/stablehlo_passes.h"], compatible_with = get_compatible_with_portable(), copts = [ "-Ithird_party", @@ -772,8 +861,8 @@ cc_library( "transforms/legalize_hlo.cc", ], hdrs = [ - "transforms/passes.h", - "transforms/passes.h.inc", + "transforms/stablehlo_passes.h", + "transforms/stablehlo_passes.h.inc", ], copts = [ "-Ithird_party", @@ -878,7 +967,7 @@ cc_library( ], hdrs = [ "transforms/composite_avg_pool.h", - "transforms/passes.h", + "transforms/stablehlo_passes.h", ], copts = [ "-Ithird_party", @@ -905,7 +994,7 @@ cc_library( srcs = [ "transforms/optimize_layout.cc", ], - hdrs = ["transforms/passes.h"], + hdrs = ["transforms/stablehlo_passes.h"], compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", @@ -953,6 +1042,7 @@ tf_cc_binary( " [tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS]", deps = [ ":check_accepted_ops_pass", + ":legalize_tf", ":op_stat_pass", ":stablehlo_util", ":transforms", @@ -969,7 +1059,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", - "//tensorflow/compiler/mlir/tf2xla/transforms:legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", "//tensorflow/core/ir/types:Dialect", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc index a510e640a7abd8..32d4626b7802d4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/odml_converter_main.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/odml_converter/passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" const char* art = R"( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index c2579fb3619911..bdba7dc58a379f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -37,7 +37,7 @@ limitations under the License. #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project @@ -56,6 +56,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" @@ -279,7 +280,7 @@ tensorflow::Status RunConverter(const PassPipelineCLParser& pass_pipeline) { mhlo::registerAllMhloDialects(registry); stablehlo::registerAllDialects(registry); registry.insert(); + mlir::quant::QuantDialect>(); mlir::quant::RegisterOps(); MLIRContext context(registry); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir new file mode 100644 index 00000000000000..bc2ce85d20f9f2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-tf.mlir @@ -0,0 +1,2532 @@ +// RUN: odml-to-stablehlo-opt --tf-stablehlo \ +// RUN: %s | FILECHECK_OPTS="" FileCheck %s + +//===----------------------------------------------------------------------===// +// BatchNorm op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// fusedBatchNormV2 is almost identical to fusedBatchNormV3 (and uses the same +// code), so only do a couple of basic checks. + +// CHECK-LABEL: fusedBatchNormV2_noTraining +func.func @fusedBatchNormV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV2_training +func.func @fusedBatchNormV2_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:5 = "tf.FusedBatchNormV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining +func.func @fusedBatchNormV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining_mixedPrecision +// CHECK-SAME: ([[X:%.*]]: tensor<8x8x8x8xbf16>, [[SCALE:%.*]]: tensor<8xf32>, [[OFFSET:%.*]]: tensor<8xf32>, [[MEAN:%.*]]: tensor<8xf32>, [[VARIANCE:%.*]]: tensor<8xf32>) +func.func @fusedBatchNormV3_noTraining_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) { + // CHECK: [[DUMMY:%.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK: [[CONVERT_X:%.*]] = stablehlo.convert [[X]] : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK: [[Y:%.*]] = "stablehlo.batch_norm_inference"([[CONVERT_X]], [[SCALE]], [[OFFSET]], [[MEAN]], [[VARIANCE]]) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32>) + // CHECK: [[Y_CONVERT:%.*]] = stablehlo.convert [[Y]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK: [[DUMMY_CAST:%.*]] = tensor.cast [[DUMMY]] : tensor<0xf32> to tensor<*xf32> + // CHECK: return [[Y_CONVERT]], [[MEAN]], [[VARIANCE]], [[MEAN]], [[VARIANCE]], [[DUMMY_CAST]] + func.return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training +func.func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: func @fusedBatchNormV3_training_batchVariance +func.func @fusedBatchNormV3_training_batchVariance(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8xf32> { + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: return %[[VAR]] + func.return %0#4 : tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_exponentialAvgFactor +func.func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { + // CHECK-DAG: %[[ALPHA:.*]] = stablehlo.constant dense<0.199999988> + // CHECK-DAG: %[[BETA:.*]] = stablehlo.constant dense<8.000000e-01> + // CHECK-DAG: %[[FACTOR:.*]] = stablehlo.constant dense<1.00195694> + // CHECK: %[[OUT:.*]], %[[MEAN:.*]], %[[VAR:.*]] = "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 0.8 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: %[[CORRECTED_VAR:.*]] = stablehlo.multiply %[[VAR]], %[[FACTOR]] + + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = stablehlo.multiply %arg3, %[[ALPHA]] + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = stablehlo.multiply %[[MEAN]], %[[BETA]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = stablehlo.multiply %arg4, %[[ALPHA]] + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = stablehlo.multiply %[[CORRECTED_VAR]], %[[BETA]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = stablehlo.add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + + // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[MEAN]], %[[VAR]] + func.return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_mixedPrecision +func.func @fusedBatchNormV3_training_mixedPrecision(%arg0: tensor<8x8x8x8xbf16>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK: stablehlo.convert %arg0 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK: stablehlo.convert {{.*}} : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_NCHW +func.func @fusedBatchNormV3_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK: "stablehlo.batch_norm_training"({{.*}}, %arg1, %arg2) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_NDHWC +func.func @fusedBatchNormV3_NDHWC(%arg0: tensor<8x8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>) { + // CHECK: feature_index = 4 : i64 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NDHWC", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_noTraining_dynamic_supported +func.func @fusedBatchNormV3_noTraining_dynamic_supported(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { + // CHECK: "stablehlo.batch_norm_inference"({{.*}}, %arg1, %arg2, %arg3, %arg4) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor, tensor, tensor, tensor, tensor) -> tensor + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = false} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported1 +func.func @fusedBatchNormV3_training_dynamic_unsupported1(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> (tensor) { + // CHECK: tf.FusedBatchNormV3 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor, tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor, tensor, tensor) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormV3_training_dynamic_unsupported2 +func.func @fusedBatchNormV3_training_dynamic_unsupported2(%arg0: tensor, %arg1: tensor<6xf32>, %arg2: tensor<6xf32>, %arg3: tensor<6xf32>, %arg4: tensor<6xf32>) -> (tensor) { + // CHECK: tf.FusedBatchNormV3 + %0:6 = "tf.FusedBatchNormV3"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, exponential_avg_factor = 1.0 : f32, is_training = true} : (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> (tensor, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) + func.return %0#0 : tensor +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGrad_noTraining +func.func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST_MUL2:.+]] = stablehlo.broadcast_in_dim %[[MUL2]], {{.*}} : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL3:.*]] = stablehlo.multiply %arg0, %[[BCAST_MUL2]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL3]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGrad_Training +func.func @fusedBatchNormGrad_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGrad"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining +func.func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL2:.*]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training +func.func @fusedBatchNormGradV2_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) {{.*}} + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_noTraining_mixed_precision +func.func @fusedBatchNormGradV2_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEST: %[[CST:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEST: %[[ADD:.*]] = stablehlo.add %arg4, %[[CST]] : tensor<8xf32> + // CHECK-NEST: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEST: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEST: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEST: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEST: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEST: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV2_Training_mixed_precision +func.func @fusedBatchNormGradV2_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV2"(%arg0, %arg1, %arg2, %arg3, %arg4) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining +func.func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { +// CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> +// CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> +// CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> +// CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> +// CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training +func.func @fusedBatchNormGradV3_Training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<0xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[EPS]] : tensor<0xf32> to tensor<*xf32> + // CHECK-NEXT: return %[[GRAD_OPERAND]], %[[EPS]], %[[CAST]] : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<0xf32>, tensor<*xf32>) + func.return %0#0, %0#3, %0#4 : tensor<8x8x8x8xf32>, tensor<0xf32>, tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_mixed_precision +func.func @fusedBatchNormGradV3_noTraining_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [3] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %[[MUL2]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training_mixed_precision +func.func @fusedBatchNormGradV3_Training_mixed_precision(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xbf16>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xbf16>) { + // CHECK-NEXT: %[[CONVERT:.*]] = stablehlo.convert %arg1 : (tensor<8x8x8x8xbf16>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%[[CONVERT]], %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: %[[CONVERT2:.*]] = stablehlo.convert %[[GRAD_OPERAND]] : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xbf16> + // CHECK-NEXT: return %[[CONVERT2]] : tensor<8x8x8x8xbf16> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NHWC", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xbf16>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xbf16> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_noTraining_NCHW +func.func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[EPS:.*]] = stablehlo.constant dense<1.000000e-03> : tensor<8xf32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg4, %[[EPS]] : tensor<8xf32> + // CHECK-NEXT: %[[RSQRT:.*]] = stablehlo.rsqrt %[[ADD]] : tensor<8xf32> + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %arg2, %[[RSQRT]] : tensor<8xf32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[MUL]], dims = [1] : (tensor<8xf32>) -> tensor<8x8x8x8xf32> + // CHECK-NEXT: %[[MUL2:.*]] = stablehlo.multiply %arg0, %[[BCAST]] : tensor<8x8x8x8xf32> + // CHECK-NEXT: return %[[MUL2]] : tensor<8x8x8x8xf32> + + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = false} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +// ----- + +// CHECK-LABEL: fusedBatchNormGradV3_Training_NCHW +func.func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x8x8x8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>, %arg5: tensor<8xf32>) -> (tensor<8x8x8x8xf32>) { + // CHECK-NEXT: %[[GRAD_OPERAND:.*]], %[[GRAD_SCALE:.*]], %[[GRAD_OFFSET:.*]] = "stablehlo.batch_norm_grad"(%arg1, %arg2, %arg3, %arg4, %arg0) <{epsilon = 1.000000e-03 : f32, feature_index = 1 : i64}> : (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>) + // CHECK-NEXT: return %[[GRAD_OPERAND]] : tensor<8x8x8x8xf32> + %0:5 = "tf.FusedBatchNormGradV3"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) {T = "tfdtype$DT_FLOAT", data_format = "NCHW", epsilon = 0.001 : f32, is_training = true} : (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) + func.return %0#0 : tensor<8x8x8x8xf32> +} + +//===----------------------------------------------------------------------===// +// Bias op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @biasAdd_default +func.func @biasAdd_default(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_NHWC +func.func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [3] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NHWC"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_NCHW +func.func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<32xi32>) -> tensor<1x32x10x32xi32> + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor<1x32x10x32xi32> + // CHECK-NEXT: return %[[ADD]] : tensor<1x32x10x32xi32> + %0 = "tf.BiasAdd"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", data_format = "NCHW"} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + func.return %0 : tensor<1x32x10x32xi32> +} + +// ----- + +// CHECK-LABEL: func @biasAdd_dynamic +func.func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [1] : (tensor, tensor<4xindex>) -> tensor + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor + // CHECK-NEXT: return %[[ADD]] : tensor + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NCHW"} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @biasAdd_partial_dynamic +func.func @biasAdd_partial_dynamic(%arg0: tensor, %arg1: tensor<512xi32>) -> tensor { + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor -> tensor<4xindex> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, %[[SHAPE]], dims = [3] : (tensor<512xi32>, tensor<4xindex>) -> tensor + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %[[BCAST]] : tensor + // CHECK-NEXT: %[[CAST:.*]] = tensor.cast %[[ADD]] : tensor to tensor + // CHECK-NEXT: return %[[CAST]] : tensor + %0 = "tf.BiasAdd"(%arg0, %arg1) {data_format = "NHWC"} : (tensor, tensor<512xi32>) -> tensor + func.return %0 : tensor +} + + +//===----------------------------------------------------------------------===// +// ClipByValue +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @clip +func.func @clip(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[VAL:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 + + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + // CHECK: return [[VAL]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @clip_dynamic +func.func @clip_dynamic(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp %arg1, %arg0, %arg2 + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @clip_static_broadcast +func.func @clip_static_broadcast(%arg0 : tensor<5xf32>, %arg1 : tensor, %arg2 : tensor) -> tensor<5xf32> { + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor) -> tensor<5xf32> + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.broadcast_in_dim %arg2, dims = [] : (tensor) -> tensor<5xf32> + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor<5xf32>, tensor, tensor) -> tensor<5xf32> + + // CHECK: return [[CLAMP]] + func.return %0 : tensor<5xf32> +} + + +// CHECK-LABEL: @clip_dynamic_broadcast +func.func @clip_dynamic_broadcast(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK: [[SHP:%.+]] = shape.shape_of %arg0 + // CHECK: [[SHPIDX:%.+]] = arith.index_cast [[SHP]] : tensor<1xindex> to tensor<1xi32> + // CHECK-DAG: [[BROADCAST_MIN:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg1, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: [[BROADCAST_MAX:%.+]] = stablehlo.dynamic_broadcast_in_dim %arg2, [[SHPIDX]], dims = [] : (tensor, tensor<1xi32>) -> tensor + // CHECK-DAG: [[CLAMP:%.+]] = stablehlo.clamp [[BROADCAST_MIN]], %arg0, [[BROADCAST_MAX]] + %0 = "tf.ClipByValue"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + + // CHECK: return [[CLAMP]] + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// DiagPart +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @diag_part +// CHECK-SAME: %[[ARG:.*]]: tensor<4x3x4x3xf32> +func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<12x12xf32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg0 : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> + // CHECK-NEXT: %[[IOTA:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %[[IOTA]], dims = [0] : (tensor<12xi32>) -> tensor<12x12xi32> + // CHECK-NEXT: %[[IOTA2:.*]] = stablehlo.iota dim = 0 : tensor<12xi32> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[IOTA2]], dims = [1] : (tensor<12xi32>) -> tensor<12x12xi32> + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare EQ, %[[BCAST]], %[[BCAST2]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> + // CHECK-NEXT: %[[SEL:.*]] = stablehlo.select %[[CMP]], %[[RESHAPE]], %[[CST0]] : tensor<12x12xi1>, tensor<12x12xf32> + // CHECK-NEXT: %[[REDUCE:.*]] = stablehlo.reduce(%[[SEL]] init: %[[CST1]]) applies stablehlo.add across dimensions = [0] : (tensor<12x12xf32>, tensor) -> tensor<12xf32> + // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[REDUCE]] : (tensor<12xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: return %[[RESHAPE2]] : tensor<4x3xf32> + + %0 = "tf.DiagPart"(%arg0) : (tensor<4x3x4x3xf32>) -> tensor<4x3xf32> + func.return %0: tensor<4x3xf32> +} + +//===----------------------------------------------------------------------===// +// MatrixDiagPart +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @matrix_diag_part +// CHECK-SAME: %[[ARG:.*]]: tensor<7x140x128xi32> +func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST4:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MIN0:.*]] = stablehlo.minimum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[MIN0]], %[[CST2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[CST1]], %[[MAX0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MIN1:.*]] = stablehlo.minimum %[[ADD0]], %[[SUB1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[SUB0]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[CST1]], %[[MIN1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[CMP0]], %[[SUB2]], %[[CST4]] : tensor<1x22x128xi1>, tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[SUB0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[SELECT0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX2:.*]] = stablehlo.maximum %[[NEG0]], %[[CST4]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB3:.*]] = stablehlo.subtract %[[MAX2]], %[[SELECT0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD2:.*]] = stablehlo.add %[[BCAST1]], %[[SUB3]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare GE, %[[ADD1]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP1]], %[[CMP2]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare GE, %[[ADD2]], %[[CST4]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP4:.*]] = stablehlo.compare LT, %[[ADD2]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP3]], %[[CMP4]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> + // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD2]], %[[ADD1]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-NEXT: %[[SELECT1:.*]] = stablehlo.select %[[BCAST1]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> + // CHECK-NEXT: return %[[SELECT1]] : tensor<7x22x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_zero_dim_complex +func.func @matrix_diag_part_zero_dim_complex(%arg0: tensor<4x0xcomplex>) -> tensor<0xcomplex> { + %cst = "tf.Const"() {value = dense<-3> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<(0.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + %0 = "tf.MatrixDiagPartV3"(%arg0, %cst, %cst_0) {align = "RIGHT_LEFT", device = ""} : (tensor<4x0xcomplex>, tensor, tensor>) -> tensor<0xcomplex> + // CHECK: return %{{[0-9]*}} : tensor<0xcomplex> + return %0 : tensor<0xcomplex> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_single_diagonal +func.func @matrix_diag_part_single_diagonal(%arg0: tensor<7x140x128xi32>) -> tensor<7x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x1x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x1x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[IOTA0]] : (tensor<128xi32>) -> tensor<1x1x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST1]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[RESHAPE0]], %[[FALSE]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[RESHAPE0]], %[[CST2]] : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x1x128xi1> + // CHECK-NEXT: %[[RESHAPE1:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x1x128xi1>) -> tensor<1x128xi1> + // CHECK-NEXT: %[[CONCAT:.*]] = stablehlo.concatenate %[[RESHAPE0]], %[[RESHAPE0]], dim = 0 : (tensor<1x1x128xi32>, tensor<1x1x128xi32>) -> tensor<2x1x128xi32> + // CHECK-NEXT: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x1x128xi32>) -> tensor<7x1x128xi32> + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.broadcast %[[RESHAPE1]], sizes = [7] : (tensor<1x128xi1>) -> tensor<7x1x128xi1> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST]], %[[GATHER]], %[[CST0]] : tensor<7x1x128xi1>, tensor<7x1x128xi32> + // CHECK-NEXT: %[[RESHAPE2:.*]] = stablehlo.reshape %[[SELECT0]] : (tensor<7x1x128xi32>) -> tensor<7x128xi32> + // CHECK-NEXT: return %[[RESHAPE2]] : tensor<7x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<0> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x128xi32> + func.return %2: tensor<7x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_ll +func.func @matrix_diag_part_align_ll(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + // CHECK-NEXT: %[[CST0:.*]] = stablehlo.constant dense<42> : tensor<7x22x128xi32> + // CHECK-NEXT: %[[CST1:.*]] = stablehlo.constant dense<128> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST2:.*]] = stablehlo.constant dense<140> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CST3:.*]] = stablehlo.constant dense<11> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[FALSE:.*]] = stablehlo.constant dense<0> : tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<22xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<22xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[IOTA1:.*]] = stablehlo.iota dim = 0 : tensor<128xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %[[IOTA1]], dims = [2] : (tensor<128xi32>) -> tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[CST3]], %[[BCAST0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[NEG0:.*]] = stablehlo.negate %[[SUB0]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX0:.*]] = stablehlo.maximum %[[SUB0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB1:.*]] = stablehlo.subtract %[[MAX0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[MAX1:.*]] = stablehlo.maximum %[[NEG0]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[SUB2:.*]] = stablehlo.subtract %[[MAX1]], %[[FALSE]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD0:.*]] = stablehlo.add %[[BCAST1]], %[[SUB1]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[ADD1:.*]] = stablehlo.add %[[BCAST1]], %[[SUB2]] : tensor<1x22x128xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare GE, %[[ADD0]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %[[ADD0]], %[[CST1]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare GE, %[[ADD1]], %[[FALSE]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare LT, %[[ADD1]], %[[CST2]] : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND1:.*]] = stablehlo.and %[[CMP2]], %[[CMP3]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[AND2:.*]] = stablehlo.and %[[AND0]], %[[AND1]] : tensor<1x22x128xi1> + // CHECK-NEXT: %[[RESHAPE0:.*]] = stablehlo.reshape %[[AND2]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> + // CHECK-NEXT: %[[CONCAT0:.*]] = stablehlo.concatenate %[[ADD1]], %[[ADD0]], dim = 0 : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> + // CHECK-NEXT: %[[GATHER0:.*]] = "stablehlo.gather"(%arg0, %[[CONCAT0]]) <{dimension_numbers = #{{.*}}, indices_are_sorted = false, slice_sizes = array}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %[[RESHAPE0]], sizes = [7] : (tensor<22x128xi1>) -> tensor<7x22x128xi1> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[BCAST2]], %[[GATHER0]], %[[CST0]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<7x22x128xi32> + + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_lr +func.func @matrix_diag_part_align_lr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "LEFT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[LE:.*]] = stablehlo.compare LE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{.*}} = stablehlo.select %[[LE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_rl +func.func @matrix_diag_part_align_rl(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_LEFT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK: %[[GE:.*]] = stablehlo.compare GE, %{{.*}}, %{{.*}} : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<1x22x128xi1> + // CHECK: %{{.*}} = stablehlo.select %[[GE]], %{{.*}}, %{{.*}} : tensor<1x22x128xi1>, tensor<1x22x128xi32> + + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_rr +func.func @matrix_diag_part_align_rr(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32> { + %0 = mhlo.constant dense<42> : tensor // padding value + %1 = mhlo.constant dense<[-10, 11]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = i32, align = "RIGHT_RIGHT" + } : (tensor<7x140x128xi32>, tensor<2xi32>, tensor) -> tensor<7x22x128xi32> + // CHECK-NOT: MatrixDiagPartV3 + func.return %2: tensor<7x22x128xi32> +} + +// ----- + +// CHECK-LABEL: func @matrix_diag_part_align_7d +// CHECK: (%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> +func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> tensor<3x5x7x9x11x4x10xf32> { + %0 = mhlo.constant dense<-1.> : tensor // padding value + %1 = mhlo.constant dense<[-6, -3]> : tensor<2xi32> // k + %2 = "tf.MatrixDiagPartV3"(%arg0, %1, %0) { + T = f32, align = "LEFT_RIGHT" + } : (tensor<3x5x7x9x11x13x17xf32>, tensor<2xi32>, tensor) -> tensor<3x5x7x9x11x4x10xf32> + func.return %2: tensor<3x5x7x9x11x4x10xf32> +} + +//===----------------------------------------------------------------------===// +// Erf +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erf +func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: mhlo.erf(%arg0) {{.*}} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +//===----------------------------------------------------------------------===// +// Erfc +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @erfc +func.func @erfc(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK-NOT: tf.Erfc + %0 = "tf.Erfc"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +//===----------------------------------------------------------------------===// +// Einsum. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @einsum +func.func @einsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { + // CHECK: stablehlo.einsum + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> + func.return %0: tensor<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @unary_einsum +func.func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { + // CHECK: stablehlo.constant{{.*}}1.000000e+00 + // CHECK: stablehlo.einsum{{.*}}",ab->aa" + %0 = "tf.Einsum"(%arg0) {equation = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32> + func.return %0: tensor<2x2xf32> +} + +//===----------------------------------------------------------------------===// +// FloorDiv and FloorMod. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @floordiv_broadcast_i32 +func.func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<3xi32> + // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %arg0, %[[BCAST0]] : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %[[BCAST1]] : tensor<2x3xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %arg0 : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP2]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[CMP1]], %[[BCAST2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> + + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_reverse_broadcast_i32 +func.func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS0:.*]] = stablehlo.constant dense<0> : tensor<2x3xi32> + // CHECK-NEXT: %[[ZEROS1:.*]] = stablehlo.constant dense<0> : tensor<3xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[DIV0:.*]] = stablehlo.divide %[[BCAST0]], %arg1 : tensor<2x3xi32> + // CHECK-NEXT: %[[MUL0:.*]] = stablehlo.multiply %[[DIV0]], %arg1 : tensor<2x3xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare NE, %[[MUL0]], %[[BCAST1]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP1:.*]] = stablehlo.compare LT, %arg0, %[[ZEROS1]] : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + // CHECK-NEXT: %[[CMP2:.*]] = stablehlo.compare LT, %arg1, %[[ZEROS0]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast_in_dim %[[CMP1]], dims = [1] : (tensor<3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[CMP3:.*]] = stablehlo.compare NE, %[[BCAST2]], %[[CMP2]] : (tensor<2x3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + // CHECK-NEXT: %[[AND0:.*]] = stablehlo.and %[[CMP0]], %[[CMP3]] : tensor<2x3xi1> + // CHECK-NEXT: %[[SUB0:.*]] = stablehlo.subtract %[[DIV0]], %[[ONES]] : tensor<2x3xi32> + // CHECK-NEXT: %[[SELECT0:.*]] = stablehlo.select %[[AND0]], %[[SUB0]], %[[DIV0]] : tensor<2x3xi1>, tensor<2x3xi32> + // CHECK-NEXT: return %[[SELECT0]] : tensor<2x3xi32> + + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f32 +func.func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %arg0, %arg0 + // CHECK-NEXT: %[[FLOOR:.*]] = stablehlo.floor %[[DIV]] + // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %0: tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @floordiv_bf16 +func.func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: stablehlo.divide + // CHECK-NEXT: stablehlo.floor + // CHECK-NEXT: stablehlo.convert + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + func.return %0: tensor<2xbf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_f16_broadcast +func.func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { + // CHECK-NEXT: stablehlo.broadcast_in_dim + // CHECK-NEXT: stablehlo.divide + // CHECK-NEXT: stablehlo.floor + // CHECK-NEXT: return + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + func.return %0: tensor<2x3xf16> +} + +// ----- + +// CHECK-LABEL: func @floordiv_dynamic +func.func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.divide + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.multiply + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.and + // + // CHECK: %[[SELECT:.*]] = stablehlo.select + // CHECK: return %[[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_unsigned +func.func @floordiv_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[RESULT:.*]] = shape.assuming + // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, + // CHECK: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %arg1, + // CHECK: %[[DIV:.*]] = stablehlo.divide %[[BCAST0]], %[[BCAST1]] + // CHECK: shape.assuming_yield %[[DIV]] + // CHECK: return %[[RESULT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floordiv_int +func.func @floordiv_int(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.divide + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.multiply + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.compare + // CHECK: shape.assuming + // CHECK: stablehlo.dynamic_broadcast_in_dim + // CHECK: stablehlo.and + // + // CHECK: %[[SELECT:.*]] = stablehlo.select + // CHECK: return %[[SELECT]] + %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_numerator +func.func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[REM:.*]] = stablehlo.remainder %[[BCAST0]], %arg1 : tensor<2x3xi32> + // CHECK: %[[AND:.*]] = stablehlo.and + // CHECK: %[[ADD:.*]] = stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_broadcast_denominator +func.func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { + // CHECK: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] + // CHECK: %[[AND:.*]] = stablehlo.and + // CHECK: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xi32>) -> tensor<2x3xi32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[BCAST1]], %[[REM]] + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + func.return %0: tensor<2x3xi32> +} + +// ----- + +// CHECK-LABEL: func @floormod_unsigned_broadcast_denominator +func.func @floormod_unsigned_broadcast_denominator(%arg0: tensor<2x3xui32>, %arg1: tensor<3xui32>) -> tensor<2x3xui32> { + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %arg1, dims = [1] : (tensor<3xui32>) -> tensor<2x3xui32> + // CHECK-NEXT: %[[REM:.*]] = stablehlo.remainder %arg0, %[[BCAST0]] : tensor<2x3xui32> + // CHECK-NEXT: return %[[REM]] : tensor<2x3xui32> + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xui32>, tensor<3xui32>) -> tensor<2x3xui32> + func.return %0: tensor<2x3xui32> +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_numerator +func.func @floormod_dynamic_broadcast_numerator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.remainder + // CHECK: shape.assuming {{.*}} { + // CHECK: stablehlo.compare + // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.and + // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @floormod_dynamic_broadcast_denominator +func.func @floormod_dynamic_broadcast_denominator_(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK-NOT: tf.FloorMod + // CHECK: %[[REM:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.remainder + // CHECK: shape.assuming {{.*}} { + // CHECK: stablehlo.compare + // CHECK: %[[AND:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.and + // CHECK: %[[ADD:.*]] = shape.assuming {{.*}} { + // CHECK: stablehlo.add + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[AND]], %[[ADD]], %[[REM]] + // CHECK-NEXT: return %[[SELECT]] + %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0: tensor +} + +//===----------------------------------------------------------------------===// +// OnesLike +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @ones_like +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) +func.func @ones_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONES]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> + %0 = "tf.OnesLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + func.return %0 : tensor<2x?xf32> +} + +//===----------------------------------------------------------------------===// +// ZerosLike +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @zeros_like +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>) +func.func @zeros_like(%arg0: tensor<2x?xf32>) -> tensor<2x?xf32> { + // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NEXT: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<2x?xf32> -> tensor<2xindex> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZEROS]], %[[SHAPE]], dims = [] : (tensor, tensor<2xindex>) -> tensor<2x?xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<2x?xf32> + %0 = "tf.ZerosLike"(%arg0) : (tensor<2x?xf32>) -> tensor<2x?xf32> + func.return %0 : tensor<2x?xf32> +} + +//===----------------------------------------------------------------------===// +// BroadcastTo. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @broadcast_to +func.func @broadcast_to(%arg0: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + %cst = "tf.Const"() { value = dense<16> : tensor<4xi32> } : () -> tensor<4xi32> + // CHECK: stablehlo.broadcast_in_dim %arg0 + %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<16xf32>, tensor<4xi32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +//===----------------------------------------------------------------------===// +// Complex op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @complex +func.func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex> { + // CHECK: stablehlo.complex + %1 = "tf.Complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex> + func.return %1 : tensor<3xcomplex> +} + +// ----- + +// CHECK-LABEL: func @imag +func.func @imag(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { + // CHECK: stablehlo.imag + %1 = "tf.Imag"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> + func.return %1 : tensor<3xf32> +} + +// ----- + +// CHECK-LABEL: func @real +func.func @real(%arg0: tensor<3xcomplex>) -> tensor<3xf32> { + // CHECK: stablehlo.real + %1 = "tf.Real"(%arg0) : (tensor<3xcomplex>) -> tensor<3xf32> + func.return %1 : tensor<3xf32> +} + +//===----------------------------------------------------------------------===// +// Concat op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @concat_v2 +func.func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 + %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + func.return %1 : tensor<6x3xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_neg_axis +func.func @concat_v2_neg_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 0 + + %axis = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<6x3xf32> + func.return %1 : tensor<6x3xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_1d_axis +func.func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { + // CHECK: stablehlo.concatenate %arg0, %arg1, dim = 1 + + %axis = "tf.Const"() { value = dense<[1]> : tensor<1xi64> } : () -> tensor<1xi64> + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<1xi64>) -> tensor<3x6xf32> + func.return %1 : tensor<3x6xf32> +} + +// ----- + +// CHECK-LABEL: func @concat_v2_non_const_axis +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 12 : i32}} { +func.func @concat_v2_non_const_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>, %axis: tensor) -> tensor<3x6xf32> { + // CHECK: "tf.ConcatV2" + %1 = "tf.ConcatV2"(%arg0, %arg1, %axis) : (tensor<3x3xf32>, tensor<3x3xf32>, tensor) -> tensor<3x6xf32> + func.return %1 : tensor<3x6xf32> +} +} + +//===----------------------------------------------------------------------===// +// Pad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @padv2_1D +func.func @padv2_1D(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor<6xf32> { + %padding = "tf.Const"() { value = dense<[[1, 2]]> : tensor<1x2xi64> } : () -> tensor<1x2xi64> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1], high = [2], interior = [0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3xf32>, tensor<1x2xi64>, tensor) -> tensor<6xf32> + func.return %1 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_2D +func.func @padv2_2D(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi64> } : () -> tensor<2x2xi64> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi64>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_i32_paddings +func.func @padv2_i32_paddings(%arg0: tensor<3x2xf32>, %arg1: tensor) -> tensor<6x9xf32> { + %padding = "tf.Const"() { value = dense<[[1,2],[3,4]]> : tensor<2x2xi32> } : () -> tensor<2x2xi32> + // CHECK: stablehlo.pad %arg0, %arg1, low = [1, 3], high = [2, 4], interior = [0, 0] + %1 = "tf.PadV2"(%arg0, %padding, %arg1) : (tensor<3x2xf32>, tensor<2x2xi32>, tensor) -> tensor<6x9xf32> + func.return %1 : tensor<6x9xf32> +} + +// ----- + +// CHECK-LABEL: func @padv2_dynamic +func.func @padv2_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor<1x2xi64>) -> tensor { + // CHECK-NEXT: %[[ZEROS:.*]] = stablehlo.constant dense<0> : tensor<1xi64> + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %arg2 : (tensor<1x2xi64>) -> tensor<2xi64> + // CHECK-NEXT: %[[SLICE0:.*]] = stablehlo.slice %[[RESHAPE]] [0:1] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: %[[SLICE1:.*]] = stablehlo.slice %[[RESHAPE]] [1:2] : (tensor<2xi64>) -> tensor<1xi64> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.dynamic_pad %arg0, %arg1, %[[SLICE0]], %[[SLICE1]], %[[ZEROS]] : (tensor, tensor, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor + // CHECK-NEXT: return %[[RESULT]] : tensor + + %1 = "tf.PadV2"(%arg0, %arg2, %arg1) : (tensor, tensor<1x2xi64>, tensor) -> tensor + func.return %1 : tensor +} + +//===----------------------------------------------------------------------===// +// Identity op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @identity +func.func @identity(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.Identity"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @identityN +func.func @identityN(%arg0: tensor<1xi32>, %arg1: tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) { + // CHECK-NEXT: return %arg0, %arg1 : tensor<1xi32>, tensor<1xf32> + %0:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xf32>) -> (tensor<1xi32>, tensor<1xf32>) + func.return %0#0, %0#1: tensor<1xi32>, tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: func @stopgradient +func.func @stopgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.StopGradient"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @preventgradient +func.func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-NEXT: return %arg0 : tensor<1xi32> + %0 = "tf.PreventGradient"(%arg0) {message = "fin gradients"} : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @checkNumerics +func.func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NEXT: return %arg0 : tensor<1xf32> + %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> + func.return %0: tensor<1xf32> +} + +//===----------------------------------------------------------------------===// +// InfeedDequeueTuple legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @infeed_dequeue_tuple +func.func @infeed_dequeue_tuple() -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) { + // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token + // CHECK: [[INFEED:%.*]]:3 = "stablehlo.infeed"([[TOKEN]]) <{infeed_config = ""{{.*}}}> : (!stablehlo.token) -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>, !stablehlo.token) + // CHECK: return [[INFEED]]#0, [[INFEED]]#1 + %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<1x8x4x4xi32>, tensor<1x100x1xf32>) + func.return %0#0, %0#1 : tensor<1x8x4x4xi32>, tensor<1x100x1xf32> +} + +// ----- + +// CHECK-LABEL: func @infeed_dequeue_tuple_dynamic_error +func.func @infeed_dequeue_tuple_dynamic_error() -> (tensor<3x3xf32>, tensor<4x?xf32>) { + // We expect legalization to fail for dynamic shapes: + // CHECK: [[INFEED:%.*]] = "tf.InfeedDequeueTuple"{{.*}} + %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3x3xf32>, tensor<4x?xf32>) + func.return %0#0, %0#1 : tensor<3x3xf32>, tensor<4x?xf32> +} + +// The following op sharding is used: +// Proto debug string: +// type: TUPLE +// tuple_shardings { +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// } +// Serialized string: +// "\08\02*\08\08\01\1A\01\01\22\01\00" + +// CHECK-LABEL: infeed_dequeue_tuple_sharding +func.func @infeed_dequeue_tuple_sharding() -> tensor<8xi32> { + // CHECK: "stablehlo.infeed" + // An additional sharding is added at the end to account for token result. + // Proto debug string: + // type: TUPLE + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // tuple_shardings { + // type: MAXIMAL + // tile_assignment_dimensions: 1 + // tile_assignment_devices: 0 + // } + // CHECK-SAME: mhlo.sharding = "\08\02*\08\08\01\1A\01\01\22\01\00*\08\08\01\1A\01\01\22\01\00" + %0 = "tf.InfeedDequeueTuple"() {_XlaSharding = "\08\02*\08\08\01\1A\01\01\22\01\00"} : () -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +//===----------------------------------------------------------------------===// +// Nullary op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @const +func.func @const() -> tensor<2xi32> { + // CHECK: stablehlo.constant dense<0> : tensor<2xi32> + %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = dense<0> : tensor<2xi32>} : () -> (tensor<2xi32>) + func.return %0: tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: @const_dynamic_output +func.func @const_dynamic_output() -> tensor<*xi32> { + // CHECK: [[CONST:%.*]] = stablehlo.constant dense<0> : tensor<2xi32> + // CHECK: [[CAST:%.*]] = tensor.cast [[CONST]] : tensor<2xi32> to tensor<*xi32> + %0 = "tf.Const"() {value = dense<0> : tensor<2xi32>} : () -> (tensor<*xi32>) + // CHECK: return [[CAST]] + func.return %0: tensor<*xi32> +} + +// ----- + +// CHECK-LABEL: @opaque_const +func.func @opaque_const() -> tensor>> { + // CHECK-NOT: stablehlo.constant + %0 = "tf.Const"() {device = "", name = "", dtype = "tfdtype$DT_INT32", value = #tf_type : tensor} : () -> tensor>> + func.return %0 : tensor>> +} + +//===----------------------------------------------------------------------===// +// Matmul op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: matmul_notranspose +// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<7x11xf32>) +func.func @matmul_notranspose(%a: tensor<5x7xf32>, %b: tensor<7x11xf32>) -> tensor<5x11xf32> { + // CHECK: stablehlo.dot %[[A]], %[[B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor<5x7xf32>, tensor<7x11xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// ----- + +// CHECK-LABEL: matmul_transpose_b +// CHECK-SAME: (%[[A:.*]]: tensor<5x7xf32>, %[[B:.*]]: tensor<11x7xf32>) +func.func @matmul_transpose_b(%a: tensor<5x7xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { + // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]], dims = [1, 0] + // CHECK: stablehlo.dot %[[A]], %[[UPDATED_B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = true} : (tensor<5x7xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// ----- + +// CHECK-LABEL: matmul_transpose_both +// CHECK-SAME: (%[[A:.*]]: tensor<7x5xf32>, %[[B:.*]]: tensor<11x7xf32>) +func.func @matmul_transpose_both(%a: tensor<7x5xf32>, %b: tensor<11x7xf32>) -> tensor<5x11xf32> { + // CHECK: %[[UPDATED_A:.*]] = stablehlo.transpose %[[A]] + // CHECK: %[[UPDATED_B:.*]] = stablehlo.transpose %[[B]] + // CHECK: stablehlo.dot %[[UPDATED_A]], %[[UPDATED_B]] + %0 = "tf.MatMul"(%a, %b) {transpose_a = true, transpose_b = true} : (tensor<7x5xf32>, tensor<11x7xf32>) -> tensor<5x11xf32> + + func.return %0 : tensor<5x11xf32> +} + +// Verify that MatMul with ranked inputs are lowered to HLO. +// CHECK-LABEL: matmul_ranked +func.func @matmul_ranked(%a: tensor, %b: tensor<7x?xf32>) -> tensor { + // CHECK: stablehlo.dot + %0 = "tf.MatMul"(%a, %b) {transpose_a = false, transpose_b = false} : (tensor, tensor<7x?xf32>) -> tensor + + func.return %0 : tensor +} + +// Verify SparseMatMul is legalized to dot. +// CHECK-LABEL: test_sparse_mat_mul +func.func @test_sparse_mat_mul(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xf32>) -> tensor<3x5xf32> { + // CHECK: stablehlo.dot + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xf32>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +// SparseMatMul where one operand needs to be transposed and the other one not. +// +// CHECK-LABEL: @test_sparse_mat_mul_with_transpose + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<5x4xf32> + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: %[[TRANSPOSE:.*]] = stablehlo.transpose %[[ARG1]] + // CHECK-SAME: dims = [1, 0] + // CHECK-SAME: -> tensor<4x5xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[TRANSPOSE]] + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: return %[[RESULT]] +func.func @test_sparse_mat_mul_with_transpose(%arg0: tensor<3x4xf32>, %arg1: tensor<5x4xf32>) -> tensor<3x5xf32> { + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = true} : (tensor<3x4xf32>, tensor<5x4xf32>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +// SparseMatMul where one operand needs to be casted and the other one not. +// +// CHECK-LABEL: @test_sparse_mat_mul_with_cast + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<4x5xbf16> + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: %[[CAST:.*]] = stablehlo.convert %[[ARG1]] + // CHECK-SAME: -> tensor<4x5xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.dot %[[ARG0]], %[[CAST]] + // CHECK-SAME: -> tensor<3x5xf32> + // CHECK: return %[[RESULT]] +func.func @test_sparse_mat_mul_with_cast(%arg0: tensor<3x4xf32>, %arg1: tensor<4x5xbf16>) -> tensor<3x5xf32> { + %0 = "tf.SparseMatMul"(%arg0, %arg1) {a_is_sparse = true, b_is_sparse = false, transpose_a = false, transpose_b = false} : (tensor<3x4xf32>, tensor<4x5xbf16>) -> tensor<3x5xf32> + func.return %0: tensor<3x5xf32> +} + +//===----------------------------------------------------------------------===// +// MaxPool op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: maxpool_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_valid_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: %[[INIT:.*]] = stablehlo.constant dense<-2147483648> : tensor + // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> + // CHECK: stablehlo.maximum + // CHECK: stablehlo.return + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_same_padding(%arg0: tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 4, 1]} : (tensor<2x13x25x7xi32>) -> tensor<2x4x7x7xi32> + func.return %0 : tensor<2x4x7x7xi32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_valid_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_valid_padding(%arg0: tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> { + // CHECK: %[[INIT:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: "stablehlo.reduce_window"(%[[ARG]], %[[INIT]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> + // CHECK: stablehlo.maximum + // CHECK: stablehlo.return + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x12x20x7xf32>) -> tensor<2x8x3x5x7xf32> + func.return %0 : tensor<2x8x3x5x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_3d_same_padding +// CHECK-SAME: %[[ARG:.*]]: tensor +func.func @maxpool_3d_same_padding(%arg0: tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + + %0 = "tf.MaxPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>) -> tensor<2x8x4x7x7xf32> + func.return %0 : tensor<2x8x4x7x7xf32> +} + +// ----- + +// CHECK-LABEL: maxpool_explicit_padding +func.func @maxpool_explicit_padding(%arg0: tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> { + // CHECK: tf.MaxPool + // TODO(b/165938852): need to support explicit padding in max_pool. + + %0 = "tf.MaxPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "EXPLICIT", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xi32>) -> tensor<2x3x5x7xi32> + func.return %0 : tensor<2x3x5x7xi32> +} + +//===----------------------------------------------------------------------===// +// MaxPoolGrad op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @max_pool_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x24x24x64xf32>, %arg1: tensor<10x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x12x12x64xf32> +func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_output: tensor<10x12x12x64xf32>, %grad: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor) -> tensor<10x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x24x24x64xf32> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 2, 1], + padding = "VALID", + strides = [1, 2, 2, 1] + } : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> + func.return %result : tensor<10x24x24x64xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_3d_grad_valid +// CHECK-SAME: %[[INPUT:.*]]: tensor<10x8x24x24x64xf32>, %arg1: tensor<10x8x12x12x64xf32>, %[[GRAD:.*]]: tensor<10x8x12x12x64xf32> +func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_output: tensor<10x8x12x12x64xf32>, %grad: tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[RESULT:.*]] = "stablehlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) + // CHECK-SAME: <{window_dimensions = array, window_strides = array}> ({ + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }, { + // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): + // CHECK: %[[SELECT_RESULT:.*]] = stablehlo.add %[[VALUE_A]], %[[VALUE_B]] : tensor + // CHECK: stablehlo.return %[[SELECT_RESULT]] : tensor + // CHECK: }) : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor) -> tensor<10x8x24x24x64xf32> + // CHECK: return %[[RESULT]] : tensor<10x8x24x24x64xf32> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 2, 2, 1]} : (tensor<10x8x24x24x64xf32>, tensor<10x8x12x12x64xf32>, tensor<10x8x12x12x64xf32>) -> tensor<10x8x24x24x64xf32> + func.return %result : tensor<10x8x24x24x64xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_grad_same +func.func @max_pool_grad_same(%orig_input: tensor<2x13x25x7xf32>, %orig_output: tensor<2x4x7x7xf32>, %grad: tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %result = "tf.MaxPoolGrad"(%orig_input, %orig_output, %grad) { + data_format = "NHWC", + ksize = [1, 2, 3, 1], + padding = "SAME", + strides = [1, 4, 4, 1] + } : (tensor<2x13x25x7xf32>, tensor<2x4x7x7xf32>, tensor<2x4x7x7xf32>) -> tensor<2x13x25x7xf32> + func.return %result : tensor<2x13x25x7xf32> +} + +// ----- + +// CHECK-LABEL: @max_pool_3d_grad_same +func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_output: tensor<2x8x4x7x7xf32>, %grad: tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> { + // CHECK: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + %result = "tf.MaxPool3DGrad"(%orig_input, %orig_output, %grad) {data_format = "NDHWC", ksize = [1, 1, 2, 3, 1], padding = "SAME", strides = [1, 1, 4, 4, 1]} : (tensor<2x8x13x25x7xf32>, tensor<2x8x4x7x7xf32>, tensor<2x8x4x7x7xf32>) -> tensor<2x8x13x25x7xf32> + func.return %result : tensor<2x8x13x25x7xf32> +} + +//===----------------------------------------------------------------------===// +// OneHot op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL:one_hot +func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { + // CHECK-NEXT: %[[IOTA0:.*]] = stablehlo.iota dim = 0 : tensor<5xi32> + // CHECK-NEXT: %[[BCAST0:.*]] = stablehlo.broadcast_in_dim %[[IOTA0]], dims = [1] : (tensor<5xi32>) -> tensor<3x5xi32> + // CHECK-NEXT: %[[BCAST1:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<3xi32>) -> tensor<3x5xi32> + // CHECK-NEXT: %[[CMP0:.*]] = stablehlo.compare EQ, %[[BCAST1]], %[[BCAST0]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK-NEXT: %[[BCAST2:.*]] = stablehlo.broadcast %arg1, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> + // CHECK-NEXT: %[[BCAST3:.*]] = stablehlo.broadcast %arg2, sizes = [3, 5] : (tensor) -> tensor<3x5xf32> + // CHECK-NEXT: %[[RESULT:.*]] = stablehlo.select %[[CMP0]], %[[BCAST2]], %[[BCAST3]] : tensor<3x5xi1>, tensor<3x5xf32> + // CHECK-NEXT: return %[[RESULT]] : tensor<3x5xf32> + %depth = "tf.Const"() { value = dense<5> : tensor } : () -> tensor + %result = "tf.OneHot"(%indices, %depth, %on_value, %off_value) {axis = -1 : i64} : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<3x5xf32> + func.return %result : tensor<3x5xf32> +} + +//===----------------------------------------------------------------------===// +// tf.OutfeedEnqueueTuple legalization +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @outfeed_enqueue_tuple +// CHECK-SAME: [[VAL_0:%.*]]: tensor<3xi32>, [[VAL_1:%.*]]: tensor<4xf32>) +func.func @outfeed_enqueue_tuple(%data_1: tensor<3xi32>, %data_2: tensor<4xf32>) -> () { + // CHECK: [[TOKEN:%.*]] = stablehlo.create_token : !stablehlo.token + // CHECK: "stablehlo.outfeed"([[VAL_0]], [[VAL_1]], [[TOKEN]]) <{outfeed_config = ""}> : (tensor<3xi32>, tensor<4xf32>, !stablehlo.token) -> !stablehlo.token + "tf.OutfeedEnqueueTuple"(%data_1, %data_2) : (tensor<3xi32>, tensor<4xf32>) -> () + func.return +} + +//===----------------------------------------------------------------------===// +// Pack op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @pack +func.func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { + // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: stablehlo.reshape {{.*}} : (tensor<2xi32>) -> tensor<1x2xi32> + // CHECK: stablehlo.concatenate {{.*}}, {{.*}}, dim = 0 + + %0 = "tf.Pack"(%arg0, %arg1) : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> + func.return %0 : tensor<2x2xi32> +} + +//===----------------------------------------------------------------------===// +// PartitionedCall op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @partitioned_call +func.func @partitioned_call(%arg0: tensor) -> tensor { + // CHECK: call @pcall_func(%arg0) : (tensor) -> tensor + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + + +func.func @pcall_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @partitioned_call_multi_input +func.func @partitioned_call_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: call @pcall_multi_input(%arg0, %arg1) : (tensor, tensor) -> tensor + %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_input} : (tensor, tensor) -> (tensor) + func.return %0 : tensor +} + + +func.func @pcall_multi_input(%arg0: tensor, %arg1: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @partitioned_call_multi_in_out +func.func @partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: call @pcall_multi_in_out(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + + +func.func @pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func.return %arg1, %arg0 : tensor, tensor +} + +// CHECK-LABEL: func @unhandled_partitioned_call +func.func @unhandled_partitioned_call(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> (tensor, tensor) { + // The argument types don't match the parameter types for the + // pcall_multi_in_out function. That's fine for a PartitionedCallOp but not + // for a standard CallOp, so this op can't be lowered. + // CHECK: "tf.PartitionedCall" + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor<*xi32>, tensor<*xi32>) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + + +// CHECK-LABEL: func @unhandled_partitioned_call_2 +func.func @unhandled_partitioned_call_2(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor, tensor) { + // CHECK: "tf.PartitionedCall" + %0, %1 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @pcall_multi_in_out} : (tensor, tensor<*xi32>) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @no_args_and_results +func.func @no_args_and_results() { + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + // CHECK: call @callee() : () -> () + "tf.PartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.StatefulPartitionedCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + "tf.LegacyCall"() {config = "", config_proto = "", executor_type = "", f = @callee} : () -> () + func.return +} + +func.func @callee() { + func.return +} + +//===----------------------------------------------------------------------===// +// ReverseV2 op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @reverse_func_32 +func.func @reverse_func_32(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi32>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_64 +func.func @reverse_func_64(%arg0: tensor<5xi32>) -> tensor<5xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor<1xi64>} : () -> (tensor<1xi64>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [0] : tensor<5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5xi32>, tensor<1xi64>) -> tensor<5xi32> + + // CHECK: return [[VAL]] : tensor<5xi32> + func.return %reversed : tensor<5xi32> +} + +// ----- + +// CHECK-LABEL: @reverse_func_neg +func.func @reverse_func_neg(%arg0: tensor<5x5xi32>) -> tensor<5x5xi32> { + %axis = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) + + // CHECK: [[VAL:%.+]] = stablehlo.reverse %arg0, dims = [1] : tensor<5x5xi32> + %reversed = "tf.ReverseV2"(%arg0, %axis) : (tensor<5x5xi32>, tensor<1xi32>) -> tensor<5x5xi32> + + // CHECK: return [[VAL]] : tensor<5x5xi32> + func.return %reversed : tensor<5x5xi32> +} + +//===----------------------------------------------------------------------===// +// StatefulPartitionedCall op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @stateful_partitioned_call +// CHECK-SAME: [[ARG:%.+]]: tensor +func.func @stateful_partitioned_call(%arg0: tensor) -> tensor { + // CHECK: call @stateful_pcall_func([[ARG]]) : (tensor) -> tensor + %0 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_func} : (tensor) -> (tensor) + func.return %0 : tensor +} + +func.func @stateful_pcall_func(%arg0: tensor) -> tensor { + func.return %arg0 : tensor +} + +// ----- + +// CHECK-LABEL: func @stateful_partitioned_call_multi_in_out +// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) +func.func @stateful_partitioned_call_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + // CHECK: call @stateful_pcall_multi_in_out([[ARG0]], [[ARG1]]) : (tensor, tensor) -> (tensor, tensor) + %0, %1 = "tf.StatefulPartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @stateful_pcall_multi_in_out} : (tensor, tensor) -> (tensor, tensor) + func.return %0, %1 : tensor, tensor +} + +func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + func.return %arg1, %arg0 : tensor, tensor +} + +//===----------------------------------------------------------------------===// +// Elu op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @elu +func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<1xf32> + // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] + // CHECK-DAG: %[[EXP:.*]] = stablehlo.exponential_minus_one %arg0 + // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %arg0, %[[EXP]] + // CHECK: return %[[RESULT]] + %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return %0: tensor<1xf32> +} + +// ----- + +// CHECK-LABEL: func @elu_grad +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func.func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[PRED:.*]] = stablehlo.compare GT, %[[FEATURES]], %[[BCAST0]] + // CHECK-DAG: %[[BCAST1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ONE]], {{.*}}, dims = [] : (tensor, tensor<2xindex>) -> tensor + // CHECK-DAG: %[[ADD1:.*]] = stablehlo.add %[[FEATURES]], %[[BCAST1]] + // CHECK-DAG: %[[MULGRAD:.*]] = stablehlo.multiply %[[GRADIENTS]], %[[ADD1]] : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + // CHECK: %[[RESULT:.*]] = stablehlo.select %[[PRED]], %[[GRADIENTS]], %[[MULGRAD]] + // CHECK: return %[[RESULT]] + %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + func.return %2 : tensor<4x8xf32> +} + +//===----------------------------------------------------------------------===// +// Relu op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @relu +func.func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor<1xi32> + // CHECK: stablehlo.maximum %arg0, %[[ZERO]] + %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @relu_unsigned +func.func @relu_unsigned(%arg0: tensor) -> tensor { + // CHECK: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK: %[[BCAST0:.*]] = stablehlo.dynamic_broadcast_in_dim %[[ZERO]], {{.*}}, dims = [] + // CHECK: stablehlo.maximum %arg0, %[[BCAST0]] + %0 = "tf.Relu"(%arg0) : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @relu6 +func.func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor + // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] + %0 = "tf.Relu6"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + func.return %0: tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @relu6_unsigned +func.func @relu6_unsigned(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[ZERO:.*]] = stablehlo.constant dense<0> : tensor + // CHECK-DAG: %[[SIX:.*]] = stablehlo.constant dense<6> : tensor + // CHECK: stablehlo.clamp %[[ZERO]], %arg0, %[[SIX]] + %0 = "tf.Relu6"(%arg0) : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @leaky_relu +func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> + // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> + // CHECK-NEXT: %[[LEAKY:.*]] = stablehlo.multiply %arg0, %[[ALPHA]] + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %arg0, %[[ZERO]] + // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %arg0, %[[LEAKY]] + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> + %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> + func.return %0 : tensor<1x4x4x3xf32> +} + +// ----- + +// CHECK-LABEL: func @leaky_relu_grad +func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) -> tensor<1x4x4xf32> attributes {tf.entry_function = {}} { + // CHECK-NEXT: %[[ALPHA:.*]] = stablehlo.constant dense<2.000000e-01> + // CHECK-NEXT: %[[ZERO:.*]] = stablehlo.constant dense<0.000000e+00> + // CHECK-NEXT: %[[LEAKYGRAD:.*]] = stablehlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] + // CHECK-NEXT: %[[CMP:.*]] = stablehlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE + // CHECK-NEXT: %[[RES:.*]] = stablehlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] + // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> + %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> + func.return %0 : tensor<1x4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign +func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { + // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] + // CHECK-NEXT: %[[DIV:.*]] = stablehlo.divide %{{.*}}, %[[ADD]] + // CHECK-NEXT: return %[[DIV]] : tensor<4x10xf32> + %0 = "tf.Softsign"(%arg0) : (tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +// ----- + +// CHECK-LABEL: func @softsign_grad +func.func @softsign_grad(%arg0: tensor<4x10xf32>, %arg1: tensor<4x10xf32>) -> tensor<4x10xf32> { + + // CHECK-NEXT: %[[ONE:.*]] = stablehlo.constant dense<1.000000e+00> + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %{{.*}} : tensor<4x10xf32> + // CHECK-NEXT: %[[BROADCAST_ADD:.*]] = stablehlo.add %[[ABS]], %[[ONE]] + // CHECK-NEXT: %[[MUL:.*]] = stablehlo.multiply %[[BROADCAST_ADD]], %[[BROADCAST_ADD]] + // CHECK-NEXT: %[[BROADCAST_DIV:.*]] = stablehlo.divide %{{.*}}, %[[MUL]] + // CHECK-NEXT: return %[[BROADCAST_DIV]] : tensor<4x10xf32> + %0 = "tf.SoftsignGrad"(%arg0, %arg1) : (tensor<4x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32> + func.return %0 : tensor<4x10xf32> +} + +//===----------------------------------------------------------------------===// +// Roll op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @Roll_0D +func.func @Roll_0D(%arg0: tensor<512xi32>, %shift: tensor) -> tensor<512xi32> { + %axis = "tf.Const"() {value = dense<0> : tensor} : () -> (tensor) + // CHECK: %[[AXIS_SIZE:.*]] = stablehlo.constant dense<512> : tensor + // CHECK: %[[T1:.+]] = stablehlo.remainder %arg1, %[[AXIS_SIZE]] : tensor + // CHECK: %[[T2:.+]] = stablehlo.add %[[T1]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[T3:.+]] = stablehlo.remainder %[[T2]], %[[AXIS_SIZE]] : tensor + // CHECK: %[[CONCAT:.+]] = stablehlo.concatenate %arg0, %arg0, dim = 0 + // CHECK: %[[OFFSET:.+]] = stablehlo.subtract %[[AXIS_SIZE]], %[[T3]] : tensor + // CHECK: stablehlo.dynamic_slice %[[CONCAT]], %[[OFFSET]], sizes = [512] + %0 = "tf.Roll"(%arg0, %shift, %axis) {device = ""} : (tensor<512xi32>, tensor, tensor) -> tensor<512xi32> + func.return %0 : tensor<512xi32> +} + +//===----------------------------------------------------------------------===// +// Select op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @select_batch_static +func.func @select_batch_static(%arg0: tensor<2xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: %[[BCAST:.*]] = stablehlo.broadcast_in_dim %arg0, dims = [0] + // CHECK: stablehlo.select %[[BCAST]], %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_static_r1 +func.func @select_batch_static_r1(%arg0: tensor, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: stablehlo.select %arg0, %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_static_all_same +func.func @select_batch_static_all_same(%arg0: tensor<2x6x8xi1>, %arg1: tensor<2x6x8xi32>, %arg2: tensor<2x6x8xi32>) -> tensor<2x6x8xi32> { + // CHECK: stablehlo.select %arg0, %arg1, %arg2 + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<2x6x8xi1>, tensor<2x6x8xi32>, tensor<2x6x8xi32>) -> tensor<2x6x8xi32> + func.return %0: tensor<2x6x8xi32> +} + +// ----- + +// CHECK-LABEL: func @select_batch_dynamic_r1 +func.func @select_batch_dynamic_r1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index + // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<1xindex> + // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[HEAD:.*]], %[[TAIL:.*]] = "shape.split_at"(%[[SHAPE1]], %[[C1]]) : (tensor<3xindex>, index) -> (tensor<1xindex>, tensor<2xindex>) + // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[HEAD]] : tensor<1xindex>, tensor<1xindex> + // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming_all %[[SHAPEEQ1]], %[[SHAPEEQ2]] + // CHECK-NEXT: %[[ASSUMING:.*]] = shape.assuming %[[SHAPEEQ]] -> (tensor) { + // CHECK-NEXT: %[[BCAST:.*]] = stablehlo.dynamic_broadcast_in_dim %arg0, %[[SHAPE1]], dims = [0] + // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %[[BCAST]], %arg1, %arg2 : tensor, tensor + // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @select_batch_dynamic +func.func @select_batch_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK-NEXT: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE1:.*]] = shape.shape_of %arg1 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPE2:.*]] = shape.shape_of %arg2 : tensor -> tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ1:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ2:.*]] = shape.cstr_eq %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ3:.*]] = shape.cstr_eq %[[SHAPE1]], %[[SHAPE2]], %[[SHAPE0]], %[[SHAPE1]] : tensor<3xindex>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex> + // CHECK-NEXT: %[[SHAPEEQ:.*]] = shape.assuming %[[SHAPEEQ3]] + // CHECK-NEXT: %[[SELECT:.*]] = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor + // CHECK-NEXT: shape.assuming_yield %[[SELECT]] : tensor + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: testSelectInvalidUnranked +func.func @testSelectInvalidUnranked(%arg0: tensor<6x7xi1>, %arg1: tensor<*xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<6x7xi1>, tensor<*xf16>, tensor<*xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: testSelectThenUnranked +func.func @testSelectThenUnranked(%arg0: tensor<3xi1>, %arg1: tensor<*xf16>, %arg2: tensor<3x2xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<*xf16>, tensor<3x2xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: testSelectElseUnranked +func.func @testSelectElseUnranked(%arg0: tensor<3xi1>, %arg1: tensor<3x2xf16>, %arg2: tensor<*xf16>) -> tensor<*xf16> { + // CHECK-NEXT: tf.Select + %0 = "tf.Select"(%arg0, %arg1, %arg2) : (tensor<3xi1>, tensor<3x2xf16>, tensor<*xf16>) -> tensor<*xf16> + func.return %0: tensor<*xf16> +} + +// ----- + +// CHECK-LABEL: func @selectv2_dynamic_ranked +func.func @selectv2_dynamic_ranked(%arg0: tensor<1xi1>, %arg1: tensor<2x?x8xi32>, %arg2: tensor<2x8x8xi32>) -> tensor<2x?x8xi32> { + // CHECK: stablehlo.select + %0 = "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<1xi1>, tensor<2x?x8xi32>, tensor<2x8x8xi32>) -> tensor<2x?x8xi32> + func.return %0: tensor<2x?x8xi32> +} + +//===----------------------------------------------------------------------===// +// Fast Fourier Transform op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @fft_1D +func.func @fft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: stablehlo.fft %arg0, type = FFT, length = [8] + %0 = "tf.FFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: func @ifft_1D +func.func @ifft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xcomplex> { + // CHECK: stablehlo.fft %arg0, type = IFFT, length = [8] + %0 = "tf.IFFT"(%arg0) : (tensor<8xcomplex>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D +func.func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: stablehlo.fft %arg0, type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<8xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_padded +func.func @rfft_1D_padded(%arg0: tensor<7xf32>) -> tensor<5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[PADDED:.*]] = stablehlo.pad %arg0, %{{.*}}, low = [0], high = [1], interior = [0] + // CHECK: stablehlo.fft %[[PADDED]], type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<7xf32>, tensor<1xi32>) -> tensor<5xcomplex> + func.return %0 : tensor<5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @rfft_1D_sliced +func.func @rfft_1D_sliced(%arg0: tensor<2x9xf32>) -> tensor<2x5xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:2, 0:8] + // CHECK: stablehlo.fft %[[SLICED]], type = RFFT, length = [8] + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor<2x9xf32>, tensor<1xi32>) -> tensor<2x5xcomplex> + func.return %0 : tensor<2x5xcomplex> +} + +// ----- + +// CHECK-LABEL: func @irfft_1D +func.func @irfft_1D(%arg0: tensor<8xcomplex>) -> tensor<8xf32> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: %[[SLICED:.*]] = stablehlo.slice %arg0 [0:5] + // CHECK: stablehlo.fft %[[SLICED]], type = IRFFT, length = [8] + %0 = "tf.IRFFT"(%arg0, %fftlength) : (tensor<8xcomplex>, tensor<1xi32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// ----- + +// CHECK-LABEL: fft_1D_dynamic +func.func @fft_1D_dynamic(%arg0: tensor>) -> tensor<8xcomplex> { + // CHECK: "tf.FFT" + %0 = "tf.FFT"(%arg0) : (tensor>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +// ----- + +// CHECK-LABEL: rfft_1D_dynamic +func.func @rfft_1D_dynamic(%arg0: tensor) -> tensor<8xcomplex> { + %fftlength = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) + // CHECK: "tf.RFFT" + %0 = "tf.RFFT"(%arg0, %fftlength) : (tensor, tensor<1xi32>) -> tensor<8xcomplex> + func.return %0 : tensor<8xcomplex> +} + +//===----------------------------------------------------------------------===// +// Shape op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @shape_1D +func.func @shape_1D(%arg0: tensor) -> tensor<1xi32> { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<1xindex> to tensor<1xi32> + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<1xi32> + + // CHECK: return [[TENSOR]] + func.return %0 : tensor<1xi32> +} + +// ----- + +// CHECK-LABEL: func @shape_2D +func.func @shape_2D(%arg0: tensor) -> tensor<2xi32> { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor<2xindex> to tensor<2xi32> + %0 = "tf.Shape"(%arg0) : (tensor) -> tensor<2xi32> + + // CHECK: return [[TENSOR]] + func.return %0 : tensor<2xi32> +} + +// ----- + +// CHECK-LABEL: func @shape_rankless +func.func @shape_rankless(%arg0: tensor<*xf32>) -> tensor { + // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0 + // CHECK: [[TENSOR:%.+]] = arith.index_cast [[SHAPE]] : tensor to tensor + %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor + + // CHECK: return [[TENSOR]] + func.return %0 : tensor +} + +//===----------------------------------------------------------------------===// +// Transpose op legalization. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @transpose_noop +func.func @transpose_noop(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + %permutation = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: return %arg0 + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<2x3xf32> + func.return %0 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_2d +func.func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor<3x2xf32> + func.return %0 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_3d_int32 +func.func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi32>} : () -> (tensor<3xi32>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi32>) -> tensor<3x2x1xf32> + func.return %0 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_3d +func.func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { + %permutation = "tf.Const"() {value = dense<[2, 1, 0]> : tensor<3xi64>} : () -> (tensor<3xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor<1x2x3xf32>, tensor<3xi64>) -> tensor<3x2x1xf32> + func.return %0 : tensor<3x2x1xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_dynamic_2d +func.func @transpose_dynamic_2d(%arg0: tensor) -> tensor<4x?xf32> { + %permutation = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) + // CHECK: stablehlo.transpose + %0 = "tf.Transpose"(%arg0, %permutation) : (tensor, tensor<2xi64>) -> tensor<4x?xf32> + func.return %0 : tensor<4x?xf32> +} + +//===----------------------------------------------------------------------===// +// Unary op legalizations. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: @abs +func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.abs %arg0 : tensor<2xf32> + %0 = "tf.Abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @abs_dynamic +func.func @abs_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.abs %arg0 : tensor + %0 = "tf.Abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @acos +func.func @acos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: %[[TEMP_0:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK: %[[TEMP_1:.*]] = stablehlo.subtract %[[TEMP_0]], %arg0 : tensor<2xf32> + // CHECK: %[[TEMP_2:.*]] = stablehlo.add %arg0, %[[TEMP_0]] : tensor<2xf32> + // CHECK: %[[TEMP_3:.*]] = stablehlo.multiply %[[TEMP_1]], %[[TEMP_2]] : tensor<2xf32> + // CHECK: %[[TEMP_4:.*]] = stablehlo.sqrt %[[TEMP_3]] : tensor<2xf32> + // CHECK: %[[TEMP_5:.*]] = stablehlo.atan2 %[[TEMP_4]], %arg0 : tensor<2xf32> + // CHECK: return %[[TEMP_5]] : tensor<2xf32> + %0 = "tf.Acos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @acos_complex +func.func @acos_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { +// CHECK-NEXT: %[[TEMP_1:.*]] = stablehlo.constant dense<4.33680869E-19> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_2:.*]] = stablehlo.constant dense<0.693147182> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_3:.*]] = stablehlo.constant dense<2.30584283E+20> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_4:.*]] = stablehlo.constant dense<2.30584274E+12> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_5:.*]] = stablehlo.constant dense<2.30584285E+30> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_6:.*]] = stablehlo.constant dense<1.41421354> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_7:.*]] = stablehlo.constant dense<2.30584287E+18> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_8:.*]] = stablehlo.constant dense<1.500000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_9:.*]] = stablehlo.constant dense<0x7F800000> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_10:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_11:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_12:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_13:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_14:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_15:.*]] = stablehlo.abs %[[TEMP_14]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_16:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_17:.*]] = stablehlo.abs %[[TEMP_16]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_18:.*]] = stablehlo.maximum %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_19:.*]] = stablehlo.compare GE, %[[TEMP_18]], %[[TEMP_7]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_20:.*]] = stablehlo.compare LE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_21:.*]] = stablehlo.add %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_22:.*]] = stablehlo.abs %[[TEMP_21]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_23:.*]] = stablehlo.maximum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_24:.*]] = stablehlo.minimum %[[TEMP_22]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_25:.*]] = stablehlo.compare EQ, %[[TEMP_23]], %[[TEMP_24]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_26:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_6]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_27:.*]] = stablehlo.divide %[[TEMP_24]], %[[TEMP_23]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_28:.*]] = stablehlo.multiply %[[TEMP_27]], %[[TEMP_27]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_29:.*]] = stablehlo.add %[[TEMP_28]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_30:.*]] = stablehlo.sqrt %[[TEMP_29]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_31:.*]] = stablehlo.compare EQ, %[[TEMP_30]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_32:.*]] = stablehlo.compare GT, %[[TEMP_28]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_33:.*]] = stablehlo.and %[[TEMP_31]], %[[TEMP_32]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_34:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_28]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_35:.*]] = stablehlo.divide %[[TEMP_34]], %[[TEMP_11]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_36:.*]] = stablehlo.add %[[TEMP_23]], %[[TEMP_35]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_37:.*]] = stablehlo.multiply %[[TEMP_23]], %[[TEMP_30]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_38:.*]] = stablehlo.select %[[TEMP_33]], %[[TEMP_36]], %[[TEMP_37]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_39:.*]] = stablehlo.select %[[TEMP_25]], %[[TEMP_26]], %[[TEMP_38]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_40:.*]] = stablehlo.subtract %[[TEMP_15]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_41:.*]] = stablehlo.abs %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_42:.*]] = stablehlo.maximum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_43:.*]] = stablehlo.minimum %[[TEMP_41]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_44:.*]] = stablehlo.compare EQ, %[[TEMP_42]], %[[TEMP_43]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_45:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_6]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_46:.*]] = stablehlo.divide %[[TEMP_43]], %[[TEMP_42]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_47:.*]] = stablehlo.multiply %[[TEMP_46]], %[[TEMP_46]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_48:.*]] = stablehlo.add %[[TEMP_47]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_49:.*]] = stablehlo.sqrt %[[TEMP_48]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_50:.*]] = stablehlo.compare EQ, %[[TEMP_49]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_51:.*]] = stablehlo.compare GT, %[[TEMP_47]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_52:.*]] = stablehlo.and %[[TEMP_50]], %[[TEMP_51]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_53:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_47]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_54:.*]] = stablehlo.divide %[[TEMP_53]], %[[TEMP_11]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_55:.*]] = stablehlo.add %[[TEMP_42]], %[[TEMP_54]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_56:.*]] = stablehlo.multiply %[[TEMP_42]], %[[TEMP_49]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_57:.*]] = stablehlo.select %[[TEMP_52]], %[[TEMP_55]], %[[TEMP_56]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_58:.*]] = stablehlo.select %[[TEMP_44]], %[[TEMP_45]], %[[TEMP_57]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_59:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_58]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_60:.*]] = stablehlo.multiply %[[TEMP_59]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_61:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_15]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_62:.*]] = stablehlo.multiply %[[TEMP_61]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_63:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_64:.*]] = stablehlo.add %[[TEMP_39]], %[[TEMP_21]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_65:.*]] = stablehlo.divide %[[TEMP_63]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_66:.*]] = stablehlo.subtract %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_67:.*]] = stablehlo.add %[[TEMP_65]], %[[TEMP_66]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_68:.*]] = stablehlo.multiply %[[TEMP_62]], %[[TEMP_67]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_69:.*]] = stablehlo.sqrt %[[TEMP_68]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_70:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_71:.*]] = stablehlo.add %[[TEMP_58]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_72:.*]] = stablehlo.divide %[[TEMP_62]], %[[TEMP_71]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_73:.*]] = stablehlo.add %[[TEMP_70]], %[[TEMP_72]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_74:.*]] = stablehlo.sqrt %[[TEMP_73]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_75:.*]] = stablehlo.multiply %[[TEMP_17]], %[[TEMP_74]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_76:.*]] = stablehlo.select %[[TEMP_20]], %[[TEMP_69]], %[[TEMP_75]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_77:.*]] = stablehlo.select %[[TEMP_19]], %[[TEMP_17]], %[[TEMP_76]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_78:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_5]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_79:.*]] = stablehlo.select %[[TEMP_78]], %[[TEMP_4]], %[[TEMP_3]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_80:.*]] = stablehlo.compare GE, %[[TEMP_17]], %[[TEMP_79]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_81:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_17]], %[[TEMP_15]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_82:.*]] = stablehlo.select %[[TEMP_80]], %[[TEMP_79]], %[[TEMP_7]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_83:.*]] = stablehlo.compare GE, %[[TEMP_81]], %[[TEMP_82]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_84:.*]] = stablehlo.log %[[TEMP_81]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_85:.*]] = stablehlo.add %[[TEMP_84]], %[[TEMP_2]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_86:.*]] = stablehlo.compare EQ, %[[TEMP_17]], %[[TEMP_9]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_87:.*]] = stablehlo.not %[[TEMP_86]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_88:.*]] = stablehlo.and %[[TEMP_80]], %[[TEMP_87]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_89:.*]] = stablehlo.divide %[[TEMP_15]], %[[TEMP_17]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_90:.*]] = stablehlo.select %[[TEMP_88]], %[[TEMP_89]], %[[TEMP_10]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_91:.*]] = stablehlo.multiply %[[TEMP_90]], %[[TEMP_90]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_92:.*]] = stablehlo.log_plus_one %[[TEMP_91]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_93:.*]] = stablehlo.multiply %[[TEMP_92]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_94:.*]] = stablehlo.add %[[TEMP_85]], %[[TEMP_93]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_95:.*]] = stablehlo.compare LT, %[[TEMP_17]], %[[TEMP_1]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_96:.*]] = stablehlo.compare LT, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_97:.*]] = stablehlo.and %[[TEMP_95]], %[[TEMP_96]] : tensor<2xi1> +// CHECK-NEXT: %[[TEMP_98:.*]] = stablehlo.multiply %[[TEMP_21]], %[[TEMP_40]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_99:.*]] = stablehlo.add %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_100:.*]] = stablehlo.divide %[[TEMP_98]], %[[TEMP_99]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_101:.*]] = stablehlo.negate %[[TEMP_100]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_102:.*]] = stablehlo.compare GE, %[[TEMP_15]], %[[TEMP_13]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_103:.*]] = stablehlo.multiply %[[TEMP_63]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_104:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_64]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_105:.*]] = stablehlo.multiply %[[TEMP_71]], %[[TEMP_12]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_106:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_105]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_107:.*]] = stablehlo.compare LE, %[[TEMP_60]], %[[TEMP_8]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_108:.*]] = stablehlo.divide %[[TEMP_103]], %[[TEMP_66]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_109:.*]] = stablehlo.add %[[TEMP_104]], %[[TEMP_108]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_110:.*]] = stablehlo.subtract %[[TEMP_60]], %[[TEMP_13]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_111:.*]] = stablehlo.select %[[TEMP_107]], %[[TEMP_109]], %[[TEMP_110]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_112:.*]] = stablehlo.select %[[TEMP_102]], %[[TEMP_106]], %[[TEMP_111]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_113:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_101]], %[[TEMP_112]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_114:.*]] = stablehlo.multiply %[[TEMP_113]], %[[TEMP_99]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_115:.*]] = stablehlo.sqrt %[[TEMP_114]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_116:.*]] = stablehlo.divide %[[TEMP_17]], %[[TEMP_115]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_117:.*]] = stablehlo.add %[[TEMP_113]], %[[TEMP_115]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_118:.*]] = stablehlo.log_plus_one %[[TEMP_117]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_119:.*]] = stablehlo.select %[[TEMP_97]], %[[TEMP_116]], %[[TEMP_118]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_120:.*]] = stablehlo.select %[[TEMP_83]], %[[TEMP_94]], %[[TEMP_119]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_121:.*]] = stablehlo.real %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_122:.*]] = stablehlo.atan2 %[[TEMP_77]], %[[TEMP_121]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_123:.*]] = stablehlo.imag %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK-NEXT: %[[TEMP_124:.*]] = stablehlo.compare LT, %[[TEMP_123]], %[[TEMP_10]] : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> +// CHECK-NEXT: %[[TEMP_125:.*]] = stablehlo.negate %[[TEMP_120]] : tensor<2xf32> +// CHECK-NEXT: %[[TEMP_126:.*]] = stablehlo.select %[[TEMP_124]], %[[TEMP_120]], %[[TEMP_125]] : tensor<2xi1>, tensor<2xf32> +// CHECK-NEXT: %[[TEMP_127:.*]] = stablehlo.complex %[[TEMP_122]], %[[TEMP_126]] : tensor<2xcomplex> +// CHECK-NEXT: return %[[TEMP_127]] : tensor<2xcomplex> + + %0 = "tf.Acos"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} + +// ----- + +// CHECK-LABEL: @acos_dynamic +func.func @acos_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: "tf.Acos" + %0 = "tf.Acos"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_dynamic_i2f +func.func @cast_dynamic_i2f(%arg0: tensor) -> tensor { + // CHECK: stablehlo.convert %arg0 : (tensor) -> tensor + %0 = "tf.Cast"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @cast_i2f +func.func @cast_i2f(%arg0: tensor<2xi32>) -> tensor<2xf32> { + // CHECK: stablehlo.convert %arg0 : (tensor<2xi32>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_c2f +func.func @cast_c2f(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + // CHECK: stablehlo.convert %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "tf.Cast"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @ceil +func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.ceil %arg0 : tensor<2xf32> + %0 = "tf.Ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @ceil_dynamic +func.func @ceil_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.ceil %arg0 : tensor + %0 = "tf.Ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @complex_abs +func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + // CHECK: stablehlo.abs %arg0 : (tensor<2xcomplex>) -> tensor<2xf32> + %0 = "tf.ComplexAbs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @cos +func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.cosine %arg0 : tensor<2xf32> + %0 = "tf.Cos"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @tan +func.func @tan(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.tan %arg0 : tensor<2xf32> + %0 = "tf.Tan"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @cos_dynamic +func.func @cos_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.cosine %arg0 : tensor + %0 = "tf.Cos"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @exp +func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.exponential %arg0 : tensor<2xf32> + %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @expm1 +func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.exponential_minus_one %arg0 : tensor<2xf32> + %0 = "tf.Expm1"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @exp_dynamic +func.func @exp_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.exponential %arg0 : tensor + %0 = "tf.Exp"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @floor +func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.floor %arg0 : tensor<2xf32> + %0 = "tf.Floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @floor_dynamic +func.func @floor_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.floor %arg0 : tensor + %0 = "tf.Floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @is_finite +func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { + // CHECK: stablehlo.is_finite %arg0 : (tensor<2xf32>) -> tensor<2xi1> + %0 = "tf.IsFinite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// ----- + +// CHECK-LABEL: func @is_finite_dynamic +func.func @is_finite_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.is_finite %arg0 : (tensor) -> tensor + %0 = "tf.IsFinite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @log +func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.log %arg0 : tensor<2xf32> + %0 = "tf.Log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @log_dynamic +func.func @log_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.log %arg0 : tensor + %0 = "tf.Log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @log1p +func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.log_plus_one %arg0 : tensor<2xf32> + %0 = "tf.Log1p"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @log1p_dynamic +func.func @log1p_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.log_plus_one %arg0 : tensor + %0 = "tf.Log1p"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @neg +func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.negate %arg0 : tensor<2xf32> + %0 = "tf.Neg"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: func @neg_dynamic +func.func @neg_dynamic(%arg0: tensor) -> tensor { + // CHECK: stablehlo.negate %arg0 : tensor + %0 = "tf.Neg"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @sigmoid +func.func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { + // CHECK: stablehlo.logistic + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// ----- + +// CHECK-LABEL: @sigmoid_complex +func.func @sigmoid_complex(%arg0: tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: stablehlo.logistic + %0 = "tf.Sigmoid"(%arg0) : (tensor<2xcomplex>) -> tensor<2xcomplex> + func.return %0 : tensor<2xcomplex> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index 29cf4ced5b142e..cabce601fa1396 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -2751,8 +2751,8 @@ func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { func.return %0 : tensor<2xf32> } -// CHECK-NOT: tfl - +// CHECK: %0 = "tfl.complex_abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> +// CHECK: return %0 : tensor<2xf32> // ----- func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { @@ -3636,3 +3636,36 @@ func.func @if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (ten // CHECK: "tfl.yield"(%2) : (tensor) -> () // CHECK: }) : (tensor) -> tensor // CHECK: return %1 : tensor + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.fft +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: rfft_2d +func.func @rfft_2d(%arg0: tensor<1x512xf32>) -> tensor<1x257xcomplex> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<512> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<1x512xf32>) -> tensor<1x257xcomplex> + func.return %0 : tensor<1x257xcomplex> +} + +// CHECK: %cst = arith.constant dense<-2> : tensor +// CHECK: %0 = "tfl.expand_dims"(%arg0, %cst) : (tensor<1x512xf32>, tensor) -> tensor<1x1x512xf32> +// CHECK-DAG: %cst_0 = arith.constant dense<1> : tensor<1xi32> +// CHECK-DAG: %cst_1 = arith.constant dense<512> : tensor<1xi32> +// CHECK: %1 = "tfl.concatenation"(%cst_0, %cst_1) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %2 = "tfl.rfft2d"(%0, %1) : (tensor<1x1x512xf32>, tensor<2xi32>) -> tensor<1x1x257xcomplex> +// CHECK: %3 = "tfl.squeeze"(%2) <{squeeze_dims = [-2]}> : (tensor<1x1x257xcomplex>) -> tensor<1x257xcomplex> +// CHECK: return %3 : tensor<1x257xcomplex> + +// ----- + +// CHECK-LABEL: @fft +func.func @fft(%arg0: tensor<3x9xcomplex>) -> tensor<3x9xcomplex> { + %0 = "mhlo.fft"(%arg0) <{ fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo }> : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> + func.return %0 : tensor<3x9xcomplex> +} + +// CHECK: %0 = "mhlo.fft"(%arg0) <{fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo}> : (tensor<3x9xcomplex>) -> tensor<3x9xcomplex> +// CHECK: return %0 : tensor<3x9xcomplex> + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 2b96254e04fc3d..31ecc93ece9cc8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -20,8 +20,8 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // NOLINT: Required to register quantization dialect. -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // NOLINT: Required to register quantization dialect. +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -36,7 +36,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" @@ -53,7 +53,7 @@ using ::mlir::quant::UniformQuantizedPerAxisType; using ::mlir::quant::UniformQuantizedType; #define GEN_PASS_DEF_COMPOSEUNIFORMQUANTIZEDTYPEPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" // These strings are used to identify the uniform_quantize / uniform_dequantize // functions. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc index 8ced899f6de108..53ba950a9abc54 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_pass.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep @@ -39,7 +39,7 @@ namespace { // This file is generated from `passes.td` and provides the implementation base // class. #define GEN_PASS_DEF_COMPOSITELOWERINGPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class CompositeLoweringPass : public impl::CompositeLoweringPassBase { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 6b50e8141447e8..8ede6e261d71aa 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -120,3 +120,12 @@ def LegalizeCompositeOdmlEmbeddingLookup : Pat< [(HasRank<1> $indices), (I32ElementsVal $indices), (HasRankAtLeast<2> $table)]>; + +def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $_, $indices, $table), + ConstantStrAttr, $attrs, $_, $_), + (TFL_EmbeddingLookupOp $indices, $table), + [(HasRank<1> $indices), + (I32ElementsVal $indices), + (HasRankAtLeast<2> $table)]>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc index c2b31aeb540720..3ad12741c5d21a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc @@ -35,7 +35,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index b1735991388680..7626cd362a47f5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -39,7 +39,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -68,7 +68,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -85,7 +85,7 @@ namespace { #define DEBUG_TYPE "tf-legalize-hlo" #define GEN_PASS_DEF_LEGALIZEHLOTOTFPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class LegalizeHloToTf : public impl::LegalizeHloToTfPassBase { /// Performs the legalization to the TF dialect. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index 29677cb30bcf3b..87d6ec4bcd707b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -334,3 +334,19 @@ cc_library( "@local_xla//xla/mlir_hlo", ], ) + +cc_library( + name = "fft", + srcs = ["fft.cc"], + hdrs = ["fft.h"], + deps = [ + "//tensorflow/compiler/mlir/lite:constant_utils", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h index ed8b06e036d816..fe9664c13cdccb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h index 679737b9a25fbf..c7c3bddeacfb21 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CUSTOM_CALL_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc index 3c139e7dbbcdd3..940c75256b9e75 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc @@ -35,7 +35,6 @@ limitations under the License. #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h index 2ea7c96dfbae08..91df1b63e76a7c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_DOT_GENERAL_H_ #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc new file mode 100644 index 00000000000000..a40da742896d16 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc @@ -0,0 +1,206 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h" + +#include +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +// Convert a DenseIntElementsAttr to a vector of int64_t. +std::vector ConvertI64DenseIntAttr(DenseIntElementsAttr attr) { + auto values = attr.getValues(); + return {values.begin(), values.end()}; +} + +// Returns true if the fft op is a supported rfft op. +bool IsSupportedRfftOp(mhlo::FftOp fft_op) { + const auto fft_type = llvm::StringSwitch>( + mlir::mhlo::stringifyFftType(fft_op.getFftType())) + .Case("FFT", mhlo::FftType::FFT) + .Case("RFFT", mhlo::FftType::RFFT) + .Case("IFFT", mhlo::FftType::IFFT) + .Case("IRFFT", mhlo::FftType::IRFFT) + .Default(std::nullopt); + if (!fft_type || *fft_type != mhlo::FftType::RFFT) { + return false; + } + + const std::vector fft_lengths = + ConvertI64DenseIntAttr(fft_op.getFftLength()); + if (fft_lengths.size() != 1) { + return false; + } + + const int fft_len = fft_lengths.back(); + const std::vector input_shape = + mlir::cast(fft_op.getOperand().getType()).getShape(); + if (fft_len != input_shape.back()) { + return false; + } + + auto input_type = + mlir::dyn_cast_or_null(fft_op.getOperand().getType()); + if (!input_type || input_type.getRank() != 2) return false; + + return true; +} + +// Convert rfft to rfft2d. +// The transformation pattern looks like below: +// +// input fft_len +// \ / +// rfft +// +// || +// \/ +// +// input fft_len +// \ / +// expand_dim concat with [1] at the front +// \ / +// rfft_2d +// | +// squeeze +class LegalizeRfftOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::FftOp fft_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + const auto fft_type = llvm::StringSwitch>( + mlir::mhlo::stringifyFftType(fft_op.getFftType())) + .Case("FFT", mhlo::FftType::FFT) + .Case("RFFT", mhlo::FftType::RFFT) + .Case("IFFT", mhlo::FftType::IFFT) + .Case("IRFFT", mhlo::FftType::IRFFT) + .Default(std::nullopt); + if (!fft_type || *fft_type != mhlo::FftType::RFFT) { + return rewriter.notifyMatchFailure(fft_op, "Unsupported fft type."); + } + + const auto fft_lengths = + llvm::to_vector(fft_op.getFftLength().getValues()); + if (fft_lengths.size() != 1) { + return rewriter.notifyMatchFailure( + fft_op, "Can only lower a single fft dimension"); + } + + const int fft_len = fft_lengths.back(); + const std::vector input_shape = + mlir::cast(fft_op.getOperand().getType()).getShape(); + if (fft_len != input_shape.back()) { + return rewriter.notifyMatchFailure(fft_op, "Unsupported fft length."); + } + + auto input_type = + mlir::dyn_cast_or_null(fft_op.getOperand().getType()); + if (!input_type || input_type.getRank() != 2) + return rewriter.notifyMatchFailure(fft_op, "Unsupported input type."); + + auto output_type = mlir::cast(fft_op.getResult().getType()); + + // Expanded inputs. + // Insert at -2 location. + auto one_ele_type = mlir::RankedTensorType::get( + llvm::ArrayRef{1}, rewriter.getIntegerType(32)); + auto minus_two = TFL::CreateConstOpWithSingleValue( + &rewriter, fft_op.getLoc(), one_ele_type, -2); + + SmallVector expanded_input_shape; + SmallVector expanded_output_shape; + int expanded_rank = input_type.getRank() + 1; + int r = 0; + for (int i = 0; i < expanded_rank; ++i) { + if (i == expanded_rank - 2) { + expanded_input_shape.push_back(1); + expanded_output_shape.push_back(1); + } else { + expanded_input_shape.push_back(input_type.getDimSize(r)); + expanded_output_shape.push_back(output_type.getDimSize(r)); + r++; + } + } + + auto expaned_input_type = mlir::RankedTensorType::get( + expanded_input_shape, input_type.getElementType()); + TFL::ExpandDimsOp expanded_input = rewriter.create( + fft_op.getLoc(), expaned_input_type, fft_op.getOperand(), + minus_two->getResult()); + + // Expanded fft_len. + auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1}); + + auto one = rewriter.create(fft_op.getLoc(), one_attr); + auto fft_len_attr = + mlir::DenseIntElementsAttr::get(one_ele_type, {fft_len}); + auto fft_len_const = + rewriter.create(fft_op.getLoc(), fft_len_attr); + + auto expanded_fft_len_type = mlir::RankedTensorType::get( + llvm::ArrayRef{2}, rewriter.getIntegerType(32)); + + TFL::ConcatenationOp expanded_fft_len = + rewriter.create( + fft_op.getLoc(), expanded_fft_len_type, + llvm::SmallVector({one, fft_len_const}), + /*axis*/ 0, /*fused_activation_function*/ "NONE"); + + // Insert the rfft_2d. + auto rfft2d_out_type = mlir::RankedTensorType::get( + expanded_output_shape, output_type.getElementType()); + auto rfft2d = rewriter.create( + fft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(), + expanded_fft_len.getResult()); + + // Insert the squeeze op. + auto squeeze_dim = rewriter.getI64ArrayAttr({-2}); + auto squeeze = rewriter.create( + fft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim); + + rewriter.replaceOp(fft_op, squeeze.getResult()); + return success(); + } +}; + +// Returns true if the fft op is a legal fft op. +bool IsLegalFftOp(mhlo::FftOp fft_op) { return !IsSupportedRfftOp(fft_op); } + +} // namespace + +void PopulateFftPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + target.addDynamicallyLegalOp(IsLegalFftOp); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h new file mode 100644 index 00000000000000..7b9491caf6fa0e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Patterns to legalize mhlo.fft to TFL. +void PopulateFftPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_FFT_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h index 74081e1e04716e..6cd637303b3a7e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_GET_DIMENSION_SIZE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h index a53bdeda2a2097..7d4f76bd3f8b6b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h index 3c2c8ae5ced600..16a5c293b0e989 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc index ae784414e4bac0..f237a7168e5660 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h index 82c4f88937d061..3bf03aec97dcfe 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h index 9bbb1f3fde06ab..c293bad98cf4ef 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::odml { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h index 01c619cbbf6178..28661d299e03df 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h index 129b19388821c9..3b3022153b2d43 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_WHILE_H_ +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index fdad31b31b4b86..113b1b51fa7b55 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -30,7 +30,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #define DEBUG_TYPE "composite-to-custom" @@ -38,12 +38,13 @@ namespace mlir { namespace odml { #define GEN_PASS_DEF_LEGALIZECOMPOSITETOCUSTOMOPPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" namespace { bool IsSupportedComposite(::mlir::stablehlo::CompositeOp op) { // List of supported composites to represent using CustomOp. - return llvm::is_contained({"odml.update_kv_cache"}, op.getName()); + return llvm::is_contained( + {"odml.update_kv_cache", "odml.update_external_kv_cache"}, op.getName()); } bool IsKVCacheCompositeOp(::mlir::stablehlo::CompositeOp op) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc index e699c303bbaac2..113293596536c9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc @@ -27,13 +27,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" namespace mlir { namespace odml { #define GEN_PASS_DEF_LEGALIZESTABLEHLOCUSTOMCALLTOCOMPOSITEPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" struct ReplaceCustomCallWithComposite final : OpRewritePattern { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index 75782674b9a02c..e628cfded7fe6b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -41,7 +41,7 @@ limitations under the License. #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/core/macros.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #define DEBUG_TYPE "compat-passes" @@ -50,7 +50,7 @@ namespace odml { #define GEN_PASS_DEF_LEGALIZESTABLEHLOTOVHLOPASS #define GEN_PASS_DEF_LEGALIZEVHLOTOSTABLEHLOPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" namespace { @@ -58,7 +58,7 @@ namespace { // StableHLO --> VHLO types //===----------------------------------------------------------------------===// -std::optional MaterializeIllegalCast(OpBuilder &builder, Type type, +Value MaterializeIllegalCast(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { return builder.create(loc, type, inputs) ->getResult(0); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc new file mode 100644 index 00000000000000..0e7f1744d5fb63 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc @@ -0,0 +1,6911 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering TensorFlow dialect to XLA dialect. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include "xla/client/lib/conv_grad_size_util.h" +#include "xla/client/padding.h" +#include "xla/client/sharding_builder.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/convert_op_folder.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/framework/kernel_shape_util.h" +#include "tensorflow/core/framework/rng_alg.h" +#include "tensorflow/core/kernels/conv_grad_shape_utils.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" +#include "tsl/platform/bfloat16.h" +#include "tsl/platform/status.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace mlir { +// Keep this in the mlir namespace to allow the use of the mhlo ops. +namespace mhlo { +namespace { + +// The utils are copied into the odml namespace to avoid duplicate names and +// they are imported here to avoid having to change the code below. +using ::mlir::odml::BuildReduceBody; +using ::mlir::odml::GetI64ElementsAttr; +using ::mlir::odml::GetScalarConstOfType; +using ::mlir::odml::GetScalarNegZeroOfType; + +constexpr char kShardingAttr[] = "mhlo.sharding"; + +/// Returns the feature dimension for the given format and input type. +static size_t GetFeatureDimension(tensorflow::TensorFormat format, + RankedTensorType input_ty) { + return GetTensorFeatureDimIndex(input_ty.getRank(), format); +} + +// Gets all integer values from the given attribute and push them to `values`. +void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { + auto array_attr = mlir::cast(attr); + values->reserve(array_attr.getValue().size()); + for (Attribute val : array_attr.getValue()) + values->push_back(mlir::cast(val).getValue().getSExtValue()); +} + +// Returns 1D 32-bit dense elements attribute with the given values. +static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = tensorflow::GetTypeFromTFTensorShape( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +// Returns a 1-d i64 elements attribute populated with numbers from start to +// end, excluding. +static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, + Builder *builder) { + int size = end - start; + + SmallVector vals; + vals.resize(size); + std::iota(vals.begin(), vals.end(), start); + + TensorType ty = + tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, vals); +} + +// Returns a 1-d i64 elements attribute populated with `val` repeated `size` +// times. +static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, + Builder *builder) { + TensorType ty = + tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, val); +} + +// Returns the corresponding type that should be used for performing sum +// accumulation over the given input type. +Type GetSumAccumulationType(Type input_type) { + MLIRContext *ctx = input_type.getContext(); + if (input_type.isBF16() || input_type.isF16()) return FloatType::getF32(ctx); + if (input_type.isSignlessInteger(8) || input_type.isSignlessInteger(16)) + return IntegerType::get(ctx, 32); + return input_type; +} + +// Returns axis in HLO format from TF elements attr with exactly one element or +// is an IntegerAttr, containing axis in the TensorFlow format. TensorFlow +// format supports negative indexing unlike HLO. +static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, + Builder *b) { + IntegerAttr intAttr = mlir::dyn_cast_or_null(attr); + if (auto elementAttr = mlir::dyn_cast_or_null(attr)) { + SmallVector index(elementAttr.getShapedType().getRank(), 0); + intAttr = elementAttr.getValues()[index]; + } + + assert(intAttr && "Invalid attribute passed to GetHLOAxisFromTFAxis"); + + int64_t axis = intAttr.getInt(); + if (axis < 0) { + axis += rank; + } + return b->getI64IntegerAttr(axis); +} + +// Returns a PrecisionConfig as an array attribute based on whether TF32 +// execution is enabled +static ArrayAttr GetPrecisionConfig(Builder *builder) { + mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() + ? mhlo::Precision::DEFAULT + : mlir::mhlo::Precision::HIGHEST; + llvm::SmallVector attr_vec; + const int num_inputs = 2; + for (int i = 0; i < num_inputs; i++) { + attr_vec.push_back( + mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + } + return builder->getArrayAttr(attr_vec); +} + +// If `value` is an IntegerAttr, returns the integer value for the HLO axis +// corresponding to the tensorflow axis. In particular, the tensorflow axis can +// be negative, in which case, the corresponding HLO axis is +// (axis + rank-of-the-tensor). +static std::optional GetIntegerHLOAxisFromTFAxis(Value value, + int64_t rank) { + DenseIntElementsAttr attrs; + if (!matchPattern(value, m_Constant(&attrs)) || + attrs.getType().getRank() != 0) { + return std::nullopt; + } + int64_t axis = attrs.getValues()[0].getInt(); + return axis < 0 ? axis + rank : axis; +} + +/// Returns a `ConvertOp` that casts the elements to a i64 type while retaining +/// the shape of the input value. +static ConvertOp CastValueToI64(Location loc, Value value, + PatternRewriter *rewriter) { + return rewriter->create(loc, value, rewriter->getIntegerType(64)); +} + +// Creates an unpack op along the 0th dimension of the tensor. The `value` input +// must be a ranked tensor. +static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, + PatternRewriter *rewriter) { + auto indices_type = mlir::cast(value.getType()); + int num_outputs = indices_type.getShape().front(); + SmallVector unpacked_indices_type( + num_outputs, + tensorflow::GetTypeFromTFTensorShape({}, indices_type.getElementType())); + auto unpacked_indices = rewriter->create( + loc, unpacked_indices_type, value, + IntegerAttr::get(rewriter->getIntegerType(64), 0)); + return unpacked_indices; +} + +// Returns size of dimension at the specified index, if ranked tensor. +// Otherwise, returns -1. +// +// Aborts if the type is ranked but doesn't have the dimension. +int64_t GetDimSize(Type ty, int64_t index) { + RankedTensorType ranked_ty = mlir::dyn_cast(ty); + if (!ranked_ty) return -1; + + return ranked_ty.getDimSize(index); +} + +template +tensorflow::TensorShape ToTensorShape(llvm::ArrayRef sizes) { + return tensorflow::TensorShape( + llvm::SmallVector(sizes.begin(), sizes.end())); +} + +template +tensorflow::TensorShape ToTensorShape( + llvm::iterator_range> sizes) { + return tensorflow::TensorShape( + llvm::SmallVector(sizes.begin(), sizes.end())); +} + +// Returns a limit scalar const op for the given type. +// Requires FloatType or IntegerType +static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Gets the resulting type from a broadcast between two types for statically +// shaped types. This is to be used for legacy lowerings that both use non +// left-padded broadcasting and static shapes. Its use should not be permitted +// in new code. +// May return nullptr on invalid static broadcast dimensions. +// ABSL_DEPRECATED() +static RankedTensorType GetStaticBroadcastType( + RankedTensorType x, RankedTensorType y, + DenseIntElementsAttr broadcast_dimensions_attr) { + auto element_type = x.getElementType(); + auto shape_x = x.getShape(); + auto shape_y = y.getShape(); + + if (shape_x.size() == shape_y.size()) { + llvm::SmallVector out_shape(shape_x.size()); + for (int i = 0; i < shape_x.size(); i++) { + auto x_val = shape_x[i]; + auto y_val = shape_y[i]; + out_shape[i] = std::max(x_val, y_val); + } + return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); + } + + auto shape_large = shape_x.size() > shape_y.size() ? shape_x : shape_y; + auto shape_small = shape_x.size() <= shape_y.size() ? shape_x : shape_y; + + llvm::SmallVector broadcast_dimensions; + // Explicit broadcast dimensions. + for (const APInt &int_value : broadcast_dimensions_attr) { + broadcast_dimensions.push_back(int_value.getSExtValue()); + } + if (broadcast_dimensions.size() != shape_small.size()) { + return nullptr; + } + llvm::SmallVector out_shape(shape_large.begin(), + shape_large.end()); + + // Update according to the broadcast dimensions. + for (const auto &index_pair : llvm::enumerate(broadcast_dimensions)) { + auto old_value = out_shape[index_pair.value()]; + auto new_value = shape_small[index_pair.index()]; + out_shape[index_pair.value()] = std::max(old_value, new_value); + } + return tensorflow::GetTypeFromTFTensorShape(out_shape, element_type); +} + +// Deprecated: This is maintained to aid in porting old code that is not yet +// dynamic shape aware and uses broadcasting modes that CHLO does not support. +// Applies static binary broadcasting to a binary elementwise op. +// This is a legacy helper to provide general broadcasting support in legacy, +// static shaped code that relies on non-left-padded broadcasting semantics. +template +static Value StaticBinaryBroadcast(Location loc, Value x, Value y, + DenseIntElementsAttr broadcast_dims, + OpBuilder &builder) { + auto x_type = mlir::cast(x.getType()); + auto y_type = mlir::cast(y.getType()); + auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); + if (!result_type) { + emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type + << " with broadcast_dims = " << broadcast_dims; + return nullptr; + } + auto larger_broadcast_dims = + GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + if (x_type.getRank() < y_type.getRank()) { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, + larger_broadcast_dims); + } + } else { + if (x_type != result_type) { + x = builder.create(loc, result_type, x, + larger_broadcast_dims); + } + if (y_type != result_type) { + y = builder.create(loc, result_type, y, broadcast_dims); + } + } + return builder.create(loc, x, y); +} + +// Gets a 1D tensor type suitable for expressing extents of the given tensor +// value type. If the value type is ranked, the result will be statically +// shaped. Otherwise, it will have a dynamic dimension. +static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { + Builder b(value_type.getContext()); + int64_t dim = value_type.hasRank() ? value_type.getRank() : -1; + return tensorflow::GetTypeFromTFTensorShape({dim}, b.getIndexType()); +} + +// Given a value (broadcast_to) and a feature dimension, broadcasts a 1D +// value (broadcast_from) along that feature dimension. This is a shortcut +// for the cases where a 1D tensor must be broadcast along a specific feature +// dimension, which can vary based on data layout, etc. +// +// The extent of `broadcast_from` dim0 must be equal to the extent of the +// feature_dim of `broadcast_to`. +// +// Example: +// [1x2x3x4], [2], 1 -> [1x2x3x4] +// TODO(laurenzo): Swap the order of broadcast_to and broadcast_from for +// consistency. Possibly also rename for clarity. +static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, + Value broadcast_from, int64_t feature_dim, + OpBuilder &builder) { + auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto to_type = mlir::cast(broadcast_to.getType()); + auto result_shape = builder.create(loc, broadcast_to); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + return builder.create( + loc, to_type, broadcast_from, result_extents, broadcast_dims); +} + +// Broadcasts `input` to the shape of `broadcast_to` value following +// TF::BroadcastTo semantics. +// +// Requires that input is a ranked tensor. +// +// TODO(hinsu): Utilize TF::ShapeOp followed by TF::BroadcastTo once ShapeOp +// supports unranked inputs in the lowering. +static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, + OpBuilder &builder) { + auto result_shape = builder.create(loc, broadcast_to); + auto to_type = mlir::cast(broadcast_to.getType()); + auto result_extents_type = GetExtentsTensorTypeFor(to_type); + auto result_extents = builder.create( + loc, result_extents_type, result_shape); + int64_t rank = mlir::cast(input.getType()).getRank(); + auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); + return builder.create( + loc, to_type, input, result_extents, broadcast_dims); +} + +// Builds a set of operations for applying reduction on the input value. A +// tf.sum op is created and will be legalized to tfl ops automatically. +static Value ApplyReduction(Location loc, Value input, + DenseIntElementsAttr reduce_dims, + OpBuilder *builder) { + auto reduce_dims_op = builder->create(loc, reduce_dims); + return builder->create(loc, input, reduce_dims_op, + builder->getBoolAttr(false)); +} + +// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` +// 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). +static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { + auto shape_tensor = builder->create( + loc, GetI64ElementsAttr({num_elements}, builder)); + + auto lower = builder->create( + loc, builder->getI32IntegerAttr(lower_limit)); + auto upper = builder->create( + loc, builder->getI32IntegerAttr(upper_limit)); + + return builder->create(loc, lower, upper, shape_tensor, + ::mlir::mhlo::RngDistribution::UNIFORM); +} + +using WhileBodyFnType = llvm::function_ref old_values, + SmallVectorImpl *new_values, OpBuilder *builder)>; + +// Creates a mhlo.while op with `builder` to loop `num_interations` times, +// each time calling the given `body_fn` on a set of values to generate a new +// set of values. Returns the final set of values via `final_values`. The +// initial set of values is passed in via `init_values`. +// +// This effectively does: +// +// ```c++ +// SmallVector old_values = init_values; +// SmallVector new_values; +// for (int i = 0; i < num_iterations; ++i) { +// body_fn(old_values, &new_values, ...); +// old_values = new_values; +// } +// ``` +// +// Under the hood an induction variable is prepended to values to control the +// number of iterations, but that is transparent to `body_fn`, which does not +// need to care about that. +static void CreateWhile32(Location loc, int num_iterations, + WhileBodyFnType body_fn, ArrayRef init_values, + SmallVectorImpl *final_values, + OpBuilder *builder) { + int value_count = init_values.size() + 1; + + // Prepend a loop induction variable to the initial values. + SmallVector init_values_with_loop_iv; + SmallVector init_types_with_loop_iv; + init_values_with_loop_iv.reserve(value_count); + init_types_with_loop_iv.reserve(value_count); + + // The initial value for the loop induction variable is 0. + init_values_with_loop_iv.push_back( + builder->create(loc, builder->getI32IntegerAttr(0))); + init_values_with_loop_iv.append(init_values.begin(), init_values.end()); + + // Accumulate types of all the init values. + for (const auto &init_value_with_loop_iv : init_values_with_loop_iv) + init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType()); + + // Create the while op. + auto while_op = builder->create(loc, init_types_with_loop_iv, + init_values_with_loop_iv); + auto ivs_count = init_types_with_loop_iv.size(); + + { + OpBuilder::InsertionGuard guard(*builder); + + // Build up the only block in the condition region. + Region &condition = while_op.getCond(); + Block *block = builder->createBlock(&condition); + block->addArguments(init_types_with_loop_iv, + SmallVector(ivs_count, loc)); + + // Get the loop induction variable and compare it against the upper limit. + auto loop_iv = block->getArgument(0); + auto upper_limit = builder->create( + loc, builder->getI32IntegerAttr(num_iterations)); + Value compare = builder->create(loc, loop_iv, upper_limit, + ComparisonDirection::LT); + + builder->create(loc, compare); + } + + { + OpBuilder::InsertionGuard guard(*builder); + + // Build up the only block in the body region. + Region &body = while_op.getBody(); + Block *block = builder->createBlock(&body); + block->addArguments(init_types_with_loop_iv, + SmallVector(ivs_count, loc)); + + SmallVector new_values; // Generated by this iteration + new_values.reserve(value_count); + + // Feed all values excluding the loop induction variable to body_fn. + body_fn(loc, block->getArgument(0), + ArrayRef(block->getArguments().begin() + 1, + block->getArguments().end()), + &new_values, builder); + + // Increment the loop induction variable by one. + auto one = + builder->create(loc, builder->getI32IntegerAttr(1)); + auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); + auto plus_one = builder->create( + loc, block->getArgument(0), one, scalar_broadcast_dims); + // Prepend with the updated loop induction variable. + new_values.insert(new_values.begin(), plus_one); + + builder->create(loc, new_values); + } + + // TODO(jpienaar): Support multi-operand while op. + final_values->reserve(init_values.size()); + for (int i = 0, e = init_values.size(); i < e; ++i) + final_values->push_back(while_op.getResult(i + 1)); +} + +//===----------------------------------------------------------------------===// +// BatchNorm op utilities. +//===----------------------------------------------------------------------===// + +static IntegerAttr getFeatureDimensionAttr(Builder &b, + tensorflow::TensorFormat format, + Value input) { + return b.getI64IntegerAttr(GetFeatureDimension( + format, mlir::cast(input.getType()))); +} + +//===----------------------------------------------------------------------===// +// FFT op utilities. +//===----------------------------------------------------------------------===// + +// Returns the 1D i64 elements attribute populated with the inner-most dim of +// the value. +static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { + if (type.getRank() == 0) { + return builder->getI64TensorAttr({}); + } + return builder->getI64TensorAttr(type.getShape().back()); +} + +// Returns True if the inner-most dim is static. +bool CheckInnerDimStatic(ShapedType type, Builder *builder) { + if (!type.hasRank()) { + return false; + } + return !type.isDynamicDim(type.getShape().size() - 1); +} + +//===----------------------------------------------------------------------===// +// MatMul op utilities. +//===----------------------------------------------------------------------===// + +// If the 'transpose' attribute is true returns ElementsAttr to transpose 2D +// matrix. Otherwise, returns ElementsAttr for identity transpose. +static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { + if (transpose.getValue()) return GetI64ElementsAttr({1, 0}, b); + return GetI64ElementsAttr({0, 1}, b); +} + +//===----------------------------------------------------------------------===// +// Pad op utilities. +//===----------------------------------------------------------------------===// + +// Slices input attribute of rank two and returns the specified column. +// +// Always returns 64 bit integer attribute regardless of bitwidth of the input +// attribute. +static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( + ElementsAttr input, int column) { + auto int_attr = mlir::cast(input); + auto shaped_type = int_attr.getType(); + auto shape = shaped_type.getShape(); + + if (shape.size() != 2) return DenseIntElementsAttr(); + + llvm::SmallVector values; + values.reserve(shaped_type.getNumElements() / shape[1]); + + for (const auto &it : llvm::enumerate(int_attr.getValues())) { + if (static_cast(it.index() % shape[1]) == column) { + values.push_back(it.value().getSExtValue()); + } + } + + auto element_type = IntegerType::get(input.getContext(), 64); + return DenseIntElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({shape[0]}, element_type), values); +} + +// Returns interior padding to use in HLO Pad op based on the TensorFlow padding +// in TensorFlow PadV2 op. +static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { + auto length = tf_padding.getShapedType().getShape()[0]; + auto element_type = IntegerType::get(tf_padding.getContext(), 64); + return DenseIntElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); +} + +//===----------------------------------------------------------------------===// +// Binary op utilities. +//===----------------------------------------------------------------------===// + +// Returns whether the two values are guaranteed to be broadcastable to the +// same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions +// must be broadcasted with a size 1 tensor or another dynamic dimension. +// Returns false on rankless. +static bool AreBroadcastCompatible(Value x, Value y) { + auto x_rankless = mlir::dyn_cast(x.getType()); + auto y_rankless = mlir::dyn_cast(y.getType()); + if (!x_rankless || !y_rankless) { + return false; + } + + // Check that the shapes can be broadcasted. + auto shape_x = x_rankless.getShape(); + auto shape_y = y_rankless.getShape(); + + int rank_diff = shape_x.size() - shape_y.size(); + int offset_x = rank_diff > 0 ? rank_diff : 0; + int offset_y = rank_diff < 0 ? -rank_diff : 0; + for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) { + int index_x = i + offset_x; + int index_y = i + offset_y; + if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) || + (shape_y[index_y] == -1 && shape_x[index_x] != 1)) { + return false; + } + } + + return true; +} + +// Return a new TensorType the same rank and dimensions as the input with an +// updated element type. +static Type ChangeTensorElementType(Builder *b, Type tensor_type, + Type element_type) { + RankedTensorType ranked_type = mlir::dyn_cast(tensor_type); + if (ranked_type) { + return tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), + element_type); + } + + return UnrankedTensorType::get(element_type); +} + +//===----------------------------------------------------------------------===// +// Softmax op utilities. +//===----------------------------------------------------------------------===// + +// Returns the type to use for accumulating the given type. +static Type GetAccumulationType(Type ty) { + // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from + // repeated floating point additions. + return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; +} + +//===----------------------------------------------------------------------===// +// Softplus op utilities. +//===----------------------------------------------------------------------===// + +static DenseElementsAttr GetEpsilonValue(Type ty) { + auto element_ty = mlir::cast(ty).getElementType(); + auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); + if (element_ty.isF16()) { + uint16_t raw_epsilon = Eigen::numext::bit_cast( + Eigen::NumTraits::epsilon()); + auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon)); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isBF16()) { + uint16_t raw_epsilon = Eigen::numext::bit_cast( + Eigen::NumTraits::epsilon()); + auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon)); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isF32()) { + auto value = APFloat(std::numeric_limits::epsilon()); + return DenseElementsAttr::get(scalar_ty, value); + } else if (element_ty.isF64()) { + auto value = APFloat(std::numeric_limits::epsilon()); + return DenseElementsAttr::get(scalar_ty, value); + } + llvm_unreachable("unsupported element type for tf.SoftPlus"); +} + +//===----------------------------------------------------------------------===// +// ArgMax/ArgMin op utilities. +//===----------------------------------------------------------------------===// + +static void BuildArgMinMaxReductionBody(Type input_element_type, + Type index_element_type, + ComparisonDirection direction, + Region *body, OpBuilder *builder) { + OpBuilder::InsertionGuard insertion_point_gurad(*builder); + + Type input_type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, input_element_type); + Type index_type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, index_element_type); + Block *block = builder->createBlock(body); + Location loc = body->getLoc(); + block->addArguments({input_type, index_type, input_type, index_type}, + SmallVector(4, loc)); + + Value lhs_val = block->getArgument(0); + Value lhs_index = block->getArgument(1); + Value rhs_val = block->getArgument(2); + Value rhs_index = block->getArgument(3); + + ImplicitLocOpBuilder b(loc, *builder); + Value compare_dt = b.create(lhs_val, rhs_val, direction); + Value selected_input = + b.create(input_type, compare_dt, lhs_val, rhs_val); + + Value compare_eq = + b.create(lhs_val, rhs_val, ComparisonDirection::EQ); + Value min_index = b.create(lhs_index, rhs_index); + Value min_val_index = + b.create(index_type, compare_dt, lhs_index, rhs_index); + Value selected_index = + b.create(index_type, compare_eq, min_index, min_val_index); + + Value return_values[] = {selected_input, selected_index}; + b.create(return_values); +} + +//===----------------------------------------------------------------------===// +// PartitionedCall op utilities. +//===----------------------------------------------------------------------===// + +// Verify that the arguments to be passed into the function are the same types +// as the function paramter types. +static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, + SymbolRefAttr func) { + auto module = op->getParentOfType(); + auto function = + dyn_cast_or_null(SymbolTable::lookupSymbolIn(module, func)); + FunctionType function_ty = function.getFunctionType(); + + for (auto arg_in : llvm::zip(args, function_ty.getInputs())) { + if (std::get<0>(arg_in).getType() != std::get<1>(arg_in)) { + // Argument type and input type mismatch. + return false; + } + } + return true; +} + +//===----------------------------------------------------------------------===// +// Slice op utilities. +//===----------------------------------------------------------------------===// + +static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, + DenseIntElementsAttr slice_sizes) { + auto input_ty = mlir::dyn_cast(input.getType()); + if (!input_ty) return false; + auto start_indices_ty = + mlir::dyn_cast(start_indices.getType()); + if (!start_indices_ty) return false; + + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + DenseIntElementsAttr constant_start_indices; + bool is_constant_start = + matchPattern(start_indices, m_Constant(&constant_start_indices)); + + for (int64_t i = 0; i < input_rank; ++i) { + int64_t input_size = input_shape[i]; + int64_t slice_size = slice_sizes.getValues()[i].getInt(); + // A slice_size of -1 means "all elements from start_index to the end". + // In order to support these semantics, we need to know both the start index + // and the shape of the input dimension. + if (slice_size < 0 && (!is_constant_start || input_size < 0)) return false; + } + return true; +} + +// TF slice size can be -1, which represents all elements from start_index to +// the end. HLO slice size can't be -1. As such, we need to translate TF slice +// size -1 to HLO slice size. +static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( + Value input, Value start_indices, DenseIntElementsAttr slice_sizes, + Builder *builder) { + DenseIntElementsAttr constant_start_indices; + if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { + return mlir::cast( + hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); + } + + auto input_ty = mlir::dyn_cast(input.getType()); + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + SmallVector normalized_sizes; + + for (int64_t i = 0; i < input_rank; ++i) { + int64_t input_size = input_shape[i]; + int64_t start_index = + constant_start_indices.getValues()[i].getInt(); + int64_t slice_size = slice_sizes.getValues()[i].getInt(); + normalized_sizes.push_back(slice_size == -1 ? input_size - start_index + : slice_size); + } + + return GetI64ElementsAttr(normalized_sizes, builder); +} + +//===----------------------------------------------------------------------===// +// XlaGather op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidGatherDims(StringAttr attr) { + ::xla::GatherDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, + Builder *builder) { + ::xla::GatherDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertGatherDimensionNumbers(dims, builder); +} + +//===----------------------------------------------------------------------===// +// XlaDot op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidDotDims(StringAttr attr) { + ::xla::DotDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) { + ::xla::DotDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertDotDimensionNumbers(dims, builder); +} + +bool HasValidPrecisionConfig(StringAttr attr) { + ::xla::PrecisionConfig precision; + return precision.ParseFromString(attr.getValue().str()); +} + +mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) { + ::xla::PrecisionConfig precision; + if (!precision.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertPrecisionConfig(&precision, builder); +} + +//===----------------------------------------------------------------------===// +// XlaVariadicReduceV2 op utilities. +//===----------------------------------------------------------------------===// + +static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc, + mlir::SymbolRefAttr func, + mlir::FunctionType func_ty, Region *body) { + OpBuilder::InsertionGuard guard(rewriter); + + Block *block = rewriter.createBlock(body); + auto inputs = func_ty.getInputs(); + block->addArguments(inputs, SmallVector(inputs.size(), loc)); + mlir::func::CallOp call_op = rewriter.create( + loc, func, func_ty.getResults(), block->getArguments()); + rewriter.create(loc, call_op.getResults()); +} + +//===----------------------------------------------------------------------===// +// Op converters. +//===----------------------------------------------------------------------===// + +NamedAttribute GetConvDimensionNumbersAttr(ArrayRef spatial_dims, + tensorflow::TensorFormat format, + Builder *builder) { + int64_t num_spatial_dims = spatial_dims.size(); + int64_t num_dims = num_spatial_dims + 2; + + int64_t batch_dim = GetTensorBatchDimIndex(num_dims, format); + int64_t feature_dim = GetTensorFeatureDimIndex(num_dims, format); + + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + int64_t kernel_input_feature_dim = num_spatial_dims; + int64_t kernel_output_feature_dim = num_spatial_dims + 1; + SmallVector kernel_spatial_dimensions; + kernel_spatial_dimensions.resize(num_spatial_dims); + std::iota(kernel_spatial_dimensions.begin(), kernel_spatial_dimensions.end(), + 0); + + return builder->getNamedAttr( + "dimension_numbers", + ConvDimensionNumbersAttr::get( + builder->getContext(), batch_dim, feature_dim, spatial_dims, + kernel_input_feature_dim, kernel_output_feature_dim, + kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims)); +} + +// Converts a TF::BiasAddOp to HLO. +// This differs from a normal TF::AddOp with respect to how the data_format +// is handled, which can optionally require a general broadcast of the +// 'bias' term in a way that is not compatible with the standard left-padded +// broadcast semantics (i.e. NCHW will broadcast into dimension 1). +// The correct 'bias' broadcast will be synthesized manually. +class ConvertBiasAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::BiasAddOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + auto feature_dim = GetFeatureDimension(data_format, value_type); + auto bias_broadcast = Broadcast1DToFeatureDim( + loc, op.getValue(), op.getBias(), feature_dim, rewriter); + Value add = rewriter.create(loc, op.getValue(), bias_broadcast); + if (add.getType() != op.getType()) { + add = rewriter.create(loc, op.getType(), add); + } + rewriter.replaceOp(op, {add}); + return success(); + } +}; + +// Conterts tf.Conv2D to mhlo.dynamic_conv. +// TODO(disc): To recover static special case's performance with adding folding, +// canonicalization func and removing ConvertConvOp. +template +class ConvertConvDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + bool GetPaddingValues(OpT &op, PatternRewriter &rewriter, Value input_size, + Value filter_size, int64_t dilation_rate, + int64_t stride, tensorflow::Padding padding_type, + Type shape_scalar_type, Value *padding_low, + Value *padding_high) const { + // Stride must be > 0 + if (stride <= 0) return false; + // Dilation rate must be >= 1 + if (dilation_rate < 1) return false; + + Location loc = op.getLoc(); + switch (padding_type) { + case tensorflow::Padding::VALID: { + auto zero = + rewriter.create(loc, 0, shape_scalar_type); + *padding_low = *padding_high = zero; + break; + } + case tensorflow::Padding::EXPLICIT: + break; + case tensorflow::Padding::SAME: { + auto zero = + rewriter.create(loc, 0, shape_scalar_type); + auto one = + rewriter.create(loc, 1, shape_scalar_type); + auto two = + rewriter.create(loc, 2, shape_scalar_type); + // See also the parallel implementation in + // GetWindowedOutputSizeFromDimsV2. effective_filter_size = (filter_size + // - 1) * dilation_rate + 1 + Value stride_value = rewriter.create( + loc, stride, shape_scalar_type); + Value dilation_rate_value = rewriter.create( + loc, dilation_rate, shape_scalar_type); + Value effective_filter_size_op = rewriter.create( + loc, one, + rewriter.create( + loc, dilation_rate_value, + rewriter.create(loc, filter_size, one))); + // output_size = (input_size + stride - 1) / stride; + Value output_size = rewriter.create( + loc, + rewriter.create( + loc, input_size, + rewriter.create(loc, stride_value, one)), + stride_value); + // std::max(int64{0}, (output_size - 1) * stride + + // effective_filter_size - input_size); + Value padding_needed = rewriter.create( + loc, + rewriter.create( + loc, effective_filter_size_op, + rewriter.create( + loc, stride_value, + rewriter.create(loc, output_size, one))), + input_size); + Value cond = rewriter.create( + loc, arith::CmpIPredicate::sge, padding_needed, zero); + padding_needed = rewriter.create( + loc, padding_needed.getType(), cond, padding_needed, zero); + *padding_low = + rewriter.create(loc, padding_needed, two); + *padding_high = + rewriter.create(loc, padding_needed, *padding_low); + break; + } + } + return true; + } + + LogicalResult matchAndRewriteDynamicConv(OpT op, + PatternRewriter &rewriter) const { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + auto result_ty = mlir::dyn_cast(op.getType()); + if (!input_ty || !filter_ty || !result_ty) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (input_ty.hasStaticShape() && filter_ty.hasStaticShape()) + return failure(); + + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is attached to + // Conv2D. + explicit_paddings = + op->template getAttrOfType("explicit_paddings").getValue(); + } + + SmallVector spatial_dim_indices; + SmallVector rhs_dilations; + SmallVector window_strides; + SmallVector paddings; + + auto get_int = [](Attribute attr) { + return mlir::cast(attr).getInt(); + }; + + constexpr int num_dims = num_spatial_dims + 2; + + Location loc = op.getLoc(); + auto shape_scalar_type = rewriter.getIntegerType(32); + + auto get_const = [&](int64_t val) { + return rewriter.create(loc, val, + shape_scalar_type); + }; + auto get_dim_value = [&](Value val, int64_t dim) { + Value dim_value = rewriter.create(loc, val, dim); + return rewriter.create(loc, shape_scalar_type, + dim_value); + }; + + for (auto i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dim_indices.push_back(dim); + + const int64_t dilation = get_int(dilations[dim]); + rhs_dilations.push_back(dilation); + const int64_t stride = get_int(strides[dim]); + window_strides.push_back(stride); + + Value pad_low, pad_high; + if (padding == tensorflow::Padding::EXPLICIT) { + pad_low = get_const(get_int(explicit_paddings[2 * dim])); + pad_high = get_const(get_int(explicit_paddings[2 * dim + 1])); + } else { + auto input_size = get_dim_value(op.getInput(), dim); + auto filter_size = get_dim_value(op.getFilter(), i); + if (!GetPaddingValues(op, rewriter, input_size, filter_size, dilation, + stride, padding, shape_scalar_type, &pad_low, + &pad_high)) { + return failure(); + } + } + paddings.push_back(pad_low); + paddings.push_back(pad_high); + } + auto rhs_dilations_attr = rewriter.getNamedAttr( + "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + + auto window_strides_attr = rewriter.getNamedAttr( + "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + + auto dimension_numbers_attr = GetConvDimensionNumbersAttr( + spatial_dim_indices, data_format, &rewriter); + + const int64_t input_channels = + GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); + // TensorFlow convolution op verifies that the number of input channels is + // divisible by the number of filter channels. + // For depthwise convolution the feature_group_count argument would be set + // to the input feature dimension. + const int64_t feature_group_count = + depthwise_conv ? input_channels : input_channels / filter_channels; + auto feature_group_count_attr = rewriter.getNamedAttr( + "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); + + auto batch_group_count_attr = rewriter.getNamedAttr( + "batch_group_count", rewriter.getI64IntegerAttr(1)); + + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + + Value paddings_op = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(2 * num_spatial_dims, + rewriter.getI32Type()), + paddings); + + SmallVector operands(op.getOperands()); + operands.push_back(paddings_op); + // Reshape the filter to {spatial_dims...., 1,in_channels * + // channel_multiplier} + if (depthwise_conv) { + ArrayRef filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape( + filter_shape.begin(), filter_shape.begin() + num_spatial_dims); + new_shape.push_back(1); + new_shape.push_back(filter_shape[num_spatial_dims] * + filter_shape[num_spatial_dims + 1]); + operands[1] = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(new_shape, + filter_ty.getElementType()), + operands[1]); + } + NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, + dimension_numbers_attr, feature_group_count_attr, + batch_group_count_attr, precision_config_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + return matchAndRewriteDynamicConv(op, rewriter); + } +}; + +using ConvertConv2DDynamic = + ConvertConvDynamic; + +// Converts the TensorFlow conv op in template to the generic HLO conv op by +// converting TensorFlow op attributes to HLO op attributes. +// +// Sample result for Conv2D: +// +// %conv = "mhlo.convolution"(%input, %filter) { +// strides = [1, 2], +// paddings = [[1, 0], [1, 1]], +// ... +// } +// +// This pattern is not defined using declarative rewrite rules as computation of +// the paddings attribute anyway requires multiple source op attributes and +// result op attributes. Defining it as declarative rewrite rule will introduce +// some duplication in the C++ helper methods. +template +class ConvertConvOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + + // With the exception of input's batch dimension, input and filter need to + // have static shape for calculation of HLO paddings and feature group count + // attributes. Filter is validated here, input is mostly validated at use. + if (!input_ty || !filter_ty || !filter_ty.hasStaticShape()) + return failure(); + + ArrayRef dilations = op.getDilations().getValue(); + ArrayRef strides = op.getStrides().getValue(); + ArrayRef explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2D. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + explicit_paddings = + op->template getAttrOfType("explicit_paddings").getValue(); + } + + SmallVector spatial_dim_indices; + SmallVector rhs_dilations; + SmallVector window_strides; + SmallVector paddings; + + auto get_int = [](Attribute attr) { + return mlir::cast(attr).getInt(); + }; + + constexpr int num_dims = num_spatial_dims + 2; + for (auto i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dim_indices.push_back(dim); + + const int64_t dilation = get_int(dilations[dim]); + rhs_dilations.push_back(dilation); + const int64_t stride = get_int(strides[dim]); + window_strides.push_back(stride); + + int64_t pad_low, pad_high; + if (padding == tensorflow::Padding::EXPLICIT) { + pad_low = get_int(explicit_paddings[2 * dim]); + pad_high = get_int(explicit_paddings[2 * dim + 1]); + } else { + int64_t output_size; + int64_t pad_low_int64; + int64_t pad_high_int64; + int64_t input_size = input_ty.getDimSize(dim); + if (input_size == ShapedType::kDynamic) return failure(); + absl::Status status = tensorflow::GetWindowedOutputSizeVerbose( + input_size, filter_ty.getDimSize(i), dilation, stride, padding, + &output_size, &pad_low_int64, &pad_high_int64); + if (!status.ok()) return failure(); + pad_low = pad_low_int64; + pad_high = pad_high_int64; + } + paddings.push_back(pad_low); + paddings.push_back(pad_high); + } + + auto rhs_dilations_attr = rewriter.getNamedAttr( + "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + + auto window_strides_attr = rewriter.getNamedAttr( + "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + + auto dimension_numbers_attr = GetConvDimensionNumbersAttr( + spatial_dim_indices, data_format, &rewriter); + + const int64_t input_channels = + GetDimSize(input_ty, GetTensorFeatureDimIndex(num_dims, data_format)); + if (input_channels == ShapedType::kDynamic) return failure(); + // Filters data_format is always HWIO so input channels dimension is after + // all spatial dimensions. + const int64_t filter_channels = GetDimSize(filter_ty, num_spatial_dims); + // TensorFlow convolution op verifies that the number of input channels is + // divisible by the number of filter channels. + // For depthwise convolution the feature_group_count argument would be set + // to the input feature dimension. + const int64_t feature_group_count = + depthwise_conv ? input_channels : input_channels / filter_channels; + auto feature_group_count_attr = rewriter.getNamedAttr( + "feature_group_count", rewriter.getI64IntegerAttr(feature_group_count)); + + auto batch_group_count_attr = rewriter.getNamedAttr( + "batch_group_count", rewriter.getI64IntegerAttr(1)); + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = rewriter.getNamedAttr( + "padding", DenseElementsAttr::get(paddings_ty, paddings)); + + auto precision_config_attr = rewriter.getNamedAttr( + "precision_config", GetPrecisionConfig(&rewriter)); + + SmallVector operands(op.getOperands()); + // Reshape the filter to {spatial_dims...., 1,in_channels * + // channel_multiplier} + if (depthwise_conv) { + ArrayRef filter_shape = filter_ty.getShape(); + llvm::SmallVector new_shape( + filter_shape.begin(), filter_shape.begin() + num_spatial_dims); + new_shape.push_back(1); + new_shape.push_back(filter_shape[num_spatial_dims] * + filter_shape[num_spatial_dims + 1]); + operands[1] = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(new_shape, + filter_ty.getElementType()), + operands[1]); + } + NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, + dimension_numbers_attr, feature_group_count_attr, + batch_group_count_attr, paddings_attr, + precision_config_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } +}; + +using ConvertConv2DOp = ConvertConvOp; +using ConvertConv3DOp = ConvertConvOp; +using ConvertDepthConv2DOp = + ConvertConvOp; + +// Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const. +class ConvertPadOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + // TODO(disc): To recover static special case's performance with folding and + // canonicalization. + LogicalResult matchAndRewrite(TF::PadV2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto input = op.getInput(); + auto paddings = op.getPaddings(); + auto constant_values = op.getConstantValues(); + auto input_type = mlir::dyn_cast(input.getType()); + auto paddings_type = mlir::dyn_cast(paddings.getType()); + if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) + return failure(); + + // TODO(disc): Remove this constraint once fold and canonicalization is + // implemented. + if (input_type.hasStaticShape()) return failure(); + + int input_rank = input_type.getRank(); + // interior padding + std::vector interior_values(input_rank, 0); + auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter); + + Value interior_padding_tensor = + rewriter.create(loc, interior_attr); + Type paddings_elem_ty = paddings_type.getElementType(); + if (!paddings_elem_ty.isInteger(64)) { + interior_padding_tensor = rewriter.create( + loc, interior_padding_tensor, paddings_elem_ty); + } + llvm::SmallVector transposed_shape = {2, input_rank}; + auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); + Value transposed_paddings = + rewriter.create(loc, paddings, transpose_attr); + Value reshaped_paddings = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape({input_rank * 2}, + paddings_elem_ty), + transposed_paddings); + + auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter); + auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); + Value left_padding_tensor = rewriter.create( + loc, reshaped_paddings, left_padding_start_attr, + left_padding_limit_attr, left_padding_stride_attr); + + auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto right_padding_limit_attr = + GetI64ElementsAttr({2 * input_rank}, &rewriter); + auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); + Value right_padding_tensor = rewriter.create( + loc, reshaped_paddings, right_padding_start_attr, + right_padding_limit_attr, right_padding_stride_attr); + + rewriter.replaceOpWithNewOp( + op, op.getType(), input, constant_values, left_padding_tensor, + right_padding_tensor, interior_padding_tensor); + + return success(); + } +}; + +class ConvertGatherNdOpDynamic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + // Converts tf.GatherNdOp to mhlo.DynamicGatherOp. + // Here we leave 'slice_sizes' as an Attr, without defining a new + // DynamicGatherOp, since GatherDimensionNumbers has already provide enough + // information for shape inference and code generation of mhlo::GatherOp. '?' + // will be filled into slice_sizes for dimensions that are dynamic sized. + // TODO(disc): To recover static special case's performance with folding and + // canonicalization. + LogicalResult matchAndRewrite(TF::GatherNdOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto params = op.getParams(); + auto params_ty = mlir::dyn_cast(params.getType()); + auto indices = op.getIndices(); + auto indices_ty = mlir::dyn_cast(indices.getType()); + auto params_rank = params_ty.getRank(); + auto indices_rank = indices_ty.getRank(); + int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); + if (!params_ty || !indices_ty) return failure(); + // the last dim of indices of GatherNdOp must be fixed shaped + if (num_index_dims == ShapedType::kDynamic) return failure(); + + SmallVector slice_sizes; + slice_sizes.reserve(params_rank); + for (int64_t i = 0; i < params_rank; ++i) { + if (i < num_index_dims) { + slice_sizes.push_back(1); + } else { + // potentially dynamic + int64_t dim_size = params_ty.getDimSize(i); + slice_sizes.push_back(dim_size); + } + } + SmallVector slice_sizes_vals; + Value slice_sizes_value = nullptr; + for (int64_t i = 0; i < params_rank; ++i) { + if (i < num_index_dims) { + slice_sizes_vals.push_back(rewriter.create( + loc, rewriter.getIntegerAttr(indices_ty.getElementType(), 1))); + } else { + int64_t dim_size = params_ty.getDimSize(i); + if (dim_size != ShapedType::kDynamic) { + slice_sizes_vals.push_back(rewriter.create( + loc, + rewriter.getIntegerAttr(indices_ty.getElementType(), dim_size))); + } else { + slice_sizes_vals.push_back(rewriter.create( + loc, indices_ty.getElementType(), + rewriter.create(loc, params, i))); + } + } + } + slice_sizes_value = + rewriter.create(loc, slice_sizes_vals); + + // collapsed_slice_dims + SmallVector collapsed_slice_dims; + collapsed_slice_dims.reserve(num_index_dims); + for (int64_t i = 0; i < num_index_dims; ++i) { + collapsed_slice_dims.push_back(i); + } + // offset_dims + SmallVector offset_dims; + offset_dims.reserve(params_rank - num_index_dims); + for (int64_t i = num_index_dims; i < params_rank; i++) { + offset_dims.push_back(i + indices_rank - 1 - num_index_dims); + } + // start_index_map + SmallVector start_index_map; + offset_dims.reserve(num_index_dims); + for (int64_t i = 0; i < num_index_dims; i++) { + start_index_map.push_back(i); + } + // index_vector_dim + int64_t index_vector_dim = indices_rank - 1; + + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), offset_dims, collapsed_slice_dims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim); + // TODO(disc): Remove this if-statement once fold and canonicalization is + // implemented. + if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getParams(), op.getIndices(), dims_attr, + GetI64ElementsAttr(slice_sizes, &rewriter)); + } else { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, + dims_attr); + } + return success(); + } +}; + +// Converts BF16 FloorDiv op to have casting operators on either end as BF16 +// division can result in strange behavior. +// +// floordiv = cast(floordiv(cast(left), cast(right)))) +// +// %left_cast = cast(%left) +// %right_cast = cast(%right) +// %div = div(%left, %left) +// %floored = floor(%div) +// %floored_cast = cast(%floored) +// +// Required to manually specify the intermediate types. +class ConvertBF16FloorDivOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::FloorDivOp op, + PatternRewriter &rewriter) const override { + auto l = mlir::dyn_cast>(op.getX()); + auto r = mlir::dyn_cast>(op.getY()); + if (!l || !r) return failure(); + + auto element_type = getElementTypeOrSelf(l.getType()); + if (!element_type.isBF16()) return failure(); + + auto out_type = op.getZ().getType(); + + l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); + r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); + + auto intermediate = rewriter.create( + op.getLoc(), + ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, + r); + + auto floor_op = + rewriter.create(op.getLoc(), out_type, intermediate); + rewriter.replaceOp(op, floor_op.getResult()); + return success(); + } +}; + +class ConvertBroadcastToOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::BroadcastToOp op, + PatternRewriter &rewriter) const override { + auto input_type = mlir::dyn_cast(op.getInput().getType()); + auto output_type = op.getOutput().getType(); + if (!input_type) { + return rewriter.notifyMatchFailure(op, "requires ranked input shape"); + } + llvm::SmallVector broadcast_dimensions; + if (input_type.getRank() > 0) { + auto ranked_output_type = mlir::dyn_cast(output_type); + if (!ranked_output_type) { + return rewriter.notifyMatchFailure(op, "requires ranked output shape"); + } + auto rank_diff = ranked_output_type.getRank() - input_type.getRank(); + // The tf.BroadcastTo op performs "right-aligned" numpy-style + // broadcasting. + broadcast_dimensions = llvm::to_vector<4>( + llvm::seq(rank_diff, ranked_output_type.getRank())); + } + rewriter.replaceOpWithNewOp( + op, output_type, op.getInput(), op.getShape(), + rewriter.getI64TensorAttr(broadcast_dimensions)); + return success(); + } +}; + +/// Converts a TF::RollOp to HLO. Only support 0D axis and shift case, and axis +/// have to be a constant. +class ConvertRollOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::RollOp op, + PatternRewriter &rewriter) const override { + auto shift_ty = mlir::dyn_cast(op.getShift().getType()); + if (!shift_ty || shift_ty.getRank() != 0) { + return rewriter.notifyMatchFailure( + op, "require the type of shift to be 0D tensor"); + } + + APInt val; + if (!matchPattern(op.getAxis(), m_ConstantInt(&val))) { + return rewriter.notifyMatchFailure(op, "require axis to be constant"); + } + int axis = val.getSExtValue(); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) { + return rewriter.notifyMatchFailure( + op, "require the type of input to have static shapes"); + } + ArrayRef input_shape = input_ty.getShape(); + int input_rank = input_ty.getRank(); + if (axis < 0) axis += input_rank; + + // Adjust large offsets into [0, axis_size). This also makes negative + // offsets positive. + // offset = ((offset % axis_size) + axis_size) % axis_size + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value offset = op.getShift(); + auto axis_size = b.create(b.getIntegerAttr( + getElementTypeOrSelf(offset.getType()), input_shape[axis])); + offset = b.create( + b.create(b.create(offset, axis_size), axis_size), + axis_size); + + // Stack two copies of the dimension, then slice from the calculated + // offset. This also works if shift is not constant. + // DynamicSliceOp requires the sizes being integer, and we can get the + // information from input shape. + auto concat = b.create( + ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); + Value zero = b.create( + b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); + SmallVector slice_begin_indices(input_rank, zero); + slice_begin_indices[axis] = b.create(axis_size, offset); + rewriter.replaceOpWithNewOp( + op, input_ty, concat, slice_begin_indices, + rewriter.getI64TensorAttr(input_shape)); + return success(); + } +}; + +/// Converts a TF::LeakyReluOp to HLO. +/// LeakyRelu(x) = alpha * x if x < 0 else x. +class ConvertLeakyReluOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value features = op.getFeatures(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyActivationVal = + rewriter.create(loc, features, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, compareGtZero, features, + leakyActivationVal); + return success(); + } +}; + +/// Converts a TF::LeakyReluGradOp to HLO. +/// LeakyReluGrad(gradient, inputs) = gradient if input > 0 +/// else alpha * gradient. +class ConvertLeakyReluGradOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::LeakyReluGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value gradients = op.getGradients(); + Value features = op.getFeatures(); + auto featureType = features.getType(); + + // Use ConstantLike for `alpha` to match the shape of feature. + auto alphaVal = chlo::getConstantLike( + rewriter, loc, op.getAlpha().convertToFloat(), features); + Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); + + Value leakyGradientVal = + rewriter.create(loc, gradients, alphaVal); + + Value compareGtZero = rewriter.create( + loc, features, zeroVal, ComparisonDirection::GT); + + rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, + gradients, leakyGradientVal); + return success(); + } +}; + +// Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. +// For a Rank-2 input, it creates the following ops: +// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} +// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} +// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} +// %4 = mhlo.constant dense<0.000000e+00> : tensor +// %5 = "mhlo.broadcast"(%4) +// %6 = "mhlo.select"(%3, %input, %5) +// %7 = "mhlo.reduce"(%6, %4) ({ +// ^bb0(%arg1: tensor, %arg2: tensor): +// %9 = mhlo.add %arg1, %arg2 : tensor +// "mhlo.return"(%9) : (tensor) -> () +// }) {dimensions = dense<0> : tensor<1xi64>} +// +// If the input's rank N is greater than 2, we will reshape it to R2 first and +// create the above ops, then reshape it back to rank N/2. +class ConvertDiagPartOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::DiagPartOp op, + PatternRewriter &rewriter) const override { + auto input_type = mlir::dyn_cast(op.getInput().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + int64_t num_dims = input_type.getRank(); + if (num_dims < 2 || num_dims % 2 != 0) return failure(); + const int64_t out_dims = num_dims / 2; + + int64_t new_size = 1; + llvm::SmallVector new_dims; + for (int i = 0; i < out_dims; i++) { + if (input_type.getDimSize(i) != input_type.getDimSize(i + out_dims)) + return op.emitOpError("invalid dimensions size"); + new_size *= input_type.getDimSize(i); + new_dims.push_back(input_type.getDimSize(i)); + } + Value reshaped_input = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, + input_type.getElementType()), + op.getInput()); + auto iota_type = tensorflow::GetTypeFromTFTensorShape( + {new_size, new_size}, rewriter.getIntegerType(32)); + auto iota0 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(0)); + auto iota1 = rewriter.create(op.getLoc(), iota_type, + rewriter.getI64IntegerAttr(1)); + Value compare = rewriter.create(op.getLoc(), iota0, iota1, + ComparisonDirection::EQ); + Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), + 0, &rewriter); + Value zero_matrix = rewriter.create( + op.getLoc(), reshaped_input.getType(), zero, + GetI64ElementsAttr({new_size, new_size}, &rewriter)); + Value masked = + rewriter.create(op.getLoc(), reshaped_input.getType(), + compare, reshaped_input, zero_matrix); + auto reduce = rewriter.create(op.getLoc(), masked, zero, + GetI64ElementsAttr({0}, &rewriter), + input_type.getElementType()); + assert(!input_type.getElementType().isInteger(1) && + "data type should not be i1"); + BuildReduceBody(input_type.getElementType(), &reduce.getBody(), + &rewriter); + rewriter.replaceOpWithNewOp( + op, + tensorflow::GetTypeFromTFTensorShape(new_dims, + input_type.getElementType()), + reduce.getResult(0)); + return success(); + } +}; + +// Converts TensorFlow MatrixDiagPartOp to HLO ops. +class ConvertMatrixDiagPartV3Op + : public OpRewritePattern { + using Shape = llvm::SmallVector; + + // Parse the "k" parameter. MatrixDiagPartV3 allows to specify the diagonal(s) + // with k. This can be either a single value (for a single diagonal) or a + // tuple of two values (starting and ending diagonal, for a band). + LogicalResult ExtractK(TF::MatrixDiagPartV3Op op, int64_t (*k)[2]) const { + DenseIntElementsAttr kattr; + if (!matchPattern(op.getK(), m_Constant(&kattr))) { + return failure(); + } + DenseIntElementsAttr::iterator it = kattr.begin(); + (*k)[0] = (*it).getSExtValue(); + it++; + if (it == kattr.end()) { + // Handle input like e.g. "k = 5", in which case we extract a single + // diagonal. + (*k)[1] = (*k)[0]; + } else { + // Handle input like e.g. "k = [-1, 1]", in which case we extract a + // band (multiple diagonals). + (*k)[1] = (*it).getSExtValue(); + } + return success(); + } + + // Utility method for broadcasting integer constants to a given shape. + BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, + int int_size, PatternRewriter &rewriter) const { + return rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(shape, + rewriter.getIntegerType(int_size)), + GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, + &rewriter), + GetI64ElementsAttr(shape, &rewriter)); + } + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + ShapedType input_type = mlir::dyn_cast(op.getInput().getType()); + + // Align is a string specifying how superdiagonals and subdiagonals should + // be aligned/padded for diagonals that are shorter than max_diag_len. The + // format is "{super}_{sub}", with {super} the superdiagonal alignment and + // {sub} the subdiagonal alignment. "LEFT" means rows will be padded to the + // left, "RIGHT" means rows will be padded ot the right. The default is + // "RIGHT_LEFT". + StringRef align = op->getAttrOfType("align").getValue(); + enum Alignment { kLeft, kRight }; + + // default is RIGHT_LEFT + Alignment superdiagonal_align = kRight; + Alignment subdiagonal_align = kLeft; + + if (align == "RIGHT_LEFT") { + superdiagonal_align = kRight; + subdiagonal_align = kLeft; + } else if (align == "RIGHT_RIGHT") { + superdiagonal_align = kRight; + subdiagonal_align = kRight; + } else if (align == "LEFT_RIGHT") { + superdiagonal_align = kLeft; + subdiagonal_align = kRight; + } else if (align == "LEFT_LEFT") { + superdiagonal_align = kLeft; + subdiagonal_align = kLeft; + } else { + return failure(); // unsupported alignment + } + + // MatrixDiagPart operates on a matrix of shape [I, J, ..., L, M, N], and + // will extract the diagonal(s) out of [M, N], for all [I, J, ..., L]. + if (!input_type || !input_type.hasStaticShape()) return failure(); + int64_t num_dims = input_type.getRank(); + if (num_dims < 2) return failure(); + int64_t rows = input_type.getDimSize(num_dims - 2); // rows + int64_t cols = input_type.getDimSize(num_dims - 1); // cols + + // We extract the diagonals from k[0] up to and including k[1]. + // Addressing is 0 for the main diagonal. (So k = [0, 0] would just extract + // the main diagonal). It's negative for subdiagonals (under and to the left + // of the main diagonal) and positive for superdiagonals (above and to the + // right of the main diagonal). + int64_t k[2]; + if (failed(ExtractK(op, &k))) return failure(); + int num_diags = k[1] - k[0] + 1; + + // Shifting diagonals away from the main diagonal might shorten them. This + // is the longest diagonal we will see. We make this the last dimension of + // the output shape. + int64_t max_diag_len = + std::min(rows + std::min(k[1], static_cast(0)), + cols + std::min(-k[0], static_cast(0))); + + // The first dimension is the index vector dimension we'll use for gather. + // It's 1 here, but will be 2 once we glue x and y together. + Shape indices_shape({1, num_diags, max_diag_len}); + + RankedTensorType iota_type = tensorflow::GetTypeFromTFTensorShape( + indices_shape, rewriter.getIntegerType(32)); + Value iotaM = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); + Value iotaN = + rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); + + // Boradcasted constants, of the same shape as iotaM and iotaN. + Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); + Value b_false = BroadcastConstant(loc, indices_shape, 0, 1, rewriter); + Value b_true = BroadcastConstant(loc, indices_shape, 1, 1, rewriter); + Value b_k1 = BroadcastConstant(loc, indices_shape, k[1], 32, rewriter); + Value b_rows = BroadcastConstant(loc, indices_shape, rows, 32, rewriter); + Value b_cols = BroadcastConstant(loc, indices_shape, cols, 32, rewriter); + Value b_max_diag_len = + BroadcastConstant(loc, indices_shape, max_diag_len, 32, rewriter); + + // d = k[1] - m + // (A.k.a. the number of the diagonal, depending on m. Note that we + // subtract m here. This means we start with the superdiagonals and + // move downwards towards the subdiagonals. So the start indices will + // be decreasing.) + Value d = rewriter.create(loc, b_k1, iotaM); + Value neg_d = rewriter.create(loc, d); + + // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) + // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) + Value diag_len_d = rewriter.create( + loc, + rewriter.create(loc, b_rows, + rewriter.create(loc, d, b_zero)), + rewriter.create(loc, b_cols, + rewriter.create(loc, d, b_zero))); + + // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. + Value cmp; + if (subdiagonal_align == kRight && superdiagonal_align == kRight) { + cmp = b_true; + } else if (superdiagonal_align == kRight) { + // offset = d>=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else if (subdiagonal_align == kRight) { + // offset = d<=0 ? max_diag_len - diag_len_d : 0 + cmp = rewriter.create(loc, d, b_zero); + } else { + // offset = 0 + cmp = b_false; + } + + // This offset shifts the diagonals to the "left" or "right", depending + // on alignment. + Value offset = rewriter.create( + loc, b_zero.getType(), cmp, + rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); + + // x = max(d, 0) - offset + // y = max(-d, 0) - offset + Value x = rewriter.create( + loc, rewriter.create(loc, d, b_zero), offset); + Value y = rewriter.create( + loc, rewriter.create(loc, neg_d, b_zero), offset); + + Value n_plus_x = rewriter.create(loc, iotaN, x); + Value n_plus_y = rewriter.create(loc, iotaN, y); + + // GatherOp is happy about letting us index out of bounds values, but those + // values will be undefined. So we mask them later. Set up the boolean + // expression that tells us which entries, in the output shape, are out of + // bounds and thus become the padding_value. + Value x_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_x, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); + Value y_in_bounds = rewriter.create( + loc, + rewriter.create(loc, b_false.getType(), n_plus_y, + b_zero), + rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); + Value in_bounds = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(Shape({num_diags, max_diag_len}), + rewriter.getIntegerType(1)), + rewriter.create(loc, x_in_bounds, y_in_bounds)); + + // Now combine x and y into the index data structure needed for gather. + Shape concat_shape({2, num_diags, max_diag_len}); + Value start_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(concat_shape, + rewriter.getIntegerType(32)), + mlir::ValueRange({n_plus_y, n_plus_x}), + mlir::IntegerAttr::get(rewriter.getIntegerType(64), 0)); + + // Shape of the final output. (Except for dimension folding in the + // single diagonal case.) + Shape output_shape; + for (int i = 0; i < num_dims - 2; i++) { + output_shape.push_back(input_type.getDimSize(i)); + } + output_shape.push_back(num_diags); + output_shape.push_back(max_diag_len); + + // A slice is the shape of what GatherOp copies per lookup. So the last + // two dimensions (M, N in the matrix-diag-part docs) are where we go + // through entry by entry. + ArrayRef input_shape = input_type.getShape(); + int input_shape_size = input_shape.size(); + Shape slice_sizes(input_shape.begin(), input_shape.end()); + int slice_dimensions = slice_sizes.size(); + slice_sizes[slice_dimensions - 2] = + std::min((int64_t)1, input_shape[input_shape_size - 2]); + slice_sizes[slice_dimensions - 1] = + std::min((int64_t)1, input_shape[input_shape_size - 1]); + + // Dimensions of the input we won't see in the output (M and N). + SmallVector collapsed_dims( + {slice_dimensions - 2, slice_dimensions - 1}); + + // Which dimensions (in the input) the two offset "columns" map to. + SmallVector start_index_map({num_dims - 2, num_dims - 1}); + + // Gather the diagonal entries. + // TODO(kramm): For a single diagonal, this might be slower than the + // mask + sum approach. Special-case num_diags==1? + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/llvm::to_vector<4>(llvm::seq(0, num_dims - 2)), + /*collapsedSliceDims=*/collapsed_dims, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, start_index_map, + /*indexVectorDim=*/0); + Value gather = rewriter.create( + loc, op.getInput(), start_indices, dims_attr, + GetI64ElementsAttr(slice_sizes, &rewriter)); + + // We now need to broadcast the "in_bounds" boolean expression, as well as + // the padding value, to do the final select. + Shape broadcast_bounds; + for (int i = 0; i < output_shape.size() - 2; i++) { + broadcast_bounds.push_back(output_shape[i]); + } + Value b_in_bounds = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(output_shape, + rewriter.getIntegerType(1)), + in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); + Value b_padding = rewriter.create( + loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); + + // Replace all out-of-bounds values in the result with padding_value. + Value result = + rewriter.create(loc, b_in_bounds, gather, b_padding); + + if (num_diags == 1) { + // matrix_diag_part folds away the 1-sized band dimension if we only + // extract a single diagonal. + result = rewriter.create(loc, op.getType(), result); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Converts TensorFlow EinsumOp to HLO EinsumOp +class ConvertEinsumOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::EinsumOp op, + PatternRewriter &rewriter) const override { + // Prepend `,` to equation if unary einsum. + std::string equation_str = op.getEquation().str(); + llvm::SmallVector inputs; + + // Unary einsum prepends `,` to equation and + // creates a scalar constant 1.0 for first operand. + if (op.getN() == 1) { + equation_str = "," + equation_str; + inputs.push_back(rewriter.create( + op.getLoc(), hlo::getScalarOfType( + mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); + } + // Insert remaining operands into inputs, TF op verifier requires there be + // 0 or 1 operands. + auto operands = op.getInputs(); + inputs.insert(inputs.end(), operands.begin(), operands.end()); + assert(inputs.size() == 2); + + rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], + inputs[1], equation_str); + return success(); + } +}; + +// Bypasses IdentityN op. +class ConvertIdentityNOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::IdentityNOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + +template +class ConvertFFTOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto input_ty = mlir::cast(op.getInput().getType()); + if (!input_ty.hasRank()) { + return failure(); + } + auto input_shape = input_ty.getShape(); + DenseIntElementsAttr fft_length_attr; + if (!matchPattern(op.getFftLength(), m_Constant(&fft_length_attr))) { + return failure(); + } + int64_t fft_length; + if (fft_length_attr.getNumElements() != 0) { + fft_length = fft_length_attr.getValues()[0].getInt(); + } else { + return failure(); + } + + int64_t expected_dim = fft_length; + std::string fft_string = "RFFT"; + if (typeid(OpTy) == typeid(TF::IRFFTOp)) { + expected_dim = fft_length / 2 + 1; + fft_string = "IRFFT"; + } + Location loc = op.getLoc(); + + // The inner-most dim cannot be dynamic. + if (input_ty.isDynamicDim(input_shape.size() - 1)) { + return failure(); + } + + auto expected_shape = llvm::to_vector<4>(input_shape.drop_back()); + expected_shape.push_back(expected_dim); + + // Zero pad or truncate the last axis + Value reshaped = op.getInput(); + SmallVector begin_indices(input_shape.size(), 0); + SmallVector strides(input_shape.size(), 1); + + // Last dim larger than expected_dim, slice the input + if (input_shape.back() > expected_dim) { + reshaped = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape(expected_shape, + input_ty.getElementType()), + op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(expected_shape, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + + // Last dim smaller than expected_dim, zero-pad the input + } else if (input_ty.getShape().back() < expected_dim) { + SmallVector no_padding(input_shape.size(), 0); + SmallVector padding(input_shape.size() - 1, 0); + padding.push_back(expected_dim - input_shape.back()); + Value zero = + GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); + reshaped = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(expected_shape, + input_ty.getElementType()), + op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), + GetI64ElementsAttr(padding, &rewriter), + GetI64ElementsAttr(no_padding, &rewriter)); + } + + rewriter.replaceOpWithNewOp( + op, op.getType(), reshaped, + FftTypeAttr::get(rewriter.getContext(), + symbolizeFftType(fft_string).value()), + rewriter.getI64TensorAttr(fft_length)); + return success(); + } +}; + +using ConvertRFFTOp = ConvertFFTOp; +using ConvertIRFFTOp = ConvertFFTOp; + +// The base class to convert TensorFlow FusedBatchNormGrad*Op to HLO +// BatchNormGradOp for training and a sequence of binary ops for inference. +// TODO(b/145536565): move to legalize_tf_patterns.td if it applies. +template +class ConvertFusedBatchNormGradBase + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FusedBatchNormGradOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value grad = op.getYBackprop(); + Value act = op.getX(); + Value scale = op.getScale(); + Value mean = op.getReserveSpace_1(); + Value var = op.getReserveSpace_2(); + + // TODO(b/141785544): Update this to not require static shapes. + // activation shape needs to be static to convert negative indices in + // TensorFlow to absolute indices required by HLO. + RankedTensorType act_type = mlir::dyn_cast(act.getType()); + if (!act_type) return failure(); + Type act_ele_type = act_type.getElementType(); + // To support mixed precision, the statistics type, which maybe more + // precise than the input types, are used for this op. + Type kernel_type = mlir::cast(scale.getType()).getElementType(); + grad = rewriter.create(loc, grad, kernel_type); + act = rewriter.create(loc, act, kernel_type); + + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto feature_dim_attr = getFeatureDimensionAttr(rewriter, data_format, act); + auto feature_dim = feature_dim_attr.getValue().getSExtValue(); + + // Gets the result values. + Value x_backprop, scale_backprop, offset_backprop; + if (op.getIsTraining()) { // training + // TODO(b/145536565): handle GPU logic separately. + // Infers the output type with the converted `act`. + Type feature_type = tensorflow::GetTypeFromTFTensorShape( + {GetDimSize(act_type, feature_dim)}, kernel_type); + + SmallVector operand_types = {act.getType(), feature_type, + feature_type}; + auto training_op = rewriter.create( + loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), + feature_dim); + + x_backprop = training_op.getResult(0); + + scale_backprop = training_op.getResult(1); + + offset_backprop = training_op.getResult(2); + } else { // inference + SmallVector non_feature_dims; + for (int64_t i = 0; i < act_type.getRank(); ++i) { + if (i == feature_dim) continue; + non_feature_dims.push_back(i); + } + auto reduce_dims = GetI64ElementsAttr(non_feature_dims, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + + // scratch1 = rsqrt(var + epsilon) + RankedTensorType scalar_float = + tensorflow::GetTypeFromTFTensorShape({}, kernel_type); + auto epsilon = rewriter.create( + loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); + auto add_op = rewriter.create( + loc, var, epsilon.getResult(), scalar_broadcast_dims); + + Value scratch1 = rewriter.create(loc, add_op); + + // scratch2 = sum(y_backprop * (x - mean)) + auto sub_op = rewriter.create( + loc, act, + Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); + auto weighted_grad = rewriter.create(loc, grad, sub_op); + Value scratch2 = + ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); + + // x_backprop = y_backprop * (scale * scratch1) + auto scaled_grad = + rewriter.create(loc, op.getScale(), scratch1); + x_backprop = rewriter.create( + loc, grad, + Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, + rewriter)); + + // scale_backprop = scratch2 * scratch1 + scale_backprop = rewriter.create(loc, scratch1, scratch2); + + // offset_backprop = sum(y_backprop) + offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); + } + + x_backprop = rewriter.create(loc, x_backprop, act_ele_type); + Value last_val[2]; + if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { + // It doesn't matter what values we provide for the last 2 results. + last_val[0] = last_val[1] = op.getX(); + } else { + auto const_val = rewriter.create( + op.getLoc(), DenseElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape( + {0}, getElementTypeOrSelf(op.getResult(3))), + 0.0)); + auto maybe_cast = [&](Value val, Type t) -> Value { + if (val.getType() == t) return val; + return rewriter.create(op.getLoc(), t, val); + }; + last_val[0] = maybe_cast(const_val, op.getResult(3).getType()); + last_val[1] = maybe_cast(const_val, op.getResult(4).getType()); + } + rewriter.replaceOp( + op, {/*x_backprop=*/x_backprop, + /*scale_backprop=*/scale_backprop, + /*offset_backprop=*/offset_backprop, last_val[0], last_val[1]}); + return success(); + } +}; + +using ConvertFusedBatchNormGradOp = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV2Op = + ConvertFusedBatchNormGradBase; +using ConvertFusedBatchNormGradV3Op = + ConvertFusedBatchNormGradBase; + +// Converts TensorFlow FusedBatchNormV3Op to either HLO BatchNormTrainingOp or +// HLO BatchNormInferenceOp, depending on the value of the 'is_training' +// parameter. +template +class ConvertFusedBatchNormBase : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FusedBatchNormOpT op, + PatternRewriter &rewriter) const override { + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + auto feature_dim = + getFeatureDimensionAttr(rewriter, data_format, op.getX()); + + auto input_type_tensor = mlir::cast(op.getX().getType()); + auto input_element_type = input_type_tensor.getElementType(); + + auto scale_type_tensor = mlir::cast(op.getScale().getType()); + auto scale_element_type = scale_type_tensor.getElementType(); + + auto mean_type_tensor = mlir::cast(op.getMean().getType()); + auto mean_element_type = mean_type_tensor.getElementType(); + // In the training case, dimensions of input tensors must be static. + if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || + !scale_type_tensor.hasStaticShape() || + !mean_type_tensor.hasStaticShape())) + return failure(); + + // TODO(b/69928690): Support mixed precision in the XLA batch + // normalization operators. As a workaround, create a new x with the same + // element type as scale (which may be more precise than the input type). + Value bn_train_input = rewriter.create( + op.getLoc(), op.getX(), scale_element_type); + TensorType bn_train_input_type_tensor = + mlir::cast(bn_train_input.getType()); + + if (op.getIsTraining()) { + // Training case. + auto operand_shape = bn_train_input_type_tensor.getShape(); + // The mean and variance are each 1 dimensional arrays the size of the + // feature dimension, with the same element type as the operand (x). + // This shape must be constructed manually because the mean and variance + // inputs are empty in the training case. + Type mean_var_type = tensorflow::GetTypeFromTFTensorShape( + {operand_shape[feature_dim.getInt()]}, scale_element_type); + // Op result type is a tuple of 3 values: output with same shape as input; + // batch_mean, and batch_var. + SmallVector operand_types = {bn_train_input_type_tensor, + mean_var_type, mean_var_type}; + auto bn_train_op = rewriter.create( + op.getLoc(), operand_types, bn_train_input, op.getScale(), + op.getOffset(), op.getEpsilon(), feature_dim.getInt()); + // HLO op outputs a tuple of tensors. Extract those results. + Value y_out = bn_train_op.getResult(0); + Value batch_mean = bn_train_op.getResult(1); + Value reserve_space_1 = batch_mean; + Value batch_variance = bn_train_op.getResult(2); + + // Apply Bessel's correction on the variance. + int total_input_size = bn_train_input_type_tensor.getNumElements(); + int total_scale_size = scale_type_tensor.getNumElements(); + int sample_size = + total_scale_size > 0 ? total_input_size / total_scale_size : 0; + int sample_size_minus_one = std::max(1, sample_size - 1); + double factor = static_cast(sample_size) / + static_cast(sample_size_minus_one); + auto factor_const_op = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); + + Value corrected_variance = rewriter.create( + op.getLoc(), batch_variance.getType(), batch_variance, + factor_const_op, /*broadcast_dimensions=*/DenseI64ArrayAttr()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); + + float exponential_avg_factor = + op.getExponentialAvgFactor().convertToFloat(); + if (exponential_avg_factor != 1.0f) { + auto alpha = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(mean_element_type, + 1.0f - exponential_avg_factor)); + auto beta = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); + + // new_running_mean = alpha * old_mean + beta * batch_mean. + auto alpha_mul_old_mean = rewriter.create( + op.getLoc(), op.getMean().getType(), alpha, op.getMean(), + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + auto beta_mul_batch_mean = rewriter.create( + op.getLoc(), batch_mean.getType(), beta, batch_mean, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + batch_mean = rewriter.create( + op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + + // new_running_variance = alpha * old_variance + beta * batch_variance. + auto alpha_mul_old_variance = rewriter.create( + op.getLoc(), op.getVariance().getType(), alpha, op.getVariance(), + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + auto beta_mul_batch_variance = rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, corrected_variance, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + corrected_variance = rewriter.create( + op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, + /*broadcast_dimensions=*/DenseI64ArrayAttr()); + } + + if (std::is_same::value) { + // FusedBatchNormV2 expects 4 outputs. + // Outputs 3 and 4 are currently marked as "reserved spaces 1 and 2". + // They are used to pass the per-batch mean and variance to the + // gradiant. Here we maintain the same behavior by setting them to the + // mean and variance calculated by BatchNormTraining. + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance}); + } else { // TF::FusedBatchNormV3Op + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + mlir::cast(op.getResult(5).getType()); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {y_out, /*batch_mean=*/batch_mean, + /*batch_variance=*/corrected_variance, + /*reserve_space_1=*/reserve_space_1, + /*reserve_space_2=*/batch_variance, + /*reserve_space_3=*/dummy_const}); + } + } else { // Inference case. + auto bn_train_op = rewriter.create( + op.getLoc(), + /*result_type=*/bn_train_input_type_tensor, bn_train_input, + op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), + op.getEpsilon(), feature_dim.getInt()); + + // Convert back to input type to stay aligned with expected output type + // for TF op. + auto y_out = rewriter.create(op.getLoc(), bn_train_op, + input_element_type); + + // The mean, variance, and reserved space outputs of the batch norm op are + // not used for inference. It doesn't matter what values we provide for + // the last 5 results as long as they are of the same type. Forward + // input mean and variance to output mean, variance, reserved_space_1 and + // reserved_space_2. + if (std::is_same::value) { + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance()}); + } else { + // For FusedBatchNormV3Op, also create a constant tensor to forward to + // last reserve_space_3 output. + auto reserve_space_3_type = + mlir::cast(op.getResult(5).getType()); + int num_elements = reserve_space_3_type.hasStaticShape() + ? reserve_space_3_type.getNumElements() + : 0; + auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( + {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); + Value dummy_const = rewriter.create( + op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); + if (const_attr_type != reserve_space_3_type) + dummy_const = rewriter.create( + op.getLoc(), reserve_space_3_type, dummy_const); + rewriter.replaceOp(op, {/*y=*/y_out, + /*batch_mean=*/op.getMean(), + /*batch_variance=*/op.getVariance(), + /*reserve_space_1=*/op.getMean(), + /*reserve_space_2=*/op.getVariance(), + /*reserve_space_3=*/dummy_const}); + } + } + return success(); + } +}; + +using ConvertFusedBatchNormV2Op = + ConvertFusedBatchNormBase; +using ConvertFusedBatchNormV3Op = + ConvertFusedBatchNormBase; + +using PaddingArray = std::vector>; + +// Returns padding values for ReduceWindow op as a vector of pairs. +// +// Requires padding to be either 'SAME' or 'VALID' and the number of input +// dimensions to be equal to the size of window dimensions and window strides. +template +static PaddingArray GetReduceWindowPaddingAsArray( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") { + return PaddingArray(num_dims, std::make_pair(0, 0)); + } + assert(padding == "SAME"); + llvm::SmallVector input_shape, window_shape, strides; + input_shape.reserve(input_dims.size()); + window_shape.reserve(window_shape.size()); + strides.reserve(window_strides.size()); + + for (const auto &dim : input_dims) input_shape.push_back(dim); + for (Attribute attr : window_dims) + window_shape.push_back(mlir::cast(attr).getInt()); + for (Attribute attr : window_strides) + strides.push_back(mlir::cast(attr).getInt()); + + PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, + ::xla::Padding::kSame); + return paddings; +} + +// Same as GetReduceWindowPaddingAsArray but returns padding as +// DenseIntElementsAttr. Returns empty attribute for `VALID` padding. +template +static DenseIntElementsAttr GetReduceWindowPaddingAsAttr( + llvm::ArrayRef input_dims, ArrayAttr window_dims, + ArrayAttr window_strides, StringRef padding, Builder *builder) { + if (padding == "VALID") return {}; + assert(padding == "SAME"); + PaddingArray paddings = GetReduceWindowPaddingAsArray( + input_dims, window_dims, window_strides, padding, builder); + int64_t rank = paddings.size(); + llvm::SmallVector flatten_paddings(rank * 2); + for (int i = 0; i < rank; i++) { + flatten_paddings[2 * i] = paddings[i].first; + flatten_paddings[2 * i + 1] = paddings[i].second; + } + return DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( + {rank, 2}, builder->getIntegerType(64)), + flatten_paddings); +} + +// Helper function for dividing each entry of `pooled` by the count of its +// corresponding window, i.e., the number of non-padding entries of the window +// which an `AvgPool` operation performed on an `input_shape`-tensor would map +// to this entry, depending on `ksize` and `strides`. This function is used for +// `AvgPool` and `AvgPoolGrad` legalizations. +// `zero` is passed as a parameter because it can be reused from caller level. +// `pooled` must have `RankedTensorType`. +template +Operation *AvgPoolDivideByCount( + Value pooled, const SmallVector &input_shape, + const SmallVector &ksize, + const SmallVector &strides, OpTy op, Value zero, + PatternRewriter &rewriter) { + Location loc = op.getLoc(); + RankedTensorType pooled_type = mlir::cast(pooled.getType()); + Type element_type = pooled_type.getElementType(); + Operation *result = nullptr; + RankedTensorType orig_input_type = + tensorflow::GetTypeFromTFTensorShape(input_shape, element_type); + + if (op.getPadding() == "VALID") { + // All window counts are equal here because we don't have padding + // (each entry of `pooled` corresponds to a window that consists of + // original input entries only). + int64_t window_count = std::accumulate(ksize.begin(), ksize.end(), 1, + std::multiplies()); + // Divide `pooled` by window counts. + Value divisor = + GetScalarConstOfType(element_type, loc, window_count, &rewriter); + auto scalar_broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create( + loc, pooled_type, pooled, divisor, scalar_broadcast_dims); + } else { + assert(op.getPadding() == "SAME"); + // For SAME padding, only original entries that contributed to a window + // are counted for the average of this window, not padded entries. + + // Build all-ones tensor of same shape as the original input. + ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); + auto all_ones_tensor = rewriter.create(loc, splat); + + // Get padding for the input. + DenseIntElementsAttr input_padding_attr = + GetReduceWindowPaddingAsAttr(input_shape, op.getKsize(), + op.getStrides(), op.getPadding(), + &rewriter); + + // Count the 1's in each window, using the same padding as for the input, + // which gives us the window counts by which `pooled` needs to be divided. + auto divisor = rewriter.create( + loc, pooled_type, + /*operand=*/all_ones_tensor, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_strides=*/GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/input_padding_attr); + BuildReduceBody(element_type, &divisor.getBody(), &rewriter); + + // Divide `pooled` by window counts. + result = rewriter.create(loc, pooled_type, pooled, + divisor.getResult(0)); + } + return result; +} + +Value GetAvgPoolInput(TF::AvgPoolOp op) { return op.getValue(); } +Value GetAvgPoolInput(TF::AvgPool3DOp op) { return op.getInput(); } + +// Converts AvgPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with add as the reduction function. The reduction result is +// then divided by the number of elements in the window. +template +class ConvertAvgPoolOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Value input_value = GetAvgPoolInput(op); + auto input_type = mlir::dyn_cast(input_value.getType()); + if (!input_type) return failure(); + + // We will do accumulation first; use a larger bitwidth if suitable. + Type input_element_type = input_type.getElementType(); + Type sum_element_type = GetSumAccumulationType(input_element_type); + Type result_type; + + // The result type for reduction and division with the proper element type. + if (auto ranked_type = mlir::dyn_cast(op.getType())) + result_type = tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), + sum_element_type); + else + result_type = UnrankedTensorType::get(sum_element_type); + + // Convert if we need enlarge the element type's bitwidth. + if (input_element_type != sum_element_type) + input_value = rewriter.create(op.getLoc(), input_value, + sum_element_type); + + // Create the ReduceWindow op. + Value init = + GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + auto reduce = rewriter.create( + op.getLoc(), result_type, input_value, init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); + + // Count the number of elements in the window. The following calculation + // is only valid for no paddings. + SmallVector input_shape( + llvm::to_vector(input_type.getShape())); + SmallVector ksize, strides; + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); + + Operation *result_op = AvgPoolDivideByCount( + reduce.getResult(0), input_shape, ksize, strides, op, init, rewriter); + + // Convert back if we enlarged the element type's bitwidth. + Value result = result_op->getOpResult(0); + if (input_element_type != sum_element_type) + result = + rewriter.create(op.getLoc(), result, input_element_type); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +using ConvertAvgPool2DOp = ConvertAvgPoolOp; +using ConvertAvgPool3DOp = ConvertAvgPoolOp; + +// `AvgPoolGradOp` is converted to the following operations: +// 1. Divide each entry of the output gradient (the gradient for the previous +// layer in backpropagation order) by the count of the corresponding window +// (i.e., the number of non-padding entries of the window which `AvgPool` +// has mapped to this entry in forward propagation). +// 2. Add appropriate interior and exterior padding for step 3 (see example +// below). +// 3. Convolve the result of step 2. with a kernel consisting of 1's (same shape +// as windows) and stride 1 in each dimension. This is implemented as a +// `ReduceWindowOp` with `AddOp` as body. +// +// Example: +// Let f : R^4 -> R^2 be an average pool function with window size 3, stride 2, +// and SAME padding with 0's. It is defined by +// f(x) = [ (x_1 + x_2 + x_3) / 3 ] ( x = (x_1, x_2, x_3, x_4) ) +// [ (x_3 + x_4 + 0) / 2 ] (the 0 results from right padding) +// Note that for SAME padding in `AvgPool` the padded entries are not counted +// for the average, this is why the second denominator is 2 and not 3. +// The Jacobian Df is +// [ 1/3 1/3 1/3 0 ] +// [ 0 0 1/2 1/2 ] +// +// Note that the Jacobian is constant (this is why `ConvertAvgPoolGradOp` only +// needs the original input shape and not the tensor as argument). +// Let v = [ 4 6 ]^T be the output gradient (^T = transposed). Then the +// average pool gradient is given by +// Df^T * v = [ 4/3 4/3 13/3 3 ]^T +// Instead of a matrix-vector-multiplication we can utilize the sparsity and +// structure of Df by using the 3-step approach from above: +// 1. Divide output gradient v by window counts: [ 4/3 6/2 ]^T +// 2. Add appropriate padding: [ 0 0 4/3 0 3 0 ]^T +// 3. Convolve with kernel [ 1 1 1 ]: [ 4/3 4/3 11/3 3 ]^T +// +// Note that the padding in step 2. is chosen in such a way that the subsequent +// convolution produces the gradient. Higher dimensions, different padding, and +// different windows/strides work in a similar way, the main difference is in +// the computation of the paddings in step 2. +// +// For more details on backpropagation for convolution of which `AvgPoolGrad` +// is a special case see `tensorflow/core/kernels/conv_grad_ops.h`. +// `tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir` has more +// examples for different cases. +template +class ConvertAvgPoolGradOp : public OpRewritePattern { + using DimVector = SmallVector; + + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) { + return op.emitOpError("invalid data format"); + } + // `out_grad` is the gradient that was propagated via backpropagation from + // the output layer. + Value out_grad = op.getGrad(); + auto out_grad_type = mlir::dyn_cast(out_grad.getType()); + if (!out_grad_type) { + return failure(); + } + Type element_type = out_grad_type.getElementType(); + DenseIntElementsAttr orig_input_shape_attr; + if (!matchPattern(op.getOrigInputShape(), + m_Constant(&orig_input_shape_attr))) { + return failure(); + } + auto orig_input_shape_values = orig_input_shape_attr.getValues(); + DimVector orig_input_shape(orig_input_shape_values.begin(), + orig_input_shape_values.end()); + DimVector ksize, strides; + GetI64ArrayAttrValues(op.getKsize(), &ksize); + GetI64ArrayAttrValues(op.getStrides(), &strides); + Value zero = GetScalarConstOfType(element_type, loc, 0, &rewriter); + + auto out_grad_divided = AvgPoolDivideByCount( + out_grad, orig_input_shape, ksize, strides, op, zero, rewriter); + + // Get same padding as for original input. + PaddingArray orig_padding = GetReduceWindowPaddingAsArray( + orig_input_shape, op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + + // Add padding around `out_grad_divided` values in such a way that the + // subsequent `ReduceWindowOp` produces the gradient. + DimVector out_grad_shape( + llvm::to_vector(out_grad_type.getShape())); + DimVector low_padding(num_dims, 0); + DimVector high_padding(num_dims, 0); + DimVector interior_padding(num_dims, 0); + constexpr int num_spatial_dims = num_dims - 2; + for (int i = 0; i < num_spatial_dims; ++i) { + int dim = tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); + int orig_input_shape_padded_in_dim = orig_input_shape[dim] + + orig_padding[dim].first + + orig_padding[dim].second; + // Set interior padding such that neighboring entries from + // `out_grad_divided` have distance `strides[dim]` from each other in + // every dimension. + interior_padding[dim] = strides[dim] - 1; + // Set exterior padding in the same way as for convolution gradient + // computation. + auto status = ::xla::ConvGradExtractAndVerifyDimension( + /*input_size=*/orig_input_shape_padded_in_dim, + /*filter_size=*/ksize[dim], + /*output_size=*/out_grad_shape[dim], + /*dilation=*/1, + /*stride=*/strides[dim], + /*padding=*/::xla::Padding::kValid); + if (!status.ok()) { + return failure(); + } + ::xla::SpatialDimensionOutputSizeAndPadding &conv_grad_spatial_dim = + status.value(); + // Subtract the original exterior padding since it doesn't contribute to + // the gradient. Note that we save one `PadOp` and some unnecessary kernel + // computations, compared to the `xla::AvgPoolGrad` implementation, by + // subtracting the original exterior padding before `ReduceWindowOp` + // instead of trimming the result of `ReduceWindowOp` (the final result is + // the same because all strides are 1). + low_padding[dim] = + conv_grad_spatial_dim.pad_before - orig_padding[dim].first; + high_padding[dim] = + conv_grad_spatial_dim.pad_after - orig_padding[dim].second; + + // Update `out_grad_shape` to result shape of following `PadOp`. + out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + + (out_grad_shape[dim] - 1) * strides[dim] + 1; + } + Value reduce_window_input = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape(out_grad_shape, element_type), + /*operand=*/out_grad_divided->getOpResult(0), + /*padding_value=*/zero, + /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), + /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), + /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); + + // Compute result by convolving `reduce_window_input` with an all-ones + // kernel, using `ReduceWindowOp` with `AddOp` body. + + Type sum_element_type = GetSumAccumulationType(element_type); + if (element_type != sum_element_type) { + // Convert to appropriate sum accumulation type to avoid precision loss. + reduce_window_input = rewriter.create(loc, reduce_window_input, + sum_element_type); + zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); + } + auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); + auto reduce_window_op = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(orig_input_shape, + sum_element_type), + /*operand=*/reduce_window_input, + /*init_value=*/zero, + /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_strides=*/ones, + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), + /*padding=*/DenseIntElementsAttr()); + BuildReduceBody(sum_element_type, &reduce_window_op.getBody(), + &rewriter); + Value result = reduce_window_op.getResult(0); + + if (element_type != sum_element_type) { + // Convert back to original element type. + result = rewriter.create(op.getLoc(), result, element_type); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +using ConvertAvgPool2DGradOp = + ConvertAvgPoolGradOp; +using ConvertAvgPool3DGradOp = + ConvertAvgPoolGradOp; + +// Converts MaxPool op to HLO ReduceWindow op by setting appropriate window +// dimensions with max as the reduction function. +// +// Sample result for VALID padding mode: +// +// %init = arith.constant dense<...> : tensor +// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// {window_dimensions = ..., window_strides = ... } +// +template +class ConvertMaxPoolOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Type element_type = + mlir::cast(op.getInput().getType()).getElementType(); + if (!element_type.isSignlessIntOrFloat()) return failure(); + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + if (padding == tensorflow::Padding::EXPLICIT) { + return failure(); + } + Location loc = op.getLoc(); + ConstantOp init = GetScalarLimitConstOfType( + element_type, loc, hlo::kInfinityLowest, &rewriter); + + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty) return failure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + auto reduce = rewriter.create( + loc, op.getType(), op.getInput(), init, + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(element_type, &reduce.getBody(), &rewriter); + + rewriter.replaceOp(op, reduce.getResult(0)); + return success(); + } +}; + +using ConvertMaxPool2DOp = ConvertMaxPoolOp; +using ConvertMaxPool3DOp = ConvertMaxPoolOp; + +// Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on +// the condition only. +class ConvertSelectOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SelectOp op, + PatternRewriter &rewriter) const override { + // This lowering only works on ranked types. + auto cond_type = + mlir::dyn_cast(op.getCondition().getType()); + auto then_type = + mlir::dyn_cast(op.getThenValue().getType()); + auto else_type = + mlir::dyn_cast(op.getElseValue().getType()); + if (!cond_type || !then_type || !else_type) { + return failure(); + } + + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value cond_shape = b.createOrFold(op.getCondition()); + Value then_shape = b.createOrFold(op.getThenValue()); + Value else_shape = b.createOrFold(op.getElseValue()); + + // First check that the `then` and `else` shapes are the equal. + Value assumption = + b.createOrFold(ValueRange{then_shape, else_shape}); + // For a vector cond we also verify that the majormost dim of `then` matches + // the vector size. To do that split off the first dim of `then`. + bool needs_broadcast = cond_type.getRank() == 1 && then_type.getRank() != 1; + Value then_shape_split = then_shape; + if (needs_broadcast) { + Value const_one = b.create(1); + Type extent_first = shape::getExtentTensorType(b.getContext(), 1); + Type extent_second = + shape::getExtentTensorType(b.getContext(), then_type.getRank() - 1); + SmallVector then_split; + b.createOrFold(then_split, + TypeRange{extent_first, extent_second}, + then_shape, const_one); + then_shape_split = then_split[0]; + } + // If the condition is not a scalar, check that it matches the other shapes. + if (cond_type.getRank() > 0) { + Value eq_cstr = b.createOrFold( + ValueRange{cond_shape, then_shape_split}); + auto witness = shape::WitnessType::get(b.getContext()); + assumption = b.createOrFold( + witness, ValueRange{assumption, eq_cstr}); + } + auto result_type = mlir::cast(op.getResult().getType()); + auto assuming_op = + b.create(ArrayRef{result_type}, assumption); + + OpBuilder::InsertionGuard guard(b); + b.createBlock(&assuming_op.getDoRegion()); + + // Broadcast the cond if necessary. + Value cond = op.getCondition(); + if (needs_broadcast) { + Value result_extents = b.create( + GetExtentsTensorTypeFor(result_type), then_shape); + cond = b.create( + tensorflow::GetTypeFromTFTensorShape(result_type.getShape(), + b.getI1Type()), + cond, result_extents, + GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); + } + Value select = b.create( + result_type, cond, op.getThenValue(), op.getElseValue()); + b.create(select); + rewriter.replaceOp(op, {assuming_op.getResult(0)}); + return success(); + } +}; + +// Converts the tf.Slice op into mhlo.real_dynamic_slice +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertSliceOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SliceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getInput(); + Value begin_indices = op.getBegin(); + Value sizes = op.getSize(); + + auto input_ty = mlir::dyn_cast(input.getType()); + auto begin_type = mlir::dyn_cast(begin_indices.getType()); + auto size_type = mlir::dyn_cast(sizes.getType()); + + if (!input_ty || !begin_type || !size_type || + !begin_type.hasStaticShape() || !size_type.hasStaticShape() || + begin_type.getRank() != 1 || size_type.getRank() != 1) { + return failure(); + } + // TODO(disc): remove static shape check once folding/canonicalization func + // added + DenseIntElementsAttr size_attr; + if (matchPattern(op.getSize(), m_Constant(&size_attr))) { + return failure(); + } + + int rank = begin_type.getDimSize(0); + auto shape_scalar_type = begin_type.getElementType(); + Value one = rewriter.create(loc, 1); + SmallVector stride_values(rank, one); + SmallVector end_values; + SmallVector begin_values; + end_values.reserve(rank); + for (int i = 0; i < rank; ++i) { + SmallVector indices; + indices.push_back(rewriter.create(loc, i)); + auto begin_value = + rewriter.create(loc, begin_indices, indices); + auto size_value = rewriter.create(loc, sizes, indices); + Value minus_one = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, -1)); + auto is_minus_one = rewriter.create( + loc, arith::CmpIPredicate::eq, size_value, minus_one); + Value end_value = + rewriter.create(loc, begin_value, size_value); + auto dim_value = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, input, i)); + end_value = rewriter.create(loc, is_minus_one, + dim_value, end_value); + auto end_value_casted = rewriter.create( + loc, rewriter.getIndexType(), end_value); + end_values.push_back(end_value_casted); + + auto begin_value_casted = rewriter.create( + loc, rewriter.getIndexType(), begin_value); + begin_values.push_back(begin_value_casted); + } + auto index_ty = rewriter.getIndexType(); + auto start_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_values.size())}, index_ty), + begin_values); + auto end_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_values.size())}, index_ty), + end_values); + auto stride_indices = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(stride_values.size())}, index_ty), + stride_values); + + auto d_slice = rewriter.create( + loc, op.getOperation()->getResult(0).getType(), input, start_indices, + end_indices, stride_indices); + rewriter.replaceOp(op, d_slice.getOperation()->getResults()); + return success(); + } +}; + +static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, + Value *out_lhs, Value *out_rhs, + PatternRewriter *rewriter) { + // The dimension structure of the relevant operands to a tf.BatchMatMulV2 is: + // - lhs: [LHSBATCHDIMS..., LHSROWS, LHSCOLS] + // - rhs: [RHSBATCHDIMS..., RHSROWS, RHSCOLS] + // - result: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, RHSCOLS] + // To perform the matmul, we need to first broadcast lhs and rhs to a common + // set of leading dimensions before doing the actual matmul. + // That's what the code below does. + // In particular, we populate out_lhs and out_rhs to have dimension structure: + // - out_lhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., LHSROWS, LHSCOLS] + // - out_rhs: [broadcast(LHSBATCHDIMS, RHSBATCHDIMS)..., RHSROWS, RHSCOLS] + // To do this, we need to calculate those output shapes, which involves + // slicing off the leading batch dims of each operand, broadcasting them, + // then concatenating the broadcasted leading dims back to the row/col dims. + // Finally, we create a TF::BroadcastTo op that does the actual broadcast. + + // TODO(silvasean): Reduce duplication across reified shape calculations and + // the static computation of output types needed to create ops. + Value lhs_shape = rewriter->create(loc, lhs); + Value rhs_shape = rewriter->create(loc, rhs); + Value const_neg2 = + rewriter->create(loc, rewriter->getIndexAttr(-2)); + auto shape_type = shape::ShapeType::get(rewriter->getContext()); + auto lhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); + auto rhs_splitted = rewriter->create( + loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); + // The last two dimensions are the matrix row/col dimensions. Don't broadcast + // them. + SmallVector result_batch_shape_compile_time_extents; + mlir::OpTrait::util::getBroadcastedShape( + lhs_type.getShape().drop_back(2), rhs_type.getShape().drop_back(2), + result_batch_shape_compile_time_extents); + auto result_batch_shape = rewriter->create( + loc, shape_type, lhs_splitted.getHead(), rhs_splitted.getHead(), + /*error=*/nullptr); + // Lambda which handles the broadcasting of one side to the common + // leading-batch dimensions. + auto broadcast_one_side = [&](Value side, RankedTensorType type, + Value tail_shape, Value *out_side) { + ArrayRef matrix_dims = type.getShape().take_back(2); + auto result_shape = result_batch_shape_compile_time_extents; + result_shape.append(matrix_dims.begin(), matrix_dims.end()); + auto result_type = tensorflow::GetTypeFromTFTensorShape( + result_shape, type.getElementType()); + auto shape = rewriter->create( + loc, shape_type, result_batch_shape, tail_shape); + auto shape_tensor = rewriter->create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(result_shape.size())}, + rewriter->getIndexType()), + shape); + *out_side = rewriter->create(loc, result_type, side, + shape_tensor); + }; + broadcast_one_side(lhs, lhs_type, lhs_splitted.getTail(), out_lhs); + broadcast_one_side(rhs, rhs_type, rhs_splitted.getTail(), out_rhs); +} + +class ConvertBatchMatMulV2Op : public OpRewritePattern { + public: + // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved + // to CHLO and it is missing legalization to MHLO. Once that is done, this + // pattern's benefit can be changed back to one as well as the fallback + // lowering pattern for the op can be removed. + // + // Set benefit of this pattern to zero to prefer the fallback pattern when + // available and applicable. That pattern avoids broadcast on operands and is + // therefore faster. + // + // Native legalization for BatchMatMulV3 needs to be added as well. + explicit ConvertBatchMatMulV2Op(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TF::BatchMatMulV2Op op, + PatternRewriter &rewriter) const override { + Value lhs = op.getX(); + Value rhs = op.getY(); + auto lhs_type = mlir::dyn_cast(lhs.getType()); + auto rhs_type = mlir::dyn_cast(rhs.getType()); + if (!lhs_type || !rhs_type) return failure(); + if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { + lhs = rewriter.create(op.getLoc(), lhs_type, lhs); + } + if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { + rhs = rewriter.create(op.getLoc(), rhs_type, rhs); + } + + // Broadcast both operands. + BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, + &rewriter); + lhs_type = mlir::cast(lhs.getType()); + rhs_type = mlir::cast(rhs.getType()); + assert(lhs_type.getRank() == rhs_type.getRank()); + int64_t rank = lhs_type.getRank(); + auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); + auto lhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); + auto rhs_contracting_dimensions = llvm::to_vector<4>( + llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); + auto dimension_numbers = DotDimensionNumbersAttr::get( + rewriter.getContext(), + /*lhs_batching_dimensions=*/batch_dimensions, + /*rhs_batching_dimensions=*/batch_dimensions, + /*lhs_contracting_dimensions=*/lhs_contracting_dimensions, + /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); + // TODO(silvasean): Emit shape checks for contracting dimensions. + // (The batch dimensions are checked by the broadcasting logic) + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, rhs, dimension_numbers, + /*precision_config=*/GetPrecisionConfig(&rewriter), + /*algorithm=*/DotAlgorithmAttr{}); + return success(); + } +}; + +// Converts the tf.Split op into a series of HLO slice ops when the tensor to be +// split has fully static shape and the dimension to split is a constant. +// +// The main logic of this pattern is to calculate the index start and end range +// for each slice. And this happens only on the dimension to be split; for all +// other dimensions, all resultant slices' index start and end range covers the +// input tensor's full range. Strides for all resultant slices are all one. +// +// For example, the following source IR: +// +// %dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor +// %0:3 = "tf.Split"(%dim, %input) : (tensor, tensor<4x6xf32>) -> +// (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) +// +// will be converted into: +// +// %0 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 2]> : tensor<2xi64>, +// start_indices = dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %1 = "mhlo.slice"(%input) { +// limit_indices = dense<4> : tensor<2xi64>, +// start_indices = dense<[0, 2]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %2 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 6]> : tensor<2xi64>, +// start_indices = dense<[0, 4]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// TODO(antiagainst): consider lowering into TF ops so the pattern can be more +// applicable. +class ConvertSplitOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitOp op, + PatternRewriter &rewriter) const override { + // We can only split inputs that have fully static shape. + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + // Calculate the dimension size for each slice along the split dimension. + int64_t input_dim_size = input_type.getDimSize(dim_index); + + int64_t num_splits = op.getNumResults(); + int64_t slice_size = input_dim_size / num_splits; + + // Get each slice's type. + auto slice_shape = llvm::to_vector<4>(input_type.getShape()); + slice_shape[dim_index] = slice_size; + Type slice_type = tensorflow::GetTypeFromTFTensorShape( + slice_shape, input_type.getElementType()); + + // Parameters for constructing each slice. + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + SmallVector strides(input_rank, 1); + + // All HLO slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(num_splits); + + for (int i = 0; i < num_splits; ++i) { + begin_indices[dim_index] = i * slice_size; + end_indices[dim_index] = (i + 1) * slice_size; + slices.push_back( + rewriter.create(op.getLoc(), slice_type, op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the +// dimension to split is a constant. +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. delete ConvertSplitOp +class ConvertSplitOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getValue(); + auto input_type = mlir::dyn_cast(input.getType()); + if (!input_type) return failure(); + + // TODO(disc): remove static shape check once folding/canonicalization func + // added and ConvertSplitOp deleted. Calculate the dimension size for each + // slice along the split dimension. We are splitting along the dynamic + // dimension, or using static pattern transform + if (input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + Value input_dim_size = + rewriter.create(loc, input, dim_index); + // Calculate the dimension size for each slice along the split dimension. + int num_splits = op.getNumResults(); + Value num_splits_value = rewriter.create( + loc, rewriter.getIndexAttr(num_splits)); + Value slice_size = + rewriter.create(loc, input_dim_size, num_splits_value); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + + SmallVector begin_indices(input_rank, zero); + SmallVector end_indices; + end_indices.reserve(input_rank); + SmallVector strides(input_rank, one); + for (int i = 0; i < input_rank; ++i) { + end_indices.push_back(rewriter.create(loc, input, i)); + } + + // All HLO d_slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(num_splits); + + for (int i = 0; i < num_splits; ++i) { + begin_indices[dim_index] = rewriter.create( + loc, slice_size, rewriter.create(loc, i)); + end_indices[dim_index] = rewriter.create( + loc, slice_size, rewriter.create(loc, i + 1)); + + Type index_ty = rewriter.getIndexType(); + auto begin_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_indices.size())}, index_ty), + begin_indices); + auto end_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_indices.size())}, index_ty), + end_indices); + auto stride_value = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(strides.size())}, index_ty), + strides); + slices.push_back(rewriter.create( + loc, op.getOperation()->getResult(i).getType(), input, begin_value, + end_value, stride_value)); + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts the tf.SplitV op into a series of HLO slice ops when the tensor to +// be split has fully static shape and the dimension to split and split sizes +// are constants. +// +// This is similar to the conversion for tf.Split op other than that the size of +// each chunk on the dimension to split is explicitly given as an op operand +// and they are not necessarily the same. +// +// For example, given the following IR: +// +// %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} +// %split_dim = "tf.Const"() {value = dense<1> : tensor} +// %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : +// (tensor<4x6xf32>, tensor<3xi32>, tensor) -> +// (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) +// +// We will generate slices following slices: +// %0 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 1]> : tensor<2xi64>, +// start_indices = dense<0> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x1xf32> +// %1 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 3]> : tensor<2xi64>, +// start_indices = dense<[0, 1]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x2xf32> +// %2 = "mhlo.slice"(%input) { +// limit_indices = dense<[4, 6]> : tensor<2xi64>, +// start_indices = dense<[0, 3]> : tensor<2xi64>, +// strides = dense<1> : tensor<2xi64>} : +// (tensor<4x6xf32>) -> tensor<4x3xf32> +class ConvertSplitVOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SplitVOp op, + PatternRewriter &rewriter) const override { + // We can only split inputs that have fully static shape. + // TODO(b/145731001): enhance to support dynamic-shaped inputs. + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type || !input_type.hasStaticShape()) return failure(); + + // We can only match when the split dimension is a constant scalar. + DenseIntElementsAttr split_dim_attr; + if (!matchPattern(op.getSplitDim(), m_Constant(&split_dim_attr))) + return failure(); + + // We can only match when the split sizes is a constant int vector. + DenseIntElementsAttr split_sizes_attr; + if (!matchPattern(op.getSizeSplits(), m_Constant(&split_sizes_attr))) + return failure(); + + // Get each chunck's size along the dimension to split. It may contain + // dynamic sizes and we need to update it if so. + SmallVector split_sizes; + int64_t total_dim_size = 0; // Total dimension size assigned to splits + std::optional dynamic_dim_index; + split_sizes.reserve( + mlir::cast(split_sizes_attr.getType()).getNumElements()); + for (const auto &dim : llvm::enumerate(split_sizes_attr)) { + int64_t dim_val = dim.value().getSExtValue(); + split_sizes.push_back(dim_val); + if (dim_val == -1) { + // We cannot have more than one dynamic dimension. + assert(!dynamic_dim_index && "invalid split sizes"); + dynamic_dim_index = dim.index(); + } else { + total_dim_size += dim_val; + } + } + + // Get the dimension we are splitting at. Offset properly if it's negative. + int64_t input_rank = input_type.getRank(); + int64_t dim_index = (*split_dim_attr.begin()).getSExtValue(); + if (dim_index < 0) dim_index += input_rank; + + int64_t input_dim_size = input_type.getDimSize(dim_index); + assert(((dynamic_dim_index && total_dim_size <= input_dim_size) || + (!dynamic_dim_index && total_dim_size == input_dim_size)) && + "invalid split sizes"); + + // Update the dynamic dimension with calculated concrete size. + if (dynamic_dim_index) + split_sizes[*dynamic_dim_index] = input_dim_size - total_dim_size; + + // Parameters for constructing each slice. + SmallVector begin_indices(input_rank, 0); + auto end_indices = llvm::to_vector<4>(input_type.getShape()); + SmallVector strides(input_rank, 1); + + // All HLO slice results used to replace the original tf.Split op. + SmallVector slices; + slices.reserve(op.getNumResults()); + + for (int i = 0, end = op.getNumResults(); i < end; ++i) { + end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; + slices.push_back(rewriter.create( + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter))); + // Prepare the begin indice for the next slice. + begin_indices[dim_index] = end_indices[dim_index]; + } + + rewriter.replaceOp(op, slices); + return success(); + } +}; + +// Converts StridedSlice op to HLO Slice op along with Reverse op to handle +// negative strides and Reshape op to update the output shape. Indices and +// strides operands are converted to attributes with non-negative indexing. +// +// If the begin input is not a compile time constant, the begin input needs to +// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this +// case, strides must have a known value of 1 (otherwise we have insufficient +// information to conform to XLA's op semantics). +// +// For example with an op like following, +// tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} +// : tensor -> tensor +// +// If the %begin input is constant, output would be: +// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} +// %sliced = "mhlo.Slice" (%input) +// {start_indices = ..., limit_indices = ..., strides = ...} +// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// +class ConvertStridedSliceOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult rewriteWithConstantBegin(TF::StridedSliceOp op, + ArrayRef begin_indices, + ArrayRef end_indices, + ArrayRef strides, + RankedTensorType input_ty, + PatternRewriter &rewriter) const { + SmallVector hlo_begin_indices, hlo_end_indices, hlo_strides, + dims_to_reverse; + int64_t input_rank = input_ty.getRank(); + ArrayRef input_shape = input_ty.getShape(); + hlo_begin_indices.reserve(input_rank); + hlo_end_indices.reserve(input_rank); + hlo_strides.reserve(input_rank); + + int64_t indices_elements = begin_indices.size(); + if (input_rank < indices_elements) return failure(); + + // Convert from TensorFlow negative or out of range indices and strides + // values to legal HLO Slice attributes. + for (int i = 0, e = indices_elements; i != e; i++) { + int64_t begin = begin_indices[i]; + int64_t end = end_indices[i]; + int64_t stride = strides[i]; + + if (stride < 0) { + // Negative stride means that the output values are computed starting + // from end until begin. Mark the dimension for reversal before slice + // and compute indices for the reversed input. + dims_to_reverse.push_back(i); + begin = (input_shape[i] - 1) - begin; + end = (input_shape[i] - 1) - end; + stride = -stride; + } + + // Unlike TensorFlow, HLO requires begin and end values to be within + // range. + begin = std::max(int64_t(0), begin); + end = std::max(begin, end); + end = std::min(end, input_shape[i]); + + hlo_begin_indices.push_back(begin); + hlo_end_indices.push_back(end); + hlo_strides.push_back(stride); + } + + Location loc = op.getLoc(); + Value input = op.getInput(); + if (!dims_to_reverse.empty()) + input = rewriter.create( + loc, input_ty, op.getInput(), + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + auto sliced = rewriter.create( + loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), + GetI64ElementsAttr(hlo_end_indices, &rewriter), + GetI64ElementsAttr(hlo_strides, &rewriter)); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult rewriteWithUnknownBegin(TF::StridedSliceOp op, + RankedTensorType input_ty, + RankedTensorType result_ty, + PatternRewriter &rewriter) const { + // If begin and end values are dynamic, we can only support this lowering + // if strides are a known value of 1. + DenseIntElementsAttr sparse_strides_attr; + if (!matchPattern(op.getStrides(), m_Constant(&sparse_strides_attr))) { + return rewriter.notifyMatchFailure( + op, + "requires that strides are known when begin/end values are dynamic"); + } + SmallVector strides; + int64_t stride_value; + for (const APInt &stride : sparse_strides_attr) { + if ((stride_value = stride.getSExtValue()) != 1) { + return rewriter.notifyMatchFailure(op, + "requires that strides are all 1 " + "when begin/end values are dynamic"); + } + strides.push_back(stride_value); + } + + ArrayRef input_shape = input_ty.getShape(); + int last_dim = std::max(static_cast(input_shape.size()) - 1, 0); + + // When begin/end values are dynamic, the ellipsis mask, if set, must refer + // to the last dimension. + int ellipsis_mask = op.getEllipsisMask(); + if (!(ellipsis_mask == 0 || ellipsis_mask == (1 << last_dim))) + return rewriter.notifyMatchFailure( + op, + "requires that ellipsis_mask, if set, refer to the last dimension of " + "input (when begin/end values are dynamic)"); + + // In this case where the begin and end values are dynamic, we only support + // cases where the number of output elements has to be equal to the number + // of input elements that are sliced. Each dimension is either sliced fully + // or sliced with a size of one. + int output_elements = result_ty.getNumElements(); + int input_elements_sliced = 1; + + // Begin must be a ranked, 1-dimensional tensor: This is checked by the + // verifier. + int64_t slicing_dim_size = + mlir::cast(op.getBegin().getType()).getDimSize(0); + uint64_t begin_mask = op.getBeginMask(); + uint64_t end_mask = op.getEndMask(); + const int input_rank = input_shape.size(); + for (int d = 0; d < input_rank; ++d) { + // Each dimension is either sliced fully or has size of one. + if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || + (d >= slicing_dim_size)) { + input_elements_sliced *= input_shape[d]; + } + } + if (input_elements_sliced != output_elements) { + return rewriter.notifyMatchFailure( + op, + "requires the number of output elements to be equal to the number of " + "input elements sliced (when begin/end values are dynamic)"); + } + + SmallVector slice_begin_indices; + // For the dimensions that are to be sliced, all have slice sizes of 1. + SmallVector slice_sizes; + auto begin_element_ty = + mlir::cast(op.getBegin().getType()).getElementType(); + // Scalar tensor type. + TensorType type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); + Location loc = op.getLoc(); + auto zero = GetScalarConstOfType(begin_element_ty, loc, 0, &rewriter); + for (int d = 0; d < input_rank; ++d) { + if ((((begin_mask >> d) & 1) && ((end_mask >> d) & 1)) || + (d >= slicing_dim_size)) { + slice_begin_indices.push_back(zero); + slice_sizes.push_back(input_shape[d]); + continue; + } + + auto index = rewriter.create( + loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), + GetI64ElementsAttr({d + 1}, &rewriter), + GetI64ElementsAttr({1}, &rewriter)); + // Convert index to scalar. + auto reshaped_index = rewriter.create(loc, type, index); + // If the index is negative, wrap it around with dimension size. + auto index_negative = + rewriter.create(loc, reshaped_index, zero); + auto input_val = GetScalarConstOfType(begin_element_ty, loc, + input_shape[d], &rewriter); + auto wrapped_index = + rewriter.create(loc, input_val, reshaped_index); + auto final_index = rewriter.create( + loc, type, index_negative, wrapped_index, reshaped_index); + slice_begin_indices.push_back(final_index); + slice_sizes.push_back(1); + } + + auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + auto sliced_type = tensorflow::GetTypeFromTFTensorShape( + slice_sizes, op.getType().getElementType()); + // This must be an xla DynamicSlice op due to the inputs that aren't + // constant. + auto sliced = rewriter.create( + loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); + + // Reshape slice result so that the shape is updated depending on + // 'new_axis_mask' or 'shrink_axis_mask' attributes. + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + return success(); + } + + LogicalResult matchAndRewrite(TF::StridedSliceOp op, + PatternRewriter &rewriter) const override { + // Input shape needs to be static to convert negative indices in TensorFlow + // to absolute indices required by HLO. + // + // TODO(hinsu): Relax this constraint for ops without negative indices and + // strides. + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + + // Output shape needs to be static to apply 'new_axis_mask' or + // 'shrink_axis_mask' by reshaping tensor after slice. + // + // TODO(hinsu): Relax this constraint for ops without the above masks. + auto result_ty = mlir::dyn_cast(op.getType()); + if (!result_ty || !result_ty.hasStaticShape()) return failure(); + + DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; + if (!matchPattern(op.getBegin(), m_Constant(&sparse_begin_attr)) || + !matchPattern(op.getEnd(), m_Constant(&sparse_end_attr))) { + return rewriteWithUnknownBegin(op, input_ty, result_ty, rewriter); + } + + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedBoundRanges(&begin_indices, &end_indices, &strides)) { + return failure(); + } + return rewriteWithConstantBegin(op, begin_indices, end_indices, strides, + input_ty, rewriter); + } +}; + +// Converts tf.StridedSliceGrad to HLO reshape, reverse and padding ops. +// +// tf.StridedSlice is taking slice of the input tensor. tf.StridedSliceGrad does +// the reverse: it propagates the graident for the sliced tensor to the original +// input tensor by doing padding with zeros. The main logic is calculating the +// indices and strides for padding. +class ConvertStridedSliceGradOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::StridedSliceGradOp op, + PatternRewriter &rewriter) const override { + // We need constant input shape to perform padding calculations later. + DenseIntElementsAttr input_shape_attr; + if (!matchPattern(op.getShape(), m_Constant(&input_shape_attr))) + return failure(); + + // We also need constant begin/end indices and strides to perform padding + // calculations. + // Bounded shape after performing strided slice + SmallVector shape; + // Bounded begin, end, and strides for strided slice + SmallVector begin_indices, end_indices, strides; + if (!op.GetSlicedShapeAndBoundRanges(&shape, &begin_indices, &end_indices, + &strides)) + return failure(); + + Value grad = op.getDy(); + Type element_type = mlir::cast(grad.getType()).getElementType(); + + // Perform reshape to undo any new/shrink axes done by strided slice. + grad = rewriter.create( + op.getLoc(), tensorflow::GetTypeFromTFTensorShape(shape, element_type), + grad); + + SmallVector padding_low, padding_high, padding_interm; + SmallVector dims_to_reverse; + padding_low.reserve(shape.size()); + padding_high.reserve(shape.size()); + padding_interm.reserve(shape.size()); + + // Prepare padding parameters for each dimension. + for (int i = 0, e = shape.size(); i < e; ++i) { + int64_t input_dim = (*(input_shape_attr.begin() + i)).getSExtValue(); + if (strides[i] > 0) { + padding_low.push_back(begin_indices[i]); + padding_interm.push_back(strides[i] - 1); + + // Pad the upper dimension up to the expected input shape. It's not + // sufficient simply to use end_indices[i] to compute the padding in + // cases where the stride does not divide evenly into the interval + // between begin_indices[i] and end_indices[i]. + int64_t size = + padding_low[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; + padding_high.push_back(input_dim - size); + } else { + dims_to_reverse.push_back(i); + padding_high.push_back(input_dim - begin_indices[i] - 1); + padding_interm.push_back(-strides[i] - 1); + + // Pad the lower dimension up to the expected input shape. + int64_t size = + padding_high[i] + shape[i] + (shape[i] - 1) * padding_interm[i]; + padding_low.push_back(input_dim - size); + } + } + + if (!dims_to_reverse.empty()) { + grad = rewriter.create( + op.getLoc(), grad.getType(), grad, + GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); + rewriter.replaceOpWithNewOp( + op, op.getType(), grad, zero, + GetI64ElementsAttr(padding_low, &rewriter), + GetI64ElementsAttr(padding_high, &rewriter), + GetI64ElementsAttr(padding_interm, &rewriter)); + return success(); + } +}; + +/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and +/// offset applied to generate the range values. The output tensor needs to +/// have a static shape. +/// +/// For example an op like the following: +/// %result = "tf.Range"(%start, %limit, %delta) {Tidx = "tfdtype$DT_FLOAT"} +/// : (tensor, tensor, tensor) -> tensor<5xf32> +/// +/// Output would be: +/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> +/// %scaled = "mhlo.multiply"(%iota, %delta) +/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : +/// (tensor<5xf32>, tensor) -> tensor<5xf32> +/// %result = "mhlo.add"(%scaled, %offset) +/// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : +/// (tensor<5xf32>, tensor) -> tensor<5xf32> +/// +/// Implementation is defined in C++ due to no type interface for the iota op. +class ConvertRangeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RangeOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = result.getType(); + if (!mlir::cast(result_type).hasStaticShape()) { + return failure(); + } + + auto iota = rewriter.create(op.getLoc(), result_type, + rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, op.getDelta(), + hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); + return success(); + } +}; + +// Converts RangeOp for cases with the length is a dynamic value. The shape of +// the resulting tensor computed, then the start and delta is used with the +// dynamic_iota value to compute the final range value. +// +// For example, the resulting range op value: +// %range = "tf.range"(%start, %limit, %delta) +// +// Is converted to the following. +// %start + %delta * iota(ceil(abs((%limit - %start) / %delta)) +// +// Implementation is defined in C++ due to the complicated type behavior. +class ConvertDynamicRangeOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RangeOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = mlir::cast(result.getType()); + if (result_type.hasStaticShape()) { + return failure(); + } + + Value start = op.getStart(); + Value delta = op.getDelta(); + Value limit = op.getLimit(); + + // To compute the length we need to use floating point calculations so that + // ceil can be computed for the number of steps. + auto compute_element_type = + mlir::isa(getElementTypeOrSelf(start.getType())) + ? getElementTypeOrSelf(start.getType()) + : rewriter.getF64Type(); + auto compute_type = tensorflow::GetTypeFromTFTensorShape( + mlir::cast(limit.getType()).getShape(), + compute_element_type); + + // Compute the length of the sequence we are going to need. This includes + // some conversion to float for the operations. + // + // %size = ceil(abs((%limit - %start) / %delta)) + auto range = rewriter.create(op.getLoc(), limit, start); + auto abs = rewriter.create(op.getLoc(), range); + + // Delta is not necessarily the same type as start and limit. + auto abs_cast = + rewriter.create(op.getLoc(), compute_type, abs); + auto delta_cast = + rewriter.create(op.getLoc(), compute_type, delta); + + // Compute the total number of integer steps and convert to the HLO + // dimension tensor. + auto normalized = + rewriter.create(op.getLoc(), abs_cast, delta_cast); + auto ceil = rewriter.create(op.getLoc(), normalized); + auto steps = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), ceil); + auto reshape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), + steps); + + // Using the resulting length compute the correct range value: + // + // %range = %start + %delta * iota(%size) + auto out_scalar_type = tensorflow::GetTypeFromTFTensorShape( + {}, getElementTypeOrSelf(result_type)); + auto start_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, start); + auto delta_out_cast = + rewriter.create(op.getLoc(), out_scalar_type, delta); + + auto iota = rewriter.create( + op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, delta_out_cast, + hlo::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, start_out_cast, + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); + return success(); + } +}; + +ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { + auto int_attr = mlir::cast(attr); + auto type = mlir::cast(val.getType()); + + SmallVector axis; + axis.reserve(int_attr.getNumElements()); + + int64_t rank = type.getRank(); + for (auto val : int_attr.getValues()) { + axis.push_back((val.getSExtValue() + rank) % rank); + } + + return builder->getI64TensorAttr(axis); +} + +/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling +/// and offset applied to generate the linspace values. The output tensor needs +/// to have a static shape. The implementation is defined in C++ because there +/// is no type inference for the iota op. +class ConvertLinSpaceOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::LinSpaceOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto result_type = mlir::dyn_cast(result.getType()); + if (!result_type || !result_type.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr num_attr; + if (!matchPattern(op.getNum(), m_Constant(&num_attr))) { + return rewriter.notifyMatchFailure(op, "Num must be a constant scalar"); + } + + if (num_attr.begin() == num_attr.end()) { + return rewriter.notifyMatchFailure(op, "Num must not be empty"); + } + int64_t num = (*num_attr.begin()).getSExtValue(); + + // Calculate the scaling that needs to be applied to the iota. + auto step_numerator = rewriter.create( + op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), + op.getStart())); + Value step_denominator = rewriter.create( + op.getLoc(), op.getNum(), result_type.getElementType()); + if (num > 1) { + Value one = GetScalarConstOfType(result_type.getElementType(), + op.getLoc(), 1, &rewriter); + step_denominator = rewriter.create( + op.getLoc(), step_denominator.getType(), step_denominator, one, + hlo::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); + } + auto step = rewriter.create( + op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, + hlo::getBroadcastDimensionsAttr(&rewriter, step_numerator, + step_denominator)); + + // Scale the iota and add the offset. + auto iota = rewriter.create(op.getLoc(), result_type, + rewriter.getI64IntegerAttr(0)); + auto scaled = rewriter.create( + op.getLoc(), result_type, iota, step, + hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); + rewriter.replaceOpWithNewOp( + op, result_type, scaled, op.getStart(), + hlo::getBroadcastDimensionsAttr(&rewriter, scaled, op.getStart())); + return success(); + } +}; + +/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over +/// ReductionOp. +/// `is_accumulation` controls whether it uses higher precision for the actual +/// reduction. This is set to false for ops like max where there is no precision +/// concerns. +// +// The Derived class should have a static method to return the initial value to +// use for reduction: +// static Value GetInitialValue(Type reduce_element_type, Location loc, +// PatternRewriter *rewriter); +// The reduce_element_type is guaranteed to be a float, int, or complex type +// suitable for use with GetScalarConstOfType or GetScalarLimitConstOfType. +template +class GenericConvertReductionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // TODO(b/141785544): Update this to not require ranked shapes. + // Input shape needs to be ranked to convert negative indices in TensorFlow + // to absolute indices required by HLO. + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty) return failure(); + ArrayRef input_shape = input_ty.getShape(); + + DenseIntElementsAttr dimensions; + if (!matchPattern(op.getReductionIndices(), m_Constant(&dimensions))) + return failure(); + + // Build the final shape from input_shape and dimensions using a bitmap + // to mark the reduced dimensions. + SmallVector reduced_dimensions_bitmap(input_shape.size(), false); + SmallVector xla_dimensions; + for (const APInt &index_raw : dimensions.getValues()) { + int64_t index = index_raw.getSExtValue(); + int64_t rank = input_shape.size(); + if ((index < -rank || index >= rank)) return failure(); + index = (index + rank) % rank; + reduced_dimensions_bitmap[index] = true; + xla_dimensions.push_back(index); + } + + Location loc = op.getLoc(); + Type element_type = input_ty.getElementType(); + + // Only float, int, and complex types are currently supported. + if (!mlir::isa(element_type) && + !mlir::isa(element_type) && + !mlir::isa(element_type)) { + return rewriter.notifyMatchFailure( + op, "element type must be float, int, or complex type"); + } + + // Convert to an accumulation type to not lose precision when doing + // repeated arithmetic operations. + Type reduce_element_type = + is_accumulation ? GetAccumulationType(element_type) : element_type; + auto casted_input = + rewriter.create(loc, op.getInput(), reduce_element_type); + + // Each reduction op can have a different initial value. + Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); + + auto reduction = rewriter.create( + loc, casted_input.getResult(), init, + GetI64ElementsAttr(xla_dimensions, &rewriter), reduce_element_type); + BuildReduceBody(reduce_element_type, &reduction.getBody(), + &rewriter); + Value result = reduction.getResult(0); + + // The mean op needs to divide by the product of the reduced dimensions. + if (std::is_same::value) { + Value in_shape = rewriter.create(loc, op.getInput()); + Value divisor_count = rewriter.create(loc, 1); + for (size_t i = 0; i < input_shape.size(); ++i) { + if (reduced_dimensions_bitmap[i]) { + Value index = rewriter.create(loc, i); + auto dim = rewriter.create(loc, in_shape, index); + divisor_count = + rewriter.create(loc, divisor_count, dim); + } + } + // HLO ops are only defined on tensors, so we cast the divisor from + // index -> i64 -> tensor<1xi64> -> tensor -> tensor + Value divisor_casted = rewriter.create( + loc, rewriter.getI64Type(), divisor_count); + Value divisor_tensor = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), + divisor_casted); + Value divisor = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), + divisor_tensor); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create(loc, result, divisor, + broadcast_dims); + } + + result = rewriter.create(loc, result, element_type); + + // Need to reshape back after the reduction if we're keeping the reduced + // dimensions. Note that we do this through successive (nominally 1) + // applications of the TF ExpandDims op vs a more labor intensive + // reshape. Various code generation techniques benefit from the knowledge + // that this is a restricted form of shape manipulation that is just adding + // unit dims. + if (op.getKeepDims()) { + for (const auto &dim_is_reduced : + llvm::enumerate(reduced_dimensions_bitmap)) { + if (dim_is_reduced.value()) { + auto index_attr = GetI32ElementsAttr( + {static_cast(dim_is_reduced.index())}, &rewriter); + Value index = rewriter.create(loc, index_attr); + result = rewriter.create(loc, result, index); + } + } + } + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts Mean op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// {dimensions = ...} +// %divisor = arith.constant dense<...> : tensor +// %mean = "mhlo.divide"(%sum, %divisor) +class ConvertMeanOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); + } +}; + +// Converts Sum op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// {dimensions = ...} +class ConvertSumOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + // The neutral element of fp addition is -0.0, not 0.0: '0.0 + -0.0 = 0.0'. + return GetScalarNegZeroOfType(reduce_element_type, loc, rewriter); + } +}; + +// Converts Max op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// {dimensions = ...} +class ConvertMaxOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, rewriter); + } +}; + +// Converts Min op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] +// {dimensions = ...} +class ConvertMinOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, rewriter); + } +}; + +// Converts Prod op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] +// {dimensions = ...} +class ConvertProdOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +// Converts All op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] +// {dimensions = ...} +class ConvertAllOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +// Converts Any op to HLO Reduce op. +// +// %init = arith.constant dense<...> : tensor +// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] +// {dimensions = ...} +class ConvertAnyOp + : public GenericConvertReductionOp { + public: + using GenericConvertReductionOp::GenericConvertReductionOp; + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); + } +}; + +// Converts tensorflow ArgMin or ArgMax op to mhlo operations that perform +// a reduction on the original input and the corresponding index. The reduction +// sub-computation selects the max (or min) value and the index for the value. +// Derived: is the resulting derived class of this class. +// OpTy: is TF::ArgMaxOp or TF::ArgMinOp. +template +class ConvertArgMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + RankedTensorType input_type = + mlir::dyn_cast(op.getInput().getType()); + if (!input_type) { + return failure(); + } + + Type input_element_type = input_type.getElementType(); + // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If + // tf.ArgMax doesn't support complex data types, this check can be removed. + if (!input_element_type.isSignlessIntOrFloat()) return failure(); + + Location loc = op.getLoc(); + Value init_value = + Derived::GetInitialValue(input_element_type, loc, rewriter); + + RankedTensorType output_type = + mlir::dyn_cast(op.getOutput().getType()); + if (!output_type) { + return rewriter.notifyMatchFailure(op, "requires known rank"); + } + + Type index_element_type = output_type.getElementType(); + Value index_init_value = + GetScalarConstOfType(index_element_type, loc, 0, &rewriter); + + RankedTensorType index_type = tensorflow::GetTypeFromTFTensorShape( + input_type.getShape(), index_element_type); + + std::optional optional_axis = + GetIntegerHLOAxisFromTFAxis(op.getDimension(), input_type.getRank()); + if (!optional_axis.has_value()) + return rewriter.notifyMatchFailure(op, "required axis"); + int64_t axis = optional_axis.value(); + + IntegerAttr iota_dimension = + IntegerAttr::get(rewriter.getIntegerType(64), axis); + Value input_shape = rewriter.create(loc, op.getInput()); + Value index_values = rewriter.create( + loc, index_type, input_shape, iota_dimension); + + Value operands[] = {op.getInput(), index_values}; + Value init_values[] = {init_value, index_init_value}; + DenseIntElementsAttr reduction_dimensions = + GetI64ElementsAttr({axis}, &rewriter); + + auto reduction = rewriter.create( + loc, llvm::ArrayRef(operands), + llvm::ArrayRef(init_values), reduction_dimensions, + TypeRange({input_element_type, index_element_type})); + auto direction = Derived::GetDirection(); + BuildArgMinMaxReductionBody(input_element_type, index_element_type, + direction, &reduction.getBody(), &rewriter); + + rewriter.replaceOp(op, {reduction.getResult(1)}); + return success(); + } +}; + +// Converts tensorflow ArgMax op to mhlo operations. The actual +// implementation is in class ConvertArgMinMaxOp: +// +// %init_index = arith.constant dense<...> : tensor +// %init = arith.constant dense<...> : tensor +// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["mhlo.arg_max"] +class ConvertArgMaxOp + : public ConvertArgMinMaxOp { + public: + using ConvertArgMinMaxOp::ConvertArgMinMaxOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, &rewriter); + } + + static ComparisonDirection GetDirection() { return ComparisonDirection::GE; } +}; + +// Converts tensorflow ArgMin op to mhlo operations. The actual +// implementation is in class ConvertArgMinMaxOp: +// +// %init_index = arith.constant dense<...> : tensor +// %init = arith.constant dense<...> : tensor +// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["mhlo.arg_min"] +class ConvertArgMinOp + : public ConvertArgMinMaxOp { + public: + using ConvertArgMinMaxOp::ConvertArgMinMaxOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter &rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, &rewriter); + } + + static ComparisonDirection GetDirection() { return ComparisonDirection::LE; } +}; + +// Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with +// assignment: +// +// %result = "mhlo.scatter"(%tensor, %indices, %updates) +// { dimensions = ... } +// +template +class ConvertTensorScatterOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto tensor_ty = mlir::dyn_cast(op.getTensor().getType()); + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); + auto updates_ty = + mlir::dyn_cast(op.getUpdates().getType()); + + if (!tensor_ty || !indices_ty || !updates_ty) return failure(); + // Last dimension of the indices needs to known at compile time for + // computation of the 'update_window_dims' attribute in the dimensions + // struct. + int64_t num_index_dims = indices_ty.getShape().back(); + if (ShapedType::isDynamic(num_index_dims)) return failure(); + + auto updates = op.getUpdates(); + + // Broadcast scalar `updates` in into expected shape as following shape: + // updates.shape == indices.shape[:-1] + tensor.shape[indices.shape[-1]:] + if (updates_ty.getRank() == 0 && + (std::is_same::value || + std::is_same::value)) { + if (!tensor_ty.hasStaticShape()) { + return failure(); + } + + if (!indices_ty.hasStaticShape()) { + return failure(); + } + + auto tensor_shape = tensor_ty.getShape(); + auto indices_shape = indices_ty.getShape(); + auto index_depth = indices_shape.back(); + llvm::SmallVector expected_update_shape; + + // create the expected update shape which scalar update is broadcasted to + expected_update_shape.append(indices_shape.begin(), + std::prev(indices_shape.end())); + + expected_update_shape.append(std::next(tensor_shape.begin(), index_depth), + tensor_shape.end()); + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + {static_cast(expected_update_shape.size())}, + rewriter.getIntegerType(64)); + + auto const_attr = GetI64ElementsAttr(expected_update_shape, &rewriter); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + + auto broadcast_to_type = tensorflow::GetTypeFromTFTensorShape( + llvm::ArrayRef(expected_update_shape), + updates_ty.getElementType()); + + updates = rewriter.create( + op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); + + updates_ty = mlir::dyn_cast(updates.getType()); + } + + int64_t tensor_rank = tensor_ty.getRank(); + int64_t indices_rank = indices_ty.getRank(); + int64_t updates_rank = + mlir::dyn_cast(updates.getType()).getRank(); + + int64_t window_dims = tensor_rank - num_index_dims; + auto dims_attr = ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + llvm::to_vector<4>( + llvm::seq(updates_rank - window_dims, updates_rank)), + llvm::to_vector<4>(llvm::seq(0, num_index_dims)), + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, + llvm::to_vector<4>(llvm::seq(0, num_index_dims)), + indices_rank - 1); + + Location loc = op.getLoc(); + auto scatter = rewriter.create( + loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), + updates, dims_attr); + Derived::BuildScatterBody(tensor_ty.getElementType(), + &scatter.getUpdateComputation(), loc, rewriter); + + rewriter.replaceOp(op, scatter.getResult(0)); + return success(); + } +}; + +class ConvertTensorScatterUpdateOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + builder.create(loc, block->getArgument(1)); + } +}; + +class ConvertTensorScatterAddOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto add_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, add_op.getResult()); + } +}; + +class ConvertTensorScatterSubOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto sub_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, sub_op.getResult()); + } +}; + +class ConvertTensorScatterMinOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto min_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, min_op.getResult()); + } +}; + +class ConvertTensorScatterMaxOp + : public ConvertTensorScatterOp { + public: + using ConvertTensorScatterOp::ConvertTensorScatterOp; + + static void BuildScatterBody(Type element_type, Region *region, Location loc, + OpBuilder &builder) { + OpBuilder::InsertionGuard guard(builder); + Block *block = builder.createBlock(region); + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + auto max_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, max_op.getResult()); + } +}; + +// Converts Tile op to HLO BroadcastInDim and Reshape ops. +// For shape [S1, S2] and multiples [M1, M2], +// MS1 = M1 * S1; MS2 = M2 * S2 +// +// %broadcast = mhlo.broadcast_in_dim(%input) { +// broadcast_dimensions = [0, 2] +// } +// %result = "mhlo.reshape"(%broadcast) : (tensor) +// -> tensor +class ConvertTileOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::TileOp op, + PatternRewriter &rewriter) const override { + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + if (!input_ty || !input_ty.hasStaticShape()) return failure(); + ArrayRef input_shape = input_ty.getShape(); + Type element_type = input_ty.getElementType(); + + DenseIntElementsAttr multiples; + if (!matchPattern(op.getMultiples(), m_Constant(&multiples)) || + multiples.getType().getRank() != 1) + return failure(); + + const int64_t input_shape_size = input_shape.size(); + if (multiples.getNumElements() != input_shape_size) return failure(); + + SmallVector broadcasted_shape; + SmallVector broadcast_dimensions; + broadcasted_shape.reserve(input_shape.size() * 2); + broadcast_dimensions.reserve(input_shape.size()); + for (auto multiple_and_input : + llvm::zip(multiples.getValues(), input_shape)) { + int64_t multiple = std::get<0>(multiple_and_input).getSExtValue(); + int64_t input_size = std::get<1>(multiple_and_input); + + if (multiple < 0) return failure(); + + // Line input up with the next dimension in broadcasted_shape + // when broadcasting. + int64_t broadcast_dim; + int64_t output_size = input_size * multiple; + if (input_size == 1 || multiple == 1) { + // Special case for when normal broadcasting will just work. + broadcast_dim = broadcasted_shape.size(); + broadcasted_shape.push_back(output_size); + } else { + // Tiling will happen for this dimension during the ReshapeOp below. + broadcasted_shape.push_back(multiple); + broadcast_dim = broadcasted_shape.size(); + broadcasted_shape.push_back(input_size); + } + broadcast_dimensions.push_back(broadcast_dim); + } + Location loc = op.getLoc(); + Type broadcasted_type = + tensorflow::GetTypeFromTFTensorShape(broadcasted_shape, element_type); + Type output_type = op.getType(); + + Value result = rewriter.create( + loc, broadcasted_type, op.getInput(), + GetI64ElementsAttr(broadcast_dimensions, &rewriter)); + + if (output_type != broadcasted_type) { + result = rewriter.create(loc, output_type, result); + } + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts the tf.TileOp op into mhlo.dynamic_reshape +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertTileOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + // clang-format off + // Converts Tile op to HLO DBroadcastInDim and DReshape ops. + // For shape [S1, S2] and multiples [M1, M2], + // MS1 = M1 * S1; MS2 = M2 * S2 + // + // %out_dim_size = [S1, M1, S2, M2] + // %broadcast_dimensions = [1, 3]; + // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions); + // %shape = [MS1, MS2] + // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor) -> tensor + // clang-format on + LogicalResult matchAndRewrite(TF::TileOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value input = op.getInput(); + Value multiples = op.getMultiples(); + auto input_ty = mlir::dyn_cast(input.getType()); + if (!input_ty) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (input_ty.hasStaticShape()) return failure(); + + Type element_type = input_ty.getElementType(); + int64_t input_rank = input_ty.getRank(); + SmallVector input_shape_values; + for (int64_t i = 0; i < input_rank; ++i) { + auto dim_size = input_ty.getDimSize(i); + if (dim_size == ShapedType::kDynamic) { + input_shape_values.push_back( + rewriter.create(loc, input, i)); + } else { + input_shape_values.push_back(rewriter.create( + loc, rewriter.getIndexAttr(dim_size))); + } + } + + auto multiples_ty = mlir::dyn_cast(multiples.getType()); + int64_t multiples_rank = multiples_ty.getRank(); + // rank of multiples input of tf.TileOp must be 1 + if (multiples_rank != 1) return failure(); + // multiples input of tf.TileOp must be fixed shaped + if ((!multiples_ty.hasStaticShape()) || + (multiples_ty.getDimSize(0) != input_rank)) { + return failure(); + } + Type index_ty = rewriter.getIndexType(); + // %out_dim_size + SmallVector out_dim_size; + out_dim_size.reserve(input_rank * 2); + for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { + Value index = rewriter.create( + loc, rewriter.getIndexAttr(dim_idx)); + Value multiples_size = + rewriter.create(loc, multiples, ValueRange{index}); + Value multiples_size_casted = + rewriter.create(loc, index_ty, multiples_size); + out_dim_size.push_back(multiples_size_casted); + out_dim_size.push_back(input_shape_values[dim_idx]); + } + SmallVector broadcast_dimensions; + broadcast_dimensions.reserve(input_rank); + for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { + broadcast_dimensions.push_back(1 + 2 * dim_idx); + } + auto broadcast_dims_attr = + GetI64ElementsAttr(broadcast_dimensions, &rewriter); + + Value out_dim_size_tensor = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(out_dim_size.size())}, index_ty), + out_dim_size); + SmallVector broadcast_shape(input_rank * 2, + ShapedType::kDynamic); + RankedTensorType broadcast_type = + tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); + Value broadcast = rewriter.create( + loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr); + + // %shape = [MS1, MS2] + SmallVector shape_values; + shape_values.reserve(input_rank); + for (int64_t i = 0; i < input_rank; ++i) { + Value dim_size_value = rewriter.create( + loc, out_dim_size[2 * i], out_dim_size[2 * i + 1]); + shape_values.push_back(dim_size_value); + } + Value shape = rewriter.create( + loc, tensorflow::GetTypeFromTFTensorShape({input_rank}, index_ty), + shape_values); + rewriter.replaceOpWithNewOp(op, op.getType(), + broadcast, shape); + return success(); + } +}; + +template +class ConvertMaxPoolGradOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Type element_type = + mlir::cast(op.getOrigInput().getType()).getElementType(); + + // Compute paddings using the original input and kernel shape and strides. + // Here, ReduceWindow op as used as the MaxPool op is lowered to the + // ReduceWindow op. + auto input_ty = + mlir::dyn_cast(op.getOrigInput().getType()); + if (!input_ty) return failure(); + DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( + input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), + &rewriter); + + auto result = rewriter.create( + loc, op.getType(), op.getOrigInput(), op.getGrad(), + GetScalarConstOfType(element_type, loc, 0, &rewriter), + GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + paddings_attr); + + BuildReduceBody(element_type, &result.getScatter(), &rewriter); + { + OpBuilder::InsertionGuard guard(rewriter); + Block *block = rewriter.createBlock(&result.getSelect()); + + // Block arguments are scalars of the given element type. + Type type = + tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = rewriter.create(loc, block->getArgument(0), + block->getArgument(1), + ComparisonDirection::GE); + rewriter.create(loc, reducer.getResult()); + } + + rewriter.replaceOp(op, result); + + return success(); + } +}; + +using ConvertMaxPool2DGradOp = + ConvertMaxPoolGradOp; +using ConvertMaxPool3DGradOp = + ConvertMaxPoolGradOp; + +// Converts tf.Conv?DBackpropInputOp into: +// %rev_filter = "mhlo.reverse"(%filter) +// %result = "mhlo.convolution"(%out_backprop, %rev_filter) +template +class ConvertConvBackpropInputOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Unpack all of the attributes. + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + constexpr int num_dims = num_spatial_dims + 2; + int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto out_backprop_ty = + mlir::dyn_cast(op.getOutBackprop().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + + // With the exception of out_backprop's batch dimension, out_backprop and + // filter need to have static shape. Filter is validated here, out_backprop + // is mostly validated at use. + if (!out_backprop_ty || !filter_ty || !filter_ty.hasStaticShape()) + return failure(); + + // Compute input_shape by supporting either: + // 1) Fully static shapes, represented as constants. + // 2) Static shapes with a dynamic batch dimension, represented as + // 1D tf.Pack of a batch dimension (can be static or dynamic) + // and other dimensions (can only be static), for example: + // "tf.Pack"(%142, %cst_301, %cst_301, %cst_300) {axis = 0 : i64, ...} + std::vector input_shape; + DenseIntElementsAttr input_shape_attr; + if (matchPattern(op.getInputSizes(), m_Constant(&input_shape_attr)) && + input_shape_attr.getType().getRank() == 1) { + input_shape.insert(input_shape.end(), + input_shape_attr.getValues().begin(), + input_shape_attr.getValues().end()); + } else { + auto pack = op.getInputSizes().template getDefiningOp(); + if (!pack || pack.getAxis() != 0) return failure(); + auto pack_ty = mlir::dyn_cast(pack.getType()); + if (!pack_ty || pack_ty.getRank() != 1) return failure(); + for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { + if (i == batch_dim) { + // We don't use the batch dimension below, so we don't care about + // its size. Might as well populate it with -1. + input_shape.push_back(ShapedType::kDynamic); + } else { + DenseIntElementsAttr input_dims_attr; + if (matchPattern(pack.getValues()[i], m_Constant(&input_dims_attr)) && + input_dims_attr.getType().getRank() == 0) { + input_shape.push_back(input_dims_attr.getSplatValue()); + } else { + return failure(); + } + } + } + } + + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; + auto strides_attr = GetI64ElementsAttr(op.getStrides()); + std::vector strides{ + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; + + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropInput. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op->template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + mlir::cast(explicit_padding).getInt()); + } + + ArrayRef filter_shape = filter_ty.getShape(); + + // Compute ConvDimensionNumbers, dilation, and padding. + SmallVector spatial_dims; + SmallVector lhs_dilation; + SmallVector rhs_dilation; + SmallVector paddings; + + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t spatial_dim = + GetTensorSpatialDimIndex(num_dims, data_format, i); + spatial_dims.push_back(spatial_dim); + + // Prepare metadata indexed by spatial_dim for computing pad_before + // and pad_after. + int64_t input_size = input_shape[spatial_dim]; + if (input_size == ShapedType::kDynamic) return failure(); + int64_t output_size = out_backprop_ty.getDimSize(spatial_dim); + if (output_size == ShapedType::kDynamic) return failure(); + int64_t filter_size = filter_ty.getDimSize(i); + int64_t stride = strides[spatial_dim]; + int64_t dilation = dilations[spatial_dim]; + + // Compute pad_before and pad_after following the logic from + // ConvBackpropComputeDimensionsV2. (Unfortunately, we cannot call + // the function in question because it doesn't work with dynamic dims). + int64_t padding_before = -1, padding_after = -1; + if (padding == tensorflow::Padding::EXPLICIT) { + padding_before = explicit_paddings[2 * spatial_dim]; + padding_after = explicit_paddings[2 * spatial_dim + 1]; + } + int64_t expected_output_size = 0; + auto status = GetWindowedOutputSizeVerbose( + input_size, filter_size, dilation, stride, padding, + &expected_output_size, &padding_before, &padding_after); + if (!status.ok()) return failure(); + if (output_size != expected_output_size) return failure(); + int64_t effective_filter_size = (filter_size - 1) * dilation + 1; + int64_t pad_before = effective_filter_size - 1 - padding_before; + int64_t padded_out_size = input_size + effective_filter_size - 1; + int64_t expanded_output_size = (output_size - 1) * stride + 1; + int64_t pad_after = padded_out_size - expanded_output_size - pad_before; + + // Populate metadata for the upcoming mhlo.conv op using the result of + // the computations performed above. + lhs_dilation.push_back(stride); + rhs_dilation.push_back(dilation); + paddings.push_back(pad_before); + paddings.push_back(pad_after); + } + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); + + Value filter = op.getFilter(); + + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = *(input_shape.begin() + feature_dim); + if (in_depth == ShapedType::kDynamic) return failure(); + const int64_t filter_in_depth = filter_shape[num_spatial_dims]; + const int64_t feature_group_count = in_depth / filter_in_depth; + + if (feature_group_count != 1) { + // 1. Reshape filter from + // [H, W, ..., filter_in_depth, out_depth] to + // [H, W, ..., filter_in_depth, G, out_depth / G]. + auto new_shape = llvm::to_vector<6>(filter_shape); + new_shape.back() = feature_group_count; + new_shape.push_back(filter_shape.back() / feature_group_count); + Type filter_element_ty = filter_ty.getElementType(); + auto ty = + tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create(op.getLoc(), ty, filter); + + // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]. + llvm::SmallVector perm(num_dims + 1); + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]); + std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]); + ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create( + op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter)); + + // 3. Reshape to [H, W, ..., in_depth, out_depth / G]. + new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1]; + new_shape[num_spatial_dims + 1] = new_shape.back(); + new_shape.pop_back(); + ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); + filter = rewriter.create(op.getLoc(), ty, filter); + } + + SmallVector kernel_spatial_dims; + kernel_spatial_dims.resize(num_spatial_dims); + std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0); + + // Mirror the filter in the spatial dimensions. + filter = rewriter.create( + op.getLoc(), filter, + GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); + + // activation gradients + // = gradients (with padding and dilation) mirrored_weights + Value result = rewriter.create( + op.getLoc(), op.getType(), op.getOutBackprop(), filter, + /*window_strides=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), + GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, + ConvDimensionNumbersAttr::get( + rewriter.getContext(), + /*inputBatchDimension=*/batch_dim, + /*inputFeatureDimension=*/feature_dim, + /*inputSpatialDimensions=*/spatial_dims, + // TF filter shape is [ H, W, ..., inC, outC ] + // Transpose the input and output features for computing the + // gradient. + /*kernelInputFeatureDimension=*/ + num_spatial_dims + 1, + /*kernelOutputFeatureDimension=*/ + num_spatial_dims, + /*kernelSpatialDimensions=*/kernel_spatial_dims, + /*outputBatchDimension=*/batch_dim, + /*outputFeatureDimension=*/feature_dim, + /*outputSpatialDimensions=*/spatial_dims), + rewriter.getI64IntegerAttr(feature_group_count), + /*batch_group_count=*/rewriter.getI64IntegerAttr(1), + /*precision_config=*/GetPrecisionConfig(&rewriter)); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +using ConvertConv2DBackpropInputOp = + ConvertConvBackpropInputOp; +using ConvertConv3DBackpropInputOp = + ConvertConvBackpropInputOp; + +// Converts tf.Conv?DBackpropFilterOp into: +// %result = "mhlo.convolution"(%input, %out_backprop) +template +class ConvertConvBackpropFilterOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Unpack all of the attributes. + tensorflow::TensorFormat data_format; + if (!FormatFromString(op.getDataFormat().str(), &data_format)) + return op.emitOpError("invalid data format"); + + tensorflow::Padding padding; + if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) + return failure(); + + auto out_backprop_ty = + mlir::dyn_cast(op.getOutBackprop().getType()); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + + for (RankedTensorType ty : {out_backprop_ty, input_ty}) + if (!ty || !ty.hasStaticShape()) return failure(); + + ArrayRef out_backprop_shape = out_backprop_ty.getShape(); + ArrayRef input_shape = input_ty.getShape(); + + DenseIntElementsAttr filter_shape_attr; + if (!matchPattern(op.getFilterSizes(), m_Constant(&filter_shape_attr)) || + filter_shape_attr.getType().getRank() != 1) + return failure(); + + auto dilations_attr = GetI64ElementsAttr(op.getDilations()); + std::vector dilations{ + dilations_attr.template getValues().begin(), + dilations_attr.template getValues().end()}; + auto strides_attr = GetI64ElementsAttr(op.getStrides()); + std::vector strides{ + strides_attr.template getValues().begin(), + strides_attr.template getValues().end()}; + + std::vector explicit_paddings; + if (padding == tensorflow::Padding::EXPLICIT) { + // EXPLICIT padding mode and the associated attribute is limited to + // Conv2DBackpropFilter. So, fetch attribute by identifier instead of the + // op.explicit_paddings() attribute getter. + ArrayRef explicit_paddings_attr = + op->template getAttrOfType("explicit_paddings").getValue(); + explicit_paddings.reserve(explicit_paddings_attr.size()); + for (Attribute explicit_padding : explicit_paddings_attr) + explicit_paddings.push_back( + mlir::cast(explicit_padding).getInt()); + } + + constexpr int num_dims = num_spatial_dims + 2; + auto filter_shape = filter_shape_attr.getValues(); + + // Reuse dimension computation logic from conv_grad_shape_utils.cc. + tensorflow::ConvBackpropDimensions dims; + if (!tensorflow::ConvBackpropComputeDimensionsV2( + /*label=*/"", num_spatial_dims, + ToTensorShape(input_shape), + ToTensorShape(filter_shape), + ToTensorShape(out_backprop_shape), dilations, + strides, padding, explicit_paddings, data_format, &dims) + .ok()) { + return failure(); + } + + // The activations (inputs) form the LHS of the convolution. + // Activations have shape: [batch, in_rows, in_cols, ..., in_depth] + // For the gradient computation, we need to: + // 1. In the case of group convolution, move the num_groups dimension before + // the batch dimension + // 2. Swap the roles of the batch and feature dimensions. + const int feature_dim = + tensorflow::GetTensorFeatureDimIndex(num_dims, data_format); + const int64_t in_depth = input_shape[feature_dim]; + const int64_t filter_in_depth = *(filter_shape.begin() + num_spatial_dims); + const int64_t batch_group_count = in_depth / filter_in_depth; + + // Compute ConvDimensionNumbers, dilation, and padding. + SmallVector spatial_dims; + SmallVector kernel_spatial_dims; + SmallVector rhs_dilation; + SmallVector paddings; + SmallVector window_strides; + + // The filter gradients are computed by a convolution of the input + // activations and the output gradients, with some appropriate padding. + // See the comment at the top of conv_grad_ops.h for details. + + for (int i : llvm::seq(0, num_spatial_dims)) { + const int64_t dim = + tensorflow::GetTensorSpatialDimIndex(num_dims, data_format, i); + kernel_spatial_dims.push_back(dim); + // Besides padding the input, we will also expand output_rows to + // expanded_out_rows = (output_rows - 1) * stride + 1 + // with zeros in between: + // + // a . . . b . . . c . . . d . . . e + // + // This is done by specifying the window dilation factors in the + // convolution HLO below. + const auto &spatial_dim_i = dims.spatial_dims[i]; + rhs_dilation.push_back(spatial_dim_i.stride); + window_strides.push_back(dilations[dim]); + + // We will also need to pad the input with zeros such that after the + // convolution, we get the right size for the filter. + // The padded_in_rows should be such that when we convolve this with the + // expanded_out_rows as a filter, we should get filter_rows back. + + const int64_t padded_in_size = + spatial_dim_i.expanded_output_size + + (spatial_dim_i.filter_size - 1) * dilations[dim]; + + // However it can be smaller than input_rows: in this + // case it means some of the inputs are not used. + // + // An example is to have input_cols = 3, filter_cols = 2 and stride = 2: + // + // INPUT = [ A B C ] + // + // FILTER = [ x y ] + // + // and the output will only have one column: a = A * x + B * y + // + // and input "C" is not used at all. + // + // We apply negative padding in this case. + const int64_t pad_total = padded_in_size - spatial_dim_i.input_size; + + // + For the EXPLICIT padding, we pad the top/left side with the explicit + // padding and pad the bottom/right side with the remaining space. + // + For the VALID padding, we don't pad anything on the top/left side + // and pad the bottom/right side with the remaining space. + // + For the SAME padding, we pad top/left side the same as bottom/right + // side. + // + // In addition, if the padded input size is smaller than the input size, + // we need to ignore some training elements of the input. We do this by + // applying negative padding on the right/bottom. + const int64_t pad_before = padding == tensorflow::Padding::EXPLICIT + ? explicit_paddings[2 * dim] + : padding == tensorflow::Padding::SAME + ? std::max(pad_total / 2, 0) + : 0; + paddings.push_back(pad_before); + paddings.push_back(pad_total - pad_before); + } + + RankedTensorType paddings_ty = tensorflow::GetTypeFromTFTensorShape( + {num_spatial_dims, 2}, rewriter.getIntegerType(64)); + auto paddings_attr = DenseIntElementsAttr::get(paddings_ty, paddings); + + SmallVector output_spatial_dimensions; + output_spatial_dimensions.resize(num_spatial_dims); + std::iota(output_spatial_dimensions.begin(), + output_spatial_dimensions.end(), 0); + + const int batch_dim = + tensorflow::GetTensorBatchDimIndex(num_dims, data_format); + + Value result = rewriter.create( + op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), + /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), + /*padding=*/paddings_attr, /*lhs_dilation=*/ + GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + GetI64ElementsAttr(rhs_dilation, &rewriter), + /*window_reversal=*/nullptr, + ConvDimensionNumbersAttr::get( + rewriter.getContext(), + // Swap batch_dim and feature_dim in the activations. + /*inputBatchDimension=*/feature_dim, + /*inputFeatureDimension=*/batch_dim, + /*inputSpatialDimensions=*/kernel_spatial_dims, + // The gradients become the RHS of the convolution. + // The gradients have shape [batch, out_rows, out_cols, ..., + // out_depth] where the batch becomes the input feature for the + // convolution. + /*kernelInputFeatureDimension=*/batch_dim, + /*kernelOutputFeatureDimension=*/feature_dim, + /*kernelSpatialDimensions=*/kernel_spatial_dims, + /*outputBatchDimension=*/num_spatial_dims, + /*outputFeatureDimension=*/num_spatial_dims + 1, + /*outputSpatialDimensions=*/output_spatial_dimensions), + /*feature_group_count=*/rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(batch_group_count), + /*precision_config=*/GetPrecisionConfig(&rewriter)); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +using ConvertConv2DBackpropFilterOp = + ConvertConvBackpropFilterOp; +using ConvertConv3DBackpropFilterOp = + ConvertConvBackpropFilterOp; + +class ConvertOneHotOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::OneHotOp op, + PatternRewriter &rewriter) const override { + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); + if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); + ArrayRef indices_shape = indices_ty.getShape(); + Type element_type = indices_ty.getElementType(); + + DenseIntElementsAttr depth_attr; + if (!matchPattern(op.getDepth(), m_Constant(&depth_attr))) { + return failure(); + } + + int64_t depth = depth_attr.getValues()[0].getSExtValue(); + int64_t axis = op.getAxis(); + if (axis == -1) axis = indices_shape.size(); + + llvm::SmallVector broadcast_dims(indices_shape.size()); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + + llvm::SmallVector output_dims = + llvm::to_vector<4>(indices_shape); + output_dims.insert(output_dims.begin() + axis, depth); + + Location loc = op.getLoc(); + + // The iota result is the effective output shape of the computation, + // and indices must be broadcast into it. At this point, this computation + // would need to be reworked quite a bit to support dynamic shapes, so + // just using static broadcasting. + auto index_type = + tensorflow::GetTypeFromTFTensorShape(output_dims, element_type); + auto iota = rewriter.create( + loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); + auto broadcast_indices = rewriter.create( + loc, index_type, op.getIndices(), + GetI64ElementsAttr(broadcast_dims, &rewriter)); + + Value compare = rewriter.create( + loc, broadcast_indices, iota, ComparisonDirection::EQ); + Value on_value = rewriter.create( + loc, op.getType(), op.getOnValue(), + GetI64ElementsAttr(output_dims, &rewriter)); + Value off_value = rewriter.create( + loc, op.getType(), op.getOffValue(), + GetI64ElementsAttr(output_dims, &rewriter)); + Value result = rewriter.create(loc, op.getType(), compare, + on_value, off_value); + + rewriter.replaceOp(op, {result}); + + return success(); + } +}; + +// Converts InfeedDequeueTuple to XLA HLO create_token, infeed and +// get_tuple_element ops. +// +// All HLO infeed ops expect a HLO token type operand and produce a tuple +// containing a token. This HLO token type is used to order multiple infeed +// operations within a computation. The token type can come from other +// infeed/outfeed/send/recv ops or can be generated using create_token op with +// no operands. Here we emit a create_token op to generate the token type +// operand of infeed. The mhlo.InfeedOp can produce multiple results and later +// will be exported to XLA infeed op with single tuple return type. +// +// For example the following IR: +// %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) +// +// would be lowered to +// +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : +// (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token> +// +class ConvertInfeedDequeueTupleOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InfeedDequeueTupleOp op, + PatternRewriter &rewriter) const override { + SmallVector result_types; + result_types.reserve(op.getOutputs().size() + 1); + for (const auto &output : op.getOutputs()) { + Type ty = output.getType(); + if (auto tensor_ty = mlir::dyn_cast(ty)) { + if (!tensor_ty.hasStaticShape()) return failure(); + } + result_types.push_back(ty); + } + + // Infeed takes a single token operand. Generate the token using + // create_token op to pass to the infeed op. + auto token = rewriter.create( + op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); + + result_types.push_back(token.getType()); + + ArrayAttr layout; // filled in during the xla-adjust-layout pass + auto data_and_token = + rewriter.create(op.getLoc(), result_types, token, + /*infeed_config=*/rewriter.getStringAttr(""), + /*layout=*/layout); + + result_types.pop_back(); // remove the token type. + + if (op.get_XlaSharding().has_value()) { + // _XlaSharding attribute in TF is a serialized string of the OpSharding + // proto, so convert to a text form here. + ::xla::OpSharding sharding_proto; + if (tensorflow::DecodeShardingAttribute( + op.get_XlaSharding().value().str(), sharding_proto) + .failed()) { + return failure(); + } + // Token is a control signal and not a real data, so arbitrarily assign + // the token to device 0. + if (sharding_proto.type() == ::xla::OpSharding::TUPLE) { + *sharding_proto.add_tuple_shardings() = + ::xla::sharding_builder::AssignDevice(0); + data_and_token->setAttr( + kShardingAttr, + rewriter.getStringAttr(sharding_proto.SerializeAsString())); + } else { + data_and_token->setAttr(kShardingAttr, op.get_XlaShardingAttr()); + } + } + + if (op->hasAttr("layouts")) { + // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to + // avoid compilation failure when exporting "layouts" attribute of the + // corresponding InfeedDequeueTupleOp to a graph node. + data_and_token->setAttr("layout", op->getAttr("layouts")); + } + llvm::SmallVector results; + results.reserve(result_types.size()); + for (const auto &idx_and_type : llvm::enumerate(result_types)) { + results.push_back(data_and_token.getResult(idx_and_type.index())); + } + rewriter.replaceOp(op, ValueRange(results)); + return success(); + } +}; + +// Converts tf.OutfeedEnqueueTuple to XLA HLO tuple, create_token and outfeed +// ops. +// +// XLA HLO outfeed op expects a token, which we generate by emitting an +// create_token op. +// +// For example the following IR: +// "tf.OutfeedEnqueueTuple"(%val_1, %val_2) : (tensor<3xi32>, tensor<4xf32>) -> +// () +// +// would be lowered to +// +// %token = "mhlo.create_token"() : () -> !mhlo.token +// %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""} +// : +// (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token +// +class ConvertOutfeedEnqueueTupleOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, + PatternRewriter &rewriter) const override { + auto token_type = mhlo::TokenType::get(rewriter.getContext()); + auto token = rewriter.create(op.getLoc(), token_type); + + rewriter.create(op.getLoc(), token_type, op.getInputs(), token, + /*outfeed_config=*/rewriter.getStringAttr("")); + rewriter.eraseOp(op); + return success(); + } +}; + +// Converts tf.TopKV2 to chlo.top_k. +class ConvertTopKV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::TopKV2Op op, + PatternRewriter &rewriter) const override { + // We can only match when the `k` operand is a constant scalar. + DenseIntElementsAttr k_attr; + if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); + int64_t k = (*k_attr.begin()).getSExtValue(); + + TensorType input_type = mlir::cast(op.getInput().getType()); + if (!input_type.hasRank()) return failure(); + int64_t input_rank = input_type.getRank(); + int64_t last_dim_index = input_rank - 1; + int64_t last_dim_size = input_type.getDimSize(last_dim_index); + if (last_dim_size == ShapedType::kDynamic) return failure(); + + rewriter.replaceOpWithNewOp(op, op.getInput(), k); + return success(); + } +}; + +// Converts tf.Unpack to a series of XLA HLO slice ops. +// +// Each slice takes one element along the dimension to unpack and takes the full +// range for all other dimensions. Each slice is then reshaped to drop the +// dimension to unpack (which is always of size 1). +// TODO(antiagainst): consider changing this into a TF internal lowering pass. +class ConvertUnpackOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.getAxis(); + if (axis < 0) axis += value_rank; + + // Parameters for constructing each slice. + SmallVector begin_indices(value_rank, 0); + auto end_indices = llvm::to_vector<4>(value_type.getShape()); + SmallVector strides(value_rank, 1); + + // All HLO slice+squeeze results used to replace the original tf.Unpack op. + SmallVector results; + results.reserve(op.getNumResults()); + + for (int i = 0, end = op.getNumResults(); i < end; ++i) { + begin_indices[axis] = i; + end_indices[axis] = i + 1; + + auto slice_op = rewriter.create( + op.getLoc(), op.getValue(), + GetI64ElementsAttr(begin_indices, &rewriter), + GetI64ElementsAttr(end_indices, &rewriter), + GetI64ElementsAttr(strides, &rewriter)); + // Reshape to drop the axis dimension. + auto result = rewriter.create( + op.getLoc(), op.getType(i), slice_op, + rewriter.getI64ArrayAttr(op.getAxis())); + results.push_back(result); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Converts tf.Unpack to a series of XLA HLO Slice ops. +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertUnpackOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::UnpackOp op, + PatternRewriter &rewriter) const override { + auto value_type = mlir::dyn_cast(op.getValue().getType()); + if (!value_type) return failure(); + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (value_type.hasStaticShape()) return failure(); + + int64_t value_rank = value_type.getRank(); + int64_t axis = op.getAxis(); + if (axis < 0) axis += value_rank; + Location loc = op.getLoc(); + + auto shape_scalar_type = rewriter.getIntegerType(32); + // Parameters for constructing each slice. + SmallVector begin_indices, end_indices, strides; + begin_indices.reserve(value_rank); + end_indices.reserve(value_rank); + strides.reserve(value_rank); + // final output shape + SmallVector shape_values; + shape_values.reserve(value_rank - 1); + // slice shape before reshape, should be like{?, 1, ?, ?} if axis = 1 + SmallVector slice_shape(value_rank, ShapedType::kDynamic); + for (int64_t dim_idx = 0; dim_idx < value_rank; ++dim_idx) { + int64_t dim_size = value_type.getDimSize(dim_idx); + if (dim_size == ShapedType::kDynamic) { + Value dim_i = rewriter.create( + loc, shape_scalar_type, + rewriter.create(loc, op.getOperand(), dim_idx)); + end_indices.push_back(dim_i); + if (dim_idx != axis) { + shape_values.push_back(dim_i); + } + } else { + Value dim_i = rewriter.create( + loc, shape_scalar_type, + rewriter.getIntegerAttr(shape_scalar_type, dim_size)); + end_indices.push_back(dim_i); + if (dim_idx != axis) { + shape_values.push_back(dim_i); + slice_shape[dim_idx] = dim_size; + } else { + slice_shape[dim_idx] = 1; + } + } + begin_indices.push_back( + rewriter.create(loc, 0, 32)); + strides.push_back(rewriter.create(loc, 1, 32)); + } + + SmallVector results; + results.reserve(op.getNumResults()); + Type i32_ty = rewriter.getI32Type(); + for (int64_t i = 0; i < op.getNumResults(); ++i) { + begin_indices[axis] = rewriter.create(loc, i, 32); + end_indices[axis] = rewriter.create(loc, i + 1, 32); + Value slice_op = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape(slice_shape, + value_type.getElementType()), + op.getValue(), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(begin_indices.size())}, i32_ty), + begin_indices), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(end_indices.size())}, i32_ty), + end_indices), + rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(strides.size())}, i32_ty), + strides)); + // Reshape to drop the axis dimension. + Value new_shape = rewriter.create( + loc, + tensorflow::GetTypeFromTFTensorShape( + {static_cast(shape_values.size())}, i32_ty), + shape_values); + Value reshape_op = rewriter.create(loc, op.getType(i), + slice_op, new_shape); + results.push_back(reshape_op); + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + +// Converts the tf.SigmoidGradOp +// TODO(disc): To recover static special case's performance with folding and +// canonicalization. +class ConvertSigmoidGradOpDynamic : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SigmoidGradOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value y = op.getY(); + Value dy = op.getDy(); + auto tp_y = mlir::dyn_cast(y.getType()); + auto tp_dy = mlir::dyn_cast(dy.getType()); + if (!tp_y || !tp_dy) return failure(); + + // TODO(disc): Remove this constraint once fold and canonicalization + // implemented. + if (tp_y.hasStaticShape() || tp_dy.hasStaticShape()) return failure(); + + Attribute attr; + Type elem_tp = tp_y.getElementType(); + if (elem_tp.isSignlessInteger()) { + attr = rewriter.getIntegerAttr(elem_tp, 1); + } else { + assert(mlir::isa(elem_tp)); + attr = rewriter.getFloatAttr(elem_tp, 1); + } + Value one = rewriter.create( + loc, DenseElementsAttr::get( + tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); + + auto v0 = rewriter.create( + loc, dy, y, hlo::getBroadcastDimensionsAttr(&rewriter, dy, y)); + auto v1 = rewriter.create( + loc, one, y, hlo::getBroadcastDimensionsAttr(&rewriter, one, y)); + auto result = rewriter.create( + loc, v0, v1, hlo::getBroadcastDimensionsAttr(&rewriter, v0, v1)); + + rewriter.replaceOp(op, result.getOperation()->getResults()); + return success(); + } +}; + +// Converts TF unsorted segment reduction ops to XLA HLO scatter op. +// +// TF unsorted segment reduction op peforms the following calculation: +// +// Assume segment ids' shape is [SI0, SI1, ..., SIm] and data's shape is +// [D0, D1, ..., Dn]. Note that segment ids' shape must be a prefix of data's +// shape, so we can have data's shape represented as [SI0, SI1, ..., SIm, +// Dm+1, ..., Dn]. Then +// output[segment_ids[SI_i0, SI_i1, ..., SI_im], D_im+1, ..., D_in] = +// over data[SI_i0, SI_i1, ..., SI_im, D_im+1, ..., D_in] +// where SI_iN is in the range of [0, SIN) and D_iN is in the range of [0, DN). +// +// The op will be translated to XLA HLO scatter with the following parameters: +// * Update window dims is [segment_id_rank, data_rank). +// * Inserted window dims is {0}. +// * Scatter dims to operand dims mapping is {0}. +// * Index vector dim is segment_id_rank. +template +class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + auto data_type = mlir::dyn_cast(op.getData().getType()); + if (!data_type) return failure(); + int64_t data_rank = data_type.getRank(); + + auto segment_ids_type = + mlir::dyn_cast(op.getSegmentIds().getType()); + if (!segment_ids_type) return failure(); + int64_t segment_ids_rank = segment_ids_type.getRank(); + + DenseIntElementsAttr num_segments_attr; + if (!matchPattern(op.getNumSegments(), m_Constant(&num_segments_attr))) + return failure(); + + // The final shape for TF unsorted segment reduction op is [num_segments] + + // data_shape[segment_ids_rank:]. + SmallVector output_shape; + output_shape.push_back((*num_segments_attr.begin()).getSExtValue()); + auto suffix = data_type.getShape().drop_front(segment_ids_rank); + output_shape.append(suffix.begin(), suffix.end()); + auto output_type = tensorflow::GetTypeFromTFTensorShape( + output_shape, data_type.getElementType()); + + // Broadcast the initial value for reduction. This will become the + // 'operand' parameter to scatter to for the final scatter op. + Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), + op.getLoc(), &rewriter); + auto broadcasted_init = rewriter.create( + op.getLoc(), output_type, init, + GetI64ElementsAttr(output_shape, &rewriter)); + + // Parameters for the generated scatter op. + SmallVector inserted_window_dims(1, 0); + SmallVector scatter_dims_to_operand_dims(1, 0); + int64_t index_vector_dim = segment_ids_rank; + + // Put all parameters in a StructAttr. + auto dims_attr = ScatterDimensionNumbersAttr::get( + rewriter.getContext(), + llvm::to_vector<4>(llvm::seq(segment_ids_rank, data_rank)), + inserted_window_dims, + /*inputBatchingDims=*/{}, + /*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims, + index_vector_dim); + + auto scatter = rewriter.create( + op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), + op.getSegmentIds(), op.getData(), dims_attr); + BuildReduceBody(data_type.getElementType(), + &scatter.getUpdateComputation(), &rewriter); + + rewriter.replaceOp(op, scatter.getResult(0)); + return success(); + } +}; + +class ConvertUnsortedSegmentMaxOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest, + rewriter); + } +}; + +class ConvertUnsortedSegmentMinOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax, + rewriter); + } +}; + +class ConvertUnsortedSegmentProdOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 1, rewriter); + } +}; + +class ConvertUnsortedSegmentSumOp + : public GenericConvertUnsortedSegmentReductionOp< + ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> { + public: + using GenericConvertUnsortedSegmentReductionOp:: + GenericConvertUnsortedSegmentReductionOp; + + static Value GetInitialValue(Type reduce_element_type, Location loc, + PatternRewriter *rewriter) { + return GetScalarConstOfType(reduce_element_type, loc, 0, rewriter); + } +}; + +// Converts tf.RandomShuffle op into a series of XLA HLO ops. +// +// tf.RandomShuffle shuffles tensors along the first dimension. If the input +// tensor's rank is 1, then it is translated into HLO sort op(s) according to +// indices randomly generated via HLO rng_uniform ops. Otherwise, it is +// translated into an HLO while op to first emulate shuffling indices using +// HLO dynamic_slice and dynamic_update_slice ops, then finally HLO gather +// with the shuffled indices. +class ConvertRandomShuffleOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::RandomShuffleOp op, + PatternRewriter &rewriter) const override { + auto no_op = [&]() { + rewriter.replaceOp(op, op.getValue()); + return success(); + }; + + auto input_type = mlir::dyn_cast(op.getValue().getType()); + if (!input_type) return failure(); + if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); + + int64_t input_rank = input_type.getRank(); + int64_t first_dim_size = input_type.getDimSize(0); + if (ShapedType::isDynamic(first_dim_size)) return failure(); + + if (first_dim_size <= 1) + // No shuffling is required, so copy input directly to output. + return no_op(); + + // For vectors, shuffle values by sorting instead of the obvious + // Fisher-Yates algorithm. Fisher-Yates is simple to implement and correct, + // but not easily parallelizable. For a sufficiently parallel architecture, + // it is faster to sort many times, than Fisher-Yates shuffle once. + if (input_rank == 1) { + // Shuffle values by assigning each value a random key and sorting the + // keys. Keys can collide causing detectable patterns in the shuffled + // output. Collisions translates into more ascending sub-sequences in the + // shuffled output than would be expected by chance. To avoid collisions, + // the number of possible key values must be sufficiently large. + + // How are more than 2^32 keys created? In each loop iteration, the + // algorithm sorts by random keys. Conceptually, the earlier iterations + // are sorting on the lower-order bits of larger keys that are never + // actually assembled. + + // The expected number of collisions is n - d + d(1 - 1/d)^n, where d is + // the number of possible keys and n is the number of values. If d = n^2, + // then the limit as n goes to infinity is 1/2. If d = n^3, then the limit + // as n goes to infinity is zero. + + // This implementation ensures that the key-space is greater than or equal + // to the cube of the number of values. The risk of collisions can be + // further reduced by increasing Exponent at the expense of + // performance. + + // For Exponent = 2, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/2)) = 65535 where the expectation is + // about 1/2. + + // For Exponent = 3, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/3)) = 1625 where the expectation is + // about 1/3255. + + // For Exponent = 4, the expected number of collisions per shuffle is + // maximized at n = floor((2^32-1)^(1/4)) = 255 where the expectation is + // about 1/132622. + constexpr int exponent = 3; + int64_t num_elements = input_type.getNumElements(); + uint32_t u32_max = std::numeric_limits::max(); + int rounds = + std::ceil(exponent * std::log(num_elements) / std::log(u32_max)); + + Value current = op.getValue(); + for (int i = 0; i < rounds; ++i) { + auto keys = + CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, + /*upper_limit=*/u32_max, &rewriter); + auto sorted = createSortOp( + &rewriter, op.getLoc(), {keys, current}, + {rewriter.getIntegerType(32), input_type.getElementType()}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/ComparisonDirection::LT); + current = sorted.getResult(1); + } + rewriter.replaceOp(op, current); + return success(); + } + + // The Fisher-Yates algorithm. + + // Generate range(n) as the initial value for the indices to be swapped. + auto indices_type = tensorflow::GetTypeFromTFTensorShape( + {first_dim_size}, rewriter.getIntegerType(32)); + Value indices = rewriter.create( + op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); + + // Generate random numbers to be used as swaps for the indices. + Value swaps = CreateRngUniform32(op.getLoc(), first_dim_size, 0, + first_dim_size, &rewriter); + + // While loop body to perform index swaps. + auto swap_body_fn = [&](Location loc, Value i, ArrayRef old_values, + SmallVectorImpl *new_values, + OpBuilder *builder) { + Value swaps = old_values[0]; + Value indices = old_values[1]; + + auto scalar_i32_type = + tensorflow::GetTypeFromTFTensorShape({}, builder->getIntegerType(32)); + auto one_cross_i64_type = tensorflow::GetTypeFromTFTensorShape( + {1}, builder->getIntegerType(64)); + + auto scalar_one = + DenseIntElementsAttr::get(one_cross_i64_type, ArrayRef(1)); + + // We need to swap the indices[i] with indices[swaps[i]]. First get + // these index values. + Value source_index = + builder->create(loc, indices, i, scalar_one); + Value swap_index = builder->create( + loc, scalar_i32_type, + builder->create(loc, swaps, i, scalar_one)); + Value target_index = builder->create( + loc, indices, swap_index, scalar_one); + + // Then perform the swap. + // indices[i] <- indices[swaps[i]] + indices = builder->create( + loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); + // indices[swaps[i]] <- indices[i] + indices = builder->create( + loc, indices.getType(), indices, source_index, + llvm::ArrayRef(swap_index)); + + // Update new values. + new_values->assign({swaps, indices}); + }; + + // Create a while op to swap indices. + SmallVector while_output; + CreateWhile32(op.getLoc(), first_dim_size, swap_body_fn, {swaps, indices}, + &while_output, &rewriter); + Value swaped_indices = while_output[1]; + + // Gather the data using the swapped indices as the shuffled order. + auto slice_sizes = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); + slice_sizes[0] = 1; + auto dims_attr = GatherDimensionNumbersAttr::get( + rewriter.getContext(), + /*offsetDims=*/llvm::to_vector<4>(llvm::seq(1, input_rank)), + /*collapsedSliceDims=*/{0}, + /*operandBatchingDims=*/{}, + /*startIndicesBatchingDims=*/{}, + /*startIndexMap=*/{0}, + /*indexVectorDim=*/1); + + SmallVector slice_sizes_values; + for (auto i = 0; i < slice_sizes.size(); ++i) { + if (slice_sizes[i] == tensorflow::kTFDynamicSize) { + Value i_const = rewriter.create( + op.getLoc(), rewriter.getIndexAttr(i)); + Value slice_size_index = + rewriter.create(op.getLoc(), op.getValue(), i_const); + Value index_to_i64 = rewriter.create( + op.getLoc(), rewriter.getI64Type(), slice_size_index); + Value i64_to_tensor = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), + index_to_i64); + slice_sizes_values.push_back(i64_to_tensor); + } else { + slice_sizes_values.push_back(rewriter.create( + op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); + } + } + + auto slice_sizes_concat = rewriter.create( + op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, + dims_attr); + + return success(); + } +}; + +// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes. +class ConvertXlaShardingOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaShardingOp op, + PatternRewriter &rewriter) const override { + // TODO(b/148313088): define sharding attribute struct in MLIR intead of + // using a string. + if (!op.get_XlaSharding().has_value()) return failure(); + + NamedAttribute call_target_name = rewriter.getNamedAttr( + "call_target_name", rewriter.getStringAttr("Sharding")); + + auto custom_call = rewriter.create( + op.getLoc(), op.getType(), op.getInput(), + ArrayRef{call_target_name}); + custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); + rewriter.replaceOp(op, custom_call.getResult(0)); + + return success(); + } +}; + +// Converts a TF InplaceUpdate op to DynamicUpdateSlice HLO. +class ConvertInplaceUpdateOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, + PatternRewriter &rewriter) const override { + auto input = mlir::dyn_cast>(op.getX()); + if (!input) return failure(); + auto indices = op.getI(); + auto updates = op.getV(); + + // Slice each row of `i` and `v` to perform a separate dynamic-update-slice + // on the contents of `x`. + auto input_type = mlir::cast(input.getType()); + auto updates_type = mlir::cast(updates.getType()); + auto indices_type = mlir::cast(indices.getType()); + if (!input_type.hasRank()) return failure(); + if (!updates_type.hasRank() || updates_type.isDynamicDim(0)) + return failure(); + if (!indices_type.hasStaticShape()) return failure(); + + if (indices_type.getRank() != 1) return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( + {}, indices_type.getElementType())); + // Note on zero_attr integer type: DynamicUpdateSlice op start_indices are + // required to have matching types. This rewrite rule creates + // DynamicUpdateSlice ops where the first "start index" is always i32 and + // subsequent ones are constructed based on zero_attr. Thus the type + // for zero_attr needs to be i32 as well. + auto zero_attr = IntegerAttr::get(rewriter.getIntegerType(32), 0); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, indices, zero_attr); + + SmallVector split_updates_shape; + split_updates_shape.append(updates_type.getShape().begin(), + updates_type.getShape().end()); + split_updates_shape.front() = 1; + SmallVector split_updates_type; + split_updates_type.resize( + updates_type.getShape().front(), + tensorflow::GetTypeFromTFTensorShape(split_updates_shape, + updates_type.getElementType())); + + auto cst = + rewriter.create(op.getLoc(), zero_attr).getResult(); + auto split_updates = rewriter.create( + op.getLoc(), split_updates_type, cst, updates); + + SmallVector input_indices; + input_indices.resize(input_type.getRank(), cst); + + for (auto pair : + llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { + input_indices.front() = std::get<0>(pair); + input = rewriter.create( + op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +// Converts a TF XlaDynamicUpdateSlice op to DynamicUpdateSlice HLO. +class ConvertXlaDynamicUpdateSliceOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, + PatternRewriter &rewriter) const override { + auto indices_type = + mlir::dyn_cast(op.getIndices().getType()); + if (!indices_type || !indices_type.hasStaticShape() || + indices_type.getShape().size() != 1) + return failure(); + + SmallVector unpacked_indices_type( + indices_type.getDimSize(0), tensorflow::GetTypeFromTFTensorShape( + {}, indices_type.getElementType())); + auto unpacked_indices = rewriter.create( + op.getLoc(), unpacked_indices_type, op.getIndices(), + IntegerAttr::get(rewriter.getIntegerType(64), 0)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getInput(), op.getUpdate(), + unpacked_indices.getOutput()); + return success(); + } +}; + +// Converts a TF XlaReduceScatter op to ReduceScatter HLO. +class ConvertXlaReduceScatterOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReduceScatterOp op, + PatternRewriter &rewriter) const override { + DenseIntElementsAttr group_assignment; + if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) + return failure(); + auto replica_groups = + mlir::cast(hlo::convertElementsAttr( + group_assignment, rewriter.getIntegerType(64))); + if (replica_groups.getType().getRank() != 2) return failure(); + + APInt scatter_dimension; + if (!matchPattern(op.getScatterDimension(), + m_ConstantInt(&scatter_dimension))) + return failure(); + + Location loc = op.getLoc(); + Type element_type = getElementTypeOrSelf(op.getInput().getType()); + + auto reduce_scatter = rewriter.create( + loc, op.getType(), op.getInput(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + scatter_dimension.getSExtValue()), + replica_groups, ChannelHandleAttr()); + StringRef reduce_op = op.getReduceOp(); + if (reduce_op == "Add") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Mul") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Min") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else if (reduce_op == "Max") { + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } else { + // For mean, add replicas in the same group. Then divide the sum by the + // number of replicas in each group below. + assert(reduce_op == "Mean"); + BuildReduceBody(element_type, &reduce_scatter.getComputation(), + &rewriter); + } + Value result = reduce_scatter.getResult(); + + // For mean, divide the merge result by group size. + if (reduce_op == "Mean") { + int64_t replica_group_size = replica_groups.getType().getDimSize(1); + if (replica_group_size == 0) return failure(); + auto divisor = GetScalarConstOfType(element_type, loc, replica_group_size, + &rewriter); + auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); + result = rewriter.create( + loc, result, divisor.getResult(), broadcast_dims); + } + + rewriter.replaceOp(op, {result}); + return success(); + } +}; + +// Converts tf.XlaReduceWindow to mhlo.ReduceWindow +class ConvertXlaReduceWindowOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReduceWindowOp op, + PatternRewriter &rewriter) const override { + DenseElementsAttr window_dimensions, window_strides, base_dilations, + window_dilations, padding; + if (!(matchPattern(op.getWindowDimensions(), + m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getBaseDilations(), m_Constant(&base_dilations)) && + matchPattern(op.getWindowDilations(), + m_Constant(&window_dilations)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) + return failure(); + + Location loc = op.getLoc(); + + SmallVector result_types{op.getResult().getType()}; + // Create the mhlo.SelectAndScatter op. + auto reduce_window_op = rewriter.create( + loc, result_types, op.getInput(), op.getInitValue(), + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + base_dilations, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_dilations, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); + // Insert a call to the reducer in the region of the mhlo op. + mlir::SymbolRefAttr func = op.getComputation(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + BuildBodyWithCall(rewriter, loc, func, func_ty, + &reduce_window_op.getBody()); + + rewriter.replaceOp(op, reduce_window_op.getResults()); + + return success(); + } +}; + +// Converts ClipByValue to XLA's clamp operation. Includes the broadcasting +// semantics for static and dynamic cases. +class ConvertClipByValueOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ClipByValueOp op, + PatternRewriter &rewriter) const override { + Value input = op.getX(); + Value min = op.getClipValueMin(); + Value max = op.getClipValueMax(); + + auto input_ty = mlir::cast(input.getType()); + auto min_ty = mlir::cast(min.getType()); + auto max_ty = mlir::cast(max.getType()); + + if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, + rewriter.getI32Type()), + input); + + if (min_ty != input_ty) { + min = + rewriter.create(op.getLoc(), input_ty, min, shape); + } + + if (max_ty != input_ty) { + max = + rewriter.create(op.getLoc(), input_ty, max, shape); + } + + rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + return success(); + } +}; + +// Converts ConstOp to XLA's constant operation and introduces a tensor cast if +// needed. +class ConvertConstOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ConstOp op, + PatternRewriter &rewriter) const override { + // Convert only for valid HLO tensors. + auto ty = mlir::dyn_cast(op.getType()); + if (!ty || + !mlir::isa(ty.getElementType())) + return failure(); + + Location loc = op.getLoc(); + Value result = rewriter.create(loc, op.getValue()); + if (result.getType() != op.getType()) + result = rewriter.create(loc, op.getType(), result); + rewriter.replaceOp(op, result); + return success(); + } +}; + +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template +class ConvertCumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpT op, + PatternRewriter &rewriter) const override { + auto input = mlir::dyn_cast>(op.getX()); + if (!input) return failure(); + auto input_type = mlir::dyn_cast(input.getType()); + if (!input_type || !input_type.hasStaticShape()) { + return failure(); + } + + ArrayRef input_shape = input_type.getShape(); + int64_t rank = input_shape.size(); + + // We can only match when the axis is a constant scalar. + DenseIntElementsAttr axis_attr; + if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { + return failure(); + } + + // Get the dimension to apply the reduction on, and offset properly if it is + // negative. + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < 0) { + axis += rank; + } + + // If we're supposed to sum things up in the reverse direction, we reverse + // the input and then later reverse the output. + if (op.getReverse()) { + llvm::SmallVector dims_to_reverse({axis}); + input = rewriter.create( + op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + // Convert if we need to enlarge the element type's bitwidth to avoid + // precision loss. + Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + + Type sum_element_type = GetSumAccumulationType(input_element_type); + input = rewriter.create(op.getLoc(), input, sum_element_type); + + SmallVector window_dims(rank, 1); + SmallVector window_strides(rank, 1); + window_dims[axis] = input_shape[axis]; + + SmallVector paddings(rank * 2, 0); + paddings[axis * 2] = + std::max(input_shape[axis] - 1, static_cast(0)); + auto paddings_attr = + DenseIntElementsAttr::get(tensorflow::GetTypeFromTFTensorShape( + {rank, 2}, rewriter.getIntegerType(64)), + paddings); + + int64_t init_value = (std::is_same::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); + + auto reduce = rewriter.create( + op.getLoc(), input.getType(), input, init, + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), + GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), + /*base_dilations=*/DenseIntElementsAttr(), + /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), + &rewriter); + Value result = reduce.getResult(0); + + if (op.getExclusive()) { + // In "exclusive" operation, the output will start with the "init" (0) + // values. There is no way to express that as a ReduceWindowOp, so run the + // normal operation, and then use a PadOp to add the 0 "column" on the + // left and cut away the last column on the right. + llvm::SmallVector low_padding(rank, 0); + llvm::SmallVector high_padding(rank, 0); + llvm::SmallVector interior_padding(rank, 0); + low_padding[axis] = 1; + high_padding[axis] = -1; + result = rewriter.create( + op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), + GetI64ElementsAttr(high_padding, &rewriter), + GetI64ElementsAttr(interior_padding, &rewriter)); + } + + // Convert back if we enlarged the element type's bitwidth. + result = + rewriter.create(op.getLoc(), result, input_element_type); + + if (op.getReverse()) { + llvm::SmallVector dims_to_reverse({axis}); + result = rewriter.create( + op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; + +// Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard +// dialect lowerings. This involves extracting the shape type, extracting and +// converting each dimension to a known integer type, and repacking into a final +// tensor. +class ConvertShapeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ShapeOp op, + PatternRewriter &rewriter) const override { + Value input = op.getInput(); + + auto result_ty = mlir::dyn_cast(op.getResult().getType()); + if (!result_ty) { + return failure(); + } + + auto index_tensor = tensorflow::GetTypeFromTFTensorShape( + result_ty.getShape(), rewriter.getIndexType()); + auto shape_op = + rewriter.create(op.getLoc(), index_tensor, input); + rewriter.replaceOpWithNewOp(op, result_ty, shape_op); + return success(); + } +}; + +class ConvertDynamicExpandDimsOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::ExpandDimsOp op, + PatternRewriter &rewriter) const override { + auto input = op.getInput(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + DenseIntElementsAttr expand_dims_attr; + if (!matchPattern(op.getDim(), m_Constant(&expand_dims_attr))) { + return failure(); + } + + auto shape = rewriter.create( + op.getLoc(), + tensorflow::GetTypeFromTFTensorShape({input_ty.getRank()}, + rewriter.getIndexType()), + input); + auto expand_dims = llvm::to_vector<6>(expand_dims_attr.getValues()); + + llvm::SmallVector dims; + dims.resize(result_ty.getRank()); + + auto inserted_dim = expand_dims[0].getSExtValue(); + + // Handle the negative value use case. + if (inserted_dim < 0) { + inserted_dim += result_ty.getRank(); + // This means the value is completely incorrect, just return. + if (inserted_dim < 0) { + return failure(); + } + } + + dims[inserted_dim] = + rewriter.create(op.getLoc(), 1); + + for (int i = 0; i < dims.size() - 1; i++) { + // Add the extracted dim. + Value index = rewriter.create(op.getLoc(), i); + Value dim = rewriter.create(op.getLoc(), shape, index); + dims[i >= inserted_dim ? i + 1 : i] = dim; + } + + auto from_extents = + rewriter.create(op.getLoc(), dims); + rewriter.replaceOpWithNewOp(op, result_ty, input, + from_extents); + return success(); + } +}; + +class ConvertDynamicSqueezeOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::SqueezeOp op, + PatternRewriter &rewriter) const override { + auto input = op.getInput(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); + if (!result_ty.hasRank() || !input_ty.hasRank() || + result_ty.hasStaticShape()) { + return failure(); + } + + // The fully dynamic case is unsupported. + if (op.getSqueezeDims().empty()) { + return failure(); + } + + SmallVector squeeze_dims; + int64_t input_rank = input_ty.getRank(); + for (const auto &squeeze_dim_apint : + op.getSqueezeDims().getAsValueRange()) { + int64_t squeeze_dim = squeeze_dim_apint.getSExtValue(); + // Handle negative inputs. + if (squeeze_dim < 0) squeeze_dim += input_rank; + assert(squeeze_dim >= 0 && squeeze_dim < input_rank && + "squeeze dim out of bounds"); + + squeeze_dims.push_back(squeeze_dim); + } + + // Collect the unsqueezed dimensions. + llvm::SmallVector dims; + for (int64_t i = 0; i != input_rank; ++i) { + if (llvm::is_contained(squeeze_dims, i)) continue; + dims.push_back(rewriter.create(op.getLoc(), input, i)); + } + + auto from_extents = + rewriter.create(op.getLoc(), dims); + rewriter.replaceOpWithNewOp(op, result_ty, input, + from_extents); + return success(); + } +}; + +// Converts tf.XlaConvV2 to mhlo.Conv +class ConvertXlaConvV2Op : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaConvV2Op op, + PatternRewriter &rewriter) const override { + DenseElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr, + rhs_dilation_attr, feature_group_count_attr; + if (!(matchPattern(op.getWindowStrides(), + m_Constant(&window_strides_attr)) && + matchPattern(op.getPadding(), m_Constant(&padding_attr)) && + matchPattern(op.getLhsDilation(), m_Constant(&lhs_dilation_attr)) && + matchPattern(op.getRhsDilation(), m_Constant(&rhs_dilation_attr)) && + matchPattern(op.getFeatureGroupCount(), + m_Constant(&feature_group_count_attr)))) + return failure(); + + auto window_strides_named_attr = rewriter.getNamedAttr( + "window_strides", + mlir::cast(hlo::convertElementsAttr( + window_strides_attr, rewriter.getIntegerType(64)))); + + auto padding_named_attr = rewriter.getNamedAttr( + "padding", mlir::cast(hlo::convertElementsAttr( + padding_attr, rewriter.getIntegerType(64)))); + + auto lhs_dilation_named_attr = rewriter.getNamedAttr( + "lhs_dilation", + mlir::cast(hlo::convertElementsAttr( + lhs_dilation_attr, rewriter.getIntegerType(64)))); + + auto rhs_dilation_named_attr = rewriter.getNamedAttr( + "rhs_dilation", + mlir::cast(hlo::convertElementsAttr( + rhs_dilation_attr, rewriter.getIntegerType(64)))); + + int64_t feature_group_count_val = + feature_group_count_attr.getValues()[0].getInt(); + auto feature_group_count_named_attr = rewriter.getNamedAttr( + "feature_group_count", + rewriter.getI64IntegerAttr(feature_group_count_val)); + + auto batch_group_count_named_attr = + rewriter.getNamedAttr("batch_group_count", op.getBatchGroupCountAttr()); + + xla::ConvolutionDimensionNumbers dnums; + dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); + auto dimension_numbers_named_attr = rewriter.getNamedAttr( + "dimension_numbers", + xla::ConvertConvDimensionNumbers(dnums, &rewriter)); + + xla::PrecisionConfig precision_config; + precision_config.ParseFromString( + op.getPrecisionConfigAttr().getValue().str()); + auto precision_config_named_attr = rewriter.getNamedAttr( + "precision_config", + xla::ConvertPrecisionConfig(&precision_config, &rewriter)); + + SmallVector operands{op.getLhs(), op.getRhs()}; + NamedAttribute attrs[] = { + window_strides_named_attr, padding_named_attr, + lhs_dilation_named_attr, rhs_dilation_named_attr, + feature_group_count_named_attr, batch_group_count_named_attr, + dimension_numbers_named_attr, precision_config_named_attr}; + rewriter.replaceOpWithNewOp(op, op.getType(), operands, + llvm::ArrayRef(attrs)); + return success(); + } +}; + +// Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter +class ConvertXlaSelectAndScatterOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaSelectAndScatterOp op, + PatternRewriter &rewriter) const override { + ElementsAttr window_dimensions, window_strides, padding; + if (!(matchPattern(op.getWindowDimensions(), + m_Constant(&window_dimensions)) && + matchPattern(op.getWindowStrides(), m_Constant(&window_strides)) && + matchPattern(op.getPadding(), m_Constant(&padding)))) + return failure(); + + Location loc = op.getLoc(); + + SmallVector result_types{op.getResult().getType()}; + // Create the mhlo.SelectAndScatter op. + auto select_and_scatter_op = rewriter.create( + loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); + + auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) { + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + BuildBodyWithCall(rewriter, loc, func, func_ty, region); + }; + + // Insert a call to the select function in the select region of the mhlo op. + insert_call_to(op.getSelect(), &select_and_scatter_op.getSelect()); + // Insert a call to the scatter function in the scatter region of the mhlo + // op. + insert_call_to(op.getScatter(), &select_and_scatter_op.getScatter()); + + rewriter.replaceOp(op, select_and_scatter_op.getResult()); + + return success(); + } +}; + +// Convert tf.XlaSort to mhlo.Sort +class ConvertXlaSortOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaSortOp op, + PatternRewriter &rewriter) const override { + // Create the sort op. + Type element_type = getElementTypeOrSelf(op.getInput().getType()); + auto sort_op = + createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/ComparisonDirection::LT); + rewriter.replaceOp(op, sort_op.getResult(0)); + return success(); + } +}; + +inline std::optional TensorFlowRngAlgToXla( + tensorflow::Algorithm alg) { + if (alg == tensorflow::RNG_ALG_PHILOX) { + return xla::RandomAlgorithm::RNG_PHILOX; + } else if (alg == tensorflow::RNG_ALG_THREEFRY) { + return xla::RandomAlgorithm::RNG_THREE_FRY; + } else if (alg == tensorflow::RNG_ALG_AUTO_SELECT) { + return xla::RandomAlgorithm::RNG_DEFAULT; + } + return std::nullopt; +} + +// Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op. +class ConvertXlaRngBitGeneratorOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaRngBitGeneratorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + DenseElementsAttr algorithm; + if (!(matchPattern(op.getAlgorithm(), m_Constant(&algorithm))) || + algorithm.getType().getRank()) { + return op.emitOpError() << "algorithm must be a constant scalar"; + } + auto alg = static_cast( + algorithm.getValues()[0].getInt()); + auto xla_alg = TensorFlowRngAlgToXla(alg); + if (!xla_alg) { + return op.emitOpError() << "unknown algorithm"; + } + + auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + rewriter.getContext(), + *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); + auto rng_bit_generator_op = rewriter.create( + loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); + + rewriter.replaceOp(op, rng_bit_generator_op.getResults()); + + return success(); + } +}; + +// Converts tf.XlaVariadicReduceV2 to mhlo.Reduce +class ConvertXlaVariadicReduceV2Op + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaVariadicReduceV2Op op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + mlir::SymbolRefAttr func = op.getReducer(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + SmallVector elementTypes{llvm::map_range( + func_ty.getResults(), + [](Type ty) { return mlir::cast(ty).getElementType(); })}; + + // Create the mhlo.reduce op. + auto reduce_op = rewriter.create( + loc, op.getInputs(), op.getInitValues(), + GetI64ElementsAttr(op.getDimensionsToReduce()), elementTypes); + + // Insert a call to the reducer in the region of the mhlo op. + BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.getBody()); + + rewriter.replaceOp(op, reduce_op.getResults()); + + return success(); + } +}; + +// Convert tf.XlaVariadicSort to mhlo.Sort +class ConvertXlaVariadicSortOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaVariadicSortOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + ElementsAttr dimension; + matchPattern(op.getDimension(), m_Constant(&dimension)); + // Create the mhlo.sort op. + auto sort_op = rewriter.create( + loc, op.getInputs(), dimension.getValues()[0].getInt(), + op.getIsStable()); + mlir::SymbolRefAttr func = op.getComparator(); + auto func_op = cast(SymbolTable::lookupSymbolIn( + op->getParentOfType(), func)); + auto func_ty = func_op.getFunctionType(); + // Insert a call to the reducer in the region of the mhlo op. + BuildBodyWithCall(rewriter, loc, func, func_ty, &sort_op.getComparator()); + + rewriter.replaceOp(op, sort_op.getResults()); + return success(); + } +}; + +// Convert tf.XlaReducePrecision to mhlo.ReducePrecision +class ConvertXlaReducePrecisionOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::XlaReducePrecisionOp op, + PatternRewriter &rewriter) const override { + IntegerType int32_type = rewriter.getIntegerType(32); + APInt exponent_bits = op.getExponentBitsAttr().getValue(); + // Truncating to 32-bits is safe, since pasing any number above the dtype + // size (which is at most 64, for float64) is equivalent to passing the + // dtype size. + IntegerAttr new_exponent_attr = + IntegerAttr::get(int32_type, exponent_bits.truncSSat(32)); + APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); + IntegerAttr new_mantissa_attr = + IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getOperand(), new_exponent_attr, + new_mantissa_attr); + return success(); + } +}; + +class LowerYieldOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + TF::YieldOp op, TF::YieldOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + +// Returns a new tensor type from the given type with element type updated to +// the given type. +TensorType UpdateElementTypeTo(Type ty, Type element_ty) { + auto ranked_ty = mlir::dyn_cast(ty); + if (!ranked_ty) { + return UnrankedTensorType::get(element_ty); + } + return RankedTensorType::get(ranked_ty.getShape(), element_ty, + ranked_ty.getEncoding()); +} + +template +class LowerControlFlowOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + SrcOpT op, typename SrcOpT::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + DstOpT mhlo_op; + Location loc = op.getLoc(); + + // To handle quant type conversions, use the converted operands' element + // types and original source op's shapes and encoding to get converted op's + // result types. This is only done for the While op for now. + llvm::SmallVector element_types; + int64_t num_results = op.getNumResults(); + if constexpr (std::is_same::value) { + element_types.reserve(num_results); + for (Value value : adaptor.getOperands()) { + element_types.push_back(getElementTypeOrSelf(value.getType())); + } + } + + if constexpr (std::is_same::value) { + // Explicitly handle the Case op because it has variadic regions and takes + // the number of regions as an input along with the operands. + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getBranchIndex(), + op.getBranches().size()); + } else if constexpr (std::is_same::value) { + llvm::SmallVector while_result_types; + while_result_types.reserve(num_results); + for (int64_t idx = 0; idx < num_results; ++idx) { + auto ty = UpdateElementTypeTo(op.getType(idx), element_types[idx]); + while_result_types.push_back(ty); + } + + mhlo_op = rewriter.create(loc, TypeRange(while_result_types), + adaptor.getOperands()); + } else { + mhlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getOperands()); + } + + int64_t num_regions = op.getNumRegions(); + for (int64_t idx = 0; idx < num_regions; ++idx) { + Region ®ion = mhlo_op.getBodyRegion(idx); + rewriter.inlineRegionBefore(op.getBodyRegion(idx), region, region.end()); + + // Update region's entry blocks argument types to handle quantized element + // types. + if constexpr (std::is_same::value) { + TypeConverter::SignatureConversion signature(num_results); + Block &block = region.front(); + for (const auto &[block_idx, original_ty] : + llvm::enumerate(block.getArgumentTypes())) { + TensorType updated_ty = + UpdateElementTypeTo(original_ty, element_types[block_idx]); + signature.addInputs(block_idx, {updated_ty}); + } + rewriter.applySignatureConversion(®ion.front(), signature); + } + } + + // Replace all uses of `op` results with the newly created op. + rewriter.replaceOp(op, mhlo_op); + return success(); + } +}; + +// Keep all these in the odml namespace to avoid collisions with the tf2xla +// version for now. +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_legalize_tf.inc" + +// LINT.IfChange +void PopulatePatterns(MLIRContext *context, RewritePatternSet *patterns) { + populateWithGenerated(*patterns); + // clang-format off + patterns->add< + ConvertAllOp, + ConvertAnyOp, + ConvertArgMaxOp, + ConvertArgMinOp, + ConvertBatchMatMulV2Op, + ConvertBiasAddOp, + ConvertBroadcastToOp, + ConvertBF16FloorDivOp, + ConvertClipByValueOp, + ConvertConstOp, + ConvertConv2DOp, + ConvertConv3DOp, + ConvertDepthConv2DOp, + ConvertConv2DBackpropFilterOp, + ConvertConv3DBackpropFilterOp, + ConvertConv2DBackpropInputOp, + ConvertConv3DBackpropInputOp, + ConvertCumprodOp, + ConvertCumsumOp, + ConvertDiagPartOp, + ConvertDynamicExpandDimsOp, + ConvertDynamicSqueezeOp, + ConvertEinsumOp, + ConvertRFFTOp, + ConvertIRFFTOp, + ConvertFusedBatchNormGradOp, + ConvertFusedBatchNormGradV2Op, + ConvertFusedBatchNormGradV3Op, + ConvertFusedBatchNormV2Op, + ConvertFusedBatchNormV3Op, + ConvertInfeedDequeueTupleOp, + ConvertIdentityNOp, + ConvertInplaceUpdateOp, + ConvertLinSpaceOp, + ConvertMaxOp, + ConvertMinOp, + ConvertAvgPool2DOp, + ConvertAvgPool3DOp, + ConvertAvgPool2DGradOp, + ConvertAvgPool3DGradOp, + ConvertMaxPool2DOp, + ConvertMaxPool3DOp, + ConvertMaxPool2DGradOp, + ConvertMaxPool3DGradOp, + ConvertMeanOp, + ConvertOneHotOp, + ConvertOutfeedEnqueueTupleOp, + ConvertProdOp, + ConvertDynamicRangeOp, + ConvertMatrixDiagPartV3Op, + ConvertRangeOp, + ConvertSelectOp, + ConvertShapeOp, + ConvertSplitOp, + ConvertSplitVOp, + ConvertStridedSliceOp, + ConvertStridedSliceGradOp, + ConvertSumOp, + ConvertTensorScatterAddOp, + ConvertTensorScatterSubOp, + ConvertTensorScatterMinOp, + ConvertTensorScatterMaxOp, + ConvertTensorScatterUpdateOp, + ConvertTileOp, + ConvertTopKV2Op, + ConvertUnpackOp, + ConvertUnsortedSegmentMaxOp, + ConvertUnsortedSegmentMinOp, + ConvertUnsortedSegmentProdOp, + ConvertUnsortedSegmentSumOp, + ConvertRandomShuffleOp, + ConvertXlaShardingOp, + ConvertXlaDynamicUpdateSliceOp, + ConvertXlaConvV2Op, + ConvertXlaReducePrecisionOp, + ConvertXlaReduceScatterOp, + ConvertXlaReduceWindowOp, + ConvertXlaRngBitGeneratorOp, + ConvertXlaSelectAndScatterOp, + ConvertXlaSortOp, + ConvertXlaVariadicReduceV2Op, + ConvertXlaVariadicSortOp, + ConvertRollOp, + ConvertLeakyReluOp, + ConvertLeakyReluGradOp, + ConvertSplitOpDynamic, + ConvertSliceOpDynamic, + ConvertTileOpDynamic, + ConvertUnpackOpDynamic, + ConvertSigmoidGradOpDynamic, + ConvertConv2DDynamic, + ConvertPadOpDynamic, + ConvertGatherNdOpDynamic, + LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, + LowerYieldOp>(context); + // clang-format on +} +// LINT.ThenChange(:MlirAlwaysOps) +} // end namespace +} // end namespace mhlo + +namespace odml { +void PopulateLegalizeTfPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + mlir::mhlo::PopulatePatterns(context, patterns); +} +} // end namespace odml +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h new file mode 100644 index 00000000000000..9594769e93f71c --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ + +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir { + +namespace func { +class FuncOp; +} +class ModuleOp; +class Operation; +template +class OperationPass; +class Pass; + +namespace odml { + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td new file mode 100644 index 00000000000000..185216448a15ed --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td @@ -0,0 +1,802 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the legalization pattern definition file for TF to XLA. + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Tensor/IR/TensorOps.td" +include "stablehlo/dialect/ChloOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mhlo/IR/hlo_ops.td" + +def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; +def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; + +// IEEE compliant floating point tensors. +def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; + +//===----------------------------------------------------------------------===// +// BatchNorm op patterns. +//===----------------------------------------------------------------------===// + +def FalseBoolAttr : AttrConstraint().getValue()">>; +def TrueBoolAttr : AttrConstraint().getValue()">>; + +def CastValueToI64: NativeCodeCall< + "CastValueToI64($0.getLoc(), $1, &$_builder)">; + +def CastValueToElementType: NativeCodeCall< + "$_builder.create($0.getLoc(), $1, " + "getElementTypeOrSelf($2.getType()))">; + +// Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is +// the corresponding value of ranked tensor type whose axis is referred in $0. +def GetHLOAxisFromTFAxis : NativeCodeCall< + "GetHLOAxisFromTFAxis(" + "$0, $1.getType().cast().getRank(), &$_builder)">; + +// Same as the above but with $1 of type operand_range from variadic TensorFlow +// input. +def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< + "GetHLOAxisFromTFAxis(" + "$0, (*$1.begin()).getType().cast().getRank(), " + "&$_builder)">; + +def CastElementsToI64Elements : NativeCodeCall< + "hlo::convertElementsAttr(" + "$0.cast(), $_builder.getIntegerType(64)).cast()">; + +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; + +//===----------------------------------------------------------------------===// +// ApproximateEqual op pattern. +//===----------------------------------------------------------------------===// + +class MHLO_ComparisonDirectionValue : + ConstantAttr; + +class CHLO_ComparisonDirectionValue : + ConstantAttr; + +// TODO(b/228291745): Assert that $x and $y have the same shape. +def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), + (CHLO_BroadcastCompareOp + (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), + (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE))>; + +//===----------------------------------------------------------------------===// +// Assert op pattern. +//===----------------------------------------------------------------------===// + +// HLO and XLA doesn't support Assertions. +def LowerAssert : Pattern<(TF_AssertOp $condition, $data, $summarize), []>; + +//===----------------------------------------------------------------------===// +// Binary op patterns. +//===----------------------------------------------------------------------===// + +// Check that two values can be broadcasted together +def AreBroadcastCompatible : Constraint, + "types must be broadcastable">; + +class DirectBinaryPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (ToOp $l, $r, (BinBroadcastDimensions $l, $r))>; + +foreach fromToBinPair = [[TF_AddV2Op, CHLO_BroadcastAddOp], + [TF_Atan2Op, CHLO_BroadcastAtan2Op], + [TF_ComplexOp, CHLO_BroadcastComplexOp], + [TF_DivOp, CHLO_BroadcastDivOp], + [TF_LeftShiftOp, CHLO_BroadcastShiftLeftOp], + [TF_MaximumOp, CHLO_BroadcastMaxOp], + [TF_MinimumOp, CHLO_BroadcastMinOp], + [TF_ModOp, CHLO_BroadcastRemOp], + [TF_MulOp, CHLO_BroadcastMulOp], + [TF_NextAfterOp, CHLO_BroadcastNextAfterOp], + [TF_PolygammaOp, CHLO_BroadcastPolygammaOp], + [TF_PowOp, CHLO_BroadcastPowOp], + [TF_RealDivOp, CHLO_BroadcastDivOp], + [TF_SubOp, CHLO_BroadcastSubOp], + [TF_ZetaOp, CHLO_BroadcastZetaOp]] in + def : DirectBinaryPat; + +def LowerRightShiftSigned : + Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastShiftRightArithmeticOp $l, $r, + (BinBroadcastDimensions $l, $r)), + [(SignedIntTensor $r)]>; + +def LowerRightShiftUnsigned : + Pat<(TF_RightShiftOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastShiftRightLogicalOp $l, $r, + (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $r)]>; + +// Performs a substitution of FloorDiv, pseudo code below: +// +// return floor(div(x, y)) +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (MHLO_FloorOp + (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), + [(IEEEFloatTensor $l)]>; + +// Performs a substitution of FloorDiv for integer tensors, which required +// additional correction for a negative numerator / denominator. Equivalent +// pseudocode is shown below: +// +// T z = x / y +// return (z * y != x && (x < 0) != (y < 0)) ? z - 1 : z +// +// BroadcastToDimensions is used to compute the broadcast attr to higher +// dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') +// without returning the broadcast of 'r' to broadcast('l', 'r'). +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (MHLO_SelectOp + (CHLO_BroadcastAndOp + (CHLO_BroadcastCompareOp + (CHLO_BroadcastMulOp:$mul + (CHLO_BroadcastDivOp:$div $l, $r, + (BinBroadcastDimensions $l, $r)), + $r, (BinBroadcastDimensions $div, $r)), + $l, (BinBroadcastDimensions $mul, $l), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp + (CHLO_BroadcastCompareOp:$l_cmp $l, + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp:$r_cmp $r, + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (NullDenseI64ArrayAttr)), + (CHLO_BroadcastSubOp $div, + (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), + (NullDenseI64ArrayAttr)), $div), + [(SignedIntTensor $l)]>; + +// FloorDiv of unsigned is just div. +def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $l)]>; + +// Performs a substitution of FloorMod designed to correct for possibly negative +// values. Pseudocode shown below: +// +// T trunc_mod = std::fmod(x, y); +// return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y +// : trunc_mod +def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), + (MHLO_SelectOp + (CHLO_BroadcastAndOp + (CHLO_BroadcastCompareOp + (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), + (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp + (CHLO_BroadcastCompareOp:$r_cmp $r, + (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, + (BinBroadcastDimensions $rem, $r_zeros), CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (BinBroadcastDimensions $r_cmp, $rem_cmp), CHLO_ComparisonDirectionValue<"NE">, + (CHLO_DEFAULT_COMPARISON_TYPE)), + (NullDenseI64ArrayAttr)), + (CHLO_BroadcastAddOp $r, + $rem, (BinBroadcastDimensions $r, $rem)), $rem), + [(TensorOf<[I8, I16, I32, I64, F16, F32, F64]> $l)]>; + +// FloorMod of unsigned is just mod. +def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastRemOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(UnsignedIntTensor $l)]>; + +def Get2DTransposePerm: NativeCodeCall< + "Get2DTransposePerm($0, &$_builder)">; + +def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; + +def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), + (MHLO_DotOp + (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), + (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), + /*precision_config=*/(NullArrayAttr))>; + +//===----------------------------------------------------------------------===// +// Logical & bitwise binary op patterns. +//===----------------------------------------------------------------------===// + +class DirectLogicalBinaryPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (ToOp $l, $r, (BinBroadcastDimensions $l, $r)), + [(AnyTypeOf<[SignedIntTensor, UnsignedIntTensor]> $l)]>; + +foreach fromToBinPair = [[TF_LogicalAndOp, CHLO_BroadcastAndOp], + [TF_LogicalOrOp, CHLO_BroadcastOrOp], + [TF_BitwiseAndOp, CHLO_BroadcastAndOp], + [TF_BitwiseOrOp, CHLO_BroadcastOrOp], + [TF_BitwiseXorOp, CHLO_BroadcastXorOp]] in + def : DirectLogicalBinaryPat; + +//===----------------------------------------------------------------------===// +// Compare op patterns. +//===----------------------------------------------------------------------===// + +class DirectComparePat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r), + (CHLO_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (CHLO_DEFAULT_COMPARISON_TYPE))>; + +def : DirectComparePat>; +def : DirectComparePat>; +def : DirectComparePat>; +def : DirectComparePat>; + +class EqualityPat + : Pat<(FromOp AnyTensor:$l, AnyTensor:$r, + TrueBoolAttr:$incompatible_shape_error), + (CHLO_BroadcastCompareOp + $l, $r, (BinBroadcastDimensions $l, $r), direction, + (CHLO_DEFAULT_COMPARISON_TYPE)), + [(MHLO_Tensor $l)]>; + +def : EqualityPat>; +def : EqualityPat>; + +//===----------------------------------------------------------------------===// +// Concat op patterns. +//===----------------------------------------------------------------------===// + +def OneElementAttrPred + : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + +def OneElementAttr + : ElementsAttrBase, + "Scalar ElementsAttr">; + +def HasRankedFirstOperand + : Constraint()">>; + +def IsShapedTensor + : Constraint()">>; + +// This pattern converts TensorFlow axis format to HLO axis format which +// doesn't wrap around like TensorFlow and is always positive. For this +// conversion, use the first input to get inputs rank. Other inputs need not be +// ranked. +// Defining op for `axis` is TensorFlow constant op in the pattern as during +// the conversion, original Concat op operands still refers to the old ops even +// if HLO constant op is introduced as an replacement for the TensorFlow +// Constant op. +def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), + (MHLO_ConcatenateOp $inputs, + (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), + [(HasRankedFirstOperand $inputs)]>; + +//===----------------------------------------------------------------------===// +// CollectivePermute op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), + (MHLO_CollectivePermuteOp $input, + (CastElementsToI64Elements $source_target_pairs), + (NullChannelHandleAttr))>; + +//===----------------------------------------------------------------------===// +// CrossReplicaSum op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), + (MHLO_CrossReplicaSumOp $input, + (CastElementsToI64Elements $group_assignment))>; + +//===----------------------------------------------------------------------===// +// All2All op patterns. +//===----------------------------------------------------------------------===// + +def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; +def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), + (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; + +//===----------------------------------------------------------------------===// +// FFT op patterns. +//===----------------------------------------------------------------------===// + +class MHLO_FftTypeValue : + ConstantAttr; + +def GetInnerDimFromValue : NativeCodeCall< + "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + +def CheckInnerDimStatic + : Constraint(), &$_builder)">>; + +def : Pat<(TF_FFTOp:$res $input), + (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +def : Pat<(TF_IFFTOp:$res $input), + (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), + [(CheckInnerDimStatic $input)]>; + +//===----------------------------------------------------------------------===// +// GatherV2 op patterns. +//===----------------------------------------------------------------------===// + +// Here, $params and $indices needs to be ranked so that $axis and $batch_dims +// attributes can be converted from TensorFlow axis format supporting negative +// indexing to the HLO format. +def LegalizeGatherV2 : + Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, + (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), + (MHLO_TorchIndexSelectOp $params, $indices, + (GetHLOAxisFromTFAxis $axis, $params), + (GetHLOAxisFromTFAxis $batch_dims, $indices))>; + +//===----------------------------------------------------------------------===// +// Pad op patterns. +//===----------------------------------------------------------------------===// + +class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< + "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + +class SliceDenseIntElementsAttr : NativeCodeCall< + "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + +// Interior padding attribute based on the TF padding. +def GetInteriorPadding : NativeCodeCall < + "GetInteriorPadding($0.cast())">; + +def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), + (MHLO_PadOp $input, $c, + (SliceDenseIntElementsAttrColumn2D<"0"> $padding), + (SliceDenseIntElementsAttrColumn2D<"1"> $padding), + (GetInteriorPadding $padding))>; + +//===----------------------------------------------------------------------===// +// Identity op patterns. +//===----------------------------------------------------------------------===// + +foreach src = [TF_IdentityOp, TF_StopGradientOp, TF__EagerConstOp] in + def : Pat<(src $op), (replaceWithValue $op)>; + +// TODO(b/32223192): Support CheckNumerics in HLO. +foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in + def : Pat<(src $op, $msg), (replaceWithValue $op)>; + +//===----------------------------------------------------------------------===// +// MatMul op patterns. +//===----------------------------------------------------------------------===// + +def GetPrecisionConfig: NativeCodeCall< + "GetPrecisionConfig(&$_builder)">; + +def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + (MHLO_DotOp + (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), + (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), + /*precision_config=*/(GetPrecisionConfig))>; + +//===----------------------------------------------------------------------===// +// Lower `tf.ZerosLike` +//===----------------------------------------------------------------------===// + +def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), + (MHLO_ConstantLike<"0"> $arg)>; + +//===----------------------------------------------------------------------===// +// Lower `tf.OnesLike` +//===----------------------------------------------------------------------===// + +def : Pat<(TF_OnesLikeOp AnyTensor:$arg), + (MHLO_ConstantLike<"1"> $arg)>; + +//===----------------------------------------------------------------------===// +// Elu op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_EluOp AnyTensor:$features), + (MHLO_SelectOp + (MHLO_CompareOp + $features, + (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + $features, + (MHLO_Expm1Op $features))>; + +def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), + (MHLO_SelectOp + (CHLO_BroadcastCompareOp + $features, + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), + $gradients, + (MHLO_MulOp + $gradients, + (CHLO_BroadcastAddOp + $features, + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), + (BinBroadcastDimensions $one, $features))))>; + +//===----------------------------------------------------------------------===// +// Relu op patterns. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Make these patterns to TF to TF lowering. Relu6 lowering will +// require HLO canonicalization of min and max on a tensor to ClampOp. + +// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. +def : Pat<(TF_ReluOp AnyTensor:$input), + (CHLO_BroadcastMaxOp + (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, + (BinBroadcastDimensions $zero, $input)), + [(TF_IntOrFpTensor $input)]>; + +// TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. +def : Pat<(TF_Relu6Op AnyRankedTensor:$input), + (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, + (MHLO_ConstantOp (GetScalarOfType<6> $input))), + [(TF_IntOrFpTensor $input)]>; + +// ReluGrad(gradients, features) = gradients * (features > 0) +// The condition that $gradients and $features need to have the same shape is +// implicitly enforced: $zero is created to have the same shape as $features, +// MHLO_SelectOp enforces that $gradients and $zero have the same shape. +def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), + (MHLO_SelectOp + (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), + MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + $gradients, $zero)>; + +//===----------------------------------------------------------------------===// +// Softsign op patterns. +//===----------------------------------------------------------------------===// + +/// Converts a TF::SoftsignOp to HLO. +/// Softsign(features) = features / (1 + abs(features)) +def : Pat<(TF_SoftsignOp AnyTensor:$input), + (MHLO_DivOp + $input, + (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) + ) + >; + +/// Converts a TF::SoftsignGradOp to HLO. +/// SoftsignGrad(gradient, features) = gradient / ((1 + abs(features)) ^ 2) +def : Pattern< + (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), + [(CHLO_BroadcastAddOp:$add + (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), + (BinBroadcastDimensions $one, $features) + ), + (CHLO_BroadcastDivOp + $gradients, + (MHLO_MulOp $add, $add), + (BinBroadcastDimensions $gradients, $add) + ) + ]>; + +//===----------------------------------------------------------------------===// +// Slice op patterns. +//===----------------------------------------------------------------------===// + +def UnpackStartingIndices: NativeCodeCall< + "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; + +def CanBeTranslatedToDynamicSlice : Constraint())">>; + +def TFSliceSizes2HLOSliceSizes : NativeCodeCall< + "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "&$_builder)">; + +def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, + (ConstantLikeMatcher AnyAttr:$slice_sizes)), + (MHLO_DynamicSliceOp $input, + (UnpackStartingIndices $op, $starting_indices), + (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), + [(CanBeTranslatedToDynamicSlice $input, $starting_indices, + $slice_sizes)]>; + +//===----------------------------------------------------------------------===// +// Select op patterns. +//===----------------------------------------------------------------------===// + + def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, + MHLO_Tensor:$on_false), + (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; + +//===----------------------------------------------------------------------===// +// PartitionedCall and LegacyCall op patterns. +//===----------------------------------------------------------------------===// + +def ArgTypesMatchCallee : Constraint< + // $0 is a resultset (possibly empty), and $_op isn't assigned. So retrieve + // the op using the builder. + CPred<"ArgTypesMatchCallee(&*$_builder.getInsertionPoint(), $1, $2)">>; + +foreach callOp = [TF_PartitionedCallOp, TF_StatefulPartitionedCallOp] in { + def : Pat<(callOp:$op $args, FlatSymbolRefAttr:$f, + $config, $config_proto, $executor_type), + (CallOp $f, $args), + [(ArgTypesMatchCallee $op, $args, $f)]>; +} + +// The extra attr on this op is _disable_call_shape_inference, which we ignore +// in the bridge. +def : Pat<(TF_LegacyCallOp:$op $args, FlatSymbolRefAttr:$f, $attr), + (CallOp $f, $args), + [(ArgTypesMatchCallee $op, $args, $f)]>; + +//===----------------------------------------------------------------------===// +// Reverse op patterns. +//===----------------------------------------------------------------------===// + +// Handles axis conversion for TF reverse. +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; + +def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), + (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + +//===----------------------------------------------------------------------===// +// Unary op patterns. +//===----------------------------------------------------------------------===// + +foreach Mapping = [ + [TF_AbsOp, MHLO_AbsOp], + [TF_CeilOp, MHLO_CeilOp], + [TF_ComplexAbsOp, MHLO_AbsOp], + [TF_CosOp, MHLO_CosineOp], + [TF_ExpOp, MHLO_ExpOp], + [TF_Expm1Op, MHLO_Expm1Op], + [TF_ErfOp, MHLO_ErfOp], + [TF_FloorOp, MHLO_FloorOp], + [TF_ImagOp, MHLO_ImagOp], + [TF_InvertOp, MHLO_NotOp], + [TF_IsFiniteOp, MHLO_IsFiniteOp], + [TF_LogOp, MHLO_LogOp], + [TF_Log1pOp, MHLO_Log1pOp], + [TF_LogicalNotOp, MHLO_NotOp], + [TF_NegOp, MHLO_NegOp], + [TF_RealOp, MHLO_RealOp], + [TF_RsqrtOp, MHLO_RsqrtOp], + [TF_SigmoidOp, MHLO_LogisticOp], + [TF_SinOp, MHLO_SineOp], + [TF_SqrtOp, MHLO_SqrtOp], + [TF_TanhOp, MHLO_TanhOp], + [TF_TanOp, MHLO_TanOp] + ] in { + def : Pat<(Mapping[0] MHLO_Tensor:$input), + (Mapping[1] $input)>; +} + +foreach Mapping = [ + [TF_AcosOp, CHLO_AcosOp], + [TF_AcoshOp, CHLO_AcoshOp], + [TF_AsinOp, CHLO_AsinOp], + [TF_AsinhOp, CHLO_AsinhOp], + [TF_AtanOp, CHLO_AtanOp], + [TF_AtanhOp, CHLO_AtanhOp], + [TF_CoshOp, CHLO_CoshOp], + [TF_ConjOp, CHLO_ConjOp], + [TF_DigammaOp, CHLO_DigammaOp], + [TF_ErfcOp, CHLO_ErfcOp], + [TF_IsInfOp, CHLO_IsInfOp], + [TF_LgammaOp, CHLO_LgammaOp], + [TF_SinhOp, CHLO_SinhOp], + ] in { + def : Pat<(Mapping[0] MHLO_AnyTensor:$input), + (Mapping[1] $input)>; +} + +def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; + +// TODO(bixia): Lower with Truncate=True for floating point value conversions. +def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; + +def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), + (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; + + +// Lowering these ops with static shape to mhlo.reshape +foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { + def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), + (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], + (addBenefit 2)>; +} + +// Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. +def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; + +def BothElementTypesSameWidthIntOrFloat : Constraint, + "element types must be integers or floats">; + +// TODO(mgester): Due to restrictions of xla::BitcastConvertType we currently +// only lower if both input and output types are int or float and have same width + +def : Pat<(TF_BitcastOp:$res MHLO_Tensor:$arg), + (MHLO_BitcastConvertOp $arg), + [(BothElementTypesSameWidthIntOrFloat $res, $arg)]>; + +// TODO(jpienaar): Lower constant like to constant to broadcast if dynamic +// and going to MHLO. + +//===----------------------------------------------------------------------===// +// Random ops. +//===----------------------------------------------------------------------===// +// TODO(b/148269299): handle random number generator seeds/states correctly. + +class MHLO_RngDistributionValue : + ConstantAttr; + +def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), + (CastValueToI64 $old, $shape), + MHLO_RngDistributionValue<"UNIFORM">), + [(IsShapedTensor $shape)]>; + +def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), + (MHLO_RngOp + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), + (MHLO_ConstantOp + (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), + (CastValueToI64 $old, $shape), + MHLO_RngDistributionValue<"NORMAL">), + [(IsShapedTensor $shape)]>; + +//===----------------------------------------------------------------------===// +// Sigmoid grad op. +//===----------------------------------------------------------------------===// + +// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the +// shape of $l instead of having it as a constant. +def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), + (MHLO_MulOp + (MHLO_MulOp $r, $l), + (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; + +//===----------------------------------------------------------------------===// +// Softplus op. +//===----------------------------------------------------------------------===// + +def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; + +def : Pattern<(TF_SoftplusOp AnyTensor:$features), + [ + (MHLO_ExpOp:$features_exp $features), + (CHLO_BroadcastAddOp:$threshold + (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), + (MHLO_ConstantOp (GetScalarOfType<2> $features)), + (NullDenseI64ArrayAttr) + ), + (MHLO_SelectOp:$output + (CHLO_BroadcastCompareOp + $features, + (MHLO_NegOp $threshold), + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"GT">, + (CHLO_DEFAULT_COMPARISON_TYPE) + ), + $features, + (MHLO_SelectOp + (CHLO_BroadcastCompareOp + $features, + $threshold, + (NullDenseI64ArrayAttr), + CHLO_ComparisonDirectionValue<"LT">, + (CHLO_DEFAULT_COMPARISON_TYPE) + ), + $features_exp, + (MHLO_Log1pOp $features_exp) + ) + ), + (replaceWithValue $output) + ]>; + +//===----------------------------------------------------------------------===// +// XlaReplicaId op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaReplicaIdOp), + (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; + +//===----------------------------------------------------------------------===// +// XlaGather op. +//===----------------------------------------------------------------------===// + +def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; + +def HasValidGatherDims : Constraint>; + +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), + $dimension_numbers, $indices_are_sorted), + (MHLO_GatherOp $operand, $start_indices, + (ToGatherDimNumsAttr $dimension_numbers), + (CastElementsToI64Elements $slice_sizes), + $indices_are_sorted), + [(HasValidGatherDims $dimension_numbers)]>; + +//===----------------------------------------------------------------------===// +// XlaDotOp op. +//===----------------------------------------------------------------------===// + +def ToDotDimNumsAttr : NativeCodeCall<"GetDotDimNumsAttr($0, &$_builder)">; + +def ToPrecisionConfigsAttr : NativeCodeCall<"GetPrecisionConfigAttr($0, &$_builder)">; + +def HasValidDotDims : Constraint>; + +def HasValidPrecisionConfig : Constraint>; + +def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), + (MHLO_DotGeneralOp $lhs, $rhs, + (ToDotDimNumsAttr $dimension_numbers), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), + [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; + +//===----------------------------------------------------------------------===// +// XlaDotV2Op op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), + (MHLO_DotGeneralOp $lhs, $rhs, + (ToDotDimNumsAttr $dimension_numbers), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), + [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; + +//===----------------------------------------------------------------------===// +// XlaDynamicSlice op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, + (ConstantLikeMatcher AnyAttr:$slice_sizes)), + (MHLO_DynamicSliceOp $input, + (UnpackStartingIndices $op, $starting_indices), + (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; + +//===----------------------------------------------------------------------===// +// XlaEisumOp op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), + (MHLO_EinsumOp $lhs, $rhs, $equation)>; + +//===----------------------------------------------------------------------===// +// XlaOptimizationBarrierOp op. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_XlaOptimizationBarrierOp $args), + (MHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc index 5f04704d54ef78..5a63a339e460b9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -244,7 +244,7 @@ class TFXlaCallModuleOpToStablehloPass } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + quant::QuantDialect, shape::ShapeDialect>(); } void runOnOperation() override { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index 82c7a4b4687055..f52ca0a40553c5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index 062222f72b3b9a..70f62c3e0b582e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc index 11cb7254b75c76..abf14a6cb2f6e5 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc @@ -43,7 +43,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" namespace mlir { namespace odml { @@ -52,7 +52,7 @@ namespace { #define DEBUG_TYPE "stablehlo-optimize-layout" #define GEN_PASS_DEF_TRANSPOSECOMMUTEOPSPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class TransposeCommuteOpsPass : public impl::TransposeCommuteOpsPassBase { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc index 250d0656031471..c3544edd9d6cb7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc @@ -39,7 +39,7 @@ namespace odml { namespace { #define GEN_PASS_DEF_PREPAREHLOPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class PrepareHloPass : public impl::PrepareHloPassBase { public: diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h similarity index 90% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h index 331505e2445e87..4a95b2530488f3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ #include @@ -80,9 +80,9 @@ void PopulateLegalizeHloToTfPatterns(RewritePatternSet* patterns, #define GEN_PASS_DECL #define GEN_PASS_REGISTRATION -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_PASSES_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.td similarity index 99% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td rename to tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.td index 2cb67eb8c72044..c1490ce575d8ca 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.td @@ -54,7 +54,7 @@ def ComposeUniformQuantizedTypePass : Pass<"compose-uniform-quantized-type", "Mo }]; let dependentDialects = [ "stablehlo::StablehloDialect", - "quant::QuantizationDialect", + "quant::QuantDialect", ]; } @@ -68,7 +68,7 @@ def UniformQuantizedStableHloToTflPass }]; let dependentDialects = [ "stablehlo::StablehloDialect", - "quant::QuantizationDialect", + "quant::QuantDialect", "mlir::TFL::TFLDialect", ]; } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc index ae4ee26eab9b8c..e38cad1d4c7edc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc @@ -33,11 +33,11 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" -#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -102,7 +102,7 @@ void TFToMhloPass::runOnOperation() { MLIRContext *context = func->getContext(); RewritePatternSet patterns(context); - mhlo::PopulateLegalizeTfPatterns(context, &patterns); + odml::PopulateLegalizeTfPatterns(context, &patterns); TF::PopulateTFLoweringBeforeHLOPatterns(context, &patterns); mhlo::Tf2XlaTypeConverter converter; mhlo::PopulateLegalizeTfWithTf2XlaPatterns( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_legalize_chlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_legalize_chlo.cc index 775e773b2f408f..471422be2bc2f1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_legalize_chlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_legalize_chlo.cc @@ -33,7 +33,7 @@ namespace odml { namespace { #define GEN_PASS_DEF_LEGALIZECHLOTOTFLPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class LegalizeChloToTflPass : public impl::LegalizeChloToTflPassBase { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index d5c878c5d9a21c..bb92fbdbe2a7c7 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -40,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/get_dimension_size.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/if.h" @@ -51,7 +53,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/while.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -260,7 +262,7 @@ bool ValueGreaterThanZero(ElementsAttr float_or_int) { } #define GEN_PASS_DEF_LEGALIZEHLOTOTFLITEPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" bool SupportedComparisonType(mhlo::ComparisonTypeAttr comp_type) { if (!comp_type) return true; @@ -291,6 +293,14 @@ bool IsCompareLegal(mhlo::CompareOp op) { return !SupportedComparisonType(op.getCompareTypeAttr()); } +bool IsAbsOpLegal(mhlo::AbsOp op) { + return !llvm::cast(op.getOperand().getType()) + .getElementType() + .isIntOrFloat() && + !llvm::cast( + op.getOperand().getType().getElementType()); +} + void SetUnaryOpLegal(ConversionTarget& target) { auto is_legal = [](Operation* op) { return !llvm::cast(op->getOperand(0).getType()) @@ -300,7 +310,6 @@ void SetUnaryOpLegal(ConversionTarget& target) { target.addDynamicallyLegalOp< // go/keep-sorted start // clang-format off - mhlo::AbsOp, mhlo::BitcastConvertOp, mhlo::CeilOp, mhlo::ConvertOp, @@ -377,6 +386,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { target.addLegalOp(); target.addDynamicallyLegalOp(IsCbrtLegal); + target.addDynamicallyLegalOp(IsAbsOpLegal); target.addDynamicallyLegalOp(IsNotOpLegal); target.addDynamicallyLegalOp(IsCompareLegal); target.addDynamicallyLegalOp( @@ -432,6 +442,7 @@ void LegalizeHloToTfLitePass::runOnOperation() { PopulateWhilePatterns(context, patterns, target); PopulateGetDimensionSizePatterns(context, patterns, target); PopulateIfPatterns(context, patterns, target); + PopulateFftPatterns(context, patterns, target); PopulateCustomCallPatterns(context, patterns, target); patterns.add(context); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index d4368384cb9c6a..27ec0fa643b315 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -326,9 +326,13 @@ def LowerCbrt : Pat<(MHLO_CbrtOp $opr), TFL_AF_None)), [(F32Tensor $opr)]>; +// Pattern to legacyze mhlo.abs(complex) to tfl.complex_abs. +def : Pat<(MHLO_AbsOp MHLO_ComplexTensor:$arg), (TFL_ComplexAbsOp $arg)>; + +// Pattern to match non-complex abs. +def : Pat<(MHLO_AbsOp MHLO_PredIntFpOrQuantizedTensor:$arg), (TFL_AbsOp $arg)>; foreach pair = [ - [MHLO_AbsOp, TFL_AbsOp], [MHLO_BitcastConvertOp, TFL_BitcastOp], [MHLO_CeilOp, TFL_CeilOp], [MHLO_CosineOp, TFL_CosOp], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc index e717114610b527..c4500d552d0573 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc @@ -50,13 +50,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" // IWYU pragma: keep namespace mlir { namespace odml { #define GEN_PASS_DEF_BUILDSTABLEHLOCOMPOSITEPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc index b8a1e02adfaabd..fbbb8a2217e47e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" // IWYU pragma: keep namespace mlir { namespace odml { #define GEN_PASS_DEF_LIFTCALLSITELOCCALLERPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" namespace { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index 4ec630a5850ed6..b7b93fb4e9da41 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -19,9 +19,9 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/drop_savedmodel_semantics.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc index 7a3abd35d0d376..b0a023494f1ca4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -38,7 +38,7 @@ namespace { #define DEBUG_TYPE "unfold-splat-constant-pass" #define GEN_PASS_DEF_UNFOLDSPLATCONSTANTPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" // Undo the MHLO::BroadcastInDimOp folding pattern on splat tensor. // TODO(b/295966255): Remove this pass after moving MHLO folders to a separate diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc index dadcabc55a5e57..f5d756d971610e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index c3a05d5a0706a7..bbda975563b666 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -27,8 +27,8 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // NOLINT: Required to register quantization dialect. -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // NOLINT: Required to register quantization dialect. +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -80,7 +80,7 @@ const char* kPaddingSame = "SAME"; const char* kPaddingValid = "VALID"; #define GEN_PASS_DEF_UNIFORMQUANTIZEDSTABLEHLOTOTFLPASS -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h.inc" class UniformQuantizedStableHloToTflPass : public impl::UniformQuantizedStableHloToTflPassBase< diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc new file mode 100644 index 00000000000000..b120a6f02e1460 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" + +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" + +namespace mlir { +namespace odml { + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarOfType(ty, raw_value)); +} + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarNegZeroOfType(ty)); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { + RankedTensorType ty = + RankedTensorType::get(static_cast(attr.size()), + IntegerType::get(attr.getContext(), 64)); + return DenseIntElementsAttr::get(ty, attr.getValue()); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h new file mode 100644 index 00000000000000..13ff4c4767721d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h @@ -0,0 +1,61 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc new file mode 100644 index 00000000000000..40d3cc27164427 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" + +#include + +#include +#include +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { +namespace { + +TEST(UtilsTest, GetScalarConstOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getI32Type(); + mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); + EXPECT_EQ(op.getValue().getValues()[0], 123); + + op->destroy(); +} + +TEST(UtilsTest, GetScalarNegZeroOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getF32Type(); + mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); + EXPECT_EQ(op.getValue().getValues()[0], -0.f); + + op->destroy(); +} + +TEST(UtilsTest, GetI64ElementsAttr) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + auto valuesAttr = builder.getI64ArrayAttr(values); + DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +TEST(UtilsTest, GetI64ElementsAttrBuilder) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +} // namespace + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir new file mode 100644 index 00000000000000..12de9da5939573 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -0,0 +1,32 @@ +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + + +module { + // CHECK-LABEL: func.func public @main + func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ + // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< + // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], + // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, + // CHECK-ROUNDTRIP-SAME: slice_sizes = array}> : + // CHECK-ROUNDTRIP-SAME: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> + // CHECK-ROUNDTRIP: return %[[gather]] + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [3], + collapsed_slice_dims = [1, 3], + operand_batching_dims = [0, 2], + start_indices_batching_dims = [1, 0], + start_index_map = [1, 3], + index_vector_dim = 3 + >, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + return %0 : tensor<4x3x5x8xi32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir new file mode 100644 index 00000000000000..44d1bb7dd8b72f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -0,0 +1,34 @@ +// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP + + +module { + // CHECK-LABEL: func.func public @main + func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { + // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 + // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ + // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter + // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], + // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> + // CHECK-ROUNDTRIP: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> + // CHECK-ROUNDTRIP: return %[[scatter]] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ + indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], + input_batching_dims = [0, 2], + scatter_indices_batching_dims = [1, 0], + scatter_dims_to_operand_dims = [1, 3], + index_vector_dim = 3 + >, + unique_indices = false + }> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> + return %0 : tensor<3x2x4x7x9xi32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir index d1f7a4bb6423a2..ef67e98e03e808 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/custom_op.mlir @@ -6,3 +6,11 @@ func.func @main(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, % } // CHECK-LABEL: main // CHECK: "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "Convolution2DTransposeBias", custom_option = #tfl}> : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + +func.func @main_non_hex_bytes(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<1x32x42x128xf32>, %arg2: tensor<4xi32>) -> tensor<1x64x84x32xf32> { + %0 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "Convolution2DTransposeBias", custom_option = #tfl} : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> + func.return %0 : tensor<1x64x84x32xf32> +} +// CHECK-LABEL: main_non_hex_bytes +// Hex representation below determined by the following command: echo -n "this is a string" | od -An -t x1 | tr -d ' ' | tr '[:lower:]' '[:upper:'] +// CHECK: "tfl.custom"(%arg0, %arg1, %arg2) <{custom_code = "Convolution2DTransposeBias", custom_option = #tfl}> : (tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<4xi32>) -> tensor<1x64x84x32xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir new file mode 100644 index 00000000000000..61df9ad531515a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/debug_metadata.mlir @@ -0,0 +1,36 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer --serialize-debug-metadata=true %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir --mlir-print-debuginfo -o - | FileCheck %s +// This test verifies that debug locations are round-trippable. + +module @jit_relu attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, tfl._legalize_tfl_variables = true} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tfl.less"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc(#loc) + // CHECK-DAG: {{.*}} = tfl.less(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1> loc([[LOC:.+]]) + %1 = "tf.If"(%0, %arg0, %arg1) {else_branch = @cond_false, then_branch = @cond_true, is_stateless = false} : (tensor<1xi1>, tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> loc(#loc) + // CHECK-DAG: {{.*}} = "tf.If"(%0, %arg0, %arg1) {{.*}} -> tensor<1xf32> loc([[LOC]]) + func.return %1 : tensor<1xf32> loc(#loc) + } + + func.func @cond_true(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc4) + // CHECK-DAG: {{.*}} = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC4:.+]]) + func.return %0 : tensor<*xf32> loc(#loc) + } + + func.func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc(#loc5) + // CHECK-DAG: {{.*}} = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> loc([[LOC5:.+]]) + func.return %0 : tensor<*xf32> loc(#loc) + } +} loc(#loc) +#loc = loc(unknown) +// CHECK-DAG: [[LOC]] = loc(unknown) +#loc1 = loc("":1:4) +// CHECK-DAG: [[LOC1:.+]] = loc("":1:4) +#loc2 = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) +// CHECK-DAG: [[LOC2:.+]] = loc("third_party/py/IPython/v3_2_3/core/interactiveshell.py":3066:16) +#loc3 = loc(callsite(#loc1 at #loc2)) +// CHECK-DAG: [[LOC3:.+]] = loc(callsite([[LOC1]] at [[LOC2]])) +#loc4 = loc("jit(relu)/jit(main)/max"(#loc3)) +// CHECK-DAG: [[LOC4]] = loc("jit(relu)/jit(main)/max"([[LOC3]])) +#loc5 = loc(fused<"">[#loc1, #loc2]) +// CHECK-DAG: [[LOC5]] = loc(fused<"">[[[LOC1]], [[LOC2]]]) \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index d70b18c9fa6036..95a88b7cf453ca 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -842,6 +842,26 @@ func.func @FuseReshapeAroundBMMNagativeTest2(%arg0: tensor<2x1536xf32>) -> tenso // CHECK: return %3 : tensor<2x768xf32> } +// CHECK-LABEL: @FuseReshapeAroundBMMNagativeTest3 +// Checks that the pattern matcher FuseReshapesAroundBatchMatMulLHS does not get +// applied for this case that does not pass the constraint around input rank. +func.func @FuseReshapeAroundBMMNagativeTest3(%arg0: tensor<10xf32>) -> tensor<5xf32> { + %cst_0 = arith.constant dense_resource<__elided__> : tensor<10x5xf32> + %cst_3 = arith.constant dense<5> : tensor<1xi32> + %cst_4 = arith.constant dense<[1, 10]> : tensor<2xi32> + %0 = "tfl.reshape"(%arg0, %cst_4) : (tensor<10xf32>, tensor<2xi32>) -> tensor<1x10xf32> + %1 = "tfl.batch_matmul"(%0, %cst_0) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x10xf32>, tensor<10x5xf32>) -> tensor<1x5xf32> + %2 = "tfl.reshape"(%1, %cst_3) : (tensor<1x5xf32>, tensor<1xi32>) -> tensor<5xf32> + return %2 : tensor<5xf32> + // CHECK: %cst = arith.constant dense_resource<__elided__> : tensor<10x5xf32> + // CHECK: %cst_0 = arith.constant dense<5> : tensor<1xi32> + // CHECK: %cst_1 = arith.constant dense<[1, 10]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<10xf32>, tensor<2xi32>) -> tensor<1x10xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %cst) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x10xf32>, tensor<10x5xf32>) -> tensor<1x5xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<1x5xf32>, tensor<1xi32>) -> tensor<5xf32> + // CHECK: return %2 : tensor<5xf32> +} + // CHECK-LABEL: @convert_bmm_rhs_transpose_into_fc // FOLD-LABEL: @convert_bmm_rhs_transpose_into_fc func.func @convert_bmm_rhs_transpose_into_fc(%arg0: tensor<8x256xf32>, %arg1: tensor<256x256xf32>) -> (tensor<8x256xf32>) { @@ -968,6 +988,24 @@ func.func @FuseBMMOutputReshape_WithTwoLHSContractionDims(%arg0: tensor<8x256x17 // CHECK: return %2 : tensor<1x128x1792xf32> } +// CHECK-LABEL: @FuseBMMOutputReshape_WithTwoLHSContractionDims_Negative +func.func @FuseBMMOutputReshape_WithTwoLHSContractionDims_Negative(%arg0: tensor<1x3872x1x128xf32>) -> tensor<1x3872x8x16xf32> { + %cst_84 = arith.constant dense<[3872, 128]> : tensor<2xi32> + %cst_82 = arith.constant dense<[1, 3872, 8, 16]> : tensor<4xi32> + %cst_24 = arith.constant dense_resource<__elided__> : tensor<128x128xf32> + %59 = "tfl.reshape"(%arg0, %cst_84) : (tensor<1x3872x1x128xf32>, tensor<2xi32>) -> tensor<3872x128xf32> + %60 = "tfl.batch_matmul"(%59, %cst_24) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3872x128xf32>, tensor<128x128xf32>) -> tensor<3872x128xf32> + %67 = "tfl.reshape"(%60, %cst_82) : (tensor<3872x128xf32>, tensor<4xi32>) -> tensor<1x3872x8x16xf32> + func.return %67: tensor<1x3872x8x16xf32> + // CHECK: %cst = arith.constant dense<[3872, 128]> : tensor<2xi32> + // CHECK: %cst_0 = arith.constant dense<[1, 3872, 8, 16]> : tensor<4xi32> + // CHECK: %cst_1 = arith.constant dense_resource<__elided__> : tensor<128x128xf32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x3872x1x128xf32>, tensor<2xi32>) -> tensor<3872x128xf32> + // CHECK: %1 = "tfl.batch_matmul"(%0, %cst_1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3872x128xf32>, tensor<128x128xf32>) -> tensor<3872x128xf32> + // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<3872x128xf32>, tensor<4xi32>) -> tensor<1x3872x8x16xf32> + // CHECK: return %2 : tensor<1x3872x8x16xf32> +} + // CHECK-LABEL: @FuseBMMOutputReshape_WithThreeLHSContractionDims func.func @FuseBMMOutputReshape_WithThreeLHSContractionDims(%arg0: tensor<2x8x256x1792xf32>, %arg1: tensor<1x2x128x8x256xf32>) -> (tensor<1x128x1792xf32>){ %cst = arith.constant dense<[1, 128, 1792]> : tensor<3xi32> @@ -3636,6 +3674,20 @@ func.func @gelu(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_erfc(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.707106769> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %2 = "tfl.neg"(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tf.Erfc"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.mul"(%arg0, %cst_0) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %4) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %6 : tensor<3xf32> + +// CHECK-LABEL:gelu_erfc +// CHECK: "tfl.gelu"(%arg0) <{approximate = false}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.707106769> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir index 4cefb2596aeb50..1034782d68d9b5 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir @@ -345,4 +345,43 @@ func.func @fakeQuantConcat(%arg0: tensor<1x6400x2xf32>, %arg1: tensor<1x1600x2xf // CHECK: return %9 } + +// CHECK-LABEL: populateFakeQuantOnMeanOutput +func.func @populateFakeQuantOnMeanOutput(%arg0: tensor) -> (tensor) { + %cst = arith.constant dense<-1.0> : tensor + %cst_1 = arith.constant dense<1.0> : tensor + %cst_2 = arith.constant dense<0> : tensor<1xi32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_1) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + %1 = "tf.Mean"(%0, %cst_2) <{keep_dims = false}> : (tensor, tensor<1xi32>) -> tensor + return %1 : tensor + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) +// CHECK: %1 = "tfl.quantize"(%0) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK: %2 = "tfl.dequantize"(%1) : (tensor>) -> tensor +// CHECK: %3 = "tf.Mean"(%2, %cst_1) +// CHECK: %4 = "tf.FakeQuantWithMinMaxVars"(%3, %cst, %cst_0) +// CHECK: %5 = "tfl.quantize"(%4) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK: %6 = "tfl.dequantize"(%5) : (tensor>) -> tensor +// CHECK: return %6 +} + +// CHECK-LABEL: populateFakeQuantOnMeanOutputNegativeCase +func.func @populateFakeQuantOnMeanOutputNegativeCase(%arg0: tensor) -> (tensor) { + %cst = arith.constant dense<-1.0> : tensor + %cst_1 = arith.constant dense<1.0> : tensor + %cst_2 = arith.constant dense<0> : tensor<1xi32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_1) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + %1 = "tf.Mean"(%0, %cst_2) <{keep_dims = false}> : (tensor, tensor<1xi32>) -> tensor + %2 = "tf.FakeQuantWithMinMaxVars"(%1, %cst, %cst_1) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + return %2 : tensor + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) +// CHECK: %1 = "tfl.quantize"(%0) +// CHECK: %2 = "tfl.dequantize"(%1) +// CHECK: %3 = "tf.Mean"(%2, %cst_1) +// CHECK: %4 = "tf.FakeQuantWithMinMaxVars"(%3, %cst, %cst_0) +// CHECK-NOT: "tf.FakeQuantWithMinMaxVars" +} + } + diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index 7640a499250cae..ed7403ae283604 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -1,6 +1,6 @@ -// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s -// RUN: tf-opt %s -tfl-quantize="legacy-quantize=true" | FileCheck --check-prefix=LEGACY %s -// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize="ops-blocklist=tfl.fully_connected,tfl.softmax locs-blocklist=Block,NullBlock" | FileCheck --check-prefix=BLOCK %s +// RUN: tf-opt %s -split-input-file -tfl-prepare-quantize -tfl-quantize | FileCheck %s +// RUN: tf-opt %s -split-input-file -tfl-quantize="legacy-quantize=true" | FileCheck --check-prefix=LEGACY %s +// RUN: tf-opt %s -split-input-file -tfl-prepare-quantize -tfl-quantize="ops-blocklist=tfl.fully_connected,tfl.softmax locs-blocklist=Block,NullBlock" | FileCheck --check-prefix=BLOCK %s // CHECK-LABEL: QuantizeFloatConst func.func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform> { @@ -12,6 +12,8 @@ func.func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform tensor<2x4x!quant.uniform> { %0 = arith.constant dense<[[-0.75, -0.5, -0.25, 0.0], [0.25, 0.5, 0.75, 1.0]]> : tensor<2x4xf32> @@ -22,6 +24,8 @@ func.func @QuantizeFloatConst4Bits() -> tensor<2x4x!quant.uniform tensor<2x2x!quant.uniform> { %0 = arith.constant dense<[[-0.1, 1.0], [1.0, 3.0]]> : tensor<2x2xf32> @@ -32,6 +36,8 @@ func.func @QuantizeDenseFloatConst() -> tensor<2x2x!quant.uniform tensor<2x2x!quant.uniform> { %0 = arith.constant dense<3.0> : tensor<2x2xf32> @@ -42,6 +48,8 @@ func.func @QuantizeSplatFloatConst() -> tensor<2x2x!quant.uniform tensor<2x2xf32> { %0 = arith.constant dense<-0.1> : tensor<2x2xf32> @@ -53,6 +61,8 @@ func.func @NotQuantizeFloatConst() -> tensor<2x2xf32> { // CHECK: return %[[cst]] : tensor<2x2xf32> } +// ----- + // CHECK-LABEL: DequantizeAndQuantize func.func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform> { %cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform> @@ -64,6 +74,8 @@ func.func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConv2D func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -82,6 +94,8 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConv2D4Bit func.func @QuantizeConv2D4Bit(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -100,6 +114,8 @@ func.func @QuantizeConv2D4Bit(tensor<1x224x224x3x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeDepthwiseConv2D func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -117,6 +133,8 @@ func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -134,6 +152,8 @@ func.func @QuantizeDepthwiseConv2D4Bit(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -159,6 +179,8 @@ func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -184,6 +206,8 @@ func.func @QuantizeFullyConnected4Bit(tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<3x3x!quant.uniform:f32, 1.0>>, %arg2: none) -> tensor<3x!quant.uniform> { %0 = "tfl.dequantize"(%arg0) : (tensor<3x!quant.uniform>) -> tensor<3xf32> @@ -196,6 +220,8 @@ func.func @QuantizeNoBiasFullyConnected(%arg0: tensor<3x!quant.uniform>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -208,6 +234,8 @@ func.func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeReshape2D func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -221,6 +249,8 @@ func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeSoftmax func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -233,6 +263,8 @@ func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeLogistic func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -245,6 +277,8 @@ func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>, tensor<1x56x56x24x!quant.uniform>) -> tensor<1x56x56x24x!quant.uniform> { ^bb0(%arg0: tensor<1x56x56x24x!quant.uniform>, %arg1: tensor<1x56x56x24x!quant.uniform>): @@ -264,6 +298,8 @@ func.func @QuantizeAdd(tensor<1x56x56x24x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConcat func.func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): @@ -277,6 +313,8 @@ func.func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConcatRequantize func.func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): @@ -291,6 +329,8 @@ func.func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, // CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeMaxPool2D func.func @QuantizeMaxPool2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -303,6 +343,8 @@ func.func @QuantizeMaxPool2D(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeSplit func.func @QuantizeSplit(%arg: tensor<4x!quant.uniform>, %cst: tensor) -> (tensor<2x!quant.uniform>,tensor<2x!quant.uniform>) { %0 = "tfl.dequantize"(%arg) : (tensor<4x!quant.uniform>) -> tensor<4xf32> @@ -315,6 +357,8 @@ func.func @QuantizeSplit(%arg: tensor<4x!quant.uniform>, %cst: tens // CHECK: return %[[sp]]#0, %[[sp]]#1 } +// ----- + // CHECK-LABEL: QuantizeSplitUnusedResults func.func @QuantizeSplitUnusedResults(%arg: tensor<4x!quant.uniform>, %cst: tensor) -> (tensor<2x!quant.uniform>,tensor<2x!quant.uniform>) { @@ -328,6 +372,8 @@ func.func @QuantizeSplitUnusedResults(%arg: tensor<4x!quant.uniform // CHECK: return %[[sp]]#0, %[[sp]]#1 } +// ----- + // CHECK-LABEL: QuantizeShape func.func @QuantizeShape(%arg0: tensor<*x!quant.uniform>, %arg1: tensor>) -> (tensor,tensor<3xi32>) { @@ -342,6 +388,8 @@ func.func @QuantizeShape(%arg0: tensor<*x!quant.uniform>, // CHECK-NEXT: %[[s2]], %[[s3]] : tensor, tensor<3xi32> } +// ----- + // CHECK-LABEL: QuantizeMultipleUsers func.func @QuantizeMultipleUsers(%arg1: tensor>) -> (tensor<1xi32>,tensor<1xi32>) { %1 = "tfl.dequantize"(%arg1) : (tensor>) -> tensor @@ -352,6 +400,8 @@ func.func @QuantizeMultipleUsers(%arg1: tensor>) - // CHECK-NEXT: %[[s1]], %[[s1]] : tensor<1xi32>, tensor<1xi32> } +// ----- + // CHECK-LABEL: NotQuantizePow func.func @NotQuantizePow(%arg0: tensor<4x!quant.uniform>, %arg1: tensor<4x!quant.uniform>) -> (tensor<4x!quant.uniform>) { @@ -370,6 +420,8 @@ func.func @NotQuantizePow(%arg0: tensor<4x!quant.uniform>, } +// ----- + // CHECK-LABEL: QuantizeCustomTfOp func.func @QuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform>, %arg1: tensor<1x!quant.uniform>, %arg2: tensor<1x!quant.uniform>, @@ -394,6 +446,8 @@ func.func @QuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform tensor<128x128x!quant.uniform> } +// ----- + // CHECK-LABEL: NotQuantizeCustomTfOp func.func @NotQuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform>, %arg1: tensor<1x!quant.uniform>, %arg2: tensor<1x!quant.uniform>, @@ -416,6 +470,8 @@ func.func @NotQuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32> } +// ----- + // CHECK-LABEL: NotQuantizableValues func.func @NotQuantizableValues(%arg0: tensor<1x!tf_type.string>) -> (tensor<1x?x16x!quant.uniform>, tensor<1x!tf_type.string>, tensor<1xi32>) { %0:3 = "tfl.custom_tf"(%arg0) ({ @@ -433,6 +489,8 @@ func.func @NotQuantizableValues(%arg0: tensor<1x!tf_type.string>) -> (tensor<1x? // CHECK: }) {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x!tf_type.string>) -> (tensor<1x?x16x!quant.uniform>, tensor<1x!tf_type.string>, tensor<1xi32>) } +// ----- + // Checks that legacy path correctly handles asymmetric quantized values. // LEGACY-LABEL: CheckLegacyQuantizeAdd func.func @CheckLegacyQuantizeAdd() -> tensor<1x2x!quant.uniform> { @@ -443,6 +501,8 @@ func.func @CheckLegacyQuantizeAdd() -> tensor<1x2x!quant.uniform>, value = dense<{{\[\[}}-1, 127]]> : tensor<1x2xi8>}> } +// ----- + func.func private @testIfThen(tensor<*xf32>) -> tensor<*xf32> func.func private @testIfElse(tensor<*xf32>) -> tensor<*xf32> @@ -461,6 +521,8 @@ func.func @NotQuantizeIf(%arg0: tensor, // CHECK-NEXT: return %[[q]] } +// ----- + // CHECK-LABEL: NotQuantizeReadVariable func.func @NotQuantizeReadVariable() -> tensor<1x2x3x!quant.uniform:f32, 0.047244094488188976:128>> { %0 = "tfl.var_handle"() {container = "", shared_name = "states"} : () -> tensor>> @@ -473,6 +535,8 @@ func.func @NotQuantizeReadVariable() -> tensor<1x2x3x!quant.uniform:f3 // CHECK-NEXT: return %[[quantize]] } +// ----- + // CHECK-LABEL: foldQuantWeightsIntoTposeConv func.func @foldQuantWeightsIntoTposeConv(%arg0: tensor<2x2x3x2048xf32>) -> tensor<2x3x2x2048xf32> { %output_shape = arith.constant dense<[2, 3, 2, 2048]> : tensor<4xi32> @@ -486,6 +550,8 @@ func.func @foldQuantWeightsIntoTposeConv(%arg0: tensor<2x2x3x2048xf32>) -> tenso // CHECK: "tfl.transpose_conv"(%cst, %1, %arg0, %0) <{fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<4xi32>, tensor<4x2x2x2048x!quant.uniform:f32 } +// ----- + // CHECK-LABEL: foldQuantWeightsIntoTposeConvf16NotFolded func.func @foldQuantWeightsIntoTposeConvf16NotFolded(%arg0: tensor<2x2x3x2048xf32>) -> tensor<2x3x2x2048xf32> { %output_shape = arith.constant dense<[2, 3, 2, 2048]> : tensor<4xi32> @@ -498,6 +564,8 @@ func.func @foldQuantWeightsIntoTposeConvf16NotFolded(%arg0: tensor<2x2x3x2048xf3 // CHECK: "tfl.dequantize" } +// ----- + // CHECK-LABEL: foldQuantWeightsIntoEmbeddingLookup func.func @foldQuantWeightsIntoEmbeddingLookup(%arg0: tensor<3xi32>) -> tensor<3x512xf32> { %q_weighs = "tfl.pseudo_qconst"() {qtype = tensor<3074x512x!quant.uniform:f32, 0.15:151>>, value = dense<-76> : tensor<3074x512xi8>} : () -> tensor<3074x512x!quant.uniform:f32, 0.15:151>> @@ -508,3 +576,28 @@ func.func @foldQuantWeightsIntoEmbeddingLookup(%arg0: tensor<3xi32>) -> tensor<3 // CHECK-NOT: "tfl.dequantize" // CHECK: "tfl.embedding_lookup"(%arg0, %0) : (tensor<3xi32>, tensor<3074x512x!quant.uniform:f32 } + +// ----- + +// CHECK-LABEL: quantizeTFCustomOp +func.func @quantizeTFCustomOp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor, tensor, tensor) { + %0 = "tfl.quantize"(%arg0) {qtype = tensor>} : (tensor) -> tensor> + %1 = "tfl.dequantize"(%0) : (tensor>) -> (tensor) + %2 = "tfl.quantize"(%arg1) {qtype = tensor>} : (tensor) -> tensor> + %3 = "tfl.dequantize"(%2) : (tensor>) -> (tensor) + %4 = "tfl.quantize"(%arg2) {qtype = tensor>} : (tensor) -> tensor> + %5 = "tfl.dequantize"(%4) : (tensor>) -> (tensor) + %6:4 = "tfl.custom_tf"(%1, %3, %5) ({ + ^bb0(%arg3: tensor, %arg4: tensor, %arg5: tensor): + %7:4 = "tf.TFLite_Detection_PostProcess"(%arg3, %arg4, %arg5) {_output_quantized = true, _output_types = [f32, f32, f32, f32], _support_output_type_float_in_quantized_op = true} : (tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + "tfl.yield"(%7#0, %7#1, %7#2, %7#3) : (tensor, tensor, tensor, tensor) -> () + }) {_output_quantized = true, _output_types = [f32, f32, f32, f32], _support_output_type_float_in_quantized_op = true} : (tensor, tensor, tensor) -> (tensor, tensor, tensor, tensor) + return %6#0, %6#1, %6#2, %6#3 : tensor, tensor, tensor, tensor + + // CHECK: %0 = "tfl.quantize"(%arg0) <{qtype = tensor>}> : (tensor) -> tensor> + // CHECK: %1 = "tfl.quantize"(%arg1) <{qtype = tensor>}> : (tensor) -> tensor> + // CHECK: %2 = "tfl.quantize"(%arg2) <{qtype = tensor>}> : (tensor) -> tensor> + // CHECK: %3:4 = "tfl.custom_tf"(%0, %1, %2) + // CHECK: (tensor>, tensor>, tensor>) -> (tensor, tensor, tensor, tensor + // CHECK: return %3#0, %3#1, %3#2, %3#3 : tensor, tensor, tensor, tensor +} diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 3c46ef3532aecf..913625a0bebe3f 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -20,18 +20,21 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/core/macros.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h" #include "tensorflow/compiler/mlir/lite/transforms/pass.h" @@ -194,6 +197,13 @@ void AddPreQuantizationStableHloToTfPasses( // to be consistent with other entrypoints. pass_manager.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + // Expand backward compatibility with the given StableHLO version by + // decomposing newer StableHLO operations into equivalent operations supported + // by that older version. + pass_manager.addNestedPass( + mlir::stablehlo::createStablehloCompatibilityExpanderPass( + {tflite_supported_stablehlo_version})); + // Decompose CHLO into StableHLO ops pass_manager.addNestedPass( mlir::odml::CreateLegalizeChloToTflPass()); @@ -330,7 +340,7 @@ void AddPreVariableFreezingTFToTFLConversionPasses( // folded before being converted to tfl.quantize and tfl.dequantize ops. std::vector target_ops = mlir::TFL::AllTfFakeQuantOps(); mlir::TFL::RaiseCustomOpsPassOptions raise_custom_ops_pass_options; - raise_custom_ops_pass_options.target_ops_ = target_ops; + raise_custom_ops_pass_options.target_ops_ = llvm::to_vector(target_ops); pass_manager->addNestedPass( mlir::TFL::CreateRaiseCustomOpsPass(raise_custom_ops_pass_options)); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index d53c01d45a0e3c..54fdcd08e75ddf 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -53,8 +53,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/translate.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 1ef84baa619d9a..2ea617080874fb 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -40,7 +40,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -60,6 +60,7 @@ limitations under the License. #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" +#include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" #include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" @@ -69,7 +70,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" @@ -530,7 +531,7 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( .insert +#include #include #include diff --git a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc index 3fcd82ef033938..61b822a5e00f82 100644 --- a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc +++ b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc @@ -26,7 +26,7 @@ limitations under the License. #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 3c52ea8b61c235..f1b602a6763aca 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc rename to tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc index 73d102a0502f1f..7668a8af959a60 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.cc @@ -14,18 +14,17 @@ limitations under the License. ==============================================================================*/ // This transformation pass convert dense tensor to sparse format. +#include "tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h" #include "absl/memory/memory.h" #include "Eigen/Core" // from @eigen_archive #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" //===----------------------------------------------------------------------===// // The DenseToSparse Pass. @@ -35,9 +34,6 @@ namespace TFL { namespace { -#define GEN_PASS_DEF_DENSETOSPARSEPASS -#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" - // If sparsity level is below this threshold, keep the tensor in dense format. constexpr float kMinSparsityLevel = 0.3; // Heuristic to check if a block configuration is correct for float constants. @@ -277,13 +273,7 @@ std::vector BuildSparsityParameterAttribute( return compressed_data; } - -struct DenseToSparsePass - : public impl::DenseToSparsePassBase { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DenseToSparsePass) - - void runOnOperation() override; -}; +} // namespace void DenseToSparsePass::runOnOperation() { func::FuncOp func = getOperation(); @@ -418,7 +408,6 @@ void DenseToSparsePass::runOnOperation() { }); } -} // namespace // Creates an instance of the TensorFlow Lite dialect DenseToSparse pass. std::unique_ptr> CreateDenseToSparsePass() { diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h new file mode 100644 index 00000000000000..fa39e09c8d0aad --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse_pass.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass convert dense tensor to sparse format. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" + +namespace mlir { +namespace TFL { + +// This pass encodes sparse weights in the model in the proper format, and adds +// Densify() op if necessary. The general algorithm is: +// 1. Get list of operands (weights) of an op that can be sparse. +// 2. Get list of supported block configurations of the op. +// 3. Calculate random sparsity of the weight. +// 3.1. If sparsity level is below the encoding threshold, keep in dense. +// 3.2. If sparsity level is above the encoding threshold, go to 4. +// 4. Try to encode the weight with supported block configurations. If the +// weight was pruned with the same block config, the blocked sparsity level +// should match the random sparsity. +// 4.1. Return the matching block config if found. +// 4.2. If no matching block config is found, encode the weight with random +// sparsity, and add Densify() op to fall back to dense execution. + +class DenseToSparsePass + : public Pass { + public: + DenseToSparsePass() = default; + DenseToSparsePass(const DenseToSparsePass &other) {} + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static llvm::StringRef GetArgument() { return "tfl-dense-to-sparse"; } + + static llvm::StringRef GetDescription() { + return "Convert dense tensor to sparse format."; + } + + /// Returns the derived pass name. + static llvm::StringRef GetName() { return "DenseToSparsePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + /// Explicitly declare the TypeID for this class. We declare an explicit + /// private instantiation because Pass classes should only be visible by the + /// current library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DenseToSparsePass) +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_DENSE_TO_SPARSE_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index 9b0a80a4f92a71..17a45356ac53d1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 78f52b63f09243..dddd2205271621 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -36,7 +36,7 @@ limitations under the License. #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index c1cb138b2cadd0..ca4ac9f43fe70a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -1441,6 +1441,23 @@ def MatchGelu : Pat< (HasOneUse $mul_out1), ]>; +// For Gelu, replaces +// 0.5 * x * ( erfc( -x * sqrt_1_2 ) ) +def MatchGeluWithErfc : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TF_ErfcOp:$erfc_out + (TFL_MulOp:$mul_out1 + (TFL_NegOp $arg0), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_1_2), TFL_AF_None)), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrFalse), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"0.707106769"> $Cst_sqrt_1_2), + (HasOneUse $mul_out), + (HasOneUse $erfc_out), + (HasOneUse $mul_out1), + ]>; + // Fetches the output of FC op, from the provided arguments. def GetFcOutput : NativeCodeCall< "GetFcOutput(&$_builder, $0, $1, $2, $3, $4, $5, $6, $7)">; @@ -1577,7 +1594,9 @@ def FuseReshapesAroundBatchMatMulLHS: Pat< $rhs, $adj_x, $adj_y, $bool_attr), (Arith_ConstantOp $s1)), (TFL_BatchMatMulOp $input, $rhs, $adj_x, $adj_y, $bool_attr), - [(HasRank<2> $rhs), + [(HasRankAtLeast<2> $input), + (HasRankAtLeast<2> $final_shape_change), + (HasRank<2> $rhs), (HasRank<2> $initial_shape_change), (AnyStaticShapeTensor $input), (AnyStaticShapeTensor $initial_shape_change), @@ -1595,6 +1614,7 @@ def FuseReshapesAroundBatchMatMulLHS: Pat< // 1. The rank of rhs is 2 // 2. The original input reshape has a) reduction in leading broadcast dim and // b) flattening of the contracting dims. +// 3. non-broadcasting, non-contracting dims of the rhs are not flattened. def FuseOutputReshape_BatchMatMulWithFlattenedContractingDims: Pat< (TFL_ReshapeOp:$final_shape_change (TFL_BatchMatMulOp:$bmm_tmp_output @@ -1614,6 +1634,7 @@ def FuseOutputReshape_BatchMatMulWithFlattenedContractingDims: Pat< (AnyStaticShapeTensor $final_shape_change), (IsBroadcastDimEqualToOne $input), (IsBroadcastDimEqualToOne $final_shape_change), + (TrailingDimValuesEqual $bmm_tmp_output, $final_shape_change), (ContractingDimsProductEqual<2> $input, $initial_shape_change), (NonBroadcastingNonContractingLhsDimsProductEqual<0,1> $input, $initial_shape_change), (NonBroadcastingNonContractingLhsDimsProductEqual<0,1> $final_shape_change, $bmm_tmp_output)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 5dc38cc8317003..b0c3fd41171b61 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -29,7 +29,7 @@ limitations under the License. namespace mlir { namespace quant { -class QuantizationDialect; +class QuantDialect; } namespace quantfork { class QuantizationForkDialect; @@ -181,10 +181,6 @@ std::unique_ptr> CreateDefaultQuantParamsPass(); // Creates an instance of the IdentifyDilatedConvPass. std::unique_ptr> CreateIdentifyDilatedConvPass(); -// Creates an instance of the TensorFlow Lite dialect pass to convert dense -// tensor to sparse format. -std::unique_ptr> CreateDenseToSparsePass(); - // Creates function pass to legalize TF While to TFL While. std::unique_ptr> CreateLegalizeTFWhilePass(); @@ -267,7 +263,6 @@ std::unique_ptr> CreatePartitionedTopologicalSortPass(); #define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS -#define GEN_PASS_DECL_DENSETOSPARSEPASS #define GEN_PASS_DECL_LEGALIZETFPASS #define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS #define GEN_PASS_DECL_MODIFYIONODESPASS diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 09ba813055a333..98495a024a7e38 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -52,35 +52,6 @@ def DecomposeHybridQuantizationPass : Pass<"tfl-decompose-hybrid-quantization", let dependentDialects = ["TFL::TensorFlowLiteDialect"]; } -def DenseToSparsePass : Pass<"tfl-dense-to-sparse", "mlir::func::FuncOp"> { - let summary = "Convert dense tensor to sparse format."; - let description = [{ - This pass encodes sparse weights in the model in the proper format, and adds - Densify() op if necessary. The general algorithm is: - 1. Get list of operands (weights) of an op that can be sparse. - 2. Get list of supported block configurations of the op. - 3. Calculate random sparsity of the weight. - 3.1. If sparsity level is below the encoding threshold, keep in dense. - 3.2. If sparsity level is above the encoding threshold, go to 4. - 4. Try to encode the weight with supported block configurations. If the - weight was pruned with the same block config, the blocked sparsity level - should match the random sparsity. - 4.1. Return the matching block config if found. - 4.2. If no matching block config is found, encode the weight with random - sparsity, and add Densify() op to fall back to dense execution. - }]; - let constructor = "CreateDenseToSparsePass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; - let options = [ - Option<"default_min_", "default-min", "double", "-1.0", - "Default minimum value for TFLite quantization">, - Option<"default_max_", "default-max", "double", "1.0", - "Default maximum value for TFLite quantization">, - Option<"is_signed_", "is-signed", "bool", "false", - "Is the corresponding integer signed">, - ]; -} - def IdentifyDilatedConvPass : Pass<"tfl-identify-dilated-conv", "mlir::func::FuncOp"> { let summary = "Convert dense tensor to sparse format."; let constructor = "CreateIdentifyDilatedConvPass()"; @@ -115,7 +86,7 @@ def LegalizeTFPass : Pass<"tfl-legalize-tf", "mlir::func::FuncOp"> { let summary = "Legalize from TensorFlow to TensorFlow Lite dialect."; let constructor = "CreateLegalizeTFPass()"; let dependentDialects = ["TFL::TensorFlowLiteDialect" , - "quant::QuantizationDialect", + "quant::QuantDialect", "quantfork::QuantizationForkDialect" ]; let options = [ @@ -272,7 +243,7 @@ def PrepareQuantizePass : Pass<"tfl-prepare-quantize", "mlir::func::FuncOp"> { let summary = "Remove qdq from input and output nodes after quantization."; let constructor = "CreatePrepareQuantizePass()"; let dependentDialects = ["TFL::TensorFlowLiteDialect", - "quant::QuantizationDialect", + "quant::QuantDialect", "quantfork::QuantizationForkDialect" ]; let options = [ @@ -303,7 +274,7 @@ def PrepareDynamicRangeQuantizePass : Pass<"tfl-prepare-quantize-dynamic-range", let summary = "Prepare TFL dialect for dynamic range quantization."; let constructor = "CreatePrepareDynamicRangeQuantizePass()"; let dependentDialects = ["TFL::TensorFlowLiteDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect" ]; let options = [ @@ -329,7 +300,7 @@ def PrepareTFPass : Pass<"tfl-prepare-tf", "mlir::func::FuncOp"> { let summary = "Prepare TF for legalization to TensorFlow Lite dialect."; let constructor = "CreatePrepareTFPass()"; let dependentDialects = ["TFL::TensorFlowLiteDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", "mhlo::MhloDialect" ]; @@ -350,7 +321,7 @@ def QuantizePass : Pass<"tfl-quantize", "mlir::func::FuncOp"> { let summary = "Apply quantization on models in TensorFlow Lite dialect."; let constructor = "CreateDefaultQuantizePass()"; let dependentDialects = ["TFL::TensorFlowLiteDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect" ]; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 00f3e1f81acd5d..03c52e6baf04ac 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -32,8 +32,8 @@ limitations under the License. #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index 0b823844aa4a58..67b8583f3e4e59 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "llvm/Support/CommandLine.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index ee65d3c21a0bc5..824976e39953d5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index e67e0e45961117..ea3e873d464958 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -45,7 +45,7 @@ limitations under the License. #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -64,7 +64,8 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" @@ -1613,6 +1614,83 @@ class QuantizeConcatResult : public OpRewritePattern { bool use_fake_quant_num_bits_; }; +// Quantizes Mean ops where the inputs are quantized with fake quant but the +// result is not explicitly quantized. Propagating the quant parameters from the +// input to the output allow proper quantization later. +// Note that this pass is intended to work around a shortcoming of TF QAT in +// which some models do not have FQ ops generated for the output of this op. +class QuantizeMeanResult : public OpRewritePattern { + public: + QuantizeMeanResult(MLIRContext *context, bool use_fake_quant_num_bits) + : OpRewritePattern(context), + use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + LogicalResult matchAndRewrite(TF::MeanOp mean, + PatternRewriter &rewriter) const override { + // Skip ops where the output is already quantized. + for (auto *user : mean->getUsers()) { + if (mlir::dyn_cast_or_null(user) || + mlir::dyn_cast_or_null(user)) { + return failure(); + } + } + + // At this point, all pre-existing FakeQuantWithMinMaxVarsOps should have + // had qdq ops generated so we'll need to follow up the chain to get to the + // fake quants. + Value operand_value = mean.getInput(); + auto dq = mlir::dyn_cast_or_null( + operand_value.getDefiningOp()); + + if (!dq) { + return failure(); + } + + auto q = + mlir::dyn_cast_or_null(dq.getInput().getDefiningOp()); + + if (!q) { + return failure(); + } + + auto fq = mlir::dyn_cast_or_null( + q.getInput().getDefiningOp()); + + if (!fq) { + return failure(); + } + + Value mean_result = mean.getResult(); + llvm::SmallVector uses; + for (OpOperand &use : mean_result.getUses()) { + uses.push_back(&use); + } + + llvm::SmallVector inputs{mean_result, fq.getMin(), fq.getMax()}; + + rewriter.setInsertionPointAfter(mean.getOperation()); + auto new_fake_quant_op = rewriter.create( + mean.getLoc(), mean->getResultTypes(), inputs, fq->getAttrs()); + + for (OpOperand *use : uses) { + use->assign(new_fake_quant_op); + } + + // Rather than directly generating qdq ops ourselves we leverage existing + // logic to do it for us. + (void)InsertTFLQuantOpsAfterTFFakeQuantOp< + TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false, + FetchConstantMinMaxInputs>( + use_fake_quant_num_bits_) + .matchAndRewrite(new_fake_quant_op, rewriter); + + return success(); + } + + private: + bool use_fake_quant_num_bits_; +}; + void PrepareTFPass::runOnOperation() { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); @@ -1687,6 +1765,7 @@ void PrepareTFPass::runOnOperation() { (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns)); phase_3_patterns.add(ctx, use_fake_quant_num_bits_); + phase_3_patterns.add(ctx, use_fake_quant_num_bits_); (void)applyPatternsAndFoldGreedily(func, std::move(phase_3_patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index e41f98af795347..cbd64278b925bf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -25,8 +25,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc index ab03af3a4c062a..4cb6313201a1e1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc @@ -25,7 +25,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 0a450f9c28152b..d3ca3179b2b818 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -33,7 +33,7 @@ limitations under the License. #include "llvm/ADT/bit.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h index 135ddb1faef32e..477c5c678fd7fe 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/meta/type_traits.h" #include "absl/status/statusor.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc index dca30aee7fe606..0cab3ff3db32fd 100644 --- a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/variables_utils.h" #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 8d9802deaeaa66..6b0f902735754d 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -38,10 +38,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/function_optimization_registry.h" @@ -237,8 +237,8 @@ Status MlirFunctionOptimizationPass::Run( tensorflow::metrics::GetGraphOptimizationCounter(), {kTfMlirCategory, "convert_graph_to_mlir"}); - auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def, - import_config, &context); + auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + **graph, debug_info, *flib_def, import_config, &context); mlir_function_pass_graph_conversion_count ->GetCell(absl::StatusCodeToString(module_ref_status.status().code())) ->IncrementBy(1); @@ -414,7 +414,7 @@ Status MlirV1CompatGraphOptimizationPass::Run( // session runtime. import_config.restrict_functionalization_to_compiled_nodes = true; - auto module_ref_status = ConvertGraphToMlir( + auto module_ref_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( **options.graph, debug_info, *options.flib_def, import_config, &context); if (!module_ref_status.ok()) { if (pass_state == MlirOptimizationPassState::Enabled) { diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc index 720616309afe38..40dd3bdc04b0f0 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints_test.cc @@ -316,7 +316,7 @@ TEST_F(AttrsAndConstraintsTest, I64ArrayInI32RangeAreCastedCorrectly) { TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayUnderI32Range) { const int64_t under_min_i32 = -2147483658; - ArrayRef array_i64{under_min_i32}; + ArrayRef array_i64(under_min_i32); EXPECT_EQ(under_min_i32, llvm::minIntN(32) - 10); EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); @@ -324,7 +324,7 @@ TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayUnderI32Range) { TEST_F(AttrsAndConstraintsTest, CastingFailsForI64ArrayAboveI32Range) { const int64_t below_max_i32 = 2147483657; - ArrayRef array_i64{below_max_i32}; + ArrayRef array_i64(below_max_i32); EXPECT_EQ(below_max_i32, llvm::maxIntN(32) + 10); EXPECT_TRUE(failed(CastI64ArrayToI32(array_i64))); diff --git a/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc index 292e0eeb3cce71..7df20e5cd9cb23 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.cc @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h index ab89c235ff99ad..47a355228bbdf8 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h @@ -44,7 +44,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ #define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_FAKEQUANTSUPPORT_H_ -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc index a6849d5f532319..2274dda2f83fc8 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "llvm/ADT/STLExtras.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -32,7 +32,7 @@ namespace mlir::quant::ir { using mlir::quant::QuantizedType; -void QuantDialect::initialize() { +void TFQuantDialect::initialize() { addOperations< #define GET_OP_LIST #include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc.inc" diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td index 0c7a62200dd0bb..e17e25171579b0 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td @@ -22,17 +22,48 @@ limitations under the License. #ifndef QUANTIZATION_OPS #define QUANTIZATION_OPS -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td" +class quant_TypedPrimitiveOrContainer : + Type.predicate, + VectorOf<[etype]>.predicate]>, + "primitive/tensor/vector of " # etype.summary>; + +// A primitive type that can represent a real value. This is either a +// floating point value or a quantized type. +def quant_RealPrimitiveType : + Type, + "real valued primitive (float or quantized type)">; + +// A primitive type that can represent a storage value. This is either an +// integer or quantized type. +def quant_StoragePrimitiveType : + Type, + "quantized storage primitive (integer or quantized type)">; + +// A primitive or container of RealPrimitiveType. +def quant_RealValueType : + quant_TypedPrimitiveOrContainer; + +// A primitive or container of StoragePrimitiveType. +def quant_StorageValueType : + quant_TypedPrimitiveOrContainer; + +// Either a real valued or storage primitive or container type. +def quant_RealOrStorageValueType : + Type, + "real valued or storage primitive or container type">; + //===----------------------------------------------------------------------===// // Base classes //===----------------------------------------------------------------------===// class Quantization_Op traits> : - Op; + Op; //===----------------------------------------------------------------------===// // Quantization casts diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td b/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td index fb762f933d6f00..9e8a4c35f31b0e 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOpsBase.td @@ -24,7 +24,7 @@ limitations under the License. include "mlir/IR/OpBase.td" -def Quant_Dialect : Dialect { +def TF_Quant_Dialect : Dialect { let name = "quantization"; let cppNamespace = "::mlir::quant::ir"; } diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc index c0509bb8243bfc..3d5535791f31f4 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h index c0c6c30e0d6e58..5a5b7d6510042d 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td index 0305f7921bd234..0f9b6a74762f9b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td @@ -20,17 +20,24 @@ limitations under the License. #define TF_Quantization include "mlir/IR/OpBase.td" -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" //===----------------------------------------------------------------------===// // QuantizedType definitions. //===----------------------------------------------------------------------===// -// The base class of a quantized type. +// The base class of a quantized type. Signed quantized types may be expressed +// as signless integers (i.e. up to op interpretation), but we include an +// explicit signedness check to differentiate the signed/unsigned constraints +// predicates from one another at the TD level. class QuantizedType params, bit signed> : Type()">, CPred<"$_self.cast()" # - ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + ".getStorageTypeIntegralWidth() == " # !head(params)>, + Or<[CPred<"$_self.cast()" # + ".getStorageType().isSignlessInteger()">, + CPred<"$_self.cast()" # + ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; string asTraitArgsStr = diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 7645177160fc62..09860eab3ef340 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -24,11 +24,10 @@ limitations under the License. #include #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h index 070ecb75f5db5b..43edaab968dc20 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h @@ -26,7 +26,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc index f017054cbe7044..b2e6e39f641e6b 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc @@ -26,8 +26,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h index fd62f131e8dca1..e93cc4cfb46c72 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index 5223b6200fb5a8..29b14dc98dd836 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -33,7 +33,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -209,10 +209,63 @@ bool IsOpQuantizable(Operation* op) { op->getAttrOfType(kQuantTraitAttrName).getValue().str() == QuantTraitValues[QuantizationTrait::FullyQuantizable]; + const bool attr_output_quantized = QuantizableOpSupportsFloatOutputType(op); + const bool trait_enforced_quantizable = op->hasTrait(); - return attr_enforced_quantizable || trait_enforced_quantizable; + return attr_enforced_quantizable || trait_enforced_quantizable || + attr_output_quantized; +} + +// Checks if an op has specific attributes that enable quantized inputs with +// float outputs. +bool QuantizableOpSupportsFloatOutputType(Operation* op) { + static constexpr char kOutputTypes[] = "_output_types"; + static constexpr char kSupportOutputTypeFloat[] = + "_support_output_type_float_in_quantized_op"; + + if (!(op->hasAttrOfType(kOutputQuantized) && + op->getAttrOfType(kOutputQuantized).getValue())) { + return false; + } + + if (!(op->hasAttrOfType(kSupportOutputTypeFloat) && + op->getAttrOfType(kSupportOutputTypeFloat) + .getValue())) { + return false; + } + + if (!op->hasAttrOfType(kOutputTypes)) { + return false; + } + + auto output_types_attr = op->getAttrOfType(kOutputTypes); + + if (output_types_attr.size() != op->getResultTypes().size()) { + return false; + } + + for (const auto [attr_element, result_type] : + llvm::zip_equal(output_types_attr, op->getResultTypes())) { + auto type_attr = mlir::dyn_cast_or_null(attr_element); + + if (!type_attr) { + return false; + } + + auto tensor_type = mlir::dyn_cast_or_null(result_type); + + if (!tensor_type) { + return false; + } + + if (type_attr.getValue() != tensor_type.getElementType()) { + return false; + } + } + + return true; } // Returns the quantized type for the diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index 3f9f56d45fbaa7..94169e3e9436c1 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -38,7 +38,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -48,6 +48,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -80,6 +81,7 @@ inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", "not_quantizable"}; +inline constexpr char kOutputQuantized[] = "_output_quantized"; inline constexpr double kNearZeroTolerance = 1.0e-6; @@ -194,6 +196,7 @@ quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, double max, Location loc); bool IsOpQuantizable(Operation* op); +bool QuantizableOpSupportsFloatOutputType(Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { @@ -539,74 +542,90 @@ class QuantizationPattern : public RewritePattern { } } - // Collect all the quantized outputs and replace them by the results of - // the new quantized op. - llvm::SmallDenseMap outputs_replaced; - SmallVector output_types; - output_types.reserve(quantizing_op->getNumResults()); - for (const auto& enumerated_result : - llvm::enumerate(quantizing_op->getResults())) { - Value result = enumerated_result.value(); - Type result_type = result.getType(); - // Add this to the test coverage once we create test ops with none type - // results. - if (result_type.isa()) { - outputs_replaced.insert({result, enumerated_result.index()}); - output_types.push_back(result_type); - continue; + Operation* quantized_op; + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, quantizing_op->getResultTypes(), quantizing_op->getAttrs()); + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region* target_region = new_state.addRegion(); + IRMapping mapping; + indexed_regions.value().cloneInto(target_region, mapping); } - Type result_ele_type = - result.getType().cast().getElementType(); - // If the user is the QuantizeOp, it must be the only user. - if (result.hasOneUse() && - llvm::isa(*result.user_begin())) { - auto user = llvm::cast(*result.user_begin()); - outputs_replaced.insert( - {user.getResult(), enumerated_result.index()}); - output_types.push_back(user.getType()); - is_operand_or_result_modified = true; - } else if (!result_ele_type.isF32()) { - // If the result is an integer tensor, then it doesn't require the - // D op in the pattern. - outputs_replaced.insert({result, enumerated_result.index()}); - output_types.push_back(result.getType()); - } else if (static_cast(this) - ->AllowDynamicRangeQuantizedResult(quantizing_op, - custom_map)) { - outputs_replaced.insert({result, enumerated_result.index()}); - output_types.push_back(result.getType()); - } else { - return failure(); + quantized_op = rewriter.create(new_state); + rewriter.replaceOp(quantizing_op, quantized_op); + } else { + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none + // type results. + if (result_type.isa()) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + result.getType().cast().getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + is_operand_or_result_modified = true; + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } } - } - // For float16 quantization if none of the operand or result is modified, - // replacing the op. See b/335025403. - if (inference_type == tensorflow::DT_HALF && - !is_operand_or_result_modified) { - return failure(); - } + // For float16 quantization if none of the operand or result is + // modified, replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } - rewriter.setInsertionPointAfter(quantizing_op); - OperationState new_state(quantizing_op->getLoc(), - quantizing_op->getName().getStringRef(), inputs, - output_types, quantizing_op->getAttrs()); - for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { - new_state.addRegion(); - } - Operation* quantized_op = rewriter.create(new_state); - if (quantizing_op->getNumRegions() != 0) { - for (const auto& indexed_regions : - llvm::enumerate(quantizing_op->getRegions())) { - Region& target_region = - quantized_op->getRegion(indexed_regions.index()); - IRMapping mapping; - indexed_regions.value().cloneInto(&target_region, mapping); + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); } - } - for (auto output : outputs_replaced) { - output.getFirst().replaceAllUsesWith( - quantized_op->getResult(output.getSecond())); } // To verify the numericals, the original floating-point ops are diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index 4564de6c3d5603..f33e586c100d5f 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -52,7 +52,7 @@ class QuantizationTestBase : public Test { arith::ArithDialect, mlir::stablehlo::StablehloDialect, func::FuncDialect, TF::TensorFlowDialect, TFL::TensorFlowLiteDialect, tf_saved_model::TensorFlowSavedModelDialect, - tf_executor::TensorFlowExecutorDialect, quant::QuantizationDialect, + tf_executor::TensorFlowExecutorDialect, quant::QuantDialect, quantfork::QuantizationForkDialect>(); } diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc index 7f66d76798acfa..fb7386d351209e 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc @@ -19,7 +19,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index e30db98a9616de..99815f73104da3 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index d4055b1732b1d8..6fa22395b2cb91 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -46,7 +46,7 @@ using ::testing::Test; class CreateI8F32UniformQuantizedTypeTest : public Test { protected: CreateI8F32UniformQuantizedTypeTest() : ctx_() { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -106,7 +106,7 @@ TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { class CreateI32F32UniformQuantizedTypeTest : public Test { protected: CreateI32F32UniformQuantizedTypeTest() : ctx_() { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -161,7 +161,7 @@ TEST_F(CreateI32F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) { class CreateI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: CreateI8F32UniformQuantizedPerAxisTypeTest() : ctx_() { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -256,7 +256,7 @@ TEST_F(CreateI8F32UniformQuantizedPerAxisTypeTest, class CreateI32F32UniformQuantizedPerAxisTypeTest : public Test { protected: CreateI32F32UniformQuantizedPerAxisTypeTest() : ctx_() { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -328,7 +328,7 @@ TEST_F(CreateI32F32UniformQuantizedPerAxisTypeTest, class IsI8F32UniformQuantizedTypeTest : public Test { protected: IsI8F32UniformQuantizedTypeTest() : builder_(&ctx_) { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -371,7 +371,7 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { class IsI8F32UniformQuantizedPerAxisTypeTest : public Test { protected: IsI8F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -429,7 +429,7 @@ TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { class IsI32F32UniformQuantizedTypeTest : public Test { protected: IsI32F32UniformQuantizedTypeTest() : builder_(&ctx_) { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -483,7 +483,7 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, ExpressedTypeF32Succeeds) { class IsI32F32UniformQuantizedPerAxisTypeTest : public Test { protected: IsI32F32UniformQuantizedPerAxisTypeTest() : builder_(&ctx_) { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; @@ -556,7 +556,7 @@ TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, ExpressedTypeF32Succeeds) { class IsSupportedByTfliteQuantizeOrDequantizeOpsTest : public Test { protected: IsSupportedByTfliteQuantizeOrDequantizeOpsTest() : builder_(&ctx_) { - ctx_.loadDialect(); + ctx_.loadDialect(); } MLIRContext ctx_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 036d0680d97009..9eea3596a84296 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -144,10 +144,10 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:unfuse_batch_norm", + "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_portable_api", @@ -340,9 +340,9 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_passes", ], @@ -726,7 +726,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_options_py_pb2", -# api_version = 2, # visibility = [":internal_visibility_allowlist_package"], # deps = [":quantization_options_proto"], # ) @@ -746,7 +745,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_config_py_pb2", -# api_version = 2, # visibility = [ # ":internal_visibility_allowlist_package", # ], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index f850b7f4e775ff..dcf16782319d12 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -278,10 +278,8 @@ cc_library( hdrs = ["pre_calibration.h"], compatible_with = get_compatible_with_portable(), visibility = [ - "//tensorflow:__pkg__", "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", - "//tensorflow/python:__pkg__", ], deps = [ ":component", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index 5575a7516fccc9..d2acdfd64065ae 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -49,8 +49,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index ad0179f3c051a1..f206fbdafe739c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project @@ -69,7 +69,7 @@ class ConvertTfQuantToMhloIntTest : public Test { void SetUp() override { DialectRegistry dialects; dialects.insert(); + mhlo::MhloDialect, quant::QuantDialect>(); ctx_ = std::make_unique(dialects); ctx_->loadAllAvailableDialects(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td index 0eaebcbd28c734..28756bd79a4dcd 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.td @@ -27,7 +27,7 @@ def ConvertTFQuantOpsToMHLO : Pass<"quant-convert-tf-quant-ops-to-mhlo", "mlir:: let constructor = "mlir::quant::stablehlo::CreateConvertTFQuantOpsToMHLOPass()"; let dependentDialects = ["TF::TensorFlowDialect", "chlo::ChloDialect", "mhlo::MhloDialect", "tf_type::TFTypeDialect", - "quant::QuantizationDialect"]; + "quant::QuantDialect"]; } def ConvertTFQuantTypes : Pass<"convert-tf-quant-types", "mlir::func::FuncOp"> { @@ -53,7 +53,7 @@ def VerifyQuantLegalization : Pass<"verify-quant-legalization", "mlir::func::Fun let constructor = "mlir::quant::stablehlo::CreateVerifyQuantLegalizationPass()"; let dependentDialects = ["tf_type::TFTypeDialect", - "quant::QuantizationDialect"]; + "quant::QuantDialect"]; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc index 7484ed89aa51b1..ad8798b9695c9d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc index 1a05364991ccc6..38d70269eb0130 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_xla_call_module_op_to_bfloat16.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc index 8cb0b645c312cf..e855c51749e6d5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_calibration_statistics_saver.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -180,7 +181,7 @@ CreateInsertCalibrationStatisticsSaverPass( StringRef calibration_data_dir, const std::vector& aggregator_ops_to_ignore) { InsertCalibrationStatisticsSaverPassOptions options = { - .aggregator_ops_to_ignore_ = aggregator_ops_to_ignore, + .aggregator_ops_to_ignore_ = llvm::to_vector(aggregator_ops_to_ignore), .calibration_data_dir_ = calibration_data_dir.str(), }; return std::make_unique(options); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 28396ec71ab07e..7583f1b0fe2be5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -18,8 +18,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc index 9a0d8fb2a25b2b..24e148949215e8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -17,7 +17,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index 7661e8d562fbe9..da59c218a56926 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -71,7 +71,7 @@ def QuantizeCompositeFunctionsPass : Pass<"stablehlo-quantize-composite-function let dependentDialects = [ "mlir::arith::ArithDialect", "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", "TF::TensorFlowDialect", ]; @@ -89,7 +89,7 @@ def PrepareQuantizePass : Pass<"stablehlo-prepare-quantize", "mlir::ModuleOp"> { ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", "mlir::arith::ArithDialect", ]; @@ -105,7 +105,7 @@ def QuantizePass : Pass<"stablehlo-quantize", "mlir::ModuleOp"> { ]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", ]; } @@ -147,7 +147,7 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu let summary = "Convert serialized XlaCallModuleOp to bfloat16"; let dependentDialects = [ "TF::TensorFlowDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::shape::ShapeDialect", "mlir::stablehlo::StablehloDialect", ]; @@ -192,7 +192,7 @@ def InsertWeightParamPass : Pass<"stablehlo-insert-weight-param", "mlir::func::F let dependentDialects = [ "mlir::stablehlo::StablehloDialect", "TF::TensorFlowDialect", - "mlir::quant::QuantizationDialect", + "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", ]; } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc index eb99e657875a7d..4052988230b108 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/post_quantize.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc index 7d2df9e27f9220..824d24065e239b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/prepare_quantize.cc @@ -17,7 +17,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 787fca3594f14a..350b6f786452ab 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -24,8 +24,8 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BlockSupport.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index c07314d6cff6cf..5e45d6d7d36625 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc index 86dbae8e4181f9..1542cf181e649d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc index a713f5501b271d..1c2c559d79ee84 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_composite_functions.cc @@ -16,7 +16,7 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td index ee525f2deead04..3139e48b840744 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.td @@ -43,7 +43,7 @@ def TestPostCalibrationComponentPass : Pass<"stablehlo-test-post-calibration-com let dependentDialects = [ "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", "mlir::func::FuncDialect", "mlir::mhlo::MhloDialect", - "mlir::quant::QuantizationDialect", "mlir::chlo::ChloDialect", + "mlir::quant::QuantDialect", "mlir::chlo::ChloDialect", "mlir::vhlo::VhloDialect", "mlir::shape::ShapeDialect", "mlir::quantfork::QuantizationForkDialect", ]; @@ -56,7 +56,7 @@ def TestTFToStablehloPass : Pass<"stablehlo-test-tf-to-stablehlo", "mlir::Module }]; let dependentDialects = [ "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", - "mlir::chlo::ChloDialect", "mlir::quant::QuantizationDialect", + "mlir::chlo::ChloDialect", "mlir::quant::QuantDialect", "mlir::mhlo::MhloDialect", "mlir::shape::ShapeDialect", "mlir::sparse_tensor::SparseTensorDialect", "mlir::vhlo::VhloDialect", ]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc index bdf7f311f26bfa..8ad34c8ac66674 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_post_calibration_component.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc index 3af53a213b0064..862a4b628d6497 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/test_tf_to_stablehlo_pass.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 5ca03bfc209656..a022cbc9e05688 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -151,6 +151,9 @@ tf_python_pybind_extension( name = "pywrap_quantization", srcs = ["pywrap_quantization.cc"], pytype_srcs = ["pywrap_quantization.pyi"], + visibility = [ + "//tensorflow/python:__pkg__", + ], # Each dependency MUST be either header-only or exclusive. deps = [ ":pywrap_quantization_lib_header_only", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index c6a7904bc0fa3d..c14cff87984890 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -17,7 +17,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project @@ -60,7 +60,7 @@ int main(int argc, char** argv) { mlir::tf_saved_model::TensorFlowSavedModelDialect, mlir::func::FuncDialect, mlir::shape::ShapeDialect, mlir::arith::ArithDialect, mlir::tf_type::TFTypeDialect, - mlir::quant::QuantizationDialect, mlir::tensor::TensorDialect, + mlir::quant::QuantDialect, mlir::tensor::TensorDialect, mlir::quantfork::QuantizationForkDialect, mlir::stablehlo::StablehloDialect, mlir::tf_executor::TensorFlowExecutorDialect, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 62abb400ca5b34..1acfb785dbd117 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include #include "llvm/Support/Casting.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -85,7 +85,7 @@ std::unique_ptr CreateContext() { RegisterCommonToolingDialects(mlir_registry); context->appendDialectRegistry(mlir_registry); context->getOrLoadDialect(); - context->getOrLoadDialect(); + context->getOrLoadDialect(); context->getOrLoadDialect(); context->getOrLoadDialect(); return context; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index 9c86ba1366869f..b9b5aded172925 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -463,9 +463,9 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:str_util", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/mlir_hlo", + "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", ], # Alwayslink is required for registering the MLIR passes. @@ -555,7 +555,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "quantization_options_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":quantization_options_proto"], # ) @@ -576,13 +575,13 @@ tf_proto_library( ":internal_visibility_allowlist_package", # To be visible from `lib_internal_impl`. "//tensorflow/core:__pkg__", + "//tensorflow/python:__pkg__", ], ) # copybara:uncomment_begin(google-only) # py_proto_library( # name = "exported_model_py_pb2", -# api_version = 2, # deps = [":exported_model_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD index 7931e3cd51e9db..b1f7caa3b0b0ec 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/BUILD @@ -110,7 +110,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "calibration_statistics_py_pb2", -# api_version = 2, # deps = [ # ":calibration_statistics_proto", # ], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index 80167b43fb9c5e..acbb7c3b03327b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -44,6 +44,7 @@ tf_cc_test( deps = [ ":mlir_dump", "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc index 2510fb96e39591..a5a7206253c19b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -27,10 +27,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -39,7 +37,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" #include "tsl/platform/path.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/stringpiece.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc index be49ddb7e03b08..0c6aaf2124ff47 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include "absl/cleanup/cleanup.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinDialect.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto b/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto index b238851a7e5415..673caa3268828e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto @@ -6,8 +6,6 @@ import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/protobuf/meta_graph.proto"; import "tensorflow/core/protobuf/saver.proto"; -option cc_enable_arenas = true; - // Represents an exported TensorFlow model. It consists of a GraphDef and extra // metadata required for building a SavedModel. This message is primarily used // to "export" the model produced from various quantization passes in c++ to @@ -15,10 +13,11 @@ option cc_enable_arenas = true; // Next ID: 11 message ExportedModel { reserved 3, 4, 7, 9; - reserved 'variable_shared_names'; - reserved 'restore_node_name'; - reserved 'save_node_name'; - reserved 'file_prefix_tensor_name'; + + reserved "variable_shared_names"; + reserved "restore_node_name"; + reserved "save_node_name"; + reserved "file_prefix_tensor_name"; GraphDef graph_def = 1; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc index 47beb9e0c2636f..64586754f6f0cd 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/optional.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc index 6fea7f1cc4778a..6af57c4013f2f4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -49,7 +49,7 @@ TEST(TfQuantOpTest, applyUniformQuantization) { MLIRContext context; OwningOpRef module(ModuleOp::create(UnknownLoc::get(&context))); OpBuilder builder(&module->getBodyRegion()); - context.loadDialect(); EmptyPatternRewriter pattern_rewriter(builder); Value value = CreateConstValue(builder, module->getLoc(), {1024, 2}, diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index 239fe32946ab87..bd5c2aee07e3ea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -139,7 +139,7 @@ class AddDumpTensorOpPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc index 8c02ace87d8001..29d96627fe47f7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc @@ -20,7 +20,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/SourceMgr.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -59,7 +59,7 @@ class ConvertCustomAggregationOpToQuantStatsPass void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_fake_quant_to_qdq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_fake_quant_to_qdq.cc index 2f2906ceb933c9..b2452461827a40 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_fake_quant_to_qdq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_fake_quant_to_qdq.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" @@ -46,7 +46,7 @@ class ConvertFakeQuantToQdqPass void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index a87245345f6987..fa0f3b3cf77930 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -23,7 +23,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index cad8c1686eb67b..1e92ad5fc4475b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -23,8 +23,8 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -61,7 +61,7 @@ using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; class PrepareQuantizePass : public PassWrapper> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index b2c0ceb205ca99..0f001dfdc8a5f6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -24,8 +24,8 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -53,7 +53,7 @@ using ::tensorflow::quantization::OpSet; class PrepareQuantizeDRQPass : public PassWrapper> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 08b2faadacd3d5..8d7193a3973a52 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -64,7 +64,7 @@ using ::tensorflow::quantization::OpSet; class PreprocessOpPass : public PassWrapper> { void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index 26e468556a36ab..6d28e3d1c793f4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 50409709d44854..b255ddda01395f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -30,8 +30,8 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -118,7 +118,7 @@ class QuantizeCompositeFunctionsPass } void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc index 6355fb1e37c8e3..0ac4ba709cde61 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_weights.cc @@ -18,8 +18,8 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -80,7 +80,7 @@ class QuantizeWeightsPass } void getDependentDialects(DialectRegistry& registry) const override { - registry.insert(); + registry.insert(); } private: diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc index 87d230fb16bbde..a403f75403d4f4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc @@ -14,7 +14,7 @@ limitations under the License. #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project @@ -39,7 +39,7 @@ int main(int argc, char **argv) { mlir::tf_saved_model::TensorFlowSavedModelDialect, mlir::func::FuncDialect, mlir::shape::ShapeDialect, mlir::arith::ArithDialect, mlir::tf_type::TFTypeDialect, - mlir::quant::QuantizationDialect, + mlir::quant::QuantDialect, mlir::quantfork::QuantizationForkDialect, mlir::tf_executor::TensorFlowExecutorDialect, mlir::stablehlo::StablehloDialect>(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index 2ae799cca85968..031c0896ac5081 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -1,3 +1,4 @@ +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_pywrap") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load( "//tensorflow:tensorflow.default.bzl", @@ -93,6 +94,7 @@ cc_library( ":py_function_lib", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -202,6 +204,7 @@ tf_python_pybind_extension( pytype_srcs = ["pywrap_function_lib.pyi"], visibility = [ "__subpackages__", + "//tensorflow/python:__pkg__", "//tensorflow/tools/pip_package:__subpackages__", ], deps = [ @@ -229,7 +232,6 @@ tf_python_pybind_extension( # All deps must be header-only. deps = [ ":py_function_lib", - ":quantize_model_cc", ":type_casters", "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", @@ -245,7 +247,10 @@ tf_python_pybind_extension( "@pybind11_abseil//pybind11_abseil:import_status_module", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], + ] + if_pywrap( + [":quantize_model_cc_impl"], + [":quantize_model_cc"], + ), ) tf_py_strict_test( diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc index fc181edb8a75f5..499a496c572153 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/pywrap_function_lib.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "pybind11/cast.h" // from @pybind11 #include "pybind11/detail/common.h" // from @pybind11 diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc index e38310879184ef..016ba8dd41ce6f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.cc @@ -49,11 +49,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" -#include "tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h index a54e988c043aa3..9e36ce52f74cbc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/quantize_model.h @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { namespace quantization { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc index c3f5c32bdd9720..e7086c57ddc2c2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/unfreeze_constants.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/core/platform/env.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto index d2c79b6ce4c668..a9ec6b554146f7 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto @@ -5,8 +5,6 @@ package tensorflow.quantization; import "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto"; import "tensorflow/core/framework/tensor.proto"; -option cc_enable_arenas = true; - // This file contains the definition of TF GraphDef-level mixed-precision // quantization configuration. The configuration will be used in the // quantization path to determine the following factors: @@ -95,9 +93,11 @@ message UnitWiseQuantizationSpec { // Type of the op, ex: Conv2D, MatMul, Einsum... The node_name field can // be omitted if it is intended to match all nodes with this type. string op_type = 1; + // Name of the node. This field accepts re2 regex format. If the node name // has enough granularity, the op_type field can be omitted. string node_name = 2; + // The function scope. If set, only ops and nodes under specified functions // are matched. Note that, Uniqueness of node name isn't guaranteed across // functions. But within each function, uniqueness is guaranteed. If users @@ -105,6 +105,7 @@ message UnitWiseQuantizationSpec { // field accepts re2 regex format. string func_name = 3; } + repeated QuantizationUnit unit = 5; // Quantization option information for the current unit. @@ -118,13 +119,17 @@ message UnitWiseQuantizationSpec { // NEXT ID: 5 enum OpSet { OP_SET_UNSPECIFIED = 0; // go/do-include-enum-unspecified + // Uses TF ops that mimic quantization behavior. Used when the corresponding // integer op is not yet present. TF = 1; + // Uses TF XLA ops XLA = 2; + // Uses TF Uniform Quantized ops UNIFORM_QUANTIZED = 3; + // Uses the StableHLO Quantizer. StableHLO Quantizer will be available as // an option in the TF Quantizer in StableHLO Quantizer v1. STABLEHLO = 4; @@ -173,7 +178,6 @@ message QuantizationOptions { // the quantization configuration for units that are not specified in // unit-wise configurations. QuantizationMethod quantization_method = 1; - OpSet op_set = 2; // If not specified, it defaults to `XLA`. // Quantization spec for each unit. This quantization spec will override the diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 7eb557e1d788c0..fbc28304cbc79e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -32,8 +32,8 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" diff --git a/tensorflow/compiler/mlir/register_common_dialects.cc b/tensorflow/compiler/mlir/register_common_dialects.cc index fe626375a8ee8f..e1e228576a5aa7 100644 --- a/tensorflow/compiler/mlir/register_common_dialects.cc +++ b/tensorflow/compiler/mlir/register_common_dialects.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project @@ -23,11 +23,13 @@ limitations under the License. #include "mlir/InitAllExtensions.h" // from @llvm-project #include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "xla/mlir/framework/ir/xla_framework.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "tensorflow/core/ir/types/dialect.h" namespace mlir { @@ -40,7 +42,7 @@ void RegisterCommonToolingDialects(mlir::DialectRegistry& registry) { registry.insert(); registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index 355fd9ecb5da21..0425d7d4300f96 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -32,6 +32,11 @@ tsl_pybind_extension( "-frtti", ], features = ["-use_header_modules"], + visibility = [ + ":friends", + "//tensorflow/python:__pkg__", + "//tensorflow/tools/pip_package:__subpackages__", + ], deps = [ "//third_party/python_runtime:headers", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 1b1674c42d5a58..2cf27e9b688efb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -654,7 +654,9 @@ cc_library( hdrs = ["utils/topological_sort.h"], deps = [ "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -769,9 +771,9 @@ cc_library( "@llvm-project//mlir:Support", "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/service:hlo_parser", + "@local_xla//xla/hlo/parser:hlo_parser", ], ) @@ -791,8 +793,12 @@ cc_library( hdrs = ["utils/tpu_cluster_util.h"], deps = [ ":device_util", + ":tensorflow", + ":tensorflow_structs", ":tpu_rewrite_device_util", + "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], ) @@ -808,6 +814,7 @@ cc_library( deps = [ "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1019,12 +1026,15 @@ cc_library( ":translate_cl_options", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/utils:string_container_utils", + "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_argument", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -1032,13 +1042,19 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/service:hlo_module_config", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:stablehlo_ops", ], alwayslink = 1, @@ -1139,6 +1155,7 @@ cc_library( "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -1164,9 +1181,15 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform:status_matchers", "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:xla_data_proto_cc", ], ) @@ -1217,7 +1240,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/io:buffered_file", + "@local_xla//xla/tsl/lib/io:buffered_file", ], ) @@ -1293,10 +1316,14 @@ tf_cc_test( ":serialize_mlir_module_utils", ":tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:test", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", ], ) @@ -1331,6 +1358,8 @@ cc_library( ], deps = [ ":dump_mlir_util", + "//tensorflow/core:lib", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -1395,9 +1424,10 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/math:math_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/service:hlo_parser", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/parser:hlo_parser", + "@local_xla//xla/tsl/lib/math:math_util", ], ) @@ -1416,6 +1446,7 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:xla_data_proto_cc", ], ) @@ -1441,6 +1472,9 @@ cc_library( ":tensorflow_types", "//tensorflow/core/ir:shape_inference_utils", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", ], ) @@ -1470,6 +1504,7 @@ cc_library( hdrs = ["utils/verify_suitable_for_graph_export.h"], deps = [ ":tensorflow", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1602,6 +1637,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -1627,9 +1663,12 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf/tpu:topology_proto_cc", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 73b57c0338a635..a0d8c94cf47997 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -45,8 +45,10 @@ tf_cuda_library( "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:refcount", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 726be6f34d9de5..86b4bd4eef46e9 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -13,14 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/raw_ostream.h" @@ -56,7 +60,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_clustering_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_clustering_end_to_end.mlir index 7d39e68a02e046..8f0c48df73f6db 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_clustering_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_clustering_end_to_end.mlir @@ -4,25 +4,25 @@ // CHECK %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor // CHECK %1 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor<25x5xf32> // CHECK %2 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<25x5xf32> -// CHECK %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %cst_0 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %3 = "tf.Fill"(%cst_0, %cst) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %cst_1 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %cst_2 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %4 = "tf.Fill"(%cst_2, %cst_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %cst_3 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %cst_4 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %5 = "tf.Fill"(%cst_4, %cst_3) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %cst_5 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %cst_6 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %6 = "tf.Fill"(%cst_6, %cst_5) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_0 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %3 = "tf.Fill"(%cst_0, %cst) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %cst_1 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_2 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %4 = "tf.Fill"(%cst_2, %cst_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %cst_3 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_4 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %5 = "tf.Fill"(%cst_4, %cst_3) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %cst_5 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_6 = "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %6 = "tf.Fill"(%cst_6, %cst_5) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> // CHECK %7:8 = tf_device.replicate([%1, %3, %4, %3] as %arg2: tensor<25x5xf32>, [%2, %5, %6, %5] as %arg3: tensor<25x5xf32>) {n = 4 : i32} { // CHECK %10 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ -// CHECK %13 = "tf.Identity"(%arg2) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %13 = "tf.Identity"(%arg2) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK tf_device.return %13 : tensor<25x5xf32> // CHECK }) : () -> tensor<25x5xf32> // CHECK %11 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ -// CHECK %13 = "tf.Identity"(%arg3) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %13 = "tf.Identity"(%arg3) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK tf_device.return %13 : tensor<25x5xf32> // CHECK }) : () -> tensor<25x5xf32> // CHECK %12:2 = "tf_device.cluster_func"(%10, %11) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0], host_compute_core = [], input_sharding_configuration = ["", ""], num_cores_per_replica = 1 : i64, output_sharding_configuration = ["", ""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\01\10\02\18\02\22\10\00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<25x5xf32>, tensor<25x5xf32>) -> (tensor<128xf32>, tensor<128xf32>) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_mlir_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_mlir_end_to_end.mlir index 482112f82689c1..b0546071dcb7e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_mlir_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_non_spmd_mlir_end_to_end.mlir @@ -4,32 +4,32 @@ // CHECK-LABEL: func.func @main // CHECK %outputs, %control = tf_executor.island wraps "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf_type.resource>>) -> tensor<25x5xf32> // CHECK %outputs_0, %control_1 = tf_executor.island wraps "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>>) -> tensor<25x5xf32> -// CHECK %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %outputs_6, %control_7 = tf_executor.island wraps "tf.Fill"(%outputs_4, %outputs_2) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %outputs_10, %control_11 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %outputs_12, %control_13 = tf_executor.island wraps "tf.Fill"(%outputs_10, %outputs_8) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %outputs_14, %control_15 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %outputs_16, %control_17 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %outputs_18, %control_19 = tf_executor.island wraps "tf.Fill"(%outputs_16, %outputs_14) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> -// CHECK %outputs_20, %control_21 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor -// CHECK %outputs_22, %control_23 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> -// CHECK %outputs_24, %control_25 = tf_executor.island wraps "tf.Fill"(%outputs_22, %outputs_20) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %outputs_6, %control_7 = tf_executor.island wraps "tf.Fill"(%outputs_4, %outputs_2) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %outputs_10, %control_11 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %outputs_12, %control_13 = tf_executor.island wraps "tf.Fill"(%outputs_10, %outputs_8) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %outputs_14, %control_15 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %outputs_16, %control_17 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %outputs_18, %control_19 = tf_executor.island wraps "tf.Fill"(%outputs_16, %outputs_14) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> +// CHECK %outputs_20, %control_21 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %outputs_22, %control_23 = tf_executor.island wraps "tf.Const"() <{value = dense<[25, 5]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %outputs_24, %control_25 = tf_executor.island wraps "tf.Fill"(%outputs_22, %outputs_20) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<25x5xf32> // CHECK %outputs_26:2, %control_27 = tf_executor.island wraps "tf._TPUCompileMlir"() // CHECK %outputs_28, %control_29 = tf_executor.island wraps "tf.Identity"(%outputs_26#0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor // CHECK %control_30 = tf_executor.island wraps "tf.TPUCompileSucceededAssert"(%outputs_28) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> () -// CHECK %outputs_31, %control_32 = tf_executor.island wraps "tf.Identity"(%outputs) {_parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> -// CHECK %outputs_33, %control_34 = tf_executor.island wraps "tf.Identity"(%outputs_0) {_parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_31, %control_32 = tf_executor.island wraps "tf.Identity"(%outputs) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_33, %control_34 = tf_executor.island wraps "tf.Identity"(%outputs_0) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK %outputs_35:2, %control_36 = tf_executor.island wraps "tf.TPUExecute"(%outputs_31, %outputs_33, %outputs_26#1) {_parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:0"} : (tensor<25x5xf32>, tensor<25x5xf32>, tensor<3x!tf_type.string>) -> (tensor<128xf32>, tensor<128xf32>) -// CHECK %outputs_37, %control_38 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> -// CHECK %outputs_39, %control_40 = tf_executor.island wraps "tf.Identity"(%outputs_18) {_parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_37, %control_38 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_39, %control_40 = tf_executor.island wraps "tf.Identity"(%outputs_18) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK %outputs_41:2, %control_42 = tf_executor.island wraps "tf.TPUExecute"(%outputs_37, %outputs_39, %outputs_26#1) {_parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<25x5xf32>, tensor<25x5xf32>, tensor<3x!tf_type.string>) -> (tensor<128xf32>, tensor<128xf32>) -// CHECK %outputs_43, %control_44 = tf_executor.island wraps "tf.Identity"(%outputs_12) {_parallel_execution_ids = "r0:2", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> -// CHECK %outputs_45, %control_46 = tf_executor.island wraps "tf.Identity"(%outputs_24) {_parallel_execution_ids = "r0:2", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_43, %control_44 = tf_executor.island wraps "tf.Identity"(%outputs_12) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:2", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_45, %control_46 = tf_executor.island wraps "tf.Identity"(%outputs_24) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:2", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK %outputs_47:2, %control_48 = tf_executor.island wraps "tf.TPUExecute"(%outputs_43, %outputs_45, %outputs_26#1) {_parallel_execution_ids = "r0:2", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:1"} : (tensor<25x5xf32>, tensor<25x5xf32>, tensor<3x!tf_type.string>) -> (tensor<128xf32>, tensor<128xf32>) -// CHECK %outputs_49, %control_50 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_parallel_execution_ids = "r0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> -// CHECK %outputs_51, %control_52 = tf_executor.island wraps "tf.Identity"(%outputs_18) {_parallel_execution_ids = "r0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", ici_weight_distribution_mlir_bridge_marker = true} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_49, %control_50 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> +// CHECK %outputs_51, %control_52 = tf_executor.island wraps "tf.Identity"(%outputs_18) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"} : (tensor<25x5xf32>) -> tensor<25x5xf32> // CHECK %outputs_53:2, %control_54 = tf_executor.island wraps "tf.TPUExecute"(%outputs_49, %outputs_51, %outputs_26#1) {_parallel_execution_ids = "r0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<25x5xf32>, tensor<25x5xf32>, tensor<3x!tf_type.string>) -> (tensor<128xf32>, tensor<128xf32>) // CHECK %outputs_55, %control_56 = tf_executor.island wraps "tf.Identity"(%outputs_41#0) {device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<128xf32>) -> tensor<128xf32> // CHECK %outputs_57, %control_58 = tf_executor.island wraps "tf.Identity"(%outputs_53#1) {device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<128xf32>) -> tensor<128xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir index 3db8828faa3dcb..b48b67ccb4c15e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_clustering_end_to_end.mlir @@ -4,20 +4,37 @@ // CHECK %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor // CHECK %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> // CHECK %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> -// CHECK %3:2 = tf_device.replicate {n = 2 : i32} { -// CHECK %6 = "tf_device.cluster_func"(%1, %2) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> -// CHECK tf_device.return %6 : tensor<*xf32> +// CHECK %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_0 = "tf.Const"() <{value = dense<[128, 1024]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK %3 = "tf.Fill"(%cst_0, %cst) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<128x1024xf32> +// CHECK %cst_1 = "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK %cst_2 = "tf.Const"() <{value = dense<1024> : tensor<1xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1xi64> +// CHECK %4 = "tf.Fill"(%cst_2, %cst_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1xi64>, tensor) -> tensor<1024xf32> +// CHECK %5:2 = tf_device.replicate([%1, %3] as %arg22: tensor<128x1024xf32>, [%2, %4] as %arg23: tensor<1024xf32>) {n = 2 : i32} { +// CHECK %8 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ +// CHECK %11 = "tf.Identity"(%arg22) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK tf_device.return %11 : tensor<128x1024xf32> +// CHECK }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x1024xf32> +// CHECK %9 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ +// CHECK %11 = "tf.Identity"(%arg23) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK tf_device.return %11 : tensor<1024xf32> +// CHECK }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1024xf32> +// CHECK %10 = "tf_device.cluster_func"(%8, %9) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> +// CHECK tf_device.return %10 : tensor<*xf32> // CHECK } -// CHECK %4 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor -// CHECK %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor -// CHECK return %5 : tensor +// CHECK %6 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor +// CHECK %7 = "tf.Identity"(%6) {device = ""} : (tensor) -> tensor +// CHECK return %7 : tensor - -// CHECK-LABEL: func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { -// CHECK %0 = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> -// CHECK %1 = "tf.MatMul"(%0, %arg1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> -// CHECK return %1 : tensor<*xf32> -// CHECK } +// CHECK-LABEL: func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { +// CHECK %cst = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> +// CHECK %0 = "tf.XlaAllReduce"(%arg0, %cst) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<128x1024xf32>, tensor<1x2xi32>) -> tensor<128x1024xf32> +// CHECK %cst_0 = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32> +// CHECK %1 = "tf.XlaAllReduce"(%arg1, %cst_0) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<1024xf32>, tensor<1x2xi32>) -> tensor<1024xf32> +// CHECK %2 = "tf.XlaSharding"(%0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK %3 = "tf.MatMul"(%2, %1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> +// CHECK return %3 : tensor<*xf32> +// CHECK } module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor<*xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir index ce7a4f02f96619..9b3d887b1aa1ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/ici_weight_distribution_spmd_mlir_end_to_end.mlir @@ -3,22 +3,33 @@ // CHECK-LABEL: func.func @main // CHECK: %outputs, %control = tf_executor.island wraps "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> // CHECK: %outputs_0, %control_1 = tf_executor.island wraps "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> -// CHECK: %outputs_2:5, %control_3 = tf_executor.island wraps "tf._TPUCompileMlir"() -// CHECK: %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_2#0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor -// CHECK: %control_6 = tf_executor.island wraps "tf.TPUCompileSucceededAssert"(%outputs_4) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> () -// CHECK: %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor -// CHECK: %outputs_9:4, %control_10 = tf_executor.island wraps "tf.Split"(%outputs_7, %outputs) {num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) -// CHECK: %outputs_11, %control_12 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#0, %outputs_0, %outputs_2#1) {_parallel_execution_ids = "r0:0,p0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_13, %control_14 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#1, %outputs_0, %outputs_2#2) {_parallel_execution_ids = "r0:0,p0:1", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_15, %control_16 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#2, %outputs_0, %outputs_2#3) {_parallel_execution_ids = "r0:0,p0:2", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_17, %control_18 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#3, %outputs_0, %outputs_2#4) {_parallel_execution_ids = "r0:0,p0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_19, %control_20 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#0, %outputs_0, %outputs_2#1) {_parallel_execution_ids = "r0:1,p0:0", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_21, %control_22 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#1, %outputs_0, %outputs_2#2) {_parallel_execution_ids = "r0:1,p0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_23, %control_24 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#2, %outputs_0, %outputs_2#3) {_parallel_execution_ids = "r0:1,p0:2", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_25, %control_26 = tf_executor.island wraps "tf.TPUExecute"(%outputs_9#3, %outputs_0, %outputs_2#4) {_parallel_execution_ids = "r0:1,p0:3", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> -// CHECK: %outputs_27, %control_28 = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor -// CHECK: %outputs_29, %control_30 = tf_executor.island wraps "tf.Identity"(%outputs_27) {device = ""} : (tensor) -> tensor -// CHECK: tf_executor.fetch %outputs_29, %control, %control_1, %control_12, %control_14, %control_16, %control_18, %control_20, %control_22, %control_24, %control_26, %control_28 : tensor, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control +// CHECK: %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_4, %control_5 = tf_executor.island wraps "tf.Const"() <{value = dense<[128, 1024]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64> +// CHECK: %outputs_6, %control_7 = tf_executor.island wraps "tf.Fill"(%outputs_4, %outputs_2) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor) -> tensor<128x1024xf32> +// CHECK: %outputs_8, %control_9 = tf_executor.island wraps "tf.Const"() <{value = dense<0.000000e+00> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_10, %control_11 = tf_executor.island wraps "tf.Const"() <{value = dense<1024> : tensor<1xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1xi64> +// CHECK: %outputs_12, %control_13 = tf_executor.island wraps "tf.Fill"(%outputs_10, %outputs_8) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1xi64>, tensor) -> tensor<1024xf32> +// CHECK: %outputs_14:5, %control_15 = tf_executor.island wraps "tf._TPUCompileMlir"() +// CHECK: %outputs_16, %control_17 = tf_executor.island wraps "tf.Identity"(%outputs_14#0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> tensor +// CHECK: %control_18 = tf_executor.island wraps "tf.TPUCompileSucceededAssert"(%outputs_16) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor) -> () +// CHECK: %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() <{value = dense<0> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor +// CHECK: %outputs_21, %control_22 = tf_executor.island wraps "tf.Identity"(%outputs) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK: %outputs_23, %control_24 = tf_executor.island wraps "tf.Identity"(%outputs_0) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: %outputs_25:4, %control_26 = tf_executor.island wraps "tf.Split"(%outputs_19, %outputs_21) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:0", num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +// CHECK: %outputs_27, %control_28 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#0, %outputs_23, %outputs_14#1) {_parallel_execution_ids = "r0:0,p0:0", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_29, %control_30 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#1, %outputs_23, %outputs_14#2) {_parallel_execution_ids = "r0:0,p0:1", device = "/job:tpu_host_worker/replica:0/task:0/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_31, %control_32 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#2, %outputs_23, %outputs_14#3) {_parallel_execution_ids = "r0:0,p0:2", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_33, %control_34 = tf_executor.island wraps "tf.TPUExecute"(%outputs_25#3, %outputs_23, %outputs_14#4) {_parallel_execution_ids = "r0:0,p0:3", device = "/job:tpu_host_worker/replica:0/task:1/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_35, %control_36 = tf_executor.island wraps "tf.Identity"(%outputs_6) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> +// CHECK: %outputs_37, %control_38 = tf_executor.island wraps "tf.Identity"(%outputs_12) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"} : (tensor<1024xf32>) -> tensor<1024xf32> +// CHECK: %outputs_39:4, %control_40 = tf_executor.island wraps "tf.Split"(%outputs_19, %outputs_35) {_ici_weight_distribution_mlir_bridge_marker = true, _parallel_execution_ids = "r0:1", num_split = 4 : i32} : (tensor, tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +// CHECK: %outputs_41, %control_42 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#0, %outputs_37, %outputs_14#1) {_parallel_execution_ids = "r0:1,p0:0", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_43, %control_44 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#1, %outputs_37, %outputs_14#2) {_parallel_execution_ids = "r0:1,p0:1", device = "/job:tpu_host_worker/replica:0/task:2/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_45, %control_46 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#2, %outputs_37, %outputs_14#3) {_parallel_execution_ids = "r0:1,p0:2", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:0"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_47, %control_48 = tf_executor.island wraps "tf.TPUExecute"(%outputs_39#3, %outputs_37, %outputs_14#4) {_parallel_execution_ids = "r0:1,p0:3", device = "/job:tpu_host_worker/replica:0/task:3/device:TPU:1"} : (tensor<32x1024xf32>, tensor<1024xf32>, tensor<3x!tf_type.string>) -> tensor<*xf32> +// CHECK: %outputs_49, %control_50 = tf_executor.island wraps "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor +// CHECK: %outputs_51, %control_52 = tf_executor.island wraps "tf.Identity"(%outputs_49) {device = ""} : (tensor) -> tensor +// CHECK: tf_executor.fetch %outputs_51, %control, %control_1, %control_28, %control_30, %control_32, %control_34, %control_42, %control_44, %control_46, %control_48, %control_50 : tensor, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control, !tf_executor.control module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor<*xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD index f2f686f9822927..cc66e3c3222247 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/BUILD @@ -308,4 +308,5 @@ glob_lit_tests( for file in test_files }, test_file_exts = ["py"], + use_lit_test_suite = False, # Each test gets a large binary, and is too big when consolidated. ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 07e3eea2d4ea69..454b47ad213ec8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1891,19 +1891,19 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // Tests tile sharding of inputs with number of splits that does not evenly divide -// the input results in an error. +// the input results in an error, when shapes are not fully known. module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { - func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { + func.func @uneven_input_sharding_disallowed(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} - %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } func.return %0#0, %1#0 : tensor<*xi32>, tensor<*xi1> } - func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { - %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<*xi32>, tensor<*xi1>) + func.func @tpu0_func(%arg0: tensor, %arg1: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor) -> (tensor<*xi32>, tensor<*xi1>) %4 = "tf.B"(%1, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> (tensor<*xi32>) %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<*xi1>) -> tensor<*xi1> func.return %4, %3 : tensor<*xi32>, tensor<*xi1> @@ -2793,7 +2793,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // ----- -// Tests ici_weight_distribution_mlir_bridge_marker attribute can be +// Tests _ici_weight_distribution_mlir_bridge_marker attribute can be // produced in created Split op module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0"]} { // CHECK-LABEL: func @propagate_ici_weight_attr_to_split_op @@ -2809,7 +2809,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) // CHECK: %[[CONST_SPLIT_0_DIM:.*]] = "tf.Const"() // CHECK: %[[SPLIT_0_OUT:[a-z0-9]+]]:2 = "tf.Split" - // CHECK: ici_weight_distribution_mlir_bridge_marker = true + // CHECK: _ici_weight_distribution_mlir_bridge_marker = true // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:5 = "tf_device.parallel_execute" // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch" // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#0, %[[COMPILE]]#1) @@ -2824,9 +2824,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc // CHECK-NEXT: %[[EXECUTE_3_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_0_OUT]]#1, %[[COMPILE]]#4) // CHECK: tf_device.return %[[EXECUTE_3_OUTPUT]] %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ - %identity = "tf.Identity"(%ri_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x10xf32>) -> tensor<128x10xf32> + %identity = "tf.Identity"(%ri_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x10xf32>) -> tensor<128x10xf32> tf_device.return %identity : tensor<128x10xf32> - }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x10xf32> + }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x10xf32> %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 4, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1], input_sharding_configuration = ["\08\03\1A\03\02\01\02\22\04\00\01\02\030\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %2, %3 : tensor<*xi32>, tensor<*xi1> } @@ -2839,3 +2839,169 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:loc func.return %4, %3 : tensor<*xi32>, tensor<*xi1> } } + +// ----- + +// Tests that outputs are correctly merged and fed from TPU computation for +// tiled output sharding with padding for concat ops. + +// The following OpSharding is used for TPU computation outputs in below test: +// Proto debug string: +// output 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// output 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_output + func.func @parallel_execute_with_tiled_output(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x10xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute" + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + // + // CHECK: %[[CONST_CONCAT3_DIM:.*]] = "tf.Const"() + // CHECK: %[[CONCAT3_OUTPUT:[0-9]*]] = "tf.Concat"(%[[CONST_CONCAT3_DIM]], %[[PARALLEL_EXECUTE_OUTPUT]]#0, %[[PARALLEL_EXECUTE_OUTPUT]]#2) + // CHECK: %[[CONST_SLICE_BEGIN:.*]] = "tf.Const"() + // dense<0> + // tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: %[[CONST_SLICE_SIZE:.*]] = "tf.Const"() + // dense<[128, 5]> : tensor<2xi64>}> : () -> tensor<2xi64> + // CHECK: "tf.Slice"(%[[CONCAT3_OUTPUT]], %[[CONST_SLICE_BEGIN]], %[[CONST_SLICE_SIZE]]) + // : (tensor<128x6xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<128x5xi32> + %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<128x10xi32>) -> (tensor<128x5xi32>, tensor<10x5xi1>) + tf_device.return %1, %2 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x5xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x10xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// Tests inputs are correctly split and fed into TPU computation for tiled input +// sharding with padding. + +// The following OpSharding is used for TPU computation inputs in the below +// test: +// Proto debug string: +// input 0 +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// +// input 1 +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 1 +// Serialized string: +// "\08\01\1A\01\01\22\01\01" + + +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { + // CHECK-LABEL: func @parallel_execute_with_tiled_input + // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<128x9xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<128x10xi32>, %[[ARG_3:[a-z0-9]*]]: tensor<128x10xi32>) + func.func @parallel_execute_with_tiled_input(%arg0: tensor<128x9xf32>, %arg1: tensor<128x9xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + // CHECK: tf_device.replicate + // CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %[[RI_0:[a-z0-9]*]]: tensor<128x9xf32> + // CHECK-SAME: [%[[ARG_2]], %[[ARG_3]]] as %[[RI_1:[a-z0-9]*]]: tensor<128x10xi32> + // CHECK-SAME: devices = + // CHECK-SAME: TPU_REPLICATED_CORE_0 = ["/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1"] + // CHECK-SAME: TPU_REPLICATED_CORE_1 = ["/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU:0"] + %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x9xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { + // CHECK: %[[DEVICE_LAUNCH_OUT:[a-z0-9]+]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> + // CHECK: %[[COMPILE:[a-z0-9]+]]:3 = "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf._TPUCompileMlir" + // CHECK: "tf_device.launch"() <{device = "/job:localhost/replica:0/task:0/device:CPU:0"}> + // CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE]]#0) + // + // CHECK: %[[PAD_SHAPE:[a-z0-9]+]] = "tf.Const"() + // CHECK: [0, 0], [0, 1] + // CHECK: : tensor<2x2xi64>}> : () -> tensor<2x2xi64> + // CHECK: %[[PAD_OUT:[a-z0-9]+]] = "tf.Pad"(%[[DEVICE_LAUNCH_OUT]], %[[PAD_SHAPE]]) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>, tensor<2x2xi64>) -> tensor<128x10xf32> + // CHECK: %[[CONST_SPLIT_DIM:.*]] = "tf.Const"() <{value = dense<1> : tensor}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK: %[[SPLIT_OUT:[a-z0-9]+]]:2 = "tf.Split"(%[[CONST_SPLIT_DIM]], %[[PAD_OUT]]) {_ici_weight_distribution_mlir_bridge_marker = true, num_split = 2 : i32} : (tensor, tensor<128x10xf32>) -> (tensor<128x5xf32>, tensor<128x5xf32>) + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:3 = "tf_device.parallel_execute" + // CHECK-NEXT: %[[LAUNCH_0_OUTPUT:[0-9]*]]:2 = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_0"}> + // + // CHECK-NEXT: %[[EXECUTE_0_OUTPUT:[0-9]*]]:2 = "tf.TPUExecute"(%[[SPLIT_OUT]]#0, %[[COMPILE]]#1) + // CHECK-NEXT: tf_device.return %[[EXECUTE_0_OUTPUT]] + // CHECK: %[[LAUNCH_1_OUTPUT:[0-9]*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_CORE_1"}> + // CHECK-NEXT: %[[EXECUTE_1_OUTPUT:[0-9]*]] = "tf.TPUExecute"(%[[SPLIT_OUT]]#1, %[[RI_1]], %[[COMPILE]]#2) + // CHECK-NEXT: tf_device.return %[[EXECUTE_1_OUTPUT]] + %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + %identity = "tf.Identity"(%ri_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>) -> tensor<128x9xf32> + tf_device.return %identity : tensor<128x9xf32> + }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x9xf32> + %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x9xf32>, tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + tf_device.return %2, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.return %0#0, %1#0 : tensor<128x10xi32>, tensor<10x5xi1> + } + func.func @tpu0_func(%arg0: tensor<128x9xf32>, %arg1: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { + %1, %2 = "tf.A"(%arg0) : (tensor<128x9xf32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) + %4 = "tf.B"(%1, %arg1) : (tensor<128x10xi32>, tensor<128x10xi32>) -> (tensor<128x10xi32>) + %3 = "tf.XlaSharding"(%2) { _XlaSharding = "", sharding = "" } : (tensor<10x5xi1>) -> tensor<10x5xi1> + func.return %4, %3 : tensor<128x10xi32>, tensor<10x5xi1> + } +} + +// ----- + +// CHECK: "tf.Split" +// : (tensor<128x1024xf32>) -> (tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>, tensor<32x1024xf32>) +module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} { + func.func @main(%arg0: tensor {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} { + %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor + %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource>>) -> tensor<128x1024xf32> + %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource>>) -> tensor<1024xf32> + %3:2 = tf_device.replicate {n = 2 : i32} { + %6 = "tf_device.cluster_func"(%1, %2) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + tf_device.return %6 : tensor<*xf32> + } + %4 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource>>) -> tensor + %5 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + return %5 : tensor + } + func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.is_same_data_across_replicas = true, mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) { + %0 = "tf.XlaSharding"(%arg0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32> + %1 = "tf.MatMul"(%0, %arg1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32> + return %1 : tensor<*xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir index abb68b92146ba7..1aa574d12bf1ba 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_sharding_identification.mlir @@ -957,3 +957,47 @@ func.func @func(%arg0: tensor) -> tensor<4xf32> { func.return %1 : tensor<4xf32> } + +// ----- +// CHECK-LABEL: func @check_AddV2_variant_shape_with_input_sharding_propagation +func.func @check_AddV2_variant_shape_with_input_sharding_propagation(%arg0: tensor, %arg1: tensor<12x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["sharding_info_1", "sharding_info_1"] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<12x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = "sharding_info_1"{{.*}}mhlo.sharding = "sharding_info_1"{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<12x384xbf16>) -> tensor { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor, tensor<12x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%add) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} + + + +// ----- +// CHECK-LABEL: func @check_BatchMatMul_variant_shape_without_input_sharding_propagation +func.func @check_BatchMatMul_variant_shape_without_input_sharding_propagation(%arg0: tensor, %arg1: tensor<256x384xbf16>) { + // CHECK: tf_device.cluster_func + // CHECK-SAME: input_sharding_configuration = ["", ""] + // CHECK-SAME: output_sharding_configuration = ["sharding_info_1"] + "tf_device.cluster_func"(%arg0, %arg1) { + func = @func, + use_spmd_for_xla_partitioning = true, num_cores_per_replica = 1 : i64 + } : (tensor, tensor<256x384xbf16>) -> tensor + func.return +} + +// CHECK-LABEL: func @func +// CHECK: {{.*}}mhlo.sharding = ""{{.*}}mhlo.sharding = ""{{.*}}->{{.*}}mhlo.sharding = "sharding_info_1" +func.func @func(%arg0: tensor, %arg1: tensor<256x384xbf16>) -> tensor { + %mul = "tf.BatchMatMul"(%arg0, %arg1) : (tensor, tensor<256x384xbf16>) -> tensor + %0 = "tf.XlaSharding"(%mul) { _XlaSharding = "sharding_info_1"} : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir b/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir index a28ea04aec9308..31434cc48d1f99 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/xla_broadcast.mlir @@ -1,22 +1,22 @@ // RUN: tf-opt %s -split-input-file -tf-xla-broadcast | FileCheck %s module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1850 : i32}} { -// CHECK-LABEL: func @move_broadcast -func.func @move_broadcast(%arg0: tensor) -> () { +// CHECK-LABEL: func @move_broadcast_non_spmd +func.func @move_broadcast_non_spmd(%arg0: tensor) -> () { // CHECK: %[[ELEM_0:.*]] = "tf.Const"() - // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} // CHECK-NEXT: %[[SHAPE_0:.*]] = "tf.Const"() - // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} - // CHECK-NEXT: %[[FULL_0:.*]] = "tf.Fill"(%[[SHAPE_0]], %[[ELEM_0]]) {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[FULL_0:.*]] = "tf.Fill"(%[[SHAPE_0]], %[[ELEM_0]]) {_ici_weight_distribution_mlir_bridge_marker = true} // CHECK: %[[ELEM_1:.*]] = "tf.Const"() - // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} // CHECK-NEXT: %[[SHAPE_1:.*]] = "tf.Const"() - // CHECK: {ici_weight_distribution_mlir_bridge_marker = true} - // CHECK-NEXT: %[[FULL_1:.*]] = "tf.Fill"(%[[SHAPE_1]], %[[ELEM_1]]) {ici_weight_distribution_mlir_bridge_marker = true} + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[FULL_1:.*]] = "tf.Fill"(%[[SHAPE_1]], %[[ELEM_1]]) {_ici_weight_distribution_mlir_bridge_marker = true} // CHECK-NEXT: tf_device.replicate([%arg0, %[[FULL_0]], %[[FULL_1]], %[[FULL_0]]] as %[[REPVAR:.*]]: tensor) {n = 4 : i32} { // CHECK-NEXT: %[[ID:.*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ - // CHECK-NEXT: %[[IDINSIDE:.*]] = "tf.Identity"(%[[REPVAR]]) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor) -> tensor + // CHECK-NEXT: %[[IDINSIDE:.*]] = "tf.Identity"(%[[REPVAR]]) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor) -> tensor // CHECK-NEXT: tf_device.return %[[IDINSIDE]] : tensor - // CHECK-NEXT: }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK-NEXT: }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor // CHECK-NEXT: "tf_device.cluster"() ({ // CHECK-NEXT: %[[GROUP:.*]] = "tf.Const"() // CHECK-SAME: [0, 1, 2, 3] @@ -31,3 +31,32 @@ func.func @move_broadcast(%arg0: tensor) -> () { func.return } } + +// ----- +module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0", "/job:worker/replica:0/task:0/device:TPU:1", "/job:worker/replica:0/task:0/device:TPU:2", "/job:worker/replica:0/task:0/device:TPU:3", "/job:worker/replica:0/task:0/device:TPU:4", "/job:worker/replica:0/task:0/device:TPU:5", "/job:worker/replica:0/task:0/device:TPU:6", "/job:worker/replica:0/task:0/device:TPU:7"]} { +// CHECK-LABEL: func @move_broadcast_spmd +func.func @move_broadcast_spmd(%arg0: tensor) -> () { + // CHECK: %[[ELEM_0:.*]] = "tf.Const"() + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[SHAPE_0:.*]] = "tf.Const"() + // CHECK: {_ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: %[[FULL_0:.*]] = "tf.Fill"(%[[SHAPE_0]], %[[ELEM_0]]) {_ici_weight_distribution_mlir_bridge_marker = true} + // CHECK-NEXT: tf_device.replicate([%arg0, %[[FULL_0]], %[[FULL_0]], %[[FULL_0]]] as %[[REPVAR:.*]]: tensor) {n = 4 : i32} { + // CHECK-NEXT: %[[ID:.*]] = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ + // CHECK-NEXT: %[[IDINSIDE:.*]] = "tf.Identity"(%[[REPVAR]]) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor) -> tensor + // CHECK-NEXT: tf_device.return %[[IDINSIDE]] : tensor + // CHECK-NEXT: }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor + // CHECK-NEXT: "tf_device.cluster"() ({ + // CHECK-NEXT: %[[GROUP:.*]] = "tf.Const"() + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-NEXT: %[[REDUCED:.*]] = "tf.XlaAllReduce"(%[[ID]], %[[GROUP]]) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor, tensor<1x4xi32>) -> tensor + // CHECK-NEXT: "tf.OpA"(%[[REDUCED]]) : (tensor) -> () + tf_device.replicate {n = 4 : i32} { + "tf_device.cluster"() ({ + "tf.OpA"(%arg0) : (tensor) -> () + tf_device.return + }) {allow_soft_placement = false, computation_shape = [], device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], host_compute_core = [], num_cores_per_replica = 2 : i64, num_replicas = 4 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\01\02\10\01\18\08\22 \00\00\00\00\00\00\00\01\01\00\00\00\01\00\00\01\00\01\00\00\00\01\00\01\01\01\00\00\01\01\00\01*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : () -> () + } + func.return +} +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index fd0668055a3dc5..f134946abe8495 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -556,7 +556,6 @@ cc_library( "tpu_resource_read_for_write.cc", "tpu_space_to_depth_pass.cc", "tpu_update_embedding_enqueue_op_inputs.cc", - "tpu_validate_inputs.cc", "update_control_dependencies.cc", "verify_suitable_for_graph_export_pass.cc", "xla_call_module_deserialization.cc", @@ -614,6 +613,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:stablehlo_custom_call_utils", "//tensorflow/compiler/mlir/tensorflow:string_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_side_effects", "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", @@ -687,7 +687,7 @@ cc_library( "@local_xla//xla:window_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla:xla_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", "@local_xla//xla/mlir_hlo", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", @@ -826,9 +826,9 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:window_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/service:shape_inference", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_utils", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/tsl/util:env_var", ], ) @@ -976,8 +976,8 @@ cc_library( hdrs = ["tf_graph_optimization_pass.h"], deps = [ "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", "//tensorflow/compiler/tf2xla:functionalize_control_flow_pass_registration", "//tensorflow/core:core_cpu", @@ -1069,10 +1069,10 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc index 91f14794494de7..2acf81dbcd78bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc @@ -18,15 +18,19 @@ limitations under the License. #include #include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" // from @llvm-project -#include "mlir/Transforms/RegionUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc index ffc650cb9073bd..67c1d911889ac4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc @@ -17,11 +17,18 @@ limitations under the License. #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #define DEBUG_TYPE "tf-hoist-replicate-invariant-resource-writes" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index 77f9361ab94a27..0130e7a63c70bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -74,7 +74,8 @@ class LowerClusterToRuntimeOpsTest : public ::testing::Test { env_ = Env::Default(); test_group_name_ = "TestGroup"; test_dir_ = testing::TmpDir(); - setenv("TF_DUMP_GRAPH_PREFIX", test_dir_.c_str(), /*overwrite=*/1); + setenv(/*name=*/"TF_DUMP_GRAPH_PREFIX", /*value=*/test_dir_.c_str(), + /*overwrite=*/1); } absl::Status CreateMlirModule(std::string mlir_module_filename) { @@ -179,8 +180,9 @@ TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) { std::vector files; TF_ASSERT_OK(env_->GetChildren(test_dir_, &files)); EXPECT_THAT(files, ::testing::IsEmpty()); - setenv("TF_DUMP_GRAPH_NAME_FILTER", "*", /*overwrite=*/1); - setenv("TF_DUMP_GRAPH_GROUPS", "main,runtime_lowering", /*overwrite=*/1); + setenv(/*name=*/"TF_DUMP_GRAPH_NAME_FILTER", /*value=*/"*", /*overwrite=*/1); + setenv(/*name=*/"TF_DUMP_GRAPH_GROUPS", /*value=*/"main,runtime_lowering", + /*overwrite=*/1); DEBUG_DATA_DUMPER()->LoadEnvvars(); TF_ASSERT_OK(CreateMlirModule("basic_cluster.mlir")); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 36025277fd915b..54dad08e546e64 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -502,8 +502,6 @@ CreateConvertToLegacyCompileAndReplicateAttributesPass(); std::unique_ptr> CreateTPUPartitionedOpConversionPass(); -std::unique_ptr> CreateTPUValidateInputsPass(); - // Creates a pass that cleans up `_replication_info` attribute on operations // that are inside a cluster. std::unique_ptr> @@ -671,7 +669,6 @@ enum MoveTransposeDirection { kBegin, kEnd }; #define GEN_PASS_DECL_TPURESOURCEREADSWRITESPARTITIONINGPASS #define GEN_PASS_DECL_TPUSPACETODEPTHPASS #define GEN_PASS_DECL_TPUUPDATEEMBEDDINGENQUEUEOPINPUTSPASS -#define GEN_PASS_DECL_TPUVALIDATEINPUTSPASS #define GEN_PASS_DECL_TENSORARRAYOPSDECOMPOSITIONPASS #define GEN_PASS_DECL_TENSORDEVICECOPYCONVERSIONPASS #define GEN_PASS_DECL_TENSORFLOWOPTIMIZEPASS diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 0eb552208194e1..6fa61e1bde3d93 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -27,12 +27,12 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/tpu_api.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index d3630226ed1f32..2cee935dc96f23 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -83,10 +83,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/util/env_var.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD index bff95d357c885f..d19d5e8e8ab5aa 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -8,7 +8,6 @@ package( "//tensorflow/compiler/mlir:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:__pkg__", - "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", "//tensorflow/compiler/mlir/tf2xla/internal:__pkg__", ], licenses = ["notice"], diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 1ccfc8775d1c44..369840c888f4a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -23,8 +23,8 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/optimization_registry.h" @@ -106,8 +106,8 @@ void GraphOptPass::runOnOperation() { // Convert Graph to MLIR GraphDebugInfo debug_info; GraphImportConfig specs; - auto module_or_status = - ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx); + auto module_or_status = tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + **options.graph, debug_info, flib_def, specs, &ctx); if (!module_or_status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) << module_or_status.status().message(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td index fb1c541e56b938..f724b4a6443639 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td @@ -837,18 +837,6 @@ def TPUPartitionedOpConversionPass : Pass<"tf-tpu-partitioned-op-conversion", "m let constructor = "TFTPU::CreateTPUPartitionedOpConversionPass()"; } -def TPUValidateInputsPass : Pass<"tf-tpu-validate-inputs", "ModuleOp"> { - let summary = "Validates inputs to the TPU TF/XLA bridge"; - - let description = [{ - This pass checks that the IR has valid input to TPU TF/XLA bridge. - It checks the relations of multiple ops. Properties of single ops are - checked by the 'verify' method of ops. - }]; - - let constructor = "TFTPU::CreateTPUValidateInputsPass()"; -} - def ClusterConstantSinkingPass : Pass<"tf-device-constant-sinking", "mlir::func::FuncOp"> { let summary = "Sinks constants implicitly captured in a tf_device.cluster region."; @@ -2225,7 +2213,7 @@ def XlaCallModuleDeserializationPass "shape::ShapeDialect", "stablehlo::StablehloDialect", "vhlo::VhloDialect", - "quant::QuantizationDialect", + "quant::QuantDialect", ]; let constructor = "TF::CreateXlaCallModuleDeserializationPass()"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc index b38f1b2d68ac94..1dd16ed7bc753c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 297b6f16cd8a79..168672d951c2ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -1,4 +1,3 @@ -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( @@ -17,7 +16,6 @@ cc_library( ], deps = [ ":mlir_roundtrip_flags", - ":node_order", ":upgrade_graph", "//tensorflow/cc/saved_model:bundle_v2", "//tensorflow/cc/saved_model:constants", @@ -42,6 +40,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:mark_initialized_variables_lib", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -69,20 +68,6 @@ cc_library( ], ) -tf_cc_test( - name = "tf_mlir_translate_registration_test", - size = "small", - srcs = ["tf_mlir_translate_registration_test.cc"], - deps = [ - ":translate_registration", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:TranslateLib", - ], -) - cc_library( name = "export_tf_dialect_op", srcs = [ @@ -252,44 +237,3 @@ cc_library( "@llvm-project//llvm:Support", ], ) - -cc_library( - name = "node_order", - srcs = ["node_order.cc"], - hdrs = ["node_order.h"], - deps = [ - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:lib", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - ], -) - -tf_cc_test( - name = "node_order_test", - size = "small", - srcs = [ - "node_order_test.cc", - ], - deps = [ - ":node_order", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:sendrecv_ops", - "//tensorflow/core", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:direct_session_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:ops", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - ], -) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3a6572dbfa7f7c..5ed8645ed9752c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -16,29 +16,22 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include -#include #include #include #include -#include #include -#include #include #include -#include #include #include -#include #include #include #include #include -#include "absl/algorithm/container.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" @@ -47,14 +40,11 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" @@ -72,25 +62,18 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/cc/saved_model/loader_util.h" -#include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/mark_initialized_variables.h" @@ -98,22 +81,18 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" #include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/common_runtime/function_def_utils.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" -#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -123,24 +102,18 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_debug_info_builder.h" #include "tensorflow/core/graph/graph_node_util.h" -#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/crash_analysis.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" -#include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" @@ -149,23 +122,13 @@ limitations under the License. #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/protobuf/struct.pb.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" -#include "tensorflow/core/util/device_name_utils.h" -#include "tensorflow/core/util/dump_graph.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -static inline absl::string_view StringRefToView(llvm::StringRef ref) { - return {ref.data(), ref.size()}; -} - namespace tensorflow { constexpr size_t kNumThreadToConvertSignatures = 10; -constexpr absl::string_view kOutputShapesAttrName = "_output_shapes"; -using ::mlir::NamedAttrList; -using ::mlir::TensorType; using ::mlir::tf_saved_model::AssetOp; using ::mlir::tf_saved_model::GlobalTensorOp; using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; @@ -178,18 +141,6 @@ using ::tsl::StatusOr; namespace { -bool IsOutputShapesAttribute(const AttrValue& attr_value, - llvm::StringRef attr_name) { - return attr_name.compare(kOutputShapesAttrName) == 0 && - attr_value.value_case() == AttrValue::kList; -} - -bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, - llvm::StringRef attr_name) { - if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes") - return attr_value.value_case() == AttrValue::kList; - return false; -} void LoadImporterDialects(mlir::MLIRContext& context) { // Load dialects involved in the conversion @@ -239,286 +190,6 @@ class NameUniquifier : public OpOrArgNameMapper { const FunctionLibraryDefinition& flib_; }; -// Stateful helper class to import a TensorFlow model into an MLIR Module. -// -// This is the base class that contains common utilities shared between the -// GraphDef importer and SavedModel importer. -// -// A subclass is expected to call `PrepareConvert` first to perform necessary -// preparation over the graph and also certain internal bookkeeping data. -// Afterwards the other protected methods can be called. -class ImporterBase { - protected: - explicit ImporterBase( - const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::ModuleOp module, - std::unordered_map* tf_name_to_mlir_name, - NameUniquifier* function_name_uniquifier, - llvm::StringRef function_name_for_debug_info = "") - : builder_(module.getContext()), - module_(module), - context_(module.getContext()), - tf_name_to_mlir_name_(tf_name_to_mlir_name), - graph_flib_(flib), - specs_(specs), - debug_info_(debug_info), - function_name_for_debug_info_(function_name_for_debug_info), - function_name_uniquifier_(function_name_uniquifier), - error_handler_(module.getContext()) { - // Log import config. - if (VLOG_IS_ON(1)) { - LOG(INFO) << "Importing with: " << specs.str(); - for (auto& it : *tf_name_to_mlir_name) { - LOG(INFO) << "\t" << it.first << " -> " << it.second; - } - } - - stack_traces_ = LoadTracesFromDebugInfo(debug_info_); - } - - // Returns the inferred function signature of the given function body. Input - // types are unranked tensor of the respective datatype in the function and - // result types are inferred by the shape_refiner_. Result types need not be - // unranked tensors and could be ranked tensors in cases where result type - // depends on an op with static output shape like tf.Const. - absl::StatusOr InferLibFunctionType( - const FunctionBody& fbody); - - // Extracts arg and ret nodes from FunctionBody. - void GetArgsAndRetsFromFunctionBody( - const FunctionBody& fbody, - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes); - - // Prepares converting the graph to an MLIR module. This step removes the - // backedges of the graph, orders the nodes and infers the shapes. - // PrepareConvert needs to ensure that the original `graph` is cloned prior - // execution. The cloning procedure relies on the roundtrip through the - // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def` - // was obtained previously provide it to the PrepareConvert to reuse. - Status PrepareConvert(const Graph& graph, - std::unique_ptr graph_def = nullptr); - - // Converts the prepared graph to a Function and adds it to the module. A set - // of nodes from the graph are given to converted to the arguments and returns - // of the function. - Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type, - const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes, - const absl::InlinedVector& control_ret_nodes, - llvm::ArrayRef attrs); - - // Finds out the function definition for the given function name from the - // graph and converts it to a function of the module. This method is called - // on demand because the graph flib_def does not provide an iterator - // interface. - Status ConvertLibFunction(llvm::StringRef func_name); - - // Returns the list of nodes in the graph. Nodes are presented in the reverse - // order of a post-order depth-first visit starting from the graph's source - // nodes. - llvm::ArrayRef GetOrderedNodes() const { return ordered_nodes_; } - - // Returns the inferred input type at index `idx` of the `node` in the - // context. - absl::StatusOr InferInputType(const Node& node, int idx, - mlir::Builder builder); - - // Returns the inferred output type at index `idx` of the `node` in the - // context. - absl::StatusOr InferOutputType(const Node& node, int idx, - mlir::Builder builder); - - // Convert deferred TF functions to the MLIR representation. - // Conversion is deferred for efficiency reasons, e.g., to limit depth - // of recursion and reduce stack size pressure. - Status ConvertDeferredFunctions(); - - private: - // Most types with subtypes have only one subtype. - using ElementSubtypes = llvm::SmallVector; - - // Metadata used for deferred function conversion. - struct DeferredConversionMetaData { - DeferredConversionMetaData( - const std::string& function_name, - const std::vector& attributes) - : function_name(function_name), attributes(attributes) {} - - std::string function_name; - std::vector attributes; - }; - - // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all - // data type and shape information is maintained by the shape_refiner_. - // TODO(jpienaar): Remove once shape inference on import is removed. - Status AddNodesToShapeRefiner( - std::unordered_map* node_name_map); - - // Prune nodes that do not feed into fetch nodes. - Status PruneUnreachableNodes( - std::unordered_map* node_name_map); - - // Converts feeds to Placeholder nodes. - Status ConvertFeedsToPlaceholders( - std::unordered_map* node_name_map); - - // Converts the inferred shape referred to by 'handle' in 'context', with - // given element type, and returns an MLIR tensor type. - absl::StatusOr ConvertDataTypeAndShape( - DataType dtype, const shape_inference::ShapeHandle& handle, - const std::vector* handle_subtypes, - shape_inference::InferenceContext* context, mlir::Builder builder); - - // Converts the inferred shape referred to by 'handle' in 'context', with - // given element type, and returns an MLIR tensor type. - absl::StatusOr ConvertElementTypeAndShape( - mlir::Type element_type, const shape_inference::ShapeHandle& handle, - shape_inference::InferenceContext* context, mlir::Builder builder); - - // Converts the inferred subtypes for an element type to corresponding MLIR - // types in 'context'. - absl::StatusOr ConvertSubtypes( - const std::vector* handle_subtypes, - shape_inference::InferenceContext* context, mlir::Builder builder); - - // Converts the tensor proto into an MLIR elements attribute. - absl::StatusOr ConvertTensorProto( - const TensorProto& value) { - return ::tensorflow::ConvertTensorProto(value, &builder_); - } - - // Converts func name in graphdef to mlir::SymbolRefAttribute. - absl::StatusOr ConvertFunctionCallName( - const std::string& func_name); - - // Converts the given non-function-call AttrValue to an MLIR Attribute. - absl::StatusOr ConvertAttributeValue(const AttrValue& value); - - // Converts the given function-call AttrValue to MLIR Attributes and pushes - // them to the given attributes list. For example, if there is a kFunc - // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to - // a list of MLIR Attributes: {{base_name : foo}, {base_name.k1 : bar}, - // {base_name.k2 : rfc}}. - Status ConvertFunctionCallAttribute(const std::string& base_name, - const AttrValue& value, - NamedAttrList* attributes); - - // Helper to create either a tf_executor operation or a TF operation wrapped - // in an island. - mlir::Operation* CreateOperation( - const Node& node, llvm::StringRef node_type_name, - const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands); - - // Converts one NodeDef from the input GraphDef into an Operation and - // inserts it into the MLIR module using builder_. - Status ConvertNode(const Node& node); - - // If the input graph represents a while-loop, the edges pointing from a - // "NextIteration" node to a "Merge" node add cyclic dependencies and make the - // topological sorting impossible. We need to remove these edges from the - // input graph to infer shapes and construct a Function. For each - // "NextIteration" node, there are two operations, "NextIteration.source" - // and "NextIteration.sink" are added to the MLIR module. - using BackEdge = BackEdgeHelper::BackEdge; - - // Removes backedges from the input graph. The removed edges are added back to - // to OpBuilder after the remaining graph is converted to the Function. - Status RemoveBackedges(); - - // Restores backedges removed during shape inference to the final Function. - Status AddBackedges(); - - // Restores a single backedge in the Function by adding a replicated - // operation before the dst operation. - Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst, - int dst_input); - - // Adds the input arguments and return operation to the function. The - // arguments are added as basic block argument. Also the argument types and - // the id of the nodes from the input graph needs to be specified. - Status ConvertFunctionArgAndRets( - mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op, - llvm::ArrayRef arg_types, - const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes, - const absl::InlinedVector& control_ret_nodes); - - // Gets the location information of the given node. It uses the - // "original_node_name" in the NodeDef to get the corresponding file location - // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If - // there are multiple "original_node_names", a FusedLoc is returned. If the - // node name couldn't be found in the input DebugInfo, a NameLoc is used as - // the location. - mlir::Location GetLocation(const Node& node); - - // Appends the location string for the node to the error message and returns - // the combined error status. - Status EmitErrorWithLocationStr(const Node& node, const Status& error_status); - - // Inserts a placeholder node in the graph to replace a feed output tensor, - // and returns the new placeholder node and a boolean indicating if the - // original input node was removed from the graph. Uses of the feed output - // tensor are replaced with this placeholder node. If the feed output tensor - // is of a single output node, the control dependencies are forwarded to the - // the placeholder node, and the original node will be removed. - // Note: This modifies the graph, and so any list of ordered nodes needs to be - // reconstructed. - absl::StatusOr> CreatePlaceholderNodeForFeed( - const TensorShapeProto& shape, DataType dtype, Node* node, int index, - const std::unordered_map& node_name_map); - - // Gets the input and output nodes corresponding to the specified input and - // output nodes in specs_. If there are no input or output nodes specified, - // nodes will be empty. - Status GetInputOutputNodes( - const std::unordered_map& node_name_map, - std::unordered_set* nodes); - - // The input graph with backedges removed. The removed backedges are stored - // in the back_edge_helper. - BackEdgeHelper back_edge_helper_; - // A map between node and output index, for each backedge. - absl::flat_hash_map back_edge_node_output_; - absl::flat_hash_map back_edge_dst_inputs_; - // A map between sink and source operation of NextIteration - absl::flat_hash_map - next_iteration_sink_source_; - - // All nodes and version information about the (copied) imported graph. - std::unique_ptr graph_; - std::vector ordered_nodes_; - - // Maps from a Node ID to a MLIR value. - using NodeValueMap = absl::flat_hash_map; - - mlir::OpBuilder builder_; - mlir::ModuleOp module_; - mlir::MLIRContext* context_; - std::unordered_map* tf_name_to_mlir_name_; - const FunctionLibraryDefinition& graph_flib_; - const GraphImportConfig& specs_; - const GraphDebugInfo& debug_info_; - StackTracesMap stack_traces_; - llvm::StringRef function_name_for_debug_info_; - NodeValueMap node_values_; - // TODO(jpienaar): Remove once shape inference on import is removed. - // The shape_refinner_ will be nullptr if shape inference on import is - // not enabled. - std::unique_ptr shape_refiner_ = nullptr; - NameUniquifier* function_name_uniquifier_; - mlir::StatusScopedDiagnosticHandler error_handler_; - // All the TF ops encountered that aren't modelled in dialect. - llvm::DenseSet unmodelled_op_names_; - - protected: - // Maps feed as TensorId to new Placeholder node name. - absl::flat_hash_map remapped_feeds_; - // Keep track of functions required deferred conversion. - std::queue deferred_functions_; -}; // Returns true if the node with given name has a non primary output that is // used by some other node as an input. Returns false if no outputs are in use @@ -596,2231 +267,8 @@ Status PreprocessGraphDef(const GraphImportConfig* specs, GraphDef* graph_def) { return absl::OkStatus(); } -// Mapping from node name to feed (index and ArrayInfo). Node name must outlive -// this map. -using FeedsByNode = absl::flat_hash_map< - absl::string_view, - absl::flat_hash_map*>>; - -// Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output -// tensor name to index and ArrayInfo. Keys and values are backed by -// `GraphImportConfig::InputArrays`. -absl::StatusOr GetFeedsByNode( - const GraphImportConfig::InputArrays& inputs) { - FeedsByNode feeds_by_node; - feeds_by_node.reserve(inputs.size()); - - for (const auto& input : inputs) { - TensorId tensor = ParseTensorName(input.first); - if (tensor.index() < 0) - return errors::FailedPrecondition( - "Feed output tensor must be a data output '", tensor.ToString(), "'"); - - auto& node = feeds_by_node[tensor.node()]; - if (!node.insert({tensor.index(), &input}).second) - return errors::FailedPrecondition( - "Multiple feeds for the same output tensor '", tensor.ToString(), - "'"); - } - - return feeds_by_node; -} - -// Creates a unique name for a node that will be replacing a feed output tensor. -std::string GetUniqueNodeName( - absl::string_view node_name, int index, - const std::unordered_map& node_name_map) { - std::string new_node_name_base = absl::StrCat(node_name, "_", index); - int count = 0; - std::string new_node_name = new_node_name_base; - while (node_name_map.find(new_node_name) != node_name_map.end()) { - new_node_name = absl::StrCat(new_node_name_base, "_", count++); - } - return new_node_name; -} - -Status ImporterBase::ConvertDeferredFunctions() { - while (!deferred_functions_.empty()) { - auto conversion_metadata = deferred_functions_.front(); - deferred_functions_.pop(); - - const FunctionDef* func_def = - graph_flib_.Find(conversion_metadata.function_name); - // Converts the graph to an MLIR function and adds it to the module. - // We populate the NodeSpec so that all the _Arg ops get their shape - // added correctly. - GraphImportConfig specs; - specs.enable_shape_inference = specs_.enable_shape_inference; - specs.unconditionally_use_set_output_shapes = - specs_.unconditionally_use_set_output_shapes; - for (const auto& name_and_value : func_def->attr()) { - if (name_and_value.first == "_input_shapes") { - auto& list = name_and_value.second.list(); - auto& signature = func_def->signature(); - // Some models have "_input_shapes" attribute, but with its value empty - if (list.shape_size() > 0 && - list.shape_size() != signature.input_arg_size()) { - return errors::FailedPrecondition( - "Number of input arguments must be equal to the length of " - "_input_shapes attribute in function '", - StringRefToView(conversion_metadata.function_name), "'."); - } - for (int i = 0, e = signature.input_arg_size(); i < e; i++) { - auto& input_arg = signature.input_arg(i); - auto& array_info = specs.inputs[input_arg.name()]; - array_info.imported_dtype = input_arg.type(); - // set to unranked for empty "_input_shapes" attribute - if (list.shape_size() > 0) - array_info.shape = list.shape(i); - else - array_info.shape.set_unknown_rank(true); - } - } - } - - ImporterBase importer(graph_flib_, debug_info_, specs, module_, - tf_name_to_mlir_name_, function_name_uniquifier_, - conversion_metadata.function_name); - - std::unique_ptr fbody; - TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody)); - TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph)); - - TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody)); - - absl::InlinedVector arg_nodes; - absl::InlinedVector ret_nodes; - absl::InlinedVector control_ret_nodes; - importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, - &control_ret_nodes); - const std::string& mlir_func_name = - (*tf_name_to_mlir_name_)[conversion_metadata.function_name]; - - TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes, - ret_nodes, control_ret_nodes, - conversion_metadata.attributes)); - - // Additional function bodies could be discovered during the deferred - // loading of the current function. Add them to the working queue. - while (!importer.deferred_functions_.empty()) { - deferred_functions_.push(importer.deferred_functions_.front()); - importer.deferred_functions_.pop(); - } - } - - return absl::OkStatus(); -} - -Status ImporterBase::RemoveBackedges() { - // Remove all the backedges. So the nodes can be added to the shape refiner. - TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); - VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size()) - << " backedges."; - - // Creates a map for quickly identifying whether a node output is a backedge. - for (const auto& edge : back_edge_helper_.RemovedEdges()) { - if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() && - back_edge_node_output_[edge.src] != edge.src_output) { - return errors::FailedPrecondition( - "More than one of the src node outputs are backedges!"); - } - back_edge_node_output_[edge.src] = edge.src_output; - // We expect a merge to receive a single backedge (multiple NextIteration - // nodes feeding into the same merge is unexpected here). - DCHECK(!back_edge_dst_inputs_.contains(edge.dst)); - back_edge_dst_inputs_[edge.dst] = edge; - } - - // Obtains a RPO ordering, using node names as a tiebreak for stable sorting. - - ordered_nodes_.clear(); - TopologicalOrdering( - *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, GroupByDevice()); - return absl::OkStatus(); -} - -Status CopyStackTraces(const Graph& from, Graph* to) { - // Copy over the stack traces. - // TODO(jpienaar): This really shouldn't be needed, copying the Graph above - // and then needing these traversals is unfortunate. - std::unordered_map node_map = from.BuildNodeNameIndex(); - for (Node* node : to->nodes()) { - if (const Node* old_node = node_map[node->name()]) { - if (const std::shared_ptr& stack = - old_node->GetStackTrace()) { - DVLOG(2) << "Stack for " << node->name() << " " - << old_node->GetStackTrace()->ToString( - AbstractStackTrace::TracePrintingOptions()); - node->SetStackTrace(stack); - } else { - DVLOG(1) << "No stack for " << node->name() << " (" << node - << ") in Graph " << &from; - } - } else { - DVLOG(1) << "No stack for " << node->name() << " (" << node - << ") in Graph " << &from; - } - } - - return absl::OkStatus(); -} - -absl::StatusOr> -ImporterBase::CreatePlaceholderNodeForFeed( - const TensorShapeProto& shape, DataType dtype, Node* node, int index, - const std::unordered_map& node_name_map) { - DCHECK_LT(index, node->num_outputs()); - const bool update_inplace = node->num_outputs() == 1 && index == 0; - std::string new_node_name = - update_inplace ? node->name() - : GetUniqueNodeName(node->name(), index, node_name_map); - - Node* placeholder_node; - NodeBuilder builder(new_node_name, "Placeholder"); - builder.Attr("shape", shape); - builder.Attr("dtype", dtype); - TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node)); - - // Update edges from original feed with Placeholder node. - std::vector data_edges; - std::vector control_edges; - for (const tensorflow::Edge* edge : node->out_edges()) { - if (edge->src_output() == index) { - data_edges.push_back(edge); - } else if (update_inplace && edge->IsControlEdge()) { - control_edges.push_back(edge); - } - } - - for (const auto* edge : data_edges) { - TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(), - edge->dst_input())); - } - - // TODO(lyandy): Preserve control dependencies properly by not forwarding - // control dependencies to data outputs and not removing single output nodes. - // When a data output is replaced as a feed, unless there is another non feed - // data output or an explicit control output used by the same node, transitive - // control dependencies are not to be executed. For single output nodes, - // Placeholders can be converted to a NoOp if there are no uses, and - // PlaceholderWithDefault can be converted to an Identity. - for (const auto* edge : control_edges) { - graph_->AddControlEdge(placeholder_node, edge->dst()); - graph_->RemoveControlEdge(edge); - } - - if (update_inplace) { - graph_->RemoveNode(node); - } - - return std::pair(placeholder_node, update_inplace); -} - -Status ImporterBase::GetInputOutputNodes( - const std::unordered_map& node_name_map, - std::unordered_set* nodes) { - auto add_node = [&](absl::string_view name) { - auto it = node_name_map.find(std::string(name)); - if (it == node_name_map.end()) { - return errors::FailedPrecondition( - absl::StrCat("Graph does not contain node: ", name)); - } - nodes->insert(it->second); - return absl::OkStatus(); - }; - - // Remap feeds and fetches to newly created Placeholder nodes. - for (const auto& input : specs_.inputs) { - TensorId tensor = ParseTensorName(input.first); - auto remapped_it = remapped_feeds_.find(tensor); - if (remapped_it != remapped_feeds_.end()) { - TF_RETURN_IF_ERROR(add_node(remapped_it->second)); - } else { - TF_RETURN_IF_ERROR(add_node(tensor.node())); - } - } - - for (const auto& output : specs_.outputs) { - TensorId tensor = ParseTensorName(output); - auto remapped_it = remapped_feeds_.find(tensor); - if (remapped_it != remapped_feeds_.end()) { - TF_RETURN_IF_ERROR(add_node(remapped_it->second)); - } else { - TF_RETURN_IF_ERROR(add_node(tensor.node())); - } - } - - for (const auto& control_output : specs_.control_outputs) - TF_RETURN_IF_ERROR(add_node(control_output)); - - return absl::OkStatus(); -} - -// TODO(jpienaar): Remove this post shape inference on import flag is removed. -Status ImporterBase::AddNodesToShapeRefiner( - std::unordered_map* node_name_map) { - shape_refiner_ = - std::make_unique(graph_->versions(), graph_->op_registry()); - // Some operations (for example "TPUExecute") don't have shape inference - // function defined, so we should set this to false for adding nodes with - // these types of operations. - shape_refiner_->set_require_shape_inference_fns(false); - shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); - - TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); - - // First add all nodes to the refiner. - for (Node* node : ordered_nodes_) { - // We need to use a TensorFlow node to teach the shape refiner that user - // specifies certain data type and shape for the inputs in the `specs_`. - // This node shouldn't have any inputs, only have one output and its - // output type/shape is only determined by its "named" attributes. (The - // attributes should have fixed names so we can use the info from `specs_` - // to set the value of them.) `Placeholder` satisfies these constraints. - // - // Therefore, if the input node isn't a `Placeholder`, we create one and use - // it to replace the original input node, so the shape refiner can - // successfully propagate the user's input type and shape to the rest of the - // graph. - bool node_added_to_shape_refiner = false; - auto it = feeds_by_node.find(node->name()); - if (it != feeds_by_node.end()) { - auto op_name = node->op_def().name(); - if (op_name != "Placeholder" && op_name != "LegacyFedInput" && - op_name != FunctionLibraryDefinition::kArgOp) { - for (const auto& output_tensor : it->second) { - const int index = output_tensor.first; - const ArrayInfo& array_info = output_tensor.second->second; - - DataType dtype = array_info.imported_dtype; - // Uses the existing output type if it isn't specified by the user. - if (dtype == DT_INVALID) { - dtype = node->output_type(index); - } - - TF_ASSIGN_OR_RETURN( - auto placeholder_node_and_removed, - CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, - *node_name_map)); - - Node* placeholder_node = placeholder_node_and_removed.first; - if (placeholder_node_and_removed.second) { - // Original node has been removed from the graph. - node = placeholder_node; - node_added_to_shape_refiner = true; - } - remapped_feeds_[{it->first, index}] = placeholder_node->name(); - (*node_name_map)[placeholder_node->name()] = placeholder_node; - // Add the new placeholder node to the shape refiner. - Status status = shape_refiner_->AddNode(placeholder_node); - if (!status.ok()) { - return EmitErrorWithLocationStr(*placeholder_node, status); - } - } - } else { - auto index_it = it->second.find(0); - if (index_it == it->second.end()) { - return errors::FailedPrecondition( - "Missing feed output tensor at index 0 for node '", node->name(), - "'"); - } - node->AddAttr("shape", index_it->second->second.shape); - DataType dtype = index_it->second->second.imported_dtype; - // Uses the existing output type if it isn't specified by the user. - if (dtype == DT_INVALID) { - dtype = node->output_type(0); - } - node->AddAttr("dtype", dtype); - } - } - if (!node_added_to_shape_refiner) { - // Add the node to the shape refiner if the node hasn't been removed. - Status status = shape_refiner_->AddNode(node); - if (!status.ok()) { - return EmitErrorWithLocationStr(*node, status); - } - } - - auto set_shape_from_list_attr = [&](const AttrValue* attr) { - auto& list = attr->list(); - // This follows the same approach as in ValidateShape, but only flags - // warning in case where there are mismatch in number of shapes and - // outputs and in which case it just returns without attempting to refine. - if (list.shape_size() != node->num_outputs()) { - LOG(WARNING) << "Node '" << node->name() << "' has " - << node->num_outputs() << " outputs but the " - << kOutputShapesAttrName - << " attribute specifies shapes for " << list.shape_size() - << " outputs"; - return absl::OkStatus(); - } - - for (const auto& shape : llvm::enumerate(list.shape())) { - auto* node_context = shape_refiner_->GetContext(node); - shape_inference::ShapeHandle handle; - Status status = - node_context->MakeShapeFromShapeProto(shape.value(), &handle); - if (!status.ok()) { - return EmitErrorWithLocationStr(*node, status); - } - node_context->set_output(shape.index(), handle); - } - return absl::OkStatus(); - }; - - // If it is the argument node, the shape handle is set explicitly, so it - // can be propagated to the body nodes of the function. - if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { - auto* node_context = shape_refiner_->GetContext(node); - DCHECK(node_context != nullptr); - if (const AttrValue* attr = node->attrs().Find("shape")) { - shape_inference::ShapeHandle handle; - Status status = - node_context->MakeShapeFromShapeProto(attr->shape(), &handle); - if (!status.ok()) { - return EmitErrorWithLocationStr(*node, status); - } - node_context->set_output(0, handle); - } else if (const AttrValue* attr = - node->attrs().Find(kOutputShapesAttrName)) { - TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); - } else { - node_context->set_output(0, node_context->UnknownShape()); - } - } - - // Following GraphConstructor::ValidateShape called from - // GraphConstructor::Convert, override the shape if _output_shapes is set. - if (specs_.unconditionally_use_set_output_shapes || - node->op_def().name() == "ReadVariableOp") { - if (const AttrValue* attr = node->attrs().Find(kOutputShapesAttrName)) - TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); - } - } - - // Since we might have inserted and removed nodes from the graph, fix - // source/sink edges and reconstruct the RPO ordering of nodes - FixupSourceAndSinkEdges(graph_.get()); - - // Prune nodes in the graph that are not reachable from the output. - if (specs_.prune_unused_nodes) { - std::unordered_set prune_start; - TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); - if (!prune_start.empty()) { - if (PruneForReverseReachability(graph_.get(), prune_start)) { - VLOG(1) << "Pruned unused nodes in graphdef"; - } else { - VLOG(1) << "No unused nodes in graphdef to prune"; - } - } else { - VLOG(1) << "No output nodes specified, skipping pruning"; - } - } else { - VLOG(1) << "Pruning unused nodes in graphdef is disabled"; - } - - // Re-initialize ordered_nodes_ since we might have modified the graph. - ordered_nodes_.clear(); - TopologicalOrdering( - *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, GroupByDevice()); - - VLOG(1) << "Inferring graph shapes to fixpoint"; - - // The "changed" information from UpdateNode can give false positives, so we - // create a dedicated method to verify the shapes are not changed before and - // after the shape refine. - auto same_inferred_shape = [](shape_inference::InferenceContext* c, - shape_inference::ShapeHandle s0, - shape_inference::ShapeHandle s1) -> bool { - if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) { - return true; - } - if (c->Rank(s0) != c->Rank(s1)) { - return false; - } - for (int i = 0; i < c->Rank(s0); ++i) { - if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) { - int64_t val0 = c->Value(c->Dim(s0, i)); - int64_t val1 = c->Value(c->Dim(s1, i)); - // Negative value is treated as unknown so all negative values indicate - // the same dimension. - if (val0 >= 0 && val1 >= 0 && val0 != val1) return false; - } - } - return true; - }; - - bool changed = true; - int i = 0; - const int kMaxIterationCount = 2; - while (changed && i != kMaxIterationCount) { - changed = false; - for (const Node* node : ordered_nodes_) { - auto* shape_context = shape_refiner_->GetContext(node); - DCHECK(shape_context != nullptr); - absl::InlinedVector existing; - existing.reserve(shape_context->num_outputs()); - for (int o = 0; o < shape_context->num_outputs(); ++o) { - existing.push_back(shape_context->output(o)); - } - bool inferred = false; - shape_inference::ShapeHandle handle; - Status status = - shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred); - if (!status.ok()) { - return EmitErrorWithLocationStr(*node, status); - } - for (int o = 0; o < shape_context->num_outputs(); ++o) { - if (!same_inferred_shape(shape_context, shape_context->output(o), - existing[o])) { - changed = true; - break; - } - } - } - ++i; - } - if (i >= kMaxIterationCount) { - LOG(WARNING) << "Graph shapes did not converge to a fixpoint within " - << kMaxIterationCount - << " iterations. Graph shapes may be conservative."; - } - VLOG(1) << "Graph shapes were inferred with " << (i - 1) - << " extra rounds of analysis to reach a fixpoint."; - return absl::OkStatus(); -} - -absl::StatusOr ImporterBase::InferInputType(const Node& node, - int idx, - mlir::Builder builder) { - if (specs_.enable_shape_inference) { - // TODO(jpienaar): Remove this if shape inference on import flag is removed. - auto* context = shape_refiner_->GetContext(&node); - DataType dtype = node.input_type(idx); - return ConvertDataTypeAndShape(dtype, context->input(idx), - context->input_handle_shapes_and_types(idx), - context, builder); - } - DataType dtype = node.properties()->input_types[idx]; - mlir::Type element_type; - TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); - return mlir::UnrankedTensorType::get(element_type); -} - -absl::StatusOr ImporterBase::InferOutputType( - const Node& node, int idx, mlir::Builder builder) { - DataType dtype = node.properties()->output_types[idx]; - - // Returns output type given inference context. - auto shape_ic = - [&](shape_inference::InferenceContext* c) -> absl::StatusOr { - // TODO(b/200093974): Post triage, consider following - // GraphConstructor::ValidateShape in checking _output_shapes always. - if (specs_.unconditionally_use_set_output_shapes) { - if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) { - auto& list = attr->list(); - if (list.shape_size() > idx) { - const TensorShapeProto& p = list.shape()[idx]; - shape_inference::ShapeHandle h; - Status s = c->MakeShapeFromShapeProto(p, &h); - if (!s.ok()) - return errors::InvalidArgument( - "Node '", node.name(), " has an invalid ", - kOutputShapesAttrName, " attribute (shape #", idx, " error:'", - s.message(), "')"); - c->set_output(idx, h); - } - } - } - - return ConvertDataTypeAndShape(dtype, c->output(idx), - c->output_handle_shapes_and_types(idx), c, - builder); - }; - - if (specs_.enable_shape_inference) { - // TODO(jpienaar): Remove this if shape inference on import flag is removed. - shape_inference::InferenceContext* shape_context = - shape_refiner_->GetContext(&node); - return shape_ic(shape_context); - } - - // Treat TensorList init ops specially here as the op requires knowing its - // element dtype. - // TODO(jpienaar): Reconsider post refactoring shape functions. - if (node.type_string() == "TensorListReserve" || - node.type_string() == "EmptyTensorList") { - mlir::Type etype; - if (auto element_dtype = node.attrs().Find("element_dtype")) { - TF_RETURN_IF_ERROR( - ConvertDataType(element_dtype->type(), builder, &etype)); - } - return GetTypeFromTFTensorShape( - {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)}, - etype.getContext())); - } - - if (node.IsWhileNode()) { - auto* output_shapes = node.attrs().Find("output_shapes"); - auto* element_types = node.attrs().Find("T"); - if (output_shapes && !output_shapes->list().shape().empty()) { - const auto& output_shape = output_shapes->list().shape(idx); - const auto& element_type = element_types->list().type(idx); - return ConvertToMlirTensorType(output_shape, element_type, &builder); - } - } - - auto type_from_array_attr = [&node, &idx, &builder]( - absl::string_view output_shape_attr, - absl::string_view element_type_attr) { - auto* output_shapes = node.attrs().Find(output_shape_attr); - auto* element_types = node.attrs().Find(element_type_attr); - const auto& output_shape = output_shapes->list().shape(idx); - const auto& element_type = element_types->list().type(idx); - return ConvertToMlirTensorType(output_shape, element_type, &builder); - }; - - if (node.type_string() == "IteratorGetNext" || - node.type_string() == "IteratorGetNextSync" || - node.type_string() == "MultiDeviceIteratorGetNextFromShard") - return type_from_array_attr("output_shapes", "output_types"); - - if (node.type_string() == "InfeedDequeueTuple") - return type_from_array_attr("shapes", "dtypes"); - - if (node.type_string() == "InfeedDequeue") { - assert(idx == 0); - const auto& output_shape = node.attrs().Find("shape")->shape(); - const auto& element_type = node.attrs().Find("dtype")->type(); - return ConvertToMlirTensorType(output_shape, element_type, &builder); - } - - // Returns a simple, more conservative unranked tensor type. - auto default_type = [&]() -> absl::StatusOr { - mlir::Type element_type; - TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); - - // TODO(b/200093974): Post triage, consider following - // GraphConstructor::ValidateShape in checking _output_shapes. - if (specs_.unconditionally_use_set_output_shapes) { - if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) { - auto& list = attr->list(); - if (list.shape_size() > idx) { - llvm::SmallVector shape; - const TensorShapeProto& shape_proto = list.shape()[idx]; - if (shape_proto.unknown_rank()) - return mlir::UnrankedTensorType::get(element_type); - TF_RETURN_IF_ERROR(ConvertToMlirShape(shape_proto, &shape)); - return GetTypeFromTFTensorShape(shape, element_type); - } - } - } - - return mlir::UnrankedTensorType::get(element_type); - }; - - // Below we only try and do some shape inference for "source" ops which have - // no inputs. - if (node.num_inputs() > 0) return default_type(); - - // Do some simply inference here to get the function arguments correct for - // this common case. - // TODO(jpienaar): Reconsider post refactoring shape functions. - if (node.IsArg()) { - if (dtype == DT_RESOURCE) { - const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); - const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); - if (dtype_attr && shape_attr) { - if (dtype_attr->list().type().empty()) { - return errors::InvalidArgument( - "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", - shape_attr->DebugString()); - } - if (shape_attr->list().shape().empty()) { - return errors::InvalidArgument( - "Invalid \"_handle_shapes\" attribute value for _Arg node: ", - shape_attr->DebugString()); - } - DataType dtype = dtype_attr->list().type(0); - const TensorShapeProto& shape_proto = shape_attr->list().shape(0); - TF_ASSIGN_OR_RETURN( - auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); - return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( - {mlir::cast(etype)}, builder.getContext())); - } else { - return mlir::UnrankedTensorType::get( - mlir::TF::ResourceType::get(builder.getContext())); - } - } else if (auto shape = node.attrs().Find("_output_shapes")) { - if (shape->has_list() && shape->list().shape_size() == 1) { - return ConvertToMlirTensorType(shape->list().shape().at(0), dtype, - &builder); - } - } - } - - const tensorflow::OpRegistrationData* op_reg_data; - TF_RETURN_IF_ERROR( - graph_->op_registry()->LookUp(node.type_string(), &op_reg_data)); - if (!op_reg_data) { - DVLOG(1) << "Skipping inference for unregistered op " << node.type_string(); - return default_type(); - } - if (op_reg_data->shape_inference_fn == nullptr) { - DVLOG(1) << "Skipping inference for op without shape function " - << node.type_string(); - return default_type(); - } - shape_inference::InferenceContext c(graph_->versions().producer(), - node.attrs(), op_reg_data->op_def, - std::vector{}, {}, - /*input_tensors_as_shapes=*/{}, {}); - TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); - return shape_ic(&c); -} - -absl::StatusOr ImporterBase::ConvertDataTypeAndShape( - DataType dtype, const shape_inference::ShapeHandle& handle, - const std::vector* handle_subtypes, - shape_inference::InferenceContext* context, mlir::Builder builder) { - TF_ASSIGN_OR_RETURN(auto subtypes, - ConvertSubtypes(handle_subtypes, context, builder)); - - mlir::Type element_type; - if (dtype == DT_VARIANT) - element_type = mlir::TF::VariantType::get(subtypes, context_); - else if (dtype == DT_RESOURCE) - element_type = mlir::TF::ResourceType::get(subtypes, context_); - else - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertDataType(dtype, builder, &element_type)); - - return ConvertElementTypeAndShape(element_type, handle, context, builder); -} - -absl::StatusOr ImporterBase::ConvertElementTypeAndShape( - mlir::Type element_type, const shape_inference::ShapeHandle& handle, - shape_inference::InferenceContext* context, mlir::Builder builder) { - if (!context->RankKnown(handle)) { - return mlir::UnrankedTensorType::get(element_type); - } - - // Sentinel for an unknown dimension size. getTensorType interprets any - // negative value as an unknown dimension. - // TODO(jmolloy): Ideally this shouldn't be a local sentinel. - const int64_t kUnknownDim = -1; - - absl::InlinedVector dimensions; - int32_t rank = context->Rank(handle); - dimensions.reserve(rank); - for (int i = 0; i < rank; ++i) { - auto dim_handle = context->Dim(handle, i); - if (!context->ValueKnown(dim_handle)) - dimensions.push_back(kUnknownDim); - else - dimensions.push_back(context->Value(dim_handle)); - } - - return GetTypeFromTFTensorShape( - llvm::ArrayRef(dimensions.begin(), dimensions.end()), element_type); -} - -absl::StatusOr ImporterBase::ConvertSubtypes( - const std::vector* handle_subtypes, - shape_inference::InferenceContext* context, mlir::Builder builder) { - ElementSubtypes subtypes; - if (!handle_subtypes) return subtypes; - - subtypes.reserve(handle_subtypes->size()); - for (const auto& subtype : *handle_subtypes) { - mlir::Type element_type; - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type)); - TF_ASSIGN_OR_RETURN(TensorType type, - ConvertElementTypeAndShape(element_type, subtype.shape, - context, builder)); - subtypes.push_back(type); - } - return subtypes; -} - -Status ImporterBase::ConvertFunctionCallAttribute(const std::string& base_name, - const AttrValue& value, - NamedAttrList* attributes) { - TF_ASSIGN_OR_RETURN(auto func_attr, - ConvertFunctionCallName(value.func().name())); - if (!func_attr) return absl::OkStatus(); - attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); - - for (const auto& it : value.func().attr()) { - auto name = absl::StrCat(base_name, ".", it.first); - TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); - attributes->push_back(builder_.getNamedAttr(name, value)); - } - return absl::OkStatus(); -} - -absl::StatusOr ImporterBase::ConvertFunctionCallName( - const std::string& func_name) { - // Some ops like XlaHostCompute op uses empty value to represent missing - // functions. Such attribute values should be defined optional in MLIR - // definition. - if (func_name.empty()) return mlir::FlatSymbolRefAttr(); - - TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); - auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; - return mlir::SymbolRefAttr::get(builder_.getContext(), mlir_func_name); -} - -absl::StatusOr ImporterBase::ConvertAttributeValue( - const AttrValue& value) { - switch (value.value_case()) { - case AttrValue::kFunc: { - // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. - // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue - // will not use this representation. This also doesn't handle empty - // function values like ConvertFunctionCallName method. - NamedAttrList attrs; - for (const auto& func_attr : value.func().attr()) { - TF_ASSIGN_OR_RETURN( - auto attr, ImporterBase::ConvertAttributeValue(func_attr.second)); - attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); - } - auto func_attrs = builder_.getDictionaryAttr(attrs); - return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); - } - case AttrValue::kList: { - if (!value.list().func().empty()) { - absl::InlinedVector attrs; - for (const auto& item : value.list().func()) { - TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); - if (item.attr_size() != 0) - return errors::Unimplemented( - "func attributes with non-zero attr.size()"); - if (attr) attrs.push_back(attr); - } - return builder_.getArrayAttr( - llvm::ArrayRef(attrs.begin(), attrs.end())); - } - return ConvertNonFuncAttributeValue(value, &builder_); - } - default: - return ConvertNonFuncAttributeValue(value, &builder_); - } -} - -void ImporterBase::GetArgsAndRetsFromFunctionBody( - const FunctionBody& fbody, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes, - absl::InlinedVector* control_ret_nodes) { - arg_nodes->reserve(fbody.arg_nodes.size()); - ret_nodes->reserve(fbody.ret_nodes.size()); - for (auto arg : fbody.arg_nodes) { - arg_nodes->emplace_back(arg, 0); - } - for (auto ret : fbody.ret_nodes) { - ret_nodes->emplace_back(ret, 0); - } - *control_ret_nodes = fbody.control_ret_nodes; -} - -Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { - // If the library function has been converted already, nothing needs to be - // done. - if (tf_name_to_mlir_name_->find(std::string(func_name)) != - tf_name_to_mlir_name_->end()) - return absl::OkStatus(); - - std::string mlir_func_name( - function_name_uniquifier_->GetUniqueName(func_name)); - (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name; - - const auto& func_lib = graph_flib_; - const auto* func_def = func_lib.Find(std::string(func_name)); - if (func_def == nullptr) { - return errors::FailedPrecondition( - absl::StrCat("Failed to find function '", StringRefToView(func_name), - "'. The imported TensorFlow GraphDef is ill-formed.")); - } - - // Converts the argument and return types to MLIR types. - std::vector attributes; - attributes.reserve(func_def->attr_size()); - for (const auto& name_and_value : func_def->attr()) { - // This is a function definition attribute, so it shouldn't contain - // kFunc attribute and it is treated as normal one. - TF_ASSIGN_OR_RETURN(auto attr, - ConvertAttributeValue(name_and_value.second)); - std::string attr_name = - mangling_util::MangleAttributeName(name_and_value.first); - attributes.push_back(builder_.getNamedAttr(attr_name, attr)); - } - - // Checks opdef stateful attribute and import that as Function Attribute - if (func_def->signature().is_stateful()) { - auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); - attributes.push_back( - builder_.getNamedAttr(stateful_str, builder_.getUnitAttr())); - } - - // Checks for an associated custom gradient function. Adds it to the attribute - // list of this function. - auto grad_func_name = func_lib.FindGradient(std::string(func_name)); - if (!grad_func_name.empty()) { - TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); - auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; - auto gradient_attr = - mlir::SymbolRefAttr::get(builder_.getContext(), mlir_grad_func_name); - auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); - attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); - } - - deferred_functions_.emplace(func_name.str(), attributes); - return absl::OkStatus(); -} - -Status ImporterBase::PruneUnreachableNodes( - std::unordered_map* node_name_map) { - std::unordered_set prune_start; - TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); - - if (!prune_start.empty()) { - if (PruneForReverseReachability(graph_.get(), prune_start)) { - VLOG(1) << "Pruned unused nodes in graphdef"; - } else { - VLOG(1) << "No unused nodes in graphdef to prune"; - } - } else { - VLOG(1) << "No output nodes specified, skipping pruning"; - } - return absl::OkStatus(); -} - -Status ImporterBase::ConvertFeedsToPlaceholders( - std::unordered_map* node_name_map) { - // Feeds (edges) are converted into single-output placeholder nodes to - // simplify the conversion process. - TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); - for (const auto& it : feeds_by_node) { - TensorId tensor = ParseTensorName(it.first); - auto jt = node_name_map->find(std::string(tensor.node())); - if (jt == node_name_map->end()) { - return errors::FailedPrecondition( - absl::StrCat("Graph does not contain node: ", tensor.node())); - } - - Node* node = jt->second; - auto op_name = node->op_def().name(); - if (op_name != "Placeholder" && op_name != "LegacyFedInput" && - op_name != FunctionLibraryDefinition::kArgOp) { - for (const auto& output_tensor : it.second) { - const int index = output_tensor.first; - const ArrayInfo& array_info = output_tensor.second->second; - - DataType dtype = array_info.imported_dtype; - // Uses the existing output type if it isn't specified by the user. - if (dtype == DT_INVALID) { - dtype = node->output_type(index); - } - - TF_ASSIGN_OR_RETURN( - auto placeholder_node_and_removed, - CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, - *node_name_map)); - - Node* placeholder_node = placeholder_node_and_removed.first; - if (placeholder_node->in_edges().empty()) { - graph_->AddControlEdge(graph_->source_node(), placeholder_node, - true /* skip test for duplicates */); - } - if (placeholder_node->out_edges().empty()) { - graph_->AddControlEdge(placeholder_node, graph_->sink_node(), - true /* skip test for duplicates */); - } - remapped_feeds_[{it.first, index}] = placeholder_node->name(); - (*node_name_map)[placeholder_node->name()] = placeholder_node; - } - } - } - return absl::OkStatus(); -} - -Status ImporterBase::PrepareConvert(const Graph& graph, - std::unique_ptr graph_def) { - // TODO(fengliuai): Converting to GraphDef and back is the easiest way to - // clone a graph. - // TODO(fengliuai): clone the graph without going to graph_def first. - if (graph_def == nullptr) { - graph_def = std::make_unique(); - graph.ToGraphDef(graph_def.get()); - } - graph_ = std::make_unique(graph.flib_def()); - GraphConstructorOptions opts; - opts.allow_internal_ops = true; - opts.add_default_attributes = true; - TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph( - opts, std::move(*graph_def), graph_.get())); - - TF_RETURN_IF_ERROR(RemoveBackedges()); - - TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get())); - - auto node_name_map = graph_->BuildNodeNameIndex(); - - if (specs_.enable_shape_inference) { - // TODO(jpienaar): Remove once infer shapes on import flag is removed. - TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map)); - } else { - TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map)); - } - - // Prune nodes in the graph that are not reachable from the output. - if (specs_.prune_unused_nodes) { - TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map)); - } - - if (!specs_.enable_shape_inference) { - // Re-initialize ordered_nodes_ since we might have modified the graph. - ordered_nodes_.clear(); - TopologicalOrdering( - *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, - GroupByDevice()); - } - - return absl::OkStatus(); -} - -Status ImporterBase::Convert( - llvm::StringRef func_name, mlir::FunctionType func_type, - const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes, - const absl::InlinedVector& control_ret_nodes, - llvm::ArrayRef attrs) { - // TODO(b/122040776): Uses debug info for FunctionDef. - auto function = mlir::func::FuncOp::create(mlir::UnknownLoc::get(context_), - func_name, func_type, attrs); - - module_.push_back(function); - // Seeds the builder with an initial block. - function.addEntryBlock(); - builder_ = mlir::OpBuilder(function.getBody()); - - // Create the graph operation in which we will convert the individual nodes. - auto graph = builder_.create( - function.getLoc(), func_type.getResults()); - builder_.createBlock(&graph.getBody()); - - for (const Node* node : ordered_nodes_) { - TF_RETURN_IF_ERROR(ConvertNode(*node)); - } - - // Adds the backedges back to the function by creating the source and sink - // pairs. - TF_RETURN_IF_ERROR(AddBackedges()); - - TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, - func_type.getInputs(), arg_nodes, - ret_nodes, control_ret_nodes)); - - // TODO(jpienaar): Update post removing shape_refinier_. - if (!specs_.enable_shape_inference) { - // Refine graph's type given more precise fetch. - auto fetch = graph.GetFetch(); - bool all_equal = true; - for (auto it : - llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) { - auto rt = std::get<1>(it); - if (rt == std::get<0>(it).getType()) continue; - std::get<0>(it).setType(rt); - all_equal = false; - } - if (!all_equal) { - function.setType(mlir::FunctionType::get(function.getContext(), - func_type.getInputs(), - graph.getResultTypes())); - } - } - - return absl::OkStatus(); -} - -Status ImporterBase::ConvertFunctionArgAndRets( - mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op, - llvm::ArrayRef arg_types, - const absl::InlinedVector& arg_nodes, - const absl::InlinedVector& ret_nodes, - const absl::InlinedVector& control_ret_nodes) { - // Store the arg/return attributes as a list rather than uniqueuing during - // construction. - llvm::SmallVector arg_attrs; - arg_attrs.resize(func.getNumArguments()); - llvm::SmallVector ret_attrs; - ret_attrs.resize(func.getNumResults()); - - auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) { - for (const auto& node_attr : node->attrs()) { - const auto& key = node_attr.first; - // Only import optional attributes (e.g., those starting with an - // underscore). - if (key.empty() || key[0] != '_') continue; - // Ignore shape inference attributes as shape information is already - // populated in the result type. - if (IsOutputShapesAttribute(node_attr.second, key) || - IsResourceOutputShapesAttribute(node_attr.second, key)) - continue; - TF_ASSIGN_OR_RETURN(auto converted_attr, - ConvertAttributeValue(node_attr.second)); - std::string dialect_attribute = "tf." + key; - if (is_arg) { - arg_attrs[index].set(dialect_attribute, converted_attr); - } else { - func.setResultAttr(index, dialect_attribute, converted_attr); - ret_attrs[index].set(dialect_attribute, converted_attr); - } - } - return absl::OkStatus(); - }; - - auto* bb = &func.front(); - llvm::SmallDenseMap, mlir::Value, 4> - arg_nodes_to_values; - for (int i = 0, e = arg_types.size(); i < e; ++i) { - auto& arg_node = arg_nodes[i]; - // The lookup can't fail here: otherwise some nodes in the function haven't - // be converted to mlir operations and don't have a mapping. - mlir::Operation* island = node_values_.find(arg_node.node->id())->second; - - auto bb_arg = bb->getArgument(i); - mlir::Value arg_def = bb_arg; - - if (island->getNumResults() != 2) - return errors::InvalidArgument( - "Only feed output tensors of single output nodes are supported"); - - // Collect mapping of OutputTensor to associated block arg. - arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); - island->getResult(0).replaceAllUsesWith(arg_def); - // Erase control outputs from feed. - auto control_uses = island->getResult(1).getUses(); - for (auto& control_use : llvm::make_early_inc_range(control_uses)) - control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); - - if (!arg_node.node->requested_device().empty()) - arg_attrs[i].set("tf.device", builder_.getStringAttr( - arg_node.node->requested_device())); - - if (arg_node.node->IsArg()) { - TF_RETURN_IF_ERROR( - set_attributes_on_func(arg_node.node, i, /*is_arg=*/true)); - } - - island->dropAllReferences(); - island->erase(); - } - - llvm::SmallVector inst_to_return; - for (const auto& ret_and_idx : llvm::enumerate(ret_nodes)) { - const auto& ret = ret_and_idx.value(); - auto* inst = node_values_[ret.node->id()]; - if (ret.node->IsRetval()) { - if (!ret.node->requested_device().empty()) - ret_attrs[ret_and_idx.index()].set( - "tf.device", builder_.getStringAttr(ret.node->requested_device())); - TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(), - /*is_arg=*/false)); - // Lookup the instruction inside the island - auto island_op = llvm::cast(inst); - mlir::Operation* inner_op = &island_op.GetBody().front(); - // Remove kRetOp or kDeviceRetOp operation and return its operand. - // kRetOp and kDeviceRetOp should have just one operand unless they have - // control dependencies. - if (inner_op->getNumOperands() != 1) - return errors::Unimplemented("Return node with multiple inputs."); - inst_to_return.push_back(inner_op->getOperand(0)); - inst->dropAllReferences(); - inst->erase(); - } else { - // Lookup and use block arg if fetch is a feed. - auto it = arg_nodes_to_values.find({ret.node, ret.index}); - if (it != arg_nodes_to_values.end()) - inst_to_return.push_back(it->second); - else - inst_to_return.push_back(inst->getResult(ret.index)); - } - } - - for (Node* control_ret : control_ret_nodes) { - auto* inst = node_values_[control_ret->id()]; - inst_to_return.push_back(*std::prev(inst->result_end())); - } - - // Terminate the function by adding a Fetch operation to terminate the graph - // and a return operation to return the Graph results. - builder_.setInsertionPointToEnd(&graph_op.getBody().front()); - builder_.create(graph_op.getLoc(), - inst_to_return); - builder_.setInsertionPointToEnd(bb); - builder_.create(mlir::UnknownLoc::get(context_), - graph_op.getResults()); - - func.setAllArgAttrs( - llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) { - return list.getDictionary(context_); - }))); - func.setAllResultAttrs( - llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) { - return list.getDictionary(context_); - }))); - - return absl::OkStatus(); -} - -mlir::Location ImporterBase::GetLocation(const Node& node) { - DVLOG(1) << "Getting location for " << node.name() << " " << &node; - // TODO(b/142400497): What is the semantic contract for locations? - // Create a location for node `name` in function `function_name`. - auto create_location = [&](llvm::StringRef name, - llvm::StringRef function_name) -> mlir::Location { - // Use the catenation of function and node names as the lookup key into the - // debug info. This matches the way that the key is formed on the python - // side. - // - // We also use this as the name for the NameLoc for ops in function, since - // otherwise our names could collide across functions. - // For ops in the main graph, we omit the "@function_name" (which, would be - // just "@" since function_name would be empty) because some code seems to - // depend on the name being this way for correctness. - std::string debug_info_key = (name + "@" + function_name).str(); - std::string name_for_name_loc = - function_name.empty() ? name.str() : debug_info_key; - auto name_loc_id = mlir::StringAttr::get(context_, name_for_name_loc); - - std::shared_ptr stack_trace = node.GetStackTrace(); - - // Prefer stack traces if available, fallback to debug info if not, and then - // finally to just name. Older versions of debug info concatenated `@` onto - // the node name for the default graph, so we check both locations. - if (stack_trace != nullptr) { - } else if (stack_traces_.contains(name_for_name_loc)) { - stack_trace = stack_traces_.at(name_for_name_loc); - } else if (stack_traces_.contains(debug_info_key)) { - stack_trace = stack_traces_.at(debug_info_key); - } else { - DVLOG(1) << "No stack trace for " << node.name(); - } - - llvm::SmallVector locations; - - if (stack_trace != nullptr) { - DVLOG(1) << "Stack available for " << node.name(); - for (const StackFrame& frame : stack_trace->ToUncachedFrames()) { - auto file_name = mlir::StringAttr::get(context_, frame.file_name); - // Use col 1 as there is no column info in StackTrace. - auto file_line_loc = - mlir::FileLineColLoc::get(file_name, frame.line_number, 1); - locations.push_back(file_line_loc); - } - } - - // If there are no locations in the stack trace, fall back to just a - // NameLoc with no child. - if (locations.empty()) return mlir::NameLoc::get(name_loc_id); - - // Use the front FileLineColLoc to generate a NameLoc. - mlir::Location node_name_loc = - mlir::NameLoc::get(name_loc_id, locations.front()); - - // If there are more locations then generate a stack trace, otherwise just - // return the name loc. - auto callsite_locs = llvm::ArrayRef(locations).drop_front(); - return callsite_locs.empty() - ? node_name_loc - : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); - }; - - // Create a location for node `name` in function `function_name`. - auto create_op_type_and_name_locations = [&]() { - return mlir::FusedLoc::get( - context_, - // Add the type operation for the propagation of op_type metadata. - {mlir::NameLoc::get( - mlir::StringAttr::get(context_, node.type_string() + ":")), - create_location(node.name(), function_name_for_debug_info_)}); - }; - - // For NextIteration nodes, location is used to pair source and sink nodes. - // Hence, we use node name as location to keep it unique. - // TODO(prakalps): In future the plan is to use tokens to pair source/sink - // nodes. Then NextIteration nodes would not need to be handled separately. - if (node.type_string() == "NextIteration") { - return create_op_type_and_name_locations(); - } - - const auto& node_def = node.def(); - auto original_nodes = - node_def.experimental_debug_info().original_node_names(); - auto original_funcs = - node_def.experimental_debug_info().original_func_names(); - - if (original_nodes.empty()) { - return create_op_type_and_name_locations(); - } else { - // If the original nodes are defined, then we use them to get a list of - // call sites, and then fuse them to a single fused location, with the name - // of the node_def. - llvm::SmallVector node_locations; - node_locations.reserve(original_nodes.size() + 2); - // Add the type operation for the propagation of op_type metadata. - node_locations.push_back(mlir::NameLoc::get( - mlir::StringAttr::get(context_, node.type_string() + ":"))); - // Retrieve the names from the experimental_debug_info. - for (int i = 0, e = original_nodes.size(); i != e; ++i) { - const auto& node_name = original_nodes[i]; - auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; - node_locations.push_back(create_location(node_name, func_name)); - } - // Retrieve the name of the node_def. - node_locations.push_back( - create_location(node.name(), function_name_for_debug_info_)); - return mlir::FusedLoc::get(context_, node_locations); - } -} - -Status ImporterBase::EmitErrorWithLocationStr(const Node& node, - const Status& error_status) { - const mlir::Location location = GetLocation(node); - mlir::emitError(location); - return error_handler_.Combine(error_status); -} - -mlir::Operation* ImporterBase::CreateOperation( - const Node& node, llvm::StringRef node_type_name, - const mlir::OperationState& result, - const llvm::SmallVectorImpl& control_operands) { - // For the tf.executor specific operations (not wrapped in an island), we - // have an extra returned value for the control result, and we concatenate - // control and non-control operands. - mlir::SmallVector types(result.types); - types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); - mlir::SmallVector operands(result.operands); - operands.append(control_operands.begin(), control_operands.end()); - - auto loc = result.location; - // Dispatch based on the name and create the appropriate operation. - if (node.IsSwitch()) { - // Switch and _SwitchN both are in switch class, differentiate based on - // op name. - if (node.op_def().name() == "_SwitchN") { - return builder_.create(loc, types, operands, - result.attributes); - } - return builder_.create(loc, types, operands, - result.attributes); - } - if (node.IsMerge()) { - return builder_.create(loc, types, operands, - result.attributes); - } - if (node.IsNextIteration()) { - // NextIteration is a bit special, we create a pair of operations that are - // linked together through a token returned by the source. - // We make use of a separate builder to insert the source at the top of - // the block. - mlir::OpBuilder builder_at_begin(builder_.getBlock(), - builder_.getBlock()->begin()); - auto source_op = - builder_at_begin.create( - loc, operands[0].getType(), result.attributes); - return builder_.create( - loc, source_op.getToken(), operands, result.attributes); - } - if (node.IsLoopCond()) { - return builder_.create(loc, types, operands, - result.attributes); - } - if (node.IsEnter()) { - return builder_.create(loc, types, operands, - result.attributes); - } - if (node.IsExit()) { - return builder_.create(loc, types, operands, - result.attributes); - } - if (node.IsControlTrigger()) { - return builder_.create( - loc, mlir::ValueRange(operands), result.attributes); - } - // Regular TensorFlow operation are wrapped in a tf_executor.island. - auto island = builder_.create( - result.location, types, control_operands, - mlir::ArrayRef{}); - island.getBody().push_back(new mlir::Block); - mlir::OpBuilder island_builder = - mlir::OpBuilder::atBlockEnd(&island.GetBody()); - - // Create the operation inside the island now. - mlir::Operation* inner_op = island_builder.create(result); - - // Sets operand_segment_sizes or result_segment_sizes attribute to the op. - const auto set_segment_sizes_attr = - [&](const NameRangeMap& arg_ranges, - const protobuf::RepeatedPtrField& args, - llvm::StringRef attr_name) { - std::vector values; - values.reserve(args.size()); - for (const auto& arg : args) { - auto range = arg_ranges.at(arg.name()); - values.push_back(range.second - range.first); - } - auto attr_value = - mlir::DenseI32ArrayAttr::get(inner_op->getContext(), values); - inner_op->setAttr(attr_name, attr_value); - }; - - if (inner_op->hasTrait() || - inner_op->hasTrait()) { - // The op has multiple variadic operands or results. - // Calculate operand and result segment sizes using the OpDef. - NameRangeMap input_ranges, output_ranges; - // This will fail only if the OpDef is syntactically invalid. - // TODO(jpienaar): Convert this CHECK into a properly propagated error. - TF_CHECK_OK( - NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges)); - if (inner_op->hasTrait()) { - // Add derived "operand_segment_sizes" attr to the created operation. - // TODO(b/146937733): Don't use here. - set_segment_sizes_attr(input_ranges, node.op_def().input_arg(), - mlir::OpTrait::AttrSizedOperandSegments< - void>::getOperandSegmentSizeAttr()); - } - - if (inner_op->hasTrait()) { - // Add derived "result_segment_sizes" attr to the created operation. - // TODO(b/146937733): Don't use here. - set_segment_sizes_attr(output_ranges, node.op_def().output_arg(), - mlir::OpTrait::AttrSizedResultSegments< - void>::getResultSegmentSizeAttr()); - } - } - - if (VLOG_IS_ON(1)) { - mlir::OperationName name = inner_op->getName(); - if (!name.isRegistered() && - // Skip unmodelled ops that are handled differently. - (node_type_name != "_Arg" && node_type_name != "_Retval") && - !unmodelled_op_names_.count(name.getIdentifier())) { - if (node.op_def().is_stateful()) { - VLOG(1) << "[potentially conservative] Op type `" << node.type_string() - << "` is stateful but effects not modelled"; - } else { - // See if any resource type is used. - bool resource = false; - std::function record_resource; - record_resource = [&](mlir::Type type) { - type.walk([&](mlir::Type t) { - if (resource) return mlir::WalkResult::interrupt(); - if (mlir::isa(type)) { - resource = true; - return mlir::WalkResult::interrupt(); - } - - return mlir::WalkResult::advance(); - }); - - return resource; - }; - - for (mlir::Type t : inner_op->getResultTypes()) - if (record_resource(t)) break; - for (mlir::Type t : inner_op->getOperandTypes()) - if (record_resource(t)) break; - if (resource) { - unmodelled_op_names_.insert(name.getIdentifier()); - VLOG(1) << "[potentially conservative] Op type `" - << node.type_string() - << "` has resource operands/results but effects not modelled"; - } - } - } - } - - // Add the terminator for the island - island_builder.create(result.location, - inner_op->getResults()); - return island.getOperation(); -} - -Status ImporterBase::ConvertNode(const Node& node) { - if (!node.IsOp()) { - // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by - // Graph and don't exist in GraphDef. - return absl::OkStatus(); - } - - // If it is a custom OP, its definition should be found in the library. We - // create the MLIR function and insert it to the module if it doesn't exist. - std::string node_type_name = node.type_string(); - const auto* func_def = graph_flib_.Find(node_type_name); - bool convert_to_legacy_call = false; - if (func_def) { - TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name)); - node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; - convert_to_legacy_call = true; - } - - auto get_full_op_name = [&](const std::string& op_name) { - const char* kTfPrefix = "tf."; - return kTfPrefix + op_name; - }; - - std::string op_name = get_full_op_name(node_type_name); - if (back_edge_node_output_.contains(&node)) { - op_name = op_name + ".sink"; - } - - mlir::OperationState result(GetLocation(node), op_name); - for (int i = 0; i < node.num_outputs(); ++i) { - // The backedge has been removed, so we shouldn't count the corresponding - // output from the src node when converting to an operation. - if (back_edge_node_output_.contains(&node) && - back_edge_node_output_[&node] == i) { - continue; - } - TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_)); - result.types.push_back(type); - } - - // Surprisingly input edges can be nondeterministically ordered. This - // particularly seems to be the case for the control edges between _SOURCE - // and _SINK that the Graph constructor inserts. Copy the input edges and - // sort the edges, but only the control edges, not data edges! - // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes. - // They'll break roundtripping anyway unless we strip them when converting - // back to graphdef. - absl::InlinedVector in_edges(node.in_edges().size()); - absl::c_copy(node.in_edges(), in_edges.begin()); - absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { - if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; - if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; - if (e1->IsControlEdge() && e2->IsControlEdge()) - return e1->src()->id() < e2->src()->id(); - return e1->dst_input() < e2->dst_input(); - }); - - result.operands.reserve(in_edges.size()); - - // Collect the control operands separately, they will be held by the island. - mlir::SmallVector control_operands; - - for (const auto* input_edge : in_edges) { - const Node& input_node = *input_edge->src(); - if (input_node.IsSource()) { - if (in_edges.size() != 1) { - return errors::FailedPrecondition( - "The node has other inputs besides the _Source node"); - } - // We don't import the _SOURCE node. - continue; - } - if (input_node.IsArg() && input_edge->IsControlEdge()) { - // Currently we have not reached consensus as to what TF function - // semantics are (b/133509504). Here we assume that all arguments to a - // function should be available before we start execution of any internal - // node. This makes the control dependencies between function arguments - // and internal nodes redundant, and so we do not import them. The TF - // inliner however assumes no such dependency between function args and - // internal nodes exists, unless explicitly stated. Since we drop control - // dependencies here, it leads to loss of information. If the function is - // inlined later, the inliner would not know of these explicit control - // dependencies present in the original graph. - continue; - } - if (node_values_.find(input_node.id()) == node_values_.end()) - return errors::FailedPrecondition( - "Graph not traversed in reverse post order; use seen before def!"); - mlir::Operation* inst = node_values_[input_node.id()]; - if (input_edge->IsControlEdge()) - control_operands.push_back(inst->getResult(inst->getNumResults() - 1)); - else - result.operands.push_back(inst->getResult(input_edge->src_output())); - } - - using FuncPairType = std::pair; - std::vector funcs; - result.attributes.reserve(node.attrs().size() + 2); - auto abstract_op = result.name.getRegisteredInfo(); - auto derived_op = - abstract_op - ? abstract_op->getInterface() - : nullptr; - for (const auto& name_and_value : node.attrs()) { - const auto& attr_name = name_and_value.first; - // Skip adding derived attributes to the generated op. - if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue; - const AttrValue& attr_value = name_and_value.second; - - // Remove _output_shapes attribute that will be added by the exporter. - if (IsOutputShapesAttribute(attr_value, attr_name)) continue; - - if (attr_value.value_case() == AttrValue::kFunc) { - // Attribute iteration order is not defined for protocol buffer Map. - // Process function attributes separately in the lexicographical order to - // have deterministic order of functions in the constructed IR. - funcs.emplace_back(&attr_name, &attr_value); - } else { - TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value)); - result.attributes.push_back(builder_.getNamedAttr(attr_name, attr)); - } - } - - auto comparator = [](const FuncPairType& a, const FuncPairType& b) { - return *a.first < *b.first; - }; - std::sort(funcs.begin(), funcs.end(), comparator); - for (const auto& func : funcs) { - TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second, - &result.attributes)); - } - - const auto& node_def = node.def(); - // NodeDef can contain partial TF device names. In such cases, canonicalize - // it. Note that in current TF, placer will place full device name to each - // node. - DeviceNameUtils::ParsedName parsed_name; - if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) { - return errors::InvalidArgument( - "Op ", op_name, " has invalid device name: ", node_def.device()); - } - // Keep the parsed name untouched if the device name is empty. - if (!node_def.device().empty()) { - if (!parsed_name.has_type) { - parsed_name.type = "CPU"; - parsed_name.has_type = true; - } - if (!parsed_name.has_id) { - parsed_name.id = 0; - parsed_name.has_id = true; - } - } - result.attributes.push_back(builder_.getNamedAttr( - "device", builder_.getStringAttr( - DeviceNameUtils::ParsedNameToString(parsed_name)))); - - // Map user function calls to LegacyCall ops and add the user function name - // as an attribute. - if (convert_to_legacy_call) { - result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_); - mlir::SymbolRefAttr val = - mlir::SymbolRefAttr::get(builder_.getContext(), node_type_name); - result.addAttribute("f", val); - - if (!result.attributes.get("_disable_call_shape_inference")) { - result.addAttribute("_disable_call_shape_inference", - builder_.getBoolAttr(false)); - } - } - - auto composite_control_flow_op = [&](const std::string& name) { - result.name = mlir::OperationName(get_full_op_name(name), context_); - bool stateless = absl::StartsWith(node_type_name, "Stateless"); - mlir::BoolAttr val = builder_.getBoolAttr(stateless); - result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); - }; - - // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common - // Case/If/While op in MLIR and add the differentiating attribute. - if (node.IsCaseNode()) composite_control_flow_op("Case"); - if (node.IsIfNode()) composite_control_flow_op("If"); - if (node.IsWhileNode()) { - composite_control_flow_op("While"); - auto* output_shapes = node.attrs().Find("output_shapes"); - if (output_shapes && !output_shapes->list().shape().empty()) - result.attributes.push_back( - builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr())); - } - - // Register the mapping between the TF node and the newly created operation. - node_values_[node.id()] = - CreateOperation(node, node_type_name, result, control_operands); - return absl::OkStatus(); -} - -// Add the backedges to the CFG. Given a backedge, we replace the original -// source and destination operations by two new operations. Most of the -// fields of the replacements are copied from the original operations. -// However, -// - for the src operation, one output is inserted to the front of the output -// list. The type of the output is set to the type of the non-control result -// of the dst operation, and -// - for the dst operation, one operand is inserted to the front of the -// operand list. This operand is using the first result of the src -// operation. -// TODO(fengliuai): Preserve the order of the results and operands if -// necessary. -Status ImporterBase::AddBackedges() { - for (auto it : back_edge_dst_inputs_) { - BackEdge& edge = it.second; - if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) { - return errors::FailedPrecondition( - "Invalid backedge; should be from NextIteration to Merge!"); - } - auto* sink = node_values_[edge.src->id()]; - auto* dst = node_values_[edge.dst->id()]; - TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input)); - } - return absl::OkStatus(); -} - -Status ImporterBase::AddBackedge(mlir::Operation* sink, mlir::Operation* dst, - int dst_input) { - // Get the NextIteration.Source operation from the token operand of the sink. - mlir::Operation* source = sink->getOperand(0).getDefiningOp(); - - // Adds the "source" to the operands of the dst by creating a new dst - // operation. - mlir::OperationState state(dst->getLoc(), dst->getName()); - auto num_operands = dst->getNumOperands(); - state.operands.reserve(num_operands + 1); - for (int input = 0, e = num_operands + 1; input != e; ++input) { - if (input < dst_input) { - state.operands.push_back(dst->getOperand(input)); - } else if (input == dst_input) { - state.operands.push_back(source->getResult(0)); - } else { - state.operands.push_back(dst->getOperand(input - 1)); - } - } - state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end()); - state.types.assign(dst->getResultTypes().begin(), - dst->getResultTypes().end()); - builder_.setInsertionPoint(dst); - auto* new_dst = builder_.create(state); - - // Replaces the output uses of the old operation by the corresponding - // result of the new operation, and deletes the old operation. - for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { - auto new_output = new_dst->getResult(i); - dst->getResult(i).replaceAllUsesWith(new_output); - } - dst->dropAllReferences(); - dst->erase(); - return absl::OkStatus(); -} - -absl::StatusOr ImporterBase::InferLibFunctionType( - const FunctionBody& fbody) { - mlir::Builder builder(context_); - - // The FunctionBody contains a graph with a single-output _Arg node for each - // function argument and a single-input _Retval node for each function return - // value. - // - // We already populated the ShapeRefiner with all the information about the - // shapes of these graph edges, so we just query it to build the corresponding - // MLIR function type signature. - - llvm::SmallVector arg_types; - if (specs_.inputs.empty()) { - arg_types.reserve(fbody.arg_types.size()); - for (auto arg : fbody.arg_nodes) { - // Find node in the graph using the node id instead of using `arg` - // directly because the graph has been cloned. - auto* node = graph_->FindNodeId(arg->id()); - TF_ASSIGN_OR_RETURN(auto type, - InferOutputType(*node, /*idx=*/0, builder)); - arg_types.push_back(type); - } - } else { - arg_types.reserve(fbody.arg_types.size()); - for (const auto& it : llvm::enumerate(specs_.inputs)) { - mlir::Type element_type; - const auto& node_info = it.value().second; - DataType dtype = node_info.imported_dtype; - // Uses the existing output type of the arg node if the data type of the - // the node isn't specified through the import configuration. - if (dtype == DT_INVALID) { - auto arg = fbody.arg_nodes[it.index()]; - auto* node = graph_->FindNodeId(arg->id()); - dtype = node->output_type(0); - if (dtype == DT_INVALID) { - return errors::InvalidArgument("Input ", it.index(), - "has invalid data type"); - } - } - TF_RETURN_IF_ERROR( - ::tensorflow::ConvertDataType(dtype, builder, &element_type)); - if (node_info.shape.unknown_rank()) { - arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); - } else { - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); - arg_types.push_back(GetTypeFromTFTensorShape(shape, element_type)); - } - } - } - - llvm::SmallVector ret_types; - ret_types.reserve(fbody.ret_types.size()); - for (auto ret : fbody.ret_nodes) { - // Find node in the graph using the node id instead of using `ret` directly - // because the graph has been cloned. - auto* node = graph_->FindNodeId(ret->id()); - TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); - ret_types.push_back(type); - } - - return builder.getFunctionType(arg_types, ret_types); -} - -// Stateful helper class to import a TensorFlow model expressed in GraphDef into -// an MLIR Module. -// -// The nodes defined in the graph are converted to a function called -// 'func_name'. All library function definitions are converted to MLIR functions -// in the module. -class GraphDefImporter : public ImporterBase { - public: - // Main entry point: converts the given graph to an MLIR Module. - static absl::StatusOr> Convert( - mlir::MLIRContext* context, const Graph& graph, - const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - std::unordered_map& tf_name_to_mlir_name, - bool disable_crash_analysis = false); - - private: - explicit GraphDefImporter( - const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::ModuleOp module, - std::unordered_map* tf_name_to_mlir_name, - NameUniquifier* function_name_uniquifier) - : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, - function_name_uniquifier) {} - - // Returns the function signature of the main function of converted MLIR - // module, the input nodes and output nodes. The type and shape information - // for the function arguments are read from `specs`, but the type and shape - // information for the function returns are inferred by the shape refiner in - // ImporterBase. - absl::StatusOr InferMainFunctionType( - const GraphImportConfig& specs, mlir::MLIRContext* context, - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes); - - // Returns the function signature of the main function, alongside input and - // output nodes, for function graphs. Arguments and return values are - // determined by node op type. Type and shape information of the function are - // inferred by the shape refiner in ImporterBase. - absl::StatusOr GetArgsRetsAndTypesFromFunctionGraph( - mlir::MLIRContext* context, - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes); - - // Finds the graph's target nodes/function's control ret nodes based on - // supplied node names in `control_outputs`. If `control_outputs` are not - // unique or a control ret node is missing, an error will be returned. - Status GetControlRetsFromGraph( - llvm::ArrayRef control_outputs, - absl::InlinedVector* control_ret_nodes); -}; - -absl::StatusOr> GraphDefImporter::Convert( - mlir::MLIRContext* context, const Graph& graph, - const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, - const GraphImportConfig& specs, - std::unordered_map& tf_name_to_mlir_name, - bool disable_crash_analysis) { - LoadImporterDialects(*context); - mlir::OwningOpRef module = - mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); - NameUniquifier function_name_uniquifier(flib_def); - - // importer.PrepareConvert below will attemp to clone the original `graph` - // via conversion to the graph def first. Convert graph to graph_def here - // first and avoid extra copies later. - auto graph_def = std::make_unique(); - graph.ToGraphDef(graph_def.get(), /*include_flib_def=*/false); - - auto scope_exit = [&]() { - std::function cleanup = []() {}; - if (!disable_crash_analysis) { - static std::atomic counter(0); - uint32 current_file_prefix = counter++; - const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash( - absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"), - *graph_def); - auto reachable_flib = flib_def.ReachableDefinitions(*graph_def); - const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash( - absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"), - reachable_flib.ToProto()); - cleanup = [=]() { - crash_analysis::RemoveReportData(graph_crash_handle); - crash_analysis::RemoveReportData(flib_crash_handle); - }; - } - - return llvm::make_scope_exit(std::move(cleanup)); - }(); - - VLOG(2) << "Importing: " - << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph, - &flib_def); - - GraphDefImporter importer(flib_def, debug_info, specs, module.get(), - &tf_name_to_mlir_name, &function_name_uniquifier); - - TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def))); - - mlir::FunctionType func_type; - absl::InlinedVector arg_nodes; - absl::InlinedVector ret_nodes; - absl::InlinedVector control_ret_nodes; - llvm::SmallVector attrs; - if (specs.graph_as_function) { - if (specs.prune_unused_nodes || !specs.inputs.empty() || - !specs.outputs.empty()) - return errors::InvalidArgument( - "Pruning of graph is currently unsupported when the main graph is " - "converted to a function."); - - TF_ASSIGN_OR_RETURN(func_type, - importer.GetArgsRetsAndTypesFromFunctionGraph( - context, &arg_nodes, &ret_nodes)); - - TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, - &control_ret_nodes)); - - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - auto node_name = [&](const OutputTensor& tensor) { - ss << tensor.node->name(); - }; - llvm::interleave(arg_nodes, ss, node_name, ","); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(ret_nodes, ss, node_name, ","); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.control_outputs, ss, ","); - auto control_outputs = - b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); - - // Under `graph_as_function` mode, `tf.entry_function` is always set as it - // is assumed feed, fetch, and target nodes are set correctly. - attrs.push_back(b.getNamedAttr( - "tf.entry_function", - b.getDictionaryAttr({inputs, outputs, control_outputs}))); - if (!specs.xla_compile_device_type.empty()) { - attrs.push_back( - b.getNamedAttr("_xla_compile_device_type", - b.getStringAttr(specs.xla_compile_device_type))); - } - attrs.push_back(b.getNamedAttr("allow_soft_placement", - b.getBoolAttr(specs.enable_soft_placement))); - } else { - // Collects the argument and return nodes by looking up the node names - // specified by the user. - TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType( - specs, context, &arg_nodes, &ret_nodes)); - - TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, - &control_ret_nodes)); - - // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and - // decoding in a centralized place. - // Record the input and output mapping. - if (!specs.inputs.empty() || !specs.outputs.empty() || - !specs.control_outputs.empty()) { - mlir::Builder b(context); - std::string s; - llvm::raw_string_ostream ss(s); - llvm::interleave( - specs.inputs, ss, - [&](const std::pair& v) { ss << v.first; }, - ","); - auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.outputs, ss, ","); - auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); - s.clear(); - llvm::interleave(specs.control_outputs, ss, ","); - auto control_outputs = - b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); - - attrs.push_back(b.getNamedAttr( - "tf.entry_function", - b.getDictionaryAttr({inputs, outputs, control_outputs}))); - } - } - - // Record version info. - PopulateTfVersions(module.get(), graph.versions()); - - const llvm::StringRef& graph_func_name = - specs.graph_func_name.empty() ? kImportModelDefaultGraphFuncName - : specs.graph_func_name; - TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type, - arg_nodes, ret_nodes, - control_ret_nodes, attrs)); - TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions()); - - // Mark main function public, others private. - for (auto function : module.get().getOps()) { - auto visibility = function.getName() == graph_func_name - ? mlir::func::FuncOp::Visibility::Public - : mlir::func::FuncOp::Visibility::Private; - function.setVisibility(visibility); - } - VLOG(2) << "Imported: " - << tensorflow::DumpMlirOpToFile("tf_mlir_imported_base", - module.get()); - return module; -} - -absl::StatusOr GraphDefImporter::InferMainFunctionType( - const GraphImportConfig& specs, mlir::MLIRContext* context, - absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes) { - // Find all the input nodes and output nodes. - // Feeds have been remapped to single output nodes (Placeholder), so an exact - // name match is sufficient. - absl::flat_hash_map inputs; - for (const auto& input_and_idx : llvm::enumerate(specs.inputs)) { - TensorId tensor = ParseTensorName(input_and_idx.value().first); - auto remapped_it = remapped_feeds_.find(tensor); - if (remapped_it != remapped_feeds_.end()) { - inputs.insert({remapped_it->second, input_and_idx.index()}); - } else { - inputs.insert({tensor.node(), input_and_idx.index()}); - } - } - - absl::flat_hash_set output_node_names; - std::vector outputs; - output_node_names.reserve(specs.outputs.size()); - for (const auto& output : specs.outputs) { - TensorId tensor = ParseTensorName(output); - auto remapped_it = remapped_feeds_.find(tensor); - if (remapped_it != remapped_feeds_.end()) { - output_node_names.insert(remapped_it->second); - outputs.push_back({remapped_it->second, 0}); - } else { - output_node_names.insert(tensor.node()); - outputs.push_back(tensor); - } - } - - if (!inputs.empty() || !outputs.empty()) { - arg_nodes->resize(inputs.size()); - ret_nodes->resize(outputs.size()); - - for (Node* n : GetOrderedNodes()) { - // Handle inputs/arguments. - auto input_it = inputs.find(n->name()); - if (input_it != inputs.end()) { - (*arg_nodes)[input_it->second] = {n, 0}; - } - - // Handle outputs/returns. - if (output_node_names.contains(n->name())) { - for (int i = 0, e = outputs.size(); i != e; ++i) { - TensorId tensor = outputs[i]; - if (n->name() != tensor.node()) continue; - (*ret_nodes)[i] = {n, tensor.index()}; - } - } - } - } - - // Starts to construct the function type. - mlir::Builder builder(context); - llvm::SmallVector arg_types; - arg_types.reserve(specs.inputs.size()); - int i = 0; - for (const auto& it : specs.inputs) { - Node* arg_node = arg_nodes->at(i).node; - if (arg_node == nullptr) { - return errors::InvalidArgument("Input ", it.first, - " was not found in graph"); - } - mlir::Type element_type; - const auto& node_info = it.second; - DataType imported_dtype = node_info.imported_dtype; - // Uses the existing output type of the arg node if the data type of the - // the node isn't specified through the import configuration. - if (imported_dtype == DT_INVALID) { - imported_dtype = arg_node->output_type(0); - if (imported_dtype == DT_INVALID) { - return errors::InvalidArgument("Input ", i, "has invalid data type"); - } - } - // Check if we have subtypes first - if (!node_info.subtypes.empty()) { - std::vector subtypes; - for (const auto& st : node_info.subtypes) { - mlir::Type st_data_type; - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(st.shape, &shape)); - TF_RETURN_IF_ERROR( - ConvertDataType(st.imported_dtype, builder, &st_data_type)); - subtypes.push_back(GetTypeFromTFTensorShape(shape, st_data_type)); - } - if (imported_dtype == DT_RESOURCE) { - element_type = - mlir::TF::ResourceType::get(subtypes, builder.getContext()); - } else if (imported_dtype == DT_VARIANT) { - element_type = - mlir::TF::VariantType::get(subtypes, builder.getContext()); - } else { - return errors::InvalidArgument(DataType_Name(imported_dtype), - " takes no subtypes."); - } - } else { - TF_RETURN_IF_ERROR( - ConvertDataType(imported_dtype, builder, &element_type)); - } - if (node_info.shape.unknown_rank()) { - arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); - } else { - llvm::SmallVector shape; - TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); - arg_types.push_back(GetTypeFromTFTensorShape(shape, element_type)); - } - i++; - } - - llvm::SmallVector ret_types; - ret_types.reserve(specs.outputs.size()); - for (int i = 0, e = specs.outputs.size(); i != e; ++i) { - if (ret_nodes->at(i).node == nullptr) { - return errors::InvalidArgument("Output ", specs.outputs[i], - " was not found in graph"); - } - } - for (const auto& ret : *ret_nodes) { - if (ret.node->num_outputs() <= ret.index) { - return errors::InvalidArgument("Invalid output index ", ret.index, - " specified for node: ", ret.node->name()); - } - TF_ASSIGN_OR_RETURN(auto type, - InferOutputType(*ret.node, ret.index, builder)); - ret_types.push_back(type); - } - - return builder.getFunctionType(arg_types, ret_types); -} - -absl::StatusOr -GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( - mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, - absl::InlinedVector* ret_nodes) { - auto add_node = [](Node* node, absl::InlinedVector* nodes) { - auto* attr = node->attrs().Find("index"); - if (!attr) - return errors::InvalidArgument(node->type_string(), " node '", - node->name(), - "' is missing attribute 'index'"); - - auto index = attr->i(); - const int num_nodes = nodes->size(); - if (num_nodes < index + 1) nodes->resize(index + 1); - - if ((*nodes)[index].node != nullptr) - return errors::InvalidArgument(node->type_string(), " node '", - node->name(), "' has attribute 'index' ", - index, " that conflicts with node '", - (*nodes)[index].node->name(), "'"); - (*nodes)[index] = {node, 0}; - - return absl::OkStatus(); - }; - - // Collect arg and ret nodes from graph. - for (auto* node : GetOrderedNodes()) - if (node->IsArg()) - TF_RETURN_IF_ERROR(add_node(node, arg_nodes)); - else if (node->IsRetval()) - TF_RETURN_IF_ERROR(add_node(node, ret_nodes)); - - // Collect arg and ret types and create function type. - mlir::Builder builder(context); - llvm::SmallVector arg_types; - arg_types.reserve(arg_nodes->size()); - for (const auto& arg_node_and_idx : llvm::enumerate(*arg_nodes)) { - auto& arg_node = arg_node_and_idx.value(); - if (arg_node.node == nullptr) - return errors::InvalidArgument("Graph missing _Arg at index ", - arg_node_and_idx.index()); - - TF_ASSIGN_OR_RETURN(auto type, - InferOutputType(*arg_node.node, /*idx=*/0, builder)); - arg_types.push_back(type); - } - - llvm::SmallVector ret_types; - ret_types.reserve(ret_nodes->size()); - for (const auto& ret_node_and_idx : llvm::enumerate(*ret_nodes)) { - auto& ret_node = ret_node_and_idx.value(); - if (ret_node.node == nullptr) - return errors::InvalidArgument("Graph missing _Retval at index ", - ret_node_and_idx.index()); - - TF_ASSIGN_OR_RETURN(auto type, - InferInputType(*ret_node.node, /*idx=*/0, builder)); - ret_types.push_back(type); - } - - return builder.getFunctionType(arg_types, ret_types); -} - -Status GraphDefImporter::GetControlRetsFromGraph( - llvm::ArrayRef control_outputs, - absl::InlinedVector* control_ret_nodes) { - if (control_outputs.empty()) return absl::OkStatus(); - - llvm::SmallDenseMap controls_to_idx; - for (const auto& control_and_idx : llvm::enumerate(control_outputs)) - controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()}); - - if (controls_to_idx.size() != control_outputs.size()) - return errors::InvalidArgument("Control outputs must be unique"); - - control_ret_nodes->resize(controls_to_idx.size()); - - for (auto* node : GetOrderedNodes()) { - auto it = controls_to_idx.find(node->name()); - if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node; - } - - for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs)) - if (std::get<0>(node_and_name) == nullptr) - return errors::InvalidArgument( - "Control output '", std::get<1>(node_and_name), "' is missing"); - - return absl::OkStatus(); -} -// Stateful helper class to import a TensorFlow model expressed in SavedModel -// into an MLIR Module. -class SavedModelObjectGraphImporter : public ImporterBase { - public: - // Main entry point: converts all functions in the given meta graph to an MLIR - // Module. - static absl::StatusOr> Convert( - SavedModelV2Bundle* saved_model, absl::Span exported_names, - mlir::MLIRContext* context, MLIRImportOptions options); - private: - explicit SavedModelObjectGraphImporter( - const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, - const GraphImportConfig& specs, mlir::ModuleOp module, - std::unordered_map* tf_name_to_mlir_name, - NameUniquifier* function_name_uniquifier) - : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, - function_name_uniquifier) {} -}; // Determines the names used to reference objects in the SavedObjectGraph. class ObjectNames { @@ -3577,11 +1025,9 @@ Status CreateSavedModelIR( return absl::OkStatus(); } -absl::StatusOr> -SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, - absl::Span exported_names, - mlir::MLIRContext* context, - MLIRImportOptions import_options) { +absl::StatusOr> ConvertSavedModelObjectGraph( + SavedModelV2Bundle* saved_model, absl::Span exported_names, + mlir::MLIRContext* context, MLIRImportOptions import_options) { LoadImporterDialects(*context); GraphDebugInfo dummy_debug_info; const GraphDebugInfo& debug_info = @@ -3612,17 +1058,15 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, options, std::move(preprocessed_graphdef), &graph)); NameUniquifier function_name_uniquifier(graph.flib_def()); - SavedModelObjectGraphImporter importer(graph.flib_def(), debug_info, specs, - module.get(), &tf_name_to_mlir_name, - &function_name_uniquifier); - - TF_RETURN_IF_ERROR(importer.PrepareConvert(graph)); - - auto fn_names = graph.flib_def().ListFunctionNames(); - for (const auto& fn_name : fn_names) { - TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); + for (const auto& fn_name : graph.flib_def().ListFunctionNames()) { + std::string mlir_func_name(function_name_uniquifier.GetUniqueName(fn_name)); + (tf_name_to_mlir_name)[std::string(fn_name)] = mlir_func_name; } - TF_RETURN_IF_ERROR(importer.ConvertDeferredFunctions()); + + specs.convert_all_functions_to_mlir = true; + TF_ASSIGN_OR_RETURN(module, tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + graph, debug_info, graph.flib_def(), specs, + module->getContext())); if (!saved_model->meta_graph_def().has_object_graph_def()) { return errors::InvalidArgument( @@ -3639,7 +1083,8 @@ SavedModelObjectGraphImporter::Convert(SavedModelV2Bundle* saved_model, llvm::make_early_inc_range(module->getOps())) { if (func.getName().starts_with("__inference__traced_save_") || func.getName().starts_with("__inference__traced_restore_") || - func.getName().starts_with("__inference_signature_wrapper_")) { + func.getName().starts_with("__inference_signature_wrapper_") || + func.getName().starts_with("main")) { func.erase(); } } @@ -3823,7 +1268,7 @@ class SavedModelSignatureDefImporterLite { const std::vector>& inputs, const std::vector>& outputs, std::vector control_outputs, - std::unordered_map& tf_name_to_mlir_name); + std::unordered_map* tf_name_to_mlir_name); // Moves the functions in `sub_module` to `module_` and skips the duplicate // functions. @@ -3932,7 +1377,7 @@ Status SavedModelSignatureDefImporterLite::ConvertInitializer( std::unordered_map tf_name_to_mlir_name; TF_ASSIGN_OR_RETURN(auto sub_module, ConvertGraph(target_node_name, inputs, {}, - {target_node_name}, tf_name_to_mlir_name)); + {target_node_name}, &tf_name_to_mlir_name)); mlir::SymbolTable sub_symbol_table(*sub_module); @@ -3971,7 +1416,7 @@ SavedModelSignatureDefImporterLite::ConvertGraph( const std::vector>& inputs, const std::vector>& outputs, const std::vector control_outputs, - std::unordered_map& tf_name_to_mlir_name) { + std::unordered_map* tf_name_to_mlir_name) { VLOG(1) << "Importing Signature: " << name; GraphImportConfig specs; @@ -3991,10 +1436,9 @@ SavedModelSignatureDefImporterLite::ConvertGraph( TF_ASSIGN_OR_RETURN(const auto* subgraph, input_.GetSubGraph(name, specs)); // Convert sub-graph to MLIR module. - return GraphDefImporter::Convert(module_->getContext(), *subgraph, - input_.debug_info(), subgraph->flib_def(), - specs, tf_name_to_mlir_name, - /*disable_crash_analysis=*/true); + return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + *subgraph, input_.debug_info(), subgraph->flib_def(), specs, + module_->getContext(), tf_name_to_mlir_name); } Status SavedModelSignatureDefImporterLite::ConvertSignature( @@ -4021,7 +1465,7 @@ Status SavedModelSignatureDefImporterLite::ConvertSignature( // Convert sub-graph to MLIR module. TF_ASSIGN_OR_RETURN( auto sub_module, - ConvertGraph(sig_def_key, inputs, outputs, {}, tf_name_to_mlir_name)); + ConvertGraph(sig_def_key, inputs, outputs, {}, &tf_name_to_mlir_name)); mlir::OpBuilder builder(sub_module->getBodyRegion()); // Find the FuncOp which corresponds to current SignatureDef. @@ -4313,41 +1757,17 @@ absl::StatusOr> ConvertGraphdefToMlir( } TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( options, std::move(preprocessed_graphdef), &graph)); - return ConvertGraphToMlir(graph, debug_info, graph.flib_def(), specs, - context); + return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + graph, debug_info, graph.flib_def(), specs, context); } absl::StatusOr> ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - mlir::MLIRContext* context) { - // TODO(jpienaar): Remove need to const_cast. - if (specs.upgrade_legacy) { - TF_RETURN_IF_ERROR( - UpgradeLegacyGraph(const_cast(&graph), - const_cast(&flib_def), - specs.restrict_functionalization_to_compiled_nodes)); - } - - std::unordered_map tf_name_to_mlir_name; - TF_ASSIGN_OR_RETURN(auto module, GraphDefImporter::Convert( - context, graph, debug_info, flib_def, - specs, tf_name_to_mlir_name)); - - if (specs.set_original_tf_func_name) { - // Set up the original function names in the imported TF MLIR. - mlir::Builder builder(module->getContext()); - mlir::SymbolTable symbol_table(*module); - for (const auto& [tf_name, mlir_name] : tf_name_to_mlir_name) { - auto func_op = symbol_table.lookup(mlir_name); - TF_RET_CHECK(func_op) - << "Graphdef importer should have created a function named " - << mlir_name << "."; - func_op->setAttr("tf._original_func_name", - builder.getStringAttr(tf_name)); - } - } - return module; + mlir::MLIRContext* context, + std::unordered_map* tf_name_to_mlir_name) { + return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + graph, debug_info, flib_def, specs, context, tf_name_to_mlir_name); } absl::StatusOr> ConvertFunctionToMlir( @@ -4360,16 +1780,15 @@ absl::StatusOr> ConvertFunctionToMlir( specs.graph_as_function = true; for (const auto* control_ret_node : fbody->control_ret_nodes) specs.control_outputs.push_back(control_ret_node->name()); - std::unordered_map tf_name_to_mlir_name; - return GraphDefImporter::Convert(context, *fbody->graph, dummy_debug_info, - flib_def, specs, tf_name_to_mlir_name); + return ConvertGraphToMlir(*fbody->graph, dummy_debug_info, flib_def, specs, + context); } absl::StatusOr> ConvertSavedModelToMlir( SavedModelV2Bundle* saved_model, mlir::MLIRContext* context, absl::Span exported_names, MLIRImportOptions options) { - return SavedModelObjectGraphImporter::Convert(saved_model, exported_names, - context, options); + return ConvertSavedModelObjectGraph(saved_model, exported_names, context, + options); } absl::StatusOr> ConvertSavedModelV1ToMlir( diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h index bca1f7f80af9e8..9e806a189a782f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_IMPORT_MODEL_H_ +#include #include #include "absl/strings/string_view.h" @@ -39,20 +40,25 @@ inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; // Given a GraphDef, returns a MLIR module containing the graph, expressed with // tf_executor dialect. +ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") absl::StatusOr> ConvertGraphdefToMlir( const GraphDef& graphdef, const GraphDebugInfo& debug_info, const GraphImportConfig& specs, mlir::MLIRContext* context); // Given a Graph, returns a MLIR module containing the graph, expressed with // tf_executor dialect. +ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") absl::StatusOr> ConvertGraphToMlir( const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, - mlir::MLIRContext* context); + mlir::MLIRContext* context, + std::unordered_map* tf_name_to_mlir_name = + nullptr); // [Experimental] // Given a Function, returns a MLIR module containing the graph, expressed with // tf_executor dialect. +ABSL_DEPRECATED("Use tensorflow::tf2xla::v2::ConvertGraphToTfExecutor instead.") absl::StatusOr> ConvertFunctionToMlir( const FunctionBody* fbody, const FunctionLibraryDefinition& flib_def, mlir::MLIRContext* context); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h index fca039c2601636..8873b0928b028f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h @@ -99,6 +99,11 @@ struct GraphImportConfig { // If true, a function attribute, `tf._original_func_name`, will be set in // functions which contains the corresponding original TF function name. bool set_original_tf_func_name = false; + + // If true, all functions in the graph will be converted to MLIR regardless of + // whether the functions are referenced by the nodes. This is needed if + // aliases and saved model object graph function matching is needed. + bool convert_all_functions_to_mlir = false; }; struct GraphExportConfig { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index 92ecf3082588ab..2ab216771ffd4d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -17,60 +17,18 @@ limitations under the License. // to satisfy the API of MLIR pass registration. In order to do this, the // command-line option header is pulled in. -#include #include -#include "absl/container/flat_hash_set.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" -#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "xla/client/client_library.h" -#include "xla/client/compile_only_client.h" -#include "xla/stream_executor/host/host_platform_id.h" -#include "xla/stream_executor/platform_manager.h" -#include "tensorflow/core/common_runtime/graph_constructor.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tsl/platform/protobuf.h" namespace mlir { using tsl::Status; using tsl::StatusOr; -static constexpr char kMlirToGraphCompilationCheckName[] = - "mlir-to-graph-compilation-check"; -// Use CPU arbitrarily in order to check that a graph compiles at all -static constexpr char kArbitraryDeviceName[] = "XLA_CPU_JIT"; - -namespace { -inline absl::string_view StringRefToView(llvm::StringRef ref) { - return {ref.data(), ref.size()}; -} -} // namespace - -static OwningOpRef GraphdefToMlirTranslateFunction( - llvm::StringRef input, MLIRContext* context) { - tensorflow::GraphdefToMlirOptions options{ - debug_info_file, xla_compile_device_type, - prune_unused_nodes, convert_legacy_fed_inputs, - graph_as_function, upgrade_legacy, - enable_shape_inference, unconditionally_use_set_output_shapes, - enable_soft_placement, set_original_tf_func_name}; - - auto module_or = tensorflow::GraphdefToMlirTranslateFunction( - input, input_arrays, input_dtypes, input_shapes, output_arrays, - control_output_arrays, options, context); - if (!module_or.status().ok()) return nullptr; - return std::move(module_or).value(); -} - -static TranslateToMLIRRegistration GraphdefToMlirTranslate( - "graphdef-to-mlir", "graphdef-to-mlir", GraphdefToMlirTranslateFunction); static OwningOpRef GraphdefToSplattedMlirTranslateFunction( llvm::StringRef input, MLIRContext* context) { @@ -90,112 +48,4 @@ static TranslateToMLIRRegistration GraphdefToSplattedMlirTranslate( "graphdef-to-splatted-mlir", "graphdef-to-splatted-mlir", GraphdefToSplattedMlirTranslateFunction); -static Status CompileGraph(tensorflow::Graph* graph, - xla::CompileOnlyClient* client) { - if (!graph || !client) { - return Status(absl::StatusCode::kInvalidArgument, - "Invalid graph or client"); - } - - tensorflow::FunctionDefLibrary flib; - auto flib_def = std::make_unique( - tensorflow::OpRegistry::Global(), flib); - - tensorflow::XlaCompiler::Options options; - options.device_type = tensorflow::DeviceType(kArbitraryDeviceName); - options.client = client; - options.flib_def = flib_def.get(); - tensorflow::XlaCompiler compiler(options); - - std::unique_ptr graph_copy( - new tensorflow::Graph(tensorflow::OpRegistry::Global())); - tensorflow::CopyGraph(*graph, graph_copy.get()); - - tensorflow::XlaCompiler::CompileOptions compile_options; - tensorflow::XlaCompiler::CompilationResult result; - return compiler.CompileGraph(compile_options, - kMlirToGraphCompilationCheckName, - std::move(graph_copy), {}, &result); -} - -static LogicalResult MlirToGraphTranslateFunction(ModuleOp module, - llvm::raw_ostream& output) { - if (!module) return failure(); - - tensorflow::GraphExportConfig confs; - confs.export_entry_func_to_flib = export_entry_func_to_flib; - confs.export_original_tf_func_name = export_original_tf_func_name; - - std::unique_ptr flib_def; - auto graph = - std::make_unique(tensorflow::OpRegistry::Global()); - absl::flat_hash_set control_ret_nodes; - auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( - module, confs, &graph, flib_def.get(), &control_ret_nodes); - if (!status.ok()) { - LOG(ERROR) << "Export to Graph failed: " << status; - return mlir::failure(); - } - - // Use Host platform, which should always exist, to make sure graphs compile. - auto platform = stream_executor::PlatformManager::PlatformWithId( - stream_executor::host::kHostPlatformId); - auto client = - xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform.value()); - - tensorflow::XlaOpRegistry::RegisterCompilationKernels(); - - // Verify that the resulting graph can compile. - if (!CompileGraph(graph.get(), client.value()).ok()) { - return mlir::failure(); - } - - auto graphdef = std::make_unique(); - // Print the graph to the output after going through GraphDef conversion. - // The DumpGraphToFile would do this anyway so just skip straight to it. - graph->ToGraphDef(graphdef.get()); - output << tsl::LegacyUnredactedDebugString(*graphdef); - - return success(); -} - -static TranslateFromMLIRRegistration mlir_to_graph_translate( - /*name=*/"mlir-to-graph", /*description=*/"convert mlir to graph", - MlirToGraphTranslateFunction, [](DialectRegistry& registry) { - mlir::RegisterAllTensorFlowDialects(registry); - }); - -static LogicalResult MlirToGraphdefTranslateFunction( - ModuleOp module, llvm::raw_ostream& output) { - if (!module) return failure(); - - tensorflow::GraphExportConfig confs; - confs.export_entry_func_to_flib = export_entry_func_to_flib; - confs.export_original_tf_func_name = export_original_tf_func_name; - - tensorflow::FunctionLibraryDefinition flib_def( - tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); - auto graph = - std::make_unique(tensorflow::OpRegistry::Global()); - absl::flat_hash_set control_ret_nodes; - - auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( - module, confs, &graph, &flib_def, &control_ret_nodes); - if (!status.ok()) { - LOG(ERROR) << "Export to Graph failed: " << status; - return mlir::failure(); - } - - tensorflow::GraphDef graphdef; - graph->ToGraphDef(&graphdef); - output << tsl::LegacyUnredactedDebugString(graphdef); - return success(); -} - -static TranslateFromMLIRRegistration mlir_to_graphdef_translate( - "mlir-to-graphdef", "mlir-to-graphdef", MlirToGraphdefTranslateFunction, - [](DialectRegistry& registry) { - mlir::RegisterAllTensorFlowDialects(registry); - }); - } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc index f7058467d89035..858c70a54a58d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.cc @@ -15,21 +15,32 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" +#include #include +#include #include #include #include +#include "absl/log/log.h" #include "absl/strings/str_split.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/core/platform/env.h" namespace tensorflow { +constexpr char kPassFilterEnvVar[] = "MLIR_BRIDGE_LOG_PASS_FILTER"; +constexpr char kStringFilterEnvVar[] = "MLIR_BRIDGE_LOG_STRING_FILTER"; +constexpr char kEnableOnlyTopLevelPassesEnvVar[] = + "MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES"; + // Counter is used as a prefix for filenames. static std::atomic log_counter(0); @@ -39,8 +50,11 @@ BridgeLoggerConfig::BridgeLoggerConfig(bool print_module_scope, : mlir::PassManager::IRPrinterConfig( print_module_scope, print_after_only_on_change, /*printAfterOnlyOnFailure=*/false, op_printing_flags), - pass_filter_(GetFilter("MLIR_BRIDGE_LOG_PASS_FILTER")), - string_filter_(GetFilter("MLIR_BRIDGE_LOG_STRING_FILTER")) {} + pass_filter_(GetFilter(kPassFilterEnvVar)), + string_filter_(GetFilter(kStringFilterEnvVar)) { + setenv(/*name=*/kEnableOnlyTopLevelPassesEnvVar, + /*value=*/"false", /*overwrite=*/0); +} // Logs op to file with name of format // `_mlir_bridge__.mlir`. @@ -83,6 +97,14 @@ std::vector BridgeLoggerConfig::GetFilter( return filter; } +bool BridgeLoggerConfig::ShouldOnlyDumpTopLevelPasses() { + const char* env_var = getenv(kEnableOnlyTopLevelPassesEnvVar); + std::string value(env_var); + std::transform(value.begin(), value.end(), value.begin(), ::tolower); + // Return true if value is "1" or "true"; otherwise, false. + return value == "1" || value == "true"; +} + bool BridgeLoggerConfig::MatchesFilter(const std::string& str, const std::vector& filter, bool exact_match) { @@ -104,6 +126,18 @@ bool BridgeLoggerConfig::ShouldPrint(mlir::Pass* pass, mlir::Operation* op) { "`MLIR_BRIDGE_LOG_PASS_FILTER`"; return false; } + if (ShouldOnlyDumpTopLevelPasses()) { + // Check if the operation is the top-level module. + // Top-level module has no parent. + if (op->getParentOp() != nullptr) { + // This is a nested operation; do not print. + VLOG(1) << "Not logging invocation of pass `" << pass_name + << "` because it is applied to a nested operation, due to " + "`MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES` being set " + "to true"; + return false; + } + } if (!string_filter_.empty()) { std::string serialized_op; llvm::raw_string_ostream os(serialized_op); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h index 2f65a87bef4d84..84bc1c60d4e133 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h @@ -23,7 +23,6 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/Timing.h" // from @llvm-project namespace tensorflow { @@ -83,6 +82,9 @@ class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig { static bool MatchesFilter(const std::string& str, const std::vector& filter, bool exact_match); + // Determines whether only top-level passes should be dumped. + // Returns true unless the environment variable is set to "0" or "false". + static bool ShouldOnlyDumpTopLevelPasses(); // Only log pass invocations whose pass name exactly matches any string in // `pass_filter_` (or when `pass_filter_` is empty). diff --git a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc index b2d2d71128a161..96df789efbaaaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger_test.cc @@ -17,12 +17,17 @@ limitations under the License. #include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -47,11 +52,32 @@ func.func @main(%arg0: tensor<7x8x9xi8>, %arg1: tensor<7x8x9xi8>) -> tensor<7x8x } )"; +void UnsetEnvironmentVariables() { + unsetenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER"); + unsetenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER"); + unsetenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES"); +} + +class BridgeLoggerFilters : public ::testing::Test { + protected: + void SetUp() override { UnsetEnvironmentVariables(); } + + mlir::MLIRContext CreateMlirContext() { + mlir::DialectRegistry mlir_registry; + mlir::RegisterAllTensorFlowDialects(mlir_registry); + return mlir::MLIRContext(mlir_registry); + } + + mlir::func::FuncOp GetFuncOp(mlir::ModuleOp module_op) { + auto func_ops = module_op.getOps(); + EXPECT_FALSE(func_ops.empty()); + return *func_ops.begin(); + } +}; + // Test pass filter. -TEST(BridgeLoggerFilters, TestPassFilter) { - mlir::DialectRegistry mlir_registry; - mlir::RegisterAllTensorFlowDialects(mlir_registry); - mlir::MLIRContext mlir_context(mlir_registry); +TEST_F(BridgeLoggerFilters, TestPassFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); mlir::OwningOpRef mlir_module_with_add; TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, &mlir_module_with_add)); @@ -64,9 +90,10 @@ TEST(BridgeLoggerFilters, TestPassFilter) { // partitioning_pass and shape_inference_pass should match the filter, // inliner_pass should not. - setenv("MLIR_BRIDGE_LOG_PASS_FILTER", + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/ "TPUResourceReadsWritesPartitioningPass;TensorFlowShapeInferencePass", - 1); + /*overwrite=*/1); BridgeLoggerConfig logger_config; EXPECT_TRUE(logger_config.ShouldPrint(partitioning_pass.get(), mlir_module_with_add.get())); @@ -77,10 +104,8 @@ TEST(BridgeLoggerFilters, TestPassFilter) { } // Test string filter. -TEST(BridgeLoggerFilters, TestStringFilter) { - mlir::DialectRegistry mlir_registry; - mlir::RegisterAllTensorFlowDialects(mlir_registry); - mlir::MLIRContext mlir_context(mlir_registry); +TEST_F(BridgeLoggerFilters, TestStringFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); mlir::OwningOpRef mlir_module_with_add, mlir_module_with_sub; TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, &mlir_module_with_add)); @@ -91,7 +116,8 @@ TEST(BridgeLoggerFilters, TestStringFilter) { mlir::TF::CreateTFShapeInferencePass(); // One string appears in both modules and the other one not. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", "func @main(%arg0: tensor;XXX", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", + /*value=*/"func @main(%arg0: tensor;XXX", /*overwrite=*/1); BridgeLoggerConfig logger_config1; EXPECT_TRUE( logger_config1.ShouldPrint(dummy_pass.get(), mlir_module_with_add.get())); @@ -99,7 +125,8 @@ TEST(BridgeLoggerFilters, TestStringFilter) { logger_config1.ShouldPrint(dummy_pass.get(), mlir_module_with_sub.get())); // Both strings do not appear in any module. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", "func @main(%arg0:tensor;XXX", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", + /*value=*/"func @main(%arg0:tensor;XXX", /*overwrite=*/1); BridgeLoggerConfig logger_config2; EXPECT_FALSE( logger_config2.ShouldPrint(dummy_pass.get(), mlir_module_with_add.get())); @@ -107,8 +134,9 @@ TEST(BridgeLoggerFilters, TestStringFilter) { logger_config2.ShouldPrint(dummy_pass.get(), mlir_module_with_sub.get())); // String appears in one module but not in the other. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", - "\"tf.AddV2\"(%arg0, %arg1) : (tensor<3x4x5xf32>", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", + /*value=*/"\"tf.AddV2\"(%arg0, %arg1) : (tensor<3x4x5xf32>", + /*overwrite=*/1); BridgeLoggerConfig logger_config3; EXPECT_TRUE( logger_config3.ShouldPrint(dummy_pass.get(), mlir_module_with_add.get())); @@ -116,11 +144,84 @@ TEST(BridgeLoggerFilters, TestStringFilter) { logger_config3.ShouldPrint(dummy_pass.get(), mlir_module_with_sub.get())); } -// Test both filters together. -TEST(BridgeLoggerFilters, TestBothFilters) { - mlir::DialectRegistry mlir_registry; - mlir::RegisterAllTensorFlowDialects(mlir_registry); - mlir::MLIRContext mlir_context(mlir_registry); +// Test enable only top level passes filter. +TEST_F(BridgeLoggerFilters, TestEnableOnlyTopLevelPassesFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_add; + TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, + &mlir_module_with_add)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + + BridgeLoggerConfig logger_config; + // ShouldPrint returns true for the top-level module operation. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); + // Find the nested function operation within the module. + mlir::func::FuncOp func_op = GetFuncOp(mlir_module_with_add.get()); + // ShouldPrint returns true for the nested function operation. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), func_op)); + + // Set the environment variable to enable only top-level passes. + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_filter; + // ShouldPrint returns false for the nested function operation. + EXPECT_FALSE( + logger_config_filter.ShouldPrint(shape_inference_pass.get(), func_op)); +} + +// Additional tests for various possible values of +// MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES. +TEST_F(BridgeLoggerFilters, TestEnableOnlyTopLevelPassesEnvVarValues) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_add; + TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, + &mlir_module_with_add)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + + mlir::ModuleOp module_op = *mlir_module_with_add; + // Find the nested function operation within the module. + mlir::func::FuncOp func_op = GetFuncOp(module_op); + + // Test with environment variable set to "FALSE". + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", + /*value=*/"FALSE", /*overwrite=*/1); + BridgeLoggerConfig logger_config_false; + // ShouldPrint should return true for top-level operation. + EXPECT_TRUE(logger_config_false.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); + // ShouldPrint should return true for nested function. + EXPECT_TRUE( + logger_config_false.ShouldPrint(shape_inference_pass.get(), func_op)); + + // Test with environment variable unset (default behavior). + unsetenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES"); + BridgeLoggerConfig logger_config_default; + // ShouldPrint should return true for top-level operation. + EXPECT_TRUE(logger_config_default.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); + // ShouldPrint should return true for nested function since default + // is disabled. + EXPECT_TRUE( + logger_config_default.ShouldPrint(shape_inference_pass.get(), func_op)); + + // Test with environment variable set to "1". + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_one; + // ShouldPrint should return false for nested function since filter + // is enabled. + EXPECT_FALSE( + logger_config_one.ShouldPrint(shape_inference_pass.get(), func_op)); +} + +// Test combinations of pass filter and string filter. +TEST_F(BridgeLoggerFilters, TestPassFilterAndStringFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); mlir::OwningOpRef mlir_module_with_add; TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, &mlir_module_with_add)); @@ -129,28 +230,193 @@ TEST(BridgeLoggerFilters, TestBothFilters) { mlir::TF::CreateTFShapeInferencePass(); // String filter is matched but pass filter is not. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", - "(tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>", 1); - setenv("MLIR_BRIDGE_LOG_PASS_FILTER", "ensorFlowShapeInferencePass", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", + /*value=*/ + "(tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> " + "tensor<3x4x5xf32>", + /*overwrite=*/1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/"ensorFlowShapeInferencePass", /*overwrite=*/1); BridgeLoggerConfig logger_config1; EXPECT_FALSE(logger_config1.ShouldPrint(shape_inference_pass.get(), mlir_module_with_add.get())); // Pass filter is matched but string filter is not. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", "XXX", 1); - setenv("MLIR_BRIDGE_LOG_PASS_FILTER", "TensorFlowShapeInferencePass", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", /*value=*/"XXX", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/"TensorFlowShapeInferencePass", /*overwrite=*/1); BridgeLoggerConfig logger_config2; EXPECT_FALSE(logger_config2.ShouldPrint(shape_inference_pass.get(), mlir_module_with_add.get())); // Both filters are matched. - setenv("MLIR_BRIDGE_LOG_STRING_FILTER", - "(tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32>", 1); - setenv("MLIR_BRIDGE_LOG_PASS_FILTER", "TensorFlowShapeInferencePass", 1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", + /*value=*/ + "(tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> " + "tensor<3x4x5xf32>", + /*overwrite=*/1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/"TensorFlowShapeInferencePass", /*overwrite=*/1); BridgeLoggerConfig logger_config3; EXPECT_TRUE(logger_config3.ShouldPrint(shape_inference_pass.get(), mlir_module_with_add.get())); } +// Test combinations of pass filter and enable only top level passes filter. +TEST_F(BridgeLoggerFilters, TestPassFilterAndEnableOnlyTopLevelPassesFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_sub; + TF_ASSERT_OK(DeserializeMlirModule(module_with_sub, &mlir_context, + &mlir_module_with_sub)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + std::unique_ptr inliner_pass = mlir::createInlinerPass(); + + // Find the nested function operation within the module. + mlir::func::FuncOp func_op = GetFuncOp(mlir_module_with_sub.get()); + + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/"TensorFlowShapeInferencePass", /*overwrite=*/1); + BridgeLoggerConfig logger_config; + // ShouldPrint should return true for top-level operation with matching pass + // filter. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_sub.get())); + // ShouldPrint should return true for nested operation when + // enable_only_top_level_passes_ is false. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), func_op)); + // ShouldPrint should return false for pass not matching the pass filter. + EXPECT_FALSE(logger_config.ShouldPrint(inliner_pass.get(), + mlir_module_with_sub.get())); + + // Set the environment variable to enable only top-level passes. + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_filter; + // ShouldPrint should return false for nested operation + EXPECT_FALSE( + logger_config_filter.ShouldPrint(shape_inference_pass.get(), func_op)); +} + +// Test combinations of string filter and enable only top level passes filter. +TEST_F(BridgeLoggerFilters, TestStringFilterAndEnableOnlyTopLevelPassesFilter) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_add; + TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, + &mlir_module_with_add)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + + // Find the nested function operation within the module. + mlir::func::FuncOp func_op = GetFuncOp(mlir_module_with_add.get()); + + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", /*value=*/"tf.AddV2", + /*overwrite=*/1); + BridgeLoggerConfig logger_config; + // ShouldPrint should return true for top-level operation containing + // "tf.AddV2". + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); + // ShouldPrint should return true for nested operation since + // enable_only_top_level_passes_ is false. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), func_op)); + + // Set the environment variable to enable only top-level passes. + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_filter; + // ShouldPrint should return false for nested operation since string + // filter matches but enable_only_top_level_passes_ is true. + EXPECT_FALSE( + logger_config_filter.ShouldPrint(shape_inference_pass.get(), func_op)); + + // Change string filter to not match any operation. + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", /*value=*/"NonExistentOp", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_no_match; + // ShouldPrint should return false since string filter does not match. + EXPECT_FALSE(logger_config_no_match.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); +} + +// Test combinations where all filters are set but none match. +TEST_F(BridgeLoggerFilters, TestAllFiltersNoMatch) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_sub; + TF_ASSERT_OK(DeserializeMlirModule(module_with_sub, &mlir_context, + &mlir_module_with_sub)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + + // Set pass filter to not match any pass + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", /*value=*/"NonExistentPass", + /*overwrite=*/1); + // Set string filter to not match any string + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", /*value=*/"NonExistentOp", + /*overwrite=*/1); + BridgeLoggerConfig logger_config; + // ShouldPrint should return false since none of the filters match. + EXPECT_FALSE(logger_config.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_sub.get())); + + // Set the environment variable to enable only top-level passes. + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_filter; + // ShouldPrint should still return false since pass and string filters do not + // match. + EXPECT_FALSE(logger_config_filter.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_sub.get())); +} + +// Test combinations of all three filters. +TEST_F(BridgeLoggerFilters, TestAllFiltersCombination) { + mlir::MLIRContext mlir_context = CreateMlirContext(); + mlir::OwningOpRef mlir_module_with_add; + TF_ASSERT_OK(DeserializeMlirModule(module_with_add, &mlir_context, + &mlir_module_with_add)); + + std::unique_ptr shape_inference_pass = + mlir::TF::CreateTFShapeInferencePass(); + std::unique_ptr inliner_pass = mlir::createInlinerPass(); + + // Find the nested function operation within the module. + mlir::func::FuncOp func_op = GetFuncOp(mlir_module_with_add.get()); + + // Set all three filters. + setenv(/*name=*/"MLIR_BRIDGE_LOG_PASS_FILTER", + /*value=*/"TensorFlowShapeInferencePass", /*overwrite=*/1); + setenv(/*name=*/"MLIR_BRIDGE_LOG_STRING_FILTER", /*value=*/"tf.AddV2", + /*overwrite=*/1); + BridgeLoggerConfig logger_config; + // ShouldPrint should return true if all filters pass and operation is + // top-level. + EXPECT_TRUE(logger_config.ShouldPrint(shape_inference_pass.get(), + mlir_module_with_add.get())); + + // ShouldPrint should return false if enable_only_top_level_passes_ is + // true and operation is nested. + setenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES", /*value=*/"1", + /*overwrite=*/1); + BridgeLoggerConfig logger_config_filter; + EXPECT_FALSE( + logger_config_filter.ShouldPrint(shape_inference_pass.get(), func_op)); + // Change to a pass that does not match the pass filter. + EXPECT_FALSE(logger_config_filter.ShouldPrint(inliner_pass.get(), + mlir_module_with_add.get())); + // Set the environment variable to disable only top-level passes. + unsetenv(/*name=*/"MLIR_BRIDGE_LOG_ENABLE_ONLY_TOP_LEVEL_PASSES"); + BridgeLoggerConfig logger_config_no_filter; + // ShouldPrint should return true for nested operation since + // enable_only_top_level_passes_ is false. + EXPECT_TRUE( + logger_config_no_filter.ShouldPrint(shape_inference_pass.get(), func_op)); + // Change to a pass that does not match the pass filter. + EXPECT_FALSE(logger_config_no_filter.ShouldPrint(inliner_pass.get(), + mlir_module_with_add.get())); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index c49971a8a8c0c7..1270865e551d52 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -28,11 +28,11 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/platform/crash_analysis.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" -#include "tsl/lib/io/buffered_file.h" using llvm::raw_ostream; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc index 14c723f8fce3da..1394ff66cd7269 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.cc @@ -17,8 +17,13 @@ limitations under the License. #include +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" +#include "tensorflow/core/ir/utils/shape_inference_utils.h" #define DEBUG_TYPE "tf-shape-inference-utils" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h index 040429ccf73057..28e2c93f5dfcac 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h @@ -18,6 +18,10 @@ limitations under the License. #include +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/ir/utils/shape_inference_utils.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h index 102dd7008f00f9..0c6e1532429d42 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.h @@ -20,7 +20,9 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc index 6ab4aa64a89070..8bec78295dcf47 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h index 7bb38112f77ee1..ea1ae8c8cc647b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_STABLEHLO_CUSTOM_CALL_H_ +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc index 7fd832e76042e4..b90fd80916e9e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/string_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 5a29bae67afe01..aa3d8c50952068 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -20,6 +20,9 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" @@ -33,32 +36,42 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace { @@ -381,7 +394,7 @@ static void RegisterMlirInputDialects(mlir::DialectRegistry& registry) { registry .insert(); + mlir::quant::QuantDialect>(); mlir::func::registerAllExtensions(registry); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc index 460c4baa49e8dc..289444a9977047 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.cc @@ -20,7 +20,13 @@ limitations under the License. #include #include -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h index a62fe17add2462..1daab85570e738 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/topological_sort.h @@ -20,7 +20,11 @@ limitations under the License. #include #include "absl/types/span.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc index 9c82c728f5d917..451ac2c1aa2640 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_cluster_util.cc @@ -19,7 +19,14 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index bcf9a21d26efc1..c20601d51dffa0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -32,6 +32,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index 250c21d627c7ed..cdbf73965968d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -20,12 +20,15 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" @@ -34,6 +37,7 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/statusor.h" namespace tensorflow { using tsl::StatusOr; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index 8527ae80b967d2..11b3077b582bdc 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -21,20 +21,33 @@ limitations under the License. #include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/core/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index e23e2313711f9c..72bf48ded7dc5f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -15,8 +15,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h index f9acbb9a88e7cb..60beacc8ecb837 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TRANSLATE_UTILS_H_ +#include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc index bcc70642c97a73..f65494a279560f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h index 0fdcfb4799636f..1a399df89578ac 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VERIFICATION_UTILS_H_ #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc index 9bfa57fd29762b..c3f96608b5a3b8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/verify_suitable_for_graph_export.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc b/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc index 517a56de5de0e5..61582cc85c98db 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/visitor.h b/tensorflow/compiler/mlir/tensorflow/utils/visitor.h index 6a7ada0bdb8824..9fd25569335608 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/visitor.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/visitor.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_VISITOR_H_ +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc index f5f9a1feab2e91..5e7768c3ce0fc3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.cc @@ -15,6 +15,21 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h" +#include "absl/log/log.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + namespace tensorflow { mlir::LogicalResult EraseClusterFuncs( llvm::MutableArrayRef to_be_erased) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h index d1c8cabf115653..8ce5403e1b54a0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util.h @@ -25,6 +25,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc index c0046f83664223..d5253ae49d1fff 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_rewrite_util_test.cc @@ -17,19 +17,21 @@ limitations under the License. #include +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatVariadic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "llvm/Support/Casting.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" -#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" -#include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/statusor.h" // #include // #include diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index addf76366984a0..ac8ecf1090b2a7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -44,16 +44,19 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tsl/lib/math/math_util.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" namespace tensorflow { namespace { @@ -73,15 +76,93 @@ int64_t GetPadding(const int split_dim, const int num_splits, return total_padding; } +mlir::TF::SliceOp CreateSliceOp(mlir::OpBuilder* builder, + const mlir::Location& location, + mlir::Value input, + const PartialTensorShape& shape) { + mlir::SmallVector slice_start_position; + for (int i = 0; i < shape.dims(); ++i) { + slice_start_position.push_back(0); + } + mlir::SmallVector slice_size; + for (int i = 0; i < shape.dims(); ++i) { + slice_size.push_back(shape.dim_size(i)); + } + + auto start_position_type = + mlir::RankedTensorType::get(shape.dims(), builder->getIntegerType(64)); + + auto start_position_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get(start_position_type, + slice_start_position)); + + auto slice_size_op = builder->create( + input.getLoc(), mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get( + shape.dims(), builder->getIntegerType(64)), + slice_size)); + + auto slice_result_type = + mlir::RankedTensorType::get(slice_size, getElementTypeOrSelf(input)); + + return builder->create(input.getLoc(), slice_result_type, + input, start_position_op, + slice_size_op); +} + +mlir::TF::PadOp CreatePadOp(mlir::OpBuilder* builder, + const mlir::Location& location, int64_t num_dims, + int64_t split_dim, mlir::Value src_input, + int64_t padding) { + auto input_type = mlir::cast(src_input.getType()); + llvm::SmallVector padding_values; + std::vector padded_shape; + for (int i = 0; i < num_dims; ++i) { + // 0 padding in the beginning. + padding_values.push_back(0); + if (i == split_dim) { + // pad the split dimension to make the total size of the input equal to + // the total size of the split dimension. + padding_values.push_back(padding); + padded_shape.push_back(input_type.getShape()[i] + padding); + } else { + padding_values.push_back(0); + padded_shape.push_back(input_type.getShape()[i]); + } + } + auto padding_type = + mlir::RankedTensorType::get({num_dims, 2}, builder->getIntegerType(64)); + auto paddings = mlir::DenseIntElementsAttr::get(padding_type, padding_values); + auto paddings_value = builder->create(location, paddings); + mlir::SmallVector expand_shape(padded_shape.begin(), + padded_shape.end()); + + auto expand_result_type = + mlir::RankedTensorType::get(expand_shape, input_type.getElementType()); + + return builder->create(location, expand_result_type, + src_input, paddings_value); +} + // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. -mlir::LogicalResult CreateSplitOp(const int num_split, - const int split_dimension, - const mlir::Location& location, - mlir::Value src_input, - mlir::OpBuilder* builder, - mlir::TF::SplitOp* split_op, - bool is_ici_weight_dist_spmd) { +mlir::LogicalResult CreateSplitOp( + const int num_split, const int split_dimension, const int64_t padding, + const mlir::Location& location, mlir::Value src_input, + mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op, + bool is_ici_weight_dist_spmd) { + if (padding > 0) { + int64_t num_dims = + mlir::cast(src_input.getType()).getRank(); + auto pad_op = CreatePadOp(builder, location, num_dims, split_dimension, + src_input, padding); + if (is_ici_weight_dist_spmd) { + pad_op->setAttr(kICIWeightDistributionMlirBridgeMarker, + builder->getBoolAttr(true)); + } + src_input = pad_op.getResult(); + } + // Creates a const op to hold split dimension value. auto split_dim_type = mlir::RankedTensorType::get({}, builder->getIntegerType(32)); @@ -139,6 +220,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Creates a tf::ConcatOp that merges `input` values in `concat_dimension`. mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, const mlir::Location& location, + const int64_t padding, mlir::ArrayRef inputs, mlir::OpBuilder* builder) { // Creates a const op to hold concat dimension value. @@ -265,6 +347,22 @@ mlir::LogicalResult CreateXlaSplitNDOp(const mlir::Location& location, return mlir::success(); } +bool IsShapeKnown(mlir::TensorType type) { + if (!type.hasRank()) return false; + + bool shape_known = false; + for (int i = 0; i < type.getRank(); ++i) { + if (type.getShape()[i] == mlir::ShapedType::kDynamic) { + shape_known = false; + break; + } else { + shape_known = true; + } + } + + return shape_known; +} + mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( const mlir::Location& location, const xla::OpSharding& input_sharding, const mlir::Value& original_source, mlir::OpBuilder* builder, @@ -335,17 +433,27 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + PartialTensorShape shape; + const auto input_type = + mlir::cast(original_source.getType()); + bool input_shape_known = IsShapeKnown(input_type); + if (input_shape_known) { + shape = PartialTensorShape(input_type.getShape()); + } for (const auto& dimension_and_num_splits : *dimension_to_splits_map) { const int dimension = dimension_and_num_splits.first; const int num_splits = dimension_and_num_splits.second; + int padding = input_shape_known + ? GetPadding(dimension, num_splits, + PartialTensorShape(input_type.getShape())) + : 0; // Creates root split op. if (split_ops_for_tiled_input.empty()) { mlir::TF::SplitOp root_split_op; - auto result = - CreateSplitOp(num_splits, dimension, location, original_source, - builder, &root_split_op, is_ici_weight_dist_spmd); + auto result = CreateSplitOp(num_splits, dimension, padding, location, + original_source, builder, &root_split_op, + is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); split_ops_for_tiled_input.emplace_back(root_split_op); @@ -358,7 +466,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingTfSplitOps( for (auto split_op : split_ops_for_tiled_input) { for (auto parent_split_output_value : split_op.getResults()) { mlir::TF::SplitOp child_split_op; - auto result = CreateSplitOp(num_splits, dimension, location, + auto result = CreateSplitOp(num_splits, dimension, padding, location, parent_split_output_value, builder, &child_split_op, is_ici_weight_dist_spmd); if (mlir::failed(result)) return mlir::failure(); @@ -827,7 +935,15 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( LOG(ERROR) << dimension_to_splits_map.status(); return mlir::failure(); } - + auto output_type = + mlir::cast(cluster_func_output.getType()); + PartialTensorShape shape; + bool output_shape_known = IsShapeKnown(output_type); + if (output_shape_known) { + shape = PartialTensorShape(output_type.getShape()); + } + bool has_paddings = false; + std::vector paddings; for (auto it = dimension_to_splits_map->rbegin(); it != dimension_to_splits_map->rend(); ++it) { int concat_dimension = it->first; @@ -837,12 +953,21 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( new_outputs.reserve(num_splits); for (int i = 0, end = outputs_to_merge.size(); i < end; i = i + num_splits) { + int64_t padding; + if (output_shape_known) { + padding = GetPadding(concat_dimension, num_splits, shape); + } else { + padding = 0; + } mlir::TF::ConcatOp concat_op = - CreateConcatOp(concat_dimension, location, + CreateConcatOp(concat_dimension, location, padding, llvm::ArrayRef{ outputs_to_merge.begin() + i, outputs_to_merge.begin() + i + num_splits}, builder); + + paddings.push_back(padding); + has_paddings |= padding > 0; new_outputs.emplace_back(concat_op.getResult()); } @@ -850,6 +975,12 @@ mlir::LogicalResult HandleTileShardedOutputsUsingTfConcatOps( } assert(outputs_to_merge.size() == 1); + if (has_paddings) { + // Add slice op to remove paddings. + mlir::TF::SliceOp slice_op = + CreateSliceOp(builder, location, outputs_to_merge[0], shape); + cluster_func_output.replaceAllUsesWith(slice_op.getResult()); + } cluster_func_output.replaceAllUsesWith(outputs_to_merge[0]); return mlir::success(); } @@ -876,26 +1007,13 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( *tiled_logical_computation_type = cluster_func_output_type; break; } - if (use_xla_nd_ops) { - if (output_shape[dimension] % output_splits == 0) { - new_output_shape[dimension] = output_shape[dimension] / output_splits; - } else { - // Input will be padded to be divisible by output_splits, thus add 1 to - // the output shape. - new_output_shape[dimension] = - (output_shape[dimension] / output_splits) + 1; - } - } else { - if (output_shape[dimension] % output_splits != 0) { - mlir::emitError( - location, - llvm::formatv("incorrect output sharding received. " - "{0}-th dimension of the output must be " - "evenly divisible by {1}, got dimension " - "shape {2}", - dimension, output_splits, output_shape[dimension])); - } + if (output_shape[dimension] % output_splits == 0) { new_output_shape[dimension] = output_shape[dimension] / output_splits; + } else { + // Input will be padded to be divisible by output_splits, thus add 1 to + // the output shape. + new_output_shape[dimension] = + (output_shape[dimension] / output_splits) + 1; } } @@ -904,23 +1022,6 @@ mlir::LogicalResult ValidateAndGetTiledExecuteOutputShape( return mlir::success(); } - -bool IsShapeKnown(mlir::TensorType type) { - if (!type.hasRank()) return false; - - bool shape_known = false; - for (int i = 0; i < type.getRank(); ++i) { - if (type.getShape()[i] == mlir::ShapedType::kDynamic) { - shape_known = false; - break; - } else { - shape_known = true; - } - } - - return shape_known; -} - } // namespace bool AreInputOutputShapesStaticallyKnownForSplitSharding( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h index ebdfc56dce6f68..699388de8457f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h @@ -27,9 +27,11 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "xla/xla_data.pb.h" @@ -43,7 +45,7 @@ inline constexpr llvm::StringRef kOutputShardingAttr = "output_sharding_configuration"; inline constexpr llvm::StringRef kICIWeightDistributionMlirBridgeMarker = - "ici_weight_distribution_mlir_bridge_marker"; + "_ici_weight_distribution_mlir_bridge_marker"; // Parses the sharding string. This sharding string can be binary (serialized) // or human readable. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc index 84d5697c9a6c2b..7e4791620f34fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tsl/platform/statusor.h" @@ -57,9 +58,9 @@ TEST(XLAShardingUtilTest, TestShapesCheckForSplitSharding) { func.func @parallel_execute_with_tiled_input(%arg0: tensor<128x9xf32>, %arg1: tensor<128x9xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x9xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ - %identity = "tf.Identity"(%ri_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>) -> tensor<128x9xf32> + %identity = "tf.Identity"(%ri_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x9xf32>) -> tensor<128x9xf32> tf_device.return %identity : tensor<128x9xf32> - }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x9xf32> + }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x9xf32> %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x9xf32>, tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) tf_device.return %2, %3 : tensor<128x10xi32>, tensor<10x5xi1> } @@ -99,9 +100,9 @@ TEST(XLAShardingUtilTest, TestShapesCheckForSplitShardingWithUnknownShapes) { func.func @parallel_execute_with_tiled_input(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<128x10xi32>, %arg3: tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<*xf32>, [%arg2, %arg3] as %ri_2: tensor<128x10xi32>) {n = 2 : i32} { %1 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({ - %identity = "tf.Identity"(%ri_1) {ici_weight_distribution_mlir_bridge_marker = true} : (tensor<*xf32>) -> tensor<*xf32> + %identity = "tf.Identity"(%ri_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<*xf32>) -> tensor<*xf32> tf_device.return %identity : tensor<*xf32> - }) {ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<*xf32> + }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<*xf32> %2, %3 = "tf_device.cluster_func"(%1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\1A\02\01\02\22\02\00\01", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<*xf32>, tensor<128x10xi32>) -> (tensor<128x10xi32>, tensor<10x5xi1>) tf_device.return %2, %3 : tensor<128x10xi32>, tensor<10x5xi1> } @@ -139,7 +140,6 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:localhost/replica:0/task:0/device:CPU:0", "/job:localhost/replica:0/task:0/device:TPU:0", "/job:localhost/replica:0/task:0/device:TPU:1", "/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0", "/job:localhost/replica:0/task:1/device:CPU:0", "/job:localhost/replica:0/task:1/device:TPU:0", "/job:localhost/replica:0/task:1/device:TPU:1", "/job:localhost/replica:0/task:1/device:TPU_SYSTEM:0"]} { func.func @uneven_input_sharding_disallowed(%arg0: tensor<128x10xf32>, %arg1: tensor<128x10xf32>, %arg2: tensor<*xi32>, %arg3: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) { %0:2, %1:2 = tf_device.replicate([%arg0, %arg1] as %ri_1: tensor<128x10xf32>, [%arg2, %arg3] as %ri_2: tensor<*xi32>) {n = 2 : i32} { - // expected-error@+1 {{incorrect input sharding configuration received. 1-th dimension of the input must be evenly divisible by 4}} %1, %2 = "tf_device.cluster_func"(%ri_1, %ri_2) {_xla_compile_device_type = "TPU", _replication_info = "cluster0", func = @tpu0_func, num_cores_per_replica = 2, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\01\02\01\02\10\02\18\02\22\10\00\00\00\00\00\00\00\01\00\01\00\00\00\01\00\01", device_assignment = [0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0], input_sharding_configuration = ["\08\03\12\12\10\0b\1a\02\01\04\2a\06\0a\02\01\00\20\01\32\02\00\00\1a\02\01\04\22\04\00\01\02\03", "\08\01\1A\01\01\22\01\01"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00", ""], use_spmd_for_xla_partitioning = false} : (tensor<128x10xf32>, tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi1>) tf_device.return %1, %2 : tensor<*xi32>, tensor<*xi1> } @@ -165,6 +165,7 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { int num_cores_per_replica = 4; mlir::OpBuilder builder(&context); bool use_xla_nd_ops = true; + llvm::SmallVector, 4> input_list; auto result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, use_xla_nd_ops, @@ -194,9 +195,30 @@ TEST(XLAShardingUtilTest, NotDivisibleShardingSplitOpTest) { // will appropriately add the values to the block. op->destroy(); + input_list.clear(); + // Expect error when use_xla_nd_ops is false. result = tensorflow::ExtractInputsForLogicalDevices( num_cores_per_replica, cluster_func_op, &builder, false, &input_list); - ASSERT_TRUE(failed(result)); + ASSERT_TRUE(succeeded(result)); + auto* split_op = input_list.front().front().getDefiningOp(); + ASSERT_TRUE(mlir::isa(split_op)); + + llvm::SmallVector split_inputs(split_op->getOperands()); + // Constant op for the split dimension + auto* const_op = split_inputs[0].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_op)); + // Pad op for the padding value to make it divisible by num_splits. + auto* pad_op = split_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(pad_op)); + llvm::SmallVector pad_inputs(pad_op->getOperands()); + auto* const_pad_value = pad_inputs[1].getDefiningOp(); + ASSERT_TRUE(mlir::isa(const_pad_value)); + // Destroy the ops to avoid error during block deletion (Same as above): + // use_empty() && "Cannot destroy a value that still has uses!" + split_op->destroy(); + const_op->destroy(); + pad_op->destroy(); + const_pad_value->destroy(); } } // namespace diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index daf7311870fe10..067ddd4c08bb5e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:shape_inference_pass", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tf2xla:mlir_bridge_rollout_policy", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/compiler/mlir/tf2xla/internal:mlir_pass_instrumentation", "//tensorflow/compiler/mlir/tf2xla/internal/passes:lowering_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", @@ -61,14 +62,14 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", - "@local_xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:register", ], ) @@ -100,8 +101,8 @@ tf_cc_test( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -114,7 +115,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tensorflow:translate_utils", @@ -142,8 +142,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla/client:compile_only_client", @@ -183,8 +181,8 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/stream_executor:platform_manager", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/lib/monitoring:test_utils", ], @@ -209,6 +207,7 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes", "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", "//tensorflow/compiler/mlir/tf2xla/internal/inference:inference_metrics_pass", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:errors", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index bde8c6b3b2f1b8..081f3b012be649 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/inference/inference_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" #include "xla/tsl/framework/device_type.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/errors.h" @@ -62,6 +63,8 @@ using mlir::func::FuncOp; namespace { void CreateReplicatedBridgePipelineV1(OpPassManager &pm) { + pm.addPass( + tensorflow::tf2xla::internal::CreateTPUValidateSessionInputsPass()); pm.addPass(mlir::tf2xla::internal::CreateInferenceMetricsPass()); // Convert to unified compilation and replication attributes. diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 8ed7d1ea727867..aa8bfe9e098aa6 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -63,6 +63,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tf2xla/internal/mlir_pass_instrumentation.h" #include "tensorflow/compiler/mlir/tf2xla/internal/passes/lowering_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" @@ -70,15 +71,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/error_payloads.h" @@ -1070,7 +1071,8 @@ absl::StatusOr> GraphToModule( config.unconditionally_use_set_output_shapes = unconditionally_use_set_output_shapes; GraphDebugInfo debug_info; - return ConvertGraphToMlir(graph, debug_info, flib_def, config, context); + return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + graph, debug_info, flib_def, config, context); } Status BuildHloFromGraph( diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index 4b8df7c35fe611..05ca6a5c26a35a 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc index 57769d2363bc18..c28eea50b5a121 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 68fe52c4e8fd27..2563e04b957ea7 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -23,9 +23,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -45,31 +47,28 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/compile_only_client.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/framework/device_type.h" #include "xla/tsl/lib/monitoring/sampler.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" -#include "tensorflow/core/lib/monitoring/sampler.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/profile_utils/cpu_utils.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" #include "tensorflow/core/util/debug_data_dumper.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tf2xla { @@ -211,13 +210,14 @@ absl::Status CompileTFFunctionWithoutMlir( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, const std::vector& arg_shapes, + const DeviceType& device_type, std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client, XlaCompiler::CompilationResult* compilation_result) { Status comp_status = CompileTFFunctionToHlo( *function_computation.flib_def, function_computation.graph_def_version, - shape_determination_funcs, arg_shapes, + shape_determination_funcs, arg_shapes, device_type, function_computation.guaranteed_constants, *function_computation.function, metadata, client, arg_core_mapping, per_core_arg_shapes, use_tuple_args, compilation_result); @@ -238,6 +238,7 @@ absl::Status CompileMLIRTFFunction( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, const std::vector& arg_shapes, + const DeviceType& device_type, std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client, @@ -286,8 +287,8 @@ absl::Status CompileMLIRTFFunction( TF_RETURN_IF_ERROR(CompileTFFunctionToHlo( *flib_def, versions.producer(), shape_determination_funcs, arg_shapes, - consts, func, metadata, client, arg_core_mapping, per_core_arg_shapes, - use_tuple_args, compilation_result)); + device_type, consts, func, metadata, client, arg_core_mapping, + per_core_arg_shapes, use_tuple_args, compilation_result)); return PopulateInputOutputAliasing(main_fn, compilation_result, use_tuple_args); @@ -301,6 +302,7 @@ absl::Status CompileTensorflowGraphToHlo( const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, const std::vector& arg_shapes, + tsl::DeviceType device_type, std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client, @@ -319,14 +321,14 @@ absl::Status CompileTensorflowGraphToHlo( if (has_mlir) { TF_RETURN_IF_ERROR(CompileMLIRTFFunction( std::get<0>(computation), metadata, use_tuple_args, - shape_determination_funcs, arg_shapes, arg_core_mapping, + shape_determination_funcs, arg_shapes, device_type, arg_core_mapping, per_core_arg_shapes, client, compilation_result)); } else { FunctionToHloArgs function_computation = std::get<1>(computation); TF_RETURN_IF_ERROR(CompileTFFunctionWithoutMlir( function_computation, metadata, use_tuple_args, - shape_determination_funcs, arg_shapes, arg_core_mapping, + shape_determination_funcs, arg_shapes, device_type, arg_core_mapping, per_core_arg_shapes, client, compilation_result)); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h index c3f2a6d2d0d868..7007d70b4bf6eb 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h @@ -21,9 +21,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/variant.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/compile_only_client.h" #include "xla/pjrt/compile_options.pb.h" +#include "xla/shape.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" @@ -39,6 +43,7 @@ absl::Status CompileTensorflowGraphToHlo( const tpu::TPUCompileMetadataProto& metadata, bool use_tuple_args, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_funcs, const std::vector& arg_shapes, + tsl::DeviceType device_type, std::vector* arg_core_mapping, std::vector>* per_core_arg_shapes, xla::CompileOnlyClient* client, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index c1af969caef8cc..567a862b7407f3 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -26,9 +26,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/shape.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/tsl/framework/device_type.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/monitoring/test_utils.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -102,8 +103,9 @@ class CompileTFGraphTest : public ::testing::Test { absl::Status compilation_status = tensorflow::tf2xla::v1::CompileTensorflowGraphToHlo( computation, metadata_proto, use_tuple_args, - shape_determination_fns, arg_shapes, &arg_core_mapping, - &per_core_arg_shapes, client, &compilation_result); + shape_determination_fns, arg_shapes, tsl::DeviceType("XLA_TPU_JIT"), + &arg_core_mapping, &per_core_arg_shapes, client, + &compilation_result); if (!compilation_status.ok()) return compilation_status; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 31188f4456f711..70ca9617ab8391 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -18,54 +18,36 @@ cc_library( "//learning/brain/google/xla:__pkg__", "//learning/brain/mlir/bridge:__pkg__", "//tensorflow/compiler/mlir/quantization/stablehlo:__pkg__", + "//tensorflow/compiler/mlir/tf2xla/api/v2/testing:__pkg__", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", ], deps = [ ":device_type_proto_cc", - "//tensorflow/compiler/jit:flags_headers", - "//tensorflow/compiler/jit:shape_inference", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", - "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", - "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/mlir/tensorflow/transforms:set_tpu_infeed_layout", - "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", - "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_tf_graph", "//tensorflow/compiler/mlir/tf2xla/internal:compilation_timer", - "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_mlir", "//tensorflow/compiler/mlir/tf2xla/internal:legalize_tf_to_hlo", "//tensorflow/compiler/mlir/tf2xla/internal:reproducer_proto_cc", "//tensorflow/compiler/tf2xla:layout_util", - "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/platform:statusor", - "//tensorflow/core/tpu:tpu_compile", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", - "//tensorflow/core/tpu/kernels:tpu_util_hdrs", "@com_google_absl//absl/log", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:error_logging", "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", "@local_xla//xla:xla_proto_cc", "@local_xla//xla/client:compile_only_client", "@local_xla//xla/hlo/ir:hlo", - "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/pjrt:compile_options_proto_cc", - "@stablehlo//:register", ], ) @@ -73,12 +55,10 @@ tf_cc_test( name = "legalize_tf_test", srcs = ["legalize_tf_test.cc"], deps = [ - ":device_type_proto_cc", ":legalize_tf", "//tensorflow/compiler/jit", - "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tf2xla/api/v2/testing:compile_mlir", "//tensorflow/compiler/mlir/tf2xla/internal:test_matchers", - "//tensorflow/compiler/mlir/tf2xla/internal/utils:test_metadata_config", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/core:framework", @@ -89,15 +69,46 @@ tf_cc_test( "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", + "@llvm-project//mlir:Pass", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", + "@local_xla//xla/stream_executor:platform", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/tsl/lib/monitoring:test_utils", ], ) +tf_cc_test( + name = "legalize_tf_test_gpu", + srcs = ["legalize_tf_test_gpu.cc"], + tags = [ + "config-cuda-only", + "no_oss", # This test only runs with GPU. + "requires-gpu-nvidia", + ], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/mlir/tf2xla/api/v2/testing:compile_mlir", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + tf_proto_library( name = "device_type_proto", srcs = ["device_type.proto"], @@ -112,7 +123,6 @@ cc_library( hdrs = ["cluster_tf.h"], visibility = [ "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", - "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tfrt:__pkg__", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:__pkg__", "//tensorflow/compiler/tf2xla:__pkg__", @@ -126,6 +136,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", "//tensorflow/compiler/mlir/tf2xla/internal:clustering_bridge_passes", "//tensorflow/compiler/mlir/tf2xla/internal:logging_hooks", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:error_payloads", @@ -178,7 +189,6 @@ cc_library( hdrs = ["tf_dialect_to_executor.h"], visibility = [ "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", - "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tfrt:__pkg__", "//tensorflow/compiler/tf2xla:__pkg__", ], @@ -272,3 +282,115 @@ cc_library( "@local_xla//xla:status_macros", ], ) + +tf_cc_test( + name = "tf_executor_to_graph_test", + srcs = ["tf_executor_to_graph_test.cc"], + data = [ + "testdata/valid_executor.mlir", + "testdata/valid_graph.txt", + ], + deps = [ + ":tf_executor_to_graph", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/tsl/lib/core:status_test_util", + "@riegeli//riegeli/bytes:fd_reader", + "@riegeli//riegeli/bytes:read_all", + ], +) + +cc_library( + name = "graph_to_tf_executor", + srcs = [ + "graph_to_tf_executor.cc", + ], + hdrs = [ + "graph_to_tf_executor.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/jit:shape_inference_helpers", + "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:convert_attr", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/compiler/mlir/tensorflow:convert_type", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/translate:upgrade_graph", + "//tensorflow/compiler/mlir/tf2xla/internal:node_order", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime:function_body", + "//tensorflow/core/platform:crash_analysis", + "//tensorflow/core/platform:types", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status", + "@local_xla//xla:status_macros", + ], +) + +tf_cc_test( + name = "graph_to_tf_executor_test", + srcs = ["graph_to_tf_executor_test.cc"], + data = [ + "testdata/graph_with_flib_def.txt", + "testdata/valid_graph.txt", + ], + deps = [ + ":graph_to_tf_executor", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", + "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/tsl/lib/core:status_test_util", + "@riegeli//riegeli/bytes:fd_reader", + "@riegeli//riegeli/bytes:read_all", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index 41df5eb0750459..7c70fccac1070e 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v2/device_type.pb.h" #include "tensorflow/compiler/mlir/tf2xla/internal/clustering_bridge_passes.h" #include "tensorflow/compiler/mlir/tf2xla/internal/logging_hooks.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/error_payloads.h" #include "tensorflow/core/platform/errors.h" @@ -148,7 +149,7 @@ void CreateReplicatedClusteringPipeline(OpPassManager &pm, // TF2-only passes should go here. However, this should be very rare and // new passes generally should go into the internal // AddReplicatedBridgeClusteringPipelinePasses. - pm.addPass(mlir::TFTPU::CreateTPUValidateInputsPass()); + pm.addPass(tensorflow::tf2xla::internal::CreateTPUValidateInputsPass()); pm.addNestedPass( mlir::TF::CreateCanonicalizeCompileAndReplicateAttributesPass()); tensorflow::tf2xla::internal::AddReplicatedBridgeClusteringPipelinePasses( diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc new file mode 100644 index 00000000000000..678bd5f4083bbb --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.cc @@ -0,0 +1,2719 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/jit/shape_inference_helpers.h" +#include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/upgrade_graph.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_attr.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/node_order.h" +#include "xla/status_macros.h" +#include "tensorflow/core/common_runtime/function_body.h" +#include "tensorflow/core/common_runtime/function_def_utils.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_runner.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_debug_info_builder.h" +#include "tensorflow/core/graph/graph_node_util.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/crash_analysis.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stack_frame.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" +#include "tensorflow/core/protobuf/trackable_object_graph.pb.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +static inline absl::string_view StringRefToView(llvm::StringRef ref) { + return {ref.data(), ref.size()}; +} + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +using ::mlir::NamedAttrList; +using ::mlir::TensorType; +using ::tsl::StatusOr; + +constexpr absl::string_view kOutputShapesAttrName = "_output_shapes"; + +void LoadImporterDialects(mlir::MLIRContext& context) { + // Load dialects involved in the conversion + mlir::DialectRegistry registry; + mlir::RegisterAllTensorFlowDialectsImpl(registry, false); + context.appendDialectRegistry(registry); + for (llvm::StringRef name : registry.getDialectNames()) + context.getOrLoadDialect(name); +} + +bool IsOutputShapesAttribute(const AttrValue& attr_value, + llvm::StringRef attr_name) { + return attr_name.compare(kOutputShapesAttrName) == 0 && + attr_value.value_case() == AttrValue::kList; +} + +bool IsResourceOutputShapesAttribute(const AttrValue& attr_value, + llvm::StringRef attr_name) { + if (attr_name == "_handle_dtypes" || attr_name == "_handle_shapes") + return attr_value.value_case() == AttrValue::kList; + return false; +} + +class NameUniquifier : public OpOrArgNameMapper { + public: + explicit NameUniquifier(const FunctionLibraryDefinition& flib) + : flib_(flib) {} + + private: + bool IsUnique(llvm::StringRef name) override { + return !flib_.Contains(std::string(name)); + } + + std::string GetName(OpOrVal op_or_val) override { + DCHECK(false) << "Unimplemented"; + return ""; + } + + const FunctionLibraryDefinition& flib_; +}; + +// Stateful helper class to import a TensorFlow model into an MLIR Module. +// +// This is the base class that contains common utilities shared between the +// GraphDef importer and SavedModel importer. +// +// A subclass is expected to call `PrepareConvert` first to perform necessary +// preparation over the graph and also certain internal bookkeeping data. +// Afterwards the other protected methods can be called. +class ImporterBase { + protected: + explicit ImporterBase( + const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, + const GraphImportConfig& specs, mlir::ModuleOp module, + std::unordered_map* tf_name_to_mlir_name, + NameUniquifier* function_name_uniquifier, + llvm::StringRef function_name_for_debug_info = "") + : builder_(module.getContext()), + module_(module), + context_(module.getContext()), + tf_name_to_mlir_name_(tf_name_to_mlir_name), + graph_flib_(flib), + specs_(specs), + debug_info_(debug_info), + function_name_for_debug_info_(function_name_for_debug_info), + function_name_uniquifier_(function_name_uniquifier), + error_handler_(module.getContext()) { + // Log import config. + if (VLOG_IS_ON(1)) { + LOG(INFO) << "Importing with: " << specs.str(); + for (auto& it : *tf_name_to_mlir_name) { + LOG(INFO) << "\t" << it.first << " -> " << it.second; + } + } + + stack_traces_ = LoadTracesFromDebugInfo(debug_info_); + } + + // Returns the inferred function signature of the given function body. Input + // types are unranked tensor of the respective datatype in the function and + // result types are inferred by the shape_refiner_. Result types need not be + // unranked tensors and could be ranked tensors in cases where result type + // depends on an op with static output shape like tf.Const. + absl::StatusOr InferLibFunctionType( + const FunctionBody& fbody); + + // Extracts arg and ret nodes from FunctionBody. + void GetArgsAndRetsFromFunctionBody( + const FunctionBody& fbody, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes, + absl::InlinedVector* control_ret_nodes); + + // Prepares converting the graph to an MLIR module. This step removes the + // backedges of the graph, orders the nodes and infers the shapes. + // PrepareConvert needs to ensure that the original `graph` is cloned prior + // execution. The cloning procedure relies on the roundtrip through the + // GraphDef. Graph to GraphDef def conversion is heavy, in case, `graph_def` + // was obtained previously provide it to the PrepareConvert to reuse. + absl::Status PrepareConvert(const Graph& graph, + std::unique_ptr graph_def = nullptr); + + // Converts the prepared graph to a Function and adds it to the module. A set + // of nodes from the graph are given to converted to the arguments and returns + // of the function. + absl::Status Convert(llvm::StringRef func_name, mlir::FunctionType func_type, + const absl::InlinedVector& arg_nodes, + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes, + llvm::ArrayRef attrs); + + // Finds out the function definition for the given function name from the + // graph and converts it to a function of the module. This method is called + // on demand because the graph flib_def does not provide an iterator + // interface. + absl::Status ConvertLibFunction(llvm::StringRef func_name); + + // Returns the list of nodes in the graph. Nodes are presented in the reverse + // order of a post-order depth-first visit starting from the graph's source + // nodes. + llvm::ArrayRef GetOrderedNodes() const { return ordered_nodes_; } + + // Returns the inferred input type at index `idx` of the `node` in the + // context. + absl::StatusOr InferInputType(const Node& node, int idx, + mlir::Builder builder); + + // Returns the inferred output type at index `idx` of the `node` in the + // context. + absl::StatusOr InferOutputType(const Node& node, int idx, + mlir::Builder builder); + + // Convert deferred TF functions to the MLIR representation. + // Conversion is deferred for efficiency reasons, e.g., to limit depth + // of recursion and reduce stack size pressure. + absl::Status ConvertDeferredFunctions(); + + private: + // Most types with subtypes have only one subtype. + using ElementSubtypes = llvm::SmallVector; + + // Metadata used for deferred function conversion. + struct DeferredConversionMetaData { + DeferredConversionMetaData( + const std::string& function_name, + const std::vector& attributes) + : function_name(function_name), attributes(attributes) {} + + std::string function_name; + std::vector attributes; + }; + + // Adds all the ordered_nodes to the shape refiner shape_refiner_. Then all + // data type and shape information is maintained by the shape_refiner_. + // TODO(jpienaar): Remove once shape inference on import is removed. + absl::Status AddNodesToShapeRefiner( + std::unordered_map* node_name_map); + + // Prune nodes that do not feed into fetch nodes. + absl::Status PruneUnreachableNodes( + std::unordered_map* node_name_map); + + // Converts feeds to Placeholder nodes. + absl::Status ConvertFeedsToPlaceholders( + std::unordered_map* node_name_map); + + // Converts the inferred shape referred to by 'handle' in 'context', with + // given element type, and returns an MLIR tensor type. + absl::StatusOr ConvertDataTypeAndShape( + DataType dtype, const shape_inference::ShapeHandle& handle, + const std::vector* handle_subtypes, + shape_inference::InferenceContext* context, mlir::Builder builder); + + // Converts the inferred shape referred to by 'handle' in 'context', with + // given element type, and returns an MLIR tensor type. + absl::StatusOr ConvertElementTypeAndShape( + mlir::Type element_type, const shape_inference::ShapeHandle& handle, + shape_inference::InferenceContext* context, mlir::Builder builder); + + // Converts the inferred subtypes for an element type to corresponding MLIR + // types in 'context'. + absl::StatusOr ConvertSubtypes( + const std::vector* handle_subtypes, + shape_inference::InferenceContext* context, mlir::Builder builder); + + // Converts the tensor proto into an MLIR elements attribute. + absl::StatusOr ConvertTensorProto( + const TensorProto& value) { + return tensorflow::ConvertTensorProto(value, &builder_); + } + + // Converts func name in graphdef to mlir::SymbolRefAttribute. + absl::StatusOr ConvertFunctionCallName( + const std::string& func_name); + + // Converts the given non-function-call AttrValue to an MLIR Attribute. + absl::StatusOr ConvertAttributeValue(const AttrValue& value); + + // Converts the given function-call AttrValue to MLIR Attributes and pushes + // them to the given attributes list. For example, if there is a kFunc + // AttrValue {name : foo, attrs : {k1 : bar, k2 : rfc}}, it will convert it to + // a list of MLIR Attributes: {{base_name : foo}, {base_name.k1 : bar}, + // {base_name.k2 : rfc}}. + absl::Status ConvertFunctionCallAttribute(const std::string& base_name, + const AttrValue& value, + NamedAttrList* attributes); + + // Helper to create either a tf_executor operation or a TF operation wrapped + // in an island. + mlir::Operation* CreateOperation( + const Node& node, llvm::StringRef node_type_name, + const mlir::OperationState& result, + const llvm::SmallVectorImpl& control_operands); + + // Converts one NodeDef from the input GraphDef into an Operation and + // inserts it into the MLIR module using builder_. + absl::Status ConvertNode(const Node& node); + + // If the input graph represents a while-loop, the edges pointing from a + // "NextIteration" node to a "Merge" node add cyclic dependencies and make the + // topological sorting impossible. We need to remove these edges from the + // input graph to infer shapes and construct a Function. For each + // "NextIteration" node, there are two operations, "NextIteration.source" + // and "NextIteration.sink" are added to the MLIR module. + using BackEdge = BackEdgeHelper::BackEdge; + + // Removes backedges from the input graph. The removed edges are added back to + // to OpBuilder after the remaining graph is converted to the Function. + absl::Status RemoveBackedges(); + + // Restores backedges removed during shape inference to the final Function. + absl::Status AddBackedges(); + + // Restores a single backedge in the Function by adding a replicated + // operation before the dst operation. + absl::Status AddBackedge(mlir::Operation* sink, mlir::Operation* dst, + int dst_input); + + // Adds the input arguments and return operation to the function. The + // arguments are added as basic block argument. Also the argument types and + // the id of the nodes from the input graph needs to be specified. + absl::Status ConvertFunctionArgAndRets( + mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op, + llvm::ArrayRef arg_types, + const absl::InlinedVector& arg_nodes, + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes); + + // Gets the location information of the given node. It uses the + // "original_node_name" in the NodeDef to get the corresponding file location + // (FileLineColLoc) from the input DebugInfo and returns an CallSiteLoc. If + // there are multiple "original_node_names", a FusedLoc is returned. If the + // node name couldn't be found in the input DebugInfo, a NameLoc is used as + // the location. + mlir::Location GetLocation(const Node& node); + + // Appends the location string for the node to the error message and returns + // the combined error absl::Status. + absl::Status EmitErrorWithLocationStr(const Node& node, + const absl::Status& error_status); + + // Inserts a placeholder node in the graph to replace a feed output tensor, + // and returns the new placeholder node and a boolean indicating if the + // original input node was removed from the graph. Uses of the feed output + // tensor are replaced with this placeholder node. If the feed output tensor + // is of a single output node, the control dependencies are forwarded to the + // the placeholder node, and the original node will be removed. + // Note: This modifies the graph, and so any list of ordered nodes needs to be + // reconstructed. + absl::StatusOr> CreatePlaceholderNodeForFeed( + const TensorShapeProto& shape, DataType dtype, Node* node, int index, + const std::unordered_map& node_name_map); + + // Gets the input and output nodes corresponding to the specified input and + // output nodes in specs_. If there are no input or output nodes specified, + // nodes will be empty. + absl::Status GetInputOutputNodes( + const std::unordered_map& node_name_map, + std::unordered_set* nodes); + + // The input graph with backedges removed. The removed backedges are stored + // in the back_edge_helper. + BackEdgeHelper back_edge_helper_; + // A map between node and output index, for each backedge. + absl::flat_hash_map back_edge_node_output_; + absl::flat_hash_map back_edge_dst_inputs_; + // A map between sink and source operation of NextIteration + absl::flat_hash_map + next_iteration_sink_source_; + + // All nodes and version information about the (copied) imported graph. + std::unique_ptr graph_; + std::vector ordered_nodes_; + + // Maps from a Node ID to a MLIR value. + using NodeValueMap = absl::flat_hash_map; + + mlir::OpBuilder builder_; + mlir::ModuleOp module_; + mlir::MLIRContext* context_; + std::unordered_map* tf_name_to_mlir_name_; + const FunctionLibraryDefinition& graph_flib_; + const GraphImportConfig& specs_; + const GraphDebugInfo& debug_info_; + StackTracesMap stack_traces_; + llvm::StringRef function_name_for_debug_info_; + NodeValueMap node_values_; + // TODO(jpienaar): Remove once shape inference on import is removed. + // The shape_refinner_ will be nullptr if shape inference on import is + // not enabled. + std::unique_ptr shape_refiner_ = nullptr; + NameUniquifier* function_name_uniquifier_; + mlir::StatusScopedDiagnosticHandler error_handler_; + // All the TF ops encountered that aren't modelled in dialect. + llvm::DenseSet unmodelled_op_names_; + + protected: + // Maps feed as TensorId to new Placeholder node name. + absl::flat_hash_map remapped_feeds_; + // Keep track of functions required deferred conversion. + std::queue deferred_functions_; +}; + +// Mapping from node name to feed (index and ArrayInfo). Node name must outlive +// this map. +using FeedsByNode = absl::flat_hash_map< + absl::string_view, + absl::flat_hash_map*>>; + +// Creates from a `GraphImportConfig::InputArrays` a mapping from a feeds output +// tensor name to index and ArrayInfo. Keys and values are backed by +// `GraphImportConfig::InputArrays`. +absl::StatusOr GetFeedsByNode( + const GraphImportConfig::InputArrays& inputs) { + FeedsByNode feeds_by_node; + feeds_by_node.reserve(inputs.size()); + + for (const auto& input : inputs) { + TensorId tensor = ParseTensorName(input.first); + if (tensor.index() < 0) + return errors::FailedPrecondition( + "Feed output tensor must be a data output '", tensor.ToString(), "'"); + + auto& node = feeds_by_node[tensor.node()]; + if (!node.insert({tensor.index(), &input}).second) + return errors::FailedPrecondition( + "Multiple feeds for the same output tensor '", tensor.ToString(), + "'"); + } + + return feeds_by_node; +} + +// Creates a unique name for a node that will be replacing a feed output tensor. +std::string GetUniqueNodeName( + absl::string_view node_name, int index, + const std::unordered_map& node_name_map) { + std::string new_node_name_base = absl::StrCat(node_name, "_", index); + int count = 0; + std::string new_node_name = new_node_name_base; + while (node_name_map.find(new_node_name) != node_name_map.end()) { + new_node_name = absl::StrCat(new_node_name_base, "_", count++); + } + return new_node_name; +} + +absl::Status ImporterBase::ConvertDeferredFunctions() { + while (!deferred_functions_.empty()) { + auto conversion_metadata = deferred_functions_.front(); + deferred_functions_.pop(); + + const FunctionDef* func_def = + graph_flib_.Find(conversion_metadata.function_name); + // Converts the graph to an MLIR function and adds it to the module. + // We populate the NodeSpec so that all the _Arg ops get their shape + // added correctly. + GraphImportConfig specs; + specs.enable_shape_inference = specs_.enable_shape_inference; + specs.unconditionally_use_set_output_shapes = + specs_.unconditionally_use_set_output_shapes; + for (const auto& name_and_value : func_def->attr()) { + if (name_and_value.first == "_input_shapes") { + auto& list = name_and_value.second.list(); + auto& signature = func_def->signature(); + // Some models have "_input_shapes" attribute, but with its value empty + if (list.shape_size() > 0 && + list.shape_size() != signature.input_arg_size()) { + return errors::FailedPrecondition( + "Number of input arguments must be equal to the length of " + "_input_shapes attribute in function '", + StringRefToView(conversion_metadata.function_name), "'."); + } + for (int i = 0, e = signature.input_arg_size(); i < e; i++) { + auto& input_arg = signature.input_arg(i); + auto& array_info = specs.inputs[input_arg.name()]; + array_info.imported_dtype = input_arg.type(); + // set to unranked for empty "_input_shapes" attribute + if (list.shape_size() > 0) + array_info.shape = list.shape(i); + else + array_info.shape.set_unknown_rank(true); + } + } + } + + ImporterBase importer(graph_flib_, debug_info_, specs, module_, + tf_name_to_mlir_name_, function_name_uniquifier_, + conversion_metadata.function_name); + + std::unique_ptr fbody; + TF_RETURN_IF_ERROR( + FunctionDefToBodyHelper(*func_def, AttrSlice(), &graph_flib_, &fbody)); + TF_RETURN_IF_ERROR(importer.PrepareConvert(*fbody->graph)); + + TF_ASSIGN_OR_RETURN(auto func_type, importer.InferLibFunctionType(*fbody)); + + absl::InlinedVector arg_nodes; + absl::InlinedVector ret_nodes; + absl::InlinedVector control_ret_nodes; + importer.GetArgsAndRetsFromFunctionBody(*fbody, &arg_nodes, &ret_nodes, + &control_ret_nodes); + const std::string& mlir_func_name = + (*tf_name_to_mlir_name_)[conversion_metadata.function_name]; + + TF_RETURN_IF_ERROR(importer.Convert(mlir_func_name, func_type, arg_nodes, + ret_nodes, control_ret_nodes, + conversion_metadata.attributes)); + + // Additional function bodies could be discovered during the deferred + // loading of the current function. Add them to the working queue. + while (!importer.deferred_functions_.empty()) { + deferred_functions_.push(importer.deferred_functions_.front()); + importer.deferred_functions_.pop(); + } + } + + return absl::OkStatus(); +} + +absl::Status ImporterBase::RemoveBackedges() { + // Remove all the backedges. So the nodes can be added to the shape refiner. + TF_RETURN_IF_ERROR(back_edge_helper_.Remove(graph_.get())); + VLOG(1) << "Found " << (back_edge_helper_.RemovedEdges().size()) + << " backedges."; + + // Creates a map for quickly identifying whether a node output is a backedge. + for (const auto& edge : back_edge_helper_.RemovedEdges()) { + if (back_edge_node_output_.find(edge.src) != back_edge_node_output_.end() && + back_edge_node_output_[edge.src] != edge.src_output) { + return errors::FailedPrecondition( + "More than one of the src node outputs are backedges!"); + } + back_edge_node_output_[edge.src] = edge.src_output; + // We expect a merge to receive a single backedge (multiple NextIteration + // nodes feeding into the same merge is unexpected here). + DCHECK(!back_edge_dst_inputs_.contains(edge.dst)); + back_edge_dst_inputs_[edge.dst] = edge; + } + + // Obtains a RPO ordering, using node names as a tiebreak for stable sorting. + + ordered_nodes_.clear(); + TopologicalOrdering( + *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, GroupByDevice()); + return absl::OkStatus(); +} + +absl::Status CopyStackTraces(const Graph& from, Graph* to) { + // Copy over the stack traces. + // TODO(jpienaar): This really shouldn't be needed, copying the Graph above + // and then needing these traversals is unfortunate. + std::unordered_map node_map = from.BuildNodeNameIndex(); + for (Node* node : to->nodes()) { + if (const Node* old_node = node_map[node->name()]) { + if (const std::shared_ptr& stack = + old_node->GetStackTrace()) { + DVLOG(2) << "Stack for " << node->name() << " " + << old_node->GetStackTrace()->ToString( + AbstractStackTrace::TracePrintingOptions()); + node->SetStackTrace(stack); + } else { + DVLOG(1) << "No stack for " << node->name() << " (" << node + << ") in Graph " << &from; + } + } else { + DVLOG(1) << "No stack for " << node->name() << " (" << node + << ") in Graph " << &from; + } + } + + return absl::OkStatus(); +} + +absl::StatusOr> +ImporterBase::CreatePlaceholderNodeForFeed( + const TensorShapeProto& shape, DataType dtype, Node* node, int index, + const std::unordered_map& node_name_map) { + DCHECK_LT(index, node->num_outputs()); + const bool update_inplace = node->num_outputs() == 1 && index == 0; + std::string new_node_name = + update_inplace ? node->name() + : GetUniqueNodeName(node->name(), index, node_name_map); + + Node* placeholder_node; + NodeBuilder builder(new_node_name, "Placeholder"); + builder.Attr("shape", shape); + builder.Attr("dtype", dtype); + TF_RETURN_IF_ERROR(builder.Finalize(graph_.get(), &placeholder_node)); + + // Update edges from original feed with Placeholder node. + std::vector data_edges; + std::vector control_edges; + for (const tensorflow::Edge* edge : node->out_edges()) { + if (edge->src_output() == index) { + data_edges.push_back(edge); + } else if (update_inplace && edge->IsControlEdge()) { + control_edges.push_back(edge); + } + } + + for (const auto* edge : data_edges) { + TF_RETURN_IF_ERROR(graph_->UpdateEdge(placeholder_node, 0, edge->dst(), + edge->dst_input())); + } + + // TODO(lyandy): Preserve control dependencies properly by not forwarding + // control dependencies to data outputs and not removing single output nodes. + // When a data output is replaced as a feed, unless there is another non feed + // data output or an explicit control output used by the same node, transitive + // control dependencies are not to be executed. For single output nodes, + // Placeholders can be converted to a NoOp if there are no uses, and + // PlaceholderWithDefault can be converted to an Identity. + for (const auto* edge : control_edges) { + graph_->AddControlEdge(placeholder_node, edge->dst()); + graph_->RemoveControlEdge(edge); + } + + if (update_inplace) { + graph_->RemoveNode(node); + } + + return std::pair(placeholder_node, update_inplace); +} + +absl::Status ImporterBase::GetInputOutputNodes( + const std::unordered_map& node_name_map, + std::unordered_set* nodes) { + auto add_node = [&](absl::string_view name) { + auto it = node_name_map.find(std::string(name)); + if (it == node_name_map.end()) { + return errors::FailedPrecondition( + absl::StrCat("Graph does not contain node: ", name)); + } + nodes->insert(it->second); + return absl::OkStatus(); + }; + + // Remap feeds and fetches to newly created Placeholder nodes. + for (const auto& input : specs_.inputs) { + TensorId tensor = ParseTensorName(input.first); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + TF_RETURN_IF_ERROR(add_node(remapped_it->second)); + } else { + TF_RETURN_IF_ERROR(add_node(tensor.node())); + } + } + + for (const auto& output : specs_.outputs) { + TensorId tensor = ParseTensorName(output); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + TF_RETURN_IF_ERROR(add_node(remapped_it->second)); + } else { + TF_RETURN_IF_ERROR(add_node(tensor.node())); + } + } + + for (const auto& control_output : specs_.control_outputs) + TF_RETURN_IF_ERROR(add_node(control_output)); + + return absl::OkStatus(); +} + +// TODO(jpienaar): Remove this post shape inference on import flag is removed. +absl::Status ImporterBase::AddNodesToShapeRefiner( + std::unordered_map* node_name_map) { + shape_refiner_ = + std::make_unique(graph_->versions(), graph_->op_registry()); + // Some operations (for example "TPUExecute") don't have shape inference + // function defined, so we should set this to false for adding nodes with + // these types of operations. + shape_refiner_->set_require_shape_inference_fns(false); + shape_refiner_->set_function_library_for_shape_inference(&graph_flib_); + + TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); + + // First add all nodes to the refiner. + for (Node* node : ordered_nodes_) { + // We need to use a TensorFlow node to teach the shape refiner that user + // specifies certain data type and shape for the inputs in the `specs_`. + // This node shouldn't have any inputs, only have one output and its + // output type/shape is only determined by its "named" attributes. (The + // attributes should have fixed names so we can use the info from `specs_` + // to set the value of them.) `Placeholder` satisfies these constraints. + // + // Therefore, if the input node isn't a `Placeholder`, we create one and use + // it to replace the original input node, so the shape refiner can + // successfully propagate the user's input type and shape to the rest of the + // graph. + bool node_added_to_shape_refiner = false; + auto it = feeds_by_node.find(node->name()); + if (it != feeds_by_node.end()) { + auto op_name = node->op_def().name(); + if (op_name != "Placeholder" && op_name != "LegacyFedInput" && + op_name != FunctionLibraryDefinition::kArgOp) { + for (const auto& output_tensor : it->second) { + const int index = output_tensor.first; + const ArrayInfo& array_info = output_tensor.second->second; + + DataType dtype = array_info.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(index); + } + + TF_ASSIGN_OR_RETURN( + auto placeholder_node_and_removed, + CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, + *node_name_map)); + + Node* placeholder_node = placeholder_node_and_removed.first; + if (placeholder_node_and_removed.second) { + // Original node has been removed from the graph. + node = placeholder_node; + node_added_to_shape_refiner = true; + } + remapped_feeds_[{it->first, index}] = placeholder_node->name(); + (*node_name_map)[placeholder_node->name()] = placeholder_node; + // Add the new placeholder node to the shape refiner. + absl::Status status = shape_refiner_->AddNode(placeholder_node); + if (!status.ok()) { + return EmitErrorWithLocationStr(*placeholder_node, status); + } + } + } else { + auto index_it = it->second.find(0); + if (index_it == it->second.end()) { + return errors::FailedPrecondition( + "Missing feed output tensor at index 0 for node '", node->name(), + "'"); + } + node->AddAttr("shape", index_it->second->second.shape); + DataType dtype = index_it->second->second.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(0); + } + node->AddAttr("dtype", dtype); + } + } + if (!node_added_to_shape_refiner) { + // Add the node to the shape refiner if the node hasn't been removed. + absl::Status status = shape_refiner_->AddNode(node); + if (!status.ok()) { + return EmitErrorWithLocationStr(*node, status); + } + } + + auto set_shape_from_list_attr = [&](const AttrValue* attr) { + auto& list = attr->list(); + // This follows the same approach as in ValidateShape, but only flags + // warning in case where there are mismatch in number of shapes and + // outputs and in which case it just returns without attempting to refine. + if (list.shape_size() != node->num_outputs()) { + LOG(WARNING) << "Node '" << node->name() << "' has " + << node->num_outputs() << " outputs but the " + << kOutputShapesAttrName + << " attribute specifies shapes for " << list.shape_size() + << " outputs"; + return absl::OkStatus(); + } + + for (const auto& shape : llvm::enumerate(list.shape())) { + auto* node_context = shape_refiner_->GetContext(node); + shape_inference::ShapeHandle handle; + absl::Status status = + node_context->MakeShapeFromShapeProto(shape.value(), &handle); + if (!status.ok()) { + return EmitErrorWithLocationStr(*node, status); + } + node_context->set_output(shape.index(), handle); + } + return absl::OkStatus(); + }; + + // If it is the argument node, the shape handle is set explicitly, so it + // can be propagated to the body nodes of the function. + if (StringPiece(node->type_string()) == FunctionLibraryDefinition::kArgOp) { + auto* node_context = shape_refiner_->GetContext(node); + DCHECK(node_context != nullptr); + if (const AttrValue* attr = node->attrs().Find("shape")) { + shape_inference::ShapeHandle handle; + absl::Status status = + node_context->MakeShapeFromShapeProto(attr->shape(), &handle); + if (!status.ok()) { + return EmitErrorWithLocationStr(*node, status); + } + node_context->set_output(0, handle); + } else if (const AttrValue* attr = + node->attrs().Find(kOutputShapesAttrName)) { + TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); + } else { + node_context->set_output(0, node_context->UnknownShape()); + } + } + + // Following GraphConstructor::ValidateShape called from + // GraphConstructor::Convert, override the shape if _output_shapes is set. + if (specs_.unconditionally_use_set_output_shapes || + node->op_def().name() == "ReadVariableOp") { + if (const AttrValue* attr = node->attrs().Find(kOutputShapesAttrName)) + TF_RETURN_IF_ERROR(set_shape_from_list_attr(attr)); + } + } + + // Since we might have inserted and removed nodes from the graph, fix + // source/sink edges and reconstruct the RPO ordering of nodes + FixupSourceAndSinkEdges(graph_.get()); + + // Prune nodes in the graph that are not reachable from the output. + if (specs_.prune_unused_nodes) { + std::unordered_set prune_start; + TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); + if (!prune_start.empty()) { + if (PruneForReverseReachability(graph_.get(), prune_start)) { + VLOG(1) << "Pruned unused nodes in graphdef"; + } else { + VLOG(1) << "No unused nodes in graphdef to prune"; + } + } else { + VLOG(1) << "No output nodes specified, skipping pruning"; + } + } else { + VLOG(1) << "Pruning unused nodes in graphdef is disabled"; + } + + // Re-initialize ordered_nodes_ since we might have modified the graph. + ordered_nodes_.clear(); + TopologicalOrdering( + *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, GroupByDevice()); + + VLOG(1) << "Inferring graph shapes to fixpoint"; + + // The "changed" information from UpdateNode can give false positives, so we + // create a dedicated method to verify the shapes are not changed before and + // after the shape refine. + auto same_inferred_shape = [](shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1) -> bool { + if (s0.SameHandle(s1) || (!c->RankKnown(s0) && !c->RankKnown(s1))) { + return true; + } + if (c->Rank(s0) != c->Rank(s1)) { + return false; + } + for (int i = 0; i < c->Rank(s0); ++i) { + if (!c->Dim(s0, i).SameHandle(c->Dim(s1, i))) { + int64_t val0 = c->Value(c->Dim(s0, i)); + int64_t val1 = c->Value(c->Dim(s1, i)); + // Negative value is treated as unknown so all negative values indicate + // the same dimension. + if (val0 >= 0 && val1 >= 0 && val0 != val1) return false; + } + } + return true; + }; + + bool changed = true; + int i = 0; + const int kMaxIterationCount = 2; + while (changed && i != kMaxIterationCount) { + changed = false; + for (const Node* node : ordered_nodes_) { + auto* shape_context = shape_refiner_->GetContext(node); + DCHECK(shape_context != nullptr); + absl::InlinedVector existing; + existing.reserve(shape_context->num_outputs()); + for (int o = 0; o < shape_context->num_outputs(); ++o) { + existing.push_back(shape_context->output(o)); + } + bool inferred = false; + shape_inference::ShapeHandle handle; + absl::Status status = + shape_refiner_->UpdateNode(node, /*relax=*/false, &inferred); + if (!status.ok()) { + return EmitErrorWithLocationStr(*node, status); + } + for (int o = 0; o < shape_context->num_outputs(); ++o) { + if (!same_inferred_shape(shape_context, shape_context->output(o), + existing[o])) { + changed = true; + break; + } + } + } + ++i; + } + if (i >= kMaxIterationCount) { + LOG(WARNING) << "Graph shapes did not converge to a fixpoint within " + << kMaxIterationCount + << " iterations. Graph shapes may be conservative."; + } + VLOG(1) << "Graph shapes were inferred with " << (i - 1) + << " extra rounds of analysis to reach a fixpoint."; + return absl::OkStatus(); +} + +absl::StatusOr ImporterBase::InferInputType(const Node& node, + int idx, + mlir::Builder builder) { + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove this if shape inference on import flag is removed. + auto* context = shape_refiner_->GetContext(&node); + DataType dtype = node.input_type(idx); + return ConvertDataTypeAndShape(dtype, context->input(idx), + context->input_handle_shapes_and_types(idx), + context, builder); + } + DataType dtype = node.properties()->input_types[idx]; + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); + return mlir::UnrankedTensorType::get(element_type); +} + +absl::StatusOr ImporterBase::InferOutputType( + const Node& node, int idx, mlir::Builder builder) { + DataType dtype = node.properties()->output_types[idx]; + + // Returns output type given inference context. + auto shape_ic = + [&](shape_inference::InferenceContext* c) -> absl::StatusOr { + // TODO(b/200093974): Post triage, consider following + // GraphConstructor::ValidateShape in checking _output_shapes always. + if (specs_.unconditionally_use_set_output_shapes) { + if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) { + auto& list = attr->list(); + if (list.shape_size() > idx) { + const TensorShapeProto& p = list.shape()[idx]; + shape_inference::ShapeHandle h; + absl::Status s = c->MakeShapeFromShapeProto(p, &h); + if (!s.ok()) + return errors::InvalidArgument( + "Node '", node.name(), " has an invalid ", + kOutputShapesAttrName, " attribute (shape #", idx, " error:'", + s.message(), "')"); + c->set_output(idx, h); + } + } + } + + return ConvertDataTypeAndShape(dtype, c->output(idx), + c->output_handle_shapes_and_types(idx), c, + builder); + }; + + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove this if shape inference on import flag is removed. + shape_inference::InferenceContext* shape_context = + shape_refiner_->GetContext(&node); + return shape_ic(shape_context); + } + + // Treat TensorList init ops specially here as the op requires knowing its + // element dtype. + // TODO(jpienaar): Reconsider post refactoring shape functions. + if (node.type_string() == "TensorListReserve" || + node.type_string() == "EmptyTensorList") { + mlir::Type etype; + if (auto element_dtype = node.attrs().Find("element_dtype")) { + TF_RETURN_IF_ERROR( + ConvertDataType(element_dtype->type(), builder, &etype)); + } + return GetTypeFromTFTensorShape( + {}, mlir::TF::VariantType::get({mlir::UnrankedTensorType::get(etype)}, + etype.getContext())); + } + + if (node.IsWhileNode()) { + auto* output_shapes = node.attrs().Find("output_shapes"); + auto* element_types = node.attrs().Find("T"); + if (output_shapes && !output_shapes->list().shape().empty()) { + const auto& output_shape = output_shapes->list().shape(idx); + const auto& element_type = element_types->list().type(idx); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + } + } + + auto type_from_array_attr = [&node, &idx, &builder]( + absl::string_view output_shape_attr, + absl::string_view element_type_attr) { + auto* output_shapes = node.attrs().Find(output_shape_attr); + auto* element_types = node.attrs().Find(element_type_attr); + const auto& output_shape = output_shapes->list().shape(idx); + const auto& element_type = element_types->list().type(idx); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + }; + + if (node.type_string() == "IteratorGetNext" || + node.type_string() == "IteratorGetNextSync" || + node.type_string() == "MultiDeviceIteratorGetNextFromShard") + return type_from_array_attr("output_shapes", "output_types"); + + if (node.type_string() == "InfeedDequeueTuple") + return type_from_array_attr("shapes", "dtypes"); + + if (node.type_string() == "InfeedDequeue") { + assert(idx == 0); + const auto& output_shape = node.attrs().Find("shape")->shape(); + const auto& element_type = node.attrs().Find("dtype")->type(); + return ConvertToMlirTensorType(output_shape, element_type, &builder); + } + + // Returns a simple, more conservative unranked tensor type. + auto default_type = [&]() -> absl::StatusOr { + mlir::Type element_type; + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &element_type)); + + // TODO(b/200093974): Post triage, consider following + // GraphConstructor::ValidateShape in checking _output_shapes. + if (specs_.unconditionally_use_set_output_shapes) { + if (const AttrValue* attr = node.attrs().Find(kOutputShapesAttrName)) { + auto& list = attr->list(); + if (list.shape_size() > idx) { + llvm::SmallVector shape; + const TensorShapeProto& shape_proto = list.shape()[idx]; + if (shape_proto.unknown_rank()) + return mlir::UnrankedTensorType::get(element_type); + TF_RETURN_IF_ERROR(ConvertToMlirShape(shape_proto, &shape)); + return GetTypeFromTFTensorShape(shape, element_type); + } + } + } + + return mlir::UnrankedTensorType::get(element_type); + }; + + // Below we only try and do some shape inference for "source" ops which have + // no inputs. + if (node.num_inputs() > 0) return default_type(); + + // Do some simply inference here to get the function arguments correct for + // this common case. + // TODO(jpienaar): Reconsider post refactoring shape functions. + if (node.IsArg()) { + if (dtype == DT_RESOURCE) { + const AttrValue* dtype_attr = node.attrs().Find("_handle_dtypes"); + const AttrValue* shape_attr = node.attrs().Find("_handle_shapes"); + if (dtype_attr && shape_attr) { + if (dtype_attr->list().type().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_dtypes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + if (shape_attr->list().shape().empty()) { + return errors::InvalidArgument( + "Invalid \"_handle_shapes\" attribute value for _Arg node: ", + shape_attr->DebugString()); + } + DataType dtype = dtype_attr->list().type(0); + const TensorShapeProto& shape_proto = shape_attr->list().shape(0); + TF_ASSIGN_OR_RETURN( + auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); + return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( + {mlir::cast(etype)}, builder.getContext())); + } else { + return mlir::UnrankedTensorType::get( + mlir::TF::ResourceType::get(builder.getContext())); + } + } else if (auto shape = node.attrs().Find("_output_shapes")) { + if (shape->has_list() && shape->list().shape_size() == 1) { + return ConvertToMlirTensorType(shape->list().shape().at(0), dtype, + &builder); + } + } + } + + const tensorflow::OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR( + graph_->op_registry()->LookUp(node.type_string(), &op_reg_data)); + if (!op_reg_data) { + DVLOG(1) << "Skipping inference for unregistered op " << node.type_string(); + return default_type(); + } + if (op_reg_data->shape_inference_fn == nullptr) { + DVLOG(1) << "Skipping inference for op without shape function " + << node.type_string(); + return default_type(); + } + shape_inference::InferenceContext c(graph_->versions().producer(), + node.attrs(), op_reg_data->op_def, + std::vector{}, {}, + /*input_tensors_as_shapes=*/{}, {}); + TF_RETURN_IF_ERROR(c.Run(op_reg_data->shape_inference_fn)); + return shape_ic(&c); +} + +absl::StatusOr ImporterBase::ConvertDataTypeAndShape( + DataType dtype, const shape_inference::ShapeHandle& handle, + const std::vector* handle_subtypes, + shape_inference::InferenceContext* context, mlir::Builder builder) { + TF_ASSIGN_OR_RETURN(auto subtypes, + ConvertSubtypes(handle_subtypes, context, builder)); + + mlir::Type element_type; + if (dtype == DT_VARIANT) + element_type = mlir::TF::VariantType::get(subtypes, context_); + else if (dtype == DT_RESOURCE) + element_type = mlir::TF::ResourceType::get(subtypes, context_); + else + TF_RETURN_IF_ERROR( + ::tensorflow::ConvertDataType(dtype, builder, &element_type)); + + return ConvertElementTypeAndShape(element_type, handle, context, builder); +} + +absl::StatusOr ImporterBase::ConvertElementTypeAndShape( + mlir::Type element_type, const shape_inference::ShapeHandle& handle, + shape_inference::InferenceContext* context, mlir::Builder builder) { + if (!context->RankKnown(handle)) { + return mlir::UnrankedTensorType::get(element_type); + } + + // Sentinel for an unknown dimension size. getTensorType interprets any + // negative value as an unknown dimension. + // TODO(jmolloy): Ideally this shouldn't be a local sentinel. + const int64_t kUnknownDim = -1; + + absl::InlinedVector dimensions; + int32_t rank = context->Rank(handle); + dimensions.reserve(rank); + for (int i = 0; i < rank; ++i) { + auto dim_handle = context->Dim(handle, i); + if (!context->ValueKnown(dim_handle)) + dimensions.push_back(kUnknownDim); + else + dimensions.push_back(context->Value(dim_handle)); + } + + return GetTypeFromTFTensorShape( + llvm::ArrayRef(dimensions.begin(), dimensions.end()), element_type); +} + +absl::StatusOr ImporterBase::ConvertSubtypes( + const std::vector* handle_subtypes, + shape_inference::InferenceContext* context, mlir::Builder builder) { + ElementSubtypes subtypes; + if (!handle_subtypes) return subtypes; + + subtypes.reserve(handle_subtypes->size()); + for (const auto& subtype : *handle_subtypes) { + mlir::Type element_type; + TF_RETURN_IF_ERROR( + ::tensorflow::ConvertDataType(subtype.dtype, builder, &element_type)); + TF_ASSIGN_OR_RETURN(TensorType type, + ConvertElementTypeAndShape(element_type, subtype.shape, + context, builder)); + subtypes.push_back(type); + } + return subtypes; +} + +absl::Status ImporterBase::ConvertFunctionCallAttribute( + const std::string& base_name, const AttrValue& value, + NamedAttrList* attributes) { + TF_ASSIGN_OR_RETURN(auto func_attr, + ConvertFunctionCallName(value.func().name())); + if (!func_attr) return absl::OkStatus(); + attributes->push_back(builder_.getNamedAttr(base_name, func_attr)); + + for (const auto& it : value.func().attr()) { + auto name = absl::StrCat(base_name, ".", it.first); + TF_ASSIGN_OR_RETURN(auto value, ConvertAttributeValue(it.second)); + attributes->push_back(builder_.getNamedAttr(name, value)); + } + return absl::OkStatus(); +} + +absl::StatusOr ImporterBase::ConvertFunctionCallName( + const std::string& func_name) { + // Some ops like XlaHostCompute op uses empty value to represent missing + // functions. Such attribute values should be defined optional in MLIR + // definition. + if (func_name.empty()) return mlir::FlatSymbolRefAttr(); + + TF_RETURN_IF_ERROR(ConvertLibFunction(func_name)); + auto mlir_func_name = (*tf_name_to_mlir_name_)[func_name]; + return mlir::SymbolRefAttr::get(builder_.getContext(), mlir_func_name); +} + +absl::StatusOr ImporterBase::ConvertAttributeValue( + const AttrValue& value) { + switch (value.value_case()) { + case AttrValue::kFunc: { + // TODO(b/156546237): Unify kFunc/NameAttrList attribute representation. + // Currently kFunc/NameAttrList attributes in a kList/repeated AttrValue + // will not use this representation. This also doesn't handle empty + // function values like ConvertFunctionCallName method. + NamedAttrList attrs; + for (const auto& func_attr : value.func().attr()) { + TF_ASSIGN_OR_RETURN( + auto attr, ImporterBase::ConvertAttributeValue(func_attr.second)); + attrs.push_back(builder_.getNamedAttr(func_attr.first, attr)); + } + auto func_attrs = builder_.getDictionaryAttr(attrs); + return mlir::TF::FuncAttr::get(context_, value.func().name(), func_attrs); + } + case AttrValue::kList: { + if (!value.list().func().empty()) { + absl::InlinedVector attrs; + for (const auto& item : value.list().func()) { + TF_ASSIGN_OR_RETURN(auto attr, ConvertFunctionCallName(item.name())); + if (item.attr_size() != 0) + return errors::Unimplemented( + "func attributes with non-zero attr.size()"); + if (attr) attrs.push_back(attr); + } + return builder_.getArrayAttr( + llvm::ArrayRef(attrs.begin(), attrs.end())); + } + return ConvertNonFuncAttributeValue(value, &builder_); + } + default: + return ConvertNonFuncAttributeValue(value, &builder_); + } +} + +void ImporterBase::GetArgsAndRetsFromFunctionBody( + const FunctionBody& fbody, absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes, + absl::InlinedVector* control_ret_nodes) { + arg_nodes->reserve(fbody.arg_nodes.size()); + ret_nodes->reserve(fbody.ret_nodes.size()); + for (auto arg : fbody.arg_nodes) { + arg_nodes->emplace_back(arg, 0); + } + for (auto ret : fbody.ret_nodes) { + ret_nodes->emplace_back(ret, 0); + } + *control_ret_nodes = fbody.control_ret_nodes; +} + +absl::Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) { + // If the library function has been converted already, nothing needs to be + // done. + if (tf_name_to_mlir_name_->find(std::string(func_name)) != + tf_name_to_mlir_name_->end()) + return absl::OkStatus(); + + std::string mlir_func_name( + function_name_uniquifier_->GetUniqueName(func_name)); + (*tf_name_to_mlir_name_)[std::string(func_name)] = mlir_func_name; + + const auto& func_lib = graph_flib_; + const auto* func_def = func_lib.Find(std::string(func_name)); + if (func_def == nullptr) { + return errors::FailedPrecondition( + absl::StrCat("Failed to find function '", StringRefToView(func_name), + "'. The imported TensorFlow GraphDef is ill-formed.")); + } + + // Converts the argument and return types to MLIR types. + std::vector attributes; + attributes.reserve(func_def->attr_size()); + for (const auto& name_and_value : func_def->attr()) { + // This is a function definition attribute, so it shouldn't contain + // kFunc attribute and it is treated as normal one. + TF_ASSIGN_OR_RETURN(auto attr, + ConvertAttributeValue(name_and_value.second)); + std::string attr_name = + mangling_util::MangleAttributeName(name_and_value.first); + attributes.push_back(builder_.getNamedAttr(attr_name, attr)); + } + + // Checks opdef stateful attribute and import that as Function Attribute + if (func_def->signature().is_stateful()) { + auto stateful_str = mlir::TF::TensorFlowDialect::GetStatefulAttrName(); + attributes.push_back( + builder_.getNamedAttr(stateful_str, builder_.getUnitAttr())); + } + + // Checks for an associated custom gradient function. Adds it to the attribute + // list of this function. + auto grad_func_name = func_lib.FindGradient(std::string(func_name)); + if (!grad_func_name.empty()) { + TF_RETURN_IF_ERROR(ConvertLibFunction(grad_func_name)); + auto mlir_grad_func_name = (*tf_name_to_mlir_name_)[grad_func_name]; + auto gradient_attr = + mlir::SymbolRefAttr::get(builder_.getContext(), mlir_grad_func_name); + auto grad_string = mlir::TF::TensorFlowDialect::GetGradientAttrName(); + attributes.push_back(builder_.getNamedAttr(grad_string, gradient_attr)); + } + + deferred_functions_.emplace(func_name.str(), attributes); + return absl::OkStatus(); +} + +absl::Status ImporterBase::PruneUnreachableNodes( + std::unordered_map* node_name_map) { + std::unordered_set prune_start; + TF_RETURN_IF_ERROR(GetInputOutputNodes(*node_name_map, &prune_start)); + + if (!prune_start.empty()) { + if (PruneForReverseReachability(graph_.get(), prune_start)) { + VLOG(1) << "Pruned unused nodes in graphdef"; + } else { + VLOG(1) << "No unused nodes in graphdef to prune"; + } + } else { + VLOG(1) << "No output nodes specified, skipping pruning"; + } + return absl::OkStatus(); +} + +absl::Status ImporterBase::ConvertFeedsToPlaceholders( + std::unordered_map* node_name_map) { + // Feeds (edges) are converted into single-output placeholder nodes to + // simplify the conversion process. + TF_ASSIGN_OR_RETURN(auto feeds_by_node, GetFeedsByNode(specs_.inputs)); + for (const auto& it : feeds_by_node) { + TensorId tensor = ParseTensorName(it.first); + auto jt = node_name_map->find(std::string(tensor.node())); + if (jt == node_name_map->end()) { + return errors::FailedPrecondition( + absl::StrCat("Graph does not contain node: ", tensor.node())); + } + + Node* node = jt->second; + auto op_name = node->op_def().name(); + if (op_name != "Placeholder" && op_name != "LegacyFedInput" && + op_name != FunctionLibraryDefinition::kArgOp) { + for (const auto& output_tensor : it.second) { + const int index = output_tensor.first; + const ArrayInfo& array_info = output_tensor.second->second; + + DataType dtype = array_info.imported_dtype; + // Uses the existing output type if it isn't specified by the user. + if (dtype == DT_INVALID) { + dtype = node->output_type(index); + } + + TF_ASSIGN_OR_RETURN( + auto placeholder_node_and_removed, + CreatePlaceholderNodeForFeed(array_info.shape, dtype, node, index, + *node_name_map)); + + Node* placeholder_node = placeholder_node_and_removed.first; + if (placeholder_node->in_edges().empty()) { + graph_->AddControlEdge(graph_->source_node(), placeholder_node, + true /* skip test for duplicates */); + } + if (placeholder_node->out_edges().empty()) { + graph_->AddControlEdge(placeholder_node, graph_->sink_node(), + true /* skip test for duplicates */); + } + remapped_feeds_[{it.first, index}] = placeholder_node->name(); + (*node_name_map)[placeholder_node->name()] = placeholder_node; + } + } + } + return absl::OkStatus(); +} + +absl::Status ImporterBase::PrepareConvert(const Graph& graph, + std::unique_ptr graph_def) { + // TODO(fengliuai): Converting to GraphDef and back is the easiest way to + // clone a graph. + // TODO(fengliuai): clone the graph without going to graph_def first. + if (graph_def == nullptr) { + graph_def = std::make_unique(); + graph.ToGraphDef(graph_def.get()); + } + graph_ = std::make_unique(graph.flib_def()); + GraphConstructorOptions opts; + opts.allow_internal_ops = true; + opts.add_default_attributes = true; + TF_RETURN_IF_ERROR(::tensorflow::ConvertGraphDefToGraph( + opts, std::move(*graph_def), graph_.get())); + + TF_RETURN_IF_ERROR(RemoveBackedges()); + + TF_RETURN_IF_ERROR(CopyStackTraces(graph, graph_.get())); + + auto node_name_map = graph_->BuildNodeNameIndex(); + + if (specs_.enable_shape_inference) { + // TODO(jpienaar): Remove once infer shapes on import flag is removed. + TF_RETURN_IF_ERROR(AddNodesToShapeRefiner(&node_name_map)); + } else { + TF_RETURN_IF_ERROR(ConvertFeedsToPlaceholders(&node_name_map)); + } + + // Prune nodes in the graph that are not reachable from the output. + if (specs_.prune_unused_nodes) { + TF_RETURN_IF_ERROR(PruneUnreachableNodes(&node_name_map)); + } + + if (!specs_.enable_shape_inference) { + // Re-initialize ordered_nodes_ since we might have modified the graph. + ordered_nodes_.clear(); + TopologicalOrdering( + *graph_, [&](Node* n) { ordered_nodes_.push_back(n); }, + GroupByDevice()); + } + + return absl::OkStatus(); +} + +absl::Status ImporterBase::Convert( + llvm::StringRef func_name, mlir::FunctionType func_type, + const absl::InlinedVector& arg_nodes, + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes, + llvm::ArrayRef attrs) { + // TODO(b/122040776): Uses debug info for FunctionDef. + auto function = mlir::func::FuncOp::create(mlir::UnknownLoc::get(context_), + func_name, func_type, attrs); + + module_.push_back(function); + // Seeds the builder with an initial block. + function.addEntryBlock(); + builder_ = mlir::OpBuilder(function.getBody()); + + // Create the graph operation in which we will convert the individual nodes. + auto graph = builder_.create( + function.getLoc(), func_type.getResults()); + builder_.createBlock(&graph.getBody()); + + for (const Node* node : ordered_nodes_) { + TF_RETURN_IF_ERROR(ConvertNode(*node)); + } + + // Adds the backedges back to the function by creating the source and sink + // pairs. + TF_RETURN_IF_ERROR(AddBackedges()); + + TF_RETURN_IF_ERROR(ConvertFunctionArgAndRets(function, graph, + func_type.getInputs(), arg_nodes, + ret_nodes, control_ret_nodes)); + + // TODO(jpienaar): Update post removing shape_refinier_. + if (!specs_.enable_shape_inference) { + // Refine graph's type given more precise fetch. + auto fetch = graph.GetFetch(); + bool all_equal = true; + for (auto it : + llvm::zip_first(graph.getResults(), fetch.getOperandTypes())) { + auto rt = std::get<1>(it); + if (rt == std::get<0>(it).getType()) continue; + std::get<0>(it).setType(rt); + all_equal = false; + } + if (!all_equal) { + function.setType(mlir::FunctionType::get(function.getContext(), + func_type.getInputs(), + graph.getResultTypes())); + } + } + + return absl::OkStatus(); +} + +absl::Status ImporterBase::ConvertFunctionArgAndRets( + mlir::func::FuncOp func, mlir::tf_executor::GraphOp graph_op, + llvm::ArrayRef arg_types, + const absl::InlinedVector& arg_nodes, + const absl::InlinedVector& ret_nodes, + const absl::InlinedVector& control_ret_nodes) { + // Store the arg/return attributes as a list rather than uniqueuing during + // construction. + llvm::SmallVector arg_attrs; + arg_attrs.resize(func.getNumArguments()); + llvm::SmallVector ret_attrs; + ret_attrs.resize(func.getNumResults()); + + auto set_attributes_on_func = [&](Node* node, int64_t index, bool is_arg) { + for (const auto& node_attr : node->attrs()) { + const auto& key = node_attr.first; + // Only import optional attributes (e.g., those starting with an + // underscore). + if (key.empty() || key[0] != '_') continue; + // Ignore shape inference attributes as shape information is already + // populated in the result type. + if (IsOutputShapesAttribute(node_attr.second, key) || + IsResourceOutputShapesAttribute(node_attr.second, key)) + continue; + TF_ASSIGN_OR_RETURN(auto converted_attr, + ConvertAttributeValue(node_attr.second)); + std::string dialect_attribute = "tf." + key; + if (is_arg) { + arg_attrs[index].set(dialect_attribute, converted_attr); + } else { + func.setResultAttr(index, dialect_attribute, converted_attr); + ret_attrs[index].set(dialect_attribute, converted_attr); + } + } + return absl::OkStatus(); + }; + + auto* bb = &func.front(); + llvm::SmallDenseMap, mlir::Value, 4> + arg_nodes_to_values; + for (int i = 0, e = arg_types.size(); i < e; ++i) { + auto& arg_node = arg_nodes[i]; + // The lookup can't fail here: otherwise some nodes in the function haven't + // be converted to mlir operations and don't have a mapping. + mlir::Operation* island = node_values_.find(arg_node.node->id())->second; + + auto bb_arg = bb->getArgument(i); + mlir::Value arg_def = bb_arg; + + if (island->getNumResults() != 2) + return errors::InvalidArgument( + "Only feed output tensors of single output nodes are supported"); + + // Collect mapping of OutputTensor to associated block arg. + arg_nodes_to_values.try_emplace({arg_node.node, arg_node.index}, arg_def); + island->getResult(0).replaceAllUsesWith(arg_def); + // Erase control outputs from feed. + auto control_uses = island->getResult(1).getUses(); + for (auto& control_use : llvm::make_early_inc_range(control_uses)) + control_use.getOwner()->eraseOperand(control_use.getOperandNumber()); + + if (!arg_node.node->requested_device().empty()) + arg_attrs[i].set("tf.device", builder_.getStringAttr( + arg_node.node->requested_device())); + + if (arg_node.node->IsArg()) { + TF_RETURN_IF_ERROR( + set_attributes_on_func(arg_node.node, i, /*is_arg=*/true)); + } + + island->dropAllReferences(); + island->erase(); + } + + llvm::SmallVector inst_to_return; + for (const auto& ret_and_idx : llvm::enumerate(ret_nodes)) { + const auto& ret = ret_and_idx.value(); + auto* inst = node_values_[ret.node->id()]; + if (ret.node->IsRetval()) { + if (!ret.node->requested_device().empty()) + ret_attrs[ret_and_idx.index()].set( + "tf.device", builder_.getStringAttr(ret.node->requested_device())); + TF_RETURN_IF_ERROR(set_attributes_on_func(ret.node, ret_and_idx.index(), + /*is_arg=*/false)); + // Lookup the instruction inside the island + auto island_op = llvm::cast(inst); + mlir::Operation* inner_op = &island_op.GetBody().front(); + // Remove kRetOp or kDeviceRetOp operation and return its operand. + // kRetOp and kDeviceRetOp should have just one operand unless they have + // control dependencies. + if (inner_op->getNumOperands() != 1) + return errors::Unimplemented("Return node with multiple inputs."); + inst_to_return.push_back(inner_op->getOperand(0)); + inst->dropAllReferences(); + inst->erase(); + } else { + // Lookup and use block arg if fetch is a feed. + auto it = arg_nodes_to_values.find({ret.node, ret.index}); + if (it != arg_nodes_to_values.end()) + inst_to_return.push_back(it->second); + else + inst_to_return.push_back(inst->getResult(ret.index)); + } + } + + for (Node* control_ret : control_ret_nodes) { + auto* inst = node_values_[control_ret->id()]; + inst_to_return.push_back(*std::prev(inst->result_end())); + } + + // Terminate the function by adding a Fetch operation to terminate the graph + // and a return operation to return the Graph results. + builder_.setInsertionPointToEnd(&graph_op.getBody().front()); + builder_.create(graph_op.getLoc(), + inst_to_return); + builder_.setInsertionPointToEnd(bb); + builder_.create(mlir::UnknownLoc::get(context_), + graph_op.getResults()); + + func.setAllArgAttrs( + llvm::to_vector<4>(llvm::map_range(arg_attrs, [&](NamedAttrList& list) { + return list.getDictionary(context_); + }))); + func.setAllResultAttrs( + llvm::to_vector<4>(llvm::map_range(ret_attrs, [&](NamedAttrList& list) { + return list.getDictionary(context_); + }))); + + return absl::OkStatus(); +} + +mlir::Location ImporterBase::GetLocation(const Node& node) { + DVLOG(1) << "Getting location for " << node.name() << " " << &node; + // TODO(b/142400497): What is the semantic contract for locations? + // Create a location for node `name` in function `function_name`. + auto create_location = [&](llvm::StringRef name, + llvm::StringRef function_name) -> mlir::Location { + // Use the catenation of function and node names as the lookup key into the + // debug info. This matches the way that the key is formed on the python + // side. + // + // We also use this as the name for the NameLoc for ops in function, since + // otherwise our names could collide across functions. + // For ops in the main graph, we omit the "@function_name" (which, would be + // just "@" since function_name would be empty) because some code seems to + // depend on the name being this way for correctness. + std::string debug_info_key = (name + "@" + function_name).str(); + std::string name_for_name_loc = + function_name.empty() ? name.str() : debug_info_key; + auto name_loc_id = mlir::StringAttr::get(context_, name_for_name_loc); + + std::shared_ptr stack_trace = node.GetStackTrace(); + + // Prefer stack traces if available, fallback to debug info if not, and then + // finally to just name. Older versions of debug info concatenated `@` onto + // the node name for the default graph, so we check both locations. + if (stack_trace != nullptr) { + } else if (stack_traces_.contains(name_for_name_loc)) { + stack_trace = stack_traces_.at(name_for_name_loc); + } else if (stack_traces_.contains(debug_info_key)) { + stack_trace = stack_traces_.at(debug_info_key); + } else { + DVLOG(1) << "No stack trace for " << node.name(); + } + + llvm::SmallVector locations; + + if (stack_trace != nullptr) { + DVLOG(1) << "Stack available for " << node.name(); + for (const StackFrame& frame : stack_trace->ToUncachedFrames()) { + auto file_name = mlir::StringAttr::get(context_, frame.file_name); + // Use col 1 as there is no column info in StackTrace. + auto file_line_loc = + mlir::FileLineColLoc::get(file_name, frame.line_number, 1); + locations.push_back(file_line_loc); + } + } + + // If there are no locations in the stack trace, fall back to just a + // NameLoc with no child. + if (locations.empty()) return mlir::NameLoc::get(name_loc_id); + + // Use the front FileLineColLoc to generate a NameLoc. + mlir::Location node_name_loc = + mlir::NameLoc::get(name_loc_id, locations.front()); + + // If there are more locations then generate a stack trace, otherwise just + // return the name loc. + auto callsite_locs = llvm::ArrayRef(locations).drop_front(); + return callsite_locs.empty() + ? node_name_loc + : mlir::CallSiteLoc::get(node_name_loc, callsite_locs); + }; + + // Create a location for node `name` in function `function_name`. + auto create_op_type_and_name_locations = [&]() { + return mlir::FusedLoc::get( + context_, + // Add the type operation for the propagation of op_type metadata. + {mlir::NameLoc::get( + mlir::StringAttr::get(context_, node.type_string() + ":")), + create_location(node.name(), function_name_for_debug_info_)}); + }; + + // For NextIteration nodes, location is used to pair source and sink nodes. + // Hence, we use node name as location to keep it unique. + // TODO(prakalps): In future the plan is to use tokens to pair source/sink + // nodes. Then NextIteration nodes would not need to be handled separately. + if (node.type_string() == "NextIteration") { + return create_op_type_and_name_locations(); + } + + const auto& node_def = node.def(); + auto original_nodes = + node_def.experimental_debug_info().original_node_names(); + auto original_funcs = + node_def.experimental_debug_info().original_func_names(); + + if (original_nodes.empty()) { + return create_op_type_and_name_locations(); + } else { + // If the original nodes are defined, then we use them to get a list of + // call sites, and then fuse them to a single fused location, with the name + // of the node_def. + llvm::SmallVector node_locations; + node_locations.reserve(original_nodes.size() + 2); + // Add the type operation for the propagation of op_type metadata. + node_locations.push_back(mlir::NameLoc::get( + mlir::StringAttr::get(context_, node.type_string() + ":"))); + // Retrieve the names from the experimental_debug_info. + for (int i = 0, e = original_nodes.size(); i != e; ++i) { + const auto& node_name = original_nodes[i]; + auto func_name = (i < original_funcs.size()) ? original_funcs[i] : ""; + node_locations.push_back(create_location(node_name, func_name)); + } + // Retrieve the name of the node_def. + node_locations.push_back( + create_location(node.name(), function_name_for_debug_info_)); + return mlir::FusedLoc::get(context_, node_locations); + } +} + +absl::Status ImporterBase::EmitErrorWithLocationStr( + const Node& node, const absl::Status& error_status) { + const mlir::Location location = GetLocation(node); + mlir::emitError(location); + return error_handler_.Combine(error_status); +} + +mlir::Operation* ImporterBase::CreateOperation( + const Node& node, llvm::StringRef node_type_name, + const mlir::OperationState& result, + const llvm::SmallVectorImpl& control_operands) { + // For the tf.executor specific operations (not wrapped in an island), we + // have an extra returned value for the control result, and we concatenate + // control and non-control operands. + mlir::SmallVector types(result.types); + types.push_back(mlir::tf_executor::ControlType::get(builder_.getContext())); + mlir::SmallVector operands(result.operands); + operands.append(control_operands.begin(), control_operands.end()); + + auto loc = result.location; + // Dispatch based on the name and create the appropriate operation. + if (node.IsSwitch()) { + // Switch and _SwitchN both are in switch class, differentiate based on + // op name. + if (node.op_def().name() == "_SwitchN") { + return builder_.create(loc, types, operands, + result.attributes); + } + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsMerge()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsNextIteration()) { + // NextIteration is a bit special, we create a pair of operations that are + // linked together through a token returned by the source. + // We make use of a separate builder to insert the source at the top of + // the block. + mlir::OpBuilder builder_at_begin(builder_.getBlock(), + builder_.getBlock()->begin()); + auto source_op = + builder_at_begin.create( + loc, operands[0].getType(), result.attributes); + return builder_.create( + loc, source_op.getToken(), operands, result.attributes); + } + if (node.IsLoopCond()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsEnter()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsExit()) { + return builder_.create(loc, types, operands, + result.attributes); + } + if (node.IsControlTrigger()) { + return builder_.create( + loc, mlir::ValueRange(operands), result.attributes); + } + // Regular TensorFlow operation are wrapped in a tf_executor.island. + auto island = builder_.create( + result.location, types, control_operands, + mlir::ArrayRef{}); + island.getBody().push_back(new mlir::Block); + mlir::OpBuilder island_builder = + mlir::OpBuilder::atBlockEnd(&island.GetBody()); + + // Create the operation inside the island now. + mlir::Operation* inner_op = island_builder.create(result); + + // Sets operand_segment_sizes or result_segment_sizes attribute to the op. + const auto set_segment_sizes_attr = + [&](const NameRangeMap& arg_ranges, + const protobuf::RepeatedPtrField& args, + llvm::StringRef attr_name) { + std::vector values; + values.reserve(args.size()); + for (const auto& arg : args) { + auto range = arg_ranges.at(arg.name()); + values.push_back(range.second - range.first); + } + auto attr_value = + mlir::DenseI32ArrayAttr::get(inner_op->getContext(), values); + inner_op->setAttr(attr_name, attr_value); + }; + + if (inner_op->hasTrait() || + inner_op->hasTrait()) { + // The op has multiple variadic operands or results. + // Calculate operand and result segment sizes using the OpDef. + NameRangeMap input_ranges, output_ranges; + // This will fail only if the OpDef is syntactically invalid. + // TODO(jpienaar): Convert this CHECK into a properly propagated error. + TF_CHECK_OK( + NameRangesForNode(node, node.op_def(), &input_ranges, &output_ranges)); + if (inner_op->hasTrait()) { + // Add derived "operand_segment_sizes" attr to the created operation. + // TODO(b/146937733): Don't use here. + set_segment_sizes_attr(input_ranges, node.op_def().input_arg(), + mlir::OpTrait::AttrSizedOperandSegments< + void>::getOperandSegmentSizeAttr()); + } + + if (inner_op->hasTrait()) { + // Add derived "result_segment_sizes" attr to the created operation. + // TODO(b/146937733): Don't use here. + set_segment_sizes_attr(output_ranges, node.op_def().output_arg(), + mlir::OpTrait::AttrSizedResultSegments< + void>::getResultSegmentSizeAttr()); + } + } + + if (VLOG_IS_ON(1)) { + mlir::OperationName name = inner_op->getName(); + if (!name.isRegistered() && + // Skip unmodelled ops that are handled differently. + (node_type_name != "_Arg" && node_type_name != "_Retval") && + !unmodelled_op_names_.count(name.getIdentifier())) { + if (node.op_def().is_stateful()) { + VLOG(1) << "[potentially conservative] Op type `" << node.type_string() + << "` is stateful but effects not modelled"; + } else { + // See if any resource type is used. + bool resource = false; + std::function record_resource; + record_resource = [&](mlir::Type type) { + type.walk([&](mlir::Type t) { + if (resource) return mlir::WalkResult::interrupt(); + if (mlir::isa(type)) { + resource = true; + return mlir::WalkResult::interrupt(); + } + + return mlir::WalkResult::advance(); + }); + + return resource; + }; + + for (mlir::Type t : inner_op->getResultTypes()) + if (record_resource(t)) break; + for (mlir::Type t : inner_op->getOperandTypes()) + if (record_resource(t)) break; + if (resource) { + unmodelled_op_names_.insert(name.getIdentifier()); + VLOG(1) << "[potentially conservative] Op type `" + << node.type_string() + << "` has resource operands/results but effects not modelled"; + } + } + } + } + + // Add the terminator for the island + island_builder.create(result.location, + inner_op->getResults()); + return island.getOperation(); +} + +absl::Status ImporterBase::ConvertNode(const Node& node) { + if (!node.IsOp()) { + // Don't import the pseudo-nodes _SOURCE or _SINK. These are added by + // Graph and don't exist in GraphDef. + return absl::OkStatus(); + } + + // If it is a custom OP, its definition should be found in the library. We + // create the MLIR function and insert it to the module if it doesn't exist. + std::string node_type_name = node.type_string(); + const auto* func_def = graph_flib_.Find(node_type_name); + bool convert_to_legacy_call = false; + if (func_def) { + TF_RETURN_IF_ERROR(ConvertLibFunction(node_type_name)); + node_type_name = (*tf_name_to_mlir_name_)[node_type_name]; + convert_to_legacy_call = true; + } + + auto get_full_op_name = [&](const std::string& op_name) { + const char* kTfPrefix = "tf."; + return kTfPrefix + op_name; + }; + + std::string op_name = get_full_op_name(node_type_name); + if (back_edge_node_output_.contains(&node)) { + op_name = op_name + ".sink"; + } + + mlir::OperationState result(GetLocation(node), op_name); + for (int i = 0; i < node.num_outputs(); ++i) { + // The backedge has been removed, so we shouldn't count the corresponding + // output from the src node when converting to an operation. + if (back_edge_node_output_.contains(&node) && + back_edge_node_output_[&node] == i) { + continue; + } + TF_ASSIGN_OR_RETURN(auto type, InferOutputType(node, i, builder_)); + result.types.push_back(type); + } + + // Surprisingly input edges can be nondeterministically ordered. This + // particularly seems to be the case for the control edges between _SOURCE + // and _SINK that the Graph constructor inserts. Copy the input edges and + // sort the edges, but only the control edges, not data edges! + // TODO(jmolloy): We should probably just ignore _SOURCE and _SINK nodes. + // They'll break roundtripping anyway unless we strip them when converting + // back to graphdef. + absl::InlinedVector in_edges(node.in_edges().size()); + absl::c_copy(node.in_edges(), in_edges.begin()); + absl::c_stable_sort(in_edges, [](const Edge* e1, const Edge* e2) { + if (e1->IsControlEdge() && !e2->IsControlEdge()) return false; + if (!e1->IsControlEdge() && e2->IsControlEdge()) return true; + if (e1->IsControlEdge() && e2->IsControlEdge()) + return e1->src()->id() < e2->src()->id(); + return e1->dst_input() < e2->dst_input(); + }); + + result.operands.reserve(in_edges.size()); + + // Collect the control operands separately, they will be held by the island. + mlir::SmallVector control_operands; + + for (const auto* input_edge : in_edges) { + const Node& input_node = *input_edge->src(); + if (input_node.IsSource()) { + if (in_edges.size() != 1) { + return errors::FailedPrecondition( + "The node has other inputs besides the _Source node"); + } + // We don't import the _SOURCE node. + continue; + } + if (input_node.IsArg() && input_edge->IsControlEdge()) { + // Currently we have not reached consensus as to what TF function + // semantics are (b/133509504). Here we assume that all arguments to a + // function should be available before we start execution of any internal + // node. This makes the control dependencies between function arguments + // and internal nodes redundant, and so we do not import them. The TF + // inliner however assumes no such dependency between function args and + // internal nodes exists, unless explicitly stated. Since we drop control + // dependencies here, it leads to loss of information. If the function is + // inlined later, the inliner would not know of these explicit control + // dependencies present in the original graph. + continue; + } + if (node_values_.find(input_node.id()) == node_values_.end()) + return errors::FailedPrecondition( + "Graph not traversed in reverse post order; use seen before def!"); + mlir::Operation* inst = node_values_[input_node.id()]; + if (input_edge->IsControlEdge()) + control_operands.push_back(inst->getResult(inst->getNumResults() - 1)); + else + result.operands.push_back(inst->getResult(input_edge->src_output())); + } + + using FuncPairType = std::pair; + std::vector funcs; + result.attributes.reserve(node.attrs().size() + 2); + auto abstract_op = result.name.getRegisteredInfo(); + auto derived_op = + abstract_op + ? abstract_op->getInterface() + : nullptr; + for (const auto& name_and_value : node.attrs()) { + const auto& attr_name = name_and_value.first; + // Skip adding derived attributes to the generated op. + if (derived_op && derived_op->isDerivedAttribute(attr_name)) continue; + const AttrValue& attr_value = name_and_value.second; + + // Remove _output_shapes attribute that will be added by the exporter. + if (IsOutputShapesAttribute(attr_value, attr_name)) continue; + + if (attr_value.value_case() == AttrValue::kFunc) { + // Attribute iteration order is not defined for protocol buffer Map. + // Process function attributes separately in the lexicographical order to + // have deterministic order of functions in the constructed IR. + funcs.emplace_back(&attr_name, &attr_value); + } else { + TF_ASSIGN_OR_RETURN(auto attr, ConvertAttributeValue(attr_value)); + result.attributes.push_back(builder_.getNamedAttr(attr_name, attr)); + } + } + + auto comparator = [](const FuncPairType& a, const FuncPairType& b) { + return *a.first < *b.first; + }; + std::sort(funcs.begin(), funcs.end(), comparator); + for (const auto& func : funcs) { + TF_RETURN_IF_ERROR(ConvertFunctionCallAttribute(*func.first, *func.second, + &result.attributes)); + } + + const auto& node_def = node.def(); + // NodeDef can contain partial TF device names. In such cases, canonicalize + // it. Note that in current TF, placer will place full device name to each + // node. + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(node_def.device(), &parsed_name)) { + return errors::InvalidArgument( + "Op ", op_name, " has invalid device name: ", node_def.device()); + } + // Keep the parsed name untouched if the device name is empty. + if (!node_def.device().empty()) { + if (!parsed_name.has_type) { + parsed_name.type = "CPU"; + parsed_name.has_type = true; + } + if (!parsed_name.has_id) { + parsed_name.id = 0; + parsed_name.has_id = true; + } + } + result.attributes.push_back(builder_.getNamedAttr( + "device", builder_.getStringAttr( + DeviceNameUtils::ParsedNameToString(parsed_name)))); + + // Map user function calls to LegacyCall ops and add the user function name + // as an attribute. + if (convert_to_legacy_call) { + result.name = mlir::OperationName(get_full_op_name("LegacyCall"), context_); + mlir::SymbolRefAttr val = + mlir::SymbolRefAttr::get(builder_.getContext(), node_type_name); + result.addAttribute("f", val); + + if (!result.attributes.get("_disable_call_shape_inference")) { + result.addAttribute("_disable_call_shape_inference", + builder_.getBoolAttr(false)); + } + } + + auto composite_control_flow_op = [&](const std::string& name) { + result.name = mlir::OperationName(get_full_op_name(name), context_); + bool stateless = absl::StartsWith(node_type_name, "Stateless"); + mlir::BoolAttr val = builder_.getBoolAttr(stateless); + result.attributes.push_back(builder_.getNamedAttr("is_stateless", val)); + }; + + // Map Case/If/While and StatelessCase/If/While op in TensorFlow to the common + // Case/If/While op in MLIR and add the differentiating attribute. + if (node.IsCaseNode()) composite_control_flow_op("Case"); + if (node.IsIfNode()) composite_control_flow_op("If"); + if (node.IsWhileNode()) { + composite_control_flow_op("While"); + auto* output_shapes = node.attrs().Find("output_shapes"); + if (output_shapes && !output_shapes->list().shape().empty()) + result.attributes.push_back( + builder_.getNamedAttr("shape_invariant", builder_.getUnitAttr())); + } + + // Register the mapping between the TF node and the newly created operation. + node_values_[node.id()] = + CreateOperation(node, node_type_name, result, control_operands); + return absl::OkStatus(); +} + +// Add the backedges to the CFG. Given a backedge, we replace the original +// source and destination operations by two new operations. Most of the +// fields of the replacements are copied from the original operations. +// However, +// - for the src operation, one output is inserted to the front of the output +// list. The type of the output is set to the type of the non-control result +// of the dst operation, and +// - for the dst operation, one operand is inserted to the front of the +// operand list. This operand is using the first result of the src +// operation. +// TODO(fengliuai): Preserve the order of the results and operands if +// necessary. +absl::Status ImporterBase::AddBackedges() { + for (auto it : back_edge_dst_inputs_) { + BackEdge& edge = it.second; + if (!edge.src->IsNextIteration() || !edge.dst->IsMerge()) { + return errors::FailedPrecondition( + "Invalid backedge; should be from NextIteration to Merge!"); + } + auto* sink = node_values_[edge.src->id()]; + auto* dst = node_values_[edge.dst->id()]; + TF_RETURN_IF_ERROR(AddBackedge(sink, dst, edge.dst_input)); + } + return absl::OkStatus(); +} + +absl::Status ImporterBase::AddBackedge(mlir::Operation* sink, + mlir::Operation* dst, int dst_input) { + // Get the NextIteration.Source operation from the token operand of the sink. + mlir::Operation* source = sink->getOperand(0).getDefiningOp(); + + // Adds the "source" to the operands of the dst by creating a new dst + // operation. + mlir::OperationState state(dst->getLoc(), dst->getName()); + auto num_operands = dst->getNumOperands(); + state.operands.reserve(num_operands + 1); + for (int input = 0, e = num_operands + 1; input != e; ++input) { + if (input < dst_input) { + state.operands.push_back(dst->getOperand(input)); + } else if (input == dst_input) { + state.operands.push_back(source->getResult(0)); + } else { + state.operands.push_back(dst->getOperand(input - 1)); + } + } + state.attributes.assign(dst->getAttrs().begin(), dst->getAttrs().end()); + state.types.assign(dst->getResultTypes().begin(), + dst->getResultTypes().end()); + builder_.setInsertionPoint(dst); + auto* new_dst = builder_.create(state); + + // Replaces the output uses of the old operation by the corresponding + // result of the new operation, and deletes the old operation. + for (unsigned i = 0, e = dst->getNumResults(); i != e; ++i) { + auto new_output = new_dst->getResult(i); + dst->getResult(i).replaceAllUsesWith(new_output); + } + dst->dropAllReferences(); + dst->erase(); + return absl::OkStatus(); +} + +absl::StatusOr ImporterBase::InferLibFunctionType( + const FunctionBody& fbody) { + mlir::Builder builder(context_); + + // The FunctionBody contains a graph with a single-output _Arg node for each + // function argument and a single-input _Retval node for each function return + // value. + // + // We already populated the ShapeRefiner with all the information about the + // shapes of these graph edges, so we just query it to build the corresponding + // MLIR function type signature. + + llvm::SmallVector arg_types; + if (specs_.inputs.empty()) { + arg_types.reserve(fbody.arg_types.size()); + for (auto arg : fbody.arg_nodes) { + // Find node in the graph using the node id instead of using `arg` + // directly because the graph has been cloned. + auto* node = graph_->FindNodeId(arg->id()); + TF_ASSIGN_OR_RETURN(auto type, + InferOutputType(*node, /*idx=*/0, builder)); + arg_types.push_back(type); + } + } else { + arg_types.reserve(fbody.arg_types.size()); + for (const auto& it : llvm::enumerate(specs_.inputs)) { + mlir::Type element_type; + const auto& node_info = it.value().second; + DataType dtype = node_info.imported_dtype; + // Uses the existing output type of the arg node if the data type of the + // the node isn't specified through the import configuration. + if (dtype == DT_INVALID) { + auto arg = fbody.arg_nodes[it.index()]; + auto* node = graph_->FindNodeId(arg->id()); + dtype = node->output_type(0); + if (dtype == DT_INVALID) { + return errors::InvalidArgument("Input ", it.index(), + "has invalid data type"); + } + } + TF_RETURN_IF_ERROR( + ::tensorflow::ConvertDataType(dtype, builder, &element_type)); + if (node_info.shape.unknown_rank()) { + arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); + } else { + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(GetTypeFromTFTensorShape(shape, element_type)); + } + } + } + + llvm::SmallVector ret_types; + ret_types.reserve(fbody.ret_types.size()); + for (auto ret : fbody.ret_nodes) { + // Find node in the graph using the node id instead of using `ret` directly + // because the graph has been cloned. + auto* node = graph_->FindNodeId(ret->id()); + TF_ASSIGN_OR_RETURN(auto type, InferInputType(*node, /*idx=*/0, builder)); + ret_types.push_back(type); + } + + return builder.getFunctionType(arg_types, ret_types); +} + +// Stateful helper class to import a TensorFlow model expressed in GraphDef into +// an MLIR Module. +// +// The nodes defined in the graph are converted to a function called +// 'func_name'. All library function definitions are converted to MLIR functions +// in the module. +class GraphDefImporter : public ImporterBase { + public: + // Main entry point: converts the given graph to an MLIR Module. + static absl::StatusOr> Convert( + mlir::MLIRContext* context, const Graph& graph, + const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + std::unordered_map* tf_name_to_mlir_name, + bool disable_crash_analysis = false); + + private: + explicit GraphDefImporter( + const FunctionLibraryDefinition& flib, const GraphDebugInfo& debug_info, + const GraphImportConfig& specs, mlir::ModuleOp module, + std::unordered_map* tf_name_to_mlir_name, + NameUniquifier* function_name_uniquifier) + : ImporterBase(flib, debug_info, specs, module, tf_name_to_mlir_name, + function_name_uniquifier) {} + + // Returns the function signature of the main function of converted MLIR + // module, the input nodes and output nodes. The type and shape information + // for the function arguments are read from `specs`, but the type and shape + // information for the function returns are inferred by the shape refiner in + // ImporterBase. + absl::StatusOr InferMainFunctionType( + const GraphImportConfig& specs, mlir::MLIRContext* context, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes); + + // Returns the function signature of the main function, alongside input and + // output nodes, for function graphs. Arguments and return values are + // determined by node op type. Type and shape information of the function are + // inferred by the shape refiner in ImporterBase. + absl::StatusOr GetArgsRetsAndTypesFromFunctionGraph( + mlir::MLIRContext* context, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes); + + // Finds the graph's target nodes/function's control ret nodes based on + // supplied node names in `control_outputs`. If `control_outputs` are not + // unique or a control ret node is missing, an error will be returned. + absl::Status GetControlRetsFromGraph( + llvm::ArrayRef control_outputs, + absl::InlinedVector* control_ret_nodes); +}; + +absl::StatusOr> GraphDefImporter::Convert( + mlir::MLIRContext* context, const Graph& graph, + const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, + const GraphImportConfig& specs, + std::unordered_map* tf_name_to_mlir_name, + bool disable_crash_analysis) { + LoadImporterDialects(*context); + mlir::OwningOpRef module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + NameUniquifier function_name_uniquifier(flib_def); + + // importer.PrepareConvert below will attemp to clone the original `graph` + // via conversion to the graph def first. Convert graph to graph_def here + // first and avoid extra copies later. + auto graph_def = std::make_unique(); + graph.ToGraphDef(graph_def.get(), /*include_flib_def=*/false); + + auto scope_exit = [&]() { + std::function cleanup = []() {}; + if (!disable_crash_analysis) { + static std::atomic counter(0); + uint32 current_file_prefix = counter++; + const auto* graph_crash_handle = crash_analysis::ReportProtoDataOnCrash( + absl::StrCat(current_file_prefix, "_mlir_import_graph.pbtxt"), + *graph_def); + auto reachable_flib = flib_def.ReachableDefinitions(*graph_def); + const auto* flib_crash_handle = crash_analysis::ReportProtoDataOnCrash( + absl::StrCat(current_file_prefix, "_mlir_import_flib.pbtxt"), + reachable_flib.ToProto()); + cleanup = [=]() { + crash_analysis::RemoveReportData(graph_crash_handle); + crash_analysis::RemoveReportData(flib_crash_handle); + }; + } + + return llvm::make_scope_exit(std::move(cleanup)); + }(); + + VLOG(2) << "Importing: " + << ::tensorflow::DumpGraphToFile("tf_mlir_importer_base", graph, + &flib_def); + + GraphDefImporter importer(flib_def, debug_info, specs, module.get(), + tf_name_to_mlir_name, &function_name_uniquifier); + + TF_RETURN_IF_ERROR(importer.PrepareConvert(graph, std::move(graph_def))); + + mlir::FunctionType func_type; + absl::InlinedVector arg_nodes; + absl::InlinedVector ret_nodes; + absl::InlinedVector control_ret_nodes; + llvm::SmallVector attrs; + if (specs.graph_as_function) { + if (specs.prune_unused_nodes || !specs.inputs.empty() || + !specs.outputs.empty()) + return errors::InvalidArgument( + "Pruning of graph is currently unsupported when the main graph is " + "converted to a function."); + + TF_ASSIGN_OR_RETURN(func_type, + importer.GetArgsRetsAndTypesFromFunctionGraph( + context, &arg_nodes, &ret_nodes)); + + TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, + &control_ret_nodes)); + + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + auto node_name = [&](const OutputTensor& tensor) { + ss << tensor.node->name(); + }; + llvm::interleave(arg_nodes, ss, node_name, ","); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(ret_nodes, ss, node_name, ","); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.control_outputs, ss, ","); + auto control_outputs = + b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); + + // Under `graph_as_function` mode, `tf.entry_function` is always set as it + // is assumed feed, fetch, and target nodes are set correctly. + attrs.push_back(b.getNamedAttr( + "tf.entry_function", + b.getDictionaryAttr({inputs, outputs, control_outputs}))); + if (!specs.xla_compile_device_type.empty()) { + attrs.push_back( + b.getNamedAttr("_xla_compile_device_type", + b.getStringAttr(specs.xla_compile_device_type))); + } + attrs.push_back(b.getNamedAttr("allow_soft_placement", + b.getBoolAttr(specs.enable_soft_placement))); + } else { + // Collects the argument and return nodes by looking up the node names + // specified by the user. + TF_ASSIGN_OR_RETURN(func_type, importer.InferMainFunctionType( + specs, context, &arg_nodes, &ret_nodes)); + + TF_RETURN_IF_ERROR(importer.GetControlRetsFromGraph(specs.control_outputs, + &control_ret_nodes)); + + // TODO(prakalps): Refactor to keep tf.entry_function attribute encoding and + // decoding in a centralized place. + // Record the input and output mapping. + if (!specs.inputs.empty() || !specs.outputs.empty() || + !specs.control_outputs.empty()) { + mlir::Builder b(context); + std::string s; + llvm::raw_string_ostream ss(s); + llvm::interleave( + specs.inputs, ss, + [&](const std::pair& v) { ss << v.first; }, + ","); + auto inputs = b.getNamedAttr("inputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.outputs, ss, ","); + auto outputs = b.getNamedAttr("outputs", b.getStringAttr(ss.str())); + s.clear(); + llvm::interleave(specs.control_outputs, ss, ","); + auto control_outputs = + b.getNamedAttr("control_outputs", b.getStringAttr(ss.str())); + + attrs.push_back(b.getNamedAttr( + "tf.entry_function", + b.getDictionaryAttr({inputs, outputs, control_outputs}))); + } + } + + // Record version info. + PopulateTfVersions(module.get(), graph.versions()); + + const llvm::StringRef& graph_func_name = + specs.graph_func_name.empty() ? kImportModelDefaultGraphFuncName + : specs.graph_func_name; + TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(graph_func_name, func_type, + arg_nodes, ret_nodes, + control_ret_nodes, attrs)); + if (specs.convert_all_functions_to_mlir) { + auto fn_names = graph.flib_def().ListFunctionNames(); + for (const auto& fn_name : fn_names) { + TF_RETURN_IF_ERROR(importer.ConvertLibFunction(fn_name)); + } + } + TF_RETURN_IF_ERROR(importer.ImporterBase::ConvertDeferredFunctions()); + + // Mark main function public, others private. + for (auto function : module.get().getOps()) { + auto visibility = function.getName() == graph_func_name + ? mlir::func::FuncOp::Visibility::Public + : mlir::func::FuncOp::Visibility::Private; + function.setVisibility(visibility); + } + VLOG(2) << "Imported: " + << tensorflow::DumpMlirOpToFile("tf_mlir_imported_base", + module.get()); + return module; +} + +absl::StatusOr GraphDefImporter::InferMainFunctionType( + const GraphImportConfig& specs, mlir::MLIRContext* context, + absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes) { + // Find all the input nodes and output nodes. + // Feeds have been remapped to single output nodes (Placeholder), so an exact + // name match is sufficient. + absl::flat_hash_map inputs; + for (const auto& input_and_idx : llvm::enumerate(specs.inputs)) { + TensorId tensor = ParseTensorName(input_and_idx.value().first); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + inputs.insert({remapped_it->second, input_and_idx.index()}); + } else { + inputs.insert({tensor.node(), input_and_idx.index()}); + } + } + + absl::flat_hash_set output_node_names; + std::vector outputs; + output_node_names.reserve(specs.outputs.size()); + for (const auto& output : specs.outputs) { + TensorId tensor = ParseTensorName(output); + auto remapped_it = remapped_feeds_.find(tensor); + if (remapped_it != remapped_feeds_.end()) { + output_node_names.insert(remapped_it->second); + outputs.push_back({remapped_it->second, 0}); + } else { + output_node_names.insert(tensor.node()); + outputs.push_back(tensor); + } + } + + if (!inputs.empty() || !outputs.empty()) { + arg_nodes->resize(inputs.size()); + ret_nodes->resize(outputs.size()); + + for (Node* n : GetOrderedNodes()) { + // Handle inputs/arguments. + auto input_it = inputs.find(n->name()); + if (input_it != inputs.end()) { + (*arg_nodes)[input_it->second] = {n, 0}; + } + + // Handle outputs/returns. + if (output_node_names.contains(n->name())) { + for (int i = 0, e = outputs.size(); i != e; ++i) { + TensorId tensor = outputs[i]; + if (n->name() != tensor.node()) continue; + (*ret_nodes)[i] = {n, tensor.index()}; + } + } + } + } + + // Starts to construct the function type. + mlir::Builder builder(context); + llvm::SmallVector arg_types; + arg_types.reserve(specs.inputs.size()); + int i = 0; + for (const auto& it : specs.inputs) { + Node* arg_node = arg_nodes->at(i).node; + if (arg_node == nullptr) { + return errors::InvalidArgument("Input ", it.first, + " was not found in graph"); + } + mlir::Type element_type; + const auto& node_info = it.second; + DataType imported_dtype = node_info.imported_dtype; + // Uses the existing output type of the arg node if the data type of the + // the node isn't specified through the import configuration. + if (imported_dtype == DT_INVALID) { + imported_dtype = arg_node->output_type(0); + if (imported_dtype == DT_INVALID) { + return errors::InvalidArgument("Input ", i, "has invalid data type"); + } + } + // Check if we have subtypes first + if (!node_info.subtypes.empty()) { + std::vector subtypes; + for (const auto& st : node_info.subtypes) { + mlir::Type st_data_type; + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(st.shape, &shape)); + TF_RETURN_IF_ERROR( + ConvertDataType(st.imported_dtype, builder, &st_data_type)); + subtypes.push_back(GetTypeFromTFTensorShape(shape, st_data_type)); + } + if (imported_dtype == DT_RESOURCE) { + element_type = + mlir::TF::ResourceType::get(subtypes, builder.getContext()); + } else if (imported_dtype == DT_VARIANT) { + element_type = + mlir::TF::VariantType::get(subtypes, builder.getContext()); + } else { + return errors::InvalidArgument(DataType_Name(imported_dtype), + " takes no subtypes."); + } + } else { + TF_RETURN_IF_ERROR( + ConvertDataType(imported_dtype, builder, &element_type)); + } + if (node_info.shape.unknown_rank()) { + arg_types.push_back(mlir::UnrankedTensorType::get(element_type)); + } else { + llvm::SmallVector shape; + TF_RETURN_IF_ERROR(ConvertToMlirShape(node_info.shape, &shape)); + arg_types.push_back(GetTypeFromTFTensorShape(shape, element_type)); + } + i++; + } + + llvm::SmallVector ret_types; + ret_types.reserve(specs.outputs.size()); + for (int i = 0, e = specs.outputs.size(); i != e; ++i) { + if (ret_nodes->at(i).node == nullptr) { + return errors::InvalidArgument("Output ", specs.outputs[i], + " was not found in graph"); + } + } + for (const auto& ret : *ret_nodes) { + if (ret.node->num_outputs() <= ret.index) { + return errors::InvalidArgument("Invalid output index ", ret.index, + " specified for node: ", ret.node->name()); + } + TF_ASSIGN_OR_RETURN(auto type, + InferOutputType(*ret.node, ret.index, builder)); + ret_types.push_back(type); + } + + return builder.getFunctionType(arg_types, ret_types); +} + +absl::StatusOr +GraphDefImporter::GetArgsRetsAndTypesFromFunctionGraph( + mlir::MLIRContext* context, absl::InlinedVector* arg_nodes, + absl::InlinedVector* ret_nodes) { + auto add_node = [](Node* node, absl::InlinedVector* nodes) { + auto* attr = node->attrs().Find("index"); + if (!attr) + return errors::InvalidArgument(node->type_string(), " node '", + node->name(), + "' is missing attribute 'index'"); + + auto index = attr->i(); + const int num_nodes = nodes->size(); + if (num_nodes < index + 1) nodes->resize(index + 1); + + if ((*nodes)[index].node != nullptr) + return errors::InvalidArgument(node->type_string(), " node '", + node->name(), "' has attribute 'index' ", + index, " that conflicts with node '", + (*nodes)[index].node->name(), "'"); + (*nodes)[index] = {node, 0}; + + return absl::OkStatus(); + }; + + // Collect arg and ret nodes from graph. + for (auto* node : GetOrderedNodes()) + if (node->IsArg()) + TF_RETURN_IF_ERROR(add_node(node, arg_nodes)); + else if (node->IsRetval()) + TF_RETURN_IF_ERROR(add_node(node, ret_nodes)); + + // Collect arg and ret types and create function type. + mlir::Builder builder(context); + llvm::SmallVector arg_types; + arg_types.reserve(arg_nodes->size()); + for (const auto& arg_node_and_idx : llvm::enumerate(*arg_nodes)) { + auto& arg_node = arg_node_and_idx.value(); + if (arg_node.node == nullptr) + return errors::InvalidArgument("Graph missing _Arg at index ", + arg_node_and_idx.index()); + + TF_ASSIGN_OR_RETURN(auto type, + InferOutputType(*arg_node.node, /*idx=*/0, builder)); + arg_types.push_back(type); + } + + llvm::SmallVector ret_types; + ret_types.reserve(ret_nodes->size()); + for (const auto& ret_node_and_idx : llvm::enumerate(*ret_nodes)) { + auto& ret_node = ret_node_and_idx.value(); + if (ret_node.node == nullptr) + return errors::InvalidArgument("Graph missing _Retval at index ", + ret_node_and_idx.index()); + + TF_ASSIGN_OR_RETURN(auto type, + InferInputType(*ret_node.node, /*idx=*/0, builder)); + ret_types.push_back(type); + } + + return builder.getFunctionType(arg_types, ret_types); +} + +absl::Status GraphDefImporter::GetControlRetsFromGraph( + llvm::ArrayRef control_outputs, + absl::InlinedVector* control_ret_nodes) { + if (control_outputs.empty()) return absl::OkStatus(); + + llvm::SmallDenseMap controls_to_idx; + for (const auto& control_and_idx : llvm::enumerate(control_outputs)) + controls_to_idx.insert({control_and_idx.value(), control_and_idx.index()}); + + if (controls_to_idx.size() != control_outputs.size()) + return errors::InvalidArgument("Control outputs must be unique"); + + control_ret_nodes->resize(controls_to_idx.size()); + + for (auto* node : GetOrderedNodes()) { + auto it = controls_to_idx.find(node->name()); + if (it != controls_to_idx.end()) (*control_ret_nodes)[it->second] = node; + } + + for (auto node_and_name : llvm::zip(*control_ret_nodes, control_outputs)) + if (std::get<0>(node_and_name) == nullptr) + return errors::InvalidArgument( + "Control output '", std::get<1>(node_and_name), "' is missing"); + + return absl::OkStatus(); +} + +absl::StatusOr> ConvertGraphToTfExecutor( + const Graph& graph, const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + mlir::MLIRContext* context, + std::unordered_map* tf_name_to_mlir_name) { + // TODO(jpienaar): Remove need to const_cast. + if (specs.upgrade_legacy) { + TF_RETURN_IF_ERROR( + UpgradeLegacyGraph(const_cast(&graph), + const_cast(&flib_def), + specs.restrict_functionalization_to_compiled_nodes)); + } + + std::unordered_map local_tf_name_to_mlir_name; + TF_ASSIGN_OR_RETURN( + auto module, + GraphDefImporter::Convert(context, graph, debug_info, flib_def, specs, + tf_name_to_mlir_name == nullptr + ? &local_tf_name_to_mlir_name + : tf_name_to_mlir_name)); + + if (specs.set_original_tf_func_name) { + // Set up the original function names in the imported TF MLIR. + mlir::Builder builder(module->getContext()); + mlir::SymbolTable symbol_table(*module); + for (const auto& [tf_name, mlir_name] : + (tf_name_to_mlir_name == nullptr ? local_tf_name_to_mlir_name + : *tf_name_to_mlir_name)) { + auto func_op = symbol_table.lookup(mlir_name); + TF_RET_CHECK(func_op) + << "Graphdef importer should have created a function named " + << mlir_name << "."; + func_op->setAttr("tf._original_func_name", + builder.getStringAttr(tf_name)); + } + } + return module; +} + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h new file mode 100644 index 00000000000000..4822edd85f7d90 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +inline constexpr absl::string_view kImportModelDefaultGraphFuncName = "main"; + +// Given a Graph, returns a MLIR module containing the graph, expressed with +// tf_executor dialect. +absl::StatusOr> ConvertGraphToTfExecutor( + const Graph& graph, const GraphDebugInfo& debug_info, + const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, + mlir::MLIRContext* context, + std::unordered_map* tf_name_to_mlir_name = + nullptr); + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_GRAPH_TO_TF_EXECUTOR_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor_test.cc new file mode 100644 index 00000000000000..6c2aaa605ea615 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor_test.cc @@ -0,0 +1,180 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" + +#include + +#include +#include +#include + +#include +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/bytes/read_all.h" // from @riegeli +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tsl/platform/protobuf.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace { + +using mlir::DialectRegistry; +using mlir::MLIRContext; + +constexpr char kGraphWithFlibDefFileName[] = "graph_with_flib_def.txt"; + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tf2xla/api/v2/testdata/"); +} + +class GraphToTfExecutorTest : public ::testing::Test { + public: + GraphToTfExecutorTest() { + mlir::RegisterCommonToolingDialects(registry_); + context_.appendDialectRegistry(registry_); + context_.loadAllAvailableDialects(); + } + + GraphDef CreateGraphDef(std::string graphdef_filename) { + std::string file_path = TestDataPath() + graphdef_filename; + std::string contents; + GraphDef graph_def; + auto status = riegeli::ReadAll(riegeli::FdReader(file_path), contents); + if (!status.ok()) { + return graph_def; + } + tsl::protobuf::TextFormat::ParseFromString(contents, &graph_def); + return graph_def; + } + + int CountNumberOfFunctionsInModule(mlir::ModuleOp module) { + int count = 0; + for (auto unused : module.getOps()) { + (void)unused; // Avoid unused variable warning + count++; + } + return count; + } + + DialectRegistry registry_; + MLIRContext context_; +}; + +TEST_F(GraphToTfExecutorTest, BasicConvertGraphToTfExecutorPasses) { + Graph graph(OpRegistry::Global()); + GraphDebugInfo debug_info; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + GraphImportConfig specs; + GraphDef graph_def = CreateGraphDef("valid_graph.txt"); + GraphConstructorOptions opts; + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + + TF_ASSERT_OK( + ConvertGraphToTfExecutor(graph, debug_info, flib_def, specs, &context_)); +} + +TEST_F( + GraphToTfExecutorTest, + ConvertGraphToTfExecutorConvertAllFunctionsTrueConvertsAllFunctionsInFlibDef) { + Graph graph(OpRegistry::Global()); + GraphDebugInfo debug_info; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + GraphDef graph_def = CreateGraphDef(kGraphWithFlibDefFileName); + GraphConstructorOptions opts; + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + GraphImportConfig specs; + specs.convert_all_functions_to_mlir = true; + + absl::StatusOr> result = + ConvertGraphToTfExecutor(graph, debug_info, graph.flib_def(), specs, + &context_); + + // should equal main + 4 functions in flib_def + ASSERT_EQ(CountNumberOfFunctionsInModule(result->get()), 5); +} + +TEST_F( + GraphToTfExecutorTest, + ConvertGraphToTfExecutorConvertAllFunctionsFalseOnlyConvertsFunctionsReferencedInGraph) { + Graph graph(OpRegistry::Global()); + GraphDebugInfo debug_info; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + GraphDef graph_def = CreateGraphDef(kGraphWithFlibDefFileName); + GraphConstructorOptions opts; + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + GraphImportConfig specs; + specs.convert_all_functions_to_mlir = false; + + absl::StatusOr> result = + ConvertGraphToTfExecutor(graph, debug_info, graph.flib_def(), specs, + &context_); + + // should equal main + 2 functions referenced by nodes in the graph via the + // "f" attr. + ASSERT_EQ(CountNumberOfFunctionsInModule(result->get()), 3); +} + +TEST_F(GraphToTfExecutorTest, + ConvertGraphToTfExecutorPopulatesTfNameToMlirNameMap) { + Graph graph(OpRegistry::Global()); + GraphDebugInfo debug_info; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + GraphDef graph_def = CreateGraphDef(kGraphWithFlibDefFileName); + GraphConstructorOptions opts; + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); + GraphImportConfig specs; + std::unordered_map tf_name_to_mlir_name; + + TF_ASSERT_OK(ConvertGraphToTfExecutor(graph, debug_info, graph.flib_def(), + specs, &context_, + &tf_name_to_mlir_name)); + + std::unordered_set result_set; + for (const auto& pair : tf_name_to_mlir_name) { + result_set.insert(pair.first); + } + // Converted functions referenced by nodes in the graph via the + // "f" attr. These are before they are converted to their corresponding MLIR + // name. + std::unordered_set expected_set = { + "__inference__traced_save_45", "__inference__traced_restore_57"}; + EXPECT_EQ(result_set, expected_set); +} + +} // namespace +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc index 7c2a16a55ef41d..ca9b7d27772e19 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.cc @@ -23,30 +23,27 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/str_cat.h" #include "absl/types/variant.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" -#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" #include "tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h" -#include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.h" #include "tensorflow/compiler/mlir/tf2xla/internal/reproducer.pb.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/client/compile_only_client.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/shape.h" +#include "xla/tsl/framework/device_type.h" #include "xla/tsl/lib/monitoring/sampler.h" #include "xla/xla.pb.h" -#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/util/debug_data_dumper.h" -#include "tensorflow/core/util/dump_graph.h" -#include "tsl/platform/error_logging.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" @@ -169,8 +166,8 @@ absl::StatusOr LegalizeMlirToHlo( if (ShouldFallbackToGraphCompiler(computation)) { TF_RETURN_IF_ERROR(tf2xla::v1::CompileTensorflowGraphToHlo( computation, metadata, use_tuple_args, shape_determination_fns, - arg_shapes, arg_core_mapping, per_core_arg_shapes, client, - compilation_result.get())); + arg_shapes, tsl::DeviceType(device_type.str()), arg_core_mapping, + per_core_arg_shapes, client, compilation_result.get())); DumpHloCompilationResult("legalize_tf_fallback.hlo", compilation_result.get()) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index 1cbc08269d0eb0..63fe03309ad290 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -22,15 +22,22 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h" #include "tensorflow/compiler/mlir/tf2xla/internal/test_matchers.h" -#include "tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/shape.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/lib/monitoring/test_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/lib/monitoring/test_utils.h" #include "tensorflow/core/platform/env.h" @@ -40,18 +47,20 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/util/debug_data_dumper.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace tensorflow { namespace tf2xla { namespace v2 { using ::tensorflow::monitoring::testing::CellReader; +using tensorflow::tf2xla::v2::testing::CompileMlirModule; using ::testing::Not; using ::testing::TestWithParam; using tpu::FunctionToHloArgs; -using tpu::MlirToHloArgs; using tpu::ShardingAndIndex; using tpu::TPUCompileMetadataProto; +using tsl::testing::TmpDir; static constexpr char kCompilationTimeStreamzName[] = "/tensorflow/core/tf2xla/api/v2/phase2_compilation_time"; @@ -102,38 +111,6 @@ static constexpr char kUnsupportedMlirBridgeModuleStr[] = R"( } })"; -absl::StatusOr CompileMlirModule( - const char* mlir_module_str, - ConfigProto::Experimental::MlirBridgeRollout rollout_state) { - MlirToHloArgs mlir_to_hlo_args; - mlir_to_hlo_args.rollout_state = rollout_state; - mlir_to_hlo_args.mlir_module = mlir_module_str; - - se::Platform* platform = - se::PlatformManager::PlatformWithName("Host").value(); - auto client = - xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform).value(); - - std::vector arg_shapes; - TPUCompileMetadataProto metadata_proto; - // Configure metadata requires parsing the module and if we are testing a - // failure, we ignore this particular set up error assuming we'll not get - // far enough to need valid metadata. - tensorflow::tf2xla::internal::ConfigureMetadata(mlir_module_str, arg_shapes, - metadata_proto) - .IgnoreError(); - bool use_tuple_args = true; - std::vector arg_core_mapping; - std::vector> per_core_arg_shapes; - std::vector> custom_legalization_passes; - - return LegalizeMlirToHlo(mlir_to_hlo_args, metadata_proto, use_tuple_args, - /*device_type=*/"XLA_TPU_JIT", - custom_legalization_passes, - /*shape_determination_fns=*/{}, arg_shapes, - &arg_core_mapping, &per_core_arg_shapes, client); -} - TEST(LegalizeTFTest, RecordsStreamzForSuccessfulLegalizeWithMlirBridge) { CellReader compilation_status(kCompilationStatusStreamzName); @@ -207,7 +184,7 @@ INSTANTIATE_TEST_SUITE_P( TEST(LegalizeTFTest, DumpsProducedHLO) { Env* env = Env::Default(); - std::string test_dir = testing::TmpDir(); + std::string test_dir = TmpDir(); setenv("TF_DUMP_GRAPH_PREFIX", test_dir.c_str(), /*overwrite=*/1); setenv("TF_DUMP_GRAPH_NAME_FILTER", "*", 1); DEBUG_DATA_DUMPER()->LoadEnvvars(); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test_gpu.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test_gpu.cc new file mode 100644 index 00000000000000..62ecc87711d385 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test_gpu.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { + +// These test are in a separate file because they requires separate build +// environments (--config=cuda). +using tensorflow::tf2xla::v2::testing::CompileMlirModule; +using tsl::testing::StatusIs; + +// MLIR which is legalize only with the right device type. +// The mlir is generated by running +// tensorflow/compiler/mlir/tf-opt -tf-xla-call-module-serialization +// +// module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, +// producer = 268 : i32}} { +// func.func private @_jit_sin(%arg0: tensor) -> tensor { +// %0 = stablehlo.sine %arg0 : tensor +// return %0 : tensor +// } +// func.func @main(%arg0: tensor) -> tensor<*xf32> { +// %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = +// "", dim_args_spec = [], _entry_function = @_jit_sin, module = "", +// platforms = ["CUDA"], version = 6 : i64} : (tensor) -> +// tensor<*xf32> +// func.return %0 : tensor<*xf32> +// } +// } +// +static constexpr char kGpuMlirModuleStr[] = R"( +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor) -> tensor<*xf32> { + %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CUDA"], version = 6 : i64} : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } +})"; + +TEST(LegalizeTFTest, RightDeviceTypeShallPass) { + TF_ASSERT_OK_AND_ASSIGN( + XlaCompiler::CompilationResult result, + CompileMlirModule( + kGpuMlirModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED, + "XLA_GPU_JIT")); +} + +TEST(LegalizeTFTest, WrongDeviceTypeShallFail) { + absl::StatusOr result = CompileMlirModule( + kGpuMlirModuleStr, + ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED, + "XLA_TPU_JIT"); + EXPECT_THAT(result, StatusIs(absl::StatusCode::kNotFound)); +} + +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/graph_with_flib_def.txt b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/graph_with_flib_def.txt new file mode 100644 index 00000000000000..42361060a5e8f2 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/graph_with_flib_def.txt @@ -0,0 +1,1535 @@ +node { + name: "Variable" + op: "VarHandleOp" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "debug_name" + value { + s: "Variable/" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + attr { + key: "shared_name" + value { + s: "Variable" + } + } +} +node { + name: "Variable/Read/ReadVariableOp" + op: "ReadVariableOp" + input: "Variable" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } +} +node { + name: "NoOp" + op: "NoOp" +} +node { + name: "Const" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "\n3\n\005\010\001\022\001v\n\n\010\002\022\006callee\n\n\010\003\022\006caller\n\016\010\004\022\nsignatures*\002\010\001\n>\0228\n\016VARIABLE_VALUE\022\010Variable\032\034v/.ATTRIBUTES/VARIABLE_VALUE*\002\010\001\n\017\n\013\010\005\022\007trace_0*\000\n\017\n\013\010\006\022\007trace_0*\000\n\002*\000\n\002*\000\n\002*\000" + } + } + } +} +node { + name: "saver_filename" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "shape" + value { + shape { + } + } + } +} +node { + name: "StatefulPartitionedCall" + op: "StatefulPartitionedCall" + input: "saver_filename" + input: "Variable" + input: "Const" + attr { + key: "Tin" + value { + list { + type: DT_STRING + type: DT_RESOURCE + type: DT_STRING + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000\222\001\002J\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference__traced_save_45" + } + } + } +} +node { + name: "StatefulPartitionedCall_1" + op: "StatefulPartitionedCall" + input: "saver_filename" + input: "Variable" + attr { + key: "Tin" + value { + list { + type: DT_STRING + type: DT_RESOURCE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_STRING + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + } + } + } + attr { + key: "config_proto" + value { + s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000\222\001\002J\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference__traced_restore_57" + } + } + } +} +library { + function { + signature { + name: "__inference_callee_10" + input_arg { + name: "x" + type: DT_FLOAT + } + input_arg { + name: "readvariableop_resource" + type: DT_RESOURCE + handle_data { + dtype: DT_FLOAT + shape { + } + } + } + output_arg { + name: "identity" + type: DT_FLOAT + } + output_arg { + name: "identity_1" + type: DT_FLOAT + } + is_stateful: true + control_output: "ReadVariableOp" + } + node_def { + name: "ReadVariableOp" + op: "ReadVariableOp" + input: "readvariableop_resource" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "x" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "ReadVariableOp:value:0" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + input: "^ReadVariableOp" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + ret { + key: "identity_1" + value: "Identity_1:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + control_ret { + key: "ReadVariableOp" + value: "ReadVariableOp" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "x" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_user_specified_name" + value { + s: "resource" + } + } + } + } + } + function { + signature { + name: "__inference__traced_save_45" + input_arg { + name: "file_prefix" + type: DT_STRING + } + input_arg { + name: "read_disablecopyonread_variable" + type: DT_RESOURCE + handle_data { + dtype: DT_FLOAT + shape { + } + } + } + input_arg { + name: "savev2_const" + type: DT_STRING + } + output_arg { + name: "identity_3" + type: DT_STRING + } + is_stateful: true + control_output: "MergeV2Checkpoints" + control_output: "Read/DisableCopyOnRead" + control_output: "Read/ReadVariableOp" + } + node_def { + name: "StaticRegexFullMatch" + op: "StaticRegexFullMatch" + input: "file_prefix" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "pattern" + value { + s: "^s3://.*" + } + } + } + node_def { + name: "Const" + op: "Const" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: ".part" + } + } + } + } + node_def { + name: "Const_1" + op: "Const" + device: "/device:CPU:*" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "_temp/part" + } + } + } + } + node_def { + name: "Select" + op: "Select" + input: "StaticRegexFullMatch:output:0" + input: "Const:output:0" + input: "Const_1:output:0" + device: "/device:CPU:*" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "StringJoin" + op: "StringJoin" + input: "file_prefix" + input: "Select:output:0" + device: "/device:CPU:*" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Read/DisableCopyOnRead" + op: "DisableCopyOnRead" + input: "read_disablecopyonread_variable" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + node_def { + name: "Read/ReadVariableOp" + op: "ReadVariableOp" + input: "read_disablecopyonread_variable" + input: "^Read/DisableCopyOnRead" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "Read/ReadVariableOp:value:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "Identity:output:0" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "num_shards" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node_def { + name: "ShardedFilename/shard" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } + } + node_def { + name: "ShardedFilename" + op: "ShardedFilename" + input: "StringJoin:output:0" + input: "ShardedFilename/shard:output:0" + input: "num_shards:output:0" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "SaveV2/tensor_names" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "v/.ATTRIBUTES/VARIABLE_VALUE" + string_val: "_CHECKPOINTABLE_OBJECT_GRAPH" + } + } + } + } + node_def { + name: "SaveV2/shape_and_slices" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node_def { + name: "SaveV2" + op: "SaveV2" + input: "ShardedFilename:filename:0" + input: "SaveV2/tensor_names:output:0" + input: "SaveV2/shape_and_slices:output:0" + input: "Identity_1:output:0" + input: "savev2_const" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_STRING + } + } + } + } + node_def { + name: "MergeV2Checkpoints/checkpoint_prefixes" + op: "Pack" + input: "ShardedFilename:filename:0" + input: "^SaveV2" + device: "/device:CPU:0" + attr { + key: "N" + value { + i: 1 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node_def { + name: "MergeV2Checkpoints" + op: "MergeV2Checkpoints" + input: "MergeV2Checkpoints/checkpoint_prefixes:output:0" + input: "file_prefix" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + node_def { + name: "Identity_2" + op: "Identity" + input: "file_prefix" + input: "^MergeV2Checkpoints" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_3" + op: "Identity" + input: "Identity_2:output:0" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + input: "^MergeV2Checkpoints" + input: "^Read/DisableCopyOnRead" + input: "^Read/ReadVariableOp" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + ret { + key: "identity_3" + value: "Identity_3:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + shape { + } + shape { + } + } + } + } + control_ret { + key: "MergeV2Checkpoints" + value: "MergeV2Checkpoints" + } + control_ret { + key: "Read/DisableCopyOnRead" + value: "Read/DisableCopyOnRead" + } + control_ret { + key: "Read/ReadVariableOp" + value: "Read/ReadVariableOp" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "file_prefix" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_user_specified_name" + value { + s: "Variable" + } + } + } + } + arg_attr { + key: 2 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "Const" + } + } + } + } + } + function { + signature { + name: "__inference_caller_19" + input_arg { + name: "x" + type: DT_FLOAT + } + input_arg { + name: "unknown" + type: DT_RESOURCE + handle_data { + dtype: DT_FLOAT + shape { + } + } + } + output_arg { + name: "identity" + type: DT_FLOAT + } + output_arg { + name: "identity_1" + type: DT_FLOAT + } + is_stateful: true + control_output: "StatefulPartitionedCall" + } + node_def { + name: "StatefulPartitionedCall" + op: "StatefulPartitionedCall" + input: "x" + input: "unknown" + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + type: DT_RESOURCE + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_collective_manager_ids" + value { + list { + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + attr { + key: "_read_only_resource_inputs" + value { + list { + i: 1 + } + } + } + attr { + key: "config_proto" + value { + s: "\n\007\n\003CPU\020\001\n\007\n\003GPU\020\0002\002J\0008\001\202\001\000\222\001\002J\000" + } + } + attr { + key: "f" + value { + func { + name: "__inference_callee_10" + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "StatefulPartitionedCall:output:0" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "StatefulPartitionedCall:output:1" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + input: "^StatefulPartitionedCall" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + ret { + key: "identity_1" + value: "Identity_1:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + control_ret { + key: "StatefulPartitionedCall" + value: "StatefulPartitionedCall" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "x" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_user_specified_name" + value { + s: "13" + } + } + } + } + } + function { + signature { + name: "__inference__traced_restore_57" + input_arg { + name: "file_prefix" + type: DT_STRING + } + input_arg { + name: "assignvariableop_variable" + type: DT_RESOURCE + handle_data { + dtype: DT_FLOAT + shape { + } + } + } + output_arg { + name: "identity_2" + type: DT_STRING + } + is_stateful: true + control_output: "AssignVariableOp" + } + node_def { + name: "RestoreV2/tensor_names" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "v/.ATTRIBUTES/VARIABLE_VALUE" + string_val: "_CHECKPOINTABLE_OBJECT_GRAPH" + } + } + } + } + node_def { + name: "RestoreV2/shape_and_slices" + op: "Const" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 2 + } + } + string_val: "" + string_val: "" + } + } + } + } + node_def { + name: "RestoreV2" + op: "RestoreV2" + input: "file_prefix" + input: "RestoreV2/tensor_names:output:0" + input: "RestoreV2/shape_and_slices:output:0" + device: "/device:CPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + shape { + unknown_rank: true + } + } + } + } + attr { + key: "dtypes" + value { + list { + type: DT_FLOAT + type: DT_STRING + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "RestoreV2:tensors:0" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + } + node_def { + name: "AssignVariableOp" + op: "AssignVariableOp" + input: "assignvariableop_variable" + input: "Identity:output:0" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "NoOp" + op: "NoOp" + device: "/device:CPU:0" + attr { + key: "_has_manual_control_dependencies" + value { + b: true + } + } + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + node_def { + name: "Identity_1" + op: "Identity" + input: "file_prefix" + input: "^AssignVariableOp" + input: "^NoOp" + device: "/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "Identity_2" + op: "Identity" + input: "Identity_1:output:0" + input: "^NoOp_1" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + } + node_def { + name: "NoOp_1" + op: "NoOp" + input: "^AssignVariableOp" + attr { + key: "_output_shapes" + value { + list { + } + } + } + } + ret { + key: "identity_2" + value: "Identity_2:output:0" + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + attr { + key: "_input_shapes" + value { + list { + shape { + } + shape { + } + } + } + } + control_ret { + key: "AssignVariableOp" + value: "AssignVariableOp" + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "file_prefix" + } + } + } + } + arg_attr { + key: 1 + value { + attr { + key: "_user_specified_name" + value { + s: "Variable" + } + } + } + } + } +} +versions { + producer: 2009 + min_consumer: 12 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir new file mode 100644 index 00000000000000..3db375e788a033 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_executor.mlir @@ -0,0 +1,10 @@ + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main() { + tf_executor.graph { + %0:2 = tf_executor.island wraps "tf.Const"() {device = "/job:localhost/replica:0/task:0/device:TPU:0", dtype = "tfdtype$DT_INT32", value = #tf_type : tensor<2xi32>} : () -> tensor<2xi32> loc("Empty/shape") + tf_executor.fetch + } + func.return + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt new file mode 100644 index 00000000000000..4eed21fedff195 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testdata/valid_graph.txt @@ -0,0 +1,44 @@ + node { + name: "Empty/shape" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:TPU:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\200\000\000\000" + } + } + } + experimental_debug_info { + } +} +library { +} +versions { + producer: 268 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testing/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/BUILD index 8639d4023ee6a8..aff5ee5508d4fe 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/testing/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/BUILD @@ -8,6 +8,33 @@ package( ], ) +cc_library( + name = "compile_mlir", + testonly = True, + srcs = ["compile_mlir.cc"], + hdrs = ["compile_mlir.h"], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/mlir/tf2xla/api/v2:legalize_tf", + "//tensorflow/compiler/mlir/tf2xla/internal/utils:test_metadata_config", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:Pass", + "@local_xla//xla:shape_util", + "@local_xla//xla/client:client_library", + "@local_xla//xla/stream_executor:platform", + "@local_xla//xla/stream_executor:platform_manager", + ], +) + cc_library( name = "utils", testonly = True, diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.cc new file mode 100644 index 00000000000000..fa86350614ce82 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.cc @@ -0,0 +1,81 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h" + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "xla/client/client_library.h" +#include "xla/shape.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +using tpu::MlirToHloArgs; +using tpu::ShardingAndIndex; +using tpu::TPUCompileMetadataProto; + +absl::StatusOr CompileMlirModule( + const char* mlir_module_str, + ConfigProto::Experimental::MlirBridgeRollout rollout_state, + absl::string_view device_type) { + MlirToHloArgs mlir_to_hlo_args; + mlir_to_hlo_args.rollout_state = rollout_state; + mlir_to_hlo_args.mlir_module = mlir_module_str; + + TF_ASSIGN_OR_RETURN(se::Platform * platform, + se::PlatformManager::PlatformWithName("Host")); + TF_ASSIGN_OR_RETURN( + auto client, xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform)); + + std::vector arg_shapes; + TPUCompileMetadataProto metadata_proto; + // Configure metadata requires parsing the module and if we are testing a + // failure, we ignore this particular set up error assuming we'll not get + // far enough to need valid metadata. + tensorflow::tf2xla::internal::ConfigureMetadata(mlir_module_str, arg_shapes, + metadata_proto) + .IgnoreError(); + bool use_tuple_args = true; + std::vector arg_core_mapping; + std::vector> per_core_arg_shapes; + std::vector> custom_legalization_passes; + + return LegalizeMlirToHlo(mlir_to_hlo_args, metadata_proto, use_tuple_args, + device_type, custom_legalization_passes, + /*shape_determination_fns=*/{}, arg_shapes, + &arg_core_mapping, &per_core_arg_shapes, client); +} + +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h new file mode 100644 index 00000000000000..7394fe37a45818 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/testing/compile_mlir.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +// Compiles the given MLIR module to XLA HLO. +absl::StatusOr CompileMlirModule( + const char* mlir_module_str, + ConfigProto::Experimental::MlirBridgeRollout rollout_state, + absl::string_view device_type = "XLA_TPU_JIT"); + +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_API_V2_TESTING_COMPILE_MLIR_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc new file mode 100644 index 00000000000000..9a53e51e4c71e1 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" + +#include + +#include +#include + +#include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "riegeli/bytes/fd_reader.h" // from @riegeli +#include "riegeli/bytes/read_all.h" // from @riegeli +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tsl/platform/protobuf.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace { + +using mlir::DialectRegistry; +using mlir::MLIRContext; +using mlir::ModuleOp; +using mlir::OwningOpRef; + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/compiler/mlir/tf2xla/api/v2/testdata/"); +} + +class TfExecutorToGraphTest : public ::testing::Test { + public: + TfExecutorToGraphTest() { + mlir::RegisterCommonToolingDialects(registry_); + context_.appendDialectRegistry(registry_); + context_.loadAllAvailableDialects(); + } + + absl::StatusOr> CreateMlirModule( + std::string mlir_module_filename) { + std::string mlir_module_path = TestDataPath() + mlir_module_filename; + return mlir::parseSourceFile(mlir_module_path, &context_); + } + + GraphDef CreateGraphDef(std::string graphdef_filename) { + std::string file_path = TestDataPath() + graphdef_filename; + std::string contents; + GraphDef graph_def; + auto status = riegeli::ReadAll(riegeli::FdReader(file_path), contents); + if (!status.ok()) { + return graph_def; + } + tsl::protobuf::TextFormat::ParseFromString(contents, &graph_def); + return graph_def; + } + + DialectRegistry registry_; + MLIRContext context_; + OwningOpRef mlir_module_; +}; + +TEST_F(TfExecutorToGraphTest, ConvertMlirToGraphSucceeds) { + auto valid_executor_module = CreateMlirModule("valid_executor.mlir"); + GraphExportConfig confs; + absl::flat_hash_set control_ret_nodes; + FunctionLibraryDefinition flib_def(OpRegistry::Global(), + FunctionDefLibrary()); + auto result_graph = std::make_unique(flib_def); + + TF_ASSERT_OK(ConvertTfExecutorToGraph(valid_executor_module.value().get(), + confs, &result_graph, &flib_def, + &control_ret_nodes)); + + GraphDef result_graphdef; + result_graph->ToGraphDef(&result_graphdef); + GraphDef expected_graphdef = CreateGraphDef("valid_graph.txt"); + EXPECT_EQ(result_graphdef.DebugString(), expected_graphdef.DebugString()); +} + +} // namespace +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index ba1f7deb3c491c..e47d20e137bb7d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -56,7 +56,7 @@ tf_cc_test( "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:hlo_proto_cc", ], ) @@ -139,14 +139,12 @@ cc_library( srcs = ["legalize_tf_to_hlo.cc"], hdrs = ["legalize_tf_to_hlo.h"], deps = [ - ":compilation_timer", ":legalize_tf_mlir", "//tensorflow/compiler/mlir/tf2xla/api/v1:compile_tf_graph", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/tf2xla:xla_tpu_backend_registration", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core/platform:status", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", @@ -312,3 +310,44 @@ tf_proto_library( "//tensorflow/compiler/mlir/tf2xla/api/v2:__pkg__", ], ) + +cc_library( + name = "node_order", + srcs = ["node_order.cc"], + hdrs = ["node_order.h"], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "node_order_test", + size = "small", + srcs = [ + "node_order_test.cc", + ], + deps = [ + ":node_order", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc index e26741e0877d7b..f1f3de668bab70 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_to_hlo.cc @@ -23,18 +23,17 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" -#include "tensorflow/compiler/mlir/tf2xla/internal/compilation_timer.h" #include "tensorflow/compiler/mlir/tf2xla/internal/legalize_tf_mlir.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/compile_only_client.h" #include "xla/shape.h" +#include "xla/tsl/framework/device_type.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace tf2xla { @@ -75,8 +74,8 @@ absl::StatusOr LegalizeTfToHlo( Status old_bridge_status = v1::CompileTensorflowGraphToHlo( MlirToHloArgs{mlir_compilation.value()}, metadata, use_tuple_args, - shape_determination_fns, arg_shapes, arg_core_mapping, - per_core_arg_shapes, client, compilation_result); + shape_determination_fns, arg_shapes, tsl::DeviceType(device_type.str()), + arg_core_mapping, per_core_arg_shapes, client, compilation_result); if (!old_bridge_status.ok()) { IncrementTfMlirBridgeSecondPhaseCounter( diff --git a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc index 3365a85e5868fb..b874bd2a0dcff6 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc @@ -64,7 +64,8 @@ class LoggingHooksTest : public ::testing::Test { test_group_name_ = "TestGroup"; test_dir_ = testing::TmpDir(); - setenv("TF_DUMP_GRAPH_PREFIX", test_dir_.c_str(), 1); + setenv(/*name=*/"TF_DUMP_GRAPH_PREFIX", /*value=*/test_dir_.c_str(), + /*overwrite=*/1); } absl::Status CreateMlirModule(std::string mlir_module_filename) { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc b/tensorflow/compiler/mlir/tf2xla/internal/node_order.cc similarity index 98% rename from tensorflow/compiler/mlir/tensorflow/translate/node_order.cc rename to tensorflow/compiler/mlir/tf2xla/internal/node_order.cc index ac49a477f40f2b..aa72e567cf1237 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/node_order.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/node_order.h" #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.h b/tensorflow/compiler/mlir/tf2xla/internal/node_order.h similarity index 89% rename from tensorflow/compiler/mlir/tensorflow/translate/node_order.h rename to tensorflow/compiler/mlir/tf2xla/internal/node_order.h index e11f3372a3b228..a6f65006512328 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/node_order.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/node_order.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ #include #include @@ -48,4 +48,4 @@ void TopologicalOrdering( } // namespace tensorflow -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_NODE_ORDER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/node_order_test.cc similarity index 99% rename from tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc rename to tensorflow/compiler/mlir/tf2xla/internal/node_order_test.cc index 3224640a105330..cd2ba68b9b08f6 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/node_order_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/node_order.h" #include #include diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index a33cf4fa995be6..a1f6f41f806f8b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -30,6 +30,8 @@ cc_library( ":mark_ops_for_outside_compilation", ":tpu_cluster_formation", ":tpu_sharding_identification_pass", + ":tpu_validate_inputs", + ":tpu_validate_session_inputs", ":verify_clustering_pass", ":xla_broadcast", ":xla_cluster_formation", @@ -359,7 +361,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], ) @@ -505,3 +507,81 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", ], ) + +cc_library( + name = "tpu_validate_inputs_utils", + srcs = ["tpu_validate_inputs_utils.cc"], + hdrs = ["tpu_validate_inputs_utils.h"], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tpu_validate_session_inputs", + srcs = ["tpu_validate_session_inputs.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + ":tpu_validate_inputs_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + ], +) + +cc_library( + name = "tpu_validate_inputs", + srcs = ["tpu_validate_inputs.cc"], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + ":tpu_validate_inputs_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:attribute_utils", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", + "//tensorflow/compiler/mlir/tensorflow:string_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_analysis", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:tpu_rewrite_device_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_pass_inc_gen", + "//tensorflow/core:framework", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/client:sharding_builder", + ], +) + +tf_cc_test( + name = "tpu_validate_inputs_utils_test", + srcs = ["tpu_validate_inputs_utils_test.cc"], + deps = [ + ":tpu_validate_inputs_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tf2xla/transforms:test_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h index 85703c2306ad6b..4d91f113a9bafc 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -67,12 +67,21 @@ CreateXlaBroadcastPass(); std::unique_ptr> CreateTPUShardingIdentificationPass(); +// Creates a pass that validates the inputs to a TPU computation. +std::unique_ptr> +CreateTPUValidateSessionInputsPass(); + +std::unique_ptr> +CreateTPUValidateInputsPass(); + #define GEN_PASS_REGISTRATION #define GEN_PASS_DECL_MARKOPSFOROUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUCLUSTERFORMATIONPASS #define GEN_PASS_DECL_TPUEXTRACTHEADTAILOUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUEXTRACTOUTSIDECOMPILATIONPASS #define GEN_PASS_DECL_TPUSHARDINGIDENTIFICATIONPASS +#define GEN_PASS_DECL_TPUVALIDATEINPUTSPASS +#define GEN_PASS_DECL_TPUVALIDATESESSIONINPUTSPASS #define GEN_PASS_DECL_VERIFYCLUSTERINGPASS #define GEN_PASS_DECL_XLACLUSTERFORMATIONPASS #include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td index c1c34561ff0eb7..e01599c03e1eac 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -441,4 +441,26 @@ def TPUShardingIdentificationPass : Pass<"tf-tpu-sharding-identification", "Modu }]; } +def TPUValidateSessionInputsPass : Pass<"tf-tpu-validate-session-inputs", "mlir::ModuleOp"> { + let summary = "Validates inputs to the TPU TF/XLA bridge in session api"; + let description = [{ + This pass checks that the IR has valid input to TPU TF/XLA bridge in session api. + It checks the relations of multiple ops. Properties of single ops are + checked by the 'verify' method of ops. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateTPUValidateSessionInputsPass()"; +} + +def TPUValidateInputsPass : Pass<"tf-tpu-validate-inputs", "mlir::ModuleOp"> { + let summary = "Validates inputs to the TPU TF/XLA bridge"; + + let description = [{ + This pass checks that the IR has valid input to TPU TF/XLA bridge. + It checks the relations of multiple ops. Properties of single ops are + checked by the 'verify' method of ops. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateTPUValidateInputsPass()"; +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc index 0d2475c5be5433..cbe88190bada21 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_sharding_identification_pass.cc @@ -49,7 +49,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -129,13 +129,33 @@ bool BinaryOpHasTraitsForSharding(Operation* op) { return false; } -bool DoTypesHaveSameShape(Value value_0, Value value_1) { +bool DoTypesHavePartialSameShape(Value value_0, Value value_1) { auto shape_0 = mlir::dyn_cast_or_null(value_0.getType()); auto shape_1 = mlir::dyn_cast_or_null(value_1.getType()); if (shape_0 && shape_1) { - return shape_0.getShape() == shape_1.getShape(); + if (shape_0.hasStaticShape() && shape_1.hasStaticShape()) + return shape_0.getShape() == shape_1.getShape(); + int i = 0, j = 0; + while (i < shape_0.getShape().size() && j < shape_1.getShape().size()) { + if (shape_0.getShape()[i] != shape_1.getShape()[j] && + !shape_0.isDynamicDim(i) && !shape_1.isDynamicDim(j)) { + return false; + } + if (shape_0.getShape()[i] == shape_1.getShape()[j]) { + i++; + j++; + } else { + if (shape_0.isDynamicDim(i)) { + i++; + } + if (shape_1.isDynamicDim(j)) { + j++; + } + } + } + return i == shape_0.getShape().size() && j == shape_1.getShape().size(); } return false; } @@ -337,7 +357,8 @@ std::optional GetXlaShardingFromArg( } if (BinaryOpHasTraitsForSharding(owner)) { - if (DoTypesHaveSameShape(value_to_visit, owner->getResult(0))) { + if (DoTypesHavePartialSameShape(value_to_visit, + owner->getResult(0))) { next_values_to_visit.push_back(use.getOwner()->getResult(0)); continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.cc similarity index 76% rename from tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc rename to tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.cc index 21f62e41383401..07213b6dadd4bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.cc @@ -12,45 +12,79 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include #include #include #include -#include #include #include +#include "absl/log/log.h" #include "absl/strings/match.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" -#include "xla/client/sharding_builder.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h" #include "xla/xla_data.pb.h" -namespace mlir { -namespace TFTPU { +namespace tensorflow { +namespace tf2xla { +namespace internal { namespace { #define GEN_PASS_DEF_TPUVALIDATEINPUTSPASS -#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.h.inc" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" constexpr char kXLAShardingAttr[] = "_XlaSharding"; constexpr char kShardingAttr[] = "sharding"; -typedef std::unordered_map MetadataMap; +using mlir::ModuleOp; +using mlir::Operation; +using mlir::OperationPass; +using mlir::StringAttr; +using mlir::StringRef; +using mlir::Type; +using mlir::TypeID; +using mlir::func::FuncOp; +using mlir::func::ReturnOp; +using mlir::TF::AssertOp; +using mlir::TF::ConstOp; +using mlir::TF::kCompileDeviceTypeAttr; +using mlir::TF::kTpuReplicateAttr; +using mlir::TF::OutfeedEnqueueTupleOp; +using mlir::TF::PartitionedCallOp; +using mlir::TF::StatefulPartitionedCallOp; +using mlir::TF::TPUPartitionedCallOp; +using mlir::TF::TPUPartitionedInputOp; +using mlir::TF::TPUPartitionedInputV2Op; +using mlir::TF::TPUPartitionedOutputOp; +using mlir::TF::TPUPartitionedOutputV2Op; +using mlir::TF::TPUReplicatedInputOp; +using mlir::TF::TPUReplicatedOutputOp; +using mlir::TF::TPUReplicateMetadataOp; +using mlir::TF::WhileOp; +using mlir::TF::XlaSetDynamicDimensionSizeOp; +using mlir::tf_executor::FetchOp; +using mlir::tf_executor::GraphOp; +using mlir::tf_executor::IslandOp; +using mlir::tf_executor::YieldOp; + +typedef std::unordered_map MetadataMap; struct TPUValidateInputsPass : public impl::TPUValidateInputsPassBase { @@ -60,21 +94,21 @@ bool IsTpuRegularOp(Operation* op) { static auto* ops = [] { llvm::SmallDenseSet* ops_set = new llvm::SmallDenseSet{ - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), }; return ops_set; }(); @@ -87,10 +121,10 @@ bool IsIntersectionXlaNonXlaOps(Operation* op) { static auto* ops = [] { llvm::SmallDenseSet* ops_set = new llvm::SmallDenseSet{ - TypeID::get(), - TypeID::get(), - TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), }; return ops_set; }(); @@ -103,9 +137,9 @@ bool IsPartitionedOp(Operation* op) { static auto* ops = [] { llvm::SmallDenseSet* ops_set = new llvm::SmallDenseSet{ - TypeID::get(), - TypeID::get(), - TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), }; return ops_set; }(); @@ -140,12 +174,12 @@ llvm::SmallVector GetPredecessors(Operation* op) { bool CheckTpuReplicateAttr(Operation* op, StringAttr attr, std::function errormsg) { - if (!op->hasAttr(TF::kTpuReplicateAttr)) { + if (!op->hasAttr(kTpuReplicateAttr)) { op->emitOpError("TF2XLA TPU bridge input check: " + errormsg() + "missing _tpu_replicate attr"); return false; } - auto opattr = op->getAttr(TF::kTpuReplicateAttr); + auto opattr = op->getAttr(kTpuReplicateAttr); if (opattr != attr) { op->emitOpError("TF2XLA TPU bridge input check: " + errormsg() + "invalid _tpu_replicate attr.") @@ -155,7 +189,7 @@ bool CheckTpuReplicateAttr(Operation* op, StringAttr attr, return true; } -bool ValidateReplicatedInput(TF::TPUReplicatedInputOp rep, int num_replicas, +bool ValidateReplicatedInput(TPUReplicatedInputOp rep, int num_replicas, StringAttr attr) { int arity = rep.getInputs().size(); if (rep.getIsPacked() && arity != 1) { @@ -179,7 +213,7 @@ bool ValidateReplicatedInput(TF::TPUReplicatedInputOp rep, int num_replicas, } return true; } -bool ValidateReplicatedOutput(TF::TPUReplicatedOutputOp rep, int num_replicas, +bool ValidateReplicatedOutput(TPUReplicatedOutputOp rep, int num_replicas, StringAttr attr) { int arity = rep.getOutputs().size(); if (arity != num_replicas) { @@ -198,7 +232,7 @@ bool ValidateReplicatedOutput(TF::TPUReplicatedOutputOp rep, int num_replicas, } return true; } -bool ValidatePartitionedInput(TF::TPUPartitionedInputOp rep, +bool ValidatePartitionedInput(TPUPartitionedInputOp rep, int num_cores_per_replica) { int arity = rep.getInputs().size(); if (arity != num_cores_per_replica) { @@ -210,7 +244,7 @@ bool ValidatePartitionedInput(TF::TPUPartitionedInputOp rep, } return true; } -bool ValidatePartitionedInputV2(TF::TPUPartitionedInputV2Op rep, +bool ValidatePartitionedInputV2(TPUPartitionedInputV2Op rep, int num_cores_per_replica) { int arity = rep.getInputs().size(); if (rep.getIsPacked() && arity != 1) { @@ -241,33 +275,33 @@ bool ValidatePartitionedOutput(T rep, int num_cores_per_replica) { return true; } -bool CheckReplicatedIOOp(Operation* op, TF::TPUReplicateMetadataOp metadata, +bool CheckReplicatedIOOp(Operation* op, TPUReplicateMetadataOp metadata, Operation* parent) { int num_replicas = metadata.getNumReplicas(); int num_cores_per_replica = metadata.getNumCoresPerReplica(); StringAttr tpu_replicate_attr = - metadata->getAttrOfType(TF::kTpuReplicateAttr); - if (auto repinput = dyn_cast(op)) { + metadata->getAttrOfType(kTpuReplicateAttr); + if (auto repinput = dyn_cast(op)) { if (!ValidateReplicatedInput(repinput, num_replicas, tpu_replicate_attr)) return false; } - if (auto repoutput = dyn_cast(op)) { + if (auto repoutput = dyn_cast(op)) { if (!ValidateReplicatedOutput(repoutput, num_replicas, tpu_replicate_attr)) return false; } - if (auto partinput = dyn_cast(op)) { + if (auto partinput = dyn_cast(op)) { if (!ValidatePartitionedInput(partinput, num_cores_per_replica)) return false; } - if (auto partinput = dyn_cast(op)) { + if (auto partinput = dyn_cast(op)) { if (!ValidatePartitionedInputV2(partinput, num_cores_per_replica)) return false; } - if (auto partoutput = dyn_cast(op)) { + if (auto partoutput = dyn_cast(op)) { if (!ValidatePartitionedOutput(partoutput, num_cores_per_replica)) return false; } - if (auto partoutput = dyn_cast(op)) { + if (auto partoutput = dyn_cast(op)) { if (!ValidatePartitionedOutput(partoutput, num_cores_per_replica)) return false; } @@ -277,8 +311,8 @@ bool CheckReplicatedIOOp(Operation* op, TF::TPUReplicateMetadataOp metadata, bool CheckClusterSuccessors(Operation* op, std::string cluster, Operation* parent, MetadataMap& metadata_map) { std::string cluster_succ = ""; - if (op->hasAttr(TF::kTpuReplicateAttr)) { - cluster_succ = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + if (op->hasAttr(kTpuReplicateAttr)) { + cluster_succ = op->getAttrOfType(kTpuReplicateAttr).str(); } if (cluster_succ.empty()) { // TODO (b/269195256#comment16): Change to error after resolving issue @@ -304,7 +338,7 @@ bool CheckClusterSuccessors(Operation* op, std::string cluster, bool CheckNonClusterSuccessors(Operation* op, Operation* parent, MetadataMap& metadata_map) { if (!IsTpuRegularOp(op)) { - if (isa(op)) { + if (isa(op)) { op->emitOpError("TF2XLA TPU bridge input check: non-cluster op = ") << parent->getName() << " has invalid successor op = " << op->getName(); @@ -319,7 +353,7 @@ bool CheckNonClusterSuccessors(Operation* op, Operation* parent, bool CheckNonClusterPredecessors(Operation* op, Operation* parent, MetadataMap& metadata_map) { if (!IsTpuRegularOp(op)) { - if (isa(op)) { + if (isa(op)) { op->emitOpError("TF2XLA TPU bridge input check: non-cluster op = ") << parent->getName() << " has invalid predecessor op = " << op->getName(); @@ -334,8 +368,8 @@ bool CheckNonClusterPredecessors(Operation* op, Operation* parent, bool CheckOpsClusterIO(Operation* op, MetadataMap& metadata_map) { bool is_cluster_op = false; std::string cluster = ""; - if (op->hasAttr(TF::kTpuReplicateAttr)) { - cluster = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + if (op->hasAttr(kTpuReplicateAttr)) { + cluster = op->getAttrOfType(kTpuReplicateAttr).str(); if (cluster.empty()) { op->emitOpError("TF2XLA TPU bridge input check: empty _tpu_replicate") << " attr for op = " << op->getName(); @@ -372,7 +406,7 @@ bool CheckOpsClusterIO(Operation* op, MetadataMap& metadata_map) { bool TypeMustBeNonXLA(const Type& type) { const Type elem = getElementTypeOrSelf(type); - return !mlir::isa(elem) && + return !mlir::isa(elem) && !tensorflow::TypeValidForXLA(type); } @@ -397,34 +431,32 @@ bool IsMustBeXlaOp(Operation* op, MetadataMap metadata_map) { // All PartitionedCall are inlined-out before XLA. // So MustBeXLA should return false if (IsPartitionedOp(op)) return false; - if (!op->hasAttr(TF::kTpuReplicateAttr)) return false; - auto cluster = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + if (!op->hasAttr(kTpuReplicateAttr)) return false; + auto cluster = op->getAttrOfType(kTpuReplicateAttr).str(); if (metadata_map.find(cluster) == metadata_map.end()) return false; auto metadata = metadata_map[cluster]; if (!metadata.getAllowSoftPlacement() && - !op->hasAttr(TF::kXlaOutsideCompilationAttr)) + !op->hasAttr(mlir::TF::kXlaOutsideCompilationAttr)) return true; std::string device = ""; - if (op->hasAttr(TF::kDeviceAttr)) - device = op->getAttrOfType(TF::kDeviceAttr).str(); + if (op->hasAttr(mlir::TF::kDeviceAttr)) + device = op->getAttrOfType(mlir::TF::kDeviceAttr).str(); else return false; - if (absl::StrContains(device, TF::kTpuDevice)) return true; + if (absl::StrContains(device, mlir::TF::kTpuDevice)) return true; return false; } bool ValidateIntersectionXlaNonXlaOps(Operation* op, MetadataMap metadata_map) { - if (isa(op) || - isa(op) || isa(op) || - isa(op) || - isa(op) || - isa(op) || - isa(op)) + if (isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op)) return true; if (IsMustBeXlaOp(op, metadata_map) && IsMustNotBeXlaOp(op)) { // TODO(b/269195256#comment19) change the warning for Identity op to error // when issue with input graph is resolved. Possible issue with python layer // inserting Identity op incorrectly. - if (isa(op)) { + if (isa(op)) { op->emitWarning("TF/XLA TPU bridge input check: found invalid op. ") << op->getName() << " can't be both xla and non-xla"; return true; @@ -488,7 +520,7 @@ bool IsValidShardingTupleForArity(Operation* op) { } bool IsValidMAXIMALSharding(Operation* op, MetadataMap& metadata_map) { - if (!op->hasAttr(TF::kTpuReplicateAttr)) return true; + if (!op->hasAttr(kTpuReplicateAttr)) return true; if (!op->hasAttr(kXLAShardingAttr) && !op->hasAttr(kShardingAttr)) { return true; } @@ -498,7 +530,7 @@ bool IsValidMAXIMALSharding(Operation* op, MetadataMap& metadata_map) { // for it. These checks are already performed in CheckOpsClusterIO. // Also assuming that if there is sharding, then there must be // cluster and the metadata corresponding to it. - auto cluster = op->getAttrOfType(TF::kTpuReplicateAttr).str(); + auto cluster = op->getAttrOfType(kTpuReplicateAttr).str(); if (cluster.empty()) { return true; } @@ -542,8 +574,8 @@ bool IsValidMAXIMALSharding(Operation* op, MetadataMap& metadata_map) { bool HasSingleCoreTpu(Operation* op) { if (auto compilation_attr = - op->getAttrOfType(TF::kCompileDeviceTypeAttr)) { - if (compilation_attr.getValue().str() == TF::kTpuDevice) { + op->getAttrOfType(kCompileDeviceTypeAttr)) { + if (compilation_attr.getValue().str() == mlir::TF::kTpuDevice) { op->emitOpError( "TF2XLA TPU bridge input check: found a single-core TPU graph"); return true; @@ -556,16 +588,22 @@ void TPUValidateInputsPass::runOnOperation() { ModuleOp module = getOperation(); bool success = true; int num_metadata = 0; - TF::TPUReplicateMetadataOp metadata; + TPUReplicateMetadataOp metadata; MetadataMap metadata_map; - module.walk([&](TF::TPUReplicateMetadataOp meta) { + module.walk([&](TPUReplicateMetadataOp meta) { ++num_metadata; metadata = meta; - metadata_map[meta->getAttrOfType(TF::kTpuReplicateAttr).str()] = + metadata_map[meta->getAttrOfType(kTpuReplicateAttr).str()] = meta; }); getOperation().walk([&](mlir::Operation* op) { + if (IsPotentialUnsupportedOp(op)) { + LOG(WARNING) << "Potential unsupported op: " + << op->getName().getStringRef().str() + << ". TF2XLA MLIR bridge does not guarantee to support it."; + } + if (IsTpuRegularOp(op)) { success &= CheckOpsClusterIO(op, metadata_map); } @@ -577,6 +615,17 @@ void TPUValidateInputsPass::runOnOperation() { success &= IsValidShardingTupleForArity(op); } success &= !HasSingleCoreTpu(op); + + if (!success) { + signalPassFailure(); + } + }); + + module.walk([&](GraphOp graph) { + if (HasV1ControlFlow(graph)) { + LOG(WARNING) << "TF2XLA MLIR bridge does not support v1 control flow." + << " Use at your own risk."; + } if (!success) { signalPassFailure(); } @@ -589,5 +638,6 @@ std::unique_ptr> CreateTPUValidateInputsPass() { return std::make_unique(); } -} // namespace TFTPU -} // namespace mlir +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.mlir similarity index 91% rename from tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir rename to tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.mlir index 295079b24fe799..909bdc6147b2e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_validate_inputs.mlir +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs.mlir @@ -253,3 +253,27 @@ func.func @num_replicas_1(%arg0: tensor) -> (tensor) { } return %0 : tensor } + +// ----- +func.func @contians_InfeedDequeueTuple(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + // expected-warning @+1 {{TPU_REPLICATED_CORE:0 device is not supported for op = tf.InfeedDequeueTuple in TF2XLA MLIR Bridge}} + %infeed_output:3, %c2 = tf_executor.island wraps "tf.InfeedDequeueTuple"() {device = "/device:TPU_REPLICATED_CORE:0"} : () -> (tensor<3xi32>, tensor<4x?xf32>, tensor<*xi16>) + %ro:2, %c3 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- +func.func @graph_contains_v1_control_flow() { + tf_executor.graph { + // expected-warning @+1 {{ is v1 control flow op which is not supported in TF2XLA MLIR Bridge.}} + %control = tf_executor.ControlTrigger {} + tf_executor.fetch + } + func.return +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.cc new file mode 100644 index 00000000000000..5969fd7afe24f6 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.cc @@ -0,0 +1,66 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h" + +#include + +#include "absl/strings/match.h" +#include "llvm/ADT/DenseSet.h" +#include "mlir/Support/TypeID.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +bool IsPotentialUnsupportedOp(Operation* op) { + static auto* ops = [] { + llvm::SmallDenseSet* ops_set = + new llvm::SmallDenseSet{ + TypeID::get(), + }; + return ops_set; + }(); + auto abstractOp = op->getRegisteredInfo(); + if (!abstractOp) return false; + + bool is_in_ops = ops->count(abstractOp->getTypeID()) != 0; + if (!is_in_ops) return false; + + std::string device = ""; + if (!op->hasAttr(kDeviceAttr)) return false; + device = op->getAttrOfType(kDeviceAttr).str(); + if (!absl::StrContains(device, kTpuReplicatedCoreZeroAttr)) return false; + op->emitWarning("TPU_REPLICATED_CORE:0 device is not supported for op = ") + << op->getName() << " in TF2XLA MLIR Bridge"; + + return true; +} + +bool HasV1ControlFlow(GraphOp graph) { + for (Operation& op : graph.GetBody().without_terminator()) { + auto island_op = llvm::dyn_cast(op); + if (!island_op) { + op.emitWarning() << " is v1 control flow op which is not supported in " + "TF2XLA MLIR Bridge."; + return true; + } + } + return false; +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h new file mode 100644 index 00000000000000..152b2e02081658 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ + +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +constexpr char kTpuReplicatedCoreZeroAttr[] = "TPU_REPLICATED_CORE:0"; + +using mlir::ModuleOp; +using mlir::Operation; +using mlir::StringAttr; +using mlir::TypeID; +using mlir::TF::InfeedDequeueTupleOp; +using mlir::TF::kDeviceAttr; +using mlir::tf_executor::GraphOp; + +bool IsPotentialUnsupportedOp(Operation* op); + +bool HasV1ControlFlow(GraphOp graph); + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_TPU_VALIDATE_INPUTS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc new file mode 100644 index 00000000000000..a64f06b838f12f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h" + +#include +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { +namespace { + +using mlir::mhlo::test::GetMlirModuleFromString; + +TEST(IsPotentialUnsupportedOp, ClusterOpReturnsFalse) { + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::OpBuilder builder(module_ref->getBodyRegion()); + + llvm::SmallVector result_types; + auto cluster = builder.create( + mlir::UnknownLoc::get(&context), result_types); + cluster->dump(); + EXPECT_FALSE(IsPotentialUnsupportedOp(cluster)); +} + +TEST(IsPotentialUnsupportedOp, InfeedDequeueTupleOpReturnsTrue) { + mlir::MLIRContext context; + context.loadDialect(); + mlir::OwningOpRef module_ref = + mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); + mlir::OpBuilder builder(module_ref->getBodyRegion()); + + llvm::SmallVector result_types; + mlir::StringAttr _XlaSharding = mlir::StringAttr::get(&context, ""); + mlir::ArrayAttr layouts = mlir::ArrayAttr::get(&context, {}); + + auto infeed_dequeue_tuple = builder.create( + mlir::UnknownLoc::get(&context), result_types, _XlaSharding, layouts); + + infeed_dequeue_tuple->setAttr( + kDeviceAttr, mlir::StringAttr::get(&context, kTpuReplicatedCoreZeroAttr)); + + EXPECT_TRUE(IsPotentialUnsupportedOp(infeed_dequeue_tuple)); +} + +TEST(HasV1ControlFlow, ReturnsTrue) { + static constexpr char kMlirModuleStr[] = R"( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @graph_contains_v1_control_flow() { + tf_executor.graph { + %control = tf_executor.ControlTrigger {} + tf_executor.fetch + } + func.return + } + })"; + mlir::MLIRContext context; + context.loadDialect(); + auto module = GetMlirModuleFromString(kMlirModuleStr, &context); + + module->get().walk( + [&](GraphOp graph) { EXPECT_TRUE(HasV1ControlFlow(graph)); }); +} + +} // namespace + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.cc new file mode 100644 index 00000000000000..fc1c84ca135b73 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.cc @@ -0,0 +1,75 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/log/log.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils.h" + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +#define GEN_PASS_DEF_TPUVALIDATESESSIONINPUTSPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +using mlir::ModuleOp; + +struct TPUValidateSessionInputsPass + : public impl::TPUValidateSessionInputsPassBase< + TPUValidateSessionInputsPass> { + void runOnOperation() override; +}; + +void TPUValidateSessionInputsPass::runOnOperation() { + ModuleOp module = getOperation(); + bool success = true; + + module.walk([&](mlir::Operation* op) { + if (IsPotentialUnsupportedOp(op)) { + LOG(WARNING) << "Potential unsupported op: " + << op->getName().getStringRef().str() + << ". TF2XLA MLIR bridge does not guarantee to support it."; + } + if (!success) { + signalPassFailure(); + } + }); + + module.walk([&](GraphOp graph) { + if (HasV1ControlFlow(graph)) { + LOG(WARNING) << "TF2XLA MLIR bridge does not support v1 control flow." + << " Use at your own risk."; + } + if (!success) { + signalPassFailure(); + } + }); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUValidateSessionInputsPass() { + return std::make_unique(); +} + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.mlir b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.mlir new file mode 100644 index 00000000000000..18c7388d16f198 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_session_inputs.mlir @@ -0,0 +1,37 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-validate-session-inputs | FileCheck %s + +// CHECK-LABEL: func @does_not_contian_InfeedDequeueTuple +func.func @does_not_contian_InfeedDequeueTuple(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + %ro:2, %c2 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- +func.func @contians_InfeedDequeueTuple(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor, tensor) { + %0:2 = tf_executor.graph { + %control = tf_executor.island() wraps "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "/device:TPU:0", num_replicas = 2, topology = "topology"} : () -> () + %ri, %c0 = tf_executor.island wraps "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor, tensor) -> tensor + %out, %c1 = tf_executor.island wraps "tf.opA"(%ri) {_tpu_replicate = "cluster"} : (tensor) -> tensor + // expected-warning @+1 {{TPU_REPLICATED_CORE:0 device is not supported for op = tf.InfeedDequeueTuple in TF2XLA MLIR Bridge}} + %infeed_output:3, %c2 = tf_executor.island wraps "tf.InfeedDequeueTuple"() {device = "/device:TPU_REPLICATED_CORE:0"} : () -> (tensor<3xi32>, tensor<4x?xf32>, tensor<*xi16>) + %ro:2, %c3 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%out) : (tensor) -> (tensor, tensor) + tf_executor.fetch %ro#0, %ro#1 : tensor, tensor + } + return %0#0, %0#1 : tensor, tensor +} + +// ----- +func.func @graph_contains_v1_control_flow() { + tf_executor.graph { + // expected-warning @+1 {{ is v1 control flow op which is not supported in TF2XLA MLIR Bridge.}} + %control = tf_executor.ControlTrigger {} + tf_executor.fetch + } + func.return +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc index 5e0ec8a50da525..80ba50fcd898e8 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc @@ -302,11 +302,6 @@ LogicalResult MoveAllBroadcastsToCluster(ClusterOp cluster, if (!num_cores_per_replica_attr) return cluster.emitOpError( CreateMissingAttributeMsg(tensorflow::kNumCoresPerReplicaAttr)); - int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - - // TODO(b/329483850): Support spmd ICI weight distribution so when num of core - // per replica > 1, it does not need to be skipped. - if (num_cores_per_replica != 1) return success(); llvm::SetVector bcasts; cluster->walk([&](Operation* op) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc index 696776f75b021c..e795d430db97e4 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/test_matchers_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/hlo.pb.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/lib/monitoring/counter.h" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD index 70f8d206840047..eeb10c1e5a0854 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/BUILD @@ -73,7 +73,7 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc index fc11c2dab477cc..79928fd362664a 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/utils/test_metadata_config.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 4231990e0769d1..1939a3dc8cd875 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -2654,3813 +2654,3 @@ func.func @sigmoid_grad_dynamic(%arg0: tensor, %arg1: tensor) -> t func.return %0 : tensor } -// ----- - -// CHECK-LABEL: @sin -func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.sine %arg0 : tensor<2xf32> - %0 = "tf.Sin"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @sin_dynamic -func.func @sin_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.sine %arg0 : tensor - %0 = "tf.Sin"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @rsqrt -func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.rsqrt %arg0 : tensor<2xf32> - %0 = "tf.Rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @rsqrt_dynamic -func.func @rsqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.rsqrt %arg0 : tensor - %0 = "tf.Rsqrt"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sqrt -func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.sqrt %arg0 : tensor<2xf32> - %0 = "tf.Sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @sqrt_dynamic -func.func @sqrt_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.sqrt %arg0 : tensor - %0 = "tf.Sqrt"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @tanh -func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.tanh %arg0 : tensor<2xf32> - %0 = "tf.Tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @tanh_dynamic -func.func @tanh_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.tanh %arg0 : tensor - %0 = "tf.Tanh"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast -func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xf32> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> - func.return %0 : tensor<2xf32> -} - -// ----- - -// CHECK-LABEL: func @bitcast_dynamic -func.func @bitcast_dynamic(%arg0: tensor) -> tensor { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor) -> tensor - %0 = "tf.Bitcast"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_same_widths -func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2xi32> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: func @bitcast_smaller_input_width -func.func @bitcast_smaller_input_width(%arg0: tensor<8xi8>) -> tensor { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<8xi8>) -> tensor - %0 = "tf.Bitcast"(%arg0) : (tensor<8xi8>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @bitcast_smaller_output_width -func.func @bitcast_smaller_output_width(%arg0: tensor<2xf32>) -> tensor<2x2xf16> { - // CHECK: mhlo.bitcast_convert %arg0 : (tensor<2xf32>) -> tensor<2x2xf16> - %0 = "tf.Bitcast"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf16> - func.return %0 : tensor<2x2xf16> -} - -// ----- - -// CHECK-LABEL: squeeze -func.func @squeeze(%arg0: tensor<1x1x10xf32>) -> tensor<1x10xf32> { - // CHECK: mhlo.reshape - %0 = "tf.Squeeze"(%arg0) : (tensor<1x1x10xf32>) -> tensor<1x10xf32> - func.return %0 : tensor<1x10xf32> -} - -// ----- - -// CHECK-LABEL: squeeze_ranked -func.func @squeeze_ranked(%arg0: tensor) -> tensor { - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor - // CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex> - // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[R]] : tensor - %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_ranked_negative -func.func @squeeze_ranked_negative(%arg0: tensor) -> tensor { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[D0:.*]] = tensor.dim %arg0, %[[C0]] : tensor - // CHECK: %[[C2:.*]] = arith.constant 2 : index - // CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor - // CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex> - // CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor, tensor<2xindex>) -> tensor - // CHECK: return %[[R]] : tensor - %0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_ranked_dynamic -func.func @squeeze_ranked_dynamic(%arg0: tensor) -> tensor { - // CHECK: "tf.Squeeze" - %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: squeeze_dynamic -func.func @squeeze_dynamic(%arg0: tensor) -> tensor<*xf32> { - // CHECK: "tf.Squeeze" - %0 = "tf.Squeeze"(%arg0) : (tensor) -> tensor<*xf32> - func.return %0 : tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: expand_dims -func.func @expand_dims(%arg0: tensor<2xf32>, %axis: tensor) -> tensor<1x2xf32> { - // CHECK: mhlo.reshape - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor<2xf32>, tensor) -> tensor<1x2xf32> - func.return %0 : tensor<1x2xf32> -} - -// ----- - -// CHECK-LABEL: expand_dims_dynamic -func.func @expand_dims_dynamic(%arg0: tensor) -> tensor { - %axis = "tf.Const"() {value = dense<1> : tensor} : () -> (tensor) - - // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 - // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] - // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] - // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]] - // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor) -> tensor - - // CHECK: return %[[RESHAPE]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: expand_dynamic_dims_rank1_axis -func.func @expand_dynamic_dims_rank1_axis(%arg0: tensor) -> tensor { - %axis = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - - // CHECK-DAG: %[[SHAPEOF:.+]] = shape.shape_of %arg0 - // CHECK-DAG: %[[CST0:.+]] = arith.constant 0 - // CHECK-DAG: %[[CST1:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT0:.+]] = tensor.extract %[[SHAPEOF]][%[[CST0]]] - // CHECK-DAG: %[[CST1_0:.+]] = arith.constant 1 - // CHECK-DAG: %[[GETEXTENT1:.+]] = tensor.extract %[[SHAPEOF]][%[[CST1_0]]] - // CHECK-DAG: %[[CST2:.+]] = arith.constant 2 - // CHECK-DAG: %[[GETEXTENT2:.+]] = tensor.extract %[[SHAPEOF]][%[[CST2]]] - // CHECK-DAG: %[[TOEXTENTS:.+]] = tensor.from_elements %[[GETEXTENT0]], %[[CST1]], %[[GETEXTENT1]], %[[GETEXTENT2]] - // CHECK-DAG: %[[RESHAPE:.+]] = mhlo.dynamic_reshape %arg0, %[[TOEXTENTS]] - %0 = "tf.ExpandDims"(%arg0, %axis) : (tensor, tensor<1xi32>) -> tensor - - // CHECK: return %[[RESHAPE]] - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sign -// CHECK-SAME: [[ARG:%arg.*]]: tensor<1x2x3x4xf32> -func.func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: [[SIGN:%.*]] = mhlo.sign [[ARG]] - // CHECK: return [[SIGN]] : tensor<1x2x3x4xf32> - %0 = "tf.Sign"(%arg0) : (tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>) - func.return %0 : tensor<1x2x3x4xf32> -} - -// ----- - -// CHECK-LABEL: func @sign_dynamic -func.func @sign_dynamic(%arg0: tensor) -> tensor { - // CHECK: [[SIGN:%.*]] = mhlo.sign %arg0 : tensor - // CHECK: return [[SIGN]] : tensor - %0 = "tf.Sign"(%arg0) : (tensor) -> (tensor) - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: slice_constant_start -func.func @slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) - // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> - // CHECK: return %[[RESULT]] : tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: slice_i32_consts -func.func @slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: slice_constant_start_negative_one_size -func.func @slice_constant_start_negative_one_size(%arg0: tensor<4xi32>) -> tensor<3xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<3> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<3xi32> - // CHECK: return %[[RESULT]] : tensor<3xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[-1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: slice_constant_start_dynamic_shape -func.func @slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_variable_start -func.func @slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_mhlo_sizes -func.func @slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { - // CHECK-NOT: "tf.Slice" - %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = "tf.Slice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> - func.return %1 : tensor<1x512x4xf32> -} - -// ----- - -// CHECK-LABEL: slice_variable_start_negative_one_size -func.func @slice_variable_start_negative_one_size(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[RESULT:.*]] = "tf.Slice" - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, -1]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.Slice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: slice_real_dynamic_slice -func.func @slice_real_dynamic_slice(%arg0: tensor<4xi32>, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor { - // CHECK: tensor.extract {{.*}} : tensor<1xi64> - // CHECK: tensor.extract {{.*}} : tensor<1xi64> - // CHECK: arith.index_cast {{.*}} : index to i64 - // CHECK: arith.cmpi eq, {{.*}} : i64 - // CHECK: arith.addi {{.*}} : i64 - // CHECK: tensor.dim {{.*}} : tensor<4xi32> - // CHECK: arith.index_cast {{.*}} : index to i64 - // CHECK: select {{.*}} : i64 - // CHECK: arith.index_cast {{.*}} : i64 to index - // CHECK: arith.index_cast {{.*}} : i64 to index - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<1xindex> - %0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// StridedSlice op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: simple_strided_slice -func.func @simple_strided_slice(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { - %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 1]> - // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<3x2xf32> - - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %output : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: dynamic_strided_slice -func.func @dynamic_strided_slice(%input: tensor) -> tensor { - %begin = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 7]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: "tf.StridedSlice" - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - func.return %output : tensor -} - -// ----- - -// CHECK-LABEL: strided_slice_negative_indices -func.func @strided_slice_negative_indices(%input: tensor<4x8xf32>) -> tensor<3x2xf32> { - %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 1]> - // CHECK-DAG-SAME: limit_indices = dense<[3, 7]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<3x2xf32> - - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<3x2xf32> - func.return %output : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: dynamic_strided_slice_negative_indices -func.func @dynamic_strided_slice_negative_indices(%input: tensor) -> tensor { - %begin = "tf.Const"() {value = dense<[-1, -2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[-4, -8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[-1, -3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: tf.StridedSlice - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor - func.return %output : tensor -} - -// ----- - -// CHECK-LABEL: strided_slice_range_clamping -func.func @strided_slice_range_clamping(%input: tensor<4x8xf32>) -> tensor<1x3xf32> { - %begin = "tf.Const"() {value = dense<[-4, -10]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[1, 10]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: mhlo.slice - // CHECK-DAG-SAME: start_indices = dense<[0, 0]> - // CHECK-DAG-SAME: limit_indices = dense<[1, 8]> - // CHECK-DAG-SAME: strides = dense<[1, 3]> - // CHECK-SAME: -> tensor<1x3xf32> - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4x8xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x3xf32> - func.return %output : tensor<1x3xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_empty -func.func @strided_slice_empty(%input: tensor<4xf32>) -> tensor<0xf32> { - %begin = "tf.Const"() {value = dense<[-4]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %end = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %strides = "tf.Const"() {value = dense<[-1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - - // CHECK: mhlo.constant dense<> : tensor<0xf32> - %output = "tf.StridedSlice"(%input, %begin, %end, %strides) - : (tensor<4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf32> - func.return %output : tensor<0xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_begin_end_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<4x128x1024xf32> -func.func @strided_slice_begin_end_mask(%input: tensor<4x128x1024xf32>) { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 0, 0, 1 (= 1) - // End mask: 1, 0, 0 (= 4) - - // So result shape: - // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 - // result shape: [4, 16, 1022] - - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: %[[REVERSE:.*]] = "mhlo.reverse"(%[[INPUT]]) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[REVERSE]]) - // CHECK-DAG-SAME: limit_indices = dense<[4, 65, 1024]> - // CHECK-DAG-SAME: start_indices = dense<[0, 4, 2]> - // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> - // CHECK-SAME: -> tensor<4x16x1022xf32> - - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<4x16x1022xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<4x16x1022xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_shrink_axis_mask -// CHECK-SAME: %[[INPUT:.+]]: tensor<4x128x1024xf32> -func.func @strided_slice_shrink_axis_mask(%input: tensor<4x128x1024xf32>) { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) - // Shrink axis mask: 1, 0, 1 (= 5) - - // So result shape: - // Dim #0: shrink axis, take value at [1] - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: shrink axis, take value at [-3] - // result shape: [16] - - // As output shape of StridedSlice differs, a reshape will follow. - - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[1, 65, 1022]> - // CHECK-DAG-SAME: start_indices = dense<[0, 4, 1021]> - // CHECK-DAG-SAME: strides = dense<[1, 4, 1]> - // CHECK-SAME: -> tensor<1x16x1xf32> - - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 1, end_mask = 4, shrink_axis_mask = 5} : (tensor<4x128x1024xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<16xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<16xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_ellipsis_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> -func.func @strided_slice_ellipsis_mask(%input: tensor<2x4x8x16x32x64xf32>) { - // For StridedSlice input[1, ..., 8:, :10, 2:6:2] - // The ellipsis mask is applied to dim #1, #2, i.e, we get canonicalized - // slice input[1, :, :, 8:, :10, 2:6:2] - - // The start, limit indices and strides attributes of mhlo.slice would - // reflect the canonicalized slice. - // As output shape of StridedSlice differs, a reshape will follow. - - %begin = "tf.Const"() {value = dense<[1, 0, 8, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 10, 10, 6]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 2]> : tensor<5xi32>} : () -> (tensor<5xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> - // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> - // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 8, end_mask = 4, shrink_axis_mask = 1, ellipsis_mask = 2} : (tensor<2x4x8x16x32x64xf32>, tensor<5xi32>, tensor<5xi32>, tensor<5xi32>) -> tensor<4x8x8x10x2xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<4x8x8x10x2xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_new_axis_mask -// CHECK-SAME: %[[INPUT:[a-z0-9]+]]: tensor<2x4x8x16x32x64xf32> -func.func @strided_slice_new_axis_mask(%input: tensor<2x4x8x16x32x64xf32>) { - // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] - // New axis mask is at index 1 and 6 of sparse spec, so - // new_axis_mask = 2^1 + 2^6 = 66 - // The ellipsis mask is applied to dim #1, #2 of input i.e, we get - // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] - // This is then reshaped to add the new axes. - - // The start, limit indices and strides attributes of mhlo.slice would - // reflect the canonicalized slice. - // As output shape of StridedSlice differs, a reshape will follow to reflect - // new axes added. - - %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) - - // CHECK: %[[SLICE:.*]] = "mhlo.slice"(%[[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[2, 4, 8, 16, 10, 6]> : tensor<6xi64> - // CHECK-DAG-SAME: start_indices = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: strides = dense<[1, 1, 1, 1, 1, 2]> : tensoe<6xi64> - // CHECK-SAME: -> tensor<1x4x8x8x10x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<2x4x8x16x32x64xf32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>) -> tensor<1x4x8x8x10x2x1xf32> - - // CHECK: mhlo.reshape %[[SLICE]] - // CHECK-SAME: -> tensor<1x4x8x8x10x2x1xf32> - - func.return -} - -// ----- - -// CHECK-LABEL: strided_slice_implicit_ellipsis_mask( -// CHECK-SAME: [[INPUT:%.*]]: tensor<10x16x2xf32> -func.func @strided_slice_implicit_ellipsis_mask(%input: tensor<10x16x2xf32>) -> tensor<2x16x2xf32> { - // StridedSlice gets input[8:10], which is same as input[8:10, ...] - // The start_indices, limit_indices, and strides attribute of mhlo.slice - // reflect the canonicalized slice. - %begin = "tf.Const"() {value = dense<8> : tensor<1xi32>} : () -> tensor<1xi32> - %end = "tf.Const"() {value = dense<10> : tensor<1xi32>} : () -> tensor<1xi32> - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: [[SLICE:%.*]] = "mhlo.slice"([[INPUT]]) - // CHECK-DAG-SAME: limit_indices = dense<[10, 16, 2]> : tensor<3xi64> - // CHECK-DAG-SAME: start_indices = dense<[8, 0, 0]> : tensor<3xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<3xi64> - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[SLICE]] : (tensor<2x16x2xf32>) -> tensor<2x16x2xf32> - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = f32} : (tensor<10x16x2xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2x16x2xf32> - // CHECK: return [[RESHAPE]] : tensor<2x16x2xf32> - func.return %0 : tensor<2x16x2xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end -func.func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1x97xi32>) -> (tensor<1x97xi32>) { - // In this case, the `begin` and `end` inputs are unknown at compile time -- - // so the StridedSlice needs to slice these vectors and use that as input to - // an HLO dynamic slice. - %begin = "tf.Pack"(%arg0) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %2 = "tf.AddV2"(%arg0, %0) {T = i32, device = ""} : (tensor, tensor) -> tensor - %end = "tf.Pack"(%2) {N = 1 : i64, T = i32, axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - // CHECK: %[[A:.*]] = mhlo.reshape %arg0 : (tensor) -> tensor<1xi32> - // CHECK-NEXT: %[[BEGIN:.*]] = "mhlo.concatenate"(%[[A]]) - // CHECK-DAG-SAME: {dimension = 0 : i64} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK-NEXT: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> - // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] - // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor - // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : - // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> - // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> - %result = "tf.StridedSlice"(%arg1, %begin, %end, %1) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - // CHECK-NEXT: return %[[FINAL]] : tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_with_start_end_mask -// CHECK-SAME: (%[[INPUT:.*]]: tensor<32x1x97xi32>, %[[BEGIN:.*]]: tensor<3xi32>, %[[END:.*]]: tensor<3xi32>) -func.func @strided_slice_nonconstant_begin_end_with_start_end_mask(%input: tensor<32x1x97xi32>, %begin: tensor<3xi32>, %end: tensor<3xi32>) -> (tensor<1x97xi32>) { - %strides = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - - // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[INDEX:.*]] = "mhlo.slice"(%[[BEGIN]]) - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64> - // CHECK-DAG-SAME: limit_indices = dense<1> : tensor<1xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64> - // CHECK-NEXT: %[[INDEX2:.*]] = mhlo.reshape %[[INDEX]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] - // CHECK-DAG-SAME: {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[INDEX2]], %[[DIM]] : (tensor, tensor) -> tensor - // CHECK-NEXT: %[[INDEX3:.*]] = mhlo.select %[[CMP]], %[[WRAP]], %[[INDEX2]] : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor - // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg1, %[[INDEX3]], %[[ZERO]], %[[ZERO]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 1, 97]> : tensor<3xi64>} : - // CHECK-DAG-SAME: (tensor<32x1x97xi32>, tensor, tensor, tensor) -> tensor<1x1x97xi32> - // CHECK-NEXT: %[[FINAL:.*]] = mhlo.reshape %[[SLICED]] : (tensor<1x1x97xi32>) -> tensor<1x97xi32> - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 6 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 6 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_1 -func.func @strided_slice_nonconstant_begin_end_stride_1(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>, %strides: tensor<1xi32>) -> (tensor<1x97xi32>) { - // Dynamic stride: when `begin` and `end` inputs are unknown at compile time, - // `strides` must be known. - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_stride_2 -func.func @strided_slice_nonconstant_begin_end_stride_2(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // Invalid stride (not equal to 1): when `begin` and `end` inputs are unknown - // at compile time, `strides` must be known to have all 1 values. - %strides = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 4 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_invalid_elem_count -func.func @strided_slice_nonconstant_begin_end_invalid_elem_count(%input: tensor<4x8xf32>, %begin: tensor<2xi64>, %end: tensor<2xi64>) -> tensor<6x10xf32> { - %strides = "tf.Const"() { value = dense<[1, 1]> : tensor<2xi64> } : () -> tensor<2xi64> - // When begin/end are dynamic, the number of output elements must be equal to - // the number of input elements sliced. - // CHECK: tf.StridedSlice - %0 = "tf.StridedSlice"(%input, %begin, %end, %strides) : (tensor<4x8xf32>, tensor<2xi64>, tensor<2xi64>, tensor<2xi64>) -> tensor<6x10xf32> - func.return %0 : tensor<6x10xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_ellipsis_mask -func.func @strided_slice_nonconstant_begin_end_and_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This ellipsis mask is not supported because it does not refer to the last - // dimension. - // [0, 1, 0] = 2 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: tf.StridedSlice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask -func.func @strided_slice_nonconstant_begin_end_and_valid_ellipsis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This ellipsis mask is supported because it refers to the last dimension. - // [1, 0, 0] = 4 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: mhlo.dynamic_slice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 4 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -// ----- - -// CHECK-LABEL: strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask -func.func @strided_slice_nonconstant_begin_end_and_valid_shrink_axis_mask(%input: tensor<32x1x97xi32>, %begin: tensor<1xi32>, %end: tensor<1xi32>) -> (tensor<1x97xi32>) { - // This shrink_axis mask is supported because it refers to a major dimension. - // [1, 1, 1] = 7 - %strides = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: mhlo.dynamic_slice - %result = "tf.StridedSlice"(%input, %begin, %end, %strides) {Index = i32, T = i32, begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 7 : i64} : (tensor<32x1x97xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x97xi32> - func.return %result : tensor<1x97xi32> -} - -//===----------------------------------------------------------------------===// -// Reduction op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @mean -func.func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %{{.*}} {broadcast_dimensions = array} : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[MEAN]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @mean_scalar_dim -func.func @mean_scalar_dim(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // Verify that tf.Mean op with scalar attributes are lowered successfully. - - // CHECK-NOT: tf.Mean - %dimension = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @mean_dynamic -func.func @mean_dynamic(%arg0: tensor) -> tensor { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor) -> tensor - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor, tensor) -> tensor - // CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor -> tensor<2xindex> - // CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index - // CHECK: %[[REDUCED_DIM:.*]] = tensor.extract %[[SHAPE0]][%[[C1_2]]] : tensor<2xindex> - // CHECK: %[[MUL:.*]] = arith.muli %[[C1_1]], %[[REDUCED_DIM]] : index - // CHECK: %[[INDEX_CAST:.*]] = arith.index_cast %[[MUL]] : index to i64 - // CHECK: %[[TENSOR:.*]] = tensor.from_elements %[[INDEX_CAST]] : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[TENSOR]] : (tensor) -> tensor - // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[CONVERT]] {broadcast_dimensions = array} : (tensor, tensor) -> tensor - // CHECK: %[[MEAN_CONVERTED:.*]] = mhlo.convert %[[MEAN]] : (tensor) -> tensor - // CHECK: %[[SHAPE1:.*]] = shape.shape_of %[[MEAN_CONVERTED]] : tensor -> tensor<1xindex> - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK: %[[UNREDUCED_DIM:.*]] = tensor.extract %[[SHAPE1]][%[[C0]]] : tensor<1xindex> - // CHECK: %[[RESULT_SHAPE:.*]] = tensor.from_elements %[[UNREDUCED_DIM]], %[[C1]] : tensor<2xindex> - // CHECK: %[[RESULT:.*]] = mhlo.dynamic_reshape %[[MEAN_CONVERTED]], %[[RESULT_SHAPE]] : (tensor, tensor<2xindex>) -> tensor - // CHECK: return %[[RESULT]] : tensor - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Mean"(%arg0, %dimension) { keep_dims = true }: (tensor, tensor<1xi64>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @sum -func.func @sum(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @sum_dynamic -func.func @sum_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x?xf16>) -> tensor<4x?xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<-0.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.add across dimensions = [1] : (tensor<4x?xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Sum"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @max -func.func @max(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @max_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Max when using quantized integer types. -func.func @max_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: func @max_dynamic -func.func @max_dynamic(%arg0: tensor<4x?xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x?xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0xFC00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.maximum across dimensions = [1] : (tensor<4x?xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Max"(%arg0, %dimension) { keep_dims = true }: (tensor<4x?xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @min -func.func @min(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : tensor<4x8xf16> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<0x7C00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.minimum across dimensions = [1] : (tensor<4x8xf16>, tensor) -> tensor<4xf16> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @min_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Min when using quantized integer types. -func.func @min_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Min"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: func @prod -func.func @prod(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { - // CHECK: %[[CAST:.*]] = mhlo.convert %arg0 : (tensor<4x8xf16>) -> tensor<4x8xf32> - // CHECK: %[[INITIAL:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%[[CAST]] init: %[[INITIAL]]) applies mhlo.multiply across dimensions = [1] : (tensor<4x8xf32>, tensor) -> tensor<4xf32> - // CHECK: %[[CAST_BACK:.*]] = mhlo.convert %[[REDUCED]] : (tensor<4xf32>) -> tensor<4xf16> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[CAST_BACK]] : (tensor<4xf16>) -> tensor<4x1xf16> - // CHECK: return %[[RESULT]] : tensor<4x1xf16> - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8xf16>, tensor<1xi64>) -> tensor<4x1xf16> - func.return %0 : tensor<4x1xf16> -} - -// ----- - -// CHECK-LABEL: func @prod_qint -// Regression test to ensure we don't crash getting the initial value for -// tf.Prod when using quantized integer types. -func.func @prod_qint(%arg0: tensor<4x8x!tf_type.qint8>) -> tensor<4x1x!tf_type.qint8> { - %dimension = "tf.Const"() { value = dense<1> : tensor<1xi64> } : () -> tensor<1xi64> - %0 = "tf.Prod"(%arg0, %dimension) { keep_dims = true }: (tensor<4x8x!tf_type.qint8>, tensor<1xi64>) -> tensor<4x1x!tf_type.qint8> - func.return %0 : tensor<4x1x!tf_type.qint8> -} - -// ----- - -// CHECK-LABEL: @all -func.func @all(%input: tensor<4x8xi1>) -> tensor<4xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK: %[[REDUCED:.*]] = mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.and across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> - %0 = "tf.All"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- - -// CHECK-LABEL: @all_keep_dim -func.func @all_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @all_dynamic -func.func @all_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> - // CHECK: mhlo.reduce(%[[ARG]] - %0 = "tf.All"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @any -func.func @any(%input: tensor<4x8xi1>) -> tensor<4xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK: mhlo.reduce(%{{.*}} init: %[[INIT]]) applies mhlo.or across dimensions = [1] : (tensor<4x8xi1>, tensor) -> tensor<4xi1> - %0 = "tf.Any"(%input, %dims) : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4xi1> - func.return %0 : tensor<4xi1> -} - -// ----- - -// CHECK-LABEL: @any_keep_dim -func.func @any_keep_dim(%input: tensor<4x8xi1>) -> tensor<4x1xi1> { - // CHECK: mhlo.reshape %{{.*}} : (tensor<4xi1>) -> tensor<4x1xi1> - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x8xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -// ----- - -// CHECK-LABEL: @any_dynamic -func.func @any_dynamic(%input: tensor<4x?xi1>) -> tensor<4x1xi1> { - %dims = "tf.Const"() { value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[ARG:.*]] = mhlo.convert %{{.*}} : tensor<4x?xi1> - // CHECK: mhlo.reduce(%[[ARG]] - %0 = "tf.Any"(%input, %dims) {keep_dims = true} : (tensor<4x?xi1>, tensor<1xi32>) -> tensor<4x1xi1> - func.return %0 : tensor<4x1xi1> -} - -//===----------------------------------------------------------------------===// -// Tile op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @tile_by_reshape -func.func @tile_by_reshape(%arg0: tensor<4x8xf32>) -> tensor<28x24xf32> { - // CHECK: %[[BROADCASTED:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor<4x8xf32>) -> tensor<7x4x3x8xf32> - // CHECK: %[[RESULT:.*]] = mhlo.reshape %[[BROADCASTED]] : (tensor<7x4x3x8xf32>) -> tensor<28x24xf32> - // CHECK: return %[[RESULT]] : tensor<28x24xf32> - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> - %0 = "tf.Tile"(%arg0, %multiples) : (tensor<4x8xf32>, tensor<2xi64>) -> tensor<28x24xf32> - func.return %0 : tensor<28x24xf32> -} - -// ----- - -// CHECK-LABEL: func @tile_just_broadcast -func.func @tile_just_broadcast(%arg0: tensor<1x1xf32>) -> tensor<7x3xf32> { - // CHECK: %[[RESULT:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x1xf32>) -> tensor<7x3xf32> - // CHECK: return %[[RESULT]] : tensor<7x3xf32> - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi64> } : () -> tensor<2xi64> - %0 = "tf.Tile"(%arg0, %multiples) : (tensor<1x1xf32>, tensor<2xi64>) -> tensor<7x3xf32> - func.return %0 : tensor<7x3xf32> -} - -// ----- - -// CHECK-LABEL: func @tile_dynamic_shape -func.func @tile_dynamic_shape(%arg0: tensor) -> tensor { - %multiples = "tf.Const"() { value = dense<[7,3]> : tensor<2xi32> } : () -> tensor<2xi32> - // CHECK: tensor.dim {{.*}} : tensor - // CHECK: tensor.from_elements {{.*}} : tensor<4xindex> - // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}) <{broadcast_dimensions = dense<[1, 3]> : tensor<2xi64>}> : (tensor, tensor<4xindex>) -> tensor - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<2xindex> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xindex>) -> tensor - %0 = "tf.Tile"(%arg0, %multiples) : (tensor, tensor<2xi32>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// ArgMax/ArgMin op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @argmax_i64_input_i32_output_axis_0 -func.func @argmax_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<-9223372036854775808> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) - // CHECK: %[[COMPARE:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor - // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] - // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor - // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor - // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor - // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> - func.return %0 : tensor<7xi32> -} - -// ----- - -// CHECK-LABEL: func @argmax_f32_input_i64_output_axis_1 -func.func @argmax_f32_input_i64_output_axis_1(%arg0: tensor<3x7xf32>) -> tensor<3xi64> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<0xFF800000> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xf32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xf32>, tensor) -> tensor<3xi64> - func.return %0 : tensor<3xi64> -} - -// ----- - -// CHECK-LABEL: func @argmax_i1_input_i64_output_axis_1 -func.func @argmax_i1_input_i64_output_axis_1(%arg0: tensor<3x7xi1>) -> tensor<3xi64> { - // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense : tensor - // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi1> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x7xi64> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi64> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x7xi1>, tensor) -> tensor<3xi64> - func.return %0 : tensor<3xi64> -} - -// ----- - -// CHECK-LABEL: func @argmax_dynamic_shape_input_output -func.func @argmax_dynamic_shape_input_output(%arg0: tensor<3x?xi32>) -> tensor { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @argmax_dynamic_shape_input -func.func @argmax_dynamic_shape_input(%arg0: tensor<3x?xi32>) -> tensor<3xi32> { - // CHECK-DAG: %[[INIT:.*]] = mhlo.constant dense<-2147483648> : tensor - // CHECK-DAG: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x?xi32> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 1 : i64}> : (tensor<2xindex>) -> tensor<3x?xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: return %[[REDUCE]]#1 : tensor<3xi32> - %axis = "tf.Const"() { value = dense<1> : tensor } : () -> tensor - %0 = "tf.ArgMax"(%arg0, %axis) : (tensor<3x?xi32>, tensor) -> tensor<3xi32> - func.return %0 : tensor<3xi32> -} - -// ----- - -// CHECK-LABEL: func @argmin_i64_input_i32_output_axis_0 -func.func @argmin_i64_input_i32_output_axis_0(%arg0: tensor<3x7xi64>) -> tensor<7xi32> { - // CHECK: %[[INIT:.*]] = mhlo.constant dense<9223372036854775807> : tensor - // CHECK-NEXT: %[[INDEX_INIT:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[SHAPE:.*]] = shape.shape_of %arg0 : tensor<3x7xi64> -> tensor<2xindex> - // CHECK: %[[INDEX:.*]] = "mhlo.dynamic_iota"(%[[SHAPE]]) <{iota_dimension = 0 : i64}> : (tensor<2xindex>) -> tensor<3x7xi32> - // CHECK: %[[REDUCE:.*]]:2 = mhlo.reduce(%arg0 init: %[[INIT]]), (%[[INDEX]] init: %[[INDEX_INIT]]) - // CHECK: (%[[ARG1:.*]]: tensor, %[[ARG3:.*]]: tensor) (%[[ARG2:.*]]: tensor, %[[ARG4:.*]]: tensor) - // CHECK: %[[COMPARE:.*]] = mhlo.compare LE, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[RESULT1:.*]] = mhlo.select %[[COMPARE]], %[[ARG1]], %[[ARG3]] : tensor, tensor - // CHECK: %[[COMPARE_EQ:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], NOTYPE : (tensor, tensor) -> tensor - // CHECK: %[[MIN:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] - // CHECK: %[[RESULT2:.*]] = mhlo.select %[[COMPARE]], %[[ARG2]], %[[ARG4]] : tensor, tensor - // CHECK: %[[RESULT3:.*]] = mhlo.select %[[COMPARE_EQ]], %[[MIN]], %[[RESULT2]] : tensor, tensor - // CHECK: mhlo.return %[[RESULT1]], %[[RESULT3]] : tensor, tensor - // CHECK: return %[[REDUCE]]#1 : tensor<7xi32> - %axis = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %0 = "tf.ArgMin"(%arg0, %axis) : (tensor<3x7xi64>, tensor) -> tensor<7xi32> - func.return %0 : tensor<7xi32> -} - -//===----------------------------------------------------------------------===// -// Random op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @rng_uniform -func.func @rng_uniform(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -// ----- - -// CHECK-LABEL: func @random_uniform_simple -func.func @random_uniform_simple(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*UNIFORM.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomUniform"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -// ----- - -// CHECK-LABEL: func @random_uniform_with_seeds -func.func @random_uniform_with_seeds(%arg0: tensor<4xi32>) -> tensor<32x12x12x64xf32> { - // CHECK: %0 = mhlo.constant dense<[32, 12, 12, 64]> : tensor<4xi32> - // CHECK-NEXT: %1 = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-NEXT: %2 = mhlo.constant dense<1.000000e+00> : tensor - // CHECK-NEXT: %3 = mhlo.convert %0 : (tensor<4xi32>) -> tensor<4xi64> - // CHECK-NEXT: %4 = "mhlo.rng"(%1, %2, %3) <{rng_distribution = #mhlo.rng_distribution}> : (tensor, tensor, tensor<4xi64>) -> tensor<32x12x12x64xf32> - %cst = "tf.Const"() {value = dense<[32, 12, 12, 64]> : tensor<4xi32>} : () -> tensor<4xi32> - %0 = "tf.RandomUniform"(%cst) {seed = 87654321 : i64, seed2 = 0 : i64} : (tensor<4xi32>) -> tensor<32x12x12x64xf32> - // CHECK: return %4 : tensor<32x12x12x64xf32> - func.return %0 : tensor<32x12x12x64xf32> -} - -// ----- - -// CHECK-LABEL: func @rng_std_normal -func.func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { - // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: %[[CONV:.*]] = mhlo.convert %arg0 : (tensor<3xi32>) -> tensor<3xi64> - // CHECK: %[[F32:.*]] = "mhlo.rng"(%[[ZERO]], %[[ONE]], %[[CONV]]) {{.*NORMAL.*}} -> tensor<12x?x64xf32> - %0 = "tf.RandomStandardNormal"(%arg0) : (tensor<3xi32>) -> tensor<12x?x64xf32> - // CHECK: return %[[F32]] - func.return %0 : tensor<12x?x64xf32> -} - -//===----------------------------------------------------------------------===// -// Range op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @range -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { - %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor - // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = array} - // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} - %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> - func.return %3 : tensor<5xf32> -} - -// ----- - -// CHECK-LABEL: func @range_dynamic -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] - // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] - // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] - // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] - // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] - // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 - // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} - %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[ADD]] - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: func @range_int_dynamic -// CHECK-SAME: [[START:%.*]]: tensor, [[DELTA:%.*]]: tensor -func.func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK-DAG: [[SUB:%.+]] = mhlo.subtract %arg1, %arg0 - // CHECK-DAG: [[ABS1:%.+]] = mhlo.abs [[SUB]] - // CHECK-DAG: [[CONVERT_1:%.+]] = mhlo.convert [[ABS1]] - // CHECK-DAG: [[CONVERT_2:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[DIV:%.+]] = mhlo.divide [[CONVERT_1]], [[CONVERT_2]] - // CHECK-DAG: [[CEIL:%.+]] = mhlo.ceil [[DIV]] - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert [[CEIL]] - // CHECK-DAG: [[RESHAPE:%.+]] = mhlo.reshape [[CONVERT_3]] - // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[CONVERT_3:%.+]] = mhlo.convert %arg0 - // CHECK-DAG: [[CONVERT_4:%.+]] = mhlo.convert %arg2 - // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT_4]] {broadcast_dimensions = array} - // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT_3]] {broadcast_dimensions = array} - %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor - - // CHECK: return [[ADD]] - func.return %2 : tensor -} - -// ----- - -// CHECK-LABEL: func @linspace_static -// CHECK-SAME: [[START:%.*]]: tensor, [[STOP:%.*]]: tensor -func.func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { - // CHECK-DAG: [[NUM:%.*]] = mhlo.constant dense<4> - // CHECK-DAG: [[NUM_F32:%.*]] = mhlo.convert [[NUM]] - // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] - // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> - // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = array} - // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = array} - // CHECK: return [[LINSPACE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @linspace_dynamic -func.func @linspace_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: "tf.LinSpace" - %0 = "tf.LinSpace"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @linspace_invalid_num -func.func @linspace_invalid_num(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: mhlo.constant dense<> : tensor<0xi32> - // CHECK: "tf.LinSpace" - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<> : tensor<0xi32>} : () -> tensor<0xi32> - %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor<0xi32>) -> tensor - func.return %1 : tensor -} - -//===----------------------------------------------------------------------===// -// LegacyCall op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -func.func @identity_func(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { - func.return %arg0: tensor<10x2xf32> -} - -// CHECK-LABEL: testSimpleLegacyCallOp -func.func @testSimpleLegacyCallOp(%arg0: tensor<10x2xf32>) -> tensor<10x2xf32> { - // CHECK: %[[RESULT:.*]] = call @identity_func(%arg0) : (tensor<10x2xf32>) -> tensor<10x2xf32> - %0 = "tf.LegacyCall"(%arg0) {f = @identity_func} : (tensor<10x2xf32>) -> tensor<10x2xf32> - // CHECK: return %[[RESULT]] - func.return %0: tensor<10x2xf32> -} - -// ----- - -func.func @select_first(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { - func.return %arg0: tensor<10x2xf32> -} - -// CHECK-LABEL: testMultiInputLegacyCallOp -func.func @testMultiInputLegacyCallOp(%arg0: tensor<10x2xf32>, %arg1: tensor<10x2xf32>) -> tensor<10x2xf32> { - // CHECK: %[[RESULT:.*]] = call @select_first(%arg0, %arg1) : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> - %0 = "tf.LegacyCall"(%arg0, %arg1) {_disable_call_shape_inference = true, _tpu_replicate = "cluster", device = "", f = @select_first} : (tensor<10x2xf32>, tensor<10x2xf32>) -> tensor<10x2xf32> - // CHECK: return %[[RESULT]] - func.return %0: tensor<10x2xf32> -} - -//===----------------------------------------------------------------------===// -// Conv op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: conv_simple -func.func @conv_simple(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[0, 1], [2, 3]], rhs_dilate = [2, 3]} - // CHECK-SAME: batch_group_count = 1 - // CHECK-SAME: feature_group_count = 2 - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x8x7x16xf32> - func.return %0 : tensor<256x8x7x16xf32> -} - -// ----- - -// CHECK-LABEL: conv3d_simple -func.func @conv3d_simple(%arg0: tensor<256x32x32x32x6xf32>, %arg1: tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] - // CHECK-SAME{LITERAL}: window = {stride = [5, 6, 7], pad = [[1, 2], [2, 3], [2, 3]], rhs_dilate = [2, 3, 4]} - // CHECK-SAME: batch_group_count = 1 - // CHECK-SAME: feature_group_count = 2 - - %0 = "tf.Conv3D"(%arg0, %arg1) {data_format = "NDHWC", dilations = [1, 2, 3, 4, 1], padding = "SAME", strides = [1, 5, 6, 7, 1]} : (tensor<256x32x32x32x6xf32>, tensor<3x3x3x3x16xf32>) -> tensor<256x7x6x5x16xf32> - func.return %0 : tensor<256x7x6x5x16xf32> -} - -// ----- - -// CHECK-LABEL: depthwiseconv_simple -func.func @depthwiseconv_simple(%arg0: tensor, %arg1: tensor<2x2x3x3xf32>) -> tensor { - // CHECK: %[[RESHAPED_FILTER:.*]] = mhlo.reshape %arg1 : (tensor<2x2x3x3xf32>) -> tensor<2x2x1x9xf32> - // CHECK: mhlo.convolution(%arg0, %[[RESHAPED_FILTER]]) - // CHECK-SAME: feature_group_count = 3 - %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { - data_format = "NHWC", - device = "", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1] - } : (tensor, tensor<2x2x3x3xf32>) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: conv_valid_padding -func.func @conv_valid_padding(%arg0: tensor<1x4x5x1xf32>, %arg1: tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> { - // CHECK: mhlo.convolution(%arg0, %arg1) - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x4x5x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x2x3x1xf32> - func.return %0 : tensor<1x2x3x1xf32> -} - -// ----- - -// CHECK-LABEL: conv_explicit_paddings -func.func @conv_explicit_paddings(%arg0: tensor<256x32x32x6xf32>, %arg1: tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME{LITERAL}: pad = [[6, 0], [3, 3]] - - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "EXPLICIT", explicit_paddings = [0, 0, 6, 0, 3, 3, 0, 0], strides = [1, 4, 5, 1]} : (tensor<256x32x32x6xf32>, tensor<3x3x3x16xf32>) -> tensor<256x9x7x16xf32> - func.return %0 : tensor<256x9x7x16xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input_dynamic -func.func @conv2d_backprop_input_dynamic(%filter: tensor<2x2x1x16xf32>, %out_backprop: tensor) -> tensor { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %cst_0_1d = "tf.Const"() {device = "", value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1_0d = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor - %cst_1_1d = "tf.Const"() {device = "", value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_512_0d = "tf.Const"() {device = "", value = dense<512> : tensor} : () -> tensor - %out_backprop_shape = "tf.Shape"(%out_backprop) {device = ""} : (tensor) -> tensor<4xi32> - %batch_size = "tf.StridedSlice"(%out_backprop_shape, %cst_0_1d, %cst_1_1d, %cst_1_1d) {begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor - %input_shape = "tf.Pack"(%batch_size, %cst_512_0d, %cst_512_0d, %cst_1_0d) {axis = 0 : i64, device = ""} : (tensor, tensor, tensor, tensor) -> tensor<4xi32> - %result = "tf.Conv2DBackpropInput"(%input_shape, %filter, %out_backprop) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true} : (tensor<4xi32>, tensor<2x2x1x16xf32>, tensor) -> tensor - return %result : tensor -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input -func.func @conv2d_backprop_input( - %filter: tensor<3x3x1x32xf32>, - %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<100x28x28x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1]> : tensor<2xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %input_sizes = "tf.Const" () { value = dense<[100,28,28,1]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<4xi32>, tensor<3x3x1x32xf32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32> - func.return %result : tensor<100x28x28x1xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_input_grouped -func.func @conv2d_backprop_input_grouped( - %filter: tensor<2x2x5x21xf32>, - %out_backprop: tensor<5x2x2x21xf32> - ) -> tensor<5x3x3x15xf32> { - %input_sizes = "tf.Const" () { value = dense<[5, 3, 3, 15]> : tensor<4xi32> } : () -> tensor<4xi32> - - // Verify filter transformation for grouped convolution. - - // CHECK: %[[RESHAPE:.*]] = mhlo.reshape %arg0 : (tensor<2x2x5x21xf32>) -> tensor<2x2x5x3x7xf32> - // CHECK: %[[TRANSPOSE:.*]] = "mhlo.transpose"(%[[RESHAPE]]) - // CHECK-SAME: permutation = dense<[0, 1, 3, 2, 4]> - // CHECK-SAME: (tensor<2x2x5x3x7xf32>) -> tensor<2x2x3x5x7xf32> - // CHECK: mhlo.reshape %[[TRANSPOSE]] : (tensor<2x2x3x5x7xf32>) -> tensor<2x2x15x7xf32> - - %result = "tf.Conv2DBackpropInput"(%input_sizes, %filter, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<4xi32>, tensor<2x2x5x21xf32>, tensor<5x2x2x21xf32>) -> tensor<5x3x3x15xf32> - func.return %result : tensor<5x3x3x15xf32> -} - - -// CHECK-LABEL: @conv3d_backprop_input -func.func @conv3d_backprop_input(%filter: tensor<3x3x3x1x6xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> { - // CHECK: %[[REV_FILTER:.*]] = "mhlo.reverse"(%arg0) <{dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg1, %[[REV_FILTER]]) - // CHECK-SAME: dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, o, i]->[b, 0, 1, 2, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64, - // CHECK-SAME: feature_group_count = 1 : i64 - - // CHECK: return %[[RESULT]] - %input_sizes = "tf.Const" () {value = dense<[2, 8, 8, 8, 1]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropInputV2"(%input_sizes, %filter, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<5xi32>, tensor<3x3x3x1x6xf32>, tensor<2x8x8x8x6xf32>) -> tensor<2x8x8x8x1xf32> - func.return %result : tensor<2x8x8x8x1xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_filter -func.func @conv2d_backprop_filter( - %input: tensor<100x28x28x1xf32>, - %out_backprop: tensor<100x26x26x32xf32> - ) -> tensor<3x3x1x32xf32> { - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [f, 0, 1, b]x[i, 0, 1, o]->[0, 1, b, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %filter_sizes = "tf.Const" () { value = dense<[3,3,1,32]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<3x3x1x32xf32> - func.return %result : tensor<3x3x1x32xf32> -} - -// ----- - -// CHECK-LABEL: @conv2d_backprop_filter_grouped -func.func @conv2d_backprop_filter_grouped( - %input: tensor<1x2x2x2xf32>, - %out_backprop: tensor<1x1x1x2xf32> - ) -> tensor<2x2x1x2xf32> { - - // CHECK: mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: batch_group_count = 2 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - - %filter_sizes = "tf.Const" () { value = dense<[2, 2, 1, 2]> : tensor<4xi32> } : () -> tensor<4xi32> - %result = "tf.Conv2DBackpropFilter"(%input, %filter_sizes, %out_backprop) { - data_format = "NHWC", - dilations = [1, 1, 1, 1], - explicit_paddings = [], - padding = "VALID", - strides = [1, 1, 1, 1], - use_cudnn_on_gpu = true - } : (tensor<1x2x2x2xf32>, tensor<4xi32>, tensor<1x1x1x2xf32>) -> tensor<2x2x1x2xf32> - func.return %result : tensor<2x2x1x2xf32> -} - - -// CHECK-LABEL: @conv3d_backprop_filter -func.func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> { - // CHECK: %[[RESULT:.*]] = mhlo.convolution(%arg0, %arg1) - // CHECK-SAME: dim_numbers = [f, 0, 1, 2, b]x[i, 0, 1, 2, o]->[0, 1, 2, b, f] - // CHECK-SAME{LITERAL}: window = {stride = [1, 1, 1], pad = [[1, 1], [1, 1], [1, 1]], lhs_dilate = [1, 1, 1], rhs_dilate = [1, 1, 1]} - // CHECK-SAME: batch_group_count = 1 : i64 - // CHECK-SAME: feature_group_count = 1 : i64 - // CHECK: return %[[RESULT]] - %filter_sizes = "tf.Const"() {value = dense<[3, 3, 3, 1, 6]> : tensor<5xi32>} : () -> tensor<5xi32> - %result = "tf.Conv3DBackpropFilterV2"(%input, %filter_sizes, %out_backprop) {data_format = "NDHWC", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1, 1]} : (tensor<2x8x8x8x1xf32>, tensor<5xi32>, tensor<2x8x8x8x6xf32>) -> tensor<3x3x3x1x6xf32> - func.return %result : tensor<3x3x3x1x6xf32> -} - -// ----- - -// CHECK-LABEL: @collective_permute -func.func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { - %source_target_pairs = "tf.Const" () { - value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> - } : () -> tensor<3x2xi32> - - // CHECK: "mhlo.collective_permute" - // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> - %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { - } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> - - func.return %0 : tensor<128x32xf32> -} - -// ----- - -// CHECK-LABEL: @cross_replica_sum -func.func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { - %replica_groups = "tf.Const" () { - value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32> - } : () -> tensor<2x4xi32> - - // CHECK: mhlo.cross-replica-sum - // CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - %result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32> - func.return %result : tensor<10xf32> -} - -// ----- - -// CHECK-LABEL: conv_dynamic -func.func @conv_dynamic(%arg0: tensor, %arg1: tensor<3x3x3x16xf32>) -> tensor { - // CHECK: "mhlo.dynamic_conv" - // CHECK-SAME: <{batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, feature_group_count = 2 : i64, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<[2, 3]> : tensor<2xi64>, window_strides = dense<[4, 5]> : tensor<2xi64>}> : (tensor, tensor<3x3x3x16xf32>, tensor<4xi32>) -> tensor - %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 2, 3, 1], padding = "SAME", strides = [1, 4, 5, 1]} : (tensor, tensor<3x3x3x16xf32>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.Split legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @split_not_match_dynamic_split_dim_input -func.func @split_not_match_dynamic_split_dim_input(%input: tensor<4x4xf32>, %split_dim: tensor) -> (tensor<*xf32>, tensor<*xf32>) { - // CHECK: tf.Split - %0:2 = "tf.Split"(%split_dim, %input) : (tensor, tensor<4x4xf32>) -> (tensor<*xf32>, tensor<*xf32>) - func.return %0#0, %0#1 : tensor<*xf32>, tensor<*xf32> -} - -// ----- - -// CHECK-LABEL: @split_not_match_dynamic_input_shape -func.func @split_not_match_dynamic_input_shape(%input: tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) { - %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> - // CHECK: arith.divsi {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> - // CHECK: muli {{.*}} : index - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<4x?x4xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<4x?x4xf32>, tensor<4x?x4xf32>) - func.return %0#0, %0#1 : tensor<4x?x4xf32>, tensor<4x?x4xf32> -} - -// ----- - -// CHECK-LABEL: @split_not_match_static_split_dim_size -func.func @split_not_match_static_split_dim_size(%input: tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: tensor.dim {{.*}} : tensor<4x?x4xf32> - // CHECK: arith.divsi {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> - // CHECK: muli {{.*}} : index - // CHECK: muli {{.*}} : index - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: tensor.from_elements {{.*}} : tensor<3xindex> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor<4x?x4xf32>, tensor<3xindex>, tensor<3xindex>, tensor<3xindex>) -> tensor<2x?x4xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x?x4xf32>) -> (tensor<2x?x4xf32>, tensor<2x?x4xf32>) - func.return %0#0, %0#1 : tensor<2x?x4xf32>, tensor<2x?x4xf32> -} - -// ----- - -// CHECK-LABEL: @split_match_and_split_into_two -func.func @split_match_and_split_into_two(%input: tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[2, 6]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<2x6xf32> - %0:2 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<2x6xf32>, tensor<2x6xf32>) - // CHECK: return %[[ONE]], %[[TWO]] - func.return %0#0, %0#1 : tensor<2x6xf32>, tensor<2x6xf32> -} - -// ----- - -// CHECK-LABEL: @split_match_and_split_into_three -// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) -func.func @split_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) { - %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 2]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<4> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 4]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - %0:3 = "tf.Split"(%cst, %input) : (tensor, tensor<4x6xf32>) -> (tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32>) - // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] - func.return %0#0, %0#1, %0#2 : tensor<4x2xf32>, tensor<4x2xf32>, tensor<4x2xf32> -} - -//===----------------------------------------------------------------------===// -// tf.TopKV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: topk_v2_non_const_k -func.func @topk_v2_non_const_k(%input: tensor<16xf32>, %k: tensor) -> (tensor, tensor) { - // CHECK: tf.TopKV2 - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16xf32>, tensor) -> (tensor, tensor) - func.return %0#0, %0#1: tensor, tensor -} - -// ----- - -// CHECK-LABEL: topk_v2_unknown_input_last_dim -func.func @topk_v2_unknown_input_last_dim(%input: tensor<16x?xf32>) -> (tensor<16x?xf32>, tensor<16x?xi32>) { - %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor - // CHECK: tf.TopKV2 - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x?xf32>, tensor) -> (tensor<16x?xf32>, tensor<16x?xi32>) - func.return %0#0, %0#1: tensor<16x?xf32>, tensor<16x?xi32> -} - -// ----- - -// CHECK-LABEL: topk_v2 -// CHECK-SAME: %[[INPUT:.*]]: tensor<16x16xf32> -func.func @topk_v2(%input: tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { - %k = "tf.Const"() {value = dense<8> : tensor} : () -> tensor - - // CHECK: chlo.top_k(%[[INPUT]], k = 8) - %0:2 = "tf.TopKV2"(%input, %k): (tensor<16x16xf32>, tensor) -> (tensor<16x8xf32>, tensor<16x8xi32>) - func.return %0#0, %0#1: tensor<16x8xf32>, tensor<16x8xi32> -} - -//===----------------------------------------------------------------------===// -// tf.SplitV legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @splitv_match_and_split_into_three -// CHECK-SAME: (%[[ARG:.*]]: tensor<4x6xf32>) -func.func @splitv_match_and_split_into_three(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: %[[ONE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x1xf32> - // CHECK: %[[TWO:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x2xf32> - // CHECK: %[[THREE:.*]] = "mhlo.slice"(%[[ARG]]) <{limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x6xf32>) -> tensor<4x3xf32> - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) - // CHECK: return %[[ONE]], %[[TWO]], %[[THREE]] - func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> -} - -// ----- - -// CHECK-LABEL: @splitv_dynamic_dim_in_split_sizes -func.func @splitv_dynamic_dim_in_split_sizes(%input: tensor<4x6xf32>) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) { - %split_sizes = "tf.Const"() {value = dense<[1, -1, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: limit_indices = dense<[4, 1]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64> - // CHECK: limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64> - // CHECK: limit_indices = dense<[4, 6]> : tensor<2xi64>, start_indices = dense<[0, 3]> : tensor<2xi64> - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor<4x6xf32>, tensor<3xi32>, tensor) -> (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) - func.return %0#0, %0#1, %0#2 : tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32> -} - -// ----- - -// CHECK-LABEL: @splitv_dynamic -func.func @splitv_dynamic(%input: tensor) -> (tensor, tensor, tensor) { - %split_sizes = "tf.Const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32> - %split_dim = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: tf.SplitV - %0:3 = "tf.SplitV"(%input, %split_sizes, %split_dim) : (tensor, tensor<3xi32>, tensor) -> (tensor, tensor, tensor) - func.return %0#0, %0#1, %0#2 : tensor, tensor, tensor -} - -//===----------------------------------------------------------------------===// -// tf.Assert legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @assert -func.func @assert(%arg0: tensor, %arg1: tensor<*xf32>) { - // CHECK-NOT: tf.Assert - "tf.Assert"(%arg0, %arg1) {summarize = 1} : (tensor, tensor<*xf32>) -> () - func.return -} - -//===----------------------------------------------------------------------===// -// tf.Unpack legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @unpack -func.func @unpack(%input: tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) { - // CHECK: %[[SLICE1:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 1, 6]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES1:.*]] = mhlo.reshape %[[SLICE1]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE2:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 2, 6]> : tensor<3xi64>, start_indices = dense<[0, 1, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES2:.*]] = mhlo.reshape %[[SLICE2]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - // CHECK: %[[SLICE3:.*]] = "mhlo.slice"(%{{.*}}) <{limit_indices = dense<[4, 3, 6]> : tensor<3xi64>, start_indices = dense<[0, 2, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> : (tensor<4x3x6xf32>) -> tensor<4x1x6xf32> - // CHECK: %[[RES3:.*]] = mhlo.reshape %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x6xf32> - - %0:3 = "tf.Unpack"(%input) {axis = 1} : (tensor<4x3x6xf32>) -> (tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32>) - // return %[[RES1]], %[[RES2]], %[[RES3]] - func.return %0#0, %0#1, %0#2 : tensor<4x6xf32>, tensor<4x6xf32>, tensor<4x6xf32> -} - -// ----- - -// CHECK-LABEL: func @unpack_dynamic -func.func @unpack_dynamic(%arg0: tensor) -> (tensor, tensor) { - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<3xi32> - // CHECK: mhlo.real_dynamic_slice {{.*}} : (tensor, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor - // CHECK: tensor.from_elements {{.*}} : tensor<2xi32> - // CHECK: mhlo.dynamic_reshape {{.*}} : (tensor, tensor<2xi32>) -> tensor - // CHECK: return {{.*}} : tensor, tensor - %0:2 = "tf.Unpack"(%arg0) {axis = -1 : i64} : (tensor) -> (tensor, tensor) - func.return %0#0, %0#1 : tensor, tensor -} - -//===----------------------------------------------------------------------===// -// tf.UnsortedSegment{Max|Min|Prod|Sum} legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @unsorted_segment_sum -// CHECK-SAME: [[DATA:%.*]]: tensor<8x16x64xf32> -// CHECK-SAME: [[SI:%.*]]: tensor<8x16xi32> -func.func @unsorted_segment_sum(%data: tensor<8x16x64xf32>, %segment_ids : tensor<8x16xi32>) -> (tensor<4x64xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ZERO]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) - // CHECK-SAME: indices_are_sorted = false, - // CHECK-SAME: scatter_dimension_numbers = - // CHECK-SAME: update_window_dims = [2] - // CHECK-SAME: inserted_window_dims = [0] - // CHECK-SAME: scatter_dims_to_operand_dims = [0] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: unique_indices = false - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[ADD]] - // CHECK-NEXT: (tensor<4x64xf32>, tensor<8x16xi32>, tensor<8x16x64xf32>) -> tensor<4x64xf32> - // CHECK: return [[SCATTER]] - %0 = "tf.UnsortedSegmentSum"(%data, %segment_ids, %num_segments) : (tensor<8x16x64xf32>, tensor<8x16xi32>, tensor) -> (tensor<4x64xf32>) - func.return %0: tensor<4x64xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_prod -// CHECK-SAME: [[DATA:%.*]]: tensor<8x?x64xf32> -// CHECK-SAME: [[SI:%.*]]: tensor -func.func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: [[INIT:%.*]] = "mhlo.broadcast"([[ONE]]) <{broadcast_sizes = dense<[4, 64]> : tensor<2xi64>}> : (tensor) -> tensor<4x64xf32> - // CHECK: [[SCATTER:%.*]] = "mhlo.scatter"([[INIT]], [[SI]], [[DATA]]) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers = - // CHECK-SAME: update_window_dims = [2] - // CHECK-SAME: inserted_window_dims = [0] - // CHECK-SAME: scatter_dims_to_operand_dims = [0] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: unique_indices = false - // CHECK: ^{{.*}}([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: [[MUL:%.*]] = mhlo.multiply [[LHS]], [[RHS]] : tensor - // CHECK: mhlo.return [[MUL]] - // CHECK-NEXT: (tensor<4x64xf32>, tensor, tensor<8x?x64xf32>) -> tensor<4x?xf32> - // CHECK: return [[SCATTER]] - %0 = "tf.UnsortedSegmentProd"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_min -func.func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<3.40282347E+38> : tensor - // CHECK: mhlo.scatter - // CHECK: mhlo.minimum - %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -// ----- - -// CHECK-LABEL: @unsorted_segment_max -func.func @unsorted_segment_max(%data: tensor<8x?x64xf32>, %segment_ids : tensor) -> (tensor<4x?xf32>) { - %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor - // CHECK: mhlo.scatter - // CHECK: mhlo.maximum - %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) - func.return %0: tensor<4x?xf32> -} - -//===----------------------------------------------------------------------===// -// tf.GatherNd legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @gatherNd_dynamic -func.func @gatherNd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: tensor.dim - // CHECK: index_cast - // CHECK: tensor.from_elements - // CHECK: mhlo.dynamic_gather - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [2] - // CHECK-SAME: collapsed_slice_dims = [0, 1] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 2 - // CHECK-SAME: indices_are_sorted = false - %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: func @gatherNd_static -func.func @gatherNd_static(%arg0: tensor<2x4x128xf32>, %arg1: tensor<2x1xi32>) -> tensor<2x4x128xf32> { - // CHECK: "mhlo.gather"({{.*}}) <{ - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [1, 2] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: slice_sizes = dense<[1, 4, 128]> - // CHECK-SAME: (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> - %0 = "tf.GatherNd"(%arg0, %arg1) {Tindices = i32, Tparams = i32, device = ""} : (tensor<2x4x128xf32>, tensor<2x1xi32>) -> tensor<2x4x128xf32> - func.return %0 : tensor<2x4x128xf32> -} - -//===----------------------------------------------------------------------===// -// tf.GatherV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @gather_v2 -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2(%params: tensor<16x2x3xf32>, %indices: tensor<16x5xi32>) -> tensor<16x2x5xf32> { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor<16x2x3xf32>, tensor<16x5xi32>, tensor<1xi32>) -> tensor<16x2x5xf32> - func.return %1 : tensor<16x2x5xf32> -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic_index_i64 -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic_index_i64(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -// ----- - -// CHECK-LABEL: @gather_v2_dynamic_shape -// CHECK-SAME: %[[PARAMS:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] -func.func @gather_v2_dynamic_shape(%params: tensor, %indices: tensor) -> tensor { - // CHECK: mhlo.torch_index_select - // CHECK-SAME: %[[PARAMS]], %[[INDICES]] - // CHECK-SAME: batch_dims = 1 - // CHECK-SAME: dim = 2 - %axis = "tf.Const"() { value = dense<[-1]> : tensor<1xi32> } : () -> tensor<1xi32> - %1 = "tf.GatherV2"(%params, %indices, %axis) {batch_dims = -1 : i64} : (tensor, tensor, tensor<1xi32>) -> tensor - func.return %1 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.StridedSliceGrad legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: strided_slice_grad -// CHECK-SAME: [[GRAD:%.*]]: tensor<4x16x1022xf32> -func.func @strided_slice_grad(%grad: tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> { - - // For StridedSlice - // Dim #: 0, 1, 2 - // Input shape: [4, 128, 1024] - // Begin: 1, 4, -3 - // End: 8, 65, 42 - // Stride: 1, 4, -1 - // Begin mask: 1, 0, 0 (= 1) - // End mask: 0, 0, 1 (= 4) - - // So result shape: - // Dim #0: begin mask (1) -> begin = 0; end 8 canonicalized to 4: so 4 - // Dim #1: 4 to 65 stride 4: so 16 - // Dim #2: begin -3 + 1024 = 1021; end mask (1) -> end = -1: so 1022 - // result shape: [4, 16, 1022] - - // To pad back: - // Dim #: 0, 1, 2 - // Pad low: 0, 4, 0 - // Pad interm: 0, 3, 0 - // Pad high: 0, 63, 2 - - %shape = "tf.Const"() {value = dense<[4, 128, 1024]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %begin = "tf.Const"() {value = dense<[1, 4, -3]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %end = "tf.Const"() {value = dense<[8, 65, 42]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %strides = "tf.Const"() {value = dense<[1, 4, -1]> : tensor<3xi32>} : () -> (tensor<3xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape %arg0 : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[REVERSE:%.*]] = "mhlo.reverse"([[RESHAPE]]) <{dimensions = dense<2> : tensor<1xi64>}> : (tensor<4x16x1022xf32>) -> tensor<4x16x1022xf32> - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REVERSE]], [[ZERO]]) <{edge_padding_high = dense<[0, 63, 2]> : tensor<3xi64>, edge_padding_low = dense<[0, 4, 0]> : tensor<3xi64>, interior_padding = dense<[0, 3, 0]> : tensor<3xi64>}> : (tensor<4x16x1022xf32>, tensor) -> tensor<4x128x1024xf32> - - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 1, end_mask = 4} : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<4x16x1022xf32>) -> tensor<4x128x1024xf32> - // CHECK: return [[PAD]] - func.return %0: tensor<4x128x1024xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_shrink_axis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<8xf32> -func.func @strided_slice_grad_shrink_axis_mask(%grad: tensor<8xf32>) -> tensor<4x8xf32> { - // Input to StridedSlice was of shape 4x8xf32 - // Strided slice gets input[2:3, 0:8] - // shrink_axis_mask is 1 denoting that dim#0 is shrunk. So the output is 8xf32 - // which is the shape of gradient. - // StridedSliceGrad would reshape the gradient to 1x8xf32 and - // then pad to match the shape of input 4x8xf32. - - %shape = "tf.Const"() {value = dense<[4, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %begin = "tf.Const"() {value = dense<[2, 0]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[3, 8]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<8xf32>) -> tensor<1x8xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<[1, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<2xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, shrink_axis_mask = 1} : (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<8xf32>) -> tensor<4x8xf32> - - // CHECK: return [[PAD]] : tensor<4x8xf32> - func.return %0 : tensor<4x8xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_new_axis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<1x2xf32> -func.func @strided_slice_grad_new_axis_mask(%grad: tensor<1x2xf32>) -> tensor<8xf32> { - // Input to StridedSlice was of shape 8xf32 - // Strided slice gets input[tf.new_axis, 2:4] - // new_axis_mask is 1 denoting new axis is inserted at dim#0. So the output is - // 1x2xf32 which is the shape of gradient. - // StridedSliceGrad would reshape the gradient to 2xf32 and - // then pad to match the shape of input 4x8xf32. - - %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %begin = "tf.Const"() {value = dense<[0, 2]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[0, 4]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x2xf32>) -> tensor<2xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<2> : tensor<1xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<4> : tensor<1xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<1xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, new_axis_mask = 1} : (tensor<1xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<1x2xf32>) -> tensor<8xf32> - - // CHECK: return [[PAD]] : tensor<8xf32> - func.return %0 : tensor<8xf32> -} - -// ----- - -// CHECK-LABEL: strided_slice_grad_ellipsis_mask -// CHECK-SAME: [[GRAD:%.*]]: tensor<2x4x8xf32> -func.func @strided_slice_grad_ellipsis_mask(%grad: tensor<2x4x8xf32>) -> tensor<4x4x8xf32> { - // Input to StridedSlice was of shape 4x4x8xf32 - // Strided slice gets input[2:4, ...] - // ellipsis_mask is 2 denoting that slice contains all elements in dim#1 and - // dim#2, ignoring begin and end indices for these dimensions. So the output - // is 2x4x8xf32 which is the shape of gradient. - // StridedSliceGrad would pad the gradient to match the shape of - // input 4x4x8xf32. - - %shape = "tf.Const"() {value = dense<[4, 4, 8]> : tensor<3xi32>} : () -> (tensor<3xi32>) - %begin = "tf.Const"() {value = dense<[2, 3]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %end = "tf.Const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> (tensor<2xi32>) - %strides = "tf.Const"() {value = dense<1> : tensor<2xi32>} : () -> (tensor<2xi32>) - - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<2x4x8xf32>) -> tensor<2x4x8xf32> - // CHECK: [[ZEROS:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZEROS]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[2, 0, 0]> : tensor<3xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<0> : tensor<3xi64> - // CHECK-DAG-SAME: interior_padding = dense<0> : tensor<3xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 0, end_mask = 0, ellipsis_mask = 2} : (tensor<3xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>, tensor<2x4x8xf32>) -> tensor<4x4x8xf32> - - // CHECK: return [[PAD]] : tensor<4x4x8xf32> - func.return %0 : tensor<4x4x8xf32> -} - - -// CHECK-LABEL: strided_slice_grad_all_masks -// CHECK-SAME: [[GRAD:%.*]]: tensor<1x4x8x8x10x2x1xf32> -func.func @strided_slice_grad_all_masks(%grad: tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> { - // For StridedSlice input[1, tf.new_axis, ..., 8:, :10, 2:6:2, tf.new_axis] - // New axis mask is at index 1 and 6 of sparse spec, so - // new_axis_mask = 2^1 + 2^6 = 66 - // The ellipsis mask is applied to dim #1, #2 of input i.e, we get - // canonicalized slice input[1, :, :, 8:, :10, 2:6:2] - // The StridedSliceGrad op would propogate the gradient for the sliced tensor - // to the original input tensor by padding with zeroes. - - %shape = "tf.Const"() {value = dense<[2, 4, 8, 16, 32, 64]> : tensor<6xi32>} : () -> (tensor<6xi32>) - %begin = "tf.Const"() {value = dense<[1, 0, 0, 8, 1, 2, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %end = "tf.Const"() {value = dense<[2, 0, 0, 10, 10, 6, 0]> : tensor<7xi32>} : () -> (tensor<7xi32>) - %strides = "tf.Const"() {value = dense<[1, 1, 1, 1, 1, 2, 1]> : tensor<7xi32>} : () -> (tensor<7xi32>) - - // Remove 2 new axes (at index 1 and 6) and 1 shrink axis (at index 0) - // CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[GRAD]] : (tensor<1x4x8x8x10x2x1xf32>) -> tensor<1x4x8x8x10x2xf32> - // CHECK: [[ZERO:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // The edge_padding_low, edge_padding_high and interior_padding attributes of - // mhlo.pad would reflect the padding required to get the shape of the - // input of StridedSlice op. - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[RESHAPE]], [[ZERO]]) - // CHECK-DAG-SAME: edge_padding_low = dense<[1, 0, 0, 8, 0, 2]> : tensor<6xi64> - // CHECK-DAG-SAME: edge_padding_high = dense<[0, 0, 0, 0, 22, 59]> : tensor<6xi64> - // CHECK-DAG-SAME: interior_padding = dense<[0, 0, 0, 0, 0, 1]> : tensor<6xi64> - %0 = "tf.StridedSliceGrad"(%shape, %begin, %end, %strides, %grad) {begin_mask = 16, end_mask = 8, shrink_axis_mask = 1, ellipsis_mask = 4, new_axis_mask = 66} : (tensor<6xi32>, tensor<7xi32>, tensor<7xi32>, tensor<7xi32>, tensor<1x4x8x8x10x2x1xf32>) -> tensor<2x4x8x16x32x64xf32> - - // CHECK: return [[PAD]] : tensor<2x4x8x16x32x64xf32> - func.return %0 : tensor<2x4x8x16x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_update -func.func @tensor_scatter_update(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: mhlo.return %arg4 : tensor - // CHECK: }) - %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_update_scalar_update -func.func @tensor_scatter_update_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { - // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> - // CHECK: "mhlo.scatter" - %0 = "tf.TensorScatterUpdate"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> - func.return %0 : tensor<4x3xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_add -func.func @tensor_scatter_add(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.add %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_add_scalar_update -func.func @tensor_scatter_add_scalar_update(%tensor: tensor<4x3xi32>, %indices: tensor<2x1xi32>, %updates: tensor) -> tensor<4x3xi32> { - // CHECK: mhlo.constant dense<[2, 3]> : tensor<2xi64> - // CHECK: "mhlo.dynamic_broadcast_in_dim"(%arg2, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi64>) -> tensor<2x3xi32> - // CHECK: "mhlo.scatter - %0 = "tf.TensorScatterAdd"(%tensor, %indices, %updates) : (tensor<4x3xi32>, tensor<2x1xi32>, tensor) -> tensor<4x3xi32> - func.return %0 : tensor<4x3xi32> -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_sub -func.func @tensor_scatter_sub(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.subtract %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterSub"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_min -func.func @tensor_scatter_min(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.minimum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterMin"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @tensor_scatter_max -func.func @tensor_scatter_max(%tensor: tensor, %indices: tensor, %updates: tensor) -> tensor { - // CHECK: "mhlo.scatter"(%arg0, %arg1, %arg2) - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME: scatter_dimension_numbers - // CHECK-SAME: update_window_dims = [1] - // CHECK-SAME: inserted_window_dims = [0, 1] - // CHECK-SAME: scatter_dims_to_operand_dims = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: unique_indices = false - // CHECK: ^bb0(%arg3: tensor, %arg4: tensor): - // CHECK: %1 = mhlo.maximum %arg3, %arg4 : tensor - // CHECK: mhlo.return %1 : tensor - // CHECK: }) - %0 = "tf.TensorScatterMax"(%tensor, %indices, %updates) : (tensor, tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.RandomShuffle legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @random_shuffle_num_elems_le_1 -func.func @random_shuffle_num_elems_le_1() -> tensor { - // CHECK: [[INPUT:%.*]] = mhlo.constant dense<1.000000e+20> : tensor - // CHECK-NEXT: return [[INPUT]] - %cst = "tf.Const"() {value = dense<1.000000e+20> : tensor} : () -> tensor - %0 = "tf.RandomShuffle"(%cst) {device = "", seed = -4294967297 : i64, seed2 = -2147483649 : i64} : (tensor) -> tensor - return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @random_shuffle_first_dim_1 -// CHECK-SAME: [[INPUT:%.*]]: tensor<1x?xf32> -func.func @random_shuffle_first_dim_1(%input: tensor<1x?xf32>) -> tensor<1x?xf32> { - %0 = "tf.RandomShuffle"(%input) : (tensor<1x?xf32>) -> (tensor<1x?xf32>) - // CHECK-NEXT: return [[INPUT]] - func.return %0: tensor<1x?xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_1D_16 -// CHECK-SAME: [[INPUT:%.*]]: tensor<16xf32> -func.func @random_shuffle_1D_16(%input: tensor<16xf32>) -> tensor<16xf32> { - // CHECK-DAG: [[SHAPE:%.*]] = mhlo.constant dense<16> : tensor<1xi64> - // CHECK-DAG: [[LOWER:%.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[UPPER:%.*]] = mhlo.constant dense<-1> : tensor - // CHECK: [[RNG:%.*]] = "mhlo.rng"([[LOWER]], [[UPPER]], [[SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> - // CHECK: [[SORT:%.*]]:2 = "mhlo.sort"([[RNG]], [[INPUT]]) <{dimension = -1 : i64, is_stable = {{.*}}}> ({ - // CHECK: ^{{.*}}([[ARG1:%.*]]: tensor, [[ARG2:%.*]]: tensor, {{.*}}: tensor, {{.*}}: tensor): - // CHECK: mhlo.compare LT, [[ARG1]], [[ARG2]], TOTALORDER - // CHECK: }) : (tensor<16xi32>, tensor<16xf32>) -> (tensor<16xi32>, tensor<16xf32>) - // CHECK: return [[SORT]]#1 - %0 = "tf.RandomShuffle"(%input) : (tensor<16xf32>) -> (tensor<16xf32>) - func.return %0: tensor<16xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_1D_10240 -func.func @random_shuffle_1D_10240(%input: tensor<10240xf32>) -> tensor<10240xf32> { - // CHECK: mhlo.rng{{.*UNIFORM.*}} - // CHECK: mhlo.sort - // CHECK: mhlo.rng{{.*UNIFORM.*}} - // CHECK: mhlo.sort - %0 = "tf.RandomShuffle"(%input) : (tensor<10240xf32>) -> (tensor<10240xf32>) - func.return %0: tensor<10240xf32> -} - -// ----- - -// CHECK-LABEL: @random_shuffle_3D -// CHECK-SAME: [[INPUT:%.*]]: tensor<4x?x16xf32> -func.func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { - // CHECK: [[INDICES:%.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> - - // CHECK-DAG: [[RNG_SHAPE:%.*]] = mhlo.constant dense<4> : tensor<1xi64> - // CHECK-DAG: [[RNG_LOWER:%.*]] = mhlo.constant dense<0> : tensor - // CHECK-DAG: [[RNG_UPPER:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[SWAPS:%.*]] = "mhlo.rng"([[RNG_LOWER]], [[RNG_UPPER]], [[RNG_SHAPE]]) <{rng_distribution = #mhlo.rng_distribution}> - - // CHECK: [[IV_INIT:%.*]] = mhlo.constant dense<0> : tensor - - // CHECK: [[WHILE_OUT:%.*]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[IV_INIT]], [[ITER_ARG1:.*]] = [[SWAPS]], [[ITER_ARG2:.*]] = [[INDICES]]) - // CHECK: [[LIMIT:%.*]] = mhlo.constant dense<4> : tensor - // CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ITER_ARG0]], [[LIMIT]], NOTYPE - // CHECK: mhlo.return [[CMP]] - // CHECK: } do { - // CHECK: [[SRC_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG1]], [[ITER_ARG0]]) <{slice_sizes = dense<1> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<1xi32> - // CHECK: [[SWP:%.*]] = mhlo.reshape [[SWP_IDX]] : (tensor<1xi32>) -> tensor - // CHECK: [[TGT_IDX:%.*]] = "mhlo.dynamic_slice"([[ITER_ARG2]], [[SWP]]) <{slice_sizes = dense<1> : tensor<1xi64>}> - // CHECK: [[INDICES1:%.*]] = mhlo.dynamic_update_slice [[ITER_ARG2]], [[TGT_IDX]], [[ITER_ARG0]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[INDICES2:%.*]] = mhlo.dynamic_update_slice [[INDICES1]], [[SRC_IDX]], [[SWP]] : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> - // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[ITER_ARG0]], [[ONE]] - // CHECK: mhlo.return [[NEW_IV]], [[ITER_ARG1]], [[INDICES2]] - // CHECK: } - - // CHECK: [[CONSTANT1:%.*]] = mhlo.constant dense<1> : tensor<1xi64> - // CHECK: [[ARITH_CONSTANT:%.*]] = arith.constant 1 : index - // CHECK: [[SHAPE_DIM:%.*]] = shape.dim %arg0, [[ARITH_CONSTANT]] : tensor<4x?x16xf32>, index -> index - // CHECK: [[INDEX_CAST:%.*]] = arith.index_cast [[SHAPE_DIM]] : index to i64 - // CHECK: [[FROM_ELEMENTS:%.*]] = tensor.from_elements [[INDEX_CAST]] : tensor<1xi64> - // CHECK: [[CONSTANT2:%.*]] = mhlo.constant dense<16> : tensor<1xi64> - // CHECK: [[CONCATENATE:%.*]] = "mhlo.concatenate"([[CONSTANT1]], [[FROM_ELEMENTS]], [[CONSTANT2]]) <{dimension = 0 : i64}> : (tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<3xi64> - // CHECK: [[DYNAMIC_GATHER:%.*]] = "mhlo.dynamic_gather"([[INPUT]], [[WHILE_OUT]]#2, [[CONCATENATE]]) - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [1, 2] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = false - // CHECK-SAME:: (tensor<4x?x16xf32>, tensor<4xi32>, tensor<3xi64>) -> tensor<4x?x16xf32> - - // CHECK: return [[DYNAMIC_GATHER]] - - %0 = "tf.RandomShuffle"(%input) : (tensor<4x?x16xf32>) -> (tensor<4x?x16xf32>) - func.return %0: tensor<4x?x16xf32> -} - -//===----------------------------------------------------------------------===// -// tf.AvgPool legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @avgpool_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x12x21x7xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x12x21x7xf16>) -> tensor<2x12x21x7xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x3x5x7xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x3x5x7xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x3x5x7xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_valid_padding(%arg0: tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x21x7xf16>) -> tensor<2x3x5x7xf16> - func.return %0 : tensor<2x3x5x7xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_valid_padding -// CHECK-SAME: [[ARG:%.+]]: tensor<2x4x12x21x7xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x12x21x7xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x3x5x7xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x4x3x5x7xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x4x3x5x7xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_3d_valid_padding(%arg0: tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 2, 2, 1], padding = "VALID", strides = [1, 1, 4, 4, 1]} : (tensor<2x4x12x21x7xf16>) -> tensor<2x4x3x5x7xf16> - func.return %0 : tensor<2x4x3x5x7xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_nchw_format -// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x12x21xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x12x21xf16>) -> tensor<2x7x12x21xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x7x3x5xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x7x3x5xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x7x3x5xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_nchw_format(%arg0: tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NCHW", ksize = [1, 1, 2, 2], padding = "VALID", strides = [1, 1, 4, 4]} : (tensor<2x7x12x21xf16>) -> tensor<2x7x3x5xf16> - func.return %0 : tensor<2x7x3x5xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_ncdhw_format -// CHECK-SAME: [[ARG:%.+]]: tensor<2x7x4x12x21xf16> -// CHECK: [[CONV32:%.+]] = mhlo.convert %arg0 : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x12x21xf32> -// CHECK: [[ZERO:%.+]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: [[DIVIDEND:%.+]] = "mhlo.reduce_window"([[CONV32]], [[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 2]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> -// CHECK: ^bb0([[ARG1:%.+]]: tensor, [[ARG2:%.+]]: tensor): -// CHECK: [[ADD:%.+]] = mhlo.add [[ARG1]], [[ARG2]] -// CHECK: mhlo.return [[ADD]] -// CHECK: }) -// CHECK-SAME: -> tensor<2x7x4x3x5xf32> -// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: [[DIV_RESULT:%.+]] = chlo.broadcast_divide [[DIVIDEND]], [[COUNT]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<2x7x4x3x5xf32> -// CHECK: [[CONV16:%.+]] = mhlo.convert [[DIV_RESULT]] -// CHECK-SAME: -> tensor<2x7x4x3x5xf16> -// CHECK: return [[CONV16]] -func.func @avgpool_3d_ncdhw_format(%arg0: tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NCDHW", ksize = [1, 1, 1, 2, 2], padding = "VALID", strides = [1, 1, 1, 4, 4]} : (tensor<2x7x4x12x21xf16>) -> tensor<2x7x4x3x5xf16> - func.return %0 : tensor<2x7x4x3x5xf16> -} - -// ----- - -// CHECK-LABEL: @avgpool_same_padding( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> -// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x6x7xf32> -// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] : tensor<2x4x6x7xf32> -// CHECK: return %[[RESULT]] : tensor<2x4x6x7xf32> -// CHECK: } -func.func @avgpool_same_padding(%arg0: tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> { - %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 5, 2, 1], padding = "SAME", strides = [1, 3, 4, 1]} : (tensor<2x12x21x7xf32>) -> tensor<2x4x6x7xf32> - func.return %0 : tensor<2x4x6x7xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_same_padding( -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> -// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[DIVIDEND:.*]] = "mhlo.reduce_window"(%[[ARG0]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x4x6x7xf32> -// CHECK: %[[ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x4x12x21x7xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [1, 1], [0, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 5, 2, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 3, 4, 1]> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x4x6x7xf32> -// CHECK: %[[RESULT:.*]] = mhlo.divide %[[DIVIDEND]], %[[DIVISOR]] -// CHECK: return %[[RESULT]] : tensor<2x4x4x6x7xf32> -// CHECK: } -func.func @avgpool_3d_same_padding(%arg0: tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> { - %0 = "tf.AvgPool3D"(%arg0) {data_format = "NDHWC", ksize = [1, 1, 5, 2, 1], padding = "SAME", strides = [1, 1, 3, 4, 1]} : (tensor<2x4x12x21x7xf32>) -> tensor<2x4x4x6x7xf32> - func.return %0 : tensor<2x4x4x6x7xf32> -} - -//===----------------------------------------------------------------------===// -// AvgPoolGrad op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @avgpool_grad_valid_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<10x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x24x32x64xf32> -// CHECK: return %[[RESULT]] : tensor<10x24x32x64xf32> -func.func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<4xi32>, tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> - func.return %result : tensor<10x24x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_valid_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = array} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x8x25x33x64xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x8x24x32x64xf32> -// CHECK: return %[[RESULT]] : tensor<10x8x24x32x64xf32> -func.func @avgpool_3d_grad_valid_padding(%grad: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 8, 24, 32, 64]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NDHWC", - ksize = [1, 1, 2, 2, 1], - padding = "VALID", - strides = [1, 1, 2, 2, 1]} : (tensor<5xi32>, tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> - func.return %result : tensor<10x8x24x32x64xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_same_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 3, 3, 0]> -// CHECK-SAME: -> tensor<2x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x13x25x9xf32> -// CHECK: return %[[RESULT]] : tensor<2x13x25x9xf32> -func.func @avgpool_grad_same_padding(%grad: tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 13, 25, 9]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 3, 1], - padding = "SAME", - strides = [1, 4, 4, 1] - } : (tensor<4xi32>, tensor<2x4x7x9xf32>) -> tensor<2x13x25x9xf32> - func.return %result : tensor<2x13x25x9xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_same_padding( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x8x13x25x9xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1], [0, 0]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4, 1]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x8x4x7x9xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x8x4x7x9xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3, 0]> -// CHECK-SAME: -> tensor<2x8x14x27x9xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x8x13x25x9xf32> -// CHECK: return %[[RESULT]] : tensor<2x8x13x25x9xf32> -func.func @avgpool_3d_grad_same_padding(%grad: tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 8, 13, 25, 9]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NDHWC", - ksize = [1, 1, 2, 3, 1], - padding = "SAME", - strides = [1, 1, 4, 4, 1]} : (tensor<5xi32>, tensor<2x8x4x7x9xf32>) -> tensor<2x8x13x25x9xf32> - func.return %result : tensor<2x8x13x25x9xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_nchw_format( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 4, 4]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 1]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1]> -// CHECK-SAME: interior_padding = dense<[0, 0, 3, 3]> -// CHECK-SAME: -> tensor<2x9x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x13x25xf32> -// CHECK: return %[[RESULT]] : tensor<2x9x13x25xf32> -func.func @avgpool_grad_nchw_format(%grad: tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 13, 25]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NCHW", - ksize = [1, 1, 2, 3], - padding = "SAME", - strides = [1, 1, 4, 4] - } : (tensor<4xi32>, tensor<2x9x4x7xf32>) -> tensor<2x9x13x25xf32> - func.return %result : tensor<2x9x13x25xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_3d_grad_ncdwh_format( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[ALL_ONES:.*]] = mhlo.constant dense<1.000000e+00> : tensor<2x9x8x13x25xf32> -// CHECK: %[[DIVISOR:.*]] = "mhlo.reduce_window"(%[[ALL_ONES]], %[[ZERO]]) -// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 0], [0, 0], [0, 1], [1, 1]]> -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<[1, 1, 1, 4, 4]> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM1:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM1]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x8x4x7xf32> -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = mhlo.divide %[[OUT_GRAD]], %[[DIVISOR]] : tensor<2x9x8x4x7xf32> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 0, 0, 0, 1]> -// CHECK-SAME: edge_padding_low = dense<[0, 0, 0, 1, 1]> -// CHECK-SAME: interior_padding = dense<[0, 0, 0, 3, 3]> -// CHECK-SAME: -> tensor<2x9x8x14x27xf32> -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT]], %[[ZERO]]) -// CHECK-SAME: window_dimensions = dense<[1, 1, 1, 2, 3]> -// CHECK-SAME: window_strides = dense<1> : tensor<5xi64> -// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[SUM2:.*]] = mhlo.add %[[ARG3]], %[[ARG4]] : tensor -// CHECK: mhlo.return %[[SUM2]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<2x9x8x13x25xf32> -// CHECK: return %[[RESULT]] : tensor<2x9x8x13x25xf32> -func.func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> { - %orig_input_shape = "tf.Const"() {value = dense<[2, 9, 8, 13, 25]> : tensor<5xi32>} : () -> (tensor<5xi32>) - %result = "tf.AvgPool3DGrad"(%orig_input_shape, %grad) { - data_format = "NCDHW", - ksize = [1, 1, 1, 2, 3], - padding = "SAME", - strides = [1, 1, 1, 4, 4]} : (tensor<5xi32>, tensor<2x9x8x4x7xf32>) -> tensor<2x9x8x13x25xf32> - func.return %result : tensor<2x9x8x13x25xf32> -} - -// ----- - -// CHECK-LABEL: @avgpool_grad_bf16( -// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { -// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] -// CHECK-SAME: broadcast_dimensions = array -// CHECK-SAME: -> tensor<10x12x16x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) -// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> -// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> -// CHECK-SAME: interior_padding = dense<[0, 1, 1, 0]> -// CHECK-SAME: -> tensor<10x25x33x64xbf16> -// CHECK: %[[REDUCE_WINDOW_INPUT_CONVERTED:.*]] = mhlo.convert %[[REDUCE_WINDOW_INPUT]] : (tensor<10x25x33x64xbf16>) -> tensor<10x25x33x64xf32> -// CHECK: %[[ZERO_F32:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[RESULT:.*]] = "mhlo.reduce_window"(%[[REDUCE_WINDOW_INPUT_CONVERTED]], %[[ZERO_F32]]) -// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> -// CHECK-SAME: window_strides = dense<1> -// CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[SUM:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[SUM]] : tensor -// CHECK: }) -// CHECK-SAME: -> tensor<10x24x32x64xf32> -// CHECK: %[[RESULT_CONVERTED:.*]] = mhlo.convert %[[RESULT]] : (tensor<10x24x32x64xf32>) -> tensor<10x24x32x64xbf16> -// CHECK: return %[[RESULT_CONVERTED]] : tensor<10x24x32x64xbf16> -func.func @avgpool_grad_bf16(%grad: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { - %orig_input_shape = "tf.Const"() {value = dense<[10, 24, 32, 64]> : tensor<4xi32>} : () -> (tensor<4xi32>) - %result = "tf.AvgPoolGrad"(%orig_input_shape, %grad) { - data_format = "NHWC", - ksize = [1, 2, 2, 1], - padding = "VALID", - strides = [1, 2, 2, 1] - } : (tensor<4xi32>, tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> - func.return %result : tensor<10x24x32x64xbf16> -} - -// ----- - -// CHECK-LABEL: xla_sharding -func.func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> { - // CHECK-NEXT: mhlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} - %0 = "tf.XlaSharding"(%arg0) {_XlaSharding = "", sharding = ""} : (tensor<4x16xf32>) -> tensor<4x16xf32> - func.return %0 : tensor<4x16xf32> -} - -// ----- - -// CHECK-LABEL: inplace_update_one -func.func @inplace_update_one(%arg0: tensor<8x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<1xi32>) -> tensor<8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> - // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] - // CHECK-DAG: [[UPDATE:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE2]], [[RESHAPE1]], [[CST]] - %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x4xf32>, tensor<1xi32>, tensor<1x4xf32>) -> tensor<8x4xf32> - - // CHECK: return [[UPDATE]] - func.return %0 : tensor<8x4xf32> -} - -// ----- - -// CHECK-LABEL: inplace_update_three -func.func @inplace_update_three(%arg0: tensor<8x8x4xf32>, %arg1: tensor<3x8x4xf32>, %arg2: tensor<3xi32>) -> tensor<8x8x4xf32> { - // CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE3:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<3> : tensor<1xi64>, start_indices = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> - // CHECK-DAG: [[SLICE4:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 8, 4]> : tensor<3xi64>, start_indices = dense<0> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[SLICE5:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 8, 4]> : tensor<3xi64>, start_indices = dense<[1, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[SLICE6:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[3, 8, 4]> : tensor<3xi64>, start_indices = dense<[2, 0, 0]> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}> - // CHECK-DAG: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] - // CHECK-DAG: [[RESHAPE2:%.+]] = mhlo.reshape [[SLICE2]] - // CHECK-DAG: [[RESHAPE3:%.+]] = mhlo.reshape [[SLICE3]] - // CHECK-DAG: [[UPDATE1:%.+]] = mhlo.dynamic_update_slice %arg0, [[SLICE4]], [[RESHAPE1]], [[CST]], [[CST]] - // CHECK-DAG: [[UPDATE2:%.+]] = mhlo.dynamic_update_slice [[UPDATE1]], [[SLICE5]], [[RESHAPE2]], [[CST]], [[CST]] - // CHECK-DAG: [[UPDATE3:%.+]] = mhlo.dynamic_update_slice [[UPDATE2]], [[SLICE6]], [[RESHAPE3]], [[CST]], [[CST]] - %0 = "tf.InplaceUpdate"(%arg0, %arg2, %arg1) : (tensor<8x8x4xf32>, tensor<3xi32>, tensor<3x8x4xf32>) -> tensor<8x8x4xf32> - - // CHECK: return [[UPDATE3]] : tensor<8x8x4xf32> - func.return %0 : tensor<8x8x4xf32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_update_slice -func.func @xla_dynamic_update_slice(%arg0: tensor<4x16xf32>, %arg1: tensor<2x4xf32>, %arg2: tensor<2xi32>) -> tensor<4x16xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor - // CHECK: [[SLICE1:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE1:%.+]] = mhlo.reshape [[SLICE1]] : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]], [[RESHAPE1]] : (tensor<4x16xf32>, tensor<2x4xf32>, tensor, tensor) -> tensor<4x16xf32> - // CHECK: return [[DUS]] - %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x16xf32>, tensor<2x4xf32>, tensor<2xi32>) -> tensor<4x16xf32> - func.return %0 : tensor<4x16xf32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_update_slice2 -func.func @xla_dynamic_update_slice2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<1xi32>) -> tensor<4xf32> { - // CHECK: [[SLICE0:%.+]] = "mhlo.slice"(%arg2) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<1xi32>) -> tensor<1xi32> - // CHECK: [[RESHAPE0:%.+]] = mhlo.reshape [[SLICE0]] : (tensor<1xi32>) -> tensor - // CHECK: [[DUS:%.+]] = mhlo.dynamic_update_slice %arg0, %arg1, [[RESHAPE0]] : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> - // CHECK: return [[DUS]] - %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<1xi32>) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -//===----------------------------------------------------------------------===// -// AllToAll op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @alltoall_basic -// See https://www.tensorflow.org/api_docs/python/tf/raw_ops/AllToAll -func.func @alltoall_basic(%input: tensor<1x2xf32>) -> tensor<2x1xf32> { - %group_assignment = "tf.Const" () { - value = dense<[[0, 1]]> : tensor<1x2xi32> - } : () -> tensor<1x2xi32> - %result = "tf.AllToAll"(%input, %group_assignment) {T = f32, concat_dimension = 0 : i64, split_count = 2 : i64, split_dimension = 1 : i64} : (tensor<1x2xf32>, tensor<1x2xi32>) -> tensor<2x1xf32> - // CHECK: mhlo.all_to_all - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> - func.return %result : tensor<2x1xf32> -} - - -//===----------------------------------------------------------------------===// -// Cumsum op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @cumsum_static -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_static(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> - // CHECK: return [[CONVERT_REDUCE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_exclusive -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_exclusive(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[X]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> - // CHECK: return [[CONVERT_REDUCE]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_reverse -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[REDUCE]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: return [[REVERSE_BACK]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = false, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_exclusive_reverse -// CHECK-SAME: [[X:%.*]]: tensor<4xf32> -func.func @cumsum_exclusive_reverse(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[AXIS:%.*]] = mhlo.constant dense<0> : tensor - // CHECK: [[REVERSE1:%.*]] = "mhlo.reverse"([[X]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: [[CONVERT_X:%.*]] = mhlo.convert [[REVERSE1]] : tensor<4xf32> - // CHECK: [[INIT:%.*]] = mhlo.constant dense<0.000000e+00> : tensor - // CHECK: [[REDUCE:%.*]] = "mhlo.reduce_window"([[CONVERT_X]], [[INIT]]) <{padding = dense<{{\[\[}}3, 0]]> : tensor<1x2xi64>, window_dimensions = dense<4> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>}> ({ - // CHECK: ^bb0([[A:%.*]]: tensor, [[B:%.*]]: tensor): - // CHECK: [[SUM:%.*]] = mhlo.add [[A]], [[B]] : tensor - // CHECK: mhlo.return [[SUM]] : tensor - // CHECK: }) : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[PAD:%.*]] = "mhlo.pad"([[REDUCE]], %{{.*}}) <{edge_padding_high = dense<-1> : tensor<1xi64>, edge_padding_low = dense<1> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<4xf32>, tensor) -> tensor<4xf32> - // CHECK: [[CONVERT_REDUCE:%.*]] = mhlo.convert [[PAD]] : tensor<4xf32> - // CHECK: [[REVERSE_BACK:%.*]] = "mhlo.reverse"([[CONVERT_REDUCE]]) <{dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: return [[REVERSE_BACK]] - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumsum"(%arg0, %0) {exclusive = true, reverse = true} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_empty -func.func @cumsum_empty(%arg0: tensor<0xf32>) -> tensor<0xf32> { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - - // CHECK: mhlo.constant dense<> : tensor<0xf32> - %1 = "tf.Cumsum"(%arg0, %0) : (tensor<0xf32>, tensor) -> tensor<0xf32> - func.return %1 : tensor<0xf32> -} - -// ----- - -// CHECK-LABEL: func @cumsum_dynamic -func.func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: "tf.Cumsum" - %0 = "tf.Cumsum"(%arg0, %arg1) : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// Cumprod op legalizations. -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @cumprod -func.func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor - // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) - // CHECK: mhlo.mul - %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor - %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> - func.return %1 : tensor<4xf32> -} - -//===----------------------------------------------------------------------===// -// tf.Softplus legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @softplus_f16 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>) -func.func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16> - func.return %0 : tensor<8x16xf16> -} - -// ----- - -// CHECK-LABEL: func @softplus_bf16 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>) -func.func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16> - func.return %0 : tensor<8x16xbf16> -} - -// ----- - -// CHECK-LABEL: func @softplus_f32 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>) -func.func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32> - func.return %0 : tensor<8x16xf32> -} - -// ----- - -// CHECK-LABEL: func @softplus_f64 -// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>) -func.func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { - // CHECK-DAG: [[FEATURES_EXP:%.*]] = mhlo.exponential [[FEATURES]] - // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor - // CHECK-DAG: [[EPSILON_LOG:%.*]] = mhlo.log [[EPSILON]] - // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] - // CHECK: [[NEG_THRESHOLD:%.*]] = mhlo.negate [[THRESHOLD]] - // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = #chlo} - // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = mhlo.log_plus_one [[FEATURES_EXP]] - // CHECK: [[ELSE_SELECT:%.*]] = mhlo.select [[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]] - // CHECK: [[ENTRY_SELECT:%.*]] = mhlo.select [[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]] - %0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64> - - // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> - func.return %0 : tensor<8x16xf64> -} - -// ----- - -// CHECK-LABEL: @xla_gather -func.func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { - %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> - - // CHECK: "mhlo.gather" - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [0, 1] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = true - // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> - - %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> - func.return %0 : tensor<1x300x10xf32> -} - -// ----- - -// CHECK-LABEL: @xla_gather_i32 -func.func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<1x300x10xf32> { - %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> - - // CHECK: "mhlo.gather" - // CHECK-SAME: dimension_numbers = - // CHECK-SAME: offset_dims = [0, 1] - // CHECK-SAME: collapsed_slice_dims = [0] - // CHECK-SAME: start_index_map = [0, 1] - // CHECK-SAME: index_vector_dim = 1 - // CHECK-SAME: indices_are_sorted = true - // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> - - %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<1x300x10xf32> - func.return %0 : tensor<1x300x10xf32> -} - - -// CHECK: func @stridedslice_with_i32 -func.func @stridedslice_with_i32(%arg0: tensor) -> tensor<4xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "const_0_arg", outputs = "identity_0_retval_RetVal"}} { -// CHECK-NOT: tf.StridedSlice -// CHECK: [[DYNSLICE:%.*]] = "mhlo.dynamic_slice -// CHECK: [[RESHAPE:%.*]] = mhlo.reshape [[DYNSLICE]] -// CHECK: return [[RESHAPE]] - %0 = "tf.Const"() {value = dense<[[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]]> : tensor<2x4xf32>} : () -> tensor<2x4xf32> - %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %2 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %3 = "tf.AddV2"(%arg0, %1) {_xla_inferred_shapes = [#tf_type.shape<>], device = ""} : (tensor, tensor) -> tensor - %4 = "tf.Pack"(%3) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %5 = "tf.Pack"(%arg0) {_xla_inferred_shapes = [#tf_type.shape<1>], axis = 0 : i64, device = ""} : (tensor) -> tensor<1xi32> - %6 = "tf.StridedSlice"(%0, %5, %4, %2) {_xla_inferred_shapes = [#tf_type.shape<4>], begin_mask = 0 : i64, device = "", ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64} : (tensor<2x4xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf32> - func.return %6 : tensor<4xf32> -} - -func.func @replica_id() -> tensor { - // CHECK: %[[ID:.*]] = mhlo.replica_id : tensor - // CHECK: %[[RESULT:.*]] = mhlo.convert %0 : (tensor) -> tensor - %0 = "tf.XlaReplicaId"() : () -> tensor - func.return %0 : tensor -} - -// CHECK: func @angle_c64 -// CHECK-SAME: ([[ARG0:%.*]]: tensor>) -func.func @angle_c64(%arg0: tensor>) -> tensor { -// CHECK: [[IMAG:%.*]] = mhlo.imag [[ARG0]] -// CHECK: [[REAL:%.*]] = mhlo.real [[ARG0]] -// CHECK: [[ATAN2:%.*]] = mhlo.atan2 [[IMAG]], [[REAL]] - %0 = "tf.Angle"(%arg0): (tensor>) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.ApproximateEqual legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: func @approximateequal_f64 -func.func @approximateequal_f64(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor - func.return %equal : tensor -} - -// CHECK-LABEL: func @approximateequal_i32 -func.func @approximateequal_i32(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : (tensor) -> tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor, tensor) -> tensor - func.return %equal : tensor -} - -// CHECK-LABEL: func @approximateequal_complex64 -func.func @approximateequal_complex64(%arg0: tensor>, %arg1: tensor>) -> tensor { - // CHECK: %[[SUB:.*]] = mhlo.subtract %arg0, %arg1 : tensor> - // CHECK: %[[ABS:.*]] = mhlo.abs %[[SUB]] : (tensor>) -> tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: %[[CONVERT:.*]] = mhlo.convert %[[CST]] : tensor - // CHECK: %[[LE:.*]] = chlo.broadcast_compare %[[ABS]], %[[CONVERT]] {comparison_direction = #chlo} : (tensor, tensor) -> tensor - // CHECK: return %[[LE]] : tensor - %equal = "tf.ApproximateEqual"(%arg0, %arg1) { tolerance = 2. : f32 } : (tensor>, tensor>) -> tensor - func.return %equal : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaConvV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: xla_conv_v2 -func.func @xla_conv_v2(%lhs: tensor<8x4x16x16x16xf32>, %rhs: tensor<4x3x3x16x16xf32>) -> (tensor<4x4x14x14x16xf32>) { - %feature_group_count = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> - %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> - %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> - // CHECK: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {stride = [3, 1, 1], pad = {{\[\[}}0, 0], {{\[}}0, 0], {{\[}}0, 0]], lhs_dilate = [4, 1, 1], rhs_dilate = [1, 1, 1]} {batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = []} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>) -> tensor<4x4x14x14x16xf32> - %0 = "tf.XlaConvV2"(%lhs, %rhs, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) {batch_group_count = 2 : i64, dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<4x4x14x14x16xf32> - func.return %0 : tensor<4x4x14x14x16xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDot legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xladot_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> -func.func @xladot_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-NOT: lhs_batching_dimensions = - // CHECK-NOT: rhs_batching_dimensions = - // CHECK-SAME: lhs_contracting_dimensions = [1] - // CHECK-SAME: rhs_contracting_dimensions = [0] - // CHECK-SAME: precision_config = [] - %res = "tf.XlaDot"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> - func.return %res : tensor<64x16xi32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDotV2 legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xladotv2_matmul( -// CHECK-SAME: %[[LHS:.*]]: tensor<64x32xi8>, %[[RHS:.*]]: tensor<32x16xi8>) -> tensor<64x16xi32> -func.func @xladotv2_matmul(%lhs : tensor<64x32xi8>, %rhs : tensor<32x16xi8>) -> tensor<64x16xi32> { - // CHECK: "mhlo.dot_general"(%[[LHS]], %[[RHS]]) <{ - // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< - // CHECK-NOT: lhs_batching_dimensions = - // CHECK-NOT: rhs_batching_dimensions = - // CHECK-SAME: lhs_contracting_dimensions = [1] - // CHECK-SAME: rhs_contracting_dimensions = [0] - // CHECK-SAME: precision_config = [] - %res = "tf.XlaDotV2"(%lhs, %rhs) {dimension_numbers = "\0A\01\01\12\01\00", precision_config = ""} : (tensor<64x32xi8>, tensor<32x16xi8>) -> tensor<64x16xi32> - func.return %res : tensor<64x16xi32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaDynamicSlice legalization -//===----------------------------------------------------------------------===// -// ----- - -// CHECK-LABEL: xla_dynamic_slice_constant_start -func.func @xla_dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor<1xi64> - // CHECK-DAG-SAME: (tensor<1xi64>) -> tensor - // CHECK-NEXT: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[START]]) - // CHECK-DAG-SAME: {slice_sizes = dense<2> : tensor<1xi64>} : - // CHECK-DAG-SAME: (tensor<4xi32>, tensor) -> tensor<2xi32> - // CHECK-NEXT: return %[[RESULT]] : tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi64>} : () -> (tensor<1xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_i32_consts -func.func @xla_dynamic_slice_i32_consts(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[START:.*]] = mhlo.constant dense<1> : tensor - // CHECK: "mhlo.dynamic_slice"(%arg0, %[[START]]) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> - %starts = "tf.Const"() {value = dense<[1]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %sizes = "tf.Const"() {value = dense<[2]> : tensor<1xi32>} : () -> (tensor<1xi32>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> - func.return %0 : tensor<2xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_constant_start_dynamic_shape -func.func @xla_dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK-DAG: %[[START1:.*]] = mhlo.constant dense<1> : tensor - // CHECK-DAG: %[[START2:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice" - // CHECK-DAG-SAME: (%arg0, %[[START1]], %[[START2]]) - // CHECK-DAG-SAME: {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : - // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %starts = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %starts, %sizes) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_variable_start -func.func @xla_dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[SLICED_START1:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START1:.*]] = mhlo.reshape %[[SLICED_START1]] : (tensor<1xi64>) -> tensor - // CHECK: %[[SLICED_START2:.*]] = "mhlo.slice"(%arg1) - // CHECK-DAG-SAME: {limit_indices = dense<2> : tensor<1xi64>, - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>, - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64> - // CHECK: %[[RESHAPED_START2:.*]] = mhlo.reshape %[[SLICED_START2]] : (tensor<1xi64>) -> tensor - // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_slice"(%arg0, %[[RESHAPED_START1]], %[[RESHAPED_START2]]) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %sizes = "tf.Const"() {value = dense<[1, 4]> : tensor<2xi64>} : () -> (tensor<2xi64>) - %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %sizes) : (tensor<3x4xi32>, tensor<2xi64>, tensor<2xi64>) -> tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// ----- - -// CHECK-LABEL: xla_dynamic_slice_mhlo_sizes -func.func @xla_dynamic_slice_mhlo_sizes(%arg0: tensor<1x1024x4xf32>, %arg1: tensor<3xi32>) -> tensor<1x512x4xf32> { - // CHECK-NOT: "tf.XlaDynamicSlice" - %0 = "mhlo.constant"() {value = dense<[1, 512, 4]> : tensor<3xi32>} : () -> tensor<3xi32> - %1 = "tf.XlaDynamicSlice"(%arg0, %arg1, %0) : (tensor<1x1024x4xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x4xf32> - func.return %1 : tensor<1x512x4xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaEinsum legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: func @xlaeinsum -func.func @xlaeinsum(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>) -> tensor<2x4xf32> { - // CHECK-NEXT: mhlo.einsum - %0 = "tf.XlaEinsum"(%arg0, %arg1) {equation = "ab,bc->ac"} : (tensor<2x3xf32>, tensor<3x4xf32>) -> tensor<2x4xf32> - func.return %0: tensor<2x4xf32> -} - - -//===----------------------------------------------------------------------===// -// tf.XlaReduceWindow legalization -//===----------------------------------------------------------------------===// -// ----- -// CHECK-LABEL: @test_xla_reduce_window -func.func @test_xla_reduce_window(%arg0: tensor<7xf32>, %arg1: tensor) -> tensor<10xf32> { - %cst = "tf.Const"() {value = dense<0> : tensor<1x2xi32>} : () -> tensor<1x2xi32> - %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_1 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_2 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32> - %cst_3 = "tf.Const"() {value = dense<4> : tensor<1xi32>} : () -> tensor<1xi32> - // CHECK: %[[REDUCE:.*]] = "mhlo.reduce_window"(%arg0, %arg1) <{base_dilations = dense<3> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<4> : tensor<1xi64>, window_dimensions = dense<1> : tensor<1xi64>, window_strides = dense<2> : tensor<1xi64>}> ({ - // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer3(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor - // CHECK-NEXT: }) : (tensor<7xf32>, tensor) -> tensor<10xf32> - // CHECK-NEXT: return %[[REDUCE]] - %0 = "tf.XlaReduceWindow"(%arg0, %arg1, %cst_0, %cst_1, %cst_2, %cst_3, %cst) {computation = @sum_reducer3} : (tensor<7xf32>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} - -func.func private @sum_reducer3(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaSort legalization -//===----------------------------------------------------------------------===// - -// ----- - -// CHECK-LABEL: @xlasort_int -// CHECK-SAME: %[[INPUT:.*]]: tensor<16xi32> -func.func @xlasort_int(%input: tensor<16xi32>) -> (tensor<16xi32>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], NOTYPE - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<16xi32>) -> tensor<16xi32> - // CHECK-NEXT: return %[[SORT]] - %output = "tf.XlaSort"(%input) : (tensor<16xi32>) -> (tensor<16xi32>) - func.return %output : tensor<16xi32> -} - -// ----- - -// CHECK-LABEL: @xlasort_float -// CHECK-SAME: %[[INPUT:.*]]: tensor<8xf64> -func.func @xlasort_float(%input: tensor<8xf64>) -> (tensor<8xf64>) { - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = -1 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare LT, %[[LHS]], %[[RHS]], TOTALORDER - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<8xf64>) -> tensor<8xf64> - // CHECK-NEXT: return %[[SORT]] - %output = "tf.XlaSort"(%input) : (tensor<8xf64>) -> (tensor<8xf64>) - func.return %output : tensor<8xf64> -} - -// ----- - -// CHECK-LABEL: @xlasort_const -func.func @xlasort_const() -> (tensor<2x3xi64>) { - // CHECK: [2, 4, 3], [6, 5, 1] - %input = "tf.Const"() {value = dense<[[2, 4, 3], [6, 5, 1]]> : tensor<2x3xi64>} : () -> (tensor<2x3xi64>) - // CHECK-NEXT: [2, 3, 4], [1, 5, 6] - %output = "tf.XlaSort"(%input): (tensor<2x3xi64>) -> (tensor<2x3xi64>) - func.return %output : tensor<2x3xi64> -} - -//===----------------------------------------------------------------------===// -// tf.XlaRngBitGenerator legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @xla_rng_bit_generator -// CHECK-SAME: %[[STATE:.*]]: tensor<2xui64> -func.func @xla_rng_bit_generator(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0,_retval1"}} { - // CHECK-NEXT: %0 = mhlo.constant dense<[10, 12]> : tensor<2xi32> - %cst = "tf.Const"() {value = dense<[10, 12]> : tensor<2xi32>} : () -> tensor<2xi32> - // CHECK-NEXT: %1 = mhlo.constant dense<3> : tensor - %cst_0 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor - // CHECK-NEXT: %[[OUTPUT_STATE:.*]], %[[OUTPUT:.*]] = "mhlo.rng_bit_generator"(%[[STATE]]) <{rng_algorithm = #mhlo.rng_algorithm}> : (tensor<2xui64>) -> (tensor<2xui64>, tensor<10x12xui32>) - // CHECK-NEXT: return %[[OUTPUT_STATE]], %[[OUTPUT]] : tensor<2xui64>, tensor<10x12xui32> - %output_key, %output = "tf.XlaRngBitGenerator"(%cst_0, %arg0, %cst) : (tensor, tensor<2xui64>, tensor<2xi32>) -> (tensor<2xui64>, tensor<10x12xui32>) - func.return %output_key, %output : tensor<2xui64>, tensor<10x12xui32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaVariadicV2 legalization -//===----------------------------------------------------------------------===// - -// ----- -// CHECK-LABEL: @xla_variadic_reduce_v2 -func.func @xla_variadic_reduce_v2(%arg0: tensor<2x3xcomplex>, %arg1: tensor>) -> tensor<3xcomplex> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) - // CHECK-SAME: dimensions = [0] - // CHECK-NEXT: (%[[ARG0:.*]]: tensor>, %[[ARG1:.*]]: tensor>) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor> - // CHECK: return %[[REDUCE]] - %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer} : (tensor<2x3xcomplex>, tensor>) -> tensor<3xcomplex> - func.return %0 : tensor<3xcomplex> -} - -func.func private @sum_reducer(%arg0: tensor>, %arg1: tensor>) -> tensor> { - %0 = "tf.AddV2"(%arg1, %arg0) : (tensor>, tensor>) -> tensor> - func.return %0 : tensor> -} - -// ----- - -// CHECK-LABEL: @xla_variadic_reduce_v2_dynamic -func.func @xla_variadic_reduce_v2_dynamic(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK: %[[REDUCE:.*]] = mhlo.reduce(%arg0 init: %arg1) - // CHECK-SAME: dimensions = [0] - // CHECK-NEXT: (%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[SUM:.*]] = func.call @sum_reducer2(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[SUM]] : tensor - // CHECK: return %[[REDUCE]] - %0 = "tf.XlaVariadicReduceV2"(%arg0, %arg1) {_XlaHasReferenceVars = false, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", dimensions_to_reduce = [0], operandSegmentSizes = array, reducer = @sum_reducer2} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @sum_reducer2(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg1, %arg0) : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaVariadicSort legalization -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @xla_variadic_sort -// CHECK-SAME: %[[INPUT:.*]]: tensor<2x3x4xui8> -func.func @xla_variadic_sort(%arg0: tensor<2x3x4xui8>) -> tensor<2x3x4xui8> attributes {tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1", outputs = "_retval0"}} { - // CHECK-NEXT: {{.*}} = mhlo.constant dense<0> : tensor - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK-NEXT: %[[SORT:.*]] = "mhlo.sort"(%[[INPUT]]) <{dimension = 0 : i64, is_stable = false}> ({ - // CHECK-NEXT: ^{{.*}}(%[[LHS:.*]]: tensor, %[[RHS:.*]]: tensor) - // CHECK-NEXT: %[[CMP:.*]] = func.call @compare_lt(%[[LHS]], %[[RHS]]) : (tensor, tensor) -> tensor - // CHECK-NEXT: mhlo.return %[[CMP]] - // CHECK-NEXT: }) : (tensor<2x3x4xui8>) -> tensor<2x3x4xui8> - // CHECK-NEXT: return %[[SORT]] - %0 = "tf.XlaVariadicSort"(%arg0, %cst) {_XlaHasReferenceVars = false, comparator = @compare_lt, device = "/job:localhost/replica:0/task:0/device:XLA_GPU:0", is_stable = false} : (tensor<2x3x4xui8>, tensor) -> tensor<2x3x4xui8> - func.return %0 : tensor<2x3x4xui8> -} - -func.func private @compare_lt(%arg0: tensor, %arg1: tensor) -> tensor attributes {tf._disable_call_shape_inference = true} { - %0 = "tf.Less"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.NextAfter legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @nextafter -func.func @nextafter(%arg0: tensor<2xf32>, %arg1 : tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %0 = chlo.broadcast_next_after %arg0, %arg1 : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - // CHECK-NEXT: return %0 : tensor<2xf32> - %0 = "tf.NextAfter"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> - func.return %0: tensor<2xf32> -} - -//===----------------------------------------------------------------------===// -// tf.XlaReduceScatter legalization -//===----------------------------------------------------------------------===// -// CHECK-LABEL: func @xla_reduce_scatter -func.func @xla_reduce_scatter(%arg0: tensor<128x128xf32>) -> tensor<64x128xf32> { - %cst = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %cst_0 = "tf.Const"() {value = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - // CHECK: "mhlo.reduce_scatter"(%arg0) - // CHECK{LITERAL}: replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> - // CHECK-SAME: scatter_dimension = 0 - // - %1 = "tf.XlaReduceScatter"(%arg0, %cst_0, %cst) {reduce_op = "Add"} : (tensor<128x128xf32>, tensor<4x2xi32>, tensor) -> tensor<64x128xf32> - func.return %1 : tensor<64x128xf32> -} - - -//===----------------------------------------------------------------------===// -// tf.XlaSelectAndScatter legalization -//===----------------------------------------------------------------------===// -func.func @test_xla_select_and_scatter(%arg0: tensor<4x5x1x1xbf16>, %arg1: tensor<2x2x1x1xbf16>, %arg2: tensor) -> tensor { - %cst = "tf.Const"() {value = dense<0> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - %cst_0 = "tf.Const"() {value = dense<[2, 2, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - %cst_1 = "tf.Const"() {value = dense<[2, 3, 1, 1]> : tensor<4xi32>} : () -> tensor<4xi32> - // CHECK: %[[SELECT_AND_SCATTER:.*]] = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) <{padding = dense<0> : tensor<4x2xi64>, window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>, window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>}> ({ - // CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) - // CHECK-NEXT: %[[RES:.*]] = func.call @ge_select(%[[ARG0]], %[[ARG1]]){{.*}} - // CHECK-NEXT: mhlo.return %[[RES]] : tensor - // CHECK-NEXT: }, { - // CHECK-NEXT: ^{{.*}}(%[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor) - // CHECK-NEXT: %[[RES:.*]] = func.call @add_scatter(%[[ARG2]], %[[ARG3]]){{.*}} - // CHECK-NEXT: mhlo.return %[[RES]] : tensor - // CHECK-NEXT: }) : (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor) -> tensor - // CHECK-NEXT: return %[[SELECT_AND_SCATTER]] - %0 = "tf.XlaSelectAndScatter"(%arg0, %cst_1, %cst_0, %cst, %arg1, %arg2) {scatter = @add_scatter, select = @ge_select} : (tensor<4x5x1x1xbf16>, tensor<4xi32>, tensor<4xi32>, tensor<4x2xi32>, tensor<2x2x1x1xbf16>, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @add_scatter(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -func.func private @ge_select(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.GreaterEqual"(%arg0, %arg1) {device = ""} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -//===----------------------------------------------------------------------===// -// tf.XlaOptimizationBarrier legalization -//===----------------------------------------------------------------------===// - -func.func @test_xla_optimization_barrier(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) { - // CHECK: %[[OPT_BARRIER:.*]]:2 = mhlo.optimization_barrier %arg0, %arg1 - // CHECK-NEXT: return %[[OPT_BARRIER]]#0, %[[OPT_BARRIER]]#1 - %0, %1 = "tf.XlaOptimizationBarrier"(%arg0, %arg1) : (tensor<4x4xf32>, tensor<3x4xi32>) -> (tensor<4x4xf32>, tensor<3x4xi32>) - func.return %0, %1 : tensor<4x4xf32>, tensor<3x4xi32> -} - -// CHECK-LABEL: @ifRegion -// CHECK-SAME: ([[ARG0:%.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @ifRegion(%arg0: tensor, %arg1: tensor) -> (tensor) { - // CHECK: [[VAL0:%.+]] = mhlo.compare GT, [[ARG0]], [[ARG1]] - %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: [[VAL1:%.+]] = "mhlo.if"([[VAL0]]) ({ - %1 = "tf.IfRegion"(%0) ({ - // CHECK: [[VAL2:%.+]] = mhlo.log [[ARG0]] - %2 = "tf.Log"(%arg0) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.exponential [[ARG1]] - %2 = "tf.Exp"(%arg1) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - // CHECK: }) : (tensor) -> tensor - }) {is_stateless = true} : (tensor) -> tensor - // CHECK: return [[VAL1]] - func.return %1 : tensor -} - -// CHECK-LABEL: func @caseRegion -// CHECK-SAME: ([[BRANCH_INDEX:%.+]]: tensor, [[ARG0:.+]]: tensor, [[ARG1:%.+]]: tensor) -func.func @caseRegion(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - // CHECK: [[VAL1:%.+]]:2 = "mhlo.case"([[BRANCH_INDEX]]) ({ - %0:2 = "tf.CaseRegion"(%index) ({ - // CHECK: [[VAL2:%.+]] = mhlo.exponential [[ARG1]] - %1 = mhlo.exponential %arg1 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL2]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.log [[ARG0]] - %1 = mhlo.log %arg0 : (tensor) -> tensor - // CHECK: mhlo.return [[VAL3]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - }, { - // CHECK: [[VAL4:%.+]] = mhlo.floor [[ARG0]] - %1 = "mhlo.floor"(%arg0) : (tensor) -> tensor - // CHECK: mhlo.return [[VAL4]], [[ARG1]] - "tf.Yield"(%1, %arg1) : (tensor, tensor) -> () - // CHECK: }) : (tensor) -> (tensor, tensor) - }) {is_stateless = true} : (tensor) -> (tensor, tensor) - // CHECK: return [[VAL1]]#0, [[VAL1]]#1 : tensor, tensor - func.return %0#0, %0#1 : tensor, tensor -} - -// ----- - -// This test case also ensures the mhlo dialect is loaded as a dependency by the -// pass and hence the split here. - -// CHECK-LABEL: func @whileRegion -func.func @whileRegion() -> tensor { - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ - ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): - %3 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.Yield"(%3) : (tensor) -> () - }, { - ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): - %4 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - "tf.Yield"(%4, %4, %4) : (tensor, tensor, tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - func.return %2#2 : tensor -} - -// ----- - -// CHECK-LABEL: func @whileRegionAdd -func.func @whileRegionAdd() -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant - %0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant - %1 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor - // CHECK: [[VAL2:%.+]]:3 = mhlo.while([[ITER_ARG0:.*]] = [[VAL0]], [[ITER_ARG1:.*]] = [[VAL1]], [[ITER_ARG2:.*]] = [[VAL0]]) - %2:3 = "tf.WhileRegion"(%0, %1, %0) ({ - ^cond(%carg0: tensor, %carg1: tensor, %carg2: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.constant - %3 = "tf.Const"() {value = dense<10> : tensor} : () -> tensor - // CHECK: [[VAL4:%.+]] = mhlo.compare LT, [[ITER_ARG2]], [[VAL3]] - %4 = "mhlo.compare"(%carg2, %3) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL4]] - "tf.Yield"(%4) : (tensor) -> () - }, { - ^body(%barg0: tensor, %barg1: tensor, %barg2: tensor): - // CHECK: [[VAL5:%.+]] = mhlo.constant - %5 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor - // CHECK: [[VAL6:%.+]] = mhlo.add [[ITER_ARG2]], [[VAL5]] - %6 = mhlo.add %barg2, %5 : tensor - // CHECK: [[VAL7:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL5]] - %7 = mhlo.add %barg0, %5 : tensor - // CHECK: mhlo.return [[VAL7]], [[ITER_ARG1]], [[VAL6]] - "tf.Yield"(%7, %barg1, %6) : (tensor, tensor, tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor, tensor, tensor) -> (tensor, tensor, tensor) - // CHECK: return [[VAL2]]#2 - func.return %2#2 : tensor -} - -// ----- - -// CHECK-LABEL: func @whileRegionImplicitInputs -// CHECK-SAME: ([[ARG0:%.+]]: tensor) -func.func @whileRegionImplicitInputs(%arg0: tensor) -> tensor { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: [[VAL2:%.+]] = mhlo.while([[ITER_ARG0:.*]] = [[ARG0]]) - %2 = "tf.WhileRegion"(%arg0) ({ - ^cond(%carg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[ITER_ARG0]], [[VAL0]] - %3 = mhlo.compare LT, %carg0, %0 : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%3) : (tensor) -> () - }, { - ^body(%barg0: tensor): - // CHECK: [[VAL3:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL1]] - %3 = mhlo.add %barg0, %1 : tensor - // CHECK: [[VAL4:%.+]] = mhlo.add [[ITER_ARG0]], [[VAL3]] - %4 = mhlo.add %barg0, %3 : tensor - // CHECK: mhlo.return [[VAL4]] - "tf.Yield"(%4) : (tensor) -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : (tensor) -> tensor - // CHECK: return [[VAL2]] - func.return %2 : tensor -} - -// CHECK-LABEL: func @whileRegionMultipleImplicitInputs -func.func @whileRegionMultipleImplicitInputs() { - // CHECK: [[VAL0:%.+]] = mhlo.constant dense<0> - %0 = mhlo.constant dense<0> : tensor - // CHECK: [[VAL1:%.+]] = mhlo.constant dense<-1> - %1 = mhlo.constant dense<-1> : tensor - // CHECK: mhlo.while() - "tf.WhileRegion"() ({ - // CHECK: [[VAL3:%.+]] = mhlo.compare LT, [[VAL0]], [[VAL1]] - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - // CHECK: mhlo.return [[VAL3]] - "tf.Yield"(%2) : (tensor) -> () - }, { - // CHECK: [[VAL3:%.+]] = mhlo.add [[VAL0]], [[VAL1]] - %2 = mhlo.add %0, %1 : tensor - // CHECK: mhlo.return - "tf.Yield"() : () -> () - }) {is_stateless = true, parallel_iterations = 10 : i64} : () -> () - // CHECK: return - func.return -} diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD new file mode 100644 index 00000000000000..f0829ca95721d0 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD @@ -0,0 +1,60 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "graph_to_tf_executor_registration", + srcs = [ + "graph_to_tf_executor_registration.cc", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/translate:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow/translate:translate_lib", + "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "graph_to_tf_executor_registration_test", + size = "small", + srcs = ["graph_to_tf_executor_registration_test.cc"], + deps = [ + ":graph_to_tf_executor_registration", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc new file mode 100644 index 00000000000000..8c7eb5e66d8e40 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc @@ -0,0 +1,193 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/client/client_library.h" +#include "xla/client/compile_only_client.h" +#include "xla/stream_executor/host/host_platform_id.h" +#include "xla/stream_executor/platform_manager.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/status.h" + +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +using tsl::Status; + +static constexpr char kMlirToGraphCompilationCheckName[] = + "mlir-to-graph-compilation-check"; +// Use CPU arbitrarily in order to check that a graph compiles at all +static constexpr char kArbitraryDeviceName[] = "XLA_CPU_JIT"; + +static Status CompileGraph(tensorflow::Graph* graph, + xla::CompileOnlyClient* client) { + if (!graph || !client) { + return Status(absl::StatusCode::kInvalidArgument, + "Invalid graph or client"); + } + + tensorflow::FunctionDefLibrary flib; + auto flib_def = std::make_unique( + tensorflow::OpRegistry::Global(), flib); + + tensorflow::XlaCompiler::Options options; + options.device_type = tensorflow::DeviceType(kArbitraryDeviceName); + options.client = client; + options.flib_def = flib_def.get(); + tensorflow::XlaCompiler compiler(options); + + std::unique_ptr graph_copy( + new tensorflow::Graph(tensorflow::OpRegistry::Global())); + tensorflow::CopyGraph(*graph, graph_copy.get()); + + tensorflow::XlaCompiler::CompileOptions compile_options; + tensorflow::XlaCompiler::CompilationResult result; + return compiler.CompileGraph(compile_options, + kMlirToGraphCompilationCheckName, + std::move(graph_copy), {}, &result); +} + +static mlir::OwningOpRef GraphdefToMlirTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context) { + tensorflow::GraphdefToMlirOptions options{ + debug_info_file, xla_compile_device_type, + prune_unused_nodes, convert_legacy_fed_inputs, + graph_as_function, upgrade_legacy, + enable_shape_inference, unconditionally_use_set_output_shapes, + enable_soft_placement, set_original_tf_func_name}; + + auto module_or = tensorflow::GraphdefToMlirTranslateFunction( + input, input_arrays, input_dtypes, input_shapes, output_arrays, + control_output_arrays, options, context); + if (!module_or.status().ok()) return nullptr; + return std::move(module_or).value(); +} + +static mlir::TranslateToMLIRRegistration GraphdefToMlirTranslate( + "graphdef-to-mlir", "graphdef-to-mlir", GraphdefToMlirTranslateFunction); + +static mlir::LogicalResult MlirToGraphTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + if (!module) return mlir::failure(); + + tensorflow::GraphExportConfig confs; + confs.export_entry_func_to_flib = export_entry_func_to_flib; + confs.export_original_tf_func_name = export_original_tf_func_name; + + std::unique_ptr flib_def; + auto graph = + std::make_unique(tensorflow::OpRegistry::Global()); + absl::flat_hash_set control_ret_nodes; + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + module, confs, &graph, flib_def.get(), &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Export to Graph failed: " << status; + return mlir::failure(); + } + + // Use Host platform, which should always exist, to make sure graphs compile. + auto platform = stream_executor::PlatformManager::PlatformWithId( + stream_executor::host::kHostPlatformId); + if (!platform.ok()) { + return mlir::failure(); + } + auto client = + xla::ClientLibrary::GetOrCreateCompileOnlyClient(platform.value()); + + tensorflow::XlaOpRegistry::RegisterCompilationKernels(); + + // Verify that the resulting graph can compile. + if (client.ok() && !CompileGraph(graph.get(), client.value()).ok()) { + return mlir::failure(); + } + + auto graphdef = std::make_unique(); + // Print the graph to the output after going through GraphDef conversion. + // The DumpGraphToFile would do this anyway so just skip straight to it. + graph->ToGraphDef(graphdef.get()); + output << tsl::LegacyUnredactedDebugString(*graphdef); + + return mlir::success(); +} + +static mlir::TranslateFromMLIRRegistration mlir_to_graph_translate( + /*name=*/"mlir-to-graph", /*description=*/"convert mlir to graph", + MlirToGraphTranslateFunction, [](mlir::DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + }); + +static llvm::LogicalResult MlirToGraphdefTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output) { + if (!module) return mlir::failure(); + + tensorflow::GraphExportConfig confs; + confs.export_entry_func_to_flib = export_entry_func_to_flib; + confs.export_original_tf_func_name = export_original_tf_func_name; + + tensorflow::FunctionLibraryDefinition flib_def( + tensorflow::OpRegistry::Global(), tensorflow::FunctionDefLibrary()); + auto graph = + std::make_unique(tensorflow::OpRegistry::Global()); + absl::flat_hash_set control_ret_nodes; + + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + module, confs, &graph, &flib_def, &control_ret_nodes); + if (!status.ok()) { + LOG(ERROR) << "Export to Graph failed: " << status; + return mlir::failure(); + } + + tensorflow::GraphDef graphdef; + graph->ToGraphDef(&graphdef); + output << tsl::LegacyUnredactedDebugString(graphdef); + return mlir::success(); +} + +static mlir::TranslateFromMLIRRegistration mlir_to_graphdef_translate( + "mlir-to-graphdef", "mlir-to-graphdef", MlirToGraphdefTranslateFunction, + [](mlir::DialectRegistry& registry) { + mlir::RegisterAllTensorFlowDialects(registry); + }); + +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration_test.cc b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration_test.cc similarity index 79% rename from tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration_test.cc rename to tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration_test.cc index df4a37c2fe64a9..c83a1483b6d3c5 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,14 +19,25 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/core/platform/test.h" -namespace mlir { -namespace { +namespace tensorflow { +namespace tf2xla { +namespace v2 { +namespace testing { + +using mlir::LogicalResult; +using mlir::StringRef; +using mlir::Translation; +using mlir::TranslationParser; class MlirTranslationTest : public ::testing::Test { private: @@ -52,17 +63,15 @@ class MlirTranslationTest : public ::testing::Test { } private: - llvm::cl::opt* + llvm::cl::opt* RegisterTranslation() { // Can only register once per process. static const auto requested_translation = - new llvm::cl::opt( + new llvm::cl::opt( llvm::cl::desc("Translation to perform")); return requested_translation; } - llvm::cl::opt* - translation_; + llvm::cl::opt* translation_; }; TEST_F(MlirTranslationTest, TranslatesMlirToGraph) { @@ -83,5 +92,7 @@ func.func @main() -> (tensor<1x2xf16>, tensor<2xf16>) { EXPECT_TRUE(absl::StrContains(result, "node {")); } -} // namespace -} // namespace mlir +} // namespace testing +} // namespace v2 +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 4092a86dd38e5c..a9872645e40a45 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -169,12 +169,12 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client/lib:conv_grad_size_util", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/builder/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", "@stablehlo//:chlo_ops", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -293,16 +293,16 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:convert_op_folder", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", ], @@ -354,14 +354,14 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_function_importer", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) @@ -390,8 +390,8 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/mlir_hlo", "@local_xla//xla/tsl/lib/core:status_test_util", ], @@ -435,7 +435,6 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/client:xla_builder", "@local_xla//xla/mlir_hlo", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc index b2ce3f56ef9960..f1e843b81f5476 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/infeed_ops_xla_adjust_layout.cc @@ -32,12 +32,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/stream_executor/tpu/c_api_conversions.h" #include "xla/stream_executor/tpu/tpu_api.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { namespace mhlo { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 9f45164ba4dfe3..47fab19abe61d0 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -59,13 +59,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/utils.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/mlir_hlo/utils/hlo_utils.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/rng_alg.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 68c412f79ff393..a3cbb4ba2cd763 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -42,11 +42,11 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" #include "xla/side_effect_util.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index b528215c75194f..9057e2406fab06 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -49,7 +49,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 2709f9dada21a7..6f864f8eb52736 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -60,15 +61,15 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -94,6 +95,22 @@ using ::tensorflow::Tensor; using ::tsl::StatusOr; using ::xla::XlaComputation; +// The OpOrArgLocNameMapper adds invalid characters to the name of the op when +// concatenating locations. This version removes those characters to make the +// name valid for NodeDef. +class OpOrArgLocNameMapperWithoutInvalidCharacters + : public tensorflow::OpOrArgLocNameMapper { + public: + OpOrArgLocNameMapperWithoutInvalidCharacters() = default; + ~OpOrArgLocNameMapperWithoutInvalidCharacters() override = default; + + protected: + std::string GetName(tensorflow::OpOrVal op_or_val) override { + std::string name = OpOrArgLocNameMapper::GetName(op_or_val); + return absl::StrReplaceAll(name, {{";", "."}}); + } +}; + static std::unique_ptr CreateDeviceMgr( const std::string& device_type) { // Register compilation kernels for all registered XLA backends. @@ -125,6 +142,8 @@ Tf2XlaRewriter::Tf2XlaRewriter(Operation* op, PatternRewriter& rewriter, : op_(op), device_type_(device_type), rewriter_(rewriter), + name_mapper_( + std::make_unique()), context_(nullptr), xla_builder_(op_->getName().getStringRef().str()) {} @@ -319,7 +338,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { } auto nodedef_or = tensorflow::ConvertTFDialectOpToNodeDef( - op_, name_mapper_.GetUniqueName(op_), + op_, name_mapper_->GetUniqueName(op_), /*ignore_unregistered_attrs=*/true); if (!nodedef_or.ok()) { return op_->emitRemark() << "failed to convert op to NodeDef: " diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index 2b8c52750a6c44..dc8b0ad459d2e1 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -29,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/op_kernel.h" @@ -106,7 +106,7 @@ class Tf2XlaRewriter { std::string device_type_; mlir::PatternRewriter& rewriter_; - tensorflow::OpOrArgLocNameMapper name_mapper_; + std::unique_ptr name_mapper_; tensorflow::XlaContext* context_; // Ref-counted. diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index 8d0f0404b8980f..2cd2f3591ba0cd 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -317,6 +317,22 @@ TEST_F(Tf2XlaRewriterTest, CreatesDefaultValues) { TF_ASSERT_OK(LegalizeModule(kModuleWithOpWithoutValuesThatShouldBeDefaulted)); } +TEST_F(Tf2XlaRewriterTest, OpWithLocationDoesntBreakNodeDefName) { + // A named location 'Name(Source)' causes the GetNameFromLoc method to append + // all the other locations to the name with a ';' separator. This test ensures + // that the name used for the NodeDef does not contain that invalid character. + static constexpr char kModuleWithOpWithoutValuesThatShouldBeDefaulted[] = + R"mlir( + module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1610 : i32}} { + func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "tf.Exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> loc(fused["exp"("exp"), "exp"]) + func.return %0 : tensor<2xf32> + } + })mlir"; + + TF_ASSERT_OK(LegalizeModule(kModuleWithOpWithoutValuesThatShouldBeDefaulted)); +} + TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { XlaBuilder builder("test_builder"); XlaComputation to_apply; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index ee377b93b7662c..aa38150e6a14c3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -24,7 +24,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index 19c31018185c82..368afec6ef07f0 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -50,7 +50,7 @@ def LegalizeTF : Pass<"xla-legalize-tf", "ModuleOp"> { "chlo::ChloDialect", "func::FuncDialect", "mhlo::MhloDialect", - "quant::QuantizationDialect", + "quant::QuantDialect", "shape::ShapeDialect", "sparse_tensor::SparseTensorDialect", "stablehlo::StablehloDialect" diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index 333712a56e55b8..71cbd5066128b3 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -272,6 +272,9 @@ tf_python_pybind_extension( pytype_srcs = [ "tfr_wrapper.pyi", ], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ ":tfr", "//tensorflow/compiler/mlir/tensorflow", diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl index e9dd5e9178080b..9f10f82f0e1b2f 100644 --- a/tensorflow/compiler/mlir/tfr/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -1,9 +1,11 @@ """BUILD extension for TF composition project.""" +load("@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") +# TODO(b/356020232): cleanup use_pywrap_rules once migration is done def gen_op_libraries( name, src, @@ -22,7 +24,11 @@ def gen_op_libraries( if not src.endswith(".py") or name == src[:-3]: fail("'src' %s conflicts with op Python wrapper. Rename it to be different from 'name'." % src) - py_deps = [ + py_deps = [] + if use_pywrap_rules(): + py_deps = ["//tensorflow/python:_pywrap_tensorflow"] + + py_deps += [ "//tensorflow/compiler/mlir/tfr:op_reg_gen", "//tensorflow/compiler/mlir/tfr:tfr_gen", "//tensorflow/compiler/mlir/tfr:composite", @@ -42,7 +48,9 @@ def gen_op_libraries( name = registered_op, srcs = [], outs = [name + ".inc.cc"], - cmd = "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, + cmd = + "PYWRAP_TARGET='//third_party/tensorflow/python:_pywrap_tensorflow' " + + "$(location %s) --output=$@ --gen_register_op=true" % gen_op_lib_exec, tools = [":" + gen_op_lib_exec], tags = tags, ) @@ -105,7 +113,9 @@ def gen_op_libraries( name = name + "_mlir", srcs = [], outs = [name + ".mlir"], - cmd = "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, + cmd = + "PYWRAP_TARGET='//third_party/tensorflow/python:_pywrap_tensorflow' " + + "$(location %s) --output=$@ --gen_register_op=false" % gen_tfr_lib_exec, tools = [":" + gen_tfr_lib_exec], tags = tags, ) diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 76ea4fa36f17a5..42825c297736ef 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -29,7 +29,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index b3bd4d618bd808..0f051851ad92ad 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -21,7 +21,7 @@ limitations under the License. include "mlir/Dialect/Shape/IR/ShapeBase.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/FunctionInterfaces.td" -include "mlir/Dialect/Quant/QuantOpsBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" diff --git a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc index babfef28d33bad..34ae51c14ed177 100644 --- a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc +++ b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc index 5c2dd3780bf798..dab7ee6fa72cac 100644 --- a/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc +++ b/tensorflow/compiler/mlir/tfr/passes/tfr_opt.cc @@ -16,8 +16,8 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/InitAllDialects.h" // from @llvm-project @@ -37,7 +37,7 @@ int main(int argc, char **argv) { mlir::DialectRegistry registry; registry.insert(); mlir::func::registerAllExtensions(registry); diff --git a/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc index 8120625bc89e27..3523f295ee8291 100644 --- a/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/composite_ops.cc @@ -13,10 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc index 3aaa0850805030..c9dcfd26104e86 100644 --- a/tensorflow/compiler/mlir/tfr/resources/test_ops.cc +++ b/tensorflow/compiler/mlir/tfr/resources/test_ops.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 5c1519e43c2bcb..ff635198f2071d 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -228,6 +228,7 @@ cc_library( ":cost_analysis", ":fallback_converter", ":tensor_array_side_effect_analysis", + ":tfrt_compile_options", ":tfrt_pipeline_options", ":tpu_passes", ":transform_utils", @@ -248,8 +249,12 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_gpu_opdefs", "//tensorflow/compiler/tf2xla:tf2xla_defs", "//tensorflow/core:framework", + "//tensorflow/core/ir/types:Dialect", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@llvm-project//llvm:Support", @@ -257,10 +262,12 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:errors", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", "@tf_runtime//:stream_analysis", @@ -504,7 +511,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", ], ) @@ -658,7 +667,9 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow/ir/host_runtime:tensorflow_tfrt_ops", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index 407871ecbdcd94..478b6d156dee4c 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD index 44fc2c0f6945b4..d68c5c3f6a68b3 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/BUILD @@ -40,6 +40,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:resource_loader", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@tf_runtime//:init_tfrt_dialects", ], diff --git a/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc b/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc index bfa9c148174e9c..ddfb073fa2c699 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc +++ b/tensorflow/compiler/mlir/tfrt/tests/ir/tfrt_fallback_util_test.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_sync.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index d3896c65d63f21..ee6178b16f9f8c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -1,4 +1,8 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tf_proto_library", +) load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_test") package( @@ -125,6 +129,7 @@ cc_library( srcs = ["tf2hlo.cc"], hdrs = ["tf2hlo.h"], deps = [ + ":ifrt_compilation_proto_cc", ":ifrt_constants", ":ifrt_types", "//tensorflow/compiler/jit:xla_cpu_jit", @@ -151,16 +156,19 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/pjrt:pjrt_compiler", "@local_xla//xla/python/ifrt", "@local_xla//xla/service:computation_placer_hdr", - "@local_xla//xla/service/llvm_ir:llvm_util", + "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/stream_executor", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/tsl/concurrency:ref_count", ], ) @@ -204,16 +212,25 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core/platform:resource_loader", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/pjrt:pjrt_compiler", + "@local_xla//xla/pjrt/cpu:cpu_client", + "@local_xla//xla/pjrt/gpu:se_gpu_pjrt_client", "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:mock", "@local_xla//xla/python/ifrt:test_util", + "@local_xla//xla/python/pjrt_ifrt", "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", + "@local_xla//xla/service:computation_placer_hdr", ], ) @@ -263,16 +280,21 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core/platform:resource_loader", "//tensorflow/core/tfrt/graph_executor:graph_execution_options", + "//tensorflow/core/tfrt/ifrt:ifrt_executable_registry", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/runtime", "//tensorflow/core/tfrt/saved_model:saved_model_testutil", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", @@ -281,3 +303,14 @@ tf_cc_test( "@tf_runtime//:hostcontext", ], ) + +tf_proto_library( + name = "ifrt_compilation_proto", + srcs = ["ifrt_compilation.proto"], + protodeps = [ + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", + "@local_xla//xla/service:hlo_proto", + "//tensorflow/core:protos_all", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto", + ], +) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc index 4c77fc0d42e4e1..865e2eb1fe865c 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -87,6 +88,13 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, func.setPublic(); } }); + // Remove the program id attribute from the submodule because they are not + // needed and will prevent us generating consistent cache key. + // program id is already in ifrt_call op's attribute and that part is not + // touched here. + submodule->get()->walk([](mlir::func::FuncOp func) { + func->removeAttr("tfrt_ifrt_serving.program_id"); + }); TF_ASSIGN_OR_RETURN( auto executable, @@ -100,7 +108,8 @@ CompileAndRegisterIfrtPrograms(absl::string_view model_name, ifrt_model_context.GetDeviceMgr(), ifrt_model_context.GetShapeRepresentationFn(), ifrt_model_context.GetIfrtServingCoreSelector(), - ifrt_model_context.GetCompilationEnvironmentProto())); + ifrt_model_context.GetCompilationEnvironmentProto(), + ifrt_model_context.GetPersistentCompilationCache())); // Register the Ifrt program to `ServingExecutableRegistry` so that // the client TF program can invoke them via `IfrtCall` op. @@ -148,15 +157,27 @@ absl::Status IfrtBackendCompiler::CompileTensorflow( "Failed to find model context for ifrt serving."); } + if ((*ifrt_model_context)->IsFrozen()) { + return absl::FailedPreconditionError( + "Cannot compile IFRT programs after the model is frozen. Please make " + "sure warmup covers all signatures by following go/tf-model-warmup."); + } + mlir::StatusScopedDiagnosticHandler diag_handler(module->getContext()); if (VLOG_IS_ON(1)) { tensorflow::DumpMlirOpToFile("ifrt_tpu_bct_conversion_before", module); } + TfrtTpuCompileOptions options; + options.disable_set_default_tpu_device_and_device_assignment_attributes = + compile_options_ + .disable_set_default_tpu_device_and_device_assignment_attributes; + options.support_multi_dims_sharding = true; + if (tpu_compiler_ != nullptr) { // Run backward compat pass so that we can use bridge to do clustering. if (mlir::failed( - tpu_compiler_->RunTPUBackwardCompatConversion(module, {}))) { + tpu_compiler_->RunTPUBackwardCompatConversion(module, options))) { return diag_handler.Combine( absl::InternalError("Failed to handle legacy TPU Ops")); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h index 085c70812feaed..0dfaa081822862 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h @@ -28,9 +28,22 @@ namespace ifrt_serving { // Implements the custom backend compiler for IFRT based serving in TFRT. class IfrtBackendCompiler : public tensorflow::BackendCompiler { public: + struct Options { + // If true, disable running TFRTSetTPUDeviceAttrPass which set the default + // `tf.device` and `device_assignment` attributes. + // This is a server-level option for now. We can consider to make it a + // per-model option in the future. + bool disable_set_default_tpu_device_and_device_assignment_attributes = true; + }; + explicit IfrtBackendCompiler(TpuCompiler* tpu_compiler = nullptr) : tpu_compiler_(tpu_compiler) {} + explicit IfrtBackendCompiler(const Options& ifrt_backend_compile_options, + TpuCompiler* tpu_compiler = nullptr) + : tpu_compiler_(tpu_compiler), + compile_options_(ifrt_backend_compile_options) {} + void GetDependentDialects(mlir::DialectRegistry& registry) const override { if (tpu_compiler_) { tpu_compiler_->RegisterTPUDialects(®istry); @@ -45,6 +58,7 @@ class IfrtBackendCompiler : public tensorflow::BackendCompiler { private: TpuCompiler* tpu_compiler_; // Not owned. + Options compile_options_; }; } // namespace ifrt_serving diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 3b190d326ce58f..a9259beab0a9a9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -16,11 +16,16 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler.h" #include +#include #include +#include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -35,18 +40,19 @@ limitations under the License. #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tsl/platform/env.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" #include "tfrt/host_context/resource_context.h" // from @tf_runtime namespace tensorflow { namespace ifrt_serving { -namespace { tsl::thread::ThreadPool& GetThreadPool() { constexpr int kMaxParallelism = 16; @@ -56,47 +62,119 @@ tsl::thread::ThreadPool& GetThreadPool() { return *thread_pool; } -TEST(IfrtBackendCompilerTest, Basic) { +class IfrtBackendCompilerTest : public ::testing::Test { + protected: + void SetUp() override { + mlir::registerAllDialects(registry_); + mlir::RegisterAllTensorFlowDialects(registry_); + context_.appendDialectRegistry(registry_); + + // Create contexts required for the compiler execution. + TF_ASSERT_OK_AND_ASSIGN(client_, xla::ifrt::test_util::GetClient()); + + core_selector_ = std::make_unique( + &mock_serving_device_selector_, client_->addressable_device_count()); + + runtime_context_.resource_context().CreateResource( + "IfrtModelContext", client_, core_selector_.get(), &GetThreadPool(), + /*compilation_environment_proto=*/nullptr); + } + + void verifyModules() { + absl::MutexLock l(&ServingExecutableRegistry::mu_); + for (const auto& [_, executable] : + *ServingExecutableRegistry::executables_) { + absl::MutexLock l(&executable->mutex_); + executable->module_->walk([](mlir::func::FuncOp func) { + ASSERT_FALSE(func->hasAttr("tfrt_ifrt_serving.program_id")); + }); + } + } + + mlir::DialectRegistry registry_; + mlir::MLIRContext context_; + std::shared_ptr client_; + + std::unique_ptr runtime_ = + tensorflow::tfrt_stub::DefaultTfrtRuntime(/*num_threads=*/1); + tensorflow::tfrt_stub::GraphExecutionOptions graph_execution_options_ = + tensorflow::tfrt_stub::GraphExecutionOptions(runtime_.get()); + tfrt::ResourceContext resource_context_; + tensorflow::tfrt_stub::ModelRuntimeContext runtime_context_ = + tensorflow::tfrt_stub::ModelRuntimeContext( + &graph_execution_options_, /*export_dir=*/"", &resource_context_); + + tsl::test_util::MockServingDeviceSelector mock_serving_device_selector_; + std::unique_ptr core_selector_; + IfrtBackendCompiler compiler_; +}; + +namespace { +using ::testing::HasSubstr; +using ::tsl::testing::StatusIs; + +struct IfrtBackendCompilerTestParams { + std::string mlir_file_name; +}; + +class IfrtBackendCompilerParameterizedTest + : public IfrtBackendCompilerTest, + public ::testing::WithParamInterface {}; + +TEST_P(IfrtBackendCompilerParameterizedTest, CompilesOk) { // Create test input module constexpr absl::string_view kDataDirectory = "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( - absl::StrCat(kDataDirectory, "/ifrt_cluster.mlir")); + absl::StrCat(kDataDirectory, "/", GetParam().mlir_file_name)); + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context_); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - mlir::RegisterAllTensorFlowDialects(registry); + TF_ASSERT_OK( + compiler_.CompileTensorflow(runtime_context_, mlir_module.get())); + verifyModules(); +} - mlir::MLIRContext context(registry); +INSTANTIATE_TEST_SUITE_P(IfrtBackendCompilerParameterizedTest, + IfrtBackendCompilerParameterizedTest, + ::testing::ValuesIn({ + {.mlir_file_name = "ifrt_cluster.mlir"}, + {.mlir_file_name = "restore_with_reference.mlir"}, + })); +TEST_F(IfrtBackendCompilerTest, CompileShallFailAfterModelIsFrozen) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/ifrt_cluster.mlir")); mlir::OwningOpRef mlir_module = - mlir::parseSourceFile(mlir_module_path, &context); + mlir::parseSourceFile(mlir_module_path, &context_); ASSERT_TRUE(mlir_module); ASSERT_TRUE(mlir_module.get() != nullptr); - // Create contexts required for the compiler execution. - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, - xla::ifrt::test_util::GetClient()); + TF_ASSERT_OK( + compiler_.CompileTensorflow(runtime_context_, mlir_module.get())); - std::unique_ptr runtime = - tensorflow::tfrt_stub::DefaultTfrtRuntime(/*num_threads=*/1); - tensorflow::tfrt_stub::GraphExecutionOptions graph_execution_options( - runtime.get()); - tfrt::ResourceContext resource_context; - tensorflow::tfrt_stub::ModelRuntimeContext runtime_context( - &graph_execution_options, /*export_dir=*/"", &resource_context); - - tsl::test_util::MockServingDeviceSelector mock_serving_device_selector; - IfrtServingCoreSelector core_selector(&mock_serving_device_selector, - client->addressable_device_count()); - - runtime_context.resource_context().CreateResource( - "IfrtModelContext", client, &core_selector, &GetThreadPool(), - /*compilation_environment_proto=*/nullptr); - - IfrtBackendCompiler compiler; - TF_ASSERT_OK(compiler.CompileTensorflow(runtime_context, mlir_module.get())); + std::optional ifrt_model_context = + runtime_context_.resource_context().GetResource( + "IfrtModelContext"); + ASSERT_TRUE(ifrt_model_context.has_value()); + + TF_ASSERT_OK((*ifrt_model_context)->Freeze()); + + mlir::OwningOpRef another_mlir_module = + mlir::parseSourceFile(mlir_module_path, &context_); + + EXPECT_THAT( + compiler_.CompileTensorflow(runtime_context_, another_mlir_module.get()), + StatusIs( + absl::StatusCode::kFailedPrecondition, + HasSubstr("Cannot compile IFRT programs after the model is frozen"))); } } // namespace diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.proto b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.proto new file mode 100644 index 00000000000000..7a97feca8f0f32 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package tensorflow.ifrt_serving; + +import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "xla/service/hlo.proto"; +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/protobuf/tpu/compile_metadata.proto"; + +message Tf2HLOResultProto { + xla.HloModuleProto hlo_module_proto = 1; + tensorflow.tpu.TPUCompileMetadataProto compile_metadata = 2; + tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/restore_with_reference.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/restore_with_reference.mlir new file mode 100644 index 00000000000000..cb9e459a2277e0 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/restore_with_reference.mlir @@ -0,0 +1,14 @@ +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() <{initializers = [@"save/restore_all_1"]}> : () -> () + func.func @"save/restore_all_1"() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_save/restore_all_1"], tf_saved_model.initializer_type = "restore_op"} { + %cst = "tf.Const"() <{value = dense<"restore_ariables"> : tensor}> : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<["", ""]> : tensor<2x!tf_type.string>}> : () -> tensor<2x!tf_type.string> + %cst_1 = "tf.Const"() <{value = dense<["y", "z"]> : tensor<2x!tf_type.string>}> : () -> tensor<2x!tf_type.string> + %0:2 = "tf.RestoreV2"(%cst, %cst_1, %cst_0): (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor, tensor<1x3xf32>) + %1 = "tf.VariableV2"() <{container = "x", shape = #tf_type.shape<>, shared_name = "y"}> : () -> tensor + %dummy = "tf.Assign"(%1, %0#0) : (tensor, tensor) -> tensor + %2 = "tf.VarHandleOp"() <{container = "x", shared_name = "z"}> : () -> tensor>> + "tf.AssignVariableOp"(%2, %0#1) : (tensor>>, tensor<1x3xf32>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_gpu.mlir b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_gpu.mlir new file mode 100644 index 00000000000000..8dfb7f9cca3dbe --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata/tf2hlo_gpu.mlir @@ -0,0 +1,25 @@ +// MLIR which is legalize only with the right device type. +// The mlir is generated by running +// tensorflow/compiler/mlir/tf-opt -tf-xla-call-module-serialization +// +// module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, +// producer = 268 : i32}} { +// func.func private @_jit_sin(%arg0: tensor) -> tensor { +// %0 = stablehlo.sine %arg0 : tensor +// return %0 : tensor +// } +// func.func @main(%arg0: tensor) -> tensor<*xf32> { +// %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = +// "", dim_args_spec = [], _entry_function = @_jit_sin, module = "", +// platforms = ["CUDA"], version = 6 : i64} : (tensor) -> +// tensor<*xf32> +// func.return %0 : tensor<*xf32> +// } +// } +// +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { + func.func @main(%arg0: tensor) -> tensor<*xf32> attributes {__tpu_compile_metadata_text = "args { dtype: DT_FLOAT kind: PARAMETER } retvals { } num_replicas: 1 num_cores_per_replica: 1"} { + %0 = "tf.XlaCallModule"(%arg0) {Sout = [#tf_type.shape<*>], device = "", dim_args_spec = [], module = "ML\EFR\03MLIRxxx-trunk\00\01\17\05\01\05\01\03\05\03\07\07\t\0B\03K5\07\01\1B\07\0B\13\0B3\0B\0B\0B\0B\0F\0B\13\0B\03\1B\0F\1B\0B\0B\0B\0B\0B\0F\13\0B\0B\0B\0B\03\07\0F\17\07\02\A7\1F\05\0D\03\03\03\07\05\0F\03\0B\0B\1B\0D'\0F)\031\113\05\11\05\13\05\15\05\17\1D\15\17\05\19\17\19\EF\01\05\1B\03\03\1D\0D\05\1F!#%\1D\1D\1D\1F\1D!\1D##\03\03\03+\0D\03-/\1D%\1D'\1D)\1D+)\01\05\11\03\01\03\01\t\04A\05\01\11\01\05\07\03\01\05\03\11\01\t\05\03\05\0B\03\01\01\05\06\13\03\01\03\01\07\04\01\03\03\06\03\01\05\01\00\9A\04-\0F\0B\03!\1B\1D\05\1B\83/\1F\15\1D\15\11\13\15\11\11\0F\0B\11builtin\00vhlo\00module\00func_v1\00sine_v1\00return_v1\00sym_name\00jit_sin\00arg_attrs\00function_type\00res_attrs\00sym_visibility\00jit(sin)/jit(main)/sin\00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\00jax.arg_info\00x\00mhlo.sharding\00{replicated}\00jax.result_info\00\00main\00public\00", platforms = ["CUDA"], version = 6 : i64} : (tensor) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc index 9156e41928acc7..2db91568dec130 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include #include #include -#include #include #include "absl/log/check.h" @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -39,18 +40,20 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.pb.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_constants.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/python/ifrt/client.h" #include "xla/service/computation_placer.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/stream_executor/platform_manager.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -59,15 +62,69 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/fingerprint.h" #include "tsl/platform/statusor.h" namespace tensorflow { namespace ifrt_serving { namespace { static constexpr absl::string_view kEntryFuncName = "main"; +uint64_t MlirModuleFingerprint(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(false); + module.print(os, flags); + return tsl::Fingerprint64(os.str()); +} } // namespace +absl::StatusOr Tf2HloArg::Key() { + uint64_t fingerprint = tsl::Fingerprint64(platform_name); + if (topology) { + TF_ASSIGN_OR_RETURN(std::string serialized_topology, topology->Serialize()); + fingerprint = tsl::Fingerprint64(serialized_topology); + } + if (platform_name != xla::CudaName() && !topology) { + return absl::FailedPreconditionError( + "Topology is required for non-GPU compilation."); + } + fingerprint = + tsl::FingerprintCat64(fingerprint, MlirModuleFingerprint(module)); + for (const auto& dtype_and_shape : input_dtypes_and_shapes) { + fingerprint = tsl::FingerprintCat64( + fingerprint, + tsl::Fingerprint64(tensorflow::DataType_Name(dtype_and_shape.dtype))); + + std::string serialized_shape; + if (!tsl::SerializeToStringDeterministic(dtype_and_shape.shape.AsProto(), + &serialized_shape)) { + return absl::InternalError("Failed to serialize shape"); + } + + fingerprint = tsl::FingerprintCat64(fingerprint, + tsl::Fingerprint64(serialized_shape)); + } + fingerprint = tsl::FingerprintCat64(fingerprint, + tsl::Fingerprint64(entry_function_name)); + std::string serialized_compile_metadata; + if (!tsl::SerializeToStringDeterministic(compile_metadata, + &serialized_compile_metadata)) { + return absl::InternalError("Failed to serialize compile metadata"); + } + fingerprint = tsl::FingerprintCat64( + fingerprint, tsl::Fingerprint64(serialized_compile_metadata)); + return absl::StrCat(absl::Hex(fingerprint)); +} + +Tf2HLOResultProto Tf2HloResult::ToProto() const { + Tf2HLOResultProto proto; + *proto.mutable_hlo_module_proto() = hlo_module_proto; + *proto.mutable_compile_metadata() = compile_metadata; + *proto.mutable_host_compute_metadata() = host_compute_metadata; + return proto; +} + absl::Status UpdateCompileMetadata( tensorflow::tpu::TPUCompileMetadataProto& metadata, absl::Span inputs) { @@ -147,17 +204,21 @@ absl::StatusOr GetCompileMetadata( return metadata; } -absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, - absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, - const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { +absl::StatusOr CompileTfToHlo(const Tf2HloArg& arg) { if (VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("ifrt_before_bridge_phase2", module); + tensorflow::DumpMlirOpToFile("ifrt_before_bridge_phase2", arg.module); + } + + // Device_type is a string of + // tensorflow/compiler/mlir/tf2xla/api/v2/device_type.proto:DeviceType + std::string device_type = "XLA_TPU_JIT"; + if (arg.platform_name == xla::CudaName()) { + device_type = "XLA_GPU_JIT"; } + VLOG(1) << "device_type: " << device_type; tpu::MlirToHloArgs mlir_to_hlo_args; - std::string module_str = tensorflow::SerializeMlirModule(module); + std::string module_str = tensorflow::SerializeMlirModule(arg.module); mlir_to_hlo_args.mlir_module = module_str; // Use fallback bridge as other modes may get deprecated. mlir_to_hlo_args.rollout_state = @@ -171,7 +232,7 @@ absl::StatusOr CompileTfToHlo( std::vector arg_shapes; - for (const auto& input : inputs) { + for (const auto& input : arg.input_dtypes_and_shapes) { arg_shapes.push_back(input.shape); } @@ -183,11 +244,12 @@ absl::StatusOr CompileTfToHlo( TF_ASSIGN_OR_RETURN( tensorflow::XlaCompiler::CompilationResult compilation_result, tensorflow::tf2xla::v2::LegalizeMlirToHlo( - mlir_to_hlo_args, compile_metadata, use_tuple_args, - /*device_type=*/"XLA_TPU_JIT", custom_legalization_passes, + mlir_to_hlo_args, arg.compile_metadata, use_tuple_args, device_type, + custom_legalization_passes, /*shape_determination_fns=*/ tensorflow::XlaShapeLayoutHelpers::ShapeDeterminationFns( - tensorflow::UseNoPreferenceLayoutFn(), shape_representation_fn), + tensorflow::UseNoPreferenceLayoutFn(), + arg.shape_representation_fn), arg_shapes, &arg_core_mapping, &per_core_arg_shapes, client)); for (auto arg_shapes_iter = per_core_arg_shapes.begin() + 1; @@ -200,18 +262,10 @@ absl::StatusOr CompileTfToHlo( } Tf2HloResult result; - result.mlir_hlo_module = xla::llvm_ir::CreateMlirModuleOp(module->getLoc()); - result.compile_metadata = std::move(compile_metadata); + result.hlo_module_proto = compilation_result.computation->proto(); + result.compile_metadata = arg.compile_metadata; result.host_compute_metadata = compilation_result.host_compute_metadata; - TF_RETURN_IF_ERROR(xla::ConvertHloToMlirHlo( - *result.mlir_hlo_module, &compilation_result.computation->proto())); - - if (VLOG_IS_ON(1)) { - tensorflow::DumpMlirOpToFile("ifrt_after_bridge_phase2", - result.mlir_hlo_module.get()); - } - return result; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index 48d7cabdd14286..16ac14cb340289 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -16,24 +16,47 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF2HLO_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_TRANSFORMS_IFRT_TF2HLO_H_ +#include +#include + +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_compilation.pb.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/client/compile_only_client.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/topology.h" +#include "xla/service/hlo.pb.h" +#include "xla/tsl/concurrency/ref_count.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" namespace tensorflow { namespace ifrt_serving { +struct Tf2HloArg { + mlir::ModuleOp module; + absl::Span input_dtypes_and_shapes; + absl::string_view entry_function_name; + tensorflow::tpu::TPUCompileMetadataProto compile_metadata; + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn; + std::shared_ptr topology; + absl::string_view platform_name; + + absl::StatusOr Key(); +}; + struct Tf2HloResult { - mlir::OwningOpRef mlir_hlo_module; + xla::HloModuleProto hlo_module_proto; tensorflow::tpu::TPUCompileMetadataProto compile_metadata; tf2xla::HostComputeMetadata host_compute_metadata; + Tf2HLOResultProto ToProto() const; }; absl::Status UpdateCompileMetadata( @@ -44,12 +67,7 @@ absl::StatusOr GetCompileMetadata( mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client); // A class that convert tf module to hlo -// TODO(b/304839793): provide wrap persistent compilation cache. -absl::StatusOr CompileTfToHlo( - mlir::ModuleOp module, absl::Span inputs, - absl::string_view entry_function_name, const xla::ifrt::Client& ifrt_client, - const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, - tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn); +absl::StatusOr CompileTfToHlo(const Tf2HloArg& arg); } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 1de61abd9e9385..d2c371ffa5109b 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -23,8 +23,10 @@ limitations under the License. #include #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mlir/IR/AsmState.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project @@ -35,17 +37,27 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/mock.h" #include "xla/python/ifrt/test_util.h" +#include "xla/python/pjrt_ifrt/pjrt_topology.h" +#include "xla/service/computation_placer.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace tensorflow { namespace ifrt_serving { namespace { +using ::testing::HasSubstr; +using ::testing::status::IsOkAndHolds; +using tsl::testing::StatusIs; // TODO(b/229726259): Make EqualsProto available in OSS class ProtoStringMatcher { @@ -100,9 +112,24 @@ TEST(Tf2HloTest, Empty) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); - auto result = - CompileTfToHlo(mlir_module.get(), {}, "main", *client, compile_metadata, - tensorflow::IdentityShapeRepresentationFn()); + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = {}, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + .platform_name = xla::CpuName(), + }; + auto result = CompileTfToHlo(arg); TF_ASSERT_OK(result.status()); } @@ -139,9 +166,25 @@ TEST(Tf2HloTest, Tuple) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", - *client, compile_metadata, - tensorflow::IdentityShapeRepresentationFn()); + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + .platform_name = xla::CpuName(), + }; + + auto result = CompileTfToHlo(arg); TF_ASSERT_OK(result.status()); } @@ -177,9 +220,25 @@ TEST(Tf2HloTest, Spmd) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", - *client, compile_metadata, - tensorflow::IdentityShapeRepresentationFn()); + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + .platform_name = xla::CpuName(), + }; + + auto result = CompileTfToHlo(arg); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -253,9 +312,25 @@ TEST(Tf2HloTest, UsingDefaultDeviceAssignment) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", - *client, compile_metadata, - tensorflow::IdentityShapeRepresentationFn()); + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + .platform_name = xla::CpuName(), + }; + + auto result = CompileTfToHlo(arg); LOG(INFO) << result->compile_metadata; TF_ASSERT_OK(result.status()); @@ -354,9 +429,25 @@ TEST(Tf2HloTest, XlaCallHostCallback) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - auto result = CompileTfToHlo(mlir_module.get(), dtype_and_shapes, "main", - *client, compile_metadata, - tensorflow::IdentityShapeRepresentationFn()); + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + .platform_name = xla::CpuName(), + }; + + auto result = CompileTfToHlo(arg); TF_ASSERT_OK(result.status()); @@ -367,6 +458,189 @@ TEST(Tf2HloTest, XlaCallHostCallback) { ASSERT_EQ((*result).host_compute_metadata.host_to_device().size(), 0); } +// On GPU enabled build, the compilation should pass. On a GPU disabled build, +// the compilation should fail with a correct error message. +TEST(Tf2HloTest, GpuCompile) { + // Create test input module + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/tf2hlo_gpu.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, &context); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + xla::ifrt::MockClient mock_client; + ON_CALL(mock_client, GetDefaultDeviceAssignment) + .WillByDefault([]() -> absl::StatusOr { + return xla::DeviceAssignment(1, 1); + }); + + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_FLOAT, {}}); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), mock_client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + Tf2HloArg arg{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared( + std::make_shared( + xla::CudaId(), xla::CudaName(), /*gpu_topology=*/nullptr)), + .platform_name = xla::CudaName(), + }; + + auto result = CompileTfToHlo(arg); +#if defined(GOOGLE_CUDA) + LOG(INFO) << "GPU compile success"; + EXPECT_OK(result); +#else + LOG(INFO) << "Non-GPU compile failure"; + EXPECT_THAT(result, StatusIs(absl::StatusCode::kUnimplemented, + HasSubstr("CUDA or ROCM build required"))); +#endif +} + +TEST(Tf2HloTest, SameArgProduceSameKeyFingerprint) { + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/xla_call_host_callback.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, + mlir::ParserConfig(&context)); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg0{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + }; + mlir::OwningOpRef mlir_module_clone = + mlir::OwningOpRef(mlir_module->clone()); + Tf2HloArg arg1{ + .module = mlir_module_clone.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + }; + + EXPECT_THAT(arg0.Key(), IsOkAndHolds(arg1.Key().value())); +} + +TEST(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { + constexpr absl::string_view kDataDirectory = + "tensorflow/compiler/mlir/tfrt/transforms/ifrt/testdata"; + std::string mlir_module_path = tensorflow::GetDataDependencyFilepath( + absl::StrCat(kDataDirectory, "/xla_call_host_callback.mlir")); + + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + mlir::RegisterAllTensorFlowDialects(registry); + + mlir::MLIRContext context(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceFile(mlir_module_path, + mlir::ParserConfig(&context)); + + ASSERT_TRUE(mlir_module); + ASSERT_TRUE(mlir_module.get() != nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr client, + xla::ifrt::test_util::GetClient()); + + std::vector dtype_and_shapes; + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + dtype_and_shapes.push_back(DtypeAndShape{DT_INT32, {1}}); + + TF_ASSERT_OK_AND_ASSIGN( + tensorflow::tpu::TPUCompileMetadataProto compile_metadata, + GetCompileMetadata(mlir_module.get(), *client)); + TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); + + xla::TfrtCpuTopologyDescription cpu_topology = + xla::TfrtCpuTopologyDescription::Create( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*devices=*/std::vector>{}, + /*machine_attributes=*/std::vector{}); + std::shared_ptr cpu_topology_ptr = + std::make_shared(cpu_topology); + + Tf2HloArg arg0{ + .module = mlir_module.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + }; + mlir::OwningOpRef mlir_module_clone = + mlir::OwningOpRef(mlir_module->clone()); + compile_metadata.set_num_replicas(11111); + Tf2HloArg arg1{ + .module = mlir_module_clone.get(), + .input_dtypes_and_shapes = dtype_and_shapes, + .entry_function_name = "main", + .compile_metadata = compile_metadata, + .shape_representation_fn = tensorflow::IdentityShapeRepresentationFn(), + .topology = std::make_shared(cpu_topology_ptr), + }; + + ASSERT_OK_AND_ASSIGN(std::string key0, arg0.Key()); + ASSERT_OK_AND_ASSIGN(std::string key1, arg1.Key()); + EXPECT_NE(key0, key1); +} + } // namespace } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index ef9270c7b64f2c..53619c6ca4285d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -59,6 +59,9 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, pm.addNestedPass(CreateTfRestorePruningPass()); pm.addNestedPass(CreateTfRestoreMergingPass()); + // Convert reference variable to resource variable since + // LowerToIfrtRestoreVariablePass does not support reference variable. + pm.addPass(CreateConvertReferenceVariableToResourceVariablePass()); pm.addPass(CreateLowerToIfrtRestoreVariablePass()); pm.addPass(CreateRewriteClusterToIfrtCallPass()); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 749f86c0ed4199..c554f8a26490e6 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -1056,7 +1056,7 @@ class TfToMlrtConversionPass type_converter_.addSourceMaterialization( [](mlir::OpBuilder &builder, mlir::Type result_type, mlir::ValueRange inputs, - mlir::Location loc) -> std::optional { + mlir::Location loc) -> mlir::Value { return builder .create(loc, result_type, inputs) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index 0e47fad312c7cc..202aa9c8d2f9ec 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -12,11 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc index 1c227426ce99e1..c9e3c79d6b7b79 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize_tf_control_flow_side_effect.cc @@ -13,9 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 2b97ec6a9536ac..cda19dc2157651 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -21,14 +21,20 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/Passes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.h" #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h" #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/device_name_utils.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.h b/tensorflow/compiler/mlir/tfrt/transforms/passes.h index a883db5e479268..7b1f322712fd47 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.h @@ -21,7 +21,11 @@ limitations under the License. #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc index c765e08742ae24..6905ede4f2dbca 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_device_attribute.cc @@ -15,13 +15,15 @@ limitations under the License. // This pass removes the device attribute from every corert.executeop. -#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tfrt/core_runtime/opdefs/core_runtime.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc index d855fa41344905..d60a09b14e5656 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/remove_tf_if_const_args.cc @@ -13,8 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc index eb6eb9dbdaee77..69fb2a858b092d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/reorder_assert.cc @@ -12,8 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc index a80d1ba7e180ef..f8343c034e0ce3 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/set_shape_invariant_in_while_ops.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc index 8bdb39c913bf75..5645bdf16c11fe 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -22,8 +22,6 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -31,17 +29,16 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/IR/DialectRegistry.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index f090745e0ae1c4..0f71991b2f8f82 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -29,24 +29,36 @@ limitations under the License. #include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" #include "tensorflow/compiler/mlir/tfrt/analysis/tensor_array_side_effect_analysis.h" @@ -58,10 +70,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.h" #include "tensorflow/compiler/mlir/tfrt/transforms/gpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tfrt_pipeline_options.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/tpu_passes.h" #include "tensorflow/compiler/mlir/tfrt/transforms/utils.h" +#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime diff --git a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc index 24dcb1904fa4ed..bebd279e28fadf 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.cc @@ -15,7 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.h" +#include "tensorflow/core/tfrt/fallback/cost_recorder.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc index 711438f21d13f9..cf65e50af55abb 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc @@ -18,19 +18,20 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" -#include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime -#include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime -#include "tfrt/core_runtime/opdefs/types.h" // from @tf_runtime namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc index cc28f332cb96aa..57b07c69bf2b55 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc @@ -16,11 +16,22 @@ limitations under the License. #include #include +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" +#include "tensorflow/core/ir/types/dialect.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 1dcd829833b985..7e319de3d8217b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -56,6 +56,7 @@ cc_library( "//tensorflow/compiler/mlir/tools/kernel_gen/transforms:passes", "//tensorflow/core:lib", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineToStandard", "@llvm-project//mlir:ArithDialect", @@ -211,6 +212,7 @@ cc_library( "//tensorflow/core/framework:resource_base", "//tensorflow/core/platform:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:ExecutionEngine", "@local_tsl//tsl/platform:thread_annotations", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.proto b/tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.proto index 01aac5186e2c08..cb27dbd5bfd07b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.proto +++ b/tensorflow/compiler/mlir/tools/kernel_gen/compile_cache_item.proto @@ -2,8 +2,6 @@ syntax = "proto3"; package mlir.kernel_gen; -option cc_enable_arenas = true; - // Protocolbuffer representing a compilation input and output. This will be used // for caching JIT compiles of kernels. message CompilationCacheItem { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index e32d215709cc81..42b29d86d31e51 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -94,6 +94,6 @@ cc_library( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 75f619fe9b940e..9a5ab0f00fd651 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -33,7 +33,7 @@ limitations under the License. // Generated dialect definitions. #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_dialect.cc.inc" -#include "tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/error_codes.pb.h" namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h index 380fb8c52448cf..6b224e91bfb5eb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.h @@ -24,6 +24,7 @@ limitations under the License. #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc index 51a505dc8de1e2..79884a02769785 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_jit_cache.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/Support/Error.h" #include "mlir/ExecutionEngine/ExecutionEngine.h" // from @llvm-project #include "tensorflow/core/platform/mutex.h" diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 489e13d172c059..dee541c450dabd 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -167,7 +167,7 @@ cc_library( "@local_xla//xla/service/gpu:target_constants", "@local_xla//xla/service/gpu/llvm_gpu_backend", ] + if_cuda_is_configured([ - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_xla//xla/stream_executor/cuda:cuda_asm_compiler", ]) + if_rocm_is_configured([ "@local_xla//xla/stream_executor/gpu:asm_compiler", diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index e4b4c6cd84dc96..d5b0ce09538dbc 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -36,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" -#include "tsl/platform/cuda_libdevice_path.h" #if GOOGLE_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 11648c9572b63c..3f3b7bcc9ef7a9 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -418,56 +418,18 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_sin -// CHECK-SAME: -> tensor<10xf32> +// CHECK: %[[VAR0:.*]] = tosa.sin %arg0 func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}} : tensor<513xi16>}> - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %arg0, %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tf.Sin"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } // ----- // CHECK-LABEL: test_cos -// CHECK-SAME: -> tensor<10xf32> +// CHECK: %[[VAR0:.*]] = tosa.cos %arg0 func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}} : tensor<513xi16>}> - // CHECK-DAG: %[[IN_TRANSLATE:.+]] = tosa.add %arg0, %[[HALF_PI]] - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %[[IN_TRANSLATE]], %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tf.Cos"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 19f0c6e216c259..e12c0a9ae0b38e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -818,56 +818,20 @@ func.func @test_sign(%arg0: tensor<21x45xi32>) -> tensor<21x45xi32> { // ----- // CHECK-LABEL: test_sin -// CHECK-SAME: -> tensor<10xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32> +// CHECK: %[[VAL_1:.*]] = tosa.sin %[[VAL_0]] : (tensor<10xf32> func.func @test_sin(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %arg0, %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tfl.sin"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } // ----- // CHECK-LABEL: test_cos -// CHECK-SAME: -> tensor<10xf32> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32> +// CHECK: %[[VAL_1:.*]] = tosa.cos %[[VAL_0]] : (tensor<10xf32> func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { - // CHECK-DAG: %[[RESULT_SCALE:.+]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1xf32>}> - // CHECK-DAG: %[[INT_MAX:.+]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1xf32>}> - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[TWO:.+]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> - // CHECK-DAG: %[[IN_SCALE:.+]] = "tosa.const"() <{value = dense<0.159154937> : tensor<1xf32>}> - // CHECK-DAG: %[[HALF_PI:.+]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1xf32>}> - // CHECK-DAG: %[[TBLVAL:.+]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> - // CHECK-DAG: %[[IN_TRANSLATE:.+]] = tosa.add %arg0, %[[HALF_PI]] - // CHECK-DAG: %[[IN_SCALED:.+]] = tosa.mul %[[IN_TRANSLATE]], %[[IN_SCALE]] - // CHECK-DAG: %[[FLOOR:.+]] = tosa.floor %[[IN_SCALED]] - // CHECK-DAG: %[[SUB1:.+]] = tosa.sub %[[IN_SCALED]], %[[FLOOR]] - // CHECK-DAG: %[[MUL1:.+]] = tosa.mul %[[SUB1]], %[[TWO]] - // CHECK-DAG: %[[SUB2:.+]] = tosa.sub %[[MUL1]], %[[ONE]] - // CHECK-DAG: %[[MUL2:.+]] = tosa.mul %[[SUB2]], %[[INT_MAX]] - // CHECK-DAG: %[[TO_INT:.+]] = tosa.cast %[[MUL2]] - // CHECK-DAG: %[[TABLE:.+]] = tosa.table %[[TO_INT]], %[[TBLVAL]] - // CHECK-DAG: %[[TABLE_CAST:.+]] = tosa.cast %[[TABLE]] - // CHECK-DAG: %[[RESULT:.+]] = tosa.mul %[[TABLE_CAST:.+]], %[[RESULT_SCALE]] %0 = "tfl.cos"(%arg0) : (tensor<10xf32>) -> tensor<*xf32> - - // CHECK: return %[[RESULT]] func.return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc index ba194e3e81c964..5c8dd934fe8117 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 25707c2bde1331..4724e061baf235 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -41,7 +41,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Arith/Utils/Utils.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project @@ -4586,109 +4586,6 @@ std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, .getResult(); } -// Lowers Sin operator to a sequence of TOSA ops. -std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, - Value input, ShapedType output_type) { - RankedTensorType input_type = dyn_cast(input.getType()); - Location loc = op->getLoc(); - - Type input_ety = input_type.getElementType(); - Type output_ety = output_type.getElementType(); - - if (!input) return std::nullopt; - - if (input_ety != output_ety) { - (void)rewriter.notifyMatchFailure(op, - "input/output element type must match"); - return std::nullopt; - } - - bool input_is_fp = input_ety.isF32(); - bool output_is_fp = output_ety.isF32(); - - if (!input_is_fp || !output_is_fp) { - (void)rewriter.notifyMatchFailure(op, "input/result must be fp32"); - return std::nullopt; - } - - // To perform a sin operation we remap the sin domain to be over a single - // period of the function, remapping to the domain of the table function. - // We then remap the range of the table function to map to the range of the - // sin operation. - - // 1. Normalize the period of the domain from [0, 2π) to [0, 1). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - Value fp_scale = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(0.5 / M_PI)})); - - // 2. Remap the periodic behavior of the domain to line up within [0, 1). - Value fp_scaled = CreateOpAndInfer( - rewriter, loc, input_type, input, fp_scale, rewriter.getI8IntegerAttr(0)); - auto floored = - CreateOpAndInfer(rewriter, loc, input_type, fp_scaled); - auto repeated = CreateOpAndInfer(rewriter, loc, input_type, - fp_scaled, floored); - - // 3. Scale and translate the normalized domain to the table domain. This - // includes a translating and scaling to [-int16_max, int16_max] and casting - // to an i16. - Value one = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {1.0f})); - - Value two = rewriter.create( - loc, fp_scalar_ty, DenseElementsAttr::get(fp_scalar_ty, {2.0f})); - auto scale_up = CreateOpAndInfer( - rewriter, loc, input_type, repeated, two, rewriter.getI8IntegerAttr(0)); - auto translate = - CreateOpAndInfer(rewriter, loc, input_type, scale_up, one); - - Value int_limit = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get( - fp_scalar_ty, - {static_cast(std::numeric_limits::max())})); - auto int_scaled = - CreateOpAndInfer(rewriter, loc, input_type, translate, - int_limit, rewriter.getI8IntegerAttr(0)); - - auto int16_ty = input_type.clone(rewriter.getIntegerType(16)); - auto casted = - CreateOpAndInfer(rewriter, loc, int16_ty, int_scaled); - - // 4. Compute the lookup table using the range of [-255, 255] for sin. - llvm::SmallVector values; - const int num_values = 513; - values.resize(num_values, 0); - // First and last values should be 0; - for (int i = 1; i < num_values - 1; ++i) - values[i] = std::numeric_limits::max() * - sin(static_cast(i) * 2.0 * M_PI / (num_values - 1.0)); - - auto table_ty = - RankedTensorType::get({num_values}, rewriter.getIntegerType(16)); - Value table = rewriter.create( - loc, table_ty, DenseElementsAttr::get(table_ty, llvm::ArrayRef(values))); - - auto table_result_ty = input_type.clone(rewriter.getIntegerType(32)); - auto table_result = CreateOpAndInfer( - rewriter, loc, table_result_ty, casted, table); - - // 5. The range of table is a 23-bit two's compliment value. Normalize the - // range by casting to an fp32 and dividing by 2^22. - auto table_result_fp = - CreateOpAndInfer(rewriter, loc, input_type, table_result); - auto output_scale = rewriter.create( - loc, fp_scalar_ty, - DenseElementsAttr::get( - fp_scalar_ty, - {static_cast(1.0 / static_cast(1 << 22))})); - - return CreateOpAndInfer(rewriter, loc, output_type, table_result_fp, - output_scale, rewriter.getI8IntegerAttr(0)) - .getResult(); -} - // Lowers Sign operator to a sequence of TOSA ops. std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index 20dbfb19d44702..cfe063408edea0 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -298,10 +298,6 @@ std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value on_value, Value off_value, int32_t depth, int32_t axis); -// Lowers 32-bit floating sin operator to a sequence of TOSA ops. -std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, - Value input, ShapedType output_type); - // Lowers Sign operator to a sequence of TOSA ops. std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 904394d370bcce..496a4275c0007b 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -25,7 +25,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -270,43 +270,30 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_sin_op = cast(op); + ShapedType output_type = - mlir::cast(tf_sin_op.getResult().getType()); + dyn_cast(tf_sin_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); - std::optional result = - convertSinOp(rewriter, op, tf_sin_op.getX(), output_type); - if (!result) return failure(); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_sin_op.getX()); - rewriter.replaceOp(op, {result.value()}); return success(); } LogicalResult ConvertTFCosOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_cos_op = cast(op); - Value input = tf_cos_op.getX(); - RankedTensorType input_ty = dyn_cast(input.getType()); - ShapedType output_ty = dyn_cast(tf_cos_op.getResult().getType()); - - if (!input_ty || !output_ty) return failure(); - bool input_is_fp = mlir::isa(input_ty.getElementType()); - bool output_is_fp = mlir::isa(output_ty.getElementType()); - - if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure( - op, "ConvertTFCosOp: input/result must be fp."); - } + ShapedType output_type = + dyn_cast(tf_cos_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); - // Replace with the equivalent sin operation: - // cos(x) = sin(x + π / 2). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - auto pi_2 = rewriter.create( - op->getLoc(), fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); - auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tf_cos_op.getX()); - CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index e6e7bc98e8d613..382bd14474d8dd 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -32,7 +32,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -3305,42 +3305,29 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( LogicalResult ConvertTFLSinOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sin_op = cast(op); - auto input = tfl_sin_op.getX(); + ShapedType output_type = dyn_cast(tfl_sin_op.getResult().getType()); - std::optional result = convertSinOp(rewriter, op, input, output_type); - if (!result) return failure(); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_sin_op.getX()); - rewriter.replaceOp(op, {result.value()}); return success(); } LogicalResult ConvertTFLCosOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_cos_op = cast(op); - Value input = tfl_cos_op.getX(); - RankedTensorType input_ty = dyn_cast(input.getType()); - ShapedType output_ty = dyn_cast(tfl_cos_op.getResult().getType()); - - if (!input_ty || !output_ty) return failure(); - - bool input_is_fp = mlir::isa(input_ty.getElementType()); - bool output_is_fp = mlir::isa(output_ty.getElementType()); - if (!input_is_fp || !output_is_fp) { - return rewriter.notifyMatchFailure(op, "input/result must be fp"); - } - - // Replace with the equivalent sin operation: - // cos(x) = sin(x + π / 2). - auto fp_scalar_ty = RankedTensorType::get({}, rewriter.getF32Type()); - auto pi_2 = rewriter.create( - op->getLoc(), fp_scalar_ty, - DenseElementsAttr::get(fp_scalar_ty, {static_cast(M_PI_2)})); - auto offset = rewriter.create(op->getLoc(), input_ty, input, pi_2); + ShapedType output_type = + dyn_cast(tfl_cos_op.getResult().getType()); + if (!output_type) + return rewriter.notifyMatchFailure(op, "output_type required"); + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_cos_op.getX()); - CreateReplaceOpAndInfer(rewriter, op, output_ty, offset); return success(); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index acb9dff2a4a8ff..c576504d102bb6 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -26,7 +26,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.h b/tensorflow/compiler/mlir/tosa/transforms/passes.h index e41453b0b9af8b..de0872b660d4ec 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.h +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.h @@ -28,7 +28,7 @@ limitations under the License. namespace mlir { namespace quant { -class QuantizationDialect; +class QuantDialect; } namespace quantfork { diff --git a/tensorflow/compiler/mlir/tosa/transforms/passes.td b/tensorflow/compiler/mlir/tosa/transforms/passes.td index 3cf7749d875f9d..08f8f9634cfce3 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/passes.td +++ b/tensorflow/compiler/mlir/tosa/transforms/passes.td @@ -19,7 +19,7 @@ def TosaLegalizeTFPass : Pass<"tosa-legalize-tf", "mlir::func::FuncOp"> { let summary = "Legalize from TensorFlow to TOSA"; let constructor = "createLegalizeTFPass()"; let dependentDialects = ["TosaDialect", - "quant::QuantizationDialect","quantfork::QuantizationForkDialect" + "quant::QuantDialect","quantfork::QuantizationForkDialect" ]; } diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index e4fb5a414c614f..456fd4dd467223 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -191,6 +191,51 @@ tf_xla_combined_py_test( # go/keep-sorted end ], ) + +tf_xla_combined_py_test( + name = "combined_ops_test_e", + size = "medium", + timeout = "long", + package = "tensorflow.compiler.tests", + python_version = "PY3", + tags = [ + "no_cuda_asan", # times out in individual tests + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "noasan", # times out consistently from 2023-08-24 + ], + tests = [ + # go/keep-sorted start + ":adagrad_da_test_lib", + ":adam_test_lib", + ":argminmax_test_lib", + ":listdiff_op_test_lib", + ":slice_ops_test_lib", + ":unary_ops_test_lib", + # go/keep-sorted end + ], +) + +tf_xla_combined_py_test( + name = "combined_ops_test_f", + size = "medium", + timeout = "long", + # copybara:uncomment_begin + # #TODO(b/286470564): Remove once the bug is fixed. + # disable_tpu_tfrt = True, + # copybara:uncomment_end + package = "tensorflow.compiler.tests", + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + tests = [ + # go/keep-sorted start + ":add_n_test_lib", + ":cond_test_lib", + ":while_test_lib", + # go/keep-sorted end + ], +) #LINT.ThenChange(:individual_tests) #LINT.IfChange(individual_tests) @@ -242,6 +287,7 @@ tf_xla_py_strict_test( tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -263,6 +309,7 @@ tf_xla_py_strict_test( tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -286,11 +333,10 @@ tf_xla_py_strict_test( # #TODO(b/286470564): Remove once the bug is fixed. # disable_tpu_tfrt = True, # copybara:uncomment_end - # TensorList ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -312,6 +358,7 @@ tf_xla_py_strict_test( "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", # times out consistently from 2023-08-24 + "notap", ], deps = [ ":xla_test", @@ -330,7 +377,6 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 5, tags = [ - "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], @@ -453,6 +499,7 @@ tf_xla_py_strict_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -587,7 +634,6 @@ tf_xla_py_strict_test( srcs = ["matrix_triangular_solve_op_test.py"], python_version = "PY3", tags = [ - "no_oss", # TODO(b/295649328): fix failed nightly tests "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", ], @@ -930,6 +976,7 @@ tf_xla_py_strict_test( tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -1017,7 +1064,6 @@ tf_xla_py_strict_test( python_version = "PY3", shard_count = 10, tags = [ - "no_oss", # TODO(b/282033702): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. ], @@ -1043,6 +1089,7 @@ tf_xla_py_strict_test( tags = [ "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -1796,6 +1843,7 @@ tf_xla_py_strict_test( "no_cuda_asan", # times out "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "noasan", #times out + "notap", ], deps = [ ":xla_test", @@ -1879,6 +1927,7 @@ tf_xla_py_strict_test( python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -2492,7 +2541,6 @@ tpu_py_strict_test( name = "approx_topk_test", srcs = ["approx_topk_test.py"], disable_experimental = False, - tags = ["no_oss"], deps = [ "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:def_function", @@ -3007,7 +3055,6 @@ tpu_py_strict_test( # TODO(b/188995810): Add an optimization in MLIR importer to not # materialize giant splat constants. python_version = "PY3", - tags = ["no_oss"], deps = [ "//tensorflow/python/distribute:tpu_strategy", "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 86356d89f63b4e..b38831b8beaffc 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -76,7 +76,7 @@ cc_library( deps = if_tensorrt([ "@local_config_tensorrt//:tensorrt_headers", "//tensorflow/core:lib", - "@local_xla//xla/stream_executor/platform:dso_loader", + "@local_tsl//tsl/platform:dso_loader", ]), ) @@ -1044,7 +1044,6 @@ tf_cuda_library( ":common_utils", ":tensorrt_lib", ":op_converter_registry", - "@local_xla//xla/stream_executor/platform:dso_loader", ]), ) @@ -1112,7 +1111,6 @@ pybind_extension( # py_proto_library( # name = "trt_engine_instance_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":trt_engine_instance_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc b/tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc index 8fc3c6e478f00f..3ecdc219226572 100644 --- a/tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc +++ b/tensorflow/compiler/tf2tensorrt/stub/nvinfer_plugin_stub.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/platform/dso_loader.h" #include "tensorflow/core/platform/env.h" +#include "tsl/platform/dso_loader.h" #include "third_party/tensorrt/NvInferPlugin.h" // Implements the TensorRT API by forwarding to TensorRT loaded from the DSO. @@ -25,8 +25,7 @@ void* GetDsoHandle() { return nullptr; #else static auto handle = []() -> void* { - auto handle_or = - stream_executor::internal::DsoLoader::GetNvInferPluginDsoHandle(); + auto handle_or = tsl::internal::DsoLoader::GetNvInferPluginDsoHandle(); if (!handle_or.ok()) return nullptr; return handle_or.value(); }(); diff --git a/tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc b/tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc index 1a4964032ba1c3..f352589907f6db 100644 --- a/tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc +++ b/tensorflow/compiler/tf2tensorrt/stub/nvinfer_stub.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/platform/dso_loader.h" #include "tensorflow/core/platform/env.h" +#include "tsl/platform/dso_loader.h" #include "third_party/tensorrt/NvInfer.h" // Implements the TensorRT API by forwarding to TensorRT loaded from the DSO. @@ -25,8 +25,7 @@ void* GetDsoHandle() { return nullptr; #else static auto handle = []() -> void* { - auto handle_or = - stream_executor::internal::DsoLoader::GetNvInferDsoHandle(); + auto handle_or = tsl::internal::DsoLoader::GetNvInferDsoHandle(); if (!handle_or.ok()) return nullptr; return handle_or.value(); }(); diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index 31f0586e7962af..113dcd57294d4e 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -20,7 +20,7 @@ limitations under the License. #if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/op_converter_registry.h" -#include "xla/stream_executor/platform/dso_loader.h" +#include "tsl/platform/dso_loader.h" #include "third_party/tensorrt/NvInfer.h" #endif @@ -33,7 +33,7 @@ bool IsGoogleTensorRTEnabled() { LOG(INFO) << "TensorRT libraries are statically linked, skip dlopen check"; return true; #else // TF_USE_TENSORRT_STATIC - auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries(); + auto handle_or = tsl::internal::DsoLoader::TryDlopenTensorRTLibraries(); if (!handle_or.ok()) { LOG_WARNING_WITH_PREFIX << "Could not find TensorRT"; } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 91350d7d1b7184..cfab323375bc51 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -98,7 +98,6 @@ tf_proto_library( xla_py_proto_library( name = "tf2xla_py", has_services = False, - api_version = 2, visibility = ["//visibility:public"], deps = [":tf2xla_proto"], ) @@ -161,7 +160,7 @@ cc_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", "@local_xla//xla/client", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -197,7 +196,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/client", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -433,7 +432,7 @@ cc_library( "@local_xla//xla/client:client_library", "@local_xla//xla/client:executable_build_options", "@local_xla//xla/client:local_client", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:platform_util", "@local_xla//xla/stream_executor:platform", ] + if_libtpu( @@ -520,15 +519,19 @@ cc_library( "//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/util:overflow", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:executable_run_options", + "@local_xla//xla:literal", "@local_xla//xla:protobuf_util", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", @@ -536,13 +539,13 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", "@local_xla//xla/client:local_client", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ] + if_libtpu([ ":xla_tpu_backend_registration", ]) + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), @@ -567,7 +570,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:session_options", "//tensorflow/core/common_runtime:core_cpu_internal", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], alwayslink = 1, ) @@ -606,8 +609,8 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], alwayslink = 1, ) @@ -634,6 +637,16 @@ cc_library( "//tensorflow/core/common_runtime/next_pluggable_device:next_pluggable_device_factory_hdrs", "//tensorflow/core/platform:stream_executor_no_cuda", "//tensorflow/core/tfrt/common:pjrt_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", + "@local_xla//xla:util", "@local_xla//xla/client:client_library", ], alwayslink = 1, @@ -656,8 +669,8 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", "@local_xla//xla/client", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], alwayslink = 1, ) @@ -681,10 +694,16 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@local_xla//xla:shape_util", + "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], alwayslink = 1, ) @@ -710,16 +729,16 @@ cc_library( "@com_google_absl//absl/types:span", "@local_xla//xla:executable_run_options", "@local_xla//xla:types", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/service/gpu:gpu_executable_run_options", "@local_xla//xla/service/gpu/runtime:nccl_clique_key", "@local_xla//xla/stream_executor", - "@local_xla//xla/translate/mhlo_to_hlo:layout_util", ], alwayslink = 1, ) @@ -739,7 +758,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/ir:hlo", ], alwayslink = 1, @@ -765,7 +784,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/ir:hlo", ], alwayslink = 1, @@ -845,7 +864,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/strings", "@local_xla//xla:status_macros", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], ) @@ -936,7 +955,7 @@ tf_cc_test( "@local_xla//xla:literal_util", "@local_xla//xla/client:client_library", "@local_xla//xla/client:local_client", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:cpu_plugin", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -999,13 +1018,15 @@ tf_cc_test( "//tensorflow/core/framework:tensor_testutil", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", + "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", "@local_xla//xla/client:local_client", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/service:cpu_plugin", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/service:hlo_proto_util", @@ -1028,9 +1049,9 @@ tf_cc_test( "//tensorflow/core:testlib", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla:literal", "@local_xla//xla:literal_util", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -1263,10 +1284,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:graph", "//tensorflow/core:lib", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", ], ) @@ -1387,8 +1405,11 @@ tf_cc_test( srcs = ["xla_op_registry_test.cc"], deps = [ ":xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/log", ], ) @@ -1481,7 +1502,7 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1538,7 +1559,10 @@ cc_library( "tf2xla_opset.cc", ], hdrs = ["tf2xla_opset.h"], - visibility = ["//tensorflow/python:__pkg__"], + visibility = [ + "//tensorflow/python:__pkg__", + "//tensorflow/python/util:__pkg__", + ], deps = [ ":tf2xla_util", ":xla_op_registry", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index b8d91294ca7b18..96a293b8676046 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -33,9 +33,9 @@ namespace tensorflow { namespace { -Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, - const NodeDef& node, StringPiece func_attr_name, - const FunctionBody** fbody) { +absl::Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, + const NodeDef& node, StringPiece func_attr_name, + const FunctionBody** fbody) { NameAttrList name_attr_list; TF_RETURN_IF_ERROR(GetNodeAttr(node, func_attr_name, &name_attr_list)); FunctionLibraryRuntime::Handle func_handle; @@ -45,9 +45,10 @@ Status GetFunctionBody(FunctionLibraryRuntime* flib_runtime, return absl::OkStatus(); } -Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, - const NodeDef& node, StringPiece func_list_attr_name, - std::vector* fbodies) { +absl::Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, + const NodeDef& node, + StringPiece func_list_attr_name, + std::vector* fbodies) { std::vector name_attr_lists; TF_RETURN_IF_ERROR(GetNodeAttr(node, func_list_attr_name, &name_attr_lists)); for (const NameAttrList& name_attr_list : name_attr_lists) { @@ -60,7 +61,7 @@ Status GetFunctionBodies(FunctionLibraryRuntime* flib_runtime, return absl::OkStatus(); } -Status CondConstInputIndices( +absl::Status CondConstInputIndices( absl::Span branch_bodies, std::vector* const_input_idxs, FunctionLibraryRuntime* flib_runtime) { TF_RET_CHECK(!branch_bodies.empty()); @@ -87,10 +88,11 @@ Status CondConstInputIndices( return absl::OkStatus(); } -Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, - const OpDef* op_def, - std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime) { +absl::Status GetCompileTimeConstInputs(const NodeDef& node, + const OpKernel* op_kernel, + const OpDef* op_def, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { DCHECK(op_def != nullptr || op_kernel != nullptr); if (node.op() == "While" || node.op() == "StatelessWhile") { // For While nodes, recurse into the body and cond graphs. @@ -172,9 +174,9 @@ Status GetCompileTimeConstInputs(const NodeDef& node, const OpKernel* op_kernel, } } -Status GetCompileTimeConstInputs(const Node* node, - std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime) { +absl::Status GetCompileTimeConstInputs(const Node* node, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { return GetCompileTimeConstInputs(node->def(), /*op_kernel=*/nullptr, &node->op_def(), const_input_idxs, flib_runtime); @@ -184,7 +186,7 @@ Status GetCompileTimeConstInputs(const Node* node, // Backwards dataflow analysis that finds arguments to a graph that must be // compile-time constants. -Status BackwardsConstAnalysis( +absl::Status BackwardsConstAnalysis( const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, FunctionLibraryRuntime* flib_runtime, @@ -207,7 +209,7 @@ Status BackwardsConstAnalysis( compile_time_const_nodes = &compile_time_const_nodes_impl; } - Status status; + absl::Status status; auto visit = [&](Node* node) { if (!status.ok()) return; @@ -294,9 +296,9 @@ Status BackwardsConstAnalysis( return status; } -Status GetCompileTimeConstInputs(const OpKernel* op_kernel, - std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime) { +absl::Status GetCompileTimeConstInputs(const OpKernel* op_kernel, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime) { return GetCompileTimeConstInputs(op_kernel->def(), op_kernel, /*op_def=*/nullptr, const_input_idxs, flib_runtime); diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h index ba5fa45fd9a6c1..ea7d9eb8b1104f 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.h +++ b/tensorflow/compiler/tf2xla/const_analysis.h @@ -35,7 +35,7 @@ namespace tensorflow { // // If `edge_filter` is non-null, only propagate const-ness along edges for which // `edge_filter` returns true. -Status BackwardsConstAnalysis( +absl::Status BackwardsConstAnalysis( const Graph& g, std::vector* compile_time_const_arg_indices, std::vector* compile_time_const_nodes, FunctionLibraryRuntime* flib_runtime, @@ -43,9 +43,9 @@ Status BackwardsConstAnalysis( // Given an op kernel and function library runtime, return all the indices of // inputs that need to be compile time constant. -Status GetCompileTimeConstInputs(const OpKernel* op_kernel, - std::vector* const_input_idxs, - FunctionLibraryRuntime* flib_runtime); +absl::Status GetCompileTimeConstInputs(const OpKernel* op_kernel, + std::vector* const_input_idxs, + FunctionLibraryRuntime* flib_runtime); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc index 577cbd9126e62b..92a644843c5d46 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc @@ -119,7 +119,7 @@ string DebugString(StateMap::CondId cond_state) { } // Returns the predicate of a switch. -Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { +absl::Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { const Edge* pred_edge; TF_RETURN_IF_ERROR(switch_node.input_edge(1, &pred_edge)); // The predicate can be preceded by a identity node. Look through @@ -131,7 +131,7 @@ Status GetSwitchPredicate(const Node& switch_node, OutputTensor* pred) { return absl::OkStatus(); } -Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { +absl::Status GetSwitchValue(const Node& switch_node, OutputTensor* val) { const Edge* val_edge; TF_RETURN_IF_ERROR(switch_node.input_edge(0, &val_edge)); *val = OutputTensor(val_edge->src(), val_edge->src_output()); @@ -301,10 +301,10 @@ class Conditional { StateMap* cond_state_map, const ShapeRefiner& refiner); // Adds merge node that is part of this conditional. - Status AddMerge(Node* m); + absl::Status AddMerge(Node* m); // Constructs an If node from the merge nodes. - Status BuildAndReplace( + absl::Status BuildAndReplace( Graph* graph, FunctionLibraryDefinition* library, std::unordered_map* merge_to_replacement); @@ -312,31 +312,31 @@ class Conditional { // Extracts the then/else bodies: creates new graphs with the nodes // corresponding to the nodes in the then/else branches as of this conditional // as function bodies. - Status ExtractBodies(Graph* graph); + absl::Status ExtractBodies(Graph* graph); // Builds the arguments that are the input to the If. - Status BuildArgumentNodes(); + absl::Status BuildArgumentNodes(); // Builds the If node for the extracted bodies with the given predicate. - Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); + absl::Status BuildIfNode(Graph* graph, FunctionLibraryDefinition* library); // Adds input edges to If node. - Status AddInputEdges( + absl::Status AddInputEdges( Graph* graph, const std::unordered_map& merge_to_replacement); // Adds output edges from If node. // Record new output tensor for all Merge nodes in 'merge_to_replacement'. - Status AddOutputEdges( + absl::Status AddOutputEdges( Graph* graph, std::unordered_map* merge_to_replacement); // Adds switch node that is part of this conditional. - Status AddSwitch(Node* s); + absl::Status AddSwitch(Node* s); // Adds a switch node along the edge and rewire the edge to go via the switch. - Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, - Graph* graph); + absl::Status AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, + Graph* graph); // Internal name of conditional. The name is based on the first merge node // added. @@ -392,12 +392,12 @@ Conditional::Conditional(OutputTensor predicate, FunctionalizeCond* parent, predicate_(predicate), refiner_(refiner) {} -Status Conditional::AddMerge(Node* m) { +absl::Status Conditional::AddMerge(Node* m) { merges_.insert(m); return absl::OkStatus(); } -Status Conditional::AddSwitch(Node* s) { +absl::Status Conditional::AddSwitch(Node* s) { VLOG(5) << "Adding switch " << s->DebugString(); OutputTensor predicate; TF_RETURN_IF_ERROR(GetSwitchPredicate(*s, &predicate)); @@ -413,7 +413,7 @@ Status Conditional::AddSwitch(Node* s) { return absl::OkStatus(); } -Status Conditional::BuildArgumentNodes() { +absl::Status Conditional::BuildArgumentNodes() { VLOG(1) << "Build function arguments"; struct Hash { size_t operator()(const std::pair& item) const { @@ -495,8 +495,9 @@ Status Conditional::BuildArgumentNodes() { return absl::OkStatus(); } -Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, - Graph* graph) { +absl::Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, + BranchType branch, + Graph* graph) { // Previously we had edge: // src:src_output ---- edge ----> dst:dst_input // post this we have (in graph) @@ -524,7 +525,7 @@ Status Conditional::AddSwitchNodeAlongEdge(const Edge* edge, BranchType branch, return AddSwitch(switch_node); } -Status Conditional::ExtractBodies(Graph* graph) { +absl::Status Conditional::ExtractBodies(Graph* graph) { VLOG(2) << "Extracting bodies for " << name(); for (auto b : {BranchType::kElseBranch, BranchType::kThenBranch}) { bodies_[static_cast(b)] = @@ -744,8 +745,8 @@ Status Conditional::ExtractBodies(Graph* graph) { return absl::OkStatus(); } -Status Conditional::BuildIfNode(Graph* graph, - FunctionLibraryDefinition* library) { +absl::Status Conditional::BuildIfNode(Graph* graph, + FunctionLibraryDefinition* library) { VLOG(2) << "Build cond function for " << name(); NodeDebugInfo debug_info((*merges_.begin())->def()); NodeDefBuilder builder(name(), "If", library, &debug_info); @@ -837,7 +838,7 @@ Status Conditional::BuildIfNode(Graph* graph, return absl::OkStatus(); } -Status Conditional::AddInputEdges( +absl::Status Conditional::AddInputEdges( Graph* graph, const std::unordered_map& merge_to_replacement) { VLOG(2) << "AddInputEdges for " << if_node_->name(); @@ -874,7 +875,7 @@ Status Conditional::AddInputEdges( return absl::OkStatus(); } -Status Conditional::AddOutputEdges( +absl::Status Conditional::AddOutputEdges( Graph* graph, std::unordered_map* merge_to_replacement) { VLOG(2) << "AddOutputEdges for " << if_node_->name(); @@ -913,7 +914,7 @@ Status Conditional::AddOutputEdges( return absl::OkStatus(); } -Status Conditional::BuildAndReplace( +absl::Status Conditional::BuildAndReplace( Graph* graph, FunctionLibraryDefinition* library, std::unordered_map* merge_to_replacement) { VLOG(1) << "Build If and replace merge nodes " @@ -952,8 +953,8 @@ string Conditional::name() const { return absl::StrCat((*merges_.begin())->name(), "_if"); } -Status FunctionalizeCond::AddIdentityNode(const Node* replacee, Node* if_node, - int port) { +absl::Status FunctionalizeCond::AddIdentityNode(const Node* replacee, + Node* if_node, int port) { NodeBuilder id_builder(replacee->name(), "Identity"); id_builder.Input(if_node, port); string outside_compilation; @@ -987,7 +988,7 @@ absl::StatusOr FunctionalizeCond::AddIfNode( return ret; } -Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { +absl::Status FunctionalizeCond::PropagateUpdatedState(const Node* replacee) { VLOG(2) << "Propagating update state for " << replacee->name() << " " << state_map_.CondStateToString(replacee); // Redo topological sort as the order could have changed. @@ -1155,7 +1156,7 @@ StateMap::CondId FunctionalizeCond::StateAlongEdge(const Edge* e) { return id; } -Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { +absl::Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { // Only Merge nodes with two inputs are supported, but if this is a redundant // merge, then the dead edge may already have been removed (if due to a // switch) and so the input count would be incorrect. @@ -1185,7 +1186,7 @@ Status FunctionalizeCond::DetermineCondStateMerge(Node* dst) { return absl::OkStatus(); } -Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { +absl::Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { // Handle non-merge join. for (auto e : dst->in_edges()) { VLOG(4) << "Processing forward flow for: " << e->DebugString() << " " @@ -1203,7 +1204,7 @@ Status FunctionalizeCond::DetermineCondStateNonMerge(Node* dst) { return absl::OkStatus(); } -Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { +absl::Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { // Handle redundant merge nodes. A merge node is considered redundant if // one input edge is dead while the other has a value. if (!state_map_.IsDead(state_map_.LookupCondId(node))) @@ -1242,7 +1243,7 @@ Status FunctionalizeCond::RemoveRedundantMerge(Node* node) { return absl::OkStatus(); } -Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { +absl::Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { // Handle redundant switch nodes. A switch node is considered redundant if // the predicate of the switch already holds on the current branch. E.g., if // p is the predicate of the switch but p is already known to hold on this @@ -1312,7 +1313,8 @@ Status FunctionalizeCond::RemoveRedundantSwitch(Node* node) { return absl::OkStatus(); } -Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { +absl::Status FunctionalizeCond::DetermineStates( + std::vector rev_topo_order) { // The state that is propagated along the given edge. for (auto it = rev_topo_order.rbegin(); it != rev_topo_order.rend(); ++it) { Node* dst = *it; @@ -1328,7 +1330,7 @@ Status FunctionalizeCond::DetermineStates(std::vector rev_topo_order) { return absl::OkStatus(); } -Status FunctionalizeCond::DetermineAncestorState(Node* dst) { +absl::Status FunctionalizeCond::DetermineAncestorState(Node* dst) { StateMap::AncestorId id = nullptr; StateMap::AncestorState state; @@ -1457,7 +1459,7 @@ void FunctionalizeCond::SortMergeNodes(std::vector* merge_order) { } } -Status FunctionalizeCond::FunctionalizeInternal() { +absl::Status FunctionalizeCond::FunctionalizeInternal() { // The general approach for converting a tf.cond (as lowered via switch/merge // nodes) to a functional if is as follows: // 1. Determine the topological order and collect all the switch and merge @@ -1595,9 +1597,9 @@ void FunctionalizeCond::AddSwitchId(int switch_id) { switch_ids_.push_back(switch_id); } -Status FunctionalizeCond::Functionalize(Graph* graph, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { +absl::Status FunctionalizeCond::Functionalize( + Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { VLOG(1) << "FunctionalizeCond::Functionalize"; FunctionalizeCond fc(graph, library, node_filter); return fc.FunctionalizeInternal(); @@ -1605,8 +1607,8 @@ Status FunctionalizeCond::Functionalize(Graph* graph, } // namespace functionalize_cond -Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { +absl::Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { // FunctionalizeControlFlow is invoked for every function, so the loops's // bodies and conditionals that were extracted into functions will be handled // in successive invocations. diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h index 23b2acb56978d0..e37555b053d7ed 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond.h +++ b/tensorflow/compiler/tf2xla/functionalize_cond.h @@ -37,8 +37,8 @@ namespace tensorflow { // b) While loops must have been functionalized before according to // `node_filter` (e.g., by calling `FunctionalizeWhileLoop` with the same // filter before calling this function). -Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); +absl::Status FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); // Internal functions/classes exposed for testing purposes. namespace functionalize_cond { @@ -184,12 +184,13 @@ class StateMap { class FunctionalizeCond { public: // See comment for function `FunctionalizeCond`. - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter); + static absl::Status Functionalize(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter); // Build identity node with the same name as the merge that will be replaced // in case the output is fetched/colocated. - Status AddIdentityNode(const Node* replacee, Node* if_node, int port); + absl::Status AddIdentityNode(const Node* replacee, Node* if_node, int port); // Add a If node to the graph defined by def that will, amongst other, replace // replacee in the graph. @@ -197,7 +198,7 @@ class FunctionalizeCond { const OutputTensor& predicate); // Propagates the state of a newly inserted node. - Status PropagateUpdatedState(const Node* replacee); + absl::Status PropagateUpdatedState(const Node* replacee); // Dump graph with the CondState annotated. void DumpGraphWithCondState(const string& name); @@ -212,7 +213,7 @@ class FunctionalizeCond { // Performs the actual cond functionalization. Iterate over groups of merge // nodes (linked by common predicates & ancestor IDs), from innermost to // outermost, and extract into If nodes. - Status FunctionalizeInternal(); + absl::Status FunctionalizeInternal(); // Returns the forward flow state propagated along edge `e`. // This may modify state_map_. @@ -221,19 +222,19 @@ class FunctionalizeCond { // Determines the CondState and AncestorState of all the nodes in the given // vector where the input is expected in reverse topological order. // This populates the state_map_. - Status DetermineStates(std::vector rev_topo_order); + absl::Status DetermineStates(std::vector rev_topo_order); // Determine the CondState for a given node using the incoming edges // to the node. Note: it is expected that this node's CondState is only // determined once its input's CondState is. - Status DetermineCondState(Node* dst) { + absl::Status DetermineCondState(Node* dst) { if (IsMerge(dst)) return DetermineCondStateMerge(dst); return DetermineCondStateNonMerge(dst); } // Helper functions for DetermineCondState. - Status DetermineCondStateNonMerge(Node* dst); - Status DetermineCondStateMerge(Node* dst); + absl::Status DetermineCondStateNonMerge(Node* dst); + absl::Status DetermineCondStateMerge(Node* dst); // Determines the dst node's CondState by joining the src and dst's CondState // where either the dst node is a merge or not. @@ -245,13 +246,13 @@ class FunctionalizeCond { StateMap::CondId dst); // Determines which switch/merge nodes are ancestors of this node. - Status DetermineAncestorState(Node* dst); + absl::Status DetermineAncestorState(Node* dst); // Checks if a merge node is redundant and if so removes it from the graph. - Status RemoveRedundantMerge(Node* node); + absl::Status RemoveRedundantMerge(Node* node); // Checks if a switch node is redundant and if so removes it from the graph. - Status RemoveRedundantSwitch(Node* node); + absl::Status RemoveRedundantSwitch(Node* node); // Sorts merge nodes (in reverse topological order) in order of increasing // nesting depth. diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index d8015ce6835d09..57f1cbdf3bd44f 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -94,7 +94,8 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { } // An non-merge op with inputs from then and else branch. - Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); + absl::Status status = + JoinCondStatesNonMerge(then_branch, else_branch).status(); EXPECT_TRUE(errors::IsInvalidArgument(status)); // Merge between then and else branch. diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 2bad3b58d34761..ac38725269bfd9 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -94,7 +94,7 @@ void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, } // Adds new function def to graph's function library if necessary. -Status AddFunctionDefToGraphLibrary( +absl::Status AddFunctionDefToGraphLibrary( const string& func_name, const AssociatedFunctionInfo& associated_function, Graph* graph, FunctionLibraryDefinition* fld) { const OpRegistrationData* op_reg_data; @@ -128,7 +128,7 @@ Status AddFunctionDefToGraphLibrary( } // Functionalizes function given by `func_name`. Update `func_map` accordingly. -Status FunctionalizeControlFlowForFunction( +absl::Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, @@ -137,7 +137,7 @@ Status FunctionalizeControlFlowForFunction( // Functionalizes all functions that are (directly or indirectly) associated to // any node in `graph`. Adds processed functions to `func_map`. -Status FunctionalizeControlFlowForNodeAssociatedFunctions( +absl::Status FunctionalizeControlFlowForNodeAssociatedFunctions( FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, bool* any_function_modified, const NodeFilter& node_filter) { @@ -201,7 +201,7 @@ Status FunctionalizeControlFlowForNodeAssociatedFunctions( return absl::OkStatus(); } -Status FunctionalizeControlFlowForFunction( +absl::Status FunctionalizeControlFlowForFunction( const string& func_name, const string& new_func_name, const protobuf::Map& attrs, FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, @@ -211,7 +211,7 @@ Status FunctionalizeControlFlowForFunction( // Convert the function to a graph. FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = absl::OkStatus(); + absl::Status ret_status = absl::OkStatus(); auto cleanup_handle = gtl::MakeCleanup([&]() { auto s = flr->ReleaseHandle(handle); if (!s.ok()) { @@ -270,10 +270,10 @@ Status FunctionalizeControlFlowForFunction( return ret_status; } -Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter, - bool include_functions) { +absl::Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter, + bool include_functions) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); @@ -308,10 +308,9 @@ Status FunctionalizeControlFlow(Graph* graph, return absl::OkStatus(); } -Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter, - bool include_functions) { +absl::Status FunctionalizeControlFlowForGraphDef( + GraphDef* graph_def, FunctionLibraryDefinition* library, + const NodeFilter& node_filter, bool include_functions) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); @@ -323,7 +322,7 @@ Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, return absl::OkStatus(); } -Status FunctionalizeControlFlowForXlaPass::Run( +absl::Status FunctionalizeControlFlowForXlaPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); if (VLOG_IS_ON(4)) { diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index ff5ae841e3044c..554186fefd1916 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -48,22 +48,21 @@ namespace tensorflow { // // The user of this function is responsible for using a node filter that // satisfies the above conditions. -Status FunctionalizeControlFlow(Graph* graph, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}, - bool include_functions = false); +absl::Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}, + bool include_functions = false); -Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}, - bool include_functions = false); +absl::Status FunctionalizeControlFlowForGraphDef( + GraphDef* graph_def, FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}, bool include_functions = false); // Rewrites the graph by turning V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While), only modifies // functions that will be executed by XLA. class FunctionalizeControlFlowForXlaPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 25a08224c8b946..604a24514f8e5a 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -46,8 +46,8 @@ namespace { // Returns the names of the "then" and "else" functions for the If node in a // graph. -Status FindIfThenAndElse(const GraphDef& graph, string* op_name, - NameAttrList* then_fn, NameAttrList* else_fn) { +absl::Status FindIfThenAndElse(const GraphDef& graph, string* op_name, + NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "If") { *op_name = node.name(); @@ -249,7 +249,7 @@ void ConditionalTestFixture::RunTest() { cond_fn.set_name("cond_node"); cond_fn.set_op("cond_fn"); *(cond_fn.add_input()) = "source"; - Status status; + absl::Status status; scope.graph()->AddNode(cond_fn, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(scope.ToGraph(&graph)); @@ -308,8 +308,8 @@ void ConditionalTestFixture::RunTest() { // Returns the names of the "cond" and "body" functions for the While node // in a graph. -Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, - NameAttrList* body) { +absl::Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, + NameAttrList* body) { for (const NodeDef& node : graph.node()) { if (node.op() == "While") { const NameAttrList* result; @@ -463,7 +463,7 @@ FunctionDef GetNoinlineFunctionDef() { // return [x + 1] // Define the above function, and add it to the given graph. It's used as the // while loop body in NoinlineLoopBody test. -Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { +absl::Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { FunctionDefLibrary fdef_lib; *(fdef_lib.add_function()) = GetNoinlineFunctionDef(); TF_RETURN_IF_ERROR(graph->AddFunctionLibrary(fdef_lib)); @@ -472,7 +472,7 @@ Status AddNoinlineFunctionToGraph(const string& node_name, Graph* graph) { increment_fn.set_op("increment_fn"); *increment_fn.add_input() = "while/Identity"; *increment_fn.add_input() = "^while/Identity"; - Status status; + absl::Status status; graph->AddNode(increment_fn, &status); return status; } @@ -511,7 +511,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { *next_iter.add_input() = noinline_node_name; (*next_iter.mutable_attr())["T"].set_type(DT_INT32); - Status status; + absl::Status status; Node* n = scope.graph()->AddNode(next_iter, &status); TF_ASSERT_OK(status); @@ -563,7 +563,7 @@ TEST(FunctionalizeControlFlow, NoinlineLoopBody) { *retval.add_input() = noinline_node_name; (*retval.mutable_attr())["T"].set_type(DT_INT32); (*retval.mutable_attr())["index"].set_i(0); - Status status; + absl::Status status; scope.graph()->AddNode(retval, &status); TF_ASSERT_OK(status); @@ -600,7 +600,8 @@ TEST(FunctionalizeControlFlow, MissingFunctionDefInLibrary) { graph.ToGraphDef(&graph_def); graph_def.clear_library(); - Status status = FunctionalizeControlFlowForGraphDef(&graph_def, &library); + absl::Status status = + FunctionalizeControlFlowForGraphDef(&graph_def, &library); EXPECT_EQ(tensorflow::error::NOT_FOUND, status.code()); } @@ -1105,9 +1106,10 @@ void ComplexTestFixture::RunTest() { ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } : NodeFilter{}; - Status status1 = FunctionalizeControlFlowForGraphDef(&optimized_graph_def, - &library, node_filter); - Status status2 = FunctionalizeControlFlow(&graph, &library, node_filter); + absl::Status status1 = FunctionalizeControlFlowForGraphDef( + &optimized_graph_def, &library, node_filter); + absl::Status status2 = + FunctionalizeControlFlow(&graph, &library, node_filter); ASSERT_EQ(status1, status2); if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) { // This case violates the precondition of `FunctionalizeControlFlow`, we diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc index c2bc42b5c24e14..cf3413154b8baa 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.cc @@ -40,7 +40,7 @@ absl::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { return graph->AddNode(ret_def); } -Status ExtractWhileLoopFrames( +absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, std::unordered_map* frames, const NodeFilter& node_filter) { @@ -82,7 +82,7 @@ Status ExtractWhileLoopFrames( } // Check that the graph has no cycle containing the given node. -Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { +absl::Status CheckNodeNotInCycle(const Node* node, const int num_nodes) { std::vector ready; ready.push_back(node); std::vector visited(num_nodes); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h index 5d7ce5618fe252..970f62daa42af3 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_util.h @@ -74,13 +74,13 @@ struct WhileLoopFrame { // If `node_filter` is defined, then we keep track of frames that should be // functionalized according to the filter (see comment for // `FunctionalizeControlFlow` for more details about node filters). -Status ExtractWhileLoopFrames( +absl::Status ExtractWhileLoopFrames( const std::vector& cf_info, const Graph* graph, std::unordered_map* frames, const NodeFilter& node_filter = {}); // Check that the graph has no cycle containing the given node. -Status CheckNodeNotInCycle(const Node* node, const int num_nodes); +absl::Status CheckNodeNotInCycle(const Node* node, const int num_nodes); // Comparison function used for sorting nodes consistently. // a) resource variables are last, and diff --git a/tensorflow/compiler/tf2xla/functionalize_while.cc b/tensorflow/compiler/tf2xla/functionalize_while.cc index 70f98b3e88daec..73afe1909b4d92 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.cc +++ b/tensorflow/compiler/tf2xla/functionalize_while.cc @@ -56,10 +56,10 @@ namespace { // taking from the Switch node was not necessarily the first output, but _Arg // nodes only have one output. By adding the Switch node to `squash_src_outputs` // we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, - std::vector stack, - const std::vector& squash_src_outputs, - std::vector* node_map, Graph* output) { +absl::Status CopySubgraph(const Graph& graph, const WhileLoopFrame* frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { VLOG(3) << "Stack: " << NodesToString(stack); std::vector visited(graph.num_node_ids(), false); while (!stack.empty()) { @@ -117,8 +117,8 @@ absl::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { } // Builds a graph for the loop condition. -Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame, - std::unique_ptr* cond_output) { +absl::Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame, + std::unique_ptr* cond_output) { VLOG(2) << "Building loop condition for " << frame->name; *cond_output = std::make_unique(graph.op_registry()); Graph* output = cond_output->get(); @@ -153,9 +153,9 @@ Status BuildLoopCondition(const Graph& graph, WhileLoopFrame* frame, } // Builds a graph for the loop body. -Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, - DataTypeVector* arg_types, - std::unique_ptr* body_output) { +absl::Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { VLOG(2) << "Building loop body for " << frame->name; *body_output = std::make_unique(graph.op_registry()); Graph* output = body_output->get(); @@ -209,9 +209,9 @@ Status BuildLoopBody(const Graph& graph, WhileLoopFrame* frame, return absl::OkStatus(); } -Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, - FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { +absl::Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { if (node_filter && !frame->should_be_functionalized) { VLOG(2) << "Skipping functionalization for frame " << frame->name << " because it has control flow nodes that are filtered out by " @@ -505,8 +505,9 @@ Status FunctionalizeLoop(Graph* graph, WhileLoopFrame* frame, } } // namespace -Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { +absl::Status FunctionalizeWhileLoop(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter) { // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. diff --git a/tensorflow/compiler/tf2xla/functionalize_while.h b/tensorflow/compiler/tf2xla/functionalize_while.h index ddd6b655cd52c2..e9b361f603d2f7 100644 --- a/tensorflow/compiler/tf2xla/functionalize_while.h +++ b/tensorflow/compiler/tf2xla/functionalize_while.h @@ -31,8 +31,9 @@ namespace tensorflow { // // Preconditions: // Same as for `FunctionalizeControlFlow` (see comment there). -Status FunctionalizeWhileLoop(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); +absl::Status FunctionalizeWhileLoop(Graph* graph, + FunctionLibraryDefinition* library, + const NodeFilter& node_filter = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc index 70c09bc84ac275..2759ad8384cd81 100644 --- a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc +++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { namespace { -Status GetTestDevice(Session* session, string* test_device) { +absl::Status GetTestDevice(Session* session, string* test_device) { std::vector devices; TF_RETURN_IF_ERROR(session->ListDevices(&devices)); diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 23eb33224dc24b..f23c423fbb2632 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -62,10 +62,10 @@ auto* graph_compiler_failed_compilation_op_count = /*metric_label=*/"op_name"); namespace { -Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, - const std::vector& expressions, - const NameAttrList& func, - std::vector* args) { +absl::Status PrepareArguments( + XlaOpKernelContext* ctx, Graph* graph, + const std::vector& expressions, + const NameAttrList& func, std::vector* args) { auto client = ctx->compiler()->client(); std::vector arg_must_be_compile_time_constant(expressions.size()); @@ -117,7 +117,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, return absl::OkStatus(); } } // namespace -Status GraphCompiler::Compile() { +absl::Status GraphCompiler::Compile() { // Check that the graph has no illegal cycles. TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_)); // Maintain a mapping from node id to node outputs. @@ -144,7 +144,7 @@ Status GraphCompiler::Compile() { OpKernel* op_kernel_raw = nullptr; // The kernel is not actually run for functional ops, we just need it // for metadata. - Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw); + absl::Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw); // Transfer ownership of the kernel to a local smart pointer. std::unique_ptr op_kernel(op_kernel_raw); @@ -183,7 +183,7 @@ Status GraphCompiler::Compile() { TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); } else { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); - Status s = op_context.status(); + absl::Status s = op_context.status(); if (!s.ok()) { graph_compiler_failed_compilation_op_count ->GetCell(params.op_kernel->def().op()) @@ -209,8 +209,8 @@ Status GraphCompiler::Compile() { namespace { -Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, - const Node& node, NameAttrList* func) { +absl::Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, + const Node& node, NameAttrList* func) { if (node.IsPartitionedCall()) { const AttrValue* attr_value; TF_RETURN_IF_ERROR( @@ -235,8 +235,8 @@ Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib, } // namespace -Status GraphCompiler::CompileFunctionalNode(Node* n, - OpKernelContext* op_context) { +absl::Status GraphCompiler::CompileFunctionalNode(Node* n, + OpKernelContext* op_context) { TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)); // For functional nodes, compile them using compiler from the context and call // into the functions. diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index da4dcda1010c28..6ab20955057bfa 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -66,7 +66,7 @@ class GraphCompiler { // Compiles the graph. The results are written in xla_context stored in the // resource_manager of the 'XlaCompilationDevice' that's passed into the // constructor. - Status Compile(); + absl::Status Compile(); private: // Partially sets params. This partially set params can be reused @@ -76,7 +76,7 @@ class GraphCompiler { // Compiles a functional node and writes result to OpkernelContext. A // functional node represents a defined computation and should be compiled // using `compiler_`. - Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); + absl::Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); XlaCompilationDevice* device_; Graph* graph_; diff --git a/tensorflow/compiler/tf2xla/graph_compiler_test.cc b/tensorflow/compiler/tf2xla/graph_compiler_test.cc index 6ec8b8f879333a..3010ac7f0b026b 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_test.cc @@ -87,7 +87,7 @@ class GraphCompilerTest : public ::testing::Test { device_mgr_ = std::make_unique(absl::WrapUnique(device_)); } - Status RunGraphCompiler(Graph& graph) { + absl::Status RunGraphCompiler(Graph& graph) { ProcessFunctionLibraryRuntime runtime( device_mgr_.get(), Env::Default(), nullptr, TF_GRAPH_DEF_VERSION, &graph.flib_def(), OptimizerOptions()); @@ -106,7 +106,8 @@ class GraphCompilerTest : public ::testing::Test { auto step_container = std::make_unique(0, [this](const string& name) { - Status status = this->device_->resource_manager()->Cleanup(name); + absl::Status status = + this->device_->resource_manager()->Cleanup(name); }); auto container_status = step_container->Create( device_->resource_manager(), XlaContext::kXlaContextResourceName, diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.cc b/tensorflow/compiler/tf2xla/graph_compiler_util.cc index ac064805f1a470..d1c984e26f390a 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.cc @@ -51,10 +51,11 @@ typedef std::unordered_map NodeMap; // tensor with a placeholder. For each feed tensor, replaces all edges so they // point from a new _Arg node instead. The newly created _Arg nodes are added to // `arg_nodes`. -Status AddArgNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField& feeds, - const std::unordered_map& feed_remapping, - std::unordered_set* arg_nodes) { +absl::Status AddArgNodes( + Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField& feeds, + const std::unordered_map& feed_remapping, + std::unordered_set* arg_nodes) { for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) { const tf2xla::Feed& feed = feeds[arg_index]; // All feeds have been replaced by placeholders. @@ -111,9 +112,10 @@ Status AddArgNodes(Graph* graph, const NodeMap& node_map, // Each fetch id identifies the positional output of some node. For each fetch // node, adds a new _Retval node instead, and adds the node to `retval_nodes`. -Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, - const protobuf::RepeatedPtrField& fetches, - std::unordered_set* retval_nodes) { +absl::Status AddRetvalNodes( + Graph* graph, const NodeMap& node_map, + const protobuf::RepeatedPtrField& fetches, + std::unordered_set* retval_nodes) { for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) { const tf2xla::TensorId& id = fetches[ret_index].id(); auto it = node_map.find(id.node_name()); @@ -145,7 +147,7 @@ Status AddRetvalNodes(Graph* graph, const NodeMap& node_map, // fetch ids respectively), and rewrites the edges so that inputs flow from _Arg // nodes, and outputs flow to _Retval nodes. This allows the symbolic graph // execution to know the input and output args for the generated function. -Status RewriteAndPruneGraph( +absl::Status RewriteAndPruneGraph( Graph* graph, const tf2xla::Config& config, const std::unordered_map& feed_remapping) { NodeMap node_map; @@ -198,7 +200,8 @@ Status RewriteAndPruneGraph( // CollectArgNodes collects _Arg nodes from the graph, and performs basic // sanity-checking to ensure the index and type attributes of each node are // initialized correctly. -Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { +absl::Status CollectArgNodes(const Graph& graph, + std::vector* arg_nodes) { std::map indexed_arg_nodes; for (Node* n : graph.nodes()) { if (n->type_string() == FunctionLibraryDefinition::kArgOp) { @@ -229,8 +232,8 @@ Status CollectArgNodes(const Graph& graph, std::vector* arg_nodes) { } // namespace -Status CreateXlaArgs(const Graph& graph, - std::vector* xla_args) { +absl::Status CreateXlaArgs(const Graph& graph, + std::vector* xla_args) { std::vector arg_nodes; TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes)); for (const Node* node : arg_nodes) { @@ -262,8 +265,8 @@ void PopulateXlaArgs(const tf2xla::Config& config, } } -Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, - std::unique_ptr* graph) { +absl::Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, + std::unique_ptr* graph) { TF_RETURN_IF_ERROR(ValidateConfig(config)); FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph_def.library()); diff --git a/tensorflow/compiler/tf2xla/graph_compiler_util.h b/tensorflow/compiler/tf2xla/graph_compiler_util.h index 4e0e3405884193..ebdf07f7eb05c1 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler_util.h +++ b/tensorflow/compiler/tf2xla/graph_compiler_util.h @@ -26,8 +26,8 @@ limitations under the License. namespace tensorflow { // Fills in xla_args from the corresponding _Arg nodes in the graph. -Status CreateXlaArgs(const Graph& graph, - std::vector* xla_args); +absl::Status CreateXlaArgs(const Graph& graph, + std::vector* xla_args); // Populate xla_args for the given XLA config. void PopulateXlaArgs(const tf2xla::Config& config, @@ -43,8 +43,8 @@ void PopulateXlaArgs(const tf2xla::Config& config, // _Arg node instead. Each fetch id causes a new _Retval node to be created, // with a new edge pointing from the named node's output index to that _Retval // node. -Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, - std::unique_ptr* graph); +absl::Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, + std::unique_ptr* graph); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 156ddbb3581222..a82e979120ab19 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -217,7 +217,7 @@ tf_cuda_library( "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/service:custom_call_status", "@local_xla//xla/service:custom_call_target_registry", "@local_xla//xla/service:hlo_proto_cc", @@ -250,9 +250,9 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", ], alwayslink = 1, ) @@ -290,9 +290,9 @@ cc_library( "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:literal_util", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -312,7 +312,7 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -333,7 +333,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@local_xla//xla:literal", - "@local_xla//xla/client:value_inference", + "@local_xla//xla/hlo/builder:value_inference", ], ) @@ -360,13 +360,13 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@local_xla//xla/mlir/utils:type_util", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/python:refine_polymorphic_shapes", "@local_xla//xla/service:hlo_proto_cc", - "@local_xla//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", @@ -404,14 +404,14 @@ tf_kernel_library( "@local_xla//xla:debug_options_flags", "@local_xla//xla:shape_util", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/service:hlo_module_config", - "@local_xla//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", ], ) @@ -445,9 +445,9 @@ tf_kernel_library( "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:tuple", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:tuple", ], ) @@ -475,8 +475,8 @@ tf_kernel_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) @@ -495,9 +495,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) @@ -555,7 +555,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -574,9 +574,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -595,8 +595,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client/lib:math", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder/lib:math", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -617,9 +617,9 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -640,10 +640,10 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:literal", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:sorting", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:sorting", ], ) @@ -667,9 +667,9 @@ tf_kernel_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -691,9 +691,9 @@ tf_kernel_library( "@com_google_absl//absl/status:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:sorting", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:sorting", ], ) @@ -718,7 +718,7 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_xla//xla:status_macros", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -739,8 +739,8 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -767,10 +767,10 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) @@ -790,7 +790,7 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -809,9 +809,9 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:slicing", - "@local_xla//xla/client/lib:svd", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:slicing", + "@local_xla//xla/hlo/builder/lib:svd", ], ) @@ -836,8 +836,8 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -852,7 +852,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -874,7 +874,7 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -892,8 +892,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -924,7 +924,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -949,8 +949,8 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -967,8 +967,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:comparators", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:comparators", ], ) @@ -986,9 +986,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:matrix", - "@local_xla//xla/client/lib:qr", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:matrix", + "@local_xla//xla/hlo/builder/lib:qr", ], ) @@ -1006,8 +1006,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -1028,8 +1028,8 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:literal", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1065,8 +1065,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1086,7 +1086,7 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@com_google_absl//absl/container:inlined_vector", "@local_xla//xla:literal", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1114,13 +1114,13 @@ tf_kernel_library( "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", - "@local_xla//xla/client/lib:loops", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder/lib:loops", ], ) @@ -1143,8 +1143,8 @@ tf_kernel_library( "@com_google_absl//absl/types:optional", "@local_xla//xla:status_macros", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:slicing", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:slicing", ], ) @@ -1165,9 +1165,9 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1189,10 +1189,10 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@local_xla//xla:comparison_util", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1215,7 +1215,7 @@ tf_kernel_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/container:inlined_vector", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1260,8 +1260,8 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1276,7 +1276,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1295,8 +1295,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -1318,8 +1318,8 @@ tf_kernel_library( "//tensorflow/core/kernels:stochastic_cast_op_header", "@com_google_absl//absl/status:statusor", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1339,7 +1339,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/platform:errors", "@com_google_absl//absl/algorithm:container", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1356,7 +1356,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1375,7 +1375,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@local_xla//xla:literal_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1415,8 +1415,8 @@ tf_kernel_library( "//tensorflow/core:lib", "@com_google_absl//absl/status", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1438,8 +1438,8 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1470,8 +1470,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:qr", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:qr", ], ) @@ -1491,10 +1491,10 @@ tf_kernel_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) @@ -1513,7 +1513,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@local_xla//xla:literal_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1532,9 +1532,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -1551,7 +1551,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:sharding_op_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1571,7 +1571,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@local_xla//xla:literal_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1589,8 +1589,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -1620,9 +1620,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1646,7 +1646,7 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1663,8 +1663,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core/platform:errors", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1686,9 +1686,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -1708,9 +1708,9 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@local_xla//xla:literal", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:slicing", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:slicing", ], ) @@ -1739,9 +1739,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:prng", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:prng", ], ) @@ -1762,9 +1762,9 @@ tf_kernel_library( "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -1788,8 +1788,8 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@local_xla//xla:shape_util", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1808,7 +1808,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1827,8 +1827,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:literal", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -1856,8 +1856,8 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -1873,8 +1873,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", ], ) @@ -1891,7 +1891,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -1926,8 +1926,8 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -1945,9 +1945,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core/platform:errors", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:matrix", - "@local_xla//xla/client/lib:qr", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:matrix", + "@local_xla//xla/hlo/builder/lib:qr", ], ) @@ -1967,10 +1967,10 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:prng", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:prng", ], ) @@ -1988,7 +1988,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:literal_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2005,11 +2005,11 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla:status_macros", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:loops", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:loops", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -2035,11 +2035,11 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2061,9 +2061,9 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2090,13 +2090,13 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:pooling", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:pooling", ], ) @@ -2113,8 +2113,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:slicing", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:slicing", ], ) @@ -2132,8 +2132,8 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:quantize", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:quantize", ], ) @@ -2151,7 +2151,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2179,8 +2179,8 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2206,13 +2206,13 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", - "@local_xla//xla/client/lib:loops", - "@local_xla//xla/client/lib:sorting", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder/lib:loops", + "@local_xla//xla/hlo/builder/lib:sorting", ], ) @@ -2232,7 +2232,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "@local_xla//xla:status_macros", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2254,7 +2254,7 @@ tf_kernel_library( "@com_google_absl//absl/status", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/service:hlo_proto_cc", ], ) @@ -2276,7 +2276,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2297,7 +2297,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2317,9 +2317,9 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@local_xla//xla:literal_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -2336,9 +2336,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -2356,7 +2356,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:literal", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2379,12 +2379,12 @@ tf_kernel_library( "@local_xla//xla:status_macros", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) @@ -2407,7 +2407,7 @@ tf_kernel_library( "@com_google_absl//absl/strings", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2443,10 +2443,10 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:matrix", - "@local_xla//xla/client/lib:pooling", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:matrix", + "@local_xla//xla/hlo/builder/lib:pooling", ], ) @@ -2468,8 +2468,8 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", ], ) @@ -2489,7 +2489,7 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "@local_xla//xla:comparison_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2510,7 +2510,7 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2527,8 +2527,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2545,10 +2545,10 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:comparators", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:comparators", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2565,8 +2565,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client/lib:slicing", - "@local_xla//xla/client/lib:tridiagonal", + "@local_xla//xla/hlo/builder/lib:slicing", + "@local_xla//xla/hlo/builder/lib:tridiagonal", ], ) @@ -2585,10 +2585,10 @@ tf_kernel_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@local_xla//xla/client:client_library", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -2607,7 +2607,7 @@ tf_kernel_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@local_xla//xla:literal", - "@local_xla//xla/client:value_inference", + "@local_xla//xla/hlo/builder:value_inference", ], ) @@ -2624,7 +2624,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2642,7 +2642,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2658,7 +2658,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2675,7 +2675,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2723,11 +2723,11 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:dynamic_shaped_ops", - "@local_xla//xla/client/lib:prng", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:dynamic_shaped_ops", + "@local_xla//xla/hlo/builder/lib:prng", ], ) @@ -2749,9 +2749,9 @@ tf_kernel_library( "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:approx_topk", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:approx_topk", ], ) @@ -2778,9 +2778,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:prng", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:prng", ], ) @@ -2828,7 +2828,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2851,9 +2851,9 @@ tf_kernel_library( "@local_xla//xla:literal", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2879,8 +2879,8 @@ tf_kernel_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:slicing", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:slicing", ], ) @@ -2904,9 +2904,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -2927,8 +2927,8 @@ tf_kernel_library( "//tensorflow/core/platform:errors", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -2950,7 +2950,7 @@ tf_kernel_library( "@local_tsl//tsl/platform:status", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -2967,7 +2967,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client/lib:self_adjoint_eig", + "@local_xla//xla/hlo/builder/lib:self_adjoint_eig", ], ) @@ -2989,9 +2989,9 @@ tf_kernel_library( "//tensorflow/core:lib", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -3009,7 +3009,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3028,7 +3028,7 @@ tf_kernel_library( "//tensorflow/core:framework", "@local_xla//xla:literal_util", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3048,7 +3048,7 @@ tf_kernel_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3067,7 +3067,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core/platform:errors", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3084,7 +3084,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3103,8 +3103,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:literal", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], ) @@ -3124,9 +3124,9 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -3147,9 +3147,9 @@ tf_kernel_library( "//tensorflow/core:protos_all_cc", "@local_xla//xla:shape_util", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -3171,7 +3171,7 @@ tf_kernel_library( "//tensorflow/core/platform:status", "@com_google_absl//absl/status", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3188,8 +3188,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:value_inference", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:value_inference", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3217,9 +3217,9 @@ tf_kernel_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:prng", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:prng", ], ) @@ -3237,7 +3237,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -3254,8 +3254,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -3270,8 +3270,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:matrix", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:matrix", ], ) @@ -3293,11 +3293,11 @@ tf_kernel_library( "@com_google_absl//absl/types:span", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:arithmetic", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:arithmetic", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -3314,8 +3314,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:lib", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", ], ) @@ -3332,7 +3332,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/core:framework", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder/lib:constants", ], ) diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc index 0daa74bf112f29..87089bebe82791 100644 --- a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc index 5c3b931a17097a..95cd1f1a5c1c7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/all_reduce_op.cc @@ -20,9 +20,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc index 0b2711d7d0158e..19c65b653fb54e 100644 --- a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc @@ -17,9 +17,9 @@ limitations under the License. #include "absl/strings/match.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/lib/approx_topk.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/approx_topk.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 785db26c8f26d0..8d764de9b406a8 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 81bbbe1955642d..11cf4682e810bf 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" #include "xla/xla_data.pb.h" #include "tsl/platform/tensor_float_32_utils.h" diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index b1878892a2cf79..9e4703163e0f13 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 045e46a8d10708..b84733e7d55185 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index 25ca7b8c20688f..95d9280924a1ab 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include "xla/literal.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/beta_op.cc b/tensorflow/compiler/tf2xla/kernels/beta_op.cc index 5de8a491217a97..b504493b7ddb0e 100644 --- a/tensorflow/compiler/tf2xla/kernels/beta_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/beta_op.cc @@ -19,11 +19,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/status_macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc index b1aa537b075c7c..d0fb98c575f73d 100644 --- a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index b917eb79c865a8..762f5a25c5f547 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index fddbee67bd4753..374f05fa918a8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -20,10 +20,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index 3a1f24c19c1bd3..d7fc2be632cd29 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc index c524b8b5c59496..e3e64b14dc5302 100644 --- a/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bucketize_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index b1b22e047450b9..ab0a26b2f9fe37 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -25,9 +25,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 67452a81f42dfd..ca7d3280cff15d 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_def_builder.h" diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index ec0d3d3b7d3fec..cf3dbfa2655f27 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -23,10 +23,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 2dd824dd016d89..bc06b3f952f75e 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc index a034c8afbe9dd9..7039fa55651a16 100644 --- a/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/clip_by_value_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -34,7 +34,7 @@ class ClipByValueOp : public XlaOpKernel { auto min = ctx->Input(1); auto max = ctx->Input(2); - auto shape_error = [&]() -> tensorflow::Status { + auto shape_error = [&]() -> absl::Status { return errors::InvalidArgument( "clip_value_min and clip_value_max must be either of " "the same shape as input, or a scalar. ", diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 6c5171a2524ae2..3d515693034ae3 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index aec9a54fa610f0..a1eeea070f7f7d 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index 3d0bc0cb7ac311..a202361a90b539 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -29,9 +29,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/util.h" #include "tensorflow/core/framework/bounds_check.h" @@ -116,7 +116,7 @@ xla::XlaOp ReshapeFilterForDepthwiseConvolution(const xla::Shape& filter_shape, // Performs some basic checks on ConvOpAttrs that are true for all kinds of XLA // convolutions (as currently implemented). -Status CheckConvAttrs(const ConvOpAttrs& attrs) { +absl::Status CheckConvAttrs(const ConvOpAttrs& attrs) { const int num_dims = attrs.num_spatial_dims + 2; const int attrs_strides_size = attrs.strides.size(); if (attrs_strides_size != num_dims) { @@ -153,7 +153,7 @@ Status CheckConvAttrs(const ConvOpAttrs& attrs) { // Wrapper around ConvBackpropComputeDimensions that converts from XLA shapes // to TensorShapes. -Status ConvBackpropComputeDimensionsV2XlaShapes( +absl::Status ConvBackpropComputeDimensionsV2XlaShapes( StringPiece label, int num_spatial_dims, const xla::Shape& input_shape, const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape, absl::Span dilations, const std::vector& strides, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h index 7266b6a35c7381..ff0272f43fca9f 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 35264b105dd566..3d876be0042949 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/node_def_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cross_op.cc b/tensorflow/compiler/tf2xla/kernels/cross_op.cc index 1dfd1b5f208647..a7753644312856 100644 --- a/tensorflow/compiler/tf2xla/kernels/cross_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cross_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index 11d7ce39086c9b..c68e60c7884cc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -27,8 +27,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h index c09818549b7bd3..9be97745d12023 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" diff --git a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc index 992b8bb3387366..62c2ab5202f7a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/data_format_ops.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index 77c0eda900fd8b..93ca01039dda5f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/util/tensor_format.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc index ef4f92cf327861..c8c1705a52f801 100644 --- a/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dequantize_op.cc @@ -18,9 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc index 5b116bb666354b..d2726af1a2b10f 100644 --- a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index b574113195fb2a..4edc4143f1a80a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -22,10 +22,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/pooling.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/pooling.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index bae9915ae7d4e9..e5dcff94279c08 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc index 1b69ee1e398973..f903d5fd130359 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_slice_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 96901ae32c0344..8fb19b1c1c9dae 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc index ed5111f624624a..d48d1fe84e67c9 100644 --- a/tensorflow/compiler/tf2xla/kernels/einsum_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/einsum_op.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.h b/tensorflow/compiler/tf2xla/kernels/elu_op.h index d5ab2e9ee90220..09c88fcbe62d54 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.h +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_ -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { XlaOp Elu(XlaOp x); diff --git a/tensorflow/compiler/tf2xla/kernels/empty_op.cc b/tensorflow/compiler/tf2xla/kernels/empty_op.cc index 98b23ee7882c62..decc24126d0f10 100644 --- a/tensorflow/compiler/tf2xla/kernels/empty_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/empty_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc index 08d7469a8b9035..11256663b59e97 100644 --- a/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/ensure_shape_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc index 1b5e1b63c2cca2..ded81d938d2baa 100644 --- a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tensorflow/core/framework/kernel_shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc index b8f7b19a3d4b1f..52412ee73f9ce8 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_param_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 1f84172f766620..2fa32e1112f8e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc index ac7f4f6cba02ec..a9673934262d1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/util.h" #include "tensorflow/core/framework/bounds_check.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index ffd9433881b157..89824e7a3313b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc index 2640442bf89661..96aef937421f6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fused_conv_ops.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/util/tensor_format.h" @@ -116,7 +116,7 @@ class FusedConv2DInt8Op : public XlaOpKernel { : ActivationMode::kRelu; } - Status DoCompile(XlaOpKernelContext* ctx) { + absl::Status DoCompile(XlaOpKernelContext* ctx) { XlaOp conv_input = ctx->Input(0); XlaOp filter = ctx->Input(1); XlaOp bias = ctx->Input(2); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 8087b271ba5fe2..08285e0bccbc18 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -26,8 +26,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -35,11 +35,11 @@ limitations under the License. namespace tensorflow { -Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, - const xla::XlaOp& indices, const TensorShape& indices_shape, - int64_t axis, bool indices_are_nd, DataType dtype, - DataType index_type, xla::XlaBuilder* builder, - xla::XlaOp* gather_output) { +absl::Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, + const TensorShape& indices_shape, int64_t axis, + bool indices_are_nd, DataType dtype, DataType index_type, + xla::XlaBuilder* builder, xla::XlaOp* gather_output) { // There is no deep reason why we need this precondition, but this is the only // combination that is used and tested today. CHECK(!indices_are_nd || axis == 0); @@ -160,10 +160,11 @@ Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, return absl::OkStatus(); } -Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, - const xla::XlaOp input, - const TensorShape& input_shape, - int batch_dims, xla::XlaOp* gather_output) { +absl::Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + const xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, + xla::XlaOp* gather_output) { auto indices = context->Input(1); auto indices_shape = context->InputShape(1); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index fa160886eb5140..8a8a66669dcf4c 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/util/bcast.h" @@ -33,19 +33,20 @@ namespace tensorflow { // If `indices_are_nd` is true, the last dimension of `indices` are treated as // a multidimensional index values. Otherwise, `indices` is treated as a tensor // of scalar indices. -Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, - const xla::XlaOp& indices, const TensorShape& indices_shape, - int64_t axis, bool indices_are_nd, DataType dtype, - DataType index_type, xla::XlaBuilder* builder, - xla::XlaOp* gather_output); +absl::Status XlaGather(const xla::XlaOp& input, const TensorShape& input_shape, + const xla::XlaOp& indices, + const TensorShape& indices_shape, int64_t axis, + bool indices_are_nd, DataType dtype, DataType index_type, + xla::XlaBuilder* builder, xla::XlaOp* gather_output); // The implementation of Gather and ResourceGather through XLA. Uses `input` as // the input instead of context->input(0) in order to allow ResourceGather to // handle obtaining the data from the ResourceVariable. -Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, - xla::XlaOp input, - const TensorShape& input_shape, - int batch_dims, xla::XlaOp* gather_output); +absl::Status XlaGatherWithBatchDimsOpImpl(XlaOpKernelContext* context, + xla::XlaOp input, + const TensorShape& input_shape, + int batch_dims, + xla::XlaOp* gather_output); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_GATHER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc index 3427b0b57227fd..305557cd773faa 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_scatter_ops.cc @@ -16,7 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/if_op.cc b/tensorflow/compiler/tf2xla/kernels/if_op.cc index 853a9a5afe39c0..2e746b7d59d930 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_op.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/common_runtime/function_body.h" @@ -104,10 +104,10 @@ static absl::StatusOr PopulateTensorArrayGradients( } // Checks that shapes matches on both sides of the conditional. -static Status ValidateShapes(XlaOpKernelContext* ctx, - const XlaCompiler::CompilationResult& then_result, - const XlaCompiler::CompilationResult& else_result, - std::vector& output_shapes) { +static absl::Status ValidateShapes( + XlaOpKernelContext* ctx, const XlaCompiler::CompilationResult& then_result, + const XlaCompiler::CompilationResult& else_result, + std::vector& output_shapes) { // Check that both branches have identical input shapes. if (then_result.xla_input_shapes.size() != 1) { return errors::FailedPrecondition("Expected one input shape"); diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc index ecb6236a6e9bb6..f044eddd0e4218 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include "xla/literal.h" #include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -91,10 +91,10 @@ absl::InlinedVector ConvertCompileTimeConstArgumentsToConst( return resolved_constant_idxs; } -Status FindMustBeConstNodes(XlaOpKernelContext* ctx, - const NameAttrList& func_name, - std::vector* must_be_const_nodes, - const FunctionBody** body) { +absl::Status FindMustBeConstNodes(XlaOpKernelContext* ctx, + const NameAttrList& func_name, + std::vector* must_be_const_nodes, + const FunctionBody** body) { TF_RETURN_IF_ERROR(ctx->compiler()->FindFunctionBody(func_name, body)); must_be_const_nodes->resize((*body)->graph->num_node_ids(), false); return BackwardsConstAnalysis(*((*body)->graph), diff --git a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h index 831df8a6222ff8..eb103954ac8683 100644 --- a/tensorflow/compiler/tf2xla/kernels/if_while_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/if_while_utils.h @@ -43,10 +43,10 @@ absl::InlinedVector ConvertCompileTimeConstArgumentsToConst( // Find and populate `must_be_const_nodes` and `body` of the function // corresponding to the kernel with context `ctx` with name `func_name`. -Status FindMustBeConstNodes(XlaOpKernelContext* ctx, - const NameAttrList& func_name, - std::vector* must_be_const_nodes, - const FunctionBody** body); +absl::Status FindMustBeConstNodes(XlaOpKernelContext* ctx, + const NameAttrList& func_name, + std::vector* must_be_const_nodes, + const FunctionBody** body); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc index f502bcf4244cfd..2213074a89d42e 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -24,13 +24,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/sorting.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/sorting.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc index e1be96ea80e3c3..5d8981dd5e6e3d 100644 --- a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -26,9 +26,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc index 87fa296c198564..a3d801a1a32819 100644 --- a/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/in_topk_op.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index ee34f0d0ec1a04..357ab3e9b0783d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc index 60bfb71895d8fa..8b2e29e29ca8ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index 17a22f18607f7d..227fc71821fc6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -45,7 +45,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -235,7 +235,7 @@ static absl::StatusOr TensorFromProto(const TensorProto& proto) { return out; } -Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( +absl::Status LightOutsideCompilationOp::CompileToCustomCallCallingTfKernel( int graph_def_version, const NodeDef& node_def, XlaOpKernelContext* ctx) { const OpRegistrationData* data = OpRegistry::Global()->LookUp(node_def.op()); int num_inputs = ctx->num_inputs(); @@ -482,9 +482,9 @@ class TfCallbackDevice : public DeviceBase { #endif } - Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, - Allocator* allocator) override { + absl::Status ReinitializeGpuDevice(OpKernelContext* context, + PerOpGpuDevice* device, DeviceContext* dc, + Allocator* allocator) override { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM auto concrete_device = static_cast(device); concrete_device->Reinitialize( @@ -534,9 +534,10 @@ class TfCallbackDevice : public DeviceBase { // Populate the output with actual dimensions of the allocated shapes. // // Populates the vector on the host and then copies it over to the GPU. -Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx, - const TfCallbackData& callback_data, - void** buffers, se::Stream* stream) { +absl::Status PopulateMetadataBufferIfNeeded(OpKernelContext& ctx, + const TfCallbackData& callback_data, + void** buffers, + se::Stream* stream) { for (int i = 0; i < ctx.num_outputs(); i++) { if (callback_data.outputs(i).is_dynamically_padded()) { Tensor* allocated = ctx.mutable_output(i); @@ -572,15 +573,15 @@ class FakeDeviceContext : public DeviceContext { se::Stream* stream_; }; -Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, - int opaque_len) { +absl::Status CallTfKernel(void* stream_handle, void** buffers, + const char* opaque, int opaque_len) { // Look up the platform only once, for a small performance gain. - static Status* platform_status = nullptr; + static absl::Status* platform_status = nullptr; static se::Platform* platform = [&]() -> se::Platform* { absl::StatusOr p = se::PlatformManager::PlatformWithName(PLATFORM); if (!p.ok()) { - platform_status = new Status(p.status()); + platform_status = new absl::Status(p.status()); return nullptr; } return *p; @@ -718,7 +719,7 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, void GenericTfCallback(void* stream_handle, void** buffers, const char* opaque, int opaque_len, XlaCustomCallStatus* status) { - Status s = CallTfKernel(stream_handle, buffers, opaque, opaque_len); + absl::Status s = CallTfKernel(stream_handle, buffers, opaque, opaque_len); if (!s.ok()) { auto msg = s.message(); XlaCustomCallStatusSetFailure(status, msg.data(), msg.size()); diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h index 015929ab08f41f..f9c42e03eb6cf7 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.h @@ -55,11 +55,11 @@ class LightOutsideCompilationOp : public XlaOpKernel { } private: - Status CompileToCustomCallCallingTfKernel(int graph_def_version, - const NodeDef& node_def, - XlaOpKernelContext* ctx); - static Status CallTfKernel(void* stream_handle, void** buffers, - const char* opaque, int opaque_len); + absl::Status CompileToCustomCallCallingTfKernel(int graph_def_version, + const NodeDef& node_def, + XlaOpKernelContext* ctx); + static absl::Status CallTfKernel(void* stream_handle, void** buffers, + const char* opaque, int opaque_len); NodeDef def_; int graph_def_version_; diff --git a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc index b420422911158f..eeb8617a61a39e 100644 --- a/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/listdiff_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -57,7 +57,7 @@ class ListDiffOp : public XlaOpKernel { DataType val_type = context->expected_output_dtype(0); DataType idx_type = context->expected_output_dtype(1); - Status status; + absl::Status status; switch (val_type) { case DT_INT32: status = ListDiffWithIndexType(context, idx_type); @@ -77,7 +77,7 @@ class ListDiffOp : public XlaOpKernel { private: template - Status ListDiff(XlaOpKernelContext* context) { + absl::Status ListDiff(XlaOpKernelContext* context) { std::vector x_input, y_input; TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(0, &x_input)); TF_RETURN_IF_ERROR(context->ConstantInputAsIntVector(1, &y_input)); @@ -107,7 +107,8 @@ class ListDiffOp : public XlaOpKernel { } template - Status ListDiffWithIndexType(XlaOpKernelContext* context, DataType idx_type) { + absl::Status ListDiffWithIndexType(XlaOpKernelContext* context, + DataType idx_type) { switch (idx_type) { case DT_INT32: return ListDiff(context); diff --git a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc index 204d60e1ed8992..46e46f6d8b3d32 100644 --- a/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lower_upper_bound_ops.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc index 8eca5fd0f650d9..b4ea95e04a43b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index ed198538049ef3..bbc77d331ebe4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc index ae379f845adfbf..af0f84aa2e1254 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_diag_ops.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc index 229a48269bfbd0..4981751c489fa7 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_inverse_op.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/qr.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc index b85942d6888a02..9b5530c569dd27 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_solve_op.cc @@ -15,9 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/qr.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc index 7c23bc3adce0ff..91d2d344b07ad0 100644 --- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index 4d978042f31ebc..3556900f49b670 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/kernels/next_after_op.cc b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc index ce7ed4c8bd8a79..42cd8dc213f1a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/next_after_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/next_after_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 0220fae6c541da..e41db50beeec48 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc index 426cd36bcf918c..a096b8f2a23e02 100644 --- a/tensorflow/compiler/tf2xla/kernels/pack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 2fa6d35ff9ec88..fc19df334a4a75 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 22c40422be9e3b..e8e6ca0beb361e 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -28,13 +28,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/pooling.h" -#include "xla/client/padding.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/pooling.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" @@ -54,7 +54,7 @@ namespace tensorflow { namespace { template -static Status ValidateKernelSizes(const T& ksizes) { +static absl::Status ValidateKernelSizes(const T& ksizes) { for (size_t i = 0; i < ksizes.size(); ++i) { if (ksizes[i] <= 0) { return errors::InvalidArgument( @@ -66,7 +66,7 @@ static Status ValidateKernelSizes(const T& ksizes) { } template -static Status ValidateStrides(const T& strides) { +static absl::Status ValidateStrides(const T& strides) { for (size_t i = 0; i < strides.size(); ++i) { if (strides[i] <= 0) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/qr_op.cc b/tensorflow/compiler/tf2xla/kernels/qr_op.cc index 0839b858ae0031..6120903fe9c991 100644 --- a/tensorflow/compiler/tf2xla/kernels/qr_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/qr_op.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/qr.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index f6cec3fa07e369..de7247399567e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index dfabf80b644410..ab83bbbe7120b3 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -28,10 +28,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc index ffc37e7659e86c..62c7c00592efdf 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/rng_converter_utils.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h index 0f6464c7ce7925..11ff44602f1900 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.h +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 576c21cdb57419..5f911018c244b5 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -97,7 +97,8 @@ class MaxOp : public XlaReductionOp { OP_REQUIRES_OK(ctx, PrimitiveTypeCheck(xla_reduction_type_)); } - static Status PrimitiveTypeCheck(xla::PrimitiveType xla_reduction_type) { + static absl::Status PrimitiveTypeCheck( + xla::PrimitiveType xla_reduction_type) { if (xla_reduction_type == xla::C64 || xla_reduction_type == xla::C128 || xla_reduction_type == xla::TUPLE || xla_reduction_type == xla::OPAQUE_TYPE) { diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h index 63fe536aa4666b..0c7e87015f940a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 60042e2b0e05bd..d1933ff4cff27c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index d4b5339285bef1..f274b271596ff5 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.h b/tensorflow/compiler/tf2xla/kernels/relu_op.h index a968194211b29a..b980df777ee734 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.h +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ #define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_ -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { XlaOp Relu(XlaOp x); diff --git a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc index 106df7cd62e644..e4b08184ba5c43 100644 --- a/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/resampler_ops.cc @@ -21,9 +21,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 31a6b689811e74..df67f3f4938356 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 9a5cfbdc37348b..5637d9091dd2fc 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc index 93140c404d6e5e..17b0f35fad3b81 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/roll_op.cc b/tensorflow/compiler/tf2xla/kernels/roll_op.cc index 5cead12fd8baf3..870c3092865367 100644 --- a/tensorflow/compiler/tf2xla/kernels/roll_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/roll_op.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc index b4b33518881d84..1444abda838008 100644 --- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc index f98ae95b58151c..29281a7696e589 100644 --- a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -34,10 +34,10 @@ namespace { // Check whether updates.shape = indices.shape[:batch_dim] + // buffer_shape[num_index_dims:] -Status ValidateUpdateShape(const TensorShape& buffer_shape, - const TensorShape& indices_shape, - const TensorShape& updates_shape, - bool broadcast_scalar_update) { +absl::Status ValidateUpdateShape(const TensorShape& buffer_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape, + bool broadcast_scalar_update) { if (indices_shape.dims() < 1) { return errors::InvalidArgument( "indices shape must have >= 1 dimension; got ", diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 071644b20ac9d9..73c5a9c6ed98e6 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc index 57278a1cbddc9d..fc9e96939b2c38 100644 --- a/tensorflow/compiler/tf2xla/kernels/select_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index 569a92d1f7fa3b..e1e93d614286a3 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 81af04fc6adcf4..60a4a1a5bc62d1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 2f1af1e40254be..b721011f512624 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -30,8 +30,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc index e09b7dacb82824..f217bc09ec79e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -28,8 +28,8 @@ limitations under the License. namespace tensorflow { -Status TensorShapeToConstant(const TensorShape& input_shape, - Tensor* shape_constant) { +absl::Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { const int dims = input_shape.dims(); if (shape_constant->dtype() == DT_INT32) { auto vec = shape_constant->vec(); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h index edef94037d38e0..4ec37b1fe7cfda 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_util.h +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -28,8 +28,8 @@ namespace tensorflow { // // The input TensorShape input_shape is used to populate the elements of // shape_constant, which is modified in place. -Status TensorShapeToConstant(const TensorShape& input_shape, - Tensor* shape_constant); +absl::Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc index 5cad1625a46519..a56dd7ed74791c 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/sharding_op_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc index 8192283016c026..63bdacfb795665 100644 --- a/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sharding_util_ops.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -43,10 +43,11 @@ constexpr absl::string_view kNumSplitsAttrName = "num_splits"; constexpr absl::string_view kNumConcatsAttrName = "num_concats"; template -Status GetAndValidateAttributes(OpKernelConstruction* ctx, - std::vector& num_partitions, - int& num_slices, std::vector& paddings, - bool& has_paddings) { +absl::Status GetAndValidateAttributes(OpKernelConstruction* ctx, + std::vector& num_partitions, + int& num_slices, + std::vector& paddings, + bool& has_paddings) { absl::string_view num_partitions_attr_name = Split ? kNumSplitsAttrName : kNumConcatsAttrName; TF_RETURN_IF_ERROR(ctx->GetAttr(num_partitions_attr_name, &num_partitions)); @@ -140,9 +141,9 @@ class XlaSplitNDBaseOp : public XlaOpKernel { } protected: - Status CompileInternal(XlaOpKernelContext* ctx, const xla::XlaOp input, - const TensorShape& input_shape, - const DataType input_dtype) { + absl::Status CompileInternal(XlaOpKernelContext* ctx, const xla::XlaOp input, + const TensorShape& input_shape, + const DataType input_dtype) { xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(input_dtype, &type)); @@ -399,10 +400,10 @@ class XlaConcatNDBaseOp : public XlaOpKernel { DataType dtype_; private: - Status GetInputsAndOutputShape(XlaOpKernelContext* ctx, - std::vector& input_handles, - std::vector& input_shapes, - std::vector& output_shape) { + absl::Status GetInputsAndOutputShape(XlaOpKernelContext* ctx, + std::vector& input_handles, + std::vector& input_shapes, + std::vector& output_shape) { TF_RETURN_IF_ERROR(ctx->InputList("inputs", &input_handles, &input_shapes)); const TensorShape& slice_shape = input_shapes[0]; diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index e4d7adca29debf..35c936d5fb88db 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index c50cceb41ea546..406b79d9981846 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc index d858a5f837f9f4..cf2d49ff5bdae2 100644 --- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc @@ -16,8 +16,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index a591aae75f6a5f..858233c28c8d03 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index a0bbb14b33c548..2648c0b077e689 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc index ee9416185e1677..f3afba664bedbe 100644 --- a/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/sparse_to_dense_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index b67cac26ea7468..ebef4cd81b2687 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc index f9f070e8d228d9..496440e9cafbf3 100644 --- a/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/spmd_manual_sharding_ops.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index c5de7862c80d67..d8bd987232b569 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -41,8 +41,8 @@ limitations under the License. namespace tensorflow { namespace { -Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, - TensorShape* stack_shape) { +absl::Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, + TensorShape* stack_shape) { auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); @@ -63,8 +63,9 @@ Status GetStackShape(xla::XlaBuilder* builder, XlaResource* resource, // // TODO(phawkins): consider changing the API of the stack operators to // allow an optional element shape at stack construction time. -Status MaybeInitializeStack(xla::XlaBuilder* builder, XlaResource* resource, - DataType dtype, const TensorShape& elem_shape) { +absl::Status MaybeInitializeStack(xla::XlaBuilder* builder, + XlaResource* resource, DataType dtype, + const TensorShape& elem_shape) { if (resource->type() != dtype) { return errors::InvalidArgument( "Stack dtype is ", DataTypeString(resource->type()), diff --git a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc index cc7684d0d6b474..01a44c9d734448 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateful_random_ops.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/util.h" @@ -138,7 +138,8 @@ int64_t GetMinStateSize(xla::RandomAlgorithm alg) { } } -Status CheckStateShape(xla::RandomAlgorithm alg, const TensorShape& shape) { +absl::Status CheckStateShape(xla::RandomAlgorithm alg, + const TensorShape& shape) { if (shape.dims() != 1) { return errors::InvalidArgument( "RNG state must have one and only one dimension, not ", shape.dims()); @@ -203,7 +204,7 @@ xla::XlaOp CounterAndKeyToVariable(xla::RandomAlgorithm alg, xla::XlaOp state, // A helper function containing the common part of several kernels below. // Precondition: 'algorithm' and 'shape' are compile-time constants. -Status CompileImpl( +absl::Status CompileImpl( XlaOpKernelContext* ctx, int state_input_idx, int alg_input_idx, int shape_input_idx, std::functionConstantInputAsShape(0, &static_shape); + absl::Status status = ctx->ConstantInputAsShape(0, &static_shape); if (status.ok()) { ctx->SetOutput(0, uniform); return; @@ -280,7 +280,7 @@ class StatelessRandomNormalOp : public XlaOpKernel { // If the input shape is constant, no need to set dimension sizes. // TODO(hinsu): Simplify this once MLIR bridge can handle bounded types. TensorShape static_shape; - Status status = ctx->ConstantInputAsShape(0, &static_shape); + absl::Status status = ctx->ConstantInputAsShape(0, &static_shape); if (status.ok()) { ctx->SetOutput(0, normal); return; diff --git a/tensorflow/compiler/tf2xla/kernels/stochastic_cast_op.cc b/tensorflow/compiler/tf2xla/kernels/stochastic_cast_op.cc index 5af51f6cff0466..94349c2e7f4c0c 100644 --- a/tensorflow/compiler/tf2xla/kernels/stochastic_cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/stochastic_cast_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index d45176aad55630..2a31e5f15fe5e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -27,10 +27,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 5f5d17c0aa1f0f..25110df1c7d733 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/status_macros.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -52,9 +52,9 @@ namespace { // the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. -Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, - XlaResource* resource, DataType dtype, - const TensorShape& elem_shape) { +absl::Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, + XlaResource* resource, DataType dtype, + const TensorShape& elem_shape) { if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument("Unexpected non-TensorArray resource"); } @@ -94,9 +94,9 @@ Status MaybeInitializeTensorArray(xla::XlaBuilder* builder, // Checks that the TensorArray 'resource' has been initialized, and has type // 'dtype'. Sets 'shape' to the shape -Status CheckTensorArrayIsInitialized(const string& op_name, - const XlaResource* resource, - DataType dtype) { +absl::Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument( "Unexpected non-TensorArray resource passed to ", op_name); @@ -114,8 +114,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name, return absl::OkStatus(); } -Status GetTensorArrayShape(const XlaResource* resource, - xla::XlaBuilder* builder, TensorShape* shape) { +absl::Status GetTensorArrayShape(const XlaResource* resource, + xla::XlaBuilder* builder, TensorShape* shape) { *shape = resource->shape(); shape->InsertDim(0, resource->max_array_size()); return absl::OkStatus(); @@ -321,7 +321,7 @@ class TensorArrayGatherOp : public XlaOpKernel { // Look for the case where the gather takes a simple slice from the // tensor array (0, 1, 2, 3, 4, ..., N) std::vector const_indices; - Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + absl::Status status = ctx->ConstantInputAsIntVector(1, &const_indices); if (status.ok()) { bool gather_is_dense_slice = true; for (auto i = 0; i < const_indices.size(); i++) { @@ -393,7 +393,7 @@ class TensorArrayScatterOp : public XlaOpKernel { // tensor array implementation allows for this to be a straight addition. bool scatter_all_elements_in_order = false; std::vector const_indices; - Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + absl::Status status = ctx->ConstantInputAsIntVector(1, &const_indices); if (status.ok() && num_indices == value_shape.dim_size(0)) { scatter_all_elements_in_order = true; for (auto i = 0; i < num_indices; i++) { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 0020da36e87301..76257c25a932c6 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -111,9 +111,10 @@ REGISTER_XLA_OP(Name("TensorListLength").IsMetadataOp(), TensorListLengthOp); // "input" is the shape input for EmptyTensorList/TensorListReserve ops. // If "input" is a compile time constant and not "unknown rank" (-1), return // its value in "*shape". -Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, xla::XlaOp input, - xla::PrimitiveType dtype, bool* got_shape, - xla::Shape* shape) { +absl::Status TryGetElementShapeFromInput(XlaOpKernelContext* ctx, + xla::XlaOp input, + xla::PrimitiveType dtype, + bool* got_shape, xla::Shape* shape) { auto is_compile_time_constant_or = input.builder()->IsConstant(input); TF_RETURN_IF_ERROR(is_compile_time_constant_or.status()); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index a2d7137dfc3359..37d0ae44178998 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -119,13 +119,13 @@ bool IsTensorListInput(XlaOpKernelContext* ctx, int index) { return ctx->InputExpression(index).kind() == XlaExpression::Kind::kTensorList; } -Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { +absl::Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); *is_initialized = list_shape.IsTuple(); return absl::OkStatus(); } -Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { +absl::Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -136,14 +136,15 @@ Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { return absl::OkStatus(); } -Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, - xla::XlaOp* output_list) { +absl::Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, + xla::XlaOp* output_list) { TF_RET_CHECK(buffer.builder()); *output_list = xla::Tuple(buffer.builder(), {buffer, push_index}); return absl::OkStatus(); } -Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { +absl::Status GetTensorListBufferShape(xla::XlaOp list, + xla::Shape* buffer_shape) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -154,7 +155,7 @@ Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { return absl::OkStatus(); } -Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { +absl::Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -164,7 +165,7 @@ Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { return absl::OkStatus(); } -Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { +absl::Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -176,8 +177,8 @@ Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { return absl::OkStatus(); } -Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, - xla::XlaOp* result) { +absl::Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, + xla::XlaOp* result) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -210,9 +211,9 @@ xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, } } -Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, - bool* leading_dim_is_dynamic, - xla::XlaOp* leading_dim_dynamic_size) { +absl::Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); TF_ASSIGN_OR_RETURN(xla::Shape list_shape, list.builder()->GetShape(list)); @@ -230,7 +231,7 @@ Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, return absl::OkStatus(); } -Status GetTensorListShapeFromElementTensorListShape( +absl::Status GetTensorListShapeFromElementTensorListShape( const xla::Shape& element_tensor_list_shape, int64_t leading_dim, bool leading_dim_is_dynamic, xla::Shape* tensor_list_shape) { std::vector shapes; @@ -252,10 +253,10 @@ Status GetTensorListShapeFromElementTensorListShape( return absl::OkStatus(); } -Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, - int64_t leading_dim, - bool leading_dim_is_dynamic, - xla::Shape* tensor_list_shape) { +absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, + int64_t leading_dim, + bool leading_dim_is_dynamic, + xla::Shape* tensor_list_shape) { if (!element_shape.IsArray()) { return errors::InvalidArgument( "GetTensorListShapeFromElementShape() only supports normal tensor " @@ -275,7 +276,7 @@ Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, return absl::OkStatus(); } -Status CreateZerosTensorListWithShape( +absl::Status CreateZerosTensorListWithShape( xla::XlaBuilder* b, const xla::Shape& list_shape, const std::vector>& dynamic_dims, xla::XlaOp* list) { @@ -304,9 +305,10 @@ Status CreateZerosTensorListWithShape( return absl::OkStatus(); } -Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, - bool element_is_tensor_list, - xla::XlaOp* initialized_list) { +absl::Status GetInitializedTensorListForElement(xla::XlaOp list, + xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* initialized_list) { int64_t leading_dim; xla::XlaOp leading_dim_dynamic_size; bool leading_dim_is_dynamic; @@ -360,9 +362,9 @@ Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, } } -Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, - bool element_is_tensor_list, - xla::XlaOp* result) { +absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* result) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -422,9 +424,9 @@ Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, return absl::OkStatus(); } -Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, - xla::XlaOp* element_result, - bool* element_is_tensor_list) { +absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, + xla::XlaOp* element_result, + bool* element_is_tensor_list) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -471,8 +473,8 @@ Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, return absl::OkStatus(); } -Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, - xla::XlaOp element, xla::XlaOp* result) { +absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp element, xla::XlaOp* result) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -528,8 +530,8 @@ Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, return absl::OkStatus(); } -Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, - xla::XlaOp* result) { +absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp* result) { bool is_initialized; TF_RETURN_IF_ERROR(IsTensorListInitialized(list, &is_initialized)); if (!is_initialized) { @@ -570,8 +572,8 @@ Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, return absl::OkStatus(); } -Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, - xla::XlaOp* result) { +absl::Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, + xla::XlaOp* result) { xla::XlaBuilder* b = tensor.builder(); TF_ASSIGN_OR_RETURN(xla::Shape shape, b->GetShape(tensor)); if (!shape.IsArray()) { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h index ed76f2e57e69c1..a86336ce79454c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/status.h" @@ -30,37 +30,38 @@ namespace tensorflow { bool IsTensorListInput(XlaOpKernelContext* ctx, int index); // Whether the TensorList is initialized (has known data type and shape). -Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized); +absl::Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized); // Whether the TensorList is a nested TensorList. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list); +absl::Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list); // Builds a non-nested TensorList from `buffer` and `push_index`. -Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, - xla::XlaOp* output_list); +absl::Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, + xla::XlaOp* output_list); // Returns buffer shape for the TensorList. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape); +absl::Status GetTensorListBufferShape(xla::XlaOp list, + xla::Shape* buffer_shape); // Returns buffer for the TensorList. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer); +absl::Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer); // Returns push index for the TensorList. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index); +absl::Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index); // Returns a new TensorList with given push_index. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, - xla::XlaOp* result); +absl::Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, + xla::XlaOp* result); // Returns an uninitialized TensorList. xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, @@ -71,19 +72,19 @@ xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b, // Returns leading dimension for the TensorList as well as a dynamic op // representing the dynamic size. Input can be initialized or uninitialized // TensorList. Non-nested and nested TensorLists are both supported. -Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, - bool* leading_dim_is_dynamic, - xla::XlaOp* leading_dim_dynamic_size); +absl::Status GetLeadingDimForTensorList(xla::XlaOp list, int64_t* leading_dim, + bool* leading_dim_is_dynamic, + xla::XlaOp* leading_dim_dynamic_size); // Returns TensorList shape for the element shape. // Element shape must be a normal tensor shape. -Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, - int64_t leading_dim, - bool leading_dim_is_dynamic, - xla::Shape* tensor_list_shape); +absl::Status GetTensorListShapeFromElementShape(const xla::Shape& element_shape, + int64_t leading_dim, + bool leading_dim_is_dynamic, + xla::Shape* tensor_list_shape); // Returns a TensorList filled by zeros with the given shape. -Status CreateZerosTensorListWithShape( +absl::Status CreateZerosTensorListWithShape( xla::XlaBuilder* b, const xla::Shape& list_shape, const std::vector>& dynamic_dims, xla::XlaOp* list); @@ -91,40 +92,41 @@ Status CreateZerosTensorListWithShape( // If the TensorList is uninitialized, initialize it with the element shape. // Input can be initialized or uninitialized TensorList. // "element" can be normal tensor or TensorList. -Status GetInitializedTensorListForElement(xla::XlaOp list, xla::XlaOp element, - bool element_is_tensor_list, - xla::XlaOp* initialized_list); +absl::Status GetInitializedTensorListForElement(xla::XlaOp list, + xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* initialized_list); // Executes TensorListPushBack with given TensorList and element. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, - bool element_is_tensor_list, - xla::XlaOp* result); +absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, + bool element_is_tensor_list, + xla::XlaOp* result); // Executes TensorListPopBack with given TensorList. // Input must be an initialized TensorList. // Non-nested and nested TensorLists are both supported. -Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, - xla::XlaOp* element_result, - bool* element_is_tensor_list); +absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, + xla::XlaOp* element_result, + bool* element_is_tensor_list); // Executes TensorListSetItem with given TensorList, index and element. // Input must be an initialized TensorList. // Only non-nested TensorList is supported. -Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, - xla::XlaOp element, xla::XlaOp* result); +absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp element, xla::XlaOp* result); // Executes TensorListGetItem with given TensorList and index. // Input must be an initialized TensorList. // Only non-nested TensorList is supported. -Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, - xla::XlaOp* result); +absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, + xla::XlaOp* result); // Executes TensorListPushBack with given tensor and push index. // "tensor" must be a normal tensor. -Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, - xla::XlaOp* result); +absl::Status ExecuteTensorListFromTensor(int push_index, xla::XlaOp tensor, + xla::XlaOp* result); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 269fcbd3a9bfb8..d6bf070137f226 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc index c3e384ccfb13d0..fddfbb288124f0 100644 --- a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -38,7 +38,7 @@ class ToBoolOp : public XlaOpKernel { } private: - Status DoCompile(XlaOpKernelContext* ctx) { + absl::Status DoCompile(XlaOpKernelContext* ctx) { auto input = ctx->Input(0); // If the input is a scalar, then non-zero value returns True. diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 9f8507fca1a093..a8003fbb9927d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/sorting.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/sorting.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 180decfb524c7d..1f2495c15a0512 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 7e6fa3267e191b..3d6beb1c1a1120 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" @@ -131,7 +131,7 @@ class InvertPermutationOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { DataType dtype = ctx->expected_output_dtype(0); - Status status; + absl::Status status; switch (dtype) { case DT_INT32: InvertPermutation(ctx); diff --git a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc index ff9a9f36d2b3f2..0747a79bb09217 100644 --- a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/lib/tridiagonal.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index f8db1591157a7a..5eb6438f89d322 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -19,9 +19,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index 670330467a403b..04f8e9eca3bd15 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/relu_op.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/platform/errors.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index 0faa21053a0fce..00d11ef7f34543 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -26,12 +26,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc index f9255e633a1209..0fc6e3e317c30b 100644 --- a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 6d59a19d53f52e..c69593dd1c21c2 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" @@ -39,7 +39,7 @@ limitations under the License. namespace tensorflow { namespace { -Status ValidateAssignUpdateVariableOpShapes(XlaOpKernelContext* ctx) { +absl::Status ValidateAssignUpdateVariableOpShapes(XlaOpKernelContext* ctx) { DataType variable_dtype; TensorShape variable_shape; TensorShape value_shape = ctx->InputShape(1); diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 5583670045e3af..8d45963d124bf4 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -22,12 +22,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/dynamic_shaped_ops.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 8f6901fbb72a63..9a0a633aa63ea4 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -34,9 +34,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/client.h" -#include "xla/client/lib/tuple.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/tuple.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" @@ -59,8 +59,8 @@ namespace tensorflow { namespace { // Verify that input resources are grouped in the end. -Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx, - const NameAttrList& body_name_attr) { +absl::Status VerifyResourceArgsGroupedAtEnd( + XlaOpKernelContext* ctx, const NameAttrList& body_name_attr) { const FunctionBody* body; TF_RETURN_IF_ERROR(ctx->compiler()->FindFunctionBody(body_name_attr, &body)); bool has_seen_resource = false; @@ -83,7 +83,7 @@ Status VerifyResourceArgsGroupedAtEnd(XlaOpKernelContext* ctx, } // Builds XlaCompiler argument descriptions `args` from `ctx`. -Status MakeXlaCompilerArgumentsFromInputs( +absl::Status MakeXlaCompilerArgumentsFromInputs( XlaOpKernelContext* ctx, std::vector* args, bool* has_uninitialized_vars, bool* has_tensor_arrays, bool* has_uninitialized_tensor_lists) { @@ -153,7 +153,7 @@ void GetLoopInvariants(XlaOpKernelContext* ctx, // Converts entries in `args` which are loop invariants and have compile time // constant inputs and need to be constants in order to be compilable to // constants so that they can be propagated in the loop body. -Status ConvertLoopInvariantsToConst( +absl::Status ConvertLoopInvariantsToConst( XlaOpKernelContext* ctx, const NameAttrList& body_name_attr, const NameAttrList& cond_name_attr, std::vector* args, @@ -191,7 +191,7 @@ Status ConvertLoopInvariantsToConst( return absl::OkStatus(); } -Status VerifyBodyInputAndOutputShapeMatch( +absl::Status VerifyBodyInputAndOutputShapeMatch( XlaOpKernelContext* ctx, const std::vector& compile_time_const_arg_indices, const XlaCompiler::CompilationResult& body, bool has_token_input_output) { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc index 6b361a2281acf9..13ac54b85463df 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_broadcast_helper_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 252d746727f443..df134f8ba50e6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -59,14 +59,14 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/python/refine_polymorphic_shapes.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h index 0260e7fc1010cf..c21f7508266b77 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 0438e648f4fe1c..4fd6de74ca0835 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -54,17 +54,17 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc index 9d613d088a4db5..01c22410bb8801 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc index f8712244ce86c9..3a2e8015c1037e 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_custom_call_v2_op.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" @@ -59,7 +59,7 @@ class XlaCustomCallV2Op : public XlaOpKernel { } private: - Status CompileImpl(XlaOpKernelContext& ctx) const { + absl::Status CompileImpl(XlaOpKernelContext& ctx) const { std::vector operands(ctx.num_inputs()); std::vector operand_shapes(ctx.num_inputs()); for (int i = 0; i < ctx.num_inputs(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc index 633dce69c5a889..7b0ea597c63488 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dequantize_op.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/quantize.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/quantize.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc index e2b951c9462525..8236e67eeded01 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_dot_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc index ca2b0cc7789782..2341a820ea921a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_pad_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc index 514848da7341ed..0cfd247bdd1de6 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_self_adjoint_eig_op.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc index 7bbd710f1b41ef..f3bd088ced826a 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_svd_op.cc @@ -17,9 +17,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/lib/svd.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/lib/svd.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index dcf2b2133f88f7..011a93e97c52fb 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -47,7 +47,7 @@ XlaShapeLayoutHelpers::LayoutPreferenceFn UseNoPreferenceLayoutFn() { } // Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( +absl::Status RewriteLayoutWithShardedShape( const std::optional& sharding, bool use_fast_memory, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, xla::Shape* xla_shape) { diff --git a/tensorflow/compiler/tf2xla/layout_util.h b/tensorflow/compiler/tf2xla/layout_util.h index cf5168a96561d2..9acf7da980fef1 100644 --- a/tensorflow/compiler/tf2xla/layout_util.h +++ b/tensorflow/compiler/tf2xla/layout_util.h @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" @@ -64,7 +64,7 @@ class XlaShapeLayoutHelpers { XlaShapeLayoutHelpers::LayoutPreferenceFn UseNoPreferenceLayoutFn(); // Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( +absl::Status RewriteLayoutWithShardedShape( const std::optional& sharding, bool use_fast_memory, XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, xla::Shape* xla_shape); diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 5618de11014bae..fe4558c368684f 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -33,8 +33,8 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:broadcast", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:broadcast", ], ) @@ -48,9 +48,9 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:statusor", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", - "@local_xla//xla/client/lib:math", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", + "@local_xla//xla/hlo/builder/lib:math", ], ) @@ -67,8 +67,8 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -84,8 +84,8 @@ cc_library( "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", ], ) @@ -99,6 +99,6 @@ cc_library( "@com_google_absl//absl/status:statusor", "@local_xla//xla:shape_util", "@local_xla//xla:util", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc index a5ab8a8b0e1fa0..80526def9b6ca1 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.cc +++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "xla/client/lib/broadcast.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/broadcast.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -37,7 +37,7 @@ absl::StatusOr BroadcastTo(xla::XlaOp input, return xla::BroadcastTo(input, output_dims); } -Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { +absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs) { TF_ASSIGN_OR_RETURN(auto lhs_xla_shape, lhs->builder()->GetShape(*lhs)); TF_ASSIGN_OR_RETURN(auto rhs_xla_shape, rhs->builder()->GetShape(*rhs)); tensorflow::TensorShape lhs_tf_shape; diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h index 4c513a09688b83..48dec32af8b081 100644 --- a/tensorflow/compiler/tf2xla/lib/broadcast.h +++ b/tensorflow/compiler/tf2xla/lib/broadcast.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/statusor.h" @@ -30,7 +30,7 @@ absl::StatusOr BroadcastTo(xla::XlaOp input, absl::Span output_dims); // Forwards to xla::BroadcastOpsToSame. -Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); +absl::Status BroadcastOpsToSame(xla::XlaOp* lhs, xla::XlaOp* rhs); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_ diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 091aaab37d2453..a087abc806e5d7 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/data_format.h" #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/util.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/lib/data_format.h b/tensorflow/compiler/tf2xla/lib/data_format.h index 2845b49656d591..131f5491fe5e06 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.h +++ b/tensorflow/compiler/tf2xla/lib/data_format.h @@ -17,7 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_DATA_FORMAT_H_ #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/util/tensor_format.h" diff --git a/tensorflow/compiler/tf2xla/lib/random.cc b/tensorflow/compiler/tf2xla/lib/random.cc index 8cf8807831c514..1998acb8e4b4e0 100644 --- a/tensorflow/compiler/tf2xla/lib/random.cc +++ b/tensorflow/compiler/tf2xla/lib/random.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/random.h b/tensorflow/compiler/tf2xla/lib/random.h index d586af85f18602..3c03633d34da7a 100644 --- a/tensorflow/compiler/tf2xla/lib/random.h +++ b/tensorflow/compiler/tf2xla/lib/random.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ #define TENSORFLOW_COMPILER_TF2XLA_LIB_RANDOM_H_ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 783a1401c81202..086684336b6de5 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 304200ff26eab3..90af6e63fcbf05 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -19,8 +19,8 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 88694a535f99b3..550f77f0dccfb0 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/util.h" #include "absl/log/log.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 83afeb2e6f5850..24f66027dd3ce5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -17,8 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 4dabab86ab3b55..8bae314ff472fa 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -36,17 +36,17 @@ limitations under the License. namespace tensorflow { -Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, - xla::BorrowingLiteral* literal) { +absl::Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(), host_tensor.shape(), &xla_shape)); return HostTensorToBorrowingLiteral(xla_shape, host_tensor, literal); } -Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, - const Tensor& host_tensor, - xla::BorrowingLiteral* literal) { +absl::Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, + const Tensor& host_tensor, + xla::BorrowingLiteral* literal) { const auto& tshape = host_tensor.shape(); TF_RET_CHECK(tshape.IsFullyDefined() && tshape.dims() == xla_shape.dimensions_size() && @@ -63,7 +63,7 @@ absl::StatusOr HostTensorToLiteral(const Tensor& host_tensor) { return literal.Clone(); } -Status HostTensorToMutableBorrowingLiteral( +absl::Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { xla::Shape xla_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor->dtype(), @@ -71,7 +71,7 @@ Status HostTensorToMutableBorrowingLiteral( return HostTensorToMutableBorrowingLiteral(xla_shape, host_tensor, literal); } -Status HostTensorToMutableBorrowingLiteral( +absl::Status HostTensorToMutableBorrowingLiteral( const xla::Shape& xla_shape, Tensor* host_tensor, xla::MutableBorrowingLiteral* literal) { *literal = xla::MutableBorrowingLiteral( @@ -80,8 +80,8 @@ Status HostTensorToMutableBorrowingLiteral( return absl::OkStatus(); } -Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, - xla::BorrowingLiteral* literal) { +absl::Status HostTensorsToBorrowingLiteralTuple( + absl::Span host_tensors, xla::BorrowingLiteral* literal) { std::vector buf_ptrs; buf_ptrs.reserve(host_tensors.size()); std::vector tensor_shapes(host_tensors.size()); @@ -100,8 +100,8 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, return absl::OkStatus(); } -Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, - Tensor* host_tensor) { +absl::Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, + Tensor* host_tensor) { TF_RET_CHECK(literal.shape().IsArray() && xla::ShapeUtil::ElementsIn(literal.shape()) == host_tensor->NumElements()); @@ -123,8 +123,8 @@ Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, return absl::OkStatus(); } -Status LiteralToHostTensor(const xla::LiteralSlice& literal, - DataType target_type, Tensor* host_tensor) { +absl::Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor) { TensorShape shape; TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); *host_tensor = Tensor(target_type, shape); diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index c0623c43953eb5..4463024eb48f82 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -31,14 +31,14 @@ namespace tensorflow { // Returns a BorrowingLiteral that utilizes the same underlying buffer owned by // 'host_tensor'. -Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, - xla::BorrowingLiteral* literal); +absl::Status HostTensorToBorrowingLiteral(const Tensor& host_tensor, + xla::BorrowingLiteral* literal); // Similar as above, except the literal shape is explicitly provided and used // instead of obtaining it from the 'host_tensor'. The provided literal shape // 'xla_shape' must be compatible with the shape of 'host_tensor'. -Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, - const Tensor& host_tensor, - xla::BorrowingLiteral* literal); +absl::Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, + const Tensor& host_tensor, + xla::BorrowingLiteral* literal); // Returns a Literal with the contents of 'host_tensor', backed by its own // storage (i.e., not reusing 'host_tensor's buffers.) @@ -46,19 +46,19 @@ absl::StatusOr HostTensorToLiteral(const Tensor& host_tensor); // Returns a MutableBorrowingLiteral that utilizes the same underlying buffer // owned by 'host_tensor', but is mutable via the xla::Literal methods. -Status HostTensorToMutableBorrowingLiteral( +absl::Status HostTensorToMutableBorrowingLiteral( Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); // Similar as above, except the literal shape is explicitly provided and used // instead of obtaining it from the 'host_tensor'. The provided literal shape // 'xla_shape' must be compatible with the shape of 'host_tensor'. -Status HostTensorToMutableBorrowingLiteral( +absl::Status HostTensorToMutableBorrowingLiteral( const xla::Shape& xla_shape, Tensor* host_tensor, xla::MutableBorrowingLiteral* literal); // Returns a BorrowingLiteral tuple that utilizes the same underlying buffers // owned by 'host_tensors'. -Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, - xla::BorrowingLiteral* literal); +absl::Status HostTensorsToBorrowingLiteralTuple( + absl::Span host_tensors, xla::BorrowingLiteral* literal); // Copies 'literal' to freshly allocated 'host_tensor', which is allocated of // type . @@ -67,14 +67,14 @@ Status HostTensorsToBorrowingLiteralTuple(absl::Span host_tensors, // derivable from the type of , because multiple tensorflow types map // to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in // XLA). -Status LiteralToHostTensor(const xla::LiteralSlice& literal, - DataType target_type, Tensor* host_tensor); +absl::Status LiteralToHostTensor(const xla::LiteralSlice& literal, + DataType target_type, Tensor* host_tensor); // Copies the contents of 'literal' to a previously allocated tensor // 'host_tensor'. The tensor and the literal must have the same number of // elements and the same type. -Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, - Tensor* host_tensor); +absl::Status CopyLiteralToHostTensor(const xla::LiteralSlice& literal, + Tensor* host_tensor); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 31294fe7b036f1..b7c9b5fd7bbf13 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc index c8afadca62771d..33c4395a1f053c 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.cc @@ -235,10 +235,10 @@ MlirOptimizationPassState MlirBridgePass::GetPassState( // and attached to a "compile" operation, whose result is fed to an "execute" // operation. The kernel for these operations is responsible to lower the // encapsulated graph to a particular device. -Status MlirBridgePass::Run(const std::string& function_name, - const ConfigProto& config_proto, - mlir::ModuleOp module, const Graph& graph, - const FunctionLibraryDefinition& function_library) { +absl::Status MlirBridgePass::Run( + const std::string& function_name, const ConfigProto& config_proto, + mlir::ModuleOp module, const Graph& graph, + const FunctionLibraryDefinition& function_library) { static absl::once_flag flag; absl::call_once(flag, UpdateLogVerbosityIfDefined, "TF_DEBUG_LOG_VERBOSITY"); @@ -362,8 +362,8 @@ MlirOptimizationPassState MlirBridgeV1CompatPass::GetPassState( } } -Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options, - mlir::ModuleOp module) { +absl::Status MlirBridgeV1CompatPass::Run( + const GraphOptimizationPassOptions& options, mlir::ModuleOp module) { static absl::once_flag flag; absl::call_once(flag, UpdateLogVerbosityIfDefined, "TF_DEBUG_LOG_VERBOSITY"); diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index 9adb9e8794c27e..eae5fb83c5d682 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -46,9 +46,10 @@ class MlirBridgePass : public MlirOptimizationPass { // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. - Status Run(const std::string& function_name, const ConfigProto& config_proto, - mlir::ModuleOp module, const Graph& graph, - const FunctionLibraryDefinition& function_library) override; + absl::Status Run(const std::string& function_name, + const ConfigProto& config_proto, mlir::ModuleOp module, + const Graph& graph, + const FunctionLibraryDefinition& function_library) override; }; // This pass uses MLIR to implement all the conversion steps to target XLA from @@ -65,8 +66,8 @@ class MlirBridgeV1CompatPass : public MlirV1CompatOptimizationPass { // This should be used as a thin mapper around mlir::ModulePass::runOnModule // API integrated with the Tensorflow runtime. - Status Run(const GraphOptimizationPassOptions& options, - mlir::ModuleOp module) override; + absl::Status Run(const GraphOptimizationPassOptions& options, + mlir::ModuleOp module) override; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 1e7bbd59464685..d69576fa7fda1c 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_attributes.pb.h" @@ -64,12 +64,14 @@ class FakeDevice : public Device { explicit FakeDevice(const DeviceAttributes& device_attributes) : Device(nullptr, device_attributes) {} - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + absl::Status Sync() override { + return errors::Unimplemented("FakeDevice::Sync()"); + } }; // Translates the graph input information from tf2xla:::Config to // GraphImportConfig. -Status ConvertInputInfo( +absl::Status ConvertInputInfo( const tf2xla::Config& config, const std::unordered_map& feed_name_remap, GraphImportConfig* specs) { @@ -99,8 +101,8 @@ Status ConvertInputInfo( // Translates the graph output information from tf2xla:::Config to // GraphImportConfig. -Status ConvertOutputInfo(const tf2xla::Config& config, - GraphImportConfig* specs) { +absl::Status ConvertOutputInfo(const tf2xla::Config& config, + GraphImportConfig* specs) { std::vector array_names; for (const tf2xla::Fetch& fetch : config.fetch()) { array_names.push_back(fetch.id().node_name()); @@ -111,7 +113,7 @@ Status ConvertOutputInfo(const tf2xla::Config& config, } // namespace -Status ConvertGraphDefToXlaViaMlir( +absl::Status ConvertGraphDefToXlaViaMlir( GraphDef graph_def, const tf2xla::Config& config, xla::XlaComputation* computation, absl::string_view debug_info_filename, absl::string_view debug_info_path_begin_marker) { diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index be344f17ed7941..b1a93508d92896 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" @@ -51,7 +51,7 @@ class MLIRContextResource : public ResourceBase { static constexpr const char* kDefaultResourceName = "mlir-xla-op-cached-context"; - static Status Create(MLIRContextResource** resource) { + static absl::Status Create(MLIRContextResource** resource) { *resource = new MLIRContextResource(); return absl::OkStatus(); } @@ -70,7 +70,7 @@ class MLIRContextResource : public ResourceBase { } // namespace -Status MlirXlaOpKernel::ContextToXlaArgs( +absl::Status MlirXlaOpKernel::ContextToXlaArgs( XlaOpKernelContext* ctx, std::vector& xla_args) { // Collect arguments that are registered as CompileTimeConstantInput. std::vector registered_consts_vec; @@ -108,7 +108,7 @@ Status MlirXlaOpKernel::ContextToXlaArgs( MlirXlaOpKernel::MlirXlaOpKernel(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} -Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { +absl::Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { // Create input XlaArguments. std::vector xla_args; TF_RETURN_IF_ERROR(ContextToXlaArgs(ctx, xla_args)); diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h index b2968b0c0a87b8..6053f5d68635d0 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h @@ -30,10 +30,10 @@ class MlirXlaOpKernel : public XlaOpKernel { explicit MlirXlaOpKernel(OpKernelConstruction* ctx); private: - Status ContextToXlaArgs(XlaOpKernelContext* ctx, - std::vector& xla_args); + absl::Status ContextToXlaArgs(XlaOpKernelContext* ctx, + std::vector& xla_args); void Compile(XlaOpKernelContext* ctx) override; - Status ConstructXlaOp(XlaOpKernelContext* ctx); + absl::Status ConstructXlaOp(XlaOpKernelContext* ctx); }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index a51ad205015bad..e65c948c87e4c8 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -49,7 +49,7 @@ namespace { // Helper shape function for operators that return an output with the same rank // as their first input. -Status UnchangedRank(shape_inference::InferenceContext* c) { +absl::Status UnchangedRank(shape_inference::InferenceContext* c) { if (c->RankKnown(c->input(0))) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0)))); } else { @@ -215,7 +215,7 @@ preferred_element_type: type of the tensor. batch_group_count: number of batch groups or grouped filters. )doc"); -static Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { +static absl::Status XlaDotShapeFunction(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle lhs_shape_handle = c->input(0); shape_inference::ShapeHandle rhs_shape_handle = c->input(1); if (!c->RankKnown(lhs_shape_handle) || !c->RankKnown(rhs_shape_handle)) { @@ -395,7 +395,7 @@ REGISTER_OP("XlaDynamicSlice") .Output("output: T") .Attr("T: type") .Attr("Tindices: {int32, int64}") - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { shape_inference::ShapeHandle size_indices_shape = c->input(2); if (!c->RankKnown(size_indices_shape)) { return UnchangedRank(c); @@ -1297,7 +1297,7 @@ scatter_dimension: Dimension to scatter. reduce_op: Reduction computation. )doc"); -Status OptimizationBarrierShape(shape_inference::InferenceContext* c) { +absl::Status OptimizationBarrierShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { c->set_output(i, c->input(i)); } diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc index a081fa18891ba2..84ed56a468df8e 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.cc +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.cc @@ -52,9 +52,10 @@ std::vector ShuffleInputDataTypeAttribute( // be rewritten, `resource_input_count` will be set to number of DT_RESOURCE // inputs, and `index_mapping` will hold a mapping for original input index to // rearranged input index. -Status InputTypesNeedsRearrange(const std::vector& in_types, - bool* need_rewrite, int* resource_input_count, - std::vector* index_mapping) { +absl::Status InputTypesNeedsRearrange(const std::vector& in_types, + bool* need_rewrite, + int* resource_input_count, + std::vector* index_mapping) { int first_resource_index = -1; for (int i = 0, end = in_types.size(); i < end; i++) { DataType type = in_types[i]; @@ -105,8 +106,8 @@ Status InputTypesNeedsRearrange(const std::vector& in_types, // Given mapping between original input index and rearranged input index, // reorder input edges for the node. -Status ReorderInputEdges(Graph* g, Node* n, - const std::vector& index_mapping) { +absl::Status ReorderInputEdges(Graph* g, Node* n, + const std::vector& index_mapping) { std::vector input_edges; for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) { @@ -129,9 +130,9 @@ Status ReorderInputEdges(Graph* g, Node* n, // input index, reorder output edges for the node. DT_RESOURCE outputs are // removed from the node and we will use the node's corresponding input for the // edge. -Status ReorderOutputEdges(Graph* g, Node* n, int input_count, - int resource_input_count, - const std::vector& index_mapping) { +absl::Status ReorderOutputEdges(Graph* g, Node* n, int input_count, + int resource_input_count, + const std::vector& index_mapping) { std::vector output_edges; for (const Edge* e : n->out_edges()) { if (e->IsControlEdge()) { @@ -159,9 +160,8 @@ Status ReorderOutputEdges(Graph* g, Node* n, int input_count, // Given mapping between original input index and rearranged input index, change // "index" attribute for _Arg nodes. -void RearrangeArgNodes( - const gtl::InlinedVector* arg_nodes, // non-absl ok - const std::vector& index_mapping) { +void RearrangeArgNodes(const absl::InlinedVector* arg_nodes, + const std::vector& index_mapping) { for (int i = 0; i < arg_nodes->size(); i++) { Node* n = (*arg_nodes)[i]; int new_index = index_mapping.at(i); @@ -176,8 +176,8 @@ void RearrangeArgNodes( // original _Retval to rearranged _Retval, and `resource_retval_to_arg` will // hold mapping from DT_RESOURCE _Retval index to its input _Arg index. Here we // assume that all DT_RESOURCE _Retval nodes come from _Arg nodes directly. -Status CalculateRetvalRearrange( - const gtl::InlinedVector& ret_nodes, // non-absl ok +absl::Status CalculateRetvalRearrange( + const absl::InlinedVector& ret_nodes, std::map* retval_index_mapping, std::map* resource_retval_to_arg) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { @@ -225,9 +225,9 @@ std::vector ShuffleOutputDataTypeAttribute( // and rearranged input index, reorder output edges for the node. DT_RESOURCE // outputs are removed from the node and we will use the node's corresponding // input for the edge. -Status RearrangeOutputEdges(Node* n, Graph* g, - const std::map& retval_index_mapping, - const std::map& resource_retval_to_arg) { +absl::Status RearrangeOutputEdges( + Node* n, Graph* g, const std::map& retval_index_mapping, + const std::map& resource_retval_to_arg) { std::vector out_edges; for (const Edge* e : n->out_edges()) { if (!e->IsControlEdge()) { @@ -258,9 +258,9 @@ Status RearrangeOutputEdges(Node* n, Graph* g, // Given mapping between original output index and rearranged output index, // change "index" attribute for _Retval nodes. Notice that DT_RESOURCE _Retval // nodes will be removed. -void RearrangeRetvalNodes( - const gtl::InlinedVector& ret_nodes, // non-absl ok - Graph* g, const std::map& retval_index_mapping) { +void RearrangeRetvalNodes(const absl::InlinedVector& ret_nodes, + Graph* g, + const std::map& retval_index_mapping) { for (int i = 0, end = ret_nodes.size(); i < end; i++) { Node* n = ret_nodes[i]; auto iter = retval_index_mapping.find(i); @@ -273,8 +273,8 @@ void RearrangeRetvalNodes( } } -Status MaybeRewriteWhileNode( - std::function +absl::Status MaybeRewriteWhileNode( + std::function get_function_body_fn, Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten) { // Check if this While node needs rewrite. @@ -382,8 +382,8 @@ Status MaybeRewriteWhileNode( return absl::OkStatus(); } -Status MaybeRewriteIfNode( - std::function +absl::Status MaybeRewriteIfNode( + std::function get_function_body_fn, Graph* g, Node* n, FunctionLibraryDefinition* fld, bool* node_rewritten, const FunctionLibraryDefinition* global_fld) { @@ -519,8 +519,8 @@ Status MaybeRewriteIfNode( } // namespace -Status RearrangeFunctionArguments( - std::function +absl::Status RearrangeFunctionArguments( + std::function get_function_body_fn, Graph* g, FunctionLibraryDefinition* fld, const FunctionLibraryDefinition* global_fld) { @@ -537,7 +537,7 @@ Status RearrangeFunctionArguments( const FunctionBody* fbody; TF_RETURN_IF_ERROR(get_function_body_fn(func_name_attrs, &fbody)); InlineFunctionBodyOptions opts; - Status s = InlineFunctionBody(*fld, g, n, fbody, opts); + absl::Status s = InlineFunctionBody(*fld, g, n, fbody, opts); // Inlining might fail because the function is marked with attribute // _noinline. s.IgnoreError(); diff --git a/tensorflow/compiler/tf2xla/rearrange_function_argument.h b/tensorflow/compiler/tf2xla/rearrange_function_argument.h index fae625308b00b3..1a290017a55dc0 100644 --- a/tensorflow/compiler/tf2xla/rearrange_function_argument.h +++ b/tensorflow/compiler/tf2xla/rearrange_function_argument.h @@ -31,8 +31,8 @@ namespace tensorflow { // `fld` is used to store rewritten functions. // `global_fld` is used to potentially supply stack traces for functions when // they are not found in `fld`. -Status RearrangeFunctionArguments( - std::function +absl::Status RearrangeFunctionArguments( + std::function get_function_body_fn, Graph* g, FunctionLibraryDefinition* fld, const FunctionLibraryDefinition* global_fld = nullptr); diff --git a/tensorflow/compiler/tf2xla/resource_util.cc b/tensorflow/compiler/tf2xla/resource_util.cc index 88ed094fe446f0..e78828df4e13a4 100644 --- a/tensorflow/compiler/tf2xla/resource_util.cc +++ b/tensorflow/compiler/tf2xla/resource_util.cc @@ -41,7 +41,7 @@ const char kRetvalOp[] = "_Retval"; const int kMaxCallDepth = 100; -Status AnalyzeResourceUsage( +absl::Status AnalyzeResourceUsage( const Graph* graph, const std::optional& function_name, const int call_depth, const absl::flat_hash_set& resource_arg_indices, FunctionLibraryRuntime* lib_runtime, @@ -95,7 +95,7 @@ void PropagateFromStackOrTensorArraySourceOp( } } -Status PropagateFromArgOp( +absl::Status PropagateFromArgOp( const Node& n, const std::optional& function_name, const absl::flat_hash_set& resource_arg_indices, absl::flat_hash_map* @@ -125,7 +125,7 @@ Status PropagateFromArgOp( return absl::OkStatus(); } -Status UpdateResourceUsageFromFunctionBodyAnalysis( +absl::Status UpdateResourceUsageFromFunctionBodyAnalysis( const Node& call_node, const std::optional& caller_function_name, const FunctionBody& fbody, @@ -179,7 +179,7 @@ Status UpdateResourceUsageFromFunctionBodyAnalysis( return absl::OkStatus(); } -Status PropagateThroughCallOp( +absl::Status PropagateThroughCallOp( const Node& n, const std::optional& function_name, const int call_depth, FunctionLibraryRuntime* lib_runtime, absl::flat_hash_map* @@ -223,7 +223,7 @@ Status PropagateThroughCallOp( } // Analyzes pass through values for Identity and IdentityN ops. -Status PropagateThroughIdentityOp( +absl::Status PropagateThroughIdentityOp( const Node& n, absl::flat_hash_map* user_to_source) { @@ -249,7 +249,7 @@ Status PropagateThroughIdentityOp( return absl::OkStatus(); } -Status AnalyzeResourceUsage( +absl::Status AnalyzeResourceUsage( const Graph* graph, const std::optional& function_name, const int call_depth, const absl::flat_hash_set& resource_arg_indices, FunctionLibraryRuntime* lib_runtime, @@ -318,7 +318,7 @@ Status AnalyzeResourceUsage( } // anonymous namespace -/*Static*/ Status ResourceUsageAnalysis::Analyze( +/*Static*/ absl::Status ResourceUsageAnalysis::Analyze( const Graph* graph, FunctionLibraryRuntime* lib_runtime, absl::flat_hash_map>* source_to_path) { diff --git a/tensorflow/compiler/tf2xla/resource_util.h b/tensorflow/compiler/tf2xla/resource_util.h index 4aac73638d6963..e4bdb5112be893 100644 --- a/tensorflow/compiler/tf2xla/resource_util.h +++ b/tensorflow/compiler/tf2xla/resource_util.h @@ -86,7 +86,7 @@ class ResourceUsageAnalysis { // source_to_path maps the nodes that creates resources to all nodes that // operate on the corresponding resource, not including sources themselves. It // is cleared upon calling this method. - static Status Analyze( + static absl::Status Analyze( const Graph* graph, FunctionLibraryRuntime* lib_runtime, absl::flat_hash_map>* source_to_path); diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index 83467cbd47e150..b8b56d4eafdcfa 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -29,8 +29,8 @@ limitations under the License. namespace tensorflow { namespace { -Status PopulateInfeedLayoutVector(const xla::Shape& shape, - std::vector* layouts) { +absl::Status PopulateInfeedLayoutVector(const xla::Shape& shape, + std::vector* layouts) { if (shape.IsTuple()) { int64_t tuple_elements = xla::ShapeUtil::TupleElementCount(shape); for (int64_t i = 0; i < tuple_elements; ++i) { @@ -73,7 +73,7 @@ absl::StatusOr MakeLayout(absl::Span minor_to_major, return true; } -Status AssignLayout( +absl::Status AssignLayout( absl::Span minor_to_major, const std::function& layout_func, xla::Shape* shape) { @@ -89,8 +89,8 @@ Status AssignLayout( } // namespace // Convert an XLA Shape into the equivalent TensorFlow shape. -Status XLAShapeToTensorShape(const xla::Shape& shape, - TensorShape* tensor_shape) { +absl::Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape) { if (shape.IsTuple()) { return errors::InvalidArgument("XLA shape ", xla::ShapeUtil::HumanString(shape), @@ -104,19 +104,18 @@ Status XLAShapeToTensorShape(const xla::Shape& shape, } // Convert a TensorShape into the equivalent XLA Shape proto. -Status TensorShapeToXLAShape(DataType dtype, - const PartialTensorShape& tensor_shape, - xla::Shape* shape) { +absl::Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape) { xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); return absl::OkStatus(); } -Status TensorShapeToBoundedXLAShape(DataType dtype, - const PartialTensorShape& tensor_shape, - const TensorShape& bound, - xla::Shape* shape) { +absl::Status TensorShapeToBoundedXLAShape( + DataType dtype, const PartialTensorShape& tensor_shape, + const TensorShape& bound, xla::Shape* shape) { xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); if (tensor_shape.unknown_rank()) { @@ -185,8 +184,9 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, } // Convert a TensorShape into the equivalent XLA Shape proto. -Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, - xla::Shape* shape) { +absl::Status TensorShapeToXLAShape(DataType dtype, + const TensorShape& tensor_shape, + xla::Shape* shape) { xla::PrimitiveType type; TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type)); *shape = TensorShapeToXLAShape(type, tensor_shape); @@ -220,7 +220,7 @@ absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape) { return layouts; } -Status GetShapeWithLayout( +absl::Status GetShapeWithLayout( const xla::Shape& input_shape, absl::Span minor_to_major, const std::function& layout_func, xla::Shape* output_shape) { diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h index 8450a09bc4ee1a..018ab191677034 100644 --- a/tensorflow/compiler/tf2xla/shape_util.h +++ b/tensorflow/compiler/tf2xla/shape_util.h @@ -30,14 +30,15 @@ namespace tensorflow { // Convert an XLA Shape into the equivalent TensorFlow shape. May fail since // not all XLA shapes can be represented as TensorShapes. -Status XLAShapeToTensorShape(const xla::Shape& shape, - TensorShape* tensor_shape); +absl::Status XLAShapeToTensorShape(const xla::Shape& shape, + TensorShape* tensor_shape); // Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow, // XLA shapes include the type. Not all `dtype` values can be represented by // XLA, so this conversion may fail. -Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape, - xla::Shape* shape); +absl::Status TensorShapeToXLAShape(DataType dtype, + const TensorShape& tensor_shape, + xla::Shape* shape); absl::StatusOr TensorShapeToXLAShape( DataType dtype, const TensorShape& tensor_shape); @@ -49,19 +50,18 @@ xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, // Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape // with unknown rank is represented by an r1 with empty dimension. -Status TensorShapeToXLAShape(DataType dtype, - const PartialTensorShape& tensor_shape, - xla::Shape* shape); +absl::Status TensorShapeToXLAShape(DataType dtype, + const PartialTensorShape& tensor_shape, + xla::Shape* shape); // Convert a PartialTensorShape into the equivalent XLA Shape proto. An shape // with unknown rank is represented by an r1 with empty dimension. xla::Shape TensorShapeToXLAShape(xla::PrimitiveType type, const PartialTensorShape& tensor_shape); -Status TensorShapeToBoundedXLAShape(DataType dtype, - const PartialTensorShape& tensor_shape, - const TensorShape& bound, - xla::Shape* shape); +absl::Status TensorShapeToBoundedXLAShape( + DataType dtype, const PartialTensorShape& tensor_shape, + const TensorShape& bound, xla::Shape* shape); // Given an XLA shape with layouts, builds a layout vector in the form able to // be fed to ops like InfeedEnqueue/InfeedEnqueueTuple/XRTAllocateV2/.... @@ -77,7 +77,7 @@ absl::StatusOr> GetShapeLayoutVector(const xla::Shape& shape); // of the layouts, create the output shape by rewriting the input shape layouts. // If a layout is missing (has -1 values) for a matching tuple subshape, the // layout_func will be called, if not nullptr. -Status GetShapeWithLayout( +absl::Status GetShapeWithLayout( const xla::Shape& input_shape, absl::Span minor_to_major, const std::function& layout_func, xla::Shape* output_shape); diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index 94cf80956f97aa..4cf9fdbd39fa0c 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -48,7 +48,7 @@ void AssignOpMetadataToSharding(xla::OpSharding& sharding, } } -Status CoreOutOfRangeError(int core, int num_cores_per_replica) { +absl::Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, "; num_cores_per_replica=", num_cores_per_replica); diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index 3077971fee586f..473ad1dd0a5dd1 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/status_macros.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc index 5b5f061cf50b0d..67832de3dffbe6 100644 --- a/tensorflow/compiler/tf2xla/sharding_util_test.cc +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -115,7 +115,7 @@ TEST_P(ShardingWithMetadataTest, GetShardingFromNode) { { Graph graph(OpRegistry::Global()); - Status status; + absl::Status status; Node* node = graph.AddNode(node_def, &status); TF_ASSERT_OK(status); diff --git a/tensorflow/compiler/tf2xla/side_effect_util.cc b/tensorflow/compiler/tf2xla/side_effect_util.cc index 9446e4b4adadb9..afe82e0de40f62 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.cc +++ b/tensorflow/compiler/tf2xla/side_effect_util.cc @@ -34,7 +34,7 @@ const char kXlaIsPlaceholderForTailOcAttrName[] = const char kXlaOriginalOutsideCompilationNodeName[] = "_xla_original_oc_node_name"; -Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { +absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal) { if (!HasNodeAttr(node->def(), kXlaHasHostTransferAttrName)) { return errors::InvalidArgument("Node ", node->DebugString(), " does not have attribute ", @@ -122,8 +122,9 @@ bool HasSideEffectingNodes(const Graph& g) { return false; } -Status ParseHostComputeCoreList(absl::Span list_from_attr, - std::map* host_compute_core) { +absl::Status ParseHostComputeCoreList( + absl::Span list_from_attr, + std::map* host_compute_core) { for (const auto& hc_core : list_from_attr) { std::vector parts = str_util::Split(hc_core, ":"); if (parts.size() != 2) { diff --git a/tensorflow/compiler/tf2xla/side_effect_util.h b/tensorflow/compiler/tf2xla/side_effect_util.h index f91fe75c8a4cf3..34f30eb7661bc1 100644 --- a/tensorflow/compiler/tf2xla/side_effect_util.h +++ b/tensorflow/compiler/tf2xla/side_effect_util.h @@ -49,7 +49,7 @@ extern const char kXlaOriginalOutsideCompilationNodeName[]; // Sets device ordinal attribute for nodes with attribute // `kXlaHasHostTransferAttrName`. -Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); +absl::Status SetDeviceOrdinalAttributeForNode(Node* node, int device_ordinal); // Calculates side-effect dependencies for the graph's token output. // Returns a set of node names representing these dependencies. @@ -61,8 +61,8 @@ bool HasSideEffectingNodes(const Graph& g); // Parse the mapping from outside_compilation_subgraph name to core number, // which is specified in an attr as a list of strings // :. -Status ParseHostComputeCoreList(absl::Span list_from_attr, - std::map* host_compute_core); +absl::Status ParseHostComputeCoreList(absl::Span list_from_attr, + std::map* host_compute_core); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc index 3fb8523ce71e0c..43623a8db8014f 100644 --- a/tensorflow/compiler/tf2xla/test_util.cc +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -20,9 +20,9 @@ limitations under the License. namespace tensorflow { -Status InstantiateFunctionForTest(const string& name, - const FunctionLibraryDefinition& library, - InstantiationResultForTest* result) { +absl::Status InstantiateFunctionForTest( + const string& name, const FunctionLibraryDefinition& library, + InstantiationResultForTest* result) { const FunctionDef* fdef = library.Find(name); TF_RET_CHECK(fdef != nullptr); diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h index 4ffc94ae3bc7c9..2b2eb4f582af3e 100644 --- a/tensorflow/compiler/tf2xla/test_util.h +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -40,9 +40,9 @@ struct InstantiationResultForTest { // Instantiates a function, producing a GraphDef to compare against the // expected graph. -Status InstantiateFunctionForTest(const string& name, - const FunctionLibraryDefinition& library, - InstantiationResultForTest* result); +absl::Status InstantiateFunctionForTest( + const string& name, const FunctionLibraryDefinition& library, + InstantiationResultForTest* result); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index ef87b320cdcd0c..b899298d9e2dc8 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" @@ -54,9 +54,10 @@ namespace { // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. -Status ConvertGraphToXla(std::unique_ptr graph, - const tf2xla::Config& config, xla::Client* client, - xla::XlaComputation* computation) { +absl::Status ConvertGraphToXla(std::unique_ptr graph, + const tf2xla::Config& config, + xla::Client* client, + xla::XlaComputation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -128,8 +129,8 @@ Status ConvertGraphToXla(std::unique_ptr graph, return absl::OkStatus(); } -Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { - auto update_var_handle_op_node = [](NodeDef& node) -> Status { +absl::Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { + auto update_var_handle_op_node = [](NodeDef& node) -> absl::Status { if (node.op() == "VarHandleOp") { node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); const auto& it = node.attr().find("allowed_devices"); @@ -156,9 +157,10 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { } // namespace -Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, - xla::Client* client, - xla::XlaComputation* computation) { +absl::Status ConvertGraphDefToXla(GraphDef graph_def, + const tf2xla::Config& config, + xla::Client* client, + xla::XlaComputation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def)); TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index 3fd2c641dd6e4f..095ad49afc6b66 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "xla/client/client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/status.h" @@ -32,9 +32,10 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, - xla::Client* client, - xla::XlaComputation* computation); +absl::Status ConvertGraphDefToXla(GraphDef graph_def, + const tf2xla::Config& config, + xla::Client* client, + xla::XlaComputation* computation); // Similar to ConvertGraphDefToXla, but uses MLIR and handle debug information. // @@ -42,7 +43,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, // debug_info_path_begin_marker: if not empty, file pathes in the debug // information are trimmed from the beginning to the first appearance of the // marker. -Status ConvertGraphDefToXlaViaMlir( +absl::Status ConvertGraphDefToXlaViaMlir( GraphDef graph_def, const tf2xla::Config& config, xla::XlaComputation* computation, absl::string_view debug_info_filename, absl::string_view debug_info_path_begin_marker); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 01bb69d16ee264..c9906ada9c1254 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "tensorflow/core/framework/attr_value.pb.h" diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index ca22c7b9ceb10a..9f21af2741dcde 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -48,7 +48,7 @@ namespace tensorflow { namespace { -Status ValidateTensorId(const tf2xla::TensorId& id) { +absl::Status ValidateTensorId(const tf2xla::TensorId& id) { if (id.node_name().empty()) { return errors::InvalidArgument("TensorId node_name must be non-empty"); } @@ -58,8 +58,8 @@ Status ValidateTensorId(const tf2xla::TensorId& id) { return absl::OkStatus(); } -Status CheckNameDuplicates(const string& kind, const string& name, - std::set* names) { +absl::Status CheckNameDuplicates(const string& kind, const string& name, + std::set* names) { if (!name.empty()) { if (!names->insert(name).second) { return errors::InvalidArgument("duplicate ", kind, " name: ", name); @@ -68,8 +68,8 @@ Status CheckNameDuplicates(const string& kind, const string& name, return absl::OkStatus(); } -Status CheckFeedFetchNameConflicts(const string& kind, - const std::set& names) { +absl::Status CheckFeedFetchNameConflicts(const string& kind, + const std::set& names) { // We don't allow the feeds or fetches to contain both "foo" and "foo_data", // since that will cause a collision in codegen symbols. for (const string& name : names) { @@ -84,9 +84,9 @@ Status CheckFeedFetchNameConflicts(const string& kind, // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`. -Status CopyAssociatedFunctions(Graph* g, - const FunctionLibraryDefinition* lookup_fld, - FunctionLibraryDefinition* fld) { +absl::Status CopyAssociatedFunctions( + Graph* g, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { for (Node* n : g->op_nodes()) { for (const auto& associated_function : GetAssociatedFunctions(*n, lookup_fld)) { @@ -127,8 +127,8 @@ absl::StatusOr ReplaceEdge(Graph* g, Node* dst, int dst_input, // Replaces usages of the given `src_output` index of the given `src` node with // the given `replacement` node (assumes the :0 output of `replacement`). -Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output, - Node* replacement) { +absl::Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output, + Node* replacement) { VLOG(1) << "Replace usages of output " << src_output << " of node " << (VLOG_IS_ON(3) ? src->DebugString() : src->name()) << " with " << (VLOG_IS_ON(3) ? replacement->DebugString() : replacement->name()); @@ -167,7 +167,7 @@ Status ReplaceSrcOutputUsageWithNode(Graph* g, Node* src, int src_output, // For graph `g`, replaces _Arg nodes whose "index" attribute is in // `const_input_index_to_node` with Const nodes. -Status ReplaceArgUsageWithConstNode( +absl::Status ReplaceArgUsageWithConstNode( Graph* g, const absl::flat_hash_map& const_input_index_to_node) { // Collect all _Arg nodes. @@ -196,7 +196,7 @@ Status ReplaceArgUsageWithConstNode( // Replaces the single input to _Retval nodes with an index in the keys of // const_input_index_to_node with the single output of the corresponding _Arg // node. -Status ReplaceRetvalInputWithArg( +absl::Status ReplaceRetvalInputWithArg( Graph* g, const absl::flat_hash_map& const_input_index_to_node) { absl::flat_hash_map arg_nodes; @@ -226,7 +226,7 @@ Status ReplaceRetvalInputWithArg( // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites // the function to replace _Arg nodes in `const_input_index_to_node` with Const // inputs. -Status PropagateConstIntoFuncAttr( +absl::Status PropagateConstIntoFuncAttr( Node* n, const string& attr_name, const absl::flat_hash_map& const_input_index_to_node, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld, @@ -281,9 +281,9 @@ Status PropagateConstIntoFuncAttr( // For an "If" node in graph `g`, if it has Const node inputs, rewrite its // then/else branch function to replace _Arg nodes with those Const inputs. -Status PropagateConstIntoIfNode(Graph* g, Node* if_node, - const FunctionLibraryDefinition* lookup_fld, - FunctionLibraryDefinition* fld) { +absl::Status PropagateConstIntoIfNode( + Graph* g, Node* if_node, const FunctionLibraryDefinition* lookup_fld, + FunctionLibraryDefinition* fld) { // Notice that first input for If node is predicate; other inputs are function // inputs. absl::flat_hash_map const_input_index_to_node; @@ -326,8 +326,8 @@ absl::StatusOr FindOrInsert( return errors::Internal("Traverse: Cannot find body function ", name); } std::unique_ptr fbody; - Status s = FunctionDefToBodyHelper(*body_func, AttrSlice(&body_attr.attr()), - lookup_fld, &fbody); + absl::Status s = FunctionDefToBodyHelper( + *body_func, AttrSlice(&body_attr.attr()), lookup_fld, &fbody); if (!s.ok() && fallback_fld != nullptr) { TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( *body_func, AttrSlice(&body_attr.attr()), fallback_fld, &fbody)); @@ -401,7 +401,7 @@ absl::StatusOr IsLoopInvariant( // For a "While" node in graph `g`, if it has Const node inputs, rewrite its // cond/body function to replace _Arg nodes with those Const inputs. Then, // propagate these Const to consumers of the relevant outputs of the while loop. -Status PropagateConstIntoAndAroundWhileNode( +absl::Status PropagateConstIntoAndAroundWhileNode( Graph* g, Node* while_node, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld) { VLOG(1) << "Propagate const into " << while_node->name(); @@ -486,7 +486,7 @@ absl::StatusOr IsLoopInvariant( /*fallback_fld=*/nullptr, &cache); } -Status ValidateConfig(const tf2xla::Config& config) { +absl::Status ValidateConfig(const tf2xla::Config& config) { std::set names; for (const tf2xla::Feed& feed : config.feed()) { TF_RETURN_IF_ERROR(ValidateTensorId(feed.id())); @@ -506,7 +506,7 @@ Status ValidateConfig(const tf2xla::Config& config) { return absl::OkStatus(); } -Status AddPlaceholdersForFeeds( +absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, std::unordered_map* feed_remapping, GraphDef* graph_def) { struct PlaceholderInfo { @@ -603,8 +603,8 @@ Status AddPlaceholdersForFeeds( return absl::OkStatus(); } -Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, - GraphDef* out) { +absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, + GraphDef* out) { *out = in; out->clear_node(); @@ -672,7 +672,7 @@ string TensorIdToString(const tf2xla::TensorId& id) { return absl::StrCat(id.node_name(), ":", id.output_index()); } -Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { +absl::Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { int core = -1; const Node* matching_node = nullptr; for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) { @@ -792,7 +792,7 @@ std::vector GetAssociatedFunctions( return results; } -Status RewriteAssociatedFunction( +absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, const string& rewritten_function_name) { @@ -862,7 +862,7 @@ Status RewriteAssociatedFunction( return absl::OkStatus(); } -Status CachedFunctionHandles::GetOrInstantiate( +absl::Status CachedFunctionHandles::GetOrInstantiate( const string& func_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { string canonicalized_name = Canonicalize(func_name, attrs); @@ -877,8 +877,8 @@ Status CachedFunctionHandles::GetOrInstantiate( return absl::OkStatus(); } -Status CachedFunctionHandles::ReleaseAllHandles() { - Status result; +absl::Status CachedFunctionHandles::ReleaseAllHandles() { + absl::Status result; for (const auto& iter : handles_) { result.Update(flr_->ReleaseHandle(iter.second)); } @@ -936,7 +936,7 @@ absl::StatusOr BuildIdentityNode( return id_node; } -Status PropagateConstIntoFunctionalNodes( +absl::Status PropagateConstIntoFunctionalNodes( Graph* g, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld) { absl::flat_hash_set done_node_ids; @@ -969,8 +969,8 @@ Status PropagateConstIntoFunctionalNodes( return absl::OkStatus(); } -Status PruneUnreachableFunctionsFromGraph(const Graph& g, - FunctionLibraryDefinition* fld) { +absl::Status PruneUnreachableFunctionsFromGraph( + const Graph& g, FunctionLibraryDefinition* fld) { GraphDef graph_def; g.ToGraphDef(&graph_def); FunctionLibraryDefinition reachable_functions = @@ -983,8 +983,8 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g, return absl::OkStatus(); } -Status RewriteTensorListWithConstElement(Graph* g, - FunctionLibraryDefinition* fld) { +absl::Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld) { for (Node* n : g->nodes()) { if (n->type_string() != "EmptyTensorList") { continue; diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index 707997e614c4c1..f2ce3944ac158c 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -32,21 +32,21 @@ limitations under the License. namespace tensorflow { // ValidateConfig returns OK iff config is valid. -Status ValidateConfig(const tf2xla::Config& config); +absl::Status ValidateConfig(const tf2xla::Config& config); // Modifies to include placeholders for each fed tensor, and // update references to the fed tensors to refer to the placeholders. // The existing nodes referenced by the feeds are not removed or modified // (except where their input edges are modified by the replacement of other // feeds). -Status AddPlaceholdersForFeeds( +absl::Status AddPlaceholdersForFeeds( const tf2xla::Config& config, const OpRegistryInterface* op_registry, std::unordered_map* feed_remapping, GraphDef* graph_def); // Returns in a copy of , pruned to only include fetches from // . -Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, - GraphDef* out); +absl::Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, + GraphDef* out); // Returns node:port for the given . string TensorIdToString(const tf2xla::TensorId& id); @@ -54,7 +54,7 @@ string TensorIdToString(const tf2xla::TensorId& id); // Updates the sharding of based on the sharding of its neighbors. // If is true, outgoing edges from are considered; else incoming // edges are considered. -Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); +absl::Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); // Add an allowed data type to the AttrConstraint with the given name. void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype, @@ -139,7 +139,7 @@ std::vector GetAssociatedFunctions( // 2. For SymbolicGradient op, add or replace GradientDef in // FunctionLibraryDefinition; // 3. For nodes like XlaWhile/XlaIf, modify their function attributes. -Status RewriteAssociatedFunction( +absl::Status RewriteAssociatedFunction( Graph* graph, Node* node, FunctionLibraryDefinition* fld, const AssociatedFunctionInfo& associated_function, const string& rewritten_function_name); @@ -152,12 +152,12 @@ class CachedFunctionHandles { // Populates `handle` for requested function and attributes. If we have // instantiated the function with the same attributes before, `handle` will be // cached handle; otherwise instantiate the function and populate `handle`. - Status GetOrInstantiate(const string& func_name, AttrSlice attrs, - FunctionLibraryRuntime::Handle* handle); + absl::Status GetOrInstantiate(const string& func_name, AttrSlice attrs, + FunctionLibraryRuntime::Handle* handle); // Releases all handles in the cache. Returns first non-OK status if any; // returns OK otherwise. - Status ReleaseAllHandles(); + absl::Status ReleaseAllHandles(); ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); } @@ -193,13 +193,13 @@ absl::StatusOr BuildIdentityNode(Graph* graph, const string& node_name, // input for tf.ones/tf.zeros. But XLA requires that shape input to be compile // time constant, so XLA compilation will fail. This rewriting process will // change the shape input to Const node. -Status PropagateConstIntoFunctionalNodes( +absl::Status PropagateConstIntoFunctionalNodes( Graph* g, const FunctionLibraryDefinition* lookup_fld, FunctionLibraryDefinition* fld); // Prunes unreachable FunctionDefs from FunctionLibraryDefinition. -Status PruneUnreachableFunctionsFromGraph(const Graph& g, - FunctionLibraryDefinition* fld); +absl::Status PruneUnreachableFunctionsFromGraph(const Graph& g, + FunctionLibraryDefinition* fld); // Finds the following pattern in the graph: // 1) EmptyTensorList -> forward While op -> backward While op, @@ -208,8 +208,8 @@ Status PruneUnreachableFunctionsFromGraph(const Graph& g, // And rewrites backward While op to use Const node instead of TensorListPopBack // result. // TODO(b/128633174) remove the TensorList and related TensorList ops. -Status RewriteTensorListWithConstElement(Graph* g, - FunctionLibraryDefinition* fld); +absl::Status RewriteTensorListWithConstElement(Graph* g, + FunctionLibraryDefinition* fld); inline bool IsConstTraversableOpType(const Node* node) { return node->type_string() == "Identity" || diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index d1ea22324c7e8c..e66a8a38813474 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -40,7 +40,7 @@ limitations under the License. namespace tensorflow { namespace { -void ExpectErrorContains(const Status& status, absl::string_view str) { +void ExpectErrorContains(const absl::Status& status, absl::string_view str) { EXPECT_NE(absl::OkStatus(), status); EXPECT_TRUE(absl::StrContains(status.message(), str)) << "expected error: " << status.message() << " to contain: " << str; diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 655a2c3cdec160..6383d277be852d 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -24,7 +24,8 @@ limitations under the License. namespace tensorflow { -Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { +absl::Status DataTypeToPrimitiveType(DataType data_type, + xla::PrimitiveType* type) { switch (data_type) { case tensorflow::DT_BOOL: *type = xla::PRED; diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h index 505d6b8ed56747..a3027a5fe9a1b8 100644 --- a/tensorflow/compiler/tf2xla/type_util.h +++ b/tensorflow/compiler/tf2xla/type_util.h @@ -24,7 +24,8 @@ limitations under the License. namespace tensorflow { // Converts a Tensorflow DataType to an XLA PrimitiveType. -Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type); +absl::Status DataTypeToPrimitiveType(DataType data_type, + xla::PrimitiveType* type); // Converts an XLA PrimitiveType to a TensorFlow DataType. // Caution: The mapping from TF types to XLA types is not one-to-one: for diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index 49ee115b953d6d..9e2eccd29b1885 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index bbdf5c7d2c74fa..215decdb4d8843 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" @@ -140,9 +140,9 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, VLOG(4) << "Done"; } -Status XlaCompilationDevice::Sync() { return absl::OkStatus(); } +absl::Status XlaCompilationDevice::Sync() { return absl::OkStatus(); } -Status XlaCompilationDevice::MakeTensorFromProto( +absl::Status XlaCompilationDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index de6a3356e05d8a..e3f6571c3039c6 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -54,11 +54,11 @@ class XlaCompilationDevice : public LocalDevice { void Compute(OpKernel* op_kernel, OpKernelContext* context) override; - Status Sync() override; + absl::Status Sync() override; - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; private: std::unique_ptr allocator_; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 970cc1775cc810..c51107fb9deaff 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/variant.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" @@ -53,12 +54,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function.h" @@ -75,6 +77,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/debug_data_dumper.h" @@ -92,8 +95,8 @@ constexpr char kCompileFunctionComponent[] = "TF2XLA_XLA_COMPILER_COMPILE_FUNCTION"; // Checks that arguments `args` match types `types`. -Status CheckSignature(const DataTypeVector& types, - absl::Span args) { +absl::Status CheckSignature(const DataTypeVector& types, + absl::Span args) { if (args.size() != types.size()) { return errors::Internal("Compilation arguments have ", args.size(), " elements while function has ", types.size()); @@ -151,9 +154,9 @@ ComputeArgAndRetvalShardings(const Graph& graph) { // cleaned up here need to change how resources are cleaned up in // graph_compiler_test. // LINT.IfChange(ExecuteGraph) -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64_t step_id) { +absl::Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, + FunctionLibraryRuntime* flib, int64_t step_id) { // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the // resource manager takes ownership via Create, and unrefs via Cleanup. We // explicitly add a reference to ensure the refcount at entry is maintained at @@ -162,7 +165,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // The Executor requires us to use ScopedStepContainer. We wrap it in a // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); - Status status; + absl::Status status; auto step_container = std::make_unique( step_id, [&status, device](const string& name) { status = device->resource_manager()->Cleanup(name); @@ -194,7 +197,7 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // `resource_updates` is a ResourceUpdate, whose `index` is the index of a // resource variable argument to the computation to be updated, and `type` is // the type of the final output. -Status BuildComputation( +absl::Status BuildComputation( const std::vector& args, const std::vector& retvals, const std::map& arg_shardings, @@ -572,9 +575,9 @@ uint64 XlaCompiler::SignatureHash::operator()( return std::hash()(signature.first); } -static Status GetFunctionBody(const NameAttrList& function, - FunctionLibraryRuntime* flib_runtime, - const FunctionBody** fbody) { +static absl::Status GetFunctionBody(const NameAttrList& function, + FunctionLibraryRuntime* flib_runtime, + const FunctionBody** fbody) { FunctionLibraryRuntime::Handle handle; TF_RETURN_IF_ERROR(flib_runtime->Instantiate( function.name(), AttrSlice(&function.attr()), &handle)); @@ -584,9 +587,9 @@ static Status GetFunctionBody(const NameAttrList& function, return absl::OkStatus(); } -Status XlaCompiler::FindFunctionBody(const NameAttrList& function, - const FunctionBody** fbody, - const ConfigProto** config_proto) { +absl::Status XlaCompiler::FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody, + const ConfigProto** config_proto) { // The function may be in either the local_flib_runtime_ or flib_runtime_. // Look up the function in local first and if it is not found then look up the // function in flib_runtime_. @@ -756,7 +759,7 @@ std::vector GetValidControlRets( return valid_control_rets; } -Status XlaCompiler::CompileSingleOp( +absl::Status XlaCompiler::CompileSingleOp( const XlaCompiler::CompileOptions& compile_options, const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, absl::Span args, XlaCompiler::CompilationResult* result) { @@ -767,7 +770,7 @@ Status XlaCompiler::CompileSingleOp( single_op_compile_argument.output_dtypes)); *result = {}; - Status status = ADD_SOURCE_LOCATION(CompileGraph( + absl::Status status = ADD_SOURCE_LOCATION(CompileGraph( compile_options, node_def.name(), std::move(graph), args, result)); if (status.ok()) { tensorflow::metrics::IncrementPhase2XlaCompilerCounter( @@ -784,7 +787,7 @@ Status XlaCompiler::CompileSingleOp( return status; } -Status XlaCompiler::CompileFunction( +absl::Status XlaCompiler::CompileFunction( const XlaCompiler::CompileOptions& options, const NameAttrList& fn_name_attrs, absl::Span args, @@ -900,7 +903,7 @@ Status XlaCompiler::CompileFunction( } // Computes the XLA shape for argument 'arg'. -Status XlaCompiler::XLAShapeForArgument( +absl::Status XlaCompiler::XLAShapeForArgument( const XlaCompiler::Argument& arg, bool is_entry_computation, const std::optional& arg_sharding, xla::Shape* xla_shape) const { @@ -1053,7 +1056,7 @@ XlaCompiler::SingleOpCompileArgument::SingleOpCompileArgument( // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status XlaCompiler::BuildArguments( +absl::Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, const std::map& arg_shardings, @@ -1280,8 +1283,8 @@ Status XlaCompiler::BuildArguments( namespace { // Check that the ops of all non-functional nodes have been registered. -Status ValidateFunctionDef(const FunctionDef* fdef, - const FunctionLibraryDefinition& flib_def) { +absl::Status ValidateFunctionDef(const FunctionDef* fdef, + const FunctionLibraryDefinition& flib_def) { for (const NodeDef& node : fdef->node_def()) { const string& op = node.op(); if (op == FunctionLibraryDefinition::kGradientOp || flib_def.Find(op)) { @@ -1298,7 +1301,7 @@ Status ValidateFunctionDef(const FunctionDef* fdef, // Returned pointer points to the internal string either in node's attributes // or in its NodeDef. This pointer is valid as long as the node has not been // modified. -Status GetPotentialFunctionName(const Node& node, const string** name) { +absl::Status GetPotentialFunctionName(const Node& node, const string** name) { if (node.IsPartitionedCall()) { const AttrValue* attr_value; TF_RETURN_IF_ERROR( @@ -1317,14 +1320,15 @@ Status GetPotentialFunctionName(const Node& node, const string** name) { // Check that the graph doesn't have any invalid nodes (e.g. incompatible with // given device_type, invalid data type, missing attributes...) -Status ValidateGraph(const Graph* graph, - const FunctionLibraryDefinition& flib_def, - const DeviceType& device_type, const string& name) { +absl::Status ValidateGraph(const Graph* graph, + const FunctionLibraryDefinition& flib_def, + const DeviceType& device_type, const string& name) { // Make sure the XLA compilation kernels are registered. This operation is // idempotent so it is fine if someone called it already. XlaOpRegistry::RegisterCompilationKernels(); - auto maybe_error = [&](const Node* node, const Status& s) -> Status { + auto maybe_error = [&](const Node* node, + const absl::Status& s) -> absl::Status { if (!s.ok()) { std::string errmsg = absl::StrCat( "Detected unsupported operations when trying to compile graph ", name, @@ -1358,7 +1362,7 @@ Status ValidateGraph(const Graph* graph, const string* function_name; TF_RETURN_IF_ERROR(GetPotentialFunctionName(*node, &function_name)); const FunctionDef* fdef = flib_def.Find(*function_name); - Status s; + absl::Status s; if (fdef) { s = ValidateFunctionDef(fdef, flib_def); TF_RETURN_IF_ERROR(maybe_error(node, s)); @@ -1443,11 +1447,10 @@ void IncreasePrecisionsToAvoidTF32(xla::HloModuleProto& module) { } // namespace -Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, - string const& name, - std::unique_ptr graph, - absl::Span args, - CompilationResult* result) { +absl::Status XlaCompiler::CompileGraph( + const XlaCompiler::CompileOptions& options, string const& name, + std::unique_ptr graph, absl::Span args, + CompilationResult* result) { VLOG(1) << "Executing graph symbolically to populate XlaBuilder.: " << name; if (VLOG_IS_ON(2) || DEBUG_DATA_DUMPER()->ShouldDump(name, kDebugGroupMain)) { VLOG(2) << "XlaCompiler::CompileGraph: " @@ -1531,8 +1534,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, } } - Status execute_status = ExecuteGraph(context, std::move(graph), device_, - flib_runtime_, NextStepId()); + absl::Status execute_status = ExecuteGraph(context, std::move(graph), device_, + flib_runtime_, NextStepId()); if (!execute_status.ok()) { VLOG(1) << "Failed executing graph " << name; return execute_status; @@ -1609,35 +1612,43 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, return absl::OkStatus(); } -Status XlaCompiler::GetChannelHandle(const string& key, - xla::ChannelHandle* channel) { +xla::ChannelHandle XlaCompiler::NewChannel( + xla::ChannelHandle::ChannelType type) { + xla::ChannelHandle new_handle; + absl::MutexLock lock(&channel_mutex_); + // Create a new channel handle with a unique value. + new_handle.set_handle(next_channel_++); + new_handle.set_type(type); + return new_handle; +} + +absl::Status XlaCompiler::GetChannelHandle(const string& key, + xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { - TF_ASSIGN_OR_RETURN(result.first->second, client()->CreateChannelHandle()); + result.first->second = NewChannel(xla::ChannelHandle::DEVICE_TO_DEVICE); } *channel = result.first->second; VLOG(1) << "Channel: " << key << " " << channel->DebugString(); return absl::OkStatus(); } -Status XlaCompiler::GetHostToDeviceChannelHandle(const string& key, - xla::ChannelHandle* channel) { +absl::Status XlaCompiler::GetHostToDeviceChannelHandle( + const string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { - TF_ASSIGN_OR_RETURN(result.first->second, - client()->CreateHostToDeviceChannelHandle()); + result.first->second = NewChannel(xla::ChannelHandle::HOST_TO_DEVICE); } *channel = result.first->second; VLOG(1) << "Host to device channel: " << key << " " << channel->DebugString(); return absl::OkStatus(); } -Status XlaCompiler::GetDeviceToHostChannelHandle(const string& key, - xla::ChannelHandle* channel) { +absl::Status XlaCompiler::GetDeviceToHostChannelHandle( + const string& key, xla::ChannelHandle* channel) { auto result = channels_.emplace(key, xla::ChannelHandle()); if (result.second) { - TF_ASSIGN_OR_RETURN(result.first->second, - client()->CreateDeviceToHostChannelHandle()); + result.first->second = NewChannel(xla::ChannelHandle::DEVICE_TO_HOST); } *channel = result.first->second; VLOG(1) << "Device to host channel: " << key << " " << channel->DebugString(); @@ -1660,7 +1671,7 @@ void SetTransfer(const string& key, absl::Span types, } // namespace -Status XlaCompiler::SetDeviceToHostMetadata( +absl::Status XlaCompiler::SetDeviceToHostMetadata( const string& key, absl::Span types, absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { @@ -1679,7 +1690,7 @@ Status XlaCompiler::SetDeviceToHostMetadata( return absl::OkStatus(); } -Status XlaCompiler::GetDeviceToHostShapes( +absl::Status XlaCompiler::GetDeviceToHostShapes( const string& key, std::vector* shapes) const { const auto iter = host_compute_sends_.find(key); if (iter == host_compute_sends_.end()) { @@ -1694,7 +1705,7 @@ Status XlaCompiler::GetDeviceToHostShapes( return absl::OkStatus(); } -Status XlaCompiler::SetHostToDeviceMetadata( +absl::Status XlaCompiler::SetHostToDeviceMetadata( const string& key, absl::Span types, absl::Span shapes) { if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) { @@ -1713,7 +1724,7 @@ Status XlaCompiler::SetHostToDeviceMetadata( return absl::OkStatus(); } -Status XlaCompiler::GetHostComputeControlDependency( +absl::Status XlaCompiler::GetHostComputeControlDependency( const string& host_compute_name, xla::XlaOp* handle) { const auto iter = host_compute_control_output_.find(host_compute_name); if (iter == host_compute_control_output_.end()) { @@ -1726,7 +1737,7 @@ Status XlaCompiler::GetHostComputeControlDependency( return absl::OkStatus(); } -Status XlaCompiler::SetHostComputeControlDependency( +absl::Status XlaCompiler::SetHostComputeControlDependency( const string& host_compute_name, const xla::XlaOp handle) { if (host_compute_control_output_.find(host_compute_name) != host_compute_control_output_.end()) { @@ -1742,7 +1753,7 @@ void XlaCompiler::PushNodeTokenMapping() { node_token_mapping_stack_.emplace(std::map{}); } -Status XlaCompiler::PopNodeTokenMapping() { +absl::Status XlaCompiler::PopNodeTokenMapping() { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( "Calling PopNodeTokenMapping() when node_token_mapping_stack_ is " @@ -1752,7 +1763,8 @@ Status XlaCompiler::PopNodeTokenMapping() { return absl::OkStatus(); } -Status XlaCompiler::SetNodeToken(const string& node_name, const xla::XlaOp op) { +absl::Status XlaCompiler::SetNodeToken(const string& node_name, + const xla::XlaOp op) { if (node_token_mapping_stack_.empty()) { return errors::FailedPrecondition( "Calling SetNodeToken() when node_token_mapping_stack_ is " diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 24b0c1d8eced91..cbb57f38b7b2e8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" @@ -28,8 +29,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -177,7 +178,8 @@ class XlaCompiler { // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects // that can be accessed by XLA op kernels. - std::function* populate_resource_manager = nullptr; + std::function* populate_resource_manager = + nullptr; // If not nullptr, this memory allocator can be used by the compiler for // temporary allocations it might want to make during compilation. @@ -228,12 +230,12 @@ class XlaCompiler { static void PopulateArgumentFromResource(const XlaResource& resource, Argument* arg); - Status CompileFunction(const CompileOptions& options, - const NameAttrList& fn_name_attrs, - absl::Span args, - CompilationResult* result); + absl::Status CompileFunction(const CompileOptions& options, + const NameAttrList& fn_name_attrs, + absl::Span args, + CompilationResult* result); - Status CompileSingleOp( + absl::Status CompileSingleOp( const CompileOptions& options, const SingleOpCompileArgument& single_op_compile_argument, absl::Span args, CompilationResult* result); @@ -241,15 +243,15 @@ class XlaCompiler { // Compiles a tensorflow::Graph into an xla::XlaComputation. // Similar to CompileFunction, but takes a Graph as input rather than a // function. - Status CompileGraph( - const CompileOptions& options, string const& name, - std::unique_ptr graph, absl::Span args, - CompilationResult* result); + absl::Status CompileGraph(const CompileOptions& options, string const& name, + std::unique_ptr graph, + absl::Span args, + CompilationResult* result); // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument( + absl::Status XLAShapeForArgument( const Argument& arg, bool is_entry_computation, const std::optional& arg_sharding, xla::Shape* xla_shape) const; @@ -259,33 +261,33 @@ class XlaCompiler { // Channel handles can be used to communicate between different // computations. Computations that communicate should be compiled with the // same XlaCompiler. - Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); + absl::Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); // Retrieves the host-to-device channel handle associated with `key`. // Allocates a new channel handle if none exists. - Status GetHostToDeviceChannelHandle(const string& key, - xla::ChannelHandle* channel); + absl::Status GetHostToDeviceChannelHandle(const string& key, + xla::ChannelHandle* channel); // Retrieves the device-to-host channel handle associated with `key`. // Allocates a new channel handle if none exists. - Status GetDeviceToHostChannelHandle(const string& key, - xla::ChannelHandle* channel); + absl::Status GetDeviceToHostChannelHandle(const string& key, + xla::ChannelHandle* channel); // Sets the shapes and types for the device to host transfer associated with // 'key'. - Status SetDeviceToHostMetadata(const string& key, - absl::Span types, - absl::Span shapes); + absl::Status SetDeviceToHostMetadata(const string& key, + absl::Span types, + absl::Span shapes); // Gets the shapes the device to host transfer associated with 'key'. - Status GetDeviceToHostShapes(const string& key, - std::vector* shapes) const; + absl::Status GetDeviceToHostShapes(const string& key, + std::vector* shapes) const; // Sets the shapes and types for the host to device transfer associated with // 'key'. - Status SetHostToDeviceMetadata(const string& key, - absl::Span types, - absl::Span shapes); + absl::Status SetHostToDeviceMetadata(const string& key, + absl::Span types, + absl::Span shapes); // In order to avoid deadlocks from dependencies in host computations, it can // be necessary to enforce a partial order on the execution of HostCompute @@ -298,40 +300,41 @@ class XlaCompiler { // 'host_compute_name' can be any string the client wishes to use to identify // a given HostCompute Op as long as the names are unique within the // compilation. - Status GetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp* handle); - Status SetHostComputeControlDependency(const string& host_compute_name, - xla::XlaOp handle); + absl::Status GetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp* handle); + absl::Status SetHostComputeControlDependency(const string& host_compute_name, + xla::XlaOp handle); const Options& options() const { return options_; } xla::Client* client() const { return options_.client; } FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } void PushNodeTokenMapping(); - Status PopNodeTokenMapping(); - Status SetNodeToken(const string& node_name, xla::XlaOp op); + absl::Status PopNodeTokenMapping(); + absl::Status SetNodeToken(const string& node_name, xla::XlaOp op); absl::StatusOr GetNodeToken(const string& node_name); // Sets the function body `fbody` to the one registered as `function`. - Status FindFunctionBody(const NameAttrList& function, - const FunctionBody** fbody, - const ConfigProto** config_proto = nullptr); + absl::Status FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody, + const ConfigProto** config_proto = nullptr); private: + absl::Mutex channel_mutex_; // Returns the optimized graph object in this function body. std::unique_ptr GetGraph(const FunctionBody* fbody); // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. - Status BuildArguments(const Graph& graph, - const std::vector& args, - bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, - const std::map& arg_shardings, - std::vector* arg_expressions, - std::vector* input_to_args, - std::vector* input_shapes, - bool is_entry_computation); + absl::Status BuildArguments( + const Graph& graph, const std::vector& args, + bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, + const std::map& arg_shardings, + std::vector* arg_expressions, + std::vector* input_to_args, std::vector* input_shapes, + bool is_entry_computation); + + xla::ChannelHandle NewChannel(xla::ChannelHandle::ChannelType type); // Graph compiler needs to know how to get an optimized graph from a function // body. @@ -341,7 +344,7 @@ class XlaCompiler { Options options_; // Status set to non-OK in the constructor if initialization fails. - Status initialization_status_; + absl::Status initialization_status_; // Returns the next step sequence number. int64_t NextStepId(); @@ -352,6 +355,9 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ StaticDeviceMgr device_mgr_; + // The next sequence number to assign to a channel. + int64_t next_channel_ ABSL_GUARDED_BY(channel_mutex_) = 1; + // To avoid copying the client's function library, use a local function // library and runtime for functions created as part of the functionalize // control flow transformation. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index d004236e22e22a..02c86294017a17 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include #include "absl/strings/match.h" +#include "absl/strings/str_cat.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/data_flow_ops.h" @@ -33,13 +35,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_proto_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tests/literal_test_util.h" +#include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/common_shape_fns.h" @@ -618,7 +621,7 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { XlaCompiler compiler(DefaultOptions()); XlaCompiler::CompilationResult result; - Status status = + absl::Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", std::move(graph), args, &result); EXPECT_FALSE(status.ok()); @@ -708,7 +711,7 @@ TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) { foo.set_name("foo"); foo.set_op("foo"); *foo.add_input() = "input_arg"; - Status status; + absl::Status status; scope.graph()->AddNode(foo, &status); TF_ASSERT_OK(status); NodeDef retval_1; @@ -771,7 +774,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { // Compiles the graph. auto options = DefaultOptions(); - std::function populate_function = + std::function populate_function = [resource](ResourceMgr* rm) { resource->Ref(); return rm->Create(rm->default_container(), "dummy", resource); @@ -988,7 +991,7 @@ TEST_F(XlaCompilerTest, UndefinedFunctionFails) { XlaCompiler::CompilationResult result; NameAttrList name_attr; name_attr.set_name("Function_NotDefined_"); - Status status = + absl::Status status = compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); EXPECT_FALSE(status.ok()); @@ -1034,7 +1037,7 @@ TEST_F(XlaCompilerTest, FunctionCallWithConstants) { .Input(value.name(), 0, DT_INT32) .Input(shape.name(), 1, DT_INT32) .Finalize(&def)); - Status status; + absl::Status status; Node* fill = scope.graph()->AddNode(def, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(scope.DoShapeInference(fill)); @@ -1065,7 +1068,7 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { XlaCompiler::CompilationResult result; NameAttrList name_attr; name_attr.set_name("XTimesTwo"); - Status status = + absl::Status status = compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, /*args=*/{}, &result); @@ -1121,7 +1124,7 @@ TEST_F(XlaCompilerTest, SliceWithDynamicBegins) { .Input(begin.node()->name(), 1, DT_INT32) .Input(size.name(), 2, DT_INT32) .Finalize(&def)); - Status status; + absl::Status status; Node* slice = scope.graph()->AddNode(def, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(scope.DoShapeInference(slice)); @@ -1554,7 +1557,7 @@ TEST_F(XlaCompilerTest, FunctionWithInvalidOp) { .Input(value.name(), 0, DT_INT32) .Input(shape.name(), 1, DT_INT32) .Finalize(&def)); - Status status; + absl::Status status; Node* fill = scope.graph()->AddNode(def, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(scope.DoShapeInference(fill)); @@ -1584,7 +1587,7 @@ TEST_F(XlaCompilerTest, NodeWithInvalidDataType) { shape.set_op("Shape"); (*shape.mutable_attr())["T"].set_type(DT_INT32); (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */ - Status status; + absl::Status status; Node* shape_node = graph->AddNode(shape, &status); TF_ASSERT_OK(status); graph->AddControlEdge(graph->source_node(), shape_node); @@ -1607,7 +1610,7 @@ TEST_F(XlaCompilerTest, SingleOpWithoutInputs) { NodeDef no_op; no_op.set_name("NoOp"); no_op.set_op("NoOp"); - Status status; + absl::Status status; graph->AddNode(no_op, &status); TF_ASSERT_OK(status); @@ -1645,7 +1648,7 @@ TEST_F(XlaCompilerTest, TokenInputAndOutput) { std::vector{kXlaTokenArgNodeName}, &side_effecting_op); AddNodeAttr(kXlaOriginalOutsideCompilationNodeName, side_effecting_op.name(), &side_effecting_op); - Status status; + absl::Status status; graph->AddNode(side_effecting_op, &status); TF_ASSERT_OK(status); EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get())); @@ -1998,7 +2001,7 @@ TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { std::vector shapes2{TensorShape({1})}; TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); - Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2); + absl::Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2); EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); } @@ -2027,10 +2030,31 @@ TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { std::vector shapes2{TensorShape({1})}; TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); - Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2); + absl::Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2); EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); } +TEST_F(XlaCompilerTest, GetChannelHandleIndependently) { + XlaCompiler compiler1(DefaultOptions()); + XlaCompiler compiler2(DefaultOptions()); + int num_channels = 3; + std::vector channel_ids1, channel_ids2; + for (int j = 0; j < num_channels; ++j) { + xla::ChannelHandle channel_handle; + TF_ASSERT_OK( + compiler1.GetChannelHandle(/*key=*/absl::StrCat(j), &channel_handle)); + channel_ids1.push_back(channel_handle.handle()); + } + for (int j = 0; j < num_channels; ++j) { + xla::ChannelHandle channel_handle; + TF_ASSERT_OK( + compiler2.GetChannelHandle(/*key=*/absl::StrCat(j), &channel_handle)); + channel_ids2.push_back(channel_handle.handle()); + } + EXPECT_THAT(channel_ids1, ::testing::UnorderedElementsAreArray({1, 2, 3})); + EXPECT_THAT(channel_ids2, ::testing::UnorderedElementsAreArray({1, 2, 3})); +} + TEST_F(OpsTestBase, BuildSingleOpCompileArgument) { TF_EXPECT_OK(NodeDefBuilder("identity_op", "Identity") .Input(FakeInput(DT_FLOAT)) diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 4a8a1e9e531e32..92ddf0125aded1 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -25,8 +25,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -188,7 +188,7 @@ const xla::XlaComputation* XlaContext::LookupOrCreate( } } -Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult( +absl::Status XlaContext::RecordCollectiveInfoFromNestedCompilationResult( const XlaCompilationResult& result) { if (result.collective_info) { return RecordCollectiveInfo(result.collective_info->group_key, diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 66df05e29d31a3..9184fb4300633c 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" @@ -111,7 +111,7 @@ class XlaContext : public ResourceBase { static const char kXlaContextResourceName[]; // Records the collective information from the nested compilation `result`. - Status RecordCollectiveInfoFromNestedCompilationResult( + absl::Status RecordCollectiveInfoFromNestedCompilationResult( const XlaCompilationResult& result); // Records the collective configurations for all the collectives in the XLA diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 1407e1c79ff38e..61bd10e413ccf3 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index b2e323db1e5a41..d410b79a3da137 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -19,8 +19,8 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/client.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc index 13a280af323cc3..7a0cc34de9af2e 100644 --- a/tensorflow/compiler/tf2xla/xla_expression_test.cc +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -13,14 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/xla_expression.h" + #include #include "absl/memory/memory.h" -#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 6bf772e7dd03f9..924ca93144e924 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -26,10 +26,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/stream_executor/stream.h" @@ -69,7 +69,7 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return ::tensorflow::FloatLiteral(b, type, value); } -/* static */ Status XlaHelpers::ReshapeLiteral( +/* static */ absl::Status XlaHelpers::ReshapeLiteral( const xla::Literal& input, absl::Span dimensions, xla::Literal* output) { if (input.shape().IsTuple()) { @@ -90,10 +90,13 @@ xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type, return absl::OkStatus(); } -Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, - DataType index_type, const TensorShape& indices_shape, - const xla::XlaOp indices, const xla::XlaOp on_value, - const xla::XlaOp off_value, xla::XlaOp* one_hot) { +absl::Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64_t depth, + int axis, DataType index_type, + const TensorShape& indices_shape, + const xla::XlaOp indices, + const xla::XlaOp on_value, + const xla::XlaOp off_value, + xla::XlaOp* one_hot) { // Broadcast the linspace constant across the indices along the new axis, // and test equality at each position. std::vector broadcast_dims(indices_shape.dims()); @@ -147,7 +150,7 @@ XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() { }; } -Status ResolveDeviceAssignment( +absl::Status ResolveDeviceAssignment( OpKernelContext* ctx, const XlaCompilationResult::CollectiveInfo& collective_info, xla::ExecutableRunOptions& run_options, @@ -177,11 +180,11 @@ Status ResolveDeviceAssignment( VLOG(5) << "Using collective params to resolve device assignment: " << params->ToString(); - Status st; + absl::Status st; absl::Notification n; ctx->collective_executor()->CompleteParamsAsync( ctx->device()->attributes(), params.get(), ctx->cancellation_manager(), - [&](const Status& s) { + [&](const absl::Status& s) { st = s; n.Notify(); }); @@ -234,7 +237,7 @@ Status ResolveDeviceAssignment( const DeviceAttributes& device_attributes = params->group.members[device_idx].device; Device* resolved_device = nullptr; - Status lookup_status = + absl::Status lookup_status = device_mgr->LookupDevice(device_attributes.name(), &resolved_device); if (lookup_status.ok()) { // This is a local device, so include it in the mapping. diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 9765ac8bc84db2..38f01c83db8251 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -23,11 +23,11 @@ limitations under the License. #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" -#include "xla/client/xla_builder.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #include "xla/service/computation_placer.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -68,19 +68,20 @@ class XlaHelpers { // Reshapes literal 'input' to have 'shape'. Both the original shape and // 'shape' must contain the same number of elements. - static Status ReshapeLiteral(const xla::Literal& input, - absl::Span shape, - xla::Literal* output); + static absl::Status ReshapeLiteral(const xla::Literal& input, + absl::Span shape, + xla::Literal* output); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new // axis. `indices_shape` is the shape of `indices`. `on_value` and // `off_value` represent the values to use for the on and off positions, // respectively. - static Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, - DataType index_type, const TensorShape& indices_shape, - xla::XlaOp indices, xla::XlaOp on_value, - xla::XlaOp off_value, xla::XlaOp* one_hot); + static absl::Status OneHot(xla::XlaBuilder* builder, int64_t depth, int axis, + DataType index_type, + const TensorShape& indices_shape, + xla::XlaOp indices, xla::XlaOp on_value, + xla::XlaOp off_value, xla::XlaOp* one_hot); // Certain DataTypes should use increased precision DataTypes when performing // reductions. This function remaps a given DataType to a higher precision @@ -201,7 +202,7 @@ struct XlaCompilationResult { // Takes several extra configuration objects by reference since // xla::ExecutableRunOptions does not take ownership; these are configured and // bundled into `run_options` if applicable. -Status ResolveDeviceAssignment( +absl::Status ResolveDeviceAssignment( OpKernelContext* ctx, const XlaCompilationResult::CollectiveInfo& collective_info, xla::ExecutableRunOptions& run_options, diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 3acbab6fb7aa2a..ac8586148b6673 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" #include "xla/cpu_function_runtime.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/cpu/buffer_info_util.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/platform_util.h" diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 6909a8ad3076ef..a17ccd63d14f76 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -24,9 +24,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/platform/errors.h" @@ -119,7 +119,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) { xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType type; - Status status = DataTypeToPrimitiveType(input_type(index), &type); + absl::Status status = DataTypeToPrimitiveType(input_type(index), &type); if (!status.ok()) { SetStatus(status); return xla::PRIMITIVE_TYPE_INVALID; @@ -129,7 +129,7 @@ xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) { xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) { xla::PrimitiveType type; - Status status = DataTypeToPrimitiveType(InputType(name), &type); + absl::Status status = DataTypeToPrimitiveType(InputType(name), &type); if (!status.ok()) { SetStatus(status); return xla::PRIMITIVE_TYPE_INVALID; @@ -137,9 +137,9 @@ xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) { return type; } -Status XlaOpKernelContext::ConstantInput(int index, - xla::Literal* constant_literal, - xla::ValueInferenceMode mode) { +absl::Status XlaOpKernelContext::ConstantInput(int index, + xla::Literal* constant_literal, + xla::ValueInferenceMode mode) { if (this->InputXlaShape(index)->is_dynamic()) { return errors::InvalidArgument( "Reading input as constant from a dynamic tensor is not yet supported. " @@ -164,26 +164,26 @@ static absl::StatusOr InputIndex(XlaOpKernelContext* context, return start; } -Status XlaOpKernelContext::ResolveInputDynamism( +absl::Status XlaOpKernelContext::ResolveInputDynamism( int index, xla::Literal* dynamism_literal) { return ResolveInputDynamismReshaped( index, context_->input(index).shape().dim_sizes(), dynamism_literal); } -Status XlaOpKernelContext::ResolveInputDynamism( +absl::Status XlaOpKernelContext::ResolveInputDynamism( absl::string_view name, xla::Literal* dynamism_literal) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ResolveInputDynamism(index, dynamism_literal); } -Status XlaOpKernelContext::ConstantInput(absl::string_view name, - xla::Literal* constant_literal, - xla::ValueInferenceMode mode) { +absl::Status XlaOpKernelContext::ConstantInput(absl::string_view name, + xla::Literal* constant_literal, + xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInput(index, constant_literal, mode); } -Status XlaOpKernelContext::ConstantInputReshaped( +absl::Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal, xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(Tensor constant, ConstantInputTensor(index, mode)); @@ -201,8 +201,8 @@ Status XlaOpKernelContext::ConstantInputReshaped( } // Converts an int16, int32 or int64 scalar literal to an int64. -static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, - int64_t* out) { +static absl::Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, + int64_t* out) { if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -219,8 +219,8 @@ static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, } // Converts an float32 or float64 scalar literal to a float64. -static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, - double* out) { +static absl::Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, + double* out) { if (literal.shape().rank() != 0) { return errors::InvalidArgument("value is not a scalar"); } @@ -234,14 +234,14 @@ static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, return absl::OkStatus(); } -Status XlaOpKernelContext::ConstantInputAsIntScalar( +absl::Status XlaOpKernelContext::ConstantInputAsIntScalar( int index, int64_t* out, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode)); return LiteralToInt64Scalar(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntScalar( +absl::Status XlaOpKernelContext::ConstantInputAsIntScalar( absl::string_view name, int64_t* out, xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntScalar(index, out, mode); @@ -254,15 +254,15 @@ absl::StatusOr XlaOpKernelContext::ConstantInputAsIntScalar( return out; } -Status XlaOpKernelContext::ConstantInputAsFloatScalar( +absl::Status XlaOpKernelContext::ConstantInputAsFloatScalar( int index, double* out, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode)); return LiteralToFloat64Scalar(literal, out); } -static Status LiteralToPredVector(const xla::LiteralSlice& literal, - std::vector* out) { +static absl::Status LiteralToPredVector(const xla::LiteralSlice& literal, + std::vector* out) { if (literal.shape().rank() != 1) { return errors::InvalidArgument("output_shape must be rank 1, got shape ", literal.shape().DebugString()); @@ -277,7 +277,8 @@ static Status LiteralToPredVector(const xla::LiteralSlice& literal, return absl::OkStatus(); } -Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { +absl::Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, + bool* out) { xla::Literal literal; XlaExpression e = InputExpression(index); absl::StatusOr dynamism_or_status = e.ResolveDynamism(); @@ -306,19 +307,19 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) { return absl::OkStatus(); } -Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( +absl::Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( absl::string_view name, std::vector* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ResolveInputDynamismIntoPredVector(index, out); } -Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name, - bool* out) { +absl::Status XlaOpKernelContext::ResolveInputDynamismIntoPred( + absl::string_view name, bool* out) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ResolveInputDynamismIntoPred(index, out); } -Status XlaOpKernelContext::ResolveInputDynamismReshaped( +absl::Status XlaOpKernelContext::ResolveInputDynamismReshaped( int index, absl::Span new_dims, xla::Literal* dynamism_literal) { XlaExpression e = InputExpression(index); @@ -350,7 +351,7 @@ Status XlaOpKernelContext::ResolveInputDynamismReshaped( return absl::OkStatus(); } -Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( +absl::Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( int index, std::vector* out) { xla::Literal literal; TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped( @@ -360,8 +361,8 @@ Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( } // Converts an int32 or int64 1D literal to an int64 vector. -static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, - std::vector* out) { +static absl::Status LiteralToInt64Vector(const xla::LiteralSlice& literal, + std::vector* out) { if (literal.shape().rank() != 1) { return errors::InvalidArgument("output_shape must be rank 1, got shape ", literal.shape().DebugString()); @@ -381,21 +382,21 @@ static Status LiteralToInt64Vector(const xla::LiteralSlice& literal, return absl::OkStatus(); } -Status XlaOpKernelContext::ConstantInputAsIntVector( +absl::Status XlaOpKernelContext::ConstantInputAsIntVector( int index, std::vector* out, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode)); return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputAsIntVector( +absl::Status XlaOpKernelContext::ConstantInputAsIntVector( absl::string_view name, std::vector* out, xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsIntVector(index, out, mode); } -Status XlaOpKernelContext::ConstantInputReshapedToIntVector( +absl::Status XlaOpKernelContext::ConstantInputReshapedToIntVector( int index, std::vector* out, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInputReshaped( @@ -403,7 +404,7 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputReshapedToIntVector( +absl::Status XlaOpKernelContext::ConstantInputReshapedToIntVector( absl::string_view name, std::vector* out, xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); @@ -413,7 +414,7 @@ Status XlaOpKernelContext::ConstantInputReshapedToIntVector( return LiteralToInt64Vector(literal, out); } -Status XlaOpKernelContext::ConstantInputAsInt64Literal( +absl::Status XlaOpKernelContext::ConstantInputAsInt64Literal( int index, xla::Literal* out, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode)); @@ -438,7 +439,7 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal( } } -Status XlaOpKernelContext::ConstantInputAsInt64Literal( +absl::Status XlaOpKernelContext::ConstantInputAsInt64Literal( absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) { TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name)); return ConstantInputAsInt64Literal(index, out, mode); @@ -446,8 +447,8 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal( // TODO(phawkins): validate that the dimensions form a valid shape, fail // gracefully if they do not. -Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape, - xla::ValueInferenceMode mode) { +absl::Status XlaOpKernelContext::ConstantInputAsShape( + int index, TensorShape* shape, xla::ValueInferenceMode mode) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode)); std::vector dims; @@ -466,7 +467,7 @@ Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape, return absl::OkStatus(); } -Status XlaOpKernelContext::ConstantInputAsPartialShape( +absl::Status XlaOpKernelContext::ConstantInputAsPartialShape( int index, PartialTensorShape* shape) { xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); @@ -487,9 +488,9 @@ Status XlaOpKernelContext::ConstantInputAsPartialShape( return absl::OkStatus(); } -Status XlaOpKernelContext::InputList(absl::string_view name, - std::vector* handles, - std::vector* shapes) { +absl::Status XlaOpKernelContext::InputList(absl::string_view name, + std::vector* handles, + std::vector* shapes) { OpInputList inputs; TF_RETURN_IF_ERROR(context_->input_list(name, &inputs)); handles->clear(); @@ -502,9 +503,9 @@ Status XlaOpKernelContext::InputList(absl::string_view name, return absl::OkStatus(); } -Status XlaOpKernelContext::ConstantInputList(absl::string_view name, - std::vector* outputs, - xla::ValueInferenceMode mode) { +absl::Status XlaOpKernelContext::ConstantInputList( + absl::string_view name, std::vector* outputs, + xla::ValueInferenceMode mode) { int start, stop; TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop)); outputs->resize(stop - start); @@ -521,7 +522,7 @@ absl::StatusOr XlaOpKernelContext::ConstantInputTensor( absl::StatusOr> constant_or_status = e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode); if (!constant_or_status.ok()) { - Status status = constant_or_status.status(); + absl::Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); @@ -545,9 +546,9 @@ absl::StatusOr XlaOpKernelContext::ConstantInputTensor( namespace { -Status ReadVariableInputTensor(const Tensor& tensor, DataType type, - const XlaOpKernelContext* ctx, - TensorShape* shape, xla::XlaOp* value) { +absl::Status ReadVariableInputTensor(const Tensor& tensor, DataType type, + const XlaOpKernelContext* ctx, + TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -596,22 +597,23 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, } // namespace -Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, - TensorShape* shape, - xla::XlaOp* value) { +absl::Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, + TensorShape* shape, + xla::XlaOp* value) { return ReadVariableInputTensor(context_->input(index), type, this, shape, value); } -Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, - DataType type, TensorShape* shape, - xla::XlaOp* value) { +absl::Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, + DataType type, + TensorShape* shape, + xla::XlaOp* value) { return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape, value); } -Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, - TensorShape* shape) const { +absl::Status XlaOpKernelContext::GetVariableTypeAndShape( + int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = XlaExpression::CastExpressionFromTensor(tensor); @@ -631,7 +633,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, void XlaOpKernelContext::SetOutputExpression(int index, const XlaExpression& expression) { - Status status = [&] { + absl::Status status = [&] { // The step's default allocator is the dummy XlaCompilationAllocator which // simply allocates a metadata buffer to hold the expression to which it // corresponds. @@ -666,7 +668,8 @@ void XlaOpKernelContext::SetOutputExpression(int index, xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) { xla::PrimitiveType type; - Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type); + absl::Status status = + DataTypeToPrimitiveType(expected_output_dtype(index), &type); if (!status.ok()) { SetStatus(status); return xla::PRIMITIVE_TYPE_INVALID; @@ -693,7 +696,8 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { SetOutputExpression(index, XlaExpression::Resource(resource)); } -Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { +absl::Status XlaOpKernelContext::GetResourceInput(int index, + XlaResource** resource) { const XlaExpression* expression = XlaExpression::CastExpressionFromTensor(context_->input(index)); TF_RET_CHECK(expression->resource() != nullptr); @@ -703,9 +707,9 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { namespace { -Status AssignVariableTensor(const Tensor& tensor, DataType type, - const XlaOpKernelContext* ctx, xla::XlaOp handle, - xla::XlaBuilder* builder) { +absl::Status AssignVariableTensor(const Tensor& tensor, DataType type, + const XlaOpKernelContext* ctx, + xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -740,41 +744,43 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, } // namespace -Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, - xla::XlaOp handle) { +absl::Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(context_->input(input_index), type, this, handle, builder()); } -Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type, - xla::XlaOp handle) { +absl::Status XlaOpKernelContext::AssignVariable(absl::string_view name, + DataType type, + xla::XlaOp handle) { TF_RET_CHECK(handle.valid()); return AssignVariableTensor(GetInputTensorByName(name), type, this, handle, builder()); } -static Status GetStatusWithStackTrace(const Status& s, - const XlaOpKernelContext* ctx) { +static absl::Status GetStatusWithStackTrace(const absl::Status& s, + const XlaOpKernelContext* ctx) { if (s.code() == error::INVALID_ARGUMENT) { - return Status{s.code(), absl::StrCat(s.message(), "\n", ctx->StackTrace())}; + return absl::Status{s.code(), + absl::StrCat(s.message(), "\n", ctx->StackTrace())}; } return s; } -void XlaOpKernelContext::CtxFailure(const Status& s) { +void XlaOpKernelContext::CtxFailure(const absl::Status& s) { context_->CtxFailure(GetStatusWithStackTrace(s, this)); } -void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) { +void XlaOpKernelContext::CtxFailureWithWarning(const absl::Status& s) { context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this)); } void XlaOpKernelContext::CtxFailure(const char* file, int line, - const Status& s) { + const absl::Status& s) { context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this)); } void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, - const Status& s) { + const absl::Status& s) { context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this)); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a66be8384f003c..b0830d0766acb2 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -16,16 +16,26 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ +#include "absl/base/attributes.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/literal.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -123,23 +133,25 @@ class XlaOpKernelContext { // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status InputList(absl::string_view name, std::vector* handles, - std::vector* shapes); + absl::Status InputList(absl::string_view name, + std::vector* handles, + std::vector* shapes); // Evaluates input and returns their dynamism vector in a vector of // predicates. - Status ResolveInputDynamismIntoPredVector(int index, std::vector* out); - Status ResolveInputDynamismIntoPred(int index, bool* out); - Status ResolveInputDynamismIntoPredVector(absl::string_view name, - std::vector* out); - Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out); - - Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal); - Status ResolveInputDynamism(absl::string_view name, - xla::Literal* dynamism_literal); - - Status ResolveInputDynamismReshaped(int index, - absl::Span new_dims, - xla::Literal* dynamism_literal); + absl::Status ResolveInputDynamismIntoPredVector(int index, + std::vector* out); + absl::Status ResolveInputDynamismIntoPred(int index, bool* out); + absl::Status ResolveInputDynamismIntoPredVector(absl::string_view name, + std::vector* out); + absl::Status ResolveInputDynamismIntoPred(absl::string_view name, bool* out); + + absl::Status ResolveInputDynamism(int index, xla::Literal* dynamism_literal); + absl::Status ResolveInputDynamism(absl::string_view name, + xla::Literal* dynamism_literal); + + absl::Status ResolveInputDynamismReshaped(int index, + absl::Span new_dims, + xla::Literal* dynamism_literal); // Helper methods for constant inputs. // Evaluates input `index` and stores it in `*constant_literal`. If the @@ -147,18 +159,18 @@ class XlaOpKernelContext { // parameters, returns a non-OK status. This function can also be used to // infer constant input upper or lower bounds, by changing the `mode` // parameter. - Status ConstantInput( + absl::Status ConstantInput( int index, xla::Literal* constant_literal, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); - Status ConstantInput( + absl::Status ConstantInput( absl::string_view name, xla::Literal* constant_literal, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant scalar int32 or int64 tensor into an int64. - Status ConstantInputAsIntScalar( + absl::Status ConstantInputAsIntScalar( int index, int64_t* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); - Status ConstantInputAsIntScalar( + absl::Status ConstantInputAsIntScalar( absl::string_view name, int64_t* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); @@ -167,48 +179,49 @@ class XlaOpKernelContext { xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant scalar float32 or float64 tensor into a float64. - Status ConstantInputAsFloatScalar( + absl::Status ConstantInputAsFloatScalar( int index, double* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant 1D int32 or int64 tensor into a vector of int64s. - Status ConstantInputAsIntVector( + absl::Status ConstantInputAsIntVector( int index, std::vector* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); - Status ConstantInputAsIntVector( + absl::Status ConstantInputAsIntVector( absl::string_view name, std::vector* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Reshapes and converts a constant int32 or int64 tensor into a vector of // int64s. - Status ConstantInputReshapedToIntVector( + absl::Status ConstantInputReshapedToIntVector( int index, std::vector* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); - Status ConstantInputReshapedToIntVector( + absl::Status ConstantInputReshapedToIntVector( absl::string_view name, std::vector* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant int32 or int64 Tensor into an xla int64 Literal. - Status ConstantInputAsInt64Literal( + absl::Status ConstantInputAsInt64Literal( int index, xla::Literal* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); - Status ConstantInputAsInt64Literal( + absl::Status ConstantInputAsInt64Literal( absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant 1D int32 or int64 tensor into a TensorShape. - Status ConstantInputAsShape( + absl::Status ConstantInputAsShape( int index, TensorShape* shape, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); // Converts a constant 1D int32 or int64 tensor, or a scalar with value -1 // into a PartialTensorShape. - Status ConstantInputAsPartialShape(int index, PartialTensorShape* shape); + absl::Status ConstantInputAsPartialShape(int index, + PartialTensorShape* shape); // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. - Status ConstantInputList( + absl::Status ConstantInputList( absl::string_view name, std::vector* outputs, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); @@ -250,21 +263,21 @@ class XlaOpKernelContext { void SetTensorListOutput(int index, const xla::XlaOp& handle); // Status handling. - void SetStatus(const Status& status) { context_->SetStatus(status); } - Status status() { return context_->status(); } + void SetStatus(const absl::Status& status) { context_->SetStatus(status); } + absl::Status status() { return context_->status(); } // Variables // Sets `*resource` to the resource associated with input `index`. - Status GetResourceInput(int index, XlaResource** resource); + absl::Status GetResourceInput(int index, XlaResource** resource); // Sets output `index` to be a reference to resource `resource`. void SetResourceOutput(int index, XlaResource* resource); // Sets `*type` and `*shape` to the current type and shape of a variable's // value. - Status GetVariableTypeAndShape(int index, DataType* type, - TensorShape* shape) const; + absl::Status GetVariableTypeAndShape(int index, DataType* type, + TensorShape* shape) const; // When dynamic_dimension_is_minus_one is set, querying a dynamic dimension // returns "-1", this is useful when the underlying ops expect explicit @@ -283,27 +296,28 @@ class XlaOpKernelContext { // `index`. If `shape` is not nullptr, sets `*shape` to the shape of the // variable. Returns an error if the variable has not been initialized, or if // its type does not match `type`. - Status ReadVariableInput(int index, DataType type, TensorShape* shape, - xla::XlaOp* value); + absl::Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::XlaOp* value); // Reads the current value of the resource variable referred to by input // `name`. - Status ReadVariableInput(absl::string_view name, DataType type, - TensorShape* shape, xla::XlaOp* value); + absl::Status ReadVariableInput(absl::string_view name, DataType type, + TensorShape* shape, xla::XlaOp* value); // Assigns the value `handle` to the variable referenced by input // `input_index`. The variable must be of `type`. Returns an error if the // variable has been initialized with a different type or with a // different shape. - Status AssignVariable(int input_index, DataType type, xla::XlaOp handle); + absl::Status AssignVariable(int input_index, DataType type, + xla::XlaOp handle); // Assigns the value `handle` to the variable referenced by input `name`. - Status AssignVariable(absl::string_view name, DataType type, - xla::XlaOp handle); + absl::Status AssignVariable(absl::string_view name, DataType type, + xla::XlaOp handle); // Helper routines for the OP_REQUIRES macros - void CtxFailure(const Status& s); - void CtxFailureWithWarning(const Status& s); - void CtxFailure(const char* file, int line, const Status& s); - void CtxFailureWithWarning(const char* file, int line, const Status& s); + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. @@ -361,7 +375,7 @@ class XlaOpKernelContext { // cannot be evaluated, e.g., because it depends on unbound parameters, // returns a non-Ok status. If InputShape(index).num_elements() != // new_shape.num_elements(), returns an error status. - Status ConstantInputReshaped( + absl::Status ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal, xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 0109f6a3f07ef3..445065971f2a6a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -19,21 +19,34 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "xla/client/client_library.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/common_runtime/local_device.h" +#include "xla/util.h" #include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_util.h" -#include "tensorflow/core/platform/mem.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace tensorflow { @@ -42,7 +55,7 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; const char* const DEVICE_XLA_CPU = "XLA_CPU"; const char* const DEVICE_XLA_GPU = "XLA_GPU"; -static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { +static absl::Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { const OpDef* op_def; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("XlaLaunch", &op_def)); NodeDef node_def; @@ -244,7 +257,7 @@ void XlaOpRegistry::RegisterCompilationKernels() { for (auto& op_registration : op_registrations) { const OpDef* op_def; - Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + absl::Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); if (!lookup_status.ok()) { LOG(ERROR) << lookup_status.message(); XLA_LOG_LINES( @@ -415,7 +428,7 @@ XlaOpRegistry::CompileTimeConstantInputArgNames(const string& op) { } } -/* static */ Status XlaOpRegistry::CompileTimeConstantInputs( +/* static */ absl::Status XlaOpRegistry::CompileTimeConstantInputs( const NodeDef& node_def, const OpKernel* op_kernel, const OpDef* op_def, std::vector* result) { result->clear(); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 333a9168f3deda..11bbbf2b928871 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -22,16 +22,25 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/errors.h" namespace tensorflow { @@ -191,9 +200,9 @@ class XlaOpRegistry { // registered. // // `result` is sorted. - static Status CompileTimeConstantInputs(const NodeDef& node_def, - const OpDef& op_def, - std::vector* result) { + static absl::Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpDef& op_def, + std::vector* result) { return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, result); } @@ -209,8 +218,8 @@ class XlaOpRegistry { // compile-time constants. // // `result` is sorted. - static Status CompileTimeConstantInputs(const OpKernel& op_kernel, - std::vector* result) { + static absl::Status CompileTimeConstantInputs(const OpKernel& op_kernel, + std::vector* result) { return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, /*op_def=*/nullptr, result); } @@ -305,10 +314,10 @@ class XlaOpRegistry { // their allowlists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); - static Status CompileTimeConstantInputs(const NodeDef& node_def, - const OpKernel* op_kernel, - const OpDef* op_def, - std::vector* result); + static absl::Status CompileTimeConstantInputs(const NodeDef& node_def, + const OpKernel* op_kernel, + const OpDef* op_def, + std::vector* result); // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. // Registrations present under the same key must satisfy IsCompatible above, diff --git a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc index 4d8e1bc31f8d58..13b648b78004ac 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry_test.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry_test.cc @@ -14,7 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +#include "absl/log/log.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 0e1d33a0c1c718..c3b690e2fdfda3 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -18,11 +18,22 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/status_macros.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/managed_stack_trace.h" +#include "tsl/platform/errors.h" namespace tensorflow { @@ -87,7 +98,8 @@ XlaResource::XlaResource( } } -Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { +absl::Status XlaResource::SetTypeAndShape(DataType type, + const TensorShape& shape) { if (type == DT_INVALID) { return errors::InvalidArgument( "Attempted to set type of resource '", name_, "'' to an invalid type", @@ -112,7 +124,7 @@ Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { return absl::OkStatus(); } -Status XlaResource::SetValue(const xla::XlaOp& value) { +absl::Status XlaResource::SetValue(const xla::XlaOp& value) { if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, @@ -123,7 +135,7 @@ Status XlaResource::SetValue(const xla::XlaOp& value) { return absl::OkStatus(); } -Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { +absl::Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { is_overwritten_ = true; if (type_ == DT_INVALID) { return errors::InvalidArgument( @@ -162,9 +174,9 @@ Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { return absl::OkStatus(); } -Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, - xla::XlaBuilder* builder, - XlaResource** gradient_out) { +absl::Status XlaResource::GetOrCreateTensorArrayGradient( + const string& source, xla::XlaBuilder* builder, + XlaResource** gradient_out) { VLOG(2) << "Gradient lookup for resource: " << name_ << " gradient: " << source; TF_RET_CHECK(kind_ == kTensorArray); @@ -186,7 +198,8 @@ Status XlaResource::GetOrCreateTensorArrayGradient(const string& source, return absl::OkStatus(); } -Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { +absl::Status XlaResource::Pack(xla::XlaOp* pack, + xla::XlaBuilder* builder) const { if (tensor_array_gradients_.empty()) { *pack = value_; } else { @@ -201,9 +214,9 @@ Status XlaResource::Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const { return absl::OkStatus(); } -Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::XlaOp& pack, - xla::XlaBuilder* builder) { +absl::Status XlaResource::SetFromPack(const std::set& gradient_sources, + const xla::XlaOp& pack, + xla::XlaBuilder* builder) { if (gradient_sources.empty()) { if (!initialized()) { initial_value_ = pack; diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 902c62edd5664a..1e515c2a5dd5fd 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -19,11 +19,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "xla/client/xla_builder.h" +#include "absl/types/optional.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/managed_stack_trace.h" namespace tensorflow { @@ -100,14 +103,14 @@ class XlaResource { // Sets the type and shape of the resource. The type and shape of a resource // must not change once the variable has been initialized. - Status SetTypeAndShape(DataType type, const TensorShape& shape); + absl::Status SetTypeAndShape(DataType type, const TensorShape& shape); // Sets the current value of the resource. Returns an error if the type is not // set to a valid value. - Status SetValue(const xla::XlaOp& value); + absl::Status SetValue(const xla::XlaOp& value); // Sets the current value of the resource to an all-zero value. - Status SetZeroValue(xla::XlaBuilder* builder); + absl::Status SetZeroValue(xla::XlaBuilder* builder); // Sets the representational shape of the resource on device. void SetRepresentationShape(const xla::Shape& shape) { @@ -118,16 +121,16 @@ class XlaResource { // exist. The call target must be an initialized TensorArray resource. A // TensorArray can have multiple named gradients; see the operator // documentation for TensorArrayGradV3 for details. - Status GetOrCreateTensorArrayGradient(const string& source, - xla::XlaBuilder* builder, - XlaResource** gradient_out); + absl::Status GetOrCreateTensorArrayGradient(const string& source, + xla::XlaBuilder* builder, + XlaResource** gradient_out); // Packs a resource into a single XLA value `pack`, suitable for use as // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without // gradients, sets `*pack` to `value`. // For TensorArrays with gradients, packs the value and its gradient values in // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; + absl::Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const; // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and @@ -135,8 +138,8 @@ class XlaResource { // If `reset_initial_values` is true, sets the initial_values as well as the // values. // Opposite of Pack(). - Status SetFromPack(const std::set& gradient_sources, - const xla::XlaOp& pack, xla::XlaBuilder* builder); + absl::Status SetFromPack(const std::set& gradient_sources, + const xla::XlaOp& pack, xla::XlaBuilder* builder); bool IsOverwritten() { return is_overwritten_; } diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 88c081ec806867..5a56b1ac0acabf 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -64,6 +64,8 @@ # Placeholder: load py_proto_library load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") +# copybara:uncomment(google-only) load("//tools/build_defs/go:go_library.bzl", "go_library") + load( "@local_xla//xla/tsl/mkl:build_defs.bzl", "if_mkl", @@ -87,7 +89,16 @@ load( "tf_opts_nortti_if_lite_protos", "transitive_hdrs", ) -load("//tensorflow:tensorflow.default.bzl", "cc_header_only_library", "filegroup", "get_compatible_with_portable", "tensorflow_opensource_extra_deps", "tf_monitoring_framework_deps", "tf_selective_registration_deps") +load( + "//tensorflow:tensorflow.default.bzl", + "cc_header_only_library", + "custom_op_cc_header_only_library", + "filegroup", + "get_compatible_with_portable", + "tensorflow_opensource_extra_deps", + "tf_monitoring_framework_deps", + "tf_selective_registration_deps", +) # For platform specific build config load( @@ -181,9 +192,9 @@ tf_proto_library( "//tensorflow/core/grappler/costs:op_performance_data", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto", "@local_tsl//tsl/profiler/protobuf:xplane_proto", - "@local_tsl//tsl/protobuf:coordination_config_proto", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto", - "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:coordination_config_proto", + "@local_xla//xla/tsl/protobuf:distributed_runtime_payloads_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], visibility = ["//visibility:public"], ) @@ -279,13 +290,13 @@ cc_library( "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", "//tensorflow/core/lib/io:legacy_lib_io_headers", "//tensorflow/core/lib/math:math_util.h", - "@local_tsl//tsl/lib/math:math_util.h", + "@local_xla//xla/tsl/lib/math:math_util.h", "//tensorflow/core/lib/monitoring:legacy_lib_monitoring_lib_headers", "//tensorflow/core/lib/random:legacy_lib_random_headers", "//tensorflow/core/lib/strings:legacy_lib_string_headers", "//tensorflow/core/platform:lib_hdrs", "//tensorflow/core/util:lib_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_io_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_io_headers", "@local_tsl//tsl/platform:lib_hdrs", ], visibility = ["//visibility:public"], @@ -852,7 +863,7 @@ filegroup( "//tensorflow/core/graph:mobile_srcs_only_runtime", "//tensorflow/core/kernels:mobile_srcs", "//tensorflow/core/lib/io:mobile_srcs_only_runtime", - "@local_tsl//tsl/lib/io:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/io:mobile_srcs_only_runtime", "//tensorflow/core/nccl:mobile_srcs", "//tensorflow/core/profiler:mobile_srcs_only_runtime", "//tensorflow/core/public:mobile_srcs_only_runtime", @@ -867,7 +878,7 @@ filegroup( "//tensorflow/core/lib/hash:mobile_srcs_only_runtime", "//tensorflow/core/lib/histogram:mobile_srcs_only_runtime", "//tensorflow/core/lib/math:mobile_srcs_only_runtime", - "@local_tsl//tsl/lib/math:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/math:mobile_srcs_only_runtime", "//tensorflow/core/lib/monitoring:mobile_srcs_only_runtime", "//tensorflow/core/lib/random:mobile_srcs_only_runtime", "//tensorflow/core/lib/strings:mobile_srcs_only_runtime", @@ -1211,11 +1222,11 @@ filegroup( "//tensorflow/core/lib/strings:legacy_lib_strings_all_headers", "//tensorflow/core/platform:legacy_lib_internal_headers", "//tensorflow/core/platform:lib_internal_private_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_io_all_headers", - "@local_tsl//tsl/lib/math:math_util.h", "@local_tsl//tsl/platform:legacy_lib_internal_headers", "@local_tsl//tsl/platform:lib_internal_private_hdrs", "@local_xla//xla/tsl/lib/core:legacy_lib_core_all_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_io_all_headers", + "@local_xla//xla/tsl/lib/math:math_util.h", ] + glob( [ "lib/**/*.h", @@ -1244,9 +1255,9 @@ filegroup( "//tensorflow/core/platform:legacy_platform_lib_hdrs", "//tensorflow/core/platform:lib_internal_public_hdrs", "//tensorflow/core/util:lib_internal_public_hdrs", - "@local_tsl//tsl/lib/io:legacy_lib_internal_public_headers", "@local_tsl//tsl/platform:lib_internal_public_hdrs", "@local_xla//xla/stream_executor/integrations:device_mem_allocator_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_internal_public_headers", ], visibility = ["//visibility:private"], ) @@ -1449,7 +1460,7 @@ cc_library( "@com_google_protobuf//:protobuf", "@double_conversion//:double-conversion", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/math:math_util", + "@local_xla//xla/tsl/lib/math:math_util", "@ml_dtypes//:float8", "@ml_dtypes//:intn", "@snappy", @@ -1476,7 +1487,7 @@ cc_library( "@local_xla//xla:autotune_results_proto_cc_impl", "@local_xla//xla:autotuning_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "@local_xla//xla/tsl/protobuf:protos_all_cc_impl", "@local_xla//xla:xla_proto_cc_impl", "@local_xla//xla:xla_data_proto_cc_impl", "@local_xla//xla/service:hlo_proto_cc_impl", @@ -1539,11 +1550,13 @@ cc_library( alias( name = "portable_jpeg_internal", actual = "//tensorflow/core/lib/jpeg:portable_jpeg_internal", + visibility = ["//visibility:public"], ) alias( name = "portable_gif_internal", actual = "//tensorflow/core/lib/gif:portable_gif_internal", + visibility = ["//visibility:public"], ) alias( @@ -1670,6 +1683,7 @@ tf_cuda_library( ":protos_all_cc", "//tensorflow/compiler/jit:common", "//tensorflow/core/activity_watcher", + "//tensorflow/core/config:flag_defs", "//tensorflow/core/example:feature_util", "//tensorflow/core/framework:allocator", "//tensorflow/core/framework:allocator_registry_impl", @@ -1727,7 +1741,7 @@ tf_cuda_library( alwayslink = 1, ) -cc_header_only_library( +custom_op_cc_header_only_library( name = "framework_headers_lib", # Fully depend on external repositories, because identifying the headers # is fragile. @@ -1857,7 +1871,7 @@ cc_library( hdrs = [ "//tensorflow/core/lib/gtl:legacy_lib_test_internal_headers", "//tensorflow/core/lib/io:legacy_lib_test_internal_headers", - "@local_tsl//tsl/lib/io:legacy_lib_test_internal_headers", + "@local_xla//xla/tsl/lib/io:legacy_lib_test_internal_headers", ], deps = [ ":lib", @@ -2040,7 +2054,6 @@ transitive_hdrs( # py_proto_library( # name = "protos_all_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":protos_all"], # ) diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index a0d40f014fa89a..de5c7d6e021588 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -1214,13 +1214,21 @@ cc_library( hdrs = ["simplify_ici_dummy_variables_pass.h"], copts = tf_copts(), deps = [ - ":colocate_predecessor_trees_pass", ":optimization_registry", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core/config:flag_defs", + "//tensorflow/core/config:flags", + "//tensorflow/core/framework:node_def_util", "//tensorflow/core/framework:tensor_proto_cc", "//tensorflow/core/framework:tensor_shape_proto_cc", + "//tensorflow/core/platform:bfloat16", + "//tensorflow/core/platform:logging", "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", ], alwayslink = 1, ) @@ -2440,6 +2448,9 @@ tf_cc_tests( "threadpool_device_test.cc", ], create_named_test_suite = True, + data = [ + "testdata/simplify_ici_dummy_variables_pass_before.pbtxt", + ], linkopts = select({ "//tensorflow:macos": ["-headerpad_max_install_names"], "//conditions:default": [], @@ -2475,6 +2486,7 @@ tf_cc_tests( "//tensorflow/core/util:protos_test_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", ], diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc index 473be0c108896d..02ff8ebdc9a3ef 100644 --- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc +++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc @@ -46,7 +46,7 @@ Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) { // third-party libraries aren't currently supported. class AccumulateNV2RemovePass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override { + absl::Status Run(const GraphOptimizationPassOptions& options) override { // TODO(freiss.oss@gmail.com): Substantial shared code with // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes // a third similar rewrite. @@ -101,7 +101,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { return absl::OkStatus(); } - Status RewriteIntoTempVariable(Node* n, Graph* g) { + absl::Status RewriteIntoTempVariable(Node* n, Graph* g) { VLOG(3) << "Rewrite AccumulateNV2 into TemporaryVariable and Assign: " << SummarizeNode(*n); @@ -229,7 +229,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { return absl::OkStatus(); } - Status RewriteIntoAddN(Node* n, Graph* g) { + absl::Status RewriteIntoAddN(Node* n, Graph* g) { VLOG(3) << "Rewrite AccumulateNV2 into AddN: " << SummarizeNode(*n); AttrSlice n_attrs = n->attrs(); diff --git a/tensorflow/core/common_runtime/all_to_all.cc b/tensorflow/core/common_runtime/all_to_all.cc index 60a00dbecf35bf..0b9a2c0e51dc78 100644 --- a/tensorflow/core/common_runtime/all_to_all.cc +++ b/tensorflow/core/common_runtime/all_to_all.cc @@ -43,8 +43,8 @@ AllToAll::AllToAll() : col_ctx_(nullptr), col_params_(nullptr), done_(nullptr), counter_(0) {} StatusCallback AllToAll::CheckCounterAndCallDone() { - return [this](const Status& s) { - Status final_status; + return [this](const absl::Status& s) { + absl::Status final_status; { mutex_lock l(mu_); status_.Update(s); @@ -75,7 +75,7 @@ StatusCallback AllToAll::CheckCounterAndCallDone() { }; } -Status AllToAll::InitializeCollectiveContext( +absl::Status AllToAll::InitializeCollectiveContext( std::shared_ptr col_ctx) { if (col_ctx->input->dim_size(0) != col_ctx->col_params->group.group_size) { return errors::InvalidArgument("input to all-to-all first dimension size (", diff --git a/tensorflow/core/common_runtime/all_to_all.h b/tensorflow/core/common_runtime/all_to_all.h index 38bfd3ddc2058a..f0fb1651c7ae43 100644 --- a/tensorflow/core/common_runtime/all_to_all.h +++ b/tensorflow/core/common_runtime/all_to_all.h @@ -33,13 +33,14 @@ class AllToAll : public CollectiveImplementationInterface { void Run(StatusCallback done) override; - Status InitializeCollectiveParams(CollectiveParams* col_params) override { + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override { return absl::OkStatus(); } // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. - Status InitializeCollectiveContext( + absl::Status InitializeCollectiveContext( std::shared_ptr col_ctx) override; private: @@ -50,7 +51,7 @@ class AllToAll : public CollectiveImplementationInterface { std::vector output_chunks_; StatusCallback done_; mutex mu_; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); int counter_ TF_GUARDED_BY(mu_); void DispatchSend(int src_rank, int target_rank, const Tensor* tensor, diff --git a/tensorflow/core/common_runtime/all_to_all_test.cc b/tensorflow/core/common_runtime/all_to_all_test.cc index 105b8058997088..ba483eb9452adc 100644 --- a/tensorflow/core/common_runtime/all_to_all_test.cc +++ b/tensorflow/core/common_runtime/all_to_all_test.cc @@ -115,8 +115,8 @@ TEST_F(AllToAllTest, Failure) { Device* device = nullptr; TF_CHECK_OK(test_env_->device_mgr->LookupDevice( col_params->group.members[i].device.name(), &device)); - Status status = RunCollective(test_env_.get(), col_params.get(), device, - &tensors[i], &tensors[i]); + absl::Status status = RunCollective(test_env_.get(), col_params.get(), + device, &tensors[i], &tensors[i]); if (!status.ok()) { mutex_lock l(mu); ++num_failures; @@ -147,8 +147,8 @@ TEST_F(AllToAllTest, WrongFirstDimensionSize) { Device* device = nullptr; TF_CHECK_OK(test_env_->device_mgr->LookupDevice( col_params->group.members[i].device.name(), &device)); - Status status = RunCollective(test_env_.get(), col_params.get(), device, - &tensors[i], &tensors[i]); + absl::Status status = RunCollective(test_env_.get(), col_params.get(), + device, &tensors[i], &tensors[i]); counter.DecrementCount(); EXPECT_TRUE(errors::IsInvalidArgument(status)); }); diff --git a/tensorflow/core/common_runtime/arg_ret_placement.cc b/tensorflow/core/common_runtime/arg_ret_placement.cc index 4a0ceba5da1e99..a1a8c85ca64e1d 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement.cc @@ -59,7 +59,7 @@ bool LogMemoryTypeMismatch(bool use_host_memory, const FullTypeDef& ft) { return true; } -Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft) { +absl::Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft) { FullTypeId id = ft.type_id(); MemoryType mt_from_ft = MemoryTypeFromFullTypeId(id); if (id == TFT_PRODUCT) { @@ -78,7 +78,7 @@ Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft) { // Note that ints_on_device is only true for single device functions // (i.e. for cases where Placer is not run). -static Status SetMemoryTypeForNode( +static absl::Status SetMemoryTypeForNode( const Node* node, const DataType dtype, bool is_arg, bool weak_flag, bool ints_on_device, MemoryTypeVector* memory_types, std::vector* alloc_attrs) { @@ -156,7 +156,7 @@ static Status SetMemoryTypeForNode( } // This helper function takes a list of nodes. -static Status SetMemoryTypeHelper( +static absl::Status SetMemoryTypeHelper( const absl::InlinedVector& nodes, const DataTypeVector& dtypes, bool is_arg, bool weak_flag, MemoryTypeVector* memory_types, std::vector* alloc_attrs) { @@ -176,7 +176,7 @@ static Status SetMemoryTypeHelper( // Note that ints_on_device is only true for single device functions // (i.e. for cases where Placer is not run). The DataType specified by the "T" // attr of input nodes is used. -static Status SetMemoryTypeHelper( +static absl::Status SetMemoryTypeHelper( const std::vector> arg_nodes, bool weak_flag, bool ints_on_device, std::vector* alloc_attrs) { @@ -199,7 +199,7 @@ static Status SetMemoryTypeHelper( // Note that ints_on_device is only true for single device functions // (i.e. for cases where Placer is not run). The DataType specified by the "T" // attr of input nodes is used. -static Status SetMemoryTypeHelper( +static absl::Status SetMemoryTypeHelper( const std::vector> ret_nodes, bool weak_flag, bool ints_on_device, std::vector* alloc_attrs) { DCHECK(alloc_attrs != nullptr); @@ -217,9 +217,9 @@ static Status SetMemoryTypeHelper( return absl::OkStatus(); } -Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types) { +absl::Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, /*weak_flag=*/false, &memory_types, nullptr); } @@ -227,77 +227,77 @@ Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, // TODO(b/258849883) Delete the `Weak...` versions of these functions once // everything is working with the version without `Weak`. -Status WeakSetMemoryTypeForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types) { +absl::Status WeakSetMemoryTypeForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, /*weak_flag=*/true, &memory_types, nullptr); } -Status SetMemoryTypeForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types) { +absl::Status SetMemoryTypeForRets(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/false, /*weak_flag=*/false, &memory_types, nullptr); } -Status WeakSetMemoryTypeForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types) { +absl::Status WeakSetMemoryTypeForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/false, /*weak_flag=*/true, &memory_types, nullptr); } -Status SetAllocAttrsForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs) { +absl::Status SetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, /*weak_flag=*/false, nullptr, &alloc_attrs); } -Status WeakSetAllocAttrsForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs) { +absl::Status WeakSetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, /*weak_flag=*/true, nullptr, &alloc_attrs); } -Status SetAllocAttrsForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs) { +absl::Status SetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/false, /*weak_flag=*/false, nullptr, &alloc_attrs); } -Status WeakSetAllocAttrsForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs) { +absl::Status WeakSetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/false, /*weak_flag=*/true, nullptr, &alloc_attrs); } -Status SingleDeviceSetAllocAttrsForArgs( +absl::Status SingleDeviceSetAllocAttrsForArgs( std::vector> arg_nodes, bool ints_on_device, std::vector& alloc_attrs) { return SetMemoryTypeHelper(arg_nodes, /*weak_flag=*/false, ints_on_device, &alloc_attrs); } -Status WeakSingleDeviceSetAllocAttrsForArgs( +absl::Status WeakSingleDeviceSetAllocAttrsForArgs( std::vector> arg_nodes, bool ints_on_device, std::vector& alloc_attrs) { return SetMemoryTypeHelper(arg_nodes, /*weak_flag=*/true, ints_on_device, &alloc_attrs); } -Status SingleDeviceSetAllocAttrsForRets( +absl::Status SingleDeviceSetAllocAttrsForRets( const std::vector> ret_nodes, bool ints_on_device, std::vector& alloc_attrs) { return SetMemoryTypeHelper(ret_nodes, /*weak_flag=*/false, ints_on_device, &alloc_attrs); } -Status WeakSingleDeviceSetAllocAttrsForRets( +absl::Status WeakSingleDeviceSetAllocAttrsForRets( const std::vector> ret_nodes, bool ints_on_device, std::vector& alloc_attrs) { return SetMemoryTypeHelper(ret_nodes, /*weak_flag=*/true, ints_on_device, diff --git a/tensorflow/core/common_runtime/arg_ret_placement.h b/tensorflow/core/common_runtime/arg_ret_placement.h index 63d9dbb794f37b..e0b401823975a5 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.h +++ b/tensorflow/core/common_runtime/arg_ret_placement.h @@ -31,9 +31,9 @@ namespace tensorflow::full_type { // expected full_type information. If an error raised about bad full // time information causes a breakage, changing `SetMemoryTypeForArgs` to // `WeakSetMemoryTypeForArgs` is a possible work around. -Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types); +absl::Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); // TODO(b/258849883) Delete the `Weak...` versions of these functions once // everything is working with the version without `Weak`. @@ -41,9 +41,9 @@ Status SetMemoryTypeForArgs(const absl::InlinedVector& nodes, // Set the contents of memory_types for args (inputs to functions, "_Arg" ops) // based on dtype. Logging of warnings if an int32 arg does not have // expected full_type information can be enabled. -Status WeakSetMemoryTypeForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types); +absl::Status WeakSetMemoryTypeForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); // Set the contents of memory_types for rets (outputs from functions, "_Retval" // ops) based on dtype. Raises an error if an int32 ret does not have @@ -51,33 +51,33 @@ Status WeakSetMemoryTypeForArgs(const absl::InlinedVector& nodes, // does not have expected full type information). If an error raised about bad // full time information causes a breakage, changing `SetMemoryTypeForRets` to // `WeakSetMemoryTypeForRets` is a possible work around. -Status SetMemoryTypeForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types); +absl::Status SetMemoryTypeForRets(const absl::InlinedVector& nodes, + const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); // Set the contents of memory_types for rets (outputs from functions, "_Retval" // ops) based on dtype. Logging of warnings if an int32 ret does not have // expected full_type information (i.e. if the source of the input to the ret // does not have expected full type information) can be enabled. -Status WeakSetMemoryTypeForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - MemoryTypeVector& memory_types); +absl::Status WeakSetMemoryTypeForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + MemoryTypeVector& memory_types); // Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) // based on dtype. Raises an error if an int32 arg does not have // expected full_type information. If an error raised about bad full // time information causes a breakage, changing `SetAllocAttrsForArgs` to // `WeakSetAllocAttrsForArgs` is a possible work around. -Status SetAllocAttrsForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs); +absl::Status SetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); // Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) // based on dtype. Logging of warnings if an int32 arg does not have // expected full_type information can be enabled. -Status WeakSetAllocAttrsForArgs(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs); +absl::Status WeakSetAllocAttrsForArgs( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); // Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" // ops) based on dtype. Raises an error if an int32 ret does not have @@ -85,17 +85,17 @@ Status WeakSetAllocAttrsForArgs(const absl::InlinedVector& nodes, // does not have expected full type information). If an error raised about bad // full time information causes a breakage, changing `SetAllocAttrsForRets` to // `WeakSetAllocAttrsForRets` is a possible work around. -Status SetAllocAttrsForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs); +absl::Status SetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); // Set the contents of alloc_attrs for rets (outputs from functions, "_Retval" // ops) based on dtype. Logging of warnings if an int32 ret does not have // expected full_type information (i.e. if the source of the input to the ret // does not have expected full type information) can be enabled. -Status WeakSetAllocAttrsForRets(const absl::InlinedVector& nodes, - const DataTypeVector& dtypes, - std::vector& alloc_attrs); +absl::Status WeakSetAllocAttrsForRets( + const absl::InlinedVector& nodes, const DataTypeVector& dtypes, + std::vector& alloc_attrs); // Set the contents of alloc_attrs for args (inputs to functions, "_Arg" ops) // for a single device funtion based on dtype. Raises an error if an int32 arg @@ -104,7 +104,7 @@ Status WeakSetAllocAttrsForRets(const absl::InlinedVector& nodes, // `SingleDeviceSetAllocAttrsForArgs` to `WeakSingleDeviceSetAllocAttrsForArgs` // is a possible work around. The DataType specified by the "T" attr of input // nodes is used. -Status SingleDeviceSetAllocAttrsForArgs( +absl::Status SingleDeviceSetAllocAttrsForArgs( std::vector> arg_nodes, bool ints_on_device, std::vector& alloc_attrs); @@ -112,7 +112,7 @@ Status SingleDeviceSetAllocAttrsForArgs( // for a single device based on dtype. Logging of warnings if an int32 arg does // not have expected full_type information can be enabled. The DataType // specified by the "T" attr of input nodes is used. -Status WeakSingleDeviceSetAllocAttrsForArgs( +absl::Status WeakSingleDeviceSetAllocAttrsForArgs( std::vector> arg_nodes, bool ints_on_device, std::vector& alloc_attrs); @@ -124,7 +124,7 @@ Status WeakSingleDeviceSetAllocAttrsForArgs( // `SingleDeviceSetAllocAttrsForRets` to `WeakSingleDeviceSetAllocAttrsForRets` // is a possible work around. The DataType specified by the "T" attr of input // nodes is used. -Status SingleDeviceSetAllocAttrsForRets( +absl::Status SingleDeviceSetAllocAttrsForRets( std::vector> ret_nodes, bool ints_on_device, std::vector& alloc_attrs); @@ -133,7 +133,7 @@ Status SingleDeviceSetAllocAttrsForRets( // does not have expected full_type information (i.e. if the source of the input // to the ret does not have expected full type information) can be enabled. The // DataType specified by the "T" attr of input nodes is used. -Status WeakSingleDeviceSetAllocAttrsForRets( +absl::Status WeakSingleDeviceSetAllocAttrsForRets( std::vector> ret_nodes, bool ints_on_device, std::vector& alloc_attrs); @@ -151,7 +151,7 @@ bool LogMemoryTypeMismatch(bool use_host_memory, const FullTypeDef& ft); // and raise an error if not. Note the FT is expected to be the full type // information for a tensor, not for the whole ouput of an op, i.e. it should // not have an outer TFT_PRODUCT. -Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft); +absl::Status CheckMemoryType(bool use_host_memory, const FullTypeDef& ft); } // namespace tensorflow::full_type diff --git a/tensorflow/core/common_runtime/arg_ret_placement_test.cc b/tensorflow/core/common_runtime/arg_ret_placement_test.cc index 0d5bbe1443f1cc..8e8ea9ffb6caf5 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement_test.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement_test.cc @@ -41,14 +41,14 @@ class FullTypeGraphUtilsTest : public ::testing::Test { : graph_(OpRegistry::Global()), root_(Scope::NewRootScope().ExitOnError()) {} - Status MakeArg(Node **arg, DataType dtype) { + absl::Status MakeArg(Node **arg, DataType dtype) { return NodeBuilder("arg", "_Arg", &root_.graph()->flib_def()) .Attr("T", dtype) .Attr("index", 0) .Finalize(root_.graph(), arg); } - Status MakeRet(Node *src, Node **ret, DataType dtype) { + absl::Status MakeRet(Node *src, Node **ret, DataType dtype) { return NodeBuilder("ret", "_Retval", &root_.graph()->flib_def()) .Input(src, 0) .Attr("T", dtype) @@ -57,7 +57,7 @@ class FullTypeGraphUtilsTest : public ::testing::Test { } public: - Status MakeArgRet(Node **arg, Node **ret, DataType dtype) { + absl::Status MakeArgRet(Node **arg, Node **ret, DataType dtype) { TF_RETURN_IF_ERROR(MakeArg(arg, dtype)); return MakeRet(*arg, ret, dtype); } @@ -152,7 +152,8 @@ TEST_F(FullTypeGraphUtilsTest, ArgError) { nodes.push_back(arg); dtypes.push_back(DT_INT32); - Status status = full_type::SetMemoryTypeForArgs(nodes, dtypes, memory_types); + absl::Status status = + full_type::SetMemoryTypeForArgs(nodes, dtypes, memory_types); EXPECT_FALSE(status.ok()); } @@ -230,7 +231,8 @@ TEST_F(FullTypeGraphUtilsTest, RetError) { TF_ASSERT_OK(MakeArgRet(&arg, &ret, DT_INT32)); nodes.push_back(ret); dtypes.push_back(DT_INT32); - Status status = full_type::SetMemoryTypeForRets(nodes, dtypes, memory_types); + absl::Status status = + full_type::SetMemoryTypeForRets(nodes, dtypes, memory_types); EXPECT_FALSE(status.ok()); } @@ -302,7 +304,7 @@ TEST_F(FullTypeGraphUtilsTest, SingleDeviceAllocAttrsRetError) { // test TFT_SHAPE_TENSOR and ints_on_device=true mismatch AddArgFullType(arg, TFT_SHAPE_TENSOR, TFT_INT32); ret_nodes.push_back(std::make_pair(ret, 0)); - Status status = full_type::SingleDeviceSetAllocAttrsForRets( + absl::Status status = full_type::SingleDeviceSetAllocAttrsForRets( ret_nodes, /*ints_on_device=*/true, alloc_attrs); EXPECT_FALSE(status.ok()); } @@ -348,7 +350,7 @@ TEST_F(FullTypeGraphUtilsTest, CheckMemoryTypeBadFT) { AddArgFullType(node, TFT_SHAPE_TENSOR, TFT_INT32); // full type information for the whole node, not for one tensor / one output const FullTypeDef &ft = node->def().experimental_type(); - Status status = full_type::CheckMemoryType(true, ft); + absl::Status status = full_type::CheckMemoryType(true, ft); EXPECT_FALSE(status.ok()); } @@ -358,7 +360,7 @@ TEST_F(FullTypeGraphUtilsTest, CheckMemoryTypeWrongFT) { AddArgFullType(node, TFT_SHAPE_TENSOR, TFT_INT32); const FullTypeDef &ft = node->def().experimental_type().args()[0]; // use_host_memory=false does not match TFT_SHAPE_TENSOR - Status status = full_type::CheckMemoryType(false, ft); + absl::Status status = full_type::CheckMemoryType(false, ft); EXPECT_FALSE(status.ok()); } diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc index c5ed8fe9042032..c9ed34b29cd78c 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.cc +++ b/tensorflow/core/common_runtime/base_collective_executor.cc @@ -228,8 +228,8 @@ CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks, BaseCollectiveExecutor::~BaseCollectiveExecutor() {} -void BaseCollectiveExecutor::StartAbort(const Status& s) { - Status status; +void BaseCollectiveExecutor::StartAbort(const absl::Status& s) { + absl::Status status; { mutex_lock l(status_mu_); if (!status_.ok()) { @@ -237,7 +237,7 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) { << s; return; } - status_ = StatusGroup::MakeDerived(Status( + status_ = StatusGroup::MakeDerived(absl::Status( s.code(), absl::StrCat( "Collective ops is aborted by: ", s.message(), @@ -253,7 +253,7 @@ void BaseCollectiveExecutor::StartAbort(const Status& s) { } } -Status BaseCollectiveExecutor::GetStatus(const Status& s) { +absl::Status BaseCollectiveExecutor::GetStatus(const absl::Status& s) { if (s.ok()) return s; mutex_lock l(status_mu_); // If the collective executor is already aborted, use the aborted status @@ -274,7 +274,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, StatusCallback done) { // See CompleteParamsAsync() how done() and the timeout callback interacts. const auto is_callback_called = std::make_shared>(false); - auto done_safe = [this, done, ctx, is_callback_called](const Status& s) { + auto done_safe = [this, done, ctx, + is_callback_called](const absl::Status& s) { bool called = is_callback_called->exchange(true); if (!called) { if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) { @@ -293,8 +294,8 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, timeout_microseconds, [this, is_callback_called, done] { bool called = is_callback_called->exchange(true); if (!called) { - Status status(absl::StatusCode::kDeadlineExceeded, - "Collective has timed out during execution."); + absl::Status status(absl::StatusCode::kDeadlineExceeded, + "Collective has timed out during execution."); StartAbort(status); done(status); } @@ -313,7 +314,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, ? &ctx->input(0) : nullptr; CollectiveImplementationInterface* col_impl = nullptr; - Status status = CreateCollective(*col_params, &col_impl); + absl::Status status = CreateCollective(*col_params, &col_impl); if (!status.ok()) { done_safe(status); DCHECK_EQ(nullptr, col_impl); @@ -331,15 +332,17 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, // Run on an unbounded work queue that can handle blocking work so as to not // starve executor threads. col_impl->Ref(); - profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync"); + tsl::profiler::TraceMeProducer producer( + "BaseCollectiveExecutor::ExecuteAsync"); RunClosure([col_impl, col_ctx, done_safe, ctx, context_id = producer.GetContextId()]() { core::ScopedUnref unref(col_impl); - profiler::TraceMeConsumer consumer( + tsl::profiler::TraceMeConsumer consumer( [ctx, col_ctx] { - string op = profiler::TraceMeOp(ctx->op_kernel().name_view(), - ctx->op_kernel().type_string_view()); - return profiler::TraceMeEncode( + string op = + tsl::profiler::TraceMeOp(ctx->op_kernel().name_view(), + ctx->op_kernel().type_string_view()); + return tsl::profiler::TraceMeEncode( std::move(op), {{"step_id", ctx->step_id()}, {"iter_id", ctx->frame_iter().iter_id}, @@ -350,7 +353,7 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx, }, context_id); col_impl->Ref(); - col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) { + col_impl->Run([col_impl, col_ctx, done_safe](const absl::Status& s) { core::ScopedUnref unref(col_impl); done_safe(s); }); @@ -367,15 +370,15 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // timeout callback executes, done_safe will become a no-op and the timeout // callback is responsible for invoking done() at the end. const auto is_callback_called = std::make_shared>(false); - int64_t trace_id = profiler::TraceMe::ActivityStart([cp]() { - return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams", - {{"group_key", cp->group.group_key}, - {"group_size", cp->group.group_size}}); + int64_t trace_id = tsl::profiler::TraceMe::ActivityStart([cp]() { + return tsl::profiler::TraceMeEncode("CollectiveExecutor::CompleteParams", + {{"group_key", cp->group.group_key}, + {"group_size", cp->group.group_size}}); }); auto done_safe = [this, is_callback_called, cancel_mgr, trace_id, - done](const Status& s) { - profiler::TraceMe::ActivityEnd(trace_id); + done](const absl::Status& s) { + tsl::profiler::TraceMe::ActivityEnd(trace_id); bool called = is_callback_called->exchange(true); if (!called) { if (!s.ok() && !IsCancelled(cancel_mgr)) { @@ -394,6 +397,7 @@ void BaseCollectiveExecutor::CompleteParamsAsync( // TODO(xldrx): Share the timeout watchdog thread among collectives. int64_t usecs = std::min(timeout_microseconds, mio); SchedNonBlockingClosureAfter( +<<<<<<< HEAD usecs, [this, is_callback_called, done, timeout_microseconds, usecs]() { for(auto cnt = timeout_microseconds - usecs; cnt > 0; cnt -= mio) { if(bool called = is_callback_called->exchange(false); called) { @@ -406,6 +410,14 @@ void BaseCollectiveExecutor::CompleteParamsAsync( Status status( absl::StatusCode::kDeadlineExceeded, "Collective has timed out waiting for other workers."); +======= + timeout_microseconds, [this, is_callback_called, done]() { + bool called = is_callback_called->exchange(true); + if (!called) { + absl::Status status( + absl::StatusCode::kDeadlineExceeded, + "Collective has timed out waiting for other workers."); +>>>>>>> master StartAbort(status); done(status); } @@ -415,7 +427,7 @@ void BaseCollectiveExecutor::CompleteParamsAsync( done_safe); } -Status BaseCollectiveExecutor::CreateCollective( +absl::Status BaseCollectiveExecutor::CreateCollective( const CollectiveParams& col_params, CollectiveImplementationInterface** col_impl) { VLOG(2) << "CreateCollective type " diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h index 242b92b2e749f7..0c4689bce5eac5 100644 --- a/tensorflow/core/common_runtime/base_collective_executor.h +++ b/tensorflow/core/common_runtime/base_collective_executor.h @@ -107,7 +107,7 @@ class BaseCollectiveExecutor : public CollectiveExecutor { ~BaseCollectiveExecutor() override; - void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_); + void StartAbort(const absl::Status& s) override TF_LOCKS_EXCLUDED(status_mu_); void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, const string& exec_key, StatusCallback done) override; @@ -147,17 +147,17 @@ class BaseCollectiveExecutor : public CollectiveExecutor { // been launched. std::unordered_map launched_ TF_GUARDED_BY(launch_mu_); mutex status_mu_; - Status status_ TF_GUARDED_BY(status_mu_); + absl::Status status_ TF_GUARDED_BY(status_mu_); private: - Status CreateCollective(const CollectiveParams& col_params, - CollectiveImplementationInterface** col_impl); + absl::Status CreateCollective(const CollectiveParams& col_params, + CollectiveImplementationInterface** col_impl); // Check if all ops on which this collective depends on have launched. bool CheckDependencies(const CollectiveParams& col_params) TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_); // Tries to return the status that is the original error. It returns the // aborted status if the collective executor is aborted. - Status GetStatus(const Status& s) TF_LOCKS_EXCLUDED(status_mu_); + absl::Status GetStatus(const absl::Status& s) TF_LOCKS_EXCLUDED(status_mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/buf_rendezvous.cc b/tensorflow/core/common_runtime/buf_rendezvous.cc index 8bd8a2c1a10ec6..62654d1f5c2d51 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous.cc @@ -43,7 +43,7 @@ BufRendezvous::~BufRendezvous() { } } -void BufRendezvous::StartAbort(const Status& s) { +void BufRendezvous::StartAbort(const absl::Status& s) { CHECK(!s.ok()); HookTable dummy_table; { @@ -58,7 +58,7 @@ void BufRendezvous::StartAbort(const Status& s) { PurgeTable(s, &dummy_table); } -void BufRendezvous::PurgeTable(const Status& s, HookTable* table) { +void BufRendezvous::PurgeTable(const absl::Status& s, HookTable* table) { for (auto& it : *table) { Hook* h = it.second; if (h->cancellation_manager != nullptr) { @@ -96,7 +96,7 @@ void BufRendezvous::ProvideBuf(const string& key, Device* dev, } #endif Hook* h = nullptr; - Status providebuf_status; + absl::Status providebuf_status; do { mutex_lock l(mu_); if (!status_.ok()) { @@ -168,7 +168,7 @@ void BufRendezvous::ConsumeBuf(const string& key, const string& device_name, // Check the incarnation in the request matches the current device // incarnation of the producer. Device* device; - Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device); + absl::Status consumebuf_status = dev_mgr_->LookupDevice(device_name, &device); if (consumebuf_status.ok() && device->attributes().incarnation() != device_incarnation) { consumebuf_status = errors::FailedPrecondition( diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h index 0ee6099eea5494..8c2d201e8781ca 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous.h +++ b/tensorflow/core/common_runtime/buf_rendezvous.h @@ -48,17 +48,17 @@ class BufRendezvous { // Inform all waiting parties that this BufRendezvous is defunct because of // an error Status interrupting the Step. - void StartAbort(const Status& s); + void StartAbort(const absl::Status& s); struct Hook; // Provided by the consumer to be called when access to the buffer // is available. If the Status arg is not OK, then hook will not // be populated. Ownership of Hook passes to consumer with the // callback. - typedef std::function ConsumerCallback; + typedef std::function ConsumerCallback; // Provided by the producer to be called when the consumer has finished // reading the buffer and will no longer access it. - typedef std::function ProducerCallback; + typedef std::function ProducerCallback; struct Hook { Device* prod_dev; @@ -124,11 +124,11 @@ class BufRendezvous { const uint64 step_id_; const DeviceMgr* const dev_mgr_; // Not owned. mutex mu_; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); typedef absl::flat_hash_map HookTable; HookTable hook_table_ TF_GUARDED_BY(mu_); - void PurgeTable(const Status& s, HookTable* table); + void PurgeTable(const absl::Status& s, HookTable* table); }; } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_ diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc index 1bbb828f32c4e6..b549b012a9ffa6 100644 --- a/tensorflow/core/common_runtime/buf_rendezvous_test.cc +++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc @@ -35,7 +35,7 @@ class BufRendezvousTest : public ::testing::Test { public: explicit FakeDevice(const DeviceAttributes& attrs) : Device(nullptr, attrs) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attrs; @@ -80,14 +80,14 @@ const string* const BufRendezvousTest::kDefaultDeviceName = const uint64 BufRendezvousTest::kDefaultIncarnation = 12345; TEST_F(BufRendezvousTest, CorrectUseProducerFirst) { - Status prod_status; - Status cons_status; + absl::Status prod_status; + absl::Status cons_status; bool prod_callback_called = false; bool cons_callback_called = false; Notification note; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [¬e, &prod_status, &prod_callback_called](const Status& s) { + [¬e, &prod_status, &prod_callback_called](const absl::Status& s) { prod_status = s; prod_callback_called = true; note.Notify(); @@ -96,7 +96,7 @@ TEST_F(BufRendezvousTest, CorrectUseProducerFirst) { EXPECT_FALSE(prod_callback_called); br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [this, &cons_status, &cons_callback_called](const Status& s, + [this, &cons_status, &cons_callback_called](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; cons_callback_called = true; @@ -115,14 +115,14 @@ TEST_F(BufRendezvousTest, CorrectUseProducerFirst) { } TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) { - Status prod_status; - Status cons_status; + absl::Status prod_status; + absl::Status cons_status; bool prod_callback_called = false; bool cons_callback_called = false; Notification note; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [this, &cons_status, &cons_callback_called](const Status& s, + [this, &cons_status, &cons_callback_called](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; cons_callback_called = true; @@ -136,7 +136,7 @@ TEST_F(BufRendezvousTest, CorrectUseConsumerFirst) { EXPECT_FALSE(cons_callback_called); br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [¬e, &prod_status, &prod_callback_called](const Status& s) { + [¬e, &prod_status, &prod_callback_called](const absl::Status& s) { prod_status = s; prod_callback_called = true; note.Notify(); @@ -153,13 +153,15 @@ TEST_F(BufRendezvousTest, ErrorDuplicatePut) { bool prod_callback_called = false; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&prod_callback_called](const Status& s) { prod_callback_called = true; }, + [&prod_callback_called](const absl::Status& s) { + prod_callback_called = true; + }, &cm_); - Status bad_status; + absl::Status bad_status; Notification note; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&bad_status, ¬e](const Status& s) { + [&bad_status, ¬e](const absl::Status& s) { bad_status = s; note.Notify(); }, @@ -174,10 +176,10 @@ TEST_F(BufRendezvousTest, ErrorDuplicatePut) { } TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) { - Status cons_status; + absl::Status cons_status; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&cons_status](const Status& s, BufRendezvous::Hook* h) { + [&cons_status](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; EXPECT_EQ(h, nullptr); }, @@ -189,20 +191,21 @@ TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) { } TEST_F(BufRendezvousTest, AbortNonEmpty) { - Status cons_status; - Status prod_status; + absl::Status cons_status; + absl::Status prod_status; Notification prod_note; Notification cons_note; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) { + [&cons_note, &cons_status](const absl::Status& s, + BufRendezvous::Hook* h) { cons_status = s; cons_note.Notify(); }, &cm_); br_->ProvideBuf( "key1", default_device_, fake_device_context_, &a_, aa_, - [&prod_note, &prod_status](const Status& s) { + [&prod_note, &prod_status](const absl::Status& s) { prod_status = s; prod_note.Notify(); }, @@ -222,20 +225,21 @@ TEST_F(BufRendezvousTest, AbortEmpty) { TEST_F(BufRendezvousTest, UseAfterAbort) { br_->StartAbort(errors::Internal("Falling sky detected")); - Status cons_status; - Status prod_status; + absl::Status cons_status; + absl::Status prod_status; Notification prod_note; Notification cons_note; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&cons_note, &cons_status](const Status& s, BufRendezvous::Hook* h) { + [&cons_note, &cons_status](const absl::Status& s, + BufRendezvous::Hook* h) { cons_status = s; cons_note.Notify(); }, &cm_); br_->ProvideBuf( "key1", default_device_, fake_device_context_, &a_, aa_, - [&prod_note, &prod_status](const Status& s) { + [&prod_note, &prod_status](const absl::Status& s) { prod_status = s; prod_note.Notify(); }, @@ -249,15 +253,15 @@ TEST_F(BufRendezvousTest, UseAfterAbort) { } TEST_F(BufRendezvousTest, DeviceIncarnationMismatch) { - Status cons_status; + absl::Status cons_status; Notification note; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [](const Status&) {}, /*cancellation_manager=*/nullptr); + [](const absl::Status&) {}, /*cancellation_manager=*/nullptr); const uint64 incorrect_incarnation = 23456; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, incorrect_incarnation, - [¬e, &cons_status](const Status& s, BufRendezvous::Hook* h) { + [¬e, &cons_status](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; note.Notify(); }, @@ -267,11 +271,11 @@ TEST_F(BufRendezvousTest, DeviceIncarnationMismatch) { } TEST_F(BufRendezvousTest, ProvideThenCancel) { - Status status; + absl::Status status; Notification note; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&status, ¬e](const Status& s) { + [&status, ¬e](const absl::Status& s) { status = s; note.Notify(); }, @@ -286,12 +290,12 @@ TEST_F(BufRendezvousTest, ProvideThenCancel) { } TEST_F(BufRendezvousTest, CancelThenProvide) { - Status status; + absl::Status status; Notification note; cm_.StartCancel(); br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&status, ¬e](const Status& s) { + [&status, ¬e](const absl::Status& s) { status = s; note.Notify(); }, @@ -305,11 +309,11 @@ TEST_F(BufRendezvousTest, CancelThenProvide) { } TEST_F(BufRendezvousTest, ConsumeThenCancel) { - Status status; + absl::Status status; Notification note; br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&status, ¬e](const Status& s, BufRendezvous::Hook* h) { + [&status, ¬e](const absl::Status& s, BufRendezvous::Hook* h) { status = s; note.Notify(); }, @@ -324,12 +328,12 @@ TEST_F(BufRendezvousTest, ConsumeThenCancel) { } TEST_F(BufRendezvousTest, CancelThenConsume) { - Status status; + absl::Status status; Notification note; cm_.StartCancel(); br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&status, ¬e](const Status& s, BufRendezvous::Hook* h) { + [&status, ¬e](const absl::Status& s, BufRendezvous::Hook* h) { status = s; note.Notify(); }, @@ -343,14 +347,14 @@ TEST_F(BufRendezvousTest, CancelThenConsume) { } TEST_F(BufRendezvousTest, ProvideConsumeThenCancel) { - Status prod_status; - Status cons_status; + absl::Status prod_status; + absl::Status cons_status; bool prod_callback_called = false; bool cons_callback_called = false; Notification note; br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [¬e, &prod_status, &prod_callback_called](const Status& s) { + [¬e, &prod_status, &prod_callback_called](const absl::Status& s) { prod_status = s; prod_callback_called = true; note.Notify(); @@ -359,7 +363,7 @@ TEST_F(BufRendezvousTest, ProvideConsumeThenCancel) { EXPECT_FALSE(prod_callback_called); br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [this, &cons_status, &cons_callback_called](const Status& s, + [this, &cons_status, &cons_callback_called](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; cons_callback_called = true; @@ -379,14 +383,14 @@ TEST_F(BufRendezvousTest, ProvideConsumeThenCancel) { } TEST_F(BufRendezvousTest, CancelThenProvideConsume) { - Status prod_status; - Status cons_status; + absl::Status prod_status; + absl::Status cons_status; bool prod_callback_called = false; bool cons_callback_called = false; cm_.StartCancel(); br_->ProvideBuf( *kDefaultKey, default_device_, fake_device_context_, &a_, aa_, - [&prod_status, &prod_callback_called](const Status& s) { + [&prod_status, &prod_callback_called](const absl::Status& s) { prod_status = s; EXPECT_TRUE(errors::IsCancelled(prod_status)); prod_callback_called = true; @@ -396,7 +400,7 @@ TEST_F(BufRendezvousTest, CancelThenProvideConsume) { EXPECT_TRUE(errors::IsCancelled(prod_status)); br_->ConsumeBuf( *kDefaultKey, *kDefaultDeviceName, kDefaultIncarnation, - [&cons_status, &cons_callback_called](const Status& s, + [&cons_status, &cons_callback_called](const absl::Status& s, BufRendezvous::Hook* h) { cons_status = s; EXPECT_TRUE(errors::IsCancelled(cons_status)); diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc index 86f70de4158efd..bfa7b0bcfc7da7 100644 --- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc +++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc @@ -81,22 +81,22 @@ TEST_F(CollectiveExecutorMgrTest, FindOrCreate) { TEST_F(CollectiveExecutorMgrTest, StepSequenceRelated) { EXPECT_EQ(CollectiveExecutor::kInvalidId, cme_->NextStepId(123)); Notification ss_note; - Status ss_status; - cme_->RefreshStepIdSequenceAsync(123, - [&ss_status, &ss_note](const Status& s) { - ss_status = s; - ss_note.Notify(); - }); + absl::Status ss_status; + cme_->RefreshStepIdSequenceAsync( + 123, [&ss_status, &ss_note](const absl::Status& s) { + ss_status = s; + ss_note.Notify(); + }); ss_note.WaitForNotification(); EXPECT_FALSE(ss_status.ok()); EXPECT_EQ(ss_status.message(), "CollectiveExecutorMgr does not implement RefreshStepIdSequence."); Notification gs_note; - Status gs_status; + absl::Status gs_status; GetStepSequenceRequest* req = nullptr; GetStepSequenceResponse* resp = nullptr; cme_->GetStepSequenceAsync(req, resp, - [&gs_status, &gs_note](const Status& s) { + [&gs_status, &gs_note](const absl::Status& s) { gs_status = s; gs_note.Notify(); }); diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 46de7b68645064..ea16129c33cd42 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -101,7 +101,8 @@ struct RankFormatter { } }; -Status CheckUserSpecifiedRanks(const std::vector members) { +absl::Status CheckUserSpecifiedRanks( + const std::vector members) { absl::flat_hash_set user_ranks = {}; bool at_least_one_member_with_no_rank = false; bool at_least_one_member_with_user_rank = false; @@ -139,7 +140,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( std::vector to_be_called; GroupRec* gr = nullptr; - Status status; + absl::Status status; { mutex_lock l(group_mu_); auto it = group_table_.find(group_params->group_key); @@ -182,7 +183,7 @@ void CollectiveParamResolverLocal::CompleteGroupLocal( return; } done = [cancel_mgr, token, - original_done = std::move(done)](const Status& status) { + original_done = std::move(done)](const absl::Status& status) { cancel_mgr->TryDeregisterCallback(token); original_done(status); }; @@ -612,7 +613,7 @@ CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp, instance_table_[cp->group.group_key][key].reset(irec); } } - Status status; + absl::Status status; { mutex_lock l(status_mu_); status = status_; @@ -624,8 +625,8 @@ CollectiveParamResolverLocal::GetOrCreateInstanceRec(CollectiveParams* cp, return irec; } -Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key, - CollGroupParams* group) { +absl::Status CollectiveParamResolverLocal::LookupGroup(int32_t group_key, + CollGroupParams* group) { mutex_lock l(group_mu_); auto group_rec = group_table_.find(group_key); if (group_rec == group_table_.end()) { @@ -654,7 +655,7 @@ void CollectiveParamResolverLocal::CompleteParamsAsync( << cp->ToString(); if (cp->run_group_initialization) { CompleteGroupLocal(device, &cp->group, cancel_mgr, - [this, device, cp, done](const Status& s) { + [this, device, cp, done](const absl::Status& s) { if (s.ok()) { CompleteInstanceLocal(device.name(), cp, done); } else { @@ -731,7 +732,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( const string& device, CollectiveParams* cp, InstanceRec* ir, const StatusCallback& done) { auto expected_shape = cp->instance.shape; - Status status; + absl::Status status; // Populate the fields common across instance. { mutex_lock l(ir->mu); @@ -771,7 +772,7 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec( // discovery. if (cp->instance.type == BROADCAST_COLLECTIVE) { WaitForGroup(ir, cp, [col_impl, ir, device, cp, done](InstanceRec* irec) { - Status s; + absl::Status s; if (ir != irec) { s = errors::Internal("Expected ir ", ir, " and irec ", irec, " to be equal"); @@ -841,7 +842,7 @@ void CollectiveParamResolverLocal::WaitForGroup(InstanceRec* ir, } } -void CollectiveParamResolverLocal::StartAbort(const Status& s) { +void CollectiveParamResolverLocal::StartAbort(const absl::Status& s) { { mutex_lock l(status_mu_); if (!status_.ok()) { @@ -855,7 +856,7 @@ void CollectiveParamResolverLocal::StartAbort(const Status& s) { StartAbortLocal(s); } -void CollectiveParamResolverLocal::StartAbortLocal(const Status& s) { +void CollectiveParamResolverLocal::StartAbortLocal(const absl::Status& s) { std::vector pending_done; { mutex_lock l(group_mu_); diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h index 6a702e237c171c..88813b0e98abb1 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.h +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h @@ -63,9 +63,9 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { CancellationManager* cancel_mgr, const StatusCallback& done) override; - Status LookupGroup(int32_t group_key, CollGroupParams* group) override; + absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) override; - void StartAbort(const Status& s) override; + void StartAbort(const absl::Status& s) override; protected: // For access to InstanceRec and CompleteDefaultRanking. @@ -75,7 +75,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { struct GroupRec { mutable mutex mu; CollGroupParams group TF_GUARDED_BY(mu); - Status status TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); std::unordered_map incarnations_by_device_name TF_GUARDED_BY(mu); std::vector pending_params TF_GUARDED_BY(mu); @@ -100,7 +100,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); // Lookup and populate parameters from an already initialized group. - Status LookupAndPopulateGroupParams(CollGroupParams* group_params); + absl::Status LookupAndPopulateGroupParams(CollGroupParams* group_params); // Used to complete/verify CollInstance. struct InstanceRec; @@ -113,7 +113,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // If an error occurs during initialization this structure stays in the // table with a non-OK status. Purging the table and restarting needs to be // done at a higher level. - Status status TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); // These fields are used to count the instances that have called // in and become known while resolving broadcast source identity and @@ -172,8 +172,8 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // If cp.device_names contains only devices local to this process // populates *localities, else returns an error. - Status GetLocalDeviceLocalities(const CollectiveParams& cp, - std::vector* localities); + absl::Status GetLocalDeviceLocalities( + const CollectiveParams& cp, std::vector* localities); // Sets cp->instance_default_rank according to location of device in // current ordering of cp->instance.device_names. @@ -183,7 +183,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { // best implementation. void AssignCollectiveType(CollectiveParams* cp); - void StartAbortLocal(const Status& s) + void StartAbortLocal(const absl::Status& s) TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_); const bool nccl_; @@ -207,7 +207,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface { std::unique_ptr, TupleHash>> instance_table_ TF_GUARDED_BY(instance_mu_); mutex status_mu_; - Status status_ TF_GUARDED_BY(status_mu_); + absl::Status status_ TF_GUARDED_BY(status_mu_); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc index 24d88fa99297c5..e6c8051215a7d4 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc @@ -162,7 +162,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteDefaultRanking) { TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { CollectiveParams* cps[NUM_DEVS]; - Status statuses[NUM_DEVS]; + absl::Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { cps[i] = new CollectiveParams(); @@ -182,7 +182,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { + [&statuses, ¬e, i](const absl::Status& s) { statuses[i] = s; note[i].Notify(); }); @@ -226,7 +226,7 @@ void InitializeCollectiveParamsForBroadcast(int instance_key, int device_idx, TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { constexpr int kInstanceKey = 5; CollectiveParams* cps[NUM_DEVS]; - Status statuses[NUM_DEVS]; + absl::Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { cps[i] = new CollectiveParams(); @@ -237,7 +237,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { + [&statuses, ¬e, i](const absl::Status& s) { statuses[i] = s; note[i].Notify(); }); @@ -268,7 +268,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) { TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { constexpr int kInstanceKey = 8; CollectiveParams* cps[NUM_DEVS]; - Status statuses[NUM_DEVS]; + absl::Status statuses[NUM_DEVS]; Notification note[NUM_DEVS]; for (int i = 0; i < NUM_DEVS; ++i) { cps[i] = new CollectiveParams(); @@ -279,7 +279,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcastForgotSender) { strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", i); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, nullptr /*CancellationManager*/, - [&statuses, ¬e, i](const Status& s) { + [&statuses, ¬e, i](const absl::Status& s) { statuses[i] = s; note[i].Notify(); }); @@ -327,7 +327,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { cp[i] = MakeCollectiveParams(/*group_key*/ 100, /*instance_key*/ 100, /*is_source*/ i == 0); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, - [&done, cp = cp[i]](const Status& s) { + [&done, cp = cp[i]](const absl::Status& s) { EXPECT_EQ(s.code(), absl::StatusCode::kAborted); EXPECT_EQ(s.message(), "__aborted__"); @@ -338,7 +338,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingGroup) { }); } start.Wait(); - prl_->StartAbort(Status(absl::StatusCode::kAborted, "__aborted__")); + prl_->StartAbort(absl::Status(absl::StatusCode::kAborted, "__aborted__")); done.Wait(); } @@ -359,7 +359,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { /*is_source*/ i == 0); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, - [&done, cp = cp[i]](const Status& s) { + [&done, cp = cp[i]](const absl::Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); cp->Unref(); @@ -378,7 +378,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { cp[i] = MakeCollectiveParams(group_key, instance_key + 1, /*is_source*/ i == 0); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, - [&done, cp = cp[i]](const Status& s) { + [&done, cp = cp[i]](const absl::Status& s) { EXPECT_EQ(s.code(), absl::StatusCode::kAborted); EXPECT_EQ(s.message(), "__aborted__"); @@ -389,7 +389,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortPendingInstance) { }); } start.Wait(); - prl_->StartAbort(Status(absl::StatusCode::kAborted, "__aborted__")); + prl_->StartAbort(absl::Status(absl::StatusCode::kAborted, "__aborted__")); done.Wait(); } @@ -410,7 +410,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { /*is_source*/ i == 0); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp[i], &cancel_mgr, - [&done, cp = cp[i]](const Status& s) { + [&done, cp = cp[i]](const absl::Status& s) { EXPECT_EQ(s.code(), error::OK); done.DecrementCount(); cp->Unref(); @@ -419,7 +419,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { } done.Wait(); } - prl_->StartAbort(Status(absl::StatusCode::kAborted, "__aborted__")); + prl_->StartAbort(absl::Status(absl::StatusCode::kAborted, "__aborted__")); auto complete_params = [this, &cancel_mgr](int group_key, int instance_key) { string device = "/job:localhost/replica:0/task:0/device:CPU:0"; @@ -428,7 +428,7 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsAfterAbortion) { /*is_source*/ true); core::ScopedUnref unref(cp); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr, - [&done](const Status& s) { + [&done](const absl::Status& s) { EXPECT_EQ(s.code(), absl::StatusCode::kAborted); EXPECT_EQ(s.message(), "__aborted__"); done.Notify(); @@ -461,14 +461,14 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) { [this, i, device, &num_ok, &cancel_mgr, &done] { int key = 100; while (true) { - Status status; + absl::Status status; Notification n; auto* cp = MakeCollectiveParams(/* group_key*/ key, /*instance_key*/ key, /*is_source*/ i == 0); prl_->CompleteParamsAsync(GetDeviceAttributes(device), cp, &cancel_mgr, - [&status, &n](const Status& s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); @@ -490,7 +490,7 @@ TEST_F(CollectiveParamResolverLocalTest, AbortNormalCompleteParamsAsync) { // on different code points each time. int64_t delay_ms = random::New64() % 50000; Env::Default()->SleepForMicroseconds(delay_ms); - prl_->StartAbort(Status(absl::StatusCode::kAborted, "__aborted__")); + prl_->StartAbort(absl::Status(absl::StatusCode::kAborted, "__aborted__")); done.Wait(); ResetParamResolver(ConfigProto()); } diff --git a/tensorflow/core/common_runtime/collective_rma_local.cc b/tensorflow/core/common_runtime/collective_rma_local.cc index 4c968f703af615..35a56dd4048f42 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.cc +++ b/tensorflow/core/common_runtime/collective_rma_local.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -void CollectiveRemoteAccessLocal::StartAbort(const Status& s) { +void CollectiveRemoteAccessLocal::StartAbort(const absl::Status& s) { buf_rendezvous_.StartAbort(s); } @@ -39,7 +39,7 @@ void CollectiveRemoteAccessLocal::RecvFromPeer( } Device* from_device; - Status status = dev_mgr_->LookupDevice(peer_device, &from_device); + absl::Status status = dev_mgr_->LookupDevice(peer_device, &from_device); if (!status.ok()) { done(status); return; @@ -47,9 +47,9 @@ void CollectiveRemoteAccessLocal::RecvFromPeer( auto consumer_callback = [to_tensor, to_device_ctx, to_device, to_alloc_attr, dev_to_dev_stream_index, - done](const Status& status, + done](const absl::Status& status, BufRendezvous::Hook* hook) { - Status s = status; + absl::Status s = status; if (s.ok()) { if (hook == nullptr) { s = errors::Internal("Invalid null hook in ConsumeBuf callback"); @@ -73,7 +73,7 @@ void CollectiveRemoteAccessLocal::RecvFromPeer( hook->prod_value, // src Tensor* to_tensor, // dst Tensor* dev_to_dev_stream_index, - [hook, done](const Status& memcpy_status) { + [hook, done](const absl::Status& memcpy_status) { // This callback may be executing in the GPUEventMgr // pool in which case it must be very short duration // and non-blocking (except e.g. for queue insertion). diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h index 5c12ed413d9594..2c51b87af40e08 100644 --- a/tensorflow/core/common_runtime/collective_rma_local.h +++ b/tensorflow/core/common_runtime/collective_rma_local.h @@ -35,7 +35,7 @@ class CollectiveRemoteAccessLocal : public CollectiveRemoteAccess { ~CollectiveRemoteAccessLocal() override = default; - void StartAbort(const Status& s) override; + void StartAbort(const absl::Status& s) override; void RecvFromPeer(const string& peer_device, const string& peer_task, bool peer_is_local, const string& key, Device* to_device, diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc index 15ea73867e5079..ff60c2d5dcd97d 100644 --- a/tensorflow/core/common_runtime/collective_rma_local_test.cc +++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc @@ -73,12 +73,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); Tensor sink_tensor(DT_FLOAT, TensorShape({8})); Notification recv_note; - Status recv_status; + absl::Status recv_status; rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, 0 /*stream_index*/, cm_.get(), - [&recv_note, &recv_status](const Status& s) { + [&recv_note, &recv_status](const absl::Status& s) { recv_status = s; recv_note.Notify(); }); @@ -89,11 +89,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU0) { // Tensors have distinct storage. EXPECT_NE(DMAHelper::base(&source_tensor), DMAHelper::base(&sink_tensor)); Notification send_note; - Status send_status; + absl::Status send_status; rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", cpu0 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - cm_.get(), [&send_note, &send_status](const Status& s) { + cm_.get(), + [&send_note, &send_status](const absl::Status& s) { send_status = s; send_note.Notify(); }); @@ -116,12 +117,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:2", &cpu2)); Tensor sink_tensor(DT_FLOAT, TensorShape({8})); Notification recv_note; - Status recv_status; + absl::Status recv_status; rma_->RecvFromPeer(kTaskName + "/device:CPU:1", kTaskName, true /*is_local*/, "key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, 0 /*stream_index*/, cm_.get(), - [&recv_note, &recv_status](const Status& s) { + [&recv_note, &recv_status](const absl::Status& s) { recv_status = s; recv_note.Notify(); }); @@ -134,11 +135,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { Device* cpu1 = nullptr; TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:1", &cpu1)); Notification send_note; - Status send_status; + absl::Status send_status; rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0", cpu1 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - cm_.get(), [&send_note, &send_status](const Status& s) { + cm_.get(), + [&send_note, &send_status](const absl::Status& s) { send_status = s; send_note.Notify(); }); @@ -155,10 +157,10 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostRecvCPU1_2) { } TEST_F(CollectiveRemoteAccessLocalTest, CheckHealth) { - Status status; + absl::Status status; Notification done; rma_->CheckPeerHealth(kTaskName, /*timeout_in_ms=*/0, - [&status, &done](const Status& s) { + [&status, &done](const absl::Status& s) { status = s; done.Notify(); }); @@ -173,12 +175,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, RecvThenCancel) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); Tensor sink_tensor(DT_FLOAT, TensorShape({8})); Notification recv_note; - Status recv_status; + absl::Status recv_status; rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, 0 /*stream_index*/, cm_.get(), - [&recv_note, &recv_status](const Status& s) { + [&recv_note, &recv_status](const absl::Status& s) { recv_status = s; recv_note.Notify(); }); @@ -195,13 +197,13 @@ TEST_F(CollectiveRemoteAccessLocalTest, CancelThenRecv) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); Tensor sink_tensor(DT_FLOAT, TensorShape({8})); Notification recv_note; - Status recv_status; + absl::Status recv_status; cm_->StartCancel(); rma_->RecvFromPeer(kTaskName + "/device:CPU:0", kTaskName, true /*is_local*/, "key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/, attr /*to_alloc_attr*/, &sink_tensor, dev_locality, 0 /*stream_index*/, cm_.get(), - [&recv_note, &recv_status](const Status& s) { + [&recv_note, &recv_status](const absl::Status& s) { recv_status = s; recv_note.Notify(); }); @@ -217,11 +219,12 @@ TEST_F(CollectiveRemoteAccessLocalTest, PostThenCancel) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); Tensor source_tensor(DT_FLOAT, TensorShape({8})); Notification send_note; - Status send_status; + absl::Status send_status; rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", cpu0 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - cm_.get(), [&send_note, &send_status](const Status& s) { + cm_.get(), + [&send_note, &send_status](const absl::Status& s) { send_status = s; send_note.Notify(); }); @@ -238,12 +241,13 @@ TEST_F(CollectiveRemoteAccessLocalTest, CancelThenPost) { TF_ASSERT_OK(device_mgr_->LookupDevice(kTaskName + "/device:CPU:0", &cpu0)); Tensor source_tensor(DT_FLOAT, TensorShape({8})); Notification send_note; - Status send_status; + absl::Status send_status; cm_->StartCancel(); rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0", cpu0 /*from_device*/, nullptr /*from_device_ctx*/, attr /*to_alloc_attr*/, &source_tensor, dev_locality, - cm_.get(), [&send_note, &send_status](const Status& s) { + cm_.get(), + [&send_note, &send_status](const absl::Status& s) { send_status = s; send_note.Notify(); }); diff --git a/tensorflow/core/common_runtime/collective_test_util.cc b/tensorflow/core/common_runtime/collective_test_util.cc index 18ef2ab824daf1..cf856ef971d45a 100644 --- a/tensorflow/core/common_runtime/collective_test_util.cc +++ b/tensorflow/core/common_runtime/collective_test_util.cc @@ -294,8 +294,9 @@ Tensor CopyTensorToHost(Device* device, const Tensor& tensor) { LOG(FATAL) << "Unsupported device_type " << device->device_type(); } -Status RunCollective(CollectiveTestEnv* test_env, CollectiveParams* col_params, - Device* device, Tensor* input, Tensor* output) { +absl::Status RunCollective(CollectiveTestEnv* test_env, + CollectiveParams* col_params, Device* device, + Tensor* input, Tensor* output) { // Copy input and allocate output if on GPU. Tensor input_buffer; Tensor output_buffer; @@ -363,9 +364,9 @@ Status RunCollective(CollectiveTestEnv* test_env, CollectiveParams* col_params, TF_RETURN_IF_ERROR(collective_impl->InitializeCollectiveContext(col_ctx)); // Run the collective. - Status status; + absl::Status status; Notification n; - collective_impl->Run([&status, &n](Status s) { + collective_impl->Run([&status, &n](absl::Status s) { status = s; n.Notify(); }); diff --git a/tensorflow/core/common_runtime/collective_test_util.h b/tensorflow/core/common_runtime/collective_test_util.h index d6704188b96b3f..492097c577f2e2 100644 --- a/tensorflow/core/common_runtime/collective_test_util.h +++ b/tensorflow/core/common_runtime/collective_test_util.h @@ -100,8 +100,9 @@ std::vector GenerateEvenSubdivOffsets(int num_devices_per_worker, int num_subdivs); // Runs a collective. input and output should be on the host. -Status RunCollective(CollectiveTestEnv* test_env, CollectiveParams* col_params, - Device* device, Tensor* input, Tensor* output); +absl::Status RunCollective(CollectiveTestEnv* test_env, + CollectiveParams* col_params, Device* device, + Tensor* input, Tensor* output); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc index dc4b7e359c1d2a..249b6c5fdb0d89 100644 --- a/tensorflow/core/common_runtime/collective_util.cc +++ b/tensorflow/core/common_runtime/collective_util.cc @@ -29,9 +29,10 @@ namespace tensorflow { namespace collective_util { /*static*/ -Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, - const string& device_name, Device** device, - DeviceLocality* device_locality) { +absl::Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, + const string& device_name, + Device** device, + DeviceLocality* device_locality) { if (!dev_mgr) { return errors::Internal("Required non-null dev_mgr ", dev_mgr, " for InitializeDeviceAndLocality"); @@ -39,7 +40,7 @@ Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, // In rare cases during cancellation, this lookup can lead to a SIGSEGV. The // cancellation was caused by some other error. See b/301496136 for details. - Status status = dev_mgr->LookupDevice(device_name, device); + absl::Status status = dev_mgr->LookupDevice(device_name, device); if (status.ok()) { CHECK(*device); *device_locality = (*device)->attributes().locality(); @@ -97,9 +98,9 @@ SubContext::SubContext(OpKernelContext* ctx, OpKernelContext::Params* params, sub_ctx_.reset(new OpKernelContext(&sub_params_, 1)); } -Status ComputeBinOp(OpKernelContext* op_ctx, OpKernelContext::Params* params, - Device* device, OpKernel* op, Tensor* output, - Tensor* input) { +absl::Status ComputeBinOp(OpKernelContext* op_ctx, + OpKernelContext::Params* params, Device* device, + OpKernel* op, Tensor* output, Tensor* input) { // Prepare an OpKernelContext that is identical to that of the original Op // (i.e. the collective), except for the input output sizes and identities and // the Op itself. diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h index b53e779701afce..79cd5d50081117 100644 --- a/tensorflow/core/common_runtime/collective_util.h +++ b/tensorflow/core/common_runtime/collective_util.h @@ -27,9 +27,10 @@ limitations under the License. namespace tensorflow { namespace collective_util { -Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, - const string& device_name, Device** device, - DeviceLocality* device_locality); +absl::Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr, + const string& device_name, + Device** device, + DeviceLocality* device_locality); string SubdivPermDebugString(const CollectiveParams& col_params); // Used for executing a sub-operation, e.g. a merge_op instance, with @@ -49,9 +50,9 @@ class SubContext { ~SubContext() = default; }; -Status ComputeBinOp(OpKernelContext* op_ctx, OpKernelContext::Params* params, - Device* device, OpKernel* op, Tensor* output, - Tensor* input); +absl::Status ComputeBinOp(OpKernelContext* op_ctx, + OpKernelContext::Params* params, Device* device, + OpKernel* op, Tensor* output, Tensor* input); } // namespace collective_util } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc index 0861c7d2a3d633..07a93d91040fa1 100644 --- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.cc @@ -166,7 +166,7 @@ void LogGraphProperties(bool is_graph_changed, bool has_valid_fill_op, } } // namespace -Status ColocatePredecessorTreesPass::Run( +absl::Status ColocatePredecessorTreesPass::Run( const GraphOptimizationPassOptions& options) { if (!ShouldRunPass(options)) { return absl::OkStatus(); diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h index efaffe4c83aedb..b1c1eea6789558 100644 --- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass.h @@ -130,7 +130,7 @@ namespace tensorflow { // heuristic because it reduces number of cut edges and tends to load balance. class ColocatePredecessorTreesPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/colocation_graph.cc b/tensorflow/core/common_runtime/colocation_graph.cc index 2edb3d7fd7fc7c..a97d03ef2f8d70 100644 --- a/tensorflow/core/common_runtime/colocation_graph.cc +++ b/tensorflow/core/common_runtime/colocation_graph.cc @@ -163,7 +163,7 @@ bool HasHostMemoryOutType(const Node& node) { } } // namespace -Status Member::SetParentAndSupportedDevices( +absl::Status Member::SetParentAndSupportedDevices( const Node& node, const std::vector& types, const DeviceNameUtils::ParsedName* local_address_spec) { int id = node.id(); @@ -176,7 +176,7 @@ Status Member::SetParentAndSupportedDevices( types, node.def(), &supported_device_types_, local_address_spec); } -Status Member::SetAssignedDeviceName(const string& device_name) { +absl::Status Member::SetAssignedDeviceName(const string& device_name) { if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) { return errors::Internal( "Setting assigned device name when there is a requested device set " @@ -191,7 +191,7 @@ Status Member::SetAssignedDeviceName(const string& device_name) { return absl::OkStatus(); } -Status Member::SetResourceDeviceName(const Node& node) { +absl::Status Member::SetResourceDeviceName(const Node& node) { if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) { return errors::Internal( "Setting resource device name when there is a requested device set " @@ -211,7 +211,7 @@ Status Member::SetResourceDeviceName(const Node& node) { return absl::OkStatus(); } -Status Member::SetRequestedDeviceName(const Node& node) { +absl::Status Member::SetRequestedDeviceName(const Node& node) { if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) { return errors::Internal( "Setting requested device name when there is an assigned device set " @@ -231,7 +231,8 @@ Status Member::SetRequestedDeviceName(const Node& node) { return absl::OkStatus(); } -Status Member::FillPossibleDevices(PossibleDevices* possible_device) const { +absl::Status Member::FillPossibleDevices( + PossibleDevices* possible_device) const { if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) { return errors::Internal( "Cannot fill PossibleDevices from a member that has non-empty assigned " @@ -265,7 +266,7 @@ bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice( return false; } -Status Member::EnsureCompatibilityAcrossResourceEdge( +absl::Status Member::EnsureCompatibilityAcrossResourceEdge( const Node& src, const Member& src_root, const Node& dst, /*dst_root is this*/ bool log_device_placement) { @@ -388,8 +389,8 @@ int Member::FindRoot(const std::vector& tree, int node_id) { return FindRoot(tree, member.parent_); } -Status Member::MergeDeviceNames(const Member& other, - bool allow_soft_placement) { +absl::Status Member::MergeDeviceNames(const Member& other, + bool allow_soft_placement) { // Assuming the "requested is a specialization of assigned and resource // devices" invariant holds for this and `other`, it will hold after the // merges below. @@ -487,14 +488,15 @@ bool Member::MergeSupportedDevices( return true; } -Status Member::AssignDevice(const Node& node) { +absl::Status Member::AssignDevice(const Node& node) { if (node.assigned_device_name_index() == assigned_device_name_index_) { return absl::OkStatus(); } DeviceNameUtils::ParsedName parsed; DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed); - Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed); + absl::Status s = + DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed); if (!s.ok()) { return errors::Internal( "Constraining by assigned device should not cause an error. Original " @@ -553,8 +555,8 @@ void Member::MaybeExcludeXlaDevices() { } } -Status Member::LimitToPossibleDevices(const PossibleDevices& devices, - bool allow_soft_placement) { +absl::Status Member::LimitToPossibleDevices(const PossibleDevices& devices, + bool allow_soft_placement) { TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( &requested_device_name_, devices.requested_device_name, allow_soft_placement)); @@ -652,7 +654,7 @@ ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack, // the largest node ID. // NOTE: If this method returns an error, *this is left in an undefined // state. -Status ColocationGraph::ColocateAllNodes() { +absl::Status ColocationGraph::ColocateAllNodes() { // This maps from a colocation group identifier to the 'root' of that // colocation group. Note that the keys in this map are StringPiece; the // actual strings are stored under the NodeDef. The lifetime of this map @@ -704,8 +706,8 @@ Status ColocationGraph::ColocateAllNodes() { return absl::OkStatus(); } -Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, - const Node* dst) { +absl::Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, + const Node* dst) { // Colocate `src` and `dst` to maintain the invariant that nodes // connected by reference edges are colocated. int src_root_id = FindAndUpdateRoot(src->id()); @@ -721,7 +723,7 @@ Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, } TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge( *src, src_root, *dst, log_device_placement_)); - Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id); + absl::Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id); if (!status.ok()) { return AttachDef( errors::InvalidArgument( @@ -734,7 +736,7 @@ Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, return absl::OkStatus(); } -Status ColocationGraph::ColocateResourceAndRefEdges( +absl::Status ColocationGraph::ColocateResourceAndRefEdges( std::unordered_set* inspection_required) { // If `node` has an input edge with reference type, add an edge from the // source of that edge to `node`. @@ -809,7 +811,7 @@ DataType GetElementDataType(const Node& node) { } } // namespace -Status ColocationGraph::AddHostOnlyDataTypesConstraints() { +absl::Status ColocationGraph::AddHostOnlyDataTypesConstraints() { auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; }; auto is_cpu_device = [](const std::pair& entry) -> bool { @@ -913,7 +915,7 @@ Status ColocationGraph::AddHostOnlyDataTypesConstraints() { return absl::OkStatus(); } -Status ColocationGraph::AddInspectionConstraints( +absl::Status ColocationGraph::AddInspectionConstraints( const std::unordered_set& inspection_required) { for (Node* node : inspection_required) { IOColocationGroups groups; @@ -926,7 +928,7 @@ Status ColocationGraph::AddInspectionConstraints( return absl::OkStatus(); } -Status ColocationGraph::Initialize() { +absl::Status ColocationGraph::Initialize() { TF_RETURN_IF_ERROR(InitializeMembers()); std::unordered_set inspection_required; @@ -968,8 +970,8 @@ std::vector NodeAndBoolToString(const std::vector& nodes) { // Note: // The same node can be added multiple times to the same group. // The same node can be added to multiple groups. -Status GetGroupNodes(const IOColocationGroups& groups, const Node& node, - std::vector>* group_nodes) { +absl::Status GetGroupNodes(const IOColocationGroups& groups, const Node& node, + std::vector>* group_nodes) { group_nodes->reserve(groups.group_devices.size()); for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) { const Node* src; @@ -1009,7 +1011,7 @@ bool IsSupportedDeviceType(const DeviceAttributes& device_attributes, } // namespace -Status ColocationGraph::ApplyIOColocationGroups( +absl::Status ColocationGraph::ApplyIOColocationGroups( const IOColocationGroups& groups, const Node& node) { if (groups.input_groups.size() != node.num_inputs()) { return errors::Internal( @@ -1068,7 +1070,7 @@ Status ColocationGraph::ApplyIOColocationGroups( return absl::OkStatus(); } -Status ColocationGraph::ColocateNodeToGroup( +absl::Status ColocationGraph::ColocateNodeToGroup( std::unordered_map* colocation_group_root, const Node* node, StringPiece colocation_group) { @@ -1080,7 +1082,7 @@ Status ColocationGraph::ColocateNodeToGroup( } else { // Try to colocate the node with the root. If there is an // error, return it. - Status s = ColocateNodes(*node, *root_node); + absl::Status s = ColocateNodes(*node, *root_node); if (!s.ok()) { if (!allow_soft_placement_) { return AttachDef(s, *node); @@ -1103,7 +1105,7 @@ Status ColocationGraph::ColocateNodeToGroup( // // NOTE: If this method returns an error, *this is left in an undefined // state. -Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) { +absl::Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) { int x_root = FindAndUpdateRoot(x.id()); int y_root = FindAndUpdateRoot(y.id()); return ColocateNodes(x, x_root, y, y_root); @@ -1112,8 +1114,8 @@ Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) { // This overload of ColocateNodes() allows a caller to provide the root node // ids for the two nodes. For large graphs, this noticeably reduces the // graph load time. -Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, - int y_root) { +absl::Status ColocationGraph::ColocateNodes(const Node& x, int x_root, + const Node& y, int y_root) { if (x_root == y_root) { return absl::OkStatus(); } @@ -1129,8 +1131,8 @@ Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, // TODO(mrry): Consider enriching the error message by pointing // out which nodes have the explicit partial device // specifications that caused this conflict. - Status s = new_root_member->MergeDeviceNames(*old_root_member, - allow_soft_placement_); + absl::Status s = new_root_member->MergeDeviceNames(*old_root_member, + allow_soft_placement_); if (!s.ok()) { return errors::InvalidArgument( "Cannot colocate nodes ", @@ -1158,7 +1160,7 @@ Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, return absl::OkStatus(); } -Status ColocationGraph::LimitToAssignedDevice(const Node& node) { +absl::Status ColocationGraph::LimitToAssignedDevice(const Node& node) { if (node.assigned_device_name_index() < 0) { return errors::Internal( "Expected an assigned node as argument to LimitToAssignedDevice but " @@ -1221,14 +1223,14 @@ void ColocationGraph::GetSoftDeviceCandidates( } } -Status ColocationGraph::LimitToPossibleDevices(const Node& node, - const PossibleDevices& devices) { +absl::Status ColocationGraph::LimitToPossibleDevices( + const Node& node, const PossibleDevices& devices) { int root = FindAndUpdateRoot(node.id()); Member& root_member = members_[root]; return root_member.LimitToPossibleDevices(devices, allow_soft_placement_); } -Status ColocationGraph::GetDevicesForNode( +absl::Status ColocationGraph::GetDevicesForNode( Node* node, const std::vector** possible_devices) { *possible_devices = nullptr; const int node_root = FindAndUpdateRoot(node->id()); @@ -1364,9 +1366,9 @@ Status ColocationGraph::GetDevicesForNode( return absl::OkStatus(); } -Status ColocationGraph::InitializeMembers() { +absl::Status ColocationGraph::InitializeMembers() { for (Node* node : graph_.op_nodes()) { - Status status = InitializeMember(*node, &members_[node->id()]); + absl::Status status = InitializeMember(*node, &members_[node->id()]); if (!status.ok()) { return AttachDef(status, *node); } @@ -1452,7 +1454,7 @@ string ColocationGraph::DebugInfo(const int node_root) const { return text; } -Status ColocationGraph::InitializeMemberWithAssignedDevice( +absl::Status ColocationGraph::InitializeMemberWithAssignedDevice( const string& assigned_device_name, const string& node_type, Member* member) { // This node has already been assigned to a device, so we @@ -1491,7 +1493,8 @@ Status ColocationGraph::InitializeMemberWithAssignedDevice( node_type); } -Status ColocationGraph::InitializeMember(const Node& node, Member* member) { +absl::Status ColocationGraph::InitializeMember(const Node& node, + Member* member) { TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices( node, device_types_, &local_address_spec_)); diff --git a/tensorflow/core/common_runtime/colocation_graph.h b/tensorflow/core/common_runtime/colocation_graph.h index e6362b47334949..887ac205393f38 100644 --- a/tensorflow/core/common_runtime/colocation_graph.h +++ b/tensorflow/core/common_runtime/colocation_graph.h @@ -37,7 +37,7 @@ class Member { public: Member() = default; - Status SetParentAndSupportedDevices( + absl::Status SetParentAndSupportedDevices( const Node& node, const std::vector& types, const DeviceNameUtils::ParsedName* local_address_spec); @@ -45,17 +45,17 @@ class Member { return requested_device_name_; } - Status SetAssignedDeviceName(const string& device_name); - Status SetResourceDeviceName(const Node& node); - Status SetRequestedDeviceName(const Node& node); + absl::Status SetAssignedDeviceName(const string& device_name); + absl::Status SetResourceDeviceName(const Node& node); + absl::Status SetRequestedDeviceName(const Node& node); - Status FillPossibleDevices(PossibleDevices* possible_device) const; + absl::Status FillPossibleDevices(PossibleDevices* possible_device) const; // Returns whether `src_root` is assigned to a CompositeDevice and `this` is // assigned to a physical device. bool IsEdgeFromCompositeDeviceToPhysicalDevice(const Member& src_root) const; - Status EnsureCompatibilityAcrossResourceEdge( + absl::Status EnsureCompatibilityAcrossResourceEdge( const Node& src, const Member& src_root, const Node& dst, /*dst_root is this*/ bool log_device_placement); @@ -78,14 +78,14 @@ class Member { static int FindRoot(const std::vector& tree, int node_id); static int FindAndUpdateRoot(std::vector* tree, int node_id); - Status MergeDeviceNames(const Member& other, bool allow_soft_placement); + absl::Status MergeDeviceNames(const Member& other, bool allow_soft_placement); // Updates this to contain the intersection of the device types in // this and "other". If the intersection is empty, returns false and does // not update this. Else returns true and updates this. bool MergeSupportedDevices(const Member& other); - Status AssignDevice(const Node& node); + absl::Status AssignDevice(const Node& node); // If user does not explicitly request XLA device and non-XLA device is // supported for this node, use only the non-XLA device. See b/140896502. @@ -93,8 +93,8 @@ class Member { // Limit the possible devices of this (should be a root) to the device // specifications in `devices`. - Status LimitToPossibleDevices(const PossibleDevices& devices, - bool allow_soft_placement); + absl::Status LimitToPossibleDevices(const PossibleDevices& devices, + bool allow_soft_placement); void set_possible_devices(std::vector&& devices) { possible_devices_ = devices; @@ -222,21 +222,21 @@ class ColocationGraph { const Device* default_local_device, bool allow_soft_placement, bool log_device_placement); - Status Initialize(); + absl::Status Initialize(); const std::vector& members() const { return members_; } // Limit the group containing `node` to the device specifications in // `devices`. - Status LimitToPossibleDevices(const Node& node, - const PossibleDevices& devices); + absl::Status LimitToPossibleDevices(const Node& node, + const PossibleDevices& devices); // Limits the possible devices of `node`'s colocation group to the device // to which `node` is assigned. This makes sure that all nodes in this // colocation group will be assigned to the same device. Without this // explicit restriction, heuristics can choose a different possible device // for other nodes in the group. - Status LimitToAssignedDevice(const Node& node); + absl::Status LimitToAssignedDevice(const Node& node); // Returns the root node of the disjoint tree to which the node with the // given id is connected. @@ -252,8 +252,8 @@ class ColocationGraph { // Note: This method returns a pointer to a field within members_. // The caller must not use the returned pointer after there is any possibility // that the members_[i].possible_devices field has been modified. - Status GetDevicesForNode(Node* node, - const std::vector** possible_devices); + absl::Status GetDevicesForNode(Node* node, + const std::vector** possible_devices); // Returns debugging info for the node referred to by 'node_root'. string DebugInfo(const int node_root) const; @@ -276,12 +276,12 @@ class ColocationGraph { // the largest node ID. // NOTE: If this method returns an error, *this is left in an undefined // state. - Status ColocateAllNodes(); + absl::Status ColocateAllNodes(); - Status ColocateResourceOrRefEdge(const Node* src, const Node* dst); + absl::Status ColocateResourceOrRefEdge(const Node* src, const Node* dst); // Adds colocation constraints to data types known not to support copying. - Status ColocateUncopiableTypeEdges( + absl::Status ColocateUncopiableTypeEdges( std::unordered_set* inspection_required); // Updates this ColocationGraph by making sure that all nodes @@ -291,7 +291,7 @@ class ColocationGraph { // PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired // deems as requiring deep inspection by placer. This is an optimization. // TODO(mdan): Deprecate in favor of ColocateUncopiableTypeEdges. - Status ColocateResourceAndRefEdges( + absl::Status ColocateResourceAndRefEdges( std::unordered_set* inspection_required); // Updates this ColocationGraph by making sure that all nodes having inputs of @@ -300,9 +300,9 @@ class ColocationGraph { // nodes that take variant inputs to the node that produces that variant. // TODO(ezhulenev): This function does not yet support "deep op" inspection, // that we have for DT_RESOURCE edges. - Status AddHostOnlyDataTypesConstraints(); + absl::Status AddHostOnlyDataTypesConstraints(); - Status AddInspectionConstraints( + absl::Status AddInspectionConstraints( const std::unordered_set& inspection_required); // Applies colocation groups for `node`'s inputs and outputs to this @@ -329,10 +329,10 @@ class ColocationGraph { // ColocateNodes(a, c) and LimitToPossibleDevices(`a`, "GPU"). The colocation // group of the `node` itself is not directly impacted. // - Status ApplyIOColocationGroups(const IOColocationGroups& groups, - const Node& node); + absl::Status ApplyIOColocationGroups(const IOColocationGroups& groups, + const Node& node); - Status ColocateNodeToGroup( + absl::Status ColocateNodeToGroup( std::unordered_map* colocation_group_root, const Node* node, StringPiece colocation_group); @@ -342,25 +342,26 @@ class ColocationGraph { // be placed on the same device type. // // If this method returns an error, *this is unchanged. - Status ColocateNodes(const Node& x, const Node& y); + absl::Status ColocateNodes(const Node& x, const Node& y); // This overload of ColocateNodes() allows a caller to provide the root node // ids for the two nodes. For large graphs, this noticeably reduces the // graph load time. // If this method returns an error, *this is unchanged. - Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root); + absl::Status ColocateNodes(const Node& x, int x_root, const Node& y, + int y_root); void GetSoftDeviceCandidates(const Node& node, const Member& root_member, int root_id, std::vector* possible_devices); - Status InitializeMembers(); + absl::Status InitializeMembers(); - Status InitializeMemberWithAssignedDevice(const string& assigned_device_name, - const string& node_type, - Member* member); + absl::Status InitializeMemberWithAssignedDevice( + const string& assigned_device_name, const string& node_type, + Member* member); - Status InitializeMember(const Node& node, Member* member); + absl::Status InitializeMember(const Node& node, Member* member); // Returns the root node of the disjoint tree to which the node with the // given id is connected. diff --git a/tensorflow/core/common_runtime/composite_device.cc b/tensorflow/core/common_runtime/composite_device.cc index afde64faf17fdf..6b0d98e44a83fd 100644 --- a/tensorflow/core/common_runtime/composite_device.cc +++ b/tensorflow/core/common_runtime/composite_device.cc @@ -24,7 +24,7 @@ const char* const kCompositeDeviceType = "COMPOSITE"; std::unique_ptr CompositeDevice::MakeDevice( const std::vector& underlying_devices, const int unique_device_id, - const DeviceNameUtils::ParsedName& host_name, Status* status) { + const DeviceNameUtils::ParsedName& host_name, absl::Status* status) { DeviceNameUtils::ParsedName parsed_name = host_name; parsed_name.type = kCompositeDeviceType; parsed_name.id = unique_device_id; @@ -34,7 +34,7 @@ std::unique_ptr CompositeDevice::MakeDevice( std::unique_ptr CompositeDevice::MakeDevice( const std::vector& underlying_devices, const string& device_name, - Status* status) { + absl::Status* status) { if (underlying_devices.empty()) { status->Update( errors::InvalidArgument("underlying_devices should not be empty.")); diff --git a/tensorflow/core/common_runtime/composite_device.h b/tensorflow/core/common_runtime/composite_device.h index 800a24f34117d7..6e79542a280722 100644 --- a/tensorflow/core/common_runtime/composite_device.h +++ b/tensorflow/core/common_runtime/composite_device.h @@ -31,7 +31,7 @@ extern const char* const kCompositeDeviceType; // op on this virtial device. class CompositeDevice : public Device { public: - Status Sync() override { + absl::Status Sync() override { return errors::Internal( "Sync() should never been invoked on CompositeDevice."); } @@ -46,12 +46,12 @@ class CompositeDevice : public Device { // CPU. static std::unique_ptr MakeDevice( const std::vector& underlying_devices, const int unique_device_id, - const DeviceNameUtils::ParsedName& host_name, Status* status); + const DeviceNameUtils::ParsedName& host_name, absl::Status* status); // Helper for creating a CompositeDevice with the given device name. static std::unique_ptr MakeDevice( const std::vector& underlying_devices, const string& device_name, - Status* status); + absl::Status* status); bool IsRemoteCallAllowed() const override { return false; } diff --git a/tensorflow/core/common_runtime/composite_device_test.cc b/tensorflow/core/common_runtime/composite_device_test.cc index af2c7915d2c699..0db56aeed17798 100644 --- a/tensorflow/core/common_runtime/composite_device_test.cc +++ b/tensorflow/core/common_runtime/composite_device_test.cc @@ -25,7 +25,7 @@ TEST(CompositeDeviceTest, Basic) { EXPECT_TRUE(DeviceNameUtils::ParseFullName(host_name, &parsed_host_name)); std::vector underlying_devices; { - Status status; + absl::Status status; std::unique_ptr composite_device = CompositeDevice::MakeDevice(underlying_devices, /*unique_device_id=*/0, parsed_host_name, &status); @@ -37,7 +37,7 @@ TEST(CompositeDeviceTest, Basic) { } { - Status status; + absl::Status status; underlying_devices.push_back( "/job:localhost/replica:0/task:0/device:CPU:0"); underlying_devices.push_back( @@ -51,7 +51,7 @@ TEST(CompositeDeviceTest, Basic) { } { - Status status; + absl::Status status; underlying_devices.push_back( "/job:localhost/replica:0/task:0/device:GPU:0"); std::unique_ptr composite_device = @@ -71,7 +71,7 @@ TEST(CompositeDeviceTest, DeviceName) { std::vector underlying_devices; underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:0"); underlying_devices.push_back("/job:worker/replica:0/task:0/device:CPU:1"); - Status status; + absl::Status status; std::unique_ptr composite_device = CompositeDevice::MakeDevice(underlying_devices, composite_device_name, &status); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index cf79d58555568a..6820a5ddd696d3 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -615,10 +615,10 @@ bool ReplaceTensorWithConstant( } // namespace -Status ConstantFold(const ConstantFoldingOptions& opts, - FunctionLibraryRuntime* function_library, Env* env, - const Device* partition_device, Graph* graph, - bool* was_mutated) { +absl::Status ConstantFold(const ConstantFoldingOptions& opts, + FunctionLibraryRuntime* function_library, Env* env, + const Device* partition_device, Graph* graph, + bool* was_mutated) { // TensorFlow flushes denormals to zero and rounds to nearest, so we do // the same here. port::ScopedFlushDenormal flush; @@ -689,7 +689,7 @@ Status ConstantFold(const ConstantFoldingOptions& opts, graph_runner.reset(nullptr); }); - Status s = + absl::Status s = graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/, tensors_to_fetch_names, &outputs); if (!s.ok()) { diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h index 64441f73db4b86..fd74a554c7e03e 100644 --- a/tensorflow/core/common_runtime/constant_folding.h +++ b/tensorflow/core/common_runtime/constant_folding.h @@ -60,10 +60,10 @@ struct ConstantFoldingOptions { // Sets `was_mutated` to true if and only if "graph" has been mutated. // The status is only set to a non-OK state if an unexpected error is hit // running the graph. -Status ConstantFold(const ConstantFoldingOptions& opts, - FunctionLibraryRuntime* function_library, Env* env, - const Device* partition_device, Graph* graph, - bool* was_mutated); +absl::Status ConstantFold(const ConstantFoldingOptions& opts, + FunctionLibraryRuntime* function_library, Env* env, + const Device* partition_device, Graph* graph, + bool* was_mutated); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/constant_folding_test.cc b/tensorflow/core/common_runtime/constant_folding_test.cc index bf518b59ac7234..d4b27716a217a7 100644 --- a/tensorflow/core/common_runtime/constant_folding_test.cc +++ b/tensorflow/core/common_runtime/constant_folding_test.cc @@ -99,7 +99,9 @@ class FakeDevice : public Device { : Device(nullptr, device_attributes) {} public: - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + absl::Status Sync() override { + return errors::Unimplemented("FakeDevice::Sync()"); + } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -137,7 +139,8 @@ TEST_F(ConstantFoldingTest, Basic) { // Tests that different node creation ordering creates same graph after constant // folding. TEST_F(ConstantFoldingTest, DeterministicFolding) { - auto build_graph_and_constant_folding = [](Graph& g, bool swap) -> Status { + auto build_graph_and_constant_folding = [](Graph& g, + bool swap) -> absl::Status { Scope s = Scope::NewRootScope(); auto a = ops::Const(s, {1.0}, {}); auto b = ops::Const(s, {2.0}, {}); @@ -351,7 +354,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceFunctionCall) { NodeDefBuilder("times_two", "XTimesTwo", s.graph()->op_registry()) .Input(c.name(), 0, DT_INT32) .Finalize(&def)); - Status status; + absl::Status status; Node* times_two = s.graph()->AddNode(def, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(s.DoShapeInference(times_two)); @@ -385,7 +388,7 @@ TEST_F(ConstantFoldingTest, TestNoReplaceNonCPUOp) { TF_ASSERT_OK(NodeDefBuilder("testop", "ConstantFoldingTestOp") .Input(aconst.name(), 0, DT_INT64) .Finalize(&def)); - Status status; + absl::Status status; Node* non_cpu = s.graph()->AddNode(def, &status); TF_ASSERT_OK(status); TF_ASSERT_OK(s.DoShapeInference(non_cpu)); @@ -681,7 +684,7 @@ class TestTFFileSystem : public ::tensorflow::NullFileSystem { using ::tensorflow::NullFileSystem::NewReadOnlyMemoryRegionFromFile; - ::tensorflow::Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewReadOnlyMemoryRegionFromFile( const string& fname, ::tensorflow::TransactionToken* token, std::unique_ptr<::tensorflow::ReadOnlyMemoryRegion>* result) override { if (fname != kTestMemRegionName) { @@ -703,7 +706,7 @@ class TestTFEnvironment : public ::tensorflow::EnvWrapper { public: using tf_base = ::tensorflow::EnvWrapper; TestTFEnvironment() : ::tensorflow::EnvWrapper(Default()) {} - ::tensorflow::Status GetFileSystemForFile( + absl::Status GetFileSystemForFile( const string& fname, ::tensorflow::FileSystem** result) override { was_used_ = true; if (fname == "test://test") { @@ -732,8 +735,8 @@ TEST_F(ConstantFoldingTest, TestImmutableConst) { TF_ASSERT_OK(root.ToGraph(&g)); TestTFEnvironment test_env; bool was_mutated; - Status status = ConstantFold(ConstantFoldingOptions{}, nullptr, - Env::Default(), nullptr, &g, &was_mutated); + absl::Status status = ConstantFold(ConstantFoldingOptions{}, nullptr, + Env::Default(), nullptr, &g, &was_mutated); EXPECT_FALSE(was_mutated); EXPECT_FALSE(status.ok()); TF_EXPECT_OK(ConstantFold(ConstantFoldingOptions{}, nullptr, &test_env, diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index 675ffc624c68e4..dadaaf0cd61f2d 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -62,7 +62,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); - auto wrapped_done = [status_cb](const Status& s) { + auto wrapped_done = [status_cb](const absl::Status& s) { status_cb->UpdateStatus(s); status_cb->Unref(); }; @@ -77,7 +77,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( + absl::Status err = errors::InvalidArgument( "During Variant Host->Device Copy: " "non-DMA-copy attempted of tensor type: ", DataTypeString(from.dtype())); @@ -98,7 +98,7 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, const Variant* v = input->flat().data(); Variant* v_out = copy.flat().data(); - Status s_copy_init; + absl::Status s_copy_init; for (int64_t i = 0; i < input->NumElements(); ++i) { s_copy_init = VariantDeviceCopy( VariantDeviceCopyDirection::HOST_TO_DEVICE, v[i], &v_out[i], copier); @@ -132,7 +132,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); - auto wrapped_done = [status_cb](const Status& s) { + auto wrapped_done = [status_cb](const absl::Status& s) { status_cb->UpdateStatus(s); status_cb->Unref(); }; @@ -151,7 +151,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( + absl::Status err = errors::InvalidArgument( "During Variant Device->Device Copy: ", src->name(), " to ", dst->name(), " non-DMA-copy attempted of tensor type: ", DataTypeString(from.dtype())); @@ -173,7 +173,7 @@ void CopyDeviceToDevice(CopyTensor::CopyFunction copy_function, const Variant* v = input->flat().data(); Variant* v_out = copy.flat().data(); - Status s_copy_init; + absl::Status s_copy_init; for (int64_t i = 0; i < input->NumElements(); ++i) { s_copy_init = VariantDeviceCopy(VariantDeviceCopyDirection::DEVICE_TO_DEVICE, v[i], @@ -254,15 +254,15 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, Tensor* cpu_tensor = new Tensor(cpu_allocator, input->dtype(), input->shape()); - auto delete_and_done = [cpu_tensor, - done = std::move(done)](const Status& status) { - delete cpu_tensor; - done(status); - }; + auto delete_and_done = + [cpu_tensor, done = std::move(done)](const absl::Status& status) { + delete cpu_tensor; + done(status); + }; auto then_copy_to_other_device = [delete_and_done = std::move(delete_and_done), recv_dev_context, cpu_tensor, cpu_allocator, out_allocator, edge_name, dst, output, - sync_dst_compute](Status status) { + sync_dst_compute](absl::Status status) { if (!status.ok()) { delete_and_done(status); return; @@ -301,10 +301,10 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, } // static -Status CopyTensor::Register(DeviceType sender_device_type, - DeviceType receiver_device_type, - CopyFunction copy_function, - bool is_pluggable_device) { +absl::Status CopyTensor::Register(DeviceType sender_device_type, + DeviceType receiver_device_type, + CopyFunction copy_function, + bool is_pluggable_device) { std::vector* registry = MutableRegistry(); registry->emplace_back(sender_device_type, receiver_device_type, copy_function, is_pluggable_device); @@ -315,7 +315,7 @@ namespace { // The following registrations enable a DT_VARIANT tensor element that contains // a wrapped `tensorflow::Tensor` to be copied between devices. -static Status WrappedTensorDeviceCopy( +static absl::Status WrappedTensorDeviceCopy( const Tensor& from, Tensor* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { if (DMAHelper::CanUseDMA(&from)) { @@ -346,7 +346,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, auto* status_cb = new ReffedStatusCallback(std::move(done)); core::ScopedUnref status_cb_unref(status_cb); - auto wrapped_done = [status_cb](const Status& s) { + auto wrapped_done = [status_cb](const absl::Status& s) { status_cb->UpdateStatus(s); status_cb->Unref(); }; @@ -360,7 +360,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( + absl::Status err = errors::InvalidArgument( "During Variant Device->Host Copy: " "non-DMA-copy attempted of tensor type: ", DataTypeString(from.dtype())); @@ -381,7 +381,7 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, const Variant* v = input->flat().data(); Variant* v_out = copy.flat().data(); - Status s_copy_init; + absl::Status s_copy_init; for (int64_t i = 0; i < input->NumElements(); ++i) { s_copy_init = VariantDeviceCopy( VariantDeviceCopyDirection::DEVICE_TO_HOST, v[i], &v_out[i], copier); diff --git a/tensorflow/core/common_runtime/copy_tensor.h b/tensorflow/core/common_runtime/copy_tensor.h index 0bbbee2421b284..80187bde94b4b6 100644 --- a/tensorflow/core/common_runtime/copy_tensor.h +++ b/tensorflow/core/common_runtime/copy_tensor.h @@ -63,9 +63,10 @@ class CopyTensor { // Register a function for copying between two specific DeviceTypes. // Note: This should only be called via the constructor of // CopyTensor::Registration or from PluggableDevice implementation. - static Status Register(DeviceType sender_device_type, - DeviceType receiver_device_type, - CopyFunction copy_function, bool is_pluggable_device); + static absl::Status Register(DeviceType sender_device_type, + DeviceType receiver_device_type, + CopyFunction copy_function, + bool is_pluggable_device); }; void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, diff --git a/tensorflow/core/common_runtime/costmodel_manager.cc b/tensorflow/core/common_runtime/costmodel_manager.cc index 36ef7d08933b84..7de63bb95dbf7b 100644 --- a/tensorflow/core/common_runtime/costmodel_manager.cc +++ b/tensorflow/core/common_runtime/costmodel_manager.cc @@ -53,8 +53,8 @@ bool CostModelManager::RemoveCostModelForGraph(const Graph* graph) { return true; } -Status CostModelManager::AddToCostGraphDef(const Graph* graph, - CostGraphDef* cost_graph) { +absl::Status CostModelManager::AddToCostGraphDef(const Graph* graph, + CostGraphDef* cost_graph) { mutex_lock l(mu_); // Get the cost model for the graph. auto it = cost_models_.find(graph); diff --git a/tensorflow/core/common_runtime/costmodel_manager.h b/tensorflow/core/common_runtime/costmodel_manager.h index 770d44287269da..8ea8a137034ab9 100644 --- a/tensorflow/core/common_runtime/costmodel_manager.h +++ b/tensorflow/core/common_runtime/costmodel_manager.h @@ -43,7 +43,7 @@ class CostModelManager { bool RemoveCostModelForGraph(const Graph* graph); - Status AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph); + absl::Status AddToCostGraphDef(const Graph* graph, CostGraphDef* cost_graph); private: mutex mu_; diff --git a/tensorflow/core/common_runtime/debugger_state_interface.cc b/tensorflow/core/common_runtime/debugger_state_interface.cc index a9626069a79926..7728c30c46f1d9 100644 --- a/tensorflow/core/common_runtime/debugger_state_interface.cc +++ b/tensorflow/core/common_runtime/debugger_state_interface.cc @@ -60,7 +60,7 @@ void DebuggerStateRegistry::RegisterFactory( } // static -Status DebuggerStateRegistry::CreateState( +absl::Status DebuggerStateRegistry::CreateState( const DebugOptions& debug_options, std::unique_ptr* state) { if (factory_ == nullptr || *factory_ == nullptr) { @@ -81,7 +81,7 @@ void DebugGraphDecoratorRegistry::RegisterFactory( } // static -Status DebugGraphDecoratorRegistry::CreateDecorator( +absl::Status DebugGraphDecoratorRegistry::CreateDecorator( const DebugOptions& options, std::unique_ptr* decorator) { if (factory_ == nullptr || *factory_ == nullptr) { diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h index e4815cca4f3c59..1b9f190e18bedf 100644 --- a/tensorflow/core/common_runtime/debugger_state_interface.h +++ b/tensorflow/core/common_runtime/debugger_state_interface.h @@ -48,7 +48,7 @@ class DebuggerStateInterface { // input_names: Name of the input Tensors (feed keys). // output_names: Names of the fetched Tensors. // target_names: Names of the target nodes. - virtual Status PublishDebugMetadata( + virtual absl::Status PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, const int64_t executor_step_index, const std::vector& input_names, const std::vector& output_names, @@ -62,11 +62,11 @@ class DebugGraphDecoratorInterface { // Insert special-purpose debug nodes to graph and dump the graph for // record. See the documentation of DebugNodeInserter::InsertNodes() for // details. - virtual Status DecorateGraph(Graph* graph, Device* device) = 0; + virtual absl::Status DecorateGraph(Graph* graph, Device* device) = 0; // Publish Graph to debug URLs. - virtual Status PublishGraph(const Graph& graph, - const string& device_name) = 0; + virtual absl::Status PublishGraph(const Graph& graph, + const string& device_name) = 0; }; typedef std::function( @@ -88,8 +88,9 @@ class DebuggerStateRegistry { // DebuggerStateInterface implementation using the registered factory, // owned by the caller and return an OK Status. Otherwise returns an error // Status. - static Status CreateState(const DebugOptions& debug_options, - std::unique_ptr* state); + static absl::Status CreateState( + const DebugOptions& debug_options, + std::unique_ptr* state); private: static DebuggerStateFactory* factory_; @@ -106,7 +107,7 @@ class DebugGraphDecoratorRegistry { public: static void RegisterFactory(const DebugGraphDecoratorFactory& factory); - static Status CreateDecorator( + static absl::Status CreateDecorator( const DebugOptions& options, std::unique_ptr* decorator); diff --git a/tensorflow/core/common_runtime/device/device_utils.cc b/tensorflow/core/common_runtime/device/device_utils.cc index e95f95cb8dfa8e..60ec4cd0082a67 100644 --- a/tensorflow/core/common_runtime/device/device_utils.cc +++ b/tensorflow/core/common_runtime/device/device_utils.cc @@ -22,13 +22,14 @@ limitations under the License. namespace tensorflow { namespace device_utils { -Status ValidateDeviceType(StringPiece type) { +absl::Status ValidateDeviceType(StringPiece type) { static const LazyRE2 kTfDeviceTypeRegEx = {"[A-Z][A-Z_]*"}; bool matches = RE2::FullMatch(type, *kTfDeviceTypeRegEx); if (!matches) { - return Status(absl::StatusCode::kFailedPrecondition, - strings::StrCat("Device name/type '", type, "' must match ", - kTfDeviceTypeRegEx->pattern(), ".")); + return absl::Status( + absl::StatusCode::kFailedPrecondition, + strings::StrCat("Device name/type '", type, "' must match ", + kTfDeviceTypeRegEx->pattern(), ".")); } return absl::OkStatus(); } diff --git a/tensorflow/core/common_runtime/device/device_utils.h b/tensorflow/core/common_runtime/device/device_utils.h index 91c93a64e46b66..05c52e0aa92081 100644 --- a/tensorflow/core/common_runtime/device/device_utils.h +++ b/tensorflow/core/common_runtime/device/device_utils.h @@ -33,7 +33,7 @@ namespace device_utils { // Note that lowercase "cpu" and "gpu" are currently supported only for // legacy reasons: // https://cs.opensource.google/tensorflow/tensorflow/+/master:tensorflow/python/framework/device_spec.py;l=46;drc=d3a378f9665d8eee827c74cb9ecbee81e4c288dd -Status ValidateDeviceType(StringPiece type); +absl::Status ValidateDeviceType(StringPiece type); } // namespace device_utils } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h index 82688a2311a345..87fe86835419c5 100644 --- a/tensorflow/core/common_runtime/device_mgr.h +++ b/tensorflow/core/common_runtime/device_mgr.h @@ -56,7 +56,8 @@ class DeviceMgr { // Assigns *device with pointer to Device of the given name. // Accepts either a full device name, or just the replica-local suffix. - virtual Status LookupDevice(StringPiece name, Device** device) const = 0; + virtual absl::Status LookupDevice(StringPiece name, + Device** device) const = 0; // Check if the current device manager contains device with the given // incarnation ID. Looking up by incarnation IDs because they are randomly @@ -100,7 +101,7 @@ class DynamicDeviceMgr : public DeviceMgr { std::vector ListDevices() const override; string DebugString() const override; string DeviceMappingString() const override; - Status LookupDevice(StringPiece name, Device** device) const override; + absl::Status LookupDevice(StringPiece name, Device** device) const override; bool ContainsDevice(int64_t device_incarnation) const override; void ClearContainers(absl::Span containers) const override; int NumDeviceType(const string& type) const override; @@ -108,17 +109,17 @@ class DynamicDeviceMgr : public DeviceMgr { Device* HostCPU() const override; // Add devices to device manager. Returns error for repeated device names. - Status AddDevices(std::vector> devices); + absl::Status AddDevices(std::vector> devices); // Remove devices from device manager. // Returns error for non-existing devices or if the HostCPU() device is in the // input list. If an error is returned, the device list is not modified. - Status RemoveDevices(const std::vector& devices); + absl::Status RemoveDevices(const std::vector& devices); // Remove devices from device manager by their names. Returns error for // non-existing devices or if the HostCPU() device is given in the input list. // If an error is returned, the device list is not modified. - Status RemoveDevicesByName(const std::vector& device_names); + absl::Status RemoveDevicesByName(const std::vector& device_names); private: mutable mutex devices_mu_; diff --git a/tensorflow/core/common_runtime/device_mgr_test.cc b/tensorflow/core/common_runtime/device_mgr_test.cc index 6cf9be8959599d..4944d317406faa 100644 --- a/tensorflow/core/common_runtime/device_mgr_test.cc +++ b/tensorflow/core/common_runtime/device_mgr_test.cc @@ -30,7 +30,7 @@ static Device* CreateDevice(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; diff --git a/tensorflow/core/common_runtime/device_resolver_local.cc b/tensorflow/core/common_runtime/device_resolver_local.cc index 1fca35b15ef012..5f7272dc57747e 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.cc +++ b/tensorflow/core/common_runtime/device_resolver_local.cc @@ -20,11 +20,11 @@ limitations under the License. namespace tensorflow { -Status DeviceResolverLocal::GetDeviceAttributes(const string& device, - DeviceAttributes* attributes) { +absl::Status DeviceResolverLocal::GetDeviceAttributes( + const string& device, DeviceAttributes* attributes) { Device* dev; // LookupDevice returns InvalidArgument if the device is not found. - Status s = dev_mgr_->LookupDevice(device, &dev); + absl::Status s = dev_mgr_->LookupDevice(device, &dev); if (absl::IsInvalidArgument(s)) { return errors::NotFound(device, " not found"); } else if (!s.ok()) { @@ -34,13 +34,13 @@ Status DeviceResolverLocal::GetDeviceAttributes(const string& device, return absl::OkStatus(); } -Status DeviceResolverLocal::GetAllDeviceAttributes( +absl::Status DeviceResolverLocal::GetAllDeviceAttributes( const string& task, std::vector* attributes) { return errors::Internal( "GetTaskCached is not supposed to be called in local collectives"); } -Status DeviceResolverLocal::UpdateDeviceAttributes( +absl::Status DeviceResolverLocal::UpdateDeviceAttributes( const std::vector& attributes) { return errors::Internal( "UpdateDeviceAttributes shouldn't be called with local collectives"); diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h index adb859abc1fd9e..814bea88a9a77f 100644 --- a/tensorflow/core/common_runtime/device_resolver_local.h +++ b/tensorflow/core/common_runtime/device_resolver_local.h @@ -30,13 +30,13 @@ class DeviceResolverLocal : public DeviceResolverInterface { public: explicit DeviceResolverLocal(const DeviceMgr* dev_mgr) : dev_mgr_(dev_mgr) {} - Status GetDeviceAttributes(const string& device, - DeviceAttributes* attributes) override; + absl::Status GetDeviceAttributes(const string& device, + DeviceAttributes* attributes) override; - Status GetAllDeviceAttributes( + absl::Status GetAllDeviceAttributes( const string& task, std::vector* attributes) override; - Status UpdateDeviceAttributes( + absl::Status UpdateDeviceAttributes( const std::vector& attributes) override; protected: diff --git a/tensorflow/core/common_runtime/device_set_test.cc b/tensorflow/core/common_runtime/device_set_test.cc index 20141e80f2a458..9c8f1ac675f01f 100644 --- a/tensorflow/core/common_runtime/device_set_test.cc +++ b/tensorflow/core/common_runtime/device_set_test.cc @@ -29,7 +29,7 @@ static Device* Dev(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -60,11 +60,12 @@ class DeviceSetTest : public ::testing::Test { class DummyFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override { + absl::Status ListPhysicalDevices(std::vector* devices) override { return absl::OkStatus(); } - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override { + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override { return absl::OkStatus(); } }; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index 1cd597c0c21ebb..0acad53f12144e 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -517,7 +517,7 @@ Status DirectSession::RunInternal( RunState run_state(step_id, &devices_); const size_t num_executors = executors_and_keys->items.size(); - profiler::TraceMeProducer activity( + tsl::profiler::TraceMeProducer activity( // To TraceMeConsumers in ExecutorState::Process/Finish. [&] { if (options_.config.experimental().has_session_metadata()) { @@ -525,17 +525,17 @@ Status DirectSession::RunInternal( options_.config.experimental().session_metadata(); string model_id = strings::StrCat(model_metadata.name(), ":", model_metadata.version()); - return profiler::TraceMeEncode("SessionRun", - {{"id", step_id}, - {"_r", 1} /*root_event*/, - {"model_id", model_id}}); + return tsl::profiler::TraceMeEncode("SessionRun", + {{"id", step_id}, + {"_r", 1} /*root_event*/, + {"model_id", model_id}}); } else { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/}); } }, tsl::profiler::ContextType::kTfExecutor, step_id, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); std::unique_ptr debugger_state; if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { @@ -889,7 +889,7 @@ Status DirectSession::Run(const RunOptions& run_options, // fetch values to and from the executors. FunctionCallFrame call_frame(executors_and_keys->input_types, executors_and_keys->output_types); - gtl::InlinedVector feed_args(inputs.size()); + absl::InlinedVector feed_args(inputs.size()); for (const auto& it : inputs) { if (it.second.dtype() == DT_RESOURCE) { Tensor tensor_from_handle; @@ -1488,9 +1488,9 @@ Status DirectSession::CreateExecutors( } Status DirectSession::GetOrCreateExecutors( - gtl::ArraySlice inputs, gtl::ArraySlice outputs, - gtl::ArraySlice target_nodes, ExecutorsAndKeys** executors_and_keys, - RunStateArgs* run_state_args) { + absl::Span inputs, absl::Span outputs, + absl::Span target_nodes, + ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) { int64_t handle_name_counter_value = -1; if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { handle_name_counter_value = handle_name_counter_.fetch_add(1); diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index a755befa5bf0d8..a0ee4c32471d81 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -67,25 +67,25 @@ class DirectSession : public Session { typedef std::vector> NamedTensorList; typedef std::unordered_map NameNodeMap; - ::tensorflow::Status Create(const GraphDef& graph) override; - ::tensorflow::Status Create(GraphDef&& graph) override; - ::tensorflow::Status Extend(const GraphDef& graph) override; - ::tensorflow::Status Extend(GraphDef&& graph) override; - ::tensorflow::Status Run(const NamedTensorList& inputs, - const std::vector& output_names, - const std::vector& target_nodes, - std::vector* outputs) override; + absl::Status Create(const GraphDef& graph) override; + absl::Status Create(GraphDef&& graph) override; + absl::Status Extend(const GraphDef& graph) override; + absl::Status Extend(GraphDef&& graph) override; + absl::Status Run(const NamedTensorList& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs) override; // NOTE: Experimental and subject to change. - ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, - const NamedTensorList& inputs, - const std::vector& output_names, - const std::vector& target_nodes, - std::vector* outputs, - RunMetadata* run_metadata) override; + absl::Status Run(const ::tensorflow::RunOptions& run_options, + const NamedTensorList& inputs, + const std::vector& output_names, + const std::vector& target_nodes, + std::vector* outputs, + RunMetadata* run_metadata) override; // NOTE: Experimental and subject to change. - ::tensorflow::Status Run( + absl::Status Run( const ::tensorflow::RunOptions& run_options, const NamedTensorList& inputs, const std::vector& output_names, const std::vector& target_nodes, std::vector* outputs, @@ -94,22 +94,21 @@ class DirectSession : public Session { // NOTE: PRunSetup and PRun are added to support partial execution. This // feature is experimental and subject to change. - ::tensorflow::Status PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) override; - ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, - const std::vector& output_names, - std::vector* outputs) override; + absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; + absl::Status PRun(const string& handle, const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs) override; // Reset clears 'containers' from the device_mgr of the DirectSession. // If 'containers' is empty, then Reset clears the default container. - ::tensorflow::Status Reset(const std::vector& containers); + absl::Status Reset(const std::vector& containers); - ::tensorflow::Status ListDevices( - std::vector* response) override; - ::tensorflow::Status Close() override; - ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { + absl::Status ListDevices(std::vector* response) override; + absl::Status Close() override; + absl::Status LocalDeviceManager(const DeviceMgr** output) override { *output = device_mgr_.get(); return absl::OkStatus(); } @@ -118,22 +117,22 @@ class DirectSession : public Session { cost_model_manager_.ExportCostModels(cost_models); } - ::tensorflow::Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override; + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; - ::tensorflow::Status RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) override; + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override; - ::tensorflow::Status RunCallable( + absl::Status RunCallable( CallableHandle handle, const std::vector& feed_tensors, std::vector* fetch_tensors, RunMetadata* run_metadata, const thread::ThreadPoolOptions& threadpool_options) override; - ::tensorflow::Status ReleaseCallable(CallableHandle handle) override; + absl::Status ReleaseCallable(CallableHandle handle) override; - ::tensorflow::Status Finalize() override; + absl::Status Finalize() override; const SessionOptions& options() const { return options_; } @@ -198,7 +197,7 @@ class DirectSession : public Session { // 'status' is the current status of the execution. struct RunState { mutex mu; - Status status TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); std::unique_ptr collective_executor; std::unique_ptr collector; TensorStore tensor_store; @@ -240,14 +239,15 @@ class DirectSession : public Session { // Retrieves an already existing set of executors to run 'inputs' and // 'outputs', or creates and caches them for future use. - ::tensorflow::Status GetOrCreateExecutors( - absl::Span inputs, absl::Span outputs, - absl::Span target_nodes, - ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); + absl::Status GetOrCreateExecutors(absl::Span inputs, + absl::Span outputs, + absl::Span target_nodes, + ExecutorsAndKeys** executors_and_keys, + RunStateArgs* run_state_args); // Creates a set of executors to run the subgraph defined by // `callable_options`. - ::tensorflow::Status CreateExecutors( + absl::Status CreateExecutors( const CallableOptions& callable_options, std::unique_ptr* out_executors_and_keys, std::unique_ptr* out_func_info, @@ -256,67 +256,65 @@ class DirectSession : public Session { // Creates several graphs given the existing graph_def_ and the // input feeds and fetches, given 'devices'. The graphs share a common // function library 'flib_def'. - ::tensorflow::Status CreateGraphs( + absl::Status CreateGraphs( const BuildGraphOptions& options, std::unordered_map>* outputs, std::unique_ptr* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types, int64_t* collective_graph_key); - ::tensorflow::Status RunInternal( - int64_t step_id, const RunOptions& run_options, - CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys, - RunMetadata* run_metadata, - const thread::ThreadPoolOptions& threadpool_options); + absl::Status RunInternal(int64_t step_id, const RunOptions& run_options, + CallFrameInterface* call_frame, + ExecutorsAndKeys* executors_and_keys, + RunMetadata* run_metadata, + const thread::ThreadPoolOptions& threadpool_options); // Returns whether inter-op execution uses a global pool or the input // `run_options` requests being run on inter_op_thread_pool = 0 in case // multiple pools are configured. bool ShouldUseRunHandlerPool(const RunOptions& run_options) const; - ::tensorflow::Status ExtendLocked(GraphDef&& graph) + absl::Status ExtendLocked(GraphDef&& graph) TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); - ::tensorflow::Status ResourceHandleToInputTensor( - const Tensor& resource_tensor, Tensor* retrieved_tensor); + absl::Status ResourceHandleToInputTensor(const Tensor& resource_tensor, + Tensor* retrieved_tensor); // Feeds more inputs to the executors, triggering further execution. - ::tensorflow::Status SendPRunInputs( + absl::Status SendPRunInputs( const std::vector>& inputs, const ExecutorsAndKeys* executors_and_keys, IntraProcessRendezvous* rendez); // Fetches more outputs from the executors. It waits until the output // tensors are computed. - ::tensorflow::Status RecvPRunOutputs( - const std::vector& output_names, - const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, - std::vector* outputs); + absl::Status RecvPRunOutputs(const std::vector& output_names, + const ExecutorsAndKeys* executors_and_keys, + PartialRunState* run_state, + std::vector* outputs); // Check if the specified fetches can be computed from the feeds // that we have already provided. - ::tensorflow::Status CheckFetch( - const std::vector>& feeds, - const std::vector& fetches, - const ExecutorsAndKeys* executors_and_keys, - const PartialRunState* run_state); + absl::Status CheckFetch(const std::vector>& feeds, + const std::vector& fetches, + const ExecutorsAndKeys* executors_and_keys, + const PartialRunState* run_state); // Use the appropriate WaitForNotification function based on whether // operation_timeout_in_ms is greater than 0. // // If the timeout expires, the `cm->StartCancel()` will be called. - ::tensorflow::Status WaitForNotification(Notification* n, - int64_t timeout_in_ms); + absl::Status WaitForNotification(Notification* n, int64_t timeout_in_ms); void WaitForNotification(Notification* n, RunState* run_state, CancellationManager* cm, int64_t timeout_in_ms); - ::tensorflow::Status CheckNotClosed() { + absl::Status CheckNotClosed() { mutex_lock l(closed_lock_); if (closed_) return errors::Cancelled("Session has been closed."); return absl::OkStatus(); } - ::tensorflow::Status CheckGraphCreated(const char* method) { + absl::Status CheckGraphCreated(const char* method) { mutex_lock l(graph_state_lock_); if (!graph_created_) { return errors::InvalidArgument( @@ -325,12 +323,12 @@ class DirectSession : public Session { return absl::OkStatus(); } - ::tensorflow::Status CreateDebuggerState( + absl::Status CreateDebuggerState( const CallableOptions& options, int64_t global_step, int64_t session_run_index, int64_t executor_step_index, std::unique_ptr* debugger_state); - ::tensorflow::Status DecorateAndPublishGraphForDebug( + absl::Status DecorateAndPublishGraphForDebug( const DebugOptions& debug_options, Graph* graph, Device* device); const SessionOptions options_; @@ -350,7 +348,7 @@ class DirectSession : public Session { // is owned. std::vector> thread_pools_; - Status init_error_; // Set to an error if construction failed. + absl::Status init_error_; // Set to an error if construction failed. // If true, blocks until device has finished all queued operations in a step. bool sync_on_finish_ = true; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index a25ef55e21d810..6850a0ef0082e3 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -146,7 +146,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork) { std::vector output_names = {y_ + ":0"}; std::vector target_nodes = {y_neg_}; std::vector outputs; - Status s = session->Run(inputs, output_names, target_nodes, &outputs); + absl::Status s = session->Run(inputs, output_names, target_nodes, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -182,7 +182,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_Callable) { EXPECT_FLOAT_EQ(5.0, mat(0, 0)); } - Status s = session->RunCallable(handle, {}, nullptr, nullptr); + absl::Status s = session->RunCallable(handle, {}, nullptr, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "`fetch_tensors` must be provided")); @@ -215,7 +215,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_OptimizeForStaticGraph) { std::vector output_names = {y_ + ":0"}; std::vector target_nodes = {y_neg_}; std::vector outputs; - Status s = session->Run(inputs, output_names, target_nodes, &outputs); + absl::Status s = session->Run(inputs, output_names, target_nodes, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -246,7 +246,7 @@ TEST_F(DirectSessionMinusAXTest, std::vector output_names = {y_ + ":0"}; std::vector target_nodes = {y_neg_}; std::vector outputs; - Status s = session->Run(inputs, output_names, target_nodes, &outputs); + absl::Status s = session->Run(inputs, output_names, target_nodes, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -298,7 +298,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithCallables) { TF_ASSERT_OK(session->ReleaseCallable(handle)); // Making a new callable fails because the session has been finalized. - Status s = + absl::Status s = session->MakeCallable(MakeCallableOptions({}, {y_ + ":0"}, {}), &handle); EXPECT_TRUE(errors::IsFailedPrecondition(s)); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been finalized.")); @@ -331,7 +331,7 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetwork_FinalizeWithRun) { EXPECT_FLOAT_EQ(5.0, mat(0, 0)); // Running a different subgraph fails because the session has been finalized. - Status s = session->Run({}, {y_ + ":0"}, {}, &outputs); + absl::Status s = session->Run({}, {y_ + ":0"}, {}, &outputs); EXPECT_TRUE(errors::IsFailedPrecondition(s)); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been finalized.")); } @@ -406,7 +406,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(y_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "would create a cycle")); } @@ -420,7 +420,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(y_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unknown node")); } @@ -435,7 +435,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(y_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unknown edge")); } @@ -449,7 +449,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(y_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsNotFound(s)); EXPECT_TRUE(absl::StrContains(s.message(), "unable to find feed output")); } @@ -466,7 +466,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(z_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -481,7 +481,7 @@ TEST_F(DirectSessionMinusAXTest, TestTensorConnection) { callable_options.add_fetch(y_neg_ + ":0"); Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options, &handle); + absl::Status s = session->MakeCallable(callable_options, &handle); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE(absl::StrContains(s.message(), "fed more than once")); } @@ -505,7 +505,7 @@ TEST_F(DirectSessionMinusAXTest, TestFeed) { std::vector outputs; // Run the graph - Status s = session->Run(inputs, output_names, {}, &outputs); + absl::Status s = session->Run(inputs, output_names, {}, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -565,7 +565,7 @@ TEST_F(DirectSessionMinusAXTest, TestConcurrency) { std::vector> inputs; std::vector outputs; // Run the graph - Status s = session->Run(inputs, output_names, {}, &outputs); + absl::Status s = session->Run(inputs, output_names, {}, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); @@ -635,7 +635,7 @@ TEST_F(DirectSessionMinusAXTest, TestPerSessionThreads) { std::vector> inputs; std::vector outputs; // Run the graph - Status s = session->Run(inputs, output_names, {}, &outputs); + absl::Status s = session->Run(inputs, output_names, {}, &outputs); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); @@ -723,8 +723,8 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts) { RunMetadata run_metadata; EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 0); - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, &run_metadata); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, &run_metadata); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -787,8 +787,8 @@ TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) { RunOptions run_options; run_options.mutable_experimental()->set_use_run_handler_pool(true); - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, nullptr); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); TF_ASSERT_OK(s); ASSERT_EQ(1, outputs.size()); @@ -827,7 +827,7 @@ TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) { std::vector outputs; // Initialize the variable - Status s = session->Run(inputs, {init->name()}, {}, &outputs); + absl::Status s = session->Run(inputs, {init->name()}, {}, &outputs); TF_ASSERT_OK(s); // Get the variable's data @@ -861,7 +861,7 @@ TEST(DirectSessionTest, MultipleFeedTest) { std::vector outputs; // Fetch without feeding. - Status s = session->Run( + absl::Status s = session->Run( {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs); TF_ASSERT_OK(s); @@ -985,7 +985,7 @@ TEST(DirectSessionTest, MultipleFeedTest_Callable) { ASSERT_EQ(22.0, outputs[1].flat()(0)); // Feed [first_const, first_const] - Status s = session->MakeCallable( + absl::Status s = session->MakeCallable( MakeCallableOptions( {first_const->name(), first_const->name()}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}), @@ -1060,7 +1060,7 @@ TEST(DirectSessionTest, FetchMultipleTimes) { std::vector outputs; auto seven = seven_node->name(); - Status s = session->Run(inputs, {seven, seven}, {}, &outputs); + absl::Status s = session->Run(inputs, {seven, seven}, {}, &outputs); TF_ASSERT_OK(s); EXPECT_EQ(2, outputs.size()); @@ -1096,7 +1096,7 @@ TEST(DirectSessionTest, MultipleFeedTestSomeSyncRun) { std::vector outputs; // Fetch without feeding. - Status s = session->Run( + absl::Status s = session->Run( run_options, {}, {first_identity->name() + ":0", second_identity->name() + ":0"}, {}, &outputs, nullptr); @@ -1592,7 +1592,7 @@ TEST(DirectSessionTest, PartialRunTest) { std::vector outputs; string handle; - Status s = session->PRunSetup( + absl::Status s = session->PRunSetup( {first_const->name(), second_const->name()}, {first_identity->name() + ":0", second_identity->name() + ":0", third_identity->name() + ":0"}, @@ -1648,8 +1648,9 @@ TEST(DirectSessionTest, PartialRunMissingFeed) { std::vector outputs; string handle; - Status s = session->PRunSetup({first_const->name(), second_const->name()}, - {third_identity->name() + ":0"}, {}, &handle); + absl::Status s = + session->PRunSetup({first_const->name(), second_const->name()}, + {third_identity->name() + ":0"}, {}, &handle); TF_ASSERT_OK(s); // Feed first_const, fetch third_identity @@ -1681,8 +1682,9 @@ TEST(DirectSessionTest, PartialRunMultiOutputFeed) { std::vector outputs; string handle; - Status s = session->PRunSetup({switch_node->name() + ":1"}, - {fourth_identity->name() + ":0"}, {}, &handle); + absl::Status s = + session->PRunSetup({switch_node->name() + ":1"}, + {fourth_identity->name() + ":0"}, {}, &handle); TF_ASSERT_OK(s); // Fetch fourth_identity without feeds. @@ -1729,7 +1731,7 @@ TEST(DirectSessionTest, RunHandleTest) { // First run call: Create a handle. std::vector outputs; - Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs); + absl::Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs); ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); @@ -1782,7 +1784,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) { // First run call: Create a handle. std::vector outputs; - Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs); + absl::Status s = session->Run({}, {node4->name() + ":0"}, {}, &outputs); ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); @@ -1823,7 +1825,8 @@ TEST(DirectSessionTest, CreateGraphFailsWhenAssigningAFedVar) { // The graph is invalid since a constant cannot be assigned to a constant. // The return Status of session->Run should flag this as an invalid argument. std::vector outputs; - Status s = session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs); + absl::Status s = + session->Run({{a->name(), zero}}, {assign->name()}, {}, &outputs); ASSERT_TRUE(errors::IsInvalidArgument(s)); } @@ -1885,7 +1888,7 @@ TEST(DirectSessionTest, TimeoutSession) { TF_ASSERT_OK(session->Create(graph)); // Verifies that the error code is DEADLINE_EXCEEDED. - Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr); + absl::Status s = session->Run({}, {}, {"fifo_queue_Dequeue"}, nullptr); ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code()); TF_ASSERT_OK(session->Close()); } @@ -1898,8 +1901,8 @@ TEST(DirectSessionTest, TimeoutSession) { RunOptions run_options; run_options.set_timeout_in_ms(20); // Verifies that the error code is DEADLINE_EXCEEDED. - Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, - nullptr, nullptr); + absl::Status s2 = session->Run(run_options, {}, {}, {"fifo_queue_Dequeue"}, + nullptr, nullptr); ASSERT_EQ(error::DEADLINE_EXCEEDED, s2.code()); TF_ASSERT_OK(session->Close()); } @@ -1947,7 +1950,7 @@ TEST(DirectSessionTest, TestTimeoutCleanShutdown) { TF_ASSERT_OK(session->Create(graph)); // Verifies that the error code is DEADLINE_EXCEEDED. - Status s = session->Run({}, {}, {"cm_polling"}, nullptr); + absl::Status s = session->Run({}, {}, {"cm_polling"}, nullptr); ASSERT_EQ(error::DEADLINE_EXCEEDED, s.code()); // Verify that the op ran to completion. @@ -2033,9 +2036,10 @@ static void TestSessionInterOpThreadsImpl(bool use_function_lib, session = sessions[0].get(); } - Status s = session->Run(run_options, {} /* inputs */, - {node->name() + ":0"} /* output_names */, {}, - &outputs, nullptr /* run_metadata */); + absl::Status s = + session->Run(run_options, {} /* inputs */, + {node->name() + ":0"} /* output_names */, {}, + &outputs, nullptr /* run_metadata */); TF_CHECK_OK(s); ASSERT_EQ(1, outputs.size()); auto flat = outputs[0].flat(); @@ -2146,9 +2150,9 @@ TEST(DirectSessionTest, TestSessionInterOpThreadsInvalidOptions) { RunOptions run_options; run_options.set_inter_op_thread_pool(pool_num); std::vector outputs; - Status s = session->Run(run_options, {} /* inputs */, - {x->name() + ":0"} /* output_names */, {}, - &outputs, nullptr /* run_metadata */); + absl::Status s = session->Run(run_options, {} /* inputs */, + {x->name() + ":0"} /* output_names */, {}, + &outputs, nullptr /* run_metadata */); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains( s.message(), @@ -2215,8 +2219,8 @@ TEST(DirectSessionTest, TestDirectSessionRunClose) { TF_ASSERT_OK(session->Close()); // Run the read on the variable to get an error. - Status s = session->Run({} /* inputs */, {}, - {var_assign->name()} /* target_nodes */, nullptr); + absl::Status s = session->Run( + {} /* inputs */, {}, {var_assign->name()} /* target_nodes */, nullptr); EXPECT_EQ(s.code(), error::CANCELLED); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been closed.")); @@ -2252,7 +2256,7 @@ TEST(DirectSessionTest, TestDirectSessionPRunClose) { std::vector outputs; string handle; - Status s = session->PRunSetup( + absl::Status s = session->PRunSetup( {first_const->name(), second_const->name()}, {first_identity->name() + ":0", second_identity->name() + ":0", third_identity->name() + ":0"}, @@ -2309,8 +2313,8 @@ TEST(DirectSessionTest, TestDirectSessionReset) { // TODO(suharshs): This test only works because we close the Session in Reset. // If we change the behavior of Reset to not close the Session, this test will // fail, since the Variable buffer is cached by var. - Status s = session->Run({} /* inputs */, {}, - {var_assign->name()} /* target_nodes */, nullptr); + absl::Status s = session->Run( + {} /* inputs */, {}, {var_assign->name()} /* target_nodes */, nullptr); EXPECT_EQ(s.code(), error::CANCELLED); EXPECT_TRUE(absl::StrContains(s.message(), "Session has been closed.")); } @@ -2331,7 +2335,7 @@ class FakeDevice : public Device { explicit FakeDevice(const DeviceAttributes& device_attributes) : Device(nullptr, device_attributes) {} - Status Sync() override { + absl::Status Sync() override { return absl::UnimplementedError("FakeDevice::Sync()"); } }; @@ -2340,11 +2344,12 @@ class FakeDevice : public Device { template class FakeFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override { + absl::Status ListPhysicalDevices(std::vector* devices) override { return absl::OkStatus(); } - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override { + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override { std::string name = absl::StrFormat("%cPU", FirstLetter); DeviceAttributes attr; attr.set_name( @@ -2594,7 +2599,7 @@ void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable( opts.mutable_fetch_devices()->insert({"y:0", gpu_device_name}); opts.set_fetch_skip_sync(true); Session::CallableHandle handle; - Status status = session->MakeCallable(opts, &handle); + absl::Status status = session->MakeCallable(opts, &handle); EXPECT_FALSE(status.ok()) << DataType_Name(dtype); EXPECT_TRUE(absl::StrContains( status.message(), @@ -2610,7 +2615,7 @@ void TestFeedAndFetchTensorsInDeviceMemoryFailsToMakeCallable( opts.clear_feed_devices(); opts.mutable_feed_devices()->insert({"x:0", gpu_device_name}); Session::CallableHandle handle; - Status status = session->MakeCallable(opts, &handle); + absl::Status status = session->MakeCallable(opts, &handle); EXPECT_FALSE(status.ok()); EXPECT_TRUE(absl::StrContains( status.message(), @@ -2807,8 +2812,8 @@ class DirectSessionCollectiveTest : public ::testing::Test { public: // Creates a graph with CollectiveOps inside functions and runs it. Returns // the generated collective_graph_key. - Status RunGraphWithCollectiveFunctions(bool add_unused_function, - int64_t* collective_graph_key) { + absl::Status RunGraphWithCollectiveFunctions(bool add_unused_function, + int64_t* collective_graph_key) { GraphDef g = CreateGraph(add_unused_function); const Tensor t1 = test::AsTensor({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1}); diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index b82799f4e46383..46cf7d529490e0 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -89,7 +89,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { std::vector target_nodes = {y_neg->name()}; std::vector outputs; const int64_t start_micros = Env::Default()->NowMicros(); - Status s = session->Run(inputs, output_names, target_nodes, &outputs); + absl::Status s = session->Run(inputs, output_names, target_nodes, &outputs); const int64_t run_duration_micros = Env::Default()->NowMicros() - start_micros; TF_ASSERT_OK(s); @@ -207,8 +207,8 @@ static void TestHWAccelerator(bool enableHWTrace) { run_options.set_trace_level(RunOptions::FULL_TRACE); } RunMetadata run_metadata; - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, &run_metadata); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, &run_metadata); const int64_t run_duration_micros = Env::Default()->NowMicros() - start_micros; TF_ASSERT_OK(s); @@ -287,8 +287,8 @@ TEST(DirectSessionWithTrackingAllocTest, CostGraph) { std::vector outputs; RunMetadata run_metadata; const int64_t start_micros = Env::Default()->NowMicros(); - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, &run_metadata); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, &run_metadata); const int64_t run_duration_micros = Env::Default()->NowMicros() - start_micros; TF_ASSERT_OK(s); @@ -344,8 +344,8 @@ TEST(DirectSessionWithTrackingAllocTest, TrackMemoryAllocation) { std::vector output_names = {y->name() + ":0"}; std::vector outputs; RunMetadata run_metadata; - Status s = session->Run(run_options, inputs, output_names, {}, &outputs, - &run_metadata); + absl::Status s = session->Run(run_options, inputs, output_names, {}, &outputs, + &run_metadata); TF_ASSERT_OK(s); for (const auto& dev_stat : run_metadata.step_stats().dev_stats()) { diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr.cc b/tensorflow/core/common_runtime/dynamic_device_mgr.cc index 06ccc121440394..d1f8fd52c338d8 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr.cc @@ -32,7 +32,7 @@ DynamicDeviceMgr::DynamicDeviceMgr() : cpu_device_(nullptr) {} DynamicDeviceMgr::DynamicDeviceMgr( std::vector>&& devices) : cpu_device_(nullptr) { - Status status = AddDevices(std::move(devices)); + absl::Status status = AddDevices(std::move(devices)); CHECK(status.ok()); // Crash OK mutex_lock l(devices_mu_); // Initialize cpu_device_. @@ -104,7 +104,8 @@ string DynamicDeviceMgr::DeviceMappingString() const { return out; } -Status DynamicDeviceMgr::LookupDevice(StringPiece name, Device** device) const { +absl::Status DynamicDeviceMgr::LookupDevice(StringPiece name, + Device** device) const { tf_shared_lock l(devices_mu_); auto iter = device_map_.find(string(name)); if (iter == device_map_.end()) { @@ -128,7 +129,7 @@ bool DynamicDeviceMgr::ContainsDevice(int64_t device_incarnation) const { void DynamicDeviceMgr::ClearContainers( absl::Span containers) const { - Status s; + absl::Status s; tf_shared_lock l(devices_mu_); for (const auto& it : dynamic_devices_) { auto d = it.first; @@ -158,7 +159,7 @@ int DynamicDeviceMgr::NumDevices() const { return dynamic_devices_.size(); } -Status DynamicDeviceMgr::AddDevices( +absl::Status DynamicDeviceMgr::AddDevices( std::vector> devices) { mutex_lock l(devices_mu_); for (auto& d : devices) { @@ -184,7 +185,8 @@ Status DynamicDeviceMgr::AddDevices( return absl::OkStatus(); } -Status DynamicDeviceMgr::RemoveDevices(const std::vector& devices) { +absl::Status DynamicDeviceMgr::RemoveDevices( + const std::vector& devices) { mutex_lock l(devices_mu_); for (const auto& d : devices) { @@ -224,7 +226,7 @@ Status DynamicDeviceMgr::RemoveDevices(const std::vector& devices) { return absl::OkStatus(); } -Status DynamicDeviceMgr::RemoveDevicesByName( +absl::Status DynamicDeviceMgr::RemoveDevicesByName( const std::vector& device_names) { std::vector devices_to_remove; for (const string& name : device_names) { diff --git a/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc b/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc index 90c9782ab1a3e0..092566e5ea2a97 100644 --- a/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc +++ b/tensorflow/core/common_runtime/dynamic_device_mgr_test.cc @@ -34,7 +34,7 @@ static Device* CreateDevice(const char* type, const char* name, class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; @@ -159,7 +159,7 @@ TEST(DynamicDeviceMgrTest, AddRepeatedDeviceToMgr) { std::vector> added_devices; added_devices.emplace_back(std::move(d1)); - Status s = dm->AddDevices(std::move(added_devices)); + absl::Status s = dm->AddDevices(std::move(added_devices)); EXPECT_TRUE( absl::StrContains(s.message(), "name conflicts with an existing device")); } @@ -177,7 +177,7 @@ TEST(DynamicDeviceMgrTest, RemoveNonExistingDeviceFromMgr) { EXPECT_EQ(dm->ListDevices().size(), 1); std::vector removed_devices{d0_ptr, d1_ptr}; - Status s = dm->RemoveDevices(removed_devices); + absl::Status s = dm->RemoveDevices(removed_devices); EXPECT_TRUE(absl::StrContains(s.message(), "Unknown device")); EXPECT_EQ(dm->ListDevices().size(), 1); // d0 *not* removed. } @@ -194,7 +194,7 @@ TEST(DynamicDeviceMgrTest, RemoveNonExistingDeviceByNameFromMgr) { EXPECT_EQ(dm->ListDevices().size(), 1); std::vector removed_devices{d0_name, d1_name}; - Status s = dm->RemoveDevicesByName(removed_devices); + absl::Status s = dm->RemoveDevicesByName(removed_devices); EXPECT_TRUE(absl::StrContains(s.message(), "unknown device")); EXPECT_EQ(dm->ListDevices().size(), 1); // d0 *not* removed } diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index c7da536c946e30..9b8b1ce5a067f5 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -40,7 +40,7 @@ tf_cuda_library( srcs = [ "core.cc", ], - visibility = ["//tensorflow:internal"], + visibility = ["//visibility:public"], deps = [ ":context", ":eager_operation", @@ -91,7 +91,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/platform:status_matchers", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -103,7 +103,7 @@ tf_cuda_library( hdrs = [ "context.h", ], - visibility = ["//tensorflow:internal"], + visibility = ["//visibility:public"], deps = [ ":custom_device", ":eager_executor", diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index 2c41e95e0200c4..1f27eaf6d64f19 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -57,17 +57,17 @@ const AttrTypeMap* GetDefaultFunctionAttrTypeMap() { } // namespace -Status OpDefForOp(const string& op_name, const OpDef** op_def) { +absl::Status OpDefForOp(const string& op_name, const OpDef** op_def) { const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + absl::Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); if (s.ok()) { *op_def = &op_reg_data->op_def; } return s; } -Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, - bool* is_function) { +absl::Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, + bool* is_function) { { tf_shared_lock l(g_op_name_to_attr_type_map_lock); *is_function = false; @@ -84,7 +84,7 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, if (*out != nullptr) return absl::OkStatus(); const OpDef* op_def = nullptr; - Status s = OpDefForOp(op_name, &op_def); + absl::Status s = OpDefForOp(op_name, &op_def); if (absl::IsNotFound(s)) { // If we did not find the op def, we assume `op_name` is a function. // If it is actually a misspelled op, user will get another error when @@ -161,8 +161,8 @@ DEFINE_GET_ATTR(tensorflow::DataType, type, "type"); #undef DEFINE_GET_ATTR template <> -Status AttrBuilder::Get(StringPiece attr_name, - absl::InlinedVector* value) const { +absl::Status AttrBuilder::Get(StringPiece attr_name, + absl::InlinedVector* value) const { auto it = encoded_attrs_.find(string(attr_name)); if (it == encoded_attrs_.end()) { return errors::NotFound("No attr named '", attr_name, @@ -192,7 +192,7 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m) const { // specify all the default attr values (e.g. for matmul, the `transpose_a` // attr defaults to false). const OpDef* op_def = nullptr; - Status s = OpDefForOp(op_name().c_str(), &op_def); + absl::Status s = OpDefForOp(op_name().c_str(), &op_def); // This is expected, if this op is a custom function, and is therefore not // present in the op registry. if (!s.ok()) return; @@ -224,7 +224,7 @@ bool ValueMatchesDefault(const OpDef* op_def, const string& attr_name, void AttrBuilder::FillAttrValueMapWithoutDefaults(AttrValueMap* m) const { const OpDef* op_def = nullptr; - Status s = OpDefForOp(op_name().c_str(), &op_def); + absl::Status s = OpDefForOp(op_name().c_str(), &op_def); for (auto& entry : encoded_attrs_) { attr_tmp_.ParseFromString(entry.second); @@ -260,8 +260,8 @@ void AttrBuilder::CopyAttributes(const AttrBuilder& other) { other.encoded_attrs_.end()); } -Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, - TF_AttrType* out, unsigned char* is_list) { +absl::Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list) { auto* t = gtl::FindOrNull(m, attr_name); if (t == nullptr) { return errors::InvalidArgument("Attribute '", attr_name, @@ -322,28 +322,28 @@ void AttrBuilder::GetNameAttrList( name_and_attrs->set_name(op_name()); } -Status AttrBuilder::GetTypeList( +absl::Status AttrBuilder::GetTypeList( absl::string_view attr_name, absl::InlinedVector* type_list) const { return Get(attr_name, type_list); } bool AttrBuilder::GetInt(absl::string_view attr_name, int64_t* result) const { - Status s = Get(attr_name, result); + absl::Status s = Get(attr_name, result); return s.ok(); } bool AttrBuilder::GetFloat(absl::string_view attr_name, float* result) const { - Status s = Get(attr_name, result); + absl::Status s = Get(attr_name, result); return s.ok(); } bool AttrBuilder::GetBool(absl::string_view attr_name, bool* result) const { - Status s = Get(attr_name, result); + absl::Status s = Get(attr_name, result); return s.ok(); } bool AttrBuilder::GetType(absl::string_view attr_name, tensorflow::DataType* result) const { - Status s = Get(attr_name, result); + absl::Status s = Get(attr_name, result); return s.ok(); } diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index 6fc817039cc214..129841e8f90133 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -43,18 +43,18 @@ namespace tensorflow { typedef std::unordered_map AttrTypeMap; // Look up OpDef for `op_name`. -Status OpDefForOp(const string& op_name, const OpDef** op_def); +absl::Status OpDefForOp(const string& op_name, const OpDef** op_def); // Returns the AttrTypeMap for the TensorFlow operation named op_name. // If op_name is not registered in global op registry, AttrTypeMapForOp assumes // the op to be a function and returns the default attributes for a function. // `is_function` is set to true in this case. -Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, - bool* is_function); +absl::Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, + bool* is_function); // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. -Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, - TF_AttrType* out, unsigned char* is_list); +absl::Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, + TF_AttrType* out, unsigned char* is_list); // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. // An AttrBuilder is a convenience class to help with that - providing a smaller @@ -139,7 +139,7 @@ class AttrBuilder : public AbstractOpAttrs { // value type in this Node. This is not an issue, because Get is used rarely // and nodes have a small number of attributes. template - Status Get(StringPiece attr_name, T* value) const { + absl::Status Get(StringPiece attr_name, T* value) const { // Common attributes are stored in AttrVecs. This Get() template // is specialized for them below. If we end up here, the type must be // among those that we store in the node_def_. @@ -178,7 +178,7 @@ class AttrBuilder : public AbstractOpAttrs { bool GetBool(absl::string_view attr_name, bool* result) const override; bool GetType(absl::string_view attr_name, tensorflow::DataType* result) const override; - Status GetTypeList( + absl::Status GetTypeList( absl::string_view attr_name, absl::InlinedVector* type_list) const override; @@ -210,14 +210,14 @@ class AttrBuilder : public AbstractOpAttrs { }; template <> -Status AttrBuilder::Get(StringPiece attr_name, int* value) const; +absl::Status AttrBuilder::Get(StringPiece attr_name, int* value) const; template <> -Status AttrBuilder::Get(StringPiece attr_name, float* value) const; +absl::Status AttrBuilder::Get(StringPiece attr_name, float* value) const; template <> -Status AttrBuilder::Get(StringPiece attr_name, bool* value) const; +absl::Status AttrBuilder::Get(StringPiece attr_name, bool* value) const; template <> -Status AttrBuilder::Get(StringPiece attr_name, - tensorflow::DataType* value) const; +absl::Status AttrBuilder::Get(StringPiece attr_name, + tensorflow::DataType* value) const; } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_ diff --git a/tensorflow/core/common_runtime/eager/attr_builder_test.cc b/tensorflow/core/common_runtime/eager/attr_builder_test.cc index dd0a18cece4a38..77462842f493a2 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder_test.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder_test.cc @@ -41,7 +41,7 @@ TEST(AttrTypeMap, Lookup) { // Unknown ops are assumed to be functions. // Their maps are filled with default attributes. bool is_function = false; - Status s = AttrTypeMapForOp("SomeFunctionName", &m, &is_function); + absl::Status s = AttrTypeMapForOp("SomeFunctionName", &m, &is_function); EXPECT_TRUE(s.ok()); EXPECT_TRUE(is_function); ASSERT_NE(m->end(), m->find("executor_type")); @@ -134,7 +134,7 @@ TEST(AttrBuilder, GetTypeList) { AttrBuilder a("IdentityN"); a.Set("T", absl::Span({DT_FLOAT, DT_INT64})); absl::InlinedVector type_list; - Status s = a.GetTypeList("T", &type_list); + absl::Status s = a.GetTypeList("T", &type_list); ASSERT_TRUE(s.ok()) << s; ASSERT_EQ(2, type_list.size()) << type_list.size(); ASSERT_EQ(DT_FLOAT, type_list[0]) << type_list[0]; diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 0677e45b4c83a6..67e11021bce18c 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -332,8 +332,9 @@ Device* SelectBestMatchingDevice(const DeviceNameUtils::ParsedName& pattern, } // namespace -Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred, - const NodeDef& ndef, Device** out) const { +absl::Status EagerContext::SelectDevice(DeviceNameUtils::ParsedName preferred, + const NodeDef& ndef, + Device** out) const { DCHECK(out != nullptr); PrioritizedDeviceTypeVector supported_devs; @@ -554,11 +555,11 @@ void EagerContext::CloseRemoteContexts( int i = 0; for (const auto& worker : remote_contexts) { core::RefCountPtr client; - Status s = GetClient(worker, &client); + absl::Status s = GetClient(worker, &client); client->CloseContextAsync( &request, &responses[i], - [&worker, &counter, context_id](const Status& s) { + [&worker, &counter, context_id](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Unable to close remote context with ID " << context_id << " for worker: " << worker << " due to " @@ -649,7 +650,7 @@ EagerContext::~EagerContext() { // shutdown is supported. if (server_->worker_env()->session_mgr != nullptr) { // Tear down coordination service. - Status s = server_->StopCoordinationService(); + absl::Status s = server_->StopCoordinationService(); if (!s.ok()) { LOG(ERROR) << "Failed to stop coordination service: " << s; } @@ -692,7 +693,7 @@ bool EagerContext::FindFunctionByName(const string& name) const { return func_lib_def_.Find(name) != nullptr; } -Status EagerContext::FindFunctionOpData( +absl::Status EagerContext::FindFunctionOpData( const string& name, const tensorflow::OpRegistrationData** op_data) { return func_lib_def_.LookUp(name, op_data); } @@ -718,7 +719,7 @@ ImmediateExecutionTensorHandle* EagerContext::TFTensorHandleFromInterface( return handle; } -Status EagerContext::RegisterFunction(AbstractFunction* f) { +absl::Status EagerContext::RegisterFunction(AbstractFunction* f) { TF_ASSIGN_OR_RETURN(core::RefCountPtr record, f->GetFunctionRecord()); if (!record) { @@ -788,7 +789,8 @@ std::vector EagerContext::ListAllTfDevices() { return devices; } -Status EagerContext::AddDevices(std::vector> devices) { +absl::Status EagerContext::AddDevices( + std::vector> devices) { std::vector> local_devices, remote_devices; while (!devices.empty()) { if (devices.front()->IsLocal()) { @@ -839,7 +841,8 @@ ScopedStepContainer* EagerContext::StepContainer() { return step_container_.get(); } -Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { +absl::Status EagerContext::MaybeRegisterFunctionRemotely( + const FunctionDef& fdef) { // Only client context can register function on remote worker context. if (!remote_device_manager_.Owned()) return absl::OkStatus(); #if !defined(IS_MOBILE_PLATFORM) @@ -862,7 +865,7 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { eager_client->StreamingEnqueueAsync( this->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, request.get(), response, - [request, response](const Status& status) { + [request, response](const absl::Status& status) { if (!status.ok()) { LOG(ERROR) << "Failed to register function remotely due to " << status.message() @@ -876,7 +879,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { return absl::OkStatus(); } -Status EagerContext::MaybeRemoveFunctionRemotely(const string& function_name) { +absl::Status EagerContext::MaybeRemoveFunctionRemotely( + const string& function_name) { // Only client context can remove function on remote worker context. if (!remote_device_manager_.Owned()) { return absl::OkStatus(); @@ -899,7 +903,7 @@ Status EagerContext::MaybeRemoveFunctionRemotely(const string& function_name) { eager_client->StreamingEnqueueAsync( this->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, request.get(), response.get(), - [request, response](const Status& status) { + [request, response](const absl::Status& status) { if (!status.ok()) { LOG(ERROR) << "Failed to remove function remotely due to " << status.message() @@ -912,7 +916,7 @@ Status EagerContext::MaybeRemoveFunctionRemotely(const string& function_name) { return absl::OkStatus(); } -Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( +absl::Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( const std::vector& remote_workers) { #if !defined(IS_MOBILE_PLATFORM) // Register multiple functions on selected remote workers. @@ -934,7 +938,7 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( for (auto& remote_worker : remote_workers) { core::RefCountPtr eager_client; - Status s = GetClient(remote_worker, &eager_client); + absl::Status s = GetClient(remote_worker, &eager_client); if (!s.ok()) { continue; } @@ -943,7 +947,7 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( eager_client->StreamingEnqueueAsync( this->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, requests[i].get(), response.get(), - [request = requests[i], response](const Status& s) { + [request = requests[i], response](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Failed to register function remotely due to " << s.message() @@ -957,27 +961,27 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( return absl::OkStatus(); } -Status EagerContext::AddFunctionDefWithStackTraces( +absl::Status EagerContext::AddFunctionDefWithStackTraces( const FunctionDef& fdef, const StackTracesMap& stack_traces) { return AddFunctionDef(fdef, FunctionDefLibrary(), /* add_to_local_only=*/false, stack_traces); } -Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { +absl::Status EagerContext::AddFunctionDef(const FunctionDef& fdef) { return AddFunctionDef(fdef, FunctionDefLibrary(), /* add_to_local_only=*/false); } -Status EagerContext::AddFunctionDef(const FunctionDef& fdef, - const FunctionDefLibrary& library, - const bool add_to_local_only, - const StackTracesMap& stack_traces) { +absl::Status EagerContext::AddFunctionDef(const FunctionDef& fdef, + const FunctionDefLibrary& library, + const bool add_to_local_only, + const StackTracesMap& stack_traces) { core::RefCountPtr func_record( new FunctionRecord(fdef, stack_traces, true)); return AddFunctionRecord(std::move(func_record), library, add_to_local_only); } -Status EagerContext::AddFunctionRecord( +absl::Status EagerContext::AddFunctionRecord( core::RefCountPtr func_record, const FunctionDefLibrary& library, bool add_to_local_only) { const FunctionDef& fdef = func_record->fdef(); @@ -1039,8 +1043,8 @@ Status EagerContext::AddFunctionRecord( return absl::OkStatus(); } -Status EagerContext::AddComponentFunction(const FunctionDef& fdef, - const FunctionDefLibrary& library) { +absl::Status EagerContext::AddComponentFunction( + const FunctionDef& fdef, const FunctionDefLibrary& library) { { mutex_lock l(cache_mu_); auto iter = component_function_libraries_.find(fdef.signature().name()); @@ -1083,8 +1087,8 @@ std::vector EagerContext::ListFunctionNames() { return func_lib_def_.ListFunctionNames(); } -Status EagerContext::AddRemoveFunctionNotifier(const string& func, - std::function notifier) { +absl::Status EagerContext::AddRemoveFunctionNotifier( + const string& func, std::function notifier) { mutex_lock l(remove_function_notifiers_mu_); auto iter = remove_function_notifiers_.find(func); if (iter != remove_function_notifiers_.end()) { @@ -1118,7 +1122,7 @@ EagerContext::GetCacheStats() { return stats; } -Status EagerContext::RemoveFunction(const string& func) { +absl::Status EagerContext::RemoveFunction(const string& func) { // TODO(mdan): The context owns these functions. Why check refcount then? std::vector> notifiers; bool is_last_ref = false; @@ -1159,7 +1163,7 @@ Status EagerContext::RemoveFunction(const string& func) { return absl::OkStatus(); } -Status EagerContext::SyncExecutors() { +absl::Status EagerContext::SyncExecutors() { VLOG(6) << "Calling SyncExecutors"; StatusGroup sg; // Synchronize on context default executor @@ -1184,7 +1188,7 @@ Status EagerContext::SyncExecutors() { request.set_context_id(GetContextId()); request.add_queue()->mutable_sync_remote_executor_for_stream(); BlockingCounter counter(static_cast(remote_contexts.size())); - std::vector statuses(remote_contexts.size()); + std::vector statuses(remote_contexts.size()); for (int i = 0; i < remote_contexts.size(); i++) { const auto& target = remote_contexts[i]; @@ -1195,14 +1199,15 @@ Status EagerContext::SyncExecutors() { eager_client->StreamingEnqueueAsync( this->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, &request, response, - [response, target, &counter, &s = statuses[i]](const Status& status) { + [response, target, &counter, + &s = statuses[i]](const absl::Status& status) { s = status; delete response; counter.DecrementCount(); }); } counter.Wait(); - for (const Status& s : statuses) { + for (const absl::Status& s : statuses) { sg.Update(s); } #endif // !IS_MOBILE_PLATFORM @@ -1272,8 +1277,8 @@ void EagerContext::SetShouldStoreGraphs(bool value) { } } -Status EagerContext::FindDeviceFromName(const char* device_name, - Device** device) const { +absl::Status EagerContext::FindDeviceFromName(const char* device_name, + Device** device) const { *device = HostCPU(); if (device_name == nullptr || strlen(device_name) == 0) { return absl::OkStatus(); @@ -1291,7 +1296,7 @@ Status EagerContext::FindDeviceFromName(const char* device_name, return status; } -Status EagerContext::FindCompositeDeviceFromName( +absl::Status EagerContext::FindCompositeDeviceFromName( StringPiece device_name, CompositeDevice** device) const { tf_shared_lock l(composite_devices_mu_); for (const auto& d : composite_devices_) { @@ -1309,7 +1314,7 @@ bool EagerContext::IsCustomDevice(const string& device_name) { &device); } -Status EagerContext::RegisterCustomDevice( +absl::Status EagerContext::RegisterCustomDevice( const string& device_name, std::unique_ptr device) { Device* existing_physical_device = nullptr; if (FindDeviceFromName(device_name.c_str(), &existing_physical_device).ok()) { @@ -1320,7 +1325,7 @@ Status EagerContext::RegisterCustomDevice( std::move(device)); } -Status EagerContext::FindOrCreateCompositeDevice( +absl::Status EagerContext::FindOrCreateCompositeDevice( const std::vector& underlying_devices, const string& device_name, CompositeDevice** composite_device) { if (!device_name.empty() && @@ -1337,7 +1342,7 @@ Status EagerContext::FindOrCreateCompositeDevice( return absl::OkStatus(); } - Status s; + absl::Status s; std::unique_ptr device; if (device_name.empty()) { // Create a CompositeDevice on the same task as the host CPU, in order to @@ -1364,8 +1369,8 @@ bool EagerContext::OnSameTask(const Device* first, const Device* second) const { } // Gets the CPU device on the task of device. -Status EagerContext::CPUDeviceOnTask(const Device* device, - Device** cpu_device) const { +absl::Status EagerContext::CPUDeviceOnTask(const Device* device, + Device** cpu_device) const { string cpu_device_name; TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( device->name(), &cpu_device_name)); @@ -1382,7 +1387,8 @@ void EagerContext::ClearResourceContainer(const string& name) { } } -Status EagerContext::GetGlobalRendezvousForFunctionLocalRendezvousStatus() { +absl::Status +EagerContext::GetGlobalRendezvousForFunctionLocalRendezvousStatus() { mutex_lock l(global_rendezvous_mu_); tsl::core::RefCountPtr rendezvous = local_rendezvous_cache_.Find(kGlobalRendezvousId); @@ -1400,7 +1406,7 @@ void EagerContext::UpdateGlobalRendezvousDeviceManager( } namespace { -Status GetTaskName(Device* d, string* task_name) { +absl::Status GetTaskName(Device* d, string* task_name) { string ignored; if (!DeviceNameUtils::SplitDeviceName(d->name(), task_name, &ignored)) { return errors::InvalidArgument("Unable to parse device name: ", d->name()); @@ -1411,13 +1417,14 @@ Status GetTaskName(Device* d, string* task_name) { } // namespace #if !defined(IS_MOBILE_PLATFORM) -Status EagerContext::GetClient(Device* device, - core::RefCountPtr* client) { +absl::Status EagerContext::GetClient( + Device* device, core::RefCountPtr* client) { return GetClient(device->parsed_name(), client); } -Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name, - core::RefCountPtr* client) { +absl::Status EagerContext::GetClient( + const DeviceNameUtils::ParsedName& device_name, + core::RefCountPtr* client) { string device_task_name; if (!DeviceNameUtils::GetTaskName(device_name, &device_task_name)) { return errors::InvalidArgument( @@ -1449,8 +1456,8 @@ Status EagerContext::GetClient(const DeviceNameUtils::ParsedName& device_name, return absl::OkStatus(); } -Status EagerContext::GetClient(const string& remote_task, - core::RefCountPtr* client) { +absl::Status EagerContext::GetClient( + const string& remote_task, core::RefCountPtr* client) { { tf_shared_lock l(remote_state_mu_); if (remote_eager_workers_ == nullptr) { @@ -1482,13 +1489,13 @@ void EagerContext::IncrementContextViewId() { context_view_id_ += 1; } -Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) { +absl::Status EagerContext::EnableCollectiveOps(const ServerDef& server_def) { return distributed_manager_->EnableCollectiveOps(server_def); } // Set collective ops related state in the context. Passing nullptr to // `new_server` will reuse the existing GRPC server in context. -Status EagerContext::StoreCollectiveOpsServer( +absl::Status EagerContext::StoreCollectiveOpsServer( std::unique_ptr new_server, DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr) { collective_executor_mgr_.Reset(rpc_collective_executor_mgr); @@ -1536,7 +1543,7 @@ Status EagerContext::StoreCollectiveOpsServer( return absl::OkStatus(); } -Status EagerContext::SetRemoteDeviceFilters( +absl::Status EagerContext::SetRemoteDeviceFilters( const string& remote_worker, const std::vector& device_filters) { // Get fully specified task name for remote worker string remote_worker_task_name; @@ -1622,7 +1629,7 @@ void EagerContext::SetWorkerEnv(WorkerEnv* worker_env, worker_session_ = worker_session; } -Status EagerContext::InitializeRemoteMaster( +absl::Status EagerContext::InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, @@ -1653,7 +1660,7 @@ Status EagerContext::InitializeRemoteMaster( cluster_flr, std::move(remote_mgr)); } -Status EagerContext::UpdateRemoteMaster( +absl::Status EagerContext::UpdateRemoteMaster( uint64 context_id, std::unique_ptr remote_eager_workers, const std::vector& add_remote_contexts, @@ -1720,7 +1727,7 @@ Status EagerContext::UpdateRemoteMaster( } // Set distributed execution related state in the master context. -Status EagerContext::SetMasterContextState( +absl::Status EagerContext::SetMasterContextState( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, @@ -1809,7 +1816,7 @@ Status EagerContext::SetMasterContextState( { for (const auto& worker : remote_contexts_) { core::RefCountPtr client; - Status s = + absl::Status s = remote_eager_workers_->GetClient(worker, &client); if (!s.ok()) { @@ -1827,7 +1834,7 @@ Status EagerContext::SetMasterContextState( request->set_context_id(context_id_); client->KeepAliveAsync( request, response, - [request, response](const Status& s) { + [request, response](const absl::Status& s) { delete request; delete response; }); @@ -1842,7 +1849,7 @@ Status EagerContext::SetMasterContextState( return absl::OkStatus(); } -Status EagerContext::InitializeRemoteWorker( +absl::Status EagerContext::InitializeRemoteWorker( std::unique_ptr remote_eager_workers, DynamicDeviceMgr* remote_device_mgr, const std::vector& remote_contexts, uint64 context_id, @@ -1899,7 +1906,7 @@ Status EagerContext::InitializeRemoteWorker( return absl::OkStatus(); } -Status EagerContext::UpdateRemoteWorker( +absl::Status EagerContext::UpdateRemoteWorker( std::unique_ptr remote_eager_workers, const std::vector& remote_contexts, uint64 context_id) { { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 53b64bdde97e31..9dac42c1921215 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -133,7 +133,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { tensorflow::Tensor& t, const char* d_name) override; ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* handle, const char* device_name, - Status* status) override; + absl::Status* status) override; ImmediateExecutionOperation* CreateOperation() override; // This is a virtual helper function to convert TFRT TensorHandle to @@ -143,7 +143,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { ImmediateExecutionTensorHandle* TFTensorHandleFromInterface( ImmediateExecutionTensorHandle* handle) override; - Status RegisterFunction(AbstractFunction* f) override; + absl::Status RegisterFunction(AbstractFunction* f) override; bool UsesTFRT() override; @@ -157,7 +157,8 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { void ListDevices(std::vector* device_attributes) override; - Status AddDevices(std::vector> devices) override; + absl::Status AddDevices( + std::vector> devices) override; thread::ThreadPool* GetThreadPool() { return thread_pool_.get(); } @@ -203,14 +204,14 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // // The chosen device is stored in the `device` argument. The argument is not // modified unless this method returns `OkStatus()`. - Status SelectDevice(DeviceNameUtils::ParsedName preferred, - const NodeDef& ndef, Device** out) const; + absl::Status SelectDevice(DeviceNameUtils::ParsedName preferred, + const NodeDef& ndef, Device** out) const; // TODO(mdan): Rename to ContainsFunction. bool FindFunctionByName(const string& name) const; - Status FindFunctionOpData(const string& name, - const tensorflow::OpRegistrationData** op_data); + absl::Status FindFunctionOpData( + const string& name, const tensorflow::OpRegistrationData** op_data); const FunctionDef* FindFunctionDef(const string& name) const override; core::RefCountPtr FindRecord( @@ -232,53 +233,53 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Add the given `fdef` to the local FunctionLibraryDefinition. And add an // entry to the KernelAndDevice cache for it if it's not exist. - Status AddFunctionDef(const FunctionDef& fdef) override; + absl::Status AddFunctionDef(const FunctionDef& fdef) override; - Status AddFunctionDefWithStackTraces( + absl::Status AddFunctionDefWithStackTraces( const FunctionDef& fdef, const StackTracesMap& stack_traces) override; // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add // it to the local FunctionLibraryDefinition as well, but no need to add it // to the KernelAndDevice cache since they won't be executed as // KernelAndDevices. - Status AddFunctionDef(const FunctionDef& fdef, - const FunctionDefLibrary& library, - bool add_to_local_only = false, - const StackTracesMap& stack_traces = {}); + absl::Status AddFunctionDef(const FunctionDef& fdef, + const FunctionDefLibrary& library, + bool add_to_local_only = false, + const StackTracesMap& stack_traces = {}); // `library` contains all FunctionDefs and GradientDefs to expand `fdef`. Add // it to the local FunctionLibraryDefinition as well, but no need to add it // to the KernelAndDevice cache since they won't be executed as // KernelAndDevices. - Status AddFunctionRecord(core::RefCountPtr func_record, - const FunctionDefLibrary& library, - bool add_to_local_only = false); + absl::Status AddFunctionRecord(core::RefCountPtr func_record, + const FunctionDefLibrary& library, + bool add_to_local_only = false); // Adds a component function (i.e. containing a subgraph of a multi-process // function) implemented as `fdef`. // // REQUIRES: `library` must contain all functions reachable from `fdef`. It // should not contain `fdef` itself. - Status AddComponentFunction(const FunctionDef& fdef, - const FunctionDefLibrary& library); + absl::Status AddComponentFunction(const FunctionDef& fdef, + const FunctionDefLibrary& library); const FunctionDef* GetFunctionDef(const string& function_name); std::vector ListFunctionNames() override; tensorflow::ImmediateExecutionContext::CacheStats GetCacheStats() override; - Status RemoveFunction(const string& func) override; - Status AddRemoveFunctionNotifier(const string& func, - std::function notifier) override; + absl::Status RemoveFunction(const string& func) override; + absl::Status AddRemoveFunctionNotifier( + const string& func, std::function notifier) override; // Wait for pending nodes to be finished in local executors (including context // default executor and thread executors) and executors on remote workers. // Return combined status of remote executors. If there are multiple errors, // the Status code will be the same as the first remote executor that has // errors, and the error message will be combined from all executors. - Status SyncExecutors(); + absl::Status SyncExecutors(); - Status AsyncWait() override { return SyncExecutors(); } + absl::Status AsyncWait() override { return SyncExecutors(); } core::RefCountPtr GetCachedKernel(Fprint128 cache_key); Device* GetCachedDevice(Fprint128 device_cache_key); @@ -316,7 +317,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Returns the global_rendezvous_for_functions' underlying LocalRendezvous' // status. If the underlying Rendezvous is not in the local_rendezvous_cache_ // returns OK. - Status GetGlobalRendezvousForFunctionLocalRendezvousStatus(); + absl::Status GetGlobalRendezvousForFunctionLocalRendezvousStatus(); // Returns a factory which maps from step_id to rendezvous. // @@ -414,18 +415,18 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // destructing the RefCountPtr object at the caller's side. // `client` must not be initialized or holding a reference of another object // before calling this method. - Status GetClient(Device* device, - core::RefCountPtr* client); - Status GetClient(const DeviceNameUtils::ParsedName& device_name, - core::RefCountPtr* client); - Status GetClient(const string& remote_task, - core::RefCountPtr* client); + absl::Status GetClient(Device* device, + core::RefCountPtr* client); + absl::Status GetClient(const DeviceNameUtils::ParsedName& device_name, + core::RefCountPtr* client); + absl::Status GetClient(const string& remote_task, + core::RefCountPtr* client); uint64 GetContextId() const; uint64 GetContextViewId() const; void IncrementContextViewId(); - Status EnableCollectiveOps(const ServerDef& server_def) override; + absl::Status EnableCollectiveOps(const ServerDef& server_def) override; // TODO(nareshmodi): Encapsulate remote state into a separate // class/struct. @@ -442,7 +443,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // (should contain no local devices). // - remote_contexts: A vector containing task names. // TODO(b/184375824): clean up parameter order for better readability. - Status InitializeRemoteMaster( + absl::Status InitializeRemoteMaster( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, @@ -460,7 +461,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // keep the current resource manager so that resources from the previous view // can still be accessed, and will automatically register existing functions // if there are newly added hosts. - Status UpdateRemoteMaster( + absl::Status UpdateRemoteMaster( uint64 context_id, std::unique_ptr remote_eager_workers, const std::vector& add_remote_contexts, @@ -468,7 +469,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Similar with InitializeRemoteMaster but this context will not kill remote // contexts in shutdown. - Status InitializeRemoteWorker( + absl::Status InitializeRemoteWorker( std::unique_ptr remote_eager_workers, DynamicDeviceMgr* remote_device_mgr, const std::vector& remote_contexts, uint64 context_id, @@ -482,17 +483,17 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Similar with InitializeRemoteWorker but will reuse existing context and // increment context_view_id. - Status UpdateRemoteWorker( + absl::Status UpdateRemoteWorker( std::unique_ptr remote_eager_workers, const std::vector& remote_contexts, uint64 context_id); - Status StoreCollectiveOpsServer( + absl::Status StoreCollectiveOpsServer( std::unique_ptr new_server, DeviceMgr* device_mgr, CollectiveExecutorMgrInterface* rpc_collective_executor_mgr); // For the specified remote worker, preprocess and set its device filters. - Status SetRemoteDeviceFilters(const string& remote_worker, - const std::vector& device_filters); + absl::Status SetRemoteDeviceFilters( + const string& remote_worker, const std::vector& device_filters); // For the specified remote worker, apply the stored device filters to the // list of device attributes following these rules: @@ -558,15 +559,16 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { tensorflow::Env* TFEnv() const { return env_; } - Status FindDeviceFromName(const char* device_name, Device** device) const; + absl::Status FindDeviceFromName(const char* device_name, + Device** device) const; - Status FindCompositeDeviceFromName(StringPiece device_name, - CompositeDevice** device) const; + absl::Status FindCompositeDeviceFromName(StringPiece device_name, + CompositeDevice** device) const; bool IsCustomDevice(const string& device_name) override; - Status RegisterCustomDevice(const string& name, - std::unique_ptr device) override; + absl::Status RegisterCustomDevice( + const string& name, std::unique_ptr device) override; CustomDeviceOpHandler& GetCustomDeviceOpHandler() override { return custom_device_op_handler_; @@ -574,13 +576,13 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { // Find or create a composite device with the given `underlying_devices` and // `device_name` (if not empty). - Status FindOrCreateCompositeDevice( + absl::Status FindOrCreateCompositeDevice( const std::vector& underlying_devices, const string& device_name, CompositeDevice** composite_device); bool OnSameTask(const Device* first, const Device* second) const; // Gets the CPU device on the task of device. - Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; + absl::Status CPUDeviceOnTask(const Device* device, Device** cpu_device) const; const SessionOptions& session_options() const { return opts_; } void InitPrioritizedDeviceTypeList(); @@ -662,9 +664,9 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { ~EagerContext() override; - Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); - Status MaybeRemoveFunctionRemotely(const string& function_name); - Status RegisterExistingFunctionsOnRemoteWorkers( + absl::Status MaybeRegisterFunctionRemotely(const FunctionDef& fdef); + absl::Status MaybeRemoveFunctionRemotely(const string& function_name); + absl::Status RegisterExistingFunctionsOnRemoteWorkers( const std::vector& remote_workers); void ResetPFLR(const DeviceMgr* device_mgr, Env* env, @@ -833,7 +835,7 @@ class EagerContext : public ImmediateExecutionContext, public core::RefCounted { uint64 context_id, uint64 context_view_id); // TODO(b/184375824): clean up parameter order for better readability. - Status SetMasterContextState( + absl::Status SetMasterContextState( std::unique_ptr server, WorkerEnv* worker_env, std::shared_ptr worker_session, std::unique_ptr remote_eager_workers, diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index b36e2a2c6c6362..e13ee2ffac4a0a 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -29,6 +29,7 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_config.pb.h" #if !defined(IS_MOBILE_PLATFORM) #include "absl/base/thread_annotations.h" @@ -353,19 +353,19 @@ bool AreLocalDevicesCompatible(const EagerContext* context, context->session_options().config.SerializeAsString(); } -Status AddRemoteDevicesToMgr(const std::vector& added_remote_workers, - WorkerCacheInterface* worker_cache, - DynamicDeviceMgr* remote_device_mgr) { +absl::Status AddRemoteDevicesToMgr( + const std::vector& added_remote_workers, + WorkerCacheInterface* worker_cache, DynamicDeviceMgr* remote_device_mgr) { std::vector> remote_devices; mutex remote_devices_mu; int num_added_workers = added_remote_workers.size(); BlockingCounter counter(num_added_workers); - std::vector statuses(num_added_workers); + std::vector statuses(num_added_workers); for (int i = 0; i < num_added_workers; i++) { NewRemoteDevices( Env::Default(), worker_cache, added_remote_workers[i], [i, &statuses, &counter, &remote_devices, &remote_devices_mu]( - const Status& s, std::vector* devices) { + const absl::Status& s, std::vector* devices) { statuses[i] = s; if (s.ok()) { mutex_lock l(remote_devices_mu); @@ -385,9 +385,10 @@ Status AddRemoteDevicesToMgr(const std::vector& added_remote_workers, return absl::OkStatus(); } -Status GetAllRemoteDevices(const std::vector& remote_workers, - WorkerCacheInterface* worker_cache, - std::unique_ptr* device_mgr) { +absl::Status GetAllRemoteDevices( + const std::vector& remote_workers, + WorkerCacheInterface* worker_cache, + std::unique_ptr* device_mgr) { auto remote_device_mgr = std::make_unique(); TF_RETURN_IF_ERROR(AddRemoteDevicesToMgr(remote_workers, worker_cache, remote_device_mgr.get())); @@ -395,7 +396,7 @@ Status GetAllRemoteDevices(const std::vector& remote_workers, return absl::OkStatus(); } -Status RemoveRemoteDevicesFromMgr( +absl::Status RemoveRemoteDevicesFromMgr( const std::vector& removed_remote_workers, DynamicDeviceMgr* remote_device_mgr) { const std::vector remote_devices = @@ -413,8 +414,9 @@ Status RemoveRemoteDevicesFromMgr( return absl::OkStatus(); } -Status ListRemoteWorkers(ServerInterface* server, const string& local_worker, - std::vector* remote_workers) { +absl::Status ListRemoteWorkers(ServerInterface* server, + const string& local_worker, + std::vector* remote_workers) { server->master_env()->worker_cache->ListWorkers(remote_workers); remote_workers->erase( std::remove(remote_workers->begin(), remote_workers->end(), local_worker), @@ -455,13 +457,13 @@ void DifferentiateWorkerLists(const std::vector* current_list, existing->resize(existing_it - existing->begin()); } -Status GetReplacedFromExistingWorkers( +absl::Status GetReplacedFromExistingWorkers( const std::vector* existing_workers, uint64 context_id, uint64 context_view_id, const ServerDef& server_def, eager::EagerClientCache* client_cache, std::vector* replaced_workers) { BlockingCounter counter(existing_workers->size()); - std::vector statuses(existing_workers->size()); + std::vector statuses(existing_workers->size()); eager::KeepAliveRequest request; request.set_context_id(context_id); std::vector responses(existing_workers->size()); @@ -473,11 +475,12 @@ Status GetReplacedFromExistingWorkers( counter.DecrementCount(); continue; } - eager_client->KeepAliveAsync(&request, &responses[i], - [i, &statuses, &counter](const Status& s) { - statuses[i] = s; - counter.DecrementCount(); - }); + eager_client->KeepAliveAsync( + &request, &responses[i], + [i, &statuses, &counter](const absl::Status& s) { + statuses[i] = s; + counter.DecrementCount(); + }); } counter.Wait(); for (int i = 0; i < existing_workers->size(); i++) { @@ -493,7 +496,7 @@ Status GetReplacedFromExistingWorkers( return absl::OkStatus(); } -Status CreateRemoteContexts( +absl::Status CreateRemoteContexts( EagerContext* context, const std::vector& remote_workers, uint64 context_id, uint64 context_view_id, int keep_alive_secs, const ServerDef& server_def, eager::EagerClientCache* remote_eager_workers, @@ -501,7 +504,7 @@ Status CreateRemoteContexts( int64_t init_timeout_in_ms, int retries, bool clear_existing_contexts) { int num_remote_workers = remote_workers.size(); BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); + std::vector statuses(num_remote_workers); for (int i = 0; i < num_remote_workers; i++) { const string& remote_worker = remote_workers[i]; DeviceNameUtils::ParsedName parsed_name; @@ -554,7 +557,7 @@ Status CreateRemoteContexts( eager_client->CreateContextAsync( &request, response, - [i, &statuses, &counter, response](const Status& s) { + [i, &statuses, &counter, response](const absl::Status& s) { statuses[i] = s; delete response; counter.DecrementCount(); @@ -571,17 +574,16 @@ Status CreateRemoteContexts( return sg.as_summary_status(); } -Status UpdateRemoteContexts(EagerContext* context, - const std::vector& remote_workers, - const std::vector& added_workers, - const std::vector& removed_workers, - uint64 context_id, uint64 context_view_id, - const ServerDef& server_def, - eager::EagerClientCache* remote_eager_workers, - const eager::CreateContextRequest& base_request) { +absl::Status UpdateRemoteContexts( + EagerContext* context, const std::vector& remote_workers, + const std::vector& added_workers, + const std::vector& removed_workers, uint64 context_id, + uint64 context_view_id, const ServerDef& server_def, + eager::EagerClientCache* remote_eager_workers, + const eager::CreateContextRequest& base_request) { int num_remote_workers = remote_workers.size(); BlockingCounter counter(num_remote_workers); - std::vector statuses(num_remote_workers); + std::vector statuses(num_remote_workers); int cluster_device_count = base_request.cluster_device_attributes_size(); std::unordered_set added_or_removed(added_workers.begin(), @@ -661,7 +663,7 @@ Status UpdateRemoteContexts(EagerContext* context, eager_client->UpdateContextAsync( &request, response, - [i, &statuses, &counter, response](const Status& s) { + [i, &statuses, &counter, response](const absl::Status& s) { statuses[i] = s; delete response; counter.DecrementCount(); @@ -674,11 +676,11 @@ Status UpdateRemoteContexts(EagerContext* context, return absl::OkStatus(); } -Status UpdateContextWithServerDef(EagerContext* context, - const ServerDef& server_def, - bool reset_context, int keep_alive_secs, - int64_t init_timeout_in_ms, int retries, - bool clear_existing_contexts = false) { +absl::Status UpdateContextWithServerDef(EagerContext* context, + const ServerDef& server_def, + bool reset_context, int keep_alive_secs, + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) { string worker_name = strings::StrCat("/job:", server_def.job_name(), "/replica:0/task:", server_def.task_index()); @@ -823,7 +825,7 @@ Status UpdateContextWithServerDef(EagerContext* context, } // Initialize remote eager workers. - Status reset_context_status = absl::OkStatus(); + absl::Status reset_context_status = absl::OkStatus(); if (reset_context) { reset_context_status = CreateRemoteContexts( context, remote_workers, context_id, context_view_id, keep_alive_secs, @@ -920,7 +922,7 @@ Status UpdateContextWithServerDef(EagerContext* context, } } // namespace -Status EagerContextDistributedManager::SetOrUpdateServerDef( +absl::Status EagerContextDistributedManager::SetOrUpdateServerDef( const ServerDef& server_def, bool reset_context, int keep_alive_secs, int64_t init_timeout_in_ms, int retries, bool clear_existing_contexts) { if (server_def.has_cluster_device_filters()) { @@ -946,9 +948,9 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( "when updating the server def."; } } - Status s = UpdateContextWithServerDef(context_, server_def, reset_context, - keep_alive_secs, init_timeout_in_ms, - retries, clear_existing_contexts); + absl::Status s = UpdateContextWithServerDef( + context_, server_def, reset_context, keep_alive_secs, init_timeout_in_ms, + retries, clear_existing_contexts); if (!s.ok()) { coordination_service_agent_ = nullptr; return s; @@ -961,7 +963,7 @@ Status EagerContextDistributedManager::SetOrUpdateServerDef( return absl::OkStatus(); } -Status EagerContextDistributedManager::InitializeLocalOnlyContext( +absl::Status EagerContextDistributedManager::InitializeLocalOnlyContext( const ServerDef& server_def, int keep_alive_secs) { string worker_name = strings::StrCat("/job:", server_def.job_name(), @@ -1029,7 +1031,7 @@ Status EagerContextDistributedManager::InitializeLocalOnlyContext( return absl::OkStatus(); } -Status EagerContextDistributedManager::EnableCollectiveOps( +absl::Status EagerContextDistributedManager::EnableCollectiveOps( const ServerDef& server_def) { ServerInterface* server = context_->GetServer(); if (server == nullptr) { @@ -1052,7 +1054,7 @@ Status EagerContextDistributedManager::EnableCollectiveOps( LOG_AND_RETURN_IF_ERROR(session_mgr->CreateSession( session_name, server_def, context_->session_options().config.isolate_session_state(), - [this](Status s) { + [this](absl::Status s) { context_->GetCollectiveExecutorHandle()->get()->StartAbort(s); })); LOG_AND_RETURN_IF_ERROR( @@ -1070,7 +1072,7 @@ Status EagerContextDistributedManager::EnableCollectiveOps( absl::StatusOr time_or_status) { if (time_or_status.ok()) { const auto coord_task = coord_agent->GetOwnTask().value(); - Status s = coord_agent->InsertKeyValue( + absl::Status s = coord_agent->InsertKeyValue( "TF_DEFAULT_PREEMPTION_NOTICE_KEY", absl::StrCat("/job:", coord_task.job_name(), "/task:", coord_task.task_id())); @@ -1149,7 +1151,7 @@ Status EagerContextDistributedManager::EnableCollectiveOps( return absl::OkStatus(); } -Status EagerContextDistributedManager::CheckRemoteAlive( +absl::Status EagerContextDistributedManager::CheckRemoteAlive( const std::string& remote_task_name, bool* is_alive) { *is_alive = false; WorkerInterface* wi = @@ -1163,10 +1165,10 @@ Status EagerContextDistributedManager::CheckRemoteAlive( GetStatusRequest request; GetStatusResponse response; - Status remote_status; + absl::Status remote_status; Notification done; wi->GetStatusAsync(/*opts_=*/nullptr, &request, &response, /*fail_fast=*/true, - [&remote_status, &done](const Status& s) { + [&remote_status, &done](const absl::Status& s) { remote_status = s; done.Notify(); }); diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.h b/tensorflow/core/common_runtime/eager/context_distributed_manager.h index ebc535b20fb253..9db43d9e64e5ae 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.h +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.h @@ -44,18 +44,18 @@ class EagerContextDistributedManager // When running in a distributed context, `init_timeout_in_ms` requests the // amount of time to wait for remote workers to respond. - Status SetOrUpdateServerDef(const ServerDef& server_def, bool reset_context, - int keep_alive_secs, int64_t init_timeout_in_ms, - int retries, - bool clear_existing_contexts = false) override; + absl::Status SetOrUpdateServerDef( + const ServerDef& server_def, bool reset_context, int keep_alive_secs, + int64_t init_timeout_in_ms, int retries, + bool clear_existing_contexts = false) override; - Status InitializeLocalOnlyContext(const ServerDef& server_def, - int keep_alive_secs) override; + absl::Status InitializeLocalOnlyContext(const ServerDef& server_def, + int keep_alive_secs) override; - Status EnableCollectiveOps(const ServerDef& server_def) override; + absl::Status EnableCollectiveOps(const ServerDef& server_def) override; - Status CheckRemoteAlive(const std::string& remote_task_name, - bool* is_alive) override; + absl::Status CheckRemoteAlive(const std::string& remote_task_name, + bool* is_alive) override; tsl::CoordinationServiceAgent* GetCoordinationServiceAgent() override { return coordination_service_agent_; diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc index 4c50c62c33ab5b..fd92a0597ecce9 100644 --- a/tensorflow/core/common_runtime/eager/context_test.cc +++ b/tensorflow/core/common_runtime/eager/context_test.cc @@ -54,7 +54,7 @@ static Device* CreateDevice(const string& type, int n) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -254,7 +254,7 @@ TEST_F(EagerContextTest, AddFunctionDefRepeatDifferent) { {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, }); - Status s = context()->AddFunctionDef(x_times_two_copy); + absl::Status s = context()->AddFunctionDef(x_times_two_copy); EXPECT_FALSE(s.ok()); } @@ -281,7 +281,7 @@ TEST_F(EagerContextTest, FunctionErrorRecovery) { {{"T", DT_FLOAT}}, /*dep=*/{"assert"}}, }); - Status s = context()->AddFunctionDef(assert_and_identity); + absl::Status s = context()->AddFunctionDef(assert_and_identity); auto fail_op = ImmediateOpPtr(context()->CreateOperation()); TF_ASSERT_OK(fail_op->Reset("AssertAndIdentity", "/job:localhost/replica:0/task:0/device:CPU:0")); @@ -340,7 +340,7 @@ TEST_F(EagerContextTest, XlaCompileDeviceType) { {{"y"}, "Mul", {"x", "two"}, {{"T", DT_INT64}}}, }); - Status s = context()->AddFunctionDef(x_times_two); + absl::Status s = context()->AddFunctionDef(x_times_two); context()->SetJitCompileRewrite(true); auto op = ImmediateOpPtr(context()->CreateOperation()); TF_ASSERT_OK( diff --git a/tensorflow/core/common_runtime/eager/copy_to_device_node.h b/tensorflow/core/common_runtime/eager/copy_to_device_node.h index d496627d418499..37d943b24e6475 100644 --- a/tensorflow/core/common_runtime/eager/copy_to_device_node.h +++ b/tensorflow/core/common_runtime/eager/copy_to_device_node.h @@ -50,14 +50,14 @@ class CopyToDeviceNode : public EagerNode { } } - Status Run() override { + absl::Status Run() override { tensorflow::Tensor tensor; tsl::profiler::ScopedMemoryDebugAnnotation op_annotation( "eager::CopyToDeviceNode", "dynamic", tensor.dtype(), [&tensor]() { return tensor.shape().DebugString(); }); TF_RETURN_IF_ERROR(src_->CopyToDevice(ctx_, dstd_, &tensor)); if (!async_ && mirror_) { - Status s = dst_->AddLocalMirror(std::move(tensor), dstd_); + absl::Status s = dst_->AddLocalMirror(std::move(tensor), dstd_); // If a mirror was added since we called HasLocalMirror then just return // and ignore the error. if (s.ok() || (s.code() == error::Code::ALREADY_EXISTS)) { @@ -69,7 +69,7 @@ class CopyToDeviceNode : public EagerNode { } } - void Abort(Status status) override { dst_->Poison(status, dstd_); } + void Abort(absl::Status status) override { dst_->Poison(status, dstd_); } string DebugString() const override { string out = "[CopyToDeviceNode]"; diff --git a/tensorflow/core/common_runtime/eager/core.cc b/tensorflow/core/common_runtime/eager/core.cc index 741b9a4966ae64..c907c62eb1fa00 100644 --- a/tensorflow/core/common_runtime/eager/core.cc +++ b/tensorflow/core/common_runtime/eager/core.cc @@ -37,7 +37,7 @@ namespace tensorflow { // TODO(b/152902651): This should not depend on EagerContext. This can be // resolved by storing ctx->HostCPU() in the TensorHandle class. -AbstractTensorInterface* TensorHandle::Resolve(Status* status) { +AbstractTensorInterface* TensorHandle::Resolve(absl::Status* status) { *status = WaitUnknownDevice(); if (!status->ok()) { return nullptr; @@ -102,7 +102,7 @@ AbstractTensorInterface* TensorHandle::Resolve(Status* status) { ImmediateExecutionTensorHandle* EagerContext::CopyTensorHandleToDevice( ImmediateExecutionTensorHandle* handle, const char* device_name, - Status* status) { + absl::Status* status) { ImmediateExecutionTensorHandle* result = nullptr; Device* device; *status = this->FindDeviceFromName(device_name, &device); @@ -155,8 +155,8 @@ ImmediateExecutionOperation* EagerContext::CreateOperation() { // TODO(b/152902651): Once we move many execute.cc functions into // eager_operation.cc we can avoid a circular dependency between them. -Status EagerOperation::Execute(absl::Span retvals, - int* num_retvals) { +absl::Status EagerOperation::Execute(absl::Span retvals, + int* num_retvals) { for (ImmediateExecutionTensorHandle* handle : inputs_) { if (TensorHandle::classof(handle)) { TF_RETURN_IF_ERROR(down_cast(handle)->WaitUnknownDevice()); diff --git a/tensorflow/core/common_runtime/eager/custom_device.cc b/tensorflow/core/common_runtime/eager/custom_device.cc index 61c52b005285bc..08a2a025a86699 100644 --- a/tensorflow/core/common_runtime/eager/custom_device.cc +++ b/tensorflow/core/common_runtime/eager/custom_device.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const { +absl::Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const { int num_dims; TF_RETURN_IF_ERROR(NumDims(&num_dims)); std::vector dims(num_dims); @@ -32,7 +32,8 @@ Status CustomDeviceTensorHandle::Shape(PartialTensorShape* shape) const { return PartialTensorShape::MakePartialShape(dims.data(), num_dims, shape); } -Status CustomDeviceTensorHandle::NumElements(int64_t* num_elements) const { +absl::Status CustomDeviceTensorHandle::NumElements( + int64_t* num_elements) const { *num_elements = 1; int num_dims; TF_RETURN_IF_ERROR(NumDims(&num_dims)); @@ -50,7 +51,7 @@ Status CustomDeviceTensorHandle::NumElements(int64_t* num_elements) const { return absl::OkStatus(); } -const char* CustomDeviceTensorHandle::DeviceType(Status* status) const { +const char* CustomDeviceTensorHandle::DeviceType(absl::Status* status) const { const DeviceNameUtils::ParsedName* parsed = ParsedName(status); if (!status->ok()) { return ""; @@ -58,7 +59,7 @@ const char* CustomDeviceTensorHandle::DeviceType(Status* status) const { return parsed->type.c_str(); } -int CustomDeviceTensorHandle::DeviceId(Status* status) const { +int CustomDeviceTensorHandle::DeviceId(absl::Status* status) const { const DeviceNameUtils::ParsedName* parsed = ParsedName(status); if (!status->ok()) { return 0; @@ -66,7 +67,8 @@ int CustomDeviceTensorHandle::DeviceId(Status* status) const { return parsed->id; } -AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) { +AbstractTensorInterface* CustomDeviceTensorHandle::Resolve( + absl::Status* status) { core::RefCountPtr copied_off( context_->GetCustomDeviceOpHandler().CopyTensorHandleToDevice( context_, this, @@ -80,7 +82,7 @@ AbstractTensorInterface* CustomDeviceTensorHandle::Resolve(Status* status) { } const DeviceNameUtils::ParsedName* CustomDeviceTensorHandle::ParsedName( - Status* status) const { + absl::Status* status) const { if (!parsed_name_.has_value()) { DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseFullOrLocalName(device_->name(), &parsed_name)) { diff --git a/tensorflow/core/common_runtime/eager/custom_device.h b/tensorflow/core/common_runtime/eager/custom_device.h index 6ab6cbe8283793..2f4f5acc95549f 100644 --- a/tensorflow/core/common_runtime/eager/custom_device.h +++ b/tensorflow/core/common_runtime/eager/custom_device.h @@ -38,22 +38,22 @@ class CustomDevice { public: virtual ~CustomDevice() = default; virtual const string& name() = 0; - virtual Status CopyTensorToDevice( + virtual absl::Status CopyTensorToDevice( ImmediateExecutionTensorHandle* tensor, ImmediateExecutionTensorHandle** result) = 0; - virtual Status CopyTensorFromDevice( + virtual absl::Status CopyTensorFromDevice( ImmediateExecutionTensorHandle* tensor, const string& target_device_name, ImmediateExecutionTensorHandle** result) = 0; - virtual Status Execute(const ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, - int* num_retvals) = 0; + virtual absl::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) = 0; // Creates a packed TensorHandle from a group of custom device TensorHandles, // one of which is on this custom device. - virtual Status Pack(absl::Span handles, - ImmediateExecutionTensorHandle** result) = 0; + virtual absl::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) = 0; // Returns true signifying to pin to the current custom device. // Returns false to pin to the physical device. @@ -98,20 +98,20 @@ class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { tensorflow::DataType DataType() const override { return dtype_; } tensorflow::FullTypeDef FullType() const override { return full_type_; } - Status Shape(PartialTensorShape* shape) const override; - Status NumElements(int64_t* num_elements) const override; + absl::Status Shape(PartialTensorShape* shape) const override; + absl::Status NumElements(int64_t* num_elements) const override; - const char* DeviceName(Status* status) const override { + const char* DeviceName(absl::Status* status) const override { return device_->name().c_str(); } - const char* BackingDeviceName(Status* status) const override { + const char* BackingDeviceName(absl::Status* status) const override { return device_->name().c_str(); } CustomDevice* device() const { return device_; } - const char* DeviceType(Status* status) const override; - int DeviceId(Status* status) const override; + const char* DeviceType(absl::Status* status) const override; + int DeviceId(absl::Status* status) const override; - AbstractTensorInterface* Resolve(Status* status) override; + AbstractTensorInterface* Resolve(absl::Status* status) override; // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { @@ -119,7 +119,7 @@ class CustomDeviceTensorHandle : public ImmediateExecutionTensorHandle { } protected: - const DeviceNameUtils::ParsedName* ParsedName(Status* status) const; + const DeviceNameUtils::ParsedName* ParsedName(absl::Status* status) const; ImmediateExecutionContext* const context_; CustomDevice* const device_; diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc index 34b89b4da61eb0..426930f04b8cda 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.cc @@ -25,7 +25,7 @@ namespace tensorflow { void CustomDeviceOpHandler::Clear() { custom_devices_.clear(); } -Status CustomDeviceOpHandler::RegisterCustomDevice( +absl::Status CustomDeviceOpHandler::RegisterCustomDevice( const string& device_name, std::unique_ptr device) { DeviceNameUtils::ParsedName parsed; if (!DeviceNameUtils::ParseFullName(device_name, &parsed) || @@ -55,9 +55,9 @@ bool CustomDeviceOpHandler::FindCustomDeviceFromName( return true; } -Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, - int* num_retvals) { +absl::Status CustomDeviceOpHandler::Execute( + ImmediateExecutionOperation* op, ImmediateExecutionTensorHandle** retvals, + int* num_retvals) { tensorflow::CustomDevice* custom_device = nullptr; TF_RETURN_IF_ERROR(MaybePinToCustomDevice(&custom_device, *op)); @@ -85,7 +85,7 @@ Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, tensorflow::ImmediateExecutionTensorHandle* new_tensor; TF_RETURN_IF_ERROR(previous->device()->CopyTensorFromDevice( previous, target_device, &new_tensor)); - Status s = op->SetInput(i, new_tensor); + absl::Status s = op->SetInput(i, new_tensor); new_tensor->Unref(); TF_RETURN_IF_ERROR(s); } @@ -101,7 +101,7 @@ Status CustomDeviceOpHandler::Execute(ImmediateExecutionOperation* op, ImmediateExecutionTensorHandle* CustomDeviceOpHandler::CopyTensorHandleToDevice( ImmediateExecutionContext* context, ImmediateExecutionTensorHandle* handle, - const char* device_name, Status* status) { + const char* device_name, absl::Status* status) { *status = absl::OkStatus(); ImmediateExecutionTensorHandle* result = nullptr; tensorflow::CustomDevice* dev; @@ -132,7 +132,7 @@ ImmediateExecutionTensorHandle* CustomDeviceOpHandler::CopyTensorHandleToDevice( return context->CopyTensorHandleToDevice(handle, device_name, status); } -Status CustomDeviceOpHandler::MaybePinToCustomDevice( +absl::Status CustomDeviceOpHandler::MaybePinToCustomDevice( CustomDevice** device, const ImmediateExecutionOperation& op) const { *device = nullptr; if (!FindCustomDeviceFromName(op.DeviceName(), device) && diff --git a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h index 2f60726566f441..6c38e50d458dcd 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_op_handler.h +++ b/tensorflow/core/common_runtime/eager/custom_device_op_handler.h @@ -29,25 +29,26 @@ class CustomDeviceOpHandler { public: ~CustomDeviceOpHandler() = default; // Register a new custom device. - Status RegisterCustomDevice(const string& device_name, - std::unique_ptr device); + absl::Status RegisterCustomDevice(const string& device_name, + std::unique_ptr device); // Find the custom device from given name. Return true if it finds one. bool FindCustomDeviceFromName(const string& name, CustomDevice** device) const; - Status Execute(ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, int* num_retvals); + absl::Status Execute(ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals); ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( ImmediateExecutionContext* context, ImmediateExecutionTensorHandle* handle, const char* device_name, - Status* status); + absl::Status* status); // Determine whether to place an op on a custom device. This method is // exposed as public for test only. - Status MaybePinToCustomDevice(CustomDevice** device, - const ImmediateExecutionOperation& op) const; + absl::Status MaybePinToCustomDevice( + CustomDevice** device, const ImmediateExecutionOperation& op) const; void Clear(); diff --git a/tensorflow/core/common_runtime/eager/custom_device_test.cc b/tensorflow/core/common_runtime/eager/custom_device_test.cc index 603ed6a1eafac8..71d60929115966 100644 --- a/tensorflow/core/common_runtime/eager/custom_device_test.cc +++ b/tensorflow/core/common_runtime/eager/custom_device_test.cc @@ -39,13 +39,14 @@ class TestCustomDevice : public CustomDevice { public: explicit TestCustomDevice(std::string name) : name_(name) {} const std::string& name() override { return name_; } - Status CopyTensorToDevice(ImmediateExecutionTensorHandle* tensor, - ImmediateExecutionTensorHandle** result) override { + absl::Status CopyTensorToDevice( + ImmediateExecutionTensorHandle* tensor, + ImmediateExecutionTensorHandle** result) override { tensor->Ref(); *result = tensor; return absl::OkStatus(); } - Status CopyTensorFromDevice( + absl::Status CopyTensorFromDevice( ImmediateExecutionTensorHandle* tensor, const std::string& target_device_name, ImmediateExecutionTensorHandle** result) override { @@ -53,14 +54,14 @@ class TestCustomDevice : public CustomDevice { *result = tensor; return absl::OkStatus(); } - Status Execute(const ImmediateExecutionOperation* op, - ImmediateExecutionTensorHandle** retvals, - int* num_retvals) override { + absl::Status Execute(const ImmediateExecutionOperation* op, + ImmediateExecutionTensorHandle** retvals, + int* num_retvals) override { return errors::Unimplemented("Not implemented"); } - Status Pack(absl::Span handles, - ImmediateExecutionTensorHandle** result) override { + absl::Status Pack(absl::Span handles, + ImmediateExecutionTensorHandle** result) override { return errors::Unimplemented("Packing is not implemented"); } @@ -82,11 +83,11 @@ class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle { : CustomDeviceTensorHandle(context, device, dtype), length_(length) {} void* DevicePointer() const override { return nullptr; } - Status NumDims(int* num_dims) const override { + absl::Status NumDims(int* num_dims) const override { *num_dims = 1; return absl::OkStatus(); } - Status Dim(int dim_index, int64_t* dim) const override { + absl::Status Dim(int dim_index, int64_t* dim) const override { if (dim_index == 0) { *dim = length_; return absl::OkStatus(); @@ -95,7 +96,7 @@ class TestCustomDeviceTensorHandle : public CustomDeviceTensorHandle { } } - Status SummarizeValue(std::string& summary) const override { + absl::Status SummarizeValue(std::string& summary) const override { summary = std::string("TestValue"); return absl::OkStatus(); } @@ -117,7 +118,7 @@ TEST(CustomDevice, TestTensorHandle) { core::RefCountPtr tensor( new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT, /*length=*/3)); - Status s; + absl::Status s; std::string device_type = tensor->DeviceType(&s); ASSERT_TRUE(s.ok()) << s.message(); EXPECT_EQ("CUSTOM", device_type); @@ -150,7 +151,7 @@ TEST(CustomDevice, TestTensorHandleUnknownDimNumElements) { new TestCustomDeviceTensorHandle(ctx.get(), &device, DT_FLOAT, /*length=*/-1)); int64_t num_elements; - Status s = tensor->NumElements(&num_elements); + absl::Status s = tensor->NumElements(&num_elements); EXPECT_FALSE(s.ok()); EXPECT_THAT(s.message(), HasSubstr("representing varying shapes")); } diff --git a/tensorflow/core/common_runtime/eager/eager_executor.cc b/tensorflow/core/common_runtime/eager/eager_executor.cc index fbe7b40eabc18d..fc552f3127576d 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor.cc @@ -64,10 +64,10 @@ EagerExecutor::~EagerExecutor() { } } -Status EagerExecutor::ShutDown() { +absl::Status EagerExecutor::ShutDown() { { bool has_thread; - Status status; + absl::Status status; { tensorflow::mutex_lock l(node_queue_mutex_); if (state_ != ExecutorState::kShutDown) { @@ -108,7 +108,7 @@ const char* EagerExecutor::StateStringLocked() { } } -Status EagerExecutor::SyncExecute(EagerNode* node) { +absl::Status EagerExecutor::SyncExecute(EagerNode* node) { if (Async()) { return errors::Internal("SyncExecute does not support async execution."); } @@ -119,7 +119,7 @@ Status EagerExecutor::SyncExecute(EagerNode* node) { uint64 id = next_node_id_++; - Status s = node->Prepare(); + absl::Status s = node->Prepare(); if (!s.ok()) { return s; } @@ -131,8 +131,8 @@ Status EagerExecutor::SyncExecute(EagerNode* node) { return s; } -Status EagerExecutor::AddOrExecute(std::unique_ptr node) { - Status status; +absl::Status EagerExecutor::AddOrExecute(std::unique_ptr node) { + absl::Status status; core::RefCountPtr item(new NodeItem); item->id = next_node_id_++; item->node = std::move(node); @@ -195,13 +195,12 @@ Status EagerExecutor::AddOrExecute(std::unique_ptr node) { return status; } -tensorflow::Status EagerExecutor::WaitForAllPendingNodes() { +absl::Status EagerExecutor::WaitForAllPendingNodes() { tensorflow::mutex_lock l(node_queue_mutex_); return WaitForAllPendingNodesLocked(&l); } -tensorflow::Status EagerExecutor::WaitForAllPendingNodesLocked( - mutex_lock* lock) { +absl::Status EagerExecutor::WaitForAllPendingNodesLocked(mutex_lock* lock) { tensorflow::condition_variable cond; // Don't wait if an error is already set. if (!status_.ok()) return status_; @@ -233,7 +232,7 @@ void EagerExecutor::ClearError() { } void EagerExecutor::NodeDone(const core::RefCountPtr& item, - const Status& status, bool from_queue) { + const absl::Status& status, bool from_queue) { DVLOG(3) << "Node Done: [id " << item->id << "] " << item->node->DebugString() << " with status: " << status; DCHECK(item->state != NodeState::kDONE); @@ -365,15 +364,15 @@ void EagerExecutor::Run() { curr_item.reset(node_queue_.front().get()); curr_item->Ref(); } - Status status = RunItem(std::move(curr_item), /*from_queue=*/true); + absl::Status status = RunItem(std::move(curr_item), /*from_queue=*/true); if (!status.ok()) { VLOG(1) << "Failed to run item: " << status; } } } -Status EagerExecutor::RunItem(core::RefCountPtr item, - bool from_queue) { +absl::Status EagerExecutor::RunItem(core::RefCountPtr item, + bool from_queue) { DVLOG(3) << "Running Node: [id " << item->id << "] " << item->node->DebugString(); AsyncRemoteExecuteNode* async_remote_node = @@ -386,7 +385,7 @@ Status EagerExecutor::RunItem(core::RefCountPtr item, // Running a remote function, need to sync if the function is going to // different device than last time we run remote distributed function. DVLOG(3) << "Executing Sync Executor for node" << item->id; - tensorflow::Status status = async_remote_node->SyncExecutors(); + absl::Status status = async_remote_node->SyncExecutors(); if (!status.ok()) { NodeDone(item, status, from_queue); return status; @@ -405,7 +404,7 @@ Status EagerExecutor::RunItem(core::RefCountPtr item, AsyncEagerNode* async_node = item->node->AsAsync(); if (async_node == nullptr) { - tensorflow::Status status = item->node->Run(); + absl::Status status = item->node->Run(); NodeDone(item, status, from_queue); return status; } @@ -416,7 +415,7 @@ Status EagerExecutor::RunItem(core::RefCountPtr item, TF_RETURN_IF_ERROR(MoveToUnfinished(std::move(item), from_queue)); - async_node->RunAsync([this, async_ref](const Status& status) { + async_node->RunAsync([this, async_ref](const absl::Status& status) { core::RefCountPtr async_item(async_ref); NodeDone(async_item, status, false); }); @@ -425,8 +424,8 @@ Status EagerExecutor::RunItem(core::RefCountPtr item, return status(); } -Status EagerExecutor::MoveToUnfinished(core::RefCountPtr item, - bool from_queue) { +absl::Status EagerExecutor::MoveToUnfinished(core::RefCountPtr item, + bool from_queue) { tensorflow::mutex_lock l(node_queue_mutex_); if (!status_.ok()) { return status_; diff --git a/tensorflow/core/common_runtime/eager/eager_executor.h b/tensorflow/core/common_runtime/eager/eager_executor.h index 05af8756133753..cec897b33bc4c3 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor.h +++ b/tensorflow/core/common_runtime/eager/eager_executor.h @@ -59,17 +59,17 @@ class EagerNode { // Prepares the node when adding it into EagerExecutor. If any errors happens, // EagerExecutor will abort the node immediately. - virtual Status Prepare() { return absl::OkStatus(); } + virtual absl::Status Prepare() { return absl::OkStatus(); } // Runs the computation corresponding to this node and blocks till the // execution is done. - virtual Status Run() = 0; + virtual absl::Status Run() = 0; // Called when this node will not be run due to some error contained in // `status`. `status` must not be OK. // For example, if the node would have computed some tensors in the Run(), // it should poison the corresponding tensor handles in this method. - virtual void Abort(Status status) = 0; + virtual void Abort(absl::Status status) = 0; // Returns nullptr iff this Eager node is synchronous. virtual AsyncEagerNode* AsAsync() { return nullptr; } @@ -90,7 +90,7 @@ class AsyncEagerNode : public EagerNode { AsyncEagerNode* AsAsync() final { return this; } - Status Run() final { + absl::Status Run() final { return errors::Unimplemented("Don't call AsyncEagerNode::Run()."); } }; @@ -102,7 +102,7 @@ class AsyncRemoteExecuteNode : public AsyncEagerNode { virtual const eager::EagerClient* eager_client() const = 0; virtual bool needs_remote_inputs() const = 0; virtual bool allow_multiple_pending_requests() const = 0; - virtual Status SyncExecutors() = 0; + virtual absl::Status SyncExecutors() = 0; }; // A class for handling async execution (see TFE_ContextSetAsync). @@ -125,33 +125,33 @@ class EagerExecutor { // blocks until all pendings nodes have finished running. // Returns the status of executing pending nodes. // If async was not enabled, aborts and destroys all pending nodes. - Status ShutDown(); + absl::Status ShutDown(); bool Async() const; bool StreamingEnqueue() const; // Inline execute node if executor is in sync mode. - Status SyncExecute(EagerNode* node); + absl::Status SyncExecute(EagerNode* node); // - Async Mode: schedules `node` for execution. // - Sync Mode: inline execute the 'node' directly. // If an error occurs (e.g. EagerExecutor has already been shut down), the // `node` is not added to this executor and its Abort() method is called. - Status AddOrExecute(std::unique_ptr node); + absl::Status AddOrExecute(std::unique_ptr node); // Blocks till all currently pending ops are done. // In particular, if EnableAsync() has not beed called, it will not return // until that happens (and pendings, at the time of call, nodes finish // running). If this executor has already been shut down, its final status is // returned. - Status WaitForAllPendingNodes(); + absl::Status WaitForAllPendingNodes(); // Clears all currently set errors which re-enables async execution. void ClearError(); // Returns Status based on any errors that occurred during async execution. - Status status() const { + absl::Status status() const { if (ok()) return absl::OkStatus(); tf_shared_lock l(node_queue_mutex_); @@ -200,8 +200,8 @@ class EagerExecutor { const char* StateStringLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); - void NodeDone(const core::RefCountPtr& item, const Status& status, - bool from_queue); + void NodeDone(const core::RefCountPtr& item, + const absl::Status& status, bool from_queue); void NotifyWaiters(uint64 id) TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); // Starts execution of pending EagerNodes. This function loops till executor @@ -210,15 +210,16 @@ class EagerExecutor { // `status_` is not ok. void Run(); - Status RunItem(core::RefCountPtr item, bool from_queue); - Status MoveToUnfinished(core::RefCountPtr item, bool from_queue); + absl::Status RunItem(core::RefCountPtr item, bool from_queue); + absl::Status MoveToUnfinished(core::RefCountPtr item, + bool from_queue); // The impl of WaitForAllPendingNodes // `lock` is the lock that holds node_queue_mutex_. - Status WaitForAllPendingNodesLocked(mutex_lock* lock) + absl::Status WaitForAllPendingNodesLocked(mutex_lock* lock) TF_EXCLUSIVE_LOCKS_REQUIRED(node_queue_mutex_); - Status WaitImpl(bool wait_all, uint64 node_id); + absl::Status WaitImpl(bool wait_all, uint64 node_id); std::atomic next_node_id_; @@ -239,7 +240,7 @@ class EagerExecutor { // `status_` is set based on any errors raised during execution of a // EagerNode. It remains set until ClearError is called. - Status status_ TF_GUARDED_BY(node_queue_mutex_); + absl::Status status_ TF_GUARDED_BY(node_queue_mutex_); std::atomic ok_ TF_GUARDED_BY(node_queue_mutex_); // Map from id of a EagerNode to condition_variables (not owned by the map). diff --git a/tensorflow/core/common_runtime/eager/eager_executor_test.cc b/tensorflow/core/common_runtime/eager/eager_executor_test.cc index 1650dbf975866e..80c205a5053cd5 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor_test.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { @@ -44,16 +44,16 @@ class TestState { class TestEagerNode : public EagerNode { public: explicit TestEagerNode(TestState* state, - Status prepare_return_status = absl::OkStatus(), - Status run_return_status = absl::OkStatus()) + absl::Status prepare_return_status = absl::OkStatus(), + absl::Status run_return_status = absl::OkStatus()) : state_(state), prepare_return_status_(prepare_return_status), run_return_status_(run_return_status) {} TestEagerNode(const TestEagerNode&) = delete; TestEagerNode& operator=(const TestEagerNode&) = delete; - Status Prepare() override { return prepare_return_status_; } + absl::Status Prepare() override { return prepare_return_status_; } - Status Run() override { + absl::Status Run() override { if (run_return_status_.ok()) { state_->update_success_state(); } else { @@ -62,27 +62,27 @@ class TestEagerNode : public EagerNode { return run_return_status_; }; - void Abort(Status status) override {} + void Abort(absl::Status status) override {} string DebugString() const override { return "testEagerNode"; } private: TestState* state_; - Status prepare_return_status_; - Status run_return_status_; + absl::Status prepare_return_status_; + absl::Status run_return_status_; }; class TestAsyncEagerNode : public AsyncEagerNode { public: - explicit TestAsyncEagerNode(TestState* state, - Status prepare_return_status = absl::OkStatus(), - Status run_return_status = absl::OkStatus()) + explicit TestAsyncEagerNode( + TestState* state, absl::Status prepare_return_status = absl::OkStatus(), + absl::Status run_return_status = absl::OkStatus()) : state_(state), prepare_return_status_(prepare_return_status), run_return_status_(run_return_status) {} TestAsyncEagerNode(const TestAsyncEagerNode&) = delete; TestAsyncEagerNode& operator=(const TestAsyncEagerNode&) = delete; - Status Prepare() override { return prepare_return_status_; } + absl::Status Prepare() override { return prepare_return_status_; } void RunAsync(StatusCallback done) override { if (run_return_status_.ok()) { @@ -93,13 +93,13 @@ class TestAsyncEagerNode : public AsyncEagerNode { done(run_return_status_); }; - void Abort(Status status) override {} + void Abort(absl::Status status) override {} string DebugString() const override { return "testAsyncEagerNode"; } private: TestState* state_; - Status prepare_return_status_; - Status run_return_status_; + absl::Status prepare_return_status_; + absl::Status run_return_status_; }; TEST(EagerExecutorTest, TestSyncExecutorWithEagerNode) { diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc index 2d4a94a204d9c3..d05832344327b5 100644 --- a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.cc @@ -44,7 +44,7 @@ void EagerOpRewriteRegistry::Register(Phase phase, int32_t ordinal, std::make_pair(std::move(pass), ordinal)); } -Status EagerOpRewriteRegistry::RunRewrite( +absl::Status EagerOpRewriteRegistry::RunRewrite( Phase phase, EagerOperation* orig_op, std::unique_ptr* out_op) { EagerOperation* pre_op = orig_op; diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h index a70877b57114a2..bd7098473d7532 100644 --- a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h @@ -37,8 +37,9 @@ class EagerOpRewrite { virtual ~EagerOpRewrite() = default; // To be implemented by an Eager op rewrite pass. - virtual Status Run(EagerOperation* orig_op, - std::unique_ptr* out_op) = 0; + virtual absl::Status Run( + EagerOperation* orig_op, + std::unique_ptr* out_op) = 0; // Holds information about the rewrite registration. struct DebugInfo { @@ -65,8 +66,8 @@ class EagerOpRewriteRegistry { std::unique_ptr pass); // Run the rewrite pass registered for a given phase. - Status RunRewrite(Phase phase, EagerOperation* orig_op, - std::unique_ptr* out_op); + absl::Status RunRewrite(Phase phase, EagerOperation* orig_op, + std::unique_ptr* out_op); // Returns the global registry of rewrite passes. static EagerOpRewriteRegistry* Global(); diff --git a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc index 560835eb07fc00..d50f3e0a4ec411 100644 --- a/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc +++ b/tensorflow/core/common_runtime/eager/eager_op_rewrite_registry_test.cc @@ -28,8 +28,9 @@ class TestEagerOpRewrite : public EagerOpRewrite { executor_(/*async=*/false, /*enable_streaming_enqueue=*/true) {} static int count_; EagerExecutor executor_; - Status Run(EagerOperation* orig_op, - std::unique_ptr* out_op) override { + absl::Status Run( + EagerOperation* orig_op, + std::unique_ptr* out_op) override { ++count_; // Create a new NoOp Eager operation. tensorflow::EagerOperation* op = diff --git a/tensorflow/core/common_runtime/eager/eager_operation.cc b/tensorflow/core/common_runtime/eager/eager_operation.cc index 488c8302c85eae..ce4b8df85e473e 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.cc +++ b/tensorflow/core/common_runtime/eager/eager_operation.cc @@ -48,40 +48,42 @@ void EagerOperation::Clear() { ClearInferenceState(); } -Status EagerOperation::SetAttrValue(const char* attr_name, - const AttrValue& value) { +absl::Status EagerOperation::SetAttrValue(const char* attr_name, + const AttrValue& value) { MutableAttrs()->Set(attr_name, value); return absl::OkStatus(); } -Status EagerOperation::SetAttrString(const char* attr_name, const char* data, - size_t length) { +absl::Status EagerOperation::SetAttrString(const char* attr_name, + const char* data, size_t length) { MutableAttrs()->Set(attr_name, StringPiece(data, length)); return absl::OkStatus(); } -Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) { +absl::Status EagerOperation::SetAttrInt(const char* attr_name, int64_t value) { MutableAttrs()->Set(attr_name, static_cast(value)); return absl::OkStatus(); } -Status EagerOperation::SetAttrFloat(const char* attr_name, float value) { +absl::Status EagerOperation::SetAttrFloat(const char* attr_name, float value) { MutableAttrs()->Set(attr_name, value); return absl::OkStatus(); } -Status EagerOperation::SetAttrBool(const char* attr_name, bool value) { +absl::Status EagerOperation::SetAttrBool(const char* attr_name, bool value) { MutableAttrs()->Set(attr_name, value); return absl::OkStatus(); } -Status EagerOperation::SetAttrType(const char* attr_name, DataType value) { +absl::Status EagerOperation::SetAttrType(const char* attr_name, + DataType value) { MutableAttrs()->Set(attr_name, value); return absl::OkStatus(); } -Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims, - const int num_dims) { +absl::Status EagerOperation::SetAttrShape(const char* attr_name, + const int64_t* dims, + const int num_dims) { if (num_dims > TensorShape::MaxDimensions()) { return errors::InvalidArgument("Value specified for `", attr_name, "` has ", num_dims, @@ -103,8 +105,8 @@ Status EagerOperation::SetAttrShape(const char* attr_name, const int64_t* dims, return absl::OkStatus(); } -Status EagerOperation::SetAttrFunction(const char* attr_name, - const AbstractOperation* value) { +absl::Status EagerOperation::SetAttrFunction(const char* attr_name, + const AbstractOperation* value) { AttrValue attr_value; NameAttrList* func = attr_value.mutable_func(); func->set_name(value->Name()); @@ -114,8 +116,9 @@ Status EagerOperation::SetAttrFunction(const char* attr_name, return absl::OkStatus(); } -Status EagerOperation::SetAttrFunctionName(const char* attr_name, - const char* data, size_t length) { +absl::Status EagerOperation::SetAttrFunctionName(const char* attr_name, + const char* data, + size_t length) { AttrValue attr_value; NameAttrList* func = attr_value.mutable_func(); func->set_name(data, length); @@ -123,17 +126,17 @@ Status EagerOperation::SetAttrFunctionName(const char* attr_name, return absl::OkStatus(); } -Status EagerOperation::SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) { +absl::Status EagerOperation::SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) { Tensor t = TensorFromInterface(tensor); MutableAttrs()->Set(attr_name, t); return absl::OkStatus(); } -Status EagerOperation::SetAttrStringList(const char* attr_name, - const void* const* values, - const size_t* lengths, - int num_values) { +absl::Status EagerOperation::SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) { std::vector v(num_values); for (int i = 0; i < num_values; ++i) { v[i] = StringPiece(static_cast(values[i]), lengths[i]); @@ -143,31 +146,34 @@ Status EagerOperation::SetAttrStringList(const char* attr_name, return absl::OkStatus(); } -Status EagerOperation::SetAttrFloatList(const char* attr_name, - const float* values, int num_values) { +absl::Status EagerOperation::SetAttrFloatList(const char* attr_name, + const float* values, + int num_values) { MutableAttrs()->Set(attr_name, gtl::ArraySlice(values, num_values)); return absl::OkStatus(); } -Status EagerOperation::SetAttrIntList(const char* attr_name, - const int64_t* values, int num_values) { +absl::Status EagerOperation::SetAttrIntList(const char* attr_name, + const int64_t* values, + int num_values) { MutableAttrs()->Set( attr_name, gtl::ArraySlice( reinterpret_cast(values), num_values)); return absl::OkStatus(); } -Status EagerOperation::SetAttrTypeList(const char* attr_name, - const DataType* values, int num_values) { +absl::Status EagerOperation::SetAttrTypeList(const char* attr_name, + const DataType* values, + int num_values) { MutableAttrs()->Set(attr_name, gtl::ArraySlice(values, num_values)); return absl::OkStatus(); } -Status EagerOperation::SetAttrBoolList(const char* attr_name, - const unsigned char* values, - int num_values) { +absl::Status EagerOperation::SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) { std::unique_ptr b(new bool[num_values]); for (int i = 0; i < num_values; ++i) { b[i] = values[i]; @@ -177,9 +183,10 @@ Status EagerOperation::SetAttrBoolList(const char* attr_name, return absl::OkStatus(); } -Status EagerOperation::SetAttrShapeList(const char* attr_name, - const int64_t** dims, - const int* num_dims, int num_values) { +absl::Status EagerOperation::SetAttrShapeList(const char* attr_name, + const int64_t** dims, + const int* num_dims, + int num_values) { std::unique_ptr proto(new TensorShapeProto[num_values]); for (int i = 0; i < num_values; ++i) { const auto num_dims_i = num_dims[i]; @@ -205,7 +212,7 @@ Status EagerOperation::SetAttrShapeList(const char* attr_name, return absl::OkStatus(); } -Status EagerOperation::SetAttrFunctionList( +absl::Status EagerOperation::SetAttrFunctionList( const char* attr_name, absl::Span values) { size_t num_values = values.size(); std::unique_ptr funcs(new NameAttrList[num_values]); @@ -219,15 +226,15 @@ Status EagerOperation::SetAttrFunctionList( return absl::OkStatus(); } -const OpDef* EagerOperation::GetOpDef(Status* status) { +const OpDef* EagerOperation::GetOpDef(absl::Status* status) { const tensorflow::OpDef* op_def = OpDef(); if (op_def) return op_def; *status = OpDefForOp(Name(), &op_def); return op_def; } -Status EagerOperation::InputLength(const char* input_name, int* length) { - Status status; +absl::Status EagerOperation::InputLength(const char* input_name, int* length) { + absl::Status status; const tensorflow::OpDef* op_def = GetOpDef(&status); if (!status.ok()) { return status; @@ -253,8 +260,9 @@ absl::Span EagerOperation::GetInputs() inputs_.size()); } -Status EagerOperation::OutputLength(const char* output_name, int* length) { - Status status; +absl::Status EagerOperation::OutputLength(const char* output_name, + int* length) { + absl::Status status; const tensorflow::OpDef* op_def = GetOpDef(&status); if (!status.ok()) { return status; @@ -272,7 +280,7 @@ Status EagerOperation::OutputLength(const char* output_name, int* length) { return absl::OkStatus(); } -Status EagerOperation::AddInput(AbstractTensorHandle* input) { +absl::Status EagerOperation::AddInput(AbstractTensorHandle* input) { ImmediateExecutionTensorHandle* h = down_cast(input); // TODO(b/175427838): It would be nice to be able to use tensorflow::isa here. @@ -283,7 +291,7 @@ Status EagerOperation::AddInput(AbstractTensorHandle* input) { return MaybeInferSingleInputAttrs(h); } -Status EagerOperation::AddInputList( +absl::Status EagerOperation::AddInputList( absl::Span inputs) { for (auto& input : inputs) { // TODO(b/175427838): It would be nice to be able to use tensorflow::isa @@ -298,8 +306,8 @@ Status EagerOperation::AddInputList( return InferInputListAttrs(inputs.size()); } -Status EagerOperation::SetInput(size_t index, - ImmediateExecutionTensorHandle* input) { +absl::Status EagerOperation::SetInput(size_t index, + ImmediateExecutionTensorHandle* input) { if (index >= inputs_.size()) { return errors::InvalidArgument("Index >= inputs.size: %d >= %d", index, inputs_.size()); @@ -317,7 +325,7 @@ Status EagerOperation::SetInput(size_t index, return absl::OkStatus(); } -Status EagerOperation::Reset( +absl::Status EagerOperation::Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, const absl::optional eager_func_params) { @@ -367,7 +375,7 @@ Status EagerOperation::Reset( return SetDeviceName(device_name); } -Status EagerOperation::MaybeInferSingleInputAttrs( +absl::Status EagerOperation::MaybeInferSingleInputAttrs( ImmediateExecutionTensorHandle* handle) { if (!op_def_) return absl::OkStatus(); @@ -415,7 +423,7 @@ void EagerOperation::InferMixedTypeInputListAttrs( } } -Status EagerOperation::InferInputListAttrs(int num_inputs) { +absl::Status EagerOperation::InferInputListAttrs(int num_inputs) { if (!op_def_) return absl::OkStatus(); int start = inference_arg_idx_; @@ -442,7 +450,7 @@ Status EagerOperation::InferInputListAttrs(int num_inputs) { return absl::OkStatus(); } -Status EagerOperation::TensorHandleInputs( +absl::Status EagerOperation::TensorHandleInputs( const absl::InlinedVector** inputs) const { if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = reinterpret_cast*>( @@ -453,7 +461,7 @@ Status EagerOperation::TensorHandleInputs( } } -Status EagerOperation::MutableTensorHandleInputs( +absl::Status EagerOperation::MutableTensorHandleInputs( absl::InlinedVector** inputs) { if (TF_PREDICT_TRUE(!HasCustomDeviceInput())) { *inputs = @@ -464,7 +472,7 @@ Status EagerOperation::MutableTensorHandleInputs( } } -Status EagerOperation::SetDeviceName(const char* c_name) { +absl::Status EagerOperation::SetDeviceName(const char* c_name) { string name(c_name != nullptr ? c_name : ""); if (name != last_set_device_name_) { if (!DeviceNameUtils::ParseFullName(name, &device_parsed_name_)) { diff --git a/tensorflow/core/common_runtime/eager/eager_operation.h b/tensorflow/core/common_runtime/eager/eager_operation.h index 3ddf91c5ed5f52..b81b0fc75313df 100644 --- a/tensorflow/core/common_runtime/eager/eager_operation.h +++ b/tensorflow/core/common_runtime/eager/eager_operation.h @@ -53,7 +53,7 @@ class EagerOperation : public ImmediateExecutionOperation { void Release() override { delete this; } void Clear() override; - Status Reset(const char* op, const char* raw_device_name) override { + absl::Status Reset(const char* op, const char* raw_device_name) override { return Reset(op, raw_device_name, false, nullptr); } @@ -73,7 +73,7 @@ class EagerOperation : public ImmediateExecutionOperation { // This also resets the internal device pointer, unless the given name refers // to a known custom device, in which case the internal device pointer is // updated to that device. - Status SetDeviceName(const char* name) override; + absl::Status SetDeviceName(const char* name) override; void SetDevice(VariantDevice device) { device_ = device; @@ -87,51 +87,56 @@ class EagerOperation : public ImmediateExecutionOperation { last_set_device_name_ = "\177"; // DEL (an invalid value) } - Status SetAttrValue(const char* attr_name, const AttrValue& value); + absl::Status SetAttrValue(const char* attr_name, const AttrValue& value); - Status AddInput(AbstractTensorHandle* input) override; - Status AddInputList(absl::Span inputs) override; - Status SetInput(size_t index, ImmediateExecutionTensorHandle* input) override; + absl::Status AddInput(AbstractTensorHandle* input) override; + absl::Status AddInputList( + absl::Span inputs) override; + absl::Status SetInput(size_t index, + ImmediateExecutionTensorHandle* input) override; absl::Span GetInputs() const override; bool HasCustomDeviceInput() const override { return custom_device_tensor_handles_count_ > 0; } - Status Execute(absl::Span retvals, - int* num_retvals) override; + absl::Status Execute(absl::Span retvals, + int* num_retvals) override; const tensorflow::OpDef* OpDef() const override { return op_def_; }; - Status SetAttrString(const char* attr_name, const char* data, - size_t length) override; - Status SetAttrInt(const char* attr_name, int64_t value) override; - Status SetAttrFloat(const char* attr_name, float value) override; - Status SetAttrBool(const char* attr_name, bool value) override; - Status SetAttrType(const char* attr_name, DataType value) override; - Status SetAttrShape(const char* attr_name, const int64_t* dims, - int num_dims) override; - Status SetAttrFunction(const char* attr_name, - const AbstractOperation* value) override; - Status SetAttrFunctionName(const char* attr_name, const char* data, + absl::Status SetAttrString(const char* attr_name, const char* data, size_t length) override; - Status SetAttrTensor(const char* attr_name, - AbstractTensorInterface* tensor) override; - Status SetAttrStringList(const char* attr_name, const void* const* values, - const size_t* lengths, int num_values) override; - Status SetAttrFloatList(const char* attr_name, const float* values, - int num_values) override; - Status SetAttrIntList(const char* attr_name, const int64_t* values, - int num_values) override; - Status SetAttrTypeList(const char* attr_name, const DataType* values, - int num_values) override; - Status SetAttrBoolList(const char* attr_name, const unsigned char* values, - int num_values) override; - Status SetAttrShapeList(const char* attr_name, const int64_t** dims, - const int* num_dims, int num_values) override; - Status SetAttrFunctionList( + absl::Status SetAttrInt(const char* attr_name, int64_t value) override; + absl::Status SetAttrFloat(const char* attr_name, float value) override; + absl::Status SetAttrBool(const char* attr_name, bool value) override; + absl::Status SetAttrType(const char* attr_name, DataType value) override; + absl::Status SetAttrShape(const char* attr_name, const int64_t* dims, + int num_dims) override; + absl::Status SetAttrFunction(const char* attr_name, + const AbstractOperation* value) override; + absl::Status SetAttrFunctionName(const char* attr_name, const char* data, + size_t length) override; + absl::Status SetAttrTensor(const char* attr_name, + AbstractTensorInterface* tensor) override; + absl::Status SetAttrStringList(const char* attr_name, + const void* const* values, + const size_t* lengths, + int num_values) override; + absl::Status SetAttrFloatList(const char* attr_name, const float* values, + int num_values) override; + absl::Status SetAttrIntList(const char* attr_name, const int64_t* values, + int num_values) override; + absl::Status SetAttrTypeList(const char* attr_name, const DataType* values, + int num_values) override; + absl::Status SetAttrBoolList(const char* attr_name, + const unsigned char* values, + int num_values) override; + absl::Status SetAttrShapeList(const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values) override; + absl::Status SetAttrFunctionList( const char* attr_name, absl::Span values) override; - Status InputLength(const char* input_name, int* length) override; - Status OutputLength(const char* output_name, int* length) override; + absl::Status InputLength(const char* input_name, int* length) override; + absl::Status OutputLength(const char* output_name, int* length) override; const AbstractOpAttrs* GetOpAttrs() const override; void AddAttrs(const AbstractOpAttrs* op_attrs) override; @@ -144,7 +149,7 @@ class EagerOperation : public ImmediateExecutionOperation { return stack_trace_; } - Status Reset( + absl::Status Reset( const char* op, const char* device_name, bool remote, EagerExecutor* executor, absl::optional eager_func_params = std::nullopt); @@ -177,9 +182,9 @@ class EagerOperation : public ImmediateExecutionOperation { // TensorHandleInputs and MutableTensorHandleInputs first check that all // inputs are TensorHandles, i.e. that there are no custom device inputs. They // return a bad status otherwise. - Status TensorHandleInputs( + absl::Status TensorHandleInputs( const absl::InlinedVector** inputs) const; - Status MutableTensorHandleInputs( + absl::Status MutableTensorHandleInputs( absl::InlinedVector** inputs); const absl::InlinedVector& Inputs() @@ -254,7 +259,7 @@ class EagerOperation : public ImmediateExecutionOperation { private: void AddTensorHandle(ImmediateExecutionTensorHandle* h); - const tensorflow::OpDef* GetOpDef(Status* status); + const tensorflow::OpDef* GetOpDef(absl::Status* status); void ClearInferenceState() { op_def_ = nullptr; @@ -262,8 +267,9 @@ class EagerOperation : public ImmediateExecutionOperation { inference_attrs_.clear_no_resize(); } - Status MaybeInferSingleInputAttrs(ImmediateExecutionTensorHandle* handle); - Status InferInputListAttrs(int num_inputs); + absl::Status MaybeInferSingleInputAttrs( + ImmediateExecutionTensorHandle* handle); + absl::Status InferInputListAttrs(int num_inputs); void InferSingleTypeInputListAttrs(const OpDef::ArgDef& input_def, DataType dtype, int num_inputs); diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 94ff2676189e75..ef2ed6455675b2 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -147,12 +147,12 @@ const string& DeviceNameOrUnspecified(Device* device) { // // `op_device` is passed in explicitly because `op->device()` might be // unset and we might have selected some specific device to run this op on. -Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, - Device* op_device, - TensorHandle* handle, // op->Inputs()[i] - int i, Device* handle_device, - Device* expected_input_device, - TensorHandle** result) { +absl::Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, + Device* op_device, + TensorHandle* handle, // op->Inputs()[i] + int i, Device* handle_device, + Device* expected_input_device, + TensorHandle** result) { VLOG(6) << "Expected input device: " << expected_input_device->name() << "; handle_device: " << handle_device->name(); // Should only be called when these don't match @@ -213,12 +213,12 @@ Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, " to ", expected_input_device->name()); }, tsl::profiler::TraceMeLevel::kInfo); - Status status = + absl::Status status = EagerCopyToDevice(handle, ctx, &op->Executor(), expected_input_device, /* mirror= */ true, &result_handle); activity.Stop(); if (!status.ok()) { - return Status( + return absl::Status( status.code(), absl::StrCat("Failed copying input tensor from ", handle_device->name(), " to ", expected_input_device->name(), " in order to run ", @@ -233,7 +233,7 @@ Status CopyInputToExpectedDevice(EagerContext* ctx, EagerOperation* op, // `op_device_name` the name of the device on which the op will run, if any. // For functions running using function library runtime, the device can be // unspecified. -Status ValidateInputTypeAndPlacement( +absl::Status ValidateInputTypeAndPlacement( EagerContext* ctx, EagerOperation* op, const core::RefCountPtr& kernel) { tsl::profiler::TraceMe activity("ValidateInputTypeAndPlacement", @@ -309,9 +309,10 @@ bool IsHostMemoryArg(const EagerOperation& op, const NodeDef* node_def, op_def.input_arg(arg_id).name()) != host_memory_args.end(); } -Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx, - const bool is_host_memory_arg, - TensorHandle* tensor_handle, Device** result) { +absl::Status GetDeviceForInput(const EagerOperation& op, + const EagerContext& ctx, + const bool is_host_memory_arg, + TensorHandle* tensor_handle, Device** result) { Device* cpu_device = ctx.HostCPU(); string device_name; if (tensor_handle->Type() != TensorHandle::LOCAL) { @@ -370,9 +371,9 @@ Status GetDeviceForInput(const EagerOperation& op, const EagerContext& ctx, return absl::OkStatus(); } -Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, - const char* attr_name, bool* value) { - Status status = op->Attrs().Get(attr_name, value); +absl::Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, + const char* attr_name, bool* value) { + absl::Status status = op->Attrs().Get(attr_name, value); if (status.ok()) { VLOG(2) << "Caller explicitly specifies " << attr_name << (value ? "=true " : "=false, ") << op->DebugString(); @@ -396,8 +397,9 @@ Status GetFuncAttr(const EagerOperation* op, const EagerContext& ctx, // Checks if `op` is a function and contains TPU replication ops. If `op` does, // then `has_tpu_replication` is set to true. Other `has_tpu_replication` is // set to false. -Status HasTPUReplication(const EagerOperation& op, const EagerContext& ctx, - bool* has_tpu_replication) { +absl::Status HasTPUReplication(const EagerOperation& op, + const EagerContext& ctx, + bool* has_tpu_replication) { *has_tpu_replication = false; if (!op.is_function()) { return absl::OkStatus(); @@ -416,8 +418,9 @@ Status HasTPUReplication(const EagerOperation& op, const EagerContext& ctx, return absl::OkStatus(); } -Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, - bool* compile_with_xla) { +absl::Status MustCompileWithXLA(const EagerOperation* op, + const EagerContext& ctx, + bool* compile_with_xla) { #if defined(PLUGGABLE_DEVICE_SUPPORTED_MACOS) *compile_with_xla = false; #else @@ -434,7 +437,8 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, return absl::OkStatus(); } - Status status = GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla); + absl::Status status = + GetFuncAttr(op, ctx, kXlaMustCompileAttr, compile_with_xla); if (status.ok()) { return absl::OkStatus(); } @@ -457,8 +461,9 @@ Status MustCompileWithXLA(const EagerOperation* op, const EagerContext& ctx, // Check if `op` has tf.StatefulPartitionedCall op with _XlaMustCompile, sets // `has_jit_compile` and `device`. -Status HasNestedJitCompile(const EagerOperation& op, const EagerContext& ctx, - bool* has_jit_compile, string* device) { +absl::Status HasNestedJitCompile(const EagerOperation& op, + const EagerContext& ctx, bool* has_jit_compile, + string* device) { *has_jit_compile = false; const std::string kStatefulPartitionedCallOp = "StatefulPartitionedCall"; @@ -517,8 +522,10 @@ string CanonicalizeDeviceType(std::string_view device_type) { return canonical_device_type; } -Status UpdateCompileCounter(const EagerOperation* op, const EagerContext& ctx, - bool compile_with_xla, bool has_tpu_replication) { +absl::Status UpdateCompileCounter(const EagerOperation* op, + const EagerContext& ctx, + bool compile_with_xla, + bool has_tpu_replication) { if (has_tpu_replication) { function_compile_counter->GetCell(tensorflow::DEVICE_TPU, kEnabled) ->IncrementBy(1); @@ -587,8 +594,8 @@ string GetFlatName(const string orig_name, int index) { // // IdentityN[T:[DT_FLOAT, DT_INT64]] -> __wrapped__IdentityN_T_2 // Concat[N:2, T:DT_FLOAT] -> __wrapped__Concat_N_2 -Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef, - const AbstractOpAttrs* op_attrs, string* name) { +absl::Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef, + const AbstractOpAttrs* op_attrs, string* name) { string fname = absl::StrCat("__wrapped__", EscapeOrigName(op->Name())); // For every variadic arg in `args`, populates `attr_to_len` with // (attr_name, len(arg)). @@ -748,8 +755,8 @@ Status BuildWrappedOpName(EagerOperation* op, const OpDef& opdef, // // Note that the N attr is preserved so that it can get copied to the // inner op via a placeholder. This allows additional verification. -Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef, - const string& fname, OpDef& signature) { +absl::Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef, + const string& fname, OpDef& signature) { signature = opdef; signature.clear_input_arg(); signature.clear_output_arg(); @@ -814,9 +821,9 @@ Status BuildWrappedOpSignature(EagerOperation* op, const OpDef& opdef, // For mixed type inputs "list(type)" we create new attributes in the signature // for each element tensor (See examples in BuildWrappedOpSignature). Here // we construct the values for those attributes and set them on the wrapped op. -Status AddMixedTypeListAttrs(EagerOperation* wrapped_op, - const AbstractOpAttrs* op_attrs, - const OpDef& opdef) { +absl::Status AddMixedTypeListAttrs(EagerOperation* wrapped_op, + const AbstractOpAttrs* op_attrs, + const OpDef& opdef) { auto FillAttrsToAdd = [op_attrs](const ProtoArgListType& opdef_args, absl::flat_hash_map* attrs_to_add) { @@ -846,9 +853,9 @@ Status AddMixedTypeListAttrs(EagerOperation* wrapped_op, // Maps the op's outputs to the function outputs. Mainly useful for variadic // outputs which need to be flattened. -Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs, - const EagerOperation* op, const OpDef& opdef, - const OpDef& signature, const string& node_name) { +absl::Status PopulateRetMap(FunctionDef* fdef, const AbstractOpAttrs* op_attrs, + const EagerOperation* op, const OpDef& opdef, + const OpDef& signature, const string& node_name) { int next_sig_output = 0; for (size_t i = 0; i < opdef.output_arg_size(); i++) { const auto& output_arg = opdef.output_arg(i); @@ -889,14 +896,14 @@ inline void GetMKLNodeDef(NodeDef* ndef) { } #endif // INTEL_MKL -Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) { +absl::Status WrapInCallOp(EagerOperation* op, EagerOperation** wrapped_op) { DCHECK(!op->is_function()); const OpDef& opdef = OpRegistry::Global()->LookUp(op->Name())->op_def; // Raise an error for ops which don't support wrapping yet. This includes // ops with list inputs/outputs and ops with private attrs. // TODO(srbs): Support list inputs/outputs. auto verify_wrappable_in_call_op = [](const OpDef& opdef, - EagerOperation* op) -> Status { + EagerOperation* op) -> absl::Status { absl::flat_hash_set opdef_attrs; for (const auto& attr : opdef.attr()) { opdef_attrs.insert(attr.name()); @@ -1011,7 +1018,7 @@ absl::StatusOr GetBoolInputs(EagerOperation* op, } // Identify boolean inputs to this EagerOperation that are on host. const TensorHandle* handle = inputs->at(i); - Status s; + absl::Status s; const char* input_device = handle->DeviceType(&s); if (!s.ok() || !absl::StrContains(input_device, "CPU")) { return errors::InvalidArgument( @@ -1064,7 +1071,7 @@ std::optional GetBoolArgumentValue(const EagerOperation& op, // If the input is not on host returns std::nullopt. const TensorHandle* handle = inputs->at(i); - Status s; + absl::Status s; const char* input_device = handle->DeviceType(&s); if (!s.ok() || !absl::StrContains(input_device, "CPU")) return std::nullopt; @@ -1146,7 +1153,7 @@ absl::StatusOr GetKernelCacheKey( // physical device names. // `input_resource_variable_dtypes_shape` - A map from input index // to dtype and shapes for resource inputs. -Status ExtractFunctionInputInfo( +absl::Status ExtractFunctionInputInfo( EagerOperation* op, const KernelDef* kernel_def, std::vector& input_device_ptrs, absl::flat_hash_map*>& composite_devices, @@ -1199,7 +1206,8 @@ Status ExtractFunctionInputInfo( return absl::OkStatus(); } -Status SetOpDevice(EagerContext& ctx, EagerOperation* op, Device** device) { +absl::Status SetOpDevice(EagerContext& ctx, EagerOperation* op, + Device** device) { // Here in local execute, set preferred device to be on the local task to // avoid placing op on a remote device with higher priority. const DeviceNameUtils::ParsedName& preferred_device = @@ -1230,7 +1238,7 @@ Status SetOpDevice(EagerContext& ctx, EagerOperation* op, Device** device) { return absl::OkStatus(); } -Status GetOrCreateKernelAndDevice( +absl::Status GetOrCreateKernelAndDevice( EagerOperation* op, TensorHandle** retvals, int* num_retvals, core::RefCountPtr* out_kernel) { EagerContext& ctx = op->EagerContext(); @@ -1309,9 +1317,9 @@ Status GetOrCreateKernelAndDevice( auto get_kernel_def = [](const EagerOperation& op, const NodeDef& node_def, const Device* op_device) -> const KernelDef* { const KernelDef* kernel_def = nullptr; - Status s = FindKernelDef(DeviceType(op_device->device_type()), node_def, - &kernel_def, - /*kernel_class_name=*/nullptr); + absl::Status s = FindKernelDef(DeviceType(op_device->device_type()), + node_def, &kernel_def, + /*kernel_class_name=*/nullptr); if (!s.ok()) return nullptr; return kernel_def; }; @@ -1417,7 +1425,7 @@ Status GetOrCreateKernelAndDevice( // _apply_op_helper's validation (which is reached when executing in // graph mode) or the eager execution's validation (which is reached via // the CreateOpKernel call). - auto validate_op = [](EagerOperation* op) -> Status { + auto validate_op = [](EagerOperation* op) -> absl::Status { const NodeDef& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def; TF_RETURN_IF_ERROR( @@ -1575,7 +1583,7 @@ Status GetOrCreateKernelAndDevice( return absl::OkStatus(); } -Status CreateUnshapedOutput( +absl::Status CreateUnshapedOutput( const KernelAndDevice& kernel, const int output_num, Device* output_device, const DataType& output_dtype, const absl::optional& eager_func_params, @@ -1610,8 +1618,8 @@ Status CreateUnshapedOutput( #endif // !IS_MOBILE_PLATFORM } -Status AddOrExecuteNode(core::RefCountPtr kernel, - EagerOperation* op, TensorHandle** retvals) { +absl::Status AddOrExecuteNode(core::RefCountPtr kernel, + EagerOperation* op, TensorHandle** retvals) { EagerExecutor& executor = op->Executor(); EagerContext& ctx = op->EagerContext(); GraphCollector* graph_collector = nullptr; @@ -1672,7 +1680,7 @@ Status AddOrExecuteNode(core::RefCountPtr kernel, op->GetCancellationManager(), {retvals, static_cast(num_outputs)}, op->GetStackTrace()); - Status s = executor.SyncExecute(&node); + absl::Status s = executor.SyncExecute(&node); // We release the inputs AFTER executing the operation in sync mode since // ExecuteNode does not increment the reference count and thus does not have // ownership of the inputs while executing. @@ -1695,8 +1703,8 @@ Status AddOrExecuteNode(core::RefCountPtr kernel, // runtime. In this case, we don't select a device because running // a function with explicitly requested device has different behavior than // running without an explicitly requested device. -Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals) { +absl::Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals) { tsl::profiler::ScopedMemoryDebugAnnotation op_annotation( op->op_name(), op->eager_func_params().has_value() ? op->eager_func_params().value().step_id.value_or(0) @@ -1747,7 +1755,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, } } - Status s = AddOrExecuteNode(std::move(kernel), op, retvals); + absl::Status s = AddOrExecuteNode(std::move(kernel), op, retvals); // Since the operation failed, we need to Unref any outputs if they were // allocated. if (!s.ok()) { @@ -1764,7 +1772,7 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals, // Run a Pack op to pack the tensors pointed by a packed input TensorHandle if // the op is a primitive op. -Status MaybePackInputTensor(EagerOperation* op) { +absl::Status MaybePackInputTensor(EagerOperation* op) { if (op->is_function() || op->EagerContext().RunEagerOpAsFunction()) { // Functions could take packed TensorHandles as inputs. return absl::OkStatus(); @@ -1799,8 +1807,8 @@ Status MaybePackInputTensor(EagerOperation* op) { #if !defined(IS_MOBILE_PLATFORM) -Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals) { +absl::Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals) { EagerContext& ctx = op->EagerContext(); // TODO(fishx): Remove following code when lazy tensor copy is ready. @@ -1930,7 +1938,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, DataTypeVector output_dtypes; auto get_output_dtypes = [](EagerOperation* op, - DataTypeVector* output_dtypes) -> Status { + DataTypeVector* output_dtypes) -> absl::Status { const auto& node_def = op->MutableAttrs()->BuildNodeDef(); const OpDef* op_def = nullptr; @@ -1980,7 +1988,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, // with the EnqueueRequest. auto store_resource_dtypes_and_shapes = [](const eager::Operation& remote_op, const DataTypeVector& output_dtypes, - TensorHandle** retvals) -> Status { + TensorHandle** retvals) -> absl::Status { if (remote_op.name() == "VarHandleOp") { if (output_dtypes.size() != 1) { return errors::Internal("VarHandleOp should only have one output."); @@ -2024,7 +2032,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } } - Status s = executor.AddOrExecute(std::move(node)); + absl::Status s = executor.AddOrExecute(std::move(node)); // Since the operation failed, we need to Unref any outputs that were // allocated. if (!s.ok()) { @@ -2040,7 +2048,7 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, } #endif // IS_MOBILE_PLATFORM -Status GetKernelOutputs( +absl::Status GetKernelOutputs( std::vector* outputs, int num_outputs, TensorHandle** retvals, EagerContext* ctx, KernelAndDevice* kernel, const absl::optional& eager_func_params) { @@ -2124,8 +2132,8 @@ void CollectGraphs(EagerContext* ctx) { } } // namespace -Status DoEagerExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals) { +absl::Status DoEagerExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals) { tsl::profiler::TraceMe activity([&] { return tsl::profiler::TraceMeEncode( "EagerExecute", @@ -2169,7 +2177,7 @@ Status DoEagerExecute(EagerOperation* op, TensorHandle** retvals, } // TODO(gjn): Consider moving into ExecuteNode class -Status EagerKernelExecute( +absl::Status EagerKernelExecute( EagerContext* ctx, const absl::InlinedVector& op_inputs, const absl::optional& eager_func_params, const core::RefCountPtr& kernel, @@ -2212,8 +2220,8 @@ Status EagerKernelExecute( kernel.get(), eager_func_params); } -Status EagerExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals) { +absl::Status EagerExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals) { if (VLOG_IS_ON(1) && op->is_function()) { const std::string& op_name = op->Name(); const std::string& exec_mode = op->IsLocal() ? "local" : "remote"; @@ -2228,7 +2236,7 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, VLOG(1) << "Entering " << msg; - Status status = DoEagerExecute(op, retvals, num_retvals); + absl::Status status = DoEagerExecute(op, retvals, num_retvals); VLOG(1) << "Exiting " << msg << ", status code is " << status; @@ -2239,9 +2247,9 @@ Status EagerExecute(EagerOperation* op, TensorHandle** retvals, namespace { -Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, - EagerExecutor* executor, Device* dstd, - bool mirror, TensorHandle** result) { +absl::Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, + EagerExecutor* executor, Device* dstd, + bool mirror, TensorHandle** result) { TF_RETURN_IF_ERROR(executor->status()); Device* d = ctx->CanonicalDevice(dstd); if (mirror && h->HasLocalMirror(d)) { @@ -2264,7 +2272,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, // reference count is still needed which will be removed if the operation // fails. if (async) { - Status s = h->AddEmptyLocalMirror(d); + absl::Status s = h->AddEmptyLocalMirror(d); if (!s.ok()) { // If a mirror was added since we called HasLocalMirror then just return // since another thread has already added the mirror. @@ -2284,7 +2292,7 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, d, dstd, h->resource_device(), h->dtype, ctx); } - Status s; + absl::Status s; if (async) { // Note that `h` may not be currently ready. However execution order will // make sure that `h` is ready before the copy is actually done. @@ -2308,9 +2316,9 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx, } // namespace -Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, - EagerExecutor* executor, Device* device, bool mirror, - TensorHandle** result) { +absl::Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, + EagerExecutor* executor, Device* device, + bool mirror, TensorHandle** result) { TF_RETURN_IF_ERROR(h->WaitUnknownDevice()); auto send_device = h->DeviceOrHostCPU(*ctx); bool sender_is_local = send_device->IsLocal(); @@ -2344,7 +2352,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, return absl::OkStatus(); } - Status s = h->AddEmptyLocalMirror(d); + absl::Status s = h->AddEmptyLocalMirror(d); if (!s.ok()) { // If a mirror was added since we called HasLocalMirror then just // return since another thread has already added the mirror. @@ -2391,7 +2399,7 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, auto node = std::make_unique( ctx, executor, h, result[0], device, recv_op_id); - Status s = executor->AddOrExecute(std::move(node)); + absl::Status s = executor->AddOrExecute(std::move(node)); if (!s.ok()) { result[0]->Unref(); result[0] = nullptr; @@ -2417,7 +2425,7 @@ void EagerKernelExecuteAsync( auto inputs = std::make_shared(op_inputs.size()); auto outputs = std::make_shared>(1); - Status s = inputs->Init(ctx, op_inputs, kernel); + absl::Status s = inputs->Init(ctx, op_inputs, kernel); if (!s.ok()) { done(s); return; @@ -2434,8 +2442,8 @@ void EagerKernelExecuteAsync( eager_func_params, coord_agent, [retvals, inputs, outputs, num_outputs, ctx, graph_collector, eager_func_params, kernel_raw = kernel.get(), - done = std::move(done)](const Status& s) { - auto wrapped_done = [&](const Status& s) { + done = std::move(done)](const absl::Status& s) { + auto wrapped_done = [&](const absl::Status& s) { kernel_raw->Unref(); done(s); }; @@ -2476,7 +2484,8 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, EagerContext& ctx = op->EagerContext(); core::RefCountPtr kernel; - Status s = GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); + absl::Status s = + GetOrCreateKernelAndDevice(op, retvals, num_retvals, &kernel); if (!s.ok()) { done(s); return; @@ -2517,22 +2526,23 @@ void EagerLocalExecuteAsync(EagerOperation* op, TensorHandle** retvals, done(s); return; } - EagerKernelExecuteAsync( - &ctx, *inputs, op->eager_func_params(), std::move(kernel), - graph_collector, op->GetCancellationManager(), retvals, num_outputs, - [op, num_outputs, retvals, done = std::move(done)](const Status& s) { - op->Clear(); - // Since the operation failed, we need to Unref any outputs if they were - // allocated. - if (!s.ok()) { - for (int i = 0, end = num_outputs; i < end; ++i) { - if (retvals[i] != nullptr) { - retvals[i]->Unref(); - retvals[i] = nullptr; - } - } - } - done(s); - }); + EagerKernelExecuteAsync(&ctx, *inputs, op->eager_func_params(), + std::move(kernel), graph_collector, + op->GetCancellationManager(), retvals, num_outputs, + [op, num_outputs, retvals, + done = std::move(done)](const absl::Status& s) { + op->Clear(); + // Since the operation failed, we need to Unref any + // outputs if they were allocated. + if (!s.ok()) { + for (int i = 0, end = num_outputs; i < end; ++i) { + if (retvals[i] != nullptr) { + retvals[i]->Unref(); + retvals[i] = nullptr; + } + } + } + done(s); + }); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.h b/tensorflow/core/common_runtime/eager/execute.h index 5cba8cf0243207..cbd1e0c90e2759 100644 --- a/tensorflow/core/common_runtime/eager/execute.h +++ b/tensorflow/core/common_runtime/eager/execute.h @@ -42,12 +42,12 @@ namespace tensorflow { // '*num_retvals' should be set to the size of this array. It is an error if // the size of 'retvals' is less than the number of outputs. This call sets // *num_retvals to the number of outputs. -Status EagerExecute(EagerOperation* op, TensorHandle** retvals, - int* num_retvals); +absl::Status EagerExecute(EagerOperation* op, TensorHandle** retvals, + int* num_retvals); // Low-level utility to execute the kernel specified by `kernel` on // `kernel->device()`, with the inputs op_inputs, in the context 'ctx'. -Status EagerKernelExecute( +absl::Status EagerKernelExecute( EagerContext* ctx, const absl::InlinedVector& op_inputs, const absl::optional& eager_func_params, const core::RefCountPtr& kernel, @@ -60,9 +60,9 @@ Status EagerKernelExecute( // the mirror flag, EagerCopyToDevice will attempt to add a mirror to the // original handle and update *result to point to h. Since this is not // guaranteed, callers should always use the value in *result. -Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, - EagerExecutor* executor, Device* device, bool mirror, - TensorHandle** result); +absl::Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx, + EagerExecutor* executor, Device* device, + bool mirror, TensorHandle** result); // Utility function that executes a fully constructed EagerOperation // asynchronously on the local task. This function works differently from diff --git a/tensorflow/core/common_runtime/eager/execute_node.cc b/tensorflow/core/common_runtime/eager/execute_node.cc index a65947f91687d1..09bebd3e1f7cf2 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.cc +++ b/tensorflow/core/common_runtime/eager/execute_node.cc @@ -35,9 +35,10 @@ bool ExecuteNodeArgs::IsRemote(EagerContext* ctx, Device* input_device, } #endif // IS_MOBILE_PLATFORM -Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx, - Device* input_device, - TensorHandle* packed_handle) { +absl::Status ExecuteNodeArgs::InitPackedHandle(const int index, + EagerContext* ctx, + Device* input_device, + TensorHandle* packed_handle) { int num_handles = packed_handle->NumPackedHandles(); packed_args_.emplace(index, absl::InlinedVector(num_handles)); @@ -47,7 +48,8 @@ Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx, TF_RETURN_IF_ERROR(packed_handle->ExtractPackedHandle(i, &h)); // We have validated that h->device() is not a CustomDevice when // constructing a pack TensorHandle. - const Status status = h->TensorValue(h->device(), &packed_arg_flat[i]); + const absl::Status status = + h->TensorValue(h->device(), &packed_arg_flat[i]); if (!status.ok()) { #if !defined(IS_MOBILE_PLATFORM) if (IsRemote(ctx, input_device, h)) { @@ -64,7 +66,7 @@ Status ExecuteNodeArgs::InitPackedHandle(const int index, EagerContext* ctx, return absl::OkStatus(); } -Status ExecuteNodeArgs::Init( +absl::Status ExecuteNodeArgs::Init( EagerContext* ctx, const absl::InlinedVector& op_inputs, const core::RefCountPtr& kernel) { // If there are multiple references to a TensorHandle in 'op_inputs' we must @@ -79,7 +81,8 @@ Status ExecuteNodeArgs::Init( for (int i = 0; i < n_inputs; ++i) { TensorHandle* in = op_inputs_flat[i]; Device* d = kernel->InputDevice(i); - Status s = in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]); + absl::Status s = + in->TensorValue(ctx->CanonicalDevice(d), &tensor_args_flat[i]); if (!s.ok()) { #if !defined(IS_MOBILE_PLATFORM) if (IsRemote(ctx, d, in)) { @@ -103,7 +106,7 @@ Status ExecuteNodeArgs::Init( serialize_remote_handle_ = [ctx, &op_inputs, is_function]( const FunctionArgIndex& index, - eager::RemoteTensorHandle* handle) -> Status { + eager::RemoteTensorHandle* handle) -> absl::Status { TensorHandle* h = op_inputs[index.index]; if (op_inputs[index.index]->Type() == TensorHandle::PACKED) { TF_RETURN_IF_ERROR( @@ -124,9 +127,9 @@ Status ExecuteNodeArgs::Init( return absl::OkStatus(); } -Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index, - Tensor* val) const { - Status s = EagerKernelArgs::GetLocalArg(index, val); +absl::Status ExecuteNodeArgs::GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const { + absl::Status s = EagerKernelArgs::GetLocalArg(index, val); if (s.ok()) { return absl::OkStatus(); } diff --git a/tensorflow/core/common_runtime/eager/execute_node.h b/tensorflow/core/common_runtime/eager/execute_node.h index b3cb571a3d256e..52bf1ecfe67d78 100644 --- a/tensorflow/core/common_runtime/eager/execute_node.h +++ b/tensorflow/core/common_runtime/eager/execute_node.h @@ -54,19 +54,20 @@ class ExecuteNodeArgs : public EagerKernelArgs { public: explicit ExecuteNodeArgs(int count) : EagerKernelArgs(count) {} - Status Init(EagerContext* ctx, - const absl::InlinedVector& op_inputs, - const core::RefCountPtr& kernel); + absl::Status Init(EagerContext* ctx, + const absl::InlinedVector& op_inputs, + const core::RefCountPtr& kernel); - Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; + absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override; bool HasRemoteOrPackedInputs() const override { return has_remote_inputs_ || has_packed_inputs_; }; #if !defined(IS_MOBILE_PLATFORM) - Status GetRemoteArg(const FunctionArgIndex& index, - eager::RemoteTensorHandle* val) const override { + absl::Status GetRemoteArg(const FunctionArgIndex& index, + eager::RemoteTensorHandle* val) const override { return serialize_remote_handle_(index, val); } #endif // IS_MOBILE_PLATFORM @@ -79,15 +80,17 @@ class ExecuteNodeArgs : public EagerKernelArgs { #endif // IS_MOBILE_PLATFORM // Initialize a packed TensorHandle which is the `index`-th argument. - Status InitPackedHandle(int index, EagerContext* ctx, Device* input_device, - TensorHandle* packed_handle); + absl::Status InitPackedHandle(int index, EagerContext* ctx, + Device* input_device, + TensorHandle* packed_handle); bool has_remote_inputs_ = false; bool has_packed_inputs_ = false; // Maps from the index of a packed arg to a list of sub-args. absl::flat_hash_map> packed_args_; #if !defined(IS_MOBILE_PLATFORM) - std::function + std::function serialize_remote_handle_; #endif // IS_MOBILE_PLATFORM }; @@ -112,12 +115,12 @@ class ExecuteNode : public EagerNode { retvals_(retvals), stack_trace_(stack_trace) {} - Status Run() override { + absl::Status Run() override { int i = 0; for (TensorHandle* h : inputs_) { if (h->RefCountIsOne()) { const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); - Status s = h->Unprotect(d); + absl::Status s = h->Unprotect(d); if (!s.ok()) { VLOG(1) << "Unable to unprotect tensor: " << s; } @@ -129,7 +132,7 @@ class ExecuteNode : public EagerNode { stack_trace_); } - void Abort(Status status) override {} + void Abort(absl::Status status) override {} std::string DebugString() const override { std::string out = "[ExecuteNode]"; @@ -190,19 +193,19 @@ class AsyncExecuteNode : public EagerNode { } } - Status Run() override { + absl::Status Run() override { int i = 0; for (TensorHandle* h : inputs_) { if (h->RefCountIsOne()) { const Device* d = ctx_->CanonicalDevice(kernel_->InputDevice(i)); - Status s = h->Unprotect(d); + absl::Status s = h->Unprotect(d); if (!s.ok()) { VLOG(1) << "Unable to unprotect tensor: " << s; } } ++i; } - Status status = EagerKernelExecute( + absl::Status status = EagerKernelExecute( ctx_, inputs_, eager_func_params_, kernel_, graph_collector_, cancellation_manager_, absl::MakeSpan(retvals_), stack_trace_); if (!status.ok()) { @@ -219,7 +222,7 @@ class AsyncExecuteNode : public EagerNode { return absl::OkStatus(); } - void Abort(Status status) override { + void Abort(absl::Status status) override { int i = 0; for (auto handle : retvals_) { handle->Poison(status, ctx_->CanonicalDevice(kernel_->OutputDevice(i))); diff --git a/tensorflow/core/common_runtime/eager/execute_node_test.cc b/tensorflow/core/common_runtime/eager/execute_node_test.cc index b2714b3e1bbcec..d6e6b5fe57dea1 100644 --- a/tensorflow/core/common_runtime/eager/execute_node_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_node_test.cc @@ -68,7 +68,7 @@ TEST(ExecuteNodeTest, ExecuteNodeArgs) { TF_ASSERT_OK(remote_device_mgr->AddDevices(std::move(remote_devices))); Device* device1 = remote_device_mgr->ListDevices().at(0); - Status s; + absl::Status s; std::unique_ptr composite_device = CompositeDevice::MakeDevice({device0->name(), device1->name()}, /*unique_device_id=*/0, diff --git a/tensorflow/core/common_runtime/eager/execute_test.cc b/tensorflow/core/common_runtime/eager/execute_test.cc index 74e5b88a64578f..e424f217130c93 100644 --- a/tensorflow/core/common_runtime/eager/execute_test.cc +++ b/tensorflow/core/common_runtime/eager/execute_test.cc @@ -168,7 +168,7 @@ TEST(ExecuteTest, SimpleFunctionInt32BadFullType) { std::vector retvals(1); int num_retvals = retvals.size(); - Status status = EagerExecute(op.get(), retvals.data(), &num_retvals); + absl::Status status = EagerExecute(op.get(), retvals.data(), &num_retvals); ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; EXPECT_TRUE( absl::StrContains(status.message(), "TFT_TENSOR has 0 args instead of 1")) diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 4dc051f5c69221..0738e0ead791c6 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -57,8 +57,8 @@ limitations under the License. namespace tensorflow { -Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index, - Tensor* val) const { +absl::Status EagerKernelArgs::GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const { if (index.sub_index >= 0) { return errors::InvalidArgument("Got unexpected sub_index ", index.sub_index, " for argument ", index.index); @@ -95,7 +95,7 @@ std::function)>* KernelAndDevice::get_runner() KernelAndDeviceFunc::~KernelAndDeviceFunc() { if (handle_ != kInvalidHandle) { - Status status = pflr_->ReleaseHandle(handle_); + absl::Status status = pflr_->ReleaseHandle(handle_); if (!status.ok()) { LOG(INFO) << "Ignoring error status when releasing multi-device function " "handle " @@ -104,7 +104,7 @@ KernelAndDeviceFunc::~KernelAndDeviceFunc() { } } -Status KernelAndDeviceOp::Init( +absl::Status KernelAndDeviceOp::Init( const bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collecto, const absl::optional& eager_func_params) { @@ -147,7 +147,7 @@ Status KernelAndDeviceOp::Init( return absl::OkStatus(); } -Status KernelAndDeviceFunc::InstantiateFunc( +absl::Status KernelAndDeviceFunc::InstantiateFunc( const bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params) { @@ -259,7 +259,7 @@ Status KernelAndDeviceFunc::InstantiateFunc( return pflr_->IsCrossProcess(handle_, &is_cross_process_); } -Status KernelAndDeviceFunc::Init( +absl::Status KernelAndDeviceFunc::Init( const bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params) { @@ -279,7 +279,7 @@ struct OpExecutionState : public core::RefCounted { }; } // anonymous namespace -Status KernelAndDeviceOp::Run( +absl::Status KernelAndDeviceOp::Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, @@ -345,7 +345,7 @@ Status KernelAndDeviceOp::Run( op_execution_state->Unref(); } - Status s = context.status(); + absl::Status s = context.status(); if (TF_PREDICT_FALSE(!s.ok())) { if (absl::IsUnavailable(s) && !is_distributed_communication_op_) { s = errors::ReplaceErrorFromNonCommunicationOps(s, kernel_->name()); @@ -429,7 +429,7 @@ KernelAndDeviceFunc::PrepareForRun( return opts; } -Status KernelAndDeviceFunc::Run( +absl::Status KernelAndDeviceFunc::Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, @@ -441,10 +441,10 @@ Status KernelAndDeviceFunc::Run( // Don't try to handle packed or remote inputs synchronously. if (inputs.HasRemoteOrPackedInputs() || eager_func_params.has_value()) { Notification n; - Status status; + absl::Status status; RunAsync(step_container, inputs, outputs, cancellation_manager, eager_func_params, coordination_service_agent, - [&status, &n](Status s) { + [&status, &n](absl::Status s) { status = s; n.Notify(); }); @@ -457,7 +457,7 @@ Status KernelAndDeviceFunc::Run( stack_trace, coordination_service_agent, &created_rendezvous); std::vector rets; - Status s; + absl::Status s; { port::ScopedFlushDenormal flush; port::ScopedSetRound round(FE_TONEAREST); @@ -480,7 +480,7 @@ void KernelAndDeviceFunc::RunAsync( CancellationManager* cancellation_manager, const absl::optional& eager_func_params, tsl::CoordinationServiceAgent* coordination_service_agent, - std::function done) { + std::function done) { tsl::profiler::TraceMe activity( [] { return tsl::profiler::TraceMeEncode("KernelAndDeviceFunc::RunAsync", @@ -492,16 +492,16 @@ void KernelAndDeviceFunc::RunAsync( step_container, outputs, cancellation_manager, eager_func_params, std::nullopt, coordination_service_agent, &created_rendezvous); - pflr_->Run( - *opts, handle_, inputs, outputs, - [opts, cancellation_manager, done = std::move(done), - created_rendezvous = created_rendezvous.release()](const Status& s) { - if (cancellation_manager == nullptr) { - delete opts->cancellation_manager; - } - created_rendezvous->Unref(); - done(s); - }); + pflr_->Run(*opts, handle_, inputs, outputs, + [opts, cancellation_manager, done = std::move(done), + created_rendezvous = + created_rendezvous.release()](const absl::Status& s) { + if (cancellation_manager == nullptr) { + delete opts->cancellation_manager; + } + created_rendezvous->Unref(); + done(s); + }); } tensorflow::Device* KernelAndDeviceOp::OutputDevice(int idx) const { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 9a9d6193f30f08..c13e1524aeb788 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -89,7 +89,8 @@ class EagerKernelArgs : public FunctionArgsInterface { bool HasRemoteOrPackedInputs() const override { return false; }; TensorValue* MutableInput(int i) { return &tensor_args_[i]; } - Status GetLocalArg(const FunctionArgIndex& index, Tensor* val) const override; + absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override; std::vector GetLocalTensors() const override; @@ -118,7 +119,7 @@ class KernelAndDevice : public core::RefCounted { // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. - virtual Status Init( + virtual absl::Status Init( bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params) = 0; @@ -146,7 +147,7 @@ class KernelAndDevice : public core::RefCounted { virtual bool IsCrossProcess() { return false; } // TODO(ashankar): Handle list-valued inputs. - virtual Status Run( + virtual absl::Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, @@ -222,12 +223,12 @@ class KernelAndDeviceOp final : public KernelAndDevice { ~KernelAndDeviceOp() override = default; - Status Init( + absl::Status Init( bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params) override; - Status Run( + absl::Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, @@ -325,17 +326,17 @@ class KernelAndDeviceFunc : public KernelAndDevice { bool IsCrossProcess() override { return is_cross_process_; } - Status InstantiateFunc( + absl::Status InstantiateFunc( bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params); - Status Init( + absl::Status Init( bool log_device_placement, const NodeDef& ndef, GraphCollector* graph_collector, const absl::optional& eager_func_params) override; - Status Run( + absl::Status Run( ScopedStepContainer* step_container, const EagerKernelArgs& inputs, std::vector* outputs, CancellationManager* cancellation_manager, diff --git a/tensorflow/core/common_runtime/eager/placement_test.cc b/tensorflow/core/common_runtime/eager/placement_test.cc index 951759d78b4321..b89b9384ba7196 100644 --- a/tensorflow/core/common_runtime/eager/placement_test.cc +++ b/tensorflow/core/common_runtime/eager/placement_test.cc @@ -60,7 +60,7 @@ static Device* CreateDevice(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -126,7 +126,7 @@ TEST_F(PlacementTest, SelectDeviceExplicitHardPlacement) { requested.Clear(); NodeDef invalid_op = NDef("invalid_op", "InvalidOp", {}, {}); - Status status = context()->SelectDevice(requested, invalid_op, &dev); + absl::Status status = context()->SelectDevice(requested, invalid_op, &dev); LOG(ERROR) << status; EXPECT_TRUE(errors::IsNotFound(status)); EXPECT_TRUE( @@ -167,7 +167,7 @@ TEST_F(PlacementTest, SelectDeviceExplicitSoftPlacement) { requested.Clear(); NodeDef invalid_op = NDef("invalid_op", "InvalidOp", {}, {}); - Status status = context()->SelectDevice(requested, invalid_op, &dev); + absl::Status status = context()->SelectDevice(requested, invalid_op, &dev); LOG(ERROR) << status; EXPECT_TRUE(errors::IsNotFound(status)); EXPECT_TRUE( diff --git a/tensorflow/core/common_runtime/eager/placement_utils.cc b/tensorflow/core/common_runtime/eager/placement_utils.cc index 4562c3287af6ee..3cbc844dddbb74 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils.cc @@ -51,8 +51,8 @@ static bool IsPinnableOp(StringPiece op_name) { } // Validate if the remote device with the given incarnation is valid in the // remote device manager of the current eager context. -static Status ValidateTensorHandleRemoteDevice(EagerContext* ctx, - int64_t device_incarnation) { +static absl::Status ValidateTensorHandleRemoteDevice( + EagerContext* ctx, int64_t device_incarnation) { if (ctx->remote_device_mgr()->ContainsDevice(device_incarnation)) { return absl::OkStatus(); } @@ -69,7 +69,7 @@ bool IsColocationExempt(StringPiece op_name) { bool IsFunction(StringPiece op_name) { const OpDef* op_def = nullptr; - Status s = OpDefForOp(string(op_name), &op_def); + absl::Status s = OpDefForOp(string(op_name), &op_def); if (!s.ok()) { if (!absl::IsNotFound(s)) { LOG(WARNING) << "Looking up OpDef failed with error: " << s; @@ -80,7 +80,7 @@ bool IsFunction(StringPiece op_name) { return false; } -Status MaybePinSmallOpsToCpu( +absl::Status MaybePinSmallOpsToCpu( bool* result, StringPiece op_name, absl::Span args, StringPiece cpu_device_name) { @@ -100,7 +100,7 @@ Status MaybePinSmallOpsToCpu( int i = 0; for (auto* arg : args) { - Status s; + absl::Status s; const char* device_name = arg->DeviceName(&s); DataType dtype = arg->DataType(); TF_RETURN_IF_ERROR(s); @@ -137,7 +137,8 @@ Status MaybePinSmallOpsToCpu( return absl::OkStatus(); } -Status MaybePinToResourceDevice(Device** device, const EagerOperation& op) { +absl::Status MaybePinToResourceDevice(Device** device, + const EagerOperation& op) { if (op.colocation_exempt()) { return absl::OkStatus(); } diff --git a/tensorflow/core/common_runtime/eager/placement_utils.h b/tensorflow/core/common_runtime/eager/placement_utils.h index b182cf684027a4..9064b86314aed7 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils.h +++ b/tensorflow/core/common_runtime/eager/placement_utils.h @@ -33,7 +33,7 @@ bool IsFunction(StringPiece op_name); // Pin the op to cpu if all op inputs are on the CPU, small (<64 elements) and // integers (int32/int64). This can be disabled by setting the environment // variable "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false". -Status MaybePinSmallOpsToCpu( +absl::Status MaybePinSmallOpsToCpu( bool* result, StringPiece op_name, absl::Span args, StringPiece cpu_device_name); @@ -41,7 +41,8 @@ Status MaybePinSmallOpsToCpu( // If a resource touching input is specified, all resource-touching ops run in // the device the resource is, regardless of anything else that has been // specified. This is identical to the graph mode behavior. -Status MaybePinToResourceDevice(Device** device, const EagerOperation& op); +absl::Status MaybePinToResourceDevice(Device** device, + const EagerOperation& op); } // namespace eager } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/placement_utils_test.cc b/tensorflow/core/common_runtime/eager/placement_utils_test.cc index 6220cc95778d66..c543b9475a072c 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils_test.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils_test.cc @@ -57,7 +57,7 @@ static Device* CreateDevice(const char* type, const char* name, public: explicit FakeDevice(const DeviceAttributes& attr, bool is_local) : Device(nullptr, attr), is_local_(is_local) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } bool IsLocal() const override { return is_local_; } @@ -211,7 +211,7 @@ TEST(PlacementUtilsTest, MaybePinToResourceDevice_OtherDevice) { Device* device1 = remote_device_mgr->ListDevices().at(0); - Status s; + absl::Status s; std::unique_ptr composite_device = CompositeDevice::MakeDevice({device0->name(), device1->name()}, /*unique_device_id=*/0, diff --git a/tensorflow/core/common_runtime/eager/shape_inference.cc b/tensorflow/core/common_runtime/eager/shape_inference.cc index f730bd62f312dd..8bdd7b1dda4831 100644 --- a/tensorflow/core/common_runtime/eager/shape_inference.cc +++ b/tensorflow/core/common_runtime/eager/shape_inference.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { namespace eager { -Status RunShapeInference( +absl::Status RunShapeInference( const NodeDef& ndef, const FunctionLibraryDefinition& lib_def, const absl::InlinedVector& inputs, const absl::InlinedVector& retvals) { diff --git a/tensorflow/core/common_runtime/eager/shape_inference.h b/tensorflow/core/common_runtime/eager/shape_inference.h index 51536b3744109f..be386f978c341b 100644 --- a/tensorflow/core/common_runtime/eager/shape_inference.h +++ b/tensorflow/core/common_runtime/eager/shape_inference.h @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { namespace eager { -Status RunShapeInference( +absl::Status RunShapeInference( const NodeDef& ndef, const FunctionLibraryDefinition& lib_def, const absl::InlinedVector& inputs, const absl::InlinedVector& retvals); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 24a0258fa3db33..d42b0319d10229 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -83,29 +83,31 @@ TensorHandle::PackedTensorHandleData::~PackedTensorHandleData() { } } -Status TensorHandle::PackedTensorHandleData::Shape(TensorShape* shape) const { +absl::Status TensorHandle::PackedTensorHandleData::Shape( + TensorShape* shape) const { *shape = shape_; return absl::OkStatus(); } -Status TensorHandle::PackedTensorHandleData::NumDims(int* num_dims) const { +absl::Status TensorHandle::PackedTensorHandleData::NumDims( + int* num_dims) const { *num_dims = shape_.dims(); return absl::OkStatus(); } -Status TensorHandle::PackedTensorHandleData::Dim(int dim_index, - int64_t* dim) const { +absl::Status TensorHandle::PackedTensorHandleData::Dim(int dim_index, + int64_t* dim) const { *dim = shape_.dim_size(dim_index); return absl::OkStatus(); } -Status TensorHandle::PackedTensorHandleData::NumElements( +absl::Status TensorHandle::PackedTensorHandleData::NumElements( int64_t* num_elements) const { *num_elements = shape_.num_elements(); return absl::OkStatus(); } -Status TensorHandle::PackedTensorHandleData::Unprotect() { +absl::Status TensorHandle::PackedTensorHandleData::Unprotect() { for (auto* handle : handles_) { TF_RETURN_IF_ERROR( std::visit([](auto& data) { return data.Unprotect(); }, handle->data_)); @@ -128,7 +130,7 @@ bool TensorHandle::PackedTensorHandleData::IsReady() const { return true; } -Status TensorHandle::PackedTensorHandleData::WaitReady( +absl::Status TensorHandle::PackedTensorHandleData::WaitReady( const char* caller) const { { tf_shared_lock l(mu_); @@ -142,7 +144,7 @@ Status TensorHandle::PackedTensorHandleData::WaitReady( return absl::OkStatus(); } -void TensorHandle::PackedTensorHandleData::Poison(Status status) { +void TensorHandle::PackedTensorHandleData::Poison(absl::Status status) { mutex_lock l(mu_); is_poisoned_ = status; } @@ -162,7 +164,7 @@ int TensorHandle::PackedTensorHandleData::NumPackedHandles() const { return handles_.size(); } -Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle( +absl::Status TensorHandle::PackedTensorHandleData::ExtractPackedHandle( const int index, TensorHandle** handle) const { if (index < 0 || index >= handles_.size()) { return errors::InvalidArgument("Expect an index within [0, ", @@ -177,7 +179,7 @@ void TensorHandle::SetResourceHandleDtypeAndShape( handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); } -Status TensorHandle::GetResourceHandleDtypesAndShapes( +absl::Status TensorHandle::GetResourceHandleDtypesAndShapes( std::vector* result) { if (dtype != DT_RESOURCE) { return errors::InvalidArgument( @@ -209,8 +211,8 @@ int TensorHandle::NumPackedHandles() const { return std::get(data_).NumPackedHandles(); } -Status TensorHandle::ExtractPackedHandle(const int index, - TensorHandle** handle) const { +absl::Status TensorHandle::ExtractPackedHandle(const int index, + TensorHandle** handle) const { if (Type() != PACKED) { return errors::Internal("Invalid ExtractPackedHandleOnDevice call on a", TypeString(), " handle: ", this); @@ -303,12 +305,10 @@ TensorHandle::TensorHandle(Device* d, Device* op_device, << " device: " << SafeDeviceDebugString(device_); } -Status TensorHandle::CreatePackedHandle(std::vector&& handles, - const tensorflow::DataType dtype, - const tensorflow::TensorShape& shape, - const string& device_name, - EagerContext* ctx, - TensorHandle** packed_handle) { +absl::Status TensorHandle::CreatePackedHandle( + std::vector&& handles, const tensorflow::DataType dtype, + const tensorflow::TensorShape& shape, const string& device_name, + EagerContext* ctx, TensorHandle** packed_handle) { if (handles.empty()) { return errors::InvalidArgument("Handles should not be empty."); } @@ -335,9 +335,9 @@ Status TensorHandle::CreatePackedHandle(std::vector&& handles, return absl::OkStatus(); } -Status TensorHandle::CreatePackedHandle(std::vector&& handles, - EagerContext* ctx, - TensorHandle** packed_handle) { +absl::Status TensorHandle::CreatePackedHandle( + std::vector&& handles, EagerContext* ctx, + TensorHandle** packed_handle) { if (handles.empty()) { return errors::InvalidArgument("Handles should not be empty."); } @@ -434,7 +434,7 @@ bool TensorHandle::IsReady() const { return std::visit([](auto& data) { return data.IsReady(); }, data_); } -Status TensorHandle::WaitReady(const char* caller) const { +absl::Status TensorHandle::WaitReady(const char* caller) const { return std::visit([caller](auto& data) { return data.WaitReady(caller); }, data_); } @@ -459,7 +459,7 @@ string TensorHandle::TypeString() const { } } -Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { +absl::Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { DVLOG(3) << "Tensor on TensorHandle: " << this; if (Type() != LOCAL) { @@ -471,8 +471,8 @@ Status TensorHandle::Tensor(const tensorflow::Tensor** t) const { return data.Tensor(t); } -Status TensorHandle::TensorFromDevice(const Device* d, - const tensorflow::Tensor** t) const { +absl::Status TensorHandle::TensorFromDevice( + const Device* d, const tensorflow::Tensor** t) const { DVLOG(3) << "TensorFromDevice on TensorHandle: " << this << " device: " << d; if (d == device_) { @@ -496,7 +496,8 @@ Status TensorHandle::TensorFromDevice(const Device* d, return mirror.Tensor(t); } -Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { +absl::Status TensorHandle::TensorValue(const Device* d, + tensorflow::TensorValue* t) { DVLOG(3) << "TensorValue on TensorHandle: " << this << " device: " << d; if (d == device_) { @@ -520,7 +521,7 @@ Status TensorHandle::TensorValue(const Device* d, tensorflow::TensorValue* t) { return mirror.TensorValue(t); } -Status TensorHandle::WaitUnknownDevice() const { +absl::Status TensorHandle::WaitUnknownDevice() const { if (unknown_device_) { TF_RETURN_IF_ERROR(std::visit( [](auto& data) { @@ -535,7 +536,7 @@ Device* TensorHandle::DeviceOrHostCPU(const EagerContext& ctx) const { return (device_ == nullptr) ? ctx.HostCPU() : device_; } -Status TensorHandle::Shape(tensorflow::TensorShape* shape) { +absl::Status TensorHandle::Shape(tensorflow::TensorShape* shape) { if (!IsReady() && inference_shape_.IsFullyDefined()) { bool fill = inference_shape_.AsTensorShape(shape); DCHECK(fill); @@ -545,7 +546,7 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { } } -Status TensorHandle::InferenceShape( +absl::Status TensorHandle::InferenceShape( shape_inference::InferenceContext* const inference_context, shape_inference::ShapeHandle* shape_handle) { if (IsReady()) { @@ -593,7 +594,7 @@ void TensorHandle::SetInferenceShape( TF_DCHECK_OK(s); } -Status TensorHandle::CopyInferenceShape(TensorHandle* other) { +absl::Status TensorHandle::CopyInferenceShape(TensorHandle* other) { if (IsReady()) { return absl::OkStatus(); } @@ -607,7 +608,7 @@ Status TensorHandle::CopyInferenceShape(TensorHandle* other) { return absl::OkStatus(); } -Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const { +absl::Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const { DCHECK(shape != nullptr); if (!IsReady() && !inference_shape_.unknown_rank()) { *shape = inference_shape_; @@ -616,7 +617,7 @@ Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const { auto result = std::visit( [](auto& data) { TensorShape shape; - Status s = data.Shape(&shape); + absl::Status s = data.Shape(&shape); return std::make_pair(shape, s); }, data_); @@ -626,7 +627,7 @@ Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const { return absl::OkStatus(); } -Status TensorHandle::NumDims(int* num_dims) const { +absl::Status TensorHandle::NumDims(int* num_dims) const { DCHECK(num_dims != nullptr); if (!IsReady() && !inference_shape_.unknown_rank()) { *num_dims = inference_shape_.dims(); @@ -637,7 +638,7 @@ Status TensorHandle::NumDims(int* num_dims) const { } } -Status TensorHandle::Dim(int dim_index, int64_t* dim) const { +absl::Status TensorHandle::Dim(int dim_index, int64_t* dim) const { DCHECK(dim != nullptr); if (!IsReady() && !inference_shape_.unknown_rank() && inference_shape_.dim_size(dim_index) != -1) { @@ -650,7 +651,7 @@ Status TensorHandle::Dim(int dim_index, int64_t* dim) const { } } -Status TensorHandle::NumElements(int64_t* num_elements) const { +absl::Status TensorHandle::NumElements(int64_t* num_elements) const { DCHECK(num_elements != nullptr); if (!IsReady() && inference_shape_.IsFullyDefined()) { *num_elements = inference_shape_.num_elements(); @@ -662,7 +663,7 @@ Status TensorHandle::NumElements(int64_t* num_elements) const { } } -Status TensorHandle::Unprotect(const Device* d) { +absl::Status TensorHandle::Unprotect(const Device* d) { DVLOG(3) << "Unprotect on TensorHandle: " << this << " device: " << d; if (d == device_) { @@ -688,7 +689,7 @@ bool TensorHandle::HasLocalMirror(const Device* d) const { return local_mirrors_.find(d) != local_mirrors_.end(); } -Status TensorHandle::AddEmptyLocalMirror(const Device* d) { +absl::Status TensorHandle::AddEmptyLocalMirror(const Device* d) { DVLOG(3) << "AddEmptyLocalMirror on TensorHandle: " << this << " device: " << d; @@ -708,8 +709,10 @@ Status TensorHandle::AddEmptyLocalMirror(const Device* d) { } #if !defined(IS_MOBILE_PLATFORM) -Status TensorHandle::RemoteAddress(const Device* d, const bool wait_until_ready, - int64_t* op_id, int32* output_num) const { +absl::Status TensorHandle::RemoteAddress(const Device* d, + const bool wait_until_ready, + int64_t* op_id, + int32* output_num) const { DVLOG(3) << "RemoteAddress on TensorHandle: " << this << " device: " << d << " " << d->name(); @@ -789,10 +792,11 @@ bool TensorHandle::HasResourceShapeMirror(const Device* d, return false; } -Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64_t op_id, - int output_num, - const string& remote_task, - EagerContext* ctx) { +absl::Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, + int64_t op_id, + int output_num, + const string& remote_task, + EagerContext* ctx) { DVLOG(3) << "AddUnshapedRemoteMirror on TensorHandle: " << this << " device: " << d << " " << d->name() << " op_id: " << op_id << " output_num: " << output_num; @@ -816,8 +820,9 @@ Status TensorHandle::AddUnshapedRemoteMirror(const Device* d, int64_t op_id, return absl::OkStatus(); } -Status TensorHandle::AddResourceShapeMirror(const Device* d, int64_t op_id, - int output_num, EagerContext* ctx) { +absl::Status TensorHandle::AddResourceShapeMirror(const Device* d, + int64_t op_id, int output_num, + EagerContext* ctx) { DVLOG(3) << "AddResourceShapeMirror on TensorHandle: " << this; mutex_lock l(mu_); @@ -848,15 +853,16 @@ Status TensorHandle::AddResourceShapeMirror(const Device* d, int64_t op_id, return absl::OkStatus(); } -Status TensorHandle::SetRemoteShape(const TensorShape& shape, const Device* d, - uint64 context_view_id) { +absl::Status TensorHandle::SetRemoteShape(const TensorShape& shape, + const Device* d, + uint64 context_view_id) { return SetRemoteShapeAndDevice(shape, d, context_view_id, /*op_device=*/""); } -Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape, - const Device* d, - uint64 context_view_id, - string op_device) { +absl::Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape, + const Device* d, + uint64 context_view_id, + string op_device) { DVLOG(3) << "SetRemoteShape on TensorHandle: " << this << " device: " << d << " " << d->name(); @@ -940,7 +946,7 @@ Status TensorHandle::SetRemoteShapeAndDevice(const TensorShape& shape, } } -void TensorHandle::PoisonRemote(Status status, const Device* d, +void TensorHandle::PoisonRemote(absl::Status status, const Device* d, uint64 context_view_id) { DVLOG(3) << "PoisonRemote on TensorHandle: " << this << " device: " << d << " " << d->name(); @@ -963,8 +969,8 @@ void TensorHandle::PoisonRemote(Status status, const Device* d, } #endif -Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, - const Device* d) { +absl::Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, + const Device* d) { if (d == device_) { return errors::Internal( "Local mirror assign conflicts with primary device."); @@ -981,7 +987,7 @@ Status TensorHandle::AddLocalMirror(tensorflow::Tensor&& tensor, return absl::OkStatus(); } -Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { +absl::Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { DVLOG(3) << "SetTensor on TensorHandle: " << this << " device: " << d; if (d == device_) { @@ -1008,7 +1014,7 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& t, const Device* d) { return absl::OkStatus(); } -void TensorHandle::Poison(Status status, const Device* d) { +void TensorHandle::Poison(absl::Status status, const Device* d) { DVLOG(3) << "Poison on TensorHandle: " << this << " device: " << d; if (d == device_) { @@ -1026,9 +1032,9 @@ void TensorHandle::Poison(Status status, const Device* d) { } } -Status TensorHandle::CopyToDevice(const EagerContext& ctx, - tensorflow::Device* d, - tensorflow::Tensor* output) const { +absl::Status TensorHandle::CopyToDevice(const EagerContext& ctx, + tensorflow::Device* d, + tensorflow::Tensor* output) const { tensorflow::Device* dstd = (d == nullptr) ? ctx.HostCPU() : d; tensorflow::Device* srcd = DeviceOrHostCPU(ctx); const bool dst_cpu = dstd->tensorflow_accelerator_device_info() == nullptr; @@ -1088,12 +1094,12 @@ Status TensorHandle::CopyToDevice(const EagerContext& ctx, // nothing to do with this tensor to complete). TF_RETURN_IF_ERROR(srcd->Sync()); tensorflow::Notification n; - tensorflow::Status status; + absl::Status status; tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, srcd, dstd, tensorflow::AllocatorAttributes(), tensorflow::AllocatorAttributes(), src, &dst, 0 /*dev_to_dev_stream_index*/, - [&status, &n](const tensorflow::Status& s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); @@ -1117,27 +1123,27 @@ Device* GetResourceDevice(const ResourceHandle& handle, EagerContext* ctx) { return device; } -const char* TensorHandle::DeviceName(Status* status) const { +const char* TensorHandle::DeviceName(absl::Status* status) const { status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } -const char* TensorHandle::BackingDeviceName(Status* status) const { +const char* TensorHandle::BackingDeviceName(absl::Status* status) const { status->Update(WaitUnknownDevice()); tensorflow::Device* d = device(); return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" : d->name().c_str(); } -const char* TensorHandle::DeviceType(Status* status) const { +const char* TensorHandle::DeviceType(absl::Status* status) const { status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? "CPU" : d->parsed_name().type.c_str(); } -int TensorHandle::DeviceId(Status* status) const { +int TensorHandle::DeviceId(absl::Status* status) const { status->Update(WaitUnknownDevice()); tensorflow::Device* d = op_device(); return (d == nullptr) ? 0 : d->parsed_name().id; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index b3ace87ce236f0..ca60815d76ec9e 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -94,14 +94,15 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // The new tensor handle shares ownership of the given handle: their reference // count will be increased by one after a call to `CreatePackedHandle`. // TODO(b/170414377): Use `TensorHandlePtr` instead. - static Status CreatePackedHandle(std::vector&& handles, - tensorflow::DataType dtype, - const tensorflow::TensorShape& shape, - const string& device_name, EagerContext* ctx, - TensorHandle** packed_handle); - static Status CreatePackedHandle(std::vector&& handles, - EagerContext* ctx, - TensorHandle** packed_handle); + static absl::Status CreatePackedHandle(std::vector&& handles, + tensorflow::DataType dtype, + const tensorflow::TensorShape& shape, + const string& device_name, + EagerContext* ctx, + TensorHandle** packed_handle); + static absl::Status CreatePackedHandle(std::vector&& handles, + EagerContext* ctx, + TensorHandle** packed_handle); #if !defined(IS_MOBILE_PLATFORM) // An unshaped remote handle refers to a tensor on a remote worker. It's not @@ -129,16 +130,16 @@ class TensorHandle : public ImmediateExecutionTensorHandle { void Release(); tensorflow::DataType DataType() const override; - Status Shape(tensorflow::PartialTensorShape* shape) const override; - Status NumDims(int* num_dims) const override; - Status NumElements(int64_t* num_elements) const override; - Status Dim(int dim_index, int64_t* dim) const override; + absl::Status Shape(tensorflow::PartialTensorShape* shape) const override; + absl::Status NumDims(int* num_dims) const override; + absl::Status NumElements(int64_t* num_elements) const override; + absl::Status Dim(int dim_index, int64_t* dim) const override; - const char* DeviceName(Status* status) const override; - const char* BackingDeviceName(Status* status) const override; - const char* DeviceType(Status* status) const override; - int DeviceId(Status* status) const override; - AbstractTensorInterface* Resolve(Status* status) override; + const char* DeviceName(absl::Status* status) const override; + const char* BackingDeviceName(absl::Status* status) const override; + const char* DeviceType(absl::Status* status) const override; + int DeviceId(absl::Status* status) const override; + AbstractTensorInterface* Resolve(absl::Status* status) override; // Subclasses may return True to instruct the string formatter // to use SummarizeValue instead of the NumPy formatter. @@ -147,16 +148,17 @@ class TensorHandle : public ImmediateExecutionTensorHandle { } // Return the Tensor from the default device. - Status Tensor(const tensorflow::Tensor** t) const; + absl::Status Tensor(const tensorflow::Tensor** t) const; // Return the Tensor from the specified device which could be either the // default device or a local mirror. The device pointer should be nullptr if // requesting the HostCPU. - Status TensorFromDevice(const Device* d, const tensorflow::Tensor** t) const; + absl::Status TensorFromDevice(const Device* d, + const tensorflow::Tensor** t) const; // Return the TensorValue from the specified device which could be either the // default device or a local mirror. The device pointer should be nullptr if // requesting the HostCPU. - Status TensorValue(const Device* d, tensorflow::TensorValue* t); + absl::Status TensorValue(const Device* d, tensorflow::TensorValue* t); Device* device() const { return device_; } Device* op_device() const { return op_device_; } @@ -167,13 +169,13 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // If the devices are unknown at creation time, block until the actual devices // are set (data is ready). - Status WaitUnknownDevice() const; + absl::Status WaitUnknownDevice() const; Device* DeviceOrHostCPU(const EagerContext& ctx) const; - Status Shape(tensorflow::TensorShape* shape); + absl::Status Shape(tensorflow::TensorShape* shape); - Status Unprotect(const Device* d); + absl::Status Unprotect(const Device* d); // Checks if a mirror tensor exists for the specified device. Mirrors are only // maintained for local devices, like CPUs & GPUs. Note a mirror may be empty, @@ -181,25 +183,27 @@ class TensorHandle : public ImmediateExecutionTensorHandle { bool HasLocalMirror(const Device* d) const; // Add an empty mirror placeholder for the specified device. The expectation // is this will be populated by a call to SetTensor. - Status AddEmptyLocalMirror(const Device* d); + absl::Status AddEmptyLocalMirror(const Device* d); // Add a local mirror. This will fail if an empty local mirror was previously // added. For that case, SetTensor should be used instead. - Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d); + absl::Status AddLocalMirror(tensorflow::Tensor&& tensor, const Device* d); #if !defined(IS_MOBILE_PLATFORM) bool HasRemoteMirror(const Device* d, uint64 context_view_id) const; bool HasResourceShapeMirror(const Device* d, uint64 context_view_id) const; - Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, int output_num, - const string& remote_task, EagerContext* ctx); - Status AddResourceShapeMirror(const Device* d, int64_t op_id, int output_num, - EagerContext* ctx); + absl::Status AddUnshapedRemoteMirror(const Device* d, int64_t op_id, + int output_num, + const string& remote_task, + EagerContext* ctx); + absl::Status AddResourceShapeMirror(const Device* d, int64_t op_id, + int output_num, EagerContext* ctx); // Return the op_id and output num if the handle refers to a remote tensor. // If wait_until_ready is true, block until the remote tensor is ready on the // given remote worker. - Status RemoteAddress(const Device* d, bool wait_until_ready, int64_t* op_id, - int32* output_num) const; + absl::Status RemoteAddress(const Device* d, bool wait_until_ready, + int64_t* op_id, int32* output_num) const; // Called on an async remote tensor once it's shape has been determined. This // transitions the tensor handle from a non-ready to a ready state by @@ -208,43 +212,46 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // creating a TensorHandle (e.g. a remote output of a remote function). // This method or Poison must be called exactly once for remote tensors that // were created without a known shape. - Status SetRemoteShape(const TensorShape& shape, const Device* d, - uint64 context_view_id); + absl::Status SetRemoteShape(const TensorShape& shape, const Device* d, + uint64 context_view_id); // If op_device is not empty, reset the devices of a remote tensor which is // created without known devices (e.g. function outputs). - Status SetRemoteShapeAndDevice(const TensorShape& shape, const Device* d, - uint64 context_view_id, string op_device); + absl::Status SetRemoteShapeAndDevice(const TensorShape& shape, + const Device* d, uint64 context_view_id, + string op_device); // Poisons either this handle or a remote mirror with error `status`. // Poisoning means that the handle will become ready and methods trying // to access the remote shape will return this error `status`. // Exactly one of SetRemoteShape or PoisonRemote methods must be called on a // unshaped handle on a remote device. - void PoisonRemote(Status status, const Device* d, uint64 context_view_id); + void PoisonRemote(absl::Status status, const Device* d, + uint64 context_view_id); #endif // Sets the `tensor` for this async non-ready handle making it ready. // This method or Poison must be called exactly once for non-ready async // handles to make them ready. - Status SetTensor(tensorflow::Tensor&& tensor, const Device* d); + absl::Status SetTensor(tensorflow::Tensor&& tensor, const Device* d); // Poisons either this handle or a local mirror with error `status`. // Poisoning means that the handle will become ready and methods trying // to access the actual tensor or shape will return this error `status`. // Exactly one of SetTensor or Poison methods must be called on a non-ready // tensor for a specific device. - void Poison(Status status, const Device* d); + void Poison(absl::Status status, const Device* d); // TODO(b/154282629): Consider moving it to EagerContext. // Copies to the tensor on the given device `d`, or to host iff `d` is null. - Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, - tensorflow::Tensor* output) const; + absl::Status CopyToDevice(const EagerContext& ctx, tensorflow::Device* d, + tensorflow::Tensor* output) const; - Status InferenceShape(shape_inference::InferenceContext* inference_context, - shape_inference::ShapeHandle* shape_handle); + absl::Status InferenceShape( + shape_inference::InferenceContext* inference_context, + shape_inference::ShapeHandle* shape_handle); void SetInferenceShape(shape_inference::InferenceContext* inference_context, const shape_inference::ShapeHandle& shape_handle); - Status CopyInferenceShape(TensorHandle* other); + absl::Status CopyInferenceShape(TensorHandle* other); // dtype for the handle. It must be the same as t.dtype() once the handle is // ready. @@ -260,14 +267,14 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // If this TensorHandle is 1) a local tensor, and 2) a resource handle, // return data types and shapes of the underlying resource. - Status GetResourceHandleDtypesAndShapes( + absl::Status GetResourceHandleDtypesAndShapes( std::vector* result); // Returns the number of packed handles. 0 if the handle type is not PACKED. int NumPackedHandles() const; // It's called on a packed TensorHandle. Extract a handle with the given // index. - Status ExtractPackedHandle(int index, TensorHandle** handle) const; + absl::Status ExtractPackedHandle(int index, TensorHandle** handle) const; // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { @@ -292,7 +299,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { // to either SetTensor or SetRemoteShape which replaces the underlying data // with a ready version of the tensor handle data. bool IsReady() const; - Status WaitReady(const char* caller) const; + absl::Status WaitReady(const char* caller) const; tensorflow::Device* device_; @@ -356,20 +363,20 @@ class TensorHandle : public ImmediateExecutionTensorHandle { ~PackedTensorHandleData(); - Status Shape(TensorShape* shape) const; - Status NumDims(int* num_dims) const; - Status Dim(int dim_index, int64_t* dim) const; - Status NumElements(int64_t* num_elements) const; - Status Unprotect(); + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect(); bool IsReady() const; - Status WaitReady(const char* caller) const; - void Poison(Status status); + absl::Status WaitReady(const char* caller) const; + void Poison(absl::Status status); string DebugString() const; // Number of packed handles. int NumPackedHandles() const; // Extract a handle on the given index. - Status ExtractPackedHandle(int index, TensorHandle** handle) const; + absl::Status ExtractPackedHandle(int index, TensorHandle** handle) const; private: // TODO(b/170414377): Use `TensorHandlePtr` instead. @@ -377,7 +384,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { const TensorShape shape_; mutable mutex mu_; - Status is_poisoned_ TF_GUARDED_BY(mu_); + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); }; // Does not need synchronization because it can be accessed only after diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc index 07f7051f5010a4..2212b19db9c683 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.cc @@ -24,7 +24,7 @@ limitations under the License. namespace tensorflow { -Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { +absl::Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { TF_RETURN_IF_ERROR(WaitReady("Tensor")); *t = &tensor_; @@ -32,7 +32,7 @@ Status LocalTensorHandleData::Tensor(const tensorflow::Tensor** t) const { return absl::OkStatus(); } -Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { +absl::Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { TF_RETURN_IF_ERROR(WaitReady("TensorValue")); tensorflow::Tensor& tensor = tensor_; @@ -41,7 +41,7 @@ Status LocalTensorHandleData::TensorValue(tensorflow::TensorValue* t) { return absl::OkStatus(); } -Status LocalTensorHandleData::Shape(TensorShape* shape) const { +absl::Status LocalTensorHandleData::Shape(TensorShape* shape) const { TF_RETURN_IF_ERROR(WaitReady("Shape")); *shape = tensor_.shape(); @@ -49,7 +49,7 @@ Status LocalTensorHandleData::Shape(TensorShape* shape) const { return absl::OkStatus(); } -Status LocalTensorHandleData::NumDims(int* num_dims) const { +absl::Status LocalTensorHandleData::NumDims(int* num_dims) const { TF_RETURN_IF_ERROR(WaitReady("NumDims")); *num_dims = tensor_.dims(); @@ -57,7 +57,7 @@ Status LocalTensorHandleData::NumDims(int* num_dims) const { return absl::OkStatus(); } -Status LocalTensorHandleData::Dim(int dim_index, int64_t* dim) const { +absl::Status LocalTensorHandleData::Dim(int dim_index, int64_t* dim) const { TF_RETURN_IF_ERROR(WaitReady("Dim")); *dim = tensor_.dim_size(dim_index); @@ -65,7 +65,7 @@ Status LocalTensorHandleData::Dim(int dim_index, int64_t* dim) const { return absl::OkStatus(); } -Status LocalTensorHandleData::NumElements(int64_t* num_elements) const { +absl::Status LocalTensorHandleData::NumElements(int64_t* num_elements) const { TF_RETURN_IF_ERROR(WaitReady("NumElements")); *num_elements = tensor_.NumElements(); @@ -73,7 +73,7 @@ Status LocalTensorHandleData::NumElements(int64_t* num_elements) const { return absl::OkStatus(); } -Status LocalTensorHandleData::Unprotect() { +absl::Status LocalTensorHandleData::Unprotect() { if (!IsReady()) { return errors::Internal("Cannot unprotect a non-ready tensor"); } @@ -83,7 +83,7 @@ Status LocalTensorHandleData::Unprotect() { return absl::OkStatus(); } -Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) { +absl::Status LocalTensorHandleData::SetTensor(tensorflow::Tensor&& t) { DCHECK(!IsReady()) << "SetTensor is only called on non-ready handles."; tensor_ = std::move(t); @@ -109,7 +109,7 @@ void LocalTensorHandleData::BlockingControl::SetReady() { is_ready_ = true; } -Status LocalTensorHandleData::BlockingControl::WaitReady( +absl::Status LocalTensorHandleData::BlockingControl::WaitReady( const char* caller) const { tf_shared_lock l(mu_); if (!is_ready_) { @@ -124,7 +124,7 @@ Status LocalTensorHandleData::BlockingControl::WaitReady( return is_poisoned_; } -void LocalTensorHandleData::BlockingControl::Poison(Status status) { +void LocalTensorHandleData::BlockingControl::Poison(absl::Status status) { mutex_lock l(mu_); if (is_ready_) { LOG(ERROR) << "Poison can only be called on non-ready handle: " << this; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data.h b/tensorflow/core/common_runtime/eager/tensor_handle_data.h index cfab1b5a5dec0e..ed58e83a183bfe 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data.h @@ -35,30 +35,30 @@ class LocalTensorHandleData { ctrl_(absl::in_place_type) {} // A local tensor handle should be able to satisfy all of these requests. - Status Tensor(const tensorflow::Tensor** t) const; - Status TensorValue(tensorflow::TensorValue* t); - Status Shape(TensorShape* shape) const; - Status NumDims(int* num_dims) const; - Status Dim(int dim_index, int64_t* dim) const; - Status NumElements(int64_t* num_elements) const; - Status Unprotect(); + absl::Status Tensor(const tensorflow::Tensor** t) const; + absl::Status TensorValue(tensorflow::TensorValue* t); + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect(); bool IsReady() const { return std::visit([](auto& data) { return data.IsReady(); }, ctrl_); } - Status WaitReady(const char* caller) const { + absl::Status WaitReady(const char* caller) const { return std::visit([caller](auto& data) { return data.WaitReady(caller); }, ctrl_); } - void Poison(Status status) { + void Poison(absl::Status status) { return std::visit([status](auto& data) { data.Poison(status); }, ctrl_); } - Status IsPoisoned() const { + absl::Status IsPoisoned() const { return std::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_); } - Status SetTensor(tensorflow::Tensor&& t); + absl::Status SetTensor(tensorflow::Tensor&& t); string DebugString() const; @@ -80,9 +80,11 @@ class LocalTensorHandleData { class NonBlockingControl { public: bool IsReady() const { return true; } - Status WaitReady(const char* caller) const { return absl::OkStatus(); } - void Poison(Status status) {} - Status IsPoisoned() const { return absl::OkStatus(); } + absl::Status WaitReady(const char* caller) const { + return absl::OkStatus(); + } + void Poison(absl::Status status) {} + absl::Status IsPoisoned() const { return absl::OkStatus(); } }; class BlockingControl { @@ -92,9 +94,9 @@ class LocalTensorHandleData { return is_ready_; } void SetReady(); - Status WaitReady(const char* caller) const; - void Poison(Status status); - Status IsPoisoned() const { + absl::Status WaitReady(const char* caller) const; + void Poison(absl::Status status); + absl::Status IsPoisoned() const { tf_shared_lock l(mu_); return is_poisoned_; } @@ -102,7 +104,7 @@ class LocalTensorHandleData { private: mutable mutex mu_; bool is_ready_ TF_GUARDED_BY(mu_); - Status is_poisoned_ TF_GUARDED_BY(mu_); + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); }; std::variant ctrl_; diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_data_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_data_test.cc index b5dcd52f8436c0..4a11dbd6ce5725 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_data_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_data_test.cc @@ -115,8 +115,7 @@ TEST(TensorHandleData, NonBlockingControlPoisonHandle) { LocalTensorHandleData handle_data(std::move(t)); TF_EXPECT_OK(handle_data.IsPoisoned()); - tensorflow::Status fake_failure_status(absl::StatusCode::kAborted, - "Fake failure."); + absl::Status fake_failure_status(absl::StatusCode::kAborted, "Fake failure."); handle_data.Poison(fake_failure_status); // NonBlockingControl can never poison the tensor. @@ -127,8 +126,7 @@ TEST(TensorHandleData, BlockingControlPoisonHandle) { LocalTensorHandleData handle_data; TF_EXPECT_OK(handle_data.IsPoisoned()); - tensorflow::Status fake_failure_status(absl::StatusCode::kAborted, - "Fake failure."); + absl::Status fake_failure_status(absl::StatusCode::kAborted, "Fake failure."); handle_data.Poison(fake_failure_status); EXPECT_THAT(handle_data.IsPoisoned(), diff --git a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc index 8938d7517d2a22..fb2853c35f5667 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle_test.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle_test.cc @@ -84,7 +84,7 @@ class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr, bool is_local) : Device(nullptr, attr), is_local_(is_local) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } bool IsLocal() const override { return is_local_; } @@ -141,7 +141,7 @@ class PackedTensorHandleTest : public ::testing::Test { } bool IsReady(TensorHandle* handle) const { return handle->IsReady(); } - Status WaitReady(TensorHandle* handle) const { + absl::Status WaitReady(TensorHandle* handle) const { return handle->WaitReady("Test"); } @@ -289,8 +289,7 @@ TEST_F(PackedTensorHandleTest, PoisonHandle) { TF_EXPECT_OK(WaitReady(packed_handle)); // Poisoning the handle will make WaitReady fail. - tensorflow::Status fake_failure_status(absl::StatusCode::kAborted, - "Fake failure."); + absl::Status fake_failure_status(absl::StatusCode::kAborted, "Fake failure."); packed_handle->Poison(fake_failure_status, packed_handle->device()); EXPECT_THAT(WaitReady(packed_handle), StatusIs(fake_failure_status.code(), @@ -450,7 +449,7 @@ TEST_F(RemoteTensorHandleTest, UnknownRemoteDevice) { Device* d2 = device_mgr.ListDevices().at(2); TF_ASSERT_OK(h->SetRemoteShapeAndDevice( shape, d1, context->GetContextViewId(), d2->name())); - Status s; + absl::Status s; EXPECT_EQ(h->BackingDeviceName(&s), d2->name()); TF_EXPECT_OK(s); EXPECT_EQ(h->device(), d2); @@ -486,8 +485,7 @@ TEST_F(RemoteTensorHandleTest, PoisonRemote) { absl::Cleanup h_cleanup = [&]() { h->Unref(); }; EXPECT_EQ(h->device(), d1); - tensorflow::Status fake_failure_status(absl::StatusCode::kAborted, - "Fake failure."); + absl::Status fake_failure_status(absl::StatusCode::kAborted, "Fake failure."); h->PoisonRemote(fake_failure_status, d1, context->GetContextViewId()); Device* d2 = device_mgr.ListDevices().at(2); @@ -533,8 +531,7 @@ TEST_F(RemoteTensorHandleTest, PoisonRemoteMirror) { TF_ASSERT_OK( h->AddUnshapedRemoteMirror(d2, op_id, output_num, remote_task, context)); - tensorflow::Status fake_failure_status(absl::StatusCode::kAborted, - "Fake failure."); + absl::Status fake_failure_status(absl::StatusCode::kAborted, "Fake failure."); h->PoisonRemote(fake_failure_status, d2, context->GetContextViewId()); EXPECT_THAT(h->SetRemoteShapeAndDevice(shape, d2, context->GetContextViewId(), @@ -820,7 +817,7 @@ TEST(TensorHandle_DeviceNameTest, OnLocalDevice) { TensorShape shape = {2}; Tensor tcpu(dtype, shape); Tensor tgpu(dtype, shape); - Status s; + absl::Status s; TensorHandle* th_cpu = TensorHandle::CreateLocalHandle(std::move(tcpu), dcpu, dcpu, dcpu, ctx); diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index d28c154ea7ab59..1a0aec3772b831 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -171,7 +171,11 @@ tf_cuda_library( cuda_deps = [ "@local_config_cuda//cuda:cudnn_header", "@local_xla//xla/stream_executor/cuda:cuda_platform", +<<<<<<< HEAD "@local_xla//xla/stream_executor/gpu:gpu_stream", +======= + "@local_xla//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", +>>>>>>> master ":gpu_runtime_hermetic_cuda_deps", ], defines = if_linux_x86_64(["TF_PLATFORM_LINUX_X86_64"]), @@ -204,7 +208,6 @@ tf_cuda_library( "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", "@local_xla//xla/stream_executor/gpu:gpu_init_impl", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", "@local_xla//xla/tsl/framework:device_id_utils", ] + if_google( # TODO(b/282068262): PJRT pulls in TFRT components that are incompatible with ARM platform. @@ -230,7 +233,9 @@ tf_cuda_library( }) + if_cuda_or_rocm([ "@local_xla//xla/service:gpu_plugin_impl", # for registering cuda compiler. ]), - ), + ) + if_cuda_or_rocm([ + "@local_tsl//tsl/platform:dso_loader", + ]), alwayslink = 1, ) diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index b3dc824bcdd4f9..83c18087ac73bf 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -26,12 +26,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" #include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/core/protobuf/bfc_memory_map.pb.h" #include "tensorflow/core/protobuf/config.pb.h" -#include "tsl/lib/gtl/inlined_vector.h" -#include "tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc index 8ea02e004a8769..9a279787d6ee2a 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc @@ -20,7 +20,6 @@ limitations under the License. #endif // GOOGLE_CUDA #include "xla/stream_executor/gpu/gpu_init.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/tsl/framework/device_id.h" #include "tsl/platform/logging.h" @@ -36,7 +35,8 @@ GPUcudaMallocAllocator::GPUcudaMallocAllocator( void* GPUcudaMallocAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { #ifdef GOOGLE_CUDA // allocate with cudaMalloc - se::gpu::ScopedActivateContext scoped_activation{stream_exec_}; + std::unique_ptr scoped_activation = + stream_exec_->Activate(); CUdeviceptr rv = 0; CUresult res = cuMemAlloc(&rv, num_bytes); if (res != CUDA_SUCCESS) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index bb932a5af9e0c0..de65df20e2dad4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" #include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h" #include "tensorflow/core/framework/typed_allocator.h" -#include "tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 205e1e2712002c..f8c8a2724cf452 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -84,9 +84,6 @@ limitations under the License. #include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #endif // TF_GPU_USE_PJRT -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/logging.h" @@ -96,6 +93,7 @@ limitations under the License. #include "tensorflow/core/profiler/lib/scoped_annotation.h" #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h" #include "tensorflow/core/public/session_options.h" +#include "tsl/platform/dso_loader.h" #ifdef TF_GPU_USE_PJRT #include "tensorflow/core/tfrt/common/pjrt_util.h" #endif // TF_GPU_USE_PJRT @@ -141,14 +139,12 @@ int GetPriority(const int tf_device_id, const GPUOptions& options) { typedef cudaStream_t gpuStream_t; typedef cudaDeviceProp gpuDeviceProp_t; #define EIGEN_GPU_SCRATCH_SIZE (Eigen::kGpuScratchSize) -using se::gpu::ScopedActivateContext; #elif TENSORFLOW_USE_ROCM typedef hipStream_t gpuStream_t; typedef hipDeviceProp_t gpuDeviceProp_t; #define EIGEN_GPU_SCRATCH_SIZE (Eigen::kGpuScratchSize) -using se::gpu::ScopedActivateContext; #endif @@ -790,7 +786,8 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { kernel_tracker_->PauseWhilePendingExceeds(pending_cap_); } } - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); profiler::ScopedMemoryDebugAnnotation op_annotation( op_kernel->name_view().data(), context->step_id()); bool should_log_inputs_and_outputs = ShouldLogInputsAndOutputs(op_kernel); @@ -884,7 +881,8 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel, }; } - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); op_kernel->ComputeAsync(context, std::move(done)); } @@ -2339,7 +2337,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds( #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // Try to dlopen GPU libraries if they are supposed to be dynamically loaded. - auto handle_or = se::internal::DsoLoader::MaybeTryDlopenGPULibraries(); + auto handle_or = tsl::internal::DsoLoader::MaybeTryDlopenGPULibraries(); if (!handle_or.ok()) { LOG(WARNING) << "Cannot dlopen some GPU libraries. Please make sure the " "missing libraries mentioned above are installed properly " diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc index d04790208b74df..eafe7a4ad494f1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc @@ -20,13 +20,13 @@ limitations under the License. namespace tensorflow { -Status GpuIdManager::InsertTfPlatformDeviceIdPair( +absl::Status GpuIdManager::InsertTfPlatformDeviceIdPair( TfDeviceId tf_device_id, PlatformDeviceId platform_device_id) { return DeviceIdManager::InsertTfPlatformDeviceIdPair(DEVICE_GPU, tf_device_id, platform_device_id); } -Status GpuIdManager::TfToPlatformDeviceId( +absl::Status GpuIdManager::TfToPlatformDeviceId( TfDeviceId tf_device_id, PlatformDeviceId* platform_device_id) { return DeviceIdManager::TfToPlatformDeviceId(DEVICE_GPU, tf_device_id, platform_device_id); diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h index 9b8df489857850..aa8553f6f90aa3 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -26,13 +26,13 @@ namespace tensorflow { class GpuIdManager { public: // Adds a mapping from tf_device_id to platform_device_id. - static Status InsertTfPlatformDeviceIdPair( + static absl::Status InsertTfPlatformDeviceIdPair( tsl::TfDeviceId tf_device_id, tsl::PlatformDeviceId platform_device_id); // Gets the platform_device_id associated with tf_device_id. Returns OK if // found. - static Status TfToPlatformDeviceId(tsl::TfDeviceId tf_device_id, - tsl::PlatformDeviceId* platform_device_id); + static absl::Status TfToPlatformDeviceId( + tsl::TfDeviceId tf_device_id, tsl::PlatformDeviceId* platform_device_id); // Clears the map. Used in unit tests only. static void TestOnlyReset(); diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index c9e51e5a97c404..96d9ca758d67e0 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -342,7 +342,7 @@ Allocator* GPUProcessState::GetGpuHostAllocator(const GPUOptions& options, options.experimental().gpu_host_mem_limit_in_mb() * (1LL << 20); if (mem_limit_bytes <= 0) { int64_t limit_mb = -1; - Status status = + absl::Status status = tsl::ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB", 1LL << 17 /*2^17 MB == 128GB*/, &limit_mb); if (!status.ok()) { diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc index b33cf0d01fde5e..74ff3f7c39cde6 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc @@ -80,10 +80,10 @@ namespace tensorflow { using se::DeviceMemoryBase; using se::Stream; -Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src, - const Tensor* dst, - const DeviceBase::AcceleratorDeviceInfo** dev_info, - se::Stream** stream) { +absl::Status PrepareCopy(Device* device, const DeviceContext* ctx, + const Tensor& src, const Tensor* dst, + const DeviceBase::AcceleratorDeviceInfo** dev_info, + se::Stream** stream) { if (device == nullptr) { return errors::Internal("Unexpected null device."); } @@ -139,8 +139,8 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev, VLOG(1) << "SetProtoFromGPU device_context " << device_context; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(dev, device_context, tensor, nullptr, &dev_info, - &send_stream); + absl::Status s = PrepareCopy(dev, device_context, tensor, nullptr, &dev_info, + &send_stream); if (!s.ok()) { done(s); return; @@ -222,8 +222,8 @@ void GPUUtil::DeviceToDeviceCopy( int dev_to_dev_stream_index, StatusCallback done) { const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info, - &send_stream); + absl::Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info, + &send_stream); if (!s.ok()) { done(s); return; @@ -323,8 +323,8 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, StatusCallback done) { const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, cpu_tensor, - &dev_info, &send_stream); + absl::Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, + cpu_tensor, &dev_info, &send_stream); if (!s.ok()) { done(s); return; @@ -367,7 +367,7 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device, xla::PjRtFuture<> future = pjrt_tensor_buffer->pjrt_buffer()->ToLiteral(literal.get()); future.OnReady([literal = std::move(literal), - done](const tensorflow::Status& status) { done(status); }); + done](const absl::Status& status) { done(status); }); return; } #endif // TF_GPU_USE_PJRT @@ -404,8 +404,8 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, VLOG(1) << "CopyCPUTensorToGPU"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* recv_stream = nullptr; - Status s = PrepareCopy(gpu_device, device_context, *cpu_tensor, gpu_tensor, - &dev_info, &recv_stream); + absl::Status s = PrepareCopy(gpu_device, device_context, *cpu_tensor, + gpu_tensor, &dev_info, &recv_stream); if (!s.ok()) { done(s); return; @@ -514,7 +514,7 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor, }); } -Status GPUUtil::Sync(Device* gpu_device) { +absl::Status GPUUtil::Sync(Device* gpu_device) { VLOG(1) << "GPUUtil::Sync"; auto* dev_info = gpu_device->tensorflow_accelerator_device_info(); if (!dev_info) { @@ -523,7 +523,7 @@ Status GPUUtil::Sync(Device* gpu_device) { return dev_info->stream->BlockHostUntilDone(); } -Status GPUUtil::SyncAll(Device* gpu_device) { +absl::Status GPUUtil::SyncAll(Device* gpu_device) { VLOG(1) << "GPUUtil::SyncAll"; auto* dev_info = gpu_device->tensorflow_accelerator_device_info(); if (!dev_info) { @@ -565,10 +565,10 @@ uint64 GPUUtil::Checksum(Device* gpu_device, const DeviceContext* device_context, const Tensor& tensor) { Tensor copy(tensor.dtype(), tensor.shape()); - Status s; + absl::Status s; Notification n; CopyGPUTensorToCPU(gpu_device, device_context, &tensor, ©, - [&s, &n](Status status) { + [&s, &n](absl::Status status) { s.Update(status); n.Notify(); }); @@ -598,8 +598,8 @@ void GPUUtil::CopyGPUTensorToSameGPU(Device* gpu_device, VLOG(1) << "CopyGPUTensorToSameGPU"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(gpu_device, device_context, *src_gpu_tensor, - dst_gpu_tensor, &dev_info, &send_stream); + absl::Status s = PrepareCopy(gpu_device, device_context, *src_gpu_tensor, + dst_gpu_tensor, &dev_info, &send_stream); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h index b3614e1bf18968..0b650ad9804343 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util.h +++ b/tensorflow/core/common_runtime/gpu/gpu_util.h @@ -52,12 +52,12 @@ class GPUUtil { // Blocks until all operations queued on the stream associated with // "gpu_device" at the time of the call have completed. Returns any // error pending on the stream at completion. - static Status Sync(Device* gpu_device); + static absl::Status Sync(Device* gpu_device); // Blocks until all operations queued on all streams associated with the // corresponding GPU device at the time of call have completed. // Returns any error pending on the stream at completion. - static Status SyncAll(Device* gpu_device); + static absl::Status SyncAll(Device* gpu_device); // For debugging purpose, given a "device" and a "tensor" allocated // on the device, return a string printing each byte in the tensor diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index 783597123a3100..0eddde84668c39 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -47,8 +47,8 @@ void GPUDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, done); } -Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, - std::function func) { +absl::Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, + std::function func) { const DeviceBase::AcceleratorDeviceInfo* gpu_info = device->tensorflow_accelerator_device_info(); gpu_info->event_mgr->ThenExecute(stream, func); diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index 1c8f6283c57c07..a4799bf23b1167 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -79,8 +79,8 @@ class GPUDeviceContext : public DeviceContext { void MaintainLifetimeOnStream(const Tensor* t, se::Stream* stream) const override {} - Status ThenExecute(Device* device, se::Stream* stream, - std::function func) override; + absl::Status ThenExecute(Device* device, se::Stream* stream, + std::function func) override; private: int stream_id_; diff --git a/tensorflow/core/common_runtime/gradients.cc b/tensorflow/core/common_runtime/gradients.cc index b91d6986705fcc..466977ecf772d6 100644 --- a/tensorflow/core/common_runtime/gradients.cc +++ b/tensorflow/core/common_runtime/gradients.cc @@ -72,7 +72,7 @@ static Node* AddZerosLike(Graph* g, NodeOut input) { read_def.set_op("ReadVariableOp"); read_def.add_input(input.name()); AddNodeAttr("dtype", DT_FLOAT, &read_def); - Status s; + absl::Status s; Node* read = g->AddNode(read_def, &s); TF_CHECK_OK(s); g->AddEdge(input.node, input.index, read, 0); @@ -91,7 +91,7 @@ static Node* AddZerosLike(Graph* g, NodeOut input) { ndef.set_op("ZerosLike"); ndef.add_input(input.name()); AddNodeAttr("T", input.dtype(), &ndef); - Status s; + absl::Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); g->AddEdge(input.node, input.index, ret, 0); @@ -143,7 +143,7 @@ static Node* AddSymGrad(Graph* g, Node* n, absl::Span grads) { (*func.mutable_attr())[attr.first] = attr.second; } AddNodeAttr("f", func, &ndef); - Status s; + absl::Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); return ret; @@ -157,7 +157,7 @@ class SymbolicGradientBuilder { std::vector* x_grad_node_outputs, Graph* graph); - Status Compute(); + absl::Status Compute(); private: absl::Span y_node_outputs_; @@ -324,7 +324,7 @@ NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) { } AddNodeAttr("N", static_cast(grads.size()), &ndef); AddNodeAttr("T", dtype, &ndef); - Status s; + absl::Status s; Node* add = graph_->AddNode(ndef, &s); TF_CHECK_OK(s); for (size_t i = 0; i < grads.size(); ++i) { @@ -336,11 +336,11 @@ NodeOut SymbolicGradientBuilder::SumGradients(const NodeOut& src) { static bool IsPrimitiveOpWithNoGrad(const string& func) { gradient::Creator creator; - Status s = gradient::GetOpGradientCreator(func, &creator); + absl::Status s = gradient::GetOpGradientCreator(func, &creator); return s.ok() && (creator == nullptr); } -Status SymbolicGradientBuilder::Compute() { +absl::Status SymbolicGradientBuilder::Compute() { // Initialize backprops. InitBackprop(); @@ -405,11 +405,11 @@ Status SymbolicGradientBuilder::Compute() { return absl::OkStatus(); } -Status AddSymbolicGradients(absl::Span y_node_outputs, - absl::Span x_node_outputs, - absl::Span y_grad_node_outputs, - std::vector* x_grad_node_outputs, - Graph* graph) { +absl::Status AddSymbolicGradients(absl::Span y_node_outputs, + absl::Span x_node_outputs, + absl::Span y_grad_node_outputs, + std::vector* x_grad_node_outputs, + Graph* graph) { SymbolicGradientBuilder builder(y_node_outputs, x_node_outputs, y_grad_node_outputs, x_grad_node_outputs, graph); diff --git a/tensorflow/core/common_runtime/gradients.h b/tensorflow/core/common_runtime/gradients.h index d14c85d112f08a..aaa9cad80ad691 100644 --- a/tensorflow/core/common_runtime/gradients.h +++ b/tensorflow/core/common_runtime/gradients.h @@ -47,11 +47,11 @@ struct NodeOut { // implementation only supports gradients for functions). In particular, // the nodes in 'x_nodes' are currently restricted to have one output. -Status AddSymbolicGradients(absl::Span y_node_outputs, - absl::Span x_node_outputs, - absl::Span y_grad_node_outputs, - std::vector* x_grad_node_outputs, - Graph* graph); +absl::Status AddSymbolicGradients(absl::Span y_node_outputs, + absl::Span x_node_outputs, + absl::Span y_grad_node_outputs, + std::vector* x_grad_node_outputs, + Graph* graph); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 3705ede827e0f6..dc8dbe5711fb2e 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -167,7 +167,7 @@ class GraphConstructor { typedef absl::Span NodeDefSlice; // versions, library, and debug_info may be nullptr - static Status Construct( + static absl::Status Construct( const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, const FunctionDefLibrary* library, const GraphDebugInfo* debug_info, Graph* g, ShapeRefiner* refiner, @@ -175,7 +175,7 @@ class GraphConstructor { std::vector* return_nodes, std::vector* missing_unused_input_map_keys); - static Status Construct( + static absl::Status Construct( const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors, std::vector* return_nodes, @@ -197,7 +197,7 @@ class GraphConstructor { virtual ~GraphConstructor() {} - Status TryImport() { + absl::Status TryImport() { TF_RETURN_IF_ERROR(EnsureNoNameCollisions()); TF_RETURN_IF_ERROR(ValidateInputMapAndControlDependencies()); TF_RETURN_IF_ERROR(BuildNodeIndex()); @@ -218,16 +218,16 @@ class GraphConstructor { } private: - Status EnsureNoNameCollisions(); - Status ValidateInputMapAndControlDependencies(); - Status BuildNodeIndex(); - Status InitFromEdges(); - Status Convert(); - Status AddBackEdges(); - Status UpdateVersionDef(); - Status PopulateReturnTensors(); - Status PopulateReturnNodes(); - Status PopulateMissingUnusedInputMapKeys(); + absl::Status EnsureNoNameCollisions(); + absl::Status ValidateInputMapAndControlDependencies(); + absl::Status BuildNodeIndex(); + absl::Status InitFromEdges(); + absl::Status Convert(); + absl::Status AddBackEdges(); + absl::Status UpdateVersionDef(); + absl::Status PopulateReturnTensors(); + absl::Status PopulateReturnNodes(); + absl::Status PopulateMissingUnusedInputMapKeys(); FunctionDefLibraryStackTraces CreateStackTracesForFunctionDefLibrary( const FunctionDefLibrary& library) const; @@ -241,12 +241,13 @@ class GraphConstructor { std::vector* is_on_cur_branch, absl::flat_hash_set* unvisited, const std::vector& node_names); - Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped); - Status ValidateColocationConstraints(const NodeDef& node_def); - Status MakeNode(NodeDef&& node_def, Node** node); - Status MakeEdge(Node* src, int output_index, Node* dst, int input_index); - Status ValidateShape(Node* node); - Status ModifyNodeDefForImport(NodeDef* node_def); + absl::Status IsNodeFullyMapped(const NodeDef& node_def, bool* is_node_mapped); + absl::Status ValidateColocationConstraints(const NodeDef& node_def); + absl::Status MakeNode(NodeDef&& node_def, Node** node); + absl::Status MakeEdge(Node* src, int output_index, Node* dst, + int input_index); + absl::Status ValidateShape(Node* node); + absl::Status ModifyNodeDefForImport(NodeDef* node_def); // Modifies node_def's inputs according to opts_.input_map. // input_already_exists is a pre-initialized vector of length // node_def->input_size(). This function will mark inputs that are remapped to @@ -495,10 +496,10 @@ bool ForwardCompatibilityWindowPassed(const VersionDef& versions) { return (versions.producer() - TF_GRAPH_DEF_VERSION) > 21; } -Status MaybeAppendVersionWarning(const VersionDef* versions, - const Status& import_status) { +absl::Status MaybeAppendVersionWarning(const VersionDef* versions, + const absl::Status& import_status) { if (versions && ForwardCompatibilityWindowPassed(*versions)) { - return Status( + return absl::Status( import_status.code(), absl::StrCat( "Converting GraphDef to Graph has failed with an error: '", @@ -516,7 +517,7 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, return import_status; } -/* static */ Status GraphConstructor::Construct( +/* static */ absl::Status GraphConstructor::Construct( const Options& opts, NodeDefSlice node_defs, const VersionDef* versions, const FunctionDefLibrary* library, const GraphDebugInfo* debug_info, Graph* g, ShapeRefiner* refiner, @@ -531,7 +532,7 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, NodeDefCopyingGraphConstructor c(opts, node_defs, versions, library, debug_info, g, refiner, return_tensors, return_nodes, missing_unused_input_map_keys); - Status s = c.TryImport(); + absl::Status s = c.TryImport(); if (!s.ok()) { c.Undo(); s = MaybeAppendVersionWarning(versions, s); @@ -539,7 +540,7 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, return s; } -/* static */ Status GraphConstructor::Construct( +/* static */ absl::Status GraphConstructor::Construct( const Options& opts, GraphDef&& graph_def, Graph* g, ShapeRefiner* refiner, std::vector>* return_tensors, std::vector* return_nodes, @@ -551,7 +552,7 @@ Status MaybeAppendVersionWarning(const VersionDef* versions, NodeDefMovingGraphConstructor c(opts, std::move(graph_def), g, refiner, return_tensors, return_nodes, missing_unused_input_map_keys); - Status s = c.TryImport(); + absl::Status s = c.TryImport(); if (!s.ok()) { c.Undo(); s = MaybeAppendVersionWarning(&version_def, s); @@ -604,7 +605,7 @@ void AddPrefixes(StringPiece node_name, } } -Status GraphConstructor::EnsureNoNameCollisions() { +absl::Status GraphConstructor::EnsureNoNameCollisions() { existing_nodes_.reserve(g_->num_nodes()); // Populate existing_nodes_ and existing_prefixes_. for (Node* n : g_->nodes()) { @@ -646,7 +647,7 @@ Status GraphConstructor::EnsureNoNameCollisions() { return absl::OkStatus(); } -Status GraphConstructor::ValidateInputMapAndControlDependencies() { +absl::Status GraphConstructor::ValidateInputMapAndControlDependencies() { for (const auto& mapping : opts_.input_map) { TensorId src = mapping.first; TensorId dst = mapping.second; @@ -673,7 +674,7 @@ Status GraphConstructor::ValidateInputMapAndControlDependencies() { return absl::OkStatus(); } -Status GraphConstructor::BuildNodeIndex() { +absl::Status GraphConstructor::BuildNodeIndex() { // Validate the node names and add them to gdef_nodes_ and gdef_prefixes_. for (int n = 0; n < node_def_count(); ++n) { const NodeDef& node_def = get_node_def(n); @@ -717,7 +718,7 @@ Status GraphConstructor::BuildNodeIndex() { return absl::OkStatus(); } -Status GraphConstructor::InitFromEdges() { +absl::Status GraphConstructor::InitFromEdges() { const int num_nodes = node_def_count(); pending_count_.reserve(num_nodes); outputs_.resize(num_nodes); @@ -784,7 +785,7 @@ Status GraphConstructor::InitFromEdges() { return absl::OkStatus(); } -Status GraphConstructor::ValidateColocationConstraints( +absl::Status GraphConstructor::ValidateColocationConstraints( const NodeDef& node_def) { if (!opts_.validate_colocation_constraints || !opts_.importing) return absl::OkStatus(); @@ -802,9 +803,9 @@ Status GraphConstructor::ValidateColocationConstraints( return absl::OkStatus(); } -Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { +absl::Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { // Add the node to the graph. - Status status; + absl::Status status; *node = g_->AddNode(std::move(node_def), &status); if (!status.ok()) return status; if (opts_.expect_device_spec || @@ -814,7 +815,7 @@ Status GraphConstructor::MakeNode(NodeDef&& node_def, Node** node) { return absl::OkStatus(); } -Status GraphConstructor::ValidateShape(Node* node) { +absl::Status GraphConstructor::ValidateShape(Node* node) { if (!opts_.importing || !opts_.validate_shape) return absl::OkStatus(); TF_RETURN_IF_ERROR(refiner_->AddNode(node)); // For nodes with the _output_shapes attribute, override the shape. @@ -845,7 +846,7 @@ Status GraphConstructor::ValidateShape(Node* node) { for (int i = 0; i < node->num_outputs(); ++i) { const TensorShapeProto& p = *shape_attrs[i]; shape_inference::ShapeHandle h; - Status s = ic->MakeShapeFromShapeProto(p, &h); + absl::Status s = ic->MakeShapeFromShapeProto(p, &h); if (!s.ok()) { return errors::InvalidArgument("Node '", node->name(), " has an invalid ", kAttrName, " attribute (shape #", i, @@ -863,7 +864,7 @@ Status GraphConstructor::ValidateShape(Node* node) { return absl::OkStatus(); } -Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { +absl::Status GraphConstructor::ModifyNodeDefForImport(NodeDef* node_def) { const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); AddDefaultsToNodeDef(*op_def, node_def); @@ -1076,8 +1077,8 @@ string GraphConstructor::FindUniqueName(StringPiece original_name) { return name; } -Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, - bool* is_node_mapped) { +absl::Status GraphConstructor::IsNodeFullyMapped(const NodeDef& node_def, + bool* is_node_mapped) { const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def.op(), &op_def)); for (int i = 0; i < op_def->output_arg_size(); ++i) { @@ -1161,7 +1162,7 @@ GraphConstructor::CreateStackTracesForFunctionDefLibrary( } } -Status GraphConstructor::Convert() { +absl::Status GraphConstructor::Convert() { if (debug_info() != nullptr) { traces_ = LoadTracesFromDebugInfo(*debug_info()); } @@ -1364,7 +1365,7 @@ Status GraphConstructor::Convert() { return absl::OkStatus(); } -Status GraphConstructor::AddBackEdges() { +absl::Status GraphConstructor::AddBackEdges() { // Add the back edges after all nodes are created. for (const auto& e : back_edges_) { Node* src_node = gdef_nodes_[e.src_name].node; @@ -1381,7 +1382,7 @@ Status GraphConstructor::AddBackEdges() { return absl::OkStatus(); } -Status GraphConstructor::UpdateVersionDef() { +absl::Status GraphConstructor::UpdateVersionDef() { if (versions() == nullptr) return absl::OkStatus(); if (!opts_.importing) { @@ -1407,7 +1408,7 @@ Status GraphConstructor::UpdateVersionDef() { return absl::OkStatus(); } -Status GraphConstructor::PopulateReturnTensors() { +absl::Status GraphConstructor::PopulateReturnTensors() { if (opts_.return_tensors.empty()) return absl::OkStatus(); for (const TensorId& id : opts_.return_tensors) { auto iter = opts_.input_map.find(id); @@ -1438,7 +1439,7 @@ Status GraphConstructor::PopulateReturnTensors() { return absl::OkStatus(); } -Status GraphConstructor::PopulateReturnNodes() { +absl::Status GraphConstructor::PopulateReturnNodes() { if (opts_.return_nodes.empty()) return absl::OkStatus(); for (StringPiece name : opts_.return_nodes) { auto iter = gdef_nodes_.find(name); @@ -1451,7 +1452,7 @@ Status GraphConstructor::PopulateReturnNodes() { return absl::OkStatus(); } -Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { +absl::Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { if (missing_unused_input_map_keys_ == nullptr) return absl::OkStatus(); for (const auto& input_map_pair : opts_.input_map) { TensorId key = input_map_pair.first; @@ -1489,8 +1490,8 @@ void GraphConstructor::Undo() { g_->set_versions(original_versions_); } -Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, - int input_index) { +absl::Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, + int input_index) { if (output_index >= src->num_outputs()) { return errors::InvalidArgument( "Output ", output_index, " of node ", src->name(), @@ -1515,8 +1516,8 @@ Status GraphConstructor::MakeEdge(Node* src, int output_index, Node* dst, } } // namespace -Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - const GraphDef& gdef, Graph* g) { +absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g) { ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); return GraphConstructor::Construct( opts, gdef.node(), &gdef.versions(), &gdef.library(), &gdef.debug_info(), @@ -1524,8 +1525,8 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, /*missing_unused_input_map_keys=*/nullptr); } -Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - GraphDef&& gdef, Graph* g) { +absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + GraphDef&& gdef, Graph* g) { ShapeRefiner refiner(gdef.versions().producer(), g->op_registry()); return GraphConstructor::Construct(opts, std::move(gdef), g, &refiner, /*return_tensors=*/nullptr, @@ -1533,9 +1534,9 @@ Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, /*missing_unused_input_map_keys=*/nullptr); } -Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, - absl::Span nodes, Graph* g, - const GraphDebugInfo* debug_info) { +absl::Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, + absl::Span nodes, Graph* g, + const GraphDebugInfo* debug_info) { ShapeRefiner refiner(TF_GRAPH_DEF_VERSION, g->op_registry()); // TODO(irving): Copy will go away once NodeInfo exists std::vector node_defs; @@ -1550,9 +1551,10 @@ Status ConvertNodeDefsToGraph(const GraphConstructorOptions& opts, /*missing_unused_input_map_keys=*/nullptr); } -Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef, - Graph* g, ShapeRefiner* refiner, - ImportGraphDefResults* results) { +absl::Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results) { if (!opts.return_tensors.empty()) { if (results == nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/core/common_runtime/graph_constructor.h b/tensorflow/core/common_runtime/graph_constructor.h index 34dfc6657d55a7..5f97f38760f3e3 100644 --- a/tensorflow/core/common_runtime/graph_constructor.h +++ b/tensorflow/core/common_runtime/graph_constructor.h @@ -53,15 +53,15 @@ struct GraphConstructorOptions { // value to the Node when they are missing from the NodeDef. bool add_default_attributes = true; }; -extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - const GraphDef& gdef, Graph* g); -extern Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, - GraphDef&& gdef, Graph* g); +extern absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + const GraphDef& gdef, Graph* g); +extern absl::Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts, + GraphDef&& gdef, Graph* g); // Same as ConvertGraphDefToGraph, but takes just nodes. Used by function // instantiation. // TODO(irving): This will turn into std::vector soon. -extern Status ConvertNodeDefsToGraph( +extern absl::Status ConvertNodeDefsToGraph( const GraphConstructorOptions& opts, absl::Span nodes, Graph* g, const GraphDebugInfo* debug_info = nullptr); @@ -194,10 +194,10 @@ struct ImportGraphDefResults { // // TODO(ashankar): Push this mechanism and get rid of Session::Extend() // as a means of enhancing an existing Graph. -extern Status ImportGraphDef(const ImportGraphDefOptions& opts, - const GraphDef& gdef, Graph* g, - ShapeRefiner* refiner, - ImportGraphDefResults* results = nullptr); +extern absl::Status ImportGraphDef(const ImportGraphDefOptions& opts, + const GraphDef& gdef, Graph* g, + ShapeRefiner* refiner, + ImportGraphDefResults* results = nullptr); // Make a copy of "src" into "*dest". // diff --git a/tensorflow/core/common_runtime/graph_constructor_fuzz.cc b/tensorflow/core/common_runtime/graph_constructor_fuzz.cc index 3fd71e21e65b66..8766d4992d68e5 100644 --- a/tensorflow/core/common_runtime/graph_constructor_fuzz.cc +++ b/tensorflow/core/common_runtime/graph_constructor_fuzz.cc @@ -39,7 +39,7 @@ void FuzzGraphEndToEndSimpleFixedInput(const GraphDef& graph_def) { // Load an arbitrary graph and run a session on it using simple input. ImportGraphDefOptions options; auto graph = std::make_unique(OpRegistry::Global()); - Status status = + absl::Status status = ImportGraphDef(options, graph_def, graph.get(), nullptr, nullptr); if (!status.ok()) { return; @@ -77,7 +77,7 @@ void FuzzGraphEndToEndAllStatic(const GraphDef& graph_def) { // to explore any arbitrary graph computation. ImportGraphDefOptions options; auto graph = std::make_unique(OpRegistry::Global()); - Status status = + absl::Status status = ImportGraphDef(options, graph_def, graph.get(), nullptr, nullptr); if (!status.ok()) { return; @@ -353,7 +353,7 @@ void FuzzGraphEndToEndFDP(std::vector data) { } std::unique_ptr graph = std::make_unique(OpRegistry::Global()); - Status s = ImportGraphDef(opts, gdef_, graph.get(), nullptr, nullptr); + absl::Status s = ImportGraphDef(opts, gdef_, graph.get(), nullptr, nullptr); if (!s.ok()) { return; } diff --git a/tensorflow/core/common_runtime/graph_constructor_test.cc b/tensorflow/core/common_runtime/graph_constructor_test.cc index 2cae8ca92c3c81..419f09c4d17c55 100644 --- a/tensorflow/core/common_runtime/graph_constructor_test.cc +++ b/tensorflow/core/common_runtime/graph_constructor_test.cc @@ -62,7 +62,7 @@ class GraphConstructorTest : public ::testing::Test { Convert(gdef_ascii); GraphConstructorOptions opts; - Status status = ConvertGraphDefToGraph(opts, gdef_, &graph_); + absl::Status status = ConvertGraphDefToGraph(opts, gdef_, &graph_); EXPECT_FALSE(status.ok()); for (const string& error : expected_error_strs) { @@ -87,7 +87,8 @@ class GraphConstructorTest : public ::testing::Test { const string original_graph_description = GraphDebugString(); Convert(gdef_ascii); - Status status = ImportGraphDef(opts, gdef_, &graph_, refiner, results); + absl::Status status = + ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_FALSE(status.ok()); for (const string& error : expected_error_strs) { @@ -108,7 +109,7 @@ class GraphConstructorTest : public ::testing::Test { ShapeRefiner* refiner = nullptr, ImportGraphDefResults* results = nullptr) { Convert(gdef_ascii); - Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results); + absl::Status s = ImportGraphDef(opts, gdef_, &graph_, refiner, results); EXPECT_EQ(absl::OkStatus(), s) << s; } @@ -156,7 +157,7 @@ class GraphConstructorTest : public ::testing::Test { return ""; } std::vector value; - Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value); + absl::Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value); if (!s.ok()) { return ""; } @@ -180,7 +181,7 @@ class GraphConstructorTest : public ::testing::Test { GraphDef gdef_; }; -Status Scalars(shape_inference::InferenceContext* c) { +absl::Status Scalars(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->Scalar()); } @@ -932,7 +933,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef) { const string& sink = graph_.FindNodeId(Graph::kSinkId)->name(); // Importing an empty graph is fine. - Status s = ImportGraphDef(opts, def, &graph_, nullptr); + absl::Status s = ImportGraphDef(opts, def, &graph_, nullptr); ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(2, graph_.num_nodes()); EXPECT_TRUE(HasControlEdge(source, sink)); @@ -1019,7 +1020,8 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) { GraphDef def; ASSERT_TRUE(protobuf::TextFormat::ParseFromString( "node{ name:'A' op:'TestDefaultAttr'}", &def)); - Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); + absl::Status s = + ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); ASSERT_EQ(absl::OkStatus(), s) << s; Node* a = nullptr; for (Node* n : graph_.nodes()) { @@ -1040,7 +1042,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_Versioning) { const ImportGraphDefOptions opts; def.mutable_versions()->set_producer(TF_GRAPH_DEF_VERSION_MIN_PRODUCER - 1); - Status s = ImportGraphDef(opts, def, &graph_, nullptr); + absl::Status s = ImportGraphDef(opts, def, &graph_, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)) << s; def.mutable_versions()->Clear(); @@ -1161,7 +1163,8 @@ node { )EOF", &def); ASSERT_TRUE(parsed); - Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); + absl::Status s = + ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); EXPECT_EQ(absl::OkStatus(), s) << s; Graph g2(OpRegistry::Global()); @@ -2255,7 +2258,8 @@ versions { )EOF", &def); ASSERT_TRUE(parsed); - Status s = ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); + absl::Status s = + ImportGraphDef(ImportGraphDefOptions(), def, &graph_, nullptr); EXPECT_EQ(absl::OkStatus(), s) << s; } @@ -2443,7 +2447,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ErrorsDoNoChangeTheGraph) { const string& source = graph_.FindNodeId(Graph::kSourceId)->name(); const string& sink = graph_.FindNodeId(Graph::kSinkId)->name(); - Status s = ImportGraphDef(opts, def, &graph_, nullptr); + absl::Status s = ImportGraphDef(opts, def, &graph_, nullptr); ASSERT_EQ(absl::OkStatus(), s) << s; EXPECT_EQ(3, graph_.num_nodes()); // 'scope/A', source and sink EXPECT_TRUE(HasControlEdge(source, sink)); @@ -2732,7 +2736,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_NestedFunctionDefs) { EXPECT_TRUE(HasNode("Outer_966fa13d")); // Check that Inner and Outer have been imported const OpDef* op_def; - Status s = graph_.op_registry()->LookUpOpDef("Inner_d03c39a3", &op_def); + absl::Status s = graph_.op_registry()->LookUpOpDef("Inner_d03c39a3", &op_def); ASSERT_TRUE(s.ok()) << s.message(); s = graph_.op_registry()->LookUpOpDef("Outer_966fa13d", &op_def); ASSERT_TRUE(s.ok()) << s.message(); @@ -3212,7 +3216,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ValidateColocationConstraints) { ImportGraphDefOptions options; // TODO(yaozhang): Extend ExpectError to check error type and use ExpectError // and ExpectOK to replace the code below. - Status s = ImportGraphDef(options, def, &graph_, nullptr); + absl::Status s = ImportGraphDef(options, def, &graph_, nullptr); EXPECT_TRUE(errors::IsInvalidArgument(s)) << s; options.validate_colocation_constraints = false; TF_EXPECT_OK(ImportGraphDef(options, def, &graph_, nullptr)); diff --git a/tensorflow/core/common_runtime/graph_def_builder_util.cc b/tensorflow/core/common_runtime/graph_def_builder_util.cc index 4062fe2a4a82da..40e45334722652 100644 --- a/tensorflow/core/common_runtime/graph_def_builder_util.cc +++ b/tensorflow/core/common_runtime/graph_def_builder_util.cc @@ -18,7 +18,8 @@ limitations under the License. namespace tensorflow { -Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph) { +absl::Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, + Graph* graph) { GraphDef graph_def; TF_RETURN_IF_ERROR(builder.ToGraphDef(&graph_def)); GraphConstructorOptions opts; diff --git a/tensorflow/core/common_runtime/graph_def_builder_util.h b/tensorflow/core/common_runtime/graph_def_builder_util.h index 01f3d710460d6d..8fb539973f1dbf 100644 --- a/tensorflow/core/common_runtime/graph_def_builder_util.h +++ b/tensorflow/core/common_runtime/graph_def_builder_util.h @@ -28,7 +28,8 @@ class Graph; // Graph->GraphDef->Graph. This cleans up the graph (e.g. adds // edges from the source and to the sink node, resolves back edges // by name), and makes sure the resulting graph is valid. -Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, Graph* graph); +absl::Status GraphDefBuilderToGraph(const GraphDefBuilder& builder, + Graph* graph); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 744e72fcf75d57..7e0d16d0c143ac 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -89,7 +89,7 @@ GraphExecutionState::~GraphExecutionState() { delete graph_; } -/* static */ Status GraphExecutionState::MakeForBaseGraph( +/* static */ absl::Status GraphExecutionState::MakeForBaseGraph( GraphDef&& graph_def, const GraphExecutionStateOptions& options, std::unique_ptr* out_state) { #ifndef __ANDROID__ @@ -130,7 +130,7 @@ GraphExecutionState::~GraphExecutionState() { return absl::OkStatus(); } -/* static */ Status GraphExecutionState::MakeForPrunedGraph( +/* static */ absl::Status GraphExecutionState::MakeForPrunedGraph( const GraphExecutionState& base_execution_state, const GraphExecutionStateOptions& options, const BuildGraphOptions& subgraph_options, @@ -180,7 +180,7 @@ GraphExecutionState::~GraphExecutionState() { return absl::OkStatus(); } -Status GraphExecutionState::Extend( +absl::Status GraphExecutionState::Extend( const GraphDef& extension_def, std::unique_ptr* out) const { if (!session_options_->config.experimental() @@ -318,9 +318,9 @@ class TensorConnectionPruneRewrite : public subgraph::PruneRewrite { : subgraph::PruneRewrite(endpoint_name, nullptr /* device_info */), from_tensor_(std::move(from_tensor)) {} - Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, - Node** out_node) override { - Status s; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) override { + absl::Status s; auto check_no_cycle_fn = [this, feed_tensor, &s](Node* n) { if (n == feed_tensor.node) { s.Update(errors::InvalidArgument( @@ -352,9 +352,10 @@ class TensorConnectionPruneRewrite : public subgraph::PruneRewrite { }; template -Status LookupDevice(const DeviceSet& device_set, const string& tensor_name, - const Map& tensor2device, - const tensorflow::DeviceAttributes** out_device_attrs) { +absl::Status LookupDevice( + const DeviceSet& device_set, const string& tensor_name, + const Map& tensor2device, + const tensorflow::DeviceAttributes** out_device_attrs) { *out_device_attrs = nullptr; if (tensor2device.empty()) { *out_device_attrs = &device_set.client_device()->attributes(); @@ -430,7 +431,7 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { } } -Status ValidateFeedAndFetchDevices( +absl::Status ValidateFeedAndFetchDevices( const Graph& graph, const std::vector& tensors_and_devices) { if (tensors_and_devices.empty()) return absl::OkStatus(); @@ -468,9 +469,9 @@ Status ValidateFeedAndFetchDevices( return absl::OkStatus(); } -Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, - PartialTensorShape* shape, - DataType* type) { +absl::Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, + PartialTensorShape* shape, + DataType* type) { static const gtl::FlatSet* const kHasExplicitShapeAttribute = CHECK_NOTNULL((new gtl::FlatSet{ "Placeholder", "PlaceholderV2", "PlaceholderWithDefault", @@ -505,7 +506,7 @@ Status GetFeedShapeAndTypeFromAttribute(const NodeDef& node, } // namespace -Status GraphExecutionState::PruneGraph( +absl::Status GraphExecutionState::PruneGraph( const BuildGraphOptions& options, Graph* graph, subgraph::RewriteGraphMetadata* out_rewrite_metadata) { std::vector> feed_rewrites; @@ -611,7 +612,8 @@ Status GraphExecutionState::PruneGraph( return absl::OkStatus(); } -Status GraphExecutionState::InitBaseGraph(std::unique_ptr&& new_graph) { +absl::Status GraphExecutionState::InitBaseGraph( + std::unique_ptr&& new_graph) { // Save stateful placements before placing. RestoreStatefulNodes(new_graph.get()); @@ -649,7 +651,7 @@ Status GraphExecutionState::InitBaseGraph(std::unique_ptr&& new_graph) { return absl::OkStatus(); } -Status GraphExecutionState::OptimizeGraph( +absl::Status GraphExecutionState::OptimizeGraph( const BuildGraphOptions& options, const Graph& graph, const FunctionLibraryDefinition* flib_def, std::unique_ptr* optimized_graph, @@ -669,7 +671,7 @@ Status GraphExecutionState::OptimizeGraph( // Add devices to the GrapplerItem // It's ok to skip invalid device annotations in Grappler. for (const Device* d : device_set_->devices()) { - Status added_device = item.AddDevice(d->name()); + absl::Status added_device = item.AddDevice(d->name()); if (!added_device.ok()) VLOG(3) << added_device.message(); } VLOG(3) << "Grappler available devices: " @@ -731,8 +733,8 @@ Status GraphExecutionState::OptimizeGraph( // Try to get the type and shape of the feed node. PartialTensorShape partial_shape; DataType type; - Status st = GetFeedShapeAndTypeFromAttribute(node->def(), - &partial_shape, &type); + absl::Status st = GetFeedShapeAndTypeFromAttribute( + node->def(), &partial_shape, &type); // Failed to get type and shape of the feed node. if (!st.ok()) { @@ -856,8 +858,8 @@ Status GraphExecutionState::OptimizeGraph( #endif // IS_MOBILE_PLATFORM } -Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, - std::unique_ptr* out) { +absl::Status GraphExecutionState::BuildGraph( + const BuildGraphOptions& options, std::unique_ptr* out) { VLOG(1) << "BuildGraph"; const uint64 start_time_usecs = Env::Default()->NowMicros(); if (!graph_) { @@ -872,8 +874,8 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, std::unique_ptr optimized_graph; std::unique_ptr optimized_flib; - Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph, - &optimized_flib); + absl::Status s = OptimizeGraph(options, *graph_, flib_def_.get(), + &optimized_graph, &optimized_flib); if (!s.ok()) { VLOG(2) << "Grappler optimization failed. Error: " << s.message(); // Simply copy the original graph and the function library if we couldn't diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index 87b3a12891d45a..b02cfc940d52c3 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -101,14 +101,14 @@ class GraphExecutionState { // Creates a new `GraphExecutionState` for the given // `graph_def`, which represents the entire graph for a session. - static Status MakeForBaseGraph( + static absl::Status MakeForBaseGraph( GraphDef&& graph_def, const GraphExecutionStateOptions& options, std::unique_ptr* out_state); // Creates a new `GraphExecutionState` and `SimpleClientGraph` // for the subgraph of `original_graph_def` defined by // `subgraph_options`. - static Status MakeForPrunedGraph( + static absl::Status MakeForPrunedGraph( const GraphExecutionState& base_execution_state, const GraphExecutionStateOptions& options, const BuildGraphOptions& subgraph_options, @@ -133,18 +133,18 @@ class GraphExecutionState { // Note that using this interface requires setting the value of // config.experimental().disable_optimize_for_static_graph() in the state // options to `true`, otherwise it will return an error. - Status Extend(const GraphDef& extension_def, - std::unique_ptr* out) const; + absl::Status Extend(const GraphDef& extension_def, + std::unique_ptr* out) const; // Builds a ClientGraph (a sub-graph of the full graph as induced by // the Node set specified in "options"). If successful, returns OK // and the caller takes the ownership of "*out". Otherwise, returns // an error. - Status BuildGraph(const BuildGraphOptions& options, - std::unique_ptr* out); + absl::Status BuildGraph(const BuildGraphOptions& options, + std::unique_ptr* out); // Optimize the graph with the node set specified in `options`. - Status OptimizeGraph( + absl::Status OptimizeGraph( const BuildGraphOptions& options, const Graph& graph, const FunctionLibraryDefinition* flib_def, std::unique_ptr* optimized_graph, @@ -182,7 +182,7 @@ class GraphExecutionState { std::unique_ptr&& flib_def, const GraphExecutionStateOptions& options); - Status InitBaseGraph(std::unique_ptr&& graph); + absl::Status InitBaseGraph(std::unique_ptr&& graph); // Map of placed stateful nodes, i.e. nodes for which is_stateful() // is true, such as "params" and "queue" nodes. Once placed these @@ -195,8 +195,8 @@ class GraphExecutionState { // Extract the subset of the graph that needs to be run, adding feed/fetch // ops as needed. - Status PruneGraph(const BuildGraphOptions& options, Graph* graph, - subgraph::RewriteGraphMetadata* out_rewrite_metadata); + absl::Status PruneGraph(const BuildGraphOptions& options, Graph* graph, + subgraph::RewriteGraphMetadata* out_rewrite_metadata); // The GraphExecutionState must store a copy of the original GraphDef if // either of the following conditions holds: diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 101fc0cd1c01d8..90052d68873c6a 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -51,8 +51,8 @@ class SimpleRendezvous : public RendezvousInterface { public: explicit SimpleRendezvous() {} - Status Send(const ParsedKey& parsed, const Args& send_args, const Tensor& val, - const bool is_dead) override { + absl::Status Send(const ParsedKey& parsed, const Args& send_args, + const Tensor& val, const bool is_dead) override { if (is_dead) { return errors::Internal("Send of a dead tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public RendezvousInterface { void RecvAsync(const ParsedKey& parsed, const Args& recv_args, DoneCallback done) override { Tensor tensor; - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); { string key(parsed.edge_name); mutex_lock l(mu_); @@ -82,7 +82,7 @@ class SimpleRendezvous : public RendezvousInterface { done(status, Args{}, recv_args, tensor, false); } - void StartAbort(const Status& status) override {} + void StartAbort(const absl::Status& status) override {} private: typedef std::unordered_map Table; @@ -100,10 +100,11 @@ GraphRunner::GraphRunner(Device* device) : device_(device) {} GraphRunner::~GraphRunner() {} -Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, - const NamedTensorList& inputs, - const std::vector& output_names, - std::vector* outputs) { +absl::Status GraphRunner::Run(Graph* graph, + FunctionLibraryRuntime* function_library, + const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs) { if (device_ == nullptr) { return errors::NotFound("Cannot find a device for GraphRunner."); } diff --git a/tensorflow/core/common_runtime/graph_runner.h b/tensorflow/core/common_runtime/graph_runner.h index 95bb95371d5873..a40d17b862b0af 100644 --- a/tensorflow/core/common_runtime/graph_runner.h +++ b/tensorflow/core/common_runtime/graph_runner.h @@ -59,10 +59,10 @@ class GraphRunner { // REQUIRES: `graph`, `env`, and `outputs` are not nullptr. // `function_library` may be nullptr. typedef std::vector> NamedTensorList; - Status Run(Graph* graph, FunctionLibraryRuntime* function_library, - const NamedTensorList& inputs, - const std::vector& output_names, - std::vector* outputs); + absl::Status Run(Graph* graph, FunctionLibraryRuntime* function_library, + const NamedTensorList& inputs, + const std::vector& output_names, + std::vector* outputs); private: std::unique_ptr device_deleter_; diff --git a/tensorflow/core/common_runtime/graph_runner_test.cc b/tensorflow/core/common_runtime/graph_runner_test.cc index b559ccc937465d..fa9798b929f79e 100644 --- a/tensorflow/core/common_runtime/graph_runner_test.cc +++ b/tensorflow/core/common_runtime/graph_runner_test.cc @@ -46,7 +46,8 @@ TEST(GraphRunnerTest, SingleConst) { auto c = ops::Const(root, 42.0f); GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs); + absl::Status s = + graph_runner.Run(root.graph(), nullptr, {}, {c.name()}, &outputs); TF_ASSERT_OK(s); test::ExpectEqual(test::AsScalar(42.0f), outputs[0]); } @@ -71,7 +72,7 @@ TEST(GraphRunnerTest, DeepCopy) { std::vector outputs; { GraphRunner graph_runner(Env::Default()); - Status s = + absl::Status s = graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs); TF_ASSERT_OK(s); } @@ -84,8 +85,8 @@ TEST(GraphRunnerTest, MultiFetchConst) { auto pi = ops::Const(root, 3.14f); GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = graph_runner.Run(root.graph(), nullptr, {}, {c.name(), pi.name()}, - &outputs); + absl::Status s = graph_runner.Run(root.graph(), nullptr, {}, + {c.name(), pi.name()}, &outputs); TF_ASSERT_OK(s); test::ExpectEqual(test::AsScalar(42.0f), outputs[0]); test::ExpectEqual(test::AsScalar(3.14f), outputs[1]); @@ -106,7 +107,7 @@ TEST(GraphRunnerTest, FeedAndFetch) { GraphRunner graph_runner(Env::Default()); std::vector outputs; - Status s = + absl::Status s = graph_runner.Run(root.graph(), nullptr, inputs, {"add:0"}, &outputs); TF_ASSERT_OK(s); test::ExpectEqual(test::AsScalar(3.0f), outputs[0]); diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 4bbd22c89dfe6f..28217fc69404be 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -203,10 +203,10 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { // Check ScopedAllocatorAttrs and forward_from. Also assign output_types. { std::vector forward_input; - Status fwd_status = + absl::Status fwd_status = GetNodeAttr(n->attrs(), "_forward_input", &forward_input); std::vector scoped_allocator_attrs; - Status sa_status = + absl::Status sa_status = GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs); int* forward_from = item->forward_from_base(); @@ -244,7 +244,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { return ptr; } -Status GraphView::Initialize(const Graph* g) { +absl::Status GraphView::Initialize(const Graph* g) { CHECK(node_offsets_ == nullptr); const int num_nodes = g->num_node_ids(); num_nodes_ = num_nodes; @@ -323,8 +323,8 @@ void GraphView::SetScopedAllocatorAttrs( NodeItem* item = node(use_node->id()); AllocatorAttributes* use_attrs = item->output_attr_base(); std::vector scoped_allocator_attrs; - Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator", - &scoped_allocator_attrs); + absl::Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator", + &scoped_allocator_attrs); if (!s.ok()) { VLOG(2) << "Failed to find expected ScopedAllocator attr on " << use_node->name(); @@ -351,10 +351,10 @@ void GraphView::SetScopedAllocatorAttrs( } namespace { -Status InferAllocAttr(const Node* n, const Node* dst, - const DeviceNameUtils::ParsedName& local_dev_name, - AllocatorAttributes* attr) { - Status s; +absl::Status InferAllocAttr(const Node* n, const Node* dst, + const DeviceNameUtils::ParsedName& local_dev_name, + AllocatorAttributes* attr) { + absl::Status s; // Note that it's possible for *n to be a Recv and *dst to be a Send, // so these two cases are not mutually exclusive. if (IsRecv(n)) { @@ -418,8 +418,8 @@ Status InferAllocAttr(const Node* n, const Node* dst, } } // namespace -Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { - Status s; +absl::Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { + absl::Status s; const DeviceNameUtils::ParsedName& local_dev_name = device->parsed_name(); std::vector scoped_allocator_instances; diff --git a/tensorflow/core/common_runtime/graph_view.h b/tensorflow/core/common_runtime/graph_view.h index ed9b14cfa1f73d..d1fe278a3443a9 100644 --- a/tensorflow/core/common_runtime/graph_view.h +++ b/tensorflow/core/common_runtime/graph_view.h @@ -211,8 +211,8 @@ class GraphView { GraphView() : space_(nullptr) {} ~GraphView(); - Status Initialize(const Graph* g); - Status SetAllocAttrs(const Graph* g, const Device* device); + absl::Status Initialize(const Graph* g); + absl::Status SetAllocAttrs(const Graph* g, const Device* device); void SetScopedAllocatorAttrs(const std::vector& sa_nodes); // Returns a mutable pointer to the `NodeItem` with the given `id` if it diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc index dca693dc594b8e..a47795c4a8d5cb 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc @@ -75,7 +75,7 @@ int HierarchicalTreeBroadcaster::GetDeviceTask( return -1; } -Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( +absl::Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( CollectiveParams* col_params) { CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE); CHECK_EQ(col_params->instance.impl_details.collective_name, @@ -185,7 +185,7 @@ Status HierarchicalTreeBroadcaster::InitializeCollectiveParams( return absl::OkStatus(); } -Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( +absl::Status HierarchicalTreeBroadcaster::InitializeCollectiveContext( std::shared_ptr col_ctx) { CHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; @@ -322,7 +322,7 @@ void HierarchicalTreeBroadcaster::RunTree() { int recv_from_rank = TreeRecvFrom(*col_params_, si); Notification note; DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output, - [this, &mu, ¬e](const Status& s) { + [this, &mu, ¬e](const absl::Status& s) { mutex_lock l(mu); status_.Update(s); note.Notify(); @@ -344,16 +344,17 @@ void HierarchicalTreeBroadcaster::RunTree() { mutex_lock l(mu); ++pending_count; } - DispatchSend(si, target_rank, my_rank, - (is_source_ ? col_ctx_->input : col_ctx_->output), - [this, &mu, &pending_count, &all_done](const Status& s) { - mutex_lock l(mu); - status_.Update(s); - --pending_count; - if (pending_count == 0) { - all_done.notify_all(); - } - }); + DispatchSend( + si, target_rank, my_rank, + (is_source_ ? col_ctx_->input : col_ctx_->output), + [this, &mu, &pending_count, &all_done](const absl::Status& s) { + mutex_lock l(mu); + status_.Update(s); + --pending_count; + if (pending_count == 0) { + all_done.notify_all(); + } + }); } } @@ -380,7 +381,7 @@ void HierarchicalTreeBroadcaster::RunTree() { col_ctx_->op_ctx->input_alloc_attr(0), col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, col_ctx_->output, 0, /*stream_index*/ - [this, &mu, &pending_count, &all_done](const Status& s) { + [this, &mu, &pending_count, &all_done](const absl::Status& s) { mutex_lock l(mu); status_.Update(s); --pending_count; diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h index f058a10c5a0f78..fd5ee9855cf260 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h @@ -35,11 +35,12 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { // The first subdiv comprises one device per task which gets the tensor on // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task // i. - Status InitializeCollectiveParams(CollectiveParams* col_params) override; + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. - Status InitializeCollectiveContext( + absl::Status InitializeCollectiveContext( std::shared_ptr col_ctx) override; // Begins async execution of the hierarchical tree broadcast. @@ -78,7 +79,7 @@ class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface { std::shared_ptr col_ctx_; const CollectiveParams* col_params_; // Not owned StatusCallback done_; - Status status_; + absl::Status status_; bool is_source_; }; diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc index c79c3fc159be08..ba419077d2774e 100644 --- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc +++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc @@ -244,7 +244,7 @@ class HierarchicalTreeBroadcasterTest : public ::testing::Test { Tensor output_tensor_; Device* device_; core::RefCountPtr col_params_; - Status status_; + absl::Status status_; }; // class DeviceInstance std::unique_ptr test_env_; diff --git a/tensorflow/core/common_runtime/immutable_executor_state.cc b/tensorflow/core/common_runtime/immutable_executor_state.cc index e3a2435505e041..6eef9e802d862e 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.cc +++ b/tensorflow/core/common_runtime/immutable_executor_state.cc @@ -81,7 +81,7 @@ ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo( } } -Status ImmutableExecutorState::Initialize(const Graph& graph) { +absl::Status ImmutableExecutorState::Initialize(const Graph& graph) { TF_RETURN_IF_ERROR(gview_.Initialize(&graph)); // Build the information about frames in this subgraph. @@ -129,7 +129,7 @@ Status ImmutableExecutorState::Initialize(const Graph& graph) { item->input_start = frame_info->total_inputs; frame_info->total_inputs += n->num_inputs(); - Status s = params_.create_kernel(n->properties(), &item->kernel); + absl::Status s = params_.create_kernel(n->properties(), &item->kernel); if (!s.ok()) { params_.delete_kernel(item->kernel); item->kernel = nullptr; @@ -282,8 +282,8 @@ bool ExtractScopedAllocatorAttr(const std::vector& sc_attr, } } // namespace -Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, - ControlFlowInfo* cf_info) { +absl::Status ImmutableExecutorState::BuildControlFlowInfo( + const Graph* g, ControlFlowInfo* cf_info) { const int num_nodes = g->num_node_ids(); cf_info->frame_names.resize(num_nodes); std::vector parent_nodes; diff --git a/tensorflow/core/common_runtime/immutable_executor_state.h b/tensorflow/core/common_runtime/immutable_executor_state.h index a1fca080ca6c5c..6a12bc1fb0b0c0 100644 --- a/tensorflow/core/common_runtime/immutable_executor_state.h +++ b/tensorflow/core/common_runtime/immutable_executor_state.h @@ -78,12 +78,12 @@ class ImmutableExecutorState { : params_(p), gview_() {} ~ImmutableExecutorState(); - Status Initialize(const Graph& graph); + absl::Status Initialize(const Graph& graph); // Process all Nodes in the current graph, attempting to infer the // memory allocation attributes to be used wherever they may allocate // a tensor buffer. - Status SetAllocAttrs(); + absl::Status SetAllocAttrs(); const LocalExecutorParams& params() const { return params_; } const GraphView& graph_view() const { return gview_; } @@ -122,8 +122,8 @@ class ImmutableExecutorState { std::vector frame_names; }; - static Status BuildControlFlowInfo(const Graph* graph, - ControlFlowInfo* cf_info); + static absl::Status BuildControlFlowInfo(const Graph* graph, + ControlFlowInfo* cf_info); void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); FrameInfo* EnsureFrameInfo(const string& fname); diff --git a/tensorflow/core/common_runtime/inline_function_utils.cc b/tensorflow/core/common_runtime/inline_function_utils.cc index 67250c6abbfc13..c1fea615fba655 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.cc +++ b/tensorflow/core/common_runtime/inline_function_utils.cc @@ -100,7 +100,7 @@ static Node* AddNoOp(StringPiece name, Graph* g) { NodeDef ndef; ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name))); ndef.set_op("NoOp"); - Status s; + absl::Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); return ret; @@ -113,7 +113,7 @@ static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { ndef.set_op("Identity"); ndef.add_input(input.name()); AddNodeAttr("T", BaseType(input.dtype()), &ndef); - Status s; + absl::Status s; Node* ret = g->AddNode(ndef, &s); TF_CHECK_OK(s); g->AddEdge(input.node, input.index, ret, 0); @@ -271,7 +271,7 @@ InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph, namespace { -Status ValidateNoInline(const FunctionBody* fbody) { +absl::Status ValidateNoInline(const FunctionBody* fbody) { const auto attr = AttrSlice(&fbody->record->fdef().attr()); bool noinline = false; if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) { @@ -332,8 +332,8 @@ string InlineFunctionBodyOptions::DebugString() const { ", uniquify_frame_names=", true_false(uniquify_frame_names)); } -Status ValidateInlining(const Node* node, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options) { +absl::Status ValidateInlining(const Node* node, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options) { // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee // that all side-effectful ops will be executed after inlining. See Grappler // function_optimizer for details. Unify all function inlining mechanism. @@ -476,9 +476,10 @@ Status ValidateInlining(const Node* node, const FunctionBody* fbody, // a single device). // // TODO(ezhulenev): Documentation above is ahead of implementation below. -Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options) { +absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, + Graph* g, Node* caller, + const FunctionBody* fbody, + const InlineFunctionBodyOptions& options) { VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" << options.DebugString() << "]"; VLOG(4) << "Inlining function: " @@ -486,7 +487,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, VLOG(4) << "Current graphdef: " << g->ToGraphDefDebug().DebugString(); VLOG(4) << "Caller: " << caller->DebugString(); - Status validation = ValidateInlining(caller, fbody, options); + absl::Status validation = ValidateInlining(caller, fbody, options); if (!validation.ok()) { return errors::Internal("Inlining mismatch: ", validation.message()); } @@ -628,7 +629,7 @@ Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, TF_RETURN_IF_ERROR( MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef)); - Status added_node; + absl::Status added_node; Node* clone = g->AddNode(std::move(ndef), &added_node); TF_CHECK_OK(added_node); node_map[n->id()] = clone; @@ -878,7 +879,7 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, continue; } FunctionLibraryRuntime::Handle handle; - Status s = InstantiateFunctionCall(node->def(), lib, &handle); + absl::Status s = InstantiateFunctionCall(node->def(), lib, &handle); if (!s.ok()) { LOG(ERROR) << "Failed to instantiate a function: " << s.message(); continue; @@ -890,10 +891,10 @@ bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, bool inlined_any = false; for (const auto& p : candidates) { - Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, - p.first->IsPartitionedCall() - ? options.multi_device_options - : options.native_options); + absl::Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, + p.first->IsPartitionedCall() + ? options.multi_device_options + : options.native_options); if (inlined.ok()) { inlined_any = true; } else { diff --git a/tensorflow/core/common_runtime/inline_function_utils.h b/tensorflow/core/common_runtime/inline_function_utils.h index a11f2e8b770b2a..94c118fe882a20 100644 --- a/tensorflow/core/common_runtime/inline_function_utils.h +++ b/tensorflow/core/common_runtime/inline_function_utils.h @@ -160,8 +160,8 @@ struct InlineFunctionBodyOptions { // // If function can't be safely inlined, returns error message with details why // inlining is not possible or safe. -Status ValidateInlining(const Node* node, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options); +absl::Status ValidateInlining(const Node* node, const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); // Given a "caller" in graph "g", which is a function call of a function // to "fbody". Replaces the "caller" with fbody->graph and connects @@ -171,9 +171,10 @@ Status ValidateInlining(const Node* node, const FunctionBody* fbody, // Returns 'OkStatus()' if function was successfully inlined into the graph. // If function inlining is not possible returns an error with a reason, and // leaves the graph in unmodified state. -Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, - Node* caller, const FunctionBody* fbody, - const InlineFunctionBodyOptions& options); +absl::Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, + Graph* g, Node* caller, + const FunctionBody* fbody, + const InlineFunctionBodyOptions& options); // There are three types of function calls that could be invoked during // *Tensorflow graph execution*: diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc index 8a0eb150dd497d..96799bcf1e4be8 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.cc +++ b/tensorflow/core/common_runtime/inspecting_placer.cc @@ -94,7 +94,7 @@ class ColocationGraphToIOColocationGroups { } } - Status FillGroups(std::vector* group_devices) { + absl::Status FillGroups(std::vector* group_devices) { group_devices->resize(group_ids_.size()); for (const auto& it : group_ids_) { int assigned_group_id = it.second; @@ -125,8 +125,8 @@ InspectingPlacer::InspectingPlacer(const FunctionStack& stack, allow_soft_placement_(allow_soft_placement), log_device_placement_(log_device_placement) {} -Status InspectingPlacer::ComputeIOColocationGroups(const Node& node, - IOColocationGroups* groups) { +absl::Status InspectingPlacer::ComputeIOColocationGroups( + const Node& node, IOColocationGroups* groups) { core::RefCountPtr fdef; NameAttrList func; TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func)); diff --git a/tensorflow/core/common_runtime/inspecting_placer.h b/tensorflow/core/common_runtime/inspecting_placer.h index b9d19a48917d94..90df36c58139fd 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.h +++ b/tensorflow/core/common_runtime/inspecting_placer.h @@ -76,8 +76,8 @@ class InspectingPlacer { // `node` must be // PlacerInspectionRequiredOpsChecker::IsPlacerInspectionRequired. - Status ComputeIOColocationGroups(const Node& node, - IOColocationGroups* groups); + absl::Status ComputeIOColocationGroups(const Node& node, + IOColocationGroups* groups); private: const FunctionStack stack_; diff --git a/tensorflow/core/common_runtime/int32_fulltype.cc b/tensorflow/core/common_runtime/int32_fulltype.cc index ab2ef6867d122b..094eaa2a9773ec 100644 --- a/tensorflow/core/common_runtime/int32_fulltype.cc +++ b/tensorflow/core/common_runtime/int32_fulltype.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/int32_fulltype.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/function.h" @@ -23,18 +24,18 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { -Status Int32FulltypePass::Int32FullTypeForTensor(DataType dtype, - FullTypeDef* tensor_t, - bool set_only_int32, - Node* node, int output_idx) { +absl::Status Int32FulltypePass::Int32FullTypeForTensor(DataType dtype, + FullTypeDef* tensor_t, + bool set_only_int32, + Node* node, + int output_idx) { if (tensor_t->type_id() == TFT_TENSOR) { if (tensor_t->args_size() != 1) { if (node != nullptr) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Full type for node='", node->name(), "' (op='", node->op_def().name(), "') in '", debug_location_, @@ -42,10 +43,11 @@ Status Int32FulltypePass::Int32FullTypeForTensor(DataType dtype, tensor_t->args_size(), " args instead of 1.\n got:\n", tensor_t->DebugString())); } else { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("TFT_TENSOR has ", tensor_t->args_size(), - " args instead of 1.\n got:\n", - tensor_t->DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("TFT_TENSOR has ", tensor_t->args_size(), + " args instead of 1.\n got:\n", + tensor_t->DebugString())); } } if (tensor_t->args(0).type_id() == TFT_INT32) { @@ -65,7 +67,8 @@ static bool is_host_memory_int32(MemoryType mtype, DataType dtype) { return (mtype == HOST_MEMORY) && (dtype == DT_INT32); } -Status Int32FulltypePass::ProcessGraph(Graph* graph, bool ints_on_device) { +absl::Status Int32FulltypePass::ProcessGraph(Graph* graph, + bool ints_on_device) { for (Node* n : graph->op_nodes()) { auto output_types = n->output_types(); bool needs_annotation = false; @@ -83,7 +86,7 @@ Status Int32FulltypePass::ProcessGraph(Graph* graph, bool ints_on_device) { if (n->def().has_experimental_type()) { FullTypeDef* node_t = n->mutable_def()->mutable_experimental_type(); if (node_t->type_id() != TFT_PRODUCT) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Full type for node='", n->name(), "' (op='", n->op_def().name(), @@ -91,7 +94,7 @@ Status Int32FulltypePass::ProcessGraph(Graph* graph, bool ints_on_device) { node_t->DebugString())); } if (node_t->args_size() != output_types.size()) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Full type for node='", n->name(), "' (op='", n->op_def().name(), "') has ", node_t->args_size(), diff --git a/tensorflow/core/common_runtime/int32_fulltype.h b/tensorflow/core/common_runtime/int32_fulltype.h index f2be1c87e3f090..1a55e0bc6a1e7c 100644 --- a/tensorflow/core/common_runtime/int32_fulltype.h +++ b/tensorflow/core/common_runtime/int32_fulltype.h @@ -39,7 +39,7 @@ class Int32FulltypePass { // eager execution). // // This method is not thread-safe. - Status ProcessGraph(Graph* graph, bool ints_on_device); + absl::Status ProcessGraph(Graph* graph, bool ints_on_device); // Update full type information for int32 tensors that are in HOST_MEMORY // to use TFT_SHAPE_TENSOR. The type_id of TENSOR_T is expected to be @@ -51,9 +51,9 @@ class Int32FulltypePass { // of a node, so it does have an outer TFT_PRODUCT. NODE and OUTPUT_IDX are // optional and only used in an error message to say that the tensor is output // OUTPUT_IDX of node NODE. - Status Int32FullTypeForTensor(DataType dtype, FullTypeDef* tensor_t, - bool set_only_int32, Node* node = nullptr, - int output_idx = 0); + absl::Status Int32FullTypeForTensor(DataType dtype, FullTypeDef* tensor_t, + bool set_only_int32, Node* node = nullptr, + int output_idx = 0); private: // Location of where annotations were added for debug messages. diff --git a/tensorflow/core/common_runtime/int32_fulltype_test.cc b/tensorflow/core/common_runtime/int32_fulltype_test.cc index 5d2c0e0b9bdb46..8cfb991cdacd38 100644 --- a/tensorflow/core/common_runtime/int32_fulltype_test.cc +++ b/tensorflow/core/common_runtime/int32_fulltype_test.cc @@ -66,7 +66,7 @@ class Int32FulltypeTest : public ::testing::Test { // Builds the given graph, and (if successful) indexes the node // names for use in placement, and later lookup. - Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { + absl::Status BuildGraph(const GraphDefBuilder& builder, Graph* out_graph) { TF_RETURN_IF_ERROR(GraphDefBuilderToGraph(builder, out_graph)); RebuildNodeNameMap(*out_graph); return absl::OkStatus(); @@ -87,7 +87,8 @@ class Int32FulltypeTest : public ::testing::Test { // Invokes the automatic annotator on "graph" // // REQUIRES: "*graph" was produced by the most recent call to BuildGraph. - Status Int32FulltypeAnnotate(Graph* graph, bool ints_on_device = false) { + absl::Status Int32FulltypeAnnotate(Graph* graph, + bool ints_on_device = false) { Int32FulltypePass int32_fulltype; return int32_fulltype.ProcessGraph(graph, ints_on_device); } diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc index 7945b5cc3cf1c6..890c175b007557 100644 --- a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { -Status IsolatePlacerInspectionRequiredOpsPass::Run( +absl::Status IsolatePlacerInspectionRequiredOpsPass::Run( const GraphOptimizationPassOptions& options) { if (options.graph == nullptr) { VLOG(1) << "Not running IsolatePlacerInspectionRequiredOpsPass because no " @@ -38,7 +38,8 @@ Status IsolatePlacerInspectionRequiredOpsPass::Run( DumpGraphToFile("isolate_deep_ops_before", *graph, nullptr, "/tmp"); } - Status status = IsolatePlacerInspectionRequiredOps(*options.flib_def, graph); + absl::Status status = + IsolatePlacerInspectionRequiredOps(*options.flib_def, graph); if (VLOG_IS_ON(3) && status.ok()) { DumpGraphToFile("isolate_deep_ops_after", *graph, nullptr, "/tmp"); diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h index 3d86c4538f0d17..1bcdc001cc2726 100644 --- a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h @@ -55,7 +55,7 @@ namespace tensorflow { // to it. class IsolatePlacerInspectionRequiredOpsPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/local_executor_params.h b/tensorflow/core/common_runtime/local_executor_params.h index 6c201d79dafa8e..a363f11351ac5a 100644 --- a/tensorflow/core/common_runtime/local_executor_params.h +++ b/tensorflow/core/common_runtime/local_executor_params.h @@ -43,8 +43,8 @@ struct LocalExecutorParams { // create_kernel returns an instance of op kernel based on NodeDef. // delete_kernel is called for every kernel used by the executor // when the executor is deleted. - std::function&, - OpKernel**)> + std::function&, + OpKernel**)> create_kernel; std::function delete_kernel; diff --git a/tensorflow/core/common_runtime/lower_case_op.cc b/tensorflow/core/common_runtime/lower_case_op.cc index 4727c941ec6efc..39d1d150fa8a1b 100644 --- a/tensorflow/core/common_runtime/lower_case_op.cc +++ b/tensorflow/core/common_runtime/lower_case_op.cc @@ -42,18 +42,18 @@ class CaseBuilder { bool keep_node_fetchable, Graph* graph); // Constructs the basic conditional control flow using switch and merge nodes. - Status CreatePivotNodes(); + absl::Status CreatePivotNodes(); // Adds the inputs from the if node to the merge nodes of the lowered if. - Status AddInputs(); + absl::Status AddInputs(); // Adds the outputs from the if node to the merge nodes of the lowered if. // Note: no inputs can be added once outputs are added as the then and else // nodes are finalized while adding outputs. - Status AddOutputs(); + absl::Status AddOutputs(); // Builds an identity node with the same outputs as Case. - Status BuildLoweredCaseOutput(); + absl::Status BuildLoweredCaseOutput(); private: // Returns unique name containing the name of the Case op being rewritten @@ -61,7 +61,7 @@ class CaseBuilder { string NewName(const string& infix); // Adds input to both the then and else nodes from src:src_output. - Status AddInput(Node* src, int src_output); + absl::Status AddInput(Node* src, int src_output); // The merged outputs of the then and else nodes. std::vector outputs_; @@ -115,7 +115,7 @@ CaseBuilder::CaseBuilder(Node* case_op, TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_)); } -Status CaseBuilder::CreatePivotNodes() { +absl::Status CaseBuilder::CreatePivotNodes() { // Construct the basic case body (consisting of feeding in the val to // create pivot nodes). Node* branch_index; @@ -143,7 +143,7 @@ string CaseBuilder::NewName(const string& infix) { return graph_->NewName(strings::StrCat(name_, "/", infix)); } -Status CaseBuilder::AddInput(Node* src, int src_output) { +absl::Status CaseBuilder::AddInput(Node* src, int src_output) { Node* input; NodeDebugInfo debug_info(*src); // Colocate the Switch node with the `src` node. @@ -169,7 +169,7 @@ Status CaseBuilder::AddInput(Node* src, int src_output) { return absl::OkStatus(); } -Status CaseBuilder::AddInputs() { +absl::Status CaseBuilder::AddInputs() { // Add input data edges. std::vector edges; TF_RETURN_IF_ERROR(case_op_->input_edges(&edges)); @@ -187,7 +187,7 @@ Status CaseBuilder::AddInputs() { return absl::OkStatus(); } -Status CaseBuilder::AddOutputs() { +absl::Status CaseBuilder::AddOutputs() { // Construct the call nodes for each branch. call_nodes_.resize(num_branches_, nullptr); for (int b = 0; b < num_branches_; b++) { @@ -250,7 +250,7 @@ Status CaseBuilder::AddOutputs() { return absl::OkStatus(); } -Status CaseBuilder::BuildLoweredCaseOutput() { +absl::Status CaseBuilder::BuildLoweredCaseOutput() { // If outputs are empty, it means that we might have only output control // edges (already connected to the `branch_executed_node`). Furthermore it's // illegal to have an IdentityN with empty inputs. @@ -267,7 +267,7 @@ Status CaseBuilder::BuildLoweredCaseOutput() { } // namespace -Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) { +absl::Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) { VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable << "): " << SummarizeNode(*n); const AttrValue* branches_attr = n->attrs().Find("branches"); diff --git a/tensorflow/core/common_runtime/lower_case_op.h b/tensorflow/core/common_runtime/lower_case_op.h index 110ac20a929d89..65b56e51d977fd 100644 --- a/tensorflow/core/common_runtime/lower_case_op.h +++ b/tensorflow/core/common_runtime/lower_case_op.h @@ -24,7 +24,7 @@ class Graph; class Node; // Replaces Case node `n` with a lowered form that uses _SwitchN/Merge nodes. -Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable); +absl::Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_case_op_test.cc b/tensorflow/core/common_runtime/lower_case_op_test.cc index 185d7b9d50289f..eb5033cd75b000 100644 --- a/tensorflow/core/common_runtime/lower_case_op_test.cc +++ b/tensorflow/core/common_runtime/lower_case_op_test.cc @@ -51,7 +51,7 @@ SessionOptions SessionOptionsWithInlining() { return session_options; } -Status Rewrite(std::unique_ptr* graph) { +absl::Status Rewrite(std::unique_ptr* graph) { FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; SessionOptions session_options = SessionOptionsWithInlining(); diff --git a/tensorflow/core/common_runtime/lower_function_call_op.cc b/tensorflow/core/common_runtime/lower_function_call_op.cc index c5226b4eefcc85..bede5151197000 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op.cc @@ -34,9 +34,9 @@ namespace tensorflow { using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; -Status RewriteFunctionCallNode(Node* n, Graph* g, - const FunctionLibraryDefinition& flib_def, - bool keep_caller_fetchable) { +absl::Status RewriteFunctionCallNode(Node* n, Graph* g, + const FunctionLibraryDefinition& flib_def, + bool keep_caller_fetchable) { VLOG(2) << "Lower function call node: " << SummarizeNode(*n); // We support lowering of two types of functions that could be invoked by the @@ -103,7 +103,7 @@ Status RewriteFunctionCallNode(Node* n, Graph* g, VLOG(2) << "Pruning disabled before inlining"; } - Status can_inline_function_call = + absl::Status can_inline_function_call = ValidateInlining(n, fbody.get(), inline_options); if (can_inline_function_call.ok()) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/common_runtime/lower_function_call_op.h b/tensorflow/core/common_runtime/lower_function_call_op.h index 84f6f68dd71b65..71d5e807c8809b 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op.h +++ b/tensorflow/core/common_runtime/lower_function_call_op.h @@ -28,9 +28,9 @@ class Node; // InlineFunctionBody from `common_runtime/function.{h,cc}`. If function // inlining is not possible or safe (see ValidateInlining), leaves the graph in // unmodified state and returns OkStatus(); -Status RewriteFunctionCallNode(Node* n, Graph* g, - const FunctionLibraryDefinition& flib_def, - bool keep_caller_fetchable); +absl::Status RewriteFunctionCallNode(Node* n, Graph* g, + const FunctionLibraryDefinition& flib_def, + bool keep_caller_fetchable); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_function_call_op_test.cc b/tensorflow/core/common_runtime/lower_function_call_op_test.cc index ea3de9500b9d3c..d276c7c43abbb7 100644 --- a/tensorflow/core/common_runtime/lower_function_call_op_test.cc +++ b/tensorflow/core/common_runtime/lower_function_call_op_test.cc @@ -57,7 +57,7 @@ SessionOptions SessionOptionsWithInlining() { return session_options; } -Status Rewrite(std::unique_ptr* graph) { +absl::Status Rewrite(std::unique_ptr* graph) { FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; SessionOptions session_options = SessionOptionsWithInlining(); diff --git a/tensorflow/core/common_runtime/lower_functional_ops.cc b/tensorflow/core/common_runtime/lower_functional_ops.cc index 62c712a0d15360..7cf5af392d518f 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops.cc @@ -98,7 +98,7 @@ bool IsPropagatableDevice(StringPiece device_string) { } // namespace -Status LowerFunctionalOpsPass::Run( +absl::Status LowerFunctionalOpsPass::Run( const GraphOptimizationPassOptions& options) { if (options.partition_graphs != nullptr) { return errors::Internal( diff --git a/tensorflow/core/common_runtime/lower_functional_ops.h b/tensorflow/core/common_runtime/lower_functional_ops.h index c372dfbad40ed6..a849550a1a3459 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops.h +++ b/tensorflow/core/common_runtime/lower_functional_ops.h @@ -34,7 +34,7 @@ class LowerFunctionalOpsPass : public GraphOptimizationPass { public: LowerFunctionalOpsPass() = default; - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; static constexpr const char* const kLowerUsingSwitchMergeAttr = LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; diff --git a/tensorflow/core/common_runtime/lower_functional_ops_test.cc b/tensorflow/core/common_runtime/lower_functional_ops_test.cc index 53a1bf47a72841..057cc4fe4c3e8c 100644 --- a/tensorflow/core/common_runtime/lower_functional_ops_test.cc +++ b/tensorflow/core/common_runtime/lower_functional_ops_test.cc @@ -53,7 +53,7 @@ SessionOptions SessionOptionsWithInlining() { return session_options; } -Status Rewrite(std::unique_ptr* graph) { +absl::Status Rewrite(std::unique_ptr* graph) { FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; SessionOptions session_options = SessionOptionsWithInlining(); diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index a2875c7c823b52..e46ef4ff3de543 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -44,18 +44,18 @@ class CondBuilder { Graph* graph); // Constructs the basic conditional control flow using switch and merge nodes. - Status CreatePivotNodes(); + absl::Status CreatePivotNodes(); // Adds the inputs from the if node to the merge nodes of the lowered if. - Status AddInputs(); + absl::Status AddInputs(); // Adds the outputs from the if node to the merge nodes of the lowered if. // Note: no inputs can be added once outputs are added as the then and else // nodes are finalized while adding outputs. - Status AddOutputs(); + absl::Status AddOutputs(); // Builds an identity node with the same outputs as If. - Status BuildLoweredIfOutput(); + absl::Status BuildLoweredIfOutput(); private: // Returns unique name containing the name of the If op being rewritten @@ -63,12 +63,12 @@ class CondBuilder { string NewName(const string& infix); // Adds input to both the then and else nodes from src:src_output. - Status AddInput(Node* src, int src_output); + absl::Status AddInput(Node* src, int src_output); // Finalizes the node described by `node_builder`. If `coloc_attr_` is not // nullptr, adds the colocation attr to the node before finalizing it. - Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph, - Node** created_node); + absl::Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph, + Node** created_node); // The merged outputs of the then and else nodes. std::vector outputs_; @@ -136,16 +136,16 @@ CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn, } } -Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, - Graph* graph, - Node** created_node) { +absl::Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, + Graph* graph, + Node** created_node) { if (coloc_attr_ != nullptr) { node_builder = node_builder.Attr(kColocationAttrName, *coloc_attr_); } return node_builder.Finalize(graph, created_node); } -Status CondBuilder::CreatePivotNodes() { +absl::Status CondBuilder::CreatePivotNodes() { // Construct the basic cond body (consisting of feeding in the predicate to // create pivot nodes). Node* switch_pred; @@ -176,7 +176,7 @@ string CondBuilder::NewName(const string& infix) { return graph_->NewName(strings::StrCat(name_, "/", infix)); } -Status CondBuilder::AddInput(Node* src, int src_output) { +absl::Status CondBuilder::AddInput(Node* src, int src_output) { Node* input; NodeDebugInfo debug_info(*src); // Colocate the Switch node with the `src` node. @@ -205,7 +205,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { return absl::OkStatus(); } -Status CondBuilder::AddInputs() { +absl::Status CondBuilder::AddInputs() { // Add input data edges. std::vector edges; TF_RETURN_IF_ERROR(if_op_->input_edges(&edges)); @@ -223,7 +223,7 @@ Status CondBuilder::AddInputs() { return absl::OkStatus(); } -Status CondBuilder::AddOutputs() { +absl::Status CondBuilder::AddOutputs() { // Construct the then and else nodes. // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize` // because the colocation for branch nodes is applied in python. @@ -279,7 +279,7 @@ Status CondBuilder::AddOutputs() { return absl::OkStatus(); } -Status CondBuilder::BuildLoweredIfOutput() { +absl::Status CondBuilder::BuildLoweredIfOutput() { // If outputs are empty, it means that we might have only output control // edges (already connected to the `branch_executed_node`). Furthermore it's // illegal to have an IdentityN with empty inputs. @@ -297,7 +297,7 @@ Status CondBuilder::BuildLoweredIfOutput() { } // namespace -Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) { +absl::Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) { VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable << "): " << SummarizeNode(*n); diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h index 55b7b91b56f460..c125a1977df5f2 100644 --- a/tensorflow/core/common_runtime/lower_if_op.h +++ b/tensorflow/core/common_runtime/lower_if_op.h @@ -24,7 +24,7 @@ class Graph; class Node; // Replaces If node `n` with its lowered form that uses Switch and Merge nodes. -Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable); +absl::Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index cf7d35409bb078..91bddb27b452be 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -49,7 +49,7 @@ SessionOptions SessionOptionsWithInlining() { return session_options; } -Status Rewrite(std::unique_ptr* graph) { +absl::Status Rewrite(std::unique_ptr* graph) { FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; SessionOptions session_options = SessionOptionsWithInlining(); diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc index 67cab4576b44ab..86880d06d4177c 100644 --- a/tensorflow/core/common_runtime/lower_while_op.cc +++ b/tensorflow/core/common_runtime/lower_while_op.cc @@ -67,10 +67,11 @@ constexpr const char* const kLowerAsMultiDeviceFunctionAttr = // consumer class LowerWhileHelper { public: - static Status Run(Node* while_op, const NameAttrList& cond_fn, - const NameAttrList& body_fn, int parallel_iterations, - Graph* graph, const FunctionLibraryDefinition* flib_def, - bool keep_node_fetchable) { + static absl::Status Run(Node* while_op, const NameAttrList& cond_fn, + const NameAttrList& body_fn, int parallel_iterations, + Graph* graph, + const FunctionLibraryDefinition* flib_def, + bool keep_node_fetchable) { LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, graph, flib_def, keep_node_fetchable); return helper.RunInternal(); @@ -85,49 +86,49 @@ class LowerWhileHelper { Graph* graph, const FunctionLibraryDefinition* flib_def, bool keep_node_fetchable); - Status RunInternal(); + absl::Status RunInternal(); void InitializeInputOutputToLoweredNodeMap(); // Creates an Enter node for each `while_op_` input and adds them to // `enter_nodes_`. If the `while_op_` has an incoming control edge from a // `src` node we add a control edge from `src` to each Enter node. - Status CreateEnterNodes(); + absl::Status CreateEnterNodes(); // Creates a Merge node for each Enter node and adds to `merge_nodes_`. // Initially now both inputs of a Merge node are the Enter node. Input at // index 1 is later updated to the output of NextIteration node in // `UpdateMergeNodes`. - Status CreateMergeNodes(); + absl::Status CreateMergeNodes(); // Creates the call node for cond func and stores in `cond_call_node_`. - Status CreateCondFuncCallNode(); + absl::Status CreateCondFuncCallNode(); // Creates a Switch node for each loop var and adds to `switch_nodes_`. // Output at index 1(true) of a Switch node is fed into the loop body. // Output at index 0(false) of a Switch node is fed into the Exit nodes. - Status CreateSwitchNodes(); + absl::Status CreateSwitchNodes(); // Creates the call node for body func and stores in `body_call_node_`. - Status CreateBodyFuncCallNode(); + absl::Status CreateBodyFuncCallNode(); // Creates an Exit node for each loop var and adds to `exit_nodes_`. These // are fed into the consumers of the `while_op_`. - Status CreateExitNodes(); + absl::Status CreateExitNodes(); // Creates an NextIteration node for each loop var and adds to // `next_iteration_nodes_`. - Status CreateNextIterationNodes(); + absl::Status CreateNextIterationNodes(); // Updates input at index 1 of each merge node created in `CreateMergeNodes` // to use the output of NextIteration node created in // `CreateNextIterationNodes` instead. - Status UpdateMergeNodes(); + absl::Status UpdateMergeNodes(); // Updates consumers of the original `while_op_` to instead use the outputs // from the exit nodes in `exit_nodes_`. Also updates any outgoing control // edges to depend on `lowered_while_executed_` instead. - Status UpdateConsumers(); + absl::Status UpdateConsumers(); // Returns unique name containing the name of the While op being rewritten // (name_), infix and a suffix to ensure it is unique within the graph. @@ -225,7 +226,7 @@ LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, .enable_colocation_key_propagation_in_while_op_lowering.value(); } -Status LowerWhileHelper::RunInternal() { +absl::Status LowerWhileHelper::RunInternal() { InitializeInputOutputToLoweredNodeMap(); TF_RETURN_IF_ERROR(CreateEnterNodes()); TF_RETURN_IF_ERROR(CreateMergeNodes()); @@ -248,7 +249,7 @@ void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() { } } -Status LowerWhileHelper::CreateEnterNodes() { +absl::Status LowerWhileHelper::CreateEnterNodes() { // Note: `Node::input_edge` runs in O(num_inputs) so we use // `Node::input_edges` instead so that below loop runs in O(num_inputs) time // and not O(num_inputs^2). @@ -304,7 +305,7 @@ Status LowerWhileHelper::CreateEnterNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateMergeNodes() { +absl::Status LowerWhileHelper::CreateMergeNodes() { for (Node* enter_node : enter_nodes_) { bool is_constant = enter_node->attrs().FindByString("is_constant")->b(); if (is_constant && enter_node->output_type(0) == DT_RESOURCE) { @@ -328,7 +329,7 @@ Status LowerWhileHelper::CreateMergeNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateCondFuncCallNode() { +absl::Status LowerWhileHelper::CreateCondFuncCallNode() { for (int i = 0; i < num_loop_inputs_; i++) { if (IsLoopCarriedResource(i)) { cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0)); @@ -357,7 +358,7 @@ Status LowerWhileHelper::CreateCondFuncCallNode() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateSwitchNodes() { +absl::Status LowerWhileHelper::CreateSwitchNodes() { for (int i = 0; i < num_loop_inputs_; i++) { if (IsLoopCarriedResource(i)) { continue; @@ -393,7 +394,7 @@ Status LowerWhileHelper::CreateSwitchNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateBodyFuncCallNode() { +absl::Status LowerWhileHelper::CreateBodyFuncCallNode() { for (int i = 0; i < num_loop_inputs_; i++) { if (IsLoopCarriedResource(i)) { body_call_builder_.Input(NodeOut(enter_nodes_[i], 0)); @@ -431,7 +432,7 @@ Status LowerWhileHelper::CreateBodyFuncCallNode() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateExitNodes() { +absl::Status LowerWhileHelper::CreateExitNodes() { std::vector outputs; outputs.reserve(num_loop_inputs_); for (int i = 0; i < num_loop_inputs_; i++) { @@ -507,7 +508,7 @@ Status LowerWhileHelper::CreateExitNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::CreateNextIterationNodes() { +absl::Status LowerWhileHelper::CreateNextIterationNodes() { for (int i = 0; i < num_loop_inputs_; i++) { Node* next_iteration; if (IsLoopCarriedResource(i)) { @@ -533,7 +534,7 @@ Status LowerWhileHelper::CreateNextIterationNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::UpdateMergeNodes() { +absl::Status LowerWhileHelper::UpdateMergeNodes() { for (int i = 0; i < merge_nodes_.size(); i++) { TF_RETURN_IF_ERROR( graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1)); @@ -541,7 +542,7 @@ Status LowerWhileHelper::UpdateMergeNodes() { return absl::OkStatus(); } -Status LowerWhileHelper::UpdateConsumers() { +absl::Status LowerWhileHelper::UpdateConsumers() { for (const Edge* e : while_op_->out_edges()) { if (e->IsControlEdge()) { graph_->AddControlEdge(lowered_while_executed_, e->dst()); @@ -587,9 +588,9 @@ bool LowerWhileHelper::IsLoopCarriedResource(int index) { } // namespace -Status RewriteWhileNode(Node* n, Graph* g, - const FunctionLibraryDefinition* flib_def, - bool keep_node_fetchable) { +absl::Status RewriteWhileNode(Node* n, Graph* g, + const FunctionLibraryDefinition* flib_def, + bool keep_node_fetchable) { VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable << "): " << SummarizeNode(*n); diff --git a/tensorflow/core/common_runtime/lower_while_op.h b/tensorflow/core/common_runtime/lower_while_op.h index 479cdc23a1d845..98095dee77ba46 100644 --- a/tensorflow/core/common_runtime/lower_while_op.h +++ b/tensorflow/core/common_runtime/lower_while_op.h @@ -26,9 +26,9 @@ class FunctionLibraryDefinition; // Replaces While node `n` with its lowered form that uses Enter, Exit, Switch, // Merge, NextIteration and LoopCond nodes. -Status RewriteWhileNode(Node* n, Graph* g, - const FunctionLibraryDefinition* flib_def, - bool keep_node_fetchable); +absl::Status RewriteWhileNode(Node* n, Graph* g, + const FunctionLibraryDefinition* flib_def, + bool keep_node_fetchable); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc index 31c1e40b431e7b..4fe9337c942766 100644 --- a/tensorflow/core/common_runtime/lower_while_op_test.cc +++ b/tensorflow/core/common_runtime/lower_while_op_test.cc @@ -52,7 +52,7 @@ SessionOptions SessionOptionsWithInlining() { return session_options; } -Status Rewrite(std::unique_ptr* graph) { +absl::Status Rewrite(std::unique_ptr* graph) { FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; SessionOptions session_options = SessionOptionsWithInlining(); diff --git a/tensorflow/core/common_runtime/memory_types.cc b/tensorflow/core/common_runtime/memory_types.cc index d90981a6e883fc..d22d72f1a57019 100644 --- a/tensorflow/core/common_runtime/memory_types.cc +++ b/tensorflow/core/common_runtime/memory_types.cc @@ -46,9 +46,10 @@ struct EndpointEq { } }; -static Status ProcessMemoryTypes( +static absl::Status ProcessMemoryTypes( const DeviceType& device_type, const Graph* g, - const std::function& fn) { + const std::function& + fn) { if (device_type != DEVICE_GPU && !DeviceFactory::IsPluggableDevice(device_type.type_string())) { // On non-GPU devices, HOST_MEMORY and DEVICE_MEMORY are always compatible. @@ -92,7 +93,8 @@ static Status ProcessMemoryTypes( return absl::OkStatus(); } -Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g) { +absl::Status ValidateMemoryTypes(const DeviceType& device_type, + const Graph* g) { return ProcessMemoryTypes( device_type, g, [](const Edge* e, MemoryType sm, MemoryType dm) { if (sm == dm) { @@ -153,8 +155,8 @@ static Node* Recv(Graph* g, const string& tensor_name, return ret; } -Status EnsureMemoryTypes(const DeviceType& device_type, - const string& device_name, Graph* g) { +absl::Status EnsureMemoryTypes(const DeviceType& device_type, + const string& device_name, Graph* g) { struct Item { const Edge* edge; MemoryType sm; @@ -214,8 +216,9 @@ Status EnsureMemoryTypes(const DeviceType& device_type, return ValidateMemoryTypes(device_type, g); } -Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, - const Node* n, int index, MemoryType* memory_type) { +absl::Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, + const Node* n, int index, + MemoryType* memory_type) { MemoryTypeVector inp_mvec; MemoryTypeVector out_mvec; TF_RETURN_IF_ERROR(MemoryTypesForNode(g->op_registry(), device_type, n->def(), diff --git a/tensorflow/core/common_runtime/memory_types.h b/tensorflow/core/common_runtime/memory_types.h index f854acfdc55d66..46a943c0a3836e 100644 --- a/tensorflow/core/common_runtime/memory_types.h +++ b/tensorflow/core/common_runtime/memory_types.h @@ -24,7 +24,7 @@ namespace tensorflow { // Returns an error iff *g running on a single device of 'device_type' // has memory type mismatch for any edge's source and destination. -Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); +absl::Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); // Updates '*g' so that every edge's source and destination has // compatible memory types by inserting proper HostSend/Recv and @@ -35,13 +35,14 @@ Status ValidateMemoryTypes(const DeviceType& device_type, const Graph* g); // Returns OK if '*g' is updated properly (ValidateMemoryTypes(g) must // be OK). Otherwise, returns an error and '*g' may be in an // invalidate state and the caller should discard it. -Status EnsureMemoryTypes(const DeviceType& device_type, - const string& device_name, Graph* g); +absl::Status EnsureMemoryTypes(const DeviceType& device_type, + const string& device_name, Graph* g); // Get the memory type for 'index'th output of node 'n' in graph 'g', when // running on 'device_type'. -Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, - const Node* n, int index, MemoryType* memory_type); +absl::Status MemoryTypeForOutput(const DeviceType& device_type, const Graph* g, + const Node* n, int index, + MemoryType* memory_type); } // end namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index 57708229f106ad..1e46fb453f6e7d 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -336,12 +336,12 @@ tf_cc_test( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime:call_options", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:coordination_config_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api_test.cc index 79d19f0221b4ef..1ab2aef8b03fb6 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/plugin_c_api_test.cc @@ -32,8 +32,8 @@ limitations under the License. namespace { struct CallbackParams { - std::function callback; - tensorflow::Status status; + std::function callback; + absl::Status status; const TFNPD_Api* api; TFNPD_DeviceEvent* event; @@ -107,7 +107,7 @@ TEST_F(PluginEventTestFixture, TestInvokeCallback) { std::string tennis_goat = "Sampras"; auto done = [result_avref = result_avref.CopyRef(), - &tennis_goat](const tensorflow::Status& status) { + &tennis_goat](const absl::Status& status) { result_avref.emplace(42); LOG(INFO) << "Invoking status callback. Tennis goat is: " << status.message(); @@ -117,7 +117,7 @@ TEST_F(PluginEventTestFixture, TestInvokeCallback) { TFNPD_DeviceEvent* event = example_plugin::CreateDeviceEventAndSetAvailable(host_.get()); - tensorflow::Status status(absl::StatusCode::kInternal, "Federer"); + absl::Status status(absl::StatusCode::kInternal, "Federer"); // CallbackParams stores the "done" callback function passed in by TF, and // status, which is "done"'s arg. We need to add another indirection since we diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc index 00c899c47b8384..bff3e5ebe7016c 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.cc @@ -131,7 +131,7 @@ TF_RendezvousDoneCallbackImpl ToC( using CallbackType = std::function; auto c_callback = new CallbackType( [on_done](TF_RendezvousDoneCallback_Params* params) -> void { - Status status = tsl::StatusFromTF_Status(params->status); + absl::Status status = tsl::StatusFromTF_Status(params->status); // TODO: Pass args through. // auto sender_args = FromC(*params->sender_args); // auto recver_args = FromC(*params->recver_args); @@ -211,20 +211,20 @@ class TfCThunkRendezvous final : public ::tensorflow::RendezvousInterface { ~TfCThunkRendezvous() override = default; - Status Send(const ParsedKey& key, const Args& args, const Tensor& val, - bool is_dead) override; + absl::Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + bool is_dead) override; void RecvAsync(const ParsedKey& key, const Args& args, DoneCallback done) override; - void StartAbort(const Status& status) override; + void StartAbort(const absl::Status& status) override; private: const TF_RendezvousThunk thunk_; }; -Status TfCThunkRendezvous::Send(const ParsedKey& key, const Args& args, - const Tensor& val, const bool is_dead) { +absl::Status TfCThunkRendezvous::Send(const ParsedKey& key, const Args& args, + const Tensor& val, const bool is_dead) { CHECK_OK_AND_ASSIGN(SendParamPtr params, SendParamsToC(key, args, val, is_dead)); thunk_.send_func(thunk_.rendezvous, params.get()); @@ -249,7 +249,7 @@ void TfCThunkRendezvous::RecvAsync(const ParsedKey& key, const Args& args, thunk_.async_recv_func(thunk_.rendezvous, params); } -void TfCThunkRendezvous::StartAbort(const Status& status) { +void TfCThunkRendezvous::StartAbort(const absl::Status& status) { std::unique_ptr> c_status( TF_NewStatus(), &TF_DeleteStatus); tsl::Set_TF_Status_from_Status(c_status.get(), status); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc index cb4826e4e114df..e533ffc2072d1f 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.cc @@ -114,7 +114,7 @@ DoneCallbackParamDeleter MakeDoneCallbackParamDeleter() { } absl::StatusOr DoneCallbackParamsToC( - const Status& status, const RendezvousInterface::Args& sender_args, + const absl::Status& status, const RendezvousInterface::Args& sender_args, const RendezvousInterface::Args& recver_args, const Tensor& tensor, const bool is_dead) { TF_RendezvousDoneCallback_Params* params = @@ -125,7 +125,7 @@ absl::StatusOr DoneCallbackParamsToC( // TODO: Pass args through. // params->sender_args = new TF_RendezvousArgsStruct(ToC(sender_args)); // params->recver_args = new TF_RendezvousArgsStruct(ToC(recver_args)); - Status tensor_status; + absl::Status tensor_status; params->tensor = TF_TensorFromTensor(tensor, &tensor_status); if (!tensor_status.ok()) { MakeDoneCallbackParamDeleter()(params); @@ -142,7 +142,7 @@ RendezvousInterface::DoneCallback FromC( } TF_RendezvousDoneCallback_Function callback = c_on_done.callback; void* context = c_on_done.context; - auto cpp_callback = [callback, context](const Status& status, + auto cpp_callback = [callback, context](const absl::Status& status, RendezvousInterface::Args sender_args, RendezvousInterface::Args recver_args, const Tensor& tensor, @@ -161,7 +161,7 @@ void SendFunctionThunk(void* opa_rendezvous, TF_RendezvousSend_Params* params) { RendezvousInterface::ParsedKey key = FromC(*params->key); RendezvousInterface::Args args = FromC(*params->args); Tensor tensor; - Status tensor_status = TF_TensorToTensor(params->tensor, &tensor); + absl::Status tensor_status = TF_TensorToTensor(params->tensor, &tensor); bool is_dead = params->is_dead; if (tensor_status.ok()) { tsl::Set_TF_Status_from_Status( @@ -181,7 +181,8 @@ void RecvFunctionThunk(void* opa_rendezvous, RendezvousInterface::Args args = FromC(*params->args); RendezvousInterface::DoneCallback on_done = [device_context = args.device_context, on_done = params->on_done]( - const Status& status, const RendezvousInterface::Args& send_args, + const absl::Status& status, + const RendezvousInterface::Args& send_args, const RendezvousInterface::Args& recv_args, const Tensor& tensor, const bool is_dead) { FromC(on_done)(status, send_args, recv_args, tensor, is_dead); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc index 02ea581b909bfd..7d60b2881a2ae9 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc @@ -74,7 +74,7 @@ class FakeAllocator : public Allocator { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return allocator_.get(); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.cc index 3ed4980b46de90..0e81a38910b6e2 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.cc @@ -42,8 +42,8 @@ absl::StatusOr ProcessGetKeyValueResult(TF_Buffer* result_buf, } } // namespace -Status CPluginCoordinationServiceAgent::InsertKeyValue(std::string_view key, - std::string_view value) { +absl::Status CPluginCoordinationServiceAgent::InsertKeyValue( + std::string_view key, std::string_view value) { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); TF_CoordinationServiceInsertKeyValue(key.data(), key.size(), value.data(), @@ -78,7 +78,8 @@ absl::StatusOr CPluginCoordinationServiceAgent::TryGetKeyValue( return ProcessGetKeyValueResult(result_buf, status); } -Status CPluginCoordinationServiceAgent::DeleteKeyValue(std::string_view key) { +absl::Status CPluginCoordinationServiceAgent::DeleteKeyValue( + std::string_view key) { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); TF_CoordinationServiceDeleteKeyValue(key.data(), key.size(), agent_, status); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h index 35e61b11d8a007..a7bc27403f5e01 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent.h @@ -39,14 +39,15 @@ class CPluginCoordinationServiceAgent : public PluginCoordinationServiceAgent { return TF_CoordinationServiceIsInitialized(agent_); } - Status InsertKeyValue(std::string_view key, std::string_view value) override; + absl::Status InsertKeyValue(std::string_view key, + std::string_view value) override; absl::StatusOr GetKeyValue(std::string_view key) override; absl::StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) override; absl::StatusOr TryGetKeyValue(std::string_view key) override; - Status DeleteKeyValue(std::string_view key) override; + absl::Status DeleteKeyValue(std::string_view key) override; private: TF_CoordinationServiceAgent* agent_; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc index 5d62f8c58668c6..04c4f96249ca91 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc @@ -28,11 +28,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/status.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { namespace { @@ -183,7 +183,7 @@ class CPluginCoordinationServiceAgentTest : public ::testing::Test { TF_ASSERT_OK(impl_->Initialize( tsl::Env::Default(), /*job_name=*/"test_job", /*task_id=*/0, config, std::move(client_), - /*error_fn=*/[](Status s) { + /*error_fn=*/[](absl::Status s) { LOG(ERROR) << "Coordination agent is set to error: " << s; })); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc index 04fa84e1298071..18331eee70b4bb 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.cc @@ -55,8 +55,8 @@ constexpr int kInvalidLineNumber = -1; namespace tensorflow { // ------------------ CPluginOpKernelConstruction ---------------------------- -Status CPluginOpKernelConstruction::GetBoolAttr(std::string_view attr_name, - bool* value) const { +absl::Status CPluginOpKernelConstruction::GetBoolAttr( + std::string_view attr_name, bool* value) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); unsigned char bool_as_char; @@ -66,15 +66,15 @@ Status CPluginOpKernelConstruction::GetBoolAttr(std::string_view attr_name, return StatusFromTF_Status(status); } -Status CPluginOpKernelConstruction::GetInt32Attr(std::string_view attr_name, - int32_t* value) const { +absl::Status CPluginOpKernelConstruction::GetInt32Attr( + std::string_view attr_name, int32_t* value) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); TF_OpKernelConstruction_GetAttrInt32(ctx_, attr_name.data(), value, status); return StatusFromTF_Status(status); } -Status CPluginOpKernelConstruction::GetInt32AttrList( +absl::Status CPluginOpKernelConstruction::GetInt32AttrList( std::string_view attr_name, std::vector* value) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); @@ -91,16 +91,16 @@ Status CPluginOpKernelConstruction::GetInt32AttrList( return StatusFromTF_Status(status); } -Status CPluginOpKernelConstruction::GetInt64Attr(std::string_view attr_name, - int64_t* value) const { +absl::Status CPluginOpKernelConstruction::GetInt64Attr( + std::string_view attr_name, int64_t* value) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); TF_OpKernelConstruction_GetAttrInt64(ctx_, attr_name.data(), value, status); return StatusFromTF_Status(status); } -Status CPluginOpKernelConstruction::GetStringAttr(std::string_view attr_name, - std::string* value) const { +absl::Status CPluginOpKernelConstruction::GetStringAttr( + std::string_view attr_name, std::string* value) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); int list_size = 0, attr_string_size = 0; // list_size is not used. @@ -114,7 +114,7 @@ Status CPluginOpKernelConstruction::GetStringAttr(std::string_view attr_name, return StatusFromTF_Status(status); } -Status CPluginOpKernelConstruction::GetFunctionAttr( +absl::Status CPluginOpKernelConstruction::GetFunctionAttr( std::string_view attr_name, NameAttrList* function) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_Status* status = c_status_ptr.get(); @@ -126,12 +126,12 @@ Status CPluginOpKernelConstruction::GetFunctionAttr( return absl::OkStatus(); } -void CPluginOpKernelConstruction::CtxFailure(const Status& status) { +void CPluginOpKernelConstruction::CtxFailure(const absl::Status& status) { CtxFailure(/*file=*/"", /*line=*/kInvalidLineNumber, status); } void CPluginOpKernelConstruction::CtxFailure(const char* file, int line, - const Status& status) { + const absl::Status& status) { TF_StatusPtr c_status_ptr(TF_NewStatus()); tsl::Set_TF_Status_from_Status(c_status_ptr.get(), status); if (line != kInvalidLineNumber) { @@ -148,7 +148,7 @@ std::string_view CPluginOpKernelContext::GetResourceMgrDefaultContainerName() { return {default_container_name.data, default_container_name.len}; } -Status CPluginOpKernelContext::LookupOrCreateResource( +absl::Status CPluginOpKernelContext::LookupOrCreateResource( std::string_view container_name, std::string_view plugin_resource_name, void** result_plugin_resource, void* (*create_func)(void*), void* create_func_args, void (*delete_func)(void*)) { @@ -171,7 +171,7 @@ CPluginOpKernelContext::GetPluginCoordinationServiceAgent() const { return CreatePluginCoordinationServiceAgent(agent); } -Status CPluginOpKernelContext::CreatePluginVariable( +absl::Status CPluginOpKernelContext::CreatePluginVariable( int index, PluginVariable** variable) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_VariableInfo* c_var_info = @@ -183,7 +183,7 @@ Status CPluginOpKernelContext::CreatePluginVariable( return absl::OkStatus(); } -Status CPluginOpKernelContext::AllocateTempForPluginVariable( +absl::Status CPluginOpKernelContext::AllocateTempForPluginVariable( PluginVariable* variable) { TF_StatusPtr c_status_ptr(TF_NewStatus()); CPluginVariable* c_plugin_variable = @@ -234,8 +234,8 @@ absl::Status CPluginOpKernelContext::GetInput(const char* name, return status; } -Status CPluginOpKernelContext::GetInputRange(std::string_view name, - std::pair* range) const { +absl::Status CPluginOpKernelContext::GetInputRange( + std::string_view name, std::pair* range) const { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_InputRange_Args args; args.status = c_status_ptr.get(); @@ -266,7 +266,7 @@ std::string_view CPluginOpKernelContext::GetDeviceName() const { return {device_name.data, device_name.len}; } -Status CPluginOpKernelContext::GetConfigProto( +absl::Status CPluginOpKernelContext::GetConfigProto( const ConfigProto** config_proto) const { TF_BufferPtr serialized_config_proto_ptr(TF_NewBuffer()); TF_StatusPtr c_status_ptr(TF_NewStatus()); @@ -274,13 +274,13 @@ Status CPluginOpKernelContext::GetConfigProto( c_status_ptr.get()); TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status_ptr.get())); ConfigProto* config_proto_ptr = new ConfigProto(); - Status status = + absl::Status status = BufferToMessage(serialized_config_proto_ptr.get(), config_proto_ptr); *config_proto = config_proto_ptr; return status; } -Status CPluginOpKernelContext::GetFunctionLibraryDefinition( +absl::Status CPluginOpKernelContext::GetFunctionLibraryDefinition( const FunctionLibraryDefinition** flib_def) const { TF_BufferPtr serialized_function_library_ptr(TF_NewBuffer()); TF_StatusPtr c_status_ptr(TF_NewStatus()); @@ -297,7 +297,7 @@ Status CPluginOpKernelContext::GetFunctionLibraryDefinition( return absl::OkStatus(); } -Status CPluginOpKernelContext::GetResourceHandle( +absl::Status CPluginOpKernelContext::GetResourceHandle( int index, const ResourceHandle** handle) const { TF_BufferPtr serialized_resource_handle_ptr(TF_NewBuffer()); TF_StatusPtr c_status_ptr(TF_NewStatus()); @@ -315,9 +315,9 @@ Status CPluginOpKernelContext::GetResourceHandle( return absl::OkStatus(); } -Status CPluginOpKernelContext::AllocateOutput(int index, - const TensorShape& shape, - Tensor** out) { +absl::Status CPluginOpKernelContext::AllocateOutput(int index, + const TensorShape& shape, + Tensor** out) { TF_StatusPtr c_status_ptr(TF_NewStatus()); const auto num_dims = shape.dims(); int64_t* dim_array = new int64_t[num_dims]; @@ -333,22 +333,23 @@ Status CPluginOpKernelContext::AllocateOutput(int index, return TF_TensorToTensor(c_tensor_ptr.get(), *out); } -Status CPluginOpKernelContext::SetOutput(int index, const Tensor& tensor) { +absl::Status CPluginOpKernelContext::SetOutput(int index, + const Tensor& tensor) { TF_StatusPtr c_status_ptr(TF_NewStatus()); TF_TensorPtr c_tensor_ptr; - Status status; + absl::Status status; c_tensor_ptr.reset(TF_TensorFromTensor(tensor, &status)); TF_RETURN_IF_ERROR(status); TF_SetOutput(ctx_, index, c_tensor_ptr.get(), c_status_ptr.get()); return StatusFromTF_Status(c_status_ptr.get()); } -void CPluginOpKernelContext::CtxFailure(const Status& status) { +void CPluginOpKernelContext::CtxFailure(const absl::Status& status) { CtxFailure(/*file=*/"", /*line=*/kInvalidLineNumber, status); } void CPluginOpKernelContext::CtxFailure(const char* file, int line, - const Status& status) { + const absl::Status& status) { TF_StatusPtr c_status_ptr(TF_NewStatus()); tsl::Set_TF_Status_from_Status(c_status_ptr.get(), status); if (line != kInvalidLineNumber) { diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h index 0ee0eb76a381d2..de3158bc122fe4 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_op_kernel.h @@ -39,19 +39,22 @@ class CPluginOpKernelConstruction : public PluginOpKernelConstruction { explicit CPluginOpKernelConstruction(void* ctx) : ctx_(reinterpret_cast(ctx)) {} - Status GetBoolAttr(std::string_view attr_name, bool* value) const override; - Status GetInt32Attr(std::string_view attr_name, int* value) const override; - Status GetInt32AttrList(std::string_view attr_name, - std::vector* value) const override; - Status GetInt64Attr(std::string_view attr_name, - int64_t* value) const override; - Status GetStringAttr(std::string_view attr_name, - std::string* value) const override; - Status GetFunctionAttr(std::string_view attr_name, - NameAttrList* function) const override; - - void CtxFailure(const Status& status) override; - void CtxFailure(const char* file, int line, const Status& status) override; + absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const override; + absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const override; + absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const override; + absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const override; + absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const override; + absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const override; + + void CtxFailure(const absl::Status& status) override; + void CtxFailure(const char* file, int line, + const absl::Status& status) override; void* GetContext() const override { return ctx_; } @@ -66,20 +69,20 @@ class CPluginOpKernelContext : public PluginOpKernelContext { std::string_view GetResourceMgrDefaultContainerName() override; - Status LookupOrCreateResource(std::string_view container_name, - std::string_view plugin_resource_name, - void** result_plugin_resource, - void* (*create_func)(void*), - void* create_func_args, - void (*delete_func)(void*)) override; + absl::Status LookupOrCreateResource(std::string_view container_name, + std::string_view plugin_resource_name, + void** result_plugin_resource, + void* (*create_func)(void*), + void* create_func_args, + void (*delete_func)(void*)) override; std::unique_ptr GetPluginCoordinationServiceAgent() const override; - Status CreatePluginVariable(int index, - PluginVariable** variable) const override; + absl::Status CreatePluginVariable(int index, + PluginVariable** variable) const override; - Status AllocateTempForPluginVariable(PluginVariable* variable) override; + absl::Status AllocateTempForPluginVariable(PluginVariable* variable) override; int NumInputs() const override { return TF_NumInputs(ctx_); } @@ -87,8 +90,8 @@ class CPluginOpKernelContext : public PluginOpKernelContext { absl::Status GetInput(const char* name, const Tensor** tensor) const override; - Status GetInputRange(std::string_view name, - std::pair* range) const override; + absl::Status GetInputRange(std::string_view name, + std::pair* range) const override; DataType GetInputDataType(int index) const override; @@ -111,7 +114,7 @@ class CPluginOpKernelContext : public PluginOpKernelContext { return ""; } - Status GetConfigProto(const ConfigProto** config_proto) const override; + absl::Status GetConfigProto(const ConfigProto** config_proto) const override; // Note: this function is only meant to clear up `config_proto` created by the // above `CPluginOpKernelContext::GetConfigProto()`. @@ -119,7 +122,7 @@ class CPluginOpKernelContext : public PluginOpKernelContext { delete config_proto; } - Status GetFunctionLibraryDefinition( + absl::Status GetFunctionLibraryDefinition( const FunctionLibraryDefinition** flib_def) const override; // Note: this function is only meant to clear up `flib_def` created by the @@ -129,8 +132,8 @@ class CPluginOpKernelContext : public PluginOpKernelContext { delete flib_def; } - Status GetResourceHandle(int index, - const ResourceHandle** handle) const override; + absl::Status GetResourceHandle(int index, + const ResourceHandle** handle) const override; // Note: this function is only meant to clear up `handle` created by the above // `CPluginOpKernelContext::GetResourceHandle()`. @@ -142,13 +145,14 @@ class CPluginOpKernelContext : public PluginOpKernelContext { return TF_GetGraphDefVersion(ctx_); } - Status AllocateOutput(int index, const TensorShape& shape, - Tensor** out) override; + absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) override; - Status SetOutput(int index, const Tensor& tensor) override; + absl::Status SetOutput(int index, const Tensor& tensor) override; - void CtxFailure(const Status& status) override; - void CtxFailure(const char* file, int line, const Status& status) override; + void CtxFailure(const absl::Status& status) override; + void CtxFailure(const char* file, int line, + const absl::Status& status) override; void* GetContext() const override { return ctx_; } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h index edba423358448c..930efed4b43845 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_coordination_service_agent.h @@ -38,7 +38,8 @@ class DirectPluginCoordinationServiceAgent return agent_->IsInitialized(); } - Status InsertKeyValue(std::string_view key, std::string_view value) override { + absl::Status InsertKeyValue(std::string_view key, + std::string_view value) override { return agent_->InsertKeyValue(key, value); } @@ -55,7 +56,7 @@ class DirectPluginCoordinationServiceAgent return agent_->TryGetKeyValue(key); } - Status DeleteKeyValue(std::string_view key) override { + absl::Status DeleteKeyValue(std::string_view key) override { return agent_->DeleteKeyValue(key); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc index 61116b9e4cdcd4..7a37c3e277a32c 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.cc @@ -32,37 +32,37 @@ limitations under the License. namespace tensorflow { -Status DirectPluginOpKernelConstruction::GetBoolAttr(std::string_view attr_name, - bool* value) const { +absl::Status DirectPluginOpKernelConstruction::GetBoolAttr( + std::string_view attr_name, bool* value) const { return ctx_->GetAttr(attr_name, value); } -Status DirectPluginOpKernelConstruction::GetInt32Attr( +absl::Status DirectPluginOpKernelConstruction::GetInt32Attr( std::string_view attr_name, int* value) const { return ctx_->GetAttr(attr_name, value); } -Status DirectPluginOpKernelConstruction::GetInt32AttrList( +absl::Status DirectPluginOpKernelConstruction::GetInt32AttrList( std::string_view attr_name, std::vector* value) const { return ctx_->GetAttr(attr_name, value); } -Status DirectPluginOpKernelConstruction::GetInt64Attr( +absl::Status DirectPluginOpKernelConstruction::GetInt64Attr( std::string_view attr_name, int64_t* value) const { return ctx_->GetAttr(attr_name, value); } -Status DirectPluginOpKernelConstruction::GetStringAttr( +absl::Status DirectPluginOpKernelConstruction::GetStringAttr( std::string_view attr_name, std::string* value) const { return ctx_->GetAttr(attr_name, value); } -Status DirectPluginOpKernelConstruction::GetFunctionAttr( +absl::Status DirectPluginOpKernelConstruction::GetFunctionAttr( std::string_view attr_name, NameAttrList* function) const { return ctx_->GetAttr(attr_name, function); } -Status DirectPluginOpKernelContext::CreatePluginVariable( +absl::Status DirectPluginOpKernelContext::CreatePluginVariable( int index, PluginVariable** variable) const { const auto& arg_tensor = ctx_->input(index); if (arg_tensor.dtype() != DT_RESOURCE) { @@ -78,7 +78,7 @@ Status DirectPluginOpKernelContext::CreatePluginVariable( return absl::OkStatus(); } -Status DirectPluginOpKernelContext::AllocateTempForPluginVariable( +absl::Status DirectPluginOpKernelContext::AllocateTempForPluginVariable( PluginVariable* variable) { auto* direct_variable = reinterpret_cast(variable); if (direct_variable->var_info_.var() == nullptr) { @@ -96,7 +96,7 @@ DirectPluginOpKernelContext::GetResourceMgrDefaultContainerName() { return ctx_->resource_manager()->default_container(); } -Status DirectPluginOpKernelContext::LookupOrCreateResource( +absl::Status DirectPluginOpKernelContext::LookupOrCreateResource( std::string_view container_name, std::string_view plugin_resource_name, void** result_plugin_resource, void* (*create_func)(void*), void* create_func_args, void (*delete_func)(void*)) { @@ -126,12 +126,12 @@ absl::Status DirectPluginOpKernelContext::GetInput( return absl::OkStatus(); } -Status DirectPluginOpKernelContext::GetInput(const char* name, - const Tensor** tensor) const { +absl::Status DirectPluginOpKernelContext::GetInput( + const char* name, const Tensor** tensor) const { return ctx_->input(name, tensor); } -Status DirectPluginOpKernelContext::GetInputRange( +absl::Status DirectPluginOpKernelContext::GetInputRange( std::string_view name, std::pair* range) const { return ctx_->op_kernel().InputRange(name, &range->first, &range->second); } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h index 86b631a29723db..053bf7c7567cc1 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/direct_plugin_op_kernel.h @@ -35,20 +35,25 @@ class DirectPluginOpKernelConstruction : public PluginOpKernelConstruction { explicit DirectPluginOpKernelConstruction(void* ctx) : ctx_(reinterpret_cast(ctx)) {} - Status GetBoolAttr(std::string_view attr_name, bool* value) const override; - Status GetInt32Attr(std::string_view attr_name, int* value) const override; - Status GetInt32AttrList(std::string_view attr_name, - std::vector* value) const override; - Status GetInt64Attr(std::string_view attr_name, - int64_t* value) const override; - Status GetStringAttr(std::string_view attr_name, - std::string* value) const override; - Status GetFunctionAttr(std::string_view attr_name, - NameAttrList* function) const override; - - void CtxFailure(const Status& status) override { ctx_->CtxFailure(status); } - - void CtxFailure(const char* file, int line, const Status& status) override { + absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const override; + absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const override; + absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const override; + absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const override; + absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const override; + absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const override; + + void CtxFailure(const absl::Status& status) override { + ctx_->CtxFailure(status); + } + + void CtxFailure(const char* file, int line, + const absl::Status& status) override { ctx_->CtxFailure(file, line, status); } @@ -64,12 +69,12 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { std::string_view GetResourceMgrDefaultContainerName() override; - Status LookupOrCreateResource(std::string_view container_name, - std::string_view plugin_resource_name, - void** result_plugin_resource, - void* (*create_func)(void*), - void* create_func_args, - void (*delete_func)(void*)) override; + absl::Status LookupOrCreateResource(std::string_view container_name, + std::string_view plugin_resource_name, + void** result_plugin_resource, + void* (*create_func)(void*), + void* create_func_args, + void (*delete_func)(void*)) override; std::unique_ptr GetPluginCoordinationServiceAgent() const override { @@ -77,10 +82,10 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { ctx_->coordination_service_agent()); } - Status CreatePluginVariable(int index, - PluginVariable** variable) const override; + absl::Status CreatePluginVariable(int index, + PluginVariable** variable) const override; - Status AllocateTempForPluginVariable(PluginVariable* variable) override; + absl::Status AllocateTempForPluginVariable(PluginVariable* variable) override; int NumInputs() const override { return ctx_->num_inputs(); } @@ -88,8 +93,8 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { absl::Status GetInput(const char* name, const Tensor** tensor) const override; - Status GetInputRange(std::string_view name, - std::pair* range) const override; + absl::Status GetInputRange(std::string_view name, + std::pair* range) const override; DataType GetInputDataType(int index) const override { return ctx_->input_dtype(index); @@ -117,7 +122,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { return ctx_->session_metadata() ? ctx_->session_metadata()->name() : ""; } - Status GetConfigProto(const ConfigProto** config_proto) const override { + absl::Status GetConfigProto(const ConfigProto** config_proto) const override { *config_proto = ctx_->function_library()->config_proto(); return absl::OkStatus(); } @@ -127,7 +132,7 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { // from FunctionLibraryRuntime in `ctx_`. } - Status GetFunctionLibraryDefinition( + absl::Status GetFunctionLibraryDefinition( const FunctionLibraryDefinition** flib_def) const override { *flib_def = ctx_->function_library()->GetFunctionLibraryDefinition(); return absl::OkStatus(); @@ -139,8 +144,8 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { // is obtained from FunctionLibraryRuntime in `ctx_`. } - Status GetResourceHandle(int index, - const ResourceHandle** handle) const override { + absl::Status GetResourceHandle(int index, + const ResourceHandle** handle) const override { *handle = &HandleFromInput(ctx_, index); return absl::OkStatus(); } @@ -154,19 +159,22 @@ class DirectPluginOpKernelContext : public PluginOpKernelContext { return ctx_->function_library()->graph_def_version(); } - Status AllocateOutput(int index, const TensorShape& shape, - Tensor** out) override { + absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) override { return ctx_->allocate_output(index, shape, out); } - Status SetOutput(int index, const Tensor& tensor) override { + absl::Status SetOutput(int index, const Tensor& tensor) override { ctx_->set_output(index, tensor); return absl::OkStatus(); } - void CtxFailure(const Status& status) override { ctx_->CtxFailure(status); } + void CtxFailure(const absl::Status& status) override { + ctx_->CtxFailure(status); + } - void CtxFailure(const char* file, int line, const Status& status) override { + void CtxFailure(const char* file, int line, + const absl::Status& status) override { LOG(WARNING) << "Plugin OP_REQUIRES failed at " << file << ": " << line << ": " << status; ctx_->CtxFailure(file, line, status); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc index 09a2fff9f37177..11ac09674d48c5 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.cc @@ -105,28 +105,29 @@ void NextPluggableDevice::ComputeAsync(AsyncOpKernel* op_kernel, } // TODO(chuanhao): implement NextPluggableDevice::Sync(). -Status NextPluggableDevice::Sync() { return absl::OkStatus(); } +absl::Status NextPluggableDevice::Sync() { return absl::OkStatus(); } // TODO(chuanhao): implement NextPluggableDevice::Sync(). void NextPluggableDevice::Sync(const DoneCallback& done) { done(Sync()); } -Status NextPluggableDevice::TryGetDeviceContext(DeviceContext** out_context) { +absl::Status NextPluggableDevice::TryGetDeviceContext( + DeviceContext** out_context) { *out_context = device_context_.get(); (*out_context)->Ref(); return absl::OkStatus(); } -Status NextPluggableDevice::MakeTensorFromProto( +absl::Status NextPluggableDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Cannot parse tensor from proto: ", - tensor_proto.DebugString())); + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Cannot parse tensor from proto: ", + tensor_proto.DebugString())); } - Status status; + absl::Status status; if (alloc_attrs.on_host()) { *tensor = parsed; VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor); @@ -154,17 +155,17 @@ Status NextPluggableDevice::MakeTensorFromProto( notifications.emplace_back(); Notification& n = *notifications.rbegin(); - StatusCallback done = [&n, &status](const Status& s) { + StatusCallback done = [&n, &status](const absl::Status& s) { if (status.ok()) { status.Update(s); } n.Notify(); }; if (!DMAHelper::CanUseDMA(&from)) { - Status err = - Status(absl::StatusCode::kInternal, - absl::StrCat("NextPluggableDevice copy from non-DMA ", - DataTypeString(from.dtype()), " tensor")); + absl::Status err = + absl::Status(absl::StatusCode::kInternal, + absl::StrCat("NextPluggableDevice copy from non-DMA ", + DataTypeString(from.dtype()), " tensor")); done(err); return err; } @@ -175,16 +176,17 @@ Status NextPluggableDevice::MakeTensorFromProto( // If the tensor is not initialized, we likely ran out of memory. if (!copy_dst->IsInitialized()) { delete copy_dst; - Status err = Status(absl::StatusCode::kResourceExhausted, - absl::StrCat("OOM when allocating tensor of shape ", - from.shape().DebugString(), " and type ", - DataTypeString(from.dtype()))); + absl::Status err = + absl::Status(absl::StatusCode::kResourceExhausted, + absl::StrCat("OOM when allocating tensor of shape ", + from.shape().DebugString(), " and type ", + DataTypeString(from.dtype()))); done(err); return err; } auto wrapped_done = [to, copy_dst, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { if (s.ok()) { *to = std::move(*copy_dst); } @@ -198,7 +200,7 @@ Status NextPluggableDevice::MakeTensorFromProto( return absl::OkStatus(); }; - Status s; + absl::Status s; for (int64_t ix = 0; ix < parsed.NumElements(); ++ix) { s = VariantDeviceCopy(VariantDeviceCopyDirection::HOST_TO_DEVICE, from[ix], ©_variant[ix], copier); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h index 553b9c37cb6228..3e1e97be11d6e1 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device.h @@ -68,15 +68,15 @@ class NextPluggableDevice : public PjRtBaseDevice { void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; - Status Sync() override; + absl::Status Sync() override; void Sync(const DoneCallback& done) override; - Status TryGetDeviceContext(DeviceContext** out_context) override; + absl::Status TryGetDeviceContext(DeviceContext** out_context) override; - Status MakeTensorFromProto(const TensorProto& tensor_proto, - AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + AllocatorAttributes alloc_attrs, + Tensor* tensor) override; int GetDeviceOrdinal() const { return device_ordinal_; } diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.cc index ddc7949a1a7f39..206c1a4a517141 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_context.cc @@ -52,7 +52,7 @@ struct StatusCallbackInvocationParams { void InvokeStatusCallbackFn(void* arg) { StatusCallbackInvocationParams* params = reinterpret_cast(arg); - tensorflow::Status cc_status = StatusFromTF_Status(params->status); + absl::Status cc_status = StatusFromTF_Status(params->status); // Invokes the "done" callback here. params->callback(cc_status); // Explicitly delete the params after callback is done. @@ -76,7 +76,7 @@ void NextPluggableDeviceContext::CopyDeviceTensorToCPU( tsl::profiler::TraceMeProducer traceme( [] { return "NextPluggableDeviceContext::CopyDeviceTensorToCPU"; }, tsl::profiler::ContextType::kGeneric); - tensorflow::Status s; + absl::Status s; TF_Tensor* c_cpu_tensor = TF_TensorFromTensor(*cpu_tensor, &s); if (!s.ok()) { done(s); @@ -105,7 +105,7 @@ void NextPluggableDeviceContext::CopyCPUTensorToDevice( tsl::profiler::TraceMeProducer traceme( [] { return "NextPluggableDeviceContext::CopyCPUTensorToDevice"; }, tsl::profiler::ContextType::kGeneric); - tensorflow::Status s; + absl::Status s; TF_Tensor* c_cpu_tensor = TF_TensorFromTensor(*cpu_tensor, &s); if (!s.ok()) { done(s); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc index 54cfe76b084530..0540cc3084a726 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.cc @@ -49,14 +49,14 @@ absl::StatusOr DeviceShapeRepresentation( &c_xla_shape.value, type, use_fast_memory, ConvertToCXlaLayoutPreference(layout_preference), &c_device_shape.value, tf_status); - const Status status = StatusFromTF_Status(tf_status); + const absl::Status status = StatusFromTF_Status(tf_status); TF_DeleteStatus(tf_status); TF_RETURN_IF_ERROR(status); return c_device_shape.AsCpp(); } } // namespace -Status NextPluggableDeviceFactory::ListPhysicalDevices( +absl::Status NextPluggableDeviceFactory::ListPhysicalDevices( std::vector* devices) { TF_Status* c_status = TF_NewStatus(); int32_t device_count = api_->TFNPD_GetDeviceCount(c_status); @@ -72,7 +72,7 @@ Status NextPluggableDeviceFactory::ListPhysicalDevices( return absl::OkStatus(); } -Status NextPluggableDeviceFactory::CreateDevices( +absl::Status NextPluggableDeviceFactory::CreateDevices( const SessionOptions& session_options, const std::string& name_prefix, std::vector>* devices) { TF_Status* c_status = TF_NewStatus(); diff --git a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h index 6d9d4d4d1ab16a..0fe4aa3ff59dc5 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h @@ -35,11 +35,11 @@ class NextPluggableDeviceFactory : public DeviceFactory { device_type_(device_type), compilation_device_name_(compilation_device_name) {} - Status ListPhysicalDevices(std::vector* devices) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& session_options, - const std::string& name_prefix, - std::vector>* devices) override; + absl::Status CreateDevices( + const SessionOptions& session_options, const std::string& name_prefix, + std::vector>* devices) override; const std::string& compilation_device_name() const { return compilation_device_name_; diff --git a/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h b/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h index 794dc33d5e6266..4d3a1734d22966 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/plugin_coordination_service_agent.h @@ -32,15 +32,15 @@ class PluginCoordinationServiceAgent { virtual bool IsInitialized() const = 0; - virtual Status InsertKeyValue(std::string_view key, - std::string_view value) = 0; + virtual absl::Status InsertKeyValue(std::string_view key, + std::string_view value) = 0; virtual absl::StatusOr GetKeyValue(std::string_view key) = 0; virtual absl::StatusOr GetKeyValue(std::string_view key, absl::Duration timeout) = 0; virtual absl::StatusOr TryGetKeyValue(std::string_view key) = 0; - virtual Status DeleteKeyValue(std::string_view key) = 0; + virtual absl::Status DeleteKeyValue(std::string_view key) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h b/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h index a8dfe8b885cb5e..b0123999ada06c 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h +++ b/tensorflow/core/common_runtime/next_pluggable_device/plugin_op_kernel.h @@ -66,19 +66,22 @@ class PluginOpKernelConstruction { PluginOpKernelConstruction() = default; virtual ~PluginOpKernelConstruction() = default; - virtual Status GetBoolAttr(std::string_view attr_name, bool* value) const = 0; - virtual Status GetInt32Attr(std::string_view attr_name, int* value) const = 0; - virtual Status GetInt32AttrList(std::string_view attr_name, - std::vector* value) const = 0; - virtual Status GetInt64Attr(std::string_view attr_name, - int64_t* value) const = 0; - virtual Status GetStringAttr(std::string_view attr_name, - std::string* value) const = 0; - virtual Status GetFunctionAttr(std::string_view attr_name, - NameAttrList* function) const = 0; - - virtual void CtxFailure(const Status& status) = 0; - virtual void CtxFailure(const char* file, int line, const Status& status) = 0; + virtual absl::Status GetBoolAttr(std::string_view attr_name, + bool* value) const = 0; + virtual absl::Status GetInt32Attr(std::string_view attr_name, + int* value) const = 0; + virtual absl::Status GetInt32AttrList(std::string_view attr_name, + std::vector* value) const = 0; + virtual absl::Status GetInt64Attr(std::string_view attr_name, + int64_t* value) const = 0; + virtual absl::Status GetStringAttr(std::string_view attr_name, + std::string* value) const = 0; + virtual absl::Status GetFunctionAttr(std::string_view attr_name, + NameAttrList* function) const = 0; + + virtual void CtxFailure(const absl::Status& status) = 0; + virtual void CtxFailure(const char* file, int line, + const absl::Status& status) = 0; virtual void* GetContext() const = 0; }; @@ -90,22 +93,21 @@ class PluginOpKernelContext { virtual std::string_view GetResourceMgrDefaultContainerName() = 0; - virtual Status LookupOrCreateResource(std::string_view container_name, - std::string_view plugin_resource_name, - void** result_plugin_resource, - void* (*create_func)(void*), - void* create_func_args, - void (*delete_func)(void*)) = 0; + virtual absl::Status LookupOrCreateResource( + std::string_view container_name, std::string_view plugin_resource_name, + void** result_plugin_resource, void* (*create_func)(void*), + void* create_func_args, void (*delete_func)(void*)) = 0; virtual std::unique_ptr GetPluginCoordinationServiceAgent() const = 0; // This method will allocate a new `PluginVariable`. Caller is responsible // for managing it's lifetime. - virtual Status CreatePluginVariable(int index, - PluginVariable** variable) const = 0; + virtual absl::Status CreatePluginVariable( + int index, PluginVariable** variable) const = 0; - virtual Status AllocateTempForPluginVariable(PluginVariable* variable) = 0; + virtual absl::Status AllocateTempForPluginVariable( + PluginVariable* variable) = 0; virtual int NumInputs() const = 0; @@ -114,8 +116,8 @@ class PluginOpKernelContext { virtual absl::Status GetInput(const char* name, const Tensor** tensor) const = 0; - virtual Status GetInputRange(std::string_view name, - std::pair* range) const = 0; + virtual absl::Status GetInputRange(std::string_view name, + std::pair* range) const = 0; virtual DataType GetInputDataType(int index) const = 0; @@ -135,32 +137,34 @@ class PluginOpKernelContext { virtual std::string GetSessionName() const = 0; - virtual Status GetConfigProto(const ConfigProto** config_proto) const = 0; + virtual absl::Status GetConfigProto( + const ConfigProto** config_proto) const = 0; virtual void MaybeDeleteConfigProto( const ConfigProto* config_proto) const = 0; - virtual Status GetFunctionLibraryDefinition( + virtual absl::Status GetFunctionLibraryDefinition( const FunctionLibraryDefinition** flib_def) const = 0; virtual void MaybeDeleteFunctionLibraryDefinition( const FunctionLibraryDefinition* flib_def) const = 0; - virtual Status GetResourceHandle(int index, - const ResourceHandle** handle) const = 0; + virtual absl::Status GetResourceHandle( + int index, const ResourceHandle** handle) const = 0; virtual void MaybeDeleteResourceHandle( const ResourceHandle* handle) const = 0; virtual int GetGraphDefVersion() const = 0; - virtual Status AllocateOutput(int index, const TensorShape& shape, - Tensor** out) = 0; + virtual absl::Status AllocateOutput(int index, const TensorShape& shape, + Tensor** out) = 0; - virtual Status SetOutput(int index, const Tensor& tensor) = 0; + virtual absl::Status SetOutput(int index, const Tensor& tensor) = 0; - virtual void CtxFailure(const Status& status) = 0; - virtual void CtxFailure(const char* file, int line, const Status& status) = 0; + virtual void CtxFailure(const absl::Status& status) = 0; + virtual void CtxFailure(const char* file, int line, + const absl::Status& status) = 0; virtual void* GetContext() const = 0; }; diff --git a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc index 28c8309f823240..700a1d63c09046 100644 --- a/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc +++ b/tensorflow/core/common_runtime/optimize_cross_host_control_deps.cc @@ -65,6 +65,26 @@ Status BuildIdentityNNode(const Node& source, StringPiece name, return absl::OkStatus(); } +Status BuildIdentityNode(const Node& source, StringPiece name, + const string& device, Graph* graph, + std::vector& inputs, + Node** node) { + NodeDefBuilder builder(name, "Identity", NodeDebugInfo(source)); + if (!device.empty()) { + builder.Device(device); + } + builder.Input(inputs[0]); + + NodeDef def; + TF_RETURN_IF_ERROR(builder.Finalize(&def)); + + TF_ASSIGN_OR_RETURN(*node, graph->AddNode(def)); + if (!device.empty()) { + (*node)->set_assigned_device_name(device); + } + return absl::OkStatus(); +} + const string& RequestedOrAssignedDevice(const Node* n) { if (!n->assigned_device_name().empty()) { return n->assigned_device_name(); @@ -232,20 +252,45 @@ Status OptimizeCrossHostDataOutputEdges(Graph* graph, Node* data_after; std::vector inputs; inputs.reserve(pair.second.size()); - for (const Edge* edge : pair.second) { - inputs.emplace_back(edge->src()->name(), edge->src_output(), - edge->src()->output_type(edge->src_output())); - } - TF_RETURN_IF_ERROR(BuildIdentityNNode( - *n, graph->NewName(strings::StrCat(n->name(), "/", "data_after")), - device, graph, inputs, &data_after)); - - int i = 0; - for (const Edge* edge : pair.second) { - graph->AddEdge(edge->src(), edge->src_output(), data_after, i); - graph->AddEdge(data_after, i, edge->dst(), edge->dst_input()); - graph->RemoveEdge(edge); - i++; + const Edge* edge0 = pair.second[0]; + if (std::all_of(pair.second.begin(), pair.second.end(), + [edge0](const Edge* e) { + return e->src() == edge0->src() && + e->src_output() == edge0->src_output(); + })) { + // Handle the special case of all inputs being identical, which is when + // we only need an Identity op with one input. + // TODO(kramm): Can we break this up further? E.g. what if we have two + // sets of inputs that are both all identical? + inputs.emplace_back(edge0->src()->name(), edge0->src_output(), + edge0->src()->output_type(edge0->src_output())); + TF_RETURN_IF_ERROR(BuildIdentityNode( + *n, graph->NewName(strings::StrCat(n->name(), "/", "data_after")), + device, graph, inputs, &data_after)); + + graph->AddEdge(edge0->src(), edge0->src_output(), data_after, 0); + int i = 0; + for (const Edge* edge : pair.second) { + graph->AddEdge(data_after, 0, edge->dst(), edge->dst_input()); + graph->RemoveEdge(edge); + i++; + } + } else { + for (const Edge* edge : pair.second) { + inputs.emplace_back(edge->src()->name(), edge->src_output(), + edge->src()->output_type(edge->src_output())); + } + TF_RETURN_IF_ERROR(BuildIdentityNNode( + *n, graph->NewName(strings::StrCat(n->name(), "/", "data_after")), + device, graph, inputs, &data_after)); + + int i = 0; + for (const Edge* edge : pair.second) { + graph->AddEdge(data_after, i, edge->dst(), edge->dst_input()); + graph->AddEdge(edge->src(), edge->src_output(), data_after, i); + graph->RemoveEdge(edge); + i++; + } } } } diff --git a/tensorflow/core/common_runtime/optimize_cross_host_control_deps_test.cc b/tensorflow/core/common_runtime/optimize_cross_host_control_deps_test.cc index d8a803ba9aa8e8..8124acef392db9 100644 --- a/tensorflow/core/common_runtime/optimize_cross_host_control_deps_test.cc +++ b/tensorflow/core/common_runtime/optimize_cross_host_control_deps_test.cc @@ -80,9 +80,9 @@ TEST(OptimizeCrossHostControlDepsTest, OptimizeCrossHostDataOutputEdges) { auto b = ops::Identity(scope.WithOpName("b"), a[0]); b.node()->set_assigned_device_name("/job:worker/task:1/CPU:0"); - auto c = ops::Identity(scope.WithOpName("c"), a[0]); + auto c = ops::Identity(scope.WithOpName("c"), a[1]); c.node()->set_assigned_device_name("/job:worker/task:1/CPU:1"); - auto d = ops::Identity(scope.WithOpName("d"), a[1]); + auto d = ops::Identity(scope.WithOpName("d"), a[0]); d.node()->set_assigned_device_name("/job:worker/task:2/CPU:0"); auto e = ops::Identity(scope.WithOpName("e"), a[1]); e.node()->set_assigned_device_name("/job:worker/task:2/CPU:1"); @@ -116,7 +116,7 @@ TEST(OptimizeCrossHostControlDepsTest, OptimizeCrossHostDataOutputEdges) { "/job:worker/task:1/device:CPU:0"); EXPECT_EQ(data_after1->def().input_size(), 2); EXPECT_EQ(data_after1->def().input(0), "a"); - EXPECT_EQ(data_after1->def().input(1), "a"); + EXPECT_EQ(data_after1->def().input(1), "a:1"); EXPECT_EQ(data_after1->op_def().name(), "IdentityN"); ASSERT_NE(data_after2, nullptr); @@ -124,7 +124,7 @@ TEST(OptimizeCrossHostControlDepsTest, OptimizeCrossHostDataOutputEdges) { EXPECT_EQ(data_after2->assigned_device_name(), "/job:worker/task:2/device:CPU:0"); EXPECT_EQ(data_after2->def().input_size(), 2); - EXPECT_EQ(data_after2->def().input(0), "a:1"); + EXPECT_EQ(data_after2->def().input(0), "a"); EXPECT_EQ(data_after2->def().input(1), "a:1"); EXPECT_EQ(data_after2->op_def().name(), "IdentityN"); @@ -142,6 +142,54 @@ TEST(OptimizeCrossHostControlDepsTest, OptimizeCrossHostDataOutputEdges) { EXPECT_EQ(map["e"]->input(0), data_after2->name() + ":1"); } +TEST(OptimizeCrossHostControlDepsTest, + CreatesIdentityNodesWhenInputsIdentical) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + auto c1 = ops::Const(scope.WithOpName("c1"), 1.0f); + auto c2 = ops::Const(scope.WithOpName("c2"), 2.0f); + auto a = ops::IdentityN(scope.WithOpName("a"), {c1, c2}); + a.operation.node()->set_assigned_device_name("/job:worker/task:0/CPU:0"); + + auto b = ops::Identity(scope.WithOpName("b"), a[0]); + auto c = ops::Identity(scope.WithOpName("c"), a[0]); + auto d = ops::Identity(scope.WithOpName("d"), a[0]); + auto e = ops::Identity(scope.WithOpName("e"), a[0]); + b.node()->set_assigned_device_name("/job:worker/task:1/CPU:0"); + c.node()->set_assigned_device_name("/job:worker/task:1/CPU:0"); + d.node()->set_assigned_device_name("/job:worker/task:1/CPU:0"); + e.node()->set_assigned_device_name("/job:worker/task:1/CPU:0"); + + Graph graph(OpRegistry::Global()); + TF_ASSERT_OK(scope.ToGraph(&graph)); + ASSERT_EQ(graph.num_op_nodes(), 7); + + TF_ASSERT_OK(OptimizeCrossHostDataOutputEdges( + &graph, /*cross_host_edges_threshold=*/2)); + + ASSERT_EQ(graph.num_op_nodes(), 8); + + Node* data_after = GetNodeByName("a/data_after/_0", &graph); + + ASSERT_NE(data_after, nullptr); + EXPECT_EQ(data_after->op_def().name(), "Identity"); + EXPECT_EQ(data_after->assigned_device_name(), + "/job:worker/task:1/device:CPU:0"); + EXPECT_EQ(data_after->def().input_size(), 1); + EXPECT_EQ(data_after->def().input(0)[0], 'a'); + EXPECT_EQ(data_after->op_def().name(), "Identity"); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + std::unordered_map map; + for (auto& node : graph_def.node()) { + map[node.name()] = &node; + } + EXPECT_EQ(map["b"]->input(0), data_after->name()); + EXPECT_EQ(map["c"]->input(0), data_after->name()); + EXPECT_EQ(map["d"]->input(0), data_after->name()); + EXPECT_EQ(map["e"]->input(0), data_after->name()); +} + TEST(OptimizeCrossHostControlDepsTest, OptimizeCrossHostControlInputEdges) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); auto a = ops::Const(scope.WithOpName("a"), 1.0f); diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info.h b/tensorflow/core/common_runtime/optimized_function_graph_info.h index b15790dbeede36..c23d722176bf07 100644 --- a/tensorflow/core/common_runtime/optimized_function_graph_info.h +++ b/tensorflow/core/common_runtime/optimized_function_graph_info.h @@ -73,8 +73,8 @@ struct OptimizedFunctionGraphInfo { delete; OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) = default; // NOLINT - OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo&& info) = - default; // NOLINT + OptimizedFunctionGraphInfo& operator=( + OptimizedFunctionGraphInfo&& info) noexcept = default; // NOLINT // Converts from the struct to OptimizedFunctionGraph proto. static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info); diff --git a/tensorflow/core/common_runtime/pluggable_device/BUILD b/tensorflow/core/common_runtime/pluggable_device/BUILD index 5121950354e86e..cc0ee4e16d5418 100644 --- a/tensorflow/core/common_runtime/pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/pluggable_device/BUILD @@ -55,6 +55,7 @@ cc_library( "//tensorflow/core/common_runtime/device:device_event_mgr", "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_xla//xla/stream_executor", @@ -86,6 +87,7 @@ cc_library( "//tensorflow/compiler/jit:pjrt_device_context", "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -99,6 +101,8 @@ cc_library( "//tensorflow/core/common_runtime/next_pluggable_device:next_pluggable_device_factory", "//tensorflow/core/platform:stream_executor", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_xla//xla/pjrt:pjrt_api", "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", @@ -115,6 +119,7 @@ cc_library( linkstatic = 1, deps = [ "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -151,6 +156,9 @@ cc_library( "//tensorflow/core/common_runtime:bfc_allocator", "//tensorflow/core/common_runtime/device:device_id", "//tensorflow/core/common_runtime/device:device_mem_allocator", + "@com_google_absl//absl/log", + "@com_google_absl//absl/memory", + "@local_xla//xla/tsl/framework:bfc_allocator", ], ) @@ -168,6 +176,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/common_runtime/device:device_id", "//tensorflow/core/common_runtime/device:device_mem_allocator", + "//tensorflow/core/framework:allocator", ], ) diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc index 161a21c5d44615..ef1f2a831ee9a5 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.cc @@ -27,38 +27,40 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/ascii.h" +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_id_utils.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" -#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/graph/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" -#include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/env_var.h" -#include "tensorflow/core/util/stream_executor_util.h" +#include "tsl/platform/errors.h" namespace tensorflow { @@ -197,7 +199,7 @@ PluggableDevice::~PluggableDevice() { device_context_->Unref(); } -Status PluggableDevice::Init(const SessionOptions& options) { +absl::Status PluggableDevice::Init(const SessionOptions& options) { se::Platform* platform = PluggableDeviceMachineManager(platform_name_); auto executor_status = DeviceIdUtil::ExecutorForTfDeviceId( DeviceType(device_type()), platform, tf_device_id_); @@ -335,7 +337,9 @@ void PluggableDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { // Based on the semantics of Device::Sync, this call should wait for // all streams not just the current one. -Status PluggableDevice::Sync() { return PluggableDeviceUtil::SyncAll(this); } +absl::Status PluggableDevice::Sync() { + return PluggableDeviceUtil::SyncAll(this); +} void PluggableDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, @@ -353,7 +357,7 @@ void PluggableDevice::ComputeAsync(AsyncOpKernel* op_kernel, op_kernel->ComputeAsync(context, std::move(done)); } -Status PluggableDevice::MaybeCopyTensorToPluggableDevice( +absl::Status PluggableDevice::MaybeCopyTensorToPluggableDevice( const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to, StatusCallback done) { if (alloc_attrs.on_host()) { @@ -362,8 +366,9 @@ Status PluggableDevice::MaybeCopyTensorToPluggableDevice( return absl::OkStatus(); } else { if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::Internal("PluggableDevice copy from non-DMA ", - DataTypeString(from.dtype()), " tensor"); + absl::Status err = + errors::Internal("PluggableDevice copy from non-DMA ", + DataTypeString(from.dtype()), " tensor"); done(err); return err; } @@ -374,14 +379,15 @@ Status PluggableDevice::MaybeCopyTensorToPluggableDevice( // If the tensor is not initialized, we likely ran out of memory. if (!copy->IsInitialized()) { delete copy; - Status err = errors::ResourceExhausted( + absl::Status err = errors::ResourceExhausted( "OOM when allocating tensor of shape ", from.shape().DebugString(), " and type ", DataTypeString(from.dtype())); done(err); return err; } - auto wrapped_done = [to, copy, done = std::move(done)](const Status& s) { + auto wrapped_done = [to, copy, + done = std::move(done)](const absl::Status& s) { if (s.ok()) { *to = std::move(*copy); } @@ -395,7 +401,7 @@ Status PluggableDevice::MaybeCopyTensorToPluggableDevice( } } -Status PluggableDevice::MakeTensorFromProto( +absl::Status PluggableDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { AllocatorAttributes attr; @@ -415,7 +421,7 @@ Status PluggableDevice::MakeTensorFromProto( Variant* copy_variant = copy.flat().data(); std::list notifications; - Status copy_status; + absl::Status copy_status; auto copier = [this, &alloc_attrs, ¬ifications, ©_status]( const Tensor& from, Tensor* to) { // Copier isn't run in a multithreaded environment, so we don't @@ -423,14 +429,14 @@ Status PluggableDevice::MakeTensorFromProto( notifications.emplace_back(); Notification& n = *notifications.rbegin(); return MaybeCopyTensorToPluggableDevice( - alloc_attrs, from, to, [&n, ©_status](const Status& s) { + alloc_attrs, from, to, [&n, ©_status](const absl::Status& s) { if (copy_status.ok()) { copy_status.Update(s); } n.Notify(); }); }; - Status s; + absl::Status s; for (int64_t ix = 0; ix < parsed.NumElements(); ++ix) { s = VariantDeviceCopy(VariantDeviceCopyDirection::HOST_TO_DEVICE, from[ix], ©_variant[ix], copier); @@ -448,9 +454,9 @@ Status PluggableDevice::MakeTensorFromProto( return copy_status; } else { Notification n; - Status status; + absl::Status status; TF_RETURN_IF_ERROR(MaybeCopyTensorToPluggableDevice( - alloc_attrs, parsed, tensor, [&n, &status](const Status& s) { + alloc_attrs, parsed, tensor, [&n, &status](const absl::Status& s) { status = s; n.Notify(); })); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h index 74ad5893921ecc..80c46b0a865310 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" @@ -28,13 +29,16 @@ limitations under the License. #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" #include "tensorflow/core/common_runtime/shared_counter.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" @@ -53,20 +57,20 @@ class PluggableDevice : public LocalDevice { ~PluggableDevice() override; // Initialize the device and return the status of initialization. - Status Init(const SessionOptions& options); + absl::Status Init(const SessionOptions& options); void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) override; void Compute(OpKernel* op_kernel, OpKernelContext* context) override; - Status Sync() override; + absl::Status Sync() override; Allocator* GetAllocator(AllocatorAttributes attr) override; - Status MakeTensorFromProto(const TensorProto& tensor_proto, - AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + AllocatorAttributes alloc_attrs, + Tensor* tensor) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, const DeviceContext* device_context, @@ -84,7 +88,7 @@ class PluggableDevice : public LocalDevice { se::Stream* compute = nullptr; se::Stream* host_to_device = nullptr; se::Stream* device_to_host = nullptr; - gtl::InlinedVector device_to_device; + absl::InlinedVector device_to_device; }; class StreamGroupFactory; @@ -107,7 +111,7 @@ class PluggableDevice : public LocalDevice { // allocate memory or if the tensor "from" is not DMA-copyable. // If there is no error prior to enqueueing the copy, an OK status // is returned. - Status MaybeCopyTensorToPluggableDevice( + absl::Status MaybeCopyTensorToPluggableDevice( const AllocatorAttributes& alloc_attrs, const Tensor& from, Tensor* to, StatusCallback done); }; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc index eba49bb8a1aff9..1523c64a1ae4ed 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.cc @@ -17,7 +17,12 @@ limitations under the License. #include -#include "tensorflow/core/lib/strings/strcat.h" +#include "absl/log/log.h" +#include "absl/memory/memory.h" +#include "xla/tsl/framework/bfc_allocator.h" +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc index ec3faf2d6329ca..2c67fd687a74ba 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.cc @@ -17,12 +17,14 @@ limitations under the License. #include -#include "tensorflow/core/common_runtime/device.h" +#include "absl/status/status.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" namespace tensorflow { @@ -50,8 +52,9 @@ void PluggableDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, device, this, input_tensor, output_tensor, done); } -Status PluggableDeviceContext::ThenExecute(Device* device, se::Stream* stream, - std::function func) { +absl::Status PluggableDeviceContext::ThenExecute(Device* device, + se::Stream* stream, + std::function func) { const DeviceBase::AcceleratorDeviceInfo* device_info = device->tensorflow_accelerator_device_info(); device_info->event_mgr->ThenExecute(stream, func); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h index 5798be1c13bc74..8ec93d3fd51fa7 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h @@ -20,7 +20,10 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/stringpiece.h" namespace stream_executor { class Stream; @@ -34,7 +37,7 @@ class PluggableDeviceContext : public DeviceContext { PluggableDeviceContext( int stream_id, se::Stream* stream, se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, - gtl::InlinedVector device_to_device_stream) + absl::InlinedVector device_to_device_stream) : stream_id_(stream_id), stream_(stream), host_to_device_stream_(host_to_device_stream), @@ -66,8 +69,8 @@ class PluggableDeviceContext : public DeviceContext { void MaintainLifetimeOnStream(const Tensor* t, se::Stream* stream) const override {} - Status ThenExecute(Device* device, se::Stream* stream, - std::function func) override; + absl::Status ThenExecute(Device* device, se::Stream* stream, + std::function func) override; bool IsPluggableDevice() override; @@ -81,7 +84,7 @@ class PluggableDeviceContext : public DeviceContext { // The stream to use for copying data from PluggableDevice to host. se::Stream* device_to_host_stream_; // Streams to use for copying data between PluggableDevices. - gtl::InlinedVector device_to_device_stream_; + absl::InlinedVector device_to_device_stream_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc index b42c0bf1affdaf..4d020d09248afa 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc @@ -82,10 +82,10 @@ int64_t MinSystemMemory(int64_t available_memory) { // Get the memory limit for the virtual device being created on PluggableDevice // with 'platform_device_id', when that virtual device is the only // virtual device being created on that PluggableDevice. -Status SingleVirtualDeviceMemoryLimit(const string& platform_name, - const GPUOptions& device_options, - PlatformDeviceId platform_device_id, - int64_t* memory_limit) { +absl::Status SingleVirtualDeviceMemoryLimit(const string& platform_name, + const GPUOptions& device_options, + PlatformDeviceId platform_device_id, + int64_t* memory_limit) { int64_t total_memory = 0; int64_t available_memory = 0; se::Platform* platform = PluggableDeviceMachineManager(platform_name); @@ -123,7 +123,7 @@ PluggableDeviceFactory::PluggableDeviceFactory(const string& device_type, const string& platform_name) : device_type_(device_type), platform_name_(platform_name) {} -Status PluggableDeviceFactory::ListPhysicalDevices( +absl::Status PluggableDeviceFactory::ListPhysicalDevices( std::vector* devices) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); @@ -138,7 +138,7 @@ Status PluggableDeviceFactory::ListPhysicalDevices( return absl::OkStatus(); } -Status PluggableDeviceFactory::GetDeviceDetails( +absl::Status PluggableDeviceFactory::GetDeviceDetails( int device_index, std::unordered_map* details) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); se::Platform* platform = PluggableDeviceMachineManager(platform_name_); @@ -162,7 +162,7 @@ Status PluggableDeviceFactory::GetDeviceDetails( return absl::OkStatus(); } -Status PluggableDeviceFactory::CreateDevices( +absl::Status PluggableDeviceFactory::CreateDevices( const SessionOptions& options, const string& name_prefix, std::vector>* devices) { TF_RETURN_IF_ERROR(ValidatePluggableDeviceMachineManager(platform_name_)); @@ -221,7 +221,7 @@ static string GetShortDeviceDescription(PlatformDeviceId platform_device_id, ", pci bus id: ", desc.pci_bus_id()); } -Status PluggableDeviceFactory::CreatePluggableDevice( +absl::Status PluggableDeviceFactory::CreatePluggableDevice( const SessionOptions& options, const string& name_prefix, TfDeviceId tf_device_id, int64_t memory_limit, const DeviceLocality& dev_locality, @@ -277,7 +277,7 @@ Status PluggableDeviceFactory::CreatePluggableDevice( return absl::OkStatus(); } -Status PluggableDeviceFactory::GetDeviceLocalities( +absl::Status PluggableDeviceFactory::GetDeviceLocalities( int num_tf_devices, std::vector* device_localities) { for (int i = 0; i < num_tf_devices; ++i) { TfDeviceId tf_device_id(i); diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h index c423dd2a1fc13a..02cc8d2cc0d023 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h @@ -24,6 +24,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -31,26 +34,26 @@ class PluggableDeviceFactory : public DeviceFactory { public: PluggableDeviceFactory(const string& device_type, const string& platform_name); - Status ListPhysicalDevices(std::vector* devices) override; - Status CreateDevices(const SessionOptions& options, - const std::string& name_prefix, - std::vector>* devices) override; - Status GetDeviceDetails(int device_index, - std::unordered_map* details) override; + absl::Status ListPhysicalDevices(std::vector* devices) override; + absl::Status CreateDevices( + const SessionOptions& options, const std::string& name_prefix, + std::vector>* devices) override; + absl::Status GetDeviceDetails( + int device_index, std::unordered_map* details) override; private: // Populates *device_localities with the DeviceLocality descriptor for // every TfDeviceId. - Status GetDeviceLocalities(int num_tf_devices, - std::vector* device_localities); + absl::Status GetDeviceLocalities( + int num_tf_devices, std::vector* device_localities); // Create a PluggableDevice associated with 'tf_device_id', allocates // (strictly) 'memory_limit' bytes of PluggableDevice memory to it, and adds // it to the 'devices' vector. - Status CreatePluggableDevice(const SessionOptions& options, - const std::string& name_prefix, - TfDeviceId tf_device_id, int64_t memory_limit, - const DeviceLocality& dev_locality, - std::vector>* devices); + absl::Status CreatePluggableDevice( + const SessionOptions& options, const std::string& name_prefix, + TfDeviceId tf_device_id, int64_t memory_limit, + const DeviceLocality& dev_locality, + std::vector>* devices); const string device_type_; const string platform_name_; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc index 0b7279d0098ac1..8c6eaea69aee4c 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.cc @@ -18,19 +18,14 @@ limitations under the License. #include #include "xla/stream_executor/platform_manager.h" -#include "tensorflow/core/common_runtime/device_factory.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/stream_executor_util.h" namespace tensorflow { -Status ValidatePluggableDeviceMachineManager(const string& platform_name) { +absl::Status ValidatePluggableDeviceMachineManager( + const string& platform_name) { return se::PlatformManager::PlatformWithName(platform_name).status(); } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h index 6362de9856ae38..0afd6063880bde 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace stream_executor { class Platform; @@ -28,7 +29,7 @@ namespace tensorflow { // Initializes the PluggableDevice platform and returns OK if the // PluggableDevice platform could be initialized. -Status ValidatePluggableDeviceMachineManager(const string& platform_name); +absl::Status ValidatePluggableDeviceMachineManager(const string& platform_name); // Returns the PluggableDevice machine manager singleton, creating it and // initializing the PluggableDevices on the machine if needed the first time it diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc index a3879acd5daa08..ca2b3ca2574407 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/c/experimental/grappler/grappler_internal.h" #include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h" @@ -31,18 +33,20 @@ limitations under the License. #include "tensorflow/core/common_runtime/next_pluggable_device/next_pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace tensorflow { -static Status InitDeviceModule(void* dso_handle) { +static absl::Status InitDeviceModule(void* dso_handle) { void* dso_symbol; tensorflow::Env* env = tensorflow::Env::Default(); - Status status = + absl::Status status = env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol); if (absl::IsNotFound(status)) { @@ -72,12 +76,12 @@ static Status InitDeviceModule(void* dso_handle) { } typedef const PJRT_Api* (*PjrtApiInitFn)(); -static Status InitNextPluggableDeviceModule(void* dso_handle) { +static absl::Status InitNextPluggableDeviceModule(void* dso_handle) { void* dso_symbol; tensorflow::Env* env = tensorflow::Env::Default(); // Loads the next pluggable device. - Status status = + absl::Status status = env->GetSymbolFromLibrary(dso_handle, "TFNPD_InitPlugin", &dso_symbol); if (absl::IsNotFound(status)) { VLOG(1) << "Next pluggable device module not found."; @@ -143,10 +147,10 @@ static Status InitNextPluggableDeviceModule(void* dso_handle) { return absl::OkStatus(); } -static Status InitGraphModule(void* dso_handle) { +static absl::Status InitGraphModule(void* dso_handle) { void* dso_symbol; tensorflow::Env* env = tensorflow::Env::Default(); - Status status = + absl::Status status = env->GetSymbolFromLibrary(dso_handle, "TF_InitGraph", &dso_symbol); if (absl::IsNotFound(status)) { @@ -163,10 +167,10 @@ static Status InitGraphModule(void* dso_handle) { } typedef void (*TFKernelInitFn)(); -static Status InitKernelModule(void* dso_handle) { +static absl::Status InitKernelModule(void* dso_handle) { void* dso_symbol; tensorflow::Env* env = tensorflow::Env::Default(); - Status status = + absl::Status status = env->GetSymbolFromLibrary(dso_handle, "TF_InitKernel", &dso_symbol); if (absl::IsNotFound(status)) { @@ -183,11 +187,11 @@ static Status InitKernelModule(void* dso_handle) { return absl::OkStatus(); } -static Status InitProfilerModule(void* dso_handle) { +static absl::Status InitProfilerModule(void* dso_handle) { void* dso_symbol; tensorflow::Env* env = tensorflow::Env::Default(); - Status status = + absl::Status status = env->GetSymbolFromLibrary(dso_handle, "TF_InitProfiler", &dso_symbol); if (absl::IsNotFound(status)) { @@ -204,7 +208,7 @@ static Status InitProfilerModule(void* dso_handle) { return absl::OkStatus(); } -Status RegisterPluggableDevicePlugin(void* dso_handle) { +absl::Status RegisterPluggableDevicePlugin(void* dso_handle) { // All modules are optional. Only return an error when a module is found but // has issues in loading / initializing. // Step 1 Init Device Module. diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h index 23b9af1d58ad47..9676a70662ee4d 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.h @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -Status RegisterPluggableDevicePlugin(void* library_filename); +absl::Status RegisterPluggableDevicePlugin(void* library_filename); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc index f9f0ad68977516..9a591508200e56 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.cc @@ -20,29 +20,33 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/device_id_utils.h" +#include "tensorflow/core/common_runtime/bfc_allocator.h" #include "tensorflow/core/common_runtime/device/device_host_allocator.h" #include "tensorflow/core/common_runtime/device/device_id.h" #include "tensorflow/core/common_runtime/device/device_id_manager.h" #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_id_utils.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_bfc_allocator.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_init.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h" -#include "tensorflow/core/common_runtime/pool_allocator.h" -#include "tensorflow/core/common_runtime/shared_counter.h" +#include "tensorflow/core/common_runtime/process_state.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/platform/numa.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/env_var.h" +#include "tsl/platform/status.h" namespace tensorflow { @@ -192,9 +196,9 @@ Allocator* PluggableDeviceProcessState::GetPluggableDeviceHostAllocator( se, numa_node, pluggable_device_host_alloc_visitors_[numa_node], pluggable_device_host_free_visitors_[numa_node]); int64_t pluggable_device_host_mem_limit_in_mb = -1; - Status status = ReadInt64FromEnvVar("TF_GPU_HOST_MEM_LIMIT_IN_MB", - 1LL << 17 /*128GB max by default*/, - &pluggable_device_host_mem_limit_in_mb); + absl::Status status = ReadInt64FromEnvVar( + "TF_GPU_HOST_MEM_LIMIT_IN_MB", 1LL << 17 /*128GB max by default*/, + &pluggable_device_host_mem_limit_in_mb); if (!status.ok()) { LOG(ERROR) << "GetPluggableDeviceHostAllocator: " << status.message(); } diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc index c8fee4b0ef6ec9..4fdc86fa045b50 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.cc @@ -16,7 +16,8 @@ limitations under the License. #include -#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/framework/allocator.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h index 3a46f766809449..dccb2548868f1e 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_simple_allocator.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/device/device_mem_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc index ecb08eec856dc9..7e0ccf4cf5176a 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.cc @@ -15,26 +15,21 @@ limitations under the License. #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h" -#include "tensorflow/core/common_runtime/copy_tensor.h" -#include "tensorflow/core/common_runtime/device.h" +#include "absl/status/status.h" +#include "xla/stream_executor/device_memory.h" #include "tensorflow/core/common_runtime/device/device_event_mgr.h" -#include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_context.h" -#include "tensorflow/core/common_runtime/pluggable_device/pluggable_device_process_state.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_reference.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/hash/hash.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor.h" -#include "tensorflow/core/platform/tensor_coding.h" -#include "tensorflow/core/util/util.h" +#include "tensorflow/core/platform/status.h" // IMPLEMENTATION NOTE: // @@ -52,10 +47,10 @@ namespace tensorflow { using se::DeviceMemoryBase; -static Status PrepareCopy(Device* device, const DeviceContext* ctx, - const Tensor& src, const Tensor* dst, - const DeviceBase::AcceleratorDeviceInfo** dev_info, - se::Stream** stream) { +static absl::Status PrepareCopy( + Device* device, const DeviceContext* ctx, const Tensor& src, + const Tensor* dst, const DeviceBase::AcceleratorDeviceInfo** dev_info, + se::Stream** stream) { if (device == nullptr) { return errors::Internal("Unexpected null device."); } @@ -113,8 +108,8 @@ void PluggableDeviceUtil::DeviceToDeviceCopy( int dev_to_dev_stream_index, StatusCallback done) { const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info, - &send_stream); + absl::Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info, + &send_stream); if (!s.ok()) { done(s); return; @@ -188,8 +183,8 @@ void PluggableDeviceUtil::CopyPluggableDeviceTensorToCPU( VLOG(1) << "CopyPluggableDeviceTensorToCPU"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(device, device_context, *device_tensor, cpu_tensor, - &dev_info, &send_stream); + absl::Status s = PrepareCopy(device, device_context, *device_tensor, + cpu_tensor, &dev_info, &send_stream); if (!s.ok()) { done(s); return; @@ -244,8 +239,8 @@ void PluggableDeviceUtil::CopyCPUTensorToPluggableDevice( VLOG(1) << "CopyCPUTensorToPluggableDevice"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* recv_stream = nullptr; - Status s = PrepareCopy(device, device_context, *cpu_tensor, device_tensor, - &dev_info, &recv_stream); + absl::Status s = PrepareCopy(device, device_context, *cpu_tensor, + device_tensor, &dev_info, &recv_stream); if (!s.ok()) { done(s); return; @@ -293,7 +288,7 @@ void PluggableDeviceUtil::CopyCPUTensorToPluggableDevice( }); } -Status PluggableDeviceUtil::Sync(Device* device) { +absl::Status PluggableDeviceUtil::Sync(Device* device) { VLOG(1) << "PluggableDeviceUtil::Sync"; auto* dev_info = device->tensorflow_accelerator_device_info(); if (!dev_info) { @@ -302,7 +297,7 @@ Status PluggableDeviceUtil::Sync(Device* device) { return dev_info->stream->BlockHostUntilDone(); } -Status PluggableDeviceUtil::SyncAll(Device* device) { +absl::Status PluggableDeviceUtil::SyncAll(Device* device) { VLOG(1) << "PluggableDeviceUtil::SyncAll"; auto* dev_info = device->tensorflow_accelerator_device_info(); if (!dev_info) { @@ -323,8 +318,8 @@ void PluggableDeviceUtil::CopyPluggableDeviceTensorToSameDevice( VLOG(1) << "CopyPluggableDeviceTensorToSameDevice"; const DeviceBase::AcceleratorDeviceInfo* dev_info = nullptr; se::Stream* send_stream = nullptr; - Status s = PrepareCopy(device, device_context, *src_device_tensor, - dst_device_tensor, &dev_info, &send_stream); + absl::Status s = PrepareCopy(device, device_context, *src_device_tensor, + dst_device_tensor, &dev_info, &send_stream); if (!s.ok()) { done(s); return; diff --git a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h index 8cff5449c853f5..51770e5a6fec9c 100644 --- a/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h +++ b/tensorflow/core/common_runtime/pluggable_device/pluggable_device_util.h @@ -18,6 +18,9 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/stream_executor.h" @@ -39,12 +42,12 @@ class PluggableDeviceUtil { // Blocks until all operations queued on the stream associated with // 'device' at the time of the call have completed. Returns any // error pending on the stream at completion. - static Status Sync(Device* device); + static absl::Status Sync(Device* device); // Blocks until all operations queued on all streams associated with the // corresponding 'device' at the time of call have completed. // Returns any error pending on the stream at completion. - static Status SyncAll(Device* device); + static absl::Status SyncAll(Device* device); static void CopyCPUTensorToPluggableDevice( const Tensor* cpu_tensor, const DeviceContext* device_context, diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 52f8c3c8df00b4..3a8e949e2253be 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -106,7 +106,7 @@ void ProcessFunctionLibraryRuntime::FunctionData::DistributedInit( init_started_ = true; } parent->Instantiate(function_name, lib_def, attrs, options, &local_handle_, - [this, done](const Status& s) { + [this, done](const absl::Status& s) { init_done_.Notify(); done(s); }); @@ -147,7 +147,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } /* static */ -Status ProcessFunctionLibraryRuntime::SendTensors( +absl::Status ProcessFunctionLibraryRuntime::SendTensors( const string& source_device, const string& target_device, const string& key_prefix, int64_t src_incarnation, absl::Span tensors_to_send, DeviceContext* device_context, @@ -184,7 +184,7 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( received_tensors, std::move(done)); } -Status ProcessFunctionLibraryRuntime::GetRetTypes( +absl::Status ProcessFunctionLibraryRuntime::GetRetTypes( FunctionLibraryRuntime::Handle h, DataTypeVector* ret_types) { FunctionLibraryRuntime* flr = nullptr; { @@ -205,7 +205,7 @@ Status ProcessFunctionLibraryRuntime::GetRetTypes( return errors::InvalidArgument("Handle ", h, " not found."); } -Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( +absl::Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( const string& device_name, int64_t* incarnation) const { FunctionLibraryRuntime* flr = GetFLR(device_name); if (flr == nullptr) { @@ -215,7 +215,7 @@ Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation( return absl::OkStatus(); } -Status ProcessFunctionLibraryRuntime::GetDeviceContext( +absl::Status ProcessFunctionLibraryRuntime::GetDeviceContext( const string& device_name, DeviceContext** device_context) const { *device_context = nullptr; FunctionLibraryRuntime* flr = GetFLR(device_name); @@ -413,7 +413,7 @@ std::vector GetLocalArgs(absl::Span args) { FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback( std::vector* rets, std::vector* tensors, FunctionLibraryRuntime::DoneCallback done) { - return [rets, tensors, done = std::move(done)](const Status& s) { + return [rets, tensors, done = std::move(done)](const absl::Status& s) { if (s.ok()) { for (const auto& t : *tensors) { rets->push_back(t); @@ -425,8 +425,9 @@ FunctionLibraryRuntime::DoneCallback TensorsToFunctionRetsDoneCallback( } // Push Tensors in `function_rets` into `tensors`. -Status FunctionRetsToTensors(const std::vector* function_rets, - std::vector* tensors) { +absl::Status FunctionRetsToTensors( + const std::vector* function_rets, + std::vector* tensors) { for (const auto& ret : *function_rets) { if (ret.index() != 0) { return errors::Internal( @@ -493,7 +494,7 @@ void ProcessFunctionLibraryRuntime::PublishSubgraphs( stats_publishers_.push_back(std::move(stats_publisher)); } -Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( +absl::Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( const string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { @@ -637,7 +638,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( &data_lib_def, absl::StrCat(function_name, "_partitioned_", random::New64())); const int num_subgraphs = subgraphs->size(); - gtl::InlinedVector instantiate_status(num_subgraphs); + absl::InlinedVector instantiate_status(num_subgraphs); // Before instantiating component functions, determine synchronous execution. data->enable_sync_execution = false; @@ -660,7 +661,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( &data](const string& target, std::unique_ptr subgraph, ComponentFunctionData* comp_data, - std::function done) { + std::function done) { const string& device_type = dev_set->FindDeviceByName(target)->device_type(); @@ -669,7 +670,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( device_type == "XLA_GPU" || options.int_args_and_retvals_on_device); Int32FulltypePass int32_fulltype( "ProcessFunctionLibraryRuntime::InstantiateMultiDevice"); - Status s = int32_fulltype.ProcessGraph(subgraph.get(), ints_on_device); + absl::Status s = + int32_fulltype.ProcessGraph(subgraph.get(), ints_on_device); if (!s.ok()) { done(s); return; @@ -726,7 +728,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( auto* component_handle = new FunctionLibraryRuntime::Handle; auto wrapped_done = [this, comp_data, component_handle, &data, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { VLOG(1) << "Finished instantiating component function " << comp_data->name << " with handle " << *component_handle << " status: " << s; if (s.ok()) { @@ -745,8 +747,8 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( FunctionLibraryRuntime* flr = GetFLR(opts.target); if (flr != nullptr) { // Initialize local function synchronously. - Status s = flr->Instantiate(comp_data->name, AttrSlice(&attrs), opts, - component_handle); + absl::Status s = flr->Instantiate(comp_data->name, AttrSlice(&attrs), + opts, component_handle); wrapped_done(s); } else { opts.ret_indices = comp_data->ret_indices; @@ -767,13 +769,13 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( num_subgraphs > GetParallelSubgraphThreshold()) { BlockingCounter counter(static_cast(num_subgraphs)); for (auto& pair : *subgraphs) { - Status* status = &instantiate_status[i]; + absl::Status* status = &instantiate_status[i]; ComponentFunctionData* comp_data = &data->glue_[pair.first]; comp_data->name = name_generator.GetName(); default_thread_pool_->Schedule( [&instantiate_component, &pair, comp_data, &counter, status]() { instantiate_component(pair.first, std::move(pair.second), comp_data, - [&counter, status](Status s) { + [&counter, status](absl::Status s) { status->Update(s); counter.DecrementCount(); }); @@ -784,11 +786,11 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( } else { for (auto& pair : *subgraphs) { Notification n; - Status* status = &instantiate_status[i]; + absl::Status* status = &instantiate_status[i]; ComponentFunctionData* comp_data = &data->glue_[pair.first]; comp_data->name = name_generator.GetName(); instantiate_component(pair.first, std::move(pair.second), comp_data, - [&n, status](Status s) { + [&n, status](absl::Status s) { status->Update(s); n.Notify(); }); @@ -823,7 +825,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice( return absl::OkStatus(); } -Status ProcessFunctionLibraryRuntime::GetOutputDevices( +absl::Status ProcessFunctionLibraryRuntime::GetOutputDevices( FunctionLibraryRuntime::Handle handle, std::vector* output_devices) const { MultiDeviceFunctionData* data = IsMultiDevice(handle); @@ -870,7 +872,7 @@ Status ProcessFunctionLibraryRuntime::GetOutputDevices( return absl::OkStatus(); } -Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( +absl::Status ProcessFunctionLibraryRuntime::PrepareRunMultiDevice( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, const MultiDeviceFunctionData** data) const { @@ -924,14 +926,15 @@ std::vector ProcessFunctionLibraryRuntime::GetOrderedSubgraphs( return subgraph_keys; } -Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( +absl::Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle outer_handle, std::vector* rets, - std::function + std::function get_component_args) const { const MultiDeviceFunctionData* data; - Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data); + absl::Status prepare_status = + PrepareRunMultiDevice(opts, outer_handle, &data); if (!prepare_status.ok()) { return prepare_status; } @@ -963,7 +966,7 @@ Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( opts_copy.rets_alloc_attrs = comp_data.ret_alloc_attrs; InternalArgs comp_args; - Status args_status = get_component_args(comp_data, &comp_args); + absl::Status args_status = get_component_args(comp_data, &comp_args); if (!args_status.ok()) { VLOG(2) << "Failed to get component function arguments: " << args_status; return args_status; @@ -982,7 +985,7 @@ Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( VLOG(4) << " with " << opts_copy.DebugString(); std::vector comp_tensor_rets; - Status run_status = + absl::Status run_status = flr->RunSync(opts_copy, comp_handle, GetLocalArgs(comp_args.args), &comp_tensor_rets); if (!run_status.ok()) { @@ -1005,10 +1008,10 @@ Status ProcessFunctionLibraryRuntime::RunMultiDeviceSync( std::vector> cleanup_items; Notification n; - Status s; + absl::Status s; std::vector comp_rets; RunInternal(opts_copy, comp_handle, comp_args.args, &comp_rets, - &cleanup_items, [&n, &s](const Status& status) { + &cleanup_items, [&n, &s](const absl::Status& status) { s.Update(status); n.Notify(); }); @@ -1024,11 +1027,12 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( FunctionLibraryRuntime::Handle outer_handle, std::vector* rets, std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done, - std::function + std::function get_component_args) const { const MultiDeviceFunctionData* data; - Status prepare_status = PrepareRunMultiDevice(opts, outer_handle, &data); + absl::Status prepare_status = + PrepareRunMultiDevice(opts, outer_handle, &data); if (!prepare_status.ok()) { done(prepare_status); return; @@ -1059,7 +1063,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( opts_copy.cancellation_manager = cm; InternalArgs comp_args; - Status s = get_component_args(comp_data, &comp_args); + absl::Status s = get_component_args(comp_data, &comp_args); if (!s.ok()) { VLOG(2) << "Failed to get component function arguments: " << s; refcounted_done->UpdateStatus(s); @@ -1072,7 +1076,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( auto component_fn_callback = [comp_rets, rets, comp_data, refcounted_done, cm, local_cm, data, comp_handle, - target](const Status& status) { + target](const absl::Status& status) { if (!status.ok()) { VLOG(2) << "Component function execution on target " << target << " from " << data->function_name_ << " with handle " @@ -1129,7 +1133,7 @@ void ProcessFunctionLibraryRuntime::RunMultiDeviceAsync( refcounted_done->Unref(); } -Status ProcessFunctionLibraryRuntime::Instantiate( +absl::Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle) { @@ -1143,10 +1147,10 @@ Status ProcessFunctionLibraryRuntime::Instantiate( return flr->Instantiate(function_name, attrs, options, handle); } - Status status; + absl::Status status; Notification notification; InstantiateRemote(function_name, attrs, options, handle, - [&status, ¬ification](const Status& s) { + [&status, ¬ification](const absl::Status& s) { status = s; notification.Notify(); }); @@ -1154,7 +1158,7 @@ Status ProcessFunctionLibraryRuntime::Instantiate( return status; } -Status ProcessFunctionLibraryRuntime::IsCrossProcess( +absl::Status ProcessFunctionLibraryRuntime::IsCrossProcess( FunctionLibraryRuntime::Handle handle, bool* is_cross_process) const { tf_shared_lock l(mu_); const auto& mdevice_it = mdevice_data_.find(handle); @@ -1198,7 +1202,7 @@ void ProcessFunctionLibraryRuntime::InstantiateRemote( f->DistributedInit( parent_, function_name, options.lib_def == nullptr ? *lib_def_ : *options.lib_def, attrs, options, - [this, function_name, target, handle, done](const Status& s) { + [this, function_name, target, handle, done](const absl::Status& s) { VLOG(1) << "ProcessFLR Instantiate [success]: " << function_name << " on: " << target << " with handle: " << *handle << " (this: " << this << ")"; @@ -1206,7 +1210,7 @@ void ProcessFunctionLibraryRuntime::InstantiateRemote( }); } -Status ProcessFunctionLibraryRuntime::RemoveHandle( +absl::Status ProcessFunctionLibraryRuntime::RemoveHandle( FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); table_.erase(function_data_[handle]->function_key()); @@ -1214,7 +1218,7 @@ Status ProcessFunctionLibraryRuntime::RemoveHandle( return absl::OkStatus(); } -Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( +absl::Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( FunctionLibraryRuntime::Handle handle) { std::unique_ptr mdata; { @@ -1231,7 +1235,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( // If we are here we are releasing the last instantiation of `handle`. // Release all component function handles. - Status overall_status; + absl::Status overall_status; for (const auto& it : mdata->glue_) { const string& device = it.first; FunctionLibraryRuntime::Handle flr_handle = it.second.handle; @@ -1248,7 +1252,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( "Failed to find FunctionLibraryRuntime for device ", device, " when releasing multi-device function handle ", handle); } - Status status = flr->ReleaseHandle(flr_handle); + absl::Status status = flr->ReleaseHandle(flr_handle); if (!status.ok()) { overall_status = status; } @@ -1257,7 +1261,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseMultiDeviceHandle( return overall_status; } -Status ProcessFunctionLibraryRuntime::ReleaseHandle( +absl::Status ProcessFunctionLibraryRuntime::ReleaseHandle( FunctionLibraryRuntime::Handle handle) { // Return directly if all function handles has already been released. if (flr_map_ == nullptr) return absl::OkStatus(); @@ -1288,12 +1292,12 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback( tsl::core::RefCountPtr created_rendezvous) const { return [this, items, done = std::move(done), step_id = opts.step_id, created_rendezvous = - created_rendezvous.release()](const Status& status) { + created_rendezvous.release()](const absl::Status& status) { if (created_rendezvous != nullptr) { created_rendezvous->Unref(); } - auto* local_status = new Status(status); - CleanUp(items, [local_status, done](const Status& cleanup_status) { + auto* local_status = new absl::Status(status); + CleanUp(items, [local_status, done](const absl::Status& cleanup_status) { local_status->Update(cleanup_status); done(*local_status); delete local_status; @@ -1302,7 +1306,7 @@ ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback( }; } -Status ProcessFunctionLibraryRuntime::CreateRendezvous( +absl::Status ProcessFunctionLibraryRuntime::CreateRendezvous( FunctionLibraryRuntime::Options& opts, tsl::core::RefCountPtr* created_rendezvous) const { DCHECK(opts.rendezvous == nullptr); @@ -1312,7 +1316,8 @@ Status ProcessFunctionLibraryRuntime::CreateRendezvous( "ProcessFunctionLibraryRuntime was created without a rendezvous " "factory."); } - Status s = rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous); + absl::Status s = + rendezvous_factory_(opts.step_id, device_mgr_, created_rendezvous); if (s.ok()) { opts.rendezvous = created_rendezvous->get(); opts.create_rendezvous = false; @@ -1320,7 +1325,7 @@ Status ProcessFunctionLibraryRuntime::CreateRendezvous( return s; } -Status ProcessFunctionLibraryRuntime::GetComponentArgs( +absl::Status ProcessFunctionLibraryRuntime::GetComponentArgs( const absl::Span args, const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data, ProcessFunctionLibraryRuntime::InternalArgs* comp_args) { @@ -1352,7 +1357,7 @@ Status ProcessFunctionLibraryRuntime::GetComponentArgs( } #if !defined(IS_MOBILE_PLATFORM) -Status ProcessFunctionLibraryRuntime::GetComponentArgs( +absl::Status ProcessFunctionLibraryRuntime::GetComponentArgs( const FunctionArgsInterface& args, const ProcessFunctionLibraryRuntime::ComponentFunctionData& comp_data, ProcessFunctionLibraryRuntime::InternalArgs* comp_args) { @@ -1382,7 +1387,7 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::Options new_opts = opts; tsl::core::RefCountPtr created_rendezvous = nullptr; if (!opts.rendezvous) { - Status s = CreateRendezvous(new_opts, &created_rendezvous); + absl::Status s = CreateRendezvous(new_opts, &created_rendezvous); if (!s.ok()) { done(s); return; @@ -1393,8 +1398,8 @@ void ProcessFunctionLibraryRuntime::Run( done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), new_opts, std::move(created_rendezvous)); std::vector* function_rets = new std::vector; - done = [rets, function_rets, done = std::move(done)](const Status& s) { - Status status = s; + done = [rets, function_rets, done = std::move(done)](const absl::Status& s) { + absl::Status status = s; if (status.ok()) { status.Update(FunctionRetsToTensors(function_rets, rets)); } @@ -1404,7 +1409,7 @@ void ProcessFunctionLibraryRuntime::Run( bool multi_device = HasMultiDeviceHandle(handle); if (multi_device) { auto get_component_args = [&args](const ComponentFunctionData& comp_data, - InternalArgs* comp_args) -> Status { + InternalArgs* comp_args) -> absl::Status { return GetComponentArgs(args, comp_data, comp_args); }; return RunMultiDeviceAsync(new_opts, handle, function_rets, cleanup_items, @@ -1453,7 +1458,7 @@ void ProcessFunctionLibraryRuntime::RunInternal( auto rendezvous = opts.rendezvous; string source_device = opts.source_device; DeviceContext* device_context; - Status s = GetDeviceContext(source_device, &device_context); + absl::Status s = GetDeviceContext(source_device, &device_context); if (!s.ok()) { done(s); return; @@ -1482,7 +1487,7 @@ void ProcessFunctionLibraryRuntime::RunInternal( flr->Run(opts, handle, local_args, remote_rets, [source_device, target_device, target_incarnation, rendezvous, device_context, rets_alloc_attrs, remote_rets, rets, - done = std::move(done)](const Status& status) mutable { + done = std::move(done)](const absl::Status& status) mutable { if (!status.ok()) { delete remote_rets; done(status); @@ -1521,7 +1526,7 @@ void ProcessFunctionLibraryRuntime::Run( args.reserve(frame->num_args()); for (size_t i = 0; i < frame->num_args(); ++i) { const Tensor* arg; - Status s = frame->GetArg(i, &arg); + absl::Status s = frame->GetArg(i, &arg); args.emplace_back(*arg); if (!s.ok()) { done(s); @@ -1532,7 +1537,7 @@ void ProcessFunctionLibraryRuntime::Run( Run(opts, handle, args, rets, - [frame, rets, done = std::move(done)](const Status& status) { + [frame, rets, done = std::move(done)](const absl::Status& status) { std::unique_ptr> rets_releaser(rets); if (!status.ok()) { @@ -1549,7 +1554,7 @@ void ProcessFunctionLibraryRuntime::Run( } for (size_t i = 0; i < frame->num_retvals(); ++i) { - Status s = frame->SetRetval(i, (*rets)[i]); + absl::Status s = frame->SetRetval(i, (*rets)[i]); if (!s.ok()) { done(s); return; @@ -1559,7 +1564,7 @@ void ProcessFunctionLibraryRuntime::Run( }); } -Status ProcessFunctionLibraryRuntime::RunSync( +absl::Status ProcessFunctionLibraryRuntime::RunSync( const FunctionLibraryRuntime::Options& orig_opts, FunctionLibraryRuntime::Handle handle, absl::Span args, std::vector* rets) const { @@ -1578,16 +1583,16 @@ Status ProcessFunctionLibraryRuntime::RunSync( return GetComponentArgs(args, comp_data, comp_args); }; - Status status = RunMultiDeviceSync(new_opts, handle, &function_rets, - std::move(get_component_args)); + absl::Status status = RunMultiDeviceSync(new_opts, handle, &function_rets, + std::move(get_component_args)); status.Update(FunctionRetsToTensors(&function_rets, rets)); return status; } else { // TODO(b/207484417): Either handle or avoid/delete this fallback path. metrics::IncrementTestCounter("pflr_runsync", "async"); Notification n; - Status s; - Run(orig_opts, handle, args, rets, [&n, &s](const Status& status) { + absl::Status s; + Run(orig_opts, handle, args, rets, [&n, &s](const absl::Status& status) { s.Update(status); n.Notify(); }); @@ -1596,13 +1601,13 @@ Status ProcessFunctionLibraryRuntime::RunSync( } } -Status ProcessFunctionLibraryRuntime::RunSync( +absl::Status ProcessFunctionLibraryRuntime::RunSync( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, CallFrameInterface* frame) const { // TODO(b/207485199): Implement this as synchronous code. Notification n; - Status s; - Run(opts, handle, frame, [&n, &s](const Status& status) { + absl::Status s; + Run(opts, handle, frame, [&n, &s](const absl::Status& status) { s.Update(status); n.Notify(); }); @@ -1631,7 +1636,7 @@ void ProcessFunctionLibraryRuntime::Run( FunctionLibraryRuntime::Options new_opts = opts; tsl::core::RefCountPtr created_rendezvous = nullptr; if (!opts.rendezvous) { - Status s = CreateRendezvous(new_opts, &created_rendezvous); + absl::Status s = CreateRendezvous(new_opts, &created_rendezvous); if (!s.ok()) { done(s); return; @@ -1648,7 +1653,7 @@ void ProcessFunctionLibraryRuntime::Run( std::move(created_rendezvous)); auto get_component_args = [&args](const ComponentFunctionData& comp_data, - InternalArgs* comp_args) -> Status { + InternalArgs* comp_args) -> absl::Status { return GetComponentArgs(args, comp_data, comp_args); }; return RunMultiDeviceAsync(new_opts, handle, rets, cleanup_items, @@ -1670,7 +1675,7 @@ void ProcessFunctionLibraryRuntime::CleanUp( refcounted_done->Unref(); } else if (parent_ != nullptr) { parent_->CleanUp(item->step_id, item->local_handle, - [refcounted_done](const Status& status) { + [refcounted_done](const absl::Status& status) { if (!status.ok()) { refcounted_done->UpdateStatus(status); } @@ -1686,7 +1691,7 @@ void ProcessFunctionLibraryRuntime::CleanUp( refcounted_done->Unref(); } -Status ProcessFunctionLibraryRuntime::Clone( +absl::Status ProcessFunctionLibraryRuntime::Clone( Env* env, int graph_def_version, const OptimizerOptions& optimizer_options, std::unique_ptr* out_lib_def, std::unique_ptr* out_pflr, diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 8c2aff62fca209..0b3b9dc00e6f6e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -47,14 +47,14 @@ class FunctionArgsInterface { virtual bool HasRemoteOrPackedInputs() const = 0; - virtual Status GetLocalArg(const FunctionArgIndex& index, - Tensor* val) const = 0; + virtual absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const = 0; virtual std::vector GetLocalTensors() const = 0; #if !defined(IS_MOBILE_PLATFORM) - virtual Status GetRemoteArg(const FunctionArgIndex& index, - eager::RemoteTensorHandle* val) const { + virtual absl::Status GetRemoteArg(const FunctionArgIndex& index, + eager::RemoteTensorHandle* val) const { return errors::Unimplemented( "Serializing a remote argument is not implemented."); } @@ -92,13 +92,12 @@ class ProcessFunctionLibraryRuntime { // doing the sending. `alloc_attrs` should either be empty or be the size of // `tensors_to_send` and indicates how the input tensors are allocated. Method // takes references on each of the `tensors_to_send`. Method doesn't block. - static Status SendTensors(const string& source_device, - const string& target_device, - const string& key_prefix, int64_t src_incarnation, - absl::Span tensors_to_send, - DeviceContext* device_context, - const std::vector& alloc_attrs, - RendezvousInterface* rendezvous); + static absl::Status SendTensors( + const string& source_device, const string& target_device, + const string& key_prefix, int64_t src_incarnation, + absl::Span tensors_to_send, DeviceContext* device_context, + const std::vector& alloc_attrs, + RendezvousInterface* rendezvous); // Receives `received_tensors` from `target_device` (originally sent from // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the @@ -119,12 +118,12 @@ class ProcessFunctionLibraryRuntime { FunctionLibraryRuntime* GetFLR(const string& device_name) const; // Returns the return types for the function identified by handle `h`. - Status GetRetTypes(FunctionLibraryRuntime::Handle h, - DataTypeVector* ret_types); + absl::Status GetRetTypes(FunctionLibraryRuntime::Handle h, + DataTypeVector* ret_types); // Returns the device incarnation for the given device_name. - Status GetDeviceIncarnation(const string& device_name, - int64_t* incarnation) const; + absl::Status GetDeviceIncarnation(const string& device_name, + int64_t* incarnation) const; // For a given canonicalized key signature of the function instantiated // on device `device_name` and a `local_handle`, creates a handle and returns @@ -154,20 +153,21 @@ class ProcessFunctionLibraryRuntime { // is set to nullptr. If some output is DT_RESOURCE, the corresponding Device* // is set to the device backing the resource. // REQUIRES: `handle` identifies a multi-device function. - Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, - std::vector* output_devices) const; + absl::Status GetOutputDevices(FunctionLibraryRuntime::Handle handle, + std::vector* output_devices) const; // Instantiates the function. See framework/function.h for more details. // Allows for function_name to be instantiated on different devices // as specified in attrs. - Status Instantiate(const string& function_name, AttrSlice attrs, - const FunctionLibraryRuntime::InstantiateOptions& options, - FunctionLibraryRuntime::Handle* handle); + absl::Status Instantiate( + const string& function_name, AttrSlice attrs, + const FunctionLibraryRuntime::InstantiateOptions& options, + FunctionLibraryRuntime::Handle* handle); // Returns whether the function represented by the given handle needs to // execute cross process. - Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, - bool* is_cross_process) const; + absl::Status IsCrossProcess(FunctionLibraryRuntime::Handle handle, + bool* is_cross_process) const; // Delegates to the local FLR that owns state corresponding to `handle` and // tells it to release it. If the `handle` isn't needed at all, the local FLR @@ -176,7 +176,7 @@ class ProcessFunctionLibraryRuntime { // For multi-device functions, calls ReleaseHandle on local FLRs for each // component function that is part of this multi-device function. // Each local FLR might call RemoveHandle on this. - Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); + absl::Status ReleaseHandle(FunctionLibraryRuntime::Handle handle); // Runs the function with given `handle`. Function could have been // instantiated on any device. More details in framework/function.h @@ -193,13 +193,13 @@ class ProcessFunctionLibraryRuntime { const FunctionArgsInterface& args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) const; - Status RunSync(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, - absl::Span args, - std::vector* rets) const; - Status RunSync(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, - CallFrameInterface* frame) const; + absl::Status RunSync(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + absl::Span args, + std::vector* rets) const; + absl::Status RunSync(const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + CallFrameInterface* frame) const; const DeviceMgr* device_mgr() { return device_mgr_; } @@ -343,15 +343,15 @@ class ProcessFunctionLibraryRuntime { // For a given device_name, returns a DeviceContext for copying // tensors to/from the device. - Status GetDeviceContext(const string& device_name, - DeviceContext** device_context) const; + absl::Status GetDeviceContext(const string& device_name, + DeviceContext** device_context) const; // Looks up the information for the given `handle` and returns the name // of the device where the function is registered. string GetDeviceName(FunctionLibraryRuntime::Handle handle) const; // Removes handle from the state owned by this object. - Status RemoveHandle(FunctionLibraryRuntime::Handle handle); + absl::Status RemoveHandle(FunctionLibraryRuntime::Handle handle); // Clones ProcessFunctionLibraryRuntime and FunctionLibraryDefinition // (transferring ownership of both to the caller). Note that the @@ -365,15 +365,15 @@ class ProcessFunctionLibraryRuntime { // FunctionLibraryDefinitions for its functions independently (and passes // these into the FunctionLibraryRuntime through an overlay), to avoid linear // runtime w.r.t. to number of functions in the current function library. - Status Clone(Env* env, int graph_def_version, - const OptimizerOptions& optimizer_options, - std::unique_ptr* out_lib_def, - std::unique_ptr* out_pflr, - bool skip_flib_def = false) const; + absl::Status Clone(Env* env, int graph_def_version, + const OptimizerOptions& optimizer_options, + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + bool skip_flib_def = false) const; - Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); + absl::Status ReleaseMultiDeviceHandle(FunctionLibraryRuntime::Handle handle); - Status InstantiateMultiDevice( + absl::Status InstantiateMultiDevice( const string& function_name, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, FunctionLibraryRuntime::Handle* handle); @@ -397,7 +397,7 @@ class ProcessFunctionLibraryRuntime { std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done) const; - Status CreateRendezvous( + absl::Status CreateRendezvous( FunctionLibraryRuntime::Options& opts, tsl::core::RefCountPtr* created_rendezvous) const; @@ -410,28 +410,29 @@ class ProcessFunctionLibraryRuntime { void CleanUp(std::vector>* items, FunctionLibraryRuntime::DoneCallback done) const; - static Status GetComponentArgs(absl::Span args, - const ComponentFunctionData& comp_data, - InternalArgs* comp_args); + static absl::Status GetComponentArgs(absl::Span args, + const ComponentFunctionData& comp_data, + InternalArgs* comp_args); #if !defined(IS_MOBILE_PLATFORM) - static Status GetComponentArgs(const FunctionArgsInterface& args, - const ComponentFunctionData& comp_data, - InternalArgs* comp_args); + static absl::Status GetComponentArgs(const FunctionArgsInterface& args, + const ComponentFunctionData& comp_data, + InternalArgs* comp_args); #endif // IS_MOBILE_PLATFORM std::vector GetOrderedSubgraphs( const MultiDeviceFunctionData* data) const; - Status PrepareRunMultiDevice(const FunctionLibraryRuntime::Options& opts, - FunctionLibraryRuntime::Handle handle, - const MultiDeviceFunctionData** data) const; + absl::Status PrepareRunMultiDevice( + const FunctionLibraryRuntime::Options& opts, + FunctionLibraryRuntime::Handle handle, + const MultiDeviceFunctionData** data) const; - Status RunMultiDeviceSync( + absl::Status RunMultiDeviceSync( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, std::vector* rets, - std::function + std::function get_component_args) const; void RunMultiDeviceAsync( @@ -439,8 +440,8 @@ class ProcessFunctionLibraryRuntime { FunctionLibraryRuntime::Handle handle, std::vector* rets, std::vector>* cleanup_items, FunctionLibraryRuntime::DoneCallback done, - std::function + std::function get_component_args) const; void PublishSubgraphs( @@ -487,7 +488,7 @@ class ProcessFunctionLibraryRuntime { const string function_key_; bool is_cross_process_ TF_GUARDED_BY(mu_) = false; bool init_started_ TF_GUARDED_BY(mu_) = false; - Status init_result_ TF_GUARDED_BY(mu_); + absl::Status init_result_ TF_GUARDED_BY(mu_); Notification init_done_; }; diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 6f77f22e5f18fa..ca3e4c871e7e53 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -129,7 +129,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { ->LookupDevice("/job:a/replica:0/task:0/device:CPU:2", &device2_ptr) .code()); // If no GPU is available, gpu_device_ will remain nullptr. - Status status = device_mgr_->LookupDevice( + absl::Status status = device_mgr_->LookupDevice( "/job:a/replica:0/task:0/device:GPU:0", &gpu_device_); if (!status.ok()) { CHECK_EQ(nullptr, gpu_device_); @@ -167,7 +167,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { proc_flr_->AddCompositeDevice(d); } - Status Instantiate( + absl::Status Instantiate( const string& name, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, FunctionLibraryRuntime::Handle* handle) { @@ -212,14 +212,15 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { } template - Status RunWithRuntime( + absl::Status RunWithRuntime( const string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, const T& args, std::vector rets, ProcessFunctionLibraryRuntime* pflr) { FunctionLibraryRuntime::Handle handle; - Status status = pflr->Instantiate(name, attrs, instantiate_opts, &handle); + absl::Status status = + pflr->Instantiate(name, attrs, instantiate_opts, &handle); if (!status.ok()) { return status; } @@ -234,10 +235,11 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { Notification done; opts.runner = &runner; std::vector out; - pflr->Run(opts, handle, args, &out, [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); + pflr->Run(opts, handle, args, &out, + [&status, &done](const absl::Status& s) { + status = s; + done.Notify(); + }); done.WaitForNotification(); if (!status.ok()) { return status; @@ -254,10 +256,11 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { return status; } Notification done2; - pflr->Run(opts, handle, args, &out, [&status, &done2](const Status& s) { - status = s; - done2.Notify(); - }); + pflr->Run(opts, handle, args, &out, + [&status, &done2](const absl::Status& s) { + status = s; + done2.Notify(); + }); done2.WaitForNotification(); EXPECT_TRUE(errors::IsNotFound(status)) << "Actual status: " << status; EXPECT_TRUE(absl::StrContains(status.message(), "not found.")); @@ -265,16 +268,17 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { return absl::OkStatus(); } - Status Run(const string& name, FunctionLibraryRuntime::Options opts, - test::function::Attrs attrs, - const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, - const std::vector& args, std::vector rets, - ProcessFunctionLibraryRuntime* pflr = nullptr) { + absl::Status Run( + const string& name, FunctionLibraryRuntime::Options opts, + test::function::Attrs attrs, + const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, + const std::vector& args, std::vector rets, + ProcessFunctionLibraryRuntime* pflr = nullptr) { return RunWithRuntime, Tensor>( name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get()); } - Status RunWithPackedArgs( + absl::Status RunWithPackedArgs( const string& name, FunctionLibraryRuntime::Options opts, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& instantiate_opts, @@ -284,23 +288,24 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { name, opts, attrs, instantiate_opts, args, rets, proc_flr_.get()); } - Status RunInstantiated(FunctionLibraryRuntime::Handle handle, - FunctionLibraryRuntime::Options opts, - const std::vector& args, - std::vector rets) { + absl::Status RunInstantiated(FunctionLibraryRuntime::Handle handle, + FunctionLibraryRuntime::Options opts, + const std::vector& args, + std::vector rets) { std::function)> runner = [](std::function fn) { test::function::FunctionTestSchedClosure(fn); }; opts.runner = &runner; - Status status; + absl::Status status; Notification done; std::vector out; - proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { - status = s; - done.Notify(); - }); + proc_flr_->Run(opts, handle, args, &out, + [&status, &done](const absl::Status& s) { + status = s; + done.Notify(); + }); done.WaitForNotification(); if (!status.ok()) { return status; @@ -396,8 +401,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, GetDeviceIncarnation) { &incarnation)); // Incarnation is a random number other than 0. EXPECT_NE(incarnation, 0); - Status s = proc_flr_->GetDeviceIncarnation("/job:a/replica:0/task:0/cpu:2", - &incarnation); + absl::Status s = proc_flr_->GetDeviceIncarnation( + "/job:a/replica:0/task:0/cpu:2", &incarnation); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); } @@ -622,8 +627,8 @@ void TestTwoDeviceMult( auto x = test::AsTensor({1, 2, 3}); Tensor y_cpu; Tensor y_gpu; - Status status = fixture->Run("TwoDeviceMult", opts, {{"T", DT_FLOAT}}, - inst_opts, {x}, {&y_cpu, &y_gpu}); + absl::Status status = fixture->Run("TwoDeviceMult", opts, {{"T", DT_FLOAT}}, + inst_opts, {x}, {&y_cpu, &y_gpu}); if (!error.empty()) { EXPECT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; @@ -782,7 +787,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListInput) { const FunctionDef& def = test::function::FuncWithListInput(); Init({def}); FunctionLibraryRuntime::Handle handle; - Status status = proc_flr_->Instantiate( + absl::Status status = proc_flr_->Instantiate( "FuncWithListInput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}), MakeOptions("CPU:0", {"CPU:0"}, {}), &handle); ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; @@ -803,7 +808,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, FullTypeForInt32) { TFT_TENSOR); Init({def}); FunctionLibraryRuntime::Handle handle; - Status status = + absl::Status status = proc_flr_->Instantiate("XTimesTwoInt32", test::function::Attrs({}), MakeOptions("CPU:0", {"CPU:0"}, {}), &handle); ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; @@ -819,7 +824,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ErrorWhenListOutput) { const FunctionDef& def = test::function::FuncWithListOutput(); Init({def}); FunctionLibraryRuntime::Handle handle; - Status status = proc_flr_->Instantiate( + absl::Status status = proc_flr_->Instantiate( "FuncWithListOutput", test::function::Attrs({{"T", DT_FLOAT}, {"N", 1}}), MakeOptions("CPU:0", {}, {"CPU:0"}), &handle); ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; @@ -925,7 +930,7 @@ FunctionDef AddVarAcrossDevices() { class TestFunctionPackedArgs : public FunctionArgsInterface { public: TestFunctionPackedArgs(const int index, - gtl::InlinedVector&& tensor_args) { + absl::InlinedVector&& tensor_args) { packed_args_.emplace(index, std::move(tensor_args)); } @@ -933,8 +938,8 @@ class TestFunctionPackedArgs : public FunctionArgsInterface { bool HasRemoteOrPackedInputs() const override { return true; }; - Status GetLocalArg(const FunctionArgIndex& index, - Tensor* val) const override { + absl::Status GetLocalArg(const FunctionArgIndex& index, + Tensor* val) const override { *val = *packed_args_.at(index.index).at(index.sub_index).tensor; return absl::OkStatus(); }; @@ -942,7 +947,7 @@ class TestFunctionPackedArgs : public FunctionArgsInterface { std::vector GetLocalTensors() const override { return {}; } private: - absl::flat_hash_map> packed_args_; + absl::flat_hash_map> packed_args_; }; TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { @@ -967,7 +972,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { GetResourceHandle("var", mgr1->default_container(), device1_->name()); // Create a CompositeDevice - Status s; + absl::Status s; std::unique_ptr composite_device = CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, /*unique_device_id=*/0, @@ -985,7 +990,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CompositeDevice) { // Packed TensorHandle { - gtl::InlinedVector handles; + absl::InlinedVector handles; handles.push_back(TensorValue(&resource_handle0)); handles.push_back(TensorValue(&resource_handle1)); TestFunctionPackedArgs args(0, std::move(handles)); @@ -1025,7 +1030,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_ResourceOutput_GPU) { *resource->tensor() = resource_value; resource->is_initialized = true; ResourceMgr* mgr = gpu_device_->resource_manager(); - Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource); + absl::Status status = + mgr->Create(mgr->default_container(), "my_gpu_var", resource); ASSERT_TRUE(status.ok()) << status.message(); // Run the function taking a resource and outputting it @@ -1068,7 +1074,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_PlacerError) { test::function::ReadResourceVariable()}); FunctionLibraryRuntime::Handle handle; - Status status = proc_flr_->Instantiate( + absl::Status status = proc_flr_->Instantiate( "ResourceOutput", test::function::Attrs({{"T", DT_FLOAT}}), inst_opts, &handle); ASSERT_TRUE(errors::IsInvalidArgument(status)) << "Actual status: " << status; @@ -1119,7 +1125,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_CreateKernelsEagerly) { // Instantiating the broken function while creating kernels eagerly should // fail. inst_opts.create_kernels_eagerly = true; - Status status = Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle); + absl::Status status = + Instantiate("Broken", {{"T", DT_INT32}}, inst_opts, &handle); EXPECT_TRUE(errors::IsInternal(status)); } @@ -1155,7 +1162,8 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, MultiDevice_StateHandle) { Var* resource = new Var(T); *resource->tensor() = resource_value; resource->is_initialized = true; - Status status = mgr->Create(mgr->default_container(), "my_gpu_var", resource); + absl::Status status = + mgr->Create(mgr->default_container(), "my_gpu_var", resource); ASSERT_TRUE(status.ok()) << status.message(); Tensor x = GetResourceHandle("my_gpu_var", mgr->default_container(), @@ -1304,7 +1312,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresent) { TEST_F(ProcessFunctionLibraryRuntimeTest, CompositeDevicesAfterCloning) { Init({AddVarAcrossDevices()}); - Status s; + absl::Status s; std::unique_ptr composite_device = CompositeDevice::MakeDevice({device0_->name(), device1_->name()}, /*unique_device_id=*/0, @@ -1339,7 +1347,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SessionMetadataPresentAfterCloning) { instantiate_opts.target = "/job:a/replica:0/task:0/cpu:0"; const auto x = test::AsTensor({17}); Tensor y; - Status s = RunWithRuntime, Tensor>( + absl::Status s = RunWithRuntime, Tensor>( "SessionMetadataReaderFn", opts, {}, instantiate_opts, {x}, {&y}, cloned_proc_flr.get()); TF_CHECK_OK(s); diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index 31fdee7a4339e5..a91eb74f1ef464 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -77,7 +77,7 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) { const bool alloc_visitors_defined = (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty()); bool use_bfc_allocator = false; - Status status = ReadBoolFromEnvVar( + absl::Status status = ReadBoolFromEnvVar( "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator); if (!status.ok()) { LOG(ERROR) << "GetCPUAllocator: " << status.message(); @@ -92,9 +92,9 @@ Allocator* ProcessState::GetCPUAllocator(int numa_node) { if (use_bfc_allocator) { // TODO(reedwm): evaluate whether 64GB by default is the best choice. int64_t cpu_mem_limit_in_mb = -1; - Status status = ReadInt64FromEnvVar("TF_CPU_BFC_MEM_LIMIT_IN_MB", - 1LL << 16 /*64GB max by default*/, - &cpu_mem_limit_in_mb); + absl::Status status = ReadInt64FromEnvVar( + "TF_CPU_BFC_MEM_LIMIT_IN_MB", 1LL << 16 /*64GB max by default*/, + &cpu_mem_limit_in_mb); if (!status.ok()) { LOG(ERROR) << "GetCPUAllocator: " << status.message(); } diff --git a/tensorflow/core/common_runtime/profile_handler.h b/tensorflow/core/common_runtime/profile_handler.h index 391dc8c19878b7..1d8856a344a452 100644 --- a/tensorflow/core/common_runtime/profile_handler.h +++ b/tensorflow/core/common_runtime/profile_handler.h @@ -56,7 +56,7 @@ class ProfileHandler { // - final_status: The status that this step finished with. virtual void StepDone(Microseconds start_time, Microseconds finish_time, Microseconds cleanup_time, int total_runops, - Status final_status) = 0; + absl::Status final_status) = 0; // Returns true if the caller should collect rpc activity. virtual bool should_collect_rpcs() = 0; diff --git a/tensorflow/core/common_runtime/propagator_state.h b/tensorflow/core/common_runtime/propagator_state.h index 680cb13ef3ecb4..e5f4fd6bfec0ac 100644 --- a/tensorflow/core/common_runtime/propagator_state.h +++ b/tensorflow/core/common_runtime/propagator_state.h @@ -34,7 +34,7 @@ limitations under the License. namespace tensorflow { -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector AllocatorAttributeVec; // Represents the ephemeral "edge state" associated with one invocation of // `Executor::Run()`. @@ -115,12 +115,12 @@ class PropagatorState { // TODO(b/152925936): Re-evaluate these constants with current usage // patterns. static constexpr int kSpillThreshold = 16384; - gtl::InlinedVector ready_; + absl::InlinedVector ready_; int front_index_; }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. - typedef gtl::InlinedVector TaggedNodeSeq; + typedef absl::InlinedVector TaggedNodeSeq; private: // The state of an iteration in a particular frame. @@ -283,7 +283,7 @@ class PropagatorState { private: // The active iteration states of this frame. - gtl::InlinedVector iterations; + absl::InlinedVector iterations; IterationState** const iterations_raw TF_GUARDED_BY(mu); IterationState* iterations_first TF_GUARDED_BY(mu); diff --git a/tensorflow/core/common_runtime/quantize_training.cc b/tensorflow/core/common_runtime/quantize_training.cc index 8f225405cf41d3..6117cccaa0cf4c 100644 --- a/tensorflow/core/common_runtime/quantize_training.cc +++ b/tensorflow/core/common_runtime/quantize_training.cc @@ -134,8 +134,8 @@ bool FindType(const Graph* graph, const Node* node, bool* signed_input, } // Find the Save op and inputs. -Status FindSaveOp(const Graph* graph, Node** save_op, - std::vector* in_edges, bool* found) { +absl::Status FindSaveOp(const Graph* graph, Node** save_op, + std::vector* in_edges, bool* found) { *found = false; for (Node* node : graph->op_nodes()) { if (node->type_string() == "SaveV2") { @@ -180,9 +180,9 @@ void FillStringTensor(Tensor* dst, const Tensor& src) { // Add the added_variables as an inputs to the Save op. // We change the inputs of the SaveV2 op to include the names of the added // variables. We also add the variables as inputs to the save op. -Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, - const std::vector& in_edges, - const std::vector& added_variables) { +absl::Status ConnectVariablesToSaveOp( + Graph* graph, Node* save_op, const std::vector& in_edges, + const std::vector& added_variables) { Node* tensor_names_op = in_edges[1]->src(); Node* shape_and_slices_op = in_edges[2]->src(); @@ -245,9 +245,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, // Assign----restore_all // | | // RestoreV2 Variable -Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, - const std::vector& in_edges, - const std::vector& variables) { +absl::Status AddRestoreVariableSubgraphs( + Graph* graph, Node* save_op, const std::vector& in_edges, + const std::vector& variables) { Node* prefix_op = in_edges[0]->src(); StringPiece name_prefix = GetNodeNamePrefix(save_op); Node* restore_all = FindRestoreAllOp(graph, name_prefix); @@ -312,7 +312,8 @@ Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op, // Adds new variables to save and restore ops matching the Save and Restore // graphs created in tensorflow/python/training/saver.py. -Status AddSaveAndRestore(Graph* graph, const std::vector& variables) { +absl::Status AddSaveAndRestore(Graph* graph, + const std::vector& variables) { Node* save_op = nullptr; std::vector in_edges; bool found = false; @@ -328,8 +329,8 @@ Status AddSaveAndRestore(Graph* graph, const std::vector& variables) { // Sets output to the Node that computes reduction axes corresponding to all // dimensions of input and return. -Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, - Node** output) { +absl::Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, + Node** output) { name_prefix = strings::StrCat(name_prefix, "/ReductionAxes"); Node* start; Tensor zero_tensor(DT_INT32, TensorShape()); @@ -362,10 +363,10 @@ Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input, } // Computes the exponential moving average of input, updated in update_variable. -Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, - const NodeBuilder::NodeOut& input, - Node* decay, Node* update_variable, - Node** assign_value) { +absl::Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, + const NodeBuilder::NodeOut& input, + Node* decay, Node* update_variable, + Node** assign_value) { // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)] name_prefix = strings::StrCat(name_prefix, "/EMA"); Node* one; @@ -415,10 +416,10 @@ Status MakeExponentialMovingAverage(Graph* graph, string name_prefix, // | EMA init_val // | \ / // +----------- assign -Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, - Node* init_val, - std::vector* added_variables, - Node** var) { +absl::Status MakeInitializedEMAVariable(Graph* graph, const string& name, + Node* decay, Node* init_val, + std::vector* added_variables, + Node** var) { // TODO(suharshs): Update this to use ResourceVariables when they are ready. TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2") @@ -458,9 +459,9 @@ Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay, } // Computes the min and max EMA of input and stores them in min_var and max_var. -Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, - std::vector* added_variables, Node** min_var, - Node** max_var) { +absl::Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, + Node* input, std::vector* added_variables, + Node** min_var, Node** max_var) { // TODO(suharshs): The decay will be constant, so we could make only one for // all quantize_and_dequantize ops to share, this would have to live outside // this function. @@ -497,10 +498,10 @@ Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input, // Makes an input min and max constant if the range is given. Otherwise, makes // min and max variables that are updated by an EMA. -Status MakeInputMinMax(Graph* graph, const string& name_prefix, - const EdgeToConvert& edge, - std::vector* added_variables, Node** input_min, - Node** input_max) { +absl::Status MakeInputMinMax(Graph* graph, const string& name_prefix, + const EdgeToConvert& edge, + std::vector* added_variables, + Node** input_min, Node** input_max) { if (edge.range_given) { // Make constant nodes for the input_min and input_max if the range is // provided. @@ -531,10 +532,11 @@ Status MakeInputMinMax(Graph* graph, const string& name_prefix, // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op // (and required input nodes) based on edge. // The result is stored in convert_node. -Status MakeQuantizeOp(Graph* graph, const string& name_prefix, - const string& quant_op_type, const EdgeToConvert& edge, - std::vector* added_variables, - Node** convert_node) { +absl::Status MakeQuantizeOp(Graph* graph, const string& name_prefix, + const string& quant_op_type, + const EdgeToConvert& edge, + std::vector* added_variables, + Node** convert_node) { Node* input_min; Node* input_max; TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables, @@ -563,8 +565,9 @@ Status MakeQuantizeOp(Graph* graph, const string& name_prefix, } // Insert conversion op, connect it to the graph and remove the old edge. -Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, - const std::vector& target_edges) { +absl::Status ProcessTargetEdges( + Graph* graph, const string& quant_op_type, + const std::vector& target_edges) { // Remember previously converted ops to avoid duplicated conversion on the // same input. std::unordered_map name_index; @@ -593,8 +596,8 @@ Status ProcessTargetEdges(Graph* graph, const string& quant_op_type, } // namespace -Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, - Graph* graph) { +absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, + Graph* graph) { if (graph == nullptr) { return errors::InvalidArgument("Cannot accept empty graph pointer."); } @@ -658,10 +661,10 @@ Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, return absl::OkStatus(); } -Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, - int32_t num_bits, - const string& quant_op_type, - GraphDef* result_graphdef) { +absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, + int32_t num_bits, + const string& quant_op_type, + GraphDef* result_graphdef) { Graph graph(OpRegistry::Global()); GraphConstructorOptions opts; TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph)); @@ -674,10 +677,9 @@ Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, return absl::OkStatus(); } -Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string, - int32_t num_bits, - const string& quant_op_type, - string* result_graph_string) { +absl::Status DoQuantizeTrainingOnSerializedGraphDef( + const string& input_graph_string, int32_t num_bits, + const string& quant_op_type, string* result_graph_string) { // First create the graph from the GraphDef. GraphDef input_graphdef; if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) { diff --git a/tensorflow/core/common_runtime/quantize_training.h b/tensorflow/core/common_runtime/quantize_training.h index 4013caef412638..de3ed6b476b24a 100644 --- a/tensorflow/core/common_runtime/quantize_training.h +++ b/tensorflow/core/common_runtime/quantize_training.h @@ -35,22 +35,22 @@ namespace tensorflow { // - num_bits out of range. // - g is null. // - More than 1 unknown ops encountered. -Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, - Graph* g); +absl::Status DoQuantizeTraining(int32_t num_bits, const string& quant_op_type, + Graph* g); // Converts the input serialized GraphDef and returns a rewritten serialized // GraphDef for quantized training. -Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, - int32_t num_bits, - const string& quant_op_type, - string* result_graph); +absl::Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph, + int32_t num_bits, + const string& quant_op_type, + string* result_graph); // Converts the input GraphDef and returns a rewritten GraphDef for quantized // training. -Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, - int32_t num_bits, - const string& quant_op_type, - GraphDef* result_graphdef); +absl::Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef, + int32_t num_bits, + const string& quant_op_type, + GraphDef* result_graphdef); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/quantize_training_test.cc b/tensorflow/core/common_runtime/quantize_training_test.cc index 4031e7405049be..7f2e1b0e709d35 100644 --- a/tensorflow/core/common_runtime/quantize_training_test.cc +++ b/tensorflow/core/common_runtime/quantize_training_test.cc @@ -51,8 +51,8 @@ class QuantizeTrainingTest : public ::testing::Test { return test::graph::Constant(g_.get(), test::AsTensor(values, shape)); } - Status Placeholder(Graph* g, const string& name, TensorShape shape, - Node** out) { + absl::Status Placeholder(Graph* g, const string& name, TensorShape shape, + Node** out) { TF_RETURN_IF_ERROR(NodeBuilder(name, "Placeholder") .Attr("dtype", DT_FLOAT) .Attr("shape", shape) @@ -60,7 +60,7 @@ class QuantizeTrainingTest : public ::testing::Test { return absl::OkStatus(); } - Status FindNode(Graph* g, const string& name, Node** out) { + absl::Status FindNode(Graph* g, const string& name, Node** out) { for (Node* node : g->nodes()) { if (node->name() == name) { *out = node; @@ -214,8 +214,8 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) { // Ensure that the backwards matmul input was not quantized. Node* found_node; - Status s = FindNode(g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"), - &found_node); + absl::Status s = FindNode( + g, strings::StrCat(d->name(), "/QuantizeAndDequantizeV2"), &found_node); EXPECT_TRUE(absl::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. @@ -268,8 +268,8 @@ TEST_F(QuantizeTrainingTest, WithBackwardNodes_FakeQuant) { // Ensure that the backwards matmul input was not quantized. Node* found_node; - Status s = FindNode(g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"), - &found_node); + absl::Status s = FindNode( + g, strings::StrCat(d->name(), "/FakeQuantWithMinMaxVars"), &found_node); EXPECT_TRUE(absl::StrContains(s.ToString(), "not found")) << s; // Ensure that m1 and m2's inputs were quantized. diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index d67bdf228108f4..e4b4b8aedf0d09 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -97,16 +97,16 @@ class RenamedDevice : public Device { return underlying_device_->MakeGpuDevice(); } - Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, - Allocator* allocator) override { + absl::Status ReinitializeGpuDevice(OpKernelContext* context, + PerOpGpuDevice* device, DeviceContext* dc, + Allocator* allocator) override { return underlying_device_->ReinitializeGpuDevice(context, device, dc, allocator); } - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override { + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { return underlying_device_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor); } @@ -129,13 +129,13 @@ class RenamedDevice : public Device { underlying_device_->ComputeAsync(op_kernel, context, std::move(done)); } - Status Sync() override { return underlying_device_->Sync(); } + absl::Status Sync() override { return underlying_device_->Sync(); } - Status MaybeRewriteGraph(std::unique_ptr* graph) override { + absl::Status MaybeRewriteGraph(std::unique_ptr* graph) override { return underlying_device_->MaybeRewriteGraph(graph); } - Status TryGetDeviceContext(DeviceContext** out_context) override { + absl::Status TryGetDeviceContext(DeviceContext** out_context) override { return underlying_device_->TryGetDeviceContext(out_context); } diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.cc b/tensorflow/core/common_runtime/rendezvous_mgr.cc index 80002db99d827a..f1a199ba97250d 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.cc +++ b/tensorflow/core/common_runtime/rendezvous_mgr.cc @@ -76,7 +76,7 @@ void SameWorkerRecvDone(const DeviceMgr* device_mgr, } Device* src_device; - Status s = device_mgr->LookupDevice(parsed.src_device, &src_device); + absl::Status s = device_mgr->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; @@ -141,7 +141,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, local->RecvAsync( parsed, recv_args, [device_mgr, parsed, done = std::move(done)]( - const Status& status, const Rendezvous::Args& send_args, + const absl::Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) mutable { // If "in" is an uninitialized tensor, do copy-construction to @@ -150,7 +150,7 @@ void IntraProcessRecvAsyncImpl(const DeviceMgr* device_mgr, Tensor* out = in.IsInitialized() ? new Tensor : new Tensor(in); auto final_callback = [send_args, recv_args, out, is_dead, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { done(s, send_args, recv_args, *out, is_dead); delete out; }; @@ -175,10 +175,9 @@ RefCountedIntraProcessRendezvous::~RefCountedIntraProcessRendezvous() { VLOG(5) << "Destructor of IntraProcessRendezvous: " << this; } -Status RefCountedIntraProcessRendezvous::Send(const ParsedKey& key, - const Rendezvous::Args& args, - const Tensor& val, - const bool is_dead) { +absl::Status RefCountedIntraProcessRendezvous::Send( + const ParsedKey& key, const Rendezvous::Args& args, const Tensor& val, + const bool is_dead) { VLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); return local_.Send(key, args, val, is_dead); } @@ -190,12 +189,12 @@ void RefCountedIntraProcessRendezvous::RecvAsync(const ParsedKey& key, IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); } -void RefCountedIntraProcessRendezvous::StartAbort(const Status& s) { +void RefCountedIntraProcessRendezvous::StartAbort(const absl::Status& s) { VLOG(1) << "IntraProcessRendezvous start Abort " << this; local_.StartAbort(s); } -Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() { +absl::Status RefCountedIntraProcessRendezvous::GetLocalRendezvousStatus() { return local_.status(); } @@ -206,10 +205,10 @@ PrivateIntraProcessRendezvous::PrivateIntraProcessRendezvous( PrivateIntraProcessRendezvous::~PrivateIntraProcessRendezvous() {} -Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key, - const Rendezvous::Args& args, - const Tensor& val, - const bool is_dead) { +absl::Status PrivateIntraProcessRendezvous::Send(const ParsedKey& key, + const Rendezvous::Args& args, + const Tensor& val, + const bool is_dead) { DVLOG(1) << "IntraProcessRendezvous Send " << this << " " << key.FullKey(); return local_.Send(key, args, val, is_dead); } @@ -222,7 +221,7 @@ void PrivateIntraProcessRendezvous::RecvAsync(const ParsedKey& key, IntraProcessRecvAsyncImpl(device_mgr_, &local_, key, args, std::move(done)); } -void PrivateIntraProcessRendezvous::StartAbort(const Status& s) { +void PrivateIntraProcessRendezvous::StartAbort(const absl::Status& s) { local_.StartAbort(s); } diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h index a6e70fcab91101..23c07b3db37159 100644 --- a/tensorflow/core/common_runtime/rendezvous_mgr.h +++ b/tensorflow/core/common_runtime/rendezvous_mgr.h @@ -49,14 +49,14 @@ class RefCountedIntraProcessRendezvous : public Rendezvous { // no other references to the RefCountedIntraProcessRendezvous object. // If the caller intend to keep a longer life time then it shall keep its own // reference to the RefCountedIntraProcessRendezvous. - Status Send(const ParsedKey& key, const Rendezvous::Args& args, - const Tensor& val, const bool is_dead) override; + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; - void StartAbort(const Status& status) override; + void StartAbort(const absl::Status& status) override; // Returns the member LocalRendezvous' status. - Status GetLocalRendezvousStatus(); + absl::Status GetLocalRendezvousStatus(); inline void UpdateDeviceManager(DeviceMgr* device_mgr) { device_mgr_ = device_mgr; @@ -87,11 +87,11 @@ class PrivateIntraProcessRendezvous : public RendezvousInterface { ~PrivateIntraProcessRendezvous() override; // Implementation of RendezvousInterface methods. - Status Send(const ParsedKey& key, const Rendezvous::Args& args, - const Tensor& val, const bool is_dead) override; + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; - void StartAbort(const Status& status) override; + void StartAbort(const absl::Status& status) override; private: const DeviceMgr* device_mgr_; diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 9c6d949d18e16b..532f4e84a2f9f2 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -Status SendTensorsToRendezvous( +absl::Status SendTensorsToRendezvous( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, absl::Span tensors_to_send) { @@ -74,7 +74,7 @@ void RecvOutputsFromRendezvousAsync( arguments; for (int i = 0; i < keys.size(); ++i) { Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(keys[i], &parsed); + absl::Status s = Rendezvous::ParseKey(keys[i], &parsed); received_tensors->push_back(Tensor()); if (!s.ok()) { done(s); @@ -99,11 +99,11 @@ void RecvOutputsFromRendezvousAsync( status_cb->Ref(); rendezvous->RecvAsync( parsed, rendez_args, - [val, key, status_cb](const Status& s, + [val, key, status_cb](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool is_dead) { - Status status = s; + absl::Status status = s; if (status.ok()) { *val = v; if (is_dead) { @@ -118,9 +118,9 @@ void RecvOutputsFromRendezvousAsync( status_cb->Unref(); } -Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, - NamedTensors* out, - const Rendezvous::Args& args) { +absl::Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, + const Rendezvous::Args& args) { // Receives values requested by the caller. Rendezvous::ParsedKey parsed; for (auto& p : *out) { diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h index 51175373f4dd51..8ed1dd7a11ad16 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.h +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -23,14 +23,14 @@ limitations under the License. namespace tensorflow { typedef std::map NamedTensors; -typedef std::function StatusCallback; +typedef std::function StatusCallback; // Uses `rendezvous` to send tensors in `tensors_to_send`. `device_context` // should be the DeviceContext associated with the source of the tensors. // `alloc_attrs` contains information about how the `tensors_to_send` are // allocated. `alloc_attrs` should either be {} or should match the length of // `keys`. -Status SendTensorsToRendezvous( +absl::Status SendTensorsToRendezvous( RendezvousInterface* rendezvous, DeviceContext* device_context, const std::vector& alloc_attrs, const std::vector& keys, absl::Span tensors_to_send); @@ -45,9 +45,9 @@ void RecvOutputsFromRendezvousAsync( const std::vector& keys, std::vector* received_tensors, StatusCallback done); -Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, - NamedTensors* out, - const Rendezvous::Args& args); +absl::Status RecvOutputsFromRendezvous(RendezvousInterface* rendezvous, + NamedTensors* out, + const Rendezvous::Args& args); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc index d84625597b06c8..7c1f256dbdc3dc 100644 --- a/tensorflow/core/common_runtime/rendezvous_util_test.cc +++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc @@ -60,7 +60,7 @@ TEST_F(RendezvousUtilTest, SendBeforeRecv) { std::vector received_keys; RecvOutputsFromRendezvousAsync( rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, - &received_keys, [&n](const Status& status) { n.Notify(); }); + &received_keys, [&n](const absl::Status& status) { n.Notify(); }); n.WaitForNotification(); EXPECT_EQ(2, received_keys.size()); @@ -74,7 +74,7 @@ TEST_F(RendezvousUtilTest, RecvBeforeSend) { std::vector received_keys; RecvOutputsFromRendezvousAsync( rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, - &received_keys, [&n](const Status& status) { n.Notify(); }); + &received_keys, [&n](const absl::Status& status) { n.Notify(); }); TF_ASSERT_OK(SendTensorsToRendezvous( rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, @@ -105,7 +105,7 @@ TEST(RendezvousUtilCallerThreadTest, RecvBeforeSend) { std::vector received_keys; RecvOutputsFromRendezvousAsync( rendez_, nullptr, {}, {MakeStringKey("hello1"), MakeStringKey("hello2")}, - &received_keys, [&n, rendez_](const Status& status) { + &received_keys, [&n, rendez_](const absl::Status& status) { rendez_->Unref(); n.Notify(); }); diff --git a/tensorflow/core/common_runtime/replicate_constants_pass.cc b/tensorflow/core/common_runtime/replicate_constants_pass.cc index 11d6f0d53864f4..93e4fd1c8f51f3 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass.cc +++ b/tensorflow/core/common_runtime/replicate_constants_pass.cc @@ -70,8 +70,8 @@ bool HasCpuDevice(const Node* node) { // Convert the CPU device name to the corresponding CPU device name. If // multiple local CPU devices are enabled, the CPU device name will also // contain the device id. -Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, - string* host_device_name) { +absl::Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, + string* host_device_name) { DeviceNameUtils::ParsedName device; if (!DeviceNameUtils::ParseFullName(device_name, &device)) { return absl::InternalError( @@ -94,7 +94,7 @@ Status DeviceNameToCpuDeviceNameWithDeviceId(const string& device_name, } // Get the CPU device on the same host as dst. -Status GetDestinationCpuDevice(const Node* dst, std::string* device) { +absl::Status GetDestinationCpuDevice(const Node* dst, std::string* device) { if (!dst->has_assigned_device_name()) return absl::AbortedError( absl::StrCat("Node name: ", dst->name(), " has no assigned device.")); @@ -104,7 +104,7 @@ Status GetDestinationCpuDevice(const Node* dst, std::string* device) { // Collect the successor edges of the constant. Group them by the device of the // successor. -Status GetSuccessorEdges( +absl::Status GetSuccessorEdges( Node* node, absl::btree_map>& device_to_edges) { for (const auto& edge : node->out_edges()) { @@ -140,7 +140,7 @@ void ReplicateToEachDevice( } // namespace -Status ReplicateConstantsPass::Run( +absl::Status ReplicateConstantsPass::Run( const GraphOptimizationPassOptions& options) { VLOG(1) << "replicate_constants_pass will replicate constants with " "number-of-elements <= " diff --git a/tensorflow/core/common_runtime/replicate_constants_pass.h b/tensorflow/core/common_runtime/replicate_constants_pass.h index b7b2f0fe98c0d2..b215d301b3ab50 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass.h +++ b/tensorflow/core/common_runtime/replicate_constants_pass.h @@ -42,7 +42,7 @@ namespace tensorflow { class ReplicateConstantsPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc index 58bab38fb1d093..3f4cf1498769a0 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.cc @@ -33,7 +33,7 @@ constexpr int kOptimizeCrossHostDataEdgesTheshold = 2; class ReplicateHelper { public: // Initialize replicated nodes with nullptr. - Status InitializeNode(const Node* node, int num_allowed_devices) { + absl::Status InitializeNode(const Node* node, int num_allowed_devices) { if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) { return errors::InvalidArgument("Node ", node->name(), " has been replicated."); @@ -44,9 +44,9 @@ class ReplicateHelper { } // Replicate the given node to an allowed device. - Status ReplicateNode(const Node* node, - const std::vector& allowed_devices, - int allowed_device_index, Graph* graph) { + absl::Status ReplicateNode(const Node* node, + const std::vector& allowed_devices, + int allowed_device_index, Graph* graph) { auto& replicated_nodes = replicated_nodes_map_.at(node); if (replicated_nodes[allowed_device_index] != nullptr) { return absl::OkStatus(); @@ -82,7 +82,7 @@ class ReplicateHelper { // Replace an edge (composite device -> composite device) with // N edges (allowed devices -> allowed devices). - Status ReplicateFromCompositeDeviceToCompositeDevice( + absl::Status ReplicateFromCompositeDeviceToCompositeDevice( const Edge* edge, const std::vector& allowed_devices, Graph* graph) { const std::vector& src_replicated_nodes = @@ -114,7 +114,7 @@ class ReplicateHelper { // one edge (one allowed device -> a regular device). // Control edge: replace an edge (composite device -> a regular device) with // N edges (allowed devices -> a regular device). - Status ReplicateFromCompositeDeviceToRegularDevice( + absl::Status ReplicateFromCompositeDeviceToRegularDevice( const Edge* edge, const std::vector& allowed_devices, Graph* graph) { const std::vector& src_replicated_nodes = @@ -197,9 +197,10 @@ class ReplicateHelper { }; // Replicate the nodes in cluster_nodes and update edges. -Status ReplicateNodesAndEdges(const std::vector& allowed_devices, - absl::flat_hash_map* cluster_nodes, - ReplicateHelper* helper, Graph* graph) { +absl::Status ReplicateNodesAndEdges( + const std::vector& allowed_devices, + absl::flat_hash_map* cluster_nodes, ReplicateHelper* helper, + Graph* graph) { // Contains nodes in cluster_nodes whose out nodes are all on physical // devices. std::queue nodes_ready_to_delete; @@ -251,7 +252,7 @@ Status ReplicateNodesAndEdges(const std::vector& allowed_devices, } // namespace -Status ReplicatePerReplicaNodesInFunctionGraph( +absl::Status ReplicatePerReplicaNodesInFunctionGraph( const absl::flat_hash_map*>& composite_devices, Graph* graph) { diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h index fd696db4905e9a..4be95ea32ca44b 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes.h +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes.h @@ -34,7 +34,7 @@ namespace tensorflow { // 3) Clusters assigned to different composite devices should have no data // dependency. // TODO(b/145922293): Register it as a POST_REWRITE_FOR_EXEC pass. -Status ReplicatePerReplicaNodesInFunctionGraph( +absl::Status ReplicatePerReplicaNodesInFunctionGraph( const absl::flat_hash_map*>& composite_devices, Graph* graph); diff --git a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc index 9682ac45ed3dce..ff6fcb4b8bc735 100644 --- a/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc +++ b/tensorflow/core/common_runtime/replicate_per_replica_nodes_test.cc @@ -263,7 +263,7 @@ TEST(ReplicatePerReplicaNodesTest, NestedFunctions) { TF_ASSERT_OK(NodeDefBuilder("func", "Func", &flib_def) .Input(arg.name(), 0, DT_RESOURCE) .Finalize(&def)); - Status status; + absl::Status status; Node* func = scope.graph()->AddNode(def, &status); TF_ASSERT_OK(status); scope.graph()->AddEdge(arg.node(), 0, func, 0); diff --git a/tensorflow/core/common_runtime/ring_alg.cc b/tensorflow/core/common_runtime/ring_alg.cc index 617127cadcc3f1..a12acfdf64c9dd 100644 --- a/tensorflow/core/common_runtime/ring_alg.cc +++ b/tensorflow/core/common_runtime/ring_alg.cc @@ -107,7 +107,7 @@ RingAlg::RingAlg(CollectiveType type, const string& name) num_subdivs_(-1) {} namespace { -Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { +absl::Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { // This function generates subdivision_offsets. Expect it to be empty when // called. DCHECK(col_params->instance.impl_details.subdiv_offsets.empty()); @@ -177,7 +177,7 @@ Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { } } // namespace -Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { +absl::Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { const string& device_name = col_params->group.members[col_params->default_rank].device.name(); // Each subdiv permutation is a ring formed by rotating each @@ -255,7 +255,7 @@ Status RingAlg::InitializeCollectiveParams(CollectiveParams* col_params) { return absl::OkStatus(); } -Status RingAlg::InitializeCollectiveContext( +absl::Status RingAlg::InitializeCollectiveContext( std::shared_ptr col_ctx) { DCHECK(col_ctx->dev_mgr); col_ctx_ = col_ctx; @@ -270,7 +270,7 @@ string RingAlg::TensorDebugString(const Tensor& tensor) { col_ctx_->op_ctx->device()->tensorflow_accelerator_device_info(); if (accelerator_device_info) { Tensor cpu_tensor(tensor.dtype(), tensor.shape()); - Status st = + absl::Status st = accelerator_device_info->default_context->CopyDeviceTensorToCPUSync( &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor); DCHECK(st.ok()); @@ -280,7 +280,7 @@ string RingAlg::TensorDebugString(const Tensor& tensor) { } } -void RingAlg::StartAbort(const Status& s) { +void RingAlg::StartAbort(const absl::Status& s) { // In abort mode we stop issuing additional ProvideBuf // and ConsumeBuf calls, but we need to wait for all of the // outstanding callbacks to be invoked before quitting. @@ -312,7 +312,7 @@ void RingAlg::Finish(bool ok) { // Recover the output from the adaptor. ca_->ConsumeFinalValue(col_ctx_->output); } - Status s; + absl::Status s; { mutex_lock l(status_mu_); s = status_; diff --git a/tensorflow/core/common_runtime/ring_alg.h b/tensorflow/core/common_runtime/ring_alg.h index a5d5f2931d81f3..df9072581fa309 100644 --- a/tensorflow/core/common_runtime/ring_alg.h +++ b/tensorflow/core/common_runtime/ring_alg.h @@ -35,17 +35,18 @@ class RingAlg : public CollectiveImplementationInterface { // Establishes the requested number of subdivision permutations based on the // ring order implicit in the device order. - Status InitializeCollectiveParams(CollectiveParams* col_params) override; + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; // Initializes members of CollectiveContext not yet initialized, i.e. device // and device_locality. Also saves the CollectiveContext in this object. - Status InitializeCollectiveContext( + absl::Status InitializeCollectiveContext( std::shared_ptr col_ctx) override; protected: // Called when a bad status is received that implies we should terminate // execution and return a bad status. - void StartAbort(const Status& s); + void StartAbort(const absl::Status& s); void Finish(bool ok); // Current status of a RingField @@ -75,7 +76,7 @@ class RingAlg : public CollectiveImplementationInterface { bool is_final = false; // is the last field in the pass for this rank Tensor chunk; // alias to field values Tensor tmp_chunk; - Status status; + absl::Status status; string DebugString() const; }; virtual void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, @@ -112,7 +113,7 @@ class RingAlg : public CollectiveImplementationInterface { Notification group_size_tensor_ready_; std::unique_ptr ca_; mutex status_mu_; - Status status_ TF_GUARDED_BY(status_mu_); + absl::Status status_ TF_GUARDED_BY(status_mu_); std::vector rfv_; }; diff --git a/tensorflow/core/common_runtime/ring_gatherer.cc b/tensorflow/core/common_runtime/ring_gatherer.cc index 34de53f21f2513..3c8749a039fb7b 100644 --- a/tensorflow/core/common_runtime/ring_gatherer.cc +++ b/tensorflow/core/common_runtime/ring_gatherer.cc @@ -42,7 +42,8 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" namespace tensorflow { -Status RingGatherer::InitializeCollectiveParams(CollectiveParams* col_params) { +absl::Status RingGatherer::InitializeCollectiveParams( + CollectiveParams* col_params) { DCHECK_EQ(col_params->instance.type, GATHER_COLLECTIVE); DCHECK_EQ(col_params->instance.impl_details.collective_name, "RingGather"); // TODO(tucker): Maybe add subdiv support. It's only useful with @@ -102,14 +103,14 @@ void RingGatherer::Run(StatusCallback done) { tsl::profiler::TraceMe activity("MemCpyAsync", tsl::profiler::TraceMeLevel::kInfo); Notification note; - Status status; + absl::Status status; Tensor alias_chunk(ca_->ChunkAlias(col_params_->subdiv_rank[0])); CollectiveRemoteAccessLocal::MemCpyAsync( col_ctx_->op_ctx->op_device_context(), col_ctx_->op_ctx->op_device_context(), col_ctx_->device, col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, &alias_chunk, - 0 /*dev_to_dev_stream_index*/, [¬e, &status](const Status& s) { + 0 /*dev_to_dev_stream_index*/, [¬e, &status](const absl::Status& s) { status.Update(s); note.Notify(); }); @@ -148,7 +149,7 @@ bool RingGatherer::RunAsyncParts() { tsl::profiler::TraceMe activity("WaitForQueuedEvents", tsl::profiler::TraceMeLevel::kInfo); Notification note; - Status s = gpu_info->default_context->ThenExecute( + absl::Status s = gpu_info->default_context->ThenExecute( col_ctx_->device, gpu_info->stream, [¬e]() { note.Notify(); }); if (s.ok()) { note.WaitForNotification(); @@ -186,7 +187,8 @@ bool RingGatherer::RunAsyncParts() { case RF_INIT: if (rf->do_recv) { rf->action = RF_RECV; - auto requeue = [this, rf, &ready_queue, &aborted](Status s) { + auto requeue = [this, rf, &ready_queue, + &aborted](absl::Status s) { if (!s.ok()) { aborted = true; StartAbort(s); @@ -215,7 +217,7 @@ bool RingGatherer::RunAsyncParts() { if (rf->do_send) { rf->action = RF_SEND; auto send_complete = [this, rf, &ready_queue, - &aborted](Status s) { + &aborted](absl::Status s) { if (!s.ok()) { aborted = true; StartAbort(s); diff --git a/tensorflow/core/common_runtime/ring_gatherer.h b/tensorflow/core/common_runtime/ring_gatherer.h index ee9634834d2b6c..ac894a38a94400 100644 --- a/tensorflow/core/common_runtime/ring_gatherer.h +++ b/tensorflow/core/common_runtime/ring_gatherer.h @@ -33,7 +33,8 @@ class RingGatherer : public RingAlg { RingGatherer() : RingAlg(GATHER_COLLECTIVE, "Gather") {} ~RingGatherer() override {} - Status InitializeCollectiveParams(CollectiveParams* col_params) override; + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; // Begins async execution of the ring gather algorithm. // Must be called in a blockable thread. diff --git a/tensorflow/core/common_runtime/ring_gatherer_test.cc b/tensorflow/core/common_runtime/ring_gatherer_test.cc index e4b402528e8668..595ff502737b93 100644 --- a/tensorflow/core/common_runtime/ring_gatherer_test.cc +++ b/tensorflow/core/common_runtime/ring_gatherer_test.cc @@ -157,7 +157,7 @@ class RingGathererTest : public ::testing::Test { Tensor output_tensor_; Device* device_; core::RefCountPtr col_params_; - Status status_; + absl::Status status_; }; std::unique_ptr test_env_; diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index a01490ca2eb644..cf8f73c15b2955 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -45,7 +45,8 @@ namespace tensorflow { RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); } -Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { +absl::Status RingReducer::InitializeCollectiveParams( + CollectiveParams* col_params) { // TODO(b/113171733): change CHECKs to return errors. CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE); CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce"); @@ -92,7 +93,7 @@ void RingReducer::Run(StatusCallback done) { // We are running in a blockable thread and the callback can't block so // just wait here on the copy. Notification note; - Status status; + absl::Status status; tsl::profiler::TraceMe activity("MemCpyAsync", tsl::profiler::TraceMeLevel::kInfo); CollectiveRemoteAccessLocal::MemCpyAsync( @@ -101,7 +102,7 @@ void RingReducer::Run(StatusCallback done) { col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0), col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input, col_ctx_->output, 0 /*dev_to_dev_stream_index*/, - [¬e, &status](const Status& s) { + [¬e, &status](const absl::Status& s) { status.Update(s); note.Notify(); }); @@ -144,7 +145,7 @@ void RingReducer::ContinueAfterInputCopy() { DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context(); op_dev_ctx->CopyCPUTensorToDevice( &group_size_val, col_ctx_->device, &group_size_tensor_, - [this](const Status& s) { + [this](const absl::Status& s) { if (!s.ok()) { StartAbort(s); } @@ -198,7 +199,7 @@ bool RingReducer::RunAsyncParts() { tsl::profiler::TraceMe activity("WaitForQueuedEvents", tsl::profiler::TraceMeLevel::kInfo); Notification note; - Status s = gpu_info->default_context->ThenExecute( + absl::Status s = gpu_info->default_context->ThenExecute( col_ctx_->device, gpu_info->stream, [¬e]() { note.Notify(); }); if (s.ok()) { note.WaitForNotification(); @@ -236,7 +237,8 @@ bool RingReducer::RunAsyncParts() { case RF_INIT: if (rf->do_recv) { rf->action = RF_RECV; - auto requeue = [this, rf, &ready_queue, &aborted](Status s) { + auto requeue = [this, rf, &ready_queue, + &aborted](absl::Status s) { if (!s.ok()) { aborted = true; StartAbort(s); @@ -255,7 +257,7 @@ bool RingReducer::RunAsyncParts() { --recv_pending_count; if (!rf->second_pass) { rf->action = RF_REDUCE; - Status s = collective_util::ComputeBinOp( + absl::Status s = collective_util::ComputeBinOp( col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, col_params_->merge_op, &rf->chunk, &rf->tmp_chunk); if (!s.ok()) { @@ -270,7 +272,7 @@ bool RingReducer::RunAsyncParts() { if (!rf->second_pass && col_params_->final_op && rf->is_final) { rf->action = RF_FINALIZE; group_size_tensor_ready_.WaitForNotification(); - Status s = collective_util::ComputeBinOp( + absl::Status s = collective_util::ComputeBinOp( col_ctx_->op_ctx, col_ctx_->op_params, col_ctx_->device, col_params_->final_op, &rf->chunk, &group_size_tensor_); if (!s.ok()) { @@ -288,7 +290,7 @@ bool RingReducer::RunAsyncParts() { if (rf->do_send) { rf->action = RF_SEND; auto send_complete = [this, rf, &ready_queue, - &aborted](Status s) { + &aborted](absl::Status s) { if (!s.ok()) { aborted = true; StartAbort(s); diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h index e7b869398bf7fa..77317235e5484f 100644 --- a/tensorflow/core/common_runtime/ring_reducer.h +++ b/tensorflow/core/common_runtime/ring_reducer.h @@ -39,7 +39,8 @@ class RingReducer : public RingAlg { // collective threadpool. void Run(StatusCallback done) override; - Status InitializeCollectiveParams(CollectiveParams* col_params) override; + absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) override; protected: void InitRingField(RingField* rf, int chunk_idx, int subdiv_idx, diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index 2464ac38670052..d4baa4aaef652e 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -47,7 +47,7 @@ namespace tensorflow { std::unique_ptr GetKernel(const NodeDef& node, const DeviceType& device_type, DeviceBase* device) { - Status status; + absl::Status status; std::unique_ptr k = CreateOpKernel( device_type, device, device->GetAllocator(AllocatorAttributes()), node, TF_GRAPH_DEF_VERSION, &status); @@ -194,7 +194,7 @@ class RingReducerTest : public ::testing::Test { core::RefCountPtr col_params_; std::unique_ptr merge_op_; std::unique_ptr final_op_; - Status status_; + absl::Status status_; }; std::unique_ptr test_env_; diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc index 36595a9afe2148..47ddfabbc27efe 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr.cc +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { -Status ScopedAllocatorContainer::AddScopedAllocator( +absl::Status ScopedAllocatorContainer::AddScopedAllocator( const Tensor& backing_tensor, int32_t scope_id, const string& scope_name, const absl::Span& fields, int32_t expected_call_count) { @@ -150,7 +150,7 @@ ScopedAllocatorContainer* ScopedAllocatorMgr::GetContainer(int64_t step_id) { return sac; } -Status ScopedAllocatorMgr::AddScopedAllocator( +absl::Status ScopedAllocatorMgr::AddScopedAllocator( const Tensor& backing_tensor, int64_t step_id, int32_t scope_id, const string& scope_name, const absl::Span& fields, diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr.h b/tensorflow/core/common_runtime/scoped_allocator_mgr.h index 60b954cdcff788..dbbf7c3249ae54 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr.h +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr.h @@ -31,7 +31,7 @@ class ScopedAllocatorMgr; class ScopedAllocatorContainer : public core::RefCounted { public: // Establishes a reachable ScopedAllocator. - Status AddScopedAllocator( + absl::Status AddScopedAllocator( const Tensor& backing_tensor, int32_t scope_id, const std::string& scope_name, const absl::Span& fields, @@ -80,7 +80,7 @@ class ScopedAllocatorMgr { ScopedAllocatorContainer* GetContainer(int64_t step_id); // Establishes a reachable ScopedAllocator. - Status AddScopedAllocator( + absl::Status AddScopedAllocator( const Tensor& backing_tensor, int64_t step_id, int32_t scope_id, const std::string& scope_name, const absl::Span& fields, diff --git a/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc b/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc index a359924f05654d..b9bcdb69fd5374 100644 --- a/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc +++ b/tensorflow/core/common_runtime/scoped_allocator_mgr_test.cc @@ -36,7 +36,7 @@ class ScopedAllocatorMgrTest : public ::testing::Test { &fields_); } - Status AddScopedAllocator(int expected_use_count, int scope_id) { + absl::Status AddScopedAllocator(int expected_use_count, int scope_id) { VLOG(2) << "Adding ScopedAllocator step_id " << step_id_ << " scope_id " << scope_id_ << " #fields " << fields_.size() << " expected_use_count " << expected_use_count; @@ -45,7 +45,7 @@ class ScopedAllocatorMgrTest : public ::testing::Test { expected_use_count); } - Status PrepScopedAllocatorMgr(int expected_use_count) { + absl::Status PrepScopedAllocatorMgr(int expected_use_count) { InitTensor(); PopulateFields(); return AddScopedAllocator(expected_use_count, scope_id_); @@ -123,7 +123,7 @@ TEST_F(ScopedAllocatorMgrTest, PopulateFields) { TEST_F(ScopedAllocatorMgrTest, ContainerAddAllocator) { backing_tensor_shape_ = TensorShape({1024}); fields_shapes_ = std::vector({{512}, {512}}); - Status s = PrepScopedAllocatorMgr(2); + absl::Status s = PrepScopedAllocatorMgr(2); EXPECT_TRUE(s.ok()); // Need to call Allocate and Deallocate in order to use up the expected uses // for this allocator. Save the instances for now. @@ -150,7 +150,7 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorSuccess) { EXPECT_EQ(other, nullptr); backing_tensor_shape_ = TensorShape({512 + 9 + 512 + 16}); fields_shapes_ = std::vector({{512}, {3, 3}, {2, 256}}); - Status s = PrepScopedAllocatorMgr(3); + absl::Status s = PrepScopedAllocatorMgr(3); other = sac->GetAllocator(scope_id_); ScopedAllocatorInstance* inst0 = sac->GetInstance(scope_id_ + 1); @@ -189,7 +189,7 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorInitFail) { backing_tensor_shape_.num_elements() * 2 * sizeof(float); // fields[0].offset + fields[0].bytes_requested is larger than the size of the // backing tensor, so this check should fail - EXPECT_DEATH(Status s = AddScopedAllocator(1, scope_id_), ""); + EXPECT_DEATH(absl::Status s = AddScopedAllocator(1, scope_id_), ""); } // ScopedAllocator allocation should fail because we called more times than @@ -198,7 +198,7 @@ TEST_F(ScopedAllocatorMgrTest, AllocatorInitFail) { TEST_F(ScopedAllocatorMgrTest, AllocatorFail) { backing_tensor_shape_ = TensorShape({1024}); fields_shapes_ = std::vector({{512}, {512}}); - Status s = PrepScopedAllocatorMgr(2); + absl::Status s = PrepScopedAllocatorMgr(2); EXPECT_TRUE(s.ok()); // Save instances so that we can explicitly delete later on. In normal // operation the instances will be automatically deleted after single use, but diff --git a/tensorflow/core/common_runtime/session.cc b/tensorflow/core/common_runtime/session.cc index 6715ab20aad132..a48bfe33895022 100644 --- a/tensorflow/core/common_runtime/session.cc +++ b/tensorflow/core/common_runtime/session.cc @@ -36,27 +36,27 @@ Session::Session() {} Session::~Session() {} -Status Session::Run(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_tensor_names, - std::vector* outputs, RunMetadata* run_metadata) { +absl::Status Session::Run(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_tensor_names, + std::vector* outputs, + RunMetadata* run_metadata) { return errors::Unimplemented( "Run with options is not supported for this session."); } -Status Session::PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) { +absl::Status Session::PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) { return errors::Unimplemented( "Partial run is not supported for this session."); } -Status Session::PRun(const string& handle, - const std::vector >& inputs, - const std::vector& output_names, - std::vector* outputs) { +absl::Status Session::PRun( + const string& handle, const std::vector >& inputs, + const std::vector& output_names, std::vector* outputs) { return errors::Unimplemented( "Partial run is not supported for this session."); } @@ -67,7 +67,7 @@ Session* NewSession(const SessionOptions& options) { // currently a no-op. SetSessionCreatedMetric(); Session* out_session; - Status s = NewSession(options, &out_session); + absl::Status s = NewSession(options, &out_session); if (!s.ok()) { LOG(ERROR) << "Failed to create session: " << s; return nullptr; @@ -75,9 +75,9 @@ Session* NewSession(const SessionOptions& options) { return out_session; } -Status NewSession(const SessionOptions& options, Session** out_session) { +absl::Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; - Status s = SessionFactory::GetFactory(options, &factory); + absl::Status s = SessionFactory::GetFactory(options, &factory); if (!s.ok()) { *out_session = nullptr; LOG(ERROR) << "Failed to get session factory: " << s; @@ -95,8 +95,8 @@ Status NewSession(const SessionOptions& options, Session** out_session) { return s; } -Status Reset(const SessionOptions& options, - const std::vector& containers) { +absl::Status Reset(const SessionOptions& options, + const std::vector& containers) { SessionFactory* factory; TF_RETURN_IF_ERROR(SessionFactory::GetFactory(options, &factory)); return factory->Reset(options, containers); diff --git a/tensorflow/core/common_runtime/session_factory.cc b/tensorflow/core/common_runtime/session_factory.cc index 0ce81f9d9aed40..c21f1dc9483ee2 100644 --- a/tensorflow/core/common_runtime/session_factory.cc +++ b/tensorflow/core/common_runtime/session_factory.cc @@ -65,8 +65,8 @@ string SessionOptionsToString(const SessionOptions& options) { } } // namespace -Status SessionFactory::GetFactory(const SessionOptions& options, - SessionFactory** out_factory) { +absl::Status SessionFactory::GetFactory(const SessionOptions& options, + SessionFactory** out_factory) { mutex_lock l(*get_session_factory_lock()); // could use reader lock std::vector> candidate_factories; diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h index 8976abd1a818d9..ffadb29ae21a6c 100644 --- a/tensorflow/core/common_runtime/session_factory.h +++ b/tensorflow/core/common_runtime/session_factory.h @@ -33,8 +33,8 @@ class SessionFactory { // Creates a new session and stores it in *out_session, or fails with an error // status if the Session could not be created. Caller takes ownership of // *out_session if this returns OkStatus(). - virtual Status NewSession(const SessionOptions& options, - Session** out_session) = 0; + virtual absl::Status NewSession(const SessionOptions& options, + Session** out_session) = 0; virtual bool AcceptsOptions(const SessionOptions& options) = 0; @@ -60,15 +60,15 @@ class SessionFactory { // listed explicitly. // // Sessions that support resource containers should override this function. - virtual Status Reset(const SessionOptions& options, - const std::vector& containers) { + virtual absl::Status Reset(const SessionOptions& options, + const std::vector& containers) { return errors::Unimplemented("Reset()"); } virtual ~SessionFactory() {} static void Register(const string& runtime_type, SessionFactory* factory); - static Status GetFactory(const SessionOptions& options, - SessionFactory** out_factory); + static absl::Status GetFactory(const SessionOptions& options, + SessionFactory** out_factory); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 7bf14d304c3740..47341276fef563 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -23,7 +23,7 @@ namespace tensorflow { // kTensorHandleResourceTypeName. const char* SessionState::kTensorHandleResourceTypeName = "TensorHandle"; -Status SessionState::GetTensor(const string& handle, Tensor* tensor) { +absl::Status SessionState::GetTensor(const string& handle, Tensor* tensor) { mutex_lock l(state_lock_); auto it = tensors_.find(handle); if (it == tensors_.end()) { @@ -34,7 +34,8 @@ Status SessionState::GetTensor(const string& handle, Tensor* tensor) { return absl::OkStatus(); } -Status SessionState::AddTensor(const string& handle, const Tensor& tensor) { +absl::Status SessionState::AddTensor(const string& handle, + const Tensor& tensor) { mutex_lock l(state_lock_); if (!tensors_.insert({handle, tensor}).second) { return errors::InvalidArgument("Failed to add a tensor with handle '", @@ -43,7 +44,7 @@ Status SessionState::AddTensor(const string& handle, const Tensor& tensor) { return absl::OkStatus(); } -Status SessionState::DeleteTensor(const string& handle) { +absl::Status SessionState::DeleteTensor(const string& handle) { mutex_lock l(state_lock_); if (tensors_.erase(handle) == 0) { return errors::InvalidArgument("Failed to delete a tensor with handle '", @@ -57,7 +58,8 @@ int64_t SessionState::GetNewId() { return tensor_id_++; } -Status TensorStore::AddTensor(const string& name, const TensorAndKey& tk) { +absl::Status TensorStore::AddTensor(const string& name, + const TensorAndKey& tk) { mutex_lock l(lock_); if (!tensors_.insert({name, tk}).second) { return errors::InvalidArgument("Failed to add a tensor with name '", name, @@ -67,8 +69,8 @@ Status TensorStore::AddTensor(const string& name, const TensorAndKey& tk) { return absl::OkStatus(); } -Status TensorStore::SaveTensors(const std::vector& output_names, - SessionState* session_state) { +absl::Status TensorStore::SaveTensors(const std::vector& output_names, + SessionState* session_state) { mutex_lock l(lock_); if (!tensors_.empty()) { // Save only the tensors in output_names in the session. diff --git a/tensorflow/core/common_runtime/session_test.cc b/tensorflow/core/common_runtime/session_test.cc index 7c68f0f001e471..0c36356e65f434 100644 --- a/tensorflow/core/common_runtime/session_test.cc +++ b/tensorflow/core/common_runtime/session_test.cc @@ -30,7 +30,7 @@ TEST(SessionTest, InvalidTargetReturnsNull) { EXPECT_EQ(nullptr, tensorflow::NewSession(options)); Session* session; - Status s = tensorflow::NewSession(options, &session); + absl::Status s = tensorflow::NewSession(options, &session); EXPECT_EQ(s.code(), error::NOT_FOUND); EXPECT_TRUE(absl::StrContains( s.message(), @@ -47,8 +47,8 @@ class FakeSessionFactory : public SessionFactory { return absl::StartsWith(options.target, "fake"); } - Status NewSession(const SessionOptions& options, - Session** out_session) override { + absl::Status NewSession(const SessionOptions& options, + Session** out_session) override { *out_session = nullptr; return absl::OkStatus(); } @@ -67,7 +67,7 @@ TEST(SessionTest, MultipleFactoriesForTarget) { options.target = "fakesession"; Session* session; - Status s = tensorflow::NewSession(options, &session); + absl::Status s = tensorflow::NewSession(options, &session); EXPECT_EQ(s.code(), error::INTERNAL); EXPECT_TRUE(absl::StrContains(s.message(), "Multiple session factories")); EXPECT_TRUE(absl::StrContains(s.message(), "FAKE_SESSION_1")); diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 06c353eb3669de..6a546835fd5f54 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -69,7 +69,7 @@ constexpr char kRetvalOp[] = "_Retval"; // Runs shape inference for the given node using the given ShapeRefiner. // The node must be a sub-node of a function node and the outer_context is // the inference context of that function node in the outer graph. -Status ShapeRefiner::InferShapesForFunctionSubNode( +absl::Status ShapeRefiner::InferShapesForFunctionSubNode( const Node* node, InferenceContext* outer_context) { TF_RETURN_IF_ERROR(AddNodeInternal(node, outer_context)); InferenceContext* node_context = CHECK_NOTNULL(GetContext(node)); @@ -158,9 +158,9 @@ Status ShapeRefiner::InferShapesForFunctionSubNode( // NOTE: Recursive user-defined functions are not supported. // Maybe we won't support recursive functions at all in TF, because of // other maintainability issues. -Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def, - AttrSlice attributes, - InferenceContext* outer_context) { +absl::Status ShapeRefiner::InferShapesForFunction( + const FunctionDef* function_def, AttrSlice attributes, + InferenceContext* outer_context) { const Graph* graph; const string& fname = function_def->signature().name(); auto it = functions_.find(fname); @@ -185,7 +185,7 @@ Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def, } absl::flat_hash_set function_nodes; - Status inference_status = absl::OkStatus(); + absl::Status inference_status = absl::OkStatus(); { auto node_shape_inference_lambda = [this, &outer_context, &function_nodes, &inference_status](const Node* node) { @@ -208,11 +208,11 @@ Status ShapeRefiner::InferShapesForFunction(const FunctionDef* function_def, return inference_status; } -Status ShapeRefiner::AddNode(const Node* node) { +absl::Status ShapeRefiner::AddNode(const Node* node) { return AddNodeInternal(node, /*outer_context=*/nullptr); } -Status ShapeRefiner::AddNodeInternal( +absl::Status ShapeRefiner::AddNodeInternal( const Node* node, shape_inference::InferenceContext* outer_context) { // Create the inference context for this node with the existing input shapes. std::unique_ptr ic(new InferenceContext( @@ -272,8 +272,8 @@ Status ShapeRefiner::AddNodeInternal( return absl::OkStatus(); } -Status ShapeRefiner::SetShape(const Node* node, int output_port, - ShapeHandle shape) { +absl::Status ShapeRefiner::SetShape(const Node* node, int output_port, + ShapeHandle shape) { auto c = GetContext(node); if (c == nullptr) { return errors::Internal("Could not find context for ", node->name()); @@ -305,7 +305,8 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, return absl::OkStatus(); } -Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { +absl::Status ShapeRefiner::UpdateNode(const Node* node, bool relax, + bool* refined) { auto it = node_to_context_.find(node); if (it == node_to_context_.end()) { *refined = true; @@ -408,7 +409,7 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { return RunShapeFn(node, op_reg_data, node_context); } -Status ShapeRefiner::EvaluateConstantTensorForEdge( +absl::Status ShapeRefiner::EvaluateConstantTensorForEdge( const Node* node, int dst_idx, bool* evaluated, Tensor* result, InferenceContext* outer_context) { const Edge* input_edge; @@ -462,7 +463,7 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge( return absl::OkStatus(); } -Status ShapeRefiner::EvaluateConstantIntScalarEdge( +absl::Status ShapeRefiner::EvaluateConstantIntScalarEdge( const Node* node, int dst_idx, bool* evaluated, int64_t* result, shape_inference::InferenceContext* outer_context) { Tensor scalar; @@ -488,7 +489,7 @@ Status ShapeRefiner::EvaluateConstantIntScalarEdge( return absl::OkStatus(); } -Status ShapeRefiner::ConstantPartialShape( +absl::Status ShapeRefiner::ConstantPartialShape( InferenceContext* target_context, const Node* node, int dst_idx, ShapeHandle* result, shape_inference::InferenceContext* outer_context) { const Edge* input_edge; @@ -634,7 +635,7 @@ Status ShapeRefiner::ConstantPartialShape( return absl::OkStatus(); } -Status ShapeRefiner::PartialStridedSliceShape( +absl::Status ShapeRefiner::PartialStridedSliceShape( Node* slice_node, InferenceContext* ctx, ShapeHandle* result, shape_inference::InferenceContext* outer_context) { // Only attempt to evaluate if begin/end/strides all are scalars. @@ -707,10 +708,10 @@ Status ShapeRefiner::PartialStridedSliceShape( return absl::OkStatus(); } -Status ShapeRefiner::RunShapeFn(const Node* node, - const OpRegistrationData* op_reg_data, - InferenceContext* c, - InferenceContext* outer_context) { +absl::Status ShapeRefiner::RunShapeFn(const Node* node, + const OpRegistrationData* op_reg_data, + InferenceContext* c, + InferenceContext* outer_context) { // This will be filled in with real data in a second pass. std::vector input_tensors(node->num_inputs(), nullptr); std::vector real_tensors(node->num_inputs()); @@ -744,7 +745,7 @@ Status ShapeRefiner::RunShapeFn(const Node* node, const_tensor_map_.clear(); VLOG(4) << "Running shape inference for function \"" << function.name() << "\"."; - Status function_inference_status = InferShapesForFunction( + absl::Status function_inference_status = InferShapesForFunction( function_def, AttrSlice(&function.attr()), c); const_tensor_map_ = const_tensor_map_copy; VLOG(4) << "Shape inference for function \"" << function.name() diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index a0b7c45140f726..580dafb0ed6cb0 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -57,15 +57,15 @@ class ShapeRefiner { // - the shape function for 'node' was not registered. // - 'node' was added before its inputs. // - The shape inference function returns an error. - Status AddNode(const Node* node); + absl::Status AddNode(const Node* node); // Sets 'node's 'output_port' output to have shape 'shape'. // // Returns an error if 'node' was not previously added to this // object, if 'output_port' is invalid, or if 'shape' is // not compatible with the existing shape of the output. - Status SetShape(const Node* node, int output_port, - shape_inference::ShapeHandle shape); + absl::Status SetShape(const Node* node, int output_port, + shape_inference::ShapeHandle shape); // Update the input shapes of node in case the shapes of the fan-ins of 'node' // have themselves been modified (For example, in case of incremental shape @@ -75,7 +75,7 @@ class ShapeRefiner { // changed (in their string representations). Note that shapes may have been // updated to newer versions (but with identical string representations) even // if <*refined> is set to false. - Status UpdateNode(const Node* node, bool relax, bool* refined); + absl::Status UpdateNode(const Node* node, bool relax, bool* refined); // Returns the InferenceContext for 'node', if present. shape_inference::InferenceContext* GetContext(const Node* node) const { @@ -139,14 +139,14 @@ class ShapeRefiner { // // On success: // - outer_context will contain output shapes inferred from input shapes - Status InferShapesForFunction( + absl::Status InferShapesForFunction( const FunctionDef* function_def, AttrSlice attributes, shape_inference::InferenceContext* outer_context); // Performs shape inference for a node inside a function. // // 'outer_context' is the 'InferenceContext' for the function's call op. - Status InferShapesForFunctionSubNode( + absl::Status InferShapesForFunctionSubNode( const Node* node, shape_inference::InferenceContext* outer_context); // Performs validation of 'node' and runs 'node's shape function, @@ -165,8 +165,8 @@ class ShapeRefiner { // - the shape function for 'node' was not registered. // - 'node' was added before its inputs. // - The shape inference function returns an error. - Status AddNodeInternal(const Node* node, - shape_inference::InferenceContext* outer_context); + absl::Status AddNodeInternal( + const Node* node, shape_inference::InferenceContext* outer_context); // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge // value can be evaluated, 'evaluated' is set to true and the value returned @@ -177,7 +177,7 @@ class ShapeRefiner { // otherwise). This gets used to perform constant propagation across Arg nodes // by requesting the constant of value of the incoming tensor from the // 'outer_context'. - Status EvaluateConstantTensorForEdge( + absl::Status EvaluateConstantTensorForEdge( const Node* node, int dst_idx, bool* evaluated, Tensor* result, shape_inference::InferenceContext* outer_context); @@ -190,7 +190,7 @@ class ShapeRefiner { // otherwise). This gets used to perform constant propagation across Arg nodes // by requesting the constant of value of the incoming tensor from the // 'outer_context'. - Status EvaluateConstantIntScalarEdge( + absl::Status EvaluateConstantIntScalarEdge( const Node* node, int dst_idx, bool* evaluated, int64_t* result, shape_inference::InferenceContext* outer_context); @@ -221,10 +221,10 @@ class ShapeRefiner { // otherwise). This gets used to perform constant propagation across Arg nodes // by requesting the constant of value of the incoming tensor from the // 'outer_context'. - Status ConstantPartialShape(shape_inference::InferenceContext* target_context, - const Node* node, int dst_idx, - shape_inference::ShapeHandle* result, - shape_inference::InferenceContext* outer_context); + absl::Status ConstantPartialShape( + shape_inference::InferenceContext* target_context, const Node* node, + int dst_idx, shape_inference::ShapeHandle* result, + shape_inference::InferenceContext* outer_context); // Implementation of ConstantPartialShape for StridedSlice nodes. // @@ -233,7 +233,7 @@ class ShapeRefiner { // otherwise). This gets used to perform constant propagation across Arg nodes // by requesting the constant of value of the incoming tensor from the // 'outer_context'. - Status PartialStridedSliceShape( + absl::Status PartialStridedSliceShape( Node* slice_node, shape_inference::InferenceContext* ctx, shape_inference::ShapeHandle* result, shape_inference::InferenceContext* outer_context); @@ -245,9 +245,10 @@ class ShapeRefiner { // otherwise). This gets used to perform constant propagation across Arg nodes // by requesting the constant of value of the incoming tensor from the // 'outer_context'. - Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, - shape_inference::InferenceContext* context, - shape_inference::InferenceContext* outer_context = nullptr); + absl::Status RunShapeFn( + const Node* node, const OpRegistrationData* op_reg_data, + shape_inference::InferenceContext* context, + shape_inference::InferenceContext* outer_context = nullptr); int32 graph_def_version_; const OpRegistryInterface* const ops_registry_; diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index b50973981fdb0e..89105e1b636129 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -162,7 +162,7 @@ TEST_F(ShapeRefinerTest, BadShapes) { TF_ASSERT_OK(m.AddNode(b.node())); // The shape of the inputs are not compatible, so we should expect // an error. - Status s = m.AddNode(mm.node()); + absl::Status s = m.AddNode(mm.node()); ASSERT_FALSE(s.ok()); ASSERT_TRUE(absl::StrContains(s.message(), "Dimensions must be equal, but are 1 and 2")); @@ -830,14 +830,14 @@ TEST_F(ShapeRefinerTest, ConstantValueVisitNodeTwice) { namespace { -Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) { +absl::Status TensorAsShapeShapeFn(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0 /* input_idx */, &out)); c->set_output(0, out); return absl::OkStatus(); } -Status PartialTensorAsShapeShapeFn(shape_inference::InferenceContext* c) { +absl::Status PartialTensorAsShapeShapeFn(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; const Tensor* t = c->input_tensor(0); if (t == nullptr || t->NumElements() != 1) { diff --git a/tensorflow/core/common_runtime/simple_propagator_state.h b/tensorflow/core/common_runtime/simple_propagator_state.h index 1a08b8b4d67fb1..9f465ef1e91d3d 100644 --- a/tensorflow/core/common_runtime/simple_propagator_state.h +++ b/tensorflow/core/common_runtime/simple_propagator_state.h @@ -99,12 +99,12 @@ class SimplePropagatorState { // TODO(b/152925936): Re-evaluate these constants with current usage // patterns. static constexpr int kSpillThreshold = 16384; - gtl::InlinedVector ready_; + absl::InlinedVector ready_; int front_index_; }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. - typedef gtl::InlinedVector TaggedNodeSeq; + typedef absl::InlinedVector TaggedNodeSeq; // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. void ActivateRoots(gtl::ArraySlice roots, diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc index 3ac8bec6145dc3..1c86e29c9552e7 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.cc @@ -15,20 +15,304 @@ limitations under the License. #include "tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h" +#include +#include +#include + +#include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/util/device_name_utils.h" +#include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/config/flags.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/bfloat16.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/util/device_name_utils.h" +#include "tensorflow/core/util/dump_graph.h" +#include "tsl/platform/errors.h" namespace tensorflow { +namespace { + +constexpr absl::string_view kTpuExecute = "TPUExecute"; +constexpr absl::string_view kParallelExecuteIds = "_parallel_execution_ids"; +const char kICIWeightDistributionMlirBridgeMarker[] = + "_ici_weight_distribution_mlir_bridge_marker"; + +// Get the new op name which is used to replace the old op, the new op name +// contains the index of the input and the task id of the TPUExecute node. +std::string GetNewOpName(std::string op_name, int index, int task_id) { + return absl::StrCat(op_name, "_ici_specific_index_", std::to_string(index), + "_task_id_", std::to_string(task_id)); +} + +// Find all the TPUExecute nodes that is not on replica 0. In addition, return +// an empty vector if there is a parallel execute id that is not 0, which +// indicates SPMD case. In the meantime, we check if this is a SPMD case. +std::vector GetNonMainReplicaIciTPUExecuteNodes(Graph* graph, + bool& is_spmd) { + std::vector tpu_nodes; + for (Node* node : graph->nodes()) { + if (node->type_string() == kTpuExecute && + HasNodeAttr(node->def(), kParallelExecuteIds)) { + auto parallel_exec_ids = node->attrs().Find(kParallelExecuteIds)->s(); + std::vector group_vec = + absl::StrSplit(parallel_exec_ids, ','); + if (group_vec.empty()) return tpu_nodes; + std::vector replica_vec = absl::StrSplit(group_vec[0], ':'); + int replica_id = std::stoi(replica_vec[1]); + if (replica_id != 0) tpu_nodes.push_back(node); + if (group_vec.size() > 1) { + std::vector parallel_vec = + absl::StrSplit(group_vec[1], ':'); + int parallel_id = std::stoi(parallel_vec[1]); + if (parallel_id != 0) is_spmd = true; + } + } + } + return tpu_nodes; +} + +// Remove the edge from old_src_node to dst_node, and add the edge from +// new_src_node to dst_node. +void RedirectEdge(Graph* graph, Node* old_src_node, Node* dst_node, + Node* new_src_node, int input_index) { + const Edge* delete_edge; + for (auto edge : dst_node->in_edges()) { + if (edge->src() == old_src_node) { + delete_edge = edge; + break; + } + } + if (delete_edge == nullptr) return; + + graph->RemoveEdge(delete_edge); + graph->AddEdge(new_src_node, 0, dst_node, input_index); +} + +// Find the corresponding host device name from the TPU device name. +string GetHostDeviceName(Node* tpu_node) { + auto device_name = tpu_node->requested_device(); + if (device_name.empty()) device_name = tpu_node->assigned_device_name(); + DeviceNameUtils::ParsedName parsed_device_name; + DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); + string host_device_name = DeviceNameUtils::FullName( + parsed_device_name.job, parsed_device_name.replica, + parsed_device_name.task, /*type=*/"CPU", /*id=*/0); + return host_device_name; +} + +std::optional> GetOutputShapeVec(Node* node) { + auto output_shapes = node->attrs().Find("_output_shapes"); + if (output_shapes == nullptr) return std::nullopt; + auto output_shape = output_shapes->list().shape()[0]; + std::vector output_shape_vec; + output_shape_vec.reserve(output_shape.dim_size()); + for (auto i = 0; i < output_shape.dim_size(); i++) { + output_shape_vec.push_back(output_shape.dim()[i].size()); + } + return output_shape_vec; +} + +int GetTPUTaskId(Node* tpu_node) { + auto device_name = tpu_node->requested_device(); + if (device_name.empty()) device_name = tpu_node->assigned_device_name(); + DeviceNameUtils::ParsedName parsed_device_name; + DeviceNameUtils::ParseFullName(device_name, &parsed_device_name); + return parsed_device_name.task; +} + +// Build the fill op. Its value is 0 and the fill op is put on the host device +// with the same task id as the TPUExecute node. +Node* BuildFillOp(GraphDefBuilder::Options& bopts, Node* tpu_node, + Node* in_node, int input_index, string host_device_name) { + // Find the output_shape vector + auto output_shape_vec = GetOutputShapeVec(in_node); + if (!output_shape_vec.has_value()) return nullptr; + + // Find the element type + auto dtype = in_node->attrs().Find("T")->type(); + + // Get TPU task id. + int tpu_task_id = GetTPUTaskId(tpu_node); + + TensorShape tensor_shape; + tensor_shape.AddDim(output_shape_vec.value().size()); + Tensor const_op_shape_tensor(DT_INT32, tensor_shape); + for (int i = 0; i < output_shape_vec.value().size(); i++) { + const_op_shape_tensor.flat()(i) = output_shape_vec.value()[i]; + } + + // Build dim of fill op + std::string const_1_name = GetNewOpName("const_1", input_index, tpu_task_id); + Node* fill_dim_input = + ops::SourceOp("Const", bopts.WithName(const_1_name) + .WithAttr("dtype", DT_INT32) + .WithAttr("value", const_op_shape_tensor)); + TensorShape fill_dim_output_shape; + fill_dim_output_shape.AddDim(output_shape_vec.value().size()); + fill_dim_input->AddAttr("_output_shapes", + std::vector{fill_dim_output_shape}); + + // Build value of fill op + std::string const_2_name = GetNewOpName("const_2", input_index, tpu_task_id); + auto scalar_tensor = Tensor(dtype, {}); + + if (dtype == DT_FLOAT) { + scalar_tensor.scalar()() = 0; + } else if (dtype == DT_BFLOAT16) { + scalar_tensor.scalar()() = bfloat16(0); + } else { + LOG(ERROR) << "Unsupported data type: ", DataTypeString(dtype); + return nullptr; + } + Node* fill_value_input = + ops::SourceOp("Const", bopts.WithName(const_2_name) + .WithAttr("dtype", dtype) + .WithAttr("value", scalar_tensor)); + TensorShape fill_value_output_shape; + fill_value_input->AddAttr("_output_shapes", + std::vector{fill_value_output_shape}); + + // Build fill op + std::string fill_name = GetNewOpName("fill", input_index, tpu_task_id); + Node* new_fill = + ops::BinaryOp("Fill", fill_dim_input, fill_value_input, + bopts.WithName(fill_name).WithAttr("T", dtype)); + + TensorShape new_output_shape; + for (auto output_shape : output_shape_vec.value()) { + new_output_shape.AddDim(output_shape); + } + new_fill->AddAttr("_output_shapes", + std::vector{new_output_shape}); + new_fill->AddAttr("_xla_inferred_shapes", + std::vector{new_output_shape}); + + // Set the device to each node. + fill_dim_input->set_requested_device(host_device_name); + fill_value_input->set_requested_device(host_device_name); + new_fill->set_requested_device(host_device_name); -Status SimplifyIciDummyVariablesPass::Run( + return new_fill; +} + +// Replace the ici dummy variable with one on the right task id. +absl::Status ReplaceIciDummyVariables(Graph* graph, int input_index, + std::vector tpu_nodes, + GraphDefBuilder::Options& bopts) { + absl::flat_hash_map device_to_node_map; + for (Node* tpu_node : tpu_nodes) { + Node* in_node; + TF_RETURN_IF_ERROR(tpu_node->input_node(input_index, &in_node)); + + if (!in_node->attrs().Find(kICIWeightDistributionMlirBridgeMarker)) { + continue; + } + + string host_device_name = GetHostDeviceName(tpu_node); + + // If the node corresponding to host_device_name is already in the graph, + // replace the edge from in_node to tpu_node with the edge from + // device_to_node_map[host_device_name] to tpu_node. + if (device_to_node_map.contains(host_device_name)) { + RedirectEdge(graph, in_node, tpu_node, + device_to_node_map[host_device_name], input_index); + continue; + } + + Node* new_fill = + BuildFillOp(bopts, tpu_node, in_node, input_index, host_device_name); + if (new_fill == nullptr) continue; + + device_to_node_map[host_device_name] = new_fill; + RedirectEdge(graph, in_node, tpu_node, device_to_node_map[host_device_name], + input_index); + } + return absl::OkStatus(); +} + +} // namespace + +bool ShouldRunPass(const GraphOptimizationPassOptions& options) { + if (!flags::Global().enable_tf2min_ici_weight.value()) { + VLOG(1) << "SimplifyIciDummyVariablesPass is disabled."; + return false; + } + VLOG(1) << "SimplifyIciDummyVariablesPass is enabled."; + + // find all potential nodes. + if (options.graph == nullptr) { + LOG(INFO) << "No graph in simplify_ici_dummy_variables_pass."; + return false; + } + return true; +} + +absl::Status SimplifyIciDummyVariablesPass::Run( const GraphOptimizationPassOptions& options) { + if (!ShouldRunPass(options)) { + return absl::OkStatus(); + } + + Graph* graph = options.graph->get(); + VLOG(1) << DumpGraphToFile("before_simplify_ici_dummy_variables_pass", *graph, + options.flib_def); + + absl::Status status; + GraphDefBuilder::Options bopts(graph, &status); + if (!status.ok()) { + LOG(ERROR) << "GraphDefBuilder::Option failed to initialize."; + return status; + } + + bool is_spmd = false; + + // Find all the qualified tpu_execute nodes which is not on replica 0. + std::vector tpu_nodes = + GetNonMainReplicaIciTPUExecuteNodes(graph, is_spmd); + + if (!is_spmd) { + VLOG(1) << "Not SPMD case, skip SimplifyIciDummyVariablesPass."; + return absl::OkStatus(); + } + + if (tpu_nodes.empty()) { + VLOG(1) << "tpu_nodes is empty, skip SimplifyIciDummyVariablesPass."; + return absl::OkStatus(); + } + + for (int i = 0; i < tpu_nodes[0]->num_inputs(); ++i) { + auto replace_status = ReplaceIciDummyVariables(graph, i, tpu_nodes, bopts); + if (!replace_status.ok()) { + LOG(ERROR) << "Replace ici dummy variables failed."; + return replace_status; + } + } + + // Remove the dead nodes that previously connected to the TPUExecute node. + RemoveDeadNodes(graph); + + VLOG(1) << DumpGraphToFile("after_simplify_ici_dummy_variables_pass", *graph, + options.flib_def); + return absl::OkStatus(); } -// REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 49, -// SimplifyIciDummyVariablesPass); +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 49, + SimplifyIciDummyVariablesPass); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h index f539793168e28b..553e298fe5ac36 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h @@ -36,7 +36,7 @@ limitations under the License. // node {name: "Identity0", op: "Identity", input: "fill0", // device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" // attr { -// key: "ici_weight_distribution_mlir_bridge_marker", value {b: true} +// key: "_ici_weight_distribution_mlir_bridge_marker", value {b: true} // } // } // node {name: "const2", op: "Const"} @@ -45,13 +45,13 @@ limitations under the License. // node {name: "identity1", op: "Identity", input: "fill1" // device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" // attr { -// key: "ici_weight_distribution_mlir_bridge_marker", value {b: true} +// key: "_ici_weight_distribution_mlir_bridge_marker", value {b: true} // } // } // node {name: "const4", op: "Const"} // node {name: "split0", op: "Split", input: "const4", input: "identity1" // attr { -// key: "ici_weight_distribution_mlir_bridge_marker" +// key: "_ici_weight_distribution_mlir_bridge_marker" // value {b: true} // } // } @@ -101,7 +101,7 @@ namespace tensorflow { // The dummy variables will be put on the same task as the TPUExecute Op. class SimplifyIciDummyVariablesPass : public GraphOptimizationPass { public: - Status Run(const GraphOptimizationPassOptions& options) override; + absl::Status Run(const GraphOptimizationPassOptions& options) override; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc index 6319eb336c7dd7..1836419c22820e 100644 --- a/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc +++ b/tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass_test.cc @@ -16,29 +16,120 @@ limitations under the License. #include "tensorflow/core/common_runtime/simplify_ici_dummy_variables_pass.h" #include +#include +#include "absl/status/status.h" #include "tensorflow/cc/framework/scope.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/config/flags.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/test.h" namespace tensorflow { -TEST(SimplifyIciDummyVariablesPassTest, SimplifyIciDummyVariables) { +// Return the node with the specified name +Node* GetNode(const Graph& graph, const std::string& name) { + for (Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +std::string TestDataPath() { + return tensorflow::GetDataDependencyFilepath( + "tensorflow/core/common_runtime/testdata/"); +} + +// Test the case enable_tf2min_ici_weight is false. +TEST(SimplifyIciDummyVariablesPassTest, flag_is_false) { + flags::Global().enable_tf2min_ici_weight.reset(false); auto graph = std::make_unique(OpRegistry::Global()); - GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); - TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); - GraphDef before; - graph->ToGraphDef(&before); + std::string graph_path = + TestDataPath() + "simplify_ici_dummy_variables_pass_before.pbtxt"; + tensorflow::GraphDef graph_def; + absl::Status load_graph_status = + ReadTextProto(tensorflow::Env::Default(), graph_path, &graph_def); + EXPECT_EQ(load_graph_status.ok(), true); + TF_EXPECT_OK(ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, + graph.get())); + GraphOptimizationPassOptions options; options.graph = &graph; SimplifyIciDummyVariablesPass pass; TF_ASSERT_OK(pass.Run(options)); + + Node* fill_1_dim = GetNode(*graph, "const_1_ici_specific_index_0_task_id_2"); + Node* fill_1_value = + GetNode(*graph, "const_2_ici_specific_index_0_task_id_2"); + Node* fill_1 = GetNode(*graph, "fill_ici_specific_index_0_task_id_2"); + EXPECT_EQ(fill_1_dim, nullptr); + EXPECT_EQ(fill_1_value, nullptr); + EXPECT_EQ(fill_1, nullptr); + + Node* fill_2_dim = GetNode(*graph, "const_1_ici_specific_index_1_task_id_2"); + Node* fill_2_value = + GetNode(*graph, "const_2_ici_specific_index_1_task_id_2"); + Node* fill_2 = GetNode(*graph, "fill_ici_specific_index_1_task_id_2"); + EXPECT_EQ(fill_2_dim, nullptr); + EXPECT_EQ(fill_2_value, nullptr); + EXPECT_EQ(fill_2, nullptr); +} + +// Test the case enable_tf2min_ici_weight is true, graph after pass will have +// dummy variables on task 2. +TEST(SimplifyIciDummyVariablesPassTest, replace_dummy_variable) { + flags::Global().enable_tf2min_ici_weight.reset(true); + auto graph = std::make_unique(OpRegistry::Global()); + std::string graph_path = + TestDataPath() + "simplify_ici_dummy_variables_pass_before.pbtxt"; + tensorflow::GraphDef graph_def; + absl::Status load_graph_status = + ReadTextProto(tensorflow::Env::Default(), graph_path, &graph_def); + EXPECT_EQ(load_graph_status.ok(), true); + TF_EXPECT_OK(ConvertGraphDefToGraph(GraphConstructorOptions(), graph_def, + graph.get())); + + GraphOptimizationPassOptions options; + options.graph = &graph; + SimplifyIciDummyVariablesPass pass; + TF_ASSERT_OK(pass.Run(options)); + + Node* fill_1_dim = GetNode(*graph, "const_1_ici_specific_index_0_task_id_2"); + Node* fill_1_value = + GetNode(*graph, "const_2_ici_specific_index_0_task_id_2"); + Node* fill_1 = GetNode(*graph, "fill_ici_specific_index_0_task_id_2"); + EXPECT_NE(fill_1_dim, nullptr); + EXPECT_NE(fill_1_value, nullptr); + EXPECT_NE(fill_1, nullptr); + EXPECT_EQ(fill_1_dim->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_1_value->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_1->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + + Node* fill_2_dim = GetNode(*graph, "const_1_ici_specific_index_1_task_id_2"); + Node* fill_2_value = + GetNode(*graph, "const_2_ici_specific_index_1_task_id_2"); + Node* fill_2 = GetNode(*graph, "fill_ici_specific_index_1_task_id_2"); + EXPECT_NE(fill_2_dim, nullptr); + EXPECT_NE(fill_2_value, nullptr); + EXPECT_NE(fill_2, nullptr); + EXPECT_EQ(fill_2_dim->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_2_value->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); + EXPECT_EQ(fill_2->requested_device(), + "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.cc b/tensorflow/core/common_runtime/single_threaded_cpu_device.cc index ababa7d14aec33..84615bec3f7ddf 100644 --- a/tensorflow/core/common_runtime/single_threaded_cpu_device.cc +++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.cc @@ -56,11 +56,11 @@ class SingleThreadedCpuDevice : public Device { ~SingleThreadedCpuDevice() override { eigen_device_.reset(); } - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override { + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override { Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from tensor_proto."); diff --git a/tensorflow/core/common_runtime/single_threaded_executor.cc b/tensorflow/core/common_runtime/single_threaded_executor.cc index 19b6a831382151..a7c30baec739ad 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor.cc @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { -Status ValidateOpIsSafeForSyncExecution( +absl::Status ValidateOpIsSafeForSyncExecution( const Node& n, bool allow_control_flow_sync_execution) { for (DataType dt : n.output_types()) { if (IsRefType(dt)) { @@ -62,8 +62,8 @@ Status ValidateOpIsSafeForSyncExecution( namespace { -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector TensorValueVec; +typedef absl::InlinedVector AllocatorAttributeVec; static const string& kSingleThreadedExecutor = *new string("SINGLE_THREADED_EXECUTOR"); @@ -82,7 +82,7 @@ class SingleThreadedExecutorImpl : public Executor { } } - Status Initialize(const Graph& graph) { + absl::Status Initialize(const Graph& graph) { // Topologicially sort `graph` to get a sequence of OpKernels. std::vector ordered_nodes; ordered_nodes.reserve(graph.num_nodes()); @@ -254,7 +254,7 @@ class SingleThreadedExecutorImpl : public Executor { return absl::OkStatus(); } - Status Run(const Args& args) override { + absl::Status Run(const Args& args) override { // The inputs to each kernel are stored contiguously in `inputs`. // // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to @@ -578,8 +578,9 @@ class SingleThreadedExecutorRegistrar { private: class Factory : public ExecutorFactory { - Status NewExecutor(const LocalExecutorParams& params, const Graph& graph, - std::unique_ptr* out_executor) override { + absl::Status NewExecutor(const LocalExecutorParams& params, + const Graph& graph, + std::unique_ptr* out_executor) override { Executor* ret; TF_RETURN_IF_ERROR(NewSingleThreadedExecutor(params, graph, &ret)); out_executor->reset(ret); @@ -591,8 +592,9 @@ static SingleThreadedExecutorRegistrar registrar; } // namespace -Status NewSingleThreadedExecutor(const LocalExecutorParams& params, - const Graph& graph, Executor** executor) { +absl::Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + const Graph& graph, + Executor** executor) { auto impl = std::make_unique(params); TF_RETURN_IF_ERROR(impl->Initialize(graph)); *executor = impl.release(); diff --git a/tensorflow/core/common_runtime/single_threaded_executor.h b/tensorflow/core/common_runtime/single_threaded_executor.h index cd332b19124f0b..55749ed6f917f9 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor.h +++ b/tensorflow/core/common_runtime/single_threaded_executor.h @@ -51,15 +51,15 @@ namespace tensorflow { // // The single-threaded executor is primarily suitable for executing simple // TensorFlow functions, such as one might find in a `tf.data` pipeline. -Status NewSingleThreadedExecutor(const LocalExecutorParams& params, - const Graph& graph, Executor** executor); +absl::Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + const Graph& graph, Executor** executor); // Returns OkStatus() for ops which are compatible with synchronous execution, // and otherwise returns an error message appropriate for propagation if needed. // If `allow_control_flow_sync_execution` is set to `true` control // nodes are marked as safe for execution on the SingleThreadedExecutor. -Status ValidateOpIsSafeForSyncExecution(const Node& n, - bool allow_control_flow_sync_execution); +absl::Status ValidateOpIsSafeForSyncExecution( + const Node& n, bool allow_control_flow_sync_execution); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/single_threaded_executor_test.cc b/tensorflow/core/common_runtime/single_threaded_executor_test.cc index a53e65d7a7a513..334ada5ad0a389 100644 --- a/tensorflow/core/common_runtime/single_threaded_executor_test.cc +++ b/tensorflow/core/common_runtime/single_threaded_executor_test.cc @@ -109,14 +109,14 @@ class ExecutorTest : public ::testing::Test { rendez_ = NewLocalRendezvous(); } - Status Run(Rendezvous* rendez) { + absl::Status Run(Rendezvous* rendez) { Executor::Args args; args.rendezvous = rendez; args.runner = runner_; return exec_->Run(args); } - Status Run(CallFrameInterface* call_frame) { + absl::Status Run(CallFrameInterface* call_frame) { Executor::Args args; args.call_frame = call_frame; args.runner = runner_; diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 3121a938bd0152..695b7d55217094 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -462,7 +462,7 @@ string StepStatsCollector::ReportAllocsOnResourceExhausted( std::make_pair(dev_stat.first, alloc.first->allocator_name()); AllocStats& dev_allocs_stats = allocs_map[dev_allocator]; TrackingAllocator* tracking_alloc = alloc.second; - gtl::InlinedVector cur_records = + absl::InlinedVector cur_records = tracking_alloc->GetCurrentRecords(); int64_t cur_bytes = 0; for (const auto& r : cur_records) { diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index df1e579f6d8932..277630cd40f9de 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -124,7 +124,7 @@ class NodeExecStatsWrapper : public NodeExecStatsInterface { void AddAllocation(Allocator* allocator, TrackingAllocator* tracking_allocator); - gtl::InlinedVector, 2> + absl::InlinedVector, 2UL> allocations_; std::unique_ptr stats_; const NodeDef* const node_; // Not owned. diff --git a/tensorflow/core/common_runtime/test_collective_executor_mgr.h b/tensorflow/core/common_runtime/test_collective_executor_mgr.h index 0a47150b6a0dc3..0d0b190a6f3772 100644 --- a/tensorflow/core/common_runtime/test_collective_executor_mgr.h +++ b/tensorflow/core/common_runtime/test_collective_executor_mgr.h @@ -60,11 +60,11 @@ class TestParamResolver : public ParamResolverInterface { done(errors::Internal("Unimplemented")); } - Status LookupGroup(int32_t group_key, CollGroupParams* group) override { + absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) override { return errors::Internal("Unimplemented"); } - void StartAbort(const Status& s) override {} + void StartAbort(const absl::Status& s) override {} }; class TestCollectiveExecutorMgr : public CollectiveExecutorMgrInterface { diff --git a/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt b/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt new file mode 100644 index 00000000000000..c29e4959bc8ccb --- /dev/null +++ b/tensorflow/core/common_runtime/testdata/simplify_ici_dummy_variables_pass_before.pbtxt @@ -0,0 +1,919 @@ +# proto-file: third_party/tensorflow/core/framework/graph.proto +# proto-message: GraphDef +node { + name: "unknown_2" + op: "_Arg" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "_handle_dtypes" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "_handle_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + unknown_rank: true + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "905" + } + } + attr { + key: "index" + value { + i: 4 + } + } +} + +node { + name: "unknown_17" + op: "_Arg" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "_handle_dtypes" + value { + list { type: DT_FLOAT } + } + } + attr { + key: "_handle_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { unknown_rank: true } + } + } + } + attr { + key: "_user_specified_name" + value { s: "935" } + } + attr { + key: "index" + value { i: 19 } + } +} + +node { + name: "tpu_compile_mlir" + op: "_TPUCompileMlir" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "NumDynamicShapes" + value { i: 0 } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { dim { size: 3 } } + shape { dim { size: 3 } } + shape { dim { size: 3 } } + shape { dim { size: 3 } } + } + } + } + attr { + key: "metadata" + value { s: "" } + } + attr { + key: "mlir_module" + value { s: "" } + } + attr { + key: "num_computations" + value { i: 4 } + } +} + +node { + name: "readvariableop_1" + op: "ReadVariableOp" + input: "unknown_17" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "dtype" + value { type: DT_FLOAT } + } +} + +node { + name: "identity_1" + op: "Identity" + input: "readvariableop_1" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 128 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } +} + +node { + name: "const_1" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape {} + } + } + } + attr { + key: "_parallel_execution_ids" + value { s: "r0:0" } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape {} + } + } + } +} + +node { + name: "split_1" + op: "Split" + input: "const_1" + input: "identity_1" + attr { + key: "T" + value { type: DT_FLOAT } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "num_split" + value { i: 4 } + } +} + +node { + name: "readvariableop_2" + op: "ReadVariableOp" + input: "unknown_2" + attr { + key: "_output_shapes" + value { + list { + shape { dim { size: 1024 } } + } + } + } + attr { + key: "dtype" + value { type: DT_FLOAT } + } +} + +node { + name: "identity_2" + op: "Identity" + input: "readvariableop_2" + device: "/job:tpu_host_worker/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { dim { size: 1024 } } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0" + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } +} + +node { + name: "tpu_execute_1" + op: "TPUExecute" + input: "split_1" + input: "identity_2" + input: "tpu_compile_mlir:1" + device: "/job:tpu_host_worker/replica:0/task:0/device:TPU:0" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0,p0:0" + } + } +} + +node { + name: "tpu_execute_2" + op: "TPUExecute" + input: "split_1:1" + input: "identity_2" + input: "tpu_compile_mlir:2" + device: "/job:tpu_host_worker/replica:0/task:0/device:TPU:1" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:0,p0:1" + } + } +} + +node { + name: "const_3" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: "\200\000\000\000\000\000\000\000\000\004\000\000\000\000\000\000" + } + } + } +} + +node { + name: "const_4" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + } + } + } +} + +node { + name: "fill_1" + op: "Fill" + input: "const_3" + input: "const_4" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "index_type" + value { + type: DT_INT64 + } + } +} + +node { + name: "identity_3" + op: "Identity" + input: "fill_1" + device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 128 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } +} + +node { + name: "const_2" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + } + } + } +} + +node { + name: "split_2" + op: "Split" + input: "const_2" + input: "identity_3" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { s: "r0:1" } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { b: true } + } + attr { + key: "num_split" + value { i: 4 } + } +} + +node { + name: "const_5" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT64 + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT64 + tensor_shape { + dim { + size: 1 + } + } + int64_val: 1024 + } + } + } +} + +node { + name: "const_6" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "bcast_id" + value { + i: 4 + } + } + attr { + key: "dtype" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BFLOAT16 + tensor_shape { + } + } + } + } + experimental_debug_info { + original_node_names: "Identity" + original_func_names: "__inference__train_helper_851" + } +} + +node { + name: "fill_2" + op: "Fill" + input: "const_5" + input: "const_6" + attr { + key: "T" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } + attr { + key: "index_type" + value { + type: DT_INT64 + } + } +} + +node { + name: "identity_4" + op: "Identity" + input: "fill_2" + device: "/job:tpu_host_worker/replica:0/task:2/device:CPU:0" + attr { + key: "T" + value { + type: DT_BFLOAT16 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1" + } + } + attr { + key: "_ici_weight_distribution_mlir_bridge_marker" + value { + b: true + } + } +} + +node { + name: "tpu_execute_3" + op: "TPUExecute" + input: "split_2" + input: "identity_4" + input: "tpu_compile_mlir:1" + device: "/job:tpu_host_worker/replica:0/task:2/device:TPU:0" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_BFLOAT16 + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape {} + shape { + dim { size: 32 } + dim { size: 1024 } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1,p0:0" + } + } +} + +node { + name: "tpu_execute_4" + op: "TPUExecute" + input: "split_2:1" + input: "identity_4" + input: "tpu_compile_mlir:2" + device: "/job:tpu_host_worker/replica:0/task:2/device:TPU:1" + attr { + key: "Targs" + value { + list { + type: DT_FLOAT + type: DT_BFLOAT16 + } + } + } + attr { + key: "Tresults" + value { + list { + type: DT_INT32 + type: DT_FLOAT + } + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + } + shape { + dim { + size: 32 + } + dim { + size: 1024 + } + } + } + } + } + attr { + key: "_parallel_execution_ids" + value { + s: "r0:1,p0:1" + } + } +} diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index a06e9e90b7ba16..23166b69540083 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -118,7 +118,7 @@ Allocator* ThreadPoolDevice::GetScopedAllocator(AllocatorAttributes attr, return allocator_; } -Status ThreadPoolDevice::MakeTensorFromProto( +absl::Status ThreadPoolDevice::MakeTensorFromProto( const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, Tensor* tensor) { if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { @@ -184,7 +184,7 @@ void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { op_kernel->Compute(context); if (context->status().ok() && node_file_writer_) { - Status s = node_file_writer_->RecordNodeExecution(op_kernel, context); + absl::Status s = node_file_writer_->RecordNodeExecution(op_kernel, context); if (!s.ok()) { LOG(ERROR) << s; context->SetStatus(s); diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h index a2b062e0e6f727..08175ccb1f231c 100644 --- a/tensorflow/core/common_runtime/threadpool_device.h +++ b/tensorflow/core/common_runtime/threadpool_device.h @@ -36,14 +36,14 @@ class ThreadPoolDevice : public LocalDevice { ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { return scoped_allocator_mgr_.get(); } - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; void CopyTensorInSameDevice(const Tensor* input_tensor, Tensor* output_tensor, const DeviceContext* device_context, StatusCallback done) override; - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } void Compute(OpKernel* op_kernel, OpKernelContext* context) override; void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc index 4ebf3576167e5f..3ac8ea5ae8b68c 100644 --- a/tensorflow/core/common_runtime/threadpool_device_factory.cc +++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc @@ -29,14 +29,15 @@ namespace tensorflow { // TODO(zhifengc/tucker): Figure out the bytes of available RAM. class ThreadPoolDeviceFactory : public DeviceFactory { public: - Status ListPhysicalDevices(std::vector* devices) override { + absl::Status ListPhysicalDevices(std::vector* devices) override { devices->push_back("/physical_device:CPU:0"); return absl::OkStatus(); } - Status CreateDevices(const SessionOptions& options, const string& name_prefix, - std::vector>* devices) override { + absl::Status CreateDevices( + const SessionOptions& options, const string& name_prefix, + std::vector>* devices) override { int num_numa_nodes = port::NUMANumNodes(); int n = 1; auto iter = options.config.device_count().find("CPU"); diff --git a/tensorflow/core/common_runtime/threadpool_device_test.cc b/tensorflow/core/common_runtime/threadpool_device_test.cc index 263b1deb0af2f1..bb4d3d809b4c9f 100644 --- a/tensorflow/core/common_runtime/threadpool_device_test.cc +++ b/tensorflow/core/common_runtime/threadpool_device_test.cc @@ -58,7 +58,7 @@ TEST(ThreadPoolDeviceTest, CopyTensor) { DeviceContext* device_context = new DeviceContext; Notification note; device.CopyTensorInSameDevice(&input, &output, device_context, - [¬e](const Status& s) { + [¬e](const absl::Status& s) { TF_ASSERT_OK(s); note.Notify(); }); diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h index 23e9989a31edb7..d6bc4d9531173f 100644 --- a/tensorflow/core/config/flag_defs.h +++ b/tensorflow/core/config/flag_defs.h @@ -67,6 +67,8 @@ class Flags { TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false, "If true, TF2XLA encapsulation will be skipped for non-TPU " "graphs.") + TF_DECLARE_FLAG(enable_graph_debug_info_caching_for_stack_frames, true, + "If true, graph debug info will cache the stack frames.") // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) }; diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc index 060ede3846df23..0ed77fb654463b 100644 --- a/tensorflow/core/config/flags_api_wrapper.cc +++ b/tensorflow/core/config/flags_api_wrapper.cc @@ -56,5 +56,6 @@ PYBIND11_MODULE(flags_pybind, m) { TF_PY_DECLARE_FLAG(enable_tf2min_ici_weight) TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining) TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs) + TF_PY_DECLARE_FLAG(enable_graph_debug_info_caching_for_stack_frames) // LINT.ThenChange(//tensorflow/core/config/flag_defs.h) }; diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 748dfc17ce213e..00a7e14585c929 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -612,7 +612,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_protos_all(), ) diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc index 2206cb09d7e9a5..3e06b11bdf47b2 100644 --- a/tensorflow/core/data/captured_function.cc +++ b/tensorflow/core/data/captured_function.cc @@ -118,8 +118,8 @@ class SimpleStepStatsCollector : public StepStatsCollectorInterface { int64_t processing_time_ TF_GUARDED_BY(mu_) = 0; }; -Status GetCapturedInput(const CapturedFunction* const func, int index, - const Tensor** out) { +absl::Status GetCapturedInput(const CapturedFunction* const func, int index, + const Tensor** out) { if (TF_PREDICT_FALSE(index >= func->captured_inputs().size())) { return errors::OutOfRange( "Out of range access to captured inputs for function ", @@ -130,10 +130,10 @@ Status GetCapturedInput(const CapturedFunction* const func, int index, return absl::OkStatus(); } -Status RunShortCircuit(const ShortCircuitInfo& info, - const std::vector& args, - const CapturedFunction* const func, - std::vector* rets) { +absl::Status RunShortCircuit(const ShortCircuitInfo& info, + const std::vector& args, + const CapturedFunction* const func, + std::vector* rets) { VLOG(3) << "Running function " << func->func().name() << " short circuit"; const int num_args = args.size(); rets->reserve(info.indices.size()); @@ -150,9 +150,10 @@ Status RunShortCircuit(const ShortCircuitInfo& info, return absl::OkStatus(); } -Status RunShortCircuit(const ShortCircuitInfo& info, std::vector&& args, - const CapturedFunction* const func, - std::vector* rets) { +absl::Status RunShortCircuit(const ShortCircuitInfo& info, + std::vector&& args, + const CapturedFunction* const func, + std::vector* rets) { VLOG(3) << "Running function " << func->func().name() << " short circuit"; const int num_args = args.size(); rets->reserve(info.indices.size()); @@ -173,16 +174,16 @@ Status RunShortCircuit(const ShortCircuitInfo& info, std::vector&& args, return absl::OkStatus(); } -Status CreateShortCircuitInfo(OpKernelConstruction* ctx, - const NameAttrList& func, - ShortCircuitInfo* info) { +absl::Status CreateShortCircuitInfo(OpKernelConstruction* ctx, + const NameAttrList& func, + ShortCircuitInfo* info) { auto& indices = info->indices; FunctionLibraryRuntime::Handle fn_handle; TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( func.name(), AttrSlice(&func.attr()), &fn_handle)); auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { - Status s = ctx->function_library()->ReleaseHandle(fn_handle); + absl::Status s = ctx->function_library()->ReleaseHandle(fn_handle); if (!s.ok()) { LOG(WARNING) << "Failed to release handle: " << s.message(); } @@ -232,7 +233,7 @@ Status CreateShortCircuitInfo(OpKernelConstruction* ctx, return absl::OkStatus(); } -Status CreateFunctionLibraryDefinition( +absl::Status CreateFunctionLibraryDefinition( const FunctionLibraryDefinition* lib_def, const string& func_name, std::unique_ptr* result) { DCHECK(lib_def != nullptr); @@ -246,8 +247,8 @@ Status CreateFunctionLibraryDefinition( return (*result)->CopyFunctionDefFrom(func_name, *lib_def); } -Status LookupFunction(const FunctionLibraryDefinition& lib_def, - const string& name, const FunctionDef** fdef) { +absl::Status LookupFunction(const FunctionLibraryDefinition& lib_def, + const string& name, const FunctionDef** fdef) { *fdef = lib_def.Find(name); if (*fdef == nullptr) { return errors::InvalidArgument( @@ -263,7 +264,7 @@ class CallFrameBase : public CallFrameInterface { : ret_types_(ret_types), retvals_(ret_types.size()) {} // Caller methods. - Status ConsumeRetvals(std::vector* retvals) { + absl::Status ConsumeRetvals(std::vector* retvals) { retvals->reserve(retvals_.size()); int i = 0; for (auto&& val : retvals_) { @@ -279,7 +280,7 @@ class CallFrameBase : public CallFrameInterface { size_t num_retvals() const override { return retvals_.size(); } // Callee methods. - Status SetRetval(int index, const Tensor& val) override { + absl::Status SetRetval(int index, const Tensor& val) override { const int retvals_size = retvals_.size(); if (index < retvals_size && val.dtype() == ret_types_[index] && !retvals_[index]) { @@ -320,7 +321,7 @@ class OwnedArgsCallFrame : public CallFrameBase { } // Callee methods. - Status GetArg(int index, const Tensor** val) override { + absl::Status GetArg(int index, const Tensor** val) override { const int args_size = args_.size(); const int captured_inputs_size = captured_inputs_->size(); if (index < args_size) { @@ -364,7 +365,7 @@ class BorrowedArgsCallFrame : public CallFrameBase { } // Callee methods. - Status GetArg(int index, const Tensor** val) override { + absl::Status GetArg(int index, const Tensor** val) override { const int args_size = args_.size(); const int captured_inputs_size = captured_inputs_->size(); if (index < args_size) { @@ -385,7 +386,7 @@ class BorrowedArgsCallFrame : public CallFrameBase { } // namespace -Status MakeIteratorFromInputElement( +absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, @@ -395,7 +396,7 @@ Status MakeIteratorFromInputElement( /*node=*/nullptr); } -Status MakeIteratorFromInputElement( +absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, @@ -431,7 +432,7 @@ Status MakeIteratorFromInputElement( } /* static */ -Status FunctionMetadata::Create( +absl::Status FunctionMetadata::Create( OpKernelConstruction* ctx, const string& func_name, Params params, std::shared_ptr* out_metadata) { NameAttrList func; @@ -439,7 +440,7 @@ Status FunctionMetadata::Create( return Create(ctx, std::move(func), params, out_metadata); } -Status FunctionMetadata::Create( +absl::Status FunctionMetadata::Create( OpKernelConstruction* ctx, NameAttrList&& func, Params params, std::shared_ptr* out_metadata) { out_metadata->reset(new FunctionMetadata(std::move(func), params)); @@ -483,7 +484,7 @@ Status FunctionMetadata::Create( } /* static */ -Status CapturedFunction::Create( +absl::Status CapturedFunction::Create( OpKernelContext* ctx, std::shared_ptr metadata, const string& argument_name, std::unique_ptr* out_function) { @@ -495,7 +496,7 @@ Status CapturedFunction::Create( } /* static */ -Status CapturedFunction::Create( +absl::Status CapturedFunction::Create( OpKernelContext* ctx, std::shared_ptr metadata, std::vector&& captured_inputs, std::unique_ptr* out_function) { @@ -504,7 +505,7 @@ Status CapturedFunction::Create( return absl::OkStatus(); } -Status CapturedFunction::AddToGraph( +absl::Status CapturedFunction::AddToGraph( SerializationContext* ctx, DatasetBase::DatasetGraphDefBuilder* b, std::vector* other_arguments, DataTypeVector* other_arguments_types) const { @@ -527,7 +528,7 @@ Status CapturedFunction::AddToGraph( return absl::OkStatus(); } -Status CapturedFunction::Instantiate( +absl::Status CapturedFunction::Instantiate( IteratorContext* ctx, std::unique_ptr* instantiated_captured_function) { return CapturedFunction::Instantiate(InstantiateCapturedFunctionParams(ctx), @@ -536,7 +537,7 @@ Status CapturedFunction::Instantiate( // TODO(b/190831948): Check whether the function creates a resource and if so, // produce a warning. -Status CapturedFunction::Instantiate( +absl::Status CapturedFunction::Instantiate( InstantiateCapturedFunctionParams params, std::unique_ptr* instantiated_captured_function) { @@ -686,7 +687,7 @@ Status CapturedFunction::Instantiate( return absl::OkStatus(); } -Status CapturedFunction::CheckExternalState() const { +absl::Status CapturedFunction::CheckExternalState() const { for (const auto& name : lib_def()->ListFunctionNames()) { TF_RETURN_IF_ERROR( IsFunctionStateful(*lib_def(), *(lib_def()->Find(name)))); @@ -700,8 +701,8 @@ CapturedFunction::CapturedFunction( : metadata_(std::move(metadata)), captured_inputs_(std::move(captured_inputs)) {} -Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, - bool* is_multi_device) const { +absl::Status CapturedFunction::IsMultiDevice(FunctionLibraryRuntime* flr, + bool* is_multi_device) const { if (!metadata_->use_multi_device_function()) { *is_multi_device = false; return absl::OkStatus(); @@ -786,13 +787,13 @@ InstantiatedCapturedFunction::InstantiatedCapturedFunction( captured_func_(captured_func), is_multi_device_(is_multi_device) {} -Status InstantiatedCapturedFunction::Run(IteratorContext* ctx, - std::vector&& args, - std::vector* rets) const { +absl::Status InstantiatedCapturedFunction::Run( + IteratorContext* ctx, std::vector&& args, + std::vector* rets) const { return Run(ctx, std::move(args), rets, /*node=*/nullptr); } -Status InstantiatedCapturedFunction::Run( +absl::Status InstantiatedCapturedFunction::Run( IteratorContext* ctx, std::vector&& args, std::vector* rets, const std::shared_ptr& node) const { auto& info = captured_func_->short_circuit_info(); @@ -849,13 +850,13 @@ Status InstantiatedCapturedFunction::Run( return frame.ConsumeRetvals(rets); } -Status InstantiatedCapturedFunction::RunWithBorrowedArgs( +absl::Status InstantiatedCapturedFunction::RunWithBorrowedArgs( IteratorContext* ctx, const std::vector& args, std::vector* ret) const { return RunWithBorrowedArgs(ctx, args, ret, /*node=*/nullptr); } -Status InstantiatedCapturedFunction::RunWithBorrowedArgs( +absl::Status InstantiatedCapturedFunction::RunWithBorrowedArgs( IteratorContext* ctx, const std::vector& args, std::vector* rets, const std::shared_ptr& node) const { auto& info = captured_func_->short_circuit_info(); @@ -911,7 +912,7 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs( return frame.ConsumeRetvals(rets); } -Status InstantiatedCapturedFunction::RunInstantiated( +absl::Status InstantiatedCapturedFunction::RunInstantiated( const std::vector& args, std::vector* rets) { auto& info = captured_func_->short_circuit_info(); if (!info.indices.empty()) { @@ -953,7 +954,8 @@ void InstantiatedCapturedFunction::RunAsync( // Run the `done` callback on a threadpool thread, because it will // potentially do a non-trivial amount of (e.g. copying) work, and we may // want to run that concurrently with the next invocation. - Status s = RunShortCircuit(info, std::move(args), captured_func_, rets); + absl::Status s = + RunShortCircuit(info, std::move(args), captured_func_, rets); runner( std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); }, std::move(done))); @@ -996,7 +998,7 @@ void InstantiatedCapturedFunction::RunAsync( const FunctionLibraryRuntime::DoneCallback& done, const std::shared_ptr& stats_collector, // Begin unbound arguments. - Status s) { + absl::Status s) { delete step_container; delete raw_cancellation_manager; if (s.ok()) { diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h index 854d9fc22cad53..b3ee90ee951878 100644 --- a/tensorflow/core/data/captured_function.h +++ b/tensorflow/core/data/captured_function.h @@ -44,7 +44,7 @@ class InstantiatedCapturedFunction; // Creates an iterator for a dataset which is created by applying the given // function to the given input element. -Status MakeIteratorFromInputElement( +absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, @@ -53,7 +53,7 @@ Status MakeIteratorFromInputElement( // Creates an iterator for a dataset which is created by applying the given // function to the given input element. Pass non-null `node` to record // processing time for modeling Iterator's GetNext() resource usage. -Status MakeIteratorFromInputElement( +absl::Status MakeIteratorFromInputElement( IteratorContext* ctx, const DatasetBaseIterator* parent, const std::vector& input_element, int64_t thread_index, const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix, @@ -75,15 +75,15 @@ class FunctionMetadata { // Creates a new instance of the `FunctionMetadata` class, fetching function // from a context argument. - static Status Create(tensorflow::OpKernelConstruction* ctx, - const string& func_name, Params params, - std::shared_ptr* out_metadata); + static absl::Status Create(tensorflow::OpKernelConstruction* ctx, + const string& func_name, Params params, + std::shared_ptr* out_metadata); // Creates a new instance of the `FunctionMetadata` class, using the provided // function. - static Status Create(tensorflow::OpKernelConstruction* ctx, - NameAttrList&& func, Params params, - std::shared_ptr* out_metadata); + static absl::Status Create(tensorflow::OpKernelConstruction* ctx, + NameAttrList&& func, Params params, + std::shared_ptr* out_metadata); // Returns the named list of function arguments. const NameAttrList& func() const { return func_; } @@ -148,38 +148,38 @@ class CapturedFunction { public: // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. - static Status Create(OpKernelContext* ctx, - std::shared_ptr metadata, - const string& argument_name, - std::unique_ptr* out_function); + static absl::Status Create(OpKernelContext* ctx, + std::shared_ptr metadata, + const string& argument_name, + std::unique_ptr* out_function); // Creates a new instance using a list of named attributes, using provided // captured inputs. - static Status Create(OpKernelContext* ctx, - std::shared_ptr metadata, - std::vector&& captured_inputs, - std::unique_ptr* out_function); + static absl::Status Create(OpKernelContext* ctx, + std::shared_ptr metadata, + std::vector&& captured_inputs, + std::unique_ptr* out_function); // Adds the definition of this captured function into the given graph, // returning its captured inputs and types through the respective output // arguments. - Status AddToGraph(SerializationContext* ctx, - DatasetBase::DatasetGraphDefBuilder* b, - std::vector* other_arguments, - DataTypeVector* other_arguments_types) const; + absl::Status AddToGraph(SerializationContext* ctx, + DatasetBase::DatasetGraphDefBuilder* b, + std::vector* other_arguments, + DataTypeVector* other_arguments_types) const; // Instantiates this function for use in the given context, providing an // InstantiatedCapturedFunction that can be used to execute functions. - Status Instantiate(IteratorContext* ctx, - std::unique_ptr* - instantiated_captured_function); + absl::Status Instantiate(IteratorContext* ctx, + std::unique_ptr* + instantiated_captured_function); - Status Instantiate(InstantiateCapturedFunctionParams params, - std::unique_ptr* - instantiated_captured_function); + absl::Status Instantiate(InstantiateCapturedFunctionParams params, + std::unique_ptr* + instantiated_captured_function); // Determines whether the captured function is stateful. - Status CheckExternalState() const; + absl::Status CheckExternalState() const; // Returns the additional captured inputs that will be passed to the function. const std::vector& captured_inputs() const { @@ -211,8 +211,8 @@ class CapturedFunction { CapturedFunction(std::shared_ptr metadata, std::vector captured_inputs); - Status IsMultiDevice(FunctionLibraryRuntime* flr, - bool* is_multi_device) const; + absl::Status IsMultiDevice(FunctionLibraryRuntime* flr, + bool* is_multi_device) const; const std::shared_ptr metadata_; const std::vector captured_inputs_; @@ -237,8 +237,8 @@ class InstantiatedCapturedFunction { // the tensors in `args`, in order to be able to deallocate them as early as // possible. Use `RunWithBorrowedArgs()` if the caller needs to retain // ownership of the `args`. - Status Run(IteratorContext* ctx, std::vector&& args, - std::vector* rets) const; + absl::Status Run(IteratorContext* ctx, std::vector&& args, + std::vector* rets) const; // Runs the instantiated captured function. This method takes ownership of // the tensors in `args`, in order to be able to deallocate them as early as @@ -247,16 +247,16 @@ class InstantiatedCapturedFunction { // for modeling Iterator's GetNext() resource usage. When non-null node is // provided, the pre-requisite is that the calling thread has previously // called `DatasetBaseIterator::RecordStart(). - Status Run(IteratorContext* ctx, std::vector&& args, - std::vector* rets, - const std::shared_ptr& node) const; + absl::Status Run(IteratorContext* ctx, std::vector&& args, + std::vector* rets, + const std::shared_ptr& node) const; // Synchronously runs the captured function on the given `args`, and stores // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when // possible. - Status RunWithBorrowedArgs(IteratorContext* ctx, - const std::vector& args, - std::vector* rets) const; + absl::Status RunWithBorrowedArgs(IteratorContext* ctx, + const std::vector& args, + std::vector* rets) const; // Synchronously runs the captured function on the given `args`, and stores // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when @@ -264,10 +264,10 @@ class InstantiatedCapturedFunction { // Iterator's GetNext() resource usage. When non-null node is provided, the // pre-requisite is that the calling thread has previously called // `DatasetBaseIterator::RecordStart(). - Status RunWithBorrowedArgs(IteratorContext* ctx, - const std::vector& args, - std::vector* rets, - const std::shared_ptr& node) const; + absl::Status RunWithBorrowedArgs( + IteratorContext* ctx, const std::vector& args, + std::vector* rets, + const std::shared_ptr& node) const; // Synchronously runs the captured function on the given `args`, and stores // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when @@ -275,8 +275,8 @@ class InstantiatedCapturedFunction { // an `IteratorContext*` is not available (such as a destructor). // // TODO(b/144278100): Avoid running functions without IteratorContext. - Status RunInstantiated(const std::vector& args, - std::vector* rets); + absl::Status RunInstantiated(const std::vector& args, + std::vector* rets); // Asynchronously runs the captured function on the given `args`, stores the // results in `*rets`, and calls the given `done` callback when the function diff --git a/tensorflow/core/data/compression_utils.cc b/tensorflow/core/data/compression_utils.cc index 68c8f27f3127ad..3b06c6e113dade 100644 --- a/tensorflow/core/data/compression_utils.cc +++ b/tensorflow/core/data/compression_utils.cc @@ -61,8 +61,8 @@ class Iov { size_t num_bytes_; }; -Status CompressElement(const std::vector& element, - CompressedElement* out) { +absl::Status CompressElement(const std::vector& element, + CompressedElement* out) { // First pass: preprocess the non`memcpy`able tensors. size_t num_string_tensors = 0; size_t num_string_tensor_strings = 0; @@ -135,8 +135,8 @@ Status CompressElement(const std::vector& element, return absl::OkStatus(); } -Status UncompressElement(const CompressedElement& compressed, - std::vector* out) { +absl::Status UncompressElement(const CompressedElement& compressed, + std::vector* out) { if (compressed.version() != kCompressedElementVersion) { return errors::Internal("Unsupported compressed element version: ", compressed.version()); diff --git a/tensorflow/core/data/compression_utils.h b/tensorflow/core/data/compression_utils.h index 0d309fc5c05d0a..21f29d843427da 100644 --- a/tensorflow/core/data/compression_utils.h +++ b/tensorflow/core/data/compression_utils.h @@ -30,12 +30,12 @@ namespace data { // out the per-component metadata for the `CompressedElement`. // // Returns an error if the uncompressed size of the element exceeds 4GB. -Status CompressElement(const std::vector& element, - CompressedElement* out); +absl::Status CompressElement(const std::vector& element, + CompressedElement* out); // Uncompresses a `CompressedElement` into a vector of tensor components. -Status UncompressElement(const CompressedElement& compressed, - std::vector* out); +absl::Status UncompressElement(const CompressedElement& compressed, + std::vector* out); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index 19345990f355f8..718fbb12c6f57e 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -257,8 +257,8 @@ std::pair MaybeOverrideSeeds( return seeds; } -Status VerifyTypeMatch(const DataType& expected, const DataType& received, - int index) { +absl::Status VerifyTypeMatch(const DataType& expected, const DataType& received, + int index) { if (expected != received) { return errors::InvalidArgument("Data type mismatch at component ", index, ": expected ", DataTypeString(expected), @@ -267,8 +267,8 @@ Status VerifyTypeMatch(const DataType& expected, const DataType& received, return absl::OkStatus(); } -Status VerifyTypesMatch(const DataTypeVector& expected, - const DataTypeVector& received) { +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { if (expected.size() != received.size()) { return errors::InvalidArgument( "Number of components does not match: expected ", expected.size(), @@ -280,8 +280,8 @@ Status VerifyTypesMatch(const DataTypeVector& expected, return absl::OkStatus(); } -Status VerifyTypesMatch(const DataTypeVector& expected, - const std::vector& received) { +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const std::vector& received) { if (expected.size() != received.size()) { return errors::InvalidArgument( "Number of components does not match: expected ", expected.size(), @@ -293,8 +293,9 @@ Status VerifyTypesMatch(const DataTypeVector& expected, return absl::OkStatus(); } -Status VerifyShapeCompatible(const PartialTensorShape& expected, - const PartialTensorShape& received, int index) { +absl::Status VerifyShapeCompatible(const PartialTensorShape& expected, + const PartialTensorShape& received, + int index) { if (!expected.IsCompatibleWith(received)) { return errors::InvalidArgument("Incompatible shapes at component ", index, ": expected ", expected.DebugString(), @@ -303,8 +304,9 @@ Status VerifyShapeCompatible(const PartialTensorShape& expected, return absl::OkStatus(); } -Status VerifyShapesCompatible(const std::vector& expected, - const std::vector& received) { +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received) { if (expected.size() != received.size()) { return errors::InvalidArgument( "Number of components does not match: expected ", expected.size(), @@ -317,8 +319,9 @@ Status VerifyShapesCompatible(const std::vector& expected, return absl::OkStatus(); } -Status VerifyShapesCompatible(const std::vector& expected, - const std::vector& received) { +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received) { if (expected.size() != received.size()) { return errors::InvalidArgument( "Number of components does not match: expected ", expected.size(), @@ -332,8 +335,8 @@ Status VerifyShapesCompatible(const std::vector& expected, return absl::OkStatus(); } -Status AddToFunctionLibrary(FunctionLibraryDefinition* base, - const FunctionLibraryDefinition& to_add) { +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionLibraryDefinition& to_add) { for (const auto& fn : to_add.ListFunctionNames()) { if (auto found = base->Find(fn)) { if (!OpDefEqual(found->signature(), to_add.Find(fn)->signature())) { @@ -347,8 +350,8 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, return base->AddLibrary(to_add); } -Status AddToFunctionLibrary(FunctionLibraryDefinition* base, - const FunctionDefLibrary& to_add) { +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionDefLibrary& to_add) { for (const auto& fd : to_add.function()) { if (auto found = base->Find(fd.signature().name())) { if (!OpDefEqual(found->signature(), fd.signature())) { @@ -363,8 +366,8 @@ Status AddToFunctionLibrary(FunctionLibraryDefinition* base, return base->AddLibrary(to_add); } -Status IsFunctionStateful(const FunctionLibraryDefinition& library, - const FunctionDef& function_def) { +absl::Status IsFunctionStateful(const FunctionLibraryDefinition& library, + const FunctionDef& function_def) { if (!function_def.signature().is_stateful()) { return absl::OkStatus(); } @@ -375,8 +378,8 @@ Status IsFunctionStateful(const FunctionLibraryDefinition& library, return absl::OkStatus(); } -Status IsNodeStateful(const FunctionLibraryDefinition& library, - const NodeDef& node) { +absl::Status IsNodeStateful(const FunctionLibraryDefinition& library, + const NodeDef& node) { const OpDef* op_def; // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore @@ -436,8 +439,8 @@ std::function)> RunnerWithMaxParallelism( std::move(runner), std::placeholders::_1); } -Status DeterminismPolicy::FromString(const std::string& s, - DeterminismPolicy* out) { +absl::Status DeterminismPolicy::FromString(const std::string& s, + DeterminismPolicy* out) { DeterminismPolicy::Type type; if (s == DeterminismPolicy::kDeterministic) { type = DeterminismPolicy::Type::kDeterministic; @@ -632,8 +635,8 @@ void StripDevicePlacement(FunctionDefLibrary* library) { } } -Status CopyPartialBatch(int64_t num_elements, const Tensor& value, - Tensor* output) { +absl::Status CopyPartialBatch(int64_t num_elements, const Tensor& value, + Tensor* output) { switch (value.dtype()) { #define HANDLE_TYPE(type) \ case DataTypeToEnum::value: { \ @@ -653,9 +656,9 @@ Status CopyPartialBatch(int64_t num_elements, const Tensor& value, return absl::OkStatus(); } -Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, - int64_t batch_size, const string& iterator_prefix, - const string& batch_prefix, std::vector* batch) { +absl::Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, + int64_t batch_size, const string& iterator_prefix, + const string& batch_prefix, std::vector* batch) { int64_t output_size; TF_RETURN_IF_ERROR(reader->ReadScalar( FullName(iterator_prefix, @@ -686,9 +689,10 @@ Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, return absl::OkStatus(); } -Status WriteBatch(int64_t batch_size, int64_t num_elements, - const string& iterator_prefix, const string& batch_prefix, - IteratorStateWriter* writer, std::vector* batch) { +absl::Status WriteBatch(int64_t batch_size, int64_t num_elements, + const string& iterator_prefix, + const string& batch_prefix, IteratorStateWriter* writer, + std::vector* batch) { TF_RETURN_IF_ERROR(writer->WriteScalar( FullName(iterator_prefix, strings::StrCat(batch_prefix, "_", kOutputSize)), @@ -711,8 +715,8 @@ Status WriteBatch(int64_t batch_size, int64_t num_elements, return absl::OkStatus(); } -Status ReadStatus(const string& iterator_prefix, const string& prefix, - IteratorStateReader* reader, Status* status) { +absl::Status ReadStatus(const string& iterator_prefix, const string& prefix, + IteratorStateReader* reader, absl::Status* status) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), @@ -724,15 +728,16 @@ Status ReadStatus(const string& iterator_prefix, const string& prefix, TF_RETURN_IF_ERROR(reader->ReadScalar( FullName(iterator_prefix, strings::StrCat(prefix, "_", kMessage)), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } return absl::OkStatus(); } -Status WriteStatus(const string& iterator_prefix, const string& prefix, - const Status& status, IteratorStateWriter* writer) { +absl::Status WriteStatus(const string& iterator_prefix, const string& prefix, + const absl::Status& status, + IteratorStateWriter* writer) { TF_RETURN_IF_ERROR(writer->WriteScalar( FullName(iterator_prefix, strings::StrCat(prefix, "_", kCode)), static_cast(status.code()))); @@ -744,10 +749,10 @@ Status WriteStatus(const string& iterator_prefix, const string& prefix, return absl::OkStatus(); } -Status ProcessBatch(int64_t batch_size, int64_t num_elements, - bool drop_remainder, const Status& status, - IteratorContext* ctx, std::vector* output, - bool* end_of_sequence, std::vector* batch) { +absl::Status ProcessBatch(int64_t batch_size, int64_t num_elements, + bool drop_remainder, const absl::Status& status, + IteratorContext* ctx, std::vector* output, + bool* end_of_sequence, std::vector* batch) { if (num_elements == 0) { if (status.ok() || absl::IsOutOfRange(status)) { *end_of_sequence = true; @@ -787,9 +792,9 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements, return absl::OkStatus(); } -Status CopyBatch(AnyContext ctx, - std::vector>&& batch_elements, - bool parallel_copy, std::vector* out_tensors) { +absl::Status CopyBatch(AnyContext ctx, + std::vector>&& batch_elements, + bool parallel_copy, std::vector* out_tensors) { const size_t num_tuple_components = batch_elements.at(0).size(); out_tensors->reserve(num_tuple_components); const int64_t num_batch_elements = batch_elements.size(); @@ -835,7 +840,7 @@ Status CopyBatch(AnyContext ctx, // Use parallelism for creating the batch as long as the final batch is at // least 1MB. if (parallel_copy && total_bytes >= (1 << 20)) { - Status status; + absl::Status status; mutex status_mu; const auto num_threads = ctx.runner_threadpool_size; const auto slice_size = num_batch_elements / num_threads; @@ -849,7 +854,7 @@ Status CopyBatch(AnyContext ctx, if (i < num_batch_elements % num_threads) ++length; (*ctx.runner)([offset, length, &status, &status_mu, &counter, ©_element_fn]() { - Status s; + absl::Status s; for (size_t j = offset; j < offset + length; ++j) { s.Update(copy_element_fn(j)); } diff --git a/tensorflow/core/data/dataset_utils.h b/tensorflow/core/data/dataset_utils.h index 78fac87b213985..c849f1b1c9f49a 100644 --- a/tensorflow/core/data/dataset_utils.h +++ b/tensorflow/core/data/dataset_utils.h @@ -40,8 +40,9 @@ constexpr int kShardHint = -1; // Creates a resource handle with a unique name for the given resource where // the resource is managed by the Resource Manager. template -Status CreateWeakHandle(OpKernelContext* ctx, T* resource, - const string& container_name, ResourceHandle* handle) { +absl::Status CreateWeakHandle(OpKernelContext* ctx, T* resource, + const string& container_name, + ResourceHandle* handle) { static std::atomic resource_id_counter(0); string unique_name = strings::StrCat(container_name, resource_id_counter.fetch_add(1)); @@ -56,7 +57,8 @@ Status CreateWeakHandle(OpKernelContext* ctx, T* resource, // Creates a ref-counting resource handle for the given resource, where the // resource is owned by the handle. template -Status CreateHandle(OpKernelContext* ctx, T* resource, ResourceHandle* handle) { +absl::Status CreateHandle(OpKernelContext* ctx, T* resource, + ResourceHandle* handle) { ResourceMgr* mgr = ctx->resource_manager(); *handle = ResourceHandle::MakeRefCountingHandle(resource, ctx->device()->name()); @@ -119,7 +121,7 @@ class AnonymousResourceOp : public OpKernel { protected: virtual string name() = 0; - virtual Status CreateResource( + virtual absl::Status CreateResource( OpKernelContext* ctx, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, T** resource) = 0; @@ -131,19 +133,21 @@ class AnonymousResourceOp : public OpKernel { // Returns OkStatus() if `expected` and `received` types match, // errors::InvalidArgument otherwise. -Status VerifyTypesMatch(const DataTypeVector& expected, - const DataTypeVector& received); +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received); -Status VerifyTypesMatch(const DataTypeVector& expected, - const std::vector& received); +absl::Status VerifyTypesMatch(const DataTypeVector& expected, + const std::vector& received); // Returns OkStatus() if `expected` and `received` shapes are compatible, // errors::InvalidArgument otherwise. -Status VerifyShapesCompatible(const std::vector& expected, - const std::vector& received); +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received); -Status VerifyShapesCompatible(const std::vector& expected, - const std::vector& received); +absl::Status VerifyShapesCompatible( + const std::vector& expected, + const std::vector& received); // Dataset op level determinism policy. class DeterminismPolicy { @@ -168,7 +172,7 @@ class DeterminismPolicy { // kNondeterministic, depending on the values of `is_deterministic`. explicit DeterminismPolicy(bool is_deterministic); - static Status FromString(const std::string& s, DeterminismPolicy* out); + static absl::Status FromString(const std::string& s, DeterminismPolicy* out); // Returns the string representing the determinism policy. This will be one of // the string constants defined above. @@ -196,18 +200,18 @@ std::pair MaybeOverrideSeeds( // Adds the functions in `to_add` to `base`. If a function with a matching // signature already exists in `base`, replaces it with the function from // `to_add`. -Status AddToFunctionLibrary(FunctionLibraryDefinition* base, - const FunctionLibraryDefinition& to_add); -Status AddToFunctionLibrary(FunctionLibraryDefinition* base, - const FunctionDefLibrary& to_add); +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionLibraryDefinition& to_add); +absl::Status AddToFunctionLibrary(FunctionLibraryDefinition* base, + const FunctionDefLibrary& to_add); // Determines whether the given function is stateful. -Status IsFunctionStateful(const FunctionLibraryDefinition& library, - const FunctionDef& function_def); +absl::Status IsFunctionStateful(const FunctionLibraryDefinition& library, + const FunctionDef& function_def); // Determines whether the given node is stateful. -Status IsNodeStateful(const FunctionLibraryDefinition& library, - const NodeDef& node); +absl::Status IsNodeStateful(const FunctionLibraryDefinition& library, + const NodeDef& node); // Creates a runner that runs functions with limited parallelism. std::function)> RunnerWithMaxParallelism( @@ -251,33 +255,35 @@ Tensor MaybeCopySubSlice(const Tensor& tensor, int64 index); void StripDevicePlacement(FunctionDefLibrary* library); // Copies partial of the batch output. -Status CopyPartialBatch(int64_t num_elements, const Tensor& value, - Tensor* output); +absl::Status CopyPartialBatch(int64_t num_elements, const Tensor& value, + Tensor* output); // Reads a batch when restoring the iterator. -Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, - int64_t batch_size, const string& iterator_prefix, - const string& batch_prefix, std::vector* batch); +absl::Status ReadBatch(IteratorContext* ctx, IteratorStateReader* reader, + int64_t batch_size, const string& iterator_prefix, + const string& batch_prefix, std::vector* batch); // Writes a batch when saving the iterator. -Status WriteBatch(int64_t batch_size, int64_t num_elements, - const string& iterator_prefix, const string& batch_prefix, - IteratorStateWriter* writer, std::vector* batch); +absl::Status WriteBatch(int64_t batch_size, int64_t num_elements, + const string& iterator_prefix, + const string& batch_prefix, IteratorStateWriter* writer, + std::vector* batch); // Reads a status when restoring the iterator. -Status ReadStatus(const string& iterator_prefix, const string& prefix, - IteratorStateReader* reader, Status* status); +absl::Status ReadStatus(const string& iterator_prefix, const string& prefix, + IteratorStateReader* reader, absl::Status* status); // Writes a status when saving the iterator. -Status WriteStatus(const string& iterator_prefix, const string& prefix, - const Status& status, IteratorStateWriter* writer); +absl::Status WriteStatus(const string& iterator_prefix, const string& prefix, + const absl::Status& status, + IteratorStateWriter* writer); // Processes a batch to output. In the case a partial batch is encountered, copy // only partial of the batch. -Status ProcessBatch(int64_t batch_size, int64_t num_elements, - bool drop_remainder, const Status& status, - IteratorContext* ctx, std::vector* output, - bool* end_of_sequence, std::vector* batch); +absl::Status ProcessBatch(int64_t batch_size, int64_t num_elements, + bool drop_remainder, const absl::Status& status, + IteratorContext* ctx, std::vector* output, + bool* end_of_sequence, std::vector* batch); // Copies the input elements to a batch. // @@ -286,9 +292,9 @@ Status ProcessBatch(int64_t batch_size, int64_t num_elements, // copy. // The `out_tensors` argument will be used to store the resulting batch (one for // each component of the input). -Status CopyBatch(AnyContext ctx, - std::vector>&& batch_elements, - bool parallel_copy, std::vector* out_tensors); +absl::Status CopyBatch(AnyContext ctx, + std::vector>&& batch_elements, + bool parallel_copy, std::vector* out_tensors); // Computes the set of experiments to apply based on the job name, task id, // rollout percentage of registered experiments, and the diff --git a/tensorflow/core/data/hash_utils.cc b/tensorflow/core/data/hash_utils.cc index 8dbab84a533656..d7a9776da231cf 100644 --- a/tensorflow/core/data/hash_utils.cc +++ b/tensorflow/core/data/hash_utils.cc @@ -74,7 +74,7 @@ bool IsNodeOfType(const NodeDef& node, return false; } -Status GetSink(const GraphDef& graph_def, const NodeDef** sink) { +absl::Status GetSink(const GraphDef& graph_def, const NodeDef** sink) { for (auto& node : graph_def.node()) { if (node.op() == kRetvalOp) { *sink = &node; @@ -88,7 +88,7 @@ Status GetSink(const GraphDef& graph_def, const NodeDef** sink) { return absl::OkStatus(); } -Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { +absl::Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { *result = false; if (IsNodeOfType(node, kOpsWithSeed)) { const OpRegistrationData* reg; @@ -117,9 +117,10 @@ Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { return absl::OkStatus(); } -Status ParseInputNodeName(absl::string_view input_name, - absl::string_view* node_name, - absl::string_view* suffix, bool* is_control_input) { +absl::Status ParseInputNodeName(absl::string_view input_name, + absl::string_view* node_name, + absl::string_view* suffix, + bool* is_control_input) { if (input_name[0] == '^') { *node_name = input_name.substr(1); *is_control_input = true; @@ -170,7 +171,7 @@ class GraphHasher { function_cache_(function_cache), attr_cache_(attr_cache) {} - Status Init() { + absl::Status Init() { // Construct a map of name -> NodeDef to avoid repeated linear searches. absl::flat_hash_map node_def_by_name; node_def_by_name.reserve(graph_->node_size()); @@ -240,14 +241,14 @@ class GraphHasher { return absl::OkStatus(); } - Status HashRoot(uint64* hash) { return HashNode(root_, hash); } + absl::Status HashRoot(uint64* hash) { return HashNode(root_, hash); } - Status CheckEqual(GraphHasher* that) { + absl::Status CheckEqual(GraphHasher* that) { return CheckNodesEqual(root_, that, that->root_); } private: - Status HashNode(const NodeDef* node, uint64* hash) { + absl::Status HashNode(const NodeDef* node, uint64* hash) { auto it = node_cache_->find(node); if (it != node_cache_->end()) { *hash = it->second; @@ -294,9 +295,9 @@ class GraphHasher { return absl::OkStatus(); } - Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that, - const NodeDef* that_node) { - Status s = CheckNodesEqualHelper(this_node, that, that_node); + absl::Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that, + const NodeDef* that_node) { + absl::Status s = CheckNodesEqualHelper(this_node, that, that_node); if (!s.ok()) { return errors::FailedPrecondition("Nodes ", this_node->name(), " and ", that_node->name(), @@ -305,8 +306,9 @@ class GraphHasher { return s; } - Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that, - const NodeDef* that_node) { + absl::Status CheckNodesEqualHelper(const NodeDef* this_node, + GraphHasher* that, + const NodeDef* that_node) { TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node, /*compare_functions=*/true)); @@ -342,8 +344,8 @@ class GraphHasher { return absl::OkStatus(); } - Status HashNodeNonInput(const NodeDef* node, bool hash_functions, - uint64* hash) { + absl::Status HashNodeNonInput(const NodeDef* node, bool hash_functions, + uint64* hash) { auto iter = attr_cache_->find(std::make_pair(node, hash_functions)); if (iter != attr_cache_->end()) { *hash = iter->second; @@ -399,9 +401,10 @@ class GraphHasher { return absl::OkStatus(); } - Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that, - const NodeDef* that_node, - bool compare_functions) { + absl::Status CheckNodesEqualNonInput(const NodeDef* this_node, + GraphHasher* that, + const NodeDef* that_node, + bool compare_functions) { // We get the list of attrs from the op registry and then look // up their values in the NodeDef attr map. This avoids looping over // a map which is non-deterministic. @@ -454,8 +457,9 @@ class GraphHasher { return absl::OkStatus(); } - Status HashAttr(const std::string& attr_name, const AttrValue& attr_value, - bool hash_functions, uint64* hash) { + absl::Status HashAttr(const std::string& attr_name, + const AttrValue& attr_value, bool hash_functions, + uint64* hash) { uint64 value_hash = 0; if (attr_value.has_func()) { if (hash_functions) { @@ -476,9 +480,10 @@ class GraphHasher { return absl::OkStatus(); } - Status CheckAttrsEqual(const std::string& attr_name, - const AttrValue& this_attr, GraphHasher* that, - const AttrValue& that_attr, bool compare_functions) { + absl::Status CheckAttrsEqual(const std::string& attr_name, + const AttrValue& this_attr, GraphHasher* that, + const AttrValue& that_attr, + bool compare_functions) { if (this_attr.has_func() != that_attr.has_func()) { return errors::FailedPrecondition( "AttrValues are of different types: ", this_attr.DebugString(), @@ -523,12 +528,12 @@ class GraphHasher { return absl::OkStatus(); } - Status HashFunction(const NameAttrList& func, uint64* hash) { + absl::Status HashFunction(const NameAttrList& func, uint64* hash) { return HashFunction(func.name(), func.attr(), hash); } - Status HashFunction(const std::string& name, const AttrValueMap& attrs, - uint64* hash) { + absl::Status HashFunction(const std::string& name, const AttrValueMap& attrs, + uint64* hash) { const FunctionDef* fdef = flib_->Find(name); auto it = function_cache_->find(fdef); if (it != function_cache_->end()) { @@ -572,17 +577,19 @@ class GraphHasher { return absl::OkStatus(); } - Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that, - const NameAttrList& that_func) { + absl::Status CheckFunctionsEqual(const NameAttrList& this_func, + GraphHasher* that, + const NameAttrList& that_func) { return CheckFunctionsEqual(this_func.name(), this_func.attr(), that, that_func.name(), that_func.attr()); } - Status CheckFunctionsEqual(const std::string& this_name, - const AttrValueMap& this_attrs, GraphHasher* that, - const std::string& that_name, - const AttrValueMap& that_attrs) { - Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name, - that_attrs); + absl::Status CheckFunctionsEqual(const std::string& this_name, + const AttrValueMap& this_attrs, + GraphHasher* that, + const std::string& that_name, + const AttrValueMap& that_attrs) { + absl::Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, + that_name, that_attrs); if (!s.ok()) { return errors::FailedPrecondition("Functions ", this_name, " and ", that_name, " are not the same:\n", s); @@ -590,11 +597,11 @@ class GraphHasher { return s; } - Status CheckFunctionsEqualHelper(const std::string& this_name, - const AttrValueMap& this_attrs, - GraphHasher* that, - const std::string& that_name, - const AttrValueMap& that_attrs) { + absl::Status CheckFunctionsEqualHelper(const std::string& this_name, + const AttrValueMap& this_attrs, + GraphHasher* that, + const std::string& that_name, + const AttrValueMap& that_attrs) { const FunctionDef* this_fdef = flib_->Find(this_name); const FunctionDef* that_fdef = that->flib_->Find(that_name); @@ -641,8 +648,8 @@ class GraphHasher { return absl::OkStatus(); } - Status HashControlInputs(const std::vector& inputs, - uint64* hash) { + absl::Status HashControlInputs(const std::vector& inputs, + uint64* hash) { *hash = 0; for (const NodeDef* input : inputs) { uint64 node_hash = 0; @@ -653,7 +660,7 @@ class GraphHasher { return absl::OkStatus(); } - Status CheckControlInputsEqual( + absl::Status CheckControlInputsEqual( const std::vector& this_inputs, GraphHasher* that, const std::vector& that_inputs) { absl::flat_hash_map this_hashes; @@ -725,7 +732,7 @@ class GraphHasher { } // anonymous namespace -Status HashTensor(const Tensor& tensor, uint64* hash) { +absl::Status HashTensor(const Tensor& tensor, uint64* hash) { const tstring* s = nullptr; // Hash tensor type. *hash = Hash64Combine(0, tensor.dtype()); @@ -751,26 +758,27 @@ Status HashTensor(const Tensor& tensor, uint64* hash) { return absl::OkStatus(); } -Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) { +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, + uint64* hash) { const FunctionLibraryDefinition flib_def(OpRegistry::Global(), graph.library()); return HashNode(graph, node, flib_def, hash); } -Status HashNode(const GraphDef& graph, const NodeDef& node, - const FunctionLibraryDefinition& flib_def, uint64* hash) { +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, + const FunctionLibraryDefinition& flib_def, uint64* hash) { GraphHasher hasher(&graph, &node, &flib_def); TF_RETURN_IF_ERROR(hasher.Init()); return hasher.HashRoot(hash); } -Status HashGraph(const GraphDef& graph_def, uint64* hash) { +absl::Status HashGraph(const GraphDef& graph_def, uint64* hash) { const NodeDef* sink = nullptr; TF_RETURN_IF_ERROR(GetSink(graph_def, &sink)); return HashNode(graph_def, *sink, hash); } -Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) { +absl::Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) { const NodeDef* sink_a; TF_RETURN_IF_ERROR(GetSink(a, &sink_a)); const NodeDef* sink_b; @@ -778,8 +786,8 @@ Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) { return CheckSubgraphsEqual(a, sink_a, b, sink_b); } -Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, - const GraphDef& b, const NodeDef* node_b) { +absl::Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, + const GraphDef& b, const NodeDef* node_b) { const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library()); GraphHasher hasher_a(&a, node_a, &flib_def_a); TF_RETURN_IF_ERROR(hasher_a.Init()); diff --git a/tensorflow/core/data/hash_utils.h b/tensorflow/core/data/hash_utils.h index 401f047e47003f..7469953f1c05ff 100644 --- a/tensorflow/core/data/hash_utils.h +++ b/tensorflow/core/data/hash_utils.h @@ -28,34 +28,34 @@ namespace data { // // NOTE: There is currently no guarantee that the hash of a subgraph will stay // the same between TensorFlow builds. -Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash); -Status HashNode(const GraphDef& graph, const NodeDef& node, - const FunctionLibraryDefinition& flib_def, uint64* hash); +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash); +absl::Status HashNode(const GraphDef& graph, const NodeDef& node, + const FunctionLibraryDefinition& flib_def, uint64* hash); // Returns a stable hash of the given tensor. // // NOTE: There is currently no guarantee that the hash of a subgraph will stay // the same between TensorFlow builds. -Status HashTensor(const Tensor& tensor, uint64* hash); +absl::Status HashTensor(const Tensor& tensor, uint64* hash); // Returns a stable hash of the given graph. // // NOTE: There is currently no guarantee that the hash of a subgraph will stay // the same between TensorFlow builds. -Status HashGraph(const GraphDef& graph, uint64* hash); +absl::Status HashGraph(const GraphDef& graph, uint64* hash); // Determines whether the given graphs are equal, following the same logic used // for HashGraph. Returns OK if the graphs can be determined to be equal, // otherwise returns an error message explaining why the graphs couldn't be // determined to be equal. -Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b); +absl::Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b); // Determines whether the subgraphs rooted at the given nodes are equal // following the same logic used for HashGraph. Returns OK if the graphs can be // determined to be equal, otherwise returns an error message explaining why the // graphs couldn't be determined to be equal. -Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, - const GraphDef& b, const NodeDef* node_b); +absl::Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, + const GraphDef& b, const NodeDef* node_b); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/hash_utils_test.cc b/tensorflow/core/data/hash_utils_test.cc index 7219aff2db19c6..5edbe32c0a2b07 100644 --- a/tensorflow/core/data/hash_utils_test.cc +++ b/tensorflow/core/data/hash_utils_test.cc @@ -52,8 +52,8 @@ class DatasetHashUtilsTest : public ::testing::Test { return hash; } - Status CheckEqual(const FunctionDefLibrary& library, const FunctionDef& fn1, - const FunctionDef& fn2) { + absl::Status CheckEqual(const FunctionDefLibrary& library, + const FunctionDef& fn1, const FunctionDef& fn2) { // Construct nodes with a function as an attr. GraphDef graph_def; *graph_def.mutable_library() = library; @@ -132,7 +132,7 @@ TEST_F(DatasetHashUtilsTest, HashFunctionDifferentFunctions) { // The second op in `f2` is changed to "Add" EXPECT_NE(GetHash(fl, *f1), GetHash(fl, *f2)); - Status s = CheckEqual(fl, *f1, *f2); + absl::Status s = CheckEqual(fl, *f1, *f2); EXPECT_NE(s.code(), error::OK); EXPECT_THAT(s.message(), ContainsRegex("Add")); } @@ -274,7 +274,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeDifferentGraphs) { uint64 hash2 = GetHash(gd, *n4); // We expect different hashes because the op has changed. EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n3, gd, n4); + absl::Status s = CheckSubgraphsEqual(gd, n3, gd, n4); EXPECT_NE(s.code(), error::OK); EXPECT_THAT(s.message(), ContainsRegex("Add")); EXPECT_THAT(s.message(), ContainsRegex("Mul")); @@ -435,7 +435,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeReversedOrder) { uint64 hash2 = GetHash(gd, *n4); // We expect different hashes because the inputs of n3 are swapped. EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n3, gd, n4); + absl::Status s = CheckSubgraphsEqual(gd, n3, gd, n4); EXPECT_NE(s.code(), error::OK); EXPECT_THAT(s.message(), ContainsRegex("AttrValues are different")); } @@ -474,7 +474,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeInputPortChanged) { // We expect different hashes because the input ports for nodes used by n3 // has changed. EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n3, gd, n4); + absl::Status s = CheckSubgraphsEqual(gd, n3, gd, n4); EXPECT_NE(s.code(), error::OK); EXPECT_THAT(s.message(), ContainsRegex("Node inputs")); } @@ -702,7 +702,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionsOps) { uint64 hash1 = GetHash(gd, *n2); uint64 hash2 = GetHash(gd, *n3); EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n2, gd, n3); + absl::Status s = CheckSubgraphsEqual(gd, n2, gd, n3); EXPECT_NE(s.code(), error::OK); EXPECT_THAT( s.message(), @@ -773,7 +773,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctions) { uint64 hash1 = GetHash(gd, *n2); uint64 hash2 = GetHash(gd, *n3); EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n2, gd, n3); + absl::Status s = CheckSubgraphsEqual(gd, n2, gd, n3); EXPECT_NE(s.code(), error::OK); EXPECT_THAT( s.message(), @@ -846,7 +846,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeDifferentFunctionLists) { uint64 hash1 = GetHash(gd, *n2); uint64 hash2 = GetHash(gd, *n3); EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n2, gd, n3); + absl::Status s = CheckSubgraphsEqual(gd, n2, gd, n3); EXPECT_NE(s.code(), error::OK); EXPECT_THAT( s.message(), @@ -892,7 +892,7 @@ TEST_F(DatasetHashUtilsTest, HashNodeDifferentControlInputs) { uint64 hash1 = GetHash(gd, *n4); uint64 hash2 = GetHash(gd, *n5); EXPECT_NE(hash1, hash2); - Status s = CheckSubgraphsEqual(gd, n4, gd, n5); + absl::Status s = CheckSubgraphsEqual(gd, n4, gd, n5); EXPECT_NE(s.code(), error::OK); EXPECT_THAT(s.message(), ContainsRegex("Control dependencies are different")); } diff --git a/tensorflow/core/data/rewrite_utils.cc b/tensorflow/core/data/rewrite_utils.cc index d1b4976ce80aad..dc13b751485b27 100644 --- a/tensorflow/core/data/rewrite_utils.cc +++ b/tensorflow/core/data/rewrite_utils.cc @@ -109,9 +109,10 @@ void RemoveFakeSinks(FunctionDef* function_def) { } } -Status ApplyRewrites(OpKernelContext* ctx, - const std::function config_factory, - GraphDef* graph_def, string* dataset_node) { +absl::Status ApplyRewrites( + OpKernelContext* ctx, + const std::function config_factory, + GraphDef* graph_def, string* dataset_node) { std::unique_ptr grappler_item = GetGrapplerItem(graph_def, dataset_node, /*add_fake_sinks=*/true); std::unordered_map device_map; @@ -166,10 +167,10 @@ RewriterConfig CreateRewriterConfig( return rewriter_config; } -Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, - std::function config_factory, - bool record_fingerprint, - core::RefCountPtr* rewritten_input) { +absl::Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + std::function config_factory, + bool record_fingerprint, + core::RefCountPtr* rewritten_input) { std::vector> input_list; GraphDef graph_def; string output_node; @@ -224,7 +225,7 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, return; } uint64 hash = 0; - Status s = HashNode(graph_def, *node_def, *lib_def, &hash); + absl::Status s = HashNode(graph_def, *node_def, *lib_def, &hash); if (!s.ok()) { VLOG(3) << "Failed to hash graph: " << s; return; @@ -232,7 +233,7 @@ Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, for (const auto& pair : input_list) { hash = Hash64CombineUnordered(hash, Hash64(pair.first)); uint64 tensor_hash = 0; - Status s = HashTensor(pair.second, &tensor_hash); + absl::Status s = HashTensor(pair.second, &tensor_hash); if (s.ok()) { hash = Hash64CombineUnordered(hash, tensor_hash); } else { diff --git a/tensorflow/core/data/rewrite_utils.h b/tensorflow/core/data/rewrite_utils.h index 601023ef53fe74..b1701d8f3a82fc 100644 --- a/tensorflow/core/data/rewrite_utils.h +++ b/tensorflow/core/data/rewrite_utils.h @@ -47,10 +47,10 @@ RewriterConfig CreateRewriterConfig( // Rewrites the input dataset using the given config. The rewritten_input // stored in the core::RefCountPtr* output parameter is owned. -Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, - std::function config_factory, - bool record_fingerprint, - core::RefCountPtr* rewritten_input); +absl::Status RewriteDataset(OpKernelContext* ctx, const DatasetBase* input, + std::function config_factory, + bool record_fingerprint, + core::RefCountPtr* rewritten_input); // Creates a grappler item for `graph_def`, which is required for graph // optimization. diff --git a/tensorflow/core/data/root_dataset.cc b/tensorflow/core/data/root_dataset.cc index 006174e7af28e8..a051cddfa8c51a 100644 --- a/tensorflow/core/data/root_dataset.cc +++ b/tensorflow/core/data/root_dataset.cc @@ -146,8 +146,8 @@ void AddTraceMetadata(const RootDataset::Params& params, const Options& options, } // namespace // static -Status RootDataset::FromOptions(const DatasetBase* input, - DatasetBase** output) { +absl::Status RootDataset::FromOptions(const DatasetBase* input, + DatasetBase** output) { Params params; SetRootDatasetParams(input->options(), ¶ms); *output = new RootDataset(input, params); @@ -158,8 +158,8 @@ Status RootDataset::FromOptions(const DatasetBase* input, return absl::OkStatus(); } -Status RootDataset::FromOptions(core::RefCountPtr input, - DatasetBase** output) { +absl::Status RootDataset::FromOptions(core::RefCountPtr input, + DatasetBase** output) { Params params; for (const auto& framework : input->options().framework_type()) { metrics::RecordTFDataFrameworkType(framework); @@ -194,7 +194,7 @@ class RootDataset::Iterator : public DatasetIterator { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { // prefetch_autotuner.h currently disregards `autotune` parameter // so no matter whether dataset()->params_.autotune is on or not // we need to pass ram_budget_manager_ to the downstream dataset operations @@ -230,8 +230,9 @@ class RootDataset::Iterator : public DatasetIterator { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { { tf_shared_lock l(mu_); if (model_ != nullptr && end_time_usec_ > 0) { @@ -258,14 +259,14 @@ class RootDataset::Iterator : public DatasetIterator { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { IteratorContext iter_ctx(CreateParams(ctx)); TF_RETURN_IF_ERROR(RestoreInput(&iter_ctx, reader, input_impl_)); ctx->MergeCheckpoint(iter_ctx.checkpoint()); @@ -340,7 +341,7 @@ class RootDataset::Iterator : public DatasetIterator { return params; } - Status EnsureModelThreadStarted(IteratorContext* ctx) { + absl::Status EnsureModelThreadStarted(IteratorContext* ctx) { mutex_lock l(mu_); if (!model_thread_) { RunMode run_mode = ctx->run_mode(); @@ -354,7 +355,7 @@ class RootDataset::Iterator : public DatasetIterator { // Dynamic RAM budget should only apply to tf.data service. raw_ram_budget = params.ComputeInitialAutotuneRamBudget(); } - Status status = model_->OptimizeLoop( + absl::Status status = model_->OptimizeLoop( params.autotune_algorithm, params.autotune_cpu_budget_func, params.ram_budget_share, raw_ram_budget, *ram_budget_manager_, cancellation_manager_.get()); @@ -432,32 +433,32 @@ int64_t RootDataset::CardinalityInternal(CardinalityOptions options) const { return input_->Cardinality(options); } -Status RootDataset::Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const { +absl::Status RootDataset::Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const { std::vector inputs; TF_RETURN_IF_ERROR(this->InputDatasets(&inputs)); return inputs[0]->Get(ctx, index, out_tensors); } -Status RootDataset::InputDatasets( +absl::Status RootDataset::InputDatasets( std::vector* inputs) const { inputs->push_back(input_); return absl::OkStatus(); } -Status RootDataset::CheckExternalState() const { +absl::Status RootDataset::CheckExternalState() const { return input_->CheckExternalState(); } -Status RootDataset::AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const { +absl::Status RootDataset::AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const { return errors::Unimplemented("RootDataset does not support serialization."); } #if !defined(IS_MOBILE_PLATFORM) -Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, - DatasetBase** output) { +absl::Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, + DatasetBase** output) { const Options& options = input->options(); absl::flat_hash_set optimizations_enabled; absl::flat_hash_set optimizations_disabled; @@ -481,8 +482,9 @@ Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, return CreateRewriterConfig(optimizations, optimization_configs); }; core::RefCountPtr rewritten_output; - Status s = RewriteDataset(ctx, input, std::move(config_factory), - /*record_fingerprint=*/false, &rewritten_output); + absl::Status s = + RewriteDataset(ctx, input, std::move(config_factory), + /*record_fingerprint=*/false, &rewritten_output); *output = rewritten_output.get(); bool rewritten = (*output != input); diff --git a/tensorflow/core/data/root_dataset.h b/tensorflow/core/data/root_dataset.h index 870741ed9354b2..0734e9959a1855 100644 --- a/tensorflow/core/data/root_dataset.h +++ b/tensorflow/core/data/root_dataset.h @@ -50,9 +50,10 @@ class RootDataset : public DatasetBase { } }; - static Status FromOptions(const DatasetBase* input, DatasetBase** output); - static Status FromOptions(core::RefCountPtr input, - DatasetBase** output); + static absl::Status FromOptions(const DatasetBase* input, + DatasetBase** output); + static absl::Status FromOptions(core::RefCountPtr input, + DatasetBase** output); ~RootDataset() override; @@ -60,21 +61,22 @@ class RootDataset : public DatasetBase { const std::vector& output_shapes() const override; int64_t CardinalityInternal(CardinalityOptions options) const override; - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override; - Status CheckExternalState() const override; + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override; + absl::Status CheckExternalState() const override; string DebugString() const override; - Status InputDatasets(std::vector* inputs) const override; + absl::Status InputDatasets( + std::vector* inputs) const override; std::unique_ptr MakeIteratorInternal( const string& prefix) const override; - Status RandomIndexingCompatible() const override { + absl::Status RandomIndexingCompatible() const override { return random_indexing_compatible_; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override; + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override; private: class Iterator; @@ -87,7 +89,7 @@ class RootDataset : public DatasetBase { core::RefCountPtr owned_input_; const Params params_; TraceMeMetadata traceme_metadata_; - Status random_indexing_compatible_; + absl::Status random_indexing_compatible_; }; // Finalizes the `input` dataset, which is expected to be called before the @@ -95,8 +97,8 @@ class RootDataset : public DatasetBase { // optimizations or inject internal tf.data transformations responsible for // autotuning or threading configuration. The caller must ensure that the // input dataset to be finalized outlives the output. -Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, - DatasetBase** output); +absl::Status FinalizeDataset(OpKernelContext* ctx, const DatasetBase* input, + DatasetBase** output); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/serialization_utils.cc b/tensorflow/core/data/serialization_utils.cc index fa4e5b42641c87..43e528a2e47cc7 100644 --- a/tensorflow/core/data/serialization_utils.cc +++ b/tensorflow/core/data/serialization_utils.cc @@ -47,9 +47,10 @@ constexpr char kIsDataset[] = ".is_dataset"; constexpr char kIteratorVariantTypeName[] = "tensorflow::Iterator"; constexpr char kOutputNode[] = ".output_node"; -Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def, - const std::vector>& input_list, - const string& output_node, Tensor* result) { +absl::Status FromGraphDef( + FunctionLibraryRuntime* flr, const GraphDef& graph_def, + const std::vector>& input_list, + const string& output_node, Tensor* result) { FunctionLibraryRuntime* cloned_flr = nullptr; std::unique_ptr pflr = nullptr; std::unique_ptr lib_def = nullptr; @@ -67,8 +68,8 @@ Status FromGraphDef(FunctionLibraryRuntime* flr, const GraphDef& graph_def, // FindStatefulOps searches `graph_def` for all of its stateful ops storing // their names in `stateful_op_names`. -Status FindStatefulOps(const GraphDef& graph_def, - std::vector* stateful_op_names) { +absl::Status FindStatefulOps(const GraphDef& graph_def, + std::vector* stateful_op_names) { FunctionLibraryDefinition lib_def(OpRegistry::Global(), graph_def.library()); // Iterate over all nodes in the graph. @@ -95,10 +96,9 @@ Status FindStatefulOps(const GraphDef& graph_def, } // namespace -Status ReadElementsFromCheckpoint(IteratorContext* ctx, - IteratorStateReader* reader, - StringPiece key_prefix, - std::vector>* elements) { +absl::Status ReadElementsFromCheckpoint( + IteratorContext* ctx, IteratorStateReader* reader, StringPiece key_prefix, + std::vector>* elements) { int64_t num_elements; TF_RETURN_IF_ERROR( reader->ReadScalar(key_prefix, kNumElements, &num_elements)); @@ -122,9 +122,9 @@ Status ReadElementsFromCheckpoint(IteratorContext* ctx, return absl::OkStatus(); } -Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, - const std::vector>& elements, - int64_t index) { +absl::Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, + const std::vector>& elements, + int64_t index) { const std::vector& element = elements[index]; std::string element_prefix = absl::StrCat(key_prefix, "::", index); TF_RETURN_IF_ERROR( @@ -136,7 +136,7 @@ Status WriteElement(IteratorStateWriter* writer, StringPiece key_prefix, return absl::OkStatus(); } -Status WriteElementsToCheckpoint( +absl::Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements) { TF_RETURN_IF_ERROR( @@ -147,7 +147,7 @@ Status WriteElementsToCheckpoint( return absl::OkStatus(); } -Status UpdateCheckpointElements( +absl::Status UpdateCheckpointElements( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements, const absl::flat_hash_set& checkpoint_indices) { @@ -174,51 +174,57 @@ VariantTensorDataReader::VariantTensorDataReader( } } -Status VariantTensorDataReader::ReadScalar(StringPiece key, - int64_t* val) const { +absl::Status VariantTensorDataReader::ReadScalar(StringPiece key, + int64_t* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadScalar(prefix, key, val); } -Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key, - int64_t* val) const { +absl::Status VariantTensorDataReader::ReadScalar(StringPiece name, + StringPiece key, + int64_t* val) const { return ReadScalarInternal(name, key, val); } -Status VariantTensorDataReader::ReadScalar(StringPiece key, - tstring* val) const { +absl::Status VariantTensorDataReader::ReadScalar(StringPiece key, + tstring* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadScalar(prefix, key, val); } -Status VariantTensorDataReader::ReadScalar(StringPiece name, StringPiece key, - tstring* val) const { +absl::Status VariantTensorDataReader::ReadScalar(StringPiece name, + StringPiece key, + tstring* val) const { return ReadScalarInternal(name, key, val); } -Status VariantTensorDataReader::ReadTensor(StringPiece key, Tensor* val) const { +absl::Status VariantTensorDataReader::ReadTensor(StringPiece key, + Tensor* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadTensor(prefix, key, val); } -Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, - StringPiece key, Tensor* val) const { +absl::Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, + StringPiece key, + Tensor* val) const { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return ReadTensorInternal(flr, prefix, key, val); } -Status VariantTensorDataReader::ReadTensor(StringPiece name, StringPiece key, - Tensor* val) const { +absl::Status VariantTensorDataReader::ReadTensor(StringPiece name, + StringPiece key, + Tensor* val) const { return ReadTensor(/*flr=*/nullptr, name, key, val); } -Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, - StringPiece name, StringPiece key, - Tensor* val) const { +absl::Status VariantTensorDataReader::ReadTensor(FunctionLibraryRuntime* flr, + StringPiece name, + StringPiece key, + Tensor* val) const { return ReadTensorInternal(flr, name, key, val); } @@ -241,9 +247,9 @@ bool VariantTensorDataReader::Contains(StringPiece n, StringPiece key) const { } template -Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, - StringPiece key, - T* val) const { +absl::Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, + StringPiece key, + T* val) const { string name(n); auto it = map_.find(name); if (it == map_.end()) { @@ -258,10 +264,9 @@ Status VariantTensorDataReader::ReadScalarInternal(StringPiece n, return absl::OkStatus(); } -Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr, - StringPiece n, - StringPiece key, - Tensor* val) const { +absl::Status VariantTensorDataReader::ReadTensorInternal( + FunctionLibraryRuntime* flr, StringPiece n, StringPiece key, + Tensor* val) const { if (Contains(n, strings::StrCat(key, kIsDataset))) { return ReadDatasetInternal(flr, n, key, val); } @@ -279,10 +284,9 @@ Status VariantTensorDataReader::ReadTensorInternal(FunctionLibraryRuntime* flr, return absl::OkStatus(); } -Status VariantTensorDataReader::ReadDatasetInternal(FunctionLibraryRuntime* flr, - StringPiece n, - StringPiece key, - Tensor* val) const { +absl::Status VariantTensorDataReader::ReadDatasetInternal( + FunctionLibraryRuntime* flr, StringPiece n, StringPiece key, + Tensor* val) const { if (flr == nullptr) { return errors::Internal( "Function library runtime is needed to restore a dataset."); @@ -312,39 +316,42 @@ std::map VariantTensorDataReader::ReadAllTensors() { return result; } -Status VariantTensorDataWriter::WriteScalar(StringPiece key, - const int64_t val) { +absl::Status VariantTensorDataWriter::WriteScalar(StringPiece key, + const int64_t val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } -Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key, - const int64_t val) { +absl::Status VariantTensorDataWriter::WriteScalar(StringPiece name, + StringPiece key, + const int64_t val) { return WriteScalarInternal(name, key, val); } -Status VariantTensorDataWriter::WriteScalar(StringPiece key, - const tstring& val) { +absl::Status VariantTensorDataWriter::WriteScalar(StringPiece key, + const tstring& val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteScalar(prefix, key, val); } -Status VariantTensorDataWriter::WriteScalar(StringPiece name, StringPiece key, - const tstring& val) { +absl::Status VariantTensorDataWriter::WriteScalar(StringPiece name, + StringPiece key, + const tstring& val) { return WriteScalarInternal(name, key, val); } -Status VariantTensorDataWriter::WriteTensor(StringPiece key, - const Tensor& val) { +absl::Status VariantTensorDataWriter::WriteTensor(StringPiece key, + const Tensor& val) { string prefix; TF_RETURN_IF_ERROR(ExtractIteratorPrefix(key, &prefix)); return WriteTensor(prefix, key, val); } -Status VariantTensorDataWriter::WriteTensor(StringPiece name, StringPiece key, - const Tensor& val) { +absl::Status VariantTensorDataWriter::WriteTensor(StringPiece name, + StringPiece key, + const Tensor& val) { return WriteTensorInternal(name, key, val); } @@ -385,9 +392,9 @@ void VariantTensorDataWriter::GetData( } template -Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, - StringPiece key, - const T& val) { +absl::Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, + StringPiece key, + const T& val) { if (is_flushed_) { return errors::FailedPrecondition( "Cannot call WriteScalar after GetData or ReleaseData is called"); @@ -397,9 +404,9 @@ Status VariantTensorDataWriter::WriteScalarInternal(StringPiece name, return WriteTensorInternal(name, key, val_t); } -Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, - StringPiece key, - const Tensor& val) { +absl::Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, + StringPiece key, + const Tensor& val) { DatasetBase* dataset; if (GetDatasetFromVariantTensor(val, &dataset).ok()) { return WriteDatasetInternal(n, key, dataset); @@ -422,7 +429,7 @@ Status VariantTensorDataWriter::WriteTensorInternal(StringPiece n, return absl::OkStatus(); } -Status VariantTensorDataWriter::WriteDatasetInternal( +absl::Status VariantTensorDataWriter::WriteDatasetInternal( StringPiece n, StringPiece key, const DatasetBase* dataset) { GraphDef graph_def; SerializationContext ctx((SerializationContext::Params())); @@ -453,7 +460,7 @@ IteratorStateVariant::IteratorStateVariant(const IteratorStateVariant& other) { } } -Status IteratorStateVariant::InitializeFromVariantData( +absl::Status IteratorStateVariant::InitializeFromVariantData( std::unique_ptr data) { data_ = std::move(data); return absl::OkStatus(); @@ -461,7 +468,7 @@ Status IteratorStateVariant::InitializeFromVariantData( void IteratorStateVariant::Encode(VariantTensorData* data) const { CompressedElement compressed_tensors; - Status s = CompressElement(data_->tensors(), &compressed_tensors); + absl::Status s = CompressElement(data_->tensors(), &compressed_tensors); if (!s.ok()) { LOG(WARNING) << "Failed to compress iterator state variant: " << s; *data = *data_; @@ -487,7 +494,7 @@ bool IteratorStateVariant::Decode(VariantTensorData data) { } std::vector tensors; - Status s = UncompressElement(*compressed, &tensors); + absl::Status s = UncompressElement(*compressed, &tensors); if (!s.ok()) { LOG(WARNING) << "Failed to uncompress iterator state variant: " << s; data_ = std::make_unique(std::move(data)); @@ -533,9 +540,10 @@ std::string IteratorStateVariant::DebugString() const { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); -Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, - std::vector>* input_list, - GraphDef* result, string* dataset_node) { +absl::Status AsGraphDefForRewrite( + OpKernelContext* ctx, const DatasetBase* input, + std::vector>* input_list, GraphDef* result, + string* dataset_node) { SerializationContext::Params params(ctx); params.input_list = input_list; params.external_state_policy = ExternalStatePolicy::POLICY_IGNORE; @@ -552,9 +560,9 @@ Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, return absl::OkStatus(); } -Status AsGraphDef(const DatasetBase* dataset, - SerializationContext&& serialization_ctx, - GraphDef* graph_def) { +absl::Status AsGraphDef(const DatasetBase* dataset, + SerializationContext&& serialization_ctx, + GraphDef* graph_def) { if (serialization_ctx.external_state_policy() == ExternalStatePolicy::POLICY_FAIL) { TF_RETURN_IF_ERROR(dataset->CheckExternalState()); diff --git a/tensorflow/core/data/serialization_utils.h b/tensorflow/core/data/serialization_utils.h index 5f867efefdd10f..5eec6305b37c9b 100644 --- a/tensorflow/core/data/serialization_utils.h +++ b/tensorflow/core/data/serialization_utils.h @@ -37,16 +37,15 @@ namespace data { inline constexpr absl::string_view kRetvalOp = "_Retval"; // Reads dataset elements from the checkpoint reader using the given key prefix. -Status ReadElementsFromCheckpoint(IteratorContext* ctx, - IteratorStateReader* reader, - StringPiece key_prefix, - std::vector>* elements); +absl::Status ReadElementsFromCheckpoint( + IteratorContext* ctx, IteratorStateReader* reader, StringPiece key_prefix, + std::vector>* elements); // Writes dataset elements to the checkpoint writer using the given key prefix. // The elements can be read back by passing the same key prefix to // ReadElementsFromCheckpoint. Only one list of elements can be written under // the same key_prefix. -Status WriteElementsToCheckpoint( +absl::Status WriteElementsToCheckpoint( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements); @@ -54,7 +53,7 @@ Status WriteElementsToCheckpoint( // using the given key prefix, assuming that vector of elements have // checkpointed these before. The elements can be read back by passing the same // key prefix to ReadElementsFromCheckpoint. -Status UpdateCheckpointElements( +absl::Status UpdateCheckpointElements( IteratorStateWriter* writer, StringPiece key_prefix, const std::vector>& elements, const absl::flat_hash_set& checkpoint_indices); @@ -68,27 +67,29 @@ class VariantTensorDataReader : public IteratorStateReader { bool Contains(StringPiece key) const override; bool Contains(StringPiece name, StringPiece key) const override; - Status ReadScalar(StringPiece key, int64_t* val) const override; - Status ReadScalar(StringPiece name, StringPiece key, - int64_t* val) const override; - Status ReadScalar(StringPiece key, tstring* val) const override; - Status ReadScalar(StringPiece name, StringPiece key, - tstring* val) const override; - Status ReadTensor(StringPiece key, Tensor* val) const override; - Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, - Tensor* val) const override; - Status ReadTensor(StringPiece name, StringPiece key, - Tensor* val) const override; - Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const override; + absl::Status ReadScalar(StringPiece key, int64_t* val) const override; + absl::Status ReadScalar(StringPiece name, StringPiece key, + int64_t* val) const override; + absl::Status ReadScalar(StringPiece key, tstring* val) const override; + absl::Status ReadScalar(StringPiece name, StringPiece key, + tstring* val) const override; + absl::Status ReadTensor(StringPiece key, Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece key, + Tensor* val) const override; + absl::Status ReadTensor(StringPiece name, StringPiece key, + Tensor* val) const override; + absl::Status ReadTensor(FunctionLibraryRuntime* flr, StringPiece name, + StringPiece key, Tensor* val) const override; private: template - Status ReadScalarInternal(StringPiece name, StringPiece key, T* val) const; - Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const; - Status ReadDatasetInternal(FunctionLibraryRuntime* flr, StringPiece name, - StringPiece key, Tensor* val) const; + absl::Status ReadScalarInternal(StringPiece name, StringPiece key, + T* val) const; + absl::Status ReadTensorInternal(FunctionLibraryRuntime* flr, StringPiece name, + StringPiece key, Tensor* val) const; + absl::Status ReadDatasetInternal(FunctionLibraryRuntime* flr, + StringPiece name, StringPiece key, + Tensor* val) const; // Produces all key/value pairs stored in this reader. Useful for debugging. std::map ReadAllTensors(); @@ -112,16 +113,17 @@ class VariantTensorDataReader : public IteratorStateReader { // Now the VariantTensorData objects can be used to serialize. class VariantTensorDataWriter : public IteratorStateWriter { public: - Status WriteScalar(StringPiece key, int64_t val) override; - Status WriteScalar(StringPiece name, StringPiece key, int64_t val) override; + absl::Status WriteScalar(StringPiece key, int64_t val) override; + absl::Status WriteScalar(StringPiece name, StringPiece key, + int64_t val) override; - Status WriteScalar(StringPiece key, const tstring& val) override; - Status WriteScalar(StringPiece name, StringPiece key, - const tstring& val) override; + absl::Status WriteScalar(StringPiece key, const tstring& val) override; + absl::Status WriteScalar(StringPiece name, StringPiece key, + const tstring& val) override; - Status WriteTensor(StringPiece key, const Tensor& val) override; - Status WriteTensor(StringPiece name, StringPiece key, - const Tensor& val) override; + absl::Status WriteTensor(StringPiece key, const Tensor& val) override; + absl::Status WriteTensor(StringPiece name, StringPiece key, + const Tensor& val) override; // Releases the built VariantTensorData's to `variants`. Clears out all // class state. @@ -135,11 +137,12 @@ class VariantTensorDataWriter : public IteratorStateWriter { void Reset(); template - Status WriteScalarInternal(StringPiece name, StringPiece key, const T& val); - Status WriteTensorInternal(StringPiece name, StringPiece key, - const Tensor& val); - Status WriteDatasetInternal(StringPiece name, StringPiece key, - const DatasetBase* dataset); + absl::Status WriteScalarInternal(StringPiece name, StringPiece key, + const T& val); + absl::Status WriteTensorInternal(StringPiece name, StringPiece key, + const Tensor& val); + absl::Status WriteDatasetInternal(StringPiece name, StringPiece key, + const DatasetBase* dataset); bool is_flushed_ = false; std::map> data_; @@ -180,7 +183,8 @@ class IteratorStateVariant { static std::string TypeName(); // Initializes `this` from a VariantTensorData object. - Status InitializeFromVariantData(std::unique_ptr data); + absl::Status InitializeFromVariantData( + std::unique_ptr data); // Returns a borrowed pointer to the underlying VariantTensorData. const VariantTensorData* GetData() const { return data_.get(); } @@ -207,9 +211,9 @@ class IteratorStateVariant { }; // Returns a GraphDef representation of the given dataset. -Status AsGraphDef(const DatasetBase* dataset, - SerializationContext&& serialization_ctx, - GraphDef* graph_def); +absl::Status AsGraphDef(const DatasetBase* dataset, + SerializationContext&& serialization_ctx, + GraphDef* graph_def); // Returns a GraphDef representation of the given dataset suitable for // optimization rewrites. It sets serialization parameters to export a minimum @@ -217,9 +221,10 @@ Status AsGraphDef(const DatasetBase* dataset, // state, not serializing data tensors, not failing if there are datasets which // do not have AsGraphDef implemented). Sets the `dataset_node` parameter to the // dataset's node name in the resulting GraphDef. -Status AsGraphDefForRewrite(OpKernelContext* ctx, const DatasetBase* input, - std::vector>* input_list, - GraphDef* result, string* dataset_node); +absl::Status AsGraphDefForRewrite( + OpKernelContext* ctx, const DatasetBase* input, + std::vector>* input_list, GraphDef* result, + string* dataset_node); // Analyzes the bytes of a tf.data iterator checkpoint to identify all of the // keys in the checkpoint along with their sizes in bytes. diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 188418148f2fc3..8a76428a848dde 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -409,7 +409,7 @@ tf_cc_test( "@com_google_absl//absl/status", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) @@ -1135,7 +1135,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:status_to_from_proto", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD index 60e2da30b5a8ac..59c323f7601384 100644 --- a/tensorflow/core/data/service/client/BUILD +++ b/tensorflow/core/data/service/client/BUILD @@ -58,7 +58,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:retrying_utils", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -102,7 +102,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -121,8 +121,8 @@ tf_cc_test( "//tensorflow/core/data/service:test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index a323a02b6096bc..a7944af056b863 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/client/common.h" #include "tensorflow/core/data/service/client/validate_utils.h" #include "tensorflow/core/data/service/common.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tsl/platform/retrying_utils.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { @@ -89,7 +89,7 @@ DataServiceClient::~DataServiceClient() { << iteration_client_id_; task_thread_manager_.reset(); if (initialized_) { - Status s = dispatcher_->ReleaseIterationClient(iteration_client_id_); + absl::Status s = dispatcher_->ReleaseIterationClient(iteration_client_id_); if (!s.ok()) { LOG(WARNING) << "Failed to release iteration client id: " << s; } @@ -102,7 +102,7 @@ DataServiceClient::~DataServiceClient() { << iteration_client_id_; } -Status DataServiceClient::Initialize( +absl::Status DataServiceClient::Initialize( const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, Allocator* allocator) { accelerator_device_info_ = accelerator_device_info; @@ -415,7 +415,7 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { return CreateGrpcWorkerClient(task_info); } -Status DataServiceClient::AddTask(const TaskInfo& task_info) +absl::Status DataServiceClient::AddTask(const TaskInfo& task_info) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_ASSIGN_OR_RETURN(std::unique_ptr worker, CreateWorkerClient(task_info)); @@ -458,7 +458,7 @@ void DataServiceClient::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { req.set_target_processing_time_nsec(target_processing_time_nsec); } ClientHeartbeatResponse resp; - Status s = dispatcher_->ClientHeartbeat(req, resp); + absl::Status s = dispatcher_->ClientHeartbeat(req, resp); if (!s.ok()) { if (IsPreemptedError(s)) { LOG(WARNING) @@ -526,7 +526,7 @@ void DataServiceClient::UpdateTasks(const ClientHeartbeatResponse& resp) should_finish_iteration_ = false; continue; } - Status s = AddTask(it->second); + absl::Status s = AddTask(it->second); if (!s.ok()) { status_ = s; get_next_cv_.notify_all(); @@ -655,9 +655,9 @@ void DataServiceClient::RunWorkerThread(std::function done) VLOG(3) << "Processing task " << task_to_process->info.task_id(); } int64_t deadline_micros = kint64max; - Status s = GetElementTraced(task_to_process.get(), deadline_micros, - /*enqueue_result=*/!IsCoordinatedRead(), - allow_skip, result); + absl::Status s = GetElementTraced(task_to_process.get(), deadline_micros, + /*enqueue_result=*/!IsCoordinatedRead(), + allow_skip, result); if (!s.ok()) { mutex_lock l(mu_); VLOG(1) << "Failed to get element from worker " @@ -754,8 +754,8 @@ void DataServiceClient::AdvanceTaskIndex() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { } } -Status DataServiceClient::TryGetElement(const Task& task, bool allow_skip, - GetElementResult& result) { +absl::Status DataServiceClient::TryGetElement(const Task& task, bool allow_skip, + GetElementResult& result) { GetElementRequest req; req.set_task_id(task.info.task_id()); req.set_skipped_previous_round(task.skipped_previous_round); @@ -797,10 +797,9 @@ void DataServiceClient::ProcessGetElementResponse( get_next_cv_.notify_all(); } -Status DataServiceClient::GetElementTraced(Task* task, int64_t deadline_micros, - bool enqueue_result, bool allow_skip, - std::shared_ptr result) - TF_LOCKS_EXCLUDED(mu_) { +absl::Status DataServiceClient::GetElementTraced( + Task* task, int64_t deadline_micros, bool enqueue_result, bool allow_skip, + std::shared_ptr result) TF_LOCKS_EXCLUDED(mu_) { VLOG(3) << "Getting an element for task id " << task->info.task_id(); tsl::profiler::TraceMe activity("GetDataServiceElement", tsl::profiler::TraceMeLevel::kInfo); @@ -817,15 +816,16 @@ Status DataServiceClient::GetElementTraced(Task* task, int64_t deadline_micros, {"round_index", task->round}}); }); } - Status s = + absl::Status s = GetElement(task, deadline_micros, enqueue_result, allow_skip, result); mutex_lock l(mu_); VLOG(3) << "Got an element for task id " << task->info.task_id(); return s; } -Status DataServiceClient::MaybeRemoveTask(Task& task, int64_t deadline_micros, - Result& result) +absl::Status DataServiceClient::MaybeRemoveTask(Task& task, + int64_t deadline_micros, + Result& result) TF_LOCKS_EXCLUDED(mu_) { bool removed; VLOG(1) << "Requesting task removal for worker " << task.info.worker_address() @@ -854,13 +854,13 @@ Status DataServiceClient::MaybeRemoveTask(Task& task, int64_t deadline_micros, return absl::OkStatus(); } -Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, - bool enqueue_result, bool allow_skip, - std::shared_ptr result) +absl::Status DataServiceClient::GetElement(Task* task, int64_t deadline_micros, + bool enqueue_result, bool allow_skip, + std::shared_ptr result) TF_LOCKS_EXCLUDED(mu_) { GetElementResult get_element_result; while (true) { - Status s = TryGetElement(*task, allow_skip, get_element_result); + absl::Status s = TryGetElement(*task, allow_skip, get_element_result); if (s.ok()) { task->num_retries = 0; break; diff --git a/tensorflow/core/data/service/client/data_service_client.h b/tensorflow/core/data/service/client/data_service_client.h index 0faa2a9ee19be3..461e74256cb6d9 100644 --- a/tensorflow/core/data/service/client/data_service_client.h +++ b/tensorflow/core/data/service/client/data_service_client.h @@ -80,7 +80,7 @@ class DataServiceClient { DataServiceClient& operator=(const DataServiceClient&) = delete; // Initializes the client. - Status Initialize( + absl::Status Initialize( const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, Allocator* allocator); @@ -152,7 +152,7 @@ class DataServiceClient { void TaskThreadManager(); void TryBlockRound(int64_t round) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); void UpdateIterationFinished(bool iteration_finished); - Status AddTask(const TaskInfo& task_info); + absl::Status AddTask(const TaskInfo& task_info); absl::StatusOr> CreateWorkerClient( const TaskInfo& task_info); absl::StatusOr> CreateWorkerClient( @@ -176,17 +176,19 @@ class DataServiceClient { // task a chance to proceed. std::shared_ptr GetTaskToProcess(); void AdvanceTaskIndex(); - Status TryGetElement(const Task& task, bool allow_skip, - GetElementResult& result); + absl::Status TryGetElement(const Task& task, bool allow_skip, + GetElementResult& result); void ProcessGetElementResponse(bool enqueue_result, GetElementResult& get_element_result, std::shared_ptr result, Task& task); - Status GetElementTraced(Task* task, int64_t deadline_micros, + absl::Status GetElementTraced(Task* task, int64_t deadline_micros, + bool enqueue_result, bool allow_skip, + std::shared_ptr result); + absl::Status MaybeRemoveTask(Task& task, int64_t deadline_micros, + Result& result); + absl::Status GetElement(Task* task, int64_t deadline_micros, bool enqueue_result, bool allow_skip, std::shared_ptr result); - Status MaybeRemoveTask(Task& task, int64_t deadline_micros, Result& result); - Status GetElement(Task* task, int64_t deadline_micros, bool enqueue_result, - bool allow_skip, std::shared_ptr result); bool ResultReady() const; std::shared_ptr PopNextResult(); bool IsCoordinatedRead() const; @@ -233,7 +235,7 @@ class DataServiceClient { // A status to be returned from the next call to `GetNext`. This is set by // asynchronous threads when they encounter errors. - Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // A queue of results for `GetElement` requests to read from. When doing // strict round robin reads, the queue will contain placeholder results with // their `Result::ready` field false until their data has been retrieved diff --git a/tensorflow/core/data/service/client/utils.cc b/tensorflow/core/data/service/client/utils.cc index 1f5f7f8d4041e9..66a3f6d945897c 100644 --- a/tensorflow/core/data/service/client/utils.cc +++ b/tensorflow/core/data/service/client/utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/grpc_util.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/protobuf/data_service.pb.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { @@ -47,7 +47,7 @@ absl::StatusOr GetDataServiceMetadata( absl::Time deadline = absl::FromUnixMicros(EnvTime::NowMicros()) + kGetMetadataRetryTimeout; - Status status = grpc_util::Retry( + absl::Status status = grpc_util::Retry( [&]() { return client.GetDataServiceMetadata(dataset_id, metadata); }, absl::Substitute("Get data service metadata for dataset $0, " "with dispatcher at $1.", diff --git a/tensorflow/core/data/service/client/utils_test.cc b/tensorflow/core/data/service/client/utils_test.cc index 8729bff56cbcb2..aa7018f3368892 100644 --- a/tensorflow/core/data/service/client/utils_test.cc +++ b/tensorflow/core/data/service/client/utils_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/test_cluster.h" #include "tensorflow/core/data/service/test_util.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/protobuf/data_service.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/client/validate_utils.cc b/tensorflow/core/data/service/client/validate_utils.cc index b4769ce1282eee..34e5535068e5cb 100644 --- a/tensorflow/core/data/service/client/validate_utils.cc +++ b/tensorflow/core/data/service/client/validate_utils.cc @@ -28,7 +28,8 @@ namespace data { namespace { // Validates local worker related parameters. -Status ValidateLocalWorkers(const DataServiceParams& data_service_params) { +absl::Status ValidateLocalWorkers( + const DataServiceParams& data_service_params) { if (data_service_params.target_workers != TARGET_WORKERS_LOCAL) { return absl::OkStatus(); } @@ -58,7 +59,8 @@ Status ValidateLocalWorkers(const DataServiceParams& data_service_params) { } // Validates cross-trainer cache related parameters. -Status ValidateCrossTrainerCache(const DataServiceParams& data_service_params) { +absl::Status ValidateCrossTrainerCache( + const DataServiceParams& data_service_params) { if (!data_service_params.cross_trainer_cache_options.has_value()) { return absl::OkStatus(); } @@ -88,7 +90,8 @@ Status ValidateCrossTrainerCache(const DataServiceParams& data_service_params) { } } // namespace -Status ValidateDataServiceParams(const DataServiceParams& data_service_params) { +absl::Status ValidateDataServiceParams( + const DataServiceParams& data_service_params) { TF_RETURN_IF_ERROR(ValidateLocalWorkers(data_service_params)); TF_RETURN_IF_ERROR(ValidateCrossTrainerCache(data_service_params)); return absl::OkStatus(); diff --git a/tensorflow/core/data/service/client/validate_utils.h b/tensorflow/core/data/service/client/validate_utils.h index b0eb370c483201..16091a072efd1e 100644 --- a/tensorflow/core/data/service/client/validate_utils.h +++ b/tensorflow/core/data/service/client/validate_utils.h @@ -22,7 +22,8 @@ namespace tensorflow { namespace data { // Validates data service dataset parameters. -Status ValidateDataServiceParams(const DataServiceParams& data_service_params); +absl::Status ValidateDataServiceParams( + const DataServiceParams& data_service_params); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/common.cc b/tensorflow/core/data/service/common.cc index cf0d352baf5a81..adde241b38634c 100644 --- a/tensorflow/core/data/service/common.cc +++ b/tensorflow/core/data/service/common.cc @@ -52,7 +52,7 @@ bool IsStaticShard(const ProcessingModeDef& processing_mode) { processing_mode.sharding_policy() == ProcessingModeDef::HINT; } -Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) { +absl::Status ValidateProcessingMode(const ProcessingModeDef& processing_mode) { if (!IsNoShard(processing_mode) && !IsDynamicShard(processing_mode) && !IsStaticShard(processing_mode)) { return errors::Internal( @@ -131,7 +131,7 @@ absl::StatusOr ParseDeploymentMode(absl::string_view s) { "COLOCATED, REMOTE, and HYBRID."); } -bool IsPreemptedError(const Status& status) { +bool IsPreemptedError(const absl::Status& status) { return errors::IsAborted(status) || errors::IsCancelled(status) || errors::IsUnavailable(status); } diff --git a/tensorflow/core/data/service/common.h b/tensorflow/core/data/service/common.h index 550cffeb7b9558..e9760e56f82845 100644 --- a/tensorflow/core/data/service/common.h +++ b/tensorflow/core/data/service/common.h @@ -67,7 +67,7 @@ bool IsDynamicShard(const ProcessingModeDef& processing_mode); bool IsStaticShard(const ProcessingModeDef& processing_mode); // Returns an internal error if `processing_mode` is invalid. -Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); +absl::Status ValidateProcessingMode(const ProcessingModeDef& processing_mode); // Converts tf.data service `sharding_policy` to `AutoShardPolicy`. Returns an // internal error if `sharding_policy` is not supported. @@ -86,7 +86,7 @@ std::string TargetWorkersToString(TargetWorkers target_workers); absl::StatusOr ParseDeploymentMode(absl::string_view s); // Returns true if `status` is a retriable error that indicates preemption. -bool IsPreemptedError(const Status& status); +bool IsPreemptedError(const absl::Status& status); // Base class for data service clients. Data service clients are // threadsafe. @@ -104,11 +104,11 @@ class DataServiceClientBase { // first RPC will perform any necessary initialization. However, it can be // useful to call `Initialize()` proactively so that any errors that happen // during initialization can be surfaced earlier. - virtual Status Initialize() { return EnsureInitialized(); } + virtual absl::Status Initialize() { return EnsureInitialized(); } protected: // Initializes the client if it isn't already initialized. - virtual Status EnsureInitialized() = 0; + virtual absl::Status EnsureInitialized() = 0; const std::string address_; const std::string protocol_; diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto index f0d9fd70a9e8e0..9d2825082efed1 100644 --- a/tensorflow/core/data/service/common.proto +++ b/tensorflow/core/data/service/common.proto @@ -2,10 +2,10 @@ syntax = "proto3"; package tensorflow.data; +import "xla/tsl/protobuf/status.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/protobuf/data_service.proto"; import "tensorflow/core/protobuf/snapshot.proto"; -import "tsl/protobuf/status.proto"; // Next tag: 2 message DatasetDef { diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc index 9367296e80bbdf..721ce5b806e7af 100644 --- a/tensorflow/core/data/service/credentials_factory.cc +++ b/tensorflow/core/data/service/credentials_factory.cc @@ -49,8 +49,8 @@ void CredentialsFactory::Register(CredentialsFactory* factory) { } } -Status CredentialsFactory::Get(absl::string_view protocol, - CredentialsFactory** out) { +absl::Status CredentialsFactory::Get(absl::string_view protocol, + CredentialsFactory** out) { mutex_lock l(*get_lock()); auto it = credentials_factories().find(std::string(protocol)); if (it != credentials_factories().end()) { @@ -69,7 +69,7 @@ Status CredentialsFactory::Get(absl::string_view protocol, absl::StrJoin(available_types, ", "), " ]"); } -Status CredentialsFactory::CreateServerCredentials( +absl::Status CredentialsFactory::CreateServerCredentials( absl::string_view protocol, std::shared_ptr<::grpc::ServerCredentials>* out) { CredentialsFactory* factory; @@ -78,7 +78,7 @@ Status CredentialsFactory::CreateServerCredentials( return absl::OkStatus(); } -Status CredentialsFactory::CreateClientCredentials( +absl::Status CredentialsFactory::CreateClientCredentials( absl::string_view protocol, std::shared_ptr<::grpc::ChannelCredentials>* out) { CredentialsFactory* factory; @@ -97,13 +97,13 @@ class InsecureCredentialsFactory : public CredentialsFactory { public: std::string Protocol() override { return "grpc"; } - Status CreateServerCredentials( + absl::Status CreateServerCredentials( std::shared_ptr<::grpc::ServerCredentials>* out) override { *out = ::grpc::InsecureServerCredentials(); return absl::OkStatus(); } - Status CreateClientCredentials( + absl::Status CreateClientCredentials( std::shared_ptr<::grpc::ChannelCredentials>* out) override { *out = ::grpc::InsecureChannelCredentials(); return absl::OkStatus(); diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h index 26ca43fafa38da..d6a3bff54d86e3 100644 --- a/tensorflow/core/data/service/credentials_factory.h +++ b/tensorflow/core/data/service/credentials_factory.h @@ -38,11 +38,11 @@ class CredentialsFactory { virtual std::string Protocol() = 0; // Stores server credentials to `*out`. - virtual Status CreateServerCredentials( + virtual absl::Status CreateServerCredentials( std::shared_ptr<::grpc::ServerCredentials>* out) = 0; // Stores client credentials to `*out`. - virtual Status CreateClientCredentials( + virtual absl::Status CreateClientCredentials( std::shared_ptr<::grpc::ChannelCredentials>* out) = 0; // Registers a credentials factory. @@ -50,13 +50,13 @@ class CredentialsFactory { // Creates server credentials using the credentials factory registered as // `protocol`, and stores them to `*out`. - static Status CreateServerCredentials( + static absl::Status CreateServerCredentials( absl::string_view protocol, std::shared_ptr<::grpc::ServerCredentials>* out); // Creates client credentials using the credentials factory registered as // `protocol`, and stores them to `*out`. - static Status CreateClientCredentials( + static absl::Status CreateClientCredentials( absl::string_view protocol, std::shared_ptr<::grpc::ChannelCredentials>* out); @@ -67,7 +67,8 @@ class CredentialsFactory { private: // Gets the credentials factory registered via `Register` for the specified // protocol, and stores it to `*out`. - static Status Get(const absl::string_view protocol, CredentialsFactory** out); + static absl::Status Get(const absl::string_view protocol, + CredentialsFactory** out); }; } // namespace data diff --git a/tensorflow/core/data/service/credentials_factory_test.cc b/tensorflow/core/data/service/credentials_factory_test.cc index 88a4ef086871cd..fdfbee21b990e2 100644 --- a/tensorflow/core/data/service/credentials_factory_test.cc +++ b/tensorflow/core/data/service/credentials_factory_test.cc @@ -34,12 +34,12 @@ class TestCredentialsFactory : public CredentialsFactory { public: std::string Protocol() override { return "test"; } - Status CreateServerCredentials( + absl::Status CreateServerCredentials( std::shared_ptr* out) override { return errors::Internal(kFailedToCreateServerCredentials); } - Status CreateClientCredentials( + absl::Status CreateClientCredentials( std::shared_ptr* out) override { return errors::Internal(kFailedToCreateClientCredentials); } @@ -70,8 +70,8 @@ TEST(CredentialsFactory, DefaultGrpcProtocol) { TEST(CredentialsFactory, MissingServerProtocol) { std::shared_ptr server_credentials; - Status s = CredentialsFactory::CreateServerCredentials("unknown_protocol", - &server_credentials); + absl::Status s = CredentialsFactory::CreateServerCredentials( + "unknown_protocol", &server_credentials); ASSERT_EQ(error::Code::NOT_FOUND, s.code()); ASSERT_TRUE( absl::StrContains(s.ToString(), @@ -81,8 +81,8 @@ TEST(CredentialsFactory, MissingServerProtocol) { TEST(CredentialsFactory, MissingClientProtocol) { std::shared_ptr client_credentials; - Status s = CredentialsFactory::CreateClientCredentials("unknown_protocol", - &client_credentials); + absl::Status s = CredentialsFactory::CreateClientCredentials( + "unknown_protocol", &client_credentials); ASSERT_EQ(error::Code::NOT_FOUND, s.code()); ASSERT_TRUE( absl::StrContains(s.ToString(), diff --git a/tensorflow/core/data/service/cross_trainer_cache.h b/tensorflow/core/data/service/cross_trainer_cache.h index 98c1725c2dce79..3ef48fe4a4f333 100644 --- a/tensorflow/core/data/service/cross_trainer_cache.h +++ b/tensorflow/core/data/service/cross_trainer_cache.h @@ -115,7 +115,7 @@ class CrossTrainerCache { // Cancels the cache with `status` and notifies the readers. After cancelling, // all `Get` calls will return `status`. // REQUIRES: !status.ok() - void Cancel(Status status); + void Cancel(absl::Status status); // Returns true if the cache has been cancelled. bool IsCancelled() const; @@ -143,7 +143,7 @@ class CrossTrainerCache { const std::string& trainer_id); // Reads a new element and writes it into the cache. - Status ExtendCache(); + absl::Status ExtendCache(); // Frees old elements to keep the cache size below `max_cache_size_bytes_`. // `new_element_size_bytes` is the size of the new element being inserted. @@ -163,7 +163,7 @@ class CrossTrainerCache { // If `status_` is non-OK, the cache is cancelled, and all method calls will // return this status. - Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // `cache_` stores the cached elements. std::deque> cache_ TF_GUARDED_BY(mu_); @@ -235,7 +235,7 @@ CrossTrainerCache::GetCacheQueryResult( } if (should_extend_cache) { - Status s = ExtendCache(); + absl::Status s = ExtendCache(); mutex_lock l(mu_); extending_cache_ = false; cv_.notify_all(); @@ -278,7 +278,8 @@ size_t CrossTrainerCache::GetElementIndex( } template -Status CrossTrainerCache::ExtendCache() TF_LOCKS_EXCLUDED(mu_) { +absl::Status CrossTrainerCache::ExtendCache() + TF_LOCKS_EXCLUDED(mu_) { TF_ASSIGN_OR_RETURN(ElementType element, cachable_sequence_->GetNext()); size_t new_element_size_bytes = cachable_sequence_->GetElementSizeBytes(element); @@ -317,7 +318,7 @@ void CrossTrainerCache::FreeSpace(size_t new_element_size_bytes) } template -void CrossTrainerCache::Cancel(Status status) +void CrossTrainerCache::Cancel(absl::Status status) TF_LOCKS_EXCLUDED(mu_) { DCHECK(!status.ok()) << "Cancelling CrossTrainerCache requires a non-OK status. Got " diff --git a/tensorflow/core/data/service/cross_trainer_cache_test.cc b/tensorflow/core/data/service/cross_trainer_cache_test.cc index 9426d24d917e04..cc4fb83bef8b26 100644 --- a/tensorflow/core/data/service/cross_trainer_cache_test.cc +++ b/tensorflow/core/data/service/cross_trainer_cache_test.cc @@ -362,7 +362,7 @@ TEST(CrossTrainerCacheTest, Cancel) { EXPECT_FALSE(cache.IsCancelled()); mutex mu; - Status status; // Guarded by `mu`. + absl::Status status; // Guarded by `mu`. std::vector> reader_threads; for (size_t i = 0; i < num_trainers; ++i) { reader_threads.push_back(absl::WrapUnique(Env::Default()->StartThread( diff --git a/tensorflow/core/data/service/data_transfer.cc b/tensorflow/core/data/service/data_transfer.cc index 67da0b4adbb70b..4f45b11d313e31 100644 --- a/tensorflow/core/data/service/data_transfer.cc +++ b/tensorflow/core/data/service/data_transfer.cc @@ -91,8 +91,9 @@ void DataTransferServer::Register(std::string name, ServerFactoryT factory) { } } -Status DataTransferServer::Build(std::string name, GetElementT get_element, - std::shared_ptr* out) { +absl::Status DataTransferServer::Build( + std::string name, GetElementT get_element, + std::shared_ptr* out) { mutex_lock l(*get_lock()); auto it = transfer_server_factories().find(name); if (it != transfer_server_factories().end()) { @@ -119,8 +120,8 @@ void DataTransferClient::Register(std::string name, ClientFactoryT factory) { } } -Status DataTransferClient::Build(std::string name, Config config, - std::unique_ptr* out) { +absl::Status DataTransferClient::Build( + std::string name, Config config, std::unique_ptr* out) { mutex_lock l(*get_lock()); auto it = transfer_client_factories().find(name); if (it != transfer_client_factories().end()) { diff --git a/tensorflow/core/data/service/data_transfer.h b/tensorflow/core/data/service/data_transfer.h index ac3c6f68e95140..cb5125b573ce97 100644 --- a/tensorflow/core/data/service/data_transfer.h +++ b/tensorflow/core/data/service/data_transfer.h @@ -74,12 +74,12 @@ class DataTransferClient { Allocator* allocator; }; using ClientFactoryT = - std::function*)>; + std::function*)>; virtual ~DataTransferClient() = default; // Fetches the next element. - virtual Status GetElement(const GetElementRequest& req, - GetElementResult& result) = 0; + virtual absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) = 0; // Makes a best effort to cancel all outstanding calls in progress for the // client, and causes further calls to return Cancelled status. @@ -89,8 +89,8 @@ class DataTransferClient { static void Register(std::string name, ClientFactoryT factory); // Builds a DataTransferClient from the factory registered under `name`. - static Status Build(std::string name, Config config, - std::unique_ptr* out); + static absl::Status Build(std::string name, Config config, + std::unique_ptr* out); // Returns a string describing properties of the client relevant for checking // compatibility with a server for a given protocol. @@ -100,7 +100,7 @@ class DataTransferClient { // Returns an error if the client is incompatible with a server which has the // properties described in `server_compatibility_info`. - virtual Status CheckCompatibility( + virtual absl::Status CheckCompatibility( const std::string& server_compatibility_info) const { return absl::OkStatus(); } @@ -113,13 +113,13 @@ class DataTransferClient { class DataTransferServer { public: using GetElementT = - std::function; - using ServerFactoryT = - std::function*)>; + std::function; + using ServerFactoryT = std::function*)>; virtual ~DataTransferServer() = default; // Starts DataTransferServer, it should be available for requests afterwards. - virtual Status Start(const experimental::WorkerConfig& config) = 0; + virtual absl::Status Start(const experimental::WorkerConfig& config) = 0; // Return the port that this server is listening on. virtual int Port() const = 0; @@ -128,8 +128,8 @@ class DataTransferServer { static void Register(std::string name, ServerFactoryT factory); // Builds a DataTransferServer from the factory registered with `name`. - static Status Build(std::string name, GetElementT get_element, - std::shared_ptr* out); + static absl::Status Build(std::string name, GetElementT get_element, + std::shared_ptr* out); // Returns a string describing properties of the server relevant for checking // compatibility with a client for a given protocol. diff --git a/tensorflow/core/data/service/data_transfer_test.cc b/tensorflow/core/data/service/data_transfer_test.cc index d07623f6b53853..4799773562546d 100644 --- a/tensorflow/core/data/service/data_transfer_test.cc +++ b/tensorflow/core/data/service/data_transfer_test.cc @@ -35,7 +35,7 @@ namespace { class TestDataTransferServer : public DataTransferServer { public: explicit TestDataTransferServer(bool* called) : called_(called) {} - Status Start(const experimental::WorkerConfig& unused_config) override { + absl::Status Start(const experimental::WorkerConfig& unused_config) override { *called_ = true; return absl::OkStatus(); } diff --git a/tensorflow/core/data/service/dataset_store.cc b/tensorflow/core/data/service/dataset_store.cc index e201ff9ef081ee..cc6430326f0f9a 100644 --- a/tensorflow/core/data/service/dataset_store.cc +++ b/tensorflow/core/data/service/dataset_store.cc @@ -34,14 +34,14 @@ namespace data { FileSystemDatasetStore::FileSystemDatasetStore(const std::string& datasets_dir) : datasets_dir_(datasets_dir) {} -Status FileSystemDatasetStore::Put(const std::string& key, - const DatasetDef& dataset) { +absl::Status FileSystemDatasetStore::Put(const std::string& key, + const DatasetDef& dataset) { std::string path_to_write = io::JoinPath(datasets_dir_, key); TF_RETURN_IF_ERROR(WriteDatasetDef(path_to_write, dataset)); return absl::OkStatus(); } -Status FileSystemDatasetStore::Get( +absl::Status FileSystemDatasetStore::Get( const std::string& key, std::shared_ptr& dataset_def) { std::string path = io::JoinPath(datasets_dir_, key); TF_RETURN_IF_ERROR(Env::Default()->FileExists(path)); @@ -51,15 +51,15 @@ Status FileSystemDatasetStore::Get( return absl::OkStatus(); } -Status MemoryDatasetStore::Put(const std::string& key, - const DatasetDef& dataset) { +absl::Status MemoryDatasetStore::Put(const std::string& key, + const DatasetDef& dataset) { auto& stored_dataset = datasets_[key]; stored_dataset = std::make_shared(dataset); return absl::OkStatus(); } -Status MemoryDatasetStore::Get(const std::string& key, - std::shared_ptr& dataset_def) { +absl::Status MemoryDatasetStore::Get( + const std::string& key, std::shared_ptr& dataset_def) { auto& stored_dataset = datasets_[key]; if (!stored_dataset) { return errors::NotFound("Dataset with key ", key, " not found"); diff --git a/tensorflow/core/data/service/dataset_store.h b/tensorflow/core/data/service/dataset_store.h index 437066d719fdaf..f79120bd9c09a6 100644 --- a/tensorflow/core/data/service/dataset_store.h +++ b/tensorflow/core/data/service/dataset_store.h @@ -35,10 +35,11 @@ class DatasetStore { // Stores the given dataset under the given key. Overwrites a dataset if it // already exists. - virtual Status Put(const std::string& key, const DatasetDef& dataset) = 0; + virtual absl::Status Put(const std::string& key, + const DatasetDef& dataset) = 0; // Gets the dataset for the given key, storing the dataset in `dataset_def`. - virtual Status Get(const std::string& key, - std::shared_ptr& dataset_def) = 0; + virtual absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) = 0; }; // Dataset store which reads and writes datasets within a directory. @@ -49,9 +50,9 @@ class FileSystemDatasetStore : public DatasetStore { FileSystemDatasetStore(const FileSystemDatasetStore&) = delete; FileSystemDatasetStore& operator=(const FileSystemDatasetStore&) = delete; - Status Put(const std::string& key, const DatasetDef& dataset) override; - Status Get(const std::string& key, - std::shared_ptr& dataset_def) override; + absl::Status Put(const std::string& key, const DatasetDef& dataset) override; + absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) override; private: const std::string datasets_dir_; @@ -65,9 +66,9 @@ class MemoryDatasetStore : public DatasetStore { MemoryDatasetStore(const MemoryDatasetStore&) = delete; MemoryDatasetStore& operator=(const MemoryDatasetStore&) = delete; - Status Put(const std::string& key, const DatasetDef& dataset) override; - Status Get(const std::string& key, - std::shared_ptr& dataset_def) override; + absl::Status Put(const std::string& key, const DatasetDef& dataset) override; + absl::Status Get(const std::string& key, + std::shared_ptr& dataset_def) override; private: // Mapping from key to dataset definition. diff --git a/tensorflow/core/data/service/dataset_store_test.cc b/tensorflow/core/data/service/dataset_store_test.cc index 46c1111e41ab7d..cf127475a68053 100644 --- a/tensorflow/core/data/service/dataset_store_test.cc +++ b/tensorflow/core/data/service/dataset_store_test.cc @@ -111,7 +111,7 @@ TEST_P(DatasetStoreTest, StoreAlreadyExists) { TEST_P(DatasetStoreTest, GetMissing) { std::unique_ptr store = MakeStore(GetParam()); std::shared_ptr result; - Status s = store->Get("missing", result); + absl::Status s = store->Get("missing", result); EXPECT_EQ(s.code(), error::NOT_FOUND); } diff --git a/tensorflow/core/data/service/dispatcher_client.cc b/tensorflow/core/data/service/dispatcher_client.cc index ff3b165899f562..d77cf3c21598c3 100644 --- a/tensorflow/core/data/service/dispatcher_client.cc +++ b/tensorflow/core/data/service/dispatcher_client.cc @@ -46,7 +46,7 @@ limitations under the License. namespace tensorflow { namespace data { -Status DataServiceDispatcherClient::Initialize() { +absl::Status DataServiceDispatcherClient::Initialize() { mutex_lock l(mu_); if (stub_) { return absl::OkStatus(); @@ -96,7 +96,7 @@ DataServiceDispatcherClient::WorkerHeartbeat( return response; } -Status DataServiceDispatcherClient::WorkerUpdate( +absl::Status DataServiceDispatcherClient::WorkerUpdate( const std::string& worker_address, std::vector& task_progress) { WorkerUpdateRequest req; @@ -113,8 +113,8 @@ Status DataServiceDispatcherClient::WorkerUpdate( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetDatasetDef(const std::string& dataset_id, - DatasetDef& dataset_def) { +absl::Status DataServiceDispatcherClient::GetDatasetDef( + const std::string& dataset_id, DatasetDef& dataset_def) { GetDatasetDefRequest req; req.set_dataset_id(dataset_id); GetDatasetDefResponse resp; @@ -127,11 +127,11 @@ Status DataServiceDispatcherClient::GetDatasetDef(const std::string& dataset_id, return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, - int64_t repetition, - int64_t split_provider_index, - Tensor& split, - bool& end_of_splits) { +absl::Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, + int64_t repetition, + int64_t split_provider_index, + Tensor& split, + bool& end_of_splits) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetSplitRequest req; req.set_iteration_id(iteration_id); @@ -152,7 +152,7 @@ Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id, return absl::OkStatus(); } -Status DataServiceDispatcherClient::Snapshot( +absl::Status DataServiceDispatcherClient::Snapshot( const DatasetDef& dataset, const std::string& path, const experimental::DistributedSnapshotMetadata& metadata) { TF_RETURN_IF_ERROR(EnsureInitialized()); @@ -171,7 +171,7 @@ Status DataServiceDispatcherClient::Snapshot( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetSnapshotSplit( +absl::Status DataServiceDispatcherClient::GetSnapshotSplit( const std::string& worker_address, const std::string& base_path, int64_t stream_index, int64_t source_index, int64_t repetition_index, Tensor& split, int64_t& local_split_index, bool& end_of_splits) { @@ -200,7 +200,7 @@ Status DataServiceDispatcherClient::GetSnapshotSplit( return absl::OkStatus(); } -Status DataServiceDispatcherClient::RegisterDataset( +absl::Status DataServiceDispatcherClient::RegisterDataset( const DatasetDef& dataset, const DataServiceMetadata& metadata, const std::optional& requested_dataset_id, std::string& dataset_id) { @@ -222,7 +222,7 @@ Status DataServiceDispatcherClient::RegisterDataset( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetOrCreateJob( +absl::Status DataServiceDispatcherClient::GetOrCreateJob( const std::string& dataset_id, const ProcessingModeDef& processing_mode, const std::optional& job_name, std::optional num_consumers, bool use_cross_trainer_cache, @@ -252,7 +252,7 @@ Status DataServiceDispatcherClient::GetOrCreateJob( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetOrCreateIteration( +absl::Status DataServiceDispatcherClient::GetOrCreateIteration( int64_t job_id, int64_t repetition, int64_t& iteration_client_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetOrCreateIterationRequest req; @@ -271,7 +271,7 @@ Status DataServiceDispatcherClient::GetOrCreateIteration( return absl::OkStatus(); } -Status DataServiceDispatcherClient::ReleaseIterationClient( +absl::Status DataServiceDispatcherClient::ReleaseIterationClient( int64_t iteration_client_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); ReleaseIterationClientRequest req; @@ -288,10 +288,8 @@ Status DataServiceDispatcherClient::ReleaseIterationClient( return absl::OkStatus(); } -Status DataServiceDispatcherClient::MaybeRemoveTask(int64_t task_id, - int64_t consumer_index, - int64_t round, - bool& removed) { +absl::Status DataServiceDispatcherClient::MaybeRemoveTask( + int64_t task_id, int64_t consumer_index, int64_t round, bool& removed) { TF_RETURN_IF_ERROR(EnsureInitialized()); MaybeRemoveTaskRequest req; req.set_task_id(task_id); @@ -307,7 +305,7 @@ Status DataServiceDispatcherClient::MaybeRemoveTask(int64_t task_id, return absl::OkStatus(); } -Status DataServiceDispatcherClient::ClientHeartbeat( +absl::Status DataServiceDispatcherClient::ClientHeartbeat( ClientHeartbeatRequest& req, ClientHeartbeatResponse& resp) { TF_RETURN_IF_ERROR(EnsureInitialized()); grpc::ClientContext ctx; @@ -318,7 +316,7 @@ Status DataServiceDispatcherClient::ClientHeartbeat( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetWorkers( +absl::Status DataServiceDispatcherClient::GetWorkers( std::vector& workers) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetWorkersRequest req; @@ -335,7 +333,7 @@ Status DataServiceDispatcherClient::GetWorkers( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetDataServiceMetadata( +absl::Status DataServiceDispatcherClient::GetDataServiceMetadata( const std::string& dataset_id, DataServiceMetadata& metadata) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetDataServiceMetadataRequest req; @@ -350,7 +348,7 @@ Status DataServiceDispatcherClient::GetDataServiceMetadata( return absl::OkStatus(); } -Status DataServiceDispatcherClient::GetDataServiceConfig( +absl::Status DataServiceDispatcherClient::GetDataServiceConfig( DataServiceConfig& config) { TF_RETURN_IF_ERROR(EnsureInitialized()); GetDataServiceConfigRequest request; @@ -364,7 +362,7 @@ Status DataServiceDispatcherClient::GetDataServiceConfig( return absl::OkStatus(); } -Status DataServiceDispatcherClient::DisableCompressionAtRuntime( +absl::Status DataServiceDispatcherClient::DisableCompressionAtRuntime( const std::string& dataset_id, bool disable_compression_at_runtime, DisableCompressionAtRuntimeResponse& response) { TF_RETURN_IF_ERROR(EnsureInitialized()); @@ -380,7 +378,7 @@ Status DataServiceDispatcherClient::DisableCompressionAtRuntime( return absl::OkStatus(); } -Status DataServiceDispatcherClient::EnsureInitialized() { +absl::Status DataServiceDispatcherClient::EnsureInitialized() { return grpc_util::Retry([this] { return Initialize(); }, "Initialize dispatcher client", /*deadline_micros=*/kint64max); diff --git a/tensorflow/core/data/service/dispatcher_client.h b/tensorflow/core/data/service/dispatcher_client.h index 9f521bd210ac6a..253d8ec06c0714 100644 --- a/tensorflow/core/data/service/dispatcher_client.h +++ b/tensorflow/core/data/service/dispatcher_client.h @@ -44,7 +44,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase { const std::string& protocol) : DataServiceClientBase(address, protocol) {} - Status Initialize() override; + absl::Status Initialize() override; // Sends a heartbeat to the dispatcher. If the worker wasn't already // registered with the dispatcher, this will register the worker. The @@ -54,91 +54,91 @@ class DataServiceDispatcherClient : public DataServiceClientBase { const WorkerHeartbeatRequest& request); // Updates the dispatcher with information about the worker's state. - Status WorkerUpdate(const std::string& worker_address, - std::vector& task_progress); + absl::Status WorkerUpdate(const std::string& worker_address, + std::vector& task_progress); // Gets a dataset definition for the given dataset id, and stores the // definition in `dataset_def`. - Status GetDatasetDef(const std::string& dataset_id, DatasetDef& dataset_def); + absl::Status GetDatasetDef(const std::string& dataset_id, + DatasetDef& dataset_def); // Gets the next split for the specified iteration id, repetition, and split // provider index. - Status GetSplit(int64_t iteration_id, int64_t repetition, - int64_t split_provider_index, Tensor& split, - bool& end_of_splits); + absl::Status GetSplit(int64_t iteration_id, int64_t repetition, + int64_t split_provider_index, Tensor& split, + bool& end_of_splits); // Gets the next split for the specified source of a stream of the snapshot in // `base_path`. If `end_of_splits` returns true, then there are no more splits // to be processed for the specified stream source. - virtual Status GetSnapshotSplit(const std::string& worker_address, - const std::string& base_path, - int64_t stream_index, int64_t source_index, - int64_t repetition_index, Tensor& split, - int64_t& local_split_index, - bool& end_of_splits); + virtual absl::Status GetSnapshotSplit( + const std::string& worker_address, const std::string& base_path, + int64_t stream_index, int64_t source_index, int64_t repetition_index, + Tensor& split, int64_t& local_split_index, bool& end_of_splits); // Initiates the process of materializing `dataset`'s output to `path`. - Status Snapshot(const DatasetDef& dataset, const std::string& path, - const experimental::DistributedSnapshotMetadata& metadata); + absl::Status Snapshot( + const DatasetDef& dataset, const std::string& path, + const experimental::DistributedSnapshotMetadata& metadata); // Registers a dataset with the tf.data service, and stores the generated // dataset id in `dataset_id`. - Status RegisterDataset(const DatasetDef& dataset, - const DataServiceMetadata& metadata, - const std::optional& requested_dataset_id, - std::string& dataset_id); + absl::Status RegisterDataset( + const DatasetDef& dataset, const DataServiceMetadata& metadata, + const std::optional& requested_dataset_id, + std::string& dataset_id); // If `job_name` is set, looks up a job matching `job_name`. // If `job_name` is absent or no matching job is found, creates a // new job. The resulting job id is stored in `job_id`. - Status GetOrCreateJob(const std::string& dataset_id, - const ProcessingModeDef& processing_mode, - const std::optional& job_name, - std::optional num_consumers, - bool use_cross_trainer_cache, - TargetWorkers target_workers, int64_t& job_id); + absl::Status GetOrCreateJob(const std::string& dataset_id, + const ProcessingModeDef& processing_mode, + const std::optional& job_name, + std::optional num_consumers, + bool use_cross_trainer_cache, + TargetWorkers target_workers, int64_t& job_id); // Looks up an iteration of a job, creating an iteration if one doesn't // already exist. The returned `iteration_client_id` can be used to query // information about the iteration. The client should call // `ReleaseIterationClient` when finished with the iteration, so that // resources can be reclaimed. - Status GetOrCreateIteration(int64_t job_id, int64_t repetition, - int64_t& iteration_client_id); + absl::Status GetOrCreateIteration(int64_t job_id, int64_t repetition, + int64_t& iteration_client_id); // Releases a iteration client id, indicating that the id will no longer be // used to read from the iteration. - Status ReleaseIterationClient(int64_t iteration_client_id); + absl::Status ReleaseIterationClient(int64_t iteration_client_id); // Attempts to remove a task. The task is removed if all consumers try to // remove the task in the same round. - Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, int64_t round, - bool& removed); + absl::Status MaybeRemoveTask(int64_t task_id, int64_t consumer_index, + int64_t round, bool& removed); // Heartbeats to the dispatcher, getting back the tasks that should be // running, and whether the iteration is finished. - Status ClientHeartbeat(ClientHeartbeatRequest& req, - ClientHeartbeatResponse& resp); + absl::Status ClientHeartbeat(ClientHeartbeatRequest& req, + ClientHeartbeatResponse& resp); // Queries the dispatcher for its registered workers. The worker info will be // stored in `workers`. - Status GetWorkers(std::vector& workers); + absl::Status GetWorkers(std::vector& workers); // Returns data service metadata for the registered dataset. - Status GetDataServiceMetadata(const std::string& dataset_id, - DataServiceMetadata& metadata); + absl::Status GetDataServiceMetadata(const std::string& dataset_id, + DataServiceMetadata& metadata); // Returns data service config of the data service cluster. - Status GetDataServiceConfig(DataServiceConfig& config); + absl::Status GetDataServiceConfig(DataServiceConfig& config); // Returns information about the decision to disable compression at runtime // for a given dataset. - Status DisableCompressionAtRuntime( + absl::Status DisableCompressionAtRuntime( const std::string& dataset_id, bool disable_compression_at_runtime, DisableCompressionAtRuntimeResponse& response); protected: - Status EnsureInitialized() override; + absl::Status EnsureInitialized() override; private: mutex mu_; diff --git a/tensorflow/core/data/service/dispatcher_client_test.cc b/tensorflow/core/data/service/dispatcher_client_test.cc index 013f514658db68..97253905b6957d 100644 --- a/tensorflow/core/data/service/dispatcher_client_test.cc +++ b/tensorflow/core/data/service/dispatcher_client_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/dataset_store.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/dispatcher_impl.cc b/tensorflow/core/data/service/dispatcher_impl.cc index aaad842c26dd0c..6c1e100735c7e4 100644 --- a/tensorflow/core/data/service/dispatcher_impl.cc +++ b/tensorflow/core/data/service/dispatcher_impl.cc @@ -132,8 +132,9 @@ std::string DatasetsDir(const std::string& work_dir) { return io::JoinPath(work_dir, kDatasetsDir); } -Status CreateWorkerStub(const std::string& address, const std::string& protocol, - std::unique_ptr& stub) { +absl::Status CreateWorkerStub(const std::string& address, + const std::string& protocol, + std::unique_ptr& stub) { ::grpc::ChannelArguments args; args.SetMaxReceiveMessageSize(-1); std::shared_ptr<::grpc::ChannelCredentials> credentials; @@ -210,7 +211,7 @@ DataServiceDispatcherImpl::~DataServiceDispatcherImpl() { maintenance_thread_.reset(); } -Status DataServiceDispatcherImpl::Start() { +absl::Status DataServiceDispatcherImpl::Start() { mutex_lock l(mu_); if (config_.job_gc_timeout_ms() >= 0) { maintenance_thread_ = absl::WrapUnique(env_->StartThread( @@ -238,7 +239,7 @@ Status DataServiceDispatcherImpl::Start() { Update update; bool end_of_journal = false; FileJournalReader reader(env_, JournalDir(config_.work_dir())); - Status s = reader.Read(update, end_of_journal); + absl::Status s = reader.Read(update, end_of_journal); if (errors::IsNotFound(s)) { LOG(INFO) << "No journal found. Starting dispatcher from new state."; } else if (!s.ok()) { @@ -315,7 +316,7 @@ size_t DataServiceDispatcherImpl::NumActiveIterations() TF_LOCKS_EXCLUDED(mu_) { return count; } -Status DataServiceDispatcherImpl::RestoreSplitProviders( +absl::Status DataServiceDispatcherImpl::RestoreSplitProviders( const Iteration& iteration, std::vector>& restored) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -346,7 +347,7 @@ Status DataServiceDispatcherImpl::RestoreSplitProviders( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::FindTasksToDelete( +absl::Status DataServiceDispatcherImpl::FindTasksToDelete( const absl::flat_hash_set& current_tasks, const std::vector>& assigned_tasks, WorkerHeartbeatResponse* response) { @@ -362,7 +363,7 @@ Status DataServiceDispatcherImpl::FindTasksToDelete( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::FindNewTasks( +absl::Status DataServiceDispatcherImpl::FindNewTasks( const std::string& worker_address, const absl::flat_hash_set& current_tasks, std::vector>& assigned_tasks, @@ -404,14 +405,14 @@ void DataServiceDispatcherImpl::ReportProcessingTimesFromActiveTasks( << ". Time in nanoseconds: " << processing_time_nsec; std::shared_ptr task; - Status s = state_.TaskFromId(task_id, task); + absl::Status s = state_.TaskFromId(task_id, task); if (!s.ok()) { VLOG(1) << "Could not find task with id " << task_id << " in tf.data service dispatcher state: " << s; continue; } - Status auto_scaler_status = auto_scaler_.ReportProcessingTime( + absl::Status auto_scaler_status = auto_scaler_.ReportProcessingTime( task->iteration->iteration_id, worker_address, absl::Nanoseconds(processing_time_nsec)); if (!auto_scaler_status.ok()) { @@ -423,7 +424,7 @@ void DataServiceDispatcherImpl::ReportProcessingTimesFromActiveTasks( } } -Status DataServiceDispatcherImpl::WorkerHeartbeat( +absl::Status DataServiceDispatcherImpl::WorkerHeartbeat( const WorkerHeartbeatRequest* request, WorkerHeartbeatResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); VLOG(3) << "Received worker heartbeat request from worker " @@ -435,7 +436,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( absl::FromUnixMicros(env_->NowMicros()); // Assigned tasks from the perspective of the dispatcher. std::vector> assigned_tasks; - Status s = state_.TasksForWorker(worker_address, assigned_tasks); + absl::Status s = state_.TasksForWorker(worker_address, assigned_tasks); if (!s.ok()) { if (!errors::IsNotFound(s)) { return s; @@ -492,7 +493,7 @@ Status DataServiceDispatcherImpl::WorkerHeartbeat( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::WorkerUpdate( +absl::Status DataServiceDispatcherImpl::WorkerUpdate( const WorkerUpdateRequest* request, WorkerUpdateResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); @@ -516,7 +517,7 @@ Status DataServiceDispatcherImpl::WorkerUpdate( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetDatasetDef( +absl::Status DataServiceDispatcherImpl::GetDatasetDef( const GetDatasetDefRequest* request, GetDatasetDefResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); @@ -528,8 +529,8 @@ Status DataServiceDispatcherImpl::GetDatasetDef( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, - GetSplitResponse* response) { +absl::Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, + GetSplitResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); int64_t iteration_id = request->iteration_id(); int64_t repetition = request->repetition(); @@ -584,7 +585,7 @@ Status DataServiceDispatcherImpl::GetSplit(const GetSplitRequest* request, return absl::OkStatus(); } -Status DataServiceDispatcherImpl::MakeSplitProviders( +absl::Status DataServiceDispatcherImpl::MakeSplitProviders( const std::string& dataset_id, std::vector>& split_providers) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -596,13 +597,13 @@ Status DataServiceDispatcherImpl::MakeSplitProviders( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetVersion(const GetVersionRequest* request, - GetVersionResponse* response) { +absl::Status DataServiceDispatcherImpl::GetVersion( + const GetVersionRequest* request, GetVersionResponse* response) { response->set_version(kDataServiceVersion); return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetOrRegisterDataset( +absl::Status DataServiceDispatcherImpl::GetOrRegisterDataset( const GetOrRegisterDatasetRequest* request, GetOrRegisterDatasetResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -633,7 +634,8 @@ DataServiceDispatcherImpl::FindDataset( const GetOrRegisterDatasetRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::shared_ptr existing_dataset; - Status status = state_.DatasetFromId(request.dataset_id(), existing_dataset); + absl::Status status = + state_.DatasetFromId(request.dataset_id(), existing_dataset); if (errors::IsNotFound(status)) { return std::optional(); @@ -646,7 +648,7 @@ DataServiceDispatcherImpl::FindDataset( return std::optional(existing_dataset->dataset_id); } -Status DataServiceDispatcherImpl::RegisterDataset( +absl::Status DataServiceDispatcherImpl::RegisterDataset( const DatasetDef& dataset, const DataServiceMetadata& metadata, const std::string& requested_dataset_id, std::string& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -662,7 +664,7 @@ Status DataServiceDispatcherImpl::RegisterDataset( return Apply(update); } -Status DataServiceDispatcherImpl::GetDataServiceMetadata( +absl::Status DataServiceDispatcherImpl::GetDataServiceMetadata( const GetDataServiceMetadataRequest* request, GetDataServiceMetadataResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -677,7 +679,7 @@ Status DataServiceDispatcherImpl::GetDataServiceMetadata( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetDataServiceConfig( +absl::Status DataServiceDispatcherImpl::GetDataServiceConfig( const GetDataServiceConfigRequest* request, GetDataServiceConfigResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -685,7 +687,7 @@ Status DataServiceDispatcherImpl::GetDataServiceConfig( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetOrCreateJob( +absl::Status DataServiceDispatcherImpl::GetOrCreateJob( const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); VLOG(3) << "GetOrCreateJob(" << request->DebugString() << ")"; @@ -699,7 +701,7 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( job_name = absl::StrCat("anonymous_job_", state_.NextAvailableJobId(), "_", random::New64()); } - Status s = state_.JobByName(job_name, job); + absl::Status s = state_.JobByName(job_name, job); if (s.ok()) { TF_RETURN_IF_ERROR(ValidateMatchingJob(job, *request)); } else if (errors::IsNotFound(s)) { @@ -714,7 +716,7 @@ Status DataServiceDispatcherImpl::GetOrCreateJob( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetOrCreateIteration( +absl::Status DataServiceDispatcherImpl::GetOrCreateIteration( const GetOrCreateIterationRequest* request, GetOrCreateIterationResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -726,7 +728,7 @@ Status DataServiceDispatcherImpl::GetOrCreateIteration( std::shared_ptr job; TF_RETURN_IF_ERROR(state_.JobFromId(request->job_id(), job)); IterationKey key(job->job_name, request->repetition()); - Status s = state_.IterationByKey(key, iteration); + absl::Status s = state_.IterationByKey(key, iteration); if (!s.ok() && !errors::IsNotFound(s)) { return s; } @@ -745,14 +747,14 @@ Status DataServiceDispatcherImpl::GetOrCreateIteration( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::MaybeRemoveTask( +absl::Status DataServiceDispatcherImpl::MaybeRemoveTask( const MaybeRemoveTaskRequest* request, MaybeRemoveTaskResponse* response) { VLOG(1) << "Attempting to remove task. Request: " << request->DebugString(); std::shared_ptr remover; std::shared_ptr task; { mutex_lock l(mu_); - Status s = state_.TaskFromId(request->task_id(), task); + absl::Status s = state_.TaskFromId(request->task_id(), task); if (errors::IsNotFound(s)) { // Task is already removed. response->set_removed(true); @@ -784,7 +786,7 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask( remove_task->set_task_id(request->task_id()); TF_RETURN_IF_ERROR(Apply(update)); } - Status auto_scaler_status = auto_scaler_.RemoveWorker( + absl::Status auto_scaler_status = auto_scaler_.RemoveWorker( task->iteration->iteration_id, task->worker_address); if (!auto_scaler_status.ok()) { VLOG(1) << "Failed to remove worker with address " << task->worker_address @@ -795,7 +797,7 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::ReleaseIterationClient( +absl::Status DataServiceDispatcherImpl::ReleaseIterationClient( const ReleaseIterationClientRequest* request, ReleaseIterationClientResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -804,7 +806,7 @@ Status DataServiceDispatcherImpl::ReleaseIterationClient( std::shared_ptr iteration; TF_RETURN_IF_ERROR( state_.IterationForIterationClientId(iteration_client_id, iteration)); - Status auto_scaler_status = + absl::Status auto_scaler_status = auto_scaler_.RemoveConsumer(iteration->iteration_id, iteration_client_id); if (!auto_scaler_status.ok()) { VLOG(1) << "Failed to remove consumer with ID " << iteration_client_id @@ -821,7 +823,7 @@ Status DataServiceDispatcherImpl::ReleaseIterationClient( } // Validates that the job matches the requested processing mode. -Status DataServiceDispatcherImpl::ValidateMatchingJob( +absl::Status DataServiceDispatcherImpl::ValidateMatchingJob( std::shared_ptr job, const GetOrCreateJobRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::string diff; @@ -853,7 +855,7 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreateJob( +absl::Status DataServiceDispatcherImpl::CreateJob( const std::string& job_name, const GetOrCreateJobRequest& request, std::shared_ptr& job) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR(ValidateProcessingMode(request.processing_mode_def())); @@ -878,7 +880,7 @@ Status DataServiceDispatcherImpl::CreateJob( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreateIteration( +absl::Status DataServiceDispatcherImpl::CreateIteration( const GetOrCreateIterationRequest& request, std::shared_ptr& iteration) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -903,7 +905,7 @@ Status DataServiceDispatcherImpl::CreateIteration( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreateTasksForWorker( +absl::Status DataServiceDispatcherImpl::CreateTasksForWorker( const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector> iterations = state_.ListIterations(); @@ -921,7 +923,7 @@ Status DataServiceDispatcherImpl::CreateTasksForWorker( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::AcquireIterationClientId( +absl::Status DataServiceDispatcherImpl::AcquireIterationClientId( const std::shared_ptr& iteration, int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { iteration_client_id = state_.NextAvailableIterationClientId(); @@ -936,7 +938,7 @@ Status DataServiceDispatcherImpl::AcquireIterationClientId( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreateTasksForIteration( +absl::Status DataServiceDispatcherImpl::CreateTasksForIteration( std::shared_ptr iteration, std::vector>& tasks) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -951,7 +953,7 @@ Status DataServiceDispatcherImpl::CreateTasksForIteration( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreatePendingTask( +absl::Status DataServiceDispatcherImpl::CreatePendingTask( std::shared_ptr iteration, const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t task_id = state_.NextAvailableTaskId(); @@ -973,7 +975,7 @@ Status DataServiceDispatcherImpl::CreatePendingTask( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CreateTask( +absl::Status DataServiceDispatcherImpl::CreateTask( std::shared_ptr iteration, const std::string& worker_address, std::shared_ptr& task) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -995,7 +997,7 @@ Status DataServiceDispatcherImpl::CreateTask( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::AssignTasks( +absl::Status DataServiceDispatcherImpl::AssignTasks( std::vector> tasks) TF_LOCKS_EXCLUDED(mu_) { for (const auto& task : tasks) { TF_RETURN_IF_ERROR(AssignTask(task)); @@ -1003,7 +1005,7 @@ Status DataServiceDispatcherImpl::AssignTasks( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( +absl::Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( const std::string& worker_address, WorkerService::Stub*& out_stub) TF_LOCKS_EXCLUDED(mu_) { { @@ -1029,8 +1031,8 @@ Status DataServiceDispatcherImpl::GetOrCreateWorkerStub( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) - TF_LOCKS_EXCLUDED(mu_) { +absl::Status DataServiceDispatcherImpl::AssignTask( + std::shared_ptr task) TF_LOCKS_EXCLUDED(mu_) { VLOG(2) << "Started assigning task " << task->task_id << " to worker " << task->worker_address; grpc::ClientContext client_ctx; @@ -1061,7 +1063,7 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr task) return absl::OkStatus(); } -Status DataServiceDispatcherImpl::ClientHeartbeat( +absl::Status DataServiceDispatcherImpl::ClientHeartbeat( const ClientHeartbeatRequest* request, ClientHeartbeatResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); @@ -1070,7 +1072,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat( latest_client_heartbeats_time_[request->iteration_client_id()] = absl::FromUnixMicros(env_->NowMicros()); std::shared_ptr iteration; - Status s = state_.IterationForIterationClientId( + absl::Status s = state_.IterationForIterationClientId( request->iteration_client_id(), iteration); if (errors::IsNotFound(s) && !config_.fault_tolerant_mode()) { return errors::NotFound( @@ -1137,7 +1139,7 @@ Status DataServiceDispatcherImpl::ClientHeartbeat( << iteration->iteration_id << " from iteration_client_id " << request->iteration_client_id() << ". Time in nanoseconds: " << request->target_processing_time_nsec(); - Status auto_scaler_status = auto_scaler_.ReportTargetProcessingTime( + absl::Status auto_scaler_status = auto_scaler_.ReportTargetProcessingTime( iteration->iteration_id, request->iteration_client_id(), absl::Nanoseconds(request->target_processing_time_nsec())); if (!auto_scaler_status.ok()) { @@ -1169,8 +1171,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, - GetWorkersResponse* response) { +absl::Status DataServiceDispatcherImpl::GetWorkers( + const GetWorkersRequest* request, GetWorkersResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); mutex_lock l(mu_); VLOG(3) << "Enter GetWorkers"; @@ -1184,8 +1186,8 @@ Status DataServiceDispatcherImpl::GetWorkers(const GetWorkersRequest* request, return absl::OkStatus(); } -Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, - SnapshotResponse* response) { +absl::Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, + SnapshotResponse* response) { if (!config_.fault_tolerant_mode()) { return errors::InvalidArgument( "tf.data distributed snapshot requires running tf.data service in the " @@ -1213,7 +1215,7 @@ Status DataServiceDispatcherImpl::Snapshot(const SnapshotRequest* request, return Apply(update); } -Status DataServiceDispatcherImpl::GetSnapshotStreams( +absl::Status DataServiceDispatcherImpl::GetSnapshotStreams( const GetSnapshotStreamsRequest* request, GetSnapshotStreamsResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -1230,7 +1232,7 @@ Status DataServiceDispatcherImpl::GetSnapshotStreams( return it->second->GetSnapshotStreams(*response); } -Status DataServiceDispatcherImpl::GetSnapshotSplit( +absl::Status DataServiceDispatcherImpl::GetSnapshotSplit( const GetSnapshotSplitRequest* request, GetSnapshotSplitResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -1278,7 +1280,7 @@ absl::Status DataServiceDispatcherImpl::RestoreSnapshots() return snapshot_status; } -Status DataServiceDispatcherImpl::DisableCompressionAtRuntime( +absl::Status DataServiceDispatcherImpl::DisableCompressionAtRuntime( const DisableCompressionAtRuntimeRequest* request, DisableCompressionAtRuntimeResponse* response) { TF_RETURN_IF_ERROR(CheckStarted()); @@ -1309,7 +1311,7 @@ Status DataServiceDispatcherImpl::DisableCompressionAtRuntime( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::PopulateTaskDef( +absl::Status DataServiceDispatcherImpl::PopulateTaskDef( std::shared_ptr task, TaskDef* task_def) const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { task_def->set_dataset_id(task->iteration->job->dataset_id); @@ -1348,7 +1350,7 @@ Status DataServiceDispatcherImpl::PopulateTaskDef( return absl::OkStatus(); } -Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) { +absl::Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock l(mu_); if (!started_) { return errors::Unavailable("Dispatcher has not started yet."); @@ -1356,7 +1358,7 @@ Status DataServiceDispatcherImpl::CheckStarted() TF_LOCKS_EXCLUDED(mu_) { return absl::OkStatus(); } -Status DataServiceDispatcherImpl::RecordSplitProduced( +absl::Status DataServiceDispatcherImpl::RecordSplitProduced( int64_t iteration_id, int64_t repetition, int64_t split_provider_index, bool finished) TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); @@ -1369,12 +1371,12 @@ Status DataServiceDispatcherImpl::RecordSplitProduced( return Apply(update); } -Status DataServiceDispatcherImpl::ApplyWithoutJournaling(const Update& update) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { +absl::Status DataServiceDispatcherImpl::ApplyWithoutJournaling( + const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return state_.Apply(update); } -Status DataServiceDispatcherImpl::Apply(const Update& update) +absl::Status DataServiceDispatcherImpl::Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (journal_writer_.has_value()) { TF_RETURN_IF_ERROR(journal_writer_.value()->Write(update)); @@ -1395,13 +1397,13 @@ void DataServiceDispatcherImpl::MaintenanceThread() { return; } { - Status s = ReleaseMissingClients(); + absl::Status s = ReleaseMissingClients(); if (!s.ok()) { LOG(WARNING) << "Error releasing missing clients: " << s; } } { - Status s = auto_scaler_.UpdateOptimalNumberOfWorkersMetric( + absl::Status s = auto_scaler_.UpdateOptimalNumberOfWorkersMetric( state_.GetNumberOfRegisteredWorkers()); if (!s.ok()) { VLOG(1) << "Error updating the optimal number of workers metric " @@ -1410,7 +1412,7 @@ void DataServiceDispatcherImpl::MaintenanceThread() { } } { - Status s = GcOldIterations(); + absl::Status s = GcOldIterations(); if (!s.ok()) { LOG(WARNING) << "Error garbage collecting old iterations: " << s; } @@ -1424,9 +1426,9 @@ void DataServiceDispatcherImpl::MaintenanceThread() { void DataServiceDispatcherImpl::RemoveClientFromAutoScaler(int64_t client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::shared_ptr iteration; - Status s = state_.IterationForIterationClientId(client_id, iteration); + absl::Status s = state_.IterationForIterationClientId(client_id, iteration); if (s.ok()) { - Status auto_scaler_status = + absl::Status auto_scaler_status = auto_scaler_.RemoveConsumer(iteration->iteration_id, client_id); if (!auto_scaler_status.ok()) { VLOG(1) << "Failed to remove consumer with ID " << client_id @@ -1439,7 +1441,7 @@ void DataServiceDispatcherImpl::RemoveClientFromAutoScaler(int64_t client_id) } } -Status DataServiceDispatcherImpl::ReleaseMissingClients() +absl::Status DataServiceDispatcherImpl::ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t now = env_->NowMicros(); for (const auto& client_id : state_.ListActiveClientIds()) { @@ -1463,10 +1465,11 @@ Status DataServiceDispatcherImpl::ReleaseMissingClients() void DataServiceDispatcherImpl::RemoveWorkerFromAutoScaler( const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector> tasks; - Status tasks_for_worker_status = state_.TasksForWorker(worker_address, tasks); + absl::Status tasks_for_worker_status = + state_.TasksForWorker(worker_address, tasks); if (tasks_for_worker_status.ok()) { for (const auto& task : tasks) { - Status auto_scaler_status = auto_scaler_.RemoveWorker( + absl::Status auto_scaler_status = auto_scaler_.RemoveWorker( task->iteration->iteration_id, worker_address); if (!auto_scaler_status.ok()) { VLOG(1) << "Failed to remove worker with address " << worker_address @@ -1499,7 +1502,7 @@ void DataServiceDispatcherImpl::DetectMissingWorkers() } } -Status DataServiceDispatcherImpl::GcOldIterations() +absl::Status DataServiceDispatcherImpl::GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector> iterations = state_.ListIterations(); @@ -1512,7 +1515,7 @@ Status DataServiceDispatcherImpl::GcOldIterations() update.mutable_garbage_collect_iteration()->set_iteration_id( iteration->iteration_id); TF_RETURN_IF_ERROR(state_.Apply(update)); - Status auto_scaler_status = + absl::Status auto_scaler_status = auto_scaler_.UnregisterIteration(iteration->iteration_id); if (!auto_scaler_status.ok()) { VLOG(1) << "Failed to unregister Iteration " << iteration->iteration_id @@ -1538,7 +1541,7 @@ bool DataServiceDispatcherImpl::ShouldGcIteration(const Iteration& iteration, (config_.job_gc_timeout_ms() * 1000)); } -Status DataServiceDispatcherImpl::GetDatasetDef( +absl::Status DataServiceDispatcherImpl::GetDatasetDef( const std::string& dataset_id, std::shared_ptr& dataset_def) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -1547,7 +1550,7 @@ Status DataServiceDispatcherImpl::GetDatasetDef( return GetDatasetDef(*dataset, dataset_def); } -Status DataServiceDispatcherImpl::GetDatasetDef( +absl::Status DataServiceDispatcherImpl::GetDatasetDef( const Dataset& dataset, std::shared_ptr& dataset_def) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { return dataset_store_->Get(dataset.dataset_id, dataset_def); diff --git a/tensorflow/core/data/service/dispatcher_impl.h b/tensorflow/core/data/service/dispatcher_impl.h index 5f1f31315a49fd..b82b8cb0c89544 100644 --- a/tensorflow/core/data/service/dispatcher_impl.h +++ b/tensorflow/core/data/service/dispatcher_impl.h @@ -140,7 +140,7 @@ class DataServiceDispatcherImpl { // Starts the dispatcher. If there is a journal, this will read from the // journal to restore the dispatcher's state. - Status Start(); + absl::Status Start(); // Stops the dispatcher. After stopping, RPCs should return without blocking. void Stop(); @@ -151,41 +151,45 @@ class DataServiceDispatcherImpl { // See dispatcher.proto for API documentation. /// Worker-facing API. - Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, - WorkerHeartbeatResponse* response); - Status WorkerUpdate(const WorkerUpdateRequest* request, - WorkerUpdateResponse* response); - Status GetDatasetDef(const GetDatasetDefRequest* request, - GetDatasetDefResponse* response); - Status GetSplit(const GetSplitRequest* request, GetSplitResponse* response); + absl::Status WorkerHeartbeat(const WorkerHeartbeatRequest* request, + WorkerHeartbeatResponse* response); + absl::Status WorkerUpdate(const WorkerUpdateRequest* request, + WorkerUpdateResponse* response); + absl::Status GetDatasetDef(const GetDatasetDefRequest* request, + GetDatasetDefResponse* response); + absl::Status GetSplit(const GetSplitRequest* request, + GetSplitResponse* response); /// Client-facing API. - Status GetVersion(const GetVersionRequest* request, - GetVersionResponse* response); - Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, - GetOrRegisterDatasetResponse* response); - Status GetDataServiceMetadata(const GetDataServiceMetadataRequest* request, - GetDataServiceMetadataResponse* response); - Status GetDataServiceConfig(const GetDataServiceConfigRequest* request, - GetDataServiceConfigResponse* response); - Status GetOrCreateJob(const GetOrCreateJobRequest* request, - GetOrCreateJobResponse* response); - Status GetOrCreateIteration(const GetOrCreateIterationRequest* request, - GetOrCreateIterationResponse* response); - Status ReleaseIterationClient(const ReleaseIterationClientRequest* request, - ReleaseIterationClientResponse* response); - Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request, - MaybeRemoveTaskResponse* response); - Status ClientHeartbeat(const ClientHeartbeatRequest* request, - ClientHeartbeatResponse* response); - Status GetWorkers(const GetWorkersRequest* request, - GetWorkersResponse* response); - Status Snapshot(const SnapshotRequest* request, SnapshotResponse* response); - Status GetSnapshotSplit(const GetSnapshotSplitRequest* request, - GetSnapshotSplitResponse* response); - Status GetSnapshotStreams(const GetSnapshotStreamsRequest* request, - GetSnapshotStreamsResponse* response); - Status DisableCompressionAtRuntime( + absl::Status GetVersion(const GetVersionRequest* request, + GetVersionResponse* response); + absl::Status GetOrRegisterDataset(const GetOrRegisterDatasetRequest* request, + GetOrRegisterDatasetResponse* response); + absl::Status GetDataServiceMetadata( + const GetDataServiceMetadataRequest* request, + GetDataServiceMetadataResponse* response); + absl::Status GetDataServiceConfig(const GetDataServiceConfigRequest* request, + GetDataServiceConfigResponse* response); + absl::Status GetOrCreateJob(const GetOrCreateJobRequest* request, + GetOrCreateJobResponse* response); + absl::Status GetOrCreateIteration(const GetOrCreateIterationRequest* request, + GetOrCreateIterationResponse* response); + absl::Status ReleaseIterationClient( + const ReleaseIterationClientRequest* request, + ReleaseIterationClientResponse* response); + absl::Status MaybeRemoveTask(const MaybeRemoveTaskRequest* request, + MaybeRemoveTaskResponse* response); + absl::Status ClientHeartbeat(const ClientHeartbeatRequest* request, + ClientHeartbeatResponse* response); + absl::Status GetWorkers(const GetWorkersRequest* request, + GetWorkersResponse* response); + absl::Status Snapshot(const SnapshotRequest* request, + SnapshotResponse* response); + absl::Status GetSnapshotSplit(const GetSnapshotSplitRequest* request, + GetSnapshotSplitResponse* response); + absl::Status GetSnapshotStreams(const GetSnapshotStreamsRequest* request, + GetSnapshotStreamsResponse* response); + absl::Status DisableCompressionAtRuntime( const DisableCompressionAtRuntimeRequest* request, DisableCompressionAtRuntimeResponse* response); @@ -199,21 +203,21 @@ class DataServiceDispatcherImpl { // Restores split providers from the state in `iteration` and stores them in // `restored`. - Status RestoreSplitProviders( + absl::Status RestoreSplitProviders( const DispatcherState::Iteration& iteration, std::vector>& restored) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Makes split providers for the specified `dataset_id`, and stores them in // `split_providers`. - Status MakeSplitProviders( + absl::Status MakeSplitProviders( const std::string& dataset_id, std::vector>& split_providers) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Registers a dataset, storing the new dataset's id in `dataset_id`. - Status RegisterDataset(const DatasetDef& dataset, - const DataServiceMetadata& metadata, - const std::string& requested_dataset_id, - std::string& dataset_id) + absl::Status RegisterDataset(const DatasetDef& dataset, + const DataServiceMetadata& metadata, + const std::string& requested_dataset_id, + std::string& dataset_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Finds the dataset ID with the requested dataset ID. // Returns nullptr if no such dataset exists. @@ -222,34 +226,34 @@ class DataServiceDispatcherImpl { // Gets a worker's stub from `worker_stubs_`, or if none exists, creates a // stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is // stored in `out_stub`. - Status GetOrCreateWorkerStub(const std::string& worker_address, - WorkerService::Stub*& out_stub) + absl::Status GetOrCreateWorkerStub(const std::string& worker_address, + WorkerService::Stub*& out_stub) TF_LOCKS_EXCLUDED(mu_); // Creates a job and stores it in `job`. - Status CreateJob(const std::string& job_name, - const GetOrCreateJobRequest& request, - std::shared_ptr& job) + absl::Status CreateJob(const std::string& job_name, + const GetOrCreateJobRequest& request, + std::shared_ptr& job) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates an iteration and stores it in `iteration`. This method updates the // dispatcher state with the new iteration, but does not assign tasks to // workers. - Status CreateIteration( + absl::Status CreateIteration( const GetOrCreateIterationRequest& request, std::shared_ptr& iteration) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates tasks for the specified worker, one task for every unfinished // iteration. - Status CreateTasksForWorker(const std::string& worker_address); + absl::Status CreateTasksForWorker(const std::string& worker_address); // Finds tasks that should be deleted from a worker, updating the heartbeat // response. - Status FindTasksToDelete( + absl::Status FindTasksToDelete( const absl::flat_hash_set& current_tasks, const std::vector>& assigned_tasks, WorkerHeartbeatResponse* response); // Finds new tasks that should be assigned to a worker and adds them to // the heartbeat response. - Status FindNewTasks( + absl::Status FindNewTasks( const std::string& worker_address, const absl::flat_hash_set& current_tasks, std::vector>& assigned_tasks, @@ -260,71 +264,72 @@ class DataServiceDispatcherImpl { const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Acquires an iteration client id to read from the given iteration and sets // `iteration_client_id`. - Status AcquireIterationClientId( + absl::Status AcquireIterationClientId( const std::shared_ptr& iteration, int64_t& iteration_client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates one task for each worker, for the given iteration. The created // tasks are stored in `tasks`. This method only updates dispatcher metadata // with the new tasks, but doesn't assign the tasks to the workers. - Status CreateTasksForIteration( + absl::Status CreateTasksForIteration( std::shared_ptr iteration, std::vector>& tasks) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates a new task for an iteration. The created task may be either // pending or active. - Status CreateTask(std::shared_ptr iteration, - const std::string& worker_address, - std::shared_ptr& task) + absl::Status CreateTask( + std::shared_ptr iteration, + const std::string& worker_address, + std::shared_ptr& task) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates a pending task for a round robin iteration. All consumers need to // agree on which round to add the task in before the pending task can be // promoted to a regular task. - Status CreatePendingTask( + absl::Status CreatePendingTask( std::shared_ptr iteration, const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Creates a new active task for an iteration, storing the created task in // `task`. - Status CreateActiveTask( + absl::Status CreateActiveTask( std::shared_ptr iteration, const std::string& worker_address, std::shared_ptr& task); // Assigns the list of tasks to the workers indicated by their // `worker_address` fields. - Status AssignTasks( + absl::Status AssignTasks( std::vector> tasks) TF_LOCKS_EXCLUDED(mu_); // Assigns a task to the worker indicated by its `worker_address` field. - Status AssignTask(std::shared_ptr task) + absl::Status AssignTask(std::shared_ptr task) TF_LOCKS_EXCLUDED(mu_); // Validates that an existing job matches a given request. // Returns an error status describing any difference. - Status ValidateMatchingJob(std::shared_ptr job, - const GetOrCreateJobRequest& request) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status ValidateMatchingJob( + std::shared_ptr job, + const GetOrCreateJobRequest& request) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Fills out a TaskDef with information about a task. - Status PopulateTaskDef(std::shared_ptr task, - TaskDef* task_def) const - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status PopulateTaskDef( + std::shared_ptr task, + TaskDef* task_def) const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Checks that the dispatcher has started, returning UNAVAILABLE if it hasn't. - Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); + absl::Status CheckStarted() TF_LOCKS_EXCLUDED(mu_); // Restores ongoing tf.data snapshots. absl::Status RestoreSnapshots(); // Records that a split was produced by a call to `GetSplit`. - Status RecordSplitProduced(int64_t iteration_id, int64_t repetition, - int64_t split_provider_index, bool finished) + absl::Status RecordSplitProduced(int64_t iteration_id, int64_t repetition, + int64_t split_provider_index, bool finished) TF_LOCKS_EXCLUDED(mu_); // Applies a state update, updating both the journal and the in-memory state. - Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status Apply(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Applies a state update, but doesn't update the journal. Only meant to be // used when recovering state when the dispatcher starts. - Status ApplyWithoutJournaling(const Update& update) + absl::Status ApplyWithoutJournaling(const Update& update) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Removes the client with `client_id` from `auto_scaler_` void RemoveClientFromAutoScaler(int64_t client_id) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Releases iteration clients that haven't heartbeated recently. - Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Removes the worker with `worker_address` from `auto_scaler_`, which is // potentially associated with multiple iterations. void RemoveWorkerFromAutoScaler(const std::string& worker_address) @@ -333,19 +338,19 @@ class DataServiceDispatcherImpl { // snapshot managers. void DetectMissingWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Scans for old iterations and marks them as finished. - Status GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status GcOldIterations() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns true if an iteration should be garbage collected. bool ShouldGcIteration(const DispatcherState::Iteration& iteration, int64_t now_us) const; // Gets a `DatasetDef` from `dataset_store_` for the given dataset id, and // stores it in `dataset_def`. - Status GetDatasetDef(const std::string& dataset_id, - std::shared_ptr& dataset_def) + absl::Status GetDatasetDef(const std::string& dataset_id, + std::shared_ptr& dataset_def) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Gets a `DatasetDef` from `dataset_store_` for the given dataset, and // stores it in `dataset_def`. - Status GetDatasetDef(const DispatcherState::Dataset& dataset, - std::shared_ptr& dataset_def) + absl::Status GetDatasetDef(const DispatcherState::Dataset& dataset, + std::shared_ptr& dataset_def) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); const experimental::DispatcherConfig config_; diff --git a/tensorflow/core/data/service/dispatcher_state.cc b/tensorflow/core/data/service/dispatcher_state.cc index 22ab9ff2aeb988..72552f51a033d2 100644 --- a/tensorflow/core/data/service/dispatcher_state.cc +++ b/tensorflow/core/data/service/dispatcher_state.cc @@ -41,7 +41,7 @@ DispatcherState::DispatcherState( const experimental::DispatcherConfig& dispatcher_config) : worker_index_resolver_(dispatcher_config.worker_addresses()) {} -Status DispatcherState::Apply(const Update& update) { +absl::Status DispatcherState::Apply(const Update& update) { switch (update.update_type_case()) { case Update::kRegisterDataset: RegisterDataset(update.register_dataset()); @@ -134,8 +134,8 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) { next_available_job_id_ = std::max(next_available_job_id_, job_id + 1); } -Status DispatcherState::JobFromId(int64_t job_id, - std::shared_ptr& job) const { +absl::Status DispatcherState::JobFromId(int64_t job_id, + std::shared_ptr& job) const { auto it = jobs_by_id_.find(job_id); if (it == jobs_by_id_.end()) { return errors::NotFound("Job with id ", job_id, " not found"); @@ -144,8 +144,8 @@ Status DispatcherState::JobFromId(int64_t job_id, return absl::OkStatus(); } -Status DispatcherState::JobByName(const std::string& job_name, - std::shared_ptr& job) const { +absl::Status DispatcherState::JobByName(const std::string& job_name, + std::shared_ptr& job) const { auto it = jobs_by_name_.find(job_name); if (it == jobs_by_name_.end()) { return errors::NotFound("Job with name ", job_name, " not found"); @@ -323,7 +323,7 @@ void DispatcherState::UpdateNextAvailableDatasetId() { } } -Status DispatcherState::DatasetFromId( +absl::Status DispatcherState::DatasetFromId( const std::string& id, std::shared_ptr& dataset) const { auto it = datasets_by_id_.find(id); if (it == datasets_by_id_.end()) { @@ -333,7 +333,7 @@ Status DispatcherState::DatasetFromId( return absl::OkStatus(); } -Status DispatcherState::WorkerFromAddress( +absl::Status DispatcherState::WorkerFromAddress( const std::string& address, std::shared_ptr& worker) const { auto it = workers_.find(address); if (it == workers_.end()) { @@ -363,7 +363,7 @@ DispatcherState::ListIterations() const { return iterations; } -Status DispatcherState::IterationFromId( +absl::Status DispatcherState::IterationFromId( int64_t id, std::shared_ptr& iteration) const { auto it = iterations_.find(id); if (it == iterations_.end()) { @@ -373,7 +373,7 @@ Status DispatcherState::IterationFromId( return absl::OkStatus(); } -Status DispatcherState::IterationByKey( +absl::Status DispatcherState::IterationByKey( IterationKey iteration_key, std::shared_ptr& iteration) const { auto it = iterations_by_key_.find(iteration_key); @@ -393,7 +393,7 @@ int64_t DispatcherState::NextAvailableIterationId() const { return next_available_iteration_id_; } -Status DispatcherState::IterationForIterationClientId( +absl::Status DispatcherState::IterationForIterationClientId( int64_t iteration_client_id, std::shared_ptr& iteration) { iteration = iterations_for_client_ids_[iteration_client_id]; if (!iteration) { @@ -417,8 +417,8 @@ int64_t DispatcherState::NextAvailableIterationClientId() const { return next_available_iteration_client_id_; } -Status DispatcherState::TaskFromId(int64_t id, - std::shared_ptr& task) const { +absl::Status DispatcherState::TaskFromId( + int64_t id, std::shared_ptr& task) const { auto it = tasks_.find(id); if (it == tasks_.end()) { return errors::NotFound("Task ", id, " not found"); @@ -427,7 +427,7 @@ Status DispatcherState::TaskFromId(int64_t id, return absl::OkStatus(); } -Status DispatcherState::TasksForIteration( +absl::Status DispatcherState::TasksForIteration( int64_t iteration_id, std::vector>& tasks) const { auto it = tasks_by_iteration_.find(iteration_id); @@ -442,7 +442,7 @@ Status DispatcherState::TasksForIteration( return absl::OkStatus(); } -Status DispatcherState::TasksForWorker( +absl::Status DispatcherState::TasksForWorker( absl::string_view worker_address, std::vector>& tasks) const { tasks.clear(); @@ -463,7 +463,8 @@ int64_t DispatcherState::NextAvailableTaskId() const { return next_available_task_id_; } -Status DispatcherState::ValidateWorker(absl::string_view worker_address) const { +absl::Status DispatcherState::ValidateWorker( + absl::string_view worker_address) const { return worker_index_resolver_.ValidateWorker(worker_address); } diff --git a/tensorflow/core/data/service/dispatcher_state.h b/tensorflow/core/data/service/dispatcher_state.h index e64b48771400ad..054c32037b8b49 100644 --- a/tensorflow/core/data/service/dispatcher_state.h +++ b/tensorflow/core/data/service/dispatcher_state.h @@ -71,7 +71,7 @@ class DispatcherState { DispatcherState& operator=(const DispatcherState&) = delete; // Applies the given update to the dispatcher's state. - Status Apply(const Update& update); + absl::Status Apply(const Update& update); // A dataset registered with the dispatcher. struct Dataset { @@ -229,38 +229,38 @@ class DispatcherState { std::string NextAvailableDatasetId() const; // Gets a dataset by id. Returns NOT_FOUND if there is no such dataset. - Status DatasetFromId(const std::string& id, - std::shared_ptr& dataset) const; + absl::Status DatasetFromId(const std::string& id, + std::shared_ptr& dataset) const; // Gets a worker by address. Returns NOT_FOUND if there is no such worker. - Status WorkerFromAddress(const std::string& address, - std::shared_ptr& worker) const; + absl::Status WorkerFromAddress(const std::string& address, + std::shared_ptr& worker) const; // Lists all workers registered with the dispatcher. std::vector> ListWorkers() const; // Returns the next available job id. int64_t NextAvailableJobId() const; // Gets a job by id. Returns NOT_FOUND if there is no such job. - Status JobFromId(int64_t job_id, std::shared_ptr& job) const; + absl::Status JobFromId(int64_t job_id, std::shared_ptr& job) const; // Gets a job by name. Returns NOT_FOUND if there is no such job. - Status JobByName(const std::string& job_name, - std::shared_ptr& job) const; + absl::Status JobByName(const std::string& job_name, + std::shared_ptr& job) const; // Returns the next available iteration id. int64_t NextAvailableIterationId() const; // Returns a list of all iterations. std::vector> ListIterations() const; // Gets an iteration by id. Returns NOT_FOUND if there is no such iteration. - Status IterationFromId(int64_t id, - std::shared_ptr& iteration) const; + absl::Status IterationFromId( + int64_t id, std::shared_ptr& iteration) const; // Gets an iteration by key. Returns NOT_FOUND if there is no such iteration. - Status IterationByKey(IterationKey key, - std::shared_ptr& iteration) const; + absl::Status IterationByKey( + IterationKey key, std::shared_ptr& iteration) const; // Returns the iteration associated with the given iteration client id. // Returns NOT_FOUND if the iteration_client_id is unknown or has been // released. - Status IterationForIterationClientId( + absl::Status IterationForIterationClientId( int64_t iteration_client_id, std::shared_ptr& iteration); // Returns a list of all active client ids. std::vector ListActiveClientIds(); @@ -270,20 +270,21 @@ class DispatcherState { // Returns the next available task id. int64_t NextAvailableTaskId() const; // Gets a task by id. Returns NOT_FOUND if there is no such task. - Status TaskFromId(int64_t id, std::shared_ptr& task) const; + absl::Status TaskFromId(int64_t id, std::shared_ptr& task) const; // Stores a list of all tasks for the given iteration to `tasks`. Returns // NOT_FOUND if there is no such iteration. - Status TasksForIteration( + absl::Status TasksForIteration( int64_t iteration_id, std::vector>& tasks) const; // Stores a list of all tasks for the given worker to `tasks`. Returns // NOT_FOUND if there is no such worker. - Status TasksForWorker(const absl::string_view worker_address, - std::vector>& tasks) const; + absl::Status TasksForWorker( + const absl::string_view worker_address, + std::vector>& tasks) const; // If the dispatcher config explicitly specifies a list of workers, validates // `worker_address` is in the list. - Status ValidateWorker(absl::string_view worker_address) const; + absl::Status ValidateWorker(absl::string_view worker_address) const; // If the dispatcher config specifies worker addresses, `GetWorkerIndex` // returns the worker index according to the list. This is useful for diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index e561ecb4dd08c2..a2b8408402a8e8 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -46,21 +46,23 @@ using ::testing::SizeIs; using ::testing::UnorderedElementsAre; using ::tsl::testing::StatusIs; -Status RegisterDataset(const std::string& dataset_id, DispatcherState& state) { +absl::Status RegisterDataset(const std::string& dataset_id, + DispatcherState& state) { Update update; RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset(); register_dataset->set_dataset_id(dataset_id); return state.Apply(update); } -Status RegisterWorker(std::string worker_address, DispatcherState& state) { +absl::Status RegisterWorker(std::string worker_address, + DispatcherState& state) { Update update; update.mutable_register_worker()->set_worker_address(worker_address); return state.Apply(update); } -Status CreateJob(int64_t job_id, const std::string& dataset_id, - const std::string& job_name, DispatcherState& state) { +absl::Status CreateJob(int64_t job_id, const std::string& dataset_id, + const std::string& job_name, DispatcherState& state) { Update update; CreateJobUpdate* create_job = update.mutable_create_job(); create_job->set_job_id(job_id); @@ -69,9 +71,10 @@ Status CreateJob(int64_t job_id, const std::string& dataset_id, return state.Apply(update); } -Status CreateIteration(int64_t iteration_id, const std::string& dataset_id, - const IterationKey& named_iteration_key, - DispatcherState& state) { +absl::Status CreateIteration(int64_t iteration_id, + const std::string& dataset_id, + const IterationKey& named_iteration_key, + DispatcherState& state) { int64_t job_id = state.NextAvailableJobId(); TF_RETURN_IF_ERROR( CreateJob(job_id, dataset_id, named_iteration_key.name, state)); @@ -83,15 +86,16 @@ Status CreateIteration(int64_t iteration_id, const std::string& dataset_id, return state.Apply(update); } -Status CreateIteration(int64_t iteration_id, const std::string& dataset_id, - DispatcherState& state) { +absl::Status CreateIteration(int64_t iteration_id, + const std::string& dataset_id, + DispatcherState& state) { IterationKey key(/*name=*/absl::StrCat(random::New64()), /*repetition=*/0); return CreateIteration(iteration_id, dataset_id, key, state); } -Status AcquireIterationClientId(int64_t iteration_id, - int64_t iteration_client_id, - DispatcherState& state) { +absl::Status AcquireIterationClientId(int64_t iteration_id, + int64_t iteration_client_id, + DispatcherState& state) { Update update; AcquireIterationClientUpdate* acquire_iteration_client = update.mutable_acquire_iteration_client(); @@ -100,8 +104,9 @@ Status AcquireIterationClientId(int64_t iteration_id, return state.Apply(update); } -Status ReleaseIterationClientId(int64_t iteration_client_id, - int64_t release_time, DispatcherState& state) { +absl::Status ReleaseIterationClientId(int64_t iteration_client_id, + int64_t release_time, + DispatcherState& state) { Update update; ReleaseIterationClientUpdate* release_iteration_client = update.mutable_release_iteration_client(); @@ -110,8 +115,9 @@ Status ReleaseIterationClientId(int64_t iteration_client_id, return state.Apply(update); } -Status CreateTask(int64_t task_id, int64_t iteration_id, - const std::string& worker_address, DispatcherState& state) { +absl::Status CreateTask(int64_t task_id, int64_t iteration_id, + const std::string& worker_address, + DispatcherState& state) { Update update; CreateTaskUpdate* create_task = update.mutable_create_task(); create_task->set_task_id(task_id); @@ -120,14 +126,14 @@ Status CreateTask(int64_t task_id, int64_t iteration_id, return state.Apply(update); } -Status FinishTask(int64_t task_id, DispatcherState& state) { +absl::Status FinishTask(int64_t task_id, DispatcherState& state) { Update update; FinishTaskUpdate* finish_task = update.mutable_finish_task(); finish_task->set_task_id(task_id); return state.Apply(update); } -Status Snapshot(const std::string& path, DispatcherState& state) { +absl::Status Snapshot(const std::string& path, DispatcherState& state) { Update update; SnapshotUpdate* snapshot = update.mutable_snapshot(); snapshot->set_path(path); @@ -205,7 +211,7 @@ TEST(DispatcherState, RegisterDatasetElementSpec) { TEST(DispatcherState, MissingDatasetId) { DispatcherState state; std::shared_ptr dataset; - Status s = state.DatasetFromId("missing_dataset_id", dataset); + absl::Status s = state.DatasetFromId("missing_dataset_id", dataset); EXPECT_EQ(s.code(), error::NOT_FOUND); } @@ -293,14 +299,14 @@ TEST(DispatcherState, ListWorkers) { TEST(DispatcherState, MissingWorker) { DispatcherState state; std::shared_ptr worker; - Status s = state.WorkerFromAddress("test_worker_address", worker); + absl::Status s = state.WorkerFromAddress("test_worker_address", worker); EXPECT_EQ(s.code(), error::NOT_FOUND); } TEST(DispatcherState, UnknownUpdate) { DispatcherState state; Update update; - Status s = state.Apply(update); + absl::Status s = state.Apply(update); EXPECT_EQ(s.code(), error::INTERNAL); } @@ -588,7 +594,7 @@ TEST(DispatcherState, ReleaseIterationClientId) { std::shared_ptr iteration; TF_EXPECT_OK(state.IterationFromId(iteration_id, iteration)); EXPECT_EQ(iteration->num_clients, 0); - Status s = + absl::Status s = state.IterationForIterationClientId(iteration_client_id, iteration); EXPECT_EQ(s.code(), error::NOT_FOUND); } diff --git a/tensorflow/core/data/service/graph_rewriters.cc b/tensorflow/core/data/service/graph_rewriters.cc index af2059ae89c707..c9d3c9b405bcbf 100644 --- a/tensorflow/core/data/service/graph_rewriters.cc +++ b/tensorflow/core/data/service/graph_rewriters.cc @@ -182,7 +182,7 @@ AutoShardRewriter::GetRewriteConfig() const { return config; } -Status WorkerIndexResolver::ValidateWorker( +absl::Status WorkerIndexResolver::ValidateWorker( absl::string_view worker_address) const { if (worker_addresses_.empty()) { return absl::OkStatus(); diff --git a/tensorflow/core/data/service/graph_rewriters.h b/tensorflow/core/data/service/graph_rewriters.h index 84c43a4f29d579..bdcf630bc88909 100644 --- a/tensorflow/core/data/service/graph_rewriters.h +++ b/tensorflow/core/data/service/graph_rewriters.h @@ -89,7 +89,7 @@ class WorkerIndexResolver { // Validates `worker_address`. Returns an error if the `worker_addresses` list // is non-empty and `worker_address` is not specified in the worker addresses // list (with optional port replacement). - Status ValidateWorker(absl::string_view worker_address) const; + absl::Status ValidateWorker(absl::string_view worker_address) const; // Processes a worker at address `worker_address`. Its index can be retrieved // by calling `GetWorkerIndex`. diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.cc b/tensorflow/core/data/service/grpc_dispatcher_impl.cc index 5b87a081c16fdf..45a46b94d9faff 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.cc @@ -33,7 +33,7 @@ GrpcDispatcherImpl::GrpcDispatcherImpl( VLOG(1) << "Registered data service dispatcher"; } -Status GrpcDispatcherImpl::Start() { return impl_.Start(); } +absl::Status GrpcDispatcherImpl::Start() { return impl_.Start(); } void GrpcDispatcherImpl::Stop() { impl_.Stop(); } diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h index 4ef4dd24de912c..50d5e2c3eec8af 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl.h +++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h @@ -34,7 +34,7 @@ class GrpcDispatcherImpl : public DispatcherService::Service { ::grpc::ServerBuilder& server_builder); ~GrpcDispatcherImpl() override { Stop(); } - Status Start(); + absl::Status Start(); void Stop(); size_t NumActiveIterations(); diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc b/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc index 393e754cdf524a..c04cdf7a718456 100644 --- a/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc +++ b/tensorflow/core/data/service/grpc_dispatcher_impl_test.cc @@ -60,14 +60,14 @@ class GrpcDispatcherImplTest : public ::testing::Test { TF_ASSERT_OK(SetUpDispatcherClientStub()); } - Status SetUpDispatcherServer() { + absl::Status SetUpDispatcherServer() { experimental::DispatcherConfig config; config.set_protocol(kProtocol); TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_server_)); return dispatcher_server_->Start(); } - Status SetUpDispatcherClientStub() { + absl::Status SetUpDispatcherClientStub() { std::shared_ptr credentials; TF_RETURN_IF_ERROR( CredentialsFactory::CreateClientCredentials(kProtocol, &credentials)); diff --git a/tensorflow/core/data/service/grpc_util.cc b/tensorflow/core/data/service/grpc_util.cc index daa8cbe8edef29..c85312477b3a92 100644 --- a/tensorflow/core/data/service/grpc_util.cc +++ b/tensorflow/core/data/service/grpc_util.cc @@ -35,7 +35,8 @@ namespace grpc_util { constexpr char kStreamRemovedMessage[] = "Stream removed"; -Status WrapError(const std::string& message, const ::grpc::Status& status) { +absl::Status WrapError(const std::string& message, + const ::grpc::Status& status) { if (status.ok()) { return errors::Internal("Expected a non-ok grpc status. Wrapping message: ", message); @@ -45,24 +46,25 @@ Status WrapError(const std::string& message, const ::grpc::Status& status) { // errors use other status codes (b/258285154). // TODO(aaudibert): Upstream this to FromGrpcStatus. if (status.error_message() == kStreamRemovedMessage) { - return Status(absl::StatusCode::kUnavailable, kStreamRemovedMessage); + return absl::Status(absl::StatusCode::kUnavailable, + kStreamRemovedMessage); } - Status s = FromGrpcStatus(status); - return Status(s.code(), - absl::StrCat(message, ": ", status.error_message())); + absl::Status s = FromGrpcStatus(status); + return absl::Status(s.code(), + absl::StrCat(message, ": ", status.error_message())); } } -Status Retry(const std::function& f, const std::string& description, - int64_t deadline_micros) { +absl::Status Retry(const std::function& f, + const std::string& description, int64_t deadline_micros) { return Retry( f, [] { return true; }, description, deadline_micros); } -Status Retry(const std::function& f, - const std::function& should_retry, - const std::string& description, int64_t deadline_micros) { - Status s = f(); +absl::Status Retry(const std::function& f, + const std::function& should_retry, + const std::string& description, int64_t deadline_micros) { + absl::Status s = f(); for (int num_retries = 0;; ++num_retries) { if (!IsPreemptedError(s)) { return s; diff --git a/tensorflow/core/data/service/grpc_util.h b/tensorflow/core/data/service/grpc_util.h index 79c52f63e31486..8fff6312c63d8a 100644 --- a/tensorflow/core/data/service/grpc_util.h +++ b/tensorflow/core/data/service/grpc_util.h @@ -27,7 +27,8 @@ namespace data { namespace grpc_util { // Wraps a grpc::Status in a tensorflow::Status with the given message. -Status WrapError(const std::string& message, const ::grpc::Status& status); +absl::Status WrapError(const std::string& message, + const ::grpc::Status& status); // Retries the given function if the function produces UNAVAILABLE, ABORTED, or // CANCELLED status codes. We retry these codes because they can all indicate @@ -37,14 +38,14 @@ Status WrapError(const std::string& message, const ::grpc::Status& status); // being retried, e.g. "register dataset" The retry loop uses exponential // backoff between retries. `deadline_micros` is interpreted as microseconds // since the epoch. -Status Retry(const std::function& f, - const std::function& should_retry, - const std::string& description, int64_t deadline_micros); +absl::Status Retry(const std::function& f, + const std::function& should_retry, + const std::string& description, int64_t deadline_micros); // Same as `Retry` above, but with a `should_retry` callback that always returns // `true`. -Status Retry(const std::function& f, const std::string& description, - int64_t deadline_micros); +absl::Status Retry(const std::function& f, + const std::string& description, int64_t deadline_micros); } // namespace grpc_util } // namespace data diff --git a/tensorflow/core/data/service/grpc_util_test.cc b/tensorflow/core/data/service/grpc_util_test.cc index 47a0b4c4d89dec..066d895ef3644b 100644 --- a/tensorflow/core/data/service/grpc_util_test.cc +++ b/tensorflow/core/data/service/grpc_util_test.cc @@ -23,13 +23,13 @@ namespace grpc_util { TEST(GrpcUtil, WrapInvalidArgument) { grpc::Status s(grpc::StatusCode::INVALID_ARGUMENT, "test message"); - Status wrapped = WrapError("wrapping message", s); + absl::Status wrapped = WrapError("wrapping message", s); ASSERT_EQ(wrapped, errors::InvalidArgument("wrapping message: test message")); } TEST(GrpcUtil, WrapOk) { grpc::Status s; - Status wrapped = WrapError("wrapping message", s); + absl::Status wrapped = WrapError("wrapping message", s); ASSERT_EQ(wrapped, errors::Internal("Expected a non-ok grpc status. Wrapping " "message: wrapping message")); } diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc index d83879bed599b9..a1f88e2ab456be 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.cc +++ b/tensorflow/core/data/service/grpc_worker_impl.cc @@ -41,7 +41,7 @@ GrpcWorkerImpl::GrpcWorkerImpl(const experimental::WorkerConfig& config, VLOG(1) << "Registered data service worker"; } -Status GrpcWorkerImpl::Start( +absl::Status GrpcWorkerImpl::Start( const std::string& worker_address, const std::vector& transfer_servers) { worker_address_ = worker_address; diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h index 9969aac48ed910..4513c0ca6ce8fb 100644 --- a/tensorflow/core/data/service/grpc_worker_impl.h +++ b/tensorflow/core/data/service/grpc_worker_impl.h @@ -41,11 +41,12 @@ class GrpcWorkerImpl : public WorkerService::Service { ::grpc::ServerBuilder& server_builder); ~GrpcWorkerImpl() override { Stop(); } - Status Start(const std::string& worker_address, - const std::vector& transfer_servers); + absl::Status Start( + const std::string& worker_address, + const std::vector& transfer_servers); void Stop(); - std::function + std::function get_element_getter() { return [this](const GetElementRequest* request, GetElementResult* result) { return impl_->GetElementResult(request, result); diff --git a/tensorflow/core/data/service/grpc_worker_impl_test.cc b/tensorflow/core/data/service/grpc_worker_impl_test.cc index 062117c94999d6..23eb6989c8cb1a 100644 --- a/tensorflow/core/data/service/grpc_worker_impl_test.cc +++ b/tensorflow/core/data/service/grpc_worker_impl_test.cc @@ -62,14 +62,14 @@ class GrpcWorkerImplTest : public ::testing::Test { TF_ASSERT_OK(SetUpWorkerClientStub()); } - Status SetUpDispatcherServer() { + absl::Status SetUpDispatcherServer() { experimental::DispatcherConfig config; config.set_protocol(kProtocol); TF_RETURN_IF_ERROR(NewDispatchServer(config, dispatcher_server_)); return dispatcher_server_->Start(); } - Status SetUpWorkerServer() { + absl::Status SetUpWorkerServer() { experimental::WorkerConfig config; config.set_protocol(kProtocol); config.set_dispatcher_address(GetDispatcherAddress()); @@ -78,7 +78,7 @@ class GrpcWorkerImplTest : public ::testing::Test { return worker_server_->Start(); } - Status SetUpWorkerClientStub() { + absl::Status SetUpWorkerClientStub() { std::shared_ptr credentials; TF_RETURN_IF_ERROR( CredentialsFactory::CreateClientCredentials(kProtocol, &credentials)); diff --git a/tensorflow/core/data/service/journal.cc b/tensorflow/core/data/service/journal.cc index d8ecd7adafce52..0462657e2363f7 100644 --- a/tensorflow/core/data/service/journal.cc +++ b/tensorflow/core/data/service/journal.cc @@ -36,8 +36,8 @@ namespace data { namespace { constexpr StringPiece kJournal = "journal"; -Status ParseSequenceNumber(const std::string& journal_file, - int64_t* sequence_number) { +absl::Status ParseSequenceNumber(const std::string& journal_file, + int64_t* sequence_number) { if (!RE2::FullMatch(journal_file, ".*_(\\d+)", sequence_number)) { return errors::InvalidArgument("Failed to parse journal file name: ", journal_file); @@ -55,7 +55,7 @@ std::string DataServiceJournalFile(const std::string& journal_dir, FileJournalWriter::FileJournalWriter(Env* env, const std::string& journal_dir) : env_(env), journal_dir_(journal_dir) {} -Status FileJournalWriter::EnsureInitialized() { +absl::Status FileJournalWriter::EnsureInitialized() { if (writer_) { return absl::OkStatus(); } @@ -76,7 +76,7 @@ Status FileJournalWriter::EnsureInitialized() { return absl::OkStatus(); } -Status FileJournalWriter::Write(const Update& update) { +absl::Status FileJournalWriter::Write(const Update& update) { TF_RETURN_IF_ERROR(EnsureInitialized()); std::string s = update.SerializeAsString(); if (s.empty()) { @@ -95,18 +95,18 @@ Status FileJournalWriter::Write(const Update& update) { FileJournalReader::FileJournalReader(Env* env, StringPiece journal_dir) : env_(env), journal_dir_(journal_dir) {} -Status FileJournalReader::EnsureInitialized() { +absl::Status FileJournalReader::EnsureInitialized() { if (reader_) { return absl::OkStatus(); } return UpdateFile(DataServiceJournalFile(journal_dir_, 0)); } -Status FileJournalReader::Read(Update& update, bool& end_of_journal) { +absl::Status FileJournalReader::Read(Update& update, bool& end_of_journal) { TF_RETURN_IF_ERROR(EnsureInitialized()); while (true) { tstring record; - Status s = reader_->ReadRecord(&record); + absl::Status s = reader_->ReadRecord(&record); if (absl::IsOutOfRange(s)) { sequence_number_++; std::string next_journal_file = @@ -132,7 +132,7 @@ Status FileJournalReader::Read(Update& update, bool& end_of_journal) { } } -Status FileJournalReader::UpdateFile(const std::string& filename) { +absl::Status FileJournalReader::UpdateFile(const std::string& filename) { VLOG(1) << "Reading from journal file " << filename; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(filename, &file_)); io::RecordReaderOptions opts; diff --git a/tensorflow/core/data/service/journal.h b/tensorflow/core/data/service/journal.h index 3d944d31b8fc86..7e909a268860d3 100644 --- a/tensorflow/core/data/service/journal.h +++ b/tensorflow/core/data/service/journal.h @@ -36,9 +36,9 @@ class JournalWriter { public: virtual ~JournalWriter() = default; // Writes and syncs an update to the journal. - virtual Status Write(const Update& update) = 0; + virtual absl::Status Write(const Update& update) = 0; // Initializes the writer if it is not yet initialized. - virtual Status EnsureInitialized() = 0; + virtual absl::Status EnsureInitialized() = 0; }; // FileJournalWriter is not thread-safe, requiring external synchronization when @@ -66,8 +66,8 @@ class FileJournalWriter : public JournalWriter { FileJournalWriter(const FileJournalWriter&) = delete; FileJournalWriter& operator=(const FileJournalWriter&) = delete; - Status Write(const Update& update) override; - Status EnsureInitialized() override; + absl::Status Write(const Update& update) override; + absl::Status EnsureInitialized() override; private: Env* env_; @@ -82,7 +82,7 @@ class JournalReader { virtual ~JournalReader() = default; // Reads the next update from the journal. Sets `end_of_journal=true` if // there are no more updates left in the journal. - virtual Status Read(Update& update, bool& end_of_journal) = 0; + virtual absl::Status Read(Update& update, bool& end_of_journal) = 0; }; // JournalReader is not thread-safe, requiring external synchronization when @@ -96,13 +96,13 @@ class FileJournalReader : public JournalReader { FileJournalReader(const FileJournalReader&) = delete; FileJournalReader& operator=(const FileJournalReader&) = delete; - Status Read(Update& update, bool& end_of_journal) override; + absl::Status Read(Update& update, bool& end_of_journal) override; private: // Initializes the reader if it is not yet initialized. - Status EnsureInitialized(); + absl::Status EnsureInitialized(); // Updates the `FileJournalReader` to read from a new file. - Status UpdateFile(const std::string& filename); + absl::Status UpdateFile(const std::string& filename); Env* env_; const std::string journal_dir_; diff --git a/tensorflow/core/data/service/journal_test.cc b/tensorflow/core/data/service/journal_test.cc index 6ee4dd4bd3f4af..bb9132d81725aa 100644 --- a/tensorflow/core/data/service/journal_test.cc +++ b/tensorflow/core/data/service/journal_test.cc @@ -67,8 +67,8 @@ Update MakeRegisterDatasetUpdate() { return update; } -Status CheckJournalContent(StringPiece journal_dir, - const std::vector& expected) { +absl::Status CheckJournalContent(StringPiece journal_dir, + const std::vector& expected) { FileJournalReader reader(Env::Default(), journal_dir); for (const auto& update : expected) { Update result; @@ -121,7 +121,7 @@ TEST(Journal, MissingFile) { FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(result, end_of_journal); + absl::Status s = reader.Read(result, end_of_journal); EXPECT_TRUE(absl::IsNotFound(s)); } @@ -140,7 +140,7 @@ TEST(Journal, NonRecordData) { FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(result, end_of_journal); + absl::Status s = reader.Read(result, end_of_journal); EXPECT_THAT(s.message(), HasSubstr("corrupted record")); EXPECT_EQ(s.code(), error::DATA_LOSS); } @@ -161,7 +161,7 @@ TEST(Journal, InvalidRecordData) { FileJournalReader reader(Env::Default(), journal_dir); Update result; bool end_of_journal = true; - Status s = reader.Read(result, end_of_journal); + absl::Status s = reader.Read(result, end_of_journal); EXPECT_THAT(s.message(), HasSubstr("Failed to parse journal record")); EXPECT_EQ(s.code(), error::DATA_LOSS); } diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc index a43892316c16d6..b49fdbcd651f74 100644 --- a/tensorflow/core/data/service/server_lib.cc +++ b/tensorflow/core/data/service/server_lib.cc @@ -54,7 +54,7 @@ GrpcDataServerBase::GrpcDataServerBase( bound_port_(port), server_options_(std::move(options)) {} -Status GrpcDataServerBase::Start() { +absl::Status GrpcDataServerBase::Start() { if (stopped_) { return errors::FailedPrecondition( "Server cannot be started after it has been stopped."); @@ -127,13 +127,13 @@ void DispatchGrpcDataServer::AddDataServiceToBuilder( service_ = std::make_unique(config_, builder).release(); } -Status DispatchGrpcDataServer::StartServiceInternal() { +absl::Status DispatchGrpcDataServer::StartServiceInternal() { return service_->Start(); } void DispatchGrpcDataServer::StopServiceInternal() { service_->Stop(); } -Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { +absl::Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { GetWorkersRequest req; GetWorkersResponse resp; ::grpc::ServerContext ctx; @@ -145,7 +145,7 @@ Status DispatchGrpcDataServer::NumWorkers(int* num_workers) { return absl::OkStatus(); } -Status DispatchGrpcDataServer::SnapshotStreams( +absl::Status DispatchGrpcDataServer::SnapshotStreams( const std::string& path, std::vector* streams) { GetSnapshotStreamsRequest req; req.set_path(path); @@ -192,9 +192,9 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer( config_.data_transfer_protocol() == kGrpcTransferProtocol) { return; } - Status s = DataTransferServer::Build(config_.data_transfer_protocol(), - service_->get_element_getter(), - &transfer_server_); + absl::Status s = DataTransferServer::Build(config_.data_transfer_protocol(), + service_->get_element_getter(), + &transfer_server_); if (!s.ok()) { LOG(ERROR) << "failed to build " << config_.data_transfer_protocol() << " server for worker " << config_.worker_address() << ": " @@ -232,7 +232,7 @@ void WorkerGrpcDataServer::MaybeStartAlternativeDataTransferServer( transfer_servers.push_back(alternative_transfer_server); } -Status WorkerGrpcDataServer::StartServiceInternal() { +absl::Status WorkerGrpcDataServer::StartServiceInternal() { std::string base_address = config_.worker_address(); if (base_address.empty()) { base_address = absl::StrCat("localhost:", kPortPlaceholder); @@ -251,7 +251,7 @@ Status WorkerGrpcDataServer::StartServiceInternal() { void WorkerGrpcDataServer::StopServiceInternal() { service_->Stop(); } -Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { +absl::Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { GetWorkerTasksRequest req; GetWorkerTasksResponse resp; ::grpc::ServerContext ctx; @@ -263,7 +263,7 @@ Status WorkerGrpcDataServer::NumTasks(int* num_tasks) { return absl::OkStatus(); } -Status WorkerGrpcDataServer::SnapshotTaskProgresses( +absl::Status WorkerGrpcDataServer::SnapshotTaskProgresses( std::vector* snapshot_task_progresses) { GetSnapshotTaskProgressesRequest req; GetSnapshotTaskProgressesResponse resp; @@ -284,14 +284,16 @@ ServerStateExport WorkerGrpcDataServer::ExportState() const { return server_state_export; } -Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr& out_server) { +absl::Status NewDispatchServer( + const experimental::DispatcherConfig& config, + std::unique_ptr& out_server) { out_server = std::make_unique(config); return absl::OkStatus(); } -Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr& out_server) { +absl::Status NewWorkerServer( + const experimental::WorkerConfig& config, + std::unique_ptr& out_server) { out_server = std::make_unique(config); return absl::OkStatus(); } diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h index 8c76d6f4b9543a..56a8f8d94fc558 100644 --- a/tensorflow/core/data/service/server_lib.h +++ b/tensorflow/core/data/service/server_lib.h @@ -50,7 +50,7 @@ class GrpcDataServerBase { virtual ~GrpcDataServerBase() = default; // Starts the server running asynchronously. - Status Start(); + absl::Status Start(); // Stops the server. This will block until all outstanding requests complete. void Stop(); @@ -69,7 +69,7 @@ class GrpcDataServerBase { void AddProfilerServiceToBuilder(::grpc::ServerBuilder& builder); // Starts the service. This will be called after building the service, so // bound_port() will return the actual bound port. - virtual Status StartServiceInternal() = 0; + virtual absl::Status StartServiceInternal() = 0; virtual void StopServiceInternal() {} int bound_port() { return bound_port_; } @@ -106,19 +106,19 @@ class DispatchGrpcDataServer : public GrpcDataServerBase { ~DispatchGrpcDataServer() override; // Returns the number of workers registered with the dispatcher. - Status NumWorkers(int* num_workers); + absl::Status NumWorkers(int* num_workers); // Returns the number of active (non-finished) iterations running on the // dispatcher. size_t NumActiveIterations(); // Returns information about all the streams for the snapshot at `path`. - Status SnapshotStreams(const std::string& path, - std::vector* streams); + absl::Status SnapshotStreams(const std::string& path, + std::vector* streams); ServerStateExport ExportState() const override; protected: void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; - Status StartServiceInternal() override; + absl::Status StartServiceInternal() override; void StopServiceInternal() override; private: @@ -147,18 +147,18 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { ~WorkerGrpcDataServer() override; // Returns the number of tasks currently being executed by the worker. - Status NumTasks(int* num_tasks); + absl::Status NumTasks(int* num_tasks); // Returns the progresses of the snapshot tasks currently being executed by // the worker. - Status SnapshotTaskProgresses( + absl::Status SnapshotTaskProgresses( std::vector* snapshot_task_progresses); ServerStateExport ExportState() const override; protected: void AddDataServiceToBuilder(::grpc::ServerBuilder& builder) override; - Status StartServiceInternal() override; + absl::Status StartServiceInternal() override; void StopServiceInternal() override; private: @@ -175,12 +175,13 @@ class WorkerGrpcDataServer : public GrpcDataServerBase { }; // Creates a dispatch tf.data server and stores it in `out_server`. -Status NewDispatchServer(const experimental::DispatcherConfig& config, - std::unique_ptr& out_server); +absl::Status NewDispatchServer( + const experimental::DispatcherConfig& config, + std::unique_ptr& out_server); // Creates a worker tf.data server and stores it in `out_server`. -Status NewWorkerServer(const experimental::WorkerConfig& config, - std::unique_ptr& out_server); +absl::Status NewWorkerServer(const experimental::WorkerConfig& config, + std::unique_ptr& out_server); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index 2b7d59674fabc3..ffc34db5936595 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -81,7 +81,7 @@ tf_cc_test( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -181,7 +181,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -287,7 +287,7 @@ tf_cc_test( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -321,7 +321,7 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -344,7 +344,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) @@ -394,7 +394,7 @@ cc_library( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:status_proto_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) @@ -423,7 +423,7 @@ tf_cc_test( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:status_proto_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) @@ -509,8 +509,8 @@ tf_cc_test( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:protos_all_cc", "@local_xla//xla/tsl/lib/monitoring:cell_reader", + "@local_xla//xla/tsl/protobuf:protos_all_cc", ], ) diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc index 8974964c9b3a81..f95fafb9343669 100644 --- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc +++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/snapshot/test_utils.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/file_utils.cc b/tensorflow/core/data/service/snapshot/file_utils.cc index 0440b00b34f7f0..ec5397bdfd1ed9 100644 --- a/tensorflow/core/data/service/snapshot/file_utils.cc +++ b/tensorflow/core/data/service/snapshot/file_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/platform/protobuf.h" #include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc index 9582cab18bc143..dc4efcc9497f22 100644 --- a/tensorflow/core/data/service/snapshot/file_utils_test.cc +++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc @@ -19,20 +19,20 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc index 43944c6a41b8f1..1623ac904c5484 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc @@ -31,10 +31,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/path_utils_test.cc b/tensorflow/core/data/service/snapshot/path_utils_test.cc index 5ee48efeecb5f8..f68d7eaff18040 100644 --- a/tensorflow/core/data/service/snapshot/path_utils_test.cc +++ b/tensorflow/core/data/service/snapshot/path_utils_test.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/data/service/snapshot/path_utils.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc index d28285966ff251..db6b7bd6733818 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider.cc @@ -26,10 +26,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc index 1e019a1742651e..8e7c473840c004 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc @@ -32,13 +32,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/split_provider.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc index 731f9435ffcd13..ff1e2caea35b00 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider.cc @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -46,7 +47,6 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tstring.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc index e40fd0ad918387..e6fcd97ef6d5dd 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/platform/tstring.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.cc b/tensorflow/core/data/service/snapshot/snapshot_manager.cc index 8ea22871945af2..fffd36c09139a5 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.cc @@ -34,6 +34,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -43,7 +46,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/mutex.h" @@ -51,8 +53,6 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager.h b/tensorflow/core/data/service/snapshot/snapshot_manager.h index 8c53ae98650878..5db495f16c87ce 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager.h +++ b/tensorflow/core/data/service/snapshot/snapshot_manager.h @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/prefetched_split_provider.h" #include "tensorflow/core/framework/dataset.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc index 65b3c59e8ecba4..cff201261b00c4 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -30,8 +32,6 @@ limitations under the License. #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc index b9b9f3d3d4d8ed..18f39b446ea218 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc @@ -24,6 +24,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" @@ -35,13 +37,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc index 071c7e1f1c72a1..dea3f3dd5785d8 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/test_utils.h" #include "tensorflow/core/data/service/task_runner.h" #include "tensorflow/core/data/service/test_util.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc index 00dbdd947eefac..c557d1630194e7 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc @@ -26,7 +26,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/io/compression.h" #include "xla/tsl/lib/monitoring/cell_reader.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -37,12 +39,10 @@ limitations under the License. #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/split_provider.cc b/tensorflow/core/data/service/split_provider.cc index ddf9d7c7a67138..8fdd913da8ff72 100644 --- a/tensorflow/core/data/service/split_provider.cc +++ b/tensorflow/core/data/service/split_provider.cc @@ -34,7 +34,8 @@ limitations under the License. namespace tensorflow { namespace data { -Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) +absl::Status DataServiceSplitProvider::GetNext(Tensor* split, + bool* end_of_splits) TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); if (!dispatcher_) { @@ -61,27 +62,27 @@ Status DataServiceSplitProvider::GetNext(Tensor* split, bool* end_of_splits) return absl::OkStatus(); } -Status DataServiceSplitProvider::Reset() TF_LOCKS_EXCLUDED(mu_) { +absl::Status DataServiceSplitProvider::Reset() TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); repetition_++; return absl::OkStatus(); } -Status DataServiceSplitProvider::Save( +absl::Status DataServiceSplitProvider::Save( std::function full_name, IteratorStateWriter* writer) { return errors::Unimplemented( "Save is not implemented for DataServiceSplitProvider"); } -Status DataServiceSplitProvider::Restore( +absl::Status DataServiceSplitProvider::Restore( std::function full_name, IteratorStateReader* reader) { return errors::Unimplemented( "Restore is not implemented for DataServiceSplitProvider"); } -Status CreateSplitProviders( +absl::Status CreateSplitProviders( const DatasetDef& dataset_def, std::vector>& split_providers) { standalone::Dataset::Params params; diff --git a/tensorflow/core/data/service/split_provider.h b/tensorflow/core/data/service/split_provider.h index c87d3a2eb9ba9f..c426fe1a795507 100644 --- a/tensorflow/core/data/service/split_provider.h +++ b/tensorflow/core/data/service/split_provider.h @@ -44,12 +44,12 @@ class DataServiceSplitProvider : public SplitProvider { split_provider_index_(split_provider_index), timeout_ms_(timeout_ms) {} - Status GetNext(Tensor* split, bool* end_of_splits) override; - Status Reset() override; - Status Save(std::function full_name, - IteratorStateWriter* writer) override; - Status Restore(std::function full_name, - IteratorStateReader* reader) override; + absl::Status GetNext(Tensor* split, bool* end_of_splits) override; + absl::Status Reset() override; + absl::Status Save(std::function full_name, + IteratorStateWriter* writer) override; + absl::Status Restore(std::function full_name, + IteratorStateReader* reader) override; private: const std::string address_; @@ -64,7 +64,7 @@ class DataServiceSplitProvider : public SplitProvider { }; // Makes split providers for `dataset_def` and stores them in `split_providers`. -Status CreateSplitProviders( +absl::Status CreateSplitProviders( const DatasetDef& dataset_def, std::vector>& split_providers); diff --git a/tensorflow/core/data/service/task_runner.cc b/tensorflow/core/data/service/task_runner.cc index 01d412f04a45a0..2b85af5aa20b73 100644 --- a/tensorflow/core/data/service/task_runner.cc +++ b/tensorflow/core/data/service/task_runner.cc @@ -56,8 +56,8 @@ StandaloneTaskIterator::StandaloneTaskIterator( std::unique_ptr iterator) : dataset_(std::move(dataset)), iterator_(std::move(iterator)) {} -Status StandaloneTaskIterator::GetNext(std::vector& element, - bool& end_of_sequence) { +absl::Status StandaloneTaskIterator::GetNext(std::vector& element, + bool& end_of_sequence) { return iterator_->GetNext(&element, &end_of_sequence); } @@ -69,7 +69,7 @@ absl::StatusOr> StandaloneTaskIterator::Save() { return iterator_->Save(); } -Status StandaloneTaskIterator::Restore( +absl::Status StandaloneTaskIterator::Restore( const std::vector& saved_iterator) { return iterator_->Restore(saved_iterator); } @@ -78,10 +78,10 @@ std::shared_ptr StandaloneTaskIterator::model() const { return iterator_->model(); } -Status TaskRunner::Create(const experimental::WorkerConfig& worker_config, - const TaskDef& task_def, - std::unique_ptr iterator, - std::unique_ptr& out) { +absl::Status TaskRunner::Create(const experimental::WorkerConfig& worker_config, + const TaskDef& task_def, + std::unique_ptr iterator, + std::unique_ptr& out) { if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) { int64_t cardinality = iterator->Cardinality(); if (cardinality != kInfiniteCardinality && @@ -116,8 +116,8 @@ FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner( FirstComeFirstServedTaskRunner::~FirstComeFirstServedTaskRunner() { Cancel(); } -Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req, - GetElementResult& result) { +absl::Status FirstComeFirstServedTaskRunner::GetNext( + const GetElementRequest& req, GetElementResult& result) { if (req.allow_skip() && buffer_.Empty()) { result.skip = true; return absl::OkStatus(); @@ -125,12 +125,12 @@ Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req, return GetNext(result); } -Status FirstComeFirstServedTaskRunner::GetNext(GetElementResult& result) { +absl::Status FirstComeFirstServedTaskRunner::GetNext(GetElementResult& result) { TF_ASSIGN_OR_RETURN(result, buffer_.Pop()); return absl::OkStatus(); } -Status FirstComeFirstServedTaskRunner::PrefetchFn() { +absl::Status FirstComeFirstServedTaskRunner::PrefetchFn() { while (true) { TF_RETURN_IF_ERROR(buffer_.Push(GetNextFromInputIterator())); } @@ -139,7 +139,7 @@ Status FirstComeFirstServedTaskRunner::PrefetchFn() { void FirstComeFirstServedTaskRunner::RunPrefetchThread() { auto prefetch_fn = [this] { - Status status = PrefetchFn(); + absl::Status status = PrefetchFn(); if (!status.ok()) { buffer_.Cancel(status); } @@ -188,8 +188,8 @@ CachingTaskRunner::CachingTaskRunner(std::unique_ptr iterator, CachingTaskRunner::~CachingTaskRunner() { Cancel(); } -Status CachingTaskRunner::GetNext(const GetElementRequest& req, - GetElementResult& result) { +absl::Status CachingTaskRunner::GetNext(const GetElementRequest& req, + GetElementResult& result) { TF_ASSIGN_OR_RETURN(std::shared_ptr element, cache_.Get(req.trainer_id())); result = element->Copy(); @@ -241,7 +241,8 @@ RoundRobinTaskRunner::RoundRobinTaskRunner( << num_consumers << " consumers"; } -Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) { +absl::Status RoundRobinTaskRunner::ValidateRequest( + const GetElementRequest& req) { if (req.consumer_index() < 0 || req.round_index() < 0) { return errors::FailedPrecondition( "RoundRobinTaskRunner needs to know the consumer index and element " @@ -255,7 +256,7 @@ Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) { return absl::OkStatus(); } -Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us) +absl::Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { VLOG(1) << worker_address_ << ": Preparing full round for round " << current_round_; @@ -266,7 +267,7 @@ Status RoundRobinTaskRunner::PrepareFullRound(int64_t wait_us) return absl::OkStatus(); } -Status RoundRobinTaskRunner::PreparePartialRound() +absl::Status RoundRobinTaskRunner::PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { VLOG(1) << worker_address_ << ": Starting partial round " << first_round_ << " for " << requests_[first_round_].size() << " consumers"; @@ -284,7 +285,7 @@ Status RoundRobinTaskRunner::PreparePartialRound() return absl::OkStatus(); } -Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) { +absl::Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) { mutex_lock l(mu_); first_round_ = std::min(first_round_, req.round_index()); absl::flat_hash_map& round = @@ -325,8 +326,8 @@ Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) { return prefetch_thread_.GetStatus(); } -Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req, - GetElementResult& result) { +absl::Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req, + GetElementResult& result) { TF_RETURN_IF_ERROR(ValidateRequest(req)); result.end_of_sequence = false; VLOG(2) << worker_address_ << ": Received request from consumer index " @@ -394,7 +395,7 @@ void PrefetchThread::Run() { } std::vector element; bool end_of_sequence; - Status s = iterator_->GetNext(element, end_of_sequence); + absl::Status s = iterator_->GetNext(element, end_of_sequence); if (!s.ok()) { mutex_lock l(mu_); status_ = s; @@ -417,8 +418,8 @@ void PrefetchThread::Run() { } } -Status PrefetchThread::FillBuffer(int64_t wait_us, - std::vector>& out) { +absl::Status PrefetchThread::FillBuffer( + int64_t wait_us, std::vector>& out) { int64_t start_us = Env::Default()->NowMicros(); out.clear(); mutex_lock l(mu_); @@ -445,7 +446,7 @@ Status PrefetchThread::FillBuffer(int64_t wait_us, return absl::OkStatus(); } -Status PrefetchThread::GetStatus() { +absl::Status PrefetchThread::GetStatus() { mutex_lock l(mu_); return status_; } diff --git a/tensorflow/core/data/service/task_runner.h b/tensorflow/core/data/service/task_runner.h index 7867d2d9dc45d3..79d698f9edc65f 100644 --- a/tensorflow/core/data/service/task_runner.h +++ b/tensorflow/core/data/service/task_runner.h @@ -44,8 +44,8 @@ class TaskIterator { // If the iterator is not yet exhausted, `GetNext` stores the next element in // `element` and sets `end_of_sequence` to `false`. Otherwise, sets // `end_of_sequence to `true`. - virtual Status GetNext(std::vector& element, - bool& end_of_sequence) = 0; + virtual absl::Status GetNext(std::vector& element, + bool& end_of_sequence) = 0; // Reports the cardinality of the dataset that created this iterator. virtual int64_t Cardinality() const = 0; @@ -58,7 +58,7 @@ class TaskIterator { // Restores the iterator from a checkpoint. `saved_iterator` is the serialized // iterator saved by calling `Save()`. - virtual Status Restore(const std::vector& saved_iterator) { + virtual absl::Status Restore(const std::vector& saved_iterator) { return errors::Unimplemented( "Restoring from a tf.data service task iterator is unsupported."); } @@ -75,10 +75,11 @@ class StandaloneTaskIterator : public TaskIterator { // lives as long as `iterator`. StandaloneTaskIterator(std::unique_ptr dataset, std::unique_ptr iterator); - Status GetNext(std::vector& element, bool& end_of_sequence) override; + absl::Status GetNext(std::vector& element, + bool& end_of_sequence) override; int64_t Cardinality() const override; absl::StatusOr> Save() override; - Status Restore(const std::vector& saved_iterator) override; + absl::Status Restore(const std::vector& saved_iterator) override; std::shared_ptr model() const override; private: @@ -90,14 +91,14 @@ class StandaloneTaskIterator : public TaskIterator { class TaskRunner { public: // Creates a `TaskRunner` and stores it in `out`. - static Status Create(const experimental::WorkerConfig& worker_config, - const TaskDef& task_def, - std::unique_ptr iterator, - std::unique_ptr& out); + static absl::Status Create(const experimental::WorkerConfig& worker_config, + const TaskDef& task_def, + std::unique_ptr iterator, + std::unique_ptr& out); virtual ~TaskRunner() = default; // Gets the next element for the given request. - virtual Status GetNext(const GetElementRequest& req, - GetElementResult& result) = 0; + virtual absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) = 0; // Cancels in-progress `GetNext` requests. virtual void Cancel() = 0; // Returns the dataset model for performance analysis. @@ -113,9 +114,9 @@ class FirstComeFirstServedTaskRunner : public TaskRunner { ~FirstComeFirstServedTaskRunner() override; // Gets the next element. It may block if the element is not ready yet. - Status GetNext(const GetElementRequest& req, - GetElementResult& result) override; - Status GetNext(GetElementResult& result); + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; + absl::Status GetNext(GetElementResult& result); void Cancel() override; @@ -124,7 +125,7 @@ class FirstComeFirstServedTaskRunner : public TaskRunner { private: // Function to continually prefetch the next element. Returns an error if the // task has been cancelled. - Status PrefetchFn(); + absl::Status PrefetchFn(); // Runs `PrefetchFn` on a dedicated thread. void RunPrefetchThread(); @@ -160,8 +161,8 @@ class CachingTaskRunner : public TaskRunner { // Gets the next element from the cross-trainer cache, blocking if the data is // not ready. // REQUIRES: !req.trainer_id().empty() - Status GetNext(const GetElementRequest& req, - GetElementResult& result) override; + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; // Cancel the task runner. After cancelling, all the `GetNext` calls will // return a Cancelled status. @@ -215,10 +216,10 @@ class PrefetchThread { // Fills `out` with a round of data. Waits for up to `wait_us` microseconds // before giving up and returning with `out` empty. A negative `wait_us` // signals to wait indefinitely. - Status FillBuffer(int64_t wait_us, - std::vector>& out); + absl::Status FillBuffer(int64_t wait_us, + std::vector>& out); // Returns the status for any failures encountered by the prefetch thread. - Status GetStatus(); + absl::Status GetStatus(); // Returns the dataset model for performance analysis. std::shared_ptr model() const; @@ -230,7 +231,7 @@ class PrefetchThread { // Buffered results for the next round. std::vector> buffer_ TF_GUARDED_BY(mu_); // The status if the prefetch thread fails. - Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); // Condition variable notified when elements are added to or removed from // `buffer_`, or when `status_` is changed. condition_variable cv_; @@ -262,21 +263,22 @@ class RoundRobinTaskRunner : public TaskRunner { RoundRobinTaskRunner(std::unique_ptr iterator, int64_t num_consumers, string worker_address); - Status GetNext(const GetElementRequest& req, - GetElementResult& result) override; + absl::Status GetNext(const GetElementRequest& req, + GetElementResult& result) override; void Cancel() override; std::shared_ptr model() const override; private: // Prepares a full round of data. `wait_us` indicates how long to wait before // skipping if a full round of data is not yet ready. - Status PrepareFullRound(int64_t wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status PrepareFullRound(int64_t wait_us) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Prepares a partial round to get consumers back in sync. - Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status ValidateRequest(const GetElementRequest& req); + absl::Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status ValidateRequest(const GetElementRequest& req); // Prepares data for the next round, blocking until the round is ready to // start. - Status PrepareRound(const GetElementRequest& req); + absl::Status PrepareRound(const GetElementRequest& req); const int64_t num_consumers_; const string worker_address_; mutex mu_; diff --git a/tensorflow/core/data/service/task_runner_test.cc b/tensorflow/core/data/service/task_runner_test.cc index ec2b4fc8615aae..62b1ab63251083 100644 --- a/tensorflow/core/data/service/task_runner_test.cc +++ b/tensorflow/core/data/service/task_runner_test.cc @@ -61,7 +61,8 @@ class RangeIterator : public TaskIterator { explicit RangeIterator(const int64_t range, const bool repeat) : range_(range), repeat_(repeat) {} - Status GetNext(std::vector& element, bool& end_of_sequence) override { + absl::Status GetNext(std::vector& element, + bool& end_of_sequence) override { end_of_sequence = (next_ >= range_); if (end_of_sequence) { return absl::OkStatus(); @@ -87,7 +88,8 @@ class InfiniteRangeIterator : public TaskIterator { public: InfiniteRangeIterator() = default; - Status GetNext(std::vector& element, bool& end_of_sequence) override { + absl::Status GetNext(std::vector& element, + bool& end_of_sequence) override { element = {Tensor{next_++}}; return absl::OkStatus(); } @@ -104,7 +106,8 @@ class ElementOrErrorIterator : public TaskIterator { explicit ElementOrErrorIterator(const std::vector>& elements) : elements_(elements) {} - Status GetNext(std::vector& element, bool& end_of_sequence) override { + absl::Status GetNext(std::vector& element, + bool& end_of_sequence) override { end_of_sequence = (next_ >= elements_.size()); if (end_of_sequence) { return absl::OkStatus(); @@ -176,9 +179,9 @@ std::vector GetRange(const size_t range) { } // Reads from the task runner, storing results in `*output`. -Status RunConsumer(int64_t consumer_index, int64_t start_index, - int64_t end_index, TaskRunner& task_runner, - std::vector& output) { +absl::Status RunConsumer(int64_t consumer_index, int64_t start_index, + int64_t end_index, TaskRunner& task_runner, + std::vector& output) { for (int64_t next_index = start_index; next_index < end_index; ++next_index) { GetElementRequest request; request.set_round_index(next_index); @@ -418,7 +421,7 @@ TEST(CachingTaskRunnerTest, CancelConcurrentReaders) { GetElementRequest request; request.set_trainer_id(absl::StrCat("Trainer_", (j % 100))); GetElementResult result; - Status status = runner.GetNext(request, result); + absl::Status status = runner.GetNext(request, result); if (!status.ok()) { return; } @@ -510,15 +513,16 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) { std::vector> per_consumer_results; std::vector> consumers; mutex mu; - Status error; + absl::Status error; for (int consumer = 0; consumer < num_consumers; ++consumer) { mutex_lock l(mu); per_consumer_results.emplace_back(); consumers.push_back(absl::WrapUnique(Env::Default()->StartThread( {}, absl::StrCat("consumer_", consumer), [&, consumer] { std::vector results; - Status s = RunConsumer(consumer, /*start_index=*/0, - /*end_index=*/num_elements, runner, results); + absl::Status s = + RunConsumer(consumer, /*start_index=*/0, + /*end_index=*/num_elements, runner, results); mutex_lock l(mu); if (!s.ok()) { error = s; @@ -558,15 +562,15 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) { std::vector> per_consumer_results; std::vector> consumers; mutex mu; - Status error; + absl::Status error; for (int consumer = 0; consumer < num_consumers; ++consumer) { mutex_lock l(mu); per_consumer_results.emplace_back(); consumers.push_back(absl::WrapUnique(Env::Default()->StartThread( {}, absl::StrCat("consumer_", consumer), [&, consumer] { std::vector results; - Status s = RunConsumer(consumer, starting_rounds[consumer], end_index, - runner, results); + absl::Status s = RunConsumer(consumer, starting_rounds[consumer], + end_index, runner, results); mutex_lock l(mu); if (!s.ok()) { error = s; diff --git a/tensorflow/core/data/service/test_cluster.cc b/tensorflow/core/data/service/test_cluster.cc index 9cda46ce0f1b0b..21c3976abda2a8 100644 --- a/tensorflow/core/data/service/test_cluster.cc +++ b/tensorflow/core/data/service/test_cluster.cc @@ -55,7 +55,7 @@ TestCluster::~TestCluster() { } } -Status TestCluster::Initialize() { +absl::Status TestCluster::Initialize() { if (initialized_) { return errors::FailedPrecondition( "Test cluster has already been initialized."); @@ -89,7 +89,7 @@ Status TestCluster::Initialize() { return absl::OkStatus(); } -Status TestCluster::AddWorker( +absl::Status TestCluster::AddWorker( std::optional port, std::optional data_transfer_protocol) { std::unique_ptr worker; diff --git a/tensorflow/core/data/service/test_cluster.h b/tensorflow/core/data/service/test_cluster.h index c071b9c2bffeae..a62669d7344d3d 100644 --- a/tensorflow/core/data/service/test_cluster.h +++ b/tensorflow/core/data/service/test_cluster.h @@ -67,9 +67,9 @@ class TestCluster { // Initializes the test cluster. This must be called before interacting with // the cluster. Initialize should be called only once. - Status Initialize(); + absl::Status Initialize(); // Adds a new worker to the cluster. - Status AddWorker( + absl::Status AddWorker( std::optional port = std::nullopt, std::optional data_transfer_protocol = std::nullopt); // Returns the number of workers in this cluster. diff --git a/tensorflow/core/data/service/thread_safe_buffer.h b/tensorflow/core/data/service/thread_safe_buffer.h index 6234de355eebca..570fb5cec5b46c 100644 --- a/tensorflow/core/data/service/thread_safe_buffer.h +++ b/tensorflow/core/data/service/thread_safe_buffer.h @@ -40,12 +40,12 @@ class ThreadSafeBuffer final { // Writes the next element. Blocks if the buffer is full. Returns an error if // the buffer has been cancelled. - Status Push(StatusOr value); + absl::Status Push(StatusOr value); // Cancels the buffer with `status` and notifies waiting threads. After // cancelling, all `Push` and `Pop` calls will return `status`. // REQUIRES: !status.ok() - void Cancel(Status status); + void Cancel(absl::Status status); // Returns whether the buffer is empty. bool Empty() const; @@ -57,7 +57,7 @@ class ThreadSafeBuffer final { condition_variable ready_to_pop_; condition_variable ready_to_push_; std::deque> results_ TF_GUARDED_BY(mu_); - Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); + absl::Status status_ TF_GUARDED_BY(mu_) = absl::OkStatus(); ThreadSafeBuffer(const ThreadSafeBuffer&) = delete; void operator=(const ThreadSafeBuffer&) = delete; @@ -93,7 +93,7 @@ StatusOr ThreadSafeBuffer::Pop() { } template -Status ThreadSafeBuffer::Push(StatusOr value) { +absl::Status ThreadSafeBuffer::Push(StatusOr value) { mutex_lock l(mu_); while (status_.ok() && results_.size() >= buffer_size_) { ready_to_push_.wait(l); @@ -107,7 +107,7 @@ Status ThreadSafeBuffer::Push(StatusOr value) { } template -void ThreadSafeBuffer::Cancel(Status status) { +void ThreadSafeBuffer::Cancel(absl::Status status) { DCHECK(!status.ok()) << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status; mutex_lock l(mu_); diff --git a/tensorflow/core/data/service/utils.cc b/tensorflow/core/data/service/utils.cc index 43ad439d21fc55..4f79b9384de3b7 100644 --- a/tensorflow/core/data/service/utils.cc +++ b/tensorflow/core/data/service/utils.cc @@ -27,7 +27,8 @@ limitations under the License. namespace tensorflow { namespace data { -Status WriteDatasetDef(const std::string& path, const DatasetDef& dataset_def) { +absl::Status WriteDatasetDef(const std::string& path, + const DatasetDef& dataset_def) { std::unique_ptr file; TF_RETURN_IF_ERROR(Env::Default()->NewWritableFile(path, &file)); io::RecordWriter writer(file.get()); @@ -35,7 +36,7 @@ Status WriteDatasetDef(const std::string& path, const DatasetDef& dataset_def) { return absl::OkStatus(); } -Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def) { +absl::Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def) { if (path.empty()) { return errors::InvalidArgument("Path is empty"); } diff --git a/tensorflow/core/data/service/utils.h b/tensorflow/core/data/service/utils.h index e7673ef7b55cb5..482d306efbeb55 100644 --- a/tensorflow/core/data/service/utils.h +++ b/tensorflow/core/data/service/utils.h @@ -28,11 +28,12 @@ namespace data { // Writes a dataset definition to the specified path. If the file already // exists, it will be overwritten. -Status WriteDatasetDef(const std::string& path, const DatasetDef& dataset_def); +absl::Status WriteDatasetDef(const std::string& path, + const DatasetDef& dataset_def); // Reads a dataset definition from specified path, and stores it in // `dataset_def`. Returns NOT_FOUND if the path cannot be found. -Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def); +absl::Status ReadDatasetDef(const std::string& path, DatasetDef& dataset_def); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/utils_test.cc b/tensorflow/core/data/service/utils_test.cc index 75cf219a92ef9b..acfdb3862b0009 100644 --- a/tensorflow/core/data/service/utils_test.cc +++ b/tensorflow/core/data/service/utils_test.cc @@ -64,7 +64,7 @@ TEST(Utils, ReadDatasetNotFound) { std::string filename = testing::TmpDir(); ASSERT_TRUE(Env::Default()->CreateUniqueFileName(&filename, "journal_dir")); DatasetDef result; - Status s = ReadDatasetDef(filename, result); + absl::Status s = ReadDatasetDef(filename, result); EXPECT_EQ(s.code(), error::NOT_FOUND); } diff --git a/tensorflow/core/data/service/validate_utils.cc b/tensorflow/core/data/service/validate_utils.cc index e0e7c7f94d3d6f..df968694a53f58 100644 --- a/tensorflow/core/data/service/validate_utils.cc +++ b/tensorflow/core/data/service/validate_utils.cc @@ -43,9 +43,9 @@ absl::StatusOr DecodeElementSpec( return decoded_spec; } -Status ValidateElementSpec(const std::string& dataset_id, - const std::string& encoded_spec1, - const std::string& encoded_spec2) { +absl::Status ValidateElementSpec(const std::string& dataset_id, + const std::string& encoded_spec1, + const std::string& encoded_spec2) { if (encoded_spec1.empty() && encoded_spec2.empty()) { return absl::OkStatus(); } @@ -70,9 +70,9 @@ Status ValidateElementSpec(const std::string& dataset_id, return absl::OkStatus(); } -Status ValidateDatasetMetadata(const std::string& dataset_id, - const DataServiceMetadata& metadata1, - const DataServiceMetadata& metadata2) { +absl::Status ValidateDatasetMetadata(const std::string& dataset_id, + const DataServiceMetadata& metadata1, + const DataServiceMetadata& metadata2) { TF_RETURN_IF_ERROR(ValidateElementSpec(dataset_id, metadata1.element_spec(), metadata2.element_spec())); MessageDifferencer differ; @@ -94,9 +94,9 @@ Status ValidateDatasetMetadata(const std::string& dataset_id, } // namespace -Status ValidateMatchingDataset(const std::string& dataset_id, - const DataServiceMetadata& metadata1, - const DataServiceMetadata& metadata2) { +absl::Status ValidateMatchingDataset(const std::string& dataset_id, + const DataServiceMetadata& metadata1, + const DataServiceMetadata& metadata2) { return ValidateDatasetMetadata(dataset_id, metadata1, metadata2); } diff --git a/tensorflow/core/data/service/validate_utils.h b/tensorflow/core/data/service/validate_utils.h index 0a91f7609850b4..c42780230e432a 100644 --- a/tensorflow/core/data/service/validate_utils.h +++ b/tensorflow/core/data/service/validate_utils.h @@ -25,9 +25,9 @@ namespace data { // Verifies the datasets with the same ID have the same metadata. If the // metadata differs, returns an invalid argument error. -Status ValidateMatchingDataset(const std::string& dataset_id, - const DataServiceMetadata& metadata1, - const DataServiceMetadata& metadata2); +absl::Status ValidateMatchingDataset(const std::string& dataset_id, + const DataServiceMetadata& metadata1, + const DataServiceMetadata& metadata2); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index 98306e42bcfde8..18510e5da36276 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -74,13 +74,13 @@ CreateDataServiceWorkerClient( return client; } -Status DataServiceWorkerClient::GetElement(const GetElementRequest& req, - GetElementResult& result) { +absl::Status DataServiceWorkerClient::GetElement(const GetElementRequest& req, + GetElementResult& result) { TF_RETURN_IF_ERROR(EnsureInitialized()); return client_->GetElement(req, result); } -Status DataServiceWorkerClient::EnsureInitialized() { +absl::Status DataServiceWorkerClient::EnsureInitialized() { mutex_lock l(mu_); if (client_) { return absl::OkStatus(); @@ -112,8 +112,8 @@ class GrpcDataTransferClient : public DataTransferClient { stub_ = WorkerService::NewStub(channel); } - Status GetElement(const GetElementRequest& req, - GetElementResult& result) override { + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) override { VLOG(3) << "GetElement for task " << req.task_id() << " from gRPC worker " << "server."; { @@ -215,14 +215,14 @@ class LocalDataTransferClient : public DataTransferClient { << "."; } - Status GetElement(const GetElementRequest& req, - GetElementResult& result) override { + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result) override { VLOG(3) << "GetElement for task " << req.task_id() << " from local worker."; TF_RETURN_IF_ERROR(VerifyClientIsNotCancelled()); TF_ASSIGN_OR_RETURN(std::shared_ptr worker, GetWorker(req)); int64_t start_time_us = env_->NowMicros(); - Status s = worker->GetElementResult(&req, &result); + absl::Status s = worker->GetElementResult(&req, &result); int64_t end_time_us = env_->NowMicros(); TF_RETURN_IF_ERROR(s); metrics::RecordTFDataServiceGetElementDuration(kLocalTransferProtocol, @@ -241,7 +241,7 @@ class LocalDataTransferClient : public DataTransferClient { } private: - Status VerifyClientIsNotCancelled() TF_LOCKS_EXCLUDED(mu_) { + absl::Status VerifyClientIsNotCancelled() TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); if (cancelled_) { return errors::Cancelled(absl::Substitute( diff --git a/tensorflow/core/data/service/worker_client.h b/tensorflow/core/data/service/worker_client.h index 014afdc6a98d1c..2bb5328461f323 100644 --- a/tensorflow/core/data/service/worker_client.h +++ b/tensorflow/core/data/service/worker_client.h @@ -45,14 +45,15 @@ class DataServiceWorkerClient : public DataServiceClientBase { allocator_(allocator) {} // Fetches an element from the worker. - Status GetElement(const GetElementRequest& req, GetElementResult& result); + absl::Status GetElement(const GetElementRequest& req, + GetElementResult& result); // Makes a best effort to cancel all outstanding calls in progress for the // client, and causes further calls to return Cancelled status. void TryCancel(); // Returns an error if the client is incompatible with a server which has the // properties described in `compatibility_info`. - Status CheckCompatibility( + absl::Status CheckCompatibility( const std::string& server_compatibility_info) const { return client_->CheckCompatibility(server_compatibility_info); } @@ -61,7 +62,7 @@ class DataServiceWorkerClient : public DataServiceClientBase { std::string GetDataTransferProtocol() const; protected: - Status EnsureInitialized() override; + absl::Status EnsureInitialized() override; private: std::string transfer_protocol_; diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index ebddaf184ce254..6d5c126cec0457 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/time/time.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.h" #include "tensorflow/core/data/service/common.pb.h" @@ -67,7 +68,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/status.pb.h" namespace tensorflow { namespace data { @@ -82,8 +82,8 @@ using WorkerConfig = experimental::WorkerConfig; // Moves the element into the response. If the tensor contains a single // CompressedElement variant, the move will be zero-copy. Otherwise, the tensor // data will be serialized as TensorProtos. -Status MoveElementToResponse(std::vector&& element, - GetElementResponse& resp) { +absl::Status MoveElementToResponse(std::vector&& element, + GetElementResponse& resp) { if (element.size() != 1 || element[0].dtype() != DT_VARIANT || !TensorShapeUtils::IsScalar(element[0].shape())) { for (const auto& component : element) { @@ -170,7 +170,7 @@ DataServiceWorkerImpl::~DataServiceWorkerImpl() { heartbeat_cv_.notify_one(); } -Status DataServiceWorkerImpl::Start( +absl::Status DataServiceWorkerImpl::Start( const std::string& worker_address, const std::vector& transfer_servers) { VLOG(3) << "Starting tf.data service worker at address " << worker_address; @@ -226,7 +226,7 @@ void DataServiceWorkerImpl::Stop() { 1000); } -Status DataServiceWorkerImpl::ValidateWorkerConfig() const { +absl::Status DataServiceWorkerImpl::ValidateWorkerConfig() const { const bool any_tag_is_empty = absl::c_any_of( config_.worker_tags(), [](const std::string& worker_tag) { return worker_tag.empty(); }); @@ -255,7 +255,7 @@ DataServiceWorkerImpl::CreateDispatcherClient() const TF_LOCKS_EXCLUDED(mu_) { return dispatcher; } -Status DataServiceWorkerImpl::GetElementResult( +absl::Status DataServiceWorkerImpl::GetElementResult( const GetElementRequest* request, struct GetElementResult* result) { Task* task = nullptr; { @@ -310,15 +310,15 @@ Status DataServiceWorkerImpl::GetElementResult( return absl::OkStatus(); } -Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request, - ProcessTaskResponse* response) { +absl::Status DataServiceWorkerImpl::ProcessTask( + const ProcessTaskRequest* request, ProcessTaskResponse* response) { mutex_lock l(mu_); const TaskDef& task = request->task(); VLOG(3) << "Received request to process task " << task.task_id(); return ProcessTaskInternal(task); } -Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) +absl::Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::shared_ptr& task = tasks_[task_def.task_id()]; if (task) { @@ -333,7 +333,7 @@ Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def) return absl::OkStatus(); } -Status DataServiceWorkerImpl::EnsureTaskInitialized( +absl::Status DataServiceWorkerImpl::EnsureTaskInitialized( DataServiceWorkerImpl::Task& task) { if (task.task_def.worker_address() != worker_address_) { return errors::Internal(absl::Substitute( @@ -367,7 +367,7 @@ absl::StatusOr DataServiceWorkerImpl::GetDatasetDef( return task_def.dataset_def(); case TaskDef::kPath: { DatasetDef def; - Status s = ReadDatasetDef(task_def.path(), def); + absl::Status s = ReadDatasetDef(task_def.path(), def); if (!s.ok()) { LOG(INFO) << "Failed to read dataset from " << task_def.path() << ": " << s << ". Falling back to reading from dispatcher."; @@ -476,8 +476,8 @@ void DataServiceWorkerImpl::StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_) { } } -Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, - GetElementResponse* response) { +absl::Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, + GetElementResponse* response) { VLOG(3) << "Received GetElement request for task " << request->task_id(); struct GetElementResult result; TF_RETURN_IF_ERROR(GetElementResult(request, &result)); @@ -491,7 +491,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, return absl::OkStatus(); } -Status DataServiceWorkerImpl::GetWorkerTasks( +absl::Status DataServiceWorkerImpl::GetWorkerTasks( const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) { mutex_lock l(mu_); for (const auto& it : tasks_) { @@ -504,7 +504,7 @@ Status DataServiceWorkerImpl::GetWorkerTasks( return absl::OkStatus(); } -Status DataServiceWorkerImpl::GetSnapshotTaskProgresses( +absl::Status DataServiceWorkerImpl::GetSnapshotTaskProgresses( const GetSnapshotTaskProgressesRequest* request, GetSnapshotTaskProgressesResponse* response) { for (const auto& snapshot_task_progress : GetSnapshotTaskProgress()) { @@ -525,7 +525,7 @@ void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) { return; } } - Status s = SendTaskUpdates(); + absl::Status s = SendTaskUpdates(); if (!s.ok()) { LOG(WARNING) << "Failed to send task updates to dispatcher: " << s; mutex_lock l(mu_); @@ -540,7 +540,7 @@ void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) { } } -Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) { +absl::Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) { std::vector task_progress; { mutex_lock l(mu_); @@ -585,14 +585,14 @@ void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) { continue; } } - Status s = Heartbeat(); + absl::Status s = Heartbeat(); if (!s.ok()) { LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s; } } } -Status DataServiceWorkerImpl::Heartbeat() { +absl::Status DataServiceWorkerImpl::Heartbeat() { WorkerHeartbeatRequest request = BuildWorkerHeartbeatRequest(); TF_ASSIGN_OR_RETURN(WorkerHeartbeatResponse response, dispatcher_->WorkerHeartbeat(request)); @@ -699,7 +699,7 @@ void DataServiceWorkerImpl::UpdateTasks(const WorkerHeartbeatResponse& response) if (deleted_tasks_.contains(task.task_id())) { continue; } - Status s = ProcessTaskInternal(task); + absl::Status s = ProcessTaskInternal(task); if (!s.ok() && !errors::IsAlreadyExists(s)) { LOG(WARNING) << "Failed to start processing task " << task.task_id() << ": " << s; @@ -723,7 +723,7 @@ void DataServiceWorkerImpl::UpdateTasks(const WorkerHeartbeatResponse& response) } // TODO(yangchen): Figure out why `mutex_lock`s here are needed for sanitizers. -Status DataServiceWorkerImpl::UpdateSnapshotWriters( +absl::Status DataServiceWorkerImpl::UpdateSnapshotWriters( const WorkerHeartbeatResponse& response) TF_LOCKS_EXCLUDED(mu_) { absl::flat_hash_set assigned_snapshot_task_keys; for (const SnapshotTaskDef& snapshot_task : response.snapshot_tasks()) { diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 6fa472212c903e..c256c88c226121 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -55,16 +55,17 @@ class DataServiceWorkerImpl { // constructor because the worker may be binding to port `0`, in which case // the address isn't known until the worker has started and decided which port // to bind to. - Status Start(const std::string& worker_address, - const std::vector& transfer_servers); + absl::Status Start( + const std::string& worker_address, + const std::vector& transfer_servers); // Stops the worker, attempting a clean shutdown by rejecting new requests // and waiting for outstanding requests to complete. void Stop(); // Serves a GetElement request, storing the result in `*result`. See // worker.proto for GetElement API documentation. - Status GetElementResult(const GetElementRequest* request, - GetElementResult* result); + absl::Status GetElementResult(const GetElementRequest* request, + GetElementResult* result); // Deletes the local task and iterator. Only called by local clients to delete // unused task iterators assuming the task is not read by remote clients. This @@ -74,15 +75,15 @@ class DataServiceWorkerImpl { // See worker.proto for API documentation. /// Dispatcher-facing API. - Status ProcessTask(const ProcessTaskRequest* request, - ProcessTaskResponse* response); + absl::Status ProcessTask(const ProcessTaskRequest* request, + ProcessTaskResponse* response); /// Client-facing API. - Status GetElement(const GetElementRequest* request, - GetElementResponse* response); - Status GetWorkerTasks(const GetWorkerTasksRequest* request, - GetWorkerTasksResponse* response); - Status GetSnapshotTaskProgresses( + absl::Status GetElement(const GetElementRequest* request, + GetElementResponse* response); + absl::Status GetWorkerTasks(const GetWorkerTasksRequest* request, + GetWorkerTasksResponse* response); + absl::Status GetSnapshotTaskProgresses( const GetSnapshotTaskProgressesRequest* request, GetSnapshotTaskProgressesResponse* response); @@ -121,16 +122,16 @@ class DataServiceWorkerImpl { }; // Validates the worker config. - Status ValidateWorkerConfig() const; + absl::Status ValidateWorkerConfig() const; // Creates and initializes a dispatcher client. absl::StatusOr> CreateDispatcherClient() const TF_LOCKS_EXCLUDED(mu_); // Sends task status to the dispatcher and checks for dispatcher commands. - Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); + absl::Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); // Creates an iterator to process a task. - Status ProcessTaskInternal(const TaskDef& task) + absl::Status ProcessTaskInternal(const TaskDef& task) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status EnsureTaskInitialized(Task& task); + absl::Status EnsureTaskInitialized(Task& task); // Stops a task, cancelling the task's outstanding requests and waiting for // them to finish. void StopTask(Task& task) TF_LOCKS_EXCLUDED(mu_); @@ -139,7 +140,7 @@ class DataServiceWorkerImpl { // A thread for doing periodic heartbeats to the dispatcher. void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_); // Performs a heartbeat to the dispatcher. - Status Heartbeat(); + absl::Status Heartbeat(); // Check with the dispatcher to see whether or not to disable compression. absl::StatusOr DisableCompressionAtRuntime( const std::string& dataset_id) const; @@ -155,7 +156,7 @@ class DataServiceWorkerImpl { void UpdateTasks(const WorkerHeartbeatResponse& response) TF_LOCKS_EXCLUDED(mu_); // Updates the distributed snapshot tasks according to the heartbeat response. - Status UpdateSnapshotWriters(const WorkerHeartbeatResponse& response) + absl::Status UpdateSnapshotWriters(const WorkerHeartbeatResponse& response) TF_LOCKS_EXCLUDED(mu_); // Creates an dataset iterator for snapshot writers. absl::StatusOr> diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index 8874484c835af0..50f3bc86cab92b 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/dataset.h" @@ -47,8 +49,6 @@ limitations under the License. #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -115,10 +115,10 @@ std::string GetCheckpointFileName(const std::string& shard_directory, static_cast(checkpoint_id))); } -Status Writer::Create(Env* env, const std::string& filename, - const std::string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_writer) { +absl::Status Writer::Create(Env* env, const std::string& filename, + const std::string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_writer) { switch (version) { case 1: *out_writer = @@ -140,7 +140,7 @@ TFRecordWriter::TFRecordWriter(const std::string& filename, const std::string& compression_type) : filename_(filename), compression_type_(compression_type) {} -Status TFRecordWriter::Initialize(tensorflow::Env* env) { +absl::Status TFRecordWriter::Initialize(tensorflow::Env* env) { TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); record_writer_ = std::make_unique( @@ -149,7 +149,7 @@ Status TFRecordWriter::Initialize(tensorflow::Env* env) { return absl::OkStatus(); } -Status TFRecordWriter::WriteTensors(const std::vector& tensors) { +absl::Status TFRecordWriter::WriteTensors(const std::vector& tensors) { for (const auto& tensor : tensors) { TensorProto proto; tensor.AsProtoTensorContent(&proto); @@ -177,12 +177,12 @@ Status TFRecordWriter::WriteTensors(const std::vector& tensors) { return absl::OkStatus(); } -Status TFRecordWriter::Sync() { +absl::Status TFRecordWriter::Sync() { TF_RETURN_IF_ERROR(record_writer_->Flush()); return dest_->Flush(); } -Status TFRecordWriter::Close() { +absl::Status TFRecordWriter::Close() { if (record_writer_ != nullptr) { TF_RETURN_IF_ERROR(Sync()); TF_RETURN_IF_ERROR(record_writer_->Close()); @@ -194,7 +194,7 @@ Status TFRecordWriter::Close() { } TFRecordWriter::~TFRecordWriter() { - Status s = Close(); + absl::Status s = Close(); if (!s.ok()) { LOG(ERROR) << "Failed to close snapshot file " << filename_ << ": " << s; } @@ -207,7 +207,7 @@ CustomWriter::CustomWriter(const std::string& filename, compression_type_(compression_type), dtypes_(dtypes) {} -Status CustomWriter::Initialize(tensorflow::Env* env) { +absl::Status CustomWriter::Initialize(tensorflow::Env* env) { TF_RETURN_IF_ERROR(env->NewAppendableFile(filename_, &dest_)); #if defined(IS_SLIM_BUILD) if (compression_type_ != io::compression::kNone) { @@ -241,7 +241,7 @@ Status CustomWriter::Initialize(tensorflow::Env* env) { return absl::OkStatus(); } -Status CustomWriter::WriteTensors(const std::vector& tensors) { +absl::Status CustomWriter::WriteTensors(const std::vector& tensors) { if (compression_type_ != io::compression::kSnappy) { experimental::SnapshotRecord record; for (const auto& tensor : tensors) { @@ -324,9 +324,9 @@ Status CustomWriter::WriteTensors(const std::vector& tensors) { return absl::OkStatus(); } -Status CustomWriter::Sync() { return dest_->Sync(); } +absl::Status CustomWriter::Sync() { return dest_->Sync(); } -Status CustomWriter::Close() { +absl::Status CustomWriter::Close() { if (dest_ != nullptr) { TF_RETURN_IF_ERROR(dest_->Close()); dest_ = nullptr; @@ -339,13 +339,13 @@ Status CustomWriter::Close() { } CustomWriter::~CustomWriter() { - Status s = Close(); + absl::Status s = Close(); if (!s.ok()) { LOG(ERROR) << "Could not finish writing file: " << s; } } -Status CustomWriter::WriteRecord(const StringPiece& data) { +absl::Status CustomWriter::WriteRecord(const StringPiece& data) { char header[kHeaderSize]; core::EncodeFixed64(header, data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); @@ -353,7 +353,7 @@ Status CustomWriter::WriteRecord(const StringPiece& data) { } #if defined(TF_CORD_SUPPORT) -Status CustomWriter::WriteRecord(const absl::Cord& data) { +absl::Status CustomWriter::WriteRecord(const absl::Cord& data) { char header[kHeaderSize]; core::EncodeFixed64(header, data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); @@ -361,10 +361,10 @@ Status CustomWriter::WriteRecord(const absl::Cord& data) { } #endif // TF_CORD_SUPPORT -Status Reader::Create(Env* env, const std::string& filename, - const string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_reader) { +absl::Status Reader::Create(Env* env, const std::string& filename, + const string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_reader) { switch (version) { // CustomReader is able to read a legacy snapshot file format (v0) though // custom writer doesn't have the ability to write it any more since it is @@ -386,7 +386,7 @@ Status Reader::Create(Env* env, const std::string& filename, return (*out_reader)->Initialize(env); } -Status Reader::SkipRecords(int64_t num_records) { +absl::Status Reader::SkipRecords(int64_t num_records) { // TODO(frankchn): Optimize to not parse the entire Tensor and actually skip. for (int i = 0; i < num_records; ++i) { std::vector unused_tensors; @@ -418,16 +418,17 @@ class Reader::Dataset : public DatasetBase { std::string DebugString() const override { return "SnapshotDatasetReader"; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { Node* shard_dir = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(shard_dir_, &shard_dir)); @@ -463,7 +464,7 @@ class Reader::Dataset : public DatasetBase { : DatasetIterator(params), start_index_(dataset()->start_index_) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { // TODO(jsimsa): This only needs to happen when we are not restoring but // parallel_interleave op implementation caches IteratorContext (and thus // the is_restoring bit ends up being inaccurate). @@ -474,16 +475,16 @@ class Reader::Dataset : public DatasetBase { } protected: - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { *end_of_sequence = false; - Status s = reader_->ReadTensors(out_tensors); + absl::Status s = reader_->ReadTensors(out_tensors); if (!absl::IsOutOfRange(s)) { start_index_++; return s; } - Status status = AdvanceToNextFile(ctx->env()); + absl::Status status = AdvanceToNextFile(ctx->env()); if (absl::IsNotFound(status)) { *end_of_sequence = true; return absl::OkStatus(); @@ -491,8 +492,8 @@ class Reader::Dataset : public DatasetBase { return status; } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentCheckpointID), current_checkpoint_id_)); TF_RETURN_IF_ERROR( @@ -500,8 +501,8 @@ class Reader::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kCurrentCheckpointID), ¤t_checkpoint_id_)); TF_RETURN_IF_ERROR( @@ -514,7 +515,7 @@ class Reader::Dataset : public DatasetBase { } private: - Status AdvanceToNextFile(Env* env) { + absl::Status AdvanceToNextFile(Env* env) { start_index_ = 0; current_checkpoint_id_++; TF_RETURN_IF_ERROR(env->FileExists(GetCurrentFilename())); @@ -528,7 +529,7 @@ class Reader::Dataset : public DatasetBase { } // TODO(frankchn): Optimize this to not parse every single element. - Status AdvanceToStartIndex(IteratorContext* ctx) { + absl::Status AdvanceToStartIndex(IteratorContext* ctx) { for (int64_t i = 0; i < start_index_; ++i) { std::vector unused; TF_RETURN_IF_ERROR(reader_->ReadTensors(&unused)); @@ -594,17 +595,18 @@ class Reader::NestedDataset : public DatasetBase { return "SnapshotNestedDatasetReader"; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { std::vector input_graph_nodes; input_graph_nodes.reserve(datasets_.size()); for (const auto& dataset : datasets_) { @@ -636,9 +638,9 @@ class Reader::NestedDataset : public DatasetBase { : DatasetIterator(params) {} protected: - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { const int64_t num_datasets = dataset()->datasets_.size(); *end_of_sequence = num_datasets == index_; if (!*end_of_sequence) { @@ -654,14 +656,14 @@ class Reader::NestedDataset : public DatasetBase { return absl::OkStatus(); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kIndex), &index_)); return absl::OkStatus(); } @@ -689,13 +691,11 @@ void Reader::NestedDatasetOp::MakeDataset(OpKernelContext* ctx, (*output)->Initialize(/*metadata=*/{}); } -Status Reader::MakeNestedDataset(Env* env, - const std::vector& shard_dirs, - const string& compression_type, int version, - const DataTypeVector& dtypes, - const std::vector& shapes, - const int64_t start_index, - DatasetBase** output) { +absl::Status Reader::MakeNestedDataset( + Env* env, const std::vector& shard_dirs, + const string& compression_type, int version, const DataTypeVector& dtypes, + const std::vector& shapes, const int64_t start_index, + DatasetBase** output) { std::vector datasets; datasets.reserve(shard_dirs.size()); @@ -746,7 +746,7 @@ TFRecordReaderImpl::TFRecordReaderImpl( compression_(compression), output_buffer_size_(output_buffer_size) {} -Status TFRecordReaderImpl::Initialize(Env* env) { +absl::Status TFRecordReaderImpl::Initialize(Env* env) { TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_)); auto options = io::RecordReaderOptions::CreateRecordReaderOptions( /*compression_type=*/compression_); @@ -798,7 +798,7 @@ absl::StatusOr TFRecordReaderImpl::Parse(const tstring& record) { return tensor; } -Status TFRecordReader::ReadTensors(std::vector* read_tensors) { +absl::Status TFRecordReader::ReadTensors(std::vector* read_tensors) { read_tensors->clear(); read_tensors->reserve(dtypes_.size()); for (int i = 0; i < dtypes_.size(); ++i) { @@ -816,7 +816,7 @@ CustomReader::CustomReader(const std::string& filename, version_(version), dtypes_(dtypes) {} -Status CustomReader::Initialize(Env* env) { +absl::Status CustomReader::Initialize(Env* env) { TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename_, &file_)); input_stream_ = std::make_unique(file_.get()); @@ -858,7 +858,7 @@ Status CustomReader::Initialize(Env* env) { return absl::OkStatus(); } -Status CustomReader::ReadTensors(std::vector* read_tensors) { +absl::Status CustomReader::ReadTensors(std::vector* read_tensors) { tsl::profiler::TraceMe activity( [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, tsl::profiler::TraceMeLevel::kInfo); @@ -912,7 +912,7 @@ Status CustomReader::ReadTensors(std::vector* read_tensors) { return absl::OkStatus(); } -Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { +absl::Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { experimental::SnapshotRecord record; #if defined(PLATFORM_GOOGLE) absl::Cord c; @@ -933,7 +933,7 @@ Status CustomReader::ReadTensorsV0(std::vector* read_tensors) { return absl::OkStatus(); } -Status CustomReader::SnappyUncompress( +absl::Status CustomReader::SnappyUncompress( const experimental::SnapshotTensorMetadata* metadata, std::vector* simple_tensors, std::vector, size_t>>* @@ -983,7 +983,7 @@ Status CustomReader::SnappyUncompress( return absl::OkStatus(); } -Status CustomReader::ReadRecord(tstring* record) { +absl::Status CustomReader::ReadRecord(tstring* record) { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); @@ -991,7 +991,7 @@ Status CustomReader::ReadRecord(tstring* record) { } #if defined(TF_CORD_SUPPORT) -Status CustomReader::ReadRecord(absl::Cord* record) { +absl::Status CustomReader::ReadRecord(absl::Cord* record) { tstring header; TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(kHeaderSize, &header)); uint64 length = core::DecodeFixed64(header.data()); @@ -1008,8 +1008,9 @@ Status CustomReader::ReadRecord(absl::Cord* record) { } #endif // TF_CORD_SUPPORT -Status WriteMetadataFile(Env* env, const string& dir, - const experimental::SnapshotMetadataRecord* metadata) { +absl::Status WriteMetadataFile( + Env* env, const string& dir, + const experimental::SnapshotMetadataRecord* metadata) { string metadata_filename = io::JoinPath(dir, kMetadataFilename); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(dir)); std::string tmp_filename = @@ -1018,7 +1019,7 @@ Status WriteMetadataFile(Env* env, const string& dir, return env->RenameFile(tmp_filename, metadata_filename); } -Status WriteMetadataFile( +absl::Status WriteMetadataFile( Env* env, const string& dir, const experimental::DistributedSnapshotMetadata* metadata) { string metadata_filename = io::JoinPath(dir, kMetadataFilename); @@ -1029,11 +1030,11 @@ Status WriteMetadataFile( return env->RenameFile(tmp_filename, metadata_filename); } -Status ReadMetadataFile(Env* env, const string& dir, - experimental::SnapshotMetadataRecord* metadata, - bool* file_exists) { +absl::Status ReadMetadataFile(Env* env, const string& dir, + experimental::SnapshotMetadataRecord* metadata, + bool* file_exists) { string metadata_filename = io::JoinPath(dir, kMetadataFilename); - Status s = env->FileExists(metadata_filename); + absl::Status s = env->FileExists(metadata_filename); *file_exists = s.ok(); if (*file_exists) { @@ -1043,11 +1044,11 @@ Status ReadMetadataFile(Env* env, const string& dir, } } -Status ReadMetadataFile(Env* env, const string& dir, - experimental::DistributedSnapshotMetadata* metadata, - bool* file_exists) { +absl::Status ReadMetadataFile( + Env* env, const string& dir, + experimental::DistributedSnapshotMetadata* metadata, bool* file_exists) { string metadata_filename = io::JoinPath(dir, kMetadataFilename); - Status s = env->FileExists(metadata_filename); + absl::Status s = env->FileExists(metadata_filename); *file_exists = s.ok(); if (*file_exists) { @@ -1057,8 +1058,8 @@ Status ReadMetadataFile(Env* env, const string& dir, } } -Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, - const GraphDef* graph) { +absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, + const GraphDef* graph) { std::string hash_hex = strings::StrCat(strings::Hex(hash, strings::kZeroPad16)); std::string graph_file = @@ -1069,10 +1070,10 @@ Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, return WriteTextProto(env, graph_file, *graph); } -Status DetermineOpState(const std::string& mode_string, bool file_exists, - const experimental::SnapshotMetadataRecord* metadata, - const uint64 pending_snapshot_expiry_seconds, - Mode* mode) { +absl::Status DetermineOpState( + const std::string& mode_string, bool file_exists, + const experimental::SnapshotMetadataRecord* metadata, + const uint64 pending_snapshot_expiry_seconds, Mode* mode) { if (mode_string == kModeRead) { // In read mode, we should expect a metadata file is written. if (!file_exists) { @@ -1124,7 +1125,7 @@ AsyncWriter::AsyncWriter(Env* env, int64_t file_index, const std::string& shard_directory, uint64 checkpoint_id, const std::string& compression, int64_t version, const DataTypeVector& output_types, - std::function done) { + std::function done) { thread_ = absl::WrapUnique(env->StartThread( ThreadOptions(), absl::StrCat("writer_thread_", file_index), [this, env, shard_directory, checkpoint_id, compression, version, @@ -1157,10 +1158,12 @@ void AsyncWriter::Consume(ElementOrEOF* be) { bool AsyncWriter::ElementAvailable() { return !deque_.empty(); } -Status AsyncWriter::WriterThread(Env* env, const std::string& shard_directory, - uint64 checkpoint_id, - const std::string& compression, - int64_t version, DataTypeVector output_types) { +absl::Status AsyncWriter::WriterThread(Env* env, + const std::string& shard_directory, + uint64 checkpoint_id, + const std::string& compression, + int64_t version, + DataTypeVector output_types) { std::unique_ptr writer; TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(shard_directory)); diff --git a/tensorflow/core/data/snapshot_utils.h b/tensorflow/core/data/snapshot_utils.h index 7e3b897bdddb0d..d543dcb3d29e40 100644 --- a/tensorflow/core/data/snapshot_utils.h +++ b/tensorflow/core/data/snapshot_utils.h @@ -84,25 +84,25 @@ std::string GetCheckpointFileName(const std::string& shard_directory, class Writer { public: // Creates a new writer object. - static Status Create(Env* env, const std::string& filename, - const std::string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_writer); + static absl::Status Create(Env* env, const std::string& filename, + const std::string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_writer); // Writes a vector of tensors to the snapshot writer file. - virtual Status WriteTensors(const std::vector& tensors) = 0; + virtual absl::Status WriteTensors(const std::vector& tensors) = 0; // Flushes any in-memory buffers to disk. - virtual Status Sync() = 0; + virtual absl::Status Sync() = 0; // Closes and finalizes the snapshot file. All calls to any other method will // be invalid after this call. - virtual Status Close() = 0; + virtual absl::Status Close() = 0; virtual ~Writer() = default; protected: - virtual Status Initialize(tensorflow::Env* env) = 0; + virtual absl::Status Initialize(tensorflow::Env* env) = 0; }; // Writes snapshots with the standard TFRecord file format. @@ -111,13 +111,13 @@ class TFRecordWriter : public Writer { TFRecordWriter(const std::string& filename, const std::string& compression_type); - Status Initialize(tensorflow::Env* env) override; + absl::Status Initialize(tensorflow::Env* env) override; - Status WriteTensors(const std::vector& tensors) override; + absl::Status WriteTensors(const std::vector& tensors) override; - Status Sync() override; + absl::Status Sync() override; - Status Close() override; + absl::Status Close() override; ~TFRecordWriter() override; @@ -142,22 +142,22 @@ class CustomWriter : public Writer { CustomWriter(const std::string& filename, const std::string& compression_type, const DataTypeVector& dtypes); - Status WriteTensors(const std::vector& tensors) override; + absl::Status WriteTensors(const std::vector& tensors) override; - Status Sync() override; + absl::Status Sync() override; - Status Close() override; + absl::Status Close() override; ~CustomWriter() override; protected: - Status Initialize(tensorflow::Env* env) override; + absl::Status Initialize(tensorflow::Env* env) override; private: - Status WriteRecord(const StringPiece& data); + absl::Status WriteRecord(const StringPiece& data); #if defined(TF_CORD_SUPPORT) - Status WriteRecord(const absl::Cord& data); + absl::Status WriteRecord(const absl::Cord& data); #endif // TF_CORD_SUPPORT std::unique_ptr dest_; @@ -209,38 +209,37 @@ class Reader { // Creates a new Reader object that reads data from `filename`. Note that // the `version`, `compression_type`, and `dtypes` arguments passed into // `Writer` and `Reader` must be the same for the reading to succeed. - static Status Create(Env* env, const std::string& filename, - const string& compression_type, int version, - const DataTypeVector& dtypes, - std::unique_ptr* out_reader); + static absl::Status Create(Env* env, const std::string& filename, + const string& compression_type, int version, + const DataTypeVector& dtypes, + std::unique_ptr* out_reader); // Returns a nested dataset for a set of given snapshot file names. // // This function takes a vector of snapshot files, and returns a nested // dataset. Each element within the nested dataset is itself a dataset, and // contains all the elements written out to each individual snapshot file. - static Status MakeNestedDataset(Env* env, - const std::vector& shard_dirs, - const string& compression_type, int version, - const DataTypeVector& dtypes, - const std::vector& shapes, - int64_t start_index, DatasetBase** output); + static absl::Status MakeNestedDataset( + Env* env, const std::vector& shard_dirs, + const string& compression_type, int version, const DataTypeVector& dtypes, + const std::vector& shapes, int64_t start_index, + DatasetBase** output); // Returns a nested dataset for the given datasets. static void MakeNestedDataset(const std::vector& datasets, DatasetBase** output); // Reads a vector of Tensors from the snapshot file. - virtual Status ReadTensors(std::vector* read_tensors) = 0; + virtual absl::Status ReadTensors(std::vector* read_tensors) = 0; // Skips `num_records`. Equivalent to calling `ReadTensors` `num_records` // times then discarding the results. - virtual Status SkipRecords(int64_t num_records); + virtual absl::Status SkipRecords(int64_t num_records); virtual ~Reader() = default; protected: - virtual Status Initialize(Env* env) = 0; + virtual absl::Status Initialize(Env* env) = 0; class Dataset; class NestedDataset; @@ -251,7 +250,7 @@ class TFRecordReaderImpl { // Constructs a `TFRecordReaderImpl`. // `filename` is the file to read from. // `compression_type` is the compression method, as defined in - // tensorflow/tsl/lib/io/compression.h. + // tensorflow/compiler/xla/tsl/lib/io/compression.h. // `output_buffer_size` specifies the buffer size required by Snappy/Zlib // compression algorithms. Ignored if compression is not enabled. TFRecordReaderImpl(const std::string& filename, const string& compression, @@ -259,7 +258,7 @@ class TFRecordReaderImpl { // Initializes the reader. Callers must initialize the reader before calling // `GetNext` or `GetTensors`. - Status Initialize(Env* env); + absl::Status Initialize(Env* env); // Reads the next Tensor in the input file. absl::StatusOr GetNext(); @@ -295,11 +294,13 @@ class TFRecordReader : public Reader { // Initializes the reader. Callers must initialize the reader before calling // `ReadTensors`. - Status Initialize(Env* env) override { return reader_impl_.Initialize(env); } + absl::Status Initialize(Env* env) override { + return reader_impl_.Initialize(env); + } // Reads Tensors into `read_tensors`. Returns OK on success, OutOfRange for // end of file, or an error status if there is an error. - Status ReadTensors(std::vector* read_tensors) override; + absl::Status ReadTensors(std::vector* read_tensors) override; // Returns the number of bytes read. uint64_t BytesRead() const { return reader_impl_.BytesRead(); } @@ -330,26 +331,26 @@ class CustomReader : public Reader { CustomReader(const std::string& filename, const string& compression_type, int version, const DataTypeVector& dtypes); - Status ReadTensors(std::vector* read_tensors) override; + absl::Status ReadTensors(std::vector* read_tensors) override; ~CustomReader() override = default; protected: - Status Initialize(Env* env) override; + absl::Status Initialize(Env* env) override; private: - Status ReadTensorsV0(std::vector* read_tensors); + absl::Status ReadTensorsV0(std::vector* read_tensors); - Status SnappyUncompress( + absl::Status SnappyUncompress( const experimental::SnapshotTensorMetadata* metadata, std::vector* simple_tensors, std::vector, size_t>>* tensor_proto_strs); - Status ReadRecord(tstring* record); + absl::Status ReadRecord(tstring* record); #if defined(TF_CORD_SUPPORT) - Status ReadRecord(absl::Cord* record); + absl::Status ReadRecord(absl::Cord* record); #endif std::string filename_; @@ -364,36 +365,38 @@ class CustomReader : public Reader { }; // Writes snapshot metadata to the given directory. -Status WriteMetadataFile(Env* env, const string& dir, - const experimental::SnapshotMetadataRecord* metadata); +absl::Status WriteMetadataFile( + Env* env, const string& dir, + const experimental::SnapshotMetadataRecord* metadata); // Writes distributed snapshot metadata to the given directory. An error is // returned if `dir` is unable to be created or if `metadata` is unable to be // written. -Status WriteMetadataFile( +absl::Status WriteMetadataFile( Env* env, const string& dir, const experimental::DistributedSnapshotMetadata* metadata); // Reads snapshot metadata from the given directory. -Status ReadMetadataFile(Env* env, const string& dir, - experimental::SnapshotMetadataRecord* metadata, - bool* file_exists); +absl::Status ReadMetadataFile(Env* env, const string& dir, + experimental::SnapshotMetadataRecord* metadata, + bool* file_exists); // Reads distributed snapshot metadata from the given directory. If the file // doesn't exist in `dir`, `file_exists` is set to true and an ok status is // returned. If the file exists in `dir` but is unable to be opened, an error // is returned. -Status ReadMetadataFile(Env* env, const string& dir, - experimental::DistributedSnapshotMetadata* metadata, - bool* file_exists); +absl::Status ReadMetadataFile( + Env* env, const string& dir, + experimental::DistributedSnapshotMetadata* metadata, bool* file_exists); // Writes a dataset graph to the given directory. -Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, - const GraphDef* graph); +absl::Status DumpDatasetGraph(Env* env, const std::string& path, uint64 hash, + const GraphDef* graph); -Status DetermineOpState(const std::string& mode_string, bool file_exists, - const experimental::SnapshotMetadataRecord* metadata, - uint64 pending_snapshot_expiry_seconds, Mode* mode); +absl::Status DetermineOpState( + const std::string& mode_string, bool file_exists, + const experimental::SnapshotMetadataRecord* metadata, + uint64 pending_snapshot_expiry_seconds, Mode* mode); // Represents a dataset element or EOF. struct ElementOrEOF { @@ -420,7 +423,7 @@ class AsyncWriter { const std::string& shard_directory, uint64 checkpoint_id, const std::string& compression, int64_t version, const DataTypeVector& output_types, - std::function done); + std::function done); // Writes the given tensors. The method is non-blocking and returns without // waiting for the element to be written. @@ -433,9 +436,10 @@ class AsyncWriter { private: void Consume(ElementOrEOF* be) TF_LOCKS_EXCLUDED(mu_); bool ElementAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status WriterThread(Env* env, const std::string& shard_directory, - uint64 checkpoint_id, const std::string& compression, - int64_t version, DataTypeVector output_types); + absl::Status WriterThread(Env* env, const std::string& shard_directory, + uint64 checkpoint_id, + const std::string& compression, int64_t version, + DataTypeVector output_types); mutex mu_; std::deque deque_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/data/split_utils_test.cc b/tensorflow/core/data/split_utils_test.cc index 060f2b99b75018..5ee3569709d796 100644 --- a/tensorflow/core/data/split_utils_test.cc +++ b/tensorflow/core/data/split_utils_test.cc @@ -32,7 +32,7 @@ std::string full_name(const std::string& name) { return FullName("test", name); } -Status SaveAndRestore(SplitProvider* split_provider) { +absl::Status SaveAndRestore(SplitProvider* split_provider) { VariantTensorDataWriter writer; TF_RETURN_IF_ERROR(split_provider->Save(full_name, &writer)); std::vector variants; @@ -42,8 +42,8 @@ Status SaveAndRestore(SplitProvider* split_provider) { return absl::OkStatus(); } -Status CheckOutput(SplitProvider* split_provider, - std::vector expected) { +absl::Status CheckOutput(SplitProvider* split_provider, + std::vector expected) { int64_t next = 0; bool end_of_splits = false; while (!end_of_splits) { diff --git a/tensorflow/core/data/standalone.cc b/tensorflow/core/data/standalone.cc index 1d80acceae95ac..5f7f8ec7fd3aa6 100644 --- a/tensorflow/core/data/standalone.cc +++ b/tensorflow/core/data/standalone.cc @@ -94,7 +94,8 @@ Iterator::~Iterator() { } } -Status Iterator::GetNext(std::vector* outputs, bool* end_of_input) { +absl::Status Iterator::GetNext(std::vector* outputs, + bool* end_of_input) { return iterator_->GetNext(ctx_.get(), outputs, end_of_input); } @@ -115,7 +116,7 @@ absl::StatusOr> Iterator::Save() { return serialized; } -Status Iterator::Restore(const std::vector& saved_iterator) { +absl::Status Iterator::Restore(const std::vector& saved_iterator) { std::vector data; data.reserve(saved_iterator.size()); for (int i = 0; i < saved_iterator.size(); ++i) { @@ -135,8 +136,8 @@ Status Iterator::Restore(const std::vector& saved_iterator) { std::shared_ptr Iterator::model() const { return ctx_->model(); } -Status Dataset::FromGraph(Params params, const GraphDef& graph_def, - std::unique_ptr* result) { +absl::Status Dataset::FromGraph(Params params, const GraphDef& graph_def, + std::unique_ptr* result) { Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr)); @@ -195,7 +196,7 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def, return absl::OkStatus(); } // static -Status Dataset::MakeIterator( +absl::Status Dataset::MakeIterator( std::vector> split_providers, std::unique_ptr* result) { // Create an `IteratorContext`, which bundles together the necessary runtime @@ -234,11 +235,11 @@ Status Dataset::MakeIterator( return absl::OkStatus(); } -Status Dataset::MakeIterator(std::unique_ptr* result) { +absl::Status Dataset::MakeIterator(std::unique_ptr* result) { return MakeIterator(/*split_providers=*/{}, result); } -Status Dataset::MakeSplitProviders( +absl::Status Dataset::MakeSplitProviders( std::vector>* result) { return finalized_dataset_->MakeSplitProviders(result); } diff --git a/tensorflow/core/data/standalone.h b/tensorflow/core/data/standalone.h index c2e257953a1b29..5b2b2b2c15ba1e 100644 --- a/tensorflow/core/data/standalone.h +++ b/tensorflow/core/data/standalone.h @@ -84,7 +84,7 @@ class Iterator { // Returns the next element of the input pipeline (if there is one) and an // indication of whether the end of the input pipeline has been reached. - Status GetNext(std::vector* outputs, bool* end_of_input); + absl::Status GetNext(std::vector* outputs, bool* end_of_input); // Saves a checkpoint of the iterator. Returns Tensors that can be called with // `Restore()`. @@ -92,7 +92,7 @@ class Iterator { // Restores the iterator from a checkpoint. `saved_iterator` is the serialized // iterator saved by calling `Save()`. - Status Restore(const std::vector& saved_iterator); + absl::Status Restore(const std::vector& saved_iterator); // Returns the dataset model for performance analysis. std::shared_ptr model() const; @@ -119,20 +119,20 @@ class Dataset { }; // Creates a new `Dataset` instance by running the given dataset graph. - static Status FromGraph(Params params, const GraphDef& graph_def, - std::unique_ptr* result); + static absl::Status FromGraph(Params params, const GraphDef& graph_def, + std::unique_ptr* result); ~Dataset(); // Creates an iterator for this dataset. - Status MakeIterator(std::unique_ptr* result); + absl::Status MakeIterator(std::unique_ptr* result); // Creates an iterator, optionally with a split provider. - Status MakeIterator( + absl::Status MakeIterator( std::vector> split_providers, std::unique_ptr* result); // Creates split providers for this dataset. - Status MakeSplitProviders( + absl::Status MakeSplitProviders( std::vector>* result); // Returns a pointer to the underlying dataset. const DatasetBase* Get() const; diff --git a/tensorflow/core/data/standalone_save_restore_test.cc b/tensorflow/core/data/standalone_save_restore_test.cc index 9798021302614f..431c6c4a08cee0 100644 --- a/tensorflow/core/data/standalone_save_restore_test.cc +++ b/tensorflow/core/data/standalone_save_restore_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/standalone.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/tf_data_memory_logger.cc b/tensorflow/core/data/tf_data_memory_logger.cc index f1edc3fe9d08a8..11b88b07c5f148 100644 --- a/tensorflow/core/data/tf_data_memory_logger.cc +++ b/tensorflow/core/data/tf_data_memory_logger.cc @@ -57,7 +57,7 @@ void LogDatasetMemoryUsage() { int64_t total_buffered_bytes = metric_collector->GetModel()->output()->TotalBufferedBytes(); model::ModelProto model_proto; - Status s = metric_collector->GetModel()->ToProto(&model_proto); + absl::Status s = metric_collector->GetModel()->ToProto(&model_proto); if (!s.ok()) { LOG(ERROR) << "Failed to convert model to proto: " << s; } diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index cfd9fb86d55d6f..38dc34ed2588a7 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -287,13 +287,11 @@ tf_cc_binary( # py_proto_library( # name = "debug_service_py_pb2", # has_services = 1, -# api_version = 2, # deps = [":debug_service_proto"], # ) # # py_proto_library( # name = "debugger_event_metadata_py_pb2", -# api_version = 2, # deps = [":debugger_event_metadata_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/core/debug/bfc_dump_reader.cc b/tensorflow/core/debug/bfc_dump_reader.cc index dcb0164e999ebb..5c780c7c9ae09b 100644 --- a/tensorflow/core/debug/bfc_dump_reader.cc +++ b/tensorflow/core/debug/bfc_dump_reader.cc @@ -24,7 +24,7 @@ limitations under the License. namespace tensorflow { MemoryDump ReadDumpFile(const string& fname) { - Status status; + absl::Status status; uint64 file_size = 0; status = Env::Default()->GetFileSize(fname, &file_size); if (!status.ok()) { diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc index 47edb02aeb61dc..1a40b13b227fd5 100644 --- a/tensorflow/core/debug/debug_graph_utils.cc +++ b/tensorflow/core/debug/debug_graph_utils.cc @@ -30,7 +30,7 @@ namespace tensorflow { namespace { // TODO(cais): Switch to safe_strtob when available. -Status ParseBoolString(const string& bool_str, bool* bool_val) { +absl::Status ParseBoolString(const string& bool_str, bool* bool_val) { const string lower_bool_str = absl::AsciiStrToLower(bool_str); if (lower_bool_str == "false" || lower_bool_str == "f" || lower_bool_str == "0") { @@ -48,7 +48,7 @@ Status ParseBoolString(const string& bool_str, bool* bool_val) { } // namespace // static -Status DebugNodeInserter::InsertNodes( +absl::Status DebugNodeInserter::InsertNodes( const protobuf::RepeatedPtrField& watches, Graph* graph, Device* device) { // TODO(cais): This method is getting too large in size. @@ -89,15 +89,16 @@ Status DebugNodeInserter::InsertNodes( watch.debug_urls().begin(), watch.debug_urls().end()); } else { - return Status(absl::StatusCode::kFailedPrecondition, - strings::StrCat( - "output_slot is expected to be -1 for wildcard ", - "node name (\"*\"), but got ", watch.output_slot())); + return absl::Status( + absl::StatusCode::kFailedPrecondition, + strings::StrCat("output_slot is expected to be -1 for wildcard ", + "node name (\"*\"), but got ", + watch.output_slot())); } continue; } else { if (watch.output_slot() < 0) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("A negative output_slot in DebugTensorWatch is ", "valid only for the wildcard node name (\"*\"), ", @@ -183,12 +184,12 @@ Status DebugNodeInserter::InsertNodes( explicit_tensor_match ? tensor_watch_urls[tensor_name] : default_debug_urls; Node* copy_node; - Status copy_s = + absl::Status copy_s = CreateCopyNode(graph, device_type, memory_type == HOST_MEMORY, src_node->name(), src_output_slot, src_dt, tensor_name, debug_ops, debug_urls, ©_node); if (!copy_s.ok()) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("Failed to create Copy/CopyHost node for tensor ", tensor_name, ", due to: ", copy_s.message())); @@ -203,9 +204,9 @@ Status DebugNodeInserter::InsertNodes( const string& debug_op_name = debug_ops[i]; Node* debug_node; - Status debug_s = CreateDebugNode(graph, *device, copy_node->name(), - src_dt, tensor_name, debug_urls, i, - debug_op_name, &debug_node); + absl::Status debug_s = CreateDebugNode( + graph, *device, copy_node->name(), src_dt, tensor_name, debug_urls, + i, debug_op_name, &debug_node); if (debug_s.ok()) { graph->AddEdge(copy_node, 0, debug_node, 0); debug_nodes.push_back(debug_node); @@ -215,7 +216,7 @@ Status DebugNodeInserter::InsertNodes( << "tensor name = " << tensor_name << "; " << "debug op name = " << debug_op_name; } else { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("Failed to create debug node ", debug_op_name, " for tensor ", tensor_name, @@ -299,7 +300,7 @@ const string DebugNodeInserter::GetDebugNodeName(const string& tensor_name, } // static -Status DebugNodeInserter::CreateCopyNode( +absl::Status DebugNodeInserter::CreateCopyNode( Graph* graph, const DeviceType device_type, const bool is_host_memory, const string& src_node_name, const int src_output, const DataType src_dt, const string& tensor_name, const std::vector& debug_ops, @@ -338,30 +339,31 @@ Status DebugNodeInserter::CreateCopyNode( .Attr("debug_ops_spec", debug_ops_spec); if (!builder.Finalize(&node_def).ok()) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("Failed to create node definition ", "for copy op ", copy_node_name, " on watched tensor ", tensor_name)); } - Status s = FindKernelDef(device_type, node_def, &kdef, nullptr); + absl::Status s = FindKernelDef(device_type, node_def, &kdef, nullptr); if (!s.ok()) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("Failed to find kernel definition ", "for copy op ", copy_node_name, " on watched tensor ", tensor_name)); } if (!NodeBuilder(builder).Finalize(graph, copy_node).ok()) { - return Status(absl::StatusCode::kFailedPrecondition, - strings::StrCat("Failed to create copy node ", copy_node_name, - " on watched tensor ", tensor_name)); + return absl::Status( + absl::StatusCode::kFailedPrecondition, + strings::StrCat("Failed to create copy node ", copy_node_name, + " on watched tensor ", tensor_name)); } return absl::OkStatus(); } // static -Status DebugNodeInserter::ParseDebugOpName( +absl::Status DebugNodeInserter::ParseDebugOpName( const string& debug_op_name, string* debug_op_name_proper, std::unordered_map* attributes) { const size_t l_index = debug_op_name.find('('); @@ -413,7 +415,7 @@ Status DebugNodeInserter::ParseDebugOpName( } // static -Status DebugNodeInserter::SetDebugNodeAttributes( +absl::Status DebugNodeInserter::SetDebugNodeAttributes( Node* debug_node, const std::unordered_map& attributes) { std::unordered_set unfulfilled_keys; for (const auto& item : attributes) { @@ -470,7 +472,7 @@ Status DebugNodeInserter::SetDebugNodeAttributes( } // static -Status DebugNodeInserter::CreateDebugNode( +absl::Status DebugNodeInserter::CreateDebugNode( Graph* graph, const Device& device, const string& src_copy_node_name, const DataType src_dt, const string& tensor_name, const std::vector& debug_urls, const int debug_op_num, diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h index 86dc90a13483fb..27cfb357e2b9d9 100644 --- a/tensorflow/core/debug/debug_graph_utils.h +++ b/tensorflow/core/debug/debug_graph_utils.h @@ -71,7 +71,7 @@ class DebugNodeInserter { // // If the nodes (A, B and C) are located on GPU and the edges from A to B or C // is HOST_MEMORY, then the CopyHost op will be used instead of the Copy op. - static Status InsertNodes( + static absl::Status InsertNodes( const protobuf::RepeatedPtrField& watches, Graph* graph, Device* device); @@ -91,7 +91,7 @@ class DebugNodeInserter { const string& debug_op_name); private: - static Status CreateCopyNode( + static absl::Status CreateCopyNode( Graph* graph, const DeviceType device_type, const bool is_host_memory, const string& src_node_name, const int src_output, const DataType src_dt, const string& tensor_name, const std::vector& debug_ops, @@ -103,20 +103,18 @@ class DebugNodeInserter { // connected with an equal sign ("="). Multiple key-value pairs are separated // with semicolons (";"), which optional whitespace in between, e.g., // "DebugNumericSummary(mute_if_healthy=true, lower_bound=-100.0)". - static Status ParseDebugOpName( + static absl::Status ParseDebugOpName( const string& debug_op_name, string* debug_op_name_proper, std::unordered_map* attributes); - static Status SetDebugNodeAttributes( + static absl::Status SetDebugNodeAttributes( Node* debug_node, const std::unordered_map& attributes); - static Status CreateDebugNode(Graph* graph, const Device& device, - const string& src_copy_node_name, - const DataType src_dt, - const string& tensor_name, - const std::vector& debug_urls, - const int debug_op_num, - const string& debug_op_name, Node** debug_node); + static absl::Status CreateDebugNode( + Graph* graph, const Device& device, const string& src_copy_node_name, + const DataType src_dt, const string& tensor_name, + const std::vector& debug_urls, const int debug_op_num, + const string& debug_op_name, Node** debug_node); // TODO(cais): Cut down the number of args to this method. friend class DebugGraphUtilsTest; diff --git a/tensorflow/core/debug/debug_graph_utils_test.cc b/tensorflow/core/debug/debug_graph_utils_test.cc index 8033698a1795f5..5ffee94043a002 100644 --- a/tensorflow/core/debug/debug_graph_utils_test.cc +++ b/tensorflow/core/debug/debug_graph_utils_test.cc @@ -24,9 +24,9 @@ namespace tensorflow { class DebugGraphUtilsTest : public ::testing::Test { protected: - Status ParseDebugOpName(const string& debug_op_name, - string* debug_op_name_proper, - std::unordered_map* attributes) { + absl::Status ParseDebugOpName( + const string& debug_op_name, string* debug_op_name_proper, + std::unordered_map* attributes) { return DebugNodeInserter::ParseDebugOpName( debug_op_name, debug_op_name_proper, attributes); } @@ -45,8 +45,8 @@ TEST_F(DebugGraphUtilsTest, TestMalformedDebugOpName) { string debug_op_name_proper; std::unordered_map attributes; - Status s = ParseDebugOpName("(mute_if_healthy=true)", &debug_op_name_proper, - &attributes); + absl::Status s = ParseDebugOpName("(mute_if_healthy=true)", + &debug_op_name_proper, &attributes); ASSERT_TRUE(errors::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(", &debug_op_name_proper, @@ -62,8 +62,8 @@ TEST_F(DebugGraphUtilsTest, TestDebugOpNameWithMalformedAttributes) { string debug_op_name_proper; std::unordered_map attributes; - Status s = ParseDebugOpName("DebugNumericSummary(=)", &debug_op_name_proper, - &attributes); + absl::Status s = ParseDebugOpName("DebugNumericSummary(=)", + &debug_op_name_proper, &attributes); ASSERT_TRUE(errors::IsInvalidArgument(s)); s = ParseDebugOpName("DebugNumericSummary(mute_if_healthy=)", @@ -130,7 +130,7 @@ TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreThanOneAttributes) { TEST_F(DebugGraphUtilsTest, TestValidDebugOpNameWithMoreDuplicateAttributes) { string debug_op_name_proper; std::unordered_map attributes; - Status s = ParseDebugOpName( + absl::Status s = ParseDebugOpName( "DebugNumericSummary(mute_if_healthy=true; lower_bound=3; " "mute_if_healthy=false;)", &debug_op_name_proper, &attributes); diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc index 87aea157cdb04c..24cde3c1c38488 100644 --- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -89,7 +89,7 @@ TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) { strings::StrCat("grpc://localhost:", testing::PickUnusedPortOrDie()); Tensor tensor(DT_FLOAT, TensorShape({1, 1})); tensor.flat()(0) = 42.0; - Status publish_status = DebugIO::PublishDebugTensor( + absl::Status publish_status = DebugIO::PublishDebugTensor( DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"), tensor, Env::Default()->NowMicros(), {kInvalidGrpcUrl}); @@ -112,7 +112,7 @@ TEST_F(GrpcDebugTest, ConnectionToDelayedStartingServerWorks) { tensor.flat()(0) = 42.0; const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); - Status publish_status = DebugIO::PublishDebugTensor( + absl::Status publish_status = DebugIO::PublishDebugTensor( kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data.url}); ASSERT_TRUE(publish_status.ok()); TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data.url)); @@ -151,7 +151,7 @@ TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) { tensor.flat()(0) = string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); - const Status status = DebugIO::PublishDebugTensor( + const absl::Status status = DebugIO::PublishDebugTensor( kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}); ASSERT_FALSE(status.ok()); ASSERT_NE(status.message().find("string value at index 0 from debug " @@ -167,7 +167,7 @@ TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) { tensor.flat()(1) = string(5000 * 1024, 'A'); const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0, "DebugIdentity"); - const Status status = DebugIO::PublishDebugTensor( + const absl::Status status = DebugIO::PublishDebugTensor( kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}); ASSERT_FALSE(status.ok()); ASSERT_NE(status.message().find("string value at index 1 from debug " @@ -194,7 +194,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { mutex mu; Notification all_done; int tensor_count TF_GUARDED_BY(mu) = 0; - std::vector statuses TF_GUARDED_BY(mu); + std::vector statuses TF_GUARDED_BY(mu); const std::vector urls({server_data_.url}); @@ -210,7 +210,7 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { // Different concurrent tasks will send different tensors. const uint64 wall_time = Env::Default()->NowMicros(); - Status publish_status = DebugIO::PublishDebugTensor( + absl::Status publish_status = DebugIO::PublishDebugTensor( DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", strings::StrCat("synchronized_node_", this_count), 0, "DebugIdentity"), @@ -235,11 +235,11 @@ TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) { delete tp; // Close the debug gRPC stream. - Status close_status = DebugIO::CloseDebugURL(server_data_.url); + absl::Status close_status = DebugIO::CloseDebugURL(server_data_.url); ASSERT_TRUE(close_status.ok()); // Check all statuses from the PublishDebugTensor calls(). - for (const Status& status : statuses) { + for (const absl::Status& status : statuses) { TF_ASSERT_OK(status); } @@ -286,7 +286,7 @@ TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUsingGrpcGating) { kDebugNodeKey); // Close the debug gRPC stream. - Status close_status = DebugIO::CloseDebugURL(server_data_.url); + absl::Status close_status = DebugIO::CloseDebugURL(server_data_.url); ASSERT_TRUE(close_status.ok()); // Check dumped files according to the expected gating results. @@ -336,7 +336,7 @@ TEST_F(GrpcDebugTest, SendDebugTensorsThroughMultipleRoundsUnderReadWriteMode) { kDebugNodeKey); // Close the debug gRPC stream. - Status close_status = DebugIO::CloseDebugURL(server_data_.url); + absl::Status close_status = DebugIO::CloseDebugURL(server_data_.url); ASSERT_TRUE(close_status.ok()); // Check dumped files according to the expected gating results. diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index 2bc06061c459f7..d0d779b7d9378b 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -169,11 +169,11 @@ bool PollTillFirstRequestSucceeds(const string& server_url, while (n_attempts++ < max_attempts) { const uint64 wall_time = Env::Default()->NowMicros(); - Status publish_s = DebugIO::PublishDebugTensor( + absl::Status publish_s = DebugIO::PublishDebugTensor( DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "prep_node", 0, "DebugIdentity"), prep_tensor, wall_time, {server_url}); - Status close_s = DebugIO::CloseDebugURL(server_url); + absl::Status close_s = DebugIO::CloseDebugURL(server_url); if (publish_s.ok() && close_s.ok()) { success = true; diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 0e4e11d81e9d12..9698076c36aba1 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -130,11 +130,11 @@ const size_t StringValMaxBytesInProto(const string& str) { // Breaks a string Tensor (represented as a TensorProto) as a vector of Event // protos. -Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, - const uint64 wall_time_us, - const size_t chunk_size_limit, - TensorProto* tensor_proto, - std::vector* events) { +absl::Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, + const uint64 wall_time_us, + const size_t chunk_size_limit, + TensorProto* tensor_proto, + std::vector* events) { const protobuf::RepeatedPtrField& strs = tensor_proto->string_val(); const size_t num_strs = strs.size(); const size_t chunk_size_ub = chunk_size_limit > 0 @@ -190,10 +190,10 @@ Status WrapStringTensorAsEvents(const DebugNodeKey& debug_node_key, // proto the field summary.tensor carries the content of the tensor. // If chunk_size_limit <= 0, the tensor will not be broken into chunks, i.e., a // length-1 vector will be returned, regardless of the size of the tensor. -Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, - const Tensor& tensor, const uint64 wall_time_us, - const size_t chunk_size_limit, - std::vector* events) { +absl::Status WrapTensorAsEvents(const DebugNodeKey& debug_node_key, + const Tensor& tensor, const uint64 wall_time_us, + const size_t chunk_size_limit, + std::vector* events) { TensorProto tensor_proto; if (tensor.dtype() == DT_STRING) { // Treat DT_STRING specially, so that tensor_util.MakeNdarray in Python can @@ -251,10 +251,10 @@ string AppendTimestampToFilePath(const string& in, const uint64 timestamp) { #ifndef PLATFORM_WINDOWS // Publishes encoded GraphDef through a gRPC debugger stream, in chunks, // conforming to the gRPC message size limit. -Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, - const string& device_name, - const int64_t wall_time, - const string& debug_url) { +absl::Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, + const string& device_name, + const int64_t wall_time, + const string& debug_url) { const uint64 hash = ::tensorflow::Hash64(encoded_graph_def); const size_t total_length = encoded_graph_def.size(); const size_t num_chunks = @@ -274,7 +274,7 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def, event.set_graph_def(strings::StrCat(hash, ",", device_name, ",", wall_time, "|", i, "|", num_chunks, "|", encoded_graph_def.substr(pos, len))); - const Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream( + const absl::Status s = DebugGrpcIO::SendEventProtoThroughGrpcStream( event, debug_url, num_chunks - 1 == i); if (!s.ok()) { return errors::FailedPrecondition( @@ -297,13 +297,13 @@ const char* const DebugIO::kGraphTag = "graph_"; const char* const DebugIO::kHashTag = "hash"; -Status ReadEventFromFile(const string& dump_file_path, Event* event) { +absl::Status ReadEventFromFile(const string& dump_file_path, Event* event) { Env* env(Env::Default()); string content; uint64 file_size = 0; - Status s = env->GetFileSize(dump_file_path, &file_size); + absl::Status s = env->GetFileSize(dump_file_path, &file_size); if (!s.ok()) { return s; } @@ -331,7 +331,7 @@ const char* const DebugIO::kGrpcURLScheme = "grpc://"; const char* const DebugIO::kMemoryURLScheme = "memcbk://"; // Publishes debug metadata to a set of debug URLs. -Status DebugIO::PublishDebugMetadata( +absl::Status DebugIO::PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, const int64_t executor_step_index, const std::vector& input_names, const std::vector& output_names, @@ -376,7 +376,7 @@ Status DebugIO::PublishDebugMetadata( LogMessage* log_message = event.mutable_log_message(); log_message->set_message(json_metadata); - Status status; + absl::Status status; for (const string& url : debug_urls) { if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) { #ifndef PLATFORM_WINDOWS @@ -418,14 +418,12 @@ Status DebugIO::PublishDebugMetadata( return status; } -Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const absl::Span debug_urls, - const bool gated_grpc, - const int64_t step_id) { +absl::Status DebugIO::PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls, + const bool gated_grpc, const int64_t step_id) { int32_t num_failed_urls = 0; - std::vector fail_statuses; + std::vector fail_statuses; for (const string& url : debug_urls) { if (absl::StartsWith(absl::AsciiStrToLower(url), kFileURLScheme)) { const string dump_root_dir = url.substr(strlen(kFileURLScheme)); @@ -443,20 +441,20 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit."); } - Status s = debug_node_key.io_of_node.empty() - ? DebugFileIO::DumpTensorToDir(debug_node_key, tensor, - wall_time_us, dump_root_dir, - nullptr) - : DebugFileIO::DumpTensorToDirForNodeDumping( - debug_node_key, tensor, wall_time_us, dump_root_dir, - nullptr, step_id); + absl::Status s = debug_node_key.io_of_node.empty() + ? DebugFileIO::DumpTensorToDir( + debug_node_key, tensor, wall_time_us, + dump_root_dir, nullptr) + : DebugFileIO::DumpTensorToDirForNodeDumping( + debug_node_key, tensor, wall_time_us, + dump_root_dir, nullptr, step_id); if (!s.ok()) { num_failed_urls++; fail_statuses.push_back(s); } } else if (absl::StartsWith(absl::AsciiStrToLower(url), kGrpcURLScheme)) { #ifndef PLATFORM_WINDOWS - Status s = DebugGrpcIO::SendTensorThroughGrpcStream( + absl::Status s = DebugGrpcIO::SendTensorThroughGrpcStream( debug_node_key, tensor, wall_time_us, url, gated_grpc); if (!s.ok()) { @@ -473,8 +471,8 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, CHECK(callback) << "No callback registered for: " << dump_root_dir; (*callback)(debug_node_key, tensor); } else { - return Status(absl::StatusCode::kUnavailable, - strings::StrCat("Invalid debug target URL: ", url)); + return absl::Status(absl::StatusCode::kUnavailable, + strings::StrCat("Invalid debug target URL: ", url)); } } @@ -484,25 +482,25 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, string error_message = strings::StrCat( "Publishing to ", num_failed_urls, " of ", debug_urls.size(), " debug target URLs failed, due to the following errors:"); - for (Status& status : fail_statuses) { + for (absl::Status& status : fail_statuses) { error_message = strings::StrCat(error_message, " ", status.message(), ";"); } - return Status(absl::StatusCode::kInternal, error_message); + return absl::Status(absl::StatusCode::kInternal, error_message); } } -Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const absl::Span debug_urls) { +absl::Status DebugIO::PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls) { return PublishDebugTensor(debug_node_key, tensor, wall_time_us, debug_urls, false); } -Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, - const std::unordered_set& debug_urls) { +absl::Status DebugIO::PublishGraph( + const Graph& graph, const string& device_name, + const std::unordered_set& debug_urls) { GraphDef graph_def; graph.ToGraphDef(&graph_def); @@ -514,7 +512,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, event.set_wall_time(static_cast(now_micros)); event.set_graph_def(buf); - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); for (const string& debug_url : debug_urls) { if (absl::StartsWith(debug_url, kFileURLScheme)) { const string dump_root_dir = @@ -591,7 +589,7 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key, #endif } -Status DebugIO::CloseDebugURL(const string& debug_url) { +absl::Status DebugIO::CloseDebugURL(const string& debug_url) { if (absl::StartsWith(debug_url, DebugIO::kGrpcURLScheme)) { #ifndef PLATFORM_WINDOWS return DebugGrpcIO::CloseGrpcStream(debug_url); @@ -604,11 +602,11 @@ Status DebugIO::CloseDebugURL(const string& debug_url) { } } -Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const string& dump_root_dir, - string* dump_file_path) { +absl::Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, + const Tensor& tensor, + const uint64 wall_time_us, + const string& dump_root_dir, + string* dump_file_path) { const string file_path = GetDumpFilePath(dump_root_dir, debug_node_key, wall_time_us); @@ -619,7 +617,7 @@ Status DebugFileIO::DumpTensorToDir(const DebugNodeKey& debug_node_key, return DumpTensorToEventFile(debug_node_key, tensor, wall_time_us, file_path); } -Status DebugFileIO::DumpTensorToDirForNodeDumping( +absl::Status DebugFileIO::DumpTensorToDirForNodeDumping( const DebugNodeKey& debug_node_key, const Tensor& tensor, const uint64 wall_time_us, const string& dump_root_dir, string* dump_file_path, const int64_t step_id) { @@ -656,16 +654,16 @@ string DebugFileIO::GetDumpFilePathForNodeDumping( wall_time_us); } -Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, - const string& dir_name, - const string& file_name) { +absl::Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, + const string& dir_name, + const string& file_name) { Env* env(Env::Default()); - Status s = RecursiveCreateDir(env, dir_name); + absl::Status s = RecursiveCreateDir(env, dir_name); if (!s.ok()) { - return Status(absl::StatusCode::kFailedPrecondition, - strings::StrCat("Failed to create directory ", dir_name, - ", due to: ", s.message())); + return absl::Status(absl::StatusCode::kFailedPrecondition, + strings::StrCat("Failed to create directory ", + dir_name, ", due to: ", s.message())); } const string file_path = io::JoinPath(dir_name, file_name); @@ -681,10 +679,9 @@ Status DebugFileIO::DumpEventProtoToFile(const Event& event_proto, return absl::OkStatus(); } -Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const string& file_path) { +absl::Status DebugFileIO::DumpTensorToEventFile( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const string& file_path) { std::vector events; TF_RETURN_IF_ERROR( WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events)); @@ -692,7 +689,7 @@ Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key, string(io::Basename(file_path))); } -Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { +absl::Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { // The path already exists as a directory. Return OK right away. return absl::OkStatus(); @@ -701,18 +698,19 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { string parent_dir(io::Dirname(dir)); if (!env->FileExists(parent_dir).ok()) { // The parent path does not exist yet, create it first. - Status s = RecursiveCreateDir(env, parent_dir); // Recursive call + absl::Status s = RecursiveCreateDir(env, parent_dir); // Recursive call if (!s.ok()) { - return Status( + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat("Failed to create directory ", parent_dir)); } } else if (env->FileExists(parent_dir).ok() && !env->IsDirectory(parent_dir).ok()) { // The path exists, but it is a file. - return Status(absl::StatusCode::kFailedPrecondition, - strings::StrCat("Failed to create directory ", parent_dir, - " because the path exists as a file ")); + return absl::Status( + absl::StatusCode::kFailedPrecondition, + strings::StrCat("Failed to create directory ", parent_dir, + " because the path exists as a file ")); } env->CreateDir(dir).IgnoreError(); @@ -721,8 +719,9 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) { if (env->FileExists(dir).ok() && env->IsDirectory(dir).ok()) { return absl::OkStatus(); } else { - return Status(absl::StatusCode::kAborted, - strings::StrCat("Failed to create directory ", parent_dir)); + return absl::Status( + absl::StatusCode::kAborted, + strings::StrCat("Failed to create directory ", parent_dir)); } } @@ -767,7 +766,7 @@ DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr) : server_stream_addr_(server_stream_addr), url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {} -Status DebugGrpcChannel::Connect(const int64_t timeout_micros) { +absl::Status DebugGrpcChannel::Connect(const int64_t timeout_micros) { ::grpc::ChannelArguments args; args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits::max()); // Avoid problems where default reconnect backoff is too long (e.g., 20 s). @@ -813,7 +812,7 @@ void DebugGrpcChannel::ReceiveAndProcessEventReplies(const size_t max_replies) { } } -Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { +absl::Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { reader_writer_->WritesDone(); // Read all EventReply messages (if any) from the server. ReceiveAndProcessEventReplies(0); @@ -821,8 +820,8 @@ Status DebugGrpcChannel::ReceiveServerRepliesAndClose() { if (reader_writer_->Finish().ok()) { return absl::OkStatus(); } else { - return Status(absl::StatusCode::kFailedPrecondition, - "Failed to close debug GRPC stream."); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Failed to close debug GRPC stream."); } } @@ -843,7 +842,7 @@ DebugGrpcIO::GetStreamChannels() { return stream_channels; } -Status DebugGrpcIO::SendTensorThroughGrpcStream( +absl::Status DebugGrpcIO::SendTensorThroughGrpcStream( const DebugNodeKey& debug_node_key, const Tensor& tensor, const uint64 wall_time_us, const string& grpc_stream_url, const bool gated) { @@ -870,7 +869,7 @@ Status DebugGrpcIO::SendTensorThroughGrpcStream( } } -Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( +absl::Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( EventReply* event_reply, const string& grpc_stream_url) { DebugGrpcChannel* debug_grpc_channel = nullptr; TF_RETURN_IF_ERROR( @@ -883,7 +882,7 @@ Status DebugGrpcIO::ReceiveEventReplyProtoThroughGrpcStream( } } -Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( +absl::Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel) { const string addr_with_path = absl::StartsWith(grpc_stream_url, DebugIO::kGrpcURLScheme) @@ -907,7 +906,7 @@ Status DebugGrpcIO::GetOrCreateDebugGrpcChannel( return absl::OkStatus(); } -Status DebugGrpcIO::SendEventProtoThroughGrpcStream( +absl::Status DebugGrpcIO::SendEventProtoThroughGrpcStream( const Event& event_proto, const string& grpc_stream_url, const bool receive_reply) { DebugGrpcChannel* debug_grpc_channel; @@ -946,7 +945,7 @@ bool DebugGrpcIO::IsWriteGateOpen(const string& grpc_debug_url, } } -Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { +absl::Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { mutex_lock l(streams_mu_); std::unordered_map>* @@ -954,7 +953,7 @@ Status DebugGrpcIO::CloseGrpcStream(const string& grpc_stream_url) { if (stream_channels->find(grpc_stream_url) != stream_channels->end()) { // Stream of the specified address exists. Close it and remove it from // record. - Status s = + absl::Status s = (*stream_channels)[grpc_stream_url]->ReceiveServerRepliesAndClose(); (*stream_channels).erase(grpc_stream_url); return s; diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index 1555f92a96df6e..95864c714682b6 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -36,7 +36,7 @@ limitations under the License. namespace tensorflow { -Status ReadEventFromFile(const string& dump_file_path, Event* event); +absl::Status ReadEventFromFile(const string& dump_file_path, Event* event); struct DebugWatchAndURLSpec { DebugWatchAndURLSpec(const string& watch_key, const string& url, @@ -61,7 +61,7 @@ class DebugIO { static const char* const kGrpcURLScheme; static const char* const kMemoryURLScheme; - static Status PublishDebugMetadata( + static absl::Status PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, const int64_t executor_step_index, const std::vector& input_names, const std::vector& output_names, @@ -80,25 +80,24 @@ class DebugIO { // "file:///foo/tfdbg_dump", "grpc://localhost:11011" // gated_grpc: Whether this call is subject to gRPC gating. // step_id: Step ID associated with the tensor. - static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const absl::Span debug_urls, - bool gated_grpc, int64_t step_id = -1); + static absl::Status PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls, + bool gated_grpc, int64_t step_id = -1); // Convenience overload of the method above for no gated_grpc by default. - static Status PublishDebugTensor(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const absl::Span debug_urls); + static absl::Status PublishDebugTensor( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const absl::Span debug_urls); // Publishes a graph to a set of debug URLs. // // Args: // graph: The graph to be published. // debug_urls: The set of debug URLs to publish the graph to. - static Status PublishGraph(const Graph& graph, const string& device_name, - const std::unordered_set& debug_urls); + static absl::Status PublishGraph( + const Graph& graph, const string& device_name, + const std::unordered_set& debug_urls); // Determines whether a copy node needs to perform deep-copy of input tensor. // @@ -145,7 +144,7 @@ class DebugIO { static bool IsDebugURLGateOpen(const string& watch_key, const string& debug_url); - static Status CloseDebugURL(const string& debug_url); + static absl::Status CloseDebugURL(const string& debug_url); }; // Helper class for debug ops. @@ -170,13 +169,14 @@ class DebugFileIO { // execution. Unit: microseconds (us). // dump_root_dir: Root directory for dumping the tensor. // dump_file_path: The actual dump file path (passed as reference). - static Status DumpTensorToDir(const DebugNodeKey& debug_node_key, - const Tensor& tensor, const uint64 wall_time_us, - const string& dump_root_dir, - string* dump_file_path); + static absl::Status DumpTensorToDir(const DebugNodeKey& debug_node_key, + const Tensor& tensor, + const uint64 wall_time_us, + const string& dump_root_dir, + string* dump_file_path); // Similar to the above, but for node inputs/outputs dumping feature. - static Status DumpTensorToDirForNodeDumping( + static absl::Status DumpTensorToDirForNodeDumping( const DebugNodeKey& debug_node_key, const Tensor& tensor, uint64 wall_time_us, const string& dump_root_dir, string* dump_file_path, int64_t step_id); @@ -205,9 +205,9 @@ class DebugFileIO { // event_prot: The Event proto to be dumped. // dir_name: Directory path. // file_name: Base file name. - static Status DumpEventProtoToFile(const Event& event_proto, - const string& dir_name, - const string& file_name); + static absl::Status DumpEventProtoToFile(const Event& event_proto, + const string& dir_name, + const string& file_name); // Request additional bytes to be dumped to the file system. // @@ -231,15 +231,15 @@ class DebugFileIO { private: // Encapsulates the Tensor in an Event protobuf and write it to file. - static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const string& file_path); + static absl::Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key, + const Tensor& tensor, + const uint64 wall_time_us, + const string& file_path); // Implemented ad hoc here for now. // TODO(cais): Replace with shared implementation once http://b/30497715 is // fixed. - static Status RecursiveCreateDir(Env* env, const string& dir); + static absl::Status RecursiveCreateDir(Env* env, const string& dir); // Tracks how much disk has been used so far. static uint64 disk_bytes_used_; @@ -295,7 +295,7 @@ class DebugGrpcChannel { // Returns: // OK Status iff connection is successfully established before timeout, // otherwise return an error Status. - Status Connect(const int64_t timeout_micros); + absl::Status Connect(const int64_t timeout_micros); // Write an Event proto to the debug gRPC stream. // @@ -334,7 +334,7 @@ class DebugGrpcChannel { // Receive EventReplies from server (if any) and close the stream and the // channel. - Status ReceiveServerRepliesAndClose(); + absl::Status ReceiveServerRepliesAndClose(); private: string server_stream_addr_; @@ -354,11 +354,10 @@ class DebugGrpcIO { static const size_t kGrpcMaxVarintLengthSize; // Sends a tensor through a debug gRPC stream. - static Status SendTensorThroughGrpcStream(const DebugNodeKey& debug_node_key, - const Tensor& tensor, - const uint64 wall_time_us, - const string& grpc_stream_url, - const bool gated); + static absl::Status SendTensorThroughGrpcStream( + const DebugNodeKey& debug_node_key, const Tensor& tensor, + const uint64 wall_time_us, const string& grpc_stream_url, + const bool gated); // Sends an Event proto through a debug gRPC stream. // Thread-safety: Safe with respect to other calls to the same method and @@ -373,12 +372,12 @@ class DebugGrpcIO { // // Returns: // The Status of the operation. - static Status SendEventProtoThroughGrpcStream( + static absl::Status SendEventProtoThroughGrpcStream( const Event& event_proto, const string& grpc_stream_url, const bool receive_reply = false); // Receive an EventReply proto through a debug gRPC stream. - static Status ReceiveEventReplyProtoThroughGrpcStream( + static absl::Status ReceiveEventReplyProtoThroughGrpcStream( EventReply* event_reply, const string& grpc_stream_url); // Check whether a debug watch key is read-activated at a given gRPC URL. @@ -393,7 +392,7 @@ class DebugGrpcIO { // Closes a gRPC stream to the given address, if it exists. // Thread-safety: Safe with respect to other calls to the same method and // calls to SendTensorThroughGrpcStream(). - static Status CloseGrpcStream(const string& grpc_stream_url); + static absl::Status CloseGrpcStream(const string& grpc_stream_url); // Set the gRPC state of a debug node key. // TODO(cais): Include device information in watch_key. @@ -420,7 +419,7 @@ class DebugGrpcIO { // // Returns: // Status of this operation. - static Status GetOrCreateDebugGrpcChannel( + static absl::Status GetOrCreateDebugGrpcChannel( const string& grpc_stream_url, DebugGrpcChannel** debug_grpc_channel); // Returns a map from debug URL to a map from debug op name to enabled state. diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index 74d5758c306ef1..9170ae04c531fe 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -163,8 +163,8 @@ TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) { const uint64 wall_time = env_->NowMicros(); string dump_file_name; - Status s = DebugFileIO::DumpTensorToDir(kDebugNodeKey, *tensor_b_, wall_time, - test_dir, &dump_file_name); + absl::Status s = DebugFileIO::DumpTensorToDir( + kDebugNodeKey, *tensor_b_, wall_time, test_dir, &dump_file_name); ASSERT_TRUE(s.ok()); // Read the file into a Event proto. @@ -240,8 +240,8 @@ TEST_F(DebugIOUtilsTest, DumpTensorToFileCannotCreateDirectory) { const uint64 wall_time = env_->NowMicros(); string dump_file_name; - Status s = DebugFileIO::DumpTensorToDir(kDebugNodeKey, *tensor_a_, wall_time, - test_dir, &dump_file_name); + absl::Status s = DebugFileIO::DumpTensorToDir( + kDebugNodeKey, *tensor_a_, wall_time, test_dir, &dump_file_name); ASSERT_FALSE(s.ok()); // Tear down temporary file and directories. @@ -279,7 +279,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) { ASSERT_NE(dump_roots[0], dump_roots[i]); } - Status s = + absl::Status s = DebugIO::PublishDebugTensor(kDebugNodeKey, *tensor_a_, wall_time, urls); ASSERT_TRUE(s.ok()); @@ -349,7 +349,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMemoryCallback) { } }); - Status s = + absl::Status s = DebugIO::PublishDebugTensor(kDebugNodeKey, *tensor_a_, wall_time, urls); ASSERT_TRUE(s.ok()); ASSERT_TRUE(called); @@ -403,7 +403,7 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { std::vector urls; urls.push_back(debug_url); - Status s = + absl::Status s = DebugIO::PublishDebugTensor(kDebugNodeKey, *tensor_a_, wall_time, urls); ASSERT_TRUE(s.ok()); diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc index df0404f8bb89ad..a1545ad1aa1516 100644 --- a/tensorflow/core/debug/debugger_state_impl.cc +++ b/tensorflow/core/debug/debugger_state_impl.cc @@ -38,7 +38,7 @@ DebuggerState::~DebuggerState() { } } -Status DebuggerState::PublishDebugMetadata( +absl::Status DebuggerState::PublishDebugMetadata( const int64_t global_step, const int64_t session_run_index, const int64_t executor_step_index, const std::vector& input_names, const std::vector& output_names, @@ -48,14 +48,14 @@ Status DebuggerState::PublishDebugMetadata( output_names, target_names, debug_urls_); } -Status DebugGraphDecorator::DecorateGraph(Graph* graph, Device* device) { +absl::Status DebugGraphDecorator::DecorateGraph(Graph* graph, Device* device) { DebugNodeInserter::DeparallelizeWhileLoops(graph, device); return DebugNodeInserter::InsertNodes( debug_options_.debug_tensor_watch_opts(), graph, device); } -Status DebugGraphDecorator::PublishGraph(const Graph& graph, - const string& device_name) { +absl::Status DebugGraphDecorator::PublishGraph(const Graph& graph, + const string& device_name) { std::unordered_set debug_urls; for (const DebugTensorWatch& watch : debug_options_.debug_tensor_watch_opts()) { diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h index 4114d68549e2f0..c34aa8bb51a917 100644 --- a/tensorflow/core/debug/debugger_state_impl.h +++ b/tensorflow/core/debug/debugger_state_impl.h @@ -32,12 +32,11 @@ class DebuggerState : public DebuggerStateInterface { // // See the doc string of DebuggerStateInterface::PublishDebugMetadata() for // details. - Status PublishDebugMetadata(const int64_t global_step, - const int64_t session_run_count, - const int64_t executor_step_count, - const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_names) override; + absl::Status PublishDebugMetadata( + const int64_t global_step, const int64_t session_run_count, + const int64_t executor_step_count, const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_names) override; private: std::unordered_set debug_urls_; @@ -49,8 +48,9 @@ class DebugGraphDecorator : public DebugGraphDecoratorInterface { : debug_options_(debug_options) {} ~DebugGraphDecorator() override {} - Status DecorateGraph(Graph* graph, Device* device) override; - Status PublishGraph(const Graph& graph, const string& device_name) override; + absl::Status DecorateGraph(Graph* graph, Device* device) override; + absl::Status PublishGraph(const Graph& graph, + const string& device_name) override; private: DebugOptions debug_options_; diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc index 72cbbe82e42b0a..9f23b6777da73c 100644 --- a/tensorflow/core/debug/grpc_session_debug_test.cc +++ b/tensorflow/core/debug/grpc_session_debug_test.cc @@ -255,7 +255,7 @@ TEST_F(GrpcSessionDebugTest, MultiDevices_String) { SetDevice(&def, a->name(), a_dev.name()); SetDevice(&def, b->name(), b_dev.name()); - Status s = session->Create(def); + absl::Status s = session->Create(def); if (s.ok()) { std::vector outputs; diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 143681cee500ed..00515c71df7917 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -172,10 +172,10 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/activity_watcher", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_rpc_handler", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) @@ -392,7 +392,7 @@ cc_library( "//tensorflow/core/platform:regexp", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ], ) @@ -491,7 +491,7 @@ cc_library( ":worker_cache", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc index a31794d7591899..bdc2acbcd5b5a0 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc @@ -67,13 +67,13 @@ void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id, FindOrCreate(step_id)->RecvLocalAsync(parsed, std::move(done)); } -Status BaseRendezvousMgr::RecvLocal(int64_t step_id, - const Rendezvous::ParsedKey& parsed, - Tensor* val, bool* is_dead) { - Status ret; +absl::Status BaseRendezvousMgr::RecvLocal(int64_t step_id, + const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) { + absl::Status ret; Notification n; RecvLocalAsync(step_id, parsed, - [val, is_dead, &ret, &n](const Status& s, + [val, is_dead, &ret, &n](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& v, const bool dead) { @@ -121,7 +121,7 @@ static bool IsImplicitLocalDevice( return !DeviceNameUtils::HasSomeDetails(parsed_device_name); } -Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { +absl::Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { CHECK_NE(session, nullptr) << "session must not be null!"; std::vector deferred_calls; { @@ -131,7 +131,7 @@ Status BaseRemoteRendezvous::Initialize(WorkerSession* session) { VLOG(1) << "Skipping rendezvous re-initialization."; return absl::OkStatus(); } - Status s = errors::Internal( + absl::Status s = errors::Internal( "Double init! Worker names would have changed from: ", session_->worker_name(), " -> ", session->worker_name()); LOG(WARNING) << s; @@ -156,9 +156,9 @@ bool BaseRemoteRendezvous::is_initialized() { return is_initialized_locked(); } -Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, - const Rendezvous::Args& args, - const Tensor& val, const bool is_dead) { +absl::Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, + const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) { VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << parsed.FullKey(); WorkerSession* sess = nullptr; { @@ -179,8 +179,8 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed, return local_.Send(parsed, args, val, is_dead); } -Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, - bool is_src) { +absl::Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed, + bool is_src) { // Cache session pointer to avoid repeatedly taking & releasing the lock // (e.g. calling session()) WorkerSession* sess = nullptr; @@ -235,7 +235,8 @@ void BaseRemoteRendezvous::SameWorkerRecvDone( WorkerSession* sess = session(); Device* src_device; - Status s = sess->device_mgr()->LookupDevice(parsed.src_device, &src_device); + absl::Status s = + sess->device_mgr()->LookupDevice(parsed.src_device, &src_device); if (!s.ok()) { done(s); return; @@ -287,7 +288,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey(); - Status s = ValidateDevices(parsed, false /*!is_src*/); + absl::Status s = ValidateDevices(parsed, false /*!is_src*/); if (!s.ok()) { done(s, Args(), recv_args, Tensor(), false); return; @@ -309,13 +310,13 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, local_.RecvAsync( parsed, recv_args, [this, parsed, done]( - const Status& status, const Rendezvous::Args& send_args, + const absl::Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { VLOG(2) << "RemoteRendezvous Finished Local Recv " << this << " " << parsed.FullKey(); Tensor* out = new Tensor; StatusCallback final_callback = [done, send_args, recv_args, out, - is_dead](const Status& s) { + is_dead](const absl::Status& s) { done(s, send_args, recv_args, *out, is_dead); delete out; }; @@ -332,7 +333,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed, // Keep current rendezvous alive while the recv is inflight. this->Ref(); RecvFromRemoteAsync(parsed, recv_args, - [this, parsed, done](const Status& status, + [this, parsed, done](const absl::Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) { @@ -368,7 +369,7 @@ void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed, void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done) { - Status s = ValidateDevices(parsed, true /* is_src */); + absl::Status s = ValidateDevices(parsed, true /* is_src */); if (!s.ok()) { done(s, Args(), Args(), Tensor(), false); return; @@ -376,13 +377,13 @@ void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed, local_.RecvAsync(parsed, Args(), std::move(done)); } -void BaseRemoteRendezvous::StartAbort(const Status& s) { +void BaseRemoteRendezvous::StartAbort(const absl::Status& s) { CHECK(!s.ok()); // If the status passed in is a cancelled or aborted error, mark it as // "derived" for the rendezvous. Derived status messages are ignored when // aggregating errors across devices: this allows us to prefer our original // status message over any cancellation related errors. - Status derived_status = s; + absl::Status derived_status = s; if (absl::IsCancelled(s) || absl::IsAborted(s)) { derived_status = StatusGroup::MakeDerived(s); } diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h index 3b5c99b16988d3..4713f3be4efa83 100644 --- a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h @@ -89,8 +89,8 @@ class BaseRendezvousMgr : public RendezvousMgrInterface { Rendezvous::DoneCallback done) override; // Synchronous wrapper for RecvLocalAsync. - Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, - Tensor* val, bool* is_dead) override; + absl::Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) override; // Removes rendezvous for "step_id". void Cleanup(int64_t step_id) override { cache_->RemoveAndAbort(step_id); } @@ -125,7 +125,7 @@ class BaseRemoteRendezvous : public RemoteRendezvous { BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id); // Upgrades the BaseRemoteRendezvous to full initialization. - Status Initialize(WorkerSession* session) override; + absl::Status Initialize(WorkerSession* session) override; void SetRemoteEagerContextDefault() override { remote_eager_context_default_ = true; @@ -136,8 +136,8 @@ class BaseRemoteRendezvous : public RemoteRendezvous { // Forwards to local_, where the Tensor "val" will be buffered and // any waiting callback stored. - Status Send(const ParsedKey& key, const Rendezvous::Args& args, - const Tensor& val, const bool is_dead) override; + absl::Status Send(const ParsedKey& key, const Rendezvous::Args& args, + const Tensor& val, const bool is_dead) override; // This method is called only by the RecvOp. It tests to see // whether the value will be produced by a local or remote device @@ -146,7 +146,7 @@ class BaseRemoteRendezvous : public RemoteRendezvous { void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args, DoneCallback done) override; - void StartAbort(const Status& status) override; + void StartAbort(const absl::Status& status) override; // This method is called only by the local Worker, forwarded through // the same method on RendezvousMgr. This occurs when the Worker @@ -199,7 +199,7 @@ class BaseRemoteRendezvous : public RemoteRendezvous { mutable mutex calls_mu_; // Status given by StartAbort() if any. - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); WorkerSession* session_ TF_GUARDED_BY(mu_); // Not owned. @@ -257,7 +257,8 @@ class BaseRemoteRendezvous : public RemoteRendezvous { // If "is_src" is true, checks that the rendezvous key "parsed"'s // source is in this process. If "is_src" is false, checks that the // rendezvous key "parsed"'s destination is in this process. - Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src); + absl::Status ValidateDevices(const Rendezvous::ParsedKey& parsed, + bool is_src); // Callback handling the case when a rendezvous has been // accomplished in local_ and the consumer is local to this process. @@ -282,9 +283,9 @@ class BaseRecvTensorCall { virtual void Start(std::function recv_done) = 0; - virtual void StartAbort(const Status& s) = 0; + virtual void StartAbort(const absl::Status& s) = 0; - virtual Status status() const = 0; + virtual absl::Status status() const = 0; private: BaseRecvTensorCall(const BaseRecvTensorCall&) = delete; diff --git a/tensorflow/core/distributed_runtime/cancellable_call.cc b/tensorflow/core/distributed_runtime/cancellable_call.cc index ed25c3a19474f2..875d35da14bfe5 100644 --- a/tensorflow/core/distributed_runtime/cancellable_call.cc +++ b/tensorflow/core/distributed_runtime/cancellable_call.cc @@ -25,7 +25,7 @@ void CancellableCall::Start(const StatusCallback& done) { const bool not_yet_cancelled = cancel_mgr_->RegisterCallback(token, [this]() { Cancel(); }); if (not_yet_cancelled) { - IssueCall([this, token, done](const Status& s) { + IssueCall([this, token, done](const absl::Status& s) { cancel_mgr_->DeregisterCallback(token); done(s); }); diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc index 04254815c59ac2..9fe115b02d3154 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.cc @@ -35,7 +35,7 @@ limitations under the License. namespace tensorflow { /* static */ -Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( +absl::Status ClusterFunctionLibraryRuntime::ConstructFunctionGraph( const OpDef& sig, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const FunctionLibraryDefinition& flib_def, GraphDef* gdef, @@ -211,7 +211,7 @@ void ClusterFunctionLibraryRuntime::Instantiate( &gdef, send_keys, recv_keys)); return absl::OkStatus(); }; - Status s; + absl::Status s; if (options.lib_def) { s = construct_graph_fn(options.lib_def); } else { @@ -236,7 +236,7 @@ void ClusterFunctionLibraryRuntime::Instantiate( wi->RegisterGraphAsync( req, resp, [this, handle, req, resp, worker_cache, wi, function_name, target, - send_keys, recv_keys, done](const Status& status) { + send_keys, recv_keys, done](const absl::Status& status) { if (status.ok()) { mutex_lock l(mu_); *handle = function_data_.size(); @@ -294,8 +294,9 @@ void ClusterFunctionLibraryRuntime::Run( CallOptions* call_options = new CallOptions(); wi->RunGraphAsync( call_options, req, resp, - [call_options, req, resp, rets, recv_keys, done](const Status& status) { - Status* local_status = new Status(status); + [call_options, req, resp, rets, recv_keys, + done](const absl::Status& status) { + absl::Status* local_status = new absl::Status(status); auto cleanup = gtl::MakeCleanup([call_options, req, resp, local_status, done] { done(*local_status); @@ -348,16 +349,17 @@ void ClusterFunctionLibraryRuntime::Run( } } std::vector* ret_tensors = new std::vector; - return Run(opts, handle, tensors, ret_tensors, - [rets, ret_tensors, done = std::move(done)](const Status& s) { - if (s.ok()) { - for (const auto& t : *ret_tensors) { - rets->push_back(t); - } - } - delete ret_tensors; - done(s); - }); + return Run( + opts, handle, tensors, ret_tensors, + [rets, ret_tensors, done = std::move(done)](const absl::Status& s) { + if (s.ok()) { + for (const auto& t : *ret_tensors) { + rets->push_back(t); + } + } + delete ret_tensors; + done(s); + }); } void ClusterFunctionLibraryRuntime::CleanUp( @@ -381,7 +383,7 @@ void ClusterFunctionLibraryRuntime::CleanUp( CleanupGraphResponse* cleanup_resp = new CleanupGraphResponse; wi->CleanupGraphAsync( cleanup_req, cleanup_resp, - [cleanup_req, cleanup_resp, done](const Status& cleanup_status) { + [cleanup_req, cleanup_resp, done](const absl::Status& cleanup_status) { done(cleanup_status); delete cleanup_req; delete cleanup_resp; diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h index 397325529b9ff5..a016a5eea418df 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime.h @@ -63,7 +63,7 @@ class ClusterFunctionLibraryRuntime : public DistributedFunctionLibraryRuntime { DeviceMgr* remote_device_mgr() const override { return remote_device_mgr_; } private: - static Status ConstructFunctionGraph( + static absl::Status ConstructFunctionGraph( const OpDef& sig, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const FunctionLibraryDefinition& flib_def, GraphDef* g, diff --git a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc index 0859a1607cf117..f61865fc4de540 100644 --- a/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc +++ b/tensorflow/core/distributed_runtime/cluster_function_library_runtime_test.cc @@ -67,7 +67,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { worker_session_.get(), true, nullptr); } - Status ConstructFunctionGraphHelper( + absl::Status ConstructFunctionGraphHelper( const OpDef& sig, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const FunctionLibraryDefinition& lib_def, GraphDef* g, @@ -86,19 +86,20 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { local_handle, done); } - Status InstantiateAndRun( + absl::Status InstantiateAndRun( const string& function_name, const FunctionLibraryDefinition& lib_def, test::function::Attrs attrs, const FunctionLibraryRuntime::InstantiateOptions& options, const std::vector& args, std::vector rets) { FunctionLibraryRuntime::LocalHandle handle; - Status status; + absl::Status status; Notification instantiate_done; - cluster_flr_->Instantiate(function_name, lib_def, attrs, options, &handle, - [&status, &instantiate_done](const Status& s) { - status = s; - instantiate_done.Notify(); - }); + cluster_flr_->Instantiate( + function_name, lib_def, attrs, options, &handle, + [&status, &instantiate_done](const absl::Status& s) { + status = s; + instantiate_done.Notify(); + }); instantiate_done.WaitForNotification(); if (!status.ok()) { return status; @@ -108,7 +109,7 @@ class ClusterFunctionLibraryRuntimeTest : public ::testing::Test { FunctionLibraryRuntime::Options opts; std::vector out; cluster_flr_->Run(opts, handle, args, &out, - [&status, &done](const Status& s) { + [&status, &done](const absl::Status& s) { status = s; done.Notify(); }); diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc index 92f7251d3ad2c5..ab13146b73bbbd 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc @@ -113,7 +113,7 @@ void CollectiveParamResolverDistributed::CompleteParamsAsync( if (cp->run_group_initialization) { CompleteGroupDistributed( device, &cp->group, cancel_mgr, - [this, device, cp, cancel_mgr, done](Status s) { + [this, device, cp, cancel_mgr, done](absl::Status s) { if (s.ok()) { std::vector devices; devices.reserve(cp->group.group_size); @@ -180,13 +180,13 @@ void CollectiveParamResolverDistributed::CompleteInstanceAsync( for (int32_t offset : request->subdiv_offset()) { cp->instance.impl_details.subdiv_offsets.push_back(offset); } - StatusCallback done_and_cleanup = [cp, done](const Status& s) { + StatusCallback done_and_cleanup = [cp, done](const absl::Status& s) { done(s); cp->Unref(); }; CompleteInstanceDistributed( request->device(), cp, cancel_mgr, - [this, cp, response, done_and_cleanup](Status status) { + [this, cp, response, done_and_cleanup](absl::Status status) { if (status.ok()) { // Now source_rank should be known, so retrieve it. bool created_irec; @@ -214,7 +214,7 @@ CollectiveParamResolverDistributed::GetCachedGroup(int32_t group_key) { return it->second.get(); } -Status CollectiveParamResolverDistributed::UpdateGroupCache( +absl::Status CollectiveParamResolverDistributed::UpdateGroupCache( const CompleteGroupResponse& resp) { // Build a new record from resp. std::unique_ptr gr(new GroupRec); @@ -295,10 +295,10 @@ void CollectiveParamResolverDistributed::CompleteGroupDistributed( return; } call->Start([this, device, group_params, call, cancel_mgr, abortion_token, - done](const Status& s) { + done](const absl::Status& s) { abortion_cancel_mgr_.DeregisterCallback(abortion_token); if (s.ok()) { - Status status = UpdateGroupCache(call->resp_); + absl::Status status = UpdateGroupCache(call->resp_); if (status.ok()) { CompleteGroupLocal(device, group_params, cancel_mgr, done); } else { @@ -327,7 +327,7 @@ bool CollectiveParamResolverDistributed::InstanceIsCached( return instance_it != group_it->second.end(); } -Status CollectiveParamResolverDistributed::UpdateInstanceCache( +absl::Status CollectiveParamResolverDistributed::UpdateInstanceCache( CollectiveParams* cp, const CompleteInstanceResponse& resp) { int32_t source_rank = resp.source_rank(); bool created_irec; @@ -384,7 +384,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( delete call; return; } - call->Start([this, device, cp, call, abortion_token, done](Status s) { + call->Start([this, device, cp, call, abortion_token, done](absl::Status s) { abortion_cancel_mgr_.DeregisterCallback(abortion_token); if (s.ok()) { s = UpdateInstanceCache(cp, call->resp_); @@ -400,7 +400,7 @@ void CollectiveParamResolverDistributed::CompleteInstanceDistributed( } } -void CollectiveParamResolverDistributed::StartAbort(const Status& s) { +void CollectiveParamResolverDistributed::StartAbort(const absl::Status& s) { { mutex_lock l(status_mu_); if (!status_.ok()) { diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h index 4d75e9d4e18e2f..63006c1253547e 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h @@ -48,7 +48,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { CancellationManager* cancel_mgr, const StatusCallback& done) override; - void StartAbort(const Status& s) override; + void StartAbort(const absl::Status& s) override; protected: // Returns the cached group iff there's an entry for this group_key in the @@ -56,7 +56,7 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { GroupRec* GetCachedGroup(int32_t group_key) TF_LOCKS_EXCLUDED(group_mu_); // Updates group_table_ with contents of resp. - Status UpdateGroupCache(const CompleteGroupResponse& resp) + absl::Status UpdateGroupCache(const CompleteGroupResponse& resp) TF_LOCKS_EXCLUDED(group_mu_); // Finds the GroupRec that corresponds to cp->group_key and also @@ -75,8 +75,8 @@ class CollectiveParamResolverDistributed : public CollectiveParamResolverLocal { TF_LOCKS_EXCLUDED(instance_mu_); // Updates instance_table_ with contents of resp. - Status UpdateInstanceCache(CollectiveParams* cp, - const CompleteInstanceResponse& resp) + absl::Status UpdateInstanceCache(CollectiveParams* cp, + const CompleteInstanceResponse& resp) TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); // Finish populating *cp. Semantics are like those of diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc index d699b1d2275a2c..f48758f993677f 100644 --- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed_test.cc @@ -39,7 +39,7 @@ static std::unique_ptr NewDevice(const string& type, class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -75,7 +75,7 @@ class FakeCache : public TestWorkerCache { WorkerInterface* wi = it->second; GetStatusRequest req; GetStatusResponse resp; - Status status = wi->GetStatus(&req, &resp); + absl::Status status = wi->GetStatus(&req, &resp); if (!status.ok()) { done(status); return; @@ -101,7 +101,7 @@ class FakeNcclCommunicator : public NcclCommunicatorInterface { done(absl::OkStatus()); } - void StartAbort(const Status& s) override {} + void StartAbort(const absl::Status& s) override {} }; class DeviceResDistTest : public ::testing::Test { @@ -220,7 +220,7 @@ class DeviceResDistTest : public ::testing::Test { CHECK(cp_res); cp_res->CompleteParamsAsync( device->attributes(), cp, &cm_, - [this, device_name, group_size](const Status& s) { + [this, device_name, group_size](const absl::Status& s) { status_[device_name] = s; { mutex_lock l(mu_); @@ -312,7 +312,7 @@ class DeviceResDistTest : public ::testing::Test { absl::flat_hash_map> workers_; // Below are keyed by device names; absl::flat_hash_map cp_; - absl::flat_hash_map status_; + absl::flat_hash_map status_; mutex mu_; int num_done_ TF_GUARDED_BY(mu_); condition_variable done_; diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc index eae1750c0b7de1..1b4ba6296f4978 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc @@ -79,8 +79,8 @@ void PopulateTensorFromExtra(const RecvBufRespExtra& extra, } } -Status PopulateTensorFromResponse(const RecvBufResponse& response, - Tensor* cpu_tensor) { +absl::Status PopulateTensorFromResponse(const RecvBufResponse& response, + Tensor* cpu_tensor) { const bool has_transport_options = response.has_transport_options(); // If there are no transport options, then the tensor has already been @@ -129,8 +129,8 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( }; State* state = new State; - Status s = dev_resolver_->GetDeviceAttributes(peer_device, - &state->server_attributes); + absl::Status s = dev_resolver_->GetDeviceAttributes( + peer_device, &state->server_attributes); if (!s.ok()) { delete state; done(s); @@ -144,7 +144,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( // Use GPU-registered memory for the CPU tensor so the transfer // goes faster. - Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev); + absl::Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev); if (!status.ok()) { delete state; done(s); @@ -169,7 +169,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( // Logic to be executed on the RecvBufAsync callback. auto recv_buf_callback = [this, state, to_device, to_alloc_attr, to_device_ctx, to_tensor, cpu_dev, - dev_to_dev_stream_index, dst_tensor, done](const Status& s) { + dev_to_dev_stream_index, dst_tensor, done](const absl::Status& s) { if (s.ok()) { // In this generic implementation the bytes come back in one of 2 // ways: @@ -183,7 +183,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( // (NOP in 2nd case) In case the final to_tensor is on GPU, buf_ptr // points to a tmp CPU buffer and needs to be copied over to // to_tensor. - Status status = + absl::Status status = PopulateTensorFromResponse(state->call->resp_, dst_tensor); if (!status.ok()) { done(status); @@ -198,7 +198,7 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev, to_device, cpu_attr, to_alloc_attr, dst_tensor, to_tensor, dev_to_dev_stream_index, - [this, state, done](const Status& s) { + [this, state, done](const absl::Status& s) { delete state; // This callback must not block, so execute // done in another thread. @@ -222,11 +222,12 @@ void CollectiveRemoteAccessDistributed::RecvFromPeer( if (already_aborted) { recv_buf_callback(errors::Cancelled("collective ops already aborted")); } else { - state->call->Start([this, abortion_token, - done = std::move(recv_buf_callback)](const Status& s) { - abortion_cancel_mgr_.DeregisterCallback(abortion_token); - done(s); - }); + state->call->Start( + [this, abortion_token, + done = std::move(recv_buf_callback)](const absl::Status& s) { + abortion_cancel_mgr_.DeregisterCallback(abortion_token); + done(s); + }); } } @@ -258,7 +259,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( // cancelled. wi->GetStatusAsync( opts, req, resp, /*fail_fast*/ true, - [this, opts, req, resp, wi, peer_task, done](Status s) { + [this, opts, req, resp, wi, peer_task, done](absl::Status s) { std::vector cached_attrs; if (s.ok()) { s = dev_resolver_->GetAllDeviceAttributes(peer_task, &cached_attrs); @@ -291,7 +292,7 @@ void CollectiveRemoteAccessDistributed::CheckPeerHealth( }); } -void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) { +void CollectiveRemoteAccessDistributed::StartAbort(const absl::Status& s) { CollectiveRemoteAccessLocal::StartAbort(s); abortion_cancel_mgr_.StartCancel(); } diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.h b/tensorflow/core/distributed_runtime/collective_rma_distributed.h index 31eab9dba2f904..22d4d6f5a119e6 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed.h +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.h @@ -49,7 +49,7 @@ class CollectiveRemoteAccessDistributed : public CollectiveRemoteAccessLocal { void CheckPeerHealth(const string& peer_task, int64_t timeout_in_ms, const StatusCallback& done) override; - void StartAbort(const Status& s) override; + void StartAbort(const absl::Status& s) override; protected: WorkerCacheInterface* worker_cache_; // Not owned diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc index f5497063e7b679..0338abeda899d3 100644 --- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc +++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc @@ -62,7 +62,7 @@ static std::unique_ptr NewDevice(const string& type, const string& name, public: explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator) : Device(nullptr, attr), allocator_(allocator) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; } private: @@ -132,9 +132,9 @@ class FakeWorker : public TestWorkerInterface { buf_rendezvous_.ConsumeBuf( request->buf_rendezvous_key(), request->src_device(), request->src_incarnation(), - [this, opts, request, response, done](const Status& status, + [this, opts, request, response, done](const absl::Status& status, BufRendezvous::Hook* h) { - Status s = status; + absl::Status s = status; if (s.ok()) { opts->ClearCancelCallback(); int64_t num_bytes = h->prod_value->TotalBytes(); @@ -196,7 +196,7 @@ class FakeCache : public TestWorkerCache { WorkerInterface* wi = it->second; GetStatusRequest req; GetStatusResponse resp; - Status status = wi->GetStatus(&req, &resp); + absl::Status status = wi->GetStatus(&req, &resp); if (!status.ok()) { done(status); return; @@ -375,14 +375,14 @@ TEST_P(CollRMADistTest, ProdFirstOK) { ResolveDeviceAttributes(); Notification consumer_note; Notification producer_note; - Status consumer_status; - Status producer_status; + absl::Status consumer_status; + absl::Status producer_status; FakeWorker* wi = workers_[1]; const string kBufKey = "fake_buf_key"; wi->buf_rendezvous()->ProvideBuf( kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, AllocatorAttributes(), - [&producer_note, &producer_status](const Status& s) { + [&producer_note, &producer_status](const absl::Status& s) { producer_status.Update(s); producer_note.Notify(); }, @@ -399,7 +399,7 @@ TEST_P(CollRMADistTest, ProdFirstOK) { kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &consumer_note](const Status& s) { + [&consumer_status, &consumer_note](const absl::Status& s) { consumer_status = s; consumer_note.Notify(); }); @@ -414,8 +414,8 @@ TEST_P(CollRMADistTest, ConsFirstOK) { ResolveDeviceAttributes(); Notification consumer_note; Notification producer_note; - Status consumer_status; - Status producer_status; + absl::Status consumer_status; + absl::Status producer_status; FakeWorker* wi = workers_[1]; const string kBufKey = "fake_buf_key"; Device* dst_device = nullptr; @@ -430,14 +430,14 @@ TEST_P(CollRMADistTest, ConsFirstOK) { kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &consumer_note](const Status& s) { + [&consumer_status, &consumer_note](const absl::Status& s) { consumer_status = s; consumer_note.Notify(); }); wi->buf_rendezvous()->ProvideBuf( kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, AllocatorAttributes(), - [&producer_note, &producer_status](const Status& s) { + [&producer_note, &producer_status](const absl::Status& s) { producer_status.Update(s); producer_note.Notify(); }, @@ -452,7 +452,7 @@ TEST_P(CollRMADistTest, ConsFirstOK) { TEST_P(CollRMADistTest, ConsFirstAbort) { ResolveDeviceAttributes(); Notification consumer_note; - Status consumer_status; + absl::Status consumer_status; const string kBufKey = "fake_buf_key"; Device* dst_device = nullptr; string dev_name = "CPU:0"; @@ -466,7 +466,7 @@ TEST_P(CollRMADistTest, ConsFirstAbort) { kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &consumer_note](const Status& s) { + [&consumer_status, &consumer_note](const absl::Status& s) { consumer_status = s; consumer_note.Notify(); }); @@ -479,14 +479,14 @@ TEST_P(CollRMADistTest, ResponseTooLarge) { ResolveDeviceAttributes(); Notification consumer_note; Notification producer_note; - Status consumer_status; - Status producer_status; + absl::Status consumer_status; + absl::Status producer_status; FakeWorker* wi = workers_[1]; const string kBufKey = "fake_buf_key"; wi->buf_rendezvous()->ProvideBuf( kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &large_response_, AllocatorAttributes(), - [&producer_note, &producer_status](const Status& s) { + [&producer_note, &producer_status](const absl::Status& s) { producer_status.Update(s); producer_note.Notify(); }, @@ -503,7 +503,7 @@ TEST_P(CollRMADistTest, ResponseTooLarge) { kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &consumer_note](const Status& s) { + [&consumer_status, &consumer_note](const absl::Status& s) { consumer_status = s; consumer_note.Notify(); }); @@ -519,8 +519,8 @@ TEST_P(CollRMADistTest, WorkerRestart) { ResolveDeviceAttributes(); Notification consumer_note; Notification producer_note; - Status consumer_status; - Status producer_status; + absl::Status consumer_status; + absl::Status producer_status; FakeWorker* wi = workers_[1]; const string buf_key = "fake_buf_key"; Device* dst_device = nullptr; @@ -535,14 +535,14 @@ TEST_P(CollRMADistTest, WorkerRestart) { buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &consumer_note](const Status& s) { + [&consumer_status, &consumer_note](const absl::Status& s) { consumer_status = s; consumer_note.Notify(); }); wi->buf_rendezvous()->ProvideBuf( buf_key, nullptr /*device*/, nullptr /*dev_ctx*/, &expected_value_, AllocatorAttributes(), - [&producer_note, &producer_status](const Status& s) { + [&producer_note, &producer_status](const absl::Status& s) { producer_status.Update(s); producer_note.Notify(); }, @@ -563,7 +563,7 @@ TEST_P(CollRMADistTest, WorkerRestart) { buf_key, dst_device, to_device_ctx, alloc_attr_, &to_tensor_, device_locality_, 0 /*dev_to_dev_stream_index*/, nullptr /*cancellation_manager*/, - [&consumer_status, &post_restart_note](const Status& s) { + [&consumer_status, &post_restart_note](const absl::Status& s) { consumer_status = s; post_restart_note.Notify(); }); @@ -573,11 +573,11 @@ TEST_P(CollRMADistTest, WorkerRestart) { TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) { ResolveDeviceAttributes(); - Status check_health_status; + absl::Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, - [&check_health_status, &check_health_done](const Status s) { + [&check_health_status, &check_health_done](const absl::Status s) { check_health_status = s; check_health_done.Notify(); }); @@ -586,11 +586,11 @@ TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) { } TEST_P(CollRMADistTest, CheckHealthOKWithoutCachedAttr) { - Status check_health_status; + absl::Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, - [&check_health_status, &check_health_done](const Status s) { + [&check_health_status, &check_health_done](const absl::Status s) { check_health_status = s; check_health_done.Notify(); }); @@ -602,11 +602,11 @@ TEST_P(CollRMADistTest, CheckHealthRestarted) { ResolveDeviceAttributes(); RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1); - Status check_health_status; + absl::Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, - [&check_health_status, &check_health_done](const Status s) { + [&check_health_status, &check_health_done](const absl::Status s) { check_health_status = s; check_health_done.Notify(); }); @@ -619,11 +619,11 @@ TEST_P(CollRMADistTest, CheckHealthFailedPeer) { RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1, /*is_failed*/ true); - Status check_health_status; + absl::Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, - [&check_health_status, &check_health_done](const Status s) { + [&check_health_status, &check_health_done](const absl::Status s) { check_health_status = s; check_health_done.Notify(); }); @@ -634,11 +634,11 @@ TEST_P(CollRMADistTest, CheckHealthFailedPeer) { TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) { ResolveDeviceAttributes(); RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1); - Status check_health_status; + absl::Status check_health_status; Notification check_health_done; rma_->CheckPeerHealth( "/job:worker/replica:0/task:1", /*timeout_in_ms=*/0, - [&check_health_status, &check_health_done](const Status s) { + [&check_health_status, &check_health_done](const absl::Status s) { check_health_status = s; check_health_done.Notify(); }); diff --git a/tensorflow/core/distributed_runtime/coordination/BUILD b/tensorflow/core/distributed_runtime/coordination/BUILD index a964cb8d1a0f43..6c0bb0e705517c 100644 --- a/tensorflow/core/distributed_runtime/coordination/BUILD +++ b/tensorflow/core/distributed_runtime/coordination/BUILD @@ -60,8 +60,8 @@ cc_library( "@com_google_absl//absl/time", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) @@ -81,11 +81,11 @@ tf_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", "@local_xla//xla/tsl/distributed_runtime:call_options", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "@local_xla//xla/tsl/protobuf:coordination_config_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc index d24ef5f03a2896..6196e9b12355b4 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.cc @@ -27,15 +27,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { -std::pair BarrierProxy::Wait() { +std::pair BarrierProxy::Wait() { mutex_lock l(mu_); if (status_set_) { return std::make_pair( @@ -97,10 +97,10 @@ size_t BarrierProxyManager::size() const { return barriers_.size(); } -Status BarrierProxyManager::Wait(tsl::CoordinationServiceAgent* agent, - const std::vector& tasks, - int num_local_threads, absl::string_view key, - absl::Duration timeout) { +absl::Status BarrierProxyManager::Wait( + tsl::CoordinationServiceAgent* agent, + const std::vector& tasks, int num_local_threads, + absl::string_view key, absl::Duration timeout) { // Only one device, no need to wait. if (tasks.size() == 1 && num_local_threads <= 1) return absl::OkStatus(); diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h index 5d9aeeec3debc4..3e0243ab2245bd 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy.h @@ -25,11 +25,11 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { @@ -72,7 +72,7 @@ class BarrierProxy { // Waits at the barrier. The first return value is the status when exiting the // barrier and the second returns `true` for precisely one caller, which may // then destroy the barrier. - std::pair Wait(); + std::pair Wait(); private: const std::string key_; @@ -85,7 +85,7 @@ class BarrierProxy { const int num_local_threads_; int num_entered_ TF_GUARDED_BY(mu_) = 0; int num_to_exit_ TF_GUARDED_BY(mu_) = 0; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); bool status_set_ TF_GUARDED_BY(mu_) = false; }; @@ -108,9 +108,10 @@ class BarrierProxyManager { // `num_local_threads` specifies the number of threads in this task to // participate. If no tasks are specified, the barrier will block for all the // connected tasks. - Status Wait(tsl::CoordinationServiceAgent* agent, - const std::vector& tasks, int num_local_threads, - absl::string_view key, absl::Duration timeout); + absl::Status Wait(tsl::CoordinationServiceAgent* agent, + const std::vector& tasks, + int num_local_threads, absl::string_view key, + absl::Duration timeout); // The number of active BarrierProxies. size_t size() const; diff --git a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc index b74535ccb44d93..81948b936c5bd2 100644 --- a/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc +++ b/tensorflow/core/distributed_runtime/coordination/coordination_service_barrier_proxy_test.cc @@ -32,13 +32,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { namespace { @@ -51,20 +51,21 @@ using tsl::CoordinationServiceAgent; class MockCoordinationServiceAgent : public CoordinationServiceAgent { public: - MOCK_METHOD(Status, WaitAtBarrier, + MOCK_METHOD(absl::Status, WaitAtBarrier, (std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks), (override)); - MOCK_METHOD(Status, CancelBarrier, (std::string_view barrier_id), (override)); + MOCK_METHOD(absl::Status, CancelBarrier, (std::string_view barrier_id), + (override)); // All the following member functions are not needed for testing. - MOCK_METHOD(Status, Initialize, + MOCK_METHOD(absl::Status, Initialize, (Env * env, std::string_view job_name, int task_id, const CoordinationServiceConfig& configs, std::unique_ptr leader_client, StatusCallback error_fn), (override)); - MOCK_METHOD(Status, Initialize, + MOCK_METHOD(absl::Status, Initialize, (Env * env, const CoordinatedTask& task, const CoordinationServiceConfig& configs, std::unique_ptr leader_client, @@ -73,17 +74,18 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(bool, IsInitialized, (), (override)); MOCK_METHOD(bool, IsConnected, (), (override)); MOCK_METHOD(bool, IsError, (), (override)); - MOCK_METHOD(Status, Connect, (), (override)); - MOCK_METHOD(Status, WaitForAllTasks, (const DeviceInfo& local_devices), + MOCK_METHOD(absl::Status, Connect, (), (override)); + MOCK_METHOD(absl::Status, WaitForAllTasks, (const DeviceInfo& local_devices), (override)); MOCK_METHOD(const DeviceInfo&, GetClusterDeviceInfo, (), (override)); MOCK_METHOD(absl::StatusOr, GetOwnTask, (), (override)); MOCK_METHOD(absl::StatusOr>, GetTaskState, (const std::vector& task), (override)); - MOCK_METHOD(Status, ReportError, (const Status& error), (override)); - MOCK_METHOD(Status, Shutdown, (), (override)); - MOCK_METHOD(Status, Reset, (), (override)); + MOCK_METHOD(absl::Status, ReportError, (const absl::Status& error), + (override)); + MOCK_METHOD(absl::Status, Shutdown, (), (override)); + MOCK_METHOD(absl::Status, Reset, (), (override)); MOCK_METHOD(absl::StatusOr, GetKeyValue, (std::string_view key), (override)); MOCK_METHOD(absl::StatusOr, GetKeyValue, @@ -97,19 +99,19 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(void, GetKeyValueDirAsync, (std::string_view key, StatusOrValueDirCallback done), (override)); - MOCK_METHOD(Status, InsertKeyValue, + MOCK_METHOD(absl::Status, InsertKeyValue, (std::string_view key, std::string_view value), (override)); - MOCK_METHOD(Status, InsertKeyValue, + MOCK_METHOD(absl::Status, InsertKeyValue, (std::string_view key, std::string_view value, bool allow_overwrite), (override)); - MOCK_METHOD(Status, DeleteKeyValue, (std::string_view key), (override)); - MOCK_METHOD(Status, UpdateKeyValue, + MOCK_METHOD(absl::Status, DeleteKeyValue, (std::string_view key), (override)); + MOCK_METHOD(absl::Status, UpdateKeyValue, (std::string_view key, std::string_view value), (override)); - MOCK_METHOD(Status, StartWatchKey, + MOCK_METHOD(absl::Status, StartWatchKey, (std::string_view key, ChangedKeyValuesCallback on_change), (override)); - MOCK_METHOD(Status, StopWatchKey, (std::string_view key), (override)); + MOCK_METHOD(absl::Status, StopWatchKey, (std::string_view key), (override)); MOCK_METHOD(void, WaitAtBarrierAsync, (std::string_view barrier_id, absl::Duration timeout, const std::vector& tasks, StatusCallback done), @@ -117,8 +119,8 @@ class MockCoordinationServiceAgent : public CoordinationServiceAgent { MOCK_METHOD(void, CancelBarrierAsync, (std::string_view barrier_id, StatusCallback done), (override)); MOCK_METHOD(absl::StatusOr, GetEnv, (), (override)); - MOCK_METHOD(void, SetError, (const Status& error), (override)); - MOCK_METHOD(Status, ActivateWatch, + MOCK_METHOD(void, SetError, (const absl::Status& error), (override)); + MOCK_METHOD(absl::Status, ActivateWatch, (std::string_view key, (const std::map&)), (override)); @@ -130,8 +132,8 @@ const int kThreadPoolSize = 32; void TestBarrierProxyWait( int num_tasks, int num_threads_planned, int num_threads_entered, - int expected_ok_count, std::optional agent_wait_status, - std::optional expected_same_exit_status_for_all_threads) { + int expected_ok_count, std::optional agent_wait_status, + std::optional expected_same_exit_status_for_all_threads) { auto agent = std::make_unique(); const std::vector tasks(num_tasks); BarrierProxy barrier(agent.get(), tasks, num_threads_planned, kTestKey, @@ -217,7 +219,7 @@ TEST(BarrierProxyTest, ExtraThreadsEnteringTheBarrierGetErrors) { void TestBarrierProxyManagerWaitSingleKey( int num_threads_planned, int num_threads_entered, - std::optional agent_wait_status, int expected_ok_count) { + std::optional agent_wait_status, int expected_ok_count) { auto agent = std::make_unique(); const std::vector tasks; BarrierProxyManager mgr; diff --git a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc index 9349ae54a15052..200eaaa4c67e9a 100644 --- a/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc +++ b/tensorflow/core/distributed_runtime/eager/cluster_function_library_runtime.cc @@ -51,7 +51,7 @@ void EagerClusterFunctionLibraryRuntime::Instantiate( FunctionLibraryRuntime::DoneCallback done) { auto target = options.target; auto released_op = std::make_unique(ctx_); - Status s = + absl::Status s = released_op->Reset(function_name.c_str(), target.c_str(), true, nullptr); if (!s.ok()) { done(s); @@ -100,7 +100,7 @@ void EagerClusterFunctionLibraryRuntime::Instantiate( /*call_opts=*/nullptr, request.get(), response.get(), [this, request, response, handle, released_op = released_op.release(), target, ret_indices, eager_client = eager_client.get(), - done](const Status& s) { + done](const absl::Status& s) { { mutex_lock l(mu_); *handle = function_data_.size(); @@ -121,8 +121,8 @@ void EagerClusterFunctionLibraryRuntime::Run( } std::vector* function_rets = new std::vector; Run(opts, handle, function_args, function_rets, - [rets, function_rets, done = std::move(done)](const Status& s) { - Status status = s; + [rets, function_rets, done = std::move(done)](const absl::Status& s) { + absl::Status status = s; if (status.ok()) { for (const auto& t : *function_rets) { if (t.index() == 0) { @@ -223,7 +223,7 @@ void EagerClusterFunctionLibraryRuntime::Run( eager_client->RunComponentFunctionAsync( call_opts.get(), request.get(), response.get(), [request, response, rets, call_opts, cm, token, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { if (cm != nullptr) { cm->TryDeregisterCallback(token); } @@ -280,7 +280,7 @@ void EagerClusterFunctionLibraryRuntime::CleanUp( // enqueue done callback of Run(). So we don't use StreamingEnqueueAsync here. eager_client->EnqueueAsync( /*call_opts=*/nullptr, request.get(), response.get(), - [request, response, done](const Status& status) { done(status); }); + [request, response, done](const absl::Status& status) { done(status); }); } DistributedFunctionLibraryRuntime* CreateClusterFLR( diff --git a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h index ca5eaa2526f6cb..a9b9ead8cf1233 100644 --- a/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h +++ b/tensorflow/core/distributed_runtime/eager/destroy_tensor_handle_node.h @@ -50,7 +50,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { // safe to ignore a failing destroy tensor handle request. eager_client_->EnqueueAsync( /*call_opts=*/nullptr, request_.get(), response, - [response, ready, done](const tensorflow::Status& s) { + [response, ready, done](const absl::Status& s) { // Omit the warning if: // 1. The remote tensor isn't ready. // 2. Lost connection to remote worker. In this case client will @@ -66,7 +66,7 @@ class DestroyTensorHandleNode : public tensorflow::AsyncEagerNode { }); } - void Abort(Status status) override {} + void Abort(absl::Status status) override {} // Remote node deletions are best effort bool Fatal() const override { return false; } diff --git a/tensorflow/core/distributed_runtime/eager/eager_client.h b/tensorflow/core/distributed_runtime/eager/eager_client.h index a2265d51a1f5d3..6fc956014ab666 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_client.h +++ b/tensorflow/core/distributed_runtime/eager/eager_client.h @@ -92,8 +92,8 @@ class EagerClientCache { // increment the refcount of the client. The reference ownership is // transferred to the caller, and the unref should automatically happen when // destructing the RefCountPtr object from the caller's side. - virtual Status GetClient(const string& target, - core::RefCountPtr* client) = 0; + virtual absl::Status GetClient(const string& target, + core::RefCountPtr* client) = 0; }; } // namespace eager diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc index d4a5effeb78a21..2a69d3cba16fe9 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/c/eager/immediate_execution_distributed_manager.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" @@ -53,15 +54,13 @@ limitations under the License. #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/lib/traceme.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace eager { namespace { -Status GetNumRetvals(FunctionLibraryDefinition* func_lib_def, - const string& op_name, - const google::protobuf::Map& attrs, - int* num_retvals) { +absl::Status GetNumRetvals( + FunctionLibraryDefinition* func_lib_def, const string& op_name, + const google::protobuf::Map& attrs, int* num_retvals) { const tensorflow::OpRegistrationData* op_reg_data = nullptr; auto status = tensorflow::OpRegistry::Global()->LookUp(op_name, &op_reg_data); if (absl::IsNotFound(status)) { @@ -96,11 +95,11 @@ Status GetNumRetvals(FunctionLibraryDefinition* func_lib_def, return absl::OkStatus(); } -Status GetEagerOperationAndNumRetvals(const Operation& operation, - EagerContext* eager_context, - EagerExecutor* eager_executor, - EagerOperation* eager_op, - int* num_retvals) { +absl::Status GetEagerOperationAndNumRetvals(const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + EagerOperation* eager_op, + int* num_retvals) { const char* name = operation.name().c_str(); // Shorthand std::optional remote_func_params = std::nullopt; @@ -164,14 +163,14 @@ Status GetEagerOperationAndNumRetvals(const Operation& operation, num_retvals); } -Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { +absl::Status TensorHandleProto(TensorHandle* handle, TensorProto* proto) { const tensorflow::Tensor* t = nullptr; TF_RETURN_IF_ERROR(handle->Tensor(&t)); t->AsProtoTensorContent(proto); return absl::OkStatus(); } -Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { +absl::Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { const tensorflow::Tensor* t = nullptr; // TODO(nareshmodi): This call makes async calls sync calls. Fix this. @@ -188,7 +187,7 @@ Status TensorHandleShape(TensorHandle* handle, TensorShapeProto* proto) { return absl::OkStatus(); } -Status AddOpRetvalsToResponse( +absl::Status AddOpRetvalsToResponse( EagerContext* eager_context, int op_id, int num_retvals, const std::vector& output_nums, TensorHandle** retvals, std::function add_tensor_proto_fn, @@ -225,12 +224,12 @@ Status AddOpRetvalsToResponse( return sg.as_summary_status(); } -Status ResetAgentAndConnectToCoordinationService( +absl::Status ResetAgentAndConnectToCoordinationService( tsl::CoordinationServiceAgent* coord_agent) { // The error state should already be consumed when a new context is // created. It should be fine to reset the agent. if (coord_agent->IsError()) { - const Status s = coord_agent->Reset(); + const absl::Status s = coord_agent->Reset(); if (!s.ok()) { LOG(ERROR) << "Coordination Service agent reset failed " << s; return s; @@ -240,7 +239,7 @@ Status ResetAgentAndConnectToCoordinationService( // cannot be propagated. As a result, Coordination Service agent can still // have the status of being connected. We should not let it connect again. if (!coord_agent->IsConnected()) { - const Status s = coord_agent->Connect(); + const absl::Status s = coord_agent->Connect(); if (!s.ok()) { LOG(ERROR) << "Coordination Service agent connect failed " << s; return s; @@ -251,8 +250,8 @@ Status ResetAgentAndConnectToCoordinationService( } // namespace -Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, - CreateContextResponse* response) { +absl::Status EagerServiceImpl::CreateContext( + const CreateContextRequest* request, CreateContextResponse* response) { bool update_collective_executor_mgr = false; { mutex_lock l(contexts_mu_); @@ -388,7 +387,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, auto remote_mgr = std::make_unique(/*is_master=*/false, ctx); - Status s = ctx->InitializeRemoteWorker( + absl::Status s = ctx->InitializeRemoteWorker( std::move(remote_eager_workers), worker_session->remote_device_mgr(), remote_workers, request->context_id(), request->context_view_id(), std::move(rendezvous_creator), cluster_flr, std::move(remote_mgr), @@ -420,7 +419,7 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, [coord_agent](absl::StatusOr time_or_status) { if (time_or_status.ok()) { const auto coord_task = coord_agent->GetOwnTask().value(); - Status s = coord_agent->InsertKeyValue( + absl::Status s = coord_agent->InsertKeyValue( "TF_DEFAULT_PREEMPTION_NOTICE_KEY", absl::StrCat("/job:", coord_task.job_name(), "/task:", coord_task.task_id())); @@ -458,8 +457,8 @@ Status EagerServiceImpl::CreateContext(const CreateContextRequest* request, return absl::OkStatus(); } -Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, - UpdateContextResponse* response) { +absl::Status EagerServiceImpl::UpdateContext( + const UpdateContextRequest* request, UpdateContextResponse* response) { // make sure env_ , env_->rendezvous_mgr available if (env_ == nullptr || env_->rendezvous_mgr == nullptr) { return tensorflow::errors::Internal( @@ -521,8 +520,8 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, &remote_eager_workers)); ctx->ClearCachesAndThreadExecutors(); - Status s = ctx->UpdateRemoteWorker(std::move(remote_eager_workers), - remote_workers, request->context_id()); + absl::Status s = ctx->UpdateRemoteWorker( + std::move(remote_eager_workers), remote_workers, request->context_id()); if (!s.ok()) { VLOG(1) << "EagerContext::UpdateRemoteWorker failed with " << s.ToString(); return s; @@ -549,7 +548,7 @@ Status EagerServiceImpl::UpdateContext(const UpdateContextRequest* request, return absl::OkStatus(); } -Status EagerServiceImpl::CreateMasterContext( +absl::Status EagerServiceImpl::CreateMasterContext( const tensorflow::uint64 context_id, EagerContext* context) { { mutex_lock l(contexts_mu_); @@ -571,7 +570,7 @@ void EagerServiceImpl::RunComponentFunction( CallOptions* call_opts, const RunComponentFunctionRequest* request, RunComponentFunctionResponse* response, StatusCallback done) { ServerContext* context = nullptr; - Status s = GetServerContext(request->context_id(), &context); + absl::Status s = GetServerContext(request->context_id(), &context); if (!s.ok()) { done(s); return; @@ -633,9 +632,9 @@ void EagerServiceImpl::RunComponentFunction( op, retvals->data(), num_retvals, [op, op_id = operation.id(), num_retvals, retvals, output_nums, cm, call_opts, response, eager_context, context, - done = std::move(done)](const Status& status) { + done = std::move(done)](const absl::Status& status) { call_opts->ClearCancelCallback(); - auto wrapped_done = [&](const Status& status) { + auto wrapped_done = [&](const absl::Status& status) { context->Unref(); done(status); delete op; @@ -655,11 +654,11 @@ void EagerServiceImpl::RunComponentFunction( }); } -Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts, - const Operation& operation, - EagerContext* eager_context, - EagerExecutor* eager_executor, - QueueResponse* queue_response) { +absl::Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts, + const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + QueueResponse* queue_response) { tensorflow::EagerOperation op(eager_context); int num_retvals = 0; TF_RETURN_IF_ERROR(GetEagerOperationAndNumRetvals( @@ -694,9 +693,10 @@ Status EagerServiceImpl::ExecuteOp(CallOptions* call_opts, std::move(add_device_fn)); } -Status EagerServiceImpl::Enqueue(CallOptions* call_opts, - const EnqueueRequest* request, - EnqueueResponse* response, uint64 stream_id) { +absl::Status EagerServiceImpl::Enqueue(CallOptions* call_opts, + const EnqueueRequest* request, + EnqueueResponse* response, + uint64 stream_id) { tsl::profiler::TraceMe activity( [&] { return absl::StrCat( @@ -712,7 +712,7 @@ Status EagerServiceImpl::Enqueue(CallOptions* call_opts, ? context->Context()->Executor() : context->Context()->RemoteMgr()->GetOrCreateExecutorForStream( stream_id); - Status s; + absl::Status s; for (const auto& item : request->queue()) { auto* queue_response = response->add_queue_response(); if (item.has_operation()) { @@ -750,8 +750,8 @@ Status EagerServiceImpl::Enqueue(CallOptions* call_opts, return absl::OkStatus(); } -Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, - WaitQueueDoneResponse* response) { +absl::Status EagerServiceImpl::WaitQueueDone( + const WaitQueueDoneRequest* request, WaitQueueDoneResponse* response) { ServerContext* context = nullptr; TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); core::ScopedUnref context_unref(context); @@ -764,8 +764,8 @@ Status EagerServiceImpl::WaitQueueDone(const WaitQueueDoneRequest* request, return context->Context()->Executor().WaitForAllPendingNodes(); } -Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, - KeepAliveResponse* response) { +absl::Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, + KeepAliveResponse* response) { ServerContext* context = nullptr; TF_RETURN_IF_ERROR(GetServerContext(request->context_id(), &context)); core::ScopedUnref context_unref(context); @@ -775,8 +775,8 @@ Status EagerServiceImpl::KeepAlive(const KeepAliveRequest* request, return absl::OkStatus(); } -Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, - CloseContextResponse* response) { +absl::Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, + CloseContextResponse* response) { ServerContext* context = nullptr; if (!GetServerContext(request->context_id(), &context).ok()) { // Swallow the error here. @@ -804,7 +804,7 @@ Status EagerServiceImpl::CloseContext(const CloseContextRequest* request, return absl::OkStatus(); } -Status EagerServiceImpl::RegisterFunction( +absl::Status EagerServiceImpl::RegisterFunction( const RegisterFunctionOp& register_function, EagerContext* eager_context) { // If the function is a component of a multi-device function, we only need to // register it locally. @@ -818,20 +818,20 @@ Status EagerServiceImpl::RegisterFunction( } } -Status EagerServiceImpl::RemoveFunction(const RemoveFunctionOp& remove_function, - EagerContext* eager_context) { +absl::Status EagerServiceImpl::RemoveFunction( + const RemoveFunctionOp& remove_function, EagerContext* eager_context) { return eager_context->RemoveFunction(remove_function.function_name()); } -Status EagerServiceImpl::CleanupFunction( +absl::Status EagerServiceImpl::CleanupFunction( const CleanupFunctionOp& cleanup_function) { env_->rendezvous_mgr->Cleanup(cleanup_function.step_id()); return absl::OkStatus(); } -Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, - EagerContext* eager_context) { - tensorflow::gtl::InlinedVector tensors; +absl::Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, + EagerContext* eager_context) { + absl::InlinedVector tensors; for (const auto& tensor_proto : send_tensor.tensors()) { Tensor tensor; if (!tensor.FromProto(tensor_proto)) { @@ -856,7 +856,7 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor, return absl::OkStatus(); } -Status EagerServiceImpl::SendPackedHandle( +absl::Status EagerServiceImpl::SendPackedHandle( const SendPackedHandleOp& send_packed_handle, EagerContext* eager_context) { if (send_packed_handle.handles().empty()) { return errors::InvalidArgument("Handles should not be empty."); @@ -902,7 +902,7 @@ Status EagerServiceImpl::SendPackedHandle( return absl::OkStatus(); } -tensorflow::Status EagerServiceImpl::GetServerContext( +absl::Status EagerServiceImpl::GetServerContext( uint64 context_id, ServerContext** server_context) { tf_shared_lock l(contexts_mu_); auto iter = contexts_.find(context_id); diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h index 8239b361b48ef5..924a99dd81c1ab 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl.h @@ -76,36 +76,36 @@ class EagerServiceImpl { } } - Status CreateContext(const CreateContextRequest* request, - CreateContextResponse* response); + absl::Status CreateContext(const CreateContextRequest* request, + CreateContextResponse* response); - Status UpdateContext(const UpdateContextRequest* request, - UpdateContextResponse* response); + absl::Status UpdateContext(const UpdateContextRequest* request, + UpdateContextResponse* response); // Create a ServerContext for master eager context. - Status CreateMasterContext(const tensorflow::uint64 context_id, - EagerContext* context); + absl::Status CreateMasterContext(const tensorflow::uint64 context_id, + EagerContext* context); static constexpr uint64 kInvalidStreamId = 0; // Used by both Enqueue and StreamingEnqueue RPCs. - Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, - EnqueueResponse* response, - uint64 stream_id = kInvalidStreamId); + absl::Status Enqueue(CallOptions* call_opts, const EnqueueRequest* request, + EnqueueResponse* response, + uint64 stream_id = kInvalidStreamId); - Status WaitQueueDone(const WaitQueueDoneRequest* request, - WaitQueueDoneResponse* response); + absl::Status WaitQueueDone(const WaitQueueDoneRequest* request, + WaitQueueDoneResponse* response); void RunComponentFunction(CallOptions* call_opts, const RunComponentFunctionRequest* request, RunComponentFunctionResponse* response, StatusCallback done); - Status KeepAlive(const KeepAliveRequest* request, - KeepAliveResponse* response); + absl::Status KeepAlive(const KeepAliveRequest* request, + KeepAliveResponse* response); - Status CloseContext(const CloseContextRequest* request, - CloseContextResponse* response); + absl::Status CloseContext(const CloseContextRequest* request, + CloseContextResponse* response); protected: // This is the server-side execution context. All state regarding execution of @@ -166,7 +166,7 @@ class EagerServiceImpl { const bool is_master_; }; // The returned ServerContext will need to be Unrefed. - tensorflow::Status GetServerContext(uint64, ServerContext**); + absl::Status GetServerContext(uint64, ServerContext**); class ClientTensorHandleDeleteNode : public EagerNode { public: @@ -181,7 +181,7 @@ class EagerServiceImpl { ~ClientTensorHandleDeleteNode() override { context_->Unref(); } - Status Run() override { + absl::Status Run() override { VLOG(3) << "ServerContext: Deleting tensor handle " << handle_to_delete_->op_id << ":" << handle_to_delete_->output_num; @@ -189,7 +189,7 @@ class EagerServiceImpl { *handle_to_delete_); } - void Abort(Status status) override {} + void Abort(absl::Status status) override {} // Remote node deletions are best effort bool Fatal() const override { return false; } @@ -208,18 +208,19 @@ class EagerServiceImpl { }; private: - Status ExecuteOp(CallOptions* call_opts, const Operation& operation, - EagerContext* eager_context, EagerExecutor* eager_executor, - QueueResponse* queue_response); - Status SendTensor(const SendTensorOp& send_tensor, - EagerContext* eager_context); - Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, + absl::Status ExecuteOp(CallOptions* call_opts, const Operation& operation, + EagerContext* eager_context, + EagerExecutor* eager_executor, + QueueResponse* queue_response); + absl::Status SendTensor(const SendTensorOp& send_tensor, EagerContext* eager_context); - Status RegisterFunction(const RegisterFunctionOp& register_function, - EagerContext* eager_context); - Status RemoveFunction(const RemoveFunctionOp& remove_function, - EagerContext* eager_context); - Status CleanupFunction(const CleanupFunctionOp& cleanup_function); + absl::Status SendPackedHandle(const SendPackedHandleOp& send_packed_handle, + EagerContext* eager_context); + absl::Status RegisterFunction(const RegisterFunctionOp& register_function, + EagerContext* eager_context); + absl::Status RemoveFunction(const RemoveFunctionOp& remove_function, + EagerContext* eager_context); + absl::Status CleanupFunction(const CleanupFunctionOp& cleanup_function); WorkerEnv* const env_; // Not owned. diff --git a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc index 84cfe637697c70..b7d513bb74c013 100644 --- a/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc +++ b/tensorflow/core/distributed_runtime/eager/eager_service_impl_test.cc @@ -55,16 +55,16 @@ namespace { class TestEagerServiceImpl : public EagerServiceImpl { public: explicit TestEagerServiceImpl(WorkerEnv* env) : EagerServiceImpl(env) {} - Status GetEagerContext(const uint64 context_id, EagerContext** ctx) { + absl::Status GetEagerContext(const uint64 context_id, EagerContext** ctx) { ServerContext* context = nullptr; TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); core::ScopedUnref context_unref(context); *ctx = context->Context(); return absl::OkStatus(); } - Status GetTensorHandle(const uint64 context_id, - const RemoteTensorHandleInternal& remote_handle, - tensorflow::TensorHandle** handle) { + absl::Status GetTensorHandle(const uint64 context_id, + const RemoteTensorHandleInternal& remote_handle, + tensorflow::TensorHandle** handle) { ServerContext* context = nullptr; TF_RETURN_IF_ERROR(GetServerContext(context_id, &context)); core::ScopedUnref context_unref(context); @@ -135,8 +135,8 @@ class FakeEagerClient : public EagerClient { class DummyEagerClientCache : public EagerClientCache { public: DummyEagerClientCache() : client_(new FakeEagerClient) {} - Status GetClient(const string& target, - core::RefCountPtr* client) override { + absl::Status GetClient(const string& target, + core::RefCountPtr* client) override { client->reset(client_.get()); client_->Ref(); return absl::OkStatus(); @@ -147,7 +147,7 @@ class DummyEagerClientCache : public EagerClientCache { }; class FakeCache : public TestWorkerCache { - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { *eager_client_cache = std::make_unique(); return absl::OkStatus(); @@ -586,7 +586,7 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest { } CallOptions call_opts; - Status status; + absl::Status status; Notification n; Env::Default()->SchedClosure([&] { status = eager_service_impl.Enqueue(&call_opts, &remote_enqueue_request, @@ -680,13 +680,13 @@ class EagerServiceImplFunctionTest : public EagerServiceImplTest { CallOptions call_opts; Notification n; - Status status; - eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, - &run_comp_func_response, - [&status, &n](const Status& s) { - status.Update(s); - n.Notify(); - }); + absl::Status status; + eager_service_impl.RunComponentFunction( + &call_opts, &run_comp_func_request, &run_comp_func_response, + [&status, &n](const absl::Status& s) { + status.Update(s); + n.Notify(); + }); if (test_cancel) { call_opts.StartCancel(); } @@ -846,13 +846,13 @@ TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionWithNameClashTest) { CallOptions call_opts; Notification n; - Status status; - eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, - &run_comp_func_response, - [&status, &n](const Status& s) { - status.Update(s); - n.Notify(); - }); + absl::Status status; + eager_service_impl.RunComponentFunction( + &call_opts, &run_comp_func_request, &run_comp_func_response, + [&status, &n](const absl::Status& s) { + status.Update(s); + n.Notify(); + }); n.WaitForNotification(); TF_ASSERT_OK(status); @@ -888,13 +888,13 @@ TEST_F(EagerServiceImplFunctionTest, ComponentNestedFunctionWithNameClashTest) { CallOptions call_opts; Notification n; - Status status; - eager_service_impl.RunComponentFunction(&call_opts, &run_comp_func_request, - &run_comp_func_response, - [&status, &n](const Status& s) { - status.Update(s); - n.Notify(); - }); + absl::Status status; + eager_service_impl.RunComponentFunction( + &call_opts, &run_comp_func_request, &run_comp_func_response, + [&status, &n](const absl::Status& s) { + status.Update(s); + n.Notify(); + }); n.WaitForNotification(); TF_ASSERT_OK(status); @@ -934,21 +934,21 @@ class FunctionWithRemoteInputsTest : public EagerServiceImplTest { class TestExecuteNodeArgs : public EagerKernelArgs { public: TestExecuteNodeArgs( - gtl::InlinedVector&& tensor_args, - std::function + absl::InlinedVector&& tensor_args, + std::function serialize_remote_handle) : EagerKernelArgs(std::move(tensor_args)), serialize_remote_handle_(std::move(serialize_remote_handle)) {} bool HasRemoteOrPackedInputs() const override { return true; } - Status GetRemoteArg(const FunctionArgIndex& index, - eager::RemoteTensorHandle* val) const override { + absl::Status GetRemoteArg(const FunctionArgIndex& index, + eager::RemoteTensorHandle* val) const override { return serialize_remote_handle_(index.index, val); } private: - std::function + std::function serialize_remote_handle_; }; @@ -1087,7 +1087,7 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { const uint64 op_id = 2; opts.op_id = op_id; Notification done; - Status status; + absl::Status status; RemoteTensorHandle input; input.set_op_id(1); input.set_output_num(0); @@ -1095,15 +1095,15 @@ TEST_F(FunctionWithRemoteInputsTest, EagerPFLRTest) { input.set_device(local_device_); std::vector inputs = {input}; std::vector outputs; - gtl::InlinedVector tensor_args = {TensorValue()}; + absl::InlinedVector tensor_args = {TensorValue()}; TestExecuteNodeArgs args( std::move(tensor_args), - [&inputs](const int i, RemoteTensorHandle* handle) -> Status { + [&inputs](const int i, RemoteTensorHandle* handle) -> absl::Status { *handle = inputs.at(i); return absl::OkStatus(); }); eager_pflr_->Run(opts, handle, args, &outputs, - [&status, &done](const Status& s) { + [&status, &done](const absl::Status& s) { status = s; done.Notify(); }); @@ -1119,12 +1119,12 @@ TEST_F(FunctionWithRemoteInputsTest, // Instantiate MatMulFunction on remote_device. FunctionLibraryRuntime::Handle handle; EXPECT_TRUE(MatMulHasAttrWithDefaultValue(fdef_)); - Status status; + absl::Status status; Notification instantiate_done; eager_cluster_flr_->Instantiate( fdef_.signature().name(), func_lib_def_, AttrSlice(&fdef_.attr()), FunctionLibraryRuntime::InstantiateOptions(), &handle, - [&status, &instantiate_done](const Status& s) { + [&status, &instantiate_done](const absl::Status& s) { status = s; instantiate_done.Notify(); }); @@ -1152,7 +1152,7 @@ TEST_F(FunctionWithRemoteInputsTest, std::vector inputs = {*input_tensor}; std::vector outputs; eager_cluster_flr_->Run(opts, handle, inputs, &outputs, - [&status, &execute_done](const Status& s) { + [&status, &execute_done](const absl::Status& s) { status = s; execute_done.Notify(); }); @@ -1193,7 +1193,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. - gtl::InlinedVector input_tensors = {TensorValue()}; + absl::InlinedVector input_tensors = {TensorValue()}; RemoteTensorHandle input; input.set_op_id(1); input.set_output_num(0); @@ -1202,7 +1202,8 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncTest) { std::vector remote_handles = {input}; TestExecuteNodeArgs inputs( std::move(input_tensors), - [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status { + [&remote_handles](const int index, + RemoteTensorHandle* handle) -> absl::Status { *handle = remote_handles.at(index); return absl::OkStatus(); }); @@ -1248,7 +1249,7 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { TF_ASSERT_OK(kernel->InstantiateFunc({}, node_def, nullptr, std::nullopt)); // Run MatMulFunction on remote_device. - gtl::InlinedVector input_tensors = {TensorValue()}; + absl::InlinedVector input_tensors = {TensorValue()}; RemoteTensorHandle input; input.set_op_id(1); input.set_output_num(0); @@ -1257,19 +1258,20 @@ TEST_F(FunctionWithRemoteInputsTest, KernelAndDeviceFuncAsyncTest) { std::vector remote_handles = {input}; TestExecuteNodeArgs inputs( std::move(input_tensors), - [&remote_handles](const int index, RemoteTensorHandle* handle) -> Status { + [&remote_handles](const int index, + RemoteTensorHandle* handle) -> absl::Status { *handle = remote_handles.at(index); return absl::OkStatus(); }); std::vector outputs; - Status status; + absl::Status status; Notification n; kernel->RunAsync(/*step_container=*/nullptr, inputs, &outputs, /*cancellation_manager=*/nullptr, /*eager_func_params=*/std::nullopt, /*coordination_service_agent=*/nullptr, - [&status, &n](const Status& s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); @@ -1483,8 +1485,8 @@ TEST_F(EagerServiceImplTest, RequestsToMasterTest) { SetTensorProto(send_tensor->add_tensors()); // Unable to handle the request since there is no eager context. - Status status = eager_service_impl.Enqueue(nullptr, &remote_enqueue_request, - &remote_enqueue_response); + absl::Status status = eager_service_impl.Enqueue( + nullptr, &remote_enqueue_request, &remote_enqueue_response); EXPECT_EQ(error::ABORTED, status.code()); EXPECT_TRUE(absl::StrContains( status.message(), @@ -1519,7 +1521,7 @@ TEST_F(EagerServiceImplTest, KeepAliveTest) { keep_alive_request.set_context_id(context_id); - Status status = + absl::Status status = eager_service_impl.KeepAlive(&keep_alive_request, &keep_alive_response); EXPECT_EQ(status.code(), error::ABORTED); diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc index b158bdadfebfba..c053c91f8fdb9d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.cc @@ -41,7 +41,7 @@ void PrepareRemoteOp(eager::Operation* remote_op, EagerOperation* op) { remote_op->set_device(op->DeviceName()); } -Status CreateUncachedKernelAndDeviceOp( +absl::Status CreateUncachedKernelAndDeviceOp( EagerOperation* op, core::RefCountPtr* kernel) { EagerContext& ctx = op->EagerContext(); Device* device = std::get(op->Device()); @@ -98,7 +98,7 @@ RemoteCopyNode::~RemoteCopyNode() { ctx_->Unref(); } -Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { +absl::Status RemoteCopyNode::RunLocalSend(EagerOperation* op) { TF_RETURN_IF_ERROR(executor_->status()); TF_RETURN_IF_ERROR(op->AddInput(src_)); @@ -123,8 +123,8 @@ void RemoteCopyNode::StartSend() { // TODO(gjn): We should consider just using the low-level SendOp::Compute() // functionality here instead of constructing an Op. EagerOperation op(ctx_); - Status status = op.Reset("_Send", /*device_name=*/nullptr, - /*remote=*/false, /*executor=*/nullptr); + absl::Status status = op.Reset("_Send", /*device_name=*/nullptr, + /*remote=*/false, /*executor=*/nullptr); if (!status.ok()) { captured_state_->SetSendStatus(status); return; @@ -180,7 +180,7 @@ void RemoteCopyNode::StartSend() { eager_client->StreamingEnqueueAsync( ctx_->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, &request, response, - [response, captured_state](const Status& s) { + [response, captured_state](const absl::Status& s) { captured_state->SetSendStatus(s); if (!s.ok()) { captured_state->recv_cancellation()->StartCancel(); @@ -190,8 +190,8 @@ void RemoteCopyNode::StartSend() { } } -Status RemoteCopyNode::RunLocalRecv(EagerOperation* op, - std::vector* outputs) { +absl::Status RemoteCopyNode::RunLocalRecv(EagerOperation* op, + std::vector* outputs) { TF_RETURN_IF_ERROR(executor_->status()); core::RefCountPtr kernel; @@ -228,7 +228,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { uint64 context_view_id = ctx_->GetContextViewId(); core::RefCountPtr eager_client; - Status status = ctx_->GetClient(recv_device_, &eager_client); + absl::Status status = ctx_->GetClient(recv_device_, &eager_client); if (!status.ok()) { captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); done(status); @@ -240,7 +240,7 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { // - remote send will take some time, but remote->remote copy is // probably rare enough that we don't care much. // Blocks until send has completed. - Status send_status = captured_state_->GetSendStatus(); + absl::Status send_status = captured_state_->GetSendStatus(); if (!send_status.ok()) { captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); done(send_status); @@ -254,9 +254,9 @@ void RemoteCopyNode::RunRemoteRecv(EagerOperation* op, StatusCallback done) { ctx_->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, &request, response, [captured_state, response, recv_device, context_view_id, - done](const Status& s) { + done](const absl::Status& s) { if (s.ok()) { - Status status = captured_state->dst()->SetRemoteShape( + absl::Status status = captured_state->dst()->SetRemoteShape( response->queue_response(0).shape(0), recv_device, context_view_id); if (!status.ok()) { @@ -278,8 +278,8 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { // TODO(gjn): We should consider just using the low-level RecvOp::Compute() // functionality here instead of constructing an Op. EagerOperation op(ctx_); - Status status = op.Reset("_Recv", /*device_name=*/nullptr, - /*remote=*/false, /*executor=*/nullptr); + absl::Status status = op.Reset("_Recv", /*device_name=*/nullptr, + /*remote=*/false, /*executor=*/nullptr); Device* recv_device = ctx_->CanonicalDevice(recv_device_); if (!status.ok()) { captured_state_->dst()->Poison(status, recv_device); @@ -316,9 +316,10 @@ void RemoteCopyNode::StartRecv(StatusCallback done) { } } -Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, - const Device* target_device, EagerContext* ctx, - SendPackedHandleOp* op) { +absl::Status SerializePackedHandle(const uint64 op_id, + TensorHandle* packed_handle, + const Device* target_device, + EagerContext* ctx, SendPackedHandleOp* op) { op->set_op_id(op_id); op->set_device_name(packed_handle->DeviceOrHostCPU(*ctx)->name()); for (int i = 0; i < packed_handle->NumPackedHandles(); ++i) { @@ -360,7 +361,7 @@ Status SerializePackedHandle(const uint64 op_id, TensorHandle* packed_handle, } void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) { - Status s; + absl::Status s; const uint64 context_view_id = ctx_->GetContextViewId(); if (!send_device_->IsLocal()) { s = errors::InvalidArgument( @@ -405,9 +406,9 @@ void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) { ctx_->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, &request, response, [captured_state, response, recv_device, context_view_id, - done](const Status& s) { + done](const absl::Status& s) { if (s.ok()) { - Status status = captured_state->dst()->SetRemoteShape( + absl::Status status = captured_state->dst()->SetRemoteShape( captured_state->GetSrcShape(), recv_device, context_view_id); if (!status.ok()) { LOG(ERROR) << "Ignoring an error encountered when setting remote " @@ -423,7 +424,7 @@ void RemoteCopyNode::StartSendPackedHandle(StatusCallback done) { } void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { - Status s; + absl::Status s; EnqueueRequest request; uint64 context_id = ctx_->GetContextId(); request.set_context_id(context_id); @@ -458,9 +459,9 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { ctx_->Executor().StreamingEnqueue(), /*call_opts=*/nullptr, &request, response, [captured_state, response, recv_device, context_view_id, - done](const Status& s) { + done](const absl::Status& s) { if (s.ok()) { - Status status = captured_state->dst()->SetRemoteShape( + absl::Status status = captured_state->dst()->SetRemoteShape( captured_state->GetSrcShape(), recv_device, context_view_id); if (!status.ok()) { LOG(ERROR) << "Ignoring an error encountered when setting remote " @@ -475,7 +476,7 @@ void RemoteCopyNode::StartRemoteSendTensor(StatusCallback done) { }); } -Status RemoteCopyNode::Prepare() { +absl::Status RemoteCopyNode::Prepare() { TF_RETURN_IF_ERROR(captured_state_->dst()->CopyInferenceShape(src_)); return absl::OkStatus(); } @@ -494,9 +495,9 @@ void RemoteCopyNode::RunAsync(StatusCallback done) { const std::shared_ptr& captured_state = captured_state_; auto done_wrapper = [captured_state, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { if (!s.ok() && errors::IsCancelled(s)) { - Status send_status = captured_state->GetSendStatus(); + absl::Status send_status = captured_state->GetSendStatus(); if (!send_status.ok()) { // In this case, Recv is cancelled because the Send op failed. // Return the status of the Send op instead. @@ -512,7 +513,7 @@ void RemoteCopyNode::RunAsync(StatusCallback done) { StartRecv(std::move(done_wrapper)); } -void RemoteCopyNode::Abort(Status status) { +void RemoteCopyNode::Abort(absl::Status status) { if (!started_) { uint64 context_view_id = ctx_->GetContextViewId(); captured_state_->dst()->PoisonRemote(status, recv_device_, context_view_id); diff --git a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h index 0f5297e855933e..32f3befdfdab14 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_copy_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_copy_node.h @@ -66,11 +66,11 @@ class RemoteCopyNode : public AsyncEagerNode { ~RemoteCopyNode() override; - Status Prepare() override; + absl::Status Prepare() override; void RunAsync(StatusCallback done) override; - void Abort(Status status) override; + void Abort(absl::Status status) override; string DebugString() const override { string out = "[RemoteCopyNode]"; @@ -90,7 +90,7 @@ class RemoteCopyNode : public AsyncEagerNode { void StartSend(); // Synchronously runs local send `op` and returns its status. - Status RunLocalSend(EagerOperation* op); + absl::Status RunLocalSend(EagerOperation* op); // Runs the _Recv operation locally or remotely. // An error return value indicates that _Recv did not run successfully. It @@ -106,7 +106,7 @@ class RemoteCopyNode : public AsyncEagerNode { // Synchronously runs local receive `op` and returns its status. // Does not wait for the send to complete before running receive. - Status RunLocalRecv(EagerOperation* op, std::vector* outputs); + absl::Status RunLocalRecv(EagerOperation* op, std::vector* outputs); // Waits for send to complete, then issues remote receive `op` and // returns its status. @@ -133,12 +133,12 @@ class RemoteCopyNode : public AsyncEagerNode { explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } ~CapturedSharedState() { dst_->Unref(); } - void SetSendStatus(Status status) { + void SetSendStatus(absl::Status status) { send_status_.Update(status); send_done_.Notify(); } - Status GetSendStatus() { + absl::Status GetSendStatus() { send_done_.WaitForNotification(); return send_status_; } @@ -156,7 +156,7 @@ class RemoteCopyNode : public AsyncEagerNode { CancellationManager recv_cancellation_; // send_status_ is safe to read only after send_done_.WaitForNotification() // has returned. - Status send_status_; + absl::Status send_status_; Notification send_done_; TensorShape src_shape_; }; diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc index af7e24d79b80f8..f118ecaeb2bbad 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.cc @@ -27,8 +27,8 @@ namespace eager { void RemoteExecuteNode::RunAsync(StatusCallback done) { auto response = std::make_shared(); - const gtl::InlinedVector& inputs = inputs_; - const gtl::InlinedVector& retvals = retvals_; + const absl::InlinedVector& inputs = inputs_; + const absl::InlinedVector& retvals = retvals_; Device* device = device_; // Filled and used only when VLOG(3) is on. @@ -60,7 +60,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) { const bool already_cancelled = !cm->RegisterCallback( token, [call_opts, response, done]() { call_opts->StartCancel(); }); if (already_cancelled) { - Status s = errors::Cancelled("RemoteExecuteNode::RunAsync"); + absl::Status s = errors::Cancelled("RemoteExecuteNode::RunAsync"); for (size_t i = 0; i < retvals.size(); ++i) { retvals[i]->PoisonRemote(s, device, context_view_id_); } @@ -81,7 +81,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) { request_.get(), response.get(), [inputs, retvals, call_opts, response, device, context_view_id = context_view_id_, rpc_description, cm, token, - done](const Status& status) { + done](const absl::Status& status) { if (cm != nullptr) { cm->TryDeregisterCallback(token); } @@ -100,7 +100,7 @@ void RemoteExecuteNode::RunAsync(StatusCallback done) { response->queue_response(0).device().empty() ? "" : response->queue_response(0).device(i); - Status s = retvals[i]->SetRemoteShapeAndDevice( + absl::Status s = retvals[i]->SetRemoteShapeAndDevice( response->queue_response(0).shape(i), device, context_view_id, output_device); diff --git a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h index 148e58a5b008c5..d1c5359d473900 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_execute_node.h +++ b/tensorflow/core/distributed_runtime/eager/remote_execute_node.h @@ -44,7 +44,7 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { CancellationManager* cancellation_manager, const NodeDef& ndef, const FunctionLibraryDefinition* lib_def, - const gtl::InlinedVector& inputs, + const absl::InlinedVector& inputs, absl::Span retvals) : AsyncRemoteExecuteNode(), eager_context_(eager_context), @@ -92,15 +92,17 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { eager_client_->Unref(); } - Status Prepare() override { + absl::Status Prepare() override { return RunShapeInference(ndef_, *lib_def_, inputs_, retvals_); } void RunAsync(StatusCallback done) override; - Status SyncExecutors() override { return eager_context_->SyncExecutors(); } + absl::Status SyncExecutors() override { + return eager_context_->SyncExecutors(); + } - void Abort(Status status) override { + void Abort(absl::Status status) override { int i = 0; for (auto handle : retvals_) { handle->PoisonRemote(status, device_, context_view_id_); @@ -133,8 +135,8 @@ class RemoteExecuteNode : public AsyncRemoteExecuteNode { CancellationManager* cancellation_manager_; const NodeDef ndef_; const FunctionLibraryDefinition* lib_def_; - gtl::InlinedVector inputs_; - gtl::InlinedVector retvals_; + absl::InlinedVector inputs_; + absl::InlinedVector retvals_; }; } // namespace eager diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index 9799808face199..acd34fd9ccbc86 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -32,7 +32,7 @@ limitations under the License. namespace tensorflow { namespace { -Status WithErrorSourcePayload(Status error) { +absl::Status WithErrorSourcePayload(absl::Status error) { core::platform::ErrorSourceProto error_source_proto; error_source_proto.set_error_source( core::platform::ErrorSourceProto::EAGER_REMOTE_MGR); @@ -62,7 +62,7 @@ void RemoteMgr::AddOperationOutput(tensorflow::TensorHandle* handle, RemoteTensorHandleInternal(operation_id, output_num), handle); } -Status RemoteMgr::GetTensorHandleImpl( +absl::Status RemoteMgr::GetTensorHandleImpl( const RemoteTensorHandleInternal& remote_handle, tensorflow::TensorHandle** handle) { auto iter = remote_tensor_handle_map_.find(remote_handle); @@ -96,14 +96,14 @@ Status RemoteMgr::GetTensorHandleImpl( return absl::OkStatus(); } -Status RemoteMgr::GetTensorHandle( +absl::Status RemoteMgr::GetTensorHandle( const RemoteTensorHandleInternal& remote_handle, tensorflow::TensorHandle** handle) { tf_shared_lock l(remote_tensor_handle_mu_); return GetTensorHandleImpl(remote_handle, handle); } -Status RemoteMgr::GetMirroredResourceShape( +absl::Status RemoteMgr::GetMirroredResourceShape( const RemoteTensorHandleInternal& remote_handle, std::vector* handle) { tf_shared_lock l(mirrored_resource_shape_mu_); @@ -125,9 +125,9 @@ Status RemoteMgr::GetMirroredResourceShape( return absl::OkStatus(); } -Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, - const bool wait_until_ready, - int64_t* op_id, int32* output_num) { +absl::Status RemoteMgr::GetRemoteTensorHandle( + const tensorflow::TensorHandle* handle, const bool wait_until_ready, + int64_t* op_id, int32* output_num) { TF_RETURN_IF_ERROR(handle->RemoteAddress(handle->device(), wait_until_ready, op_id, output_num)); tensorflow::TensorHandle* h; @@ -141,7 +141,7 @@ Status RemoteMgr::GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, return absl::OkStatus(); } -Status RemoteMgr::DeleteTensorHandle( +absl::Status RemoteMgr::DeleteTensorHandle( const RemoteTensorHandleInternal& remote_handle) { { mutex_lock l(remote_tensor_handle_mu_); @@ -165,7 +165,7 @@ Status RemoteMgr::DeleteTensorHandle( remote_handle.op_id, ", Output num: ", remote_handle.output_num)); } -Status RemoteMgr::SerializeRemoteTensorHandle( +absl::Status RemoteMgr::SerializeRemoteTensorHandle( TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out, Device* device, absl::string_view device_name, const bool serialize_resource_dtype_and_shape) { @@ -203,8 +203,8 @@ Status RemoteMgr::SerializeRemoteTensorHandle( return absl::OkStatus(); } -Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, - TensorHandle** out) { +absl::Status RemoteMgr::DeserializeRemoteTensorHandle( + const RemoteTensorHandle& in, TensorHandle** out) { Device* device; if (parent_->local_device_mgr()->LookupDevice(in.op_device(), &device).ok() || parent_->local_device_mgr()->LookupDevice(in.device(), &device).ok()) { @@ -260,7 +260,7 @@ void RemoteMgr::DeleteExecutorForStream(uint64 stream_id) { if (it == executor_map_.end()) { return; } - Status s = it->second.ShutDown(); + absl::Status s = it->second.ShutDown(); if (!s.ok()) { LOG(ERROR) << "EagerExecutor shutdown with error " << s.message(); } diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.h b/tensorflow/core/distributed_runtime/eager/remote_mgr.h index 6a1e9de756dad5..b62134cd6e5860 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.h +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.h @@ -50,10 +50,11 @@ class RemoteMgr { void AddOperationOutput(tensorflow::TensorHandle* handles, int64_t operation_id, int32_t output_num); - Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle, - tensorflow::TensorHandle** handle); + absl::Status GetTensorHandle(const RemoteTensorHandleInternal& remote_handle, + tensorflow::TensorHandle** handle); - Status DeleteTensorHandle(const RemoteTensorHandleInternal& remote_handle); + absl::Status DeleteTensorHandle( + const RemoteTensorHandleInternal& remote_handle); // Helper function to create monotonically increasing ids unique to this // context. @@ -66,15 +67,15 @@ class RemoteMgr { // Serialize a remote TensorHandle to a RemoteTensorHandle. // If wait_until_ready is true, block until the remote handle is ready on a // remote worker. - Status SerializeRemoteTensorHandle( + absl::Status SerializeRemoteTensorHandle( TensorHandle* in, const bool wait_until_ready, RemoteTensorHandle* out, Device* device, absl::string_view device_name = "", const bool serialize_resource_dtype_and_shape = false); // Deserialize a RemoteTensorHandle to a TensorHandle(local/remote). // The output holds a reference to the TensorHandle. - Status DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, - TensorHandle** out); + absl::Status DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, + TensorHandle** out); EagerExecutor& GetOrCreateExecutorForStream(uint64 stream_id); @@ -87,16 +88,17 @@ class RemoteMgr { private: // Returns the op_id and output_num if the given local TensorHandle exists in // remote_tensor_handle_map_. - Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, - const bool wait_until_ready, int64_t* op_id, - int32* output_num) + absl::Status GetRemoteTensorHandle(const tensorflow::TensorHandle* handle, + const bool wait_until_ready, + int64_t* op_id, int32* output_num) TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); - Status GetTensorHandleImpl(const RemoteTensorHandleInternal& remote_handle, - tensorflow::TensorHandle** handle) + absl::Status GetTensorHandleImpl( + const RemoteTensorHandleInternal& remote_handle, + tensorflow::TensorHandle** handle) TF_SHARED_LOCKS_REQUIRED(remote_tensor_handle_mu_); - Status GetMirroredResourceShape( + absl::Status GetMirroredResourceShape( const RemoteTensorHandleInternal& remote_handle, std::vector* handle); diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc index 1d104abb306afb..ae05ce640cf0dc 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr_test.cc @@ -169,7 +169,7 @@ TEST_F(RemoteMgrTest, ErrorSourcesShouldExist) { TF_ASSERT_OK(remote_mgr.DeleteTensorHandle(remote_handle_internal)); // Now that the tensor has been deleted, we cannot access the remote handle. - Status s = remote_mgr.DeleteTensorHandle(remote_handle_internal); + absl::Status s = remote_mgr.DeleteTensorHandle(remote_handle_internal); EXPECT_FALSE(s.ok()); EXPECT_TRUE(s.GetPayload(kErrorSource).has_value()); diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc index 7c6198f5578e94..73427ed1372ed8 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.cc @@ -40,7 +40,7 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task, } core::RefCountPtr eager_client; - Status status = ctx->GetClient(remote_task, &eager_client); + absl::Status status = ctx->GetClient(remote_task, &eager_client); if (!status.ok()) { LOG_EVERY_N_SEC(INFO, 60) << "Unable to destroy remote tensor handle because the target " @@ -61,7 +61,7 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task, std::move(request), std::move(eager_client), ready)); auto& executor = ctx->Executor(); if (executor.Async()) { - Status status = executor.AddOrExecute(std::move(node)); + absl::Status status = executor.AddOrExecute(std::move(node)); if (!status.ok()) { LOG_EVERY_N_SEC(WARNING, 60) << "Unable to destroy remote tensor handles. If you are " @@ -74,7 +74,7 @@ void DestroyRemoteTensorHandle(EagerContext* ctx, const string& remote_task, // to send out the destroy request in a new thread to avoid deadlock. auto* released_node = node.release(); (*ctx->runner())([ctx, released_node] { - Status status = + absl::Status status = ctx->Executor().AddOrExecute(absl::WrapUnique(released_node)); if (!status.ok()) { LOG_EVERY_N_SEC(WARNING, 60) @@ -125,7 +125,7 @@ RemoteTensorHandleData::~RemoteTensorHandleData() { } } -Status RemoteTensorHandleData::Shape(TensorShape* shape) const { +absl::Status RemoteTensorHandleData::Shape(TensorShape* shape) const { TF_RETURN_IF_ERROR(WaitReady("Shape")); tf_shared_lock l(mu_); @@ -134,7 +134,7 @@ Status RemoteTensorHandleData::Shape(TensorShape* shape) const { return absl::OkStatus(); } -Status RemoteTensorHandleData::NumDims(int* num_dims) const { +absl::Status RemoteTensorHandleData::NumDims(int* num_dims) const { TF_RETURN_IF_ERROR(WaitReady("NumDims")); tf_shared_lock l(mu_); @@ -143,7 +143,7 @@ Status RemoteTensorHandleData::NumDims(int* num_dims) const { return absl::OkStatus(); } -Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const { +absl::Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const { TF_RETURN_IF_ERROR(WaitReady("Dim")); tf_shared_lock l(mu_); @@ -152,7 +152,7 @@ Status RemoteTensorHandleData::Dim(int dim_index, int64_t* dim) const { return absl::OkStatus(); } -Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const { +absl::Status RemoteTensorHandleData::NumElements(int64_t* num_elements) const { TF_RETURN_IF_ERROR(WaitReady("NumElements")); tf_shared_lock l(mu_); @@ -166,22 +166,22 @@ bool RemoteTensorHandleData::IsReady() const { return is_ready_; } -void RemoteTensorHandleData::Poison(Status status) { +void RemoteTensorHandleData::Poison(absl::Status status) { mutex_lock l(mu_); is_poisoned_ = status; is_ready_ = true; } -Status RemoteTensorHandleData::IsPoisoned() const { +absl::Status RemoteTensorHandleData::IsPoisoned() const { tf_shared_lock l(mu_); return is_poisoned_; } -Status RemoteTensorHandleData::SetShape(const TensorShape& shape) { +absl::Status RemoteTensorHandleData::SetShape(const TensorShape& shape) { return SetShapeAndRemoteTask(shape, /*remote_task=*/""); } -Status RemoteTensorHandleData::SetShapeAndRemoteTask( +absl::Status RemoteTensorHandleData::SetShapeAndRemoteTask( const TensorShape& shape, const string& remote_task) { // If `is_ready_` is set previously due to poisoning, return the original // error that poisoned this tensor. @@ -221,9 +221,8 @@ string RemoteTensorHandleData::DebugString() const { " output_num: ", output_num_); } -Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_until_ready, - int64_t* op_id, - int32* output_num) const { +absl::Status RemoteTensorHandleData::OpIdAndOutputNum( + const bool wait_until_ready, int64_t* op_id, int32* output_num) const { if (wait_until_ready) { TF_RETURN_IF_ERROR(WaitReady("OpIdAndOutputNumUntilReady")); } @@ -232,7 +231,7 @@ Status RemoteTensorHandleData::OpIdAndOutputNum(const bool wait_until_ready, return absl::OkStatus(); } -Status RemoteTensorHandleData::WaitReady(const char* caller) const { +absl::Status RemoteTensorHandleData::WaitReady(const char* caller) const { tf_shared_lock l(mu_); if (!is_ready_) { tsl::profiler::TraceMe activity( diff --git a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h index 92f0a66ebbbba7..892d82bd5f7efe 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h +++ b/tensorflow/core/distributed_runtime/eager/remote_tensor_handle_data.h @@ -41,33 +41,33 @@ class RemoteTensorHandleData { // A remote tensor handle does not have a Tensor object, hence it can only // support the shape requests. - Status Shape(TensorShape* shape) const; - Status NumDims(int* num_dims) const; - Status Dim(int dim_index, int64_t* dim) const; - Status NumElements(int64_t* num_elements) const; - Status Unprotect() { return absl::OkStatus(); } + absl::Status Shape(TensorShape* shape) const; + absl::Status NumDims(int* num_dims) const; + absl::Status Dim(int dim_index, int64_t* dim) const; + absl::Status NumElements(int64_t* num_elements) const; + absl::Status Unprotect() { return absl::OkStatus(); } bool IsReady() const; - Status WaitReady(const char* caller) const; - Status SetShape(const TensorShape& shape); - Status SetShapeAndRemoteTask(const TensorShape& shape, - const string& remote_task); - void Poison(Status status); - Status IsPoisoned() const; + absl::Status WaitReady(const char* caller) const; + absl::Status SetShape(const TensorShape& shape); + absl::Status SetShapeAndRemoteTask(const TensorShape& shape, + const string& remote_task); + void Poison(absl::Status status); + absl::Status IsPoisoned() const; string DebugString() const; // Return the op id and output num. If wait_until_ready is true, block until // the remote tensor is ready on a remote worker. - Status OpIdAndOutputNum(bool wait_until_ready, int64_t* op_id, - int32* output_num) const; + absl::Status OpIdAndOutputNum(bool wait_until_ready, int64_t* op_id, + int32* output_num) const; uint64 context_view_id() const { return context_view_id_; } private: mutable mutex mu_; bool is_ready_ TF_GUARDED_BY(mu_); - Status is_poisoned_ TF_GUARDED_BY(mu_); + absl::Status is_poisoned_ TF_GUARDED_BY(mu_); TensorShape shape_ TF_GUARDED_BY(mu_); // IDs required when this class is representing a remote tensor handle. diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc index 356f0a08412fd9..9f803991417dce 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/blocking_counter.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc index 7a077cbe13b8d6..25eb3da148a23f 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_function_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/kernel_and_device.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc index cffe93d297a8df..01ab92e1939ceb 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc @@ -20,13 +20,13 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc index 3d9ff3c459181f..3c04a22afb174b 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/strcat.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc index e5e9aad6f06ddf..1d78c7008b91e1 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_session_coordination_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/tf_datatype.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/distributed_runtime/integration_test/coordination_test_opkernel_registration.cc b/tensorflow/core/distributed_runtime/integration_test/coordination_test_opkernel_registration.cc index af266e3845e21b..6e85d474250f70 100644 --- a/tensorflow/core/distributed_runtime/integration_test/coordination_test_opkernel_registration.cc +++ b/tensorflow/core/distributed_runtime/integration_test/coordination_test_opkernel_registration.cc @@ -147,8 +147,7 @@ class TestReportErrorToClusterOp : public OpKernel { "initialized properly.")); return; } - tensorflow::Status s(static_cast(error_code), - error_message); + absl::Status s(static_cast(error_code), error_message); s.SetPayload(tsl::CoordinationErrorPayloadKey(), absl::Cord("testing error payload")); OP_REQUIRES_OK(ctx, coord_agent->ReportError(s)); diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc index a41f977b059f82..c014125ac8b9b9 100644 --- a/tensorflow/core/distributed_runtime/local_master.cc +++ b/tensorflow/core/distributed_runtime/local_master.cc @@ -23,9 +23,9 @@ limitations under the License. namespace tensorflow { namespace { -Status WaitForNotification(CallOptions* call_options, - const int64_t default_timeout_in_ms, - Notification* n) { +absl::Status WaitForNotification(CallOptions* call_options, + const int64_t default_timeout_in_ms, + Notification* n) { int64_t timeout_in_ms = call_options->GetTimeout(); if (timeout_in_ms == 0) { timeout_in_ms = default_timeout_in_ms; @@ -52,55 +52,58 @@ LocalMaster::LocalMaster(Master* master_impl, : master_impl_(master_impl), default_timeout_in_ms_(default_timeout_in_ms) {} -Status LocalMaster::CreateSession(CallOptions* call_options, - const CreateSessionRequest* request, - CreateSessionResponse* response) { +absl::Status LocalMaster::CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) { Notification n; - Status ret; - master_impl_->CreateSession(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->CreateSession(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::ExtendSession(CallOptions* call_options, - const ExtendSessionRequest* request, - ExtendSessionResponse* response) { +absl::Status LocalMaster::ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) { Notification n; - Status ret; - master_impl_->ExtendSession(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->ExtendSession(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::PartialRunSetup(CallOptions* call_options, - const PartialRunSetupRequest* request, - PartialRunSetupResponse* response) { +absl::Status LocalMaster::PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) { Notification n; - Status ret; - master_impl_->PartialRunSetup(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->PartialRunSetup(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::RunStep(CallOptions* call_options, - RunStepRequestWrapper* request, - MutableRunStepResponseWrapper* response) { +absl::Status LocalMaster::RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) { Notification n; - Status ret; + absl::Status ret; master_impl_->RunStep(call_options, request, response, - [&n, &ret](const Status& s) { + [&n, &ret](const absl::Status& s) { ret.Update(s); n.Notify(); }); @@ -117,40 +120,42 @@ MutableRunStepResponseWrapper* LocalMaster::CreateRunStepResponse() { return new InMemoryRunStepResponse; } -Status LocalMaster::CloseSession(CallOptions* call_options, - const CloseSessionRequest* request, - CloseSessionResponse* response) { +absl::Status LocalMaster::CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) { Notification n; - Status ret; - master_impl_->CloseSession(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->CloseSession(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::ListDevices(CallOptions* call_options, - const ListDevicesRequest* request, - ListDevicesResponse* response) { +absl::Status LocalMaster::ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) { Notification n; - Status ret; - master_impl_->ListDevices(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->ListDevices(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::Reset(CallOptions* call_options, - const ResetRequest* request, - ResetResponse* response) { +absl::Status LocalMaster::Reset(CallOptions* call_options, + const ResetRequest* request, + ResetResponse* response) { Notification n; - Status ret; - master_impl_->Reset(request, response, [&n, &ret](const Status& s) { + absl::Status ret; + master_impl_->Reset(request, response, [&n, &ret](const absl::Status& s) { ret.Update(s); n.Notify(); }); @@ -159,26 +164,27 @@ Status LocalMaster::Reset(CallOptions* call_options, return ret; } -Status LocalMaster::MakeCallable(CallOptions* call_options, - const MakeCallableRequest* request, - MakeCallableResponse* response) { +absl::Status LocalMaster::MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) { Notification n; - Status ret; - master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->MakeCallable(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::RunCallable(CallOptions* call_options, - const RunCallableRequest* request, - RunCallableResponse* response) { +absl::Status LocalMaster::RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) { Notification n; - Status ret; + absl::Status ret; master_impl_->RunCallable(call_options, request, response, - [&n, &ret](const Status& s) { + [&n, &ret](const absl::Status& s) { ret.Update(s); n.Notify(); }); @@ -186,15 +192,16 @@ Status LocalMaster::RunCallable(CallOptions* call_options, WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; } -Status LocalMaster::ReleaseCallable(CallOptions* call_options, - const ReleaseCallableRequest* request, - ReleaseCallableResponse* response) { +absl::Status LocalMaster::ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) { Notification n; - Status ret; - master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) { - ret.Update(s); - n.Notify(); - }); + absl::Status ret; + master_impl_->ReleaseCallable(request, response, + [&n, &ret](const absl::Status& s) { + ret.Update(s); + n.Notify(); + }); TF_RETURN_IF_ERROR( WaitForNotification(call_options, default_timeout_in_ms_, &n)); return ret; diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h index 3fbbbf28a1fe7f..e4fc37e4f60f50 100644 --- a/tensorflow/core/distributed_runtime/local_master.h +++ b/tensorflow/core/distributed_runtime/local_master.h @@ -40,46 +40,47 @@ class LocalMaster : public MasterInterface { public: ~LocalMaster() override {} - Status CreateSession(CallOptions* call_options, - const CreateSessionRequest* request, - CreateSessionResponse* response) override; + absl::Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) override; - Status ExtendSession(CallOptions* call_options, - const ExtendSessionRequest* request, - ExtendSessionResponse* response) override; + absl::Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) override; - Status PartialRunSetup(CallOptions* call_options, - const PartialRunSetupRequest* request, - PartialRunSetupResponse* response) override; + absl::Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) override; - Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - MutableRunStepResponseWrapper* response) override; + absl::Status RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) override; MutableRunStepRequestWrapper* CreateRunStepRequest() override; MutableRunStepResponseWrapper* CreateRunStepResponse() override; - Status CloseSession(CallOptions* call_options, - const CloseSessionRequest* request, - CloseSessionResponse* response) override; + absl::Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) override; - Status ListDevices(CallOptions* call_options, - const ListDevicesRequest* request, - ListDevicesResponse* response) override; + absl::Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) override; // See tensorflow::Reset() and the comment on ResetRequest. - Status Reset(CallOptions* call_options, const ResetRequest* request, - ResetResponse* response) override; - - Status MakeCallable(CallOptions* call_options, - const MakeCallableRequest* request, - MakeCallableResponse* response) override; - Status RunCallable(CallOptions* call_options, - const RunCallableRequest* request, - RunCallableResponse* response) override; - Status ReleaseCallable(CallOptions* call_options, - const ReleaseCallableRequest* request, - ReleaseCallableResponse* response) override; + absl::Status Reset(CallOptions* call_options, const ResetRequest* request, + ResetResponse* response) override; + + absl::Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) override; + absl::Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) override; + absl::Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) override; // Registers the mapping from the given `target` to the given `master`. // diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index f67f2b6052bcc1..0bcff072585273 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -35,6 +35,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/distributed_runtime/remote_device.h" @@ -56,7 +57,6 @@ limitations under the License. #include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tensorflow { @@ -137,7 +137,7 @@ MasterSession* Master::FindMasterSession(const string& handle) { class DeviceFinder { public: - static Status GetRemoteDevices( + static absl::Status GetRemoteDevices( const protobuf::RepeatedPtrField& device_filters, MasterEnv* env, WorkerCacheInterface* worker_cache, std::vector>* out_remote) { @@ -253,7 +253,7 @@ class DeviceFinder { // never be called. NewRemoteDevices( env_->env, worker_cache_, targets_[i], - [this, i](const Status& s, std::vector* devices) { + [this, i](const absl::Status& s, std::vector* devices) { WhenFound(i, s, devices); }); } @@ -264,7 +264,7 @@ class DeviceFinder { // responded. const int32 kLoggingPeriodMs = 10 * 1000; - Status Wait() { + absl::Status Wait() { mutex_lock l(mu_); // TODO(mrry): Propagate a timeout here, since `num_pending_` may // never become zero. @@ -314,9 +314,9 @@ class DeviceFinder { // heard from this target or not. std::vector targets_; std::vector seen_targets_ TF_GUARDED_BY(mu_); - Status status_; + absl::Status status_; - void WhenFound(int target_index, const Status& s, + void WhenFound(int target_index, const absl::Status& s, std::vector* devices) { mutex_lock l(mu_); seen_targets_[target_index] = true; @@ -364,7 +364,7 @@ class DeviceFinder { void Master::CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done) { SchedClosure([this, req, resp, done]() { - Status status; + absl::Status status; WorkerCacheFactoryOptions worker_cache_factory_options; auto call_done = gtl::MakeCleanup([&status, &done] { done(status); }); status = ValidateExternalGraphDefSyntax(req->graph_def()); @@ -506,7 +506,7 @@ void Master::ExtendSession(const ExtendSessionRequest* req, } SchedClosure([session, req, resp, done]() { - Status status = ValidateExternalGraphDefSyntax(req->graph_def()); + absl::Status status = ValidateExternalGraphDefSyntax(req->graph_def()); if (status.ok()) { status = session->Extend(req, resp); } @@ -517,8 +517,8 @@ void Master::ExtendSession(const ExtendSessionRequest* req, void Master::PartialRunSetup(const PartialRunSetupRequest* req, PartialRunSetupResponse* resp, MyClosure done) { - Status s = recent_request_ids_.TrackUnique(req->request_id(), - "PartialRunSetup (Master)", *req); + absl::Status s = recent_request_ids_.TrackUnique( + req->request_id(), "PartialRunSetup (Master)", *req); if (!s.ok()) { done(s); return; @@ -530,7 +530,7 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req, } SchedClosure([session, req, resp, done]() { - Status s = session->PartialRunSetup(req, resp); + absl::Status s = session->PartialRunSetup(req, resp); session->Unref(); done(s); }); @@ -538,8 +538,8 @@ void Master::PartialRunSetup(const PartialRunSetupRequest* req, void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req, MutableRunStepResponseWrapper* resp, MyClosure done) { - Status s = recent_request_ids_.TrackUnique(req->request_id(), - "RunStep (Master)", req); + absl::Status s = recent_request_ids_.TrackUnique(req->request_id(), + "RunStep (Master)", req); if (!s.ok()) { done(s); return; @@ -552,7 +552,7 @@ void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req, } SchedClosure([this, start_time, session, opts, req, resp, done]() { - Status status = session->Run(opts, *req, resp); + absl::Status status = session->Run(opts, *req, resp); session->Unref(); uint64 done_time = env_->env->NowMicros(); done(status); @@ -585,7 +585,7 @@ void Master::CloseSession(const CloseSessionRequest* req, // Session Close() blocks on thread shutdown. Therefore, we need to // delete it in non-critical thread. SchedClosure([session, done]() { - Status s = session->Close(); + absl::Status s = session->Close(); session->Unref(); done(s); }); @@ -603,13 +603,13 @@ void Master::ListDevices(const ListDevicesRequest* req, return; } core::ScopedUnref ref(session); - Status s = session->ListDevices(resp); + absl::Status s = session->ListDevices(resp); done(s); return; } std::vector> remote_devices; - Status s = DeviceFinder::GetRemoteDevices({}, env_, env_->worker_cache, - &remote_devices); + absl::Status s = DeviceFinder::GetRemoteDevices( + {}, env_, env_->worker_cache, &remote_devices); if (s.ok()) { for (Device* dev : env_->local_devices) { *(resp->add_local_device()) = dev->attributes(); @@ -638,7 +638,7 @@ void Master::CleanupWorkers(const ResetRequest& reset) { auto worker = env_->worker_cache->GetOrCreateWorker(worker_name); if (worker) { worker->CleanupAllAsync( - &req, &resp[i], [this, &n, worker_name, worker, c](Status s) { + &req, &resp[i], [this, &n, worker_name, worker, c](absl::Status s) { if (!s.ok()) { LOG(ERROR) << "Worker CleanupAll failed: " << s; } @@ -674,7 +674,7 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp, CleanupWorkers(*req); SchedClosure([sessions_to_close, done]() { - Status s; + absl::Status s; for (MasterSession* session : sessions_to_close) { s.Update(session->Close()); session->Unref(); @@ -685,8 +685,8 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp, void Master::MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp, MyClosure done) { - Status s = recent_request_ids_.TrackUnique(req->request_id(), - "MakeCallable (Master)", *req); + absl::Status s = recent_request_ids_.TrackUnique( + req->request_id(), "MakeCallable (Master)", *req); if (!s.ok()) { done(s); return; @@ -698,7 +698,7 @@ void Master::MakeCallable(const MakeCallableRequest* req, } SchedClosure([session, req, resp, done = std::move(done)]() { - Status s = session->MakeCallable(*req, resp); + absl::Status s = session->MakeCallable(*req, resp); session->Unref(); done(s); }); @@ -706,8 +706,8 @@ void Master::MakeCallable(const MakeCallableRequest* req, void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req, RunCallableResponse* resp, MyClosure done) { - Status s = recent_request_ids_.TrackUnique(req->request_id(), - "RunCallable (Master)", *req); + absl::Status s = recent_request_ids_.TrackUnique( + req->request_id(), "RunCallable (Master)", *req); if (!s.ok()) { done(s); return; @@ -719,7 +719,7 @@ void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req, } SchedClosure([session, opts, req, resp, done = std::move(done)]() { - Status s = session->RunCallable(opts, *req, resp); + absl::Status s = session->RunCallable(opts, *req, resp); session->Unref(); done(s); }); @@ -734,7 +734,7 @@ void Master::ReleaseCallable(const ReleaseCallableRequest* req, } SchedClosure([session, req, resp, done = std::move(done)]() { - Status s = session->ReleaseCallable(*req, resp); + absl::Status s = session->ReleaseCallable(*req, resp); session->Unref(); done(s); }); diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h index 8cdbe6fccedd0a..a3930249b629ee 100644 --- a/tensorflow/core/distributed_runtime/master.h +++ b/tensorflow/core/distributed_runtime/master.h @@ -39,7 +39,7 @@ class Master { virtual ~Master(); // Convenient typedef for a closure passing a Status. - typedef std::function MyClosure; + typedef std::function MyClosure; void CreateSession(const CreateSessionRequest* req, CreateSessionResponse* resp, MyClosure done); diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h index 633e21df361386..b8dcf1963df50d 100644 --- a/tensorflow/core/distributed_runtime/master_env.h +++ b/tensorflow/core/distributed_runtime/master_env.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { class Env; @@ -99,8 +99,8 @@ struct MasterEnv { std::vector filtered_worker_list)> master_session_factory; - std::function + std::function worker_cache_factory; // Generates per-step CollectiveExecutors and has access to utilities diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h index cde47fb9caf55f..df9894f7cbae78 100644 --- a/tensorflow/core/distributed_runtime/master_interface.h +++ b/tensorflow/core/distributed_runtime/master_interface.h @@ -33,27 +33,27 @@ namespace tensorflow { class MasterInterface { public: virtual ~MasterInterface() {} - virtual Status CreateSession(CallOptions* call_options, - const CreateSessionRequest* request, - CreateSessionResponse* response) = 0; + virtual absl::Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) = 0; - virtual Status ExtendSession(CallOptions* call_options, - const ExtendSessionRequest* request, - ExtendSessionResponse* response) = 0; + virtual absl::Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) = 0; - virtual Status PartialRunSetup(CallOptions* call_options, - const PartialRunSetupRequest* request, - PartialRunSetupResponse* response) { + virtual absl::Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) { return errors::Unimplemented("Partial run not implemented for this master"); } - virtual Status RunStep(CallOptions* call_options, - RunStepRequestWrapper* request, - MutableRunStepResponseWrapper* response) = 0; + virtual absl::Status RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) = 0; - virtual Status RunStep(CallOptions* call_options, - const RunStepRequest* request, - RunStepResponse* response) { + virtual absl::Status RunStep(CallOptions* call_options, + const RunStepRequest* request, + RunStepResponse* response) { std::unique_ptr wrapped_request( new ProtoRunStepRequest(request)); std::unique_ptr wrapped_response( @@ -81,26 +81,27 @@ class MasterInterface { return new OwnedProtoRunStepResponse; } - virtual Status CloseSession(CallOptions* call_options, - const CloseSessionRequest* request, - CloseSessionResponse* response) = 0; - - virtual Status ListDevices(CallOptions* call_options, - const ListDevicesRequest* request, - ListDevicesResponse* response) = 0; - - virtual Status Reset(CallOptions* call_options, const ResetRequest* request, - ResetResponse* response) = 0; - - virtual Status MakeCallable(CallOptions* call_options, - const MakeCallableRequest* request, - MakeCallableResponse* response) = 0; - virtual Status RunCallable(CallOptions* call_options, - const RunCallableRequest* request, - RunCallableResponse* response) = 0; - virtual Status ReleaseCallable(CallOptions* call_options, - const ReleaseCallableRequest* request, - ReleaseCallableResponse* response) = 0; + virtual absl::Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) = 0; + + virtual absl::Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) = 0; + + virtual absl::Status Reset(CallOptions* call_options, + const ResetRequest* request, + ResetResponse* response) = 0; + + virtual absl::Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) = 0; + virtual absl::Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) = 0; + virtual absl::Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) = 0; protected: // NOTE: This should only be called by implementations of this diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index cb08a53815fd73..761188d14f9b40 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tensorflow/core/common_runtime/process_util.h" #include "tensorflow/core/common_runtime/profile_handler.h" #include "tensorflow/core/common_runtime/stats_publisher_interface.h" @@ -64,7 +65,6 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { @@ -148,15 +148,17 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { } LoggingResponse* resp = new LoggingResponse; Ref(); - p.worker->LoggingAsync(req, resp, [this, req, resp](const Status& s) { - delete req; - delete resp; - // ReffedClientGraph owns p.worker so we need to hold a ref to - // ensure that the method doesn't attempt to access p.worker after - // ReffedClient graph has deleted it. - // TODO(suharshs): Simplify this ownership model. - Unref(); - }); + p.worker->LoggingAsync(req, resp, + [this, req, resp](const absl::Status& s) { + delete req; + delete resp; + // ReffedClientGraph owns p.worker so we need to + // hold a ref to ensure that the method doesn't + // attempt to access p.worker after ReffedClient + // graph has deleted it. + // TODO(suharshs): Simplify this ownership model. + Unref(); + }); } } @@ -178,7 +180,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { LoggingResponse* resp = new LoggingResponse; p.worker->LoggingAsync( &req, resp, - [step_id, ss, resp, &scoped_mu, &all_done](const Status& s) { + [step_id, ss, resp, &scoped_mu, &all_done](const absl::Status& s) { { mutex_lock l(scoped_mu); if (s.ok()) { @@ -205,18 +207,22 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // Partitions the graph into subgraphs and registers them on // workers. - Status RegisterPartitions(PartitionOptions popts); + absl::Status RegisterPartitions(PartitionOptions popts); // Runs one step of all partitions. - Status RunPartitions(const MasterEnv* env, int64_t step_id, - int64_t execution_count, PerStepState* pss, - CallOptions* opts, const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp, - CancellationManager* cm, const bool is_last_partial_run); - Status RunPartitions(const MasterEnv* env, int64_t step_id, - int64_t execution_count, PerStepState* pss, - CallOptions* call_opts, const RunCallableRequest& req, - RunCallableResponse* resp, CancellationManager* cm); + absl::Status RunPartitions(const MasterEnv* env, int64_t step_id, + int64_t execution_count, PerStepState* pss, + CallOptions* opts, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp, + CancellationManager* cm, + const bool is_last_partial_run); + absl::Status RunPartitions(const MasterEnv* env, int64_t step_id, + int64_t execution_count, PerStepState* pss, + CallOptions* call_opts, + const RunCallableRequest& req, + RunCallableResponse* resp, + CancellationManager* cm); // Calls workers to cleanup states for the step "step_id". Calls // `done` when all cleanup RPCs have completed. @@ -228,9 +234,9 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { void ProcessDeviceStats(ProfileHandler* ph, const DeviceStepStats& ds, bool is_rpc); // Checks that the requested fetches can be computed from the provided feeds. - Status CheckFetches(const RunStepRequestWrapper& req, - const RunState* run_state, - GraphExecutionState* execution_state); + absl::Status CheckFetches(const RunStepRequestWrapper& req, + const RunState* run_state, + GraphExecutionState* execution_state); private: const string session_handle_; @@ -290,7 +296,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { Notification init_done_; // init_result_ remembers the initialization error if any. - Status init_result_ TF_GUARDED_BY(mu_); + absl::Status init_result_ TF_GUARDED_BY(mu_); std::unique_ptr stats_publisher_; @@ -314,10 +320,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const PartitionOptions& popts); // The actual graph partitioning and registration implementation. - Status DoBuildPartitions( + absl::Status DoBuildPartitions( PartitionOptions popts, ClientGraph* client_graph, std::unordered_map* out_partitions); - Status DoRegisterPartitions( + absl::Status DoRegisterPartitions( const PartitionOptions& popts, std::unordered_map graph_partitions); @@ -325,7 +331,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { // This is a generic method that handles Run, PartialRun, and RunCallable. template - Status RunPartitionsHelper( + absl::Status RunPartitionsHelper( const std::unordered_map& feeds, const FetchListType& fetches, const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, @@ -340,7 +346,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { void operator=(const ReffedClientGraph&) = delete; }; -Status MasterSession::ReffedClientGraph::RegisterPartitions( +absl::Status MasterSession::ReffedClientGraph::RegisterPartitions( PartitionOptions popts) { { // Ensure register once. mu_.lock(); @@ -354,7 +360,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( mu_.unlock(); std::unordered_map graph_defs; popts.flib_def = client_graph->flib_def.get(); - Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs); + absl::Status s = + DoBuildPartitions(popts, client_graph.get(), &graph_defs); if (s.ok()) { // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain // valid after the call to DoRegisterPartitions begins, so @@ -376,7 +383,7 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions( init_done_.WaitForNotification(); mu_.lock(); } - const Status result = init_result_; + const absl::Status result = init_result_; mu_.unlock(); return result; } @@ -429,7 +436,7 @@ void MasterSession::ReffedClientGraph::TrackFeedsAndFetches( } } -Status MasterSession::ReffedClientGraph::DoBuildPartitions( +absl::Status MasterSession::ReffedClientGraph::DoBuildPartitions( PartitionOptions popts, ClientGraph* client_graph, std::unordered_map* out_partitions) { if (popts.need_to_record_start_times) { @@ -445,11 +452,11 @@ Status MasterSession::ReffedClientGraph::DoBuildPartitions( return Partition(popts, &client_graph->graph, out_partitions); } -Status MasterSession::ReffedClientGraph::DoRegisterPartitions( +absl::Status MasterSession::ReffedClientGraph::DoRegisterPartitions( const PartitionOptions& popts, std::unordered_map graph_partitions) { partitions_.reserve(graph_partitions.size()); - Status s; + absl::Status s; for (auto& name_def : graph_partitions) { partitions_.emplace_back(); Part* part = &partitions_.back(); @@ -471,7 +478,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( struct Call { RegisterGraphRequest req; RegisterGraphResponse resp; - Status status; + absl::Status status; }; const int num = partitions_.size(); absl::InlinedVector calls(num); @@ -490,7 +497,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( callable_opts_.run_options().debug_options(); c->req.set_collective_graph_key(collective_graph_key_); VLOG(2) << "Register " << c->req.graph_def().DebugString(); - auto cb = [c, &done](const Status& s) { + auto cb = [c, &done](const absl::Status& s) { c->status = s; done.DecrementCount(); }; @@ -524,7 +531,7 @@ class RunManyGraphs { Call* get(int index) { return &calls_[index]; } // When the index-th call is done, updates the overall status. - void WhenDone(int index, const Status& s) { + void WhenDone(int index, const absl::Status& s) { TRACEPRINTF("Partition %d %v", index, s); Call* call = get(index); call->done = true; @@ -532,7 +539,7 @@ class RunManyGraphs { if (resp->status_code() != absl::StatusCode::kOk) { // resp->status_code will only be non-OK if s.ok(). mutex_lock l(mu_); - Status resp_status = call->resp->status(); + absl::Status resp_status = call->resp->status(); ReportBadStatus(errors::CreateWithUpdatedMessage( resp_status, strings::StrCat("From ", *call->worker_name, ":\n", resp_status.message()))); @@ -580,7 +587,7 @@ class RunManyGraphs { pending_.Wait(); } - Status status() const { + absl::Status status() const { mutex_lock l(mu_); // Concat status objects in this StatusGroup to get the aggregated status, // as each status in status_group_ is already summarized status. @@ -595,7 +602,7 @@ class RunManyGraphs { StatusGroup status_group_ TF_GUARDED_BY(mu_); bool cancel_issued_ TF_GUARDED_BY(mu_) = false; - void ReportBadStatus(const Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + void ReportBadStatus(const absl::Status& s) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { VLOG(1) << "Master received error status " << s; if (!cancel_issued_ && !StatusGroup::IsDerived(s)) { // Only start cancelling other workers upon receiving a non-derived @@ -615,15 +622,15 @@ class RunManyGraphs { void operator=(const RunManyGraphs&) = delete; }; -Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req, - MutableRunGraphRequestWrapper* worker_req, - size_t index, const string& send_key) { +absl::Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req, + MutableRunGraphRequestWrapper* worker_req, + size_t index, const string& send_key) { return worker_req->AddSendFromRunStepRequest(client_req, index, send_key); } -Status AddSendFromClientRequest(const RunCallableRequest& client_req, - MutableRunGraphRequestWrapper* worker_req, - size_t index, const string& send_key) { +absl::Status AddSendFromClientRequest(const RunCallableRequest& client_req, + MutableRunGraphRequestWrapper* worker_req, + size_t index, const string& send_key) { return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key); } @@ -635,7 +642,7 @@ struct RunCallableResponseWrapper { RunMetadata* mutable_metadata() { return resp->mutable_metadata(); } - Status AddTensorFromRunGraphResponse( + absl::Status AddTensorFromRunGraphResponse( const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp, size_t index) { return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]); @@ -645,7 +652,7 @@ struct RunCallableResponseWrapper { template -Status MasterSession::ReffedClientGraph::RunPartitionsHelper( +absl::Status MasterSession::ReffedClientGraph::RunPartitionsHelper( const std::unordered_map& feeds, const FetchListType& fetches, const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, @@ -770,7 +777,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper( TF_RETURN_IF_ERROR(calls.status()); // Collects fetches and metadata. - Status status; + absl::Status status; for (int i = 0; i < num; ++i) { const Part& part = partitions_[i]; MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); @@ -810,7 +817,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper( return status; } -Status MasterSession::ReffedClientGraph::RunPartitions( +absl::Status MasterSession::ReffedClientGraph::RunPartitions( const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp, CancellationManager* cm, @@ -835,7 +842,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions( call_opts, req, resp, cm, is_last_partial_run); } -Status MasterSession::ReffedClientGraph::RunPartitions( +absl::Status MasterSession::ReffedClientGraph::RunPartitions( const MasterEnv* env, int64_t step_id, int64_t execution_count, PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req, RunCallableResponse* resp, CancellationManager* cm) { @@ -889,9 +896,9 @@ class CleanupBroadcastHelper { CleanupGraphResponse* response(int i) { return &resps_[i]; } // Called when the ith response is received. - void call_done(int i, const Status& s) { + void call_done(int i, const absl::Status& s) { bool run_callback = false; - Status status_copy; + absl::Status status_copy; { mutex_lock l(mu_); status_.Update(s); @@ -917,7 +924,7 @@ class CleanupBroadcastHelper { // Number of requests remaining to be collected. int num_pending_ TF_GUARDED_BY(mu_); // Aggregate status of the operation. - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); // Callback to be called when all operations complete. StatusCallback done_; @@ -937,7 +944,7 @@ void MasterSession::ReffedClientGraph::CleanupPartitionsAsync( const Part& part = partitions_[i]; part.worker->CleanupGraphAsync( helper->request(), helper->response(i), - [helper, i](const Status& s) { helper->call_done(i, s); }); + [helper, i](const absl::Status& s) { helper->call_done(i, s); }); } } @@ -1037,7 +1044,7 @@ void MasterSession::ReffedClientGraph::ProcessDeviceStats( // TODO(suharsh,mrry): Build a map from fetch target to set of feeds it depends // on once at setup time to prevent us from computing the dependencies // everytime. -Status MasterSession::ReffedClientGraph::CheckFetches( +absl::Status MasterSession::ReffedClientGraph::CheckFetches( const RunStepRequestWrapper& req, const RunState* run_state, GraphExecutionState* execution_state) { // Build the set of pending feeds that we haven't seen. @@ -1114,7 +1121,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() { const string name = part.name; WorkerInterface* w = part.worker; CHECK_NOTNULL(w); - auto cb = [worker_cache, c, name, w](const Status& s) { + auto cb = [worker_cache, c, name, w](const absl::Status& s) { if (!s.ok()) { // This error is potentially benign, so we don't log at the // error level. @@ -1264,8 +1271,8 @@ void MasterSession::UpdateLastAccessTime() { last_access_time_usec_.store(Env::Default()->NowMicros()); } -Status MasterSession::Create(GraphDef&& graph_def, - const ClusterDef& cluster_def) { +absl::Status MasterSession::Create(GraphDef&& graph_def, + const ClusterDef& cluster_def) { if (session_opts_.config.use_per_session_threads() || session_opts_.config.session_inter_op_thread_pool_size() > 0) { return errors::InvalidArgument( @@ -1290,7 +1297,8 @@ Status MasterSession::Create(GraphDef&& graph_def, return CreateWorkerSessions(cluster_def); } -Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { +absl::Status MasterSession::CreateWorkerSessions( + const ClusterDef& cluster_def) { const std::vector worker_names = filtered_worker_list_; WorkerCacheInterface* worker_cache = get_worker_cache(); @@ -1304,7 +1312,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { // Request and responses used for a given worker. CreateWorkerSessionRequest request; CreateWorkerSessionResponse response; - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); }; BlockingCounter done(worker_names.size()); std::vector workers(worker_names.size()); @@ -1325,7 +1333,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { const int64_t client_device_incarnation = devices_->client_device()->attributes().incarnation(); - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; @@ -1408,7 +1416,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { } for (size_t i = 0; i < worker_names.size(); ++i) { - auto cb = [i, &workers, &done](const Status& s) { + auto cb = [i, &workers, &done](const absl::Status& s) { workers[i].status = s; done.DecrementCount(); }; @@ -1423,7 +1431,7 @@ Status MasterSession::CreateWorkerSessions(const ClusterDef& cluster_def) { return status; } -Status MasterSession::DeleteWorkerSessions() { +absl::Status MasterSession::DeleteWorkerSessions() { WorkerCacheInterface* worker_cache = get_worker_cache(); const std::vector& worker_names = filtered_worker_list_; @@ -1439,7 +1447,7 @@ Status MasterSession::DeleteWorkerSessions() { // Request and responses used for a given worker. DeleteWorkerSessionRequest request; DeleteWorkerSessionResponse response; - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); }; BlockingCounter done(worker_names.size()); std::vector workers(worker_names.size()); @@ -1453,7 +1461,7 @@ Status MasterSession::DeleteWorkerSessions() { } }); - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); // Create all the workers & kick off the computations. for (size_t i = 0; i < worker_names.size(); ++i) { workers[i].name = &worker_names[i]; @@ -1465,7 +1473,7 @@ Status MasterSession::DeleteWorkerSessions() { } for (size_t i = 0; i < worker_names.size(); ++i) { - auto cb = [i, &workers, &done](const Status& s) { + auto cb = [i, &workers, &done](const absl::Status& s) { workers[i].status = s; done.DecrementCount(); }; @@ -1480,7 +1488,7 @@ Status MasterSession::DeleteWorkerSessions() { return status; } -Status MasterSession::ListDevices(ListDevicesResponse* resp) const { +absl::Status MasterSession::ListDevices(ListDevicesResponse* resp) const { if (worker_cache_) { // This is a ClusterSpec-propagated session, and thus env_->local_devices // are invalid. @@ -1504,8 +1512,8 @@ Status MasterSession::ListDevices(ListDevicesResponse* resp) const { return absl::OkStatus(); } -Status MasterSession::Extend(const ExtendSessionRequest* req, - ExtendSessionResponse* resp) { +absl::Status MasterSession::Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp) { UpdateLastAccessTime(); std::unique_ptr extended_execution_state; { @@ -1540,9 +1548,10 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const { return env_->worker_cache; } -Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial, - ReffedClientGraph** out_rcg, - int64_t* out_count) { +absl::Status MasterSession::StartStep(const BuildGraphOptions& opts, + bool is_partial, + ReffedClientGraph** out_rcg, + int64_t* out_count) { const uint64 hash = HashBuildGraphOptions(opts); { mutex_lock l(mu_); @@ -1597,9 +1606,9 @@ uint64 MasterSession::NewStepId(int64_t graph_key) { int32_t retry_count = 0; while (static_cast(step_id) == CollectiveExecutor::kInvalidId) { Notification note; - Status status; + absl::Status status; env_->collective_executor_mgr->RefreshStepIdSequenceAsync( - graph_key, [&status, ¬e](const Status& s) { + graph_key, [&status, ¬e](const absl::Status& s) { status = s; note.Notify(); }); @@ -1618,8 +1627,8 @@ uint64 MasterSession::NewStepId(int64_t graph_key) { } } -Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, - PartialRunSetupResponse* resp) { +absl::Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, + PartialRunSetupResponse* resp) { std::vector inputs, outputs, targets; for (const auto& feed : req->feed()) { inputs.push_back(feed); @@ -1657,8 +1666,9 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, return absl::OkStatus(); } -Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp) { +absl::Status MasterSession::Run(CallOptions* opts, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp) { UpdateLastAccessTime(); { mutex_lock l(mu_); @@ -1669,7 +1679,7 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req, // Note: all code paths must eventually call MarkRunCompletion() // in order to appropriate decrement the num_running_ counter. } - Status status; + absl::Status status; if (!req.partial_run_handle().empty()) { status = DoPartialRun(opts, req, resp); } else { @@ -1687,7 +1697,7 @@ void MasterSession::MarkRunCompletion() { } } -Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { +absl::Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { // Registers subgraphs if haven't done so. PartitionOptions popts; popts.node_to_loc = SplitByWorker; @@ -1730,9 +1740,9 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) { return absl::OkStatus(); } -Status MasterSession::DoPartialRun(CallOptions* opts, - const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp) { +absl::Status MasterSession::DoPartialRun(CallOptions* opts, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp) { auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); const string& prun_handle = req.partial_run_handle(); RunState* run_state = nullptr; @@ -1831,7 +1841,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, } bool is_last_partial_run = run_state->PendingDone(); - Status s = run_state->rcg->RunPartitions( + absl::Status s = run_state->rcg->RunPartitions( env_, run_state->step_id, run_state->count, &run_state->pss, opts, req, resp, &cancellation_manager_, is_last_partial_run); @@ -1846,7 +1856,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, req.options(), resp->mutable_metadata()); cleanup.release(); // MarkRunCompletion called in done closure. rcg->CleanupPartitionsAsync( - run_state->step_id, [this, rcg, prun_handle](const Status& s) { + run_state->step_id, [this, rcg, prun_handle](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } @@ -1860,7 +1870,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, return s; } -Status MasterSession::CreateDebuggerState( +absl::Status MasterSession::CreateDebuggerState( const DebugOptions& debug_options, const RunStepRequestWrapper& req, int64_t rcg_execution_count, std::unique_ptr* debugger_state) { @@ -1922,14 +1932,12 @@ void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg, } } -Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, - uint64 step_id, - const RunOptions& run_options, - PerStepState* pss, - const std::unique_ptr& ph, - const Status& run_status, - RunMetadata* out_run_metadata) { - Status s = run_status; +absl::Status MasterSession::PostRunCleanup( + MasterSession::ReffedClientGraph* rcg, uint64 step_id, + const RunOptions& run_options, PerStepState* pss, + const std::unique_ptr& ph, const absl::Status& run_status, + RunMetadata* out_run_metadata) { + absl::Status s = run_status; if (s.ok()) { pss->end_micros = Env::Default()->NowMicros(); if (rcg->collective_graph_key() != @@ -1954,7 +1962,7 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, } Ref(); rcg->Ref(); - rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { + rcg->CleanupPartitionsAsync(step_id, [this, rcg](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Cleanup partition error: " << s; } @@ -1965,7 +1973,7 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, return s; } -Status MasterSession::DoRunWithLocalExecution( +absl::Status MasterSession::DoRunWithLocalExecution( CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString(); @@ -2007,16 +2015,16 @@ Status MasterSession::DoRunWithLocalExecution( "disable_output_partition_graphs is true."); } - Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, - &cancellation_manager_, false); + absl::Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, + resp, &cancellation_manager_, false); cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s, resp->mutable_metadata()); } -Status MasterSession::MakeCallable(const MakeCallableRequest& req, - MakeCallableResponse* resp) { +absl::Status MasterSession::MakeCallable(const MakeCallableRequest& req, + MakeCallableResponse* resp) { UpdateLastAccessTime(); BuildGraphOptions opts; @@ -2038,7 +2046,7 @@ Status MasterSession::MakeCallable(const MakeCallableRequest& req, !should_delete_worker_sessions_); } - Status s = BuildAndRegisterPartitions(callable); + absl::Status s = BuildAndRegisterPartitions(callable); if (!s.ok()) { callable->Unref(); return s; @@ -2055,9 +2063,10 @@ Status MasterSession::MakeCallable(const MakeCallableRequest& req, return absl::OkStatus(); } -Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, - const RunCallableRequest& req, - RunCallableResponse* resp) { +absl::Status MasterSession::DoRunCallable(CallOptions* opts, + ReffedClientGraph* rcg, + const RunCallableRequest& req, + RunCallableResponse* resp) { VLOG(2) << "DoRunCallable req: " << req.DebugString(); PerStepState pss; pss.start_micros = Env::Default()->NowMicros(); @@ -2077,16 +2086,16 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, std::unique_ptr ph; FillPerStepState(rcg, run_options, step_id, count, &pss, &ph); - Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, - &cancellation_manager_); + absl::Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, + resp, &cancellation_manager_); cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s, resp->mutable_metadata()); } -Status MasterSession::RunCallable(CallOptions* opts, - const RunCallableRequest& req, - RunCallableResponse* resp) { +absl::Status MasterSession::RunCallable(CallOptions* opts, + const RunCallableRequest& req, + RunCallableResponse* resp) { UpdateLastAccessTime(); ReffedClientGraph* callable; { @@ -2111,8 +2120,8 @@ Status MasterSession::RunCallable(CallOptions* opts, return DoRunCallable(opts, callable, req, resp); } -Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, - ReleaseCallableResponse* resp) { +absl::Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, + ReleaseCallableResponse* resp) { UpdateLastAccessTime(); ReffedClientGraph* to_unref = nullptr; { @@ -2129,7 +2138,7 @@ Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, return absl::OkStatus(); } -Status MasterSession::Close() { +absl::Status MasterSession::Close() { { mutex_lock l(mu_); closed_ = true; // All subsequent calls to Run() or Extend() will fail. @@ -2147,7 +2156,7 @@ Status MasterSession::Close() { } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); if (should_delete_worker_sessions_) { - Status s = DeleteWorkerSessions(); + absl::Status s = DeleteWorkerSessions(); if (!s.ok()) { LOG(WARNING) << s; } diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 0a568db96cd403..f7016518bca5a9 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -57,7 +57,7 @@ class MasterSession : public core::RefCounted { // Initialize the MasterSession for "def". Must be called before Extend(), // Run(), or Close(). - Status Create(GraphDef&& def, const ClusterDef& cluster_def); + absl::Status Create(GraphDef&& def, const ClusterDef& cluster_def); // Returns the session handle. const string& handle() const { return handle_; } @@ -76,32 +76,33 @@ class MasterSession : public core::RefCounted { // is "resp->new_graph_version". // // Extend() may block the caller thread for a long time. - Status Extend(const ExtendSessionRequest* req, ExtendSessionResponse* resp); + absl::Status Extend(const ExtendSessionRequest* req, + ExtendSessionResponse* resp); // Setup a partial run call. - Status PartialRunSetup(const PartialRunSetupRequest* req, - PartialRunSetupResponse* resp); + absl::Status PartialRunSetup(const PartialRunSetupRequest* req, + PartialRunSetupResponse* resp); // Run one step. - Status Run(CallOptions* opts, const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp); + absl::Status Run(CallOptions* opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); - Status ListDevices(ListDevicesResponse* resp) const; + absl::Status ListDevices(ListDevicesResponse* resp) const; - Status MakeCallable(const MakeCallableRequest& req, - MakeCallableResponse* resp); + absl::Status MakeCallable(const MakeCallableRequest& req, + MakeCallableResponse* resp); - Status RunCallable(CallOptions* opts, const RunCallableRequest& req, - RunCallableResponse* resp); + absl::Status RunCallable(CallOptions* opts, const RunCallableRequest& req, + RunCallableResponse* resp); - Status ReleaseCallable(const ReleaseCallableRequest& req, - ReleaseCallableResponse* resp); + absl::Status ReleaseCallable(const ReleaseCallableRequest& req, + ReleaseCallableResponse* resp); // Close this session and delete "*this". Returns OK if all known // states are cleanup successfully. // // Close() may block the caller thread for a long time. - Status Close(); + absl::Status Close(); // Close this session and release a reference on "*this". // @@ -217,39 +218,40 @@ class MasterSession : public core::RefCounted { // If this session is operating using the new ClusterSpec propagation behavior // call this method in order to propagate the cluster membership to all // workers. - Status CreateWorkerSessions(const ClusterDef& cluster_def); + absl::Status CreateWorkerSessions(const ClusterDef& cluster_def); bool should_delete_worker_sessions_ = false; - Status DeleteWorkerSessions(); + absl::Status DeleteWorkerSessions(); - Status StartStep(const BuildGraphOptions& opts, bool is_partial, - ReffedClientGraph** out_rcg, int64_t* out_count); + absl::Status StartStep(const BuildGraphOptions& opts, bool is_partial, + ReffedClientGraph** out_rcg, int64_t* out_count); void ClearRunsTable(std::vector* to_unref, RCGMap* rcg_map) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); void FillPerStepState(MasterSession::ReffedClientGraph* rcg, const RunOptions& run_options, uint64 step_id, int64_t count, PerStepState* out_pss, std::unique_ptr* out_ph); - Status DoRunWithLocalExecution(CallOptions* opts, - const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp); - Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp); - Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, - const RunCallableRequest& req, - RunCallableResponse* resp); - Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id, - const RunOptions& run_options, PerStepState* pss, - const std::unique_ptr& ph, - const Status& run_status, - RunMetadata* out_run_metadata); + absl::Status DoRunWithLocalExecution(CallOptions* opts, + const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); + absl::Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp); + absl::Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, + const RunCallableRequest& req, + RunCallableResponse* resp); + absl::Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, + uint64 step_id, const RunOptions& run_options, + PerStepState* pss, + const std::unique_ptr& ph, + const absl::Status& run_status, + RunMetadata* out_run_metadata); void MarkRunCompletion(); void UpdateLastAccessTime(); - Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); + absl::Status BuildAndRegisterPartitions(ReffedClientGraph* rcg); - Status CreateDebuggerState( + absl::Status CreateDebuggerState( const DebugOptions& debug_options, const RunStepRequestWrapper& req, int64_t rcg_execution_count, std::unique_ptr* debugger_state); diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc index 1e9e5545183191..cd46939b8357f4 100644 --- a/tensorflow/core/distributed_runtime/master_test.cc +++ b/tensorflow/core/distributed_runtime/master_test.cc @@ -64,15 +64,16 @@ class MasterTest : public ::testing::Test { // Helpers for MasterService.{CreateSession,RunStep,CloseSession} // rpc calls. - Status CreateSession(const GraphDef& def, string* handle, - int64_t* initial_version) { + absl::Status CreateSession(const GraphDef& def, string* handle, + int64_t* initial_version) { ::grpc::ClientContext ctx; CreateSessionRequest req; *(req.mutable_graph_def()) = def; // Invokes placement frequently. req.mutable_config()->set_placement_period(1); CreateSessionResponse resp; - const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp)); + const absl::Status s = + FromGrpcStatus(master_->CreateSession(&ctx, req, &resp)); if (s.ok()) { *handle = resp.session_handle(); *initial_version = resp.graph_version(); @@ -80,24 +81,26 @@ class MasterTest : public ::testing::Test { return s; } - Status ExtendSession(const string& handle, const GraphDef& def, - int64_t current_version, int64_t* new_version) { + absl::Status ExtendSession(const string& handle, const GraphDef& def, + int64_t current_version, int64_t* new_version) { ::grpc::ClientContext ctx; ExtendSessionRequest req; req.set_session_handle(handle); *(req.mutable_graph_def()) = def; req.set_current_graph_version(current_version); ExtendSessionResponse resp; - const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp)); + const absl::Status s = + FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp)); if (s.ok()) { *new_version = resp.new_graph_version(); } return s; } - Status RunStep(const string& handle, - const std::vector >& feed, - const std::map& fetch) { + absl::Status RunStep( + const string& handle, + const std::vector >& feed, + const std::map& fetch) { ::grpc::ClientContext ctx; RunStepRequest req; req.set_session_handle(handle); @@ -113,7 +116,7 @@ class MasterTest : public ::testing::Test { req.add_fetch(fetch_name); } RunStepResponse resp; - const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp)); + const absl::Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp)); if (s.ok()) { for (const auto& fetch_resp : resp.tensor()) { auto it = fetch.find(fetch_resp.name()); @@ -124,7 +127,7 @@ class MasterTest : public ::testing::Test { return s; } - Status CloseSession(const string& handle) { + absl::Status CloseSession(const string& handle) { ::grpc::ClientContext ctx; CloseSessionRequest req; req.set_session_handle(handle); @@ -132,7 +135,7 @@ class MasterTest : public ::testing::Test { return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp)); } - Status Reset() { + absl::Status Reset() { ::grpc::ClientContext ctx; ResetRequest req; ResetResponse resp; @@ -153,7 +156,7 @@ TEST_F(MasterTest, ListDevices) { ::grpc::ClientContext ctx; ListDevicesRequest req; ListDevicesResponse resp; - const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp)); + const absl::Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp)); TF_EXPECT_OK(s); EXPECT_EQ(1, resp.local_device_size()); EXPECT_EQ("CPU", resp.local_device(0).device_type()); @@ -268,7 +271,8 @@ TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) { &failed]() { n.WaitForNotification(); int64_t new_version; - Status s = ExtendSession(handle, def_1, initial_version, &new_version); + absl::Status s = + ExtendSession(handle, def_1, initial_version, &new_version); EXPECT_TRUE(s.ok() || errors::IsAborted(s)); { mutex_lock l(mu); @@ -337,7 +341,7 @@ TEST_F(MasterTest, ConcurrentExtendAndRun) { // Concurrent with the Extend, we will either fail (as above), or // succeed (as below). while (!extend_done.HasBeenNotified()) { - Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}); + absl::Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}); EXPECT_TRUE(errors::IsNotFound(s) || s.ok()); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 1212d187fbd6ff..60a264565dbb61 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -59,13 +59,14 @@ const string& InMemoryRunStepRequest::feed_name(size_t i) const { return feeds_[i].first; } -Status InMemoryRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { +absl::Status InMemoryRunStepRequest::FeedValue(size_t i, + Tensor* out_tensor) const { *out_tensor = feeds_[i].second; return absl::OkStatus(); } -Status InMemoryRunStepRequest::FeedValue(size_t i, - TensorProto* out_tensor) const { +absl::Status InMemoryRunStepRequest::FeedValue(size_t i, + TensorProto* out_tensor) const { feeds_[i].second.AsProtoTensorContent(out_tensor); return absl::OkStatus(); } @@ -152,8 +153,8 @@ size_t MutableProtoRunStepRequest::num_feeds() const { const string& MutableProtoRunStepRequest::feed_name(size_t i) const { return request_.feed(i).name(); } -Status MutableProtoRunStepRequest::FeedValue(size_t i, - Tensor* out_tensor) const { +absl::Status MutableProtoRunStepRequest::FeedValue(size_t i, + Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_.feed(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { @@ -161,8 +162,8 @@ Status MutableProtoRunStepRequest::FeedValue(size_t i, } } -Status MutableProtoRunStepRequest::FeedValue(size_t i, - TensorProto* out_tensor) const { +absl::Status MutableProtoRunStepRequest::FeedValue( + size_t i, TensorProto* out_tensor) const { *out_tensor = request_.feed(i).tensor(); return absl::OkStatus(); } @@ -244,7 +245,8 @@ const string& ProtoRunStepRequest::feed_name(size_t i) const { return request_->feed(i).name(); } -Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { +absl::Status ProtoRunStepRequest::FeedValue(size_t i, + Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_->feed(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { @@ -252,7 +254,8 @@ Status ProtoRunStepRequest::FeedValue(size_t i, Tensor* out_tensor) const { } } -Status ProtoRunStepRequest::FeedValue(size_t i, TensorProto* out_tensor) const { +absl::Status ProtoRunStepRequest::FeedValue(size_t i, + TensorProto* out_tensor) const { *out_tensor = request_->feed(i).tensor(); return absl::OkStatus(); } @@ -335,12 +338,13 @@ const string& InMemoryRunGraphRequest::send_key(size_t i) const { return sends_[i].first; } -Status InMemoryRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { +absl::Status InMemoryRunGraphRequest::SendValue(size_t i, + Tensor* out_tensor) const { *out_tensor = sends_[i].second; return absl::OkStatus(); } -Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( +absl::Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) { Tensor tensor; @@ -349,7 +353,7 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( return absl::OkStatus(); } -Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( +absl::Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( const RunCallableRequest& run_callable_request, size_t i, const string& send_key) { Tensor tensor; @@ -475,8 +479,8 @@ const string& MutableProtoRunGraphRequest::send_key(size_t i) const { return request_.send(i).name(); } -Status MutableProtoRunGraphRequest::SendValue(size_t i, - Tensor* out_tensor) const { +absl::Status MutableProtoRunGraphRequest::SendValue(size_t i, + Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_.send(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { @@ -484,7 +488,7 @@ Status MutableProtoRunGraphRequest::SendValue(size_t i, } } -Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( +absl::Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) { NamedTensorProto* send = request_.add_send(); @@ -493,7 +497,7 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( return absl::OkStatus(); } -Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( +absl::Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( const RunCallableRequest& run_callable_request, size_t i, const string& send_key) { NamedTensorProto* send = request_.add_send(); @@ -579,7 +583,8 @@ const string& ProtoRunGraphRequest::send_key(size_t i) const { return request_->send(i).name(); } -Status ProtoRunGraphRequest::SendValue(size_t i, Tensor* out_tensor) const { +absl::Status ProtoRunGraphRequest::SendValue(size_t i, + Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(request_->send(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for feed value ", i); } else { @@ -619,12 +624,13 @@ const string& InMemoryRunGraphResponse::recv_key(size_t i) const { return recvs_[i].first; } -Status InMemoryRunGraphResponse::RecvValue(size_t i, TensorProto* out_tensor) { +absl::Status InMemoryRunGraphResponse::RecvValue(size_t i, + TensorProto* out_tensor) { recvs_[i].second.AsProtoTensorContent(out_tensor); return absl::OkStatus(); } -Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { +absl::Status InMemoryRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { *out_tensor = recvs_[i].second; return absl::OkStatus(); } @@ -641,13 +647,13 @@ CostGraphDef* InMemoryRunGraphResponse::mutable_cost_graph() { return &cost_graph_; } -Status InMemoryRunGraphResponse::status() const { return status_; } +absl::Status InMemoryRunGraphResponse::status() const { return status_; } -errors::Code InMemoryRunGraphResponse::status_code() const { - return static_cast(status_.code()); +absl::StatusCode InMemoryRunGraphResponse::status_code() const { + return static_cast(status_.code()); } -void InMemoryRunGraphResponse::set_status(const Status& status) { +void InMemoryRunGraphResponse::set_status(const absl::Status& status) { status_ = status; } @@ -677,13 +683,14 @@ const string& OwnedProtoRunGraphResponse::recv_key(size_t i) const { return response_.recv(i).name(); } -Status OwnedProtoRunGraphResponse::RecvValue(size_t i, - TensorProto* out_tensor) { +absl::Status OwnedProtoRunGraphResponse::RecvValue(size_t i, + TensorProto* out_tensor) { out_tensor->Swap(response_.mutable_recv(i)->mutable_tensor()); return absl::OkStatus(); } -Status OwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { +absl::Status OwnedProtoRunGraphResponse::RecvValue(size_t i, + Tensor* out_tensor) { if (!ParseTensorProtoToTensor(response_.recv(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for recv value ", i); } else { @@ -707,16 +714,16 @@ CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { return response_.mutable_cost_graph(); } -Status OwnedProtoRunGraphResponse::status() const { - return Status(static_cast(response_.status_code()), - response_.status_error_message()); +absl::Status OwnedProtoRunGraphResponse::status() const { + return absl::Status(static_cast(response_.status_code()), + response_.status_error_message()); } absl::StatusCode OwnedProtoRunGraphResponse::status_code() const { return static_cast(response_.status_code()); } -void OwnedProtoRunGraphResponse::set_status(const Status& status) { +void OwnedProtoRunGraphResponse::set_status(const absl::Status& status) { response_.set_status_code(static_cast(status.code())); response_.set_status_error_message(absl::StatusMessageAsCStr(status)); } @@ -749,13 +756,14 @@ const string& NonOwnedProtoRunGraphResponse::recv_key(size_t i) const { return response_->recv(i).name(); } -Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, - TensorProto* out_tensor) { +absl::Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, + TensorProto* out_tensor) { out_tensor->Swap(response_->mutable_recv(i)->mutable_tensor()); return absl::OkStatus(); } -Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, Tensor* out_tensor) { +absl::Status NonOwnedProtoRunGraphResponse::RecvValue(size_t i, + Tensor* out_tensor) { if (!ParseTensorProtoToTensor(response_->recv(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for recv value ", i); } else { @@ -779,16 +787,16 @@ CostGraphDef* NonOwnedProtoRunGraphResponse::mutable_cost_graph() { return response_->mutable_cost_graph(); } -Status NonOwnedProtoRunGraphResponse::status() const { - return Status(static_cast(response_->status_code()), - response_->status_error_message()); +absl::Status NonOwnedProtoRunGraphResponse::status() const { + return absl::Status(static_cast(response_->status_code()), + response_->status_error_message()); } absl::StatusCode NonOwnedProtoRunGraphResponse::status_code() const { return static_cast(response_->status_code()); } -void NonOwnedProtoRunGraphResponse::set_status(const Status& status) { +void NonOwnedProtoRunGraphResponse::set_status(const absl::Status& status) { response_->set_status_code(static_cast(status.code())); response_->set_status_error_message(absl::StatusMessageAsCStr(status)); } @@ -819,8 +827,8 @@ const string& InMemoryRunStepResponse::tensor_name(size_t i) const { return tensors_[i].first; } -Status InMemoryRunStepResponse::TensorValue(size_t i, - Tensor* out_tensor) const { +absl::Status InMemoryRunStepResponse::TensorValue(size_t i, + Tensor* out_tensor) const { *out_tensor = tensors_[i].second; return absl::OkStatus(); } @@ -829,7 +837,7 @@ const RunMetadata& InMemoryRunStepResponse::metadata() const { return metadata_; } -Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( +absl::Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* wrapper, size_t i) { Tensor tensor; TF_RETURN_IF_ERROR(wrapper->RecvValue(i, &tensor)); @@ -839,13 +847,13 @@ Status InMemoryRunStepResponse::AddTensorFromRunGraphResponse( RunMetadata* InMemoryRunStepResponse::mutable_metadata() { return &metadata_; } -Status InMemoryRunStepResponse::status() const { return status_; } +absl::Status InMemoryRunStepResponse::status() const { return status_; } -errors::Code InMemoryRunStepResponse::status_code() const { - return static_cast(status_.code()); +absl::StatusCode InMemoryRunStepResponse::status_code() const { + return static_cast(status_.code()); } -void InMemoryRunStepResponse::set_status(const Status& status) { +void InMemoryRunStepResponse::set_status(const absl::Status& status) { status_ = status; } @@ -862,8 +870,8 @@ const string& OwnedProtoRunStepResponse::tensor_name(size_t i) const { return response_.tensor(i).name(); } -Status OwnedProtoRunStepResponse::TensorValue(size_t i, - Tensor* out_tensor) const { +absl::Status OwnedProtoRunStepResponse::TensorValue(size_t i, + Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(response_.tensor(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); } else { @@ -875,7 +883,7 @@ const RunMetadata& OwnedProtoRunStepResponse::metadata() const { return response_.metadata(); } -Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( +absl::Status OwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) { NamedTensorProto* response_tensor = response_.add_tensor(); @@ -887,16 +895,16 @@ RunMetadata* OwnedProtoRunStepResponse::mutable_metadata() { return response_.mutable_metadata(); } -Status OwnedProtoRunStepResponse::status() const { - return Status(static_cast(response_.status_code()), - response_.status_error_message()); +absl::Status OwnedProtoRunStepResponse::status() const { + return absl::Status(static_cast(response_.status_code()), + response_.status_error_message()); } absl::StatusCode OwnedProtoRunStepResponse::status_code() const { return static_cast(response_.status_code()); } -void OwnedProtoRunStepResponse::set_status(const Status& status) { +void OwnedProtoRunStepResponse::set_status(const absl::Status& status) { response_.set_status_code(static_cast(status.code())); response_.set_status_error_message(absl::StatusMessageAsCStr(status)); } @@ -915,8 +923,8 @@ const string& NonOwnedProtoRunStepResponse::tensor_name(size_t i) const { return response_->tensor(i).name(); } -Status NonOwnedProtoRunStepResponse::TensorValue(size_t i, - Tensor* out_tensor) const { +absl::Status NonOwnedProtoRunStepResponse::TensorValue( + size_t i, Tensor* out_tensor) const { if (!ParseTensorProtoToTensor(response_->tensor(i).tensor(), out_tensor)) { return errors::InvalidArgument("Invalid TensorProto for fetch value ", i); } else { @@ -928,7 +936,7 @@ const RunMetadata& NonOwnedProtoRunStepResponse::metadata() const { return response_->metadata(); } -Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( +absl::Status NonOwnedProtoRunStepResponse::AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) { NamedTensorProto* response_tensor = response_->add_tensor(); @@ -940,16 +948,16 @@ RunMetadata* NonOwnedProtoRunStepResponse::mutable_metadata() { return response_->mutable_metadata(); } -Status NonOwnedProtoRunStepResponse::status() const { - return Status(static_cast(response_->status_code()), - response_->status_error_message()); +absl::Status NonOwnedProtoRunStepResponse::status() const { + return absl::Status(static_cast(response_->status_code()), + response_->status_error_message()); } absl::StatusCode NonOwnedProtoRunStepResponse::status_code() const { return static_cast(response_->status_code()); } -void NonOwnedProtoRunStepResponse::set_status(const Status& status) { +void NonOwnedProtoRunStepResponse::set_status(const absl::Status& status) { response_->set_status_code(static_cast(status.code())); response_->set_status_error_message(absl::StatusMessageAsCStr(status)); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 4be73f850cad88..d4b07fb51ce4a3 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -64,8 +64,8 @@ class RunStepRequestWrapper { virtual const string& feed_name(size_t i) const = 0; // Stores the content of the feed value at index `i` in `tensor`. - virtual Status FeedValue(size_t i, Tensor* out_tensor) const = 0; - virtual Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; + virtual absl::Status FeedValue(size_t i, Tensor* out_tensor) const = 0; + virtual absl::Status FeedValue(size_t i, TensorProto* out_tensor) const = 0; // Fetches. A list of tensor names. The caller expects a tensor to // be returned for each fetch[i] (see RunStepResponse.tensor). The @@ -123,8 +123,8 @@ class InMemoryRunStepRequest : public MutableRunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* out_tensor) const override; - Status FeedValue(size_t i, TensorProto* out_tensor) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -174,8 +174,8 @@ class MutableProtoRunStepRequest : public MutableRunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* out_tensor) const override; - Status FeedValue(size_t i, TensorProto* out_tensor) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -215,8 +215,8 @@ class ProtoRunStepRequest : public RunStepRequestWrapper { const string& partial_run_handle() const override; size_t num_feeds() const override; const string& feed_name(size_t i) const override; - Status FeedValue(size_t i, Tensor* out_tensor) const override; - Status FeedValue(size_t i, TensorProto* out_tensor) const override; + absl::Status FeedValue(size_t i, Tensor* out_tensor) const override; + absl::Status FeedValue(size_t i, TensorProto* out_tensor) const override; size_t num_fetches() const override; const string& fetch_name(size_t i) const override; size_t num_targets() const override; @@ -277,7 +277,7 @@ class RunGraphRequestWrapper { // Sends the tensors in "send" into the graph before the run. virtual size_t num_sends() const = 0; virtual const string& send_key(size_t i) const = 0; - virtual Status SendValue(size_t i, Tensor* out_tensor) const = 0; + virtual absl::Status SendValue(size_t i, Tensor* out_tensor) const = 0; // Fetches the keys into `RunGraphResponse.recv` after the run. virtual size_t num_recvs() const = 0; @@ -315,10 +315,10 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { // Stores the i^{th} feed value in `run_step_request` in this // request with the given `send_key`. - virtual Status AddSendFromRunStepRequest( + virtual absl::Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) = 0; - virtual Status AddSendFromRunCallableRequest( + virtual absl::Status AddSendFromRunCallableRequest( const RunCallableRequest& run_callable_request, size_t i, const string& send_key) = 0; @@ -339,7 +339,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { const ExecutorOpts& exec_opts() const override; size_t num_sends() const override; const string& send_key(size_t i) const override; - Status SendValue(size_t i, Tensor* out_tensor) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; size_t num_recvs() const override; const string& recv_key(size_t i) const override; bool is_partial() const override; @@ -354,10 +354,10 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { void set_graph_handle(const string& handle) override; void set_step_id(int64_t step_id) override; ExecutorOpts* mutable_exec_opts() override; - Status AddSendFromRunStepRequest( + absl::Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) override; - Status AddSendFromRunCallableRequest( + absl::Status AddSendFromRunCallableRequest( const RunCallableRequest& run_callable_request, size_t i, const string& send_key) override; void add_recv_key(const string& recv_key) override; @@ -399,7 +399,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { const ExecutorOpts& exec_opts() const override; size_t num_sends() const override; const string& send_key(size_t i) const override; - Status SendValue(size_t i, Tensor* out_tensor) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; size_t num_recvs() const override; const string& recv_key(size_t i) const override; bool is_partial() const override; @@ -414,10 +414,10 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { void set_graph_handle(const string& handle) override; void set_step_id(int64_t step_id) override; ExecutorOpts* mutable_exec_opts() override; - Status AddSendFromRunStepRequest( + absl::Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) override; - Status AddSendFromRunCallableRequest( + absl::Status AddSendFromRunCallableRequest( const RunCallableRequest& run_callable_request, size_t i, const string& send_key) override; void add_recv_key(const string& recv_key) override; @@ -442,7 +442,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper { const ExecutorOpts& exec_opts() const override; size_t num_sends() const override; const string& send_key(size_t i) const override; - Status SendValue(size_t i, Tensor* out_tensor) const override; + absl::Status SendValue(size_t i, Tensor* out_tensor) const override; size_t num_recvs() const override; const string& recv_key(size_t i) const override; bool is_partial() const override; @@ -483,8 +483,8 @@ class MutableRunGraphResponseWrapper { virtual const string& recv_key(size_t i) const = 0; // NOTE: The following methods may perform a destructive read, for // efficiency. - virtual Status RecvValue(size_t i, TensorProto* out_tensor) = 0; - virtual Status RecvValue(size_t i, Tensor* out_tensor) = 0; + virtual absl::Status RecvValue(size_t i, TensorProto* out_tensor) = 0; + virtual absl::Status RecvValue(size_t i, Tensor* out_tensor) = 0; virtual void AddRecv(const string& key, const Tensor& value) = 0; // Submessages that store performance statistics about the subgraph @@ -496,9 +496,9 @@ class MutableRunGraphResponseWrapper { virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; // Returned status if requested. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; virtual absl::StatusCode status_code() const = 0; - virtual void set_status(const Status& status) = 0; + virtual void set_status(const absl::Status& status) = 0; protected: // Returns a mutable protobuf message that represents the contents of @@ -521,17 +521,17 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { // MutableRunGraphResponseWrapper methods. size_t num_recvs() const override; const string& recv_key(size_t i) const override; - Status RecvValue(size_t i, TensorProto* out_tensor) override; - Status RecvValue(size_t i, Tensor* out_tensor) override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: // NOTE: This method is not implemented. See @@ -545,7 +545,7 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { std::vector partition_graphs_; // Store the code and message separately so that they can be updated // independently by setters. - Status status_; + absl::Status status_; }; // Proto-based message wrapper for use on the client side of the RunGraph RPC. @@ -554,17 +554,17 @@ class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { // MutableRunGraphResponseWrapper methods. size_t num_recvs() const override; const string& recv_key(size_t i) const override; - Status RecvValue(size_t i, TensorProto* out_tensor) override; - Status RecvValue(size_t i, Tensor* out_tensor) override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: RunGraphResponse* get_proto() override; @@ -581,17 +581,17 @@ class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { // MutableRunGraphResponseWrapper methods. size_t num_recvs() const override; const string& recv_key(size_t i) const override; - Status RecvValue(size_t i, TensorProto* out_tensor) override; - Status RecvValue(size_t i, Tensor* out_tensor) override; + absl::Status RecvValue(size_t i, TensorProto* out_tensor) override; + absl::Status RecvValue(size_t i, Tensor* out_tensor) override; void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; size_t num_partition_graphs() const override; GraphDef* mutable_partition_graph(size_t i) override; void AddPartitionGraph(const GraphDef& partition_graph) override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: RunGraphResponse* get_proto() override; @@ -629,11 +629,11 @@ class MutableRunStepResponseWrapper { // the fetch order specified in RunStepRequest. virtual size_t num_tensors() const = 0; virtual const string& tensor_name(size_t i) const = 0; - virtual Status TensorValue(size_t i, Tensor* out_tensor) const = 0; + virtual absl::Status TensorValue(size_t i, Tensor* out_tensor) const = 0; // Stores the i^{th} recv value in `run_graph_response` in this // response with the given `name`. - virtual Status AddTensorFromRunGraphResponse( + virtual absl::Status AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) = 0; @@ -642,9 +642,9 @@ class MutableRunStepResponseWrapper { virtual RunMetadata* mutable_metadata() = 0; // Returned status if requested. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; virtual absl::StatusCode status_code() const = 0; - virtual void set_status(const Status& status) = 0; + virtual void set_status(const absl::Status& status) = 0; protected: // Returns a mutable protobuf message that represents the contents of @@ -667,15 +667,15 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { // MutableRunStepResponseWrapper methods. size_t num_tensors() const override; const string& tensor_name(size_t i) const override; - Status TensorValue(size_t i, Tensor* out_tensor) const override; - Status AddTensorFromRunGraphResponse( + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: // NOTE: This method is not implemented. See @@ -687,7 +687,7 @@ class InMemoryRunStepResponse : public MutableRunStepResponseWrapper { RunMetadata metadata_; // Store the code and message separately so that they can be updated // independently by setters. - Status status_; + absl::Status status_; }; // Proto-based message wrapper for use on the client side of the RunStep RPC. @@ -696,15 +696,15 @@ class OwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { // MutableRunStepResponseWrapper methods. size_t num_tensors() const override; const string& tensor_name(size_t i) const override; - Status TensorValue(size_t i, Tensor* out_tensor) const override; - Status AddTensorFromRunGraphResponse( + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: RunStepResponse* get_proto() override; @@ -721,15 +721,15 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper { // MutableRunStepResponseWrapper methods. size_t num_tensors() const override; const string& tensor_name(size_t i) const override; - Status TensorValue(size_t i, Tensor* out_tensor) const override; - Status AddTensorFromRunGraphResponse( + absl::Status TensorValue(size_t i, Tensor* out_tensor) const override; + absl::Status AddTensorFromRunGraphResponse( const string& name, MutableRunGraphResponseWrapper* run_graph_response, size_t i) override; const RunMetadata& metadata() const override; RunMetadata* mutable_metadata() override; - Status status() const override; + absl::Status status() const override; absl::StatusCode status_code() const override; - void set_status(const Status& status) override; + void set_status(const absl::Status& status) override; protected: RunStepResponse* get_proto() override; diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.cc b/tensorflow/core/distributed_runtime/partial_run_mgr.cc index d77e437dbb2dea..27cf54536a4601 100644 --- a/tensorflow/core/distributed_runtime/partial_run_mgr.cc +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.cc @@ -35,9 +35,10 @@ bool PartialRunMgr::FindOrCreate(int step_id, return true; } -void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) { +void PartialRunMgr::ExecutorDone(int step_id, + const absl::Status& executor_status) { StatusCallback done; - Status callback_status; + absl::Status callback_status; { mutex_lock l(mu_); auto run_it = step_id_to_partial_run_.find(step_id); @@ -63,8 +64,8 @@ void PartialRunMgr::ExecutorDone(int step_id, const Status& executor_status) { } void PartialRunMgr::PartialRunDone(int step_id, StatusCallback done, - const Status& status) { - Status callback_status; + const absl::Status& status) { + absl::Status callback_status; { mutex_lock l(mu_); auto run_it = step_id_to_partial_run_.find(step_id); diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr.h b/tensorflow/core/distributed_runtime/partial_run_mgr.h index 5c104125029691..bf2b2b1a0b5607 100644 --- a/tensorflow/core/distributed_runtime/partial_run_mgr.h +++ b/tensorflow/core/distributed_runtime/partial_run_mgr.h @@ -55,7 +55,7 @@ class PartialRunMgr { // Calls the final callback if the PartialRunRequest has already completed. // Otherwise stores the executor_status to be propagated when the // PartialRunRequest completes (PartialRunDone has been called). - void ExecutorDone(int step_id, const Status& executor_status); + void ExecutorDone(int step_id, const absl::Status& executor_status); // Calls done if the executor has already completed (ExecutorDone has been // called). Otherwise, stores the status and done callback, calling them when @@ -63,7 +63,8 @@ class PartialRunMgr { // thread of either PartialRunDone or ExecutorDone. // If executor_status in ExecutorDone is not OK, it takes precedence over // status and is passed to the done callback. - void PartialRunDone(int step_id, StatusCallback done, const Status& status); + void PartialRunDone(int step_id, StatusCallback done, + const absl::Status& status); private: // PartialRunState stores state associated with a pending partial run request. @@ -73,7 +74,7 @@ class PartialRunMgr { bool executor_done = false; StatusCallback final_callback = nullptr; - Status final_status; + absl::Status final_status; }; mutex mu_; diff --git a/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc index a2d56e5b2fd6ad..5ef771a55278e9 100644 --- a/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/partial_run_mgr_test.cc @@ -65,22 +65,22 @@ TEST(PartialRunMgr, PartialRunRemoved) { int called = 0; partial_run_mgr.PartialRunDone( - step_id, [&called](Status status) { called++; }, absl::OkStatus()); + step_id, [&called](absl::Status status) { called++; }, absl::OkStatus()); partial_run_mgr.ExecutorDone(step_id, absl::OkStatus()); // Calling ExecutorDone and PartialRunDone on the step_id should still only // result in the callback being called once. // This proves that the original PartialRun has been removed. partial_run_mgr.PartialRunDone( - step_id, [&called](Status status) { called++; }, absl::OkStatus()); + step_id, [&called](absl::Status status) { called++; }, absl::OkStatus()); partial_run_mgr.ExecutorDone(step_id, absl::OkStatus()); EXPECT_EQ(1, called); } struct StatusTestParam { - Status executor_status; - Status partial_run_status; - Status expected_status; + absl::Status executor_status; + absl::Status partial_run_status; + absl::Status expected_status; }; class StatusPropagationTest : public ::testing::TestWithParam { @@ -89,15 +89,15 @@ class StatusPropagationTest : public ::testing::TestWithParam { // State to help keep track of when the callback is called. Notification invoked_; - Status status_; + absl::Status status_; - void set_status(const Status& status) { + void set_status(const absl::Status& status) { status_ = status; invoked_.Notify(); } // Blocks until status is set. - Status status() { + absl::Status status() { invoked_.WaitForNotification(); return status_; } @@ -112,9 +112,9 @@ TEST_P(StatusPropagationTest, ExecutorDoneFirst) { partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); partial_run_mgr_.ExecutorDone(step_id, param.executor_status); - partial_run_mgr_.PartialRunDone(step_id, - [this](Status status) { set_status(status); }, - param.partial_run_status); + partial_run_mgr_.PartialRunDone( + step_id, [this](absl::Status status) { set_status(status); }, + param.partial_run_status); EXPECT_EQ(status(), param.expected_status); } @@ -127,9 +127,9 @@ TEST_P(StatusPropagationTest, PartialRunDoneFirst) { CancellationManager* cancellation_manager; partial_run_mgr_.FindOrCreate(step_id, &cancellation_manager); - partial_run_mgr_.PartialRunDone(step_id, - [this](Status status) { set_status(status); }, - param.partial_run_status); + partial_run_mgr_.PartialRunDone( + step_id, [this](absl::Status status) { set_status(status); }, + param.partial_run_status); partial_run_mgr_.ExecutorDone(step_id, param.executor_status); EXPECT_EQ(status(), param.expected_status); @@ -137,8 +137,8 @@ TEST_P(StatusPropagationTest, PartialRunDoneFirst) { // Instantiate tests for all error orderings, for both call orders of // ExecutorDone and PartialRunDone. -Status ExecutorError() { return errors::Internal("executor error"); } -Status PartialRunError() { return errors::Internal("partial run error"); } +absl::Status ExecutorError() { return errors::Internal("executor error"); } +absl::Status PartialRunError() { return errors::Internal("partial run error"); } INSTANTIATE_TEST_SUITE_P( PartialRunMgr, StatusPropagationTest, ::testing::Values( diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.cc b/tensorflow/core/distributed_runtime/recent_request_ids.cc index e7ea66286fb341..f75390b26bd338 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.cc +++ b/tensorflow/core/distributed_runtime/recent_request_ids.cc @@ -60,9 +60,9 @@ bool RecentRequestIds::Insert(int64_t request_id) { return true; } -Status RecentRequestIds::TrackUnique(int64_t request_id, - const string& method_name, - const protobuf::Message& request) { +absl::Status RecentRequestIds::TrackUnique(int64_t request_id, + const string& method_name, + const protobuf::Message& request) { if (Insert(request_id)) { return absl::OkStatus(); } else { diff --git a/tensorflow/core/distributed_runtime/recent_request_ids.h b/tensorflow/core/distributed_runtime/recent_request_ids.h index bc2b1a3d615035..2eb35ac7266c6c 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids.h +++ b/tensorflow/core/distributed_runtime/recent_request_ids.h @@ -60,12 +60,12 @@ class RecentRequestIds { // num_tracked_request_ids insertions. For backwards compatibility, this // always returns OK for request_id 0. The method_name and the request's // ShortDebugString are added to returned errors. - Status TrackUnique(int64_t request_id, const string& method_name, - const protobuf::Message& request); + absl::Status TrackUnique(int64_t request_id, const string& method_name, + const protobuf::Message& request); // Overloaded version of the above function for wrapped protos. template - Status TrackUnique(int64_t request_id, const string& method_name, - const RequestWrapper* wrapper); + absl::Status TrackUnique(int64_t request_id, const string& method_name, + const RequestWrapper* wrapper); private: bool Insert(int64_t request_id); @@ -87,9 +87,9 @@ class RecentRequestIds { // Implementation details template -Status RecentRequestIds::TrackUnique(int64_t request_id, - const string& method_name, - const RequestWrapper* wrapper) { +absl::Status RecentRequestIds::TrackUnique(int64_t request_id, + const string& method_name, + const RequestWrapper* wrapper) { if (Insert(request_id)) { return absl::OkStatus(); } else { diff --git a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc index cfa2f78b8cd0ed..11b85256ac634e 100644 --- a/tensorflow/core/distributed_runtime/recent_request_ids_test.cc +++ b/tensorflow/core/distributed_runtime/recent_request_ids_test.cc @@ -26,7 +26,8 @@ limitations under the License. namespace tensorflow { -Status TrackUnique(int64_t request_id, RecentRequestIds* recent_request_ids) { +absl::Status TrackUnique(int64_t request_id, + RecentRequestIds* recent_request_ids) { RecvTensorRequest request; request.set_request_id(request_id); return recent_request_ids->TrackUnique(request_id, "recent_request_ids_test", diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc index 9e6479b0e77ba1..ad8ac2080ab833 100644 --- a/tensorflow/core/distributed_runtime/remote_device.cc +++ b/tensorflow/core/distributed_runtime/remote_device.cc @@ -39,7 +39,7 @@ class RemoteDevice : public Device { : Device(env, da), local_dev_name_(DeviceNameUtils::LocalName(da.name())) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } ResourceMgr* resource_manager() override { @@ -91,8 +91,8 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, }; Call* call = new Call; auto cb = [env, worker_cache, worker_name, done, wi, - call](const Status& status) { - Status s = status; + call](const absl::Status& status) { + absl::Status s = status; std::vector remote_devices; auto cleanup = gtl::MakeCleanup( [&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] { diff --git a/tensorflow/core/distributed_runtime/remote_device.h b/tensorflow/core/distributed_runtime/remote_device.h index 2b7a8eb55b5d12..766c9d8e167f8d 100644 --- a/tensorflow/core/distributed_runtime/remote_device.h +++ b/tensorflow/core/distributed_runtime/remote_device.h @@ -36,7 +36,7 @@ class WorkerCacheInterface; // This callback should have the same definition as DeviceMgr::LookupDevice // It assigns *device with pointer to Device of the given 'name', where 'name' // is either a full device name, or just the replica-local suffix. -typedef std::function +typedef std::function LookupLocalDevice; // Creates Remote Devices for the provided device attributes. Helpful when the @@ -59,7 +59,7 @@ void AsRemoteDevices( // // Otherwise, the 'done' callback is given an error status and the // vector is empty. -typedef std::function*)> +typedef std::function*)> NewRemoteDevicesDone; void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache, const string& worker_name, NewRemoteDevicesDone done); diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h index 2308d8abc9c46d..6ec759d45441bc 100644 --- a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h +++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h @@ -39,7 +39,7 @@ class WorkerSession; class RemoteRendezvous : public Rendezvous { public: // Fully construct the RemoteRendezvous. - virtual Status Initialize(WorkerSession* session) = 0; + virtual absl::Status Initialize(WorkerSession* session) = 0; // In remote eager, set current instance as context default rendezvous which // will be used for eager op-by-op execution. @@ -91,8 +91,9 @@ class RendezvousMgrInterface { Rendezvous::DoneCallback done) = 0; // Synchronous wrapper for RecvLocalAsync. - virtual Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed, - Tensor* val, bool* is_dead) = 0; + virtual absl::Status RecvLocal(int64_t step_id, + const Rendezvous::ParsedKey& parsed, + Tensor* val, bool* is_dead) = 0; // Removes rendezvous for "step_id". // diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 959a8abae6518a..0b7b731e6f51fb 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -208,9 +208,9 @@ tf_cuda_library( "//tensorflow/core/profiler/lib:scoped_memory_debug_annotation", "//tensorflow/core/protobuf:worker_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", "@local_xla//xla/tsl/distributed_runtime/rpc:async_service_interface", "@local_xla//xla/tsl/distributed_runtime/rpc:grpc_call", + "@local_xla//xla/tsl/protobuf:rpc_options_proto_cc", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/BUILD b/tensorflow/core/distributed_runtime/rpc/eager/BUILD index fc26093c418cf3..96945529341a09 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/eager/BUILD @@ -17,7 +17,6 @@ cc_library( # copybara:uncomment copts = ["-Wthread-safety-analysis"], deps = [ "//tensorflow/core/protobuf:eager_service_cc_grpc_proto", - "@local_xla//xla/stream_executor/platform", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc index f69a34eeb21839..00141a5dc89f30 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.cc @@ -243,9 +243,9 @@ class GrpcEagerClient : public EagerClient { it->second.SendNextRequest(*request, response, std::move(done_wrapped)); } else { Notification n; - Status status; + absl::Status status; EnqueueAsync(call_opts, request, response, - [&n, &status](const Status& s) { + [&n, &status](const absl::Status& s) { status.Update(s); n.Notify(); }); @@ -268,7 +268,7 @@ class GrpcEagerClient : public EagerClient { StatusCallback callback_wrapper(StatusCallback done) { Ref(); - return [this, done = std::move(done)](const Status& status) { + return [this, done = std::move(done)](const absl::Status& status) { done(status); this->Unref(); if (TF_PREDICT_FALSE(!status.ok())) { @@ -304,8 +304,8 @@ class GrpcEagerClientCache : public EagerClientCache { ~GrpcEagerClientCache() override { threads_.clear(); } - Status GetClient(const string& target, - core::RefCountPtr* client) override { + absl::Status GetClient(const string& target, + core::RefCountPtr* client) override { mutex_lock l(clients_mu_); auto it = clients_.find(target); if (it == clients_.end()) { diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc index d68f403e634cbe..128da6b893add2 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client_test.cc @@ -46,7 +46,7 @@ TEST(GrpcEagerClientCache, TestGetClientThreadSafety) { Env::Default()->SchedClosure([&client_cache, i, &counter]() { string target = strings::StrCat("/job:worker/replica:0/task:", i); core::RefCountPtr eager_client; - Status s = client_cache->GetClient(target, &eager_client); + absl::Status s = client_cache->GetClient(target, &eager_client); // With 6 tasks added to the job, querying client for 0--5 should be OK, // and querying client for 6+ should give invalid argument error. error::Code expected_code = i <= 5 ? error::OK : error::INVALID_ARGUMENT; diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h index 4ef8131feef362..24cd17a441fd7d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ -#include "xla/stream_executor/platform/port.h" #include "tensorflow/core/protobuf/eager_service.grpc.pb.h" #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_EAGER_GRPC_EAGER_SERVICE_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc index 9599766e784e35..5d58d415c81470 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.cc @@ -32,10 +32,17 @@ GrpcEagerServiceImpl::GrpcEagerServiceImpl( local_impl_(env), enqueue_streaming_thread_(env_->env, "enqueue_streaming_thread", 1) { server_builder->RegisterService(&service_); + // gRPC by default will cancel requests that sit in a completion queue for + // more than 30s. See + // https://github.com/grpc/grpc/blob/e52e48b7ef83feeff56ed0894ce39841ea8bd483/include/grpc/impl/channel_arg_names.h#L106-L111 + // Extending this to 1 hour for Tensorflow since some graphs may have periods + // of heavy load which may cause the server to run into these cancellations. + server_builder->AddChannelArgument( + "grpc.server_max_unrequested_time_in_server", 3600); cq_ = server_builder->AddCompletionQueue(); } -Status GrpcEagerServiceImpl::CreateMasterContext( +absl::Status GrpcEagerServiceImpl::CreateMasterContext( const tensorflow::uint64 context_id, EagerContext* context) { return local_impl_.CreateMasterContext(context_id, context); } diff --git a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h index 88d49312fa87d1..7417b9a74a754d 100644 --- a/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service_impl.h @@ -45,8 +45,8 @@ class GrpcEagerServiceImpl : public tsl::AsyncServiceInterface { virtual ~GrpcEagerServiceImpl() {} // Create a master context in eager service. - Status CreateMasterContext(const tensorflow::uint64 context_id, - EagerContext* context); + absl::Status CreateMasterContext(const tensorflow::uint64 context_id, + EagerContext* context); void HandleRPCsLoop() override; void Shutdown() override; @@ -92,12 +92,12 @@ class GrpcEagerServiceImpl : public tsl::AsyncServiceInterface { env_->compute_pool->Schedule([this, call]() { auto call_opts = std::make_shared(); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - local_impl_.RunComponentFunction(call_opts.get(), &call->request, - &call->response, - [call, call_opts](const Status& s) { - call->ClearCancelCallback(); - call->SendResponse(ToGrpcStatus(s)); - }); + local_impl_.RunComponentFunction( + call_opts.get(), &call->request, &call->response, + [call, call_opts](const absl::Status& s) { + call->ClearCancelCallback(); + call->SendResponse(ToGrpcStatus(s)); + }); }); tsl::Call:: @@ -129,7 +129,7 @@ class GrpcEagerServiceImpl : public tsl::AsyncServiceInterface { // NOTE(fishx): Use the address of StreamingCall as the stream_id since we // reuse the same StreamingCall for multiple requests in the same // streaming connection. - Status status = local_impl_.Enqueue( + absl::Status status = local_impl_.Enqueue( /*call_opts=*/nullptr, &call->request(), call->mutable_response(), reinterpret_cast(static_cast(call))); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index de9bd049d0d468..1039acd85ef9c2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -54,6 +54,14 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { is_shutdown_(false), default_session_config_(default_session_config) { builder->RegisterService(&master_service_); + // gRPC by default will cancel requests that sit in a completion queue for + // more than 30s. See + // https://github.com/grpc/grpc/blob/e52e48b7ef83feeff56ed0894ce39841ea8bd483/include/grpc/impl/channel_arg_names.h#L106-L111 + // Extending this to 1 hour for Tensorflow since some graphs may have + // periods of heavy load which may cause the server to run into these + // cancellations. + builder->AddChannelArgument("grpc.server_max_unrequested_time_in_server", + 3600); cq_ = builder->AddCompletionQueue(); } @@ -154,11 +162,12 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { CreateSessionRequest* rewritten_req = new CreateSessionRequest; rewritten_req->mutable_config()->MergeFrom(default_session_config_); rewritten_req->MergeFrom(call->request); - master_impl_->CreateSession(rewritten_req, &call->response, - [call, rewritten_req](const Status& status) { - call->SendResponse(ToGrpcStatus(status)); - delete rewritten_req; - }); + master_impl_->CreateSession( + rewritten_req, &call->response, + [call, rewritten_req](const absl::Status& status) { + call->SendResponse(ToGrpcStatus(status)); + delete rewritten_req; + }); ENQUEUE_REQUEST(CreateSession, true); } @@ -166,7 +175,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void ExtendSessionHandler( MasterCall* call) { master_impl_->ExtendSession(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(ExtendSession, false); @@ -176,7 +185,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void PartialRunSetupHandler( MasterCall* call) { master_impl_->PartialRunSetup(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(PartialRunSetup, false); @@ -199,7 +208,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { master_impl_->RunStep( call_opts, wrapped_request, wrapped_response, [call, call_opts, wrapped_request, wrapped_response, - trace](const Status& status) { + trace](const absl::Status& status) { call->ClearCancelCallback(); delete call_opts; delete wrapped_request; @@ -222,7 +231,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void CloseSessionHandler( MasterCall* call) { master_impl_->CloseSession(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(CloseSession, false); @@ -232,7 +241,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void ListDevicesHandler( MasterCall* call) { master_impl_->ListDevices(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(ListDevices, false); @@ -241,7 +250,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { // RPC handler for resetting all sessions. void ResetHandler(MasterCall* call) { master_impl_->Reset(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(Reset, false); @@ -251,7 +260,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void MakeCallableHandler( MasterCall* call) { master_impl_->MakeCallable(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(MakeCallable, false); @@ -267,13 +276,14 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { // `MasterSession` implementation. call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms()); call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); - master_impl_->RunCallable(call_opts, &call->request, &call->response, - [call, call_opts, trace](const Status& status) { - call->ClearCancelCallback(); - delete call_opts; - delete trace; - call->SendResponse(ToGrpcStatus(status)); - }); + master_impl_->RunCallable( + call_opts, &call->request, &call->response, + [call, call_opts, trace](const absl::Status& status) { + call->ClearCancelCallback(); + delete call_opts; + delete trace; + call->SendResponse(ToGrpcStatus(status)); + }); ENQUEUE_REQUEST(RunCallable, false); } @@ -281,7 +291,7 @@ class GrpcMasterService : public tsl::AsyncServiceInterface { void ReleaseCallableHandler( MasterCall* call) { master_impl_->ReleaseCallable(&call->request, &call->response, - [call](const Status& status) { + [call](const absl::Status& status) { call->SendResponse(ToGrpcStatus(status)); }); ENQUEUE_REQUEST(ReleaseCallable, false); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index d180b4dc236451..803d543aee63b7 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -45,69 +45,70 @@ class GrpcRemoteMaster : public MasterInterface { ~GrpcRemoteMaster() override {} - Status CreateSession(CallOptions* call_options, - const CreateSessionRequest* request, - CreateSessionResponse* response) override { + absl::Status CreateSession(CallOptions* call_options, + const CreateSessionRequest* request, + CreateSessionResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::CreateSession); } - Status ExtendSession(CallOptions* call_options, - const ExtendSessionRequest* request, - ExtendSessionResponse* response) override { + absl::Status ExtendSession(CallOptions* call_options, + const ExtendSessionRequest* request, + ExtendSessionResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::ExtendSession); } - Status PartialRunSetup(CallOptions* call_options, - const PartialRunSetupRequest* request, - PartialRunSetupResponse* response) override { + absl::Status PartialRunSetup(CallOptions* call_options, + const PartialRunSetupRequest* request, + PartialRunSetupResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::PartialRunSetup); } - Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request, - MutableRunStepResponseWrapper* response) override { + absl::Status RunStep(CallOptions* call_options, + RunStepRequestWrapper* request, + MutableRunStepResponseWrapper* response) override { return CallWithRetry(call_options, &request->ToProto(), get_proto_from_wrapper(response), &MasterServiceStub::RunStep, "RunStep/Client"); } - Status CloseSession(CallOptions* call_options, - const CloseSessionRequest* request, - CloseSessionResponse* response) override { + absl::Status CloseSession(CallOptions* call_options, + const CloseSessionRequest* request, + CloseSessionResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::CloseSession); } - Status ListDevices(CallOptions* call_options, - const ListDevicesRequest* request, - ListDevicesResponse* response) override { + absl::Status ListDevices(CallOptions* call_options, + const ListDevicesRequest* request, + ListDevicesResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::ListDevices); } - Status Reset(CallOptions* call_options, const ResetRequest* request, - ResetResponse* response) override { + absl::Status Reset(CallOptions* call_options, const ResetRequest* request, + ResetResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::Reset); } - Status MakeCallable(CallOptions* call_options, - const MakeCallableRequest* request, - MakeCallableResponse* response) override { + absl::Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::MakeCallable); } - Status RunCallable(CallOptions* call_options, - const RunCallableRequest* request, - RunCallableResponse* response) override { + absl::Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::RunCallable); } - Status ReleaseCallable(CallOptions* call_options, - const ReleaseCallableRequest* request, - ReleaseCallableResponse* response) override { + absl::Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) override { return CallWithRetry(call_options, request, response, &MasterServiceStub::ReleaseCallable); } @@ -124,17 +125,17 @@ class GrpcRemoteMaster : public MasterInterface { } template - Status CallWithRetry(CallOptions* call_options, const Request* request, - Response* response, - ::grpc::Status (MasterServiceStub::*pfunc)( - ::grpc::ClientContext*, const Request&, Response*), - string trace_string = {}) { + absl::Status CallWithRetry( + CallOptions* call_options, const Request* request, Response* response, + ::grpc::Status (MasterServiceStub::*pfunc)(::grpc::ClientContext*, + const Request&, Response*), + string trace_string = {}) { absl::Duration timeout = absl::Milliseconds(call_options->GetTimeout()); absl::Time expired_time = absl::FromUnixMicros(Env::Default()->NowMicros()); if (timeout > absl::ZeroDuration()) { expired_time += timeout; } - Status s; + absl::Status s; for (int num_retries = 0;; ++num_retries) { ::grpc::ClientContext ctx; std::unique_ptr trace; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc index d1ad012a28f222..71a7e0ff1f5e2a 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc @@ -135,7 +135,7 @@ class GrpcRemoteWorker : public WorkerInterface { bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); auto callback = [this, request, response, done, start_usec, - logging_active](Status s) { + logging_active](absl::Status s) { if (logging_active) { if (logger_->LoggingActive()) { int64_t end_usec = Env::Default()->NowMicros(); @@ -204,7 +204,7 @@ class GrpcRemoteWorker : public WorkerInterface { bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2); auto callback = [this, request, response, done, start_usec, - logging_active](Status s) { + logging_active](absl::Status s) { if (logging_active) { if (logger_->LoggingActive()) { int64_t end_usec = Env::Default()->NowMicros(); @@ -296,7 +296,7 @@ class GrpcRemoteWorker : public WorkerInterface { request.set_request_id(request_id); MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse(); - auto done = [response](Status status) { delete response; }; + auto done = [response](absl::Status status) { delete response; }; IssueRequest(&request, response, markrecvfinished_, done); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index 61861a670b8e9e..5b78748a909c06 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -136,8 +136,8 @@ GrpcServer::~GrpcServer() { } // Look up the requested host name and port for this task in `server_def`. -Status GrpcServer::GetHostAndPort(const ServerDef& server_def, - string* host_name, int* port) const { +absl::Status GrpcServer::GetHostAndPort(const ServerDef& server_def, + string* host_name, int* port) const { *port = -1; *host_name = "localhost"; for (const auto& job : server_def.cluster().job()) { @@ -182,7 +182,7 @@ Status GrpcServer::GetHostAndPort(const ServerDef& server_def, return absl::OkStatus(); } -Status GrpcServer::Init(const GrpcServerOptions& opts) { +absl::Status GrpcServer::Init(const GrpcServerOptions& opts) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -259,7 +259,7 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { builder.SetMaxMessageSize(std::numeric_limits::max()); bool reuse_port = false; - const Status status = + const absl::Status status = ReadBoolFromEnvVar("TF_GRPC_REUSE_PORT", false, &reuse_port); if (!status.ok()) { LOG(ERROR) << status.message(); @@ -366,8 +366,8 @@ Status GrpcServer::Init(const GrpcServerOptions& opts) { return absl::OkStatus(); } -Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, - GrpcChannelSpec* channel_spec) { +absl::Status GrpcServer::ParseChannelSpec( + const WorkerCacheFactoryOptions& options, GrpcChannelSpec* channel_spec) { for (const auto& job : options.cluster_def.job()) { std::map host_ports; for (const auto& task : job.tasks()) { @@ -393,10 +393,11 @@ Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options, return absl::OkStatus(); } -Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, - WorkerCacheInterface** worker_cache) { +absl::Status GrpcServer::WorkerCacheFactory( + const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache) { if (options.job_name.empty()) { - Status s = errors::InvalidArgument( + absl::Status s = errors::InvalidArgument( "The master (current machine) is not included in the provided " "cluster_def. ", options.cluster_def.DebugString()); @@ -432,7 +433,7 @@ Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options, return absl::OkStatus(); } -Status GrpcServer::Start() { +absl::Status GrpcServer::Start() { mutex_lock l(mu_); switch (state_) { case NEW: { @@ -474,14 +475,14 @@ Status GrpcServer::Start() { } } -Status GrpcServer::AddMasterEagerContextToEagerService( +absl::Status GrpcServer::AddMasterEagerContextToEagerService( const tensorflow::uint64 context_id, tensorflow::EagerContext* context) { auto* eager_service = static_cast(eager_service_); return eager_service->CreateMasterContext(context_id, context); } -Status GrpcServer::UpdateServerDef(const ServerDef& server_def) { +absl::Status GrpcServer::UpdateServerDef(const ServerDef& server_def) { mutex_lock l(mu_); server_def_ = server_def; WorkerCacheInterface* worker_cache; @@ -514,7 +515,7 @@ Status GrpcServer::UpdateServerDef(const ServerDef& server_def) { // TODO(haoyuzhang): Remove this method once we have a mechanism to directly set // field inside the RPC coordination service handler. -Status GrpcServer::SetCoordinationServiceAgentInstance( +absl::Status GrpcServer::SetCoordinationServiceAgentInstance( tsl::CoordinationServiceAgent* agent) { auto* coord_service = static_cast(coordination_service_); @@ -522,7 +523,7 @@ Status GrpcServer::SetCoordinationServiceAgentInstance( return absl::OkStatus(); } -Status GrpcServer::SetCoordinationServiceInstance( +absl::Status GrpcServer::SetCoordinationServiceInstance( tsl::CoordinationServiceInterface* service) { auto* coord_service = static_cast(coordination_service_); @@ -530,7 +531,7 @@ Status GrpcServer::SetCoordinationServiceInstance( return absl::OkStatus(); } -Status GrpcServer::StopCoordinationService() { +absl::Status GrpcServer::StopCoordinationService() { // Note: the sequence of events is important here. // 1. Agent must be torn down before the service as it needs to notify the // service. @@ -544,7 +545,7 @@ Status GrpcServer::StopCoordinationService() { return absl::OkStatus(); } -Status GrpcServer::Stop() { +absl::Status GrpcServer::Stop() { mutex_lock l(mu_); switch (state_) { case NEW: @@ -561,7 +562,7 @@ Status GrpcServer::Stop() { } } -Status GrpcServer::Join() { +absl::Status GrpcServer::Join() { mutex_lock l(mu_); switch (state_) { case NEW: @@ -602,15 +603,15 @@ std::unique_ptr GrpcServer::CreateMaster(MasterEnv* master_env) { } /* static */ -Status GrpcServer::Create(const ServerDef& server_def, Env* env, - DeviceMgr* local_device_mgr, - std::unique_ptr* out_server) { +absl::Status GrpcServer::Create(const ServerDef& server_def, Env* env, + DeviceMgr* local_device_mgr, + std::unique_ptr* out_server) { std::unique_ptr ret( new GrpcServer(server_def, env == nullptr ? Env::Default() : env)); GrpcServerOptions options; options.rendezvous_mgr_func = NewRpcRendezvousMgr; options.local_device_mgr = local_device_mgr; - Status s = ret->Init(options); + absl::Status s = ret->Init(options); if (!s.ok()) { LOG(ERROR) << s; return s; @@ -620,16 +621,16 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env, } /* static */ -Status GrpcServer::Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server) { +absl::Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { return Create(server_def, env, nullptr, out_server); } /* static */ -Status GrpcServer::Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server) { +absl::Status GrpcServer::Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server) { std::unique_ptr server; - Status s = Create(server_def, env, nullptr, &server); + absl::Status s = Create(server_def, env, nullptr, &server); if (!s.ok()) { return s; } @@ -645,8 +646,9 @@ class GrpcServerFactory : public ServerFactory { return server_def.protocol() == "grpc"; } - Status NewServer(const ServerDef& server_def, const Options& options, - std::unique_ptr* out_server) override { + absl::Status NewServer( + const ServerDef& server_def, const Options& options, + std::unique_ptr* out_server) override { return GrpcServer::Create(server_def, Env::Default(), options.local_device_mgr, out_server); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 383955e54675f2..ca162c193d3b15 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -87,23 +87,23 @@ class GrpcServer : public ServerInterface { int requested_port) {} public: - static Status Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server); - static Status Create(const ServerDef& server_def, Env* env, - std::unique_ptr* out_server); + static absl::Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); + static absl::Status Create(const ServerDef& server_def, Env* env, + std::unique_ptr* out_server); // Reuse the local_device_mgr. - static Status Create(const ServerDef& server_def, Env* env, - DeviceMgr* local_device_mgr, - std::unique_ptr* out_server); + static absl::Status Create(const ServerDef& server_def, Env* env, + DeviceMgr* local_device_mgr, + std::unique_ptr* out_server); // Destruction is only supported in the factory method. Clean // shutdown is not currently implemented for this server type. virtual ~GrpcServer(); // Implementations of ServerInterface methods. - Status Start() override; - Status Stop() override; - Status Join() override; + absl::Status Start() override; + absl::Status Stop() override; + absl::Status Join() override; const string target() const override; WorkerEnv* worker_env() override { return &worker_env_; } @@ -111,22 +111,22 @@ class GrpcServer : public ServerInterface { // Add master eager context to local eager service in order to handle enqueue // requests from remote workers. - Status AddMasterEagerContextToEagerService( + absl::Status AddMasterEagerContextToEagerService( const tensorflow::uint64 context_id, tensorflow::EagerContext* context) override; // Update the set of workers that can be reached by the GRPC server - Status UpdateServerDef(const ServerDef& server_def) override; + absl::Status UpdateServerDef(const ServerDef& server_def) override; // Pass coordination service agent instance to server's RPC handler - Status SetCoordinationServiceAgentInstance( + absl::Status SetCoordinationServiceAgentInstance( tsl::CoordinationServiceAgent* agent) override; // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is // supported. - Status StopCoordinationService() override; + absl::Status StopCoordinationService() override; protected: - virtual Status GetHostAndPort(const ServerDef& server_def, string* host_name, - int* port) const; - Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); + virtual absl::Status GetHostAndPort(const ServerDef& server_def, + string* host_name, int* port) const; + absl::Status Init(const GrpcServerOptions& opts = GrpcServerOptions()); // A subclass can override this method to support secure credentials. virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials( @@ -137,8 +137,9 @@ class GrpcServer : public ServerInterface { virtual std::unique_ptr CreateMaster(MasterEnv* master_env); // Creates a WorkerCacheInterface for a session. - virtual Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options, - WorkerCacheInterface** worker_cache); + virtual absl::Status WorkerCacheFactory( + const WorkerCacheFactoryOptions& options, + WorkerCacheInterface** worker_cache); // Override to return extra services to be brought up and managed along with // the standard {master, worker, eager} services. The map key is an aribtrary @@ -159,8 +160,8 @@ class GrpcServer : public ServerInterface { } // Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec. - Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, - GrpcChannelSpec* channel_spec); + absl::Status ParseChannelSpec(const WorkerCacheFactoryOptions& options, + GrpcChannelSpec* channel_spec); // Returns the port to which this server is bound. // This method may only be called after `this->Init()` returns successfully. @@ -173,7 +174,7 @@ class GrpcServer : public ServerInterface { GrpcWorker* worker_impl() const { return worker_impl_.get(); } GrpcWorkerEnv* grpc_worker_env() const { return grpc_worker_env_.get(); } - Status SetCoordinationServiceInstance( + absl::Status SetCoordinationServiceInstance( tsl::CoordinationServiceInterface* service); private: diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 7911ea2e59dc03..3d979d1b79ee52 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -43,8 +43,8 @@ GrpcSession::GrpcSession(const SessionOptions& options) GrpcSession::~GrpcSession() {} /* static */ -Status GrpcSession::Create(const SessionOptions& options, - std::unique_ptr* out_session) { +absl::Status GrpcSession::Create(const SessionOptions& options, + std::unique_ptr* out_session) { std::unique_ptr session(new GrpcSession(options)); std::unique_ptr master; // For testing, we enable the client to disable the use of the local @@ -103,7 +103,7 @@ void GrpcSession::SetHandleAndGraphVersion(string handle, current_graph_version_ = graph_version; } -Status GrpcSession::Handle(string* out_handle) { +absl::Status GrpcSession::Handle(string* out_handle) { mutex_lock l(mu_); if (handle_.empty()) { return errors::InvalidArgument("A session is not created yet...."); @@ -112,7 +112,8 @@ Status GrpcSession::Handle(string* out_handle) { return absl::OkStatus(); } -Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) { +absl::Status GrpcSession::CreateImpl(CallOptions* call_options, + GraphDef graph) { { mutex_lock l(mu_); if (!handle_.empty()) { @@ -125,35 +126,37 @@ Status GrpcSession::CreateImpl(CallOptions* call_options, GraphDef graph) { req.set_target(options_.target); ReEncodeConsts(req.mutable_graph_def()); CreateSessionResponse resp; - Status s = master_->CreateSession(call_options, &req, &resp); + absl::Status s = master_->CreateSession(call_options, &req, &resp); if (s.ok()) { SetHandleAndGraphVersion(resp.session_handle(), resp.graph_version()); } return s; } -Status GrpcSession::Create(const GraphDef& graph) { +absl::Status GrpcSession::Create(const GraphDef& graph) { return Create(GraphDef(graph)); } -Status GrpcSession::Create(const RunOptions& run_options, - const GraphDef& graph) { +absl::Status GrpcSession::Create(const RunOptions& run_options, + const GraphDef& graph) { return Create(run_options, GraphDef(graph)); } -Status GrpcSession::Create(GraphDef&& graph) { +absl::Status GrpcSession::Create(GraphDef&& graph) { CallOptions call_options; call_options.SetTimeout(options_.config.operation_timeout_in_ms()); return CreateImpl(&call_options, std::move(graph)); } -Status GrpcSession::Create(const RunOptions& run_options, GraphDef&& graph) { +absl::Status GrpcSession::Create(const RunOptions& run_options, + GraphDef&& graph) { CallOptions call_options; call_options.SetTimeout(run_options.timeout_in_ms()); return CreateImpl(&call_options, std::move(graph)); } -Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) { +absl::Status GrpcSession::ExtendImpl(CallOptions* call_options, + GraphDef graph) { bool handle_is_empty; { mutex_lock l(mu_); @@ -169,35 +172,36 @@ Status GrpcSession::ExtendImpl(CallOptions* call_options, GraphDef graph) { req.mutable_graph_def()->Swap(&graph); req.set_current_graph_version(current_graph_version_); ExtendSessionResponse resp; - Status s = master_->ExtendSession(call_options, &req, &resp); + absl::Status s = master_->ExtendSession(call_options, &req, &resp); if (s.ok()) { current_graph_version_ = resp.new_graph_version(); } return s; } -Status GrpcSession::Extend(const GraphDef& graph) { +absl::Status GrpcSession::Extend(const GraphDef& graph) { return Extend(GraphDef(graph)); } -Status GrpcSession::Extend(const RunOptions& run_options, - const GraphDef& graph) { +absl::Status GrpcSession::Extend(const RunOptions& run_options, + const GraphDef& graph) { return Extend(run_options, GraphDef(graph)); } -Status GrpcSession::Extend(GraphDef&& graph) { +absl::Status GrpcSession::Extend(GraphDef&& graph) { CallOptions call_options; call_options.SetTimeout(options_.config.operation_timeout_in_ms()); return ExtendImpl(&call_options, std::move(graph)); } -Status GrpcSession::Extend(const RunOptions& run_options, GraphDef&& graph) { +absl::Status GrpcSession::Extend(const RunOptions& run_options, + GraphDef&& graph) { CallOptions call_options; call_options.SetTimeout(run_options.timeout_in_ms()); return ExtendImpl(&call_options, std::move(graph)); } -Status GrpcSession::RunHelper( +absl::Status GrpcSession::RunHelper( const RunOptions& run_options, const std::vector>& inputs, const std::vector& output_tensor_names, @@ -284,39 +288,40 @@ Status GrpcSession::RunHelper( return absl::OkStatus(); } -Status GrpcSession::Run(const RunOptions& run_options, - const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, - RunMetadata* run_metadata) { +absl::Status GrpcSession::Run( + const RunOptions& run_options, + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, std::vector* outputs, + RunMetadata* run_metadata) { return RunHelper(run_options, inputs, output_tensor_names, target_node_names, outputs, run_metadata, /* prun_handle */ ""); } -Status GrpcSession::Run(const std::vector>& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) { +absl::Status GrpcSession::Run( + const std::vector>& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) { RunOptions run_options; run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); return Run(run_options, inputs, output_tensor_names, target_node_names, outputs, nullptr); } -Status GrpcSession::RunProto(CallOptions* call_options, - MutableRunStepRequestWrapper* req, - MutableRunStepResponseWrapper* resp) { +absl::Status GrpcSession::RunProto(CallOptions* call_options, + MutableRunStepRequestWrapper* req, + MutableRunStepResponseWrapper* resp) { string handle; TF_RETURN_IF_ERROR(Handle(&handle)); req->set_session_handle(handle); return master_->RunStep(call_options, req, resp); } -Status GrpcSession::PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) { +absl::Status GrpcSession::PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) { // Convert to proto PartialRunSetupRequest req; PartialRunSetupResponse resp; @@ -338,17 +343,16 @@ Status GrpcSession::PRunSetup(const std::vector& input_names, return absl::OkStatus(); } -Status GrpcSession::PRun(const string& handle, - const std::vector>& inputs, - const std::vector& output_names, - std::vector* outputs) { +absl::Status GrpcSession::PRun( + const string& handle, const std::vector>& inputs, + const std::vector& output_names, std::vector* outputs) { RunOptions run_options; run_options.set_timeout_in_ms(options_.config.operation_timeout_in_ms()); return RunHelper(run_options, inputs, output_names, /* targets */ {}, outputs, /* run_metadata */ nullptr, handle); } -Status GrpcSession::Close() { +absl::Status GrpcSession::Close() { CloseSessionRequest req; { mutex_lock l(mu_); @@ -364,7 +368,7 @@ Status GrpcSession::Close() { return master_->CloseSession(&call_options, &req, &resp); } -Status GrpcSession::ListDevices(std::vector* response) { +absl::Status GrpcSession::ListDevices(std::vector* response) { ListDevicesRequest req; { mutex_lock l(mu_); @@ -384,7 +388,7 @@ Status GrpcSession::ListDevices(std::vector* response) { ListDevicesResponse resp; CallOptions call_options; call_options.SetTimeout(options_.config.operation_timeout_in_ms()); - Status s = master_->ListDevices(&call_options, &req, &resp); + absl::Status s = master_->ListDevices(&call_options, &req, &resp); if (!s.ok()) { LOG(ERROR) << "Could not list devices: " << s; return s; @@ -406,8 +410,8 @@ void GrpcSession::SetRemoteMaster(std::unique_ptr master) { } // Static method. -Status GrpcSession::Reset(const SessionOptions& options, - const std::vector& containers) { +absl::Status GrpcSession::Reset(const SessionOptions& options, + const std::vector& containers) { SharedGrpcChannelPtr master_channel; TF_RETURN_IF_ERROR( NewHostPortGrpcChannel(options.target.substr(kSchemePrefixLength), @@ -419,13 +423,13 @@ Status GrpcSession::Reset(const SessionOptions& options, ResetResponse resp; CallOptions call_options; call_options.SetTimeout(options.config.operation_timeout_in_ms()); - Status ret = master->Reset(&call_options, &req, &resp); + absl::Status ret = master->Reset(&call_options, &req, &resp); delete master; return ret; } -Status GrpcSession::MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) { +absl::Status GrpcSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { MakeCallableRequest req; TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); *req.mutable_options() = callable_options; @@ -438,10 +442,10 @@ Status GrpcSession::MakeCallable(const CallableOptions& callable_options, return absl::OkStatus(); } -Status GrpcSession::RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) { +absl::Status GrpcSession::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { RunCallableRequest req; TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); req.set_handle(handle); @@ -465,7 +469,7 @@ Status GrpcSession::RunCallable(CallableHandle handle, return absl::OkStatus(); } -Status GrpcSession::ReleaseCallable(CallableHandle handle) { +absl::Status GrpcSession::ReleaseCallable(CallableHandle handle) { ReleaseCallableRequest req; TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); req.set_handle(handle); @@ -481,8 +485,8 @@ class GrpcSessionFactory : public SessionFactory { return absl::StartsWith(options.target, kSchemePrefix); } - Status NewSession(const SessionOptions& options, - Session** out_session) override { + absl::Status NewSession(const SessionOptions& options, + Session** out_session) override { std::unique_ptr session; TF_RETURN_IF_ERROR(GrpcSession::Create(options, &session)); *out_session = session.release(); @@ -490,8 +494,8 @@ class GrpcSessionFactory : public SessionFactory { } // Invokes the session specific static method to reset containers. - Status Reset(const SessionOptions& options, - const std::vector& containers) override { + absl::Status Reset(const SessionOptions& options, + const std::vector& containers) override { return GrpcSession::Reset(options, containers); } }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index b0278b855d3b7c..fe92f7c073251e 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -51,61 +51,64 @@ class GrpcSession : public Session { explicit GrpcSession(const SessionOptions& options); public: - static Status Create(const SessionOptions& options, - std::unique_ptr* out_session); + static absl::Status Create(const SessionOptions& options, + std::unique_ptr* out_session); // Resets the resource containers. - static Status Reset(const SessionOptions& options, - const std::vector& containers); + static absl::Status Reset(const SessionOptions& options, + const std::vector& containers); ~GrpcSession() override; // Creates a session with the "target". The session carries out // the graph computation defined by "graph", and will have version // number "initial_version". - Status Create(const GraphDef& graph) override; - Status Create(const RunOptions& run_options, const GraphDef& graph) override; - Status Create(GraphDef&& graph) override; - Status Create(const RunOptions& run_options, GraphDef&& graph) override; + absl::Status Create(const GraphDef& graph) override; + absl::Status Create(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Create(GraphDef&& graph) override; + absl::Status Create(const RunOptions& run_options, GraphDef&& graph) override; // Runs with and without RunOptions. - Status Run(const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) override; - Status Run(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, RunMetadata* run_metadata) override; - - Status Extend(const GraphDef& graph) override; - Status Extend(const RunOptions& run_options, const GraphDef& graph) override; - Status Extend(GraphDef&& graph) override; - Status Extend(const RunOptions& run_options, GraphDef&& graph) override; - - Status Close() override; + absl::Status Run(const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) override; + absl::Status Run(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata) override; + + absl::Status Extend(const GraphDef& graph) override; + absl::Status Extend(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Extend(GraphDef&& graph) override; + absl::Status Extend(const RunOptions& run_options, GraphDef&& graph) override; + + absl::Status Close() override; // NOTE: This API is still experimental and may change. - Status PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) override; + absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; // NOTE: This API is still experimental and may change. - Status PRun(const string& handle, - const std::vector >& inputs, - const std::vector& output_names, - std::vector* outputs) override; + absl::Status PRun(const string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) override; - Status ListDevices(std::vector* response) override; + absl::Status ListDevices(std::vector* response) override; - Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override; - Status RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) override; - Status ReleaseCallable(CallableHandle handle) override; + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override; + absl::Status ReleaseCallable(CallableHandle handle) override; protected: // Takes ownership of `*master`. @@ -127,21 +130,22 @@ class GrpcSession : public Session { bool is_local_ = false; - Status Handle(string* out_handle) TF_LOCKS_EXCLUDED(mu_); + absl::Status Handle(string* out_handle) TF_LOCKS_EXCLUDED(mu_); - Status RunHelper(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, RunMetadata* run_metadata, - const string& prun_handle); + absl::Status RunHelper(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata, const string& prun_handle); - Status RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, - MutableRunStepResponseWrapper* resp); + absl::Status RunProto(CallOptions* call_options, + MutableRunStepRequestWrapper* req, + MutableRunStepResponseWrapper* resp); // Implementations for all the public interfaces. - Status CreateImpl(CallOptions* call_options, GraphDef graph); - Status ExtendImpl(CallOptions* call_options, GraphDef graph); + absl::Status CreateImpl(CallOptions* call_options, GraphDef graph); + absl::Status ExtendImpl(CallOptions* call_options, GraphDef graph); GrpcSession(const GrpcSession&) = delete; void operator=(const GrpcSession&) = delete; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index c1026dc273136c..1b5ae927544a14 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -208,7 +208,7 @@ TEST(GrpcSessionTest, CallableWithOnDeviceFeedsAndFetches) { opts.mutable_fetch_devices()->insert({fetch, device_name}); Session::CallableHandle handle; - Status status = session->MakeCallable(opts, &handle); + absl::Status status = session->MakeCallable(opts, &handle); EXPECT_EQ(error::UNIMPLEMENTED, status.code()); TF_ASSERT_OK(session->Close()); } @@ -344,8 +344,8 @@ TEST(GrpcSessionTest, DisableOutputPartitionGraphs) { RunOptions run_options; run_options.set_output_partition_graphs(true); RunMetadata run_metadata; - Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr, - &run_metadata); + absl::Status s = session->Run(run_options, {}, {}, {node_names[2]}, nullptr, + &run_metadata); EXPECT_TRUE(errors::IsInvalidArgument(s)); EXPECT_TRUE( absl::StrContains(s.message(), "disable_output_partition_graphs")); @@ -601,7 +601,7 @@ TEST(GrpcSessionTest, MultiDevices_String) { SetDevice(&def, a->name(), a_dev.name()); SetDevice(&def, b->name(), b_dev.name()); - Status s = session->Create(def); + absl::Status s = session->Create(def); if (s.ok()) { std::vector outputs; TF_ASSERT_OK(session->Run({}, {b->name()}, {}, &outputs)); @@ -727,7 +727,7 @@ TEST(GrpcSessionTest, Error) { TF_ASSERT_OK(session->Create(gdef)); { - Status status = session->Run({}, fetches, {}, nullptr); + absl::Status status = session->Run({}, fetches, {}, nullptr); EXPECT_FALSE(status.ok()); EXPECT_NE(status.ToString().find("fantasia!"), string::npos); } @@ -788,7 +788,7 @@ TEST(GrpcSessionTest, ErrorStatusLog) { TF_ASSERT_OK(session->Create(gdef)); { - Status status = session->Run({}, fetches, {}, nullptr); + absl::Status status = session->Run({}, fetches, {}, nullptr); EXPECT_FALSE(status.ok()); std::cerr << status << "\n"; EXPECT_NE(status.ToString().find("fantasia!"), string::npos); @@ -854,7 +854,7 @@ TEST(GrpcSessionTest, LongErrorMessage) { TF_ASSERT_OK(session->Create(gdef)); { - Status status = session->Run({}, fetches, {}, nullptr); + absl::Status status = session->Run({}, fetches, {}, nullptr); EXPECT_FALSE(status.ok()); EXPECT_NE(status.ToString().find("fantasia!"), string::npos); } @@ -1023,7 +1023,7 @@ void CreateInvalidGraph(const string& graph_def_ascii, std::unique_ptr session( NewRemote(Options(cluster->targets()[0], 1))); - Status s = session->Create(graph); + absl::Status s = session->Create(graph); ASSERT_FALSE(s.ok()); EXPECT_NE(s.message().find(error_substring), string::npos); @@ -1182,7 +1182,7 @@ TEST(SessionTest, ExtendValidation) { &extension); ASSERT_TRUE(success); - Status s = session->Extend(extension); + absl::Status s = session->Extend(extension); ASSERT_FALSE(s.ok()); EXPECT_NE(s.message().find("Illegal op input name"), string::npos); @@ -1228,7 +1228,7 @@ TEST(SessionTest, CreateTimeoutWithSessionOptions) { test::graph::Delay(&graph, b, Microseconds(1000000)); GraphDef gdef; test::graph::ToGraphDef(&graph, &gdef); - Status status = session->Create(gdef); + absl::Status status = session->Create(gdef); // Either error is possible, depending on the environment. EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || error::UNAVAILABLE == status.code()); @@ -1248,7 +1248,7 @@ TEST(SessionTest, CreateTimeoutWithRunOptions) { RunOptions run_options; // Sets RunOption timeout_in_ms to 20. run_options.set_timeout_in_ms(20); - Status status = session->Create(run_options, gdef); + absl::Status status = session->Create(run_options, gdef); // Either error is possible, depending on the environment. EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || error::UNAVAILABLE == status.code()); @@ -1278,7 +1278,7 @@ TEST(SessionTest, RunTimeoutWithSessionOptions) { // Verifies that Run() times out, and the error code is DEADLINE_EXCEEDED. std::vector> inputs; - Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr); + absl::Status status = session->Run(inputs, {}, {b_delay->name()}, nullptr); // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL. EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || @@ -1308,8 +1308,8 @@ TEST(SessionTest, RunTimeoutWithRunOptions) { std::vector> inputs; RunOptions run_options; run_options.set_timeout_in_ms(100); - Status status = session->Run(run_options, inputs, {}, {b_delay->name()}, - nullptr, nullptr); + absl::Status status = session->Run(run_options, inputs, {}, {b_delay->name()}, + nullptr, nullptr); // TODO(sherrym): Due to potentially a GRPC bug, we sometimes get // GRPC_CHTTP2_INTERNAL_ERROR which is mapped to error::INTERNAL. EXPECT_TRUE(error::DEADLINE_EXCEEDED == status.code() || @@ -1405,7 +1405,7 @@ TEST(GrpcSessionTest, ErrorAggregationTwoWorkersTwoErrors) { TF_ASSERT_OK(session->Create(gdef)); { std::vector outputs; - Status status = session->Run({}, fetches, {}, &outputs); + absl::Status status = session->Run({}, fetches, {}, &outputs); LOG(INFO) << status; EXPECT_FALSE(status.ok()); // Status contains the error either worker1 or worker2. @@ -1488,7 +1488,7 @@ TEST(GrpcSessionTest, ErrorAggregationTwoWorkerRace) { TF_ASSERT_OK(session->Create(gdef)); { std::vector outputs; - Status status = session->Run({}, fetches, targets, &outputs); + absl::Status status = session->Run({}, fetches, targets, &outputs); LOG(INFO) << status; EXPECT_FALSE(status.ok()); // assert status contains the root error @@ -1587,7 +1587,7 @@ TEST(GrpcSessionTest, ErrorAggregationThreeWorkerRaceVariant1) { TF_ASSERT_OK(session->Create(gdef)); { std::vector outputs; - Status status = session->Run({}, fetches, targets, &outputs); + absl::Status status = session->Run({}, fetches, targets, &outputs); LOG(INFO) << status; EXPECT_FALSE(status.ok()); // assert status contains the root error @@ -1688,7 +1688,7 @@ TEST(GrpcSessionTest, ErrorAggregationThreeWorkerRaceVariant2) { TF_ASSERT_OK(session->Create(gdef)); { std::vector outputs; - Status status = session->Run({}, fetches, targets, &outputs); + absl::Status status = session->Run({}, fetches, targets, &outputs); LOG(INFO) << status; EXPECT_FALSE(status.ok()); // assert status contains the root error diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.cc b/tensorflow/core/distributed_runtime/rpc/grpc_state.cc index c86ad552cd1a28..1476ce65adc356 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.cc @@ -54,7 +54,7 @@ void UntypedStreamingRPCState::Tag::OnCompleted(bool ok) { streaming_state_->Unref(); // Ref acquired when tag was handed to grpc. } -void Exchange::Complete(Status status) { +void Exchange::Complete(absl::Status status) { if (status.ok()) { if (!tsl::GrpcMaybeParseProto(&response_buf_, response_)) { status.Update(errors::Internal("could not parse rpc response")); @@ -160,7 +160,7 @@ void ExchangeQueue::Swap(ExchangeQueue* other) { std::swap(call_started_, other->call_started_); } -void ExchangeQueue::CompleteAll(Status status) { +void ExchangeQueue::CompleteAll(absl::Status status) { for (Exchange& exchange : exchanges_) { exchange.Complete(status); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h index 43d87b35cb2628..4c5f560e78f0c8 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h @@ -124,7 +124,7 @@ class Exchange { // If `status` is success, completes this exchange by parsing the // response_buf_ and invoking cb_ with OkStatus(). Else, invokes the // callback with `status`. - void Complete(Status status); + void Complete(absl::Status status); const State& state() const { return state_; } @@ -198,7 +198,7 @@ class ExchangeQueue { void Swap(ExchangeQueue* other); // Completes all exchanges in this with `status`. - void CompleteAll(Status status); + void CompleteAll(absl::Status status); void CallStarted() { call_started_ = true; } @@ -250,7 +250,7 @@ class StreamingRPCState : public UntypedStreamingRPCState { ::grpc::ByteBuffer request_buf; ::grpc::Status s = tsl::GrpcMaybeUnparseProto(request, &request_buf); if (!s.ok()) { - Status status = FromGrpcStatus(s); + absl::Status status = FromGrpcStatus(s); LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " << status.ToString(); done(status); @@ -351,7 +351,7 @@ class StreamingRPCState : public UntypedStreamingRPCState { return; } - Status s = FromGrpcStatus(call_status_); + absl::Status s = FromGrpcStatus(call_status_); if (s.ok() && !ok) { s.Update( errors::Internal("GRPC status is okay but CompletionQueueStatus is " @@ -374,7 +374,7 @@ class StreamingRPCState : public UntypedStreamingRPCState { kDone, }; - void MarkDoneAndCompleteExchanges(Status status) + void MarkDoneAndCompleteExchanges(absl::Status status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) { call_state_ = State::kDone; VLOG(2) << "Ending gRPC streaming call on the client side due to " diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc index e7c5d68bc1ca65..33f40b9d39fa63 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc @@ -164,7 +164,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, } else { // skeleton is the encoded TensorProto contents (dtype and shape), but // not the actual data - gtl::InlinedVector skeleton(SkeletonEncodingSizeUpperBound(val)); + absl::InlinedVector skeleton( + SkeletonEncodingSizeUpperBound(val)); io::ProtoEncodeHelper e_skeleton(skeleton.data(), skeleton.size()); EncodeSkeleton(val, &e_skeleton); @@ -196,7 +197,7 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val, bool require_ack, // Encode all but the actual "tdata", but including the tag and // varlength header for the "tdata" - gtl::InlinedVector space(encoder_size); + absl::InlinedVector space(encoder_size); io::ProtoEncodeHelper e(space.data(), space.size()); // (A) e.WriteRawBytes(header); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc index ef28ad6667291b..f4b36334237a09 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding_test.cc @@ -65,7 +65,7 @@ class GrpcTensorCodingTest : public ::testing::Test { } } void DoTestForStrings(DataType dt) { - gtl::InlinedVector v; + absl::InlinedVector v; for (int elems = 0; elems <= 10000; elems++) { if (elems < 100 || (elems % 1000 == 0)) { Tensor a(dt, TensorShape({1, static_cast(v.size())})); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc index 8c7a686dd002e6..77f7d11283044f 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc @@ -38,8 +38,8 @@ limitations under the License. namespace tensorflow { namespace { -Status FillServerDef(const string& cluster_spec, const string& job_name, - int task_index, ServerDef* options) { +absl::Status FillServerDef(const string& cluster_spec, const string& job_name, + int task_index, ServerDef* options) { options->set_protocol("grpc"); options->set_job_name(job_name); options->set_task_index(task_index); @@ -113,8 +113,8 @@ int main(int argc, char* argv[]) { return -1; } tensorflow::ServerDef server_def; - tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name, - task_index, &server_def); + absl::Status s = tensorflow::FillServerDef(cluster_spec, job_name, task_index, + &server_def); if (!s.ok()) { std::cerr << "ERROR: " << s.message() << std::endl; Usage(argv[0]); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc index 0e9769cefd866e..89557b91c2d95c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc @@ -29,8 +29,9 @@ limitations under the License. namespace tensorflow { namespace test { -Status TestCluster::MakeTestCluster(const TestClusterConfig& config, - std::unique_ptr* out_cluster) { +absl::Status TestCluster::MakeTestCluster( + const TestClusterConfig& config, + std::unique_ptr* out_cluster) { std::string binary_path = !config.binary_path.empty() ? config.binary_path diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h index 43aa2b38c53611..9101ca92d06a81 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h @@ -69,8 +69,9 @@ class TestCluster { // processes `n`. On success, the test cluster is stored in // *out_cluster, and this function returns OK. Otherwise an error is // returned. - static Status MakeTestCluster(const TestClusterConfig& config, - std::unique_ptr* out_cluster); + static absl::Status MakeTestCluster( + const TestClusterConfig& config, + std::unique_ptr* out_cluster); ~TestCluster(); // Returns a vector of string ":" pairs that may be diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc index 2900259a83867d..f48ed0c11b73bc 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc @@ -35,9 +35,10 @@ limitations under the License. namespace tensorflow { namespace { -Status FillServerDef(const string& job_spec, const string& job_name, - int num_cpus, int num_gpus, int task_index, int replica, - std::string host_port, ServerDef* options) { +absl::Status FillServerDef(const string& job_spec, const string& job_name, + int num_cpus, int num_gpus, int task_index, + int replica, std::string host_port, + ServerDef* options) { options->set_protocol("grpc"); options->set_job_name(job_name); options->set_task_index(task_index); @@ -138,7 +139,7 @@ int main(int argc, char* argv[]) { } tensorflow::ServerDef def; - tensorflow::Status s = + absl::Status s = tensorflow::FillServerDef(job_spec, job_name, num_cpus, num_gpus, task_index, replica, host_port, &def); if (!s.ok()) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc index 2a18d0d28fe885..4820646930521c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc @@ -75,14 +75,15 @@ class GrpcWorkerCache : public WorkerCachePartial { } } - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_)); return absl::OkStatus(); } - Status GetCoordinationClientCache(std::unique_ptr* - coordination_client_cache) override { + absl::Status GetCoordinationClientCache( + std::unique_ptr* coordination_client_cache) + override { coordination_client_cache->reset( NewGrpcCoordinationClientCache(channel_cache_)); return absl::OkStatus(); @@ -154,8 +155,8 @@ GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() { GrpcWorkerEnv* CreateGrpcWorkerEnv() { int num_cpus = port::NumSchedulableCPUs(); int64_t num_completion_queues; - Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64, - &num_completion_queues); + absl::Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64, + &num_completion_queues); if (!status.ok()) { LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status; } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc index bb705a9b3d3f19..68abc533c1fa67 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "tensorflow/core/common_runtime/buf_rendezvous.h" #include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device.h" @@ -56,7 +57,6 @@ limitations under the License. #include "tensorflow/core/protobuf/transport_options.pb.h" #include "tensorflow/core/protobuf/worker.pb.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tensorflow { @@ -223,7 +223,7 @@ class GrpcWorkerServiceThread { WorkerCall* call) { Schedule([this, call]() { worker_->GetStepSequenceAsync( - &call->request, &call->response, [call](const Status& s) { + &call->request, &call->response, [call](const absl::Status& s) { VLOG(3) << "Bad response from GetStepSequence:" << s; call->SendResponse(ToGrpcStatus(s)); }); @@ -249,7 +249,7 @@ class GrpcWorkerServiceThread { call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response, [call, call_opts, wrapped_request, - wrapped_response](const Status& s) { + wrapped_response](const absl::Status& s) { VLOG(3) << "RunGraph::Done"; if (!s.ok()) { VLOG(3) << "Bad response from RunGraph:" << s; @@ -272,7 +272,7 @@ class GrpcWorkerServiceThread { worker_->GrpcRecvTensorAsync( call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { + [call, call_opts](const absl::Status& s) { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { @@ -289,7 +289,7 @@ class GrpcWorkerServiceThread { CallOptions* call_opts = new CallOptions; call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); worker_->RecvBufAsync(call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { + [call, call_opts](const absl::Status& s) { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { @@ -308,7 +308,7 @@ class GrpcWorkerServiceThread { call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); worker_->CompleteGroupAsync( call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { + [call, call_opts](const absl::Status& s) { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { @@ -327,7 +327,7 @@ class GrpcWorkerServiceThread { call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); worker_->CompleteInstanceAsync( call_opts, &call->request, &call->response, - [call, call_opts](const Status& s) { + [call, call_opts](const absl::Status& s) { call->ClearCancelCallback(); delete call_opts; if (!s.ok()) { @@ -371,6 +371,14 @@ class GrpcWorkerService : public tsl::AsyncServiceInterface { GrpcWorkerServiceOptions options) : is_shutdown_(false) { builder->RegisterService(&worker_service_); + // gRPC by default will cancel requests that sit in a completion queue for + // more than 30s. See + // https://github.com/grpc/grpc/blob/e52e48b7ef83feeff56ed0894ce39841ea8bd483/include/grpc/impl/channel_arg_names.h#L106-L111 + // Extending this to 1 hour for Tensorflow since some graphs may have + // periods of heavy load which may cause the server to run into these + // cancellations. + builder->AddChannelArgument("grpc.server_max_unrequested_time_in_server", + 3600); for (int i = 0; i < options.num_serving_threads; i++) { threads_.emplace_back(new GrpcWorkerServiceThread( @@ -447,9 +455,9 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, bool cache_enabled = (response_cache_ != nullptr && request_id != 0); - auto do_response = [response, done, cache_enabled](const Tensor& tensor, - bool is_dead, - const Status& status) { + auto do_response = [response, done, cache_enabled]( + const Tensor& tensor, bool is_dead, + const absl::Status& status) { if (status.ok()) { grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response); } @@ -467,7 +475,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, auto rendezvous_done = [this, request_id, do_response, cache_enabled]( const Tensor& tensor, bool is_dead, - const Status& status) { + const absl::Status& status) { if (cache_enabled) { // Data is ready. Process all pending requests in the response cache. response_cache_->RequestFinished(request_id, tensor, is_dead, status); @@ -476,11 +484,11 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, } }; - auto fail = [&rendezvous_done](const Status& status) { + auto fail = [&rendezvous_done](const absl::Status& status) { rendezvous_done(Tensor(), false, status); }; - Status s = recent_request_ids_.TrackUnique( + absl::Status s = recent_request_ids_.TrackUnique( request_id, "RecvTensor (GrpcWorker)", *request); if (!s.ok()) { fail(s); @@ -515,7 +523,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, env_->rendezvous_mgr->RecvLocalAsync( step_id, parsed, [opts, rendezvous_done, src_dev, request]( - const Status& status, const Rendezvous::Args& send_args, + const absl::Status& status, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& val, const bool is_dead) { opts->ClearCancelCallback(); @@ -543,7 +551,7 @@ void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts, << " gpu_info: " << src_dev->tensorflow_accelerator_device_info(); StatusCallback copy_ready = [rendezvous_done, copy, - is_dead](const Status& s) { + is_dead](const absl::Status& s) { // The value is now ready to be returned on the wire. rendezvous_done(*copy, is_dead, s); delete copy; @@ -590,7 +598,7 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, auto do_response = [this, response, done, cache_enabled]( const Tensor& tensor, bool is_dead, - const Status& status) { + const absl::Status& status) { if (status.ok()) { SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response); } @@ -609,7 +617,7 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, } auto rendezvous_done = [this, request_id, do_response, cache_enabled]( - const Tensor& tensor, const Status& status) { + const Tensor& tensor, const absl::Status& status) { if (cache_enabled) { // Data is ready. Process all pending requests in the response cache. response_cache_->RequestFinished(request_id, tensor, false, status); @@ -618,13 +626,13 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, } }; - auto fail = [&rendezvous_done](const Status& status) { + auto fail = [&rendezvous_done](const absl::Status& status) { rendezvous_done(Tensor(), status); }; // This is a generic, low performance implementation appropriate for grpc. - Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)", - *request); + absl::Status s = recent_request_ids_.TrackUnique( + request_id, "RecvBuf (GrpcWorker)", *request); if (!s.ok()) { fail(s); return; @@ -634,9 +642,9 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, env_->collective_executor_mgr->FindOrCreate(step_id), true); CollectiveRemoteAccess* rma = ce_handle.get()->remote_access(); auto consumer_callback = [this, request, rendezvous_done]( - const Status& status, + const absl::Status& status, BufRendezvous::Hook* hook) { - Status s = status; + absl::Status s = status; if (s.ok()) { if (hook == nullptr) { s = errors::Internal("Invalid null hook for key ", @@ -678,7 +686,7 @@ void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request, hook->prod_value->dtype(), hook->prod_value->shape()); hook->prod_ctx->CopyDeviceTensorToCPU( hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor, - [hook, cpu_tensor, rendezvous_done](const Status& s) { + [hook, cpu_tensor, rendezvous_done](const absl::Status& s) { rendezvous_done(*cpu_tensor, s); BufRendezvous::DoneWithHook(hook); delete cpu_tensor; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 5f74b27223af7a..51cbbbac941437 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -102,7 +102,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { StartRTCall(std::move(recv_done)); } - void StartAbort(const Status& s) override { + void StartAbort(const absl::Status& s) override { { mutex_lock l(mu_); status_.Update(s); @@ -110,7 +110,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { opts_.StartCancel(); } - Status status() const override { + absl::Status status() const override { mutex_lock l(mu_); return status_; } @@ -138,7 +138,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { resp_.InitAlloc(dst_device_, alloc_attrs_); auto abort_checked = std::make_shared(); auto cb = [this, abort_checked, - recv_done = std::move(recv_done)](const Status& s) { + recv_done = std::move(recv_done)](const absl::Status& s) { // Make sure the Rendezvous abort checking is finished before running the // callback, which might destroy the current call object. abort_checked->WaitForNotification(); @@ -155,7 +155,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { // the `RecvTensorAsync` request registers its RPC cancellation to `opts_`. // In that case, the previous `StartAbort` would not trigger the // cancellation of this call. - Status s; + absl::Status s; { mutex_lock l(mu_); s = status_; @@ -179,7 +179,7 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { Rendezvous::DoneCallback done_; mutable mutex mu_; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); RpcRecvTensorCall(const RpcRecvTensorCall&) = delete; void operator=(const RpcRecvTensorCall&) = delete; @@ -234,7 +234,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { CHECK(is_initialized()); - Status s; + absl::Status s; // Prepare a RecvTensor call that can handle being aborted. RpcRecvTensorCall* call = get_call_freelist()->New(); @@ -294,7 +294,7 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( DeregisterCall(call, recv_args); // If StartAbort was called prior to DeregisterCall, then the // current status should be bad. - Status s = call->status(); + absl::Status s = call->status(); // NOTE: `*session()` can potentially be deleted before we return from // `call->done()(...)`, so we must release the worker before calling the // callback. diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index fa784d27e0be39..701ce3ed4e61a3 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -80,11 +80,11 @@ class DummyWorkerCache : public WorkerCacheInterface { } return dummy_remote_worker_; } - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { return errors::Unimplemented("Unimplemented."); } - Status GetCoordinationClientCache( + absl::Status GetCoordinationClientCache( std::unique_ptr* coord_client_cache) override { return errors::Unimplemented("Unimplemented."); } @@ -103,7 +103,7 @@ static Device* CreateDevice(const char* type, const char* name) { class FakeDevice : public Device { public: explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; } }; DeviceAttributes attr; @@ -272,7 +272,7 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) { Notification n; rmgr_.RecvLocalAsync( step_id, key, - [&n](const Status& s, const Rendezvous::Args send_args, + [&n](const absl::Status& s, const Rendezvous::Args send_args, const Rendezvous::Args recv_args, const Tensor& val, bool is_dead) { auto send_dev_context = @@ -320,21 +320,21 @@ TEST_F(RpcRendezvousMgrTest, RemoteRecvAsyncMany) { int num_requests = 10000; Tensor val(DT_STRING); mutex mu_; - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); BlockingCounter counter(num_requests); for (int i = 0; i < num_requests; i++) { - rendez->RecvAsync( - key, args, - [&mu_, &status, &counter](const Status& s, const Rendezvous::Args&, - const Rendezvous::Args&, const Tensor&, - const bool) { - { - mutex_lock l(mu_); - status.Update(s); - } - counter.DecrementCount(); - }); + rendez->RecvAsync(key, args, + [&mu_, &status, &counter](const absl::Status& s, + const Rendezvous::Args&, + const Rendezvous::Args&, + const Tensor&, const bool) { + { + mutex_lock l(mu_); + status.Update(s); + } + counter.DecrementCount(); + }); } counter.Wait(); TF_ASSERT_OK(status); diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.cc b/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.cc index 62a26218cbde7e..663adf385b0b50 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.cc @@ -75,7 +75,8 @@ bool RpcResponseCache::QueueRequest(int64_t request_id, int64_t step_id, } void RpcResponseCache::RequestFinished(int64_t request_id, const Tensor& tensor, - bool is_dead, const Status& status) { + bool is_dead, + const absl::Status& status) { ResponseCacheEntry entry_copy; { diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h b/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h index c7c7567fa2ba2f..0f31ddaf81ccfc 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_response_cache.h @@ -41,7 +41,7 @@ namespace tensorflow { class RpcResponseCache { public: using FinishResponseCB = std::function; + const Tensor& tensor, bool is_dead, const absl::Status& status)>; // Add the given request to the cache. // If the request is in the cache, @@ -56,7 +56,7 @@ class RpcResponseCache { // Fill the response cache for the given request_id and respond to all // pending request. void RequestFinished(int64_t request_id, const Tensor& tensor, bool is_dead, - const Status& status); + const absl::Status& status); // Erase the cache entry with the given request_id void EraseRequestId(int64_t request_id); @@ -78,7 +78,7 @@ class RpcResponseCache { int64_t step_id = -1; Tensor tensor; bool is_dead = false; - Status response_status; + absl::Status response_status; void FinishResponse(const FinishResponseCB& cb) const { cb(tensor, is_dead, response_status); diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc index fd649181a38b97..1af67bdb51b3ca 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -87,7 +87,7 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( GetStepSequenceResponse* resp = new GetStepSequenceResponse; req->add_graph_key(graph_key); wi->GetStepSequenceAsync( - req, resp, [this, req, resp, done](const Status& s) { + req, resp, [this, req, resp, done](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Bad response [" << s << "] from GetStepSequenceAsync call to " @@ -128,7 +128,7 @@ void RpcCollectiveExecutorMgr::GetStepSequenceAsync( } } -Status RpcCollectiveExecutorMgr::UpdateStepSequences( +absl::Status RpcCollectiveExecutorMgr::UpdateStepSequences( const GetStepSequenceResponse& resp) { mutex_lock l(sequence_mu_); for (const StepSequence& ss : resp.step_sequence()) { diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h index 0e775093813e31..6836204cc1a289 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h @@ -65,7 +65,7 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { friend class RpcCollectiveExecutorMgrTest; private: - Status UpdateStepSequences(const GetStepSequenceResponse& resp); + absl::Status UpdateStepSequences(const GetStepSequenceResponse& resp); // This class maintains the step_id sequencing for a single // collective_graph_key. diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc index 057da719818200..a5522f8cab0cd4 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc @@ -86,12 +86,12 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) { // Calling Refresh should generate a valid id. { Notification note; - Status status; - cme_->RefreshStepIdSequenceAsync(7, - [this, &status, ¬e](const Status& s) { - status = s; - note.Notify(); - }); + absl::Status status; + cme_->RefreshStepIdSequenceAsync( + 7, [this, &status, ¬e](const absl::Status& s) { + status = s; + note.Notify(); + }); EXPECT_TRUE(status.ok()); } x = cme_->NextStepId(7); @@ -109,12 +109,12 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) { // Calling refresh should jump to a different point in the random space. { Notification note; - Status status; - cme_->RefreshStepIdSequenceAsync(7, - [this, &status, ¬e](const Status& s) { - status = s; - note.Notify(); - }); + absl::Status status; + cme_->RefreshStepIdSequenceAsync( + 7, [this, &status, ¬e](const absl::Status& s) { + status = s; + note.Notify(); + }); note.WaitForNotification(); EXPECT_TRUE(status.ok()); @@ -136,9 +136,9 @@ TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) { request.add_graph_key(4); { Notification note; - Status status; + absl::Status status; cme_->GetStepSequenceAsync(&request, &response, - [this, &status, ¬e](const Status& s) { + [this, &status, ¬e](const absl::Status& s) { status = s; note.Notify(); }); @@ -156,9 +156,9 @@ TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) { response.Clear(); { Notification note; - Status status; + absl::Status status; cme_->GetStepSequenceAsync(&request, &response, - [this, &status, ¬e](const Status& s) { + [this, &status, ¬e](const absl::Status& s) { status = s; note.Notify(); }); diff --git a/tensorflow/core/distributed_runtime/server_lib.cc b/tensorflow/core/distributed_runtime/server_lib.cc index a653a7999fed41..2f7cc4184662f4 100644 --- a/tensorflow/core/distributed_runtime/server_lib.cc +++ b/tensorflow/core/distributed_runtime/server_lib.cc @@ -46,8 +46,8 @@ void ServerFactory::Register(const string& server_type, } /* static */ -Status ServerFactory::GetFactory(const ServerDef& server_def, - ServerFactory** out_factory) { +absl::Status ServerFactory::GetFactory(const ServerDef& server_def, + ServerFactory** out_factory) { mutex_lock l(*get_server_factory_lock()); for (const auto& server_factory : *server_factories()) { if (server_factory.second->AcceptsOptions(server_def)) { @@ -69,8 +69,8 @@ Status ServerFactory::GetFactory(const ServerDef& server_def, // Creates a server based on the given `server_def`, and stores it in // `*out_server`. Returns OK on success, otherwise returns an error. -Status NewServer(const ServerDef& server_def, - std::unique_ptr* out_server) { +absl::Status NewServer(const ServerDef& server_def, + std::unique_ptr* out_server) { ServerFactory* factory; TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); return factory->NewServer(server_def, ServerFactory::Options(), out_server); @@ -78,9 +78,9 @@ Status NewServer(const ServerDef& server_def, // Creates a server based on the given `server_def`, and stores it in // `*out_server`. Returns OK on success, otherwise returns an error. -Status NewServerWithOptions(const ServerDef& server_def, - const ServerFactory::Options& options, - std::unique_ptr* out_server) { +absl::Status NewServerWithOptions( + const ServerDef& server_def, const ServerFactory::Options& options, + std::unique_ptr* out_server) { ServerFactory* factory; TF_RETURN_IF_ERROR(ServerFactory::GetFactory(server_def, &factory)); return factory->NewServer(server_def, options, out_server); diff --git a/tensorflow/core/distributed_runtime/server_lib.h b/tensorflow/core/distributed_runtime/server_lib.h index 3a5ca1853889bc..cc92d0bae12b17 100644 --- a/tensorflow/core/distributed_runtime/server_lib.h +++ b/tensorflow/core/distributed_runtime/server_lib.h @@ -49,18 +49,18 @@ class ServerInterface { // Starts the server running asynchronously. Returns OK on success, otherwise // returns an error. - virtual Status Start() = 0; + virtual absl::Status Start() = 0; // Stops the server asynchronously. Returns OK on success, otherwise returns // an error. // // After calling `Stop()`, the caller may call `Join()` to block until the // server has stopped. - virtual Status Stop() = 0; + virtual absl::Status Stop() = 0; // Blocks until the server has stopped. Returns OK on success, otherwise // returns an error. - virtual Status Join() = 0; + virtual absl::Status Join() = 0; // Returns a target string that can be used to connect to this server using // `tensorflow::NewSession()`. @@ -70,20 +70,20 @@ class ServerInterface { virtual MasterEnv* master_env() = 0; // Update the set of workers that can be reached by the server - virtual Status UpdateServerDef(const ServerDef& server_def) = 0; + virtual absl::Status UpdateServerDef(const ServerDef& server_def) = 0; // Functions to operate on service-specific properties. // // Add master eager context to local eager service in order to handle enqueue // requests from remote workers. - virtual Status AddMasterEagerContextToEagerService( + virtual absl::Status AddMasterEagerContextToEagerService( const tensorflow::uint64 context_id, EagerContext* context) = 0; // Set coordination service agent instance to coordination service RPC handler - virtual Status SetCoordinationServiceAgentInstance( + virtual absl::Status SetCoordinationServiceAgentInstance( tsl::CoordinationServiceAgent* agent) = 0; // TODO(hanyangtay): Remove this method once gRPC server clean shutdown is // supported. - virtual Status StopCoordinationService() = 0; + virtual absl::Status StopCoordinationService() = 0; private: ServerInterface(const ServerInterface&) = delete; @@ -99,8 +99,9 @@ class ServerFactory { // Creates a new server based on the given `server_def`, and stores // it in `*out_server`. Returns OK on success, otherwise returns an // error. - virtual Status NewServer(const ServerDef& server_def, const Options& options, - std::unique_ptr* out_server) = 0; + virtual absl::Status NewServer( + const ServerDef& server_def, const Options& options, + std::unique_ptr* out_server) = 0; // Returns true if and only if this factory can create a server // based on the given `server_def`. @@ -117,17 +118,17 @@ class ServerFactory { // Looks up a factory that can create a server based on the given // `server_def`, and stores it in `*out_factory`. Returns OK on // success, otherwise returns an error. - static Status GetFactory(const ServerDef& server_def, - ServerFactory** out_factory); + static absl::Status GetFactory(const ServerDef& server_def, + ServerFactory** out_factory); }; // Creates a server based on the given `server_def`, and stores it in // `*out_server`. Returns OK on success, otherwise returns an error. -Status NewServer(const ServerDef& server_def, - std::unique_ptr* out_server); -Status NewServerWithOptions(const ServerDef& server_def, - const ServerFactory::Options& options, - std::unique_ptr* out_server); +absl::Status NewServer(const ServerDef& server_def, + std::unique_ptr* out_server); +absl::Status NewServerWithOptions(const ServerDef& server_def, + const ServerFactory::Options& options, + std::unique_ptr* out_server); } // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/server_lib_test.cc b/tensorflow/core/distributed_runtime/server_lib_test.cc index 49abd7e7a639e9..c38ca77b339dff 100644 --- a/tensorflow/core/distributed_runtime/server_lib_test.cc +++ b/tensorflow/core/distributed_runtime/server_lib_test.cc @@ -26,8 +26,9 @@ class TestServerFactory : public ServerFactory { return server_def.protocol() == "test_protocol"; } - Status NewServer(const ServerDef& server_def, const Options& options, - std::unique_ptr* out_server) override { + absl::Status NewServer( + const ServerDef& server_def, const Options& options, + std::unique_ptr* out_server) override { return absl::OkStatus(); } }; @@ -44,7 +45,7 @@ TEST(ServerLibTest, NewServerNoFactoriesAccept) { ServerDef server_def; server_def.set_protocol("fake_protocol"); std::unique_ptr server; - Status s = NewServer(server_def, &server); + absl::Status s = NewServer(server_def, &server); ASSERT_NE(s, absl::OkStatus()); EXPECT_TRUE(absl::StrContains( s.message(), "No server factory registered for the given ServerDef")); diff --git a/tensorflow/core/distributed_runtime/session_mgr.cc b/tensorflow/core/distributed_runtime/session_mgr.cc index a6b4df397b6b25..aa6399f55c01a0 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.cc +++ b/tensorflow/core/distributed_runtime/session_mgr.cc @@ -23,6 +23,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tensorflow/core/activity_watcher/activity.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/renamed_device.h" @@ -35,9 +38,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tensorflow { namespace { @@ -112,16 +112,15 @@ std::string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) { "/task:", server_def.task_index()); } -Status SessionMgr::CreateSession(const std::string& session, - const ServerDef& server_def, - bool isolate_session_state, - StatusCallback coordination_error_callback) { +absl::Status SessionMgr::CreateSession( + const std::string& session, const ServerDef& server_def, + bool isolate_session_state, StatusCallback coordination_error_callback) { return CreateSession(session, server_def, {}, isolate_session_state, /*master_task=*/"", /*master_incarnation=*/0, coordination_error_callback); } -Status SessionMgr::CreateSession( +absl::Status SessionMgr::CreateSession( const std::string& session, const ServerDef& server_def, const protobuf::RepeatedPtrField& cluster_device_attributes, @@ -132,7 +131,7 @@ Status SessionMgr::CreateSession( /*master_incarnation=*/0); } -Status SessionMgr::CreateSession( +absl::Status SessionMgr::CreateSession( const std::string& session, const ServerDef& server_def, const protobuf::RepeatedPtrField& cluster_device_attributes, @@ -314,7 +313,7 @@ void SessionMgr::ResetDefaultWorkerCache(WorkerCacheInterface* worker_cache) { default_worker_cache_.reset(worker_cache); } -Status SessionMgr::UpdateSession( +absl::Status SessionMgr::UpdateSession( const std::string& session, const ServerDef& server_def, const protobuf::RepeatedPtrField& cluster_device_attributes) { @@ -377,7 +376,7 @@ Status SessionMgr::UpdateSession( return absl::OkStatus(); } -Status SessionMgr::DeleteSession(const std::string& session) { +absl::Status SessionMgr::DeleteSession(const std::string& session) { mutex_lock l(mu_); auto it = sessions_.find(session); if (it != sessions_.end()) { @@ -386,7 +385,7 @@ Status SessionMgr::DeleteSession(const std::string& session) { return absl::OkStatus(); } -Status SessionMgr::DeleteAllSessions() { +absl::Status SessionMgr::DeleteAllSessions() { std::map> tmp_sessions; { mutex_lock l(mu_); @@ -399,7 +398,7 @@ Status SessionMgr::DeleteAllSessions() { return absl::OkStatus(); } -Status SessionMgr::WorkerSessionForSessionLocked( +absl::Status SessionMgr::WorkerSessionForSessionLocked( const std::string& session_handle, std::shared_ptr* out_session) { if (session_handle.empty()) { @@ -422,7 +421,7 @@ Status SessionMgr::WorkerSessionForSessionLocked( return absl::OkStatus(); } -Status SessionMgr::WorkerSessionForSession( +absl::Status SessionMgr::WorkerSessionForSession( const std::string& session_handle, std::shared_ptr* out_session) { mutex_lock l(mu_); diff --git a/tensorflow/core/distributed_runtime/session_mgr.h b/tensorflow/core/distributed_runtime/session_mgr.h index 6daaa756b05bfb..55c64f45c9daeb 100644 --- a/tensorflow/core/distributed_runtime/session_mgr.h +++ b/tensorflow/core/distributed_runtime/session_mgr.h @@ -41,7 +41,7 @@ struct WorkerEnv; // SessionMgr is threadsafe. class SessionMgr { public: - typedef std::function + typedef std::function WorkerCacheFactory; explicit SessionMgr( @@ -52,13 +52,13 @@ class SessionMgr { ~SessionMgr() {} // Allocates state for a new session. - Status CreateSession( + absl::Status CreateSession( const std::string& session, const ServerDef& server_def, bool isolate_session_state, - StatusCallback coordination_error_callback = [](Status s) { + StatusCallback coordination_error_callback = [](absl::Status s) { LOG(ERROR) << "Coordination agent is set to error: " << s; }); - Status CreateSession( + absl::Status CreateSession( const std::string& session, const ServerDef& server_def, const protobuf::RepeatedPtrField& device_attributes, bool isolate_session_state); @@ -70,12 +70,12 @@ class SessionMgr { // master has restarted before deleting the sessions on worker. When it // happens, old sessions associated with the master will be automatically // removed before the new session is created. - Status CreateSession( + absl::Status CreateSession( const std::string& session, const ServerDef& server_def, const protobuf::RepeatedPtrField& device_attributes, bool isolate_session_state, std::string master_task, int64_t master_incarnation, - StatusCallback coordination_error_callback = [](Status s) { + StatusCallback coordination_error_callback = [](absl::Status s) { LOG(ERROR) << "Coordination agent is set to error: " << s; }); @@ -83,19 +83,21 @@ class SessionMgr { // Updates state (worker cache, devices) of worker session identified by // session name (`session`) based on a new server_def and set of devices. - Status UpdateSession(const std::string& session, const ServerDef& server_def, - const protobuf::RepeatedPtrField& - cluster_device_attributes); + absl::Status UpdateSession(const std::string& session, + const ServerDef& server_def, + const protobuf::RepeatedPtrField& + cluster_device_attributes); // Locates the worker session for a given session handle - Status WorkerSessionForSession(const std::string& session_handle, - std::shared_ptr* out_session); + absl::Status WorkerSessionForSession( + const std::string& session_handle, + std::shared_ptr* out_session); std::shared_ptr LegacySession(); - Status DeleteSession(const std::string& session); + absl::Status DeleteSession(const std::string& session); // Deletes all existing sessions. - Status DeleteAllSessions(); + absl::Status DeleteAllSessions(); // Provides access to the coordination service agent. This method should only // be called after the agent has been initialized during session creation, or @@ -142,7 +144,7 @@ class SessionMgr { // Not owned. And should only be used for setting the coordination service. tsl::CoordinationServiceRpcHandler* coordination_handler_ = nullptr; - Status WorkerSessionForSessionLocked( + absl::Status WorkerSessionForSessionLocked( const std::string& session_handle, std::shared_ptr* out_session) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/tensorflow/core/distributed_runtime/session_mgr_test.cc b/tensorflow/core/distributed_runtime/session_mgr_test.cc index 9e19a878750e77..0eab3f2aacf9c1 100644 --- a/tensorflow/core/distributed_runtime/session_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/session_mgr_test.cc @@ -33,7 +33,9 @@ class FakeDevice : public Device { : Device(nullptr, device_attributes) {} public: - Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); } + absl::Status Sync() override { + return errors::Unimplemented("FakeDevice::Sync()"); + } Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; } @@ -252,7 +254,7 @@ TEST_F(SessionMgrTest, LegacySession) { TEST_F(SessionMgrTest, UnknownSessionHandle) { std::string session_handle = "unknown_session_handle"; std::shared_ptr session; - Status s = mgr_.WorkerSessionForSession(session_handle, &session); + absl::Status s = mgr_.WorkerSessionForSession(session_handle, &session); EXPECT_TRUE(absl::IsAborted(s)); EXPECT_TRUE(absl::StrContains(s.message(), "Session handle is not found")); EXPECT_TRUE(s.GetPayload(kWorkerPossiblyRestarted).has_value()); diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index 4779fb5777742b..4b4c7e4d8f5c32 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -50,8 +50,8 @@ void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) { allocator_ = device_->GetAllocator(alloc_attrs_); } -Status TensorResponse::InitFrom(RecvTensorResponse* response) { - Status s; +absl::Status TensorResponse::InitFrom(RecvTensorResponse* response) { + absl::Status s; meta_.Swap(response); if (on_host_) { if (!tensor_.FromProto(allocator_, meta_.tensor())) { @@ -79,7 +79,7 @@ void TensorResponse::InitPartial(const RecvTensorResponse& response, tensor_ = std::move(t); } -Status TensorResponse::ParseFrom(Source* source) { +absl::Status TensorResponse::ParseFrom(Source* source) { if (!on_host_) { protobuf::io::CodedInputStream input(source->contents()); @@ -87,7 +87,7 @@ Status TensorResponse::ParseFrom(Source* source) { if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) { return errors::InvalidArgument("Cannot parse tensor from response"); } - Status s = + absl::Status s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); // Reduce memory usage for big tensors. { diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h index b517bac4222956..1fd40d957fa7ab 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.h +++ b/tensorflow/core/distributed_runtime/tensor_coding.h @@ -67,11 +67,11 @@ class TensorResponse { // Parse the RecvTensorResponse encoded in the data yielded by // source->contents() into *this. - Status ParseFrom(Source* source); + absl::Status ParseFrom(Source* source); // Initialize tensor from *response. // Leaves *response with unspecified contents. - Status InitFrom(RecvTensorResponse* response); + absl::Status InitFrom(RecvTensorResponse* response); // Initialize tensor metadata from response and allocate // uninitialized backing storage for actual contents. diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc index 21d2cbb7a38960..a95e51b03486b7 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding_test.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc @@ -92,7 +92,7 @@ class TensorResponseTest : public ::testing::Test { DummyDevice cpu_device(Env::Default()); response.InitAlloc(&cpu_device, AllocatorAttributes()); for (int i = 0; i < 2; i++) { // Twice so we exercise reuse of "response" - Status s = response.ParseFrom(&source); + absl::Status s = response.ParseFrom(&source); EXPECT_TRUE(s.ok()); const RecvTensorResponse& meta = response.metadata(); @@ -183,7 +183,7 @@ static void BM_TensorResponse(::testing::benchmark::State& state) { TensorResponse response; response.InitAlloc(&cpu_device, AllocatorAttributes()); StringSource source(&encoded, -1); - Status s = response.ParseFrom(&source); + absl::Status s = response.ParseFrom(&source); bytes = response.tensor().TotalBytes(); } state.SetLabel(strings::StrCat("Bytes: ", bytes)); diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h index ec8ba7be22da8e..e7ad1041dd73ff 100644 --- a/tensorflow/core/distributed_runtime/test_utils.h +++ b/tensorflow/core/distributed_runtime/test_utils.h @@ -162,12 +162,12 @@ class TestWorkerCache : public WorkerCacheInterface { void ReleaseWorker(const string& target, WorkerInterface* worker) override {} - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { return errors::Unimplemented("Unimplemented."); } - Status GetCoordinationClientCache( + absl::Status GetCoordinationClientCache( std::unique_ptr* coord_client_cache) override { return errors::Unimplemented("Unimplemented."); } diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 0922b04de0b0f8..9fb0a76ad866f9 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tensorflow/core/common_runtime/collective_executor_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/collective.h" #include "tensorflow/core/profiler/lib/device_profiler_session.h" #include "tsl/platform/tracing.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tensorflow { @@ -59,7 +59,7 @@ void Worker::GetStatusAsync(CallOptions* opts, const GetStatusRequest* request, void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request, CreateWorkerSessionResponse* response, StatusCallback done) { - Status s = env_->session_mgr->CreateSession( + absl::Status s = env_->session_mgr->CreateSession( request->session_handle(), request->server_def(), request->cluster_device_attributes(), request->isolate_session_state(), request->master_task(), request->master_incarnation()); @@ -70,7 +70,7 @@ void Worker::DeleteWorkerSessionAsync(CallOptions* opts, const DeleteWorkerSessionRequest* request, DeleteWorkerSessionResponse* response, StatusCallback done) { - Status s = env_->session_mgr->DeleteSession(request->session_handle()); + absl::Status s = env_->session_mgr->DeleteSession(request->session_handle()); done(s); } @@ -78,7 +78,7 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request, RegisterGraphResponse* response, StatusCallback done) { std::shared_ptr session; - Status s; + absl::Status s; if (request->create_worker_session_called()) { s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), &session); @@ -99,7 +99,7 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request, DeregisterGraphResponse* response, StatusCallback done) { std::shared_ptr session; - Status s; + absl::Status s; if (request->create_worker_session_called()) { s = env_->session_mgr->WorkerSessionForSession(request->session_handle(), &session); @@ -128,9 +128,9 @@ void Worker::AbortStep(int64_t step_id) { }); } -Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, - GraphMgr::NamedTensors* in, - GraphMgr::NamedTensors* out) { +absl::Status Worker::PrepareRunGraph(RunGraphRequestWrapper* req, + GraphMgr::NamedTensors* in, + GraphMgr::NamedTensors* out) { static Tensor empty_tensor(DT_FLOAT); if (req->num_sends() > 0) { Tensor val; @@ -149,7 +149,7 @@ void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request, MutableRunGraphResponseWrapper* response, StatusCallback done) { if (request->store_errors_in_response_body()) { - done = [response, done](const Status& status) { + done = [response, done](const absl::Status& status) { response->set_status(status); done(absl::OkStatus()); }; @@ -174,8 +174,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, StatusCallback done) { const int64_t step_id = request->step_id(); TRACEPRINTF("RunGraph: %lld", step_id); - Status s = recent_request_ids_.TrackUnique(request->request_id(), - "RunGraph (Worker)", request); + absl::Status s = recent_request_ids_.TrackUnique( + request->request_id(), "RunGraph (Worker)", request); if (!s.ok()) { done(s); return; @@ -234,8 +234,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, request->graph_handle(), step_id, request->exec_opts(), in, session.get(), collector, response, cm, env_->session_mgr->GetCoordinationServiceAgent(), [this, step_id, response, session, cm, out, token, collector, - device_profiler_session, opts, done](const Status& status) { - Status s = status; + device_profiler_session, opts, done](const absl::Status& status) { + absl::Status s = status; if (s.ok()) { s = session->graph_mgr()->RecvOutputs(step_id, out); } @@ -273,7 +273,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, const int64_t step_id = request->step_id(); const string& graph_handle = request->graph_handle(); TRACEPRINTF("PartialRunGraph: %lld", step_id); - Status s = recent_request_ids_.TrackUnique( + absl::Status s = recent_request_ids_.TrackUnique( request->request_id(), "PartialRunGraph (Worker)", request); if (!s.ok()) { done(s); @@ -295,7 +295,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, GraphMgr::NamedTensors in; GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors; s = PrepareRunGraph(request, &in, out); - auto finish = [done, out, opts](const Status& s) { + auto finish = [done, out, opts](const absl::Status& s) { opts->ClearCancelCallback(); delete out; done(s); @@ -326,7 +326,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, graph_handle, step_id, request->exec_opts(), in, session.get(), /*collector=*/nullptr, /*response=*/nullptr, cm, env_->session_mgr->GetCoordinationServiceAgent(), - [this, token, step_id, session](Status s) { + [this, token, step_id, session](absl::Status s) { cancellation_manager_.DeregisterCallback(token); partial_run_mgr_.ExecutorDone(step_id, s); }); @@ -340,7 +340,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts, } session->graph_mgr()->RecvOutputsAsync( - step_id, out, [this, out, request, response, step_id, finish](Status s) { + step_id, out, + [this, out, request, response, step_id, finish](absl::Status s) { if (s.ok()) { // Construct and return the resp. for (const auto& p : *out) { @@ -419,7 +420,8 @@ void Worker::CompleteGroupAsync(CallOptions* opts, group_params->device_type = DeviceType(request->device_type()); env_->collective_executor_mgr->GetParamResolver()->CompleteGroupAsync( request->device_attributes(), group_params, &cancellation_manager_, - [response, group_params, done = std::move(done)](const Status& s) { + [response, group_params, + done = std::move(done)](const absl::Status& s) { if (s.ok()) { response->set_group_key(group_params->group_key); response->set_group_size(group_params->group_size); @@ -469,8 +471,8 @@ void Worker::GetStepSequenceAsync(const GetStepSequenceRequest* request, // Helper for RecvTensor. Validates "key" and returns the source // device in "*src_dev". -Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, - Device** src_dev) { +absl::Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev) { // Figures out which device the tensor is hosted on. string local_name = DeviceNameUtils::LocalName(parsed.src_device); TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev)); diff --git a/tensorflow/core/distributed_runtime/worker.h b/tensorflow/core/distributed_runtime/worker.h index 780976dcab3b19..4c55e1b9612fb2 100644 --- a/tensorflow/core/distributed_runtime/worker.h +++ b/tensorflow/core/distributed_runtime/worker.h @@ -112,8 +112,8 @@ class Worker : public WorkerInterface { WorkerEnv* const env_; // Not owned. RecentRequestIds recent_request_ids_; - Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, - Device** src_dev); + absl::Status PrepareRecvTensor(const Rendezvous::ParsedKey& parsed, + Device** src_dev); void AbortStep(int64_t); @@ -122,9 +122,9 @@ class Worker : public WorkerInterface { CancellationManager cancellation_manager_; - Status PrepareRunGraph(RunGraphRequestWrapper* req, - GraphMgr::NamedTensors* in, - GraphMgr::NamedTensors* out); + absl::Status PrepareRunGraph(RunGraphRequestWrapper* req, + GraphMgr::NamedTensors* in, + GraphMgr::NamedTensors* out); void DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, MutableRunGraphResponseWrapper* response, diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h index 542ac79334da17..1ac4de35d9788f 100644 --- a/tensorflow/core/distributed_runtime/worker_cache.h +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -typedef std::function StatusCallback; +typedef std::function StatusCallback; class ChannelCache; class StepStats; @@ -75,11 +75,11 @@ class WorkerCacheInterface { // construct client cache of different types sharing the same underling RPC // channels, to replace the eager and coordination cache function. // Build and return a EagerClientCache object wrapping that channel. - virtual Status GetEagerClientCache( + virtual absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) = 0; // Build and return a CoordinationClientCache object wrapping that channel. - virtual Status GetCoordinationClientCache( + virtual absl::Status GetCoordinationClientCache( std::unique_ptr* coordination_client_cache) = 0; // Start/stop logging activity. diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc index f224094f5f7d5b..58b130228e00dd 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_partial.cc +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc @@ -43,7 +43,7 @@ void WorkerCachePartial::GetDeviceLocalityAsync(const string& device_name, if (!GetDeviceLocalityNonBlocking(device_name, locality)) { // If cache entry was empty, make one try to fill it by RPC. SchedClosure([this, &device_name, locality, done]() { - Status s = RefreshDeviceStatus(device_name); + absl::Status s = RefreshDeviceStatus(device_name); if (s.ok() && !GetDeviceLocalityNonBlocking(device_name, locality)) { s = errors::Unavailable("No known remote device: ", device_name); } @@ -54,10 +54,11 @@ void WorkerCachePartial::GetDeviceLocalityAsync(const string& device_name, done(absl::OkStatus()); } -Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) { +absl::Status WorkerCachePartial::RefreshDeviceStatus( + const string& device_name) { string task; string device; - Status s; + absl::Status s; if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) { s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ", device_name); diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.h b/tensorflow/core/distributed_runtime/worker_cache_partial.h index 57f2aca898bef6..b5a500b86dae00 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_partial.h +++ b/tensorflow/core/distributed_runtime/worker_cache_partial.h @@ -47,7 +47,7 @@ class WorkerCachePartial : public WorkerCacheInterface { // Initiate a GetStatusAsync to the remote task named by "task", and // update the cache with all the DeviceAttributes reported. - Status RefreshDeviceStatus(const string& device_name); + absl::Status RefreshDeviceStatus(const string& device_name); typedef std::unordered_map StatusMap; StatusMap device_status_cache_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h index 195da534b45ce6..7f709b4fb5c1bb 100644 --- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h +++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h @@ -54,13 +54,14 @@ class WorkerCacheWrapper : public WorkerCacheInterface { return wrapped_->ReleaseWorker(target, worker); } - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { return wrapped_->GetEagerClientCache(eager_client_cache); } - Status GetCoordinationClientCache(std::unique_ptr* - coordination_client_cache) override { + absl::Status GetCoordinationClientCache( + std::unique_ptr* coordination_client_cache) + override { return wrapped_->GetCoordinationClientCache(coordination_client_cache); } diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h index 7b759eef95b9df..382425bb51cb30 100644 --- a/tensorflow/core/distributed_runtime/worker_interface.h +++ b/tensorflow/core/distributed_runtime/worker_interface.h @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { // Status callback. -typedef std::function StatusCallback; +typedef std::function StatusCallback; // Custom decoder for a response to RecvTensorAsync. class TensorResponse; @@ -68,7 +68,7 @@ class WorkerInterface { new NonOwnedProtoRunGraphResponse(response); RunGraphAsync(opts, wrapped_request, wrapped_response, [wrapped_request, wrapped_response, - done = std::move(done)](const Status& s) { + done = std::move(done)](const absl::Status& s) { done(s); delete wrapped_request; delete wrapped_response; @@ -129,12 +129,12 @@ class WorkerInterface { GetStepSequenceResponse* response, StatusCallback done) = 0; - Status GetStatus(const GetStatusRequest* request, - GetStatusResponse* response) { - Status ret; + absl::Status GetStatus(const GetStatusRequest* request, + GetStatusResponse* response) { + absl::Status ret; Notification n; GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true, - [&ret, &n](const Status& s) { + [&ret, &n](const absl::Status& s) { ret = s; n.Notify(); }); @@ -142,47 +142,49 @@ class WorkerInterface { return ret; } - Status CreateWorkerSession(const CreateWorkerSessionRequest* request, - CreateWorkerSessionResponse* response) { + absl::Status CreateWorkerSession(const CreateWorkerSessionRequest* request, + CreateWorkerSessionResponse* response) { return CallAndWait(&ME::CreateWorkerSessionAsync, request, response); } - Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, - DeleteWorkerSessionResponse* response) { + absl::Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request, + DeleteWorkerSessionResponse* response) { return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request, response); } - Status RegisterGraph(const RegisterGraphRequest* request, - RegisterGraphResponse* response) { + absl::Status RegisterGraph(const RegisterGraphRequest* request, + RegisterGraphResponse* response) { return CallAndWait(&ME::RegisterGraphAsync, request, response); } - Status DeregisterGraph(const DeregisterGraphRequest* request, - DeregisterGraphResponse* response) { + absl::Status DeregisterGraph(const DeregisterGraphRequest* request, + DeregisterGraphResponse* response) { return CallAndWait(&ME::DeregisterGraphAsync, request, response); } - Status CleanupGraph(const CleanupGraphRequest* request, - CleanupGraphResponse* response) { + absl::Status CleanupGraph(const CleanupGraphRequest* request, + CleanupGraphResponse* response) { return CallAndWait(&ME::CleanupGraphAsync, request, response); } - Status CleanupAll(const CleanupAllRequest* request, - CleanupAllResponse* response) { + absl::Status CleanupAll(const CleanupAllRequest* request, + CleanupAllResponse* response) { return CallAndWait(&ME::CleanupAllAsync, request, response); } - Status Logging(const LoggingRequest* request, LoggingResponse* response) { + absl::Status Logging(const LoggingRequest* request, + LoggingResponse* response) { return CallAndWait(&ME::LoggingAsync, request, response); } - Status Tracing(const TracingRequest* request, TracingResponse* response) { + absl::Status Tracing(const TracingRequest* request, + TracingResponse* response) { return CallAndWait(&ME::TracingAsync, request, response); } - Status GetStepSequence(const GetStepSequenceRequest* request, - GetStepSequenceResponse* response) { + absl::Status GetStepSequence(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response) { return CallAndWait(&ME::GetStepSequenceAsync, request, response); } @@ -204,10 +206,10 @@ class WorkerInterface { typedef WorkerInterface ME; template - Status CallAndWait(Method func, const Req* req, Resp* resp) { - Status ret; + absl::Status CallAndWait(Method func, const Req* req, Resp* resp) { + absl::Status ret; Notification n; - (this->*func)(req, resp, [&ret, &n](const Status& s) { + (this->*func)(req, resp, [&ret, &n](const absl::Status& s) { ret = s; n.Notify(); }); @@ -216,11 +218,11 @@ class WorkerInterface { } template - Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) { + absl::Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) { CallOptions call_opts; - Status ret; + absl::Status ret; Notification n; - (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) { + (this->*func)(&call_opts, req, resp, [&ret, &n](const absl::Status& s) { ret = s; n.Notify(); }); diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc index 1bbf1a7bb6c329..1b4592a54d9dbf 100644 --- a/tensorflow/core/distributed_runtime/worker_session.cc +++ b/tensorflow/core/distributed_runtime/worker_session.cc @@ -76,13 +76,14 @@ class WorkerFreeListCache : public WorkerCacheInterface { } } - Status GetEagerClientCache( + absl::Status GetEagerClientCache( std::unique_ptr* eager_client_cache) override { return wrapped_->GetEagerClientCache(eager_client_cache); } - Status GetCoordinationClientCache(std::unique_ptr* - coordination_client_cache) override { + absl::Status GetCoordinationClientCache( + std::unique_ptr* coordination_client_cache) + override { return wrapped_->GetCoordinationClientCache(coordination_client_cache); } @@ -146,7 +147,7 @@ WorkerSession::WorkerSession( worker_session_created->GetCell()->Set(true); } -Status WorkerSession::UpdateWorkerCacheAndDevices( +absl::Status WorkerSession::UpdateWorkerCacheAndDevices( std::unique_ptr new_worker_cache, std::vector> added_remote_devices, const std::vector& removed_remote_devices) { @@ -197,7 +198,7 @@ WorkerSession::WorkerSession( WorkerSession::~WorkerSession() { if (graph_mgr_) { - Status s = graph_mgr_->DeregisterAll(); + absl::Status s = graph_mgr_->DeregisterAll(); if (!s.ok()) { LOG(WARNING) << "Error during worker session deletion: " << s; } diff --git a/tensorflow/core/distributed_runtime/worker_session.h b/tensorflow/core/distributed_runtime/worker_session.h index b97fe1a99292d9..e366accf18075b 100644 --- a/tensorflow/core/distributed_runtime/worker_session.h +++ b/tensorflow/core/distributed_runtime/worker_session.h @@ -90,7 +90,7 @@ class WorkerSession { // Update an existing worker session with new set of remote workers and // devices. Added devices will be owned by the worker session, and removed // devices will be freed by their names. - Status UpdateWorkerCacheAndDevices( + absl::Status UpdateWorkerCacheAndDevices( std::unique_ptr new_worker_cache, std::vector> added_remote_devices, const std::vector& removed_remote_devices); diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index 5b20f183865e58..272da5127eecb5 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -31,7 +31,7 @@ default_visibility = [ "//tensorflow/core:__subpackages__", "//tensorflow/security/fuzzing:__subpackages__", # TODO(pedaveeraiah): to be removed when summary.proto.h deps moves to TSL - "@local_tsl//tsl/lib:__subpackages__", + "@local_xla//xla/tsl/lib:__subpackages__", # copybara:uncomment "//learning/brain/tfrt/aot:__subpackages__", # copybara:uncomment "//platforms/xla/megascale/tensorflow:__subpackages__", # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/graph_executor:__subpackages__", @@ -575,7 +575,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status", ], ) @@ -719,9 +719,11 @@ cc_library( visibility = default_visibility + [ "//learning/brain/experimental/jax_pst/grpc/tf:__pkg__", "//learning/brain/google/data/core/kernels:__pkg__", + "//learning/brain/google/data/core/parse_proto:__pkg__", "//learning/deepmind/tensorflow/queues:__pkg__", "//learning/deepmind/tensorflow/sstable:__pkg__", "//learning/deepmind/video/tensorflow:__pkg__", + "//learning/sibyl/tfx/state/kernels:__pkg__", "//learning/sibyl/tfx/transformation/kernels:__pkg__", "//tensorflow/compiler/mlir/tools/kernel_gen:__pkg__", "//tensorflow/compiler/tf2xla:__pkg__", @@ -1410,7 +1412,7 @@ tf_cc_tests( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@eigen_archive//:eigen3", - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -1569,6 +1571,7 @@ tf_proto_library( ":types_proto", ":versions_proto", ], + visibility = ["//visibility:public"], ) tf_proto_library( @@ -1801,13 +1804,15 @@ tf_proto_library( ":tensor_shape_proto", ":types_proto", ], + visibility = [ + "//tensorflow/python:__pkg__", + ] + default_visibility, ) # copybara:uncomment_begin(google-only) # py_proto_library( # name = "function_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [ # ":function_proto", # ], @@ -1823,9 +1828,9 @@ tf_proto_library( ":tensor_proto", ":tensor_shape_proto", ":types_proto", - "@local_tsl//tsl/protobuf:histogram_proto", + "@local_xla//xla/tsl/protobuf:histogram_proto", ], - exports = ["@local_tsl//tsl/protobuf:histogram_proto"], + exports = ["@local_xla//xla/tsl/protobuf:histogram_proto"], ) tf_proto_library( diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 6557a4cec7598e..f1dd62af9e4a3c 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/framework/typed_allocator.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mem.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h index 843d024a999037..fa6b6dda979b1e 100644 --- a/tensorflow/core/framework/attr_value_util.h +++ b/tensorflow/core/framework/attr_value_util.h @@ -45,7 +45,7 @@ class NameAttrList; std::string SummarizeAttrValue(const AttrValue& attr_value); // Generates an error if attr_value doesn't have the indicated attr type. -Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); +absl::Status AttrValueHasType(const AttrValue& attr_value, StringPiece type); // Converts a text proto value from "text" into the field of *out // indicated by "type" (e.g. from the type field of an AttrDef). diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h index ffbdfc0d038c8b..8fca00f0e3b515 100644 --- a/tensorflow/core/framework/collective.h +++ b/tensorflow/core/framework/collective.h @@ -169,16 +169,16 @@ class DeviceResolverInterface { virtual ~DeviceResolverInterface() {} // Populates *attributes with the DeviceAttributes of the specified device. - virtual Status GetDeviceAttributes(const string& device, - DeviceAttributes* attributes) = 0; + virtual absl::Status GetDeviceAttributes(const string& device, + DeviceAttributes* attributes) = 0; // Returns all device attributes of a task. - virtual Status GetAllDeviceAttributes( + virtual absl::Status GetAllDeviceAttributes( const string& task, std::vector* attributes) = 0; // Updates device attributes. It returns error if any device already // exists in the DeviceResolver and has a different incarnation. - virtual Status UpdateDeviceAttributes( + virtual absl::Status UpdateDeviceAttributes( const std::vector& attributes) = 0; }; @@ -213,10 +213,11 @@ class ParamResolverInterface { // Looks up a group. It returns an error if the group is not ready or not // found. - virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) = 0; + virtual absl::Status LookupGroup(int32_t group_key, + CollGroupParams* group) = 0; // Aborts the resolver. After abortion the resolver can no longer be used. - virtual void StartAbort(const Status& s) = 0; + virtual void StartAbort(const absl::Status& s) = 0; }; // Graphs which utilize Collective Ops in a common instance must @@ -255,7 +256,7 @@ class NcclCommunicatorInterface; // instances and various distributed resolution capabilities. class CollectiveExecutorMgrInterface : public StepSequenceInterface { public: - virtual ~CollectiveExecutorMgrInterface() {} + ~CollectiveExecutorMgrInterface() override {} // Returns the step-specific CollectiveExecutor, creating if one does not // already exist. The caller assumes ownership of one Ref on the object. @@ -310,14 +311,14 @@ class CollectiveRemoteAccess { virtual BufRendezvous* buf_rendezvous() = 0; - virtual void StartAbort(const Status& s) = 0; + virtual void StartAbort(const absl::Status& s) = 0; }; // A step-specific object that can execute a collective operation completely // described by a CollectiveParams object. class CollectiveExecutor : public core::RefCounted { public: - virtual void StartAbort(const Status& s) {} + virtual void StartAbort(const absl::Status& s) {} virtual void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params, @@ -344,7 +345,7 @@ class CollectiveExecutor : public core::RefCounted { cancel_mgr, done); } - virtual Status LookupGroup(int32_t group_key, CollGroupParams* group) { + virtual absl::Status LookupGroup(int32_t group_key, CollGroupParams* group) { return cem_->GetParamResolver()->LookupGroup(group_key, group); } @@ -428,7 +429,7 @@ class NcclCommunicatorInterface { virtual void Enqueue(std::shared_ptr col_ctx, StatusCallback done) = 0; - virtual void StartAbort(const Status& s) = 0; + virtual void StartAbort(const absl::Status& s) = 0; }; // Interface of a Collective Op implementation. Each specific CollectiveOp will @@ -437,7 +438,7 @@ class NcclCommunicatorInterface { // common_runtime/hierarchical_tree_broadcaster for examples. class CollectiveImplementationInterface : public core::RefCounted { public: - virtual ~CollectiveImplementationInterface() = default; + ~CollectiveImplementationInterface() override = default; // Initializes the portions of `col_params` specific to this // implementation. Called exactly once for every Collective instance during @@ -447,13 +448,14 @@ class CollectiveImplementationInterface : public core::RefCounted { // `col_params` passed in and should not manipulate any data members. However // because it is virtual and needs to be implemented by every derived class we // do not mark it as static. - virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0; + virtual absl::Status InitializeCollectiveParams( + CollectiveParams* col_params) = 0; // Prepares the CollectiveContext for executing this CollectiveImplementation. // Called from CollectiveExecutor right before calling Run(). The // CollectiveContext passed in must outlive the CollectiveImplementation // object. - virtual Status InitializeCollectiveContext( + virtual absl::Status InitializeCollectiveContext( std::shared_ptr col_ctx) = 0; // Processes and moves data according to the logic of this Collective @@ -471,14 +473,15 @@ class CollectiveRegistry { // Looks up a previously registered CollectiveImplementation under // `collective_name`. If found, creates an instance of the implementation and // assign to `implementation`. - static Status Lookup(const string& collective_name, - CollectiveImplementationInterface** implementation); + static absl::Status Lookup( + const string& collective_name, + CollectiveImplementationInterface** implementation); // Looks up a previously registered CollectiveImplementation under // `collective_name`. If found, returns the static instance of this // implementation via `implementation`. This instance should only be used to // call InitializateCollectiveParams. - static Status LookupParamResolverInstance( + static absl::Status LookupParamResolverInstance( const string& collective_name, CollectiveImplementationInterface** implementation); @@ -493,11 +496,11 @@ class CollectiveRegistry { // the CollectiveImplementation. Also creates a static instance of the // implementation - this instance is used during param resolution and should // only be used to call InitializeCollectiveParams. - static Status Register(const string& collective_name, Factory factory); + static absl::Status Register(const string& collective_name, Factory factory); - static Status LookupHelper(const string& collective_name, - CollectiveImplementationInterface** implementation, - bool param_resolver); + static absl::Status LookupHelper( + const string& collective_name, + CollectiveImplementationInterface** implementation, bool param_resolver); }; // Class used to call CollectiveRegistry::Register. This should only be used to diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 03470e6dd298f9..d50a831826b9df 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -1037,7 +1037,7 @@ struct AnyContext { // defined below. class IteratorBase : public Checkpointable { public: - virtual ~IteratorBase() { + ~IteratorBase() override { for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { (*rit)(); } diff --git a/tensorflow/core/framework/device_base.cc b/tensorflow/core/framework/device_base.cc index 8e33fc7fc00226..ac2d383f96ef5d 100644 --- a/tensorflow/core/framework/device_base.cc +++ b/tensorflow/core/framework/device_base.cc @@ -34,14 +34,13 @@ DeviceBase::~DeviceBase() { eigen_cpu_devices_.clear(); } -Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor, - StringPiece tensor_name, - Device* device, - Tensor* cpu_tensor) { +absl::Status DeviceContext::CopyDeviceTensorToCPUSync( + const Tensor* device_tensor, StringPiece tensor_name, Device* device, + Tensor* cpu_tensor) { absl::Notification n; - Status status; + absl::Status status; CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor, - [&](const Status& s) { + [&](const absl::Status& s) { status = s; n.Notify(); }); @@ -49,13 +48,12 @@ Status DeviceContext::CopyDeviceTensorToCPUSync(const Tensor* device_tensor, return status; } -Status DeviceContext::CopyCPUTensorToDeviceSync(const Tensor* cpu_tensor, - Device* device, - Tensor* device_tensor) const { +absl::Status DeviceContext::CopyCPUTensorToDeviceSync( + const Tensor* cpu_tensor, Device* device, Tensor* device_tensor) const { absl::Notification n; - Status status; + absl::Status status; CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, - [&](const Status& s) { + [&](const absl::Status& s) { status = s; n.Notify(); }); diff --git a/tensorflow/core/framework/full_type_inference_util.cc b/tensorflow/core/framework/full_type_inference_util.cc index 285b436e5d3eb2..029ca251b536c2 100644 --- a/tensorflow/core/framework/full_type_inference_util.cc +++ b/tensorflow/core/framework/full_type_inference_util.cc @@ -89,12 +89,13 @@ TypeInferenceFn Merge() { continue; } - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected compatible input types, but input ", - i, ":\n", t.DebugString(), - " is neither a subtype nor a supertype of the " - "combined inputs preceding it:\n", - merged.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected compatible input types, but input ", i, ":\n", + t.DebugString(), + " is neither a subtype nor a supertype of the " + "combined inputs preceding it:\n", + merged.DebugString())); } FullTypeDef ret_type; @@ -138,9 +139,10 @@ TypeInferenceFn Decode(FullTypeId t, int i) { const FullTypeId enc_tid = GetArgDefaultUnset(in_t, 1).type_id(); if ((enc_tid != TFT_UNSET) && (enc_tid != t)) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected encoded type ", t, " for input ", i, - ", got ", in_t.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected encoded type ", t, " for input ", i, ", got ", + in_t.DebugString())); } FullTypeDef ret_type; @@ -191,7 +193,7 @@ TypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx, if (in_cont_t.type_id() != TFT_UNSET) { if (in_cont_t.type_id() != t) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("expected container type ", t, " for input ", container_idx, ", got ", in_cont_t.DebugString())); @@ -225,14 +227,15 @@ TypeInferenceFn UnaryContainerAdd(FullTypeId t, int container_idx, } if (homogeneous) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected a subtype of ", el_t.DebugString(), - " for input ", element_idx, - " of a homogeneous container ", t, ", got ", - in_el_t.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected a subtype of ", el_t.DebugString(), + " for input ", element_idx, + " of a homogeneous container ", t, ", got ", + in_el_t.DebugString())); } else { // TODO(mdan): Implement if needed. - return Status( + return absl::Status( absl::StatusCode::kUnimplemented, absl::StrCat("need union types for heterogeneous containers.\n" "A homogeneous container would expect a subtype of ", @@ -287,9 +290,10 @@ TypeInferenceFn ContainerMap( return ret_type; } if (in_cont_t.type_id() != t) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected type ", t, " for input ", input_idx, - ", got ", in_cont_t.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected type ", t, " for input ", input_idx, ", got ", + in_cont_t.DebugString())); } ret_type.set_type_id(TFT_PRODUCT); FullTypeDef* out_cont_t = ret_type.add_args(); @@ -299,9 +303,10 @@ TypeInferenceFn ContainerMap( return ret_type; } if (in_el_t.type_id() != TFT_PRODUCT) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected PRODUCT element type for input ", - input_idx, ", got ", in_el_t.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected PRODUCT element type for input ", input_idx, + ", got ", in_el_t.DebugString())); } FullTypeDef* out_el_t = out_cont_t->add_args(); out_el_t->set_type_id(TFT_PRODUCT); @@ -324,9 +329,10 @@ TypeInferenceFn MapCovariant(FullTypeId t, FullTypeId u, int input_idx) { return ret_type; } if (in_t.type_id() != t) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("expected type ", t, " for input ", - input_idx, ", got ", in_t.DebugString())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("expected type ", t, " for input ", input_idx, + ", got ", in_t.DebugString())); } ret_type.set_type_id(TFT_PRODUCT); FullTypeDef* t = ret_type.add_args(); @@ -368,7 +374,7 @@ TypeInferenceFn Tuple(const std::vector& func_list) { return unset_type; } if (t.type_id() != TFT_PRODUCT) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("for Tuple type inference function, expected result " "of type inference function ", diff --git a/tensorflow/core/framework/full_type_util.h b/tensorflow/core/framework/full_type_util.h index 27d10e81d8d990..4039f3c812ad64 100644 --- a/tensorflow/core/framework/full_type_util.h +++ b/tensorflow/core/framework/full_type_util.h @@ -75,8 +75,8 @@ OpTypeConstructor VariadicTensorContainer(FullTypeId t, const string& var_name); // specified in an op definition. Such types are usually generic and dependent // on input types. This function resolves the output types based on the input // types specified in a given node def. -Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, - FullTypeDef& target); +absl::Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, + FullTypeDef& target); const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i); const FullTypeDef& GetArgDefaultAny(const FullTypeDef& t, int i); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 3cbfde22c75bc6..4c81d7b79ed457 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" #include "absl/types/variant.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/cancellation.h" @@ -51,7 +52,6 @@ limitations under the License. #include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/threadpool_interface.h" #include "tensorflow/core/protobuf/config.pb.h" -#include "tsl/protobuf/error_codes.pb.h" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/protobuf/remote_tensor_handle.pb.h" #endif // IS_MOBILE_PLATFORM @@ -256,7 +256,7 @@ inline FunctionDefHelper::AttrValueWrapper::AttrValueWrapper(StringPiece val) { // GetFunctionSignature(func name, opdef) returns OK if the func name is found // and opdef is filled with a pointer to the corresponding signature // (a OpDef proto). Otherwise, returns an error. -typedef std::function +typedef std::function GetFunctionSignature; struct InstantiationResult { @@ -264,9 +264,9 @@ struct InstantiationResult { DataTypeVector ret_types; std::vector nodes; }; -Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, - GetFunctionSignature get_function, - InstantiationResult* result); +absl::Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result); // Returns a debug string for a function definition. // @@ -300,7 +300,7 @@ class CallFrameInterface { virtual size_t num_args() const = 0; virtual size_t num_retvals() const = 0; - virtual Status GetArg(int index, const Tensor** val) = 0; + virtual absl::Status GetArg(int index, const Tensor** val) = 0; // Optimized implementation of `GetArg()` that allows the caller to take // ownership of the tensor. This method may only be called once per @@ -313,7 +313,7 @@ class CallFrameInterface { } virtual bool CanConsumeArg(int index) const { return false; } - virtual Status SetRetval(int index, const Tensor& val) = 0; + virtual absl::Status SetRetval(int index, const Tensor& val) = 0; }; // Represents a function call frame. I.e., the data structure used to @@ -328,19 +328,20 @@ class FunctionCallFrame : public CallFrameInterface { ~FunctionCallFrame() override; // Caller methods. - Status SetArgs(absl::Span args); - Status GetRetvals(std::vector* rets) const; + absl::Status SetArgs(absl::Span args); + absl::Status GetRetvals(std::vector* rets) const; // Moves the return values from the frame to rets. If allow_dead_tensors is // false it will fail if any of the retvals do not have a value. - Status ConsumeRetvals(std::vector* rets, bool allow_dead_tensors); + absl::Status ConsumeRetvals(std::vector* rets, + bool allow_dead_tensors); size_t num_args() const override { return arg_types_.size(); } size_t num_retvals() const override { return ret_types_.size(); } // Callee methods. - Status GetArg(int index, const Tensor** val) override; - Status SetRetval(int index, const Tensor& val) override; + absl::Status GetArg(int index, const Tensor** val) override; + absl::Status SetRetval(int index, const Tensor& val) override; private: DataTypeVector arg_types_; @@ -449,12 +450,13 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // // Associates `graph` with a function `func_name`. Lifetime assumption: // `graph` has to outlive all instantiated graphs. - Status AddFunctionDef(const FunctionDef& fdef, - const StackTracesMap& stack_traces = {}) + absl::Status AddFunctionDef(const FunctionDef& fdef, + const StackTracesMap& stack_traces = {}) TF_LOCKS_EXCLUDED(mu_); - Status AddFunctionDef(FunctionDef&& fdef, StackTracesMap&& stack_traces = {}) + absl::Status AddFunctionDef(FunctionDef&& fdef, + StackTracesMap&& stack_traces = {}) TF_LOCKS_EXCLUDED(mu_); - Status AddFunctionRecord(core::RefCountPtr record) + absl::Status AddFunctionRecord(core::RefCountPtr record) TF_LOCKS_EXCLUDED(mu_); // Adds gradient definition 'grad' to this function library. @@ -462,27 +464,27 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // If 'grad' is successfully added, it will be accessible via 'FindGradient' // and included in the proto returned by 'ToProto'. // This operation is atomic. - Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); + absl::Status AddGradientDef(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); // Replaces the function corresponding to `func` with `fdef`. Returns // a non-OK status if "func" was not found in the library, OK otherwise. // Please be careful when replacing function: make sure all previous pointers // returned by `Find()` are no longer in use. - Status ReplaceFunction(const std::string& func, const FunctionDef& fdef, - const StackTracesMap& stack_traces = {}) + absl::Status ReplaceFunction(const std::string& func, const FunctionDef& fdef, + const StackTracesMap& stack_traces = {}) TF_LOCKS_EXCLUDED(mu_); // Replaces the gradient corresponding to `grad.function_name()`. Returns // a non-OK status if "grad.function_name()" was not found in the library, OK // otherwise. - Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); + absl::Status ReplaceGradient(const GradientDef& grad) TF_LOCKS_EXCLUDED(mu_); // Removes the function corresponding to 'func'. Returns a non-OK status if // 'func' was not found in the library, OK otherwise. // Please be careful when removing function: make sure there are no other // nodes using the function, and all previous pointers returned by `Find()` // are no longer in use. - Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); + absl::Status RemoveFunction(const std::string& func) TF_LOCKS_EXCLUDED(mu_); // Removes all the functions and gradient functions. void Clear() TF_LOCKS_EXCLUDED(mu_); @@ -490,24 +492,26 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. - Status AddLibrary(const FunctionLibraryDefinition& other) + absl::Status AddLibrary(const FunctionLibraryDefinition& other) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(FunctionLibraryDefinition&& other) TF_LOCKS_EXCLUDED(mu_); - Status AddLibrary(FunctionLibraryDefinition&& other) TF_LOCKS_EXCLUDED(mu_); // Adds the functions and gradients in 'lib_def' to this function library. // Duplicate functions and gradients are ignored. This overload adds the // functions with no stack traces. This operation is atomic. - Status AddLibrary(const FunctionDefLibrary& lib_def) TF_LOCKS_EXCLUDED(mu_); - Status AddLibrary(FunctionDefLibrary&& lib_def) TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(const FunctionDefLibrary& lib_def) + TF_LOCKS_EXCLUDED(mu_); + absl::Status AddLibrary(FunctionDefLibrary&& lib_def) TF_LOCKS_EXCLUDED(mu_); // Adds the functions and gradients in 'lib_def' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. - Status AddLibrary(const FunctionDefLibrary& lib_def, - const FunctionDefLibraryStackTraces& library_traces) + absl::Status AddLibrary(const FunctionDefLibrary& lib_def, + const FunctionDefLibraryStackTraces& library_traces) TF_LOCKS_EXCLUDED(mu_); - Status AddLibrary(FunctionDefLibrary&& lib_def, - const FunctionDefLibraryStackTraces& library_traces) + absl::Status AddLibrary(FunctionDefLibrary&& lib_def, + const FunctionDefLibraryStackTraces& library_traces) TF_LOCKS_EXCLUDED(mu_); // If the gradient function for 'func' is specified explicitly in @@ -524,8 +528,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // // NB: This function outputs a borrowed pointer, which can be invalidated by a // subsequent call to `ReplaceFunction()` with the given name. - Status LookUp(const std::string& op_type_name, - const OpRegistrationData** op_reg_data) const override + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override TF_LOCKS_EXCLUDED(mu_); // Generates new function name with the specified prefix that is unique @@ -538,13 +542,15 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // iff the attribute is given by the function's definition. // TODO(irving): Remove; keep only the const Node& version. template - Status GetAttr(const NodeDef& ndef, const std::string& attr, T* value) const; + absl::Status GetAttr(const NodeDef& ndef, const std::string& attr, + T* value) const; // Given a node, inspects attributes of the callee function to derive the // attribute 'value' for 'attr'. Returns OK iff the attribute is given by the // function's definition. template - Status GetAttr(const Node& node, const std::string& attr, T* value) const; + absl::Status GetAttr(const Node& node, const std::string& attr, + T* value) const; // Returns a proto representation of the state of this function library. FunctionDefLibrary ToProto() const TF_LOCKS_EXCLUDED(mu_); @@ -579,8 +585,8 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // name `func` already exists in this function library, and has the same // implementation as in `other`. If the implementations conflict, an invalid // argument error is returned. - Status CopyFunctionDefFrom(const std::string& name, - const FunctionLibraryDefinition& other); + absl::Status CopyFunctionDefFrom(const std::string& name, + const FunctionLibraryDefinition& other); // Returns graph with debug stack traces for the given function, or `nullptr` // if none found. @@ -643,14 +649,15 @@ class FunctionLibraryDefinition : public OpRegistryInterface { std::string FindGradientHelper(const std::string& func) const TF_SHARED_LOCKS_REQUIRED(mu_); - Status AddHelper(FunctionRecord* registration, bool* added) + absl::Status AddHelper(FunctionRecord* registration, bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Same as AddFunctionDef/AddGradientDef except these methods set // `added` to true if the `fdef`/`grad` were actually added to this. - Status AddFunctionDefHelper(FunctionDef&& fdef, StackTracesMap&& stack_traces, - bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status AddGradientDefHelper(const GradientDef& grad, bool* added) + absl::Status AddFunctionDefHelper(FunctionDef&& fdef, + StackTracesMap&& stack_traces, bool* added) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status AddGradientDefHelper(const GradientDef& grad, bool* added) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Helper function for GetAttr. Returns the FunctionDef* to get the @@ -660,19 +667,19 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // Remove all functions in `funcs` and all gradients of functions in // `funcs_with_grads` from this library. - Status Remove(const std::vector& funcs, - const std::vector& funcs_with_grads) + absl::Status Remove(const std::vector& funcs, + const std::vector& funcs_with_grads) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Remove `func` from the library. Returns non-OK Status unless `func` is in // the library. This should only be called when there is a guarantee that the // function being removed hasn't been retrieved with `Find`. - Status RemoveFunctionHelper(const std::string& func) + absl::Status RemoveFunctionHelper(const std::string& func) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Remove gradient of function `func` from the library. Returns non-OK Status // unless `func` has a gradient. - Status RemoveGradient(const std::string& func) + absl::Status RemoveGradient(const std::string& func) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); mutable mutex mu_; @@ -708,7 +715,7 @@ struct FunctionArgIndex { class FunctionLibraryRuntime : public core::WeakRefCounted { public: - virtual ~FunctionLibraryRuntime() {} + ~FunctionLibraryRuntime() override {} // Instantiate a function with the given "attrs". // @@ -815,10 +822,10 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // If provided, this optimization function will be invoked before // the placer for multi-device functions. - std::function /*ret_node_names*/, - std::vector /*keep_node_names*/, - FunctionLibraryDefinition*, const DeviceSet&, - Device* /*cpu_device*/, std::unique_ptr*)> + std::function /*ret_node_names*/, + std::vector /*keep_node_names*/, + FunctionLibraryDefinition*, const DeviceSet&, + Device* /*cpu_device*/, std::unique_ptr*)> optimize_graph_fn; // If set, partitioned functions will be added to `graph_collector`. @@ -874,17 +881,18 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { bool allow_soft_placement = false; }; typedef uint64 Handle; - virtual Status Instantiate(const std::string& function_name, AttrSlice attrs, - const InstantiateOptions& options, - Handle* handle) = 0; - Status Instantiate(const std::string& function_name, AttrSlice attrs, - Handle* handle) { + virtual absl::Status Instantiate(const std::string& function_name, + AttrSlice attrs, + const InstantiateOptions& options, + Handle* handle) = 0; + absl::Status Instantiate(const std::string& function_name, AttrSlice attrs, + Handle* handle) { auto opts = absl::make_unique(); return Instantiate(function_name, attrs, *opts, handle); } // Releases state associated with the handle. - virtual Status ReleaseHandle(Handle handle) = 0; + virtual absl::Status ReleaseHandle(Handle handle) = 0; // Returns the function body for the instantiated function given its // handle 'h'. Returns nullptr if "h" is not found. @@ -894,7 +902,7 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { virtual const FunctionBody* GetFunctionBody(Handle h) = 0; // Returns the return types for the function identified by handle `h`. - virtual Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0; + virtual absl::Status GetRetTypes(Handle h, DataTypeVector* ret_types) = 0; // Asynchronously invokes the instantiated function identified by // "handle". @@ -962,24 +970,24 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // Returns a human readable representation of this. std::string DebugString() const; }; - typedef std::function DoneCallback; + typedef std::function DoneCallback; virtual void Run(const Options& opts, Handle handle, absl::Span args, std::vector* rets, DoneCallback done) = 0; virtual void Run(const Options& opts, Handle handle, CallFrameInterface* call_frame, DoneCallback done) = 0; - virtual Status RunSync(Options opts, Handle handle, - absl::Span args, - std::vector* rets) = 0; - virtual Status RunSync(Options opts, Handle handle, - CallFrameInterface* call_frame) = 0; + virtual absl::Status RunSync(Options opts, Handle handle, + absl::Span args, + std::vector* rets) = 0; + virtual absl::Status RunSync(Options opts, Handle handle, + CallFrameInterface* call_frame) = 0; // Creates a "kernel" for the given NodeProperties "props". // // If succeeds, returns OK and the caller takes the ownership of the // returned "*kernel". Otherwise, returns an error. - virtual Status CreateKernel( + virtual absl::Status CreateKernel( const std::shared_ptr& props, OpKernel** kernel) = 0; @@ -1040,10 +1048,10 @@ class FunctionLibraryRuntime : public core::WeakRefCounted { // FunctionLibraryDefinitions for its functions independently (and passes // these into the FunctionLibraryRuntime through an overlay), to avoid linear // runtime w.r.t. to number of functions in the current function library. - virtual Status Clone(std::unique_ptr* out_lib_def, - std::unique_ptr* out_pflr, - FunctionLibraryRuntime** out_flr, - bool skip_flib_def = false) = 0; + virtual absl::Status Clone( + std::unique_ptr* out_lib_def, + std::unique_ptr* out_pflr, + FunctionLibraryRuntime** out_flr, bool skip_flib_def = false) = 0; // Returns the name of the executor class (in the sense of // `ExecutorFactory::GetFactory()`) that will be used based on the given @@ -1084,7 +1092,7 @@ class CustomKernelCreator { const std::shared_ptr& props) const = 0; // Given a supported NodeDef, returns a kernel that computes the node. - virtual Status CreateKernel( + virtual absl::Status CreateKernel( FunctionLibraryRuntime* flr, const std::shared_ptr& props, std::unique_ptr* kernel) const = 0; @@ -1163,8 +1171,8 @@ class DistributedFunctionLibraryRuntime { // Otherwise (arg_def is a simple type T), *is_type_list is set to // false, and *dtypes is set to a single element vector, whose only // element is T. -Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, - bool* is_type_list, DataTypeVector* dtypes); +absl::Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, + bool* is_type_list, DataTypeVector* dtypes); // To register a gradient function for a builtin op, one should use // REGISTER_OP_GRADIENT(, ); @@ -1227,12 +1235,13 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, namespace gradient { // Register a gradient creator for the "op". -typedef std::function Creator; +typedef std::function + Creator; bool RegisterOp(const std::string& op, Creator func); // Returns OK the gradient creator for the "op" is found (may be // nullptr if REGISTER_OP_NO_GRADIENT is used. -Status GetOpGradientCreator(const std::string& op, Creator* creator); +absl::Status GetOpGradientCreator(const std::string& op, Creator* creator); }; // namespace gradient // Declare explicit instantiations of GetAttr diff --git a/tensorflow/core/framework/function_handle_cache.h b/tensorflow/core/framework/function_handle_cache.h index 41c73e29bba815..1bd67138d1964f 100644 --- a/tensorflow/core/framework/function_handle_cache.h +++ b/tensorflow/core/framework/function_handle_cache.h @@ -34,13 +34,13 @@ class FunctionHandleCache { // // The cache retains the ownership of the handle. In particular, the caller // should not invoke `ReleaseHandle`. - Status Instantiate(const string& function_name, AttrSlice attrs, - FunctionLibraryRuntime::InstantiateOptions options, - FunctionLibraryRuntime::Handle* handle); + absl::Status Instantiate(const string& function_name, AttrSlice attrs, + FunctionLibraryRuntime::InstantiateOptions options, + FunctionLibraryRuntime::Handle* handle); // Releases all the handles in the cache, clearing out the state for all // functions involved. - Status Clear(); + absl::Status Clear(); private: mutex mu_; diff --git a/tensorflow/core/framework/function_test.cc b/tensorflow/core/framework/function_test.cc index e32c9852bc2902..150d09433c1bf8 100644 --- a/tensorflow/core/framework/function_test.cc +++ b/tensorflow/core/framework/function_test.cc @@ -69,7 +69,7 @@ class Attrs { typedef FunctionDefHelper FDH; -Status GetOpSig(const string& op, const OpDef** sig) { +absl::Status GetOpSig(const string& op, const OpDef** sig) { return OpRegistry::Global()->LookUpOpDef(op, sig); } @@ -633,7 +633,7 @@ TEST(TFunc, IntsOnDeviceArgSet) { EXPECT_EQ("_DeviceRetval", result.nodes[4].op()); } -static void HasError(const Status& s, const string& substr) { +static void HasError(const absl::Status& s, const string& substr) { EXPECT_TRUE(absl::StrContains(s.ToString(), substr)) << ">>" << s << "<<, expected substring >>" << substr << "<<"; } @@ -1109,7 +1109,7 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) { // Test that adding a function with same name as existing op fails. FunctionDef fdef = test::function::XTimesTwo(); fdef.mutable_signature()->set_name("Add"); - Status s = lib_def.AddFunctionDef(fdef); + absl::Status s = lib_def.AddFunctionDef(fdef); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.message(), "Cannot add function 'Add' because an op with the same name " @@ -1151,7 +1151,7 @@ TEST(FunctionLibraryDefinitionTest, AddGradientDef) { // Test that adding a duplicate gradient fails grad.set_gradient_func(test::function::XTimes16().signature().name()); - Status s = lib_def.AddGradientDef(grad); + absl::Status s = lib_def.AddGradientDef(grad); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.message(), "Cannot assign gradient function 'XTimes16' to 'XTimesTwo' because " @@ -1162,7 +1162,7 @@ TEST(FunctionLibraryDefinitionTest, RemoveFunction) { FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); TF_CHECK_OK(lib_def.AddFunctionDef(test::function::XTimesTwo())); - Status s = lib_def.RemoveFunction("XTimes16"); + absl::Status s = lib_def.RemoveFunction("XTimes16"); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.message(), "Tried to remove non-existent function 'XTimes16'."); @@ -1200,7 +1200,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary) { test::function::XTimesTwo().signature().name()); *proto.add_function() = fdef; FunctionLibraryDefinition lib_def2(OpRegistry::Global(), proto); - Status s = lib_def.AddLibrary(lib_def2); + absl::Status s = lib_def.AddLibrary(lib_def2); EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT); EXPECT_EQ(s.message(), "Cannot add function 'XTimesTwo' because a different function with " @@ -1248,7 +1248,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibrary_Atomic) { FunctionLibraryDefinition lib_def(OpRegistry::Global(), FunctionDefLibrary()); // Try adding the two functions to lib_def - Status s = lib_def.AddLibrary(proto); + absl::Status s = lib_def.AddLibrary(proto); EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); EXPECT_EQ( "Cannot add function 'XTimesTwo' because a different function with " @@ -1299,7 +1299,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_FuncConflict) { // Verify that adding lib_def2 will fail because of function conflict // and WXPlusB is not added. - Status s = lib_def.AddLibrary(lib_def2); + absl::Status s = lib_def.AddLibrary(lib_def2); EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); EXPECT_EQ( "Cannot add function 'XTimesTwo' because a different function " @@ -1335,7 +1335,7 @@ TEST(FunctionLibraryDefinitionTest, AddLibraryDefinition_Atomic_GradConflict) { // Verify that adding lib_def2 will fail because of gradient conflict // and WXPlusB is not added. - Status s = lib_def.AddLibrary(lib_def2); + absl::Status s = lib_def.AddLibrary(lib_def2); EXPECT_EQ(error::Code::INVALID_ARGUMENT, s.code()); EXPECT_EQ( "Cannot assign gradient function 'WXPlusB' to 'XTimesTwo'" diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h index 9ebe610bdca295..a164ac310fe4ed 100644 --- a/tensorflow/core/framework/graph_def_util.h +++ b/tensorflow/core/framework/graph_def_util.h @@ -43,7 +43,7 @@ string SummarizeGraphDef(const GraphDef& graph_def); // DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? // ControlInput = "^", NodeName // NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * -Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); +absl::Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); // Adds default attributes to NodeDefs in 'graph_def' starting // from the 'node_offset' node in 'graph_def'. @@ -54,15 +54,15 @@ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def); // that cannot be found in 'op_registry'. // // REQUIRES: 'graph_def' and 'op_registry' are not nullptr. -Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, - const OpRegistryInterface& op_registry, - int node_offset); +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset); // Same as above, except for the fact that it skips nodes that aren't found in // op_registry if skip_unknown_ops is true. -Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, - const OpRegistryInterface& op_registry, - int node_offset, bool skip_unknown_ops); +absl::Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, + const OpRegistryInterface& op_registry, + int node_offset, bool skip_unknown_ops); // Remove attrs from 'graph_def' that have the default value according // to 'producer_op_registry', but don't exist according to @@ -94,7 +94,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, // OpListOpRegistry producer_op_registry(producer_stripped_op_list); // TF_RETURN_IF_ERROR(RemoveNewDefaultAttrsFromGraphDef( // &graph_def, *OpRegistry::Global(), producer_op_registry, nullptr)); -Status RemoveNewDefaultAttrsFromGraphDef( +absl::Status RemoveNewDefaultAttrsFromGraphDef( GraphDef* graph_def, const OpRegistryInterface& consumer_op_registry, const OpRegistryInterface& producer_op_registry, std::set>* op_attr_removed); @@ -126,9 +126,9 @@ void OpsUsedByGraph(const GraphDef& graph_def, // // Most users will pass *OpRegistry::Global() for op_registry to strip against // the list of ops registered in this process. -Status StrippedOpListForGraph(const GraphDef& graph_def, - const OpRegistryInterface& op_registry, - OpList* stripped_op_list); +absl::Status StrippedOpListForGraph(const GraphDef& graph_def, + const OpRegistryInterface& op_registry, + OpList* stripped_op_list); } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util_test.cc b/tensorflow/core/framework/graph_def_util_test.cc index 14089aca51ef1d..78731dcee992db 100644 --- a/tensorflow/core/framework/graph_def_util_test.cc +++ b/tensorflow/core/framework/graph_def_util_test.cc @@ -28,9 +28,9 @@ limitations under the License. namespace tensorflow { namespace { -Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) { +absl::Status FinalizeOpDef(const OpDefBuilder& b, OpDef* op_def) { OpRegistrationData op_reg_data; - const Status s = b.Finalize(&op_reg_data); + const absl::Status s = b.Finalize(&op_reg_data); *op_def = op_reg_data.op_def; return s; } diff --git a/tensorflow/core/framework/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h index 5af678645527bd..369b86ecea5e03 100644 --- a/tensorflow/core/framework/graph_to_functiondef.h +++ b/tensorflow/core/framework/graph_to_functiondef.h @@ -29,17 +29,17 @@ namespace tensorflow { // Graph to FunctionDef conversion. This code is closely modeled on the Python // function graph_to_function_def(), which is located in // tensorflow/python/framework/graph_to_function_def.py. -Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, - bool append_hash_to_fn_name, - bool set_stateful_from_nodes, - bool copy_placeholder_attrs_from_nodes, - const std::vector& body_nodes, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& output_names, - const std::vector& control_outputs, - const std::vector& control_output_names, - const char* description, FunctionDef* fdef); +absl::Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + bool append_hash_to_fn_name, + bool set_stateful_from_nodes, + bool copy_placeholder_attrs_from_nodes, + const std::vector& body_nodes, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& output_names, + const std::vector& control_outputs, + const std::vector& control_output_names, + const char* description, FunctionDef* fdef); // Converts 'graph' to a FunctionDef 'fdef', with name 'name': // @@ -49,19 +49,19 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, // be added to the function `control_ret` map (see FunctionDef) and // `control_output` in Op definition (see OpDef). Control output name must // be unique for all control output nodes. -Status GraphToFunctionDef( +absl::Status GraphToFunctionDef( const Graph& graph, const string& name, const std::function(const Node*)>& control_ret, FunctionDef* fdef); -Status GraphToFunctionDef(const Graph& graph, const string& name, - FunctionDef* fdef); +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + FunctionDef* fdef); -Status GraphToFunctionDef(const Graph& graph, const string& name, - const std::vector& output_names, - FunctionDef* fdef); +absl::Status GraphToFunctionDef(const Graph& graph, const string& name, + const std::vector& output_names, + FunctionDef* fdef); -Status GraphToFunctionDef( +absl::Status GraphToFunctionDef( std::unique_ptr graph, const string& name, const std::function(const Node*)>& control_ret, FunctionDef* fdef); diff --git a/tensorflow/core/framework/kernel_def_util.h b/tensorflow/core/framework/kernel_def_util.h index b973cefc4f4d24..b60b3b2c95a0f6 100644 --- a/tensorflow/core/framework/kernel_def_util.h +++ b/tensorflow/core/framework/kernel_def_util.h @@ -23,8 +23,8 @@ namespace tensorflow { // Returns whether the attrs satisfy the constraints in the kernel_def. Returns // an error if attrs in kernel_def are not found, or have a mismatching type. -Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, - bool* match); +absl::Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, + bool* match); } // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_shape_util.h b/tensorflow/core/framework/kernel_shape_util.h index 551a863e3d38e5..6d444e18a6adf7 100644 --- a/tensorflow/core/framework/kernel_shape_util.h +++ b/tensorflow/core/framework/kernel_shape_util.h @@ -72,10 +72,10 @@ namespace tensorflow { // size and padding of each spatial dimension can be computed by calling // GetWindowedOutputSize separately for each dimension. // -Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, - int dilation_rate, int64_t stride, - Padding padding_type, int64_t* output_size, - int64_t* padding_size); +absl::Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, + int dilation_rate, int64_t stride, + Padding padding_type, int64_t* output_size, + int64_t* padding_size); // Returns the same output dimensions as in GetWindowedOutputSize, but returns // verbose padding dimensions (before/after), and EXPLICIT padding is supported. @@ -84,12 +84,10 @@ Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, // *padding_before and *padding_after are set by this function, and any // excess padding (caused by an odd padding size value) is added to the // 'padding_after' dimension. -Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, - int64_t dilation_rate, int64_t stride, - Padding padding_type, - int64_t* output_size, - int64_t* padding_before, - int64_t* padding_after); +absl::Status GetWindowedOutputSizeVerbose( + int64_t input_size, int64_t filter_size, int64_t dilation_rate, + int64_t stride, Padding padding_type, int64_t* output_size, + int64_t* padding_before, int64_t* padding_after); // Given an input tensor, kernel, stride and padding type, populates the 3D size // of the output tensor and padding to be applied to the input tensor at the @@ -99,13 +97,13 @@ Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, // padding is not supported. // The V2 version computes the same outputs with arbitrary dilation_rate. For // detailed equations, refer to the comments for GetWindowedOutputSize(). -Status Get3dOutputSizeV2(const std::array& input, - const std::array& window, - const std::array& dilations, - const std::array& strides, - Padding padding_type, - std::array* output_ptr, - std::array* padding_ptr); +absl::Status Get3dOutputSizeV2(const std::array& input, + const std::array& window, + const std::array& dilations, + const std::array& strides, + Padding padding_type, + std::array* output_ptr, + std::array* padding_ptr); } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_SHAPE_UTIL_H_ diff --git a/tensorflow/core/framework/local_rendezvous.h b/tensorflow/core/framework/local_rendezvous.h index 48affe938776f3..332daaa6c02060 100644 --- a/tensorflow/core/framework/local_rendezvous.h +++ b/tensorflow/core/framework/local_rendezvous.h @@ -49,14 +49,14 @@ class LocalRendezvous { table_buckets_(std::make_unique(num_buckets_)) {} ~LocalRendezvous(); - Status Send(const Rendezvous::ParsedKey& key, - const Rendezvous::Args& send_args, const Tensor& val, - bool is_dead); + absl::Status Send(const Rendezvous::ParsedKey& key, + const Rendezvous::Args& send_args, const Tensor& val, + bool is_dead); void RecvAsync(const Rendezvous::ParsedKey& key, const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done); - void StartAbort(const Status& status); - Status status(); + void StartAbort(const absl::Status& status); + absl::Status status(); // Releases all the references to the aborted rendezvous. Used in unit tests. static void ReleaseAbortedRendezvous() { @@ -65,7 +65,7 @@ class LocalRendezvous { } private: - void DoAbort(const Status& status); + void DoAbort(const absl::Status& status); tsl::core::RefCountPtr GetOwnerRefCountPtr(); @@ -101,7 +101,7 @@ class LocalRendezvous { // Immutable set of buckets. This uses less memory than std::vector. const std::unique_ptr table_buckets_; mutex mu_; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); // We deliberately leak one reference of the aborted rendezvous here, so that // they won't be destructed, and lose the status_. diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h index f0234af3110e7d..9d673fbca5769f 100644 --- a/tensorflow/core/framework/lookup_interface.h +++ b/tensorflow/core/framework/lookup_interface.h @@ -47,8 +47,8 @@ class LookupInterface : public ResourceBase { // fails. // - In addition, other implementations may provide another non-OK status // specific to their failure modes. - virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, - const Tensor& default_value) = 0; + virtual absl::Status Find(OpKernelContext* ctx, const Tensor& keys, + Tensor* values, const Tensor& default_value) = 0; // Inserts elements into the table. Each element of the key tensor is // associated with the corresponding element in the value tensor. @@ -61,8 +61,8 @@ class LookupInterface : public ResourceBase { // - InvalidArgument: if any of the preconditions on the lookup key or value // fails. // - Unimplemented: if the table does not support insertions. - virtual Status Insert(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) = 0; + virtual absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; // Removes elements from the table. // This method is only implemented in mutable tables that can be updated over @@ -73,7 +73,7 @@ class LookupInterface : public ResourceBase { // - OK: when the remove finishes successfully. // - InvalidArgument: if any of the preconditions on the lookup key fails. // - Unimplemented: if the table does not support removals. - virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; + virtual absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; // Returns the number of elements in the table. virtual size_t size() const = 0; @@ -82,14 +82,14 @@ class LookupInterface : public ResourceBase { // Note that the shape of the tensors is completely up to the implementation // of the table and can be different than the tensors used for the Insert // function above. - virtual Status ExportValues(OpKernelContext* ctx) = 0; + virtual absl::Status ExportValues(OpKernelContext* ctx) = 0; // Imports previously exported keys and values. // As mentioned above, the shape of the keys and values tensors are determined // by the ExportValues function above and can be different than for the // Insert function. - virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) = 0; + virtual absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) = 0; // Returns the data type of the key. virtual DataType key_dtype() const = 0; @@ -110,19 +110,19 @@ class LookupInterface : public ResourceBase { // - DataType of the tensor values equals to the table value_dtype // - the values tensor has the required shape given keys and the tables's // value shape. - virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, - const Tensor& values); + virtual absl::Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, + const Tensor& values); // Similar to the function above but instead checks eligibility for the Import // function. - virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, - const Tensor& values); + virtual absl::Status CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values); // Check format of the key tensor for the Remove function. // Returns OK if all the following requirements are satisfied, otherwise it // returns InvalidArgument: // - DataType of the tensor keys equals to the table key_dtype - virtual Status CheckKeyTensorForRemove(const Tensor& keys); + virtual absl::Status CheckKeyTensorForRemove(const Tensor& keys); // Check the arguments of a find operation. Returns OK if all the following // requirements are satisfied, otherwise it returns InvalidArgument: @@ -130,7 +130,8 @@ class LookupInterface : public ResourceBase { // - DataType of the tensor default_value equals to the table value_dtype // - the default_value tensor has the required shape given keys and the // tables's value shape. - Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); + absl::Status CheckFindArguments(const Tensor& keys, + const Tensor& default_value); string DebugString() const override { return strings::StrCat("A lookup table of size: ", size()); @@ -143,18 +144,18 @@ class LookupInterface : public ResourceBase { } protected: - virtual ~LookupInterface() = default; + ~LookupInterface() override = default; // Makes sure that the key and value tensor DataType's match the table // key_dtype and value_dtype. - Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); + absl::Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); // Makes sure that the provided shape is consistent with the table keys shape. - Status CheckKeyShape(const TensorShape& shape); + absl::Status CheckKeyShape(const TensorShape& shape); private: - Status CheckKeyAndValueTensorsHelper(const Tensor& keys, - const Tensor& values); + absl::Status CheckKeyAndValueTensorsHelper(const Tensor& keys, + const Tensor& values); }; } // namespace lookup diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h index f719131bcb4781..e124722297f4b7 100644 --- a/tensorflow/core/framework/memory_types.h +++ b/tensorflow/core/framework/memory_types.h @@ -28,10 +28,11 @@ class NodeDef; // // REQUIRES: * '*_memory_types' is not nullptr. // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). -Status MemoryTypesForNode(const OpRegistryInterface* op_registry, - const DeviceType& device_type, const NodeDef& ndef, - MemoryTypeVector* input_memory_types, - MemoryTypeVector* output_memory_types); +absl::Status MemoryTypesForNode(const OpRegistryInterface* op_registry, + const DeviceType& device_type, + const NodeDef& ndef, + MemoryTypeVector* input_memory_types, + MemoryTypeVector* output_memory_types); } // namespace tensorflow diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index 9dd80d3d0c7835..d94447cbfa1910 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/tsl/lib/monitoring/counter.h" #include "xla/tsl/lib/monitoring/gauge.h" #include "xla/tsl/lib/monitoring/sampler.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/protobuf/data_service.pb.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace metrics { diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index ff9caa326ff4bb..4c78ec7a51cf56 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -601,12 +601,12 @@ class Node { TF_LOCKS_EXCLUDED(mu_); // Produces a proto for this node. Does not produce a proto for input nodes. - virtual Status ToProto(ModelProto::Node* node_proto) const; + virtual absl::Status ToProto(ModelProto::Node* node_proto) const; // Restores a node from the proto. Does not restore input nodes. - static Status FromProto(ModelProto::Node node_proto, - std::shared_ptr output, - std::shared_ptr* node); + static absl::Status FromProto(ModelProto::Node node_proto, + std::shared_ptr output, + std::shared_ptr* node); // Returns a vector of nodes of the subtree rooted in this node. The nodes are // either in breadth-first search or reverse breadth-first search order @@ -811,8 +811,8 @@ class Node { // Restores node from the proto. Note that this is not done recursively, i.e. // input nodes are not restored. - static Status FromProtoHelper(ModelProto::Node node_proto, - std::shared_ptr node); + static absl::Status FromProtoHelper(ModelProto::Node node_proto, + std::shared_ptr node); // Stores the time passed to the last call to `Node::record_start()` on the // current thread. @@ -974,12 +974,12 @@ class Model { // // To terminate the execution of the optimization loop, the caller needs to // invoke `cancellation_mgr->StartCancel()`. - Status OptimizeLoop(AutotuneAlgorithm algorithm, - std::function cpu_budget_func, - double ram_budget_share, - std::optional fixed_ram_budget, - RamBudgetManager& ram_budget_manager, - CancellationManager* cancellation_manager); + absl::Status OptimizeLoop(AutotuneAlgorithm algorithm, + std::function cpu_budget_func, + double ram_budget_share, + std::optional fixed_ram_budget, + RamBudgetManager& ram_budget_manager, + CancellationManager* cancellation_manager); // Uses the given algorithm and resource budgets to perform the autotuning // optimization. @@ -1006,21 +1006,21 @@ class Model { void RemoveNode(std::shared_ptr node) TF_LOCKS_EXCLUDED(mu_); // Produces a proto for this model. - Status ToProto(ModelProto* model_proto); + absl::Status ToProto(ModelProto* model_proto); // Restores a model from the proto. - static Status FromProto(ModelProto model_proto, - std::unique_ptr* model); + static absl::Status FromProto(ModelProto model_proto, + std::unique_ptr* model); // Saves this model with a given snapshot and its optimization parameters to a // file. Note that the file directory must already exist. - Status Save(const string& fname, std::shared_ptr snapshot, - const OptimizationParams& optimization_params); + absl::Status Save(const string& fname, std::shared_ptr snapshot, + const OptimizationParams& optimization_params); // Loads a model and its optimization parameters from a file with the given // name. - static Status Load(const string& fname, std::unique_ptr* model, - OptimizationParams* optimization_params); + static absl::Status Load(const string& fname, std::unique_ptr* model, + OptimizationParams* optimization_params); // Records gap time between consecutive `GetNext()` calls. void RecordIteratorGapTime(uint64_t duration_usec); diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h index 84dcc9e5a8d6e3..5a19f774c7a199 100644 --- a/tensorflow/core/framework/node_def_builder.h +++ b/tensorflow/core/framework/node_def_builder.h @@ -34,8 +34,8 @@ limitations under the License. namespace tensorflow { class NodeDefBuilder; -typedef std::function +typedef std::function FakeInputFunctor; // This is a helper for creating a NodeDef. Automatically sets attrs @@ -137,7 +137,7 @@ class NodeDefBuilder { // and the builder will be left in an undefined state. // WARNING: Not all problems are detected! The resulting NodeDef may // not be valid! Call ValidateNodeDef() from node_def_utils to be sure. - Status Finalize(NodeDef* node_def, bool consume = false); + absl::Status Finalize(NodeDef* node_def, bool consume = false); // Accessors for the values set in the constructor. const string& node_name() const { return node_def_.name(); } diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc index 8531027b232ed1..c89932b13ee518 100644 --- a/tensorflow/core/framework/node_def_builder_test.cc +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -53,7 +53,7 @@ class NodeDefBuilderTest : public ::testing::Test { DataTypeSlice expected_in_types, DataTypeSlice expected_out_types, StringPiece proto) { NodeDef node_def; - Status status = builder.Finalize(&node_def); + absl::Status status = builder.Finalize(&node_def); TF_EXPECT_OK(status); if (!status.ok()) return; NodeDef expected; @@ -80,7 +80,7 @@ class NodeDefBuilderTest : public ::testing::Test { void ExpectFailures(NodeDefBuilder& builder, // NOLINT const std::vector& messages) { NodeDef node_def; - Status status = builder.Finalize(&node_def); + absl::Status status = builder.Finalize(&node_def); EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; for (const string& message : messages) { @@ -101,7 +101,7 @@ class NodeDefBuilderTest : public ::testing::Test { void ExpectInvalid(NodeDefBuilder& builder, // NOLINT const string& message) { NodeDef node_def; - Status status = builder.Finalize(&node_def); + absl::Status status = builder.Finalize(&node_def); if (status.ok()) { status = ValidateNodeDef(node_def, op_def_); } diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index cac0a537b97f6f..b5eb424a89bd58 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -158,9 +158,9 @@ class AttrSlice { // Returns the attr_value for attr_name if found. Otherwise, returns a // NotFound status. - Status Find(StringPiece attr_name, const AttrValue** attr_value) const; - Status FindByString(const std::string& attr_name, - const AttrValue** attr_value) const; + absl::Status Find(StringPiece attr_name, const AttrValue** attr_value) const; + absl::Status FindByString(const std::string& attr_name, + const AttrValue** attr_value) const; // Helper class to avoid allocations in EqualAttrs. // TODO(irving): Will go away once NodeInfo is used. @@ -196,7 +196,8 @@ class AttrSlice { return ndef_ != nullptr ? &ndef_->attr() : attrs_; } - Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const; + absl::Status CheckFind(StringPiece attr_name, + const AttrValue* attr_value) const; const NodeDef* ndef_; const AttrValueMap* attrs_; @@ -208,53 +209,55 @@ bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have // a matching type, a non-ok status will be returned. -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::string* value); // type: "string" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - tstring* value); // type: "tstring" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - int64_t* value); // type: "int" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - int32* value); // type: "int" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - float* value); // type: "float" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - bool* value); // type: "bool" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - DataType* value); // type: "type" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - TensorShapeProto* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - TensorShape* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - PartialTensorShape* value); // type: "shape" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - Tensor* value); // type: "tensor" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(string)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(tstring)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(int)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(int)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(float)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(bool)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(type)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - DataTypeVector* value); // type "list(type)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(shape)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type "list(shape)" -Status GetNodeAttr( +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::string* value); // type: "string" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + tstring* value); // type: "tstring" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + int64_t* value); // type: "int" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + int32* value); // type: "int" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + float* value); // type: "float" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + bool* value); // type: "bool" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + DataType* value); // type: "type" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + TensorShapeProto* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + TensorShape* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + PartialTensorShape* value); // type: "shape" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + Tensor* value); // type: "tensor" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(string)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(tstring)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(int)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(int)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(float)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(bool)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(type)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + DataTypeVector* value); // type "list(type)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(shape)" +absl::Status GetNodeAttr( + const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type "list(shape)" +absl::Status GetNodeAttr( const AttrSlice& attrs, StringPiece attr_name, std::vector* value); // type "list(shape)" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type: "list(tensor)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type: "list(tensor)" template StatusOr GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) { @@ -265,23 +268,24 @@ StatusOr GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) { // This version avoids copying the TensorProto. // REQUIRES: Must not use *value beyond the lifetime of node_def. -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - const TensorProto** value); // type: "tensor" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const TensorProto** value); // type: "tensor" bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const TensorProto** value); // type: "tensor" // This version avoids copying the NameAttrList. // REQUIRES: Must not use *value beyond the lifetime of node_def. -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - const NameAttrList** value); // type: "func" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + const NameAttrList** value); // type: "func" bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, const NameAttrList** value); // type: "func" // These versions copies the NameAttrList(s). -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - NameAttrList* value); // type: "func" -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - std::vector* value); // type: "list(func)" +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + NameAttrList* value); // type: "func" +absl::Status GetNodeAttr( + const AttrSlice& attrs, StringPiece attr_name, + std::vector* value); // type: "list(func)" // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in node_def, or the attr does not have @@ -334,36 +338,36 @@ const std::string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name); // Specialization to parse an attribute directly into a Padding enum. -Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, - Padding* value); +absl::Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, + Padding* value); // Computes the input type for a specific node input. // REQUIRES: ValidateOpDef(op_def).ok() -Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, - int input_port, DataType* input_type); +absl::Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int input_port, DataType* input_type); // Computes the input types for a specific node. // REQUIRES: ValidateOpDef(op_def).ok() -Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* inputs); +absl::Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs); // Computes the output type for a specific node output. // REQUIRES: ValidateOpDef(op_def).ok() -Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, - int output_port, DataType* output_type); +absl::Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, + int output_port, DataType* output_type); // Computes the output types for a specific node. // REQUIRES: ValidateOpDef(op_def).ok() -Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* outputs); -Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, - DataTypeVector* outputs); +absl::Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* outputs); +absl::Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, + DataTypeVector* outputs); // Computes the input and output types for a specific node. // REQUIRES: ValidateOpDef(op_def).ok() -Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, - DataTypeVector* inputs, DataTypeVector* outputs); +absl::Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs, DataTypeVector* outputs); // Computes the number of outputs for a specific node. // REQUIRES: ValidateOpDef(op_def).ok() -Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, - int* num_outputs); +absl::Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs); // Map a node/op's input/output port_id to arg_id. // @@ -381,7 +385,7 @@ int OpPortIdToArgId(const NodeDef& node, // * All attrs satisfies constraints from the OpDef. // * Has a signature matching SignatureForNode(). // etc. -Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); +absl::Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); // Computes the mapping from input/output argument name to the // corresponding input/output index range. For example, @@ -393,8 +397,8 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); // returned `NameRangeMap` objects. typedef gtl::FlatMap, hash> NameRangeMap; -Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, - NameRangeMap* inputs, NameRangeMap* outputs); +absl::Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); // Adds default values to *node_def for unspecified attrs from op_def. void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); @@ -413,30 +417,30 @@ void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def); // DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? // ControlInput = "^", NodeName // NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * -Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); +absl::Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); // Returns "status" with formatted NodeDef attached as additional text // in the error message. If 'allow_multiple_formatted_node' is false and there // is already a formatted NodeDef present in 'status', we simply attach the name // of the NodeDef instead of the formatted string. -Status AttachDef(const Status& status, const NodeDef& node_def, - bool allow_multiple_formatted_node = false); +absl::Status AttachDef(const absl::Status& status, const NodeDef& node_def, + bool allow_multiple_formatted_node = false); // Appends the given prefix and suffix to the original node name in order to // make the name unique. If it's an "Enter" node and uniquify_frame_name is // true, use the same way to reset attribute "frame_name". -Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, - NodeDef* node_def, - bool uniquify_frame_name = true); +absl::Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, + NodeDef* node_def, + bool uniquify_frame_name = true); // Appends the given prefix to the colocation group name if the name exists // in `to_match`. -Status MaybeAddPrefixToColocationConstraints( +absl::Status MaybeAddPrefixToColocationConstraints( const std::unordered_set& match, StringPiece prefix, NodeDef* node_def); // Updates the colocation constraint name with the one provided in the map (if // it exists in the map) for node_def. -Status MaybeUpdateColocationConstraintsWithMap( +absl::Status MaybeUpdateColocationConstraintsWithMap( const std::map& node_name_map, NodeDef* node_def); diff --git a/tensorflow/core/framework/node_properties.h b/tensorflow/core/framework/node_properties.h index 88489f44fc4acc..91c495bb2c8c1c 100644 --- a/tensorflow/core/framework/node_properties.h +++ b/tensorflow/core/framework/node_properties.h @@ -47,9 +47,9 @@ struct NodeProperties { // from the given NodeDef. 'op_registry' is used to look up the OpDef // corresponding to node_def.op(). Returns an error if OpDef lookup or // creation failed. - static Status CreateFromNodeDef(NodeDef node_def, - const OpRegistryInterface* op_registry, - std::shared_ptr* props); + static absl::Status CreateFromNodeDef( + NodeDef node_def, const OpRegistryInterface* op_registry, + std::shared_ptr* props); const OpDef* op_def; // not owned. NodeDef node_def; diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 993ac59112a1a7..41b39fc2076469 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -48,12 +48,12 @@ class OpRegistryInterface { // Returns an error status and sets *op_reg_data to nullptr if no OpDef is // registered under that name, otherwise returns the registered OpDef. // Caller must not delete the returned pointer. - virtual Status LookUp(const std::string& op_type_name, - const OpRegistrationData** op_reg_data) const = 0; + virtual absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const = 0; // Shorthand for calling LookUp to get the OpDef. - Status LookUpOpDef(const std::string& op_type_name, - const OpDef** op_def) const; + absl::Status LookUpOpDef(const std::string& op_type_name, + const OpDef** op_def) const; }; // The standard implementation of OpRegistryInterface, along with a @@ -68,14 +68,15 @@ class OpRegistryInterface { // }); class OpRegistry : public OpRegistryInterface { public: - typedef std::function OpRegistrationDataFactory; + typedef std::function + OpRegistrationDataFactory; OpRegistry(); void Register(const OpRegistrationDataFactory& op_data_factory); - Status LookUp(const std::string& op_type_name, - const OpRegistrationData** op_reg_data) const override; + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override; // Returns OpRegistrationData* of registered op type, else returns nullptr. const OpRegistrationData* LookUp(const std::string& op_type_name) const; @@ -100,7 +101,7 @@ class OpRegistry : public OpRegistryInterface { // Registers a function that validates op registry. void RegisterValidator( - std::function validator) { + std::function validator) { op_registry_validator_ = std::move(validator); } @@ -110,7 +111,8 @@ class OpRegistry : public OpRegistryInterface { // obtained from building and adding the OpDef to the registry, and the OpDef // itself if it was successfully built. A watcher returns a Status which is in // turn returned as the final registration status. - typedef std::function Watcher; + typedef std::function + Watcher; // An OpRegistry object has only one watcher. This interface is not thread // safe, as different clients are free to set the watcher any time. @@ -122,13 +124,13 @@ class OpRegistry : public OpRegistryInterface { // SetWatcher(nullptr); // Returns a non-OK status if a non-null watcher is over-written by another // non-null watcher. - Status SetWatcher(const Watcher& watcher); + absl::Status SetWatcher(const Watcher& watcher); // Process the current list of deferred registrations. Note that calls to // Export, LookUp and DebugString would also implicitly process the deferred // registrations. Returns the status of the first failed op registration or // OkStatus() otherwise. - Status ProcessRegistrations() const; + absl::Status ProcessRegistrations() const; // Defer the registrations until a later call to a function that processes // deferred registrations are made. Normally, registrations that happen after @@ -148,13 +150,14 @@ class OpRegistry : public OpRegistryInterface { // Calls the functions in deferred_ and registers their OpDef's // It returns the Status of the first failed op registration or OkStatus() // otherwise. - Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status CallDeferred() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Add 'def' to the registry with additional data 'data'. On failure, or if // there is already an OpDef with that name registered, returns a non-okay // status. - Status RegisterAlreadyLocked(const OpRegistrationDataFactory& op_data_factory) - const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status RegisterAlreadyLocked( + const OpRegistrationDataFactory& op_data_factory) const + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); const OpRegistrationData* LookUpSlow(const std::string& op_type_name) const; @@ -169,7 +172,8 @@ class OpRegistry : public OpRegistryInterface { // Registry watcher. mutable Watcher watcher_ TF_GUARDED_BY(mu_); - std::function op_registry_validator_; + std::function + op_registry_validator_; }; // An adapter to allow an OpList to be used as an OpRegistryInterface. @@ -181,8 +185,8 @@ class OpListOpRegistry : public OpRegistryInterface { public: // Does not take ownership of op_list, *op_list must outlive *this. explicit OpListOpRegistry(const OpList* op_list); - Status LookUp(const std::string& op_type_name, - const OpRegistrationData** op_reg_data) const override; + absl::Status LookUp(const std::string& op_type_name, + const OpRegistrationData** op_reg_data) const override; // Returns OpRegistrationData* of op type in list, else returns nullptr. const OpRegistrationData* LookUp(const std::string& op_type_name) const; diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index 04adde1c50327d..da11e32498becf 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -93,7 +93,7 @@ class OpCompatibilityTest : public OpsTestBase { void ExpectIncompatible(const OpDef& old_op_def, const OpDef& new_op_def, const string& error) { // Test OpDefCompatible gives the same answer without the node_def. - Status status = OpDefCompatible(old_op_def, new_op_def); + absl::Status status = OpDefCompatible(old_op_def, new_op_def); if (status.ok()) { ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. " << SummarizeOpDef(new_op_def); @@ -115,7 +115,7 @@ class OpCompatibilityTest : public OpsTestBase { AddDefaultsToNodeDef(*new_op_def, node_def()); // Validate that it does not pass validation. - Status status = ValidateNodeDef(*node_def(), *new_op_def); + absl::Status status = ValidateNodeDef(*node_def(), *new_op_def); if (status.ok()) { ADD_FAILURE() << SummarizeNodeDef(*node_def()); } else { @@ -174,7 +174,7 @@ class OpCompatibilityTest : public OpsTestBase { // Validate that the NodeDef is valid. TF_ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def)); - Status status = OpDefAttrDefaultsUnchanged(old_op_def, *new_op_def); + absl::Status status = OpDefAttrDefaultsUnchanged(old_op_def, *new_op_def); if (status.ok()) { ADD_FAILURE() << SummarizeOpDef(old_op_def) << " vs. " << SummarizeOpDef(*new_op_def); diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index bc6058767dcbc5..8009135d584188 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -33,7 +33,7 @@ limitations under the License. namespace tensorflow { // TODO(b/62899350): Refactor without proto dependencies. -typedef std::function OpTypeConstructor; +typedef std::function OpTypeConstructor; typedef std::vector> TypeRefVector; @@ -61,7 +61,7 @@ class FunctionDefHelper; namespace shape_inference { class InferenceContext; } -typedef std::function +typedef std::function OpShapeInferenceFn; struct OpRegistrationData { @@ -253,7 +253,7 @@ class OpDefBuilder { // // Note that OpDefBuilder only reports parsing errors. You should also // call ValidateOpDef() to detect other problems. - Status Finalize(OpRegistrationData* op_reg_data) const; + absl::Status Finalize(OpRegistrationData* op_reg_data) const; private: friend class FunctionDefHelper; diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index 59aeed913e4506..80d2d37545ebe2 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -43,7 +43,7 @@ class OpDefBuilderTest : public ::testing::Test { void ExpectSuccess(const OpDefBuilder& builder, StringPiece proto, OpShapeInferenceFn* shape_fn_out = nullptr) { OpRegistrationData op_reg_data; - Status status = builder.Finalize(&op_reg_data); + absl::Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); OpDef& op_def = op_reg_data.op_def; if (status.ok()) { @@ -63,7 +63,7 @@ class OpDefBuilderTest : public ::testing::Test { void ExpectOrdered(const OpDefBuilder& builder, StringPiece proto) { OpRegistrationData op_reg_data; - Status status = builder.Finalize(&op_reg_data); + absl::Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); OpDef& op_def = op_reg_data.op_def; if (status.ok()) { @@ -76,7 +76,7 @@ class OpDefBuilderTest : public ::testing::Test { void ExpectFailure(const OpDefBuilder& builder, const string& error) { OpRegistrationData op_reg_data; - Status status = builder.Finalize(&op_reg_data); + absl::Status status = builder.Finalize(&op_reg_data); EXPECT_FALSE(status.ok()); if (!status.ok()) { EXPECT_EQ(status.message(), error); diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h index 09c6c0cf72e271..e116f89229dc54 100644 --- a/tensorflow/core/framework/op_def_util.h +++ b/tensorflow/core/framework/op_def_util.h @@ -30,16 +30,16 @@ limitations under the License. namespace tensorflow { // Performs a consistency check across the fields of the op_def. -Status ValidateOpDef(const OpDef& op_def); +absl::Status ValidateOpDef(const OpDef& op_def); // Check if an op is deprecated at the given GraphDef version. If the op is // deprecated at a future version, a warning will be logged. -Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version); +absl::Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version); // Validates that attr_value satisfies the type and constraints from attr. // REQUIRES: attr has already been validated. -Status ValidateAttrValue(const AttrValue& attr_value, - const OpDef::AttrDef& attr); +absl::Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr); // The following search through op_def for an attr with the indicated name. // Returns nullptr if no such attr is found. @@ -61,19 +61,20 @@ std::string SummarizeOpDef(const OpDef& op_def); // Returns an error if new_op is not backwards-compatible with (more // accepting than) old_op. // REQUIRES: old_op and new_op must pass validation. -Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op); +absl::Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op); // Returns an error if any attr in penultimate_op that is not in old_op // has a different default value in new_op. In general it is not safe // to change the default for an attr that has been added to an op. -Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, - const OpDef& penultimate_op, - const OpDef& new_op); +absl::Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, + const OpDef& penultimate_op, + const OpDef& new_op); // Returns an error if the default value for any attr is removed or modified // in new_op compared to old_op. Adding new default values is safe, and does // not raise an error. -Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op); +absl::Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, + const OpDef& new_op); // Remove all docs from *op_def / *op_list. void RemoveDescriptionsFromOpDef(OpDef* op_def); diff --git a/tensorflow/core/framework/op_def_util_test.cc b/tensorflow/core/framework/op_def_util_test.cc index 90a3f6720e14e4..faf958b00e45e2 100644 --- a/tensorflow/core/framework/op_def_util_test.cc +++ b/tensorflow/core/framework/op_def_util_test.cc @@ -41,11 +41,13 @@ OpDef::AttrDef ADef(const string& text) { class ValidateOpDefTest : public ::testing::Test { protected: - Status TestProto(const string& text) { return ValidateOpDef(FromText(text)); } + absl::Status TestProto(const string& text) { + return ValidateOpDef(FromText(text)); + } - Status TestBuilder(const OpDefBuilder& builder) { + absl::Status TestBuilder(const OpDefBuilder& builder) { OpRegistrationData op_reg_data; - Status status = builder.Finalize(&op_reg_data); + absl::Status status = builder.Finalize(&op_reg_data); TF_EXPECT_OK(status); if (!status.ok()) { return status; @@ -56,7 +58,7 @@ class ValidateOpDefTest : public ::testing::Test { }; namespace { -void ExpectFailure(const Status& status, const string& message) { +void ExpectFailure(const absl::Status& status, const string& message) { EXPECT_FALSE(status.ok()) << "Did not see error with: " << message; if (!status.ok()) { LOG(INFO) << "message: " << status; @@ -540,12 +542,12 @@ TEST(OpDefAttrDefaultsUnchangedTest, Foo) { TF_EXPECT_OK(OpDefAttrDefaultsUnchanged(op1, op2)); // Changing a default value: not ok. - Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3); + absl::Status changed_attr = OpDefAttrDefaultsUnchanged(op2, op3); ExpectFailure(changed_attr, "Attr 'n' has changed it's default value; from \"x\" to \"y\""); // Removing a default value: not ok. - Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1); + absl::Status removed_attr = OpDefAttrDefaultsUnchanged(op2, op1); ExpectFailure(removed_attr, "Attr 'n' has removed it's default; from \"x\" to no default"); } diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h index c269e2df04973c..1db41eb401117f 100644 --- a/tensorflow/core/framework/op_gen_lib.h +++ b/tensorflow/core/framework/op_gen_lib.h @@ -62,20 +62,20 @@ class ApiDefMap { // definitions take precedence. // ApiDefs loaded from files must contain a subset of ops defined // in the OpList passed to the constructor. - Status LoadFileList(Env* env, const std::vector& filenames); + absl::Status LoadFileList(Env* env, const std::vector& filenames); // Load a single file. Api definitions are merged if the same // op definition is loaded multiple times. Later-loaded // definitions take precedence. // ApiDefs loaded from file must contain a subset of ops defined // in the OpList passed to the constructor. - Status LoadFile(Env* env, const string& filename); + absl::Status LoadFile(Env* env, const string& filename); // Load ApiDefs from string containing ApiDefs text proto. // api_def_file_contents is expected to be in "multiline format". // ApiDefs must contain a subset of ops defined in OpsList // passed to the constructor. - Status LoadApiDef(const string& api_def_file_contents); + absl::Status LoadApiDef(const string& api_def_file_contents); // Updates ApiDef docs. For example, if ApiDef renames an argument // or attribute, applies these renames to descriptions as well. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index cd9c83bebc626f..5f72933a9b590b 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -1768,9 +1768,8 @@ Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { const OpRegistrationData* op_reg_data; const Status status = op_registry.LookUp(kernel_def.op(), &op_reg_data); if (!status.ok()) { - // TODO(josh11b): Make this a hard error. - LOG(ERROR) << "OpKernel ('" << kernel_def.ShortDebugString() - << "') for unknown op: " << kernel_def.op(); + LOG(WARNING) << "OpKernel ('" << kernel_def.ShortDebugString() + << "') for unknown op: " << kernel_def.op(); continue; } const OpDef& op_def = op_reg_data->op_def; diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index ff067cd9b61412..264b66471291ad 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -194,8 +194,9 @@ class OpKernel { return output_memory_types_; } - Status InputRange(StringPiece input_name, int* start, int* stop) const; - Status OutputRange(StringPiece output_name, int* start, int* stop) const; + absl::Status InputRange(StringPiece input_name, int* start, int* stop) const; + absl::Status OutputRange(StringPiece output_name, int* start, + int* stop) const; // Returns `true` if and only if this kernel uses deferred execution. bool is_deferred() const { return is_deferred_; } @@ -260,7 +261,7 @@ class OpKernelConstruction { const std::shared_ptr& props, const MemoryTypeSlice& input_memory_types, const MemoryTypeSlice& output_memory_types, - int graph_def_version, Status* status); + int graph_def_version, absl::Status* status); Env* env() const { return device_->env(); } @@ -276,10 +277,11 @@ class OpKernelConstruction { // Allocates a temporary Tensor of the specified type and shape. The // Tensor must not be used after kernel construction is // complete. See comment above. - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp); - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp, AllocatorAttributes allocator_attr); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr); // User-supplied configuration of this operation. const NodeDef& def() const { return props_->node_def; } @@ -305,18 +307,19 @@ class OpKernelConstruction { // If expected_inputs == inputs() and expected_outputs == output_types(), // returns OK, else returns INVALID_ARGUMENT with an error message. // Recommended for Ops with dynamic signatures. - Status MatchSignature(const DataTypeSlice expected_inputs, - const DataTypeSlice expected_outputs); + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); // For recording configuration errors during construction. - void SetStatus(const Status& status); - const Status& status() const { return *status_; } + void SetStatus(const absl::Status& status); + const absl::Status& status() const { return *status_; } // Look up the attr with name attr_name and set *value to its value. If no // attr with attr_name is found in def(), or the attr does not have // a matching type, a non-ok status will be returned. template - Status GetAttr(StringPiece attr_name, T* value) const TF_ATTRIBUTE_NOINLINE; + absl::Status GetAttr(StringPiece attr_name, + T* value) const TF_ATTRIBUTE_NOINLINE; // Return true if the attr_name is defined in def(). bool HasAttr(StringPiece attr_name) const; @@ -336,10 +339,10 @@ class OpKernelConstruction { int graph_def_version() const { return graph_def_version_; } // Helper routines for the OP_REQUIRES macros - void CtxFailure(const Status& s); - void CtxFailureWithWarning(const Status& s); - void CtxFailure(const char* file, int line, const Status& s); - void CtxFailureWithWarning(const char* file, int line, const Status& s); + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); // Unrecommended functions: these are functions that have some // current uses but are not recommended for use, and may go away at @@ -363,7 +366,7 @@ class OpKernelConstruction { MemoryTypeSlice input_memory_types_; MemoryTypeSlice output_memory_types_; const int graph_def_version_; - Status* status_; + absl::Status* status_; // Allow access from OpKernel ctor. friend class OpKernel; @@ -473,7 +476,7 @@ class OpOutputList { Tensor* operator[](int i); bool required(int i) const; DataType expected_output_dtype(int i) const; - Status allocate(int i, const TensorShape& shape, Tensor** output); + absl::Status allocate(int i, const TensorShape& shape, Tensor** output); void set(int i, const Tensor& tensor); void set(int i, Tensor&& tensor); void set_ref(int i, mutex* mu, Tensor* tensor_for_ref); @@ -730,7 +733,7 @@ class OpKernelContext { int num_inputs() const { return params_->inputs.size(); } DataType input_dtype(int index) const; - Status input_dtype(StringPiece name, DataType* dtype) const; + absl::Status input_dtype(StringPiece name, DataType* dtype) const; MemoryType input_memory_type(int index) const; int num_outputs() const { return outputs_.size(); } @@ -755,14 +758,14 @@ class OpKernelContext { // use mutable_input below. // REQUIRES: !IsRefType(input_dtype(index)) // REQUIRES: the named input must not be a list. - Status input(StringPiece name, const Tensor** tensor); + absl::Status input(StringPiece name, const Tensor** tensor); // Returns the named list-valued immutable input in "list", as // defined in the OpDef. If the named output is not list-valued, // returns a one-element list. May only be used for non-Ref // inputs. For Ref inputs use mutable_input below. // REQUIRES: !IsRefType(input_dtype(index)) - Status input_list(StringPiece name, OpInputList* list); + absl::Status input_list(StringPiece name, OpInputList* list); // For mutable inputs, use the following together to make sure there // is no concurrent access to mutable_input(), e.g.: @@ -772,7 +775,7 @@ class OpKernelContext { // // modify the values in t // } // REQUIRES: IsRefType(input_dtype(index)) - Status input_ref_mutex(StringPiece name, mutex** out_mutex); + absl::Status input_ref_mutex(StringPiece name, mutex** out_mutex); // Returns a mutable input tensor. Must be used to access Ref // inputs. REQUIRES: IsRefType(input_dtype(index)). The caller may @@ -790,7 +793,7 @@ class OpKernelContext { // the input mutex will be acquired before returning the Tensor. // REQUIRES: the named input must not be a list. // REQUIRES: the named input must be a ref tensor. - Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); + absl::Status mutable_input(StringPiece name, Tensor* tensor, bool lock_held); // Returns the named list-valued mutable input in "list", as defined // in the OpDef. If the named input is not list-valued, returns a @@ -798,7 +801,7 @@ class OpKernelContext { // stored in the Tensor buffer may be modified, and modifications // will be visible to other Ops reading the same ref tensor. // REQUIRES: the named input must be a ref tensor. - Status mutable_input_list(StringPiece name, OpMutableInputList* list); + absl::Status mutable_input_list(StringPiece name, OpMutableInputList* list); // Replace the corresponding Ref Input to use the storage buffer // used by tensor. If !lock_held the input mutex will be acquired @@ -810,8 +813,8 @@ class OpKernelContext { // buffer used by tensor. If !lock_held the input mutex will be // acquired before returning the Tensor. // REQUIRES: IsRefType(input_dtype(index)). - Status replace_ref_input(StringPiece name, const Tensor& tensor, - bool lock_held); + absl::Status replace_ref_input(StringPiece name, const Tensor& tensor, + bool lock_held); // Deletes the Tensor object used as the Ref Input at // input_index. This is not usually necessary and should be used @@ -861,10 +864,9 @@ class OpKernelContext { bool forward_input_to_output_with_shape(int input_index, int output_index, const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; - Status forward_input_to_output_with_shape(StringPiece input_name, - StringPiece output_name, - const TensorShape& output_shape, - Tensor** output) TF_MUST_USE_RESULT; + absl::Status forward_input_to_output_with_shape( + StringPiece input_name, StringPiece output_name, + const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; // Returns a pointer to a Tensor aliasing the underlying buffer backing // input[input_index] iff @@ -905,11 +907,11 @@ class OpKernelContext { // forwarded input will be assign to output argument forwarded_input (if it's // not nullptr). If no inputs are forwarded, forwarded_input will be assigned // -1. - Status forward_input_or_allocate_output( + absl::Status forward_input_or_allocate_output( absl::Span candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output, int* forwarded_input = nullptr) TF_MUST_USE_RESULT; - Status forward_input_or_allocate_output( + absl::Status forward_input_or_allocate_output( absl::Span candidate_input_names, StringPiece output_name, const TensorShape& output_shape, Tensor** output) TF_MUST_USE_RESULT; @@ -917,12 +919,12 @@ class OpKernelContext { // Tries to reuse one of the inputs given in input_indices as a temporary. // If none of the given inputs can be forwarded, calls // allocate_temp() to allocate a new temporary buffer. - Status forward_input_or_allocate_temp( + absl::Status forward_input_or_allocate_temp( absl::Span candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, Tensor* out_temp) TF_MUST_USE_RESULT; - Status forward_input_or_allocate_temp( + absl::Status forward_input_or_allocate_temp( absl::Span candidate_input_indices, DataType type, const TensorShape& shape, Tensor* out_temp) TF_MUST_USE_RESULT { return forward_input_or_allocate_temp(candidate_input_indices, type, shape, @@ -933,7 +935,7 @@ class OpKernelContext { // Returns the named list-valued output in "list", as defined in the OpDef. // If the named output is not list-valued, returns a one-element list. - Status output_list(StringPiece name, OpOutputList* list); + absl::Status output_list(StringPiece name, OpOutputList* list); // If output_required(index) returns true, the OpKernel's Compute() method // should call allocate_output(index, ...), set_output(index, ...), @@ -993,49 +995,53 @@ class OpKernelContext { // If memory allocation fails, returns an error status. // // REQUIRES: !IsRefType(expected_output_dtype(index)) - Status allocate_output(int index, const TensorShape& shape, - Tensor** tensor) TF_MUST_USE_RESULT; - Status allocate_output(StringPiece name, const TensorShape& shape, - Tensor** tensor) TF_MUST_USE_RESULT; + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; + absl::Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor) TF_MUST_USE_RESULT; // The following methods use the supplied attributes instead of // those in output_attr_array. The caller is responsible for // ensuring that the attributes are "compatible" with the // output_attr_array, e.g. the tensor is allocated on the correct // device. See comment above. - Status allocate_output(int index, const TensorShape& shape, Tensor** tensor, - AllocatorAttributes attr) TF_MUST_USE_RESULT; - Status allocate_output(StringPiece name, const TensorShape& shape, - Tensor** tensor, - AllocatorAttributes attr) TF_MUST_USE_RESULT; + absl::Status allocate_output(int index, const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; + absl::Status allocate_output(StringPiece name, const TensorShape& shape, + Tensor** tensor, + AllocatorAttributes attr) TF_MUST_USE_RESULT; // Allocates a temporary Tensor of the specified type and // shape. Devices such as GPUs that enqueue Ops for lazy execution // may retain references to the temporary tensors after the Op's // Compute method has run. See comment above. - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp, AllocatorAttributes allocator_attr, - const AllocationAttributes& allocation_attr); - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp, AllocatorAttributes allocator_attr); - Status allocate_temp(DataType type, const TensorShape& shape, - Tensor* out_temp); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp, + AllocatorAttributes allocator_attr); + absl::Status allocate_temp(DataType type, const TensorShape& shape, + Tensor* out_temp); // Copies a tensor (allocated by the caller) to the specified output // index. REQUIRES: !IsRefType(expected_output_dtype(index)) // REQUIRES: 'tensor' must have the same MemoryType as // output_memory_types[index]. See comment above. - Status set_output(StringPiece name, const Tensor& tensor); - Status set_output(StringPiece name, Tensor&& tensor); + absl::Status set_output(StringPiece name, const Tensor& tensor); + absl::Status set_output(StringPiece name, Tensor&& tensor); void set_output(int index, const Tensor& tensor); void set_output(int index, Tensor&& tensor); // To output a reference. Caller retains ownership of mu and tensor_for_ref, // and they must outlive all uses within the step. See comment above. // REQUIRES: IsRefType(expected_output_dtype(index)) - Status set_output_ref(StringPiece name, mutex* mu, Tensor* tensor_for_ref); + absl::Status set_output_ref(StringPiece name, mutex* mu, + Tensor* tensor_for_ref); // Returns nullptr if allocate_output() or set_output() have not been called. - Status mutable_output(StringPiece name, Tensor** tensor); + absl::Status mutable_output(StringPiece name, Tensor** tensor); // Return the DeviceContext that should be used for this Op. // @@ -1146,13 +1152,13 @@ class OpKernelContext { // returns OK, else returns INVALID_ARGUMENT with an error message. // Recommended for Ops with dynamic signatures, where validation can only // be performed at runtime. - Status MatchSignature(const DataTypeSlice expected_inputs, - const DataTypeSlice expected_outputs); + absl::Status MatchSignature(const DataTypeSlice expected_inputs, + const DataTypeSlice expected_outputs); // An OpKernel should call SetStatus() if Compute() encounters an // error. - void SetStatus(const Status& status); - const Status& status() const { return status_; } + void SetStatus(const absl::Status& status); + const absl::Status& status() const { return status_; } // Cancellation. // @@ -1183,10 +1189,10 @@ class OpKernelContext { } // Helper routines for the OP_REQUIRES macros - void CtxFailure(const Status& s); - void CtxFailureWithWarning(const Status& s); - void CtxFailure(const char* file, int line, const Status& s); - void CtxFailureWithWarning(const char* file, int line, const Status& s); + void CtxFailure(const absl::Status& s); + void CtxFailureWithWarning(const absl::Status& s); + void CtxFailure(const char* file, int line, const absl::Status& s); + void CtxFailureWithWarning(const char* file, int line, const absl::Status& s); // Unrecommended functions: these are functions that have some // current uses but are not recommended for use, and may go away at @@ -1271,16 +1277,17 @@ class OpKernelContext { bool record_memory_consumption_ = false; // Internal common method used when allocating tensor memory - Status allocate_tensor(DataType type, const TensorShape& shape, - Tensor* out_tensor, - AllocatorAttributes allocator_attr) { + absl::Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes allocator_attr) { return allocate_tensor(type, shape, out_tensor, allocator_attr, AllocationAttributes()); } - Status allocate_tensor(DataType type, const TensorShape& shape, - Tensor* out_tensor, AllocatorAttributes allocator_attr, - const AllocationAttributes& allocation_attr); + absl::Status allocate_tensor(DataType type, const TensorShape& shape, + Tensor* out_tensor, + AllocatorAttributes allocator_attr, + const AllocationAttributes& allocation_attr); // Helpers for `set_output()`. @@ -1289,14 +1296,14 @@ class OpKernelContext { void maybe_track_allocations_for_set_output(const Tensor& tensor); - Status get_input_index(StringPiece name, int* out_index) const; - Status get_output_index(StringPiece name, int* out_index) const; + absl::Status get_input_index(StringPiece name, int* out_index) const; + absl::Status get_output_index(StringPiece name, int* out_index) const; // Initialize the allocated_scope_ids_ set the first time this method is // called. void maybe_initialize_scope_id_set(); - Status status_; + absl::Status status_; friend class CollectiveExecutor; // for access to params_ Params* params_; // not owned absl::InlinedVector outputs_; @@ -1380,34 +1387,32 @@ const Eigen::GpuDevice& OpKernelContext::eigen_device() const; // of the returned pointer. // EXPECTED USAGE: unique_ptr op = CreateOpKernel(...); // REQUIRES: def has all attrs specified (e.g. using AddDefaultsToNodeDef()). -std::unique_ptr CreateOpKernel(DeviceType device_type, - DeviceBase* device, - Allocator* allocator, - const NodeDef& node_def, - int graph_def_version, Status* status); +std::unique_ptr CreateOpKernel( + DeviceType device_type, DeviceBase* device, Allocator* allocator, + const NodeDef& node_def, int graph_def_version, absl::Status* status); std::unique_ptr CreateOpKernel( DeviceType device_type, DeviceBase* device, Allocator* allocator, const std::shared_ptr& props, int graph_def_version, - Status* status); + absl::Status* status); -Status CreateOpKernel(DeviceType device_type, DeviceBase* device, - Allocator* allocator, FunctionLibraryRuntime* flib, - const std::shared_ptr& props, - int graph_def_version, OpKernel** kernel); +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); -Status CreateOpKernel(DeviceType device_type, DeviceBase* device, - Allocator* allocator, FunctionLibraryRuntime* flib, - ResourceMgr* resource_mgr, - const std::shared_ptr& props, - int graph_def_version, OpKernel** kernel); +absl::Status CreateOpKernel(DeviceType device_type, DeviceBase* device, + Allocator* allocator, FunctionLibraryRuntime* flib, + ResourceMgr* resource_mgr, + const std::shared_ptr& props, + int graph_def_version, OpKernel** kernel); // Returns into 'device_types' the subset of prioritized_types that this // binary has registered for the given NodeDef. // // REQUIRES: * 'device_types' is not nullptr. // * def has all attrs specified (e.g. using AddDefaultsToNodeDef()). -Status SupportedDeviceTypesForNode( +absl::Status SupportedDeviceTypesForNode( const std::vector& prioritized_types, const NodeDef& def, PrioritizedDeviceTypeVector* device_types, const DeviceNameUtils::ParsedName* local_address_spec = nullptr); @@ -1417,7 +1422,8 @@ Status SupportedDeviceTypesForNode( std::string KernelsRegisteredForOp(StringPiece op_name); // Call once after Op registration has completed. -Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry); +absl::Status ValidateKernelRegistrations( + const OpRegistryInterface& op_registry); // ----------------------------------------------------------------------------- // OpKernel registration implementation follows, please ignore. @@ -1504,7 +1510,7 @@ bool KernelDefAvailable(const DeviceType& device_type, const NodeDef& node_def); // node_attrs has a corresponding kernel registered on device_type, returns OK // and fill in the kernel def and kernel_class_name. and // may be null. -Status FindKernelDef( +absl::Status FindKernelDef( const DeviceType& device_type, StringPiece node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info, @@ -1514,8 +1520,9 @@ Status FindKernelDef( // If node_def has a corresponding kernel registered on device_type, // returns OK and fill in the kernel def and kernel_class_name. and // may be null. -Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, - const KernelDef** def, std::string* kernel_class_name); +absl::Status FindKernelDef(const DeviceType& device_type, + const NodeDef& node_def, const KernelDef** def, + std::string* kernel_class_name); // Writes a list of all registered kernels to LOG(INFO), to help users debug // missing kernel errors. @@ -1582,7 +1589,8 @@ class OpKernelRegistrar { // Template and inline method implementations, please ignore template -Status OpKernelConstruction::GetAttr(StringPiece attr_name, T* value) const { +absl::Status OpKernelConstruction::GetAttr(StringPiece attr_name, + T* value) const { return GetNodeAttr(def(), attr_name, value); } @@ -1687,8 +1695,8 @@ inline DataType OpOutputList::expected_output_dtype(int i) const { return ctx_->expected_output_dtype(start_ + i); } -inline Status OpOutputList::allocate(int i, const TensorShape& shape, - Tensor** output) { +inline absl::Status OpOutputList::allocate(int i, const TensorShape& shape, + Tensor** output) { DCHECK_GE(i, 0); DCHECK_LT(i, stop_ - start_); return ctx_->allocate_output(start_ + i, shape, output); diff --git a/tensorflow/core/framework/op_kernel_test_base.h b/tensorflow/core/framework/op_kernel_test_base.h index 3227005326ff07..7b3951e56411be 100644 --- a/tensorflow/core/framework/op_kernel_test_base.h +++ b/tensorflow/core/framework/op_kernel_test_base.h @@ -74,7 +74,7 @@ class OpKernelBuilderTest : public ::testing::Test { const DeviceType& device_type, const std::vector& attrs, DataTypeSlice input_types = {}) { - Status status; + absl::Status status; NodeDef def = CreateNodeDef(op_type, attrs); for (size_t i = 0; i < input_types.size(); ++i) { def.add_input("a:0"); @@ -112,7 +112,7 @@ class OpKernelBuilderTest : public ::testing::Test { void ExpectFailure(const string& op_type, const DeviceType& device_type, const std::vector& attrs, error::Code code) { - Status status; + absl::Status status; const NodeDef def = CreateNodeDef(op_type, attrs); Env* env = Env::Default(); DeviceBase device(env); @@ -135,7 +135,7 @@ class OpKernelBuilderTest : public ::testing::Test { EXPECT_NE(dt.first, device_type); } } else { - Status status2 = + absl::Status status2 = SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); EXPECT_EQ(status.code(), status2.code()); } @@ -153,7 +153,7 @@ class OpKernelBuilderTest : public ::testing::Test { const KernelDef* kernel_def = nullptr; string kernel_class_name; - const Status status = + const absl::Status status = FindKernelDef(device_type, def, &kernel_def, &kernel_class_name); if (status.ok()) { return kernel_class_name; diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h index 8cbca548cf02df..10c4fa467e3228 100644 --- a/tensorflow/core/framework/op_segment.h +++ b/tensorflow/core/framework/op_segment.h @@ -56,10 +56,10 @@ class OpSegment { // error. // // OpSegment keeps the ownership of the returned "*kernel". - typedef std::function CreateKernelFn; - Status FindOrCreate(const std::string& session_handle, - const std::string& node_name, OpKernel** kernel, - CreateKernelFn create_fn); + typedef std::function CreateKernelFn; + absl::Status FindOrCreate(const std::string& session_handle, + const std::string& node_name, OpKernel** kernel, + CreateKernelFn create_fn); // Returns true if OpSegment should own the kernel. static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, diff --git a/tensorflow/core/framework/op_segment_test.cc b/tensorflow/core/framework/op_segment_test.cc index af16e9f7ef96c4..de37f1921c1c1a 100644 --- a/tensorflow/core/framework/op_segment_test.cc +++ b/tensorflow/core/framework/op_segment_test.cc @@ -63,7 +63,7 @@ class OpSegmentTest : public ::testing::Test { OpSegment::CreateKernelFn GetFn(const NodeDef* ndef) { return [this, ndef](OpKernel** kernel) { - Status s; + absl::Status s; auto created = CreateOpKernel(DEVICE_CPU, &device_, cpu_allocator(), *ndef, TF_GRAPH_DEF_VERSION, &s); if (s.ok()) { @@ -115,7 +115,7 @@ TEST_F(OpSegmentTest, SessionNotFound) { OpSegment opseg; OpKernel* op; NodeDef def = float_nodedefs_[0]; - Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + absl::Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); EXPECT_TRUE(errors::IsNotFound(s)) << s; } @@ -125,7 +125,7 @@ TEST_F(OpSegmentTest, CreateFailure) { NodeDef def = float_nodedefs_[0]; def.set_op("nonexistop"); opseg.AddHold("A"); - Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); + absl::Status s = opseg.FindOrCreate("A", def.name(), &op, GetFn(&def)); EXPECT_TRUE(errors::IsNotFound(s)) << s; opseg.RemoveHold("A"); } diff --git a/tensorflow/core/framework/ops_util.h b/tensorflow/core/framework/ops_util.h index 6d6868a4ac6399..ae73a562d40348 100644 --- a/tensorflow/core/framework/ops_util.h +++ b/tensorflow/core/framework/ops_util.h @@ -34,9 +34,9 @@ namespace tensorflow { // index and size for broadcast for that dimension are different from the // current index and kernel size. // This is mainly used by gradient algorithms for pooling operations. -Status GetBroadcastSize(const int index, const int in_size, const int ksize, - const int stride, const int pad_size, int* bindex, - int* bsize); +absl::Status GetBroadcastSize(const int index, const int in_size, + const int ksize, const int stride, + const int pad_size, int* bindex, int* bsize); // Converts Brain's Padding to Eigen's PaddingType. Eigen::PaddingType BrainPadding2EigenPadding(Padding padding); diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h index 2093dd1f45df01..e916b5064aac6a 100644 --- a/tensorflow/core/framework/queue_interface.h +++ b/tensorflow/core/framework/queue_interface.h @@ -34,8 +34,8 @@ class QueueInterface : public ResourceBase { typedef AsyncOpKernel::DoneCallback DoneCallback; typedef std::function CallbackWithTuple; - virtual Status ValidateTuple(const Tuple& tuple) = 0; - virtual Status ValidateManyTuple(const Tuple& tuple) = 0; + virtual absl::Status ValidateTuple(const Tuple& tuple) = 0; + virtual absl::Status ValidateManyTuple(const Tuple& tuple) = 0; // Stashes a function object for future execution, that will eventually // enqueue the tuple of tensors into the queue, and returns immediately. The @@ -82,7 +82,7 @@ class QueueInterface : public ResourceBase { // Assuming *this represents a shared queue, verify that it matches // another instantiation indicated by node_def. - virtual Status MatchesNodeDef(const NodeDef& node_def) = 0; + virtual absl::Status MatchesNodeDef(const NodeDef& node_def) = 0; // Returns the number of elements in the queue. virtual int32 size() const = 0; diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h index f78bce374925e3..6210b68fe17b45 100644 --- a/tensorflow/core/framework/reader_interface.h +++ b/tensorflow/core/framework/reader_interface.h @@ -65,7 +65,7 @@ class ReaderInterface : public ResourceBase { OpKernelContext* context) = 0; // Restore this reader to its newly-constructed state. - virtual Status Reset() = 0; + virtual absl::Status Reset() = 0; // Accessors virtual int64_t NumRecordsProduced() = 0; @@ -73,9 +73,9 @@ class ReaderInterface : public ResourceBase { // -- Serialization/Restoration support -- // Not all readers will support saving and restoring state. - virtual Status SerializeState(tstring* state) = 0; + virtual absl::Status SerializeState(tstring* state) = 0; // Note: Must Reset on error. - virtual Status RestoreState(const tstring& state) = 0; + virtual absl::Status RestoreState(const tstring& state) = 0; string DebugString() const override { return "a reader"; } diff --git a/tensorflow/core/framework/rendezvous.h b/tensorflow/core/framework/rendezvous.h index c4b5b99ec51ba0..87861994226707 100644 --- a/tensorflow/core/framework/rendezvous.h +++ b/tensorflow/core/framework/rendezvous.h @@ -89,8 +89,8 @@ class RendezvousInterface { // Send/Recv on the same worker. // // Send() never blocks. - virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val, - const bool is_dead) = 0; + virtual absl::Status Send(const ParsedKey& key, const Args& args, + const Tensor& val, const bool is_dead) = 0; // Callback provided by a tensor consumer waiting on the rendezvous. // It will be invoked when the tensor is available, or when a non-OK @@ -98,7 +98,7 @@ class RendezvousInterface { // two Rendezvous::Args, one provided by the sender, the other by the // receiver, which may be needed when a non-CPU device is in use // by either side. - typedef std::function DoneCallback; @@ -106,16 +106,16 @@ class RendezvousInterface { DoneCallback done) = 0; // Synchronous wrapper for RecvAsync. - Status Recv(const ParsedKey& key, const Args& args, Tensor* val, - bool* is_dead, int64_t timeout_ms); - Status Recv(const ParsedKey& key, const Args& args, Tensor* val, - bool* is_dead); + absl::Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead, int64_t timeout_ms); + absl::Status Recv(const ParsedKey& key, const Args& args, Tensor* val, + bool* is_dead); // Aborts all pending and future Send/Recv with the given "status". // // StartAbort() does not wait for ongoing calls to finish. // REQUIRES: !status.ok() - virtual void StartAbort(const Status& status) = 0; + virtual void StartAbort(const absl::Status& status) = 0; virtual ~RendezvousInterface(); @@ -135,22 +135,23 @@ class Rendezvous : public RendezvousInterface, public core::WeakRefCounted { // Default to a factory that evaluates to false. Factory() : valid_(false) {} - explicit Factory(std::function*)> - create_fn) + explicit Factory( + std::function*)> + create_fn) : valid_(true), create_fn_(std::move(create_fn)) {} explicit operator bool() const { return valid_; } - Status operator()(const int64_t step_id, const DeviceMgr* device_mgr, - tsl::core::RefCountPtr* rendez) const { + absl::Status operator()(const int64_t step_id, const DeviceMgr* device_mgr, + tsl::core::RefCountPtr* rendez) const { return create_fn_(step_id, device_mgr, rendez); } private: bool valid_; - std::function*)> + std::function*)> create_fn_; }; @@ -163,7 +164,7 @@ class Rendezvous : public RendezvousInterface, public core::WeakRefCounted { const std::string& name, const FrameAndIter& frame_iter); - static Status ParseKey(StringPiece key, ParsedKey* out); + static absl::Status ParseKey(StringPiece key, ParsedKey* out); }; // Returns a Rendezvous instance that is limited to use only by diff --git a/tensorflow/core/framework/resource_base.h b/tensorflow/core/framework/resource_base.h index 46a76ea094ce17..c22adb559f127c 100644 --- a/tensorflow/core/framework/resource_base.h +++ b/tensorflow/core/framework/resource_base.h @@ -52,7 +52,7 @@ class ResourceBase : public core::WeakRefCounted { // should not be tied to the graph that created it, since the graph may be // destroyed before the resource is used. To avoid this lifetime issue, you // can usually set a unique `shared_name` attribute for the resource. - virtual Status AsGraphDef(GraphDefBuilder* builder, Node** out) const { + virtual absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const { return errors::Unimplemented("AsGraphDef not implemented for resource ", DebugString()); } diff --git a/tensorflow/core/framework/resource_handle.h b/tensorflow/core/framework/resource_handle.h index 93c62f44b8d36c..393a899862d0d4 100644 --- a/tensorflow/core/framework/resource_handle.h +++ b/tensorflow/core/framework/resource_handle.h @@ -49,8 +49,8 @@ class ResourceHandle { // Use this factory method if the `proto` comes from user controlled input, to // prevent a denial of service. - static Status BuildResourceHandle(const ResourceHandleProto& proto, - ResourceHandle* out); + static absl::Status BuildResourceHandle(const ResourceHandleProto& proto, + ResourceHandle* out); // Unique name for the device containing the resource. const std::string& device() const { return device_; } @@ -97,7 +97,7 @@ class ResourceHandle { // Conversion to and from ResourceHandleProto void AsProto(ResourceHandleProto* proto) const; - Status FromProto(const ResourceHandleProto& proto); + absl::Status FromProto(const ResourceHandleProto& proto); // Serialization via ResourceHandleProto std::string SerializeAsString() const; @@ -165,11 +165,11 @@ class ResourceHandle { // Validates that the resource type in `handle` is `T`. template - Status ValidateType() const { + absl::Status ValidateType() const { return ValidateType(TypeIndex::Make()); } - Status ValidateType(const TypeIndex& type_index) const; + absl::Status ValidateType(const TypeIndex& type_index) const; // Generates unique IDs (e.g. for names of anonymous variables) static int64_t GenerateUniqueId(); diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h index 63ccd938314925..9982c02fc36a35 100644 --- a/tensorflow/core/framework/resource_op_kernel.h +++ b/tensorflow/core/framework/resource_op_kernel.h @@ -80,7 +80,7 @@ class ResourceOpKernel : public OpKernel { mgr->LookupOrCreate( cinfo_.container(), cinfo_.name(), &resource, [this](T** ret) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - Status s = CreateResource(ret); + absl::Status s = CreateResource(ret); if (!s.ok() && *ret != nullptr) { CHECK((*ret)->Unref()); } @@ -133,7 +133,7 @@ class ResourceOpKernel : public OpKernel { // Must return a T descendant allocated with new that ResourceOpKernel will // take ownership of. - virtual Status CreateResource(T** resource) + virtual absl::Status CreateResource(T** resource) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0; // During the first Compute(), resource is either created or looked up using @@ -141,7 +141,7 @@ class ResourceOpKernel : public OpKernel { // it is compatible with this op's configuration. The verification may fail in // cases such as two graphs asking queues of the same shared name to have // inconsistent capacities. - virtual Status VerifyResource(T* resource) { return absl::OkStatus(); } + virtual absl::Status VerifyResource(T* resource) { return absl::OkStatus(); } Tensor tensor_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/resource_op_kernel_test.cc b/tensorflow/core/framework/resource_op_kernel_test.cc index 9b6b7ea30eb824..5f0b01a174ccd9 100644 --- a/tensorflow/core/framework/resource_op_kernel_test.cc +++ b/tensorflow/core/framework/resource_op_kernel_test.cc @@ -57,12 +57,12 @@ class StubResourceOpKernel : public ResourceOpKernel { using ResourceOpKernel::ResourceOpKernel; private: - Status CreateResource(StubResource** resource) override { + absl::Status CreateResource(StubResource** resource) override { *resource = CHECK_NOTNULL(new StubResource); return GetNodeAttr(def(), "code", &(*resource)->code); } - Status VerifyResource(StubResource* resource) override { + absl::Status VerifyResource(StubResource* resource) override { int code; TF_RETURN_IF_ERROR(GetNodeAttr(def(), "code", &code)); if (code != resource->code) { @@ -93,7 +93,7 @@ class ResourceOpKernelTest : public ::testing::Test { .Attr("code", code) .Attr("shared_name", shared_name) .Finalize(&node_def)); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel( DEVICE_CPU, &device_, device_.GetAllocator(AllocatorAttributes()), node_def, TF_GRAPH_DEF_VERSION, &status)); @@ -110,7 +110,7 @@ class ResourceOpKernelTest : public ::testing::Test { return resource_op; } - Status RunOpKernel(OpKernel* op) { + absl::Status RunOpKernel(OpKernel* op) { OpKernelContext::Params params; params.device = &device_; @@ -148,7 +148,7 @@ TEST_F(ResourceOpKernelTest, PrivateResource) { // Destroy the op kernel. Expect the resource to be released. op = nullptr; - Status s = + absl::Status s = mgr_.Lookup(mgr_.default_container(), key, &resource); EXPECT_FALSE(s.ok()); diff --git a/tensorflow/core/framework/resource_var.cc b/tensorflow/core/framework/resource_var.cc index c441a8e2648591..9cd77215af9549 100644 --- a/tensorflow/core/framework/resource_var.cc +++ b/tensorflow/core/framework/resource_var.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -Status Var::AsGraphDef(GraphDefBuilder* builder, Node** out) const { +absl::Status Var::AsGraphDef(GraphDefBuilder* builder, Node** out) const { // Set a shared_name so that the created resource can outlive the graph that // created it. Node* var = ops::SourceOp( diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h index d74435d33f5406..6c0a8d962f022a 100644 --- a/tensorflow/core/framework/resource_var.h +++ b/tensorflow/core/framework/resource_var.h @@ -86,7 +86,7 @@ class Var : public ResourceBase { is_initialized = false; } - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override; + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override; std::string DebugString() const override { return strings::StrCat(DataTypeString(tensor_.dtype()), "/", diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc index e68025643817d8..df0e45bee7bba5 100644 --- a/tensorflow/core/framework/run_handler.cc +++ b/tensorflow/core/framework/run_handler.cc @@ -837,11 +837,11 @@ class RunHandlerPool::Impl { thread_local std::unique_ptr< Eigen::MaxSizeVector> thread_work_sources = - std::unique_ptr>( - new Eigen::MaxSizeVector( - static_cast(ParamFromEnvWithDefault( - "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", - kMaxConcurrentHandlers)))); + std::make_unique>( + + static_cast(ParamFromEnvWithDefault( + "TF_RUN_HANDLER_MAX_CONCURRENT_HANDLERS", + kMaxConcurrentHandlers))); uint64 version; int num_active_requests; RunHandler::Impl* handler_impl; @@ -1059,7 +1059,7 @@ void RunHandler::Impl::ThreadPoolInterfaceWrapper::Schedule( RunHandler::Impl::Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) { - thread_pool_interface_.reset(new ThreadPoolInterfaceWrapper(this)); + thread_pool_interface_ = std::make_unique(this); Reset(0, RunOptions::Experimental::RunHandlerPoolOptions()); } diff --git a/tensorflow/core/framework/run_handler_test.cc b/tensorflow/core/framework/run_handler_test.cc index 582ed1c194c8fb..b6560dc45c73b9 100644 --- a/tensorflow/core/framework/run_handler_test.cc +++ b/tensorflow/core/framework/run_handler_test.cc @@ -662,8 +662,8 @@ TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPool) { RunOptions run_options; run_options.mutable_experimental()->set_use_run_handler_pool(true); - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, nullptr); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); EXPECT_EQ(absl::OkStatus(), s); ASSERT_EQ(1, outputs.size()); @@ -693,8 +693,8 @@ TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPool) { std::vector> inputs; std::vector outputs; // Run the graph - Status s = session->Run(run_options, inputs, output_names, {}, &outputs, - nullptr); + absl::Status s = session->Run(run_options, inputs, output_names, {}, + &outputs, nullptr); EXPECT_EQ(absl::OkStatus(), s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); @@ -729,8 +729,8 @@ TEST_F(RunHandlerTest, UseRunHandlerPoolEnableSubPoolWithPriority) { ->mutable_run_handler_pool_options() ->set_priority(1); - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, nullptr); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); EXPECT_EQ(absl::OkStatus(), s); ASSERT_EQ(1, outputs.size()); @@ -762,8 +762,8 @@ TEST_F(RunHandlerTest, TestConcurrencyUseRunHandlerPoolWithPriority) { std::vector> inputs; std::vector outputs; // Run the graph - Status s = session->Run(run_options, inputs, output_names, {}, &outputs, - nullptr); + absl::Status s = session->Run(run_options, inputs, output_names, {}, + &outputs, nullptr); EXPECT_EQ(absl::OkStatus(), s); ASSERT_EQ(1, outputs.size()); auto mat = outputs[0].matrix(); diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h index 9d49951870d4f0..d102e153c0001f 100644 --- a/tensorflow/core/framework/session_state.h +++ b/tensorflow/core/framework/session_state.h @@ -31,13 +31,13 @@ namespace tensorflow { class SessionState { public: // Get a tensor from the session state. - Status GetTensor(const std::string& handle, Tensor* tensor); + absl::Status GetTensor(const std::string& handle, Tensor* tensor); // Store a tensor in the session state. - Status AddTensor(const std::string& handle, const Tensor& tensor); + absl::Status AddTensor(const std::string& handle, const Tensor& tensor); // Delete a tensor from the session state. - Status DeleteTensor(const std::string& handle); + absl::Status DeleteTensor(const std::string& handle); int64_t GetNewId(); @@ -68,11 +68,11 @@ class TensorStore { }; // Add the named tensor to the tensor store for this run. - Status AddTensor(const std::string& name, const TensorAndKey& tk); + absl::Status AddTensor(const std::string& name, const TensorAndKey& tk); // Save the tensors in the tensor store of this run to the session. - Status SaveTensors(const std::vector& output_names, - SessionState* session_state); + absl::Status SaveTensors(const std::vector& output_names, + SessionState* session_state); // Returns true if no tensors have been added to this store. bool empty() TF_NO_THREAD_SAFETY_ANALYSIS { return !dirty_; } diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 71d856eaeebb6b..9a34865b3810b7 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include +#include #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/full_type_util.h" @@ -68,7 +69,7 @@ InferenceContext::InferenceContext( if (v == nullptr) { continue; } - handle_data[i].reset(new std::vector(v->size())); + handle_data[i] = std::make_unique>(v->size()); auto& new_v = *handle_data[i]; for (int j = 0, end = v->size(); j < end; ++j) { const auto& p = (*v)[j]; @@ -100,10 +101,11 @@ InferenceContext::InferenceContext( InferenceContext::~InferenceContext() {} -Status InferenceContext::Run( - const std::function& fn) { +absl::Status InferenceContext::Run( + const std::function& + fn) { ForgetMerges(); - Status s = fn(this); + absl::Status s = fn(this); if (!s.ok()) { ForgetMerges(); return AttachContext(s); @@ -116,8 +118,8 @@ Status InferenceContext::Run( return s; } -Status InferenceContext::set_output(StringPiece output_name, - const std::vector& shapes) { +absl::Status InferenceContext::set_output( + StringPiece output_name, const std::vector& shapes) { auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); @@ -135,8 +137,8 @@ Status InferenceContext::set_output(StringPiece output_name, return absl::OkStatus(); } -Status InferenceContext::input(StringPiece input_name, - std::vector* output) const { +absl::Status InferenceContext::input(StringPiece input_name, + std::vector* output) const { const auto result = input_name_map_.find(input_name); if (result == input_name_map_.end()) { return errors::InvalidArgument("Unknown input name: ", input_name); @@ -149,8 +151,8 @@ Status InferenceContext::input(StringPiece input_name, return absl::OkStatus(); } -Status InferenceContext::output(StringPiece output_name, - std::vector* output) const { +absl::Status InferenceContext::output(StringPiece output_name, + std::vector* output) const { const auto result = output_name_map_.find(output_name); if (result == output_name_map_.end()) { return errors::InvalidArgument("Unknown output name: ", output_name); @@ -167,7 +169,7 @@ void InferenceContext::PreInputInit( const OpDef& op_def, const std::vector& input_tensors, const std::vector& input_tensors_as_shapes) { // TODO(mdan): This is also done at graph construction. Run only here instead? - Status s = full_type::SpecializeType(attrs_, op_def, ret_types_); + absl::Status s = full_type::SpecializeType(attrs_, op_def, ret_types_); if (!s.ok()) { construction_status_ = s; return; @@ -188,7 +190,7 @@ void InferenceContext::PreInputInit( output_handle_shapes_and_types_.resize(num_outputs); } -Status InferenceContext::ExpandOutputs(int new_output_size) { +absl::Status InferenceContext::ExpandOutputs(int new_output_size) { const int outputs_size = outputs_.size(); if (new_output_size < outputs_size) { return errors::InvalidArgument("Trying to reduce number of outputs of op."); @@ -312,8 +314,8 @@ string InferenceContext::DebugString( return strings::StrCat("[", absl::StrJoin(pieces, ","), "]"); } -Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank, - ShapeHandle* out) { +absl::Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank, + ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } @@ -337,8 +339,8 @@ Status InferenceContext::WithRank(ShapeHandle shape, int64_t rank, existing); } -Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank, - ShapeHandle* out) { +absl::Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank, + ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } @@ -352,8 +354,8 @@ Status InferenceContext::WithRankAtLeast(ShapeHandle shape, int64_t rank, " but is rank ", existing); } -Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank, - ShapeHandle* out) { +absl::Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank, + ShapeHandle* out) { if (rank > kint32max) { return errors::InvalidArgument("Rank cannot exceed kint32max"); } @@ -367,8 +369,8 @@ Status InferenceContext::WithRankAtMost(ShapeHandle shape, int64_t rank, " but is rank ", existing); } -Status InferenceContext::WithValue(DimensionHandle dim, int64_t value, - DimensionHandle* out) { +absl::Status InferenceContext::WithValue(DimensionHandle dim, int64_t value, + DimensionHandle* out) { const int64_t existing = Value(dim); if (existing == value) { *out = dim; @@ -409,8 +411,8 @@ void InferenceContext::Relax(DimensionHandle d_old, DimensionHandle d_new, } } -Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, - DimensionHandle* out) { +absl::Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { if (d0.SameHandle(d1)) { *out = d0; return absl::OkStatus(); @@ -432,9 +434,9 @@ Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, } } -Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, - ShapeHandle* s_out, - ShapeHandle* prefix_out) { +absl::Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, + ShapeHandle* s_out, + ShapeHandle* prefix_out) { *s_out = *prefix_out = nullptr; if (!RankKnown(prefix) || !RankKnown(s)) { *s_out = s; @@ -503,8 +505,8 @@ void InferenceContext::Relax(ShapeHandle s_old, ShapeHandle s_new, *out = MakeShape(dims); } -Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, - ShapeHandle* out) { +absl::Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { if (s0.SameHandle(s1)) { *out = s0; return absl::OkStatus(); @@ -563,7 +565,7 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i])); } - Status s = ReturnCreatedShape(dims, out); + absl::Status s = ReturnCreatedShape(dims, out); if (s.ok()) { // Merge the new shape with s0. Since s0 and s1 are merged, this implies // that s1 and out are also merged. @@ -572,18 +574,19 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, return s; } -Status InferenceContext::Subshape(ShapeHandle s, int64_t start, - ShapeHandle* out) { +absl::Status InferenceContext::Subshape(ShapeHandle s, int64_t start, + ShapeHandle* out) { return Subshape(s, start, std::numeric_limits::max() /* end */, out); } -Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end, - ShapeHandle* out) { +absl::Status InferenceContext::Subshape(ShapeHandle s, int64_t start, + int64_t end, ShapeHandle* out) { return Subshape(s, start, end, 1 /* stride */, out); } -Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end, - int64_t stride, ShapeHandle* out) { +absl::Status InferenceContext::Subshape(ShapeHandle s, int64_t start, + int64_t end, int64_t stride, + ShapeHandle* out) { int64_t start_in = start; int64_t end_in = end; @@ -642,8 +645,8 @@ Status InferenceContext::Subshape(ShapeHandle s, int64_t start, int64_t end, return ReturnCreatedShape(dims, out); } -Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, - ShapeHandle* out) { +absl::Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, + ShapeHandle* out) { if (!RankKnown(s1) || !RankKnown(s2)) { return ReturnUnknownShape(out); } @@ -657,8 +660,9 @@ Status InferenceContext::Concatenate(ShapeHandle s1, ShapeHandle s2, return ReturnCreatedShape(dims, out); } -Status InferenceContext::ReplaceDim(ShapeHandle s, int64_t dim_index_in, - DimensionHandle new_dim, ShapeHandle* out) { +absl::Status InferenceContext::ReplaceDim(ShapeHandle s, int64_t dim_index_in, + DimensionHandle new_dim, + ShapeHandle* out) { if (!RankKnown(s)) { return ReturnUnknownShape(out); } @@ -721,7 +725,8 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1, return MakeShape({dim1, dim2}); } -Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( +absl::Status +InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( int input_idx, ShapeHandle* out) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape)); @@ -740,8 +745,8 @@ Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape( input_tensor(input_idx), input_shape, out); } -Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, - ShapeHandle* out) { +absl::Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, + ShapeHandle* out) { ShapeHandle input_shape; TF_RETURN_IF_ERROR(WithRank(input(input_idx), 1, &input_shape)); @@ -759,15 +764,15 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx, input_tensor(input_idx), input_shape, out); } -Status InferenceContext::MakeShapeFromTensor(const Tensor* t, - ShapeHandle tensor_shape, - ShapeHandle* out) { +absl::Status InferenceContext::MakeShapeFromTensor(const Tensor* t, + ShapeHandle tensor_shape, + ShapeHandle* out) { return InternalMakeShapeFromTensor( false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape, out); } -Status InferenceContext::InternalMakeShapeFromTensor( +absl::Status InferenceContext::InternalMakeShapeFromTensor( bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out) { // Only callers who have set @@ -879,7 +884,7 @@ Status InferenceContext::InternalMakeShapeFromTensor( return ReturnCreatedShape(dims, out); } -Status InferenceContext::MakeShapeFromPartialTensorShape( +absl::Status InferenceContext::MakeShapeFromPartialTensorShape( const PartialTensorShape& partial_shape, ShapeHandle* out) { *out = nullptr; if (partial_shape.dims() == -1) { @@ -895,8 +900,8 @@ Status InferenceContext::MakeShapeFromPartialTensorShape( return ReturnCreatedShape(dims, out); } -Status InferenceContext::MakeShapeFromTensorShape(const TensorShape& shape, - ShapeHandle* out) { +absl::Status InferenceContext::MakeShapeFromTensorShape( + const TensorShape& shape, ShapeHandle* out) { return MakeShapeFromPartialTensorShape(PartialTensorShape(shape.dim_sizes()), out); } @@ -914,15 +919,16 @@ TensorShapeProto InferenceContext::ShapeHandleToProto(ShapeHandle handle) { return out; } -Status InferenceContext::MakeShapeFromShapeProto(const TensorShapeProto& proto, - ShapeHandle* out) { +absl::Status InferenceContext::MakeShapeFromShapeProto( + const TensorShapeProto& proto, ShapeHandle* out) { *out = nullptr; TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(proto)); PartialTensorShape partial_shape(proto); return MakeShapeFromPartialTensorShape(partial_shape, out); } -Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t* val) { +absl::Status InferenceContext::GetScalarFromTensor(const Tensor* t, + int64_t* val) { // Caller must ensure that is not NULL. const int rank = t->dims(); if (rank != 0) { @@ -944,8 +950,8 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t* val) { } } -Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx, - int64_t* val) { +absl::Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx, + int64_t* val) { // Caller must ensure that is not NULL. const int rank = t->dims(); if (rank != 1) { @@ -974,7 +980,8 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64_t idx, } // Returns a new dimension whose value is given by a scalar input tensor. -Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { +absl::Status InferenceContext::MakeDimForScalarInput(int idx, + DimensionHandle* out) { int64_t val; const Tensor* t = input_tensor(idx); if (t == nullptr) { @@ -990,7 +997,7 @@ Status InferenceContext::MakeDimForScalarInput(int idx, DimensionHandle* out) { return absl::OkStatus(); } -Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( +absl::Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( int idx, int input_rank, DimensionHandle* out) { int64_t val; const Tensor* t = input_tensor(idx); @@ -1019,9 +1026,10 @@ Status InferenceContext::MakeDimForScalarInputWithNegativeIndexing( return absl::OkStatus(); } -Status InferenceContext::Divide(DimensionHandle dividend, - DimensionOrConstant divisor, - bool evenly_divisible, DimensionHandle* out) { +absl::Status InferenceContext::Divide(DimensionHandle dividend, + DimensionOrConstant divisor, + bool evenly_divisible, + DimensionHandle* out) { const int64_t divisor_value = Value(divisor); if (divisor_value == 1) { *out = dividend; @@ -1044,8 +1052,9 @@ Status InferenceContext::Divide(DimensionHandle dividend, return absl::OkStatus(); } -Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out) { +absl::Status InferenceContext::Add(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { const int64_t first_value = Value(first); const int64_t second_value = Value(second); // Special cases. @@ -1070,9 +1079,9 @@ Status InferenceContext::Add(DimensionHandle first, DimensionOrConstant second, return absl::OkStatus(); } -Status InferenceContext::Subtract(DimensionHandle first, - DimensionOrConstant second, - DimensionHandle* out) { +absl::Status InferenceContext::Subtract(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { const int64_t first_value = Value(first); const int64_t second_value = Value(second); // Special cases. @@ -1093,9 +1102,9 @@ Status InferenceContext::Subtract(DimensionHandle first, return absl::OkStatus(); } -Status InferenceContext::Multiply(DimensionHandle first, - DimensionOrConstant second, - DimensionHandle* out) { +absl::Status InferenceContext::Multiply(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { const int64_t first_value = Value(first); const int64_t second_value = Value(second); // Special cases. @@ -1122,8 +1131,9 @@ Status InferenceContext::Multiply(DimensionHandle first, return absl::OkStatus(); } -Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out) { +absl::Status InferenceContext::Min(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { const int64_t first_value = Value(first); const int64_t second_value = Value(second); if (first_value == 0) { @@ -1142,8 +1152,9 @@ Status InferenceContext::Min(DimensionHandle first, DimensionOrConstant second, return absl::OkStatus(); } -Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out) { +absl::Status InferenceContext::Max(DimensionHandle first, + DimensionOrConstant second, + DimensionHandle* out) { const int64_t first_value = Value(first); const int64_t second_value = Value(second); if (first_value == kUnknownDim || second_value == kUnknownDim) { @@ -1158,7 +1169,7 @@ Status InferenceContext::Max(DimensionHandle first, DimensionOrConstant second, return absl::OkStatus(); } -Status InferenceContext::AttachContext(const Status& status) { +absl::Status InferenceContext::AttachContext(const absl::Status& status) { std::vector input_shapes; input_shapes.reserve(inputs_.size()); for (const ShapeHandle& input_shape : inputs_) { @@ -1243,8 +1254,8 @@ bool InferenceContext::MergeHandleShapesAndTypes( bool InferenceContext::MergeOutputHandleShapesAndTypes( int idx, const std::vector& shapes_and_types) { if (output_handle_shapes_and_types_[idx] == nullptr) { - output_handle_shapes_and_types_[idx].reset( - new std::vector(shapes_and_types)); + output_handle_shapes_and_types_[idx] = + std::make_unique>(shapes_and_types); return true; } return MergeHandleShapesAndTypes(shapes_and_types, @@ -1254,8 +1265,8 @@ bool InferenceContext::MergeOutputHandleShapesAndTypes( bool InferenceContext::MergeInputHandleShapesAndTypes( int idx, const std::vector& shapes_and_types) { if (input_handle_shapes_and_types_[idx] == nullptr) { - input_handle_shapes_and_types_[idx].reset( - new std::vector(shapes_and_types)); + input_handle_shapes_and_types_[idx] = + std::make_unique>(shapes_and_types); return true; } return MergeHandleShapesAndTypes(shapes_and_types, @@ -1293,8 +1304,8 @@ bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( << "Got idx: " << idx << " but only " << output_handle_shapes_and_types_.size() << " inputs."; if (output_handle_shapes_and_types_[idx] == nullptr) { - output_handle_shapes_and_types_[idx].reset( - new std::vector(shapes_and_types)); + output_handle_shapes_and_types_[idx] = + std::make_unique>(shapes_and_types); return true; } return RelaxHandleShapesAndMergeTypes( @@ -1308,8 +1319,8 @@ bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( << "Got idx: " << idx << " but only " << input_handle_shapes_and_types_.size() << " inputs."; if (input_handle_shapes_and_types_[idx] == nullptr) { - input_handle_shapes_and_types_[idx].reset( - new std::vector(shapes_and_types)); + input_handle_shapes_and_types_[idx] = + std::make_unique>(shapes_and_types); return true; } return RelaxHandleShapesAndMergeTypes( diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 6ed932e0c78189..4c02335ba82f82 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -275,8 +275,9 @@ class InferenceContext { // argument, returns the status of the inference. // // On error, additional context is provided in the error message. - Status Run( - const std::function& fn); + absl::Status Run( + const std::function& + fn); // Merge the stored shape of the input in position idx with according // to the following rules: @@ -339,7 +340,8 @@ class InferenceContext { void SetInput(int idx, ShapeHandle shape) { inputs_[idx] = shape; } ShapeHandle input(int64_t idx) const { return inputs_[idx]; } - Status input(StringPiece input_name, std::vector* output) const; + absl::Status input(StringPiece input_name, + std::vector* output) const; int num_inputs() const { return inputs_.size(); } // Returns the input tensor at index , or nullptr if the input tensor is @@ -392,16 +394,17 @@ class InferenceContext { ShapeHandle output(int64_t idx) const { return outputs_.at(idx); } void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; } - Status set_output(StringPiece output_name, - const std::vector& shapes); + absl::Status set_output(StringPiece output_name, + const std::vector& shapes); int num_outputs() const { return outputs_.size(); } ShapeHandle output(int idx) const { return outputs_.at(idx); } - Status output(StringPiece output_name, - std::vector* output) const; + absl::Status output(StringPiece output_name, + std::vector* output) const; // Returns the value for attribute named `attr_name`. - Status GetAttr(StringPiece attr_name, const AttrValue** attr_value) const { + absl::Status GetAttr(StringPiece attr_name, + const AttrValue** attr_value) const { return attrs_.Find(attr_name, attr_value); } const AttrValue* GetAttr(StringPiece attr_name) const { @@ -464,67 +467,69 @@ class InferenceContext { // the shape with asserted rank in <*out>. Otherwise return an error. // // Note that <*out> may be set to . - Status WithRank(ShapeHandle shape, int64_t rank, - ShapeHandle* out) TF_MUST_USE_RESULT; - Status WithRankAtLeast(ShapeHandle shape, int64_t rank, - ShapeHandle* out) TF_MUST_USE_RESULT; - Status WithRankAtMost(ShapeHandle shape, int64_t rank, + absl::Status WithRank(ShapeHandle shape, int64_t rank, ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status WithRankAtLeast(ShapeHandle shape, int64_t rank, + ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status WithRankAtMost(ShapeHandle shape, int64_t rank, + ShapeHandle* out) TF_MUST_USE_RESULT; // If has value , or its value is unknown, returns OK and returns // the dimension with asserted value in <*out>. Otherwise returns an error. // // Note that <*out> may be set to . - Status WithValue(DimensionHandle dim, int64_t value, - DimensionHandle* out) TF_MUST_USE_RESULT; + absl::Status WithValue(DimensionHandle dim, int64_t value, + DimensionHandle* out) TF_MUST_USE_RESULT; // Merges and and returns the merged shape in <*out>. See // 'MergeInput' function for full details and examples. - Status Merge(ShapeHandle s0, ShapeHandle s1, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Merge(ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) TF_MUST_USE_RESULT; // Asserts that 's rank >= 's rank, and the first // dimensions of are compatible with the dimensions of // . // Returns the merged results in <*s_out> and <*prefix_out>. - Status MergePrefix(ShapeHandle s, ShapeHandle prefix, ShapeHandle* s_out, - ShapeHandle* prefix_out) TF_MUST_USE_RESULT; + absl::Status MergePrefix(ShapeHandle s, ShapeHandle prefix, + ShapeHandle* s_out, + ShapeHandle* prefix_out) TF_MUST_USE_RESULT; // Merges and and returns the merged dimension in <*out>. If // and have incompatible values, returns an error. // // Note that <*out> may be set to or . - Status Merge(DimensionHandle d0, DimensionHandle d1, - DimensionHandle* out) TF_MUST_USE_RESULT; + absl::Status Merge(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) TF_MUST_USE_RESULT; // Returns in <*out> a sub-shape of with dimensions [start:]. // can be negative to index from the end of the shape. If > // rank of , then an empty subshape is returned. - Status Subshape(ShapeHandle s, int64_t start, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Subshape(ShapeHandle s, int64_t start, + ShapeHandle* out) TF_MUST_USE_RESULT; // Returns in <*out> a sub-shape of , with dimensions [start:end]. // and can be negative, to index from the end of the shape. // and are set to the rank of if > rank of . - Status Subshape(ShapeHandle s, int64_t start, int64_t end, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, + ShapeHandle* out) TF_MUST_USE_RESULT; // Returns in <*out> a sub-shape of , with dimensions [start:end:stride]. // and can be negative, to index from the end of the shape. // and are set to the rank of if > rank of . // can be negative, to reverse the . - Status Subshape(ShapeHandle s, int64_t start, int64_t end, int64_t stride, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Subshape(ShapeHandle s, int64_t start, int64_t end, + int64_t stride, ShapeHandle* out) TF_MUST_USE_RESULT; // Returns in <*out> the result of appending the dimensions of to those // of . - Status Concatenate(ShapeHandle s1, ShapeHandle s2, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status Concatenate(ShapeHandle s1, ShapeHandle s2, + ShapeHandle* out) TF_MUST_USE_RESULT; // Returns in the shape from replacing with // . - Status ReplaceDim(ShapeHandle s, int64_t dim_index, DimensionHandle new_dim, - ShapeHandle* out) TF_MUST_USE_RESULT; + absl::Status ReplaceDim(ShapeHandle s, int64_t dim_index, + DimensionHandle new_dim, + ShapeHandle* out) TF_MUST_USE_RESULT; // Returns a new shape with the given dims. The returned value is owned by // this context. @@ -549,24 +554,25 @@ class InferenceContext { // Returns in a new shape whose dimension sizes come from input tensor // . The tensor must be a 1-dimensional int32 or int64 tensor. If // the input tensor is NULL, then an unknown shape is returned. - Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); + absl::Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out); // Like the function above, but treats scalar values as unknown // shapes. **NOTE** If the scalar is statically known, its value // must be -1 or an error is returned. - Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx, - ShapeHandle* out); + absl::Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + int input_idx, ShapeHandle* out); // Returns in a new shape corresponding to . - Status MakeShapeFromShapeProto(const TensorShapeProto& proto, - ShapeHandle* out); + absl::Status MakeShapeFromShapeProto(const TensorShapeProto& proto, + ShapeHandle* out); // Returns in a new shape corresponding to . - Status MakeShapeFromPartialTensorShape( + absl::Status MakeShapeFromPartialTensorShape( const PartialTensorShape& partial_shape, ShapeHandle* out); // Returns in a new shape corresponding to . - Status MakeShapeFromTensorShape(const TensorShape& shape, ShapeHandle* out); + absl::Status MakeShapeFromTensorShape(const TensorShape& shape, + ShapeHandle* out); absl::StatusOr MakeShapeFromShapeTensor( const TensorShape& shape); @@ -581,61 +587,62 @@ class InferenceContext { // Returns in a scalar value from an input tensor . The input tensor // must be a 0-dimensional int32 or int64 tensor. Caller must ensure that the // input tensor is not NULL. - Status GetScalarFromTensor(const Tensor* t, int64_t* val); + absl::Status GetScalarFromTensor(const Tensor* t, int64_t* val); // Returns in a scalar value from a 1D input tensor with int32 or // int64 elements. Caller must ensure that the input tensor is not NULL. - Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val); + absl::Status GetScalarFromTensor(const Tensor* t, int64_t idx, int64_t* val); // Returns a new dimension whose value is given by a scalar input tensor. // The input tensor must be in host memory, since it is dereferenced to get // the value. - Status MakeDimForScalarInput(int idx, DimensionHandle* out); + absl::Status MakeDimForScalarInput(int idx, DimensionHandle* out); // Returns a new dimension whose value is given by a scalar input tensor. // This allows for a negative input dimension given the rank of a separate // tensor. This rank can be negative if unknown. // The input tensor must be in host memory, since it is dereferenced to get // the value. - Status MakeDimForScalarInputWithNegativeIndexing(int idx, int input_rank, - DimensionHandle* out); + absl::Status MakeDimForScalarInputWithNegativeIndexing(int idx, + int input_rank, + DimensionHandle* out); // Look up the attr being evaluated with name attr_name and set *value to its // value. If no attr with attr_name is found in def(), or the attr does not // have a matching type, a non-ok status will be returned. template - Status GetAttr(StringPiece attr_name, T* value) const; + absl::Status GetAttr(StringPiece attr_name, T* value) const; // Returns in the result of dividing by . // Returns an error if is not positive or if // and does not evenly divide . - Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, - bool evenly_divisible, DimensionHandle* out); + absl::Status Divide(DimensionHandle dividend, DimensionOrConstant divisor, + bool evenly_divisible, DimensionHandle* out); // Returns in the sum of and . - Status Add(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out); + absl::Status Add(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); // Returns in the dimension that is minus . - Status Subtract(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out); + absl::Status Subtract(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); // Returns in the product of and . - Status Multiply(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out); + absl::Status Multiply(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); // Returns in the minimum of and . If either or // is zero the results is zero. Otherwise, if either or // is unknown the results is unknown. - Status Min(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out); + absl::Status Min(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); // Returns in the maximum of and . If either or // is unknown the results is unknown. - Status Max(DimensionHandle first, DimensionOrConstant second, - DimensionHandle* out); + absl::Status Max(DimensionHandle first, DimensionOrConstant second, + DimensionHandle* out); - Status construction_status() const { return construction_status_; } + absl::Status construction_status() const { return construction_status_; } // Methods to propagate shape and dtype on edges of handles. Handles are the // dtype DT_RESOURCE which can be used to access state stored in a @@ -727,8 +734,8 @@ class InferenceContext { // Returns in a new shape whose dimension sizes come from tensor . // The tensor must be a 1-dimensional int32 or int64 tensor. If is NULL, // then an unknown shape is returned. - Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, - ShapeHandle* out); + absl::Status MakeShapeFromTensor(const Tensor* t, ShapeHandle tensor_shape, + ShapeHandle* out); int graph_def_version() const { return graph_def_version_; } @@ -741,7 +748,7 @@ class InferenceContext { } // Adds new outputs; useful when mutating the graph. - Status ExpandOutputs(int new_output_size); + absl::Status ExpandOutputs(int new_output_size); private: // Creates and stores shapes for use in InferenceContext. @@ -786,18 +793,18 @@ class InferenceContext { void PostInputInit(std::vector>> input_handle_data); - Status ReturnUnknownShape(ShapeHandle* out) { + absl::Status ReturnUnknownShape(ShapeHandle* out) { *out = UnknownShape(); return absl::OkStatus(); } - Status ReturnCreatedShape(const std::vector& dims, - ShapeHandle* out) { + absl::Status ReturnCreatedShape(const std::vector& dims, + ShapeHandle* out) { *out = MakeShape(dims); return absl::OkStatus(); } // Adds additional context to the given status. - Status AttachContext(const Status& status); + absl::Status AttachContext(const absl::Status& status); // Relaxes an existing value with a new value and returns the // relaxed dimension in <*out>. If and have incompatible @@ -829,7 +836,7 @@ class InferenceContext { } // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor. - Status InternalMakeShapeFromTensor( + absl::Status InternalMakeShapeFromTensor( bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t, ShapeHandle tensor_shape, ShapeHandle* out); @@ -871,7 +878,7 @@ class InferenceContext { // An error set during construction. TODO(cwhipkey): remove when test // constructor is removed. - Status construction_status_; + absl::Status construction_status_; // Pair of shape or dim handles that are equivalent, ie that represent the // same underlying shape of dimension. Note that for each pair at least one of @@ -912,7 +919,7 @@ inline DimensionOrConstant::DimensionOrConstant(int64_t val) : val(val) { } template -Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { +absl::Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { return GetNodeAttr(attrs_, attr_name, value); } diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index e550ff8fb88374..c5dbc299b86540 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -1106,7 +1106,7 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) { InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {}, {}); ShapeHandle out; - Status s = c.MakeShapeFromShapeTensor(0, &out); + absl::Status s = c.MakeShapeFromShapeTensor(0, &out); if (s.ok()) { return c.DebugString(out); } else { diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index 34574b6e54ede1..b4cd528a4470c6 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -30,9 +30,9 @@ namespace shape_inference { using errors::Unknown; -Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, - const string& ins, - const string& expected_outs) { +absl::Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, + const string& ins, + const string& expected_outs) { const OpRegistrationData* op_reg_data; TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUp(op.name, &op_reg_data)); @@ -230,7 +230,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } // static -Status ShapeInferenceTestutil::MakeShapeFromString( +absl::Status ShapeInferenceTestutil::MakeShapeFromString( InferenceContext::ShapeManager* manager, const string& spec, ShapeHandle* output) { if (spec == "?") { diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 769dcde453a221..d65965b43c2b51 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -15,14 +15,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ #define TENSORFLOW_CORE_FRAMEWORK_SHAPE_INFERENCE_TESTUTIL_H_ +#include #include -#include "absl/strings/string_view.h" +#include "absl/status/status.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" @@ -68,35 +67,35 @@ class ShapeInferenceTestutil { // the second is which dimension in that input it corresponds to. // can be "e"; this is used to indicate that shape inference // should have failed. - static Status InferShapes(ShapeInferenceTestOp op, const string& ins, - const string& expected_outs); + static absl::Status InferShapes(ShapeInferenceTestOp op, const string& ins, + const string& expected_outs); private: - ShapeInferenceTestutil() {} + ShapeInferenceTestutil() = default; // Makes a shape out of 'spec'. - static Status MakeShapeFromString(InferenceContext::ShapeManager* manager, - const string& spec, ShapeHandle* output); + static absl::Status MakeShapeFromString( + InferenceContext::ShapeManager* manager, const string& spec, + ShapeHandle* output); }; } // namespace shape_inference -#define INFER_OK(op, i, o) \ - EXPECT_EQ( \ - ::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ - op, i, o), \ - ::tensorflow::OkStatus()) -#define INFER_ERROR(error_substring, op, i) \ - { \ - tensorflow::Status status = \ - (::tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ - op, i, "e")); \ - std::string error_message = status.ToString(); \ - const std::string substring = std::string(error_substring); \ - EXPECT_NE(status, ::tensorflow::OkStatus()); \ - EXPECT_TRUE(absl::StrContains(error_message, substring)) \ - << "Expected to see '" << substring << "' in '" << error_message \ - << "'"; \ +#define INFER_OK(op, i, o) \ + EXPECT_EQ(tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, o), \ + absl::OkStatus()) + +#define INFER_ERROR(error_substring, op, i) \ + { \ + absl::Status status = \ + (tensorflow::shape_inference::ShapeInferenceTestutil::InferShapes( \ + op, i, "e")); \ + std::string error_message = status.ToString(); \ + EXPECT_NE(status, absl::OkStatus()); \ + EXPECT_TRUE(absl::StrContains(error_message, error_substring)) \ + << "Expected to see '" << error_substring << "' in '" << error_message \ + << "'"; \ } } // namespace tensorflow diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index 83fdf7829c3f2b..5b89a82f861b9a 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -60,7 +60,8 @@ class StatsAggregator { virtual void EncodeToProto(Summary* out_summary) = 0; // Sets a `summary_writer` with this stats_aggregator. - virtual Status SetSummaryWriter(SummaryWriterInterface* summary_writer) = 0; + virtual absl::Status SetSummaryWriter( + SummaryWriterInterface* summary_writer) = 0; // Increment the `label` cell of metrics mapped with `name` by given `value`. virtual void IncrementCounter(const string& name, const string& label, diff --git a/tensorflow/core/framework/summary.proto b/tensorflow/core/framework/summary.proto index c6b515abea2517..9e219b027fc26a 100644 --- a/tensorflow/core/framework/summary.proto +++ b/tensorflow/core/framework/summary.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package tensorflow; -import public "tsl/protobuf/histogram.proto"; +import public "xla/tsl/protobuf/histogram.proto"; import "tensorflow/core/framework/tensor.proto"; diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index c1278e7187d3e9..f2cd323101c625 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -685,8 +685,8 @@ TensorBuffer* FromProtoField(Allocator* a, // the remaining elements up to n to be the default ResourceHandle() value. const int64_t real_n = n < in_n ? n : in_n; for (int64_t i = 0; i < real_n; ++i) { - Status s = ResourceHandle::BuildResourceHandle(in.resource_handle_val(i), - &data[i]); + absl::Status s = ResourceHandle::BuildResourceHandle( + in.resource_handle_val(i), &data[i]); if (!s.ok()) { LOG(ERROR) << "Could not decode resource handle from proto \"" << in.resource_handle_val(i).ShortDebugString() @@ -861,8 +861,8 @@ std::ostream& operator<<(std::ostream& out, const Tensor& tensor) { return out; } -Status Tensor::BitcastFrom(const Tensor& other, DataType dtype, - const TensorShape& shape) { +absl::Status Tensor::BitcastFrom(const Tensor& other, DataType dtype, + const TensorShape& shape) { int in_size = DataTypeSize(other.dtype()); int out_size = DataTypeSize(dtype); if (in_size == 0) { @@ -992,8 +992,8 @@ Tensor::Tensor(Allocator* a, DataType type, const TensorShape& shape, } } -Status Tensor::BuildTensor(DataType type, const TensorShape& shape, - Tensor* out_tensor) { +absl::Status Tensor::BuildTensor(DataType type, const TensorShape& shape, + Tensor* out_tensor) { // Avoid crashes due to invalid or unsupported types. CASES_WITH_DEFAULT( type, {}, return errors::InvalidArgument("Type not set"), diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 3f46d40bb03f5a..6ca65799276f0a 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -51,15 +51,17 @@ class TensorProto; class Var; namespace batch_util { -Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); -Status CopySliceToElement(const Tensor& parent, Tensor* element, int64_t index); -Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index); -Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst); -Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst); +absl::Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); +absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); +absl::Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, + int64_t index); +absl::Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); +absl::Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); } // namespace batch_util /// @ingroup core @@ -185,8 +187,8 @@ class Tensor { /// validate that the `DataType` is valid and supported. /// /// The underlying buffer is allocated using a `CPUAllocator`. - static Status BuildTensor(DataType type, const TensorShape& shape, - Tensor* out_tensor); + static absl::Status BuildTensor(DataType type, const TensorShape& shape, + Tensor* out_tensor); private: // A tag type for selecting the `Tensor` constructor overload that creates a @@ -657,8 +659,8 @@ class Tensor { /// /// If any of the requirements are not met, errors::InvalidArgument is /// returned. - Status BitcastFrom(const Tensor& other, DataType dtype, - const TensorShape& shape); + absl::Status BitcastFrom(const Tensor& other, DataType dtype, + const TensorShape& shape); /// Like BitcastFrom, but CHECK fails if any preconditions are not met. /// @@ -705,20 +707,20 @@ class Tensor { friend class CastOpBase; // For access to set_dtype. friend class ScopedAllocator; // For access to buf_. friend class PjRtTensorBufferUtil; // For access to buf_. - friend Status batch_util::CopyElementToSlice( + friend absl::Status batch_util::CopyElementToSlice( Tensor element, Tensor* parent, int64_t index); // For access to base(). - friend Status batch_util::CopySliceToElement( + friend absl::Status batch_util::CopySliceToElement( const Tensor& parent, Tensor* element, int64_t index); // For access to base(). - friend Status batch_util::MaybeMoveSliceToElement( + friend absl::Status batch_util::MaybeMoveSliceToElement( Tensor* parent, Tensor* element, int64_t index); // For access to base(). - friend Status batch_util::CopyContiguousSlices( + friend absl::Status batch_util::CopyContiguousSlices( const Tensor& src, int64_t src_offset, int64_t dst_offset, int64_t num_slices, Tensor* dst); // For access to base(). - friend Status batch_util::MaybeMoveContiguousSlices( + friend absl::Status batch_util::MaybeMoveContiguousSlices( Tensor& src, int64_t src_offset, int64_t dst_offset, int64_t num_slices, Tensor* dst); // For access to base(). diff --git a/tensorflow/core/framework/tensor_fuzz.cc b/tensorflow/core/framework/tensor_fuzz.cc index 49f91b021bf9fc..ef04128e0d8328 100644 --- a/tensorflow/core/framework/tensor_fuzz.cc +++ b/tensorflow/core/framework/tensor_fuzz.cc @@ -27,7 +27,7 @@ namespace { void BuildTensorAlwaysSucceedsWithValidTensorShape(DataType type, const TensorShape& shape) { Tensor out; - Status status = Tensor::BuildTensor(type, shape, &out); + absl::Status status = Tensor::BuildTensor(type, shape, &out); TF_EXPECT_OK(status); } FUZZ_TEST(TensorFuzz, BuildTensorAlwaysSucceedsWithValidTensorShape) diff --git a/tensorflow/core/framework/tensor_shape.cc b/tensorflow/core/framework/tensor_shape.cc index 9f0c4fa1681d95..35c628216ed3c6 100644 --- a/tensorflow/core/framework/tensor_shape.cc +++ b/tensorflow/core/framework/tensor_shape.cc @@ -94,7 +94,8 @@ bool TensorShapeBase::IsValid(const TensorShapeProto& proto) { } template -Status TensorShapeBase::IsValidShape(const TensorShapeProto& proto) { +absl::Status TensorShapeBase::IsValidShape( + const TensorShapeProto& proto) { // NOTE(irving): Unfortunately, TensorShape allows parsing protos with // unknown_shape() set, and it seems hard to remove this without backwards // compatibility issues. @@ -155,7 +156,7 @@ TensorShapeBase::TensorShapeBase(const TensorShapeProto& proto) { } template -Status TensorShapeBase::BuildTensorShapeBase( +absl::Status TensorShapeBase::BuildTensorShapeBase( const TensorShapeProto& proto, TensorShapeBase* out) { out->set_tag(REP16); out->set_data_type(DT_INVALID); @@ -169,7 +170,7 @@ Status TensorShapeBase::BuildTensorShapeBase( out->set_ndims_byte(0); out->set_num_elements(1); int64_t num_elements_excluding_zero_dims = 1; - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (const auto& d : proto.dim()) { s = out->AddDimWithStatus(d.size()); if (!s.ok()) { @@ -202,7 +203,7 @@ TensorShapeBase::TensorShapeBase(absl::Span dim_sizes) { } template -Status TensorShapeBase::BuildTensorShapeBase( +absl::Status TensorShapeBase::BuildTensorShapeBase( absl::Span dim_sizes, TensorShapeBase* out) { out->set_tag(REP16); out->set_data_type(DT_INVALID); @@ -224,7 +225,8 @@ static inline bool Set16(bool partial, uint16* dst, int dim, int64_t val) { } template -Status TensorShapeBase::InitDims(absl::Span dim_sizes) { +absl::Status TensorShapeBase::InitDims( + absl::Span dim_sizes) { DCHECK_EQ(tag(), REP16); // Allow sizes that are under kint64max^0.25 so that 4-way multiplication @@ -298,7 +300,7 @@ Status TensorShapeBase::InitDims(absl::Span dim_sizes) { set_ndims_byte(0); set_num_elements(1); - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); for (int64_t s : dim_sizes) { status.Update(AddDimWithStatus(internal::SubtleMustCopy(s))); if (!status.ok()) { @@ -384,7 +386,7 @@ void TensorShapeRep::ClearAllButDataType() { } template -Status TensorShapeBase::RecomputeNumElements() { +absl::Status TensorShapeBase::RecomputeNumElements() { if (unknown_rank()) { set_num_elements(-1); return absl::OkStatus(); @@ -422,7 +424,7 @@ void TensorShapeBase::AddDim(int64_t size) { } template -Status TensorShapeBase::AddDimWithStatus(int64_t size) { +absl::Status TensorShapeBase::AddDimWithStatus(int64_t size) { if (!kIsPartial) { if (TF_PREDICT_FALSE(size < 0)) { return errors::InvalidArgument("Expected a non-negative size, got ", @@ -506,9 +508,9 @@ void TensorShapeBase::AppendShape(const TensorShapeBase& shape) { } template -Status TensorShapeBase::AppendShapeWithStatus( +absl::Status TensorShapeBase::AppendShapeWithStatus( const TensorShapeBase& shape) { - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (auto d : shape) { s.Update(AddDimWithStatus(d.size)); if (!s.ok()) { @@ -534,7 +536,7 @@ void TensorShapeBase::InsertDim(int d, int64_t size) { } template -Status TensorShapeBase::InsertDimWithStatus(int d, int64_t size) { +absl::Status TensorShapeBase::InsertDimWithStatus(int d, int64_t size) { if (!kIsPartial) { if (TF_PREDICT_FALSE(size < 0)) { return errors::InvalidArgument("Expected a non-negative size, got ", @@ -560,7 +562,7 @@ Status TensorShapeBase::InsertDimWithStatus(int d, int64_t size) { vals.insert(vals.begin() + d, size); ClearAllButDataType(); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval)); if (!s.ok()) { @@ -608,7 +610,7 @@ void TensorShapeBase::set_dim(int d, int64_t size) { } template -Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { +absl::Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { if (TF_PREDICT_FALSE(d < 0)) { return errors::InvalidArgument("Index must be non-negative, got ", d); } @@ -635,7 +637,7 @@ Status TensorShapeBase::SetDimWithStatus(int d, int64_t size) { vals[d] = size; ClearAllButDataType(); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval)); if (!s.ok()) { @@ -668,7 +670,8 @@ void TensorShapeBase::RemoveDimRange(int begin, int end) { } template -Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, int end) { +absl::Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, + int end) { if (unknown_rank()) { return absl::OkStatus(); } @@ -700,7 +703,7 @@ Status TensorShapeBase::RemoveDimRangeWithStatus(int begin, int end) { vals.erase(vals.begin() + begin, vals.begin() + end); ClearAllButDataType(); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (auto dval : vals) { s.Update(AddDimWithStatus(dval)); if (!s.ok()) { @@ -809,7 +812,7 @@ bool TensorShapeUtils::EndsWith(const TensorShape& shape, } template -Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) { +absl::Status MakeShapeHelper(const T* dims, int64_t n, Shape* out) { out->Clear(); if (n > TensorShape::MaxDimensions()) { return errors::InvalidArgument("Too many dimensions"); @@ -879,7 +882,7 @@ PartialTensorShape PartialTensorShape::Concatenate(int64_t size) const { return out; } -Status PartialTensorShape::ConcatenateWithStatus( +absl::Status PartialTensorShape::ConcatenateWithStatus( int64_t size, PartialTensorShape* out) const { *out = *this; return out->AddDimWithStatus(size); @@ -895,7 +898,7 @@ PartialTensorShape PartialTensorShape::Concatenate( return out; } -Status PartialTensorShape::ConcatenateWithStatus( +absl::Status PartialTensorShape::ConcatenateWithStatus( const PartialTensorShape& shape, PartialTensorShape* out) const { if (unknown_rank() || shape.unknown_rank()) { *out = PartialTensorShape(); @@ -903,15 +906,15 @@ Status PartialTensorShape::ConcatenateWithStatus( } *out = *this; for (auto dim : shape) { - Status s = out->AddDimWithStatus(dim.size); + absl::Status s = out->AddDimWithStatus(dim.size); if (!s.ok()) return s; } return absl::OkStatus(); } -Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, - PartialTensorShape* result) const { +absl::Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, + PartialTensorShape* result) const { if (unknown_rank()) { *result = shape; return absl::OkStatus(); @@ -933,7 +936,7 @@ Status PartialTensorShape::MergeWith(const PartialTensorShape& shape, } result->Clear(); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); for (int i = 0; i < dims_; ++i) { const int64_t dim0 = dim_size(i); const int64_t dim1 = shape.dim_size(i); @@ -1024,8 +1027,8 @@ bool PartialTensorShapeUtils::AreIdentical( } } -Status TensorShapeUtils::NumElements(absl::Span shape, - int64_t* num_elements) { +absl::Status TensorShapeUtils::NumElements(absl::Span shape, + int64_t* num_elements) { int64_t n = 1; for (auto dim : shape) { n = MultiplyWithoutOverflow(n, dim); diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index b1b7bbade0d18a..0bcf1fc54af844 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -181,14 +181,14 @@ class TensorShapeBase : public TensorShapeRep { // an array of sizes if calling code cannot validate that the sizes specify a // valid `TensorShape`. // The value in `*out` is valid iff the returned value is `Status::OK`. - static Status BuildTensorShapeBase(absl::Span dim_sizes, - TensorShapeBase* out); - static Status BuildTensorShapeBase(std::initializer_list dim_sizes, - TensorShapeBase* out) { + static absl::Status BuildTensorShapeBase(absl::Span dim_sizes, + TensorShapeBase* out); + static absl::Status BuildTensorShapeBase( + std::initializer_list dim_sizes, TensorShapeBase* out) { return BuildTensorShapeBase(absl::Span(dim_sizes), out); } - static Status BuildTensorShapeBase(const TensorShapeProto& proto, - TensorShapeBase* out); + static absl::Status BuildTensorShapeBase(const TensorShapeProto& proto, + TensorShapeBase* out); /// Returns `true` iff `proto` is a valid tensor shape. // For TensorShape, the proto shape must be fully defined. @@ -196,7 +196,7 @@ class TensorShapeBase : public TensorShapeRep { /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error /// status otherwise. - static Status IsValidShape(const TensorShapeProto& proto); + static absl::Status IsValidShape(const TensorShapeProto& proto); /// Returns `true` iff this is a valid tensor shape. bool IsValid(); @@ -207,14 +207,14 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `AddDim` but returns a `Status`. /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes. - Status AddDimWithStatus(int64_t size); + absl::Status AddDimWithStatus(int64_t size); /// Appends all the dimensions from `shape`. void AppendShape(const TensorShapeBase& shape); /// Same as `RemoveDim` but returns a `Status`. /// Use if you cannot validate all invariants, to prevent `CHECK`-fail. - Status AppendShapeWithStatus(const TensorShapeBase& shape); + absl::Status AppendShapeWithStatus(const TensorShapeBase& shape); /// \brief Insert a dimension somewhere in the `TensorShape`. /// REQUIRES: `0 <= d <= dims()` @@ -224,7 +224,7 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `InsertDim` but returns a `Status`. /// Use if unsure if requirements in `InsertDim` are satistified, to prevent /// `CHECK`-fail crashes. - Status InsertDimWithStatus(int d, int64_t size); + absl::Status InsertDimWithStatus(int d, int64_t size); /// \brief Modifies the size of the dimension `d` to be `size` /// REQUIRES: `0 <= d < dims()` @@ -234,7 +234,7 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `set_dim` but returns a `Status`. /// Use if unsure if requirements in `set_dim` are satistified, to prevent /// `CHECK`-fail crashes. - Status SetDimWithStatus(int d, int64_t size); + absl::Status SetDimWithStatus(int d, int64_t size); /// \brief Removes dimension `d` from the `TensorShape`. /// REQUIRES: `0 <= d < dims()` @@ -245,7 +245,7 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `RemoveDim` but returns a `Status`. /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes. - Status RemoveDimWithStatus(int64_t d) { + absl::Status RemoveDimWithStatus(int64_t d) { if (TF_PREDICT_FALSE(d < 0)) { return errors::Internal( "Expected dimension index to be non-negative, got ", d); @@ -262,7 +262,7 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `RemoveLastDims` but returns a `Status`. /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes. - Status RemoveLastDimsWithStatus(int64_t n) { + absl::Status RemoveLastDimsWithStatus(int64_t n) { if (TF_PREDICT_FALSE(n > dims())) { return errors::Internal("Expected dimension index to be at most ", dims(), " got ", n); @@ -280,7 +280,7 @@ class TensorShapeBase : public TensorShapeRep { /// Same as `RemoveDimRange` but returns a `Status`. /// Use if unsure if requirements in `RemoveDimRange` are satistified, to /// prevent `CHECK`-fail crashes. - Status RemoveDimRangeWithStatus(int begin, int end); + absl::Status RemoveDimRangeWithStatus(int begin, int end); /// Return whether the rank is unknown bool unknown_rank() const { @@ -324,8 +324,8 @@ class TensorShapeBase : public TensorShapeRep { explicit TensorShapeBase(DataType dt); private: - Status RecomputeNumElements(); - Status InitDims(absl::Span dim_sizes); + absl::Status RecomputeNumElements(); + absl::Status InitDims(absl::Span dim_sizes); // True for PartialTensorShape, false for TensorShape static constexpr bool kIsPartial = @@ -338,7 +338,7 @@ class TensorShapeBase : public TensorShapeRep { // For use by TensorShapeUtils::MakeShape template - friend Status MakeShapeHelper(const T*, int64_t, S*); + friend absl::Status MakeShapeHelper(const T*, int64_t, S*); }; /// Outputs `TensorShapeBase` to `std::ostream`. @@ -364,16 +364,16 @@ class TensorShape : public TensorShapeBase { // an array of sizes if calling code cannot validate that the sizes specify a // valid `TensorShape`. // The value in `*out` is valid iff the returned value is `Status::OK`. - static Status BuildTensorShape(absl::Span dim_sizes, - TensorShape* out) { + static absl::Status BuildTensorShape(absl::Span dim_sizes, + TensorShape* out) { return BuildTensorShapeBase(dim_sizes, out); } - static Status BuildTensorShape(std::initializer_list dim_sizes, - TensorShape* out) { + static absl::Status BuildTensorShape(std::initializer_list dim_sizes, + TensorShape* out) { return BuildTensorShape(absl::Span(dim_sizes), out); } - static Status BuildTensorShape(const TensorShapeProto& proto, - TensorShape* out) { + static absl::Status BuildTensorShape(const TensorShapeProto& proto, + TensorShape* out) { return BuildTensorShapeBase(proto, out); } @@ -402,7 +402,8 @@ class TensorShape : public TensorShapeBase { // not equal to `dims()`. // Caller must take ownership of `out`. template - Status AsEigenDSizesWithStatus(Eigen::DSizes* out) const; + absl::Status AsEigenDSizesWithStatus( + Eigen::DSizes* out) const; /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in /// which case we pad the rest of the sizes with 1. @@ -416,7 +417,7 @@ class TensorShape : public TensorShapeBase { // not equal to `dims()`. // Caller must take ownership of `out`. template - Status AsEigenDSizesWithPaddingWithStatus( + absl::Status AsEigenDSizesWithPaddingWithStatus( Eigen::DSizes* out) const; private: @@ -506,18 +507,21 @@ class TensorShapeUtils { /// \brief Returns a `TensorShape` whose dimensions are /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. - static Status MakeShape(const int32* dims, int64_t n, TensorShape* out); - static Status MakeShape(const int64_t* dims, int64_t n, TensorShape* out); - static Status MakeShape(absl::Span shape, TensorShape* out); - static Status MakeShape(absl::Span shape, TensorShape* out); - static Status MakeShape(const int32* dims, int64_t n, - PartialTensorShape* out); - static Status MakeShape(const int64_t* dims, int64_t n, - PartialTensorShape* out); - static Status MakeShape(absl::Span shape, - PartialTensorShape* out); - static Status MakeShape(absl::Span shape, - PartialTensorShape* out); + static absl::Status MakeShape(const int32* dims, int64_t n, TensorShape* out); + static absl::Status MakeShape(const int64_t* dims, int64_t n, + TensorShape* out); + static absl::Status MakeShape(absl::Span shape, + TensorShape* out); + static absl::Status MakeShape(absl::Span shape, + TensorShape* out); + static absl::Status MakeShape(const int32* dims, int64_t n, + PartialTensorShape* out); + static absl::Status MakeShape(const int64_t* dims, int64_t n, + PartialTensorShape* out); + static absl::Status MakeShape(absl::Span shape, + PartialTensorShape* out); + static absl::Status MakeShape(absl::Span shape, + PartialTensorShape* out); static std::string ShapeListString( const absl::Span& shapes); @@ -531,8 +535,8 @@ class TensorShapeUtils { /// \brief Returns the product of values in an int64 array, /// or a failing Status if the array represents a value larger than /// a `TensorShape` can hold. - static Status NumElements(absl::Span shape, - int64_t* num_elements); + static absl::Status NumElements(absl::Span shape, + int64_t* num_elements); }; /// Manages the partially known dimensions of a Tensor and their sizes. @@ -545,16 +549,16 @@ class PartialTensorShape : public TensorShapeBase { // an array of sizes if calling code cannot validate that the sizes specify a // valid `PartialTensorShape`. // The value in `*out` is valid iff the returned value is `Status::OK`. - static Status BuildPartialTensorShape(absl::Span dim_sizes, - PartialTensorShape* out) { + static absl::Status BuildPartialTensorShape( + absl::Span dim_sizes, PartialTensorShape* out) { return BuildTensorShapeBase(dim_sizes, out); } - static Status BuildPartialTensorShape( + static absl::Status BuildPartialTensorShape( std::initializer_list dim_sizes, PartialTensorShape* out) { return BuildPartialTensorShape(absl::Span(dim_sizes), out); } - static Status BuildPartialTensorShape(const TensorShapeProto& proto, - PartialTensorShape* out) { + static absl::Status BuildPartialTensorShape(const TensorShapeProto& proto, + PartialTensorShape* out) { return BuildTensorShapeBase(proto, out); } @@ -573,7 +577,8 @@ class PartialTensorShape : public TensorShapeBase { /// Similar to `Concatenate` but returning `Status`. /// Use if calling code cannot validate all requirements and if `CHECK`-fails /// are to be avoided. - Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const; + absl::Status ConcatenateWithStatus(int64_t size, + PartialTensorShape* out) const; /// Appends all the dimensions from `shape`. Returns a new /// PartialTensorShape. @@ -582,14 +587,14 @@ class PartialTensorShape : public TensorShapeBase { /// Similar to `Concatenate` but returning `Status`. /// Use if calling code cannot validate all requirements and if `CHECK`-fails /// are to be avoided. - Status ConcatenateWithStatus(const PartialTensorShape& shape, - PartialTensorShape* out) const; + absl::Status ConcatenateWithStatus(const PartialTensorShape& shape, + PartialTensorShape* out) const; /// Merges all the dimensions from `shape`. Returns /// `InvalidArgument` error if either `shape` has a different rank /// or if any of the dimensions are incompatible. - Status MergeWith(const PartialTensorShape& shape, - PartialTensorShape* result) const; + absl::Status MergeWith(const PartialTensorShape& shape, + PartialTensorShape* result) const; /// Exact equality test. Returns true iff the ranks match (i.e., both are /// unknown, or both are known and equal), and all dimensions are equal (i.e., @@ -611,8 +616,8 @@ class PartialTensorShape : public TensorShapeBase { /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are /// considered "unknown". template - static Status MakePartialShape(const T* dims, int n, - PartialTensorShape* out) { + static absl::Status MakePartialShape(const T* dims, int n, + PartialTensorShape* out) { return TensorShapeUtils::MakeShape(dims, n, out); } }; @@ -670,7 +675,7 @@ Eigen::DSizes TensorShape::AsEigenDSizes() const { } template -Status TensorShape::AsEigenDSizesWithStatus( +absl::Status TensorShape::AsEigenDSizesWithStatus( Eigen::DSizes* out) const { if (TF_PREDICT_FALSE(NDIMS != dims())) { return errors::Internal("Asking for tensor of ", NDIMS, @@ -688,7 +693,7 @@ Eigen::DSizes TensorShape::AsEigenDSizesWithPadding() const { } template -Status TensorShape::AsEigenDSizesWithPaddingWithStatus( +absl::Status TensorShape::AsEigenDSizesWithPaddingWithStatus( Eigen::DSizes* out) const { if (TF_PREDICT_FALSE(NDIMS < dims())) { return errors::Internal("Asking for tensor of at most ", NDIMS, diff --git a/tensorflow/core/framework/tensor_shape_fuzz.cc b/tensorflow/core/framework/tensor_shape_fuzz.cc index d14284e5530c96..d5958e4a96cd8c 100644 --- a/tensorflow/core/framework/tensor_shape_fuzz.cc +++ b/tensorflow/core/framework/tensor_shape_fuzz.cc @@ -28,7 +28,7 @@ namespace { void FuzzTensorShape(const std::vector& dim_sizes) { TensorShape out; - Status status = TensorShape::BuildTensorShape(dim_sizes, &out); + absl::Status status = TensorShape::BuildTensorShape(dim_sizes, &out); if (!dim_sizes.empty() && dim_sizes.size() < 5) { const auto [min, max] = std::minmax_element(dim_sizes.begin(), dim_sizes.end()); @@ -42,7 +42,8 @@ FUZZ_TEST(TensorShapeFuzz, FuzzTensorShape); void FuzzPartialTensorShape(const std::vector& dim_sizes) { PartialTensorShape out; - Status status = PartialTensorShape::BuildPartialTensorShape(dim_sizes, &out); + absl::Status status = + PartialTensorShape::BuildPartialTensorShape(dim_sizes, &out); if (!dim_sizes.empty() && dim_sizes.size() < 5) { const auto [min, max] = std::minmax_element(dim_sizes.begin(), dim_sizes.end()); @@ -58,7 +59,7 @@ void FuzzSetDimWithStatus(TensorShape shape, int dim, int64_t value) { int initial_rank = shape.dims(); bool should_be_ok = shape.dims() == 2 && shape.dim_size(0) <= 100 && shape.dim_size(1) <= 100 && dim < 2 && value < 100; - Status status = shape.SetDimWithStatus(dim, value); + absl::Status status = shape.SetDimWithStatus(dim, value); if (status.ok()) { EXPECT_EQ(initial_rank, shape.dims()); EXPECT_EQ(value, shape.dim_size(dim)); @@ -74,7 +75,7 @@ void FuzzRemoveDimWithStatus(TensorShape shape, int dim) { auto initial_rank = shape.dims(); bool should_be_ok = shape.dims() == 2 && shape.dim_size(0) <= 100 && shape.dim_size(1) <= 100 && dim >= 0 && dim < 2; - Status status = shape.RemoveDimWithStatus(dim); + absl::Status status = shape.RemoveDimWithStatus(dim); if (status.ok()) { EXPECT_EQ(shape.dims(), initial_rank - 1); } else { diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index c13a16fd3c8004..57c2a34862aae4 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -427,7 +427,7 @@ TEST(TensorShapeTest, ostream) { TEST(TensorShapeTest, AddDimWithStatus) { TensorShape s({10, 5, 20}); - Status status = s.AddDimWithStatus(400); + absl::Status status = s.AddDimWithStatus(400); EXPECT_TRUE(status.ok()); EXPECT_EQ(400000, s.num_elements()); ASSERT_EQ(4, s.dims()); @@ -458,7 +458,7 @@ TEST(TensorShapeTest, AppendShapeWithStatus) { TEST(TensorShapeTest, Factory) { TensorShape s; - Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s); + absl::Status status = TensorShape::BuildTensorShapeBase({10, 5, 20}, &s); EXPECT_TRUE(status.ok()); EXPECT_EQ(1000, s.num_elements()); ASSERT_EQ(3, s.dims()); @@ -547,7 +547,7 @@ class TensorShapeOld { /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error /// status otherwise. - static Status IsValidShape(const TensorShapeProto& proto); + static absl::Status IsValidShape(const TensorShapeProto& proto); /// Clear a tensor shape void Clear(); @@ -675,7 +675,7 @@ bool TensorShapeOld::IsValid(const TensorShapeProto& proto) { return true; } -Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) { +absl::Status TensorShapeOld::IsValidShape(const TensorShapeProto& proto) { int64_t num_elements = 1; for (const auto& d : proto.dim()) { if (d.size() < 0) { diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc index c467590ea5e523..c64f4157c57561 100644 --- a/tensorflow/core/framework/tensor_slice.cc +++ b/tensorflow/core/framework/tensor_slice.cc @@ -47,8 +47,8 @@ TensorSlice::TensorSlice( } } -Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto, - TensorSlice* output) { +absl::Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto, + TensorSlice* output) { output->Clear(); output->starts_.reserve(proto.extent_size()); output->lengths_.reserve(proto.extent_size()); @@ -75,7 +75,7 @@ Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto, return absl::OkStatus(); } -Status TensorSlice::Parse(const string& str, TensorSlice* slice) { +absl::Status TensorSlice::Parse(const string& str, TensorSlice* slice) { std::vector items = str_util::Split(str, ':', str_util::SkipEmpty()); slice->starts_.reserve(items.size()); slice->lengths_.reserve(items.size()); @@ -267,8 +267,8 @@ int64_t TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) { return extent.length(); } -Status TensorSlice::SliceTensorShape(const TensorShape& shape, - TensorShape* result_shape) const { +absl::Status TensorSlice::SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const { result_shape->Clear(); // Mismatching ranks: we can't apply the slice at all. if (shape.dims() != dims()) { diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h index d0fd25432728d6..4ada28d1f20109 100644 --- a/tensorflow/core/framework/tensor_slice.h +++ b/tensorflow/core/framework/tensor_slice.h @@ -51,13 +51,13 @@ class TensorSlice { // This factory methods should be used instead of the constructor that takes a // `TensorSliceProto` if calling code cannot validate that the sizes specify a // valid `TensorSlice`. - static Status BuildTensorSlice(const TensorSliceProto& proto, - TensorSlice* output); + static absl::Status BuildTensorSlice(const TensorSliceProto& proto, + TensorSlice* output); - static Status Parse(const string& str, TensorSlice* output); + static absl::Status Parse(const string& str, TensorSlice* output); static TensorSlice ParseOrDie(const string& str) { TensorSlice ret; - Status s = Parse(str, &ret); + absl::Status s = Parse(str, &ret); if (!s.ok()) { LOG(FATAL) << "Could not parse TensorSlice"; } @@ -151,8 +151,8 @@ class TensorSlice { // Requires that the shape and *this have the same rank. // For example, given a tensor shape of {3, 4, 5}, and a slice of // 1,2:-:0,2, the result shape is {2, 4, 2}. - Status SliceTensorShape(const TensorShape& shape, - TensorShape* result_shape) const; + absl::Status SliceTensorShape(const TensorShape& shape, + TensorShape* result_shape) const; // Given slice "sub" where "sub" is fully contained in *this, // (meaning that the intersection of "sub" and *this equals "sub"), computes diff --git a/tensorflow/core/framework/tensor_slice_test.cc b/tensorflow/core/framework/tensor_slice_test.cc index 8e6ce1013e8e40..1818c0b3f27c3c 100644 --- a/tensorflow/core/framework/tensor_slice_test.cc +++ b/tensorflow/core/framework/tensor_slice_test.cc @@ -87,7 +87,7 @@ TEST(TensorSliceTest, Serialization) { // Failed parsing { TensorSlice slice; - Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); + absl::Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE( absl::StrContains(s.message(), @@ -96,7 +96,7 @@ TEST(TensorSliceTest, Serialization) { } { TensorSlice slice; - Status s = TensorSlice::Parse("-:-1,3", &slice); + absl::Status s = TensorSlice::Parse("-:-1,3", &slice); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains( s.message(), @@ -122,7 +122,7 @@ TEST(TensorSliceTest, Serialization) { // int64 parsing failure { TensorSlice slice; - Status s = + absl::Status s = TensorSlice::Parse("19223372036854775808,19223372036854775808", &slice); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains( @@ -246,7 +246,7 @@ TEST(TensorSliceTest, SliceTensorShape) { TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); TensorShape x({2, 4, 5, 8}); TensorShape y; - Status s = a.SliceTensorShape(x, &y); + absl::Status s = a.SliceTensorShape(x, &y); EXPECT_EQ(s.code(), error::INTERNAL); EXPECT_TRUE(absl::StrContains(s.message(), "Extent in dimension 1 out of bounds: " diff --git a/tensorflow/core/framework/tensor_util.cc b/tensorflow/core/framework/tensor_util.cc index 33b38f6032e67a..f9131de632827a 100644 --- a/tensorflow/core/framework/tensor_util.cc +++ b/tensorflow/core/framework/tensor_util.cc @@ -56,7 +56,7 @@ void DeepCopy(const Tensor& input, Tensor* output) { } } -Status Concat(const absl::Span tensors, Tensor* result) { +absl::Status Concat(const absl::Span tensors, Tensor* result) { if (tensors.empty()) { return errors::InvalidArgument("Cannot concatenate zero tensors"); } @@ -119,8 +119,8 @@ Status Concat(const absl::Span tensors, Tensor* result) { return absl::OkStatus(); } -Status Split(const Tensor& tensor, const absl::Span sizes, - std::vector* result) { +absl::Status Split(const Tensor& tensor, const absl::Span sizes, + std::vector* result) { if (tensor.dims() == 0) { return errors::InvalidArgument("Cannot split a zero-dimensional tensor"); } @@ -425,7 +425,7 @@ bool CompressTensorProtoInPlace(int64_t min_num_elements, #undef HANDLE_COMPRESS_CASE -Status MakeShape(const Tensor& shape, TensorShape* out) { +absl::Status MakeShape(const Tensor& shape, TensorShape* out) { if (!TensorShapeUtils::IsVector(shape.shape())) { return errors::InvalidArgument( "shape must be a vector of {int32,int64}, got shape ", diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 6112db80bf4144..ee607ff5b8d5be 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -49,8 +49,8 @@ void DeepCopy(const Tensor& input, Tensor* output); // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory. // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it // is not appropriately memory-aligned. -Status Concat(absl::Span tensors, - Tensor* result) TF_MUST_USE_RESULT; +absl::Status Concat(absl::Span tensors, + Tensor* result) TF_MUST_USE_RESULT; // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'. @@ -62,8 +62,8 @@ Status Concat(absl::Span tensors, // appropriately memory-aligned. // // Split() and Concat() are inverse operations. -Status Split(const Tensor& tensor, absl::Span sizes, - std::vector* result) TF_MUST_USE_RESULT; +absl::Status Split(const Tensor& tensor, absl::Span sizes, + std::vector* result) TF_MUST_USE_RESULT; namespace internal { void SetTensorProtoShape(absl::Span shape, @@ -351,7 +351,7 @@ inline bool CompressTensorProtoInPlace(TensorProto* tensor) { // Make a TensorShape from the contents of shape_t. Shape_t must be a // 1-dimensional tensor of type int32 or int64. -Status MakeShape(const Tensor& shape_t, TensorShape* out); +absl::Status MakeShape(const Tensor& shape_t, TensorShape* out); } // namespace tensor } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 5a836deeccabbd..9cab23446f40b0 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -66,21 +66,21 @@ struct StoredTensorValue { stored = data.tensors_[0]; return true; } - static Status CopyCPUToGPU( + static absl::Status CopyCPUToGPU( const StoredTensorValue& from, StoredTensorValue* to, - const std::function& copy) { + const std::function& copy) { ++*GetCopyCPUToGPUCounter(); return copy(from.stored, &(to->stored)); } - static Status CopyGPUToCPU( + static absl::Status CopyGPUToCPU( const StoredTensorValue& from, StoredTensorValue* to, - const std::function& copy) { + const std::function& copy) { ++*GetCopyGPUToCPUCounter(); return copy(from.stored, &(to->stored)); } - static Status CopyGPUToGPU( + static absl::Status CopyGPUToGPU( const StoredTensorValue& from, StoredTensorValue* to, - const std::function& copy) { + const std::function& copy) { ++*GetCopyGPUToGPUCounter(); return copy(from.stored, &(to->stored)); } @@ -259,7 +259,7 @@ TEST(VariantOpCopyTest, CreateConstOnGPUFailsGracefully) { TF_ASSERT_OK(root.status()); ClientSession session(root); std::vector outputs; - Status s = session.Run({create_const}, &outputs); + absl::Status s = session.Run({create_const}, &outputs); EXPECT_TRUE( absl::StrContains(s.message(), "GPU copy from non-DMA string tensor")) << s.ToString(); @@ -364,7 +364,7 @@ TEST(VariantOpCopyTest, CreateCopyCPUToGPUStringFailsSafely) { ClientSession session(root); std::vector outputs; - Status err = session.Run({create_op, identity}, &outputs); + absl::Status err = session.Run({create_op, identity}, &outputs); EXPECT_TRUE(errors::IsInvalidArgument(err)); EXPECT_TRUE( absl::StrContains(err.message(), diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index 0261feed3f22f6..306f6a6fec743d 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -125,7 +125,7 @@ REGISTER_VARIANT_DECODE_TYPE(double); #undef REGISTER_VARIANT_DECODE_TYPE -Status VariantDeviceCopy( +absl::Status VariantDeviceCopy( const VariantDeviceCopyDirection direction, const Variant& from, Variant* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { @@ -143,7 +143,7 @@ Status VariantDeviceCopy( namespace { template -Status DeviceCopyPrimitiveType( +absl::Status DeviceCopyPrimitiveType( const T& in, T* out, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) { // Dummy copy, we don't actually bother copying to the device and back for @@ -174,8 +174,8 @@ REGISTER_VARIANT_DEVICE_COPY_TYPE(bool); namespace { template -Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, - T* t_out) { +absl::Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, + T* t_out) { *t_out = T(0); return absl::OkStatus(); } @@ -196,8 +196,8 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); namespace { template -Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, - T* out) { +absl::Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, + const T& b, T* out) { *out = a + b; return absl::OkStatus(); } diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 0584fb7c573244..f75177f712be74 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -66,10 +66,11 @@ extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal(); class UnaryVariantOpRegistry { public: typedef std::function VariantDecodeFn; - typedef std::function + typedef std::function VariantUnaryOpFn; - typedef std::function + typedef std::function VariantBinaryOpFn; // An AsyncTensorDeviceCopyFn is a function provided to @@ -89,14 +90,14 @@ class UnaryVariantOpRegistry { // Any failure of the copy itself will update the underlying // stream status and propagate through the runtime independent // of the caller. - typedef std::function + typedef std::function AsyncTensorDeviceCopyFn; // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn' // expected to be passed to the registration macro // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION. - typedef std::function + typedef std::function AsyncVariantDeviceCopyFn; // Add a decode function to the registry. @@ -296,7 +297,7 @@ bool DecodeUnaryVariant(Variant* variant); // REQUIRES: // 'to' is not null. // -Status VariantDeviceCopy( +absl::Status VariantDeviceCopy( const VariantDeviceCopyDirection direction, const Variant& from, Variant* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn); @@ -310,8 +311,8 @@ Status VariantDeviceCopy( // v_out is not null. // template -Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, - Variant* v_out) { +absl::Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, + const Variant& v, Variant* v_out) { const std::string& device = DeviceName::value; UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); @@ -334,8 +335,9 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, // out is not null. // template -Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, - const Variant& a, const Variant& b, Variant* out) { +absl::Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, + const Variant& a, const Variant& b, + Variant* out) { if (a.TypeId() != b.TypeId()) { return errors::Internal( "BinaryOpVariants: Variants a and b have different " @@ -385,8 +387,8 @@ class UnaryVariantDecodeRegistration { template class UnaryVariantDeviceCopyRegistration { public: - typedef std::function + typedef std::function LocalVariantDeviceCopyFn; UnaryVariantDeviceCopyRegistration( const VariantDeviceCopyDirection direction, const TypeIndex& type_index, @@ -398,7 +400,7 @@ class UnaryVariantDeviceCopyRegistration { [type_index_name, device_copy_fn]( const Variant& from, Variant* to, UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn - device_copy_tensor_fn) -> Status { + device_copy_tensor_fn) -> absl::Status { DCHECK_NE(to, nullptr); *to = T(); if (from.get() == nullptr) { @@ -415,7 +417,8 @@ class UnaryVariantDeviceCopyRegistration { template class UnaryVariantUnaryOpRegistration { - typedef std::function + typedef std::function LocalVariantUnaryOpFn; public: @@ -427,7 +430,7 @@ class UnaryVariantUnaryOpRegistration { UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( op, device, type_index, [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, - Variant* v_out) -> Status { + Variant* v_out) -> absl::Status { DCHECK_NE(v_out, nullptr); *v_out = T(); if (v.get() == nullptr) { @@ -444,8 +447,8 @@ class UnaryVariantUnaryOpRegistration { template class UnaryVariantBinaryOpRegistration { - typedef std::function + typedef std::function LocalVariantBinaryOpFn; public: @@ -459,7 +462,7 @@ class UnaryVariantBinaryOpRegistration { op, device, type_index, [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, const Variant& b, - Variant* out) -> Status { + Variant* out) -> absl::Status { DCHECK_NE(out, nullptr); *out = T(); if (a.get() == nullptr) { diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 2ef7be5cc7fa27..594e9a6682ddfa 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -39,41 +39,43 @@ namespace { struct VariantValue { string TypeName() const { return "TEST VariantValue"; } - static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, - VariantValue* v_out) { + static absl::Status CPUZerosLikeFn(OpKernelContext* ctx, + const VariantValue& v, + VariantValue* v_out) { if (v.early_exit) { return errors::InvalidArgument("early exit zeros_like!"); } v_out->value = 1; // CPU return absl::OkStatus(); } - static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, - VariantValue* v_out) { + static absl::Status GPUZerosLikeFn(OpKernelContext* ctx, + const VariantValue& v, + VariantValue* v_out) { if (v.early_exit) { return errors::InvalidArgument("early exit zeros_like!"); } v_out->value = 2; // GPU return absl::OkStatus(); } - static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a, - const VariantValue& b, VariantValue* out) { + static absl::Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { if (a.early_exit) { return errors::InvalidArgument("early exit add!"); } out->value = a.value + b.value; // CPU return absl::OkStatus(); } - static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a, - const VariantValue& b, VariantValue* out) { + static absl::Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { if (a.early_exit) { return errors::InvalidArgument("early exit add!"); } out->value = -(a.value + b.value); // GPU return absl::OkStatus(); } - static Status CPUToGPUCopyFn( + static absl::Status CPUToGPUCopyFn( const VariantValue& from, VariantValue* to, - const std::function& copier) { + const std::function& copier) { TF_RETURN_IF_ERROR(copier(Tensor(), nullptr)); to->value = 0xdeadbeef; return absl::OkStatus(); @@ -168,7 +170,7 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { Variant v_out; bool dummy_executed = false; auto dummy_copy_fn = [&dummy_executed](const Tensor& from, - Tensor* to) -> Status { + Tensor* to) -> absl::Status { dummy_executed = true; return absl::OkStatus(); }; @@ -203,8 +205,8 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { Variant v_out = VariantValue(); OpKernelContext* null_context_pointer = nullptr; - Status s0 = UnaryOpVariant(null_context_pointer, - ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); + absl::Status s0 = UnaryOpVariant( + null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE(absl::StrContains(s0.message(), "early exit zeros_like")); @@ -275,7 +277,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { Variant v_out = VariantValue(); OpKernelContext* null_context_pointer = nullptr; - Status s0 = BinaryOpVariants( + absl::Status s0 = BinaryOpVariants( null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE(absl::StrContains(s0.message(), "early exit add")); diff --git a/tensorflow/core/framework/versions.cc b/tensorflow/core/framework/versions.cc index 7291f572b32055..38ed437ff461c3 100644 --- a/tensorflow/core/framework/versions.cc +++ b/tensorflow/core/framework/versions.cc @@ -20,8 +20,9 @@ limitations under the License. namespace tensorflow { -Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, - const char* upper_name, const char* lower_name) { +absl::Status CheckVersions(const VersionDef& versions, int consumer, + int min_producer, const char* upper_name, + const char* lower_name) { // Guard against the caller misordering the arguments if (consumer < min_producer) { return errors::Internal(upper_name, " version check has consumer ", diff --git a/tensorflow/core/framework/versions.h b/tensorflow/core/framework/versions.h index 2e043fedeffeed..a63ff7035698bb 100644 --- a/tensorflow/core/framework/versions.h +++ b/tensorflow/core/framework/versions.h @@ -31,8 +31,9 @@ class VersionDef; // TF_RETURN_IF_ERROR(CheckVersions(versions, TF_GRAPH_DEF_VERSION, // TF_GRAPH_DEF_VERSION_MIN_PRODUCER, // "GraphDef", "graph")); -Status CheckVersions(const VersionDef& versions, int consumer, int min_producer, - const char* upper_name, const char* lower_name); +absl::Status CheckVersions(const VersionDef& versions, int consumer, + int min_producer, const char* upper_name, + const char* lower_name); } // namespace tensorflow diff --git a/tensorflow/core/function/polymorphism/BUILD b/tensorflow/core/function/polymorphism/BUILD index 0a333bb9d73ffd..3289406a109e26 100644 --- a/tensorflow/core/function/polymorphism/BUILD +++ b/tensorflow/core/function/polymorphism/BUILD @@ -100,7 +100,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "function_type_py_pb2", -# api_version = 2, # visibility = ["//visibility:private"], # deps = [":function_type_proto"], # ) diff --git a/tensorflow/core/function/runtime_client/BUILD b/tensorflow/core/function/runtime_client/BUILD index 970bfd92d7accd..e2576d14ebc18e 100644 --- a/tensorflow/core/function/runtime_client/BUILD +++ b/tensorflow/core/function/runtime_client/BUILD @@ -125,14 +125,14 @@ tf_python_pybind_extension( "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/framework:function_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) diff --git a/tensorflow/core/function/runtime_client/runtime_client.cc b/tensorflow/core/function/runtime_client/runtime_client.cc index f146e310bb5952..126873c92b24e7 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.cc +++ b/tensorflow/core/function/runtime_client/runtime_client.cc @@ -66,7 +66,7 @@ EagerContext& GlobalEagerContext() { static EagerContext* global_ctx = []() { SessionOptions opts; std::vector> devices; - Status&& device_init_status = DeviceFactory::AddDevices( + absl::Status&& device_init_status = DeviceFactory::AddDevices( opts, "/job:localhost/replica:0/task:0", &devices); CHECK(device_init_status.ok()); // Crash OK @@ -94,14 +94,15 @@ absl::StatusOr Runtime::GetFunctionProto(StringPiece name) { const FunctionDef* f = ctx.FindFunctionDef(std::string(name)); if (f == nullptr) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Could not find an attribute for key ", name)); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Could not find an attribute for key ", name)); } return *f; } -Status Runtime::CreateFunction(const FunctionDef& fdef) { +absl::Status Runtime::CreateFunction(const FunctionDef& fdef) { const auto& fname = fdef.signature().name(); if (this->eager_ctx_.FindFunctionByName(fname)) { TF_RETURN_WITH_CONTEXT_IF_ERROR(this->eager_ctx_.RemoveFunction(fname), @@ -110,14 +111,14 @@ Status Runtime::CreateFunction(const FunctionDef& fdef) { return this->eager_ctx_.AddFunctionDef(fdef); } -Status Runtime::CreateFunction(OpaqueTfgGraphFuncOp* fop) { +absl::Status Runtime::CreateFunction(OpaqueTfgGraphFuncOp* fop) { mlir::tfg::GraphFuncOp fop_proper = *reinterpret_cast(fop); return mlir::tfg::ConvertToFunctionDef(fop_proper, *this->eager_ctx_.FuncLibDef()); } -Status Runtime::CreateFunction(OpaqueTfFuncOp* fop) { +absl::Status Runtime::CreateFunction(OpaqueTfFuncOp* fop) { mlir::func::FuncOp fop_proper = *reinterpret_cast(fop); const auto& fname = fop_proper.getName().str(); GraphExportConfig config; @@ -129,8 +130,9 @@ Status Runtime::CreateFunction(OpaqueTfFuncOp* fop) { return CreateFunction(fdef); } -Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, - Dialect dialect) { +absl::Status Runtime::TransformFunction(StringPiece name, + StringPiece pipeline_name, + Dialect dialect) { // TODO(mdan): Use a longer-lived context. mlir::MLIRContext ctx; mlir::PassManager pm(&ctx); @@ -140,9 +142,9 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, // StringPiece doesn't seem to always be compatible with StringRef. if (mlir::failed(mlir::parsePassPipeline(std::string(pipeline_name), pm, error_stream))) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("locating pass pipeline ", pipeline_name, ": ", - error_stream.str())); + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("locating pass pipeline ", pipeline_name, + ": ", error_stream.str())); } // For now, we roundtrip from proto. Once we have a permanent MLIR @@ -161,9 +163,9 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, mlir::StatusScopedDiagnosticHandler diagnostics_handler(&ctx); if (failed(pm.run(mlir_fn->get()))) { - return diagnostics_handler.Combine( - Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("running pass pipeline ", pipeline_name, ": "))); + return diagnostics_handler.Combine(absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("running pass pipeline ", pipeline_name, ": "))); } for (auto fn : mlir_fn->get().getBody()->getOps()) { @@ -175,7 +177,7 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, } if (dialect == Dialect::TF) { - Status status; + absl::Status status; FunctionLibraryDefinition& flib_def = *this->eager_ctx_.FuncLibDef(); std::unique_ptr fbody; status = FunctionDefToBodyHelper(*fn, AttrSlice(), &flib_def, &fbody); @@ -187,9 +189,9 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, mlir::StatusScopedDiagnosticHandler diagnostics_handler(&ctx); if (failed(pm.run(mlir_fn->get()))) { - return diagnostics_handler.Combine( - Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("running pass pipeline ", pipeline_name, ": "))); + return diagnostics_handler.Combine(absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("running pass pipeline ", pipeline_name, ": "))); } for (auto fn : mlir_fn->get().getBody()->getOps()) { @@ -200,7 +202,7 @@ Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name, return absl::OkStatus(); } - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Unsupported dialect: ", dialect, ". Supported dialects are Dialect::TFG and Dialect::TF.")); diff --git a/tensorflow/core/function/runtime_client/runtime_client.h b/tensorflow/core/function/runtime_client/runtime_client.h index e2cffdf4d74796..d26c09b3a9db3b 100644 --- a/tensorflow/core/function/runtime_client/runtime_client.h +++ b/tensorflow/core/function/runtime_client/runtime_client.h @@ -73,17 +73,17 @@ class Runtime { absl::StatusOr GetFunctionProto(StringPiece name); // TODO(mdan): Enforce creation or rename to SetFunction. - Status CreateFunction(const FunctionDef& fdef); + absl::Status CreateFunction(const FunctionDef& fdef); // TODO(mdan): Change to mlir::tfg::GraphFuncOp once pybind can depend on it. - Status CreateFunction(OpaqueTfgGraphFuncOp* fop); + absl::Status CreateFunction(OpaqueTfgGraphFuncOp* fop); // TODO(xjun): Change to mlir::func::FuncOp once pybind can depend on it. - Status CreateFunction(OpaqueTfFuncOp* fop); + absl::Status CreateFunction(OpaqueTfFuncOp* fop); // Applies a MLIR pipeline to an existing function. // The pipeline may rename the function. If it does so, the old function // remains unchanged. If the new name specifies an existing function, it will // be overwritten. - Status TransformFunction(StringPiece name, StringPiece pipeline_name, - Dialect dialect = Dialect::TFG); + absl::Status TransformFunction(StringPiece name, StringPiece pipeline_name, + Dialect dialect = Dialect::TFG); absl::StatusOr CallFunction( StringPiece name, absl::Span args); diff --git a/tensorflow/core/function/runtime_client/runtime_client_test.cc b/tensorflow/core/function/runtime_client/runtime_client_test.cc index 91effd88b7f701..487e1ee62c233d 100644 --- a/tensorflow/core/function/runtime_client/runtime_client_test.cc +++ b/tensorflow/core/function/runtime_client/runtime_client_test.cc @@ -46,7 +46,7 @@ namespace { EagerContextPtr TestingEagerCtx() { SessionOptions opts; std::vector> devices; - Status&& device_init_status = DeviceFactory::AddDevices( + absl::Status&& device_init_status = DeviceFactory::AddDevices( opts, "/job:localhost/replica:0/task:0", &devices); CHECK(device_init_status.ok()); // Crash OK @@ -62,7 +62,7 @@ EagerContextPtr TestingEagerCtx() { } int IntValue(ImmediateExecutionTensorHandle& h) { - Status status; + absl::Status status; AbstractTensorPtr t(h.Resolve(&status)); DCHECK(status.ok()); switch (h.DataType()) { diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD index 788f4278a187b6..cf8502fbe77d1c 100644 --- a/tensorflow/core/function/trace_type/BUILD +++ b/tensorflow/core/function/trace_type/BUILD @@ -186,21 +186,18 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "serialization_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":serialization_proto"], # ) # # py_proto_library( # name = "serialization_test_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":serialization_test_proto"], # ) # # py_proto_library( # name = "default_types_py_pb2", -# api_version = 2, # visibility = ["//tensorflow:internal"], # deps = [":default_types_proto"], # ) diff --git a/tensorflow/core/graph/collective_order.cc b/tensorflow/core/graph/collective_order.cc index 7b17d07480c6dc..9f8a498d88b47e 100644 --- a/tensorflow/core/graph/collective_order.cc +++ b/tensorflow/core/graph/collective_order.cc @@ -23,11 +23,11 @@ namespace { // Find all CollectiveReduce nodes and the existing data dependencies between // them. -Status DiscoverDataDependencies( +absl::Status DiscoverDataDependencies( const Graph* graph, std::vector* collective_nodes, std::vector* instance_keys, absl::flat_hash_map>* data_dependencies) { - Status s; + absl::Status s; // Algorithm: do Reverse DFS starting at sink. `node_leave` is called when // all parents of `node` have been visited. At that point, // `data_dependencies[node]` is a list containing `instance_key` of every @@ -40,7 +40,7 @@ Status DiscoverDataDependencies( bool enter_node = node->IsCollective() && node->type_string() == "CollectiveReduce"; if (enter_node) { - Status get_attr_status = + absl::Status get_attr_status = GetNodeAttr(node->attrs(), "instance_key", &instance_key); s.Update(get_attr_status); collective_nodes->push_back(node); @@ -67,7 +67,7 @@ Status DiscoverDataDependencies( // collective nodes, create control dependencies between concurrent collectives // and store in `dependency_edges`. // If there exists an edge a -> b then `dependency_edges[a]` contains `b` -Status CreateControlDependencies( +absl::Status CreateControlDependencies( const std::vector& collective_nodes, const std::vector& instance_keys, absl::flat_hash_map>* data_dependencies, @@ -144,7 +144,7 @@ Status CreateControlDependencies( // Insert control dependencies defined by `dependency_edges` in `graph`. If // `order_type` is `kEdges`, insert explicit control edges, else if `order_type` // is `kAttrs`, encode dependencies as an attribute on collective node. -Status InsertControlDependencies( +absl::Status InsertControlDependencies( Graph* graph, GraphCollectiveOrder order_type, const absl::flat_hash_map>& dependency_edges) { @@ -181,7 +181,7 @@ Status InsertControlDependencies( } // namespace -Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) { +absl::Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type) { // `instance_keys[i]` corresponds to `collective_nodes[i]` std::vector collective_nodes; std::vector instance_keys; diff --git a/tensorflow/core/graph/collective_order.h b/tensorflow/core/graph/collective_order.h index 67a1427a96635f..c62017bbed6344 100644 --- a/tensorflow/core/graph/collective_order.h +++ b/tensorflow/core/graph/collective_order.h @@ -29,7 +29,7 @@ enum class GraphCollectiveOrder { kNone, kEdges, kAttrs }; // control edges between collective graph nodes. If `order_type` is `kAttrs`, // add an attribute to the node which may be used by collective executor to // ensure the required ordering. -Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type); +absl::Status OrderCollectives(Graph* graph, GraphCollectiveOrder order_type); } // namespace tensorflow diff --git a/tensorflow/core/graph/collective_order_test.cc b/tensorflow/core/graph/collective_order_test.cc index 0f5c424c29e092..46333535cbbaad 100644 --- a/tensorflow/core/graph/collective_order_test.cc +++ b/tensorflow/core/graph/collective_order_test.cc @@ -129,7 +129,7 @@ std::unique_ptr InitGraph() { CollectiveReduceNode(&builder, id1, "c3_1", dev1, 3); std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); - Status s = GraphDefBuilderToGraph(builder, graph.get()); + absl::Status s = GraphDefBuilderToGraph(builder, graph.get()); if (!s.ok()) { LOG(FATAL) << "Error building graph " << s; } @@ -177,7 +177,7 @@ std::unique_ptr InitGraph2() { CollectiveReduceNode(&builder, id, "c3", dev0, 3); std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); - Status s = GraphDefBuilderToGraph(builder, graph.get()); + absl::Status s = GraphDefBuilderToGraph(builder, graph.get()); if (!s.ok()) { LOG(FATAL) << "Error building graph " << s; } @@ -216,7 +216,7 @@ std::unique_ptr InitGraphForPruning() { CollectiveReduceNode(&builder, z, "c4", dev0, 4); std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); - Status s = GraphDefBuilderToGraph(builder, graph.get()); + absl::Status s = GraphDefBuilderToGraph(builder, graph.get()); if (!s.ok()) { LOG(FATAL) << "Error building graph " << s; } diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc index 581bd8d730561a..4cd9316a4607e3 100644 --- a/tensorflow/core/graph/control_flow.cc +++ b/tensorflow/core/graph/control_flow.cc @@ -38,8 +38,8 @@ struct Frame { }; // Verify that the ControlFlowInfo of the graph has valid loop structure. -Status ValidateControlFlowInfo(const Graph* graph, - const std::vector& cf_info) { +absl::Status ValidateControlFlowInfo( + const Graph* graph, const std::vector& cf_info) { std::unordered_map frames; for (const Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; @@ -83,8 +83,9 @@ Status ValidateControlFlowInfo(const Graph* graph, } } // namespace -Status BuildControlFlowInfo(const Graph* g, std::vector* info, - std::vector* unreachable_nodes) { +absl::Status BuildControlFlowInfo(const Graph* g, + std::vector* info, + std::vector* unreachable_nodes) { info->clear(); info->resize(g->num_node_ids()); diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h index cbef1c24cb768f..c1e2db339122df 100644 --- a/tensorflow/core/graph/control_flow.h +++ b/tensorflow/core/graph/control_flow.h @@ -52,8 +52,9 @@ struct ControlFlowInfo { // NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0. // This essentially means there can't be multiple serial Nexts in an iteration, // which all sane front-ends should satisfy. -Status BuildControlFlowInfo(const Graph* g, std::vector* info, - std::vector* unreachable_nodes = nullptr); +absl::Status BuildControlFlowInfo( + const Graph* g, std::vector* info, + std::vector* unreachable_nodes = nullptr); } // namespace tensorflow diff --git a/tensorflow/core/graph/control_flow_test.cc b/tensorflow/core/graph/control_flow_test.cc index 1594e085f12571..80709fbd866717 100644 --- a/tensorflow/core/graph/control_flow_test.cc +++ b/tensorflow/core/graph/control_flow_test.cc @@ -26,20 +26,22 @@ limitations under the License. namespace tensorflow { namespace { -Status LessThanTenCond(const Scope& scope, const std::vector& inputs, - Output* output) { +absl::Status LessThanTenCond(const Scope& scope, + const std::vector& inputs, + Output* output) { *output = ops::Less(scope, inputs[0], 10); return scope.status(); } -Status AddOneBody(const Scope& scope, const std::vector& inputs, - std::vector* outputs) { +absl::Status AddOneBody(const Scope& scope, const std::vector& inputs, + std::vector* outputs) { outputs->push_back(ops::AddN(scope, {inputs[0], 1})); return scope.status(); } -Status NestedLoopBody(const Scope& scope, const std::vector& inputs, - std::vector* outputs) { +absl::Status NestedLoopBody(const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { return ops::BuildWhileLoop(scope.NewSubScope("inner"), inputs, LessThanTenCond, AddOneBody, "inner_loop", outputs); @@ -58,7 +60,7 @@ TEST(ValidateControlFlowTest, InputsFromDifferentFrames) { // {inner/Enter', 'outer/Switch'} --> 'inner/Merge'. 'inner/Enter' is in frame // 'inner_loop'. 'outer/Switch' is in frame 'outer_loop'. std::vector info; - Status status = BuildControlFlowInfo(graph.get(), &info); + absl::Status status = BuildControlFlowInfo(graph.get(), &info); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.message(), "has inputs from different frames")) @@ -96,7 +98,7 @@ TEST(ValidateControlFlowTest, MismatchedParentFrames) { (*enter.mutable_attr())["T"].set_type(DT_INT32); (*enter.mutable_attr())["frame_name"].set_s("test_loop"); *enter.add_input() = "Enter"; - Status status; + absl::Status status; Node* enter_2 = graph->AddNode(enter, &status); TF_ASSERT_OK(status); graph->AddControlEdge(enter_1, enter_2); @@ -129,7 +131,7 @@ TEST(ValidateControlFlowTest, TwoLoopCond) { std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); std::vector info; - Status status = BuildControlFlowInfo(graph.get(), &info); + absl::Status status = BuildControlFlowInfo(graph.get(), &info); EXPECT_FALSE(status.ok()); EXPECT_TRUE( absl::StrContains(status.message(), "more than one LoopCond node")) diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index b5408c37cda47a..a06187cdfeb8e5 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -162,7 +162,7 @@ void Node::Clear() { void Node::UpdateProperties() { DataTypeVector inputs; DataTypeVector outputs; - Status status = + absl::Status status = InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs); if (!status.ok()) { LOG(ERROR) << "Failed at updating node: " << status; @@ -188,9 +188,9 @@ void Node::ClearTypeInfo() { } } -Status Node::ShrinkTypeInfo(const absl::flat_hash_map& index_mapping, - const string& type_attr_name, - bool update_full_type) { +absl::Status Node::ShrinkTypeInfo( + const absl::flat_hash_map& index_mapping, + const string& type_attr_name, bool update_full_type) { std::vector dtypes; TF_RETURN_IF_ERROR(GetNodeAttr(def(), type_attr_name, &dtypes)); @@ -316,7 +316,7 @@ void Node::set_original_func_names(const std::vector& names) { } } -Status Node::input_edge(int idx, const Edge** e) const { +absl::Status Node::input_edge(int idx, const Edge** e) const { if (idx < 0 || idx >= num_inputs()) { return errors::InvalidArgument("Invalid input_edge index: ", idx, ", Node ", name(), " only has ", num_inputs(), @@ -343,7 +343,7 @@ Status Node::input_edge(int idx, const Edge** e) const { } // Returns a vector of the non-control input edges to a node, indexed by ID. -Status Node::input_edges(std::vector* input_edges) const { +absl::Status Node::input_edges(std::vector* input_edges) const { input_edges->clear(); input_edges->resize(num_inputs(), nullptr); @@ -367,7 +367,7 @@ Status Node::input_edges(std::vector* input_edges) const { return absl::OkStatus(); } -Status Node::input_node(int idx, Node** n) const { +absl::Status Node::input_node(int idx, Node** n) const { const Edge* e; TF_RETURN_IF_ERROR(input_edge(idx, &e)); if (e == nullptr) { @@ -378,14 +378,14 @@ Status Node::input_node(int idx, Node** n) const { return absl::OkStatus(); } -Status Node::input_node(int idx, const Node** const_n) const { +absl::Status Node::input_node(int idx, const Node** const_n) const { Node* n; TF_RETURN_IF_ERROR(input_node(idx, &n)); *const_n = n; return absl::OkStatus(); } -Status Node::input_tensor(int idx, OutputTensor* t) const { +absl::Status Node::input_tensor(int idx, OutputTensor* t) const { const Edge* e; TF_RETURN_IF_ERROR(input_edge(idx, &e)); DCHECK(e != nullptr); @@ -449,7 +449,7 @@ Graph::Graph(const OpRegistryInterface* ops) NodeDef def; def.set_name("_SOURCE"); def.set_op("NoOp"); - Status status; + absl::Status status; Node* source = AddNode(def, &status); TF_CHECK_OK(status); CHECK_EQ(source->id(), kSourceId); @@ -468,7 +468,7 @@ Graph::Graph(const FunctionLibraryDefinition& flib_def) if (flib_def.num_functions() > 0 && versions_->min_consumer() < 12) { versions_->set_min_consumer(12); } - Status s = ops_.AddLibrary(flib_def); + absl::Status s = ops_.AddLibrary(flib_def); CHECK(s.ok()) << s.message(); } @@ -537,13 +537,13 @@ void Graph::Copy(const Graph& src) { } absl::StatusOr Graph::AddNode(NodeDef node_def) { - Status s; + absl::Status s; Node* out = AddNode(std::move(node_def), &s); TF_RETURN_IF_ERROR(s); return out; } -Node* Graph::AddNode(NodeDef node_def, Status* status) { +Node* Graph::AddNode(NodeDef node_def, absl::Status* status) { const OpRegistrationData* op_reg_data; status->Update(ops_.LookUp(node_def.op(), &op_reg_data)); if (!status->ok()) return nullptr; @@ -567,7 +567,7 @@ Node* Graph::AddNode(NodeDef node_def, Status* status) { } else { if (op_reg_data->type_ctor != nullptr) { VLOG(3) << "AddNode: found type constructor for " << node_def.name(); - Status s = + absl::Status s = full_type::SpecializeType(AttrSlice(node_def), op_reg_data->op_def, *(node_def.mutable_experimental_type())); if (!s.ok()) { @@ -733,8 +733,8 @@ const Edge* FindEdge(const Node* dst, int index) { } } // namespace -Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, - int dst_index) { +absl::Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst, + int dst_index) { TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index)); TF_RETURN_IF_ERROR(IsValidInputTensor(dst, dst_index)); const Edge* e = FindEdge(dst, dst_index); @@ -760,7 +760,8 @@ void Graph::AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { } } -Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) { +absl::Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, + Node* dst) { if (!dst->IsWhileNode()) { return errors::Internal( "dst argument to AddWhileEdgeHack should be a While op, got: ", @@ -782,13 +783,13 @@ Status Graph::AddWhileInputHack(Node* new_src, int new_src_index, Node* dst) { return absl::OkStatus(); } -Status Graph::AddFunctionLibrary( +absl::Status Graph::AddFunctionLibrary( const FunctionDefLibrary& fdef_lib, const FunctionDefLibraryStackTraces& library_traces) { return AddFunctionLibrary(FunctionDefLibrary(fdef_lib), library_traces); } -Status Graph::AddFunctionLibrary( +absl::Status Graph::AddFunctionLibrary( FunctionDefLibrary&& fdef_lib, const FunctionDefLibraryStackTraces& library_traces) { // Need a new-enough consumer to support the functions we add to the graph. @@ -798,16 +799,16 @@ Status Graph::AddFunctionLibrary( return ops_.AddLibrary(std::move(fdef_lib), library_traces); } -Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { +absl::Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { return AddFunctionLibrary(fdef_lib, /*library_traces=*/{}); } -Status Graph::AddFunctionLibrary(FunctionDefLibrary&& fdef_lib) { +absl::Status Graph::AddFunctionLibrary(FunctionDefLibrary&& fdef_lib) { return AddFunctionLibrary(std::move(fdef_lib), /*library_traces=*/{}); } -Status Graph::AddFunctionDef(const FunctionDef& fdef, - const StackTracesMap& stack_traces) { +absl::Status Graph::AddFunctionDef(const FunctionDef& fdef, + const StackTracesMap& stack_traces) { // Need a new-enough consumer to support the functions we add to the graph. if (versions_->min_consumer() < 12) { versions_->set_min_consumer(12); @@ -815,7 +816,7 @@ Status Graph::AddFunctionDef(const FunctionDef& fdef, return ops_.AddFunctionDef(fdef, stack_traces); } -Status Graph::AddGradientDef(const GradientDef& gdef) { +absl::Status Graph::AddGradientDef(const GradientDef& gdef) { // Need a new-enough consumer to support the functions we add to the graph. if (versions_->min_consumer() < 12) { versions_->set_min_consumer(12); @@ -914,7 +915,7 @@ std::string Graph::NewName(StringPiece prefix) { return strings::StrCat(prefix, "/_", name_counter_++); } -Status Graph::IsValidNode(const Node* node) const { +absl::Status Graph::IsValidNode(const Node* node) const { if (node == nullptr) { return errors::InvalidArgument("Node is null"); } @@ -934,7 +935,7 @@ Status Graph::IsValidNode(const Node* node) const { return absl::OkStatus(); } -Status Graph::IsValidOutputTensor(const Node* node, int idx) const { +absl::Status Graph::IsValidOutputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_outputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", @@ -945,7 +946,7 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const { return absl::OkStatus(); } -Status Graph::IsValidInputTensor(const Node* node, int idx) const { +absl::Status Graph::IsValidInputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); if (idx >= node->num_inputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", @@ -1004,13 +1005,13 @@ int Graph::InternDeviceName(const std::string& device_name) { return index; } -Status Graph::AddWhileContext(StringPiece frame_name, - std::vector enter_nodes, - std::vector exit_nodes, - OutputTensor cond_output, - std::vector body_inputs, - std::vector body_outputs, - WhileContext** result) { +absl::Status Graph::AddWhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result) { auto pair = while_ctxs_.insert(std::pair( std::string(frame_name), WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes), diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index c7a4f696bf126d..68905818f403f9 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -228,20 +228,20 @@ class Node { void ClearAttr(const std::string& name); // Returns into '*e' the edge connecting to the 'idx' input of this Node. - Status input_edge(int idx, const Edge** e) const; + absl::Status input_edge(int idx, const Edge** e) const; // Returns into '*edges' the input data edges of this Node, indexed by input // number. Does not return control edges. - Status input_edges(std::vector* edges) const; + absl::Status input_edges(std::vector* edges) const; // Returns into '*n' the node that has an output connected to the // 'idx' input of this Node. - Status input_node(int idx, const Node** n) const; - Status input_node(int idx, Node** n) const; + absl::Status input_node(int idx, const Node** n) const; + absl::Status input_node(int idx, Node** n) const; // Returns into '*t' the idx-th input tensor of this node, represented as the // output tensor of input_node(idx). - Status input_tensor(int idx, OutputTensor* t) const; + absl::Status input_tensor(int idx, OutputTensor* t) const; WhileContext* while_ctx() const { return while_ctx_; } void set_while_ctx(WhileContext* while_ctx) { @@ -276,8 +276,9 @@ class Node { // removed. dtype information in the TYPE_ATTR_NAME attr is always updated. // Use UPDATE_FULL_TYPE=true when this changes the node's outputs to also // update the node's full type information (if present). - Status ShrinkTypeInfo(const absl::flat_hash_map& index_mapping, - const string& type_attr_name, bool update_full_type); + absl::Status ShrinkTypeInfo( + const absl::flat_hash_map& index_mapping, + const string& type_attr_name, bool update_full_type); // Called after an incident non-control edge has changed. Does nothing if not // all input edges are defined. @@ -560,7 +561,7 @@ class Graph { // Adds a new node to this graph, and returns it. Infers the Op and // input/output types for the node. *this owns the returned instance. // Returns nullptr and sets *status on error. - Node* AddNode(NodeDef node_def, Status* status); + Node* AddNode(NodeDef node_def, absl::Status* status); // Same as above, but using StatusOr. This method is always preferred. absl::StatusOr AddNode(NodeDef node_def); @@ -613,7 +614,8 @@ class Graph { // Updates the input to a node. The existing edge to `dst` is removed and an // edge from `new_src` to `dst` is created. The NodeDef associated with `dst` // is also updated. - Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, int dst_index); + absl::Status UpdateEdge(Node* new_src, int new_src_index, Node* dst, + int dst_index); // Add an input to dst that comes from the "src_slot" output of the // node named by "src_name". @@ -622,35 +624,37 @@ class Graph { // Like AddEdge but updates dst's NodeDef. Used to add an input edge to a // "While" op during gradient construction, see AddInputWhileHack in // python_api.h for more details. - Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst); + absl::Status AddWhileInputHack(Node* new_src, int new_src_index, Node* dst); // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same // name. This overload adds the function definitions with no stack traces. - Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); - Status AddFunctionLibrary(FunctionDefLibrary&& fdef_lib); + absl::Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib); + absl::Status AddFunctionLibrary(FunctionDefLibrary&& fdef_lib); // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same // name. - Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib, - const FunctionDefLibraryStackTraces& stack_traces); - Status AddFunctionLibrary(FunctionDefLibrary&& fdef_lib, - const FunctionDefLibraryStackTraces& stack_traces); + absl::Status AddFunctionLibrary( + const FunctionDefLibrary& fdef_lib, + const FunctionDefLibraryStackTraces& stack_traces); + absl::Status AddFunctionLibrary( + FunctionDefLibrary&& fdef_lib, + const FunctionDefLibraryStackTraces& stack_traces); // Adds the function definition and its stacktraces to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same // name. - Status AddFunctionDef(const FunctionDef& fdef, - const StackTracesMap& stack_traces); + absl::Status AddFunctionDef(const FunctionDef& fdef, + const StackTracesMap& stack_traces); // Adds the gradient definition to this graph's op registry. Ignores duplicate // gradients of the same function, and returns a bad status if an imported // gradient differs from an existing gradient of the same function name. - Status AddGradientDef(const GradientDef& gdef); + absl::Status AddGradientDef(const GradientDef& gdef); // The number of live nodes in the graph. // @@ -777,25 +781,26 @@ class Graph { } // Returns OK if `node` is non-null and belongs to this graph - Status IsValidNode(const Node* node) const; + absl::Status IsValidNode(const Node* node) const; // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not // accept control outputs. - Status IsValidOutputTensor(const Node* node, int idx) const; + absl::Status IsValidOutputTensor(const Node* node, int idx) const; // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept // control inputs. - Status IsValidInputTensor(const Node* node, int idx) const; + absl::Status IsValidInputTensor(const Node* node, int idx) const; // Create and return a new WhileContext owned by this graph. This is called // when a new while loop is created. `frame_name` must be unique among // WhileContexts in this graph. - Status AddWhileContext(StringPiece frame_name, std::vector enter_nodes, - std::vector exit_nodes, - OutputTensor cond_output, - std::vector body_inputs, - std::vector body_outputs, - WhileContext** result); + absl::Status AddWhileContext(StringPiece frame_name, + std::vector enter_nodes, + std::vector exit_nodes, + OutputTensor cond_output, + std::vector body_inputs, + std::vector body_outputs, + WhileContext** result); // Builds a node name to node pointer index for all nodes in the graph. std::unordered_map BuildNodeNameIndex() const; @@ -1027,7 +1032,7 @@ inline bool NodeIter::operator!=(const NodeIter& rhs) const { } inline void NodeIter::operator++() { - while (1) { + while (true) { DCHECK_LE(id_, graph_->num_node_ids()); ++id_; if (id_ >= graph_->num_node_ids() || graph_->FindNodeId(id_) != nullptr) { diff --git a/tensorflow/core/graph/graph_debug_info_builder.cc b/tensorflow/core/graph/graph_debug_info_builder.cc index 015494c181ed70..b539fa1d5c04a5 100644 --- a/tensorflow/core/graph/graph_debug_info_builder.cc +++ b/tensorflow/core/graph/graph_debug_info_builder.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/logging.h" @@ -178,10 +179,19 @@ void GraphDebugInfoBuilder::AccumulateStackTrace( AppendToStackTraceProto(stack_frame, stack_trace_proto); } } else { - frame_to_index_.reserve(frame_to_index_.size() + - trace->ToFrames().size()); - for (const auto& stack_frame : trace->ToFrames()) { - AppendToStackTraceProto(stack_frame, stack_trace_proto); + if (flags::Global() + .enable_graph_debug_info_caching_for_stack_frames.value()) { + frame_to_index_.reserve(frame_to_index_.size() + + trace->ToFrames().size()); + for (const auto& stack_frame : trace->ToFrames()) { + AppendToStackTraceProto(stack_frame, stack_trace_proto); + } + } else { + frame_to_index_.reserve(frame_to_index_.size() + + trace->ToUncachedFrames().size()); + for (const auto& stack_frame : trace->ToUncachedFrames()) { + AppendToStackTraceProto(stack_frame, stack_trace_proto); + } } } } diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc index 0b6c50373471ed..b8734f662c5fe8 100644 --- a/tensorflow/core/graph/graph_def_builder.cc +++ b/tensorflow/core/graph/graph_def_builder.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -GraphDefBuilder::Options::Options(Graph* graph, Status* status) +GraphDefBuilder::Options::Options(Graph* graph, absl::Status* status) : graph_(graph), status_(status) {} GraphDefBuilder::Options::~Options() {} @@ -64,7 +64,7 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputsImpl( return *this; } -Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { +absl::Status GraphDefBuilder::ToGraphDef(GraphDef* graph_def) const { if (status_.ok()) { graph_.ToGraphDef(graph_def); *graph_def->mutable_library() = flib_def_.ToProto(); @@ -89,7 +89,7 @@ Node* GraphDefBuilder::Options::FinalizeBuilder(NodeBuilder* builder) const { return returned_node; } -void GraphDefBuilder::Options::UpdateStatus(const Status& status) const { +void GraphDefBuilder::Options::UpdateStatus(const absl::Status& status) const { if (status_ == nullptr) { TF_CHECK_OK(status); } else { diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h index 8935cf46485247..bc44649302172f 100644 --- a/tensorflow/core/graph/graph_def_builder.h +++ b/tensorflow/core/graph/graph_def_builder.h @@ -74,7 +74,7 @@ class GraphDefBuilder { // Sets the Graph (that Nodes will be added to) and the status. The // status may be set to nullptr, in which case errors cause CHECK // failures. The graph and status must outlive *this. - Options(Graph* graph, Status* status); + Options(Graph* graph, absl::Status* status); ~Options(); // Methods for setting options. These are const methods: they @@ -119,7 +119,7 @@ class GraphDefBuilder { Node* FinalizeBuilder(NodeBuilder* builder) const; // Updates the associated status, if any, or calls TF_CHECK_OK if none. - void UpdateStatus(const Status& status) const; + void UpdateStatus(const absl::Status& status) const; // Accessor const OpRegistryInterface* op_registry() const { @@ -139,7 +139,7 @@ class GraphDefBuilder { } Graph* const graph_; - Status* const status_; + absl::Status* const status_; string name_; string device_; std::vector control_inputs_; @@ -164,13 +164,13 @@ class GraphDefBuilder { // Once all the nodes have been added, call this to get whether it was // successful, and if so fill *graph_def. - Status ToGraphDef(GraphDef* graph_def) const; + absl::Status ToGraphDef(GraphDef* graph_def) const; // Adds the function and gradient definitions in `fdef_lib` to this graph's op // registry. Ignores duplicate functions, and returns a bad status if an // imported function differs from an existing function or op with the same // name. - Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { + absl::Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) { return flib_def_.AddLibrary(fdef_lib); } @@ -183,7 +183,7 @@ class GraphDefBuilder { private: Graph graph_; FunctionLibraryDefinition flib_def_; - Status status_; + absl::Status status_; Options opts_; }; diff --git a/tensorflow/core/graph/graph_node_util.cc b/tensorflow/core/graph/graph_node_util.cc index acfbd3fbbf9463..3bf14ed2944394 100644 --- a/tensorflow/core/graph/graph_node_util.cc +++ b/tensorflow/core/graph/graph_node_util.cc @@ -31,13 +31,13 @@ string FormatNodeForError(const Node& node) { return FormatNodeDefForError(node.def()); } -Status NameRangesForNode(const Node& node, const OpDef& op_def, - NameRangeMap* inputs, NameRangeMap* outputs) { +absl::Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs) { return NameRangesForNode(node.def(), op_def, inputs, outputs); } -Status AttachDef(const Status& status, const Node& node, - bool allow_multiple_formatted_node) { +absl::Status AttachDef(const absl::Status& status, const Node& node, + bool allow_multiple_formatted_node) { return AttachDef(status, node.def(), allow_multiple_formatted_node); } diff --git a/tensorflow/core/graph/graph_node_util.h b/tensorflow/core/graph/graph_node_util.h index f78564c91b0ffa..146c4c07ca833a 100644 --- a/tensorflow/core/graph/graph_node_util.h +++ b/tensorflow/core/graph/graph_node_util.h @@ -50,15 +50,15 @@ void MergeDebugInfo(const NodeDef& from, NodeDef* to); // space, the returned `NameRangeMap` objects borrow the input/output // argument names from `op_def`. The `op_def` must outlive the // returned `NameRangeMap` objects. -Status NameRangesForNode(const Node& node, const OpDef& op_def, - NameRangeMap* inputs, NameRangeMap* outputs); +absl::Status NameRangesForNode(const Node& node, const OpDef& op_def, + NameRangeMap* inputs, NameRangeMap* outputs); // Returns "status" with formatted Node attached as additional text // in the error message. If 'allow_multiple_formatted_node' is false and there // is already a formatted Node present in 'status', we simply attach the name // of the Node instead of the formatted string. -Status AttachDef(const Status& status, const Node& node, - bool allow_multiple_formatted_node = false); +absl::Status AttachDef(const absl::Status& status, const Node& node, + bool allow_multiple_formatted_node = false); } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPH_GRAPH_NODE_UTIL_H_ diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index eac0fa367e5577..8e31106e70a58f 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -184,7 +184,7 @@ void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, GraphDef* gdef, const Edge* edge, NodeDefBuilder::NodeOut send_from, int64_t start_time, - const string& tensor_name_attr, Status* status) { + const string& tensor_name_attr, absl::Status* status) { const DataType dtype = send_from.data_type; const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; const Node* src = edge->src(); @@ -241,7 +241,7 @@ NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, GraphDef* gdef, const Edge* edge, NodeDef** real_recv, - const string& tensor_name_attr, Status* status) { + const string& tensor_name_attr, absl::Status* status) { const DataType dtype = EdgeType(edge); const Node* src = edge->src(); const Node* dst = edge->dst(); @@ -324,7 +324,7 @@ NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, } NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, - const Edge* edge, Status* status) { + const Edge* edge, absl::Status* status) { const Node* src = edge->src(); Tensor tensor(DT_FLOAT, TensorShape({0})); NodeDef* result = gdef->add_node(); @@ -340,7 +340,7 @@ NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, // A dummy node for scheduling. NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, const string& assigned_device_name, int64_t epoch, - int64_t starttime, Status* status) { + int64_t starttime, absl::Status* status) { NodeDef* result = gdef->add_node(); *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)), "ControlTrigger") @@ -410,7 +410,7 @@ bool IsControlLoop(const Node* node) { // An enter node for control flow. Node* AddControlEnter(Graph* g, const string& node_name, const string& device_name, const string& frame_name, - const int parallel_iterations, Status* status) { + const int parallel_iterations, absl::Status* status) { NodeBuilder node_builder(node_name, "Enter", g->op_registry()); node_builder.Input({"dummy", 0, DT_FLOAT}); node_builder.Attr("frame_name", frame_name); @@ -425,7 +425,7 @@ Node* AddControlEnter(Graph* g, const string& node_name, // A merge node for control flow. Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, const string& node_name, const string& device_name, - Status* status) { + absl::Status* status) { NodeBuilder node_builder(node_name, "Merge", g->op_registry()); node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); Node* res_node; @@ -506,11 +506,11 @@ void AddControlFlowInfo(const Node* node, const Node* src, // switch node will be connected to the LoopCond node. The merge node will // be connected to all the recvs of the same frame by control edges when // the actual partitioning happens. -Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, - const Edge* edge, Node* loop_cond, - std::vector* cf_info, - ControlLoop* loop) { - Status status; +absl::Status AddControlLoop(const PartitionOptions& opts, Graph* g, + const Node* src, const Edge* edge, Node* loop_cond, + std::vector* cf_info, + ControlLoop* loop) { + absl::Status status; GraphDefBuilder::Options bopts(g, &status); const ControlFlowInfo& src_info = (*cf_info)[src->id()]; const string& device_name = edge->dst()->assigned_device_name(); @@ -562,7 +562,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, // Build memory and device type info for every node in the graph. // TODO(yuanbyu): It might be simpler if we convert MemoryType to // DeviceType for the inputs/outputs of each node. -Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { +absl::Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { MemoryTypeVector input_memory_types; MemoryTypeVector output_memory_types; @@ -621,9 +621,9 @@ const Node* OutputFrame(const Node* node, // // TODO(yuanbyu): The correctness of this construction is rather subtle. I got // it wrong many times so it would be nice to write a proof to be sure. -Status AddControlFlow(const PartitionOptions& opts, Graph* g, - GraphInfo* g_info) { - Status status; +absl::Status AddControlFlow(const PartitionOptions& opts, Graph* g, + GraphInfo* g_info) { + absl::Status status; GraphDefBuilder::Options bopts(g, &status); std::vector& cf_info = g_info->cf_info; @@ -798,7 +798,7 @@ struct PriorityTopoSortNodeGreater { // // Note that graph_partition_test.cc accesses this function for testing, even // though it's not declared in the header. -Status TopologicalSortNodesWithTimePriority( +absl::Status TopologicalSortNodesWithTimePriority( const GraphDef* gdef, std::vector>* nodes, std::unordered_map* node_to_start_time_out) { @@ -873,9 +873,9 @@ Status TopologicalSortNodesWithTimePriority( return absl::OkStatus(); } -Status AddControlEdges(const PartitionOptions& opts, - std::unordered_map* partitions) { - Status status; +absl::Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map* partitions) { + absl::Status status; // TODO(yuanbyu): Very naive for now. To be improved. const int num_epochs = 100; const int prefetch = 6; @@ -968,10 +968,10 @@ void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { } } -Status Partition(const PartitionOptions& opts, Graph* g, - std::unordered_map* partitions) { +absl::Status Partition(const PartitionOptions& opts, Graph* g, + std::unordered_map* partitions) { // TODO(b/290689453) Refactor this into smaller functions - Status status; + absl::Status status; absl::flat_hash_map> debug_info_builders; partitions->clear(); diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h index 8b81fc536dc284..59e9fe0e61c35d 100644 --- a/tensorflow/core/graph/graph_partition.h +++ b/tensorflow/core/graph/graph_partition.h @@ -95,14 +95,14 @@ struct PartitionOptions { // generate node names. // // Stores the partitions in *partitions. -Status Partition(const PartitionOptions& opts, Graph* input, - std::unordered_map* partitions); +absl::Status Partition(const PartitionOptions& opts, Graph* input, + std::unordered_map* partitions); // Add control edges to the partitions to control the ordering // and timing of the recv nodes based on the start times calculated // using some scheduling algorithm. -Status AddControlEdges(const PartitionOptions& opts, - std::unordered_map* partitions); +absl::Status AddControlEdges(const PartitionOptions& opts, + std::unordered_map* partitions); } // namespace tensorflow diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index a6ad2a01031af4..2e807b2e21340d 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -51,7 +51,7 @@ limitations under the License. namespace tensorflow { // from graph_partition.cc -extern Status TopologicalSortNodesWithTimePriority( +extern absl::Status TopologicalSortNodesWithTimePriority( const GraphDef* gdef, std::vector>* nodes, std::unordered_map* node_to_start_time_out); @@ -103,7 +103,7 @@ void Partition(const GraphDef& graph_def, popts.get_incarnation = [](const string& name) { return (name[0] - 'A') + 100; }; - Status s = Partition(popts, &g, partitions); + absl::Status s = Partition(popts, &g, partitions); CHECK(s.ok()) << s; // Check versions. @@ -466,7 +466,7 @@ TEST_F(GraphPartitionTest, PartitionIncompleteGraph) { )EOF", &ndef); ASSERT_TRUE(parsed); - Status status; + absl::Status status; g.AddNode(ndef, &status); TF_ASSERT_OK(status); diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index c0582f7092f40d..13602cec25ab56 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -118,7 +118,7 @@ class GraphTest : public ::testing::Test { NodeDef node_def; TF_CHECK_OK(builder.Finalize(&node_def)); - Status s; + absl::Status s; Node* node = graph_.AddNode(node_def, &s); TF_CHECK_OK(s); return node; @@ -280,7 +280,7 @@ TEST_F(GraphTest, NodeByIndex) { graph_.RemoveNode(a); // 'c's input_node entry should be invalidated. - Status s = c->input_node(0, &a_copy); + absl::Status s = c->input_node(0, &a_copy); EXPECT_FALSE(s.ok()); // Add two new nodes. @@ -446,7 +446,7 @@ TEST_F(GraphTest, IsValidNode) { TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2)); // nullptr - Status s = graph_.IsValidNode(nullptr); + absl::Status s = graph_.IsValidNode(nullptr); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); EXPECT_EQ(string("Node is null"), s.message()); @@ -586,7 +586,7 @@ TEST_F(GraphTest, UpdateEdge) { EXPECT_EQ("0->1;0->2;2->1;2->3;2->4;2->5;4->1;", EdgeIter(graph_)); // Update a's 1st output which is out of range. - Status s = graph_.UpdateEdge(a, 1, d, 0); + absl::Status s = graph_.UpdateEdge(a, 1, d, 0); EXPECT_FALSE(s.ok()); EXPECT_EQ( s.message(), @@ -624,12 +624,12 @@ TEST_F(GraphTest, EdgeDebugString) { EXPECT_EQ(s1, "[id=0 :0 -> :0]"); // Print edge with null src node - auto e2 = BuildEdge(2, 0, b, 1, 1); + auto e2 = BuildEdge(2, nullptr, b, 1, 1); auto s2 = e2->DebugString(); EXPECT_EQ(s2, "[id=2 :1 -> B:1]"); // Print edge with null dst node - auto e3 = BuildEdge(3, a, 0, 2, 1); + auto e3 = BuildEdge(3, a, nullptr, 2, 1); auto s3 = e3->DebugString(); EXPECT_EQ(s3, "[id=3 A:2 -> :1]"); } @@ -652,7 +652,7 @@ TEST_F(GraphTest, AddFunctionLibrary) { FunctionDefLibrary error_proto = proto; *error_proto.mutable_function(0)->add_node_def() = error_proto.function(0).node_def(0); - Status s = graph_.AddFunctionLibrary(error_proto); + absl::Status s = graph_.AddFunctionLibrary(error_proto); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.message(), "Cannot add function 'XTimesTwo' because a different function with " @@ -744,7 +744,7 @@ TEST_F(GraphTest, NodeShrinkTypeOutput) { NodeDef node_def; TF_CHECK_OK(builder.Finalize(&node_def)); - Status s; + absl::Status s; Node* node = graph_.AddNode(node_def, &s); TF_CHECK_OK(s); @@ -792,7 +792,7 @@ TEST_F(GraphTest, NodeShrinkTypeInput) { NodeDef node_def; TF_CHECK_OK(builder.Finalize(&node_def)); - Status s; + absl::Status s; Node* node = graph_.AddNode(node_def, &s); TF_CHECK_OK(s); diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index dbd8fafd1ea523..96e5768941228d 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -123,7 +123,8 @@ absl::StatusOr NodeBuilder::Finalize(Graph* graph, bool consume) { return out; } -Status NodeBuilder::Finalize(Graph* graph, Node** created_node, bool consume) { +absl::Status NodeBuilder::Finalize(Graph* graph, Node** created_node, + bool consume) { // In case of error, set *created_node to nullptr. if (created_node != nullptr) { *created_node = nullptr; diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 6b7d67b9f1be5a..0d5bf9fb9a240c 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -123,7 +123,8 @@ class NodeBuilder { // *created_node will be set to the new node (or nullptr on error). // If `consume` is true, the builder state will be moved into `node_def`, // and the builder will be left in an undefined state. - Status Finalize(Graph* graph, Node** created_node, bool consume = false); + absl::Status Finalize(Graph* graph, Node** created_node, + bool consume = false); // Same as `Finalize` above, but using StatusOr to return value. Preferred // form. diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index b101810f2f1a81..8c73691fd6ba56 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -54,7 +54,7 @@ typedef std::unordered_map NameIndex; // Return true on success. On error, return false and sets *error to // an appropriate error message (and *g is left in an indeterminate // state). -Status FeedInputs( +absl::Status FeedInputs( Graph* g, const std::vector>& feed_rewrites, NameIndex* name_index, DataTypeVector* out_feed_types) { out_feed_types->clear(); @@ -119,7 +119,7 @@ Status FeedInputs( return absl::OkStatus(); } -Status FetchOutputs( +absl::Status FetchOutputs( Graph* g, const std::vector>& fetch_rewrites, NameIndex* name_index, std::vector* out_fetch_nodes, DataTypeVector* out_fetch_types) { @@ -187,9 +187,9 @@ bool AddNodeToTargets(const string& node_or_tensor_name, return true; } -Status PruneForTargets(Graph* g, const NameIndex& name_index, - const std::vector& fetch_nodes, - const absl::Span& target_nodes) { +absl::Status PruneForTargets(Graph* g, const NameIndex& name_index, + const std::vector& fetch_nodes, + const absl::Span& target_nodes) { string not_found; std::unordered_set targets; for (Node* n : fetch_nodes) { @@ -216,8 +216,8 @@ Status PruneForTargets(Graph* g, const NameIndex& name_index, } // namespace -Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, - Node** out_node) { +absl::Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) { // NOTE(mrry): We must include the index as part of the node // name, because _Arg is a "stateful" kernel and therefore // its name must uniquely identify a kernel instance across all @@ -233,8 +233,9 @@ Status ArgFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, return absl::OkStatus(); } -Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, - Node** out_node) { +absl::Status RecvFeedRewrite::AddNode(Graph* g, + NodeBuilder::NodeOut feed_tensor, + Node** out_node) { TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat("_recv_", feed_tensor.node->name(), "_", feed_tensor.index), @@ -253,8 +254,9 @@ Status RecvFeedRewrite::AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, return absl::OkStatus(); } -Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, - Node** out_node) { +absl::Status RetvalFetchRewrite::AddNode(Graph* g, + NodeBuilder::NodeOut fetch_tensor, + Node** out_node) { // NOTE(mrry): We must include the index as part of the node // name, because _Retval is a "stateful" kernel and therefore // its name must uniquely identify a kernel instance across all @@ -272,8 +274,9 @@ Status RetvalFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, return absl::OkStatus(); } -Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, - Node** out_node) { +absl::Status SendFetchRewrite::AddNode(Graph* g, + NodeBuilder::NodeOut fetch_tensor, + Node** out_node) { TF_RETURN_IF_ERROR( NodeBuilder(strings::StrCat("_send_", fetch_tensor.node->name(), "_", fetch_tensor.index), @@ -290,7 +293,7 @@ Status SendFetchRewrite::AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, return absl::OkStatus(); } -Status RewriteGraphForExecution( +absl::Status RewriteGraphForExecution( Graph* g, const absl::Span& fed_outputs, const absl::Span& fetch_outputs, const absl::Span& target_node_names, @@ -335,7 +338,7 @@ std::vector ConvertToVector(StringContainer field) { } } // namespace -Status RewriteGraphForExecution( +absl::Status RewriteGraphForExecution( Graph* g, const std::vector>& feed_rewrites, const std::vector>& fetch_rewrites, const absl::Span& target_node_names, diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h index 824c89e5482ed4..37013b8f7d09ee 100644 --- a/tensorflow/core/graph/subgraph.h +++ b/tensorflow/core/graph/subgraph.h @@ -56,8 +56,8 @@ class PruneRewrite { // Creates a new node whose output replaces the given `tensor` in graph `g`. // The node will be assigned to the device named in `device_info`. - virtual Status AddNode(Graph* g, NodeBuilder::NodeOut tensor, - Node** out_node) = 0; + virtual absl::Status AddNode(Graph* g, NodeBuilder::NodeOut tensor, + Node** out_node) = 0; // Returns the name of the tensor to which this rewrite applies. const string& endpoint_name() { return *endpoint_name_; } @@ -97,7 +97,7 @@ class PruneRewrite { // - fed output "node:output_index" does not exist in "*g" // - fetch output "node:output_index" does not exist in "*g" // - target node "node" does not exist in "*g" -Status RewriteGraphForExecution( +absl::Status RewriteGraphForExecution( Graph* g, const absl::Span& fed_outputs, const absl::Span& fetch_outputs, const absl::Span& target_node_names, @@ -106,7 +106,7 @@ Status RewriteGraphForExecution( // A more general version of the above function that supports // customizable rewriting actions for each fed and fetched tensor. -Status RewriteGraphForExecution( +absl::Status RewriteGraphForExecution( Graph* g, const std::vector>& feed_rewrites, const std::vector>& fetch_rewrites, const absl::Span& target_node_names, @@ -122,8 +122,8 @@ class ArgFeedRewrite : public PruneRewrite { ArgFeedRewrite(const string* endpoint_name, const DeviceAttributes* device_info, int32_t arg_index) : PruneRewrite(endpoint_name, device_info), arg_index_(arg_index) {} - Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, - Node** out_node) override; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) override; private: const int32 arg_index_; @@ -133,8 +133,8 @@ class ArgFeedRewrite : public PruneRewrite { class RecvFeedRewrite : public PruneRewrite { public: using PruneRewrite::PruneRewrite; - Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, - Node** out_node) override; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut feed_tensor, + Node** out_node) override; }; // A rewrite action that adds a _Retval node for a fetched tensor. @@ -143,8 +143,8 @@ class RetvalFetchRewrite : public PruneRewrite { RetvalFetchRewrite(const string* endpoint_name, const DeviceAttributes* device_info, int32_t retval_index) : PruneRewrite(endpoint_name, device_info), retval_index_(retval_index) {} - Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, - Node** out_node) override; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, + Node** out_node) override; private: const int32 retval_index_; @@ -155,8 +155,8 @@ class RetvalFetchRewrite : public PruneRewrite { class SendFetchRewrite : public PruneRewrite { public: using PruneRewrite::PruneRewrite; - Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, - Node** out_node) override; + absl::Status AddNode(Graph* g, NodeBuilder::NodeOut fetch_tensor, + Node** out_node) override; }; } // namespace subgraph diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index 571da3b62e57cd..9d86672dd94f37 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -117,7 +117,7 @@ class SubgraphTest : public ::testing::Test { str_util::Split(targets_str, ',', str_util::SkipEmpty()); subgraph::RewriteGraphMetadata metadata; - Status s = subgraph::RewriteGraphForExecution( + absl::Status s = subgraph::RewriteGraphForExecution( subgraph, fed, fetch, targets, device_info_, use_function_convention, &metadata); if (!s.ok()) { diff --git a/tensorflow/core/graph/validate.cc b/tensorflow/core/graph/validate.cc index db98b0ba9bb953..154d9f26c80cf5 100644 --- a/tensorflow/core/graph/validate.cc +++ b/tensorflow/core/graph/validate.cc @@ -28,9 +28,9 @@ limitations under the License. namespace tensorflow { namespace graph { -Status ValidateGraphDef(const GraphDef& graph_def, - const OpRegistryInterface& op_registry) { - Status s; +absl::Status ValidateGraphDef(const GraphDef& graph_def, + const OpRegistryInterface& op_registry) { + absl::Status s; const int version = graph_def.versions().producer(); for (const NodeDef& node_def : graph_def.node()) { // Look up the OpDef for the node_def's op name. @@ -43,15 +43,15 @@ Status ValidateGraphDef(const GraphDef& graph_def, return s; } -Status ValidateGraphDefAgainstOpRegistry( +absl::Status ValidateGraphDefAgainstOpRegistry( const GraphDef& graph_def, const OpRegistryInterface& op_registry) { GraphDef copy(graph_def); TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(©, op_registry, 0)); return ValidateGraphDef(copy, op_registry); } -Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, - const OpList& op_list) { +absl::Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, + const OpList& op_list) { OpListOpRegistry registry(&op_list); return ValidateGraphDefAgainstOpRegistry(graph_def, registry); } @@ -61,7 +61,7 @@ void GetOpListForValidation(OpList* op_list, const OpRegistry& op_registry) { RemoveDescriptionsFromOpList(op_list); } -Status ValidateGraphHasNoCycle(const Graph& graph) { +absl::Status ValidateGraphHasNoCycle(const Graph& graph) { // A node is ready when all of its inputs have been visited. std::vector ready; std::vector pending_count(graph.num_node_ids(), 0); @@ -115,7 +115,7 @@ Status ValidateGraphHasNoCycle(const Graph& graph) { return absl::OkStatus(); } -Status VerifyNoDuplicateNodeNames(const GraphDef& graph) { +absl::Status VerifyNoDuplicateNodeNames(const GraphDef& graph) { absl::flat_hash_set nodes; for (const auto& node : graph.node()) { if (nodes.contains(node.name())) { diff --git a/tensorflow/core/graph/validate.h b/tensorflow/core/graph/validate.h index bfb3a25ac91761..3d59219b6d2346 100644 --- a/tensorflow/core/graph/validate.h +++ b/tensorflow/core/graph/validate.h @@ -31,21 +31,21 @@ namespace graph { // REQUIRES: // * `op_registry` is not nullptr. // * `graph_def` has default attrs filled in (see AddDefaultAttrsToGraphDef()). -Status ValidateGraphDef(const GraphDef& graph_def, - const OpRegistryInterface& op_registry); +absl::Status ValidateGraphDef(const GraphDef& graph_def, + const OpRegistryInterface& op_registry); // Like ValidateGraphDef() except it makes a copy of `graph_def` and calls // AddDefaultAttrsToGraphDef() on the copy, removing that requirement from the // caller. -Status ValidateGraphDefAgainstOpRegistry( +absl::Status ValidateGraphDefAgainstOpRegistry( const GraphDef& graph_def, const OpRegistryInterface& op_registry); // Like ValidateGraphDefAgainstOpRegistry() except it takes an OpList // instead of an OpRegistryInterface. Note that the OpList need not // have descriptions, which can be a big space savings, see // GetOpListForValidation() below. -Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, - const OpList& op_list); +absl::Status ValidateGraphDefAgainstOpList(const GraphDef& graph_def, + const OpList& op_list); // Get an OpList from `*op_registry` with all the descriptions removed. void GetOpListForValidation( @@ -57,10 +57,10 @@ void GetOpListForValidation( // all been visited, and counts the total number of visited nodes. If there is a // cycle, nodes in the cycle will never be visited, and the visited count will // be less than the total node count. -Status ValidateGraphHasNoCycle(const Graph& graph); +absl::Status ValidateGraphHasNoCycle(const Graph& graph); // Returns OK if the graph has no duplicate node names. -Status VerifyNoDuplicateNodeNames(const GraphDef& graph); +absl::Status VerifyNoDuplicateNodeNames(const GraphDef& graph); } // namespace graph } // namespace tensorflow diff --git a/tensorflow/core/graph/validate_test.cc b/tensorflow/core/graph/validate_test.cc index 4a57af42326402..f9ef6367a09fc2 100644 --- a/tensorflow/core/graph/validate_test.cc +++ b/tensorflow/core/graph/validate_test.cc @@ -61,7 +61,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedDefaultAttr) { GraphDef graph_def; auto parser = protobuf::TextFormat::Parser(); CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; - Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); + absl::Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); EXPECT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "NodeDef missing attr")); @@ -84,7 +84,7 @@ TEST(ValidateGraphDefTest, GraphWithUnspecifiedRequiredAttr) { GraphDef graph_def; auto parser = protobuf::TextFormat::Parser(); CHECK(parser.MergeFromString(graph_def_str, &graph_def)) << graph_def_str; - Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); + absl::Status s = graph::ValidateGraphDef(graph_def, *OpRegistry::Global()); EXPECT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "NodeDef missing attr")); @@ -231,7 +231,7 @@ Node* AddNodeFromNodeDef(Graph& graph, const string& name, NodeDef node_def; TF_CHECK_OK(builder.Finalize(&node_def)); - Status s; + absl::Status s; Node* node = graph.AddNode(node_def, &s); TF_CHECK_OK(s); return node; diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h index f1c5f210f13a00..a3a3708cd3e164 100644 --- a/tensorflow/core/grappler/clusters/cluster.h +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -52,13 +52,13 @@ class Cluster { // TensorFlow session successfully created. Returns an error otherwise. // There is no graceful degradation to handle the case where only a subset // of the requested resources are available. - virtual Status Provision() = 0; + virtual absl::Status Provision() = 0; // Attempts to shutdown the cluster. // Returns OK iff there are no pending calls to the Run() method and all the // resources used by the cluster could be released. Returns an error // otherwise. - virtual Status Shutdown() { return absl::OkStatus(); } + virtual absl::Status Shutdown() { return absl::OkStatus(); } // Whether soft placement is allowed. If allow_soft_placement is true, // an op will be placed on CPU if there's no GPU implementation for the OP @@ -106,14 +106,14 @@ class Cluster { // Enables collecting the allocator stats. If called, must be called before // Provision(). - virtual Status EnablePeakMemoryStats() { + virtual absl::Status EnablePeakMemoryStats() { return absl::UnimplementedError(strings ::StrCat( "Peak Memory Stats are not supported on ", type(), " clusters")); } // Returns peak memory of all devices during the session creation and session // runs. - virtual Status GetPeakMemoryUsage( + virtual absl::Status GetPeakMemoryUsage( std::unordered_map* device_peak_memory) const { return absl::UnimplementedError( "GetPeakMemoryUsage is not implemented for this type of cluster."); @@ -121,16 +121,16 @@ class Cluster { // Prepare the session to run the specified grappler item. This include // initializing all the model variables. - virtual Status Initialize(const GrapplerItem& item) = 0; + virtual absl::Status Initialize(const GrapplerItem& item) = 0; // Run the specified graph_def and return the corresponding metadata. - virtual Status Run(const GraphDef& graph_def, - const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) = 0; + virtual absl::Status Run(const GraphDef& graph_def, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) = 0; // Run the specified GrapplerItem and return the corresponding metadata. - virtual Status Run(const GrapplerItem& item, RunMetadata* metadata) { + virtual absl::Status Run(const GrapplerItem& item, RunMetadata* metadata) { return Run(item.graph, item.feed, item.fetch, metadata); } diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 92f17cc30d1a42..3fb2787f034e35 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -70,7 +70,7 @@ SingleMachine::~SingleMachine() { thread_pool_.reset(); } -Status SingleMachine::Provision() { +absl::Status SingleMachine::Provision() { // This is really ugly: to avoid leaking variables, we need to reset the tf // session every time we're done processing a grappler item. However, // variables are global, and therefore we can't have more than 1 session alive @@ -96,7 +96,7 @@ Status SingleMachine::Provision() { } TfDeviceId tf_device_id(parsed.id); PlatformDeviceId platform_device_id; - Status s = + absl::Status s = GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id); if (!s.ok()) { return absl::UnavailableError( @@ -123,7 +123,7 @@ Status SingleMachine::Provision() { return absl::OkStatus(); } -Status SingleMachine::Initialize(const GrapplerItem& item) { +absl::Status SingleMachine::Initialize(const GrapplerItem& item) { mutex_lock l(this->last_graph_mu_); if (last_graph_ != &item.graph || last_graph_id_ != item.id) { init_ops_ = item.init_ops; @@ -135,7 +135,7 @@ Status SingleMachine::Initialize(const GrapplerItem& item) { return absl::OkStatus(); } -Status SingleMachine::Shutdown() { +absl::Status SingleMachine::Shutdown() { TF_RETURN_IF_ERROR(ShutdownSession()); mutex_lock l(this->last_graph_mu_); @@ -145,10 +145,10 @@ Status SingleMachine::Shutdown() { return absl::OkStatus(); } -Status SingleMachine::Run(const GraphDef& graph_def, - const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) { +absl::Status SingleMachine::Run( + const GraphDef& graph_def, + const std::vector>& feed, + const std::vector& fetch, RunMetadata* metadata) { mutex_lock l(this->last_graph_mu_); if (last_graph_ != &graph_def) { TF_RETURN_IF_ERROR(ResetSession()); @@ -206,20 +206,20 @@ Status SingleMachine::Run(const GraphDef& graph_def, return absl::OkStatus(); } -Status SingleMachine::EnablePeakMemoryStats() { +absl::Status SingleMachine::EnablePeakMemoryStats() { EnableCPUAllocatorStats(); cpu_allocator_stats_enabled_ = true; // No need to enable GPU allocator stats since its stats are always collected. return absl::OkStatus(); } -Status SingleMachine::GetPeakMemoryUsage( +absl::Status SingleMachine::GetPeakMemoryUsage( std::unordered_map* device_peak_memory) const { // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the // the AllocatorStats would be collected. if (!cpu_allocator_stats_enabled_) { - return Status(absl::StatusCode::kInvalidArgument, - "Tracking allocation for CPU is not enabled."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Tracking allocation for CPU is not enabled."); } const DeviceMgr* device_mgr; @@ -230,8 +230,8 @@ Status SingleMachine::GetPeakMemoryUsage( for (Device* device : devices) { auto* allocator = device->GetAllocator(AllocatorAttributes()); if (!allocator->TracksAllocationSizes()) { - return Status(absl::StatusCode::kInvalidArgument, - "Tracking allocation is not enabled."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Tracking allocation is not enabled."); } absl::optional stats = allocator->GetStats(); (*device_peak_memory)[device->name()] = @@ -241,13 +241,13 @@ Status SingleMachine::GetPeakMemoryUsage( return absl::OkStatus(); } -Status SingleMachine::RunWithTimeout( +absl::Status SingleMachine::RunWithTimeout( const std::vector>& feed, const std::vector& fetch, RunMetadata* run_metadata) { return RunWithTimeout(feed, fetch, run_metadata, timeout_s_); } -Status SingleMachine::RunWithTimeout( +absl::Status SingleMachine::RunWithTimeout( const std::vector>& feed, const std::vector& fetch, RunMetadata* run_metadata, int64_t timeout_s) { @@ -257,7 +257,7 @@ Status SingleMachine::RunWithTimeout( CHECK(!closing_); } - auto status = std::make_shared(); + auto status = std::make_shared(); auto local_metadata = std::make_shared(); const bool executed_in_time = ExecuteWithTimeout( [this, status, local_metadata, feed, fetch]() { @@ -274,7 +274,7 @@ Status SingleMachine::RunWithTimeout( return *status; } -Status SingleMachine::CloseSession(bool use_timeout) { +absl::Status SingleMachine::CloseSession(bool use_timeout) { if (!session_ || !thread_pool_) { return absl::OkStatus(); } @@ -320,7 +320,7 @@ Status SingleMachine::CloseSession(bool use_timeout) { return absl::OkStatus(); } -Status SingleMachine::ShutdownSession() { +absl::Status SingleMachine::ShutdownSession() { TF_RETURN_IF_ERROR(CloseSession(true /*use_timeout*/)); // Delete the threadpool: this ensures that all the pending closures complete @@ -346,7 +346,7 @@ Status SingleMachine::ShutdownSession() { return absl::OkStatus(); } -Status SingleMachine::ResetSession() { +absl::Status SingleMachine::ResetSession() { if (session_) { LOG(INFO) << "Cleaning up previous session"; @@ -444,12 +444,12 @@ void SingleMachine::MergeCosts(CostGraphDef* graph_costs, } } -Status SingleMachine::ClearAllocatorStats() const { +absl::Status SingleMachine::ClearAllocatorStats() const { // Cpu_allocator->TracksAllocationSizes() returns true doesn't always mean the // the AllocatorStats would be collected. if (!cpu_allocator_stats_enabled_) { - return Status(absl::StatusCode::kInvalidArgument, - "Tracking allocation for CPU is not enabled."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Tracking allocation for CPU is not enabled."); } const DeviceMgr* device_mgr; @@ -459,11 +459,11 @@ Status SingleMachine::ClearAllocatorStats() const { for (Device* device : devices) { auto* allocator = device->GetAllocator(AllocatorAttributes()); if (!allocator->TracksAllocationSizes()) { - return Status(absl::StatusCode::kInvalidArgument, - "Tracking allocation is not enabled."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "Tracking allocation is not enabled."); } if (!allocator->ClearStats()) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Clearing allocation stats is not supported for ", device->name())); diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h index 2a78bce9146486..e049ca2fe09765 100644 --- a/tensorflow/core/grappler/clusters/single_machine.h +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -35,36 +35,38 @@ class SingleMachine : public Cluster { string type() const override { return "single_machine"; } - Status Provision() override; - Status Shutdown() override; + absl::Status Provision() override; + absl::Status Shutdown() override; - Status Initialize(const GrapplerItem& item) override; - Status Run(const GraphDef& item, - const std::vector>& feed, - const std::vector& fetch, RunMetadata* metadata) override; + absl::Status Initialize(const GrapplerItem& item) override; + absl::Status Run(const GraphDef& item, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) override; const DeviceSet* GetDeviceSet() const override { return device_set_.get(); } - Status EnablePeakMemoryStats() override; + absl::Status EnablePeakMemoryStats() override; // It requires EnableAllocatorStats(true) be called before Provision(). - Status GetPeakMemoryUsage( + absl::Status GetPeakMemoryUsage( std::unordered_map* device_peak_memory) const override; private: - Status RunWithTimeout(const std::vector>& feed, - const std::vector& fetch, - RunMetadata* run_metadata); - Status RunWithTimeout(const std::vector>& feed, - const std::vector& fetch, - RunMetadata* run_metadata, int64_t timeout_s); - Status ResetSession(); - Status CloseSession(bool use_timeout); - Status ShutdownSession(); + absl::Status RunWithTimeout( + const std::vector>& feed, + const std::vector& fetch, RunMetadata* run_metadata); + absl::Status RunWithTimeout( + const std::vector>& feed, + const std::vector& fetch, RunMetadata* run_metadata, + int64_t timeout_s); + absl::Status ResetSession(); + absl::Status CloseSession(bool use_timeout); + absl::Status ShutdownSession(); void MergeCosts(CostGraphDef* graph_costs, const CostGraphDef& init_costs, const CostGraphDef& queue_costs); - Status ClearAllocatorStats() const; + absl::Status ClearAllocatorStats() const; std::unique_ptr session_; std::vector queue_runner_defs_; diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc index ac9f58fdcf54ee..99de8acf2404e4 100644 --- a/tensorflow/core/grappler/clusters/single_machine_test.cc +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -246,9 +246,9 @@ TEST_F(SingleMachineTest, TimeOuts) { TF_CHECK_OK(cluster_->Initialize(item)); RunMetadata metadata; - Status s1 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); + absl::Status s1 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); EXPECT_TRUE(errors::IsDeadlineExceeded(s1)); - Status s2 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); + absl::Status s2 = cluster_->Run(item.graph, item.feed, item.fetch, &metadata); EXPECT_TRUE(errors::IsDeadlineExceeded(s2)); } @@ -334,7 +334,7 @@ static void RunInfiniteTFLoop() { TF_CHECK_OK(cluster.Provision()); TF_CHECK_OK(cluster.Initialize(item)); - Status s1 = cluster.Run(item.graph, item.feed, item.fetch, nullptr); + absl::Status s1 = cluster.Run(item.graph, item.feed, item.fetch, nullptr); if (!errors::IsDeadlineExceeded(s1)) { LOG(ERROR) << "Expected 'deadline exceeded' error, got " << s1; // Exit to break the infinite loop @@ -342,7 +342,7 @@ static void RunInfiniteTFLoop() { } // Attempt to shutdown the cluster and make sure we get the proper error code. - Status s2 = cluster.Shutdown(); + absl::Status s2 = cluster.Shutdown(); if (!errors::IsUnavailable(s2)) { LOG(ERROR) << "Expected 'unavailable' error, got " << s2; // Exit to break the infinite loop @@ -628,7 +628,7 @@ TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) { TF_CHECK_OK(cluster.Initialize(item)); std::unordered_map device_peak_memory; - Status s = cluster.GetPeakMemoryUsage(&device_peak_memory); + absl::Status s = cluster.GetPeakMemoryUsage(&device_peak_memory); TF_CHECK_OK(cluster.Shutdown()); ASSERT_FALSE(s.ok()); EXPECT_TRUE(errors::IsInvalidArgument(s)); diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index 648a0668f6265a..685506ce99bfce 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -157,7 +157,7 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) { if (device.has_id) { TfDeviceId tf_device_id(device.id); PlatformDeviceId platform_device_id; - Status s = + absl::Status s = GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id); if (!s.ok()) { LOG(ERROR) << s; diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 03dea8dd13a378..0f2b6a6d2fdfff 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -57,16 +57,15 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set) VirtualCluster::~VirtualCluster() {} -Status VirtualCluster::Provision() { return absl::OkStatus(); } +absl::Status VirtualCluster::Provision() { return absl::OkStatus(); } -Status VirtualCluster::Initialize(const GrapplerItem& item) { +absl::Status VirtualCluster::Initialize(const GrapplerItem& item) { return absl::OkStatus(); } -Status VirtualCluster::Run(const GraphDef& graph, - const std::vector>& feed, - const std::vector& fetch, - RunMetadata* metadata) { +absl::Status VirtualCluster::Run( + const GraphDef& graph, const std::vector>& feed, + const std::vector& fetch, RunMetadata* metadata) { GrapplerItem item; item.graph = graph; item.feed = feed; @@ -74,7 +73,8 @@ Status VirtualCluster::Run(const GraphDef& graph, return Run(item, metadata); } -Status VirtualCluster::Run(const GrapplerItem& item, RunMetadata* metadata) { +absl::Status VirtualCluster::Run(const GrapplerItem& item, + RunMetadata* metadata) { // Initializes an analytical cost estimator to estimate the graph cost. Makes // sure to use static shape inference to prevent the virtual scheduler from // calling the Run method on the cluster and creating an infinite loop. diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index fa547a4ed2faf3..f42e1047ce2373 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -44,12 +44,13 @@ class VirtualCluster : public Cluster { string type() const override { return "virtual"; } - Status Provision() override; - Status Initialize(const GrapplerItem& item) override; - Status Run(const GraphDef& graph, - const std::vector>& feed, - const std::vector& fetch, RunMetadata* metadata) override; - Status Run(const GrapplerItem& item, RunMetadata* metadata) override; + absl::Status Provision() override; + absl::Status Initialize(const GrapplerItem& item) override; + absl::Status Run(const GraphDef& graph, + const std::vector>& feed, + const std::vector& fetch, + RunMetadata* metadata) override; + absl::Status Run(const GrapplerItem& item, RunMetadata* metadata) override; const DeviceSet* GetDeviceSet() const override { return device_set_; } private: diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc index c922ba0817405b..a774b5e6ccc8af 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc @@ -116,7 +116,7 @@ TEST_F(VirtualClusterTest, OutOfMemory) { item.fetch.push_back("i2"); TF_CHECK_OK(cluster_->Initialize(item)); - Status s = cluster_->Run(item.graph, item.feed, item.fetch, nullptr); + absl::Status s = cluster_->Run(item.graph, item.feed, item.fetch, nullptr); EXPECT_EQ(error::RESOURCE_EXHAUSTED, s.code()); } diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 59f08fedafae7d..9576fa5c2dc693 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -416,7 +416,6 @@ tf_cc_test( # py_proto_library( # name = "op_performance_data_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_performance_data"], # ) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc index 10bf5ea4fc1aae..a56b4a23d093e1 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc @@ -37,11 +37,11 @@ namespace grappler { namespace { // Helper function in PredictCosts() to add cost node to cost_graph. -Status AddCostNode(ReadyNodeManager* node_manager, const OpContext& op_context, - int node_id, const Costs& node_costs, - gtl::FlatMap* name_to_cost_node, - gtl::FlatMap* name_to_id, - CostGraphDef* cost_graph) { +absl::Status AddCostNode( + ReadyNodeManager* node_manager, const OpContext& op_context, int node_id, + const Costs& node_costs, + gtl::FlatMap* name_to_cost_node, + gtl::FlatMap* name_to_id, CostGraphDef* cost_graph) { const string& op_name = op_context.name; auto it = name_to_cost_node->find(op_name); CostGraphDef::Node* node; @@ -149,14 +149,14 @@ AnalyticalCostEstimator::AnalyticalCostEstimator( node_manager_.get(), std::move(placer)); } -Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { +absl::Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) { item_ = &item; return absl::OkStatus(); } -Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, - RunMetadata* run_metadata, - Costs* costs) const { +absl::Status AnalyticalCostEstimator::PredictCosts( + const GraphDef& optimized_graph, RunMetadata* run_metadata, + Costs* costs) const { std::unique_ptr item_storage; const GrapplerItem* item; // Many callers to PredictCosts() pass the same optimized_graph as was used @@ -209,7 +209,7 @@ Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph, // TODO(pcma): Add unit tests for generating CostGraphDef. if (cost_graph) { - Status s = + absl::Status s = AddCostNode(node_manager_.get(), op_context, node_id++, node_costs, &name_to_cost_node, &name_to_id, cost_graph); if (!s.ok()) { diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.h b/tensorflow/core/grappler/costs/analytical_cost_estimator.h index 8387a886c6784d..b31ce39ef6324a 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator.h +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.h @@ -56,13 +56,14 @@ class AnalyticalCostEstimator : public CostEstimator { ~AnalyticalCostEstimator() override {} // This implementation always returns OK. - Status Initialize(const GrapplerItem& item) override; + absl::Status Initialize(const GrapplerItem& item) override; // Predict the performance of each node of the optimized graph and annotate // the RunMetadata with the corresponding estimates. Also returns the // expected cost for the whole graph. - Status PredictCosts(const GraphDef& optimized_graph, - RunMetadata* run_metadata, Costs* cost) const override; + absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const override; const VirtualScheduler* GetScheduler() const { return scheduler_.get(); } diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h index 22b273986421ce..b133b3695df576 100644 --- a/tensorflow/core/grappler/costs/cost_estimator.h +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -238,7 +238,7 @@ class CostEstimator { // Initializes the estimator for the specified grappler item. // The estimator shouldn't be used if this function returns any status other // that OK. - virtual Status Initialize(const GrapplerItem& item) = 0; + virtual absl::Status Initialize(const GrapplerItem& item) = 0; // Predicts the cost of running the given optimized version of the grappler // item. @@ -248,8 +248,9 @@ class CostEstimator { // overall cost of running the graph (e.g. the latency of the computation). // Returns a status that indicate is the performance could be estimated or // not. - virtual Status PredictCosts(const GraphDef& optimized_graph, - RunMetadata* run_metadata, Costs* cost) const = 0; + virtual absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const = 0; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc index 4099c2495edc00..e5e6638b07d622 100644 --- a/tensorflow/core/grappler/costs/graph_memory.cc +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -31,13 +31,13 @@ limitations under the License. namespace tensorflow { namespace grappler { -Status GraphMemory::InferStatically( +absl::Status GraphMemory::InferStatically( const std::unordered_map& devices) { VirtualCluster cluster(devices); TF_RETURN_IF_ERROR(cluster.Provision()); TF_RETURN_IF_ERROR(cluster.Initialize(item_)); RunMetadata metadata; - Status s = cluster.Run(item_, &metadata); + absl::Status s = cluster.Run(item_, &metadata); // The virtual cluster returns the RESOURCE_EXHAUSTED error when it detects // that the model would run out of memory. We still get the metadata we need // out of the simulation, so we just ignore this error. @@ -48,7 +48,7 @@ Status GraphMemory::InferStatically( return absl::OkStatus(); } -Status GraphMemory::InferDynamically(Cluster* cluster) { +absl::Status GraphMemory::InferDynamically(Cluster* cluster) { if (!cluster->DetailedStatsEnabled()) { return errors::Unavailable("Detailed stats collection must be enabled"); } diff --git a/tensorflow/core/grappler/costs/graph_memory.h b/tensorflow/core/grappler/costs/graph_memory.h index 6e2520cf4f73e3..fcd9eaeba5a4e6 100644 --- a/tensorflow/core/grappler/costs/graph_memory.h +++ b/tensorflow/core/grappler/costs/graph_memory.h @@ -43,9 +43,9 @@ class GraphMemory { explicit GraphMemory(const GrapplerItem& item) : item_(item), unknown_usage_({-1, {}}) {} - Status InferStatically( + absl::Status InferStatically( const std::unordered_map& devices); - Status InferDynamically(Cluster* cluster); + absl::Status InferDynamically(Cluster* cluster); // Worst case memory usage in bytes, or -1 if the usage is unknown. If there // are multiple devices, returns the highest per device memory usage. diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc index bcb20098575be9..0a290bf06c2e48 100644 --- a/tensorflow/core/grappler/costs/graph_memory_test.cc +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -49,7 +49,7 @@ TEST_F(GraphMemoryTest, Basic) { item.feed.clear(); GraphMemory memory(item); - Status s = memory.InferStatically(devices_); + absl::Status s = memory.InferStatically(devices_); TF_CHECK_OK(s); const GraphMemory::MemoryUsage& mem_usage = memory.GetPeakMemoryUsage("/CPU:0"); @@ -77,7 +77,7 @@ TEST_F(GraphMemoryTest, UnknownBatchSize) { item.feed.clear(); GraphMemory memory(item); - Status s = memory.InferStatically(devices_); + absl::Status s = memory.InferStatically(devices_); TF_CHECK_OK(s); // Same maths as before, except that batch size is unknown and therefore // assumed to be one. @@ -104,7 +104,7 @@ TEST_F(GraphMemoryTest, MultiDevice) { item.feed.clear(); GraphMemory memory(item); - Status s = memory.InferStatically(devices_); + absl::Status s = memory.InferStatically(devices_); TF_CHECK_OK(s); const GraphMemory::MemoryUsage& cpu_mem = memory.GetPeakMemoryUsage("/CPU:0"); @@ -143,7 +143,7 @@ TEST_F(GraphMemoryTest, GpuSwapping) { { // Estimate the max memory usage for the graph. GraphMemory memory(item); - Status s = memory.InferStatically(devices_); + absl::Status s = memory.InferStatically(devices_); TF_CHECK_OK(s); const GraphMemory::MemoryUsage& gpu_mem = @@ -171,7 +171,7 @@ TEST_F(GraphMemoryTest, GpuSwapping) { } } GraphMemory memory(item); - Status s = memory.InferStatically(devices_); + absl::Status s = memory.InferStatically(devices_); TF_CHECK_OK(s); const GraphMemory::MemoryUsage& new_gpu_mem = memory.GetPeakMemoryUsage("/GPU:0"); @@ -207,7 +207,7 @@ TEST_F(GraphMemoryTest, CtrlDependencies) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphMemory memory(item); - Status status = memory.InferStatically(devices_); + absl::Status status = memory.InferStatically(devices_); TF_CHECK_OK(status); const GraphMemory::MemoryUsage& mem = memory.GetPeakMemoryUsage("/CPU:0"); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 600b829fbfa89a..613b12bb18ae3a 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -96,7 +96,7 @@ struct Processor { // Extract the shape or dim denoted by the handle. void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; } // Merge the shapes or dims. - Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) { + absl::Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) { if (InferenceContext::RankKnown(*result)) { // The result was initialized in a previous merge to a shape of known // rank, make sure we preserve that information. @@ -135,7 +135,7 @@ struct Processor { // Merge the dimensions d1 and d2. Return the known shape if there is one, // otherwise look for a symbolic shape. If there is no symbolic shape and no // known shape, the shape if fully unknown so return -1. - Status Merge(DimensionHandle d1, DimensionHandle d2, int64_t* result) { + absl::Status Merge(DimensionHandle d1, DimensionHandle d2, int64_t* result) { const int64_t dim1 = InferenceContext::Value(d1); const int64_t dim2 = InferenceContext::Value(d2); @@ -159,7 +159,7 @@ struct Processor { } private: - Status RefineDim(int64_t dim, int64_t* result) { + absl::Status RefineDim(int64_t dim, int64_t* result) { if (*result >= 0) { if (!(*result == dim || dim < 0)) { return errors::InvalidArgument("Inconsistent dimensions detected"); @@ -187,7 +187,7 @@ class DisjointSet { } } - Status Merge(Handle x, Handle y); + absl::Status Merge(Handle x, Handle y); const typename HandleToObject::Object GetMergedValue(Handle value); private: @@ -225,7 +225,7 @@ DisjointSet::GetMergedValue(Handle value) { } template -Status DisjointSet::Merge(Handle x, Handle y) { +absl::Status DisjointSet::Merge(Handle x, Handle y) { Rep* x_root = Find(x); Rep* y_root = Find(y); @@ -785,7 +785,7 @@ class SymbolicShapeRefiner { // // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* function_node) { + absl::Status UpdateFunction(const NodeDef* function_node) { NameAttrList function; TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function)); auto it = fun_to_grappler_function_item_.find(function.name()); @@ -988,7 +988,7 @@ class SymbolicShapeRefiner { // Prepares input shapes/values/handles, then runs shape inference, and // finally sets output shapes/values/handles. - Status UpdateNode(const NodeDef* node, bool* refined) { + absl::Status UpdateNode(const NodeDef* node, bool* refined) { NodeContext* ctx = GetNodeContext(node); if (ctx == nullptr) { TF_RETURN_IF_ERROR(AddNode(node)); @@ -1131,7 +1131,7 @@ class SymbolicShapeRefiner { return InferShapes(*node, ctx); } - Status SetUnknownShape(const NodeDef* node, int output_port) { + absl::Status SetUnknownShape(const NodeDef* node, int output_port) { shape_inference::ShapeHandle shape = GetUnknownOutputShape(node, output_port); InferenceContext* ctx = GetContext(node); @@ -1305,8 +1305,8 @@ class SymbolicShapeRefiner { return true; } - Status AddFunction(const NodeDef* function_node, - const std::string& function_name) { + absl::Status AddFunction(const NodeDef* function_node, + const std::string& function_name) { auto it = fun_to_grappler_function_item_.find(function_name); if (it != fun_to_grappler_function_item_.end()) { return absl::OkStatus(); @@ -1315,7 +1315,7 @@ class SymbolicShapeRefiner { const FunctionDef* function_def = CHECK_NOTNULL(function_library_.Find(function_name)); GrapplerFunctionItem grappler_function_item; - Status function_instantiated = + absl::Status function_instantiated = MakeGrapplerFunctionItem(*function_def, function_library_, graph_def_version_, &grappler_function_item); @@ -1351,7 +1351,7 @@ class SymbolicShapeRefiner { return absl::OkStatus(); } - Status AddNode(const NodeDef* node) { + absl::Status AddNode(const NodeDef* node) { NodeContext& node_ctx = node_to_context_[node]; NameAttrList function; TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function)); @@ -1380,7 +1380,7 @@ class SymbolicShapeRefiner { graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes, input_tensors, input_tensors_as_shapes, std::move(input_handle_shapes_and_types))); - const Status s = node_ctx.inference_context->construction_status(); + const absl::Status s = node_ctx.inference_context->construction_status(); if (!s.ok()) { node_ctx.inference_context.reset(nullptr); } @@ -1581,7 +1581,8 @@ class SymbolicShapeRefiner { // Run a node to infer output shapes and values, and add it to the // NodeContext. - Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) { + absl::Status UpdateOutputShapesAndValues(const NodeDef& node, + NodeContext* c) { InferenceContext* ic = c->inference_context.get(); // Input to EvaluateNode() @@ -1636,8 +1637,8 @@ class SymbolicShapeRefiner { // Currently only handle nodes with static shapes, i.e. shapes do not change // during execution. // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well. - Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node, - NodeContext* c) const { + absl::Status UpdateOutputShapesUsingAnnotatedInformation( + const NodeDef& node, NodeContext* c) const { const auto& attr = node.attr(); if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() || attr.count(kOutputShapes) == 0) @@ -1697,8 +1698,8 @@ class SymbolicShapeRefiner { return absl::OkStatus(); } - Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed, - NodeContext* c) { + absl::Status MaybeUpdateNodeContextOutput(const NodeDef& node, + const bool is_fed, NodeContext* c) { // Propagate tensors and shape tensors unless the node is fed. // TODO(bsteiner) We should still propagate the shapes to the ports that // aren't fed in the case of a ShapeN node. @@ -1918,7 +1919,7 @@ class SymbolicShapeRefiner { return absl::OkStatus(); } - Status InferShapes(const NodeDef& node, NodeContext* c) { + absl::Status InferShapes(const NodeDef& node, NodeContext* c) { // Infer the shapes of output tensors. if (!c->op_data || c->op_data->shape_inference_fn == nullptr || !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) { @@ -1929,7 +1930,7 @@ class SymbolicShapeRefiner { TF_RETURN_IF_ERROR( c->inference_context->Run(shape_inference::UnknownShape)); } - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); auto it = fed_ports_.find(node.name()); const bool is_fed = it != fed_ports_.end(); if (is_fed) { @@ -2071,7 +2072,7 @@ class SymbolicShapeManager { public: SymbolicShapeManager() {} - Status Merge(ShapeHandle s1, ShapeHandle s2) { + absl::Status Merge(ShapeHandle s1, ShapeHandle s2) { if (!s1.IsSet() || !s2.IsSet()) { return absl::OkStatus(); } @@ -2085,7 +2086,7 @@ class SymbolicShapeManager { } return absl::OkStatus(); } - Status Merge(DimensionHandle d1, DimensionHandle d2) { + absl::Status Merge(DimensionHandle d1, DimensionHandle d2) { if (!d1.IsSet() || !d2.IsSet()) { return absl::OkStatus(); } @@ -2137,9 +2138,9 @@ class SymbolicShapeManager { // Checks whether there is any conflict in merged shapes and dims in // SymbolicShapeManager. -Status ValidateSymbolicShapeManager(const GraphDef& graph_def, - SymbolicShapeRefiner* refiner, - SymbolicShapeManager* shape_manager) { +absl::Status ValidateSymbolicShapeManager(const GraphDef& graph_def, + SymbolicShapeRefiner* refiner, + SymbolicShapeManager* shape_manager) { if (!VLOG_IS_ON(1)) { return absl::OkStatus(); } @@ -2186,9 +2187,9 @@ Status ValidateSymbolicShapeManager(const GraphDef& graph_def, } // Log shape inference and its merged shapes. -Status VerboseShapeInferenceLogging(const GraphDef& graph_def, - SymbolicShapeRefiner* refiner, - SymbolicShapeManager* shape_manager) { +absl::Status VerboseShapeInferenceLogging(const GraphDef& graph_def, + SymbolicShapeRefiner* refiner, + SymbolicShapeManager* shape_manager) { // As logging all the nodes would generate too many lines, we by default // skip this detailed logging. Users may add nodes of interest to // node_names_for_logging to enable detailed logging. @@ -2234,7 +2235,7 @@ Status VerboseShapeInferenceLogging(const GraphDef& graph_def, return absl::OkStatus(); } -Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( +absl::Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode, const std::vector& shapes_and_types, std::vector* queue_shapes_and_types) { @@ -2259,9 +2260,9 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( // Compute the output shape of the merge node as the union of the available // input shapes. -Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner, - const NodeDef* node, - bool* new_shapes) const { +absl::Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, + bool* new_shapes) const { InferenceContext* ic = shape_refiner->GetContext(node); if (!ic) { // Now we can run shape inference @@ -2312,8 +2313,9 @@ Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner, } // Manually propagate the input shape for Enter nodes. -Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, - const NodeDef* node, bool* new_shapes) { +absl::Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, + bool* new_shapes) { InferenceContext* ic = shape_refiner->GetContext(node); if (!ic) { TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes)); @@ -2339,7 +2341,7 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, return absl::OkStatus(); } -Status GraphProperties::UpdateShapes( +absl::Status GraphProperties::UpdateShapes( SymbolicShapeRefiner* shape_refiner, const absl::flat_hash_map& resource_handles, const NodeDef* n, bool* new_shapes) const { @@ -2368,7 +2370,7 @@ Status GraphProperties::UpdateShapes( } // Propagates the shapes in the transitive fan-out of . -Status GraphProperties::PropagateShapes( +absl::Status GraphProperties::PropagateShapes( SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, const absl::flat_hash_map& resource_handles, int num_loops) const { @@ -2420,9 +2422,9 @@ Status GraphProperties::PropagateShapes( return absl::OkStatus(); } -Status GraphProperties::UpdateQueue(const NodeDef* queue_node, - SymbolicShapeRefiner* shape_refiner, - bool* new_shapes) { +absl::Status GraphProperties::UpdateQueue(const NodeDef* queue_node, + SymbolicShapeRefiner* shape_refiner, + bool* new_shapes) { auto* ctx = shape_refiner->GetNodeContext(queue_node); if (!ctx) { TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node)); @@ -2467,7 +2469,7 @@ Status GraphProperties::UpdateQueue(const NodeDef* queue_node, return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes); } -Status GraphProperties::UpdateEnqueue( +absl::Status GraphProperties::UpdateEnqueue( const NodeDef* enqueue_node, const absl::flat_hash_map& resource_handles, SymbolicShapeRefiner* shape_refiner, bool* new_shapes) { @@ -2514,10 +2516,9 @@ Status GraphProperties::UpdateEnqueue( return absl::OkStatus(); } -Status GraphProperties::InferStatically(bool assume_valid_feeds, - bool aggressive_shape_inference, - bool include_input_tensor_values, - bool include_output_tensor_values) { +absl::Status GraphProperties::InferStatically( + bool assume_valid_feeds, bool aggressive_shape_inference, + bool include_input_tensor_values, bool include_output_tensor_values) { FunctionLibraryDefinition function_library(OpRegistry::Global(), item_.graph.library()); absl::flat_hash_map> fed_ports; @@ -2589,7 +2590,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, } std::vector topo_order; - Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order); + absl::Status s = + ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order); if (!s.ok()) { if (extra_deps.empty()) { return s; @@ -2759,7 +2761,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, return absl::OkStatus(); } -Status GraphProperties::InferDynamically(Cluster* cluster) { +absl::Status GraphProperties::InferDynamically(Cluster* cluster) { TF_RETURN_IF_ERROR(cluster->Initialize(item_)); // Runs the model once to collect the shapes in the cost model. @@ -2770,7 +2772,8 @@ Status GraphProperties::InferDynamically(Cluster* cluster) { return InferFromCostGraph(metadata.cost_graph()); } -Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const { +absl::Status GraphProperties::AnnotateOutputShapes( + GraphDef* output_graph_def) const { *output_graph_def = item_.graph; for (int i = 0; i < output_graph_def->node_size(); i++) { auto node = output_graph_def->mutable_node(i); @@ -2786,7 +2789,8 @@ Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const { return absl::OkStatus(); } -Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) { +absl::Status GraphProperties::InferFromCostGraph( + const CostGraphDef& cost_graph) { if (cost_graph.node_size() == 0) { LOG(WARNING) << "cost_graph is empty: nothing can be inferred!"; } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index fbe8096d40e3b7..1d9575e1e5c805 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -96,33 +96,33 @@ class GraphProperties { // will included in the input properties. // If include_output_tensor_values is true, the values of constant tensors // will be included in the output properties. - Status InferStatically(bool assume_valid_feeds, - bool aggressive_shape_inference, - bool include_input_tensor_values, - bool include_output_tensor_values); - Status InferStatically(bool assume_valid_feeds, - bool aggressive_shape_inference, - bool include_tensor_values) { + absl::Status InferStatically(bool assume_valid_feeds, + bool aggressive_shape_inference, + bool include_input_tensor_values, + bool include_output_tensor_values); + absl::Status InferStatically(bool assume_valid_feeds, + bool aggressive_shape_inference, + bool include_tensor_values) { return InferStatically( assume_valid_feeds, /*aggressive_shape_inference=*/aggressive_shape_inference, /*include_input_tensor_values=*/include_tensor_values, /*include_output_tensor_values=*/include_tensor_values); } - Status InferStatically(bool assume_valid_feeds) { + absl::Status InferStatically(bool assume_valid_feeds) { return InferStatically(assume_valid_feeds, /*aggressive_shape_inference=*/false, /*include_tensor_values=*/true); } // Infer the shape by running the graph on the specified cluster and recording // the shapes of the processed tensors. - Status InferDynamically(Cluster* cluster); + absl::Status InferDynamically(Cluster* cluster); // Extract the properties from a cost graph. For testing only since there is // no way to ensure that the cost graph match the item. - Status InferFromCostGraph(const CostGraphDef& cost_graph); + absl::Status InferFromCostGraph(const CostGraphDef& cost_graph); // Stores `item_.graph` with the inferred output shapes to `output_graph_def`. - Status AnnotateOutputShapes(GraphDef* output_graph_def) const; + absl::Status AnnotateOutputShapes(GraphDef* output_graph_def) const; // Return the properties of node inputs/outputs, including data types and // shapes. Note that the dimensions in the shapes can be negative. We use the @@ -161,40 +161,41 @@ class GraphProperties { private: // Relaxes shapes , determined from an EnqueueV2 node, into // <*queue_shapes_and_types>. - static Status RelaxEnqueueShapesAndMergeTypes( + static absl::Status RelaxEnqueueShapesAndMergeTypes( SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode, const std::vector& shapes_and_types, std::vector* queue_shapes_and_types); // Update the shapes of the enqueue node, port them over to the corresponding // queue, and schedule the reprocessing of the queue if needed. - static Status UpdateEnqueue( + static absl::Status UpdateEnqueue( const NodeDef* enqueue_node, const absl::flat_hash_map& resource_handles, SymbolicShapeRefiner* shape_refiner, bool* new_shapes); // Update the shapes and types of the Queue node, if not set by Enqueue node. - static Status UpdateQueue(const NodeDef* queue_node, - SymbolicShapeRefiner* shape_refiner, - bool* new_shapes); + static absl::Status UpdateQueue(const NodeDef* queue_node, + SymbolicShapeRefiner* shape_refiner, + bool* new_shapes); // Update the output shapes of a Merge node, and enqueue its fanout in // new_shapes if needed. - Status UpdateMerge(SymbolicShapeRefiner* shape_refiner, const NodeDef* node, - bool* new_shapes) const; + absl::Status UpdateMerge(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, bool* new_shapes) const; // Process the Enter node, and enqueue its fanout in new_shapes if needed. - static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner, - const NodeDef* node, bool* new_shapes); + static absl::Status UpdateEnter(SymbolicShapeRefiner* shape_refiner, + const NodeDef* node, bool* new_shapes); // Update the shapes for node 'n'. If output shapes for n have changed, // enqueue its fanout in 'new_shapes'. - Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, - const absl::flat_hash_map& - resource_handles, - const NodeDef* n, bool* new_shapes) const; + absl::Status UpdateShapes( + SymbolicShapeRefiner* shape_refiner, + const absl::flat_hash_map& + resource_handles, + const NodeDef* n, bool* new_shapes) const; // Propagate the shapes for the nodes enqueued in new_shapes and their // transitive fanout until a fixed point is reached. - Status PropagateShapes( + absl::Status PropagateShapes( SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes, const absl::flat_hash_map& resource_handles, diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index c64a322d4b9dc3..53fe8bef1f6f85 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -157,7 +157,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) { CHECK(fake_input.NextItem(&item)); GraphProperties properties(item); - Status s = properties.InferStatically(true); + absl::Status s = properties.InferStatically(true); TF_ASSERT_OK(s); for (const auto& node : item.graph.node()) { @@ -198,7 +198,7 @@ TEST_F(GraphPropertiesTest, ClearProperties) { CHECK(fake_input.NextItem(&item)); GraphProperties properties(item); - Status s = properties.InferStatically(true); + absl::Status s = properties.InferStatically(true); TF_ASSERT_OK(s); for (const auto& node : item.graph.node()) { @@ -225,7 +225,7 @@ TEST_F(GraphPropertiesTest, Clear) { CHECK(fake_input.NextItem(&item)); GraphProperties properties(item); - Status s = properties.InferStatically(true); + absl::Status s = properties.InferStatically(true); TF_ASSERT_OK(s); EXPECT_TRUE(properties.has_properties()); @@ -241,7 +241,7 @@ TEST_F(GraphPropertiesTest, DynamicProperties) { GraphProperties properties(item); TF_ASSERT_OK(cluster_->Initialize(item)); - Status s = properties.InferDynamically(cluster_.get()); + absl::Status s = properties.InferDynamically(cluster_.get()); TF_ASSERT_OK(s); for (const auto& node : item.graph.node()) { @@ -2066,7 +2066,7 @@ TEST_F(GraphPropertiesTest, FedNodes) { { // Conservative shape analysis: the shape of fed ports should be unknown GraphProperties properties(item); - Status s = properties.InferStatically(false); + absl::Status s = properties.InferStatically(false); TF_ASSERT_OK(s); for (const auto& node : item.graph.node()) { if (node.op() == "Const") { @@ -2097,7 +2097,7 @@ TEST_F(GraphPropertiesTest, FedNodes) { // Optimistic shape analysis: the shape of fed ports should be derived from // the shape of the fanin. GraphProperties properties(item); - Status s = properties.InferStatically(true); + absl::Status s = properties.InferStatically(true); TF_ASSERT_OK(s); for (const auto& node : item.graph.node()) { if (node.op() == "Square" || node.op() == "AddN") { diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index e7c330cbe22e6c..43c4a26fb832bc 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -44,15 +44,15 @@ MeasuringCostEstimator::MeasuringCostEstimator(Cluster* cluster, cluster_ = cluster; } -Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) { +absl::Status MeasuringCostEstimator::Initialize(const GrapplerItem& item) { feed_ = item.feed; fetch_ = item.fetch; return cluster_->Initialize(item); } -Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, - RunMetadata* run_metadata, - Costs* costs) const { +absl::Status MeasuringCostEstimator::PredictCosts( + const GraphDef& optimized_graph, RunMetadata* run_metadata, + Costs* costs) const { CostGraphDef* cost_graph = nullptr; if (run_metadata) { cost_graph = run_metadata->mutable_cost_graph(); @@ -63,13 +63,13 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, BlockingCounter barrier(measurement_steps_); mutex status_mu; - Status status; + absl::Status status; auto measurement_fn = [&](const int step) { const Costs::MicroSeconds start = Env::Default()->NowMicros(); RunMetadata metadata; - const Status local_status = + const absl::Status local_status = cluster_->Run(optimized_graph, feed_, fetch_, &metadata); { mutex_lock lock(status_mu); diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.h b/tensorflow/core/grappler/costs/measuring_cost_estimator.h index 67145f5241ef8a..5da9bac98538a4 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.h +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.h @@ -52,14 +52,15 @@ class MeasuringCostEstimator : public CostEstimator { // Initializes the estimator for the specified grappler item. // This implementation always returns OK. - Status Initialize(const GrapplerItem& item) override; + absl::Status Initialize(const GrapplerItem& item) override; // Runs the optimized version of the graph on the cluster, measures // the runtimes of each operation, and annotates the CostGraphDef of // RunMetadata with the corresponding measurements. // Returns the average latency for the whole graph. - Status PredictCosts(const GraphDef& optimized_graph, - RunMetadata* run_metadata, Costs* cost) const override; + absl::Status PredictCosts(const GraphDef& optimized_graph, + RunMetadata* run_metadata, + Costs* cost) const override; private: Cluster* cluster_; // Not owned. diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 42cfafa2cd33c4..b869ac8a09bf56 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -136,7 +136,7 @@ static void ExtractExtraProperties( Env* env = Env::Default(); FileStatistics stat; - Status s = env->Stat(filename, &stat); + absl::Status s = env->Stat(filename, &stat); if (!s.ok()) { continue; } @@ -250,7 +250,7 @@ DeviceProperties GetDeviceInfo(const string& device_str) { if (parsed.type == "GPU") { TfDeviceId tf_device_id(parsed.id); PlatformDeviceId platform_device_id; - Status s = + absl::Status s = GpuIdManager::TfToPlatformDeviceId(tf_device_id, &platform_device_id); if (!s.ok()) { // We are probably running simulation without linking cuda libraries. diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 43941e62226648..8ff152dc83fc44 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -169,7 +169,7 @@ HeapReadyManager::HeapReadyManager() : ReadyNodeManager() { std::make_heap(nodes_.begin(), nodes_.end()); } -Status HeapReadyManager::Init( +absl::Status HeapReadyManager::Init( const std::unordered_map* node_map) { // Resets the node state since different instances of the scheduler can reuse // the same node_manager. @@ -266,7 +266,7 @@ void PriorityReadyManager::AddNode(const NodeDef* node) { HeapReadyManager::AddNode(node); } -Status PriorityReadyManager::SetPriority( +absl::Status PriorityReadyManager::SetPriority( const std::unordered_map& node_priority) { node_priority_ = node_priority; return absl::OkStatus(); @@ -275,7 +275,7 @@ Status PriorityReadyManager::SetPriority( CompositeNodeManager::CompositeNodeManager() : ReadyNodeManager(), send_manager_(), recv_manager_() {} -Status CompositeNodeManager::Init( +absl::Status CompositeNodeManager::Init( const std::unordered_map* node_map) { node_map_ = node_map; TF_RETURN_IF_ERROR(send_manager_.Init(node_map)); @@ -403,9 +403,9 @@ SchedulerState::SchedulerState(const bool use_static_shapes, track_mem_usage_snapshot_ = VLOG_IS_ON(1); } -Status SchedulerState::Init(const GrapplerItem* item, - std::vector* initial_nodes, - bool create_explicit_channel_device) { +absl::Status SchedulerState::Init(const GrapplerItem* item, + std::vector* initial_nodes, + bool create_explicit_channel_device) { initialized_ = false; // Clear all internal states so that the SchedulerState is reusable for @@ -1398,7 +1398,7 @@ VirtualScheduler::VirtualScheduler( std::unique_ptr scheduler_state) : scheduler_state_(std::move(scheduler_state)), ready_nodes_(ready_nodes) {} -Status VirtualScheduler::Init(const GrapplerItem* item) { +absl::Status VirtualScheduler::Init(const GrapplerItem* item) { // SchedulerState::Init() preprocesses the input grappler_item and // graph_properties to extract necessary information for emulating tensorflow // op scheduling and construct internal data structures (NodeState and diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 12aaa1ea7da325..f574832b1857c8 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -172,7 +172,7 @@ class ReadyNodeManager { public: ReadyNodeManager() {} virtual ~ReadyNodeManager() {} - virtual Status Init( + virtual absl::Status Init( const std::unordered_map* node_map) { return absl::OkStatus(); } @@ -226,7 +226,7 @@ class LIFOManager : public ReadyNodeManager { class HeapReadyManager : public ReadyNodeManager { public: HeapReadyManager(); - Status Init( + absl::Status Init( const std::unordered_map* node_map) override; ~HeapReadyManager() override {} void AddNode(const NodeDef* node) override; @@ -274,7 +274,8 @@ class PriorityReadyManager : public HeapReadyManager { void AddNode(const NodeDef* node) override; // Note this should be called after Init(). - Status SetPriority(const std::unordered_map& node_priority); + absl::Status SetPriority( + const std::unordered_map& node_priority); protected: std::function Greater() override; @@ -296,7 +297,7 @@ class CompositeNodeManager : public ReadyNodeManager { CompositeNodeManager(); ~CompositeNodeManager() override {} - Status Init( + absl::Status Init( const std::unordered_map* node_map) override; void AddNode(const NodeDef* node) override; const NodeDef* GetCurrNode() override; @@ -352,9 +353,9 @@ class SchedulerState { // initial_nodes is the set of nodes (primary inputs) discovered by Init() // which may be added by a ReadyNodeManager (or related/derivative scheduler) // to begin node schedule and graph simulation. - Status Init(const GrapplerItem* item, - std::vector* initial_nodes, - bool create_explicit_channel_device = true); + absl::Status Init(const GrapplerItem* item, + std::vector* initial_nodes, + bool create_explicit_channel_device = true); virtual Costs Summary() const; // Like the above, but writes detailed stats to RunMetadata. @@ -487,7 +488,7 @@ class VirtualScheduler { // This function should be called at least once after the scheduler is // constructed. An uninitialized or failed-to-initialize scheduler will cause // undefined behavior. - virtual Status Init(const GrapplerItem* item); + virtual absl::Status Init(const GrapplerItem* item); // Gets the current scheduled node for execution; the caller of this function // can accordingly simulate the execution of the current scheduled node. diff --git a/tensorflow/core/grappler/graph_topology_view.cc b/tensorflow/core/grappler/graph_topology_view.cc index 7dfb15c2c07f6a..a6cc0dc29fa91f 100644 --- a/tensorflow/core/grappler/graph_topology_view.cc +++ b/tensorflow/core/grappler/graph_topology_view.cc @@ -38,7 +38,7 @@ inline void SortAndRemoveDuplicates(T* v) { } // namespace -Status GraphTopologyView::InitializeFromGraph( +absl::Status GraphTopologyView::InitializeFromGraph( const GraphDef& graph, const absl::Span ephemeral_edges, bool ignore_control_edges) { @@ -140,20 +140,20 @@ Status GraphTopologyView::InitializeFromGraph( return absl::OkStatus(); } -Status GraphTopologyView::InitializeFromGraph( +absl::Status GraphTopologyView::InitializeFromGraph( const GraphDef& graph, const absl::Span ephemeral_edges) { return InitializeFromGraph(graph, ephemeral_edges, /*ignore_control_edges=*/false); } -Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph, - bool ignore_control_edges) { +absl::Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph, + bool ignore_control_edges) { return InitializeFromGraph(graph, absl::Span(), ignore_control_edges); } -Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) { +absl::Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) { return InitializeFromGraph(graph, absl::Span(), /*ignore_control_edges*/ false); } diff --git a/tensorflow/core/grappler/graph_topology_view.h b/tensorflow/core/grappler/graph_topology_view.h index cdb2eeb92a9726..91cbfa2a1cef9a 100644 --- a/tensorflow/core/grappler/graph_topology_view.h +++ b/tensorflow/core/grappler/graph_topology_view.h @@ -55,13 +55,14 @@ class GraphTopologyView { // computing graph topology. Example: Tensorflow runtime allows concurrent // execution of dequeue/enqueue ops from the same queue resource, but we might // want to enforce ordering between them for the purpose of graph analysis. - Status InitializeFromGraph(const GraphDef& graph, - absl::Span ephemeral_edges, - bool ignore_control_edges); - Status InitializeFromGraph(const GraphDef& graph, - absl::Span ephemeral_edges); - Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges); - Status InitializeFromGraph(const GraphDef& graph); + absl::Status InitializeFromGraph( + const GraphDef& graph, absl::Span ephemeral_edges, + bool ignore_control_edges); + absl::Status InitializeFromGraph( + const GraphDef& graph, absl::Span ephemeral_edges); + absl::Status InitializeFromGraph(const GraphDef& graph, + bool ignore_control_edges); + absl::Status InitializeFromGraph(const GraphDef& graph); bool is_initialized() const { return graph_ != nullptr; } int num_nodes() const { return num_nodes_; } diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 6c3aa407bd91df..4b7e8cfe71a798 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -320,7 +320,7 @@ class GraphViewInternal { protected: explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} - Status AddUniqueNode(NodeDefT* node) { + absl::Status AddUniqueNode(NodeDefT* node) { auto inserted = nodes_.emplace(node->name(), node); return inserted.second ? absl::OkStatus() @@ -330,7 +330,7 @@ class GraphViewInternal { // TODO(ezhulenev): Remove this function. void AddUniqueNodeOrDie(NodeDefT* node) { - Status st = AddUniqueNode(node); + absl::Status st = AddUniqueNode(node); CHECK(st.ok()) << st.message(); } diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 17e01a67bf3793..0143b66c5ac228 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -171,7 +171,7 @@ const std::unordered_set& GrapplerItem::devices() const { return devices_; } -Status GrapplerItem::AddDevice(const string& device) { +absl::Status GrapplerItem::AddDevice(const string& device) { DeviceNameUtils::ParsedName name; if (!DeviceNameUtils::ParseFullName(device, &name)) { @@ -187,10 +187,10 @@ Status GrapplerItem::AddDevice(const string& device) { return absl::OkStatus(); } -Status GrapplerItem::AddDevices(const GrapplerItem& other) { +absl::Status GrapplerItem::AddDevices(const GrapplerItem& other) { std::vector invalid_devices; for (const string& device : other.devices()) { - Status added = AddDevice(device); + absl::Status added = AddDevice(device); if (!added.ok()) invalid_devices.emplace_back(device); } return invalid_devices.empty() @@ -200,10 +200,10 @@ Status GrapplerItem::AddDevices(const GrapplerItem& other) { "]"); } -Status GrapplerItem::InferDevicesFromGraph() { +absl::Status GrapplerItem::InferDevicesFromGraph() { absl::flat_hash_set invalid_devices; for (const NodeDef& node : graph.node()) { - Status added = AddDevice(node.device()); + absl::Status added = AddDevice(node.device()); if (!added.ok()) invalid_devices.insert(node.device()); } VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]"; diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 6778d1f3047335..36bc4f1552e4be 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -112,13 +112,13 @@ struct GrapplerItem { // Adds a device to a set of available devices, only if it's a valid fully // defined device name. Returns `OkStatus()` if successfully added a device, // and an error otherwise. - Status AddDevice(const string& device); + absl::Status AddDevice(const string& device); // Adds all valid devices from the other Grappler item to the device set. - Status AddDevices(const GrapplerItem& other); + absl::Status AddDevices(const GrapplerItem& other); // Adds all valid devices from the nodes of the graph to the device set. // Returns `OkStatus()` if all device annotations found in a graph are valid // fully defined device names, and an error otherwise. - Status InferDevicesFromGraph(); + absl::Status InferDevicesFromGraph(); // Clears a set of available devices. void ClearDevices(); diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 35198461cc8033..54c8883db7cd2d 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -80,7 +80,7 @@ void InitializeTensor(DataType type, Tensor* tensor) { // Applies the same graph pruning logic to the graph as Session.Run in TF. // If the returned status is not OK, item state may be inconsistent. -Status PruneGraph(GrapplerItem* item) { +absl::Status PruneGraph(GrapplerItem* item) { ModelPruner pruner; GraphDef pruned_graph; Cluster* cluster = nullptr; // ModelPruner doesn't check cluster. @@ -91,10 +91,10 @@ Status PruneGraph(GrapplerItem* item) { // Replace any unknown dimensions in a shape with // cfg.placeholder_unknown_output_shape_dim if it is no less than 0. -Status ReplaceUnknownShapeDim(const ItemConfig& cfg, - const TensorShapeProto& shape_pb_in, - TensorShapeProto* shape_pb_out, - TensorShape* shape_out) { +absl::Status ReplaceUnknownShapeDim(const ItemConfig& cfg, + const TensorShapeProto& shape_pb_in, + TensorShapeProto* shape_pb_out, + TensorShape* shape_out) { std::vector dims; for (const auto& dim_proto : shape_pb_in.dim()) { if (cfg.placeholder_unknown_output_shape_dim >= 0 && @@ -115,7 +115,7 @@ Status ReplaceUnknownShapeDim(const ItemConfig& cfg, // the Placeholder node has _output_shapes. // Otherwise keep it intact to keep compatible with shape annotation // (b/134092018). -Status UpdatePlaceholderShape( +absl::Status UpdatePlaceholderShape( const ItemConfig& cfg, const std::unordered_set& signature_feed_nodes, GrapplerItem* new_item, NodeDef* node) { @@ -140,7 +140,7 @@ Status UpdatePlaceholderShape( // shape is not empty if the shape is partially defined. TensorShape shape; TensorShapeProto shape_proto; - Status make_shape_status = ReplaceUnknownShapeDim( + absl::Status make_shape_status = ReplaceUnknownShapeDim( cfg, node->attr().at("shape").shape(), &shape_proto, &shape); if (!make_shape_status.ok()) { return absl::InternalError( @@ -208,9 +208,9 @@ Status UpdatePlaceholderShape( } // namespace -Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, - GraphDef* output_graph_def, - const ItemConfig& cfg) { +absl::Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, + GraphDef* output_graph_def, + const ItemConfig& cfg) { // This is a temporary change that optimizes the graph in context of a single // gpu machine. Down the line, we may want to make grappler_item_builder aware // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated @@ -369,8 +369,8 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( NodeName(input.name()))) { TensorShape shape; TensorShapeProto shape_proto; - Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), - &shape_proto, &shape); + absl::Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), + &shape_proto, &shape); if (!s.ok()) { LOG(ERROR) << "Invalid shape for signature input " << input.name() << ": " << s << ", skipping this input"; @@ -551,8 +551,8 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( for (auto& node : *new_item->graph.mutable_node()) { if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") { - Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes, - new_item.get(), &node); + absl::Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes, + new_item.get(), &node); if (!s.ok()) return nullptr; } else if (IsConstant(node)) { auto it = asset_node_to_value.find(node.name()); @@ -619,7 +619,7 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } // Instantiate all the missing attributes with their default values. - Status attr_status = AddDefaultAttrsToGraphDef( + absl::Status attr_status = AddDefaultAttrsToGraphDef( &new_item->graph, FunctionLibraryDefinition(OpRegistry::Global(), new_item->graph.library()), @@ -633,7 +633,7 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( // Optimize the graph (function inlining, l1 optimizations, etc). VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: " << new_item->graph.node_size(); - Status optimize_status = + absl::Status optimize_status = RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg); if (!optimize_status.ok()) { LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index d967fa180aad21..00661da0253c0d 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -64,8 +64,9 @@ struct ItemConfig { // Method for optimizing the graph def (including function inlining and other // optimizations). This is optimizations from common_runtime, NOT Grappler // function optimizer. -Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, - GraphDef* output_graph_def, const ItemConfig& cfg); +absl::Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, + GraphDef* output_graph_def, + const ItemConfig& cfg); // Factory method for creating a GrapplerItem from a MetaGraphDef. // Returns nullptr if the given meta_graph cannot be converted. diff --git a/tensorflow/core/grappler/inputs/file_input_yielder.cc b/tensorflow/core/grappler/inputs/file_input_yielder.cc index e9ffb92cf0a929..2df0378441df9c 100644 --- a/tensorflow/core/grappler/inputs/file_input_yielder.cc +++ b/tensorflow/core/grappler/inputs/file_input_yielder.cc @@ -74,7 +74,7 @@ bool FileInputYielder::NextItem(GrapplerItem* item) { LOG(INFO) << "Loading model from " << filename; MetaGraphDef metagraph; - Status s = ReadBinaryProto(Env::Default(), filename, &metagraph); + absl::Status s = ReadBinaryProto(Env::Default(), filename, &metagraph); if (!s.ok()) { s = ReadTextProto(Env::Default(), filename, &metagraph); } diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc index 580a526d0b1b3d..6b2f380bd6a06d 100644 --- a/tensorflow/core/grappler/inputs/utils.cc +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -26,7 +26,8 @@ limitations under the License. namespace tensorflow { namespace grappler { -bool FilesExist(const std::vector& files, std::vector* status) { +bool FilesExist(const std::vector& files, + std::vector* status) { return Env::Default()->FilesExist(files, status); } @@ -34,22 +35,23 @@ bool FilesExist(const std::set& files) { return FilesExist(std::vector(files.begin(), files.end()), nullptr); } -bool FileExists(const string& file, Status* status) { +bool FileExists(const string& file, absl::Status* status) { *status = Env::Default()->FileExists(file); return status->ok(); } -Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result) { - Status status; +absl::Status ReadGraphDefFromFile(const string& graph_def_path, + GraphDef* result) { + absl::Status status; if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) { return ReadTextProto(Env::Default(), graph_def_path, result); } return status; } -Status ReadMetaGraphDefFromFile(const string& graph_def_path, - MetaGraphDef* result) { - Status status; +absl::Status ReadMetaGraphDefFromFile(const string& graph_def_path, + MetaGraphDef* result) { + absl::Status status; if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) { return ReadTextProto(Env::Default(), graph_def_path, result); } diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h index 0f69913ed6fa99..589dbc00f4560c 100644 --- a/tensorflow/core/grappler/inputs/utils.h +++ b/tensorflow/core/grappler/inputs/utils.h @@ -29,17 +29,18 @@ namespace tensorflow { namespace grappler { bool FilesExist(const std::vector& files, - std::vector* status = nullptr); + std::vector* status = nullptr); bool FilesExist(const std::set& files); -bool FileExists(const string& file, Status* status); +bool FileExists(const string& file, absl::Status* status); // Reads GraphDef from file in either text or raw serialized format. -Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result); +absl::Status ReadGraphDefFromFile(const string& graph_def_path, + GraphDef* result); // Reads MetaGraphDef from file in either text or raw serialized format. -Status ReadMetaGraphDefFromFile(const string& meta_graph_def_path, - MetaGraphDef* result); +absl::Status ReadMetaGraphDefFromFile(const string& meta_graph_def_path, + MetaGraphDef* result); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/utils_test.cc b/tensorflow/core/grappler/inputs/utils_test.cc index ea38f6c4d8545b..51a1c48b6adf5c 100644 --- a/tensorflow/core/grappler/inputs/utils_test.cc +++ b/tensorflow/core/grappler/inputs/utils_test.cc @@ -81,7 +81,7 @@ TEST_F(UtilsTest, FilesExist) { FilesExist(std::vector{{non_existent_file_}, {actual_file_}})); EXPECT_TRUE(FilesExist(std::vector{{actual_file_}})); - std::vector status; + std::vector status; EXPECT_FALSE(FilesExist( std::vector{{non_existent_file_}, {actual_file_}}, &status)); EXPECT_EQ(status.size(), 2); diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index cf159922c51daa..a801e68e701bb0 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -250,13 +250,13 @@ bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) { return it != fanouts.end() && !it->second.empty(); } -Status MutationError(absl::string_view function_name, absl::string_view params, - absl::string_view msg) { +absl::Status MutationError(absl::string_view function_name, + absl::string_view params, absl::string_view msg) { return errors::InvalidArgument(absl::Substitute( "MutableGraphView::$0($1) error: $2.", function_name, params, msg)); } -using ErrorHandler = std::function; +using ErrorHandler = std::function; ErrorHandler UpdateFanoutsError(absl::string_view from_node_name, absl::string_view to_node_name) { @@ -267,7 +267,7 @@ ErrorHandler UpdateFanoutsError(absl::string_view from_node_name, }; } -Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) { +absl::Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) { if (!IsTensorIdRegular(fanin)) { return handler(absl::Substitute("fanin '$0' must be a regular tensor id", fanin.ToString())); @@ -275,7 +275,7 @@ Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) { return absl::OkStatus(); } -Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { +absl::Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { if (!IsTensorIdPortValid(fanin)) { return handler(absl::Substitute("fanin '$0' must be a valid tensor id", fanin.ToString())); @@ -283,8 +283,9 @@ Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { return absl::OkStatus(); } -Status CheckAddingFaninToSelf(absl::string_view node_name, - const TensorId& fanin, ErrorHandler handler) { +absl::Status CheckAddingFaninToSelf(absl::string_view node_name, + const TensorId& fanin, + ErrorHandler handler) { if (node_name == fanin.node()) { return handler( absl::Substitute("can't add fanin '$0' to self", fanin.ToString())); @@ -292,8 +293,9 @@ Status CheckAddingFaninToSelf(absl::string_view node_name, return absl::OkStatus(); } -Status CheckRemovingFaninFromSelf(absl::string_view node_name, - const TensorId& fanin, ErrorHandler handler) { +absl::Status CheckRemovingFaninFromSelf(absl::string_view node_name, + const TensorId& fanin, + ErrorHandler handler) { if (node_name == fanin.node()) { return handler(absl::Substitute("can't remove fanin '$0' from self", fanin.ToString())); @@ -305,15 +307,15 @@ string NodeMissingErrorMsg(absl::string_view node_name) { return absl::Substitute("node '$0' was not found", node_name); } -Status CheckNodeExists(absl::string_view node_name, NodeDef* node, - ErrorHandler handler) { +absl::Status CheckNodeExists(absl::string_view node_name, NodeDef* node, + ErrorHandler handler) { if (node == nullptr) { return handler(NodeMissingErrorMsg(node_name)); } return absl::OkStatus(); } -Status CheckPortRange(int port, int min, int max, ErrorHandler handler) { +absl::Status CheckPortRange(int port, int min, int max, ErrorHandler handler) { if (port < min || port > max) { if (max < min) { return handler("no available ports as node has no regular fanins"); @@ -462,7 +464,7 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) { return node_in_graph; } -Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { +absl::Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { // 1. Add all new functions and check that functions with the same name // have identical definition. const int function_size = subgraph.library().function_size(); @@ -511,7 +513,7 @@ Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { return absl::OkStatus(); } -Status MutableGraphView::UpdateNode( +absl::Status MutableGraphView::UpdateNode( absl::string_view node_name, absl::string_view op, absl::string_view device, absl::Span> attrs) { auto error_status = [node_name, op, device, attrs](absl::string_view msg) { @@ -565,9 +567,9 @@ Status MutableGraphView::UpdateNode( return absl::OkStatus(); } -Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, - absl::string_view to_node_name, - bool update_fanouts) { +absl::Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts) { auto error_status = [from_node_name, to_node_name, update_fanouts](absl::string_view msg) { string params = absl::Substitute( @@ -608,9 +610,9 @@ Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, return absl::OkStatus(); } -Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, - absl::string_view to_node_name, - bool update_fanouts) { +absl::Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts) { auto error_status = [from_node_name, to_node_name, update_fanouts](absl::string_view msg) { string params = absl::Substitute( @@ -753,8 +755,8 @@ Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, return absl::OkStatus(); } -Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name, - absl::string_view to_node_name) { +absl::Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name, + absl::string_view to_node_name) { NodeDef* from_node = GetNode(from_node_name); TF_RETURN_IF_ERROR( CheckNodeExists(from_node_name, from_node, @@ -766,8 +768,8 @@ Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name, return UpdateFanoutsInternal(from_node, to_node); } -Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node, - NodeDef* to_node) { +absl::Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node, + NodeDef* to_node) { VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.", from_node->name(), to_node->name()); if (from_node == to_node) { @@ -911,8 +913,8 @@ bool MutableGraphView::AddFaninInternal(NodeDef* node, return true; } -Status MutableGraphView::AddRegularFanin(absl::string_view node_name, - const TensorId& fanin) { +absl::Status MutableGraphView::AddRegularFanin(absl::string_view node_name, + const TensorId& fanin) { auto error_status = [node_name, fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', fanin='$1'", node_name, fanin.ToString()); @@ -930,9 +932,8 @@ Status MutableGraphView::AddRegularFanin(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name, - int port, - const TensorId& fanin) { +absl::Status MutableGraphView::AddRegularFaninByPort( + absl::string_view node_name, int port, const TensorId& fanin) { auto error_status = [node_name, port, fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'", node_name, port, fanin.ToString()); @@ -1033,8 +1034,8 @@ NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch( return identity_node; } -Status MutableGraphView::AddControllingFanin(absl::string_view node_name, - const TensorId& fanin) { +absl::Status MutableGraphView::AddControllingFanin(absl::string_view node_name, + const TensorId& fanin) { auto error_status = [node_name, fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', fanin='$1'", node_name, fanin.ToString()); @@ -1120,8 +1121,8 @@ bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node, return modified; } -Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name, - const TensorId& fanin) { +absl::Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name, + const TensorId& fanin) { auto error_status = [node_name, fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', fanin='$1'", node_name, fanin.ToString()); @@ -1140,8 +1141,8 @@ Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name, - int port) { +absl::Status MutableGraphView::RemoveRegularFaninByPort( + absl::string_view node_name, int port) { auto error_status = [node_name, port](absl::string_view msg) { string params = absl::Substitute("node_name='$0', port=$1", node_name, port); @@ -1201,7 +1202,7 @@ bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node, return false; } -Status MutableGraphView::RemoveControllingFanin( +absl::Status MutableGraphView::RemoveControllingFanin( absl::string_view node_name, absl::string_view fanin_node_name) { auto error_status = [node_name, fanin_node_name](absl::string_view msg) { string params = absl::Substitute("node_name='$0', fanin_node_name='$1'", @@ -1221,8 +1222,8 @@ Status MutableGraphView::RemoveControllingFanin( return absl::OkStatus(); } -Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, - bool keep_controlling_fanins) { +absl::Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, + bool keep_controlling_fanins) { NodeDef* node = GetNode(node_name); if (node == nullptr) { string params = @@ -1253,9 +1254,9 @@ Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::UpdateFanin(absl::string_view node_name, - const TensorId& from_fanin, - const TensorId& to_fanin) { +absl::Status MutableGraphView::UpdateFanin(absl::string_view node_name, + const TensorId& from_fanin, + const TensorId& to_fanin) { auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'", @@ -1343,9 +1344,8 @@ Status MutableGraphView::UpdateFanin(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, - int port, - const TensorId& fanin) { +absl::Status MutableGraphView::UpdateRegularFaninByPort( + absl::string_view node_name, int port, const TensorId& fanin) { auto error_status = [node_name, port, fanin](absl::string_view msg) { string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'", node_name, port, fanin.ToString()); @@ -1387,8 +1387,8 @@ Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, - int from_port, int to_port) { +absl::Status MutableGraphView::SwapRegularFaninsByPorts( + absl::string_view node_name, int from_port, int to_port) { auto error_status = [node_name, from_port, to_port](absl::string_view msg) { string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2", node_name, from_port, to_port); @@ -1431,7 +1431,7 @@ Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, return absl::OkStatus(); } -Status MutableGraphView::UpdateAllRegularFaninsToControlling( +absl::Status MutableGraphView::UpdateAllRegularFaninsToControlling( absl::string_view node_name) { auto error_status = [node_name](absl::string_view msg) { string params = absl::Substitute("node_name='$0'", node_name); @@ -1502,7 +1502,7 @@ Status MutableGraphView::UpdateAllRegularFaninsToControlling( return absl::OkStatus(); } -Status MutableGraphView::CheckNodesCanBeDeleted( +absl::Status MutableGraphView::CheckNodesCanBeDeleted( const absl::flat_hash_set& nodes_to_delete) { std::vector missing_nodes; std::vector nodes_with_fanouts; @@ -1565,7 +1565,7 @@ Status MutableGraphView::CheckNodesCanBeDeleted( return absl::OkStatus(); } -Status MutableGraphView::DeleteNodes( +absl::Status MutableGraphView::DeleteNodes( const absl::flat_hash_set& nodes_to_delete) { TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete)); diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index 68946a65ab4d49..fdd4fa322f342e 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -74,22 +74,23 @@ class MutableGraphView : public internal::GraphViewInternal { // // IMPORTANT: All nodes and functions of the given subgraph moved into the // underlying graph, which leaves subgraph in valid but undefined state. - Status AddSubgraph(GraphDef&& subgraph); + absl::Status AddSubgraph(GraphDef&& subgraph); // Updates node `node_name` op, device, and attributes. This will clear any // existing attributes. If it is not possible to update the node or if the // node does not exist, an error will be returned and nothing will be modified // in the graph. - Status UpdateNode(absl::string_view node_name, absl::string_view op, - absl::string_view device, - absl::Span> attrs); + absl::Status UpdateNode(absl::string_view node_name, absl::string_view op, + absl::string_view device, + absl::Span> attrs); // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is // in use, node `from_node_name` does not exist, or node `from_node_name` has // fanouts and `update_fanouts` is set to false, an error will be returned and // nothing will be modified in the graph. - Status UpdateNodeName(absl::string_view from_node_name, - absl::string_view to_node_name, bool update_fanouts); + absl::Status UpdateNodeName(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts); // Swap node names `from_node_name` and `to_node_name`. Self loops of one node // are removed by updating the inputs introducing self loops to use the other @@ -115,8 +116,9 @@ class MutableGraphView : public internal::GraphViewInternal { // If it is not possible to swap node names (i.e. nodes do not exist or Switch // control dependency may be introduced), an error will be returned and // nothing will be modified in the graph. - Status SwapNodeNames(absl::string_view from_node_name, - absl::string_view to_node_name, bool update_fanouts); + absl::Status SwapNodeNames(absl::string_view from_node_name, + absl::string_view to_node_name, + bool update_fanouts); // Updates all fanouts (input ports fetching output tensors) from // `from_node_name` to the `to_node_name`, including control dependencies. @@ -130,15 +132,16 @@ class MutableGraphView : public internal::GraphViewInternal { // 1. foo1(new_bar:0, new_bar:1, other:0) // 2. foo2(new_bar:1, other:1) // 3. foo3(other:2, ^new_bar) - Status UpdateFanouts(absl::string_view from_node_name, - absl::string_view to_node_name); + absl::Status UpdateFanouts(absl::string_view from_node_name, + absl::string_view to_node_name); // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not // exist in the graph, nothing will be modified in the graph. Otherwise fanin // will be added after existing non control dependency fanins. Control // dependencies will be deduped. To add control dependencies, use // AddControllingFanin. - Status AddRegularFanin(absl::string_view node_name, const TensorId& fanin); + absl::Status AddRegularFanin(absl::string_view node_name, + const TensorId& fanin); // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node // or fanin do not exist in the graph, nothing will be modified in the graph. @@ -148,8 +151,8 @@ class MutableGraphView : public internal::GraphViewInternal { // If the port is not a valid port (less than 0 or greater than the number of // regular fanins), this will result in an error and the node will not be // modified. - Status AddRegularFaninByPort(absl::string_view node_name, int port, - const TensorId& fanin); + absl::Status AddRegularFaninByPort(absl::string_view node_name, int port, + const TensorId& fanin); // Adds control dependency `fanin` to the target node named `node_name`. To // add regular fanins, use AddRegularFanin. @@ -172,8 +175,8 @@ class MutableGraphView : public internal::GraphViewInternal { // If the control dependency being added is redundant (control dependency // already exists or control dependency can be deduped from regular fanins), // this will not result in an error and the node will not be modified. - Status AddControllingFanin(absl::string_view node_name, - const TensorId& fanin); + absl::Status AddControllingFanin(absl::string_view node_name, + const TensorId& fanin); // Removes regular fanin `fanin` from node `node_name`. If the node or fanin // do not exist in the graph, nothing will be modified in the graph. If there @@ -182,7 +185,8 @@ class MutableGraphView : public internal::GraphViewInternal { // // If the fanin being removed doesn't exist in the node's inputs, this will // not result in an error and the node will not be modified. - Status RemoveRegularFanin(absl::string_view node_name, const TensorId& fanin); + absl::Status RemoveRegularFanin(absl::string_view node_name, + const TensorId& fanin); // Removes regular fanin at port `port` from node `node_name`. If the node // does not exist in the graph, nothing will be modified in the graph. @@ -191,7 +195,7 @@ class MutableGraphView : public internal::GraphViewInternal { // If the port is not a valid port (less than 0 or greater than the last index // of the regular fanins), this will result in an error and the node will not // be modified. - Status RemoveRegularFaninByPort(absl::string_view node_name, int port); + absl::Status RemoveRegularFaninByPort(absl::string_view node_name, int port); // Removes control dependency `fanin_node_name` from the target node named // `node_name`. If the node or fanin do not exist in the graph, nothing will @@ -199,16 +203,16 @@ class MutableGraphView : public internal::GraphViewInternal { // // If the fanin being removed doesn't exist in the node's inputs, this will // not result in an error and the node will not be modified. - Status RemoveControllingFanin(absl::string_view node_name, - absl::string_view fanin_node_name); + absl::Status RemoveControllingFanin(absl::string_view node_name, + absl::string_view fanin_node_name); // Removes all fanins from node `node_name`. Control dependencies will be // retained if keep_controlling_fanins is true. // // If no fanins are removed, this will not result in an error and the node // will not be modified. - Status RemoveAllFanins(absl::string_view node_name, - bool keep_controlling_fanins); + absl::Status RemoveAllFanins(absl::string_view node_name, + bool keep_controlling_fanins); // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If // the fanins or node do not exist, nothing will be modified in the graph. @@ -216,8 +220,9 @@ class MutableGraphView : public internal::GraphViewInternal { // // If the fanin being updated doesn't exist in the node's inputs, this will // not result in an error and the node will not be modified. - Status UpdateFanin(absl::string_view node_name, const TensorId& from_fanin, - const TensorId& to_fanin); + absl::Status UpdateFanin(absl::string_view node_name, + const TensorId& from_fanin, + const TensorId& to_fanin); // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If // the fanins or node do not exist, nothing will be modified in the graph. @@ -226,8 +231,8 @@ class MutableGraphView : public internal::GraphViewInternal { // If the port is not a valid port (less than 0 or greater than the last index // of the regular fanins), this will result in an error and the node will not // be modified. - Status UpdateRegularFaninByPort(absl::string_view node_name, int port, - const TensorId& fanin); + absl::Status UpdateRegularFaninByPort(absl::string_view node_name, int port, + const TensorId& fanin); // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the // node does not exist, nothing will be modified in the graph. @@ -235,18 +240,18 @@ class MutableGraphView : public internal::GraphViewInternal { // If the ports are not a valid port (less than 0 or greater than the last // index of the regular fanins), this will result in an error and the node // will not be modified. - Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port, - int to_port); + absl::Status SwapRegularFaninsByPorts(absl::string_view node_name, + int from_port, int to_port); // Updates all regular fanins to equivalent controlling fanins. If it is not // possible, an error will be returned and nothing will be modified in the // graph. - Status UpdateAllRegularFaninsToControlling(absl::string_view node_name); + absl::Status UpdateAllRegularFaninsToControlling(absl::string_view node_name); // Deletes nodes from the graph. If a node can't be safely removed, // specifically if a node still has fanouts, an error will be returned. Nodes // that can't be found are ignored. - Status DeleteNodes(const absl::flat_hash_set& nodes_to_delete); + absl::Status DeleteNodes(const absl::flat_hash_set& nodes_to_delete); private: // Adds fanouts for fanins of node to graph, while deduping control @@ -281,7 +286,7 @@ class MutableGraphView : public internal::GraphViewInternal { // // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the // behavior is undefined. - Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node); + absl::Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node); // Adds fanin to node. If fanin is a control dependency, existing control // dependencies will be checked first before adding. Otherwise fanin will be @@ -313,7 +318,7 @@ class MutableGraphView : public internal::GraphViewInternal { // Checks if nodes to be deleted are missing or have any fanouts that will // remain in the graph. If node is removed in either case, the graph will // enter an invalid state. - Status CheckNodesCanBeDeleted( + absl::Status CheckNodesCanBeDeleted( const absl::flat_hash_set& nodes_to_delete); // Removes fanins of the deleted node from internal state. Control diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index 3916f992f6f144..c05d9f0ad0351b 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -202,7 +202,7 @@ TEST(MutableGraphViewTest, AddSubgraphAndFailIfFunctionDifferent) { FunctionDef x_times_two = test::function::XTimesTwo(); GraphDef subgraph = test::function::GDef({}, {x_times_two}); - Status status = graph.AddSubgraph(std::move(subgraph)); + absl::Status status = graph.AddSubgraph(std::move(subgraph)); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.message(), "MutableGraphView::AddSubgraph(function_size=1) error: Found " @@ -289,7 +289,7 @@ TEST(MutableGraphViewTest, UpdateNodeSwitchControlDependency) { AttrValue attr; attr.set_type(DT_FLOAT); - Status s = graph.UpdateNode("foo", "Switch", kDevice, {{"T", attr}}); + absl::Status s = graph.UpdateNode("foo", "Switch", kDevice, {{"T", attr}}); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::UpdateNodeOp(node_name='foo', op='Switch', " @@ -356,7 +356,8 @@ void TestUpdateNodeName(absl::string_view from_node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, from_node_name); - Status s = graph.UpdateNodeName(from_node_name, to_node_name, update_fanouts); + absl::Status s = + graph.UpdateNodeName(from_node_name, to_node_name, update_fanouts); EXPECT_EQ(s.ok(), success); string updated_node_name; if (success) { @@ -678,7 +679,8 @@ void TestSwapNodeNamesError(absl::string_view from_node_name, MutableGraphView graph(&graph_def); - Status s = graph.SwapNodeNames(from_node_name, to_node_name, update_fanouts); + absl::Status s = + graph.SwapNodeNames(from_node_name, to_node_name, update_fanouts); EXPECT_EQ(s.ok(), false); EXPECT_EQ(s.message(), error_msg); @@ -842,7 +844,7 @@ TEST(MutableGraphViewTest, UpdateFanoutsToSwitchWithControlFromSwitch) { MutableGraphView graph(&graph_def); - Status s = graph.UpdateFanouts("a", "b"); + absl::Status s = graph.UpdateFanouts("a", "b"); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::UpdateFanouts(from_node_name='a', to_node_name='b') " @@ -923,7 +925,7 @@ void TestAddRegularFanin(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.AddRegularFanin(node_name, fanin_to_add); + absl::Status s = graph.AddRegularFanin(node_name, fanin_to_add); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1055,7 +1057,7 @@ void TestAddRegularFaninByPort(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.AddRegularFaninByPort(node_name, port, fanin_to_add); + absl::Status s = graph.AddRegularFaninByPort(node_name, port, fanin_to_add); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1209,7 +1211,7 @@ void TestRemoveRegularFanin(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.RemoveRegularFanin(node_name, fanin_to_remove); + absl::Status s = graph.RemoveRegularFanin(node_name, fanin_to_remove); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1351,7 +1353,7 @@ void TestRemoveRegularFaninByPort(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.RemoveRegularFaninByPort(node_name, port); + absl::Status s = graph.RemoveRegularFaninByPort(node_name, port); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1456,7 +1458,7 @@ void TestRemoveAllFanins(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.RemoveAllFanins(node_name, keep_controlling_nodes); + absl::Status s = graph.RemoveAllFanins(node_name, keep_controlling_nodes); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1552,7 +1554,7 @@ void TestUpdateFanin(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.UpdateFanin(node_name, from_fanin, to_fanin); + absl::Status s = graph.UpdateFanin(node_name, from_fanin, to_fanin); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1687,7 +1689,7 @@ void TestUpdateFaninFromFaninToNodeAsSwitchControl(const TensorId& fanin) { MutableGraphView graph(&graph_def); - Status s = graph.UpdateFanin("c", fanin, {"b", Graph::kControlSlot}); + absl::Status s = graph.UpdateFanin("c", fanin, {"b", Graph::kControlSlot}); EXPECT_FALSE(s.ok()); string expected_msg = absl::Substitute( "MutableGraphView::UpdateFanin(node_name='c', from_fanin='$0', " @@ -1730,7 +1732,7 @@ void TestUpdateRegularFaninByPort(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.UpdateRegularFaninByPort(node_name, port, fanin); + absl::Status s = graph.UpdateRegularFaninByPort(node_name, port, fanin); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -1894,7 +1896,8 @@ void TestSwapRegularFaninsByPorts(absl::string_view node_name, bool node_exists, absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.SwapRegularFaninsByPorts(node_name, from_port, to_port); + absl::Status s = + graph.SwapRegularFaninsByPorts(node_name, from_port, to_port); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -2392,7 +2395,7 @@ TEST(MutableGraphViewTest, AddControllingFaninMissing) { MutableGraphView graph(&graph_def); // Missing fanin. - Status s = graph.AddControllingFanin("a", {"c", Graph::kControlSlot}); + absl::Status s = graph.AddControllingFanin("a", {"c", Graph::kControlSlot}); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::AddControllingFanin(node_name='a', fanin='^c') error: " @@ -2464,7 +2467,7 @@ TEST(MutableGraphViewTest, AddControllingFaninSwitch) { MutableGraphView graph(&graph_def); - Status s = graph.AddControllingFanin("a", {"b", Graph::kControlSlot}); + absl::Status s = graph.AddControllingFanin("a", {"b", Graph::kControlSlot}); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::AddControllingFanin(node_name='a', fanin='^b') error: " @@ -2558,7 +2561,7 @@ void TestAddControllingFaninSelfLoops(absl::string_view node_name, MutableGraphView graph(&graph_def); - Status s = graph.AddControllingFanin(node_name, fanin); + absl::Status s = graph.AddControllingFanin(node_name, fanin); EXPECT_FALSE(s.ok()); EXPECT_EQ(s.message(), error_msg); @@ -2616,7 +2619,8 @@ TEST(MutableGraphViewTest, AddControllingFaninSelfLoopsGeneratedIdentity) { // node, with name `ConstantFoldingCtrl/b_1`. As the input node is of the same // name, we will introduce a self loop, so no control dependency should be // added. - Status s = graph.AddControllingFanin("ConstantFoldingCtrl/b_1", {"b", 1}); + absl::Status s = + graph.AddControllingFanin("ConstantFoldingCtrl/b_1", {"b", 1}); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::AddControllingFanin(node_name='ConstantFoldingCtrl/" @@ -2709,7 +2713,7 @@ TEST(MutableGraphViewTest, RemoveControllingFaninSelfLoop) { MutableGraphView graph(&graph_def); - Status s = graph.RemoveControllingFanin("c", "c"); + absl::Status s = graph.RemoveControllingFanin("c", "c"); EXPECT_FALSE(s.ok()); string expected_msg = "MutableGraphView::RemoveControllingFanin(node_name='c', " @@ -2753,7 +2757,7 @@ void TestUpdateAllRegularFaninsToControlling( absl::flat_hash_map> unmodified_node_inputs = GetNodeInputsFromGraph(graph_def, node_name); - Status s = graph.UpdateAllRegularFaninsToControlling(node_name); + absl::Status s = graph.UpdateAllRegularFaninsToControlling(node_name); EXPECT_EQ(s.ok(), success); if (!success) { EXPECT_EQ(s.message(), error_msg); @@ -2935,7 +2939,7 @@ TEST(MutableGraphViewTest, DeleteNodesWithError) { MutableGraphView graph(&graph_def); - Status s = graph.DeleteNodes({"b", "a"}); + absl::Status s = graph.DeleteNodes({"b", "a"}); EXPECT_FALSE(s.ok()); string error_msg = "MutableGraphView::DeleteNodes(nodes_to_delete={a, b}) error: can't " @@ -2969,7 +2973,7 @@ TEST(MutableGraphViewTest, DeleteNodesWithLargeError) { MutableGraphView graph(&graph_def); - Status s = graph.DeleteNodes({"a", "b", "c", "d", "e", "f"}); + absl::Status s = graph.DeleteNodes({"a", "b", "c", "d", "e", "f"}); EXPECT_FALSE(s.ok()); string error_msg = "MutableGraphView::DeleteNodes(nodes_to_delete={a, b, c, d, e, ...}) " diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 2bf4de1ba86033..6ecf9b6e903611 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -682,7 +682,7 @@ bool IsPersistent(const NodeDef& node) { bool HasRefInput(const NodeDef& node) { const OpDef* op_def; - Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + absl::Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } @@ -705,7 +705,7 @@ bool IsDataset(const NodeDef& node) { bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) { const OpDef* op_def = nullptr; const string& op_name = node.op(); - Status status = op_registry->LookUpOpDef(op_name, &op_def); + absl::Status status = op_registry->LookUpOpDef(op_name, &op_def); if (!status.ok()) { LOG(WARNING) << "Failed to lookup OpDef for " << op_name << ". Error: " << status.message(); @@ -726,7 +726,7 @@ bool IsFreeOfSideEffect(const NodeDef& node, } const OpDef* op_def = nullptr; const string& op_name = node.op(); - Status status = op_registry->LookUpOpDef(op_name, &op_def); + absl::Status status = op_registry->LookUpOpDef(op_name, &op_def); if (!status.ok()) { return false; } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index f3079b745d029c..27c8acfc854cd3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2015,7 +2015,7 @@ class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage { // chain of unary elementwise ops that are not outputs. if (IsReshape(*node)) { bool skip = false; - gtl::InlinedVector nodes_in_chain; + absl::InlinedVector nodes_in_chain; const auto predicate_fn = [this, node, &skip, &nodes_in_chain](const NodeDef& input) { nodes_in_chain.push_back(&input); @@ -3838,7 +3838,7 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { } auto copy_tensor_values_to_vector = - [node](const Tensor& t, gtl::InlinedVector* vec) { + [node](const Tensor& t, absl::InlinedVector* vec) { if (t.dtype() == DT_INT32) { auto t_flat = t.flat(); vec->assign(&t_flat(0), &t_flat(t.NumElements())); @@ -3853,8 +3853,8 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { return absl::OkStatus(); }; - gtl::InlinedVector slice_begin_vec; - gtl::InlinedVector slice_size_vec; + absl::InlinedVector slice_begin_vec; + absl::InlinedVector slice_size_vec; TF_RETURN_IF_ERROR( copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec)); TF_RETURN_IF_ERROR( @@ -3958,9 +3958,9 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { bool is_identity; bool is_simple_slice; bool slice_dim0; - gtl::InlinedVector slice_begin_vec; - gtl::InlinedVector slice_end_vec; - gtl::InlinedVector slice_strides_vec; + absl::InlinedVector slice_begin_vec; + absl::InlinedVector slice_end_vec; + absl::InlinedVector slice_strides_vec; TF_RETURN_IF_ERROR(ValidateStridedSliceOp( &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bd10921cb877a1..22bfd0fea50aa6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -4496,7 +4496,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) { } else if (node.name() == "pc_slice_out") { ASSERT_EQ(node.input_size(), 1); EXPECT_EQ(node.input(0), "c"); - } else if (str_util::EndsWith(node.name(), "_out")) { + } else if (absl::EndsWith(node.name(), "_out")) { ASSERT_EQ(node.input_size(), 1); EXPECT_EQ( absl::StrCat(node.input(0), "_out"), diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index 4da6f454d2486e..34c03b61890e7b 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -2312,9 +2312,11 @@ Status AutoMixedPrecision::Optimize(Cluster* cluster, const GrapplerItem& item, << " graph optimizer"; return absl::OkStatus(); } - // Check if CPU supports FP16 + // Check if CPU supports FP16, oneDNN supports FP16 on + // some platforms by converting to and from FP32 if (mode_ == AutoMixedPrecisionMode::FP16_CPU && - !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF)) { + !IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF) && + !IsAVXConvertSupportedByOneDNNOnThisCPU()) { VLOG(1) << "No support for " << name() << " graph optimizer on CPU"; return absl::OkStatus(); } diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index ee1d968e0248ad..813875bcc43499 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -102,7 +102,7 @@ class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { TF_CHECK_OK( ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", &optimization_level)); - optimization_level = str_util::Uppercase(optimization_level); + optimization_level = absl::AsciiStrToUpper(optimization_level); return optimization_level == "TENSOR_CORES_ONLY"; } @@ -154,6 +154,7 @@ class AutoMixedPrecisionListsFp16 : public AutoMixedPrecisionLists { list.insert("TmlpV3"); list.insert("Pmlp"); list.insert("FastUnsortedSegmentMax"); + list.insert("VoxelMax"); } #if TENSORFLOW_USE_ROCM if (true) { diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 0b855be91f8099..f3def370cee3a6 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -132,7 +132,10 @@ class AutoMixedPrecisionTest : public GrapplerTest { bool is_fp16_enabled_on_cpu = false; #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - is_fp16_enabled_on_cpu = IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF); + // oneDNN supports FP16 on some platforms by converting to and from FP32 + is_fp16_enabled_on_cpu = + IsAMXDataTypeSupportedByOneDNNOnThisCPU(DT_HALF) || + IsAVXConvertSupportedByOneDNNOnThisCPU(); #endif // INTEL_MKL && ENABLE_ONEDNN_V3 if (!IsMKLEnabled() || !is_fp16_enabled_on_cpu) { GTEST_SKIP() << "This device doesn't support FP16"; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 90eb63a22bcaa7..ad69a3f5bd80c2 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -61,7 +61,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; // We only fold/materialize constants smaller than 100kB. const int64_t kMaxConstantSize = 100 * 1024; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9305aa09764f3d..5a31b65717f91b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -89,8 +89,8 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties* properties) const; Status EvaluateNode(const NodeDef& node, - const gtl::InlinedVector& inputs, - gtl::InlinedVector* output) const; + const absl::InlinedVector& inputs, + absl::InlinedVector* output) const; Status EvaluateOneFoldable(const NodeDef& node, std::vector* outputs, bool* result_too_large); @@ -232,7 +232,8 @@ class ConstantFolding : public GraphOptimizer { // input dimensions to reduce along are all of size 1 and keep_dims is true). bool IsReductionSimplifiableToIdentity( const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims, - const gtl::InlinedVector& reduction_indices_vector) const; + const absl::InlinedVector& reduction_indices_vector) + const; // Changes a reduction into an Identity op, returning true on success. bool ReplaceReductionWithIdentity(NodeDef* node) const; diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 6f867024bb9000..3f7c18f9e01e34 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -895,7 +895,7 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/platform:status_matchers", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -993,7 +993,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ] + tf_protos_all(), alwayslink = 1, ) diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h index 19b7002dcd8562..f7da097d4b1b09 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h @@ -34,7 +34,7 @@ using SetFunctionSignatureFn = std::function; -using StringCollection = gtl::InlinedVector; +using StringCollection = absl::InlinedVector; // These functions are invoked with nodes from second function that were // previously taking arguments as input. The `arg_num` tells which diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc index 4753892e7f28b4..a726f167d57d89 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map.cc @@ -62,10 +62,9 @@ absl::StatusOr GetCompressionMapNode(const GraphDef& graph) { } // namespace -Status RemoveCompressionMap::OptimizeAndCollectStats(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) { +absl::Status RemoveCompressionMap::OptimizeAndCollectStats( + Cluster* cluster, const GrapplerItem& item, GraphDef* output, + OptimizationStats* stats) { *output = item.graph; TF_ASSIGN_OR_RETURN(NodeDef compression_map_node, GetCompressionMapNode(*output)); diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map.h b/tensorflow/core/grappler/optimizers/data/remove_compression_map.h index 6306cca768894f..550436f4e3d234 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map.h +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map.h @@ -30,14 +30,15 @@ class RemoveCompressionMap : public TFDataOptimizerBase { bool UsesFunctionLibrary() const override { return false; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc index 2060b0ed4c83e8..3503e0a4c7d635 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/platform/status_matchers.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc b/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc index 43bcca49e2a7a5..21de626e2c5e63 100644 --- a/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc +++ b/tensorflow/core/grappler/optimizers/data/replicate_on_split.cc @@ -26,10 +26,9 @@ limitations under the License. namespace tensorflow { namespace grappler { -Status ReplicateOnSplit::OptimizeAndCollectStats(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) { +absl::Status ReplicateOnSplit::OptimizeAndCollectStats( + Cluster* cluster, const GrapplerItem& item, GraphDef* output, + OptimizationStats* stats) { VLOG(1) << "Running replicate on split optimization"; *output = item.graph; MutableGraphView graph(output); diff --git a/tensorflow/core/grappler/optimizers/data/replicate_on_split.h b/tensorflow/core/grappler/optimizers/data/replicate_on_split.h index 338ef29b3fdcc3..cffcbd18588973 100644 --- a/tensorflow/core/grappler/optimizers/data/replicate_on_split.h +++ b/tensorflow/core/grappler/optimizers/data/replicate_on_split.h @@ -30,14 +30,15 @@ class ReplicateOnSplit : public TFDataOptimizerBase { bool UsesFunctionLibrary() const override { return false; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc index 11e4916e5a18e5..c3c68085771ae7 100644 --- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.cc @@ -167,7 +167,8 @@ NodeDef CreateBufferSizeNode(DataType dtype, return node; } -Status CreateAndAppendPrefetchNode(MutableGraphView* graph, FunctionDef& fdef) { +absl::Status CreateAndAppendPrefetchNode(MutableGraphView* graph, + FunctionDef& fdef) { auto get_last_dataset_op_node = [&]() -> const NodeDef* { // Find the input node of fdef's ret value. const auto& output_arg = fdef.signature().output_arg(0).name(); @@ -251,10 +252,10 @@ Status CreateAndAppendPrefetchNode(MutableGraphView* graph, FunctionDef& fdef) { return absl::OkStatus(); } -Status AddInterleaveNode(MutableGraphView* graph, - const NodeDef& parallel_interleave_node, - const std::string& interleave_map_func_name, - absl::flat_hash_set& nodes_to_delete) { +absl::Status AddInterleaveNode(MutableGraphView* graph, + const NodeDef& parallel_interleave_node, + const std::string& interleave_map_func_name, + absl::flat_hash_set& nodes_to_delete) { NodeDef interleave_node; interleave_node.set_op(kInterleaveDatasetOpName); graph_utils::SetUniqueGraphNodeName( @@ -321,7 +322,7 @@ Status AddInterleaveNode(MutableGraphView* graph, } } // namespace -Status SeqInterleavePrefetch::OptimizeAndCollectStats( +absl::Status SeqInterleavePrefetch::OptimizeAndCollectStats( Cluster* cluster, const GrapplerItem& item, GraphDef* output, OptimizationStats* stats) { *output = item.graph; diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h index 00cfed1ed78abd..c881d9aa1babc0 100644 --- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch.h @@ -35,14 +35,15 @@ class SeqInterleavePrefetch : public TFDataOptimizerBase { // library. bool UsesFunctionLibrary() const override { return true; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; protected: bool autotune_ = true; diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc index eba3fced1876c4..80741eca2a2c94 100644 --- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc @@ -267,8 +267,8 @@ bool IsInterleaveNode(const NodeDef &node) { } // namespace -Status OptimizeWithInjectInterleavePrefetch(const GrapplerItem &item, - GraphDef *output) { +absl::Status OptimizeWithInjectInterleavePrefetch(const GrapplerItem &item, + GraphDef *output) { SeqInterleavePrefetch optimizer; return optimizer.Optimize(nullptr, item, output); } diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index 5fabf42bf03872..bd82281b1cf3f1 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -41,10 +41,10 @@ constexpr char kShuffleAndRepeatDatasetV2[] = "ShuffleAndRepeatDatasetV2"; constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration"; -Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node, - const NodeDef& repeat_node, - MutableGraphView* graph, GraphDef* output, - NodeDef* fused_node) { +absl::Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { fused_node->set_op(kShuffleAndRepeatDataset); graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output, fused_node); @@ -75,10 +75,10 @@ Status FuseShuffleV1AndRepeat(const NodeDef& shuffle_node, return absl::OkStatus(); } -Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, - const NodeDef& repeat_node, - MutableGraphView* graph, GraphDef* output, - NodeDef* fused_node) { +absl::Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { fused_node->set_op(kShuffleAndRepeatDatasetV2); graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDatasetV2, output, fused_node); @@ -115,10 +115,10 @@ Status FuseShuffleV2AndRepeat(const NodeDef& shuffle_node, return absl::OkStatus(); } -Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, - const NodeDef& repeat_node, - MutableGraphView* graph, GraphDef* output, - NodeDef* fused_node) { +absl::Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, + const NodeDef& repeat_node, + MutableGraphView* graph, GraphDef* output, + NodeDef* fused_node) { fused_node->set_op(kShuffleAndRepeatDatasetV2); graph_utils::SetUniqueGraphNodeName(kShuffleAndRepeatDataset, output, fused_node); @@ -154,7 +154,7 @@ Status FuseShuffleV3AndRepeat(const NodeDef& shuffle_node, } // namespace -Status ShuffleAndRepeatFusion::OptimizeAndCollectStats( +absl::Status ShuffleAndRepeatFusion::OptimizeAndCollectStats( Cluster* cluster, const GrapplerItem& item, GraphDef* output, OptimizationStats* stats) { *output = item.graph; diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h index 5ce38242bbe3b0..ba30ca63a0aeec 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h @@ -30,14 +30,15 @@ class ShuffleAndRepeatFusion : public TFDataOptimizerBase { bool UsesFunctionLibrary() const override { return false; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/slack.cc b/tensorflow/core/grappler/optimizers/data/slack.cc index c83a371973609c..9c42f088d13d98 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.cc +++ b/tensorflow/core/grappler/optimizers/data/slack.cc @@ -77,8 +77,8 @@ constexpr std::array kPassThroughOps = { } // namespace -Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, - NodeDef* dataset_node) { +absl::Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, + NodeDef* dataset_node) { if (dataset_node->op() == kPrefetchDatasetOp) { if (HasNodeAttr(*dataset_node, "slack_period")) { (*dataset_node->mutable_attr())["slack_period"].set_i(slack_period_); @@ -105,10 +105,10 @@ Status Slack::RecursivelyHandleOp(const MutableGraphView& graph, return absl::OkStatus(); } -Status Slack::OptimizeAndCollectStats(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) { +absl::Status Slack::OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) { if (slack_period_ < 1) return errors::InvalidArgument("Invalid `slack_period` parameter: ", slack_period_); diff --git a/tensorflow/core/grappler/optimizers/data/slack.h b/tensorflow/core/grappler/optimizers/data/slack.h index b39cfc65094567..af70d314697a36 100644 --- a/tensorflow/core/grappler/optimizers/data/slack.h +++ b/tensorflow/core/grappler/optimizers/data/slack.h @@ -35,7 +35,7 @@ class Slack : public TFDataOptimizerBase { bool UsesFunctionLibrary() const override { return false; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { if (!config) return errors::InvalidArgument("Config parameter required."); @@ -48,15 +48,16 @@ class Slack : public TFDataOptimizerBase { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; private: int64_t slack_period_ = -1; - Status RecursivelyHandleOp(const MutableGraphView& graph, - NodeDef* dataset_node); + absl::Status RecursivelyHandleOp(const MutableGraphView& graph, + NodeDef* dataset_node); }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/slack_test.cc b/tensorflow/core/grappler/optimizers/data/slack_test.cc index bc1205cbae8c1b..518a237afe4da4 100644 --- a/tensorflow/core/grappler/optimizers/data/slack_test.cc +++ b/tensorflow/core/grappler/optimizers/data/slack_test.cc @@ -89,7 +89,7 @@ TEST(SlackTest, TestFailWithoutInit) { GrapplerItem item; Slack optimizer; GraphDef output; - Status result = optimizer.Optimize(nullptr, item, &output); + absl::Status result = optimizer.Optimize(nullptr, item, &output); EXPECT_FALSE(result.ok()); EXPECT_TRUE(absl::IsInvalidArgument(result)); @@ -105,7 +105,7 @@ TEST(SlackTest, TestFailWithInvalidSlackEveryParam) { TF_ASSERT_OK(optimizer.Init(&config)); GraphDef output; - Status result = optimizer.Optimize(nullptr, item, &output); + absl::Status result = optimizer.Optimize(nullptr, item, &output); EXPECT_FALSE(result.ok()); EXPECT_TRUE(absl::IsInvalidArgument(result)); diff --git a/tensorflow/core/grappler/optimizers/data/split_utils.cc b/tensorflow/core/grappler/optimizers/data/split_utils.cc index 1798f1de44f054..54bf9bcd1660f6 100644 --- a/tensorflow/core/grappler/optimizers/data/split_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/split_utils.cc @@ -108,7 +108,7 @@ class InputRewriter { // *new_input_str will be set to the empty string if the input should be // removed, which occurs if it is a control dependency for a node in the first // function. - Status RewriteInput(absl::string_view input_str, string* new_input_str); + absl::Status RewriteInput(absl::string_view input_str, string* new_input_str); private: bool IsInFirstFunction(absl::string_view node_name) { @@ -116,26 +116,27 @@ class InputRewriter { } // Rewrite a control input. input_str is in the form "^node_name" - Status RewriteControlInput(absl::string_view input_str, - string* new_input_str); + absl::Status RewriteControlInput(absl::string_view input_str, + string* new_input_str); // Rewrite an input that is an argument to original_function_. input_str is in // the form "fun_in" or "fun_in:number". - Status RewriteArgumentInput(absl::string_view input_str, - string* new_input_str); + absl::Status RewriteArgumentInput(absl::string_view input_str, + string* new_input_str); // Rewrite an input that is the output of a node. input_str is in the form // "node:out" or "node:out:number" - Status RewriteNodeInput(absl::string_view input_str, string* new_input_str); + absl::Status RewriteNodeInput(absl::string_view input_str, + string* new_input_str); // Rewrites an input, `input_str`, where the node producing `input_str` is in // first_function_ and the node consuming `input_str` is in second_function_. // This function adds an output argument to first_function_ and an input // argument to second_function_. "input_arg_def" is the ArgDef corresponding // to input_str, and must have the type() field set. - Status RewriteCrossFunctionInput(absl::string_view input_str, - const OpDef::ArgDef& input_arg_def, - string* new_input_str); + absl::Status RewriteCrossFunctionInput(absl::string_view input_str, + const OpDef::ArgDef& input_arg_def, + string* new_input_str); string unique_name(const std::string& name) { if (used_names_.count(name) == 0) { @@ -173,8 +174,8 @@ class InputRewriter { std::unordered_set used_names_; }; -Status InputRewriter::RewriteInput(absl::string_view input_str, - string* new_input_str) { +absl::Status InputRewriter::RewriteInput(absl::string_view input_str, + string* new_input_str) { auto iter = input_map_.find(input_str); if (iter != input_map_.end()) { *new_input_str = iter->second; @@ -192,8 +193,8 @@ Status InputRewriter::RewriteInput(absl::string_view input_str, return absl::OkStatus(); } -Status InputRewriter::RewriteControlInput(absl::string_view input_str, - string* new_input_str) { +absl::Status InputRewriter::RewriteControlInput(absl::string_view input_str, + string* new_input_str) { DCHECK_EQ(input_str.at(0), '^'); absl::string_view node_name = input_str.substr(1); if (IsInFirstFunction(node_name)) { @@ -204,8 +205,8 @@ Status InputRewriter::RewriteControlInput(absl::string_view input_str, return absl::OkStatus(); } -Status InputRewriter::RewriteArgumentInput(absl::string_view input_str, - string* new_input_str) { +absl::Status InputRewriter::RewriteArgumentInput(absl::string_view input_str, + string* new_input_str) { std::vector components = absl::StrSplit(input_str, ':'); if (components.size() != 1 && components.size() != 2) { return errors::Internal("Found node with invalid argument input: ", @@ -254,8 +255,8 @@ Status InputRewriter::RewriteArgumentInput(absl::string_view input_str, return RewriteCrossFunctionInput(input_str, *found_arg_def, new_input_str); } -Status InputRewriter::RewriteNodeInput(absl::string_view input_str, - string* new_input_str) { +absl::Status InputRewriter::RewriteNodeInput(absl::string_view input_str, + string* new_input_str) { std::vector components = absl::StrSplit(input_str, ':'); if (components.size() != 2 && components.size() != 3) { return errors::Internal("Found node with invalid node input: ", input_str); @@ -321,7 +322,7 @@ Status InputRewriter::RewriteNodeInput(absl::string_view input_str, return RewriteCrossFunctionInput(input_str, found_arg_def, new_input_str); } -Status InputRewriter::RewriteCrossFunctionInput( +absl::Status InputRewriter::RewriteCrossFunctionInput( absl::string_view input_str, const OpDef::ArgDef& input_arg_def, string* new_input_str) { DCHECK(input_arg_def.type() != DT_INVALID); diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc index 917c66939d8208..d9519e27b2d0f4 100644 --- a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.cc @@ -34,10 +34,9 @@ constexpr char kModelDataset[] = "ModelDataset"; } // namespace -Status UsePrivateThreadPool::OptimizeAndCollectStats(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) { +absl::Status UsePrivateThreadPool::OptimizeAndCollectStats( + Cluster* cluster, const GrapplerItem& item, GraphDef* output, + OptimizationStats* stats) { *output = item.graph; MutableGraphView graph(output); diff --git a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h index f515b3afb41371..b886d36ae80650 100644 --- a/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h +++ b/tensorflow/core/grappler/optimizers/data/use_private_thread_pool.h @@ -31,14 +31,15 @@ class UsePrivateThreadPool : public TFDataOptimizerBase { bool UsesFunctionLibrary() const override { return false; } - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } - Status OptimizeAndCollectStats(Cluster* cluster, const GrapplerItem& item, - GraphDef* output, - OptimizationStats* stats) override; + absl::Status OptimizeAndCollectStats(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output, + OptimizationStats* stats) override; }; } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc index d7d32c4e05a7b1..f79f5cdf592305 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -26,8 +26,8 @@ limitations under the License. namespace tensorflow { namespace grappler { -Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* output) { +absl::Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { bool can_optimize = false; for (const NodeDef& node : item.graph.node()) { if (IsAssert(node) || IsCheckNumerics(node) || IsPrint(node)) { diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.h b/tensorflow/core/grappler/optimizers/debug_stripper.h index ab081f85a7b87f..c94257f5a5af7d 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.h +++ b/tensorflow/core/grappler/optimizers/debug_stripper.h @@ -32,8 +32,8 @@ class DebugStripper : public GraphOptimizer { bool UsesFunctionLibrary() const override { return false; } - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* output) override; + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 7eeb154f46fb2c..8d214e41ce41d8 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -139,7 +139,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const { return false; } const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + absl::Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok() || op_def->output_arg_size() == 0) { return false; } @@ -472,7 +472,7 @@ void DependencyOptimizer::CleanControlInputs() { } } -Status DependencyOptimizer::OptimizeDependencies() { +absl::Status DependencyOptimizer::OptimizeDependencies() { SetVector nodes_to_simplify; std::set nodes_to_delete; for (int i = 0; i < optimized_graph_->node_size(); ++i) { @@ -532,7 +532,7 @@ void LongestPathsLowerBounds( } // namespace -Status DependencyOptimizer::TransitiveReduction() { +absl::Status DependencyOptimizer::TransitiveReduction() { // PRECONDITION: optimized_graph_ must be sorted topologically. const int num_nodes = optimized_graph_->node_size(); // Set up a compressed version of the graph to save a constant factor in the @@ -540,7 +540,7 @@ Status DependencyOptimizer::TransitiveReduction() { // highest index of a target of any control output from each node. int num_controls = 0; std::vector> outputs(num_nodes); - std::vector, 2>> control_outputs( + std::vector, 2UL>> control_outputs( num_nodes); // target_range[i] contains the range of node indices for which to compute // longest paths starting from node i. @@ -747,8 +747,9 @@ void DependencyOptimizer::GroupCrossDeviceControlEdges(bool host_granularity) { } } -Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { +absl::Status DependencyOptimizer::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph) { optimized_graph_ = optimized_graph; *optimized_graph_ = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); @@ -758,7 +759,7 @@ Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, const int num_iterations = 2; for (int iteration = 0; iteration < num_iterations; ++iteration) { GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); - Status topo_sort_status; + absl::Status topo_sort_status; // Perform topological sort to prepare the graph for transitive reduction. topo_sort_status = TopologicalSort(optimized_graph_); // Set up index-based graph datastructures to speed up analysis steps below. diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index 4251a4a559efdd..cc8d704337ba01 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -37,8 +37,8 @@ class DependencyOptimizer : public GraphOptimizer { bool UsesFunctionLibrary() const override { return false; } - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) override; + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; private: // Returns true if bypassing node does not increase the number of edges or @@ -64,9 +64,9 @@ class DependencyOptimizer : public GraphOptimizer { std::set* nodes_to_delete); // Eliminates redundant control dependencies by computing the transitive // reduction of the graph. - Status TransitiveReduction(); + absl::Status TransitiveReduction(); // Main driver of dependency optimizations. - Status OptimizeDependencies(); + absl::Status OptimizeDependencies(); // Replaces multiple cross-device control edges from the same device with a // single control edge. If `host_granularity` is true then group control // edges from all devices on the same host. diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index a072a6a93db105..9539466c71611c 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -60,7 +60,7 @@ TEST_F(DependencyOptimizerTest, NoOp) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); VerifyGraphsEqual(item.graph, output, __FUNCTION__); @@ -85,7 +85,7 @@ TEST_F(DependencyOptimizerTest, DependenciesDrivenByConstants) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -120,7 +120,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -186,7 +186,7 @@ TEST_F(DependencyOptimizerTest, FullTypeForKeptNoop) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -240,7 +240,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -298,7 +298,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); @@ -332,7 +332,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); TF_CHECK_OK(TopologicalSort(&item.graph)); @@ -352,7 +352,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_EmptyInputOrOutput) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -396,7 +396,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_DeviceBoundaries) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // The optimization should be disabled to prevent increasing the number of @@ -428,7 +428,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentityOps_DeviceBoundaries) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // The optimization should be disabled to prevent increasing the number of @@ -451,7 +451,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentityOps_IdenticalDevices) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); @@ -483,7 +483,7 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Run the optimizer twice to make sure the rewrite is idempotent. item.graph.Swap(&output); @@ -540,7 +540,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 3, output.node_size()); @@ -607,7 +607,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) { item.fetch.push_back("or2"); DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 1, output.node_size()); @@ -651,7 +651,7 @@ TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { item.fetch.push_back("neg2"); DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(4, output.node_size()); EXPECT_EQ("neg2", output.node(3).name()); @@ -687,7 +687,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size() - 2, output.node_size()); @@ -728,7 +728,7 @@ TEST_F(DependencyOptimizerTest, IdentityInputs) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(6, output.node_size()); @@ -765,7 +765,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) { DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(8, output.node_size()); @@ -808,7 +808,7 @@ TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) { DependencyOptimizer optimizer; GraphDef optimized_graph_def; - Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def); + absl::Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def); TF_EXPECT_OK(status); EXPECT_EQ(6, optimized_graph_def.node_size()); @@ -831,7 +831,7 @@ TEST_F(DependencyOptimizerTest, item.fetch = {"result"}; DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); VerifyGraphsEqual(item.graph, output, __FUNCTION__); @@ -853,7 +853,7 @@ TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) { item.fetch = {"result"}; DependencyOptimizer optimizer; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); EXPECT_EQ(3, output.node_size()); for (const auto& node : output.node()) { diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc index 855635aefc1bb2..865835bc9f588a 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; // In order to avoid the overhead of creating a large thread pool, we set a // small default thread count. This value should be revised should DeviceSimple @@ -49,9 +49,9 @@ DeviceSimple::~DeviceSimple() { delete eigen_worker_threads_.workers; } -Status DeviceSimple::MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) { +absl::Status DeviceSimple::MakeTensorFromProto( + const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs, + Tensor* tensor) { Tensor parsed(tensor_proto.dtype()); if (!parsed.FromProto(cpu_allocator(), tensor_proto)) { return errors::InvalidArgument("Cannot parse tensor from tensor_proto."); @@ -60,10 +60,10 @@ Status DeviceSimple::MakeTensorFromProto(const TensorProto& tensor_proto, return absl::OkStatus(); } -Status EvaluateNode(const NodeDef& node, const TensorVector& inputs, - DeviceBase* cpu_device, ResourceMgr* resource_mgr, - TensorVector* output) { - Status status; +absl::Status EvaluateNode(const NodeDef& node, const TensorVector& inputs, + DeviceBase* cpu_device, ResourceMgr* resource_mgr, + TensorVector* output) { + absl::Status status; std::unique_ptr device; if (cpu_device == nullptr) { device.reset(new DeviceSimple()); @@ -81,7 +81,7 @@ Status EvaluateNode(const NodeDef& node, const TensorVector& inputs, params.op_kernel = op_kernel.get(); params.resource_manager = resource_mgr; - gtl::InlinedVector output_attrs; + absl::InlinedVector output_attrs; const int num_outputs = op_kernel->num_outputs(); for (int i = 0; i < num_outputs; i++) { AllocatorAttributes attr; diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h index a146c9a5cad1ef..9ae5cb22bac42b 100644 --- a/tensorflow/core/grappler/optimizers/evaluation_utils.h +++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h @@ -38,9 +38,9 @@ class DeviceSimple : public DeviceBase { DeviceSimple(); ~DeviceSimple(); - Status MakeTensorFromProto(const TensorProto& tensor_proto, - const AllocatorAttributes alloc_attrs, - Tensor* tensor) override; + absl::Status MakeTensorFromProto(const TensorProto& tensor_proto, + const AllocatorAttributes alloc_attrs, + Tensor* tensor) override; Allocator* GetAllocator(AllocatorAttributes attr) override { return cpu_allocator(); @@ -54,10 +54,10 @@ class DeviceSimple : public DeviceBase { const std::string device_type_ = DEVICE_CPU; }; -Status EvaluateNode(const NodeDef& node, - const gtl::InlinedVector& inputs, - DeviceBase* cpu_device, ResourceMgr* resource_mgr, - gtl::InlinedVector* output); +absl::Status EvaluateNode(const NodeDef& node, + const absl::InlinedVector& inputs, + DeviceBase* cpu_device, ResourceMgr* resource_mgr, + absl::InlinedVector* output); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc index d84c742f8bbc6d..7f9206f2f9254c 100644 --- a/tensorflow/core/grappler/optimizers/function_api_info.cc +++ b/tensorflow/core/grappler/optimizers/function_api_info.cc @@ -26,7 +26,7 @@ namespace grappler { FunctionApiInfo::FunctionApiInfo() {} FunctionApiInfo::~FunctionApiInfo() {} -Status FunctionApiInfo::Init(const FunctionDef& function_def) { +absl::Status FunctionApiInfo::Init(const FunctionDef& function_def) { function_type_ = FunctionApiInfo::FunctionType::INFERENCE; for (const auto& attr : function_def.attr()) { if (attr.first == "api_preferred_device") { @@ -120,9 +120,10 @@ bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2, return true; } -Status ValidateSignature(const string& interface_name, - const std::vector& equiv_funcs, - const FunctionApiInfo::FunctionType function_type) { +absl::Status ValidateSignature( + const string& interface_name, + const std::vector& equiv_funcs, + const FunctionApiInfo::FunctionType function_type) { if (equiv_funcs.size() < 2) return absl::OkStatus(); for (size_t k = 1; k < equiv_funcs.size(); ++k) { const bool check_input = @@ -142,7 +143,7 @@ Status ValidateSignature(const string& interface_name, return absl::OkStatus(); } -Status ValidateSignatures( +absl::Status ValidateSignatures( const std::unordered_map>& intf_to_func, const FunctionApiInfo::FunctionType function_type) { @@ -153,7 +154,7 @@ Status ValidateSignatures( } } // namespace -Status FunctionLibraryApiInfo::Init( +absl::Status FunctionLibraryApiInfo::Init( const FunctionDefLibrary& function_library) { std::unordered_map> infer_funcs; std::unordered_map> fwd_funcs; @@ -197,7 +198,7 @@ Status FunctionLibraryApiInfo::Init( return absl::OkStatus(); } -Status FunctionLibraryApiInfo::GetEquivalentImplementations( +absl::Status FunctionLibraryApiInfo::GetEquivalentImplementations( const string& function_name, std::vector* other_functions) const { const auto func_it = func_info_.find(function_name); if (func_it == func_info_.end()) return absl::OkStatus(); diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h index c32cd5c9ca4c04..e2ae234fbb0d59 100644 --- a/tensorflow/core/grappler/optimizers/function_api_info.h +++ b/tensorflow/core/grappler/optimizers/function_api_info.h @@ -39,7 +39,7 @@ class FunctionApiInfo { BACKWARD, }; - Status Init(const FunctionDef& function_def); + absl::Status Init(const FunctionDef& function_def); const string& interface_name() const; const string& preferred_device() const; @@ -75,9 +75,9 @@ class FunctionLibraryApiInfo { FunctionLibraryApiInfo(); virtual ~FunctionLibraryApiInfo(); // Populate the internal field for the functions within the function_library. - Status Init(const FunctionDefLibrary& function_library); + absl::Status Init(const FunctionDefLibrary& function_library); - Status GetEquivalentImplementations( + absl::Status GetEquivalentImplementations( const string& function_name, std::vector* other_functions) const; const FunctionApiInfo* GetApiInfo(const string& function_name) const; diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc index 57c10a18460518..a03edb39f35b15 100644 --- a/tensorflow/core/grappler/optimizers/function_api_info_test.cc +++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc @@ -107,7 +107,7 @@ bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info, const string& func_name, const std::vector& expected_other) { std::vector other_impl; - Status status = + absl::Status status = lib_api_info.GetEquivalentImplementations(func_name, &other_impl); EXPECT_EQ(status, absl::OkStatus()); const std::unordered_set actual(other_impl.begin(), other_impl.end()); @@ -181,7 +181,7 @@ TEST(FunctionApiInfoTest, MismatchedArguments) { FunctionDefLibrary func_lib; PopulateSampleLibrary(/* mismatch_args */ true, &func_lib); FunctionLibraryApiInfo lib_api_info; - const Status ret = lib_api_info.Init(func_lib); + const absl::Status ret = lib_api_info.Init(func_lib); EXPECT_FALSE(ret.ok()); } diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc index b6920f8f3e6374..330cb62e19c3a8 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc @@ -116,7 +116,7 @@ class FakeDevice : public Device { public: FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {} explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {} - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } private: static DeviceAttributes attr(const string& device) { @@ -465,11 +465,11 @@ FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib, } // Push all constant inputs of an instantiating node into the function body. -Status PushDownConstInputs(const NodeDef& func_node, - const FunctionOptimizerContext& ctx, - GrapplerFunctionItem* item, - absl::flat_hash_set* const_inputs, - absl::flat_hash_set* control_deps) { +absl::Status PushDownConstInputs(const NodeDef& func_node, + const FunctionOptimizerContext& ctx, + GrapplerFunctionItem* item, + absl::flat_hash_set* const_inputs, + absl::flat_hash_set* control_deps) { // Record node control dependencies in the control_deps set. const auto record_control_deps = [&](const NodeDef* const_input) { for (int i = const_input->input_size() - 1; i >= 0; --i) { @@ -585,10 +585,9 @@ void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization, } } -Status UpdateSpecializedFunctionCallSite(const FunctionDef& func, - const NodeDef& func_node, - const string& specialized_func_name, - NodeDef* specialized_func_node) { +absl::Status UpdateSpecializedFunctionCallSite( + const FunctionDef& func, const NodeDef& func_node, + const string& specialized_func_name, NodeDef* specialized_func_node) { if (IsDirectFunctionCall(func, func_node)) { specialized_func_node->set_op(specialized_func_name); @@ -607,7 +606,7 @@ Status UpdateSpecializedFunctionCallSite(const FunctionDef& func, // function specialization. Function specialization might change the number of // inputs and outputs, so we have to make sure that graph node is updated // accordingly. -Status UpdateSpecializedFunctionNode( +absl::Status UpdateSpecializedFunctionNode( const FunctionDef& func, const NodeDef& func_node, const FunctionSpecialization& specialization, NodeDef* specialized_func_node) { @@ -643,7 +642,7 @@ Status UpdateSpecializedFunctionNode( return absl::OkStatus(); } -Status InitializeFunctionSpecializationSignature( +absl::Status InitializeFunctionSpecializationSignature( const NodeDef& func_node, const FunctionDef& func, const AttrSlice& func_instantiation_attr, const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) { @@ -683,9 +682,10 @@ string SpecializedFunctionName(const FunctionOptimizerContext& ctx, absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id); } -Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func, - FunctionOptimizerContext* ctx, - GraphDef* optimized_graph) { +absl::Status SpecializeFunction(const NodeDef& func_node, + const FunctionDef& func, + FunctionOptimizerContext* ctx, + GraphDef* optimized_graph) { VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node); const AttrSlice func_instantiation_attr = @@ -880,7 +880,7 @@ const bool IsExemptFromSideEffectsExecutionValidation(const string& op) { // // When function executed via FunctionLibraryRuntime we do not have to check // this, because `PruneFunctionBody` has special pruning rules for stateful ops. -Status ValidateSideEffectsExecution( +absl::Status ValidateSideEffectsExecution( const FunctionBody& fbody, OutputControlSource output_control_source, bool has_outgoing_control_edges, bool validate_outgoing_control_edge = true) { @@ -947,8 +947,8 @@ Status ValidateSideEffectsExecution( } // Validates that no dead tensor can reach function output. -Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def, - const FunctionBody& fbody) { +absl::Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def, + const FunctionBody& fbody) { absl::flat_hash_set output_nodes = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()}; @@ -1012,15 +1012,15 @@ Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def, } // Makes an instance of FunctionBody for inlining from a Node. -Status MakeFunctionBodyForInlining(const Node& node, - const FunctionLibraryDefinition& flib_def, - std::unique_ptr* fbody) { +absl::Status MakeFunctionBodyForInlining( + const Node& node, const FunctionLibraryDefinition& flib_def, + std::unique_ptr* fbody) { VLOG(3) << "Make function body for inlining: " << SummarizeNode(node); // Finds a FunctionDef in a library and verifies that it exists. const auto find_fdef = [&flib_def, &node]( const string& name, - const FunctionDef** fdef) -> Status { + const FunctionDef** fdef) -> absl::Status { if ((*fdef = flib_def.Find(name)) == nullptr) { return absl::InternalError(absl::StrCat( "Was not able to find a function definition (name=", name, @@ -1208,10 +1208,10 @@ void AddFrameForwardingControlEdge(const std::vector& info, // ops (Switch/Merge/...). // // Runs a placer after inlining, to keep all nodes in a graph placed. -Status InlineFunctionCalls(const GrapplerItem& item, - const RewriterConfig::Toggle opt_level, - const bool lower_control_flow, - GraphDef* output_graph) { +absl::Status InlineFunctionCalls(const GrapplerItem& item, + const RewriterConfig::Toggle opt_level, + const bool lower_control_flow, + GraphDef* output_graph) { bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE; VLOG(2) << "Inline function calls: grappler_item_id=" << item.id << " (aggressive_mode=" << is_aggressive << ")"; @@ -1330,7 +1330,7 @@ Status InlineFunctionCalls(const GrapplerItem& item, } // Basic validation rules defined in common_runtime shared by all functions. - Status can_inline_function_call = + absl::Status can_inline_function_call = ValidateInlining(n, fbody.get(), inline_options); // Additional validation rules defined only in Grappler. @@ -1448,7 +1448,7 @@ void RestoreTensorMapping(const FunctionOptimizerContext& ctx, } // namespace -Status FunctionOptimizer::RunFunctionOptimizerPass( +absl::Status FunctionOptimizer::RunFunctionOptimizerPass( const GrapplerItem& item, GraphDef* optimized_graph) const { VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id; @@ -1498,7 +1498,8 @@ Status FunctionOptimizer::RunFunctionOptimizerPass( if (specialization_worthy && !no_specialize) { // TODO(ezhulenev): Specialize function call if input has a known shape. // Specialize function body for its instantiation attributes and inputs. - Status status = SpecializeFunction(node, *func, &ctx, optimized_graph); + absl::Status status = + SpecializeFunction(node, *func, &ctx, optimized_graph); if (!status.ok() && is_graph_modified()) { return status; } else if (!status.ok() && !is_graph_modified()) { @@ -1523,8 +1524,8 @@ Status FunctionOptimizer::RunFunctionOptimizerPass( return absl::OkStatus(); } -Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item, - GraphDef* optimized_graph) { +absl::Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item, + GraphDef* optimized_graph) { // Nothing to do here. if (item.graph.library().function_size() == 0) { return absl::AbortedError("Nothing to do."); diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.h b/tensorflow/core/grappler/optimizers/function_optimizer.h index 7b2712c9a7b281..8f8eb7325326fe 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer.h +++ b/tensorflow/core/grappler/optimizers/function_optimizer.h @@ -35,8 +35,8 @@ class FunctionOptimizer : public GraphOptimizer { bool UsesFunctionLibrary() const override { return true; } - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) override; + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; private: friend class FunctionOptimizerTest; @@ -46,8 +46,8 @@ class FunctionOptimizer : public GraphOptimizer { // `optimized_graph`. Function call nodes inlined or specialized, and // instantiated function body or specialized function call nodes will be added // to the `optimized_graph`. - Status RunFunctionOptimizerPass(const GrapplerItem& item, - GraphDef* optimized_graph) const; + absl::Status RunFunctionOptimizerPass(const GrapplerItem& item, + GraphDef* optimized_graph) const; RewriterConfig::Toggle opt_level_; bool lower_control_flow_; diff --git a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc index 9e03bbad6b91a3..1c14f48402c352 100644 --- a/tensorflow/core/grappler/optimizers/function_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/function_optimizer_test.cc @@ -128,7 +128,7 @@ TEST_F(FunctionOptimizerTest, InlineFunction_FixedTypeFunction) { }); GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // Calls to XTimesTwo were removed from the graph. @@ -478,7 +478,7 @@ TEST_F(FunctionOptimizerTest, InlineSymbolicGradientNoInlineFunc) { *item.graph.mutable_library()->add_function() = func; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); // The optimizer should succeed but the graphs should be the same. TF_EXPECT_OK(status); CompareGraphs(item.graph, output); @@ -2095,7 +2095,7 @@ TEST_F(FunctionOptimizerTest, PruningUselessLibraryFunctions) { test::function::XTimes16(), }); GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); ASSERT_EQ(output.library().function().size(), 1); @@ -2138,7 +2138,7 @@ TEST_F(FunctionOptimizerTest, PreserveSaverDefFunctions) { item.restore_op = "Restore"; item.save_op = "Save"; GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); + absl::Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); ASSERT_EQ(output.library().function().size(), 3); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc index 726bd0b3325a22..5cd5bafc759148 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.cc @@ -190,8 +190,8 @@ inline std::pair GetSrcAndDstDataFormats( return {src_format, dst_format}; } -Status ExpandLayoutSensitiveOp(TransposeContext* context, - TransposerFactory* transposer_factory) { +absl::Status ExpandLayoutSensitiveOp(TransposeContext* context, + TransposerFactory* transposer_factory) { const int num_nodes = context->num_nodes; for (int i = 0; i < num_nodes; ++i) { auto* node_view = context->graph_view->GetNode(i); @@ -200,7 +200,7 @@ Status ExpandLayoutSensitiveOp(TransposeContext* context, std::shared_ptr transposer = transposer_factory->GetTransposer(*node_def); if (transposer == nullptr) { - return Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat( "Layout sensitive operation should have a transposer. Node: ", @@ -212,8 +212,8 @@ Status ExpandLayoutSensitiveOp(TransposeContext* context, return absl::OkStatus(); } -Status ExpandLayoutAgnosticOp(TransposeContext* context, - TransposerFactory* transposer_factory) { +absl::Status ExpandLayoutAgnosticOp(TransposeContext* context, + TransposerFactory* transposer_factory) { const int num_nodes = context->num_nodes; for (int i = 0; i < num_nodes; ++i) { auto* node_view = context->graph_view->GetNode(i); @@ -221,7 +221,7 @@ Status ExpandLayoutAgnosticOp(TransposeContext* context, if (IsLayoutAgnosticOp(*node_def)) { const auto& transposer = transposer_factory->GetTransposer(*node_def); if (transposer == nullptr) { - return Status( + return absl::Status( absl::StatusCode::kNotFound, absl::StrCat( "Layout agnostic operation should have a transposer. Node: ", @@ -297,7 +297,7 @@ inline bool IsCancellableNodePair( IsCancellableDataFormatNodePair(fanout_transpose, fanin_transpose); } -Status EraseCancellableNodes(TransposeContext* context) { +absl::Status EraseCancellableNodes(TransposeContext* context) { const int original_num_nodes = context->num_nodes; utils::MutableGraphView* graph_view = context->graph_view.get(); utils::Mutation* mutation = graph_view->GetMutationBuilder(); @@ -344,7 +344,7 @@ Status EraseCancellableNodes(TransposeContext* context) { // // From: Transpose[NHWC->NCHW] -> Pad[paddings] -> Transpose[NCHW->NHWC] // To: Pad[Permute(paddings)] -Status EraseCancellableNodesAroundPad(TransposeContext* context) { +absl::Status EraseCancellableNodesAroundPad(TransposeContext* context) { utils::MutableGraphView* graph_view = context->graph_view.get(); utils::Mutation* mutation = graph_view->GetMutationBuilder(); @@ -452,7 +452,7 @@ Status EraseCancellableNodesAroundPad(TransposeContext* context) { return mutation->Apply(); } -Status EraseOutputShapeAttrs(TransposeContext* context) { +absl::Status EraseOutputShapeAttrs(TransposeContext* context) { utils::MutableGraphView* graph_view = context->graph_view.get(); utils::Mutation* mutation = graph_view->GetMutationBuilder(); const int num_nodes = graph_view->NumNodes(); @@ -473,9 +473,9 @@ Status EraseOutputShapeAttrs(TransposeContext* context) { // When there is only CPU, there will be no conversion by default, unless user // chose to convert the graph to a desired format. Currently, NCHW -> NHWC // format conversion is available on CPU. -Status GenericLayoutOptimizer::Optimize(Cluster* cluster, - const GrapplerItem& item, - GraphDef* output) { +absl::Status GenericLayoutOptimizer::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* output) { if (cluster == nullptr) { LOG(WARNING) << "generic layout optimizer was called with cluster == nullptr"; @@ -483,7 +483,7 @@ Status GenericLayoutOptimizer::Optimize(Cluster* cluster, } if (!enforced_layout_.empty() && enforced_layout_ != "NHWC" && enforced_layout_ != "NCHW") { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Invalid value for enforced_layout: ", enforced_layout_, ". Supported layouts: 'NHWC', 'NCHW'.")); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h index a97cf4abe676d6..61a578fabbef72 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer.h @@ -47,8 +47,8 @@ class GenericLayoutOptimizer : public GraphOptimizer { bool UsesFunctionLibrary() const override { return false; } - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* output) override; + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; private: RewriterConfig::Toggle opt_level_; diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc index fc74b1ea21d8ef..6578d30df2035b 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_test.cc @@ -284,7 +284,7 @@ TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv2DGraph) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); @@ -311,7 +311,7 @@ TEST_F(GenericLayoutOptimizerTest, PreserveFetch) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv_node = graph_view.GetNode("Conv2D"); @@ -330,7 +330,7 @@ TEST_F(GenericLayoutOptimizerTest, EmptyDevice) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv_node = graph_view.GetNode("Conv2D"); @@ -354,7 +354,7 @@ TEST_F(GenericLayoutOptimizerTest, GPUDevice) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv_node = graph_view.GetNode("Conv2D"); @@ -373,7 +373,7 @@ TEST_F(GenericLayoutOptimizerTest, CPUDevice) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv_node = graph_view.GetNode("Conv2D"); @@ -396,7 +396,7 @@ TEST_F(GenericLayoutOptimizerTest, NoOptimizeIntegerConvolution) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv_node = graph_view.GetNode("Conv2D"); @@ -418,7 +418,7 @@ TEST_F(GenericLayoutOptimizerTest, Connectivity) { // middle are layout agnostic). If the graph is already in topological order, // the problem is easier, where layout optimizer only needs to check // single-hop connectivity. - Status status; + absl::Status status; utils::GraphView graph_view_original(&item.graph, &status); const int i1_index = graph_view_original.GetNode("i1")->node_index(); const int i2_index = graph_view_original.GetNode("i2")->node_index(); @@ -452,7 +452,7 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv2d_backprop_node = graph_view.GetNode("Conv2DBackpropInput"); @@ -487,7 +487,7 @@ TEST_F(GenericLayoutOptimizerTest, Conv2DDataFormatVecPermuteCollapse) { // // Graph after collapsion: // input -> T -> conv2d -> shape -> fill -> T' -> output - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); auto* conv2d_node = graph_view.GetNode("Conv2D"); @@ -564,7 +564,7 @@ TEST_F(GenericLayoutOptimizerTest, DoNotPruneNonAddedCancellableTransposes) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); @@ -710,7 +710,7 @@ TEST_F(GenericLayoutOptimizerTest, PreserveInputShapes) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); @@ -735,7 +735,7 @@ TEST_F(GenericLayoutOptimizerTest, OptimizeSimpleConv3DGraph_CPU) { GraphDef output; TF_ASSERT_OK(optimizer.Optimize(virtual_cluster_.get(), item, &output)); - Status status; + absl::Status status; utils::GraphView graph_view(&output, &status); TF_ASSERT_OK(status); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc index 70653a0a643606..2854810e3c040f 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.cc @@ -155,7 +155,7 @@ bool IsHostMemory(const NodeDef& node, int output_port) { DeviceNameUtils::ParsedName parsed_name; if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) { DeviceType device_type(parsed_name.type); - Status s = FindKernelDef(device_type, node, nullptr, nullptr); + absl::Status s = FindKernelDef(device_type, node, nullptr, nullptr); if (s.ok()) { tensorflow::MemoryTypeVector in_mtypes; tensorflow::MemoryTypeVector out_mtypes; @@ -237,17 +237,16 @@ class ScopedDataFormatUpgrader { // TransposeContext. -Status TransposeContext::InitializeTransposeContext(bool assume_valid_feeds, - const GrapplerItem& item, - const Cluster* cluster, - TransposeContext* context) { +absl::Status TransposeContext::InitializeTransposeContext( + bool assume_valid_feeds, const GrapplerItem& item, const Cluster* cluster, + TransposeContext* context) { DCHECK(context != nullptr); context->graph_properties = std::make_unique(item); TF_RETURN_IF_ERROR( context->graph_properties->InferStatically(assume_valid_feeds)); TF_RETURN_IF_ERROR( context->graph_properties->AnnotateOutputShapes(&context->graph)); - Status status; + absl::Status status; context->graph_view = std::make_unique(&context->graph, &status); TF_RETURN_IF_ERROR(status); @@ -299,12 +298,10 @@ bool Transposer::ShouldProcess(const TransposeContext& context, !(node.NumRegularFanouts() == 0 && node.NumControlledFanouts() == 0); } -Status Transposer::CreateConstPermNode(TransposeContext* context, - absl::string_view node_name, - absl::string_view device, - absl::Span permutation, - absl::string_view control_node_name, - utils::MutationNewNode* added_node) { +absl::Status Transposer::CreateConstPermNode( + TransposeContext* context, absl::string_view node_name, + absl::string_view device, absl::Span permutation, + absl::string_view control_node_name, utils::MutationNewNode* added_node) { auto* graph_view = context->graph_view.get(); DCHECK(!graph_view->HasNode(node_name)); @@ -329,13 +326,13 @@ Status Transposer::CreateConstPermNode(TransposeContext* context, tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); node.mutable_attr()->insert({"value", attr_tensor}); - Status status; + absl::Status status; *added_node = graph_view->GetMutationBuilder()->AddNode(std::move(node), &status); return status; } -Status Transposer::CreateTransposeNode( +absl::Status Transposer::CreateTransposeNode( TransposeContext* context, absl::string_view name_format, const DataType& data_type, absl::string_view device, TensorShapeProto fanin_shape, absl::Span permutation, @@ -380,16 +377,15 @@ Status Transposer::CreateTransposeNode( // Connect const_perm_node to 2nd input of transpose_node. node.add_input(const_perm_node_name); - Status status; + absl::Status status; *added_node = graph_view->GetMutationBuilder()->AddNode(std::move(node), &status); return status; } -Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context, - absl::Span dst_ports, - utils::MutableNodeView* dst_node, - absl::string_view op) { +absl::Status Transposer::UpdateFaninEdgesWithOp( + TransposeContext* context, absl::Span dst_ports, + utils::MutableNodeView* dst_node, absl::string_view op) { const bool is_in_frame = context->frames.IsInFrame(*dst_node->node()); for (int dst_port : dst_ports) { auto& fanin_port = dst_node->GetRegularFanin(dst_port); @@ -406,10 +402,9 @@ Status Transposer::UpdateFaninEdgesWithOp(TransposeContext* context, return absl::OkStatus(); } -Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context, - absl::Span src_ports, - utils::MutableNodeView* src_node, - absl::string_view op) { +absl::Status Transposer::UpdateFanoutEdgesWithOp( + TransposeContext* context, absl::Span src_ports, + utils::MutableNodeView* src_node, absl::string_view op) { // Update attr _output_shapes for output ports. const auto* output_shape_attr = src_node->GetAttr(kAttrOutputShape); AttrValue shape_attr_copy; @@ -452,7 +447,7 @@ Status Transposer::UpdateFanoutEdgesWithOp(TransposeContext* context, return absl::OkStatus(); } -Status Transposer::CreateDataFormatNode( +absl::Status Transposer::CreateDataFormatNode( TransposeContext* context, absl::string_view node_name, absl::string_view op, absl::string_view device, const DataType& data_type, bool is_fanin_on_host, bool is_src_format_to_dst_format, @@ -491,13 +486,13 @@ Status Transposer::CreateDataFormatNode( // Add place holder for 1st input field. node.add_input(""); - Status status; + absl::Status status; *added_node = graph_view->GetMutationBuilder()->AddNode(std::move(node), &status); return status; } -Status Transposer::UpdateEdge( +absl::Status Transposer::UpdateEdge( TransposeContext* context, absl::string_view name_format, absl::string_view op, const AttrValue* input_shape, bool is_in_frame, bool is_src_format_to_dst_format, const int src_port, const int dst_port, @@ -551,10 +546,10 @@ Status Transposer::UpdateEdge( is_src_format_to_dst_format, &added_node)); added_node_name = node_name; } else { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Unsupported op \"", op, - "\". Supported ops are Transpose, " - "DataFormatVecPerm, DataFormatDimMap.")); + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Unsupported op \"", op, + "\". Supported ops are Transpose, " + "DataFormatVecPerm, DataFormatDimMap.")); } // Connect src_node to 1st input of added_node. @@ -702,8 +697,8 @@ inline string GetLayoutSensitiveNodeDataFormat( return ""; } -Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status LayoutSensitiveOpTransposer::UpdateNode( + TransposeContext* context, utils::MutableNodeView* node) { utils::Mutation* mutation = context->graph_view->GetMutationBuilder(); AttrValue data_format_attr; data_format_attr.set_s(context->dst_format); @@ -742,7 +737,7 @@ Status LayoutSensitiveOpTransposer::UpdateNode(TransposeContext* context, return absl::OkStatus(); } -Status DefaultLayoutSensitiveOpTransposer::TransposeNode( +absl::Status DefaultLayoutSensitiveOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutSensitiveOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); @@ -762,8 +757,8 @@ Status DefaultLayoutSensitiveOpTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status AvgPoolGradTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsAvgPoolGrad(*node->node())); if (!ShouldProcess(*context, *node) || !IsFaninPortRankN(*node, 1, 4)) { return absl::OkStatus(); @@ -779,8 +774,8 @@ Status AvgPoolGradTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status BiasAddTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status BiasAddTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { // This TransposeNode allows for BiasAdd but not BiasAddV1, since BiasAdd // supports different data format. DCHECK(IsBiasAddV2(*node->node())); @@ -805,8 +800,8 @@ Status BiasAddTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status BiasAddGradTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsBiasAddGrad(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -831,7 +826,7 @@ Status BiasAddGradTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status Conv2DBackpropFilterTransposer::TransposeNode( +absl::Status Conv2DBackpropFilterTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv2DBackpropFilter(*node->node()) || IsDepthwiseConv2dNativeBackpropFilter(*node->node())); @@ -850,7 +845,7 @@ Status Conv2DBackpropFilterTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status Conv2DBackpropInputTransposer::TransposeNode( +absl::Status Conv2DBackpropInputTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv2DBackpropInput(*node->node()) || IsDepthwiseConv2dNativeBackpropInput(*node->node())); @@ -883,8 +878,8 @@ Status Conv2DBackpropInputTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status Conv3DTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status Conv3DTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsConv3D(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 5) { @@ -903,7 +898,7 @@ Status Conv3DTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status Conv3DBackpropFilterTransposer::TransposeNode( +absl::Status Conv3DBackpropFilterTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropFilterV2(*node->node())); const int rank = GetFanoutPortRank(*node, 0); @@ -926,7 +921,7 @@ Status Conv3DBackpropFilterTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status Conv3DBackpropInputTransposer::TransposeNode( +absl::Status Conv3DBackpropInputTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsConv3DBackpropInputV2(*node->node())); const int rank = GetFanoutPortRank(*node, 0); @@ -948,8 +943,8 @@ Status Conv3DBackpropInputTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status FusedBatchNormExTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status FusedBatchNormExTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormEx(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { return absl::OkStatus(); @@ -978,7 +973,7 @@ bool FusedBatchNormGradTransposer::IsTraining( return false; } -Status FusedBatchNormGradTransposer::TransposeNode( +absl::Status FusedBatchNormGradTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsFusedBatchNormGrad(*node->node())); const int rank = GetFanoutPortRank(*node, 0); @@ -999,8 +994,8 @@ Status FusedBatchNormGradTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsMaxPoolV2(*node->node())); // We check data_input's shape instead, because the shape inference of // MaxPoolV2 is not able to infer the shape when ksize or strides is not @@ -1022,8 +1017,8 @@ Status MaxPoolV2Transposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status MaxPool3DTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status MaxPool3DTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsMaxPool3D(*node->node())); // We check data_input's shape instead, because the shape inference of // MaxPool3D is not able to infer the shape when ksize or strides is not @@ -1044,8 +1039,8 @@ Status MaxPool3DTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status MaxPoolGradTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolGrad(*node->node()) || IsMaxPoolGradGradV1(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { return absl::OkStatus(); @@ -1060,8 +1055,8 @@ Status MaxPoolGradTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status MaxPoolGradV2Transposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status MaxPoolGradV2Transposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsMaxPoolGradV2(*node->node()) || IsMaxPoolGradGradV2(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4)) { return absl::OkStatus(); @@ -1186,7 +1181,7 @@ std::vector LayoutAgnosticOpTransposer::GetVariadicNDFaninPorts( return ports; } -Status DefaultLayoutAgnosticOpTransposer::TransposeNode( +absl::Status DefaultLayoutAgnosticOpTransposer::TransposeNode( TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsDefaultLayoutAgnosticOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); @@ -1206,8 +1201,8 @@ Status DefaultLayoutAgnosticOpTransposer::TransposeNode( return context->graph_view->GetMutationBuilder()->Apply(); } -Status AddNTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status AddNTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsAddN(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1252,7 +1247,7 @@ std::vector BinaryOpTransposer::GetNDDataFaninPorts( return values; } -Status BinaryOpTransposer::AddNodeReshape( +absl::Status BinaryOpTransposer::AddNodeReshape( utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, absl::string_view input_name, absl::string_view shape_const_node_name, const DataType& data_type) { @@ -1271,12 +1266,12 @@ Status BinaryOpTransposer::AddNodeReshape( attr_type_params.set_type(data_type); new_node.mutable_attr()->insert({"T", attr_type_params}); - Status status; + absl::Status status; mutation->AddNode(std::move(new_node), &status); return status; } -Status BinaryOpTransposer::AddNodeShapeConst( +absl::Status BinaryOpTransposer::AddNodeShapeConst( utils::Mutation* mutation, absl::string_view node_name, absl::string_view node_device, bool node_in_frame, int num_channels, absl::string_view depended_node, int rank) { @@ -1304,14 +1299,13 @@ Status BinaryOpTransposer::AddNodeShapeConst( new_node.add_input(AsControlDependency(string(depended_node))); } - Status status; + absl::Status status; mutation->AddNode(std::move(new_node), &status); return status; } -Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context, - utils::MutableNodeView* node, - int rank) { +absl::Status BinaryOpTransposer::MaybeReshapeVectorFanin( + TransposeContext* context, utils::MutableNodeView* node, int rank) { int vector_index = -1; if (IsNDOperateWithMD(*node, rank, 1)) { vector_index = 1; @@ -1352,8 +1346,8 @@ Status BinaryOpTransposer::MaybeReshapeVectorFanin(TransposeContext* context, return absl::OkStatus(); } -Status BinaryOpTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status BinaryOpTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsBinaryOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1374,8 +1368,8 @@ Status BinaryOpTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status ConcatOpTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status ConcatOpTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsConcat(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1404,8 +1398,8 @@ Status ConcatOpTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status FillOpTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status FillOpTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsFill(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || !IsFaninPortDimsNIfConst(*node, 0, {4}) || @@ -1418,8 +1412,8 @@ Status FillOpTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status IdentityNTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status IdentityNTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsIdentityN(*node->node())); const auto ports_4d = GetVariadicNDFaninPorts(*context, *node, 4); @@ -1468,8 +1462,8 @@ bool MergeTransposer::IsEveryFaninAfterDstToSrcTransform( return true; } -Status MergeTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status MergeTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsMerge(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1486,8 +1480,8 @@ Status MergeTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status PadTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status PadTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsMirrorPad(*node->node()) || IsMirrorPadGrad(*node->node()) || IsPad(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || @@ -1578,8 +1572,8 @@ bool ReduceTransposer::IsReduceAxisSupported(const TransposeContext& context, IsAlongAxis(tensor, indices({'C'}), 4); } -Status ReduceTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status ReduceTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsReduceOp(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1604,8 +1598,8 @@ Status ReduceTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status ReverseV2Transposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status ReverseV2Transposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsReverseV2(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || !IsAfterDstToSrcTransform(*context, *node)) { @@ -1634,8 +1628,8 @@ std::vector SelectTransposer::GetFaninPorts( return {1, 2}; } -Status SelectTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SelectTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSelect(*node->node())); const auto& regular_fanin_0 = node->GetRegularFanin(0); auto* regular_fanin_0_node = regular_fanin_0.node_view(); @@ -1651,8 +1645,8 @@ Status SelectTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status ShapeTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status ShapeTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsShape(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1672,8 +1666,8 @@ Status ShapeTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status ShapeNTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status ShapeNTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsShapeN(*node->node())); // ShapeN requires all input tensors to have the same dimensions. Therefore, // we simply use the 0th fanin port. @@ -1696,8 +1690,8 @@ Status ShapeNTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status SliceTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SliceTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSlice(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1719,8 +1713,8 @@ Status SliceTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status SplitTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SplitTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSplit(*node->node())); const auto ports = GetDataFanoutPorts(*node); int rank = 4; @@ -1744,8 +1738,8 @@ Status SplitTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status SplitVTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SplitVTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSplitV(*node->node())); const auto ports = GetDataFanoutPorts(*node); int rank = 4; @@ -1834,8 +1828,8 @@ bool SqueezeTransposer::IsDimsSupported( IsAlongAxis(*squeeze_dims_attr, indices({'N', 'H', 'W'}), kRank)); } -Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SqueezeTransposer::UpdateSqueezeDims( + TransposeContext* context, utils::MutableNodeView* node) { const auto* squeeze_dims_attr = node->GetAttr(kAttrSqueezeDims); if (squeeze_dims_attr == nullptr) { return errors::InvalidArgument("Missing attribute ", kAttrSqueezeDims); @@ -1869,8 +1863,8 @@ Status SqueezeTransposer::UpdateSqueezeDims(TransposeContext* context, return absl::OkStatus(); } -Status SqueezeTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SqueezeTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSqueeze(*node->node())); if (!ShouldProcess(*context, *node) || !IsDimsSupported(*context, *node) || !IsInputConvertible(*context, *node) || @@ -1898,9 +1892,9 @@ bool StridedSliceTransposer::HasOnlyBeginEndMask( IsMaskZero(node, "shrink_axis_mask"); } -Status StridedSliceTransposer::PermuteMask(TransposeContext* context, - utils::MutableNodeView* node, - absl::string_view mask) { +absl::Status StridedSliceTransposer::PermuteMask(TransposeContext* context, + utils::MutableNodeView* node, + absl::string_view mask) { // Computers the permutation of the masks based on the src and dst format. // For example: // src_format = NHWC @@ -1927,8 +1921,8 @@ Status StridedSliceTransposer::PermuteMask(TransposeContext* context, return absl::OkStatus(); } -Status StridedSliceTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status StridedSliceTransposer::TransposeNode( + TransposeContext* context, utils::MutableNodeView* node) { DCHECK(IsStridedSlice(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1955,8 +1949,8 @@ Status StridedSliceTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status SwitchTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status SwitchTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsSwitch(*node->node())); const int rank = GetFaninPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1973,8 +1967,8 @@ Status SwitchTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status TernaryOpTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status TernaryOpTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsTernaryOp(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { @@ -1991,8 +1985,8 @@ Status TernaryOpTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status TileTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status TileTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsTile(*node->node())); if (!ShouldProcess(*context, *node) || !IsFanoutPortRankN(*node, 0, 4) || !IsFaninPortDimsNIfConst(*node, 1, {4}) || @@ -2006,8 +2000,8 @@ Status TileTransposer::TransposeNode(TransposeContext* context, return context->graph_view->GetMutationBuilder()->Apply(); } -Status UnaryGradTransposer::TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) { +absl::Status UnaryGradTransposer::TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) { DCHECK(IsUnaryGrad(*node->node())); const int rank = GetFanoutPortRank(*node, 0); if (rank != 4 && rank != 5) { diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h index e84114c59908ba..1c0c0134e51660 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer.h @@ -50,14 +50,14 @@ struct TransposeContext { // Initializes TransposeContext with given GrapplerItem. Because initializing // FrameMap and GraphProperties may return error, we initialize // TransposeContext outside constructor. - static Status InitializeTransposeContext(bool assume_valid_feeds, - const GrapplerItem& item, - const Cluster* cluster, - TransposeContext* context); - - static Status InitializeTransposeContext(const GrapplerItem& item, - const Cluster* cluster, - TransposeContext* context) { + static absl::Status InitializeTransposeContext(bool assume_valid_feeds, + const GrapplerItem& item, + const Cluster* cluster, + TransposeContext* context); + + static absl::Status InitializeTransposeContext(const GrapplerItem& item, + const Cluster* cluster, + TransposeContext* context) { return InitializeTransposeContext(false, item, cluster, context); } @@ -109,23 +109,23 @@ class Transposer { // Transposes given node from src format to dst format. Also perform other // necessary operations to guarantee the graph produce the same result. // Eg. Add Transpose node sets before fanin ports and after fanout ports. - virtual Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) = 0; + virtual absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) = 0; // Creates a Const node for permutation. If node with node_name already exits, // return and reuse it. - Status CreateConstPermNode(TransposeContext* context, - absl::string_view node_name, - absl::string_view device, - absl::Span permutation, - absl::string_view control_node_name, - utils::MutationNewNode* added_node); + absl::Status CreateConstPermNode(TransposeContext* context, + absl::string_view node_name, + absl::string_view device, + absl::Span permutation, + absl::string_view control_node_name, + utils::MutationNewNode* added_node); // Creates a TransposeNode with given properties. If node with node_name // already exits, return and reuse it. // A const perm node is also created and connected to the 2nd fanin. // control_node_name is ignored if it is empty. - Status CreateTransposeNode( + absl::Status CreateTransposeNode( TransposeContext* context, absl::string_view name_format, const DataType& data_type, absl::string_view device, TensorShapeProto fanin_shape, absl::Span permutation, @@ -134,26 +134,25 @@ class Transposer { // Update all edges between dst_node->fanin[dst_ports] and dst_node by // inserting an op node. - Status UpdateFaninEdgesWithOp(TransposeContext* context, - absl::Span dst_ports, - utils::MutableNodeView* dst_node, - absl::string_view op); + absl::Status UpdateFaninEdgesWithOp(TransposeContext* context, + absl::Span dst_ports, + utils::MutableNodeView* dst_node, + absl::string_view op); // Update all edges between src_node:src_ports and nodes take // src_node:src_ports as fanin. Also update attr _output_shape of src_node. - Status UpdateFanoutEdgesWithOp(TransposeContext* context, - absl::Span src_ports, - utils::MutableNodeView* src_node, - absl::string_view op); + absl::Status UpdateFanoutEdgesWithOp(TransposeContext* context, + absl::Span src_ports, + utils::MutableNodeView* src_node, + absl::string_view op); // Creates a DataFromat node with given properties. // DataFromat op is either DataFormatVecPermute or DataFormatDimMap. - Status CreateDataFormatNode(TransposeContext* context, - absl::string_view node_name, absl::string_view op, - absl::string_view device, - const DataType& data_type, bool is_fanin_on_host, - bool is_src_format_to_dst_format, - utils::MutationNewNode* added_node); + absl::Status CreateDataFormatNode( + TransposeContext* context, absl::string_view node_name, + absl::string_view op, absl::string_view device, const DataType& data_type, + bool is_fanin_on_host, bool is_src_format_to_dst_format, + utils::MutationNewNode* added_node); protected: int GetFanoutPortRank(const utils::MutableNodeView& node, int port) const; @@ -178,12 +177,12 @@ class Transposer { // Update all edges between dst_node->fanin[dst_ports] and dst_node. // A node with op is created and inserted between all edges. // op is one of Transpose, DataFormatVecPermute or DataFormatDimMap. - Status UpdateEdge(TransposeContext* context, absl::string_view name_format, - absl::string_view op, const AttrValue* input_shape, - bool is_in_frame, bool is_src_format_to_dst_format, - const int src_port, const int dst_port, - utils::MutableNodeView* src_node, - utils::MutableNodeView* dst_node); + absl::Status UpdateEdge(TransposeContext* context, + absl::string_view name_format, absl::string_view op, + const AttrValue* input_shape, bool is_in_frame, + bool is_src_format_to_dst_format, const int src_port, + const int dst_port, utils::MutableNodeView* src_node, + utils::MutableNodeView* dst_node); string GetFaninNameFormat(absl::string_view node_name, int port, absl::string_view src_format, absl::string_view dst_format); @@ -203,7 +202,8 @@ class LayoutSensitiveOpTransposer : public Transposer { // Updates attrs data_format, ksize, strides of the given node to dst_format. // _output_shape is updated during UpdateOutputEdges. - Status UpdateNode(TransposeContext* context, utils::MutableNodeView* node); + absl::Status UpdateNode(TransposeContext* context, + utils::MutableNodeView* node); }; // Layout sensitive op transposers. @@ -213,88 +213,88 @@ class DefaultLayoutSensitiveOpTransposer : public LayoutSensitiveOpTransposer { explicit DefaultLayoutSensitiveOpTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class BiasAddTransposer : public LayoutSensitiveOpTransposer { public: explicit BiasAddTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class AvgPoolGradTransposer : public LayoutSensitiveOpTransposer { public: explicit AvgPoolGradTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class BiasAddGradTransposer : public LayoutSensitiveOpTransposer { public: explicit BiasAddGradTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class Conv2DBackpropFilterTransposer : public LayoutSensitiveOpTransposer { public: explicit Conv2DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class Conv2DBackpropInputTransposer : public LayoutSensitiveOpTransposer { public: explicit Conv2DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class Conv3DTransposer : public LayoutSensitiveOpTransposer { public: explicit Conv3DTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class Conv3DBackpropFilterTransposer : public LayoutSensitiveOpTransposer { public: explicit Conv3DBackpropFilterTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class Conv3DBackpropInputTransposer : public LayoutSensitiveOpTransposer { public: explicit Conv3DBackpropInputTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class FusedBatchNormExTransposer : public LayoutSensitiveOpTransposer { public: explicit FusedBatchNormExTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class FusedBatchNormGradTransposer : public LayoutSensitiveOpTransposer { public: explicit FusedBatchNormGradTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool IsTraining(const utils::MutableNodeView& node) const; @@ -304,32 +304,32 @@ class MaxPoolV2Transposer : public LayoutSensitiveOpTransposer { public: explicit MaxPoolV2Transposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class MaxPool3DTransposer : public LayoutSensitiveOpTransposer { public: explicit MaxPool3DTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class MaxPoolGradTransposer : public LayoutSensitiveOpTransposer { public: explicit MaxPoolGradTransposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class MaxPoolGradV2Transposer : public LayoutSensitiveOpTransposer { public: explicit MaxPoolGradV2Transposer() : LayoutSensitiveOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; // Layout agnostic op transposers. @@ -351,74 +351,75 @@ class DefaultLayoutAgnosticOpTransposer : public LayoutAgnosticOpTransposer { public: explicit DefaultLayoutAgnosticOpTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class AddNTransposer : public LayoutAgnosticOpTransposer { public: explicit AddNTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class BinaryOpTransposer : public LayoutAgnosticOpTransposer { public: explicit BinaryOpTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool IsNDOperateWithMD(const utils::MutableNodeView& node, int n, int m); bool IsFaninShapeSupported(const utils::MutableNodeView& node, int rank); std::vector GetNDDataFaninPorts(const utils::MutableNodeView& node, int rank); - Status AddNodeShapeConst(utils::Mutation* mutation, - absl::string_view node_name, - absl::string_view node_device, bool node_in_frame, - int num_channels, absl::string_view depended_node, - int rank); - Status AddNodeReshape(utils::Mutation* mutation, absl::string_view node_name, - absl::string_view node_device, - absl::string_view input_name, - absl::string_view shape_const_node_name, - const DataType& data_type); - Status MaybeReshapeVectorFanin(TransposeContext* context, - utils::MutableNodeView* node, int rank); + absl::Status AddNodeShapeConst(utils::Mutation* mutation, + absl::string_view node_name, + absl::string_view node_device, + bool node_in_frame, int num_channels, + absl::string_view depended_node, int rank); + absl::Status AddNodeReshape(utils::Mutation* mutation, + absl::string_view node_name, + absl::string_view node_device, + absl::string_view input_name, + absl::string_view shape_const_node_name, + const DataType& data_type); + absl::Status MaybeReshapeVectorFanin(TransposeContext* context, + utils::MutableNodeView* node, int rank); }; class ConcatOpTransposer : public LayoutAgnosticOpTransposer { public: explicit ConcatOpTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class FillOpTransposer : public LayoutAgnosticOpTransposer { public: explicit FillOpTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class IdentityNTransposer : public LayoutAgnosticOpTransposer { public: explicit IdentityNTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class MergeTransposer : public LayoutAgnosticOpTransposer { public: explicit MergeTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool IsEveryFaninAfterDstToSrcTransform( @@ -430,16 +431,16 @@ class PadTransposer : public LayoutAgnosticOpTransposer { public: explicit PadTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class ReduceTransposer : public LayoutAgnosticOpTransposer { public: explicit ReduceTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool KeepDims(const utils::MutableNodeView& node); @@ -452,16 +453,16 @@ class ReverseV2Transposer : public LayoutAgnosticOpTransposer { public: explicit ReverseV2Transposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class SelectTransposer : public LayoutAgnosticOpTransposer { public: explicit SelectTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; protected: bool IsFaninScalarVector4D(const utils::MutableNodeView& fanin, int port); @@ -472,48 +473,48 @@ class ShapeTransposer : public LayoutAgnosticOpTransposer { public: explicit ShapeTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class ShapeNTransposer : public LayoutAgnosticOpTransposer { public: explicit ShapeNTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class SliceTransposer : public LayoutAgnosticOpTransposer { public: explicit SliceTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class SplitTransposer : public LayoutAgnosticOpTransposer { public: explicit SplitTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class SplitVTransposer : public LayoutAgnosticOpTransposer { public: explicit SplitVTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class SqueezeTransposer : public LayoutAgnosticOpTransposer { public: explicit SqueezeTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool IsInputConvertible(const TransposeContext& context, @@ -522,54 +523,55 @@ class SqueezeTransposer : public LayoutAgnosticOpTransposer { int rank) const; bool IsDimsSupported(const TransposeContext& context, const utils::MutableNodeView& node) const; - Status UpdateSqueezeDims(TransposeContext* context, - utils::MutableNodeView* node); + absl::Status UpdateSqueezeDims(TransposeContext* context, + utils::MutableNodeView* node); }; class StridedSliceTransposer : public LayoutAgnosticOpTransposer { public: explicit StridedSliceTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; private: bool IsMaskZero(const utils::MutableNodeView& node, absl::string_view mask); bool HasOnlyBeginEndMask(const utils::MutableNodeView& node); - Status PermuteMask(TransposeContext* context, utils::MutableNodeView* node, - absl::string_view mask); + absl::Status PermuteMask(TransposeContext* context, + utils::MutableNodeView* node, + absl::string_view mask); }; class SwitchTransposer : public LayoutAgnosticOpTransposer { public: explicit SwitchTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class TernaryOpTransposer : public LayoutAgnosticOpTransposer { public: explicit TernaryOpTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class TileTransposer : public LayoutAgnosticOpTransposer { public: explicit TileTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; class UnaryGradTransposer : public LayoutAgnosticOpTransposer { public: explicit UnaryGradTransposer() : LayoutAgnosticOpTransposer() {} - Status TransposeNode(TransposeContext* context, - utils::MutableNodeView* node) override; + absl::Status TransposeNode(TransposeContext* context, + utils::MutableNodeView* node) override; }; // Utils. @@ -577,15 +579,15 @@ class UnaryGradTransposer : public LayoutAgnosticOpTransposer { // Permutes elements according to permutation and replaces the original values. // Permutation and values must have same size. template -Status PermuteSingle(absl::string_view location, - absl::Span permutation, T* values) { +absl::Status PermuteSingle(absl::string_view location, + absl::Span permutation, T* values) { DCHECK(values != nullptr); int permutation_size = permutation.size(); if (values->size() != permutation_size) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Size of values ", values->size(), - " does not match size of permutation ", - permutation_size, " @ ", location)); + return absl::Status(absl::StatusCode::kInvalidArgument, + absl::StrCat("Size of values ", values->size(), + " does not match size of permutation ", + permutation_size, " @ ", location)); } typedef typename T::value_type V; std::vector elements(values->begin(), values->end()); @@ -599,15 +601,16 @@ Status PermuteSingle(absl::string_view location, // Permutes two elements at a time according to permutation and replaces the // original values. Values must be twice the size of permutation. template -Status PermuteDouble(absl::string_view location, - absl::Span permutation, T* values) { +absl::Status PermuteDouble(absl::string_view location, + absl::Span permutation, T* values) { DCHECK(values != nullptr); int permutation_size = permutation.size(); if (values->size() != permutation_size * 2) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Size of values ", values->size(), - " does not match twice the size of permutation ", - permutation_size, " @ ", location)); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Size of values ", values->size(), + " does not match twice the size of permutation ", + permutation_size, " @ ", location)); } typedef typename T::value_type V; std::vector elements(values->begin(), values->end()); diff --git a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc index c6c4a6127cbfd8..3403a83e5eaf59 100644 --- a/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc +++ b/tensorflow/core/grappler/optimizers/generic_layout_optimizer_transposer_test.cc @@ -66,7 +66,8 @@ constexpr char kOpTranspose[] = "Transpose"; class TransposerImpl : public Transposer { public: explicit TransposerImpl() : Transposer() {} - Status TransposeNode(TransposeContext*, utils::MutableNodeView*) override { + absl::Status TransposeNode(TransposeContext*, + utils::MutableNodeView*) override { return absl::OkStatus(); } }; @@ -115,8 +116,8 @@ Output SimpleConv2D(const Scope* scope, const DataType& data_type = DT_FLOAT) { return conv2d; } -Status CreateSimpleConv2DGraph(GraphDef* graph, - const DataType& data_type = DT_FLOAT) { +absl::Status CreateSimpleConv2DGraph(GraphDef* graph, + const DataType& data_type = DT_FLOAT) { Scope scope = Scope::NewRootScope(); auto conv2d = SimpleConv2D(&scope, data_type); auto output = ops::Identity(scope.WithOpName("output"), conv2d); @@ -124,8 +125,8 @@ Status CreateSimpleConv2DGraph(GraphDef* graph, return scope.ToGraphDef(graph); } -Status CreateSimpleFusedBatchNorm(GraphDef* graph, - const DataType& data_type = DT_FLOAT) { +absl::Status CreateSimpleFusedBatchNorm(GraphDef* graph, + const DataType& data_type = DT_FLOAT) { Scope scope = Scope::NewRootScope(); auto x = ops::RandomUniform(scope.WithOpName("x"), @@ -150,7 +151,7 @@ Status CreateSimpleFusedBatchNorm(GraphDef* graph, return scope.ToGraphDef(graph); } -Status CreateSimpleMaxPoolGrad(GraphDef* graph, bool use_grad_grad) { +absl::Status CreateSimpleMaxPoolGrad(GraphDef* graph, bool use_grad_grad) { Scope scope = Scope::NewRootScope(); auto input = ops::RandomUniform(scope.WithOpName("orig_input"), @@ -181,7 +182,7 @@ Status CreateSimpleMaxPoolGrad(GraphDef* graph, bool use_grad_grad) { return scope.ToGraphDef(graph); } -Status CreateSimpleBiasAddGrad(GraphDef* graph, const Input& shape) { +absl::Status CreateSimpleBiasAddGrad(GraphDef* graph, const Input& shape) { Scope scope = Scope::NewRootScope(); auto input = ops::RandomUniform(scope.WithOpName("input"), shape, DT_FLOAT); auto bag = @@ -192,9 +193,9 @@ Status CreateSimpleBiasAddGrad(GraphDef* graph, const Input& shape) { return scope.ToGraphDef(graph); } -Status CreateSimpleConv2DBackpropFilter(GraphDef* graph, - const DataType& data_type = DT_FLOAT, - absl::string_view padding = "SAME") { +absl::Status CreateSimpleConv2DBackpropFilter( + GraphDef* graph, const DataType& data_type = DT_FLOAT, + absl::string_view padding = "SAME") { Scope scope = Scope::NewRootScope(); auto input = ops::RandomUniform(scope.WithOpName("input"), @@ -227,8 +228,8 @@ Status CreateSimpleConv2DBackpropFilter(GraphDef* graph, return scope.ToGraphDef(graph); } -Status CreateSimpleConv2DBackpropInput(GraphDef* graph, - const DataType& data_type = DT_FLOAT) { +absl::Status CreateSimpleConv2DBackpropInput( + GraphDef* graph, const DataType& data_type = DT_FLOAT) { Scope scope = Scope::NewRootScope(); auto input_sizes = ops::Const(scope.WithOpName("input_sizes"), {kBatchSize, kHeight, kWidth, kDepthIn}); @@ -250,8 +251,8 @@ Status CreateSimpleConv2DBackpropInput(GraphDef* graph, return scope.ToGraphDef(graph); } -Status CreateSimpleFusedBatchNormGrad(GraphDef* graph, bool is_training, - const DataType& data_type = DT_FLOAT) { +absl::Status CreateSimpleFusedBatchNormGrad( + GraphDef* graph, bool is_training, const DataType& data_type = DT_FLOAT) { Scope scope = Scope::NewRootScope(); auto y_backprop = ops::RandomUniform(scope.WithOpName("y_backprop"), @@ -285,7 +286,7 @@ Status CreateSimpleFusedBatchNormGrad(GraphDef* graph, bool is_training, return scope.ToGraphDef(graph); } -Status CreateSimpleAddN(GraphDef* graph) { +absl::Status CreateSimpleAddN(GraphDef* graph) { Scope scope = Scope::NewRootScope(); auto input = ops::RandomUniform(scope.WithOpName("input"), @@ -309,7 +310,7 @@ Status CreateSimpleAddN(GraphDef* graph) { return scope.ToGraphDef(graph); } -Status CreateSimpleIdentityN(GraphDef* graph) { +absl::Status CreateSimpleIdentityN(GraphDef* graph) { Scope scope = Scope::NewRootScope(); auto conv2d_1_input = ops::RandomUniform(scope.WithOpName("conv2d_1_input"), @@ -573,7 +574,7 @@ TEST_F(TransposerTest, CreateTransposeNode) { EXPECT_EQ(transpose_node_name, "transpose_node-0-Transpose-NWCHToNCWH-LayoutOptimizer"); utils::Mutation* mutation = context.graph_view->GetMutationBuilder(); - Status status; + absl::Status status; // Placeholder node with empty name as transpose node is created with it's // first input not set. mutation->AddNode({}, &status); diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc index 78df573e2ce902..a93468e482746a 100644 --- a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc +++ b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc @@ -46,7 +46,7 @@ class CopyFromGpuToHostKernel : public AsyncOpKernel { ctx->op_device_context()->CopyDeviceTensorToCPU( &input, "CopyFromGpuToHost", static_cast(ctx->device()), - output, [ctx, done](const Status& s) { + output, [ctx, done](const absl::Status& s) { ctx->SetStatus(s); done(); }); @@ -75,7 +75,7 @@ class CopyFromHostToGpuKernel : public AsyncOpKernel { ctx->op_device_context()->CopyCPUTensorToDevice( &input, static_cast(ctx->device()), output, - [ctx, done](const Status& s) { + [ctx, done](const absl::Status& s) { ctx->SetStatus(s); done(); }); diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h index cf38fd2c4475f4..6b7ba893035ebe 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -56,12 +56,12 @@ class GraphOptimizer { // A return value of error::Aborted() can be used signal early termination of // the optimizer, e.g. if the optimization turned out to be a no-op. In this // case the content of *optimized_graph is undefined. - virtual Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) = 0; + virtual absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) = 0; // Subclasses may define a version of Optimize that consumes item. - virtual Status Optimize(Cluster* cluster, GrapplerItem&& item, - GraphDef* optimized_graph) { + virtual absl::Status Optimize(Cluster* cluster, GrapplerItem&& item, + GraphDef* optimized_graph) { return Optimize(cluster, item, optimized_graph); } diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc index a1035c38bfc4ec..8442e3ee945780 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.cc @@ -28,8 +28,8 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name) { } }; -Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, - NodeDef** node) { +absl::Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, + NodeDef** node) { string node_name = NodeName(input); NodeDef* node_by_name = ctx.node_map->GetNode(node_name); if (node_by_name == nullptr) { @@ -40,9 +40,9 @@ Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, return absl::OkStatus(); } -Status GetTensorProperties(const GraphOptimizerContext& ctx, - const string& tensor, - const OpInfo::TensorProperties** properties) { +absl::Status GetTensorProperties(const GraphOptimizerContext& ctx, + const string& tensor, + const OpInfo::TensorProperties** properties) { if (ctx.graph_properties == nullptr) { return errors::InvalidArgument("Graph properties are unknown."); } diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h index d7a8672064bae9..ed5549abcf9abd 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage.h @@ -66,11 +66,11 @@ struct GraphOptimizerContext { RewriterConfig::Toggle opt_level; }; -Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, - NodeDef** node); -Status GetTensorProperties(const GraphOptimizerContext& ctx, - const string& tensor, - const OpInfo::TensorProperties** properties); +absl::Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, + NodeDef** node); +absl::Status GetTensorProperties(const GraphOptimizerContext& ctx, + const string& tensor, + const OpInfo::TensorProperties** properties); NodeDef* AddCopyNode(const GraphOptimizerContext& ctx, const string& name, const NodeDef* node_to_copy); @@ -142,13 +142,13 @@ class GraphOptimizerStage { // TODO(ezhulenev): if it will appear that Result output parameter is not // sufficiently useful (used with a reason by most optimizers), get rid of it, // and remove template parameter. - virtual Status TrySimplify(NodeDef* node, Result* result) = 0; + virtual absl::Status TrySimplify(NodeDef* node, Result* result) = 0; // Return InvalidArgumentError if node is not supported by the optimizer // stage. // TODO(ezhulenev): make this check part of non-virtual public API // (TrySimplify), and make virtual implementation protected. - Status EnsureNodeIsSupported(const NodeDef* node) const { + absl::Status EnsureNodeIsSupported(const NodeDef* node) const { return IsSupported(node) ? absl::OkStatus() : errors::InvalidArgument( @@ -183,13 +183,13 @@ class GraphOptimizerStage { // Get a node by input name from a node map. Return an error if node was not // found. - Status GetInputNode(const string& input, NodeDef** node) const { + absl::Status GetInputNode(const string& input, NodeDef** node) const { return ::tensorflow::grappler::GetInputNode(ctx_, input, node); } // Lookup tensor properties by name. Tensor name might have non-zero port // number. Return an error if tensor node doesn't exists in a graph, or it // doesn't have properties defined for requested port. - Status GetTensorProperties( + absl::Status GetTensorProperties( const string& tensor, const OpInfo::TensorProperties** properties) const { return ::tensorflow::grappler::GetTensorProperties(ctx_, tensor, properties); @@ -257,7 +257,7 @@ class GraphOptimizerStagePipeline { bool PassThroughAllStages(NodeDef* node, Result* result) { for (auto& stage : stages_) { if (stage->IsSupported(node)) { - const Status stage_status = stage->TrySimplify(node, result); + const absl::Status stage_status = stage->TrySimplify(node, result); // Each stage must be "error safe" (just like exception safe). In // case of any error it must leave optimized graph unmodified. if (!stage_status.ok()) { @@ -275,12 +275,12 @@ class GraphOptimizerStagePipeline { // is true or a stage fails. // // Returns any stage failure status, or else OkStatus(). - Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) { + absl::Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) { for (auto& stage : stages_) { if (!stage->IsSupported(node)) { continue; } - const Status stage_status = stage->TrySimplify(node, result); + const absl::Status stage_status = stage->TrySimplify(node, result); if (!stage_status.ok()) { return stage_status; } else if (break_predicate_(*result)) { diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc index f36406225651a3..7e78f8e743dc81 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc +++ b/tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc @@ -43,7 +43,7 @@ class FakeOptimizerStage : public GraphOptimizerStage { ~FakeOptimizerStage() override = default; bool IsSupported(const NodeDef* node) const override { return true; } - Status TrySimplify(NodeDef* node, FakeResult* result) override { + absl::Status TrySimplify(NodeDef* node, FakeResult* result) override { return absl::OkStatus(); } }; diff --git a/tensorflow/core/grappler/optimizers/implementation_selector.cc b/tensorflow/core/grappler/optimizers/implementation_selector.cc index d9fe63f873b514..3b6b3f2f3be12b 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector.cc +++ b/tensorflow/core/grappler/optimizers/implementation_selector.cc @@ -144,8 +144,9 @@ void UpdateForwardIdentityNodeDtype(utils::MutableNodeView* forward_node, } } -Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName, - const FunctionApiInfo& apiInfo) { +absl::Status UpdateNodeDef(utils::MutableNodeView* node_view, + const string& funcName, + const FunctionApiInfo& apiInfo) { NodeDef* node_def = node_view->node(); VLOG(3) << "Node def before swap is: " << node_def->DebugString(); @@ -229,13 +230,13 @@ Status UpdateNodeDef(utils::MutableNodeView* node_view, const string& funcName, return absl::OkStatus(); } -Status ImplementationSelector::LoadFunctions(const GraphDef& graph) { +absl::Status ImplementationSelector::LoadFunctions(const GraphDef& graph) { lib_info_ = std::make_unique(); TF_RETURN_IF_ERROR(lib_info_->Init(graph.library())); return absl::OkStatus(); } -Status ImplementationSelector::MaybeOptimizeFunctionCall( +absl::Status ImplementationSelector::MaybeOptimizeFunctionCall( utils::MutableNodeView* node_view) const { // There are two ways of calling functions: // 1. By specifying an op name as a function name, or @@ -302,8 +303,8 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall( } // Finds the index of the device from the device name list. -Status FindDeviceIndex(const utils::MutableNodeView* device_index_node, - const string& device, int* index) { +absl::Status FindDeviceIndex(const utils::MutableNodeView* device_index_node, + const string& device, int* index) { DeviceNameUtils::ParsedName parsed_name; if (!DeviceNameUtils::ParseFullName(device, &parsed_name) || !parsed_name.has_type) { @@ -336,8 +337,8 @@ void RewriteDeviceIndexOp(utils::MutableNodeView* device_index_node, VLOG(2) << "Node after rewriting:" << node->DebugString(); } -Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const { - Status status; +absl::Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const { + absl::Status status; VLOG(2) << "graph before rewriting device index:" << graph->DebugString(); utils::MutableGraphView graph_view(graph, &status); TF_RETURN_IF_ERROR(status); @@ -360,7 +361,7 @@ Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const { int index; // If any error is thrown out during device parsing, we simply skip // and do not modify the DeviceIndexNode. - Status status = + absl::Status status = FindDeviceIndex(node_view, fanout.node_view()->GetDevice(), &index); if (status.ok()) { RewriteDeviceIndexOp(node_view, index); @@ -371,7 +372,8 @@ Status ImplementationSelector::SelectDeviceIndex(GraphDef* graph) const { return absl::OkStatus(); } -Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { +absl::Status ImplementationSelector::SelectImplementation( + GraphDef* graph) const { if (!graph->has_library()) { VLOG(2) << "Skipping graph since it does not have function def"; return absl::OkStatus(); @@ -381,7 +383,7 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { return absl::OkStatus(); } - Status status; + absl::Status status; utils::MutableGraphView graph_view(graph, &status); TF_RETURN_IF_ERROR(status); @@ -393,9 +395,9 @@ Status ImplementationSelector::SelectImplementation(GraphDef* graph) const { return absl::OkStatus(); } -Status ImplementationSelector::Optimize(Cluster* cluster, - const GrapplerItem& item, - GraphDef* optimized_graph) { +absl::Status ImplementationSelector::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph) { auto status = LoadFunctions(item.graph); // Eat up the error from function loading, since this optimizer might run // several times, and might try to run against functions generated by diff --git a/tensorflow/core/grappler/optimizers/implementation_selector.h b/tensorflow/core/grappler/optimizers/implementation_selector.h index 289063d701040c..8219e9b4a0f6ce 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector.h +++ b/tensorflow/core/grappler/optimizers/implementation_selector.h @@ -94,7 +94,7 @@ class ImplementationSelector : public CustomGraphOptimizer { public: ImplementationSelector() = default; ~ImplementationSelector() override = default; - Status Init( + absl::Status Init( const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { return absl::OkStatus(); } @@ -105,12 +105,13 @@ class ImplementationSelector : public CustomGraphOptimizer { bool UsesFunctionLibrary() const override { return false; } // This call is not thread-safe. - Status Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) override; + absl::Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; private: - Status LoadFunctions(const GraphDef& graph); - Status MaybeOptimizeFunctionCall(utils::MutableNodeView* node_view) const; + absl::Status LoadFunctions(const GraphDef& graph); + absl::Status MaybeOptimizeFunctionCall( + utils::MutableNodeView* node_view) const; // Finds all call sites for functions, then replace with the appropriate // implementation. @@ -123,7 +124,7 @@ class ImplementationSelector : public CustomGraphOptimizer { // may call into another function, so a function might have to be duplicated. // For simplicity, we do not change function bodies. Also, we do not change // gradients. - Status SelectImplementation(GraphDef* graph) const; + absl::Status SelectImplementation(GraphDef* graph) const; // Rewrites the DeviceIndex op with a Const op with value of the index of the // device the associcated Case op runs. @@ -185,7 +186,7 @@ class ImplementationSelector : public CustomGraphOptimizer { // device: "/device:GPU:0" // ... // } - Status SelectDeviceIndex(GraphDef* graph) const; + absl::Status SelectDeviceIndex(GraphDef* graph) const; std::unique_ptr lib_info_; diff --git a/tensorflow/core/grappler/optimizers/implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/implementation_selector_test.cc index 36b721cdce4b83..959359a27329c7 100644 --- a/tensorflow/core/grappler/optimizers/implementation_selector_test.cc +++ b/tensorflow/core/grappler/optimizers/implementation_selector_test.cc @@ -52,7 +52,7 @@ TEST_F(ImplementationSelectorTest, NoUpdate) { TF_ASSERT_OK(optimizer->Init()); GraphDef output; - const Status status = optimizer->Optimize(nullptr, item, &output); + const absl::Status status = optimizer->Optimize(nullptr, item, &output); TF_EXPECT_OK(status); // This is a trivial graph so there is nothing to update. diff --git a/tensorflow/core/grappler/optimizers/inference/BUILD b/tensorflow/core/grappler/optimizers/inference/BUILD index 3b6e92e6a0434c..41f6c0728cc7de 100644 --- a/tensorflow/core/grappler/optimizers/inference/BUILD +++ b/tensorflow/core/grappler/optimizers/inference/BUILD @@ -25,7 +25,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "batch_op_rewriter_proto_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":batch_op_rewriter_proto"], # ) diff --git a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h index 2f7a490f8e4e83..d15ff68bb595b0 100644 --- a/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h +++ b/tensorflow/core/grappler/optimizers/inference/batch_op_rewriter.h @@ -45,17 +45,16 @@ using ::tensorflow::serving::BatchOpRewriteConfig; // allocating batch threads per batch-op. class BatchOpRewriter : public ::tensorflow::grappler::CustomGraphOptimizer { public: - ::tensorflow::Status Init( + absl::Status Init( const ::tensorflow::RewriterConfig_CustomGraphOptimizer* config) override; std::string name() const override { return "batch_op_rewriter"; } bool UsesFunctionLibrary() const override { return false; } - ::tensorflow::Status Optimize( - ::tensorflow::grappler::Cluster* cluster, - const ::tensorflow::grappler::GrapplerItem& item, - ::tensorflow::GraphDef* optimized_graph) override; + absl::Status Optimize(::tensorflow::grappler::Cluster* cluster, + const ::tensorflow::grappler::GrapplerItem& item, + ::tensorflow::GraphDef* optimized_graph) override; private: BatchOpRewriteConfig config_; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index 94c2c22f472f19..b32b6ab850467e 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -55,7 +55,7 @@ namespace tensorflow { namespace grappler { namespace { -using TensorVector = gtl::InlinedVector; +using TensorVector = absl::InlinedVector; class LoopInvariantNodeMotionOptimizer { public: diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 1929925a0b6566..dbbab4e2c9c492 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -107,7 +107,7 @@ constexpr char kFill[] = "fill"; constexpr int kMissingIndex = -1; struct RemapperContext { - explicit RemapperContext(GrapplerItem* item, Status* status, + explicit RemapperContext(GrapplerItem* item, absl::Status* status, RewriterConfig::CpuLayout cpu_layout_conversion, bool xla_auto_clustering_on, bool xla_cpu_jit_disable_fusion) @@ -1519,7 +1519,43 @@ bool IsMatchedMatMulBiasAddAndGeluExact( } // Mul: "output" }; - // Pattern 2: + // Pattern 2: Erfc + // Const: 1/sqrt(2) Const: 1/2 + // \ \ + // * --> BiasAdd --> Neg --> Mul --> Erfc --> Mul --> Mul + // / \____________________________________/ + // MatMul + static utils::OpTypePattern* gelu_exact_pattern2 = new utils::OpTypePattern + {"Mul", "output", NodeStatus::kReplace, + { + {"Mul", "one_half_x_erfc", NodeStatus::kRemove, + { + {"Const", "one_half", NodeStatus::kRemain}, + {"Erfc", "erfc", NodeStatus::kRemove, + { + {"Mul", "neg_bias_add_x_sqrt_one_half", NodeStatus::kRemove, + { + {"Const", "sqrt_one_half", NodeStatus::kRemain}, + {"Neg", "neg", NodeStatus::kRemove, + {{"BiasAdd", "bias_add", NodeStatus::kRemove}} + }, // Neg: "neg" + } + } // Mul: "neg_bias_add_x_sqrt_one_half" + } // Erfc: "erfc" + } + } // Mul: "one_half_x_erfc" + }, + {"BiasAdd", "bias_add", NodeStatus::kRemove, + { + {"MatMul", "matmul", NodeStatus::kRemove}, + {"*", "bias", NodeStatus::kRemain} + } + } // BiasAdd: "bias_add" + } + }; // Mul: "output" + + + // Pattern 3: // Cast|Const: 1/sqrt(2) Cast|Const: 1 // \ \ // * --> BiasAdd --> Mul --> Erf --> Add|AddV2 --> Mul @@ -1527,7 +1563,7 @@ bool IsMatchedMatMulBiasAddAndGeluExact( // MatMul ----------------------------> Mul // / // Cast|Const: 1/2 - static utils::OpTypePattern* gelu_exact_pattern2 = new utils::OpTypePattern + static utils::OpTypePattern* gelu_exact_pattern3 = new utils::OpTypePattern {"Mul", "output", NodeStatus::kReplace, { {"Add|AddV2", "erf_plus_one", NodeStatus::kRemove, @@ -1567,17 +1603,18 @@ bool IsMatchedMatMulBiasAddAndGeluExact( std::set dummy_remove_node_indices; if (!matched_nodes_map) matched_nodes_map = &dummy_matched_nodes_map; if (!remove_node_indices) remove_node_indices = &dummy_remove_node_indices; - if (graph_matcher.GetMatchedNodes(*gelu_exact_pattern, ctx.nodes_to_preserve, - node_view, matched_nodes_map, - remove_node_indices)) { - return true; + auto patterns = {gelu_exact_pattern, gelu_exact_pattern2, + gelu_exact_pattern3}; + for (auto& pattern : patterns) { + matched_nodes_map->clear(); + remove_node_indices->clear(); + if (graph_matcher.GetMatchedNodes(*pattern, ctx.nodes_to_preserve, + node_view, matched_nodes_map, + remove_node_indices)) { + return true; + } } - // Pattern 1 not matched, check for pattern 2 - matched_nodes_map->clear(); - remove_node_indices->clear(); - return graph_matcher.GetMatchedNodes(*gelu_exact_pattern2, - ctx.nodes_to_preserve, node_view, - matched_nodes_map, remove_node_indices); + return false; } // Gelu in python api generates a number of nodes in the graph. Depending on the @@ -1742,8 +1779,12 @@ bool FindMatMulBiasAddAndGelu(RemapperContext* ctx, int node_index, } // Check if the matched constants have desired values. - std::map values_map = { - {"sqrt_one_half", 0.707106}, {"one", 1.0}, {"one_half", 0.5}}; + std::map values_map = {{"sqrt_one_half", 0.707106}, + {"one_half", 0.5}}; + // GeluExact Pattern 2 (Erfc) does not have constant "one". + if (matched_nodes_map->find("one") != matched_nodes_map->end()) { + values_map["one"] = 1.0; + } if (!VerifyConstants(ctx, matched_nodes_map, &values_map)) return false; } else if (found_gelu_approximate) { NodeDef* matmul_node = @@ -2195,7 +2236,7 @@ bool FindMklLayerNorm(RemapperContext* ctx, int node_index, // Additional check for LayerNorm if (found_op_type_match) { if (!ctx->inferred_graph_properties) { - Status s = ctx->graph_properties.InferStatically( + absl::Status s = ctx->graph_properties.InferStatically( /*assume_valid_feeds=*/true, /*aggressive_shape_inference=*/false, /*include_input_tensor_values=*/true, @@ -2908,7 +2949,7 @@ bool FindFusedBatchMatMul(RemapperContext* ctx, int node_index, // addend is 4D tensor with second dim_size = 1. if (!found_op_type_match) return false; if (!ctx->inferred_graph_properties) { - Status s = ctx->graph_properties.InferStatically( + absl::Status s = ctx->graph_properties.InferStatically( /*assume_valid_feeds=*/true, /*aggressive_shape_inference=*/false, /*include_input_tensor_values=*/false, @@ -3000,7 +3041,7 @@ bool FindInstanceNorm(RemapperContext* ctx, int node_index, // Additional checks for InstanceNorm if (!ctx->inferred_graph_properties) { - Status s = ctx->graph_properties.InferStatically( + absl::Status s = ctx->graph_properties.InferStatically( /*assume_valid_feeds=*/true, /*aggressive_shape_inference=*/false, /*include_input_tensor_values=*/false, @@ -3261,10 +3302,10 @@ void SetFusedOpAttributes(NodeDef* fused, SetAttrValue(epsilon, &(*attr)["epsilon"]); // required only for BatchNorm } -Status AddFusedContractionNode(RemapperContext* ctx, - const ContractionWithBiasAdd& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedContractionNode(RemapperContext* ctx, + const ContractionWithBiasAdd& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern"; const GraphDef* graph = ctx->graph_view.graph(); @@ -3298,7 +3339,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, SetFusedOpAttributes(&fused_op, {"BiasAdd"}); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3309,10 +3350,10 @@ Status AddFusedContractionNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedContractionNode(RemapperContext* ctx, - const ContractionWithActivation& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedContractionNode(RemapperContext* ctx, + const ContractionWithActivation& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& contraction = graph->node(matched.contraction); const NodeDef& activation = graph->node(matched.activation); @@ -3351,7 +3392,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, fused_op.set_name(activation.name()); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3362,7 +3403,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedContractionNode( +absl::Status AddFusedContractionNode( RemapperContext* ctx, const ContractionWithBiasAddAndActivation& matched, std::vector* invalidated_nodes, std::vector* nodes_to_delete) { DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern"; @@ -3404,7 +3445,7 @@ Status AddFusedContractionNode( SetFusedOpAttributes(&fused_op, {"BiasAdd", activation.op()}); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3416,10 +3457,10 @@ Status AddFusedContractionNode( return absl::OkStatus(); } -Status AddFusedConvNode(RemapperContext* ctx, - const ContractionWithSqueezeAndBiasAdd& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedConvNode(RemapperContext* ctx, + const ContractionWithSqueezeAndBiasAdd& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { DCHECK(IsDeviceCompatible(*ctx, matched)) << "Unsupported fusion pattern"; const GraphDef* graph = ctx->graph_view.graph(); @@ -3457,7 +3498,7 @@ Status AddFusedConvNode(RemapperContext* ctx, remapped_squeeze.set_input(0, contraction.name()); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_conv), &status); TF_RETURN_IF_ERROR(status); mutation->AddNode(std::move(remapped_squeeze), &status); @@ -3471,10 +3512,10 @@ Status AddFusedConvNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedConv2DNode(RemapperContext* ctx, - const ContractionWithBatchNorm& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedConv2DNode(RemapperContext* ctx, + const ContractionWithBatchNorm& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& contraction = graph->node(matched.contraction); DCHECK(IsConv2D(contraction)) << "Only Conv2D supported for now"; @@ -3499,7 +3540,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, /*num_args=*/4, /*epsilon=*/matched.epsilon); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_conv2d), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3510,10 +3551,9 @@ Status AddFusedConv2DNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedConv2DNode(RemapperContext* ctx, - const ContractionWithBatchNormAndActivation& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedConv2DNode( + RemapperContext* ctx, const ContractionWithBatchNormAndActivation& matched, + std::vector* invalidated_nodes, std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& contraction = graph->node(matched.contraction); @@ -3543,7 +3583,7 @@ Status AddFusedConv2DNode(RemapperContext* ctx, /*num_args=*/4, /*epsilon=*/matched.epsilon); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_conv2d), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3555,10 +3595,9 @@ Status AddFusedConv2DNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedContractionNode(RemapperContext* ctx, - const ContractionWithBiasAddAndAdd& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedContractionNode( + RemapperContext* ctx, const ContractionWithBiasAddAndAdd& matched, + std::vector* invalidated_nodes, std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& contraction = graph->node(matched.contraction); const NodeDef& bias_add = graph->node(matched.bias_add); @@ -3597,7 +3636,7 @@ Status AddFusedContractionNode(RemapperContext* ctx, SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(contraction_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3609,9 +3648,10 @@ Status AddFusedContractionNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedConv3DNode(RemapperContext* ctx, + const PadWithConv3D& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& contraction = graph->node(matched.contraction_idx); const NodeDef& pad_node_def = graph->node(matched.pad_idx); @@ -3651,7 +3691,7 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, } utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3661,7 +3701,7 @@ Status AddFusedConv3DNode(RemapperContext* ctx, const PadWithConv3D& matched, return absl::OkStatus(); } -Status AddFusedContractionNode( +absl::Status AddFusedContractionNode( RemapperContext* ctx, const ContractionWithBiasAndAddActivation& matched, std::vector* invalidated_nodes, std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); @@ -3695,7 +3735,7 @@ Status AddFusedContractionNode( SetFusedOpAttributes(&fused_conv, {"BiasAdd", "Add", activation.op()}, 2); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_conv), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3708,7 +3748,7 @@ Status AddFusedContractionNode( return absl::OkStatus(); } -Status FuseContractionWithBiasAddAndHardSwish( +absl::Status FuseContractionWithBiasAddAndHardSwish( RemapperContext* ctx, std::map* matched_nodes_map, std::set* remove_node_indices, std::vector* invalidated_nodes, std::vector* nodes_to_delete) { @@ -3737,7 +3777,7 @@ Status FuseContractionWithBiasAddAndHardSwish( SetFusedOpAttributes(&fused_node, {"BiasAdd", "_FusedHardSwish"}); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3749,11 +3789,11 @@ Status FuseContractionWithBiasAddAndHardSwish( return absl::OkStatus(); } -Status FuseConv2DSwish(RemapperContext* ctx, - const std::map& matched_nodes_map, - const std::set& remove_node_indices, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status FuseConv2DSwish(RemapperContext* ctx, + const std::map& matched_nodes_map, + const std::set& remove_node_indices, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const NodeDef* mul = ctx->graph_view.GetNode(matched_nodes_map.at("mulToswish"))->node(); const NodeDef* conv2d = @@ -3788,7 +3828,7 @@ Status FuseConv2DSwish(RemapperContext* ctx, CopyConv2DAttributes(*conv2d, &fused_op); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3802,7 +3842,7 @@ Status FuseConv2DSwish(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedMatMulBiasAddAndGelu( +absl::Status AddFusedMatMulBiasAddAndGelu( RemapperContext* ctx, const std::map& matched_nodes_map, const std::set& remove_node_indices, std::vector* invalidated_nodes, std::vector* nodes_to_delete, @@ -3833,7 +3873,7 @@ Status AddFusedMatMulBiasAddAndGelu( SetFusedOpAttributes(&fused_node, {"BiasAdd", "GeluExact"}); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3845,13 +3885,13 @@ Status AddFusedMatMulBiasAddAndGelu( return absl::OkStatus(); } -Status AddMklLayerNorm(RemapperContext* ctx, - const std::map& matched_nodes_map, - const std::set& remove_node_indices, - const std::vector& input_node_names, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete, - const float epsilon) { +absl::Status AddMklLayerNorm(RemapperContext* ctx, + const std::map& matched_nodes_map, + const std::set& remove_node_indices, + const std::vector& input_node_names, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete, + const float epsilon) { auto* output_node = ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node(); @@ -3866,7 +3906,7 @@ Status AddMklLayerNorm(RemapperContext* ctx, SetAttrValue(epsilon, &(*attr)["epsilon"]); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3878,7 +3918,7 @@ Status AddMklLayerNorm(RemapperContext* ctx, return absl::OkStatus(); } -Status ReplaceMulMaximumWithLeakyRelu( +absl::Status ReplaceMulMaximumWithLeakyRelu( RemapperContext* ctx, const std::map& matched_nodes_map, const std::set& remove_node_indices, std::vector* invalidated_nodes, std::vector* nodes_to_delete, @@ -3914,7 +3954,7 @@ Status ReplaceMulMaximumWithLeakyRelu( SetAttrValue(alpha, &(*attr)["alpha"]); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3928,7 +3968,7 @@ Status ReplaceMulMaximumWithLeakyRelu( return absl::OkStatus(); } -Status ReplaceSigmoidMulWithSwish( +absl::Status ReplaceSigmoidMulWithSwish( RemapperContext* ctx, const std::map& matched_nodes_map, const std::set& remove_node_indices, std::vector* invalidated_nodes, std::vector* nodes_to_delete) { @@ -3947,7 +3987,7 @@ Status ReplaceSigmoidMulWithSwish( (*attr)["T"] = mul->attr().at("T"); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -3960,10 +4000,10 @@ Status ReplaceSigmoidMulWithSwish( return absl::OkStatus(); } -Status AddFusedBatchNormExNode(RemapperContext* ctx, - const FusedBatchNormEx& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedBatchNormExNode(RemapperContext* ctx, + const FusedBatchNormEx& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& fused_batch_norm = graph->node(matched.fused_batch_norm); const NodeDef& activation = graph->node(matched.activation); @@ -4014,7 +4054,7 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx, (*identity_op.mutable_attr())["T"] = attrs->at("T"); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); mutation->AddNode(std::move(identity_op), &status); @@ -4030,10 +4070,10 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedBatchNormGradExNode(RemapperContext* ctx, - const FusedBatchNormGradEx& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedBatchNormGradExNode(RemapperContext* ctx, + const FusedBatchNormGradEx& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& fused_batch_norm_grad = graph->node(matched.fused_batch_norm_grad); @@ -4084,7 +4124,7 @@ Status AddFusedBatchNormGradExNode(RemapperContext* ctx, (*identity_op.mutable_attr())["T"] = attrs->at("T"); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); if (matched.side_input_grad != kMissingIndex) { @@ -4103,7 +4143,8 @@ Status AddFusedBatchNormGradExNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { +absl::Status AddBatchNormNodes(RemapperContext* ctx, + const FusedBatchNorm& matched) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& fused_node = graph->node(matched.fused_batch_norm); VLOG(2) << "Optimizing fused batch norm node " @@ -4116,7 +4157,7 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { string variance = fused_node.input(4); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; string x_format = fused_node.attr().at(kDataFormat).s(); if (x_format == "NCHW" || x_format == "NCDHW") { @@ -4298,10 +4339,10 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { return mutation->Apply(); } -Status AddTensorToHashBucketNode(RemapperContext* ctx, - const TensorToHashBucket& matched, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddTensorToHashBucketNode(RemapperContext* ctx, + const TensorToHashBucket& matched, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { const GraphDef* graph = ctx->graph_view.graph(); const NodeDef& pre_as_string = graph->node(matched.pre_as_string); const NodeDef& as_string = graph->node(matched.as_string); @@ -4325,7 +4366,7 @@ Status AddTensorToHashBucketNode(RemapperContext* ctx, (*attr)["num_buckets"] = src_attr1.at("num_buckets"); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_op), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -4336,12 +4377,12 @@ Status AddTensorToHashBucketNode(RemapperContext* ctx, return absl::OkStatus(); } -Status AddFusedBatchMatMul(RemapperContext* ctx, - const std::map& matched_nodes_map, - const std::set& remove_node_indices, - const std::vector& input_node_names, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete) { +absl::Status AddFusedBatchMatMul(RemapperContext* ctx, + const std::map& matched_nodes_map, + const std::set& remove_node_indices, + const std::vector& input_node_names, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete) { auto* output_node = ctx->graph_view.GetNode(matched_nodes_map.at("output"))->node(); auto* batch_matmul_node = @@ -4357,7 +4398,7 @@ Status AddFusedBatchMatMul(RemapperContext* ctx, SetFusedOpAttributes(&fused_node, {"Mul", "Add"}, /*num_args=*/2); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -4383,12 +4424,12 @@ std::vector GetTensorValues(const Tensor& tensor) { return result_vector; } -Status AddMklFusedInstanceNorm(RemapperContext* ctx, - std::map* matched_nodes_map, - std::set* remove_node_indices, - std::vector* invalidated_nodes, - std::vector* nodes_to_delete, - bool fuse_activation) { +absl::Status AddMklFusedInstanceNorm(RemapperContext* ctx, + std::map* matched_nodes_map, + std::set* remove_node_indices, + std::vector* invalidated_nodes, + std::vector* nodes_to_delete, + bool fuse_activation) { auto* output_node = ctx->graph_view.GetNode(matched_nodes_map->at("output"))->node(); auto* input_node = @@ -4478,7 +4519,7 @@ Status AddMklFusedInstanceNorm(RemapperContext* ctx, } utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -4632,7 +4673,7 @@ bool FindSoftplusAndTanhAndMul(RemapperContext* ctx, int node_index, return found_op_type_match; } -Status ReplaceSoftplusTanhAndMulWithMish( +absl::Status ReplaceSoftplusTanhAndMulWithMish( RemapperContext* ctx, const std::map* matched_nodes_map, const std::set* remove_node_indices, std::vector* invalidated_nodes, std::vector* nodes_to_delete) { @@ -4652,7 +4693,7 @@ Status ReplaceSoftplusTanhAndMulWithMish( (*fused_node_attr)["T"] = old_mul_node->attr().at("T"); utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); - Status status; + absl::Status status; mutation->AddNode(std::move(fused_node), &status); TF_RETURN_IF_ERROR(status); TF_RETURN_IF_ERROR(mutation->Apply()); @@ -4831,10 +4872,10 @@ inline bool IsXlaCpuGlobalJitOn() { } } // namespace -Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, - GraphDef* optimized_graph) { +absl::Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { GrapplerItem mutable_item = item; - Status status; + absl::Status status; bool xla_cpu_jit_disable_fusion = xla_auto_clustering_on_ && IsXlaCpuGlobalJitOn(); #ifdef DNNL_AARCH64_USE_ACL diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index bd2882cfd3e77d..1c52fa321ec0e7 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -435,7 +435,7 @@ void EraseNodesFromGraph(const std::set& nodes_to_delete, } \ break -Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { +absl::Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { // TODO(rmlarsen): Support more general shapes. // TODO(lyandy): Change `value` to be int64 once int64 -> qint32 is supported. if (tensor->NumElements() != 1) { @@ -470,7 +470,7 @@ Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { #undef HANDLE_CASE -Status CheckAttrExists(const NodeDef& node, const string& key) { +absl::Status CheckAttrExists(const NodeDef& node, const string& key) { if (!HasNodeAttr(node, key)) { return errors::InvalidArgument("Node '", node.name(), "' lacks '", key, "' attr: ", node.ShortDebugString()); @@ -478,14 +478,15 @@ Status CheckAttrExists(const NodeDef& node, const string& key) { return absl::OkStatus(); } -Status CheckAttrsExist(const NodeDef& node, absl::Span keys) { +absl::Status CheckAttrsExist(const NodeDef& node, + absl::Span keys) { for (const string& key : keys) { TF_RETURN_IF_ERROR(CheckAttrExists(node, key)); } return absl::OkStatus(); } -Status IsKernelRegisteredForNode( +absl::Status IsKernelRegisteredForNode( absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info, absl::string_view node_op, absl::string_view node_device, @@ -500,7 +501,7 @@ Status IsKernelRegisteredForNode( node_op, node_device, node_attrs, nullptr, nullptr); } -Status IsKernelRegisteredForNode(const NodeDef& node) { +absl::Status IsKernelRegisteredForNode(const NodeDef& node) { return IsKernelRegisteredForNode(node.name(), node.has_experimental_debug_info(), node.experimental_debug_info(), node.op(), diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index bfa5655f80aae7..edba785f0b9d5d 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -379,10 +379,11 @@ int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map); void DedupControlInputs(NodeDef* node); // Returns an error if an attribute with the given key does not exist in node. -Status CheckAttrExists(const NodeDef& node, const string& key); +absl::Status CheckAttrExists(const NodeDef& node, const string& key); // Returns an error if attributes with the given keys do not exist in node. -Status CheckAttrsExist(const NodeDef& node, absl::Span keys); +absl::Status CheckAttrsExist(const NodeDef& node, + absl::Span keys); // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. @@ -407,14 +408,14 @@ void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, // Returns OkStatus() if a kernel is registered for node.op() on the device // type corresponding to node.device(). -Status IsKernelRegisteredForNode( +absl::Status IsKernelRegisteredForNode( absl::string_view node_name, bool has_experimental_debug_info, const NodeDef_ExperimentalDebugInfo& experimental_debug_info, absl::string_view node_op, absl::string_view node_device, AttrSlice node_attrs); -Status IsKernelRegisteredForNode(const NodeDef& node); +absl::Status IsKernelRegisteredForNode(const NodeDef& node); -Status SetTensorValue(DataType dtype, int value, Tensor* tensor); +absl::Status SetTensorValue(DataType dtype, int value, Tensor* tensor); void EraseNodesFromGraph(const std::set& nodes_to_delete, GraphDef* graph); diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 37bf0785b35923..9bc94d5f7b083e 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -432,7 +432,7 @@ TEST(CheckAttrExists, All) { TF_EXPECT_OK(CheckAttrsExist(node, {"apple", "pear"})); TF_EXPECT_OK(CheckAttrsExist(node, {"pear", "apple"})); - Status status = CheckAttrExists(node, "banana"); + absl::Status status = CheckAttrExists(node, "banana"); EXPECT_FALSE(status.ok()); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains( @@ -600,7 +600,7 @@ template void TestSetTensorValue(DataType type, int val, bool success, absl::string_view error_msg) { Tensor t(type, TensorShape({})); - Status s = SetTensorValue(t.dtype(), val, &t); + absl::Status s = SetTensorValue(t.dtype(), val, &t); EXPECT_EQ(s.ok(), success); if (s.ok()) { test::ExpectTensorEqual(Tensor(static_cast(val)), t); diff --git a/tensorflow/core/grappler/verifiers/graph_verifier.h b/tensorflow/core/grappler/verifiers/graph_verifier.h index 10fd201eadcfd3..53d62e4c986d68 100644 --- a/tensorflow/core/grappler/verifiers/graph_verifier.h +++ b/tensorflow/core/grappler/verifiers/graph_verifier.h @@ -46,7 +46,7 @@ class GraphVerifier { // Implement an algorithm to verify the specified graph. // The return value is a Status that represents a concatenation of Status of // each verification step. - virtual Status Verify(const GraphDef& graph) = 0; + virtual absl::Status Verify(const GraphDef& graph) = 0; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/verifiers/structure_verifier.cc b/tensorflow/core/grappler/verifiers/structure_verifier.cc index 819605d80db46e..b4568f5896b56e 100644 --- a/tensorflow/core/grappler/verifiers/structure_verifier.cc +++ b/tensorflow/core/grappler/verifiers/structure_verifier.cc @@ -31,7 +31,7 @@ namespace tensorflow { namespace grappler { // TODO(ashwinm): Expand this to add more structural checks. -Status StructureVerifier::Verify(const GraphDef& graph) { +absl::Status StructureVerifier::Verify(const GraphDef& graph) { StatusGroup status_group; FunctionLibraryDefinition function_library(OpRegistry::Global(), diff --git a/tensorflow/core/grappler/verifiers/structure_verifier.h b/tensorflow/core/grappler/verifiers/structure_verifier.h index ab719f1214eebb..de77933fedac10 100644 --- a/tensorflow/core/grappler/verifiers/structure_verifier.h +++ b/tensorflow/core/grappler/verifiers/structure_verifier.h @@ -34,7 +34,7 @@ class StructureVerifier : public GraphVerifier { string name() const override { return "structure_verifier"; }; - Status Verify(const GraphDef& graph) override; + absl::Status Verify(const GraphDef& graph) override; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc index 6ce108b92e98f2..95c0a759159c91 100644 --- a/tensorflow/core/grappler/verifiers/structure_verifier_test.cc +++ b/tensorflow/core/grappler/verifiers/structure_verifier_test.cc @@ -42,7 +42,7 @@ class StructureVerifierTest : public ::testing::Test { std::unique_ptr verifier_; }; -Status Scalars(shape_inference::InferenceContext* c) { +absl::Status Scalars(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->Scalar()); } @@ -84,7 +84,7 @@ TEST_F(StructureVerifierTest, OpNotRegistered) { "node { name: 'input' op: 'OpNotRegistered' }" "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }"); - Status status = verifier_->Verify(graph_); + absl::Status status = verifier_->Verify(graph_); EXPECT_TRUE(errors::IsNotFound(status)); EXPECT_TRUE(absl::StrContains(status.message(), "Op type not registered")); } @@ -93,7 +93,7 @@ TEST_F(StructureVerifierTest, DuplicateNodeNames) { SetGraph( "node { name: 'A' op: 'TestParams' }" "node { name: 'A' op: 'TestInput' }"); - Status status = verifier_->Verify(graph_); + absl::Status status = verifier_->Verify(graph_); EXPECT_TRUE(errors::IsAlreadyExists(status)); EXPECT_TRUE(absl::StrContains(status.message(), "Node already exists:")); } @@ -103,7 +103,7 @@ TEST_F(StructureVerifierTest, GraphWithInvalidCycle) { "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }" "node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }"); - Status status = verifier_->Verify(graph_); + absl::Status status = verifier_->Verify(graph_); EXPECT_TRUE(errors::IsInvalidArgument(status)); EXPECT_TRUE(absl::StrContains( status.message(), "The graph couldn't be sorted in topological order")); diff --git a/tensorflow/core/ir/tf_op_wrapper.h b/tensorflow/core/ir/tf_op_wrapper.h index 383d6aa2f52e85..e295647c091061 100644 --- a/tensorflow/core/ir/tf_op_wrapper.h +++ b/tensorflow/core/ir/tf_op_wrapper.h @@ -20,7 +20,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/iterator_range.h" -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project diff --git a/tensorflow/core/ir/types/dialect.h b/tensorflow/core/ir/types/dialect.h index 7c1a1cda1bec94..b0b601e36d7b3a 100644 --- a/tensorflow/core/ir/types/dialect.h +++ b/tensorflow/core/ir/types/dialect.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index ed240a3e54936e..75d45dffebafa0 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1508,7 +1508,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/platform:status_matchers", "@com_google_absl//absl/strings", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -3644,7 +3644,6 @@ cc_library( ]) + if_cuda([ "@local_xla//xla/stream_executor/cuda:cublas_lt_header", ]) + if_rocm([ - "@local_xla//xla/stream_executor/platform:dso_loader", "@local_xla//xla/stream_executor/rocm:hipblas_lt_header", ]) + if_static(["//tensorflow/core/platform:tensor_float_32_utils"]), ) diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc index 67a022d2de7b99..af52bce3e8b5dc 100644 --- a/tensorflow/core/kernels/aggregate_ops.cc +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -170,7 +170,7 @@ class AddNOp : public OpKernel { // : ctx->input(ix).scalar()()) // This reduces (possibly expensive) copying of Variants from // the inputs into temp at the lowest levels of the summation tree. - static inline Status AddVariantTo( + static inline absl::Status AddVariantTo( OpKernelContext* ctx, const int lhs_ix, const int rhs_ix, absl::InlinedVector* temp, absl::InlinedVector* temp_filled) { diff --git a/tensorflow/core/kernels/as_string_op_test.cc b/tensorflow/core/kernels/as_string_op_test.cc index d3e5f405b42333..58b976e14345b3 100644 --- a/tensorflow/core/kernels/as_string_op_test.cc +++ b/tensorflow/core/kernels/as_string_op_test.cc @@ -30,9 +30,9 @@ namespace { class AsStringGraphTest : public OpsTestBase { protected: - Status Init(DataType input_type, const string& fill = "", int width = -1, - int precision = -1, bool scientific = false, - bool shortest = false) { + absl::Status Init(DataType input_type, const string& fill = "", + int width = -1, int precision = -1, bool scientific = false, + bool shortest = false) { TF_CHECK_OK(NodeDefBuilder("op", "AsString") .Input(FakeInput(input_type)) .Attr("fill", fill) @@ -171,16 +171,16 @@ TEST_F(AsStringGraphTest, Variant) { } TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) { - Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1, - /*scientific=*/true, /*shortest=*/true); + absl::Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/true, /*shortest=*/true); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains( s.message(), "Cannot select both scientific and shortest notation")); } TEST_F(AsStringGraphTest, NoShortestForNonFloat) { - Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, - /*scientific=*/false, /*shortest=*/true); + absl::Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/false, /*shortest=*/true); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains( s.message(), @@ -188,8 +188,8 @@ TEST_F(AsStringGraphTest, NoShortestForNonFloat) { } TEST_F(AsStringGraphTest, NoScientificForNonFloat) { - Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, - /*scientific=*/true); + absl::Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1, + /*scientific=*/true); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains( s.message(), @@ -197,14 +197,14 @@ TEST_F(AsStringGraphTest, NoScientificForNonFloat) { } TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) { - Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5); + absl::Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE( absl::StrContains(s.message(), "precision not supported for datatype")); } TEST_F(AsStringGraphTest, LongFill) { - Status s = Init(DT_INT32, /*fill=*/"asdf"); + absl::Status s = Init(DT_INT32, /*fill=*/"asdf"); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains(s.message(), "Fill string must be one or fewer characters")); @@ -241,13 +241,13 @@ TEST_F(AsStringGraphTest, FillWithChar1) { } TEST_F(AsStringGraphTest, FillWithChar3) { - Status s = Init(DT_INT32, /*fill=*/"s"); + absl::Status s = Init(DT_INT32, /*fill=*/"s"); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains(s.message(), "Fill argument not supported")); } TEST_F(AsStringGraphTest, FillWithChar4) { - Status s = Init(DT_INT32, /*fill=*/"n"); + absl::Status s = Init(DT_INT32, /*fill=*/"n"); ASSERT_EQ(error::INVALID_ARGUMENT, s.code()); ASSERT_TRUE(absl::StrContains(s.message(), "Fill argument not supported")); } diff --git a/tensorflow/core/kernels/barrier_ops.cc b/tensorflow/core/kernels/barrier_ops.cc index c3cb388cd65df9..85c3947fc8730e 100644 --- a/tensorflow/core/kernels/barrier_ops.cc +++ b/tensorflow/core/kernels/barrier_ops.cc @@ -84,7 +84,7 @@ class Barrier : public ResourceBase { queue_component_shapes, strings::StrCat(name_, "_queue")); } - Status Initialize() { return ready_queue_->Initialize(); } + absl::Status Initialize() { return ready_queue_->Initialize(); } template void TryInsertMany(const Tensor& keys, int component_index, @@ -304,10 +304,12 @@ class Barrier : public ResourceBase { protected: template - Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values, const TensorShape& element_shape, - int component_index, int i, - std::vector* ready_tuples, bool* new_elements) + absl::Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values, + const TensorShape& element_shape, + int component_index, int i, + std::vector* ready_tuples, + bool* new_elements) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto keys_vec = keys.flat(); auto values_matrix = values.flat_outer_dims(); @@ -459,7 +461,7 @@ class BarrierOp : public ResourceOpKernel { } private: - Status CreateResource(Barrier** barrier) override + absl::Status CreateResource(Barrier** barrier) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { *barrier = new Barrier(value_component_types_, value_component_shapes_, cinfo_.name()); @@ -469,7 +471,7 @@ class BarrierOp : public ResourceOpKernel { return (*barrier)->Initialize(); } - Status VerifyResource(Barrier* barrier) override + absl::Status VerifyResource(Barrier* barrier) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (barrier->component_types() != value_component_types_) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/batch_kernel_test_util.cc b/tensorflow/core/kernels/batch_kernel_test_util.cc index bda3c25b182973..8e73e62f4e1b8d 100644 --- a/tensorflow/core/kernels/batch_kernel_test_util.cc +++ b/tensorflow/core/kernels/batch_kernel_test_util.cc @@ -33,7 +33,7 @@ bool BatchFunctionKernelTestAccess::enable_adaptive_batch_threads() const { return kernel_->enable_adaptive_batch_threads_; } -Status BatchFunctionKernelTestBase::Init(bool enable_adaptive_scheduler) { +absl::Status BatchFunctionKernelTestBase::Init(bool enable_adaptive_scheduler) { std::vector input_dtypes({DataType::DT_INT64, DataType::DT_INT64}); std::vector inputs( {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64}), diff --git a/tensorflow/core/kernels/batch_kernel_test_util.h b/tensorflow/core/kernels/batch_kernel_test_util.h index e6b37e635ac0bc..2495580a05ad45 100644 --- a/tensorflow/core/kernels/batch_kernel_test_util.h +++ b/tensorflow/core/kernels/batch_kernel_test_util.h @@ -39,7 +39,7 @@ class BatchFunctionKernelTestBase : public OpsTestBase, public ::testing::WithParamInterface { public: // Init test fixture with a batch kernel instance. - Status Init(bool enable_adaptive_scheduler); + absl::Status Init(bool enable_adaptive_scheduler); }; } // namespace test_util diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index bd93c1ec3a02a3..250ce16b500c5f 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -160,14 +160,14 @@ class BatchResource : public serving::BatchResourceBase { } }; - static Status Create(bool has_process_batch_function, - int32_t num_batch_threads, - int32_t max_execution_batch_size, - int32_t batch_timeout_micros, - int32_t max_enqueued_batches, - const std::vector& allowed_batch_sizes, - bool enable_large_batch_splitting, - std::unique_ptr* resource) { + static absl::Status Create(bool has_process_batch_function, + int32_t num_batch_threads, + int32_t max_execution_batch_size, + int32_t batch_timeout_micros, + int32_t max_enqueued_batches, + const std::vector& allowed_batch_sizes, + bool enable_large_batch_splitting, + std::unique_ptr* resource) { return Create(has_process_batch_function, num_batch_threads, max_execution_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, @@ -182,7 +182,7 @@ class BatchResource : public serving::BatchResourceBase { /*batch_padding_policy=*/"PAD_UP", resource); } - static Status Create( + static absl::Status Create( bool has_process_batch_function, int32_t num_batch_threads, int32_t max_execution_batch_size, int32_t batch_timeout_micros, int32_t max_enqueued_batches, @@ -213,7 +213,7 @@ class BatchResource : public serving::BatchResourceBase { return absl::OkStatus(); } - static Status Create( + static absl::Status Create( bool has_process_batch_function, AdaptiveBatcherT::Options adaptive_shared_batch_scheduler_options, int32_t max_batch_size, int32_t batch_timeout_micros, @@ -256,7 +256,7 @@ class BatchResource : public serving::BatchResourceBase { void ProcessFuncBatchImpl( const serving::BatchResourceBase::BatchTask& last_task, absl::Span inputs, std::vector* combined_outputs, - std::function done) const override { + std::function done) const override { auto* last_task_context = last_task.context; FunctionLibraryRuntime::Options opts; opts.step_container = last_task_context->step_container(); @@ -275,7 +275,7 @@ class BatchResource : public serving::BatchResourceBase { FunctionLibraryRuntime::Handle fhandle = down_cast(last_task).fhandle; flib->Run(opts, fhandle, inputs, combined_outputs, - [&](const Status& run_status) { + [&](const absl::Status& run_status) { done(run_status); done_notif.Notify(); }); @@ -352,7 +352,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { GetModelName(c)); RecordBatchParamNumBatchThreads(num_batch_threads_, GetModelName(c)); - std::function creator; + std::function creator; FunctionLibraryRuntime::Handle handle; OP_REQUIRES_OK_ASYNC(c, GetOrCreateFunctionHandle(c, &handle), done); @@ -461,7 +461,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { std::unique_ptr> { return {std::make_unique(handle)}; }; - Status status; + absl::Status status; if (serving::ShouldWarmupAllBatchSizes(c)) { status = br->RegisterWarmupInputs(guid, c, batcher_queue_, create_batch_task_fn, done); @@ -474,7 +474,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { // Assume br calls done, so nothing to do here. } -Status BatchFunctionKernel::InstantiateFunction( +absl::Status BatchFunctionKernel::InstantiateFunction( OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const { // TODO(b/173748062): Merge this instantiation logic with PartitionedCall. FunctionLibraryRuntime* flib = c->function_library(); @@ -535,7 +535,7 @@ Status BatchFunctionKernel::InstantiateFunction( handle); } -Status BatchFunctionKernel::GetOrCreateFunctionHandle( +absl::Status BatchFunctionKernel::GetOrCreateFunctionHandle( OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) { mutex_lock ml(mu_); if (!fhandle_) { @@ -551,7 +551,7 @@ Status BatchFunctionKernel::GetOrCreateFunctionHandle( // If large batch split is not enabled, the last one must equal // `max_batch_size_`. otherwise the last element must be smaller than or equal // to `max_batch_size_`. -Status BatchFunctionKernel::ValidateAllowedBatchSizes() const { +absl::Status BatchFunctionKernel::ValidateAllowedBatchSizes() const { if (allowed_batch_sizes_.empty()) { return absl::OkStatus(); } @@ -679,20 +679,21 @@ class BatchKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* c, DoneCallback done) final { BatchResource* br; - std::function creator = [this](BatchResource** r) { - std::unique_ptr new_resource; - TF_RETURN_IF_ERROR(BatchResource::Create( - /*has_process_batch_function=*/false, num_batch_threads_, - max_batch_size_, batch_timeout_micros_, max_enqueued_batches_, - allowed_batch_sizes_, false, &new_resource)); - *r = new_resource.release(); - return absl::OkStatus(); - }; + std::function creator = + [this](BatchResource** r) { + std::unique_ptr new_resource; + TF_RETURN_IF_ERROR(BatchResource::Create( + /*has_process_batch_function=*/false, num_batch_threads_, + max_batch_size_, batch_timeout_micros_, max_enqueued_batches_, + allowed_batch_sizes_, false, &new_resource)); + *r = new_resource.release(); + return absl::OkStatus(); + }; OP_REQUIRES_OK_ASYNC(c, c->resource_manager()->LookupOrCreate( container_, shared_name_, &br, creator), done); - const Status status = br->RegisterInput( + const absl::Status status = br->RegisterInput( random::New64(), c, batcher_queue_, []() -> absl::StatusOr< std::unique_ptr> { @@ -706,7 +707,7 @@ class BatchKernel : public AsyncOpKernel { // Validates 'allowed_batch_sizes_'. The entries must increase // monotonically, and the last one must equal 'max_batch_size_'. - Status ValidateAllowedBatchSizes() const { + absl::Status ValidateAllowedBatchSizes() const { if (allowed_batch_sizes_.empty()) { return absl::OkStatus(); } @@ -762,7 +763,8 @@ class UnbatchResource : public ResourceBase { string DebugString() const final { return "UnbatchResource"; } - Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) { + absl::Status Compute(OpKernelContext* context, + AsyncOpKernel::DoneCallback done) { const Tensor& data_t = context->input(0); const Tensor& batch_index_t = context->input(1); @@ -808,7 +810,7 @@ class UnbatchResource : public ResourceBase { // Critical section. std::vector done_callbacks_to_call; - Status status = [&]() -> Status { + absl::Status status = [&]() -> absl::Status { mutex_lock ml(mu_); // Check to see whether the tensor we want is already ready. @@ -946,7 +948,7 @@ class UnbatchKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* c, DoneCallback done) final { UnbatchResource* ubr; - std::function creator = + std::function creator = [this](UnbatchResource** r) { *r = new UnbatchResource(timeout_micros_); return absl::OkStatus(); @@ -978,8 +980,8 @@ class UnbatchGradResource : public ResourceBase { // Flushes the information for one batch, given its context and done // callback. Clears all information about it from the available_tensors_. - Status OutputBatch(OpKernelContext* context, - const AsyncOpKernel::DoneCallback& done) + absl::Status OutputBatch(OpKernelContext* context, + const AsyncOpKernel::DoneCallback& done) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { const Tensor& batch_index_t = context->input(1); auto batch_index = @@ -1012,8 +1014,8 @@ class UnbatchGradResource : public ResourceBase { } // Ingests data from one invocation of the op. - Status Compute(OpKernelContext* context, - const AsyncOpKernel::DoneCallback& done) { + absl::Status Compute(OpKernelContext* context, + const AsyncOpKernel::DoneCallback& done) { const Tensor& data_t = context->input(0); const Tensor& batch_index_t = context->input(1); const Tensor& grad_t = context->input(2); @@ -1142,7 +1144,7 @@ class UnbatchGradKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* c, DoneCallback done) final { UnbatchGradResource* ubr; - std::function creator = + std::function creator = [](UnbatchGradResource** r) { *r = new UnbatchGradResource(); return absl::OkStatus(); @@ -1151,7 +1153,7 @@ class UnbatchGradKernel : public AsyncOpKernel { c->resource_manager()->LookupOrCreate( container_, shared_name_, &ubr, creator), done); - Status status = ubr->Compute(c, done); + absl::Status status = ubr->Compute(c, done); ubr->Unref(); OP_REQUIRES_OK_ASYNC(c, status, done); // Assume ubr calls done, so nothing to do here. diff --git a/tensorflow/core/kernels/batch_kernels.h b/tensorflow/core/kernels/batch_kernels.h index 11373af2048991..73baea3a9b13e3 100644 --- a/tensorflow/core/kernels/batch_kernels.h +++ b/tensorflow/core/kernels/batch_kernels.h @@ -78,16 +78,16 @@ class BatchFunctionKernel : public AsyncOpKernel { // If large batch split is not enabled, the last one must equal // `max_batch_size_`. otherwise the last element must be smaller than or equal // to `max_batch_size_`. - Status ValidateAllowedBatchSizes() const; + absl::Status ValidateAllowedBatchSizes() const; // Creates the function handle if it isn't initialized yet; and re-use it // afterwards. - Status GetOrCreateFunctionHandle(OpKernelContext* c, - FunctionLibraryRuntime::Handle* handle); + absl::Status GetOrCreateFunctionHandle( + OpKernelContext* c, FunctionLibraryRuntime::Handle* handle); // Instantiate the user-defined function and emits `handle`. - Status InstantiateFunction(OpKernelContext* c, - FunctionLibraryRuntime::Handle* handle) const; + absl::Status InstantiateFunction( + OpKernelContext* c, FunctionLibraryRuntime::Handle* handle) const; // Initialize vars by reading from op-kernel-construction. // Vars diff --git a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc index 7e66f9b26726f4..06391a9e60d5cf 100644 --- a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc +++ b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc @@ -53,7 +53,7 @@ class BatchFunctionKernelTest : public test_util::BatchFunctionKernelTestBase { class BatchFunctionKernelParallelWarmupTestState : public OpsTestBase { public: // Init test fixture with a batch kernel instance. - Status Init(bool enable_splitting, bool check_output_shape) { + absl::Status Init(bool enable_splitting, bool check_output_shape) { static auto *const cpu_device = []() { auto device = DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); diff --git a/tensorflow/core/kernels/batch_kernels_env_test.cc b/tensorflow/core/kernels/batch_kernels_env_test.cc index 5a8bfec9f90d46..cf4bbf3613c5e0 100644 --- a/tensorflow/core/kernels/batch_kernels_env_test.cc +++ b/tensorflow/core/kernels/batch_kernels_env_test.cc @@ -31,7 +31,7 @@ TEST_P(BatchFunctionKernelEnvTest, Basic) { tensorflow::setenv("TF_NUM_BATCH_THREADS", "0", 1 /* overwrite */); const bool adaptive_scheduler_enabled = GetParam(); - Status status = Init(adaptive_scheduler_enabled); + absl::Status status = Init(adaptive_scheduler_enabled); if (adaptive_scheduler_enabled) { EXPECT_THAT(status, tensorflow::testing::StatusIs( error::FAILED_PRECONDITION, diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 9aaeb5ad5207c8..e09b3cdc6dc8f0 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/function.h" @@ -48,7 +49,6 @@ limitations under the License. #include "tsl/platform/refcount.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index 9098d5a76d1ee9..8be441b231387a 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -136,7 +136,7 @@ class AdaptiveSharedBatchScheduler // Ownership is shared between the caller of Create() and any queues created // via AddQueue(). - static Status Create( + static absl::Status Create( const Options& options, std::shared_ptr>* scheduler); @@ -164,9 +164,10 @@ class AdaptiveSharedBatchScheduler // success, the caller can assume that all output_tasks will be scheduled. // Including this option allows the scheduler to pack batches better and // should usually improve overall throughput. - std::function* input_task, int first_size, - int max_batch_size, - std::vector>* output_tasks)> + std::function* input_task, int first_size, + int max_batch_size, + std::vector>* output_tasks)> split_input_task_func; // If true, the padding will not be appended. @@ -176,9 +177,9 @@ class AdaptiveSharedBatchScheduler using BatchProcessor = std::function>)>; // Adds queue (and its callback) to be managed by this scheduler. - Status AddQueue(const QueueOptions& options, - BatchProcessor process_batch_callback, - std::unique_ptr>* queue); + absl::Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); double in_flight_batches_limit() { mutex_lock l(mu_); @@ -308,7 +309,7 @@ class ASBSQueue : public BatchScheduler { // Adds task to current batch. Fails if the task size is larger than the batch // size or if the current batch is full and this queue's number of outstanding // batches is at its maximum. - Status Schedule(std::unique_ptr* task) override; + absl::Status Schedule(std::unique_ptr* task) override; // Number of tasks waiting to be scheduled. size_t NumEnqueuedTasks() const override; @@ -381,7 +382,7 @@ template constexpr double AdaptiveSharedBatchScheduler::kMinStepSizeMultiplier; template -Status AdaptiveSharedBatchScheduler::Create( +absl::Status AdaptiveSharedBatchScheduler::Create( const Options& options, std::shared_ptr>* scheduler) { if (options.num_batch_threads < 1) { @@ -446,7 +447,7 @@ AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( } template -Status AdaptiveSharedBatchScheduler::AddQueue( +absl::Status AdaptiveSharedBatchScheduler::AddQueue( const QueueOptions& options, BatchProcessor process_batch_callback, std::unique_ptr>* queue) { if (options.max_batch_size <= 0) { @@ -729,7 +730,7 @@ ASBSQueue::~ASBSQueue() { } template -Status ASBSQueue::Schedule(std::unique_ptr* task) { +absl::Status ASBSQueue::Schedule(std::unique_ptr* task) { size_t size = (*task)->size(); if (options_.split_input_task_func == nullptr && size > options_.max_batch_size) { diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc index 8858390da400fc..f4290d02e383ab 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc @@ -45,9 +45,10 @@ class FakeTask : public BatchTask { // Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on // that task. Returns the resulting status. -Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { +absl::Status ScheduleTask(size_t task_size, + BatchScheduler* scheduler) { std::unique_ptr task(new FakeTask(task_size)); - Status status = scheduler->Schedule(&task); + absl::Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); return status; diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h index f27f20cf2b4b4f..a066550399c56d 100644 --- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler.h @@ -239,9 +239,10 @@ class BasicBatchScheduler : public BatchScheduler { // NOTE: // Instantiations of `TaskType` may vary, so it's up to caller to define // how (e.g., which members to access) to split input tasks. - std::function* input_task, - int first_output_task_size, int input_batch_size_limit, - std::vector>* output_tasks)> + std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)> split_input_task_func; // The maximum size of each enqueued batch (i.e., in `batches_`). @@ -263,14 +264,15 @@ class BasicBatchScheduler : public BatchScheduler { // The environment to use. Env* env = Env::Default(); }; - static Status Create(const Options& options, - std::function>)> - process_batch_callback, - std::unique_ptr* scheduler); + static absl::Status Create( + const Options& options, + std::function>)> + process_batch_callback, + std::unique_ptr* scheduler); ~BasicBatchScheduler() override = default; - Status Schedule(std::unique_ptr* task) override; + absl::Status Schedule(std::unique_ptr* task) override; size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; @@ -294,7 +296,7 @@ class BasicBatchScheduler : public BatchScheduler { // Implementation details follow. API users need not read. template -Status BasicBatchScheduler::Create( +absl::Status BasicBatchScheduler::Create( const Options& options, std::function>)> process_batch_callback, @@ -338,7 +340,7 @@ Status BasicBatchScheduler::Create( } template -Status BasicBatchScheduler::Schedule( +absl::Status BasicBatchScheduler::Schedule( std::unique_ptr* task) { return shared_scheduler_queue_->Schedule(task); } diff --git a/tensorflow/core/kernels/batching_util/basic_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/basic_batch_scheduler_test.cc index 0b8fac15894e43..6da8dbd0ca9cb8 100644 --- a/tensorflow/core/kernels/batching_util/basic_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/basic_batch_scheduler_test.cc @@ -44,9 +44,10 @@ class FakeTask : public BatchTask { // Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' // on that task. Returns the resulting status. -Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { +absl::Status ScheduleTask(size_t task_size, + BatchScheduler* scheduler) { std::unique_ptr task(new FakeTask(task_size)); - Status status = scheduler->Schedule(&task); + absl::Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); return status; diff --git a/tensorflow/core/kernels/batching_util/batch_input_task.h b/tensorflow/core/kernels/batching_util/batch_input_task.h index 908393008b41c1..4f50f1daf02bd0 100644 --- a/tensorflow/core/kernels/batching_util/batch_input_task.h +++ b/tensorflow/core/kernels/batching_util/batch_input_task.h @@ -145,7 +145,7 @@ template class BatchInputTask : public std::enable_shared_from_this> { public: - using SplitInputFunc = std::function* input_task, int first_output_task_size, int input_batch_size_limit, std::vector>* output_tasks)>; @@ -172,7 +172,8 @@ class BatchInputTask std::unique_ptr GetSplitTask(int split_id); - Status SplitBatches(std::vector>* output_tasks); + absl::Status SplitBatches( + std::vector>* output_tasks); std::unique_ptr input_task_; @@ -187,7 +188,7 @@ class BatchInputTask mutable absl::once_flag once_; std::vector> task_splits_; - Status split_status_; + absl::Status split_status_; }; // @@ -253,7 +254,7 @@ std::unique_ptr BatchInputTask::GetSplitTask(int split_id) { } template -Status BatchInputTask::SplitBatches( +absl::Status BatchInputTask::SplitBatches( std::vector>* output_tasks) { return split_func_(&input_task_, open_batch_remaining_slot_, batch_size_limit_, output_tasks); diff --git a/tensorflow/core/kernels/batching_util/batch_input_task_test.cc b/tensorflow/core/kernels/batching_util/batch_input_task_test.cc index 4fb1a9b8cbde19..ed0770c986825a 100644 --- a/tensorflow/core/kernels/batching_util/batch_input_task_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_input_task_test.cc @@ -60,7 +60,7 @@ namespace { using TensorMatrix = std::vector>; -using SplitFunc = std::function* input_task, int first_output_task_size, int input_batch_size_limit, std::vector>* output_tasks)>; @@ -114,7 +114,7 @@ class BatchInputTaskTest : public ::testing::Test { device_ = DeviceFactory::NewDevice("CPU", SessionOptions{}, "/job:a/replica:0/task:0"); - Status op_kernel_creation_status; + absl::Status op_kernel_creation_status; batch_kernel_ = CreateOpKernel( DEVICE_CPU, device_.get(), device_->GetAllocator(AllocatorAttributes{}), CreateBatchKernelNodeDef(), TF_GRAPH_DEF_VERSION, diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index 81cbe417123074..4c1cfe162052c1 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -337,7 +337,7 @@ string GetTensorNamesAndShapesString(const OpKernelContext* context, return out.str(); } -Status BatchResourceBase::RegisterWarmupInputs( +absl::Status BatchResourceBase::RegisterWarmupInputs( int64_t guid, OpKernelContext* context, const string& batcher_queue_name, const CreateBatchTaskFn& create_batch_task_fn, AsyncOpKernel::DoneCallback done) { @@ -355,7 +355,7 @@ Status BatchResourceBase::RegisterWarmupInputs( std::make_shared(allowed_batch_sizes_.size()); // Enqueue warmup batches. for (int i = 0; i < allowed_batch_sizes_.size(); ++i) { - Status status = RegisterInput( + absl::Status status = RegisterInput( guid, context, batcher_queue_name, create_batch_task_fn_share_status, [warmup_counter = warmup_counter.get()]() { warmup_counter->DecrementCount(); @@ -373,7 +373,7 @@ Status BatchResourceBase::RegisterWarmupInputs( }); } -Status BatchResourceBase::RegisterInput( +absl::Status BatchResourceBase::RegisterInput( int64_t guid, OpKernelContext* context, const string& batcher_queue_name, const CreateBatchTaskFn& create_batch_task_fn, AsyncOpKernel::DoneCallback done_callback, int forced_warmup_batch_size) { @@ -574,7 +574,8 @@ BatchResourceBase::GetBatcherQueueOptions( batcher_queue_options.split_input_task_func = [](std::unique_ptr* input_task, int open_batch_remaining_slot, int max_batch_size, - std::vector>* output_tasks) -> Status { + std::vector>* output_tasks) + -> absl::Status { return SplitInputTask(input_task, open_batch_remaining_slot, max_batch_size, output_tasks); }; @@ -616,7 +617,8 @@ BatchResourceBase::GetAdaptiveBatcherQueueOptions( batcher_queue_options.split_input_task_func = [](std::unique_ptr* input_task, int open_batch_remaining_slot, int max_batch_size, - std::vector>* output_tasks) -> Status { + std::vector>* output_tasks) + -> absl::Status { return SplitInputTask(input_task, open_batch_remaining_slot, max_batch_size, output_tasks); }; @@ -626,7 +628,7 @@ BatchResourceBase::GetAdaptiveBatcherQueueOptions( return batcher_queue_options; } -/*static*/ Status BatchResourceBase::ValidateBatch(const BatchT& batch) { +/*static*/ absl::Status BatchResourceBase::ValidateBatch(const BatchT& batch) { for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) { const BatchResourceBase::BatchTask& task = batch.task(task_idx); @@ -665,7 +667,7 @@ int BatchResourceBase::RoundToLowestAllowedBatchSize( batcher_queue_options_.disable_padding); } -Status BatchResourceBase::ConcatInputTensors( +absl::Status BatchResourceBase::ConcatInputTensors( const BatchT& batch, const std::vector>& unbatched_tasks, OpKernelContext* context, std::vector* concatenated_tensors) const { @@ -751,7 +753,7 @@ Status BatchResourceBase::ConcatInputTensors( } Tensor concatenated_tensor; - Status concat_status = + absl::Status concat_status = Concat(context, to_concatenate, &concatenated_tensor); TF_RETURN_IF_ERROR(concat_status); concatenated_tensors->push_back(concatenated_tensor); @@ -759,7 +761,7 @@ Status BatchResourceBase::ConcatInputTensors( return absl::OkStatus(); } -/*static*/ Status BatchResourceBase::SplitInputTask( +/*static*/ absl::Status BatchResourceBase::SplitInputTask( std::unique_ptr* input_task_ptr, int open_batch_remaining_slot, int max_batch_size, std::vector>* output_tasks) { BatchTask& input_task = *(*input_task_ptr); @@ -833,8 +835,8 @@ Status BatchResourceBase::ConcatInputTensors( // TODO(b/154140947): // Figure out the optimal implementation of Split, by using // 'Tensor::Slice' and eliminating unnecessary memcpy as much as possible. - const Status split_status = Split(input_task.context, input_tensor, - output_task_sizes, &split_tensors); + const absl::Status split_status = Split(input_task.context, input_tensor, + output_task_sizes, &split_tensors); if (!split_status.ok()) { return errors::Internal( "When splitting input, Tensor split operation failed: ", @@ -856,7 +858,7 @@ Status BatchResourceBase::ConcatInputTensors( return absl::OkStatus(); } -Status BatchResourceBase::SplitOutputTensors( +absl::Status BatchResourceBase::SplitOutputTensors( const std::vector& combined_outputs, BatchT* batch, std::vector>& unbatched_tasks) const { DCHECK_GE(batch->num_tasks(), 1); @@ -899,16 +901,21 @@ Status BatchResourceBase::SplitOutputTensors( return errors::FailedPrecondition( "Batched output tensor has 0 dimensions"); } - if (output_tensor.shape().dim_size(0) != + int64_t zeroth_dim_output_tensor_size = output_tensor.shape().dim_size(0); + if (zeroth_dim_output_tensor_size != static_cast(batch->size() + unbatched_tasks_size + padding_size)) { return errors::FailedPrecondition( "Batched output tensor's 0th dimension does not equal the sum of " - "the 0th dimension sizes of the input tensors"); + "the 0th dimension sizes of the input tensors. " + "0th dimension size: ", + zeroth_dim_output_tensor_size, "; batch size: ", batch->size(), + "; unbatched tasks size: ", unbatched_tasks_size, + "; padding size: ", padding_size); } std::vector split_tensor; - const Status split_status = tensor::Split( + const absl::Status split_status = tensor::Split( output_tensor, task_sizes_plus_optional_padding, &split_tensor); DCHECK(split_status.ok()) << split_status; if (!split_status.ok()) { @@ -944,8 +951,8 @@ Status BatchResourceBase::SplitOutputTensors( return absl::OkStatus(); } -void BatchResourceBase::CleanUpFunctionHelper(BatchTask& task, - const Status& status) const { +void BatchResourceBase::CleanUpFunctionHelper( + BatchTask& task, const absl::Status& status) const { WithContext wc(task.propagated_context); if (!status.ok()) { if (!absl::StrContains(status.message(), @@ -985,10 +992,10 @@ void BatchResourceBase::ProcessFuncBatch( // Regardless of the outcome, we need to propagate the status to the // individual tasks and signal that they are done. We use MakeCleanup() to // ensure that this happens no matter how we exit the method below. - Status status; + absl::Status status; bool cleanup_done = false; int64_t processed_size = batch->size(); - auto cleanup_fn = [&](const Status& status) { + auto cleanup_fn = [&](const absl::Status& status) { if (cleanup_done) { return; } @@ -1046,8 +1053,8 @@ void BatchResourceBase::ProcessFuncBatch( // library runtime will handle it now. finally.release(); ProcessFuncBatchImpl( - last_task, args, &combined_outputs, [&](const Status& run_status) { - Status final_status; + last_task, args, &combined_outputs, [&](const absl::Status& run_status) { + absl::Status final_status; auto run_finally = gtl::MakeCleanup([&]() { // We do the cleanup here as an optimization, so that // it runs in the underlying TF inter-op threadpool. @@ -1106,7 +1113,7 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { // All tasks should have the same number of input edges. const int num_input_edges = batch->task(0).inputs.size(); std::vector concatenated_tensors; - const Status concat_status = + const absl::Status concat_status = ConcatInputTensors(*batch, {}, last_task_context, &concatenated_tensors); processed_size = RoundToLowestAllowedBatchSize(batch->size()); OP_REQUIRES_OK_ASYNC(last_task_context, concat_status, last_task_callback); @@ -1158,9 +1165,8 @@ void BatchResourceBase::ProcessBatch(std::unique_ptr batch) const { } } -/*static*/ Status BatchResourceBase::EmitIndexTensor(OpKernelContext* context, - const BatchT& batch, - int output_index) { +/*static*/ absl::Status BatchResourceBase::EmitIndexTensor( + OpKernelContext* context, const BatchT& batch, int output_index) { const TensorShape index_shape({batch.num_tasks(), 3}); Tensor* index = nullptr; TF_RETURN_IF_ERROR( @@ -1191,10 +1197,9 @@ void BatchResourceBase::ProcessBatchCallBack( } } -Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, - const string& model_name, - const string& op_name, - BatcherQueueT** queue) { +absl::Status BatchResourceBase::LookupOrCreateBatcherQueue( + const string& queue_name, const string& model_name, const string& op_name, + BatcherQueueT** queue) { mutex_lock l(batcher_queues_mu_); auto it = batcher_queues_.find(queue_name); diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h index 1c50c552d6fa66..a2731da30d0cb3 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h @@ -361,7 +361,7 @@ class BatchScheduler { // substantial amount of time. If the method returns Status::OK, the task is // processed asynchronously, and any errors that occur during the processing // of the batch that includes the task can be reported to 'task'. - virtual Status Schedule(std::unique_ptr* task) = 0; + virtual absl::Status Schedule(std::unique_ptr* task) = 0; // Returns the number of tasks that have been scheduled (i.e. accepted by // Schedule()), but have yet to be handed to a thread for execution as part of diff --git a/tensorflow/core/kernels/batching_util/concat_split_util.h b/tensorflow/core/kernels/batching_util/concat_split_util.h index adf6363979e7b7..b5354be35c70a9 100644 --- a/tensorflow/core/kernels/batching_util/concat_split_util.h +++ b/tensorflow/core/kernels/batching_util/concat_split_util.h @@ -35,8 +35,8 @@ typedef Eigen::GpuDevice GPUDevice; // 'output' using 'context' for the allocation to ensure proper device // placement. template -Status Concat(OpKernelContext* context, const absl::Span inputs, - Tensor* output) { +absl::Status Concat(OpKernelContext* context, + const absl::Span inputs, Tensor* output) { const int input_dims = inputs[0].dims(); const TensorShape& input_shape = inputs[0].shape(); @@ -91,10 +91,11 @@ Status Concat(OpKernelContext* context, const absl::Span inputs, } // Same as 'Concat' above, but handles Tensor dtype deduction automatically. -inline Status Concat(OpKernelContext* context, - const absl::Span inputs, Tensor* output) { +inline absl::Status Concat(OpKernelContext* context, + const absl::Span inputs, + Tensor* output) { const DataType type = inputs[0].dtype(); - Status concat_status; + absl::Status concat_status; switch (type) { #define CASE(type) \ case DataTypeToEnum::value: \ @@ -117,9 +118,9 @@ inline Status Concat(OpKernelContext* context, // Handles special cases that are cheap. Sets 'done==true' iff it found an // applicable special case and wrote to the outputs. Otherwise acts as a no-op. template -Status SplitEasyCases(OpKernelContext* context, const Tensor& input, - const absl::Span sizes, - std::vector* outputs, bool* done) { +absl::Status SplitEasyCases(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs, bool* done) { *done = false; int64_t total_size = 0; @@ -154,9 +155,9 @@ Status SplitEasyCases(OpKernelContext* context, const Tensor& input, // Handles the general case, on CPU. template -Status SplitCPU(OpKernelContext* context, const Tensor& input, - const absl::Span sizes, - std::vector* outputs) { +absl::Status SplitCPU(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { int64_t suffix_dim_size = 1; for (int i = 1; i < input.shape().dims(); ++i) { suffix_dim_size *= input.shape().dim_size(i); @@ -208,9 +209,9 @@ Status SplitGPU(OpKernelContext* context, const Tensor& input, // The outer function that dispatches to the various Split*() functions above. template -Status Split(OpKernelContext* context, const Tensor& input, - const absl::Span sizes, - std::vector* outputs) { +absl::Status Split(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { bool easy_cases_done; TF_RETURN_IF_ERROR( SplitEasyCases(context, input, sizes, outputs, &easy_cases_done)); @@ -227,11 +228,11 @@ Status Split(OpKernelContext* context, const Tensor& input, } // Same as 'Split' above, but handles Tensor dtype automatically. -inline Status Split(OpKernelContext* context, const Tensor& input, - const absl::Span sizes, - std::vector* outputs) { +inline absl::Status Split(OpKernelContext* context, const Tensor& input, + const absl::Span sizes, + std::vector* outputs) { const DataType type = input.dtype(); - Status split_status; + absl::Status split_status; switch (type) { #define CASE(type) \ case DataTypeToEnum::value: \ diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h index 7340ace6317603..a7285077a5e599 100644 --- a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h @@ -99,7 +99,7 @@ class SerialDeviceBatchScheduler : public std::enable_shared_from_this< // Ownership is shared between the caller of Create() and any queues created // via AddQueue(). - static Status Create( + static absl::Status Create( const Options& options, std::shared_ptr>* scheduler); @@ -113,9 +113,9 @@ class SerialDeviceBatchScheduler : public std::enable_shared_from_this< using BatchProcessor = std::function>)>; // Adds queue (and its callback) to be managed by this scheduler. - Status AddQueue(const QueueOptions& options, - BatchProcessor process_batch_callback, - std::unique_ptr>* queue); + absl::Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); double in_flight_batches_limit() { mutex_lock l(mu_); @@ -212,7 +212,7 @@ class SDBSQueue : public BatchScheduler { // Adds task to current batch. Fails if the task size is larger than the batch // size or if the current batch is full and this queue's number of outstanding // batches is at its maximum. - Status Schedule(std::unique_ptr* task) override; + absl::Status Schedule(std::unique_ptr* task) override; // Number of tasks waiting to be scheduled. size_t NumEnqueuedTasks() const override; @@ -262,7 +262,7 @@ class SDBSBatch : public Batch { // ---------------- SerialDeviceBatchScheduler ---------------- template -Status SerialDeviceBatchScheduler::Create( +absl::Status SerialDeviceBatchScheduler::Create( const Options& options, std::shared_ptr>* scheduler) { if (options.num_batch_threads < 1) { @@ -332,7 +332,7 @@ SerialDeviceBatchScheduler::~SerialDeviceBatchScheduler() { } template -Status SerialDeviceBatchScheduler::AddQueue( +absl::Status SerialDeviceBatchScheduler::AddQueue( const QueueOptions& options, BatchProcessor process_batch_callback, std::unique_ptr>* queue) { if (options.max_batch_size <= 0) { @@ -487,7 +487,7 @@ SDBSQueue::~SDBSQueue() { } template -Status SDBSQueue::Schedule(std::unique_ptr* task) { +absl::Status SDBSQueue::Schedule(std::unique_ptr* task) { SDBSBatch* new_batch = nullptr; size_t size = (*task)->size(); if (size > options_.max_batch_size) { diff --git a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc index 4433f588357237..502813be36d177 100644 --- a/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/serial_device_batch_scheduler_test.cc @@ -42,9 +42,10 @@ class FakeTask : public BatchTask { // Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on // that task. Returns the resulting status. -Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { +absl::Status ScheduleTask(size_t task_size, + BatchScheduler* scheduler) { std::unique_ptr task(new FakeTask(task_size)); - Status status = scheduler->Schedule(&task); + absl::Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); return status; diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 10f0656c98cb84..9aee0efa25abbd 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -142,7 +142,7 @@ class SharedBatchScheduler }; // Ownership is shared between the caller of Create() and any queues created // via AddQueue(). - static Status Create( + static absl::Status Create( const Options& options, std::shared_ptr>* scheduler); @@ -205,9 +205,10 @@ class SharedBatchScheduler // NOTE: // Instantiations of `TaskType` may vary, so it's up to caller to define // how (e.g., which members to access) to split input tasks. - std::function* input_task, - int first_output_task_size, int input_batch_size_limit, - std::vector>* output_tasks)> + std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)> split_input_task_func; // The maximum size of each enqueued batch (i.e., in @@ -268,9 +269,9 @@ class SharedBatchScheduler MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize; }; // This method is marked virtual for testing purposes only. - virtual Status AddQueue(const QueueOptions& options, - ProcessBatchCallback process_batch_callback, - std::unique_ptr>* queue); + virtual absl::Status AddQueue( + const QueueOptions& options, ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue); protected: explicit SharedBatchScheduler(const Options& options); @@ -287,7 +288,7 @@ class SharedBatchScheduler void ThreadLogic(); // Called by `AddQueue`. - Status AddQueueAfterRewritingOptions( + absl::Status AddQueueAfterRewritingOptions( const QueueOptions& options, ProcessBatchCallback process_batch_callback, std::unique_ptr>* queue); @@ -358,7 +359,7 @@ class Queue { ProcessBatchCallbackWithPaddingTasks>; using SchedulableBatchCallback = std::function; - using SplitInputTaskIntoSubtasksCallback = std::function* input_task, int open_batch_remaining_slot, int max_execution_batch_size, std::vector>* output_tasks)>; @@ -371,7 +372,7 @@ class Queue { // Submits a task to the queue, with the same semantics as // BatchScheduler::Schedule(). - Status Schedule(std::unique_ptr* task); + absl::Status Schedule(std::unique_ptr* task); // Returns the number of enqueued tasks, with the same semantics as // BatchScheduler::NumEnqueuedTasks(). @@ -439,7 +440,7 @@ class Queue { // Implementation of Schedule above. Enqueues `task` as it // is or split it inline (eagerly) to form batches to be processed by // `Queue::ProcessBatch` - Status ScheduleWithoutOrEagerSplitImpl(std::unique_ptr* task) + absl::Status ScheduleWithoutOrEagerSplitImpl(std::unique_ptr* task) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Closes the open batch residing at the back of std::deque, and inserts a @@ -447,7 +448,7 @@ class Queue { void StartNewBatch() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Split `input task` into `output_tasks` according to 'task_sizes'. - Status SplitInputBatchIntoSubtasks( + absl::Status SplitInputBatchIntoSubtasks( std::unique_ptr* input_task, std::vector>* output_tasks) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -469,7 +470,7 @@ class Queue { // Returns an error if queue doesn't have capacity for this task. // // `task` must outlive this method. - Status ValidateBatchTaskQueueCapacity(TaskType* task) const + absl::Status ValidateBatchTaskQueueCapacity(TaskType* task) const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // Returns an error if the low priority task queue doesn't have capacity for @@ -478,7 +479,7 @@ class Queue { // single task does not it exceed input batch size limit and the total size of // the tasks in the queue does not exceed the max batch size * max enqueued // batch sizes. - Status ValidateLowPriorityTaskQueueCapacity(const TaskType& task) const + absl::Status ValidateLowPriorityTaskQueueCapacity(const TaskType& task) const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); // The task size of the last batch in the queue. @@ -536,13 +537,6 @@ class Queue { // TaskQueue low_priority_tasks_ TF_GUARDED_BY(mu_); - // The enqueued batches for low priority input. - // Each element corresponds to a task to be dequeued and processed by - // `Queue::ProcessBatch`. - // - std::deque>> low_priority_batches_ - TF_GUARDED_BY(mu_); - // The enqueued batches for high priority input. // Each element corresponds to a task to be dequeued and processed by // `Queue::ProcessBatch`. @@ -588,7 +582,7 @@ class QueueHandle : public BatchScheduler { Queue* queue); ~QueueHandle() override; - Status Schedule(std::unique_ptr* task) override; + absl::Status Schedule(std::unique_ptr* task) override; size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; @@ -609,7 +603,7 @@ class QueueHandle : public BatchScheduler { } // namespace internal template -Status SharedBatchScheduler::Create( +absl::Status SharedBatchScheduler::Create( const Options& options, std::shared_ptr>* scheduler) { if (options.num_batch_threads < 1) { @@ -640,7 +634,7 @@ SharedBatchScheduler::~SharedBatchScheduler() { } template -Status SharedBatchScheduler::AddQueue( +absl::Status SharedBatchScheduler::AddQueue( const QueueOptions& options, ProcessBatchCallback process_batch_callback, std::unique_ptr>* queue) { QueueOptions rewrite_options = options; @@ -659,7 +653,7 @@ Status SharedBatchScheduler::AddQueue( } template -Status SharedBatchScheduler::AddQueueAfterRewritingOptions( +absl::Status SharedBatchScheduler::AddQueueAfterRewritingOptions( const QueueOptions& options, ProcessBatchCallback process_batch_callback, std::unique_ptr>* queue) { if (options.input_batch_size_limit == 0) { @@ -856,7 +850,7 @@ bool Queue::IsLowPriorityTask(std::unique_ptr* task) { } template -Status Queue::ScheduleWithoutOrEagerSplitImpl( +absl::Status Queue::ScheduleWithoutOrEagerSplitImpl( std::unique_ptr* task) { // TODO(b/161857471): // Add test coverage when when concurrent incoming batches arrives and @@ -902,7 +896,7 @@ Status Queue::ScheduleWithoutOrEagerSplitImpl( } template -Status Queue::Schedule(std::unique_ptr* task) { +absl::Status Queue::Schedule(std::unique_ptr* task) { const bool large_batch_splitting = options_.enable_large_batch_splitting; tsl::profiler::TraceMe trace_me([task, large_batch_splitting] { return profiler::TraceMeEncode( @@ -974,7 +968,8 @@ size_t Queue::SchedulingCapacityInternal() const { } template -Status Queue::ValidateBatchTaskQueueCapacity(TaskType* task) const { +absl::Status Queue::ValidateBatchTaskQueueCapacity( + TaskType* task) const { // Check if the task size is larger than the batch size limit, regardless of // the batch capacity. if (task->size() > options_.input_batch_size_limit) { @@ -1020,7 +1015,7 @@ Status Queue::ValidateBatchTaskQueueCapacity(TaskType* task) const { } template -Status Queue::ValidateLowPriorityTaskQueueCapacity( +absl::Status Queue::ValidateLowPriorityTaskQueueCapacity( const TaskType& task) const { // Unlike the high priority batch capacity validation where having only // input_batch_size_limit without max_execution_batch_size is allowed, it @@ -1237,7 +1232,7 @@ void Queue::StartNewBatch() { } template -Status Queue::SplitInputBatchIntoSubtasks( +absl::Status Queue::SplitInputBatchIntoSubtasks( std::unique_ptr* input_task, std::vector>* output_tasks) { const int open_batch_remaining_slot = @@ -1328,7 +1323,7 @@ QueueHandle::~QueueHandle() { } template -Status QueueHandle::Schedule(std::unique_ptr* task) { +absl::Status QueueHandle::Schedule(std::unique_ptr* task) { return queue_->Schedule(task); } diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index 64f28209c6a98f..fc6abf65ad2248 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -93,19 +93,19 @@ class FakeTaskWithoutCriticality { using Queue = BatchScheduler; using Scheduler = SharedBatchScheduler; using QueueOptions = Scheduler::QueueOptions; -using SplitFunc = - std::function* input_task, - int first_output_task_size, int input_batch_size_limit, - std::vector>* output_tasks)>; +using SplitFunc = std::function* input_task, int first_output_task_size, + int input_batch_size_limit, + std::vector>* output_tasks)>; // Creates a FakeTask of size 'task_size' and 'criticality', and calls // 'scheduler->Schedule()' on that task. Returns the resulting status. // 'criticality' defaults to kCritical. -Status ScheduleTask(size_t task_size, BatchScheduler* scheduler, - tsl::criticality::Criticality criticality = - tsl::criticality::Criticality::kCritical) { +absl::Status ScheduleTask(size_t task_size, BatchScheduler* scheduler, + tsl::criticality::Criticality criticality = + tsl::criticality::Criticality::kCritical) { std::unique_ptr task(new FakeTask(task_size, criticality)); - Status status = scheduler->Schedule(&task); + absl::Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); return status; @@ -114,11 +114,11 @@ Status ScheduleTask(size_t task_size, BatchScheduler* scheduler, // Helper function similar to the function above. Creates a FakeTask of size // 'task_size' and calls 'scheduler->Schedule()' on that task. Returns the // resulting status. -Status ScheduleTaskWithoutCriticality( +absl::Status ScheduleTaskWithoutCriticality( size_t task_size, BatchScheduler* scheduler) { std::unique_ptr task( new FakeTaskWithoutCriticality(task_size)); - Status status = scheduler->Schedule(&task); + absl::Status status = scheduler->Schedule(&task); // Schedule() should have consumed 'task' iff it returned Status::OK. CHECK_EQ(status.ok(), task == nullptr); return status; @@ -208,7 +208,8 @@ class SharedBatchSchedulerTestBase { return [](std::unique_ptr* input_task, int open_batch_remaining_slot, int max_batch_size, - std::vector>* output_tasks) -> Status { + std::vector>* output_tasks) + -> absl::Status { std::unique_ptr owned_input_task = std::move(*input_task); const int input_task_size = owned_input_task->size(); @@ -453,7 +454,7 @@ TEST_P( [](std::unique_ptr* input_task, int open_batch_remaining_slot, int max_batch_size, std::vector>* - output_tasks) -> Status { + output_tasks) -> absl::Status { std::unique_ptr owned_input_task = std::move(*input_task); const int input_task_size = owned_input_task->size(); @@ -901,7 +902,7 @@ TEST_P(SharedBatchSchedulerTest, OneFullQueueDoesntBlockOtherQueues) { // Clog up queue 0. TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); queue_0_processing.WaitForNotification(); - Status queue_0_status; + absl::Status queue_0_status; do { queue_0_status = ScheduleTask(1, queue_0.get()); } while (queue_0_status.ok()); @@ -1650,7 +1651,7 @@ void CreateQueues() { auto split_func_for_size_one_task = [](std::unique_ptr* input_task, int open_batch_remaining_slot, int max_batch_size, - std::vector>* output_tasks) -> Status { + std::vector>* output_tasks) -> absl::Status { output_tasks->push_back(std::move(*input_task)); Notification notify; diff --git a/tensorflow/core/kernels/batching_util/threadsafe_status.cc b/tensorflow/core/kernels/batching_util/threadsafe_status.cc index fa5cda7161b4e0..fc4bd4c6c8e37e 100644 --- a/tensorflow/core/kernels/batching_util/threadsafe_status.cc +++ b/tensorflow/core/kernels/batching_util/threadsafe_status.cc @@ -21,17 +21,17 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" namespace tensorflow { -const Status& ThreadSafeStatus::status() const& { +const absl::Status& ThreadSafeStatus::status() const& { tf_shared_lock lock(mutex_); return status_; } -Status ThreadSafeStatus::status() && { +absl::Status ThreadSafeStatus::status() && { tf_shared_lock lock(mutex_); return std::move(status_); } -void ThreadSafeStatus::Update(const Status& new_status) { +void ThreadSafeStatus::Update(const absl::Status& new_status) { if (new_status.ok()) { return; } @@ -40,12 +40,12 @@ void ThreadSafeStatus::Update(const Status& new_status) { status_.Update(new_status); } -void ThreadSafeStatus::Update(Status&& new_status) { +void ThreadSafeStatus::Update(absl::Status&& new_status) { if (new_status.ok()) { return; } mutex_lock lock(mutex_); - status_.Update(std::forward(new_status)); + status_.Update(std::forward(new_status)); } } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/threadsafe_status.h b/tensorflow/core/kernels/batching_util/threadsafe_status.h index c14a8a907147bd..68e94f705f0d47 100644 --- a/tensorflow/core/kernels/batching_util/threadsafe_status.h +++ b/tensorflow/core/kernels/batching_util/threadsafe_status.h @@ -40,17 +40,17 @@ namespace tensorflow { // When updated in a multi-threading setup, only the first error is retained. class ThreadSafeStatus { public: - const Status& status() const& TF_LOCKS_EXCLUDED(mutex_); - Status status() && TF_LOCKS_EXCLUDED(mutex_); + const absl::Status& status() const& TF_LOCKS_EXCLUDED(mutex_); + absl::Status status() && TF_LOCKS_EXCLUDED(mutex_); // Retains the first error status: replaces the current status with // `new_status` if `new_status` is not OK and the previous status is OK. - void Update(const Status& new_status) TF_LOCKS_EXCLUDED(mutex_); - void Update(Status&& new_status) TF_LOCKS_EXCLUDED(mutex_); + void Update(const absl::Status& new_status) TF_LOCKS_EXCLUDED(mutex_); + void Update(absl::Status&& new_status) TF_LOCKS_EXCLUDED(mutex_); private: mutable mutex mutex_; - Status status_ TF_GUARDED_BY(mutex_); + absl::Status status_ TF_GUARDED_BY(mutex_); }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index d6f8d3dbad9ed0..4c01847dea65d1 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -42,11 +42,11 @@ namespace functor { template struct BincountFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& arr, - const typename TTypes::ConstTensor& weights, - typename TTypes::Tensor& output, - const Tidx num_bins) { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& arr, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& output, + const Tidx num_bins) { Tensor all_nonneg_t; TF_RETURN_IF_ERROR(context->allocate_temp( DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes())); @@ -87,11 +87,11 @@ struct BincountFunctor { template struct BincountFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& arr, - const typename TTypes::ConstTensor& weights, - typename TTypes::Tensor& output, - const Tidx num_bins) { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& arr, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& output, + const Tidx num_bins) { Tensor all_nonneg_t; TF_RETURN_IF_ERROR(context->allocate_temp( DT_BOOL, TensorShape({}), &all_nonneg_t, AllocatorAttributes())); @@ -170,11 +170,11 @@ struct BincountFunctor { template struct BincountReduceFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& in, - const typename TTypes::ConstTensor& weights, - typename TTypes::Tensor& out, - const Tidx num_bins) { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& in, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& out, + const Tidx num_bins) { std::atomic err_neg_val = 0; const int num_rows = out.dimension(0); const int num_cols = in.dimension(1); @@ -325,7 +325,7 @@ class DenseBincountOp : public OpKernel { const int64_t num_rows = data.dim_size(0); auto weight_matrix = (weights.NumElements() == 0) - ? weights.shaped(gtl::InlinedVector(2, 0)) + ? weights.shaped(absl::InlinedVector(2, 0)) : weights.matrix(); OP_REQUIRES_OK( ctx, ctx->allocate_output(0, TensorShape({num_rows, size}), &out_t)); diff --git a/tensorflow/core/kernels/bincount_op.h b/tensorflow/core/kernels/bincount_op.h index 56ad0dbb7ab242..4884761788f74e 100644 --- a/tensorflow/core/kernels/bincount_op.h +++ b/tensorflow/core/kernels/bincount_op.h @@ -28,20 +28,20 @@ namespace functor { template struct BincountFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& arr, - const typename TTypes::ConstTensor& weights, - typename TTypes::Tensor& output, - const Tidx num_bins); + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& arr, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& output, + const Tidx num_bins); }; template struct BincountReduceFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& in, - const typename TTypes::ConstTensor& weights, - typename TTypes::Tensor& out, - const Tidx num_bins); + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& in, + const typename TTypes::ConstTensor& weights, + typename TTypes::Tensor& out, + const Tidx num_bins); }; } // end namespace functor diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 179a930da5790c..d1bbfb24f28623 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -33,10 +33,10 @@ namespace functor { template struct BucketizeFunctor { // PRECONDITION: boundaries_vector must be sorted. - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& input, - const std::vector& boundaries_vector, - typename TTypes::Tensor& output) { + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output) { const int N = input.size(); for (int i = 0; i < N; i++) { auto first_bigger_it = std::upper_bound( diff --git a/tensorflow/core/kernels/bucketize_op.h b/tensorflow/core/kernels/bucketize_op.h index d26525fdc88bcb..9fb59c77ef7f29 100644 --- a/tensorflow/core/kernels/bucketize_op.h +++ b/tensorflow/core/kernels/bucketize_op.h @@ -29,10 +29,10 @@ namespace functor { template struct BucketizeFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& input, - const std::vector& boundaries_vector, - typename TTypes::Tensor& output); + static absl::Status Compute(OpKernelContext* context, + const typename TTypes::ConstTensor& input, + const std::vector& boundaries_vector, + typename TTypes::Tensor& output); }; } // namespace functor diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index 79d4462489c583..0a4e011815b80d 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -112,7 +112,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) { } } -Status CastOpBase::Unimplemented() { +absl::Status CastOpBase::Unimplemented() { return errors::Unimplemented("Cast ", DataTypeString(external_src_dtype_), " to ", DataTypeString(external_dst_dtype_), " is not supported"); @@ -122,7 +122,7 @@ CpuCastOp::CpuCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { OP_REQUIRES_OK(ctx, Prepare()); } -Status CpuCastOp::Prepare() { +absl::Status CpuCastOp::Prepare() { if (external_src_dtype_ == external_dst_dtype_) { work_ = nullptr; // Identity return absl::OkStatus(); diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h index c0b6eaf084089a..0c9556516a64d7 100644 --- a/tensorflow/core/kernels/cast_op.h +++ b/tensorflow/core/kernels/cast_op.h @@ -142,7 +142,7 @@ class CastOpBase : public OpKernel { DataType external_dst_dtype_; bool use_truncation_; CastFunctorType work_ = nullptr; - Status Unimplemented(); + absl::Status Unimplemented(); CastOpBase(const CastOpBase&) = delete; void operator=(const CastOpBase&) = delete; @@ -154,7 +154,7 @@ class CpuCastOp : public CastOpBase { explicit CpuCastOp(OpKernelConstruction* ctx); private: - Status Prepare(); + absl::Status Prepare(); }; namespace functor { diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 6c6f0d593d22b4..d78e1422792ed5 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -33,9 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif namespace tensorflow { @@ -261,14 +259,13 @@ class CheckNumericsOp : public AsyncOpKernel { auto check_cb = [this, stream, abnormal_detected_ref, abnormal_detected_host, context, done]() { { -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - se::gpu::ScopedActivateContext scoped_activation{stream->parent()}; -#endif + std::unique_ptr scoped_activation = + stream->parent()->Activate(); TTypes::Vec abnormal_detected_host_flat = abnormal_detected_host.flat(); abnormal_detected_ref.Unref(); checkForAnomalies(context, abnormal_detected_host_flat); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. done(); diff --git a/tensorflow/core/kernels/checkpoint_callback_manager.cc b/tensorflow/core/kernels/checkpoint_callback_manager.cc index 26ae9ec0fb3cae..6140681bbe8145 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager.cc +++ b/tensorflow/core/kernels/checkpoint_callback_manager.cc @@ -66,7 +66,7 @@ void TriggerSaveCallbackIfFileNotExist(absl::string_view checkpoint_id, return; } - Status write_status = + absl::Status write_status = WriteStringToFile(Env::Default(), file_path, *save_content); if (!write_status.ok()) { LOG(WARNING) << write_status; @@ -86,7 +86,8 @@ void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id, return; } std::string payload; - Status read_status = ReadFileToString(Env::Default(), file_path, &payload); + absl::Status read_status = + ReadFileToString(Env::Default(), file_path, &payload); if (!read_status.ok()) { LOG(WARNING) << "Failed to read: " << read_status; return; @@ -94,7 +95,7 @@ void TriggerRestoreCallbackIfFileExists(absl::string_view checkpoint_id, LOG(INFO) << "Calling a restore callback: file_extension = " << file_extension << ", checkpoint_id = " << checkpoint_id; - Status callback_status = callback(checkpoint_id, payload); + absl::Status callback_status = callback(checkpoint_id, payload); if (!callback_status.ok()) { LOG(WARNING) << callback_status; } @@ -140,7 +141,7 @@ CheckpointCallbackManager::GetCheckpointIdAndPathFromPrefix( absl::StrCat("Failed to find a checkpoint id. prefix = ", prefix)); } -Status CheckpointCallbackManager::RegisterSaveCallback( +absl::Status CheckpointCallbackManager::RegisterSaveCallback( absl::string_view file_extension, SaveCallback callback) { SaveCallback lazy_callback = nullptr; std::string checkpoint_id; @@ -174,7 +175,7 @@ bool CheckpointCallbackManager::DoesSaveCallbackExist( return save_callbacks_.contains(file_extension); } -Status CheckpointCallbackManager::RegisterRestoreCallback( +absl::Status CheckpointCallbackManager::RegisterRestoreCallback( absl::string_view file_extension, RestoreCallback callback) { RestoreCallback lazy_callback = nullptr; std::string checkpoint_id; diff --git a/tensorflow/core/kernels/checkpoint_callback_manager.h b/tensorflow/core/kernels/checkpoint_callback_manager.h index 61389229d09c10..7e0d9d8fc3a97e 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager.h +++ b/tensorflow/core/kernels/checkpoint_callback_manager.h @@ -41,7 +41,7 @@ using SaveCallback = // Status restore_callback(absl::string_view checkpoint_id, // absl::string_view content_from_checkpoint); using RestoreCallback = - std::function; + std::function; // A class to save and restore additional information for checkpointing. class CheckpointCallbackManager : public ResourceBase { @@ -67,8 +67,8 @@ class CheckpointCallbackManager : public ResourceBase { // The callback should return a string content needs to be stored // as a part of a checkpoint, and then the content is stored as a file // with the registered the file_extension. - Status RegisterSaveCallback(absl::string_view file_extension, - SaveCallback callback); + absl::Status RegisterSaveCallback(absl::string_view file_extension, + SaveCallback callback); // Checks if a registered save callback exists for an extension. bool DoesSaveCallbackExist(absl::string_view file_extension); @@ -77,8 +77,8 @@ class CheckpointCallbackManager : public ResourceBase { // The passed file_extension is used to generate a file name together with // an identified checkpoint_id. If the file exists, the registered callback // is triggered with the content of the file. - Status RegisterRestoreCallback(absl::string_view file_extension, - RestoreCallback callback); + absl::Status RegisterRestoreCallback(absl::string_view file_extension, + RestoreCallback callback); // Checks if a registered restore callback exists for an extension. bool DoesRestoreCallbackExist(absl::string_view file_extension); diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index f5a882b792ca4f..f77d6cd010c1ae 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -58,7 +58,7 @@ static std::unique_ptr BuildOpKernel(OpKernelConstruction* c, if (name.empty() || name == "Id") return k; sub_node->set_name(name); sub_node->set_op(name); - Status status; + absl::Status status; k = CreateOpKernel(c->device_type(), c->device(), c->device()->GetAllocator(AllocatorAttributes()), *sub_node, c->graph_def_version(), &status); @@ -128,7 +128,7 @@ class CollectiveOpV1Kernel : public AsyncOpKernel { << col_params_->instance.instance_key; col_exec->CompleteParamsAsync( c->device()->attributes(), col_params_, c->cancellation_manager(), - [this, c, done](const Status& s) { + [this, c, done](const absl::Status& s) { if (s.ok()) { col_params_->instance.impl_details.dependencies = dependencies_; ComputeAsync(c, done); @@ -202,7 +202,8 @@ class CollectiveGatherOpKernel : public CollectiveOpV1Kernel { } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, col_params = col_params_, done](const Status& s) { + auto actual_done = [c, col_params = col_params_, + done](const absl::Status& s) { VLOG(1) << "CollectiveGatherOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() << " group " << col_params->group.group_key << " instance " @@ -311,7 +312,8 @@ class CollectiveReduceOpKernel : public CollectiveOpV1Kernel { } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, col_params = col_params_, done](const Status& s) { + auto actual_done = [c, col_params = col_params_, + done](const absl::Status& s) { VLOG(1) << "CollectiveReduceOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() << " group " << col_params->group.group_key << " instance " @@ -391,7 +393,8 @@ class CollectiveBcastSendOpKernel : public CollectiveOpV1Kernel { " does not match shape of input"), done); - auto actual_done = [c, col_params = col_params_, done](const Status& s) { + auto actual_done = [c, col_params = col_params_, + done](const absl::Status& s) { VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() << " group " << col_params->group.group_key << " instance " @@ -462,7 +465,8 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpV1Kernel { } if (!CanProceedWithCompute(c, col_exec, done)) return; - auto actual_done = [c, col_params = col_params_, done](const Status& s) { + auto actual_done = [c, col_params = col_params_, + done](const absl::Status& s) { VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for collective " << c->op_kernel().name() << " device " << c->device()->name() << " group " << col_params->group.group_key << " instance_key " @@ -532,10 +536,10 @@ class CollectiveAssignGroupV2OpKernel : public OpKernel { } private: - static Status ComputeGroupKey(const Tensor& group_assignment, - const int32_t device_index, - const int32_t base_key, Tensor* group_size, - Tensor* group_key) { + static absl::Status ComputeGroupKey(const Tensor& group_assignment, + const int32_t device_index, + const int32_t base_key, + Tensor* group_size, Tensor* group_key) { group_size->flat()(0) = group_assignment.dim_size(1); for (int group_id = 0; group_id < group_assignment.dim_size(0); @@ -595,10 +599,12 @@ class CollectiveOpV2Kernel : public AsyncOpKernel { // Fills common parts of CollectiveParams according to the Op, *excluding // output_shape*. Kernels should further work on the CollectiveParams if they // need to set additional fields. - Status FillCollectiveParams(CollectiveParams* col_params, OpKernelContext* c, - CollectiveType collective_type, - const Tensor& group_size, const Tensor& group_key, - const Tensor& instance_key) { + absl::Status FillCollectiveParams(CollectiveParams* col_params, + OpKernelContext* c, + CollectiveType collective_type, + const Tensor& group_size, + const Tensor& group_key, + const Tensor& instance_key) { if (group_size.dims() > 0) { return errors::InvalidArgument( "Unexpected dimensions on input group_size, got ", @@ -704,7 +710,7 @@ class CollectiveOpV2Kernel : public AsyncOpKernel { col_exec->CompleteParamsAsync( c->device()->attributes(), col_params, c->cancellation_manager(), [c, activity_id, xprof_ctx_id, done = std::move(done), col_params, - col_exec](const Status& s) mutable { + col_exec](const absl::Status& s) mutable { tsl::profiler::TraceMeConsumer consumer( [&] { return tsl::profiler::TraceMeEncode( @@ -715,7 +721,8 @@ class CollectiveOpV2Kernel : public AsyncOpKernel { if (s.ok()) { auto actual_done = [c, activity_id, col_params, xprof_ctx_id, - done = std::move(done)](const Status& s) { + done = + std::move(done)](const absl::Status& s) { tsl::profiler::TraceMeConsumer consumer( [&] { return tsl::profiler::TraceMeEncode( @@ -1038,7 +1045,8 @@ class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel { device_type_ = c->device_type(); } - Status CheckInputs(Tensor group_size_t, Tensor group_key_t, Tensor rank_t) { + absl::Status CheckInputs(Tensor group_size_t, Tensor group_key_t, + Tensor rank_t) { if (group_size_t.dims() > 0) { return errors::InvalidArgument( "Unexpected dimensions on input group_size. " @@ -1120,7 +1128,7 @@ class CollectiveInitializeCommunicatorOpKernel : public AsyncOpKernel { << group_params->group_key; col_exec->CompleteGroupAsync( c->device()->attributes(), group_params, c->cancellation_manager(), - [c, done = std::move(done), group_params](const Status& s) { + [c, done = std::move(done), group_params](const absl::Status& s) { if (s.ok()) { VLOG(1) << "Collective Group initialization done for device " << c->device()->name() << " group " @@ -1167,10 +1175,10 @@ class CollectiveOpV3Kernel : public AsyncOpKernel { // Fills common parts of CollectiveParams according to the Op, *excluding // output_shape*. Kernels should further work on the CollectiveParams if they // need to set additional fields. - Status FillCollectiveParams(CollectiveParams* col_params, - const Tensor& group_assignment, - CollectiveType collective_type, - CollectiveGroupResource* resource) { + absl::Status FillCollectiveParams(CollectiveParams* col_params, + const Tensor& group_assignment, + CollectiveType collective_type, + CollectiveGroupResource* resource) { int64 group_id; int64 group_size; if (group_assignment.NumElements() == 0) { @@ -1223,10 +1231,10 @@ class CollectiveOpV3Kernel : public AsyncOpKernel { col_exec->CompleteParamsAsync( c->device()->attributes(), col_params, c->cancellation_manager(), [c, done = std::move(done), col_params, - col_exec](const Status& s) mutable { + col_exec](const absl::Status& s) mutable { if (s.ok()) { - auto actual_done = [c, col_params, - done = std::move(done)](const Status& s) { + auto actual_done = [c, col_params, done = std::move(done)]( + const absl::Status& s) { VLOG(1) << "Collective ExecuteAsync done for " << col_params->name << " device " << c->device()->name() << " group " << col_params->group.group_key diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index 13343a99b64acd..d2578a55701672 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -65,7 +65,7 @@ class ConditionalAccumulator functor::SetZeroFunctor set_zero_functor_; - Status ValidateShape(const Tensor* tensor) + absl::Status ValidateShape(const Tensor* tensor) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { // Must be compatible with accumulated gradient if available if (counter_ > 0) { diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 8a0c73d0bdbca2..6bd50a7e65dcb2 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -29,7 +29,8 @@ ConditionalAccumulatorBase::ConditionalAccumulatorBase( current_global_step_ = 0; } -Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) { +absl::Status ConditionalAccumulatorBase::MatchesNodeDef( + const NodeDef& node_def) { // TODO(xinghao@): implement the checks for the node definition return absl::OkStatus(); } @@ -39,7 +40,8 @@ Status ConditionalAccumulatorBase::MatchesNodeDef(const NodeDef& node_def) { * step. Logs warning if the accumulator's time step is already larger than the * provided time step. */ -Status ConditionalAccumulatorBase::SetGlobalStep(int64_t new_global_step) { +absl::Status ConditionalAccumulatorBase::SetGlobalStep( + int64_t new_global_step) { mutex_lock lock(mu_); if (new_global_step < current_global_step_) { LOG(WARNING) << "Attempt to set current_global_step_ to smaller value: " diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index 2e2fa9c441a1cf..683e667e3536cb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -73,9 +73,9 @@ class ConditionalAccumulatorBase : public ResourceBase { // SetGlobalStep is a modifier method for current_global_step. // It returns an InvalidArgument error if the new_global_step is less than // current_global_step. - Status SetGlobalStep(int64_t new_global_step); + absl::Status SetGlobalStep(int64_t new_global_step); - Status MatchesNodeDef(const NodeDef& node_def); + absl::Status MatchesNodeDef(const NodeDef& node_def); protected: // Virtual methods to be implemented by sub-classes for different datatypes. diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.cc b/tensorflow/core/kernels/conditional_accumulator_base_op.cc index 3b6eb16c7ae02b..cf094c7f3f78e2 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.cc @@ -49,7 +49,7 @@ class AccumulatorSetGlobalStepOp new_global_step_tensor->shape().DebugString())); } - Status status = + absl::Status status = accumulator->SetGlobalStep(new_global_step_tensor->scalar()()); if (!status.ok()) ctx->CtxFailureWithWarning(status); } diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 73504fbd495327..c0d1c9a6c8e7f5 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -76,10 +76,10 @@ class ConditionalAccumulatorBaseOp : public OpKernel { virtual void SetHandleToOutput(OpKernelContext* ctx) TF_SHARED_LOCKS_REQUIRED(mu_) = 0; - virtual Status CheckSignature(OpKernelContext* ctx) = 0; + virtual absl::Status CheckSignature(OpKernelContext* ctx) = 0; protected: - typedef std::function Creator; + typedef std::function Creator; // Subclasses must override this virtual Creator GetCreator() const = 0; @@ -94,7 +94,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { bool accumulator_set_ TF_GUARDED_BY(mu_); private: - Status SetAccumulatorHandle(OpKernelContext* ctx) + absl::Status SetAccumulatorHandle(OpKernelContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR(cinfo_.Init(ctx->resource_manager(), def())); diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index 732847ff200db7..c0d171f2fed12e 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -41,7 +41,7 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { }; } - Status CheckSignature(OpKernelContext* ctx) override { + absl::Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); return absl::OkStatus(); } @@ -79,7 +79,7 @@ class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { }; } - Status CheckSignature(OpKernelContext* ctx) override { + absl::Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_RESOURCE})); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/constant_op_test.cc b/tensorflow/core/kernels/constant_op_test.cc index a2052c146da7a3..33aba406c94bfd 100644 --- a/tensorflow/core/kernels/constant_op_test.cc +++ b/tensorflow/core/kernels/constant_op_test.cc @@ -60,7 +60,7 @@ void ConstantOpTest::PersistentMemoryTrackingTest(bool on_gpu) { std::unique_ptr device(DeviceFactory::NewDevice( device_string, {}, "/job:worker/replica:0/task:0")); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(device_type, device.get(), cpu_allocator(), const_node, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/control_flow_ops_test.cc b/tensorflow/core/kernels/control_flow_ops_test.cc index a60e6a022327b1..a58500ed951160 100644 --- a/tensorflow/core/kernels/control_flow_ops_test.cc +++ b/tensorflow/core/kernels/control_flow_ops_test.cc @@ -146,7 +146,7 @@ static void add_identity_nodes(Node* node, Graph& graph, } // Runs type inference pass on graph -static Status type_inference(Graph& graph) { +static absl::Status type_inference(Graph& graph) { GraphOptimizationPassOptions opt_options; std::unique_ptr graph_ptr(new Graph(OpRegistry::Global())); graph_ptr->Copy(graph); diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index 0be69d2689e7be..aeafb0db6745c5 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -50,7 +50,7 @@ int ConvBackpropDimensions::SpatialPadding(const Padding& padding, namespace { -Status ConvBackpropExtractAndVerifyDimension( +absl::Status ConvBackpropExtractAndVerifyDimension( StringPiece label, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& output_shape, const absl::Span dilations, const std::vector& strides, @@ -92,7 +92,7 @@ Status ConvBackpropExtractAndVerifyDimension( } // namespace -Status ConvBackpropComputeDimensionsV2( +absl::Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, const absl::Span dilations, const std::vector& strides, @@ -157,13 +157,11 @@ Status ConvBackpropComputeDimensionsV2( return absl::OkStatus(); } -Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, - const TensorShape& input_shape, - const TensorShape& filter_shape, - const TensorShape& out_backprop_shape, - const std::vector& strides, - Padding padding, TensorFormat data_format, - ConvBackpropDimensions* dims) { +absl::Status ConvBackpropComputeDimensions( + StringPiece label, int num_spatial_dims, const TensorShape& input_shape, + const TensorShape& filter_shape, const TensorShape& out_backprop_shape, + const std::vector& strides, Padding padding, + TensorFormat data_format, ConvBackpropDimensions* dims) { static constexpr std::array one_dilations = {{1, 1, 1, 1, 1}}; return ConvBackpropComputeDimensionsV2( label, num_spatial_dims, input_shape, filter_shape, out_backprop_shape, @@ -171,11 +169,10 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, dims); } -Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes, - const TensorShape& filter_shape, - const TensorShape& out_backprop_shape, - const TensorFormat& data_format, - TensorShape* input_shape) { +absl::Status Conv2DBackpropComputeInputShape( + const Tensor& input_sizes, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const TensorFormat& data_format, + TensorShape* input_shape) { if (!TensorShapeUtils::IsVector(input_sizes.shape())) { return errors::InvalidArgument( "Conv2DBackpropInput: input_sizes input must be 1-dim, not ", diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.h b/tensorflow/core/kernels/conv_grad_shape_utils.h index f61f53ee13cc38..9fdc0ce9bcabdc 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.h +++ b/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -66,18 +66,16 @@ struct ConvBackpropDimensions { // Common code between implementations of Conv?DBackpropInput and // Conv?DBackpropFilter. Verifies that the dimensions all match, and computes // sizes/padding for the spatial dimensions. Does not support explicit padding. -Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, - const TensorShape& input_shape, - const TensorShape& filter_shape, - const TensorShape& out_backprop_shape, - const std::vector& strides, - Padding padding, TensorFormat data_format, - ConvBackpropDimensions* dims); +absl::Status ConvBackpropComputeDimensions( + StringPiece label, int num_spatial_dims, const TensorShape& input_shape, + const TensorShape& filter_shape, const TensorShape& out_backprop_shape, + const std::vector& strides, Padding padding, + TensorFormat data_format, ConvBackpropDimensions* dims); // The V2 version computes the same outputs with arbitrary dilation rate and // supports explicit padding. // TODO(b/67112639): Merge V2 versions and the original versions eventually. -Status ConvBackpropComputeDimensionsV2( +absl::Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, absl::Span dilations, const std::vector& strides, @@ -85,11 +83,10 @@ Status ConvBackpropComputeDimensionsV2( TensorFormat data_format, ConvBackpropDimensions* dims); // Computes the shape of the in_backprop. -Status Conv2DBackpropComputeInputShape(const Tensor& input_sizes, - const TensorShape& filter_shape, - const TensorShape& out_backprop_shape, - const TensorFormat& data_format, - TensorShape* input_shape); +absl::Status Conv2DBackpropComputeInputShape( + const Tensor& input_sizes, const TensorShape& filter_shape, + const TensorShape& out_backprop_shape, const TensorFormat& data_format, + TensorShape* input_shape); } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CONV_GRAD_SHAPE_UTILS_H_ diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 4f613b91fd73a7..d976e0a55c084c 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -39,6 +39,7 @@ typedef Eigen::GpuDevice GPUDevice; if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \ } while (false) +<<<<<<< HEAD bool UseNhwcLayoutForConvOnRocm(se::Stream* stream) { #if TENSORFLOW_USE_ROCM bool is_enabled = se::gpu::UseNhwcLayoutForRocm(); @@ -51,6 +52,10 @@ bool UseNhwcLayoutForConvOnRocm(se::Stream* stream) { Status InitConv2DParameters(const OpKernelConstruction* context, Conv2DParameters* params) { +======= +absl::Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params) { +>>>>>>> master TF_RETURN_IF_ERROR(context->GetAttr("dilations", ¶ms->dilations)); TF_RETURN_IF_ERROR(context->GetAttr("strides", ¶ms->strides)); TF_RETURN_IF_ERROR(context->GetAttr("padding", ¶ms->padding)); @@ -104,9 +109,9 @@ Status InitConv2DParameters(const OpKernelConstruction* context, return absl::OkStatus(); } -Status ComputeConv2DDimension(const Conv2DParameters& params, - const Tensor& input, const Tensor& filter, - Conv2DDimensions* dimensions) { +absl::Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions) { int required_dims = params.data_format == TensorFormat::FORMAT_NCHW_VECT_C ? 5 : 4; // Check that 2D convolution input and filter have exactly required_dims. diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index aaa49642279674..65c63fec1e439f 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -125,15 +125,15 @@ struct Conv2DDimensions { // Initializes and validates Conv2D parameters configured by OpKernel // attributes. -Status InitConv2DParameters(const OpKernelConstruction* context, - Conv2DParameters* params); +absl::Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params); // Computes and validates convolutions dimensions from Conv2D parameters. If // parameters are valid, dimensions will be updated with derived convolution // dimensions, otherwise an error will be returned. -Status ComputeConv2DDimension(const Conv2DParameters& params, - const Tensor& input, const Tensor& filter, - Conv2DDimensions* dimensions); +absl::Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions); } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc index d45ff0171dfe59..8887103240c9d7 100644 --- a/tensorflow/core/kernels/conv_ops_fused_image_transform.cc +++ b/tensorflow/core/kernels/conv_ops_fused_image_transform.cc @@ -341,8 +341,8 @@ class FusedResizeAndPadConvFunctor { // use TensorFlow's resource management to ensure that the memory will be // released when the session is over. Im2ColBufferResource* im2col_buffer_resource; - std::function**)> creator = - [](Im2ColBufferResource** resource) { + std::function**)> + creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); return absl::OkStatus(); }; @@ -378,7 +378,7 @@ class FusedResizeAndPadConvFunctor { (needed_resize_cache_count * sizeof(T1)) <= kResizeCacheSize, errors::InvalidArgument("Input too large for resize cache")); Im2ColBufferResource* resize_cache_resource; - std::function**)> + std::function**)> resize_creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc index 07efd8069b019f..3ebd3a4fa76d93 100644 --- a/tensorflow/core/kernels/conv_ops_using_gemm.cc +++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc @@ -308,7 +308,7 @@ class Im2ColConvFunctor { // use TensorFlow's resource management to ensure that the memory will be // released when the session is over. Im2ColBufferResource* im2col_buffer_resource; - std::function**)> + std::function**)> creator = [](Im2ColBufferResource** resource) { *resource = new Im2ColBufferResource(); return absl::OkStatus(); diff --git a/tensorflow/core/kernels/count_ops.cc b/tensorflow/core/kernels/count_ops.cc index dd1d3db048046d..93b595f49bcdaa 100644 --- a/tensorflow/core/kernels/count_ops.cc +++ b/tensorflow/core/kernels/count_ops.cc @@ -38,8 +38,9 @@ using BatchedMap = std::vector>; namespace { // TODO(momernick): Extend this function to work with outputs of rank > 2. template -Status OutputSparse(const BatchedMap& per_batch_counts, int64_t num_values, - bool is_1d, OpKernelContext* context) { +absl::Status OutputSparse(const BatchedMap& per_batch_counts, + int64_t num_values, bool is_1d, + OpKernelContext* context) { int total_values = 0; int num_batches = per_batch_counts.size(); for (const auto& per_batch_count : per_batch_counts) { diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index 2480bc435bd01c..401f1572298d9b 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -56,11 +56,11 @@ class CTCDecodeHelper { inline int GetTopPaths() const { return top_paths_; } void SetTopPaths(int tp) { top_paths_ = tp; } - Status ValidateInputsGenerateOutputs( + absl::Status ValidateInputsGenerateOutputs( OpKernelContext* ctx, const Tensor** inputs, const Tensor** seq_len, Tensor** log_prob, OpOutputList* decoded_indices, OpOutputList* decoded_values, OpOutputList* decoded_shape) const { - Status status = ctx->input("inputs", inputs); + absl::Status status = ctx->input("inputs", inputs); if (!status.ok()) return status; status = ctx->input("sequence_length", seq_len); if (!status.ok()) return status; @@ -100,7 +100,7 @@ class CTCDecodeHelper { } } - Status s = ctx->allocate_output( + absl::Status s = ctx->allocate_output( "log_probability", TensorShape({batch_size, top_paths_}), log_prob); if (!s.ok()) return s; @@ -115,8 +115,8 @@ class CTCDecodeHelper { } // sequences[b][p][ix] stores decoded value "ix" of path "p" for batch "b". - Status StoreAllDecodedSequences( - const std::vector > >& sequences, + absl::Status StoreAllDecodedSequences( + const std::vector>>& sequences, OpOutputList* decoded_indices, OpOutputList* decoded_values, OpOutputList* decoded_shape) const { // Calculate the total number of entries for each path @@ -138,7 +138,7 @@ class CTCDecodeHelper { const int64_t p_num = num_entries[p]; - Status s = + absl::Status s = decoded_indices->allocate(p, TensorShape({p_num, 2}), &p_indices); if (!s.ok()) return s; s = decoded_values->allocate(p, TensorShape({p_num}), &p_values); diff --git a/tensorflow/core/kernels/ctc_loss_op.cc b/tensorflow/core/kernels/ctc_loss_op.cc index 1886ae27026974..63d31fcf62d46d 100644 --- a/tensorflow/core/kernels/ctc_loss_op.cc +++ b/tensorflow/core/kernels/ctc_loss_op.cc @@ -153,7 +153,7 @@ class CTCLossOp : public OpKernel { ctx, sparse::SparseTensor::Create(*labels_indices, *labels_values, labels_shape, order, &labels_sp)); - Status labels_sp_valid = labels_sp.IndicesValid(); + absl::Status labels_sp_valid = labels_sp.IndicesValid(); OP_REQUIRES(ctx, labels_sp_valid.ok(), errors::InvalidArgument("label SparseTensor is not valid: ", labels_sp_valid.message())); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index c447763924a878..6a19a8fee82cda 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -1459,7 +1459,7 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:utils", - "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/profiler/lib:traceme", ], ) diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index cd8acd1042c2da..43892f1c6a80ae 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -135,17 +135,18 @@ class BatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { const int64 cardinality = Cardinality(); if (index < 0 || index >= cardinality) { return errors::OutOfRange("Index out of range [0, ", cardinality, @@ -170,9 +171,9 @@ class BatchDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size = nullptr; @@ -195,14 +196,14 @@ class BatchDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { tsl::mutex_lock l(mu_); return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { // Each row of `batch_elements` is a tuple of tensors from the // input iterator. std::vector> batch_elements; @@ -274,8 +275,8 @@ class BatchDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( prefix(), kInputImplEmpty, static_cast(!input_impl_))); @@ -285,8 +286,8 @@ class BatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_empty; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/batch_dataset_op_test.cc b/tensorflow/core/kernels/data/batch_dataset_op_test.cc index b38897bcc75354..7b838bcfc0a5c6 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op_test.cc @@ -316,7 +316,7 @@ static void add_identity_nodes(Node* node, Graph& graph, } // Runs type inference pass on graph -static Status type_inference(Graph& graph) { +static absl::Status type_inference(Graph& graph) { GraphOptimizationPassOptions opt_options; std::unique_ptr graph_ptr(new Graph(OpRegistry::Global())); graph_ptr->Copy(graph); diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index b77af19b8a0ea5..aa0a364988331f 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -85,8 +85,8 @@ class DatasetRandomAccessCache { // Extends the temporary cache up to a given index and then updates // out_tensors with the element at that index. - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) { if (!iter_resource_) { TF_ASSIGN_OR_RETURN(iter_resource_, GetIteratorResourceFromDataset(ctx, input_)); @@ -103,7 +103,7 @@ class DatasetRandomAccessCache { std::vector> GetCacheData() { return cache_; } private: - Status ExtendTempCacheToIndex(int64 index, OpKernelContext* ctx) { + absl::Status ExtendTempCacheToIndex(int64 index, OpKernelContext* ctx) { bool end_of_sequence; while (cache_.size() <= index) { std::vector out_tensors; @@ -211,12 +211,13 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -247,14 +248,14 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); return InitializeIterator(ctx); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); return iterator_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -266,14 +267,14 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kMode, mode_)); return SaveInput(ctx, writer, iterator_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); { int64_t temp; @@ -331,7 +332,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { if (!dataset()->env_->FileExists(MetaFilename(filename_)).ok()) { LOG(WARNING) << kIncompleteCacheErrorMessage; std::vector cache_files; - Status s = dataset()->env_->GetMatchingPaths( + absl::Status s = dataset()->env_->GetMatchingPaths( strings::StrCat(filename_, "*"), &cache_files); if (!s.ok()) { LOG(WARNING) << "Failed to get matching files on " << filename_ @@ -347,14 +348,14 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); *end_of_sequence = false; TF_RETURN_IF_ERROR(EnsureLockFileExists(end_of_sequence)); @@ -364,7 +365,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR(writer_->status()); if (cur_index_ >= kMaxItems) { // As a courtesy, close the [truncated] cache file. - Status s = Finish(); + absl::Status s = Finish(); if (!s.ok()) { LOG(ERROR) << s; } @@ -406,8 +407,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kCurIndex, cur_index_)); @@ -442,8 +443,8 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t temp; // TODO(b/78048575): Update this when saving size_t tensors directly @@ -479,7 +480,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { } private: - Status EnsureLockFileExists(bool* end_of_sequence) + absl::Status EnsureLockFileExists(bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (iteration_completed_) { *end_of_sequence = true; @@ -541,7 +542,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { return absl::OkStatus(); } - Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status Finish() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { iteration_completed_ = true; // Flush the current bundle. TF_RETURN_IF_ERROR(writer_->Finish()); @@ -593,9 +594,9 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { reader_(dataset()->env_, dataset()->filename_), iterator_restored_(false) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); *end_of_sequence = false; TF_RETURN_IF_ERROR(reader_.status()); @@ -636,15 +637,15 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kCurIndex, cur_index_)); return absl::OkStatus(); } - Status RestoreInternal( + absl::Status RestoreInternal( IteratorContext* ctx, IteratorStateReader* iterator_state_reader) override { mutex_lock l(mu_); @@ -674,7 +675,7 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase { bool iterator_restored_ TF_GUARDED_BY(mu_); }; // FileReaderIterator - Status InitializeIterator(IteratorContext* ctx) + absl::Status InitializeIterator(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { // We intentionally use the same prefix for both `FileReaderIterator` and // `FileWriterIterator`. Since at any time there will be at most one of @@ -718,9 +719,9 @@ class CacheDatasetOp::FileDataset : public CacheDatasetOp::FileDatasetBase { using FileDatasetBase::FileDatasetBase; protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); Node* filename = nullptr; @@ -739,9 +740,9 @@ class CacheDatasetOp::FileDatasetV2 : public CacheDatasetOp::FileDatasetBase { resource_handle_(resource_handle) {} protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* filename_node = nullptr; @@ -798,8 +799,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return input_->Cardinality(options); }; - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { mutex_lock l(mu_); CardinalityOptions options; @@ -818,8 +819,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return dataset_random_access_cache_->Get(ctx, index, out_tensors); } - Status Get(AnyContext ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(AnyContext ctx, int64 index, + std::vector* out_tensors) const override { mutex_lock l(mu_); if (!iterator_random_access_cache_) { iterator_random_access_cache_ = @@ -828,12 +829,13 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return iterator_random_access_cache_->Get(ctx, index, out_tensors); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -849,14 +851,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { cache_(cache), global_shuffle_iterator_(dataset()) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); return InitializeIterator(ctx); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { if (ctx->index_mapper() != nullptr) { return global_shuffle_iterator_.GetNext(ctx, out_tensors, end_of_sequence); @@ -872,8 +874,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (cache_->IsCompleted()) { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCacheCompleted, "")); @@ -884,8 +886,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return SaveInput(ctx, writer, iterator_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } @@ -916,14 +918,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); @@ -951,8 +953,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!cache_->IsCompleted()) { TF_RETURN_IF_ERROR( @@ -961,8 +963,8 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(prefix(), kCacheCompleted)) { TF_RETURN_IF_ERROR( @@ -985,7 +987,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { cache_(cache), index_(0) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { // The memory allocated for the cache is owned by the parent // dataset but performance modeling uses the iterator abstraction and // thus we record the memory allocated for the cache here. The caveat @@ -998,9 +1000,9 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (index_ < cache_->size()) { const std::vector& cache_tensors = cache_->at(index_); @@ -1022,15 +1024,15 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, index_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); { // kIndex will not be set if we are restoring from a checkpoint @@ -1050,7 +1052,7 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { size_t index_ TF_GUARDED_BY(mu_); }; // MemoryReaderIterator - Status InitializeIterator(IteratorContext* ctx) + absl::Status InitializeIterator(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (cache_->IsCompleted()) { iterator_ = std::make_unique( @@ -1097,7 +1099,7 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase { ~MemoryDataset() override { manager_->Unref(); - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete cache resource: " << s.ToString(); @@ -1105,9 +1107,9 @@ class CacheDatasetOp::MemoryDataset : public CacheDatasetOp::MemoryDatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* filename_node = nullptr; @@ -1141,7 +1143,7 @@ class CacheDatasetOp::MemoryDatasetV2 ~MemoryDatasetV2() override { manager_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete cache resource: " << s.ToString(); @@ -1150,9 +1152,9 @@ class CacheDatasetOp::MemoryDatasetV2 } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* filename_node = nullptr; @@ -1191,7 +1193,7 @@ void CacheDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, bool owns_resource = false; MemoryCacheManager* manager = nullptr; auto handle = HandleFromInput(ctx, 2); - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); if (errors::IsNotFound(s)) { owns_resource = true; diff --git a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc index 0f8f7f8824b2b5..1ddfe5eae0dec5 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops_test.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops_test.cc @@ -49,12 +49,12 @@ class CacheDatasetParams : public DatasetParams { return {filename_tensor}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {CacheDatasetOp::kInputDataset, CacheDatasetOp::kFileName}; return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; @@ -71,7 +71,7 @@ class CacheDatasetParams : public DatasetParams { class CacheDatasetOpTest : public DatasetOpsTestBase { public: - Status Initialize(const DatasetParams& dataset_params) { + absl::Status Initialize(const DatasetParams& dataset_params) { TF_RETURN_IF_ERROR(DatasetOpsTestBase::Initialize(dataset_params)); auto params = static_cast(dataset_params); cache_filename_ = params.filename(); @@ -81,7 +81,7 @@ class CacheDatasetOpTest : public DatasetOpsTestBase { ~CacheDatasetOpTest() override { if (!cache_filename_.empty()) { std::vector cache_files; - Status s = device_->env()->GetMatchingPaths( + absl::Status s = device_->env()->GetMatchingPaths( strings::StrCat(cache_filename_, "*"), &cache_files); if (!s.ok()) { LOG(WARNING) << "Failed to get matching files on " << cache_filename_ diff --git a/tensorflow/core/kernels/data/cache_ops.cc b/tensorflow/core/kernels/data/cache_ops.cc index 002d3876e61ef0..4a422cee5722cb 100644 --- a/tensorflow/core/kernels/data/cache_ops.cc +++ b/tensorflow/core/kernels/data/cache_ops.cc @@ -76,7 +76,7 @@ AnonymousMemoryCacheHandleOp::AnonymousMemoryCacheHandleOp( string AnonymousMemoryCacheHandleOp::name() { return kMemoryCache; } -Status AnonymousMemoryCacheHandleOp::CreateResource( +absl::Status AnonymousMemoryCacheHandleOp::CreateResource( OpKernelContext* ctx, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, MemoryCacheManager** manager) { @@ -87,7 +87,7 @@ Status AnonymousMemoryCacheHandleOp::CreateResource( void DeleteMemoryCacheOp::Compute(OpKernelContext* ctx) { const ResourceHandle& handle = ctx->input(0).flat()(0); // The resource might have been already deleted by the dataset. - Status s = ctx->resource_manager()->Delete(handle); + absl::Status s = ctx->resource_manager()->Delete(handle); if (!errors::IsNotFound(s)) { OP_REQUIRES_OK(ctx, s); } diff --git a/tensorflow/core/kernels/data/cache_ops.h b/tensorflow/core/kernels/data/cache_ops.h index 523b5ee2343e06..e1e58ae9c1df89 100644 --- a/tensorflow/core/kernels/data/cache_ops.h +++ b/tensorflow/core/kernels/data/cache_ops.h @@ -78,11 +78,10 @@ class AnonymousMemoryCacheHandleOp private: string name() override; - Status CreateResource(OpKernelContext* ctx, - std::unique_ptr flib_def, - std::unique_ptr pflr, - FunctionLibraryRuntime* lib, - MemoryCacheManager** manager) override; + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, MemoryCacheManager** manager) override; }; // Deletes an instance of cache resource. diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 77ed1ce1ca6a69..6e7f2fdeef050d 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -108,8 +108,8 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { this, name_utils::IteratorPrefix(kDatasetType, prefix)}); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this)); return absl::OkStatus(); } @@ -141,19 +141,20 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { return input_cardinality + to_concatenate_cardinality; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); inputs->push_back(to_concatenate_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(input_->CheckExternalState()); return to_concatenate_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); if (index < input_cardinality_) { TF_RETURN_IF_ERROR(input_->Get(ctx, index, out_tensors)); @@ -169,9 +170,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph)); Node* to_concatenate_graph = nullptr; @@ -190,7 +191,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); input_impls_.resize(2); @@ -204,9 +205,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (!input_impls_[0] && !input_impls_[1]) { *end_of_sequence = true; @@ -327,8 +328,8 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, i_)); TF_RETURN_IF_ERROR( @@ -348,8 +349,8 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_uninitialized[2]; @@ -443,9 +444,9 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { size_t element_count_ TF_GUARDED_BY(mu_) = 0; }; - Status MostSpecificCompatibleShape(const PartialTensorShape& ts1, - const PartialTensorShape& ts2, - PartialTensorShape* output_tensorshape) { + absl::Status MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2, + PartialTensorShape* output_tensorshape) { if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) return absl::OkStatus(); auto dims1 = ts1.dim_sizes(); diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index e1d14f3ecd9fa8..b3c114ce833a08 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -92,7 +92,8 @@ void DatasetToGraphOp::Compute(OpKernelContext* ctx) { params.external_state_policy = external_state_policy_; GraphDef graph_def; - Status s = AsGraphDef(dataset, SerializationContext(params), &graph_def); + absl::Status s = + AsGraphDef(dataset, SerializationContext(params), &graph_def); if (!s.ok()) { ctx->CtxFailure(errors::FailedPrecondition( "Failed to serialize the input pipeline graph: ", s.message())); @@ -178,7 +179,8 @@ void DatasetFingerprintOp::Compute(OpKernelContext* ctx) { SerializationContext::Params params(ctx); GraphDef graph_def; - Status s = AsGraphDef(dataset, SerializationContext(params), &graph_def); + absl::Status s = + AsGraphDef(dataset, SerializationContext(params), &graph_def); if (!s.ok()) { ctx->CtxFailure(absl::FailedPreconditionError(absl::StrFormat( diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index 936ad1f9d4c357..c734b8b754d0bd 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -79,12 +79,13 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { return cardinality_; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -103,9 +104,9 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* cardinality_node = nullptr; @@ -123,13 +124,13 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); if (!*end_of_sequence) { @@ -163,16 +164,16 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("num_elements"), num_elements_)); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("num_elements"), &num_elements_)); return RestoreInput(ctx, reader, input_impl_); diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc index 6dac458703c872..810e2b5dfa8a4a 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -66,19 +66,20 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* transformations_node = nullptr; @@ -94,7 +95,7 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { std::vector tokens = absl::StrSplit(prefix(), ':', absl::SkipEmpty()); if (dataset()->transformations_.size() > tokens.size() - 2) { @@ -116,9 +117,9 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -129,14 +130,14 @@ class AssertNextDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc index 96a08f9acf2a2b..ae3b159009b92d 100644 --- a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op_test.cc @@ -45,14 +45,14 @@ class AssertNextDatasetParams : public DatasetParams { transformations_)}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->reserve(input_dataset_params_.size() + 1); input_names->emplace_back(AssertNextDatasetOp::kInputDataset); input_names->emplace_back(AssertNextDatasetOp::kTransformations); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{AssertNextDatasetOp::kOutputShapes, output_shapes_}, {AssertNextDatasetOp::kOutputTypes, output_dtypes_}}; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc index 134437a3137358..dbd2159d4af0d4 100644 --- a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op.cc @@ -69,7 +69,8 @@ absl::StatusOr GetPreviousDataset( } // Checks `dataset`'s op name against that in `assertions`. -Status CheckOpName(const DatasetBase& dataset, const NameAttrList& assertions) { +absl::Status CheckOpName(const DatasetBase& dataset, + const NameAttrList& assertions) { if (!MatchesAnyVersion(assertions.name(), dataset.type_string())) { return errors::InvalidArgument("Asserted transformation matching '", assertions.name(), "', but found '", @@ -91,8 +92,8 @@ absl::StatusOr GetDatasetNode(const DatasetBase& dataset, } // Checks `dataset`'s attrs against those in `assertions`. -Status CheckAttributes(const DatasetBase& dataset, - const NameAttrList& assertions) { +absl::Status CheckAttributes(const DatasetBase& dataset, + const NameAttrList& assertions) { if (assertions.attr().empty()) return absl::OkStatus(); TF_ASSIGN_OR_RETURN(NodeDef node, GetDatasetNode(dataset, assertions.name())); std::vector attrs_not_found; @@ -121,8 +122,8 @@ Status CheckAttributes(const DatasetBase& dataset, } // Checks `dataset`'s op name and attrs against those in `transformation`. -Status CheckTransformation(const DatasetBase& dataset, - const tstring& transformation) { +absl::Status CheckTransformation(const DatasetBase& dataset, + const tstring& transformation) { TF_ASSIGN_OR_RETURN(NameAttrList assertions, GetAssertions(transformation)); TF_RETURN_IF_ERROR(CheckOpName(dataset, assertions)); TF_RETURN_IF_ERROR(CheckAttributes(dataset, assertions)); @@ -166,19 +167,20 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* transformations_node = nullptr; @@ -194,7 +196,7 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { const DatasetBase* current_dataset = dataset(); for (int i = 0; i < dataset()->transformations_.size(); ++i) { absl::StatusOr previous_dataset = @@ -205,8 +207,8 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { " transformations but encountered only ", i, "."); } - Status s = CheckTransformation(**previous_dataset, - dataset()->transformations_[i]); + absl::Status s = CheckTransformation(**previous_dataset, + dataset()->transformations_[i]); if (!s.ok()) { return errors::InvalidArgument( "Failure checking transformations at offset ", i, ": ", @@ -218,9 +220,9 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -231,14 +233,14 @@ class AssertPrevDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc index cb7bd224bb23f2..3e9608f00d5daf 100644 --- a/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/assert_prev_dataset_op_test.cc @@ -67,14 +67,14 @@ class AssertPrevDatasetParams : public DatasetParams { transformations_)}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->reserve(input_dataset_params_.size() + 1); input_names->emplace_back(AssertPrevDatasetOp::kInputDataset); input_names->emplace_back(AssertPrevDatasetOp::kTransformations); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{AssertPrevDatasetOp::kOutputShapes, output_shapes_}, {AssertPrevDatasetOp::kOutputTypes, output_dtypes_}}; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc index e7b2925b0dca25..7cab05d86bd917 100644 --- a/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/auto_shard_dataset_op_test.cc @@ -50,7 +50,7 @@ class AutoShardDatasetParams : public DatasetParams { return CreateTensors(TensorShape({}), {{num_workers_}, {index_}}); } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(AutoShardDatasetOp::kInputDataset); input_names->emplace_back(AutoShardDatasetOp::kNumWorkers); @@ -58,7 +58,7 @@ class AutoShardDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back(AutoShardDatasetOp::kAutoShardPolicy, auto_shard_policy_); @@ -215,7 +215,7 @@ static void add_identity_nodes(Node* node, Graph& graph, } // Runs type inference pass on graph -static Status type_inference(Graph& graph) { +static absl::Status type_inference(Graph& graph) { GraphOptimizationPassOptions opt_options; std::unique_ptr graph_ptr(new Graph(OpRegistry::Global())); graph_ptr->Copy(graph); diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc index 2d2b239205302b..50429905a0e6f1 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_branch_dataset_op.cc @@ -55,16 +55,17 @@ class WrapperDataset : public DatasetBase { string DebugString() const override { return "WrapperDataset"; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { return errors::Unimplemented(DebugString(), "::AsGraphDefInternal"); } @@ -86,7 +87,7 @@ class WrapperDataset : public DatasetBase { explicit WrapperIterator(const Params& params, bool error) : DatasetIterator(params), error_(error) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { if (error_) { return errors::InvalidArgument( "Cannot create more than one WrapperIterator per WrapperDataset. " @@ -96,9 +97,9 @@ class WrapperDataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return dataset()->real_iterator_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -109,13 +110,13 @@ class WrapperDataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return absl::OkStatus(); } @@ -249,13 +250,13 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { return static_cast(n) * ratio_numerator_ / ratio_denominator_; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { for (const auto& captured_func : captured_funcs_) { TF_RETURN_IF_ERROR(captured_func->CheckExternalState()); } @@ -263,9 +264,9 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -341,7 +342,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { instantiated_captured_funcs_(dataset()->captured_funcs_.size()), histograms_(dataset()->captured_funcs_.size()) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); @@ -357,9 +358,9 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { // The first num_elements_per_branch * num_branches iterations, we run // experiments on the branches, using (branch_index_, experiment_counter_) // to keep track of which experiment we're on. - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { { // Locking scope mutex_lock l(mu_); if (branch_index_ < dataset()->captured_funcs_.size()) { @@ -370,7 +371,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { /*is_get_next=*/true)); } - Status s = GetNextFromExperiment(ctx, out_tensors, end_of_sequence); + absl::Status s = + GetNextFromExperiment(ctx, out_tensors, end_of_sequence); experiment_counter_++; if (experiment_counter_ >= dataset()->num_elements_per_branch_) { @@ -406,8 +408,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { // TODO(rachelim): Save and restore histogram state as well. Currently, // if an iterator is saved and restored, the histograms start recording // from scratch. - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), @@ -425,8 +427,8 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), @@ -453,15 +455,15 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { } private: - Status GetNextFromExperiment(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) + absl::Status GetNextFromExperiment(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { DCHECK_GE(branch_index_, 0); DCHECK_LT(branch_index_, histograms_.size()); int64_t start = EnvTime::NowNanos(); - Status s = + absl::Status s = current_iterator_->GetNext(ctx, out_tensors, end_of_sequence); if (experiment_counter_ > 0) { @@ -495,8 +497,9 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel { << " as the fastest index."; } - Status MakeCurrentIterator(IteratorContext* ctx, int64_t branch_index, - bool is_experiment, bool is_get_next) + absl::Status MakeCurrentIterator(IteratorContext* ctx, + int64_t branch_index, bool is_experiment, + bool is_get_next) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { DCHECK_GE(branch_index, 0); DCHECK_LT(branch_index, histograms_.size()); diff --git a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc index 41352a2cd40f5a..68ee0a5cec7878 100644 --- a/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/choose_fastest_dataset_op.cc @@ -160,7 +160,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { return cardinality_; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { for (const auto& input : inputs_) { inputs->push_back(input); @@ -168,7 +168,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { for (const auto& input : inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); } @@ -176,9 +176,9 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { std::vector input_nodes; input_nodes.reserve(inputs_.size()); for (const auto& input : inputs_) { @@ -201,7 +201,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { : DatasetIterator(params), histograms_(dataset()->inputs_.size()) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); input_impls_.resize(dataset()->inputs_.size()); for (size_t i = 0, num_inputs = dataset()->inputs_.size(); @@ -213,9 +213,9 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); // The first num_experiments_ iterations, we fire up a thread for @@ -248,8 +248,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { // TODO(rachelim): Save and restore histogram state as well. Currently, // if an iterator is saved and restored, the histograms start recording // from scratch. - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"), experiment_counter_)); @@ -269,8 +269,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("experiment_counter"), &experiment_counter_)); @@ -295,7 +295,7 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { private: struct InvocationResult { Notification notification; - Status status; + absl::Status status; bool end_of_sequence; std::vector out_tensors; }; @@ -332,8 +332,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel { RecordStart(ctx); auto cleanup = gtl::MakeCleanup([this, ctx]() { RecordStop(ctx); }); int64_t start = EnvTime::NowNanos(); - Status s = input_impls_[i]->GetNext(ctx, &result->out_tensors, - &result->end_of_sequence); + absl::Status s = input_impls_[i]->GetNext(ctx, &result->out_tensors, + &result->end_of_sequence); histograms_[i].Add(static_cast(EnvTime::NowNanos() - start)); result->status = s; diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc index 80c717806f7be7..1850a5fbde2e41 100644 --- a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -200,18 +200,20 @@ class CSVDatasetOp : public DatasetOpKernel { string DebugString() const override { return "CSVDatasetOp::Dataset"; } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { + return absl::OkStatus(); + } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* filenames = nullptr; Node* compression_type = nullptr; Node* buffer_size = nullptr; @@ -278,16 +280,16 @@ class CSVDatasetOp : public DatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); bool select_all = dataset()->select_cols_.empty() && dataset()->exclude_cols_.empty(); do { // We are currently processing a file, so try to read the next record if (input_stream_) { - Status s = + absl::Status s = ReadRecord(ctx, out_tensors, select_all, dataset()->select_cols_, dataset()->exclude_cols_); if (s.ok()) { @@ -326,8 +328,8 @@ class CSVDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), current_file_index_)); @@ -343,8 +345,8 @@ class CSVDatasetOp : public DatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); ResetStreamsLocked(); int64_t current_file_index; @@ -364,7 +366,7 @@ class CSVDatasetOp : public DatasetOpKernel { num_buffer_reads_ = size_t(num_buffer_reads - 1); // Restores the most recently held buffer - Status s = input_stream_->SkipNBytes( + absl::Status s = input_stream_->SkipNBytes( num_buffer_reads_ * dataset()->options_.input_buffer_size); if (!s.ok() && !errors::IsOutOfRange(s)) { // We might get out of range error here if the size of the file @@ -374,7 +376,7 @@ class CSVDatasetOp : public DatasetOpKernel { return s; } - Status s2 = FillBuffer(&buffer_); + absl::Status s2 = FillBuffer(&buffer_); if (!s2.ok() && !errors::IsOutOfRange(s2)) { return s2; } @@ -392,9 +394,10 @@ class CSVDatasetOp : public DatasetOpKernel { // character of the record in buffer_, or past the end of the buffer. // Note: ctx and out_tensors are only used in this function // when fields are included in the record. - Status ReadRecord(IteratorContext* ctx, std::vector* out_tensors, - bool select_all, const std::vector& selected, - const std::vector& excluded) + absl::Status ReadRecord(IteratorContext* ctx, + std::vector* out_tensors, bool select_all, + const std::vector& selected, + const std::vector& excluded) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (pos_ >= buffer_.size()) { // At the end of the file, this will return errors::OutOfRange @@ -410,7 +413,7 @@ class CSVDatasetOp : public DatasetOpKernel { size_t num_selected_parsed = 0; size_t num_excluded_parsed = 0; - Status result; + absl::Status result; while (!end_of_record) { // Read till we reach \n, \r or EOF bool explicit_exclude = num_excluded_parsed < excluded.size() && @@ -436,14 +439,14 @@ class CSVDatasetOp : public DatasetOpKernel { // Parses one field from position pos_ in the buffer. Fields are // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of // the next field. - Status ParseOneField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) + absl::Status ParseOneField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (pos_ >= buffer_.size()) { // If we get here, this means the previous field's end coincided // with the end of the buffer. We can fill the buffer without abandon. - Status s = FillBuffer(&buffer_); + absl::Status s = FillBuffer(&buffer_); if (errors::IsOutOfRange(s)) { // Reached EOF, and last field is empty @@ -480,8 +483,8 @@ class CSVDatasetOp : public DatasetOpKernel { // Given that pos_ exceeds the buffer, saves the relevant part of the // current buffer (if necessary), fills the buffer, and resets indices to // 0. - Status SaveAndFillBuffer(std::vector* earlier_pieces, - size_t* start, bool include) + absl::Status SaveAndFillBuffer(std::vector* earlier_pieces, + size_t* start, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { tstring temp_buffer; @@ -499,18 +502,19 @@ class CSVDatasetOp : public DatasetOpKernel { // reads from buffer until end of field is reached (delim, CRLF, or EOF). // Advances pos_ to keep track of our position in the buffer as we go, // stopping at the first character of the next field. - Status ParseQuotedField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) + absl::Status ParseQuotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector earlier_pieces; size_t start = pos_; pos_++; // Starting quotation mark - Status parse_result; + absl::Status parse_result; while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + absl::Status s = + SaveAndFillBuffer(&earlier_pieces, &start, include); if (errors::IsOutOfRange(s)) { return errors::InvalidArgument( "Reached end of file without closing quoted field in " @@ -526,7 +530,8 @@ class CSVDatasetOp : public DatasetOpKernel { // decide what to do pos_++; if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + absl::Status s = + SaveAndFillBuffer(&earlier_pieces, &start, include); if (errors::IsOutOfRange(s)) { // This was the last field. We are done *end_of_record = true; @@ -570,10 +575,10 @@ class CSVDatasetOp : public DatasetOpKernel { // Converts quoted field to an output tensor, removing the starting // and ending quotes from it and unescaping double quotations if // necessary. - Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector* out_tensors, - const std::vector& earlier_pieces, - bool include) + absl::Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, + bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!include) return absl::OkStatus(); @@ -633,17 +638,18 @@ class CSVDatasetOp : public DatasetOpKernel { // reads from buffer until end of field is reached (delim, CRLF, or EOF). // Advances pos_ to keep track of our position in the buffer as we go, // stopping at the first character of the next field. - Status ParseUnquotedField(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_record, bool include) + absl::Status ParseUnquotedField(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_record, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector earlier_pieces; size_t start = pos_; - Status parse_result; + absl::Status parse_result; while (true) { // Each iter reads 1 char, filling buffer if necessary if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + absl::Status s = + SaveAndFillBuffer(&earlier_pieces, &start, include); // Handle errors if (errors::IsOutOfRange(s)) { // Whatever we have is the last field of the last record @@ -687,10 +693,11 @@ class CSVDatasetOp : public DatasetOpKernel { } } - Status FillBuffer(tstring* result) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status FillBuffer(tstring* result) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { result->clear(); ++num_buffer_reads_; - Status s = input_stream_->ReadNBytes( + absl::Status s = input_stream_->ReadNBytes( dataset()->options_.input_buffer_size, result); if (errors::IsOutOfRange(s) && !result->empty()) { @@ -701,8 +708,8 @@ class CSVDatasetOp : public DatasetOpKernel { } // Given a field, converts it to the right output tensor type - Status FieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector* out_tensors) { + absl::Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector* out_tensors) { size_t output_idx = out_tensors->size(); if (output_idx >= dataset()->out_type_.size()) { // We can get here if we're selecting all columns, but the number of @@ -806,7 +813,7 @@ class CSVDatasetOp : public DatasetOpKernel { // linebreak, and ignore it if so. void SkipNewLineIfNecessary() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); + absl::Status s = FillBuffer(&buffer_); pos_ = 0; // If we failed to fill buffer, it doesn't matter because we're done // with the record @@ -820,10 +827,10 @@ class CSVDatasetOp : public DatasetOpKernel { // Given a string field, and its index in the output, // converts it to a Tensor of the right type and adds it to the // out_tensors vector. - Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector* out_tensors, - const std::vector& earlier_pieces, - bool include) + absl::Status UnquotedFieldToOutput( + IteratorContext* ctx, StringPiece field, + std::vector* out_tensors, + const std::vector& earlier_pieces, bool include) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!include) return absl::OkStatus(); @@ -847,7 +854,8 @@ class CSVDatasetOp : public DatasetOpKernel { } // Sets up reader streams to read from the file at `current_file_index_`. - Status SetupStreamsLocked(Env* env) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status SetupStreamsLocked(Env* env) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (current_file_index_ >= dataset()->filenames_.size()) { return errors::InvalidArgument( "current_file_index_:", current_file_index_, @@ -878,7 +886,7 @@ class CSVDatasetOp : public DatasetOpKernel { // the first newline because it might contain quoted fields with // newlines in the header as well std::vector empty; - Status s = ReadRecord(nullptr, nullptr, false, empty, empty); + absl::Status s = ReadRecord(nullptr, nullptr, false, empty, empty); if (!s.ok()) { return errors::InvalidArgument("Can't read header of file"); } diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index a46aa2937b6f1e..7c4e83714c3307 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -152,7 +152,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { ~Dataset() override { iteration_counter_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( iteration_counter_handle_.container(), iteration_counter_handle_.name()); if (!s.ok()) { @@ -190,21 +190,22 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { is_coordinated_read_); } - Status CheckExternalState() const override { - return Status( + absl::Status CheckExternalState() const override { + return absl::Status( absl::StatusCode::kFailedPrecondition, strings::StrCat(DebugString(), " does not yet support serialization.")); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { // Inputs std::vector inputs; @@ -337,19 +338,21 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { data_service_client_.Cancel(); }, &deregister_fn_)); tsl::AllocatorAttributes attrs; - attrs.set_gpu_compatible(ctx->options()->service_options().pinned()); + if (ctx->options() != nullptr) { + attrs.set_gpu_compatible(ctx->options()->service_options().pinned()); + } return data_service_client_.Initialize(ctx->accelerator_device_info(), ctx->allocator(attrs)); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { auto ctx_factory = [ctx, this]() { return std::make_unique( ctx, this, buffer_size_, model_node()); @@ -372,13 +375,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { /*max=*/std::numeric_limits::max())}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return errors::Unimplemented("SaveInternal is not yet supported"); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return errors::Unimplemented("RestoreInternal is not yet supported"); } @@ -658,7 +661,7 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, OP_REQUIRES_OK( ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle)); IterationCounter* iteration_counter = nullptr; - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( iteration_counter_handle.container(), iteration_counter_handle.name(), &iteration_counter); bool owns_resource = false; diff --git a/tensorflow/core/kernels/data/experimental/data_service_ops.cc b/tensorflow/core/kernels/data/experimental/data_service_ops.cc index f3434f04625d96..f548afb29af188 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_ops.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_ops.cc @@ -103,8 +103,8 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) { params.external_state_policy = external_state_policy_; SerializationContext serialization_ctx(params); DatasetDef dataset_def; - Status s = AsGraphDef(dataset, std::move(serialization_ctx), - dataset_def.mutable_graph()); + absl::Status s = AsGraphDef(dataset, std::move(serialization_ctx), + dataset_def.mutable_graph()); if (!s.ok()) { OP_REQUIRES_OK( ctx, diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc index 91b596c6273b66..f8b33402e35363 100644 --- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc @@ -120,20 +120,20 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* batch_size_node; @@ -156,14 +156,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const typename Iterator::Params& params) : DatasetIterator>(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return DatasetIterator>::dataset()->input_->MakeIterator( ctx, this, DatasetIterator>::prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { // Each row of the output SparseTensor is an individual tensor // from the input iterator. std::vector batch_elements; @@ -295,15 +295,15 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { DatasetIterator>::dataset()->batch_size_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(Iterator::RestoreInput(ctx, reader, input_impl_)); return absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc index 059105c0aaa39b..81b97f275bc3de 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -82,8 +82,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { this, name_utils::IteratorPrefix(kDatasetType, prefix)}); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { TF_ASSIGN_OR_RETURN(*split_providers, GetSplitProviders(this)); return absl::OkStatus(); } @@ -112,7 +112,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { return kUnknownCardinality; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(selector_input_); for (const auto& data_input : data_inputs_) { inputs->push_back(data_input); @@ -120,7 +121,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { for (const auto& input : data_inputs_) { TF_RETURN_IF_ERROR(input->CheckExternalState()); } @@ -128,9 +129,9 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* selector_input_node; TF_RETURN_IF_ERROR( b->AddInputDataset(ctx, selector_input_, &selector_input_node)); @@ -163,7 +164,7 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_ASSIGN_OR_RETURN(input_contexts_, CreateInputIteratorContexts(ctx, dataset())); @@ -181,9 +182,9 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (!selector_input_impl_) { *end_of_sequence = true; @@ -251,8 +252,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { {model::MakeNonTunableParameter(kCycleLength, /*value=*/1)}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kSelectorInputImplEmpty), @@ -272,8 +273,8 @@ class DirectedInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_empty; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc index e79c64d5750e72..15e2e529b2eb6e 100644 --- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc @@ -49,7 +49,7 @@ class DirectedInterleaveDatasetParams : public DatasetParams { std::vector GetInputTensors() const override { return {}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back( DirectedInterleaveDatasetOp::kSelectorInputDataset); @@ -60,7 +60,7 @@ class DirectedInterleaveDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes, output_dtypes_); diff --git a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc index 7e737026c8d4a4..1a635b2b293758 100644 --- a/tensorflow/core/kernels/data/experimental/distributed_save_op.cc +++ b/tensorflow/core/kernels/data/experimental/distributed_save_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" +#include "xla/tsl/lib/io/compression.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/io/compression.h" namespace tensorflow { namespace data { @@ -78,8 +78,8 @@ void DistributedSaveOp::Compute(OpKernelContext* ctx) { SerializationContext::Params params(ctx); SerializationContext serialization_ctx(params); DatasetDef dataset_def; - Status s = AsGraphDef(dataset, std::move(serialization_ctx), - dataset_def.mutable_graph()); + absl::Status s = AsGraphDef(dataset, std::move(serialization_ctx), + dataset_def.mutable_graph()); if (!s.ok()) { OP_REQUIRES_OK( ctx, diff --git a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc index 059719f214f3f0..9f17e21a62032d 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_reducer_dataset_op.cc @@ -113,7 +113,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return "GroupByReducerDatasetOp::Dataset"; } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState()); TF_RETURN_IF_ERROR(captured_init_func_->CheckExternalState()); TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState()); @@ -122,9 +122,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -200,7 +200,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate( @@ -214,9 +214,9 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); // Iterate through the input dataset, keying input elements to reducers. @@ -286,8 +286,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_key_func_->CheckExternalState())); TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( @@ -344,8 +344,8 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); diff --git a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc index 508a281f43f49d..de370e0cc46cd4 100644 --- a/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/group_by_window_dataset_op.cc @@ -115,13 +115,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return kUnknownCardinality; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState()); TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState()); TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState()); @@ -129,9 +129,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -193,7 +193,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate( @@ -205,9 +205,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); do { if (current_group_iterator_) { @@ -310,8 +310,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_key_func_->CheckExternalState())); TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( @@ -374,8 +374,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -435,8 +435,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { } private: - Status SaveGroup(IteratorStateWriter* writer, const string& name, - const std::vector>& group) + absl::Status SaveGroup(IteratorStateWriter* writer, const string& name, + const std::vector>& group) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(strings::StrCat(name, "_size"), group.size())); @@ -451,9 +451,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreGroup(IteratorContext* ctx, IteratorStateReader* reader, - const string& name, - std::vector>* group) + absl::Status RestoreGroup(IteratorContext* ctx, + IteratorStateReader* reader, const string& name, + std::vector>* group) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t group_size; TF_RETURN_IF_ERROR( @@ -473,7 +473,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status StartFlushingGroup(IteratorContext* ctx, int64_t key) + absl::Status StartFlushingGroup(IteratorContext* ctx, int64_t key) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { DatasetBase* group_dataset; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc index 5d0c9ff554b531..ff6757fe7ec129 100644 --- a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -68,20 +68,20 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return kUnknownCardinality; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue log_warning_attr; @@ -98,15 +98,15 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { - Status s; + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + absl::Status s; { tf_shared_lock l(mu_); if (!input_impl_) { @@ -136,8 +136,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (input_impl_) TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); @@ -147,8 +147,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (reader->Contains(full_name("input_impls_empty"))) input_impl_.reset(); diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc index 2852b443205205..fa0faf25b55734 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc @@ -64,8 +64,8 @@ class ListDatasetOp::Dataset : public DatasetBase { this, name_utils::IteratorPrefix(kDatasetType, prefix)}); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { split_providers->push_back( std::make_unique(num_elements_)); return absl::OkStatus(); @@ -85,11 +85,12 @@ class ListDatasetOp::Dataset : public DatasetBase { return num_elements_; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } absl::Status RandomIndexingCompatible() const override { return absl::OkStatus(); @@ -112,9 +113,9 @@ class ListDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { std::vector tensors; tensors.reserve(tensors_.size()); for (const Tensor& t : tensors_) { @@ -144,7 +145,7 @@ class ListDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { if (ctx->split_providers().empty()) { split_provider_ = std::make_shared(dataset()->num_elements_); @@ -155,9 +156,9 @@ class ListDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { if (ctx->index_mapper() != nullptr) { return global_shuffle_iterator_.GetNext(ctx, out_tensors, end_of_sequence); @@ -184,16 +185,16 @@ class ListDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(split_provider_->Save( [this](const std::string& key) { return full_name(key); }, writer)); TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc index 44e25cdb334a71..6399c5e95723e2 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc @@ -47,7 +47,7 @@ class ListDatasetParams : public DatasetParams { std::vector GetInputTensors() const override { return tensors_; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->reserve(tensors_.size()); for (int i = 0; i < tensors_.size(); ++i) { input_names->emplace_back(absl::StrCat("tensors_", i)); @@ -55,7 +55,7 @@ class ListDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"Tinput_types", input_types_}, {"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/experimental/load_dataset_op.cc b/tensorflow/core/kernels/data/experimental/load_dataset_op.cc index a28743164d43a5..6e8c065714e17f 100644 --- a/tensorflow/core/kernels/data/experimental/load_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/load_dataset_op.cc @@ -81,19 +81,20 @@ class LoadDatasetOp::Dataset : public DatasetBase { return metadata_.num_elements(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return captured_reader_func_->CheckExternalState(); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* path_node = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(path_, &path_node)); @@ -137,7 +138,7 @@ class LoadDatasetOp::Dataset : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(dataset()->captured_reader_func_->Instantiate( ctx, &instantiated_captured_reader_func_)); @@ -145,9 +146,9 @@ class LoadDatasetOp::Dataset : public DatasetBase { return input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -158,20 +159,20 @@ class LoadDatasetOp::Dataset : public DatasetBase { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); return this->SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); return this->RestoreInput(ctx, reader, input_impl_); } private: - Status InitializeInput(IteratorContext* ctx) + absl::Status InitializeInput(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto run_dir = snapshot_util::RunDirectory( TranslateFileName(dataset()->path_), dataset()->metadata_.run_id()); diff --git a/tensorflow/core/kernels/data/experimental/lookup_ops.cc b/tensorflow/core/kernels/data/experimental/lookup_ops.cc index 7fff4076d71bd8..904e6b04c8c501 100644 --- a/tensorflow/core/kernels/data/experimental/lookup_ops.cc +++ b/tensorflow/core/kernels/data/experimental/lookup_ops.cc @@ -53,7 +53,7 @@ class DatasetIterator ~DatasetIterator() override {} - Status Init(OpKernelContext* ctx) { + absl::Status Init(OpKernelContext* ctx) { data::IteratorContext::Params params(ctx); function_handle_cache_ = std::make_unique(params.flr); params.function_handle_cache = function_handle_cache_.get(); @@ -88,7 +88,7 @@ class DatasetIterator const Tensor& values() const override { return tensors_[1]; } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } int64_t total_size() const override { int64_t size = dataset_->Cardinality(); @@ -106,7 +106,7 @@ class DatasetIterator std::unique_ptr cancellation_manager_; std::unique_ptr iterator_; std::vector tensors_; - Status status_; + absl::Status status_; }; std::unique_ptr MakeDatasetInitializerSerializer( @@ -170,7 +170,7 @@ void InitializeTableFromDataset(OpKernelContext* ctx, dataset_shapes[1].DebugString())); DatasetIterator iter(dataset); OP_REQUIRES_OK(ctx, iter.Init(ctx)); - Status s = + absl::Status s = table->Initialize(iter, MakeDatasetInitializerSerializer(ctx, dataset)); if (errors::IsFailedPrecondition(s) && table->is_initialized()) { LOG(INFO) << "Table already initialized from dataset."; diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc index d88506b4176a29..335d18ab0905f3 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op.cc @@ -135,20 +135,21 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -216,7 +217,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); interleave_depth_ = ctx->interleave_depth(); @@ -241,9 +242,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::shared_ptr result; { mutex_lock l(*mu_); @@ -290,8 +291,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); if (ctx->symbolic_checkpoint()) { @@ -316,8 +317,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(*mu_); DCHECK(!runner_thread_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -381,7 +382,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { // (required to make the behavior is observably identical to a // sequential execution of map followed by batch), we must also keep // track of the offset into the batch that produced `s`. - void UpdateStatus(const Status& s, int64_t offset) { + void UpdateStatus(const absl::Status& s, int64_t offset) { if (TF_PREDICT_FALSE(!s.ok())) { mutex_lock l(mu); if (status.ok() || offset < status_offset) { @@ -396,7 +397,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { int64_t num_elements TF_GUARDED_BY(mu); std::vector output; bool output_allocated TF_GUARDED_BY(mu); - Status status TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); int64_t status_offset TF_GUARDED_BY(mu); // Counts the number of outstanding calls for this batch. int64_t num_calls TF_GUARDED_BY(&Iterator::mu_); @@ -431,7 +432,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { // Get the next input element. std::vector input_element; bool end_of_input = false; - Status status = + absl::Status status = input_impl_->GetNext(ctx.get(), &input_element, &end_of_input); bool return_early; { @@ -448,7 +449,8 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { std::shared_ptr> return_values = std::make_shared>(); - auto done = [this, ctx, result, return_values, offset](Status status) { + auto done = [this, ctx, result, return_values, + offset](absl::Status status) { if (dataset()->preserve_cardinality_ && errors::IsOutOfRange(status)) { // To guarantee that the transformation preserves the cardinality of // the dataset, we convert `OutOfRange` to `InvalidArgument` as the @@ -459,7 +461,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } result->UpdateStatus(status, offset); if (status.ok()) { - Status allocate_status = + absl::Status allocate_status = EnsureOutputAllocated(ctx, result, return_values); if (!allocate_status.ok()) { result->UpdateStatus(allocate_status, offset); @@ -483,7 +485,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { // TODO(mrry): Add a version of DoParallelConcat that allows us // to move `tensor` where possible, to speed up string tensor // batching. - Status copy_status = batch_util::CopyElementToSlice( + absl::Status copy_status = batch_util::CopyElementToSlice( std::move(tensor), batch, offset); if (!copy_status.ok()) { result->UpdateStatus(copy_status, offset); @@ -527,7 +529,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } } - Status EnsureOutputAllocated( + absl::Status EnsureOutputAllocated( const std::shared_ptr& ctx, const std::shared_ptr& result, const std::shared_ptr>& return_values) { @@ -620,8 +622,9 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { } } - Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, - size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + absl::Status ReadBatchResult(IteratorContext* ctx, + IteratorStateReader* reader, size_t index) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { batch_results_.push_back( std::make_shared(dataset()->batch_size_, ctx)); std::shared_ptr result = batch_results_.back(); @@ -649,7 +652,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteBatchResult(IteratorStateWriter* writer, size_t index) + absl::Status WriteBatchResult(IteratorStateWriter* writer, size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr result = batch_results_[index]; string batch_prefix = strings::StrCat(kBatchResults, "_", index); diff --git a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc index a340bd20758a48..e0e58aa0391379 100644 --- a/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/map_and_batch_dataset_op_test.cc @@ -55,7 +55,7 @@ class MapAndBatchDatasetParams : public DatasetParams { return inputs; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->reserve(input_dataset_params_.size() + other_arguments_.size() + 3); input_names->emplace_back(MapAndBatchDatasetOp::kInputDataset); @@ -70,7 +70,7 @@ class MapAndBatchDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc index 403439b1604277..4c0184e1b4b36e 100644 --- a/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/matching_files_dataset_op.cc @@ -82,17 +82,19 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return "MatchingFilesDatasetOp::Dataset"; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { + return absl::OkStatus(); + } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* patterns_node = nullptr; TF_RETURN_IF_ERROR(b->AddVector(patterns_, &patterns_node)); TF_RETURN_IF_ERROR(b->AddDataset(this, {patterns_node}, output)); @@ -105,9 +107,9 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); FileSystem* fs; @@ -197,8 +199,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name("current_pattern_index"), current_pattern_index_)); @@ -229,8 +231,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t current_pattern_index; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -272,15 +274,15 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { } private: - Status UpdateIterator(IteratorContext* ctx, FileSystem* fs, - const string& dir, const string& eval_pattern) + absl::Status UpdateIterator(IteratorContext* ctx, FileSystem* fs, + const string& dir, const string& eval_pattern) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { StringPiece fixed_prefix = StringPiece(eval_pattern) .substr(0, eval_pattern.find_first_of("*?[\\")); filepath_queue_.push(PathStatus(dir, true)); - Status ret; // Status to return + absl::Status ret; // Status to return // DFS to find the first element in the iterator. while (!filepath_queue_.empty()) { @@ -312,7 +314,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { // three possible values: OK for true; FAILED_PRECONDITION for false; // CANCELLED if we don't calculate IsDirectory (we might do that // because there isn't any point in exploring that child path). - std::vector children_dir_status; + std::vector children_dir_status; children_dir_status.resize(children.size()); // This IsDirectory call can be expensive for some FS. Parallelizing @@ -342,7 +344,7 @@ class MatchingFilesDatasetOp : public DatasetOpKernel { for (int i = 0; i < children.size(); i++) { const string& child_dir_path = io::JoinPath(current_dir, children[i]); - const Status& child_dir_status = children_dir_status[i]; + const absl::Status& child_dir_status = children_dir_status[i]; // If the IsDirectory call was cancelled we bail. if (child_dir_status.code() == tensorflow::error::CANCELLED) { diff --git a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc index 1649bb7d54a93b..dd9725f0d58ef2 100644 --- a/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/non_serializable_dataset_op.cc @@ -69,20 +69,20 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { return "NonSerializableDatasetOp::Dataset"; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { return errors::Unimplemented(DebugString(), " does not support serialization."); } @@ -97,14 +97,14 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -114,14 +114,14 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index cd9bb2aa6d4147..662adc5295bfc4 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -149,20 +149,21 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType, params); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { std::vector> inputs; std::vector>> list_inputs; int input_index = 0; @@ -290,7 +291,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // TODO(jsimsa): Register cancellation callback once the implementation is // refactored not to hold mu_ while calling `GetNext` on the input. - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { cancellation_manager_ = std::make_unique(); IteratorContext::Params params(ctx); params.cancellation_manager = cancellation_manager_.get(); @@ -303,9 +304,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // It is implemented so that it matches the deterministic interleave // unless getting the next element would block and we are allowed to be // nondeterministic. - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureWorkerThreadsStarted(ctx)); while (!cancelled_) { @@ -339,7 +340,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { block_count_ = 0; } *end_of_sequence = false; - Status s = current_worker->outputs.front().status; + absl::Status s = current_worker->outputs.front().status; tsl::profiler::TraceMe traceme([&] { return tsl::profiler::TraceMeEncode( "ParallelInterleaveConsume", @@ -365,7 +366,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // Start prefetching a new iterator. std::vector args; bool end_of_input = false; - Status s = input_impl_->GetNext(ctx, &args, &end_of_input); + absl::Status s = input_impl_->GetNext(ctx, &args, &end_of_input); if (end_of_input) { input_impl_.reset(); } else { @@ -420,8 +421,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { kDeterministic, deterministic_ ? 1.0 : 0.0)}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); // The order of locking is important here to avoid deadlock. @@ -465,8 +466,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { { // The order of locking is important here to avoid deadlock. mutex_lock l(mu_); @@ -494,12 +495,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } std::unique_ptr threadpool = ctx->CreateThreadPool( "read_worker_thread_state", dataset()->num_threads()); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); BlockingCounter counter(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { threadpool->Schedule([this, i, ctx, reader, &s, &counter] { WorkerThreadState state; - Status result = ReadWorkerThreadStateLocked(ctx, reader, i, &state); + absl::Status result = + ReadWorkerThreadStateLocked(ctx, reader, i, &state); mutex_lock l(mu_); mutex_lock ckpt_l(ckpt_mu_); if (!result.ok()) { @@ -585,13 +587,13 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { struct OutputElem { // The output iterator sets `status` if getting the output element // fails. - Status status; + absl::Status status; // The buffered data element. std::vector output; int64_t id = -1; - explicit OutputElem(const Status& s) : status(s) {} - OutputElem(const Status& s, int64_t id) : status(s), id(id) {} + explicit OutputElem(const absl::Status& s) : status(s) {} + OutputElem(const absl::Status& s, int64_t id) : status(s), id(id) {} }; // Worker threads operate on their relevant WorkerState structs. @@ -620,7 +622,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } // Sets inputs for a worker thread and notifies it to start processing. - void SetInputs(const Status& s, std::vector input_arguments) { + void SetInputs(const absl::Status& s, + std::vector input_arguments) { if (s.ok()) { DCHECK(!MayHaveElements()) << "Tried to start inputs, despite already producing!"; @@ -648,7 +651,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { bool end_of_sequence = false; // Status returned from `MakeIteratorFromInputElement`. - Status iterator_creation_status; + absl::Status iterator_creation_status; // The arguments to be used to construct `iterator`. std::vector input; @@ -667,14 +670,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } } - Status EnsureWorkerThreadsStarted(IteratorContext* ctx) + absl::Status EnsureWorkerThreadsStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (worker_threads_.empty() && input_impl_) { worker_threads_.reserve(dataset()->num_threads()); for (int64_t i = 0; i < dataset()->num_threads(); ++i) { std::vector args; bool end_of_input = false; - Status s = input_impl_->GetNext(ctx, &args, &end_of_input); + absl::Status s = input_impl_->GetNext(ctx, &args, &end_of_input); if (end_of_input) { input_impl_.reset(); return absl::OkStatus(); @@ -748,7 +751,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { // this function for details. while (true) { // Whether creation of the iterator succeeded. - Status iterator_creation_status; + absl::Status iterator_creation_status; // 1. Build a new iterator or use the existing one. if (make_new_iterator) { // 1a. Get new input tensors or use the exiting ones. @@ -944,7 +947,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { } } - Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index) + absl::Status WriteWorkerStateLocked(IteratorStateWriter* writer, int index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { string iterator_name = strings::StrCat(prefix(), "::", kWorker, "_", index); @@ -969,8 +972,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadWorkerStateLocked(IteratorContext* ctx, - IteratorStateReader* reader, int index) + absl::Status ReadWorkerStateLocked(IteratorContext* ctx, + IteratorStateReader* reader, int index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { string worker_prefix = strings::StrCat(prefix(), "::", kWorker, "_", index); @@ -1002,8 +1005,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteWorkerThreadStateLocked(SerializationContext* ctx, - IteratorStateWriter* writer, int index) + absl::Status WriteWorkerThreadStateLocked(SerializationContext* ctx, + IteratorStateWriter* writer, + int index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { string iterator_name = strings::StrCat(prefix(), "::", kWorkerThread, "_", index); @@ -1035,9 +1039,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadWorkerThreadStateLocked(IteratorContext* ctx, - IteratorStateReader* reader, int index, - WorkerThreadState* state) { + absl::Status ReadWorkerThreadStateLocked(IteratorContext* ctx, + IteratorStateReader* reader, + int index, + WorkerThreadState* state) { string worker_prefix = strings::StrCat(prefix(), "::", kWorkerThread, "_", index); // Restore inputs. @@ -1076,10 +1081,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteOutputElemLocked(IteratorStateWriter* writer, - const OutputElem& output_elem, - const string& iterator_name, - const string& prefix) + absl::Status WriteOutputElemLocked(IteratorStateWriter* writer, + const OutputElem& output_elem, + const string& iterator_name, + const string& prefix) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { TF_RETURN_IF_ERROR(WriteStatusLocked( writer, iterator_name, strings::StrCat(prefix, "_", kStatus), @@ -1095,11 +1100,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadOutputElemLocked(IteratorContext* ctx, - IteratorStateReader* reader, - OutputElem* output_elem, - const string& iterator_name, - const string& prefix) { + absl::Status ReadOutputElemLocked(IteratorContext* ctx, + IteratorStateReader* reader, + OutputElem* output_elem, + const string& iterator_name, + const string& prefix) { TF_RETURN_IF_ERROR(ReadStatusLocked(reader, iterator_name, strings::StrCat(prefix, "_", kStatus), &output_elem->status)); @@ -1118,9 +1123,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteStatusLocked(IteratorStateWriter* writer, - const string& iterator_name, const string& prefix, - const Status& status) + absl::Status WriteStatusLocked(IteratorStateWriter* writer, + const string& iterator_name, + const string& prefix, + const absl::Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( iterator_name, strings::StrCat(prefix, "_", kCode), @@ -1133,9 +1139,9 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadStatusLocked(IteratorStateReader* reader, - const string& iterator_name, const string& prefix, - Status* status) { + absl::Status ReadStatusLocked(IteratorStateReader* reader, + const string& iterator_name, + const string& prefix, absl::Status* status) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( iterator_name, strings::StrCat(prefix, "_", kCode), &code_int)); @@ -1146,7 +1152,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(reader->ReadScalar( iterator_name, strings::StrCat(prefix, "_", KMessage), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc index 1ba4c37e0513a8..26de0a17c9ec48 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op_test.cc @@ -66,7 +66,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(ParallelInterleaveDatasetOp::kInputDataset); for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back( @@ -81,7 +81,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"deterministic", deterministic_}, {"Targuments", type_arguments_}, diff --git a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc index 03a1a42c1f7a57..5c7c6013ae8aad 100644 --- a/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parse_example_dataset_op.cc @@ -273,20 +273,20 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return input_->Cardinality(options); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -379,7 +379,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { if (deregister_fn_) deregister_fn_(); } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); if (num_parallel_calls_->value == model::kAutotune) { num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx); @@ -391,9 +391,9 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::shared_ptr result; { mutex_lock l(*mu_); @@ -427,8 +427,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { @@ -465,8 +465,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64_t invocation_results_size; @@ -537,7 +537,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { explicit InvocationResult(int64_t id) : id(id) {} Notification notification; - Status status; + absl::Status status; std::vector return_values; bool end_of_input = false; int64_t id = -1; @@ -594,7 +594,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return; } - auto done = [this, ctx, result](Status status) { + auto done = [this, ctx, result](absl::Status status) { result->status.Update(status); CallCompleted(ctx, result); }; @@ -626,8 +626,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { RecordStart(ctx.get()); } - Status CheckOutputTensor(const Tensor& tensor, size_t value_index, - size_t output_index) const { + absl::Status CheckOutputTensor(const Tensor& tensor, size_t value_index, + size_t output_index) const { if (tensor.dtype() != dataset()->output_dtypes()[output_index]) { return errors::InvalidArgument( "Got wrong type for FastParseExample return value ", value_index, @@ -646,8 +646,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status ParseExample(IteratorContext* ctx, std::vector input, - std::vector* output) { + absl::Status ParseExample(IteratorContext* ctx, std::vector input, + std::vector* output) { thread::ThreadPool* device_threadpool = ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers; std::vector slice_vec; @@ -725,10 +725,10 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status ProcessResult(IteratorContext* ctx, - const std::shared_ptr& result, - std::vector* out_tensors, - bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) { + absl::Status ProcessResult( + IteratorContext* ctx, const std::shared_ptr& result, + std::vector* out_tensors, bool* end_of_sequence) + TF_LOCKS_EXCLUDED(*mu_) { if (!result->end_of_input && result->status.ok()) { *out_tensors = std::move(result->return_values); RecordBufferDequeue(ctx, *out_tensors); @@ -849,8 +849,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { } } - Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, - const Status& status) + absl::Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const absl::Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( CodeKey(index), static_cast(status.code()))); @@ -861,8 +861,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) + absl::Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + absl::Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); @@ -872,7 +872,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { tstring error_message; TF_RETURN_IF_ERROR( reader->ReadScalar(ErrorMessageKey(index), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/random_access_ops.cc b/tensorflow/core/kernels/data/experimental/random_access_ops.cc index b4c26b8136ed2c..e7b6a391ce94f3 100644 --- a/tensorflow/core/kernels/data/experimental/random_access_ops.cc +++ b/tensorflow/core/kernels/data/experimental/random_access_ops.cc @@ -27,7 +27,7 @@ namespace tensorflow { namespace data { namespace experimental { -Status GetElementAtIndexOp::DoCompute(OpKernelContext* ctx) { +absl::Status GetElementAtIndexOp::DoCompute(OpKernelContext* ctx) { DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); diff --git a/tensorflow/core/kernels/data/experimental/random_access_ops.h b/tensorflow/core/kernels/data/experimental/random_access_ops.h index ddc5e9dabf53f3..293cb99c12380e 100644 --- a/tensorflow/core/kernels/data/experimental/random_access_ops.h +++ b/tensorflow/core/kernels/data/experimental/random_access_ops.h @@ -49,7 +49,7 @@ class GetElementAtIndexOp : public AsyncOpKernel { } protected: - Status DoCompute(OpKernelContext* ctx); + absl::Status DoCompute(OpKernelContext* ctx); private: UnboundedThreadPool unbounded_threadpool_; diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc index 31443f9cf76d2d..dd699bcdee272f 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op.cc @@ -70,7 +70,7 @@ class RandomDatasetOp::Dataset : public DatasetBase { ~Dataset() override { manager_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s; @@ -78,8 +78,8 @@ class RandomDatasetOp::Dataset : public DatasetBase { } } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { // We use kint64 to generate an effectively infinite number of "splits". // These splits aren't actually used during iteration. // TODO(aaudibert): Avoid sending dummy splits over RPC when using tf.data @@ -118,16 +118,17 @@ class RandomDatasetOp::Dataset : public DatasetBase { return kInfiniteCardinality; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* seed_node = nullptr; Node* seed2_node = nullptr; TF_RETURN_IF_ERROR(b->AddScalar(seeds_.input_seed(), &seed_node)); @@ -159,16 +160,16 @@ class RandomDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { out_tensors->reserve(1); mutex_lock l(mu_); out_tensors->emplace_back(ctx->allocator({}), DT_INT64, TensorShape({})); @@ -182,8 +183,8 @@ class RandomDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. TF_RETURN_IF_ERROR( @@ -196,8 +197,8 @@ class RandomDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); // Restore the random number generators. int64_t num_random_samples; @@ -278,7 +279,7 @@ void RandomDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { bool owns_resource = true; if (op_version_ == 2) { OP_REQUIRES_OK(ctx, HandleFromInput(ctx, 2, &handle)); - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); owns_resource = false; if (errors::IsNotFound(s)) { diff --git a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc index 35bd077fb8bd8f..b099c7caea2365 100644 --- a/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/random_dataset_op_test.cc @@ -98,7 +98,7 @@ class RandomDatasetParams : public DatasetParams { return {seed_, seed2_, seed_generator_resource_}; } - virtual Status GetInputNames( + virtual absl::Status GetInputNames( std::vector* input_names) const override { *input_names = {RandomDatasetOp::kSeed, RandomDatasetOp::kSeed2}; if (op_version_ == 2) { @@ -107,7 +107,8 @@ class RandomDatasetParams : public DatasetParams { return absl::OkStatus(); } - virtual Status GetAttributes(AttributeVector* attributes) const override { + virtual absl::Status GetAttributes( + AttributeVector* attributes) const override { *attributes = {{"output_types", output_dtypes_}, {"output_shapes", output_shapes_}, {"metadata", ""}}; diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index e6c02d226756ed..711dcc0fe30437 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -94,20 +94,20 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { return name_utils::DatasetDebugString(kDatasetTypeV1, params); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* num_replicas = nullptr; @@ -125,14 +125,14 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { ~Iterator() override {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); *end_of_sequence = false; if (slice_number_ % dataset()->num_replicas_ == 0) { @@ -189,8 +189,8 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR( @@ -212,8 +212,8 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("input_impl_empty"))) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -355,20 +355,20 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { return name_utils::DatasetDebugString(kDatasetTypeV2); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_sizes = nullptr; @@ -388,14 +388,14 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { ~Iterator() override {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (end_of_sequence_) { *end_of_sequence = true; @@ -571,8 +571,8 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR( @@ -592,8 +592,8 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("input_impl_empty"))) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -621,7 +621,7 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { } private: - Status ValidateInputTensors() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status ValidateInputTensors() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { for (size_t i = 0; i < tensors_.size(); ++i) { if (tensors_[i].dims() == 0) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc index ca84cf49ee33f6..71ff3626b290d3 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op.cc @@ -70,19 +70,20 @@ class SamplingDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* rate = nullptr; @@ -105,13 +106,13 @@ class SamplingDatasetOp::Dataset : public DatasetBase { parent_generator_(seeds_.first, seeds_.second), generator_(&parent_generator_) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { bool rand_val_hit; do { { @@ -150,8 +151,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase { generator_.Skip(num_random_samples_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -170,8 +171,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); // Restore the random number generators. TF_RETURN_IF_ERROR(reader->ReadScalar( diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc index 3c0f3dac7b9bc9..6a17952ca867f5 100644 --- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc @@ -46,14 +46,14 @@ class SamplingDatasetParams : public DatasetParams { return {rate, seed_tensor, seed2_tensor}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {SamplingDatasetOp::kInputDataset, SamplingDatasetOp::kRate, SamplingDatasetOp::kSeed, SamplingDatasetOp::kSeed2}; return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{SamplingDatasetOp::kOutputTypes, output_dtypes_}, {SamplingDatasetOp::kOutputShapes, output_shapes_}}; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/experimental/save_dataset_op.cc b/tensorflow/core/kernels/data/experimental/save_dataset_op.cc index 0110618143c85f..2a6364c7e61c5a 100644 --- a/tensorflow/core/kernels/data/experimental/save_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/save_dataset_op.cc @@ -66,7 +66,7 @@ SaveDatasetOp::SaveDatasetOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseShardFunc, &use_shard_func_)); } -Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) { +absl::Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) { metrics::RecordTFDataFetchOp("SaveDatasetOp"); DatasetBase* dataset; TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset)); @@ -95,10 +95,10 @@ Status SaveDatasetOp::DoCompute(OpKernelContext* ctx) { return absl::OkStatus(); } -Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset, - std::unique_ptr captured_func, - const std::string& run_dir, - uint64* num_elements) { +absl::Status SaveDatasetOp::WriteData( + OpKernelContext* ctx, DatasetBase* dataset, + std::unique_ptr captured_func, const std::string& run_dir, + uint64* num_elements) { IteratorContext::Params params(ctx); auto function_handle_cache = std::make_unique(params.flr); @@ -121,7 +121,7 @@ Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset, &iter_ctx, /*parent=*/nullptr, "Save", &iterator)); mutex mu; - Status status; + absl::Status status; absl::flat_hash_map> writers; while (true) { @@ -148,7 +148,7 @@ Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset, auto writer_thread = std::make_unique( ctx->env(), shard_index, snapshot_shard_directory, /*checkpoint_id=*/0, compression_, kFileFormatVersion, - finalized_dataset->output_dtypes(), [&mu, &status](Status s) { + finalized_dataset->output_dtypes(), [&mu, &status](absl::Status s) { mutex_lock l(mu); status.Update(s); }); @@ -167,10 +167,9 @@ Status SaveDatasetOp::WriteData(OpKernelContext* ctx, DatasetBase* dataset, return status; } -Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx, - InstantiatedCapturedFunction* function, - const std::vector& element, - int64_t* shard_index) { +absl::Status SaveDatasetOp::GetShardIndex( + IteratorContext* ctx, InstantiatedCapturedFunction* function, + const std::vector& element, int64_t* shard_index) { if (!use_shard_func_) { *shard_index = (*shard_index + 1) % GetCpuBudget(); return absl::OkStatus(); @@ -187,10 +186,9 @@ Status SaveDatasetOp::GetShardIndex(IteratorContext* ctx, return absl::OkStatus(); } -Status SaveDatasetOp::WriteMetadataFile(Env* env, const std::string& path, - uint64 run_id, - const DataTypeVector& output_dtypes, - uint64 num_elements, bool finalized) { +absl::Status SaveDatasetOp::WriteMetadataFile( + Env* env, const std::string& path, uint64 run_id, + const DataTypeVector& output_dtypes, uint64 num_elements, bool finalized) { SnapshotMetadataRecord metadata; metadata.set_creation_timestamp(EnvTime::NowMicros()); metadata.set_run_id( @@ -242,19 +240,20 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -318,7 +317,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { SignalEOF(true); } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_)); @@ -337,9 +336,9 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { *end_of_sequence = false; snapshot_util::AsyncWriter* current_writer; @@ -384,7 +383,8 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { auto writer = std::make_unique( ctx->env(), shard_index, snapshot_shard_directory, current_checkpoint_id_, dataset()->compression_, - kFileFormatVersion, dataset()->output_dtypes(), [this](Status s) { + kFileFormatVersion, dataset()->output_dtypes(), + [this](absl::Status s) { if (!s.ok()) { mutex_lock l(writer_status_mu_); writer_status_ = s; @@ -400,8 +400,8 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), static_cast(run_id_))); @@ -414,8 +414,8 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t run_id_signed; int64_t current_checkpoint_id; @@ -440,10 +440,10 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { } private: - Status GetShardIndex(IteratorContext* ctx, - InstantiatedCapturedFunction* function, - const std::vector& element, - bool use_shard_func, int64_t* shard_index) + absl::Status GetShardIndex(IteratorContext* ctx, + InstantiatedCapturedFunction* function, + const std::vector& element, + bool use_shard_func, int64_t* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!use_shard_func) { *shard_index = (*shard_index + 1) % GetCpuBudget(); @@ -462,9 +462,10 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteMetadataFile(Env* env, const std::string& path, uint64 run_id, - const DataTypeVector& output_dtypes, - uint64 num_elements, bool finalized) + absl::Status WriteMetadataFile(Env* env, const std::string& path, + uint64 run_id, + const DataTypeVector& output_dtypes, + uint64 num_elements, bool finalized) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { SnapshotMetadataRecord metadata; metadata.set_creation_timestamp(EnvTime::NowMicros()); @@ -497,7 +498,7 @@ class SaveDatasetV2Op::Dataset : public DatasetBase { absl::flat_hash_map> writers_ TF_GUARDED_BY(mu_); - Status writer_status_ TF_GUARDED_BY(writer_status_mu_); + absl::Status writer_status_ TF_GUARDED_BY(writer_status_mu_); bool writers_closed_ TF_GUARDED_BY(mu_); uint64 run_id_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/experimental/save_dataset_op.h b/tensorflow/core/kernels/data/experimental/save_dataset_op.h index 212a1d7b62b34d..77478d4ee7e5a7 100644 --- a/tensorflow/core/kernels/data/experimental/save_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/save_dataset_op.h @@ -42,25 +42,26 @@ class SaveDatasetOp : public HybridAsyncOpKernel { explicit SaveDatasetOp(OpKernelConstruction* ctx); - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; private: static constexpr const int kFileFormatVersion = 2; - Status ConsumeElement(); + absl::Status ConsumeElement(); - Status GetShardIndex(IteratorContext* ctx, - InstantiatedCapturedFunction* function, - const std::vector& element, - int64_t* shard_index); + absl::Status GetShardIndex(IteratorContext* ctx, + InstantiatedCapturedFunction* function, + const std::vector& element, + int64_t* shard_index); - Status WriteData(OpKernelContext* ctx, DatasetBase* dataset, - std::unique_ptr captured_func, - const std::string& run_dir, uint64* num_elements); + absl::Status WriteData(OpKernelContext* ctx, DatasetBase* dataset, + std::unique_ptr captured_func, + const std::string& run_dir, uint64* num_elements); - Status WriteMetadataFile(Env* env, const std::string& path, uint64 run_id, - const DataTypeVector& output_dtypes, - uint64 num_elements, bool finalized); + absl::Status WriteMetadataFile(Env* env, const std::string& path, + uint64 run_id, + const DataTypeVector& output_dtypes, + uint64 num_elements, bool finalized); bool use_shard_func_; std::string compression_; diff --git a/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc index f0b7745a46e93d..fe2315e35bd6a4 100644 --- a/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/save_dataset_op_test.cc @@ -59,14 +59,14 @@ class SaveDatasetV2Params : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(SaveDatasetV2Op::kInputDataset); input_names->emplace_back(SaveDatasetV2Op::kPath); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back(SaveDatasetV2Op::kCompression, compression_); attr_vector->emplace_back(SaveDatasetV2Op::kShardFunc, shard_func_); @@ -97,7 +97,7 @@ class SaveDatasetV2Params : public DatasetParams { class SaveDatasetV2OpTest : public DatasetOpsTestBase { public: - Status Initialize(const DatasetParams& dataset_params) { + absl::Status Initialize(const DatasetParams& dataset_params) { TF_RETURN_IF_ERROR(DatasetOpsTestBase::Initialize(dataset_params)); auto params = static_cast(dataset_params); save_filename_ = params.path(); diff --git a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc index a4aa38277870ee..1fc21550244599 100644 --- a/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/scan_dataset_op.cc @@ -114,21 +114,21 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { } } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); std::vector initial_state_nodes; @@ -173,16 +173,16 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( ctx, &instantiated_captured_func_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); std::vector next_element; @@ -202,7 +202,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { state_and_output.reserve(dataset()->state_types_.size() + output_dtypes().size()); - Status s = instantiated_captured_func_->Run( + absl::Status s = instantiated_captured_func_->Run( ctx, std::move(args), &state_and_output, model_node()); DCHECK(state_and_output.size() <= dataset()->state_types_.size() + output_dtypes().size()); @@ -261,8 +261,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); @@ -276,8 +276,8 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64_t size; diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index 5c73ed46cdfae0..e0a90ac5f04e63 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -59,7 +59,8 @@ class StatsAggregatorWithTagAndPrefix : public StatsAggregator { } } - Status SetSummaryWriter(SummaryWriterInterface* summary_writer) override { + absl::Status SetSummaryWriter( + SummaryWriterInterface* summary_writer) override { return wrapped_->SetSummaryWriter(summary_writer); } @@ -143,20 +144,20 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { return input_->Cardinality(options); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* resource_handle_node = nullptr; @@ -177,15 +178,15 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { IteratorContext iter_ctx = ContextWithAggregator(ctx); return dataset()->input_->MakeIterator(&iter_ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); IteratorContext iter_ctx = ContextWithAggregator(ctx); return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); @@ -210,14 +211,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); return RestoreInput(ctx, reader, input_impl_); } diff --git a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc index a8d8a7fa44228f..1781ac0a345884 100644 --- a/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sleep_dataset_op.cc @@ -70,20 +70,20 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { return input_->Cardinality(options); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -115,7 +115,7 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR(RegisterCancellationCallback( ctx->cancellation_manager(), [this]() { @@ -127,9 +127,9 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); RecordStop(ctx); bool cancelled = mu_.AwaitWithDeadline( @@ -150,13 +150,13 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return RestoreInput(ctx, reader, input_impl_); } diff --git a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc index fa522ea8e74bdf..1657cef0a092a9 100644 --- a/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/sliding_window_dataset_op.cc @@ -118,20 +118,20 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { return (drop_remainder_ ? n : n + window_shift_ - 1) / window_shift_; } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* window_size = nullptr; @@ -157,14 +157,14 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { explicit Iterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { const int64_t window_size = dataset()->window_size_; const int64_t window_shift = dataset()->window_shift_; const int64_t window_stride = dataset()->window_stride_; @@ -258,8 +258,8 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { dataset()->window_shift_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (!input_impl_) { TF_RETURN_IF_ERROR( @@ -281,8 +281,8 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (!reader->Contains(full_name("input_impl_empty"))) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index 314a3c905bd244..a8f3e1ed9a38fc 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -124,8 +124,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")}); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { return errors::Unimplemented( "Splitting is not implemented for snapshot datasets."); } @@ -146,19 +146,20 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { return input_->Cardinality(); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -242,7 +243,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { Reader(const Params& params, int64_t start_index) : DatasetIterator(params), start_index_(start_index) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(dataset()->reader_func_->Instantiate( @@ -293,24 +294,24 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { return input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { // We do not need to checkpoint the reader as we are rebuilding the // reader datasets from information that is already saved by the main // iterator. return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return absl::OkStatus(); } @@ -345,7 +346,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { SignalEOF(true); } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( dataset()->shard_func_->Instantiate(ctx, &instantiated_shard_func_)); @@ -355,9 +356,9 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { *end_of_sequence = false; snapshot_util::AsyncWriter* current_writer; @@ -416,7 +417,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { auto writer = std::make_unique( ctx->env(), shard_index, snapshot_shard_directory, current_checkpoint_id_, dataset()->compression_, - kFileFormatVersion, dataset()->output_dtypes(), [this](Status s) { + kFileFormatVersion, dataset()->output_dtypes(), + [this](absl::Status s) { if (!s.ok()) { LOG(ERROR) << "AsyncWriter in snapshot writer failed: " << s; mutex_lock l(writer_status_mu_); @@ -433,8 +435,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kRunId), static_cast(run_id_))); @@ -447,8 +449,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t run_id_signed; int64_t current_checkpoint_id; @@ -469,9 +471,9 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } private: - Status GetShardIndex(IteratorContext* ctx, - const std::vector& tensors, - int64_t* shard_index) + absl::Status GetShardIndex(IteratorContext* ctx, + const std::vector& tensors, + int64_t* shard_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { std::vector output_tensors; @@ -491,7 +493,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteMetadataFile(Env* env, bool finalized) + absl::Status WriteMetadataFile(Env* env, bool finalized) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { DCHECK(!run_dir_.empty()); @@ -530,7 +532,7 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { absl::flat_hash_map> writers_ TF_GUARDED_BY(mu_); - Status writer_status_ TF_GUARDED_BY(writer_status_mu_); + absl::Status writer_status_ TF_GUARDED_BY(writer_status_mu_); bool writers_closed_ TF_GUARDED_BY(mu_); uint64 run_id_ TF_GUARDED_BY(mu_); @@ -551,24 +553,24 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { explicit Passthrough(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return RestoreInput(ctx, reader, input_impl_); } @@ -589,17 +591,17 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { hash_dir_(snapshot_util::HashDirectory(dataset()->path_, dataset()->hash_)) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return ctx->env()->RecursivelyCreateDir( io::JoinPath(dataset()->writer_prefix_, hash_dir_)); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (iterator_ == nullptr) { - Status s = InitializeIterator(ctx, /*reader=*/nullptr); + absl::Status s = InitializeIterator(ctx, /*reader=*/nullptr); if (!s.ok()) { iterator_.reset(); return s; @@ -610,8 +612,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (iterator_ != nullptr) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_)); @@ -624,8 +626,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (reader->Contains(full_name(kIteratorMode))) { TF_RETURN_IF_ERROR(InitializeIterator(ctx, reader)); @@ -635,7 +637,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase { } private: - Status InitializeIterator(IteratorContext* ctx, IteratorStateReader* reader) + absl::Status InitializeIterator(IteratorContext* ctx, + IteratorStateReader* reader) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (reader != nullptr) { // Check whether the computed hash directory is the same. @@ -915,7 +918,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { uint64 hash; OP_REQUIRES_OK(ctx, ComputeDatasetHash(graph_def, path, &hash)); - Status dump_status = + absl::Status dump_status = snapshot_util::DumpDatasetGraph(ctx->env(), path, hash, &graph_def); if (!dump_status.ok()) { LOG(WARNING) << "Unable to write graphdef to disk, error: " @@ -975,8 +978,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { Iterator::Params{this, absl::StrCat(prefix, "::Snapshot")}); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { return errors::Unimplemented( "Splitting is not implemented for snapshot datasets."); } @@ -995,20 +998,20 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return input_->Cardinality(options); } - Status InputDatasets( + absl::Status InputDatasets( std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -1106,13 +1109,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { // Initialize at first and at that point we don't know which iterator // (Reader / Writer / Passthrough) we need to restore as this info is part // of the checkpoint. - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (iterator_ == nullptr) { experimental::SnapshotMetadataRecord metadata; @@ -1129,8 +1132,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); if (iterator_ != nullptr) { TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_)); @@ -1142,8 +1145,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); tstring hash_dir; TF_RETURN_IF_ERROR(reader->ReadScalar(full_name(kHashDir), &hash_dir)); @@ -1167,7 +1170,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { // This method expects that state_ is populated and it will create the // correct Reader / Writer / Passthrough iterator and initialize it. - Status InitializeIterator( + absl::Status InitializeIterator( IteratorContext* ctx, const experimental::SnapshotMetadataRecord& metadata) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { @@ -1245,7 +1248,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool, dataset()->num_reader_threads_); @@ -1280,9 +1283,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { absl::Time start = absl::Now(); mutex_lock l(mu_); if (!background_threads_started_) { @@ -1318,7 +1321,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } if (!buffer_.empty()) { - Status s = buffer_.front().status; + absl::Status s = buffer_.front().status; if (s.ok()) { *end_of_sequence = false; *out_tensors = std::move(buffer_.front().value); @@ -1367,8 +1370,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(kHashDir), hash_dir_)); @@ -1402,8 +1405,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); tstring hash_dir, run_id, run_dir; TF_RETURN_IF_ERROR( @@ -1471,7 +1474,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { private: // Reads one file end to end. - Status ReadFile(Env* env, const string& filename) { + absl::Status ReadFile(Env* env, const string& filename) { std::unique_ptr reader; TF_RETURN_IF_ERROR(snapshot_util::Reader::Create( env, filename, dataset()->compression_, version_, @@ -1492,7 +1495,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } std::vector read_tensors; - Status s = reader->ReadTensors(&read_tensors); + absl::Status s = reader->ReadTensors(&read_tensors); if (s.ok()) { tsl::profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kParse); }, @@ -1541,7 +1544,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } VLOG(2) << "Starting to read: " << filename; } - Status s = ReadFile(env, filename); + absl::Status s = ReadFile(env, filename); // If we get to the end of the file, it's a clean termination and // we are at the end of the file. If all files have been processed, // then we insert an end_of_sequence marker in the buffer and @@ -1568,8 +1571,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } - Status WriteStatus(IteratorStateWriter* writer, size_t index, - const Status& status) + absl::Status WriteStatus(IteratorStateWriter* writer, size_t index, + const absl::Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( CodeKey(index), static_cast(status.code()))); @@ -1580,8 +1583,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status ReadStatus(IteratorStateReader* reader, size_t index, - Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status ReadStatus(IteratorStateReader* reader, size_t index, + absl::Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); absl::StatusCode code = static_cast(code_int); @@ -1590,7 +1594,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { tstring error_message; TF_RETURN_IF_ERROR( reader->ReadScalar(ErrorMessageKey(index), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } @@ -1607,7 +1611,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } struct BufferElement { - Status status; + absl::Status status; std::vector value; }; @@ -1665,16 +1669,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool, dataset()->num_writer_threads_); return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { absl::Time start = absl::Now(); bool first_call; @@ -1780,8 +1784,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); if (end_of_sequence_) { @@ -1834,8 +1838,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); buffer_.clear(); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -1953,7 +1957,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return snapshot_data_filename; } - Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) { + absl::Status FillBuffer(IteratorContext* ctx) TF_LOCKS_EXCLUDED(mu_) { snapshot_util::ElementOrEOF elem; TF_RETURN_IF_ERROR( input_impl_->GetNext(ctx, &elem.value, &elem.end_of_sequence)); @@ -1994,10 +1998,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { return absl::OkStatus(); } - Status ProcessOneElement(Env* env, int64_t* bytes_written, - string* snapshot_data_filename, - std::unique_ptr* writer, - bool* end_of_processing) { + absl::Status ProcessOneElement( + Env* env, int64_t* bytes_written, string* snapshot_data_filename, + std::unique_ptr* writer, + bool* end_of_processing) { tsl::profiler::TraceMe activity( [&]() { return absl::StrCat(prefix(), kSeparator, kProcessOneElement); @@ -2094,7 +2098,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { int64_t bytes_written = 0; string snapshot_data_filename = GetSnapshotFilename(); std::unique_ptr writer; - Status s = snapshot_util::Writer::Create( + absl::Status s = snapshot_util::Writer::Create( env, snapshot_data_filename, dataset()->compression_, kCurrentVersion, dataset()->output_dtypes(), &writer); if (!s.ok()) { @@ -2108,7 +2112,7 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { bool end_of_processing = false; while (!end_of_processing) { - Status s = + absl::Status s = ProcessOneElement(env, &bytes_written, &snapshot_data_filename, &writer, &end_of_processing); if (!s.ok()) { @@ -2123,10 +2127,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { } } - Status ShouldCloseWriter(Env* env, const string& filename, - uint64 bytes_written, - snapshot_util::Writer* writer, - bool* should_close) { + absl::Status ShouldCloseWriter(Env* env, const string& filename, + uint64 bytes_written, + snapshot_util::Writer* writer, + bool* should_close) { // If the compression ratio has been estimated, use it to decide // whether the file should be closed. We avoid estimating the // compression ratio repeatedly because it requires syncing the file, @@ -2199,25 +2203,25 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { explicit SnapshotPassthroughIterator(const Params& params) : DatasetIterator(params) {} - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return RestoreInput(ctx, reader, input_impl_); } @@ -2256,8 +2260,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel { const std::string snapshot_name_; }; - Status ComputeDatasetHash(const GraphDef& graph_def, const std::string& path, - uint64* hash) { + absl::Status ComputeDatasetHash(const GraphDef& graph_def, + const std::string& path, uint64* hash) { TF_RETURN_IF_ERROR(HashGraph(graph_def, hash)); // Adding path, compression, reader / writer path prefix, shard size // bytes to the fp as they effect the data written on disk. diff --git a/tensorflow/core/kernels/data/experimental/sql/query_connection.h b/tensorflow/core/kernels/data/experimental/sql/query_connection.h index 40f13d54f35b44..031a87253e5fad 100644 --- a/tensorflow/core/kernels/data/experimental/sql/query_connection.h +++ b/tensorflow/core/kernels/data/experimental/sql/query_connection.h @@ -48,10 +48,10 @@ class QueryConnection { // The client must call `Close()` to release the connection resources, even // if `Open()` fails. `Close()` must be called before making another call // to `Open()`. - virtual Status Open(const string& data_source_name, const string& query, - const DataTypeVector& output_types) = 0; + virtual absl::Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) = 0; // Closes an opened connection. - virtual Status Close() = 0; + virtual absl::Status Close() = 0; // Retrieves the next row of the result set of the query from the most recent // call to `Open()`. // @@ -61,8 +61,9 @@ class QueryConnection { // If there are no more rows in the result set, then instead `true` will be // stored in `*end_of_sequence`, and the content of `*out_tensors` will be // undefined. - virtual Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) = 0; + virtual absl::Status GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) = 0; }; } // namespace sql diff --git a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc index 8b63d10a6e86e5..a3ebb8a6f5fc83 100644 --- a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.cc @@ -29,9 +29,9 @@ SqliteQueryConnection::~SqliteQueryConnection() { if (db_ != nullptr) db_->Unref(); } -Status SqliteQueryConnection::Open(const string& data_source_name, - const string& query, - const DataTypeVector& output_types) { +absl::Status SqliteQueryConnection::Open(const string& data_source_name, + const string& query, + const DataTypeVector& output_types) { if (db_ != nullptr) { return errors::FailedPrecondition( "Failed to open query connection: Connection already opened."); @@ -43,16 +43,16 @@ Status SqliteQueryConnection::Open(const string& data_source_name, return absl::OkStatus(); } -Status SqliteQueryConnection::Close() { +absl::Status SqliteQueryConnection::Close() { stmt_ = SqliteStatement(); db_->Unref(); db_ = nullptr; return absl::OkStatus(); } -Status SqliteQueryConnection::GetNext(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) { +absl::Status SqliteQueryConnection::GetNext(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) { if (!stmt_) TF_RETURN_IF_ERROR(PrepareQuery()); TF_RETURN_IF_ERROR(stmt_.Step(end_of_sequence)); if (!*end_of_sequence) { @@ -66,7 +66,7 @@ Status SqliteQueryConnection::GetNext(IteratorContext* ctx, return absl::OkStatus(); } -Status SqliteQueryConnection::PrepareQuery() { +absl::Status SqliteQueryConnection::PrepareQuery() { TF_RETURN_IF_ERROR(db_->Prepare(query_, &stmt_)); int column_count = stmt_.ColumnCount(); if (column_count != static_cast(output_types_.size())) { diff --git a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h index 42526c7668a2a6..4cf2608c22f02c 100644 --- a/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/experimental/sql/sqlite_query_connection.h @@ -30,15 +30,15 @@ class SqliteQueryConnection : public QueryConnection { public: SqliteQueryConnection(); ~SqliteQueryConnection() override; - Status Open(const string& data_source_name, const string& query, - const DataTypeVector& output_types) override; - Status Close() override; - Status GetNext(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) override; + absl::Status Open(const string& data_source_name, const string& query, + const DataTypeVector& output_types) override; + absl::Status Close() override; + absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override; private: // Prepares the query string `query_`. - Status PrepareQuery(); + absl::Status PrepareQuery(); // Fills `tensor` with the column_index_th element of the current row of // `stmt_`. void FillTensorWithResultSetEntry(const DataType& data_type, int column_index, diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 425b38b458149e..31b032f571417a 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -79,20 +79,21 @@ class FilterDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); std::vector other_arguments; @@ -120,7 +121,7 @@ class FilterDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( @@ -130,9 +131,9 @@ class FilterDatasetOp::Dataset : public DatasetBase { // NOTE(mrry): This method is thread-safe as long as `input_impl_` and `f` // are thread-safe. However, if multiple threads enter this method, // outputs may be observed in a non-deterministic order. - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { auto stats_aggregator = ctx->stats_aggregator(); bool matched; do { @@ -213,8 +214,8 @@ class FilterDatasetOp::Dataset : public DatasetBase { return model::MakeUnknownRatioNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); @@ -230,8 +231,8 @@ class FilterDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_empty; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/filter_dataset_op_test.cc b/tensorflow/core/kernels/data/filter_dataset_op_test.cc index e325b604c60dda..14b7f571a7bba7 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op_test.cc @@ -46,7 +46,7 @@ class FilterDatasetParams : public DatasetParams { return other_arguments_; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->reserve(input_dataset_params_.size() + other_arguments_.size()); @@ -59,7 +59,7 @@ class FilterDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"predicate", pred_func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/finalize_dataset_op_test.cc b/tensorflow/core/kernels/data/finalize_dataset_op_test.cc index efd135c0e24839..2077cc28c161ec 100644 --- a/tensorflow/core/kernels/data/finalize_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/finalize_dataset_op_test.cc @@ -40,12 +40,12 @@ class FinalizeDatasetParams : public DatasetParams { std::vector GetInputTensors() const override { return {}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(FinalizeDatasetOp::kInputDataset); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{FinalizeDatasetOp::kHasCapturedRef, has_captured_ref_}, {FinalizeDatasetOp::kOutputTypes, output_dtypes_}, {FinalizeDatasetOp::kOutputShapes, output_shapes_}}; @@ -245,7 +245,7 @@ TEST_F(FinalizeDatasetOpTest, MaxIntraOpParallelismNodeName) { auto test_case_params = MaxIntraOpParallelismParams(); TF_ASSERT_OK(Initialize(test_case_params)); std::vector inputs; - Status s = dataset_->InputDatasets(&inputs); + absl::Status s = dataset_->InputDatasets(&inputs); TF_ASSERT_OK(CheckDatasetNodeName(test_case_params.node_name())); CheckDatasetPipelineTypeStrings( {"MaxIntraOpParallelismDataset", "OptionsDataset", "RangeDataset"}); @@ -255,7 +255,7 @@ TEST_F(FinalizeDatasetOpTest, PrivateThreadPoolNodeName) { auto test_case_params = PrivateThreadPoolParams(); TF_ASSERT_OK(Initialize(test_case_params)); std::vector inputs; - Status s = dataset_->InputDatasets(&inputs); + absl::Status s = dataset_->InputDatasets(&inputs); TF_ASSERT_OK(CheckDatasetNodeName(test_case_params.node_name())); CheckDatasetPipelineTypeStrings( {"PrivateThreadPoolDataset", "OptionsDataset", "RangeDataset"}); @@ -265,7 +265,7 @@ TEST_F(FinalizeDatasetOpTest, ModelNodeName) { auto test_case_params = ModelParams(); TF_ASSERT_OK(Initialize(test_case_params)); std::vector inputs; - Status s = dataset_->InputDatasets(&inputs); + absl::Status s = dataset_->InputDatasets(&inputs); TF_ASSERT_OK(CheckDatasetNodeName(test_case_params.node_name())); CheckDatasetPipelineTypeStrings( {"ModelDataset", "OptionsDataset", "RangeDataset"}); @@ -275,7 +275,7 @@ TEST_F(FinalizeDatasetOpTest, OptimizationsDefaultNodeName) { auto test_case_params = OptimizationsDefaultParams(); TF_ASSERT_OK(Initialize(test_case_params)); std::vector inputs; - Status s = dataset_->InputDatasets(&inputs); + absl::Status s = dataset_->InputDatasets(&inputs); TF_ASSERT_OK(CheckDatasetNodeName(test_case_params.node_name())); CheckDatasetPipelineTypeStrings({"PrivateThreadPoolDataset", "MaxIntraOpParallelismDataset", @@ -286,7 +286,7 @@ TEST_F(FinalizeDatasetOpTest, AllChainedDatasetsNodeName) { auto test_case_params = AllChainedDatasetsParams(); TF_ASSERT_OK(Initialize(test_case_params)); std::vector inputs; - Status s = dataset_->InputDatasets(&inputs); + absl::Status s = dataset_->InputDatasets(&inputs); TF_ASSERT_OK(CheckDatasetNodeName(test_case_params.node_name())); CheckDatasetPipelineTypeStrings( {"PrefetchDataset", "ModelDataset", "PrivateThreadPoolDataset", diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc index 5c55e34d270b54..54a1aae03d00c9 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op.cc @@ -94,16 +94,17 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType, params); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* filenames = nullptr; Node* header_bytes = nullptr; Node* record_bytes = nullptr; @@ -130,9 +131,9 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { explicit UncompressedIterator(const Params& params) : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); do { // We are currently processing a file, so try to read the next record. @@ -196,8 +197,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { } protected: - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentFileIndex, current_file_index_)); @@ -211,8 +212,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t current_file_index; TF_RETURN_IF_ERROR( @@ -256,9 +257,9 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { explicit CompressedIterator(const Params& params) : DatasetIterator(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { static monitoring::CounterCell* bytes_counter = metrics::GetTFDataBytesReadCounter(kDatasetType); mutex_lock l(mu_); @@ -283,7 +284,7 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { } } else { tstring record; - Status s = buffered_input_stream_->ReadNBytes( + absl::Status s = buffered_input_stream_->ReadNBytes( dataset()->record_bytes_, &record); if (s.ok()) { bytes_counter->IncrementBy(dataset()->record_bytes_); @@ -384,8 +385,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentFileIndex, current_file_index_)); @@ -400,8 +401,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t current_file_index; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc index ad4fff98fe5bbc..16236a9f427a8c 100644 --- a/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/fixed_length_record_dataset_op_test.cc @@ -55,7 +55,7 @@ class FixedLengthRecordDatasetParams : public DatasetParams { CreateTensor(TensorShape({}), {ToString(compression_type_)})}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); *input_names = {FixedLengthRecordDatasetOp::kFileNames, FixedLengthRecordDatasetOp::kHeaderBytes, @@ -66,7 +66,7 @@ class FixedLengthRecordDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("metadata", ""); return absl::OkStatus(); @@ -87,9 +87,9 @@ class FixedLengthRecordDatasetParams : public DatasetParams { class FixedLengthRecordDatasetOpTest : public DatasetOpsTestBase {}; -Status CreateTestFiles(const std::vector& filenames, - const std::vector& contents, - CompressionType compression_type) { +absl::Status CreateTestFiles(const std::vector& filenames, + const std::vector& contents, + CompressionType compression_type) { if (filenames.size() != contents.size()) { return tensorflow::errors::InvalidArgument( "The number of files does not match with the contents"); diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index cd03a090febdad..cc9f372792a45f 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -130,12 +130,13 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return *cardinality; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } @@ -150,9 +151,9 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); std::vector other_arguments; @@ -181,7 +182,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); input_ckpt_ = std::make_unique(ctx->id_registry()); TF_RETURN_IF_ERROR( @@ -190,9 +191,9 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { ctx, &instantiated_captured_func_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { if (ctx->index_mapper()) { return Get(ctx, out_tensors, end_of_sequence); } @@ -253,8 +254,9 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { // LINT.ThenChange(:SkipInternal) } - Status SkipInternal(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped) override { + absl::Status SkipInternal(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, + int* num_skipped) override { // LINT.IfChange(SkipInternal) mutex_lock l(mu_); *num_skipped = 0; @@ -418,8 +420,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { {model::MakeNonTunableParameter(kCycleLength, /*value=*/1)}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override TF_LOCKS_EXCLUDED(mu_) { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); @@ -446,8 +448,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override TF_LOCKS_EXCLUDED(mu_) { if (ctx->restored_element_count().has_value()) { return RestoreForGlobalShuffle(ctx, reader); @@ -482,8 +484,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreForGlobalShuffle(IteratorContext* ctx, - IteratorStateReader* reader) + absl::Status RestoreForGlobalShuffle(IteratorContext* ctx, + IteratorStateReader* reader) TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); element_count_ = *ctx->restored_element_count(); @@ -530,8 +532,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { } private: - Status BuildCurrentElementIteratorLocked(IteratorContext* ctx, - bool is_get_next) + absl::Status BuildCurrentElementIteratorLocked(IteratorContext* ctx, + bool is_get_next) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { // NOTE: We intentionally ignore resource modeling outside GetNext(). std::shared_ptr node = is_get_next ? model_node() : nullptr; @@ -540,8 +542,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { prefix(), ¤t_element_iterator_, node); } - Status RestoreCurrentElementIterator(IteratorContext* ctx, - IteratorStateReader* reader) + absl::Status RestoreCurrentElementIterator(IteratorContext* ctx, + IteratorStateReader* reader) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (ctx->symbolic_checkpoint()) { return RestoreCurrentElementIteratorSymbolic(ctx, reader); @@ -567,8 +569,8 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreCurrentElementIteratorSymbolic(IteratorContext* ctx, - IteratorStateReader* reader) + absl::Status RestoreCurrentElementIteratorSymbolic( + IteratorContext* ctx, IteratorStateReader* reader) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { bool end_of_sequence; auto input_ctx = std::make_unique(*ctx); diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc index b6a68065c93845..c3da5e1b4ea4e4 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op_test.cc @@ -46,7 +46,7 @@ class FlatMapDatasetParams : public DatasetParams { return other_arguments_; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(FlatMapDatasetOp::kInputDataset); for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back( @@ -55,7 +55,7 @@ class FlatMapDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 401a8e50284859..986e427c1739a6 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -75,20 +75,21 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(init_func_->CheckExternalState()); TF_RETURN_IF_ERROR(next_func_->CheckExternalState()); return finalize_func_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { return errors::Unimplemented(DebugString(), " does not support serialization"); } @@ -102,7 +103,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { ~Iterator() override { if (!finalized_ && initialized_) { std::vector ignored; - Status s = + absl::Status s = instantiated_finalize_func_->RunInstantiated(state_, &ignored); if (!s.ok()) { LOG(WARNING) @@ -112,7 +113,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { } } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->init_func_->Instantiate(ctx, &instantiated_init_func_)); TF_RETURN_IF_ERROR( @@ -122,9 +123,9 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); if (!initialized_) { @@ -138,7 +139,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status s = instantiated_next_func_->RunWithBorrowedArgs( + absl::Status s = instantiated_next_func_->RunWithBorrowedArgs( ctx, state_, out_tensors, model_node()); if (s.ok()) { *end_of_sequence = false; @@ -163,14 +164,14 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return errors::Unimplemented( "GeneratorDataset does not support checkpointing."); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return errors::Unimplemented( "GeneratorDataset does not support checkpointing."); } diff --git a/tensorflow/core/kernels/data/get_options_op_test.cc b/tensorflow/core/kernels/data/get_options_op_test.cc index 8f5ae9d7ea7d8b..6e40665fcaf282 100644 --- a/tensorflow/core/kernels/data/get_options_op_test.cc +++ b/tensorflow/core/kernels/data/get_options_op_test.cc @@ -42,12 +42,12 @@ class GetOptionsParams : public DatasetParams { std::vector GetInputTensors() const override { return {}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(OptionsDatasetOp::kInputDataset); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 772f1fdb416d50..a5added76bff09 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -104,20 +104,21 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); Node* cycle_length_node; @@ -149,7 +150,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); input_ckpt_ = std::make_unique(ctx->id_registry()); @@ -164,7 +165,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } - Status AdvancePosition(int num_elements) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status AdvancePosition(int num_elements) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { block_index_ += num_elements; if (block_index_ == dataset()->block_length_) { AdvanceToNextInCycle(); @@ -184,9 +186,9 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { } } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); while (!end_of_input_ || num_open_ > 0) { if (current_elements_[cycle_index_]) { @@ -231,8 +233,9 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status SkipInternal(IteratorContext* ctx, int num_to_skip, - bool* end_of_sequence, int* num_skipped) override { + absl::Status SkipInternal(IteratorContext* ctx, int num_to_skip, + bool* end_of_sequence, + int* num_skipped) override { mutex_lock l(mu_); *num_skipped = 0; while (!end_of_input_ || num_open_ > 0) { @@ -288,8 +291,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { kCycleLength, dataset()->cycle_length_)}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); mutex_lock l(mu_); @@ -312,8 +315,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64_t cycle_index; @@ -385,8 +388,8 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { : (cycle_index); } - Status SaveCurrentElements(SerializationContext* ctx, - IteratorStateWriter* writer) + absl::Status SaveCurrentElements(SerializationContext* ctx, + IteratorStateWriter* writer) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -476,7 +479,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { // 5. `input_impl_->GetNext()` -> put the result at [2] as args // // 6. ... and so on. - Status RestoreArgsListAndInputOffsetCycleIdxMap( + absl::Status RestoreArgsListAndInputOffsetCycleIdxMap( IteratorContext& ctx, std::vector& input_element_indices, std::vector>& checkpoints, std::vector>& args, @@ -572,7 +575,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreCurrentElements( + absl::Status RestoreCurrentElements( IteratorContext* ctx, IteratorStateReader* reader, std::vector& input_element_indices, std::vector>&& checkpoints, @@ -633,7 +636,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status MoveToNextElement(IteratorContext* ctx) + absl::Status MoveToNextElement(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!end_of_input_) { // Get the next element from the input dataset, and create diff --git a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc index e27e147ae7c6da..73c75b84496b40 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op_test.cc @@ -54,7 +54,7 @@ class InterleaveDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->reserve(input_dataset_params_.size() + other_arguments_.size() + 2); @@ -68,7 +68,7 @@ class InterleaveDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 499dfcfbb2c1cb..6384bc35d20856 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -107,9 +107,9 @@ IteratorResource::~IteratorResource() { VLOG(2) << "destroying iterator resource"; } -Status IteratorResource::GetNext(OpKernelContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) { +absl::Status IteratorResource::GetNext(OpKernelContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) { std::shared_ptr captured_state; { tf_shared_lock l(mu_); @@ -180,9 +180,9 @@ absl::Status IteratorResource::GetModelProto(std::string& model_proto) { return absl::OkStatus(); } -Status IteratorResource::Save(OpKernelContext* ctx, - ExternalStatePolicy external_state_policy, - IteratorStateWriter* writer) { +absl::Status IteratorResource::Save(OpKernelContext* ctx, + ExternalStatePolicy external_state_policy, + IteratorStateWriter* writer) { std::shared_ptr captured_state; { tf_shared_lock l(mu_); @@ -214,8 +214,8 @@ Status IteratorResource::Save(OpKernelContext* ctx, return iterator->Save(&serialization_ctx, writer); } -Status IteratorResource::Restore(OpKernelContext* ctx, - IteratorStateReader* reader) { +absl::Status IteratorResource::Restore(OpKernelContext* ctx, + IteratorStateReader* reader) { const DatasetBase* dataset; std::shared_ptr new_state; const DatasetBase* input_dataset; @@ -278,8 +278,8 @@ Status IteratorResource::Restore(OpKernelContext* ctx, return absl::OkStatus(); } -Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, - const DatasetBase* dataset) { +absl::Status IteratorResource::SetIteratorFromDataset( + OpKernelContext* ctx, const DatasetBase* dataset) { std::shared_ptr new_state; { tf_shared_lock l(mu_); @@ -376,9 +376,9 @@ class IteratorVariantSerializer { // Calls `Save` on the iterator_resource to build up the list of // IteratorStateVariant objects. - Status InitializeFromIterator(OpKernelContext* ctx, - ExternalStatePolicy external_state_policy, - IteratorResource* iterator_resource) { + absl::Status InitializeFromIterator(OpKernelContext* ctx, + ExternalStatePolicy external_state_policy, + IteratorResource* iterator_resource) { VariantTensorDataWriter writer; TF_RETURN_IF_ERROR( iterator_resource->Save(ctx, external_state_policy, &writer)); @@ -397,7 +397,7 @@ class IteratorVariantSerializer { } // Initializes `this` from `serialized_t` while restoring the iterator state. - Status InitFromTensor(const Tensor* serialized_t) { + absl::Status InitFromTensor(const Tensor* serialized_t) { int64_t num_tensors = serialized_t->dim_size(0); auto serialized_vec = serialized_t->vec(); std::vector data; @@ -421,7 +421,7 @@ class IteratorVariantSerializer { // Stores the IteratorStateVariant list into a pre-allocated tensor. Expects // that InitializeFromIterator was called before. - Status Serialize(Tensor* serialized) { + absl::Status Serialize(Tensor* serialized) { if (!can_serialize_) { return errors::InvalidArgument( "Please call InitializeFromIterator before calling Serialize."); @@ -514,7 +514,7 @@ void IteratorHandleOp::Compute(OpKernelContext* context) return absl::OkStatus(); })); - Status s = VerifyResource(resource); + absl::Status s = VerifyResource(resource); if (TF_PREDICT_FALSE(!s.ok())) { resource->Unref(); context->SetStatus(s); @@ -529,7 +529,7 @@ void IteratorHandleOp::Compute(OpKernelContext* context) TypeIndex::Make())); } -Status IteratorHandleOp::VerifyResource(IteratorResource* resource) { +absl::Status IteratorHandleOp::VerifyResource(IteratorResource* resource) { TF_RETURN_IF_ERROR( VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); TF_RETURN_IF_ERROR( @@ -585,7 +585,7 @@ AnonymousIteratorHandleOp::AnonymousIteratorHandleOp( string AnonymousIteratorHandleOp::name() { return kAnonymousIterator; } -Status AnonymousIteratorHandleOp::CreateResource( +absl::Status AnonymousIteratorHandleOp::CreateResource( OpKernelContext* ctx, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, IteratorResource** resource) { @@ -613,7 +613,7 @@ void HybridAsyncOpKernel::Compute(OpKernelContext* ctx) { ctx->SetStatus(DoCompute(ctx)); } -Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { +absl::Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { tensorflow::ResourceTagger tag(kTFDataResourceTag, ctx->op_kernel().type_string()); DatasetBase* dataset; @@ -625,7 +625,7 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) { return iterator_resource->SetIteratorFromDataset(ctx, dataset); } -Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) { +absl::Status DeleteIteratorOp::DoCompute(OpKernelContext* ctx) { tensorflow::ResourceTagger tag(kTFDataResourceTag, ctx->op_kernel().type_string()); const ResourceHandle& handle = ctx->input(0).flat()(0); @@ -660,7 +660,7 @@ class ToSingleElementOp : public AsyncOpKernel { } private: - Status DoCompute(OpKernelContext* ctx) { + absl::Status DoCompute(OpKernelContext* ctx) { tsl::profiler::TraceMe traceme( [&] { return tsl::profiler::TraceMeEncode("ToSingleElementOp::DoCompute", @@ -780,7 +780,7 @@ class OneShotIteratorOp : public AsyncOpKernel { void Init(OpKernelContext* ctx, const DoneCallback& done) { IteratorResource* iterator = nullptr; ContainerInfo cinfo; - Status s = TryInit(ctx, &iterator, &cinfo); + absl::Status s = TryInit(ctx, &iterator, &cinfo); std::vector> callbacks_to_run; { @@ -799,8 +799,8 @@ class OneShotIteratorOp : public AsyncOpKernel { ProduceOutput(ctx, done); } - Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, - ContainerInfo* cinfo) { + absl::Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, + ContainerInfo* cinfo) { TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); FunctionLibraryRuntime* flr; @@ -866,7 +866,7 @@ class OneShotIteratorOp : public AsyncOpKernel { Tensor* handle; OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle), done); - Status s; + absl::Status s; { mutex_lock l(mu_); s = initialization_status_; @@ -891,7 +891,7 @@ class OneShotIteratorOp : public AsyncOpKernel { IteratorResource* iterator_resource_ TF_GUARDED_BY(mu_) = nullptr; bool initialization_started_ TF_GUARDED_BY(mu_) = false; - Status initialization_status_ TF_GUARDED_BY(mu_); + absl::Status initialization_status_ TF_GUARDED_BY(mu_); std::vector> done_callbacks_ TF_GUARDED_BY(mu_); const int graph_def_version_; @@ -914,7 +914,7 @@ void RecordElementSize(const std::vector element, }); } -Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { +absl::Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { VLOG(3) << "IteratorGetNextOp enter. iter_id=" << ctx->frame_iter().iter_id; auto cleanup = gtl::MakeCleanup([ctx] { VLOG(3) << "IteratorGetNextOp exit. iter_id=" << ctx->frame_iter().iter_id; @@ -953,7 +953,7 @@ Status IteratorGetNextOp::DoCompute(OpKernelContext* ctx) { return absl::OkStatus(); } -Status IteratorGetModelProtoOp::DoCompute(OpKernelContext* ctx) { +absl::Status IteratorGetModelProtoOp::DoCompute(OpKernelContext* ctx) { IteratorResource* iterator = nullptr; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, 0), &iterator)); core::ScopedUnref unref_iterator(iterator); @@ -967,7 +967,7 @@ Status IteratorGetModelProtoOp::DoCompute(OpKernelContext* ctx) { return absl::OkStatus(); } -Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) { +absl::Status IteratorGetNextAsOptionalOp::DoCompute(OpKernelContext* ctx) { VLOG(3) << "IteratorGetNextAsOptionalOp enter. iter_id=" << ctx->frame_iter().iter_id; auto cleanup = gtl::MakeCleanup([ctx] { @@ -1137,7 +1137,7 @@ void DeserializeIteratorOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized_t)); IteratorVariantSerializer serializer; OP_REQUIRES_OK(ctx, serializer.InitFromTensor(serialized_t)); - Status s = iterator_resource->Restore(ctx, serializer.GetReader()); + absl::Status s = iterator_resource->Restore(ctx, serializer.GetReader()); if (!s.ok()) { OP_REQUIRES_OK( ctx, diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index d28b86a5d35c00..a2b134114cd32a 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -53,25 +53,26 @@ class IteratorResource : public ResourceBase { // // If no more outputs remain, `true` will be stored in `*end_of_sequence`, and // the content of `*out_tensors` will be undefined. - Status GetNext(OpKernelContext* ctx, std::vector* out_tensors, - bool* end_of_sequence); + absl::Status GetNext(OpKernelContext* ctx, std::vector* out_tensors, + bool* end_of_sequence); absl::Status GetModelProto(std::string& model_proto); // Saves a checkpoint of the state of the iterator through the given `writer`. - Status Save(OpKernelContext* ctx, ExternalStatePolicy external_state_policy, - IteratorStateWriter* writer); + absl::Status Save(OpKernelContext* ctx, + ExternalStatePolicy external_state_policy, + IteratorStateWriter* writer); // Restores the state of the iterator from a checkpoint created by `Save`. - Status Restore(OpKernelContext* ctx, IteratorStateReader* reader); + absl::Status Restore(OpKernelContext* ctx, IteratorStateReader* reader); // Creates an iterator for `dataset`, and associates the iterator with this // iterator resource. // // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, // or `Restore`. - Status SetIteratorFromDataset(OpKernelContext* ctx, - const DatasetBase* dataset); + absl::Status SetIteratorFromDataset(OpKernelContext* ctx, + const DatasetBase* dataset); string DebugString() const override { return "Iterator resource"; } @@ -180,7 +181,7 @@ class IteratorHandleOp : public OpKernel { // it is compatible with this op's configuration. The verification may fail in // cases such as two graphs asking queues of the same shared name to have // inconsistent capacities. - Status VerifyResource(IteratorResource* resource); + absl::Status VerifyResource(IteratorResource* resource); FunctionLibraryRuntime* CreatePrivateFLR( OpKernelContext* ctx, std::unique_ptr* device_mgr, @@ -207,11 +208,10 @@ class AnonymousIteratorHandleOp : public AnonymousResourceOp { private: string name() override; - Status CreateResource(OpKernelContext* ctx, - std::unique_ptr flib_def, - std::unique_ptr pflr, - FunctionLibraryRuntime* lib, - IteratorResource** resource) override; + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, IteratorResource** resource) override; DataTypeVector output_dtypes_; std::vector output_shapes_; @@ -242,7 +242,7 @@ class HybridAsyncOpKernel : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) final; protected: - virtual Status DoCompute(OpKernelContext* ctx) = 0; + virtual absl::Status DoCompute(OpKernelContext* ctx) = 0; private: BackgroundWorker background_worker_; @@ -254,7 +254,7 @@ class MakeIteratorOp : public HybridAsyncOpKernel { : HybridAsyncOpKernel(ctx, "tf_data_make_iterator") {} protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; }; class IteratorGetNextOp : public HybridAsyncOpKernel { @@ -268,7 +268,7 @@ class IteratorGetNextOp : public HybridAsyncOpKernel { AsyncOpKernel* AsAsync() override; protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; private: DataTypeVector output_types_; @@ -283,7 +283,7 @@ class IteratorGetModelProtoOp : public HybridAsyncOpKernel { /*background_worker_name=*/"tf_data_iterator_get_model_proto") {} protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; }; class DeleteIteratorOp : public HybridAsyncOpKernel { @@ -292,7 +292,7 @@ class DeleteIteratorOp : public HybridAsyncOpKernel { : HybridAsyncOpKernel(ctx, "tf_data_delete_iterator") {} protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; }; class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel { @@ -304,7 +304,7 @@ class IteratorGetNextAsOptionalOp : public HybridAsyncOpKernel { } protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; private: DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 35c5ca04cde741..bf034a569733f5 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -88,18 +88,19 @@ class MapDatasetOp::Dataset : public DatasetBase { } } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); std::vector args; TF_RETURN_IF_ERROR(input_->Get(ctx, index, &args)); @@ -116,9 +117,9 @@ class MapDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -168,7 +169,7 @@ class MapDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate( @@ -178,17 +179,17 @@ class MapDatasetOp::Dataset : public DatasetBase { // NOTE(mrry): This method is thread-safe as long as `input_impl_` and `f` // are thread-safe. However, if multiple threads enter this method, // outputs may be observed in a non-deterministic order. - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::vector args; TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &args, end_of_sequence)); if (*end_of_sequence) { return absl::OkStatus(); } - Status s = instantiated_captured_func_->Run(ctx, std::move(args), - out_tensors, model_node()); + absl::Status s = instantiated_captured_func_->Run( + ctx, std::move(args), out_tensors, model_node()); if (errors::IsOutOfRange(s)) { if (dataset()->preserve_cardinality_) { // To guarantee that the transformation preserves the cardinality of @@ -215,16 +216,16 @@ class MapDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 3171116e7ae404..3bf82dbc4117e7 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -92,7 +92,7 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface { return static_cast(kernel_->num_outputs()); } - Status GetArg(int index, const Tensor** val) override { + absl::Status GetArg(int index, const Tensor** val) override { if (index < 0 || index >= compute_opts_->args.size() + compute_opts_->captured_inputs.size()) { return errors::InvalidArgument("Mismatch in number of function inputs."); @@ -121,7 +121,7 @@ class MapDefunOp::MapFunctionCallFrame : public CallFrameInterface { return absl::OkStatus(); } - Status SetRetval(int index, const Tensor& val) override { + absl::Status SetRetval(int index, const Tensor& val) override { if (index < 0 || index >= kernel_->num_outputs()) { return errors::InvalidArgument("Mismatch in number of function outputs."); } @@ -193,7 +193,7 @@ void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done); - Status s = SetupOutputs(ctx, compute_opts); + absl::Status s = SetupOutputs(ctx, compute_opts); if (!s.ok()) delete compute_opts; OP_REQUIRES_OK_ASYNC(ctx, s, done); @@ -203,7 +203,7 @@ void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { // Run loop StatusCallback callback = std::bind( [](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done, - const Status& status) { + const absl::Status& status) { delete compute_opts; ctx->SetStatus(status); done(); @@ -226,7 +226,7 @@ void MapDefunOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { refcounted->Ref(); ctx->function_library()->Run( opts, func_handle_, call_frame, - [call_frame, refcounted, c_mgr](const Status& func_status) { + [call_frame, refcounted, c_mgr](const absl::Status& func_status) { delete c_mgr; delete call_frame; refcounted->UpdateStatus(func_status); @@ -254,8 +254,8 @@ void MapDefunOp::SetRunOptions(OpKernelContext* ctx, opts->run_all_kernels_inline = ctx->run_all_kernels_inline(); } -Status MapDefunOp::SetupArgs(OpKernelContext* ctx, - ComputeOptions** compute_opts) { +absl::Status MapDefunOp::SetupArgs(OpKernelContext* ctx, + ComputeOptions** compute_opts) { OpInputList arguments; TF_RETURN_IF_ERROR(ctx->input_list(kArguments, &arguments)); OpInputList captured_inputs; @@ -290,7 +290,8 @@ Status MapDefunOp::SetupArgs(OpKernelContext* ctx, return absl::OkStatus(); } -Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { +absl::Status MapDefunOp::SetupOutputs(OpKernelContext* ctx, + ComputeOptions* opts) { mutex_lock l(opts->mu); TF_RETURN_IF_ERROR(ctx->output_list(kOutput, &opts->output)); diff --git a/tensorflow/core/kernels/data/map_defun_op.h b/tensorflow/core/kernels/data/map_defun_op.h index cf2ed99c6ba1dd..fc4adde992e844 100644 --- a/tensorflow/core/kernels/data/map_defun_op.h +++ b/tensorflow/core/kernels/data/map_defun_op.h @@ -59,9 +59,9 @@ class MapDefunOp : public AsyncOpKernel { ComputeOptions* compute_opts, bool always_collect_stats); // Get inputs to Compute and check that they are valid. - Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts); + absl::Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts); - Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts); + absl::Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts); FunctionLibraryRuntime::Handle func_handle_; std::vector output_shapes_; diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc index aaf292ef365020..5fd5f9ae667627 100644 --- a/tensorflow/core/kernels/data/map_defun_op_test.cc +++ b/tensorflow/core/kernels/data/map_defun_op_test.cc @@ -47,7 +47,7 @@ class MapDefunOpParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->reserve(arguments_.size() + captured_inputs_.size()); @@ -62,7 +62,7 @@ class MapDefunOpParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = { {MapDefunOp::kTarguments, type_arguments_}, {MapDefunOp::kTcaptured, type_captured_}, @@ -90,8 +90,9 @@ class MapDefunOpParams : public DatasetParams { class MapDefunOpTest : public DatasetOpsTestBase { protected: // Creates a new `MapDefun` op kernel - Status CreateMapDefunOpKernel(const MapDefunOpParams& params, - std::unique_ptr* map_defun_kernel) { + absl::Status CreateMapDefunOpKernel( + const MapDefunOpParams& params, + std::unique_ptr* map_defun_kernel) { std::vector input_namess; TF_RETURN_IF_ERROR(params.GetInputNames(&input_namess)); AttributeVector attributes; @@ -104,7 +105,7 @@ class MapDefunOpTest : public DatasetOpsTestBase { } // Creates a new `MapDefun` op kernel context. - Status CreateMapDefunContext( + absl::Status CreateMapDefunContext( OpKernel* const op_kernel, absl::InlinedVector* const inputs, std::unique_ptr* context) { diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc index 15a19b992d66ee..35bcff7ac47f66 100644 --- a/tensorflow/core/kernels/data/model_dataset_op.cc +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -93,19 +93,20 @@ class ModelDatasetOp::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue algorithm_attr; @@ -140,14 +141,14 @@ class ModelDatasetOp::Dataset : public DatasetBase { ~Iterator() override { cancellation_manager_->StartCancel(); } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(IteratorContext(CreateParams(ctx)), this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { if (!ctx->model()) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsureOptimizationLoopThreadStarted(ctx)); @@ -163,13 +164,13 @@ class ModelDatasetOp::Dataset : public DatasetBase { /*ratio=*/1); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return SaveInput(ctx, writer, input_impl_); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return RestoreInput(IteratorContext(CreateParams(ctx)), reader, input_impl_); } @@ -187,7 +188,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { return params; } - Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx) + absl::Status EnsureOptimizationLoopThreadStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!model_thread_) { auto ram_budget_manager = ctx->ram_budget_manager(); @@ -195,7 +196,7 @@ class ModelDatasetOp::Dataset : public DatasetBase { ctx->StartThread("tf_data_model", [this, ram_budget_manager]() { int64_t captured_cpu_budget = cpu_budget_; int64_t captured_ram_budget = ram_budget_; - Status status = model_->OptimizeLoop( + absl::Status status = model_->OptimizeLoop( dataset()->algorithm_, [captured_cpu_budget]() { return captured_cpu_budget; }, 1.0, captured_ram_budget, *ram_budget_manager, diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 117a9611ec5285..f6dd98365b50b5 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -55,7 +55,7 @@ const char kOutputShapes[] = "output_shapes"; const char kOutputTypes[] = "output_types"; struct HostBufferElement { - Status status; + absl::Status status; bool end_of_sequence; std::vector value; }; @@ -108,8 +108,9 @@ class MultiDeviceIterator : public ResourceBase { " devices"); } - Status Init(std::unique_ptr iterator, int64_t max_buffer_size, - int64_t* incarnation_id, DatasetBase* dataset) { + absl::Status Init(std::unique_ptr iterator, + int64_t max_buffer_size, int64_t* incarnation_id, + DatasetBase* dataset) { if (iterator) { TF_RETURN_IF_ERROR( VerifyTypesMatch(output_types_, iterator->output_dtypes())); @@ -133,9 +134,9 @@ class MultiDeviceIterator : public ResourceBase { return absl::OkStatus(); } - Status GetNextFromShard(OpKernelContext* ctx, int shard_num, - int64_t incarnation_id, - MultiDeviceIteratorCallback callback) { + absl::Status GetNextFromShard(OpKernelContext* ctx, int shard_num, + int64_t incarnation_id, + MultiDeviceIteratorCallback callback) { tsl::profiler::TraceMe traceme([&] { return tsl::profiler::TraceMeEncode( absl::StrCat("GetNextFromShard", shard_num), @@ -557,7 +558,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel { flr, std::move(function_handle_cache)); return absl::OkStatus(); })); - Status s = VerifyResource(resource); + absl::Status s = VerifyResource(resource); if (TF_PREDICT_FALSE(!s.ok())) { resource->Unref(); context->SetStatus(s); @@ -578,7 +579,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel { // it is compatible with this op's configuration. The verification may fail in // cases such as two graphs asking queues of the same shared name to have // inconsistent capacities. - Status VerifyResource(MultiDeviceIterator* resource) { + absl::Status VerifyResource(MultiDeviceIterator* resource) { TF_RETURN_IF_ERROR( VerifyTypesMatch(output_types_, resource->output_types())); TF_RETURN_IF_ERROR( @@ -618,11 +619,10 @@ class AnonymousMultiDeviceIteratorOp private: string name() override { return kAnonymousMultiDeviceIterator; } - Status CreateResource(OpKernelContext* ctx, - std::unique_ptr flib_def, - std::unique_ptr pflr, - FunctionLibraryRuntime* lib, - MultiDeviceIterator** resource) override { + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, MultiDeviceIterator** resource) override { auto function_handle_cache = std::make_unique(lib); *resource = new MultiDeviceIterator(ctx->env(), output_dtypes_, output_shapes_, @@ -723,7 +723,7 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { [ctx, iterator, start_time, &n](const HostBufferElement& elem) { iterator->metrics_collector().RecordStop(start_time, elem.value); - Status s = elem.status; + absl::Status s = elem.status; if (!s.ok()) { ctx->SetStatus(s); } else if (elem.end_of_sequence) { @@ -737,8 +737,8 @@ class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { }, std::placeholders::_1); - Status s = iterator->GetNextFromShard(ctx, shard_num, incarnation_id, - std::move(callback)); + absl::Status s = iterator->GetNextFromShard( + ctx, shard_num, incarnation_id, std::move(callback)); if (!s.ok()) { ctx->SetStatus(s); iterator->Unref(); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 72524e30813f40..d06cb1ffe419f0 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -83,8 +83,8 @@ void MakeDatasetHelper(OpKernelContext* ctx, }; core::RefCountPtr rewritten; - Status s = RewriteDataset(ctx, input, std::move(config_factory), - /*record_fingerprint=*/false, &rewritten); + absl::Status s = RewriteDataset(ctx, input, std::move(config_factory), + /*record_fingerprint=*/false, &rewritten); *output = rewritten.release(); if (errors::IsDeadlineExceeded(s)) { // Ignore DeadlineExceeded as it implies that the attempted rewrite took too diff --git a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc index 6e03748b970582..a5e86718f12a92 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op_test.cc @@ -44,13 +44,13 @@ class OptimizeDatasetParams : public DatasetParams { return {CreateTensor(TensorShape({1}), {optimizations_})}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {OptimizeDatasetOp::kInputDataset, OptimizeDatasetOp::kOptimizations}; return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = { {OptimizeDatasetOp::kOutputShapes, output_shapes_}, {OptimizeDatasetOp::kOutputTypes, output_dtypes_}, diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index 07967a672135e3..5b9fcef4359c0c 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -26,7 +26,7 @@ namespace tensorflow { namespace data { namespace { -static Status OptionalDeviceCopy( +static absl::Status OptionalDeviceCopy( const OptionalVariant& from, OptionalVariant* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { if (from.has_value()) { @@ -135,8 +135,9 @@ void OptionalGetValueOp::Compute(OpKernelContext* ctx) { } } -Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, - std::vector value) { +absl::Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, + int output_index, + std::vector value) { OptionalVariant v(std::move(value)); Tensor* variant_t; AllocatorAttributes cpu_alloc; @@ -147,7 +148,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, return absl::OkStatus(); } -Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { +absl::Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { OptionalVariant v; Tensor* variant_t; AllocatorAttributes cpu_alloc; diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h index 4e69243dbb111d..8006b00b5a4452 100644 --- a/tensorflow/core/kernels/data/optional_ops.h +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -27,22 +27,23 @@ namespace data { // Stores a DT_VARIANT value representing an Optional with the given value // in the `output_index`^th output of the given kernel execution context. -Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, - std::vector value); +absl::Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, + int output_index, + std::vector value); // Stores a DT_VARIANT value representing an Optional with no value // in the `output_index`^th output of the given kernel execution context. -Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); +absl::Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); template -Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, - OptionalVariant* y) { +absl::Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, + OptionalVariant* y) { return OptionalZerosLike(ctx, x, y, ZerosLikeTensor); } template -Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a, - const OptionalVariant& b, OptionalVariant* out) { +absl::Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a, + const OptionalVariant& b, OptionalVariant* out) { return OptionalBinaryAdd(ctx, a, b, out, BinaryAddTensors); } diff --git a/tensorflow/core/kernels/data/optional_ops_util.cc b/tensorflow/core/kernels/data/optional_ops_util.cc index c504c99c7a528d..bd8ae76e67e344 100644 --- a/tensorflow/core/kernels/data/optional_ops_util.cc +++ b/tensorflow/core/kernels/data/optional_ops_util.cc @@ -26,11 +26,11 @@ limitations under the License. namespace tensorflow { namespace data { -Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, - OptionalVariant* y, - std::function - zeros_like_func) { +absl::Status OptionalZerosLike( + OpKernelContext* ctx, const OptionalVariant& x, OptionalVariant* y, + std::function + zeros_like_func) { if (!x.has_value()) { return absl::OkStatus(); } @@ -44,11 +44,11 @@ Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, return absl::OkStatus(); } -Status OptionalBinaryAdd( +absl::Status OptionalBinaryAdd( OpKernelContext* ctx, const OptionalVariant& a, const OptionalVariant& b, OptionalVariant* out, - std::function + std::function binary_add_func) { // TODO(skyewm): should adding a value to a non-value be a no-op instead? if (a.has_value() != b.has_value()) { diff --git a/tensorflow/core/kernels/data/optional_ops_util.h b/tensorflow/core/kernels/data/optional_ops_util.h index 5e3ce9141b6d03..3ee3742f8304e5 100644 --- a/tensorflow/core/kernels/data/optional_ops_util.h +++ b/tensorflow/core/kernels/data/optional_ops_util.h @@ -98,17 +98,17 @@ class OptionalVariant { std::shared_ptr> values_; }; -Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, - OptionalVariant* y, - std::function - zeros_like_func); +absl::Status OptionalZerosLike( + OpKernelContext* ctx, const OptionalVariant& x, OptionalVariant* y, + std::function + zeros_like_func); -Status OptionalBinaryAdd( +absl::Status OptionalBinaryAdd( OpKernelContext* ctx, const OptionalVariant& a, const OptionalVariant& b, OptionalVariant* out, - std::function + std::function binary_add_func); } // namespace data diff --git a/tensorflow/core/kernels/data/options_dataset_op.cc b/tensorflow/core/kernels/data/options_dataset_op.cc index 06523c8c793067..52aa8ce2ef85bd 100644 --- a/tensorflow/core/kernels/data/options_dataset_op.cc +++ b/tensorflow/core/kernels/data/options_dataset_op.cc @@ -76,8 +76,8 @@ class OptionsDatasetOp::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { return input_->Get(ctx, index, out_tensors); } @@ -85,12 +85,13 @@ class OptionsDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } @@ -99,9 +100,9 @@ class OptionsDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); AttrValue serialized_options_attr; diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index 1c8ef0caef8b04..8bc8d91a93cc35 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -121,19 +121,20 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size = nullptr; @@ -189,13 +190,13 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { // Each row of `batch_elements` is a tuple of tensors from the // input iterator. std::vector> batch_elements; @@ -244,8 +245,8 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( prefix(), kExhausted, static_cast(!input_impl_))); @@ -255,8 +256,8 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_exhausted; TF_RETURN_IF_ERROR( @@ -283,9 +284,10 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { // potentially read the input values in-place into their respective slice // locations. This would require a different GetNext() overload that // supports zero-copy, and might make sense in an optimization pass. - Status CopyBatch(IteratorContext* ctx, - const std::vector>& batch_elements, - std::vector* out_tensors) { + absl::Status CopyBatch( + IteratorContext* ctx, + const std::vector>& batch_elements, + std::vector* out_tensors) { const size_t num_tuple_components = batch_elements[0].size(); const int64_t num_batch_elements = batch_elements.size(); for (size_t component_index = 0; component_index < num_tuple_components; @@ -371,7 +373,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { if (dataset()->parallel_copy_ && (batch_component.AllocatedBytes() / num_batch_elements) >= (1 << 15)) { BlockingCounter counter(num_batch_elements); - Status status; + absl::Status status; mutex status_mu; const auto num_threads = ctx->runner_threadpool_size(); const auto slice_size = num_batch_elements / num_threads; @@ -386,7 +388,7 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase { ©_element_fn]() { for (size_t j = offset; j < offset + length; ++j) { { - Status s = copy_element_fn(j); + absl::Status s = copy_element_fn(j); mutex_lock l(status_mu); status.Update(s); } diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc index 6edacb8b7e3444..dd74a0f00b320e 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op_test.cc @@ -62,7 +62,7 @@ class PaddedBatchDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {PaddedBatchDatasetOp::kInputDataset, PaddedBatchDatasetOp::kBatchSize}; // Create the input names for the input padded_shapes. @@ -79,7 +79,7 @@ class PaddedBatchDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"parallel_copy", parallel_copy_}, {"Toutput_types", output_dtypes_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc index 7527226d94a8f3..84de715927d369 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op.cc @@ -137,19 +137,20 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { // Input: input_dataset Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -203,7 +204,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); interleave_depth_ = ctx->interleave_depth(); @@ -232,9 +233,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::shared_ptr result; { mutex_lock l(*mu_); @@ -279,8 +280,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { if (ctx->symbolic_checkpoint()) { return writer->WriteScalar(prefix(), kBatchResultsSize, 0); } @@ -299,8 +300,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(*mu_); DCHECK(!runner_thread_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); @@ -354,7 +355,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { bool end_of_input TF_GUARDED_BY(mu); int64_t num_elements TF_GUARDED_BY(mu); std::vector output TF_GUARDED_BY(mu); - Status status TF_GUARDED_BY(mu); + absl::Status status TF_GUARDED_BY(mu); bool call_finished TF_GUARDED_BY(&Iterator::mu_); bool output_allocated TF_GUARDED_BY(mu); const int64_t uid = -1; @@ -394,8 +395,8 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { bool end_of_input = false; for (int i = 0; i < dataset()->batch_size_ && !end_of_input; ++i) { std::vector batch_element_tuple; - Status status = input_impl_->GetNext(ctx.get(), &batch_element_tuple, - &end_of_input); + absl::Status status = input_impl_->GetNext( + ctx.get(), &batch_element_tuple, &end_of_input); { mutex_lock l(result->mu); result->end_of_input = result->end_of_input || end_of_input; @@ -420,7 +421,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { auto copy_elements_fn = [this, ctx, result, batch_elements = std::move(batch_elements)]() mutable { - Status status; + absl::Status status; { mutex_lock l(result->mu); status = CopyBatch(AnyContext(ctx.get()), std::move(batch_elements), @@ -543,8 +544,9 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { return true; } - Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, - size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + absl::Status ReadBatchResult(IteratorContext* ctx, + IteratorStateReader* reader, size_t index) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { batch_results_.push_back(std::make_shared(ctx)); std::shared_ptr result = batch_results_.back(); string batch_prefix = strings::StrCat(kBatchResults, "_", index); @@ -570,7 +572,7 @@ class ParallelBatchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteBatchResult(IteratorStateWriter* writer, size_t index) + absl::Status WriteBatchResult(IteratorStateWriter* writer, size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr result = batch_results_[index]; string batch_prefix = strings::StrCat(kBatchResults, "_", index); diff --git a/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc index f4fc0c519e79e6..574b7b99da5510 100644 --- a/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_batch_dataset_op_test.cc @@ -55,7 +55,7 @@ class ParallelBatchDatasetParams : public DatasetParams { return {batch_size, num_parallel_calls, drop_remainder}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {ParallelBatchDatasetOp::kInputDataset, ParallelBatchDatasetOp::kBatchSize, ParallelBatchDatasetOp::kNumParallelCalls, @@ -63,7 +63,7 @@ class ParallelBatchDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = { {"parallel_copy", parallel_copy_}, {"output_types", output_dtypes_}, diff --git a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc index 53b488958dc1e4..e8a16c0a08e20a 100644 --- a/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_filter_dataset_op.cc @@ -93,20 +93,21 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return name_utils::DatasetDebugString(kDatasetType); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); std::vector other_arguments; @@ -151,7 +152,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { if (deregister_fn_) deregister_fn_(); } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); interleave_depth_ = ctx->interleave_depth(); if (num_parallel_calls_->value == model::kAutotune) { @@ -169,9 +170,9 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { ctx, &instantiated_captured_func_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::shared_ptr result; { mutex_lock l(*mu_); @@ -201,8 +202,8 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { /*max=*/ctx->runner_threadpool_size())}); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); mutex_lock l(*mu_); @@ -238,8 +239,8 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64_t invocation_results_size; @@ -296,7 +297,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { InvocationResult() : uid(tensorflow::EnvTime::NowNanos()) {} Notification notification; - Status status; + absl::Status status; std::vector return_values; std::vector predicate_values; bool end_of_input = false; @@ -349,7 +350,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return; } result->return_values = input_element; - auto done = [this, ctx, result](Status status) { + auto done = [this, ctx, result](absl::Status status) { result->status.Update(status); // Callback is not a predicate function, set the error status of this // result. @@ -382,7 +383,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { std::move(input_element)); (*ctx->runner())( [this, ctx, fn = std::move(fn), done = std::move(done)]() { - Status s; + absl::Status s; // Check whether we are already recording to prevent invalid // nesting of `RecordStart` calls. if (IsRecording(ctx.get())) { @@ -397,10 +398,10 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { } } - Status ProcessResult(IteratorContext* ctx, - const std::shared_ptr& result, - std::vector* out_tensors, - bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) { + absl::Status ProcessResult(IteratorContext* ctx, + const std::shared_ptr& result, + std::vector* out_tensors, + bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) { if (!result->end_of_input && result->status.ok()) { *out_tensors = std::move(result->return_values); *end_of_sequence = false; @@ -522,9 +523,9 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return true; } - Status WriteComponentsLocked(IteratorStateWriter* writer, - const std::string& prefix, - const std::vector& values) + absl::Status WriteComponentsLocked(IteratorStateWriter* writer, + const std::string& prefix, + const std::vector& values) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar(prefix, kSize, values.size())); for (size_t j = 0; j < values.size(); j++) { @@ -534,10 +535,10 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadComponentsLocked(IteratorContext* ctx, - IteratorStateReader* reader, - const std::string& prefix, - std::vector* values) + absl::Status ReadComponentsLocked(IteratorContext* ctx, + IteratorStateReader* reader, + const std::string& prefix, + std::vector* values) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64_t size; TF_RETURN_IF_ERROR(reader->ReadScalar(prefix, kSize, &size)); @@ -556,8 +557,9 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status WriteStatusLocked(IteratorStateWriter* writer, - const std::string& key, const Status& status) + absl::Status WriteStatusLocked(IteratorStateWriter* writer, + const std::string& key, + const absl::Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( key, kErrorCode, static_cast(status.code()))); @@ -568,8 +570,9 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadStatusLocked(IteratorStateReader* reader, const std::string& key, - Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + absl::Status ReadStatusLocked(IteratorStateReader* reader, + const std::string& key, absl::Status* status) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(key, kErrorCode, &code_int)); absl::StatusCode code = static_cast(code_int); @@ -578,7 +581,7 @@ class ParallelFilterDatasetOp::Dataset : public DatasetBase { tstring error_message; TF_RETURN_IF_ERROR( reader->ReadScalar(key, kErrorMessage, &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc index 9537ea1c1aef2c..3b27c6da5669e9 100644 --- a/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_filter_dataset_op_test.cc @@ -56,7 +56,7 @@ class ParallelFilterDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->reserve(input_dataset_params_.size() + other_arguments_.size()); @@ -69,7 +69,7 @@ class ParallelFilterDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = { {"predicate", pred_func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, {"output_types", output_dtypes_}, diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index c585aeba983f6f..84b06ad05f5072 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -288,7 +288,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { DatasetGraphDefBuilder* b, Node** output) const override { std::vector> inputs; - std::vector>> list_inputs; + std::vector>> list_inputs; int input_index = 0; Node* input_node; diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc index 7649e7aa996996..bd257808dd29a0 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op_test.cc @@ -73,7 +73,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(ParallelInterleaveDatasetOp::kInputDataset); for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back( @@ -89,7 +89,7 @@ class ParallelInterleaveDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"deterministic", deterministic_}, {"Targuments", type_arguments_}, diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 2774d25747340b..5c25b52f48b71c 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -162,8 +162,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); absl::call_once(instantiated_captured_func_once_, [this, ctx] { instantiated_captured_func_status_ = captured_func_->Instantiate( @@ -175,12 +175,13 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return instantiated_captured_func_->RunInstantiated(args, out_tensors); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { TF_RETURN_IF_ERROR(captured_func_->CheckExternalState()); return input_->CheckExternalState(); } @@ -190,9 +191,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { // Input: input_dataset Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); @@ -287,7 +288,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return deterministic_; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); interleave_depth_ = ctx->interleave_depth(); if (use_unbounded_threadpool_) { @@ -315,9 +316,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { std::shared_ptr result; { mutex_lock l(*mu_); @@ -378,8 +379,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { /*is_legacy_prefetch_autotuned=*/false, estimated_element_size); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( dataset()->captured_func_->CheckExternalState())); if (ctx->symbolic_checkpoint()) { @@ -418,8 +419,8 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); DCHECK(invocation_results_.empty()); @@ -499,7 +500,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { checkpoint(MemoryCheckpoint{ctx->id_registry()}) {} Notification notification; - Status status; + absl::Status status; std::vector return_values; bool end_of_input = false; const int64_t uid; @@ -558,7 +559,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return; } - auto done = [this, ctx, result](Status status) { + auto done = [this, ctx, result](absl::Status status) { if (!status.ok()) { result->status = AddErrorContext(status); } @@ -593,7 +594,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { std::move(input_element)); (*ctx->runner())( [this, ctx, fn = std::move(fn), done = std::move(done)]() { - Status s; + absl::Status s; // Check whether we are already recording to prevent invalid // nesting of `RecordStart` calls. if (IsRecording(ctx.get())) { @@ -608,10 +609,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } } - Status ProcessResult(IteratorContext* ctx, - const std::shared_ptr& result, - std::vector* out_tensors, - bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) { + absl::Status ProcessResult(IteratorContext* ctx, + const std::shared_ptr& result, + std::vector* out_tensors, + bool* end_of_sequence) TF_LOCKS_EXCLUDED(*mu_) { ctx->MergeCheckpoint(&result->checkpoint); if (!result->end_of_input && result->status.ok()) { *out_tensors = std::move(result->return_values); @@ -739,8 +740,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } } - Status WriteStatusLocked(IteratorStateWriter* writer, - const std::string& prefix, const Status& status) + absl::Status WriteStatusLocked(IteratorStateWriter* writer, + const std::string& prefix, + const absl::Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix, absl::StrCat("_", kErrorCode), @@ -753,8 +755,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadStatusLocked(IteratorStateReader* reader, - const std::string& prefix, Status* status) + absl::Status ReadStatusLocked(IteratorStateReader* reader, + const std::string& prefix, + absl::Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64_t code_int; TF_RETURN_IF_ERROR( @@ -765,7 +768,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { tstring error_message; TF_RETURN_IF_ERROR(reader->ReadScalar( prefix, absl::StrCat("_", kErrorMessage), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc index cedc0e8adad743..cee0daa7e3161e 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc @@ -61,7 +61,7 @@ class ParallelMapDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->emplace_back(ParallelMapDatasetOp::kInputDataset); for (int i = 0; i < other_arguments_.size(); ++i) { input_names->emplace_back( @@ -71,7 +71,7 @@ class ParallelMapDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { *attr_vector = {{"f", func_}, {"Targuments", type_arguments_}, {"output_shapes", output_shapes_}, diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index c6238a0f987a1d..7bd9bec0237b99 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -108,17 +108,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return input_->Cardinality(options); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { return input_->Get(ctx, index, out_tensors); } @@ -127,9 +128,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size = nullptr; @@ -175,7 +176,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); auto_tuner_ = std::make_unique( dataset()->buffer_size_, dataset()->buffer_size_min_, @@ -201,9 +202,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { const auto& stats_aggregator = ctx->stats_aggregator(); { mutex_lock l(*mu_); @@ -266,8 +267,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { /*is_legacy_prefetch_autotuned=*/legacy_autotune_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { if (ctx->symbolic_checkpoint()) { return absl::OkStatus(); } @@ -295,8 +296,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock input_l(input_mu_); mutex_lock l(*mu_); DCHECK(!prefetch_thread_); @@ -362,7 +363,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { checkpoint(MemoryCheckpoint{ctx->id_registry()}) {} // The producer sets `status` if getting the input element fails. - Status status; + absl::Status status; // The buffered data element. std::vector value; int64_t created_us; @@ -370,8 +371,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { MemoryCheckpoint checkpoint; }; - Status RestoreBuffer(IteratorContext* const ctx, - IteratorStateReader* const reader) + absl::Status RestoreBuffer(IteratorContext* const ctx, + IteratorStateReader* const reader) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { size_t buffer_size; { @@ -420,8 +421,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { cond_var_->notify_all(); } - Status Consume(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status Consume(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { const auto& stats_aggregator = ctx->stats_aggregator(); if (stats_aggregator) { double buffer_limit_ = buffer_limit(); @@ -439,7 +441,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } // A new element is available. Forward the status from computing it, and // (if we successfully got an element) the output values. - Status s = buffer_.front().status; + absl::Status s = buffer_.front().status; if (s.ok()) { int64_t buffer_element_id = buffer_.front().uid; tsl::profiler::TraceMe traceme( @@ -494,7 +496,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return s; } - Status EnsureThreadsStarted(IteratorContext* ctx) + absl::Status EnsureThreadsStarted(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!prefetch_thread_) { std::shared_ptr new_ctx = @@ -576,8 +578,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } } - Status WriteStatus(IteratorStateWriter* writer, size_t index, - const Status& status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + absl::Status WriteStatus(IteratorStateWriter* writer, size_t index, + const absl::Status& status) + TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(absl::StrCat(prefix(), "::", index), CodeKey(), static_cast(status.code()))); @@ -589,7 +592,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status) + absl::Status ReadStatus(IteratorStateReader* reader, size_t index, + absl::Status* status) TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64_t code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(absl::StrCat(prefix(), "::", index), @@ -601,7 +605,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( reader->ReadScalar(absl::StrCat(prefix(), "::", index), ErrorMessageKey(), &error_message)); - *status = Status(code, error_message); + *status = absl::Status(code, error_message); } else { *status = absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc index 7fe3674db2aca4..9f4576001042f3 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op_test.cc @@ -46,14 +46,14 @@ class PrefetchDatasetParams : public DatasetParams { return {CreateTensor(TensorShape({}), {buffer_size_})}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(PrefetchDatasetOp::kInputDataset); input_names->emplace_back(PrefetchDatasetOp::kBufferSize); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); diff --git a/tensorflow/core/kernels/data/random_seed_ops.cc b/tensorflow/core/kernels/data/random_seed_ops.cc index 61566cab3fabd3..ab8a0eaecb9cab 100644 --- a/tensorflow/core/kernels/data/random_seed_ops.cc +++ b/tensorflow/core/kernels/data/random_seed_ops.cc @@ -80,7 +80,7 @@ void AnonymousSeedGeneratorHandleOp::Compute(OpKernelContext* ctx) { std::string AnonymousSeedGeneratorHandleOp::name() { return kSeedGenerator; } -Status AnonymousSeedGeneratorHandleOp::CreateResource( +absl::Status AnonymousSeedGeneratorHandleOp::CreateResource( OpKernelContext* ctx, std::unique_ptr flib_def, std::unique_ptr pflr, FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) diff --git a/tensorflow/core/kernels/data/random_seed_ops.h b/tensorflow/core/kernels/data/random_seed_ops.h index 638f779e4bc166..f0afa739bfd59b 100644 --- a/tensorflow/core/kernels/data/random_seed_ops.h +++ b/tensorflow/core/kernels/data/random_seed_ops.h @@ -136,11 +136,10 @@ class AnonymousSeedGeneratorHandleOp private: string name() override; - Status CreateResource(OpKernelContext* ctx, - std::unique_ptr flib_def, - std::unique_ptr pflr, - FunctionLibraryRuntime* lib, - SeedGeneratorManager** manager) override; + absl::Status CreateResource( + OpKernelContext* ctx, std::unique_ptr flib_def, + std::unique_ptr pflr, + FunctionLibraryRuntime* lib, SeedGeneratorManager** manager) override; mutex mu_; std::unique_ptr seeds_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index 5834494e5a043f..dfa6d4b43cd70e 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -52,8 +52,8 @@ constexpr char kHasSplitProvider[] = "has_split_provider"; constexpr char kSlash[] = "/"; constexpr char kSplitProvider[] = "split_provider"; -Status ConvertOutputTypes(const tensorflow::DataTypeVector& output_dtypes, - std::vector* out_tensors, int64 value) { +absl::Status ConvertOutputTypes(const tensorflow::DataTypeVector& output_dtypes, + std::vector* out_tensors, int64 value) { switch (output_dtypes[0]) { #define HANDLE_TYPE(type) \ case DataTypeToEnum::value: { \ @@ -144,7 +144,7 @@ class RangeDatasetOp::RangeSplitProvider : public SplitProvider { RangeSplitProvider(int64_t start, int64_t stop, int64_t step) : counter_(start, stop, step) {} - Status GetNext(Tensor* split, bool* end_of_splits) override { + absl::Status GetNext(Tensor* split, bool* end_of_splits) override { int64_t next = counter_.GetNext(end_of_splits); if (*end_of_splits) { return absl::OkStatus(); @@ -154,20 +154,20 @@ class RangeDatasetOp::RangeSplitProvider : public SplitProvider { return absl::OkStatus(); } - Status Reset() override { + absl::Status Reset() override { counter_.Reset(); return absl::OkStatus(); } - Status Save(std::function key_name_fn, - IteratorStateWriter* writer) override { + absl::Status Save(std::function key_name_fn, + IteratorStateWriter* writer) override { TF_RETURN_IF_ERROR( writer->WriteScalar(key_name_fn(kNext), counter_.Peek())); return absl::OkStatus(); } - Status Restore(std::function key_name_fn, - IteratorStateReader* reader) override { + absl::Status Restore(std::function key_name_fn, + IteratorStateReader* reader) override { int64_t next; TF_RETURN_IF_ERROR(reader->ReadScalar(key_name_fn(kNext), &next)); counter_.SetNext(next); @@ -221,36 +221,37 @@ class RangeDatasetOp::Dataset : public DatasetBase { return RangeCardinality(start_, stop_, step_); } - Status MakeSplitProviders(std::vector>* - split_providers) const override { + absl::Status MakeSplitProviders(std::vector>* + split_providers) const override { split_providers->push_back( std::make_unique(start_, stop_, step_)); return absl::OkStatus(); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->clear(); return absl::OkStatus(); } - Status CheckExternalState() const override { return absl::OkStatus(); } + absl::Status CheckExternalState() const override { return absl::OkStatus(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { return Get(AnyContext(ctx), index, out_tensors); } - Status Get(AnyContext ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(AnyContext ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); return ConvertOutputTypes(output_dtypes(), out_tensors, start_ + (index * step_)); } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* start = nullptr; Node* stop = nullptr; Node* step = nullptr; @@ -276,7 +277,7 @@ class RangeDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { if (ctx->split_providers().empty() || dataset()->replicate_on_split_) { counter_ = std::make_unique( dataset()->start_, dataset()->stop_, dataset()->step_); @@ -287,9 +288,9 @@ class RangeDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { if (ctx->index_mapper() != nullptr) { return global_shuffle_iterator_.GetNext(ctx, out_tensors, end_of_sequence); @@ -318,8 +319,8 @@ class RangeDatasetOp::Dataset : public DatasetBase { return model::MakeSourceNode(std::move(args)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { if (split_provider_) { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kHasSplitProvider, true)); @@ -336,8 +337,8 @@ class RangeDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } diff --git a/tensorflow/core/kernels/data/reduce_dataset_op.cc b/tensorflow/core/kernels/data/reduce_dataset_op.cc index c4937bf0cd2ea8..bca2891d3c7bdf 100644 --- a/tensorflow/core/kernels/data/reduce_dataset_op.cc +++ b/tensorflow/core/kernels/data/reduce_dataset_op.cc @@ -41,7 +41,7 @@ ReduceDatasetOp::ReduceDatasetOp(OpKernelConstruction* ctx) OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_)); } -Status ReduceDatasetOp::DoCompute(OpKernelContext* ctx) { +absl::Status ReduceDatasetOp::DoCompute(OpKernelContext* ctx) { tsl::profiler::TraceMe traceme( [&] { return tsl::profiler::TraceMeEncode("ReduceDatasetOp::DoCompute", diff --git a/tensorflow/core/kernels/data/reduce_dataset_op.h b/tensorflow/core/kernels/data/reduce_dataset_op.h index 524e7e304bd9c4..73e1814480ea99 100644 --- a/tensorflow/core/kernels/data/reduce_dataset_op.h +++ b/tensorflow/core/kernels/data/reduce_dataset_op.h @@ -30,7 +30,7 @@ class ReduceDatasetOp : public HybridAsyncOpKernel { explicit ReduceDatasetOp(OpKernelConstruction* ctx); protected: - Status DoCompute(OpKernelContext* ctx) override; + absl::Status DoCompute(OpKernelContext* ctx) override; std::shared_ptr func_metadata_ = nullptr; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc index 65e9e0a4ba5e6c..779a9ab82104d1 100644 --- a/tensorflow/core/kernels/data/reduce_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/reduce_dataset_op_test.cc @@ -54,7 +54,7 @@ class ReduceDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back("input_dataset"); for (int i = 0; i < initial_state_.size(); ++i) { @@ -66,7 +66,7 @@ class ReduceDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); *attr_vector = {{"f", func_}, {"Tstate", type_state_}, diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 555bdbaad31322..571a609423283a 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -73,7 +73,7 @@ bool HasDataServiceInput(const DatasetBase* dataset) { return true; } std::vector inputs; - Status s = dataset->InputDatasets(&inputs); + absl::Status s = dataset->InputDatasets(&inputs); if (!s.ok()) { return false; } @@ -204,17 +204,18 @@ class RepeatDatasetOp::Dataset : public DatasetBase { return count_ * n; } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); return input_->Get(ctx, index % input_->Cardinality(), out_tensors); } @@ -224,9 +225,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* count = nullptr; @@ -243,9 +244,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { *end_of_sequence = true; return absl::OkStatus(); } @@ -257,12 +258,12 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { return absl::OkStatus(); } }; @@ -274,15 +275,15 @@ class RepeatDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); return dataset()->input_->MakeIterator( ctx, this, nested_prefix(prefix(), i_), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. if (!input_impl_) { *end_of_sequence = true; @@ -341,8 +342,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurIteration, i_)); TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -353,8 +354,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); int64_t input_empty; TF_RETURN_IF_ERROR( @@ -414,15 +415,15 @@ class RepeatDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); return dataset()->input_->MakeIterator( ctx, this, nested_prefix(prefix(), i_), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); // TODO(mrry): Make locking less conservative. do { if (!input_impl_) { @@ -463,8 +464,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase { /*ratio=*/kKnownRatio); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurIteration, i_)); TF_RETURN_IF_ERROR(writer->WriteScalar( @@ -475,8 +476,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_)); int64_t input_empty; diff --git a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc index 451d7c107de04c..77a4b3a472d8a3 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op_test.cc @@ -47,14 +47,14 @@ class RepeatDatasetParams : public DatasetParams { return {CreateTensor(TensorShape({}), {count_})}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(RepeatDatasetOp::kInputDataset); input_names->emplace_back(RepeatDatasetOp::kCount); return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); diff --git a/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc b/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc index bde5797ffd2d69..700966242c3024 100644 --- a/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/rewrite_dataset_op_test.cc @@ -42,13 +42,13 @@ class RewriteDatasetParams : public DatasetParams { return {CreateTensor(TensorShape({}), {rewrite_name_})}; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { *input_names = {RewriteDatasetOp::kInputDataset, RewriteDatasetOp::kRewriteName}; return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); return absl::OkStatus(); diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc index 607e3fc2d0c556..f027267e95cad6 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op.cc @@ -111,17 +111,18 @@ class ShardDatasetOp::Dataset : public DatasetBase { return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0); } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); return input_->Get(ctx, index_ + (num_shards_ * index), out_tensors); } @@ -131,9 +132,9 @@ class ShardDatasetOp::Dataset : public DatasetBase { } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* num_shards = nullptr; @@ -158,7 +159,7 @@ class ShardDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { if (dataset()->num_shards_ == kShardHint) { return errors::FailedPrecondition( "`tf.data.Dataset.shard(SHARD_HINT, ...)` can only be used in " @@ -171,9 +172,9 @@ class ShardDatasetOp::Dataset : public DatasetBase { return dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); *end_of_sequence = false; if (!input_impl_) { @@ -210,8 +211,9 @@ class ShardDatasetOp::Dataset : public DatasetBase { if (dataset()->require_non_empty_ && next_index_ < dataset()->num_shards_) { int num_skipped; - Status s = input_impl_->Skip(ctx, dataset()->num_shards_ - next_index_, - end_of_sequence, &num_skipped); + absl::Status s = + input_impl_->Skip(ctx, dataset()->num_shards_ - next_index_, + end_of_sequence, &num_skipped); if (*end_of_sequence || errors::IsOutOfRange(s)) { // `dataset()->require_non_empty_` implies that this transformation // was introduced by auto_sharding rewrite, so it's acceptable @@ -232,8 +234,8 @@ class ShardDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status Get(IteratorContext* ctx, std::vector* out_tensors, - bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status Get(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { IteratorContextWithIndexMapper ctx_with_index_mapper(ctx, this); auto merge_checkpoint = gtl::MakeCleanup([&ctx_with_index_mapper] { ctx_with_index_mapper.MergeCheckpoint(); @@ -274,8 +276,8 @@ class ShardDatasetOp::Dataset : public DatasetBase { std::move(args), 1.0 / static_cast(dataset()->num_shards_)); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar( prefix(), kInputImplEmpty, static_cast(!input_impl_))); @@ -287,8 +289,8 @@ class ShardDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); if (ctx->restored_element_count().has_value()) { element_count_ = *ctx->restored_element_count(); diff --git a/tensorflow/core/kernels/data/shard_dataset_op_test.cc b/tensorflow/core/kernels/data/shard_dataset_op_test.cc index d593bc6be3accb..5ef18421bcd159 100644 --- a/tensorflow/core/kernels/data/shard_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shard_dataset_op_test.cc @@ -41,7 +41,7 @@ class ShardDatasetParams : public DatasetParams { return CreateTensors(TensorShape({}), {{num_shards_}, {index_}}); } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(ShardDatasetOp::kInputDataset); input_names->emplace_back(ShardDatasetOp::kNumShards); @@ -49,7 +49,7 @@ class ShardDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("require_non_empty", require_non_empty_); attr_vector->emplace_back("output_types", output_dtypes_); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index c1002e7c7c64e5..d84fa4ccd8d30c 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -124,17 +124,18 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } } - Status InputDatasets(std::vector* inputs) const override { + absl::Status InputDatasets( + std::vector* inputs) const override { inputs->push_back(input_); return absl::OkStatus(); } - Status CheckExternalState() const override { + absl::Status CheckExternalState() const override { return input_->CheckExternalState(); } - Status Get(OpKernelContext* ctx, int64 index, - std::vector* out_tensors) const override { + absl::Status Get(OpKernelContext* ctx, int64 index, + std::vector* out_tensors) const override { TF_RETURN_IF_ERROR(CheckRandomAccessCompatible(index)); { mutex_lock l(mu_); @@ -201,7 +202,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } - Status Initialize(IteratorContext* ctx) override { + absl::Status Initialize(IteratorContext* ctx) override { mutex_lock l(mu_); seed_generator_->GenerateSeeds(&seed_, &seed2_); ResetRngs(); @@ -214,9 +215,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { return absl::OkStatus(); } - Status GetNextInternal(IteratorContext* ctx, - std::vector* out_tensors, - bool* end_of_sequence) override { + absl::Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { mutex_lock l(mu_); TF_RETURN_IF_ERROR(FillBuffer(ctx)); if (num_elements_ == 0) { @@ -259,8 +260,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { generator_.Skip(num_random_samples_); } - Status SaveInternal(SerializationContext* ctx, - IteratorStateWriter* writer) override { + absl::Status SaveInternal(SerializationContext* ctx, + IteratorStateWriter* writer) override { mutex_lock l(mu_); // Save state needed to restore the random number generators. TF_RETURN_IF_ERROR( @@ -319,8 +320,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { return absl::OkStatus(); } - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { + absl::Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { mutex_lock l(mu_); // Restore the random number generators. int64_t num_random_samples; @@ -440,7 +441,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { } // Fills the shuffle buffer, preparing the buffer for sampling. - Status FillBuffer(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status FillBuffer(IteratorContext* ctx) + TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t start_micros = EnvTime::NowMicros(); int64_t num_log_entries = 0; while (ShouldFillBuffer()) { @@ -503,7 +505,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase { return num_elements_ < buffer_->size(); } - Status PrepareNextEpoch(IteratorContext* ctx) + absl::Status PrepareNextEpoch(IteratorContext* ctx) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (epoch_ == 0) { slices_.push_back(std::make_unique(0, 0, false)); @@ -610,7 +612,7 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase { ~Dataset() override { manager_->Unref(); - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); @@ -620,9 +622,9 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase { string op_type() const override { return kDatasetType; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size_node = nullptr; @@ -669,7 +671,7 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase { ~DatasetV2() override { manager_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); @@ -680,9 +682,9 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase { string op_type() const override { return kDatasetType; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size_node = nullptr; @@ -724,7 +726,7 @@ class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase { ~DatasetV3() override { manager_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); @@ -735,9 +737,9 @@ class ShuffleDatasetOp::DatasetV3 : public ShuffleDatasetBase { string op_type() const override { return kDatasetType; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size_node = nullptr; @@ -805,7 +807,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, if (op_version_ == 3) { auto handle = HandleFromInput(ctx, 4); SeedGeneratorManager* manager = nullptr; - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); int64_t seed; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kSeed, &seed)); @@ -842,7 +844,7 @@ void ShuffleDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, } else if (op_version_ == 2) { auto handle = HandleFromInput(ctx, 2); SeedGeneratorManager* manager = nullptr; - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); bool owns_resource = false; if (errors::IsNotFound(s)) { @@ -917,7 +919,7 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { ~Dataset() override { manager_->Unref(); - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); @@ -927,9 +929,9 @@ class ShuffleAndRepeatDatasetOp::Dataset : public ShuffleDatasetBase { string op_type() const override { return kDatasetType; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size = nullptr; @@ -974,7 +976,7 @@ class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase { ~DatasetV2() override { manager_->Unref(); if (owns_resource_) { - Status s = resource_mgr_->Delete( + absl::Status s = resource_mgr_->Delete( resource_handle_.container(), resource_handle_.name()); if (!s.ok()) { LOG(WARNING) << "Failed to delete RNG resource: " << s.ToString(); @@ -985,9 +987,9 @@ class ShuffleAndRepeatDatasetOp::DatasetV2 : public ShuffleDatasetBase { string op_type() const override { return kDatasetType; } protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { + absl::Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* buffer_size_node = nullptr; @@ -1070,7 +1072,7 @@ void ShuffleAndRepeatDatasetOp::MakeDataset(OpKernelContext* ctx, if (op_version_ == 2) { auto handle = HandleFromInput(ctx, 5); SeedGeneratorManager* manager = nullptr; - Status s = ctx->resource_manager()->Lookup( + absl::Status s = ctx->resource_manager()->Lookup( handle.container(), handle.name(), &manager); bool owns_resource = false; if (errors::IsNotFound(s)) { diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc index 99f907de2c4dc7..40d2d976731b79 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op_test.cc @@ -59,7 +59,7 @@ class ShuffleDatasetParams : public DatasetParams { return input_tensors; } - Status GetInputNames(std::vector* input_names) const override { + absl::Status GetInputNames(std::vector* input_names) const override { input_names->clear(); input_names->emplace_back(ShuffleDatasetOpBase::kInputDataset); input_names->emplace_back(ShuffleDatasetOpBase::kBufferSize); @@ -71,7 +71,7 @@ class ShuffleDatasetParams : public DatasetParams { return absl::OkStatus(); } - Status GetAttributes(AttributeVector* attr_vector) const override { + absl::Status GetAttributes(AttributeVector* attr_vector) const override { attr_vector->clear(); attr_vector->emplace_back("output_types", output_dtypes_); attr_vector->emplace_back("output_shapes", output_shapes_); diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op.cc b/tensorflow/core/kernels/data/tf_record_dataset_op.cc index 33a3ce97d91e9d..f13524b8e0ef34 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #include "tensorflow/core/platform/logging.h" +#include "tsl/profiler/lib/traceme.h" namespace tensorflow { namespace data { @@ -265,6 +266,14 @@ class TFRecordDatasetOp::Dataset : public DatasetBase { } // Actually move on to next file. + tsl::profiler::TraceMe traceme( + [&, current_file_index = current_file_index_] { + return tsl::profiler::TraceMeEncode( + "TFRecordDatasetOp::Iterator::SetupStreamsLocked", + {{"filename", dataset()->filenames_[current_file_index]}}); + }, + tsl::profiler::kInfo); + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( TranslateFileName(dataset()->filenames_[current_file_index_]), &file_)); diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 15ff88c5c229ff..92607656b52a00 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -179,11 +179,11 @@ class BaseDebugOp : public OpKernel { // Publish a tensor to all debug URLs of the debug op. // Log an error if the publishing failed. - Status PublishTensor(const Tensor& tensor, int64_t step_id = -1) { + absl::Status PublishTensor(const Tensor& tensor, int64_t step_id = -1) { if (debug_urls_.empty()) { return absl::OkStatus(); } else { - Status status = DebugIO::PublishDebugTensor( + absl::Status status = DebugIO::PublishDebugTensor( *debug_watch_key_, tensor, Env::Default()->NowMicros(), debug_urls_, gated_grpc_, step_id); if (!status.ok()) { diff --git a/tensorflow/core/kernels/debug_ops_test.cc b/tensorflow/core/kernels/debug_ops_test.cc index 102881b1fa4023..b554f78d32e607 100644 --- a/tensorflow/core/kernels/debug_ops_test.cc +++ b/tensorflow/core/kernels/debug_ops_test.cc @@ -40,7 +40,8 @@ namespace tensorflow { class DebugIdentityOpTest : public OpsTestBase { protected: - Status Init(DataType input_type, const std::vector& debug_urls) { + absl::Status Init(DataType input_type, + const std::vector& debug_urls) { env_ = Env::Default(); TF_CHECK_OK(NodeDefBuilder("op", "DebugIdentity") @@ -51,7 +52,7 @@ class DebugIdentityOpTest : public OpsTestBase { return InitOp(); } - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { std::vector empty_debug_urls; return Init(input_type, empty_debug_urls); } @@ -178,7 +179,7 @@ TEST_F(DebugIdentityOpTest, StringSuccess) { // Tests for DebugNanCountOp class DebugNanCountOpTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "DebugNanCount") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") @@ -241,7 +242,7 @@ TEST_F(DebugNanCountOpTest, Double_no_NaNs) { // Tests for DebugNumericSummaryOp class DebugNumericSummaryOpTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "DebugNumericSummary") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") @@ -249,7 +250,8 @@ class DebugNumericSummaryOpTest : public OpsTestBase { return InitOp(); } - Status InitGated(DataType input_type, const std::vector& debug_urls) { + absl::Status InitGated(DataType input_type, + const std::vector& debug_urls) { TF_CHECK_OK(NodeDefBuilder("op", "DebugNumericSummary") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") @@ -632,7 +634,7 @@ TEST_F(DebugNumericSummaryOpTest, DisabledDueToNonMatchingWatchKey) { // Tests for DebugNumericSummaryOp class DebugNumericSummaryOpCustomLowerBoundTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "DebugNumericSummary") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") @@ -684,7 +686,7 @@ TEST_F(DebugNumericSummaryOpCustomLowerBoundTest, Float_full_house) { // Tests for DebugNumericSummaryOp class DebugNumericSummaryOpCustomLowerUpperBoundsTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "DebugNumericSummary") .Input(FakeInput(input_type)) .Attr("tensor_name", "FakeTensor:0") diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc index 407746a9e20b02..3f1338ab4c026f 100644 --- a/tensorflow/core/kernels/decode_compressed_op.cc +++ b/tensorflow/core/kernels/decode_compressed_op.cc @@ -34,14 +34,14 @@ class MemoryInputStream : public io::InputStreamInterface { ~MemoryInputStream() override {} - Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { + absl::Status ReadNBytes(int64_t bytes_to_read, tstring* result) override { result->clear(); if (bytes_to_read < 0) { return errors::InvalidArgument("Can't read a negative number of bytes: ", bytes_to_read); } int64_t bytes = bytes_to_read; - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); if (pos_ + bytes_to_read > len_) { bytes = len_ - pos_; s = errors::OutOfRange("reached end of file"); @@ -56,7 +56,7 @@ class MemoryInputStream : public io::InputStreamInterface { int64_t Tell() const override { return pos_; } - Status Reset() override { + absl::Status Reset() override { pos_ = 0; return absl::OkStatus(); } @@ -107,7 +107,7 @@ class DecodeCompressedOp : public OpKernel { input_stream.get(), static_cast(kBufferSize), static_cast(kBufferSize), zlib_options)); tstring output_string; - Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string); + absl::Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string); OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s); output_flat(i) = std::move(output_string); } diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc index 9ec7691a01f00e..eae7d4543785b0 100644 --- a/tensorflow/core/kernels/decode_proto_op.cc +++ b/tensorflow/core/kernels/decode_proto_op.cc @@ -86,7 +86,8 @@ struct DefaultValue { // value: the default value as obtained from the FieldDescriptor // result: the object to initialize template -Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) { +absl::Status InitDefaultValue(DataType dtype, const T value, + DefaultValue* result) { result->dtype = dtype; switch (dtype) { case DT_BOOL: @@ -126,8 +127,8 @@ Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) { } template <> -Status InitDefaultValue(DataType dtype, const tstring value, - DefaultValue* result) { +absl::Status InitDefaultValue(DataType dtype, const tstring value, + DefaultValue* result) { // These are sanity checks that should never trigger given the code that // leads here. if (TF_PREDICT_FALSE(dtype != DT_STRING)) { @@ -141,9 +142,8 @@ Status InitDefaultValue(DataType dtype, const tstring value, // Initializes a default value from the output data type and the field // descriptor. -Status InitDefaultValueFromFieldDescriptor(DataType dtype, - const FieldDescriptor* field_desc, - DefaultValue* result) { +absl::Status InitDefaultValueFromFieldDescriptor( + DataType dtype, const FieldDescriptor* field_desc, DefaultValue* result) { switch (field_desc->type()) { case WireFormatLite::TYPE_DOUBLE: return InitDefaultValue(dtype, field_desc->default_value_double(), @@ -248,7 +248,7 @@ class CountCollector { explicit CountCollector(int32* count) : count_ptr_(count) {} // Reads (in this case counts) a single value. - Status ReadValue(CodedInputStream* input, const FieldInfo& field) { + absl::Status ReadValue(CodedInputStream* input, const FieldInfo& field) { // Only repeated fields can have count > 1. if (*count_ptr_ == 0 || field.is_repeated) { (*count_ptr_)++; @@ -262,8 +262,8 @@ class CountCollector { } // Reads (in this case counts) a length-delimited list of values. - Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, - size_t buf_size) { + absl::Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, + size_t buf_size) { if (buf_size == 0) { return absl::OkStatus(); } @@ -286,7 +286,7 @@ class CountCollector { // Dispatch to the appropriately typed field reader based on the schema // type. - Status st; + absl::Status st; switch (field.type) { case WireFormatLite::TYPE_DOUBLE: st = CountPackedFixed(buf, buf_size); @@ -367,7 +367,7 @@ class CountCollector { // Counts the number of packed varints in an array. The end of a varint is // signaled by a value < 0x80, so counting them requires parsing the // bytestream. It is the caller's responsibility to ensure that len > 0. - Status CountPackedVarint(const uint8* buf, size_t len) { + absl::Status CountPackedVarint(const uint8* buf, size_t len) { const uint8* bound = buf + len; int count; @@ -396,7 +396,7 @@ class CountCollector { // Counts the number of fixed-size values in a packed field. This can be done // without actually parsing anything. template - Status CountPackedFixed(const uint8* unused_buf, size_t len) { + absl::Status CountPackedFixed(const uint8* unused_buf, size_t len) { int count = len / sizeof(T); if (count * sizeof(T) != len) { return errors::DataLoss( @@ -483,7 +483,7 @@ class DenseCollector { // Always inlining gave a ~50% speedup on microbenchmarks at one point. // TODO(nix): try removing it to see if that still holds. // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE - Status ReadValue(CodedInputStream* input, const FieldInfo& field) { + absl::Status ReadValue(CodedInputStream* input, const FieldInfo& field) { // For required and optional fields, we overwrite values[0] with // the latest one in the wire stream. // See https://developers.google.com/protocol-buffers/docs/encoding#optional @@ -501,8 +501,8 @@ class DenseCollector { } // Reads and stores a length-delimited list of values. - Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, - const size_t buf_size) { + absl::Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, + const size_t buf_size) { const void* buf; int unused_max_buf_size; input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size); @@ -533,7 +533,7 @@ class DenseCollector { // Fills in any missing values in the output array with defaults. Dispatches // to the appropriately typed field default based on the runtime type tag. - Status FillWithDefaults() { + absl::Status FillWithDefaults() { switch (default_value_.dtype) { case DataType::DT_BOOL: return FillDefault(absl::get(default_value_.value)); @@ -569,7 +569,7 @@ class DenseCollector { // uses next_repeat_index_ which counts the number of parsed values for the // field. template - Status FillDefault(const T& default_value) { + absl::Status FillDefault(const T& default_value) { for (int i = next_repeat_index_; i < max_repeat_count_; i++) { reinterpret_cast(datap_)[i] = default_value; } @@ -839,7 +839,7 @@ class DecodeProtoOp : public OpKernel { counters.emplace_back(&field_sizes[i]); } - Status st = Collect(&input, absl::MakeSpan(counters)); + absl::Status st = Collect(&input, absl::MakeSpan(counters)); if (st.ok() && !input.ConsumedEntireMessage()) { st = errors::DataLoss("CountFields: Failed to consume entire buffer"); } @@ -928,7 +928,7 @@ class DecodeProtoOp : public OpKernel { // Fill in output tensors from the wire. CodedInputStream input(reinterpret_cast(buf.c_str()), buf.size()); - Status st = Collect(&input, absl::MakeSpan(collectors)); + absl::Status st = Collect(&input, absl::MakeSpan(collectors)); if (st.ok() && !input.ConsumedEntireMessage()) { st = errors::DataLoss( "AccumulateFields: Failed to consume entire buffer"); @@ -952,8 +952,8 @@ class DecodeProtoOp : public OpKernel { // Traverses a serialized protobuf, dispatching values to the collectors. template - Status Collect(CodedInputStream* input, - absl::Span collectors) { + absl::Status Collect(CodedInputStream* input, + absl::Span collectors) { // At the beginning of each loop, the last field number that was seen, // regardless of whether it was collected or not, or -1 if no field has // been seen before. @@ -1033,9 +1033,10 @@ class DecodeProtoOp : public OpKernel { // Collects values for a single field. template - Status CollectField(const FieldInfo& field, - WireFormatLite::WireType wire_type, - CodedInputStream* input, CollectorClass* collector) { + absl::Status CollectField(const FieldInfo& field, + WireFormatLite::WireType wire_type, + CodedInputStream* input, + CollectorClass* collector) { // The wire format library defines the same constants used in // descriptor.proto. This static_cast is safe because they are guaranteed to // stay in sync. diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h index 77087957cbecf1..c16db9361f6753 100644 --- a/tensorflow/core/kernels/dense_update_functor.h +++ b/tensorflow/core/kernels/dense_update_functor.h @@ -66,14 +66,15 @@ struct DenseUpdate { } // end namespace functor template -Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to); +absl::Status VariantCopyFn(OpKernelContext* context, const Tensor& from, + Tensor* to); template <> -Status VariantCopyFn(OpKernelContext* context, const Tensor& from, - Tensor* to); +absl::Status VariantCopyFn(OpKernelContext* context, + const Tensor& from, Tensor* to); template <> -Status VariantCopyFn(OpKernelContext* context, const Tensor& from, - Tensor* to); +absl::Status VariantCopyFn(OpKernelContext* context, + const Tensor& from, Tensor* to); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/deserialize_sparse_string_op.cc b/tensorflow/core/kernels/deserialize_sparse_string_op.cc index 8ec0cee143a34a..c4bad2a3e1852e 100644 --- a/tensorflow/core/kernels/deserialize_sparse_string_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_string_op.cc @@ -213,7 +213,7 @@ class DeserializeSparseOp : public OpKernel { } private: - Status Deserialize(const tstring& serialized, Tensor* result) { + absl::Status Deserialize(const tstring& serialized, Tensor* result) { TensorProto proto; if (!ParseProtoUnlimited(&proto, serialized)) { return errors::InvalidArgument("Could not parse serialized proto"); @@ -226,7 +226,7 @@ class DeserializeSparseOp : public OpKernel { return absl::OkStatus(); } - Status GetAndValidateSparseTensor( + absl::Status GetAndValidateSparseTensor( const tstring& serialized_indices, const tstring& serialized_values, const tstring& serialized_shape, DataType values_dtype, int index, Tensor* output_indices, Tensor* output_values, Tensor* output_shape) { diff --git a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc index 7169d0f061a110..c46488401b751f 100644 --- a/tensorflow/core/kernels/deserialize_sparse_variant_op.cc +++ b/tensorflow/core/kernels/deserialize_sparse_variant_op.cc @@ -274,10 +274,11 @@ class DeserializeSparseOp : public OpKernel { } private: - Status GetAndValidateSparseTensorShape(const Variant& serialized_values, - const Variant& serialized_shape, - int index, const Tensor** output_shape, - int64_t* output_num_non_zeros) { + absl::Status GetAndValidateSparseTensorShape(const Variant& serialized_values, + const Variant& serialized_shape, + int index, + const Tensor** output_shape, + int64_t* output_num_non_zeros) { // Deserialize and validate the shape. *output_shape = serialized_shape.get(); if (*output_shape == nullptr) { @@ -300,7 +301,7 @@ class DeserializeSparseOp : public OpKernel { return absl::OkStatus(); } - Status GetAndValidateSparseTensorIndicesAndValues( + absl::Status GetAndValidateSparseTensorIndicesAndValues( const Variant& serialized_indices, const Variant& serialized_values, int index, int expected_rank, const Tensor** output_indices, const Tensor** output_values) { diff --git a/tensorflow/core/kernels/diag_op.cc b/tensorflow/core/kernels/diag_op.cc index 0b7a957bd40937..c3a550a51994a6 100644 --- a/tensorflow/core/kernels/diag_op.cc +++ b/tensorflow/core/kernels/diag_op.cc @@ -61,7 +61,7 @@ class DiagOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output_tensor)); functor::DiagFunctor diagFunc; - Status s = + absl::Status s = diagFunc(context, diagonal.NumElements(), diagonal.flat().data(), output_tensor->flat().data()); OP_REQUIRES_OK(context, s); @@ -98,8 +98,9 @@ class DiagPartOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); functor::DiagPartFunctor diagPartFunc; - Status s = diagPartFunc(context, out_shape.num_elements(), - tensor.flat().data(), output->flat().data()); + absl::Status s = + diagPartFunc(context, out_shape.num_elements(), tensor.flat().data(), + output->flat().data()); OP_REQUIRES_OK(context, s); } }; @@ -126,9 +127,9 @@ class DiagPartOp : public OpKernel { namespace functor { template struct DiagFunctor { - EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, - const int64_t size, const T* in, - T* out) { + EIGEN_ALWAYS_INLINE absl::Status operator()(OpKernelContext* context, + const int64_t size, const T* in, + T* out) { // This subprocess is responsible for writing values in index range // [start*size, limit*size) auto subDiag = [in, out, size](int64_t start, int64_t limit) { @@ -148,9 +149,9 @@ struct DiagFunctor { template struct DiagPartFunctor { - EIGEN_ALWAYS_INLINE Status operator()(OpKernelContext* context, - const int64_t size, const T* in, - T* out) { + EIGEN_ALWAYS_INLINE absl::Status operator()(OpKernelContext* context, + const int64_t size, const T* in, + T* out) { // This subprocess is responsible for extracting values in index range // [start, limit) auto subDiagPart = [in, out, size](int64_t start, int64_t limit) { diff --git a/tensorflow/core/kernels/diag_op.h b/tensorflow/core/kernels/diag_op.h index e00857da530d3f..c41da62def3369 100644 --- a/tensorflow/core/kernels/diag_op.h +++ b/tensorflow/core/kernels/diag_op.h @@ -26,14 +26,14 @@ namespace functor { template struct DiagFunctor { - Status operator()(OpKernelContext* context, const int64_t size, const T* in, - T* out); + absl::Status operator()(OpKernelContext* context, const int64_t size, + const T* in, T* out); }; template struct DiagPartFunctor { - Status operator()(OpKernelContext* context, const int64_t size, const T* in, - T* out); + absl::Status operator()(OpKernelContext* context, const int64_t size, + const T* in, T* out); }; } // namespace functor diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 092d62a917f7e9..ae3fe6d140fef8 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -48,12 +48,8 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/transform_output_iterator.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" -using stream_executor::gpu::ScopedActivateContext; -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" -using stream_executor::gpu::ScopedActivateContext; #endif // GOOGLE_CUDA namespace tensorflow { @@ -298,7 +294,8 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { partition_ref, cpu_tensor, done]() { { auto stream = c->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); OpOutputList outputs; this->AllocateOutputs(c, &data, &partitions, &cpu_tensor, &outputs, @@ -311,7 +308,7 @@ class DynamicPartitionOpGPU : public AsyncOpKernel { int64 slice_size = data.NumElements() / N; this->GatherSlices(c, &data, &indices_out, N, slice_size, outputs); partition_ref.Unref(); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. diff --git a/tensorflow/core/kernels/dynamic_partition_op_test.cc b/tensorflow/core/kernels/dynamic_partition_op_test.cc index 1af0c04d8d1939..b8d0fbee44220a 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_test.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_test.cc @@ -153,7 +153,7 @@ TEST_F(DynamicPartitionOpTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({5, 3}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); AddInputFromArray(TensorShape({5}), {0, 2, 99, 2, 2}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "partitions[2] = 99 is not in [0, 4)")) << s; diff --git a/tensorflow/core/kernels/dynamic_stitch_op_test.cc b/tensorflow/core/kernels/dynamic_stitch_op_test.cc index a4983e3fe04c13..26852d3dde38b4 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op_test.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op_test.cc @@ -102,7 +102,7 @@ TEST_F(DynamicStitchOpTest, Error_IndicesMultiDimensional) { AddInputFromArray(TensorShape({1, 5}), {1, 6, 2, 3, 5}); AddInputFromArray(TensorShape({3}), {0, 40, 70}); AddInputFromArray(TensorShape({5}), {10, 60, 20, 30, 50}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "data[1].shape = [5] does not start with indices[1].shape = [1,5]")) @@ -117,7 +117,7 @@ TEST_F(DynamicStitchOpTest, Error_DataNumDimsMismatch) { AddInputFromArray(TensorShape({5}), {1, 6, 2, 3, 5}); AddInputFromArray(TensorShape({3}), {0, 40, 70}); AddInputFromArray(TensorShape({1, 5}), {10, 60, 20, 30, 50}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "data[1].shape = [1,5] does not start with indices[1].shape = [5]")) @@ -133,7 +133,7 @@ TEST_F(DynamicStitchOpTest, Error_DataDimSizeMismatch) { AddInputFromArray(TensorShape({3, 1}), {0, 40, 70}); AddInputFromArray(TensorShape({4, 2}), {10, 11, 60, 61, 20, 21, 30, 31}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "Need data[0].shape[1:] = data[1].shape[1:], got " @@ -149,7 +149,7 @@ TEST_F(DynamicStitchOpTest, Error_DataAndIndicesSizeMismatch) { AddInputFromArray(TensorShape({5}), {1, 6, 2, 3, 5}); AddInputFromArray(TensorShape({3}), {0, 40, 70}); AddInputFromArray(TensorShape({4}), {10, 60, 20, 30}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "data[1].shape = [4] does not start with indices[1].shape = [5]")) diff --git a/tensorflow/core/kernels/edit_distance_op.cc b/tensorflow/core/kernels/edit_distance_op.cc index b495f2f80524e4..42054edaf87195 100644 --- a/tensorflow/core/kernels/edit_distance_op.cc +++ b/tensorflow/core/kernels/edit_distance_op.cc @@ -35,11 +35,13 @@ namespace tensorflow { namespace { -Status ValidateShapes(OpKernelContext* ctx, const Tensor& hypothesis_indices, - const Tensor& hypothesis_values, - const Tensor& hypothesis_shape, - const Tensor& truth_indices, const Tensor& truth_values, - const Tensor& truth_shape) { +absl::Status ValidateShapes(OpKernelContext* ctx, + const Tensor& hypothesis_indices, + const Tensor& hypothesis_values, + const Tensor& hypothesis_shape, + const Tensor& truth_indices, + const Tensor& truth_values, + const Tensor& truth_shape) { if (!TensorShapeUtils::IsMatrix(hypothesis_indices.shape())) return errors::InvalidArgument( "hypothesis_indices should be a matrix, but got shape: ", diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index 50329ad8db720f..5148e849b307bd 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -246,8 +246,9 @@ size_t TotalPackedSize( template -Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +absl::Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, + CodedOutputStream* output) { auto wire_type = WireFormatLite::WireTypeForFieldType( WireFormatLite::FieldType(field_desc.type())); @@ -282,9 +283,9 @@ Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, // Writes a possibly repeated string, bytes, or message field. template -Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, - CodedOutputStream* output) { +absl::Status WriteVarLenField(const FieldDescriptor& field_desc, + const Tensor& input, int message_index, int size, + CodedOutputStream* output) { auto input_t = input.flat_inner_dims(); for (int64_t i = 0; i < size; i++) { const T& value = input_t(static_cast(message_index), i); @@ -319,8 +320,9 @@ static void WriteBytesAdapter(int field_number, const tstring& value, // Writes a group field. Groups are treated like submessages, but tag-delimited // instead of length-delimited. WireFormatLite handles this differently so we // code it ourselves. -Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +absl::Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, + CodedOutputStream* output) { auto input_t = input.flat_inner_dims(); for (int64_t i = 0; i < size; i++) { const string& value = input_t(static_cast(message_index), i); @@ -338,8 +340,9 @@ Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, // responsibility to ensure that the type of the input tensor is compatible with // the type of the proto field descriptor, and that (message_index, size-1) is // within bounds. -Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +absl::Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, + CodedOutputStream* output) { DataType dtype = input.dtype(); switch (field_desc.type()) { diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc index b7ea4c9f9bf592..163e89bc0b4b0f 100644 --- a/tensorflow/core/kernels/example_parsing_ops.cc +++ b/tensorflow/core/kernels/example_parsing_ops.cc @@ -102,8 +102,8 @@ class ParseExampleOp : public OpKernel { protected: // Copies keys from tensor to std::vector. - Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name, - std::vector* keys) const { + absl::Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name, + std::vector* keys) const { const Tensor* key_t; TF_RETURN_IF_ERROR(ctx->input(input_name, &key_t)); keys->reserve(key_t->NumElements()); @@ -115,8 +115,8 @@ class ParseExampleOp : public OpKernel { } // Copies keys from OpInputList of scalar to std::vector. - Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name, - std::vector* keys) const { + absl::Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name, + std::vector* keys) const { OpInputList key_list; TF_RETURN_IF_ERROR(ctx->input_list(input_name, &key_list)); keys->reserve(key_list.size()); @@ -127,11 +127,12 @@ class ParseExampleOp : public OpKernel { } // Validates the shapes of input tensors. - Status CheckInputShapes(const Tensor* serialized, const Tensor* names, - const OpInputList& dense_defaults, - const std::vector& dense_keys_t, - const std::vector& sparse_keys_t, - const std::vector& ragged_keys_t) const { + absl::Status CheckInputShapes( + const Tensor* serialized, const Tensor* names, + const OpInputList& dense_defaults, + const std::vector& dense_keys_t, + const std::vector& sparse_keys_t, + const std::vector& ragged_keys_t) const { if (op_version_ == 2) { if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) { return errors::InvalidArgument( @@ -235,18 +236,19 @@ class ParseExampleOp : public OpKernel { } // Parses a single example. - Status ParseExampleScalar(const example::FastParseExampleConfig& config, - const Tensor* serialized, OpKernelContext* ctx, - example::Result* result) const { + absl::Status ParseExampleScalar(const example::FastParseExampleConfig& config, + const Tensor* serialized, + OpKernelContext* ctx, + example::Result* result) const { const tstring& serialized_proto = serialized->scalar()(); return FastParseSingleExample(config, serialized_proto, result); } // Parses a vector of examples. - Status ParseExampleVector(const example::FastParseExampleConfig& config, - const Tensor* serialized, const Tensor* names, - OpKernelContext* ctx, - example::Result* result) const { + absl::Status ParseExampleVector(const example::FastParseExampleConfig& config, + const Tensor* serialized, const Tensor* names, + OpKernelContext* ctx, + example::Result* result) const { auto serialized_t = serialized->flat(); auto names_t = names->flat(); absl::Span slice(serialized_t.data(), serialized_t.size()); @@ -256,8 +258,8 @@ class ParseExampleOp : public OpKernel { ctx->device()->tensorflow_cpu_worker_threads()->workers, result); } - Status WriteOutput(const example::Result& result, - OpKernelContext* ctx) const { + absl::Status WriteOutput(const example::Result& result, + OpKernelContext* ctx) const { OpOutputList dense_values; OpOutputList sparse_indices; OpOutputList sparse_values; @@ -488,7 +490,7 @@ class ParseSequenceExampleOp : public OpKernel { } protected: - Status CheckInputShapes( + absl::Status CheckInputShapes( const Tensor* serialized, const Tensor* names, const OpInputList& context_dense_defaults, @@ -686,10 +688,10 @@ class ParseSequenceExampleOp : public OpKernel { return config; } - Status WriteOutput(const example::Result& context_result, - const example::Result& feature_list_result, - const std::vector& dense_feature_lengths, - OpKernelContext* ctx) const { + absl::Status WriteOutput(const example::Result& context_result, + const example::Result& feature_list_result, + const std::vector& dense_feature_lengths, + OpKernelContext* ctx) const { OpOutputList context_sparse_indices; OpOutputList context_sparse_values; OpOutputList context_sparse_shapes; diff --git a/tensorflow/core/kernels/fifo_queue.cc b/tensorflow/core/kernels/fifo_queue.cc index 27adb244c4d5e9..80ea9b7febe70e 100644 --- a/tensorflow/core/kernels/fifo_queue.cc +++ b/tensorflow/core/kernels/fifo_queue.cc @@ -86,10 +86,9 @@ void FIFOQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, } /* static */ -Status FIFOQueue::GetElementComponentFromBatch(const FIFOQueue::Tuple& tuple, - int64_t index, int component, - OpKernelContext* ctx, - Tensor* out_tensor) { +absl::Status FIFOQueue::GetElementComponentFromBatch( + const FIFOQueue::Tuple& tuple, int64_t index, int component, + OpKernelContext* ctx, Tensor* out_tensor) { TensorShape element_shape(tuple[component].shape()); element_shape.RemoveDim(0); TF_RETURN_IF_ERROR( @@ -231,8 +230,8 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, // an optimized case where the queue 'knows' what attributes to // use, and plumbs them through here. Tensor element; - Status status = ctx->allocate_temp(component_dtypes_[i], - ManyOutShape(i, 0), &element); + absl::Status status = ctx->allocate_temp(component_dtypes_[i], + ManyOutShape(i, 0), &element); if (!status.ok()) { ctx->SetStatus(status); callback(Tuple()); @@ -270,7 +269,7 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, i >= 0; --i) { for (int j = 0; j < num_components(); ++j) { Tensor element; - Status s = GetElementComponentFromBatch( + absl::Status s = GetElementComponentFromBatch( attempt->tuple, i, j, attempt->context, &element); if (!s.ok()) { attempt->context->SetStatus( @@ -355,7 +354,7 @@ void FIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } } -Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { +absl::Status FIFOQueue::MatchesNodeDef(const NodeDef& node_def) { if (!MatchesNodeDefOp(node_def, "FIFOQueue").ok() && !MatchesNodeDefOp(node_def, "FIFOQueueV2").ok()) { return errors::InvalidArgument("Expected FIFOQueue, found ", node_def.op()); @@ -375,7 +374,7 @@ FIFOQueueOp::FIFOQueueOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("shapes", &component_shapes_)); } -Status FIFOQueueOp::CreateResource(QueueInterface** ret) { +absl::Status FIFOQueueOp::CreateResource(QueueInterface** ret) { FIFOQueue* queue = new FIFOQueue(capacity_, component_types_, component_shapes_, cinfo_.name()); return CreateTypedQueue(queue, ret); diff --git a/tensorflow/core/kernels/fifo_queue.h b/tensorflow/core/kernels/fifo_queue.h index c6faff14962b31..6648fe271dcdea 100644 --- a/tensorflow/core/kernels/fifo_queue.h +++ b/tensorflow/core/kernels/fifo_queue.h @@ -47,7 +47,7 @@ class FIFOQueue : public TypedQueue > { void TryDequeueMany(int num_elements, OpKernelContext* ctx, bool allow_small_batch, CallbackWithTuple callback) override; - Status MatchesNodeDef(const NodeDef& node_def) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; int32 size() const override { mutex_lock lock(mu_); @@ -61,10 +61,10 @@ class FIFOQueue : public TypedQueue > { void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetElementComponentFromBatch(const Tuple& tuple, int64_t index, - int component, - OpKernelContext* ctx, - Tensor* out_tensor); + static absl::Status GetElementComponentFromBatch(const Tuple& tuple, + int64_t index, int component, + OpKernelContext* ctx, + Tensor* out_tensor); private: FIFOQueue(const FIFOQueue&) = delete; @@ -80,7 +80,7 @@ class FIFOQueueOp : public TypedQueueOp { explicit FIFOQueueOp(OpKernelConstruction* context); private: - Status CreateResource(QueueInterface** ret) override + absl::Status CreateResource(QueueInterface** ret) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); std::vector component_shapes_; diff --git a/tensorflow/core/kernels/fill_empty_rows_functor.h b/tensorflow/core/kernels/fill_empty_rows_functor.h index c5e19846745703..2298ed92d0f8d7 100644 --- a/tensorflow/core/kernels/fill_empty_rows_functor.h +++ b/tensorflow/core/kernels/fill_empty_rows_functor.h @@ -30,19 +30,21 @@ namespace functor { template struct FillEmptyRows { // Note that the done callback is only used by the GPU implementation. - Status operator()(OpKernelContext* context, const Tensor& default_value_t, - const Tensor& indices_t, const Tensor& values_t, - const Tensor& dense_shape_t, - typename AsyncOpKernel::DoneCallback done = nullptr); + absl::Status operator()(OpKernelContext* context, + const Tensor& default_value_t, + const Tensor& indices_t, const Tensor& values_t, + const Tensor& dense_shape_t, + typename AsyncOpKernel::DoneCallback done = nullptr); }; template struct FillEmptyRows { static constexpr int IndicesRank = RaggedOperands ? 1 : 2; - Status operator()(OpKernelContext* context, const Tensor& default_value_t, - const Tensor& indices_t, const Tensor& values_t, - const Tensor& dense_shape_t, - typename AsyncOpKernel::DoneCallback done) { + absl::Status operator()(OpKernelContext* context, + const Tensor& default_value_t, + const Tensor& indices_t, const Tensor& values_t, + const Tensor& dense_shape_t, + typename AsyncOpKernel::DoneCallback done) { (void)done; // Unused (only used in GPU implementation) const int kOutputIndicesOutput = 0; const int kOutputValuesOutput = 1; @@ -210,20 +212,20 @@ struct FillEmptyRows { template struct FillEmptyRowsGrad { - Status operator()(OpKernelContext* context, - typename TTypes::ConstVec reverse_index_map, - typename TTypes::ConstVec grad_values, - typename TTypes::Vec d_values, - typename TTypes::Scalar d_default_value); + absl::Status operator()(OpKernelContext* context, + typename TTypes::ConstVec reverse_index_map, + typename TTypes::ConstVec grad_values, + typename TTypes::Vec d_values, + typename TTypes::Scalar d_default_value); }; template struct FillEmptyRowsGrad { - Status operator()(OpKernelContext* context, - typename TTypes::ConstVec reverse_index_map, - typename TTypes::ConstVec grad_values, - typename TTypes::Vec d_values, - typename TTypes::Scalar d_default_value) { + absl::Status operator()(OpKernelContext* context, + typename TTypes::ConstVec reverse_index_map, + typename TTypes::ConstVec grad_values, + typename TTypes::Vec d_values, + typename TTypes::Scalar d_default_value) { const CPUDevice& device = context->eigen_device(); const Tindex N = reverse_index_map.dimension(0); const Tindex N_full = grad_values.dimension(0); diff --git a/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc b/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc index 83e5f872720789..983517226163c1 100644 --- a/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/fill_empty_rows_functor_gpu.cu.cc @@ -18,7 +18,6 @@ limitations under the License. #define EIGEN_USE_GPU #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" @@ -31,8 +30,6 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace -using stream_executor::gpu::ScopedActivateContext; - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -338,7 +335,8 @@ struct FillEmptyRows { // Ensure that within the callback, the proper GPU settings are // configured. auto stream = context->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); int first_invalid_index = *first_invalid_index_host.data(); OP_REQUIRES_ASYNC( @@ -394,7 +392,7 @@ struct FillEmptyRows { output_indices, output_values), done); } - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. done(); diff --git a/tensorflow/core/kernels/fingerprint_op_test.cc b/tensorflow/core/kernels/fingerprint_op_test.cc index 2ba9640d09e963..1786bd240862bd 100644 --- a/tensorflow/core/kernels/fingerprint_op_test.cc +++ b/tensorflow/core/kernels/fingerprint_op_test.cc @@ -31,7 +31,7 @@ limitations under the License. namespace tensorflow { namespace { -Status MakeNodeDef(DataType dtype, NodeDef* node_def) { +absl::Status MakeNodeDef(DataType dtype, NodeDef* node_def) { return NodeDefBuilder("fingerprint", "Fingerprint") .Input(FakeInput(dtype)) .Input(FakeInput(DT_STRING)) @@ -40,11 +40,11 @@ Status MakeNodeDef(DataType dtype, NodeDef* node_def) { class FingerprintOpTest : public OpsTestBase { protected: - Status MakeFingerprintOp(Tensor* tensor) { + absl::Status MakeFingerprintOp(Tensor* tensor) { return MakeFingerprintOp(tensor, "farmhash64"); } - Status MakeFingerprintOp(Tensor* data, const string& method) { + absl::Status MakeFingerprintOp(Tensor* data, const string& method) { TF_RETURN_IF_ERROR(MakeNodeDef(data->dtype(), node_def())); TF_RETURN_IF_ERROR(InitOp()); @@ -195,7 +195,7 @@ TEST_F(FingerprintOpTest, SupportedMethods) { Tensor tensor(DT_STRING, TensorShape{1}); TF_ASSERT_OK(MakeFingerprintOp(&tensor, "unsupported_method")); - const Status status = RunOpKernel(); + const absl::Status status = RunOpKernel(); EXPECT_FALSE(status.ok()); EXPECT_NE(status.message().find("unsupported_method"), string::npos); } diff --git a/tensorflow/core/kernels/fixed_length_record_reader_op.cc b/tensorflow/core/kernels/fixed_length_record_reader_op.cc index 8db930d7795480..def1dc12d70877 100644 --- a/tensorflow/core/kernels/fixed_length_record_reader_op.cc +++ b/tensorflow/core/kernels/fixed_length_record_reader_op.cc @@ -48,7 +48,7 @@ class FixedLengthRecordReader : public ReaderBase { // On success: // * buffered_inputstream_ != nullptr, // * buffered_inputstream_->Tell() == header_bytes_ - Status OnWorkStartedLocked() override { + absl::Status OnWorkStartedLocked() override { record_number_ = 0; lookahead_cache_.clear(); @@ -72,13 +72,13 @@ class FixedLengthRecordReader : public ReaderBase { return absl::OkStatus(); } - Status OnWorkFinishedLocked() override { + absl::Status OnWorkFinishedLocked() override { buffered_inputstream_.reset(nullptr); return absl::OkStatus(); } - Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) override { + absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) override { // We will always "hop" the hop_bytes_ except the first record // where record_number_ == 0 if (record_number_ != 0) { @@ -92,7 +92,8 @@ class FixedLengthRecordReader : public ReaderBase { // as the cache_size has been skipped through cache. int64_t cache_size = lookahead_cache_.size(); lookahead_cache_.clear(); - Status s = buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size); + absl::Status s = + buffered_inputstream_->SkipNBytes(hop_bytes_ - cache_size); if (!s.ok()) { if (!errors::IsOutOfRange(s)) { return s; @@ -105,7 +106,7 @@ class FixedLengthRecordReader : public ReaderBase { // Fill up lookahead_cache_ to record_bytes_ + footer_bytes_ int bytes_to_read = record_bytes_ + footer_bytes_ - lookahead_cache_.size(); - Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value); + absl::Status s = buffered_inputstream_->ReadNBytes(bytes_to_read, value); if (!s.ok()) { value->clear(); if (!errors::IsOutOfRange(s)) { @@ -127,7 +128,7 @@ class FixedLengthRecordReader : public ReaderBase { return absl::OkStatus(); } - Status ResetLocked() override { + absl::Status ResetLocked() override { record_number_ = 0; buffered_inputstream_.reset(nullptr); lookahead_cache_.clear(); diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index b2930d4b45a670..864855de1d69f6 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -256,21 +256,23 @@ class SymbolicGradientOp : public AsyncOpKernel { } std::vector* rets = new std::vector; tsl::profiler::TraceMe trace_me("SymbolicGradientOp"); - lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { - if (!status.ok()) { - ctx->SetStatus(status); - } else if (rets->size() != ctx->num_outputs()) { - ctx->SetStatus(errors::InvalidArgument( - "SymGrad expects to return ", ctx->num_outputs(), - " tensor(s), but get ", rets->size(), " tensor(s) instead.")); - } else { - for (size_t i = 0; i < rets->size(); ++i) { - ctx->set_output(i, std::move((*rets)[i])); - } - } - delete rets; - done(); - }); + lib->Run( + opts, handle, args, rets, + [ctx, done, rets](const absl::Status& status) { + if (!status.ok()) { + ctx->SetStatus(status); + } else if (rets->size() != ctx->num_outputs()) { + ctx->SetStatus(errors::InvalidArgument( + "SymGrad expects to return ", ctx->num_outputs(), + " tensor(s), but get ", rets->size(), " tensor(s) instead.")); + } else { + for (size_t i = 0; i < rets->size(); ++i) { + ctx->set_output(i, std::move((*rets)[i])); + } + } + delete rets; + done(); + }); } private: @@ -408,7 +410,8 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { lib->Run( opts, handle, args, rets, [rets, done = std::move(done), func_name, ctx, cancel_mgr, - target_device = std::move(function_target.first)](const Status& status) { + target_device = + std::move(function_target.first)](const absl::Status& status) { tsl::profiler::TraceMe activity( [&] { return tsl::profiler::TraceMeEncode( diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 62e5754943135c..0c1de592cbd99f 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -42,13 +42,13 @@ typedef std::vector TensorVec; namespace { // Helper to instantiate function "func" in the library "lib". -Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, - FunctionLibraryRuntime::Handle* handle) { +absl::Status Instantiate(FunctionLibraryRuntime* lib, const NameAttrList& func, + FunctionLibraryRuntime::Handle* handle) { return lib->Instantiate(func.name(), AttrSlice(&func.attr()), handle); } -Status Instantiate(OpKernelContext* ctx, const NameAttrList& func, - FunctionLibraryRuntime::Handle* handle) { +absl::Status Instantiate(OpKernelContext* ctx, const NameAttrList& func, + FunctionLibraryRuntime::Handle* handle) { FunctionLibraryRuntime::InstantiateOptions opts; opts.executor_type = ctx->executor_type(); return ctx->function_library()->Instantiate( @@ -56,7 +56,7 @@ Status Instantiate(OpKernelContext* ctx, const NameAttrList& func, } // If "t" is a scalar of a supported type, returns t != 0 in "*v". -Status ToBool(absl::Span t, bool* v) { +absl::Status ToBool(absl::Span t, bool* v) { if (t.size() != 1) { return errors::InvalidArgument( "Expected a single scalar which can be converted to a boolean, got ", @@ -95,8 +95,8 @@ Status ToBool(absl::Span t, bool* v) { // Sets "rets" to be the output of "ctx". Validates rets' types based // on "kernel". -Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, - absl::Span rets) { +absl::Status SetOutputs(const OpKernel* kernel, OpKernelContext* ctx, + absl::Span rets) { if (rets.size() != ctx->num_outputs()) { return errors::Internal("Expect to produce ", ctx->num_outputs(), " tensors, but only get ", rets.size()); @@ -141,12 +141,12 @@ class IfOp : public AsyncOpKernel { LOG(INFO) << "FunctionLibraryRuntime already destroyed."; continue; } - Status then_status = lib->ReleaseHandle(it.second.first.first); + absl::Status then_status = lib->ReleaseHandle(it.second.first.first); if (!then_status.ok()) { LOG(INFO) << "Ignoring error while destructing IfOp then function: " << then_status; } - Status else_status = lib->ReleaseHandle(it.second.first.second); + absl::Status else_status = lib->ReleaseHandle(it.second.first.second); if (!else_status.ok()) { LOG(INFO) << "Ignoring error while destructing IfOp else function: " << else_status; @@ -204,7 +204,7 @@ class IfOp : public AsyncOpKernel { // Evaluate one of the branch. opts_, handle, args_, &rets_, // Done callback - [this](Status s) { + [this](absl::Status s) { if (s.ok()) { s = SetOutputs(kernel_, ctx_, rets_); } @@ -228,8 +228,8 @@ class IfOp : public AsyncOpKernel { TensorVec rets_; }; - Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, - FHandle* else_handle) { + absl::Status GetHandles(OpKernelContext* ctx, FHandle* then_handle, + FHandle* else_handle) { // TODO(b/37549631): Because this op has `SetIsStateful()` in its // op registration, this kernel may be shared by multiple // subgraphs, which have different associated @@ -284,7 +284,7 @@ class CaseOp : public AsyncOpKernel { } for (const auto& handle : it.second.first) { - Status status = lib->ReleaseHandle(handle); + absl::Status status = lib->ReleaseHandle(handle); if (!status.ok()) { LOG(INFO) << "Ignoring error while destructing CaseOp branch function: " @@ -314,8 +314,8 @@ class CaseOp : public AsyncOpKernel { tsl::core::WeakPtr>> handles_ ABSL_GUARDED_BY(mu_); - Status GetHandles(OpKernelContext* ctx, - std::vector& branch_handles) { + absl::Status GetHandles(OpKernelContext* ctx, + std::vector& branch_handles) { // TODO(b/37549631): Because this op has `SetIsStateful()` in its // op registration, this kernel may be shared by multiple // subgraphs, which have different associated @@ -383,7 +383,7 @@ class CaseOp : public AsyncOpKernel { // Evaluate one of the branch. opts_, branch_handles_[branch], args_, &rets_, // Done callback - [this](Status s) { + [this](absl::Status s) { if (s.ok()) { s = SetOutputs(kernel_, ctx_, rets_); } @@ -444,13 +444,13 @@ class WhileOp : public AsyncOpKernel { LOG(INFO) << "FunctionLibraryRuntime already destroyed."; continue; } - Status cond_status = lib->ReleaseHandle(it.second.first.first); + absl::Status cond_status = lib->ReleaseHandle(it.second.first.first); if (!cond_status.ok()) { LOG(INFO) << "Ignoring error while destructing WhileOp condition function: " << cond_status; } - Status body_status = lib->ReleaseHandle(it.second.first.second); + absl::Status body_status = lib->ReleaseHandle(it.second.first.second); if (!body_status.ok()) { LOG(INFO) << "Ignoring error while destructing WhileOp body function: " << body_status; @@ -476,7 +476,7 @@ class WhileOp : public AsyncOpKernel { void Compute(OpKernelContext* ctx) override { // Use the non-callback-based implementation when the synchronous Compute() // method is invoked, because the caller is explicitly donating a thread. - Status s = DoComputeSync(ctx); + absl::Status s = DoComputeSync(ctx); // NOTE: Unfortunately, we cannot use OP_REQUIRES_OK here, because this is // still an AsyncOpKernel, and there is a run-time check to avoid calling // OP_REQUIRES_OK in AsyncOpKernel::ComputeAsync() (which would deadlock in @@ -496,9 +496,9 @@ class WhileOp : public AsyncOpKernel { tsl::core::WeakPtr>> handles_ ABSL_GUARDED_BY(mu_); - static Status CondResultToBool(OpKernelContext* ctx, - const FunctionLibraryRuntime::Options& opts, - const Tensor& cond_t, bool* out_result) { + static absl::Status CondResultToBool( + OpKernelContext* ctx, const FunctionLibraryRuntime::Options& opts, + const Tensor& cond_t, bool* out_result) { bool is_pluggable = ctx->op_device_context() && ctx->op_device_context()->IsPluggableDevice(); const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = @@ -553,7 +553,7 @@ class WhileOp : public AsyncOpKernel { size_t num_args() const override { return args_->size(); } size_t num_retvals() const override { return retvals_->size(); } - Status GetArg(int index, const Tensor** val) override { + absl::Status GetArg(int index, const Tensor** val) override { if (index < args_->size()) { *val = &(*args_)[index]; return absl::OkStatus(); @@ -571,7 +571,7 @@ class WhileOp : public AsyncOpKernel { return index >= 0 && index < args_->size(); } - Status SetRetval(int index, const Tensor& val) override { + absl::Status SetRetval(int index, const Tensor& val) override { if (TF_PREDICT_FALSE(index < 0)) { return errors::InvalidArgument( "Expected non-negative return value index, but got: ", index, "."); @@ -638,7 +638,7 @@ class WhileOp : public AsyncOpKernel { // Evaluate the condition. opts_, cond_handle_, args_, &rets_, // Done cb. - [this](const Status& s) { + [this](const absl::Status& s) { if (!s.ok()) { return Finish(s); } @@ -647,7 +647,7 @@ class WhileOp : public AsyncOpKernel { } void StartBody() { - Status s; + absl::Status s; if (rets_.size() != 1) { s = errors::InvalidArgument( "Expected a single scalar return value from WhileOp cond, got ", @@ -674,7 +674,7 @@ class WhileOp : public AsyncOpKernel { // Evaluate the body. opts_, body_handle_, body_frame_.get(), // Done callback - [this](const Status& s) { + [this](const absl::Status& s) { if (!s.ok()) { return Finish(s); } @@ -690,7 +690,7 @@ class WhileOp : public AsyncOpKernel { }); } - void Finish(Status s) { + void Finish(absl::Status s) { if (s.ok()) { s = SetOutputs(kernel_, ctx_, args_); } @@ -700,7 +700,7 @@ class WhileOp : public AsyncOpKernel { } }; - Status DoComputeSync(OpKernelContext* ctx) { + absl::Status DoComputeSync(OpKernelContext* ctx) { FHandle cond_handle; FHandle body_handle; TF_RETURN_IF_ERROR(GetHandles(ctx, &cond_handle, &body_handle)); @@ -755,8 +755,8 @@ class WhileOp : public AsyncOpKernel { } while (true); } - Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, - FHandle* body_handle) { + absl::Status GetHandles(OpKernelContext* ctx, FHandle* cond_handle, + FHandle* body_handle) { // TODO(b/37549631): Because this op has `SetIsStateful()` in its // op registration, this kernel may be shared by multiple // subgraphs, which have different associated @@ -817,8 +817,8 @@ class ToBoolOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ToBool").Device(DEVICE_CPU), ToBoolOp); -Status GetScalar(OpKernelContext* ctx, int index, int32* value, - const char* label) { +absl::Status GetScalar(OpKernelContext* ctx, int index, int32* value, + const char* label) { Tensor t = ctx->input(index); if (!TensorShapeUtils::IsScalar(t.shape())) { return errors::InvalidArgument(label, " must be a scalar, but ", @@ -843,7 +843,7 @@ class ForOp : public AsyncOpKernel { LOG(INFO) << "FunctionLibraryRuntime already destroyed."; continue; } - Status status = lib->ReleaseHandle(it.second.first); + absl::Status status = lib->ReleaseHandle(it.second.first); if (!status.ok()) { LOG(INFO) << "Ignoring error while destructing ForOp body function: " << status; @@ -866,7 +866,7 @@ class ForOp : public AsyncOpKernel { std::pair>> handles_ ABSL_GUARDED_BY(mu_); - Status GetHandles(OpKernelContext* ctx, FHandle* body_handle) { + absl::Status GetHandles(OpKernelContext* ctx, FHandle* body_handle) { // TODO(b/37549631): Because this op has `SetIsStateful()` in its // op registration, this kernel may be shared by multiple // subgraphs, which have different associated @@ -922,7 +922,7 @@ class ForOp : public AsyncOpKernel { ~State() = default; void Start() { - Status s = StartLoop(); + absl::Status s = StartLoop(); if (!s.ok()) Finish(s); } @@ -942,7 +942,7 @@ class ForOp : public AsyncOpKernel { // If an error e is returned, caller must call Finish(e). // If OK is returned, the async loop execution has been started. - Status StartLoop() { + absl::Status StartLoop() { SetRunOptions(ctx_, &opts_, false /* always_collect_stats */); TF_RETURN_IF_ERROR(GetScalar(ctx_, 0, iter_, "start")); @@ -983,17 +983,18 @@ class ForOp : public AsyncOpKernel { } rets_.clear(); tsl::profiler::TraceMe trace_me("ForOp"); - lib_->Run(opts_, body_handle_, args_, &rets_, [this](const Status& s) { - if (s.ok()) { - *iter_ += delta_; - RunNext(); - } else { - Finish(s); - } - }); + lib_->Run(opts_, body_handle_, args_, &rets_, + [this](const absl::Status& s) { + if (s.ok()) { + *iter_ += delta_; + RunNext(); + } else { + Finish(s); + } + }); } - void Finish(Status s) { + void Finish(absl::Status s) { if (s.ok()) { s = SetOutputs(kernel_, ctx_, rets_); } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 241ecfe789f371..3d510e4b50dadd 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -69,8 +69,9 @@ string ToString(FusedBatchNormActivationMode activation_mode) { } } -Status ParseActivationMode(OpKernelConstruction* context, - FusedBatchNormActivationMode* activation_mode) { +absl::Status ParseActivationMode( + OpKernelConstruction* context, + FusedBatchNormActivationMode* activation_mode) { string activation_mode_str; TF_RETURN_IF_ERROR(context->GetAttr("activation_mode", &activation_mode_str)); @@ -230,7 +231,7 @@ struct FusedBatchNorm { if (tensor_format == FORMAT_NCHW) { // Perform NHWC to NCHW const std::array perm = {0, 3, 1, 2}; - const Status s = ::tensorflow::DoTranspose( + const absl::Status s = ::tensorflow::DoTranspose( context->eigen_device(), transformed_y, perm, y_output); if (!s.ok()) { context->SetStatus(errors::InvalidArgument("Transpose failed: ", s)); @@ -350,7 +351,7 @@ struct FusedBatchNorm { if (tensor_format == FORMAT_NCHW) { // Perform NHWC to NCHW const std::array perm = {0, 3, 1, 2}; - const Status s = ::tensorflow::DoTranspose( + const absl::Status s = ::tensorflow::DoTranspose( context->eigen_device(), transformed_y, perm, y_output); if (!s.ok()) { context->SetStatus(errors::InvalidArgument("Transpose failed: ", s)); diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h index 2e2d32b8f5ab41..e50d80aef15081 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.h +++ b/tensorflow/core/kernels/fused_batch_norm_op.h @@ -32,8 +32,8 @@ enum class FusedBatchNormActivationMode { kIdentity, kRelu }; std::string ToString(FusedBatchNormActivationMode activation_mode); -Status ParseActivationMode(OpKernelConstruction* context, - FusedBatchNormActivationMode* activation_mode); +absl::Status ParseActivationMode(OpKernelConstruction* context, + FusedBatchNormActivationMode* activation_mode); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.cc b/tensorflow/core/kernels/fused_eigen_output_kernels.cc index 39054713db5ac2..5af4c9cb6f7028 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.cc +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { -Status InitializeFusedComputation( +absl::Status InitializeFusedComputation( OpKernelConstruction* context, const std::string& kernel_name, const std::vector& patterns, FusedComputationType* fused_computation, diff --git a/tensorflow/core/kernels/fused_eigen_output_kernels.h b/tensorflow/core/kernels/fused_eigen_output_kernels.h index 21bcf17df3e9d6..84a0d27ba60840 100644 --- a/tensorflow/core/kernels/fused_eigen_output_kernels.h +++ b/tensorflow/core/kernels/fused_eigen_output_kernels.h @@ -66,7 +66,7 @@ struct FusedComputationPattern { // Parse attributes from the kernel construction context, and verifies that they // specify valid fused computation pattern. -Status InitializeFusedComputation( +absl::Status InitializeFusedComputation( OpKernelConstruction* context, const string& kernel_name, const std::vector& patterns, FusedComputationType* fused_computation, @@ -409,8 +409,8 @@ template using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel; template -Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, - const float* leakyrelu_alpha = nullptr) { +absl::Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, + const float* leakyrelu_alpha = nullptr) { // Bias of the following dimensions: [ output_depth ] const Tensor& bias = context->input(2); @@ -432,9 +432,9 @@ Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args, } template -Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, - FusedBatchNormArgs* args, - const float* leakyrelu_alpha = nullptr) { +absl::Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon, + FusedBatchNormArgs* args, + const float* leakyrelu_alpha = nullptr) { const Tensor& scale = context->input(2); const Tensor& offset = context->input(3); const Tensor& estimated_mean = context->input(4); diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h index e79ec675e26dec..b53e1348713edc 100644 --- a/tensorflow/core/kernels/gather_nd_op.h +++ b/tensorflow/core/kernels/gather_nd_op.h @@ -45,7 +45,7 @@ struct GatherNdSlice { }; template -Status DoGatherNd( +absl::Status DoGatherNd( OpKernelContext* c, const Tensor& params, const Tensor& indices, Tensor* out, BadIndicesPolicy bad_indices_policy = BadIndicesPolicy::kDefault) { diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index 2758fbb3a57fe1..8fd4d26453decc 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -89,7 +89,7 @@ TEST_F(GatherNdOpTest, Error_OutOfRange) { // Feed and run AddInputFromArray(TensorShape({5}), {0, 1, 2, 8, 4}); AddInputFromArray(TensorShape({2, 1}), {3, 5}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.message(), "indices[1] = [5] does not index into param shape [5]")) << s.message(); diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index 636c3cf18fd336..f3cbce2fb249bb 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -184,7 +184,7 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) { {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); AddInputFromArray(TensorShape({4}), {0, 4, 99, 2}); AddInputFromArray(TensorShape({}), {0}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)")) << s; @@ -198,7 +198,7 @@ TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) { {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); AddInputFromArray(TensorShape({4}), {0, 4, 99, 2}); AddInputFromArray(TensorShape({}), {0}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10")) << s; diff --git a/tensorflow/core/kernels/guarantee_const_op_test.cc b/tensorflow/core/kernels/guarantee_const_op_test.cc index 75ef6decf0144a..09e74987c72598 100644 --- a/tensorflow/core/kernels/guarantee_const_op_test.cc +++ b/tensorflow/core/kernels/guarantee_const_op_test.cc @@ -30,7 +30,7 @@ namespace { class GuaranteeConstOpTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "GuaranteeConst") .Input(FakeInput(input_type)) .Finalize(node_def())); diff --git a/tensorflow/core/kernels/hinge-loss.h b/tensorflow/core/kernels/hinge-loss.h index 8d2b1d33b683d9..51f11e049cfb26 100644 --- a/tensorflow/core/kernels/hinge-loss.h +++ b/tensorflow/core/kernels/hinge-loss.h @@ -106,7 +106,7 @@ class HingeLossUpdater : public DualLossUpdater { // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively // as expected by hinge loss. - Status ConvertLabel(float* const example_label) const final { + absl::Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/identity_n_op_test.cc b/tensorflow/core/kernels/identity_n_op_test.cc index b5bc5e62a95e1f..1cde97d5301188 100644 --- a/tensorflow/core/kernels/identity_n_op_test.cc +++ b/tensorflow/core/kernels/identity_n_op_test.cc @@ -29,7 +29,7 @@ namespace { class IdentityNOpTest : public OpsTestBase { protected: - Status Init(DataType input0_type, DataType input1_type) { + absl::Status Init(DataType input0_type, DataType input1_type) { TF_CHECK_OK(NodeDefBuilder("op", "IdentityN") .Input(FakeInput({input0_type, input1_type})) .Finalize(node_def())); diff --git a/tensorflow/core/kernels/identity_op_test.cc b/tensorflow/core/kernels/identity_op_test.cc index 8b23aedd07d6fe..408b28e9d805cb 100644 --- a/tensorflow/core/kernels/identity_op_test.cc +++ b/tensorflow/core/kernels/identity_op_test.cc @@ -29,7 +29,7 @@ namespace { class IdentityOpTest : public OpsTestBase { protected: - Status Init(DataType input_type) { + absl::Status Init(DataType input_type) { TF_CHECK_OK(NodeDefBuilder("op", "Identity") .Input(FakeInput(input_type)) .Finalize(node_def())); diff --git a/tensorflow/core/kernels/identity_reader_op.cc b/tensorflow/core/kernels/identity_reader_op.cc index 6dee08d1a60cc6..292cede60c537f 100644 --- a/tensorflow/core/kernels/identity_reader_op.cc +++ b/tensorflow/core/kernels/identity_reader_op.cc @@ -33,8 +33,8 @@ class IdentityReader : public ReaderBase { explicit IdentityReader(const string& node_name) : ReaderBase(strings::StrCat("IdentityReader '", node_name, "'")) {} - Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) override { + absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) override { *key = current_work(); *value = current_work(); *produced = true; @@ -44,14 +44,14 @@ class IdentityReader : public ReaderBase { // Stores state in a ReaderBaseState proto, since IdentityReader has // no additional state beyond ReaderBase. - Status SerializeStateLocked(tstring* state) override { + absl::Status SerializeStateLocked(tstring* state) override { ReaderBaseState base_state; SaveBaseState(&base_state); SerializeToTString(base_state, state); return absl::OkStatus(); } - Status RestoreStateLocked(const tstring& state) override { + absl::Status RestoreStateLocked(const tstring& state) override { ReaderBaseState base_state; if (!ParseProtoUnlimited(&base_state, state)) { return errors::InvalidArgument("Could not parse state for ", name(), ": ", diff --git a/tensorflow/core/kernels/image/BUILD b/tensorflow/core/kernels/image/BUILD index d871c0868262e0..8733c058c7acf0 100644 --- a/tensorflow/core/kernels/image/BUILD +++ b/tensorflow/core/kernels/image/BUILD @@ -197,7 +197,6 @@ tf_kernel_library( "//tensorflow/core/util:determinism_for_kernels", ] + if_cuda_or_rocm([ "//tensorflow/core/platform:stream_executor", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", ]), ) diff --git a/tensorflow/core/kernels/image/crop_and_resize_op.cc b/tensorflow/core/kernels/image/crop_and_resize_op.cc index 6624aa2523da97..cd7501a4966625 100644 --- a/tensorflow/core/kernels/image/crop_and_resize_op.cc +++ b/tensorflow/core/kernels/image/crop_and_resize_op.cc @@ -41,12 +41,8 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" -using stream_executor::gpu::ScopedActivateContext; -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" -using stream_executor::gpu::ScopedActivateContext; #endif namespace tensorflow { @@ -56,9 +52,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; using Callback = std::function; -static inline Status ParseAndCheckBoxSizes(const Tensor& boxes, - const Tensor& box_index, - int* num_boxes) { +static inline absl::Status ParseAndCheckBoxSizes(const Tensor& boxes, + const Tensor& box_index, + int* num_boxes) { if (boxes.NumElements() == 0 && box_index.NumElements() == 0) { *num_boxes = 0; return absl::OkStatus(); @@ -888,7 +884,8 @@ inline void RunIfBoxIndexIsValid( compute, done]() { { auto stream = context->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); const bool isvalid = isvalid_host_tensor.scalar()(); isvalid_dev_ref.Unref(); OP_REQUIRES_ASYNC( @@ -896,7 +893,7 @@ inline void RunIfBoxIndexIsValid( errors::OutOfRange("box_index has values outside [0, batch_size)"), done); compute(); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda Context. done(); diff --git a/tensorflow/core/kernels/image/crop_and_resize_op_test.cc b/tensorflow/core/kernels/image/crop_and_resize_op_test.cc index 70b9dc77ca6117..b82df065927a82 100644 --- a/tensorflow/core/kernels/image/crop_and_resize_op_test.cc +++ b/tensorflow/core/kernels/image/crop_and_resize_op_test.cc @@ -379,7 +379,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidInputShape) { AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {0}); AddInputFromArray(TensorShape({2}), {4, 4}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "input image must be 4-D")) << s; } @@ -390,7 +390,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) { AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({2}), {0, 0}); AddInputFromArray(TensorShape({2}), {4, 4}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( absl::StrContains(s.ToString(), "box_index has incompatible shape")) @@ -403,7 +403,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) { AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {1}); AddInputFromArray(TensorShape({2}), {3, 3}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "box_index has values outside [0, batch_size)")) diff --git a/tensorflow/core/kernels/image/decode_image_op.cc b/tensorflow/core/kernels/image/decode_image_op.cc index afb653191e3e8a..d9c4def13bc044 100644 --- a/tensorflow/core/kernels/image/decode_image_op.cc +++ b/tensorflow/core/kernels/image/decode_image_op.cc @@ -273,7 +273,7 @@ class DecodeImageV2Op : public OpKernel { input.data(), input.size(), flags, nullptr /* nwarn */, [&](int width, int height, int channels) -> uint8* { buffer_size = height * width * channels; - Status status; + absl::Status status; // By the existing API, we support decoding JPEG with `DecodeGif` // op. We need to make sure to return 4-D shapes when using // `DecodeGif`. @@ -465,7 +465,7 @@ class DecodeImageV2Op : public OpKernel { buffer_size = static_cast(num_frames) * height * width * channels; - Status status; + absl::Status status; // By the existing API, we support decoding GIF with `decode_jpeg` or // with `decode_png` if the GIF is a single-frame GIF (non-animated). // We need to make sure to return 3-D shapes when using in this case. diff --git a/tensorflow/core/kernels/image/encode_jpeg_op_test.cc b/tensorflow/core/kernels/image/encode_jpeg_op_test.cc index 5a97d734a00791..922a3aff5f72b0 100644 --- a/tensorflow/core/kernels/image/encode_jpeg_op_test.cc +++ b/tensorflow/core/kernels/image/encode_jpeg_op_test.cc @@ -38,7 +38,7 @@ TEST_F(EncodeJpegWithVariableQualityTest, FailsForInvalidQuality) { AddInputFromArray(TensorShape({2, 2, 3}), {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); AddInputFromArray(TensorShape({}), {200}); - Status status = RunOpKernel(); + absl::Status status = RunOpKernel(); EXPECT_TRUE(errors::IsInvalidArgument(status)); EXPECT_TRUE(absl::StartsWith(status.message(), "quality must be in [0,100]")); } diff --git a/tensorflow/core/kernels/image/non_max_suppression_op_test.cc b/tensorflow/core/kernels/image/non_max_suppression_op_test.cc index 580e6ceb0eb7bf..69331e44434d20 100644 --- a/tensorflow/core/kernels/image/non_max_suppression_op_test.cc +++ b/tensorflow/core/kernels/image/non_max_suppression_op_test.cc @@ -174,7 +174,7 @@ TEST_F(NonMaxSuppressionOpTest, TestInconsistentBoxAndScoreShapes) { 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101}); AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); AddInputFromArray(TensorShape({}), {30}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) @@ -186,7 +186,7 @@ TEST_F(NonMaxSuppressionOpTest, TestInvalidIOUThreshold) { AddInputFromArray(TensorShape({1, 4}), {0, 0, 1, 1}); AddInputFromArray(TensorShape({1}), {.9f}); AddInputFromArray(TensorShape({}), {3}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( @@ -334,7 +334,7 @@ TEST_F(NonMaxSuppressionV2OpTest, TestInconsistentBoxAndScoreShapes) { AddInputFromArray(TensorShape({5}), {.9f, .75f, .6f, .95f, .5f}); AddInputFromArray(TensorShape({}), {30}); AddInputFromArray(TensorShape({}), {.5f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) @@ -347,7 +347,7 @@ TEST_F(NonMaxSuppressionV2OpTest, TestInvalidIOUThreshold) { AddInputFromArray(TensorShape({1}), {.9f}); AddInputFromArray(TensorShape({}), {3}); AddInputFromArray(TensorShape({}), {1.2f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( @@ -583,7 +583,7 @@ TYPED_TEST(NonMaxSuppressionV3OpTest, TestInconsistentBoxAndScoreShapes) { this->template AddInputFromList(TensorShape({}), {30}); this->template AddInputFromList(TensorShape({}), {0.5}); this->template AddInputFromList(TensorShape({}), {0}); - Status s = this->RunOpKernel(); + absl::Status s = this->RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) @@ -599,7 +599,7 @@ TYPED_TEST(NonMaxSuppressionV3OpTest, TestInvalidIOUThreshold) { this->template AddInputFromList(TensorShape({}), {3}); this->template AddInputFromList(TensorShape({}), {1.2f}); this->template AddInputFromList(TensorShape({}), {0}); - Status s = this->RunOpKernel(); + absl::Status s = this->RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE( @@ -950,7 +950,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInconsistentBoxAndScoreShapes) { AddInputFromArray(TensorShape({}), {30}); AddInputFromArray(TensorShape({}), {.5f}); AddInputFromArray(TensorShape({}), {0.0f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "scores has incompatible shape")) @@ -964,7 +964,7 @@ TEST_F(NonMaxSuppressionWithOverlapsOpTest, TestInvalidOverlapsShape) { AddInputFromArray(TensorShape({}), {30}); AddInputFromArray(TensorShape({}), {0.f}); AddInputFromArray(TensorShape({}), {0.0f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); ASSERT_FALSE(s.ok()); EXPECT_TRUE(absl::StrContains(s.ToString(), "overlaps must be square")) << s; diff --git a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc index 209dbbdd60761d..c80a518c4e01d4 100644 --- a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc @@ -227,7 +227,7 @@ TEST_F(ResizeBicubicOpTest, TestBicubic2x2To0x0) { AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({2}), {0, 0}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE( absl::StrContains(s.message(), "output dimensions must be positive")) diff --git a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc index 61e02fbd686ced..d8b3d6c779f92b 100644 --- a/tensorflow/core/kernels/image/resize_bilinear_op_test.cc +++ b/tensorflow/core/kernels/image/resize_bilinear_op_test.cc @@ -493,7 +493,7 @@ TEST_P(ResizeBilinearOpTest, Test6_3c) { TestResize(1, 304, 303, 3, 299, 299); } TEST_P(ResizeBilinearOpTest, TestInvalidOutputSize) { AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({2}), {0, 0}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE( absl::StrContains(s.message(), "output dimensions must be positive")) @@ -503,7 +503,7 @@ TEST_P(ResizeBilinearOpTest, TestInvalidOutputSize) { TEST_P(ResizeBilinearOpTest, TestInvalidInputShape) { AddInputFromArray(TensorShape({2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({2}), {4, 4}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), "input must be 4-dimensional")) << s; @@ -512,7 +512,7 @@ TEST_P(ResizeBilinearOpTest, TestInvalidInputShape) { TEST_P(ResizeBilinearOpTest, TestInvalidSizeDim) { AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({2, 1}), {4, 4}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), "shape_t must be 1-dimensional")) << s; @@ -521,7 +521,7 @@ TEST_P(ResizeBilinearOpTest, TestInvalidSizeDim) { TEST_P(ResizeBilinearOpTest, TestInvalidSizeElements) { AddInputFromArray(TensorShape({1, 2, 2, 1}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({3}), {4, 4, 1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), "shape_t must have two elements")) << s; diff --git a/tensorflow/core/kernels/image/resize_op_benchmark_test.cc b/tensorflow/core/kernels/image/resize_op_benchmark_test.cc index 6736f04eda5b42..acae1f0b49f2f9 100644 --- a/tensorflow/core/kernels/image/resize_op_benchmark_test.cc +++ b/tensorflow/core/kernels/image/resize_op_benchmark_test.cc @@ -33,10 +33,10 @@ static Graph* Resize(const char* algorithm, int batches, int width, out_size_flat(1) = height * 2; Node* ret; - Status s = NodeBuilder(g->NewName("n"), algorithm) - .Input(test::graph::Constant(g, in)) - .Input(test::graph::Constant(g, out_size)) - .Finalize(g, &ret); + absl::Status s = NodeBuilder(g->NewName("n"), algorithm) + .Input(test::graph::Constant(g, in)) + .Input(test::graph::Constant(g, out_size)) + .Finalize(g, &ret); assert(s.ok()); return g; } diff --git a/tensorflow/core/kernels/image/scale_and_translate_op.cc b/tensorflow/core/kernels/image/scale_and_translate_op.cc index f1f8b357684e1e..3cacc10229495d 100644 --- a/tensorflow/core/kernels/image/scale_and_translate_op.cc +++ b/tensorflow/core/kernels/image/scale_and_translate_op.cc @@ -49,10 +49,11 @@ inline const T& Clamp(const T& low, const T& high, const T& value) { } template -Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, - const int64_t output_size, const int64_t input_size, - const float scale, const float translate, - const bool antialias, Spans* spans) { +absl::Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, + const int64_t output_size, + const int64_t input_size, const float scale, + const float translate, const bool antialias, + Spans* spans) { // When sampling, we need the inverse scale and translation, to map from an // output to an input pixel. const float inv_scale = 1.0 / scale; @@ -124,10 +125,10 @@ Status ComputeSpansCore(OpKernelContext* context, const Kernel& kernel, return absl::OkStatus(); } -Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, - const int64_t forward_output_size, - const int64_t forward_input_size, - Spans* grad_spans) { +absl::Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, + const int64_t forward_output_size, + const int64_t forward_input_size, + Spans* grad_spans) { struct GradComponent { int index; float weight; @@ -188,11 +189,11 @@ Status ComputeGradSpansCore(OpKernelContext* context, const Spans& spans, // input_size transformed by scale and translate to an output dimension of // length output_size. Note that there's no requirement that; // output_size = input_size * scale. -Status ComputeSpans(OpKernelContext* context, - const functor::SamplingKernelType kernel_type, - const int64_t output_size, const int64_t input_size, - const float scale, const float translate, - const bool antialias, Spans* spans) { +absl::Status ComputeSpans(OpKernelContext* context, + const functor::SamplingKernelType kernel_type, + const int64_t output_size, const int64_t input_size, + const float scale, const float translate, + const bool antialias, Spans* spans) { switch (kernel_type) { case functor::Lanczos1Kernel: { return ComputeSpansCore(context, CreateLanczos1Kernel(), output_size, @@ -236,12 +237,12 @@ Status ComputeSpans(OpKernelContext* context, // Computes the grad spans for the passed kernel. // forward_input_size and forward_output_size are the input and output size from // the forward operation. -Status ComputeGradSpans(OpKernelContext* context, - const functor::SamplingKernelType kernel_type, - const int64_t forward_output_size, - const int64_t forward_input_size, const float scale, - const float translate, const bool antialias, - Spans* grad_spans) { +absl::Status ComputeGradSpans(OpKernelContext* context, + const functor::SamplingKernelType kernel_type, + const int64_t forward_output_size, + const int64_t forward_input_size, + const float scale, const float translate, + const bool antialias, Spans* grad_spans) { Spans spans; TF_RETURN_IF_ERROR(ComputeSpans(context, kernel_type, forward_output_size, forward_input_size, scale, translate, diff --git a/tensorflow/core/kernels/image/scale_and_translate_op_test.cc b/tensorflow/core/kernels/image/scale_and_translate_op_test.cc index 4ae5accf4c83a1..55bd559ddd6776 100644 --- a/tensorflow/core/kernels/image/scale_and_translate_op_test.cc +++ b/tensorflow/core/kernels/image/scale_and_translate_op_test.cc @@ -248,7 +248,7 @@ class ScaleAndTranslateOpTest : public OpsTestBase { {output_image_height, output_image_width}); AddInputFromArray(TensorShape({2}), {scale[1], scale[0]}); AddInputFromArray(TensorShape({2}), {translate[1], translate[0]}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); const int batch_size = GetOutput(0)->dim_size(0); const int channels = GetOutput(0)->dim_size(3); Tensor expected(allocator(), DT_FLOAT, diff --git a/tensorflow/core/kernels/immutable_constant_op.cc b/tensorflow/core/kernels/immutable_constant_op.cc index 2b9f8f34f5d4c1..be0194413a3b81 100644 --- a/tensorflow/core/kernels/immutable_constant_op.cc +++ b/tensorflow/core/kernels/immutable_constant_op.cc @@ -26,7 +26,7 @@ class MemmappedTensorAllocator : public Allocator { public: MemmappedTensorAllocator() {} - Status InitializeFromRegion(const string& name, Env* env) { + absl::Status InitializeFromRegion(const string& name, Env* env) { const auto status = env->NewReadOnlyMemoryRegionFromFile(name, &memory_region_); if (!status.ok()) { @@ -60,7 +60,7 @@ class MemmappedTensorAllocator : public Allocator { delete this; } } - const Status& allocation_status() const { return allocation_status_; } + const absl::Status& allocation_status() const { return allocation_status_; } void set_delete_on_deallocate() { delete_on_deallocate_ = true; } @@ -73,7 +73,7 @@ class MemmappedTensorAllocator : public Allocator { private: std::unique_ptr memory_region_; // If there is an error during allocation we keep it in this status. - Status allocation_status_; + absl::Status allocation_status_; // When the allocator is owned by TensorBuffer it will be deleted on // de-allocation. diff --git a/tensorflow/core/kernels/immutable_constant_op_test.cc b/tensorflow/core/kernels/immutable_constant_op_test.cc index a91932a240b292..e1edbadf9f210f 100644 --- a/tensorflow/core/kernels/immutable_constant_op_test.cc +++ b/tensorflow/core/kernels/immutable_constant_op_test.cc @@ -64,7 +64,7 @@ class TestFileSystem : public NullFileSystem { // import non-transactional method from the base class using NullFileSystem::NewReadOnlyMemoryRegionFromFile; - Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewReadOnlyMemoryRegionFromFile( const string& fname, TransactionToken* token, std::unique_ptr* result) override { float val = 0; @@ -146,8 +146,8 @@ TEST(ImmutableConstantOpTest, ExecutionError) { error::INTERNAL); } -Status CreateTempFileFloat(Env* env, float value, uint64 size, - string* filename) { +absl::Status CreateTempFileFloat(Env* env, float value, uint64 size, + string* filename) { const string dir = testing::TmpDir(); *filename = io::JoinPath(dir, strings::StrCat("file_", value)); std::unique_ptr file; @@ -191,8 +191,8 @@ TEST(ImmutableConstantOpTest, FromFile) { EXPECT_EQ(outputs.front().flat()(2), 2.0f * 3.0f); } -Status CreateTempFileBadString(Env* env, char value, uint64 size, - const string suffix, string* filename) { +absl::Status CreateTempFileBadString(Env* env, char value, uint64 size, + const string suffix, string* filename) { const string dir = testing::TmpDir(); *filename = io::JoinPath(dir, strings::StrCat("file_", suffix)); std::unique_ptr file; diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc index 7c295b970ab538..33ae127ac94031 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.cc +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -21,9 +21,9 @@ limitations under the License. namespace tensorflow { namespace lookup { -Status InitializableLookupTable::Find(OpKernelContext* ctx, const Tensor& keys, - Tensor* values, - const Tensor& default_value) { +absl::Status InitializableLookupTable::Find(OpKernelContext* ctx, + const Tensor& keys, Tensor* values, + const Tensor& default_value) { if (!is_initialized()) { return errors::FailedPrecondition("Table not initialized."); } @@ -33,9 +33,9 @@ Status InitializableLookupTable::Find(OpKernelContext* ctx, const Tensor& keys, return DoFind(keys, values, default_value); } -Status InitializableLookupTable::ImportValues(OpKernelContext* ctx, - const Tensor& keys, - const Tensor& values) { +absl::Status InitializableLookupTable::ImportValues(OpKernelContext* ctx, + const Tensor& keys, + const Tensor& values) { lookup::KeyValueTensorIterator iter(&keys, &values); auto serializer = std::make_unique( [keys, values](GraphDefBuilder* builder, Node* table, Node** out) { @@ -60,11 +60,11 @@ Status InitializableLookupTable::ImportValues(OpKernelContext* ctx, return Initialize(iter, std::move(serializer)); } -Status InitializableLookupTable::Initialize(InitTableIterator& iter) { +absl::Status InitializableLookupTable::Initialize(InitTableIterator& iter) { return Initialize(iter, /*serializer=*/nullptr); } -Status InitializableLookupTable::Initialize( +absl::Status InitializableLookupTable::Initialize( InitTableIterator& iter, std::unique_ptr serializer) { if (!iter.Valid()) { @@ -101,8 +101,8 @@ Status InitializableLookupTable::Initialize( return absl::OkStatus(); } -Status InitializableLookupTable::AreEntriesSame(const InitTableIterator& iter, - bool* result) { +absl::Status InitializableLookupTable::AreEntriesSame( + const InitTableIterator& iter, bool* result) { *result = static_cast(iter.total_size()) == size(); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index 010febb73e8cca..c190fbd3e158fe 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -44,30 +44,30 @@ class InitializableLookupTable : public LookupInterface { // fails. // - In addition, other implementations may provide another non-OK status // specific to their failure modes. - Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, - const Tensor& default_value) final; + absl::Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, + const Tensor& default_value) final; // Returns errors::Unimplemented. - Status Insert(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) final { + absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) final { return errors::Unimplemented( "Insert not supported by InitializableLookupTable implementations"); } // Returns errors::Unimplemented. - Status Remove(OpKernelContext* ctx, const Tensor& keys) final { + absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) final { return errors::Unimplemented( "Remove not supported by InitializableLookupTable implementations"); } - Status ExportValues(OpKernelContext* context) override { + absl::Status ExportValues(OpKernelContext* context) override { return errors::Unimplemented( "ExportValues not supported by InitializableLookupTable " "implementations"); } - Status ImportValues(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) final; + absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) final; TensorShape key_shape() const final { return TensorShape(); } @@ -91,14 +91,14 @@ class InitializableLookupTable : public LookupInterface { // fail_if_initialized is set to true. // - In addition, other implementations may provide another non-OK status // specific to their failure modes. - Status Initialize(InitTableIterator& iter); + absl::Status Initialize(InitTableIterator& iter); // Initializes the table from the given init table iterator. `serializer` may // specify how to serialize the table initializer, so that the table can be // serialized using its metadata (as opposed to serializing a handle to the // table). - Status Initialize(InitTableIterator& iter, - std::unique_ptr serializer); + absl::Status Initialize(InitTableIterator& iter, + std::unique_ptr serializer); // Basic iterator to initialize lookup tables. // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that @@ -127,7 +127,7 @@ class InitializableLookupTable : public LookupInterface { virtual const Tensor& values() const = 0; // Returns an error if one has occurred, otherwise returns Status::OK. - virtual Status status() const = 0; + virtual absl::Status status() const = 0; // Returns the total number of elements that the iterator will produce. // It might return -1 in case of error. @@ -149,8 +149,8 @@ class InitializableLookupTable : public LookupInterface { public: // A function which builds a graph so that executing `*out` will initialize // `table`. - using SerializeFn = std::function; + using SerializeFn = std::function; // A function which performs any necessary cleanup for the serializer. using CleanupFn = std::function; @@ -166,7 +166,7 @@ class InitializableLookupTable : public LookupInterface { ~InitializerSerializer() { cleanup_(); } // Builds a graph so that executing `*out` will initialize `table`. - Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) { + absl::Status AsGraphDef(GraphDefBuilder* builder, Node* table, Node** out) { return serialize_(builder, table, out); } @@ -178,11 +178,11 @@ class InitializableLookupTable : public LookupInterface { protected: // Prepares and allocates the underlying data structure to store the given // number of expected elements. - virtual Status DoPrepare(size_t expected_num_elements) = 0; + virtual absl::Status DoPrepare(size_t expected_num_elements) = 0; // Same as DoPrepare() but derived implementations might choose to skip // calling get_expected_num_elements if size is not needed for DoPrepare. - virtual Status DoLazyPrepare( + virtual absl::Status DoLazyPrepare( std::function get_expected_num_elements) { int64_t expected_num_elements = get_expected_num_elements(); if (expected_num_elements < 0) { @@ -193,13 +193,14 @@ class InitializableLookupTable : public LookupInterface { // Populates the table in batches given keys and values as tensors into the // underlying data structure. - virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; + virtual absl::Status DoInsert(const Tensor& keys, const Tensor& values) = 0; // Performs the batch find operation on the underlying data structure. - virtual Status DoFind(const Tensor& keys, Tensor* values, - const Tensor& default_value) = 0; + virtual absl::Status DoFind(const Tensor& keys, Tensor* values, + const Tensor& default_value) = 0; - virtual Status AreEntriesSame(const InitTableIterator& iter, bool* result); + virtual absl::Status AreEntriesSame(const InitTableIterator& iter, + bool* result); mutex mu_; @@ -248,7 +249,7 @@ class KeyValueTensorIterator const Tensor& values() const override { return *values_; } - Status status() const override { return status_; } + absl::Status status() const override { return status_; } int64_t total_size() const override { return keys_ == nullptr ? -1 : keys_->NumElements(); @@ -261,7 +262,7 @@ class KeyValueTensorIterator const Tensor* keys_; // Doesn't own it. const Tensor* values_; // Doesn't own it. bool valid_; // true if the iterator points to an existing range. - Status status_; + absl::Status status_; }; } // namespace lookup diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 5c547ad74786c3..45db7d3b2d3f49 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -30,8 +30,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice; namespace functor { template -Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32_t loc, - Tensor* output) { +absl::Status DoParallelConcatUpdate(const Device& d, const Tensor& value, + int32_t loc, Tensor* output) { auto Tvalue = value.shaped({1, value.NumElements()}); auto Toutput = output->flat_outer_dims(); auto nrows = Toutput.dimension(0); @@ -41,8 +41,8 @@ Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32_t loc, } template <> -Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32_t loc, - Tensor* output) { +absl::Status DoParallelConcat(const CPUDevice& d, const Tensor& value, + int32_t loc, Tensor* output) { CHECK_EQ(value.dtype(), output->dtype()); switch (value.dtype()) { #define CASE(type) \ @@ -240,8 +240,8 @@ class InplaceOpBase : public OpKernel { } protected: - virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i, - const Tensor& v, Tensor* y) = 0; + virtual absl::Status DoCompute(OpKernelContext* ctx, const Tensor& i, + const Tensor& v, Tensor* y) = 0; }; } // end namespace @@ -285,8 +285,8 @@ void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i, } template <> -Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i, - const Tensor& v, Tensor* y) { +absl::Status DoInplace(const CPUDevice& device, InplaceOpType op, + const Tensor& i, const Tensor& v, Tensor* y) { CHECK_EQ(v.dtype(), y->dtype()); if (op == I_UPDATE) { if (v.dtype() == DT_STRING) { @@ -320,8 +320,8 @@ class InplaceOp : public InplaceOpBase { explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {} protected: - Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v, - Tensor* y) override { + absl::Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v, + Tensor* y) override { const auto& d = ctx->eigen_device(); return ::tensorflow::functor::DoInplace(d, op, i, v, y); } @@ -339,8 +339,8 @@ class CopyOpBase : public OpKernel { } protected: - virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x, - Tensor* y) = 0; + virtual absl::Status DoCompute(OpKernelContext* ctx, const Tensor& x, + Tensor* y) = 0; }; template @@ -349,7 +349,8 @@ class CopyOp : public CopyOpBase { explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {} protected: - Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override { + absl::Status DoCompute(OpKernelContext* ctx, const Tensor& x, + Tensor* y) override { const auto& d = ctx->eigen_device(); return ::tensorflow::functor::DoCopy(d, x, y); } @@ -362,7 +363,7 @@ namespace functor { typedef Eigen::ThreadPoolDevice CPUDevice; template <> -Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) { +absl::Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) { CHECK_EQ(x.dtype(), y->dtype()); switch (x.dtype()) { #define CASE(type) \ diff --git a/tensorflow/core/kernels/inplace_ops_functor.h b/tensorflow/core/kernels/inplace_ops_functor.h index f591fc4f032e94..e1707824158114 100644 --- a/tensorflow/core/kernels/inplace_ops_functor.h +++ b/tensorflow/core/kernels/inplace_ops_functor.h @@ -23,8 +23,8 @@ namespace tensorflow { namespace functor { template -Status DoParallelConcat(const Device& device, const Tensor& value, int32_t loc, - Tensor* output); +absl::Status DoParallelConcat(const Device& device, const Tensor& value, + int32_t loc, Tensor* output); // Inplace update/add/sub values in 'y'. It computes // y[i, :] = v if op is I_UPDATE @@ -37,11 +37,11 @@ enum InplaceOpType { I_SUB, // x -= y }; template -Status DoInplace(const Device& device, InplaceOpType op, const Tensor& i, - const Tensor& v, Tensor* y); +absl::Status DoInplace(const Device& device, InplaceOpType op, const Tensor& i, + const Tensor& v, Tensor* y); // Copies x into y. template -Status DoCopy(const Device& device, const Tensor& x, Tensor* y); +absl::Status DoCopy(const Device& device, const Tensor& x, Tensor* y); } // end namespace functor } // end namespace tensorflow diff --git a/tensorflow/core/kernels/isotonic_regression_op_test.cc b/tensorflow/core/kernels/isotonic_regression_op_test.cc index 4d0ba5af3218d5..0baf0fe69bdc53 100644 --- a/tensorflow/core/kernels/isotonic_regression_op_test.cc +++ b/tensorflow/core/kernels/isotonic_regression_op_test.cc @@ -117,7 +117,7 @@ static void BM_IncreasingSequence(benchmark::State& state) { helper.MakeOp(DT_FLOAT_REF); helper.AddIncreasingInput(batch_size, input_size); state.ResumeTiming(); - Status stat = helper.RunOpKernel(); + absl::Status stat = helper.RunOpKernel(); } state.SetItemsProcessed( static_cast(batch_size * input_size * state.iterations())); diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index 2dfd62fc943dbd..3987478fe115f2 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -52,12 +52,12 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -using ShapeVec = gtl::InlinedVector; -using Labels = gtl::InlinedVector; -using OperandLabels = gtl::InlinedVector; -using LabelCounts = gtl::InlinedVector; -using OperandLabelCounts = gtl::InlinedVector; -using LabelToDimSizes = gtl::InlinedVector; +using ShapeVec = absl::InlinedVector; +using Labels = absl::InlinedVector; +using OperandLabels = absl::InlinedVector; +using LabelCounts = absl::InlinedVector; +using OperandLabelCounts = absl::InlinedVector; +using LabelToDimSizes = absl::InlinedVector; struct EinsumHelper { // Insert new (unnamed) broadcasting labels at the location of ellipsis. @@ -77,9 +77,9 @@ struct EinsumHelper { // Record and validate the label to dimension mapping. Must be a named // (non-broadcasting) label as broadcasting labels don't have a fixed // dimension. - static Status RecordLabelToDimension(const int label, const int axis, - const Tensor& input, - LabelToDimSizes* label_to_dim_sizes) { + static absl::Status RecordLabelToDimension( + const int label, const int axis, const Tensor& input, + LabelToDimSizes* label_to_dim_sizes) { const int64_t input_dim = input.dim_size(axis); // We know that label_to_dim_sizes has the size to accommodate named labels. if (label_to_dim_sizes->at(label) != 0 && @@ -95,9 +95,9 @@ struct EinsumHelper { // Validate input dimensions and populate unnamed labels and their label // counts. - static Status ProcessDimensions( + static absl::Status ProcessDimensions( const OpInputList& inputs, - const gtl::InlinedVector& input_has_ellipsis, + const absl::InlinedVector& input_has_ellipsis, const bool output_has_ellipsis, OperandLabels* input_labels, Labels* output_labels, std::vector* label_types, OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, @@ -192,8 +192,8 @@ struct EinsumHelper { } // Returns a reshaped input Tensor. The underlying buffer is not copied. - static Status CopyFrom(const Tensor& input, const TensorShape& shape, - Tensor* output) { + static absl::Status CopyFrom(const Tensor& input, const TensorShape& shape, + Tensor* output) { if (output->CopyFrom(input, shape)) return absl::OkStatus(); return errors::Internal( "Encountered error while reshaping a Tensor of shape ", @@ -214,9 +214,10 @@ struct EinsumHelper { // Transpose the input given a permutation. Returns a reference to the input // if transposing is not necessary. template - static Status TransposeOperand(OpKernelContext* ctx, const Tensor& input, - const std::vector& permutation, - Tensor* output) { + static absl::Status TransposeOperand(OpKernelContext* ctx, + const Tensor& input, + const std::vector& permutation, + Tensor* output) { if (!ShouldTranspose(input.shape(), permutation)) { return CopyFrom(input, input.shape(), output); } @@ -240,10 +241,11 @@ struct EinsumHelper { // If there are repeated labels in either the input or output, then this // strides the input (e.g. iii->i) or inflates it (e.g. i->iii), respectively. template - static Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, - const Labels& labels, - const LabelCounts& label_counts, - const bool should_inflate, Tensor* output) { + static absl::Status StrideOrInflate(OpKernelContext* ctx, const Tensor& input, + const Labels& labels, + const LabelCounts& label_counts, + const bool should_inflate, + Tensor* output) { // Return early if there are no repeated indices. if (absl::c_all_of(label_counts, [](int c) { return c <= 1; })) { return CopyFrom(input, input.shape(), output); @@ -321,7 +323,7 @@ struct EinsumHelper { const std::vector& label_types) { // Check that ordering is according to dimension type, with the role of // free and contract dimensions swapped. - gtl::InlinedVector remap = {0, 1, 3, 2, 4}; + absl::InlinedVector remap = {0, 1, 3, 2, 4}; for (int i = 0; i + 1 < labels.size(); ++i) { const int dimtype_a = remap[label_types[labels[i]]]; const int dimtype_b = remap[label_types[labels[i + 1]]]; @@ -334,7 +336,7 @@ struct EinsumHelper { } template - static Status ReduceOperand( + static absl::Status ReduceOperand( OpKernelContext* ctx, const Tensor& input, const std::vector& label_types, const LabelCounts& label_counts, Labels* labels, Labels* free_labels, @@ -372,7 +374,7 @@ struct EinsumHelper { // Reshape denotes the rank-5 shape [broadcast, batch, free, contract, // reduce] where we've compacted the dimensions of each EinsumDimensionType. - gtl::InlinedVector reshape(5, 1); + absl::InlinedVector reshape(5, 1); // The output shape is [batch shape] + [free size, contract size] // That is, the batch shape is preserved (for broadcasting while // contracting) while the free dims and contract dims are compressed to one @@ -417,8 +419,8 @@ struct EinsumHelper { } // Reshapes a Tensor of shape [b0,b1...bk,N,M] to [prod(b0,b1...bk),N,M]. - static Status ReshapeToRank3(const Tensor& input, int batch_size, - Tensor* output) { + static absl::Status ReshapeToRank3(const Tensor& input, int batch_size, + Tensor* output) { const int rank = input.dims(); TensorShape output_shape = {batch_size, input.dim_size(rank - 2), input.dim_size(rank - 1)}; @@ -433,10 +435,9 @@ struct EinsumHelper { // functor would be very inefficient. The functor should detect if this is the // case and perform componentwise multiplication functor instead. template - static Status ContractOperands(OpKernelContext* ctx, - absl::Span inputs, - absl::Span swap_free_and_contract, - Tensor* output) { + static absl::Status ContractOperands( + OpKernelContext* ctx, absl::Span inputs, + absl::Span swap_free_and_contract, Tensor* output) { if (inputs.size() == 1) return CopyFrom(inputs[0], inputs[0].shape(), output); MatMulBCast bcast(inputs[0].shape().dim_sizes(), @@ -513,8 +514,8 @@ class EinsumOp : public OpKernel { // dimensions, respectively. const int num_inputs = inputs.size(); OperandLabels free_labels(num_inputs); - gtl::InlinedVector inputs_reduced(num_inputs); - gtl::InlinedVector swap_free_and_contract(num_inputs); + absl::InlinedVector inputs_reduced(num_inputs); + absl::InlinedVector swap_free_and_contract(num_inputs); for (int i = 0; i < num_inputs; ++i) { OP_REQUIRES_OK(ctx, EinsumHelper::ReduceOperand( @@ -603,7 +604,7 @@ class EinsumOp : public OpKernel { Tensor output; OP_REQUIRES_OK(ctx, EinsumHelper::TransposeOperand( ctx, output_inflated, output_permutation, &output)); - ctx->set_output(0, output); + ctx->set_output(0, std::move(output)); } string TraceString(const OpKernelContext& ctx, bool verbose) const override { @@ -627,7 +628,7 @@ class EinsumOp : public OpKernel { std::vector label_types_; OperandLabelCounts input_label_counts_; LabelCounts output_label_counts_; - gtl::InlinedVector input_has_ellipsis_; + absl::InlinedVector input_has_ellipsis_; bool output_has_ellipsis_ = false; }; diff --git a/tensorflow/core/kernels/linalg/linalg_ops_common.h b/tensorflow/core/kernels/linalg/linalg_ops_common.h index d4d66bd4c8f809..d774ad9e6bced5 100644 --- a/tensorflow/core/kernels/linalg/linalg_ops_common.h +++ b/tensorflow/core/kernels/linalg/linalg_ops_common.h @@ -43,7 +43,7 @@ class LinearAlgebraOp : public OpKernel { void Compute(OpKernelContext* context) override; protected: - using TensorShapes = gtl::InlinedVector; + using TensorShapes = absl::InlinedVector; // Returns the number of leading inputs that are to be treated as matrix // inputs. By default this is all the inputs. Derived classes can override // this to tell the base class to ignore one or more trailing inputs. @@ -152,8 +152,8 @@ class LinearAlgebraOp : public OpKernel { OutputMatrixMaps* outputs) = 0; private: - using TensorInputs = gtl::InlinedVector; - using TensorOutputs = gtl::InlinedVector; + using TensorInputs = absl::InlinedVector; + using TensorOutputs = absl::InlinedVector; // This function maps 2-d slices (matrices) of the input and output tensors // using Eigen::Map and calls ComputeMatrix implemented in terms of the // Eigen::MatrixBase API by the derived class. diff --git a/tensorflow/core/kernels/linalg/lu_op.cc b/tensorflow/core/kernels/linalg/lu_op.cc index 770c5d8fe6c67c..e1525bf5937eb6 100644 --- a/tensorflow/core/kernels/linalg/lu_op.cc +++ b/tensorflow/core/kernels/linalg/lu_op.cc @@ -32,8 +32,8 @@ class LuOp : public OpKernel { explicit LuOp(OpKernelConstruction* context) : OpKernel(context) {} protected: - using TensorShapes = gtl::InlinedVector; - using TensorOutputs = gtl::InlinedVector; + using TensorShapes = absl::InlinedVector; + using TensorOutputs = absl::InlinedVector; using Matrix = Eigen::Matrix; diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index fa4b082c9779c2..51c0d4b6654034 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -46,7 +46,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { +absl::Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { if (t.shape() == TensorShape({})) { if ((t.dtype() == DT_INT32 && t.scalar()() == -1) || (t.dtype() == DT_INT64 && t.scalar()() == -1)) { @@ -72,9 +72,9 @@ Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out) { DataTypeString(t.dtype())); } -Status GetElementShapeFromInput(OpKernelContext* c, - const TensorList& tensor_list, int index, - PartialTensorShape* element_shape) { +absl::Status GetElementShapeFromInput(OpKernelContext* c, + const TensorList& tensor_list, int index, + PartialTensorShape* element_shape) { TF_RETURN_IF_ERROR(TensorShapeFromTensor(c->input(index), element_shape)); // Check that `element_shape` and `tensor_list.element_shape` are // compatible and store the merged shape in `element_shape`. @@ -83,7 +83,8 @@ Status GetElementShapeFromInput(OpKernelContext* c, return absl::OkStatus(); } -Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { +absl::Status GetInputList(OpKernelContext* c, int index, + const TensorList** list) { if (!TensorShapeUtils::IsScalar(c->input(index).shape())) { return errors::InvalidArgument("Input list must be a scalar saw: ", c->input(index).shape().DebugString()); @@ -98,10 +99,11 @@ Status GetInputList(OpKernelContext* c, int index, const TensorList** list) { return absl::OkStatus(); } -Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, - int32_t output_index, - const TensorList& input_list, - TensorList** output_list) { +absl::Status ForwardInputOrCreateNewList(OpKernelContext* c, + int32_t input_index, + int32_t output_index, + const TensorList& input_list, + TensorList** output_list) { // Attempt to forward the input tensor to the output if possible. std::unique_ptr maybe_output = c->forward_input( input_index, output_index, DT_VARIANT, TensorShape{}, @@ -697,7 +699,7 @@ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TensorList, TensorListZerosLike); -static Status TensorListDeviceCopy( +static absl::Status TensorListDeviceCopy( const TensorList& from, TensorList* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { to->element_shape = from.element_shape; diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index fec3ebab2aa27f..9837b08716afae 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -56,18 +56,20 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; -Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); +absl::Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); -Status GetElementShapeFromInput(OpKernelContext* c, - const TensorList& tensor_list, int index, - PartialTensorShape* element_shape); +absl::Status GetElementShapeFromInput(OpKernelContext* c, + const TensorList& tensor_list, int index, + PartialTensorShape* element_shape); -Status GetInputList(OpKernelContext* c, int index, const TensorList** list); +absl::Status GetInputList(OpKernelContext* c, int index, + const TensorList** list); -Status ForwardInputOrCreateNewList(OpKernelContext* c, int32_t input_index, - int32_t output_index, - const TensorList& input_list, - TensorList** output_list); +absl::Status ForwardInputOrCreateNewList(OpKernelContext* c, + int32_t input_index, + int32_t output_index, + const TensorList& input_list, + TensorList** output_list); // TODO(penporn): Move this to a proper place. inline bool IsPluggableDevice(OpKernelContext* c) { @@ -825,8 +827,8 @@ class TensorListFromTensor : public OpKernel { // Scatters values in `value` into `list`. Assumes that `indices` are valid. template -Status Scatter(OpKernelContext* c, const Tensor& value, const Tensor& indices, - TensorList* list) { +absl::Status Scatter(OpKernelContext* c, const Tensor& value, + const Tensor& indices, TensorList* list) { const auto copy_tensor = IsPluggableDevice(c) ? &CopyTensorPluggableDevice : &CopyTensor; for (int index = 0; index < indices.NumElements(); ++index) { @@ -978,14 +980,14 @@ class TensorListScatter : public OpKernel { }; template -Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, - const TensorList& b, TensorList* out) { +absl::Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, + const TensorList& b, TensorList* out) { return TensorListBinaryAdd(c, a, b, out, BinaryAddTensors); } template -Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, - TensorList* y) { +absl::Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, + TensorList* y) { return TensorListZerosLike(c, x, y, ZerosLikeTensor); } diff --git a/tensorflow/core/kernels/load_and_remap_matrix_op.cc b/tensorflow/core/kernels/load_and_remap_matrix_op.cc index fb2f9d40495c94..7f57e939555d6f 100644 --- a/tensorflow/core/kernels/load_and_remap_matrix_op.cc +++ b/tensorflow/core/kernels/load_and_remap_matrix_op.cc @@ -34,7 +34,7 @@ namespace tensorflow { namespace { // Returning a Status instead of using OP_REQUIRES directly since that doesn't // seem to work outside the main OpKernel functions. -Status RemapVectorToMap( +absl::Status RemapVectorToMap( const TTypes::Vec& remapping, std::vector* id_present, std::unordered_map* old_id_to_new_id) { id_present->clear(); diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index cdc764ea1cccd4..a9640f553da2b8 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -37,15 +37,15 @@ static mutex* file_mutex = new mutex(); // Appends the given data to the specified file. It will create the file if it // doesn't already exist. -Status AppendStringToFile(const std::string& fname, StringPiece data, - Env* env) { +absl::Status AppendStringToFile(const std::string& fname, StringPiece data, + Env* env) { // TODO(ckluk): If opening and closing on every log causes performance issues, // we can reimplement using reference counters. mutex_lock l(*file_mutex); std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewAppendableFile(fname, &file)); - Status a = file->Append(data); - Status c = file->Close(); + absl::Status a = file->Append(data); + absl::Status c = file->Close(); return a.ok() ? c : a; } diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc index fdb85fda2d70a0..885bd748aeea8d 100644 --- a/tensorflow/core/kernels/logging_ops_test.cc +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -34,7 +34,7 @@ namespace { class PrintingV2GraphTest : public OpsTestBase { protected: - Status Init(const string& output_stream = "log(warning)") { + absl::Status Init(const string& output_stream = "log(warning)") { TF_CHECK_OK(NodeDefBuilder("op", "PrintV2") .Input(FakeInput(DT_STRING)) .Attr("output_stream", output_stream) @@ -61,8 +61,8 @@ TEST_F(PrintingV2GraphTest, InvalidInputRank) { class PrintingGraphTest : public OpsTestBase { protected: - Status Init(DataType input_type1, DataType input_type2, string msg = "", - int first_n = -1, int summarize = 3) { + absl::Status Init(DataType input_type1, DataType input_type2, string msg = "", + int first_n = -1, int summarize = 3) { TF_CHECK_OK(NodeDefBuilder("op", "Print") .Input(FakeInput(input_type1)) .Input(FakeInput(2, input_type2)) @@ -132,7 +132,7 @@ TEST_F(PrintingGraphTest, FirstNSuccess) { class TimestampTest : public OpsTestBase { protected: - Status Init() { + absl::Status Init() { TF_CHECK_OK(NodeDefBuilder("op", "Timestamp").Finalize(node_def())); return InitOp(); } diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h index d06b1228f62a59..d848a1f33b30bc 100644 --- a/tensorflow/core/kernels/logistic-loss.h +++ b/tensorflow/core/kernels/logistic-loss.h @@ -96,7 +96,7 @@ class LogisticLossUpdater : public DualLossUpdater { // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively // as expected by logistic regression. - Status ConvertLabel(float* const example_label) const final { + absl::Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/lookup_table_init_op.h b/tensorflow/core/kernels/lookup_table_init_op.h index 6f72775eb442f9..e94db921bfd237 100644 --- a/tensorflow/core/kernels/lookup_table_init_op.h +++ b/tensorflow/core/kernels/lookup_table_init_op.h @@ -22,10 +22,11 @@ namespace tensorflow { namespace lookup { // Helper function to initialize an InitializableLookupTable from a text file. -Status InitializeTableFromTextFile(const string& filename, int64_t vocab_size, - char delimiter, int32_t key_index, - int32_t value_index, Env* env, - InitializableLookupTable* table); +absl::Status InitializeTableFromTextFile(const string& filename, + int64_t vocab_size, char delimiter, + int32_t key_index, int32_t value_index, + Env* env, + InitializableLookupTable* table); } // namespace lookup } // namespace tensorflow diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 78a2716f0b95d2..49a28dc324b9fb 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -60,8 +60,8 @@ class MutableHashTableOfScalars final : public LookupInterface { return table_.size(); } - Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, - const Tensor& default_value) override { + absl::Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, + const Tensor& default_value) override { const auto key_values = key.flat(); auto value_values = value->flat(); const auto default_flat = default_value.flat(); @@ -86,7 +86,7 @@ class MutableHashTableOfScalars final : public LookupInterface { return absl::OkStatus(); } - Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { + absl::Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { const auto key_values = keys.flat(); const auto value_values = values.flat(); @@ -101,12 +101,12 @@ class MutableHashTableOfScalars final : public LookupInterface { return absl::OkStatus(); } - Status Insert(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) override { + absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) override { return DoInsert(false, keys, values); } - Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) override { const auto key_values = keys.flat(); mutex_lock l(mu_); @@ -116,12 +116,12 @@ class MutableHashTableOfScalars final : public LookupInterface { return absl::OkStatus(); } - Status ImportValues(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) override { + absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) override { return DoInsert(true, keys, values); } - Status ExportValues(OpKernelContext* ctx) override { + absl::Status ExportValues(OpKernelContext* ctx) override { tf_shared_lock l(mu_); int64_t size = table_.size(); @@ -157,7 +157,7 @@ class MutableHashTableOfScalars final : public LookupInterface { return sizeof(MutableHashTableOfScalars) + ret; } - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { tf_shared_lock l(mu_); int64_t size = table_.size(); Tensor keys(key_dtype(), TensorShape({size})); @@ -231,8 +231,8 @@ class MutableHashTableOfTensors final : public LookupInterface { return table_.size(); } - Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, - const Tensor& default_value) override { + absl::Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, + const Tensor& default_value) override { const auto default_flat = default_value.flat_inner_dims(); const auto key_values = key.flat(); auto value_values = value->flat_inner_dims(); @@ -267,7 +267,7 @@ class MutableHashTableOfTensors final : public LookupInterface { return absl::OkStatus(); } - Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { + absl::Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) { const auto key_values = keys.flat(); const auto value_values = values.flat_inner_dims(); int64_t value_dim = value_shape_.dim_size(0); @@ -288,12 +288,12 @@ class MutableHashTableOfTensors final : public LookupInterface { return absl::OkStatus(); } - Status Insert(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) override { + absl::Status Insert(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) override { return DoInsert(false, keys, values); } - Status Remove(OpKernelContext* ctx, const Tensor& keys) override { + absl::Status Remove(OpKernelContext* ctx, const Tensor& keys) override { const auto key_values = keys.flat(); mutex_lock l(mu_); @@ -303,12 +303,12 @@ class MutableHashTableOfTensors final : public LookupInterface { return absl::OkStatus(); } - Status ImportValues(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) override { + absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) override { return DoInsert(true, keys, values); } - Status ExportValues(OpKernelContext* ctx) override { + absl::Status ExportValues(OpKernelContext* ctx) override { tf_shared_lock l(mu_); int64_t size = table_.size(); int64_t value_dim = value_shape_.dim_size(0); @@ -345,7 +345,7 @@ class MutableHashTableOfTensors final : public LookupInterface { return sizeof(MutableHashTableOfTensors) + ret; } - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { tf_shared_lock l(mu_); int64_t size = table_.size(); Tensor keys(key_dtype(), TensorShape({size})); @@ -496,8 +496,9 @@ class MutableDenseHashTable final : public LookupInterface { return num_entries_; } - Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, - const Tensor& default_value) override TF_LOCKS_EXCLUDED(mu_) { + absl::Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value, + const Tensor& default_value) override + TF_LOCKS_EXCLUDED(mu_) { const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64_t key_size = key_shape_.num_elements(); const int64_t value_size = value_shape_.num_elements(); @@ -563,8 +564,8 @@ class MutableDenseHashTable final : public LookupInterface { return absl::OkStatus(); } - Status Insert(OpKernelContext* ctx, const Tensor& key, - const Tensor& value) override TF_LOCKS_EXCLUDED(mu_) { + absl::Status Insert(OpKernelContext* ctx, const Tensor& key, + const Tensor& value) override TF_LOCKS_EXCLUDED(mu_) { const int64_t batch_size = (key.dims() == 0) ? 1 : key.dim_size(0); if (key.NumElements() != batch_size * key_shape_.num_elements()) { TensorShape expected_shape({batch_size}); @@ -589,7 +590,7 @@ class MutableDenseHashTable final : public LookupInterface { return DoInsert(ctx, key, value, false); } - Status Remove(OpKernelContext* ctx, const Tensor& key) override + absl::Status Remove(OpKernelContext* ctx, const Tensor& key) override TF_LOCKS_EXCLUDED(mu_) { if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) { TensorShape expected_shape({key.dim_size(0)}); @@ -602,8 +603,9 @@ class MutableDenseHashTable final : public LookupInterface { return DoRemove(ctx, key); } - Status ImportValues(OpKernelContext* ctx, const Tensor& keys, - const Tensor& values) override TF_LOCKS_EXCLUDED(mu_) { + absl::Status ImportValues(OpKernelContext* ctx, const Tensor& keys, + const Tensor& values) override + TF_LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); num_buckets_ = keys.dim_size(0); key_buckets_ = keys; @@ -626,15 +628,16 @@ class MutableDenseHashTable final : public LookupInterface { return absl::OkStatus(); } - Status ExportValues(OpKernelContext* ctx) override TF_LOCKS_EXCLUDED(mu_) { + absl::Status ExportValues(OpKernelContext* ctx) override + TF_LOCKS_EXCLUDED(mu_) { tf_shared_lock l(mu_); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_)); TF_RETURN_IF_ERROR(ctx->set_output("values", value_buckets_)); return absl::OkStatus(); } - Status CheckKeyAndValueTensorsForImport(const Tensor& keys, - const Tensor& values) override { + absl::Status CheckKeyAndValueTensorsForImport(const Tensor& keys, + const Tensor& values) override { TF_RETURN_IF_ERROR(CheckKeyAndValueTypes(keys, values)); TF_RETURN_IF_ERROR(CheckKeyShape(keys.shape())); @@ -673,8 +676,8 @@ class MutableDenseHashTable final : public LookupInterface { } private: - Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value, - bool ignore_empty_and_deleted_key) + absl::Status DoInsert(OpKernelContext* ctx, const Tensor& key, + const Tensor& value, bool ignore_empty_and_deleted_key) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { const int64_t num_elements = (key.dims() == 0) ? 1 : key.dim_size(0); const int64_t value_size = value_shape_.num_elements(); @@ -743,7 +746,7 @@ class MutableDenseHashTable final : public LookupInterface { return absl::OkStatus(); } - Status DoRemove(OpKernelContext* ctx, const Tensor& key) + absl::Status DoRemove(OpKernelContext* ctx, const Tensor& key) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { const int64_t num_elements = key.dim_size(0); const int64_t key_size = key_shape_.num_elements(); @@ -794,7 +797,7 @@ class MutableDenseHashTable final : public LookupInterface { return absl::OkStatus(); } - Status AllocateBuckets(OpKernelContext* ctx, int64_t new_num_buckets) + absl::Status AllocateBuckets(OpKernelContext* ctx, int64_t new_num_buckets) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (new_num_buckets < 4 || ((new_num_buckets & (new_num_buckets - 1)) != 0)) { @@ -832,7 +835,7 @@ class MutableDenseHashTable final : public LookupInterface { return absl::OkStatus(); } - Status Rebucket(OpKernelContext* ctx, int64_t num_new_buckets) + absl::Status Rebucket(OpKernelContext* ctx, int64_t num_new_buckets) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { Tensor old_key_buckets = key_buckets_; Tensor old_value_buckets = value_buckets_; @@ -889,7 +892,7 @@ class LookupTableOpKernel : public OpKernel { : DT_STRING_REF) {} protected: - Status GetTable(OpKernelContext* ctx, lookup::LookupInterface** table) { + absl::Status GetTable(OpKernelContext* ctx, lookup::LookupInterface** table) { if (expected_input_0_ == DT_RESOURCE) { return GetResourceLookupTable("table_handle", ctx, table); } else { diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index f4855a22d73665..daa7f6e32dc9dd 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -221,7 +221,7 @@ class HashTable : public InitializableLookupTable { public: HashTable(OpKernelContext* ctx, OpKernel* kernel) {} - Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { + absl::Status AsGraphDef(GraphDefBuilder* builder, Node** out) const override { // We set use_node_name_sharing with a unique node name so that the resource // can outlive the HashTableV2 kernel. This means that the lifetime of the // HashTable resource will be tied to the lifetime of the resource manager @@ -261,7 +261,7 @@ class HashTable : public InitializableLookupTable { return table_.size(); } - Status ExportValues(OpKernelContext* context) override { + absl::Status ExportValues(OpKernelContext* context) override { if (!is_initialized()) { return errors::Aborted("HashTable is not initialized."); } @@ -290,7 +290,7 @@ class HashTable : public InitializableLookupTable { DataType value_dtype() const override { return DataTypeToEnum::v(); } protected: - Status DoPrepare(size_t size) override { + absl::Status DoPrepare(size_t size) override { if (is_initialized()) { return errors::Aborted("HashTable already initialized."); } @@ -300,11 +300,11 @@ class HashTable : public InitializableLookupTable { return absl::OkStatus(); }; - Status DoLazyPrepare(std::function size_fn) override { + absl::Status DoLazyPrepare(std::function size_fn) override { return DoPrepare(size_fn()); } - Status DoInsert(const Tensor& keys, const Tensor& values) override { + absl::Status DoInsert(const Tensor& keys, const Tensor& values) override { const auto key_values = keys.flat(); const auto value_values = values.flat(); for (int64_t i = 0; i < key_values.size(); ++i) { @@ -320,8 +320,8 @@ class HashTable : public InitializableLookupTable { return absl::OkStatus(); } - Status DoFind(const Tensor& key, Tensor* value, - const Tensor& default_value) override { + absl::Status DoFind(const Tensor& key, Tensor* value, + const Tensor& default_value) override { const V default_val = default_value.flat()(0); const auto key_values = key.flat(); auto value_values = value->flat(); diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index 44b93c75e5c988..c2b29cb96bf5c8 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -305,8 +305,8 @@ Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext* ctx, return ctx->resource_manager()->Lookup(container, table_handle, table); } -Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, - LookupInterface** table) { +absl::Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, + LookupInterface** table) { DataType handle_dtype; TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype)); if (handle_dtype == DT_RESOURCE) { diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h index 2094f894bc5b4b..ca0e93833b04cb 100644 --- a/tensorflow/core/kernels/lookup_util.h +++ b/tensorflow/core/kernels/lookup_util.h @@ -33,33 +33,38 @@ namespace lookup { // passed by attribute with name input_name, returns null if the table // doesn't exist. Use GetResourceLookupTable() or GetReferenceLookupTable() if // the input dtype is known. -Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, - LookupInterface** table); -Status GetResourceLookupTable(StringPiece input_name, OpKernelContext* ctx, - LookupInterface** table); -Status GetReferenceLookupTable(StringPiece input_name, OpKernelContext* ctx, - LookupInterface** table); +absl::Status GetLookupTable(StringPiece input_name, OpKernelContext* ctx, + LookupInterface** table); +absl::Status GetResourceLookupTable(StringPiece input_name, + OpKernelContext* ctx, + LookupInterface** table); +absl::Status GetReferenceLookupTable(StringPiece input_name, + OpKernelContext* ctx, + LookupInterface** table); // Gets the InitializableLookupTable stored in the // ctx->resource_manager() with key passed by attribute with name // input_name, returns null if the table doesn't exist. -Status GetInitializableLookupTable(StringPiece input_name, OpKernelContext* ctx, - InitializableLookupTable** table); +absl::Status GetInitializableLookupTable(StringPiece input_name, + OpKernelContext* ctx, + InitializableLookupTable** table); // Verify that the given key_dtype and value_dtype matches the corresponding // table's data types. -Status CheckTableDataTypes(const LookupInterface& table, DataType key_dtype, - DataType value_dtype, const string& table_name); +absl::Status CheckTableDataTypes(const LookupInterface& table, + DataType key_dtype, DataType value_dtype, + const string& table_name); // Initializes `table` from `filename`. -Status InitializeTableFromTextFile(const string& filename, int64_t vocab_size, - char delimiter, int32_t key_index, - int32_t value_index, int64_t offset, - Env* env, InitializableLookupTable* table); +absl::Status InitializeTableFromTextFile(const string& filename, + int64_t vocab_size, char delimiter, + int32_t key_index, int32_t value_index, + int64_t offset, Env* env, + InitializableLookupTable* table); // Initializes `table` from `filename`. `func` may specify how to represent the // initializer as a graphdef, so that the table can be serialized as metadata. -Status InitializeTableFromTextFile( +absl::Status InitializeTableFromTextFile( const string& filename, int64_t vocab_size, char delimiter, int32_t key_index, int32_t value_index, int64_t offset, Env* env, std::unique_ptr serializer, diff --git a/tensorflow/core/kernels/loss.h b/tensorflow/core/kernels/loss.h index 7db348800e92a3..85893ba8042983 100644 --- a/tensorflow/core/kernels/loss.h +++ b/tensorflow/core/kernels/loss.h @@ -52,7 +52,7 @@ class DualLossUpdater { // Converts binary example labels from 0.0 or 1.0 to appropriate range for // each loss function. - virtual Status ConvertLabel(float* const example_label) const = 0; + virtual absl::Status ConvertLabel(float* const example_label) const = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc index cc5e1cd9cc74fe..fcff157a630655 100644 --- a/tensorflow/core/kernels/loss_test.cc +++ b/tensorflow/core/kernels/loss_test.cc @@ -193,7 +193,7 @@ TEST(HingeLoss, ComputeDualLoss) { TEST(HingeLoss, ConvertLabel) { HingeLossUpdater loss_updater; float example_label = 1.0; - Status status; + absl::Status status; // A label with value 1.0 should remain intact. TF_EXPECT_OK(loss_updater.ConvertLabel(&example_label)); @@ -338,7 +338,7 @@ TEST(PoissonLoss, ConvertLabel) { PoissonLossUpdater loss_updater; float example_label = -1.0; // Negative label should throw an error. - Status status = loss_updater.ConvertLabel(&example_label); + absl::Status status = loss_updater.ConvertLabel(&example_label); EXPECT_FALSE(status.ok()); } diff --git a/tensorflow/core/kernels/map_kernels.h b/tensorflow/core/kernels/map_kernels.h index ad01ef15932661..6949ff554a286b 100644 --- a/tensorflow/core/kernels/map_kernels.h +++ b/tensorflow/core/kernels/map_kernels.h @@ -22,8 +22,8 @@ limitations under the License. namespace tensorflow { -inline Status GetInputMap(OpKernelContext* ctx, int index, - const TensorMap** ret_map) { +inline absl::Status GetInputMap(OpKernelContext* ctx, int index, + const TensorMap** ret_map) { if (!TensorShapeUtils::IsScalar(ctx->input(index).shape())) { return errors::InvalidArgument("Input map must be a scalar. Saw: ", ctx->input(index).shape().DebugString()); @@ -39,11 +39,11 @@ inline Status GetInputMap(OpKernelContext* ctx, int index, } // TODO(kattian): change into templated function -inline Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, - int32_t input_index, - int32_t output_index, - const TensorMap& input_map, - TensorMap** output_map) { +inline absl::Status ForwardInputOrCreateNewMap(OpKernelContext* ctx, + int32_t input_index, + int32_t output_index, + const TensorMap& input_map, + TensorMap** output_map) { // Attempt to forward the input tensor to the output if possible. std::unique_ptr maybe_output = ctx->forward_input( input_index, output_index, DT_VARIANT, TensorShape{}, @@ -223,8 +223,8 @@ class TensorMapStackKeys : public OpKernel { }; template -Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, - const TensorMap& b, TensorMap* out) { +absl::Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, + const TensorMap& b, TensorMap* out) { // Binary add returns a map containing the union of keys. // Values with keys in the intersection are added. out->tensors() = a.tensors(); @@ -244,8 +244,8 @@ Status TensorMapBinaryAdd(OpKernelContext* ctx, const TensorMap& a, } template -Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, - TensorMap* y) { +absl::Status TensorMapZerosLike(OpKernelContext* ctx, const TensorMap& x, + TensorMap* y) { // Zeros like returns an empty map. return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index bf16340081466a..e8076eb02b8393 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -163,10 +163,10 @@ class StagingMap : public ResourceBase { } // Check that the index is within bounds - Status check_index(const Tensor& key, std::size_t index) + absl::Status check_index(const Tensor& key, std::size_t index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (index >= dtypes_.size()) { - return Status(errors::InvalidArgument( + return absl::Status(errors::InvalidArgument( "Index '", index, "' for key '", key.scalar()(), "' was out of bounds '", dtypes_.size(), "'.")); } @@ -174,9 +174,9 @@ class StagingMap : public ResourceBase { return absl::OkStatus(); } - Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, - const Tensor& indices, Tuple* output, - bool copy = false) + absl::Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key, + const Tensor& indices, Tuple* output, + bool copy = false) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); @@ -188,7 +188,7 @@ class StagingMap : public ResourceBase { // Insist on a value present at the specified index if (!(*map_tuple)[index].has_value()) { - return Status(errors::InvalidArgument( + return absl::Status(errors::InvalidArgument( "Tensor at index '", index, "' for key '", key.scalar()(), "' has already been removed.")); } @@ -208,8 +208,8 @@ class StagingMap : public ResourceBase { // Check that the optional value at the specified index // is uninitialized - Status check_index_uninitialized(const Tensor& key, std::size_t index, - const OptionalTuple& tuple) + absl::Status check_index_uninitialized(const Tensor& key, std::size_t index, + const OptionalTuple& tuple) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (tuple[index].has_value()) { return errors::InvalidArgument("The tensor for index '", index, @@ -222,7 +222,7 @@ class StagingMap : public ResourceBase { } // Check that the indices are strictly ordered - Status check_index_ordering(const Tensor& indices) { + absl::Status check_index_ordering(const Tensor& indices) { if (indices.NumElements() == 0) { return errors::InvalidArgument("Indices are empty"); } @@ -241,7 +241,7 @@ class StagingMap : public ResourceBase { } // Check bytes are within memory limits memory limits - Status check_memory_limit(std::size_t bytes) + absl::Status check_memory_limit(std::size_t bytes) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (has_memory_limit() && bytes > memory_limit_) { return errors::ResourceExhausted( @@ -254,8 +254,9 @@ class StagingMap : public ResourceBase { } // Insert incomplete data into the Barrier - Status put_incomplete(const KeyType& key, const Tensor& indices, - OptionalTuple* tuple, tensorflow::mutex_lock* lock) + absl::Status put_incomplete(const KeyType& key, const Tensor& indices, + OptionalTuple* tuple, + tensorflow::mutex_lock* lock) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto findices = indices.flat(); @@ -331,7 +332,7 @@ class StagingMap : public ResourceBase { } // Does the insertion into the actual staging area - Status put_complete(const KeyType& key, OptionalTuple* tuple) + absl::Status put_complete(const KeyType& key, OptionalTuple* tuple) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { // Insert key and tuples into the map map_.insert({key, std::move(*tuple)}); @@ -350,7 +351,7 @@ class StagingMap : public ResourceBase { memory_limit_(memory_limit), current_bytes_(0) {} - Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) { + absl::Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) { tensorflow::mutex_lock lock(mu_); // Sanity check the indices @@ -379,7 +380,7 @@ class StagingMap : public ResourceBase { return absl::OkStatus(); } - Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { + absl::Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) { tensorflow::mutex_lock lock(mu_); // Sanity check the indices @@ -401,7 +402,7 @@ class StagingMap : public ResourceBase { return absl::OkStatus(); } - Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { + absl::Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) { tensorflow::mutex_lock lock(mu_); // Sanity check the indices @@ -432,7 +433,7 @@ class StagingMap : public ResourceBase { return absl::OkStatus(); } - Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { + absl::Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) { tensorflow::mutex_lock lock(mu_); // Sanity check the indices @@ -467,7 +468,7 @@ class StagingMap : public ResourceBase { return absl::OkStatus(); } - Status clear() { + absl::Status clear() { tensorflow::mutex_lock lock(mu_); map_.clear(); incomplete_.clear(); @@ -492,13 +493,13 @@ class StagingMap : public ResourceBase { }; template -Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef, - StagingMap** map) { +absl::Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef, + StagingMap** map) { auto rm = ctx->resource_manager(); ContainerInfo cinfo; // Lambda for creating the Staging Area - auto create_fn = [&ndef](StagingMap** ret) -> Status { + auto create_fn = [&ndef](StagingMap** ret) -> absl::Status { DataTypeVector dtypes; int64_t capacity; int64_t memory_limit; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 3c24bdc6965c86..70fc941e80fa15 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -916,7 +916,7 @@ class BaseBatchMatMulOp : public OpKernel { const Tensor& in0 = ctx->input(0); const Tensor& in1 = ctx->input(1); - const Status s = ValidateInputTensors(ctx, in0, in1); + const absl::Status s = ValidateInputTensors(ctx, in0, in1); if (!s.ok()) { ctx->SetStatus(s); return; @@ -1020,8 +1020,9 @@ class BaseBatchMatMulOp : public OpKernel { } protected: - virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) = 0; + virtual absl::Status ValidateInputTensors(OpKernelContext* ctx, + const Tensor& in0, + const Tensor& in1) = 0; private: // TODO(171979567) Make the ops take both adj and transpose attributes. @@ -1052,8 +1053,8 @@ class BatchMatMulOp : public BaseBatchMatMulOp { ~BatchMatMulOp() override {} private: - Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) override { + absl::Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { // Disallow broadcasting support. Ensure that all batch dimensions of the // input tensors match. if (in0.dims() != in1.dims()) { @@ -1097,8 +1098,8 @@ class BatchMatMulV2Op : public BaseBatchMatMulOp { ~BatchMatMulV2Op() override {} private: - Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, - const Tensor& in1) override { + absl::Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, + const Tensor& in1) override { // Enable broadcasting support. Validity of broadcasting is checked in // BaseBatchMatMulOp. if (in0.dims() < 2) { diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 64035190934adf..bc99ad59db4543 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -15,6 +15,8 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #define EIGEN_USE_THREADS #include "tensorflow/core/kernels/maxpooling_op.h" @@ -889,6 +891,13 @@ class MaxPoolingNoMaskV2Op : public OpKernel { OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " "specify 4 dimensions")); + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); OP_REQUIRES(context, stride_.size() == 4, errors::InvalidArgument("Sliding window stride field must " @@ -1116,6 +1125,14 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " "specify 4 dimensions")); + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); OP_REQUIRES(context, stride_.size() == 4, errors::InvalidArgument("Sliding window stride field must " @@ -1261,6 +1278,14 @@ class MaxPoolingNoMaskOp : public OpKernel { OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " "specify 4 dimensions")); + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); OP_REQUIRES(context, stride_.size() == 4, errors::InvalidArgument("Sliding window stride field must " diff --git a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h index c80fb797566f0b..da031d5cd1246d 100644 --- a/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_pooling_ops_common.h @@ -507,8 +507,10 @@ class MklPoolingOpBase : public OpKernel { "specify 4 or 5 dimensions")); for (int i = 0; i < this->ksize_.size(); ++i) { OP_REQUIRES(context, this->ksize_[i] > 0, - absl::InvalidArgumentError(absl::StrCat( - "Sliding window ksize for dimension ", i, " was zero."))); + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); } OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index b88ee74c0999ef..428b0a5fa8d8c3 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -171,7 +171,7 @@ tf_kernel_library( ":gpu_sinh_kernels", ":gpu_tan_kernels", ":gpu_tanh_kernels", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -253,7 +253,7 @@ tf_kernel_library( ":gpu_greater_equal_kernels", ":gpu_less_equal_kernels", ":gpu_less_kernels", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -302,7 +302,7 @@ tf_kernel_library( ":gpu_ones_like_kernels", ":gpu_zeros_like_kernels", "@eigen_archive//:eigen3", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -323,7 +323,7 @@ tf_kernel_library( deps = if_mlir_generated_gpu_kernels_enabled([ ":base_gpu_op", ":gpu_next_after_kernels", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -349,7 +349,7 @@ tf_kernel_library( ":gpu_relu_kernels", ":gpu_selu_kernels", "@eigen_archive//:eigen3", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -371,7 +371,7 @@ tf_kernel_library( ":base_gpu_op", ":gpu_softplus_kernels", "@eigen_archive//:eigen3", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -393,7 +393,7 @@ tf_kernel_library( ":base_gpu_op", ":gpu_softsign_kernels", "@eigen_archive//:eigen3", - ]), + ]) + ["//tensorflow/core/framework:types_proto_cc"], ) tf_kernel_library( @@ -460,8 +460,16 @@ tf_cuda_cc_test( deps = [ ":base_ops_test", ":base_unary_ops_test", + "//tensorflow/core:framework", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:lib", "//tensorflow/core/common_runtime:device", "//tensorflow/core/common_runtime:device_factory", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_googletest//:gtest_main", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -481,8 +489,11 @@ tf_cuda_cc_test( deps = [ ":base_ops_test", ":base_unary_ops_test", + "//tensorflow/core:framework", + "//tensorflow/core:framework_types_hdr", "//tensorflow/core/common_runtime:device", "//tensorflow/core/common_runtime:device_factory", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc index cc0de521eb9c91..7bfa3e0686c6d8 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_and.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc index fdce45d41ab7ca..0b9ff8ee78b1b9 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_not.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc index d4a5cb63e156b5..6406e602dd317c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_logical_or.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc index 6a8ed996bc7ee8..64e9342e4e75b4 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_maximum.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc index 48af5bcdc557ae..68da3737d810ea 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_minimum.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc index 512f3ca658b517..aba7bfa571b0f7 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_mul.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_mul_no_nan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_mul_no_nan.cc index 72f6cec8879bc8..a171b7a4b5b734 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_mul_no_nan.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_mul_no_nan.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc index ee40545ec858a6..7bcaf4b70d9c18 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_neg.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_next_after.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_next_after.cc index cfadd3a596d13d..7e3686766a4a83 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_next_after.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_next_after.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc index 680fd0cee39a36..887ca37c2242dc 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_not_equal.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_ones_like.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_ones_like.cc index cdb306f8378dab..93f2316a5c8dfc 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_ones_like.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_ones_like.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_polygamma.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_polygamma.cc index 1e15d28e3db44d..87adf03abbd315 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_polygamma.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_polygamma.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc index b30e4a9506fe29..256943985b5200 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_pow.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc index fe95325a19dd2f..2c400aca27abb1 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_real.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_reciprocal.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_reciprocal.cc index 172edd2e353bb0..7348466384e401 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_reciprocal.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_reciprocal.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_relu.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_relu.cc index 3b817f1eee74b2..df58f50a86a116 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_relu.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_relu.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc index fad9cbaebbf5fd..00e06ab237dcd1 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_right_shift.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_rint.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_rint.cc index be7f9dbf4ddab1..2dec8800ec6e98 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_rint.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_rint.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_round.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_round.cc index b45462bb03a220..56aa174253724b 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_round.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_round.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc index 3419b22e8785ca..315b69e4a246d7 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_rsqrt.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_select.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_select.cc index 994414da853daf..dc5e704d18dbec 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_select.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_select.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_selu.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_selu.cc index 3a3f5c2d08ec0d..5c6605bb3f8f26 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_selu.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_selu.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sigmoid.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sigmoid.cc index 2573c339fa6b39..84033a79d5c6b0 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sigmoid.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sigmoid.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc index 1ef9e48e2a66d2..70dbd0c17d2e8c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sign.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc index 2f417ed7b126e2..8ce140d4eda64f 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sin.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc index b42e23157ea33b..fb9a5f6e3e194c 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sinh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_softplus.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_softplus.cc index be5ac82ba9777a..35e5d6c12cc0e2 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_softplus.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_softplus.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_softsign.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_softsign.cc index 35b3217ce700dc..4ce186b052c9c6 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_softsign.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_softsign.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc index c304ed96dfa53c..61afccb02346c2 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sqrt.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc index 2874c946a73c32..11af03569b3393 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_square.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc index fdc82a8d9128eb..45cc2d18fe3282 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_squared_difference.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc index b0af6ca3d42d0d..b2f2f355969f68 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_sub.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc index 00d4e3e9943788..7b5b2ac20fcdc7 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_tan.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc index acd875e3fa46c4..7f1f8698ce9e1a 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_tanh.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_truncate_div.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_truncate_div.cc index 7626402b6152c8..cd633250fb8cc1 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_truncate_div.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_truncate_div.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_xdivy.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_xdivy.cc index d87014b73f3c27..798f3a56d76131 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_xdivy.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_xdivy.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_xlog1py.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_xlog1py.cc index 24176c50f96f08..176ae30c047d12 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_xlog1py.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_xlog1py.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_xlogy.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_xlogy.cc index 98ca95abdcedcc..ce0d3739cb9da4 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_xlogy.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_xlogy.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_zeros_like.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_zeros_like.cc index 11f24ad5258d1a..c30f17fbaacd63 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_zeros_like.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_zeros_like.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc b/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc index 3c597c6c805236..2726d6b89e4f89 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_op_zeta.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_gpu_op.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_large_tensor_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_large_tensor_test.cc index b8f46b76d50321..4b783595441535 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_large_tensor_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_large_tensor_test.cc @@ -17,6 +17,10 @@ limitations under the License. #include #include +#include +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" #include "tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h" diff --git a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc index eb9414a0563537..4a2e9744215fcd 100644 --- a/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc +++ b/tensorflow/core/kernels/mlir_generated/gpu_unary_ops_test.cc @@ -23,8 +23,17 @@ limitations under the License. #include #include +#include +#include "Eigen/Core" // from @eigen_archive +#include "llvm/ADT/STLExtras.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/mlir_generated/base_ops_test.h" #include "tensorflow/core/kernels/mlir_generated/base_unary_ops_test.h" +#include "tensorflow/core/platform/env.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc index 6147afc73e58a0..19d2d909658311 100644 --- a/tensorflow/core/kernels/mutex_ops.cc +++ b/tensorflow/core/kernels/mutex_ops.cc @@ -76,7 +76,7 @@ class Mutex : public ResourceBase { void AcquireAsync( OpKernelContext* c, - std::function fn) { + std::function fn) { CancellationManager* cm = c->cancellation_manager(); CancellationToken token{}; bool* cancelled = nullptr; @@ -98,7 +98,8 @@ class Mutex : public ResourceBase { } thread_pool_->Schedule(std::bind( [this, cm, cancelled, - token](std::function + token](std::function fn_) { bool local_locked; { @@ -158,7 +159,7 @@ class MutexLockOp : public AsyncOpKernel { c, std::bind( [c, variant, mutex](DoneCallback done_, // End of bound arguments. - const Status& s, + const absl::Status& s, Mutex::SharedLockReleaser&& lock) { VLOG(2) << "Finished locking mutex " << mutex << " with lock: " << lock.shared_ptr.get() diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc index 405b58edbfc979..545239f3078092 100644 --- a/tensorflow/core/kernels/nn_ops_test.cc +++ b/tensorflow/core/kernels/nn_ops_test.cc @@ -817,7 +817,7 @@ static void BM_LRNFloat(::testing::benchmark::State& state, int depth, int cols, .Attr("beta", 0.5) .Finalize(&lrn_node_def)); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), lrn_node_def, TF_GRAPH_DEF_VERSION, &status)); @@ -890,12 +890,13 @@ static void BM_AvgPool(::testing::benchmark::State& state, int batch_size, // AvgPooling op. NodeDef avgpool_node_def; CHECK_EQ(kernel_rows, kernel_cols); - Status status = NodeDefBuilder("avgpool_op", "AvgPool") - .Input(FakeInput(DT_FLOAT)) - .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", padding == VALID ? "VALID" : "SAME") - .Finalize(&avgpool_node_def); + absl::Status status = + NodeDefBuilder("avgpool_op", "AvgPool") + .Input(FakeInput(DT_FLOAT)) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&avgpool_node_def); TF_CHECK_OK(status); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), @@ -993,13 +994,14 @@ static void BM_AvgPoolBk(::testing::benchmark::State& state, int batch_size, // AvgPoolGrad op. NodeDef avgpool_grad_node_def; - Status status = NodeDefBuilder("avgpool_grad_op", "AvgPoolGrad") - .Input(FakeInput()) - .Input(FakeInput(DT_FLOAT)) - .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", padding == VALID ? "VALID" : "SAME") - .Finalize(&avgpool_grad_node_def); + absl::Status status = + NodeDefBuilder("avgpool_grad_op", "AvgPoolGrad") + .Input(FakeInput()) + .Input(FakeInput(DT_FLOAT)) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&avgpool_grad_node_def); TF_CHECK_OK(status); std::unique_ptr op( CreateOpKernel(DEVICE_CPU, nullptr, cpu_allocator(), @@ -1085,12 +1087,13 @@ static void BM_MaxPool(::testing::benchmark::State& state, int batch_size, // MaxPooling op. NodeDef maxpool_node_def; CHECK_EQ(kernel_rows, kernel_cols); - Status status = NodeDefBuilder("maxpool_op", "MaxPool") - .Input(FakeInput()) - .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) - .Attr("strides", {1, stride, stride, 1}) - .Attr("padding", padding == VALID ? "VALID" : "SAME") - .Finalize(&maxpool_node_def); + absl::Status status = + NodeDefBuilder("maxpool_op", "MaxPool") + .Input(FakeInput()) + .Attr("ksize", {1, kernel_rows, kernel_cols, 1}) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", padding == VALID ? "VALID" : "SAME") + .Finalize(&maxpool_node_def); TF_CHECK_OK(status); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), maxpool_node_def, @@ -1270,9 +1273,9 @@ static void BM_ReluFloat(::testing::benchmark::State& state, int batch_size, // Reluing op. NodeDef relu_node_def; - Status status = NodeDefBuilder("relu_op", "Relu") - .Input(FakeInput(DT_FLOAT)) - .Finalize(&relu_node_def); + absl::Status status = NodeDefBuilder("relu_op", "Relu") + .Input(FakeInput(DT_FLOAT)) + .Finalize(&relu_node_def); TF_CHECK_OK(status); std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), relu_node_def, @@ -1341,9 +1344,9 @@ static void BM_SoftplusFloat(::testing::benchmark::State& state, int batch_size, // Softplusing op. NodeDef softplus_node_def; - Status status = NodeDefBuilder("softplus_op", "Softplus") - .Input(FakeInput(DT_FLOAT)) - .Finalize(&softplus_node_def); + absl::Status status = NodeDefBuilder("softplus_op", "Softplus") + .Input(FakeInput(DT_FLOAT)) + .Finalize(&softplus_node_def); TF_CHECK_OK(status); std::unique_ptr op( CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), diff --git a/tensorflow/core/kernels/ops_testutil.cc b/tensorflow/core/kernels/ops_testutil.cc index 06762dbf9652eb..4efbac731bcaf2 100644 --- a/tensorflow/core/kernels/ops_testutil.cc +++ b/tensorflow/core/kernels/ops_testutil.cc @@ -135,11 +135,11 @@ void OpsTestBase::set_node_def(const NodeDef& node_def) { NodeDef* OpsTestBase::node_def() { return &node_def_; } -Status OpsTestBase::InitOp() { +absl::Status OpsTestBase::InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); } -Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) { +absl::Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) { std::shared_ptr props; TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef( node_def_, OpRegistry::Global(), &props)); @@ -189,7 +189,7 @@ void OpsTestBase::CreateContext() { context_.reset(new OpKernelContext(params_.get())); } -Status OpsTestBase::RunOpKernel() { +absl::Status OpsTestBase::RunOpKernel() { CreateContext(); device_->Compute(kernel_.get(), context_.get()); return context_->status(); diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index d098d1e6265add..ef4a7cd5142cde 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -86,10 +86,10 @@ class OpsTestBase : public ::testing::Test { // and output types as output. // // Returns the status of initialization. - Status InitOp(); + absl::Status InitOp(); // Only use this directly if you have a deprecated op that you need to test. - Status InitOpWithGraphVersion(int graph_def_version); + absl::Status InitOpWithGraphVersion(int graph_def_version); // Adds an input for every element described by the shape. // 'input_mapping' maps an index (0...NumElements(shape)) to a @@ -133,7 +133,7 @@ class OpsTestBase : public ::testing::Test { // Runs an operation producing 'num_outputs' outputs. // // Returns the context's status after running the operation. - Status RunOpKernel(); + absl::Status RunOpKernel(); // Returns the tensor input for 'input_index'. // diff --git a/tensorflow/core/kernels/ops_util_test.cc b/tensorflow/core/kernels/ops_util_test.cc index 82275979f758e1..f225de18e5142d 100644 --- a/tensorflow/core/kernels/ops_util_test.cc +++ b/tensorflow/core/kernels/ops_util_test.cc @@ -71,7 +71,7 @@ class OpsUtilTest : public ::testing::Test { static void VerifyGet2dOutputSizeBoundaries(padding_struct pad_struct, error::Code code) { int64_t new_height, new_width, pad_rows, pad_cols; - Status status = GetWindowedOutputSize( + absl::Status status = GetWindowedOutputSize( pad_struct.input.in_height, pad_struct.input.filter_height, /*dilation_rate=*/1, pad_struct.input.row_stride, pad_struct.input.padding, &new_height, &pad_rows); @@ -86,7 +86,7 @@ class OpsUtilTest : public ::testing::Test { static void VerifyGet2dOutputSizeValues(padding_struct pad_struct, error::Code code) { int64_t new_height, new_width, pad_rows, pad_cols; - Status status = GetWindowedOutputSize( + absl::Status status = GetWindowedOutputSize( pad_struct.input.in_height, pad_struct.input.filter_height, /*dilation_rate=*/1, pad_struct.input.row_stride, pad_struct.input.padding, &new_height, &pad_rows); @@ -105,7 +105,7 @@ class OpsUtilTest : public ::testing::Test { static void VerifyGet2dOutputVerboseSizeValues(padding_struct pad_struct, error::Code code) { int64_t new_height, new_width, pad_top, pad_bottom, pad_left, pad_right; - Status status = GetWindowedOutputSizeVerbose( + absl::Status status = GetWindowedOutputSizeVerbose( pad_struct.input.in_height, pad_struct.input.filter_height, /*dilation_rate=*/1, pad_struct.input.row_stride, pad_struct.input.padding, &new_height, &pad_top, &pad_bottom); @@ -125,7 +125,7 @@ class OpsUtilTest : public ::testing::Test { static void VerifyBoundaries(bcast_struct bcast, error::Code code) { int new_index, new_size; - Status status = GetBroadcastSize( + absl::Status status = GetBroadcastSize( bcast.input.index, bcast.input.in_size, bcast.input.ksize, bcast.input.stride, bcast.input.pad_size, &new_index, &new_size); EXPECT_EQ(status.code(), code) << status; diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index 317f19712d0e8d..3b50099fb9997c 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -41,8 +41,8 @@ PaddingFIFOQueue::PaddingFIFOQueue( ConvertShapesPartialDimensionsToZero(component_shapes), name), partial_shapes_(component_shapes) {} -Status PaddingFIFOQueue::Initialize() { - Status s = FIFOQueue::Initialize(); +absl::Status PaddingFIFOQueue::Initialize() { + absl::Status s = FIFOQueue::Initialize(); if (!s.ok()) return s; if (component_dtypes_.size() != partial_shapes_.size()) { @@ -56,7 +56,7 @@ Status PaddingFIFOQueue::Initialize() { } /* static */ -Status PaddingFIFOQueue::GetElementComponent( +absl::Status PaddingFIFOQueue::GetElementComponent( const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx, Tensor* out_tensor) { TensorShape element_shape(tuple[component].shape()); @@ -108,8 +108,8 @@ void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, for (int64_t i = attempt->tuples.size() - 1; i >= 0; --i) { for (int j = 0; j < num_components(); ++j) { Tensor element; - Status s = GetElementComponent(attempt->tuples[i], j, - attempt->context, &element); + absl::Status s = GetElementComponent( + attempt->tuples[i], j, attempt->context, &element); if (!s.ok()) { attempt->context->SetStatus( errors::DataLoss("Failed to restore element from " @@ -233,7 +233,7 @@ void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } } -Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { +absl::Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); for (size_t i = 0; i < tuple.size(); ++i) { if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) { @@ -246,7 +246,7 @@ Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) { return absl::OkStatus(); } -Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { +absl::Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); const int64_t batch_size = tuple[0].dim_size(0); for (size_t i = 0; i < tuple.size(); ++i) { @@ -263,7 +263,7 @@ Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) { return absl::OkStatus(); } -Status PaddingFIFOQueue::CompatibleNodeDefShapes( +absl::Status PaddingFIFOQueue::CompatibleNodeDefShapes( const NodeDef& node_def) const { std::vector requested_shapes; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); @@ -279,7 +279,7 @@ Status PaddingFIFOQueue::CompatibleNodeDefShapes( } } -Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { +absl::Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() && !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) { return errors::InvalidArgument("Expected PaddingFIFOQueue, found ", @@ -291,8 +291,8 @@ Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) { return absl::OkStatus(); } -static Status ValidateElementToLargerSlice(const Tensor& element, - Tensor* parent) { +static absl::Status ValidateElementToLargerSlice(const Tensor& element, + Tensor* parent) { DCHECK_NE(parent->dim_size(0), 0); if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { TensorShape chip_shape = parent->shape(); @@ -307,9 +307,9 @@ static Status ValidateElementToLargerSlice(const Tensor& element, } template -Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { - Status s = ValidateElementToLargerSlice(element, parent); +absl::Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { + absl::Status s = ValidateElementToLargerSlice(element, parent); if (!s.ok()) { return s; } @@ -332,8 +332,8 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, namespace { template -Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, - int index) { +absl::Status HandleElementToLargerSliceWithRank(const Tensor& element, + Tensor* parent, int index) { #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ return HandleElementToLargerSlice(element, parent, index); \ @@ -351,8 +351,9 @@ Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, } // namespace -Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element, - Tensor* parent, int index) { +absl::Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element, + Tensor* parent, + int index) { if (parent->dims() != element.dims() + 1) { return errors::Internal( "Mismatched ranks. Element's rank is: ", element.dims(), @@ -381,7 +382,7 @@ Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element, } // Static method -Status PaddingFIFOQueue::SetElementZero(Tensor* element) { +absl::Status PaddingFIFOQueue::SetElementZero(Tensor* element) { #define HANDLE_TYPE(T) \ if (element->dtype() == DataTypeToEnum::value) { \ element->flat().setConstant(T()); \ diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h index d124f0a03c049a..74107e80b1977b 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.h +++ b/tensorflow/core/kernels/padding_fifo_queue.h @@ -38,19 +38,19 @@ class PaddingFIFOQueue : public FIFOQueue { const std::vector& component_shapes, const string& name); - Status Initialize() override; + absl::Status Initialize() override; // Implementations of QueueInterface methods -------------------------------- void TryDequeueMany(int num_elements, OpKernelContext* ctx, bool allow_small_batch, CallbackWithTuple callback) override; - Status MatchesNodeDef(const NodeDef& node_def) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; protected: - Status ValidateManyTuple(const Tuple& tuple) override; - Status ValidateTuple(const Tuple& tuple) override; - Status CompatibleNodeDefShapes(const NodeDef& node_def) const; + absl::Status ValidateManyTuple(const Tuple& tuple) override; + absl::Status ValidateTuple(const Tuple& tuple) override; + absl::Status CompatibleNodeDefShapes(const NodeDef& node_def) const; // Convert a list of PartialTensorShape to a list of // TensorShape. @@ -60,26 +60,26 @@ class PaddingFIFOQueue : public FIFOQueue { absl::Span partial_shapes); // Sets the values in the given element to zero. - static Status SetElementZero(Tensor* element); + static absl::Status SetElementZero(Tensor* element); // Copies element into the index^th slice (in the first dimension) // of parent. Allows for the parent's slice to have a larger size // than the element, and copies the element into the upper left hand // corner of the slice. - static Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, - int index); + static absl::Status CopyElementToLargerSlice(const Tensor& element, + Tensor* parent, int index); std::vector partial_shapes_; private: ~PaddingFIFOQueue() override {} - static Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple, - int component, OpKernelContext* ctx, - Tensor* out_tensor); + static absl::Status GetElementComponent(const PaddingFIFOQueue::Tuple& tuple, + int component, OpKernelContext* ctx, + Tensor* out_tensor); - static Status IsSameSizeExceptZerosInFirst(const TensorShape& first, - const TensorShape& second); + static absl::Status IsSameSizeExceptZerosInFirst(const TensorShape& first, + const TensorShape& second); PaddingFIFOQueue(const PaddingFIFOQueue&) = delete; void operator=(const PaddingFIFOQueue&) = delete; diff --git a/tensorflow/core/kernels/padding_fifo_queue_op.cc b/tensorflow/core/kernels/padding_fifo_queue_op.cc index 177aaf223d7764..030216e13ec4b5 100644 --- a/tensorflow/core/kernels/padding_fifo_queue_op.cc +++ b/tensorflow/core/kernels/padding_fifo_queue_op.cc @@ -53,7 +53,7 @@ class PaddingFIFOQueueOp : public TypedQueueOp { } private: - Status CreateResource(QueueInterface** ret) override + absl::Status CreateResource(QueueInterface** ret) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { PaddingFIFOQueue* queue = new PaddingFIFOQueue( capacity_, component_types_, component_shapes_, cinfo_.name()); diff --git a/tensorflow/core/kernels/parse_tensor_test.cc b/tensorflow/core/kernels/parse_tensor_test.cc index ed2327b8c360ae..1473eff064e3ea 100644 --- a/tensorflow/core/kernels/parse_tensor_test.cc +++ b/tensorflow/core/kernels/parse_tensor_test.cc @@ -49,7 +49,7 @@ class SerializeTensorOpTest : public OpsTestBase { DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")); absl::InlinedVector inputs; inputs.push_back({nullptr, serialized}); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), parse_node_def, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index d07b4b92dd2db5..7970738a96b713 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -75,7 +75,7 @@ PartitionedCallOp::PartitionedCallOp(OpKernelConstruction* ctx) PartitionedCallOp::~PartitionedCallOp() { for (const auto& it : handles_) { - Status status = it.first->ReleaseHandle(it.second); + absl::Status status = it.first->ReleaseHandle(it.second); if (!status.ok()) { LOG(INFO) << "Ignoring error while destructing PartitionedCallOp: " << status.ToString(); @@ -131,7 +131,7 @@ void PartitionedCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { RunFunction(handle, inputs, lib, ctx, done); } -Status PartitionedCallOp::FillOutputDevices( +absl::Status PartitionedCallOp::FillOutputDevices( const FunctionLibraryRuntime& lib, const Device& cpu_device, AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts) { const FunctionLibraryDefinition* flib = lib.GetFunctionLibraryDefinition(); @@ -165,10 +165,9 @@ Status PartitionedCallOp::FillOutputDevices( return absl::OkStatus(); } -Status PartitionedCallOp::Instantiate(FunctionLibraryRuntime* lib, - OpKernelContext* ctx, - std::vector* inputs, - FunctionLibraryRuntime::Handle* handle) { +absl::Status PartitionedCallOp::Instantiate( + FunctionLibraryRuntime* lib, OpKernelContext* ctx, + std::vector* inputs, FunctionLibraryRuntime::Handle* handle) { FunctionLibraryRuntime::InstantiateOptions opts; const auto* config = (ctx->function_library()) ? ctx->function_library()->config_proto() @@ -260,7 +259,7 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle, tsl::profiler::TraceMe trace_me("PartitionedCallOp"); lib->Run(run_opts, handle, inputs, rets, [rets, done = std::move(done), ctx, func_name, - step_container](const Status& status) { + step_container](const absl::Status& status) { if (!status.ok()) { const string function_and_msg = strings::StrCat(errors::FormatFunctionForError(func_name), diff --git a/tensorflow/core/kernels/partitioned_function_ops.h b/tensorflow/core/kernels/partitioned_function_ops.h index 6ae267b90a9414..2b2ec8ea959f7c 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.h +++ b/tensorflow/core/kernels/partitioned_function_ops.h @@ -41,13 +41,13 @@ class PartitionedCallOp : public AsyncOpKernel { void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override; protected: - Status FillOutputDevices(const FunctionLibraryRuntime& lib, - const Device& cpu_device, AttrSlice attrs, - FunctionLibraryRuntime::InstantiateOptions* opts); + absl::Status FillOutputDevices( + const FunctionLibraryRuntime& lib, const Device& cpu_device, + AttrSlice attrs, FunctionLibraryRuntime::InstantiateOptions* opts); - Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx, - std::vector* inputs, - FunctionLibraryRuntime::Handle* handle); + absl::Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx, + std::vector* inputs, + FunctionLibraryRuntime::Handle* handle); void RunFunction(FunctionLibraryRuntime::Handle handle, const std::vector& inputs, diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h index 7cdee546dd2b17..d946b066ee9dfd 100644 --- a/tensorflow/core/kernels/poisson-loss.h +++ b/tensorflow/core/kernels/poisson-loss.h @@ -78,7 +78,7 @@ class PoissonLossUpdater : public DualLossUpdater { // Setting this at 1 for now, it only impacts the adaptive sampling. double SmoothnessConstant() const final { return 1; } - Status ConvertLabel(float* const example_label) const final { + absl::Status ConvertLabel(float* const example_label) const final { if (*example_label < 0.0) { return errors::InvalidArgument( "Only non-negative labels can be used with the Poisson log loss. " diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc index aa22224f89bd68..28e24e79fe0bcf 100644 --- a/tensorflow/core/kernels/pooling_ops_3d.cc +++ b/tensorflow/core/kernels/pooling_ops_3d.cc @@ -88,7 +88,7 @@ Pool3dParameters::Pool3dParameters(OpKernelContext* context, col_stride, padding, &out_width, &pad_cols)); } -Status Pool3dParameters::forward_output_shape(TensorShape* shape) { +absl::Status Pool3dParameters::forward_output_shape(TensorShape* shape) { return ShapeFromFormatWithStatus(data_format, tensor_in_batch, {{out_plane, out_height, out_width}}, depth, shape); diff --git a/tensorflow/core/kernels/pooling_ops_3d.h b/tensorflow/core/kernels/pooling_ops_3d.h index eab41a60d694fc..c0a589ff95092a 100644 --- a/tensorflow/core/kernels/pooling_ops_3d.h +++ b/tensorflow/core/kernels/pooling_ops_3d.h @@ -45,7 +45,7 @@ struct Pool3dParameters { const TensorShape& tensor_in_shape); // Returns the shape of the output for "forward" pooling operations. - Status forward_output_shape(TensorShape* shape); + absl::Status forward_output_shape(TensorShape* shape); int depth; diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index bade5e15b2ea8e..4ccca647c154aa 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -86,9 +86,9 @@ struct PadInputWithNegativeInf { } // namespace -Status CheckPaddingSize(int64_t window_rows, int64_t window_cols, - int64_t pad_top, int64_t pad_bottom, int64_t pad_left, - int64_t pad_right) { +absl::Status CheckPaddingSize(int64_t window_rows, int64_t window_cols, + int64_t pad_top, int64_t pad_bottom, + int64_t pad_left, int64_t pad_right) { if (!FastBoundsCheck(pad_top, window_rows)) { return errors::InvalidArgument("Top padding ", pad_top, " needs to be smaller than the " @@ -210,7 +210,7 @@ PoolParameters::PoolParameters(OpKernelContext* context, } } -Status PoolParameters::forward_output_shape(TensorShape* shape) { +absl::Status PoolParameters::forward_output_shape(TensorShape* shape) { if (depth_window == 1) { // Spatial pooling return ShapeFromFormatWithStatus(data_format, tensor_in_batch, out_height, diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 8513ea1644d199..bb5dda562af672 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -53,7 +53,7 @@ struct PoolParameters { TensorFormat data_format, const TensorShape& tensor_in_shape); // Returns the shape of the output for "forward" pooling operations. - Status forward_output_shape(TensorShape* shape); + absl::Status forward_output_shape(TensorShape* shape); int depth; @@ -107,11 +107,13 @@ class MaxPoolingOp : public OpKernel { OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " "specify 4 dimensions")); - for (int i = 0; i < ksize_.size(); ++i) { - OP_REQUIRES(context, ksize_[i] > 0, - errors::InvalidArgument("Sliding window ksize for dimension ", - i, " was zero.")); - } + OP_REQUIRES( + context, + ksize_[0] > 0 && ksize_[1] > 0 && ksize_[2] > 0 && ksize_[3] > 0, + errors::InvalidArgument( + absl::StrCat("Sliding window ksize must be positive. The " + "specified or inferred ksize is: ", + absl::StrJoin(ksize_, ",")))); OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); OP_REQUIRES(context, stride_.size() == 4, errors::InvalidArgument("Sliding window stride field must " diff --git a/tensorflow/core/kernels/priority_queue.cc b/tensorflow/core/kernels/priority_queue.cc index dde3ad973610b4..56ea77fdbcf2ca 100644 --- a/tensorflow/core/kernels/priority_queue.cc +++ b/tensorflow/core/kernels/priority_queue.cc @@ -40,8 +40,8 @@ PriorityQueue::PriorityQueue(int32_t capacity, const string& name) : TypedQueue(capacity, component_dtypes, component_shapes, name) {} -Status PriorityQueue::Initialize() { - Status s = TypedQueue::Initialize(); +absl::Status PriorityQueue::Initialize() { + absl::Status s = TypedQueue::Initialize(); if (!s.ok()) return s; mutex_lock lock(mu_); @@ -115,7 +115,7 @@ void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, } /* static */ -Status PriorityQueue::GetElementComponentFromBatch( +absl::Status PriorityQueue::GetElementComponentFromBatch( const PriorityQueue::Tuple& tuple, int index, int component, OpKernelContext* ctx, Tensor* out_element) { TensorShape element_shape(tuple[component].shape()); @@ -273,8 +273,8 @@ void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, // an optimized case where the queue 'knows' what attributes to // use, and plumbs them through here. Tensor element; - Status status = ctx->allocate_temp(component_dtypes_[i], - ManyOutShape(i, 0), &element); + absl::Status status = ctx->allocate_temp(component_dtypes_[i], + ManyOutShape(i, 0), &element); if (!status.ok()) { ctx->SetStatus(status); callback(Tuple()); @@ -384,7 +384,7 @@ void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } } -Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { +absl::Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() && !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) { return errors::InvalidArgument("Expected PriorityQueue, found ", @@ -396,7 +396,7 @@ Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) { return absl::OkStatus(); } -Status PriorityQueue::MatchesPriorityNodeDefTypes( +absl::Status PriorityQueue::MatchesPriorityNodeDefTypes( const NodeDef& node_def) const { DataTypeVector requested_dtypes; TF_RETURN_IF_ERROR( @@ -412,7 +412,7 @@ Status PriorityQueue::MatchesPriorityNodeDefTypes( return absl::OkStatus(); } -Status PriorityQueue::MatchesPriorityNodeDefShapes( +absl::Status PriorityQueue::MatchesPriorityNodeDefShapes( const NodeDef& node_def) const { std::vector requested_shapes; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); diff --git a/tensorflow/core/kernels/priority_queue.h b/tensorflow/core/kernels/priority_queue.h index c571846a37d118..f7ca800a66bf7a 100644 --- a/tensorflow/core/kernels/priority_queue.h +++ b/tensorflow/core/kernels/priority_queue.h @@ -52,7 +52,8 @@ class PriorityQueue const std::vector& component_shapes, const string& name); - Status Initialize() override; // Must be called before any other method. + absl::Status Initialize() + override; // Must be called before any other method. // Implementations of QueueInterface methods -------------------------------- @@ -64,9 +65,9 @@ class PriorityQueue void TryDequeueMany(int num_elements, OpKernelContext* ctx, bool allow_small_batch, CallbackWithTuple callback) override; - Status MatchesNodeDef(const NodeDef& node_def) override; - Status MatchesPriorityNodeDefTypes(const NodeDef& node_def) const; - Status MatchesPriorityNodeDefShapes(const NodeDef& node_def) const; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; + absl::Status MatchesPriorityNodeDefTypes(const NodeDef& node_def) const; + absl::Status MatchesPriorityNodeDefShapes(const NodeDef& node_def) const; int32 size() const override { mutex_lock lock(mu_); @@ -80,10 +81,10 @@ class PriorityQueue void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetElementComponentFromBatch(const Tuple& tuple, int index, - int component, - OpKernelContext* ctx, - Tensor* out_element); + static absl::Status GetElementComponentFromBatch(const Tuple& tuple, + int index, int component, + OpKernelContext* ctx, + Tensor* out_element); PriorityQueue(const PriorityQueue&) = delete; void operator=(const PriorityQueue&) = delete; diff --git a/tensorflow/core/kernels/priority_queue_op.cc b/tensorflow/core/kernels/priority_queue_op.cc index 2b304694870000..ea91bde1cf214c 100644 --- a/tensorflow/core/kernels/priority_queue_op.cc +++ b/tensorflow/core/kernels/priority_queue_op.cc @@ -50,7 +50,7 @@ class PriorityQueueOp : public TypedQueueOp { } private: - Status CreateResource(QueueInterface** ret) override + absl::Status CreateResource(QueueInterface** ret) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { PriorityQueue* queue = new PriorityQueue(capacity_, component_types_, component_shapes_, cinfo_.name()); diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc index ad6f5353173ac5..967e107e553634 100644 --- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc +++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc @@ -735,7 +735,7 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given) { AddInputFromArray(TensorShape({}), {1.0}); // Min AddInputFromArray(TensorShape({}), {0.0}); // Max - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Invalid range: input_min 1 > input_max 0")) << s; @@ -757,7 +757,7 @@ TEST_F(QuantizeAndDequantizeTest, Invalid_range_given_V3) { AddInputFromArray(TensorShape({}), {0.0}); // Max AddInputFromArray(TensorShape({}), {8}); // num_bits - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Invalid range: input_min 1 > input_max 0")) << s; diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc index 997bc0293ba7a5..02a4a84b203ec8 100644 --- a/tensorflow/core/kernels/quantized_concat_op.cc +++ b/tensorflow/core/kernels/quantized_concat_op.cc @@ -78,7 +78,7 @@ class QuantizedConcatOp : public OpKernel { explicit QuantizedConcatOp(OpKernelConstruction* c) : OpKernel(c) {} - Status CalculateInputAndOutputRange( + absl::Status CalculateInputAndOutputRange( const OpInputList& input_mins, const OpInputList& input_maxes, const size_t N, std::vector>* input_mins_and_maxes, @@ -130,12 +130,13 @@ class QuantizedConcatOp : public OpKernel { return inputs_flat_dim0; } - Status CalculateConcatDims(const size_t N, const TensorShape& input_shape, - int input_dims, const OpInputList& values, - const int32_t concat_dim, - const int64_t inputs_flat_dim0, - ConstMatrixVector* inputs_flat, - int* output_concat_dim) { + absl::Status CalculateConcatDims(const size_t N, + const TensorShape& input_shape, + int input_dims, const OpInputList& values, + const int32_t concat_dim, + const int64_t inputs_flat_dim0, + ConstMatrixVector* inputs_flat, + int* output_concat_dim) { // Note that we reduce the concat of n-dimensional tensors into a two // dimensional concat. Assuming the dimensions of any input/output // tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 956e5b9f6c5d48..3f3e2743d674f4 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -275,7 +275,7 @@ class Im2ColConvFunctor { // use TensorFlow's resource management to ensure that the memory will be // released when the session is over. Im2ColBufferResource* im2col_buffer_resource; - std::function**)> + std::function**)> creator = [](Im2ColBufferResource** resource) { #ifdef _MSC_VER // MSVC complains about the capture of chunk_value_count which oddly diff --git a/tensorflow/core/kernels/quantized_instance_norm_test.cc b/tensorflow/core/kernels/quantized_instance_norm_test.cc index ab729e3d6cd205..9569de8d03bd97 100644 --- a/tensorflow/core/kernels/quantized_instance_norm_test.cc +++ b/tensorflow/core/kernels/quantized_instance_norm_test.cc @@ -91,7 +91,7 @@ void Expect(const Tensor& input, float x_min, float x_max, root, input_ph, x_min, x_max, QuantizedInstanceNorm::Attrs().VarianceEpsilon(variance_eps)); - Status s = root.status(); + absl::Status s = root.status(); EXPECT_TRUE(s.ok()); ClientSession session(root); diff --git a/tensorflow/core/kernels/queue_base.cc b/tensorflow/core/kernels/queue_base.cc index f9f86e599a8b0f..b6b9e2b980c674 100644 --- a/tensorflow/core/kernels/queue_base.cc +++ b/tensorflow/core/kernels/queue_base.cc @@ -28,8 +28,8 @@ namespace tensorflow { namespace { template -Status HandleSliceToElement(const Tensor& parent, Tensor* element, - int64_t index) { +absl::Status HandleSliceToElement(const Tensor& parent, Tensor* element, + int64_t index) { typedef typename EnumToDataType
::Type T; DCHECK_NE(parent.dim_size(0), 0); DCHECK_GE(index, 0); @@ -60,7 +60,7 @@ QueueBase::QueueBase(int32_t capacity, const DataTypeVector& component_dtypes, QueueBase::~QueueBase() {} -Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { +absl::Status QueueBase::ValidateTupleCommon(const Tuple& tuple) const { if (tuple.size() != static_cast(num_components())) { return errors::InvalidArgument( "Wrong number of components in tuple. Expected ", num_components(), @@ -89,8 +89,8 @@ string QueueBase::ShapeListString(const absl::Span& shapes) { return result; } -Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, - const string& op) const { +absl::Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, + const string& op) const { if (node_def.op() != op) { return errors::InvalidArgument("Shared queue '", name_, "' has type '", op, "' that does not match type of Node '", @@ -99,8 +99,8 @@ Status QueueBase::MatchesNodeDefOp(const NodeDef& node_def, return absl::OkStatus(); } -Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, - int32_t capacity) const { +absl::Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, + int32_t capacity) const { int32_t requested_capacity = -1; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "capacity", &requested_capacity)); if (requested_capacity < 0) requested_capacity = kUnbounded; @@ -112,7 +112,7 @@ Status QueueBase::MatchesNodeDefCapacity(const NodeDef& node_def, return absl::OkStatus(); } -Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { +absl::Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { DataTypeVector requested_dtypes; TF_RETURN_IF_ERROR( GetNodeAttr(node_def, "component_types", &requested_dtypes)); @@ -126,7 +126,7 @@ Status QueueBase::MatchesNodeDefTypes(const NodeDef& node_def) const { return absl::OkStatus(); } -Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { +absl::Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { std::vector requested_shapes; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes)); if (requested_shapes != component_shapes_) { @@ -141,7 +141,7 @@ Status QueueBase::MatchesNodeDefShapes(const NodeDef& node_def) const { // TODO(mrry): If these checks become a bottleneck, find a way to // reduce the number of times that they are called. -Status QueueBase::ValidateTuple(const Tuple& tuple) { +absl::Status QueueBase::ValidateTuple(const Tuple& tuple) { TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); if (specified_shapes()) { for (size_t i = 0; i < tuple.size(); ++i) { @@ -158,7 +158,7 @@ Status QueueBase::ValidateTuple(const Tuple& tuple) { // TODO(mrry): If these checks become a bottleneck, find a way to // reduce the number of times that they are called. -Status QueueBase::ValidateManyTuple(const Tuple& tuple) { +absl::Status QueueBase::ValidateManyTuple(const Tuple& tuple) { TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple)); const int64_t batch_size = tuple[0].dim_size(0); if (specified_shapes()) { @@ -334,14 +334,14 @@ void QueueBase::FlushUnlocked() { } } -Status QueueBase::CopySliceToElement(const Tensor& parent, Tensor* element, - int64_t index) { +absl::Status QueueBase::CopySliceToElement(const Tensor& parent, + Tensor* element, int64_t index) { return batch_util::CopySliceToElement(parent, element, index); } /* static */ -Status QueueBase::CopyElementToSlice(const Tensor& element, Tensor* parent, - int64_t index) { +absl::Status QueueBase::CopyElementToSlice(const Tensor& element, + Tensor* parent, int64_t index) { return batch_util::CopyElementToSlice(element, parent, index); } diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index e47ddb4dd28894..d39ab45498b843 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -53,8 +53,8 @@ class QueueBase : public QueueInterface { return component_dtypes_; } - Status ValidateTuple(const Tuple& tuple) override; - Status ValidateManyTuple(const Tuple& tuple) override; + absl::Status ValidateTuple(const Tuple& tuple) override; + absl::Status ValidateManyTuple(const Tuple& tuple) override; void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, DoneCallback callback) override; @@ -72,8 +72,8 @@ class QueueBase : public QueueInterface { } // Copies the index^th slice (in the first dimension) of parent into element. - static Status CopySliceToElement(const Tensor& parent, Tensor* element, - int64_t index); + static absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); // Copies element into the index^th slice (in the first dimension) of parent. // NOTE(mrry): This method is deprecated. Use @@ -82,8 +82,8 @@ class QueueBase : public QueueInterface { ABSL_DEPRECATED( "Use `tensorflow::batch_util::CopySliceToElement()` defined in " "\"./batch_util.h\" instead.") - static Status CopyElementToSlice(const Tensor& element, Tensor* parent, - int64_t index); + static absl::Status CopyElementToSlice(const Tensor& element, Tensor* parent, + int64_t index); protected: enum Action { kEnqueue, kDequeue }; @@ -110,7 +110,7 @@ class QueueBase : public QueueInterface { bool specified_shapes() const { return component_shapes_.size() > 0; } // Code common to Validate*Tuple(). - Status ValidateTupleCommon(const Tuple& tuple) const; + absl::Status ValidateTupleCommon(const Tuple& tuple) const; TensorShape ManyOutShape(int i, int64_t batch_size) { TensorShape shape({batch_size}); @@ -136,11 +136,12 @@ class QueueBase : public QueueInterface { // Helpers for implementing MatchesNodeDef(). static string ShapeListString(const absl::Span& shapes); - Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const; - Status MatchesNodeDefCapacity(const NodeDef& node_def, - int32_t capacity) const; - Status MatchesNodeDefTypes(const NodeDef& node_def) const; - Status MatchesNodeDefShapes(const NodeDef& node_def) const; + absl::Status MatchesNodeDefOp(const NodeDef& node_def, + const string& op) const; + absl::Status MatchesNodeDefCapacity(const NodeDef& node_def, + int32_t capacity) const; + absl::Status MatchesNodeDefTypes(const NodeDef& node_def) const; + absl::Status MatchesNodeDefShapes(const NodeDef& node_def) const; protected: const int32 capacity_; diff --git a/tensorflow/core/kernels/queue_op.cc b/tensorflow/core/kernels/queue_op.cc index 493e3367820730..e16c6034de4596 100644 --- a/tensorflow/core/kernels/queue_op.cc +++ b/tensorflow/core/kernels/queue_op.cc @@ -44,11 +44,10 @@ void QueueOp::Compute(OpKernelContext* context) { } } -Status QueueOp::VerifyResource(QueueInterface* queue) { +absl::Status QueueOp::VerifyResource(QueueInterface* queue) { return queue->MatchesNodeDef(def()); } - QueueOpKernel::QueueOpKernel(OpKernelConstruction* context) : AsyncOpKernel(context) {} diff --git a/tensorflow/core/kernels/queue_op.h b/tensorflow/core/kernels/queue_op.h index 5583b79c786e95..57a771d91fcb50 100644 --- a/tensorflow/core/kernels/queue_op.h +++ b/tensorflow/core/kernels/queue_op.h @@ -43,7 +43,7 @@ class QueueOp : public ResourceOpKernel { DataTypeVector component_types_; private: - Status VerifyResource(QueueInterface* queue) override; + absl::Status VerifyResource(QueueInterface* queue) override; }; class TypedQueueOp : public QueueOp { @@ -52,7 +52,7 @@ class TypedQueueOp : public QueueOp { protected: template - Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) { + absl::Status CreateTypedQueue(TypedQueue* queue, QueueInterface** ret) { if (queue == nullptr) { return errors::ResourceExhausted("Failed to allocate queue."); } diff --git a/tensorflow/core/kernels/ragged_cross_op.cc b/tensorflow/core/kernels/ragged_cross_op.cc index ba8d01248a9ab9..bc4442e7534df7 100644 --- a/tensorflow/core/kernels/ragged_cross_op.cc +++ b/tensorflow/core/kernels/ragged_cross_op.cc @@ -378,12 +378,12 @@ class RaggedCrossOp : public OpKernel { private: // Validates input tensors. - Status ValidateInput(const OpInputList& ragged_values_list, - const OpInputList& ragged_splits_list, - const OpInputList& sparse_indices_list, - const OpInputList& sparse_values_list, - const OpInputList& sparse_shape_list, - const OpInputList& dense_list) { + absl::Status ValidateInput(const OpInputList& ragged_values_list, + const OpInputList& ragged_splits_list, + const OpInputList& sparse_indices_list, + const OpInputList& sparse_values_list, + const OpInputList& sparse_shape_list, + const OpInputList& dense_list) { const auto num_ragged = ragged_values_list.size(); const auto num_sparse = sparse_indices_list.size(); @@ -459,12 +459,13 @@ class RaggedCrossOp : public OpKernel { } // Build a feature reader for each input tensor, and store them in `features`. - Status BuildFeatureReaders(const OpInputList& ragged_values_list, - const OpInputList& ragged_splits_list, - const OpInputList& sparse_indices_list, - const OpInputList& sparse_values_list, - const OpInputList& dense_list, int64_t batch_size, - FeatureReaders* features) { + absl::Status BuildFeatureReaders(const OpInputList& ragged_values_list, + const OpInputList& ragged_splits_list, + const OpInputList& sparse_indices_list, + const OpInputList& sparse_values_list, + const OpInputList& dense_list, + int64_t batch_size, + FeatureReaders* features) { features->reserve(input_order_.size()); int next_ragged = 0; @@ -522,9 +523,9 @@ class RaggedCrossOp : public OpKernel { } // Builds a RaggedReatureReader - static Status BuildRaggedFeatureReader(const Tensor& values, - const Tensor& splits, - FeatureReaders* features) { + static absl::Status BuildRaggedFeatureReader(const Tensor& values, + const Tensor& splits, + FeatureReaders* features) { if (values.dtype() != DT_INT64 && values.dtype() != DT_STRING) { return errors::InvalidArgument("Unexpected dtype for input ", (features->size() + 1), ": ", @@ -556,8 +557,8 @@ class RaggedCrossOp : public OpKernel { } // Builds a DenseFaggedReatureReader. - static Status BuildDenseFeatureReader(const Tensor& values, - FeatureReaders* features) { + static absl::Status BuildDenseFeatureReader(const Tensor& values, + FeatureReaders* features) { if (values.dtype() == DT_INT64) { features->emplace_back(new DenseFeatureReader(values)); } else if (values.dtype() == DT_STRING) { @@ -571,10 +572,10 @@ class RaggedCrossOp : public OpKernel { } // Builds a SparseFaggedReatureReader. - static Status BuildSparseFeatureReader(const Tensor& indices, - const Tensor& values, - int64_t batch_size, - FeatureReaders* features) { + static absl::Status BuildSparseFeatureReader(const Tensor& indices, + const Tensor& values, + int64_t batch_size, + FeatureReaders* features) { if (values.dtype() == DT_INT64) { features->emplace_back( new SparseFeatureReader(indices, values, batch_size)); @@ -590,9 +591,10 @@ class RaggedCrossOp : public OpKernel { } // Allocates output tensors with proper size, and populates row_splits_out. - Status BuildOutputTensors(const FeatureReaders& features, int64_t batch_size, - OpKernelContext* context, Tensor** values_out, - Tensor** row_splits_out) { + absl::Status BuildOutputTensors(const FeatureReaders& features, + int64_t batch_size, OpKernelContext* context, + Tensor** values_out, + Tensor** row_splits_out) { // Allocate and populate the row_splits output tensor. TF_RETURN_IF_ERROR(context->allocate_output( 1, TensorShape({batch_size + 1}), row_splits_out)); diff --git a/tensorflow/core/kernels/ragged_gather_op.cc b/tensorflow/core/kernels/ragged_gather_op.cc index 0252be8de06803..d902e8424d7486 100644 --- a/tensorflow/core/kernels/ragged_gather_op.cc +++ b/tensorflow/core/kernels/ragged_gather_op.cc @@ -100,8 +100,8 @@ class RaggedGatherOpBase : public OpKernel { using ConstFlatType = typename TTypes::ConstFlat; // Check if any indices are out-of-bounds. - ::tensorflow::Status ValidateIndices(const Tensor& indices_in, - SPLITS_TYPE num_params) { + absl::Status ValidateIndices(const Tensor& indices_in, + SPLITS_TYPE num_params) { const auto& indices = indices_in.flat(); for (SPLITS_TYPE i = 0; i < indices.size(); ++i) { SPLITS_TYPE index = indices(i); @@ -118,7 +118,7 @@ class RaggedGatherOpBase : public OpKernel { // Also find the slices of values that need to be copied, and store them // in `value_slices`. The total number of values that will be copied (which // we need for allocating the output values tensor) is stored in `num_values`. - ::tensorflow::Status MakeSplits( + absl::Status MakeSplits( const Tensor& indices_in, const OpInputList& params_nested_splits_in, SPLITS_TYPE num_params_dense_values, std::vector>* out_splits, @@ -191,7 +191,7 @@ class RaggedGatherOpBase : public OpKernel { return absl::OkStatus(); } - ::tensorflow::Status ValidateSplits( + absl::Status ValidateSplits( const std::vector& params_nested_splits, SPLITS_TYPE num_params_dense_values) { // Validate @@ -219,7 +219,7 @@ class RaggedGatherOpBase : public OpKernel { return absl::OkStatus(); } - ::tensorflow::Status WriteSplits( + absl::Status WriteSplits( const std::vector>& out_splits, OpKernelContext* context) { OpOutputList splits_out; @@ -237,7 +237,7 @@ class RaggedGatherOpBase : public OpKernel { return absl::OkStatus(); } - ::tensorflow::Status WriteValues( + absl::Status WriteValues( const Tensor& params_dense_values_in, const std::vector>& value_slices, int values_index, SPLITS_TYPE num_values, diff --git a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc index 6dc50c9435c66f..65469260e3fa09 100644 --- a/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_from_variant_op.cc @@ -30,7 +30,7 @@ namespace { /* Extracts the components of the variant-encoded tensor `encoded_variant` * into a flat vector of `RaggedTensorVariant` objects. */ -Status RaggedComponentsFromVariant( +absl::Status RaggedComponentsFromVariant( const Tensor& encoded_variant, int input_ragged_rank, int output_ragged_rank, DataType value_dtype, DataType split_dtype, std::vector* decoded_ragged) { @@ -92,7 +92,7 @@ Status RaggedComponentsFromVariant( * This should only be used when input_ragged_rank=0 and output_ragged_rank=0. */ template -Status StackNonRaggedTensors( +absl::Status StackNonRaggedTensors( const std::vector& ragged_components, RaggedTensorVariant* output_ragged) { if (ragged_components.empty()) { @@ -125,7 +125,7 @@ Status StackNonRaggedTensors( } template -Status NestedStackRaggedTensors( +absl::Status NestedStackRaggedTensors( const std::vector& ragged_components, const std::vector& nested_dim_sizes, const int input_ragged_rank, const int output_ragged_rank, RaggedTensorVariant* output_ragged) { diff --git a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc index 41667581eab3d1..7f92a50133ce99 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_sparse_kernel.cc @@ -154,7 +154,7 @@ class RaggedTensorToSparseOp : public OpKernel { private: // Validate `rt_nested_splits` to ensure we don't get any segfaults. - static ::tensorflow::Status ValidateInputs( + static absl::Status ValidateInputs( std::vector rt_nested_splits, const Tensor& rt_dense_values_in) { for (int i = 0; i < rt_nested_splits.size(); ++i) { diff --git a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc index a3dbeb9aac9f84..516a0cddcb6acc 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_tensor_op.cc @@ -84,7 +84,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } } - Status GetMaxWidth(OpKernelContext* c, int dimension, INDEX_TYPE* result) { + absl::Status GetMaxWidth(OpKernelContext* c, int dimension, + INDEX_TYPE* result) { const RowPartitionTensor row_partition_tensor = GetRowPartitionTensor(c, dimension - 1); switch (GetRowPartitionTypeByDimension(dimension - 1)) { @@ -137,8 +138,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { return std::max(index_length - first_equal_index, max_width); } - Status CalculateOutputSize(INDEX_TYPE first_dim, OpKernelContext* c, - vector* result) { + absl::Status CalculateOutputSize(INDEX_TYPE first_dim, OpKernelContext* c, + vector* result) { TensorShapeProto value_shape_proto; c->input(kValueInputIndex).shape().AsProto(&value_shape_proto); @@ -207,7 +208,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { DCHECK_EQ(result->size(), first_dimension); } - Status CalculateOutputIndexRowSplit( + absl::Status CalculateOutputIndexRowSplit( const RowPartitionTensor& row_split, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, @@ -260,7 +261,7 @@ class RaggedTensorToTensorBaseOp : public OpKernel { // result[6] = -1 because parent_output_index[value_rowids[6]] == -1 // result[7] = -1 because parent_output_index[value_rowids[6]] == -1 // result[8] = parent_output_index[value_rowids[7]] - Status CalculateOutputIndexValueRowID( + absl::Status CalculateOutputIndexValueRowID( const RowPartitionTensor& value_rowids, const vector& parent_output_index, INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, @@ -315,11 +316,11 @@ class RaggedTensorToTensorBaseOp : public OpKernel { return absl::OkStatus(); } - Status CalculateOutputIndex(OpKernelContext* context, int dimension, - const vector& parent_output_index, - INDEX_TYPE output_index_multiplier, - INDEX_TYPE output_size, - vector* result) { + absl::Status CalculateOutputIndex( + OpKernelContext* context, int dimension, + const vector& parent_output_index, + INDEX_TYPE output_index_multiplier, INDEX_TYPE output_size, + vector* result) { const RowPartitionTensor row_partition_tensor = GetRowPartitionTensor(context, dimension); auto partition_type = GetRowPartitionTypeByDimension(dimension); @@ -345,7 +346,8 @@ class RaggedTensorToTensorBaseOp : public OpKernel { } } - Status GetFirstDimensionSize(OpKernelContext* context, INDEX_TYPE* result) { + absl::Status GetFirstDimensionSize(OpKernelContext* context, + INDEX_TYPE* result) { const Tensor first_partition_tensor = context->input(kFirstPartitionInputIndex); if (row_partition_types_.empty()) { diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc index 04237d8ecb7f99..b4d7fc8395b614 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op.cc @@ -36,7 +36,7 @@ namespace tensorflow { namespace { template -Status UnbatchDenseZerothDim( +absl::Status UnbatchDenseZerothDim( const RaggedTensorVariant& batched_ragged, std::vector* ragged_components) { Tensor batched_values = batched_ragged.values(); @@ -65,7 +65,7 @@ Status UnbatchDenseZerothDim( } template -Status UnbatchRaggedZerothDim( +absl::Status UnbatchRaggedZerothDim( const RaggedTensorVariant& batched_ragged, std::vector* ragged_components) { // Set up the component Ragged Tensors. diff --git a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc index ac04580f4dec11..8532a20e129cb5 100644 --- a/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc +++ b/tensorflow/core/kernels/ragged_tensor_to_variant_op_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/shape_inference.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/ragged_tensor_variant.cc b/tensorflow/core/kernels/ragged_tensor_variant.cc index 57b5ee36f2622b..b6b70a283c7c48 100644 --- a/tensorflow/core/kernels/ragged_tensor_variant.cc +++ b/tensorflow/core/kernels/ragged_tensor_variant.cc @@ -52,7 +52,7 @@ bool RaggedTensorVariant::Decode(const VariantTensorData& data) { namespace { -Status RaggedTensorVariantDeviceCopy( +absl::Status RaggedTensorVariantDeviceCopy( const RaggedTensorVariant& from, RaggedTensorVariant* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { // RaggedTensorVariant is only used by kernels that run on the CPU, so we diff --git a/tensorflow/core/kernels/ragged_tensor_variant.h b/tensorflow/core/kernels/ragged_tensor_variant.h index db35b8bcf0d35b..1d2066b0dcf457 100644 --- a/tensorflow/core/kernels/ragged_tensor_variant.h +++ b/tensorflow/core/kernels/ragged_tensor_variant.h @@ -68,9 +68,9 @@ class RaggedTensorVariant { }; template -Status RaggedTensorVariantZerosLike(OpKernelContext* c, - const RaggedTensorVariant& x, - RaggedTensorVariant* y) { +absl::Status RaggedTensorVariantZerosLike(OpKernelContext* c, + const RaggedTensorVariant& x, + RaggedTensorVariant* y) { y->set_nested_splits(x.nested_splits()); TF_RETURN_IF_ERROR( ZerosLikeTensor(c, x.values(), y->mutable_values())); @@ -78,10 +78,10 @@ Status RaggedTensorVariantZerosLike(OpKernelContext* c, } template -Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, - const RaggedTensorVariant& x, - const RaggedTensorVariant& y, - RaggedTensorVariant* out) { +absl::Status RaggedTensorVariantBinaryAdd(OpKernelContext* c, + const RaggedTensorVariant& x, + const RaggedTensorVariant& y, + RaggedTensorVariant* out) { if (x.values().dtype() != y.values().dtype()) { return errors::InvalidArgument( "Can't add RaggedTensorVariants of different dtypes. One is ", diff --git a/tensorflow/core/kernels/ragged_utils.h b/tensorflow/core/kernels/ragged_utils.h index f91f1da343993f..3ccd34a523e51d 100644 --- a/tensorflow/core/kernels/ragged_utils.h +++ b/tensorflow/core/kernels/ragged_utils.h @@ -26,9 +26,9 @@ namespace tensorflow { // Verifies that the splits are valid for ragged tensor template -Status RaggedTensorVerifySplits(const Tensor& ragged_splits, - bool check_last_element, - int64_t num_ragged_values) { +absl::Status RaggedTensorVerifySplits(const Tensor& ragged_splits, + bool check_last_element, + int64_t num_ragged_values) { auto flat_ragged_splits = ragged_splits.flat(); if (ragged_splits.dims() != 1) { diff --git a/tensorflow/core/kernels/random_index_shuffle_ops.cc b/tensorflow/core/kernels/random_index_shuffle_ops.cc index 0f71b067e4f526..f2fdd41c6832a2 100644 --- a/tensorflow/core/kernels/random_index_shuffle_ops.cc +++ b/tensorflow/core/kernels/random_index_shuffle_ops.cc @@ -45,8 +45,8 @@ std::array CastSeedFrom(const Tensor& seed_t, const int row) { static_cast(seed_vals(3 * row + 2))}; } -Status GetSeed(const Tensor& seed_t, const int row, - std::array* seed) { +absl::Status GetSeed(const Tensor& seed_t, const int row, + std::array* seed) { if (seed_t.dtype() == DT_INT32) { *seed = CastSeedFrom(seed_t, row); } else if (seed_t.dtype() == DT_UINT32) { diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 129d5fc4379909..7624b56b50b587 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -51,8 +51,9 @@ typedef Eigen::GpuDevice GPUDevice; namespace { -static Status AllocateOutputWithShape(OpKernelContext* ctx, const Tensor& shape, - int index, Tensor** output) { +static absl::Status AllocateOutputWithShape(OpKernelContext* ctx, + const Tensor& shape, int index, + Tensor** output) { TensorShape tensor_shape; TF_RETURN_IF_ERROR(tensor::MakeShape(shape, &tensor_shape)); return ctx->allocate_output(index, tensor_shape, output); diff --git a/tensorflow/core/kernels/random_shuffle_queue_op.cc b/tensorflow/core/kernels/random_shuffle_queue_op.cc index bed97066d14d03..856357489bdfab 100644 --- a/tensorflow/core/kernels/random_shuffle_queue_op.cc +++ b/tensorflow/core/kernels/random_shuffle_queue_op.cc @@ -47,7 +47,8 @@ class RandomShuffleQueue : public TypedQueue > { const std::vector& component_shapes, const string& name); - Status Initialize() override; // Must be called before any other method. + absl::Status Initialize() + override; // Must be called before any other method. // Implementations of QueueInterface methods -------------------------------- void TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, @@ -58,7 +59,7 @@ class RandomShuffleQueue : public TypedQueue > { void TryDequeueMany(int num_elements, OpKernelContext* ctx, bool allow_small_batch, CallbackWithTuple callback) override; - Status MatchesNodeDef(const NodeDef& node_def) override; + absl::Status MatchesNodeDef(const NodeDef& node_def) override; int32 size() const override { mutex_lock lock(mu_); @@ -72,10 +73,10 @@ class RandomShuffleQueue : public TypedQueue > { void DequeueLocked(OpKernelContext* ctx, Tuple* tuple) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - static Status GetElementComponentFromBatch(const Tuple& tuple, int64_t index, - int component, - OpKernelContext* ctx, - Tensor* out_tensor); + static absl::Status GetElementComponentFromBatch(const Tuple& tuple, + int64_t index, int component, + OpKernelContext* ctx, + Tensor* out_tensor); const int32 min_after_dequeue_; const int64_t original_seed_; @@ -106,7 +107,7 @@ RandomShuffleQueue::RandomShuffleQueue( parent_generator_ = random::PhiloxRandom(seed, seed2); } -Status RandomShuffleQueue::Initialize() { +absl::Status RandomShuffleQueue::Initialize() { TF_RETURN_IF_ERROR(TypedQueue::Initialize()); mutex_lock lock(mu_); @@ -165,11 +166,9 @@ void RandomShuffleQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx, } /* static */ -Status RandomShuffleQueue::GetElementComponentFromBatch(const Tuple& tuple, - int64_t index, - int component, - OpKernelContext* ctx, - Tensor* out_tensor) { +absl::Status RandomShuffleQueue::GetElementComponentFromBatch( + const Tuple& tuple, int64_t index, int component, OpKernelContext* ctx, + Tensor* out_tensor) { TensorShape element_shape(tuple[component].shape()); element_shape.RemoveDim(0); TF_RETURN_IF_ERROR( @@ -314,8 +313,8 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, // an optimized case where the queue 'knows' what attributes to // use, and plumbs them through here. Tensor element; - Status s = ctx->allocate_temp(component_dtypes_[i], ManyOutShape(i, 0), - &element); + absl::Status s = ctx->allocate_temp(component_dtypes_[i], + ManyOutShape(i, 0), &element); if (!s.ok()) { ctx->SetStatus(s); callback(Tuple()); @@ -351,7 +350,7 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, i >= 0; --i) { for (int j = 0; j < num_components(); ++j) { Tensor element; - Status s = GetElementComponentFromBatch( + absl::Status s = GetElementComponentFromBatch( attempt->tuple, i, j, attempt->context, &element); if (!s.ok()) { attempt->context->SetStatus( @@ -437,7 +436,7 @@ void RandomShuffleQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx, } } -Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { +absl::Status RandomShuffleQueue::MatchesNodeDef(const NodeDef& node_def) { if (!MatchesNodeDefOp(node_def, "RandomShuffleQueue").ok() && !MatchesNodeDefOp(node_def, "RandomShuffleQueueV2").ok()) { return errors::InvalidArgument("Expected RandomShuffleQueue, found ", @@ -496,7 +495,7 @@ class RandomShuffleQueueOp : public TypedQueueOp { } private: - Status CreateResource(QueueInterface** ret) override + absl::Status CreateResource(QueueInterface** ret) override TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { RandomShuffleQueue* queue = new RandomShuffleQueue( capacity_, min_after_dequeue_, seed_, seed2_, component_types_, diff --git a/tensorflow/core/kernels/range_sampler.cc b/tensorflow/core/kernels/range_sampler.cc index 449e0ccb879253..971b849ce71a59 100644 --- a/tensorflow/core/kernels/range_sampler.cc +++ b/tensorflow/core/kernels/range_sampler.cc @@ -241,8 +241,8 @@ FixedUnigramSampler::FixedUnigramSampler(int64_t range, float distortion, FillReservedIds(num_reserved_ids); } -Status FixedUnigramSampler::SetDistributionSampler(Env* env, - const string& vocab_file) { +absl::Status FixedUnigramSampler::SetDistributionSampler( + Env* env, const string& vocab_file) { TF_RETURN_IF_ERROR(LoadFromFile(env, vocab_file, distortion_)); if (!TF_PREDICT_TRUE(FixedUnigramSampler::range() == weights_.size())) return (errors::InvalidArgument("range is ", FixedUnigramSampler::range(), @@ -252,7 +252,7 @@ Status FixedUnigramSampler::SetDistributionSampler(Env* env, return absl::OkStatus(); } -Status FixedUnigramSampler::SetDistributionSampler( +absl::Status FixedUnigramSampler::SetDistributionSampler( const std::vector& unigrams) { LoadFromUnigrams(unigrams, distortion_); if (!TF_PREDICT_TRUE(FixedUnigramSampler::range() == weights_.size())) @@ -280,8 +280,9 @@ void FixedUnigramSampler::FillReservedIds(int32_t num_reserved_ids) { } } -Status FixedUnigramSampler::LoadFromFile(Env* env, const string& vocab_file, - float distortion) { +absl::Status FixedUnigramSampler::LoadFromFile(Env* env, + const string& vocab_file, + float distortion) { std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewRandomAccessFile(vocab_file, &file)); diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h index 94a0801a43be25..c49bbcc5b1eede 100644 --- a/tensorflow/core/kernels/range_sampler.h +++ b/tensorflow/core/kernels/range_sampler.h @@ -208,8 +208,8 @@ class FixedUnigramSampler : public RangeSampler { int32_t num_shards, int32_t shard); // The vocab_file is assumed to be a CSV, with the last entry of each row a // value representing the counts or probabilities for the corresponding ID. - Status SetDistributionSampler(Env* env, const string& vocab_file); - Status SetDistributionSampler(const std::vector& unigrams); + absl::Status SetDistributionSampler(Env* env, const string& vocab_file); + absl::Status SetDistributionSampler(const std::vector& unigrams); float Probability(int64_t value) const override; int64_t Sample(random::SimplePhilox* rnd) const override; @@ -232,7 +232,8 @@ class FixedUnigramSampler : public RangeSampler { void FillReservedIds(int32_t num_reserved_ids); // Load IDs to sample from a CSV file. It is assumed that the last item of // each row contains a count or probability for the corresponding ID. - Status LoadFromFile(Env* env, const string& vocab_file, float distortion); + absl::Status LoadFromFile(Env* env, const string& vocab_file, + float distortion); // Load from an in-memory array. void LoadFromUnigrams(const std::vector& unigrams, float distortion); }; diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc index 32aed624a8926f..1aeadc634ccea3 100644 --- a/tensorflow/core/kernels/range_sampler_test.cc +++ b/tensorflow/core/kernels/range_sampler_test.cc @@ -171,7 +171,7 @@ TEST_F(RangeSamplerTest, FixedUnigramNoExistingFilename) { Env* env = Env::Default(); string fname = "NoExistingFile"; FixedUnigramSampler* test_sampler = new FixedUnigramSampler(9, 0.8, 0, 1, 0); - Status s = test_sampler->SetDistributionSampler(env, fname); + absl::Status s = test_sampler->SetDistributionSampler(env, fname); sampler_.reset(test_sampler); EXPECT_TRUE(absl::IsNotFound(s)) << s; } @@ -180,7 +180,7 @@ TEST_F(RangeSamplerTest, FixedUnigramNoMatchingRangeWeights) { string fname = io::JoinPath(testing::TmpDir(), "vocab_file"); TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent)); FixedUnigramSampler* test_sampler = new FixedUnigramSampler(8, 0.8, 0, 1, 0); - Status s = test_sampler->SetDistributionSampler(env, fname); + absl::Status s = test_sampler->SetDistributionSampler(env, fname); sampler_.reset(test_sampler); EXPECT_TRUE(absl::IsInvalidArgument(s)) << s; } diff --git a/tensorflow/core/kernels/record_yielder.cc b/tensorflow/core/kernels/record_yielder.cc index 6bba98ab2c1ae9..ae3f9fab6fcba2 100644 --- a/tensorflow/core/kernels/record_yielder.cc +++ b/tensorflow/core/kernels/record_yielder.cc @@ -44,7 +44,7 @@ RecordYielder::~RecordYielder() { delete thread_; } -Status RecordYielder::YieldOne(tstring* value) { +absl::Status RecordYielder::YieldOne(tstring* value) { mutex_lock l(mu_); while (!BufEnough() && status_.ok()) { buf_enough_.wait(l); @@ -72,17 +72,17 @@ struct RecordYielder::Shard { int index; // Shard index. std::vector filenames; // File names given to this shard. Notification done; // Notified when this shard is done. - Status status; // Shard status. + absl::Status status; // Shard status. }; -bool RecordYielder::ShouldFinish(const Status& s) { +bool RecordYielder::ShouldFinish(const absl::Status& s) { mutex_lock l(mu_); status_.Update(s); return stop_ || !status_.ok(); } -static Status MatchFiles(const string& patterns, - std::vector* filenames) { +static absl::Status MatchFiles(const string& patterns, + std::vector* filenames) { for (const auto& file_pattern : str_util::Split(patterns, ',')) { std::vector tmp_filenames; TF_RETURN_IF_ERROR( @@ -102,7 +102,7 @@ void RecordYielder::MainLoop() { // Finds all files. std::vector filenames; - Status s = MatchFiles(opts_.file_pattern, &filenames); + absl::Status s = MatchFiles(opts_.file_pattern, &filenames); if (filenames.empty()) { s = errors::NotFound("Found no files at ", opts_.file_pattern); @@ -201,7 +201,7 @@ void RecordYielder::ShardLoop(Shard* shard) { for (const string& filename : shard->filenames) { std::unique_ptr file; if (ShouldFinish(absl::OkStatus())) break; - Status s = Env::Default()->NewRandomAccessFile(filename, &file); + absl::Status s = Env::Default()->NewRandomAccessFile(filename, &file); if (!s.ok()) { shard->status = errors::InvalidArgument("Can't open ", filename); break; @@ -213,7 +213,7 @@ void RecordYielder::ShardLoop(Shard* shard) { uint64 offset = 0; tstring record; while (true) { - Status s = rdr.ReadRecord(&offset, &record); + absl::Status s = rdr.ReadRecord(&offset, &record); if (s.ok()) { values.emplace_back(std::move(record)); if (values.size() >= kRecords && Add(&values)) { diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h index adb975a5a3e095..7e4c0f5ac04e02 100644 --- a/tensorflow/core/kernels/record_yielder.h +++ b/tensorflow/core/kernels/record_yielder.h @@ -90,7 +90,7 @@ class RecordYielder { RecordYielder& operator=(const RecordYielder&) = delete; // Yields one 'value'. - Status YieldOne(tstring* value); + absl::Status YieldOne(tstring* value); // Returns the current epoch number. int64_t current_epoch() const { return epoch_; } @@ -110,7 +110,7 @@ class RecordYielder { // Turned to true when this is deleted. bool stop_ TF_GUARDED_BY(mu_) = false; - Status status_ TF_GUARDED_BY(mu_); + absl::Status status_ TF_GUARDED_BY(mu_); // PRG used for randomization. std::mt19937_64 rnd_ TF_GUARDED_BY(mu_); @@ -151,7 +151,7 @@ class RecordYielder { void MainLoop(); struct Shard; void ShardLoop(Shard* shard); - bool ShouldFinish(const Status& s); + bool ShouldFinish(const absl::Status& s); bool Add(std::vector* values); }; diff --git a/tensorflow/core/kernels/reduction_ops_common.cc b/tensorflow/core/kernels/reduction_ops_common.cc index babb3fbb465178..60f5b9462f8366 100644 --- a/tensorflow/core/kernels/reduction_ops_common.cc +++ b/tensorflow/core/kernels/reduction_ops_common.cc @@ -58,8 +58,8 @@ absl::InlinedVector ReductionHelper::permutation() { } template -Status SimplifyHelper(const Tensor& data, const Tensor& axis, - absl::InlinedVector& bitmap) { +absl::Status SimplifyHelper(const Tensor& data, const Tensor& axis, + absl::InlinedVector& bitmap) { auto axis_vec = axis.flat(); for (int64_t i = 0; i < axis.NumElements(); ++i) { Tperm index = axis_vec(i); @@ -79,8 +79,8 @@ Status SimplifyHelper(const Tensor& data, const Tensor& axis, return absl::OkStatus(); } -Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, - const bool keep_dims) { +absl::Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims) { // bitmap[i] indicates whether to reduce data along i-th axis. absl::InlinedVector bitmap(data.dims(), false); if (axis.dtype() == DT_INT32) { diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index a4789936cebfa1..6ce777f748a777 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -71,7 +71,8 @@ class ReductionHelper { public: ReductionHelper() : reduce_first_axis_(false) {} - Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims); + absl::Status Simplify(const Tensor& data, const Tensor& axis, + const bool keep_dims); // We need to do roughly: // tmp_out = allocate(out_reshape()) diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc index a3e5041d189f83..bceedf4d8912a6 100644 --- a/tensorflow/core/kernels/regex_replace_op.cc +++ b/tensorflow/core/kernels/regex_replace_op.cc @@ -30,8 +30,8 @@ namespace { // Context requirements: // - "input" string Tensor at input_index=0 // - "output" string Tensor at output_index=0 -Status InternalCompute(const RE2& regex, const string& rewrite, - const bool replace_global, OpKernelContext* ctx) { +absl::Status InternalCompute(const RE2& regex, const string& rewrite, + const bool replace_global, OpKernelContext* ctx) { const Tensor* input_tensor; TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor)); Tensor* output_tensor; diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h index 5f0a46b0381207..dd603374c0fe73 100644 --- a/tensorflow/core/kernels/reshape_op.h +++ b/tensorflow/core/kernels/reshape_op.h @@ -116,9 +116,9 @@ class ReshapeOp : public OpKernel { private: template - Status ValidateSizes(const Tensor& sizes, int64_t* product, - int* unknown_index, TensorShape* shape, - bool* has_zero_dim) { + absl::Status ValidateSizes(const Tensor& sizes, int64_t* product, + int* unknown_index, TensorShape* shape, + bool* has_zero_dim) { *product = 1; *unknown_index = -1; *has_zero_dim = false; diff --git a/tensorflow/core/kernels/reshape_util.cc b/tensorflow/core/kernels/reshape_util.cc index d5383b3319db4b..eff4212bf536ef 100644 --- a/tensorflow/core/kernels/reshape_util.cc +++ b/tensorflow/core/kernels/reshape_util.cc @@ -40,10 +40,11 @@ namespace functor { template <> struct ReshapeSparseTensorFunctor { - Status operator()(OpKernelContext *context, const TensorShape &input_shape, - const TensorShape &output_shape, - typename TTypes::ConstMatrix input_indices, - typename TTypes::Matrix output_indices) const { + absl::Status operator()( + OpKernelContext *context, const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const { (void)context; // Unused (only used in GPU implementation) const int64_t input_rank = input_shape.dims(); const int64_t output_rank = output_shape.dims(); diff --git a/tensorflow/core/kernels/reshape_util.h b/tensorflow/core/kernels/reshape_util.h index 28d1909c07bda1..1945712c7db6b8 100644 --- a/tensorflow/core/kernels/reshape_util.h +++ b/tensorflow/core/kernels/reshape_util.h @@ -38,10 +38,11 @@ namespace functor { template struct ReshapeSparseTensorFunctor { - Status operator()(OpKernelContext *context, const TensorShape &input_shape, - const TensorShape &output_shape, - typename TTypes::ConstMatrix input_indices, - typename TTypes::Matrix output_indices) const; + absl::Status operator()( + OpKernelContext *context, const TensorShape &input_shape, + const TensorShape &output_shape, + typename TTypes::ConstMatrix input_indices, + typename TTypes::Matrix output_indices) const; }; } // namespace functor diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 97f19a918d745f..d362d9d66cccb9 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -99,10 +99,11 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { namespace { -Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { +absl::Status CopyVariable(int output_idx, OpKernelContext* ctx, + const Tensor* t) { Tensor* output; Notification n; - Status status; + absl::Status status; AllocatorAttributes attr; if (t->dtype() == DT_VARIANT) { attr.set_on_host(true); @@ -116,7 +117,7 @@ Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { // OpKernelContext Device* device = down_cast(ctx->device()); ctx->op_device_context()->CopyTensorInSameDevice( - t, device, output, [&n, &status](const Status& s) { + t, device, output, [&n, &status](const absl::Status& s) { status = s; n.Notify(); }); @@ -357,7 +358,7 @@ DestroyResourceOp::DestroyResourceOp(OpKernelConstruction* ctx) void DestroyResourceOp::Compute(OpKernelContext* ctx) { const ResourceHandle& p = HandleFromInput(ctx, 0); - Status status = DeleteResource(ctx, p); + absl::Status status = DeleteResource(ctx, p); if (ignore_lookup_error_ && errors::IsNotFound(status)) { return; } @@ -688,7 +689,8 @@ class VarIsInitializedOp : public OpKernel { context->allocate_output(0, TensorShape({}), &output)); auto output_tensor = output->tensor(); core::RefCountPtr variable; - Status s = LookupResource(context, HandleFromInput(context, 0), &variable); + absl::Status s = + LookupResource(context, HandleFromInput(context, 0), &variable); if (!s.ok()) { output_tensor() = false; return; @@ -973,12 +975,14 @@ bool ValidateInput(const Tensor& updates) { } template -Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices, - const Tensor& updates, Index num_indices); +absl::Status DoScatter(OpKernelContext* c, Tensor* params, + const Tensor& indices, const Tensor& updates, + Index num_indices); template -Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices, - const Tensor& updates, Index num_indices); +absl::Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, + const Tensor& indices, const Tensor& updates, + Index num_indices); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -1046,8 +1050,9 @@ Status DoScatterOnCpu(OpKernelContext* c, Tensor* params, const Tensor& indices, #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM template -Status DoScatter(OpKernelContext* c, Tensor* params, const Tensor& indices, - const Tensor& updates, Index num_indices) { +absl::Status DoScatter(OpKernelContext* c, Tensor* params, + const Tensor& indices, const Tensor& updates, + Index num_indices) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value && tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) { @@ -1107,7 +1112,7 @@ class ResourceScatterUpdateOp : public OpKernel { explicit ResourceScatterUpdateOp(OpKernelConstruction* c) : OpKernel(c) { // We use the same kernel for many operations. // Each operation has a different set of attributes defined in its nodes. - Status s = c->GetAttr("use_locking", &use_exclusive_lock_); + absl::Status s = c->GetAttr("use_locking", &use_exclusive_lock_); if (!s.ok()) { use_exclusive_lock_ = false; } diff --git a/tensorflow/core/kernels/resource_variable_util.cc b/tensorflow/core/kernels/resource_variable_util.cc index d66ac4e3e1d47f..fcc4aa31b39c59 100644 --- a/tensorflow/core/kernels/resource_variable_util.cc +++ b/tensorflow/core/kernels/resource_variable_util.cc @@ -19,8 +19,8 @@ limitations under the License. namespace tensorflow { -Status ValidateAssignUpdateVariableOpShapes(const TensorShape& variable_shape, - const TensorShape& value_shape) { +absl::Status ValidateAssignUpdateVariableOpShapes( + const TensorShape& variable_shape, const TensorShape& value_shape) { if (!variable_shape.IsSameSize(value_shape)) { return errors::InvalidArgument( "Cannot update variable with shape ", variable_shape.DebugString(), diff --git a/tensorflow/core/kernels/resource_variable_util.h b/tensorflow/core/kernels/resource_variable_util.h index 72ffd4b5028fb1..1222b4eb5c6f0b 100644 --- a/tensorflow/core/kernels/resource_variable_util.h +++ b/tensorflow/core/kernels/resource_variable_util.h @@ -20,8 +20,8 @@ limitations under the License. namespace tensorflow { -Status ValidateAssignUpdateVariableOpShapes(const TensorShape& variable_shape, - const TensorShape& value_shape); +absl::Status ValidateAssignUpdateVariableOpShapes( + const TensorShape& variable_shape, const TensorShape& value_shape); } // namespace tensorflow diff --git a/tensorflow/core/kernels/restore_op_test.cc b/tensorflow/core/kernels/restore_op_test.cc index a08732cd3b8bce..15dacaf6d93c45 100644 --- a/tensorflow/core/kernels/restore_op_test.cc +++ b/tensorflow/core/kernels/restore_op_test.cc @@ -87,7 +87,7 @@ TEST_F(RestoreOpTest, RestoreSimple) { absl::InlinedVector inputs; - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), save, TF_GRAPH_DEF_VERSION, &status)); @@ -391,7 +391,7 @@ TEST_F(RestoreSliceOpTest, RestoreInt) { absl::InlinedVector inputs; - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), save, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/restore_v2_op_test.cc b/tensorflow/core/kernels/restore_v2_op_test.cc index 3a3649429e490b..b9f289f01bb90f 100644 --- a/tensorflow/core/kernels/restore_v2_op_test.cc +++ b/tensorflow/core/kernels/restore_v2_op_test.cc @@ -98,7 +98,7 @@ class RestoreV2OpTest : public OpsTestBase { absl::InlinedVector inputs; - Status status; + absl::Status status; std::unique_ptr op( CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), save, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/roll_op_test.cc b/tensorflow/core/kernels/roll_op_test.cc index 6e3bc3970ed60c..5ea034fb0f4a9f 100644 --- a/tensorflow/core/kernels/roll_op_test.cc +++ b/tensorflow/core/kernels/roll_op_test.cc @@ -373,7 +373,7 @@ TEST_F(RollOpTest, Error_InputMustBeVectorOrHigher) { AddInputFromArray(TensorShape({}), {7}); AddInputFromArray(TensorShape({}), {1}); AddInputFromArray(TensorShape({}), {0}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "input must be 1-D or higher")) << s; } @@ -385,7 +385,7 @@ TEST_F(RollOpTest, Error_AxisMustBeScalarOrVector) { AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({}), {1}); AddInputFromArray(TensorShape({1, 2}), {0, 1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "axis must be a scalar or a 1-D vector")) << s; @@ -398,7 +398,7 @@ TEST_F(RollOpTest, Error_ShiftMustBeScalarOrVector) { AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1, 2}), {0, 1}); AddInputFromArray(TensorShape({}), {1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "shift must be a scalar or a 1-D vector")) << s; @@ -411,7 +411,7 @@ TEST_F(RollOpTest, Error_ShiftAndAxisMustBeSameSize) { AddInputFromArray(TensorShape({2, 2}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({1}), {1}); AddInputFromArray(TensorShape({2}), {0, 1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "shift and axis must have the same size")) << s; @@ -424,7 +424,7 @@ TEST_F(RollOpTest, Error_AxisOutOfRange) { AddInputFromArray(TensorShape({4}), {1, 2, 3, 4}); AddInputFromArray(TensorShape({}), {1}); AddInputFromArray(TensorShape({}), {1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "is out of range")) << s; } diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index 8cda48097cf9b8..5ad128c323bcc5 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -88,7 +88,7 @@ void SaveTensors( checkpoint::TensorSliceWriter writer(filename_t.flat()(0), std::move(builder_func)); - Status s; + absl::Status s; auto tensor_names_flat = tensor_names_t.flat(); // Process tensors in sorted name order. This allows us to avoid seeking @@ -294,7 +294,7 @@ struct RestoreOp { status = run(&reader); } - Status run(BundleReader* reader) { + absl::Status run(BundleReader* reader) { TensorShape restored_full_shape; TF_RETURN_IF_ERROR( reader->LookupTensorShape(tensor_name, &restored_full_shape)); @@ -357,15 +357,15 @@ struct RestoreOp { string reader_prefix; DataType dtype; - ::tensorflow::Status status; + absl::Status status; }; } // namespace -Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, - const Tensor& tensor_names, - const Tensor& shape_and_slices, - absl::Span dtypes) { +absl::Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + absl::Span dtypes) { const string& prefix_string = prefix.scalar()(); const auto& tensor_names_flat = tensor_names.flat(); diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h index 6e58b90c6c4d87..f5fac54138f550 100644 --- a/tensorflow/core/kernels/save_restore_tensor.h +++ b/tensorflow/core/kernels/save_restore_tensor.h @@ -63,10 +63,10 @@ void RestoreTensor(OpKernelContext* context, // * "prefix" has 1 element, DT_STRING. // * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING. // * "dtypes" has N elements, the datatypes of the to-restore tensors. -Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, - const Tensor& tensor_names, - const Tensor& shape_and_slices, - absl::Span dtypes); +absl::Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, + const Tensor& tensor_names, + const Tensor& shape_and_slices, + absl::Span dtypes); } // namespace tensorflow diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index 53329049936d66..04344dd8a63dcc 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -296,7 +296,7 @@ class MergeV2Checkpoints : public OpKernel { for (const string& input_prefix : input_prefixes) { const string dirname(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; - Status status = env->DeleteDir(dirname); + absl::Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all // others will hit NotFound. Use vlog to be less verbose. if (!status.ok()) VLOG(1) << status; diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index f359e3b3a55e5c..7d61e1aa2f257e 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -61,9 +61,10 @@ namespace functor { template -Status DoScatterNd(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, Tensor* out, - bool allocate, BadIndicesPolicy bad_indices_policy); +absl::Status DoScatterNd(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate, + BadIndicesPolicy bad_indices_policy); } // namespace functor // Returns true if the three tensors have valid number of elements @@ -838,10 +839,10 @@ TF_CALL_COMPLEX_TYPES(REGISTER_SCATTER_ND_TENSOR_GPU); namespace functor { template -Status PrepareAndValidateInputs(const TensorShape& params_shape, - const Tensor& indices, const Tensor& updates, - int64_t* slice_dim, Index* num_updates, - Index* slice_size) { +absl::Status PrepareAndValidateInputs(const TensorShape& params_shape, + const Tensor& indices, + const Tensor& updates, int64_t* slice_dim, + Index* num_updates, Index* slice_size) { const TensorShape& indices_shape(indices.shape()); const TensorShape& updates_shape(updates.shape()); @@ -924,10 +925,10 @@ namespace { template -Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, - Tensor* out, bool allocate, - BadIndicesPolicy bad_indices_policy) { +absl::Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate, + BadIndicesPolicy bad_indices_policy) { int64_t slice_dim; Index num_updates; Index slice_size; @@ -1011,23 +1012,23 @@ Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, template -Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, - Tensor* out, bool allocate) { +absl::Status DoScatterNdImpl(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate) { return DoScatterNdImpl( c, indices, updates, shape, out, allocate, BadIndicesPolicy::kDefault); } template -Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, - Tensor* out, bool allocate, - BadIndicesPolicy bad_indices_policy); +absl::Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate, + BadIndicesPolicy bad_indices_policy); template -Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, - Tensor* out, bool allocate) { +absl::Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate) { return DoScatterNdOnCpu(c, indices, updates, shape, out, allocate, BadIndicesPolicy::kDefault); } @@ -1104,9 +1105,10 @@ Status DoScatterNdOnCpu(OpKernelContext* c, const Tensor& indices, template -Status DoScatterNd(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, Tensor* out, - bool allocate, BadIndicesPolicy bad_indices_policy) { +absl::Status DoScatterNd(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate, + BadIndicesPolicy bad_indices_policy) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (std::is_same::value && tensorflow::OpDeterminismRequired() && !DisableScatterOpDeterminism()) { @@ -1129,9 +1131,9 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, template -Status DoScatterNd(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, Tensor* out, - bool allocate) { +absl::Status DoScatterNd(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate) { return DoScatterNd( c, indices, updates, shape, out, allocate, BadIndicesPolicy::kDefault); } diff --git a/tensorflow/core/kernels/scatter_nd_op.h b/tensorflow/core/kernels/scatter_nd_op.h index f9a2ce0ed6e12b..b736d4b0aafcf7 100644 --- a/tensorflow/core/kernels/scatter_nd_op.h +++ b/tensorflow/core/kernels/scatter_nd_op.h @@ -64,9 +64,9 @@ struct ScatterNdFunctor { // before the scatter is executed. template -Status DoScatterNd(OpKernelContext* c, const Tensor& indices, - const Tensor& updates, const TensorShape& shape, Tensor* out, - bool allocate); +absl::Status DoScatterNd(OpKernelContext* c, const Tensor& indices, + const Tensor& updates, const TensorShape& shape, + Tensor* out, bool allocate); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index 02fa44f193b28f..addebb6115e005 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -146,7 +146,7 @@ TEST_F(TensorScatterUpdateOpTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {0, 99, 4}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [99] does not index into shape [5,3]")) << s; @@ -174,7 +174,7 @@ TEST_F(TensorScatterUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {0, 99, 4}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [99] does not index into shape [5,3]")) << s; @@ -361,7 +361,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {0, 99, 4}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [99] does not index into shape [5,3]")) << s; @@ -375,7 +375,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_WrongDimsIndices) { AddInputFromArray(TensorShape({1, 3, 1}), {0, 4, 99}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "Dimensions [0,1) of indices[shape=[1,3,1]] = 1 must match dimensions " @@ -393,7 +393,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { AddInputFromArray( TensorShape({3, 4}), {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "Dimensions [1,2) of input[shape=[5,3]] must match dimensions [1,2) of " @@ -410,7 +410,7 @@ TEST_F(ScatterNdUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { AddInputFromArray(TensorShape({3, 1}), {0, 4, 2}); AddInputFromArray(TensorShape({2, 3}), {100, 101, 102, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "Dimensions [0,1) of indices[shape=[3,1]] = 3 must match dimensions [0,1)" @@ -440,7 +440,7 @@ TEST_F(ScatterNdUpdateOpErrorOnBadIndicesTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {0, 99, 4}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [99] does not index into shape [5,3]")) << s; @@ -533,7 +533,7 @@ TEST_F(ScatterNdOpTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {100, 101, 102}); // Shape: output tensor of 5x1 shape. AddInputFromArray(TensorShape({2}), {5, 1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); // The valid index range is [0,5). Expect "5" to raise error. EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [5] does not index into shape [5,1]")) @@ -564,7 +564,7 @@ TEST_F(ScatterNdOpErrorOnBadIndicesTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3, 1}), {100, 101, 102}); // Shape: output tensor of 5x1 shape. AddInputFromArray(TensorShape({2}), {5, 1}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); // The valid index range is [0,5). Expect "5" to raise error. EXPECT_TRUE(absl::StrContains( s.ToString(), "indices[1] = [5] does not index into shape [5,1]")) @@ -658,7 +658,7 @@ void BM_ScatterNdHelper(::testing::benchmark::State& state, int embedding_size, bm.AddInputFromArray(TensorShape({kNumUpdates, embedding_size}), updates); for (auto i : state) { - Status s = bm.RunOpKernel(); + absl::Status s = bm.RunOpKernel(); } state.SetItemsProcessed((static_cast(kNumUpdates) * embedding_size) * state.iterations()); diff --git a/tensorflow/core/kernels/scatter_nd_util.cc b/tensorflow/core/kernels/scatter_nd_util.cc index 4793e4ce99761c..73624d356c7713 100644 --- a/tensorflow/core/kernels/scatter_nd_util.cc +++ b/tensorflow/core/kernels/scatter_nd_util.cc @@ -19,9 +19,9 @@ limitations under the License. namespace tensorflow { -Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, - const TensorShape& indices_shape, - const TensorShape& updates_shape) { +absl::Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape) { const int64_t slice_dim = (indices_shape.dims() > 1) ? indices_shape.dim_size(indices_shape.dims() - 1) diff --git a/tensorflow/core/kernels/scatter_nd_util.h b/tensorflow/core/kernels/scatter_nd_util.h index f0530048ef699a..5095e92582df0e 100644 --- a/tensorflow/core/kernels/scatter_nd_util.h +++ b/tensorflow/core/kernels/scatter_nd_util.h @@ -22,9 +22,9 @@ limitations under the License. namespace tensorflow { // Validates the input shapes for the ScatterNdUpdateOp -Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, - const TensorShape& indices_shape, - const TensorShape& updates_shape); +absl::Status ValidateScatterNdUpdateShape(const TensorShape& params_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape); inline bool DisableScatterOpDeterminism() { static bool cached_disable = [] { diff --git a/tensorflow/core/kernels/scatter_op_test.cc b/tensorflow/core/kernels/scatter_op_test.cc index f3f8d8ea84c390..11e2150f64a274 100644 --- a/tensorflow/core/kernels/scatter_op_test.cc +++ b/tensorflow/core/kernels/scatter_op_test.cc @@ -180,7 +180,7 @@ TEST_F(ScatterUpdateOpTest, Error_IndexOutOfRange) { AddInputFromArray(TensorShape({3}), {0, 4, 99}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 5)")) << s; @@ -193,7 +193,7 @@ TEST_F(ScatterSubOpTest, Error_IndexOutOfRange) { {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); AddInputFromArray(TensorShape({3}), {0, 1, 99}); AddInputFromArray(TensorShape({3}), {100, 101, 102}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE( absl::StrContains(s.ToString(), "indices[2] = 99 is not in [0, 14)")) << s; @@ -210,7 +210,7 @@ TEST_F(ScatterSubOpTest, StressIndexTest) { AddInputFromArray(TensorShape({kRows}), values); AddInputFromArray(TensorShape({kNumUpdates}), indices); AddInputFromArray(TensorShape({kNumUpdates}), updates); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); Tensor params_tensor = *mutable_input(0).tensor; Tensor expected(allocator(), DT_INT32, TensorShape({1})); test::FillValues(&expected, {-1000000}); @@ -225,7 +225,7 @@ TEST_F(ScatterUpdateOpTest, Error_WrongDimsIndices) { AddInputFromArray(TensorShape({1, 3}), {0, 4, 99}); AddInputFromArray(TensorShape({3, 3}), {100, 101, 102, 777, 778, 779, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Must have updates.shape = indices.shape + " "params.shape[1:] or updates.shape = [], got ")) @@ -242,7 +242,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedParamsAndUpdateDimensions) { AddInputFromArray( TensorShape({3, 4}), {100, 101, 102, 103, 777, 778, 779, 780, 10000, 10001, 10002, 10004}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Must have updates.shape = indices.shape + " "params.shape[1:] or updates.shape = [], got ")) @@ -259,7 +259,7 @@ TEST_F(ScatterUpdateOpTest, Error_MismatchedIndicesAndUpdateDimensions) { AddInputFromArray(TensorShape({3}), {0, 4, 2}); AddInputFromArray(TensorShape({2, 3}), {100, 101, 102, 10000, 10001, 10002}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Must have updates.shape = indices.shape + " "params.shape[1:] or updates.shape = [], got ")) @@ -307,7 +307,7 @@ void BM_ScatterHelper(::testing::benchmark::State& state, int embedding_size, bm.AddInputFromArray(TensorShape({kNumUpdates, embedding_size}), updates); for (auto i : state) { - Status s = bm.RunOpKernel(); + absl::Status s = bm.RunOpKernel(); } state.SetItemsProcessed((static_cast(kNumUpdates) * embedding_size) * state.iterations()); diff --git a/tensorflow/core/kernels/scoped_allocator_ops.cc b/tensorflow/core/kernels/scoped_allocator_ops.cc index 95b768cfc10fbf..93ac2a61a7eb6b 100644 --- a/tensorflow/core/kernels/scoped_allocator_ops.cc +++ b/tensorflow/core/kernels/scoped_allocator_ops.cc @@ -56,7 +56,7 @@ class ScopedAllocatorOp : public OpKernel { } Tensor* backing_tensor = nullptr; AllocatorAttributes attr = context->output_alloc_attr(0); - Status s = + absl::Status s = context->allocate_output(0, {num_elements_}, &backing_tensor, attr); VLOG(1) << "_ScopedAllocatorOp " << context->op_kernel().name() << " new backing tensor size " << backing_tensor->TotalBytes() diff --git a/tensorflow/core/kernels/scoped_allocator_ops_test.cc b/tensorflow/core/kernels/scoped_allocator_ops_test.cc index 4edc38e0f772a5..4c6c26d0802ca4 100644 --- a/tensorflow/core/kernels/scoped_allocator_ops_test.cc +++ b/tensorflow/core/kernels/scoped_allocator_ops_test.cc @@ -99,9 +99,9 @@ void PrepOp(DataType dtype, int32_t id, *backing_tensor = new Tensor(allocator, dtype, {num_elements}); int64_t step_id = 10; - Status s = sam->AddScopedAllocator(**backing_tensor, step_id, id, - "sa_" + op_name + "_test", *fields, - fields_shapes.size()); + absl::Status s = sam->AddScopedAllocator(**backing_tensor, step_id, id, + "sa_" + op_name + "_test", *fields, + fields_shapes.size()); TF_ASSERT_OK(s); ScopedAllocatorContainer* sac = sam->GetContainer(step_id); @@ -179,7 +179,7 @@ class ScopedAllocatorConcatOpTest : public OpsTestBase { // Check input and output are same tensor. const Tensor& input = context_->input(0); OpOutputList output_list; - Status s = context_->output_list("output", &output_list); + absl::Status s = context_->output_list("output", &output_list); TF_ASSERT_OK(s); const Tensor& output = *(output_list[0]); CHECK_EQ(DMAHelper::base(&input), DMAHelper::base(&output)); @@ -242,7 +242,7 @@ TEST_F(ScopedAllocatorConcatOpTest, FailNumElementsCheck) { AddInputFromArray({8}, {0, 1, 2, 3, 4, 5, 6, 7}); AddInputFromArray({4}, {0, 1, 2, 3}); AddInputFromArray({4}, {4, 5, 6, 7}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); } @@ -253,7 +253,7 @@ TEST_F(ScopedAllocatorConcatOpTest, FailBounds) { AddInputFromArray({8}, {0, 1, 2, 3, 4, 5, 6, 7}); AddInputFromArray({4}, {0, 1, 2, 3}); AddInputFromArray({4}, {4, 5, 6, 7}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); } @@ -301,7 +301,7 @@ class ScopedAllocatorSplitOpTest : public OpsTestBase { const char* lower_limit_c = static_cast(lower_limit); // for pointer arithmetic OpOutputList output_list; - Status s = context_->output_list("output", &output_list); + absl::Status s = context_->output_list("output", &output_list); TF_ASSERT_OK(s); for (int i = 0; i < output_list.size(); i++) { const Tensor& output = *(output_list[i]); @@ -334,7 +334,7 @@ TEST_F(ScopedAllocatorSplitOpTest, Success3) { TEST_F(ScopedAllocatorSplitOpTest, FailNLessThan2) { BuildNodeDef({4, 4}, DT_FLOAT, "test", 120, 1, {{4, 4}}); - Status s = InitOp(); + absl::Status s = InitOp(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); } @@ -348,7 +348,7 @@ TEST_F(ScopedAllocatorSplitOpTest, FailBounds) { AddInputFromArray({8}, {0, 1, 2, 3, 4, 5, 6, 7}); AddInputFromArray({4}, {0, 1, 2, 3}); AddInputFromArray({4}, {4, 5, 6, 7}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); } diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc index fa33622eee71f7..d84744ec8ab7ba 100644 --- a/tensorflow/core/kernels/sdca_internal.cc +++ b/tensorflow/core/kernels/sdca_internal.cc @@ -92,7 +92,7 @@ void ModelWeights::UpdateDeltaWeights( } } -Status ModelWeights::Initialize(OpKernelContext* const context) { +absl::Status ModelWeights::Initialize(OpKernelContext* const context) { OpInputList sparse_indices_inputs; TF_RETURN_IF_ERROR( context->input_list("sparse_indices", &sparse_indices_inputs)); @@ -246,7 +246,7 @@ const ExampleStatistics Example::ComputeWxAndWeightedExampleNorm( } // Examples contains all the training examples that SDCA uses for a mini-batch. -Status Examples::SampleAdaptiveProbabilities( +absl::Status Examples::SampleAdaptiveProbabilities( const int num_loss_partitions, const Regularizations& regularization, const ModelWeights& model_weights, const TTypes::Matrix example_state_data, @@ -262,7 +262,7 @@ Status Examples::SampleAdaptiveProbabilities( const Example& example = examples_[example_id]; const double example_weight = example.example_weight(); float label = example.example_label(); - const Status conversion_status = loss_updater->ConvertLabel(&label); + const absl::Status conversion_status = loss_updater->ConvertLabel(&label); const ExampleStatistics example_statistics = example.ComputeWxAndWeightedExampleNorm(num_loss_partitions, model_weights, regularization, @@ -331,11 +331,11 @@ void Examples::RandomShuffle() { } // TODO(sibyl-Aix6ihai): Refactor/shorten this function. -Status Examples::Initialize(OpKernelContext* const context, - const ModelWeights& weights, - const int num_sparse_features, - const int num_sparse_features_with_values, - const int num_dense_features) { +absl::Status Examples::Initialize(OpKernelContext* const context, + const ModelWeights& weights, + const int num_sparse_features, + const int num_sparse_features_with_values, + const int num_dense_features) { num_features_ = num_sparse_features + num_dense_features; OpInputList sparse_example_indices_inputs; @@ -424,7 +424,7 @@ Status Examples::Initialize(OpKernelContext* const context, return absl::OkStatus(); } -Status Examples::CreateSparseFeatureRepresentation( +absl::Status Examples::CreateSparseFeatureRepresentation( const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples, const int num_sparse_features, const ModelWeights& weights, const OpInputList& sparse_example_indices_inputs, @@ -432,7 +432,7 @@ Status Examples::CreateSparseFeatureRepresentation( const OpInputList& sparse_feature_values_inputs, std::vector* const examples) { mutex mu; - Status result; // Guarded by mu + absl::Status result; // Guarded by mu auto parse_partition = [&](const int64_t begin, const int64_t end) { // The static_cast here is safe since begin and end can be at most // num_examples which is an int. @@ -511,13 +511,13 @@ Status Examples::CreateSparseFeatureRepresentation( return result; } -Status Examples::CreateDenseFeatureRepresentation( +absl::Status Examples::CreateDenseFeatureRepresentation( const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples, const int num_dense_features, const ModelWeights& weights, const OpInputList& dense_features_inputs, std::vector* const examples) { mutex mu; - Status result; // Guarded by mu + absl::Status result; // Guarded by mu auto parse_partition = [&](const int64_t begin, const int64_t end) { // The static_cast here is safe since begin and end can be at most // num_examples which is an int. @@ -543,12 +543,12 @@ Status Examples::CreateDenseFeatureRepresentation( return result; } -Status Examples::ComputeSquaredNormPerExample( +absl::Status Examples::ComputeSquaredNormPerExample( const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples, const int num_sparse_features, const int num_dense_features, std::vector* const examples) { mutex mu; - Status result; // Guarded by mu + absl::Status result; // Guarded by mu // Compute norm of examples. auto compute_example_norm = [&](const int64_t begin, const int64_t end) { // The static_cast here is safe since begin and end can be at most diff --git a/tensorflow/core/kernels/sdca_internal.h b/tensorflow/core/kernels/sdca_internal.h index 6e574640d0e585..8f5ac0384edccb 100644 --- a/tensorflow/core/kernels/sdca_internal.h +++ b/tensorflow/core/kernels/sdca_internal.h @@ -76,7 +76,7 @@ class Regularizations { Regularizations() {} // Initialize() must be called immediately after construction. - Status Initialize(OpKernelConstruction* const context) { + absl::Status Initialize(OpKernelConstruction* const context) { TF_RETURN_IF_ERROR(context->GetAttr("l1", &symmetric_l1_)); TF_RETURN_IF_ERROR(context->GetAttr("l2", &symmetric_l2_)); shrinkage_ = symmetric_l1_ / symmetric_l2_; @@ -294,7 +294,7 @@ class ModelWeights { const Eigen::ThreadPoolDevice& device, const Example& example, const std::vector& normalized_bounded_dual_delta); - Status Initialize(OpKernelContext* const context); + absl::Status Initialize(OpKernelContext* const context); const std::vector& sparse_weights() const { return sparse_weights_; @@ -327,7 +327,7 @@ class Examples { // Adaptive SDCA in the current implementation only works for // binary classification, where the input argument for num_weight_vectors // is 1. - Status SampleAdaptiveProbabilities( + absl::Status SampleAdaptiveProbabilities( const int num_loss_partitions, const Regularizations& regularization, const ModelWeights& model_weights, const TTypes::Matrix example_state_data, @@ -341,16 +341,16 @@ class Examples { int num_features() const { return num_features_; } // Initialize() must be called immediately after construction. - Status Initialize(OpKernelContext* const context, const ModelWeights& weights, - int num_sparse_features, - int num_sparse_features_with_values, - int num_dense_features); + absl::Status Initialize(OpKernelContext* const context, + const ModelWeights& weights, int num_sparse_features, + int num_sparse_features_with_values, + int num_dense_features); private: // Reads the input tensors, and builds the internal representation for sparse // features per example. This function modifies the |examples| passed in // to build the sparse representations. - static Status CreateSparseFeatureRepresentation( + static absl::Status CreateSparseFeatureRepresentation( const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, int num_sparse_features, const ModelWeights& weights, const OpInputList& sparse_example_indices_inputs, @@ -361,7 +361,7 @@ class Examples { // Reads the input tensors, and builds the internal representation for dense // features per example. This function modifies the |examples| passed in // to build the sparse representations. - static Status CreateDenseFeatureRepresentation( + static absl::Status CreateDenseFeatureRepresentation( const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, int num_dense_features, const ModelWeights& weights, const OpInputList& dense_features_inputs, @@ -369,7 +369,7 @@ class Examples { // Computes squared example norm per example i.e |x|^2. This function modifies // the |examples| passed in and adds the squared norm per example. - static Status ComputeSquaredNormPerExample( + static absl::Status ComputeSquaredNormPerExample( const DeviceBase::CpuWorkerThreads& worker_threads, int num_examples, int num_sparse_features, int num_dense_features, std::vector* const examples); diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc index c302b0f188d365..99acf4d7fdb0b5 100644 --- a/tensorflow/core/kernels/sdca_ops.cc +++ b/tensorflow/core/kernels/sdca_ops.cc @@ -171,7 +171,7 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) { } struct { mutex mu; - Status value TF_GUARDED_BY(mu); + absl::Status value TF_GUARDED_BY(mu); } train_step_status; std::atomic atomic_index(-1); auto train_step = [&](const int64_t begin, const int64_t end) { @@ -183,7 +183,7 @@ void DoCompute(const ComputeOptions& options, OpKernelContext* const context) { const float dual = example_state_data(example_index, 0); const float example_weight = example.example_weight(); float example_label = example.example_label(); - const Status conversion_status = + const absl::Status conversion_status = options.loss_updater->ConvertLabel(&example_label); if (!conversion_status.ok()) { mutex_lock l(train_step_status.mu); diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc index d955ae7d2ac022..809e4d3f0be102 100644 --- a/tensorflow/core/kernels/searchsorted_op.cc +++ b/tensorflow/core/kernels/searchsorted_op.cc @@ -35,11 +35,12 @@ typedef Eigen::GpuDevice GPUDevice; namespace functor { template struct UpperBoundFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& sorted_inputs, - const typename TTypes::ConstTensor& values, - int batch_size, int num_inputs, int num_values, - typename TTypes::Tensor* output) { + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output) { auto work_fn = [&](int64_t first, int64_t last) { for (int b = 0; b < batch_size; ++b) { const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; @@ -64,11 +65,12 @@ struct UpperBoundFunctor { template struct LowerBoundFunctor { - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& sorted_inputs, - const typename TTypes::ConstTensor& values, - int batch_size, int num_inputs, int num_values, - typename TTypes::Tensor* output) { + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output) { auto work_fn = [&](int64_t first, int64_t last) { for (int b = 0; b < batch_size; ++b) { const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; @@ -117,13 +119,13 @@ class UpperBoundOp : public OpKernel { values_t.shape().dims(), " for `values` argument"))); // must have same batch dim_size for both OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), - Status(absl::StatusCode::kInvalidArgument, - "Leading dim_size of both tensors must match.")); + absl::Status(absl::StatusCode::kInvalidArgument, + "Leading dim_size of both tensors must match.")); // this is required because we do indexing in int32 on the GPU OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits::max(), - Status(absl::StatusCode::kInvalidArgument, - "values tensor size must less than INT_MAX")); + absl::Status(absl::StatusCode::kInvalidArgument, + "values tensor size must less than INT_MAX")); Tensor* output_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); @@ -180,13 +182,13 @@ class LowerBoundOp : public OpKernel { values_t.shape().dims(), " for `values` argument"))); // must have same batch dim_size for both OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), - Status(absl::StatusCode::kInvalidArgument, - "Leading dim_size of both tensors must match.")); + absl::Status(absl::StatusCode::kInvalidArgument, + "Leading dim_size of both tensors must match.")); // this is required because we do indexing in int32 on the GPU OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits::max(), - Status(absl::StatusCode::kInvalidArgument, - "values tensor size must less than INT_MAX")); + absl::Status(absl::StatusCode::kInvalidArgument, + "values tensor size must less than INT_MAX")); Tensor* output_t; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h index a139dfc23e9b33..fb4ade03f7dcb4 100644 --- a/tensorflow/core/kernels/searchsorted_op.h +++ b/tensorflow/core/kernels/searchsorted_op.h @@ -29,22 +29,24 @@ template struct UpperBoundFunctor { // Searches for values in sorted_inputs and returns the greatest possible // index where they maintain sorted order. - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& sorted_inputs, - const typename TTypes::ConstTensor& values, - int batch_size, int num_inputs, int num_values, - typename TTypes::Tensor* output); + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output); }; template struct LowerBoundFunctor { // Searches for values in sorted_inputs and returns the lowest possible // index where they maintain sorted order. - static Status Compute(OpKernelContext* context, - const typename TTypes::ConstTensor& sorted_inputs, - const typename TTypes::ConstTensor& values, - int batch_size, int num_inputs, int num_values, - typename TTypes::Tensor* output); + static absl::Status Compute( + OpKernelContext* context, + const typename TTypes::ConstTensor& sorted_inputs, + const typename TTypes::ConstTensor& values, int batch_size, + int num_inputs, int num_values, + typename TTypes::Tensor* output); }; } // namespace functor diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 95c71006b6ee77..93aa9636110b9a 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -138,11 +138,12 @@ struct Highest { template struct SparseSegmentReductionFunctor { - Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn, - T default_value, typename TTypes::ConstTensor input, - typename TTypes::ConstVec indices, - typename TTypes::ConstVec segment_ids, - typename TTypes::Tensor output); + absl::Status operator()(OpKernelContext* context, bool is_mean, bool is_sqrtn, + T default_value, + typename TTypes::ConstTensor input, + typename TTypes::ConstVec indices, + typename TTypes::ConstVec segment_ids, + typename TTypes::Tensor output); }; template diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h index a9bf175f205f4e..f0ba0ce2c27572 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.h @@ -33,9 +33,7 @@ limitations under the License. #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace #include "tensorflow/core/util/permutation_input_iterator.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#elif (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) +#if (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) #include "tensorflow/core/platform/rocm.h" #endif @@ -1315,8 +1313,8 @@ struct SparseSegmentGradV2Functor { const GPUDevice& device = context->eigen_gpu_device(); Toffsets num_unique = (*last_idx_host.data()) + 1; - se::gpu::ScopedActivateContext scoped_activation{ - context->op_device_context()->stream()->parent()}; + std::unique_ptr scoped_activation = + context->op_device_context()->stream()->parent()->Activate(); TensorShape output_shape = dense_output_shape; OP_REQUIRES_OK_ASYNC(context, diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl.h b/tensorflow/core/kernels/segment_reduction_ops_impl.h index 0ec8be88961332..658d0d161ab3ee 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl.h +++ b/tensorflow/core/kernels/segment_reduction_ops_impl.h @@ -60,14 +60,11 @@ limitations under the License. #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/util/gpu_solvers.h" -using stream_executor::gpu::ScopedActivateContext; #elif TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #include "tensorflow/core/util/gpu_solvers.h" -using stream_executor::gpu::ScopedActivateContext; #endif // GOOGLE_CUDA namespace tensorflow { @@ -77,18 +74,18 @@ typedef Eigen::GpuDevice GPUDevice; namespace internal { -Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, - const Tensor& segment_ids); -Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, - OpKernelContext* context, - const Tensor& data, - const Tensor& segment_ids, - const Tensor& num_segments); -Status ValidateSparseSegmentReduction(OpKernelContext* context, - const Tensor& input, - const Tensor& indices, - const Tensor& segment_ids, - bool has_num_segments); +absl::Status ValidateSegmentReduction(OpKernelContext* c, const Tensor& input, + const Tensor& segment_ids); +absl::Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments); +absl::Status ValidateSparseSegmentReduction(OpKernelContext* context, + const Tensor& input, + const Tensor& indices, + const Tensor& segment_ids, + bool has_num_segments); } // namespace internal // This operator handles reducing segments along the first dimension. @@ -288,7 +285,8 @@ class SegmentReductionGPUOp : public AsyncOpKernel { // Ensure that within the callback, the proper GPU settings are // configured. auto stream = context->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); Index output_rows = *output_rows_host.data(); output_rows++; @@ -921,7 +919,8 @@ class SparseSegmentReductionOpBase // Ensure that within the callback, the proper GPU settings are // configured. auto stream = context->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); SegmentId last_segment_id = *last_segment_id_host.data(); SegmentId output_rows = last_segment_id + 1; @@ -1364,9 +1363,9 @@ class SparseSegmentSqrtNGradOp template class SparseSegmentGradV2OpCommon { public: - Status operator()(OpKernelContext* context, - SparseSegmentReductionOperation operation, - typename AsyncOpKernel::DoneCallback done = nullptr) { + absl::Status operator()(OpKernelContext* context, + SparseSegmentReductionOperation operation, + typename AsyncOpKernel::DoneCallback done = nullptr) { const Tensor& input = context->input(0); const Tensor& indices = context->input(1); const Tensor& segment_ids = context->input(2); diff --git a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc index 24d5f94b60f987..d8e669a15c329f 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_impl_1.cc @@ -19,8 +19,9 @@ limitations under the License. namespace tensorflow { namespace internal { // Static routines not in the templated class to reduce code size -Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input, - const Tensor& segment_ids) { +absl::Status ValidateSegmentReduction(OpKernelContext* context, + const Tensor& input, + const Tensor& segment_ids) { if (!TensorShapeUtils::IsVectorOrHigher(input.shape())) { return errors::InvalidArgument("input must be at least rank 1"); } @@ -38,11 +39,11 @@ Status ValidateSegmentReduction(OpKernelContext* context, const Tensor& input, } // check routines not in the templated class to reduce code size -Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, - OpKernelContext* context, - const Tensor& data, - const Tensor& segment_ids, - const Tensor& num_segments) { +absl::Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, + OpKernelContext* context, + const Tensor& data, + const Tensor& segment_ids, + const Tensor& num_segments) { if (!TensorShapeUtils::IsScalar(num_segments.shape())) { return errors::InvalidArgument( "num_segments should be a scalar, not shape ", @@ -58,11 +59,11 @@ Status ValidateUnsortedSegmentReduction(OpKernel* op_kernel, return absl::OkStatus(); } -Status ValidateSparseSegmentReduction(OpKernelContext* context, - const Tensor& input, - const Tensor& indices, - const Tensor& segment_ids, - bool has_num_segments) { +absl::Status ValidateSparseSegmentReduction(OpKernelContext* context, + const Tensor& input, + const Tensor& indices, + const Tensor& segment_ids, + bool has_num_segments) { if (has_num_segments) { const Tensor& num_segments_t = context->input(3); if (!TensorShapeUtils::IsScalar(num_segments_t.shape())) { diff --git a/tensorflow/core/kernels/segment_reduction_ops_test.cc b/tensorflow/core/kernels/segment_reduction_ops_test.cc index 2763d97f5e3a2f..df6e3f5b5a9f91 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_test.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_test.cc @@ -67,7 +67,7 @@ static void BM_UnsortedSegmentReduction(::testing::benchmark::State& state, .Input(FakeInput(DT_INT32)) .Input(FakeInput(DT_INT32)) .Finalize(&reduction_node_def)); - Status status; + absl::Status status; std::unique_ptr reduction_op( CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), reduction_node_def, TF_GRAPH_DEF_VERSION, &status)); @@ -131,7 +131,7 @@ static void BM_SegmentReduction(::testing::benchmark::State& state, .Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DataTypeToEnum::v())) .Finalize(&reduction_node_def)); - Status status; + absl::Status status; std::unique_ptr reduction_op( CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), reduction_node_def, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index d15cc7feda3bd0..4ae504a2d60244 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -175,7 +175,7 @@ string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const { namespace { Rendezvous::DoneCallback make_recv_callback(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { - return [ctx, done = std::move(done)](const Status& s, + return [ctx, done = std::move(done)](const absl::Status& s, const Rendezvous::Args& send_args, const Rendezvous::Args& recv_args, const Tensor& val, bool is_dead) { diff --git a/tensorflow/core/kernels/sendrecv_ops_test.cc b/tensorflow/core/kernels/sendrecv_ops_test.cc index 10f23e418d9a57..bce377702d0b85 100644 --- a/tensorflow/core/kernels/sendrecv_ops_test.cc +++ b/tensorflow/core/kernels/sendrecv_ops_test.cc @@ -27,8 +27,8 @@ namespace { // implementations, and to avoid the duplicate-send or duplicate-recv // errors that would arise from running either benchmark in a loop. class DummyRendezvous : public Rendezvous { - Status Send(const ParsedKey& key, const Args& args, const Tensor& val, - const bool is_dead) override { + absl::Status Send(const ParsedKey& key, const Args& args, const Tensor& val, + const bool is_dead) override { return absl::OkStatus(); } void RecvAsync(const ParsedKey& key, const Args& args, @@ -36,7 +36,7 @@ class DummyRendezvous : public Rendezvous { static Tensor* t = new Tensor(DT_FLOAT, TensorShape({0})); done(absl::OkStatus(), args, args, *t, false); } - void StartAbort(const Status& status) override {} + void StartAbort(const absl::Status& status) override {} }; static Graph* Send() { diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index f55c5ecdb37e84..169b3f3a291030 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -51,8 +51,8 @@ class SerializeSparseOp : public OpKernel { bool IsExpensive() override; - Status Initialize(Tensor* result); - Status Serialize(const Tensor& input, T* result); + absl::Status Initialize(Tensor* result); + absl::Status Serialize(const Tensor& input, T* result); void Compute(OpKernelContext* context) override { const Tensor* input_indices; @@ -105,14 +105,14 @@ bool SerializeSparseOp::IsExpensive() { } template <> -Status SerializeSparseOp::Initialize(Tensor* result) { +absl::Status SerializeSparseOp::Initialize(Tensor* result) { *result = Tensor(DT_STRING, TensorShape({3})); return absl::OkStatus(); } template <> -Status SerializeSparseOp::Serialize(const Tensor& input, - tstring* result) { +absl::Status SerializeSparseOp::Serialize(const Tensor& input, + tstring* result) { TensorProto proto; input.AsProtoTensorContent(&proto); *result = proto.SerializeAsString(); @@ -125,14 +125,14 @@ REGISTER_KERNEL_BUILDER(Name("SerializeSparse") SerializeSparseOp); template <> -Status SerializeSparseOp::Initialize(Tensor* result) { +absl::Status SerializeSparseOp::Initialize(Tensor* result) { *result = Tensor(DT_VARIANT, TensorShape({3})); return absl::OkStatus(); } template <> -Status SerializeSparseOp::Serialize(const Tensor& input, - Variant* result) { +absl::Status SerializeSparseOp::Serialize(const Tensor& input, + Variant* result) { *result = input; return absl::OkStatus(); } @@ -147,9 +147,9 @@ struct SerializeGroups {}; template struct SerializeGroups { - Status operator()(sparse::GroupIterable* minibatch, - const Tensor& output_shape, int64_t N, int rank, - Tensor* serialized_sparse) { + absl::Status operator()(sparse::GroupIterable* minibatch, + const Tensor& output_shape, int64_t N, int rank, + Tensor* serialized_sparse) { auto serialized_sparse_t = serialized_sparse->matrix(); int64_t last_nonempty_group = -1; @@ -251,9 +251,9 @@ void CopyValues(const Eigen::half* src, Eigen::half* dest, template struct SerializeGroups { - Status operator()(sparse::GroupIterable* minibatch, - const Tensor& output_shape, int64_t N, int rank, - Tensor* serialized_sparse) { + absl::Status operator()(sparse::GroupIterable* minibatch, + const Tensor& output_shape, int64_t N, int rank, + Tensor* serialized_sparse) { auto serialized_sparse_t = serialized_sparse->template matrix(); int64_t last_nonempty_group = -1; diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc index 8ba24346ab1cb5..c91aff5a0ae5f3 100644 --- a/tensorflow/core/kernels/set_kernels.cc +++ b/tensorflow/core/kernels/set_kernels.cc @@ -52,7 +52,8 @@ void CheckRankAtLeast2(OpKernelContext* ctx, const TensorShape& shape) { } // Return group shape, which is the 1st n-1 dimensions of shape. -Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { +absl::Status GroupShape(const VarDimArray& input_shape, + ShapeArray* grouped_shape) { if (input_shape.size() < 2) { // TODO(irving): Why can't 2 be 1 here? return errors::InvalidArgument("Shape [", absl::StrJoin(input_shape, ","), @@ -65,9 +66,10 @@ Status GroupShape(const VarDimArray& input_shape, ShapeArray* grouped_shape) { // Build `SparseTensor` from indices, values, and shape in inputs // [base_index, base_index + 3), and validate its rank and indices. -Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index, - const bool validate_indices, - sparse::SparseTensor* tensor) { +absl::Status SparseTensorFromContext(OpKernelContext* ctx, + const int32_t base_index, + const bool validate_indices, + sparse::SparseTensor* tensor) { // Assume row-major order. TensorShape shape; const Tensor& shape_tensor = ctx->input(base_index + 2); @@ -80,7 +82,7 @@ Status SparseTensorFromContext(OpKernelContext* ctx, const int32_t base_index, std::vector order(shape.dims()); std::iota(order.begin(), order.end(), 0); - Status status = sparse::SparseTensor::Create( + absl::Status status = sparse::SparseTensor::Create( ctx->input(base_index), ctx->input(base_index + 1), shape, order, tensor); if (!validate_indices || !status.ok()) return status; @@ -419,7 +421,7 @@ void SetOperationOp::ApplySetOperation(const absl::flat_hash_set& set1, } // Validate shapes have the same dimensions. -Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { +absl::Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { if (shape1 != shape2) { return errors::InvalidArgument("Mismatched shapes [", absl::StrJoin(shape1, ","), "] vs [", @@ -430,8 +432,8 @@ Status CheckShapesMatch(VarDimArray shape1, VarDimArray shape2) { // Validate ranks are the same, and all but last dimension are the same. // Return GroupShape. -Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2, - ShapeArray* group_shape) { +absl::Status GroupShapeFromInputs(VarDimArray shape1, VarDimArray shape2, + ShapeArray* group_shape) { ShapeArray group_shape_1; TF_RETURN_IF_ERROR(GroupShape(shape1, &group_shape_1)); ShapeArray group_shape_2; diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h index dcddddc5e38686..d9c64c76ba57f4 100644 --- a/tensorflow/core/kernels/shape_ops.h +++ b/tensorflow/core/kernels/shape_ops.h @@ -31,8 +31,8 @@ limitations under the License. namespace tensorflow { namespace shape_op_helpers { -inline Status GetShape(OpKernelContext* ctx, int input_index, - TensorShape* shape) { +inline absl::Status GetShape(OpKernelContext* ctx, int input_index, + TensorShape* shape) { *shape = ctx->input(input_index).shape(); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/shuffle_common.h b/tensorflow/core/kernels/shuffle_common.h index 0fd93bdfca9573..0eea7fd43a335e 100644 --- a/tensorflow/core/kernels/shuffle_common.h +++ b/tensorflow/core/kernels/shuffle_common.h @@ -61,9 +61,9 @@ static void IndexedShuffle(const int64_t size, const InT& input_mat, } template -Status RandomShuffle(OpKernelContext* context, const Tensor& input, - int output_idx, - std::function get_rng) { +absl::Status RandomShuffle( + OpKernelContext* context, const Tensor& input, int output_idx, + std::function get_rng) { if (input.NumElements() <= 1 || input.dim_size(0) <= 1) { // No shuffling is required, so copy input directly to output context->set_output(output_idx, input); diff --git a/tensorflow/core/kernels/smooth-hinge-loss.h b/tensorflow/core/kernels/smooth-hinge-loss.h index f1019b7c53cb7c..8dc2c8068bf66a 100644 --- a/tensorflow/core/kernels/smooth-hinge-loss.h +++ b/tensorflow/core/kernels/smooth-hinge-loss.h @@ -75,7 +75,7 @@ class SmoothHingeLossUpdater : public DualLossUpdater { // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively // as expected by smooth hinge loss. - Status ConvertLabel(float* const example_label) const final { + absl::Status ConvertLabel(float* const example_label) const final { if (*example_label == 0.0) { *example_label = -1; return absl::OkStatus(); diff --git a/tensorflow/core/kernels/spacetobatch_functor.cc b/tensorflow/core/kernels/spacetobatch_functor.cc index bdc2f4680927f2..6e5319b94bbd06 100644 --- a/tensorflow/core/kernels/spacetobatch_functor.cc +++ b/tensorflow/core/kernels/spacetobatch_functor.cc @@ -90,7 +90,7 @@ template struct SpaceToBatchFunctor { using SpaceT = typename std::conditional::type; using BatchT = typename std::conditional::type; - Status operator()( + absl::Status operator()( const CPUDevice& d, typename TTypes::Tensor space_tensor, const int64_t block_shape_tensor[NUM_BLOCK_DIMS], diff --git a/tensorflow/core/kernels/spacetobatch_functor.h b/tensorflow/core/kernels/spacetobatch_functor.h index 4804703e25f018..7838b5e3dc55bf 100644 --- a/tensorflow/core/kernels/spacetobatch_functor.h +++ b/tensorflow/core/kernels/spacetobatch_functor.h @@ -100,7 +100,7 @@ struct SpaceToBatchFunctor { // then this is the input to the conversion. // // The caller must ensure that the dimensions of the tensors are correct. - Status operator()( + absl::Status operator()( const Device& d, typename TTypes::Tensor space_tensor, const int64_t block_shape[NUM_BLOCK_DIMS], diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc index de5f9728592872..3495bf2836013e 100644 --- a/tensorflow/core/kernels/spacetobatch_op.cc +++ b/tensorflow/core/kernels/spacetobatch_op.cc @@ -42,10 +42,10 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template -Status SpaceToBatchOpCompute(OpKernelContext* context, - const Tensor& orig_input_tensor, - const Tensor& orig_block_shape, - const Tensor& orig_paddings) { +absl::Status SpaceToBatchOpCompute(OpKernelContext* context, + const Tensor& orig_input_tensor, + const Tensor& orig_block_shape, + const Tensor& orig_paddings) { const int input_dims = orig_input_tensor.dims(); if (!TensorShapeUtils::IsVector(orig_block_shape.shape())) { return errors::InvalidArgument("block_shape rank should be 1 instead of ", diff --git a/tensorflow/core/kernels/sparse/BUILD b/tensorflow/core/kernels/sparse/BUILD index 814eb66a171f77..bf1f38975c1733 100644 --- a/tensorflow/core/kernels/sparse/BUILD +++ b/tensorflow/core/kernels/sparse/BUILD @@ -58,7 +58,6 @@ tf_kernel_library( "zeros_op.h", ], gpu_deps = [ - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", "//tensorflow/core/kernels:gpu_device_array", ], gpu_srcs = [ diff --git a/tensorflow/core/kernels/sparse/add_op.cc b/tensorflow/core/kernels/sparse/add_op.cc index e27de2a1782b91..c454241c1574c2 100644 --- a/tensorflow/core/kernels/sparse/add_op.cc +++ b/tensorflow/core/kernels/sparse/add_op.cc @@ -49,8 +49,8 @@ class CSRSparseMatrixAddFunctor { const T beta) : ctx_(ctx), alpha_(alpha), beta_(beta) {} - Status operator()(const CSRSparseMatrix& a, const CSRSparseMatrix& b, - CSRSparseMatrix* c) { + absl::Status operator()(const CSRSparseMatrix& a, const CSRSparseMatrix& b, + CSRSparseMatrix* c) { TensorShape a_tensor_shape; TensorShape b_tensor_shape; TF_RETURN_IF_ERROR(TensorShapeUtils::MakeShape( diff --git a/tensorflow/core/kernels/sparse/conj_op.cc b/tensorflow/core/kernels/sparse/conj_op.cc index 0436ea0a85f889..7521e51e6a6bcb 100644 --- a/tensorflow/core/kernels/sparse/conj_op.cc +++ b/tensorflow/core/kernels/sparse/conj_op.cc @@ -47,7 +47,7 @@ class CSRSparseMatrixConjFunctor { public: explicit CSRSparseMatrixConjFunctor(OpKernelContext* ctx) : ctx_(ctx) {} - Status operator()(const CSRSparseMatrix& a, CSRSparseMatrix* b) { + absl::Status operator()(const CSRSparseMatrix& a, CSRSparseMatrix* b) { const int total_nnz = a.total_nnz(); Tensor b_values_t; TF_RETURN_IF_ERROR(ctx_->allocate_temp( diff --git a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc index ff3543f4602c77..403af12bb8fb52 100644 --- a/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc +++ b/tensorflow/core/kernels/sparse/csr_sparse_matrix_to_sparse_tensor_op.cc @@ -43,8 +43,8 @@ using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; // Validate that CSR SparseMatrix has the expected dtype and rank 2 or 3. -Status ValidateCSRSparseMatrix(const CSRSparseMatrix& csr_sparse_matrix, - DataType expected_dtype) { +absl::Status ValidateCSRSparseMatrix(const CSRSparseMatrix& csr_sparse_matrix, + DataType expected_dtype) { if (csr_sparse_matrix.dtype() != expected_dtype) { return errors::InvalidArgument( "Expected a CSRSparseMatrix of type ", DataTypeString(expected_dtype), diff --git a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc index 4109fe3de8017f..6e635d140ad7df 100644 --- a/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/dense_to_csr_sparse_matrix_op.cc @@ -34,11 +34,9 @@ limitations under the License. #include "tensorflow/core/kernels/sparse/sparse_matrix.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/util/cuda_sparse.h" #include "tensorflow/core/util/gpu_solvers.h" -using ::stream_executor::gpu::ScopedActivateContext; #endif namespace tensorflow { @@ -222,7 +220,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { { // Ensure that within the callback, the proper GPU settings are // configured. - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); // Extract out the values. Tensor temp_values_t; @@ -337,7 +336,7 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel { c, c->allocate_output(0, TensorShape({}), &matrix_t, cpu_alloc), done); matrix_t->scalar()() = std::move(matrix); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. diff --git a/tensorflow/core/kernels/sparse/kernels.cc b/tensorflow/core/kernels/sparse/kernels.cc index bbb96743d0be21..9ef6e4e5f11c78 100644 --- a/tensorflow/core/kernels/sparse/kernels.cc +++ b/tensorflow/core/kernels/sparse/kernels.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { namespace functor { -Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()( +absl::Status SparseTensorToCSRSparseMatrixCPUFunctor::operator()( int64_t batch_size, int num_rows, int num_cols, TTypes::ConstMatrix indices, TTypes::Vec batch_ptr, TTypes::Vec csr_row_ptr, TTypes::Vec csr_col_ind) { diff --git a/tensorflow/core/kernels/sparse/kernels.h b/tensorflow/core/kernels/sparse/kernels.h index c609845cadf4fa..11dccc45a5af8f 100644 --- a/tensorflow/core/kernels/sparse/kernels.h +++ b/tensorflow/core/kernels/sparse/kernels.h @@ -36,8 +36,9 @@ namespace functor { // nnz_per_batch.dimension(0) == B template struct CalculateNNZPerBatchMatrixFromIndices { - Status operator()(OpKernelContext* c, TTypes::ConstMatrix indices, - TTypes::Vec nnz_per_batch); + absl::Status operator()(OpKernelContext* c, + TTypes::ConstMatrix indices, + TTypes::Vec nnz_per_batch); }; // Split a subset of a SparseTensors' indices into two vectors: @@ -82,12 +83,12 @@ struct SparseTensorToCOOSparseMatrix { // template struct COOSparseMatrixToSparseTensor { - Status operator()(OpKernelContext* c, - TTypes::ConstVec host_dense_shape, - TTypes::ConstVec host_batch_ptrs, - TTypes::Vec coo_row_ind, - TTypes::ConstVec coo_col_ind, - TTypes::Matrix indices); + absl::Status operator()(OpKernelContext* c, + TTypes::ConstVec host_dense_shape, + TTypes::ConstVec host_batch_ptrs, + TTypes::Vec coo_row_ind, + TTypes::ConstVec coo_col_ind, + TTypes::Matrix indices); }; // Convert a vector of coo row indices to csr row pointers. @@ -99,9 +100,9 @@ struct COOSparseMatrixToSparseTensor { // template struct COOSparseMatrixToCSRSparseMatrix { - Status operator()(OpKernelContext* c, const int rows, const int cols, - TTypes::UnalignedVec coo_row_ind, - TTypes::UnalignedVec csr_row_ptr); + absl::Status operator()(OpKernelContext* c, const int rows, const int cols, + TTypes::UnalignedVec coo_row_ind, + TTypes::UnalignedVec csr_row_ptr); }; // Convert a matrix of (batched) coo row and column indices to CSR SparseMatrix @@ -119,11 +120,11 @@ struct COOSparseMatrixToCSRSparseMatrix { // Also csr_row_ptr should be initially filled with zeros. // struct SparseTensorToCSRSparseMatrixCPUFunctor { - Status operator()(int64_t batch_size, int num_rows, int num_cols, - TTypes::ConstMatrix indices, - TTypes::Vec batch_ptr, - TTypes::Vec csr_row_ptr, - TTypes::Vec csr_col_ind); + absl::Status operator()(int64_t batch_size, int num_rows, int num_cols, + TTypes::ConstMatrix indices, + TTypes::Vec batch_ptr, + TTypes::Vec csr_row_ptr, + TTypes::Vec csr_col_ind); }; // Convert a vector of csr row pointers to coo row indices. @@ -135,9 +136,9 @@ struct SparseTensorToCSRSparseMatrixCPUFunctor { // template struct CSRSparseMatrixToCOOSparseMatrix { - Status operator()(OpKernelContext* c, - TTypes::UnalignedConstVec csr_row_ptr, - TTypes::UnalignedVec coo_row_ind); + absl::Status operator()(OpKernelContext* c, + TTypes::UnalignedConstVec csr_row_ptr, + TTypes::UnalignedVec coo_row_ind); }; // Calculates C = matmul(A, B) or C = matmul(A, B)^T, where A is in CSR format @@ -145,9 +146,9 @@ struct CSRSparseMatrixToCOOSparseMatrix { template struct CSRSparseMatrixMatMul { explicit CSRSparseMatrixMatMul(const bool transpose_output); - Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, - typename TTypes::ConstMatrix b, - typename TTypes::Matrix c); + absl::Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + typename TTypes::ConstMatrix b, + typename TTypes::Matrix c); }; // Calculates y = A * x, y = A^T * x, or y = A^H * x, where A is in CSR format @@ -155,8 +156,8 @@ struct CSRSparseMatrixMatMul { template class CSRSparseMatrixMatVec { CSRSparseMatrixMatVec(bool transpose_a, bool adjoint_a); - Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, - const T* x, T* y); + absl::Status Compute(OpKernelContext* ctx, const ConstCSRComponent& a, + const T* x, T* y); }; // Calculates C = functor(A, B) where A and B are CSR and C is CSR @@ -165,20 +166,20 @@ template struct CSRStructureModifyingFunctor { virtual ~CSRStructureModifyingFunctor() {} - virtual Status Initialize() = 0; + virtual absl::Status Initialize() = 0; - virtual Status GetWorkspaceSize(const ConstCSRComponent& a, - const ConstCSRComponent& b, - size_t* bufferSize) = 0; + virtual absl::Status GetWorkspaceSize(const ConstCSRComponent& a, + const ConstCSRComponent& b, + size_t* bufferSize) = 0; - virtual Status GetOutputStructure(const ConstCSRComponent& a, - const ConstCSRComponent& b, - TTypes::UnalignedVec c_row_ptr, - int* output_nnz, void* workspace) = 0; + virtual absl::Status GetOutputStructure(const ConstCSRComponent& a, + const ConstCSRComponent& b, + TTypes::UnalignedVec c_row_ptr, + int* output_nnz, void* workspace) = 0; - virtual Status Compute(const ConstCSRComponent& a, - const ConstCSRComponent& b, CSRComponent* c, - void* workspace) = 0; + virtual absl::Status Compute(const ConstCSRComponent& a, + const ConstCSRComponent& b, + CSRComponent* c, void* workspace) = 0; }; // Calculates C = alpha * A + beta * B, where A and B are in CSR @@ -200,31 +201,31 @@ struct CSRSparseSparseMatrixMatMul // Calculates Y = transpose(X) where X and Y are CSR format components. template struct CSRSparseMatrixTransposeComponent { - Status operator()(OpKernelContext* ctx, const ConstCSRComponent& x, - CSRComponent* y); + absl::Status operator()(OpKernelContext* ctx, const ConstCSRComponent& x, + CSRComponent* y); }; // Calculates Y = transpose(X) where X and Y are in CSR format. template struct CSRSparseMatrixTranspose { - Status operator()(OpKernelContext* ctx, bool conjugate, - const CSRSparseMatrix& input_matrix, - CSRSparseMatrix* output_matrix); + absl::Status operator()(OpKernelContext* ctx, bool conjugate, + const CSRSparseMatrix& input_matrix, + CSRSparseMatrix* output_matrix); }; // Calculates Y = softmax(X) where X and Y are in CSR format; // missing coefficients in X are treates as -inf (logits of 0 probability). template struct CSRSparseMatrixSoftmax { - Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits, - typename TTypes::Vec softmax_values); + absl::Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& logits, + typename TTypes::Vec softmax_values); }; template struct CSRSparseMatrixSoftmaxGrad { - Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax, - const CSRSparseMatrix& grad_softmax, - typename TTypes::Vec gradient_values); + absl::Status operator()(OpKernelContext* ctx, const CSRSparseMatrix& softmax, + const CSRSparseMatrix& grad_softmax, + typename TTypes::Vec gradient_values); }; template @@ -232,8 +233,8 @@ class CSRSparseMatrixMulScalar { public: explicit CSRSparseMatrixMulScalar() {} - Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, - typename TTypes::ConstScalar b, CSRSparseMatrix* c); + absl::Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, + typename TTypes::ConstScalar b, CSRSparseMatrix* c); }; template @@ -241,8 +242,8 @@ class CSRSparseMatrixBatchMulVec { public: explicit CSRSparseMatrixBatchMulVec() {} - Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, - typename TTypes::ConstFlat b, CSRSparseMatrix* c); + absl::Status Compute(OpKernelContext* ctx, const CSRSparseMatrix& a, + typename TTypes::ConstFlat b, CSRSparseMatrix* c); }; } // namespace functor diff --git a/tensorflow/core/kernels/sparse/mat_mul_op.h b/tensorflow/core/kernels/sparse/mat_mul_op.h index 37043f1adf11f1..3e55cfbc38f201 100644 --- a/tensorflow/core/kernels/sparse/mat_mul_op.h +++ b/tensorflow/core/kernels/sparse/mat_mul_op.h @@ -110,9 +110,9 @@ class CSRMatMulOp : public OpKernel { ~CSRMatMulOp() override {} - Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, - const Tensor& dense_tensor_b, int* rank, - int64_t* batch_size) { + absl::Status ValidateInputs(const CSRSparseMatrix& sparse_matrix_a, + const Tensor& dense_tensor_b, int* rank, + int64_t* batch_size) { if (sparse_matrix_a.dtype() != dense_tensor_b.dtype()) { return absl::InvalidArgumentError(absl::StrCat( "Input types don't match. a.dtype == ", @@ -243,11 +243,12 @@ class CSRMatMulCPUOp : public CSRMatMulOp { // transpose_output is True, allocates a temporary buffer with the transposed // output. 'matmul_result' points to either output or output_transposed, based // on whether transpose_output is True. - Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, - const int64_t batch_size, const int64_t num_rows, - const int64_t num_cols, const bool transpose_output, - Tensor** output, Tensor* output_transposed, - Tensor** matmul_result) { + absl::Status AllocateOutput(OpKernelContext* ctx, const int32_t rank, + const int64_t batch_size, const int64_t num_rows, + const int64_t num_cols, + const bool transpose_output, Tensor** output, + Tensor* output_transposed, + Tensor** matmul_result) { TensorShape output_shape; if (rank == 3) { TF_RETURN_IF_ERROR(output_shape.AddDimWithStatus(batch_size)); @@ -468,8 +469,9 @@ class CSRMatMulCPUOp : public CSRMatMulOp { // Transposes (and optionally, conjugates) a given Tensor. Also allocates the // required memory for the output Tensor. - Status TransposeAndConjugateTensor(OpKernelContext* ctx, const Tensor& input, - bool conjugate, Tensor* output) { + absl::Status TransposeAndConjugateTensor(OpKernelContext* ctx, + const Tensor& input, bool conjugate, + Tensor* output) { TensorShape transposed_shape = input.shape(); transposed_shape.set_dim(input.dims() - 1, input.dim_size(input.dims() - 2)); @@ -482,9 +484,10 @@ class CSRMatMulCPUOp : public CSRMatMulOp { // Transposes (and optionally, conjugates) a given Tensor. The output should // be already allocated. - Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, - const Tensor& input, - bool conjugate, Tensor* output) { + absl::Status TransposeAndConjugateAllocatedTensor(OpKernelContext* ctx, + const Tensor& input, + bool conjugate, + Tensor* output) { if (conjugate) { TF_RETURN_IF_ERROR(DoConjugateMatrixTranspose( ctx->eigen_device(), input, output)); @@ -595,7 +598,7 @@ class CSRMatMulGPUOp : public CSRMatMulOp { a_dense_shape_comp}; const T* b_i = b_base_ptr + i * b_slice_size; T* c_i = &c_t->template flat()(i * c_slice_size); - Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); + absl::Status s = csr_spmv.Compute(ctx, a_comp, b_i, c_i); OP_REQUIRES_OK(ctx, s); } if (conjugate_output) { @@ -690,7 +693,7 @@ class CSRMatMulGPUOp : public CSRMatMulOp { {c_matrix_lhs, c_matrix_rhs}); ConstCSRComponent a_comp{a_row_ptr, a_col_ind, a_values, a_input_dense_shape_comp}; - Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); + absl::Status s = csr_spmmadd.Compute(ctx, a_comp, b_i, c_mat_col_major_i); OP_REQUIRES_OK(ctx, s); } diff --git a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc index 90f4fbde158748..0453a1b97919a5 100644 --- a/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_cholesky_op.cc @@ -228,9 +228,9 @@ class CSRSparseCholeskyCPUOp : public OpKernel { } private: - Status ValidateInputs(const CSRSparseMatrix& sparse_matrix, - const Tensor& permutation_indices, int* batch_size, - int64_t* num_rows) { + absl::Status ValidateInputs(const CSRSparseMatrix& sparse_matrix, + const Tensor& permutation_indices, + int* batch_size, int64_t* num_rows) { if (sparse_matrix.dtype() != DataTypeToEnum::value) return errors::InvalidArgument( "Asked for a CSRSparseMatrix of type ", diff --git a/tensorflow/core/kernels/sparse/sparse_matrix.h b/tensorflow/core/kernels/sparse/sparse_matrix.h index 95b93443863ecc..8e5ff45f57d30a 100644 --- a/tensorflow/core/kernels/sparse/sparse_matrix.h +++ b/tensorflow/core/kernels/sparse/sparse_matrix.h @@ -144,22 +144,21 @@ class CSRSparseMatrix { return *this; } - static Status CreateCSRSparseMatrix(DataType dtype, - const Tensor& dense_shape, // on host - const Tensor& batch_pointers, // on host - const Tensor& row_pointers, - const Tensor& col_indices, - const Tensor& values, - CSRSparseMatrix* matrix) { + static absl::Status CreateCSRSparseMatrix( + DataType dtype, + const Tensor& dense_shape, // on host + const Tensor& batch_pointers, // on host + const Tensor& row_pointers, const Tensor& col_indices, + const Tensor& values, CSRSparseMatrix* matrix) { *matrix = CSRSparseMatrix(dtype, dense_shape, batch_pointers, row_pointers, col_indices, values); - Status s = matrix->Validate(); + absl::Status s = matrix->Validate(); matrix->metadata_.validated = s.ok(); matrix->SetupVecs(); return s; } - Status Validate() const { + absl::Status Validate() const { return ValidateTypesAndShapes(metadata_.dtype, dense_shape_, batch_pointers_, row_pointers_, col_indices_, values_); @@ -349,8 +348,8 @@ class CSRSparseMatrix { Tensor values(p.tensors_[4]); // Check that the validated bool is consistent with the data. - Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers, - row_pointers, col_indices, values); + absl::Status s = ValidateTypesAndShapes(dtype, dense_shape, batch_pointers, + row_pointers, col_indices, values); if (s.ok() != validated) return false; // Save to this object. @@ -381,7 +380,7 @@ class CSRSparseMatrix { // This static method copies CSRSparseMatrices in all directions: // Host->Device, Device->Host, and Device->Device. - static Status DeviceCopy( + static absl::Status DeviceCopy( const CSRSparseMatrix& from, CSRSparseMatrix* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { VLOG(2) << "DeviceCopy from type: " << DataTypeString(from.dtype()) @@ -423,12 +422,12 @@ class CSRSparseMatrix { col_indices_vec_.reset(); } - static Status ValidateTypesAndShapes(DataType dtype, - const Tensor& dense_shape, - const Tensor& batch_pointers, - const Tensor& row_pointers, - const Tensor& col_indices, - const Tensor& values) { + static absl::Status ValidateTypesAndShapes(DataType dtype, + const Tensor& dense_shape, + const Tensor& batch_pointers, + const Tensor& row_pointers, + const Tensor& col_indices, + const Tensor& values) { // TODO(ebrevdo): Consider adding support for other floating point types // (namely, float16). if (dtype != DT_FLOAT && dtype != DT_DOUBLE && dtype != DT_COMPLEX64 && @@ -547,10 +546,10 @@ class CSRSparseMatrix { // where T depends on a.dtype(). T will be one of: float, double, // complex64, complex128. template class BinaryFunctor> -Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx, - const CSRSparseMatrix& a, - const CSRSparseMatrix& b, - CSRSparseMatrix* c) { +absl::Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx, + const CSRSparseMatrix& a, + const CSRSparseMatrix& b, + CSRSparseMatrix* c) { DataType dt = a.dtype(); if (dt != b.dtype()) { return errors::InvalidArgument( @@ -587,9 +586,9 @@ Status CSRSparseMatrixBinaryHelper(OpKernelContext* ctx, // where T depends on a.dtype(). T will be one of: float, double, // complex64, complex128. template class UnaryFunctor> -Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx, - const CSRSparseMatrix& a, - CSRSparseMatrix* b) { +absl::Status CSRSparseMatrixUnaryHelper(OpKernelContext* ctx, + const CSRSparseMatrix& a, + CSRSparseMatrix* b) { DataType dt = a.dtype(); switch (dt) { case DT_FLOAT: { @@ -632,8 +631,8 @@ struct CSRComponent { }; template -Status ExtractVariantFromInput(OpKernelContext* ctx, int index, - const T** value) { +absl::Status ExtractVariantFromInput(OpKernelContext* ctx, int index, + const T** value) { const Tensor& input_t = ctx->input(index); if (!TensorShapeUtils::IsScalar(input_t.shape())) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc index cfee2de9f4fc3f..e93e2b0a018845 100644 --- a/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc +++ b/tensorflow/core/kernels/sparse/sparse_tensor_to_csr_sparse_matrix_op.cc @@ -35,11 +35,9 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_utils.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/util/cuda_sparse.h" #include "tensorflow/core/util/gpu_solvers.h" -using ::stream_executor::gpu::ScopedActivateContext; #endif namespace tensorflow { @@ -234,7 +232,8 @@ class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel { // Ensure that within the callback, the proper GPU settings are // configured. { - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); Tensor batch_ptr_t(cpu_allocator(), DT_INT32, TensorShape({batch_size + 1})); @@ -326,7 +325,7 @@ class SparseTensorToCSRSparseMatrixGPUOp : public AsyncOpKernel { c, c->allocate_output(0, TensorShape({}), &matrix_t, cpu_alloc), done); matrix_t->scalar()() = std::move(matrix); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. diff --git a/tensorflow/core/kernels/sparse/transpose_op.cc b/tensorflow/core/kernels/sparse/transpose_op.cc index 30a887749ac7b6..74e0b85f393e40 100644 --- a/tensorflow/core/kernels/sparse/transpose_op.cc +++ b/tensorflow/core/kernels/sparse/transpose_op.cc @@ -50,8 +50,8 @@ typedef Eigen::GpuDevice GPUDevice; namespace { template -Status ValidateTransposeInputs(const ConstCSRComponent& input, - const CSRComponent& output) { +absl::Status ValidateTransposeInputs(const ConstCSRComponent& input, + const CSRComponent& output) { const int rank = input.dense_shape_host.size(); const int64_t nnz = input.col_ind.size(); const int num_rows = input.row_ptr.size() - 1; @@ -144,7 +144,7 @@ REGISTER_TRANSPOSE(GPU, complex128) namespace functor { template -Status CSRSparseMatrixTranspose::operator()( +absl::Status CSRSparseMatrixTranspose::operator()( OpKernelContext* ctx, bool conjugate, const CSRSparseMatrix& input_matrix, CSRSparseMatrix* output_matrix) { const int rank = input_matrix.dims(); @@ -213,8 +213,9 @@ template struct CSRSparseMatrixTransposeComponent { using SparseMatrix = Eigen::SparseMatrix; - Status operator()(OpKernelContext* ctx, const ConstCSRComponent& input, - CSRComponent* output) { + absl::Status operator()(OpKernelContext* ctx, + const ConstCSRComponent& input, + CSRComponent* output) { TF_RETURN_IF_ERROR(ValidateTransposeInputs(input, *output)); const int rank = input.dense_shape_host.size(); diff --git a/tensorflow/core/kernels/sparse/zeros_op.cc b/tensorflow/core/kernels/sparse/zeros_op.cc index b09f20db9c121b..e71a4a16c0b01d 100644 --- a/tensorflow/core/kernels/sparse/zeros_op.cc +++ b/tensorflow/core/kernels/sparse/zeros_op.cc @@ -65,9 +65,9 @@ class CSRZerosOp : public OpKernel { namespace { template -Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx, - const CSRSparseMatrix& x, - CSRSparseMatrix* y) { +absl::Status CSRSparseMatrixZerosLikeHelper(OpKernelContext* ctx, + const CSRSparseMatrix& x, + CSRSparseMatrix* y) { functor::CSRSparseMatrixZeros csr_sparse_matrix_zeros; return csr_sparse_matrix_zeros(ctx, x.dtype(), x.dense_shape(), y); } diff --git a/tensorflow/core/kernels/sparse/zeros_op.h b/tensorflow/core/kernels/sparse/zeros_op.h index 8df31337110275..2a86089e04e62e 100644 --- a/tensorflow/core/kernels/sparse/zeros_op.h +++ b/tensorflow/core/kernels/sparse/zeros_op.h @@ -40,8 +40,9 @@ namespace functor { template struct CSRSparseMatrixZeros { - Status operator()(OpKernelContext* c, DataType dtype, - const Tensor& dense_shape_t, CSRSparseMatrix* matrix) { + absl::Status operator()(OpKernelContext* c, DataType dtype, + const Tensor& dense_shape_t, + CSRSparseMatrix* matrix) { auto dense_shape = dense_shape_t.vec(); const int rank = dense_shape.size(); if (!(rank == 2 || rank == 3)) { diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index e41caf8e0e4a45..9d45d52b55c398 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -69,7 +69,7 @@ class SparseConditionalAccumulator Eigen::Unaligned> SliceConstT; - Status ValidateShape( + absl::Status ValidateShape( std::tuple* tensor, bool has_known_shape) TF_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) { const Tensor* tensor_idx = std::get<0>(*tensor); diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 7a915cded37527..75834698465b8d 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -43,7 +43,7 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { // TODO(tanzheny): actually switch it to resource. You won't be able to use // it with cond2 otherwise. - Status CheckSignature(OpKernelContext* ctx) override { + absl::Status CheckSignature(OpKernelContext* ctx) override { TF_RETURN_IF_ERROR(ctx->MatchSignature({}, {DT_STRING_REF})); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/sparse_cross_op.cc b/tensorflow/core/kernels/sparse_cross_op.cc index 73cec0e1e8c734..1f10def306145d 100644 --- a/tensorflow/core/kernels/sparse_cross_op.cc +++ b/tensorflow/core/kernels/sparse_cross_op.cc @@ -475,11 +475,11 @@ int64_t CalculateBatchSize(const OpInputList& shapes_list_in, } // Validates input tensors. -Status ValidateInput(const OpInputList& indices_list_in, - const OpInputList& values_list_in, - const OpInputList& shapes_list_in, - const OpInputList& dense_list_in, - const DataType& internal_type) { +absl::Status ValidateInput(const OpInputList& indices_list_in, + const OpInputList& values_list_in, + const OpInputList& shapes_list_in, + const OpInputList& dense_list_in, + const DataType& internal_type) { const auto size = indices_list_in.size(); // Only perform internal_type check for SparseCrossOp. // Check if the internal_type is not invalid before doing so. @@ -705,7 +705,7 @@ GenerateKeyedColumnsFromInput(const OpInputList& indices_list_in, // It also output_start_indices which contains the start indices for each // input in the output SparseTensor. template -Status CreateOutputTensors( +absl::Status CreateOutputTensors( const std::vector>>& columns, int64_t batch_size, OpKernelContext* context, Tensor** indices_out, Tensor** values_out, Tensor** shape_out, @@ -911,7 +911,7 @@ class SparseCrossHashedOp : public OpKernel { const auto salt = salt_t->flat(); OP_REQUIRES_OK( context, salt.size() == 2 - ? Status() + ? absl::Status() : errors::InvalidArgument( "Input \"salt\" must have length 2 but has length ", salt.size())); diff --git a/tensorflow/core/kernels/sparse_reduce_op.cc b/tensorflow/core/kernels/sparse_reduce_op.cc index b54a0e49a46652..97dd91523ebc7f 100644 --- a/tensorflow/core/kernels/sparse_reduce_op.cc +++ b/tensorflow/core/kernels/sparse_reduce_op.cc @@ -111,7 +111,7 @@ absl::StatusOr SparseTensorReduceHelper(const SparseTensor &sp, return reduction; } -Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) { +absl::Status ValidateInputs(const Tensor *shape_t, const Tensor *reduction_axes_t) { // indices and values are validated in SparseTensor ctor. if (!TensorShapeUtils::IsVector(shape_t->shape())) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/sparse_slice_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_slice_op_gpu.cu.cc index 7745123be21d68..10363a97c4c597 100644 --- a/tensorflow/core/kernels/sparse_slice_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_slice_op_gpu.cu.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" @@ -33,8 +32,6 @@ limitations under the License. #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace -using stream_executor::gpu::ScopedActivateContext; - namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -203,7 +200,8 @@ struct SparseSliceFunctor { // Ensure that within the callback, the proper GPU settings are // configured. auto stream = context->op_device_context()->stream(); - std::optional scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); int64_t output_nnz = *output_nnz_host.data(); Tensor* output_indices = nullptr; @@ -220,7 +218,7 @@ struct SparseSliceFunctor { T* output_values_ptr = output_values->vec().data(); if (output_nnz == 0) { - // Release ScopedActivateContext to prevent deadlock when done + // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. scoped_activation.reset(); @@ -241,7 +239,7 @@ struct SparseSliceFunctor { selected_nonzeros_ptr, output_indices_ptr, output_values_ptr), done); - // Release ScopedActivateContext to prevent deadlock when done + // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. scoped_activation.reset(); diff --git a/tensorflow/core/kernels/sparse_split_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_split_op_gpu.cu.cc index df423c5a94fc90..abf53592c11bd5 100644 --- a/tensorflow/core/kernels/sparse_split_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_split_op_gpu.cu.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define EIGEN_USE_GPU #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace -using stream_executor::gpu::ScopedActivateContext; namespace tensorflow { @@ -302,7 +301,8 @@ struct SparseSplitFunctor { // Ensure that within the callback, the proper GPU settings are // configured. auto stream = context->op_device_context()->stream(); - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); GpuDeviceArrayOnHost output_indices(context, num_split); GpuDeviceArrayOnHost output_values(context, num_split); @@ -325,7 +325,7 @@ struct SparseSplitFunctor { input_values_ptr, dense_shape.dim_size(axis), output_indices.data(), output_values.data()), done); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc index 1d8b3b0156c756..e07df7085c240b 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.cc @@ -33,8 +33,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice; namespace { template -Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values, - const Tensor *a_shape, const Tensor *b) { +absl::Status ValidateInputs(const Tensor *a_indices, const Tensor *a_values, + const Tensor *a_shape, const Tensor *b) { if (!TensorShapeUtils::IsMatrix(a_indices->shape())) { return errors::InvalidArgument( "Input a_indices should be a matrix but received shape: ", diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index 0a2ffcf495b5e9..ccdd6598dcf997 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -239,20 +239,20 @@ REGISTER_KERNELS_GPU(complex128); namespace functor { namespace { -Status KOutOfBoundsError(int64_t k, std::size_t i, int rhs_index_a, - std::size_t lhs_right) { +absl::Status KOutOfBoundsError(int64_t k, std::size_t i, int rhs_index_a, + std::size_t lhs_right) { return errors::InvalidArgument("k (", k, ") from index[", i, ",", rhs_index_a, "] out of bounds (>=", lhs_right, ")"); } -Status MOutOfBoundsError(int64_t m, std::size_t i, int lhs_index_a, - int64_t out_dim0) { +absl::Status MOutOfBoundsError(int64_t m, std::size_t i, int lhs_index_a, + int64_t out_dim0) { return errors::InvalidArgument("m (", m, ") from index[", i, ",", lhs_index_a, "] out of bounds (>=", out_dim0, ")"); } template -Status SparseTensorDenseMatMulImpl( +absl::Status SparseTensorDenseMatMulImpl( typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b) { @@ -326,10 +326,11 @@ Status SparseTensorDenseMatMulImpl( template struct SparseTensorDenseMatMulFunctor { - static Status Compute(OpKernelContext* ctx, typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b) { + static absl::Status Compute(OpKernelContext* ctx, + typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b) { using Tsum = typename SumType::type; Tensor temp_out_t; if (!std::is_same::value) { diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h index 3cab997c567e87..fef151ea7e61b4 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h @@ -29,7 +29,7 @@ namespace functor { template struct SparseTensorDenseMatMulFunctor { - static EIGEN_ALWAYS_INLINE Status Compute( + static EIGEN_ALWAYS_INLINE absl::Status Compute( OpKernelContext* ctx, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, typename TTypes::ConstMatrix b); diff --git a/tensorflow/core/kernels/sparse_tensors_map_ops.cc b/tensorflow/core/kernels/sparse_tensors_map_ops.cc index 198372204eb85c..1e0635b4d38059 100644 --- a/tensorflow/core/kernels/sparse_tensors_map_ops.cc +++ b/tensorflow/core/kernels/sparse_tensors_map_ops.cc @@ -49,8 +49,8 @@ class SparseTensorsMap : public ResourceBase { absl::InlinedVector shape; } PersistentSparseTensor; - Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, - int64_t* handle) { + absl::Status AddSparseTensor(OpKernelContext* ctx, const SparseTensor& sp, + int64_t* handle) { Tensor ix; TF_RETURN_IF_ERROR( ctx->allocate_temp(sp.indices().dtype(), sp.indices().shape(), &ix)); @@ -72,7 +72,7 @@ class SparseTensorsMap : public ResourceBase { return absl::OkStatus(); } - Status RetrieveAndClearSparseTensors( + absl::Status RetrieveAndClearSparseTensors( OpKernelContext* ctx, const TTypes::ConstVec& handles, std::vector* sparse_tensors) { sparse_tensors->clear(); @@ -113,7 +113,7 @@ class SparseTensorsMap : public ResourceBase { class SparseTensorAccessingOp : public OpKernel { public: - typedef std::function CreatorCallback; + typedef std::function CreatorCallback; explicit SparseTensorAccessingOp(OpKernelConstruction* context) : OpKernel(context), sparse_tensors_map_(nullptr) {} @@ -123,8 +123,8 @@ class SparseTensorAccessingOp : public OpKernel { if (sparse_tensors_map_) sparse_tensors_map_->Unref(); } - Status GetMap(OpKernelContext* ctx, bool is_writing, - SparseTensorsMap** sparse_tensors_map) { + absl::Status GetMap(OpKernelContext* ctx, bool is_writing, + SparseTensorsMap** sparse_tensors_map) { mutex_lock l(mu_); if (sparse_tensors_map_) { diff --git a/tensorflow/core/kernels/sparse_to_dense_op.cc b/tensorflow/core/kernels/sparse_to_dense_op.cc index cad0d090034038..255f9ed65c93ab 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op.cc @@ -46,10 +46,10 @@ namespace tensorflow { namespace { -Status CheckSparseToDenseShapes(const Tensor& indices, - const Tensor& output_shape, - const Tensor& sparse_values, - const Tensor& default_value) { +absl::Status CheckSparseToDenseShapes(const Tensor& indices, + const Tensor& output_shape, + const Tensor& sparse_values, + const Tensor& default_value) { // sparse_indices if (indices.dims() > 2) { return errors::InvalidArgument( diff --git a/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc b/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc index 0998bcbc43dfc5..6491d83b71ff7a 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_gpu.cu.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/core/kernels/sparse_to_dense_op_gpu.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -194,7 +193,8 @@ void LaunchSparseToDense::operator()( // Ensure that within the callback, the proper GPU settings are // configured. auto stream = c->op_device_context()->stream(); - se::gpu::ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); OP_REQUIRES_ASYNC( c, valid_status.valid == INT_MAX, @@ -225,7 +225,7 @@ void LaunchSparseToDense::operator()( shape.flat().data(), num_dims, dense_ptr), done); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda Context. done(); diff --git a/tensorflow/core/kernels/sparse_to_dense_op_test.cc b/tensorflow/core/kernels/sparse_to_dense_op_test.cc index 4082c570e0d641..96bff3a46ae0bf 100644 --- a/tensorflow/core/kernels/sparse_to_dense_op_test.cc +++ b/tensorflow/core/kernels/sparse_to_dense_op_test.cc @@ -240,7 +240,7 @@ static void BM_SparseToDense(::testing::benchmark::State& state) { .Input(FakeInput(DT_FLOAT)) .Finalize(&sparse_node_def)); - Status status; + absl::Status status; std::unique_ptr op(CreateOpKernel(DEVICE_CPU, device.get(), cpu_allocator(), sparse_node_def, TF_GRAPH_DEF_VERSION, &status)); diff --git a/tensorflow/core/kernels/sparse_utils.cc b/tensorflow/core/kernels/sparse_utils.cc index d9a2850e596519..d6f70dadc6a031 100644 --- a/tensorflow/core/kernels/sparse_utils.cc +++ b/tensorflow/core/kernels/sparse_utils.cc @@ -148,8 +148,9 @@ namespace { // Ensures indices, values, shape are all of the proper ranks and are // compatible. -Status ValidateSparseTensorShape(const Tensor& indices, const Tensor& values, - const Tensor& shape) { +absl::Status ValidateSparseTensorShape(const Tensor& indices, + const Tensor& values, + const Tensor& shape) { // Indices must be a matrix, and values/shape must be a vector. if (!TensorShapeUtils::IsMatrix(indices.shape())) { return errors::InvalidArgument("Sparse indices must be rank 2 but is rank ", @@ -196,8 +197,8 @@ string CreateIndexString(const IndexTensor& indices, int64_t row) { // Ensures all sparse indices are within correct bounds. template -Status ValidateSparseTensorIndicesUnordered(const Tensor& indices, - const Tensor& shape) { +absl::Status ValidateSparseTensorIndicesUnordered(const Tensor& indices, + const Tensor& shape) { // Ensure no index is out-of-bounds. const auto indices_mat = indices.flat_inner_dims(); const auto shape_vec = shape.flat(); @@ -221,8 +222,8 @@ Status ValidateSparseTensorIndicesUnordered(const Tensor& indices, // Ensures all sparse indices are within correct bounds and are // lexicographically ordered. template -Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, - const Tensor& shape) { +absl::Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, + const Tensor& shape) { const auto indices_mat = indices.flat_inner_dims(); const auto shape_vec = shape.flat(); int64_t nnz = indices.dim_size(0); @@ -288,9 +289,9 @@ Status ValidateSparseTensorIndicesOrdered(const Tensor& indices, } // namespace template -Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, - const Tensor& shape, - IndexValidation index_validation) { +absl::Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, + const Tensor& shape, + IndexValidation index_validation) { TF_RETURN_IF_ERROR(ValidateSparseTensorShape(indices, values, shape)); switch (index_validation) { case IndexValidation::kOrdered: diff --git a/tensorflow/core/kernels/sparse_utils.h b/tensorflow/core/kernels/sparse_utils.h index 4e6ab744691c28..8f86b5184e1def 100644 --- a/tensorflow/core/kernels/sparse_utils.h +++ b/tensorflow/core/kernels/sparse_utils.h @@ -78,9 +78,9 @@ enum class IndexValidation { // Validates the three component tensors of a sparse tensor have the proper // shapes. Also validates index values according to the method supplied. template -Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, - const Tensor& shape, - IndexValidation index_validation); +absl::Status ValidateSparseTensor(const Tensor& indices, const Tensor& values, + const Tensor& shape, + IndexValidation index_validation); } // namespace sparse_utils } // namespace tensorflow diff --git a/tensorflow/core/kernels/sparse_utils_test.cc b/tensorflow/core/kernels/sparse_utils_test.cc index db7ea59056b608..035511107d608c 100644 --- a/tensorflow/core/kernels/sparse_utils_test.cc +++ b/tensorflow/core/kernels/sparse_utils_test.cc @@ -466,7 +466,7 @@ TEST_P(ValidateSparseTensorTest, IndexOutOfBoundsFails) { for (int64_t val : {static_cast(-1), test_shape.dim_size(dim)}) { indices_mat(row, dim) = val; - Status indices_valid = ValidateSparseTensor( + absl::Status indices_valid = ValidateSparseTensor( indices, values, shape, index_validation); if (index_validation == IndexValidation::kNone) { TF_EXPECT_OK(indices_valid); @@ -511,7 +511,7 @@ TEST_P(ValidateSparseTensorTest, IndexOutOfOrderFailsForOrderedValidation) { std::swap(indices_mat(row1, dim), indices_mat(row2, dim)); } - Status indices_valid = ValidateSparseTensor( + absl::Status indices_valid = ValidateSparseTensor( indices, values, shape, index_validation); if (ordered) { EXPECT_THAT( diff --git a/tensorflow/core/kernels/sparse_xent_op.cc b/tensorflow/core/kernels/sparse_xent_op.cc index 4ece900f6c5b95..06cefaa44e576d 100644 --- a/tensorflow/core/kernels/sparse_xent_op.cc +++ b/tensorflow/core/kernels/sparse_xent_op.cc @@ -33,7 +33,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template -Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) { +absl::Status CheckInvalidLabelIndex(const Tensor& labels, int64_t max_index) { if (labels.NumElements() == 0) return absl::OkStatus(); const auto label_values = labels.vec(); int64_t bad_index; diff --git a/tensorflow/core/kernels/spectrogram_convert_test_data.cc b/tensorflow/core/kernels/spectrogram_convert_test_data.cc index 1878eb5999b505..18521b6182f0f9 100644 --- a/tensorflow/core/kernels/spectrogram_convert_test_data.cc +++ b/tensorflow/core/kernels/spectrogram_convert_test_data.cc @@ -25,7 +25,7 @@ namespace wav { // This takes a CSV file representing an array of complex numbers, and saves out // a version using a binary format to save space in the repository. -Status ConvertCsvToRaw(const string& input_filename) { +absl::Status ConvertCsvToRaw(const string& input_filename) { std::vector>> input_data; ReadCSVFileToComplexVectorOrDie(input_filename, &input_data); const string output_filename = input_filename + ".bin"; @@ -47,7 +47,7 @@ int main(int argc, char* argv[]) { return 1; } tensorflow::string filename(argv[1]); - tensorflow::Status status = tensorflow::wav::ConvertCsvToRaw(filename); + absl::Status status = tensorflow::wav::ConvertCsvToRaw(filename); if (!status.ok()) { LOG(ERROR) << "Error processing '" << filename << "':" << status; return 1; diff --git a/tensorflow/core/kernels/squared-loss.h b/tensorflow/core/kernels/squared-loss.h index 3a0f6d2abb2253..3b334d68e48a5f 100644 --- a/tensorflow/core/kernels/squared-loss.h +++ b/tensorflow/core/kernels/squared-loss.h @@ -63,7 +63,7 @@ class SquaredLossUpdater : public DualLossUpdater { inline double SmoothnessConstant() const final { return 1.0; } // Labels don't require conversion for linear regression. - Status ConvertLabel(float* const example_label) const final { + absl::Status ConvertLabel(float* const example_label) const final { return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/stack.cc b/tensorflow/core/kernels/stack.cc index 90eaf2efebe1bd..34769627368cc3 100644 --- a/tensorflow/core/kernels/stack.cc +++ b/tensorflow/core/kernels/stack.cc @@ -54,7 +54,7 @@ class Stack : public ResourceBase { max_size_(max_size), closed_(false) {} - Status Push(const TensorAndAllocation& value) { + absl::Status Push(const TensorAndAllocation& value) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(CheckNotClosed()); int stack_size = stack_.size(); @@ -66,7 +66,7 @@ class Stack : public ResourceBase { return absl::OkStatus(); } - Status Pop(TensorAndAllocation* value) { + absl::Status Pop(TensorAndAllocation* value) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(CheckNotClosed()); if (stack_.empty()) { @@ -116,7 +116,7 @@ class Stack : public ResourceBase { bool closed_ TF_GUARDED_BY(mu_); std::vector stack_ TF_GUARDED_BY(mu_); - Status CheckNotClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status CheckNotClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (closed_) { return errors::InvalidArgument("Stack[", stack_name_, "] has already been closed."); @@ -125,7 +125,7 @@ class Stack : public ResourceBase { } }; -Status GetStack(OpKernelContext* ctx, Stack** stack) { +absl::Status GetStack(OpKernelContext* ctx, Stack** stack) { if (ctx->input_dtype(0) == DT_RESOURCE) { return LookupResource(ctx, HandleFromInput(ctx, 0), stack); } else { @@ -258,7 +258,7 @@ void StackPushOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { new Tensor(cpu_allocator, tensor.dtype(), tensor.shape()); device_ctxt->CopyDeviceTensorToCPU( &tensor, "StackPush", device, cpu_tensor, - [cpu_tensor, stack, ctx, done](const Status& s) { + [cpu_tensor, stack, ctx, done](const absl::Status& s) { ctx->SetStatus(s); if (s.ok()) { AllocatorAttributes alloc_attrs = ctx->input_alloc_attr(1); @@ -307,7 +307,7 @@ void StackPopOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { new Tensor(gpu_allocator, cpu_tensor->dtype(), cpu_tensor->shape()); device_ctxt->CopyCPUTensorToDevice( cpu_tensor, device, device_tensor, - [device_tensor, ctx, done](const Status& s) { + [device_tensor, ctx, done](const absl::Status& s) { ctx->SetStatus(s); if (s.ok()) { ctx->set_output(0, *device_tensor); diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index 63d84513b3f5e1..64230037a8945e 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -38,14 +38,14 @@ class Buffer : public ResourceBase { : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {} // the Buffer takes ownership of the Tuple - Status Put(Tuple* tuple) { + absl::Status Put(Tuple* tuple) { std::unique_lock lock(mu_); std::size_t tuple_bytes = GetTupleBytes(*tuple); // Sanity check so that we don't block for ever below if (memory_limit_ > 0 && tuple_bytes > memory_limit_) { - return Status( + return absl::Status( errors::ResourceExhausted("Attempted to insert " "tensors with combined size of '", tuple_bytes, @@ -103,7 +103,7 @@ class Buffer : public ResourceBase { } // Return tuple at index - Status Peek(std::size_t index, Tuple* tuple) { + absl::Status Peek(std::size_t index, Tuple* tuple) { std::unique_lock lock(mu_); // Wait if the requested index is not available @@ -176,12 +176,13 @@ class Buffer : public ResourceBase { std::deque buf_; }; -Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) { +absl::Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, + Buffer** buf) { auto rm = ctx->resource_manager(); ContainerInfo cinfo; // Lambda for creating the Staging Area - auto create_fn = [&ndef](Buffer** ret) -> Status { + auto create_fn = [&ndef](Buffer** ret) -> absl::Status { int64_t capacity; int64_t memory_limit; TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity)); diff --git a/tensorflow/core/kernels/stateful_random_ops.cc b/tensorflow/core/kernels/stateful_random_ops.cc index a79af474016af1..e7f747eaa7a991 100644 --- a/tensorflow/core/kernels/stateful_random_ops.cc +++ b/tensorflow/core/kernels/stateful_random_ops.cc @@ -55,7 +55,7 @@ struct UpdateVariableAndFill_Philox { } // end namespace functor -Status CheckState(const Tensor& state) { +absl::Status CheckState(const Tensor& state) { if (state.dtype() != STATE_ELEMENT_DTYPE) { return errors::InvalidArgument("dtype of RNG state variable must be ", DataTypeString(STATE_ELEMENT_DTYPE), @@ -68,7 +68,7 @@ Status CheckState(const Tensor& state) { return absl::OkStatus(); } -Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) { +absl::Status CheckPhiloxState(const Tensor& state, int64_t alg_tag_skip = 0) { static_assert(std::is_same::value, "StateElementType must be int64"); static_assert(std::is_same::value, @@ -113,7 +113,7 @@ absl::StatusOr GetAlg(OpKernelContext* ctx, } template -Status UpdateVariableAndFill( +absl::Status UpdateVariableAndFill( OpKernelContext* ctx, Distribution dist, int state_input_idx, bool read_alg_from_state, ConcreteRngAlgorithm alg, int64_t output_size, typename Distribution::ResultElementType* output_data) { @@ -190,7 +190,7 @@ class StatefulRandomOp : public OpKernel { }; template -Status GetScalar(const Tensor& tensor, int input_idx, T* result) { +absl::Status GetScalar(const Tensor& tensor, int input_idx, T* result) { auto dtype = DataTypeToEnum::v(); if (tensor.dims() != 0) { return errors::InvalidArgument("input ", std::to_string(input_idx), diff --git a/tensorflow/core/kernels/stateless_random_gamma_op.cc b/tensorflow/core/kernels/stateless_random_gamma_op.cc index 815cd84056aa64..dd4689f25afefa 100644 --- a/tensorflow/core/kernels/stateless_random_gamma_op.cc +++ b/tensorflow/core/kernels/stateless_random_gamma_op.cc @@ -59,11 +59,11 @@ namespace functor { template struct StatelessRandomGammaFunctor { - static Status Fill(OpKernelContext* ctx, const T* alpha_flat, - int64_t num_samples, int64_t num_alphas, - int64_t samples_per_alpha, const uint64* key, - const uint64* counter, random::PhiloxRandom random, - T* samples_flat) { + static absl::Status Fill(OpKernelContext* ctx, const T* alpha_flat, + int64_t num_samples, int64_t num_alphas, + int64_t samples_per_alpha, const uint64* key, + const uint64* counter, random::PhiloxRandom random, + T* samples_flat) { if (key != nullptr && counter != nullptr) { random = GetPhiloxRandomFromCounterKeyMem(counter, key); } diff --git a/tensorflow/core/kernels/stateless_random_gamma_op.h b/tensorflow/core/kernels/stateless_random_gamma_op.h index 4b81863cfe0b69..426dbd5e7117e4 100644 --- a/tensorflow/core/kernels/stateless_random_gamma_op.h +++ b/tensorflow/core/kernels/stateless_random_gamma_op.h @@ -29,11 +29,11 @@ namespace functor { // nullptr, they provide the input; otherwise `random` provides the input. template struct StatelessRandomGammaFunctor { - static Status Fill(OpKernelContext* ctx, const T* alpha_flat, - int64_t num_samples, int64_t num_alphas, - int64_t samples_per_alpha, const uint64* key, - const uint64* counter, const random::PhiloxRandom& random, - T* samples_flat); + static absl::Status Fill(OpKernelContext* ctx, const T* alpha_flat, + int64_t num_samples, int64_t num_alphas, + int64_t samples_per_alpha, const uint64* key, + const uint64* counter, + const random::PhiloxRandom& random, T* samples_flat); }; } // namespace functor diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index 1ce076e4c1a0d1..19cee8f704f553 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -30,8 +30,8 @@ namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; -Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, - random::PhiloxRandom::ResultType* out_counter) { +absl::Status GenerateKey(Tensor seed, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter) { // Grab the two seeds uint64 seed0; uint64 seed1; diff --git a/tensorflow/core/kernels/stateless_random_ops.h b/tensorflow/core/kernels/stateless_random_ops.h index 40f758abc272b9..42ce3bffe39164 100644 --- a/tensorflow/core/kernels/stateless_random_ops.h +++ b/tensorflow/core/kernels/stateless_random_ops.h @@ -27,8 +27,8 @@ namespace tensorflow { // // REQUIRES: `seed_t` must be a length-2 vector of type DT_INT{32,64}. // `out_key` and `out_counter` must be non-null. -Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key, - random::PhiloxRandom::ResultType* out_counter); +absl::Status GenerateKey(Tensor seed_t, random::PhiloxRandom::Key* out_key, + random::PhiloxRandom::ResultType* out_counter); // A base class for kernels of stateless RNG ops that take shape and seed as the // first 2 inputs. diff --git a/tensorflow/core/kernels/stateless_random_ops_v2.h b/tensorflow/core/kernels/stateless_random_ops_v2.h index b566f490fdd6fb..0b5b8945c5f1a5 100644 --- a/tensorflow/core/kernels/stateless_random_ops_v2.h +++ b/tensorflow/core/kernels/stateless_random_ops_v2.h @@ -22,9 +22,9 @@ limitations under the License. namespace tensorflow { -inline Status CheckKeyCounterShape(int minimum_counter_size, - TensorShape const& key_shape, - TensorShape const& counter_shape) { +inline absl::Status CheckKeyCounterShape(int minimum_counter_size, + TensorShape const& key_shape, + TensorShape const& counter_shape) { if (!(key_shape.dims() == 1 && key_shape.dim_size(0) == RNG_KEY_SIZE)) { return errors::InvalidArgument( "key must have shape [", RNG_KEY_SIZE, "], not ", diff --git a/tensorflow/core/kernels/stateless_random_ops_v2_util.h b/tensorflow/core/kernels/stateless_random_ops_v2_util.h index 9d814cfbb5744b..a57983426fb21f 100644 --- a/tensorflow/core/kernels/stateless_random_ops_v2_util.h +++ b/tensorflow/core/kernels/stateless_random_ops_v2_util.h @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { template -Status GetScalar(const Tensor& tensor, int input_idx, T* result) { +absl::Status GetScalar(const Tensor& tensor, int input_idx, T* result) { auto dtype = DataTypeToEnum::v(); if (tensor.dims() != 0) { return errors::InvalidArgument("input ", std::to_string(input_idx), diff --git a/tensorflow/core/kernels/stochastic_cast_op_test.cc b/tensorflow/core/kernels/stochastic_cast_op_test.cc index 10d9eae13249ff..9543afda020307 100644 --- a/tensorflow/core/kernels/stochastic_cast_op_test.cc +++ b/tensorflow/core/kernels/stochastic_cast_op_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/random/philox_random.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/lib/random/philox_random.h" namespace Eigen { namespace internal { diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc index 55d688b530ea55..58571b288c9a89 100644 --- a/tensorflow/core/kernels/string_format_op_test.cc +++ b/tensorflow/core/kernels/string_format_op_test.cc @@ -29,9 +29,9 @@ namespace { class StringFormatGraphTest : public OpsTestBase { protected: - Status Init(int num_inputs, DataType input_type, - const string& template_ = "%s", const string& placeholder = "%s", - int summarize = 3) { + absl::Status Init(int num_inputs, DataType input_type, + const string& template_ = "%s", + const string& placeholder = "%s", int summarize = 3) { TF_CHECK_OK(NodeDefBuilder("op", "StringFormat") .Input(FakeInput(num_inputs, input_type)) .Attr("template", template_) diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index 105a89f589a0fe..d8514369bc1d8e 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -19,7 +19,8 @@ limitations under the License. namespace tensorflow { // Sets unit value based on str. -Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { +absl::Status ParseUnicodeEncoding(const string& str, + UnicodeEncoding* encoding) { if (str == "UTF-8") { *encoding = UnicodeEncoding::UTF8; } else if (str == "UTF-16-BE") { @@ -35,7 +36,7 @@ Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { } // Sets unit value based on str. -Status ParseCharUnit(const string& str, CharUnit* unit) { +absl::Status ParseCharUnit(const string& str, CharUnit* unit) { if (str == "BYTE") { *unit = CharUnit::BYTE; } else if (str == "UTF8_CHAR") { diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 6b32fbf83fe2ec..9dda609a5b7d62 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -33,10 +33,10 @@ enum class CharUnit { BYTE, UTF8_CHAR }; inline bool IsTrailByte(char x) { return static_cast(x) < -0x40; } // Sets `encoding` based on `str`. -Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); +absl::Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); // Sets `unit` value based on `str`. -Status ParseCharUnit(const string& str, CharUnit* unit); +absl::Status ParseCharUnit(const string& str, CharUnit* unit); // Returns the number of Unicode characters in a UTF-8 string. // Result may be incorrect if the input string is not valid UTF-8. diff --git a/tensorflow/core/kernels/summary_image_op.cc b/tensorflow/core/kernels/summary_image_op.cc index a68bf724cf9efc..b434e424490414 100644 --- a/tensorflow/core/kernels/summary_image_op.cc +++ b/tensorflow/core/kernels/summary_image_op.cc @@ -142,9 +142,10 @@ class SummaryImageOp : public OpKernel { // differently in the float and uint8 cases: the float case needs a temporary // buffer which can be shared across calls to ith_image, but the uint8 case // does not. - Status AddImages(const string& tag, int batch_size, int w, int h, int depth, - const std::function& ith_image, - Summary* s) { + absl::Status AddImages(const string& tag, int batch_size, int w, int h, + int depth, + const std::function& ith_image, + Summary* s) { const int N = std::min(max_images_, batch_size); for (int i = 0; i < N; ++i) { Summary::Value* v = s->add_value(); diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h index 374fa71dd60c87..f423d4abaa5808 100644 --- a/tensorflow/core/kernels/summary_interface.h +++ b/tensorflow/core/kernels/summary_interface.h @@ -32,28 +32,31 @@ class SummaryWriterInterface : public ResourceBase { virtual ~SummaryWriterInterface() override {} // Flushes all unwritten messages in the queue. - virtual Status Flush() = 0; + virtual absl::Status Flush() = 0; // These are called in the OpKernel::Compute methods for the summary ops. - virtual Status WriteTensor(int64_t global_step, Tensor t, const string& tag, - const string& serialized_metadata) = 0; + virtual absl::Status WriteTensor(int64_t global_step, Tensor t, + const string& tag, + const string& serialized_metadata) = 0; - virtual Status WriteScalar(int64_t global_step, Tensor t, - const string& tag) = 0; + virtual absl::Status WriteScalar(int64_t global_step, Tensor t, + const string& tag) = 0; - virtual Status WriteHistogram(int64_t global_step, Tensor t, - const string& tag) = 0; + virtual absl::Status WriteHistogram(int64_t global_step, Tensor t, + const string& tag) = 0; - virtual Status WriteImage(int64_t global_step, Tensor t, const string& tag, - int max_images, Tensor bad_color) = 0; + virtual absl::Status WriteImage(int64_t global_step, Tensor t, + const string& tag, int max_images, + Tensor bad_color) = 0; - virtual Status WriteAudio(int64_t global_step, Tensor t, const string& tag, - int max_outputs_, float sample_rate) = 0; + virtual absl::Status WriteAudio(int64_t global_step, Tensor t, + const string& tag, int max_outputs_, + float sample_rate) = 0; - virtual Status WriteGraph(int64_t global_step, - std::unique_ptr graph) = 0; + virtual absl::Status WriteGraph(int64_t global_step, + std::unique_ptr graph) = 0; - virtual Status WriteEvent(std::unique_ptr e) = 0; + virtual absl::Status WriteEvent(std::unique_ptr e) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/kernels/summary_op_test.cc b/tensorflow/core/kernels/summary_op_test.cc index 9c9e87581c69e0..cd6b96b0f68be6 100644 --- a/tensorflow/core/kernels/summary_op_test.cc +++ b/tensorflow/core/kernels/summary_op_test.cc @@ -151,7 +151,7 @@ TEST_F(SummaryHistoOpTest, Error_WrongDimsTags) { // Feed and run AddInputFromArray(TensorShape({2, 1}), {"tag1", "tag2"}); AddInputFromArray(TensorShape({2}), {1.0f, -0.73f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "tags must be scalar")) << s; } @@ -161,7 +161,7 @@ TEST_F(SummaryHistoOpTest, Error_TooManyTagValues) { // Feed and run AddInputFromArray(TensorShape({2}), {"tag1", "tag2"}); AddInputFromArray(TensorShape({2, 1}), {1.0f, -0.73f}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "tags must be scalar")) << s; } @@ -260,7 +260,7 @@ TEST_F(SummaryMergeOpTest, Error_MismatchedSize) { "value { tag: \"tagduplicate\" simple_value: 1.0 } ", &s2)); AddInputFromArray(TensorShape({2}), {s1.SerializeAsString(), s2.SerializeAsString()}); - Status s = RunOpKernel(); + absl::Status s = RunOpKernel(); EXPECT_TRUE(absl::StrContains(s.ToString(), "Duplicate tag")) << s; } diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index fa24b716a9c822..46c26e356d785d 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -81,8 +81,8 @@ TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_SET_ZERO_GPU); std::atomic TensorArray::tensor_array_counter{0}; -Status TensorArray::CopyShapesFrom(TensorArray* rhs, - const TensorShape* shape_to_prepend) { +absl::Status TensorArray::CopyShapesFrom(TensorArray* rhs, + const TensorShape* shape_to_prepend) { mutex_lock l(mu_); mutex_lock l_rhs(rhs->mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index 1081c2be8a08a8..aef4a97b01fc69 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -42,8 +42,8 @@ namespace tensor_array { // Full implementations are in tensor_array.cc template -Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current, - const Tensor* add) { +absl::Status AddToTensor(OpKernelContext* ctx, Tensor* sum, + const Tensor* current, const Tensor* add) { return errors::InvalidArgument( "tensor_array::AddToTensor type not supported: ", DataTypeString(DataTypeToEnum::value)); @@ -70,7 +70,7 @@ TF_CALL_COMPLEX_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU); #undef TENSOR_ARRAY_WRITE_OR_ADD template -Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { +absl::Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { return errors::InvalidArgument( "tensor_array::TensorSetZero type not supported: ", DataTypeString(DataTypeToEnum::value)); @@ -185,20 +185,21 @@ class TensorArray : public ResourceBase { // Note, value is passed as a pointer because we its underlying // Tensor's shape is accessed. Otherwise it is not modified. template - Status WriteOrAggregate(OpKernelContext* ctx, const int32_t index, - const Tensor* value) { + absl::Status WriteOrAggregate(OpKernelContext* ctx, const int32_t index, + const Tensor* value) { mutex_lock l(mu_); return LockedWriteOrAggregate(ctx, index, value); } template - Status WriteOrAggregateMany(OpKernelContext* ctx, - const std::vector& indices, - std::vector* values) { + absl::Status WriteOrAggregateMany(OpKernelContext* ctx, + const std::vector& indices, + std::vector* values) { mutex_lock l(mu_); int32_t i = 0; for (const int32_t ix : indices) { - Status s = LockedWriteOrAggregate(ctx, ix, &(*values)[i]); + absl::Status s = + LockedWriteOrAggregate(ctx, ix, &(*values)[i]); ++i; TF_RETURN_IF_ERROR(s); } @@ -221,20 +222,20 @@ class TensorArray : public ResourceBase { // the returned '*value'. // * The index is marked as read (it cannot be rewritten to). template - Status Read(OpKernelContext* ctx, const int32_t index, Tensor* value) { + absl::Status Read(OpKernelContext* ctx, const int32_t index, Tensor* value) { mutex_lock l(mu_); return LockedRead(ctx, index, value); } template - Status ReadMany(OpKernelContext* ctx, const std::vector& indices, - std::vector* values) { + absl::Status ReadMany(OpKernelContext* ctx, const std::vector& indices, + std::vector* values) { mutex_lock l(mu_); values->clear(); values->resize(indices.size()); int32_t i = 0; for (const int32_t ix : indices) { - Status s = LockedRead(ctx, ix, &(*values)[i]); + absl::Status s = LockedRead(ctx, ix, &(*values)[i]); ++i; if (!s.ok()) return s; } @@ -248,10 +249,10 @@ class TensorArray : public ResourceBase { return element_shape_; } - Status SetElemShape(const PartialTensorShape& candidate) { + absl::Status SetElemShape(const PartialTensorShape& candidate) { mutex_lock l(mu_); PartialTensorShape new_element_shape_; - Status s = element_shape_.MergeWith(candidate, &new_element_shape_); + absl::Status s = element_shape_.MergeWith(candidate, &new_element_shape_); if (!s.ok()) { return s; } @@ -271,7 +272,7 @@ class TensorArray : public ResourceBase { } // Return the size of the TensorArray. - Status Size(int32* size) { + absl::Status Size(int32* size) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = tensors_.size(); @@ -279,7 +280,7 @@ class TensorArray : public ResourceBase { } // Record the size of the TensorArray after an unpack or split. - Status SetMarkedSize(int32_t size) { + absl::Status SetMarkedSize(int32_t size) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); if (!is_grad_) { @@ -289,7 +290,7 @@ class TensorArray : public ResourceBase { } // Return the marked size of the TensorArray. - Status MarkedSize(int32* size) { + absl::Status MarkedSize(int32* size) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = marked_size_; @@ -297,7 +298,7 @@ class TensorArray : public ResourceBase { } // Return the size that should be used by pack or concat op. - Status PackOrConcatSize(int32* size) { + absl::Status PackOrConcatSize(int32* size) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(LockedReturnIfClosed()); *size = is_grad_ ? marked_size_ : tensors_.size(); @@ -332,7 +333,8 @@ class TensorArray : public ResourceBase { // zero-tensors, which will be replaced by future aggregate writes, // or instantiated by future reads. Requires a non-const pointer // to the rhs to access its mutex. - Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend); + absl::Status CopyShapesFrom(TensorArray* rhs, + const TensorShape* shape_to_prepend); // Clear the TensorArray, including any Tensor references, and mark as closed. void ClearAndMarkClosed() { @@ -350,19 +352,19 @@ class TensorArray : public ResourceBase { } private: - Status LockedWrite(OpKernelContext* ctx, const int32_t index, Tensor* value) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status LockedWrite(OpKernelContext* ctx, const int32_t index, + Tensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); template - Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32_t index, - const Tensor* value) + absl::Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32_t index, + const Tensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); template - Status LockedRead(OpKernelContext* ctx, const int32_t index, Tensor* value) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status LockedRead(OpKernelContext* ctx, const int32_t index, + Tensor* value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); - Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status LockedReturnIfClosed() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (closed_) { return errors::InvalidArgument("TensorArray ", handle_.vec()(1), " has already been closed."); @@ -438,9 +440,9 @@ class TensorArray : public ResourceBase { }; template -Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, - const int32_t index, - const Tensor* value) { +absl::Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, + const int32_t index, + const Tensor* value) { TF_RETURN_IF_ERROR(LockedReturnIfClosed()); size_t index_size = static_cast(index); if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) { @@ -514,15 +516,15 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, Tensor* existing_t = &t.tensor; if (t.local_copy) { - Status s = tensor_array::AddToTensor(ctx, existing_t, - existing_t, value); + absl::Status s = tensor_array::AddToTensor(ctx, existing_t, + existing_t, value); TF_RETURN_IF_ERROR(s); } else { Tensor local_tensor; TF_RETURN_IF_ERROR( ctx->allocate_temp(dtype_, existing_t->shape(), &local_tensor)); - Status s = tensor_array::AddToTensor(ctx, &local_tensor, - existing_t, value); + absl::Status s = tensor_array::AddToTensor(ctx, &local_tensor, + existing_t, value); TF_RETURN_IF_ERROR(s); t.tensor = local_tensor; t.local_copy = true; @@ -540,8 +542,8 @@ Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx, } template -Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, - Tensor* value) { +absl::Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, + Tensor* value) { TF_RETURN_IF_ERROR(LockedReturnIfClosed()); if ((index < 0) || (!is_grad_ && (static_cast(index) >= tensors_.size()))) { @@ -606,7 +608,7 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32_t index, // return zeros of the appropriate shape. TF_RETURN_IF_ERROR(ctx->allocate_temp(dtype_, t.shape, &t.tensor)); if (t.shape.num_elements() > 0) { - Status s = tensor_array::TensorSetZero(ctx, &t.tensor); + absl::Status s = tensor_array::TensorSetZero(ctx, &t.tensor); if (!s.ok()) return s; } } diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 7ccd9b5afe12a6..fe318a58803fb6 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -57,7 +57,8 @@ typedef Eigen::GpuDevice GPUDevice; namespace tensorflow { -Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { +absl::Status GetHandle(OpKernelContext* ctx, string* container, + string* ta_handle) { { Tensor tensor; // Assuming that handle is the input at index 0. @@ -78,7 +79,7 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { return absl::OkStatus(); } -Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { +absl::Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { string container; string ta_handle; if (ctx->input_dtype(0) != DT_RESOURCE) { @@ -94,7 +95,7 @@ Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { } } -Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { +absl::Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { const Tensor* flow_control; TF_RETURN_IF_ERROR(ctx->input("flow_in", &flow_control)); if (set_output) { @@ -152,9 +153,9 @@ class TensorArrayCreationOp : public OpKernel { } protected: - virtual Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, - Tensor* tensor_array_output_handle, - TensorArray** output_tensor_array) = 0; + virtual absl::Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, + Tensor* tensor_array_output_handle, + TensorArray** output_tensor_array) = 0; private: const DeviceType device_type_; @@ -185,9 +186,9 @@ class TensorArrayOp : public TensorArrayCreationOp { if (tensor_array_name_.empty()) tensor_array_name_ = name(); } - Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, - Tensor* tensor_array_output_handle, - TensorArray** output_tensor_array) override { + absl::Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, + Tensor* tensor_array_output_handle, + TensorArray** output_tensor_array) override { const Tensor* tensor_size; TF_RETURN_IF_ERROR(ctx->input("size", &tensor_size)); @@ -310,9 +311,9 @@ class TensorArrayGradOp : public TensorArrayCreationOp { OP_REQUIRES_OK(context, context->GetAttr("source", &source_)); } - Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, - Tensor* tensor_array_output_handle, - TensorArray** output_tensor_array) override { + absl::Status CreateTensorArray(OpKernelContext* ctx, ResourceMgr* rm, + Tensor* tensor_array_output_handle, + TensorArray** output_tensor_array) override { string container; string tensor_array_name; if (ctx->input_dtype(0) != DT_RESOURCE) { @@ -384,8 +385,8 @@ class TensorArrayGradOp : public TensorArrayCreationOp { const auto key = strings::StrCat(output_handle(0), output_handle(1)); auto creator = [key, tensor_array, array_size, marked_size, element_shape, - shape_to_prepend, - tensor_array_output_handle](TensorArray** ret) -> Status { + shape_to_prepend, tensor_array_output_handle]( + TensorArray** ret) -> absl::Status { *ret = new TensorArray( key, tensor_array->ElemType(), *tensor_array_output_handle, array_size, element_shape, tensor_array->HasIdenticalElementShapes(), @@ -395,7 +396,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { return (*ret)->CopyShapesFrom(tensor_array, &shape_to_prepend); }; - Status s = ctx->step_container()->LookupOrCreate( + absl::Status s = ctx->step_container()->LookupOrCreate( rm, key, output_tensor_array, creator); (*output_tensor_array)->Unref(); @@ -496,7 +497,7 @@ class TensorArrayWriteOp : public OpKernel { DataTypeString(tensor_array->ElemType()), " but Op is trying to write dtype ", DataTypeString(tensor_value->dtype()), ".")); - Status s = + absl::Status s = tensor_array->WriteOrAggregate(ctx, index, tensor_value); OP_REQUIRES_OK(ctx, s); } @@ -577,7 +578,7 @@ class TensorArrayReadOp : public OpKernel { "TensorArray dtype is ", DataTypeString(tensor_array->ElemType()), " but Op requested dtype ", DataTypeString(dtype_), ".")); Tensor value; - Status s = tensor_array->Read(ctx, index, &value); + absl::Status s = tensor_array->Read(ctx, index, &value); OP_REQUIRES_OK(ctx, s); ctx->set_output(0, value); } @@ -706,7 +707,7 @@ class TensorArrayPackOrGatherOp : public OpKernel { } // Read all the Tensors into a vector to keep track of their memory. - Status s = tensor_array->ReadMany(ctx, indices, &values); + absl::Status s = tensor_array->ReadMany(ctx, indices, &values); OP_REQUIRES_OK(ctx, s); const Tensor* value_0_t = &values[0]; @@ -912,7 +913,7 @@ class TensorArrayConcatOp : public OpKernel { std::vector values; std::vector indices(array_size); std::iota(indices.begin(), indices.end(), 0); - Status s = tensor_array->ReadMany(ctx, indices, &values); + absl::Status s = tensor_array->ReadMany(ctx, indices, &values); OP_REQUIRES_OK(ctx, s); Tensor* lengths_tensor = nullptr; @@ -1211,8 +1212,8 @@ class TensorArrayUnpackOrScatterOp : public OpKernel { OP_REQUIRES_OK(ctx, tensor_array->SetMarkedSize(array_size)); } - Status s = tensor_array->WriteOrAggregateMany(ctx, write_indices, - &write_values); + absl::Status s = tensor_array->WriteOrAggregateMany( + ctx, write_indices, &write_values); OP_REQUIRES_OK(ctx, s); } }; @@ -1404,8 +1405,8 @@ class TensorArraySplitOp : public OpKernel { std::vector indices(array_size); std::iota(indices.begin(), indices.end(), 0); - Status s = tensor_array->WriteOrAggregateMany(ctx, indices, - &write_values); + absl::Status s = tensor_array->WriteOrAggregateMany( + ctx, indices, &write_values); OP_REQUIRES_OK(ctx, s); } }; diff --git a/tensorflow/core/kernels/tensor_flag_utils.cc b/tensorflow/core/kernels/tensor_flag_utils.cc index 974c4622a69a89..c6f0d3add62c20 100644 --- a/tensorflow/core/kernels/tensor_flag_utils.cc +++ b/tensorflow/core/kernels/tensor_flag_utils.cc @@ -21,25 +21,26 @@ limitations under the License. namespace tensorflow { namespace tensor_flag_utils { -Status ValidateSparseMatrixShardingConfig(const Tensor& config) { +absl::Status ValidateSparseMatrixShardingConfig(const Tensor& config) { if (TensorShapeUtils::IsScalar(config.shape())) { const float scalar_config = config.template scalar()(); if (0 < scalar_config && scalar_config <= 1.0) { return absl::OkStatus(); } - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Expected config to be in range (0, 1] but instead found ", scalar_config)); } if (!TensorShapeUtils::IsMatrix(config.shape())) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Expected config to be either scalar or matrix " - "but instead found tensor of rank ", - config.dims())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Expected config to be either scalar or matrix " + "but instead found tensor of rank ", + config.dims())); } if (config.dim_size(1) != 3) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat( "Expected config matrix to have dim(1) = 3 but instead found ", @@ -85,25 +86,26 @@ MatrixType FindConfigValueForKey( return config_mat(last_row_index, 2); } -Status ValidateScalarQuantityShardingConfig(const Tensor& config) { +absl::Status ValidateScalarQuantityShardingConfig(const Tensor& config) { if (TensorShapeUtils::IsScalar(config.shape())) { const float scalar_config = config.template scalar()(); if (0 < scalar_config && scalar_config <= 1.0) { return absl::OkStatus(); } - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Expected config to be in range (0, 1] but instead found ", scalar_config)); } if (!TensorShapeUtils::IsMatrix(config.shape())) { - return Status(absl::StatusCode::kInvalidArgument, - absl::StrCat("Expected config to be either scalar or matrix " - "but instead found tensor of rank ", - config.dims())); + return absl::Status( + absl::StatusCode::kInvalidArgument, + absl::StrCat("Expected config to be either scalar or matrix " + "but instead found tensor of rank ", + config.dims())); } if (config.dim_size(1) != 2) { - return Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat( "Expected config matrix to have dim(1) = 2 but instead found ", diff --git a/tensorflow/core/kernels/tensor_flag_utils.h b/tensorflow/core/kernels/tensor_flag_utils.h index f60c192326c40e..f20ecad7083dae 100644 --- a/tensorflow/core/kernels/tensor_flag_utils.h +++ b/tensorflow/core/kernels/tensor_flag_utils.h @@ -39,11 +39,11 @@ std::vector ParseRowStartIndices( // [0, 1.0). If config is a matrix then config must have shape M x 3, all of // its entries must be positive, and entries in the last column may not // exceed 1.0. If config is a matrix then it may not be empty. -Status ValidateSparseMatrixShardingConfig(const Tensor& config); +absl::Status ValidateSparseMatrixShardingConfig(const Tensor& config); // Returns OkStatus() if and only if config is a float scalar or a non-empty // matrix with dimensions M x 2. -Status ValidateScalarQuantityShardingConfig(const Tensor& config); +absl::Status ValidateScalarQuantityShardingConfig(const Tensor& config); // Returns the last entry of the first row in config_mat for which the first // two entries are no smaller than the respective entries in key. If no such diff --git a/tensorflow/core/kernels/tensor_list_util.cc b/tensorflow/core/kernels/tensor_list_util.cc index 7dc0d01b56b61d..0882902ba571b3 100644 --- a/tensorflow/core/kernels/tensor_list_util.cc +++ b/tensorflow/core/kernels/tensor_list_util.cc @@ -25,11 +25,11 @@ limitations under the License. namespace tensorflow { -Status TensorListBinaryAdd( +absl::Status TensorListBinaryAdd( OpKernelContext* c, const TensorList& a, const TensorList& b, TensorList* out, - std::function + std::function binary_add_func) { if (a.element_dtype != b.element_dtype) { return errors::InvalidArgument( @@ -64,10 +64,10 @@ Status TensorListBinaryAdd( return absl::OkStatus(); } -Status TensorListZerosLike( +absl::Status TensorListZerosLike( OpKernelContext* c, const TensorList& x, TensorList* y, - std::function + std::function zeros_like_func) { y->element_dtype = x.element_dtype; y->element_shape = x.element_shape; diff --git a/tensorflow/core/kernels/tensor_list_util.h b/tensorflow/core/kernels/tensor_list_util.h index 784b508c5a90d7..7ffabce89a4c17 100644 --- a/tensorflow/core/kernels/tensor_list_util.h +++ b/tensorflow/core/kernels/tensor_list_util.h @@ -25,17 +25,17 @@ class OpKernelContext; class TensorList; class Tensor; -Status TensorListBinaryAdd( +absl::Status TensorListBinaryAdd( OpKernelContext* c, const TensorList& a, const TensorList& b, TensorList* out, - std::function + std::function binary_add_func); -Status TensorListZerosLike( +absl::Status TensorListZerosLike( OpKernelContext* c, const TensorList& x, TensorList* y, - std::function + std::function zeros_like_func); } // namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_map.cc b/tensorflow/core/kernels/tensor_map.cc index a95d256cff92f4..78428aca178db3 100644 --- a/tensorflow/core/kernels/tensor_map.cc +++ b/tensorflow/core/kernels/tensor_map.cc @@ -43,7 +43,7 @@ void TensorMap::Encode(VariantTensorData* data) const { } } -static Status TensorMapDeviceCopy( +static absl::Status TensorMapDeviceCopy( const TensorMap& from, TensorMap* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { for (const std::pair& p : from.tensors()) { diff --git a/tensorflow/core/kernels/text_line_reader_op.cc b/tensorflow/core/kernels/text_line_reader_op.cc index 89b56cb1853bd7..c0be80b902ab3f 100644 --- a/tensorflow/core/kernels/text_line_reader_op.cc +++ b/tensorflow/core/kernels/text_line_reader_op.cc @@ -35,14 +35,14 @@ class TextLineReader : public ReaderBase { env_(env), line_number_(0) {} - Status OnWorkStartedLocked() override { + absl::Status OnWorkStartedLocked() override { line_number_ = 0; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); input_buffer_.reset(new io::InputBuffer(file_.get(), kBufferSize)); for (; line_number_ < skip_header_lines_; ++line_number_) { string line_contents; - Status status = input_buffer_->ReadLine(&line_contents); + absl::Status status = input_buffer_->ReadLine(&line_contents); if (absl::IsOutOfRange(status)) { // We ignore an end of file error when skipping header lines. // We will end up skipping this file. @@ -53,14 +53,14 @@ class TextLineReader : public ReaderBase { return absl::OkStatus(); } - Status OnWorkFinishedLocked() override { + absl::Status OnWorkFinishedLocked() override { input_buffer_.reset(nullptr); return absl::OkStatus(); } - Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) override { - Status status = input_buffer_->ReadLine(value); + absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) override { + absl::Status status = input_buffer_->ReadLine(value); ++line_number_; if (status.ok()) { *key = strings::StrCat(current_work(), ":", line_number_); @@ -75,7 +75,7 @@ class TextLineReader : public ReaderBase { } } - Status ResetLocked() override { + absl::Status ResetLocked() override { line_number_ = 0; input_buffer_.reset(nullptr); return ReaderBase::ResetLocked(); diff --git a/tensorflow/core/kernels/tf_record_reader_op.cc b/tensorflow/core/kernels/tf_record_reader_op.cc index 9126139afc6b65..9b989c197c91ef 100644 --- a/tensorflow/core/kernels/tf_record_reader_op.cc +++ b/tensorflow/core/kernels/tf_record_reader_op.cc @@ -36,7 +36,7 @@ class TFRecordReader : public ReaderBase { offset_(0), compression_type_(compression_type) {} - Status OnWorkStartedLocked() override { + absl::Status OnWorkStartedLocked() override { offset_ = 0; TF_RETURN_IF_ERROR(env_->NewRandomAccessFile(current_work(), &file_)); @@ -46,16 +46,16 @@ class TFRecordReader : public ReaderBase { return absl::OkStatus(); } - Status OnWorkFinishedLocked() override { + absl::Status OnWorkFinishedLocked() override { reader_.reset(nullptr); file_.reset(nullptr); return absl::OkStatus(); } - Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) override { + absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) override { *key = strings::StrCat(current_work(), ":", offset_); - Status status = reader_->ReadRecord(&offset_, value); + absl::Status status = reader_->ReadRecord(&offset_, value); if (absl::IsOutOfRange(status)) { *at_end = true; return absl::OkStatus(); @@ -65,7 +65,7 @@ class TFRecordReader : public ReaderBase { return absl::OkStatus(); } - Status ResetLocked() override { + absl::Status ResetLocked() override { offset_ = 0; reader_.reset(nullptr); file_.reset(nullptr); diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc index 4470b6b3fcf237..4c4c79f843e503 100644 --- a/tensorflow/core/kernels/topk_op.cc +++ b/tensorflow/core/kernels/topk_op.cc @@ -117,7 +117,7 @@ class TopK : public OpKernel { auto values = values_out->flat_inner_dims(); auto indices = indices_out->flat_inner_dims(); - Status s = functor::TopKFunctor::Compute( + absl::Status s = functor::TopKFunctor::Compute( context, sorted_, k, input, num_rows, num_cols, values, indices); OP_REQUIRES_OK(context, s); } @@ -131,7 +131,7 @@ namespace functor { template struct TopKFunctor { - static EIGEN_ALWAYS_INLINE Status Compute( + static EIGEN_ALWAYS_INLINE absl::Status Compute( OpKernelContext* context, bool sorted, int k, const typename TTypes::ConstTensor& input, const int64_t num_rows, const int64_t num_cols, typename TTypes::Tensor values, diff --git a/tensorflow/core/kernels/topk_op.h b/tensorflow/core/kernels/topk_op.h index 97528db3dda82c..cdebb07f7cb885 100644 --- a/tensorflow/core/kernels/topk_op.h +++ b/tensorflow/core/kernels/topk_op.h @@ -28,11 +28,11 @@ namespace functor { template struct TopKFunctor { - static Status Compute(OpKernelContext* context, bool sorted, int k, - const typename TTypes::ConstTensor& input, - const int64_t num_rows, const int64_t num_cols, - typename TTypes::Tensor values, - typename TTypes::Tensor indices); + static absl::Status Compute(OpKernelContext* context, bool sorted, int k, + const typename TTypes::ConstTensor& input, + const int64_t num_rows, const int64_t num_cols, + typename TTypes::Tensor values, + typename TTypes::Tensor indices); }; } // end namespace functor diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 94baba66659d55..48810b83e4fbe0 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -164,13 +164,13 @@ struct ApplyAdagradV2 { template struct SparseApplyAdagrad { - Status operator()(const CPUDevice& d, typename TTypes::Matrix var, - typename TTypes::Matrix accum, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar epsilon, - typename TTypes::ConstMatrix grad, - typename TTypes::ConstVec indices, - int64_t inner_dim, bool update_slots) { + absl::Status operator()(const CPUDevice& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim, bool update_slots) { const Tindex N = static_cast(indices.dimension(0)); if (N == 0) return absl::OkStatus(); const Tindex first_dim_size = static_cast(var.dimension(0)); @@ -272,14 +272,14 @@ struct ApplyProximalAdagrad { template struct SparseApplyProximalAdagrad { - Status operator()(const CPUDevice& d, typename TTypes::Matrix var, - typename TTypes::Matrix accum, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar l1, - typename TTypes::ConstScalar l2, - typename TTypes::ConstMatrix grad, - typename TTypes::ConstVec indices, - int64_t inner_dim) { + absl::Status operator()(const CPUDevice& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim) { const Tindex N = static_cast(indices.dimension(0)); if (N == 0) return absl::OkStatus(); const Tindex first_dim_size = static_cast(var.dimension(0)); @@ -587,17 +587,18 @@ void ComputeFtrl(GradTy grad, template struct SparseApplyFtrl { - Status operator()(const CPUDevice& d, typename TTypes::Matrix var_flat, - typename TTypes::Matrix accum_flat, - typename TTypes::Matrix linear_flat, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar l1, - typename TTypes::ConstScalar l2, - typename TTypes::ConstScalar l2_shrinkage, - typename TTypes::ConstScalar lr_power, - typename TTypes::ConstMatrix grad_flat, - typename TTypes::ConstVec indices_vec, - int64_t inner_dim, bool multiply_linear_by_lr) { + absl::Status operator()(const CPUDevice& d, + typename TTypes::Matrix var_flat, + typename TTypes::Matrix accum_flat, + typename TTypes::Matrix linear_flat, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power, + typename TTypes::ConstMatrix grad_flat, + typename TTypes::ConstVec indices_vec, + int64_t inner_dim, bool multiply_linear_by_lr) { const Tindex N = static_cast(indices_vec.dimension(0)); if (N > 0) { T lr_scalar = lr(); diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index e0626c3d8dcd1a..8f986d13a9a712 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -107,13 +107,13 @@ struct ApplyAdagradDA { template struct SparseApplyAdagrad { // Note that epsilon is ignored if has_epsilon is false. - Status operator()(const Device& d, typename TTypes::Matrix var, - typename TTypes::Matrix accum, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar epsilon, - typename TTypes::ConstMatrix grad, - typename TTypes::ConstVec indices, - int64_t inner_dim, bool update_slots); + absl::Status operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar epsilon, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim, bool update_slots); }; template @@ -128,14 +128,14 @@ struct ApplyProximalAdagrad { template struct SparseApplyProximalAdagrad { - Status operator()(const Device& d, typename TTypes::Matrix var, - typename TTypes::Matrix accum, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar l1, - typename TTypes::ConstScalar l2, - typename TTypes::ConstMatrix grad, - typename TTypes::ConstVec indices, - int64_t inner_dim); + absl::Status operator()(const Device& d, typename TTypes::Matrix var, + typename TTypes::Matrix accum, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstMatrix grad, + typename TTypes::ConstVec indices, + int64_t inner_dim); }; template @@ -190,17 +190,17 @@ struct ApplyFtrlV2MultiplyLinearByLr { template struct SparseApplyFtrl { - Status operator()(const Device& d, typename TTypes::Matrix var_flat, - typename TTypes::Matrix accum_flat, - typename TTypes::Matrix linear_flat, - typename TTypes::ConstScalar lr, - typename TTypes::ConstScalar l1, - typename TTypes::ConstScalar l2, - typename TTypes::ConstScalar l2_shrinkage, - typename TTypes::ConstScalar lr_power, - typename TTypes::ConstMatrix grad_flat, - typename TTypes::ConstVec indices_vec, - int64_t inner_dim, bool multiply_linear_by_lr); + absl::Status operator()(const Device& d, typename TTypes::Matrix var_flat, + typename TTypes::Matrix accum_flat, + typename TTypes::Matrix linear_flat, + typename TTypes::ConstScalar lr, + typename TTypes::ConstScalar l1, + typename TTypes::ConstScalar l2, + typename TTypes::ConstScalar l2_shrinkage, + typename TTypes::ConstScalar lr_power, + typename TTypes::ConstMatrix grad_flat, + typename TTypes::ConstVec indices_vec, + int64_t inner_dim, bool multiply_linear_by_lr); }; template diff --git a/tensorflow/core/kernels/transpose_functor.h b/tensorflow/core/kernels/transpose_functor.h index 5dc5ffcbdb66d3..f4c905b1c27742 100644 --- a/tensorflow/core/kernels/transpose_functor.h +++ b/tensorflow/core/kernels/transpose_functor.h @@ -33,8 +33,8 @@ namespace tensorflow { // REQUIRES: in.dims() == perm.size() // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) template -Status DoTranspose(const Device& device, const Tensor& in, - const absl::Span perm, Tensor* out); +absl::Status DoTranspose(const Device& device, const Tensor& in, + const absl::Span perm, Tensor* out); // Conjugate and transpose tensor 'in' into tensor 'out' according to dimension // permutation 'perm'. @@ -44,19 +44,21 @@ Status DoTranspose(const Device& device, const Tensor& in, // REQUIRES: in.dims() == perm.size() // REQUIRES: in.dim_size(perm[i]) == out->dim_size(i) template -Status DoConjugateTranspose(const Device& device, const Tensor& in, - const absl::Span perm, Tensor* out); +absl::Status DoConjugateTranspose(const Device& device, const Tensor& in, + const absl::Span perm, + Tensor* out); // Convenience versions of DoTranspose that only swap the last (inner) two // dimensions. template -Status DoMatrixTranspose(const Device& device, const Tensor& in, Tensor* out); +absl::Status DoMatrixTranspose(const Device& device, const Tensor& in, + Tensor* out); // Convenience versions of DoConjugateTranspose that only swap the last (inner) // two dimensions. template -Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in, - Tensor* out); +absl::Status DoConjugateMatrixTranspose(const Device& device, const Tensor& in, + Tensor* out); // Primary device specific functor to be specialized for each device and type. template @@ -164,9 +166,9 @@ void TransposeUsingEigen(const Device& d, const Tensor& in, } template -Status DoTransposeImpl(const Device& d, const Tensor& in, - const absl::Span perm, bool conjugate, - Tensor* out) { +absl::Status DoTransposeImpl(const Device& d, const Tensor& in, + const absl::Span perm, bool conjugate, + Tensor* out) { CHECK_EQ(in.dims(), out->dims()); CHECK_EQ(in.dims(), perm.size()); CHECK_EQ(in.dtype(), out->dtype()); @@ -239,8 +241,9 @@ Status DoTransposeImpl(const Device& d, const Tensor& in, } template -inline Status DoMatrixTransposeImpl(const Device& device, const Tensor& in, - bool conjugate, Tensor* out) { +inline absl::Status DoMatrixTransposeImpl(const Device& device, + const Tensor& in, bool conjugate, + Tensor* out) { const int ndims = in.dims(); if (ndims == 0) return absl::OkStatus(); TransposePermsVec perm(ndims); diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index cca2f6da33abef..e94b68eaf52db9 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -93,8 +93,8 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") namespace { template -Status PermutationHelper(const Tensor& perm, const int dims, - std::vector* permutation) { +absl::Status PermutationHelper(const Tensor& perm, const int dims, + std::vector* permutation) { auto Vperm = perm.vec(); if (dims != Vperm.size()) { return errors::InvalidArgument("transpose expects a vector of size ", dims, @@ -187,17 +187,18 @@ void TransposeOp::Compute(OpKernelContext* ctx) { } } -Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) { +absl::Status TransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, + Tensor* out) { typedef Eigen::ThreadPoolDevice CPUDevice; return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, out); } -Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, - const Tensor& in, - absl::Span perm, - Tensor* out) { +absl::Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, + const Tensor& in, + absl::Span perm, + Tensor* out) { typedef Eigen::ThreadPoolDevice CPUDevice; return ::tensorflow::DoConjugateTranspose(ctx->eigen_device(), in, perm, out); diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h index f758c6ae66471f..8f0405b604f818 100644 --- a/tensorflow/core/kernels/transpose_op.h +++ b/tensorflow/core/kernels/transpose_op.h @@ -28,8 +28,9 @@ class TransposeOp : public OpKernel { void Compute(OpKernelContext* ctx) override; protected: - virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) = 0; + virtual absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, + Tensor* out) = 0; virtual bool IsConjugate() const { return false; } }; @@ -38,8 +39,8 @@ class TransposeCpuOp : public TransposeOp { explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} protected: - Status DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) override; + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; }; #if defined(INTEL_MKL) @@ -58,8 +59,8 @@ class TransposeGpuOp : public TransposeOp { explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} protected: - Status DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) override; + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; }; @@ -70,8 +71,8 @@ class ConjugateTransposeCpuOp : public TransposeOp { : TransposeOp(ctx) {} protected: - Status DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) override; + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; bool IsConjugate() const override { return true; } }; @@ -94,8 +95,8 @@ class ConjugateTransposeGpuOp : public TransposeOp { : TransposeOp(ctx) {} protected: - Status DoTranspose(OpKernelContext* ctx, const Tensor& in, - absl::Span perm, Tensor* out) override; + absl::Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + absl::Span perm, Tensor* out) override; bool IsConjugate() const override { return true; } }; diff --git a/tensorflow/core/kernels/typed_queue.h b/tensorflow/core/kernels/typed_queue.h index 2e67261841859d..e4c82f0ebde03a 100644 --- a/tensorflow/core/kernels/typed_queue.h +++ b/tensorflow/core/kernels/typed_queue.h @@ -36,7 +36,7 @@ class TypedQueue : public QueueBase { const std::vector& component_shapes, const string& name); - virtual Status Initialize(); // Must be called before any other method. + virtual absl::Status Initialize(); // Must be called before any other method. int64_t MemoryUsed() const override; @@ -51,7 +51,7 @@ TypedQueue::TypedQueue( : QueueBase(capacity, component_dtypes, component_shapes, name) {} template -Status TypedQueue::Initialize() { +absl::Status TypedQueue::Initialize() { if (component_dtypes_.empty()) { return errors::InvalidArgument("Empty component types for queue ", name_); } diff --git a/tensorflow/core/kernels/unary_ops_composition.cc b/tensorflow/core/kernels/unary_ops_composition.cc index 98684f382ecd21..112f32a8641fe1 100644 --- a/tensorflow/core/kernels/unary_ops_composition.cc +++ b/tensorflow/core/kernels/unary_ops_composition.cc @@ -56,8 +56,8 @@ struct UnaryOpsCompositionBase { private: friend class UnaryOpsComposition; - Status ExportComputeFns(const std::vector& op_names, - std::vector* fns, int* cost) { + absl::Status ExportComputeFns(const std::vector& op_names, + std::vector* fns, int* cost) { for (const string& op_name : op_names) { auto it = compute_fns.find(op_name); if (it == compute_fns.end()) diff --git a/tensorflow/core/kernels/unicode_ops.cc b/tensorflow/core/kernels/unicode_ops.cc index c18ca23791405e..d5dbe94ce5bb74 100644 --- a/tensorflow/core/kernels/unicode_ops.cc +++ b/tensorflow/core/kernels/unicode_ops.cc @@ -204,7 +204,7 @@ struct ErrorOptions { bool error_on_malformatting = false; }; -Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) { +absl::Status GetErrorOptions(OpKernelConstruction* ctx, ErrorOptions* out) { *out = ErrorOptions(); string error_policy; diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc b/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc index 413882ca810835..7ff1dc10cd1d7e 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils.cc @@ -26,8 +26,8 @@ using errors::InvalidArgument; // Reference: // https://github.com/tensorflow/tensorflow/blob/57946ceb4b6119d6d0f49abbb2e3d1636a3b83a0/tensorflow/lite/kernels/internal/quantization_util.cc#L53 // Where double_multiplier >= 0 and TFLITE_EMULATE_FLOAT is not defined. -Status QuantizeMultiplier(double double_multiplier, - int32_t& quantized_multiplier, int32_t& shift) { +absl::Status QuantizeMultiplier(double double_multiplier, + int32_t& quantized_multiplier, int32_t& shift) { if (!isfinite(double_multiplier) || double_multiplier <= 0) { return InvalidArgument( "double_multiplier must be a poisitive finite number. Given ", diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils.h b/tensorflow/core/kernels/uniform_quant_ops/math_utils.h index 8d471f9d21139d..5cd9c1b4746d43 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/math_utils.h +++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils.h @@ -105,10 +105,11 @@ void AffineDequantize(const ConstTensorTin& input_tensor, float scale, // / scale + zero_point), while AffineQuantize() uses floor(input_val * // (1./scale) + 0.5) + zero_point template -Status AsymmetricQuantize(const ConstTensorTin& input_tensor, - int32_t quantization_min_val, - int32_t quantization_max_val, float& scale, - int32& zero_point, TensorTout quantized_tensor) { +absl::Status AsymmetricQuantize(const ConstTensorTin& input_tensor, + int32_t quantization_min_val, + int32_t quantization_max_val, float& scale, + int32& zero_point, + TensorTout quantized_tensor) { if (quantization_min_val >= quantization_max_val) { // NOLINTNEXTLINE return errors::InvalidArgument( @@ -177,8 +178,8 @@ Status AsymmetricQuantize(const ConstTensorTin& input_tensor, // // Output quantized_multiplier is clamped to range [0, INT32_MAX], // and shift is clamped to range [-31, 30]. -Status QuantizeMultiplier(double double_multiplier, - int32_t& quantized_multiplier, int32_t& shift); +absl::Status QuantizeMultiplier(double double_multiplier, + int32_t& quantized_multiplier, int32_t& shift); // Requantize input_val given quantized effective_muliplier|shift and // input|output zero_point. @@ -207,7 +208,7 @@ namespace internal { // Requantize from per-tensor to per-tensor. template -Status PerTensorToPerTensorRequantize( +absl::Status PerTensorToPerTensorRequantize( const Tensor& input, float input_scale, int32_t input_zero_point, float output_scale, int32_t output_zero_point, int32_t quantization_min_val, int32_t quantization_max_val, Tensor& output) { @@ -235,13 +236,14 @@ Status PerTensorToPerTensorRequantize( // - From per-axis to per-tensor. // - From per-axis to per-axis. template -Status PerAxisRequantize(OpKernelContext* context, const Tensor& input, - const Tensor& input_scales, - const Tensor& input_zero_points, - const Tensor& output_scales, - const Tensor& output_zero_points, - int quantization_axis, int32_t quantization_min_val, - int32_t quantization_max_val, Tensor& output) { +absl::Status PerAxisRequantize(OpKernelContext* context, const Tensor& input, + const Tensor& input_scales, + const Tensor& input_zero_points, + const Tensor& output_scales, + const Tensor& output_zero_points, + int quantization_axis, + int32_t quantization_min_val, + int32_t quantization_max_val, Tensor& output) { const bool input_per_axis_quantization = input_scales.dims() == 1; const bool output_per_axis_quantization = output_scales.dims() == 1; const auto& per_axis_scales_shape = input_per_axis_quantization @@ -304,14 +306,12 @@ Status PerAxisRequantize(OpKernelContext* context, const Tensor& input, } // namespace internal template -Status EvalRequantize(OpKernelContext* context, const Tensor& input, - const Tensor& input_scales, - const Tensor& input_zero_points, - const Tensor& output_scales, - const Tensor& output_zero_points, - int input_quantization_axis, int output_quantization_axis, - int32_t quantization_min_val, - int32_t quantization_max_val, Tensor& output) { +absl::Status EvalRequantize( + OpKernelContext* context, const Tensor& input, const Tensor& input_scales, + const Tensor& input_zero_points, const Tensor& output_scales, + const Tensor& output_zero_points, int input_quantization_axis, + int output_quantization_axis, int32_t quantization_min_val, + int32_t quantization_max_val, Tensor& output) { if (input_quantization_axis == -1 && output_quantization_axis == -1) { return internal::PerTensorToPerTensorRequantize( input, input_scales.scalar()(), diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc index 842b3dd4c683ae..65f09f2da113b8 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.cc @@ -18,10 +18,10 @@ namespace tensorflow { using tensorflow::errors::InvalidArgument; -Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, - const TensorShape& scales_shape, - const TensorShape& zero_points_shape, - int quantization_axis) { +absl::Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, + const TensorShape& scales_shape, + const TensorShape& zero_points_shape, + int quantization_axis) { if (!scales_shape.IsSameSize(zero_points_shape)) { return InvalidArgument( "scales and zero_points shape must be same, but given scales shape ", diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h index 9042325e20a01f..4a303a3ff91fd1 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h +++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h @@ -32,10 +32,10 @@ bool AllElementsPositive(const Tensor& tensor) { // Given data tensor's shape and quantization params, returns if the shapes are // valid. -Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, - const TensorShape& scales_shape, - const TensorShape& zero_points_shape, - int quantization_axis); +absl::Status QuantizationAxisAndShapeValid(const TensorShape& data_shape, + const TensorShape& scales_shape, + const TensorShape& zero_points_shape, + int quantization_axis); // Given in_shape and perm to transpose, returns out shape after the transpose. // perm must be a permutation of [0, 1, ..., in_shape.rank - 1]. The caller is diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc index 6cb4439b25230d..0cd8d4358d851d 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_add_op.cc @@ -114,16 +114,14 @@ void QuantizedAdd(const Tensor& lhs, const Tensor& rhs, } template -Status EvalQuantizedAdd(OpKernelContext* context, const Tensor& lhs, - const Tensor& rhs, const Tensor& lhs_scales, - const Tensor& lhs_zero_points, const Tensor& rhs_scales, - const Tensor& rhs_zero_points, - const Tensor& output_scales, - const Tensor& output_zero_points, - int output_quantization_min_val, - int output_quantization_max_val, - int lhs_quantization_axis, int rhs_quantization_axis, - int output_quantization_axis, Tensor& output) { +absl::Status EvalQuantizedAdd( + OpKernelContext* context, const Tensor& lhs, const Tensor& rhs, + const Tensor& lhs_scales, const Tensor& lhs_zero_points, + const Tensor& rhs_scales, const Tensor& rhs_zero_points, + const Tensor& output_scales, const Tensor& output_zero_points, + int output_quantization_min_val, int output_quantization_max_val, + int lhs_quantization_axis, int rhs_quantization_axis, + int output_quantization_axis, Tensor& output) { const DataType dtype = DataTypeToEnum::v(); Tensor zeros_of_output_scales_shape; diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc index 1fa972a8c9801e..b8f5c58038e54c 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops.cc @@ -270,7 +270,7 @@ void ConvWithAccFunctionAndOutFunction( // Quantized Conv on per-tensor quantized padded and dilated transposed lhs and // per-tensor quantized transposed rhs. template -Status EvalLhsPerTensorAndRhsPerTensorQuantizedConv( +absl::Status EvalLhsPerTensorAndRhsPerTensorQuantizedConv( const Tensor& lhs, const Tensor& rhs, const UniformQuantizedConvolutionParams& convolution_params, const float lhs_scale, const int32_t lhs_zero_point, const float rhs_scale, @@ -308,7 +308,7 @@ Status EvalLhsPerTensorAndRhsPerTensorQuantizedConv( // Quantized Conv on per-tensor quantized padded and dilated transposed lhs and // per-channel quantized transposed rhs. template -Status EvalLhsPerTensorAndRhsPerChannelQuantizedConv( +absl::Status EvalLhsPerTensorAndRhsPerChannelQuantizedConv( OpKernelContext* context, const Tensor& lhs, const Tensor& rhs, const UniformQuantizedConvolutionParams& convolution_params, const float lhs_scale, const int32_t lhs_zero_point, @@ -449,7 +449,7 @@ void EvalLhsPerBatchAndRhsPerChannelQuantizedConv( // Given quantized `lhs` and quantized `rhs`, performs quantized convolution and // writes to `out`. Assumes that `out` is already allocated with correct size. template -Status EvalQuantizedConv( +absl::Status EvalQuantizedConv( OpKernelContext* context, const Tensor& lhs, const Tensor& rhs, const UniformQuantizedConvolutionParams& convolution_params, const Tensor& lhs_scales, const Tensor& lhs_zero_points, @@ -519,7 +519,7 @@ Status EvalQuantizedConv( // For more details on `lhs` quantization policy, refer to the comment of class // UniformQuantizedConvolutionHybridOp below. template -Status EvalHybridConv( +absl::Status EvalHybridConv( OpKernelContext* context, const Tensor& lhs, const Tensor& rhs, const UniformQuantizedConvolutionParams& convolution_params, const Tensor& rhs_scales, const Tensor& rhs_zero_points, Tensor& out) { diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc index f2d51987c6f40c..1200a09d3f86ff 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_dot_ops.cc @@ -23,8 +23,8 @@ namespace { using tensorflow::errors::InvalidArgument; // Given lhs and rhs shapes, returns if the shapes are valid for 2D X 2D dot. -Status DotInputShapeValid(const TensorShape& lhs_shape, - const TensorShape& rhs_shape) { +absl::Status DotInputShapeValid(const TensorShape& lhs_shape, + const TensorShape& rhs_shape) { if (lhs_shape.dims() != 2) { return InvalidArgument("lhs rank must be 2, but given lhs shape ", lhs_shape.DebugString()); @@ -80,7 +80,7 @@ void DotWithAccFunctionAndOutputFunction(const Tensor& lhs, const Tensor& rhs, // Performs dot on per-tensor quantized lhs and per-tensor quantized rhs. template -Status EvalLhsPerTensorAndRhsPerTensorQuantizedDot( +absl::Status EvalLhsPerTensorAndRhsPerTensorQuantizedDot( const Tensor& lhs, const Tensor& rhs, float lhs_scale, int32_t lhs_zero_point, float rhs_scale, int32_t rhs_zero_point, float output_scale, int32_t output_zero_point, @@ -115,7 +115,7 @@ Status EvalLhsPerTensorAndRhsPerTensorQuantizedDot( // Performs dot on per-tensor quantized lhs and per-channel (dimension 1) // quantized rhs. template -Status EvalLhsPerTensorAndRhsPerChannelQuantizedDot( +absl::Status EvalLhsPerTensorAndRhsPerChannelQuantizedDot( OpKernelContext* context, const Tensor& lhs, const Tensor& rhs, float lhs_scale, int32_t lhs_zero_point, const Tensor& rhs_scales, const Tensor& rhs_zero_points, const Tensor& output_scales, @@ -232,14 +232,15 @@ void EvalLhsPerBatchAndRhsPerChannelQuantizedDot( // and produce quantized output. Assumes that output is already allocated with // correct size. template -Status EvalQuantizedDot(OpKernelContext* context, const Tensor& lhs, - const Tensor& rhs, const Tensor& lhs_scales, - const Tensor& lhs_zero_points, const Tensor& rhs_scales, - const Tensor& rhs_zero_points, - const Tensor& output_scales, - const Tensor& output_zero_points, - int output_quantization_min_val, - int output_quantization_max_val, Tensor& output) { +absl::Status EvalQuantizedDot(OpKernelContext* context, const Tensor& lhs, + const Tensor& rhs, const Tensor& lhs_scales, + const Tensor& lhs_zero_points, + const Tensor& rhs_scales, + const Tensor& rhs_zero_points, + const Tensor& output_scales, + const Tensor& output_zero_points, + int output_quantization_min_val, + int output_quantization_max_val, Tensor& output) { const float lhs_scale = lhs_scales.scalar()(); const int32_t lhs_zero_point = lhs_zero_points.scalar()(); if (rhs_scales.dims() != 0) { @@ -265,9 +266,9 @@ Status EvalQuantizedDot(OpKernelContext* context, const Tensor& lhs, // For more details on lhs quantization policy, refer to the comment of class // UniformQuantizedDotHybridOp below. template -Status EvalHybridDot(OpKernelContext* context, const Tensor& lhs, - const Tensor& rhs, const Tensor& rhs_scales, - const Tensor& rhs_zero_points, Tensor& output) { +absl::Status EvalHybridDot(OpKernelContext* context, const Tensor& lhs, + const Tensor& rhs, const Tensor& rhs_scales, + const Tensor& rhs_zero_points, Tensor& output) { const int64_t batches = lhs.dim_size(0); Tensor lhs_quantized; diff --git a/tensorflow/core/kernels/unique_op_gpu.cu.h b/tensorflow/core/kernels/unique_op_gpu.cu.h index d76e1fe989be29..c2c66638edea9b 100644 --- a/tensorflow/core/kernels/unique_op_gpu.cu.h +++ b/tensorflow/core/kernels/unique_op_gpu.cu.h @@ -30,7 +30,10 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/util/gpu_kernel_helper.h" #include "tensorflow/core/util/gpu_solvers.h" // For ScratchSpace +<<<<<<< HEAD #include "xla/stream_executor/gpu/scoped_activate_context.h" +======= +>>>>>>> master #if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" @@ -329,8 +332,8 @@ class UniqueOpGPU : public AsyncOpKernel { const GPUDevice& device = context->eigen_gpu_device(); int64 uniq_size = (*last_idx_host.data()) + 1; - se::gpu::ScopedActivateContext scoped_activation{ - context->op_device_context()->stream()->parent()}; + std::unique_ptr scoped_activation = + context->op_device_context()->stream()->parent()->Activate(); Tensor unique_input_inds; TIndex* unique_input_inds_ptr = nullptr; diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 870dfd01fabc07..cee2a8dda268a5 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -104,7 +104,7 @@ class TemporaryVariableOp : public OpKernel { } void Compute(OpKernelContext* context) override { - Status s; + absl::Status s; ResourceMgr* rm = context->resource_manager(); OP_REQUIRES(context, rm, errors::Internal("No per-step resource manager.")); auto unique_name = TemporaryVariableName(var_name_, context->frame_iter()); diff --git a/tensorflow/core/kernels/variant_ops_util.cc b/tensorflow/core/kernels/variant_ops_util.cc index 7b8f38b5264df3..947d21e13029e7 100644 --- a/tensorflow/core/kernels/variant_ops_util.cc +++ b/tensorflow/core/kernels/variant_ops_util.cc @@ -31,12 +31,12 @@ namespace tensorflow { // : ctx->input(ix).scalar()()) // This reduces (possibly expensive) copying of Variants from // the inputs into temp at the lowest levels of the summation tree. -static inline Status AddVariantTo( +static inline absl::Status AddVariantTo( OpKernelContext* ctx, const int lhs_ix, const int rhs_ix, absl::InlinedVector* temp, absl::InlinedVector* temp_filled, - std::function + std::function binary_add_variant) { Variant tmp; if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix)); @@ -53,8 +53,8 @@ static inline Status AddVariantTo( } void AddNVariant(OpKernelContext* ctx, - std::function + std::function binary_add_variant) { const Tensor& input0 = ctx->input(0); const int num = ctx->num_inputs(); diff --git a/tensorflow/core/kernels/variant_ops_util.h b/tensorflow/core/kernels/variant_ops_util.h index 7ebe4fb87dfb8d..d6d1e831253ca2 100644 --- a/tensorflow/core/kernels/variant_ops_util.h +++ b/tensorflow/core/kernels/variant_ops_util.h @@ -27,8 +27,8 @@ class Tensor; class Variant; void AddNVariant(OpKernelContext* ctx, - std::function + std::function binary_add_variant); } // namespace tensorflow diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc index 9f3600ee922270..1421e24cbb0fdd 100644 --- a/tensorflow/core/kernels/where_op.cc +++ b/tensorflow/core/kernels/where_op.cc @@ -40,12 +40,8 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/util/gpu_solvers.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" -using stream_executor::gpu::ScopedActivateContext; -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" -using stream_executor::gpu::ScopedActivateContext; #endif // TENSORFLOW_USE_ROCM #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -73,9 +69,9 @@ int64_t CountAccumulator(const bool* begin, const bool* end) { template struct NumTrue { - static Status Compute(OpKernelContext* ctx, const CPUDevice& d, - typename TTypes::ConstFlat input, - TTypes::UnalignedScalar num_true) { + static absl::Status Compute(OpKernelContext* ctx, const CPUDevice& d, + typename TTypes::ConstFlat input, + TTypes::UnalignedScalar num_true) { num_true() = CountAccumulator(input.data(), input.data() + input.size()); return absl::OkStatus(); } @@ -93,7 +89,7 @@ struct Where { } } - EIGEN_ALWAYS_INLINE static Status Compute( + EIGEN_ALWAYS_INLINE static absl::Status Compute( OpKernelContext* ctx, const CPUDevice& d, typename TTypes::ConstTensor input, typename TTypes::Matrix output, TIndex* found_true) { @@ -143,7 +139,7 @@ class WhereCPUOp : public OpKernel { int64_t num_true; TTypes::UnalignedScalar num_true_t(&num_true); - Status s = functor::NumTrue::Compute( + absl::Status s = functor::NumTrue::Compute( context, context->eigen_device(), input.flat(), num_true_t); OP_REQUIRES_OK(context, s); @@ -296,7 +292,8 @@ class WhereGPUOp : public AsyncOpKernel { // configured. auto stream = context->op_device_context()->stream(); { - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); // TODO(ebrevdo): Properly copy back found_true value to CPU for // validation checking. Currently Where::Compute() @@ -348,7 +345,7 @@ class WhereGPUOp : public AsyncOpKernel { // num_true, " elements; but when writing their indices, saw ", // found_true, " elements."), // done); - } // Release ScopedActivateContext to prevent deadlock when done + } // Release ActivateContext to prevent deadlock when done // inlines another Op kernel, which may assume the original cuda // Context. diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h index 9388c8a5557c80..fceea011c9ca52 100644 --- a/tensorflow/core/kernels/where_op.h +++ b/tensorflow/core/kernels/where_op.h @@ -38,7 +38,7 @@ namespace functor { template struct NumTrue { - EIGEN_ALWAYS_INLINE static Status Compute( + EIGEN_ALWAYS_INLINE static absl::Status Compute( OpKernelContext* ctx, const Device& d, typename TTypes::ConstFlat input, typename TTypes::UnalignedScalar num_true); @@ -52,7 +52,7 @@ struct Where { // *found_true != output.dimension(0), // then the input may have changed between the initial counting of // the true values and the call to Where. - EIGEN_ALWAYS_INLINE static Status Compute( + EIGEN_ALWAYS_INLINE static absl::Status Compute( OpKernelContext* ctx, const Device& d, typename TTypes::ConstTensor input, typename TTypes::Matrix output, TIndex* found_true); diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index f57e1d7c602ade..6eb97c95f5df0c 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -35,7 +35,8 @@ limitations under the License. namespace tensorflow { template -static Status ReadEntireFile(Env* env, const string& filename, T* contents) { +static absl::Status ReadEntireFile(Env* env, const string& filename, + T* contents) { std::unique_ptr file; TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); io::RandomAccessInputStream input_stream(file.get()); @@ -50,8 +51,8 @@ class WholeFileReader : public ReaderBase { : ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")), env_(env) {} - Status ReadLocked(tstring* key, tstring* value, bool* produced, - bool* at_end) override { + absl::Status ReadLocked(tstring* key, tstring* value, bool* produced, + bool* at_end) override { *key = current_work(); TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value)); *produced = true; @@ -61,14 +62,14 @@ class WholeFileReader : public ReaderBase { // Stores state in a ReaderBaseState proto, since WholeFileReader has // no additional state beyond ReaderBase. - Status SerializeStateLocked(tstring* state) override { + absl::Status SerializeStateLocked(tstring* state) override { ReaderBaseState base_state; SaveBaseState(&base_state); SerializeToTString(base_state, state); return absl::OkStatus(); } - Status RestoreStateLocked(const tstring& state) override { + absl::Status RestoreStateLocked(const tstring& state) override { ReaderBaseState base_state; if (!ParseProtoUnlimited(&base_state, state)) { return errors::InvalidArgument("Could not parse state for ", name(), ": ", diff --git a/tensorflow/core/kernels/word2vec_kernels.cc b/tensorflow/core/kernels/word2vec_kernels.cc index bc81562edf3155..7f1dddce884009 100644 --- a/tensorflow/core/kernels/word2vec_kernels.cc +++ b/tensorflow/core/kernels/word2vec_kernels.cc @@ -177,7 +177,7 @@ class SkipgramOp : public OpKernel { *label = sentence_[label_pos_++]; } - Status Init(Env* env, const string& filename) { + absl::Status Init(Env* env, const string& filename) { string data; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &data)); StringPiece input = data; diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD index b5e24a8e81caa0..80104f9ff01c6e 100644 --- a/tensorflow/core/lib/core/BUILD +++ b/tensorflow/core/lib/core/BUILD @@ -7,11 +7,7 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/cc/saved_model:__subpackages__", - "//tensorflow/core:__subpackages__", - "@local_tsl//tsl:__subpackages__", - ], + default_visibility = ["//tensorflow/core:__subpackages__"], licenses = ["notice"], ) @@ -132,14 +128,9 @@ tf_proto_library( srcs = ["error_codes.proto"], make_default_target_header_only = True, protodeps = [ - "@local_tsl//tsl/protobuf:error_codes_proto_impl", - ], - visibility = [ - "//tensorflow/core:__subpackages__", - "//tensorflow/core/protobuf:__subpackages__", - "@local_tsl//tsl:__subpackages__", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl", ], - exports = ["@local_tsl//tsl/protobuf:error_codes_proto_impl"], + exports = ["@local_xla//xla/tsl/protobuf:error_codes_proto_impl"], ) # Export source files needed for mobile builds, which do not use granular targets. @@ -217,7 +208,7 @@ filegroup( srcs = [ "status.h", ], - visibility = ["//tensorflow/core:__pkg__"], + visibility = ["//visibility:private"], ) filegroup( diff --git a/tensorflow/core/lib/core/error_codes.proto b/tensorflow/core/lib/core/error_codes.proto index 4038d5935d7bfd..ccd552ed76dba6 100644 --- a/tensorflow/core/lib/core/error_codes.proto +++ b/tensorflow/core/lib/core/error_codes.proto @@ -1,3 +1,3 @@ syntax = "proto3"; -import public "tsl/protobuf/error_codes.proto"; +import public "xla/tsl/protobuf/error_codes.proto"; diff --git a/tensorflow/core/lib/core/status_test.cc b/tensorflow/core/lib/core/status_test.cc index af80f615baf65b..bfe401fdcc1cbf 100644 --- a/tensorflow/core/lib/core/status_test.cc +++ b/tensorflow/core/lib/core/status_test.cc @@ -28,57 +28,57 @@ TEST(Status, OK) { EXPECT_EQ(absl::OkStatus().message(), ""); TF_EXPECT_OK(absl::OkStatus()); TF_ASSERT_OK(absl::OkStatus()); - EXPECT_EQ(absl::OkStatus(), Status()); - Status s; + EXPECT_EQ(absl::OkStatus(), absl::Status()); + absl::Status s; EXPECT_TRUE(s.ok()); } TEST(DeathStatus, CheckOK) { - Status status(errors::InvalidArgument("Invalid")); + absl::Status status(errors::InvalidArgument("Invalid")); ASSERT_DEATH(TF_CHECK_OK(status), "Invalid"); } TEST(Status, Set) { - Status status; - status = Status(absl::StatusCode::kCancelled, "Error message"); + absl::Status status; + status = absl::Status(absl::StatusCode::kCancelled, "Error message"); EXPECT_EQ(status.code(), absl::StatusCode::kCancelled); EXPECT_EQ(status.message(), "Error message"); } TEST(Status, Copy) { - Status a(errors::InvalidArgument("Invalid")); - Status b(a); + absl::Status a(errors::InvalidArgument("Invalid")); + absl::Status b(a); ASSERT_EQ(a.ToString(), b.ToString()); } TEST(Status, Assign) { - Status a(errors::InvalidArgument("Invalid")); - Status b; + absl::Status a(errors::InvalidArgument("Invalid")); + absl::Status b; b = a; ASSERT_EQ(a.ToString(), b.ToString()); } TEST(Status, Move) { - Status a(errors::InvalidArgument("Invalid")); - Status b(std::move(a)); + absl::Status a(errors::InvalidArgument("Invalid")); + absl::Status b(std::move(a)); ASSERT_EQ("INVALID_ARGUMENT: Invalid", b.ToString()); } TEST(Status, MoveAssign) { - Status a(errors::InvalidArgument("Invalid")); - Status b; + absl::Status a(errors::InvalidArgument("Invalid")); + absl::Status b; b = std::move(a); ASSERT_EQ("INVALID_ARGUMENT: Invalid", b.ToString()); } TEST(Status, Update) { - Status s; + absl::Status s; s.Update(absl::OkStatus()); ASSERT_TRUE(s.ok()); - Status a(errors::InvalidArgument("Invalid")); + absl::Status a(errors::InvalidArgument("Invalid")); s.Update(a); ASSERT_EQ(s.ToString(), a.ToString()); - Status b(errors::Internal("Internal")); + absl::Status b(errors::Internal("Internal")); s.Update(b); ASSERT_EQ(s.ToString(), a.ToString()); s.Update(absl::OkStatus()); @@ -86,29 +86,29 @@ TEST(Status, Update) { ASSERT_FALSE(s.ok()); } -TEST(Status, EqualsOK) { ASSERT_EQ(absl::OkStatus(), Status()); } +TEST(Status, EqualsOK) { ASSERT_EQ(absl::OkStatus(), absl::Status()); } TEST(Status, EqualsSame) { - Status a(errors::InvalidArgument("Invalid")); - Status b(errors::InvalidArgument("Invalid")); + absl::Status a(errors::InvalidArgument("Invalid")); + absl::Status b(errors::InvalidArgument("Invalid")); ASSERT_EQ(a, b); } TEST(Status, EqualsCopy) { - const Status a(errors::InvalidArgument("Invalid")); - const Status b = a; + const absl::Status a(errors::InvalidArgument("Invalid")); + const absl::Status b = a; ASSERT_EQ(a, b); } TEST(Status, EqualsDifferentCode) { - const Status a(errors::InvalidArgument("message")); - const Status b(errors::Internal("message")); + const absl::Status a(errors::InvalidArgument("message")); + const absl::Status b(errors::Internal("message")); ASSERT_NE(a, b); } TEST(Status, EqualsDifferentMessage) { - const Status a(errors::InvalidArgument("message")); - const Status b(errors::InvalidArgument("another")); + const absl::Status a(errors::InvalidArgument("message")); + const absl::Status b(errors::InvalidArgument("another")); ASSERT_NE(a, b); } @@ -122,17 +122,17 @@ TEST(StatusGroup, OKStatusGroup) { TEST(StatusGroup, AggregateWithSingleErrorStatus) { StatusGroup c; - const Status internal(errors::Internal("Original error.")); + const absl::Status internal(errors::Internal("Original error.")); c.Update(internal); ASSERT_EQ(c.as_summary_status(), internal); - Status concat_status = c.as_concatenated_status(); + absl::Status concat_status = c.as_concatenated_status(); ASSERT_EQ(concat_status.code(), internal.code()); ASSERT_TRUE(absl::StrContains(concat_status.message(), internal.message())); // Add derived error status - const Status derived = + const absl::Status derived = StatusGroup::MakeDerived(errors::Internal("Derived error.")); c.Update(derived); @@ -145,22 +145,22 @@ TEST(StatusGroup, AggregateWithSingleErrorStatus) { TEST(StatusGroup, AggregateWithMultipleErrorStatus) { StatusGroup c; - const Status internal(errors::Internal("Original error.")); - const Status cancelled(errors::Cancelled("Cancelled after 10 steps.")); - const Status aborted(errors::Aborted("Aborted after 10 steps.")); + const absl::Status internal(errors::Internal("Original error.")); + const absl::Status cancelled(errors::Cancelled("Cancelled after 10 steps.")); + const absl::Status aborted(errors::Aborted("Aborted after 10 steps.")); c.Update(internal); c.Update(cancelled); c.Update(aborted); - Status summary = c.as_summary_status(); + absl::Status summary = c.as_summary_status(); ASSERT_EQ(summary.code(), internal.code()); ASSERT_TRUE(absl::StrContains(summary.message(), internal.message())); ASSERT_TRUE(absl::StrContains(summary.message(), cancelled.message())); ASSERT_TRUE(absl::StrContains(summary.message(), aborted.message())); - Status concat_status = c.as_concatenated_status(); + absl::Status concat_status = c.as_concatenated_status(); ASSERT_EQ(concat_status.code(), internal.code()); ASSERT_TRUE(absl::StrContains(concat_status.message(), internal.message())); ASSERT_TRUE(absl::StrContains(concat_status.message(), cancelled.message())); @@ -168,7 +168,7 @@ TEST(StatusGroup, AggregateWithMultipleErrorStatus) { } TEST(Status, InvalidPayloadGetsIgnored) { - Status s = Status(); + absl::Status s = absl::Status(); s.SetPayload("Invalid", absl::Cord("Invalid Val")); ASSERT_FALSE(s.GetPayload("Invalid").has_value()); bool is_err_erased = s.ErasePayload("Invalid"); @@ -176,7 +176,7 @@ TEST(Status, InvalidPayloadGetsIgnored) { } TEST(Status, SetPayloadSetsOrUpdatesIt) { - Status s(absl::StatusCode::kInternal, "Error message"); + absl::Status s(absl::StatusCode::kInternal, "Error message"); s.SetPayload("Error key", absl::Cord("Original")); ASSERT_EQ(s.GetPayload("Error key"), absl::Cord("Original")); s.SetPayload("Error key", absl::Cord("Updated")); @@ -184,7 +184,7 @@ TEST(Status, SetPayloadSetsOrUpdatesIt) { } TEST(Status, ErasePayloadRemovesIt) { - Status s(absl::StatusCode::kInternal, "Error message"); + absl::Status s(absl::StatusCode::kInternal, "Error message"); s.SetPayload("Error key", absl::Cord("Original")); bool is_err_erased = s.ErasePayload("Error key"); @@ -195,9 +195,9 @@ TEST(Status, ErasePayloadRemovesIt) { } static void BM_TF_CHECK_OK(::testing::benchmark::State& state) { - tensorflow::Status s = (state.max_iterations < 0) - ? errors::InvalidArgument("Invalid") - : absl::OkStatus(); + absl::Status s = (state.max_iterations < 0) + ? errors::InvalidArgument("Invalid") + : absl::OkStatus(); for (auto i : state) { TF_CHECK_OK(s); } diff --git a/tensorflow/core/lib/db/sqlite.cc b/tensorflow/core/lib/db/sqlite.cc index c0d366ec695421..65f6492e50cd9d 100644 --- a/tensorflow/core/lib/db/sqlite.cc +++ b/tensorflow/core/lib/db/sqlite.cc @@ -79,7 +79,7 @@ absl::StatusCode GetTfErrorCode(int code) { } template -Status PrintfStatus(int rc, const char* fmt, Args&&... args) { +absl::Status PrintfStatus(int rc, const char* fmt, Args&&... args) { return {GetTfErrorCode(rc), strings::Printf(fmt, std::forward(args)...)}; } @@ -91,7 +91,8 @@ sqlite3_stmt* PrepareRawOrDie(sqlite3* db, const char* sql) { return stmt; } -Status SetPragma(Sqlite* db, const char* pragma, const StringPiece& value) { +absl::Status SetPragma(Sqlite* db, const char* pragma, + const StringPiece& value) { if (value.empty()) return absl::OkStatus(); for (auto p = value.begin(); p < value.end(); ++p) { if (!(('0' <= *p && *p <= '9') || ('A' <= *p && *p <= 'Z') || @@ -111,7 +112,7 @@ const StringPiece GetEnv(const char* var) { return (val == nullptr) ? StringPiece() : StringPiece(val); } -Status EnvPragma(Sqlite* db, const char* pragma, const char* var) { +absl::Status EnvPragma(Sqlite* db, const char* pragma, const char* var) { TF_RETURN_WITH_CONTEXT_IF_ERROR(SetPragma(db, pragma, GetEnv(var)), "getenv(", var, ")"); return absl::OkStatus(); @@ -120,7 +121,7 @@ Status EnvPragma(Sqlite* db, const char* pragma, const char* var) { } // namespace /* static */ -Status Sqlite::Open(const string& path, int flags, Sqlite** db) { +absl::Status Sqlite::Open(const string& path, int flags, Sqlite** db) { flags |= SQLITE_OPEN_PRIVATECACHE; flags |= SQLITE_OPEN_URI; sqlite3* sqlite = nullptr; @@ -139,7 +140,7 @@ Status Sqlite::Open(const string& path, int flags, Sqlite** db) { sqlite3_stmt* commit = PrepareRawOrDie(sqlite, "COMMIT"); sqlite3_stmt* rollback = PrepareRawOrDie(sqlite, "ROLLBACK"); *db = new Sqlite(sqlite, begin, commit, rollback); - Status s = absl::OkStatus(); + absl::Status s = absl::OkStatus(); // Up until 2016 the default SQLite page_size was 1024. This ensures // the new default regardless of linkage unless configured otherwise. s.Update(SetPragma(*db, "page_size", "4096")); @@ -170,7 +171,7 @@ Sqlite::~Sqlite() { CHECK_EQ(SQLITE_OK, sqlite3_close(db_)); } -Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) { +absl::Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) { SqliteLock lock(*this); sqlite3_stmt* ps = nullptr; int rc = sqlite3_prepare_v2(db_, sql.data(), static_cast(sql.size()), @@ -184,7 +185,7 @@ Status Sqlite::Prepare(const StringPiece& sql, SqliteStatement* stmt) { return absl::OkStatus(); } -Status SqliteStatement::Step(bool* is_done) { +absl::Status SqliteStatement::Step(bool* is_done) { DCHECK(stmt_ != nullptr); if (TF_PREDICT_FALSE(bind_error_ != SQLITE_OK)) { *is_done = true; @@ -214,7 +215,7 @@ bool SqliteStatement::StepOrDie() { return !is_done; } -Status SqliteStatement::StepOnce() { +absl::Status SqliteStatement::StepOnce() { bool is_done; TF_RETURN_IF_ERROR(Step(&is_done)); if (TF_PREDICT_FALSE(is_done)) { @@ -228,9 +229,9 @@ const SqliteStatement& SqliteStatement::StepOnceOrDie() { return *this; } -Status SqliteStatement::StepAndReset() { +absl::Status SqliteStatement::StepAndReset() { bool is_done; - Status s = Step(&is_done); + absl::Status s = Step(&is_done); if (TF_PREDICT_FALSE(s.ok() && !is_done)) { s = errors::Internal("Unexpected row: ", sql()); } @@ -277,7 +278,7 @@ void SqliteTransaction::Begin() { } } -Status SqliteTransaction::Commit() { +absl::Status SqliteTransaction::Commit() { int rc = sqlite3_step(db_->commit_); if (rc != SQLITE_DONE) { return PrintfStatus(rc, "COMMIT failed: [%d] %s", rc, diff --git a/tensorflow/core/lib/db/sqlite.h b/tensorflow/core/lib/db/sqlite.h index 28029020aac6f1..35fc40d3e66ff2 100644 --- a/tensorflow/core/lib/db/sqlite.h +++ b/tensorflow/core/lib/db/sqlite.h @@ -79,7 +79,7 @@ class TF_LOCKABLE Sqlite : public core::RefCounted { /// /// This function sets PRAGMA values from TF_SQLITE_* environment /// variables. See sqlite.cc to learn more. - static Status Open(const string& path, int flags, Sqlite** db); + static absl::Status Open(const string& path, int flags, Sqlite** db); /// \brief Creates SQLite statement. /// @@ -89,7 +89,7 @@ class TF_LOCKABLE Sqlite : public core::RefCounted { /// routine will retry automatically and then possibly fail. /// /// The returned statement holds a reference to this object. - Status Prepare(const StringPiece& sql, SqliteStatement* stmt); + absl::Status Prepare(const StringPiece& sql, SqliteStatement* stmt); SqliteStatement PrepareOrDie(const StringPiece& sql); /// \brief Returns extended result code of last error. @@ -177,7 +177,7 @@ class SqliteStatement { /// /// This statement should be Reset() or destructed when finished with /// the result. - Status Step(bool* is_done); + absl::Status Step(bool* is_done); bool StepOrDie() TF_MUST_USE_RESULT; /// \brief Executes query when only one row is desired. @@ -187,14 +187,14 @@ class SqliteStatement { /// /// This statement should be Reset() or destructed when finished with /// the result. - Status StepOnce(); + absl::Status StepOnce(); const SqliteStatement& StepOnceOrDie(); /// \brief Executes query, ensures zero rows returned, then Reset(). /// /// If a row is returned, an internal error Status is returned that /// won't be reflected in the connection error state. - Status StepAndReset(); + absl::Status StepAndReset(); void StepAndResetOrDie(); /// \brief Resets statement so it can be executed again. @@ -430,7 +430,7 @@ class TF_SCOPED_LOCKABLE SqliteTransaction { /// /// If this is successful, a new transaction will be started, which /// is rolled back when exiting the scope. - Status Commit(); + absl::Status Commit(); private: void Begin(); diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index 0a0042fda80dcf..ec394f262c65e7 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -214,7 +214,7 @@ TEST_F(SqliteTest, Statement_MoveAssignment) { TEST_F(SqliteTest, PrepareFailed) { SqliteLock lock(*db_); SqliteStatement stmt; - Status s = db_->Prepare("SELECT", &stmt); + absl::Status s = db_->Prepare("SELECT", &stmt); ASSERT_FALSE(s.ok()); EXPECT_NE(string::npos, s.message().find("SELECT")); EXPECT_EQ(SQLITE_ERROR, db_->errcode()); @@ -223,7 +223,7 @@ TEST_F(SqliteTest, PrepareFailed) { TEST_F(SqliteTest, BindFailed) { auto stmt = db_->PrepareOrDie("INSERT INTO T (a) VALUES (123)"); stmt.BindInt(1, 123); - Status s = stmt.StepOnce(); + absl::Status s = stmt.StepOnce(); EXPECT_NE(string::npos, s.message().find("INSERT INTO T (a) VALUES (123)")) << s.message(); } diff --git a/tensorflow/core/lib/gtl/BUILD b/tensorflow/core/lib/gtl/BUILD index 868d05f0912fc8..801bf59f7dcf79 100644 --- a/tensorflow/core/lib/gtl/BUILD +++ b/tensorflow/core/lib/gtl/BUILD @@ -48,7 +48,7 @@ cc_library( name = "compactptrset", hdrs = ["compactptrset.h"], deps = [ - "@local_tsl//tsl/lib/gtl:compactptrset", + "@local_xla//xla/tsl/lib/gtl:compactptrset", ], ) @@ -75,7 +75,7 @@ cc_library( "//tensorflow/core/lib/hash", "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/gtl:flatmap", + "@local_xla//xla/tsl/lib/gtl:flatmap", ], ) @@ -83,7 +83,7 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ - "@local_tsl//tsl/lib/gtl:flatrep", + "@local_xla//xla/tsl/lib/gtl:flatrep", ], ) @@ -91,7 +91,7 @@ cc_library( name = "flatset", hdrs = ["flatset.h"], deps = [ - "@local_tsl//tsl/lib/gtl:flatset", + "@local_xla//xla/tsl/lib/gtl:flatset", ], ) @@ -102,7 +102,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", "@com_google_absl//absl/container:inlined_vector", - "@local_tsl//tsl/lib/gtl:inlined_vector", + "@local_xla//xla/tsl/lib/gtl:inlined_vector", ], ) @@ -110,7 +110,7 @@ cc_library( name = "int_type", hdrs = ["int_type.h"], deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "@local_xla//xla/tsl/lib/gtl:int_type", ], ) @@ -118,7 +118,7 @@ cc_library( name = "iterator_range", hdrs = ["iterator_range.h"], deps = [ - "@local_tsl//tsl/lib/gtl:iterator_range", + "@local_xla//xla/tsl/lib/gtl:iterator_range", ], ) @@ -140,7 +140,7 @@ cc_library( hdrs = ["map_util.h"], deps = [ "//tensorflow/core/platform:hash", # TODO(dduneavy) examples/custom_ops_doc transitively depends on this - "@local_tsl//tsl/lib/gtl:map_util", + "@local_xla//xla/tsl/lib/gtl:map_util", ], ) @@ -167,7 +167,7 @@ filegroup( "inlined_vector.h", "iterator_range.h", "priority_queue_util.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_gtl_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_gtl_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -180,7 +180,7 @@ filegroup( "manual_constructor.h", "map_util.h", "top_n.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_internal_public_gtl_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_internal_public_gtl_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -189,7 +189,7 @@ filegroup( name = "legacy_lib_test_internal_headers", srcs = [ "manual_constructor.h", - "@local_tsl//tsl/lib/gtl:legacy_lib_test_internal_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_test_internal_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -198,7 +198,7 @@ filegroup( name = "legacy_android_gif_internal_headers", srcs = [ "cleanup.h", - "@local_tsl//tsl/lib/gtl:legacy_android_gif_internal_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_android_gif_internal_headers", ], visibility = [ "//tensorflow/core:__pkg__", @@ -215,7 +215,7 @@ filegroup( "flatrep.h", "inlined_vector.h", "top_n.h", - "@local_tsl//tsl/lib/gtl:mobile_srcs_no_runtime", + "@local_xla//xla/tsl/lib/gtl:mobile_srcs_no_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -232,7 +232,7 @@ filegroup( "map_util.h", "priority_queue_util.h", "//tensorflow/core/lib/gtl/subtle:map_traits", - "@local_tsl//tsl/lib/gtl:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/gtl:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -255,7 +255,7 @@ filegroup( "priority_queue_util.h", "top_n.h", "//tensorflow/core/lib/gtl/subtle:map_traits", - "@local_tsl//tsl/lib/gtl:legacy_lib_gtl_all_headers", + "@local_xla//xla/tsl/lib/gtl:legacy_lib_gtl_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/gtl/compactptrset.h b/tensorflow/core/lib/gtl/compactptrset.h index 326aca55d34f0e..6655ac92d99ec7 100644 --- a/tensorflow/core/lib/gtl/compactptrset.h +++ b/tensorflow/core/lib/gtl/compactptrset.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ #define TENSORFLOW_CORE_LIB_GTL_COMPACTPTRSET_H_ -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h index 818ec69fd96fd7..94a5ad687d678a 100644 --- a/tensorflow/core/lib/gtl/edit_distance.h +++ b/tensorflow/core/lib/gtl/edit_distance.h @@ -44,9 +44,8 @@ namespace gtl { // int64 dist = LevenshteinDistance("hi", "bye", std::equal_to()); // template -inline int64_t LevenshteinDistance(const gtl::ArraySlice& s, - const gtl::ArraySlice& t, - const Cmp& cmp) { +inline int64_t LevenshteinDistance(const gtl::ArraySlice s, + const gtl::ArraySlice t, const Cmp& cmp) { const int64_t s_size = s.size(); const int64_t t_size = t.size(); diff --git a/tensorflow/core/lib/gtl/flatmap.h b/tensorflow/core/lib/gtl/flatmap.h index 15c02b381daa19..3b112a714cb883 100644 --- a/tensorflow/core/lib/gtl/flatmap.h +++ b/tensorflow/core/lib/gtl/flatmap.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATMAP_H_ +#include "xla/tsl/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatrep.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/gtl/flatmap.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/flatrep.h b/tensorflow/core/lib/gtl/flatrep.h index dfc8af5ef01dfa..59caa4b086708a 100644 --- a/tensorflow/core/lib/gtl/flatrep.h +++ b/tensorflow/core/lib/gtl/flatrep.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATREP_H_ -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/flatset.h b/tensorflow/core/lib/gtl/flatset.h index 659bf98f74f9b1..fcb7ed96b9a166 100644 --- a/tensorflow/core/lib/gtl/flatset.h +++ b/tensorflow/core/lib/gtl/flatset.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ #define TENSORFLOW_CORE_LIB_GTL_FLATSET_H_ -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index b94af28b839346..df9d1a245dbf9a 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include "tsl/lib/gtl/inlined_vector.h" // IWYU pragma: export +#include "xla/tsl/lib/gtl/inlined_vector.h" // IWYU pragma: export // TODO(kramerb): This is kept only because lots of targets transitively depend // on it. Remove all targets' dependencies. #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/lib/gtl/int_type.h b/tensorflow/core/lib/gtl/int_type.h index 2259cb8bc3b84d..c161ee917e82cc 100644 --- a/tensorflow/core/lib/gtl/int_type.h +++ b/tensorflow/core/lib/gtl/int_type.h @@ -17,7 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ #define TENSORFLOW_CORE_LIB_GTL_INT_TYPE_H_ -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/iterator_range.h b/tensorflow/core/lib/gtl/iterator_range.h index 4748761d8da0a8..ca980fd536b2d8 100644 --- a/tensorflow/core/lib/gtl/iterator_range.h +++ b/tensorflow/core/lib/gtl/iterator_range.h @@ -25,7 +25,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ #define TENSORFLOW_CORE_LIB_GTL_ITERATOR_RANGE_H_ -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/map_util.h b/tensorflow/core/lib/gtl/map_util.h index 3b2d767548809d..47d28e7dd23e1b 100644 --- a/tensorflow/core/lib/gtl/map_util.h +++ b/tensorflow/core/lib/gtl/map_util.h @@ -20,7 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ #define TENSORFLOW_CORE_LIB_GTL_MAP_UTIL_H_ -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/gtl/subtle/BUILD b/tensorflow/core/lib/gtl/subtle/BUILD index 2b79160c01ea0e..f74d6f7604eec5 100644 --- a/tensorflow/core/lib/gtl/subtle/BUILD +++ b/tensorflow/core/lib/gtl/subtle/BUILD @@ -12,7 +12,7 @@ filegroup( name = "map_traits", srcs = [ "map_traits.h", - "@local_tsl//tsl/lib/gtl/subtle:map_traits", + "@local_xla//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = ["//tensorflow/core/lib/gtl:__pkg__"], ) diff --git a/tensorflow/core/lib/gtl/subtle/map_traits.h b/tensorflow/core/lib/gtl/subtle/map_traits.h index a5296b8b93a010..c4cca1fb644640 100644 --- a/tensorflow/core/lib/gtl/subtle/map_traits.h +++ b/tensorflow/core/lib/gtl/subtle/map_traits.h @@ -23,7 +23,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ #define TENSORFLOW_CORE_LIB_GTL_SUBTLE_MAP_TRAITS_H_ -#include "tsl/lib/gtl/subtle/map_traits.h" +#include "xla/tsl/lib/gtl/subtle/map_traits.h" namespace tensorflow { namespace gtl { diff --git a/tensorflow/core/lib/hash/BUILD b/tensorflow/core/lib/hash/BUILD index c2b6018d034e64..8c1e8cd471776d 100644 --- a/tensorflow/core/lib/hash/BUILD +++ b/tensorflow/core/lib/hash/BUILD @@ -26,7 +26,7 @@ cc_library( "//tensorflow/core/platform", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/hash:crc32c", + "@local_xla//xla/tsl/lib/hash:crc32c", ], ) @@ -51,7 +51,7 @@ filegroup( name = "mobile_srcs_only_runtime", srcs = [ "crc32c.h", - "@local_tsl//tsl/lib/hash:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/hash:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -61,7 +61,7 @@ filegroup( srcs = [ "crc32c.h", "hash.h", - "@local_tsl//tsl/lib/hash:legacy_lib_hash_all_headers", + "@local_xla//xla/tsl/lib/hash:legacy_lib_hash_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h index 07945b10a92b0e..7e8c8307af2e92 100644 --- a/tensorflow/core/lib/hash/crc32c.h +++ b/tensorflow/core/lib/hash/crc32c.h @@ -18,10 +18,10 @@ limitations under the License. #include +#include "xla/tsl/lib/hash/crc32c.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/hash/crc32c.h" namespace tensorflow { namespace crc32c { diff --git a/tensorflow/core/lib/io/BUILD b/tensorflow/core/lib/io/BUILD index a525a92c43ba65..9da3498fbc2357 100644 --- a/tensorflow/core/lib/io/BUILD +++ b/tensorflow/core/lib/io/BUILD @@ -8,8 +8,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ "//tensorflow/c/experimental/filesystem:__pkg__", - "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", - "@local_tsl//tsl/lib/io/snappy:__pkg__", + "@local_xla//xla/tsl/lib/io/snappy:__pkg__", "//third_party/py/tensorflow_io:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core:__pkg__", @@ -31,7 +30,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:block", + "@local_xla//xla/tsl/lib/io:block", ], ) @@ -41,14 +40,14 @@ cc_library( deps = [ ":inputstream_interface", "//tensorflow/core/platform:env", - "@local_tsl//tsl/lib/io:buffered_inputstream", + "@local_xla//xla/tsl/lib/io:buffered_inputstream", ], ) cc_library( name = "compression", hdrs = ["compression.h"], - deps = ["@local_tsl//tsl/lib/io:compression"], + deps = ["@local_xla//xla/tsl/lib/io:compression"], ) cc_library( @@ -60,7 +59,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:inputbuffer", + "@local_xla//xla/tsl/lib/io:inputbuffer", ], ) @@ -72,7 +71,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:inputstream_interface", + "@local_xla//xla/tsl/lib/io:inputstream_interface", ], ) @@ -82,7 +81,7 @@ cc_library( deps = [ "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:iterator", + "@local_xla//xla/tsl/lib/io:iterator", ], ) @@ -100,7 +99,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:protobuf", "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:proto_encode_helper", + "@local_xla//xla/tsl/lib/io:proto_encode_helper", ], ) @@ -111,7 +110,7 @@ cc_library( ":inputstream_interface", "//tensorflow/core/platform:cord", "//tensorflow/core/platform:env", - "@local_tsl//tsl/lib/io:random_inputstream", + "@local_xla//xla/tsl/lib/io:random_inputstream", ], ) @@ -127,7 +126,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:record_reader", + "@local_xla//xla/tsl/lib/io:record_reader", ], ) @@ -145,28 +144,28 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:record_writer", + "@local_xla//xla/tsl/lib/io:record_writer", ], ) alias( name = "snappy_inputbuffer", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_inputbuffer", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_inputbuffer", ) alias( name = "snappy_inputstream", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_inputstream", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_inputstream", ) alias( name = "snappy_outputbuffer", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_outputbuffer", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_outputbuffer", ) alias( name = "snappy_compression_options", - actual = "@local_tsl//tsl/lib/io/snappy:snappy_compression_options", + actual = "@local_xla//xla/tsl/lib/io/snappy:snappy_compression_options", ) cc_library( @@ -174,7 +173,7 @@ cc_library( hdrs = ["cache.h"], deps = [ "//tensorflow/core/platform:stringpiece", - "@local_tsl//tsl/lib/io:cache", + "@local_xla//xla/tsl/lib/io:cache", ], ) @@ -186,14 +185,14 @@ cc_library( ], deps = [ ":iterator", - "@local_tsl//tsl/lib/io:table", + "@local_xla//xla/tsl/lib/io:table", ], ) cc_library( name = "table_options", hdrs = ["table_options.h"], - deps = ["@local_tsl//tsl/lib/io:table_options"], + deps = ["@local_xla//xla/tsl/lib/io:table_options"], ) cc_library( @@ -201,7 +200,7 @@ cc_library( hdrs = ["zlib_compression_options.h"], deps = [ "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_compression_options", + "@local_xla//xla/tsl/lib/io:zlib_compression_options", ], ) @@ -215,7 +214,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:status", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_inputstream", + "@local_xla//xla/tsl/lib/io:zlib_inputstream", ], ) @@ -229,7 +228,7 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:stringpiece", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", + "@local_xla//xla/tsl/lib/io:zlib_outputbuffer", ], ) diff --git a/tensorflow/core/lib/io/block.h b/tensorflow/core/lib/io/block.h index e6417881718060..d3cfb88f97e46f 100644 --- a/tensorflow/core/lib/io/block.h +++ b/tensorflow/core/lib/io/block.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_H_ #define TENSORFLOW_CORE_LIB_IO_BLOCK_H_ +#include "xla/tsl/lib/io/block.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/block.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h index b83db6dbfa5726..b47278cba40e30 100644 --- a/tensorflow/core/lib/io/block_builder.h +++ b/tensorflow/core/lib/io/block_builder.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ #define TENSORFLOW_CORE_LIB_IO_BLOCK_BUILDER_H_ +#include "xla/tsl/lib/io/block_builder.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/block_builder.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h index b211dc05efefc1..15023e6aa5d5b0 100644 --- a/tensorflow/core/lib/io/buffered_inputstream.h +++ b/tensorflow/core/lib/io/buffered_inputstream.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/file_system.h" -#include "tsl/lib/io/buffered_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/cache.h b/tensorflow/core/lib/io/cache.h index 7c647d80090fdf..3afd011fdf79e5 100644 --- a/tensorflow/core/lib/io/cache.h +++ b/tensorflow/core/lib/io/cache.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_CACHE_H_ #define TENSORFLOW_CORE_LIB_IO_CACHE_H_ +#include "xla/tsl/lib/io/cache.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/cache.h" namespace tensorflow { using tsl::Slice; // NOLINT(misc-unused-using-decls) diff --git a/tensorflow/core/lib/io/compression.h b/tensorflow/core/lib/io/compression.h index 326e1a17e5144d..628de3751edb04 100644 --- a/tensorflow/core/lib/io/compression.h +++ b/tensorflow/core/lib/io/compression.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ #define TENSORFLOW_CORE_LIB_IO_COMPRESSION_H_ -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/compression.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/format.h b/tensorflow/core/lib/io/format.h index 0ea77614bc4a5c..49f96d1929c658 100644 --- a/tensorflow/core/lib/io/format.h +++ b/tensorflow/core/lib/io/format.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_FORMAT_H_ #define TENSORFLOW_CORE_LIB_IO_FORMAT_H_ +#include "xla/tsl/lib/io/format.h" #include "tensorflow/core/lib/io/table_builder.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/format.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/inputbuffer.h b/tensorflow/core/lib/io/inputbuffer.h index 0cc33cb1aac895..2573a81657c056 100644 --- a/tensorflow/core/lib/io/inputbuffer.h +++ b/tensorflow/core/lib/io/inputbuffer.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ #define TENSORFLOW_CORE_LIB_IO_INPUTBUFFER_H_ +#include "xla/tsl/lib/io/inputbuffer.h" #include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/inputbuffer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h index 135043a4f356ba..f38489d55c6d86 100644 --- a/tensorflow/core/lib/io/inputstream_interface.h +++ b/tensorflow/core/lib/io/inputstream_interface.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ #define TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/inputstream_interface.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/iterator.h b/tensorflow/core/lib/io/iterator.h index a758cf51ef10e9..4f3c096086a4a5 100644 --- a/tensorflow/core/lib/io/iterator.h +++ b/tensorflow/core/lib/io/iterator.h @@ -26,9 +26,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ #define TENSORFLOW_CORE_LIB_IO_ITERATOR_H_ +#include "xla/tsl/lib/io/iterator.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/iterator.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/proto_encode_helper.h b/tensorflow/core/lib/io/proto_encode_helper.h index 97b98bac26630b..8ca1d5beb300da 100644 --- a/tensorflow/core/lib/io/proto_encode_helper.h +++ b/tensorflow/core/lib/io/proto_encode_helper.h @@ -16,11 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ #define TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_ +#include "xla/tsl/lib/io/proto_encode_helper.h" #include "tensorflow/core/platform/coding.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/proto_encode_helper.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h index cb3c5ed6f98326..70651bc67f3d5c 100644 --- a/tensorflow/core/lib/io/random_inputstream.h +++ b/tensorflow/core/lib/io/random_inputstream.h @@ -16,10 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/random_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/file_system.h" -#include "tsl/lib/io/random_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index 51a332b6fbd0db..c2a06c6b666908 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_reader.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/record_reader.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 63dd44427a2ea1..602de00ed872d5 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -24,10 +24,10 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD +#include "xla/tsl/lib/io/record_writer.h" #include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/record_writer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h index 93466ddaa6e315..0045829a1af5c1 100644 --- a/tensorflow/core/lib/io/table.h +++ b/tensorflow/core/lib/io/table.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_H_ +#include "xla/tsl/lib/io/table.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/table.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h index a895387a2c3fbe..52e27e9af9ef94 100644 --- a/tensorflow/core/lib/io/table_builder.h +++ b/tensorflow/core/lib/io/table_builder.h @@ -24,10 +24,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_ +#include "xla/tsl/lib/io/table_builder.h" #include "tensorflow/core/lib/io/table_options.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" -#include "tsl/lib/io/table_builder.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h index b751aefdbeaef3..c16d4aca7e30b6 100644 --- a/tensorflow/core/lib/io/table_options.h +++ b/tensorflow/core/lib/io/table_options.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ #define TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_ -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/table_options.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/two_level_iterator.h b/tensorflow/core/lib/io/two_level_iterator.h index 357efd1d5993b0..c2b94de7f26439 100644 --- a/tensorflow/core/lib/io/two_level_iterator.h +++ b/tensorflow/core/lib/io/two_level_iterator.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ #define TENSORFLOW_CORE_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#include "xla/tsl/lib/io/two_level_iterator.h" #include "tensorflow/core/lib/io/iterator.h" -#include "tsl/lib/io/two_level_iterator.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/zlib_compression_options.h b/tensorflow/core/lib/io/zlib_compression_options.h index 643c041ec6efc0..a0d433782b69cb 100644 --- a/tensorflow/core/lib/io/zlib_compression_options.h +++ b/tensorflow/core/lib/io/zlib_compression_options.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_compression_options.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/zlib_inputstream.h b/tensorflow/core/lib/io/zlib_inputstream.h index 75bef87ca38c2e..086493e31face5 100644 --- a/tensorflow/core/lib/io/zlib_inputstream.h +++ b/tensorflow/core/lib/io/zlib_inputstream.h @@ -16,13 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_INPUTSTREAM_H_ +#include "xla/tsl/lib/io/zlib_inputstream.h" #include "tensorflow/core/lib/io/inputstream_interface.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_inputstream.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index f68a594bda8551..7d3950f633abbe 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ #define TENSORFLOW_CORE_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/io/zlib_outputbuffer.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/math/BUILD b/tensorflow/core/lib/math/BUILD index 751af8e9d026ac..7c14708ed116ec 100644 --- a/tensorflow/core/lib/math/BUILD +++ b/tensorflow/core/lib/math/BUILD @@ -20,7 +20,7 @@ cc_library( deps = [ "//tensorflow/core/platform:logging", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/math:math_util", + "@local_xla//xla/tsl/lib/math:math_util", ], ) diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h index b92e421b6e2821..39bae7f4308a48 100644 --- a/tensorflow/core/lib/math/math_util.h +++ b/tensorflow/core/lib/math/math_util.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ #define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_ +#include "xla/tsl/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/math/math_util.h" namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) diff --git a/tensorflow/core/lib/png/BUILD b/tensorflow/core/lib/png/BUILD index cdd3491276c9a3..b23730e0fa50fc 100644 --- a/tensorflow/core/lib/png/BUILD +++ b/tensorflow/core/lib/png/BUILD @@ -16,6 +16,7 @@ cc_library( srcs = ["png_io.cc"], hdrs = ["png_io.h"], features = ["-layering_check"], + visibility = ["//visibility:public"], deps = [ "//tensorflow/core/platform:byte_order", "//tensorflow/core/platform:logging", diff --git a/tensorflow/core/lib/random/BUILD b/tensorflow/core/lib/random/BUILD index db2c962671c3f0..ef6262c5876cef 100644 --- a/tensorflow/core/lib/random/BUILD +++ b/tensorflow/core/lib/random/BUILD @@ -13,7 +13,7 @@ package( cc_library( name = "exact_uniform_int", hdrs = ["exact_uniform_int.h"], - deps = ["@local_tsl//tsl/lib/random:exact_uniform_int"], + deps = ["@local_xla//xla/tsl/lib/random:exact_uniform_int"], ) cc_library( @@ -32,7 +32,7 @@ cc_library( "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/random:philox", + "@local_xla//xla/tsl/lib/random:philox", ], ) @@ -43,7 +43,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":philox_random", - "@local_tsl//tsl/lib/random:random_distributions_utils", + "@local_xla//xla/tsl/lib/random:random_distributions_utils", ], ) @@ -51,7 +51,7 @@ cc_library( name = "philox_random", hdrs = ["philox_random.h"], compatible_with = get_compatible_with_portable(), - deps = ["@local_tsl//tsl/lib/random:philox_random"], + deps = ["@local_xla//xla/tsl/lib/random:philox_random"], ) cc_library( @@ -74,7 +74,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:macros", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/random:weighted_picker", + "@local_xla//xla/tsl/lib/random:weighted_picker", ], ) @@ -90,7 +90,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/random:mobile_srcs_only_runtime", ], ) @@ -102,7 +102,7 @@ filegroup( "random_distributions.h", "random_distributions_utils.h", "simple_philox.h", - "@local_tsl//tsl/lib/random:legacy_lib_random_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_random_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -114,7 +114,7 @@ filegroup( "random_distributions.h", "random_distributions_utils.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:legacy_lib_internal_public_random_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_internal_public_random_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -130,7 +130,7 @@ filegroup( "random_distributions_utils.h", "simple_philox.h", "weighted_picker.h", - "@local_tsl//tsl/lib/random:legacy_lib_random_all_headers", + "@local_xla//xla/tsl/lib/random:legacy_lib_random_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h index 585def64e91b73..6218d8998fa1ab 100644 --- a/tensorflow/core/lib/random/distribution_sampler.h +++ b/tensorflow/core/lib/random/distribution_sampler.h @@ -31,12 +31,12 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ #define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#include "xla/tsl/lib/random/distribution_sampler.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/distribution_sampler.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/exact_uniform_int.h b/tensorflow/core/lib/random/exact_uniform_int.h index 5f02c664b74f0b..cd511d43f55510 100644 --- a/tensorflow/core/lib/random/exact_uniform_int.h +++ b/tensorflow/core/lib/random/exact_uniform_int.h @@ -18,7 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ #define TENSORFLOW_CORE_LIB_RANDOM_EXACT_UNIFORM_INT_H_ -#include "tsl/lib/random/exact_uniform_int.h" +#include "xla/tsl/lib/random/exact_uniform_int.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h index c2d44ecfd541de..2fe4120f9674b3 100644 --- a/tensorflow/core/lib/random/philox_random.h +++ b/tensorflow/core/lib/random/philox_random.h @@ -20,7 +20,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ #define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_ -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h index 0a9e4f94f3d72d..57ce99a07333b0 100644 --- a/tensorflow/core/lib/random/random_distributions.h +++ b/tensorflow/core/lib/random/random_distributions.h @@ -17,10 +17,10 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/lib/random/random_distributions.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions_utils.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/random_distributions.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/random_distributions_utils.h b/tensorflow/core/lib/random/random_distributions_utils.h index 4b7267031bdd86..4c2680493bceae 100644 --- a/tensorflow/core/lib/random/random_distributions_utils.h +++ b/tensorflow/core/lib/random/random_distributions_utils.h @@ -20,8 +20,8 @@ limitations under the License. #include +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tensorflow/core/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h index fa7f49ebecfaaf..7c94ca21414459 100644 --- a/tensorflow/core/lib/random/simple_philox.h +++ b/tensorflow/core/lib/random/simple_philox.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ #define TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_ +#include "xla/tsl/lib/random/simple_philox.h" #include "tensorflow/core/lib/random/philox_random.h" #include "tensorflow/core/lib/random/random_distributions.h" -#include "tsl/lib/random/simple_philox.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/random/weighted_picker.h b/tensorflow/core/lib/random/weighted_picker.h index 58d4198ae97d4a..ae404814960096 100644 --- a/tensorflow/core/lib/random/weighted_picker.h +++ b/tensorflow/core/lib/random/weighted_picker.h @@ -29,10 +29,10 @@ limitations under the License. #include +#include "xla/tsl/lib/random/weighted_picker.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/random/weighted_picker.h" namespace tensorflow { namespace random { diff --git a/tensorflow/core/lib/strings/proto_serialization_test.cc b/tensorflow/core/lib/strings/proto_serialization_test.cc index 216075830c4fa0..0ffe00f36c9fd6 100644 --- a/tensorflow/core/lib/strings/proto_serialization_test.cc +++ b/tensorflow/core/lib/strings/proto_serialization_test.cc @@ -62,7 +62,7 @@ static void BM_ProtoSerializationToBuffer(::testing::benchmark::State& state) { const size_t size = graph_def.ByteSizeLong(); for (auto i : state) { - gtl::InlinedVector buf(size); + absl::InlinedVector buf(size); testing::DoNotOptimize( SerializeToBufferDeterministic(graph_def, buf.data(), size)); } diff --git a/tensorflow/core/lib/strings/proto_text_util.h b/tensorflow/core/lib/strings/proto_text_util.h index 44559578da02e9..af288e0738011f 100644 --- a/tensorflow/core/lib/strings/proto_text_util.h +++ b/tensorflow/core/lib/strings/proto_text_util.h @@ -85,8 +85,7 @@ class ProtoTextOutput { // Appends a string value, like my_field: "abc123". void AppendString(const char field_name[], const string& value) { - AppendFieldAndValue( - field_name, StrCat("\"", ::tensorflow::str_util::CEscape(value), "\"")); + AppendFieldAndValue(field_name, StrCat("\"", absl::CEscape(value), "\"")); } // Appends a string value, like my_field: "abc123", but only if value is not diff --git a/tensorflow/core/nccl/BUILD b/tensorflow/core/nccl/BUILD index 99dc7559d59982..ab490c04f068e3 100644 --- a/tensorflow/core/nccl/BUILD +++ b/tensorflow/core/nccl/BUILD @@ -32,7 +32,6 @@ cc_library( "@local_config_nccl//:nccl", "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:unbounded_work_queue", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", "//tensorflow/core/framework:tensor_proto_cc", ]) + if_rocm([ "@local_config_rocm//rocm:rccl", diff --git a/tensorflow/core/nccl/nccl_manager.cc b/tensorflow/core/nccl/nccl_manager.cc index cf3ceb670ba717..d48c1484ae7f32 100644 --- a/tensorflow/core/nccl/nccl_manager.cc +++ b/tensorflow/core/nccl/nccl_manager.cc @@ -28,15 +28,12 @@ limitations under the License. #include "tensorflow/core/profiler/lib/annotated_traceme.h" #include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" -#if GOOGLE_CUDA -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#elif TENSORFLOW_USE_ROCM +#if TENSORFLOW_USE_ROCM #include "tensorflow/core/platform/rocm.h" #endif namespace tensorflow { -using stream_executor::gpu::ScopedActivateContext; #if TENSORFLOW_USE_ROCM // Local hipify of cuda symbols #define cudaError_t hipError_t @@ -719,7 +716,8 @@ void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { #else se::Stream* comm_stream = nccl_stream->stream.get(); #endif - ScopedActivateContext scoped_context(nccl_stream->executor); + std::unique_ptr scoped_context = + nccl_stream->executor->Activate(); cudaStream_t cu_stream = reinterpret_cast( comm_stream->platform_specific_handle().stream); diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index b05c4125eaa9bd..b1f50645f1dafb 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -33,7 +33,7 @@ REGISTER_OP_NO_GRADIENT("FakeQuantWithMinMaxArgsGradient"); REGISTER_OP_NO_GRADIENT("FakeQuantWithMinMaxVarsGradient"); REGISTER_OP_NO_GRADIENT("FakeQuantWithMinMaxVarsPerChannelGradient"); -Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -54,7 +54,7 @@ Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Reshape", ReshapeGrad); REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad); -Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -73,7 +73,7 @@ Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Squeeze", SqueezeGrad); -Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -92,7 +92,7 @@ Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Identity", IdentityGrad); -Status PackGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status PackGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Create( "_", @@ -118,7 +118,7 @@ Status PackGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Pack", PackGrad); -Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -142,8 +142,8 @@ Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Unpack", UnpackGrad); -Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, - bool dim_is_last_arg) { +absl::Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, + bool dim_is_last_arg) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); DataType T; @@ -215,18 +215,18 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, return absl::OkStatus(); } -Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) { return ConcatGradHelper(attrs, g, false); } -Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) { return ConcatGradHelper(attrs, g, true); } REGISTER_OP_GRADIENT("Concat", ConcatGrad); REGISTER_OP_GRADIENT("ConcatV2", ConcatGradV2); -Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -246,7 +246,7 @@ Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Split", SplitGrad); -Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -267,7 +267,7 @@ Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("SplitV", SplitVGrad); -Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { int N; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N)); std::vector dys; @@ -294,7 +294,7 @@ Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("_ArrayToList", ArrayToListGrad); -Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -314,7 +314,7 @@ Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad); -Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs {"dims: int32", "x: T", "dy: T"}, @@ -337,7 +337,7 @@ Status FillGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Fill", FillGrad); -Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs {"x: T", "p: int32", "dy: T"}, @@ -356,7 +356,7 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Transpose", TransposeGrad); -Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -377,7 +377,7 @@ Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad); -Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs {"x: T", "p: int32", "dy: T"}, @@ -396,7 +396,7 @@ Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("ConjugateTranspose", ConjugateTransposeGrad); -Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs {"x: T", "d: bool", "dy: T"}, @@ -414,7 +414,7 @@ Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Reverse", ReverseGrad); -Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) { DataType itype; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype)); if (itype != DT_INT32) { @@ -438,7 +438,7 @@ Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("ReverseV2", ReverseV2Grad); -Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { DataType itype; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype)); if (itype != DT_INT32) { @@ -473,7 +473,7 @@ Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Slice", SliceGrad); -Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) { DataType itype; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype)); if (itype != DT_INT32) { @@ -510,7 +510,7 @@ Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad); -Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { DataType itype; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype)); if (itype != DT_INT32) { @@ -552,7 +552,7 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad); -Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) { DataType itype; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype)); if (itype != DT_INT32) { diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 942a45b0e7b56e..8d53c6dbb38425 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -42,8 +42,8 @@ using shape_inference::UnchangedShape; namespace { -Status GetAxisForPackAndUnpack(InferenceContext* c, int32_t rank_after_pack, - int32* axis) { +absl::Status GetAxisForPackAndUnpack(InferenceContext* c, + int32_t rank_after_pack, int32* axis) { TF_RETURN_IF_ERROR(c->GetAttr("axis", axis)); if (*axis < -1 * rank_after_pack || *axis >= rank_after_pack) { return errors::InvalidArgument("Invalid axis: ", *axis, "; must be in [", @@ -65,8 +65,8 @@ std::vector AsInt64(const Tensor* tensor, int64_t num_elements) { } template -Status PadKnown(InferenceContext* c, ShapeHandle input, - const Tensor* paddings_t, int64_t num_dims) { +absl::Status PadKnown(InferenceContext* c, ShapeHandle input, + const Tensor* paddings_t, int64_t num_dims) { // paddings_t is known. std::vector dims(num_dims); auto paddings_data = paddings_t->matrix(); @@ -82,7 +82,7 @@ Status PadKnown(InferenceContext* c, ShapeHandle input, return absl::OkStatus(); } -Status PadShapeFn(InferenceContext* c) { +absl::Status PadShapeFn(InferenceContext* c) { // Paddings is a matrix of [input_rank, 2]. ShapeHandle paddings; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &paddings)); @@ -122,7 +122,7 @@ Status PadShapeFn(InferenceContext* c) { } } -Status TransposeShapeFn(InferenceContext* c) { +absl::Status TransposeShapeFn(InferenceContext* c) { ShapeHandle input = c->input(0); ShapeHandle perm_shape = c->input(1); const Tensor* perm = c->input_tensor(1); @@ -188,7 +188,7 @@ Status TransposeShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status SetOutputShapeForReshape(InferenceContext* c) { +absl::Status SetOutputShapeForReshape(InferenceContext* c) { ShapeHandle in = c->input(0); ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); @@ -1120,7 +1120,7 @@ REGISTER_OP("Fill") .Attr("index_type: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { DataType index_type = DT_INT32; - Status s = c->GetAttr("index_type", &index_type); + absl::Status s = c->GetAttr("index_type", &index_type); if (!s.ok() && s.code() != error::NOT_FOUND) { return s; } @@ -1460,7 +1460,7 @@ REGISTER_OP("_MklConjugateTranspose") // -------------------------------------------------------------------------- namespace { -Status UniqueIdxShapeFn(InferenceContext* c) { +absl::Status UniqueIdxShapeFn(InferenceContext* c) { ShapeHandle input = c->input(0); const Tensor* axis_t = c->input_tensor(1); if (axis_t == nullptr || !c->RankKnown(input)) { @@ -1565,7 +1565,7 @@ REGISTER_OP("UniqueWithCountsV2") namespace { -Status ShapeShapeFn(InferenceContext* c) { +absl::Status ShapeShapeFn(InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { DimensionHandle dim; if (c->RankKnown(c->input(i))) { @@ -1973,8 +1973,8 @@ REGISTER_OP("MirrorPad") // -------------------------------------------------------------------------- namespace { template -Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, - const Tensor* paddings_t, int64_t input_rank) { +absl::Status MirrorPadKnown(InferenceContext* c, ShapeHandle input, + const Tensor* paddings_t, int64_t input_rank) { auto paddings_data = paddings_t->matrix(); std::vector dims(input_rank); for (int64_t i = 0; i < input_rank; ++i) { @@ -2244,11 +2244,12 @@ std::vector GetFlatInt64(const Tensor& t) { } } -Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape, - ShapeHandle block_shape_shape, - const Tensor* block_shape_t, - ShapeHandle paddings_shape, - const Tensor* paddings_t) { +absl::Status SpaceToBatchShapeHelper(InferenceContext* c, + ShapeHandle input_shape, + ShapeHandle block_shape_shape, + const Tensor* block_shape_t, + ShapeHandle paddings_shape, + const Tensor* paddings_t) { if (c->Rank(block_shape_shape) != 1) { return errors::InvalidArgument("block_shape must have rank 1."); } @@ -2320,10 +2321,12 @@ Status SpaceToBatchShapeHelper(InferenceContext* c, ShapeHandle input_shape, return absl::OkStatus(); } -Status BatchToSpaceShapeHelper(InferenceContext* c, ShapeHandle input_shape, - ShapeHandle block_shape_shape, - const Tensor* block_shape_t, - ShapeHandle crops_shape, const Tensor* crops_t) { +absl::Status BatchToSpaceShapeHelper(InferenceContext* c, + ShapeHandle input_shape, + ShapeHandle block_shape_shape, + const Tensor* block_shape_t, + ShapeHandle crops_shape, + const Tensor* crops_t) { if (c->Rank(block_shape_shape) != 1) { return errors::InvalidArgument("block_shape must have rank 1."); } @@ -3018,7 +3021,7 @@ REGISTER_OP("Dequantize") .Attr("dtype: {bfloat16, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { int axis = -1; - Status s = c->GetAttr("axis", &axis); + absl::Status s = c->GetAttr("axis", &axis); if (!s.ok() && s.code() != error::NOT_FOUND) { return s; } @@ -3126,7 +3129,7 @@ REGISTER_OP("QuantizedInstanceNorm") namespace { -Status ScatterNdTensorShape(InferenceContext* c) { +absl::Status ScatterNdTensorShape(InferenceContext* c) { ShapeHandle output_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &output_shape)); ShapeHandle indices_shape; diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 0b42309b349608..5546a6c158e7f1 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -414,7 +414,7 @@ TEST(ArrayOpsTest, Shape_ShapeFn) { } // Runs type inference pass on graph -static Status type_inference(Graph& graph) { +static absl::Status type_inference(Graph& graph) { GraphOptimizationPassOptions opt_options; std::unique_ptr graph_ptr(new Graph(OpRegistry::Global())); graph_ptr->Copy(graph); diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc index a8fc883e98d1dd..4d05a4645ccdf2 100644 --- a/tensorflow/core/ops/audio_ops.cc +++ b/tensorflow/core/ops/audio_ops.cc @@ -27,7 +27,7 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -Status DecodeWavShapeFn(InferenceContext* c) { +absl::Status DecodeWavShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); @@ -60,7 +60,7 @@ Status DecodeWavShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status EncodeWavShapeFn(InferenceContext* c) { +absl::Status EncodeWavShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); @@ -68,7 +68,7 @@ Status EncodeWavShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status SpectrogramShapeFn(InferenceContext* c) { +absl::Status SpectrogramShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); int32_t window_size; @@ -110,7 +110,7 @@ Status SpectrogramShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status MfccShapeFn(InferenceContext* c) { +absl::Status MfccShapeFn(InferenceContext* c) { ShapeHandle spectrogram; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &spectrogram)); ShapeHandle unused; diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc index 403fae4b19321a..c44114a787a123 100644 --- a/tensorflow/core/ops/candidate_sampling_ops.cc +++ b/tensorflow/core/ops/candidate_sampling_ops.cc @@ -24,7 +24,7 @@ using shape_inference::ShapeHandle; namespace { -Status CandidateSamplerShapeFn(InferenceContext* c) { +absl::Status CandidateSamplerShapeFn(InferenceContext* c) { int64_t num_sampled; TF_RETURN_IF_ERROR(c->GetAttr("num_sampled", &num_sampled)); int64_t num_true; diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.cc b/tensorflow/core/ops/compat/op_compatibility_lib.cc index 7f0fa492b53196..9626cbb4ccabc5 100644 --- a/tensorflow/core/ops/compat/op_compatibility_lib.cc +++ b/tensorflow/core/ops/compat/op_compatibility_lib.cc @@ -51,13 +51,13 @@ static void AddNewOpToHistory(const OpDef& op, } } -static Status ReadOpHistory(Env* env, const string& file, - const string& directory, - OpCompatibilityLib::OpHistory* out) { +static absl::Status ReadOpHistory(Env* env, const string& file, + const string& directory, + OpCompatibilityLib::OpHistory* out) { // Read op history form `directory` if it exists there. std::vector matching_files; - Status status = env->GetMatchingPaths(io::JoinPath(directory, "*.pbtxt"), - &matching_files); + absl::Status status = env->GetMatchingPaths( + io::JoinPath(directory, "*.pbtxt"), &matching_files); if (status.ok() && !matching_files.empty()) { printf("Reading op history from %s/*.pbtxt...\n", directory.c_str()); std::sort(matching_files.begin(), matching_files.end()); @@ -110,9 +110,9 @@ OpCompatibilityLib::OpCompatibilityLib(const string& ops_prefix, OpRegistry::Global()->Export(false, &op_list_); } -Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops, - int* added_ops, - OpHistory* out_op_history) { +absl::Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops, + int* added_ops, + OpHistory* out_op_history) { *changed_ops = 0; *added_ops = 0; diff --git a/tensorflow/core/ops/compat/op_compatibility_lib.h b/tensorflow/core/ops/compat/op_compatibility_lib.h index 1693f2bd5cfe27..776a603966252b 100644 --- a/tensorflow/core/ops/compat/op_compatibility_lib.h +++ b/tensorflow/core/ops/compat/op_compatibility_lib.h @@ -70,8 +70,8 @@ class OpCompatibilityLib { // generate a new history adding all changed ops. Sets // *changed_ops/*added_ops to the number of changed/added ops // (ignoring doc changes). - Status ValidateCompatible(Env* env, int* changed_ops, int* added_ops, - OpHistory* out_op_history); + absl::Status ValidateCompatible(Env* env, int* changed_ops, int* added_ops, + OpHistory* out_op_history); private: const string ops_file_; diff --git a/tensorflow/core/ops/compat/update_ops_main.cc b/tensorflow/core/ops/compat/update_ops_main.cc index eae80b8a94f5ee..da5e17e1f69fe4 100644 --- a/tensorflow/core/ops/compat/update_ops_main.cc +++ b/tensorflow/core/ops/compat/update_ops_main.cc @@ -53,7 +53,7 @@ void WriteUpdateTo(const string& directory) { printf("%d changed ops\n%d added ops\n", changed_ops, added_ops); const string& history_dir = compatibility.op_history_directory(); - Status status = env->CreateDir(history_dir); + absl::Status status = env->CreateDir(history_dir); if (!errors::IsAlreadyExists(status)) { TF_QCHECK_OK(status); } diff --git a/tensorflow/core/ops/control_flow_ops.cc b/tensorflow/core/ops/control_flow_ops.cc index b7537609b08f44..4ff25885c5562b 100644 --- a/tensorflow/core/ops/control_flow_ops.cc +++ b/tensorflow/core/ops/control_flow_ops.cc @@ -27,7 +27,7 @@ using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- namespace { -Status SwitchShape(InferenceContext* c) { +absl::Status SwitchShape(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); ShapeHandle out = c->input(0); @@ -43,7 +43,7 @@ Status SwitchShape(InferenceContext* c) { return absl::OkStatus(); } -Status SwitchNShape(InferenceContext* c) { +absl::Status SwitchNShape(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); ShapeHandle out = c->input(0); @@ -121,7 +121,7 @@ REGISTER_OP("RefSelect") // -------------------------------------------------------------------------- namespace { -Status MergeShape(InferenceContext* c) { +absl::Status MergeShape(InferenceContext* c) { ShapeHandle out = c->input(0); if (!c->RankKnown(out)) { out = c->UnknownShape(); diff --git a/tensorflow/core/ops/control_flow_ops_test.cc b/tensorflow/core/ops/control_flow_ops_test.cc index bf6620704da5f6..7e4a78dd72c9ed 100644 --- a/tensorflow/core/ops/control_flow_ops_test.cc +++ b/tensorflow/core/ops/control_flow_ops_test.cc @@ -105,7 +105,7 @@ TEST(ControlFlowOpsTest, RefSelect_ShapeFn) { } // Runs type inference pass on graph -static Status type_inference(Graph& graph) { +static absl::Status type_inference(Graph& graph) { GraphOptimizationPassOptions opt_options; std::unique_ptr graph_ptr(new Graph(OpRegistry::Global())); graph_ptr->Copy(graph); diff --git a/tensorflow/core/ops/count_ops.cc b/tensorflow/core/ops/count_ops.cc index a63e7b59801a35..1ea1d376c1faff 100644 --- a/tensorflow/core/ops/count_ops.cc +++ b/tensorflow/core/ops/count_ops.cc @@ -22,7 +22,7 @@ namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -Status DenseCountSparseOutputShapeFn(InferenceContext *c) { +absl::Status DenseCountSparseOutputShapeFn(InferenceContext *c) { auto values = c->input(0); auto weights = c->input(1); ShapeHandle output; @@ -40,7 +40,7 @@ Status DenseCountSparseOutputShapeFn(InferenceContext *c) { return absl::OkStatus(); } -Status SparseCountSparseOutputShapeFn(InferenceContext *c) { +absl::Status SparseCountSparseOutputShapeFn(InferenceContext *c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); auto rank = c->Dim(c->input(0), 1); @@ -51,7 +51,7 @@ Status SparseCountSparseOutputShapeFn(InferenceContext *c) { return absl::OkStatus(); } -Status RaggedCountSparseOutputShapeFn(InferenceContext *c) { +absl::Status RaggedCountSparseOutputShapeFn(InferenceContext *c) { int32_t rank = c->Rank(c->input(1)); if (rank != c->kUnknownRank) { ++rank; // Add the ragged dimension diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index 3bfc79146fe7bd..8329f3963d7258 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -26,7 +26,7 @@ using shape_inference::ShapeHandle; namespace { -Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) { +absl::Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) { auto* t = c->input_handle_shapes_and_types(0); if (t != nullptr && t->size() == c->num_outputs()) { for (int i = 0; i < c->num_outputs(); ++i) { @@ -88,7 +88,7 @@ REGISTER_OP("DynamicPartition") namespace { -Status DynamicStitchShapeFunction(InferenceContext* c) { +absl::Status DynamicStitchShapeFunction(InferenceContext* c) { int32_t num_partitions; TF_RETURN_IF_ERROR(c->GetAttr("N", &num_partitions)); @@ -158,7 +158,7 @@ REGISTER_OP("ParallelDynamicStitch") // -------------------------------------------------------------------------- namespace { -Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { +absl::Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { ShapeHandle handle; DimensionHandle unused_handle; for (int i = 0; i < c->num_inputs(); ++i) { @@ -171,7 +171,7 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return absl::OkStatus(); } -Status TwoElementOutput(InferenceContext* c) { +absl::Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return absl::OkStatus(); } diff --git a/tensorflow/core/ops/functional_grad.cc b/tensorflow/core/ops/functional_grad.cc index 0c8f7b1dd98612..4b42f8baee9e0d 100644 --- a/tensorflow/core/ops/functional_grad.cc +++ b/tensorflow/core/ops/functional_grad.cc @@ -21,7 +21,7 @@ namespace tensorflow { typedef FunctionDefHelper FDH; -Status MapAccumulateGrad(const AttrSlice& attrs, FunctionDef* ret) { +absl::Status MapAccumulateGrad(const AttrSlice& attrs, FunctionDef* ret) { const NameAttrList* func; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "f", &func)); DataType T; diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index ae3cba006ea9b0..aecdd6c254187d 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -90,7 +90,7 @@ else_branch: A function that takes 'inputs' and returns a list of tensors. whose types are the same as what then_branch returns. )doc"); -Status IfShapeInferenceFn(shape_inference::InferenceContext* c) { +absl::Status IfShapeInferenceFn(shape_inference::InferenceContext* c) { std::vector output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); // If `output_shapes` attr is set use that as the shapes of the outputs @@ -135,7 +135,7 @@ REGISTER_OP("If") .SetIsStateful() .SetShapeFn(IfShapeInferenceFn); -Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) { +absl::Status CaseShapeInferenceFn(shape_inference::InferenceContext* c) { std::vector output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); // If `output_shapes` attr is set use that as the shapes of the outputs @@ -207,7 +207,7 @@ body: A function that takes a list of tensors and returns another by T. )doc"); -Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) { +absl::Status WhileShapeInferenceFn(shape_inference::InferenceContext* c) { std::vector output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); // If `output_shapes` attr is set use that as the shapes of the outputs diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 1cdbe485b03486..72ffa938e49834 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -29,8 +29,10 @@ namespace { // Sets output[0] to shape [batch_dim,height,width,channel_dim], where // height and width come from the size_tensor. -Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, - int size_input_idx, DimensionHandle channel_dim) { +absl::Status SetOutputToSizedImage(InferenceContext* c, + DimensionHandle batch_dim, + int size_input_idx, + DimensionHandle channel_dim) { // Verify shape of size input. ShapeHandle size; TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size)); @@ -61,14 +63,14 @@ Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, return absl::OkStatus(); } -Status ResizeShapeFn(InferenceContext* c) { +absl::Status ResizeShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */, c->Dim(input, 3)); } -Status DecodeImageShapeFn(InferenceContext* c) { +absl::Status DecodeImageShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); DimensionHandle channels_dim; @@ -89,7 +91,7 @@ Status DecodeImageShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status DecodeImageV2ShapeFn(InferenceContext* c) { +absl::Status DecodeImageV2ShapeFn(InferenceContext* c) { ShapeHandle unused; int32_t channels; bool expand_animations; @@ -123,7 +125,7 @@ Status DecodeImageV2ShapeFn(InferenceContext* c) { } } -Status EncodeImageShapeFn(InferenceContext* c) { +absl::Status EncodeImageShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused)); c->set_output(0, c->Scalar()); @@ -131,7 +133,7 @@ Status EncodeImageShapeFn(InferenceContext* c) { } // Allow encoding batches of images. -Status BatchedEncodeImageShapeFn(InferenceContext* c) { +absl::Status BatchedEncodeImageShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input)); ShapeHandle s; @@ -140,7 +142,7 @@ Status BatchedEncodeImageShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status ColorspaceShapeFn(InferenceContext* c) { +absl::Status ColorspaceShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); @@ -154,7 +156,7 @@ Status ColorspaceShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status NMSShapeFn(InferenceContext* c) { +absl::Status NMSShapeFn(InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); @@ -177,7 +179,7 @@ Status NMSShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status SoftNMSShapeFn(InferenceContext* c) { +absl::Status SoftNMSShapeFn(InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes)); @@ -203,7 +205,7 @@ Status SoftNMSShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status CombinedNMSShapeFn(InferenceContext* c) { +absl::Status CombinedNMSShapeFn(InferenceContext* c) { // Get inputs and validate ranks ShapeHandle boxes; // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4] @@ -1133,7 +1135,7 @@ REGISTER_OP("GenerateBoundingBoxProposals") .Output("rois: float") .Output("roi_probabilities: float") .Attr("post_nms_topn: int = 300") - .SetShapeFn([](InferenceContext* c) -> Status { + .SetShapeFn([](InferenceContext* c) -> absl::Status { // make sure input tensors have are correct rank ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold, n_pre_nms, min_box_size; diff --git a/tensorflow/core/ops/io_ops.cc b/tensorflow/core/ops/io_ops.cc index a1d2c0879e4a55..d2540c02bc9f0a 100644 --- a/tensorflow/core/ops/io_ops.cc +++ b/tensorflow/core/ops/io_ops.cc @@ -26,7 +26,7 @@ using shape_inference::ShapeHandle; namespace { -Status ScalarInputsAndOutputs(InferenceContext* c) { +absl::Status ScalarInputsAndOutputs(InferenceContext* c) { ShapeHandle unused; for (int i = 0; i < c->num_inputs(); ++i) { TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); @@ -37,7 +37,7 @@ Status ScalarInputsAndOutputs(InferenceContext* c) { return absl::OkStatus(); } -Status TwoElementVectorAndScalarOutputs(InferenceContext* c) { +absl::Status TwoElementVectorAndScalarOutputs(InferenceContext* c) { ShapeHandle handle; DimensionHandle unused_handle; for (int i = 0; i < c->num_inputs(); ++i) { @@ -50,7 +50,7 @@ Status TwoElementVectorAndScalarOutputs(InferenceContext* c) { return absl::OkStatus(); } -Status TwoElementOutput(InferenceContext* c) { +absl::Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return absl::OkStatus(); } diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index 209953f6f56394..3bc382816ff41f 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -26,8 +26,8 @@ using shape_inference::ShapeHandle; namespace { // Return in the result of making the end of a square matrix. -Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input, - ShapeHandle* out) { +absl::Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input, + ShapeHandle* out) { ShapeHandle s; TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &s)); @@ -40,7 +40,7 @@ Status MakeBatchSquareMatrix(InferenceContext* c, ShapeHandle input, return absl::OkStatus(); } -Status BatchUnchangedSquareShapeFn(InferenceContext* c) { +absl::Status BatchUnchangedSquareShapeFn(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &out)); c->set_output(0, out); @@ -48,7 +48,7 @@ Status BatchUnchangedSquareShapeFn(InferenceContext* c) { } // The first input is [...,K,M] and second input is [...,M,N]. -Status BandedTriangularSolveShapeFn(InferenceContext* c) { +absl::Status BandedTriangularSolveShapeFn(InferenceContext* c) { ShapeHandle lhs; ShapeHandle rhs; @@ -92,7 +92,7 @@ Status BandedTriangularSolveShapeFn(InferenceContext* c) { // The first input is [...,M,N] and second input is either [...,M,K] or [...,M]. // Output is [...,N,K] or [...,N]. If , then input is [...,M,M]. -Status MatrixSolveShapeFn(InferenceContext* c, bool square) { +absl::Status MatrixSolveShapeFn(InferenceContext* c, bool square) { ShapeHandle lhs; ShapeHandle rhs; if (square) { @@ -129,7 +129,7 @@ Status MatrixSolveShapeFn(InferenceContext* c, bool square) { // The first input is [...,M,M] and second input is [...,M,N]. // Output is [...,M,N]. -Status MatrixTriangularSolveShapeFn(InferenceContext* c) { +absl::Status MatrixTriangularSolveShapeFn(InferenceContext* c) { ShapeHandle lhs; ShapeHandle rhs; TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &lhs)); @@ -158,7 +158,7 @@ Status MatrixTriangularSolveShapeFn(InferenceContext* c) { // Input is [...,N,N]. Outputs are: // [...,N];[0], if compute_v is false, // [...,N];[...,N,N], if compute_v is true. -Status SelfAdjointEigV2ShapeFn(InferenceContext* c) { +absl::Status SelfAdjointEigV2ShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(MakeBatchSquareMatrix(c, c->input(0), &input)); DimensionHandle n; @@ -183,7 +183,7 @@ Status SelfAdjointEigV2ShapeFn(InferenceContext* c) { // Input is [...,N,N]. // First and second outputs are: // [...,N,N]; [...,N]. -Status LuShapeFn(InferenceContext* c) { +absl::Status LuShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); @@ -209,7 +209,7 @@ Status LuShapeFn(InferenceContext* c) { // [...,M,M]; [...,M,N], if full_matrices is true, // [...,M,P]; [...,P,N], if full_matrices is false, // where P = min(M,N). -Status QrShapeFn(InferenceContext* c) { +absl::Status QrShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); DimensionHandle m = c->Dim(input, -2); @@ -240,7 +240,7 @@ Status QrShapeFn(InferenceContext* c) { // [...,M,M]; [...,N,N], if compute_uv is true and full_matrices is true, // [...,M,P]; [...,N,P], if compute_uv is true and full_matrices is false, // where P = min(M,N). -Status SvdShapeFn(InferenceContext* c) { +absl::Status SvdShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); DimensionHandle m = c->Dim(input, -2); @@ -281,7 +281,7 @@ Status SvdShapeFn(InferenceContext* c) { // Inputs: [...,1,M], [...,1,M], [...,1,M],[...,M,N]. // Output is [...,M,N]. -Status TridiagonalMatMulShapeFn(InferenceContext* c) { +absl::Status TridiagonalMatMulShapeFn(InferenceContext* c) { ShapeHandle superdiag; ShapeHandle maindiag; ShapeHandle subdiag; @@ -329,7 +329,7 @@ Status TridiagonalMatMulShapeFn(InferenceContext* c) { // The first input is [...,3,M] and second input is [...,M,K]. // Output is [...,M,K]. -Status TridiagonalSolveShapeFn(InferenceContext* c) { +absl::Status TridiagonalSolveShapeFn(InferenceContext* c) { ShapeHandle lhs; ShapeHandle rhs; // Check that rank is at least 2. diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 46466a7916a036..bd363d46e0460c 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -24,7 +24,7 @@ namespace { // Verifies that `shapes_and_types` is a valid list handle and has the right // dtype. -Status VerifyHandleData( +absl::Status VerifyHandleData( shape_inference::InferenceContext* c, const std::vector& shapes_and_types, DataType element_dtype) { @@ -262,7 +262,7 @@ REGISTER_OP("TensorListStack") return absl::OkStatus(); }); -Status TensorListConcatShapeInference( +absl::Status TensorListConcatShapeInference( shape_inference::InferenceContext* c, shape_inference::ShapeHandle element_shape) { DataType element_dtype; diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc index 57109811641b20..d85cad19425192 100644 --- a/tensorflow/core/ops/lookup_ops.cc +++ b/tensorflow/core/ops/lookup_ops.cc @@ -29,7 +29,7 @@ using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- namespace { -Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { +absl::Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { ShapeHandle handle; DimensionHandle unused_handle; for (int i = 0; i < c->num_inputs(); ++i) { @@ -42,7 +42,8 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return absl::OkStatus(); } -Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { +absl::Status ScalarAndTwoElementVectorInputsAndScalarOutputs( + InferenceContext* c) { ShapeHandle handle; DimensionHandle unused_handle; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle)); @@ -56,12 +57,12 @@ Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) { return absl::OkStatus(); } -Status TwoElementOutput(InferenceContext* c) { +absl::Status TwoElementOutput(InferenceContext* c) { c->set_output(0, c->Vector(2)); return absl::OkStatus(); } -Status ScalarOutput(InferenceContext* c) { +absl::Status ScalarOutput(InferenceContext* c) { c->set_output(0, c->Scalar()); return absl::OkStatus(); } @@ -87,11 +88,11 @@ REGISTER_OP("LookupTableFind") return absl::OkStatus(); }); -Status ValidateTableType(InferenceContext* c, - const ShapeAndType& key_shape_and_type, - const string& key_dtype_attr, - const ShapeAndType& value_shape_and_type, - const string& value_dtype_attr) { +absl::Status ValidateTableType(InferenceContext* c, + const ShapeAndType& key_shape_and_type, + const string& key_dtype_attr, + const ShapeAndType& value_shape_and_type, + const string& value_dtype_attr) { DataType key_dtype; TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype)); if (key_shape_and_type.dtype != key_dtype) { @@ -113,10 +114,10 @@ Status ValidateTableType(InferenceContext* c, return absl::OkStatus(); } -Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, - const string& key_dtype_attr, - const string& value_dtype_attr, - ShapeAndType* output_shape_and_type) { +absl::Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys, + const string& key_dtype_attr, + const string& value_dtype_attr, + ShapeAndType* output_shape_and_type) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->size() != 2) { output_shape_and_type->shape = c->UnknownShape(); @@ -316,8 +317,8 @@ REGISTER_OP("LookupTableImportV2") return absl::OkStatus(); }); -Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, - const ShapeHandle& value) { +absl::Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, + const ShapeHandle& value) { c->set_output(0, c->Scalar()); ShapeHandle key_s; @@ -336,12 +337,12 @@ Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key, return absl::OkStatus(); } -Status MutableHashTableShapeFn(InferenceContext* c) { +absl::Status MutableHashTableShapeFn(InferenceContext* c) { return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/c->Scalar()); } -Status MutableHashTableOfTensorsShapeFn(InferenceContext* c) { +absl::Status MutableHashTableOfTensorsShapeFn(InferenceContext* c) { PartialTensorShape value_p; TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); ShapeHandle value_s; @@ -349,7 +350,7 @@ Status MutableHashTableOfTensorsShapeFn(InferenceContext* c) { return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s); } -Status MutableDenseHashTableShapeFn(InferenceContext* c) { +absl::Status MutableDenseHashTableShapeFn(InferenceContext* c) { PartialTensorShape value_p; TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p)); ShapeHandle value_s; diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 647f33a7686c91..9494e676705048 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -146,7 +146,7 @@ REGISTER_OP_GRADIENT("_FusedMulAdd2", FusedMulAdd2Grad<0>); REGISTER_OP_GRADIENT("_FusedMulSub2", FusedMulAdd2Grad<1>); // Cwise binary ops -Status GradForUnaryCwise(FunctionDef* g, std::vector nodes) { +absl::Status GradForUnaryCwise(FunctionDef* g, std::vector nodes) { for (auto& n : nodes) { if (n.attr.empty()) { n.attr = {{"T", "$T"}}; @@ -164,7 +164,7 @@ Status GradForUnaryCwise(FunctionDef* g, std::vector nodes) { return absl::OkStatus(); } -Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"sign"}, "Sign", {"x"}, {}, {"dy"}}, @@ -174,7 +174,7 @@ Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Abs", AbsGrad); -Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"dx"}, "Neg", {"dy"}}, @@ -183,7 +183,7 @@ Status NegGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Neg", NegGrad); -Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Reciprocal", {"x"}}, @@ -196,7 +196,7 @@ Status InvGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Inv", InvGrad); REGISTER_OP_GRADIENT("Reciprocal", InvGrad); -Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { FDH::Const("c", int64_t{2}), @@ -208,7 +208,7 @@ Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Square", SquareGrad); -Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Sqrt", {"x"}}, @@ -222,7 +222,7 @@ Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sqrt", SqrtGrad); -Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}}, @@ -237,7 +237,7 @@ Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad); -Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Exp", {"x"}}, @@ -247,7 +247,7 @@ Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Exp", ExpGrad); -Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Exp", {"x"}}, @@ -257,7 +257,7 @@ Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Expm1", Expm1Grad); -Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}}, @@ -267,7 +267,7 @@ Status LogGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Log", LogGrad); -Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { FDH::Const("const", 1.0f), @@ -279,7 +279,7 @@ Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Log1p", Log1pGrad); -Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}}, @@ -289,7 +289,7 @@ Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sinh", SinhGrad); -Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}}, @@ -299,7 +299,7 @@ Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Cosh", CoshGrad); -Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Tanh", {"x"}}, @@ -313,7 +313,7 @@ Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Tanh", TanhGrad); -Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Asinh", {"x"}}, @@ -324,7 +324,7 @@ Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Asinh", AsinhGrad); -Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Acosh", {"x"}}, @@ -335,7 +335,7 @@ Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Acosh", AcoshGrad); -Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x2"}, "Square", {"x"}}, @@ -349,7 +349,7 @@ Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Atanh", AtanhGrad); -Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"y"}, "Sigmoid", {"x"}}, @@ -363,7 +363,7 @@ Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad); -Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"s"}, "Shape", {"x"}}, @@ -375,7 +375,7 @@ Status SignGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sign", SignGrad); -Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"cos"}, "Cos", {"x"}, {}, {"dy"}}, @@ -385,7 +385,7 @@ Status SinGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sin", SinGrad); -Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"sin"}, "Sin", {"x"}, {}, {"dy"}}, @@ -396,7 +396,7 @@ Status CosGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Cos", CosGrad); -Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x2"}, "Square", {"x"}}, @@ -412,7 +412,7 @@ Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Acos", AcosGrad); -Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x2"}, "Square", {"x"}}, @@ -427,7 +427,7 @@ Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Asin", AsinGrad); -Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"x2"}, "Square", {"x"}}, @@ -441,7 +441,7 @@ Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Atan", AtanGrad); -Status TanGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status TanGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"cosx"}, "Cos", {"x"}}, @@ -453,7 +453,7 @@ Status TanGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Tan", TanGrad); -Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { FDH::Const("zero", 0.f), @@ -463,7 +463,7 @@ Status RealGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Real", RealGrad); -Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { FDH::Const("zero", 0.f), @@ -473,7 +473,7 @@ Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Imag", ImagGrad); -Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"re"}, "Real", {"x"}}, @@ -487,7 +487,7 @@ Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Angle", AngleGrad); -Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForUnaryCwise(g, { {{"dx"}, "Conj", {"dy"}}, @@ -496,7 +496,7 @@ Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Conj", ConjGrad); -Status CastGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status CastGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -516,7 +516,7 @@ REGISTER_OP_GRADIENT("Cast", CastGrad); // // TODO(zhifengc): This can be arrange as a function in the standard // library. -Status GradForBinaryCwise(FunctionDef* g, std::vector body) { +absl::Status GradForBinaryCwise(FunctionDef* g, std::vector body) { // clang-format off std::vector nodes = { {{"sx"}, "Shape", {"x"}}, @@ -551,7 +551,7 @@ Status GradForBinaryCwise(FunctionDef* g, std::vector body) { return absl::OkStatus(); } -Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "Identity", {"dz"}}, @@ -562,7 +562,7 @@ Status AddGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Add", AddGrad); REGISTER_OP_GRADIENT("AddV2", AddGrad); -Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "Identity", {"dz"}}, @@ -572,7 +572,7 @@ Status SubGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sub", SubGrad); -Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { DataType T; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { @@ -594,7 +594,7 @@ Status MulGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Mul", MulGrad); -Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "MulNoNan", {"y", "dz"}}, // y * dz @@ -604,7 +604,7 @@ Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("MulNoNan", MulGrad); -Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "Div", {"dz", "y"}}, @@ -617,7 +617,7 @@ Status DivGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Div", DivGrad); -Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "RealDiv", {"dz", "y"}}, @@ -630,7 +630,7 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("RealDiv", RealDivGrad); -Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "DivNoNan", {"dz", "y"}}, @@ -643,7 +643,7 @@ Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad); -Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off std::vector nodes = { {{"z"}, "Pow", {"x", "y"}}, @@ -684,7 +684,7 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Pow", PowGrad); -Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"zeros"}, "ZerosLike", {"x"}}, @@ -700,7 +700,7 @@ Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Xlogy", XlogyGrad); -Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { FDH::Const("const", 1.0f), @@ -719,7 +719,7 @@ Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad); -Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"zeros"}, "ZerosLike", {"x"}}, @@ -737,7 +737,7 @@ Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Xdivy", XdivyGrad); -Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { FDH::Const("c", int64_t{2}), @@ -751,8 +751,8 @@ Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad); -Status MaximumMinimumGradHelper(const string& comparator, - const AttrSlice& attrs, FunctionDef* g) { +absl::Status MaximumMinimumGradHelper(const string& comparator, + const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"c"}, comparator, {"x", "y"}, {}, {"dz"}}, @@ -763,17 +763,17 @@ Status MaximumMinimumGradHelper(const string& comparator, // clang-format on } -Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) { return MaximumMinimumGradHelper("GreaterEqual", attrs, g); } REGISTER_OP_GRADIENT("Maximum", MaximumGrad); -Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) { return MaximumMinimumGradHelper("LessEqual", attrs, g); } REGISTER_OP_GRADIENT("Minimum", MinimumGrad); -Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForBinaryCwise(g, { {{"gx"}, "Real", {"dz"}}, @@ -784,7 +784,7 @@ Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Complex", ComplexGrad); // Cwise ternary ops. -Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( {"c:bool", "x:T", "y:T", "dz:T"}, @@ -808,7 +808,7 @@ REGISTER_OP_GRADIENT("Select", SelectGrad); // // TODO(zhifengc): This helper is pretty ugly. Do something better. // TODO(zhifengc): This can be arrange as a function in the standard library. -Status GradForReductionOp(FunctionDef* g, std::vector body) { +absl::Status GradForReductionOp(FunctionDef* g, std::vector body) { // Shape manipulation nodes. // clang-format off @@ -855,7 +855,7 @@ Status GradForReductionOp(FunctionDef* g, std::vector body) { return absl::OkStatus(); } -Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForReductionOp(g, { {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}}, @@ -865,7 +865,7 @@ Status SumGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Sum", SumGrad); -Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off return GradForReductionOp(g, { {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"}, @@ -891,8 +891,8 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad); // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad); // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad); -Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, - FunctionDef* g) { +absl::Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, + FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -918,22 +918,23 @@ Status MinMaxGradHelper(const string& op, const AttrSlice& attrs, return absl::OkStatus(); } -Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) { return MinMaxGradHelper("Max", attrs, g); } REGISTER_OP_GRADIENT("Max", MaxGrad); -Status MinGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MinGrad(const AttrSlice& attrs, FunctionDef* g) { return MinMaxGradHelper("Min", attrs, g); } REGISTER_OP_GRADIENT("Min", MinGrad); -static Status MatMulGradHelper(FunctionDef* g, const string& opname, - const string& attr_adj_x, - const string& attr_adj_y, const string& x0, - bool ax0, const string& x1, bool ax1, - const string& y0, bool ay0, const string& y1, - bool ay1, bool enable_broadcasting) { +static absl::Status MatMulGradHelper(FunctionDef* g, const string& opname, + const string& attr_adj_x, + const string& attr_adj_y, const string& x0, + bool ax0, const string& x1, bool ax1, + const string& y0, bool ay0, + const string& y1, bool ay1, + bool enable_broadcasting) { // The final outputs are "dx" and "dy". If we're broadcasting compute // intermediate nodes for now. std::vector nodes = { @@ -986,9 +987,9 @@ static Status MatMulGradHelper(FunctionDef* g, const string& opname, return absl::OkStatus(); } -Status MatMulGradCommon(const string& opname, const string& attr_adj_x, - const string& attr_adj_y, const AttrSlice& attrs, - FunctionDef* g, bool enable_broadcasting) { +absl::Status MatMulGradCommon(const string& opname, const string& attr_adj_x, + const string& attr_adj_y, const AttrSlice& attrs, + FunctionDef* g, bool enable_broadcasting) { DataType T; TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { @@ -1016,19 +1017,19 @@ Status MatMulGradCommon(const string& opname, const string& attr_adj_x, true, "dz", true, "x", true, enable_broadcasting); } -Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) { return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g, false /* enable_broadcasting */); } REGISTER_OP_GRADIENT("MatMul", MatMulGrad); -Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) { return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g, false /* enable_broadcasting */); } REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad); -Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) { return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", attrs, g, true /* enable_broadcasting */); } diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 27892348f74078..41bd7ac4a1820b 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -40,8 +40,8 @@ class MathGradTest : public ::testing::Test { protected: // Unary // dst is the output dtype of op_node. - Status Unary(const FDH::Node& op_node, const Tensor& x, const DataType dst, - Tensor* y) { + absl::Status Unary(const FDH::Node& op_node, const Tensor& x, + const DataType dst, Tensor* y) { const DataType src = x.dtype(); auto adef = [](const string& name, const DataType type) { // E.g., x:float, dy:double @@ -94,7 +94,7 @@ class MathGradTest : public ::testing::Test { return s; } - Status Unary(const string& op, const Tensor& x, Tensor* y) { + absl::Status Unary(const string& op, const Tensor& x, Tensor* y) { const FDH::Node op_node = {{"y"}, op, {"x"}, {{"T", x.dtype()}}}; return Unary(op_node, x, x.dtype(), y); } @@ -412,7 +412,7 @@ class MathGradTest : public ::testing::Test { } }; -void HasError(const Status& s, const string& substr) { +void HasError(const absl::Status& s, const string& substr) { EXPECT_TRUE(absl::StrContains(s.ToString(), substr)) << s << ", expected substring " << substr; } diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 98ebd5244e031d..1e263139ebc2c5 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1124,7 +1124,7 @@ REGISTER_OP("Max") namespace { -Status ArgOpShape(shape_inference::InferenceContext* c) { +absl::Status ArgOpShape(shape_inference::InferenceContext* c) { ShapeHandle dimension_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape)); @@ -1201,7 +1201,7 @@ REGISTER_OP("ArgMin") namespace { -Status SegmentReductionShapeFn(InferenceContext* c) { +absl::Status SegmentReductionShapeFn(InferenceContext* c) { ShapeHandle data_shape; ShapeHandle segment_ids_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); @@ -1217,7 +1217,7 @@ Status SegmentReductionShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status SparseSegmentReductionShapeFn(InferenceContext* c) { +absl::Status SparseSegmentReductionShapeFn(InferenceContext* c) { ShapeHandle data_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); @@ -1241,8 +1241,8 @@ Status SparseSegmentReductionShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status SparseSegmentReductionGradShapeFnImpl(InferenceContext* c, - bool outputs_unique_indices) { +absl::Status SparseSegmentReductionGradShapeFnImpl( + InferenceContext* c, bool outputs_unique_indices) { ShapeHandle data_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); @@ -1285,18 +1285,18 @@ Status SparseSegmentReductionGradShapeFnImpl(InferenceContext* c, return absl::OkStatus(); } -Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { +absl::Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { return SparseSegmentReductionGradShapeFnImpl( c, /*outputs_unique_indices=*/false); } -Status SparseSegmentReductionGradV2ShapeFn(InferenceContext* c) { +absl::Status SparseSegmentReductionGradV2ShapeFn(InferenceContext* c) { return SparseSegmentReductionGradShapeFnImpl(c, /*outputs_unique_indices=*/true); } -Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { +absl::Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { ShapeHandle data_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); @@ -1621,8 +1621,8 @@ REGISTER_OP("Any") namespace { template -Status RangeSize(const Tensor* start_t, const Tensor* limit_t, - const Tensor* delta_t, InferenceContext* const c) { +absl::Status RangeSize(const Tensor* start_t, const Tensor* limit_t, + const Tensor* delta_t, InferenceContext* const c) { T start = start_t->scalar()(); T limit = limit_t->scalar()(); T delta = delta_t->scalar()(); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 86cc18ff5da49b..766eabfd77ff2b 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -226,14 +226,14 @@ TEST(MathOpsTest, Select_ShapeFn) { typedef std::vector> ShapeDtypeV; std::vector> handle_data; std::unique_ptr c; - auto run_inference_for_handles = [&]() -> Status { + auto run_inference_for_handles = [&]() -> absl::Status { CHECK(op_reg_data->shape_inference_fn != nullptr); c.reset(new shape_inference::InferenceContext( TF_GRAPH_DEF_VERSION, op.node_def, op_reg_data->op_def, {PartialTensorShape(), PartialTensorShape(), PartialTensorShape()}, {}, {}, handle_data)); TF_CHECK_OK(c->construction_status()); - Status s = c->Run(op_reg_data->shape_inference_fn); + absl::Status s = c->Run(op_reg_data->shape_inference_fn); LOG(INFO) << "Inference got " << s; return s; }; diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index c08821123bdbfd..f48c0e97ff9212 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -22,7 +22,7 @@ namespace tensorflow { typedef FunctionDefHelper FDH; -Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( "SoftmaxGrad", @@ -47,7 +47,7 @@ Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Softmax", SoftmaxGrad); -Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( "LogSoftmaxGrad", @@ -72,7 +72,7 @@ Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("LogSoftmax", LogSoftmaxGrad); -Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -90,7 +90,7 @@ Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Relu", ReluGrad); -Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -108,7 +108,7 @@ Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Relu6", Relu6Grad); -Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -137,6 +137,7 @@ Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad); +<<<<<<< HEAD Status DropoutGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( @@ -157,6 +158,9 @@ Status DropoutGrad(const AttrSlice& attrs, FunctionDef* g) { REGISTER_OP_GRADIENT("Dropout", DropoutGrad); Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) { +======= +absl::Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) { +>>>>>>> master // clang-format off *g = FDH::Define( // Arg defs @@ -192,7 +196,7 @@ Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad); -Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -223,7 +227,7 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad); -Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -249,7 +253,7 @@ Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("AvgPool", AvgPoolGrad); -Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs @@ -280,7 +284,7 @@ Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("MaxPoolGrad", MaxPoolGradGrad); -Status BiasAddGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status BiasAddGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 095f8ed0b331ee..3579417269595e 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -36,7 +36,7 @@ using shape_inference::ShapeHandle; namespace { -Status FractionalPoolShapeFn(InferenceContext* c) { +absl::Status FractionalPoolShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); @@ -543,7 +543,7 @@ create these operators. namespace { -Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { +absl::Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); @@ -1432,7 +1432,7 @@ REGISTER_OP("InTopKV2") namespace { -Status TopKShapeFn(InferenceContext* c) { +absl::Status TopKShapeFn(InferenceContext* c) { ShapeHandle input; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); @@ -1478,7 +1478,7 @@ inline uint32_t log2_ceil(uint64_t value) { return value == 0 ? 0 : Log2Ceiling(value); } -Status ApproxTopKShape(shape_inference::InferenceContext* c) { +absl::Status ApproxTopKShape(shape_inference::InferenceContext* c) { int64_t k; int64_t reduction_dimension; float recall_target; @@ -1557,7 +1557,7 @@ Status ApproxTopKShape(shape_inference::InferenceContext* c) { c->set_output(1, output_shape); return absl::OkStatus(); } -// LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc) +// LINT.ThenChange(//tensorflow/compiler/xla/hlo/builder/lib/approx_topk_shape.cc) } // namespace diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index a3e801a87099c7..7975ba5ab43181 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -29,9 +29,9 @@ namespace { // Adds output shapes for dense tensors in Parse*Example ops. template // TensorShape or PartialTensorShape -Status AddDenseOutputShapes(const std::vector& dense_shapes, - const ShapeHandle& prefix, InferenceContext* c, - int* output_idx) { +absl::Status AddDenseOutputShapes( + const std::vector& dense_shapes, const ShapeHandle& prefix, + InferenceContext* c, int* output_idx) { for (const auto& dense_shape : dense_shapes) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(dense_shape, &s)); @@ -62,9 +62,9 @@ void AddSparseOutputShapes(int num_sparse, const ShapeHandle input_shape, } // Adds output shapes for ragged tensors in Parse*Example ops. -Status AddRaggedOutputShapes(int num_ragged, bool ragged_rank_2, - const DimensionHandle& num_examples, - InferenceContext* c, int* output_idx) { +absl::Status AddRaggedOutputShapes(int num_ragged, bool ragged_rank_2, + const DimensionHandle& num_examples, + InferenceContext* c, int* output_idx) { DimensionHandle num_splits; TF_RETURN_IF_ERROR(c->Add(num_examples, 1, &num_splits)); // Values diff --git a/tensorflow/core/ops/ragged_array_ops.cc b/tensorflow/core/ops/ragged_array_ops.cc index b9538fc3567658..cb9896d77b7989 100644 --- a/tensorflow/core/ops/ragged_array_ops.cc +++ b/tensorflow/core/ops/ragged_array_ops.cc @@ -22,7 +22,7 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -Status RaggedGatherShapeFn(InferenceContext* c); +absl::Status RaggedGatherShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -184,7 +184,7 @@ REGISTER_OP("RaggedFillEmptyRowsGrad") // Shape Functions //============================================================================== -Status RaggedGatherShapeFn(InferenceContext* c) { +absl::Status RaggedGatherShapeFn(InferenceContext* c) { int num_splits; int64_t PARAMS_RAGGED_RANK; TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/ops/ragged_conversion_ops.cc b/tensorflow/core/ops/ragged_conversion_ops.cc index de715a834fe210..3e28d57a6d970d 100644 --- a/tensorflow/core/ops/ragged_conversion_ops.cc +++ b/tensorflow/core/ops/ragged_conversion_ops.cc @@ -25,7 +25,7 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; namespace { -tensorflow::Status ValidateRowPartitionTypesAndShapes( +absl::Status ValidateRowPartitionTypesAndShapes( const std::vector& row_partition_types, InferenceContext* c) { // Note: the allowed types may be extended in the future. @@ -89,11 +89,11 @@ tensorflow::Status ValidateRowPartitionTypesAndShapes( } // namespace -Status RaggedTensorToSparseShapeFn(InferenceContext* c); -Status RaggedTensorToVariantShapeFn(InferenceContext* c); -Status RaggedTensorFromVariantShapeFn(InferenceContext* c); -Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); -Status RaggedTensorToTensorShapeFn(InferenceContext* c); +absl::Status RaggedTensorToSparseShapeFn(InferenceContext* c); +absl::Status RaggedTensorToVariantShapeFn(InferenceContext* c); +absl::Status RaggedTensorFromVariantShapeFn(InferenceContext* c); +absl::Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c); +absl::Status RaggedTensorToTensorShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -157,7 +157,7 @@ REGISTER_OP("RaggedTensorToTensor") // Shape Functions //============================================================================== -Status RaggedTensorToSparseShapeFn(InferenceContext* c) { +absl::Status RaggedTensorToSparseShapeFn(InferenceContext* c) { int64_t num_splits; TF_RETURN_IF_ERROR(c->GetAttr("RAGGED_RANK", &num_splits)); // TODO(b/112274756): Allow ragged_rank to be 0. @@ -186,7 +186,7 @@ Status RaggedTensorToSparseShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status RaggedTensorToVariantShapeFn(InferenceContext* c) { +absl::Status RaggedTensorToVariantShapeFn(InferenceContext* c) { int64_t num_splits; TF_RETURN_IF_ERROR(c->GetAttr("RAGGED_RANK", &num_splits)); bool batched; @@ -208,7 +208,7 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { +absl::Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { ShapeHandle shape; TF_RETURN_IF_ERROR( c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &shape)); @@ -216,7 +216,7 @@ Status RaggedTensorToVariantGradientShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { +absl::Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { int64_t input_ragged_rank; TF_RETURN_IF_ERROR( c->GetAttr("input_ragged_rank", &input_ragged_rank)); @@ -236,7 +236,7 @@ Status RaggedTensorFromVariantShapeFn(InferenceContext* c) { return absl::OkStatus(); } -tensorflow::Status RaggedTensorToTensorShapeFn(InferenceContext* c) { +absl::Status RaggedTensorToTensorShapeFn(InferenceContext* c) { TensorShapeProto shape; { ShapeHandle shape_handle; diff --git a/tensorflow/core/ops/ragged_math_ops.cc b/tensorflow/core/ops/ragged_math_ops.cc index 08b76b1ed99862..bb927b5f2e0938 100644 --- a/tensorflow/core/ops/ragged_math_ops.cc +++ b/tensorflow/core/ops/ragged_math_ops.cc @@ -22,7 +22,7 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -Status RaggedRangeShapeFn(InferenceContext* c); +absl::Status RaggedRangeShapeFn(InferenceContext* c); //============================================================================== // Registered Ops @@ -42,7 +42,7 @@ REGISTER_OP("RaggedRange") // Shape Functions //============================================================================== -Status RaggedRangeShapeFn(InferenceContext* c) { +absl::Status RaggedRangeShapeFn(InferenceContext* c) { // Check that all inputs (starts, limits, and deltas) have rank 0 or 1. ShapeHandle starts = c->input(0); ShapeHandle limits = c->input(1); diff --git a/tensorflow/core/ops/random_index_shuffle_ops.cc b/tensorflow/core/ops/random_index_shuffle_ops.cc index 8fa4381ed39df8..f4e0f651012b5e 100644 --- a/tensorflow/core/ops/random_index_shuffle_ops.cc +++ b/tensorflow/core/ops/random_index_shuffle_ops.cc @@ -27,7 +27,7 @@ using shape_inference::ShapeHandle; namespace { -static Status StatelessRandomPermuteShape(InferenceContext* c) { +static absl::Status StatelessRandomPermuteShape(InferenceContext* c) { ShapeHandle index_shape, seed_shape, max_index_shape, rounds_shape; // Basic constraints but unknown ranks will not raise errors here. diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc index ff25e6765090c0..464849964dd0ca 100644 --- a/tensorflow/core/ops/random_ops.cc +++ b/tensorflow/core/ops/random_ops.cc @@ -45,7 +45,7 @@ REGISTER_OP("RandomUniformInt") .Attr("T: {int32, int64}") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; - Status s = c->WithRank(c->input(1), 0, &unused); + absl::Status s = c->WithRank(c->input(1), 0, &unused); if (!s.ok()) { return errors::InvalidArgument( "minval must be a scalar; got a tensor of shape ", diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 4f6d30458c7b0a..d999aa29e72541 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -31,14 +31,14 @@ namespace tensorflow { namespace { -Status ReadVariableShapeFn(InferenceContext* c) { +absl::Status ReadVariableShapeFn(InferenceContext* c) { // The user can add a "_shape" atribute to ReadVariableOp nodes. It is // useful for inferring shapes in a function, when no shape information // is passed about input resources. The user can annotate the graph using // the variable capture list of the function. // If the "_shape" attribute is found, it is used to set the output shape. PartialTensorShape p; - Status annotation_found_status = c->GetAttr("_shape", &p); + absl::Status annotation_found_status = c->GetAttr("_shape", &p); if (annotation_found_status.ok()) { ShapeHandle s; TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); @@ -58,7 +58,7 @@ Status ReadVariableShapeFn(InferenceContext* c) { return absl::OkStatus(); } -Status ReadVariablesShapeFn(InferenceContext* c) { +absl::Status ReadVariablesShapeFn(InferenceContext* c) { int n; TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); DataTypeVector value_dtypes; @@ -160,7 +160,7 @@ REGISTER_OP("_ReadVariablesOp") .Attr("dtypes: list(type)") .SetShapeFn(ReadVariablesShapeFn); -Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { +absl::Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FunctionDefHelper::Define( // Arg defs @@ -182,7 +182,7 @@ REGISTER_OP("DestroyResourceOp") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs); -Status CreateAssignShapeFn(InferenceContext* c) { +absl::Status CreateAssignShapeFn(InferenceContext* c) { std::vector handle_shape_and_type; TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( c, &handle_shape_and_type)); @@ -232,7 +232,7 @@ REGISTER_OP("VarIsInitializedOp") .Output("is_initialized: bool") .SetShapeFn(tensorflow::shape_inference::ScalarShape); -Status VariableShapeShapeFn(InferenceContext* c) { +absl::Status VariableShapeShapeFn(InferenceContext* c) { auto* handle_data = c->input_handle_shapes_and_types(0); if (handle_data == nullptr || handle_data->empty()) { c->set_output(0, c->Vector(c->UnknownDim())); @@ -320,7 +320,7 @@ REGISTER_OP("ResourceGatherNd") namespace { -Status ResourceScatterUpdateShape(InferenceContext* c) { +absl::Status ResourceScatterUpdateShape(InferenceContext* c) { std::vector handle_shape_and_type; TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( c, &handle_shape_and_type)); diff --git a/tensorflow/core/ops/risc_ops.cc b/tensorflow/core/ops/risc_ops.cc index 61df18ab80c0ea..4f703a8a640995 100644 --- a/tensorflow/core/ops/risc_ops.cc +++ b/tensorflow/core/ops/risc_ops.cc @@ -22,7 +22,8 @@ limitations under the License. namespace tensorflow { namespace { -Status RiscBinaryNonBroadcastOpShapeFn(shape_inference::InferenceContext* c) { +absl::Status RiscBinaryNonBroadcastOpShapeFn( + shape_inference::InferenceContext* c) { const auto rank = c->Rank(c->input(0)); if (rank != c->Rank(c->input(1))) { return errors::InvalidArgument("Mismatch rank for input."); diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index 57ba9b3f05abe5..0dffdd9f6cd574 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -23,7 +23,7 @@ using shape_inference::InferenceContext; using shape_inference::ShapeHandle; // -------------------------------------------------------------------------- -static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) { +static absl::Status ApplySdcaOptimizerShapeFn(InferenceContext* c) { std::vector sparse_handles; if (c->input("sparse_weights", &sparse_handles).ok()) { TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/ops/sparse_csr_matrix_ops.cc b/tensorflow/core/ops/sparse_csr_matrix_ops.cc index 0a80e10f6a5313..4711b3183bfbe0 100644 --- a/tensorflow/core/ops/sparse_csr_matrix_ops.cc +++ b/tensorflow/core/ops/sparse_csr_matrix_ops.cc @@ -30,8 +30,8 @@ using shape_inference::InferenceContext; using shape_inference::ShapeAndType; using shape_inference::ShapeHandle; -Status GetVariantInput(InferenceContext* c, int index, - ShapeAndType* shape_and_type) { +absl::Status GetVariantInput(InferenceContext* c, int index, + ShapeAndType* shape_and_type) { ShapeHandle variant; TF_RETURN_IF_ERROR(c->WithRank(c->input(index), 0, &variant)); auto* shapes_and_types = c->input_handle_shapes_and_types(index); @@ -45,9 +45,9 @@ Status GetVariantInput(InferenceContext* c, int index, // Validates that a shape represents a (rank-2) square matrix or a (rank-3) // batch of square matrices. -Status ValidateSquareMatrixShape(InferenceContext* c, - const ShapeHandle& matrix_shape, - DimensionHandle* matrix_dimension) { +absl::Status ValidateSquareMatrixShape(InferenceContext* c, + const ShapeHandle& matrix_shape, + DimensionHandle* matrix_dimension) { ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtLeast(matrix_shape, 2, &out)); TF_RETURN_IF_ERROR(c->WithRankAtMost(matrix_shape, 3, &out)); diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc index 77f9f23c067246..0110c3b2a31061 100644 --- a/tensorflow/core/ops/sparse_ops.cc +++ b/tensorflow/core/ops/sparse_ops.cc @@ -27,7 +27,7 @@ using shape_inference::ShapeHandle; namespace { -Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) { +absl::Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) { ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 57782f2abf81d6..f1f4a599eee663 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -92,7 +92,8 @@ REGISTER_OP("IFFTND") return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }); -Status RFFTShape(InferenceContext* c, const bool forward, const int rank) { +absl::Status RFFTShape(InferenceContext* c, const bool forward, + const int rank) { ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc index 15eae501f1d10e..985570ba1f16e7 100644 --- a/tensorflow/core/ops/state_ops.cc +++ b/tensorflow/core/ops/state_ops.cc @@ -113,7 +113,7 @@ REGISTER_OP("AssignSub") namespace { -Status ScatterUpdateShape(InferenceContext* c) { +absl::Status ScatterUpdateShape(InferenceContext* c) { ShapeHandle var_shape = c->input(0); ShapeHandle indices_shape = c->input(1); @@ -131,7 +131,7 @@ Status ScatterUpdateShape(InferenceContext* c) { return absl::OkStatus(); } -Status ScatterNdUpdateShape(InferenceContext* c) { +absl::Status ScatterNdUpdateShape(InferenceContext* c) { ShapeHandle input_shape = c->input(0); if (c->input_handle_shapes_and_types(0) != nullptr) { const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); diff --git a/tensorflow/core/ops/stateful_random_ops.cc b/tensorflow/core/ops/stateful_random_ops.cc index 1728143b6226f7..343280cc048cf9 100644 --- a/tensorflow/core/ops/stateful_random_ops.cc +++ b/tensorflow/core/ops/stateful_random_ops.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -Status StatefulRandomShape(shape_inference::InferenceContext* c) { +absl::Status StatefulRandomShape(shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; // Check algorithm shape ShapeHandle unused; @@ -61,7 +61,7 @@ REGISTER_OP("StatefulUniformInt") // Check inputs ShapeHandle unused; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - Status s = c->WithRank(c->input(3), 0, &unused); + absl::Status s = c->WithRank(c->input(3), 0, &unused); if (!s.ok()) { return errors::InvalidArgument( "minval must be a scalar; got a tensor of shape ", diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index ca4fa4eae9752e..f9b13b3035c6c4 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -22,7 +22,7 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(InferenceContext* c) { +static absl::Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); @@ -63,7 +63,7 @@ REGISTER_OP("StatelessRandomUniformInt") .Attr("Tseed: {int32, int64} = DT_INT64") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; - Status s = c->WithRank(c->input(2), 0, &unused); + absl::Status s = c->WithRank(c->input(2), 0, &unused); if (!s.ok()) { return errors::InvalidArgument( "minval must be a scalar; got a tensor of shape ", diff --git a/tensorflow/core/ops/stateless_random_ops_v2.cc b/tensorflow/core/ops/stateless_random_ops_v2.cc index f9d2e6abe67cae..3769a8ae8f6a14 100644 --- a/tensorflow/core/ops/stateless_random_ops_v2.cc +++ b/tensorflow/core/ops/stateless_random_ops_v2.cc @@ -23,7 +23,7 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShapeV2(InferenceContext* c) { +static absl::Status StatelessShapeV2(InferenceContext* c) { // Check key and counter shapes ShapeHandle key; ShapeHandle counter; @@ -70,7 +70,7 @@ REGISTER_OP("StatelessRandomUniformIntV2") .Attr("Tshape: {int32, int64} = DT_INT32") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; - Status s = c->WithRank(c->input(4), 0, &unused); + absl::Status s = c->WithRank(c->input(4), 0, &unused); if (!s.ok()) { return errors::InvalidArgument( "minval must be a scalar; got a tensor of shape ", diff --git a/tensorflow/core/ops/tpu_embedding_ops.cc b/tensorflow/core/ops/tpu_embedding_ops.cc index 4fe61a80451164..2ee7791cb2ed20 100644 --- a/tensorflow/core/ops/tpu_embedding_ops.cc +++ b/tensorflow/core/ops/tpu_embedding_ops.cc @@ -62,7 +62,7 @@ REGISTER_OP("RecvTPUEmbeddingActivations") .Attr("num_outputs: int >= 1") .Attr("config: string") .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { std::string config_string; TF_RETURN_IF_ERROR(c->GetAttr("config", &config_string)); tpu::TPUEmbeddingConfiguration config; @@ -101,7 +101,7 @@ REGISTER_OP("SendTPUEmbeddingGradients") .Attr("NN: int >= 0 = 0") .Attr("config: string") .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { int nn; TF_RETURN_IF_ERROR(c->GetAttr("NN", &nn)); std::vector learning_rates; @@ -136,7 +136,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch") .Attr("device_ordinal: int = -1") .Attr("combiners: list(string) = []") .SetIsStateful() - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { std::vector combiners; TF_RETURN_IF_ERROR(c->GetAttr("combiners", &combiners)); int n; diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc index c250df3eba3ac8..5f48cc42197e62 100644 --- a/tensorflow/core/ops/training_ops.cc +++ b/tensorflow/core/ops/training_ops.cc @@ -49,8 +49,8 @@ ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) { // is an input+output parameter, containing the current known input shape to // the gradient. template -static Status HandleGradAndIndicesInputs(InferenceContext* c, int grad_idx, - ShapeHandle* s) { +static absl::Status HandleGradAndIndicesInputs(InferenceContext* c, + int grad_idx, ShapeHandle* s) { ShapeHandle grad = ShapeOrHandleShape(c, grad_idx); if (!is_sparse) { TF_RETURN_IF_ERROR(c->Merge(*s, grad, s)); @@ -76,7 +76,7 @@ static Status HandleGradAndIndicesInputs(InferenceContext* c, int grad_idx, } template -static Status ApplyGradientDescentShapeFn(InferenceContext* c) { +static absl::Status ApplyGradientDescentShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha @@ -105,7 +105,7 @@ REGISTER_OP("ResourceApplyGradientDescent") .SetShapeFn(ApplyGradientDescentShapeFn); template -Status ApplyProximalGradientDescentShapeFn(InferenceContext* c) { +absl::Status ApplyProximalGradientDescentShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha @@ -170,7 +170,7 @@ REGISTER_OP("ResourceSparseApplyProximalGradientDescent") /*is_resource=*/true>); template -static Status ApplyAdadeltaShapeFn(InferenceContext* c) { +static absl::Status ApplyAdadeltaShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -246,7 +246,7 @@ REGISTER_OP("ResourceSparseApplyAdadelta") .SetShapeFn(ApplyAdadeltaShapeFn); template -static Status ApplyAdagradShapeFn(InferenceContext* c) { +static absl::Status ApplyAdagradShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -308,7 +308,7 @@ REGISTER_OP("ResourceSparseApplyAdagrad") .SetShapeFn(ApplyAdagradShapeFn); template -static Status ApplyAdagradV2ShapeFn(InferenceContext* c) { +static absl::Status ApplyAdagradV2ShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -378,7 +378,7 @@ REGISTER_OP("ResourceSparseApplyAdagradV2") ApplyAdagradV2ShapeFn); template -static Status ApplyProximalAdagradShapeFn(InferenceContext* c) { +static absl::Status ApplyProximalAdagradShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -449,7 +449,7 @@ REGISTER_OP("ResourceSparseApplyProximalAdagrad") ApplyProximalAdagradShapeFn); template -static Status ApplyAdagradDAShapeFn(InferenceContext* c) { +static absl::Status ApplyAdagradDAShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), @@ -532,7 +532,7 @@ REGISTER_OP("ResourceSparseApplyAdagradDA") ApplyAdagradDAShapeFn); template -static Status ApplyFtrlShapeFn(InferenceContext* c) { +static absl::Status ApplyFtrlShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -681,7 +681,7 @@ REGISTER_OP("ResourceSparseApplyFtrlV2") .SetShapeFn(ApplyFtrlShapeFn); template -static Status ApplyMomentumShapeFn(InferenceContext* c) { +static absl::Status ApplyMomentumShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -776,7 +776,7 @@ REGISTER_OP("ResourceSparseApplyKerasMomentum") .SetShapeFn(ApplyMomentumShapeFn); template -static Status ApplyAdamShapeFn(InferenceContext* c) { +static absl::Status ApplyAdamShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -832,7 +832,7 @@ REGISTER_OP("ResourceApplyAdam") .SetShapeFn(ApplyAdamShapeFn); template -static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c) { +static absl::Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -873,7 +873,7 @@ REGISTER_OP("ResourceApplyAdamWithAmsgrad") .SetShapeFn(ApplyAdamWithAmsgradShapeFn); template -static Status ApplyAdaMaxShapeFn(InferenceContext* c) { +static absl::Status ApplyAdaMaxShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -924,7 +924,7 @@ REGISTER_OP("ResourceApplyAdaMax") .SetShapeFn(ApplyAdaMaxShapeFn); template -static Status ApplyRMSPropShapeFn(InferenceContext* c) { +static absl::Status ApplyRMSPropShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -1003,7 +1003,7 @@ REGISTER_OP("ResourceSparseApplyRMSProp") .SetShapeFn(ApplyRMSPropShapeFn); template -static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c) { +static absl::Status ApplyCenteredRMSPropShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -1091,7 +1091,7 @@ REGISTER_OP("ResourceSparseApplyCenteredRMSProp") ApplyCenteredRMSPropShapeFn); template -static Status ApplyAddSignShapeFn(InferenceContext* c) { +static absl::Status ApplyAddSignShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( @@ -1135,7 +1135,7 @@ REGISTER_OP("ResourceApplyAddSign") .SetShapeFn(ApplyAddSignShapeFn); template -static Status ApplyPowerSignShapeFn(InferenceContext* c) { +static absl::Status ApplyPowerSignShapeFn(InferenceContext* c) { ShapeHandle unused; ShapeHandle s = ShapeOrHandleShape(c, 0); // var TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/ops/uniform_quant_ops.cc b/tensorflow/core/ops/uniform_quant_ops.cc index c5fcb762dabd13..b5f0baccc78a8c 100644 --- a/tensorflow/core/ops/uniform_quant_ops.cc +++ b/tensorflow/core/ops/uniform_quant_ops.cc @@ -43,9 +43,10 @@ absl::StatusOr ToTensorShape(ShapeHandle shape_handle, return shape; } -Status ScalesZeroPointsShapeValid(shape_inference::InferenceContext* context, - DimensionHandle match_dimension_handle, - ShapeHandle scales, ShapeHandle zero_points) { +absl::Status ScalesZeroPointsShapeValid( + shape_inference::InferenceContext* context, + DimensionHandle match_dimension_handle, ShapeHandle scales, + ShapeHandle zero_points) { const int32_t scales_rank = shape_inference::InferenceContext::Rank(scales); const int32_t zero_points_rank = shape_inference::InferenceContext::Rank(zero_points); @@ -72,7 +73,7 @@ Status ScalesZeroPointsShapeValid(shape_inference::InferenceContext* context, return absl::OkStatus(); } -Status DotShape(shape_inference::InferenceContext* context) { +absl::Status DotShape(shape_inference::InferenceContext* context) { ShapeHandle lhs; TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs)); ShapeHandle rhs; @@ -115,7 +116,7 @@ Status DotShape(shape_inference::InferenceContext* context) { return absl::OkStatus(); } -Status DotHybridShape(shape_inference::InferenceContext* context) { +absl::Status DotHybridShape(shape_inference::InferenceContext* context) { ShapeHandle lhs; TF_RETURN_IF_ERROR(context->WithRank(context->input(0), 2, &lhs)); ShapeHandle rhs; @@ -177,8 +178,8 @@ struct ShapeCommonParams { is_output_scales_zero_points_set(false) {} }; -Status ConvolutionShapeCommon(shape_inference::InferenceContext* context, - const ShapeCommonParams& params) { +absl::Status ConvolutionShapeCommon(shape_inference::InferenceContext* context, + const ShapeCommonParams& params) { const int32_t lhs_rank = shape_inference::InferenceContext::Rank(params.lhs); const int32_t rhs_rank = shape_inference::InferenceContext::Rank(params.rhs); @@ -237,7 +238,7 @@ Status ConvolutionShapeCommon(shape_inference::InferenceContext* context, return absl::OkStatus(); } -Status ConvolutionShape(shape_inference::InferenceContext* context) { +absl::Status ConvolutionShape(shape_inference::InferenceContext* context) { ShapeHandle lhs; TF_RETURN_IF_ERROR(context->WithRankAtLeast(context->input(0), 2, &lhs)); ShapeHandle rhs; @@ -268,7 +269,8 @@ Status ConvolutionShape(shape_inference::InferenceContext* context) { rhs_zero_points, output_scales, output_zero_points)); } -Status ConvolutionHybridShape(shape_inference::InferenceContext* context) { +absl::Status ConvolutionHybridShape( + shape_inference::InferenceContext* context) { ShapeHandle lhs; TF_RETURN_IF_ERROR(context->WithRankAtLeast(context->input(0), 2, &lhs)); ShapeHandle rhs; diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 0d23a089d6dbf9..a0b34533191e09 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -237,7 +237,6 @@ cc_library( compatible_with = [], deps = [ ":platform", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", ], ) @@ -685,7 +684,6 @@ cc_library( hdrs = if_rocm_is_configured(["rocm.h"]), deps = if_rocm_is_configured([ ":platform", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", ]), ) @@ -753,6 +751,8 @@ cc_library( ":strcat", ":stringprintf", ":types", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@local_tsl//tsl/platform:status", ], ) @@ -850,6 +850,7 @@ cc_library( hdrs = ["tensor_coding.h"], deps = [ ":coding", + ":logging", ":platform", ":protobuf", ":refcount", @@ -1084,13 +1085,12 @@ tf_cuda_library( deps = [ "//tensorflow/core:lib", "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:dso_loader", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/cuda:cuda_platform_id", - "@local_xla//xla/stream_executor/gpu:scoped_activate_context", "@local_xla//xla/stream_executor/host:host_platform_id", - "@local_xla//xla/stream_executor/platform:dso_loader", "@local_xla//xla/stream_executor/rocm:rocm_platform_id", ] + if_rocm_is_configured([ "@local_xla//xla/stream_executor/rocm:miopen_plugin", @@ -1107,13 +1107,13 @@ cc_library( ], features = ["-parse_headers"], deps = [ + "@local_tsl//tsl/platform:dso_loader", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/cuda:cuda_platform_id", "@local_xla//xla/stream_executor/host:host_platform", "@local_xla//xla/stream_executor/host:host_platform_id", - "@local_xla//xla/stream_executor/platform:dso_loader", "@local_xla//xla/stream_executor/rocm:rocm_platform_id", ], ) @@ -1268,7 +1268,8 @@ tf_cc_test( tf_cc_test( name = "fake_python_env_test", - size = "medium", + # Test size is marked as large because it showed a big runtime init overhead in build tests. + size = "large", srcs = ["fake_python_env_test.cc"], args = [ "/some/path/to/pythontest.runfiles/org_tensorflow/stuff/to/run.py", diff --git a/tensorflow/core/platform/build_config.bzl b/tensorflow/core/platform/build_config.bzl index de0453b6deac98..dd10f841a5235c 100644 --- a/tensorflow/core/platform/build_config.bzl +++ b/tensorflow/core/platform/build_config.bzl @@ -11,7 +11,7 @@ load( _tf_additional_rpc_deps = "tf_additional_rpc_deps", _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", - _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_cuda_root_path_deps = "tf_cuda_root_path_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", _tf_google_mobile_srcs_only_runtime = "tf_google_mobile_srcs_only_runtime", @@ -53,7 +53,7 @@ tf_additional_lib_hdrs = _tf_additional_lib_hdrs tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps -tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_cuda_root_path_deps = _tf_cuda_root_path_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime tf_google_mobile_srcs_only_runtime = _tf_google_mobile_srcs_only_runtime diff --git a/tensorflow/core/platform/build_config.default.bzl b/tensorflow/core/platform/build_config.default.bzl index b39d6b913740e5..4eeea4882183f3 100644 --- a/tensorflow/core/platform/build_config.default.bzl +++ b/tensorflow/core/platform/build_config.default.bzl @@ -45,7 +45,7 @@ def tf_protos_all(): Label("//tensorflow/core:protos_all_cc_impl"), "@local_xla//xla:autotune_results_proto_cc_impl", "@local_xla//xla:autotuning_proto_cc_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "@local_xla//xla/tsl/protobuf:protos_all_cc_impl", ], otherwise = [Label("//tensorflow/core:protos_all_cc")], ) diff --git a/tensorflow/core/platform/build_config_root.bzl b/tensorflow/core/platform/build_config_root.bzl index 76fb425ba6f313..744c9eada0d712 100644 --- a/tensorflow/core/platform/build_config_root.bzl +++ b/tensorflow/core/platform/build_config_root.bzl @@ -9,6 +9,7 @@ load( _if_llvm_powerpc_available = "if_llvm_powerpc_available", _if_llvm_system_z_available = "if_llvm_system_z_available", _if_llvm_x86_available = "if_llvm_x86_available", + _if_pywrap = "if_pywrap", _if_static = "if_static", _if_static_and_not_mobile = "if_static_and_not_mobile", _tf_additional_grpc_deps_py = "tf_additional_grpc_deps_py", @@ -35,6 +36,7 @@ if_llvm_system_z_available = _if_llvm_system_z_available if_llvm_x86_available = _if_llvm_x86_available if_dynamic_kernels = _if_dynamic_kernels if_static = _if_static +if_pywrap = _if_pywrap if_static_and_not_mobile = _if_static_and_not_mobile tf_additional_grpc_deps_py = _tf_additional_grpc_deps_py tf_additional_license_deps = _tf_additional_license_deps diff --git a/tensorflow/core/platform/build_config_root.default.bzl b/tensorflow/core/platform/build_config_root.default.bzl index b503d99729b04d..bcea4eebe6d64b 100644 --- a/tensorflow/core/platform/build_config_root.default.bzl +++ b/tensorflow/core/platform/build_config_root.default.bzl @@ -1,5 +1,7 @@ """TODO(jakeharmon): Write module docstring.""" +load("@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") + # unused in TSL def tf_additional_plugin_deps(): return select({ @@ -10,6 +12,10 @@ def tf_additional_plugin_deps(): }) def if_dynamic_kernels(extra_deps, otherwise = []): + # TODO(b/356020232): remove after migration is done + if use_pywrap_rules(): + return otherwise + return select({ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, "//conditions:default": otherwise, diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 5f4ae9b22bda89..e2b61164265cc5 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -155,6 +155,10 @@ cc_library( name = "http_request", hdrs = ["http_request.h"], copts = tsl_copts(), + visibility = [ + ":dependency_allowlist", + "//learning/brain/research/meta_architect/hub/pythia:__pkg__", + ], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_internal", @@ -207,6 +211,10 @@ cc_library( "google_auth_provider.h", ], copts = tsl_copts(), + visibility = [ + ":dependency_allowlist", + "//learning/brain/research/meta_architect/hub/pythia:__pkg__", + ], deps = [ ":compute_engine_metadata_client", ":oauth_client", diff --git a/tensorflow/core/platform/cuda.h b/tensorflow/core/platform/cuda.h index f92e8e3b6ac750..d032f23a093fb2 100644 --- a/tensorflow/core/platform/cuda.h +++ b/tensorflow/core/platform/cuda.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_CUDA_H_ #define TENSORFLOW_CORE_PLATFORM_CUDA_H_ -#include "xla/stream_executor/gpu/scoped_activate_context.h" // IWYU pragma: keep #include "tensorflow/core/platform/platform.h" // IWYU pragma: keep #endif // TENSORFLOW_CORE_PLATFORM_CUDA_H_ diff --git a/tensorflow/core/platform/error_payloads.cc b/tensorflow/core/platform/error_payloads.cc index f208ce7d57a17f..257f80b908f733 100644 --- a/tensorflow/core/platform/error_payloads.cc +++ b/tensorflow/core/platform/error_payloads.cc @@ -20,8 +20,7 @@ namespace tsl { using ::tensorflow::core::platform::ErrorSourceProto; void OkOrSetErrorCounterPayload( - const ErrorSourceProto::ErrorSource& error_source, - tensorflow::Status& status) { + const ErrorSourceProto::ErrorSource& error_source, absl::Status& status) { if (!status.ok() && !status.GetPayload(tensorflow::kErrorSource).has_value()) { ErrorSourceProto error_source_proto; diff --git a/tensorflow/core/platform/error_payloads.h b/tensorflow/core/platform/error_payloads.h index b806b3cbb4b9f2..e976dfc0c470dc 100644 --- a/tensorflow/core/platform/error_payloads.h +++ b/tensorflow/core/platform/error_payloads.h @@ -38,7 +38,7 @@ constexpr char kErrorSource[] = void OkOrSetErrorCounterPayload( const tensorflow::core::platform::ErrorSourceProto::ErrorSource& error_source, - tensorflow::Status& status); + absl::Status& status); } // namespace tsl namespace tensorflow { diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc index 41d4f5819ee996..b07e72b2b187c9 100644 --- a/tensorflow/core/platform/file_system_test.cc +++ b/tensorflow/core/platform/file_system_test.cc @@ -34,28 +34,30 @@ class InterPlanetaryFileSystem : public NullFileSystem { public: TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - Status FileExists(const string& fname, TransactionToken* token) override { + absl::Status FileExists(const string& fname, + TransactionToken* token) override { string parsed_path; ParsePath(fname, &parsed_path); if (BodyExists(parsed_path)) { return absl::OkStatus(); } - return Status(absl::StatusCode::kNotFound, "File does not exist"); + return absl::Status(absl::StatusCode::kNotFound, "File does not exist"); } // Adds the dir to the parent's children list and creates an entry for itself. - Status CreateDir(const string& dirname, TransactionToken* token) override { + absl::Status CreateDir(const string& dirname, + TransactionToken* token) override { string parsed_path; ParsePath(dirname, &parsed_path); // If the directory already exists, throw an error. if (celestial_bodies_.find(parsed_path) != celestial_bodies_.end()) { - return Status(absl::StatusCode::kAlreadyExists, - "dirname already exists."); + return absl::Status(absl::StatusCode::kAlreadyExists, + "dirname already exists."); } std::vector split_path = str_util::Split(parsed_path, '/'); // If the path is too long then we don't support it. if (split_path.size() > 3) { - return Status(absl::StatusCode::kInvalidArgument, "Bad dirname"); + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad dirname"); } if (split_path.empty()) { return absl::OkStatus(); @@ -68,8 +70,8 @@ class InterPlanetaryFileSystem : public NullFileSystem { } if (split_path.size() == 2) { if (!BodyExists(split_path[0])) { - return Status(absl::StatusCode::kFailedPrecondition, - "Base dir not created"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Base dir not created"); } celestial_bodies_[split_path[0]].insert(split_path[1]); celestial_bodies_.insert( @@ -79,18 +81,20 @@ class InterPlanetaryFileSystem : public NullFileSystem { if (split_path.size() == 3) { const string& parent_path = this->JoinPath(split_path[0], split_path[1]); if (!BodyExists(parent_path)) { - return Status(absl::StatusCode::kFailedPrecondition, - "Base dir not created"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Base dir not created"); } celestial_bodies_[parent_path].insert(split_path[2]); celestial_bodies_.insert( std::pair>(parsed_path, {})); return absl::OkStatus(); } - return Status(absl::StatusCode::kFailedPrecondition, "Failed to create"); + return absl::Status(absl::StatusCode::kFailedPrecondition, + "Failed to create"); } - Status IsDirectory(const string& dirname, TransactionToken* token) override { + absl::Status IsDirectory(const string& dirname, + TransactionToken* token) override { string parsed_path; ParsePath(dirname, &parsed_path); // Simulate evil_directory has bad permissions by throwing a LOG(FATAL) @@ -99,16 +103,16 @@ class InterPlanetaryFileSystem : public NullFileSystem { } std::vector split_path = str_util::Split(parsed_path, '/'); if (split_path.size() > 2) { - return Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); + return absl::Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); } if (celestial_bodies_.find(parsed_path) != celestial_bodies_.end()) { return absl::OkStatus(); } - return Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); + return absl::Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); } - Status GetChildren(const string& dir, TransactionToken* token, - std::vector* result) override { + absl::Status GetChildren(const string& dir, TransactionToken* token, + std::vector* result) override { TF_RETURN_IF_ERROR(IsDirectory(dir, nullptr)); string parsed_path; ParsePath(dir, &parsed_path); @@ -154,8 +158,8 @@ class InterPlanetaryFileSystem : public NullFileSystem { // common prefix of BaseDir(). string Match(InterPlanetaryFileSystem* ipfs, const string& suffix_pattern) { std::vector results; - Status s = ipfs->GetMatchingPaths(ipfs->JoinPath(kPrefix, suffix_pattern), - nullptr, &results); + absl::Status s = ipfs->GetMatchingPaths( + ipfs->JoinPath(kPrefix, suffix_pattern), nullptr, &results); if (!s.ok()) { return s.ToString(); } else { @@ -285,16 +289,17 @@ TEST(InterPlanetaryFileSystemTest, CanCreateTempFile) { class TestFileSystem : public NullFileSystem { public: // Only allow for a single root directory. - Status IsDirectory(const string& dirname, TransactionToken* token) override { + absl::Status IsDirectory(const string& dirname, + TransactionToken* token) override { if (dirname == "." || dirname.empty()) { return absl::OkStatus(); } - return Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); + return absl::Status(absl::StatusCode::kFailedPrecondition, "Not a dir"); } // Simulating a FS with a root dir and a single file underneath it. - Status GetChildren(const string& dir, TransactionToken* token, - std::vector* result) override { + absl::Status GetChildren(const string& dir, TransactionToken* token, + std::vector* result) override { if (dir == "." || dir.empty()) { result->push_back("test"); } diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h index fa23b59604a356..b766b42b5975e1 100644 --- a/tensorflow/core/platform/protobuf_internal.h +++ b/tensorflow/core/platform/protobuf_internal.h @@ -26,8 +26,8 @@ namespace tensorflow { // Utility for parsing an Any value with full or lite protos. template -Status ParseAny(const google::protobuf::Any& any, T* message, - const string& type_name) { +absl::Status ParseAny(const google::protobuf::Any& any, T* message, + const string& type_name) { CHECK_EQ(type_name, message->GetTypeName()); if (!any.Is()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/platform/resource_loader_test.cc b/tensorflow/core/platform/resource_loader_test.cc index 75bdca19452590..cbddb772d2666e 100644 --- a/tensorflow/core/platform/resource_loader_test.cc +++ b/tensorflow/core/platform/resource_loader_test.cc @@ -30,7 +30,7 @@ string DataDependencyPath() { TEST(ResourceLoaderTest, FindsAndOpensFile) { string filepath = GetDataDependencyFilepath(DataDependencyPath()); - Status s = Env::Default()->FileExists(filepath); + absl::Status s = Env::Default()->FileExists(filepath); EXPECT_TRUE(s.ok()) << "No file found at this location: " << filepath; } diff --git a/tensorflow/core/platform/rocm.h b/tensorflow/core/platform/rocm.h index 0695ba6750618f..8fc0fa9d9aa8a3 100644 --- a/tensorflow/core/platform/rocm.h +++ b/tensorflow/core/platform/rocm.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_ROCM_H_ #define TENSORFLOW_CORE_PLATFORM_ROCM_H_ -#include "xla/stream_executor/gpu/scoped_activate_context.h" // IWYU pragma: keep #include "tensorflow/core/platform/platform.h" // IWYU pragma: keep #endif // TENSORFLOW_CORE_PLATFORM_ROCM_H_ diff --git a/tensorflow/core/platform/status.h b/tensorflow/core/platform/status.h index f0b57ce207d1d5..99f66009508b82 100644 --- a/tensorflow/core/platform/status.h +++ b/tensorflow/core/platform/status.h @@ -16,25 +16,45 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_STATUS_H_ #define TENSORFLOW_CORE_PLATFORM_STATUS_H_ +#include "absl/base/macros.h" +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/status.h" +#if !defined(ABSL_DEPRECATE_AND_INLINE) +#define ABSL_DEPRECATE_AND_INLINE() +#endif + namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) +#ifdef SWIG using tsl::FromAbslStatus; using tsl::OkStatus; using tsl::Status; +using tsl::ToAbslStatus; +#else +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status FromAbslStatus(const ::absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status ToAbslStatus(const ::absl::Status& s) { return s; } +ABSL_DEPRECATE_AND_INLINE() +inline ::absl::Status OkStatus() { return ::absl::OkStatus(); }; +using Status ABSL_DEPRECATE_AND_INLINE() = ::absl::Status; +#endif using tsl::StatusCallback; using tsl::StatusGroup; using tsl::TfCheckOpHelper; using tsl::TfCheckOpHelperOutOfLine; -using tsl::ToAbslStatus; namespace errors { +#ifdef SWIG using tsl::errors::Code; +#else +using Code ABSL_DEPRECATE_AND_INLINE() = ::absl::StatusCode; +#endif using tsl::errors::GetStackTrace; using tsl::errors::SetStackTrace; } // namespace errors diff --git a/tensorflow/core/platform/stream_executor.h b/tensorflow/core/platform/stream_executor.h index f72e3566645e59..58acf8ebd0ab74 100644 --- a/tensorflow/core/platform/stream_executor.h +++ b/tensorflow/core/platform/stream_executor.h @@ -22,7 +22,6 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" @@ -30,5 +29,6 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/platform/platform.h" #include "tensorflow/core/platform/types.h" +#include "tsl/platform/dso_loader.h" #endif // TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_H_ diff --git a/tensorflow/core/platform/stream_executor_no_cuda.h b/tensorflow/core/platform/stream_executor_no_cuda.h index 53f5ccefed2616..e6013d76f672ee 100644 --- a/tensorflow/core/platform/stream_executor_no_cuda.h +++ b/tensorflow/core/platform/stream_executor_no_cuda.h @@ -22,12 +22,12 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/platform/platform.h" +#include "tsl/platform/dso_loader.h" #endif // TENSORFLOW_CORE_PLATFORM_STREAM_EXECUTOR_NO_CUDA_H_ diff --git a/tensorflow/core/platform/tensor_coding.cc b/tensorflow/core/platform/tensor_coding.cc index dd91086efaf4dc..38f1d26508722f 100644 --- a/tensorflow/core/platform/tensor_coding.cc +++ b/tensorflow/core/platform/tensor_coding.cc @@ -15,9 +15,12 @@ limitations under the License. #include "tensorflow/core/platform/tensor_coding.h" +#include +#include #include #include "tensorflow/core/platform/coding.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringpiece.h" @@ -34,6 +37,13 @@ void AssignRefCounted(StringPiece src, core::RefCounted* obj, string* out) { } void EncodeStringList(const tstring* strings, int64_t n, string* out) { + int64_t tot = n * sizeof(size_t); + for (int i = 0; i < n; ++i) { + tot += strings[i].size(); + } + if (tot > INT_MAX) { + LOG(FATAL) << "EncodeStringList size too large: " << tot; // Crash OK + } out->clear(); for (int i = 0; i < n; ++i) { core::PutVarint32(out, strings[i].size()); diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index 8d2d6f030b670a..8b26a517b4bc0b 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -105,10 +105,10 @@ cc_library( deps = [ "//tensorflow/core/profiler/lib:profiler_factory_impl", "//tensorflow/core/profiler/lib:profiler_session_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", "@local_xla//xla/tsl/profiler/backends/cpu:annotation_stack_impl", "@local_xla//xla/tsl/profiler/backends/cpu:threadpool_listener", "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", + "@local_xla//xla/tsl/profiler/utils:time_utils_impl", ], alwayslink = True, ) @@ -147,6 +147,11 @@ cc_library( deps = [ ":protos_all_cc", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_types_hdr", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/lib/core:status", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", ], @@ -174,14 +179,12 @@ filegroup( # py_proto_library( # name = "profiler_analysis_proto_py_pb2", # has_services = 1, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":profiler_analysis_proto"], # ) # # py_proto_library( # name = "protos_all_py_pb2", -# api_version = 2, # visibility = [":friends"], # deps = [":protos_all"], # ) diff --git a/tensorflow/core/profiler/backends/gpu/BUILD b/tensorflow/core/profiler/backends/gpu/BUILD index e83d4cae4c0501..c803eb6b16d1b4 100644 --- a/tensorflow/core/profiler/backends/gpu/BUILD +++ b/tensorflow/core/profiler/backends/gpu/BUILD @@ -43,10 +43,10 @@ tf_cuda_cc_test( "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/backends/profiler/gpu:cuda_test", "@local_xla//xla/backends/profiler/gpu:cupti_collector", "@local_xla//xla/backends/profiler/gpu:device_tracer", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cupti_headers", diff --git a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc index a00403f982be3f..690128400b8be1 100644 --- a/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/backends/gpu/device_tracer_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda_runtime.h" #include "xla/backends/profiler/gpu/cupti_collector.h" #endif // GOOGLE_CUDA +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/common_runtime/direct_session.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/graph.pb.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/device_name_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" // TODO(b/186367334) #define CUPTI_NVBUG_3299481_WAR (10000 <= CUDA_VERSION && CUDA_VERSION < 11000) @@ -121,7 +121,7 @@ class DeviceTracerTest : public ::testing::Test { } protected: - void ExpectFailure(const Status& status, error::Code code) { + void ExpectFailure(const absl::Status& status, error::Code code) { EXPECT_FALSE(status.ok()) << status; if (!status.ok()) { LOG(INFO) << "Status message: " << status.message(); @@ -162,7 +162,7 @@ TEST_F(DeviceTracerTest, CollectBeforeStop) { if (!tracer) return; TF_EXPECT_OK(tracer->Start()); XSpace space; - Status status = tracer->CollectData(&space); + absl::Status status = tracer->CollectData(&space); ExpectFailure(status, tensorflow::error::FAILED_PRECONDITION); TF_EXPECT_OK(tracer->Stop()); } @@ -173,7 +173,7 @@ TEST_F(DeviceTracerTest, StartTwoTracers) { if (!tracer1 || !tracer2) return; TF_EXPECT_OK(tracer1->Start()); - Status status = tracer2->Start(); + absl::Status status = tracer2->Start(); ExpectFailure(status, tensorflow::error::UNAVAILABLE); TF_EXPECT_OK(tracer1->Stop()); TF_EXPECT_OK(tracer2->Start()); @@ -197,7 +197,7 @@ TEST_F(DeviceTracerTest, RunWithTracer) { std::vector outputs; TF_ASSERT_OK(tracer->Start()); - Status s = session->Run(inputs, output_names, target_nodes, &outputs); + absl::Status s = session->Run(inputs, output_names, target_nodes, &outputs); TF_ASSERT_OK(s); TF_ASSERT_OK(tracer->Stop()); ASSERT_EQ(1, outputs.size()); @@ -224,8 +224,8 @@ TEST_F(DeviceTracerTest, RunWithTraceOption) { RunOptions run_options; run_options.set_trace_level(RunOptions::FULL_TRACE); RunMetadata run_metadata; - Status s = session->Run(run_options, inputs, output_names, target_nodes, - &outputs, &run_metadata); + absl::Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, &run_metadata); TF_ASSERT_OK(s); ASSERT_TRUE(run_metadata.has_step_stats()); // Depending on whether this runs on CPU or GPU, we will have a @@ -393,15 +393,21 @@ TEST_F(DeviceTracerTest, CudaRuntimeResource) { host_plane.ForEachLine([&](const tensorflow::profiler::XLineVisitor& line) { VLOG(3) << "Line " << line.Id() << "\n"; line.ForEachEvent([&](const tensorflow::profiler::XEventVisitor& event) { - VLOG(3) << " Event " << *event.Type() << "\n"; - - absl::optional stat = - event.GetStat(expected_event_stat_type[event_idx]); - // The stat may not exist if we're looking at the wrong line. - if (stat.has_value()) { - event_idx += 1; - VLOG(3) << " Stat name=" << stat->Name() << " type=" << *stat->Type() - << " " << stat->ToString() << "\n"; + if (event_idx < expected_event_stat_type.size()) { + VLOG(3) << " Event " + << (event.Type().has_value() ? std::to_string(*event.Type()) + : "UNKNOWN_TYPE") + << "\n"; + + absl::optional stat = + event.GetStat(expected_event_stat_type[event_idx]); + // The stat may not exist if we're looking at the wrong line. + if (stat.has_value()) { + event_idx += 1; + VLOG(3) << " Stat name=" << stat->Name() << " type=" << *stat->Type() + << " " << stat->ToString() << ", event_idx:" << event_idx + << "\n"; + } } }); }); diff --git a/tensorflow/core/profiler/convert/BUILD b/tensorflow/core/profiler/convert/BUILD index 9233aded397147..d6e8e15016c413 100644 --- a/tensorflow/core/profiler/convert/BUILD +++ b/tensorflow/core/profiler/convert/BUILD @@ -30,10 +30,10 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -139,9 +139,9 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:format_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -240,8 +240,8 @@ cc_library( "//tensorflow/core/profiler/utils:op_metrics_db_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:format_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -299,7 +299,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -326,7 +326,6 @@ cc_library( "//tensorflow/core/profiler/utils:hardware_type_utils", "//tensorflow/core/profiler/utils:hlo_proto_map", "//tensorflow/core/profiler/utils:kernel_stats_utils", - "//tensorflow/core/profiler/utils:math_utils", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", @@ -334,9 +333,11 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -367,6 +368,7 @@ tf_cc_test( ":repository", ":step_events_to_steps_db", ":xplane_to_op_stats", + ":xplane_to_step_events", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:test", @@ -377,11 +379,13 @@ tf_cc_test( "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "//tensorflow/core/profiler/protobuf:tf_function_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:op_metrics_db_utils", "//tensorflow/core/profiler/utils:xplane_builder", "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -402,11 +406,11 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -425,7 +429,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:group_events", ], ) @@ -444,8 +448,8 @@ cc_library( "//tensorflow/core/profiler/utils:trace_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -485,8 +489,8 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -507,7 +511,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_utils", "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -533,7 +537,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -552,7 +556,7 @@ tf_cc_test( "//tensorflow/core/profiler/utils:xplane_schema", "//tensorflow/core/profiler/utils:xplane_test_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:group_events", ], ) @@ -606,9 +610,9 @@ cc_library( "//tensorflow/core/profiler/utils:derived_timeline", "//tensorflow/core/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:preprocess_xplane", - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:preprocess_xplane", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -675,10 +679,10 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -718,7 +722,7 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -796,8 +800,8 @@ cc_library( "//tensorflow/core/profiler/utils:hlo_proto_map", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", "@local_xla//xla/service:hlo_proto_cc", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", ], ) @@ -877,12 +881,13 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:statusor", + "//tensorflow/core/profiler/utils:hlo_module_map", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", ], ) @@ -945,12 +950,12 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -961,8 +966,8 @@ cc_library( visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -972,9 +977,9 @@ tf_cc_test( deps = [ ":dcn_utils", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_builder", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -989,9 +994,9 @@ cc_library( "//tensorflow/core/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -1003,8 +1008,8 @@ cc_library( ":dcn_analysis", "//tensorflow/core/profiler/utils:xplane_utils", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", ], ) @@ -1015,8 +1020,8 @@ tf_cc_test( ":dcn_analysis", ":dcn_utils", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -1038,17 +1043,17 @@ cc_library( "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -1059,7 +1064,7 @@ cc_library( deps = [ "//tensorflow/core/profiler/protobuf:dcn_slack_analysis_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) diff --git a/tensorflow/core/profiler/convert/dcn_analysis.cc b/tensorflow/core/profiler/convert/dcn_analysis.cc index fd54adf88d3fb3..5c58cda325cf33 100644 --- a/tensorflow/core/profiler/convert/dcn_analysis.cc +++ b/tensorflow/core/profiler/convert/dcn_analysis.cc @@ -22,11 +22,11 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/convert/dcn_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_analysis_test.cc b/tensorflow/core/profiler/convert/dcn_analysis_test.cc index f89df221444d4b..345b3752b637aa 100644 --- a/tensorflow/core/profiler/convert/dcn_analysis_test.cc +++ b/tensorflow/core/profiler/convert/dcn_analysis_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/convert/dcn_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc index 32a91c836c1497..6806742f5cec8b 100644 --- a/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc +++ b/tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/convert/dcn_slack_analysis_combiner.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" -#include "tsl/profiler/utils/math_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils.cc b/tensorflow/core/profiler/convert/dcn_utils.cc index 98ecf7dc14106e..7b41905265c385 100644 --- a/tensorflow/core/profiler/convert/dcn_utils.cc +++ b/tensorflow/core/profiler/convert/dcn_utils.cc @@ -17,8 +17,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils.h b/tensorflow/core/profiler/convert/dcn_utils.h index 1149daa4b62be7..e0dd3a174df919 100644 --- a/tensorflow/core/profiler/convert/dcn_utils.h +++ b/tensorflow/core/profiler/convert/dcn_utils.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/dcn_utils_test.cc b/tensorflow/core/profiler/convert/dcn_utils_test.cc index 27c74d79c66407..1d31fcd1502a6d 100644 --- a/tensorflow/core/profiler/convert/dcn_utils_test.cc +++ b/tensorflow/core/profiler/convert/dcn_utils_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc index ab40287a231aa6..f5d940c7d500ef 100644 --- a/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc +++ b/tensorflow/core/profiler/convert/hlo_proto_to_memory_visualization_utils.cc @@ -535,6 +535,7 @@ struct HeapSimulatorStats { // Update memory timelines and seen buffers. heap_size_bytes_timeline.push_back(heap_size_bytes); unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); + hlo_instruction_name_timeline.push_back(event.instruction_name()); const LogicalBufferStruct* logical_buffer = wrapper.GetLogicalBuffer(event.buffer_id()); if (logical_buffer == nullptr) return; @@ -569,7 +570,8 @@ struct HeapSimulatorStats { } // Update stats when memory usage decrease. - Status DecreaseMemoryUsage(LogicalBufferStruct* canonical_logical_buffer) { + absl::Status DecreaseMemoryUsage( + LogicalBufferStruct* canonical_logical_buffer) { int64_t canonical_buffer_id = canonical_logical_buffer->proto.id(); logical_buffers.remove(canonical_buffer_id); heap_size_bytes -= canonical_logical_buffer->size(); @@ -587,10 +589,13 @@ struct HeapSimulatorStats { } // Finalize the memory usage stats from heap simulator trace. - Status FinalizeMemoryUsage() { + absl::Status FinalizeMemoryUsage() { // Add the final heap size after simulating the entire heap trace. heap_size_bytes_timeline.push_back(heap_size_bytes); unpadded_heap_size_bytes_timeline.push_back(unpadded_heap_size_bytes); + // Add an empty instruction name just so that this array is the same size as + // the other two. + hlo_instruction_name_timeline.push_back(""); if (seen_buffer_allocations.size() != 1) { return errors::InvalidArgument( @@ -627,6 +632,7 @@ struct HeapSimulatorStats { // Heap size timeline. std::vector heap_size_bytes_timeline; std::vector unpadded_heap_size_bytes_timeline; + std::vector hlo_instruction_name_timeline; // Position of peak memory usage in the timeline. int64_t peak_heap_size_position = 0; @@ -640,9 +646,9 @@ struct HeapSimulatorStats { int64_t simulator_trace_event_size; }; -Status ProcessHeapSimulatorTrace(const HloProtoBufferWrapper& wrapper, - const int64_t memory_color, - HeapSimulatorStats* stats) { +absl::Status ProcessHeapSimulatorTrace(const HloProtoBufferWrapper& wrapper, + const int64_t memory_color, + HeapSimulatorStats* stats) { int64_t heap_simulator_trace_id = wrapper.GetHeapSimulatorTraceId(memory_color); @@ -1047,6 +1053,8 @@ void GeneratePreprocessResult(const HloProtoBufferWrapper& wrapper, result->add_unpadded_heap_sizes( BytesToMiB(simulator_stats.unpadded_heap_size_bytes_timeline[i]) + add_mib); + result->add_hlo_instruction_names( + simulator_stats.hlo_instruction_name_timeline[i]); } result->set_peak_heap_mib(BytesToMiB(simulator_stats.peak_heap_size_bytes) + diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc index 24c27bcf6d6204..d0cb6d46078eca 100644 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.cc @@ -29,7 +29,7 @@ limitations under the License. namespace tensorflow { namespace profiler { -Status ConvertMultiXSpacesToCombinedOpStats( +absl::Status ConvertMultiXSpacesToCombinedOpStats( const SessionSnapshot& session_snapshot, const OpStatsOptions& options, OpStats* combined_op_stats) { // Read multiple XSpaces and convert to multiple OpStats. diff --git a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h index ddacc41a6c60f6..51348097d321f3 100644 --- a/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h +++ b/tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h @@ -28,7 +28,7 @@ namespace profiler { // . // Return the first error status during conversion, or return OkStatus() if // there is no error. -Status ConvertMultiXSpacesToCombinedOpStats( +absl::Status ConvertMultiXSpacesToCombinedOpStats( const SessionSnapshot& session_snapshot, const OpStatsOptions& options, OpStats* combined_op_stats); diff --git a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc index c02c3d0067e49a..e13e0cb73a2ab5 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.cc @@ -28,6 +28,8 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" @@ -45,8 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { @@ -595,6 +595,7 @@ StepSummary ComputeStepTimeSummaryInMs( // iterates over each core. for (const auto& coreid_and_stepinfo : coreid_stepinfo_map.step_info_per_core()) { + if (coreid_and_stepinfo.first >= kSparseCoreIndexStart) continue; const auto& step_info = coreid_and_stepinfo.second; max_per_step_stats_in_ms = std::max(step_info.duration_ps() / kNumPsPerMs, max_per_step_stats_in_ms); diff --git a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc index 0e0ad42b20a4da..57b974005c3001 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_overview_page.cc @@ -22,6 +22,9 @@ limitations under the License. #include "google/protobuf/any.pb.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_to_record.h" #include "tensorflow/core/profiler/convert/op_stats_to_input_pipeline_analysis.h" @@ -42,9 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc index 043ae143dc969b..760e6439e90e9a 100644 --- a/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc +++ b/tensorflow/core/profiler/convert/preprocess_single_host_xplane.cc @@ -16,11 +16,11 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/derived_timeline.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/preprocess_xplane.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/process_megascale_dcn.cc b/tensorflow/core/profiler/convert/process_megascale_dcn.cc index 947c5e54a19568..2d8313bfc9cb82 100644 --- a/tensorflow/core/profiler/convert/process_megascale_dcn.cc +++ b/tensorflow/core/profiler/convert/process_megascale_dcn.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include "tensorflow/core/profiler/convert/dcn_analysis.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/repository.cc b/tensorflow/core/profiler/convert/repository.cc index abc4a994325d39..fa6f52d3a76754 100644 --- a/tensorflow/core/profiler/convert/repository.cc +++ b/tensorflow/core/profiler/convert/repository.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/path.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tsl/platform/errors.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/repository.h b/tensorflow/core/profiler/convert/repository.h index 55d33af3d4bfbb..af990aa5cb073e 100644 --- a/tensorflow/core/profiler/convert/repository.h +++ b/tensorflow/core/profiler/convert/repository.h @@ -26,13 +26,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/profiler/utils/hlo_module_map.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { @@ -181,6 +182,18 @@ absl::Status ReadBinaryProto(const SessionSnapshot& session_snapshot, return session_snapshot.ReadBinaryProto(data_type, host, proto); } +// Process HloModuleMap from all XSpaces in a session. +inline absl::StatusOr ProcessHloModuleMap( + const SessionSnapshot& session_snapshot) { + HloModuleMap hlo_module_map; + for (int i = 0; i < session_snapshot.XSpaceSize(); i++) { + TF_ASSIGN_OR_RETURN(std::unique_ptr xspace, + session_snapshot.GetXSpace(i)); + ProcessHloModuleMapFromXSpace(hlo_module_map, xspace.get()); + } + return hlo_module_map; +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc index e04e1f47412adc..46fcb4d9473161 100644 --- a/tensorflow/core/profiler/convert/step_events_to_steps_db.cc +++ b/tensorflow/core/profiler/convert/step_events_to_steps_db.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/utils/event_span.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/BUILD b/tensorflow/core/profiler/convert/trace_viewer/BUILD index 85ee0784aada65..0d426cecf12e6b 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/BUILD +++ b/tensorflow/core/profiler/convert/trace_viewer/BUILD @@ -28,7 +28,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -40,7 +40,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:trace_events_proto_cc", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -76,7 +76,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/profiler/lib:context_types_hdrs", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -98,7 +98,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -125,9 +125,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/lib/io:iterator", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/profiler/lib:context_types_hdrs", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/lib/io:iterator", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc index cb0de415dba4f8..994de7aa3fc7c6 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.cc @@ -31,6 +31,11 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_options.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -39,14 +44,9 @@ limitations under the License. #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table.h" -#include "tsl/lib/io/table_builder.h" -#include "tsl/lib/io/table_options.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h index 3b627417e6d706..cbed82e0e51142 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events.h @@ -34,18 +34,18 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/lib/io/table.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h" #include "tensorflow/core/profiler/lib/context_types.h" #include "tensorflow/core/profiler/protobuf/task.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/lib/io/table.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" #include "tsl/platform/status.h" #include "tsl/profiler/lib/context_types.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h index 66de83fe1991b4..6c331e275d2efa 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/time/time.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_viewer_color.h" #include "tensorflow/core/profiler/lib/context_types.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/profiler/lib/context_types.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { @@ -68,6 +68,8 @@ struct JsonTraceOptions { TraceEventsColorerInterface* colorer = nullptr; bool generate_stack_frames = true; + bool use_new_backend = false; + std::string code_link; }; // Counts generated JSON events by type. @@ -193,7 +195,7 @@ class JsonEventWriter { } switch (event.flow_entry_type()) { case TraceEvent::FLOW_NONE: - // The caller prevents this case from happenning. + // The caller prevents this case from happening. break; case TraceEvent::FLOW_START: output_->Append(R"(,"flow_out":true)"); @@ -222,7 +224,7 @@ class JsonEventWriter { } switch (event.flow_entry_type()) { case TraceEvent::FLOW_NONE: - // The caller prevents this case from happenning. + // The caller prevents this case from happening. break; case TraceEvent::FLOW_START: output_->Append(R"(,"ph":"b")"); @@ -514,8 +516,11 @@ void TraceEventsToJson(const JsonTraceOptions& options, // uses higher-precision when manipulating event times. Note that the // timestamps of trace events are always given in microseconds. output->Append( - R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true},)"); + R"({"displayTimeUnit":"ns","metadata":{"highres-ticks":true}, "codeLink":")", + options.code_link, R"(",)"); + output->Append(absl::StrFormat(R"("useNewBackend": %s,)", + options.use_new_backend ? "true" : "false")); WriteDetails(options.details, output); WriteSelectedDeviceIds(options.selected_device_ids, output); WriteReturnedEventsSize(events.NumEvents(), output); diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc index e5f84d14efb270..1080db88ee9d0e 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h index 0d7e8721e2a6c6..4f0e1dc838b830 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h @@ -22,8 +22,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc index 5fe66cd7182f00..c51f18043aa480 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h index 4257384bcf88b9..da503d417c81e7 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility.h @@ -24,9 +24,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_filter_interface.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc index 60a3cdfd939801..e9c4dce6d17a4c 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_viewer_visibility_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_hlo.cc b/tensorflow/core/profiler/convert/xplane_to_hlo.cc index f4af2784f039f6..792a31701f3dd5 100644 --- a/tensorflow/core/profiler/convert/xplane_to_hlo.cc +++ b/tensorflow/core/profiler/convert/xplane_to_hlo.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/core/profiler/convert/repository.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc index 2f2349884d52e3..2f1f14045567df 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -27,8 +29,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc index b289f54baa67f0..612a40bc3f7a80 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { @@ -527,7 +527,8 @@ void ProcessMemoryProfileProto(int64_t max_num_snapshots, } template -Status ConvertProtoToJson(const Proto& proto_output, std::string* json_output) { +absl::Status ConvertProtoToJson(const Proto& proto_output, + std::string* json_output) { protobuf::util::JsonPrintOptions json_options; json_options.always_print_primitive_fields = true; auto status = protobuf::util::MessageToJsonString(proto_output, json_output, @@ -555,8 +556,8 @@ MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, return memory_profile; } -Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output) { +absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, + std::string* json_output) { if (const XPlane* host_plane = FindPlaneWithName(xspace, kHostThreadsPlaneName)) { MemoryProfile memory_profile = ConvertXPlaneToMemoryProfile(*host_plane); diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h index 28b4d805601e2b..00f919d4dbd42e 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile.h +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile.h @@ -32,8 +32,8 @@ namespace profiler { MemoryProfile ConvertXPlaneToMemoryProfile(const XPlane& host_plane, int64_t max_num_snapshots = 1000); -Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, - std::string* json_output); +absl::Status ConvertXSpaceToMemoryProfileJson(const XSpace& xspace, + std::string* json_output); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc index 527735095e088c..8d0415db234e94 100644 --- a/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_memory_profile_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/convert/xplane_to_memory_profile.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/memory_profile.pb.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 8b228479872bcc..0384b80e82d980 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -28,6 +28,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -41,10 +45,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h index 126f0118da1b60..c5d2a229d52bc4 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.h @@ -18,13 +18,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/op_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc index 46f46fabe65f4b..cca87ccc3f668c 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats.cc @@ -22,6 +22,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/convert/step_events_to_steps_db.h" #include "tensorflow/core/profiler/convert/xplane_to_kernel_stats_db.h" @@ -38,19 +43,16 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hardware_type_utils.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" -#include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { namespace { +using tsl::profiler::FindPlanesWithPrefix; using tsl::profiler::FindTensorCorePlanes; std::string Hostname(const XSpace& space) { @@ -78,14 +80,20 @@ PerfEnv MakePerfEnv(double peak_tera_flops_per_second, PerfEnv GetPerfEnvFromXPlane(const XPlane& device_plane) { DeviceCapabilities cap = GetDeviceCaps(device_plane); if (!absl::StartsWith(device_plane.name(), kTpuPlanePrefix)) { - return MakePerfEnv( - tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)) * - cap.num_cores(), - // Ideally, the cap should report separate hbm BW, for now set to same. - {tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth()), - tsl::profiler::UniToGiga(cap.memory_bandwidth())}); + double peak_tera_flops_per_second = + cap.num_cores() * + tsl::profiler::GigaToTera(GetFlopMaxThroughputPerSM(cap)); + double hbm_bw_giga_bytes_per_second = + tsl::profiler::UniToGiga(cap.memory_bandwidth()); + double shm_giga_bytes_per_second = + cap.num_cores() * + tsl::profiler::UniToGiga(GetSharedMemoryBandwidthPerSM(cap)); + // Note that treat SRAM_RD and SRAM_WR as the same. So in future, we could + // only use one for shared memory / L1 cache, one for another like L2. + return MakePerfEnv(peak_tera_flops_per_second, + {/*HBM_RW=*/hbm_bw_giga_bytes_per_second, + /*SRAM_RD=*/shm_giga_bytes_per_second, + /*SRAM_WR=*/shm_giga_bytes_per_second}); } else { XPlaneVisitor visitor = tsl::profiler::CreateTfXPlaneVisitor(&device_plane); auto peak_tera_flops_per_second = @@ -179,11 +187,6 @@ void SetProgramIdToNameMap(const HloProtoMap& hlo_proto_map, OpStats ConvertXSpaceToOpStats(const XSpace& space, const OpStatsOptions& options) { - std::vector device_planes = FindTensorCorePlanes(space); - bool is_tpu = !device_planes.empty(); - if (!is_tpu) { - device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); - } OpStats op_stats; StepEvents step_events; PropagateXSpaceDiagnosticsToOpStats(space, &op_stats); @@ -194,6 +197,14 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space, KernelReportMap reports; + // Handle device planes first. device_planes will contain either GPU or TPU. + std::vector device_planes = + FindPlanesWithPrefix(space, kTpuPlanePrefix); + const bool is_gpu = device_planes.empty(); + if (is_gpu) { + device_planes = FindPlanesWithPrefix(space, kGpuPlanePrefix); + } + const bool is_tpu = !is_gpu; // TODO(b/161942993) parallelize XPlane processing per thread. for (const XPlane* device_trace : device_planes) { XPlane aggregated_xplane; diff --git a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc index 64971ef99cbc25..68c0b29a9f4481 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_stats_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/multi_xplanes_to_op_stats.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tsl/platform/status.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { @@ -86,12 +86,16 @@ TEST(ConvertXPlaneToOpStats, GpuPerfEnv) { TF_CHECK_OK(ConvertMultiXSpacesToCombinedOpStats(session_snapshot_or.value(), options, &op_stats)); const PerfEnv& perf_env = op_stats.perf_env(); - EXPECT_NEAR(141, perf_env.peak_tera_flops_per_second(), kMaxError); + // Change to lower flops number that we do not use sum of the tensor core peak + // flops and the cuda core peak flops together as peak flops. Only use the + // tensor core peak flops as all those white papers are using. + EXPECT_NEAR(125.34, perf_env.peak_tera_flops_per_second(), kMaxError); EXPECT_NEAR( 900, perf_env.peak_bws_giga_bytes_per_second(MemBwType::MEM_BW_TYPE_HBM_RW), kMaxError); - EXPECT_NEAR(156.67, perf_env.ridge_point(), kMaxError); + // Ridge point changed accordingly from above peak flops change. + EXPECT_NEAR(139.26, perf_env.ridge_point(), kMaxError); } TEST(ConvertXPlaneToOpStats, GpuRunEnvironment) { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events.cc b/tensorflow/core/profiler/convert/xplane_to_step_events.cc index 8e50ed1ad18a93..47d1aa8c5f3588 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events.cc @@ -25,6 +25,11 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -33,11 +38,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { @@ -281,11 +281,14 @@ StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { StepEvents device_step_events; XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); std::optional tpu_core_id = tsl::profiler::GetTensorCoreId(plane.Name()); + std::optional sc_core_id = tsl::profiler::GetSparseCoreId(plane.Name()); plane.ForEachLine([&](const XLineVisitor& line) { int64_t line_id = line.Id(); if (line_id == kThreadIdStepInfo || (tpu_core_id.has_value() && - line.Name() == tsl::profiler::kStepLineName)) { + line.Name() == tsl::profiler::kStepLineName) || + (sc_core_id.has_value() && + line.Name() == tsl::profiler::kSparseCoreStepLineName)) { StepEvents step_marker_events = ConvertDeviceStepInfoToStepMarkers(line); UnionCombineStepEvents(step_marker_events, &device_step_events); } else if (IsDerivedThreadId(line_id)) { @@ -300,6 +303,10 @@ StepEvents ConvertDeviceTraceXPlaneToStepEvents(const XPlane& device_trace) { stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents(*tpu_core_id, line); IntersectCombineStepEvents(stream_step_events, &device_step_events); + } else if (sc_core_id.has_value()) { + stream_step_events = ConvertTpuDeviceTraceXLineToStepEvents( + kSparseCoreIndexStart + *sc_core_id, line); + IntersectCombineStepEvents(stream_step_events, &device_step_events); } else { stream_step_events = ConvertDeviceTraceXLineToStepEvents(plane.Id(), line); diff --git a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc index cf4d2b0af40b06..7f6069ec0f511f 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_events_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/group_events.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/group_events.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc index ff084d1e03e1e5..94c7a3da15adfe 100644 --- a/tensorflow/core/profiler/convert/xplane_to_step_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_step_stats.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc index 7159570eee4cfe..c566235840ffd3 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_data_stats.cc @@ -25,15 +25,15 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/tf_data_stats.pb.h" #include "tensorflow/core/profiler/utils/html_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc index 533d937792362f..58b64a1696ed9d 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions.cc @@ -26,6 +26,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -34,8 +36,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/math_utils.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc index ca85cb6005d32e..c7127b80212372 100644 --- a/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_tf_functions_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/profiler/protobuf/tf_function.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc index cfb4f2ec20cf4b..d442c9eca5047a 100644 --- a/tensorflow/core/profiler/convert/xplane_to_trace_container.cc +++ b/tensorflow/core/profiler/convert/xplane_to_trace_container.cc @@ -22,16 +22,16 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_event_arguments_builder.h" #include "tensorflow/core/profiler/convert/trace_viewer/trace_events_util.h" #include "tensorflow/core/profiler/protobuf/trace_events.pb.h" #include "tensorflow/core/profiler/protobuf/trace_events_raw.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc index 82cf25f4e2b180..3a5aeb78da5c14 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.cc @@ -32,6 +32,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" #include "xla/side_effect_util.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" @@ -43,13 +50,6 @@ limitations under the License. #include "tsl/platform/regexp.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h index daac70f634abca..2f9e5551449cf5 100644 --- a/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h +++ b/tensorflow/core/profiler/convert/xspace_to_dcn_slack_analysis.h @@ -28,13 +28,13 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/profiler/protobuf/dcn_collective_info.pb.h" #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/protobuf/topology.pb.h" #include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/internal/print_model_analysis.cc b/tensorflow/core/profiler/internal/print_model_analysis.cc index e486529069a5a0..60dcd90ea131ab 100644 --- a/tensorflow/core/profiler/internal/print_model_analysis.cc +++ b/tensorflow/core/profiler/internal/print_model_analysis.cc @@ -48,7 +48,7 @@ string RunProfile(const string& command, const string& options, } Options opts; - tensorflow::Status s = Options::FromProtoStr(options, &opts); + absl::Status s = Options::FromProtoStr(options, &opts); if (!s.ok()) { absl::FPrintF(stderr, "%s\n", s.ToString()); return ""; diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc index e7e9494361608a..5d7b27c8dd858d 100644 --- a/tensorflow/core/profiler/internal/tfprof_code.cc +++ b/tensorflow/core/profiler/internal/tfprof_code.cc @@ -290,12 +290,12 @@ class PprofProfileImpl : public PprofProfile { samples_->Add(leaf, reversed_call_ids); } - Status WritePprofProfile(const string& filename) override { + absl::Status WritePprofProfile(const string& filename) override { pprof::Profile profile_pb; Build(&profile_pb); std::unique_ptr file; - Status s = Env::Default()->NewWritableFile(filename, &file); + absl::Status s = Env::Default()->NewWritableFile(filename, &file); if (!s.ok()) return s; int32_t buf_size = 1024 * 1024; @@ -517,7 +517,7 @@ const ShowMultiNode* TFCode::ShowInternal(const Options& opts, pprof_profile_ = std::make_unique(&opts); Format(root, root->show_children, opts, &root->formatted_str, root->mutable_proto(), &call_ids); - Status s = pprof_profile_->WritePprofProfile( + absl::Status s = pprof_profile_->WritePprofProfile( opts.output_options.at(kPprofOpts[0])); if (!s.ok()) { absl::FPrintF(stderr, "%s\n", s.ToString()); diff --git a/tensorflow/core/profiler/internal/tfprof_code.h b/tensorflow/core/profiler/internal/tfprof_code.h index cfa8801bcee5e9..5664fb0c74e2b0 100644 --- a/tensorflow/core/profiler/internal/tfprof_code.h +++ b/tensorflow/core/profiler/internal/tfprof_code.h @@ -47,7 +47,7 @@ class PprofProfile { virtual void AddSample(const CodeNode* leaf, std::vector* call_ids) = 0; - virtual Status WritePprofProfile(const string& filename) = 0; + virtual absl::Status WritePprofProfile(const string& filename) = 0; }; class TFCode : public TFMultiShow { diff --git a/tensorflow/core/profiler/internal/tfprof_show.cc b/tensorflow/core/profiler/internal/tfprof_show.cc index 85209e48437e18..d44c5c886e9e14 100644 --- a/tensorflow/core/profiler/internal/tfprof_show.cc +++ b/tensorflow/core/profiler/internal/tfprof_show.cc @@ -38,9 +38,9 @@ const GraphNodeProto& TFShow::Show(const string& prefix, const Options& opts) { absl::PrintF("%s", (prefix + ret->formatted_str)); fflush(stdout); } else if (opts.output_type == kOutput[2]) { - Status s = WriteStringToFile(Env::Default(), - opts.output_options.at(kFileOpts[0]), - prefix + ret->formatted_str); + absl::Status s = WriteStringToFile(Env::Default(), + opts.output_options.at(kFileOpts[0]), + prefix + ret->formatted_str); if (!s.ok()) { absl::FPrintF(stderr, "%s\n", s.ToString()); } diff --git a/tensorflow/core/profiler/internal/tfprof_show_multi.cc b/tensorflow/core/profiler/internal/tfprof_show_multi.cc index 942cf25d145688..970f11140c0469 100644 --- a/tensorflow/core/profiler/internal/tfprof_show_multi.cc +++ b/tensorflow/core/profiler/internal/tfprof_show_multi.cc @@ -43,9 +43,9 @@ const MultiGraphNodeProto& TFMultiShow::Show(const string& prefix, absl::PrintF("%s%s", prefix, ret->formatted_str); fflush(stdout); } else if (opts.output_type == kOutput[2]) { - Status s = WriteStringToFile(Env::Default(), - opts.output_options.at(kFileOpts[0]), - prefix + ret->formatted_str); + absl::Status s = WriteStringToFile(Env::Default(), + opts.output_options.at(kFileOpts[0]), + prefix + ret->formatted_str); if (!s.ok()) { absl::FPrintF(stderr, "%s\n", s.ToString()); } diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc index 6ca840cf71a12e..3ef535b775047f 100644 --- a/tensorflow/core/profiler/internal/tfprof_stats.cc +++ b/tensorflow/core/profiler/internal/tfprof_stats.cc @@ -82,7 +82,7 @@ TFStats::TFStats(const string& filename, miss_accelerator_stream_(false), ckpt_reader_(std::move(ckpt_reader)) { string str; - Status s = ReadFileToString(Env::Default(), filename, &str); + absl::Status s = ReadFileToString(Env::Default(), filename, &str); if (!s.ok()) { absl::FPrintF(stderr, "Failed to read profile: %s", s.ToString()); return; @@ -351,7 +351,7 @@ void TFStats::SerializeToString(string* content) { void TFStats::WriteProfile(const string& filename) { string content; SerializeToString(&content); - Status s = WriteStringToFile(Env::Default(), filename, content); + absl::Status s = WriteStringToFile(Env::Default(), filename, content); if (!s.ok()) { absl::FPrintF(stderr, "%s\n", s.ToString()); } diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.cc b/tensorflow/core/profiler/internal/tfprof_timeline.cc index 7d8e58e81b44d7..b3a16605d4c416 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline.cc +++ b/tensorflow/core/profiler/internal/tfprof_timeline.cc @@ -331,7 +331,7 @@ void Timeline::GenerateCodeTimeline(const CodeNode* node) { void Timeline::OutputTimeline() { std::string outfile = absl::StrFormat("%s_%d", outfile_, step()); - Status s = + absl::Status s = WriteStringToFile(Env::Default(), outfile, chrome_formatter_.Format()); if (!s.ok()) { absl::FPrintF(stderr, "Failed to write timeline file: %s\nError: %s\n", diff --git a/tensorflow/core/profiler/internal/tfprof_utils.cc b/tensorflow/core/profiler/internal/tfprof_utils.cc index ad6759629ebf2c..4551e8746a9ea9 100644 --- a/tensorflow/core/profiler/internal/tfprof_utils.cc +++ b/tensorflow/core/profiler/internal/tfprof_utils.cc @@ -84,12 +84,12 @@ string StripQuote(const string& s) { return s.substr(start, end - start + 1); } -tensorflow::Status ReturnError(const std::vector& pieces, int idx) { +absl::Status ReturnError(const std::vector& pieces, int idx) { string val; if (pieces.size() > idx + 1) { val = pieces[idx + 1]; } - return tensorflow::Status( + return absl::Status( absl::StatusCode::kInvalidArgument, absl::StrCat("Invalid option '", pieces[idx], "' value: '", val, "'")); } @@ -115,15 +115,15 @@ bool StringToBool(StringPiece str, bool* value) { } } // namespace -tensorflow::Status ParseCmdLine(const string& line, string* cmd, - tensorflow::tfprof::Options* opts) { +absl::Status ParseCmdLine(const string& line, string* cmd, + tensorflow::tfprof::Options* opts) { std::vector pieces = absl::StrSplit(line, ' ', absl::SkipEmpty()); std::vector cmds_str(kCmds, kCmds + sizeof(kCmds) / sizeof(*kCmds)); if (std::find(cmds_str.begin(), cmds_str.end(), pieces[0]) == cmds_str.end()) { - return tensorflow::Status(absl::StatusCode::kInvalidArgument, - "First string must be a valid command."); + return absl::Status(absl::StatusCode::kInvalidArgument, + "First string must be a valid command."); } *cmd = pieces[0]; @@ -279,7 +279,7 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd, return ReturnError(pieces, i); } - tensorflow::Status s = + absl::Status s = ParseOutput(pieces[i + 1], &opts->output_type, &opts->output_options); if (!s.ok()) return s; ++i; diff --git a/tensorflow/core/profiler/internal/tfprof_utils.h b/tensorflow/core/profiler/internal/tfprof_utils.h index 057911d5e7d768..7f4e49bac3b3b0 100644 --- a/tensorflow/core/profiler/internal/tfprof_utils.h +++ b/tensorflow/core/profiler/internal/tfprof_utils.h @@ -33,30 +33,30 @@ string FormatMemory(int64_t bytes); string FormatShapes(const std::vector& shapes); -tensorflow::Status ParseCmdLine(const string& line, string* cmd, - tensorflow::tfprof::Options* opts); +absl::Status ParseCmdLine(const string& line, string* cmd, + tensorflow::tfprof::Options* opts); string StringReplace(const string& str, const string& oldsub, const string& newsub); template -Status ReadProtoFile(Env* env, const string& fname, T* proto, - bool binary_first) { +absl::Status ReadProtoFile(Env* env, const string& fname, T* proto, + bool binary_first) { string out; - Status s = ReadFileToString(env, fname, &out); + absl::Status s = ReadFileToString(env, fname, &out); if (!s.ok()) return s; if (binary_first) { if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { - return Status(); + return absl::Status(); } else if (protobuf::TextFormat::ParseFromString(out, proto)) { - return Status(); + return absl::Status(); } } else { if (protobuf::TextFormat::ParseFromString(out, proto)) { - return Status(); + return absl::Status(); } else if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { - return Status(); + return absl::Status(); } } return errors::InvalidArgument("Cannot parse proto file."); diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD index 55d726723fc752..6b1ca8e6be8744 100644 --- a/tensorflow/core/profiler/lib/BUILD +++ b/tensorflow/core/profiler/lib/BUILD @@ -144,7 +144,7 @@ cc_library( "@local_tsl//tsl/profiler/lib:traceme_encode", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder", - "@local_tsl//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), ) diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h index 51e7e8ba5fbbe7..23e48948095811 100644 --- a/tensorflow/core/profiler/lib/traceme.h +++ b/tensorflow/core/profiler/lib/traceme.h @@ -20,7 +20,7 @@ limitations under the License. #include "tsl/profiler/lib/traceme.h" #if !defined(IS_MOBILE_PLATFORM) -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #endif // TODO: b/323943471 - This macro should eventually be provided by Abseil. diff --git a/tensorflow/core/profiler/protobuf/BUILD b/tensorflow/core/profiler/protobuf/BUILD index 7a79e4a8ba7939..13cce56d193865 100644 --- a/tensorflow/core/profiler/protobuf/BUILD +++ b/tensorflow/core/profiler/protobuf/BUILD @@ -200,28 +200,24 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xplane_py_pb2", -# api_version = 2, # visibility = [":friends"], # deps = [":xplane_proto"], # ) # # py_proto_library( # name = "memory_viewer_preprocess_py_pb2", -# api_version = 2, # visibility = [":memory_viewer_friends"], # deps = [":memory_viewer_preprocess_proto"], # ) # # py_proto_library( # name = "op_profile_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_profile_proto"], # ) # # py_proto_library( # name = "op_metrics_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":op_metrics_proto"], # ) diff --git a/tensorflow/core/profiler/protobuf/memory_viewer_preprocess.proto b/tensorflow/core/profiler/protobuf/memory_viewer_preprocess.proto index 6ebefd183c0a4d..32bd5a7d8d4392 100644 --- a/tensorflow/core/profiler/protobuf/memory_viewer_preprocess.proto +++ b/tensorflow/core/profiler/protobuf/memory_viewer_preprocess.proto @@ -56,6 +56,9 @@ message PreprocessResult { // and dimensionality) at each HLO program point (the HLO sequential order). repeated double unpadded_heap_sizes = 2; + // The HloInstruction that was being processed at this HLO program point. + repeated string hlo_instruction_names = 20; + // Heap objects at the peak memory usage point ordered by HLO program "birth" // time. repeated HeapObject max_heap = 3; diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 96d9a50408aa1a..89e9735fbc2190 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -53,9 +53,9 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:time_utils", "@local_xla//xla/tsl/profiler/rpc:profiler_service_impl", + "@local_xla//xla/tsl/profiler/utils:file_system_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/profiler/tfprof_options.cc b/tensorflow/core/profiler/tfprof_options.cc index 0f4ec58c540236..595d4190997baa 100644 --- a/tensorflow/core/profiler/tfprof_options.cc +++ b/tensorflow/core/profiler/tfprof_options.cc @@ -15,9 +15,13 @@ limitations under the License. #include "tensorflow/core/profiler/tfprof_options.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" -#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/tfprof_options.pb.h" namespace tensorflow { diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h index d8704dd736bab7..c1f13ebf355b27 100644 --- a/tensorflow/core/profiler/tfprof_options.h +++ b/tensorflow/core/profiler/tfprof_options.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace tfprof { diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index fcd5c26b7e7104..a7fc62c6b90164 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -43,7 +43,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -56,7 +56,21 @@ cc_library( ":xplane_schema", "//tensorflow/core:lib", "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/profiler/utils:math_utils", + ], +) + +tf_cc_test( + name = "hardware_type_utils_test", + srcs = ["hardware_type_utils_test.cc"], + deps = [ + ":hardware_type_utils", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) @@ -65,7 +79,7 @@ cc_library( hdrs = ["math_utils.h"], deps = [ "@com_google_absl//absl/base:core_headers", - "@local_tsl//tsl/profiler/utils:math_utils", + "@local_xla//xla/tsl/profiler/utils:math_utils", ], ) @@ -90,9 +104,9 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -119,8 +133,8 @@ cc_library( "//tensorflow/core/profiler/convert:op_metrics_db_combiner", "//tensorflow/core/profiler/protobuf:op_metrics_proto_cc", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -129,7 +143,7 @@ cc_library( hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), deps = [ - "@local_tsl//tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:trace_utils", ], ) @@ -139,7 +153,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_builder", + "@local_xla//xla/tsl/profiler/utils:xplane_builder", ], ) @@ -150,7 +164,7 @@ cc_library( visibility = [":friends"], deps = [ "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -160,7 +174,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_utils", ], ) @@ -176,7 +190,7 @@ cc_library( "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/profiler/utils:xplane_test_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_test_utils", ], ) @@ -186,7 +200,7 @@ cc_library( copts = tf_profiler_copts(), visibility = [":friends"], deps = [ - "@local_tsl//tsl/profiler/utils:xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_visitor", ], ) @@ -208,7 +222,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", ], ) @@ -227,8 +241,8 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/profiler/utils:timespan", "@local_xla//xla:shape_util", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -255,14 +269,14 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:tpu_xplane_utils", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:tpu_xplane_utils", + "@local_xla//xla/tsl/profiler/utils:trace_utils", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", "@local_xla//xla/tsl/util:stats_calculator_portable", ], ) @@ -282,9 +296,10 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:group_events", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", + "@com_google_googletest//:gtest_main", + "@local_xla//xla/tsl/profiler/utils:group_events", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:xplane_schema", ], ) @@ -343,7 +358,7 @@ cc_library( "//tensorflow/core/platform:types", "//tensorflow/core/profiler/protobuf:steps_db_proto_cc", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/profiler/utils:timespan", + "@local_xla//xla/tsl/profiler/utils:timespan", ], ) @@ -371,7 +386,7 @@ cc_library( ":xplane_visitor", "//tensorflow/core/profiler/protobuf:hardware_types_proto_cc", "//tensorflow/core/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -404,9 +419,9 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/profiler/convert:xla_op_utils", + "@local_xla//xla/tsl/profiler/utils:tf_xplane_visitor", ], ) @@ -433,11 +448,13 @@ tf_cuda_library( ], visibility = [":friends"], deps = [ + ":hlo_proto_map", ":hlo_proto_to_module", "//tensorflow/core/platform:path", "//tensorflow/core/profiler/lib:traceme_encode", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/service:hlo_cost_analysis", "@local_xla//xla/service:hlo_proto_cc", @@ -449,3 +466,37 @@ cc_library( hdrs = ["hlo_module_utils.h"], deps = ["@local_xla//xla/hlo/ir:hlo"], ) + +cc_library( + name = "xprof_gpu_cost_analysis", + srcs = ["xprof_gpu_cost_analysis.cc"], + hdrs = ["xprof_gpu_cost_analysis.h"], + visibility = [":friends"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_xla//xla:shape_util", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_cost_analysis", + "@local_xla//xla/service/gpu/model:gpu_hlo_cost_analysis", + ], +) + +tf_cc_test( + name = "xprof_gpu_cost_analysis_test", + srcs = ["xprof_gpu_cost_analysis_test.cc"], + deps = [ + ":xprof_gpu_cost_analysis", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:shape_util", + "@local_xla//xla:test_helpers", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/service:hlo_cost_analysis", + "@local_xla//xla/service/gpu/model:hlo_op_profiles", + "@local_xla//xla/tests:hlo_test_base", + "@local_xla//xla/tests:xla_internal_test_main", + ], +) diff --git a/tensorflow/core/profiler/utils/cost_utils.cc b/tensorflow/core/profiler/utils/cost_utils.cc index 2cbd2590f0c525..f1899f17ac30fd 100644 --- a/tensorflow/core/profiler/utils/cost_utils.cc +++ b/tensorflow/core/profiler/utils/cost_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index a895af6a00b259..9da9ac89efe9e9 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -26,6 +26,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "xla/tsl/util/stats_calculator.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -40,13 +47,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { @@ -140,6 +140,10 @@ DerivedXLineBuilder::DerivedXLineBuilder( int64_t timestamp_ns, std::vector dependent_lines) : group_id_stat_metadata_( plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))), + correlation_id_metadata_(plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCorrelationId))), + cuda_graph_id_metadata_(plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kCudaGraphId))), line_(plane->GetOrCreateLine(line_id)), dependent_lines_(std::move(dependent_lines)) { line_.SetName(name); @@ -185,13 +189,31 @@ void DerivedXLineBuilder::ExpandOrAddLevelEvent( } } +void DerivedXLineBuilder::AddStatToLevelEvent(int level, + const XStatMetadata& metadata, + int64_t value) { + if (auto it = last_event_by_level_.find(level); + it != last_event_by_level_.end() && it->second.has_value()) { + it->second->SetOrAddStatValue(metadata, value); + } +} + +void DerivedXLineBuilder::AddStatToLevelEvent(int level, + const XStatMetadata& metadata, + uint64_t value) { + if (auto it = last_event_by_level_.find(level); + it != last_event_by_level_.end() && it->second.has_value()) { + it->second->SetOrAddStatValue(metadata, value); + } +} + // When deriving a bunch of events with the same timespan, there could be // indeterministic behavior of how trace viewer stacking these events. // This function will shrink the stack of events with the same timespan when -// necessary. Event at top of stack might shrink more than event at the bottom. -// Because the time unit in trace viewer is nanosecond, therefore the minimum -// difference is 1ns. However to prevent shrink induced inconsitency, we can -// not shrink more than the duration of event at the top of the stack. +// necessary. Event at top of stack might shrink more than event at the +// bottom. Because the time unit in trace viewer is nanosecond, therefore the +// minimum difference is 1ns. However to prevent shrink induced inconsitency, +// we can not shrink more than the duration of event at the top of the stack. void DerivedXLineBuilder::AdjustDurationForTraceViewer(int level) { if (level >= last_event_by_level_.size() || !last_event_by_level_[level]) return; @@ -286,8 +308,8 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, GetSortedEvents(plane_visitor)) { GpuEventStats stats(&event); // For HLO/TF op lines, only use kernel events (i.e. excluding memcpy or - // allocation events). Also CudaGraph executions are also treated as kernel - // events. + // allocation events). Also CudaGraph executions are also treated as + // kernel events. if (!stats.IsKernel() && !stats.IsCudaGraphExecution()) continue; tsl::profiler::Timespan event_span = event.GetTimespan(); @@ -300,9 +322,26 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver, if (stats.IsXlaOp()) { auto symbol = symbol_resolver(stats.program_id, stats.hlo_module_name, stats.hlo_op_names.back()); - hlo_ops.ExpandOrAddEvents( - GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol), - event_span, stats.group_id); + auto hlo_events_metadata = + GetOrCreateHloOpEventsMetadata(plane_builder, stats, symbol); + hlo_ops.ExpandOrAddEvents(hlo_events_metadata, event_span, + stats.group_id); + // If the kernel event is nodes of a CudaGraph or a whole cuda graph + // exec, try to mark extra stats to to corresponding XLA op event here. + if (stats.cuda_graph_id_for_inner_node.has_value() && + *stats.cuda_graph_id_for_inner_node != 0) { + int level = static_cast(hlo_events_metadata.size()) - 1; + if (level >= 0) { + hlo_ops.AddStatToLevelEvent(level, *hlo_ops.GetCudaGraphIdMetadata(), + *stats.cuda_graph_id_for_inner_node); + if (stats.correlation_id.has_value()) { + hlo_ops.AddStatToLevelEvent(level, + *hlo_ops.GetCorrelationIdMetadata(), + *stats.correlation_id); + } + } + } + if (!symbol.tf_op_name.empty()) { ProcessTfOpEvent(symbol.tf_op_name, event_span, stats.group_id, plane_builder, diff --git a/tensorflow/core/profiler/utils/derived_timeline.h b/tensorflow/core/profiler/utils/derived_timeline.h index 72583c2a79d772..7fd06c9ab2f42a 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.h +++ b/tensorflow/core/profiler/utils/derived_timeline.h @@ -24,10 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/timespan.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { @@ -46,6 +47,11 @@ class DerivedXEventBuilder { event_.SetTimespan(event_span); } + template + void SetOrAddStatValue(const XStatMetadata& metadata, ValueT&& value) { + event_.SetOrAddStatValue(metadata, std::forward(value)); + } + private: XEventBuilder event_; std::optional group_id_; @@ -79,6 +85,22 @@ class DerivedXLineBuilder { // Reset the last events lower than or equal to the given level. void ResetLastEvents(int level = 0); + // To avoid using templates while need hide its implementation in .cc file, + // use two functions to set stat value for int64_t and uint64_t here. + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + int64_t value); + + void AddStatToLevelEvent(int level, const XStatMetadata& metadata, + uint64_t value); + + const XStatMetadata* GetCorrelationIdMetadata() const { + return correlation_id_metadata_; + } + + const XStatMetadata* GetCudaGraphIdMetadata() const { + return cuda_graph_id_metadata_; + } + private: // If the last event of the given level has the same metadata, expands it to // include the time until the given event's end time. @@ -92,6 +114,9 @@ class DerivedXLineBuilder { void AdjustDurationForTraceViewer(int level); const XStatMetadata* group_id_stat_metadata_ = nullptr; + const XStatMetadata* correlation_id_metadata_ = nullptr; + const XStatMetadata* cuda_graph_id_metadata_ = nullptr; + XLineBuilder line_; absl::flat_hash_map> last_event_by_level_; diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index f1f0daca282358..edda0d4673cdda 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -17,8 +17,13 @@ limitations under the License. #include #include +#include +#include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" @@ -27,9 +32,6 @@ limitations under the License. #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/group_events.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { @@ -291,6 +293,64 @@ TEST(DerivedTimelineTest, TfOpNameScopeShrinkTest) { } } +// Checks that XLA Ops mapping to CudaGraph launch has extra stats. +TEST(DerivedTimelineTest, XloOpHasCudaGraphStats) { + constexpr absl::string_view kModuleName = "module"; + constexpr absl::string_view kHloOpName = "op_level_2"; + constexpr absl::string_view kKernelDetails = "kernel_details"; + constexpr int64_t kGroupIdValue = 1; + constexpr int64_t kCorrelationIdValue = 10000; + const uint64_t kCudaGraphIdValue = 20; + XSpace space; + tsl::profiler::GroupMetadataMap group_metadata_map; + + // Build Input Plane/Line/Events and derive events from them. + XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); + XPlaneBuilder plane_builder(&plane); + auto line_builder = plane_builder.GetOrCreateLine(0); + CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, + {{StatType::kKernelDetails, kKernelDetails}, + {StatType::kGroupId, kGroupIdValue}, + {StatType::kHloModule, kModuleName}, + {StatType::kHloOp, kHloOpName}, + {StatType::kCorrelationId, kCorrelationIdValue}, + {StatType::kCudaGraphId, kCudaGraphIdValue}}); + CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, + {{StatType::kKernelDetails, kKernelDetails}, + {StatType::kGroupId, kGroupIdValue}, + {StatType::kHloModule, kModuleName}, + {StatType::kHloOp, kHloOpName}, + {StatType::kCorrelationId, kCorrelationIdValue}, + {StatType::kCudaGraphId, kCudaGraphIdValue}}); + GenerateDerivedTimeLines(group_metadata_map, &space); + + // Check that the HLO op line is added and has the extra stats for the first + // derived event. + size_t num_hlo_op_line = 0; + size_t num_events = 0; + std::optional correlation_id; + std::optional cuda_graph_id; + XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); + plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { + if (line_visitor.Id() == kThreadIdHloOp) { + num_hlo_op_line++; + if (num_hlo_op_line == 1) { + num_events = line_visitor.NumEvents(); + line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) { + correlation_id = event_visitor.GetStat(StatType::kCorrelationId); + cuda_graph_id = event_visitor.GetStat(StatType::kCudaGraphId); + }); + } + } + }); + EXPECT_EQ(num_hlo_op_line, 1); + EXPECT_EQ(num_events, 1); + ASSERT_TRUE(correlation_id.has_value()); + EXPECT_EQ(correlation_id->IntValue(), kCorrelationIdValue); + ASSERT_TRUE(cuda_graph_id.has_value()); + EXPECT_EQ(cuda_graph_id->UintValue(), kCudaGraphIdValue); +} + TEST(DerivedTimelineTest, DeriveLinesForXlaCpuOps) { XPlane xplane; XPlaneBuilder plane_builder(&xplane); diff --git a/tensorflow/core/profiler/utils/device_caps_utils.cc b/tensorflow/core/profiler/utils/device_caps_utils.cc index 5e8edea62493f8..b01bc35f72b3bc 100644 --- a/tensorflow/core/profiler/utils/device_caps_utils.cc +++ b/tensorflow/core/profiler/utils/device_caps_utils.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/core/profiler/utils/device_caps_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { @@ -81,7 +81,6 @@ DeviceCapabilities GetDeviceCaps(const XPlane& plane) { break; } }); - return caps; } diff --git a/tensorflow/core/profiler/utils/event_span.cc b/tensorflow/core/profiler/utils/event_span.cc index cc9e2ed044361b..27ddddf1e4d195 100644 --- a/tensorflow/core/profiler/utils/event_span.cc +++ b/tensorflow/core/profiler/utils/event_span.cc @@ -22,10 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/event_span.h b/tensorflow/core/profiler/utils/event_span.h index 20c8643c5df722..4100390b88959b 100644 --- a/tensorflow/core/profiler/utils/event_span.h +++ b/tensorflow/core/profiler/utils/event_span.h @@ -21,10 +21,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/protobuf/steps_db.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc index c5b880c4498fe2..cd81aea0842dd8 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ b/tensorflow/core/profiler/utils/gpu_event_stats.cc @@ -70,6 +70,9 @@ GpuEventStats::GpuEventStats(const XEventVisitor* event) { case StatType::kCudaGraphExecId: cuda_graph_exec_id = stat.UintValue(); break; + case StatType::kCudaGraphId: + cuda_graph_id_for_inner_node = stat.UintValue(); + break; default: break; } diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.h b/tensorflow/core/profiler/utils/gpu_event_stats.h index 8b9ac5ae75c62d..7740e41ce5ac9c 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.h +++ b/tensorflow/core/profiler/utils/gpu_event_stats.h @@ -56,7 +56,8 @@ struct GpuEventStats { // Stats derived by grouping. std::optional group_id; bool is_eager = false; - std::optional cuda_graph_exec_id; + std::optional cuda_graph_exec_id; + std::optional cuda_graph_id_for_inner_node; }; // Stats for a host-side GPU launch XEvent. diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.cc b/tensorflow/core/profiler/utils/hardware_type_utils.cc index ad3682ad99289c..85cdb13a03ccea 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.cc +++ b/tensorflow/core/profiler/utils/hardware_type_utils.cc @@ -15,7 +15,11 @@ limitations under the License. #include "tensorflow/core/profiler/utils/hardware_type_utils.h" +#include + +#include "absl/container/btree_map.h" #include "absl/strings/match.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/hardware_types.pb.h" @@ -25,82 +29,262 @@ namespace tensorflow { namespace profiler { namespace { -// Get theoretical upperbound of single precision FMA throughput of the GPU per -// cycle per streaming multiprocessor. -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions__throughput-native-arithmetic-instructions -uint32 GetFmaMaxThroughputPerSMPerCycle(const DeviceCapabilities& device_cap) { - if (device_cap.device_vendor() == kDeviceVendorNvidia) { - uint32 n_fp32_cores = 0; - uint32 n_tc_cores = 0; - switch (device_cap.compute_capability().major()) { - case 2: - // Fermi - n_fp32_cores = 32; - break; - case 3: - // Kepler - n_fp32_cores = 192; - break; - case 5: - // Maxwell - n_fp32_cores = 128; - break; - case 6: - // Pascal - if (device_cap.compute_capability().minor() > 0) { - // Pascal SM61/62 - n_fp32_cores = 128; - } else { - // Pascal SM60 - n_fp32_cores = 64; - } - break; - case 7: - // Volta and Turing - n_fp32_cores = 64; - n_tc_cores = 8; - break; - case 8: - // Ampere - if (device_cap.compute_capability().minor() >= 6) { - // Ampere SM86 - n_fp32_cores = 128; - } else { - // Ampere SM80 - n_fp32_cores = 64; - } - n_tc_cores = 4; - break; - default: - LOG(ERROR) << "Invalid GPU compute capability."; - break; - } - // GPU TensorCore can execute 64 FMAs per cycle. - // https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/ - return n_fp32_cores + n_tc_cores * 64; - } else if (device_cap.device_vendor() == kDeviceVendorAMD) { - uint32_t n_xdlops = 0; - uint32_t n_fp32_cores = 0; +// The calculation methods is referred from Nvidia developer forum: +// https://forums.developer.nvidia.com/t/how-to-calculate-the-tensor-core-fp16-performance-of-h100/244727 +// Below data are calculated from the various NVidia whitepapers/specs. + +// https://resources.nvidia.com/en-us-tensor-core/gtc22-whitepaper-hopper +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_9_0 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 512, + .fp16_tflops = 512, + .int8_tops = 1024, + }, + .tensor_core = + { + .fp64_tflops = 256, + .fp32_tflops = 2048, + .bf16_tflops = 4096, + .fp16_tflops = 4096, + .fp8_tflops = 8192, + .int8_tops = 8192, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://images.nvidia.com/aem-dam/Solutions/geforce/ada/nvidia-ada-gpu-architecture.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_9 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 256, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp32_tflops = 512, + .bf16_tflops = 1024, + .fp16_tflops = 1024, + .fp8_tflops = 2048, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_6 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 256, + .bf16_tflops = 256, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp32_tflops = 256, + .bf16_tflops = 512, + .fp16_tflops = 1024, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; - if (device_cap.compute_capability().major() <= 9) { - n_fp32_cores = 64; - } else { - n_fp32_cores = 32; +// https://www.nvidia.com/content/PDF/nvidia-ampere-ga-102-gpu-architecture-whitepaper-v2.1.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_8_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .bf16_tflops = 256, + .fp16_tflops = 512, + .int8_tops = 512, + }, + .tensor_core = + { + .fp64_tflops = 128, + .fp32_tflops = 1024, + .bf16_tflops = 2048, + .fp16_tflops = 2048, + .int8_tops = 4096, + }, + .has_tensor_core_sparsity_support = true, +}; + +// https://images.nvidia.com/aem-dam/en-zz/Solutions/design-visualization/technologies/turing-architecture/NVIDIA-Turing-Architecture-Whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_5 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp16_tflops = 1024, + .int8_tops = 2048, + .int4_tops = 4096, + }, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/volta-architecture/pdf/volta-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_7_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .bf16_tflops = 0.0, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = + { + .fp16_tflops = 1024, + }, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_1 = { + .cuda_core = + { + .fp64_tflops = 8, + .fp32_tflops = 256, + .fp16_tflops = 4, + .int8_tops = 1024, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://images.nvidia.com/content/pdf/tesla/whitepaper/pascal-architecture-whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_6_0 = { + .cuda_core = + { + .fp64_tflops = 64, + .fp32_tflops = 128, + .fp16_tflops = 256, + .int8_tops = 512, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-product-literature/NVIDIA-Kepler-GK110-GK210-Architecture-Whitepaper.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_5_0 = { + .cuda_core = + { + .fp64_tflops = 4, + .fp32_tflops = 256, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +// https://www.nvidia.com/content/PDF/product-specifications/GeForce_GTX_680_Whitepaper_FINAL.pdf +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_3_0 = { + .cuda_core = + { + .fp64_tflops = 128, + .fp32_tflops = 384, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +const GpuFlopCapabilities kComputeCap_PerSM_PerCycle_2_0 = { + .cuda_core = + { + .fp64_tflops = 8, + .fp32_tflops = 64, + }, + .tensor_core = {}, + .has_tensor_core_sparsity_support = false, +}; + +GpuFlopCapabilities GetNvidiaFlopCapsPerSMPerCycle(int major_comp_cap, + int minor_comp_cap) { + static const auto& kPerSMFlopCapsTable = + *new absl::btree_map{ + // TODO: Add incoming blackwell, and other old GPUS + {9000, &kComputeCap_PerSM_PerCycle_9_0}, + {8090, &kComputeCap_PerSM_PerCycle_8_9}, + {8060, &kComputeCap_PerSM_PerCycle_8_6}, + {8000, &kComputeCap_PerSM_PerCycle_8_0}, + {7050, &kComputeCap_PerSM_PerCycle_7_5}, + {7000, &kComputeCap_PerSM_PerCycle_7_0}, + {6010, &kComputeCap_PerSM_PerCycle_6_1}, + {6000, &kComputeCap_PerSM_PerCycle_6_0}, + {5000, &kComputeCap_PerSM_PerCycle_5_0}, + {3000, &kComputeCap_PerSM_PerCycle_3_0}, + {2000, &kComputeCap_PerSM_PerCycle_2_0}, + }; + + const int normalized_compute_cap = + major_comp_cap * 1000 + minor_comp_cap * 10; + GpuFlopCapabilities flops_cap{}; + auto it = kPerSMFlopCapsTable.lower_bound(normalized_compute_cap); + if (it == kPerSMFlopCapsTable.end()) { + LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." + << minor_comp_cap << " is too old to support."; + } else { + flops_cap = *it->second; + if (it->first != normalized_compute_cap) { + LOG(WARNING) << "GPU compute capability " << major_comp_cap << "." + << minor_comp_cap + << " is not found. Use the highest compute cap known " + << (it->first / 1000) << "." << ((it->first % 1000) / 10) + << " instead."; } - // TODO(rocm-profiler): verify with new devices - return n_fp32_cores + n_xdlops * 1; + } + return flops_cap; +} + +GpuFlopCapabilities GetGpuFlopCapabilitiesPerSM( + const DeviceCapabilities& device_cap) { + GpuFlopCapabilities flops_cap{}; + if (device_cap.device_vendor() == kDeviceVendorNvidia) { + flops_cap = + GetNvidiaFlopCapsPerSMPerCycle(device_cap.compute_capability().major(), + device_cap.compute_capability().minor()); } else { - LOG(ERROR) << "Unknown device vendor " << device_cap.device_vendor(); - return 0; + LOG(WARNING) << "Unsupported device vendor " << device_cap.device_vendor(); } + + flops_cap.ScaleWith(device_cap.clock_rate_in_ghz()); + return flops_cap; } } // namespace double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap) { - // One FMA = 2 floating point operations, one multiply and one add. - return GetFmaMaxThroughputPerSMPerCycle(device_cap) * 2 * - device_cap.clock_rate_in_ghz(); + GpuFlopCapabilities sm_flops = GetGpuFlopCapabilitiesPerSM(device_cap); + double result = std::max( + {sm_flops.cuda_core.fp32_tflops, sm_flops.cuda_core.fp16_tflops, + sm_flops.tensor_core.fp32_tflops, sm_flops.tensor_core.fp16_tflops}); + VLOG(3) << "GetFlopMaxThroughputPerSM get result: " << result << " GFLOPs"; + return result; +} + +double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap) { + // https://docs.nvidia.com/gameworks/content/developertools/desktop/analysis/report/cudaexperiments/kernellevel/memorystatisticsshared.htm + // Compute capability 2.0, each bank has bandwidth of 4 bytes per 2 cycles. + // For compute capability 3.0 and above, each bank has bandwidth 8 bytes per + // cycle. Each SM has 32 banks. + double transaction_byts_per_cycle = + device_cap.compute_capability().major() <= 2 ? (32 * 4 / 2) : (32 * 8); + double GiBPS = transaction_byts_per_cycle * device_cap.clock_rate_in_ghz(); + return tsl::profiler::GigaToUni(GiBPS); } absl::string_view GpuModelName(const DeviceCapabilities& device_cap) { diff --git a/tensorflow/core/profiler/utils/hardware_type_utils.h b/tensorflow/core/profiler/utils/hardware_type_utils.h index 894b8c5753805e..41b1bd4b65471c 100644 --- a/tensorflow/core/profiler/utils/hardware_type_utils.h +++ b/tensorflow/core/profiler/utils/hardware_type_utils.h @@ -22,10 +22,48 @@ limitations under the License. namespace tensorflow { namespace profiler { +struct GpuFlopCapabilities { + struct FlopCapabilityOnPrecisions { + double fp64_tflops = 0; + double fp32_tflops = 0; // also for tf32 for nvidia tensor core + double bf16_tflops = 0; + double fp16_tflops = 0; + double fp8_tflops = 0; + double int8_tops = 0; + double fp4_tflops = 0; + double int4_tops = 0; + + void ScaleWith(double scale) { + fp64_tflops *= scale; + fp32_tflops *= scale; + bf16_tflops *= scale; + fp16_tflops *= scale; + fp8_tflops *= scale; + int8_tops *= scale; + fp4_tflops *= scale; + int4_tops *= scale; + } + }; + + FlopCapabilityOnPrecisions cuda_core; + FlopCapabilityOnPrecisions tensor_core; + bool has_tensor_core_sparsity_support = false; + + void ScaleWith(double scale) { + cuda_core.ScaleWith(scale); + tensor_core.ScaleWith(scale); + } +}; + // Get peak single precision throughput of the GPU in GFLOPS per // streaming multiprocessor. +// TODO: Need design on how to use the sparsity capability of FLOPs. double GetFlopMaxThroughputPerSM(const DeviceCapabilities& device_cap); +// for Nvidia GPU, return shared memory bandwidth in Bytes Per Second on +// one single SM given the GPU core freq in device_cap. +double GetSharedMemoryBandwidthPerSM(const DeviceCapabilities& device_cap); + // Returns the GPU model name from the given DeviceCapabilities. // For nvidia GPUs, the name is like "Nvidia GPU (Kepler)" or "Nvidia GPU // (Turing)". For AMD GPUs, the name is like "AMD GPU - gfx-10XX series". diff --git a/tensorflow/core/profiler/utils/hardware_type_utils_test.cc b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc new file mode 100644 index 00000000000000..9476848a650dcc --- /dev/null +++ b/tensorflow/core/profiler/utils/hardware_type_utils_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/hardware_type_utils.h" + +#include "xla/tsl/profiler/utils/math_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace profiler { +namespace { + +TEST(HardwareTypeUtilsTest, H100PeakComputTFlops) { + DeviceCapabilities device_cap; + // For NVIDIA H100 PCIe 80 GB, according to + // https://resources.nvidia.com/en-us-data-center-overview/gtc22-whitepaper-hopper + // https://www.techpowerup.com/gpu-specs/h100-pcie-80-gb.c3899 + device_cap.set_clock_rate_in_ghz(1.620); + device_cap.set_num_cores(114); + device_cap.set_memory_size_in_bytes( + tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); + device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); + device_cap.set_device_vendor("Nvidia"); + device_cap.mutable_compute_capability()->set_major(9); + device_cap.mutable_compute_capability()->set_minor(0); + + // Get target TFLOPS per SM and check. + double peak_tflops = + GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; + EXPECT_NEAR(peak_tflops, 756, /*abs_error=*/1.0); +} + +TEST(HardwareTypeUtilsTest, A100PeakComputTFlops) { + DeviceCapabilities device_cap; + // For NVIDIA A100 SXM4 80 GB, according to: + // https://images.nvidia.com/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf + // https://www.techpowerup.com/gpu-specs/a100-sxm4-80-gb.c3746 + device_cap.set_clock_rate_in_ghz(1.410); + device_cap.set_num_cores(108); + device_cap.set_memory_size_in_bytes( + tsl::profiler::GibiToGiga(tsl::profiler::GigaToUni(80))); + device_cap.set_memory_bandwidth(tsl::profiler::GigaToUni(2.04 * 1024)); + device_cap.set_device_vendor("Nvidia"); + device_cap.mutable_compute_capability()->set_major(8); + device_cap.mutable_compute_capability()->set_minor(0); + + double peak_tflops = + GetFlopMaxThroughputPerSM(device_cap) * device_cap.num_cores() / 1000.0; + EXPECT_NEAR(peak_tflops, 312, /*abs_error=*/1.0); +} + +} // namespace +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.cc b/tensorflow/core/profiler/utils/hlo_module_map.cc index 0fbc48011e84f8..e167c0e47f5136 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.cc +++ b/tensorflow/core/profiler/utils/hlo_module_map.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" +#include "tensorflow/core/profiler/utils/hlo_proto_map.h" #include "tensorflow/core/profiler/utils/hlo_proto_to_module.h" namespace tensorflow { @@ -63,7 +64,14 @@ HloModuleWrapper::HloModuleWrapper( HloModuleWrapper::HloModuleWrapper( std::unique_ptr module, std::function shape_func) - : module_(std::move(module)) { + : HloModuleWrapper(module.get(), shape_func) { + owned_module_ = std::move(module); +} + +HloModuleWrapper::HloModuleWrapper( + const xla::HloModule* module, + std::function shape_func) + : module_(module) { if (module_ == nullptr) return; const xla::HloCostAnalysis* cost_analysis = nullptr; @@ -88,6 +96,7 @@ HloModuleWrapper::HloModuleWrapper( for (const xla::HloComputation* computation : module_->computations()) { for (const xla::HloInstruction* instr : computation->instructions()) { + if (instructions_by_name_.contains(instr->name())) continue; instructions_by_name_.try_emplace( instr->name(), HloInstructionWrapper(instr, cost_analysis)); } @@ -122,5 +131,12 @@ void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, /*shape_func=*/nullptr)); } +void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, + const XSpace* space) { + for (auto& [program_id, hlo_proto] : ParseHloProtosFromXSpace(*space)) { + AddHloProto(hlo_module_map, program_id, *hlo_proto); + } +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/hlo_module_map.h b/tensorflow/core/profiler/utils/hlo_module_map.h index d3525ca014aee7..d5c796e24ffe62 100644 --- a/tensorflow/core/profiler/utils/hlo_module_map.h +++ b/tensorflow/core/profiler/utils/hlo_module_map.h @@ -45,6 +45,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { @@ -99,6 +100,10 @@ class HloModuleWrapper { std::unique_ptr module, std::function shape_func); + explicit HloModuleWrapper( + const xla::HloModule* module, + std::function shape_func); + const HloInstructionWrapper* GetHloInstruction( absl::string_view hlo_name) const; @@ -107,7 +112,10 @@ class HloModuleWrapper { absl::string_view Name() const { return module_->name(); } private: - std::unique_ptr module_; + const xla::HloModule* module_; + + protected: + std::unique_ptr owned_module_; // Map of HloInstructionWrappers by name. using HloInstructionMap = @@ -122,11 +130,16 @@ using HloModuleMap = void AddHloProto(HloModuleMap& hlo_module_map, uint64_t program_id, const xla::HloProto& hlo_proto); +// Process HloModuleMap from single XSpace. +void ProcessHloModuleMapFromXSpace(HloModuleMap& hlo_module_map, + const XSpace* space); + // WARNING: The returned pointer will be invalidated if HloModuleMap is mutated. -inline const HloModuleWrapper* GetHloModule(const HloModuleMap& hlo_module_map, +inline const HloModuleWrapper* GetHloModule(const HloModuleMap* hlo_module_map, uint64_t program_id) { - auto iter = hlo_module_map.find(program_id); - if (iter == hlo_module_map.end()) return nullptr; + if (hlo_module_map == nullptr) return nullptr; + auto iter = hlo_module_map->find(program_id); + if (iter == hlo_module_map->end()) return nullptr; return &iter->second; } @@ -134,7 +147,7 @@ inline const HloInstructionWrapper* GetHloInstruction( const HloModuleMap& hlo_module_map, std::optional program_id, absl::string_view hlo_name) { if (!program_id.has_value()) return nullptr; - const auto* hlo_module = GetHloModule(hlo_module_map, *program_id); + const auto* hlo_module = GetHloModule(&hlo_module_map, *program_id); if (hlo_module == nullptr) return nullptr; return hlo_module->GetHloInstruction(hlo_name); } diff --git a/tensorflow/core/profiler/utils/hlo_proto_map.cc b/tensorflow/core/profiler/utils/hlo_proto_map.cc index 2269acb66eb4c5..bdb16fca3c3fa9 100644 --- a/tensorflow/core/profiler/utils/hlo_proto_map.cc +++ b/tensorflow/core/profiler/utils/hlo_proto_map.cc @@ -29,10 +29,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/host_offload_utils.cc b/tensorflow/core/profiler/utils/host_offload_utils.cc index 44b2eca6dca1ad..312b47f168cc44 100644 --- a/tensorflow/core/profiler/utils/host_offload_utils.cc +++ b/tensorflow/core/profiler/utils/host_offload_utils.cc @@ -30,11 +30,11 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/profiler/utils/trace_utils.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/math_utils.h b/tensorflow/core/profiler/utils/math_utils.h index 1cffd53aafef7b..380884eeb994af 100644 --- a/tensorflow/core/profiler/utils/math_utils.h +++ b/tensorflow/core/profiler/utils/math_utils.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/base/macros.h" -#include "tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" // TODO: b/323943471 - This macro should eventually be provided by Abseil. #ifndef ABSL_DEPRECATE_AND_INLINE diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc index c1206949c336f3..5c8f13e58e8e0d 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.cc +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -24,18 +25,19 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { const absl::string_view kIdle = "IDLE"; +const uint32_t kSparseCoreIndexStart = 1000000; namespace { @@ -226,6 +228,19 @@ OpMetrics* OpMetricsDbBuilder::LookupOrInsertNewOpMetrics( void XEventsOpMetricsDbBuilder::AddOpMetric( const tsl::profiler::XEventVisitor& event) { OpKey key = GetOpKeyFromHloEventMetadata(event.Metadata()); + std::optional stat = event.GetStat(StatType::kStepIdleTimePs); + if (stat.has_value()) { + uint64_t idle_time_ps = stat->IntOrUintValue(); + OpMetrics op_metrics; + op_metrics.set_self_time_ps(event.DurationPs() - idle_time_ps); + op_metrics.set_name("sparse_core_busy_ops"); + // TODO: Make it meaningful after SC stats are available. + op_metrics.set_category("sparse_core_busy_ops"); + constexpr uint64_t kMaxProgramId = std::numeric_limits::max(); + constexpr uint64_t kMaxSymbolId = std::numeric_limits::max(); + flat_op_metric_[kMaxProgramId][kMaxSymbolId] = op_metrics; + SetOpMetricsFromHloEvent(event, &op_metrics); + } if (!key.program_id.has_value() || !key.symbol_id.has_value()) return; OpMetricBySymbol& op_metric_by_symbol = flat_op_metric_[key.program_id.value()]; diff --git a/tensorflow/core/profiler/utils/op_metrics_db_utils.h b/tensorflow/core/profiler/utils/op_metrics_db_utils.h index 27cdfb61fa7800..e3ff3fcc5f6205 100644 --- a/tensorflow/core/profiler/utils/op_metrics_db_utils.h +++ b/tensorflow/core/profiler/utils/op_metrics_db_utils.h @@ -24,16 +24,18 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { // The name of OpMetrics to represent the idle time. TF_CONST_INIT extern const absl::string_view kIdle; +// The core index to add to sparse core index in op metrics. +TF_CONST_INIT extern const uint32_t kSparseCoreIndexStart; // Helps build an op metrics database (borrowed). // Enables fast lookup of existing ops and prevents the creation of duplicate diff --git a/tensorflow/core/profiler/utils/op_utils.cc b/tensorflow/core/profiler/utils/op_utils.cc index 8c8fbfdceb9492..52cbd2192d36f2 100644 --- a/tensorflow/core/profiler/utils/op_utils.cc +++ b/tensorflow/core/profiler/utils/op_utils.cc @@ -20,11 +20,11 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" -#include "tsl/profiler/utils/tf_op_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/op_utils.h b/tensorflow/core/profiler/utils/op_utils.h index dea310aeae16fd..57ce74e0c6ba5c 100644 --- a/tensorflow/core/profiler/utils/op_utils.h +++ b/tensorflow/core/profiler/utils/op_utils.h @@ -17,11 +17,11 @@ limitations under the License. #define TENSORFLOW_CORE_PROFILER_UTILS_OP_UTILS_H_ #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/protobuf/op_metrics.pb.h" #include "tensorflow/core/profiler/utils/op_metrics_db_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/step_intersection.cc b/tensorflow/core/profiler/utils/step_intersection.cc index 6fbf258b5e6b09..ed246abd9737ae 100644 --- a/tensorflow/core/profiler/utils/step_intersection.cc +++ b/tensorflow/core/profiler/utils/step_intersection.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/profiler/utils/step_intersection.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/profiler/utils/timespan.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.cc b/tensorflow/core/profiler/utils/tfstreamz_utils.cc index 7d0c67b7d0f3ff..d19b327cc6284a 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.cc +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.cc @@ -76,8 +76,8 @@ tfstreamz::Percentiles ToProto(const monitoring::Percentiles& percentiles) { } // namespace -Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns) { +absl::Status SerializeToXPlane(const std::vector& snapshots, + XPlane* plane, uint64 line_start_time_ns) { XPlaneBuilder xplane(plane); XLineBuilder line = xplane.GetOrCreateLine(0); // This plane has single line. line.SetTimestampNs(line_start_time_ns); diff --git a/tensorflow/core/profiler/utils/tfstreamz_utils.h b/tensorflow/core/profiler/utils/tfstreamz_utils.h index 1ab21ed1b5ed51..25b7436c4eae63 100644 --- a/tensorflow/core/profiler/utils/tfstreamz_utils.h +++ b/tensorflow/core/profiler/utils/tfstreamz_utils.h @@ -32,8 +32,8 @@ struct TfStreamzSnapshot { uint64 end_time_ns; // time after collection. }; -Status SerializeToXPlane(const std::vector& snapshots, - XPlane* plane, uint64 line_start_time_ns); +absl::Status SerializeToXPlane(const std::vector& snapshots, + XPlane* plane, uint64 line_start_time_ns); } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/trace_utils.h b/tensorflow/core/profiler/utils/trace_utils.h index 735a0207db2c27..89e2b4cde93586 100644 --- a/tensorflow/core/profiler/utils/trace_utils.h +++ b/tensorflow/core/profiler/utils/trace_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_TRACE_UTILS_H_ -#include "tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_builder.h b/tensorflow/core/profiler/utils/xplane_builder.h index 873af726d37eab..c0e2c39b0dc6ac 100644 --- a/tensorflow/core/profiler/utils/xplane_builder.h +++ b/tensorflow/core/profiler/utils/xplane_builder.h @@ -22,7 +22,7 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index d6efbd1cd7a1b1..cfa748bf04ab8a 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_SCHEMA_H_ -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_test_utils.h b/tensorflow/core/profiler/utils/xplane_test_utils.h index c3ed5de0f22237..c2619394d88445 100644 --- a/tensorflow/core/profiler/utils/xplane_test_utils.h +++ b/tensorflow/core/profiler/utils/xplane_test_utils.h @@ -19,10 +19,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_utils.h b/tensorflow/core/profiler/utils/xplane_utils.h index 75ed0d1b3ed330..9292ed6a6b8e30 100644 --- a/tensorflow/core/profiler/utils/xplane_utils.h +++ b/tensorflow/core/profiler/utils/xplane_utils.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xplane_visitor.h b/tensorflow/core/profiler/utils/xplane_visitor.h index deebadbdee5c3b..81db4a4f1bd315 100644 --- a/tensorflow/core/profiler/utils/xplane_visitor.h +++ b/tensorflow/core/profiler/utils/xplane_visitor.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ #define TENSORFLOW_CORE_PROFILER_UTILS_XPLANE_VISITOR_H_ -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc new file mode 100644 index 00000000000000..9df196901de411 --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.cc @@ -0,0 +1,98 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" + +namespace tensorflow { +namespace profiler { + +namespace { + +std::vector GetInputBitwidths(const xla::HloInstruction& hlo) { + std::vector input_bitwidths; + for (const auto& operand : hlo.operands()) { + switch (operand->shape().element_type()) { + case xla::PRIMITIVE_TYPE_INVALID: + case xla::TUPLE: + case xla::OPAQUE_TYPE: + case xla::TOKEN: + break; + default: + input_bitwidths.push_back( + xla::primitive_util::BitWidth(operand->shape().element_type())); + } + } + return input_bitwidths; +} + +} // namespace + +absl::Status XProfGpuCostAnalysis::Postprocess(const xla::HloInstruction* hlo) { + if (hlo == nullptr) { + return absl::OkStatus(); + } + + uint32_t flop_rate_adjustment = 1; + float model_flops = current_properties_[kFlopsKey]; + // Calculate adjustment of device flops based on input bit widths. + // This provide most general adjustment for all ops, and for all gpus. + // TODO: Add adjustment for specific GPUs. + std::vector input_bitwidths = GetInputBitwidths(*hlo); + if (!input_bitwidths.empty()) { + int max_input_bitwidth = + *std::max_element(input_bitwidths.begin(), input_bitwidths.end()); + if (model_flops) { + // for int8/fp8, 2x flops assumed comparing with fp16 flops(most of + // recent GPU models); for int4, 4x of model flops assumed comparing + // with fp16 flops. (like Nvidia T4, 3090). It will be more precise + // after adjustment based on specific GPUs mentioned above. + switch (max_input_bitwidth) { + case 8: + flop_rate_adjustment = 2; + break; + case 4: + flop_rate_adjustment = 4; + break; + } + } + } + current_properties_[kDeviceFlopsAdjustment] = + model_flops - model_flops / flop_rate_adjustment; + return xla::gpu::GpuHloCostAnalysis::Postprocess(hlo); +} + +std::unique_ptr +XProfGpuCostAnalysis::CreateNestedCostAnalysis() { + return std::make_unique(options_); +} + +int64_t XProfGpuCostAnalysis::GetDeviceFlopsAdjustment( + const xla::HloInstruction& hlo) { + return GetPropertyForHlo(hlo, kDeviceFlopsAdjustment, hlo_properties_); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h new file mode 100644 index 00000000000000..6977295c76939b --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ +#define TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" + +namespace tensorflow { +namespace profiler { + +// XProfGpuCostAnalysis provides additional cost analysis for XProf, which +// normalizes the flops to the device flops based on input bit widths. +class XProfGpuCostAnalysis : public xla::gpu::GpuHloCostAnalysis { + public: + explicit XProfGpuCostAnalysis(const xla::HloCostAnalysis::Options& options) + : xla::gpu::GpuHloCostAnalysis(options) {} + + absl::Status Postprocess(const xla::HloInstruction* hlo) override; + + int64_t GetDeviceFlopsAdjustment(const xla::HloInstruction& hlo); + + protected: + std::unique_ptr CreateNestedCostAnalysis() override; + + private: + static inline constexpr absl::string_view kDeviceFlopsAdjustment = + "device_flops_adjustment"; +}; + +} // namespace profiler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PROFILER_UTILS_XPROF_GPU_COST_ANALYSIS_H_ diff --git a/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc new file mode 100644 index 00000000000000..2586e131b53a44 --- /dev/null +++ b/tensorflow/core/profiler/utils/xprof_gpu_cost_analysis_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/profiler/utils/xprof_gpu_cost_analysis.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace profiler { + +class XprofGpuHloCostAnalysisTest : public xla::HloTestBase { + xla::HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { + return [&](const xla::Shape& shape) { + constexpr int64_t kPointerSize = 8; + return xla::ShapeUtil::ByteSizeOf(shape, kPointerSize); + }; + } + + public: + xla::HloCostAnalysis::Options options_{ + ShapeSizeBytesFunction(), + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}; + XProfGpuCostAnalysis analysis_{options_}; + XprofGpuHloCostAnalysisTest() : xla::HloTestBase() {} +}; + +TEST_F(XprofGpuHloCostAnalysisTest, Fp16GemmNoAdjustment) { + absl::string_view hlo_string = R"( +HloModule r + +ENTRY e { + arg0 = f16[65536,32800] parameter(0) + arg1 = f16[32800,32] parameter(1) + gemm = (f16[65536,32], s8[0]) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config="{ + \"gemm_backend_config\": { + \"alpha_real\":1, + \"beta\":0, + \"dot_dimension_numbers\":{ + \"lhs_contracting_dimensions\":[\"1\"], + \"rhs_contracting_dimensions\":[\"0\"], + \"lhs_batch_dimensions\":[], + \"rhs_batch_dimensions\":[] + }, + \"alpha_imag\":0, + \"precision_config\":{ + \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] + }, + \"epilogue\":\"DEFAULT\" + } + }" + ROOT get-tuple-element = f16[65536,32] + get-tuple-element((f16[65536,32], s8[0]) gemm), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + xla::HloComputation* comp = module->entry_computation(); + const xla::HloInstruction* fp16gemm = comp->GetInstructionWithName("gemm"); + // flops of gemm A * B = rows(A) * cols(B) * cols(A) * 2 + // where 2 is for the add and multiply + int64_t gold_flops = 65536LL * 32800 * 32 * 2; + EXPECT_EQ(analysis_.flop_count(*fp16gemm), gold_flops); + EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*fp16gemm), 0); +} + +TEST_F(XprofGpuHloCostAnalysisTest, S8GemmAdjustment) { + absl::string_view hlo_string = R"( +HloModule r + +ENTRY e { + arg0 = s8[65536,32800] parameter(0) + arg1 = s8[32800,32] parameter(1) + gemm = (s32[65536,32], s8[0]) custom-call(arg0, arg1), + custom_call_target="__cublas$gemm", + backend_config="{ + \"gemm_backend_config\": { + \"alpha_real\":1, + \"beta\":0, + \"dot_dimension_numbers\":{ + \"lhs_contracting_dimensions\":[\"1\"], + \"rhs_contracting_dimensions\":[\"0\"], + \"lhs_batch_dimensions\":[], + \"rhs_batch_dimensions\":[] + }, + \"alpha_imag\":0, + \"precision_config\":{ + \"operand_precision\":[\"DEFAULT\",\"DEFAULT\"] + }, + \"epilogue\":\"DEFAULT\" + } + }" + ROOT get-tuple-element = s32[65536,32] + get-tuple-element((s32[65536,32], s8[0]) gemm), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + xla::HloComputation* comp = module->entry_computation(); + const xla::HloInstruction* s8gemm = comp->GetInstructionWithName("gemm"); + int64_t gold_flops = 65536LL * 32800 * 32 * 2; + EXPECT_EQ(analysis_.flop_count(*s8gemm), gold_flops); + // Matmul of int8 * int8 -> int32, normalized it to equivalent fp16 flops by + // dividing by 2 as all inputs are 8 bits + EXPECT_EQ(analysis_.GetDeviceFlopsAdjustment(*s8gemm), gold_flops / 2); +} + +} // namespace profiler +} // namespace tensorflow diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index 30a0b7283a2e57..f38b5c3847fc1f 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -71,7 +71,7 @@ tf_proto_library( srcs = ["conv_autotuning.proto"], make_default_target_header_only = True, protodeps = [ - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], ) @@ -129,8 +129,9 @@ tf_proto_library( name = "error_codes_proto_impl", srcs = ["error_codes.proto"], make_default_target_header_only = True, - protodeps = ["@local_tsl//tsl/protobuf:error_codes_proto_impl"], - exports = ["@local_tsl//tsl/protobuf:error_codes_proto_impl"], + protodeps = ["@local_xla//xla/tsl/protobuf:error_codes_proto_impl"], + visibility = ["//visibility:public"], + exports = ["@local_xla//xla/tsl/protobuf:error_codes_proto_impl"], ) exports_files( @@ -200,16 +201,16 @@ tf_proto_library( ":error_codes_proto_impl", "//tensorflow/core/framework:protos_all", "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", - "@local_tsl//tsl/protobuf:coordination_config_proto", - "@local_tsl//tsl/protobuf:rpc_options_proto", - "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:coordination_config_proto", + "@local_xla//xla/tsl/protobuf:rpc_options_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], tags = ["alt_dep=//third_party/tensorflow/core:protos_all"], visibility = ["//visibility:public"], exports = [ - "@local_tsl//tsl/protobuf:rpc_options_proto", - "@local_tsl//tsl/protobuf:status_proto", "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:rpc_options_proto", + "@local_xla//xla/tsl/protobuf:status_proto", ], ) diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index f5ee231fac28e4..7eb8842ebd6baa 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/coordination_config.proto"; import "tensorflow/core/framework/cost_graph.proto"; import "tensorflow/core/framework/graph.proto"; import "tensorflow/core/framework/step_stats.proto"; @@ -9,7 +10,6 @@ import "tensorflow/core/protobuf/cluster.proto"; import "tensorflow/core/protobuf/debug.proto"; import "tensorflow/core/protobuf/rewriter_config.proto"; import "tensorflow/core/protobuf/rpc_options.proto"; -import "tsl/protobuf/coordination_config.proto"; option cc_enable_arenas = true; option java_outer_classname = "ConfigProtos"; @@ -77,6 +77,11 @@ message GPUOptions { // name "/device:GPU:") are also called "TF GPU id"s. Please // refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h // for more information. + // 3. The visible_device_list is also used for PluggableDevice. And + // different types of PluggableDevices share this field. In that case, + // the pluggable_device_type is used to distinguish them, making the + // visible_device_list a list of :, + // e.g. "PluggableDeviceA:0,PluggableDeviceA:1,PluggableDeviceB:0". string visible_device_list = 5; // In the event polling loop sleep this many microseconds between diff --git a/tensorflow/core/protobuf/conv_autotuning.proto b/tensorflow/core/protobuf/conv_autotuning.proto index 21f1c2adbf5613..47ed3a1174899b 100644 --- a/tensorflow/core/protobuf/conv_autotuning.proto +++ b/tensorflow/core/protobuf/conv_autotuning.proto @@ -4,7 +4,7 @@ syntax = "proto3"; package tensorflow; -import "tsl/protobuf/dnn.proto"; +import "xla/tsl/protobuf/dnn.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/error_codes.proto b/tensorflow/core/protobuf/error_codes.proto index 2e61ab7fd45a32..6842c696f4f40d 100644 --- a/tensorflow/core/protobuf/error_codes.proto +++ b/tensorflow/core/protobuf/error_codes.proto @@ -6,6 +6,6 @@ syntax = "proto3"; // code for some users that use JS through J2CL. package tensorflow.error.dummy; -import public "tsl/protobuf/error_codes.proto"; +import public "xla/tsl/protobuf/error_codes.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/fingerprint.proto b/tensorflow/core/protobuf/fingerprint.proto index 6ac5307ebacab7..1e3ce8d1513aaa 100644 --- a/tensorflow/core/protobuf/fingerprint.proto +++ b/tensorflow/core/protobuf/fingerprint.proto @@ -25,6 +25,8 @@ message FingerprintDef { uint64 saved_object_graph_hash = 4; // Hash of the checkpoint. uint64 checkpoint_hash = 5; + // An UUID for the model, chosen at random, not related to the hashes. + string uuid = 7; // Version specification of the fingerprint. VersionDef version = 6; // TODO(b/290068219): add USM version when GA diff --git a/tensorflow/core/protobuf/rpc_options.proto b/tensorflow/core/protobuf/rpc_options.proto index db9216a7e7bb4c..03593a682e81cb 100644 --- a/tensorflow/core/protobuf/rpc_options.proto +++ b/tensorflow/core/protobuf/rpc_options.proto @@ -2,6 +2,6 @@ syntax = "proto3"; package tensorflow.dummy; -import public "tsl/protobuf/rpc_options.proto"; +import public "xla/tsl/protobuf/rpc_options.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/status.proto b/tensorflow/core/protobuf/status.proto index dd6f703d100ae2..d7df8cf3e05af0 100644 --- a/tensorflow/core/protobuf/status.proto +++ b/tensorflow/core/protobuf/status.proto @@ -6,6 +6,6 @@ syntax = "proto3"; // code for some users that use JS through J2CL. package tensorflow.dummy; -import public "tsl/protobuf/status.proto"; +import public "xla/tsl/protobuf/status.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/protobuf/tpu/BUILD b/tensorflow/core/protobuf/tpu/BUILD index 78d8761223b60b..f714172ae2b490 100644 --- a/tensorflow/core/protobuf/tpu/BUILD +++ b/tensorflow/core/protobuf/tpu/BUILD @@ -83,42 +83,36 @@ tf_pyclif_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tpu_embedding_configuration_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":tpu_embedding_configuration_proto"], # ) # # py_proto_library( # name = "optimization_parameters_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":optimization_parameters_proto"], # ) # # py_proto_library( # name = "topology_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":topology_proto"], # ) # # py_proto_library( # name = "dynamic_padding_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":dynamic_padding_proto"], # ) # # py_proto_library( # name = "compilation_result_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compilation_result_proto"], # ) # # py_proto_library( # name = "compile_metadata_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compile_metadata_proto"], # ) diff --git a/tensorflow/core/protobuf/tpu/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto index 80532f382b0123..1a6c0be5276c32 100644 --- a/tensorflow/core/protobuf/tpu/optimization_parameters.proto +++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto @@ -34,37 +34,32 @@ message SimulatedQuantization { int32 num_buckets = 3; } -// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The -// actual learning rates are provided as a scalar input list to the +// Dynamic input specification for optimizers in the TPUEmbeddingConfiguration. +// The actual dynamic inputs are provided as a scalar input list to the // SendTPUEmbeddingGradients Op indexed by their tag specified through the // following proto. -message DynamicLearningRate { - // For tables where learning rates are dynamically computed and communicated - // to the TPU embedding program, a tag must be specified for the learning - // rate. +message OptimizerDynamicInput { + // For tables where dynamic inputs are needed (e.g., learning rates or other + // dynamic hyperparameters used in optimizers), a tag must be specified for + // the input. // - // The tag must be a non-negative integer. The total number of unique tags - // must be less than or equal to the number of tables in the TPU embedding - // configuration (a table does not specify any tag if it uses a constant - // learning rate, and specifies exactly one tag if it uses dynamic learning - // rates). - // - // All tags in the range [0, number_of_unique_tags) must be present in the TPU - // embedding configuration, i.e. a tag cannot be skipped if a different tag - // numerically greater than it is used in the configuration. + // The tag must be a non-negative integer. All tags in the range + // [0, number_of_unique_tags) must be present in the TPU embedding + // configuration, i.e. a tag cannot be skipped if a different tag numerically + // greater than it is used in the configuration. // // If multiple tables specify the same tag, they *MUST* have - // the same dynamic learning rate, for example, their dynamic learning rate - // could be computed by the same TensorFlow sub-graph. The partitioning of the + // the same dynamic input, for example, their dynamic learning rate could be + // computed by the same TensorFlow sub-graph. The partitioning of the // embedding layer would be more optimal if the number_of_unique_tags is as // *LOW* as possible, i.e., if many tables share the same tag. // - // The learning_rate input of the SendTPUEmbeddingGradients op is used to - // communicate dynamic learning rates to the TPU embedding program. - // The learning_rate input is a list of scalars where the size of the list is - // equal to the number of unique tags. The learning rate associated with a - // particular tag is specified by populating its corresponding index in the - // list of learning_rate scalars. + // The hyper_parameters input of the SendTPUEmbeddingGradients op is used to + // communicate dynamic hyper-parameters to the TPU embedding program. + // The hyper_parameters input is a list of scalars where the size of the list + // is equal to the number of unique tags. The hyper-parameter associated with + // a particular tag is specified by populating its corresponding index in the + // list of scalars. int32 tag = 1; } @@ -72,7 +67,7 @@ message DynamicLearningRate { message LearningRate { oneof learning_rate { float constant = 1; - DynamicLearningRate dynamic = 2; + OptimizerDynamicInput dynamic = 2; } } @@ -131,6 +126,53 @@ message BoundedAdagradParameters { float max_accumulator = 3; } +// Frequency Aware Adagrad optimizer. This optimizer implements the AdaGrad +// algorithm and further allows to: +// * Scale the learning rate based on frequency of the update. Sparsely updated +// rows are updated with a higher effective learning rate, and frequently +// updated rows are updated with a lower effective learning rate. +// * Decay the growth of the accumulator values. +// * Use L1 / L2 regularization for the weight updates. +// +// The optimization algorithm is shown below. +// counter(new) = counter(old) + 1 +// accum(new) = max(accumulator_decay * accum(old) + grad^2, +// initial_accumulator_value) +// lr_scale = min((step_counter / accum(new)) ^ probability_exponent, +// max_lr_multiplier) update = grad * lr_scale / sqrt(accum(new)) if +// (l1_regularization_strength > 0.0): +// update = update + l1_regularization_strength * sign(var(old)) +// if (l2_regularization_strength > 0.0): +// update = update + l2_regularization_strength * var(old) +// var(new) = var(old) - lr_scale * grad * update + +message FrequencyAwareAdagradParameters { + // The L1 regularization parameter for adjusting the update based on the sign + // of the variable. + float l1_regularization_strength = 1; + + // The L2 regularization parameter for adjusting the update based on the + // variable. + float l2_regularization_strength = 2; + + // The exponent used for scaling the learning rate based on the sparsity of + // updates. + float probability_exponent = 4; + + // The maximum value of the learning rate scale. + float max_lr_multiplier = 3; + + // The decay for the Adagrad accumulator. + float accumulator_decay = 5; + + // The initial and minimum value for the Adagrad accumulator. + float initial_accumulator_value = 6; + + // The tag for identifying the step counter used for the frequency aware + // Adagrad optimizer. + OptimizerDynamicInput step_counter = 7; +} + // https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD // https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629 message StochasticGradientDescentParameters {} @@ -502,7 +544,6 @@ message HotIdReplicationConfiguration { message OptimizationParameters { // Learning rate used for updating the embedding layer parameters. LearningRate learning_rate = 13; - reserved 1; // Old learning rate tag. // Limits to which to clip the weight values after the backward pass; not // present means no limits are applied. @@ -550,6 +591,7 @@ message OptimizationParameters { AdagradParameters adagrad = 3; AdagradMomentumParameters adagrad_momentum = 26; BoundedAdagradParameters bounded_adagrad = 19; + FrequencyAwareAdagradParameters frequency_aware_adagrad = 30; StochasticGradientDescentParameters stochastic_gradient_descent = 4; FtrlParameters ftrl = 5; AdamParameters adam = 6; @@ -567,9 +609,9 @@ message OptimizationParameters { AssignParameters assign = 25; } - reserved 15; // Old use_gradient_accumulation. + reserved 1, 15; - // NEXT_ID: 30 + // NEXT_ID: 31 } // Specification of an optimization algorithm's state variables (both the main diff --git a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto index 3dee624de569d2..c092231f72b762 100644 --- a/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto +++ b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto @@ -139,6 +139,11 @@ message TPUEmbeddingConfiguration { // Number of cores per replica. int32 num_cores_per_replica = 2; + + // If true, the tensors are manually partitioned. Otherwise, use the + // automatic SPMD partitioning. This should be true when users use + // `shard_map`. + bool use_manual_partitioning = 3; } SpmdSharding spmd_sharding = 11; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index e7fe3c0842ed14..1bacc105d05a64 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -21,7 +21,7 @@ limitations under the License. // Also update tensorflow/tensorflow.bzl and // tensorflow/tools/pip_package/setup.py #define TF_MAJOR_VERSION 2 -#define TF_MINOR_VERSION 18 +#define TF_MINOR_VERSION 19 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1990 // Updated: 2024/9/19 +#define TF_GRAPH_DEF_VERSION 2028 // Updated: 2024/10/27 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.cc b/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.cc deleted file mode 100644 index b2bd844e74babe..00000000000000 --- a/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file implements gpu related utility functions. - -#include "tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h" - -#include "tensorflow/c/tf_tensor.h" -#include "tensorflow/c/tf_tensor_internal.h" -#include "xla/stream_executor/cuda/cuda_driver.h" -#include "xla/stream_executor/platform.h" -#include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" -#include "tensorflow/core/runtime_fallback/util/tensor_util.h" -#include "tensorflow/core/runtime_fallback/util/type_util.h" - -namespace tensorflow { -namespace tfd { - -// Helper to lookup GPU platform (CUDA vs ROCm) from a given TensorHandle. -static tfrt::gpu::wrapper::Platform GetTfrtGpuPlatformHelper( - tensorflow::TensorHandle* th) { - auto device = th->op_device(); - auto gpu_device = static_cast(device); - return GetTfrtGpuPlatform(gpu_device); -} - -tfrt::gpu::wrapper::Platform GetTfrtGpuPlatform( - tensorflow::BaseGPUDevice* device) { - auto platform_kind = device->executor()->platform_kind(); - if (platform_kind == stream_executor::PlatformKind::kCuda) { - return tfrt::gpu::wrapper::Platform::CUDA; - } else if (platform_kind == stream_executor::PlatformKind::kROCm) { - return tfrt::gpu::wrapper::Platform::ROCm; - } - return tfrt::gpu::wrapper::Platform::NONE; -} - -// Lookup GPU platform (CUDA vs ROCm) from a given TensorHandle. -tfrt::gpu::wrapper::Platform GetTfrtGpuPlatform(tensorflow::TensorHandle* th) { - // Cache lookup result assuming TF does not mix CUDA and ROCm tensor handles. - static auto gpu_platform = GetTfrtGpuPlatformHelper(th); - return gpu_platform; -} - -namespace { -struct TFManagedBufferDeleter { - void operator()(TF_ManagedBuffer* p) const { p->Unref(); } -}; -using OwnedTFManagedBuffer = - std::unique_ptr; -} // namespace - -// Moves one ref on GpuBuffer to tensorflow::Tensor. -tfrt::Expected MoveGpuBufferToTFTensor( - tfrt::AsyncValueRef gpu_buffer, tfrt::DType dtype, - tfrt::TensorShape shape) { - auto deallocator = [](void* data, size_t len, void* arg) { - auto* gpu_buffer = reinterpret_cast(arg); - gpu_buffer->DropRef(); - }; - - // `owns_memory` is used by tensorflow::Tensor::RefCountIsOne. - // One ref on `gpu_buffer` is transfered here to TF_ManagedBuffer. - OwnedTFManagedBuffer tf_managed_buffer{ - new TF_ManagedBuffer(gpu_buffer->pointer().raw(), gpu_buffer->size(), - deallocator, gpu_buffer.release(), - /*owns_memory=*/false)}; - tensorflow::Tensor tensor(GetTfDataType(dtype), GetTfShape(shape), - tf_managed_buffer.get()); - return std::move(tensor); -} - -} // namespace tfd -} // namespace tensorflow diff --git a/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h b/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h deleted file mode 100644 index 9ff8136a0145ac..00000000000000 --- a/tensorflow/core/runtime_fallback/util/gpu/gpu_utils.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This file declares gpu related utility functions. - -#ifndef TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_GPU_GPU_UTILS_H_ -#define TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_GPU_GPU_UTILS_H_ - -#include "tensorflow/core/common_runtime/eager/tensor_handle.h" -#include "tensorflow/core/common_runtime/gpu/gpu_device.h" -#include "tfrt/tensor/tensor.h" // from @tf_runtime - -namespace tensorflow { -namespace tfd { - -// Lookup GPU platform (CUDA vs ROCm) from a given tensorflow::TensorHandle. -tfrt::gpu::wrapper::Platform GetTfrtGpuPlatform(tensorflow::TensorHandle* th); - -tfrt::gpu::wrapper::Platform GetTfrtGpuPlatform( - tensorflow::BaseGPUDevice* device); - -// Moves one ref on GpuBuffer to tensorflow::Tensor. -tfrt::Expected MoveGpuBufferToTFTensor( - tfrt::AsyncValueRef gpu_buffer, tfrt::DType dtype, - tfrt::TensorShape shape); - -} // namespace tfd -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_RUNTIME_FALLBACK_UTIL_GPU_GPU_UTILS_H_ diff --git a/tensorflow/core/summary/loader.cc b/tensorflow/core/summary/loader.cc index 3af1f1b32dc1a7..8d06f49a66e507 100644 --- a/tensorflow/core/summary/loader.cc +++ b/tensorflow/core/summary/loader.cc @@ -99,7 +99,7 @@ int main(int argc, char* argv[]) { tstring record; while (true) { std::unique_ptr event = std::unique_ptr(new Event); - Status s = reader.ReadRecord(&offset, &record); + absl::Status s = reader.ReadRecord(&offset, &record); if (s.code() == error::OUT_OF_RANGE) break; TF_CHECK_OK(s); if (!ParseProtoUnlimited(event.get(), record)) { diff --git a/tensorflow/core/summary/schema.cc b/tensorflow/core/summary/schema.cc index 822e2fa3bfdaf2..3b6f3d6c5d3ce7 100644 --- a/tensorflow/core/summary/schema.cc +++ b/tensorflow/core/summary/schema.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace { -Status Run(Sqlite* db, const char* sql) { +absl::Status Run(Sqlite* db, const char* sql) { SqliteStatement stmt; TF_RETURN_IF_ERROR(db->Prepare(sql, &stmt)); return stmt.StepAndReset(); @@ -27,7 +27,7 @@ Status Run(Sqlite* db, const char* sql) { } // namespace -Status SetupTensorboardSqliteDb(Sqlite* db) { +absl::Status SetupTensorboardSqliteDb(Sqlite* db) { // Note: GCC raw strings macros are broken. // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 TF_RETURN_IF_ERROR( @@ -35,7 +35,7 @@ Status SetupTensorboardSqliteDb(Sqlite* db) { kTensorboardSqliteApplicationId)) .StepAndReset()); db->PrepareOrDie("PRAGMA user_version=0").StepAndResetOrDie(); - Status s; + absl::Status s; // Ids identify resources. // diff --git a/tensorflow/core/summary/schema.h b/tensorflow/core/summary/schema.h index 6305f8eabd7cac..4361088c8be7a0 100644 --- a/tensorflow/core/summary/schema.h +++ b/tensorflow/core/summary/schema.h @@ -26,7 +26,7 @@ constexpr uint32 kTensorboardSqliteApplicationId = 0xfeedabee; /// /// If they are already created, this has no effect. If schema /// migrations are necessary, they will be performed with logging. -Status SetupTensorboardSqliteDb(Sqlite* db); +absl::Status SetupTensorboardSqliteDb(Sqlite* db); } // namespace tensorflow diff --git a/tensorflow/core/summary/summary_converter.cc b/tensorflow/core/summary/summary_converter.cc index 4e9908a9fa0eab..53ed1dfded5b55 100644 --- a/tensorflow/core/summary/summary_converter.cc +++ b/tensorflow/core/summary/summary_converter.cc @@ -27,7 +27,7 @@ namespace tensorflow { namespace { template -Status TensorValueAt(Tensor t, int64_t i, T* out) { +absl::Status TensorValueAt(Tensor t, int64_t i, T* out) { #define CASE(I) \ case DataTypeToEnum::value: \ *out = static_cast(t.flat()(i)); \ @@ -71,9 +71,10 @@ typedef Eigen::Tensor Uint8Image; // differently in the float and uint8 cases: the float case needs a temporary // buffer which can be shared across calls to ith_image, but the uint8 case // does not. -Status AddImages(const string& tag, int max_images, int batch_size, int w, - int h, int depth, - const std::function& ith_image, Summary* s) { +absl::Status AddImages(const string& tag, int max_images, int batch_size, int w, + int h, int depth, + const std::function& ith_image, + Summary* s) { const int N = std::min(max_images, batch_size); for (int i = 0; i < N; ++i) { Summary::Value* v = s->add_value(); @@ -177,10 +178,10 @@ void NormalizeFloatImage(int hw, int depth, } template -Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, int w, - int hw, int depth, int batch_size, - const string& base_tag, Tensor bad_color_tensor, - Summary* s) { +absl::Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, + int w, int hw, int depth, int batch_size, + const string& base_tag, + Tensor bad_color_tensor, Summary* s) { // For float and half images, nans and infs are replaced with bad_color. if (bad_color_tensor.dim_size(0) < depth) { return errors::InvalidArgument( @@ -204,8 +205,8 @@ Status NormalizeAndAddImages(const Tensor& tensor, int max_images, int h, int w, } // namespace -Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, - Summary* s) { +absl::Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, + Summary* s) { Summary::Value* v = s->add_value(); v->set_tag(tag); float value; @@ -214,8 +215,8 @@ Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, return absl::OkStatus(); } -Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, - Summary* s) { +absl::Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, + Summary* s) { Summary::Value* v = s->add_value(); v->set_tag(tag); histogram::Histogram histo; @@ -234,9 +235,9 @@ Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, return absl::OkStatus(); } -Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, - int max_images, const Tensor& bad_color, - Summary* s) { +absl::Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, + int max_images, const Tensor& bad_color, + Summary* s) { if (!(tensor.dims() == 4 && (tensor.dim_size(3) == 1 || tensor.dim_size(3) == 3 || tensor.dim_size(3) == 4))) { @@ -283,9 +284,9 @@ Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, return absl::OkStatus(); } -Status AddTensorAsAudioToSummary(const Tensor& tensor, const string& tag, - int max_outputs, float sample_rate, - Summary* s) { +absl::Status AddTensorAsAudioToSummary(const Tensor& tensor, const string& tag, + int max_outputs, float sample_rate, + Summary* s) { if (sample_rate <= 0.0f) { return errors::InvalidArgument("sample_rate must be > 0"); } diff --git a/tensorflow/core/summary/summary_converter.h b/tensorflow/core/summary/summary_converter.h index dc005d2604ff16..d77d4c670e8d8d 100644 --- a/tensorflow/core/summary/summary_converter.h +++ b/tensorflow/core/summary/summary_converter.h @@ -22,16 +22,16 @@ limitations under the License. namespace tensorflow { // TODO(jart): Delete these methods in favor of new Python implementation. -Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, - Summary* s); -Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, - Summary* s); -Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, - int max_images, const Tensor& bad_color, - Summary* s); -Status AddTensorAsAudioToSummary(const Tensor& tensor, const string& tag, - int max_outputs, float sample_rate, - Summary* s); +absl::Status AddTensorAsScalarToSummary(const Tensor& t, const string& tag, + Summary* s); +absl::Status AddTensorAsHistogramToSummary(const Tensor& t, const string& tag, + Summary* s); +absl::Status AddTensorAsImageToSummary(const Tensor& tensor, const string& tag, + int max_images, const Tensor& bad_color, + Summary* s); +absl::Status AddTensorAsAudioToSummary(const Tensor& tensor, const string& tag, + int max_outputs, float sample_rate, + Summary* s); } // namespace tensorflow diff --git a/tensorflow/core/summary/summary_db_writer.cc b/tensorflow/core/summary/summary_db_writer.cc index 8772ac1c8a25d1..b2d12f5785f7af 100644 --- a/tensorflow/core/summary/summary_db_writer.cc +++ b/tensorflow/core/summary/summary_db_writer.cc @@ -99,7 +99,7 @@ string StringifyShape(const TensorShape& shape) { return result; } -Status CheckSupportedType(const Tensor& t) { +absl::Status CheckSupportedType(const Tensor& t) { #define CASE(T) \ case DataTypeToEnum::value: \ break; @@ -136,7 +136,8 @@ void PatchPluginName(SummaryMetadata* metadata, const char* name) { } } -Status SetDescription(Sqlite* db, int64_t id, const StringPiece& markdown) { +absl::Status SetDescription(Sqlite* db, int64_t id, + const StringPiece& markdown) { const char* sql = R"sql( INSERT OR REPLACE INTO Descriptions (id, description) VALUES (?, ?) )sql"; @@ -165,9 +166,9 @@ class IdAllocator { DCHECK(db_ != nullptr); } - Status CreateNewId(int64_t* id) TF_LOCKS_EXCLUDED(mu_) { + absl::Status CreateNewId(int64_t* id) TF_LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); - Status s; + absl::Status s; SqliteStatement stmt; TF_RETURN_IF_ERROR(db_->Prepare("INSERT INTO Ids (id) VALUES (?)", &stmt)); for (int i = 0; i < kMaxIdCollisions; ++i) { @@ -216,9 +217,9 @@ class IdAllocator { class GraphWriter { public: - static Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, - GraphDef* graph, uint64 now, int64_t run_id, - int64_t* graph_id) + static absl::Status Save(Sqlite* db, SqliteTransaction* txn, IdAllocator* ids, + GraphDef* graph, uint64 now, int64_t run_id, + int64_t* graph_id) SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { TF_RETURN_IF_ERROR(ids->CreateNewId(graph_id)); GraphWriter saver{db, txn, graph, now, *graph_id}; @@ -246,7 +247,7 @@ class GraphWriter { } } - Status SaveNodeInputs() { + absl::Status SaveNodeInputs() { const char* sql = R"sql( INSERT INTO NodeInputs ( graph_id, @@ -298,7 +299,7 @@ class GraphWriter { return absl::OkStatus(); } - Status SaveNodes() { + absl::Status SaveNodes() { const char* sql = R"sql( INSERT INTO Nodes ( graph_id, @@ -333,7 +334,7 @@ class GraphWriter { return absl::OkStatus(); } - Status SaveGraph(int64_t run_id) { + absl::Status SaveGraph(int64_t run_id) { const char* sql = R"sql( INSERT OR REPLACE INTO Graphs ( run_id, @@ -355,7 +356,7 @@ class GraphWriter { return insert.StepAndReset(); } - Status MaybeFlush() { + absl::Status MaybeFlush() { if (unflushed_bytes_ >= kFlushBytes) { TF_RETURN_WITH_CONTEXT_IF_ERROR(txn_->Commit(), "flushing ", unflushed_bytes_, " bytes"); @@ -404,9 +405,9 @@ class RunMetadata { return run_id_; } - Status SetGraph(Sqlite* db, uint64 now, double computed_time, - std::unique_ptr g) SQLITE_TRANSACTIONS_EXCLUDED(*db) - TF_LOCKS_EXCLUDED(mu_) { + absl::Status SetGraph(Sqlite* db, uint64 now, double computed_time, + std::unique_ptr g) + SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { int64_t run_id; { mutex_lock lock(mu_); @@ -420,9 +421,10 @@ class RunMetadata { return txn.Commit(); } - Status GetTagId(Sqlite* db, uint64 now, double computed_time, - const string& tag_name, int64_t* tag_id, - const SummaryMetadata& metadata) TF_LOCKS_EXCLUDED(mu_) { + absl::Status GetTagId(Sqlite* db, uint64 now, double computed_time, + const string& tag_name, int64_t* tag_id, + const SummaryMetadata& metadata) + TF_LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); TF_RETURN_IF_ERROR(InitializeRun(db, now, computed_time)); auto e = tag_ids_.find(tag_name); @@ -466,7 +468,7 @@ class RunMetadata { } private: - Status InitializeUser(Sqlite* db, uint64 now) + absl::Status InitializeUser(Sqlite* db, uint64 now) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (user_id_ != kAbsent || user_name_.empty()) return absl::OkStatus(); const char* get_sql = R"sql( @@ -498,7 +500,8 @@ class RunMetadata { return absl::OkStatus(); } - Status InitializeExperiment(Sqlite* db, uint64 now, double computed_time) + absl::Status InitializeExperiment(Sqlite* db, uint64 now, + double computed_time) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (experiment_name_.empty()) return absl::OkStatus(); if (experiment_id_ == kAbsent) { @@ -565,7 +568,7 @@ class RunMetadata { return absl::OkStatus(); } - Status InitializeRun(Sqlite* db, uint64 now, double computed_time) + absl::Status InitializeRun(Sqlite* db, uint64 now, double computed_time) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (run_name_.empty()) return absl::OkStatus(); TF_RETURN_IF_ERROR(InitializeExperiment(db, now, computed_time)); @@ -635,19 +638,19 @@ class SeriesWriter { DCHECK(series_ > 0); } - Status Append(Sqlite* db, int64_t step, uint64 now, double computed_time, - const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) - TF_LOCKS_EXCLUDED(mu_) { + absl::Status Append(Sqlite* db, int64_t step, uint64 now, + double computed_time, const Tensor& t) + SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (rowids_.empty()) { - Status s = Reserve(db, t); + absl::Status s = Reserve(db, t); if (!s.ok()) { rowids_.clear(); return s; } } int64_t rowid = rowids_.front(); - Status s = Write(db, rowid, step, computed_time, t); + absl::Status s = Write(db, rowid, step, computed_time, t); if (s.ok()) { ++count_; } @@ -655,7 +658,7 @@ class SeriesWriter { return s; } - Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) + absl::Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); // Delete unused pre-allocated Tensors. @@ -678,8 +681,9 @@ class SeriesWriter { } private: - Status Write(Sqlite* db, int64_t rowid, int64_t step, double computed_time, - const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) { + absl::Status Write(Sqlite* db, int64_t rowid, int64_t step, + double computed_time, const Tensor& t) + SQLITE_TRANSACTIONS_EXCLUDED(*db) { if (t.dtype() == DT_STRING) { if (t.dims() == 0) { return Update(db, step, computed_time, t, t.scalar()(), rowid); @@ -695,8 +699,8 @@ class SeriesWriter { } } - Status Update(Sqlite* db, int64_t step, double computed_time, const Tensor& t, - const StringPiece& data, int64_t rowid) { + absl::Status Update(Sqlite* db, int64_t step, double computed_time, + const Tensor& t, const StringPiece& data, int64_t rowid) { const char* sql = R"sql( UPDATE OR REPLACE Tensors @@ -721,7 +725,7 @@ class SeriesWriter { return absl::OkStatus(); } - Status UpdateNdString(Sqlite* db, const Tensor& t, int64_t tensor_rowid) + absl::Status UpdateNdString(Sqlite* db, const Tensor& t, int64_t tensor_rowid) SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) { DCHECK_EQ(t.dtype(), DT_STRING); DCHECK_GT(t.dims(), 0); @@ -751,8 +755,8 @@ class SeriesWriter { return absl::OkStatus(); } - Status Reserve(Sqlite* db, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) - TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status Reserve(Sqlite* db, const Tensor& t) + SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { SqliteTransaction txn(*db); // only for performance unflushed_bytes_ = 0; if (t.dtype() == DT_STRING) { @@ -767,7 +771,7 @@ class SeriesWriter { return txn.Commit(); } - Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) + absl::Status ReserveData(Sqlite* db, SqliteTransaction* txn, size_t size) SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { int64_t space = @@ -776,8 +780,8 @@ class SeriesWriter { return ReserveTensors(db, txn, space); } - Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, - int64_t reserved_bytes) + absl::Status ReserveTensors(Sqlite* db, SqliteTransaction* txn, + int64_t reserved_bytes) SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { const char* sql = R"sql( @@ -802,7 +806,7 @@ class SeriesWriter { return absl::OkStatus(); } - Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) + absl::Status MaybeFlush(Sqlite* db, SqliteTransaction* txn) SQLITE_EXCLUSIVE_TRANSACTIONS_REQUIRED(*db) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (unflushed_bytes_ >= kFlushBytes) { @@ -835,14 +839,14 @@ class RunWriter { public: explicit RunWriter(RunMetadata* meta) : meta_{meta} {} - Status Append(Sqlite* db, int64_t tag_id, int64_t step, uint64 now, - double computed_time, const Tensor& t) + absl::Status Append(Sqlite* db, int64_t tag_id, int64_t step, uint64 now, + double computed_time, const Tensor& t) SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { SeriesWriter* writer = GetSeriesWriter(tag_id); return writer->Append(db, step, now, computed_time, t); } - Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) + absl::Status Finish(Sqlite* db) SQLITE_TRANSACTIONS_EXCLUDED(*db) TF_LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); if (series_writers_.empty()) return absl::OkStatus(); @@ -896,7 +900,7 @@ class SummaryDbWriter : public SummaryWriterInterface { ~SummaryDbWriter() override { core::ScopedUnref unref(db_); - Status s = run_.Finish(db_); + absl::Status s = run_.Finish(db_); if (!s.ok()) { // TODO(jart): Retry on transient errors here. LOG(ERROR) << s; @@ -918,10 +922,10 @@ class SummaryDbWriter : public SummaryWriterInterface { } } - Status Flush() override { return absl::OkStatus(); } + absl::Status Flush() override { return absl::OkStatus(); } - Status WriteTensor(int64_t global_step, Tensor t, const string& tag, - const string& serialized_metadata) override { + absl::Status WriteTensor(int64_t global_step, Tensor t, const string& tag, + const string& serialized_metadata) override { TF_RETURN_IF_ERROR(CheckSupportedType(t)); SummaryMetadata metadata; if (!metadata.ParseFromString(serialized_metadata)) { @@ -930,25 +934,26 @@ class SummaryDbWriter : public SummaryWriterInterface { return Write(global_step, t, tag, metadata); } - Status WriteScalar(int64_t global_step, Tensor t, - const string& tag) override { + absl::Status WriteScalar(int64_t global_step, Tensor t, + const string& tag) override { TF_RETURN_IF_ERROR(CheckSupportedType(t)); SummaryMetadata metadata; PatchPluginName(&metadata, kScalarPluginName); return Write(global_step, AsScalar(t), tag, metadata); } - Status WriteGraph(int64_t global_step, std::unique_ptr g) override { + absl::Status WriteGraph(int64_t global_step, + std::unique_ptr g) override { uint64 now = env_->NowMicros(); return meta_.SetGraph(db_, now, DoubleTime(now), std::move(g)); } - Status WriteEvent(std::unique_ptr e) override { + absl::Status WriteEvent(std::unique_ptr e) override { return MigrateEvent(std::move(e)); } - Status WriteHistogram(int64_t global_step, Tensor t, - const string& tag) override { + absl::Status WriteHistogram(int64_t global_step, Tensor t, + const string& tag) override { uint64 now = env_->NowMicros(); std::unique_ptr e{new Event}; e->set_step(global_step); @@ -958,8 +963,8 @@ class SummaryDbWriter : public SummaryWriterInterface { return MigrateEvent(std::move(e)); } - Status WriteImage(int64_t global_step, Tensor t, const string& tag, - int max_images, Tensor bad_color) override { + absl::Status WriteImage(int64_t global_step, Tensor t, const string& tag, + int max_images, Tensor bad_color) override { uint64 now = env_->NowMicros(); std::unique_ptr e{new Event}; e->set_step(global_step); @@ -969,8 +974,8 @@ class SummaryDbWriter : public SummaryWriterInterface { return MigrateEvent(std::move(e)); } - Status WriteAudio(int64_t global_step, Tensor t, const string& tag, - int max_outputs, float sample_rate) override { + absl::Status WriteAudio(int64_t global_step, Tensor t, const string& tag, + int max_outputs, float sample_rate) override { uint64 now = env_->NowMicros(); std::unique_ptr e{new Event}; e->set_step(global_step); @@ -983,8 +988,8 @@ class SummaryDbWriter : public SummaryWriterInterface { string DebugString() const override { return "SummaryDbWriter"; } private: - Status Write(int64_t step, const Tensor& t, const string& tag, - const SummaryMetadata& metadata) { + absl::Status Write(int64_t step, const Tensor& t, const string& tag, + const SummaryMetadata& metadata) { uint64 now = env_->NowMicros(); double computed_time = DoubleTime(now); int64_t tag_id; @@ -997,7 +1002,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return absl::OkStatus(); } - Status MigrateEvent(std::unique_ptr e) { + absl::Status MigrateEvent(std::unique_ptr e) { switch (e->what_case()) { case Event::WhatCase::kSummary: { uint64 now = env_->NowMicros(); @@ -1024,7 +1029,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return absl::OkStatus(); } - Status MigrateGraph(const Event* e, const string& graph_def) { + absl::Status MigrateGraph(const Event* e, const string& graph_def) { uint64 now = env_->NowMicros(); std::unique_ptr graph{new GraphDef}; if (!ParseProtoUnlimited(graph.get(), graph_def)) { @@ -1033,7 +1038,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return meta_.SetGraph(db_, now, e->wall_time(), std::move(graph)); } - Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateSummary(const Event* e, Summary::Value* s, uint64 now) { switch (s->value_case()) { case Summary::Value::ValueCase::kTensor: TF_RETURN_WITH_CONTEXT_IF_ERROR(MigrateTensor(e, s, now), "tensor"); @@ -1056,7 +1061,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return absl::OkStatus(); } - Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateTensor(const Event* e, Summary::Value* s, uint64 now) { Tensor t; if (!t.FromProto(s->tensor())) return errors::InvalidArgument("bad proto"); TF_RETURN_IF_ERROR(CheckSupportedType(t)); @@ -1068,7 +1073,7 @@ class SummaryDbWriter : public SummaryWriterInterface { // TODO(jart): Refactor Summary -> Tensor logic into separate file. - Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateScalar(const Event* e, Summary::Value* s, uint64 now) { // See tensorboard/plugins/scalar/summary.py and data_compat.py Tensor t{DT_FLOAT, {}}; t.scalar()() = s->simple_value(); @@ -1079,7 +1084,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } - Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateHistogram(const Event* e, Summary::Value* s, uint64 now) { const HistogramProto& histo = s->histo(); int k = histo.bucket_size(); if (k != histo.bucket_limit_size()) { @@ -1110,7 +1115,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } - Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateImage(const Event* e, Summary::Value* s, uint64 now) { // See tensorboard/plugins/image/summary.py and data_compat.py Tensor t{DT_STRING, {3}}; auto img = s->mutable_image(); @@ -1124,7 +1129,7 @@ class SummaryDbWriter : public SummaryWriterInterface { return run_.Append(db_, tag_id, e->step(), now, e->wall_time(), t); } - Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { + absl::Status MigrateAudio(const Event* e, Summary::Value* s, uint64 now) { // See tensorboard/plugins/audio/summary.py and data_compat.py Tensor t{DT_STRING, {1, 2}}; auto wav = s->mutable_audio(); @@ -1146,9 +1151,10 @@ class SummaryDbWriter : public SummaryWriterInterface { } // namespace -Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, - const string& run_name, const string& user_name, - Env* env, SummaryWriterInterface** result) { +absl::Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, + const string& run_name, + const string& user_name, Env* env, + SummaryWriterInterface** result) { *result = new SummaryDbWriter(env, db, experiment_name, run_name, user_name); return absl::OkStatus(); } diff --git a/tensorflow/core/summary/summary_db_writer.h b/tensorflow/core/summary/summary_db_writer.h index 5669afe7f67e10..9b4644b91bde24 100644 --- a/tensorflow/core/summary/summary_db_writer.h +++ b/tensorflow/core/summary/summary_db_writer.h @@ -33,9 +33,10 @@ namespace tensorflow { /// the future if support for other DBs is added to core. /// /// The result holds a new reference to db. -Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, - const string& run_name, const string& user_name, - Env* env, SummaryWriterInterface** result); +absl::Status CreateSummaryDbWriter(Sqlite* db, const string& experiment_name, + const string& run_name, + const string& user_name, Env* env, + SummaryWriterInterface** result); } // namespace tensorflow diff --git a/tensorflow/core/summary/summary_db_writer_test.cc b/tensorflow/core/summary/summary_db_writer_test.cc index 722e99516cbfb3..8ddf4ebae66a48 100644 --- a/tensorflow/core/summary/summary_db_writer_test.cc +++ b/tensorflow/core/summary/summary_db_writer_test.cc @@ -65,7 +65,7 @@ class SummaryDbWriterTest : public ::testing::Test { int64_t QueryInt(const string& sql) { SqliteStatement stmt = db_->PrepareOrDie(sql); bool is_done; - Status s = stmt.Step(&is_done); + absl::Status s = stmt.Step(&is_done); if (!s.ok() || is_done) { LOG(ERROR) << s << " due to " << sql; return -1; @@ -76,7 +76,7 @@ class SummaryDbWriterTest : public ::testing::Test { double QueryDouble(const string& sql) { SqliteStatement stmt = db_->PrepareOrDie(sql); bool is_done; - Status s = stmt.Step(&is_done); + absl::Status s = stmt.Step(&is_done); if (!s.ok() || is_done) { LOG(ERROR) << s << " due to " << sql; return -1; @@ -87,7 +87,7 @@ class SummaryDbWriterTest : public ::testing::Test { string QueryString(const string& sql) { SqliteStatement stmt = db_->PrepareOrDie(sql); bool is_done; - Status s = stmt.Step(&is_done); + absl::Status s = stmt.Step(&is_done); if (!s.ok() || is_done) { LOG(ERROR) << s << " due to " << sql; return "MISSINGNO"; diff --git a/tensorflow/core/summary/summary_file_writer.cc b/tensorflow/core/summary/summary_file_writer.cc index 69903c6e791ede..89d6c2fb76ef4f 100644 --- a/tensorflow/core/summary/summary_file_writer.cc +++ b/tensorflow/core/summary/summary_file_writer.cc @@ -39,8 +39,8 @@ class SummaryFileWriter : public SummaryWriterInterface { flush_millis_(flush_millis), env_(env) {} - Status Initialize(const string& logdir, const string& filename_suffix) { - const Status is_dir = env_->IsDirectory(logdir); + absl::Status Initialize(const string& logdir, const string& filename_suffix) { + const absl::Status is_dir = env_->IsDirectory(logdir); if (!is_dir.ok()) { if (is_dir.code() != tensorflow::error::NOT_FOUND) { return is_dir; @@ -66,7 +66,7 @@ class SummaryFileWriter : public SummaryWriterInterface { return absl::OkStatus(); } - Status Flush() override { + absl::Status Flush() override { mutex_lock ml(mu_); if (!is_initialized_) { return errors::FailedPrecondition("Class was not properly initialized."); @@ -78,8 +78,8 @@ class SummaryFileWriter : public SummaryWriterInterface { (void)Flush(); // Ignore errors. } - Status WriteTensor(int64_t global_step, Tensor t, const string& tag, - const string& serialized_metadata) override { + absl::Status WriteTensor(int64_t global_step, Tensor t, const string& tag, + const string& serialized_metadata) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -101,8 +101,8 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteScalar(int64_t global_step, Tensor t, - const string& tag) override { + absl::Status WriteScalar(int64_t global_step, Tensor t, + const string& tag) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -111,8 +111,8 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteHistogram(int64_t global_step, Tensor t, - const string& tag) override { + absl::Status WriteHistogram(int64_t global_step, Tensor t, + const string& tag) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -121,8 +121,8 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteImage(int64_t global_step, Tensor t, const string& tag, - int max_images, Tensor bad_color) override { + absl::Status WriteImage(int64_t global_step, Tensor t, const string& tag, + int max_images, Tensor bad_color) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -131,8 +131,8 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteAudio(int64_t global_step, Tensor t, const string& tag, - int max_outputs, float sample_rate) override { + absl::Status WriteAudio(int64_t global_step, Tensor t, const string& tag, + int max_outputs, float sample_rate) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -141,8 +141,8 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteGraph(int64_t global_step, - std::unique_ptr graph) override { + absl::Status WriteGraph(int64_t global_step, + std::unique_ptr graph) override { std::unique_ptr e{new Event}; e->set_step(global_step); e->set_wall_time(GetWallTime()); @@ -150,7 +150,7 @@ class SummaryFileWriter : public SummaryWriterInterface { return WriteEvent(std::move(e)); } - Status WriteEvent(std::unique_ptr event) override { + absl::Status WriteEvent(std::unique_ptr event) override { mutex_lock ml(mu_); queue_.emplace_back(std::move(event)); if (queue_.size() > max_queue_ || @@ -167,7 +167,7 @@ class SummaryFileWriter : public SummaryWriterInterface { return static_cast(env_->NowMicros()) / 1.0e6; } - Status InternalFlush() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + absl::Status InternalFlush() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { for (const std::unique_ptr& e : queue_) { events_writer_->WriteEvent(*e); } @@ -193,12 +193,12 @@ class SummaryFileWriter : public SummaryWriterInterface { } // namespace -Status CreateSummaryFileWriter(int max_queue, int flush_millis, - const string& logdir, - const string& filename_suffix, Env* env, - SummaryWriterInterface** result) { +absl::Status CreateSummaryFileWriter(int max_queue, int flush_millis, + const string& logdir, + const string& filename_suffix, Env* env, + SummaryWriterInterface** result) { SummaryFileWriter* w = new SummaryFileWriter(max_queue, flush_millis, env); - const Status s = w->Initialize(logdir, filename_suffix); + const absl::Status s = w->Initialize(logdir, filename_suffix); if (!s.ok()) { w->Unref(); *result = nullptr; diff --git a/tensorflow/core/summary/summary_file_writer.h b/tensorflow/core/summary/summary_file_writer.h index 7d964516da3cee..6d58438de81b7a 100644 --- a/tensorflow/core/summary/summary_file_writer.h +++ b/tensorflow/core/summary/summary_file_writer.h @@ -33,10 +33,10 @@ namespace tensorflow { /// filename_suffix. The caller owns a reference to result if the /// returned status is ok. The Env object must not be destroyed until /// after the returned writer. -Status CreateSummaryFileWriter(int max_queue, int flush_millis, - const string& logdir, - const string& filename_suffix, Env* env, - SummaryWriterInterface** result); +absl::Status CreateSummaryFileWriter(int max_queue, int flush_millis, + const string& logdir, + const string& filename_suffix, Env* env, + SummaryWriterInterface** result); } // namespace tensorflow diff --git a/tensorflow/core/summary/summary_file_writer_test.cc b/tensorflow/core/summary/summary_file_writer_test.cc index a62876e82e0537..84f209f10256a8 100644 --- a/tensorflow/core/summary/summary_file_writer_test.cc +++ b/tensorflow/core/summary/summary_file_writer_test.cc @@ -42,9 +42,9 @@ class FakeClockEnv : public EnvWrapper { class SummaryFileWriterTest : public ::testing::Test { protected: - Status SummaryTestHelper( + absl::Status SummaryTestHelper( const string& test_name, - const std::function& writer_fn, + const std::function& writer_fn, const std::function& test_fn) { static std::set* tests = new std::set(); CHECK(tests->insert(test_name).second) << ": " << test_name; @@ -166,7 +166,7 @@ namespace { // Create a 1x1 monochrome image consisting of a single pixel oof the given // type. template -static Status CreateImage(SummaryWriterInterface* writer) { +static absl::Status CreateImage(SummaryWriterInterface* writer) { Tensor bad_color(DT_UINT8, TensorShape({1})); bad_color.scalar()() = 0; Tensor one(DataTypeToEnum::v(), TensorShape({1, 1, 1, 1})); diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index eb9724f707ec5f..248a3821085b5d 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -178,6 +178,7 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt/cpu:cpu_client", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -192,9 +193,9 @@ tf_cc_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla/pjrt/cpu:cpu_client", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -216,11 +217,11 @@ tf_cuda_cc_test( "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/service:gpu_plugin", "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc index 03dcdb7c8b9c23..3d40545051115e 100644 --- a/tensorflow/core/tfrt/common/pjrt_state_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/tfrt/common/pjrt_util_test.cc b/tensorflow/core/tfrt/common/pjrt_util_test.cc index 48f774388d355d..a13bd3d62c4b86 100644 --- a/tensorflow/core/tfrt/common/pjrt_util_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_util_test.cc @@ -19,12 +19,12 @@ limitations under the License. #include "xla/pjrt/cpu/cpu_client.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_state.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/tfrt/fallback/BUILD b/tensorflow/core/tfrt/fallback/BUILD index ebdb44eeef8cc4..a7eedfa43bbe89 100644 --- a/tensorflow/core/tfrt/fallback/BUILD +++ b/tensorflow/core/tfrt/fallback/BUILD @@ -30,6 +30,7 @@ package_group( # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/brain/mobile/lite/delegates/tfmrt/...", # copybara:uncomment "//learning/infra/mira/distributed/...", + # copybara:uncomment "//learning/infra/mira/experimental/orbax_model/...", ], ) @@ -84,10 +85,8 @@ cc_library( hdrs = ["op_kernel_runner.h"], features = tf_features_nolayering_check_if_ios(), visibility = [ - ":friends", # copybara:uncomment "//tensorflow/core/runtime_fallback:internal", - "//tensorflow/core/tfrt/graph_executor:__subpackages__", - "//tensorflow/lite/delegates/flex:__pkg__", + "//visibility:public", ], deps = [ "@com_google_absl//absl/container:inlined_vector", diff --git a/tensorflow/core/tfrt/fallback/op_kernel_runner.cc b/tensorflow/core/tfrt/fallback/op_kernel_runner.cc index 01439a4e92a746..5766387a74b82a 100644 --- a/tensorflow/core/tfrt/fallback/op_kernel_runner.cc +++ b/tensorflow/core/tfrt/fallback/op_kernel_runner.cc @@ -133,6 +133,11 @@ absl::StatusOr OpKernelRunner::Create( std::unique_ptr op_kernel; TF_RETURN_IF_ERROR(CreateOpKernel(function_library_runtime, std::move(node_def), &op_kernel)); + + if (!op_kernel) { + return absl::InternalError( + absl::StrCat("Failed to create OpKernel for op: ", op_name)); + } return OpKernelRunner(device, function_library_runtime, std::move(op_kernel)); } diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 61d869fe2a767b..3c10ed139b9bbf 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -12,8 +12,10 @@ package( package_group( name = "friends", packages = [ + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/...", # copybara:uncomment "//learning/brain/tfrt/...", + # copybara:uncomment "//learning/infra/mira/experimental/orbax_model/...", # copybara:uncomment "//learning/serving/servables/tfrt/...", # copybara:uncomment "//smartass/brain/inference/...", # copybara:uncomment "//tensorflow/compiler/mlir/tfrt/...", @@ -77,6 +79,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/compiler/mlir/tfrt:backend_compiler", "//tensorflow/compiler/mlir/tfrt:import_model", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", @@ -221,7 +224,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "config_proto_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":config_proto"], # ) diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.cc b/tensorflow/core/tfrt/graph_executor/graph_executor.cc index 23e90e00e6fae1..632cab735a1c46 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/transforms/update_op_cost_in_tfrt_mlir.h" #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" @@ -141,7 +142,7 @@ auto* graph_executor_mode = monitoring::Gauge::New( } // namespace -tensorflow::Status RunMlrtFunction( +absl::Status RunMlrtFunction( mlrt::bc::Function function, const mlrt::LoadedExecutable& loaded_executable, const tsl::RCReference& request_context, @@ -301,7 +302,7 @@ absl::StatusOr> CreateRequestInfo( return request_info; } -tensorflow::Status GraphExecutionRunOnFunction( +absl::Status GraphExecutionRunOnFunction( const GraphExecutionOptions& options, const GraphExecutionRunOptions& run_options, absl::string_view signature_name, const SymbolUids& symbol_uids, @@ -562,7 +563,7 @@ void CreateSortedNamesAndOriginalIndices(absl::Span names, } // namespace -tensorflow::Status GraphExecutor::Run( +absl::Status GraphExecutor::Run( const RunOptions& run_options, absl::Span> inputs, absl::Span output_tensor_names, @@ -677,7 +678,7 @@ tensorflow::Status GraphExecutor::Run( return absl::OkStatus(); } -tensorflow::Status GraphExecutor::Extend(const GraphDef& graph) { +absl::Status GraphExecutor::Extend(const GraphDef& graph) { return graph_execution_state_->Extend(graph); } @@ -857,15 +858,15 @@ GraphExecutor::ImportClientGraphToMlirModule( // Convert the optimized graph to an MLIR module. TF_ASSIGN_OR_RETURN( auto module, - tensorflow::ConvertGraphToMlir(*optimized_graph.graph, /*debug_info=*/{}, - optimized_graph.graph->flib_def(), - graph_import_config, context)); + tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + *optimized_graph.graph, /*debug_info=*/{}, + optimized_graph.graph->flib_def(), graph_import_config, context)); return std::make_pair(std::move(*optimized_graph.graph->mutable_flib_def()), std::move(module)); } -tensorflow::Status GraphExecutor::InitBef( +absl::Status GraphExecutor::InitBef( LoadedClientGraph* loaded_client_graph, tensorflow::tfrt_stub::WorkQueueInterface* work_queue) { auto* bef_file = loaded_client_graph->executable_context()->bef_file.get(); @@ -895,8 +896,7 @@ tensorflow::Status GraphExecutor::InitBef( return absl::OkStatus(); } -tensorflow::Status GraphExecutor::InitBytecode( - LoadedClientGraph* loaded_graph) { +absl::Status GraphExecutor::InitBytecode(LoadedClientGraph* loaded_graph) { TF_ASSIGN_OR_RETURN( auto request_info, CreateRequestInfo(options_, /*run_options=*/{}, @@ -990,7 +990,7 @@ GraphExecutor::GetOrCreateLoadedClientGraph( return {*loaded_client_graph_ptr}; } -tensorflow::Status GraphExecutor::RunWithSyncInterpreter( +absl::Status GraphExecutor::RunWithSyncInterpreter( const std::string& graph_name, absl::Span input_values, absl::Span input_names, absl::Span input_dtypes, @@ -1058,7 +1058,7 @@ CostRecorder* GraphExecutor::LoadedClientGraph::MaybeGetCostRecorder( return nullptr; } -Status GraphExecutor::LoadedClientGraph::UpdateCost( +absl::Status GraphExecutor::LoadedClientGraph::UpdateCost( const CostRecorder& cost_recorder, const Runtime& runtime) { LOG(INFO) << "TFRT updating op costs of loaded client graph (" << this << ") " << name_; @@ -1184,7 +1184,7 @@ void GraphExecutor::LoadedClientGraph::UpdateCostAnalysisData( } } -tensorflow::Status GraphExecutor::CompileGraph( +absl::Status GraphExecutor::CompileGraph( const std::string& graph_name, absl::Span input_tensor_names, absl::Span input_tensor_dtypes, diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor.h b/tensorflow/core/tfrt/graph_executor/graph_executor.h index 936d4ca891b441..5fa98297520822 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor.h +++ b/tensorflow/core/tfrt/graph_executor/graph_executor.h @@ -114,7 +114,7 @@ absl::StatusOr> CreateRequestInfo( // // TODO(chky): Refactor this function to take `LoadedClientGraph` instead of // having a long list of parameters. -tensorflow::Status GraphExecutionRunOnFunction( +absl::Status GraphExecutionRunOnFunction( const GraphExecutionOptions& options, const GraphExecutionRunOptions& run_options, absl::string_view signature_name, const SymbolUids& symbol_uids, @@ -133,7 +133,7 @@ tensorflow::Status GraphExecutionRunOnFunction( CostRecorder* cost_recorder = nullptr); // Runs a MLRT function for executing tensorflow graphs. -tensorflow::Status RunMlrtFunction( +absl::Status RunMlrtFunction( mlrt::bc::Function function, const mlrt::LoadedExecutable& loaded_executable, const tsl::RCReference& request_context, @@ -168,8 +168,8 @@ class GraphExecutor { CostRecorder* MaybeGetCostRecorder(absl::Time now, bool* do_recompilation); // Updates the op cost values in this `LoadedClientGraph` with records from // `cost_recorder`. - Status UpdateCost(const CostRecorder& cost_recorder, - const Runtime& runtime); + absl::Status UpdateCost(const CostRecorder& cost_recorder, + const Runtime& runtime); // Updates `cost_analysis_data_` to make it accurate for the next execution. // Assumes a cost update occurred this cycle. void UpdateCostAnalysisData(absl::Time now, bool do_recompilation); @@ -267,7 +267,7 @@ class GraphExecutor { std::unique_ptr kernel_registry); // Runs on the graph according to given input/output. - tensorflow::Status Run( + absl::Status Run( const RunOptions& run_options, absl::Span> inputs, absl::Span output_tensor_names, @@ -279,7 +279,7 @@ class GraphExecutor { // responsibility to ensure `graph_name` corresponds to logically different // graphs, since this name is used to lookup compiled graphs in the cache. The // graph is run synchronously with the TFRT interpreter. - tensorflow::Status RunWithSyncInterpreter( + absl::Status RunWithSyncInterpreter( const std::string& graph_name, absl::Span input_values, absl::Span input_names, absl::Span input_dtypes, @@ -288,7 +288,7 @@ class GraphExecutor { absl::Span outputs); // Extends the current graph by `graph`. - tensorflow::Status Extend(const GraphDef& graph); + absl::Status Extend(const GraphDef& graph); tensorflow::tfrt_stub::TfrtGraphExecutionState& graph_execution_state() const { @@ -308,7 +308,7 @@ class GraphExecutor { FallbackState& fallback_state() { return *fallback_state_; } // Compiles graph for `graph_name` and runs any initializers. - tensorflow::Status CompileGraph( + absl::Status CompileGraph( const std::string& graph_name, absl::Span input_tensor_names, absl::Span input_tensor_dtypes, @@ -337,11 +337,10 @@ class GraphExecutor { absl::StatusOr CompileMlirModuleToBef( mlir::ModuleOp module) const; - tensorflow::Status InitBef( - LoadedClientGraph* loaded_client_graph, - tensorflow::tfrt_stub::WorkQueueInterface* work_queue); + absl::Status InitBef(LoadedClientGraph* loaded_client_graph, + tensorflow::tfrt_stub::WorkQueueInterface* work_queue); - tensorflow::Status InitBytecode(LoadedClientGraph* loaded_graph); + absl::Status InitBytecode(LoadedClientGraph* loaded_graph); // Returns a `LoadedClientGraph` given input/output tensor info. If there is // no existing one yet, creates one first. diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index c0e07b385f763c..08d8bf4c2a89e9 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -76,7 +76,7 @@ class GraphExecutorForTestingCostAnalysis : public GraphExecutor { class GraphExecutorTest : public ::testing::TestWithParam {}; -tensorflow::Status GetSimpleGraphDef(GraphDef& graph_def) { +absl::Status GetSimpleGraphDef(GraphDef& graph_def) { auto scope = tensorflow::Scope::NewRootScope().WithDevice("/device:CPU:0"); auto input = ops::Placeholder(scope.WithOpName("input"), DT_INT32); diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index 69e129b4353c07..c7204a2756360f 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -121,11 +121,14 @@ cc_library( ":ifrt_device_utils", ":ifrt_loaded_variable_registry", ":ifrt_loaded_variable_utils", + ":ifrt_persistent_compilation_cache", ":ifrt_restore_tensor_registry", ":ifrt_serving_core_selector", ":ifrt_tensor_utils", ":sharding_utils", ":tf_host_callback", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tfrt:export", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:extract_callback", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", @@ -159,12 +162,13 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/pjrt:host_callback", + "@local_xla//xla/pjrt:pjrt_compiler", "@local_xla//xla/pjrt:pjrt_executable", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt/hlo:hlo_program", "@local_xla//xla/python/pjrt_ifrt", - "@local_xla//xla/python/pjrt_ifrt:xla_ifrt", "@local_xla//xla/service:computation_placer_hdr", "@local_xla//xla/tsl/concurrency:ref_count", "@local_xla//xla/tsl/framework:serving_device_selector", @@ -209,6 +213,26 @@ cc_library( ], ) +cc_library( + name = "ifrt_persistent_compilation_cache", + srcs = ["ifrt_persistent_compilation_cache.cc"], + hdrs = ["ifrt_persistent_compilation_cache.h"], + deps = [ + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:tf2hlo", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_xla//xla/pjrt:pjrt_executable", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt/hlo:hlo_program", + "@local_xla//xla/python/pjrt_ifrt:xla_ifrt", + "@local_xla//xla/tsl/concurrency:ref_count", + ], +) + cc_library( name = "ifrt_loaded_variable_registry", srcs = ["ifrt_loaded_variable_registry.cc"], @@ -246,6 +270,7 @@ cc_library( deps = [ ":ifrt_executable_registry", ":ifrt_loaded_variable_registry", + ":ifrt_persistent_compilation_cache", ":ifrt_restore_tensor_registry", ":ifrt_serving_core_selector", "//tensorflow/compiler/tf2xla:xla_helpers", @@ -276,10 +301,9 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/python/ifrt", - "@local_xla//xla/python/pjrt_ifrt", + "@local_xla//xla/python/pjrt_ifrt:pjrt_dtype", ], ) @@ -386,6 +410,7 @@ cc_library( ], deps = [ ":ifrt_loaded_variable_registry", + ":ifrt_persistent_compilation_cache", ":ifrt_restore_tensor_registry", ":ifrt_serving_core_selector", ":ifrt_serving_executable", @@ -619,7 +644,6 @@ cc_library( srcs = ["checkpoint_loader.cc"], hdrs = ["checkpoint_loader.h"], deps = [ - ":ifrt_loaded_variable_registry", ":ifrt_loaded_variable_utils", ":ifrt_restore_tensor_registry", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", @@ -629,6 +653,8 @@ cc_library( "//tensorflow/core/framework:node_def_util", "//tensorflow/core/framework:tensor", "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/kernel:context", diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc index dc4ad26e7c10f7..1e468b02760310 100644 --- a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" -#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h" #include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h" @@ -138,10 +138,62 @@ absl::StatusOr Cast( return *(op_kernel_context.mutable_output(0)); } +void RunShardHelper(const tfrt_stub::OpKernelRunner& runner, + AsyncState* async_state, RestoreVariableShard shard) { + // Keep input tensor alive in `shard`. + auto* op_kernel_context_ptr = &async_state->context; + runner.Run(op_kernel_context_ptr); + + auto& op_kernel_context = async_state->context; + if (!op_kernel_context.status().ok()) { + for (auto& result : async_state->results) { + std::move(result).Set(op_kernel_context.status()); + } + return; + } + DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); + DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs()); + + // TODO(b/343964091): consider to run multiple casts in parallel. + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + + if (op_kernel_context.mutable_output(i)->dtype() != + shard.restored_dtypes[i]) { + std::move(async_state->results[i]) + .Set(absl::InvalidArgumentError( + absl::StrCat("The restored tensor has a different dtype than the " + "variable handle: ", + op_kernel_context.mutable_output(i)->dtype(), + " vs. ", shard.restored_dtypes[i]))); + return; + } + const ResourceHandle& var_handle = + shard.var_handles[i].tensor().scalar()(); + + if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) { + std::move(async_state->results[i]) + .Set(*std::move(op_kernel_context.mutable_output(i))); + } else { + absl::StatusOr cast_output = + Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i], + var_handle.dtypes_and_shapes()[0].dtype, + shard.truncate_in_cast[i], async_state->device_manager, + async_state->process_function_library_runtime, + async_state->run_state.params); + if (!cast_output.ok()) { + std::move(async_state->results[i]).Set(cast_output.status()); + } else { + std::move(async_state->results[i]).Set(*std::move(cast_output)); + } + } + } +} + absl::Status RunShard(RestoreVariableShard shard, IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue, - tf_mlrt::Context& context) { + tf_mlrt::Context& context, bool use_async_restore) { if (!ifrt_restore_tensor_registry) { return absl::InternalError("ifrt_restore_tensor_registry must not be null"); } @@ -218,60 +270,18 @@ absl::Status RunShard(RestoreVariableShard shard, } async_state->results.push_back(std::move(promise)); } + // Run the shard synchronously. + if (!use_async_restore) { + RunShardHelper(runner, async_state.get(), shard); + } else { + // Use dedicated work queue for restore operation. + checkpoint_loader_work_queue->AddTask([runner = std::move(runner), + async_state = std::move(async_state), + shard = std::move(shard)]() { + RunShardHelper(runner, async_state.get(), shard); + }); + } - // Use dedicated work queue for restore operation. - checkpoint_loader_work_queue->AddTask([runner = std::move(runner), - async_state = std::move(async_state), - shard = std::move(shard)]() { - // Keep input tensor alive in `shard`. - auto* op_kernel_context_ptr = &async_state->context; - runner.Run(op_kernel_context_ptr); - - auto& op_kernel_context = async_state->context; - if (!op_kernel_context.status().ok()) { - for (auto& result : async_state->results) { - std::move(result).Set(op_kernel_context.status()); - } - return; - } - DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); - DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs()); - - // TODO(b/343964091): consider to run multiple casts in parallel. - for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { - DCHECK(op_kernel_context.mutable_output(i)); - - if (op_kernel_context.mutable_output(i)->dtype() != - shard.restored_dtypes[i]) { - std::move(async_state->results[i]) - .Set(absl::InvalidArgumentError(absl::StrCat( - "The restored tensor has a different dtype than the " - "variable handle: ", - op_kernel_context.mutable_output(i)->dtype(), " vs. ", - shard.restored_dtypes[i]))); - return; - } - const ResourceHandle& var_handle = - shard.var_handles[i].tensor().scalar()(); - - if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) { - std::move(async_state->results[i]) - .Set(*std::move(op_kernel_context.mutable_output(i))); - } else { - absl::StatusOr cast_output = - Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i], - var_handle.dtypes_and_shapes()[0].dtype, - shard.truncate_in_cast[i], async_state->device_manager, - async_state->process_function_library_runtime, - async_state->run_state.params); - if (!cast_output.ok()) { - std::move(async_state->results[i]).Set(cast_output.status()); - } else { - std::move(async_state->results[i]).Set(*std::move(cast_output)); - } - } - } - }); return absl::OkStatus(); } @@ -286,8 +296,7 @@ int64_t GetSizeFromVarHandle(const ResourceHandle& handle) { } // namespace -absl::Status CheckpointLoader::PrepareRestore( - mlir::OwningOpRef module) { +absl::Status CheckpointLoader::PrepareRestore(const PrepareRestoreArgs& args) { VLOG(1) << "Skip CheckpointLoader::PrepareRestore"; return absl::OkStatus(); } @@ -349,7 +358,8 @@ absl::Status CheckpointLoader::Load( } for (const auto& shard : shards) { TF_RETURN_IF_ERROR(RunShard(shard, ifrt_restore_tensor_registry_, - checkpoint_loader_work_queue_, context)); + checkpoint_loader_work_queue_, context, + use_async_restore_)); } return absl::OkStatus(); } diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h index 4710c711cfcb0c..e47c78bbba0dc4 100644 --- a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h @@ -23,7 +23,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" @@ -38,15 +41,25 @@ namespace ifrt_serving { // Implement the `CheckpointLoaderInterface` by using RestoreV2. class CheckpointLoader { public: + struct PrepareRestoreArgs { + mlir::MLIRContext* context; + tensorflow::MetaGraphDef meta_graph_def; + tfrt_stub::FallbackState* fallback_state; + std::string saved_model_dir; + bool run_placer_grappler_on_functions; + }; + explicit CheckpointLoader( IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, - tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue) + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue, + bool use_async_restore = true) : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry), - checkpoint_loader_work_queue_(checkpoint_loader_work_queue) {} + checkpoint_loader_work_queue_(checkpoint_loader_work_queue), + use_async_restore_(use_async_restore) {} virtual ~CheckpointLoader() = default; // Called before `Load` to do some preparation work. - virtual absl::Status PrepareRestore(mlir::OwningOpRef module); + virtual absl::Status PrepareRestore(const PrepareRestoreArgs& args); // Load the checkpoint. This API is designed to be compatible with the // `tf_mlrt.ifrt_restore_variable` kernel. @@ -61,6 +74,7 @@ class CheckpointLoader { protected: IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_; tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_; + bool use_async_restore_ = true; }; } // namespace ifrt_serving diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h index 5275b683027602..25b9e6c3810af7 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h @@ -88,6 +88,7 @@ class ServingExecutableRegistry { private: friend class Handle; + friend class IfrtBackendCompilerTest; static absl::Mutex mu_; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc index 8c5c83cc7a0000..2f3154e064efb7 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc @@ -96,7 +96,8 @@ CreateIfrtServingExecutable(mlir::MLIRContext& context, int64_t program_id) { &ifrt_restore_tensor_registry, work_queue.get(), device_mgr.get(), tensorflow::IdentityShapeRepresentationFn(), /*ifrt_serving_core_selector=*/nullptr, - /*compilation_environment_proto=*/nullptr); + /*compilation_environment_proto=*/nullptr, + /*persistent_compilation_cache=*/nullptr); } TEST(IfrtExecutableRegistry, Basic) { diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h index bc8f802ab8c75c..f5bc05213590a9 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_ #define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_CONTEXT_H_ +#include #include #include #include @@ -25,10 +26,12 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/topology.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tsl/platform/protobuf.h" @@ -68,7 +71,8 @@ class IfrtModelContext { tsl::thread::ThreadPool* thread_pool, tensorflow::DeviceMgr* device_mgr, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::unique_ptr compilation_environment_proto, - std::shared_ptr topology) + std::shared_ptr topology, + IfrtPersistentCompilationCache* persistent_compilation_cache = nullptr) : client_(std::move(client)), topology_(topology), ifrt_serving_core_selector_(ifrt_serving_core_selector), @@ -76,7 +80,8 @@ class IfrtModelContext { device_mgr_(device_mgr), shape_representation_fn_(shape_representation_fn), compilation_environment_proto_( - std::move(compilation_environment_proto)) {} + std::move(compilation_environment_proto)), + persistent_compilation_cache_(persistent_compilation_cache) {} void RegisterHandle(ServingExecutableRegistry::Handle handle) { handles_.push_back(std::move(handle)); @@ -105,6 +110,10 @@ class IfrtModelContext { return restore_tensor_registry_; } + IfrtPersistentCompilationCache* GetPersistentCompilationCache() const { + return persistent_compilation_cache_; + } + tensorflow::DeviceMgr* GetDeviceMgr() const { return device_mgr_; } IfrtServingCoreSelector* GetIfrtServingCoreSelector() const { return ifrt_serving_core_selector_; @@ -153,6 +162,7 @@ class IfrtModelContext { IfrtLoadedVariableRegistry loaded_variable_registry_; IfrtRestoreTensorRegistry restore_tensor_registry_; + IfrtPersistentCompilationCache* persistent_compilation_cache_ = nullptr; bool frozen_ = false; }; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.cc b/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.cc new file mode 100644 index 00000000000000..7c971018575f40 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.cc @@ -0,0 +1,70 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/program.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/tsl/concurrency/ref_count.h" + +namespace tensorflow { +namespace ifrt_serving { + +absl::StatusOr> +IfrtPersistentCompilationCache::LookupLoadedExecutableOrCreate( + std::unique_ptr hlo_program, + tsl::RCReference device_list, + const xla::CompileOptions& xla_compile_options, + const std::vector>& + loaded_host_callbacks, + xla::ifrt::Client* client, + absl::AnyInvocable< + absl::StatusOr>( + std::unique_ptr program, + std::unique_ptr options)> + value_fn) { + // No persistent cache implemented, compile directly. + auto ifrt_xla_compile_options = + std::make_unique(xla_compile_options, + loaded_host_callbacks); + return value_fn(std::move(hlo_program), std::move(ifrt_xla_compile_options)); + ; +} + +absl::StatusOr +IfrtPersistentCompilationCache::LookupTf2HloResultOrCreate( + Tf2HloArg tf2hlo_arg, tsl::RCReference device_list) { + // No tf2xla persistent cache is implemented, compile directly. + return CompileTfToHlo(tf2hlo_arg); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h b/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h new file mode 100644 index 00000000000000..e1994aea777b98 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h @@ -0,0 +1,76 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ + +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/program.h" +#include "xla/tsl/concurrency/ref_count.h" +namespace tensorflow { +namespace ifrt_serving { + +class IfrtPersistentCompilationCache { + public: + IfrtPersistentCompilationCache() = default; + virtual ~IfrtPersistentCompilationCache() = default; + + // The implementation of this API should be thread-safe. It generates a key + // for looking up the executable in the persistent cache and it will return + // the LoadedExecutable if hits cache. Otherwise, it will call the `value_fn` + // to generate and return the LoadedExecutable. + virtual absl::StatusOr> + LookupLoadedExecutableOrCreate( + std::unique_ptr hlo_program, + tsl::RCReference device_list, + const xla::CompileOptions& xla_compile_options, + const std::vector>& + loaded_host_callbacks, + xla::ifrt::Client* client, + absl::AnyInvocable< + absl::StatusOr>( + std::unique_ptr program, + std::unique_ptr options)> + value_fn); + + // The implementation of this API should be thread-safe. It generates a key + // for looking up the Tf2HloResult in the persistent cache and it will return + // the Tf2HloResult if hits cache. Otherwise, it will call the `value_fn` to + // generate and return the Tf2HloResult. + virtual absl::StatusOr LookupTf2HloResultOrCreate( + Tf2HloArg tf2hlo_arg, + tsl::RCReference device_list); + + virtual bool IsXlaCompilationCacheEnabled() const { return false; } + virtual bool IsTf2HloCompilationCacheEnabled() const { return false; } +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_PERSISTENT_COMPILATION_CACHE_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 4bf833133849ee..5a076556e25300 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -39,7 +39,10 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" @@ -48,7 +51,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/pjrt/host_callback.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" @@ -58,10 +63,10 @@ limitations under the License. #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/host_callback.h" +#include "xla/python/ifrt/program.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" -#include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/tsl/concurrency/ref_count.h" @@ -79,6 +84,7 @@ limitations under the License. #include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" @@ -184,6 +190,34 @@ absl::StatusOr> GetAssignedDevices( device_assignment_attr_val); } +absl::StatusOr< + absl::flat_hash_map>> +GetHostCallbackModulesAndRemoveHostFuncs(mlir::ModuleOp module) { + absl::flat_hash_map> + host_callback_modules; + llvm::DenseSet xla_host_compute_ops; + module->walk( + [&](mlir::TF::XlaHostComputeOp op) { xla_host_compute_ops.insert(op); }); + for (auto& op : xla_host_compute_ops) { + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef host_callback_module, + ExtractCallbackModule(module, op.getKey().str())); + auto [_, inserted] = host_callback_modules.insert( + {op.getKey().str(), std::move(host_callback_module)}); + if (!inserted) { + return absl::FailedPreconditionError( + absl::StrCat("Duplicate host callback key: ", op.getKey().str())); + } + auto func = mlir::SymbolTable::lookupNearestSymbolFrom( + module, op.getKeyAttr()); + if (!func) { + return absl::InternalError( + absl::StrCat("symbol not found: ", op.getKey().str())); + } + func->erase(); + } + return host_callback_modules; +} + } // namespace absl::StatusOr> @@ -198,7 +232,8 @@ IfrtServingExecutable::Create( tensorflow::DeviceMgr* device_mgr, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, IfrtServingCoreSelector* ifrt_serving_core_selector, - tsl::protobuf::Message* compilation_environement_proto) { + tsl::protobuf::Message* compilation_environment_proto, + IfrtPersistentCompilationCache* persistent_compilation_cache) { TF_ASSIGN_OR_RETURN( tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata, GetCompileMetadata(*module, *client)); @@ -217,7 +252,7 @@ IfrtServingExecutable::Create( std::move(original_compile_metadata), xla::ifrt::BasicDeviceList::Create(xla::ifrt::BasicDeviceList::Devices( assigned_devices.begin(), assigned_devices.end())), - compilation_environement_proto)); + compilation_environment_proto, persistent_compilation_cache)); return executable; } @@ -280,7 +315,7 @@ GroupHostCallbackByKey(const Tf2HloResult& tf2hlo_result) { // TODO: shape propagation in module absl::StatusOr BuildHostCallback( absl::string_view key, const HostCallbackBuilderInfo& builder_info, - mlir::ModuleOp module, tensorflow::DeviceMgr* device_mgr, + mlir::ModuleOp callback_module, tensorflow::DeviceMgr* device_mgr, std::vector>& tf_host_callbacks) { VLOG(2) << "BuildHostCallback for key: " << key; @@ -329,13 +364,8 @@ absl::StatusOr BuildHostCallback( DtypeAndShape{.dtype = metadata.type(), .shape = metadata.shape()}); } - // TODO(b/332774825): reuse functions in BEF/MLRT once we switch to - // GraphExecutor. - TF_ASSIGN_OR_RETURN(mlir::OwningOpRef callback_module, - ExtractCallbackModule(module, key)); - TF_ASSIGN_OR_RETURN(std::vector function_defs, - BuildFunctionDef(*callback_module)); + BuildFunctionDef(callback_module)); TF_ASSIGN_OR_RETURN( std::unique_ptr tf_host_callback, @@ -352,7 +382,9 @@ absl::StatusOr BuildHostCallback( } absl::StatusOr> BuildHostCallbacks( - const Tf2HloResult& tf2hlo_result, mlir::ModuleOp module, + const Tf2HloResult& tf2hlo_result, + absl::flat_hash_map> + host_callback_modules, tensorflow::DeviceMgr* device_mgr, std::vector>& tf_host_callbacks) { TF_ASSIGN_OR_RETURN(auto host_callback_maps, @@ -361,8 +393,14 @@ absl::StatusOr> BuildHostCallbacks( std::vector host_callbacks; host_callbacks.reserve(host_callback_maps.size()); for (const auto& [entry_function, builder_info] : host_callback_maps) { + auto host_callback_module_it = host_callback_modules.find(entry_function); + if (host_callback_module_it == host_callback_modules.end()) { + return absl::NotFoundError(absl::StrCat( + "Host callback module not found for key: ", entry_function)); + } TF_ASSIGN_OR_RETURN(auto host_callback, - BuildHostCallback(entry_function, builder_info, module, + BuildHostCallback(entry_function, builder_info, + *host_callback_module_it->second, device_mgr, tf_host_callbacks)); host_callbacks.push_back(std::move(host_callback)); } @@ -375,11 +413,38 @@ IfrtServingExecutable::CreateExecutableSynchronously( mlir::OwningOpRef module_copy, const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata, absl::Span dtypes_and_shapes) { + TF_ASSIGN_OR_RETURN(auto host_callback_modules, + GetHostCallbackModulesAndRemoveHostFuncs(*module_copy)); + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("module_for_bridge_phase2", *module_copy); + } + Tf2HloArg tf2hlo_arg{ + .module = module_copy.get(), + .input_dtypes_and_shapes = dtypes_and_shapes, + .entry_function_name = signature_name(), + .compile_metadata = compile_metadata, + .shape_representation_fn = shape_representation_fn_, + .platform_name = ifrt_client_->platform_name(), + }; + + if (tf2hlo_arg.platform_name != xla::CudaName()) { + TF_ASSIGN_OR_RETURN( + tf2hlo_arg.topology, + ifrt_client_->GetTopologyForDevices(assigned_device_list_)); + } + + TF_ASSIGN_OR_RETURN(Tf2HloResult tf2hlo_result, + persistent_compilation_cache_->LookupTf2HloResultOrCreate( + tf2hlo_arg, assigned_device_list_)); TF_ASSIGN_OR_RETURN( - Tf2HloResult tf2hlo_result, - CompileTfToHlo(*module_copy, dtypes_and_shapes, signature_name(), - *ifrt_client_, compile_metadata, - shape_representation_fn_)); + mlir::OwningOpRef mlir_hlo_module, + xla::ConvertHloToMlirHlo(*module_copy->getContext(), + &tf2hlo_result.hlo_module_proto)); + + if (VLOG_IS_ON(1)) { + tensorflow::DumpMlirOpToFile("ifrt_after_bridge_phase2", + mlir_hlo_module.get()); + } const int num_replicas = tf2hlo_result.compile_metadata.num_replicas(); const int num_partitions = tf2hlo_result.compile_metadata.num_cores_per_replica(); @@ -423,9 +488,10 @@ IfrtServingExecutable::CreateExecutableSynchronously( } std::vector> tf_host_callbacks; - TF_ASSIGN_OR_RETURN(auto host_callbacks, - BuildHostCallbacks(tf2hlo_result, *module_copy, - device_mgr_, tf_host_callbacks)); + TF_ASSIGN_OR_RETURN( + auto host_callbacks, + BuildHostCallbacks(tf2hlo_result, std::move(host_callback_modules), + device_mgr_, tf_host_callbacks)); std::vector> loaded_host_callbacks; @@ -436,17 +502,24 @@ IfrtServingExecutable::CreateExecutableSynchronously( ifrt_client_.get(), std::make_unique(host_callback))); } + auto hlo_program = + std::make_unique(mlir_hlo_module.get()); + std::unique_ptr ifrt_executable; + SharedCachedExecutableBundle executable_bundle = + std::make_shared(); TF_ASSIGN_OR_RETURN( - std::unique_ptr ifrt_executable, - ifrt_client_->GetDefaultCompiler()->Compile( - std::make_unique( - tf2hlo_result.mlir_hlo_module.get()), - std::make_unique( - xla_compile_options, loaded_host_callbacks))); + ifrt_executable, + persistent_compilation_cache_->LookupLoadedExecutableOrCreate( + std::move(hlo_program), assigned_device_list_, xla_compile_options, + loaded_host_callbacks, ifrt_client_.get(), + [&](std::unique_ptr program, + std::unique_ptr options) + -> absl::StatusOr> { + return ifrt_client_->GetDefaultCompiler()->Compile( + std::move(program), std::move(options)); + })); - SharedCachedExecutableBundle executable_bundle = - std::make_shared(); executable_bundle->ifrt_executable = std::move(ifrt_executable); executable_bundle->compile_metadata = std::move(tf2hlo_result.compile_metadata); @@ -639,10 +712,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( TF_ASSIGN_OR_RETURN( auto execution_result, executable_bundle->ifrt_executable->Execute( - absl::MakeSpan(args), - /*options=*/ - {.untuple_result = true, - .use_major_to_minor_data_layout_for_callbacks = true}, + absl::MakeSpan(args), /*options=*/{.fill_status = true}, std::move(execution_device_list))); auto status = execution_result.status.Await(); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 76b40f89323de4..34d46f1d3c3e08 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_H_ #define TENSORFLOW_CORE_TFRT_IFRT_IFRT_SERVING_EXECUTABLE_H_ +#include + #include #include #include @@ -49,6 +51,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" @@ -72,7 +75,8 @@ class IfrtServingExecutable { tensorflow::DeviceMgr* device_mgr, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, IfrtServingCoreSelector* ifrt_serving_core_selector, - tsl::protobuf::Message* compilation_environment_proto); + tsl::protobuf::Message* compilation_environment_proto, + IfrtPersistentCompilationCache* persistent_compilation_cache); // Movable but not copyable. IfrtServingExecutable(IfrtServingExecutable&& other) = default; @@ -99,6 +103,7 @@ class IfrtServingExecutable { } private: + friend class IfrtBackendCompilerTest; // In memory cache key. struct Key { std::vector input_shapes; @@ -145,7 +150,8 @@ class IfrtServingExecutable { IfrtServingCoreSelector* ifrt_serving_core_selector, tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata, tsl::RCReference assigned_device_list, - tsl::protobuf::Message* compilation_environment_proto) + tsl::protobuf::Message* compilation_environment_proto, + IfrtPersistentCompilationCache* persistent_compilation_cache) : program_id_(program_id), model_name_(std::string(model_name)), signature_name_(std::string(signature_name)), @@ -160,7 +166,8 @@ class IfrtServingExecutable { device_mgr_(device_mgr), shape_representation_fn_(std::move(shape_representation_fn)), ifrt_serving_core_selector_(std::move(ifrt_serving_core_selector)), - compilation_environment_proto_(compilation_environment_proto) {} + compilation_environment_proto_(compilation_environment_proto), + persistent_compilation_cache_(persistent_compilation_cache) {} int64_t program_id_; using SharedCachedExecutableBundle = std::shared_ptr; @@ -194,6 +201,11 @@ class IfrtServingExecutable { bool is_frozen_ ABSL_GUARDED_BY(mutex_) = false; + // The persistent compilation cache is a global cache and is not owned by + // this executable. When it is nullptr, the persistent compilation cache is + // disabled at ifrt serving level. + IfrtPersistentCompilationCache* persistent_compilation_cache_; + // Asynchronously load the restored variable tensors to Ifrt array. absl::Status AsyncLoadIfrtArray( absl::Span inputs, diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.cc index 004d98d6425102..ff327d7b1c5ef3 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" #include "tensorflow/core/tfrt/ifrt/tf_host_callback.h" @@ -75,6 +76,8 @@ IfrtServingExecutableTestHelper::IfrtServingExecutableTestHelper( mlir::registerAllDialects(registry_); mlir::RegisterAllTensorFlowDialects(registry_); context_ = std::make_unique(registry_); + ifrt_persistent_compilation_cache_ = + std::make_unique(); } std::unique_ptr @@ -87,7 +90,8 @@ IfrtServingExecutableTestHelper::MakeExecutable(int64_t program_id, thread_pool_.get(), &ifrt_loaded_variable_registry_, &ifrt_restore_tensor_registry_, work_queue_.get(), device_mgr_.get(), tensorflow::IdentityShapeRepresentationFn(), core_selector_.get(), - /*compilation_environment_proto=*/nullptr); + /*compilation_environment_proto=*/nullptr, + ifrt_persistent_compilation_cache_.get()); TF_CHECK_OK(executable_or.status()); return std::move(executable_or.value()); } diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h index 44d84a0944a532..238e72b8c85db6 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_persistent_compilation_cache.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h" @@ -72,6 +73,8 @@ class IfrtServingExecutableTestHelper { mlir::DialectRegistry registry_; std::unique_ptr context_; + std::unique_ptr + ifrt_persistent_compilation_cache_; }; // Returns the path to the MLIR module for the given module name. diff --git a/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc index b920efe932bd5c..f3ef48f0a2e3d6 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc index cd29511d4d982f..11c7157925dad3 100644 --- a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc @@ -20,16 +20,10 @@ limitations under the License. #include #include -#include "absl/log/check.h" -#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/pjrt/cpu/cpu_client.h" -#include "xla/python/ifrt/array.h" -#include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/tsl/framework/serving_device_selector.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc index e89dc042152756..cba087f976e968 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc @@ -370,7 +370,7 @@ mlrt::bc::Buffer CreateExecutableForIfrtLoadVariableOp( return buffer; } -class KernelTest : public ::testing::Test { +class KernelTest : public ::testing::TestWithParam { protected: void SetUp() override { mlrt::RegisterBuiltinKernels(registry_); @@ -410,7 +410,8 @@ class KernelTest : public ::testing::Test { ifrt_serving::kIfrtModelRestoreContextName, std::make_unique( &ifrt_model_context_->GetRestoreTensorRegistry(), - ifrt_model_context_->checkpoint_loader_queue())); + ifrt_model_context_->checkpoint_loader_queue(), + /*use_async_restore=*/GetParam())); serving_device_selector_ = std::make_unique(); @@ -440,7 +441,7 @@ class KernelTest : public ::testing::Test { tensorflow::ifrt_serving::IfrtModelContext* ifrt_model_context_; }; -TEST_F(KernelTest, IfrtLoadVariableOpCanGetTensorFromResourceManager) { +TEST_P(KernelTest, IfrtLoadVariableOpCanGetTensorFromResourceManager) { auto buffer = CreateExecutableForIfrtLoadVariableOp( /*redundant_ifrt_load_variable_op=*/false, /*used_by_host=*/true); @@ -488,7 +489,7 @@ TEST_F(KernelTest, IfrtLoadVariableOpCanGetTensorFromResourceManager) { TensorEq(input_tensor)); } -TEST_F(KernelTest, IfrtLoadVariableOp) { +TEST_P(KernelTest, IfrtLoadVariableOp) { auto buffer = CreateExecutableForIfrtLoadVariableOp(); mlrt::bc::Executable executable(buffer.data()); @@ -540,7 +541,7 @@ TEST_F(KernelTest, IfrtLoadVariableOp) { TensorEq(tensorflow::Tensor())); } -TEST_F(KernelTest, DuplicateIfrtLoadVariableOpShallSucceed) { +TEST_P(KernelTest, DuplicateIfrtLoadVariableOpShallSucceed) { auto buffer = CreateExecutableForIfrtLoadVariableOp( /*redundant_ifrt_load_variable_op=*/true); @@ -594,7 +595,7 @@ TEST_F(KernelTest, DuplicateIfrtLoadVariableOpShallSucceed) { TensorEq(tensorflow::Tensor())); } -TEST_F(KernelTest, IfrtRestoreVariableOp) { +TEST_P(KernelTest, IfrtRestoreVariableOp) { std::string checkpoint_prefix = tensorflow::GetDataDependencyFilepath( "tensorflow/core/tfrt/mlrt/kernel/testdata/" @@ -656,7 +657,7 @@ TEST_F(KernelTest, IfrtRestoreVariableOp) { EXPECT_THAT(*restored_tensor, TensorEq(AsTensor({1, 2, 3}, {3}))); } -TEST_F(KernelTest, IfrtRestoreVariableOp4Variables) { +TEST_P(KernelTest, IfrtRestoreVariableOp4Variables) { std::string checkpoint_prefix = tensorflow::GetDataDependencyFilepath( "tensorflow/core/tfrt/mlrt/kernel/testdata/" @@ -748,7 +749,7 @@ TEST_F(KernelTest, IfrtRestoreVariableOp4Variables) { TensorEq(AsTensor({10, 11, 12}, {3}))); } -TEST_F(KernelTest, IfrtRestoreVariableOpInValidInput) { +TEST_P(KernelTest, IfrtRestoreVariableOpInValidInput) { std::string checkpoint_prefix = tensorflow::GetDataDependencyFilepath( "tensorflow/core/tfrt/mlrt/kernel/testdata/" @@ -810,6 +811,8 @@ TEST_F(KernelTest, IfrtRestoreVariableOpInValidInput) { ::tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument)); } +INSTANTIATE_TEST_SUITE_P(KernelTest, KernelTest, ::testing::Bool()); + } // namespace } // namespace tf_mlrt } // namespace tensorflow diff --git a/tensorflow/core/tfrt/runtime/BUILD b/tensorflow/core/tfrt/runtime/BUILD index dfdc99cde0ab9f..630336de3550dd 100644 --- a/tensorflow/core/tfrt/runtime/BUILD +++ b/tensorflow/core/tfrt/runtime/BUILD @@ -15,6 +15,7 @@ package_group( "//tensorflow/core/tfrt/...", "//tensorflow/core/runtime_fallback/...", # copybara:uncomment "//tensorflow_serving/...", + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 5261546e6c6a0f..a133e53fe89098 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -10,6 +10,7 @@ package_group( name = "friends", packages = [ # Authorized users go here. + # copybara:uncomment "//cloud/ai/platform/prediction/...", # copybara:uncomment "//learning/brain/experimental/tfrt/...", # copybara:uncomment "//learning/brain/tfrt/...", # copybara:uncomment "//learning/infra/mira/...", @@ -104,12 +105,12 @@ cc_library( "//tensorflow/cc/saved_model:reader", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:serialize_mlir_module_utils", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tensorflow:upgrade_graph", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/compiler/mlir/tfrt:import_model", "//tensorflow/compiler/mlir/tfrt:saved_model", "//tensorflow/compiler/mlir/tfrt:tfrt_compile_options", diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 1056fc05dc3a7d..0d16406120bcaa 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -40,9 +40,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "tensorflow/compiler/mlir/tfrt/saved_model/saved_model.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" @@ -99,6 +98,10 @@ namespace tfrt_stub { namespace { constexpr absl::string_view kSignatureJoiningDelimiter = "+"; +constexpr absl::string_view kXlaCallModuleOpName = "XlaCallModule"; +// TODO(b/374165187): Use enums for model types. +constexpr absl::string_view kJaxModelLabel = "JAX"; +constexpr absl::string_view kUnknownModelLabel = "UNKNOWN"; auto* lazy_loading_count = monitoring::Counter<3>::New( "/tensorflow/tfrt/lazy_loading_count", "The total number of lazy loadings.", @@ -112,6 +115,11 @@ auto* use_backend_compiler_count = monitoring::Counter<3>::New( "The total number of instances that use injected backend compiler.", "model_name", "model_version", "use_backend_compiler"); +auto* inferred_model_type_count = monitoring::Counter<3>::New( + "/tensorflow/tfrt/inferred_model_type", + "Count of SavedModels with their inferred model types (best-effort).", + "model_name", "model_version", "inferred_model_type"); + auto* saved_model_import_time_seconds = tensorflow::monitoring::Gauge::New( "/tensorflow/tfrt/saved_model/import_time", @@ -141,22 +149,19 @@ absl::Status PrepareRestore(mlir::MLIRContext* context, const std::string& saved_model_dir, const SavedModel::Options& options, ifrt_serving::CheckpointLoader* checkpoint_loader) { - // Import the global MLIR with `import_user_signatures` as true so that we can - // analysis the global MLIR to retrieve data needed for restore. - mlir::OwningOpRef mlir_module_restore_analysis; - ASSIGN_OR_RETURN_IN_IMPORT( - mlir_module_restore_analysis, - ImportSavedModel( - context, meta_graph_def, fallback_state, saved_model_dir, - /*import_user_signatures=*/true, - options.graph_execution_options.run_placer_grappler_on_functions)); - if (!checkpoint_loader) { return absl::InternalError("Missing checkpoint loader."); } - TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore( - std::move(mlir_module_restore_analysis))); + ifrt_serving::CheckpointLoader::PrepareRestoreArgs args = { + .context = context, + .meta_graph_def = meta_graph_def, + .fallback_state = &fallback_state, + .saved_model_dir = saved_model_dir, + .run_placer_grappler_on_functions = + options.graph_execution_options.run_placer_grappler_on_functions}; + + TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore(args)); LOG(INFO) << "Complete set restore metadata."; return absl::OkStatus(); @@ -390,6 +395,20 @@ bool AotPackageExists(absl::string_view saved_model_dir) { env->FileExists(aot_bef_path).ok(); } +std::string GetInferredModelType(const MetaGraphDef& meta_graph_def) { + bool found_xla_call_module_op = false; + for (const auto& function : meta_graph_def.graph_def().library().function()) { + for (const auto& node : function.node_def()) { + if (node.name() == kXlaCallModuleOpName) { + found_xla_call_module_op = true; + break; + } + } + } + return std::string(found_xla_call_module_op ? kJaxModelLabel + : kUnknownModelLabel); +} + } // namespace SavedModel::~SavedModel() = default; // Out-of-line C++ key function. @@ -761,6 +780,15 @@ absl::StatusOr> SavedModelImpl::LoadSavedModel( << persistent_cache_directory << ", and set it to read-only."; } + if (options.infer_model_type) { + inferred_model_type_count + ->GetCell(options.graph_execution_options.model_metadata.name(), + absl::StrCat( + options.graph_execution_options.model_metadata.version()), + GetInferredModelType(meta_graph_def)) + ->IncrementBy(1); + } + if (options.graph_execution_options.use_ifrt) { use_ifrt_count ->GetCell(options.graph_execution_options.model_metadata.name(), @@ -1023,7 +1051,7 @@ SavedModelImpl::ImportSubgraph( graph_import_config)); // Convert the optimized graph to an MLIR module. - return tensorflow::ConvertGraphToMlir( + return tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( *optimization_result.graph, /*debug_info=*/{}, optimization_result.graph->flib_def(), graph_import_config, context); } diff --git a/tensorflow/core/tfrt/saved_model/saved_model.h b/tensorflow/core/tfrt/saved_model/saved_model.h index 297a0468130d7c..ca8619018cd6e3 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.h +++ b/tensorflow/core/tfrt/saved_model/saved_model.h @@ -110,6 +110,10 @@ class SavedModel { // True if and only if SavedModel is being loaded to generate AOT results. bool aot_generation = false; + // Make a best-effort guess at the model type. E.g. detecting JAX models by + // looking for the `XlaCallModule` op in the MetagraphDef. + bool infer_model_type = false; + GraphExecutionOptions graph_execution_options; }; diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD index c026800861ff2a..22bc6dec94a585 100644 --- a/tensorflow/core/tfrt/saved_model/tests/BUILD +++ b/tensorflow/core/tfrt/saved_model/tests/BUILD @@ -625,6 +625,7 @@ cc_library( "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/monitoring:cell_reader", "@tf_runtime//:core_runtime_alwayslink", "@tf_runtime//:hostcontext", ], diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc index 605e4413ffffc1..647358bd5fd1fc 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/monitoring/cell_reader.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -54,6 +55,8 @@ namespace tensorflow { namespace tfrt_stub { namespace { +using ::tsl::monitoring::testing::CellReader; + struct TestParams { bool enable_grappler = false; bool enable_lazy_loading = false; @@ -1181,6 +1184,33 @@ TEST(SavedModelTest, CustomCompiler) { EXPECT_EQ(test_context.signature_name, "toy"); } +// TODO(b/374165187): Add a test case for positive identification of JAX models. +// Currently we don't have those in our testdata. +TEST(SavedModelTest, InferModelType) { + // SavedModel toy contains a graph of a single 'tf.AddV2' op. It is generated + // using the following python code: + // x = tf.placeholder(tf.int32, shape=(3)) + // y = tf.compat.v1.get_variable(name='y', initializer=[1, 2, 3]) + // r = tf.matmul(x, y) + std::string saved_model_dir = tensorflow::GetDataDependencyFilepath( + "tensorflow/core/tfrt/saved_model/tests/toy_v1/1"); + + auto inferred_model_type_count_reader = + CellReader("/tensorflow/tfrt/inferred_model_type"); + + auto runtime = DefaultTfrtRuntime(/*num_threads=*/1); + auto options = DefaultSavedModelOptions(runtime.get()); + options.infer_model_type = true; + + auto saved_model = SavedModelImpl::LoadSavedModel(options, saved_model_dir, + /*tags=*/{"serve"}); + TF_CHECK_OK(saved_model.status()); + + // TODO(b/374165187): We currently get the model name from graph execution + // options but our test setup does not populate that. + EXPECT_EQ(inferred_model_type_count_reader.Delta("", "0", "UNKNOWN"), 1); +} + } // namespace } // namespace tfrt_stub } // namespace tensorflow diff --git a/tensorflow/core/tfrt/stubs/BUILD b/tensorflow/core/tfrt/stubs/BUILD index 8d7227fc74d7ae..422077b7c9469f 100644 --- a/tensorflow/core/tfrt/stubs/BUILD +++ b/tensorflow/core/tfrt/stubs/BUILD @@ -16,12 +16,10 @@ cc_library( deps = [ "//tensorflow/core/tfrt/graph_executor:executable_context", "//tensorflow/core/tfrt/graph_executor:sync_resource_state", - "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:context", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", "@tf_runtime//:hostcontext", ], ) diff --git a/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.cc b/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.cc index b6b20fad414fbc..417b5ea4c932e9 100644 --- a/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.cc +++ b/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/core/tfrt/graph_executor/executable_context.h" #include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" @@ -65,11 +64,6 @@ void AddSyncContext(mlrt::ExecutionContext& execution_context, execution_context, host_context, sync_state); } -void AddNativeLoweringPasses(mlir::OpPassManager* pass_manager) { - GetTfrtNativeLoweringStubRegistry().Get().AddNativeLoweringPasses( - pass_manager); -} - absl::StatusOr> BuildExecutableContext( mlir::ModuleOp module, const mlrt::KernelRegistry& kernel_registry) { return GetTfrtNativeLoweringStubRegistry().Get().BuildExecutableContext( diff --git a/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h b/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h index ee690f532be1a8..d27fe02df12a3f 100644 --- a/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h +++ b/tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h @@ -20,10 +20,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/core/tfrt/graph_executor/executable_context.h" #include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" -#include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime @@ -38,7 +36,6 @@ class TfrtNativeLoweringStub { virtual void AddSyncContext( mlrt::ExecutionContext& execution_context, HostContext& host_context, tensorflow::tfrt_stub::SyncResourceState* sync_state) {} - virtual void AddNativeLoweringPasses(mlir::OpPassManager* pass_manager) {} virtual absl::StatusOr< std::shared_ptr> BuildExecutableContext(mlir::ModuleOp module, @@ -54,8 +51,6 @@ void AddSyncContext(mlrt::ExecutionContext& execution_context, tfrt::HostContext& host_context, tensorflow::tfrt_stub::SyncResourceState* sync_state); -void AddNativeLoweringPasses(mlir::OpPassManager* pass_manager); - absl::StatusOr> BuildExecutableContext(mlir::ModuleOp module, const mlrt::KernelRegistry& kernel_registry); diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index eea5f1d99631ca..41bf66bd57579e 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -25,8 +25,10 @@ cc_library( hdrs = ["tpu_embedding_configuration_utils.h"], visibility = ["//visibility:public"], deps = [ + ":tpu_embedding_optimization_parameters_utils", "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", ], @@ -57,7 +59,7 @@ tf_cc_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -72,6 +74,7 @@ cc_library( "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_xla//xla:xla_data_proto_cc", @@ -103,7 +106,7 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], ) @@ -237,20 +240,29 @@ cc_library( "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:strcat", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:macros", - "@local_xla//xla:literal_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", + "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/hlo/ir:hlo", ], ) diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 73fbacd589160b..12096d3abcf27a 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -76,6 +76,7 @@ cc_library( "//tensorflow/core/platform:strcat", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_xla//xla:status_macros", ], @@ -163,6 +164,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/common_runtime:optimization_registry", "//tensorflow/core/config:flag_defs", + "//tensorflow/core/framework:types_proto_cc", ], ) @@ -226,7 +228,7 @@ cc_library( "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla:xla_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", "@local_xla//xla/service:computation_placer", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_api", @@ -396,6 +398,7 @@ cc_library( srcs = ["tpu_embedding_software_deduplication_rewrite_pass.cc"], hdrs = ["tpu_embedding_software_deduplication_rewrite_pass.h"], deps = [ + ":tpu_embedding_rewrite_pass_utils", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -404,7 +407,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc", "//tensorflow/core/tpu:tpu_embedding_configuration_utils", - "//tensorflow/core/tpu/graph_rewrite:tpu_embedding_rewrite_pass_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc index bca30520071c66..1862edab9cd38a 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/strings/str_join.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/device_set.h" #include "tensorflow/core/framework/device.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/device_name_utils.h" diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index a13b3caba2fc17..588c0f9ce584fe 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" @@ -54,7 +55,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/service/computation_placer.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index 9370cd6b01ab1c..bb643300dfc8e4 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" @@ -62,7 +63,6 @@ limitations under the License. #include "tensorflow/core/graph/graph_node_util.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/flatset.h" -#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc index a21cdaec4dbc72..a08a5e6be10a01 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/testlib.h" diff --git a/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.cc index 97f041ba30dd9d..65a3738ed76d74 100644 --- a/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/tpu_embedding_software_deduplication_rewrite_pass.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" @@ -57,7 +57,7 @@ absl::Status CheckNumInputsOrOutputs( const tpu::TPUEmbeddingConfiguration& tpu_embedding_config) { if (tpu_embedding_config.feature_descriptor_size() == 0 && num_input_or_outputs != tpu_embedding_config.table_descriptor_size()) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Number of tables in the TPU embedding config: %d does not match the " "%s attribute: %d in the %s node.", tpu_embedding_config.table_descriptor_size(), attribute_name, @@ -66,7 +66,7 @@ absl::Status CheckNumInputsOrOutputs( if (tpu_embedding_config.feature_descriptor_size() > 0 && num_input_or_outputs != tpu_embedding_config.feature_descriptor_size()) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Feature descriptor is set in tpu embedding config. But number of " "features in the TPU embedding config: %d does not match the " "%s attribute: %d in the %s node.", @@ -117,7 +117,7 @@ absl::StatusOr MakeRecvActivationsNodeDef( absl::Span data_inputs, absl::Span control_inputs) { if (!data_inputs.empty()) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( absl::StrFormat("Expected to have zero inputs for " "RecvTPUEmbeddingActivations node, found %d inputs.", data_inputs.size())); @@ -126,7 +126,7 @@ absl::StatusOr MakeRecvActivationsNodeDef( tpu::TPUEmbeddingConfiguration tpu_embedding_config; if (!tpu_embedding_config.ParseFromString( std::string(tpu_embedding_config_str))) { // NOLINT - return errors::InvalidArgument( + return absl::InvalidArgumentError( "Malformed config attribute in the RecvTPUEmbeddingActivations node."); } @@ -181,9 +181,8 @@ absl::StatusOr MakeSendGradientsNodeDef( absl::Span data_inputs, absl::Span control_inputs) { tpu::TPUEmbeddingConfiguration tpu_embedding_config; - if (!tpu_embedding_config.ParseFromString( - std::string(tpu_embedding_config_str))) { // NOLINT - return errors::InvalidArgument( + if (!tpu_embedding_config.ParseFromString(tpu_embedding_config_str)) { + return absl::InvalidArgumentError( "Malformed config attribute in the SendTPUEmbeddingGradients node."); } @@ -195,37 +194,39 @@ absl::StatusOr MakeSendGradientsNodeDef( "SendTPUEmbeddingGradients", tpu_embedding_config)); - int32 learning_rate_tag_count = 0; + int32 dynamic_inputs_tag_count = 0; if (!GetNodeAttr(AttrSlice(old_gradients_node_def), "NN", - &learning_rate_tag_count) + &dynamic_inputs_tag_count) .ok()) { LOG(INFO) << "Missing the NN attribute (number of dynamic learning rate tags) in " "the SendTPUEmbeddingGradients node. Setting the value to 0."; } - auto status_or_lr_tag_count = - tpu::ComputeTotalTagCountForDynamicLearningRates(tpu_embedding_config); - if (!status_or_lr_tag_count.ok()) { - return errors::InvalidArgument(status_or_lr_tag_count.status().message()); + auto status_or_dynamic_inputs_tag_count = + tpu::ComputeTotalTagCountForOptimizerDynamicInputs(tpu_embedding_config); + if (!status_or_dynamic_inputs_tag_count.ok()) { + return absl::InvalidArgumentError( + status_or_dynamic_inputs_tag_count.status().message()); } - const int32 expected_learning_rate_tag_count = status_or_lr_tag_count.value(); + const int32 expected_dynamic_inputs_tag_count = + status_or_dynamic_inputs_tag_count.value(); - if (learning_rate_tag_count != expected_learning_rate_tag_count) { - return errors::InvalidArgument(absl::StrFormat( + if (dynamic_inputs_tag_count != expected_dynamic_inputs_tag_count) { + return absl::InvalidArgumentError(absl::StrFormat( "Number of dynamic learning rate tags in the TPU embedding config: %d " "does not match the NN attribute: %d in the SendTPUEmbeddingGradients " "node.", - expected_learning_rate_tag_count, learning_rate_tag_count)); + expected_dynamic_inputs_tag_count, dynamic_inputs_tag_count)); } if (data_inputs.size() != - static_cast(num_inputs + learning_rate_tag_count)) { - return errors::InvalidArgument(absl::StrFormat( + static_cast(num_inputs + dynamic_inputs_tag_count)) { + return absl::InvalidArgumentError(absl::StrFormat( "Mismatch in the number of inputs for SendTPUEmbeddingGradients node, " "expected: %d, actual: %d", - num_inputs + learning_rate_tag_count, data_inputs.size())); + num_inputs + dynamic_inputs_tag_count, data_inputs.size())); } NodeDefBuilder builder(old_gradients_node_def.name(), @@ -234,10 +235,10 @@ absl::StatusOr MakeSendGradientsNodeDef( builder.Device(device_name); } - // The Numtables here can be interpreted as num features if the feature + // The NumTables here can be interpreted as num features if the feature // descriptor is present in the config. builder.Attr("NumTables", num_inputs) - .Attr("NumLearningRateTags", learning_rate_tag_count) + .Attr("NumLearningRateTags", dynamic_inputs_tag_count) .Attr("config", tpu_embedding_config_str); if (!tpu_replicate_attr.empty()) { builder.Attr("_tpu_replicate", tpu_replicate_attr); @@ -250,7 +251,7 @@ absl::StatusOr MakeSendGradientsNodeDef( builder.Input(absl::MakeConstSpan(data_inputs.data(), num_inputs)) .Input(absl::MakeConstSpan(data_inputs.data() + num_inputs, - learning_rate_tag_count)) + dynamic_inputs_tag_count)) .Input(absl::StrCat(deduplication_data_node_name, ":output"), /*src_index=*/0, DT_VARIANT); for (const std::string& control_input : control_inputs) { @@ -281,7 +282,7 @@ struct SendRecvNodesMapKey { s.requested_device); } - const inline bool operator==(const SendRecvNodesMapKey& s) const { + inline bool operator==(const SendRecvNodesMapKey& s) const { return (tpu_replicate_attr == s.tpu_replicate_attr && requested_device == s.requested_device); } @@ -328,7 +329,7 @@ std::vector GetDataInputs(const Node* node, // from the graph nodes (activations_node and gradients_node). If both nodes are // present, ensure that the TPUEmbeddingConfiguration proto, assigned device // name, and index are the same on both nodes. -Status ValidateAndGetTPUEmbeddingConfiguration( +absl::Status ValidateAndGetTPUEmbeddingConfiguration( const Node* activations_node, const Node* gradients_node, absl::string_view tpu_replicate_attr, std::string* tpu_embedding_config_str) { @@ -341,7 +342,7 @@ Status ValidateAndGetTPUEmbeddingConfiguration( GetNodeAttr(gradients_node->def(), "config", &gradients_config_str)); if (activations_config_str != gradients_config_str) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "TPU embedding config attributes of RecvTPUEmbeddingActivations and " "SendTPUEmbeddingGradients nodes with the same tpu_replicate attr: " "%s are not identical.", @@ -349,7 +350,7 @@ Status ValidateAndGetTPUEmbeddingConfiguration( } if (activations_node->assigned_device_name() != gradients_node->assigned_device_name()) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Mismatch in assigned device names for the " "RecvTPUEmbeddingActivations (%s) and SendTPUEmbeddingGradients (%s) " "nodes with the same tpu_replicate attr: %s.", @@ -358,7 +359,7 @@ Status ValidateAndGetTPUEmbeddingConfiguration( } if (activations_node->assigned_device_name_index() != gradients_node->assigned_device_name_index()) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Mismatch in assigned device name indices for the " "RecvTPUEmbeddingActivations (%d) and SendTPUEmbeddingGradients (%d) " "nodes with the same tpu_replicate attr: %s.", @@ -368,7 +369,7 @@ Status ValidateAndGetTPUEmbeddingConfiguration( } if (activations_node == nullptr && gradients_node == nullptr) { - return errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Found tpu_replicate attr: %s with no corresponding " "RecvTPUEmbeddingActivations or SendTPUEmbeddingGradients nodes", tpu_replicate_attr)); @@ -388,13 +389,11 @@ Status ValidateAndGetTPUEmbeddingConfiguration( // (Op=RecvTPUEmbeddingActivations) and the old_gradients_node // (Op=SendTPUEmbeddingGradients) are copied over to the newly inserted node to // ensure that it has the same control frame. -Status AddRecvDeduplicationDataNode(const Node* old_activations_node, - const Node* old_gradients_node, - const std::string& requested_device, - absl::string_view tpu_replicate_attr, - absl::string_view tpu_embedding_config_str, - Node** deduplication_data_node, - Graph* graph) { +absl::Status AddRecvDeduplicationDataNode( + const Node* old_activations_node, const Node* old_gradients_node, + const std::string& requested_device, absl::string_view tpu_replicate_attr, + absl::string_view tpu_embedding_config_str, Node** deduplication_data_node, + Graph* graph) { // Note that control inputs added later while constructing the Node are copied // over automatically to the NodeDef, so we don't need to specify any control // inputs here. @@ -440,7 +439,7 @@ Status AddRecvDeduplicationDataNode(const Node* old_activations_node, // node (Op=XlaRecvTPUEmbeddingActivations) and initializes it with the // specified tpu_replicate and tpu_embedding_config_str attributes. Connects the // output of the deduplication_data_node to the input of the newly added node. -Status ReplaceRecvActivationsNodeAndAddDeduplicationInputs( +absl::Status ReplaceRecvActivationsNodeAndAddDeduplicationInputs( absl::string_view tpu_replicate_attr, absl::string_view tpu_embedding_config_str, Node* old_activations_node, Node* deduplication_data_node, Graph* graph) { @@ -482,7 +481,7 @@ Status ReplaceRecvActivationsNodeAndAddDeduplicationInputs( // node (Op=XlaSendTPUEmbeddingGradients) and initializes it with the specified // tpu_replicate and tpu_embedding_config_str attributes. Connects the output of // the deduplication_data_node to the last input of the newly added node. -Status ReplaceSendGradientsNodeAndAddDeduplicationInputs( +absl::Status ReplaceSendGradientsNodeAndAddDeduplicationInputs( absl::string_view tpu_replicate_attr, absl::string_view tpu_embedding_config_str, Node* old_gradients_node, Node* deduplication_data_node, Graph* graph) { @@ -523,7 +522,7 @@ Status ReplaceSendGradientsNodeAndAddDeduplicationInputs( } // Rewrites the graph for a particular _tpu_replicate attribute. -Status RewriteGraphForTpuReplicateAttrAndDevice( +absl::Status RewriteGraphForTpuReplicateAttrAndDevice( absl::string_view tpu_replicate_attr, const std::string& requested_device, Node* old_activations_node, Node* old_gradients_node, Graph* graph) { VLOG(1) << "Rewriting graph for _tpu_replicate attribute: " @@ -560,8 +559,8 @@ Status RewriteGraphForTpuReplicateAttrAndDevice( // Inserts a RecvTPUEmbeddingActivations node into the send_recv_nodes_map. This // map temporarily holds the RecvTPUEmbeddingActivations and // SendTPUEmbeddingGradients of the graph before they are rewritten. -Status InsertActivationsNodeIntoMap(Node* activations_node, - SendRecvNodesMap* send_recv_nodes_map) { +absl::Status InsertActivationsNodeIntoMap( + Node* activations_node, SendRecvNodesMap* send_recv_nodes_map) { std::string tpu_replicate_attr; TF_RETURN_IF_ERROR(GetNodeAttr(activations_node->def(), "_tpu_replicate", &tpu_replicate_attr)); @@ -576,13 +575,13 @@ Status InsertActivationsNodeIntoMap(Node* activations_node, const SendRecvNodesMap::iterator it = send_recv_nodes_map->find(key); if (it != send_recv_nodes_map->end()) { if (it->second.activations_node != nullptr) { - return errors::AlreadyExists(absl::StrFormat( + return absl::AlreadyExistsError(absl::StrFormat( "Found duplicate RecvTPUEmbeddingActivations node in graph with " "tpu_replicate attr: %s and requested_device: %s", tpu_replicate_attr, requested_device)); } if (it->second.gradients_node == nullptr) { - return errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Found map object with no RecvTPUEmbeddingActivations or " "SendTPUEmbeddingGradients nodes and tpu_replicate attr: %s and " "requested_device: %s", @@ -600,8 +599,8 @@ Status InsertActivationsNodeIntoMap(Node* activations_node, // Inserts a SendTPUEmbeddingGradients node into the send_recv_nodes_map. This // map temporarily holds the RecvTPUEmbeddingActivations and // SendTPUEmbeddingGradients of the graph before they are rewritten. -Status InsertGradientsNodeIntoMap(Node* gradients_node, - SendRecvNodesMap* send_recv_nodes_map) { +absl::Status InsertGradientsNodeIntoMap(Node* gradients_node, + SendRecvNodesMap* send_recv_nodes_map) { std::string tpu_replicate_attr; TF_RETURN_IF_ERROR(GetNodeAttr(gradients_node->def(), "_tpu_replicate", &tpu_replicate_attr)); @@ -616,13 +615,13 @@ Status InsertGradientsNodeIntoMap(Node* gradients_node, const SendRecvNodesMap::iterator it = send_recv_nodes_map->find(key); if (it != send_recv_nodes_map->end()) { if (it->second.gradients_node != nullptr) { - return errors::AlreadyExists(absl::StrFormat( + return absl::AlreadyExistsError(absl::StrFormat( "Found duplicate SendTPUEmbeddingGradients node in graph with " "tpu_replicate attr: %s and requested_device: %s", tpu_replicate_attr, requested_device)); } if (it->second.activations_node == nullptr) { - return errors::Internal(absl::StrFormat( + return absl::InternalError(absl::StrFormat( "Found map object with no RecvTPUEmbeddingActivations or " "SendTPUEmbeddingGradients nodes and tpu_replicate attr: %s and " "requested_device: %s", @@ -639,7 +638,7 @@ Status InsertGradientsNodeIntoMap(Node* gradients_node, // Groups the RecvTPUEmbeddingActivations and SendTPUEmbeddingGradients of the // graph using their _tpu_replicate attribute and requested device. -Status GroupSendRecvNodesByTpuReplicateAttrAndDevice( +absl::Status GroupSendRecvNodesByTpuReplicateAttrAndDevice( const Graph* graph, SendRecvNodesMap* send_recv_nodes_map) { VLOG(1) << "Grouping nodes by _tpu_replicate attribute"; for (Node* node : graph->nodes()) { @@ -658,7 +657,7 @@ Status GroupSendRecvNodesByTpuReplicateAttrAndDevice( // Rewrites the graph in the specified GraphOptimizationPassOptions object for // software deduplication. -Status RewriteGraph(Graph* graph) { +absl::Status RewriteGraph(Graph* graph) { SendRecvNodesMap send_recv_nodes_map; TF_RETURN_IF_ERROR(GroupSendRecvNodesByTpuReplicateAttrAndDevice( graph, &send_recv_nodes_map)); @@ -768,8 +767,8 @@ absl::StatusOr ComputeRewriterConfigForNodeDef( // tpu_replicate and tpu_embedding_config attributes are the same if the // final_rewriter_config has been partially populated. Aggregates the // control inputs of both configs as well. -Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, - RewriterConfig* final_rewriter_config) { +absl::Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, + RewriterConfig* final_rewriter_config) { if (final_rewriter_config->activations_node_def_name.empty() && final_rewriter_config->gradients_node_def_name.empty()) { final_rewriter_config->device_name = rewriter_config.device_name; @@ -779,13 +778,13 @@ Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, rewriter_config.tpu_embedding_config_str; } else { if (final_rewriter_config->device_name != rewriter_config.device_name) { - return errors::InvalidArgument(absl::StrFormat( + return absl::InvalidArgumentError(absl::StrFormat( "Mismatch in device names for TPU embedding nodes: %s != %s", final_rewriter_config->device_name, rewriter_config.device_name)); } if (final_rewriter_config->tpu_replicate_attr != rewriter_config.tpu_replicate_attr) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( absl::StrFormat("Mismatch in _tpu_replicate attributes for TPU " "embedding nodes: %s != %s", final_rewriter_config->tpu_replicate_attr, @@ -793,7 +792,7 @@ Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, } if (final_rewriter_config->tpu_embedding_config_str != rewriter_config.tpu_embedding_config_str) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( absl::StrFormat("Mismatch in config attributes for TPU " "embedding nodes: %s != %s", final_rewriter_config->tpu_embedding_config_str, @@ -805,7 +804,7 @@ Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, if (!rewriter_config.activations_node_def_name.empty()) { if (!final_rewriter_config->activations_node_def_name.empty()) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( absl::StrFormat("Found duplicate RecvTPUEmbeddingActivations nodes " "%s and %s in function.", rewriter_config.activations_node_def_name, @@ -816,7 +815,7 @@ Status MergeRewriterConfigs(const RewriterConfig& rewriter_config, } if (!rewriter_config.gradients_node_def_name.empty()) { if (!final_rewriter_config->gradients_node_def_name.empty()) { - return errors::InvalidArgument( + return absl::InvalidArgumentError( absl::StrFormat("Found duplicate SendTPUEmbeddingGradients nodes %s " "and %s in function.", rewriter_config.gradients_node_def_name, @@ -877,7 +876,7 @@ bool RewriterConfigsByDeviceHasEmbeddingOperations( // Rewrites the function defs in the specified GraphOptimizationPassOptions // object for software deduplication. -Status RewriteFunctionDefs(FunctionLibraryDefinition* flib_def) { +absl::Status RewriteFunctionDefs(FunctionLibraryDefinition* flib_def) { for (const std::string& fname : flib_def->ListFunctionNames()) { // The function def cannot be modified. Hence, make a copy, modify the copy // and then replace the original function def using the copy. @@ -958,7 +957,7 @@ Status RewriteFunctionDefs(FunctionLibraryDefinition* flib_def) { } // namespace -Status TPUEmbeddingSoftwareDeduplicationRewritePass::Run( +absl::Status TPUEmbeddingSoftwareDeduplicationRewritePass::Run( const GraphOptimizationPassOptions& options) { TF_RETURN_IF_ERROR(RewriteGraph(options.graph->get())); TF_RETURN_IF_ERROR(RewriteFunctionDefs(options.flib_def)); diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 5c1dc5da889ede..dd30044db1e98b 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -197,9 +197,9 @@ cc_library( "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client:xla_computation", - "@local_xla//xla/client/lib:slicing", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder:xla_computation", + "@local_xla//xla/hlo/builder/lib:slicing", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", @@ -247,13 +247,13 @@ tf_kernel_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla:util", "@local_xla//xla/stream_executor", "@local_xla//xla/stream_executor/tpu:proto_helper", "@local_xla//xla/stream_executor/tpu:status_helper", "@local_xla//xla/stream_executor/tpu:tpu_api", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ] + if_libtpu_tf_status(), alwayslink = 1, ) @@ -837,7 +837,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_xla//xla/stream_executor/platform", + "@local_xla//xla/stream_executor/platform:initialize", "@local_xla//xla/stream_executor/tpu:tpu_ops_c_api_hdrs", ], alwayslink = 1, @@ -859,9 +859,9 @@ cc_library( "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:tstring", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_xla//xla/stream_executor/platform", + "@local_xla//xla/stream_executor/platform:initialize", "@local_xla//xla/stream_executor/tpu:tpu_node_context", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], alwayslink = True, ) @@ -959,7 +959,7 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], alwayslink = 1, ) @@ -972,8 +972,8 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/numeric:bits", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:arithmetic", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:arithmetic", ], alwayslink = 1, ) @@ -1003,7 +1003,7 @@ cc_library( "@local_xla//xla:literal_util", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:c_api_decl", "@local_xla//xla/stream_executor/tpu:proto_helper", @@ -1127,8 +1127,8 @@ cc_library( "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", - "@local_xla//xla/client:xla_builder", - "@local_xla//xla/client/lib:constants", + "@local_xla//xla/hlo/builder:xla_builder", + "@local_xla//xla/hlo/builder/lib:constants", ], alwayslink = True, ) @@ -1445,7 +1445,7 @@ tf_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -1551,7 +1551,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "sparse_core_layout_py_pb2", -# api_version = 2, # deps = [":sparse_core_layout_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/core/tpu/kernels/cross_replica_ops.cc b/tensorflow/core/tpu/kernels/cross_replica_ops.cc index 24a8ee467f63a7..6a027cd1a391f4 100644 --- a/tensorflow/core/tpu/kernels/cross_replica_ops.cc +++ b/tensorflow/core/tpu/kernels/cross_replica_ops.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/tpu/kernels/image_resize_ops.cc b/tensorflow/core/tpu/kernels/image_resize_ops.cc index 2c442e0748fc20..f3d2248a92fc29 100644 --- a/tensorflow/core/tpu/kernels/image_resize_ops.cc +++ b/tensorflow/core/tpu/kernels/image_resize_ops.cc @@ -22,8 +22,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc index 084802a89b5137..3185293b52d484 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.proto b/tensorflow/core/tpu/kernels/sparse_core_layout.proto index 6b7bbd9ed5ebe5..a7e162b228acb8 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.proto +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.proto @@ -94,6 +94,12 @@ message SparseCoreTableLayout { // partitions. // sparse_core_shard_rotation = table_index * sparse_cores_per_partition int64 sparse_core_shard_rotation = 9; + + // The batch size per sparsecore for this table. This combines the batch sizes + // of all the features pointing to this table. + int64 per_sparse_core_batch_size = 10; + // Number of features that refer to this table. + int64 num_features = 11; } message SparseCoreTableLayouts { diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc index 8548a92efe0495..e25889827a49f3 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.cc @@ -92,7 +92,9 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // tensor. } else if (indices_or_row_splits.dims() == 2 && indices_or_row_splits.NumElements() >= 0) { - // TODO(pineapplejuice233): Add checking logic for sparse tensor input. + // NOTE(mrry): Checking logic for SparseTensor inputs is in + // `ComputeRowIdsBeforePadding()`, to avoid an extra traversal of the + // indices matrix. } else if (indices_or_row_splits.dims() == 1 && indices_or_row_splits.NumElements() > 0) { // Ragged tensor. @@ -114,6 +116,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, const int32 total_id_count, + const int32 sample_count, int32* row_ids_before_padding) { // The only difference between dense tensor, sparse tensor and ragged tensor // is the row ids output. @@ -140,7 +143,14 @@ Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, if (current_row_id < previous_row_id) { return absl::InvalidArgumentError( "Invalid indices_or_row_splits input, indices of SparseTensor need " - "to be sorted in ascending order."); + "to be sorted in ascending (non-decreasing) order."); + } + if (current_row_id >= sample_count) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid indices_or_row_splits input, indices of SparseTensor " + "contained a row_id ", + current_row_id, " that was >= the sample count (", sample_count, + ").")); } *(row_ids_before_padding + i) = current_row_id; previous_row_id = current_row_id; @@ -309,9 +319,9 @@ class ConvertToCooTensorOp : public OpKernel { auto row_ids_before_dedup = std::make_unique(total_id_count); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding( + *indices_or_row_splits, total_id_count, + sample_count_, row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = @@ -520,9 +530,8 @@ void GetMinibatchesInCsrWithPhysicalReplicaOp::Compute(OpKernelContext* ctx) { "The number of minibatches per sparse core is ", num_minibatch_per_sc, ". But the max minibatches per sparse core is set to be ", max_minibatches_per_sc_, " which is smaller."))); - VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " - << "program_key = '" << program_key << "'" - << ", table_name = '" << table_name_ << "'" + VLOG(2) << "GetMinibatchesInCsrWithPhysicalReplicaOp: " << "program_key = '" + << program_key << "'" << ", table_name = '" << table_name_ << "'" << ", max_ids = " << max_ids_per_partition << ", max_uniques = " << max_unique_ids_per_partition << ", num_minibatch_per_sc = " << num_minibatch_per_sc; @@ -1213,9 +1222,9 @@ void ConvertToListOfSparseCoreCooTensorsOp::Compute(OpKernelContext* ctx) { auto row_ids_before_dedup = std::unique_ptr( new std::remove_extent_t[total_id_count]); - OP_REQUIRES_OK( - ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, total_id_count, - row_ids_before_dedup.get())); + OP_REQUIRES_OK(ctx, ComputeRowIdsBeforePadding(*indices_or_row_splits, + total_id_count, sample_count_, + row_ids_before_dedup.get())); // Compute the rescaled gains for non-sum combiners. std::optional> gains_rescale = diff --git a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h index ce43521cbc5147..d3651d04de2d6e 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_preprocess_ops.h @@ -55,7 +55,7 @@ Status ValidateInputs(const Tensor& indices_or_row_splits, const Tensor& values, // Compute the row id list before padding. Status ComputeRowIdsBeforePadding(const Tensor& indices_or_row_splits, - int32 total_id_count, + int32 total_id_count, int32 sample_count, int32* row_ids_before_padding); class GetMinibatchesInCsrWithPhysicalReplicaOp : public OpKernel { diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc index acccaecc4d0a53..fa57f936e57c29 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.cc @@ -28,9 +28,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h index 0a4b9183a04184..71995cb92480f7 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h +++ b/tensorflow/core/tpu/kernels/sparse_core_xla_ops.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_OPS_H_ #define TENSORFLOW_CORE_TPU_KERNELS_SPARSE_CORE_XLA_OPS_H_ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/xla_data.pb.h" #include "tsl/platform/macros.h" diff --git a/tensorflow/core/tpu/kernels/topk_ops.cc b/tensorflow/core/tpu/kernels/topk_ops.cc index b8dcf2ba55b987..16334632946c25 100644 --- a/tensorflow/core/tpu/kernels/topk_ops.cc +++ b/tensorflow/core/tpu/kernels/topk_ops.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/numeric/bits.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/tpu/tpu_defs.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op.cc b/tensorflow/core/tpu/kernels/tpu_compile_op.cc index b4a462a1e20b72..f1f4cbb6d380e5 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/tpu/tpu_node_context.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/tstring.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace tpu { diff --git a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc index 60636369401c17..e91d4b90af39a0 100644 --- a/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_configuration_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/stream_executor/tpu/status_helper.h" #include "xla/stream_executor/tpu/tpu_api.h" #include "xla/stream_executor/tpu/tpu_ops_c_api.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/device_base.h" @@ -57,7 +58,6 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/tstring.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc index f23215c086b513..4b0e60e024ad03 100644 --- a/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc +++ b/tensorflow/core/tpu/kernels/tpu_embedding_ops.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/shape.h" diff --git a/tensorflow/core/tpu/kernels/xla/BUILD b/tensorflow/core/tpu/kernels/xla/BUILD index 62c0cf82925549..08d2623965fa50 100644 --- a/tensorflow/core/tpu/kernels/xla/BUILD +++ b/tensorflow/core/tpu/kernels/xla/BUILD @@ -38,8 +38,8 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:side_effect_util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/builder:xla_builder", ], alwayslink = 1, ) @@ -78,7 +78,7 @@ cc_library( "@local_xla//xla:shape_util", "@local_xla//xla:util", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:xla_builder", + "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:c_api_decl", diff --git a/tensorflow/core/tpu/kernels/xla/get_item_op.cc b/tensorflow/core/tpu/kernels/xla/get_item_op.cc index 324a2fc8831db1..bd26278f885bb0 100644 --- a/tensorflow/core/tpu/kernels/xla/get_item_op.cc +++ b/tensorflow/core/tpu/kernels/xla/get_item_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc index 3156c1f3ca8242..98b28a01cd9bb7 100644 --- a/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/host_compute_ops.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/side_effect_util.h" diff --git a/tensorflow/core/tpu/kernels/xla/infeed_op.cc b/tensorflow/core/tpu/kernels/xla/infeed_op.cc index 6eed665b1516d3..8481d5b27b92e3 100644 --- a/tensorflow/core/tpu/kernels/xla/infeed_op.cc +++ b/tensorflow/core/tpu/kernels/xla/infeed_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/layout_util.h" #include "xla/shape.h" diff --git a/tensorflow/core/tpu/kernels/xla/inplace_ops.cc b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc index f4c5b8a4cf0d50..26a6610ba2ea78 100644 --- a/tensorflow/core/tpu/kernels/xla/inplace_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/inplace_ops.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc index fac67ba2289c31..393ba47a909cb1 100644 --- a/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc +++ b/tensorflow/core/tpu/kernels/xla/outfeed_ops.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/core/tpu/tpu_compile.cc b/tensorflow/core/tpu/tpu_compile.cc index 4a5096edea3cce..706f3a50a5c101 100644 --- a/tensorflow/core/tpu/tpu_compile.cc +++ b/tensorflow/core/tpu/tpu_compile.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_compile.h" #include +#include +#include #include #include #include @@ -25,29 +27,44 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/shape_inference.h" #include "tensorflow/compiler/tf2xla/layout_util.h" -#include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "xla/client/compile_only_client.h" -#include "xla/literal_util.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/function_body.h" #include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/graph_constructor.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_defs.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace tensorflow { namespace tpu { @@ -433,7 +450,7 @@ Status RunShapeInferenceOnComputation( Status CompileTFFunctionToHlo( const FunctionLibraryDefinition& flib_def, int graph_def_version, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, - const std::vector& arg_shapes, + const std::vector& arg_shapes, const DeviceType& device_type, const GuaranteedConsts& guaranteed_constants, const NameAttrList& function, const tpu::TPUCompileMetadataProto& metadata, xla::CompileOnlyClient* client, @@ -442,7 +459,7 @@ Status CompileTFFunctionToHlo( bool use_tuple_args, XlaCompiler::CompilationResult* compilation_result) { XlaCompiler::Options compiler_options; FunctionLibraryDefinition flib_definition(flib_def); - compiler_options.device_type = DeviceType(DEVICE_TPU_XLA_JIT); + compiler_options.device_type = device_type; compiler_options.client = client; compiler_options.flib_def = &flib_definition; compiler_options.allow_cpu_custom_calls = false; diff --git a/tensorflow/core/tpu/tpu_compile.h b/tensorflow/core/tpu/tpu_compile.h index a0340cb126708b..df60568310ff9c 100644 --- a/tensorflow/core/tpu/tpu_compile.h +++ b/tensorflow/core/tpu/tpu_compile.h @@ -20,8 +20,14 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "xla/client/compile_only_client.h" +#include "xla/shape.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" @@ -43,7 +49,7 @@ Status RunShapeInferenceOnComputation( Status CompileTFFunctionToHlo( const FunctionLibraryDefinition& flib_def, int graph_def_version, const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns, - const std::vector& arg_shapes, + const std::vector& arg_shapes, const DeviceType& device_type, const GuaranteedConsts& guaranteed_constants, const NameAttrList& function, const tpu::TPUCompileMetadataProto& metadata, xla::CompileOnlyClient* client, diff --git a/tensorflow/core/tpu/tpu_embedding_configuration_utils.cc b/tensorflow/core/tpu/tpu_embedding_configuration_utils.cc index 8ee5d873c047ec..3da9697b2e84f4 100644 --- a/tensorflow/core/tpu/tpu_embedding_configuration_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_configuration_utils.cc @@ -15,23 +15,27 @@ limitations under the License. #include "tensorflow/core/tpu/tpu_embedding_configuration_utils.h" +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" +#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" +#include "tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h" namespace tensorflow { namespace tpu { -absl::StatusOr ComputeTotalTagCountForDynamicLearningRates( +absl::StatusOr ComputeTotalTagCountForOptimizerDynamicInputs( const tensorflow::tpu::TPUEmbeddingConfiguration& tpu_embedding_config) { // Ordering of tag elements helps make the subsequent error checking simpler. std::set tag_set; - for (const auto& table_descriptor : tpu_embedding_config.table_descriptor()) { - const auto& lr_spec = - table_descriptor.optimization_parameters().learning_rate(); - if (lr_spec.has_dynamic()) { - tag_set.insert(lr_spec.dynamic().tag()); - } + const auto& opt_params = table_descriptor.optimization_parameters(); + const auto tags_for_table = GetOptimizerDynamicInputTags(opt_params); + tag_set.insert(tags_for_table.begin(), tags_for_table.end()); } // Traverse the tag set to determine that tags are contiguous. @@ -46,7 +50,7 @@ absl::StatusOr ComputeTotalTagCountForDynamicLearningRates( ++next_tag; } - return tag_set.size(); + return static_cast(tag_set.size()); } } // namespace tpu diff --git a/tensorflow/core/tpu/tpu_embedding_configuration_utils.h b/tensorflow/core/tpu/tpu_embedding_configuration_utils.h index 5de698b1e05478..3ac55d17ff64c7 100644 --- a/tensorflow/core/tpu/tpu_embedding_configuration_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_configuration_utils.h @@ -16,17 +16,19 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_UTILS_H_ #define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_CONFIGURATION_UTILS_H_ +#include + #include "absl/status/statusor.h" #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" namespace tensorflow { namespace tpu { -// Returns the total number of unique dynamic learning rate tags. If the tag -// specific is erroneous, returns an invalid argument error. For correct tag -// specification, see the comment next to the DynamicLearningRate proto in +// Returns the total number of unique dynamic input tags used in optimizers. If +// the tag specific is erroneous, returns an invalid argument error. For correct +// tag specification, see the comment next to the OptimizerDynamicInput proto in // //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto. -absl::StatusOr ComputeTotalTagCountForDynamicLearningRates( +absl::StatusOr ComputeTotalTagCountForOptimizerDynamicInputs( const tensorflow::tpu::TPUEmbeddingConfiguration& tpu_embedding_config); } // namespace tpu diff --git a/tensorflow/core/tpu/tpu_embedding_errors_test.cc b/tensorflow/core/tpu/tpu_embedding_errors_test.cc index f0a8d869b797ef..261ca0ac63e2df 100644 --- a/tensorflow/core/tpu/tpu_embedding_errors_test.cc +++ b/tensorflow/core/tpu/tpu_embedding_errors_test.cc @@ -22,9 +22,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/platform/errors.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow::tpu { namespace { diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc index 19c0612efb4414..46f0fa991b6777 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringprintf.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" @@ -42,6 +42,8 @@ std::string GetOptimizationAlgorithmName(OptimizationAlgorithm alg) { return "AdagradMomentum"; case OptimizationAlgorithm::kBoundedAdagrad: return "BoundedAdagrad"; + case OptimizationAlgorithm::kFrequencyAwareAdagrad: + return "FrequencyAwareAdagrad"; case OptimizationAlgorithm::kStochasticGradientDescent: return "StochasticGradientDescent"; case OptimizationAlgorithm::kFtrl: @@ -86,6 +88,8 @@ std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { return "Adagrad with Momentum"; case OptimizationAlgorithm::kBoundedAdagrad: return "Bounded Adagrad"; + case OptimizationAlgorithm::kFrequencyAwareAdagrad: + return "Frequency Aware Adagrad"; case OptimizationAlgorithm::kStochasticGradientDescent: return "stochastic gradient descent"; case OptimizationAlgorithm::kFtrl: @@ -125,8 +129,8 @@ std::string GetOptimizationAlgorithmFriendlyName(OptimizationAlgorithm alg) { // Returns the number of optimization parameter vectors used by the optimization // algorithm, excluding the weights themselves and assuming no gradient // accumulation. -Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, - int* count) { +absl::Status GetBaseAuxiliaryParameterCount( + const OptimizationParameters& params, int* count) { switch (params.parameters_case()) { case OptimizationAlgorithm::kAdagrad: *count = 1; @@ -137,6 +141,9 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, case OptimizationAlgorithm::kBoundedAdagrad: *count = 1; return absl::OkStatus(); + case OptimizationAlgorithm::kFrequencyAwareAdagrad: + *count = 2; + return absl::OkStatus(); case OptimizationAlgorithm::kStochasticGradientDescent: *count = 0; return absl::OkStatus(); @@ -205,8 +212,9 @@ Status GetBaseAuxiliaryParameterCount(const OptimizationParameters& params, return errors::InvalidArgument("No optimization algorithm specified"); } -Status GetGradientAccumulationSupport(const OptimizationParameters& params, - GradientAccumulationSupport* support) { +absl::Status GetGradientAccumulationSupport( + const OptimizationParameters& params, + GradientAccumulationSupport* support) { int auxiliary_parameter_count; TF_RETURN_IF_ERROR( GetBaseAuxiliaryParameterCount(params, &auxiliary_parameter_count)); @@ -216,8 +224,8 @@ Status GetGradientAccumulationSupport(const OptimizationParameters& params, return absl::OkStatus(); } -Status UseGradientAccumulation(const OptimizationParameters& params, - bool* use_gradient_accumulation) { +absl::Status UseGradientAccumulation(const OptimizationParameters& params, + bool* use_gradient_accumulation) { GradientAccumulationSupport support; TF_RETURN_IF_ERROR(GetGradientAccumulationSupport(params, &support)); bool raw_gradient_accumulation_status = false; @@ -260,7 +268,7 @@ Status UseGradientAccumulation(const OptimizationParameters& params, return absl::OkStatus(); } -Status GetOptimizationAlgorithmStateVariables( +absl::Status GetOptimizationAlgorithmStateVariables( const OptimizationParameters& params, std::vector* state_variables) { // The parameter set for the weights themselves is required to be named @@ -295,6 +303,12 @@ Status GetOptimizationAlgorithmStateVariables( add_state_variable("accumulators"); break; } + case OptimizationAlgorithm::kFrequencyAwareAdagrad: { + add_state_variable("parameters"); + add_state_variable("accumulators"); + add_state_variable("counters"); + break; + } case OptimizationAlgorithm::kStochasticGradientDescent: { add_state_variable("parameters"); break; @@ -406,11 +420,39 @@ Status GetOptimizationAlgorithmStateVariables( return absl::OkStatus(); } +absl::flat_hash_set GetOptimizerDynamicInputTags( + const OptimizationParameters& params) { + absl::flat_hash_set tags; + if (params.learning_rate().has_dynamic()) { + tags.insert(params.learning_rate().dynamic().tag()); + } + tags.merge(GetOptimizerHyperParameterTags(params)); + return tags; +} + +absl::flat_hash_set GetOptimizerHyperParameterTags( + const OptimizationParameters& params) { + absl::flat_hash_set tags; + switch (params.parameters_case()) { + case OptimizationAlgorithm::kFrequencyAwareAdagrad: + tags.insert(params.frequency_aware_adagrad().step_counter().tag()); + break; + default: + break; + } + return tags; +} + +bool UsesDynamicInputsInOptimizer(const OptimizationParameters& params) { + return !GetOptimizerDynamicInputTags(params).empty(); +} + std::vector GetOptimizationAlgorithms() { return { OptimizationAlgorithm::kAdagrad, OptimizationAlgorithm::kAdagradMomentum, OptimizationAlgorithm::kBoundedAdagrad, + OptimizationAlgorithm::kFrequencyAwareAdagrad, OptimizationAlgorithm::kStochasticGradientDescent, OptimizationAlgorithm::kFtrl, OptimizationAlgorithm::kAdam, @@ -429,7 +471,7 @@ std::vector GetOptimizationAlgorithms() { }; } -Status LoadOpShapeFunction::operator()( +absl::Status LoadOpShapeFunction::operator()( shape_inference::InferenceContext* c) const { int table_id; TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); @@ -459,7 +501,7 @@ Status LoadOpShapeFunction::operator()( return absl::OkStatus(); } -Status RetrieveOpShapeFunction::operator()( +absl::Status RetrieveOpShapeFunction::operator()( shape_inference::InferenceContext* c) const { int table_id; TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); diff --git a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h index 3956852fdf753a..43643fbdb90781 100644 --- a/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_optimization_parameters_utils.h @@ -21,8 +21,9 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h" namespace tensorflow { @@ -50,28 +51,45 @@ enum class GradientAccumulationSupport { // Returns the number of optimization parameter vectors used by the optimization // algorithm, excluding the weights themselves and assuming no gradient // accumulation. -Status GetBaseAuxiliaryParameterCount(const OptimizationParameters ¶ms, - int *count); +absl::Status GetBaseAuxiliaryParameterCount( + const OptimizationParameters ¶ms, int *count); // Returns whether (and how) an optimization algorithm supports gradient // accumulation. -Status GetGradientAccumulationSupport(const OptimizationParameters ¶ms, - GradientAccumulationSupport *support); +absl::Status GetGradientAccumulationSupport( + const OptimizationParameters ¶ms, GradientAccumulationSupport *support); // Returns whether both the given set of optimization parameters has gradient // accumulation turned on and that the algorithm used supports it or should // ignore that setting. Returns an error if gradient accumulation is enabled and // the algorithm does not support it. -Status UseGradientAccumulation(const OptimizationParameters ¶ms, - bool *use_gradient_accumulation); +absl::Status UseGradientAccumulation(const OptimizationParameters ¶ms, + bool *use_gradient_accumulation); // Returns the parameter specifications for the optimization algorithm (the main // parameters first, followed by any auxiliary parameters such as Adagrad // accumulators). -Status GetOptimizationAlgorithmStateVariables( +absl::Status GetOptimizationAlgorithmStateVariables( const OptimizationParameters ¶ms, std::vector *state_variables); +// Returns the set of dynamic input tags used by the optimization algorithm. +// This includes both dynamic learning rates and other hyperparameters (e.g., +// step counters for the frequency aware Adagrad optimizer). +absl::flat_hash_set GetOptimizerDynamicInputTags( + const OptimizationParameters ¶ms); + +// Returns the set of dynamic hyperparameter tags used by the optimization +// algorithm. This includes other hyperparameters used by the optimization +// algorithm (e.g., step counters for the frequency aware Adagrad optimizer). It +// excludes the dynamic learning rate tag. +absl::flat_hash_set GetOptimizerHyperParameterTags( + const OptimizationParameters ¶ms); + +// Returns true if the optimization algorithm uses dynamic inputs in its +// computation. +bool UsesDynamicInputsInOptimizer(const OptimizationParameters ¶ms); + // Maximum value of auxiliary_parametery_count for any optimization algorithm. // This count is used by TPU embedding load/retrieve and needs to be independent // of any particular TPU version and hence, we take the maximum across all TPU @@ -102,14 +120,14 @@ inline float GradientAccumulatorInitialValue() { class LoadOpShapeFunction { public: // Computes resulting shape and does parameter checking. - Status operator()(shape_inference::InferenceContext *c) const; + absl::Status operator()(shape_inference::InferenceContext *c) const; }; // Generic shape function for per-optimization-algorithm retrieve ops. class RetrieveOpShapeFunction { public: // Computes resulting shape and does parameter checking. - Status operator()(shape_inference::InferenceContext *c) const; + absl::Status operator()(shape_inference::InferenceContext *c) const; }; } // namespace tpu diff --git a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc index 57cbd5aa8d998d..faddf8d106f832 100644 --- a/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_output_layout_utils.cc @@ -28,10 +28,11 @@ namespace tpu { Status ComputeOutputTensorShapes( const tensorflow::tpu::TPUEmbeddingConfiguration& config, std::vector* shapes) { - const int64_t core_count_per_replica = - config.spmd_sharding().enabled() - ? config.spmd_sharding().num_cores_per_replica() - : 1; + int64_t core_count_per_replica = 1; + if (config.spmd_sharding().enabled() && + !config.spmd_sharding().use_manual_partitioning()) { + core_count_per_replica = config.spmd_sharding().num_cores_per_replica(); + } if (config.feature_descriptor_size() > 0) { for (const TPUEmbeddingConfiguration::FeatureDescriptor& feature : config.feature_descriptor()) { diff --git a/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.cc b/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.cc index 80a31ecff2d880..64720ad47fe01f 100644 --- a/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.cc +++ b/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h b/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h index 586c060bbfc7c3..957de62bacc1e9 100644 --- a/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h +++ b/tensorflow/core/tpu/tpu_embedding_spmd_sharding_utils.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TPU_TPU_EMBEDDING_SPMD_SHARDING_UTILS_H_ #define TENSORFLOW_CORE_TPU_TPU_EMBEDDING_SPMD_SHARDING_UTILS_H_ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/platform/statusor.h" diff --git a/tensorflow/core/transforms/cse/pass.cc b/tensorflow/core/transforms/cse/pass.cc index f84fdfaaf568e3..f940a2bff65702 100644 --- a/tensorflow/core/transforms/cse/pass.cc +++ b/tensorflow/core/transforms/cse/pass.cc @@ -17,11 +17,11 @@ limitations under the License. #include +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index 9779a4ea8ee6be..72d271e0c5300d 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -744,7 +744,6 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "@local_xla//xla/stream_executor/platform:dso_loader", "@local_xla//xla/stream_executor/rocm:hipsolver_wrapper", "@local_xla//xla/stream_executor/rocm:rocblas_plugin", "@local_xla//xla/stream_executor/rocm:rocblas_wrapper", diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index e428e14e170b40..7be71e3ab3e7bb 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -52,8 +52,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla/tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) @@ -66,8 +66,8 @@ tf_cc_test( ":conv_parameters_proto_cc", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla:test", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) @@ -78,7 +78,7 @@ tf_proto_library( ], protodeps = [ "//tensorflow/core/framework:types_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], ) @@ -136,7 +136,7 @@ tf_proto_library( ], protodeps = [ "//tensorflow/core/util/autotune_maps:conv_parameters_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "@local_xla//xla/tsl/protobuf:dnn_proto", ], visibility = [ "//waymo/ml/deploy/benchmark:__subpackages__", @@ -148,7 +148,6 @@ tf_proto_library( # py_proto_library( # name = "autotune_map_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//waymo/ml/deploy/system/autotuning:__subpackages__"], # deps = [":autotune_map_proto"], # ) @@ -180,12 +179,13 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/platform:str_util", - "@local_tsl//tsl/protobuf:dnn_proto_cc", + "@com_google_absl//absl/strings:string_view", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", "@local_xla//xla/tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/protobuf:dnn_proto_cc", ], ) diff --git a/tensorflow/core/util/autotune_maps/autotune_map.proto b/tensorflow/core/util/autotune_maps/autotune_map.proto index c655b3c1a5927d..79192075761bd9 100644 --- a/tensorflow/core/util/autotune_maps/autotune_map.proto +++ b/tensorflow/core/util/autotune_maps/autotune_map.proto @@ -21,8 +21,8 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/dnn.proto"; import "tensorflow/core/util/autotune_maps/conv_parameters.proto"; -import "tsl/protobuf/dnn.proto"; message ConvMapProto { message Entry { diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc index c601502a0d0512..51e4a365821767 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc @@ -26,13 +26,13 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { @@ -173,7 +173,7 @@ Status PopulateConvMap( } // namespace #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -Status SerializeAutotuneMaps(std::string *output) { +absl::Status SerializeAutotuneMaps(std::string *output) { AutotuneMapsProto proto; #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM TF_ASSIGN_OR_RETURN(*proto.mutable_conv_map(), @@ -185,7 +185,7 @@ Status SerializeAutotuneMaps(std::string *output) { return absl::OkStatus(); } -Status LoadSerializedAutotuneMaps(absl::string_view s) { +absl::Status LoadSerializedAutotuneMaps(absl::string_view s) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM AutotuneMapsProto proto; // The explicit string conversion here is a workaround for diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.h b/tensorflow/core/util/autotune_maps/autotune_serialize.h index 9aedc6de0abce3..8c8bdc2f7e13a7 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.h +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.h @@ -27,6 +27,7 @@ limitations under the License. #include +#include "absl/strings/string_view.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -34,11 +35,11 @@ namespace tensorflow { // TODO(b/189530096) Support autotune maps for more ops. // Loads autotune maps from string output by SerializeAutotuneMaps and uses // them to update the runtime autotune maps. -Status LoadSerializedAutotuneMaps(absl::string_view s); +absl::Status LoadSerializedAutotuneMaps(absl::string_view s); // Serializes all the autotune maps into a string that can be decoded by // LoadSerializedAutotuneMaps. -Status SerializeAutotuneMaps(std::string* output); +absl::Status SerializeAutotuneMaps(std::string* output); // Resets all autotune maps. For test use only. void ResetAutotuneMaps(); diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc index 0bd1122c132238..8044441680501b 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc index 6279cd03ae25ac..5443e8e28c7193 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include #include "xla/test.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.proto b/tensorflow/core/util/autotune_maps/conv_parameters.proto index 03a9cfd005d6f6..aee217e80968c9 100644 --- a/tensorflow/core/util/autotune_maps/conv_parameters.proto +++ b/tensorflow/core/util/autotune_maps/conv_parameters.proto @@ -22,8 +22,8 @@ syntax = "proto3"; package tensorflow; +import "xla/tsl/protobuf/dnn.proto"; import "tensorflow/core/framework/types.proto"; -import "tsl/protobuf/dnn.proto"; // LINT.IfChange diff --git a/tensorflow/core/util/batch_util.cc b/tensorflow/core/util/batch_util.cc index 24823886285488..5a82c0af6c6959 100644 --- a/tensorflow/core/util/batch_util.cc +++ b/tensorflow/core/util/batch_util.cc @@ -33,8 +33,8 @@ namespace batch_util { namespace { -Status ValidateInput(const Tensor& parent, const Tensor& element, - int64_t index) { +absl::Status ValidateInput(const Tensor& parent, const Tensor& element, + int64_t index) { DCHECK_NE(parent.dim_size(0), 0); DCHECK_GE(index, 0); if (element.NumElements() != (parent.NumElements() / parent.dim_size(0))) { @@ -50,8 +50,8 @@ Status ValidateInput(const Tensor& parent, const Tensor& element, } template -Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest, - int64_t num_values) { +absl::Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest, + int64_t num_values) { static_assert(tsl::is_simple_type::value, "Memcpy requires a simple type."); memcpy(dest, src, num_values * sizeof(T)); @@ -59,8 +59,8 @@ Status HandleElementToSlice(const Tensor& /* element */, T* src, T* dest, } template <> -Status HandleElementToSlice(const Tensor& element, tstring* src, - tstring* dest, int64_t num_values) { +absl::Status HandleElementToSlice(const Tensor& element, tstring* src, + tstring* dest, int64_t num_values) { if (element.RefCountIsOne()) { for (int64_t i = 0; i < num_values; ++i) { *dest++ = std::move(*src++); @@ -72,8 +72,8 @@ Status HandleElementToSlice(const Tensor& element, tstring* src, } template <> -Status HandleElementToSlice(const Tensor& element, Variant* src, - Variant* dest, int64_t num_values) { +absl::Status HandleElementToSlice(const Tensor& element, Variant* src, + Variant* dest, int64_t num_values) { if (element.RefCountIsOne()) { for (int64_t i = 0; i < num_values; ++i) { *dest++ = std::move(*src++); @@ -85,18 +85,19 @@ Status HandleElementToSlice(const Tensor& element, Variant* src, } template <> -Status HandleElementToSlice(const Tensor& /* element */, - ResourceHandle* src, - ResourceHandle* dest, - int64_t num_values) { +absl::Status HandleElementToSlice(const Tensor& /* element */, + ResourceHandle* src, + ResourceHandle* dest, + int64_t num_values) { std::copy_n(src, num_values, dest); return absl::OkStatus(); } template <> -Status HandleElementToSlice(const Tensor& /* element */, - Eigen::half* src, Eigen::half* dest, - int64_t num_values) { +absl::Status HandleElementToSlice(const Tensor& /* element */, + Eigen::half* src, + Eigen::half* dest, + int64_t num_values) { std::copy_n(src, num_values, dest); return absl::OkStatus(); } @@ -180,7 +181,7 @@ void HandleSliceToElement(Tensor* parent, Eigen::half* src, } // namespace // Copies element into the index^th slice of parent (in the 0th dimension). -Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) { +absl::Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) { TF_RETURN_IF_ERROR(ValidateInput(*parent, element, index)); const int64_t num_values = element.NumElements(); #define HANDLE_TYPE(T) \ @@ -201,8 +202,8 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index) { } // Copies the index^th slice of parent (in the 0th dimension) into element. -Status CopySliceToElement(const Tensor& parent, Tensor* element, - int64_t index) { +absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index) { TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index)); const int64_t num_values = element->NumElements(); @@ -226,9 +227,9 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, // Does the same thing as `CopyContiguousSlices` except it might move // the underlying data from `src` to `dst` when possible. -Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst) { +absl::Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst) { if (src.dtype() != dst->dtype()) { return absl::FailedPreconditionError(absl::StrCat( "MaybeMoveContiguousSlices cannot perform copy: src and dst have " @@ -299,9 +300,9 @@ Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, } } -Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst) { +absl::Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst) { if (src.dtype() != dst->dtype()) { return errors::FailedPrecondition( "CopyContiguousSlices cannot perform copy: src and dst have different " @@ -375,7 +376,8 @@ Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, // // NOTE(mrry): The implementation may be able to optimize the copy to a move. // This is particularly important for DT_STRING tensors. -Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) { +absl::Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, + int64_t index) { TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index)); const int64_t num_values = element->NumElements(); @@ -400,7 +402,8 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index) { // The following five functions are copied from padding_fifo_queue.cc. // TODO(mrry): Reconcile these functions with the similar methods in the // queue implementation. -Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { +absl::Status ValidateElementToLargerSlice(const Tensor& element, + Tensor* parent) { DCHECK_NE(parent->dim_size(0), 0); if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) { TensorShape chip_shape = parent->shape(); @@ -415,8 +418,8 @@ Status ValidateElementToLargerSlice(const Tensor& element, Tensor* parent) { } template -Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { +absl::Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { TF_RETURN_IF_ERROR(ValidateElementToLargerSlice(element, parent)); if (element.NumElements() == 0) { return absl::OkStatus(); @@ -435,8 +438,8 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, } template -Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, - int index) { +absl::Status HandleElementToLargerSliceWithRank(const Tensor& element, + Tensor* parent, int index) { #define HANDLE_TYPE(T) \ case DataTypeToEnum::value: { \ return HandleElementToLargerSlice(element, parent, index); \ @@ -452,8 +455,8 @@ Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent, } } -Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, - int index) { +absl::Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index) { if (parent->dims() != element.dims() + 1) { return errors::Internal( "Mismatched ranks. Element's rank is: ", element.dims(), @@ -482,7 +485,7 @@ Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, } } -Status SetElementZero(Tensor* element, const Tensor& padding) { +absl::Status SetElementZero(Tensor* element, const Tensor& padding) { #define HANDLE_TYPE(T) \ if (element->dtype() == DataTypeToEnum::value) { \ element->flat().setConstant(padding.scalar()()); \ diff --git a/tensorflow/core/util/batch_util.h b/tensorflow/core/util/batch_util.h index d31fd647ca1283..176c229ad80fd5 100644 --- a/tensorflow/core/util/batch_util.h +++ b/tensorflow/core/util/batch_util.h @@ -27,10 +27,11 @@ namespace batch_util { // to move the `element` argument into this function, and the implementation // may be able to optimize the copy to a move. This is particularly important // for DT_STRING tensors. -Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); +absl::Status CopyElementToSlice(Tensor element, Tensor* parent, int64_t index); // Copies the index^th slice of parent (in the 0th dimension) into element. -Status CopySliceToElement(const Tensor& parent, Tensor* element, int64_t index); +absl::Status CopySliceToElement(const Tensor& parent, Tensor* element, + int64_t index); // Copies 'num_slices' contiguous slices from 'src' tensor starting from index // 'src_offset' into target tensor 'dst', and places them into slices @@ -40,31 +41,32 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64_t index); // requires cum_prod(src.shape[1:] == cum_prod(dst->shape[1:]). For example if // source is of shape [x, 2, 1] and dst is a tensor of shape [y, 1, 2], this // function can still proceed successfully. -Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst); +absl::Status CopyContiguousSlices(const Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); // Copies the index^th slice of parent (in the 0th dimension) into element. // // NOTE(mrry): The implementation may be able to optimize the copy to a move. // This is particularly important for DT_STRING tensors. -Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64_t index); +absl::Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, + int64_t index); // Moves `src` Tensor's data in [src_offset, src_offset+num_slices) along // the first dimension if possible. Otherwise, copy them into `dst`. -Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, - int64_t dst_offset, int64_t num_slices, - Tensor* dst); +absl::Status MaybeMoveContiguousSlices(Tensor& src, int64_t src_offset, + int64_t dst_offset, int64_t num_slices, + Tensor* dst); // Zero-initializes the tensor `element` using the scalar stored in `padding`. // Both `element` and `padding` must have matching `dtype`. -Status SetElementZero(Tensor* element, const Tensor& padding); +absl::Status SetElementZero(Tensor* element, const Tensor& padding); // Copies `element` into a (0th dimension) slice of `parent`, assuming // the shape of `element` is strictly not larger along any axis than a // slice. -Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, - int index); +absl::Status CopyElementToLargerSlice(const Tensor& element, Tensor* parent, + int index); } // namespace batch_util } // namespace tensorflow diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index c325426aed3559..a592d7a36ec226 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -99,10 +99,10 @@ class CTCBeamSearchDecoder : public CTCDecoder { ~CTCBeamSearchDecoder() override {} // Run the hibernating beam search algorithm on the given input. - Status Decode(const typename CTCDecoder::SequenceLength& seq_len, - const std::vector::Input>& input, - std::vector::Output>* output, - typename CTCDecoder::ScoreOutput* scores) override; + absl::Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override; // Calculate the next step of the beam search and update the internal state. template @@ -129,8 +129,8 @@ class CTCBeamSearchDecoder : public CTCDecoder { void Reset(); // Extract the top n paths at current time step - Status TopPaths(int n, std::vector>* paths, - std::vector* log_probs, bool merge_repeated) const; + absl::Status TopPaths(int n, std::vector>* paths, + std::vector* log_probs, bool merge_repeated) const; private: int beam_width_; @@ -157,7 +157,7 @@ class CTCBeamSearchDecoder : public CTCDecoder { }; template -Status CTCBeamSearchDecoder::Decode( +absl::Status CTCBeamSearchDecoder::Decode( const typename CTCDecoder::SequenceLength& seq_len, const std::vector::Input>& input, std::vector::Output>* output, @@ -198,7 +198,7 @@ Status CTCBeamSearchDecoder::Decode( leaves_.push(entry); } - Status status = + absl::Status status = TopPaths(top_n, &beams, &beam_log_probabilities, this->merge_repeated_); if (!status.ok()) { return status; @@ -400,7 +400,7 @@ void CTCBeamSearchDecoder::Reset() { } template -Status CTCBeamSearchDecoder::TopPaths( +absl::Status CTCBeamSearchDecoder::TopPaths( int n, std::vector>* paths, std::vector* log_probs, bool merge_repeated) const { CHECK_NOTNULL(paths)->clear(); diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 088722adb892b8..8e6b3477d2baac 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -56,9 +56,10 @@ class CTCDecoder { // - input[t].rows(b) - t = 0 to timesteps; b = 0 t batch_size_ // - output.size() specifies the number of beams to be returned. // - scores(b, i) - b = 0 to batch_size; i = 0 to output.size() - virtual Status Decode(const SequenceLength& seq_len, - const std::vector& input, - std::vector* output, ScoreOutput* scores) = 0; + virtual absl::Status Decode(const SequenceLength& seq_len, + const std::vector& input, + std::vector* output, + ScoreOutput* scores) = 0; int batch_size() { return batch_size_; } int num_classes() { return num_classes_; } @@ -79,10 +80,10 @@ class CTCGreedyDecoder : public CTCDecoder { CTCGreedyDecoder(int num_classes, int batch_size, bool merge_repeated) : CTCDecoder(num_classes, batch_size, merge_repeated) {} - Status Decode(const typename CTCDecoder::SequenceLength& seq_len, - const std::vector::Input>& input, - std::vector::Output>* output, - typename CTCDecoder::ScoreOutput* scores) override { + absl::Status Decode(const typename CTCDecoder::SequenceLength& seq_len, + const std::vector::Input>& input, + std::vector::Output>* output, + typename CTCDecoder::ScoreOutput* scores) override { if (output->empty() || (*output)[0].size() < Decoder::batch_size_) { return errors::InvalidArgument( "output needs to be of size at least (1, batch_size)."); diff --git a/tensorflow/core/util/ctc/ctc_loss_calculator.h b/tensorflow/core/util/ctc/ctc_loss_calculator.h index 5f4311c9ee31b0..12c4ac0a96530c 100644 --- a/tensorflow/core/util/ctc/ctc_loss_calculator.h +++ b/tensorflow/core/util/ctc/ctc_loss_calculator.h @@ -65,13 +65,12 @@ class CTCLossCalculator { template - Status CalculateLoss(const VectorIn& seq_len, const LabelSequences& labels, - const std::vector& inputs, - bool preprocess_collapse_repeated, - bool ctc_merge_repeated, - bool ignore_longer_outputs_than_inputs, VectorOut* loss, - std::vector* gradients, - DeviceBase::CpuWorkerThreads* workers = nullptr) const; + absl::Status CalculateLoss( + const VectorIn& seq_len, const LabelSequences& labels, + const std::vector& inputs, bool preprocess_collapse_repeated, + bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, + VectorOut* loss, std::vector* gradients, + DeviceBase::CpuWorkerThreads* workers = nullptr) const; private: void CalculateForwardVariables(const std::vector& l_prime, @@ -94,11 +93,13 @@ class CTCLossCalculator { // batch. Return value: // max_{b in batch_size} l_primes[b].size() template - Status PopulateLPrimes(bool preprocess_collapse_repeated, - bool ignore_longer_outputs_than_inputs, int batch_size, - int num_classes, const Vector& seq_len, - const LabelSequences& labels, size_t* max_u_prime, - LabelSequences* l_primes) const; + absl::Status PopulateLPrimes(bool preprocess_collapse_repeated, + bool ignore_longer_outputs_than_inputs, + int batch_size, int num_classes, + const Vector& seq_len, + const LabelSequences& labels, + size_t* max_u_prime, + LabelSequences* l_primes) const; // Utility indices for the CTC algorithm. int blank_index_; @@ -111,7 +112,7 @@ class CTCLossCalculator { template template -Status CTCLossCalculator::CalculateLoss( +absl::Status CTCLossCalculator::CalculateLoss( const VectorIn& seq_len, const LabelSequences& labels, const std::vector& inputs, bool preprocess_collapse_repeated, bool ctc_merge_repeated, bool ignore_longer_outputs_than_inputs, @@ -164,7 +165,7 @@ Status CTCLossCalculator::CalculateLoss( // and calculate the maximum necessary allocation size. LabelSequences l_primes(batch_size); size_t max_u_prime = 0; - Status l_p_ret = PopulateLPrimes( + absl::Status l_p_ret = PopulateLPrimes( preprocess_collapse_repeated, ignore_longer_outputs_than_inputs, batch_size, num_classes, seq_len, labels, &max_u_prime, &l_primes); if (!l_p_ret.ok()) { @@ -284,7 +285,7 @@ Status CTCLossCalculator::CalculateLoss( template template -Status CTCLossCalculator::PopulateLPrimes( +absl::Status CTCLossCalculator::PopulateLPrimes( bool preprocess_collapse_repeated, bool ignore_longer_outputs_than_inputs, int batch_size, int num_classes, const Vector& seq_len, const LabelSequences& labels, size_t* max_u_prime, diff --git a/tensorflow/core/util/cuda_solvers.cc b/tensorflow/core/util/cuda_solvers.cc index 2e990d9dc99125..679e0d25f54d08 100644 --- a/tensorflow/core/util/cuda_solvers.cc +++ b/tensorflow/core/util/cuda_solvers.cc @@ -21,7 +21,6 @@ #include "third_party/gpus/cuda/include/cublas_v2.h" #include "third_party/gpus/cuda/include/cusolverDn.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -98,8 +97,6 @@ using trsm_Z = cublasStatus_t(cublasContext*, cublasSideMode_t, namespace tensorflow { namespace { -using se::gpu::ScopedActivateContext; - inline bool CopyHostToDevice(OpKernelContext* context, void* dst, const void* src, uint64 bytes) { auto stream = context->op_device_context()->stream(); @@ -228,7 +225,8 @@ void GpuSolver::CheckLapackInfoAndDeleteSolverAsync( std::function&)> info_checker_callback, std::vector host_lapack_infos) { - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); Status status; for (const auto& host_lapack_info : host_lapack_infos) { for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) { diff --git a/tensorflow/core/util/debug_events_writer.cc b/tensorflow/core/util/debug_events_writer.cc index db0f83ed8a0526..9790422adc2701 100644 --- a/tensorflow/core/util/debug_events_writer.cc +++ b/tensorflow/core/util/debug_events_writer.cc @@ -44,7 +44,7 @@ SingleDebugEventFileWriter::SingleDebugEventFileWriter(const string& file_path) num_outstanding_events_(0), writer_mu_() {} -Status SingleDebugEventFileWriter::Init() { +absl::Status SingleDebugEventFileWriter::Init() { if (record_writer_ != nullptr) { // TODO(cais): We currently don't check for file deletion. When the need // arises, check and fix it. @@ -83,7 +83,7 @@ void SingleDebugEventFileWriter::WriteSerializedDebugEvent( } } -Status SingleDebugEventFileWriter::Flush() { +absl::Status SingleDebugEventFileWriter::Flush() { const int num_outstanding = num_outstanding_events_.load(); if (num_outstanding == 0) { return absl::OkStatus(); @@ -106,10 +106,10 @@ Status SingleDebugEventFileWriter::Flush() { return absl::OkStatus(); } -Status SingleDebugEventFileWriter::Close() { - Status status = Flush(); +absl::Status SingleDebugEventFileWriter::Close() { + absl::Status status = Flush(); if (writable_file_ != nullptr) { - Status close_status = writable_file_->Close(); + absl::Status close_status = writable_file_->Close(); if (!close_status.ok()) { status = close_status; } @@ -142,7 +142,7 @@ DebugEventsWriter* DebugEventsWriter::GetDebugEventsWriter( } // static -Status DebugEventsWriter::LookUpDebugEventsWriter( +absl::Status DebugEventsWriter::LookUpDebugEventsWriter( const string& dump_root, DebugEventsWriter** debug_events_writer) { mutex_lock l(DebugEventsWriter::factory_mu_); std::unordered_map>* writer_pool = @@ -155,7 +155,7 @@ Status DebugEventsWriter::LookUpDebugEventsWriter( return absl::OkStatus(); } -Status DebugEventsWriter::Init() { +absl::Status DebugEventsWriter::Init() { mutex_lock l(initialization_mu_); // TODO(cais): We currently don't check for file deletion. When the need @@ -205,33 +205,34 @@ Status DebugEventsWriter::Init() { return absl::OkStatus(); } -Status DebugEventsWriter::WriteSourceFile(SourceFile* source_file) { +absl::Status DebugEventsWriter::WriteSourceFile(SourceFile* source_file) { DebugEvent debug_event; debug_event.set_allocated_source_file(source_file); return SerializeAndWriteDebugEvent(&debug_event, SOURCE_FILES); } -Status DebugEventsWriter::WriteStackFrameWithId( +absl::Status DebugEventsWriter::WriteStackFrameWithId( StackFrameWithId* stack_frame_with_id) { DebugEvent debug_event; debug_event.set_allocated_stack_frame_with_id(stack_frame_with_id); return SerializeAndWriteDebugEvent(&debug_event, STACK_FRAMES); } -Status DebugEventsWriter::WriteGraphOpCreation( +absl::Status DebugEventsWriter::WriteGraphOpCreation( GraphOpCreation* graph_op_creation) { DebugEvent debug_event; debug_event.set_allocated_graph_op_creation(graph_op_creation); return SerializeAndWriteDebugEvent(&debug_event, GRAPHS); } -Status DebugEventsWriter::WriteDebuggedGraph(DebuggedGraph* debugged_graph) { +absl::Status DebugEventsWriter::WriteDebuggedGraph( + DebuggedGraph* debugged_graph) { DebugEvent debug_event; debug_event.set_allocated_debugged_graph(debugged_graph); return SerializeAndWriteDebugEvent(&debug_event, GRAPHS); } -Status DebugEventsWriter::WriteExecution(Execution* execution) { +absl::Status DebugEventsWriter::WriteExecution(Execution* execution) { if (circular_buffer_size_ <= 0) { // No cyclic-buffer behavior. DebugEvent debug_event; @@ -254,7 +255,7 @@ Status DebugEventsWriter::WriteExecution(Execution* execution) { } } -Status DebugEventsWriter::WriteGraphExecutionTrace( +absl::Status DebugEventsWriter::WriteGraphExecutionTrace( GraphExecutionTrace* graph_execution_trace) { TF_RETURN_IF_ERROR(Init()); if (circular_buffer_size_ <= 0) { @@ -279,7 +280,7 @@ Status DebugEventsWriter::WriteGraphExecutionTrace( } } -Status DebugEventsWriter::WriteGraphExecutionTrace( +absl::Status DebugEventsWriter::WriteGraphExecutionTrace( const string& tfdbg_context_id, const string& device_name, const string& op_name, int32_t output_slot, int32_t tensor_debug_mode, const Tensor& tensor_value) { @@ -356,7 +357,7 @@ int DebugEventsWriter::RegisterDeviceAndGetId(const string& device_name) { return device_id; } -Status DebugEventsWriter::FlushNonExecutionFiles() { +absl::Status DebugEventsWriter::FlushNonExecutionFiles() { TF_RETURN_IF_ERROR(Init()); if (source_files_writer_ != nullptr) { TF_RETURN_IF_ERROR(source_files_writer_->Flush()); @@ -370,7 +371,7 @@ Status DebugEventsWriter::FlushNonExecutionFiles() { return absl::OkStatus(); } -Status DebugEventsWriter::FlushExecutionFiles() { +absl::Status DebugEventsWriter::FlushExecutionFiles() { TF_RETURN_IF_ERROR(Init()); if (execution_writer_ != nullptr) { @@ -409,7 +410,7 @@ string DebugEventsWriter::FileName(DebugEventFileType type) { return GetFileNameInternal(type); } -Status DebugEventsWriter::Close() { +absl::Status DebugEventsWriter::Close() { { mutex_lock l(initialization_mu_); if (!is_initialized_) { @@ -495,7 +496,7 @@ DebugEventsWriter::DebugEventsWriter(const string& dump_root, device_name_to_id_(), device_mu_() {} -Status DebugEventsWriter::InitNonMetadataFile(DebugEventFileType type) { +absl::Status DebugEventsWriter::InitNonMetadataFile(DebugEventFileType type) { std::unique_ptr* writer = nullptr; SelectWriter(type, &writer); const string filename = GetFileNameInternal(type); @@ -513,8 +514,8 @@ Status DebugEventsWriter::InitNonMetadataFile(DebugEventFileType type) { return absl::OkStatus(); } -Status DebugEventsWriter::SerializeAndWriteDebugEvent(DebugEvent* debug_event, - DebugEventFileType type) { +absl::Status DebugEventsWriter::SerializeAndWriteDebugEvent( + DebugEvent* debug_event, DebugEventFileType type) { std::unique_ptr* writer = nullptr; SelectWriter(type, &writer); if (writer != nullptr) { diff --git a/tensorflow/core/util/debug_events_writer.h b/tensorflow/core/util/debug_events_writer.h index 79aad9488c9438..1fa4718d45e30e 100644 --- a/tensorflow/core/util/debug_events_writer.h +++ b/tensorflow/core/util/debug_events_writer.h @@ -51,12 +51,12 @@ class SingleDebugEventFileWriter { public: explicit SingleDebugEventFileWriter(const string& file_path); - Status Init(); + absl::Status Init(); void WriteSerializedDebugEvent(tensorflow::StringPiece debug_event_str); - Status Flush(); - Status Close(); + absl::Status Flush(); + absl::Status Close(); const string FileName(); @@ -115,7 +115,7 @@ class DebugEventsWriter { // If no DebugEventsWriter has been created at the dump_root, a non-OK // Status will be returned. Else an OK status will be returned, with // the pointer to the existing instance provided by reference. - static Status LookUpDebugEventsWriter( + static absl::Status LookUpDebugEventsWriter( const string& dump_root, DebugEventsWriter** debug_events_writer); ~DebugEventsWriter(); @@ -126,31 +126,32 @@ class DebugEventsWriter { // Idempotent: if the metadata file exists and is open, this is a no-op. // If on the other hand the file was opened, but has since disappeared (e.g. // deleted by another process), this will open a new file. - Status Init(); + absl::Status Init(); // The four DebugEvent fields below are written _without_ the circular // buffer. Source file contents are written to the *.source_files file. // Takes ownership of source_file. - Status WriteSourceFile(SourceFile* source_file); + absl::Status WriteSourceFile(SourceFile* source_file); // Stack frames are written to the *.code_locations file. // Takes ownership of stack_frame_with_id. - Status WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id); + absl::Status WriteStackFrameWithId(StackFrameWithId* stack_frame_with_id); // Graph op creation events are written to the *.graphs file. // Takes ownership of graph_op_creation. - Status WriteGraphOpCreation(GraphOpCreation* graph_op_creation); + absl::Status WriteGraphOpCreation(GraphOpCreation* graph_op_creation); // Debugged graphs are written to the *.graphs file. // Takes ownership of debugged_graph. - Status WriteDebuggedGraph(DebuggedGraph* debugged_graph); + absl::Status WriteDebuggedGraph(DebuggedGraph* debugged_graph); // The two DebugEvent fields below are written to the circular buffer // and saved to disk only at the FlushExecutionFiles() call. // Execution events (eager execution of an op or a tf.function) are written // to the *.execution file. Takes ownership of execution. - Status WriteExecution(Execution* execution); + absl::Status WriteExecution(Execution* execution); // Graph execution traces (graph-internal tensor values or their summaries) // are written to the *.graph_execution_traces file. // Takes ownership of graph_execution_trace. - Status WriteGraphExecutionTrace(GraphExecutionTrace* graph_execution_trace); + absl::Status WriteGraphExecutionTrace( + GraphExecutionTrace* graph_execution_trace); // Write a graph execution trace without using a protocol buffer. // Instead, pass the raw values related to the graph execution trace. @@ -167,11 +168,12 @@ class DebugEventsWriter { // tensor(s) // that this trace is concerned with. The semantics of this tensor value // depends on the value of `tensor_debug_mode`. - Status WriteGraphExecutionTrace(const string& tfdbg_context_id, - const string& device_name, - const string& op_name, int32_t output_slot, - int32_t tensor_debug_mode, - const Tensor& tensor_value); + absl::Status WriteGraphExecutionTrace(const string& tfdbg_context_id, + const string& device_name, + const string& op_name, + int32_t output_slot, + int32_t tensor_debug_mode, + const Tensor& tensor_value); // Writes a serialized DebugEvent to one of the debug-events files // concerned with the non-execution events: the SOURCE_FILES, STACK_FRAMES @@ -200,15 +202,15 @@ class DebugEventsWriter { // and/or check for success. // FlushNonExecutionFiles() pushes outstanding DebugEvents not written // events to the circular buffer to their respective files. - Status FlushNonExecutionFiles(); + absl::Status FlushNonExecutionFiles(); // Writes current contents of the circular buffers to their respective // debug event files and clears the circular buffers. - Status FlushExecutionFiles(); + absl::Status FlushExecutionFiles(); // Close() calls FlushNonExecutionFiles() and FlushExecutionFiles() // and then closes the current debug events files. - Status Close(); + absl::Status Close(); private: static std::unordered_map>* @@ -228,10 +230,10 @@ class DebugEventsWriter { string FileName(DebugEventFileType type); // Initialize the TFRecord writer for non-metadata file type. - Status InitNonMetadataFile(DebugEventFileType type); + absl::Status InitNonMetadataFile(DebugEventFileType type); - Status SerializeAndWriteDebugEvent(DebugEvent* debug_event, - DebugEventFileType type); + absl::Status SerializeAndWriteDebugEvent(DebugEvent* debug_event, + DebugEventFileType type); void SelectWriter(DebugEventFileType type, std::unique_ptr** writer); diff --git a/tensorflow/core/util/debug_events_writer_test.cc b/tensorflow/core/util/debug_events_writer_test.cc index 5ac30597608900..2021010735aebb 100644 --- a/tensorflow/core/util/debug_events_writer_test.cc +++ b/tensorflow/core/util/debug_events_writer_test.cc @@ -63,7 +63,7 @@ class DebugEventsWriterTest : public ::testing::Test { static bool ReadDebugEventProto(io::RecordReader* reader, uint64* offset, DebugEvent* proto) { tstring record; - Status s = reader->ReadRecord(offset, &record); + absl::Status s = reader->ReadRecord(offset, &record); if (!s.ok()) { return false; } diff --git a/tensorflow/core/util/dump_graph.cc b/tensorflow/core/util/dump_graph.cc index 2bec2edfe5bd5a..c8eb3d48060d71 100644 --- a/tensorflow/core/util/dump_graph.cc +++ b/tensorflow/core/util/dump_graph.cc @@ -74,9 +74,9 @@ struct GraphDumperConfig { // The dumper and suffix configured. struct Config { bool IsSet() const { return dumper != nullptr; } - std::function + std::function dumper = nullptr; string suffix = ".pbtxt"; } config TF_GUARDED_BY(mu); @@ -95,7 +95,8 @@ GraphDumperConfig& GetGraphDumperConfig() { string GetDumpGraphFormatLowerCase() { string fmt; - Status status = tsl::ReadStringFromEnvVar("TF_DUMP_GRAPH_FMT", "TXT", &fmt); + absl::Status status = + tsl::ReadStringFromEnvVar("TF_DUMP_GRAPH_FMT", "TXT", &fmt); if (!status.ok()) { LOG(WARNING) << "Failed to read TF_DUMP_GRAPH_FMT: " << status; return "txt"; @@ -120,33 +121,34 @@ class StderrWritableFile : public WritableFile { public: StderrWritableFile() = default; - Status Append(StringPiece data) override { + absl::Status Append(StringPiece data) override { fprintf(stderr, "%.*s", static_cast(data.size()), data.data()); return absl::OkStatus(); } - Status Close() override { return absl::OkStatus(); } + absl::Status Close() override { return absl::OkStatus(); } - Status Flush() override { + absl::Status Flush() override { fflush(stderr); return absl::OkStatus(); } - Status Name(StringPiece* result) const override { + absl::Status Name(StringPiece* result) const override { *result = "stderr"; return absl::OkStatus(); } - Status Sync() override { return absl::OkStatus(); } + absl::Status Sync() override { return absl::OkStatus(); } - Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { return errors::Unimplemented("Stream not seekable"); } }; -Status CreateWritableFile(Env* env, const string& dirname, const string& name, - const string& suffix, string* filepath, - std::unique_ptr* file) { +absl::Status CreateWritableFile(Env* env, const string& dirname, + const string& name, const string& suffix, + string* filepath, + std::unique_ptr* file) { string dir; if (!dirname.empty()) { dir = dirname; @@ -183,8 +185,8 @@ Status CreateWritableFile(Env* env, const string& dirname, const string& name, return env->NewWritableFile(*filepath, file); } -Status WriteProtoToUniqueFile(const tensorflow::protobuf::Message& proto, - WritableFile* file) { +absl::Status WriteProtoToUniqueFile(const tensorflow::protobuf::Message& proto, + WritableFile* file) { string s; string format = GetDumpGraphFormatLowerCase(); if (format == "txt" && @@ -205,8 +207,8 @@ Status WriteProtoToUniqueFile(const tensorflow::protobuf::Message& proto, return file->Close(); } -Status WriteProtoToUniqueFile(const tensorflow::protobuf::MessageLite& proto, - WritableFile* file) { +absl::Status WriteProtoToUniqueFile( + const tensorflow::protobuf::MessageLite& proto, WritableFile* file) { string s; if (!SerializeToStringDeterministic(proto, &s)) { return errors::Internal("Failed to serialize proto to string."); @@ -223,11 +225,11 @@ Status WriteProtoToUniqueFile(const tensorflow::protobuf::MessageLite& proto, string DumpToFile(const string& name, const string& dirname, const string& suffix, absl::string_view type_name, - std::function dumper) { + std::function dumper) { string filepath; std::unique_ptr file; - Status status = CreateWritableFile(Env::Default(), dirname, name, suffix, - &filepath, &file); + absl::Status status = CreateWritableFile(Env::Default(), dirname, name, + suffix, &filepath, &file); if (!status.ok()) { return StrCat("(failed to create writable file: ", status.ToString(), ")"); } @@ -242,9 +244,9 @@ string DumpToFile(const string& name, const string& dirname, } void SetGraphDumper( - std::function + std::function dumper, string suffix) { GraphDumperConfig& dumper_config = GetGraphDumperConfig(); diff --git a/tensorflow/core/util/dump_graph.h b/tensorflow/core/util/dump_graph.h index b70aec362ca640..0d0c55754dd8ff 100644 --- a/tensorflow/core/util/dump_graph.h +++ b/tensorflow/core/util/dump_graph.h @@ -70,9 +70,9 @@ string DumpProtoToFile(const string& name, // instead via DumpGraphToFile. As the custom dumper may not produce protobufs, // allow specifying a file suffix/extension too. void SetGraphDumper( - std::function + std::function dumper, string suffix = ".pbtxt"); @@ -81,7 +81,7 @@ void SetGraphDumper( // The dumper callback will be responsible for writing data to the file. string DumpToFile(const string& name, const string& dirname, const string& suffix, absl::string_view type_name, - std::function dumper); + std::function dumper); } // namespace tensorflow diff --git a/tensorflow/core/util/einsum_op_util.cc b/tensorflow/core/util/einsum_op_util.cc index ff133d9df18530..55151c724af993 100644 --- a/tensorflow/core/util/einsum_op_util.cc +++ b/tensorflow/core/util/einsum_op_util.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { -Status ValidateEinsumEquation( +absl::Status ValidateEinsumEquation( const string& equation, absl::InlinedVector* input_subscripts, string* output_subscript) { absl::InlinedVector inputs_and_output_subscripts = @@ -81,13 +81,12 @@ void MapToLabels(const string& subscript, Labels* labels, } } -Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels, - Labels* output_labels, - std::vector* label_types, - OperandLabelCounts* input_label_counts, - LabelCounts* output_label_counts, - absl::InlinedVector* input_has_ellipsis, - bool* output_has_ellipsis) { +absl::Status ParseEinsumEquation( + const string& equation, OperandLabels* input_labels, Labels* output_labels, + std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis) { absl::InlinedVector input_str; string output_str; TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str)); diff --git a/tensorflow/core/util/einsum_op_util.h b/tensorflow/core/util/einsum_op_util.h index 664cb22b9fac08..6155b8a08d663b 100644 --- a/tensorflow/core/util/einsum_op_util.h +++ b/tensorflow/core/util/einsum_op_util.h @@ -52,7 +52,7 @@ enum EinsumDimensionType { }; // Parses and validates an einsum equation in explicit form. -Status ValidateEinsumEquation( +absl::Status ValidateEinsumEquation( const string& equation, absl::InlinedVector* input_subscripts, string* output_subscript); @@ -60,13 +60,12 @@ Status ValidateEinsumEquation( // labels are integerized and we populate input and output label subscripts // and corresponding counts. Also create the mapping from (named) labels to // their EinsumDimensionType. -Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels, - Labels* output_labels, - std::vector* label_types, - OperandLabelCounts* input_label_counts, - LabelCounts* output_label_counts, - absl::InlinedVector* input_has_ellipsis, - bool* output_has_ellipsis); +absl::Status ParseEinsumEquation( + const string& equation, OperandLabels* input_labels, Labels* output_labels, + std::vector* label_types, + OperandLabelCounts* input_label_counts, LabelCounts* output_label_counts, + absl::InlinedVector* input_has_ellipsis, + bool* output_has_ellipsis); } // namespace tensorflow diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index 649ed3d28dc819..6be31c499d33ae 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -42,14 +42,14 @@ EventsWriter::~EventsWriter() { Close().IgnoreError(); // Autoclose in destructor. } -Status EventsWriter::Init() { return InitWithSuffix(""); } +absl::Status EventsWriter::Init() { return InitWithSuffix(""); } -Status EventsWriter::InitWithSuffix(const string& suffix) { +absl::Status EventsWriter::InitWithSuffix(const string& suffix) { file_suffix_ = suffix; return InitIfNeeded(); } -Status EventsWriter::InitIfNeeded() { +absl::Status EventsWriter::InitIfNeeded() { if (recordio_writer_ != nullptr) { CHECK(!filename_.empty()); if (!FileStillExists().ok()) { @@ -125,7 +125,7 @@ void EventsWriter::WriteEvent(const Event& event) { WriteSerializedEvent(record); } -Status EventsWriter::Flush() { +absl::Status EventsWriter::Flush() { if (num_outstanding_events_ == 0) return absl::OkStatus(); CHECK(recordio_file_ != nullptr) << "Unexpected NULL file"; @@ -140,10 +140,10 @@ Status EventsWriter::Flush() { return absl::OkStatus(); } -Status EventsWriter::Close() { - Status status = Flush(); +absl::Status EventsWriter::Close() { + absl::Status status = Flush(); if (recordio_file_ != nullptr) { - Status close_status = recordio_file_->Close(); + absl::Status close_status = recordio_file_->Close(); if (!close_status.ok()) { status = close_status; } @@ -154,7 +154,7 @@ Status EventsWriter::Close() { return status; } -Status EventsWriter::FileStillExists() { +absl::Status EventsWriter::FileStillExists() { if (env_->FileExists(filename_).ok()) { return absl::OkStatus(); } diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h index 8a934e8cc39e29..06eaee845eb6a6 100644 --- a/tensorflow/core/util/events_writer.h +++ b/tensorflow/core/util/events_writer.h @@ -55,8 +55,8 @@ class EventsWriter { // and is open this is a no-op. If on the other hand the file was opened, // but has since disappeared (e.g. deleted by another process), this will open // a new file with a new timestamp in its filename. - Status Init(); - Status InitWithSuffix(const std::string& suffix); + absl::Status Init(); + absl::Status InitWithSuffix(const std::string& suffix); // Returns the filename for the current events file: // filename_ = [file_prefix_].out.events.[timestamp].[hostname][suffix] @@ -78,12 +78,12 @@ class EventsWriter { // be written too. // Close() calls Flush() and then closes the current events file. // Returns true only if both the flush and the closure were successful. - Status Flush(); - Status Close(); + absl::Status Flush(); + absl::Status Close(); private: - Status FileStillExists(); // OK if event_file_path_ exists. - Status InitIfNeeded(); + absl::Status FileStillExists(); // OK if event_file_path_ exists. + absl::Status InitIfNeeded(); Env* env_; const std::string file_prefix_; diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc index 1c3185b1924c88..81b2aa400b027c 100644 --- a/tensorflow/core/util/events_writer_test.cc +++ b/tensorflow/core/util/events_writer_test.cc @@ -57,7 +57,7 @@ void WriteFile(EventsWriter* writer) { static bool ReadEventProto(io::RecordReader* reader, uint64* offset, Event* proto) { tstring record; - Status s = reader->ReadRecord(offset, &record); + absl::Status s = reader->ReadRecord(offset, &record); if (!s.ok()) { return false; } diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc index a3c336fc6fa2e0..fafafa94ef0bda 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.cc +++ b/tensorflow/core/util/example_proto_fast_parsing.cc @@ -127,7 +127,7 @@ class Feature { Feature() = default; explicit Feature(StringPiece serialized) : serialized_(serialized) {} - Status ParseDataType(DataType* dtype) { + absl::Status ParseDataType(DataType* dtype) { DCHECK(dtype != nullptr); if (serialized_.empty()) { *dtype = DT_INVALID; @@ -589,7 +589,7 @@ void LogSparseFeatureDataLoss(StringPiece feature_name) { duplicated_sparse_feature->GetCell()->IncrementBy(1); } -Status FastParseSerializedExample( +absl::Status FastParseSerializedExample( const tstring& serialized_example, const tstring& example_name, const size_t example_index, const Config& config, const PresizedCuckooMap>& config_index, @@ -950,7 +950,7 @@ Status FastParseSerializedExample( return absl::OkStatus(); } -Status CheckConfigDataType(DataType dtype) { +absl::Status CheckConfigDataType(DataType dtype) { switch (dtype) { case DT_INT64: case DT_FLOAT: @@ -970,7 +970,7 @@ inline void ReportUnexpectedDataType(DataType dtype) { << "in variable that should have been checked by CheckConfigDataType()."; } -Status CheckConfigDataTypes(const Config& config) { +absl::Status CheckConfigDataTypes(const Config& config) { // Check config so we can safely CHECK(false) in switches on config.*.dtype for (auto& c : config.sparse) { TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); @@ -1133,10 +1133,10 @@ void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src, } // namespace -Status FastParseExample(const Config& config, - absl::Span serialized, - absl::Span example_names, - thread::ThreadPool* thread_pool, Result* result) { +absl::Status FastParseExample(const Config& config, + absl::Span serialized, + absl::Span example_names, + thread::ThreadPool* thread_pool, Result* result) { DCHECK(result != nullptr); // Check config so we can safely CHECK(false) in switches on config.*.dtype TF_RETURN_IF_ERROR(CheckConfigDataTypes(config)); @@ -1230,7 +1230,7 @@ Status FastParseExample(const Config& config, std::vector> sparse_buffers(num_minibatches); std::vector> varlen_dense_buffers(num_minibatches); std::vector> ragged_buffers(num_minibatches); - std::vector status_of_minibatch(num_minibatches); + std::vector status_of_minibatch(num_minibatches); auto ProcessMiniBatch = [&](size_t minibatch) { sparse_buffers[minibatch].resize(config.sparse.size()); varlen_dense_buffers[minibatch].resize(config.dense.size()); @@ -1254,7 +1254,7 @@ Status FastParseExample(const Config& config, ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool); - for (Status& status : status_of_minibatch) { + for (absl::Status& status : status_of_minibatch) { TF_RETURN_IF_ERROR(status); } @@ -1447,8 +1447,8 @@ Status FastParseExample(const Config& config, return absl::OkStatus(); } -Status FastParseSingleExample(const Config& config, StringPiece serialized, - Result* result) { +absl::Status FastParseSingleExample(const Config& config, + StringPiece serialized, Result* result) { DCHECK(result != nullptr); // Check config so we can safely CHECK(false) in switches on config.*.dtype TF_RETURN_IF_ERROR(CheckConfigDataTypes(config)); @@ -2098,7 +2098,7 @@ inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream, } // Reads an example proto, and extracts a StringPiece pointer to each feature. -Status ExtractFeaturesFromSequenceExamples( +absl::Status ExtractFeaturesFromSequenceExamples( const absl::Span examples, const absl::Span example_names, FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) { @@ -2167,8 +2167,9 @@ Status ExtractFeaturesFromSequenceExamples( // Populates context_features[k].length based on context_features[k].protos // (for all k). -Status GetContextFeatureLengths(const absl::Span example_names, - FeatureProtosMap* context_features) { +absl::Status GetContextFeatureLengths( + const absl::Span example_names, + FeatureProtosMap* context_features) { for (auto& c : *context_features) { FeatureProtos& feature = c.second; for (int d = 0; d < feature.protos.size(); ++d) { @@ -2202,8 +2203,9 @@ Status GetContextFeatureLengths(const absl::Span example_names, // Populates sequence_features[k].length and sequence_features[k].num_rows based // on sequence_features[k].protos (for all k). -Status GetSequenceFeatureLengths(const absl::Span example_names, - FeatureProtosMap* sequence_features) { +absl::Status GetSequenceFeatureLengths( + const absl::Span example_names, + FeatureProtosMap* sequence_features) { for (auto& c : *sequence_features) { FeatureProtos& feature = c.second; for (int d = 0; d < feature.protos.size(); ++d) { @@ -2299,11 +2301,11 @@ void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst, // Parses dense features in `context_features`, and writes their parsed // values to `context_results`. -Status ParseContextDenseFeatures(const FeatureProtosMap& context_features, - const FastParseExampleConfig& context_config, - absl::Span example_names, - bool is_batch, int num_examples, - Allocator* allocator, Result* context_result) { +absl::Status ParseContextDenseFeatures( + const FeatureProtosMap& context_features, + const FastParseExampleConfig& context_config, + absl::Span example_names, bool is_batch, int num_examples, + Allocator* allocator, Result* context_result) { for (int t = 0; t < context_config.dense.size(); ++t) { const auto& c = context_config.dense[t]; const FeatureProtos& feature = @@ -2362,12 +2364,11 @@ Status ParseContextDenseFeatures(const FeatureProtosMap& context_features, // Parses sparse features in `context_features`, and writes their parsed // values to `context_results`. -Status ParseContextSparseFeatures(const FeatureProtosMap& context_features, - const FastParseExampleConfig& context_config, - absl::Span example_names, - bool is_batch, int num_examples, - Allocator* allocator, - Result* context_result) { +absl::Status ParseContextSparseFeatures( + const FeatureProtosMap& context_features, + const FastParseExampleConfig& context_config, + absl::Span example_names, bool is_batch, int num_examples, + Allocator* allocator, Result* context_result) { for (int t = 0; t < context_config.sparse.size(); ++t) { const auto& c = context_config.sparse[t]; const FeatureProtos& feature = @@ -2424,12 +2425,11 @@ Status ParseContextSparseFeatures(const FeatureProtosMap& context_features, // Parses ragged features in `context_features`, and writes their parsed // values to `context_results`. -Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features, - const FastParseExampleConfig& context_config, - absl::Span example_names, - bool is_batch, int num_examples, - Allocator* allocator, - Result* context_result) { +absl::Status ParseContextRaggedFeatures( + const FeatureProtosMap& context_features, + const FastParseExampleConfig& context_config, + absl::Span example_names, bool is_batch, int num_examples, + Allocator* allocator, Result* context_result) { for (int t = 0; t < context_config.ragged.size(); ++t) { const auto& c = context_config.ragged[t]; const FeatureProtos& feature = @@ -2502,12 +2502,12 @@ Status ParseContextRaggedFeatures(const FeatureProtosMap& context_features, // Parses dense features in `sequence_features`, and writes their parsed // values to `sequence_result`. -Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features, - const FastParseExampleConfig& sequence_config, - absl::Span example_names, - bool is_batch, int num_examples, - Allocator* allocator, Result* sequence_result, - std::vector* dense_feature_lengths) { +absl::Status ParseSequenceDenseFeatures( + const FeatureProtosMap& sequence_features, + const FastParseExampleConfig& sequence_config, + absl::Span example_names, bool is_batch, int num_examples, + Allocator* allocator, Result* sequence_result, + std::vector* dense_feature_lengths) { TensorShape dense_length_shape; if (is_batch) { dense_length_shape.AddDim(num_examples); @@ -2656,7 +2656,7 @@ Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features, // Parses sparse features in `sequence_features`, and writes their parsed // values to `sequence_result`. -Status ParseSequenceSparseFeatures( +absl::Status ParseSequenceSparseFeatures( const FeatureProtosMap& sequence_features, const FastParseExampleConfig& sequence_config, absl::Span example_names, bool is_batch, int num_examples, @@ -2784,7 +2784,7 @@ Status ParseSequenceSparseFeatures( // Parses ragged features in `sequence_features`, and writes their parsed // values to `sequence_result`. -Status ParseSequenceRaggedFeatures( +absl::Status ParseSequenceRaggedFeatures( const FeatureProtosMap& sequence_features, const FastParseExampleConfig& sequence_config, absl::Span example_names, bool is_batch, int num_examples, @@ -2931,14 +2931,13 @@ Status ParseSequenceRaggedFeatures( // TODO(sundberg): Use the threadpool to parallelize example parsing. // TODO(b/111553342): Support extracting feature statistics from the examples. -Status FastParseSequenceExample(const FastParseExampleConfig& context_config, - const FastParseExampleConfig& sequence_config, - absl::Span serialized, - absl::Span example_names, - thread::ThreadPool* thread_pool, - Result* context_result, Result* sequence_result, - std::vector* dense_feature_lengths, - bool is_batch) { +absl::Status FastParseSequenceExample( + const FastParseExampleConfig& context_config, + const FastParseExampleConfig& sequence_config, + absl::Span serialized, + absl::Span example_names, thread::ThreadPool* thread_pool, + Result* context_result, Result* sequence_result, + std::vector* dense_feature_lengths, bool is_batch) { int num_examples = serialized.size(); DCHECK(context_result != nullptr); DCHECK(sequence_result != nullptr); diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h index 9c14aa9b6b3844..edc72f47e773ca 100644 --- a/tensorflow/core/util/example_proto_fast_parsing.h +++ b/tensorflow/core/util/example_proto_fast_parsing.h @@ -134,23 +134,23 @@ struct Result { // according to given config. // Given example names have to either be empty or the same size as serialized. // example_names are used only for error messages. -Status FastParseExample(const FastParseExampleConfig& config, - absl::Span serialized, - absl::Span example_names, - thread::ThreadPool* thread_pool, Result* result); +absl::Status FastParseExample(const FastParseExampleConfig& config, + absl::Span serialized, + absl::Span example_names, + thread::ThreadPool* thread_pool, Result* result); // TODO(mrry): Move the hash table construction into the config object. typedef FastParseExampleConfig FastParseSingleExampleConfig; -Status FastParseSingleExample(const FastParseSingleExampleConfig& config, - StringPiece serialized, Result* result); +absl::Status FastParseSingleExample(const FastParseSingleExampleConfig& config, + StringPiece serialized, Result* result); // Parses a batch of serialized SequenceExample protos and converts them into // result according to given config. // Given example names have to either be empty or the same size as serialized. // example_names are used only for error messages. // (If batch=true, then this parses a single SequenceExample.) -Status FastParseSequenceExample( +absl::Status FastParseSequenceExample( const example::FastParseExampleConfig& context_config, const example::FastParseExampleConfig& sequence_config, absl::Span serialized, diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc index 4d19be4a732932..1eab3903134ed7 100644 --- a/tensorflow/core/util/example_proto_fast_parsing_test.cc +++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc @@ -423,7 +423,7 @@ TEST(TestFastParseExample, Empty) { Result result; FastParseExampleConfig config; config.sparse.push_back({"test", DT_STRING}); - Status status = + absl::Status status = FastParseExample(config, absl::Span(), absl::Span(), nullptr, &result); EXPECT_TRUE(status.ok()) << status; diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc index a9f0ec1712afba..be7dc745cfdb3e 100644 --- a/tensorflow/core/util/example_proto_helper.cc +++ b/tensorflow/core/util/example_proto_helper.cc @@ -29,7 +29,7 @@ limitations under the License. namespace tensorflow { -Status CheckValidType(const DataType& dtype) { +absl::Status CheckValidType(const DataType& dtype) { switch (dtype) { case DT_INT64: case DT_FLOAT: @@ -41,8 +41,8 @@ Status CheckValidType(const DataType& dtype) { } } -Status CheckTypesMatch(const Feature& feature, const DataType& dtype, - bool* match) { +absl::Status CheckTypesMatch(const Feature& feature, const DataType& dtype, + bool* match) { switch (dtype) { case DT_INT64: *match = (feature.kind_case() == Feature::kInt64List); @@ -60,10 +60,10 @@ Status CheckTypesMatch(const Feature& feature, const DataType& dtype, return absl::OkStatus(); } -Status FeatureDenseCopy(const std::size_t out_index, const string& name, - const string& key, const DataType& dtype, - const TensorShape& shape, const Feature& feature, - Tensor* out) { +absl::Status FeatureDenseCopy(const std::size_t out_index, const string& name, + const string& key, const DataType& dtype, + const TensorShape& shape, const Feature& feature, + Tensor* out) { const std::size_t num_elements = shape.num_elements(); const std::size_t offset = out_index * num_elements; @@ -217,7 +217,7 @@ void RowDenseCopy(const std::size_t& out_index, const DataType& dtype, } } -Status SingleExampleProtoToTensors( +absl::Status SingleExampleProtoToTensors( const Example& example, const string& example_name, const int batch_index, const std::vector& fixed_len_features, const std::vector& var_len_features, @@ -300,10 +300,10 @@ Status SingleExampleProtoToTensors( return absl::OkStatus(); } -Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, - const std::vector& sparse_values_tmp, - const int batch_size, - VarLenFeatureBatchShapes* output_shapes) { +absl::Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, + const std::vector& sparse_values_tmp, + const int batch_size, + VarLenFeatureBatchShapes* output_shapes) { int64_t total_num_features = 0; int64_t max_num_features = 0; for (int b = 0; b < batch_size; ++b) { @@ -319,7 +319,7 @@ Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, return absl::OkStatus(); } -Status BatchExampleProtoToTensors( +absl::Status BatchExampleProtoToTensors( const std::vector& examples, const std::vector& names, const std::vector& fixed_len_features, @@ -407,7 +407,7 @@ Status BatchExampleProtoToTensors( return absl::OkStatus(); } -Status ParseExampleAttrs::FinishInit(int op_version) { +absl::Status ParseExampleAttrs::FinishInit(int op_version) { switch (op_version) { case 1: num_ragged = 0; @@ -457,7 +457,7 @@ Status ParseExampleAttrs::FinishInit(int op_version) { return absl::OkStatus(); } -Status ParseSingleExampleAttrs::FinishInit() { +absl::Status ParseSingleExampleAttrs::FinishInit() { if (sparse_keys.size() != sparse_types.size()) { return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)"); } @@ -476,7 +476,7 @@ Status ParseSingleExampleAttrs::FinishInit() { return absl::OkStatus(); } -Status ParseSequenceExampleAttrs::FinishInit(int op_version) { +absl::Status ParseSequenceExampleAttrs::FinishInit(int op_version) { switch (op_version) { case 1: num_context_ragged = 0; @@ -593,7 +593,7 @@ Status ParseSequenceExampleAttrs::FinishInit(int op_version) { return absl::OkStatus(); } -Status ParseSingleSequenceExampleAttrs::FinishInit() { +absl::Status ParseSingleSequenceExampleAttrs::FinishInit() { if (static_cast(num_context_sparse) != context_sparse_types.size()) { return errors::InvalidArgument( "len(context_sparse_keys) != len(context_sparse_types)"); @@ -632,9 +632,9 @@ Status ParseSingleSequenceExampleAttrs::FinishInit() { return absl::OkStatus(); } -Status GetDenseShapes(const std::vector& dense_shapes, - std::vector* variable_length, - std::vector* elements_per_stride) { +absl::Status GetDenseShapes(const std::vector& dense_shapes, + std::vector* variable_length, + std::vector* elements_per_stride) { // Temporary check until we start allowing a variable length outer // dimension. for (int i = 0; i < dense_shapes.size(); ++i) { diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h index 030dfc396ff5b8..801aae375e0d5f 100644 --- a/tensorflow/core/util/example_proto_helper.h +++ b/tensorflow/core/util/example_proto_helper.h @@ -74,7 +74,7 @@ struct VarLenFeature { // GetSparseTensorShape can be used to calculate the final shapes and // CopyIntoSparseTensor can be used to copy from the temporary vector // into the final allocated tensors. -Status SingleExampleProtoToTensors( +absl::Status SingleExampleProtoToTensors( const Example& example, const string& name, int batch_index, const std::vector& fixed_len_features, const std::vector& var_len_features, @@ -92,10 +92,10 @@ struct VarLenFeatureBatchShapes { // Get the shape of the sparse values and indices tensors for the batch, // given how many of the tensors in the temporary sparse values vector // are actually filled. -Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, - const std::vector& sparse_values_tmp, - int batch_size, - VarLenFeatureBatchShapes* output_shapes); +absl::Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, + const std::vector& sparse_values_tmp, + int batch_size, + VarLenFeatureBatchShapes* output_shapes); // A method to convert a batch of tensorflow::Example protos into output // tensors. This method is useful if there already is a batch of deserialized @@ -107,7 +107,7 @@ Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, // // Note that unlike SingleExampleProtoToTensors, output tensors are // allocated using a provided Allocator within this method. -Status BatchExampleProtoToTensors( +absl::Status BatchExampleProtoToTensors( const std::vector& examples, const std::vector& names, const std::vector& fixed_len_features, @@ -119,19 +119,19 @@ Status BatchExampleProtoToTensors( // Check that the given dtype is one that is compatible with // tensorflow::Example protocol buffer feature values. -Status CheckValidType(const DataType& dtype); +absl::Status CheckValidType(const DataType& dtype); // Check that the provided Feature proto message's oneof value // matches that of the provided dtype. -Status CheckTypesMatch(const Feature& feature, const DataType& dtype, - bool* match); +absl::Status CheckTypesMatch(const Feature& feature, const DataType& dtype, + bool* match); // For a single Example, copy a dense feature value into an output // dense value tensor Out at the provided out_index offset. -Status FeatureDenseCopy(std::size_t out_index, const string& name, - const string& key, const DataType& dtype, - const TensorShape& shape, const Feature& feature, - Tensor* out); +absl::Status FeatureDenseCopy(std::size_t out_index, const string& name, + const string& key, const DataType& dtype, + const TensorShape& shape, const Feature& feature, + Tensor* out); // Copy the value a provided Tensor into an output dense_value tensor Out // at the provided out_index offset. @@ -153,16 +153,16 @@ int64_t CopyIntoSparseTensor(const Tensor& in, int batch, int64_t offset, // Check that each dense_shape has known rank and inner dimensions; and // update variable_length (whether the outer dimension is None) and // elements_per_stride for each denes_shape. -Status GetDenseShapes(const std::vector& dense_shapes, - std::vector* variable_length, - std::vector* elements_per_stride); +absl::Status GetDenseShapes(const std::vector& dense_shapes, + std::vector* variable_length, + std::vector* elements_per_stride); // Parses the attributes passed to ParseExample. // REQUIRES: Init must be called after construction. struct ParseExampleAttrs { public: template - Status Init(ContextType* ctx, int op_version = 1) { + absl::Status Init(ContextType* ctx, int op_version = 1) { TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types)); TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense", &dense_types)); TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes", &dense_shapes)); @@ -198,7 +198,8 @@ struct ParseExampleAttrs { std::vector elements_per_stride; private: - Status FinishInit(int op_version); // for context-independent parts of Init. + absl::Status FinishInit( + int op_version); // for context-independent parts of Init. }; // Parses the attributes passed to ParseSingleExample. @@ -206,7 +207,7 @@ struct ParseExampleAttrs { struct ParseSingleExampleAttrs { public: template - Status Init(ContextType* ctx) { + absl::Status Init(ContextType* ctx) { TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_keys", &sparse_keys)); TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types", &sparse_types)); TF_RETURN_IF_ERROR(ctx->GetAttr("dense_keys", &dense_keys)); @@ -235,7 +236,7 @@ struct ParseSingleExampleAttrs { std::vector elements_per_stride; private: - Status FinishInit(); // for context-independent parts of Init. + absl::Status FinishInit(); // for context-independent parts of Init. }; // Parses the attributes passed to ParseSequenceExample. @@ -243,7 +244,7 @@ struct ParseSingleExampleAttrs { struct ParseSequenceExampleAttrs { public: template - Status Init(ContextType* ctx, int op_version = 1) { + absl::Status Init(ContextType* ctx, int op_version = 1) { switch (op_version) { case 1: { std::vector missing_empty_vector; @@ -318,7 +319,8 @@ struct ParseSequenceExampleAttrs { std::vector feature_list_ragged_split_types; private: - Status FinishInit(int op_version); // for context-independent parts of Init. + absl::Status FinishInit( + int op_version); // for context-independent parts of Init. }; // Parses the attributes passed to ParseSingleSequenceExample. @@ -326,7 +328,7 @@ struct ParseSequenceExampleAttrs { struct ParseSingleSequenceExampleAttrs { public: template - Status Init(ContextType* ctx) { + absl::Status Init(ContextType* ctx) { TF_RETURN_IF_ERROR( ctx->GetAttr("context_sparse_types", &context_sparse_types)); TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense)); @@ -359,7 +361,7 @@ struct ParseSingleSequenceExampleAttrs { std::vector feature_list_dense_shapes; private: - Status FinishInit(); // for context-independent parts of Init. + absl::Status FinishInit(); // for context-independent parts of Init. }; } // namespace tensorflow diff --git a/tensorflow/core/util/gpu_device_functions.h b/tensorflow/core/util/gpu_device_functions.h index 8f945874f4f432..1a3909afe32ee7 100644 --- a/tensorflow/core/util/gpu_device_functions.h +++ b/tensorflow/core/util/gpu_device_functions.h @@ -73,7 +73,7 @@ static std::string cudaGetErrorString(int err) { return std::to_string(err); } #define TF_RETURN_IF_CUDA_ERROR(result) \ do { \ cudaError_t error(result); \ - if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \ + if (!TF_PREDICT_TRUE(error == cudaSuccess)) { \ return absl::InternalError( \ absl::StrCat("Cuda call failed with ", cudaGetErrorString(error))); \ } \ @@ -82,7 +82,7 @@ static std::string cudaGetErrorString(int err) { return std::to_string(err); } #define TF_OP_REQUIRES_CUDA_SUCCESS(context, result) \ do { \ cudaError_t error(result); \ - if (!SE_PREDICT_TRUE(error == cudaSuccess)) { \ + if (!TF_PREDICT_TRUE(error == cudaSuccess)) { \ context->SetStatus(absl::InternalError( \ absl::StrCat("Cuda call failed with", cudaGetErrorString(error)))); \ return; \ diff --git a/tensorflow/core/util/guarded_philox_random.cc b/tensorflow/core/util/guarded_philox_random.cc index ede9d381cec972..b0b26d6c9b8563 100644 --- a/tensorflow/core/util/guarded_philox_random.cc +++ b/tensorflow/core/util/guarded_philox_random.cc @@ -20,7 +20,7 @@ limitations under the License. namespace tensorflow { -Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) { +absl::Status GuardedPhiloxRandom::Init(OpKernelConstruction* context) { // Grab seed Attrs. int64_t seed, seed2; auto status = context->GetAttr("seed", &seed); diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h index db28b26f01a692..225d68fda71e87 100644 --- a/tensorflow/core/util/guarded_philox_random.h +++ b/tensorflow/core/util/guarded_philox_random.h @@ -45,7 +45,7 @@ class GuardedPhiloxRandom { // Initialize the generator from attributes "seed" and "seed2". // If both seeds are unspecified, use random seeds. // Must be called exactly once. - Status Init(OpKernelConstruction* context); + absl::Status Init(OpKernelConstruction* context); // Initialize with given seeds. void Init(int64_t seed, int64_t seed2); diff --git a/tensorflow/core/util/matmul_autotune.cc b/tensorflow/core/util/matmul_autotune.cc index e8704bb4af7c1f..f040b398d57f5a 100644 --- a/tensorflow/core/util/matmul_autotune.cc +++ b/tensorflow/core/util/matmul_autotune.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { bool MatmulAutotuneEnable() { bool value; - Status status = + absl::Status status = ReadBoolFromEnvVar("TF_MATMUL_AUTOTUNE_ENABLE", false, &value); if (!status.ok()) { LOG(ERROR) << status.message(); @@ -40,7 +40,7 @@ bool MatmulDoFP32ComputationFP16Input() { // user-set-true, user-set-false, user-no-setting. In the calling sites, // check the compatibilities. Note that user-set-false with compute // capability <= 5.2 will cause an error in the later cublasGemmEx() call. - Status status = + absl::Status status = ReadBoolFromEnvVar("TF_FP16_MATMUL_USE_FP32_COMPUTE", true, &value); if (!status.ok()) { LOG(ERROR) << status.message(); diff --git a/tensorflow/core/util/memmapped_file_system.cc b/tensorflow/core/util/memmapped_file_system.cc index 97d749028d16e0..2dd82aeeff12b8 100644 --- a/tensorflow/core/util/memmapped_file_system.cc +++ b/tensorflow/core/util/memmapped_file_system.cc @@ -61,24 +61,25 @@ class RandomAccessFileFromMemmapped : public RandomAccessFile { ~RandomAccessFileFromMemmapped() override = default; - Status Name(StringPiece* result) const override { + absl::Status Name(StringPiece* result) const override { return errors::Unimplemented( "RandomAccessFileFromMemmapped does not support Name()"); } - Status Read(uint64 offset, size_t to_read, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t to_read, StringPiece* result, + char* scratch) const override { if (offset >= length_) { *result = StringPiece(scratch, 0); - return Status(absl::StatusCode::kOutOfRange, "Read after file end"); + return absl::Status(absl::StatusCode::kOutOfRange, "Read after file end"); } const uint64 region_left = std::min(length_ - offset, static_cast(to_read)); *result = StringPiece(reinterpret_cast(data_) + offset, region_left); - return (region_left == to_read) ? absl::OkStatus() - : Status(absl::StatusCode::kOutOfRange, - "Read less bytes than requested"); + return (region_left == to_read) + ? absl::OkStatus() + : absl::Status(absl::StatusCode::kOutOfRange, + "Read less bytes than requested"); } private: @@ -91,8 +92,8 @@ class RandomAccessFileFromMemmapped : public RandomAccessFile { MemmappedFileSystem::MemmappedFileSystem() = default; -Status MemmappedFileSystem::FileExists(const string& fname, - TransactionToken* token) { +absl::Status MemmappedFileSystem::FileExists(const string& fname, + TransactionToken* token) { if (!mapped_memory_) { return errors::FailedPrecondition("MemmappedEnv is not initialized"); } @@ -103,7 +104,7 @@ Status MemmappedFileSystem::FileExists(const string& fname, return errors::NotFound(fname, " not found"); } -Status MemmappedFileSystem::NewRandomAccessFile( +absl::Status MemmappedFileSystem::NewRandomAccessFile( const string& filename, TransactionToken* token, std::unique_ptr* result) { if (!mapped_memory_) { @@ -119,7 +120,7 @@ Status MemmappedFileSystem::NewRandomAccessFile( return absl::OkStatus(); } -Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile( +absl::Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile( const string& filename, TransactionToken* token, std::unique_ptr* result) { if (!mapped_memory_) { @@ -135,8 +136,9 @@ Status MemmappedFileSystem::NewReadOnlyMemoryRegionFromFile( return absl::OkStatus(); } -Status MemmappedFileSystem::GetFileSize(const string& filename, - TransactionToken* token, uint64* size) { +absl::Status MemmappedFileSystem::GetFileSize(const string& filename, + TransactionToken* token, + uint64* size) { if (!mapped_memory_) { return errors::FailedPrecondition("MemmappedEnv is not initialized"); } @@ -148,8 +150,9 @@ Status MemmappedFileSystem::GetFileSize(const string& filename, return absl::OkStatus(); } -Status MemmappedFileSystem::Stat(const string& fname, TransactionToken* token, - FileStatistics* stat) { +absl::Status MemmappedFileSystem::Stat(const string& fname, + TransactionToken* token, + FileStatistics* stat) { uint64 size; auto status = GetFileSize(fname, token, &size); if (status.ok()) { @@ -158,49 +161,49 @@ Status MemmappedFileSystem::Stat(const string& fname, TransactionToken* token, return status; } -Status MemmappedFileSystem::NewWritableFile(const string& filename, - TransactionToken* token, - std::unique_ptr* wf) { +absl::Status MemmappedFileSystem::NewWritableFile( + const string& filename, TransactionToken* token, + std::unique_ptr* wf) { return errors::Unimplemented("memmapped format doesn't support writing"); } -Status MemmappedFileSystem::NewAppendableFile( +absl::Status MemmappedFileSystem::NewAppendableFile( const string& filename, TransactionToken* token, std::unique_ptr* result) { return errors::Unimplemented("memmapped format doesn't support writing"); } -Status MemmappedFileSystem::GetChildren(const string& filename, - TransactionToken* token, - std::vector* strings) { +absl::Status MemmappedFileSystem::GetChildren(const string& filename, + TransactionToken* token, + std::vector* strings) { return errors::Unimplemented("memmapped format doesn't support GetChildren"); } -Status MemmappedFileSystem::GetMatchingPaths(const string& pattern, - TransactionToken* token, - std::vector* results) { +absl::Status MemmappedFileSystem::GetMatchingPaths( + const string& pattern, TransactionToken* token, + std::vector* results) { return errors::Unimplemented( "memmapped format doesn't support GetMatchingPaths"); } -Status MemmappedFileSystem::DeleteFile(const string& filename, - TransactionToken* token) { +absl::Status MemmappedFileSystem::DeleteFile(const string& filename, + TransactionToken* token) { return errors::Unimplemented("memmapped format doesn't support DeleteFile"); } -Status MemmappedFileSystem::CreateDir(const string& dirname, - TransactionToken* token) { +absl::Status MemmappedFileSystem::CreateDir(const string& dirname, + TransactionToken* token) { return errors::Unimplemented("memmapped format doesn't support CreateDir"); } -Status MemmappedFileSystem::DeleteDir(const string& dirname, - TransactionToken* token) { +absl::Status MemmappedFileSystem::DeleteDir(const string& dirname, + TransactionToken* token) { return errors::Unimplemented("memmapped format doesn't support DeleteDir"); } -Status MemmappedFileSystem::RenameFile(const string& filename_from, - const string& filename_to, - TransactionToken* token) { +absl::Status MemmappedFileSystem::RenameFile(const string& filename_from, + const string& filename_to, + TransactionToken* token) { return errors::Unimplemented("memmapped format doesn't support RenameFile"); } @@ -211,8 +214,8 @@ const void* MemmappedFileSystem::GetMemoryWithOffset(uint64 offset) const { constexpr const char MemmappedFileSystem::kMemmappedPackagePrefix[]; constexpr const char MemmappedFileSystem::kMemmappedPackageDefaultGraphDef[]; -Status MemmappedFileSystem::InitializeFromFile(Env* env, - const string& filename) { +absl::Status MemmappedFileSystem::InitializeFromFile(Env* env, + const string& filename) { TF_RETURN_IF_ERROR( env->NewReadOnlyMemoryRegionFromFile(filename, &mapped_memory_)); directory_.clear(); @@ -287,8 +290,8 @@ bool MemmappedFileSystem::IsWellFormedMemmappedPackageFilename( MemmappedEnv::MemmappedEnv(Env* env) : EnvWrapper(env) {} -Status MemmappedEnv::GetFileSystemForFile(const string& fname, - FileSystem** result) { +absl::Status MemmappedEnv::GetFileSystemForFile(const string& fname, + FileSystem** result) { if (MemmappedFileSystem::IsMemmappedPackageFilename(fname)) { if (!memmapped_file_system_) { return errors::FailedPrecondition( @@ -300,7 +303,7 @@ Status MemmappedEnv::GetFileSystemForFile(const string& fname, return EnvWrapper::GetFileSystemForFile(fname, result); } -Status MemmappedEnv::GetRegisteredFileSystemSchemes( +absl::Status MemmappedEnv::GetRegisteredFileSystemSchemes( std::vector* schemes) { const auto status = EnvWrapper::GetRegisteredFileSystemSchemes(schemes); if (status.ok()) { @@ -309,7 +312,7 @@ Status MemmappedEnv::GetRegisteredFileSystemSchemes( return status; } -Status MemmappedEnv::InitializeFromFile(const string& package_filename) { +absl::Status MemmappedEnv::InitializeFromFile(const string& package_filename) { std::unique_ptr file_system_ptr(new MemmappedFileSystem); const auto status = file_system_ptr->InitializeFromFile(target(), package_filename); diff --git a/tensorflow/core/util/memmapped_file_system.h b/tensorflow/core/util/memmapped_file_system.h index 6b5e098a2d7217..225defc49ceceb 100644 --- a/tensorflow/core/util/memmapped_file_system.h +++ b/tensorflow/core/util/memmapped_file_system.h @@ -64,39 +64,41 @@ class MemmappedFileSystem : public FileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - Status FileExists(const string& fname, TransactionToken* token) override; - Status NewRandomAccessFile( + absl::Status FileExists(const string& fname, + TransactionToken* token) override; + absl::Status NewRandomAccessFile( const string& filename, TransactionToken* token, std::unique_ptr* result) override; - Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewReadOnlyMemoryRegionFromFile( const string& filename, TransactionToken* token, std::unique_ptr* result) override; // All these functions return Unimplemented error, the memmapped storage is // read only. - Status NewWritableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; - Status NewAppendableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; - Status GetChildren(const string& dir, TransactionToken* token, - std::vector* r) override; - Status GetMatchingPaths(const string& pattern, TransactionToken* token, - std::vector* results) override; - Status DeleteFile(const string& f, TransactionToken* token) override; - Status CreateDir(const string& d, TransactionToken* token) override; - Status DeleteDir(const string& d, TransactionToken* token) override; - Status RenameFile(const string& s, const string& t, - TransactionToken* token) override; + absl::Status NewWritableFile(const string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status NewAppendableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) override; + absl::Status GetChildren(const string& dir, TransactionToken* token, + std::vector* r) override; + absl::Status GetMatchingPaths(const string& pattern, TransactionToken* token, + std::vector* results) override; + absl::Status DeleteFile(const string& f, TransactionToken* token) override; + absl::Status CreateDir(const string& d, TransactionToken* token) override; + absl::Status DeleteDir(const string& d, TransactionToken* token) override; + absl::Status RenameFile(const string& s, const string& t, + TransactionToken* token) override; // These functions are implemented. - Status GetFileSize(const string& f, TransactionToken* token, - uint64* s) override; + absl::Status GetFileSize(const string& f, TransactionToken* token, + uint64* s) override; // Currently just returns size. - Status Stat(const string& fname, TransactionToken* token, - FileStatistics* stat) override; + absl::Status Stat(const string& fname, TransactionToken* token, + FileStatistics* stat) override; // Initializes filesystem from a file in memmapped format. - Status InitializeFromFile(Env* env, const string& filename); + absl::Status InitializeFromFile(Env* env, const string& filename); // Checks if the filename has a correct prefix. static bool IsMemmappedPackageFilename(const string& filename); @@ -126,10 +128,11 @@ class MemmappedEnv : public EnvWrapper { public: explicit MemmappedEnv(Env* env); ~MemmappedEnv() override = default; - Status GetFileSystemForFile(const string& fname, - FileSystem** result) override; - Status GetRegisteredFileSystemSchemes(std::vector* schemes) override; - Status InitializeFromFile(const string& filename); + absl::Status GetFileSystemForFile(const string& fname, + FileSystem** result) override; + absl::Status GetRegisteredFileSystemSchemes( + std::vector* schemes) override; + absl::Status InitializeFromFile(const string& filename); protected: std::unique_ptr memmapped_file_system_; diff --git a/tensorflow/core/util/memmapped_file_system_test.cc b/tensorflow/core/util/memmapped_file_system_test.cc index 8c0e2ea3480ce2..26e15450921e01 100644 --- a/tensorflow/core/util/memmapped_file_system_test.cc +++ b/tensorflow/core/util/memmapped_file_system_test.cc @@ -38,8 +38,9 @@ constexpr char kTensor2FileName[] = "memmapped_package://t2"; constexpr char kProtoFileName[] = "memmapped_package://b"; constexpr int kTestGraphDefVersion = 666; -Status CreateMemmappedFileSystemFile(const string& filename, bool corrupted, - Tensor* test_tensor) { +absl::Status CreateMemmappedFileSystemFile(const string& filename, + bool corrupted, + Tensor* test_tensor) { Env* env = Env::Default(); MemmappedFileSystemWriter writer; TF_RETURN_IF_ERROR(writer.InitializeToFile(env, filename)); diff --git a/tensorflow/core/util/memmapped_file_system_writer.cc b/tensorflow/core/util/memmapped_file_system_writer.cc index 102e3f6b3e9b97..411dbc51733a48 100644 --- a/tensorflow/core/util/memmapped_file_system_writer.cc +++ b/tensorflow/core/util/memmapped_file_system_writer.cc @@ -18,8 +18,8 @@ limitations under the License. namespace tensorflow { -Status MemmappedFileSystemWriter::InitializeToFile(Env* env, - const string& filename) { +absl::Status MemmappedFileSystemWriter::InitializeToFile( + Env* env, const string& filename) { auto status = env->NewWritableFile(filename, &output_file_); if (status.ok()) { output_file_offset_ = 0; @@ -27,8 +27,8 @@ Status MemmappedFileSystemWriter::InitializeToFile(Env* env, return status; } -Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor, - const string& element_name) { +absl::Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor, + const string& element_name) { if (!output_file_) { return errors::FailedPrecondition( "MemmappedEnvWritter: saving tensor into not opened file"); @@ -55,7 +55,7 @@ Status MemmappedFileSystemWriter::SaveTensor(const Tensor& tensor, return result; } -Status MemmappedFileSystemWriter::SaveProtobuf( +absl::Status MemmappedFileSystemWriter::SaveProtobuf( const protobuf::MessageLite& message, const string& element_name) { if (!output_file_) { return errors::FailedPrecondition( @@ -89,7 +89,7 @@ StringPiece EncodeUint64LittleEndian(uint64 val, char* output_buffer) { } // namespace -Status MemmappedFileSystemWriter::FlushAndClose() { +absl::Status MemmappedFileSystemWriter::FlushAndClose() { if (!output_file_) { return errors::FailedPrecondition( "MemmappedEnvWritter: flushing into not opened file"); @@ -109,7 +109,7 @@ Status MemmappedFileSystemWriter::FlushAndClose() { return absl::OkStatus(); } -Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) { +absl::Status MemmappedFileSystemWriter::AdjustAlignment(uint64 alignment) { const uint64 alignment_rest = output_file_offset_ % alignment; const uint64 to_write_for_alignment = (alignment_rest == 0) ? 0 : alignment - (output_file_offset_ % alignment); diff --git a/tensorflow/core/util/memmapped_file_system_writer.h b/tensorflow/core/util/memmapped_file_system_writer.h index 2b15a7e476ff9f..9d0db92758252d 100644 --- a/tensorflow/core/util/memmapped_file_system_writer.h +++ b/tensorflow/core/util/memmapped_file_system_writer.h @@ -31,15 +31,15 @@ class MemmappedFileSystemWriter { public: MemmappedFileSystemWriter() = default; ~MemmappedFileSystemWriter() = default; - Status InitializeToFile(Env* env, const string& filename); - Status SaveTensor(const Tensor& tensor, const string& element_name); - Status SaveProtobuf(const protobuf::MessageLite& message, - const string& element_name); + absl::Status InitializeToFile(Env* env, const string& filename); + absl::Status SaveTensor(const Tensor& tensor, const string& element_name); + absl::Status SaveProtobuf(const protobuf::MessageLite& message, + const string& element_name); // Writes out the directory of regions and closes the output file. - Status FlushAndClose(); + absl::Status FlushAndClose(); private: - Status AdjustAlignment(uint64 alignment); + absl::Status AdjustAlignment(uint64 alignment); void AddToDirectoryElement(const string& element_name, uint64 length); MemmappedFileSystemDirectory directory_; // The current offset in the file, to support alignment. diff --git a/tensorflow/core/util/mirror_pad_mode.cc b/tensorflow/core/util/mirror_pad_mode.cc index d6be854e9168bb..067996c69d07ef 100644 --- a/tensorflow/core/util/mirror_pad_mode.cc +++ b/tensorflow/core/util/mirror_pad_mode.cc @@ -22,8 +22,8 @@ limitations under the License. namespace tensorflow { -Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, - MirrorPadMode* value) { +absl::Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, + MirrorPadMode* value) { string str_value; TF_RETURN_IF_ERROR(GetNodeAttr(node_def, attr_name, &str_value)); if (str_value == "REFLECT") { diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h index b8eb3751fb0c2b..5675a22739cc82 100644 --- a/tensorflow/core/util/mirror_pad_mode.h +++ b/tensorflow/core/util/mirror_pad_mode.h @@ -45,8 +45,8 @@ string GetMirrorPadModeAttrString(); class NodeDef; // Specialization to parse an attribute directly into a MirrorPadMode enum. -Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, - MirrorPadMode* value); +absl::Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name, + MirrorPadMode* value); } // end namespace tensorflow diff --git a/tensorflow/core/util/padding.cc b/tensorflow/core/util/padding.cc index 28805364b9dd96..e502d5eafae769 100644 --- a/tensorflow/core/util/padding.cc +++ b/tensorflow/core/util/padding.cc @@ -22,7 +22,7 @@ limitations under the License. namespace tensorflow { -Status GetPaddingFromString(StringPiece str_value, Padding* value) { +absl::Status GetPaddingFromString(StringPiece str_value, Padding* value) { if (str_value == "SAME") { *value = SAME; } else if (str_value == "VALID") { @@ -35,9 +35,9 @@ Status GetPaddingFromString(StringPiece str_value, Padding* value) { return absl::OkStatus(); } -Status CheckValidPadding(Padding padding_type, - const std::vector& explicit_paddings, - int num_dims, TensorFormat data_format) { +absl::Status CheckValidPadding(Padding padding_type, + const std::vector& explicit_paddings, + int num_dims, TensorFormat data_format) { if (padding_type == Padding::EXPLICIT) { const int num_paddings = explicit_paddings.size(); if (num_paddings != 2 * num_dims) { diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h index 1ae0e38537692c..9c0cf543a0dc4f 100644 --- a/tensorflow/core/util/padding.h +++ b/tensorflow/core/util/padding.h @@ -47,9 +47,9 @@ enum Padding { }; // Returns an error if the padding attributes are invalid. -Status CheckValidPadding(Padding padding_type, - const std::vector& explicit_paddings, - int num_dims, TensorFormat data_format); +absl::Status CheckValidPadding(Padding padding_type, + const std::vector& explicit_paddings, + int num_dims, TensorFormat data_format); // Return the string containing the list of valid padding types, that can be // used as an Attr() in REGISTER_OP. @@ -61,7 +61,7 @@ std::string GetPaddingAttrStringWithExplicit(); std::string GetExplicitPaddingsAttrString(); // Sets padding value based on the given string padding value. -Status GetPaddingFromString(StringPiece str_value, Padding* value); +absl::Status GetPaddingFromString(StringPiece str_value, Padding* value); } // end namespace tensorflow diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h index 894c69b8b4d294..7d43e34b35ce50 100644 --- a/tensorflow/core/util/proto/decode.h +++ b/tensorflow/core/util/proto/decode.h @@ -323,7 +323,8 @@ inline int ReadPackedPrimitives(const void* bufp, const size_t len, // to the desired type for TensorFlow and stored. template -inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) { +inline absl::Status ReadPrimitive(CodedInputStream* input, int index, + void* data) { ValueType v; if (!WireFormatLite::ReadPrimitive(input, &v)) { return errors::DataLoss("Failed reading primitive"); @@ -336,7 +337,7 @@ inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) { // Reads a string, submessage, or other variable-length field from a // serialized proto. // May read all or part of a repeated field. -inline Status ReadBytes(CodedInputStream* input, int index, void* datap) { +inline absl::Status ReadBytes(CodedInputStream* input, int index, void* datap) { tstring* data = reinterpret_cast(datap) + index; uint32 length; @@ -354,8 +355,8 @@ inline Status ReadBytes(CodedInputStream* input, int index, void* datap) { // Reads a tag-delimited field (TYPE_GROUP) from a serialized proto, // as a bytestring. -inline Status ReadGroupBytes(CodedInputStream* input, int field_number, - int index, void* datap) { +inline absl::Status ReadGroupBytes(CodedInputStream* input, int field_number, + int index, void* datap) { // WireFormatLite::SkipField has an option to emit the // skipped bytes to an output stream. We could do better by implementing our // own scanner but this is simpler for now. @@ -386,9 +387,10 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number, } // Reads a single field value from a CodedInputStream into a tensor. -inline Status ReadValue(CodedInputStream* input, - WireFormatLite::FieldType field_type, int field_number, - DataType dtype, int index, void* datap) { +inline absl::Status ReadValue(CodedInputStream* input, + WireFormatLite::FieldType field_type, + int field_number, DataType dtype, int index, + void* datap) { // Dispatch to the appropriately typed field reader based on the schema type. switch (field_type) { case WireFormatLite::TYPE_DOUBLE: @@ -502,10 +504,10 @@ inline Status ReadValue(CodedInputStream* input, } // Reads and stores a length-delimited list of values. -inline Status ReadPackedFromArray(const void* buf, size_t buf_size, - const WireFormatLite::FieldType field_type, - const int field_number, const DataType dtype, - const int stride, int* index, void* data) { +inline absl::Status ReadPackedFromArray( + const void* buf, size_t buf_size, + const WireFormatLite::FieldType field_type, const int field_number, + const DataType dtype, const int stride, int* index, void* data) { // Dispatch to the appropriately typed field reader based on the schema type. switch (field_type) { case WireFormatLite::TYPE_DOUBLE: diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.h b/tensorflow/core/util/proto/descriptor_pool_registry.h index 66c20e9e413372..59c709ea150e87 100644 --- a/tensorflow/core/util/proto/descriptor_pool_registry.h +++ b/tensorflow/core/util/proto/descriptor_pool_registry.h @@ -30,7 +30,7 @@ namespace tensorflow { class DescriptorPoolRegistry { public: - typedef std::function* owned_desc_pool)> DescriptorPoolFn; diff --git a/tensorflow/core/util/proto/descriptor_pool_registry_test.cc b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc index 97f9f31c1835e4..abc266ccd243e7 100644 --- a/tensorflow/core/util/proto/descriptor_pool_registry_test.cc +++ b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc @@ -21,7 +21,7 @@ namespace tensorflow { namespace { struct Value { - static Status Function( + static absl::Status Function( tensorflow::protobuf::DescriptorPool const** desc_pool, std::unique_ptr* owned_desc_pool) { return absl::OkStatus(); diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc index 09137504664e1b..31942145fe32fa 100644 --- a/tensorflow/core/util/proto/descriptors.cc +++ b/tensorflow/core/util/proto/descriptors.cc @@ -26,8 +26,9 @@ limitations under the License. namespace tensorflow { namespace { -Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set, - std::unique_ptr* out_pool) { +absl::Status CreatePoolFromSet( + const protobuf::FileDescriptorSet& set, + std::unique_ptr* out_pool) { *out_pool = absl::make_unique(); for (const auto& file : set.file()) { if ((*out_pool)->BuildFile(file) == nullptr) { @@ -43,10 +44,10 @@ Status CreatePoolFromSet(const protobuf::FileDescriptorSet& set, // // The file must contain a serialized `FileDescriptorSet`. See // `GetDescriptorPool()` for more information. -Status GetDescriptorPoolFromFile( +absl::Status GetDescriptorPoolFromFile( tensorflow::Env* env, const string& filename, std::unique_ptr* owned_desc_pool) { - Status st = env->FileExists(filename); + absl::Status st = env->FileExists(filename); if (!st.ok()) { return st; } @@ -64,7 +65,7 @@ Status GetDescriptorPoolFromFile( return CreatePoolFromSet(descs, owned_desc_pool); } -Status GetDescriptorPoolFromBinary( +absl::Status GetDescriptorPoolFromBinary( const string& source, std::unique_ptr* owned_desc_pool) { if (!absl::StartsWith(source, "bytes://")) { @@ -86,7 +87,7 @@ Status GetDescriptorPoolFromBinary( } // namespace -Status GetDescriptorPool( +absl::Status GetDescriptorPool( Env* env, string const& descriptor_source, protobuf::DescriptorPool const** desc_pool, std::unique_ptr* owned_desc_pool) { @@ -98,7 +99,7 @@ Status GetDescriptorPool( // If there is no pool function registered for the given source, let the // runtime find the file or URL. - Status status = + absl::Status status = GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool); if (status.ok()) { *desc_pool = owned_desc_pool->get(); diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h index 36cf62f47e5c84..3402ed0504410e 100644 --- a/tensorflow/core/util/proto/descriptors.h +++ b/tensorflow/core/util/proto/descriptors.h @@ -45,7 +45,7 @@ using tsl::Env; // // Custom schemas can be supported by registering a handler with the // `DescriptorPoolRegistry`. -Status GetDescriptorPool( +absl::Status GetDescriptorPool( Env* env, string const& descriptor_source, protobuf::DescriptorPool const** desc_pool, std::unique_ptr* owned_desc_pool); diff --git a/tensorflow/core/util/proto/local_descriptor_pool_registration.cc b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc index 6929976e4ba10e..1e941b6ab3bb2f 100644 --- a/tensorflow/core/util/proto/local_descriptor_pool_registration.cc +++ b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc @@ -21,7 +21,7 @@ namespace tensorflow { namespace { struct LocalDescriptorPool { - static Status Function( + static absl::Status Function( tensorflow::protobuf::DescriptorPool const** desc_pool, std::unique_ptr* owned_desc_pool) { *desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool(); diff --git a/tensorflow/core/util/proto/proto_utils.cc b/tensorflow/core/util/proto/proto_utils.cc index 282be94a774d08..be13bdd876776a 100644 --- a/tensorflow/core/util/proto/proto_utils.cc +++ b/tensorflow/core/util/proto/proto_utils.cc @@ -69,21 +69,21 @@ bool IsCompatibleType(FieldDescriptor::Type field_type, DataType dtype) { } } -Status ParseTextFormatFromString(absl::string_view input, - protobuf::Message* output) { +absl::Status ParseTextFormatFromString(absl::string_view input, + protobuf::Message* output) { DCHECK(output != nullptr) << "output must be non NULL"; // When checks are disabled, instead log the error and return an error status. if (output == nullptr) { LOG(ERROR) << "output must be non NULL"; - return Status(absl::StatusCode::kInvalidArgument, - "output must be non NULL"); + return absl::Status(absl::StatusCode::kInvalidArgument, + "output must be non NULL"); } string err; StringErrorCollector err_collector(&err, /*one-indexing=*/true); protobuf::TextFormat::Parser parser; parser.RecordErrorsTo(&err_collector); if (!parser.ParseFromString(string(input), output)) { - return Status(absl::StatusCode::kInvalidArgument, err); + return absl::Status(absl::StatusCode::kInvalidArgument, err); } return absl::OkStatus(); } diff --git a/tensorflow/core/util/proto/proto_utils.h b/tensorflow/core/util/proto/proto_utils.h index f0347a84cbe429..01b8ad0d7479cf 100644 --- a/tensorflow/core/util/proto/proto_utils.h +++ b/tensorflow/core/util/proto/proto_utils.h @@ -37,8 +37,8 @@ bool IsCompatibleType(FieldDescriptor::Type field_type, DataType dtype); // Parses a text-formatted protobuf from a string into the given Message* output // and returns status OK if valid, or INVALID_ARGUMENT with an accompanying // parser error message if the text format is invalid. -Status ParseTextFormatFromString(absl::string_view input, - protobuf::Message* output); +absl::Status ParseTextFormatFromString(absl::string_view input, + protobuf::Message* output); class StringErrorCollector : public protobuf::io::ErrorCollector { public: diff --git a/tensorflow/core/util/proto/proto_utils_test.cc b/tensorflow/core/util/proto/proto_utils_test.cc index 00153d2b6f25ac..64d0e5a859744b 100644 --- a/tensorflow/core/util/proto/proto_utils_test.cc +++ b/tensorflow/core/util/proto/proto_utils_test.cc @@ -34,7 +34,7 @@ TEST(ParseTextFormatFromStringTest, Success) { TEST(ParseTextFormatFromStringTest, ErrorOnInvalidSyntax) { protobuf::DescriptorProto output; - Status status = ParseTextFormatFromString("name: foo", &output); + absl::Status status = ParseTextFormatFromString("name: foo", &output); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_THAT(status.message(), ContainsRegex("foo")); EXPECT_FALSE(output.has_name()); @@ -42,7 +42,7 @@ TEST(ParseTextFormatFromStringTest, ErrorOnInvalidSyntax) { TEST(ParseTextFormatFromStringTest, ErrorOnUnknownFieldName) { protobuf::DescriptorProto output; - Status status = ParseTextFormatFromString("badname: \"foo\"", &output); + absl::Status status = ParseTextFormatFromString("badname: \"foo\"", &output); EXPECT_EQ(status.code(), error::INVALID_ARGUMENT); EXPECT_THAT(status.message(), ContainsRegex("badname")); EXPECT_FALSE(output.has_name()); diff --git a/tensorflow/core/util/rocm_solvers.cc b/tensorflow/core/util/rocm_solvers.cc index 24f3fb5ae79e46..6fb71185e19af8 100644 --- a/tensorflow/core/util/rocm_solvers.cc +++ b/tensorflow/core/util/rocm_solvers.cc @@ -37,11 +37,8 @@ Unmqr // ---- // rocsolver_Xunmqr // hipsolverXunmqr #include #include -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/platform/default/dso_loader.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/rocm/rocblas_wrapper.h" +#include "xla/stream_executor/stream_executor.h" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" @@ -57,8 +54,7 @@ Unmqr // ---- // rocsolver_Xunmqr // hipsolverXunmqr namespace tensorflow { namespace { -using stream_executor::gpu::GpuExecutor; -using stream_executor::gpu::ScopedActivateContext; +using stream_executor::StreamExecutor; inline bool CopyHostToDevice(OpKernelContext* context, void* dst, const void* src, uint64 bytes) { @@ -68,7 +64,13 @@ inline bool CopyHostToDevice(OpKernelContext* context, void* dst, } struct GpuSolverHandles { +<<<<<<< HEAD explicit GpuSolverHandles(hipStream_t stream) { +======= + explicit GpuSolverHandles(StreamExecutor* parent, hipStream_t stream) { + parent_ = parent; + std::unique_ptr sac = parent_->Activate(); +>>>>>>> master #if TF_ROCM_VERSION >= 40500 CHECK(se::wrap::hipsolverCreate(&hipsolver_handle) == rocblas_status_success) @@ -83,6 +85,10 @@ struct GpuSolverHandles { } ~GpuSolverHandles() { +<<<<<<< HEAD +======= + std::unique_ptr sac = parent_->Activate(); +>>>>>>> master CHECK(se::wrap::rocblas_destroy_handle(rocm_blas_handle) == rocblas_status_success) << "Failed to destroy rocBlas instance."; @@ -92,6 +98,10 @@ struct GpuSolverHandles { << "Failed to destroy hipsolver instance."; #endif } +<<<<<<< HEAD +======= + StreamExecutor* parent_; +>>>>>>> master rocblas_handle rocm_blas_handle; #if TF_ROCM_VERSION >= 40500 hipsolverHandle_t hipsolver_handle; @@ -114,6 +124,11 @@ static mutex handle_map_mutex(LINKER_INITIALIZED); GpuSolver::GpuSolver(OpKernelContext* context) : context_(context) { mutex_lock lock(handle_map_mutex); +<<<<<<< HEAD +======= + StreamExecutor* gpu_executor = + context->op_device_context()->stream()->parent(); +>>>>>>> master hip_stream_ = reinterpret_cast( CHECK_NOTNULL(context->op_device_context() ->stream() @@ -176,7 +191,8 @@ void GpuSolver::CheckLapackInfoAndDeleteSolverAsync( std::function&)> info_checker_callback, std::vector host_lapack_infos) { - ScopedActivateContext scoped_activation{stream->parent()}; + std::unique_ptr scoped_activation = + stream->parent()->Activate(); Status status; for (const auto& host_lapack_info : host_lapack_infos) { for (int i = 0; i < host_lapack_info.size() && status.ok(); ++i) { @@ -772,7 +788,11 @@ TF_CALL_HIP_LAPACK_TYPES_NO_COMPLEX(GESVD_INSTANCE); TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSV_INSTANCE); template +<<<<<<< HEAD static inline Status TrsmImpl(SolverFnT solver, +======= +static inline Status TrsmImpl(StreamExecutor* gpu_executor, SolverFnT solver, +>>>>>>> master rocblas_handle rocm_blas_handle, rocblas_side side, rocblas_fill uplo, rocblas_operation trans, rocblas_diagonal diag, @@ -782,6 +802,11 @@ static inline Status TrsmImpl(SolverFnT solver, mutex_lock lock(handle_map_mutex); using ROCmScalar = typename ROCmComplexT::type; +<<<<<<< HEAD +======= + std::unique_ptr sac = + gpu_executor->Activate(); +>>>>>>> master TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, side, uplo, trans, diag, m, n, reinterpret_cast(alpha), @@ -791,6 +816,7 @@ static inline Status TrsmImpl(SolverFnT solver, return OkStatus(); } +<<<<<<< HEAD #define TRSM_INSTANCE(Scalar, type_prefix) \ template <> \ Status GpuSolver::Trsm( \ @@ -801,12 +827,30 @@ static inline Status TrsmImpl(SolverFnT solver, return TrsmImpl(BLAS_SOLVER_FN(trsm, type_prefix), \ rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \ A, lda, B, ldb); \ +======= +#define TRSM_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status GpuSolver::Trsm( \ + rocblas_side side, rocblas_fill uplo, rocblas_operation trans, \ + rocblas_diagonal diag, int m, int n, \ + const Scalar* alpha, /* host or device pointer */ \ + const Scalar* A, int lda, Scalar* B, int ldb) { \ + StreamExecutor* gpu_executor = \ + context_->op_device_context()->stream()->parent(); \ + return TrsmImpl(gpu_executor, BLAS_SOLVER_FN(trsm, type_prefix), \ + rocm_blas_handle_, side, uplo, trans, diag, m, n, alpha, \ + A, lda, B, ldb); \ +>>>>>>> master } TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_INSTANCE); template +<<<<<<< HEAD Status MatInvBatchedImpl(SolverFnT solver, +======= +Status MatInvBatchedImpl(StreamExecutor* gpu_executor, SolverFnT solver, +>>>>>>> master rocblas_handle rocm_blas_handle, int n, const Scalar* const host_a_dev_ptrs[], int lda, int* dev_pivots, @@ -815,6 +859,11 @@ Status MatInvBatchedImpl(SolverFnT solver, int batch_size) { mutex_lock lock(handle_map_mutex); using ROCmScalar = typename ROCmComplexT::type; +<<<<<<< HEAD +======= + std::unique_ptr sac = + gpu_executor->Activate(); +>>>>>>> master GetrfBatched(n, host_a_dev_ptrs, lda, dev_pivots, dev_lapack_info, batch_size); @@ -825,6 +874,7 @@ Status MatInvBatchedImpl(SolverFnT solver, return OkStatus(); } +<<<<<<< HEAD #define MATINVBATCHED_INSTANCE(Scalar, type_prefix) \ template <> \ Status GpuSolver::MatInvBatched( \ @@ -840,6 +890,25 @@ Status MatInvBatchedImpl(SolverFnT solver, BLAS_SOLVER_FN(matinvbatched, type_prefix), \ rocm_blas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \ host_a_inverse_dev_ptrs, ldainv, dev_lapack_info, batch_size); \ +======= +#define MATINVBATCHED_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status GpuSolver::MatInvBatched( \ + int n, const Scalar* const host_a_dev_ptrs[], int lda, \ + const Scalar* const host_a_inverse_dev_ptrs[], int ldainv, \ + DeviceLapackInfo* dev_lapack_info, int batch_size) { \ + StreamExecutor* gpu_executor = \ + context_->op_device_context()->stream()->parent(); \ + Tensor pivots; \ + context_->allocate_scoped_tensor(DataTypeToEnum::value, \ + TensorShape{batch_size, n}, &pivots); \ + auto pivots_mat = pivots.template matrix(); \ + int* dev_pivots = pivots_mat.data(); \ + return MatInvBatchedImpl( \ + gpu_executor, BLAS_SOLVER_FN(matinvbatched, type_prefix), \ + rocm_blas_handle_, n, host_a_dev_ptrs, lda, dev_pivots, \ + host_a_inverse_dev_ptrs, ldainv, dev_lapack_info, batch_size); \ +>>>>>>> master } #define TRSM_BATCHED_INSTANCE(Scalar, type_prefix) \ @@ -876,7 +945,11 @@ Status MatInvBatchedImpl(SolverFnT solver, TF_CALL_LAPACK_TYPES_NO_COMPLEX(TRSM_BATCHED_INSTANCE); template +<<<<<<< HEAD Status GeamImpl(SolverFnT solver, +======= +Status GeamImpl(StreamExecutor* gpu_executor, SolverFnT solver, +>>>>>>> master rocblas_handle rocm_blas_handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const Scalar* alpha, /* host or device pointer */ const Scalar* A, int lda, @@ -886,6 +959,11 @@ Status GeamImpl(SolverFnT solver, mutex_lock lock(handle_map_mutex); using ROCmScalar = typename ROCmComplexT::type; +<<<<<<< HEAD +======= + std::unique_ptr sac = + gpu_executor->Activate(); +>>>>>>> master TF_RETURN_IF_ROCBLAS_ERROR(solver(rocm_blas_handle, transa, transb, m, n, reinterpret_cast(alpha), reinterpret_cast(A), lda, @@ -895,6 +973,7 @@ Status GeamImpl(SolverFnT solver, return OkStatus(); } +<<<<<<< HEAD #define GEAM_INSTANCE(Scalar, type_prefix) \ template <> \ Status GpuSolver::Geam( \ @@ -904,6 +983,19 @@ Status GeamImpl(SolverFnT solver, return GeamImpl(BLAS_SOLVER_FN(geam, type_prefix), \ rocm_blas_handle_, transa, transb, m, n, alpha, A, lda, \ beta, B, ldb, C, ldc); \ +======= +#define GEAM_INSTANCE(Scalar, type_prefix) \ + template <> \ + Status GpuSolver::Geam( \ + rocblas_operation transa, rocblas_operation transb, int m, int n, \ + const Scalar* alpha, const Scalar* A, int lda, const Scalar* beta, \ + const Scalar* B, int ldb, Scalar* C, int ldc) { \ + StreamExecutor* gpu_executor = \ + context_->op_device_context()->stream()->parent(); \ + return GeamImpl(gpu_executor, BLAS_SOLVER_FN(geam, type_prefix), \ + rocm_blas_handle_, transa, transb, m, n, alpha, A, lda, \ + beta, B, ldb, C, ldc); \ +>>>>>>> master } TF_CALL_LAPACK_TYPES_NO_COMPLEX(GEAM_INSTANCE); diff --git a/tensorflow/core/util/sparse/sparse_tensor.cc b/tensorflow/core/util/sparse/sparse_tensor.cc index 75dffd02fed286..aa63d6d039779f 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.cc +++ b/tensorflow/core/util/sparse/sparse_tensor.cc @@ -27,21 +27,21 @@ int UnsafeGetDimsFromIx(const Tensor& ix) { return ix.dim_size(1); } -Status GetDimsFromIx(const Tensor& ix, int* result) { +absl::Status GetDimsFromIx(const Tensor& ix, int* result) { if (!TensorShapeUtils::IsMatrix(ix.shape())) { return errors::InvalidArgument("indices must be a matrix, but got: ", ix.shape().DebugString()); } *result = UnsafeGetDimsFromIx(ix); - return Status(); + return absl::Status(); } } // namespace -/* static */ Status SparseTensor::Create(Tensor ix, Tensor vals, - const VarDimArray shape, - const VarDimArray order, - SparseTensor* result) { +/* static */ absl::Status SparseTensor::Create(Tensor ix, Tensor vals, + const VarDimArray shape, + const VarDimArray order, + SparseTensor* result) { if (ix.dtype() != DT_INT64) { return errors::InvalidArgument("indices must be type int64 but got: ", ix.dtype()); @@ -73,24 +73,24 @@ Status GetDimsFromIx(const Tensor& ix, int* result) { return absl::OkStatus(); } -/* static */ Status SparseTensor::Create(Tensor ix, Tensor vals, - const TensorShape& shape, - SparseTensor* result) { +/* static */ absl::Status SparseTensor::Create(Tensor ix, Tensor vals, + const TensorShape& shape, + SparseTensor* result) { return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape), UndefinedOrder(TensorShapeToVector(shape)), result); } -/* static */ Status SparseTensor::Create(Tensor ix, Tensor vals, - const VarDimArray shape, - SparseTensor* result) { +/* static */ absl::Status SparseTensor::Create(Tensor ix, Tensor vals, + const VarDimArray shape, + SparseTensor* result) { return Create(std::move(ix), std::move(vals), shape, UndefinedOrder(shape), result); } -/* static */ Status SparseTensor::Create(Tensor ix, Tensor vals, - const TensorShape& shape, - const VarDimArray order, - SparseTensor* result) { +/* static */ absl::Status SparseTensor::Create(Tensor ix, Tensor vals, + const TensorShape& shape, + const VarDimArray order, + SparseTensor* result) { return Create(std::move(ix), std::move(vals), TensorShapeToVector(shape), order, result); } @@ -222,7 +222,7 @@ bool SparseTensor::IndicesValidMatrix32BitFastPath() const { } template -Status SparseTensor::IndicesValidHelper() const { +absl::Status SparseTensor::IndicesValidHelper() const { const auto ix_t = ix_.matrix(); const int64_t* const shape_ptr = shape_.data(); @@ -275,7 +275,7 @@ Status SparseTensor::IndicesValidHelper() const { return absl::OkStatus(); } -Status SparseTensor::IndicesValid() const { +absl::Status SparseTensor::IndicesValid() const { if (shape_.size() == 1 && IndicesValidVectorFastPath()) { return absl::OkStatus(); } diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index a2468f6f4a610d..b7b7b17b3190e9 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -44,17 +44,17 @@ class SparseTensor { typedef absl::Span VarDimArray; typedef absl::InlinedVector ShapeArray; - static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, - const VarDimArray order, SparseTensor* result); + static absl::Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + const VarDimArray order, SparseTensor* result); - static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, - SparseTensor* result); + static absl::Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + SparseTensor* result); - static Status Create(Tensor ix, Tensor vals, const VarDimArray shape, - SparseTensor* result); + static absl::Status Create(Tensor ix, Tensor vals, const VarDimArray shape, + SparseTensor* result); - static Status Create(Tensor ix, Tensor vals, const TensorShape& shape, - const VarDimArray order, SparseTensor* result); + static absl::Status Create(Tensor ix, Tensor vals, const TensorShape& shape, + const VarDimArray order, SparseTensor* result); SparseTensor() : dims_(0) {} @@ -113,7 +113,7 @@ class SparseTensor { DataType dtype() const { return vals_.dtype(); } - Status IndicesValid() const; + absl::Status IndicesValid() const; VarDimArray shape() const { return shape_; } @@ -170,8 +170,9 @@ class SparseTensor { // isn't an integer multiple of split_dim, we add one extra dimension for // each slice. template - static Status Split(const SparseTensor& tensor, const int split_dim, - const int num_split, std::vector* result); + static absl::Status Split(const SparseTensor& tensor, const int split_dim, + const int num_split, + std::vector* result); // Slice() will slice the input SparseTensor into a SparseTensor based on // specified start and size. Both start and size are 1-D array with each @@ -212,7 +213,7 @@ class SparseTensor { bool IndicesValidMatrix32BitFastPath() const; template - Status IndicesValidHelper() const; + absl::Status IndicesValidHelper() const; // Helper for ToDense() template @@ -493,9 +494,10 @@ inline SparseTensor SparseTensor::Concat( } template -inline Status SparseTensor::Split(const SparseTensor& input_tensor, - const int split_dim, const int num_split, - std::vector* result) { +inline absl::Status SparseTensor::Split(const SparseTensor& input_tensor, + const int split_dim, + const int num_split, + std::vector* result) { std::vector output_indices; std::vector output_values; std::vector output_shapes; @@ -567,7 +569,7 @@ inline Status SparseTensor::Split(const SparseTensor& input_tensor, result->reserve(num_split); for (int i = 0; i < num_split; ++i) { SparseTensor tensor; - Status create_status = + absl::Status create_status = Create(output_indices[i], output_values[i], output_shapes[i], &tensor); if (!create_status.ok()) { return create_status; diff --git a/tensorflow/core/util/sparse/sparse_tensor_test.cc b/tensorflow/core/util/sparse/sparse_tensor_test.cc index 586b1790e55de6..4126c6d55ef093 100644 --- a/tensorflow/core/util/sparse/sparse_tensor_test.cc +++ b/tensorflow/core/util/sparse/sparse_tensor_test.cc @@ -190,7 +190,7 @@ TEST(SparseTensorTest, SparseTensorConstruction) { std::vector order{0, 1, 2}; SparseTensor st; TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); - Status st_indices_valid = st.IndicesValid(); + absl::Status st_indices_valid = st.IndicesValid(); EXPECT_FALSE(st_indices_valid.ok()); EXPECT_EQ( "indices[2] = [2,0,0] is out of order. " @@ -296,7 +296,7 @@ TEST(SparseTensorTest, ValidateIndicesFindsInvalid) { TF_ASSERT_OK(SparseTensor::Create(ix, vals, shape, order, &st)); st.Reorder(order); - Status st_indices_valid = st.IndicesValid(); + absl::Status st_indices_valid = st.IndicesValid(); EXPECT_FALSE(st_indices_valid.ok()); EXPECT_EQ("indices[1] = [0,0,0] is repeated", st_indices_valid.message()); @@ -337,7 +337,7 @@ TEST(SparseTensorTest, SparseTensorCheckBoundaries) { ix_t(0, 0) = 11; ix.matrix() = ix_t; st.Reorder(order); - Status st_indices_valid = st.IndicesValid(); + absl::Status st_indices_valid = st.IndicesValid(); EXPECT_FALSE(st_indices_valid.ok()); // Error message references index 4 because of the call to Reorder. EXPECT_EQ("[11,0,0] is out of bounds: need 0 <= index < [10,10,10]", diff --git a/tensorflow/core/util/stats_calculator.h b/tensorflow/core/util/stats_calculator.h index 20c997ced374a7..7f188fe07c6133 100644 --- a/tensorflow/core/util/stats_calculator.h +++ b/tensorflow/core/util/stats_calculator.h @@ -33,6 +33,7 @@ namespace tensorflow { using tsl::Stat; using tsl::StatsCalculator; +using tsl::StatWithPercentiles; } // namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index 4ca9b222fb114b..9aad78e3890057 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -61,7 +61,7 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/io:buffered_file", + "@local_xla//xla/tsl/lib/io:buffered_file", "@local_xla//xla/tsl/util:byte_swap_array", ], ) diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index 87038432ae9428..c97356202bcd93 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "xla/tsl/util/byte_swap_array.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -54,7 +55,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/byte_swap_tensor.h" #include "tensorflow/core/util/tensor_bundle/naming.h" #include "tensorflow/core/util/tensor_slice_util.h" -#include "tsl/lib/io/buffered_file.h" #ifdef PLATFORM_WINDOWS #undef DeleteFile diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h index ba1a4f7053aac6..e3d8bb590ce411 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h @@ -72,6 +72,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/function_ref.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/buffered_file.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" @@ -87,7 +88,6 @@ limitations under the License. #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/protobuf/tensor_bundle.pb.h" #include "tensorflow/core/util/tensor_slice_set.h" -#include "tsl/lib/io/buffered_file.h" #include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 0f1dc463a4eed3..46ed95309bfdec 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -520,9 +520,9 @@ std::string GetConvnetDataFormat2D3DAttrString(); // FORMAT_NCHW: (N, C, spatial); rank = spatial.size() + 2 // FORMAT_NCHW_VECT_C: (N, C, spatial, InnerC); rank = spatial.size() + 3 // FORMAT_NHWC_VECT_W: (N, spatial, C, InnerW); rank = spatial.size() + 3 -inline Status ShapeFromFormatWithStatus(TensorFormat format, int64_t N, - absl::Span spatial, - int64_t C, TensorShape* shape) { +inline absl::Status ShapeFromFormatWithStatus(TensorFormat format, int64_t N, + absl::Span spatial, + int64_t C, TensorShape* shape) { const int dims = GetTensorDimsFromSpatialDims(spatial.size(), format); absl::InlinedVector dim_sizes(dims); dim_sizes[GetTensorBatchDimIndex(dims, format)] = N; @@ -583,9 +583,9 @@ inline TensorShape ShapeFromFilterTensorFormat( } // Return a tensor shape of the specified 'format', and dimensions. -inline Status ShapeFromFormatWithStatus(TensorFormat format, int64_t N, - int64_t H, int64_t W, int64_t C, - TensorShape* shape) { +inline absl::Status ShapeFromFormatWithStatus(TensorFormat format, int64_t N, + int64_t H, int64_t W, int64_t C, + TensorShape* shape) { return ShapeFromFormatWithStatus(format, N, {H, W}, C, shape); } @@ -606,10 +606,10 @@ inline TensorShape ShapeFromFilterTensorFormat(FilterTensorFormat format, // Returns a copy of the specified tensor 'src_shape' converted from // 'src_format' to 'dst_format'. -inline Status ShapeFromFormatWithStatus(TensorFormat dst_format, - const TensorShape& src_shape, - TensorFormat src_format, - TensorShape* shape) { +inline absl::Status ShapeFromFormatWithStatus(TensorFormat dst_format, + const TensorShape& src_shape, + TensorFormat src_format, + TensorShape* shape) { if (src_format == dst_format) { *shape = src_shape; return absl::OkStatus(); diff --git a/tensorflow/core/util/tensor_ops_util.h b/tensorflow/core/util/tensor_ops_util.h index 123bf0f35d0702..ccb0c8ec85612f 100644 --- a/tensorflow/core/util/tensor_ops_util.h +++ b/tensorflow/core/util/tensor_ops_util.h @@ -31,7 +31,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; template -Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) { +absl::Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, + Tensor* out) { AllocatorAttributes attr; if (x.dtype() == DT_VARIANT) { attr.set_on_host(true); @@ -69,8 +70,8 @@ Status ZerosLikeTensor(OpKernelContext* ctx, const Tensor& x, Tensor* out) { } template -Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, const Tensor& b, - Tensor* out) { +absl::Status BinaryAddTensors(OpKernelContext* ctx, const Tensor& a, + const Tensor& b, Tensor* out) { if (a.dtype() == DT_INVALID) { *out = b; return absl::OkStatus(); diff --git a/tensorflow/core/util/tensor_slice_reader.cc b/tensorflow/core/util/tensor_slice_reader.cc index 8e84ca617072f9..6911b58a563a2b 100644 --- a/tensorflow/core/util/tensor_slice_reader.cc +++ b/tensorflow/core/util/tensor_slice_reader.cc @@ -70,12 +70,12 @@ class TensorSliceReaderTable : public TensorSliceReader::Table { }; } // namespace -Status OpenTableTensorSliceReader(const string& fname, - TensorSliceReader::Table** result) { +absl::Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** result) { *result = nullptr; Env* env = Env::Default(); std::unique_ptr f; - Status s = env->NewRandomAccessFile(fname, &f); + absl::Status s = env->NewRandomAccessFile(fname, &f); if (s.ok()) { uint64 file_size; s = env->GetFileSize(fname, &file_size); @@ -113,7 +113,7 @@ TensorSliceReader::TensorSliceReader(const string& filepattern, int preferred_shard) : filepattern_(filepattern), open_function_(std::move(open_function)) { VLOG(1) << "TensorSliceReader for " << filepattern; - Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_); + absl::Status s = Env::Default()->GetMatchingPaths(filepattern, &fnames_); if (!s.ok()) { status_ = errors::InvalidArgument( "Unsuccessful TensorSliceReader constructor: " @@ -151,7 +151,7 @@ void TensorSliceReader::LoadShard(int shard) const { const string fname = fnames_[shard]; VLOG(1) << "Reading meta data from file " << fname << "..."; Table* table; - Status s = open_function_(fname, &table); + absl::Status s = open_function_(fname, &table); if (!s.ok()) { status_ = errors::DataLoss("Unable to open table file ", fname, ": ", s.ToString()); @@ -233,7 +233,7 @@ bool TensorSliceReader::HasTensor(const string& name, TensorShape* shape, } } -Status TensorSliceReader::GetTensor( +absl::Status TensorSliceReader::GetTensor( const string& name, std::unique_ptr* out_tensor) const { DataType type; TensorShape shape; @@ -256,7 +256,7 @@ Status TensorSliceReader::GetTensor( } std::unique_ptr t(new tensorflow::Tensor); - Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get()); + absl::Status s = tensorflow::Tensor::BuildTensor(type, shape, t.get()); if (!s.ok()) return s; for (const auto d : shape.dim_sizes()) { diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h index 91ee7cd7d76ab9..be6a0b6f1c1f13 100644 --- a/tensorflow/core/util/tensor_slice_reader.h +++ b/tensorflow/core/util/tensor_slice_reader.h @@ -62,7 +62,7 @@ class TensorSliceReader { virtual ~Table(); virtual bool Get(const string& key, string* value) = 0; }; - typedef std::function OpenTableFunction; + typedef std::function OpenTableFunction; static constexpr int kLoadAllShards = -1; TensorSliceReader(const string& filepattern); @@ -78,7 +78,7 @@ class TensorSliceReader { int num_files() const { return sss_.size(); } // Get the status of the reader. - Status status() const { return status_; } + absl::Status status() const { return status_; } // Checks if the reader contains any slice of a tensor. In case the reader // does contain the tensor, if "shape" is not nullptr, fill "shape" with the @@ -101,8 +101,8 @@ class TensorSliceReader { // Returns value for one tensor. Only single slice checkpoints are supported // at the moment. - Status GetTensor(const string& name, - std::unique_ptr* out_tensor) const; + absl::Status GetTensor(const string& name, + std::unique_ptr* out_tensor) const; typedef std::unordered_map VarToShapeMap; typedef std::unordered_map VarToDataTypeMap; @@ -136,14 +136,14 @@ class TensorSliceReader { mutable bool all_shards_loaded_ = false; mutable std::vector> sss_; mutable std::unordered_map tensors_; - mutable Status status_; + mutable absl::Status status_; TensorSliceReader(const TensorSliceReader&) = delete; void operator=(const TensorSliceReader&) = delete; }; -Status OpenTableTensorSliceReader(const string& fname, - TensorSliceReader::Table** result); +absl::Status OpenTableTensorSliceReader(const string& fname, + TensorSliceReader::Table** result); template bool TensorSliceReader::CopySliceData(const string& name, @@ -187,7 +187,7 @@ bool TensorSliceReader::CopySliceData(const string& name, } // Ensure the TensorSlice contains the expected amount of data. TensorShape shp_s; - Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s); + absl::Status s = slice_s.SliceTensorShape(tss->shape(), &shp_s); if (!s.ok()) { VLOG(1) << "Failed to slice tensor " << name << ", slice " << slice_s.DebugString() << ": " << s; diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h index 59426f97de19d7..cc85b7f2e353dd 100644 --- a/tensorflow/core/util/tensor_slice_reader_cache.h +++ b/tensorflow/core/util/tensor_slice_reader_cache.h @@ -67,7 +67,8 @@ class TensorSliceReaderCache { private: // Need to use a regular function type in the key map as std::function does // not support ==. - typedef Status (*OpenFuncType)(const string&, TensorSliceReader::Table**); + typedef absl::Status (*OpenFuncType)(const string&, + TensorSliceReader::Table**); // Protects attributes below. mutex mu_; diff --git a/tensorflow/core/util/tensor_slice_set.cc b/tensorflow/core/util/tensor_slice_set.cc index 176bc1bcc25ab9..1396a27abc0262 100644 --- a/tensorflow/core/util/tensor_slice_set.cc +++ b/tensorflow/core/util/tensor_slice_set.cc @@ -33,7 +33,8 @@ TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type) TensorSliceSet::~TensorSliceSet() = default; -Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag) { +absl::Status TensorSliceSet::Register(const TensorSlice& slice, + const string& tag) { TensorShape result_shape; TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape)); string str = slice.DebugString(); @@ -64,7 +65,7 @@ bool TensorSliceSet::QueryMeta( const TensorSlice& slice, std::vector>* results) const { results->clear(); - Status s; + absl::Status s; string str = slice.DebugString(); // First we check if there is an exactly match (this is the dominant case). const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); @@ -79,7 +80,7 @@ bool TensorSliceSet::QueryMeta( // intersections cover the entire slice. We rely on the fact that the // existing slices don't have any intersection among themselves. TensorShape target_shape; - Status s; + absl::Status s; s = slice.SliceTensorShape(shape_, &target_shape); if (!s.ok()) { LOG(WARNING) << s; @@ -112,7 +113,7 @@ bool TensorSliceSet::QueryMeta( } } -Status RegisterTensorSlice( +absl::Status RegisterTensorSlice( const string& name, const TensorShape& shape, DataType type, const string& tag, const TensorSlice& slice, std::unordered_map* tensor_slices) { diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h index 4887321de959f6..f7b3d08db584f1 100644 --- a/tensorflow/core/util/tensor_slice_set.h +++ b/tensorflow/core/util/tensor_slice_set.h @@ -47,7 +47,7 @@ class TensorSliceSet { // associated with the slice (in one application it denotes the name of the // file that contains the slice); the "data" points to the data of the tensor // slice (it can be a nullptr). - Status Register(const TensorSlice& slice, const string& tag); + absl::Status Register(const TensorSlice& slice, const string& tag); // Alternative way of querying about a new slice: instead of copying the // data, it returns a list of meta data about the stored slices that will @@ -82,7 +82,7 @@ class TensorSliceSet { // "name". Other arguments are used for validations. Does not modify the map // or its values on non-OK. // REQUIRES: tensor_slices != nullptr -Status RegisterTensorSlice( +absl::Status RegisterTensorSlice( const string& name, const TensorShape& shape, DataType type, const string& tag, const TensorSlice& slice, std::unordered_map* tensor_slices); diff --git a/tensorflow/core/util/tensor_slice_util.h b/tensorflow/core/util/tensor_slice_util.h index 7507d5ae6135ca..b58ecf55116339 100644 --- a/tensorflow/core/util/tensor_slice_util.h +++ b/tensorflow/core/util/tensor_slice_util.h @@ -151,7 +151,7 @@ static bool CopyDataFromTensorSliceToTensorSlice(const TensorShape& shape, // We need to compute the applied shapes after applying slice_s and // slice_d. TensorShape shp_s, shp_d; - Status s; + absl::Status s; s = slice_s.SliceTensorShape(shape, &shp_s); if (!s.ok()) { LOG(WARNING) << s; diff --git a/tensorflow/core/util/tensor_slice_writer.cc b/tensorflow/core/util/tensor_slice_writer.cc index e06e18c46a8ea3..35fd86b5a86af9 100644 --- a/tensorflow/core/util/tensor_slice_writer.cc +++ b/tensorflow/core/util/tensor_slice_writer.cc @@ -44,9 +44,9 @@ class TableBuilder : public TensorSliceWriter::Builder { void Add(StringPiece key, StringPiece val) override { builder_->Add(key, val); } - Status Finish(int64_t* file_size) override { + absl::Status Finish(int64_t* file_size) override { *file_size = -1; - Status s = builder_->Finish(); + absl::Status s = builder_->Finish(); if (s.ok()) { s = file_->Close(); if (s.ok()) { @@ -69,11 +69,11 @@ class TableBuilder : public TensorSliceWriter::Builder { }; } // anonymous namespace -Status CreateTableTensorSliceBuilder(const string& name, - TensorSliceWriter::Builder** builder) { +absl::Status CreateTableTensorSliceBuilder( + const string& name, TensorSliceWriter::Builder** builder) { *builder = nullptr; std::unique_ptr f; - Status s = Env::Default()->NewWritableFile(name, &f); + absl::Status s = Env::Default()->NewWritableFile(name, &f); if (s.ok()) { *builder = new TableBuilder(name, f.release()); return absl::OkStatus(); @@ -88,7 +88,7 @@ TensorSliceWriter::TensorSliceWriter(const string& filename, create_builder_(std::move(create_builder)), slices_(0) { Env* env = Env::Default(); - Status status = env->CanCreateTempFile(filename_, &use_temp_file_); + absl::Status status = env->CanCreateTempFile(filename_, &use_temp_file_); if (!status.ok()) { LOG(ERROR) << "Failed to get CanCreateTempFile attribute: " << filename_; use_temp_file_ = true; @@ -103,9 +103,9 @@ TensorSliceWriter::TensorSliceWriter(const string& filename, versions->set_min_consumer(TF_CHECKPOINT_VERSION_MIN_CONSUMER); } -Status TensorSliceWriter::Finish() { +absl::Status TensorSliceWriter::Finish() { Builder* b; - Status s = create_builder_(data_filename_, &b); + absl::Status s = create_builder_(data_filename_, &b); if (!s.ok()) { delete b; return s; @@ -199,8 +199,8 @@ size_t TensorSliceWriter::MaxBytesPerElementOrZero(DataType dt) { } template <> -Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements, - SavedSlice* ss) { +absl::Status TensorSliceWriter::SaveData(const tstring* data, + int64_t num_elements, SavedSlice* ss) { size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes + (num_elements * MaxBytesPerElement(DT_INT32)); for (int64_t i = 0; i < num_elements; ++i) { diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h index 1409518f965621..bd13b55d6de471 100644 --- a/tensorflow/core/util/tensor_slice_writer.h +++ b/tensorflow/core/util/tensor_slice_writer.h @@ -49,9 +49,10 @@ class TensorSliceWriter { public: virtual ~Builder() = default; virtual void Add(StringPiece key, StringPiece value) = 0; - virtual Status Finish(int64_t* file_size) = 0; + virtual absl::Status Finish(int64_t* file_size) = 0; }; - typedef std::function CreateBuilderFunction; + typedef std::function + CreateBuilderFunction; TensorSliceWriter(const string& filename, CreateBuilderFunction create_builder); @@ -59,14 +60,15 @@ class TensorSliceWriter { // Adds a slice. We support float and int32 for now. // TODO(yangke): add more supports template - Status Add(const string& name, const TensorShape& shape, - const TensorSlice& slice, const T* data); - Status Finish(); + absl::Status Add(const string& name, const TensorShape& shape, + const TensorSlice& slice, const T* data); + absl::Status Finish(); // Allocate "num_elements" elements in "ss" and save the data in "data" // there. template - static Status SaveData(const T* data, int64_t num_elements, SavedSlice* ss); + static absl::Status SaveData(const T* data, int64_t num_elements, + SavedSlice* ss); static size_t MaxBytesPerElement(DataType dt); @@ -102,8 +104,9 @@ class TensorSliceWriter { }; template -Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, - const TensorSlice& slice, const T* data) { +absl::Status TensorSliceWriter::Add(const string& name, + const TensorShape& shape, + const TensorSlice& slice, const T* data) { // The tensor and the slice have to be compatible if (shape.dims() != slice.dims()) { return errors::Internal("Incompatible tensor shape and slice: ", "shape = ", @@ -167,8 +170,8 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape, } template -Status TensorSliceWriter::SaveData(const T* data, int64_t num_elements, - SavedSlice* ss) { +absl::Status TensorSliceWriter::SaveData(const T* data, int64_t num_elements, + SavedSlice* ss) { size_t max_bytes_per_element = MaxBytesPerElementOrZero(DataTypeToEnum::value); if (max_bytes_per_element == 0) { @@ -190,15 +193,15 @@ Status TensorSliceWriter::SaveData(const T* data, int64_t num_elements, } template <> -Status TensorSliceWriter::SaveData(const tstring* data, int64_t num_elements, - SavedSlice* ss); +absl::Status TensorSliceWriter::SaveData(const tstring* data, + int64_t num_elements, SavedSlice* ss); // Create a table builder that will write to "filename" in // tensorflow::io::Table format. If successful, return OK // and set "*builder" to the allocated builder. Otherwise, return a // non-OK status. -Status CreateTableTensorSliceBuilder(const string& filename, - TensorSliceWriter::Builder** builder); +absl::Status CreateTableTensorSliceBuilder( + const string& filename, TensorSliceWriter::Builder** builder); } // namespace checkpoint diff --git a/tensorflow/core/util/tensor_slice_writer_test.cc b/tensorflow/core/util/tensor_slice_writer_test.cc index 65b636ea36ac0e..4aa0948db78687 100644 --- a/tensorflow/core/util/tensor_slice_writer_test.cc +++ b/tensorflow/core/util/tensor_slice_writer_test.cc @@ -346,7 +346,7 @@ TEST(TensorSliceWriteTest, SizeErrors) { TensorShape shape({300, 1000000}); TensorSlice slice = TensorSlice::ParseOrDie("-:-"); const std::vector data(300000000, -1); - Status s = writer.Add("test1", shape, slice, data.data()); + absl::Status s = writer.Add("test1", shape, slice, data.data()); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), "Tensor slice is too large to serialize")); @@ -357,7 +357,7 @@ TEST(TensorSliceWriteTest, SizeErrors) { TensorShape shape({256, 1024}); TensorSlice slice = TensorSlice::ParseOrDie("-:-"); const std::vector data(256 * 1024, std::string(8192, 'f')); - Status s = writer.Add("test2", shape, slice, data.data()); + absl::Status s = writer.Add("test2", shape, slice, data.data()); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains(s.message(), "Tensor slice is too large to serialize")); @@ -368,7 +368,7 @@ TEST(TensorSliceWriterTest, InvalidInput) { SavedSlice ss; std::array data; std::fill(data.begin(), data.end(), 1234); - Status s = TensorSliceWriter::SaveData(data.data(), data.size(), &ss); + absl::Status s = TensorSliceWriter::SaveData(data.data(), data.size(), &ss); EXPECT_EQ(s.code(), error::INVALID_ARGUMENT); EXPECT_TRUE(absl::StrContains( s.message(), "Tensor slice serialization not implemented for dtype")); diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index 05f5d0f9636d04..e197f0cf90c86c 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -189,4 +189,14 @@ bool IsAMXDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt) { return result; } +// Check if oneDNN supports AVX-NE-CONVERT on CPU +bool IsAVXConvertSupportedByOneDNNOnThisCPU() { + bool result = false; +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + using port::TestCPUFeature; + result = TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT); +#endif // INTEL_MKL && ENABLE_ONEDNN_V3 + return result; +} + } // namespace tensorflow diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h index 8bcbba6c6cf52b..701c423045da8f 100644 --- a/tensorflow/core/util/util.h +++ b/tensorflow/core/util/util.h @@ -71,6 +71,8 @@ bool IsDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt); // Check if input type supports AMX on CPU when oneDNN is enabled bool IsAMXDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt); +bool IsAVXConvertSupportedByOneDNNOnThisCPU(); + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_UTIL_H_ diff --git a/tensorflow/distribute/experimental/rpc/kernels/BUILD b/tensorflow/distribute/experimental/rpc/kernels/BUILD index 611339854ec09c..a71c8e92418f60 100644 --- a/tensorflow/distribute/experimental/rpc/kernels/BUILD +++ b/tensorflow/distribute/experimental/rpc/kernels/BUILD @@ -20,7 +20,6 @@ cc_library( "//tensorflow/distribute/experimental/rpc/proto:tf_rpc_service_cc_grpc_proto", "//tensorflow/distribute/experimental/rpc/proto:tf_rpc_service_proto_cc", "@com_github_grpc_grpc//:grpc++", - "@local_xla//xla/stream_executor/platform", ], alwayslink = 1, ) @@ -66,7 +65,6 @@ tf_kernel_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_xla//xla/stream_executor/platform", ], alwayslink = 1, ) diff --git a/tensorflow/distribute/experimental/rpc/kernels/grpc_rpc_service.h b/tensorflow/distribute/experimental/rpc/kernels/grpc_rpc_service.h index 004b4f294c67ea..c479234fc1dd30 100644 --- a/tensorflow/distribute/experimental/rpc/kernels/grpc_rpc_service.h +++ b/tensorflow/distribute/experimental/rpc/kernels/grpc_rpc_service.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef TENSORFLOW_DISTRIBUTE_EXPERIMENTAL_RPC_KERNELS_GRPC_RPC_SERVICE_H_ #define TENSORFLOW_DISTRIBUTE_EXPERIMENTAL_RPC_KERNELS_GRPC_RPC_SERVICE_H_ -#include "xla/stream_executor/platform/port.h" #include "tensorflow/distribute/experimental/rpc/proto/tf_rpc_service.grpc.pb.h" #include "tensorflow/distribute/experimental/rpc/proto/tf_rpc_service.pb.h" diff --git a/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc b/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc index 2adcc2d649d882..bfc62f173ecf55 100644 --- a/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc +++ b/tensorflow/distribute/experimental/rpc/kernels/rpc_ops.cc @@ -187,12 +187,11 @@ class FunctionRegistry { return debug_string; } - tensorflow::Status Register(const std::string& method, - FunctionLibraryRuntime* lib, - FunctionLibraryRuntime::Handle fn_handle, - std::vector captured_inputs, - const StructuredValue& input_specs, - const StructuredValue& output_specs) { + absl::Status Register(const std::string& method, FunctionLibraryRuntime* lib, + FunctionLibraryRuntime::Handle fn_handle, + std::vector captured_inputs, + const StructuredValue& input_specs, + const StructuredValue& output_specs) { mutex_lock l(mu_); FunctionMetadata fn_metadata; fn_metadata.handle = fn_handle; @@ -209,8 +208,8 @@ class FunctionRegistry { return absl::OkStatus(); } - tensorflow::Status LookUp(const std::string& method, - FunctionMetadata* output) const { + absl::Status LookUp(const std::string& method, + FunctionMetadata* output) const { mutex_lock l(mu_); auto it = registered_methods_.find(method); if (it == registered_methods_.end()) { @@ -273,18 +272,19 @@ class RpcServiceImpl : public grpc::RpcService::Service { std::vector* rets = new std::vector; Notification notification; - fn_lib->Run(opts, handle, args, rets, - [rets, response, ¬ification, &status](const Status& st) { - status = st; - if (status.ok()) { - for (size_t i = 0; i < rets->size(); ++i) { - auto t = response->add_output_tensors(); - (*rets)[i].AsProtoField(t); - } - } - delete rets; - notification.Notify(); - }); + fn_lib->Run( + opts, handle, args, rets, + [rets, response, ¬ification, &status](const absl::Status& st) { + status = st; + if (status.ok()) { + for (size_t i = 0; i < rets->size(); ++i) { + auto t = response->add_output_tensors(); + (*rets)[i].AsProtoField(t); + } + } + delete rets; + notification.Notify(); + }); notification.WaitForNotification(); return ToGrpcStatus(status); @@ -327,12 +327,11 @@ class RpcServer : public ResourceBase { return absl::StrCat("RpcServer resource with ", registry_.DebugString()); } - tensorflow::Status Register(const std::string& method, - FunctionLibraryRuntime* lib, - FunctionLibraryRuntime::Handle fn_handle, - std::vector captured_inputs, - const StructuredValue& input_specs, - const StructuredValue& output_specs) { + absl::Status Register(const std::string& method, FunctionLibraryRuntime* lib, + FunctionLibraryRuntime::Handle fn_handle, + std::vector captured_inputs, + const StructuredValue& input_specs, + const StructuredValue& output_specs) { mutex_lock m(mu_); if (server_started_) { return tensorflow::errors::FailedPrecondition( @@ -466,7 +465,7 @@ class RpcClient : public ResourceBase { }; class RpcFutureResource : public ResourceBase { - typedef std::function + typedef std::function FutureCallBack; public: @@ -490,20 +489,20 @@ class RpcFutureResource : public ResourceBase { done_ = true; } - void set_status(Status status) { status_.Update(status); } - Status get_status() { return status_; } + void set_status(absl::Status status) { status_.Update(status); } + absl::Status get_status() { return status_; } CallResponse* get_response() { return &response_; } private: CallResponse response_; bool done_ TF_GUARDED_BY(mu_); - Status status_; + absl::Status status_; std::vector call_backs_ TF_GUARDED_BY(mu_); mutable mutex mu_; }; -Status ExtractServerAddressFromInput(OpKernelContext* ctx, - std::string* address) { +absl::Status ExtractServerAddressFromInput(OpKernelContext* ctx, + std::string* address) { const Tensor* server_address; auto status = ctx->input("server_address", &server_address); if (status.ok()) { @@ -593,7 +592,7 @@ void RpcClientOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { } auto* response = new ListResponse(); client->ListAsync( - response, [ctx, response, done](const Status& status) { + response, [ctx, response, done](const absl::Status& status) { if (!status.ok()) { ctx->SetStatus(status); } else { @@ -754,7 +753,7 @@ void RpcCallOp::Compute(OpKernelContext* ctx) { client->CallAsync( method, args, response, - [future_resource_ptr](const Status& status) { + [future_resource_ptr](const absl::Status& status) { future_resource_ptr->set_status(status); future_resource_ptr->OperationFinished(); future_resource_ptr->Unref(); @@ -784,7 +783,8 @@ void RpcCheckStatusOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { } future_resource->AddDoneCallback( - [ctx, done, handle](const Status& status, const CallResponse& response) { + [ctx, done, handle](const absl::Status& status, + const CallResponse& response) { Tensor error_code(DT_INT64, TensorShape({})), error_message(DT_STRING, TensorShape({})); error_code.scalar()() = status.raw_code(); @@ -818,7 +818,8 @@ void RpcGetValueOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { } future_resource->AddDoneCallback( - [ctx, done, handle](const Status& status, const CallResponse& response) { + [ctx, done, handle](const absl::Status& status, + const CallResponse& response) { if (!status.ok()) { ctx->SetStatus(status); } else { diff --git a/tensorflow/distribute/experimental/rpc/proto/BUILD b/tensorflow/distribute/experimental/rpc/proto/BUILD index 097b4b38797619..6acb10c5140d30 100644 --- a/tensorflow/distribute/experimental/rpc/proto/BUILD +++ b/tensorflow/distribute/experimental/rpc/proto/BUILD @@ -28,7 +28,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tf_rpc_service_py_pb2", -# api_version = 2, # deps = [":tf_rpc_service_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/dtensor/cc/BUILD b/tensorflow/dtensor/cc/BUILD index cf69454562b81e..5a10958f6dd1ef 100644 --- a/tensorflow/dtensor/cc/BUILD +++ b/tensorflow/dtensor/cc/BUILD @@ -267,8 +267,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:device_util", "//tensorflow/compiler/mlir/tensorflow:error_util", - "//tensorflow/compiler/mlir/tensorflow:import_model", - "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/dtensor/cc/dstatus.h b/tensorflow/dtensor/cc/dstatus.h index 01d85ca2eaad7b..62de9e96bc80b8 100644 --- a/tensorflow/dtensor/cc/dstatus.h +++ b/tensorflow/dtensor/cc/dstatus.h @@ -33,13 +33,14 @@ namespace dtensor { template using StatusOr = tsl::StatusOr; -inline Status WithContext(const Status& ds, absl::string_view file, - int line_number, absl::string_view context = "") { +inline absl::Status WithContext(const absl::Status& ds, absl::string_view file, + int line_number, + absl::string_view context = "") { if (ds.ok()) { return ds; } - return Status(ds.code(), absl::StrCat(ds.message(), "\n", file, ":", - line_number, " :: ", context)); + return absl::Status(ds.code(), absl::StrCat(ds.message(), "\n", file, ":", + line_number, " :: ", context)); } template @@ -49,9 +50,9 @@ inline StatusOr WithContext(StatusOr&& ds, absl::string_view file, if (ds.ok()) { return ds; } - return Status(ds.status().code(), - absl::StrCat(ds.status().message(), "\n", file, ":", - line_number, " :: ", context)); + return absl::Status(ds.status().code(), + absl::StrCat(ds.status().message(), "\n", file, ":", + line_number, " :: ", context)); } #define DT_CTX(dstatus, ...) \ diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index 6600a6b23ebd9d..3fb0bf78a22785 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -232,8 +232,8 @@ class DTensorDevice { default_mesh_ = global_default_mesh_; } - Status SetTPUCoreIDs(const std::string& mesh_name, - const std::vector& tpu_core_ids) { + absl::Status SetTPUCoreIDs(const std::string& mesh_name, + const std::vector& tpu_core_ids) { if (VLOG_IS_ON(1)) { LOG(INFO) << "Setting TPU core IDs for " << (mesh_name.empty() ? "default mesh" : mesh_name) << ": "; @@ -871,7 +871,7 @@ StatusOr FetchAttributes(const TFE_OpAttrs* attributes) { if (TF_GetCode(status) == TF_OK) { TF_DeleteStatus(status); } else { - Status failure_status = StatusFromTF_Status(status); + absl::Status failure_status = StatusFromTF_Status(status); TF_DeleteStatus(status); return failure_status; } @@ -979,7 +979,7 @@ TFE_TensorHandle* DTensorDevice::ToTensorHandle(TFE_Context* context, TF_SetStatus(status, TF_INTERNAL, tensor.status().ToString().c_str()); return tensor_handle; } - Status tf_tensor_from_tensor_status; + absl::Status tf_tensor_from_tensor_status; TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &tf_tensor_from_tensor_status); if (!tf_tensor_from_tensor_status.ok()) { @@ -1518,7 +1518,7 @@ StatusOr> SelectGraphToExecute( // Adds processed graph to run for each mesh computation in // `execution_functions` to function definition library. -Status AddExecutionFunctionDefsToFunctionDefLibrary( +absl::Status AddExecutionFunctionDefsToFunctionDefLibrary( const std::string doperation_name, const StackTracesMap& stack_traces, const absl::flat_hash_set& control_ret_nodes, TFE_Context* context, const Graph& graph, ExecutionFunctions* execution_functions) { @@ -2070,7 +2070,7 @@ void DTensorDevice::ExecuteRegularOperation( // for DeviceId. This is done as the first arg is always DeviceId, and it // isn't mapped to input Tensors. const int resource_index_to_update = entry.first - 1; - const Status s = + const absl::Status s = llvm::cast(inputs[resource_index_to_update]) ->UpdateLayout(entry.second); if (!s.ok()) { @@ -2223,7 +2223,8 @@ void DTensorDevice::ExecuteRegularOperation( for (int i = 0; i < result->size(); ++i) { auto& result_tensor = (*result)[i]; const std::vector* result_tensor_shape; - Status shape_status = result_tensor->Shape(&result_tensor_shape); + absl::Status shape_status = + result_tensor->Shape(&result_tensor_shape); if (!shape_status.ok()) { Set_TF_Status_from_Status(status, shape_status); return; @@ -2588,7 +2589,7 @@ bool DTensorDevice::ShouldFastExecuteEagerPureOperation( // Fetch the ops registy to get the output dtype for the op. Certain dtypes // like string are not supported by the broadcast. const OpDef* op_def = nullptr; - Status status = + absl::Status status = OpRegistry::Global()->LookUpOpDef(dtensor_operation.name, &op_def); if (!status.ok()) return false; for (const auto& output_arg : op_def->output_arg()) { diff --git a/tensorflow/dtensor/cc/dtensor_device_util.cc b/tensorflow/dtensor/cc/dtensor_device_util.cc index 8bb49f31d64d64..1457d890b85895 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.cc +++ b/tensorflow/dtensor/cc/dtensor_device_util.cc @@ -203,7 +203,7 @@ std::unique_ptr BroadcastResourceTensor( PartialTensorShape partial_shape = r->dtypes_and_shapes().begin()->shape; // Set the shape/type of the tensor that the resource points to // so that the graph has correct shape/type information that we can use. - const Status s = + const absl::Status s = llvm::cast((*result).get()) ->UpdateShapeAndDType(partial_shape.AsProto(), r->dtypes_and_shapes().begin()->dtype); @@ -239,9 +239,9 @@ bool LayoutsAreCompatible(std::optional first_layout, } // Parse a pair of attribute of (indices, layouts) into a map. -Status ParseAttrMap(const Node& node, absl::string_view indices_attr, - absl::string_view layout_attr, - std::map* indices_layout_map) { +absl::Status ParseAttrMap(const Node& node, absl::string_view indices_attr, + absl::string_view layout_attr, + std::map* indices_layout_map) { std::vector layouts; if (!TryGetNodeAttr(node.attrs(), layout_attr, &layouts)) { return absl::OkStatus(); @@ -265,14 +265,14 @@ Status ParseAttrMap(const Node& node, absl::string_view indices_attr, return absl::OkStatus(); } -Status ParseResourceArgumentLayouts( +absl::Status ParseResourceArgumentLayouts( const Node& node, std::map* inferred_resource_input_layouts) { return ParseAttrMap(node, kNewResourceLayoutIndices, kNewResourceArgLayouts, inferred_resource_input_layouts); } -Status ParseShapeInputLayouts(const Node& node, - std::map* shape_output_metadata) { +absl::Status ParseShapeInputLayouts( + const Node& node, std::map* shape_output_metadata) { return ParseAttrMap(node, kShapeOpInputLayoutIndices, kShapeOpInputLayout, shape_output_metadata); } @@ -325,7 +325,7 @@ StatusOr> GetTensorShapeAsVector( StatusOr> GetTensorShapeAsVector( TFE_TensorHandle* tensor) { tensorflow::PartialTensorShape shape; - const Status status = tensorflow::unwrap(tensor)->Shape(&shape); + const absl::Status status = tensorflow::unwrap(tensor)->Shape(&shape); if (status.ok()) { return GetTensorShapeAsVector(shape); } else { @@ -516,7 +516,7 @@ StatusOr SummarizeValues( std::string TensorWithLayoutTf::SummarizeValue() const { std::string value_summary; - Status status; + absl::Status status; if (layout_.IsSingleDevice() || layout_.IsFullyReplicated()) { status = tensorflow::unwrap(tensors_[0].get())->SummarizeValue(value_summary); @@ -632,9 +632,9 @@ std::string SparseTensorWithLayout::SummarizeValue() const { std::string values_summary; std::string dense_shapes_summary; - Status indices_status; - Status values_status; - Status dense_shapes_status; + absl::Status indices_status; + absl::Status values_status; + absl::Status dense_shapes_status; if (layout().IsFullyReplicated()) { indices_status = tensorflow::unwrap(indices_->tensor(0)) @@ -724,12 +724,12 @@ StatusOr ExecutableManager::ShouldFoldInput( "manager)"); } -Status InferOutputLayouts(const DTensorOperation& doperation, - const NameAttrList& attributes, - const std::optional& default_layout, - tensorflow::Graph* graph, - std::vector* output_layouts) { - tensorflow::Status status; +absl::Status InferOutputLayouts(const DTensorOperation& doperation, + const NameAttrList& attributes, + const std::optional& default_layout, + tensorflow::Graph* graph, + std::vector* output_layouts) { + absl::Status status; tensorflow::NodeDef op_node_def; op_node_def.set_op(doperation.name); op_node_def.set_name("eager_operation"); @@ -756,7 +756,7 @@ Status InferOutputLayouts(const DTensorOperation& doperation, return absl::OkStatus(); } -Status PrepareGraphForMlir( +absl::Status PrepareGraphForMlir( const ExecutableManager>& module_manager, const std::vector& inputs, const DTensorOperation& doperation, @@ -770,7 +770,7 @@ Status PrepareGraphForMlir( // determine default layouts. ShapeRefiner shape_refiner(TF_GRAPH_DEF_VERSION, &flib_def); shape_refiner.set_function_library_for_shape_inference(&flib_def); - tensorflow::Status status; + absl::Status status; { // We include an _Arg node for the device ID, but this isn't used by the // initial function. It will be provided a value, though, so it's @@ -962,7 +962,7 @@ StatusOr> GetNumLocalOutputs(Node* node) { } namespace { -Status SetMultiDeviceFunctionOutputs( +absl::Status SetMultiDeviceFunctionOutputs( TranslatedFunction& function, Node* node, const std::vector& global_output_shapes) { const AttrValue* serialized_layouts = (node->attrs()).Find(kLayoutAttr); @@ -1107,11 +1107,12 @@ StatusOr IdentifyAllFunctionsToExecute( // be dropped during MLIR lowering. // TODO(b/171265131): fix the underlying issue to avoid inserting identity // nodes. -Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph) { +absl::Status MaybeInsertIdentityNodes(const FunctionDef* function_def, + Graph* graph) { if (function_def == nullptr || function_def->control_ret().empty()) { return absl::OkStatus(); } - tensorflow::Status status; + absl::Status status; for (Node* n : graph->nodes()) { if (!n->IsRetval()) { continue; diff --git a/tensorflow/dtensor/cc/dtensor_device_util.h b/tensorflow/dtensor/cc/dtensor_device_util.h index 604d33f9691cd3..97ddf29823a7e0 100644 --- a/tensorflow/dtensor/cc/dtensor_device_util.h +++ b/tensorflow/dtensor/cc/dtensor_device_util.h @@ -580,14 +580,14 @@ StatusOr> GetTensorShapeAsVector( // Returns the shape of a given tensor. StatusOr> GetTensorShapeAsVector(TFE_TensorHandle* tensor); -Status InferOutputLayouts(const DTensorOperation& doperation, - const NameAttrList& attributes, - const std::optional& default_layout, - tensorflow::Graph* graph, - std::vector* output_layouts); +absl::Status InferOutputLayouts(const DTensorOperation& doperation, + const NameAttrList& attributes, + const std::optional& default_layout, + tensorflow::Graph* graph, + std::vector* output_layouts); // Creates a Graph with _Arg and _Retval nodes surrounding an // `operation_name`-type node. -Status PrepareGraphForMlir( +absl::Status PrepareGraphForMlir( const ExecutableManager>& module_manager, const std::vector& inputs, const DTensorOperation& doperation, @@ -608,7 +608,8 @@ StatusOr IdentifyAllFunctionsToExecute( // be dropped during MLIR lowering. // TODO(b/171265131): fix the underlying issue to avoid inserting identity // nodes. -Status MaybeInsertIdentityNodes(const FunctionDef* function_def, Graph* graph); +absl::Status MaybeInsertIdentityNodes(const FunctionDef* function_def, + Graph* graph); // Add DTensor specific function attributes to be compatible with eager runtime. void AddDTensorFunctionAttr(FunctionDef& function_def); diff --git a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc index 2281b31374e1e6..9c5a35278d9179 100644 --- a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc +++ b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc @@ -31,10 +31,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" #include "xla/status_macros.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/function.h" @@ -96,7 +96,8 @@ DTensorMlirPassRunner::ImportGraphToMlir( // Imports GraphDef to TF MLIR. absl::StatusOr> module_ref = - ConvertGraphToMlir(graph, debug_info, flib_def, import_config, &context_); + tensorflow::tf2xla::v2::ConvertGraphToTfExecutor( + graph, debug_info, flib_def, import_config, &context_); // Adds DTensor attributes to ModuleOp. mlir::ModuleOp module = module_ref.value().get(); @@ -140,7 +141,7 @@ DTensorMlirPassRunner::ImportGraphToMlir( return module_ref; } -Status DTensorMlirPassRunner::Run(mlir::ModuleOp module) { +absl::Status DTensorMlirPassRunner::Run(mlir::ModuleOp module) { // Executes and collects results from the passes. mlir::StatusScopedDiagnosticHandler diag_handler(&context_); diff --git a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h index df033425e162c5..f9ea51b20b5481 100644 --- a/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h +++ b/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h @@ -42,7 +42,7 @@ class DTensorMlirPassRunner { Fprint128 cache_key); // Transforms input MLIR module with DTensor Pass pipeline. - Status Run(mlir::ModuleOp module); + absl::Status Run(mlir::ModuleOp module); private: // N.B. op_registration_ must be initialized before context/pass-manager to diff --git a/tensorflow/dtensor/cc/dtensor_meta_ops.cc b/tensorflow/dtensor/cc/dtensor_meta_ops.cc index b4a51359209044..7b3f04947a4717 100644 --- a/tensorflow/dtensor/cc/dtensor_meta_ops.cc +++ b/tensorflow/dtensor/cc/dtensor_meta_ops.cc @@ -52,7 +52,7 @@ REGISTER_OP("DTensorAllScatter") "int64, uint64, bool, string}") .Attr("input_layout: string") .Attr("output_layout: string") - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { shape_inference::ShapeHandle in = c->input(0); if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. @@ -112,7 +112,7 @@ REGISTER_OP("DTensorAllGather") "bool}") .Attr("input_layout: string") .Attr("output_layout: string") - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { shape_inference::ShapeHandle in = c->input(0); if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. @@ -168,7 +168,7 @@ REGISTER_OP("DTensorAllToAll") .Attr("T: {half, bfloat16, float, float64, int32, uint32, int64, bool}") .Attr("input_layout: string") .Attr("output_layout: string") - .SetShapeFn([](shape_inference::InferenceContext* c) -> Status { + .SetShapeFn([](shape_inference::InferenceContext* c) -> absl::Status { shape_inference::ShapeHandle in = c->input(0); if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. diff --git a/tensorflow/dtensor/cc/dtensor_operation.h b/tensorflow/dtensor/cc/dtensor_operation.h index 002ccaea8a60bf..406c7dfa627181 100644 --- a/tensorflow/dtensor/cc/dtensor_operation.h +++ b/tensorflow/dtensor/cc/dtensor_operation.h @@ -45,7 +45,7 @@ struct DTensorOperation { return false; } const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(name, &op_def); + absl::Status status = OpRegistry::Global()->LookUpOpDef(name, &op_def); DCHECK(status.ok()); // Not found. This really shouldn't happen. if (!status.ok()) { return false; diff --git a/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc b/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc index ed228eecd30eaa..9b4eac2fb53059 100644 --- a/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc +++ b/tensorflow/dtensor/cc/dtensor_tpu_kernels.cc @@ -52,10 +52,10 @@ namespace dtensor { // Returns OK if the deletion succeeded, or if the resource was not found. Else // return the deletion error. template -Status DeleteIfExists(ResourceMgr* resource_manager, - const char* resource_name) { +absl::Status DeleteIfExists(ResourceMgr* resource_manager, + const char* resource_name) { VLOG(1) << "Removing resource " << resource_name << " if it exists"; - Status status = resource_manager->Delete( + absl::Status status = resource_manager->Delete( resource_manager->default_container(), resource_name); if (status.ok()) { VLOG(1) << "Removed existing resource " << resource_name; @@ -129,9 +129,9 @@ class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel { bool use_tfrt_host_runtime_; - static Status InitializeInternal(OpKernelContext* ctx, ResourceMgr* rmgr, - absl::Duration retry_timeout, - std::vector* core_id_output_vec) { + static absl::Status InitializeInternal( + OpKernelContext* ctx, ResourceMgr* rmgr, absl::Duration retry_timeout, + std::vector* core_id_output_vec) { // Reset the TPU embedding engine interface if we are not the master. // We need to reset the interface before initializing the host because the // resetting process reset the TPU platform. @@ -220,7 +220,7 @@ class ConfigureAndInitializeGlobalTPUOpKernel : public OpKernel { tpu_mesh)); VLOG(1) << "Removing existing proto compilation cache lookup if it exists"; - Status resource_delete_status = + absl::Status resource_delete_status = rmgr->Delete( rmgr->default_container(), tpu::kCompiledProtoCacheResourceName); @@ -253,7 +253,7 @@ class ShutdownTPUSystemOpKernel : public OpKernel { void Compute(OpKernelContext* ctx) override { LOG(INFO) << "ShutdownTPUSystemOpKernel op"; - Status status; + absl::Status status; TpuSystemInterface* tpu_system = GetPreferredTpuSystem(); if (tpu_system == nullptr) { VLOG(1) << "Shutting down the default TPU system."; diff --git a/tensorflow/dtensor/cc/slice_util.cc b/tensorflow/dtensor/cc/slice_util.cc index f9f54b807897ab..a11d5583da0409 100644 --- a/tensorflow/dtensor/cc/slice_util.cc +++ b/tensorflow/dtensor/cc/slice_util.cc @@ -138,7 +138,7 @@ std::optional Token::GetLocalToken(int64_t dim_size, return std::nullopt; } -Status TokenProcessor::Run(const std::vector& tokens) { +absl::Status TokenProcessor::Run(const std::vector& tokens) { int64_t input_rank = input_rank_; int64_t output_rank; TF_ASSIGN_OR_RETURN(int64_t ellipsis_size, diff --git a/tensorflow/dtensor/cc/slice_util.h b/tensorflow/dtensor/cc/slice_util.h index c16a18227a9495..4d2d5c7250eae0 100644 --- a/tensorflow/dtensor/cc/slice_util.h +++ b/tensorflow/dtensor/cc/slice_util.h @@ -82,7 +82,7 @@ class TokenProcessor { explicit TokenProcessor(int64_t input_rank) : input_rank_(input_rank) {} virtual ~TokenProcessor() = default; - Status Run(const std::vector& tokens); + absl::Status Run(const std::vector& tokens); protected: // Loop for an ellipsis or the unconsumed axes in the end. @@ -105,7 +105,8 @@ class TokenProcessor { virtual void PrepareResults(int64_t spec_rank, int64_t input_rank, int64_t output_rank) = 0; - virtual Status FinalizeResults(int64_t input_rank, int64_t output_rank) = 0; + virtual absl::Status FinalizeResults(int64_t input_rank, + int64_t output_rank) = 0; private: const int64_t input_rank_; @@ -182,7 +183,8 @@ class ForwardLayoutInference : public TokenProcessor { expander_value_sharding_.reserve(output_rank); } - Status FinalizeResults(int64_t input_rank, int64_t output_rank) override { + absl::Status FinalizeResults(int64_t input_rank, + int64_t output_rank) override { DCHECK_EQ(expander_input_sharding_.size(), input_rank); DCHECK_EQ(expander_value_sharding_.size(), output_rank); TF_ASSIGN_OR_RETURN( @@ -282,7 +284,8 @@ class BackwardLayoutInference : public TokenProcessor { expander_value_sharding_.reserve(output_rank); } - Status FinalizeResults(int64_t input_rank, int64_t output_rank) override { + absl::Status FinalizeResults(int64_t input_rank, + int64_t output_rank) override { DCHECK_EQ(expander_input_sharding_.size(), input_rank); DCHECK_EQ(expander_value_sharding_.size(), output_rank); TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/dtensor/cc/tpu_system_interface.h b/tensorflow/dtensor/cc/tpu_system_interface.h index c14065d7023f9d..bfcebf34d0c4f8 100644 --- a/tensorflow/dtensor/cc/tpu_system_interface.h +++ b/tensorflow/dtensor/cc/tpu_system_interface.h @@ -37,12 +37,12 @@ class TpuSystemInterface { public: virtual ~TpuSystemInterface() = default; - virtual Status Initialize(OpKernelContext* ctx, ResourceMgr* rmgr, - absl::Duration retry_timeout, - std::vector* core_id_output_vec, - bool use_tfrt_host_runtime) = 0; + virtual absl::Status Initialize(OpKernelContext* ctx, ResourceMgr* rmgr, + absl::Duration retry_timeout, + std::vector* core_id_output_vec, + bool use_tfrt_host_runtime) = 0; - virtual Status Shutdown() = 0; + virtual absl::Status Shutdown() = 0; virtual std::vector> TPUCoreIDsToLocations( TFE_Context* context, const std::vector& tpu_core_ids) = 0; diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index f304d843096efb..754b85be360632 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -283,7 +283,7 @@ cc_library( "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:status", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], alwayslink = 1, ) @@ -550,7 +550,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], alwayslink = True, ) diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index 55fc9c3258d818..c6e5061cd19354 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -598,7 +598,7 @@ createSubgroupsByTopoDist( // between two ops for (auto& all_reduce_group : all_reduce_groups) { std::vector new_group; - Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); // Sort AllReduces by topological level as the input order may not reflect // their dependencies on the operands in the compute graph. diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index 572f0870374c7e..9928d5d2f37bf6 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -745,8 +745,9 @@ mlir::LogicalResult BuildOuterMainFunc( return mlir::success(); } -Status ExtractResultLayouts(mlir::Operation* op, mlir::func::ReturnOp return_op, - std::vector& expanded_results) { +absl::Status ExtractResultLayouts( + mlir::Operation* op, mlir::func::ReturnOp return_op, + std::vector& expanded_results) { if (!return_op || (return_op.getNumOperands() == 0)) { return absl::OkStatus(); } @@ -838,7 +839,7 @@ struct DTensorMultiDeviceExpansion return_op ? return_op->getNumOperands() : 0); for (const mlir::TF::StatefulPartitionedCallOp& stateful_call_op : stateful_call_ops) { - const Status status = + const absl::Status status = ExtractResultLayouts(stateful_call_op, return_op, expanded_results); const StatusOr> mesh = status.ok() ? ExtractDeviceMeshFromOp(stateful_call_op) : status; diff --git a/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc index 88791958dec709..fb55163108280c 100644 --- a/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/concat_spmd_expander.cc @@ -38,8 +38,8 @@ namespace tensorflow { namespace dtensor { namespace { -Status VerifyConcatLayout(mlir::Value concat_dim_operand, - const Layout& concat_layout) { +absl::Status VerifyConcatLayout(mlir::Value concat_dim_operand, + const Layout& concat_layout) { TF_ASSIGN_OR_RETURN(int64_t concat_dim_value, ExtractConstIntFromValue(concat_dim_operand)); for (const auto& shard_and_dimension : diff --git a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc index 5957d73bc7e7ca..88d185d894257b 100644 --- a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc @@ -51,8 +51,8 @@ namespace dtensor { namespace { template -Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout, - ConvOp conv_op) { +absl::Status VerifyConvLayout(const Layout& input_layout, + const Layout& filter_layout, ConvOp conv_op) { if (!filter_layout.IsFullyReplicated()) return errors::InvalidArgument( "Filter for convolution must have fully replicated layout."); @@ -333,7 +333,7 @@ StatusOr HandleConvBackpropInput( } llvm::SmallVector global_shape; - Status extract_status = + absl::Status extract_status = ExtractConstVectorFromValue(conv_op.getInputSizes(), &global_shape); // If the input is dynamic size, we expect the output is all so dynamic size diff --git a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc index 8659a2d25699f3..ce40b37db11d66 100644 --- a/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/dtensor_op_spmd_expander.cc @@ -52,8 +52,8 @@ namespace { // 3. Src/target layouts are from different mesh. // 4. One of scr/target layout is from host mesh cluster. // 5. CPU host cluster mesh has 1 device. -Status ValidateSendRecvLayoutConfiguration(mlir::TF::DTensorSend dtensor_send, - mlir::TF::DTensorRecv dtensor_recv) { +absl::Status ValidateSendRecvLayoutConfiguration( + mlir::TF::DTensorSend dtensor_send, mlir::TF::DTensorRecv dtensor_recv) { // If either one of the send/recv ops has already been lowered, then send/recv // configuration has already been verified. if (!dtensor_send || !dtensor_recv) return absl::OkStatus(); diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index a89a07521eb939..a244089a2bffea 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -98,7 +98,7 @@ StatusOr EinsumSPMDExpander::ExpandOp(mlir::Operation* op) { // input_mappings: for each equation input, the map from the equation labels // to the tensor dimension of that label. // output_mapping: as above, but for the equation output. -Status ExtractEquationRelations( +absl::Status ExtractEquationRelations( absl::string_view equation, absl::flat_hash_set& reduced_dims, std::vector>>& input_mappings, absl::flat_hash_map>& output_mapping) { @@ -388,7 +388,7 @@ StatusOr> EinsumSPMDExpander::ComputeLayoutBackward( // for x is sharded. If both are sharded, we can compute the einsum on the // diagonal machines in the mesh and 0s on the off diagonals and then all // the much smaller matrix. -Status EinsumSPMDExpander::MaybeRelayoutInputs( +absl::Status EinsumSPMDExpander::MaybeRelayoutInputs( const std::vector& input_layouts, mlir::Operation* op, const Layout& output_layout, absl::flat_hash_set& reduce_dims, Layout& einsum_layout, std::vector& new_inputs) { diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h index 46e7ad23a83614..421ed4ccffe9ef 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h @@ -52,11 +52,11 @@ class EinsumSPMDExpander : public SPMDExpanderBase { // * The resulting output layout of the einsum operation, so we can insert an // AllConcat/split to make the output have the desired layout. // * The new inputs to fed into the einsum. - Status MaybeRelayoutInputs(const std::vector& input_layouts, - mlir::Operation* op, const Layout& output_layout, - absl::flat_hash_set& reduce_dims, - Layout& einsum_layout, - std::vector& new_inputs); + absl::Status MaybeRelayoutInputs( + const std::vector& input_layouts, mlir::Operation* op, + const Layout& output_layout, + absl::flat_hash_set& reduce_dims, Layout& einsum_layout, + std::vector& new_inputs); }; } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc index 31304474bb0c42..85b027d717ebce 100644 --- a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc @@ -70,8 +70,8 @@ bool IsComplexFFT(mlir::Value input) { return mlir::isa(data_type); } -Status IsProperFFTLength(mlir::Operation* op, - const llvm::SmallVector& fft_length_vec) { +absl::Status IsProperFFTLength( + mlir::Operation* op, const llvm::SmallVector& fft_length_vec) { TF_ASSIGN_OR_RETURN(auto input_layout, ExtractRequiredLayoutFromOperand(op->getOperand(0))); const Mesh& mesh = input_layout.mesh(); @@ -160,7 +160,7 @@ StatusOr EmitTransposeRelayout(mlir::OpBuilder& builder, return transposed_input; } -Status NormalizeAxes(std::vector& transform_axes, int input_rank) { +absl::Status NormalizeAxes(std::vector& transform_axes, int input_rank) { std::sort(transform_axes.begin(), transform_axes.end()); for (int i = 0; i < transform_axes.size(); ++i) { if (transform_axes[i] >= input_rank) { diff --git a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc index 285c9a40045f9c..57d11d9a5927c7 100644 --- a/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/gather_spmd_expander.cc @@ -245,11 +245,10 @@ StatusOr GatherNdGetOutputLayoutFromInput( return Layout::GetLayout(output_specs, mesh); } -Status GatherNdGetInputLayoutFromOutput(const Layout& output_layout, - Layout* params_layout, int params_rank, - Layout* indices_layout, - int indices_rank, int index_dimensions, - const Mesh& mesh) { +absl::Status GatherNdGetInputLayoutFromOutput( + const Layout& output_layout, Layout* params_layout, int params_rank, + Layout* indices_layout, int indices_rank, int index_dimensions, + const Mesh& mesh) { // We copy the first indices_rank - 1 dimensions of the output layout to // indices_layout (with the last dimensions replicated) and the remaining // dimensions to params_layout (with the first index_dimensions dimensions diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index 61d51226141168..4a82b0f62c601c 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -189,7 +189,7 @@ StatusOr MatMulSPMDExpander::OutputLayoutAndReducedDims( // * The resulting layout of the matmul tensor, so we can insert an AllConcat/ // split to make the output have the desired layout. // * The left and right value for use as input to the matmul. -Status MatMulSPMDExpander::MaybeRelayoutInputs( +absl::Status MatMulSPMDExpander::MaybeRelayoutInputs( mlir::Operation* op, const Layout& left_layout, bool left_transposed, const Layout& right_layout, bool right_transposed, const Layout& output_layout, std::string& reduced_dim, diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h index fedbe357254b5c..10fdd3c3cfb1e3 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.h @@ -62,11 +62,11 @@ class MatMulSPMDExpander : public SPMDExpanderBase { // matmul_layout will be set to the layout of the output of the local matmul // (after the above reduction). This may be different from the desired output // layout. - Status MaybeRelayoutInputs(mlir::Operation* op, const Layout& left_layout, - bool left_transposed, const Layout& right_layout, - bool right_transposed, const Layout& output_layout, - std::string& reduced_dim, Layout& matmul_layout, - mlir::Value& left, mlir::Value& right); + absl::Status MaybeRelayoutInputs( + mlir::Operation* op, const Layout& left_layout, bool left_transposed, + const Layout& right_layout, bool right_transposed, + const Layout& output_layout, std::string& reduced_dim, + Layout& matmul_layout, mlir::Value& left, mlir::Value& right); }; } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc index 7c7012c140b51c..959d61d75b0085 100644 --- a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc @@ -248,9 +248,9 @@ StatusOr> UnpackSPMDExpander::ComputeLayoutBackward( namespace { -Status VerifyPaddedDimensionNotSharded(const Layout& layout, - mlir::Value pad_input, - mlir::Value pad_output) { +absl::Status VerifyPaddedDimensionNotSharded(const Layout& layout, + mlir::Value pad_input, + mlir::Value pad_output) { auto input_type = mlir::dyn_cast(pad_input.getType()); auto output_type = mlir::dyn_cast(pad_output.getType()); @@ -350,8 +350,8 @@ StatusOr> PadSPMDExpander::ComputeLayoutBackward( namespace { -Status VerifyTileOperandLayout(const Layout& operand_layout, - llvm::ArrayRef static_multiples) { +absl::Status VerifyTileOperandLayout(const Layout& operand_layout, + llvm::ArrayRef static_multiples) { for (const auto& tensor_dim_and_multiple : llvm::enumerate(static_multiples)) { const auto& index = tensor_dim_and_multiple.index(); @@ -962,9 +962,9 @@ TransposeSPMDExpander::ComputeLayoutBackward( namespace { -Status RelayoutOneHotInput(const absl::optional& input_layout, - const absl::optional& output_layout, - const int axis, mlir::TF::OneHotOp& one_hot) { +absl::Status RelayoutOneHotInput(const absl::optional& input_layout, + const absl::optional& output_layout, + const int axis, mlir::TF::OneHotOp& one_hot) { if (!input_layout || !output_layout) return errors::InvalidArgument( "layout for tf.OneHot operation inputs and outputs must be known before" diff --git a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc index 785b36f23edf1d..f2706fbb01c015 100644 --- a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc @@ -47,7 +47,7 @@ namespace tensorflow { namespace dtensor { namespace { -Status CheckLayoutIsSupported(const Layout& layout) { +absl::Status CheckLayoutIsSupported(const Layout& layout) { // Currently we support small mesh rank for arbitrary layout. if (layout.mesh().rank() > 3) return errors::InvalidArgument("Large mesh rank size is not supported", @@ -56,7 +56,7 @@ Status CheckLayoutIsSupported(const Layout& layout) { return absl::OkStatus(); } -Status ValidateShapeAndGetNewShape( +absl::Status ValidateShapeAndGetNewShape( const llvm::SmallVector& op_shape, const Layout& layout, llvm::SmallVectorImpl& new_random_shape) { TF_RETURN_IF_ERROR(CheckLayoutIsSupported(layout)); diff --git a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc index 699733ec49abee..0ef6a854050459 100644 --- a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc @@ -55,7 +55,7 @@ absl::string_view DefiningOpName(mlir::Value operand) { return StringRefToView(operand.getDefiningOp()->getName().getStringRef()); } -Status AssertReplicated(mlir::Value operand) { +absl::Status AssertReplicated(mlir::Value operand) { TF_ASSIGN_OR_RETURN(auto layout, ExtractLayoutFromOperand(operand)); if (!layout) return absl::OkStatus(); @@ -80,9 +80,9 @@ absl::flat_hash_set ReducedMeshDimensions( } template -Status ExtractDims(mlir::Operation* op, - llvm::SmallVector* reduced_dims, bool* keep_dims, - bool* matched) { +absl::Status ExtractDims(mlir::Operation* op, + llvm::SmallVector* reduced_dims, + bool* keep_dims, bool* matched) { if (!llvm::isa(op)) return absl::OkStatus(); auto reduce_op = llvm::cast(op); *keep_dims = reduce_op.getKeepDims(); @@ -95,7 +95,7 @@ Status ExtractDims(mlir::Operation* op, } template <> -Status ExtractDims( +absl::Status ExtractDims( mlir::Operation* op, llvm::SmallVector* reduced_dims, bool* keep_dims, bool* matched) { if (!llvm::isa(op)) return absl::OkStatus(); @@ -111,7 +111,7 @@ Status ExtractDims( } template <> -Status ExtractDims( +absl::Status ExtractDims( mlir::Operation* op, llvm::SmallVector* reduced_dims, bool* keep_dims, bool* matched) { if (!llvm::isa(op)) return absl::OkStatus(); @@ -138,7 +138,7 @@ Status ExtractDims( } template <> -Status ExtractDims( +absl::Status ExtractDims( mlir::Operation* op, llvm::SmallVector* reduced_dims, bool* keep_dims, bool* matched) { if (!llvm::isa(op)) return absl::OkStatus(); @@ -148,9 +148,9 @@ Status ExtractDims( return absl::OkStatus(); } -Status ExtractReductionParameters(mlir::Operation* op, - absl::flat_hash_set& reduced_dims_set, - bool& keep_dims) { +absl::Status ExtractReductionParameters( + mlir::Operation* op, absl::flat_hash_set& reduced_dims_set, + bool& keep_dims) { llvm::SmallVector reduced_dims; bool matched = false; TF_RETURN_IF_ERROR(ExtractDims(op, &reduced_dims, diff --git a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc index 7b375fa4f704ae..b6734e12b6f6c1 100644 --- a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc @@ -105,10 +105,9 @@ StatusOr ExpandSummaryWriterOp(mlir::Operation* op) { return InferSPMDExpandedLocalShape(op); } -Status ValidateAndAssignResourceInputLayout(mlir::tf_device::ClusterOp op, - const std::string& layout_string, - const int resource_arg_index, - mlir::OpBuilder* builder) { +absl::Status ValidateAndAssignResourceInputLayout( + mlir::tf_device::ClusterOp op, const std::string& layout_string, + const int resource_arg_index, mlir::OpBuilder* builder) { const auto add_layout_as_attributes = [&](std::vector new_resource_layouts, std::vector new_resource_indices, int resource_arg_index, diff --git a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc index 737d1f562bb8ff..51acf884abe5db 100644 --- a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc @@ -205,7 +205,7 @@ StatusOr ConditionalSave( // the extraction failed to just ignore those values and work as if those are // empty. llvm::SmallVector original_shape_and_slices; - const Status extraction_status = ExtractConstStringVectorFromValue( + const absl::Status extraction_status = ExtractConstStringVectorFromValue( original_save.getShapeAndSlices(), original_shape_and_slices); if (extraction_status.ok()) { for (const std::string& shape_and_slice : original_shape_and_slices) { diff --git a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc index ed9b0bc4f49bed..4f39293b61f700 100644 --- a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc @@ -45,11 +45,11 @@ namespace tensorflow { namespace dtensor { namespace { -Status GetSliceOpArguments(mlir::TF::SliceOp slice_op, - llvm::SmallVector& begins, - bool& dynamic_begins, - llvm::SmallVector& sizes) { - Status begins_result = +absl::Status GetSliceOpArguments(mlir::TF::SliceOp slice_op, + llvm::SmallVector& begins, + bool& dynamic_begins, + llvm::SmallVector& sizes) { + absl::Status begins_result = ExtractConstVectorFromValue(slice_op.getBegin(), &begins); dynamic_begins = !begins_result.ok(); diff --git a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc index 5ea020e28470ae..f8113dcd57389e 100644 --- a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc @@ -115,11 +115,12 @@ StatusOr ComputeGlobalReduce( // Takes a sharded logits and compute both the shifted exponentiation of the // logits and its sum. Assumes that builder's insertion point is after logits. -Status ComputeExpAndSum(mlir::OpBuilder& builder, const mlir::Value& logits, - const Layout& logits_layout, - mlir::Value& shifted_logits, - mlir::Value& exp_of_shifted_logits, - mlir::Value& sum_of_exp) { +absl::Status ComputeExpAndSum(mlir::OpBuilder& builder, + const mlir::Value& logits, + const Layout& logits_layout, + mlir::Value& shifted_logits, + mlir::Value& exp_of_shifted_logits, + mlir::Value& sum_of_exp) { auto loc = logits.getLoc(); if (logits_layout.rank() == 0) diff --git a/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc index ab32210770ae0e..5a1bdc6e9b4f05 100644 --- a/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/strided_slice_spmd_expander.cc @@ -129,8 +129,8 @@ StatusOr> TokenizeOp(T strided_slice) { // this is the only meaningful change when a global Token vector is converted // to the local Token vector. template -Status UpdateOpFromTokens(T strided_slice, - const std::vector& tokens) { +absl::Status UpdateOpFromTokens(T strided_slice, + const std::vector& tokens) { mlir::OpBuilder builder(strided_slice); llvm::SmallVector end; end.reserve(tokens.size()); diff --git a/tensorflow/dtensor/mlir/layout_propagation_v2.cc b/tensorflow/dtensor/mlir/layout_propagation_v2.cc index a4d7aba21a2b30..9809c480529834 100644 --- a/tensorflow/dtensor/mlir/layout_propagation_v2.cc +++ b/tensorflow/dtensor/mlir/layout_propagation_v2.cc @@ -1355,7 +1355,7 @@ void FindRootsAndEmitError( // Runs an iteration of layout propagation, where we merge producer and consumer // requests and then recompute recommended layouts on all operations that // are connected to an updated layout. -Status RunOneIteration( +absl::Status RunOneIteration( llvm::DenseSet& is_locked, llvm::DenseSet& is_updated, llvm::DenseMap>& producer_request, @@ -1395,9 +1395,10 @@ Status RunOneIteration( // Compares every value's layouts in `merged_a` with the ones in `merged_b`, // and store the values that differ in `changed`. -Status CompareMergedLayouts(const llvm::DenseMap& merged_a, - const llvm::DenseMap& merged_b, - llvm::DenseSet& changed) { +absl::Status CompareMergedLayouts( + const llvm::DenseMap& merged_a, + const llvm::DenseMap& merged_b, + llvm::DenseSet& changed) { if (merged_a.size() != merged_b.size()) return errors::Internal( "Both merged_layouts did not have the same number of set layouts."); @@ -1490,16 +1491,16 @@ struct DLayoutPropagationPassV2 int stage = 0; llvm::DenseMap merged_layouts; - Status status; + absl::Status status; while (!is_updated.empty() && stage < kLayoutPropagationMaxStages) { ++stage; int steps = 0; // Step 1. Run the layout propagation v2 until convergence or max steps. while (!is_updated.empty() && steps < LayoutPropagationMaxSteps()) { - Status status = RunOneIteration(is_locked, is_updated, producer_request, - consumer_requests, producers, consumers, - merged_layouts, module, stage, &steps); + absl::Status status = RunOneIteration( + is_locked, is_updated, producer_request, consumer_requests, + producers, consumers, merged_layouts, module, stage, &steps); if (!status.ok()) { module.emitOpError() << "Failure running iteration."; return signalPassFailure(); diff --git a/tensorflow/dtensor/mlir/set_default_sharding.cc b/tensorflow/dtensor/mlir/set_default_sharding.cc index caca44e4a877ea..2a8c2fb58e4015 100644 --- a/tensorflow/dtensor/mlir/set_default_sharding.cc +++ b/tensorflow/dtensor/mlir/set_default_sharding.cc @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include "tensorflow/dtensor/cc/constants.h" namespace tensorflow { diff --git a/tensorflow/dtensor/mlir/shape_utils.cc b/tensorflow/dtensor/mlir/shape_utils.cc index 7b28d8e0e28d01..d0d7c3adc5de8a 100644 --- a/tensorflow/dtensor/mlir/shape_utils.cc +++ b/tensorflow/dtensor/mlir/shape_utils.cc @@ -230,7 +230,7 @@ mlir::LogicalResult InferShapeOfTFOpWithCustomOperandConstantFn( } // namespace -Status InferSPMDExpandedLocalShapeForResourceOutput( +absl::Status InferSPMDExpandedLocalShapeForResourceOutput( mlir::OpResult* op_result, const Layout& output_layout, mlir::MLIRContext* context) { if (llvm::isa( diff --git a/tensorflow/dtensor/mlir/shape_utils.h b/tensorflow/dtensor/mlir/shape_utils.h index 910bd5abec9dba..f525965b3b0262 100644 --- a/tensorflow/dtensor/mlir/shape_utils.h +++ b/tensorflow/dtensor/mlir/shape_utils.h @@ -35,9 +35,9 @@ StatusOr> ExtractGlobalOutputShape( // If result is a resource, the shape of the result should be adjusted to // local value of the resource, based on the layout for output. -Status InferSPMDExpandedLocalShapeForResourceOutput(mlir::OpResult* op_result, - const Layout& output_layout, - mlir::MLIRContext* context); +absl::Status InferSPMDExpandedLocalShapeForResourceOutput( + mlir::OpResult* op_result, const Layout& output_layout, + mlir::MLIRContext* context); // Returns op with recalculated local shape of `op` given all it's operands. mlir::Operation* InferSPMDExpandedLocalShape(mlir::Operation* op); diff --git a/tensorflow/dtensor/mlir/sparse_expander.cc b/tensorflow/dtensor/mlir/sparse_expander.cc index 87ed96c8a9b371..e827a55d122d1c 100644 --- a/tensorflow/dtensor/mlir/sparse_expander.cc +++ b/tensorflow/dtensor/mlir/sparse_expander.cc @@ -53,7 +53,7 @@ InitOnStartupMarker SparseExpanderRegistry::RegisterSparseExpansionFn( return {}; } -Status RunSparseExpansion(mlir::Operation* op, mlir::Operation** output) { +absl::Status RunSparseExpansion(mlir::Operation* op, mlir::Operation** output) { // Only expand if there are any SparseTensor inputs. if (HasAnySparseInput(op)) { SparseExpanderBase* expander = diff --git a/tensorflow/dtensor/mlir/sparse_expander.h b/tensorflow/dtensor/mlir/sparse_expander.h index fabbaed4df40f9..ba9e1ac5cd69a5 100644 --- a/tensorflow/dtensor/mlir/sparse_expander.h +++ b/tensorflow/dtensor/mlir/sparse_expander.h @@ -47,7 +47,7 @@ class SparseExpanderBase { }; // Computes the Sparse expansion for `op`. -Status RunSparseExpansion(mlir::Operation* op, mlir::Operation** output); +absl::Status RunSparseExpansion(mlir::Operation* op, mlir::Operation** output); // A registry of sparse SPMD expanders. This map is statically stored and // initialized with all the registered sparse SPMD expanders. diff --git a/tensorflow/dtensor/mlir/spmd_expander.cc b/tensorflow/dtensor/mlir/spmd_expander.cc index ce1bb02d4e954b..9ec8177b92c549 100644 --- a/tensorflow/dtensor/mlir/spmd_expander.cc +++ b/tensorflow/dtensor/mlir/spmd_expander.cc @@ -64,8 +64,9 @@ namespace { // descendent nodes. // User should not explicitly set a output parted layout and expect it to affect // the layout of ancestor nodes. -Status AdjustPartedLayout(const llvm::DenseMap& input_layouts, - llvm::DenseMap* computed_layouts) { +absl::Status AdjustPartedLayout( + const llvm::DenseMap& input_layouts, + llvm::DenseMap* computed_layouts) { // If any input has parted layout, propagate the parted layout to the layout // of all the computed values. bool input_has_parted_layout = false; @@ -149,8 +150,8 @@ InitOnStartupMarker SPMDExpanderRegistry::RegisterPropagateFn( return {}; } -Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, - mlir::Operation** output) { +absl::Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, + mlir::Operation** output) { TF_ASSIGN_OR_RETURN(std::vector> computed_layout, ExtractLayoutFromOp(op)); @@ -298,7 +299,7 @@ StatusOr> SPMDExpanderBase::ComputeLayoutBackward( return ComputeLayoutBackward(op, output_layouts); } -Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) { +absl::Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) { SPMDExpanderBase* expander = SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op); if (expander != nullptr) { diff --git a/tensorflow/dtensor/mlir/spmd_expander.h b/tensorflow/dtensor/mlir/spmd_expander.h index 52916c832cb58a..e5711bef5e06fc 100644 --- a/tensorflow/dtensor/mlir/spmd_expander.h +++ b/tensorflow/dtensor/mlir/spmd_expander.h @@ -115,14 +115,15 @@ class SPMDExpanderBase { // Run ExpandOp() and set layout from the computed layout from original op. // Returns the expanded op in output. - Status ExpandOpAndSetLayout(mlir::Operation* op, mlir::Operation** output); + absl::Status ExpandOpAndSetLayout(mlir::Operation* op, + mlir::Operation** output); }; // Computes the SPMD expansion for `op`. // // Prior to this call, all inputs to `op` have been lowered to local operations // & shapes. The lowered op must emit a type compatible with the local shape. -Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output); +absl::Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output); // A registry of SPMD expanders. This map is statically stored and initialized // with all the registered SPMD expanders. diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.cc b/tensorflow/dtensor/mlir/spmd_expander_common.cc index 596f590ed95eed..8defe605d9d81b 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.cc +++ b/tensorflow/dtensor/mlir/spmd_expander_common.cc @@ -114,9 +114,10 @@ StatusOr GlobalTypeFromLocalType( return new_output_type; } -Status CreateSplitOp(const int num_split, const int split_dimension, - const mlir::Location location, mlir::Value src_input, - mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op) { +absl::Status CreateSplitOp(const int num_split, const int split_dimension, + const mlir::Location location, mlir::Value src_input, + mlir::OpBuilder* builder, + mlir::TF::SplitOp* split_op) { // Creates a const op to hold split dimension value. auto split_dim_type = mlir::RankedTensorType::get({}, builder->getIntegerType(32)); @@ -693,8 +694,8 @@ mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, absl::StrCat(prefix, "_dtensor_function_", unique_id)); } -Status SetBuilderInsertionAfterValue(mlir::Value value, - mlir::OpBuilder& builder) { +absl::Status SetBuilderInsertionAfterValue(mlir::Value value, + mlir::OpBuilder& builder) { if (mlir::isa(value)) { builder.setInsertionPointAfterValue(value); return absl::OkStatus(); @@ -714,7 +715,8 @@ Status SetBuilderInsertionAfterValue(mlir::Value value, return absl::OkStatus(); } -Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") { +absl::Status PrintTensor(mlir::Value value, + const std::string& format_string = "%s") { mlir::OpBuilder builder(value.getContext()); builder.setInsertionPointAfterValue(value); TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(value)); @@ -731,7 +733,7 @@ Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") { return absl::OkStatus(); } -Status ExtractConstStringVectorFromValue( +absl::Status ExtractConstStringVectorFromValue( mlir::Value value, llvm::SmallVectorImpl& out_vector) { value = GetForwardedDTensorLayoutInput(value); if (mlir::isa(value)) diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.h b/tensorflow/dtensor/mlir/spmd_expander_common.h index 0a35ce8032b07b..4140fe8840a924 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.h +++ b/tensorflow/dtensor/mlir/spmd_expander_common.h @@ -72,9 +72,10 @@ StatusOr GlobalTypeFromLocalType( // Creates a tf::SplitOp that splits 'src_input' into 'num_splits' ways // in 'split_dimension' dimension and returns the split values. -Status CreateSplitOp(int num_split, int split_dimension, - mlir::Location location, mlir::Value src_input, - mlir::OpBuilder* builder, mlir::TF::SplitOp* split_op); +absl::Status CreateSplitOp(int num_split, int split_dimension, + mlir::Location location, mlir::Value src_input, + mlir::OpBuilder* builder, + mlir::TF::SplitOp* split_op); // Given layouts + shapes, determines if the two are broadcast compatible. // See source file for more documentation. @@ -171,15 +172,15 @@ mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, // argument, this checks that all users of the value are in the same cluster. // If not it errors out. If they are then it sets the inserition point to the // top of the cluster. -Status SetBuilderInsertionAfterValue(mlir::Value value, - mlir::OpBuilder& builder); +absl::Status SetBuilderInsertionAfterValue(mlir::Value value, + mlir::OpBuilder& builder); // Inserts a StringFormat and Print op, should only be used for debugging // on CPU. -Status PrintTensor(mlir::Value value, const std::string& format_string); +absl::Status PrintTensor(mlir::Value value, const std::string& format_string); // Extract a vector of string from mlir value. -Status ExtractConstStringVectorFromValue( +absl::Status ExtractConstStringVectorFromValue( mlir::Value value, llvm::SmallVectorImpl& out_vector); StatusOr ExtractConstScalarStringFromValue(mlir::Value value); diff --git a/tensorflow/dtensor/mlir/tpu_integration.cc b/tensorflow/dtensor/mlir/tpu_integration.cc index dd2694fc004e46..0fde0fc26e1d06 100644 --- a/tensorflow/dtensor/mlir/tpu_integration.cc +++ b/tensorflow/dtensor/mlir/tpu_integration.cc @@ -40,7 +40,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include "tensorflow/dtensor/cc/constants.h" #include "tensorflow/dtensor/cc/tensor_layout.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" diff --git a/tensorflow/dtensor/mlir/value_utils.cc b/tensorflow/dtensor/mlir/value_utils.cc index b3eeccd78c8411..5c1d9a80cc23c4 100644 --- a/tensorflow/dtensor/mlir/value_utils.cc +++ b/tensorflow/dtensor/mlir/value_utils.cc @@ -36,13 +36,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/dtensor/cc/dstatus.h" #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" #include "tensorflow/dtensor/mlir/op_utils.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/dtensor/proto/BUILD b/tensorflow/dtensor/proto/BUILD index da9cd8002c8d87..8694130645b933 100644 --- a/tensorflow/dtensor/proto/BUILD +++ b/tensorflow/dtensor/proto/BUILD @@ -27,7 +27,6 @@ alias( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "layout_proto_py_pb2", -# api_version = 2, # deps = [":layout_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/dtensor/python/tests/BUILD b/tensorflow/dtensor/python/tests/BUILD index 8e61e6083a849d..4fcf858ba7c345 100644 --- a/tensorflow/dtensor/python/tests/BUILD +++ b/tensorflow/dtensor/python/tests/BUILD @@ -394,6 +394,9 @@ dtensor_test( "tpu", ], main = "layout_test.py", + tags = [ + "no_windows", + ], deps = [ ":test_util", "//tensorflow/dtensor/python:api", @@ -460,6 +463,9 @@ dtensor_test( TPU_V3_DONUT_BACKEND, GPU_2DEVS_BACKEND, ], + env = { + "TF_FORCE_GPU_ALLOW_GROWTH": "true", + }, deps = [ ":test_util", "//tensorflow/dtensor/python:config", diff --git a/tensorflow/dtensor/tests/BUILD b/tensorflow/dtensor/tests/BUILD index 8cc18e61e97d16..464d110468de23 100644 --- a/tensorflow/dtensor/tests/BUILD +++ b/tensorflow/dtensor/tests/BUILD @@ -79,7 +79,7 @@ tf_cc_test( "//tensorflow/dtensor/cc:tensor_layout", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/dtensor/tests/executable_manager_test.cc b/tensorflow/dtensor/tests/executable_manager_test.cc index e1d39218056b70..2784d3a1fe8e50 100644 --- a/tensorflow/dtensor/tests/executable_manager_test.cc +++ b/tensorflow/dtensor/tests/executable_manager_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/refcount.h" @@ -22,7 +23,6 @@ limitations under the License. #include "tensorflow/dtensor/cc/dtensor_operation.h" #include "tensorflow/dtensor/cc/tensor_layout.h" #include "tsl/platform/status_matchers.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace dtensor { diff --git a/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_kernel.cc b/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_kernel.cc index bf42b49b55cd36..10249ce5a83f98 100644 --- a/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_kernel.cc +++ b/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_kernel.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + // Please use the appropriate namespace for your project namespace tensorflow { namespace custom_op_examples { diff --git a/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_op.cc b/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_op.cc index 0748d0efed8c2e..1b5c5070c23d94 100644 --- a/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_op.cc +++ b/tensorflow/examples/custom_ops_doc/multiplex_2/multiplex_2_op.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" // Use a namespace when registering by prepending the // package's name to the op’s name and separate with a '>'. diff --git a/tensorflow/examples/custom_ops_doc/multiplex_4/multiplex_4_op.cc b/tensorflow/examples/custom_ops_doc/multiplex_4/multiplex_4_op.cc index 4a1f5b9a877dd2..fd26eb0f07ca3b 100644 --- a/tensorflow/examples/custom_ops_doc/multiplex_4/multiplex_4_op.cc +++ b/tensorflow/examples/custom_ops_doc/multiplex_4/multiplex_4_op.cc @@ -30,7 +30,7 @@ namespace custom_op_examples { using ::tensorflow::shape_inference::InferenceContext; -Status MultiplexShapeFunction(InferenceContext* c) { +absl::Status MultiplexShapeFunction(InferenceContext* c) { int64_t num_cond_a; TF_RETURN_IF_ERROR(c->GetAttr("N", &num_cond_a)); tensorflow::shape_inference::ShapeHandle unused; diff --git a/tensorflow/examples/multibox_detector/main.cc b/tensorflow/examples/multibox_detector/main.cc index 9bc1aa36b5326d..3ed053a7e58627 100644 --- a/tensorflow/examples/multibox_detector/main.cc +++ b/tensorflow/examples/multibox_detector/main.cc @@ -47,12 +47,12 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/numbers.h" #include "tensorflow/core/platform/path.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" // These are all common classes it's handy to reference with no namespace. diff --git a/tensorflow/examples/speech_commands/accuracy_utils.cc b/tensorflow/examples/speech_commands/accuracy_utils.cc index cbd91c817204f2..c647dff760c7a0 100644 --- a/tensorflow/examples/speech_commands/accuracy_utils.cc +++ b/tensorflow/examples/speech_commands/accuracy_utils.cc @@ -24,13 +24,12 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/numbers.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { -Status ReadGroundTruthFile(const string& file_name, - std::vector>* result) { +absl::Status ReadGroundTruthFile( + const string& file_name, std::vector>* result) { std::ifstream file(file_name); if (!file) { return tensorflow::errors::NotFound("Ground truth file '", file_name, diff --git a/tensorflow/examples/speech_commands/accuracy_utils.h b/tensorflow/examples/speech_commands/accuracy_utils.h index d40a441f60e1a7..4fbeab14870855 100644 --- a/tensorflow/examples/speech_commands/accuracy_utils.h +++ b/tensorflow/examples/speech_commands/accuracy_utils.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -40,8 +41,8 @@ struct StreamingAccuracyStats { // Takes a file name, and loads a list of expected word labels and times from // it, as comma-separated variables. -Status ReadGroundTruthFile(const string& file_name, - std::vector>* result); +absl::Status ReadGroundTruthFile( + const string& file_name, std::vector>* result); // Given ground truth labels and corresponding predictions found by a model, // figure out how many were correct. Takes a time limit, so that only diff --git a/tensorflow/examples/speech_commands/label_wav.cc b/tensorflow/examples/speech_commands/label_wav.cc index 7ea4e72c165218..7ecd7e9cf6d7e9 100644 --- a/tensorflow/examples/speech_commands/label_wav.cc +++ b/tensorflow/examples/speech_commands/label_wav.cc @@ -27,12 +27,12 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/command_line_flags.h" #include "tsl/platform/env.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" // These are all common classes it's handy to reference with no namespace. diff --git a/tensorflow/examples/speech_commands/recognize_commands.cc b/tensorflow/examples/speech_commands/recognize_commands.cc index f209b5567a2594..abe2558363b074 100644 --- a/tensorflow/examples/speech_commands/recognize_commands.cc +++ b/tensorflow/examples/speech_commands/recognize_commands.cc @@ -18,7 +18,6 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -38,11 +37,9 @@ RecognizeCommands::RecognizeCommands(const std::vector& labels, previous_top_label_time_ = std::numeric_limits::min(); } -Status RecognizeCommands::ProcessLatestResults(const Tensor& latest_results, - const int64_t current_time_ms, - string* found_command, - float* score, - bool* is_new_command) { +absl::Status RecognizeCommands::ProcessLatestResults( + const Tensor& latest_results, const int64_t current_time_ms, + string* found_command, float* score, bool* is_new_command) { if (latest_results.NumElements() != labels_count_) { return errors::InvalidArgument( "The results for recognition should contain ", labels_count_, diff --git a/tensorflow/examples/speech_commands/recognize_commands.h b/tensorflow/examples/speech_commands/recognize_commands.h index a075b65d408af3..a2d4c4191b0c0b 100644 --- a/tensorflow/examples/speech_commands/recognize_commands.h +++ b/tensorflow/examples/speech_commands/recognize_commands.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -55,10 +56,10 @@ class RecognizeCommands { int32_t minimum_count = 3); // Call this with the results of running a model on sample data. - Status ProcessLatestResults(const Tensor& latest_results, - const int64_t current_time_ms, - string* found_command, float* score, - bool* is_new_command); + absl::Status ProcessLatestResults(const Tensor& latest_results, + const int64_t current_time_ms, + string* found_command, float* score, + bool* is_new_command); private: // Configuration diff --git a/tensorflow/examples/speech_commands/test_streaming_accuracy.cc b/tensorflow/examples/speech_commands/test_streaming_accuracy.cc index 8beeaffdd9ae6c..672f4070ff750b 100644 --- a/tensorflow/examples/speech_commands/test_streaming_accuracy.cc +++ b/tensorflow/examples/speech_commands/test_streaming_accuracy.cc @@ -79,13 +79,13 @@ bazel run tensorflow/examples/speech_commands:test_streaming_accuracy -- \ #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/examples/speech_commands/accuracy_utils.h" #include "tensorflow/examples/speech_commands/recognize_commands.h" #include "tsl/platform/env.h" -#include "tsl/platform/status.h" #include "tsl/platform/types.h" // These are all common classes it's handy to reference with no namespace. diff --git a/tensorflow/go/README.md b/tensorflow/go/README.md index 01be92188df6cb..6f1f34024eb59e 100644 --- a/tensorflow/go/README.md +++ b/tensorflow/go/README.md @@ -10,7 +10,7 @@ Construct and execute TensorFlow graphs in Go. # WARNING: -The TensorFlow team is not currently maintaning the Documentation for installing the Go bindings for TensorFlow. +The TensorFlow team is not currently maintaining the Documentation for installing the Go bindings for TensorFlow. The instructions has been maintained by the third party contributor: @wamuir diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index c86ebca819cf24..a0bfea1d34accc 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -1,6 +1,7 @@ # Description: # TensorFlow Java API. +load("@rules_java//java:defs.bzl", "java_library", "java_plugin") load( "//tensorflow:tensorflow.bzl", "VERSION", diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 35ed2d517e9241..f9c092b138fb76 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -560,7 +560,7 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, } } } - return OkStatus(); + return absl::OkStatus(); } } // namespace java diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD index c8e8abbf1c4947..c1e7724652148b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD +++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD @@ -1,6 +1,8 @@ # Description: # TensorFlow Java examples. +load("@rules_java//java:defs.bzl", "java_binary") + package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = ["//visibility:private"], diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 5bf1aaae11fc7a..9e700fcecfa792 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -153,6 +153,8 @@ filegroup( ":minimal_logging_android.cc", ":mutable_op_resolver.cc", ":mutable_op_resolver.h", + ":mutable_op_resolver_utils.cc", + ":mutable_op_resolver_utils.h", ":op_resolver.h", ":portable_type_to_tflitetype.h", ":stderr_reporter.cc", @@ -276,6 +278,9 @@ cc_test( "tflite_not_portable_android", ], deps = [ + ":cc_api_stable", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:subgraph_test_util", "//tensorflow/lite/testing:util", "@com_google_googletest//:gtest_main", @@ -1226,6 +1231,7 @@ cc_library( compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings() + tflite_copts(), deps = [ + "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:common", ], ) diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index bce9627fbd3381..719bcc7572664d 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -205,6 +205,16 @@ if(TFLITE_ENABLE_XNNPACK) "${CMAKE_BINARY_DIR}/pthreadpool") endif() list(APPEND TFLITE_TARGET_DEPENDENCIES pthreadpool) + + IF(NOT DEFINED FP16_SOURCE_DIR) + MESSAGE(STATUS "Downloading FP16 to ${CMAKE_BINARY_DIR}/FP16-source (define FP16_SOURCE_DIR to avoid it)") + CONFIGURE_FILE(cmake/DownloadFP16.cmake "${CMAKE_BINARY_DIR}/FP16-download/CMakeLists.txt") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") + EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" --build . + WORKING_DIRECTORY "${CMAKE_BINARY_DIR}/FP16-download") + SET(FP16_SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" CACHE STRING "FP16 source directory") + ENDIF() endif() set(TF_TARGET_PRIVATE_OPTIONS "") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang$") diff --git a/tensorflow/lite/README.md b/tensorflow/lite/README.md index 589d4f93481e50..471176f51bf3d3 100644 --- a/tensorflow/lite/README.md +++ b/tensorflow/lite/README.md @@ -4,5 +4,6 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. -See the documentation: https://www.tensorflow.org/lite/ -Documentation edits can be made here: [tensorflow/lite/g3doc](./g3doc/) +- See the documentation: https://www.tensorflow.org/lite/ + +- Documentation edits can be made here: [tensorflow/lite/g3doc](./g3doc/) diff --git a/tensorflow/lite/acceleration/configuration/configuration.proto b/tensorflow/lite/acceleration/configuration/configuration.proto index 657ae5e12daf7b..29d911bd1b05b1 100644 --- a/tensorflow/lite/acceleration/configuration/configuration.proto +++ b/tensorflow/lite/acceleration/configuration/configuration.proto @@ -330,7 +330,9 @@ enum XNNPackFlags { message XNNPackSettings { optional int32 num_threads = 1; - optional XNNPackFlags flags = 2 [default = TFLITE_XNNPACK_DELEGATE_NO_FLAGS]; + // If flags is unset or zero, it means use the default XNNPack delegate flags. + // Any other value means use exactly (and only) the flags specified. + optional XNNPackFlags flags = 2; // Path to the XNNPack cache file. XNNPack packed buffers are saved to and // reloaded from this cache which can reduce initialization time and the // packing memory footprint. @@ -787,6 +789,17 @@ message MtkNeuronSettings { // Optional path to the platform-dependent Neuron configuration file. // See docs at https://neuropilot.mediatek.com/ for more details. optional string neuron_config_path = 10; + + // The deadline time duration (in ms) of the inference (waiting + execution). + // The scheduler would adjust scheduling based on this value. Note that + // setting this value to zero implies no deadline requirement. + optional int32 inference_deadline_ms = 11; + + // The maximum inference (waiting + execution) time duration (in ms). The + // scheduler would abort the inference if the inference time dutation exceed + // the time specified. Note that setting this value to zero implies no abort + // time requirement. + optional int32 inference_abort_time_ms = 12; } // How to configure TFLite. @@ -1129,4 +1142,4 @@ message BenchmarkEventStorage { optional BenchmarkEvent benchmark_event = 2; } -// LINT.ThenChange(//tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev) +// LINT.ThenChange(//tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev:all) diff --git a/tensorflow/lite/acceleration/configuration/configuration_generated.h b/tensorflow/lite/acceleration/configuration/configuration_generated.h index 2820f7c43b522c..4cb4861e78f4f4 100644 --- a/tensorflow/lite/acceleration/configuration/configuration_generated.h +++ b/tensorflow/lite/acceleration/configuration/configuration_generated.h @@ -528,27 +528,20 @@ inline const XNNPackFlags (&EnumValuesXNNPackFlags())[10] { return values; } -inline const char * const *EnumNamesXNNPackFlags() { - static const char * const names[11] = { - "TFLITE_XNNPACK_DELEGATE_NO_FLAGS", - "TFLITE_XNNPACK_DELEGATE_FLAG_QS8", - "TFLITE_XNNPACK_DELEGATE_FLAG_QU8", - "TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8", - "TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16", - "TFLITE_XNNPACK_DELEGATE_FLAG_DYNAMIC_FULLY_CONNECTED", - "TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS", - "TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER", - "TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS", - "TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING", - nullptr - }; - return names; -} - inline const char *EnumNameXNNPackFlags(XNNPackFlags e) { - if (::flatbuffers::IsOutRange(e, XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS, XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING)) return ""; - const size_t index = static_cast(e); - return EnumNamesXNNPackFlags()[index]; + switch (e) { + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS: return "TFLITE_XNNPACK_DELEGATE_NO_FLAGS"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8: return "TFLITE_XNNPACK_DELEGATE_FLAG_QS8"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QU8: return "TFLITE_XNNPACK_DELEGATE_FLAG_QU8"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8: return "TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16: return "TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_DYNAMIC_FULLY_CONNECTED: return "TFLITE_XNNPACK_DELEGATE_FLAG_DYNAMIC_FULLY_CONNECTED"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS: return "TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER: return "TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS: return "TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS"; + case XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING: return "TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING"; + default: return ""; + } } namespace CoreMLSettings_ { @@ -2835,6 +2828,8 @@ struct MtkNeuronSettingsT : public ::flatbuffers::NativeTable { std::vector compile_options{}; std::vector accelerator_names{}; std::string neuron_config_path{}; + int32_t inference_deadline_ms = 0; + int32_t inference_abort_time_ms = 0; }; struct MtkNeuronSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table { @@ -2850,7 +2845,9 @@ struct MtkNeuronSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table VT_USE_CACHEABLE_BUFFER = 16, VT_COMPILE_OPTIONS = 18, VT_ACCELERATOR_NAMES = 20, - VT_NEURON_CONFIG_PATH = 22 + VT_NEURON_CONFIG_PATH = 22, + VT_INFERENCE_DEADLINE_MS = 24, + VT_INFERENCE_ABORT_TIME_MS = 26 }; tflite::MtkNeuronSettings_::ExecutionPreference execution_preference() const { return static_cast(GetField(VT_EXECUTION_PREFERENCE, 0)); @@ -2882,6 +2879,12 @@ struct MtkNeuronSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table const ::flatbuffers::String *neuron_config_path() const { return GetPointer(VT_NEURON_CONFIG_PATH); } + int32_t inference_deadline_ms() const { + return GetField(VT_INFERENCE_DEADLINE_MS, 0); + } + int32_t inference_abort_time_ms() const { + return GetField(VT_INFERENCE_ABORT_TIME_MS, 0); + } bool Verify(::flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_EXECUTION_PREFERENCE, 4) && @@ -2900,6 +2903,8 @@ struct MtkNeuronSettings FLATBUFFERS_FINAL_CLASS : private ::flatbuffers::Table verifier.VerifyVectorOfStrings(accelerator_names()) && VerifyOffset(verifier, VT_NEURON_CONFIG_PATH) && verifier.VerifyString(neuron_config_path()) && + VerifyField(verifier, VT_INFERENCE_DEADLINE_MS, 4) && + VerifyField(verifier, VT_INFERENCE_ABORT_TIME_MS, 4) && verifier.EndTable(); } MtkNeuronSettingsT *UnPack(const ::flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2941,6 +2946,12 @@ struct MtkNeuronSettingsBuilder { void add_neuron_config_path(::flatbuffers::Offset<::flatbuffers::String> neuron_config_path) { fbb_.AddOffset(MtkNeuronSettings::VT_NEURON_CONFIG_PATH, neuron_config_path); } + void add_inference_deadline_ms(int32_t inference_deadline_ms) { + fbb_.AddElement(MtkNeuronSettings::VT_INFERENCE_DEADLINE_MS, inference_deadline_ms, 0); + } + void add_inference_abort_time_ms(int32_t inference_abort_time_ms) { + fbb_.AddElement(MtkNeuronSettings::VT_INFERENCE_ABORT_TIME_MS, inference_abort_time_ms, 0); + } explicit MtkNeuronSettingsBuilder(::flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2963,8 +2974,12 @@ inline ::flatbuffers::Offset CreateMtkNeuronSettings( bool use_cacheable_buffer = true, ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> compile_options = 0, ::flatbuffers::Offset<::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>> accelerator_names = 0, - ::flatbuffers::Offset<::flatbuffers::String> neuron_config_path = 0) { + ::flatbuffers::Offset<::flatbuffers::String> neuron_config_path = 0, + int32_t inference_deadline_ms = 0, + int32_t inference_abort_time_ms = 0) { MtkNeuronSettingsBuilder builder_(_fbb); + builder_.add_inference_abort_time_ms(inference_abort_time_ms); + builder_.add_inference_deadline_ms(inference_deadline_ms); builder_.add_neuron_config_path(neuron_config_path); builder_.add_accelerator_names(accelerator_names); builder_.add_compile_options(compile_options); @@ -2989,7 +3004,9 @@ inline ::flatbuffers::Offset CreateMtkNeuronSettingsDirect( bool use_cacheable_buffer = true, const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *compile_options = nullptr, const std::vector<::flatbuffers::Offset<::flatbuffers::String>> *accelerator_names = nullptr, - const char *neuron_config_path = nullptr) { + const char *neuron_config_path = nullptr, + int32_t inference_deadline_ms = 0, + int32_t inference_abort_time_ms = 0) { auto optimization_hints__ = optimization_hints ? _fbb.CreateVector(*optimization_hints) : 0; auto compile_options__ = compile_options ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*compile_options) : 0; auto accelerator_names__ = accelerator_names ? _fbb.CreateVector<::flatbuffers::Offset<::flatbuffers::String>>(*accelerator_names) : 0; @@ -3005,7 +3022,9 @@ inline ::flatbuffers::Offset CreateMtkNeuronSettingsDirect( use_cacheable_buffer, compile_options__, accelerator_names__, - neuron_config_path__); + neuron_config_path__, + inference_deadline_ms, + inference_abort_time_ms); } ::flatbuffers::Offset CreateMtkNeuronSettings(::flatbuffers::FlatBufferBuilder &_fbb, const MtkNeuronSettingsT *_o, const ::flatbuffers::rehasher_function_t *_rehasher = nullptr); @@ -5534,7 +5553,9 @@ inline bool operator==(const MtkNeuronSettingsT &lhs, const MtkNeuronSettingsT & (lhs.use_cacheable_buffer == rhs.use_cacheable_buffer) && (lhs.compile_options == rhs.compile_options) && (lhs.accelerator_names == rhs.accelerator_names) && - (lhs.neuron_config_path == rhs.neuron_config_path); + (lhs.neuron_config_path == rhs.neuron_config_path) && + (lhs.inference_deadline_ms == rhs.inference_deadline_ms) && + (lhs.inference_abort_time_ms == rhs.inference_abort_time_ms); } inline bool operator!=(const MtkNeuronSettingsT &lhs, const MtkNeuronSettingsT &rhs) { @@ -5561,6 +5582,8 @@ inline void MtkNeuronSettings::UnPackTo(MtkNeuronSettingsT *_o, const ::flatbuff { auto _e = compile_options(); if (_e) { _o->compile_options.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->compile_options[_i] = _e->Get(_i)->str(); } } else { _o->compile_options.resize(0); } } { auto _e = accelerator_names(); if (_e) { _o->accelerator_names.resize(_e->size()); for (::flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->accelerator_names[_i] = _e->Get(_i)->str(); } } else { _o->accelerator_names.resize(0); } } { auto _e = neuron_config_path(); if (_e) _o->neuron_config_path = _e->str(); } + { auto _e = inference_deadline_ms(); _o->inference_deadline_ms = _e; } + { auto _e = inference_abort_time_ms(); _o->inference_abort_time_ms = _e; } } inline ::flatbuffers::Offset MtkNeuronSettings::Pack(::flatbuffers::FlatBufferBuilder &_fbb, const MtkNeuronSettingsT* _o, const ::flatbuffers::rehasher_function_t *_rehasher) { @@ -5581,6 +5604,8 @@ inline ::flatbuffers::Offset CreateMtkNeuronSettings(::flatbu auto _compile_options = _o->compile_options.size() ? _fbb.CreateVectorOfStrings(_o->compile_options) : 0; auto _accelerator_names = _o->accelerator_names.size() ? _fbb.CreateVectorOfStrings(_o->accelerator_names) : 0; auto _neuron_config_path = _o->neuron_config_path.empty() ? 0 : _fbb.CreateString(_o->neuron_config_path); + auto _inference_deadline_ms = _o->inference_deadline_ms; + auto _inference_abort_time_ms = _o->inference_abort_time_ms; return tflite::CreateMtkNeuronSettings( _fbb, _execution_preference, @@ -5592,7 +5617,9 @@ inline ::flatbuffers::Offset CreateMtkNeuronSettings(::flatbu _use_cacheable_buffer, _compile_options, _accelerator_names, - _neuron_config_path); + _neuron_config_path, + _inference_deadline_ms, + _inference_abort_time_ms); } diff --git a/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto.cc b/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto.cc index 87f822cf40c308..8ed43c971e1ad0 100644 --- a/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto.cc +++ b/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto.cc @@ -469,6 +469,10 @@ proto::MtkNeuronSettings ConvertMtkNeuronSettings( proto_settings.set_neuron_config_path(settings.neuron_config_path()->str()); } + proto_settings.set_inference_deadline_ms(settings.inference_deadline_ms()); + proto_settings.set_inference_abort_time_ms( + settings.inference_abort_time_ms()); + return proto_settings; } diff --git a/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto_test.cc b/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto_test.cc index bd5ef446b088f5..7ac336c848ed18 100644 --- a/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto_test.cc +++ b/tensorflow/lite/acceleration/configuration/flatbuffer_to_proto_test.cc @@ -572,6 +572,8 @@ TEST_F(ConversionTest, MtkNeuronSettings) { input_settings->compile_options = {"TEST_COMPILE_OPTIONS"}; input_settings->accelerator_names = {"TEST_ACCELERATOR_NAME"}; input_settings->neuron_config_path = "TEST_NEURON_CONFIG_PATH"; + input_settings->inference_deadline_ms = 1337; + input_settings->inference_abort_time_ms = 42; const proto::ComputeSettings compute = ConvertFromFlatbuffer(settings_); const proto::MtkNeuronSettings& output_settings = @@ -596,6 +598,8 @@ TEST_F(ConversionTest, MtkNeuronSettings) { EXPECT_EQ(output_settings.accelerator_names().size(), 1); EXPECT_EQ(output_settings.accelerator_names().at(0), "TEST_ACCELERATOR_NAME"); EXPECT_EQ(output_settings.neuron_config_path(), "TEST_NEURON_CONFIG_PATH"); + EXPECT_EQ(output_settings.inference_deadline_ms(), 1337); + EXPECT_EQ(output_settings.inference_abort_time_ms(), 42); } TEST_F(ConversionTest, MiniBenchmarkSettings) { diff --git a/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer.cc b/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer.cc index 366869babe63d3..a35c259e206163 100644 --- a/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer.cc +++ b/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer.cc @@ -428,7 +428,8 @@ Offset ConvertMtkNeuronSettings( settings.compile_options().end()), builder.CreateVectorOfStrings(settings.accelerator_names().begin(), settings.accelerator_names().end()), - builder.CreateString(settings.neuron_config_path())); + builder.CreateString(settings.neuron_config_path()), + settings.inference_deadline_ms(), settings.inference_abort_time_ms()); } Offset ConvertCoralSettings(const proto::CoralSettings& settings, diff --git a/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer_test.cc b/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer_test.cc index e504f09c6825fe..c0db23fad047e0 100644 --- a/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer_test.cc +++ b/tensorflow/lite/acceleration/configuration/proto_to_flatbuffer_test.cc @@ -184,6 +184,8 @@ TEST(ConversionTest, MtkNeuronSettings) { const std::string kCompileOptions = "TEST_COMPILE_OPTIONS"; const std::string kAcceleratorName = "TEST_ACCELERATOR_NAME"; const std::string kNeuronConfigPath = "TEST_NEURON_CONFIG_PATH"; + const int32_t kInferenceDeadlineMs = 1337; + const int32_t kInferenceAbortTimeMs = 42; // Create the proto settings. proto::TFLiteSettings input_settings; @@ -198,6 +200,8 @@ TEST(ConversionTest, MtkNeuronSettings) { mtk_neuron_settings->add_compile_options(kCompileOptions); mtk_neuron_settings->add_accelerator_names(kAcceleratorName); mtk_neuron_settings->set_neuron_config_path(kNeuronConfigPath); + mtk_neuron_settings->set_inference_deadline_ms(kInferenceDeadlineMs); + mtk_neuron_settings->set_inference_abort_time_ms(kInferenceAbortTimeMs); flatbuffers::FlatBufferBuilder flatbuffers_builder; // Convert. @@ -231,6 +235,10 @@ TEST(ConversionTest, MtkNeuronSettings) { kAcceleratorName); EXPECT_EQ(output_mtk_neuron_settings->neuron_config_path()->str(), kNeuronConfigPath); + EXPECT_EQ(output_mtk_neuron_settings->inference_deadline_ms(), + kInferenceDeadlineMs); + EXPECT_EQ(output_mtk_neuron_settings->inference_abort_time_ms(), + kInferenceAbortTimeMs); } } // namespace diff --git a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev index 98adc6d31f05b9..569042d3c88e7b 100644 --- a/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev +++ b/tensorflow/lite/acceleration/configuration/testdata/configuration.proto_prev @@ -310,11 +310,27 @@ enum XNNPackFlags { TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8 = 3; // Force 16-bit floating point inference. TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16 = 4; + // Enable XNNPACK acceleration for FULLY_CONNECTED operator with dynamic + // weights. + TFLITE_XNNPACK_DELEGATE_FLAG_DYNAMIC_FULLY_CONNECTED = 8; + // Enable XNNPACK acceleration for VAR_HANDLE, READ_VARIABLE, and + // ASSIGN_VARIABLE operators. + TFLITE_XNNPACK_DELEGATE_FLAG_VARIABLE_OPERATORS = 16; + // Enable transient indirection buffer to reduce memory usage in selected + // operators. + TFLITE_XNNPACK_DELEGATE_FLAG_TRANSIENT_INDIRECTION_BUFFER = 32; + // Enable the latest XNNPACK operators and features in the delegate which have + // not yet been enabled by default. + TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_LATEST_OPERATORS = 64; + // Enable XNNPack subgraph reshaping. + TFLITE_XNNPACK_DELEGATE_FLAG_ENABLE_SUBGRAPH_RESHAPING = 128; } message XNNPackSettings { optional int32 num_threads = 1; - optional XNNPackFlags flags = 2 [default = TFLITE_XNNPACK_DELEGATE_NO_FLAGS]; + // If flags is unset or zero, it means use the default XNNPack delegate flags. + // Any other value means use exactly (and only) the flags specified. + optional XNNPackFlags flags = 2; // Path to the XNNPack cache file. XNNPack packed buffers are saved to and // reloaded from this cache which can reduce initialization time and the // packing memory footprint. diff --git a/tensorflow/lite/allocation_test.cc b/tensorflow/lite/allocation_test.cc index c17f52f25a338b..fcc6e9ccc54e67 100644 --- a/tensorflow/lite/allocation_test.cc +++ b/tensorflow/lite/allocation_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" #if defined(__linux__) #include diff --git a/tensorflow/lite/arena_planner_subgraph_test.cc b/tensorflow/lite/arena_planner_subgraph_test.cc index 41769abb03e2c6..cee6b516e3c088 100644 --- a/tensorflow/lite/arena_planner_subgraph_test.cc +++ b/tensorflow/lite/arena_planner_subgraph_test.cc @@ -15,6 +15,9 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/subgraph_test_util.h" namespace tflite { diff --git a/tensorflow/lite/array.cc b/tensorflow/lite/array.cc index 1b1ff2e4557537..21d704a76c4232 100644 --- a/tensorflow/lite/array.cc +++ b/tensorflow/lite/array.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/array.h" +#include "tensorflow/lite/c/common.h" + namespace tflite { namespace array_internal { diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl index 8ebe89096a8128..a1998987547a5f 100644 --- a/tensorflow/lite/build_def.bzl +++ b/tensorflow/lite/build_def.bzl @@ -1,7 +1,7 @@ """Build macros for TF Lite.""" load("//tensorflow:strict.default.bzl", "py_strict_test") -load("//tensorflow:tensorflow.bzl", "clean_dep", "if_oss", "tf_binary_additional_srcs", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "if_oss", "tf_binary_additional_srcs", "tf_cc_shared_object") load("//tensorflow/lite:special_rules.bzl", "tflite_copts_extra") load("//tensorflow/lite/java:aar_with_jni.bzl", "aar_with_jni") load("@build_bazel_rules_android//android:rules.bzl", "android_library") @@ -11,6 +11,17 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") def register_extension_info(**kwargs): pass +def clean_dep(target): + """Returns string to 'target' in @litert repository. + + Use this function when referring to targets in the @litert + repository from macros that may be called from external repositories. + """ + + # A repo-relative label is resolved relative to the file in which the + # Label() call appears, i.e. @tsl. + return str(Label(target)) + def tflite_copts(): """Defines common compile time flags for TFLite libraries.""" copts = [ @@ -211,7 +222,8 @@ def tflite_jni_binary( tags = [], srcs = [], visibility = None, # 'None' means use the default visibility. - local_defines = []): + local_defines = [], + exec_properties = {}): """Builds a jni binary for TFLite.""" linkopts = linkopts + select({ clean_dep("//tensorflow:macos"): [ @@ -239,6 +251,7 @@ def tflite_jni_binary( testonly = testonly, visibility = visibility, local_defines = local_defines, + exec_properties = exec_properties, ) def tflite_cc_shared_object( diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index f1664849f36e50..db826d47cb40df 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -403,6 +403,7 @@ cc_library( # For use with library targets that can't use relative paths. # LINT.IfChange(exported_headers) exports_files([ + "builtin_op_data.h", "c_api.h", "c_api_experimental.h", "c_api_opaque.h", @@ -413,6 +414,7 @@ exports_files([ filegroup( name = "tensorflowlite_c_api_hdrs_filegroup", srcs = [ + "builtin_op_data.h", "c_api.h", "c_api_types.h", "common.h", diff --git a/tensorflow/lite/cmake/DownloadFP16.cmake b/tensorflow/lite/cmake/DownloadFP16.cmake new file mode 100644 index 00000000000000..e2a8aea65332d0 --- /dev/null +++ b/tensorflow/lite/cmake/DownloadFP16.cmake @@ -0,0 +1,35 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CMAKE_MINIMUM_REQUIRED(VERSION 3.5 FATAL_ERROR) + +PROJECT(fp16-download NONE) + +# Set file timestamps to the time of extraction. +IF(POLICY CMP0135) + CMAKE_POLICY(SET CMP0135 NEW) +ENDIF() + +INCLUDE(ExternalProject) +ExternalProject_Add(fp16 + URL https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip + URL_HASH SHA256=e66e65515fa09927b348d3d584c68be4215cfe664100d01c9dbc7655a5716d70 + SOURCE_DIR "${CMAKE_BINARY_DIR}/FP16-source" + BINARY_DIR "${CMAKE_BINARY_DIR}/FP16" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index 0f23c4b599e407..48b296170f00df 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -164,10 +164,7 @@ cc_library( "signature_runner.h", ], compatible_with = get_compatible_with_portable(), - visibility = [ - "//research/drishti/benchmarking/async:__subpackages__", - "//tensorflow/lite:__subpackages__", - ], + visibility = ["//tensorflow/lite:__subpackages__"], deps = [ ":model_builder", ":signature_runner", @@ -306,7 +303,6 @@ cc_library( compatible_with = get_compatible_with_portable(), visibility = [ "//tensorflow/lite:__pkg__", - "//tensorflow/lite/c:__subpackages__", "//tensorflow/lite/core:__subpackages__", ], deps = [ diff --git a/tensorflow/lite/core/acceleration/configuration/BUILD b/tensorflow/lite/core/acceleration/configuration/BUILD index cd2c147603d710..2ddd858c7f14bc 100644 --- a/tensorflow/lite/core/acceleration/configuration/BUILD +++ b/tensorflow/lite/core/acceleration/configuration/BUILD @@ -19,6 +19,7 @@ cc_library( deps = [ "//tensorflow/lite/acceleration/configuration:configuration_fbs", "//tensorflow/lite/core/c:common", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], ) @@ -38,6 +39,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/nnapi:nnapi_implementation_headers", + "//tensorflow/lite/nnapi:nnapi_lib", "@com_google_absl//absl/memory", ], alwayslink = 1, # For registration to always run. @@ -61,6 +63,9 @@ cc_test( "//tensorflow/lite/delegates/nnapi:nnapi_delegate", "//tensorflow/lite/delegates/nnapi:nnapi_delegate_mock_test", "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/nnapi:nnapi_implementation_headers", + "//tensorflow/lite/nnapi:nnapi_lib", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", "@flatbuffers", ], @@ -76,6 +81,7 @@ cc_library( deps = [ "//tensorflow/lite/core/acceleration/configuration/c:stable_delegate", "//tensorflow/lite/core/shims:tflite_use_opaque_delegate", # buildcleaner: keep + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/synchronization", ], ) @@ -85,6 +91,7 @@ cc_test( srcs = ["stable_delegate_registry_test.cc"], deps = [ ":stable_delegate_registry", + "//tensorflow/lite/core/acceleration/configuration/c:stable_delegate", "@com_google_googletest//:gtest_main", ], ) @@ -99,7 +106,9 @@ cc_library( deps = [ "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/acceleration/configuration:delegate_registry", + "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/memory", diff --git a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc index f05dae0a9a373f..1133b1b69c0e84 100644 --- a/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.cc @@ -17,8 +17,6 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/c/xnnpack_plugin.h" -#include - #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/c/delegate_plugin.h" #include "tensorflow/lite/core/c/common.h" @@ -30,6 +28,8 @@ static TfLiteDelegate* CreateDelegate(const void* settings) { const ::tflite::TFLiteSettings* tflite_settings = static_cast(settings); auto options(TfLiteXNNPackDelegateOptionsDefault()); + // The following code block is duplicated in the C++ XNNPack delegate plugin. + // LINT.IfChange(tflite_settings_to_xnnpack_delegate_options) const auto* xnnpack_settings = tflite_settings->xnnpack_settings(); if (xnnpack_settings) { options.num_threads = xnnpack_settings->num_threads(); @@ -45,6 +45,7 @@ static TfLiteDelegate* CreateDelegate(const void* settings) { xnnpack_settings->weight_cache_file_path()->c_str(); } } + // LINT.ThenChange(../xnnpack_plugin.cc:tflite_settings_to_xnnpack_delegate_options) return TfLiteXNNPackDelegateCreate(&options); } diff --git a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc index b28759b6af77b9..71ee43e2b5f935 100644 --- a/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc +++ b/tensorflow/lite/core/acceleration/configuration/delegate_registry.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/delegate_registry.h b/tensorflow/lite/core/acceleration/configuration/delegate_registry.h index e3dc41e5dd707f..742e74389927a5 100644 --- a/tensorflow/lite/core/acceleration/configuration/delegate_registry.h +++ b/tensorflow/lite/core/acceleration/configuration/delegate_registry.h @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/c/common.h" diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc index 2adfaa5b2ae1ff..34dd7bbed229f6 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.cc @@ -17,6 +17,8 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h" +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" + namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h index 03721d73a98df4..8b86801be3d28c 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin.h @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" #include "tensorflow/lite/nnapi/nnapi_implementation.h" namespace tflite { diff --git a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc index 75179c020dc3cf..57a3042737600a 100644 --- a/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/nnapi_plugin_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/core/c/common.h" @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h" #include "tensorflow/lite/delegates/nnapi/nnapi_delegate_mock_test.h" #include "tensorflow/lite/kernels/test_util.h" +#include "tensorflow/lite/nnapi/NeuralNetworksTypes.h" +#include "tensorflow/lite/nnapi/nnapi_implementation.h" +#include "tensorflow/lite/schema/schema_generated.h" // Tests for checking that the NNAPI Delegate plugin correctly handles all the // options from the flatbuffer. diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc index e5203762d2affa..87284f3bcfe074 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" +#include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" namespace tflite { namespace delegates { diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h index ede67164500795..25ac647290fb49 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/synchronization/mutex.h" #include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" diff --git a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc index a3b6725599e33d..c3a8335345ebdf 100644 --- a/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/stable_delegate_registry_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/core/acceleration/configuration/stable_delegate_registry.h" #include +#include "tensorflow/lite/core/acceleration/configuration/c/stable_delegate.h" namespace { diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc index 81566bcd7f186c..ffd2566e3248c6 100644 --- a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc +++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include -#include "absl/memory/memory.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" namespace tflite { @@ -34,11 +35,23 @@ class XNNPackPlugin : public DelegatePluginInterface { } explicit XNNPackPlugin(const TFLiteSettings& tflite_settings) : options_(TfLiteXNNPackDelegateOptionsDefault()) { + // LINT.IfChange(tflite_settings_to_xnnpack_delegate_options) const auto* xnnpack_settings = tflite_settings.xnnpack_settings(); if (xnnpack_settings) { options_.num_threads = xnnpack_settings->num_threads(); - options_.flags = xnnpack_settings->flags(); + // If xnnpack_settings->flags is zero, then leave options.flags + // unmodified, i.e. use the default flags (not zero). + // If xnnpack_settings->flags is nonzero, then use exactly + // those flags (i.e. discard the default flags). + if (xnnpack_settings->flags()) { + options_.flags = xnnpack_settings->flags(); + } + if (xnnpack_settings->weight_cache_file_path()) { + options_.weight_cache_file_path = + xnnpack_settings->weight_cache_file_path()->c_str(); + } } + // LINT.ThenChange(c/xnnpack_plugin.cc:tflite_settings_to_xnnpack_delegate_options) } private: diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc index 2aa1d95a44f10d..cd8fd91a56fa37 100644 --- a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc +++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc @@ -15,9 +15,11 @@ limitations under the License. // Some very simple unit tests of the (C++) XNNPack Delegate Plugin. -#include +#include + #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "pthreadpool.h" // from @pthreadpool #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" @@ -28,20 +30,16 @@ namespace tflite { class XnnpackPluginTest : public testing::Test { public: static constexpr int kNumThreadsForTest = 7; - static constexpr tflite::XNNPackFlags kFlagsForTest = - tflite::XNNPackFlags::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8; void SetUp() override { // Construct a FlatBuffer that contains // TFLiteSettings { // delegate: Delegate.XNNPACK, - // XNNPackSettings { num_threads: kNumThreadsForTest - // flags: TFLITE_XNNPACK_DELEGATE_FLAG_QS8 | - // TFLITE_XNNPACK_DELEGATE_FLAG_QU8 + // XNNPackSettings { + // num_threads: kNumThreadsForTest // } // }. XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); xnnpack_settings_builder.add_num_threads(kNumThreadsForTest); - xnnpack_settings_builder.add_flags(kFlagsForTest); flatbuffers::Offset xnnpack_settings = xnnpack_settings_builder.Finish(); TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); @@ -58,7 +56,7 @@ class XnnpackPluginTest : public testing::Test { ASSERT_NE(delegate_plugin_, nullptr); } void TearDown() override { delegate_plugin_.reset(); } - ~XnnpackPluginTest() override {} + ~XnnpackPluginTest() override = default; protected: // settings_ points into storage owned by flatbuffer_builder_. @@ -88,4 +86,88 @@ TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) { EXPECT_EQ(thread_count, kNumThreadsForTest); } +TEST_F(XnnpackPluginTest, UsesDefaultFlagsByDefault) { + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + int flags = TfLiteXNNPackDelegateGetFlags(delegate.get()); + EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags); +} + +TEST_F(XnnpackPluginTest, UsesSpecifiedFlagsWhenNonzero) { + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_flags( + tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + tflite_settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName( + "XNNPackPlugin", *tflite_settings_); + + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + int flags = TfLiteXNNPackDelegateGetFlags(delegate.get()); + EXPECT_EQ(flags, tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8); +} + +// Settings flags to XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS (zero) +// causes flags to be set to their default values, not zero. +// This is potentially confusing behaviour, but we can't distinguish +// the case when flags isn't set from the case when flags is set to zero. +TEST_F(XnnpackPluginTest, UsesDefaultFlagsWhenZero) { + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_flags( + tflite::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_NO_FLAGS); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + tflite_settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName( + "XNNPackPlugin", *tflite_settings_); + + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + int flags = TfLiteXNNPackDelegateGetFlags(delegate.get()); + EXPECT_EQ(flags, TfLiteXNNPackDelegateOptionsDefault().flags); +} + +TEST_F(XnnpackPluginTest, DoesNotSetWeightCacheFilePathByDefault) { + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + const TfLiteXNNPackDelegateOptions *options = + TfLiteXNNPackDelegateGetOptions(delegate.get()); + EXPECT_EQ(options->weight_cache_file_path, nullptr); +} + +TEST_F(XnnpackPluginTest, HonoursWeightCacheFilePathSetting) { + const char *const kWeightCachePath = "/tmp/wcfp"; + const auto weight_cache_file_path_string = + flatbuffer_builder_.CreateString(kWeightCachePath); + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_weight_cache_file_path( + weight_cache_file_path_string); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + tflite_settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName( + "XNNPackPlugin", *tflite_settings_); + + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + const TfLiteXNNPackDelegateOptions *options = + TfLiteXNNPackDelegateGetOptions(delegate.get()); + EXPECT_STREQ(options->weight_cache_file_path, kWeightCachePath); +} + } // namespace tflite diff --git a/tensorflow/lite/core/c/BUILD b/tensorflow/lite/core/c/BUILD index 6e0066185483ff..c2b1715dba4934 100644 --- a/tensorflow/lite/core/c/BUILD +++ b/tensorflow/lite/core/c/BUILD @@ -15,7 +15,6 @@ load( package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__", "//tensorflow/lite:__subpackages__", ], licenses = ["notice"], @@ -45,6 +44,7 @@ bzl_library( filegroup( name = "headers_filegroup", srcs = [ + "builtin_op_data.h", "c_api.h", "c_api_types.h", "common.h", @@ -55,7 +55,6 @@ filegroup( filegroup( name = "tflite_internal_cc_3p_api_deps_src", srcs = [ - "builtin_op_data.h", "common.cc", "common.h", ], @@ -335,10 +334,7 @@ tflite_cc_library_with_c_headers_test( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), - visibility = [ - "//tensorflow/lite:__subpackages__", - "@org_tensorflow_lite_support//tensorflow_lite_support/custom_ops:__subpackages__", - ] + common_header_visibility_allowlist(), + visibility = ["//tensorflow/lite:__subpackages__"] + common_header_visibility_allowlist(), deps = [ ":c_api_types", "//tensorflow/lite:tflite_kernel_use_xnnpack_optional", diff --git a/tensorflow/lite/core/c/c_api_types.h b/tensorflow/lite/core/c/c_api_types.h index f0b76bde0258cb..79a00319709300 100644 --- a/tensorflow/lite/core/c/c_api_types.h +++ b/tensorflow/lite/core/c/c_api_types.h @@ -110,6 +110,13 @@ typedef enum TfLiteStatus { // TODO(b/250636993): Cancellation triggered by `SetCancellationFunction` // should also return this status code. kTfLiteCancelled = 8, + + // This status is returned by Prepare when the output shape cannot be + // determined but the size of the output tensor is known. For example, the + // output of reshape is always the same size as the input. This means that + // such ops may be + // done in place. + kTfLiteOutputShapeNotKnown = 9, } TfLiteStatus; /// Types supported by tensor diff --git a/tensorflow/lite/core/c/common.cc b/tensorflow/lite/core/c/common.cc index d2cb82199bfce0..09e71d578f4c25 100644 --- a/tensorflow/lite/core/c/common.cc +++ b/tensorflow/lite/core/c/common.cc @@ -295,7 +295,8 @@ TfLiteStatus TfLiteTensorResizeMaybeCopy(size_t num_bytes, TfLiteTensor* tensor, #ifdef TF_LITE_TENSORFLOW_PROFILER tflite::PauseHeapMonitoring(/*pause=*/true); #endif - size_t alloc_bytes = num_bytes; + // This buffer may be consumed by XNNPack. + size_t alloc_bytes = num_bytes + /*XNN_EXTRA_BYTES=*/16; // TODO(b/145340303): Tensor data should be aligned. if (!tensor->data.data) { tensor->data.data = (char*)malloc(alloc_bytes); diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index 4609fac67f5524..00f1a93ec6567e 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -521,7 +521,7 @@ void Interpreter::AddProfiler(std::unique_ptr profiler) { } impl::SignatureRunner* Interpreter::GetSignatureRunner( - const char* signature_key_, bool apply_default_delegates) { + const char* signature_key_) { auto [signature_key, empty_signature_fallback] = ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_); if (!signature_key) { @@ -533,13 +533,11 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner( return &(iter->second); } - if (apply_default_delegates) { - // Default delegates are applied once for all subgraphs. Only returns error - // when the status is kTfLiteError. For other statuses, it will fall back to - // the default implementation. - if (ApplyLazyDelegateProviders() == kTfLiteError) { - return nullptr; - } + // Default delegates are applied once for all subgraphs. Only returns error + // when the status is kTfLiteError. For other statuses, it will fall back to + // the default implementation. + if (ApplyLazyDelegateProviders() == kTfLiteError) { + return nullptr; } if (empty_signature_fallback) { diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index f413d2b2f125d1..feb4995b0e96f4 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -341,13 +341,12 @@ class Interpreter { /// given signature_key is not valid. Note, the returned SignatureRunner /// instance is owned by and has the same lifetime as the Interpreter object; /// additionally, class SignatureRunner is *not* thread-safe. - /// This function will additionally apply default delegates unless - /// `apply_default_delegate` is set to false. /// If you need to specify delegates, you have to do that before calling this - /// function or provide `apply_default_delegate` as false and applying - /// delegates later. - SignatureRunner* GetSignatureRunner(const char* signature_key, - bool apply_default_delegate = true); + /// function. This function will additionally apply default delegates. Thus, + /// applying delegates after that might lead to undesirable behaviors. + /// If you need `SignatureRunner` without applying default delegates, + /// use `BuiltinOpResolverWithoutDefaultDelegates`. + SignatureRunner* GetSignatureRunner(const char* signature_key); /// \warning Experimental interface, subject to change. \n \brief Returns a /// pointer to the AsyncSignatureRunner instance to run the part of the graph diff --git a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl index a8483fb57bb974..320e5a009222ec 100644 --- a/tensorflow/lite/core/shims/cc_library_with_tflite.bzl +++ b/tensorflow/lite/core/shims/cc_library_with_tflite.bzl @@ -3,9 +3,9 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@build_bazel_rules_android//android:rules.bzl", "android_binary", "android_library") load("@rules_java//java:defs.bzl", "java_library", "java_test") -load("//tensorflow:tensorflow.bzl", "clean_dep") load( "//tensorflow/lite:build_def.bzl", + "clean_dep", "tflite_copts_warnings", "tflite_custom_c_library", "tflite_jni_binary", diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc index 0c5d20d0c02d34..dbd250364a3d82 100644 --- a/tensorflow/lite/core/subgraph.cc +++ b/tensorflow/lite/core/subgraph.cc @@ -1506,7 +1506,8 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( node_index); #endif // TF_LITE_TENSORFLOW_PROFILER const TfLiteStatus op_prepare_status = OpPrepare(registration, &node); - if (op_prepare_status != kTfLiteOk) { + if (op_prepare_status != kTfLiteOk && + op_prepare_status != kTfLiteOutputShapeNotKnown) { ReportOpError(&context_, node, registration, node_index, "failed to prepare"); return op_prepare_status; @@ -1517,7 +1518,8 @@ TfLiteStatus Subgraph::PrepareOpsStartingAt( // Discontinue if the node has dynamic outputs. Note that we don't // stop for dynamic temporary tensors since they won't affect the // sizes of other tensors in the graph. - if (HasDynamicTensor(context_, node.outputs, &dynamic_tensor_index_)) { + if (HasDynamicTensor(context_, node.outputs, &dynamic_tensor_index_) || + op_prepare_status == kTfLiteOutputShapeNotKnown) { has_dynamic_tensors_ = true; return kTfLiteOk; } diff --git a/tensorflow/lite/core/tools/verifier.cc b/tensorflow/lite/core/tools/verifier.cc index d2e110a1358f81..6dc5647ada468c 100644 --- a/tensorflow/lite/core/tools/verifier.cc +++ b/tensorflow/lite/core/tools/verifier.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/lite/core/tools/verifier.h" +#include + #include #include #include #include #include +#include #include "absl/container/flat_hash_set.h" #include "absl/types/optional.h" @@ -57,7 +60,7 @@ void ReportError(ErrorReporter* error_reporter, const char* format, ...) { } // Returns the int32_t value pointed by ptr. -const uint32_t GetIntPtr(const char* ptr) { +uint32_t GetIntPtr(const char* ptr) { #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && \ __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ return flatbuffers::EndianScalar(*reinterpret_cast(ptr)); diff --git a/tensorflow/lite/core/tools/verifier_internal.cc b/tensorflow/lite/core/tools/verifier_internal.cc index 706d534d6320bf..1f0b537acc5001 100644 --- a/tensorflow/lite/core/tools/verifier_internal.cc +++ b/tensorflow/lite/core/tools/verifier_internal.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/lite/core/tools/verifier_internal.h" +#include +#include + #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/delegates/flex/build_def.bzl b/tensorflow/lite/delegates/flex/build_def.bzl index bd62777f46c9dd..069fd0edca025f 100644 --- a/tensorflow/lite/delegates/flex/build_def.bzl +++ b/tensorflow/lite/delegates/flex/build_def.bzl @@ -3,7 +3,6 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_library") load( "//tensorflow:tensorflow.bzl", - "clean_dep", "if_android", "if_ios", "if_mobile", @@ -17,6 +16,7 @@ load( ) load( "//tensorflow/lite:build_def.bzl", + "clean_dep", "tflite_cc_shared_object", "tflite_copts", "tflite_jni_binary", diff --git a/tensorflow/lite/delegates/flex/test/BUILD b/tensorflow/lite/delegates/flex/test/BUILD index 65467a9a84f903..92f81d68892b69 100644 --- a/tensorflow/lite/delegates/flex/test/BUILD +++ b/tensorflow/lite/delegates/flex/test/BUILD @@ -1,5 +1,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@build_bazel_rules_apple//apple:ios.bzl", "ios_static_framework") +load("@rules_java//java:defs.bzl", "java_library", "java_test") load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/lite/delegates/flex:build_def.bzl", "tflite_flex_jni_library") diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index d66d66b544a608..2fe82d4df684db 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -69,6 +69,11 @@ config_setting( config_setting( name = "tflite_gpu_extra_gles_deps", + # copybara:uncomment_begin(google-only) + # constraint_values = [ + # "//third_party/bazel_platforms/os:linux", + # ], + # copybara:uncomment_end values = { "copt": "-DTFLITE_GPU_EXTRA_GLES_DEPS", "cpu": "k8", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index b84cb9a71a46f0..ecc22a1fef7cc4 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -351,11 +351,9 @@ cc_library( ":cl_context", ":cl_device", ":program_cache", - ":util", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:precision", - "//tensorflow/lite/delegates/gpu/common:shape", "//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:tensor", "//tensorflow/lite/delegates/gpu/common/task:tensor_desc", diff --git a/tensorflow/lite/delegates/gpu/cl/environment.cc b/tensorflow/lite/delegates/gpu/cl/environment.cc index ed5b895e4a8164..2ec9c243027896 100644 --- a/tensorflow/lite/delegates/gpu/cl/environment.cc +++ b/tensorflow/lite/delegates/gpu/cl/environment.cc @@ -15,12 +15,16 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/cl/environment.h" -#include #include #include -#include "tensorflow/lite/delegates/gpu/cl/util.h" -#include "tensorflow/lite/delegates/gpu/common/shape.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_command_queue.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_context.h" +#include "tensorflow/lite/delegates/gpu/cl/cl_device.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" +#include "tensorflow/lite/delegates/gpu/common/precision.h" +#include "tensorflow/lite/delegates/gpu/common/status.h" +#include "tensorflow/lite/delegates/gpu/common/task/tensor_desc.h" namespace tflite { namespace gpu { @@ -252,8 +256,11 @@ bool CanUseSubBufferForImage2d(const GpuInfo& gpu_info) { if (!gpu_info.IsCL11OrHigher()) { return false; } - if (gpu_info.IsPowerVR()) { - // driver issue + if (gpu_info.IsPowerVR() && + gpu_info.powervr_info.driver_version.branch_main <= 23) { + // 24.2@6603887 - works. + // 1.15@6133110 - doesn't work. + // Segfaults, wrong results at model level. return false; } if (gpu_info.IsNvidia()) { diff --git a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h index cbdf46a0b85055..92b5cb12ef2ad3 100644 --- a/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h +++ b/tensorflow/lite/delegates/gpu/cl/opencl_wrapper.h @@ -544,6 +544,18 @@ typedef cl_int(CL_API_CALL *PFN_clEnqueueCommandBufferKHR)( cl_uint /*num_events_in_wait_list*/, const cl_event * /*event_wait_list*/, cl_event * /*event*/); +#if CL_KHR_COMMAND_BUFFER_EXTENSION_VERSION >= CL_MAKE_VERSION(0, 9, 5) +typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( + cl_command_buffer_khr /*command_buffer*/, + cl_command_queue /*command_queue*/, + const cl_command_properties_khr * /*properties*/, cl_kernel /*kernel*/, + cl_uint /*work_dim*/, const size_t * /*global_work_offset*/, + const size_t * /*global_work_size*/, const size_t * /*local_work_size*/, + cl_uint /*num_sync_points_in_wait_list*/, + const cl_sync_point_khr * /*sync_point_wait_list*/, + cl_sync_point_khr * /*sync_point*/, + cl_mutable_command_khr * /*mutable_handle*/); +#else typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( cl_command_buffer_khr /*command_buffer*/, cl_command_queue /*command_queue*/, @@ -555,6 +567,7 @@ typedef cl_int(CL_API_CALL *PFN_clCommandNDRangeKernelKHR)( const cl_sync_point_khr * /*sync_point_wait_list*/, cl_sync_point_khr * /*sync_point*/, cl_mutable_command_khr * /*mutable_handle*/); +#endif typedef cl_int(CL_API_CALL *PFN_clGetCommandBufferInfoKHR)( cl_command_buffer_khr /*command_buffer*/, diff --git a/tensorflow/lite/delegates/gpu/common/BUILD b/tensorflow/lite/delegates/gpu/common/BUILD index 195124f269c7c8..df4cbb4306790d 100644 --- a/tensorflow/lite/delegates/gpu/common/BUILD +++ b/tensorflow/lite/delegates/gpu/common/BUILD @@ -533,6 +533,8 @@ cc_test( name = "winograd_util_test", srcs = ["winograd_util_test.cc"], deps = [ + ":operations", + ":shape", ":winograd_util", "@com_google_googletest//:gtest_main", ], diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.cc b/tensorflow/lite/delegates/gpu/common/gpu_info.cc index 2627adda13c6bd..0b7bc03136c842 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.cc +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.cc @@ -386,37 +386,46 @@ int AdrenoInfo::GetComputeUnitsCount() const { } AppleInfo::AppleInfo(const std::string& gpu_description) { - const std::map kMapping = { - {"apple a7 gpu", AppleGpu::kA7}, - {"apple a8 gpu", AppleGpu::kA8}, - {"apple a8x gpu", AppleGpu::kA8X}, - {"apple a9 gpu", AppleGpu::kA9}, - {"apple a9x gpu", AppleGpu::kA9X}, - {"apple a10 gpu", AppleGpu::kA10}, - {"apple a10x gpu", AppleGpu::kA10X}, - {"apple a11 gpu", AppleGpu::kA11}, - {"apple a12 gpu", AppleGpu::kA12}, - {"apple a12x gpu", AppleGpu::kA12X}, - {"apple a12z gpu", AppleGpu::kA12Z}, - {"apple a13 gpu", AppleGpu::kA13}, - {"apple a14 gpu", AppleGpu::kA14}, - {"apple a15 gpu", AppleGpu::kA15}, - {"apple a16 gpu", AppleGpu::kA16}, - {"apple a17 pro gpu", AppleGpu::kA17Pro}, - // on tablets we have metal device name "apple m1 gpu" - // and on notebooks "apple m1" - {"apple m1 gpu", AppleGpu::kM1}, + const std::vector> kMapping = { + {"apple a7", AppleGpu::kA7}, + {"apple a8", AppleGpu::kA8}, + {"apple a8x", AppleGpu::kA8X}, + {"apple a9", AppleGpu::kA9}, + {"apple a9x", AppleGpu::kA9X}, + {"apple a10", AppleGpu::kA10}, + {"apple a10x", AppleGpu::kA10X}, + {"apple a11", AppleGpu::kA11}, + {"apple a12", AppleGpu::kA12}, + {"apple a12x", AppleGpu::kA12X}, + {"apple a12z", AppleGpu::kA12Z}, + {"apple a13", AppleGpu::kA13}, + {"apple a14", AppleGpu::kA14}, + {"apple a15", AppleGpu::kA15}, + {"apple a16", AppleGpu::kA16}, + {"apple a17 pro", AppleGpu::kA17Pro}, + {"apple a18", AppleGpu::kA18}, + {"apple a18 pro", AppleGpu::kA18Pro}, {"apple m1", AppleGpu::kM1}, {"apple m1 pro", AppleGpu::kM1Pro}, {"apple m1 max", AppleGpu::kM1Max}, {"apple m1 ultra", AppleGpu::kM1Ultra}, {"apple m2", AppleGpu::kM2}, + {"apple m2 pro", AppleGpu::kM2Pro}, + {"apple m2 max", AppleGpu::kM2Max}, + {"apple m2 ultra", AppleGpu::kM2Ultra}, + {"apple m3", AppleGpu::kM3}, + {"apple m3 pro", AppleGpu::kM3Pro}, + {"apple m3 max", AppleGpu::kM3Max}, + {"apple m4", AppleGpu::kM4}, }; - auto it = kMapping.find(gpu_description); - if (it != kMapping.end()) { - gpu_type = it->second; - } else { - gpu_type = AppleGpu::kUnknown; + gpu_type = AppleGpu::kUnknown; + std::string gpu_name = ""; + for (const auto& v : kMapping) { + if (gpu_description.find(v.first) != std::string::npos && + v.first.size() > gpu_name.size()) { + gpu_name = v.first; + gpu_type = v.second; + } } gpu_family = GetGpuFamily(); } @@ -439,9 +448,10 @@ AppleInfo::Family AppleInfo::GetGpuFamily() const { } else if (gpu_type == AppleGpu::kA14 || IsM1Series()) { return AppleInfo::Family::kApple7; } else if (gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kM2) { + IsM2Series()) { return AppleInfo::Family::kApple8; - } else if (gpu_type == AppleGpu::kA17Pro) { + } else if (gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kA18 || + gpu_type == AppleGpu::kA18Pro || IsM3Series() || IsM4Series()) { return AppleInfo::Family::kApple9; } return AppleInfo::Family::kApple1; @@ -496,27 +506,28 @@ bool AppleInfo::IsM1Series() const { gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra; } +bool AppleInfo::IsM2Series() const { + return gpu_type == AppleGpu::kM2 || gpu_type == AppleGpu::kM2Pro || + gpu_type == AppleGpu::kM2Max || gpu_type == AppleGpu::kM2Ultra; +} + +bool AppleInfo::IsM3Series() const { + return gpu_type == AppleGpu::kM3 || gpu_type == AppleGpu::kM3Pro || + gpu_type == AppleGpu::kM3Max; +} + +bool AppleInfo::IsM4Series() const { return gpu_type == AppleGpu::kM4; } + bool AppleInfo::IsBionic() const { - return gpu_type == AppleGpu::kA11 || gpu_type == AppleGpu::kA12 || - gpu_type == AppleGpu::kA12X || gpu_type == AppleGpu::kA12Z || - gpu_type == AppleGpu::kA13 || gpu_type == AppleGpu::kA14 || - gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM1 || - gpu_type == AppleGpu::kM1Pro || gpu_type == AppleGpu::kM1Max || - gpu_type == AppleGpu::kM1Ultra || gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple4; } bool AppleInfo::IsSIMDMatMulSupported() const { - return gpu_type == AppleGpu::kA14 || gpu_type == AppleGpu::kA15 || - gpu_type == AppleGpu::kA16 || gpu_type == AppleGpu::kA17Pro || - gpu_type == AppleGpu::kM1 || gpu_type == AppleGpu::kM1Pro || - gpu_type == AppleGpu::kM1Max || gpu_type == AppleGpu::kM1Ultra || - gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple7; } bool AppleInfo::IsSIMDMatMulFp32Perf2x() const { - return gpu_type == AppleGpu::kA15 || gpu_type == AppleGpu::kA16 || - gpu_type == AppleGpu::kA17Pro || gpu_type == AppleGpu::kM2; + return gpu_family >= AppleInfo::Family::kApple8 || IsM1Series(); } bool AppleInfo::IsRoundToNearestSupported() const { return IsBionic(); } @@ -560,6 +571,10 @@ int AppleInfo::GetComputeUnitsCount() const { return 5; case AppleGpu::kA17Pro: return 6; + case AppleGpu::kA18: + return 5; + case AppleGpu::kA18Pro: + return 6; case AppleGpu::kM1: // approximate, can be 7 or 8 return 8; @@ -573,7 +588,28 @@ int AppleInfo::GetComputeUnitsCount() const { // approximate, 64 is max possible return 64; case AppleGpu::kM2: - // approximate, 10 is max possible + // approximate + return 10; + case AppleGpu::kM2Pro: + // approximate + return 19; + case AppleGpu::kM2Max: + // approximate + return 38; + case AppleGpu::kM2Ultra: + // approximate + return 76; + case AppleGpu::kM3: + // approximate + return 10; + case AppleGpu::kM3Pro: + // approximate + return 18; + case AppleGpu::kM3Max: + // approximate + return 40; + case AppleGpu::kM4: + // approximate return 10; case AppleGpu::kUnknown: return 4; diff --git a/tensorflow/lite/delegates/gpu/common/gpu_info.h b/tensorflow/lite/delegates/gpu/common/gpu_info.h index f5d73a2f341e28..c1b4eb6454e4a1 100644 --- a/tensorflow/lite/delegates/gpu/common/gpu_info.h +++ b/tensorflow/lite/delegates/gpu/common/gpu_info.h @@ -175,11 +175,20 @@ enum class AppleGpu { kA15, kA16, kA17Pro, + kA18, + kA18Pro, kM1, kM1Pro, kM1Max, kM1Ultra, kM2, + kM2Pro, + kM2Max, + kM2Ultra, + kM3, + kM3Pro, + kM3Max, + kM4, }; struct AppleInfo { @@ -216,6 +225,9 @@ struct AppleInfo { bool IsBionic() const; bool IsM1Series() const; + bool IsM2Series() const; + bool IsM3Series() const; + bool IsM4Series() const; bool IsSIMDMatMulSupported() const; // Often, fp32 alu performance is 1/2 of fp16 alu performance diff --git a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc index 43eab75e99b91d..ce2152fa29a279 100644 --- a/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/common/tasks/fully_connected.cc @@ -243,8 +243,6 @@ FullyConnected CreateFullyConnected(const GpuInfo& gpu_info, std::move(bias_tensor_desc))); return result; - - return result; } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util.cc b/tensorflow/lite/delegates/gpu/common/winograd_util.cc index c499d1e9e3dd0b..3ebaef9a38503e 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/lite/delegates/gpu/common/data_type.h" +#include "tensorflow/lite/delegates/gpu/common/operations.h" #include "tensorflow/lite/delegates/gpu/common/shape.h" #include "tensorflow/lite/delegates/gpu/common/tensor.h" diff --git a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc index 81fb643d399a82..1c694488a33937 100644 --- a/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc +++ b/tensorflow/lite/delegates/gpu/common/winograd_util_test.cc @@ -15,8 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/winograd_util.h" -#include #include +#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/shape.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD index 8b13e0ff92cdbb..795dc219f9037d 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/BUILD +++ b/tensorflow/lite/delegates/gpu/gl/compiler/BUILD @@ -25,6 +25,7 @@ cc_test( ], deps = [ ":preprocessor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", ], ) @@ -36,6 +37,7 @@ cc_library( deps = [ ":preprocessor", ":variable_accessor", + "//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:data_type", "//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/gl:object", @@ -53,8 +55,10 @@ cc_test( ], deps = [ ":object_accessor", + ":preprocessor", ":variable_accessor", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:object", "//tensorflow/lite/delegates/gpu/gl:variable", "@com_google_absl//absl/types:variant", "@com_google_googletest//:gtest_main", @@ -178,6 +182,7 @@ cc_library( "//tensorflow/lite/delegates/gpu/common:model_transformer", "//tensorflow/lite/delegates/gpu/common:operations", "//tensorflow/lite/delegates/gpu/common:types", + "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:any", @@ -196,6 +201,9 @@ cc_test( deps = [ ":compiled_node", ":fuse_auto_input", + "//tensorflow/lite/delegates/gpu/common:model", + "//tensorflow/lite/delegates/gpu/common:model_transformer", + "//tensorflow/lite/delegates/gpu/gl:node_shader", "@com_google_absl//absl/types:any", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", @@ -225,6 +233,7 @@ cc_test( "tflite_not_portable_ios", ], deps = [ + ":preprocessor", ":variable_accessor", "//tensorflow/lite/delegates/gpu/common:types", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc index 75928dae5f204c..58cf0af1967136 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/rename.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc index 761fb8b4602246..985da96ebff678 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" #include "absl/types/any.h" -#include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/model.h" -#include "tensorflow/lite/delegates/gpu/common/operations.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc index 403617366912ee..61c3114a3a0d88 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_auto_input_test.cc @@ -20,7 +20,10 @@ limitations under the License. #include #include #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc index f227ab2147847a..6fae121e02cf36 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inline.cc @@ -23,12 +23,12 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" #include "absl/types/any.h" +#include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" -#include "tensorflow/lite/delegates/gpu/gl/compiler/shader_code.h" #include "tensorflow/lite/delegates/gpu/gl/node_shader.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc index 1e27404b741bde..19e520be166f04 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/fuse_inplace.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/any.h" #include "tensorflow/lite/delegates/gpu/common/model.h" +#include "tensorflow/lite/delegates/gpu/common/model_transformer.h" #include "tensorflow/lite/delegates/gpu/common/types.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/compiled_node.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc index 11228fd2efe58b..43e9fa83e4c9e1 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.cc @@ -21,12 +21,16 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "tensorflow/lite/delegates/gpu/common/access_type.h" #include "tensorflow/lite/delegates/gpu/common/data_type.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h index 318709fe7ff235..74273a6864193e 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" #include "tensorflow/lite/delegates/gpu/gl/object.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc index 4bf9482436506a..fbca570d892f2f 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/object_accessor_test.cc @@ -23,7 +23,9 @@ limitations under the License. #include #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc index 16db8945dece21..173c281e331fcd 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" namespace tflite { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc index 95fcf6244606f4..d4b7cf4157916c 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/preprocessor_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc index fd418770104444..1a05bfa2d87050 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/rename.cc @@ -21,13 +21,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/object_accessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h" +#include "tensorflow/lite/delegates/gpu/gl/node_shader.h" #include "tensorflow/lite/delegates/gpu/gl/object.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc index b55c480654146f..d1a7fd78e1a87b 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.cc @@ -22,8 +22,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/variant.h" #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" +#include "tensorflow/lite/delegates/gpu/gl/variable.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h index f6d5344d3b345e..0eb01c0ea284f5 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" #include "tensorflow/lite/delegates/gpu/gl/variable.h" diff --git a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc index 0e8be2a577ba75..20ac0368c66644 100644 --- a/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc +++ b/tensorflow/lite/delegates/gpu/gl/compiler/variable_accessor_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/delegates/gpu/common/types.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler/preprocessor.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/gpu/tflite_profile.cc b/tensorflow/lite/delegates/gpu/tflite_profile.cc index f0b95553845db4..babb52c2b7768d 100644 --- a/tensorflow/lite/delegates/gpu/tflite_profile.cc +++ b/tensorflow/lite/delegates/gpu/tflite_profile.cc @@ -16,6 +16,7 @@ limitations under the License. #include "absl/time/time.h" #include "tensorflow/lite/core/api/profiler.h" +#include "tensorflow/lite/delegates/gpu/common/task/profiling_info.h" namespace tflite { namespace gpu { diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h index 62dfdbf9345d83..1ece2fbb74da4d 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_c_api.h @@ -23,7 +23,7 @@ extern "C" { // Use TfLiteNnapiDelegateOptionsDefault() for Default options. // WARNING: This is an experimental API and subject to change. -typedef struct TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions { +typedef struct TfLiteNnapiDelegateOptions { // Preferred Power/perf trade-off. For more details please see // ANeuralNetworksCompilation_setPreference documentation in : // https://developer.android.com/ndk/reference/group/neural-networks.html @@ -84,8 +84,8 @@ typedef struct TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions { // Returns a delegate that uses NNAPI for ops execution. // Must outlive the interpreter. // WARNING: This is an experimental API and subject to change. -TfLiteDelegate* TFL_CAPI_EXPORT -TfLiteNnapiDelegateCreate(const TfLiteNnapiDelegateOptions* options); +TFL_CAPI_EXPORT TfLiteDelegate* TfLiteNnapiDelegateCreate( + const TfLiteNnapiDelegateOptions* options); // Returns TfLiteNnapiDelegateOptions populated with default values. // WARNING: This is an experimental API and subject to change. @@ -93,7 +93,7 @@ TFL_CAPI_EXPORT TfLiteNnapiDelegateOptions TfLiteNnapiDelegateOptionsDefault(); // Does any needed cleanup and deletes 'delegate'. // WARNING: This is an experimental API and subject to change. -void TFL_CAPI_EXPORT TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate); +TFL_CAPI_EXPORT void TfLiteNnapiDelegateDelete(TfLiteDelegate* delegate); #ifdef __cplusplus } diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD index f37fd78e0f613f..208370b318fef6 100644 --- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD +++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/BUILD @@ -140,7 +140,6 @@ cc_library( "//tensorflow/lite/kernels:test_util", "//tensorflow/lite/kernels:test_util_delegate_providers", "//tensorflow/lite/testing:util", - "@com_google_benchmark//:benchmark", "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc b/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc index f3fe76d395a79e..65bf981c962383 100644 --- a/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc +++ b/tensorflow/lite/delegates/utils/experimental/stable_delegate/kernel_test_main.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include "benchmark/benchmark.h" // from @com_google_benchmark #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/kernels/acceleration_test_util.h" #include "tensorflow/lite/kernels/acceleration_test_util_internal.h" @@ -90,11 +89,11 @@ void ValidateAcceleration(const SingleOpModel& model) { // We only want to check the delegate is working properly, so an error due // to incompatibility between the model and the delegate is not considered a // failure here. - EXPECT_THAT(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), + ASSERT_THAT(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), testing::AnyOf(kTfLiteOk, kTfLiteApplicationError)); return; } else { - EXPECT_EQ(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), + ASSERT_EQ(model.GetDelegateApplicationStatus().value_or(kTfLiteOk), kTfLiteOk); } @@ -108,10 +107,10 @@ void ValidateAcceleration(const SingleOpModel& model) { return; } TFLITE_LOG(INFO) << "Validating acceleration with the stable delegate"; - EXPECT_EQ(model.CountNumberOfDelegatedPartitions(), 1) + ASSERT_GT(num_applied_delegates, 0) << "No delegates were applied."; + ASSERT_EQ(model.CountNumberOfDelegatedPartitions(), 1) << "Expecting operation to be accelerated but cannot find a partition " "associated to the stable delegate"; - EXPECT_GT(num_applied_delegates, 0) << "No delegates were applied."; } bool InitKernelTest(int* argc, char** argv) { @@ -155,7 +154,6 @@ int main(int argc, char** argv) { tflite::LogToStderr(); if (tflite::InitKernelTest(&argc, argv)) { testing::InitGoogleTest(&argc, argv); - benchmark::RunSpecifiedBenchmarks(); int ret = RUN_ALL_TESTS(); tflite::DestroyKernelTest(); return ret; diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 02f60403dbc347..8a4cb392ac9621 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -720,8 +720,10 @@ cc_library( srcs = ["quantized_binary_elementwise_tester.cc"], hdrs = ["quantized_binary_elementwise_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -739,8 +741,10 @@ cc_library( hdrs = ["quantized_conv_2d_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -758,8 +762,10 @@ cc_library( hdrs = ["quantized_depthwise_conv_2d_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -777,8 +783,10 @@ cc_library( hdrs = ["quantized_fully_connected_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -795,6 +803,7 @@ cc_library( srcs = ["quantized_leaky_relu_tester.cc"], hdrs = ["quantized_leaky_relu_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -813,6 +822,7 @@ cc_library( srcs = ["quantized_pad_tester.cc"], hdrs = ["quantized_pad_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -831,26 +841,10 @@ cc_library( srcs = ["quantized_pool_2d_tester.cc"], hdrs = ["quantized_pool_2d_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", - "//tensorflow/lite/core:framework", - "//tensorflow/lite/core/c:common", - "//tensorflow/lite/core/kernels:builtin_ops", - "//tensorflow/lite/schema:schema_conversion_utils", - "//tensorflow/lite/schema:schema_fbs", - "@com_google_googletest//:gtest", - "@flatbuffers", - ], -) - -cc_library( - name = "quantized_reduce_tester", - testonly = 1, - srcs = ["quantized_reduce_tester.cc"], - hdrs = ["quantized_reduce_tester.h"], - deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core:framework", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", @@ -867,6 +861,7 @@ cc_library( srcs = ["quantized_resize_bilinear_tester.cc"], hdrs = ["quantized_resize_bilinear_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -885,6 +880,7 @@ cc_library( srcs = ["quantized_unary_elementwise_tester.cc"], hdrs = ["quantized_unary_elementwise_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -904,8 +900,10 @@ cc_library( hdrs = ["quantized_variable_ops_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -926,8 +924,10 @@ cc_library( hdrs = ["quantized_variable_ops_tester.h"], deps = [ ":xnnpack_delegate", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -943,6 +943,7 @@ cc_library( srcs = ["reduce_tester.cc"], hdrs = ["reduce_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -961,6 +962,7 @@ cc_library( srcs = ["reshape_tester.cc"], hdrs = ["reshape_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -979,6 +981,7 @@ cc_library( srcs = ["resize_bilinear_tester.cc"], hdrs = ["resize_bilinear_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -997,8 +1000,10 @@ cc_library( srcs = ["slice_tester.cc"], hdrs = ["slice_tester.h"], deps = [ + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", + "//tensorflow/lite/core:cc_api_stable", "//tensorflow/lite/core/c:common", "//tensorflow/lite/core/kernels:builtin_ops", "//tensorflow/lite/schema:schema_conversion_utils", @@ -1143,6 +1148,7 @@ cc_library( hdrs = ["quantized_transpose_conv_tester.h"], deps = [ ":xnnpack_delegate_test_mode", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/lite:framework", "//tensorflow/lite:schema_fbs_version", "//tensorflow/lite/core:framework", @@ -1629,8 +1635,8 @@ cc_test( ) cc_test( - name = "mean_test", - srcs = ["mean_test.cc"], + name = "reduce_test", + srcs = ["reduce_test.cc"], linkopts = select({ "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, "//conditions:default": [], @@ -1645,21 +1651,6 @@ cc_test( ], ) -cc_test( - name = "sum_test", - srcs = ["sum_test.cc"], - linkopts = select({ - "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, - "//conditions:default": [], - }), - deps = [ - ":reduce_tester", - ":test_main", - ":xnnpack_delegate_test_mode", - "@com_google_googletest//:gtest", - ], -) - cc_test( name = "minimum_test", srcs = ["minimum_test.cc"], @@ -1891,6 +1882,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1906,6 +1899,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1921,6 +1916,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1936,6 +1933,8 @@ cc_test( ":reshape_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1951,6 +1950,8 @@ cc_test( ":reshape_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -1981,6 +1982,7 @@ cc_test( ":resize_bilinear_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -1996,6 +1998,8 @@ cc_test( ":test_main", ":unary_elementwise_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2028,6 +2032,7 @@ cc_test( ":dequantize_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2043,6 +2048,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2058,6 +2065,8 @@ cc_test( ":concatenation_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2073,6 +2082,7 @@ cc_test( ":quantized_conv_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2088,6 +2098,7 @@ cc_test( ":quantized_depthwise_conv_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2103,6 +2114,8 @@ cc_test( ":depth_to_space_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2118,6 +2131,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2133,6 +2148,7 @@ cc_test( ":quantized_fully_connected_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2148,6 +2164,7 @@ cc_test( ":quantized_leaky_relu_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2163,6 +2180,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2178,21 +2197,8 @@ cc_test( ":quantized_pool_2d_tester", ":test_main", ":xnnpack_delegate_test_mode", - "@com_google_googletest//:gtest", - ], -) - -cc_test( - name = "signed_quantized_mean_test", - srcs = ["signed_quantized_mean_test.cc"], - linkopts = select({ - "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, - "//conditions:default": [], - }), - deps = [ - ":quantized_reduce_tester", - ":test_main", - ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2208,6 +2214,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2223,6 +2231,7 @@ cc_test( ":quantized_pad_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2238,6 +2247,7 @@ cc_test( ":quantized_resize_bilinear_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2253,6 +2263,8 @@ cc_test( ":slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2268,6 +2280,8 @@ cc_test( ":space_to_depth_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2283,6 +2297,8 @@ cc_test( ":split_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2298,6 +2314,7 @@ cc_test( ":strided_slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2313,6 +2330,8 @@ cc_test( ":quantized_binary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2328,6 +2347,8 @@ cc_test( ":quantized_unary_elementwise_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2343,6 +2364,8 @@ cc_test( ":test_main", ":transpose_tester", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2358,6 +2381,7 @@ cc_test( ":quantized_transpose_conv_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2377,6 +2401,7 @@ cc_test( ":xnnpack_delegate", ":quantized_variable_ops_tester_no_test_mode", "@com_google_googletest//:gtest", + "//tensorflow/lite/c:c_api_types", ], ) @@ -2391,6 +2416,7 @@ cc_test( ":quantized_variable_ops_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2406,6 +2432,8 @@ cc_test( ":slice_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], ) @@ -2421,6 +2449,7 @@ cc_test( ":softmax_tester", ":test_main", ":xnnpack_delegate_test_mode", + "//tensorflow/lite/c:c_api_types", "@com_google_googletest//:gtest", ], ) @@ -2725,21 +2754,6 @@ cc_test( ], ) -cc_test( - name = "unsigned_quantized_mean_test", - srcs = ["unsigned_quantized_mean_test.cc"], - linkopts = select({ - "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, - "//conditions:default": [], - }), - deps = [ - ":quantized_reduce_tester", - ":test_main", - ":xnnpack_delegate_test_mode", - "@com_google_googletest//:gtest", - ], -) - cc_test( name = "unsigned_quantized_mul_test", srcs = ["unsigned_quantized_mul_test.cc"], @@ -2956,6 +2970,7 @@ cc_test( name = "weight_cache_test", srcs = ["weight_cache_test.cc"], deps = [ + ":file_util", ":test_main", ":weight_cache", ":weight_cache_schema", diff --git a/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc index a2665d612d3864..0109427159c729 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.cc @@ -26,10 +26,14 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc index 266a8a22ac7b0b..3e91c73d09934c 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.cc @@ -25,10 +25,16 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc index 017713b4761f3e..162037f9a74f68 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.cc @@ -25,10 +25,16 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc index c42f8d78f97a21..fdcb8999565394 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.cc @@ -26,10 +26,15 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc index 88dead108e7b9b..410c1dbf21c872 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc index aa1f2391613684..545f5cfd761a46 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_pad_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc index 4918b34aeb7d7f..f1cd0249e7d0c3 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.cc @@ -23,10 +23,14 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc deleted file mode 100644 index c4a8cf1b381db5..00000000000000 --- a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.cc +++ /dev/null @@ -1,194 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" -#include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/version.h" - -namespace tflite { -namespace xnnpack { - -template -void QuantizedReduceTester::Test(Interpreter* delegate_interpreter, - Interpreter* default_interpreter) const { - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto input_rng = std::bind( - std::uniform_int_distribution(std::numeric_limits::min(), - std::numeric_limits::max()), - std::ref(rng)); - - T* default_input_data = default_interpreter->typed_input_tensor(0); - std::generate_n(default_input_data, InputSize(), std::ref(input_rng)); - - T* delegate_input_data = delegate_interpreter->typed_input_tensor(0); - std::copy_n(default_input_data, InputSize(), delegate_input_data); - - ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk); - - T* default_output_data = default_interpreter->typed_output_tensor(0); - T* delegate_output_data = delegate_interpreter->typed_output_tensor(0); - - const int32_t output_size = OutputSize(); - for (size_t i = 0; i < output_size; i++) { - ASSERT_LE(std::abs(static_cast(default_output_data[i]) - - static_cast(delegate_output_data[i])), - 1) - << "default " << static_cast(default_output_data[i]) - << ", delegate " << static_cast(delegate_output_data[i]) - << " at index " << i << " / " << output_size; - } -} - -void QuantizedReduceTester::Test(tflite::BuiltinOperator reduce_op, - TfLiteDelegate* delegate) const { - std::vector buffer = CreateTfLiteModel(reduce_op); - const Model* model = GetModel(buffer.data()); - - std::unique_ptr delegate_interpreter; - ASSERT_EQ( - InterpreterBuilder( - model, - ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())( - &delegate_interpreter), - kTfLiteOk); - std::unique_ptr default_interpreter; - ASSERT_EQ( - InterpreterBuilder( - model, - ::tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates())( - &default_interpreter), - kTfLiteOk); - - ASSERT_TRUE(delegate_interpreter); - ASSERT_TRUE(default_interpreter); - - ASSERT_EQ(delegate_interpreter->inputs().size(), 1); - ASSERT_EQ(default_interpreter->inputs().size(), 1); - - ASSERT_EQ(delegate_interpreter->outputs().size(), 1); - ASSERT_EQ(default_interpreter->outputs().size(), 1); - - ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk); - ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk); - - ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk); - - if (Unsigned()) { - Test(delegate_interpreter.get(), default_interpreter.get()); - } else { - Test(delegate_interpreter.get(), default_interpreter.get()); - } -} - -std::vector QuantizedReduceTester::CreateTfLiteModel( - tflite::BuiltinOperator reduce_op) const { - flatbuffers::FlatBufferBuilder builder; - flatbuffers::Offset operator_code = - CreateOperatorCode(builder, reduce_op); - - const std::array, 2> buffers{{ - CreateBuffer(builder, builder.CreateVector({})), - CreateBuffer(builder, builder.CreateVector( - reinterpret_cast(Axes().data()), - sizeof(int32_t) * Axes().size())), - }}; - - const std::vector output_shape = OutputShape(); - const std::array axes_shape{ - {static_cast(Axes().size())}}; - const std::array, 3> tensors{{ - CreateTensor(builder, - builder.CreateVector(InputShape().data(), - InputShape().size()), - Unsigned() ? TensorType_UINT8 : TensorType_INT8, - /*buffer=*/0, /*name=*/0, - CreateQuantizationParameters( - builder, /*min=*/0, /*max=*/0, - builder.CreateVector({InputScale()}), - builder.CreateVector({InputZeroPoint()}))), - CreateTensor( - builder, - builder.CreateVector(axes_shape.data(), axes_shape.size()), - TensorType_INT32, /*buffer=*/1), - CreateTensor(builder, - builder.CreateVector(output_shape.data(), - output_shape.size()), - Unsigned() ? TensorType_UINT8 : TensorType_INT8, - /*buffer=*/0, /*name=*/0, - CreateQuantizationParameters( - builder, /*min=*/0, /*max=*/0, - builder.CreateVector({OutputScale()}), - builder.CreateVector({OutputZeroPoint()}))), - }}; - - const flatbuffers::Offset reducer_options = - CreateReducerOptions(builder, KeepDims()); - - const std::array op_inputs{{0, 1}}; - const std::array op_outputs{{2}}; - flatbuffers::Offset op = CreateOperator( - builder, /*opcode_index=*/0, - builder.CreateVector(op_inputs.data(), op_inputs.size()), - builder.CreateVector(op_outputs.data(), op_outputs.size()), - tflite::BuiltinOptions_ReducerOptions, reducer_options.Union()); - - const std::array subgraph_inputs{{0}}; - const std::array subgraph_outputs{{2}}; - flatbuffers::Offset subgraph = CreateSubGraph( - builder, builder.CreateVector(tensors.data(), tensors.size()), - builder.CreateVector(subgraph_inputs.data(), - subgraph_inputs.size()), - builder.CreateVector(subgraph_outputs.data(), - subgraph_outputs.size()), - builder.CreateVector(&op, 1)); - - flatbuffers::Offset description = - builder.CreateString("Quantized Reduce model"); - - flatbuffers::Offset model_buffer = CreateModel( - builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), - builder.CreateVector(&subgraph, 1), description, - builder.CreateVector(buffers.data(), buffers.size())); - - builder.Finish(model_buffer); - - return std::vector(builder.GetBufferPointer(), - builder.GetBufferPointer() + builder.GetSize()); -} - -int32_t QuantizedReduceTester::ComputeSize(const std::vector& shape) { - return std::accumulate(shape.cbegin(), shape.cend(), 1, - std::multiplies()); -} - -} // namespace xnnpack -} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h b/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h deleted file mode 100644 index d8199ec1f4c2cd..00000000000000 --- a/tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_REDUCE_TESTER_H_ -#define TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_REDUCE_TESTER_H_ - -#include -#include -#include - -#include -#include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_generated.h" - -namespace tflite { -namespace xnnpack { - -class QuantizedReduceTester { - public: - QuantizedReduceTester() = default; - QuantizedReduceTester(const QuantizedReduceTester&) = delete; - QuantizedReduceTester& operator=(const QuantizedReduceTester&) = delete; - - inline QuantizedReduceTester& InputShape( - std::initializer_list shape) { - for (auto it = shape.begin(); it != shape.end(); ++it) { - EXPECT_GT(*it, 0); - } - input_shape_ = std::vector(shape.begin(), shape.end()); - input_size_ = QuantizedReduceTester::ComputeSize(input_shape_); - return *this; - } - - inline const std::vector& InputShape() const { return input_shape_; } - - inline int32_t InputSize() const { return input_size_; } - - inline QuantizedReduceTester& Axes(std::initializer_list axes) { - for (auto it = axes.begin(); it != axes.end(); ++it) { - EXPECT_GE(*it, 0); - } - axes_ = std::vector(axes.begin(), axes.end()); - return *this; - } - - inline const std::vector& Axes() const { return axes_; } - - inline QuantizedReduceTester& KeepDims(bool keep_dims) { - keep_dims_ = keep_dims; - return *this; - } - - inline bool KeepDims() const { return keep_dims_; } - - inline std::vector OutputShape() const { - std::vector output_shape; - output_shape.reserve(InputShape().size()); - std::unordered_set axes_set(Axes().cbegin(), Axes().cend()); - for (int32_t i = 0; i < InputShape().size(); i++) { - if (axes_set.count(i) != 0) { - if (KeepDims()) { - output_shape.push_back(1); - } - } else { - output_shape.push_back(InputShape()[i]); - } - } - return output_shape; - } - - inline int32_t OutputSize() const { - int32_t output_size = 1; - std::unordered_set axes_set(Axes().cbegin(), Axes().cend()); - for (int32_t i = 0; i < InputShape().size(); i++) { - if (axes_set.count(i) == 0) { - output_size *= InputShape()[i]; - } - } - return output_size; - } - - inline QuantizedReduceTester& InputZeroPoint(int32_t input_zero_point) { - input_zero_point_ = input_zero_point; - return *this; - } - - inline int32_t InputZeroPoint() const { return input_zero_point_; } - - inline QuantizedReduceTester& OutputZeroPoint(int32_t output_zero_point) { - output_zero_point_ = output_zero_point; - return *this; - } - - inline int32_t OutputZeroPoint() const { return output_zero_point_; } - - inline QuantizedReduceTester& InputScale(float input_scale) { - input_scale_ = input_scale; - return *this; - } - - inline float InputScale() const { return input_scale_; } - - inline QuantizedReduceTester& OutputScale(float output_scale) { - output_scale_ = output_scale; - return *this; - } - - inline float OutputScale() const { return output_scale_; } - - inline QuantizedReduceTester& Unsigned(bool is_unsigned) { - unsigned_ = is_unsigned; - return *this; - } - - inline bool Unsigned() const { return unsigned_; } - - template - void Test(Interpreter* delegate_interpreter, - Interpreter* default_interpreter) const; - - void Test(tflite::BuiltinOperator reduce_op, TfLiteDelegate* delegate) const; - - private: - std::vector CreateTfLiteModel(tflite::BuiltinOperator reduce_op) const; - - static int32_t ComputeSize(const std::vector& shape); - - std::vector input_shape_; - std::vector axes_; - int32_t input_size_; - bool keep_dims_ = true; - int32_t input_zero_point_ = 1; - int32_t output_zero_point_ = 2; - float input_scale_ = 1.25f; - float output_scale_ = 0.75f; - bool unsigned_ = false; -}; - -} // namespace xnnpack -} // namespace tflite - -#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_QUANTIZED_REDUCE_TESTER_H_ diff --git a/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc index d9ab3f5359547d..484b25f2c68bd1 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc index e9a9a19d856bcb..f5bc843ca9479e 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.cc @@ -26,13 +26,14 @@ limitations under the License. #include #include -#include "fp16.h" // from @FP16 -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/c/builtin_op_data.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc index b8e297ae46175a..6efcbafca015d3 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc b/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc index 71e7f3e7630a07..61ba2b60f2b8bb 100644 --- a/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.cc @@ -25,11 +25,14 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/mean_test.cc b/tensorflow/lite/delegates/xnnpack/reduce_test.cc similarity index 73% rename from tensorflow/lite/delegates/xnnpack/mean_test.cc rename to tensorflow/lite/delegates/xnnpack/reduce_test.cc index b3b4cd0039c9c9..ab7792a0487ebc 100644 --- a/tensorflow/lite/delegates/xnnpack/mean_test.cc +++ b/tensorflow/lite/delegates/xnnpack/reduce_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,6 +17,9 @@ limitations under the License. #include #include #include +#include +#include +#include #include #include "tensorflow/lite/c/c_api_types.h" @@ -27,7 +30,53 @@ limitations under the License. namespace tflite { namespace xnnpack { -TEST(Mean, 4DReduceBatchSqueezeDims) { +struct TestParam { + using Tuple = std::tuple; + explicit TestParam(const Tuple& t) + : op(std::get<0>(t)), quantization(std::get<1>(t)) {} + BuiltinOperator op; + enum ReduceTester::Quantization quantization; +}; + +class ReduceTest : public testing::TestWithParam { + public: + static std::string GetName(const testing::TestParamInfo& i) { + std::stringstream sstr; + switch (i.param.op) { + case BuiltinOperator_MEAN: + sstr << "mean"; + break; + case BuiltinOperator_SUM: + sstr << "sum"; + break; + default: + sstr << "unknown"; + break; + } + switch (i.param.quantization) { + case ReduceTester::Quantization::None: + break; + case ReduceTester::Quantization::Signed: + sstr << "_signed_quantized"; + break; + case ReduceTester::Quantization::Unsigned: + sstr << "_unsigned_quantized"; + break; + } + return sstr.str(); + } +}; + +INSTANTIATE_TEST_SUITE_P( + Reduce, ReduceTest, + testing::ConvertGenerator(testing::Combine( + testing::Values(BuiltinOperator_MEAN, BuiltinOperator_SUM), + testing::Values(ReduceTester::Quantization::None, + ReduceTester::Quantization::Signed, + ReduceTester::Quantization::Unsigned))), + ReduceTest::GetName); + +TEST_P(ReduceTest, 4DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -42,13 +91,14 @@ TEST(Mean, 4DReduceBatchSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({0}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceBatchKeepDims) { +TEST_P(ReduceTest, 4DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -63,13 +113,14 @@ TEST(Mean, 4DReduceBatchKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({0}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceHeightSqueezeDims) { +TEST_P(ReduceTest, 4DReduceHeightSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -84,13 +135,14 @@ TEST(Mean, 4DReduceHeightSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({1}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceHeightKeepDims) { +TEST_P(ReduceTest, 4DReduceHeightKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -105,13 +157,14 @@ TEST(Mean, 4DReduceHeightKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({1}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceWidthSqueezeDims) { +TEST_P(ReduceTest, 4DReduceWidthSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -126,13 +179,14 @@ TEST(Mean, 4DReduceWidthSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({2}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceWidthKeepDims) { +TEST_P(ReduceTest, 4DReduceWidthKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -147,13 +201,14 @@ TEST(Mean, 4DReduceWidthKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({2}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceHeightWidthSqueezeDims) { +TEST_P(ReduceTest, 4DReduceHeightWidthSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -168,19 +223,21 @@ TEST(Mean, 4DReduceHeightWidthSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({1, 2}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({2, 1}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceHeightWidthKeepDims) { +TEST_P(ReduceTest, 4DReduceHeightWidthKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -195,19 +252,21 @@ TEST(Mean, 4DReduceHeightWidthKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({1, 2}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({2, 1}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceChannelsSqueezeDims) { +TEST_P(ReduceTest, 4DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -222,13 +281,14 @@ TEST(Mean, 4DReduceChannelsSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({3}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 4DReduceChannelsKeepDims) { +TEST_P(ReduceTest, 4DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -243,13 +303,14 @@ TEST(Mean, 4DReduceChannelsKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({3}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceBatchSqueezeDims) { +TEST_P(ReduceTest, 3DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -263,13 +324,14 @@ TEST(Mean, 3DReduceBatchSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({0}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceBatchKeepDims) { +TEST_P(ReduceTest, 3DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -283,13 +345,14 @@ TEST(Mean, 3DReduceBatchKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({0}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceWidthSqueezeDims) { +TEST_P(ReduceTest, 3DReduceWidthSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -303,13 +366,14 @@ TEST(Mean, 3DReduceWidthSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({1}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceWidthKeepDims) { +TEST_P(ReduceTest, 3DReduceWidthKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -323,13 +387,14 @@ TEST(Mean, 3DReduceWidthKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({1}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceChannelsSqueezeDims) { +TEST_P(ReduceTest, 3DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -343,13 +408,14 @@ TEST(Mean, 3DReduceChannelsSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({2}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 3DReduceChannelsKeepDims) { +TEST_P(ReduceTest, 3DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -363,13 +429,14 @@ TEST(Mean, 3DReduceChannelsKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, width, channels}) .Axes({2}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 2DReduceBatchSqueezeDims) { +TEST_P(ReduceTest, 2DReduceBatchSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -382,13 +449,14 @@ TEST(Mean, 2DReduceBatchSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, channels}) .Axes({0}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 2DReduceBatchKeepDims) { +TEST_P(ReduceTest, 2DReduceBatchKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -401,13 +469,14 @@ TEST(Mean, 2DReduceBatchKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, channels}) .Axes({0}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 2DReduceChannelsSqueezeDims) { +TEST_P(ReduceTest, 2DReduceChannelsSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -420,13 +489,14 @@ TEST(Mean, 2DReduceChannelsSqueezeDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, channels}) .Axes({1}) .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 2DReduceChannelsKeepDims) { +TEST_P(ReduceTest, 2DReduceChannelsKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -439,13 +509,14 @@ TEST(Mean, 2DReduceChannelsKeepDims) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, channels}) .Axes({1}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 1DSqueezeDims) { +TEST_P(ReduceTest, 1DSqueezeDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -456,11 +527,15 @@ TEST(Mean, 1DSqueezeDims) { std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); const auto batch = shape_rng(); - ReduceTester().InputShape({batch}).Axes({0}).KeepDims(false).Test( - BuiltinOperator_MEAN, xnnpack_delegate.get()); + ReduceTester() + .Quantization(GetParam().quantization) + .InputShape({batch}) + .Axes({0}) + .KeepDims(false) + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, 1DKeepDims) { +TEST_P(ReduceTest, 1DKeepDims) { std::unique_ptr xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), TfLiteXNNPackDelegateDelete); @@ -471,11 +546,15 @@ TEST(Mean, 1DKeepDims) { std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); const auto batch = shape_rng(); - ReduceTester().InputShape({batch}).Axes({0}).KeepDims(true).Test( - BuiltinOperator_MEAN, xnnpack_delegate.get()); + ReduceTester() + .Quantization(GetParam().quantization) + .InputShape({batch}) + .Axes({0}) + .KeepDims(true) + .Test(GetParam().op, xnnpack_delegate.get()); } -TEST(Mean, MultiThreading) { +TEST_P(ReduceTest, MultiThreading) { TfLiteXNNPackDelegateOptions delegate_options = TfLiteXNNPackDelegateOptionsDefault(); delegate_options.num_threads = 2; @@ -493,10 +572,11 @@ TEST(Mean, MultiThreading) { const auto channels = shape_rng(); ReduceTester() + .Quantization(GetParam().quantization) .InputShape({batch, height, width, channels}) .Axes({1, 2}) .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); + .Test(GetParam().op, xnnpack_delegate.get()); } } // namespace xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc index cc6f69066f84be..656375b56a596b 100644 --- a/tensorflow/lite/delegates/xnnpack/reduce_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reduce_tester.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,34 +15,86 @@ limitations under the License. #include "tensorflow/lite/delegates/xnnpack/reduce_tester.h" -#include #include +#include #include +#include #include +#include #include #include #include +#include +#include #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" namespace tflite { namespace xnnpack { -void ReduceTester::Test(tflite::BuiltinOperator reduce_op, - TfLiteDelegate* delegate) const { +template +struct UniformDistribution { + static std::uniform_int_distribution Get() { + return std::uniform_int_distribution( + std::numeric_limits::min(), std::numeric_limits::max()); + } +}; + +template <> +struct UniformDistribution { + static std::uniform_real_distribution Get() { return {}; } +}; + +template +void ReduceTester::Test(Interpreter* delegate_interpreter, + Interpreter* default_interpreter) const { std::random_device random_device; auto rng = std::mt19937(random_device()); - auto input_rng = - std::bind(std::uniform_real_distribution(), std::ref(rng)); + auto input_rng = std::bind(UniformDistribution::Get(), std::ref(rng)); + + T* default_input_data = default_interpreter->typed_input_tensor(0); + std::generate_n(default_input_data, InputSize(), std::ref(input_rng)); + + T* delegate_input_data = delegate_interpreter->typed_input_tensor(0); + std::copy_n(default_input_data, InputSize(), delegate_input_data); + + ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk); + T* default_output_data = default_interpreter->typed_output_tensor(0); + T* delegate_output_data = delegate_interpreter->typed_output_tensor(0); + + const int32_t output_size = OutputSize(); + if constexpr (std::is_floating_point_v) { + for (size_t i = 0; i < output_size; i++) { + ASSERT_NEAR( + default_output_data[i], delegate_output_data[i], + std::numeric_limits::epsilon() * + std::max(std::abs(default_output_data[i]) * RelativeTolerance(), + 1.0f)); + } + } else { + for (size_t i = 0; i < output_size; i++) { + ASSERT_LE(std::abs(default_output_data[i] - delegate_output_data[i]), 1) + << "default " << +default_output_data[i] << ", delegate " + << +delegate_output_data[i] << " at index " << i << " / " + << output_size; + } + } +} + +void ReduceTester::Test(tflite::BuiltinOperator reduce_op, + TfLiteDelegate* delegate) const { std::vector buffer = CreateTfLiteModel(reduce_op); const Model* model = GetModel(buffer.data()); @@ -75,31 +127,34 @@ void ReduceTester::Test(tflite::BuiltinOperator reduce_op, ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk); - float* default_input_data = default_interpreter->typed_input_tensor(0); - std::generate_n(default_input_data, InputSize(), std::ref(input_rng)); - - float* delegate_input_data = - delegate_interpreter->typed_input_tensor(0); - std::copy_n(default_input_data, InputSize(), delegate_input_data); - - ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk); - ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk); + switch (Quantization()) { + case Quantization::None: + Test(delegate_interpreter.get(), default_interpreter.get()); + break; + case Quantization::Signed: + Test(delegate_interpreter.get(), default_interpreter.get()); + break; + case Quantization::Unsigned: + Test(delegate_interpreter.get(), default_interpreter.get()); + break; + } +} - float* default_output_data = - default_interpreter->typed_output_tensor(0); - float* delegate_output_data = - delegate_interpreter->typed_output_tensor(0); +namespace { - const int32_t output_size = OutputSize(); - for (size_t i = 0; i < output_size; i++) { - ASSERT_NEAR( - default_output_data[i], delegate_output_data[i], - std::numeric_limits::epsilon() * - std::max(std::abs(default_output_data[i]) * RelativeTolerance(), - 1.0f)); +TensorType GetTensorType(enum ReduceTester::Quantization q) { + switch (q) { + case ReduceTester::Quantization::None: + return TensorType_FLOAT32; + case ReduceTester::Quantization::Signed: + return TensorType_INT8; + case ReduceTester::Quantization::Unsigned: + return TensorType_UINT8; } } +} // namespace + std::vector ReduceTester::CreateTfLiteModel( tflite::BuiltinOperator reduce_op) const { flatbuffers::FlatBufferBuilder builder; @@ -116,11 +171,27 @@ std::vector ReduceTester::CreateTfLiteModel( const std::vector output_shape = OutputShape(); const std::array axes_shape{ {static_cast(Axes().size())}}; + + const flatbuffers::Offset input_quantization = + Quantization() == Quantization::None + ? 0 + : CreateQuantizationParameters( + builder, /*min=*/0, /*max=*/0, + builder.CreateVector({InputScale()}), + builder.CreateVector({InputZeroPoint()})); + const flatbuffers::Offset output_quantization = + Quantization() == Quantization::None + ? 0 + : CreateQuantizationParameters( + builder, /*min=*/0, /*max=*/0, + builder.CreateVector({OutputScale()}), + builder.CreateVector({OutputZeroPoint()})); const std::array, 3> tensors{{ CreateTensor(builder, builder.CreateVector(InputShape().data(), InputShape().size()), - TensorType_FLOAT32), + GetTensorType(Quantization()), /*buffer=*/0, /*name=*/0, + input_quantization), CreateTensor( builder, builder.CreateVector(axes_shape.data(), axes_shape.size()), @@ -128,7 +199,8 @@ std::vector ReduceTester::CreateTfLiteModel( CreateTensor(builder, builder.CreateVector(output_shape.data(), output_shape.size()), - TensorType_FLOAT32), + GetTensorType(Quantization()), /*buffer=*/0, /*name=*/0, + output_quantization), }}; const flatbuffers::Offset reducer_options = @@ -152,8 +224,12 @@ std::vector ReduceTester::CreateTfLiteModel( subgraph_outputs.size()), builder.CreateVector(&op, 1)); + std::string model_description = "Reduce model"; + if (Quantization() != Quantization::None) { + model_description = "Quantized reduce model"; + } flatbuffers::Offset description = - builder.CreateString("Reduce model"); + builder.CreateString(model_description); flatbuffers::Offset model_buffer = CreateModel( builder, TFLITE_SCHEMA_VERSION, builder.CreateVector(&operator_code, 1), diff --git a/tensorflow/lite/delegates/xnnpack/reduce_tester.h b/tensorflow/lite/delegates/xnnpack/reduce_tester.h index 149b6303080ba8..b3a05e0c9b882e 100644 --- a/tensorflow/lite/delegates/xnnpack/reduce_tester.h +++ b/tensorflow/lite/delegates/xnnpack/reduce_tester.h @@ -17,11 +17,13 @@ limitations under the License. #define TENSORFLOW_LITE_DELEGATES_XNNPACK_REDUCE_TESTER_H_ #include +#include #include #include #include #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/schema/schema_generated.h" namespace tflite { @@ -29,6 +31,8 @@ namespace xnnpack { class ReduceTester { public: + enum class Quantization { None, Signed, Unsigned }; + ReduceTester() = default; ReduceTester(const ReduceTester&) = delete; ReduceTester& operator=(const ReduceTester&) = delete; @@ -97,6 +101,45 @@ class ReduceTester { inline float RelativeTolerance() const { return relative_tolerance_; } + inline ReduceTester& InputZeroPoint(int32_t input_zero_point) { + input_zero_point_ = input_zero_point; + return *this; + } + + inline int32_t InputZeroPoint() const { return input_zero_point_; } + + inline ReduceTester& OutputZeroPoint(int32_t output_zero_point) { + output_zero_point_ = output_zero_point; + return *this; + } + + inline int32_t OutputZeroPoint() const { return output_zero_point_; } + + inline ReduceTester& InputScale(float input_scale) { + input_scale_ = input_scale; + return *this; + } + + inline float InputScale() const { return input_scale_; } + + inline ReduceTester& OutputScale(float output_scale) { + output_scale_ = output_scale; + return *this; + } + + inline float OutputScale() const { return output_scale_; } + + inline ReduceTester& Quantization(Quantization q) { + quantization_ = q; + return *this; + } + + inline enum Quantization Quantization() const { return quantization_; } + + template + void Test(Interpreter* delegate_interpreter, + Interpreter* default_interpreter) const; + void Test(tflite::BuiltinOperator reduce_op, TfLiteDelegate* delegate) const; private: @@ -109,6 +152,11 @@ class ReduceTester { int32_t input_size_; bool keep_dims_ = true; float relative_tolerance_ = 10.0f; + int32_t input_zero_point_ = 1; + int32_t output_zero_point_ = 2; + float input_scale_ = 1.25f; + float output_scale_ = 0.75f; + enum Quantization quantization_ = Quantization::None; }; } // namespace xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/relu6_test.cc b/tensorflow/lite/delegates/xnnpack/relu6_test.cc index 75f32dcfd39116..5f2de211ec4ef9 100644 --- a/tensorflow/lite/delegates/xnnpack/relu6_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu6_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc b/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc index 9e799577e6ed73..07aab082aefc0a 100644 --- a/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu_n1_to_1_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/relu_test.cc b/tensorflow/lite/delegates/xnnpack/relu_test.cc index 8996ff5d04b8c4..b088a2b9053e18 100644 --- a/tensorflow/lite/delegates/xnnpack/relu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/relu_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/reshape_test.cc b/tensorflow/lite/delegates/xnnpack/reshape_test.cc index fc8d240f120ff5..56c252f461eef6 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_test.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/reshape_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc index e2f4fe2e63e9ad..a3c5a17fd38105 100644 --- a/tensorflow/lite/delegates/xnnpack/reshape_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/reshape_tester.cc @@ -25,11 +25,12 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc b/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc index b1fc49ca93fbd1..c66004e3205617 100644 --- a/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc +++ b/tensorflow/lite/delegates/xnnpack/resize_bilinear_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc index e4ee08280f4260..c2832b0c64d68b 100644 --- a/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/resize_bilinear_tester.cc @@ -25,11 +25,13 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" -#include "tensorflow/lite/core/model.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/round_test.cc b/tensorflow/lite/delegates/xnnpack/round_test.cc index 0481762ca1947b..1e4f861bf12b6a 100644 --- a/tensorflow/lite/delegates/xnnpack/round_test.cc +++ b/tensorflow/lite/delegates/xnnpack/round_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc b/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc index f86379326d0bc6..927152db28af04 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_dequantize_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/dequantize_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc index 87e669c5d0cde2..ad159deb61b55c 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_add_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc index 35377adfb86e26..c7590f21f8a944 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_concatenation_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/concatenation_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc index a43b3c42fbf40f..f67ba714b01cc8 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_conv_2d_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_conv_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc index f6003e99398949..d85a9cfb1ceac4 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_depth_to_space_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/depth_to_space_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc index 33ff96aa84594f..3acfbaaf34778e 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_depthwise_conv_2d_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_depthwise_conv_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc index b66434f7454d82..676fc06bdf4fe5 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_elu_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc index bf341b8bd8e00f..1be48daba79655 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_fully_connected_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_fully_connected_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc index c03c767eb7a64f..4aa74580b6b827 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_leaky_relu_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_leaky_relu_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc index 02f3383954ce24..9067ffebf02dfd 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_logistic_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc index 509f2cd1e72849..4a12e817039b32 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_max_pool_2d_test.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_pool_2d_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc deleted file mode 100644 index 87fcf35bb87972..00000000000000 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_mean_test.cc +++ /dev/null @@ -1,501 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include -#include "tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" - -namespace tflite { -namespace xnnpack { - -TEST(SignedQuantizedMean, DISABLED_4DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_4DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_4DReduceHeightSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_4DReduceHeightKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, 4DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, 4DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, 4DReduceHeightWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, 4DReduceHeightWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_4DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_4DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_3DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_2DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_2DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_2DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_2DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_1DSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - QuantizedReduceTester().InputShape({batch}).Axes({0}).KeepDims(false).Test( - BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, DISABLED_1DKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - QuantizedReduceTester().InputShape({batch}).Axes({0}).KeepDims(true).Test( - BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(SignedQuantizedMean, MultiThreading) { - TfLiteXNNPackDelegateOptions delegate_options = - TfLiteXNNPackDelegateOptionsDefault(); - delegate_options.num_threads = 2; - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -} // namespace xnnpack -} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc index fbab966d6229f5..b28ed665ed3542 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_mul_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc index 13c4ff2d2ade90..7ce3ad1a2b4653 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_pad_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_pad_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc index 1f57d47cdba326..71f0843406535a 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_reshape_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/reshape_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc index 8c77ba185f552a..c3cf1cef9dc3af 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_resize_bilinear_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_resize_bilinear_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc index 2b487ee4151cc0..48ca30e3adfbc5 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_slice_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/slice_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc index 3e23a7701f1a73..99d4ce31ea9a74 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_space_to_depth_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/space_to_depth_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite::xnnpack { namespace { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc index 3a5338da1726b1..2cf61c50ef6662 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_split_test.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/split_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc index f934d56ad7e123..3540057a7d676a 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_strided_slice_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/delegates/xnnpack/strided_slice_tester.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc index 8a0e5f5204c851..bd5e92dcf16582 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_sub_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_binary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc index 939d296fca6a72..708ac12112beca 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_tanh_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_unary_elementwise_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc index 095256fff2d656..7daae13ebdea16 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_conv_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_transpose_conv_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc index 71ddef900fdcd7..d32af38da21c61 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_transpose_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/transpose_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc index 5e0df9cb63982b..22a59e76720a5d 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_multiple_subgraph_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc index 2d54b207969d64..5c083a37570ce9 100644 --- a/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc +++ b/tensorflow/lite/delegates/xnnpack/signed_quantized_variable_ops_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/quantized_variable_ops_tester.h" namespace tflite { diff --git a/tensorflow/lite/delegates/xnnpack/slice_test.cc b/tensorflow/lite/delegates/xnnpack/slice_test.cc index 3a80181a6e74b7..3a1790b1143d85 100644 --- a/tensorflow/lite/delegates/xnnpack/slice_test.cc +++ b/tensorflow/lite/delegates/xnnpack/slice_test.cc @@ -22,8 +22,10 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/slice_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace xnnpack { diff --git a/tensorflow/lite/delegates/xnnpack/slice_tester.cc b/tensorflow/lite/delegates/xnnpack/slice_tester.cc index 5c1aa6c5921242..da97c89e983645 100644 --- a/tensorflow/lite/delegates/xnnpack/slice_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/slice_tester.cc @@ -25,9 +25,13 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/interpreter.h" -#include "tensorflow/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" diff --git a/tensorflow/lite/delegates/xnnpack/softmax_test.cc b/tensorflow/lite/delegates/xnnpack/softmax_test.cc index ae33a1afad37af..f55d3c23f66019 100644 --- a/tensorflow/lite/delegates/xnnpack/softmax_test.cc +++ b/tensorflow/lite/delegates/xnnpack/softmax_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/delegates/xnnpack/softmax_tester.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" diff --git a/tensorflow/lite/delegates/xnnpack/sum_test.cc b/tensorflow/lite/delegates/xnnpack/sum_test.cc deleted file mode 100644 index 269b55b132cdf5..00000000000000 --- a/tensorflow/lite/delegates/xnnpack/sum_test.cc +++ /dev/null @@ -1,501 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include -#include "tensorflow/lite/delegates/xnnpack/reduce_tester.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" - -namespace tflite { -namespace xnnpack { - -TEST(Sum, DISABLED_4DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_4DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_4DReduceHeightSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_4DReduceHeightKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, 4DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, 4DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, 4DReduceHeightWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, 4DReduceHeightWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_4DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_4DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_3DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_2DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_2DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_2DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_2DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_1DSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - ReduceTester().InputShape({batch}).Axes({0}).KeepDims(false).Test( - BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, DISABLED_1DKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - ReduceTester().InputShape({batch}).Axes({0}).KeepDims(true).Test( - BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -TEST(Sum, MultiThreading) { - TfLiteXNNPackDelegateOptions delegate_options = - TfLiteXNNPackDelegateOptionsDefault(); - delegate_options.num_threads = 2; - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - ReduceTester() - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_SUM, xnnpack_delegate.get()); -} - -} // namespace xnnpack -} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/unsigned_quantized_mean_test.cc b/tensorflow/lite/delegates/xnnpack/unsigned_quantized_mean_test.cc deleted file mode 100644 index 43433a06a33898..00000000000000 --- a/tensorflow/lite/delegates/xnnpack/unsigned_quantized_mean_test.cc +++ /dev/null @@ -1,528 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include -#include "tensorflow/lite/delegates/xnnpack/quantized_reduce_tester.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" - -namespace tflite { -namespace xnnpack { - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceHeightSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceHeightKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, 4DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, 4DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, 4DReduceHeightWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, 4DReduceHeightWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({2, 1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_4DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({3}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceWidthSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceWidthKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_3DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, width, channels}) - .Axes({2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_2DReduceBatchSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_2DReduceBatchKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, channels}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_2DReduceChannelsSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(false) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_2DReduceChannelsKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, channels}) - .Axes({1}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_1DSqueezeDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - QuantizedReduceTester().InputShape({batch}).Axes({0}).KeepDims(false).Test( - BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, DISABLED_1DKeepDims) { - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch}) - .Axes({0}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -TEST(UnsignedQuantizedMean, MultiThreading) { - TfLiteXNNPackDelegateOptions delegate_options = - TfLiteXNNPackDelegateOptionsDefault(); - delegate_options.num_threads = 2; - std::unique_ptr - xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options), - TfLiteXNNPackDelegateDelete); - - std::random_device random_device; - auto rng = std::mt19937(random_device()); - auto shape_rng = - std::bind(std::uniform_int_distribution(2, 5), std::ref(rng)); - const auto batch = shape_rng(); - const auto height = shape_rng(); - const auto width = shape_rng(); - const auto channels = shape_rng(); - - QuantizedReduceTester() - .Unsigned(true) - .InputShape({batch, height, width, channels}) - .Axes({1, 2}) - .KeepDims(true) - .Test(BuiltinOperator_MEAN, xnnpack_delegate.get()); -} - -} // namespace xnnpack -} // namespace tflite diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.cc b/tensorflow/lite/delegates/xnnpack/weight_cache.cc index ab70bc4b63a504..da220a3a31b241 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache.cc +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.cc @@ -66,7 +66,7 @@ limitations under the License. namespace tflite::xnnpack { namespace { -constexpr size_t kMinAlignment = 64; +constexpr size_t kMinAlignment = 128; // Checks if the given path is a special value to use an in-memory cache. bool IsInMemoryCachePath(const char* path) { @@ -132,6 +132,8 @@ bool FileExists(const char* path) { void swap(MMapHandle& a, MMapHandle& b) { using std::swap; swap(a.size_, b.size_); + swap(a.offset_, b.offset_); + swap(a.offset_page_adjustment_, b.offset_page_adjustment_); swap(a.data_, b.data_); } @@ -144,11 +146,12 @@ MMapHandle& MMapHandle::operator=(MMapHandle&& other) { return *this; } -bool MMapHandle::Map(const char* path) { - return this->Map(FileDescriptor::Open(path, O_RDONLY), path); +bool MMapHandle::Map(const char* path, const size_t offset) { + return this->Map(FileDescriptor::Open(path, O_RDONLY), offset, path); } -bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { +bool MMapHandle::Map(const FileDescriptor& fd, const size_t offset, + const char* const path) { this->UnMap(); XNNPACK_RETURN_CHECK(fd.IsValid(), @@ -162,15 +165,19 @@ bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { // This will reset data_ and size_ on return until is is deactivated. ScopeGuard unmap_on_error([this] { UnMap(); }); - size_ = file_stats.st_size; + size_ = file_stats.st_size - offset; + offset_ = offset; #if defined(_MSC_VER) // This allocation is freed in UnMap and in the desctructor. data_ = new uint8_t[size_]; + fd.SetPos(offset); XNNPACK_RETURN_CHECK(fd.Read(data_, size_), "could not read file ('%s'): %s.", path, strerror(errno)); #else - data_ = static_cast(mmap(/*addr=*/nullptr, size_, PROT_READ, - MAP_SHARED, fd.Value(), /*offset=*/0)); + offset_page_adjustment_ = offset_ % getpagesize(); + data_ = static_cast( + mmap(/*addr=*/nullptr, size_ + offset_page_adjustment_, PROT_READ, + MAP_SHARED, fd.Value(), offset_ - offset_page_adjustment_)); XNNPACK_RETURN_CHECK(data_ != MAP_FAILED, "could not mmap file (%s): %s.", path, strerror(errno)); #endif @@ -178,6 +185,25 @@ bool MMapHandle::Map(const FileDescriptor& fd, const char* const path) { return true; } +bool MMapHandle::Resize(size_t new_size) { +#if defined(__linux__) || defined(__ANDROID__) + void* const remapped_data = + mremap(data_, size_ + offset_page_adjustment_, + new_size + offset_page_adjustment_, /*flags=*/0); + if (remapped_data == MAP_FAILED) { + XNNPACK_RETURN_CHECK(errno == ENOMEM, "remap failed: %s", strerror(errno)); + return false; + } + size_ = new_size; + return true; +#else + // The current implementation uses new/delete which doesn't provide a way to + // modify an allocation size. Changing to malloc/realloc/free doesn't ensure + // that a memory allocation will not be moved when reallocating + return false; +#endif +} + void MMapHandle::UnMap() { if (data_) { #if defined(_MSC_VER) @@ -187,39 +213,46 @@ void MMapHandle::UnMap() { #endif } data_ = nullptr; + offset_ = 0; + offset_page_adjustment_ = 0; size_ = 0; } -void swap(WeightCacheBuilder& a, WeightCacheBuilder& b) { - using std::swap; - swap(a.schema_, b.schema_); - swap(a.data_, b.data_); - swap(a.capacity_, b.capacity_); - swap(a.fd_, b.fd_); - swap(a.file_path_, b.file_path_); -} - -WeightCacheBuilder::WeightCacheBuilder(WeightCacheBuilder&& other) { - swap(*this, other); -} +#define XNN_MOVE_CONSTRUCT_MEMBER(x) x(std::move(other.x)) +WeightCacheBuilder::WeightCacheBuilder(WeightCacheBuilder&& other) + : XNN_MOVE_CONSTRUCT_MEMBER(data_), + XNN_MOVE_CONSTRUCT_MEMBER(schema_), + XNN_MOVE_CONSTRUCT_MEMBER(capacity_), + XNN_MOVE_CONSTRUCT_MEMBER(build_segment_size_), + XNN_MOVE_CONSTRUCT_MEMBER(build_segment_start_), + XNN_MOVE_CONSTRUCT_MEMBER(first_write_done_), + XNN_MOVE_CONSTRUCT_MEMBER(fd_), + XNN_MOVE_CONSTRUCT_MEMBER(file_path_) {} +#undef XNN_MOVE_CONSTRUCT_MEMBER WeightCacheBuilder& WeightCacheBuilder::operator=(WeightCacheBuilder&& other) { - Reset(); - swap(*this, other); +#define XNN_MOVE_MEMBER(x) x = std::move(other.x) + XNN_MOVE_MEMBER(data_); + XNN_MOVE_MEMBER(schema_); + XNN_MOVE_MEMBER(capacity_); + XNN_MOVE_MEMBER(build_segment_size_); + XNN_MOVE_MEMBER(build_segment_start_); + XNN_MOVE_MEMBER(first_write_done_); + XNN_MOVE_MEMBER(fd_); + XNN_MOVE_MEMBER(file_path_); +#undef XNN_MOVE_MEMBER return *this; } -WeightCacheBuilder::~WeightCacheBuilder() { Reset(); } - bool WeightCacheBuilder::Start(const char* path) { - Reset(); - ScopeGuard reset_on_error([this] { Reset(); }); - + XNNPACK_RETURN_CHECK(!IsStarted()); file_path_ = path; + if (IsInMemoryCachePath(file_path_)) { fd_ = CreateInMemoryFileDescriptor("XNNPack in-memory weight cache"); } else { - fd_.Reset(open(file_path_.c_str(), O_CREAT | O_TRUNC | O_WRONLY, 0644)); + fd_ = FileDescriptor::Open(file_path_.c_str(), O_CREAT | O_TRUNC | O_RDWR, + 0644); } XNNPACK_RETURN_CHECK(fd_.IsValid(), "could not open file ('%s'): %s.", file_path_.c_str(), strerror(errno)); @@ -227,44 +260,63 @@ bool WeightCacheBuilder::Start(const char* path) { // Write data in the header, this will be overwritten in the `Finalize` call. // We explicitly set the header as invalid. If any error happens during // the build, reloading the cache file will fail. - const XNNPackCacheHeader header{XNNPackCacheHeader::kInvalidHeader}; + XNNPackCacheHeader header{XNNPackCacheHeader::kInvalidHeader}; + header.buffer_list_offset = sizeof(header); XNNPACK_RETURN_CHECK(fd_.Write(&header, sizeof(header)), - "could not write padding for flatbuffer offset in %s.", + "could not write initial cache header in %s.", file_path_.c_str()); schema_.base_offset = Align(sizeof(header), kMinAlignment); - - reset_on_error.Deactivate(); return true; } -void WeightCacheBuilder::Reset() { - fd_.Close(); - data_.reset(nullptr); - capacity_ = 0; - schema_ = cache::schema::BufferListT(); +bool WeightCacheBuilder::StartBuildStep() { + XNNPACK_RETURN_CHECK(IsStarted()); + + // Reload flatbuffer data. + XNNPackCacheHeader header; + fd_.SetPos(0); + XNNPACK_RETURN_CHECK(fd_.Read(&header, sizeof(header)), + "could not read cache file header."); + if (header.buffer_list_size) { + MMapHandle buffer_list_data; + XNNPACK_RETURN_CHECK(buffer_list_data.Map(fd_, header.buffer_list_offset), + "could not map buffer list mapping"); + cache::schema::GetBufferList(buffer_list_data.data())->UnPackTo(&schema_); + } + + // Move cursor to end of existing data. + build_segment_size_ = 0; + build_segment_start_ = fd_.SetPos(header.buffer_list_offset); + XNNPACK_RETURN_CHECK(build_segment_start_ != -1); + + is_build_step_ = true; + return true; } +void WeightCacheBuilder::Reset() { *this = WeightCacheBuilder(); } + void* WeightCacheBuilder::Reserve(size_t size) { if (size > capacity_) { // We don't care about the data when we are reserving space. We save memory // by deleting the existing buffer first. data_.reset(nullptr); - data_ = std::make_unique(size); + data_ = std::make_unique(size + kMinAlignment); capacity_ = size; } - return data_.get(); + return reinterpret_cast( + Align(reinterpret_cast(data_.get()), kMinAlignment)); } BufferLocation WeightCacheBuilder::Append(PackIdentifier pack_id, const void* data, uint64_t size) { - XNNPACK_ABORT_CHECK(IsStarted(), + XNNPACK_ABORT_CHECK(is_build_step_, "cannot append data to an unstarted builder."); // Add some padding so that the cache file can be mmaped and the buffer // stays aligned correctly. const size_t offset = Align(fd_.GetPos(), kMinAlignment); - if (fd_.SetPos(offset) != offset) { + if (fd_.SetPos(offset) == -1) { return BufferLocation::Invalid(); } @@ -278,20 +330,24 @@ BufferLocation WeightCacheBuilder::Append(PackIdentifier pack_id, schema_.buffers.push_back(std::make_unique(buffer)); if (!fd_.Write(data, size)) { - TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, file_path_.c_str(), + TFLITE_LOG_PROD(tflite::TFLITE_LOG_ERROR, "XNNPack weight cache: cannot append buffer to cache file"); return BufferLocation::Invalid(); } return loc; } -bool WeightCacheBuilder::ShouldFinalize() const { return fd_.IsValid(); } - -bool WeightCacheBuilder::Finalize() { +bool WeightCacheBuilder::StopBuildStep() { XNNPACK_RETURN_CHECK(fd_.IsValid(), "cache file ('%s') is not open for writing: %s.", file_path_.c_str(), strerror(errno)); + is_build_step_ = false; + if (fd_.GetPos() == build_segment_start_ && first_write_done_) { + // Nothing was written to the file, we can exit early. + return true; + } + flatbuffers::FlatBufferBuilder builder; // Add a fake size and the base offset to mutate them afterwards. Otherwise // space for it won't be added to the flatbuffer. @@ -321,16 +377,19 @@ bool WeightCacheBuilder::Finalize() { XNNPACK_RETURN_CHECK(fd_.Write(builder.GetBufferPointer(), builder.GetSize()), "cannot write buffer list to '%s'.", file_path_.c_str()); + // Save the segment size for that it can be individually mapped. + build_segment_size_ = fd_.GetPos() - build_segment_start_; + // Write the header at the beginning of the file. XNNPACK_RETURN_CHECK(fd_.SetPos(0) != -1, "could not move in the file to write header to %s", strerror(errno)); - XNNPACK_ABORT_CHECK(fd_.Write(&header, sizeof(header)), - "cannot write cache header to %s.", file_path_.c_str()); + XNNPACK_RETURN_CHECK(fd_.Write(&header, sizeof(header)), + "cannot write cache header to %s.", file_path_.c_str()); TFLITE_LOG_PROD(tflite::TFLITE_LOG_VERBOSE, "XNNPack weight cache: written to '%s'.", file_path_.c_str()); - Reset(); + first_write_done_ = true; return true; } @@ -349,7 +408,7 @@ MMapWeightCacheProvider& MMapWeightCacheProvider::operator=( swap(file_path_, other.file_path_); swap(buffer_address_to_identifier_, other.buffer_address_to_identifier_); swap(cache_key_to_offset_, other.cache_key_to_offset_); - swap(mmap_handle_, other.mmap_handle_); + swap(mmap_handles_, other.mmap_handles_); swap(mmap_buffer_base_offset_, other.mmap_buffer_base_offset_); swap(builder_, other.builder_); return *this; @@ -357,7 +416,7 @@ MMapWeightCacheProvider& MMapWeightCacheProvider::operator=( void MMapWeightCacheProvider::SetFilePath(const char* path) { XNNPACK_ABORT_CHECK( - !IsFinalized(), + !IsBuilding(), "Cannot change the path of a cache that has already been loaded."); // We try to keep file_path_'s data as stable as possible. Don't overwrite // if the path hasn't changed. @@ -374,7 +433,6 @@ bool MMapWeightCacheProvider::LoadOrStartBuild(const char* path) { } else if (StartBuild(path)) { TFLITE_LOG_PROD(tflite::TFLITE_LOG_VERBOSE, "XNNPack weight cache build for '%s' started.", path); - return true; } return false; @@ -382,7 +440,13 @@ bool MMapWeightCacheProvider::LoadOrStartBuild(const char* path) { bool MMapWeightCacheProvider::StartBuild(const char* path) { SetFilePath(path); - return builder_.Start(path); + building_run_ = builder_.Start(path); + if (IsInMemoryCachePath(file_path_)) { + // Duplicate the file descriptor to avoid loosing the temporary file when + // the builder is reset. + temporary_file_descriptor_ = builder_.GetFileDescriptor().Duplicate(); + } + return building_run_; } bool MMapWeightCacheProvider::Load(const std::string& path) { @@ -393,10 +457,13 @@ bool MMapWeightCacheProvider::Load(const std::string& path) { bool MMapWeightCacheProvider::Load() { mmap_buffer_base_offset_ = 0; cache_key_to_offset_.clear(); + mmap_handles_.resize(1); + MMapHandle& mmap_handle = mmap_handles_.front(); + ScopeGuard unmap_on_fail([this] { mmap_handles_.clear(); }); if (temporary_file_descriptor_.IsValid()) { - XNNPACK_RETURN_CHECK( - mmap_handle_.Map(temporary_file_descriptor_, file_path_.c_str())); + XNNPACK_RETURN_CHECK(mmap_handle.Map(temporary_file_descriptor_, + /*offset=*/0, file_path_.c_str())); } else { XNNPACK_ABORT_CHECK(!file_path_.empty(), "Path wasn't provided to weight cache provider."); @@ -406,24 +473,22 @@ bool MMapWeightCacheProvider::Load() { file_path_.c_str(), strerror(errno)); return false; } - - XNNPACK_RETURN_CHECK(mmap_handle_.Map(file_path_.c_str())); + XNNPACK_RETURN_CHECK(mmap_handle.Map(file_path_.c_str())); } - ScopeGuard unmap_on_fail([this] { mmap_handle_.UnMap(); }); - - XNNPACK_RETURN_CHECK(mmap_handle_.size() >= sizeof(XNNPackCacheHeader), + XNNPACK_RETURN_CHECK(mmap_handle.size() >= sizeof(XNNPackCacheHeader), "invalid cache file size."); - const XNNPackCacheHeader header = [this] { + const XNNPackCacheHeader header = [&mmap_handle] { XNNPackCacheHeader header; - memcpy(&header, mmap_handle_.data(), sizeof(header)); + memcpy(&header, mmap_handle.data(), sizeof(header)); return header; }(); - XNNPACK_RETURN_CHECK( - header.version == XNNPackCacheHeader::kVersion, - "incompatible header version. Cache needs to be built again."); + XNNPACK_RETURN_CHECK(header.version == XNNPackCacheHeader::kVersion, + "incompatible header version. Got %zd, expected %zd. " + "Cache needs to be built again.", + header.version, XNNPackCacheHeader::kVersion); XNNPACK_RETURN_CHECK(xnn_experimental_check_build_identifier( header.xnnpack_build_identifier, @@ -431,22 +496,22 @@ bool MMapWeightCacheProvider::Load() { "XNNPack weight cache: incompatible XNNPack version. " "Cache needs to be built again."); - XNNPACK_RETURN_CHECK(header.buffer_list_offset < mmap_handle_.size(), + XNNPACK_RETURN_CHECK(header.buffer_list_offset < mmap_handle.size(), "invalid offset for buffer list descriptor."); - XNNPACK_RETURN_CHECK(header.buffer_list_size == - mmap_handle_.size() - header.buffer_list_offset, - "invalid size for buffer list descriptor."); + XNNPACK_RETURN_CHECK( + header.buffer_list_size == mmap_handle.size() - header.buffer_list_offset, + "invalid size for buffer list descriptor."); // Verifiy the flabuffer part of the file. - flatbuffers::Verifier verifier( - mmap_handle_.data() + header.buffer_list_offset, header.buffer_list_size); + flatbuffers::Verifier verifier(mmap_handle.data() + header.buffer_list_offset, + header.buffer_list_size); XNNPACK_RETURN_CHECK(cache::schema::VerifyBufferListBuffer(verifier), "buffer list validation failed."); // Load flatbuffer. const cache::schema::BufferList* buffer_list = cache::schema::GetBufferList( - mmap_handle_.data() + header.buffer_list_offset); + mmap_handle.data() + header.buffer_list_offset); XNNPACK_RETURN_CHECK(buffer_list, "could not get packed weights from flatbuffer."); @@ -459,6 +524,9 @@ bool MMapWeightCacheProvider::Load() { /*weights_id=*/buffer->weights_id(), /*bias_id=*/buffer->bias_id()}, BufferLocation{/*offset=*/buffer->offset(), /*size=*/buffer->size()}); + offset_to_addr_.insert( + {buffer->offset(), + mmap_handle.data() + mmap_buffer_base_offset_ + buffer->offset()}); } } @@ -466,6 +534,87 @@ bool MMapWeightCacheProvider::Load() { return true; } +bool MMapWeightCacheProvider::LoadLastBuildStep() { + if (mmap_handles_.empty()) { + return Load(); + } + + if (builder_.LastBuildStepSize() == 0) { + return true; + } + + const XNNPackCacheHeader header = [this] { + XNNPackCacheHeader header; + memcpy(&header, mmap_handles_.front().data(), sizeof(header)); + return header; + }(); + + // Map last data segment: + // - either resize the last mmap handle; + // - or add a new mapping handle. + { + MMapHandle& last_mmap_handle = mmap_handles_.back(); + const int last_mmap_size = last_mmap_handle.size(); + if (!last_mmap_handle.Resize(last_mmap_size + + builder_.LastBuildStepSize())) { + mmap_handles_.emplace_back(); + if (temporary_file_descriptor_.IsValid()) { + XNNPACK_RETURN_CHECK( + mmap_handles_.back().Map(temporary_file_descriptor_, + /*offset=*/builder_.LastBuildStepStart()), + "could not map last build step"); + } else { + XNNPACK_RETURN_CHECK( + mmap_handles_.back().Map(file_path_.c_str(), + /*offset=*/builder_.LastBuildStepStart()), + "could not map last build step"); + } + } + } + // Read the updated buffer list. + MMapHandle& segment_mmap_handle = mmap_handles_.back(); + const size_t buffer_list_offset = + header.buffer_list_offset - segment_mmap_handle.offset(); + + flatbuffers::Verifier verifier( + segment_mmap_handle.data() + buffer_list_offset, header.buffer_list_size); + XNNPACK_RETURN_CHECK(cache::schema::VerifyBufferListBuffer(verifier), + "buffer list validation failed."); + + const cache::schema::BufferList* buffer_list = cache::schema::GetBufferList( + segment_mmap_handle.data() + buffer_list_offset); + XNNPACK_RETURN_CHECK(buffer_list, + "could not get packed weights from flatbuffer."); + + // Update offset_to_addr_ with new offsets + const ptrdiff_t offset_modifier = + buffer_list->base_offset() - segment_mmap_handle.offset(); + for (const auto* buffer : *(buffer_list->buffers())) { + const size_t offset = buffer->offset(); + if (!offset_to_addr_.count(offset)) { + offset_to_addr_.insert( + {offset, segment_mmap_handle.data() + offset + offset_modifier}); + } + } + return true; +} + +bool MMapWeightCacheProvider::StartBuildStep() { + XNNPACK_RETURN_CHECK(CanStartBuildStep(), + "cannot append data to an existing cache file."); + if (IsBuilding()) { + return true; + } + is_build_step_ = builder_.StartBuildStep(); + return is_build_step_; +} + +bool MMapWeightCacheProvider::StopBuildStep() { + XNNPACK_RETURN_CHECK(builder_.StopBuildStep()); + is_build_step_ = false; + return LoadLastBuildStep(); +} + void MMapWeightCacheProvider::MapTensorIdentifiers( const TfLiteTensor* tensors, const size_t size, const std::unordered_map& tensor_index_to_identifier) { @@ -497,8 +646,8 @@ size_t MMapWeightCacheProvider::LookUp( } void* MMapWeightCacheProvider::ReserveSpace(size_t size) { - XNNPACK_ABORT_CHECK(!IsFinalized(), - "Cannot reserve space in a finalized cache."); + XNNPACK_ABORT_CHECK(IsBuilding(), + "Cannot reserve space in a cache that isn't building."); return builder_.Reserve(size); } @@ -512,8 +661,8 @@ size_t MMapWeightCacheProvider::LookUpOrInsert( return offset_it->second.offset; } - XNNPACK_ABORT_CHECK(!IsFinalized(), - "Cannot insert a buffer in a finalized cache."); + XNNPACK_ABORT_CHECK( + IsBuilding(), "Cannot insert a buffer in a cache that is not building."); const BufferLocation location = builder_.Append(pack_id, ptr, size); XNNPACK_ABORT_CHECK(!location.IsInvalid(), @@ -526,42 +675,19 @@ void* MMapWeightCacheProvider::OffsetToAddr(const size_t offset) { // While the cache is being built, the buffer could grow and need to be // reallocated so we cannot ensure pointer stability. XNNPACK_ABORT_CHECK( - IsFinalized(), - "Cannot get the address of a buffer in a non finalized cache."); - return mmap_handle_.data() + mmap_buffer_base_offset_ + offset; + !IsBuilding(), + "Cannot get the address of a buffer in a cache during a building step."); + return offset_to_addr_[offset]; } void MMapWeightCacheProvider::Release() { buffer_address_to_identifier_.clear(); cache_key_to_offset_.clear(); - mmap_handle_ = MMapHandle(); + mmap_handles_.clear(); mmap_buffer_base_offset_ = 0; builder_ = WeightCacheBuilder(); } -bool MMapWeightCacheProvider::Finalize() { - if (IsFinalized()) { - return true; - } - XNNPACK_RETURN_CHECK(!file_path_.empty(), - "file path wasn't set. Cannot finalize the cache."); - if (IsInMemoryCachePath(file_path_)) { - // Duplicate the file descriptor to avoid loosing the temporary file when - // the builder is reset. - temporary_file_descriptor_ = builder_.GetFileDescriptor().Duplicate(); - } - if (!builder_.Finalize()) { - return false; - } - builder_ = WeightCacheBuilder(); - - return Load(); -} - -bool MMapWeightCacheProvider::IsFinalized() const { - return mmap_handle_.IsMapped(); -} - size_t MMapWeightCacheProvider::look_up( void* context, const xnn_weights_cache_look_up_key* cache_key) { return reinterpret_cast(context)->LookUp(cache_key); @@ -579,7 +705,7 @@ size_t MMapWeightCacheProvider::look_up_or_insert( } bool MMapWeightCacheProvider::is_finalized(void* context) { - return reinterpret_cast(context)->IsFinalized(); + return reinterpret_cast(context)->IsActive(); } void* MMapWeightCacheProvider::offset_to_addr(void* context, size_t offset) { diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache.h b/tensorflow/lite/delegates/xnnpack/weight_cache.h index afdd4d02f068fd..3e2efed46a6c45 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache.h +++ b/tensorflow/lite/delegates/xnnpack/weight_cache.h @@ -18,9 +18,11 @@ limitations under the License. #include #include #include +#include #include #include #include +#include #include "xnnpack.h" // from @XNNPACK #include "tensorflow/lite/c/common.h" @@ -111,13 +113,22 @@ class MMapHandle { // Maps the file at the given path. [[nodiscard /*Mapping a file can fail.*/]] - bool Map(const char* path); + bool Map(const char* path, size_t offset = 0); // Maps the fd associated to the file descriptor. // // The debug_path is printed along the error messages. [[nodiscard /*Mapping a file can fail.*/]] - bool Map(const FileDescriptor& fd, const char* debug_path = "unspecified"); + bool Map(const FileDescriptor& fd, size_t offset = 0, + const char* debug_path = "unspecified"); + + // Tries to resize the current mapping. + // + // Only succeeds if the mapping could be resized without being moved. + // + // WARNING: expects `IsMapped()` to be true. + [[nodiscard /*Resizing a file can fail.*/]] + bool Resize(size_t new_size); // Unmaps an existing mapping. void UnMap(); @@ -126,14 +137,16 @@ class MMapHandle { bool IsMapped() const { return data_ != nullptr; } // Returns the mapping buffer. - uint8_t* data() { return data_; } + uint8_t* data() { return data_ + offset_page_adjustment_; } // Returns the mapping buffer. - const uint8_t* data() const { return data_; } + const uint8_t* data() const { return data_ + offset_page_adjustment_; } // Returns the mapping size in bytes. size_t size() const { return size_; } + size_t offset() const { return offset_; } + uint8_t* begin() { return data(); } const uint8_t* begin() const { return data(); } @@ -146,6 +159,8 @@ class MMapHandle { private: size_t size_ = 0; + size_t offset_ = 0; + size_t offset_page_adjustment_ = 0; uint8_t* data_ = nullptr; }; @@ -156,7 +171,7 @@ class MMapHandle { class WeightCacheBuilder { public: WeightCacheBuilder() = default; - ~WeightCacheBuilder(); + ~WeightCacheBuilder() = default; // Non-copyable. WeightCacheBuilder(const WeightCacheBuilder&) = delete; @@ -174,6 +189,12 @@ class WeightCacheBuilder { return fd_.IsValid(); } + // Reopens the given file to add data to it. + // + // This should be only called from the weight cache provider. + [[nodiscard /*Starting a build step may fail.*/]] + bool StartBuildStep(); + // Resets the builder, discarding any data that hasn't been written. void Reset(); @@ -194,12 +215,25 @@ class WeightCacheBuilder { BufferLocation Append(PackIdentifier pack_id, const void* data, uint64_t size); - // Checks whether this builder has data that needs to be written to disk. - bool ShouldFinalize() const; - // Writes the flatbuffer to disk. [[nodiscard /*Writing the weight cache can fail.*/]] - bool Finalize(); + bool StopBuildStep(); + + // Get the offset in the cache file of the data written during the last step. + // + // This includes the buffers that were appended and the whole buffer mapping. + [[nodiscard]] + size_t LastBuildStepStart() const { + return build_segment_start_; + } + + // Get the size of the data written during the last step. + // + // This includes the buffers that were appended and the whole buffer mapping. + [[nodiscard]] + size_t LastBuildStepSize() const { + return build_segment_size_; + } // Returns the file descriptor. const FileDescriptor& GetFileDescriptor() const { return fd_; } @@ -218,15 +252,23 @@ class WeightCacheBuilder { // may be removed at any time. uint8_t* data() const { return data_.get(); } - friend void swap(WeightCacheBuilder& a, WeightCacheBuilder& b); - private: std::unique_ptr data_ = nullptr; cache::schema::BufferListT schema_; size_t capacity_ = 0; + // Size of the data written between StartBuildStep and StopBuildStep. + size_t build_segment_size_ = 0; + // Offset in the cache file when StartBuildStep was called. + size_t build_segment_start_ = 0; + // The call to StopBuildStep may short circuit when nothing was written to the + // cache. To ensure a smooth reloading, we need to ensure that the file header + // is correct. This flag lets us know if that has happened. + bool first_write_done_ = false; // Temporary file descriptor to write the weights to disk immediately. FileDescriptor fd_; std::string file_path_; + + bool is_build_step_ = false; }; // Allows XNNPack to directly load packed weights from disk instead of having to @@ -269,10 +311,25 @@ class MMapWeightCacheProvider { [[nodiscard /*Loading a cache file may fail.*/]] bool Load(const std::string& path); - // Loads the weight cache previouslt set with `SetFilePath`. + // Loads the weight cache previously set with `SetFilePath`. [[nodiscard /*Loading cache data may fail.*/]] bool Load(); + // Checks if the cache is currently being built or if it was loaded from a + // file. + [[nodiscard]] + bool CanStartBuildStep() const { + return building_run_; + }; + + // Prepares to add new data to the cache. + [[nodiscard /*Updating cache data may fail.*/]] + bool StartBuildStep(); + + // Prepares to use data that was added to the cache during a build step. + [[nodiscard /*Updating cache data may fail.*/]] + bool StopBuildStep(); + // Creates the tensor map. void MapTensorIdentifiers( const TfLiteTensor* tensors, size_t size, @@ -315,21 +372,17 @@ class MMapWeightCacheProvider { // Releases the weight cache's memory. void Release(); - // Ensures that the cache is ready. - // - // If the cache file already exists, this is a no-op. Otherwise, this writes - // the file to disk and reloads it. - [[nodiscard /*Writing the cache file may fail.*/]] - bool Finalize(); - - // Checks whether the cache is ready to be used. - bool IsFinalized() const; - // Returns true if any weights have been added to the underlying builder. - bool IsBuilding() const { return !IsFinalized() && !file_path_.empty(); }; + [[nodiscard]] + bool IsBuilding() const { + return is_build_step_; + }; // Returns true if a file is mapped or a file path is set. - bool IsActive() const { return IsFinalized() || !file_path_.empty(); }; + [[nodiscard]] + bool IsActive() const { + return !mmap_handles_.empty() || builder_.IsStarted(); + }; // Returns the cache provider expected by XNNPack. xnn_weights_cache_provider& GetCacheProvider() { return cache_provider_; } @@ -359,6 +412,10 @@ class MMapWeightCacheProvider { // Hashes a cache key to lookup in `cache_key_to_identifier_`. PackIdentifier BuildPackIdentifier(const xnn_weights_cache_look_up_key& key); + // Loads the data written by the last call to `builder_.BuildStepStop()`. + [[nodiscard /*Loading cache data may fail.*/]] + bool LoadLastBuildStep(); + // Cache provider implementation for XNNPack. xnn_weights_cache_provider cache_provider_{ /*context=*/this, @@ -382,7 +439,7 @@ class MMapWeightCacheProvider { cache_key_to_offset_; // MMap allocation handler. - MMapHandle mmap_handle_; + std::vector mmap_handles_; // The offset to the first buffer data in the MMap allocation. size_t mmap_buffer_base_offset_; @@ -393,6 +450,23 @@ class MMapWeightCacheProvider { // Used to build the cache. WeightCacheBuilder builder_; + + // True if the current run is the one building the cache file. + // + // We cannot distinguish between a wrong/outdated cache and one that is not + // fully done. To detect misuse, we still want to raise an error when XNNPack + // tries to append data to an existing file (i.e. when this is `false`). + bool building_run_ = false; + + // True between StartBuildStep and StopBuildStep. + // + // This is used to check whether the builder is active, which means that some + // of the buffers are not available/can't be retrieved. + bool is_build_step_ = false; + + // Stores the loaded buffer addresses corresponding to the given offset in the + // cache file. + std::map offset_to_addr_; }; } // namespace xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc index ecbc04dbe40073..ea3ab354fb3a59 100644 --- a/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc +++ b/tensorflow/lite/delegates/xnnpack/weight_cache_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include @@ -36,6 +38,7 @@ limitations under the License. #include "xnnpack.h" // from @XNNPACK #include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/xnnpack/file_util.h" #include "tensorflow/lite/delegates/xnnpack/weight_cache_schema_generated.h" #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" @@ -52,6 +55,47 @@ namespace { using testing::ElementsAreArray; using testing::Ge; +std::string GenerateRandomString(const size_t size) { + constexpr char chars[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz"; + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, + sizeof(chars) - 1); + std::string str(size, 'a'); + std::generate(begin(str), end(str), [&] { return pick(rg); }); + return str; +}; + +template +class LightSpan { + public: + using value_type = T; + + LightSpan(const void* data, const size_t size) + : ptr_(reinterpret_cast(data)), size_(size) {} + + size_t size() const { return size(); } + const T* begin() const { return ptr_; } + const T* end() const { return ptr_ + size_; } + + friend std::ostream& operator<<(std::ostream& os, const LightSpan& s) { + os << '['; + auto it = s.begin(); + if (it != s.end()) { + os << +*it; + } + ++it; + for (; it != s.end(); ++it) { + os << ", " << +*it; + } + return os << ']'; + } + + private: + T* ptr_; + size_t size_; +}; + // Wraps a call to `mkstemp` to create temporary files. class TempFileDesc { public: @@ -184,6 +228,82 @@ TEST(MMapHandleTest, MoveConstructs) { EXPECT_THAT(handle2, ElementsAreArray(payload)); } +TEST(MMapHandleTest, Resize) { + const std::string payload = "This is some data in the file."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + +#if defined(__linux__) || defined(__ANDROID__) + const size_t kMaxResizeTestCount = 20; + bool was_resized = true; + for (size_t i = 0; i < kMaxResizeTestCount && was_resized; ++i) { + was_resized = handle.Resize(payload.size() * 2); + EXPECT_TRUE(was_resized || errno == ENOMEM); + } +#else + EXPECT_FALSE(handle.Resize(payload.size())); +#endif +} + +TEST(MMapHandleTest, MapWithOffset) { + const std::string payload = "This is some data in the file."; + const std::string payload2 = "Some other data appended to the the offset."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + ASSERT_EQ(write(tmp_file.GetFd(), payload2.c_str(), size(payload2)), + size(payload2)); + tmp_file.Close(); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath(), /*offset=*/size(payload))); + EXPECT_EQ(handle.size(), size(payload2)); + EXPECT_THAT(std::string((const char*)handle.data(), handle.size()), + testing::StrEq(payload2)); +} + +TEST(MMapHandleTest, ResizeMapWithOffset) { + const std::string payload = "This is some data in the file."; + const std::string payload2 = "Some other data appended to the the offset."; + const std::string payload3 = + "Yet some other data written after the initial mapping."; + + TempFileDesc tmp_file; + ASSERT_TRUE(tmp_file.IsOpen()); + ASSERT_EQ(write(tmp_file.GetFd(), payload.c_str(), size(payload)), + size(payload)); + ASSERT_EQ(write(tmp_file.GetFd(), payload2.c_str(), size(payload2)), + size(payload2)); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath(), /*offset=*/size(payload))); + + ASSERT_EQ(write(tmp_file.GetFd(), payload3.c_str(), size(payload3)), + size(payload3)); + tmp_file.Close(); +#if defined(__linux__) || defined(__ANDROID__) + bool was_resized = handle.Resize(payload2.size() + payload3.size()); + if (was_resized) { + EXPECT_THAT(std::string((const char*)handle.data(), handle.size()), + testing::StrEq(payload2 + payload3)); + } else { + GTEST_SKIP() + << "This run did not end up in a resize of the mmaped interval."; + } +#else + GTEST_SKIP() << "Resize is not supported for this build."; +#endif +} + TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { using std::size; @@ -193,6 +313,7 @@ TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { WeightCacheBuilder builder; const std::string cache_path = testing::TempDir() + "/cache"; ASSERT_TRUE(builder.Start(cache_path.c_str())); + ASSERT_TRUE(builder.StartBuildStep()); const size_t payload_size = size(payload); void* buffer = builder.Reserve(payload_size); @@ -201,9 +322,8 @@ TEST(WeightCacheBuilderTest, ReserveAppendWriteWorks) { EXPECT_EQ(loc.size, payload_size); EXPECT_GE(builder.capacity(), payload_size); - EXPECT_TRUE(builder.ShouldFinalize()); - ASSERT_TRUE(builder.Finalize()); + ASSERT_TRUE(builder.StopBuildStep()); MMapHandle handle; ASSERT_TRUE(handle.Map(cache_path.c_str())); @@ -258,14 +378,14 @@ TEST(WeightCacheBuilderTest, AppendWithoutReserveWriteWorks) { const std::string cache_path = testing::TempDir() + "/cache"; WeightCacheBuilder builder; ASSERT_TRUE(builder.Start(cache_path.c_str())); + ASSERT_TRUE(builder.StartBuildStep()); const size_t payload_size = size(payload); auto loc = builder.Append(dummy_id, payload.c_str(), payload_size); EXPECT_EQ(loc.size, payload_size); - EXPECT_TRUE(builder.ShouldFinalize()); - ASSERT_TRUE(builder.Finalize()); + ASSERT_TRUE(builder.StopBuildStep()); MMapHandle handle; ASSERT_TRUE(handle.Map(cache_path.c_str())); @@ -341,6 +461,127 @@ TEST(WeightCacheBuilderTest, InMemoryCacheTriggeredByCorrectPrefix) { } } +TEST(WeightCacheBuilderTest, MultipleStepBuild) { + using std::size; + + const std::string payload1 = "This is some data in the file."; + const PackIdentifier dummy_id1{1, 2, 3}; + const std::string payload2 = "Other data in the file."; + const PackIdentifier dummy_id2{2, 3, 4}; + const std::string payload3 = + GenerateRandomString(/*10 MiB*/ 10 * 1024 * 1024); + const PackIdentifier dummy_id3{3, 4, 5}; + + TempFileDesc tmp_file{TempFileDesc::kAutoClose}; + + WeightCacheBuilder builder; + ASSERT_TRUE(builder.Start(tmp_file.GetCPath())); + ASSERT_TRUE(builder.StartBuildStep()); + + { + const size_t payload_size = size(payload1); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload1.c_str(), payload_size); + const auto loc = builder.Append(dummy_id1, buffer, payload_size); + EXPECT_EQ(loc.size, payload_size); + EXPECT_GE(builder.capacity(), payload_size); + } + { + const size_t payload_size = size(payload3); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload3.c_str(), payload_size); + const auto loc = builder.Append(dummy_id3, buffer, payload_size); + (void)loc; + } + + ASSERT_TRUE(builder.StopBuildStep()); + + MMapHandle handle; + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + ASSERT_TRUE(builder.StartBuildStep()); + { + const size_t payload_size = size(payload2); + void* buffer = builder.Reserve(payload_size); + std::memcpy(buffer, payload2.c_str(), payload_size); + const auto loc = builder.Append(dummy_id2, buffer, payload_size); + EXPECT_EQ(loc.size, payload_size); + EXPECT_GE(builder.capacity(), payload_size); + } + + ASSERT_TRUE(builder.StopBuildStep()); + + ASSERT_TRUE(handle.Map(tmp_file.GetCPath())); + + const XNNPackCacheHeader& header = + *reinterpret_cast(handle.data()); + + ASSERT_EQ(header.version, XNNPackCacheHeader::kVersion); + ASSERT_NE(header.buffer_list_offset, 0); + ASSERT_NE(header.buffer_list_size, 0); + ASSERT_LE(header.buffer_list_offset + header.buffer_list_size, handle.size()); + + const cache::schema::BufferList* const packed_weights = + cache::schema::GetBufferList(handle.data() + header.buffer_list_offset); + + ASSERT_NE(packed_weights, nullptr); + ASSERT_NE(packed_weights->buffers(), nullptr); + ASSERT_EQ(packed_weights->buffers()->size(), 3); + // Payload 1. + const auto* buffer1 = packed_weights->buffers()->Get(0); + ASSERT_NE(buffer1, nullptr); + ASSERT_EQ(buffer1->size(), size(payload1)); + ASSERT_EQ(buffer1->packing_algorithm_id(), dummy_id1.pack_algorithm_id); + ASSERT_EQ(buffer1->weights_id(), dummy_id1.weights_id); + ASSERT_EQ(buffer1->bias_id(), dummy_id1.bias_id); + + // Payload 3. + const auto* buffer3 = packed_weights->buffers()->Get(1); + ASSERT_NE(buffer3, nullptr); + ASSERT_EQ(buffer3->size(), size(payload3)); + ASSERT_EQ(buffer3->packing_algorithm_id(), dummy_id3.pack_algorithm_id); + ASSERT_EQ(buffer3->weights_id(), dummy_id3.weights_id); + ASSERT_EQ(buffer3->bias_id(), dummy_id3.bias_id); + + // Payload 2. + const auto* buffer2 = packed_weights->buffers()->Get(2); + ASSERT_NE(buffer2, nullptr); + ASSERT_EQ(buffer2->size(), size(payload2)); + ASSERT_EQ(buffer2->packing_algorithm_id(), dummy_id2.pack_algorithm_id); + ASSERT_EQ(buffer2->weights_id(), dummy_id2.weights_id); + ASSERT_EQ(buffer2->bias_id(), dummy_id2.bias_id); + + flatbuffers::Verifier verifier(handle.data() + header.buffer_list_offset, + header.buffer_list_size); + EXPECT_TRUE(cache::schema::VerifyBufferListBuffer(verifier)); + + // Payload 1. + ASSERT_LE(packed_weights->base_offset() + buffer1->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer1->offset() + buffer1->size(), + size(handle)); + + // Payload 2. + ASSERT_LE(packed_weights->base_offset() + buffer2->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer2->offset() + buffer2->size(), + size(handle)); + + // Payload 3. + ASSERT_LE(packed_weights->base_offset() + buffer3->offset(), size(handle)); + ASSERT_LE(packed_weights->base_offset() + buffer3->offset() + buffer3->size(), + size(handle)); + + auto GetBufferData = [&handle, &packed_weights](const auto* buffer) { + return std::tuple( + reinterpret_cast( + handle.data() + packed_weights->base_offset() + buffer->offset()), + buffer->size()); + }; + + EXPECT_THAT(GetBufferData(buffer1), ElementsAreArray(payload1)); + EXPECT_THAT(GetBufferData(buffer2), ElementsAreArray(payload2)); + EXPECT_THAT(GetBufferData(buffer3), ElementsAreArray(payload3)); +} + struct FakeContext { // Adds a new tensor and it's backing buffer to the context. // @@ -447,12 +688,12 @@ struct BuildMMapWeightCacheProviderTest : testing::Test { ctx.FinalizeTensors(); cache_provider.MapTensorIdentifiers(ctx.tensors.data(), ctx.tensors.size(), ctx.tensor_buffer_identifiers); - const std::string cache_path = testing::TempDir() + "/cache"; - ASSERT_TRUE(cache_provider.StartBuild(cache_path.c_str())); + ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); } FakeContext ctx; MMapWeightCacheProvider cache_provider; + TempFileDesc tmp_file{TempFileDesc::kAutoClose}; }; TEST_F(BuildMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { @@ -462,8 +703,10 @@ TEST_F(BuildMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { TEST_F(BuildMMapWeightCacheProviderTest, LookUpSucceeds) { enum { kWeightIndex, kBiasIndex }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex, kBiasIndex); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key = ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); @@ -474,10 +717,12 @@ TEST_F(BuildMMapWeightCacheProviderTest, LookUpSucceeds) { TEST_F(BuildMMapWeightCacheProviderTest, DifferentAlgoSeedsSameTensorsDontConflict) { enum { kWeightIndex, kBiasIndex }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex, kBiasIndex); const auto pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, kWeightIndex, kBiasIndex); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key_1 = ctx.LookUpKey(kAlgoSeed1, kWeightIndex, kBiasIndex); @@ -495,6 +740,7 @@ TEST_F(BuildMMapWeightCacheProviderTest, TEST_F(BuildMMapWeightCacheProviderTest, SameAlgoSeedDifferentTensorsDontConflict) { enum { kWeightIndex1, kWeightIndex2, kBiasIndex1, kBiasIndex2 }; + ASSERT_TRUE(cache_provider.StartBuildStep()); const auto pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex1); @@ -507,6 +753,7 @@ TEST_F(BuildMMapWeightCacheProviderTest, const auto pack_id_4 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex2, kBiasIndex2); + EXPECT_TRUE(cache_provider.StopBuildStep()); const xnn_weights_cache_look_up_key look_up_key_1 = ctx.LookUpKey(kAlgoSeed1, kWeightIndex1, kBiasIndex1); @@ -540,10 +787,9 @@ TEST_F(BuildMMapWeightCacheProviderTest, cache_provider.LookUp(&look_up_key_4)); } -TEST_F(BuildMMapWeightCacheProviderTest, FinalizeWorks) { +TEST_F(BuildMMapWeightCacheProviderTest, BuildStepSequenceWorks) { enum { kWeightIndex1, kBiasIndex, kWeightIndex2 }; - TempFileDesc tmp_file(TempFileDesc::kAutoClose); - ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); + ASSERT_TRUE(cache_provider.StartBuildStep()); ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex); @@ -552,9 +798,10 @@ TEST_F(BuildMMapWeightCacheProviderTest, FinalizeWorks) { EXPECT_TRUE(cache_provider.IsActive()); EXPECT_TRUE(cache_provider.IsBuilding()); - ASSERT_TRUE(cache_provider.Finalize()); + ASSERT_TRUE(cache_provider.StopBuildStep()); - ASSERT_TRUE(cache_provider.IsFinalized()); + ASSERT_TRUE(cache_provider.IsActive()); + EXPECT_FALSE(cache_provider.IsBuilding()); } struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { @@ -562,15 +809,14 @@ struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { void SetUp() override { BuildMMapWeightCacheProviderTest::SetUp(); - ASSERT_TRUE(cache_provider.StartBuild(tmp_file.GetCPath())); + ASSERT_TRUE(cache_provider.StartBuildStep()); pack_id_1 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed1, kWeightIndex1, kBiasIndex); pack_id_2 = ctx.PackTensors(&cache_provider.GetCacheProvider(), kAlgoSeed2, kWeightIndex2); - ASSERT_TRUE(cache_provider.Finalize()); - ASSERT_TRUE(cache_provider.IsFinalized()); + ASSERT_TRUE(cache_provider.StopBuildStep()); } xnn_weights_cache_look_up_key LookUpKey1() const { @@ -581,7 +827,6 @@ struct LoadMMapWeightCacheProviderTest : BuildMMapWeightCacheProviderTest { return ctx.LookUpKey(kAlgoSeed2, kWeightIndex2); } - TempFileDesc tmp_file; PackIdentifier pack_id_1; PackIdentifier pack_id_2; }; @@ -591,36 +836,6 @@ TEST_F(LoadMMapWeightCacheProviderTest, LookUpFailsIfKeyDoesntMatch) { EXPECT_EQ(cache_provider.LookUp(&look_up_key), SIZE_MAX); } -template -class LightSpan { - public: - using value_type = T; - - LightSpan(const void* data, const size_t size) - : ptr_(reinterpret_cast(data)), size_(size) {} - - size_t size() const { return size(); } - const T* begin() const { return ptr_; } - const T* end() const { return ptr_ + size_; } - - friend std::ostream& operator<<(std::ostream& os, const LightSpan& s) { - os << '['; - auto it = s.begin(); - if (it != s.end()) { - os << +*it; - } - ++it; - for (; it != s.end(); ++it) { - os << ", " << +*it; - } - return os << ']'; - } - - private: - T* ptr_; - size_t size_; -}; - TEST_F(LoadMMapWeightCacheProviderTest, LookUpSucceeds) { const auto& reference_1 = ctx.packed_buffers.find(pack_id_1)->second; const auto& reference_2 = ctx.packed_buffers.find(pack_id_2)->second; @@ -652,6 +867,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { const int32_t fake_packing_algo_seed = 0xBA0BAB; const char packed_data_ref_1[] = "abcdefghij"; const char packed_data_ref_2[] = "klmnopqr"; + const std::string packed_data_ref_3 = + GenerateRandomString(/*10 MiB*/ 10 * 1024 * 1024); auto bytes = [](const auto& array) { return size(array) * sizeof(array[0]); }; constexpr int kBufferCount = 10; @@ -660,6 +877,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { char fake_buffer_pointer[kBufferCount] = {0}; { // Build and reload scenario. + // This isn't factored between the two scenarios. When reloading the cache + // in another process, the buffer addresses will have changed. TfLiteTensor tensors[kBufferCount]; std::unordered_map tensor_buffer_identifiers; for (int i = 0; i < kBufferCount; ++i) { @@ -669,6 +888,8 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { MMapWeightCacheProvider cache_provider; ASSERT_TRUE(cache_provider.StartBuild(temp_fd.GetCPath())); + // 1st build step. + ASSERT_TRUE(cache_provider.StartBuildStep()); xnn_weights_cache_t cache = &cache_provider.GetCacheProvider(); cache_provider.MapTensorIdentifiers(tensors, size(tensors), @@ -679,6 +900,11 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { .kernel = tensors[0].data.data, .bias = tensors[1].data.data}; + const xnn_weights_cache_look_up_key look_up_key_3{ + .seed = fake_packing_algo_seed, + .kernel = tensors[3].data.data, + .bias = tensors[4].data.data}; + // Lookup non-packed tensor. ASSERT_EQ(cache->look_up(cache, &look_up_key_1), SIZE_MAX); // Reserve space, write data and add packed data. @@ -689,25 +915,50 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { const size_t build_offset_1 = cache->look_up_or_insert( cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); - // Check that a second insertion with the same key returns the same offset. + // Check that a second insertion with the same key returns the same + // offset. const size_t build_offset_redundant = cache->look_up_or_insert( cache, &look_up_key_1, reserved_ptr, bytes(packed_data_ref_1)); EXPECT_EQ(build_offset_1, build_offset_redundant); + // Lookup and insert other tensor + ASSERT_EQ(cache->look_up(cache, &look_up_key_3), SIZE_MAX); + void* const reserved_ptr_3 = + cache->reserve_space(cache, bytes(packed_data_ref_3)); + ASSERT_NE(reserved_ptr_3, nullptr); + std::memcpy(reserved_ptr_3, packed_data_ref_3.data(), + bytes(packed_data_ref_3)); + const size_t build_offset_3 = cache->look_up_or_insert( + cache, &look_up_key_3, reserved_ptr_3, bytes(packed_data_ref_3)); + + ASSERT_TRUE(cache_provider.StopBuildStep()); + // Lookup newly packed tensor. ASSERT_EQ(cache->look_up(cache, &look_up_key_1), build_offset_1); + ASSERT_EQ(cache->look_up(cache, &look_up_key_3), build_offset_3); + + // 2nd build step. + ASSERT_TRUE(cache_provider.StartBuildStep()); // Add a tensor without reserving before. const xnn_weights_cache_look_up_key look_up_key_2{ .seed = fake_packing_algo_seed, .kernel = tensors[2].data.data, .bias = tensors[3].data.data}; + const size_t build_offset_2 = cache->look_up_or_insert( cache, &look_up_key_2, (void*)packed_data_ref_2, bytes(packed_data_ref_2)); + // Buffer inserted during build step 1 can be looked up. + EXPECT_EQ(cache->look_up(cache, &look_up_key_3), build_offset_3); + // Reinsert buffer inserted during build step 1 should be a no-op. + EXPECT_EQ(cache->look_up_or_insert(cache, &look_up_key_3, reserved_ptr_3, + bytes(packed_data_ref_3)), + build_offset_3); + // Save the cache to disk and reload. - ASSERT_TRUE(cache_provider.Finalize()); + ASSERT_TRUE(cache_provider.StopBuildStep()); ASSERT_TRUE(cache->is_finalized(cache)); @@ -730,6 +981,16 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { EXPECT_THAT( LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), ElementsAreArray(packed_data_ref_2)); + + const size_t reload_offset_3 = cache->look_up(cache, &look_up_key_3); + ASSERT_EQ(reload_offset_3, build_offset_3); + + const void* const loaded_packed_data_3 = + cache->offset_to_addr(cache, reload_offset_3); + ASSERT_NE(loaded_packed_data_3, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_3, size(packed_data_ref_3)), + ElementsAreArray(packed_data_ref_3)); } { // Load existing cache scenario. @@ -757,6 +1018,11 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { .kernel = tensors[2].data.data, .bias = tensors[3].data.data}; + const xnn_weights_cache_look_up_key look_up_key_3{ + .seed = fake_packing_algo_seed, + .kernel = tensors[3].data.data, + .bias = tensors[4].data.data}; + ASSERT_TRUE(cache->is_finalized(cache)); const size_t offset_1 = cache->look_up(cache, &look_up_key_1); @@ -775,6 +1041,14 @@ TEST(MMapWeightCacheProviderTest, XnnpackCApiJourney) { EXPECT_THAT( LightSpan(loaded_packed_data_2, size(packed_data_ref_2)), ElementsAreArray(packed_data_ref_2)); + + const size_t offset_3 = cache->look_up(cache, &look_up_key_3); + const void* const loaded_packed_data_3 = + cache->offset_to_addr(cache, offset_3); + ASSERT_NE(loaded_packed_data_3, nullptr); + EXPECT_THAT( + LightSpan(loaded_packed_data_3, size(packed_data_ref_3)), + ElementsAreArray(packed_data_ref_3)); } } diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index 99e2605779f8bd..16e9330296cde4 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -880,6 +881,7 @@ class Subgraph { } switch (registration->builtin_code) { + case kTfLiteBuiltinExpandDims: case kTfLiteBuiltinMean: case kTfLiteBuiltinPad: case kTfLiteBuiltinSum: @@ -1149,9 +1151,26 @@ class Subgraph { if (context->profiler) { flags |= XNN_FLAG_BASIC_PROFILING; } + + if (delegate.weight_cache_provider_.IsActive() && + delegate.weight_cache_provider_.CanStartBuildStep()) { + if (!delegate.weight_cache_provider_.StartBuildStep()) { + TF_LITE_KERNEL_LOG( + context, "XNNPack delegate failed to start cache build step."); + return nullptr; + } + } status = xnn_create_runtime_v4(subgraph.get(), delegate.weights_cache(), delegate.workspace(), delegate.threadpool(), flags, &runtime_ptr); + if (delegate.weight_cache_provider_.IsActive() && + delegate.weight_cache_provider_.CanStartBuildStep()) { + if (!delegate.weight_cache_provider_.StopBuildStep()) { + TF_LITE_KERNEL_LOG(context, + "XNNPack delegate failed to stop cache build step."); + return nullptr; + } + } if (status != xnn_status_success) { TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime"); return nullptr; @@ -1165,17 +1184,6 @@ class Subgraph { bool enable_subgraph_reshaping, Delegate* delegate) { std::lock_guard lock(delegate->workspace_mutex_); - // The weights cache needs to be finalized only once. Prepare will be called - // for each partition after all the partitions have been created (therefore - // all the weights are known and have been packed). - if (delegate->weight_cache_provider_.IsActive()) { - if (!delegate->weight_cache_provider_.Finalize()) { - TF_LITE_KERNEL_LOG(context, - "XNNPack delegate failed to finalize cache."); - return kTfLiteError; - } - } - if (enable_subgraph_reshaping) { xnn_status status = xnn_status_invalid_state; for (int i = 0; i < inputs_.size(); ++i) { @@ -1232,6 +1240,7 @@ class Subgraph { TfLiteStatus Invoke(TfLiteContext* context, bool enable_subgraph_reshaping, Delegate* delegate) { std::lock_guard lock(delegate->workspace_mutex_); + bool any_pointers_changed = false; for (std::pair io_info : externals_) { const TfLiteTensor& tensor = context->tensors[io_info.first]; @@ -2468,24 +2477,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus CheckTensorNonDynamicAllocation( - const Delegate& delegate, TfLiteContext* context, - const TfLiteTensor& tensor, int tensor_index, int node_index) { - // TODO(b/149120844): remove checks once dynamic tensors are supported - if (delegate.enable_subgraph_reshaping()) { - return kTfLiteOk; - } - if (tensor.allocation_type == kTfLiteDynamic) { - TF_LITE_MAYBE_KERNEL_LOG( - context, - "invalid allocation type in tensor #%d in node #%d: " - "expected non-dynamic tensor", - tensor_index, node_index); - return kTfLiteError; - } - return kTfLiteOk; - } - static TfLiteStatus CheckTensorStaticAllocation(TfLiteContext* context, const TfLiteTensor& tensor, int tensor_index, @@ -2725,6 +2716,10 @@ class Subgraph { case kTfLiteBuiltinElu: return VisitEluNode(subgraph, delegate, logging_context, node_index, node, context->tensors, input_output_tensors); + case kTfLiteBuiltinExpandDims: + return VisitExpandDimsNode(subgraph, delegate, logging_context, + node_index, node, context->tensors, + input_output_tensors); case kTfLiteBuiltinFullyConnected: { // FullyConnected with sparse weight has version 8, which cannot be // delegated to XNNPack. @@ -2789,9 +2784,10 @@ class Subgraph { case kTfLiteBuiltinSum: { const TfLiteReducerParams* reducer_params = static_cast(node->builtin_data); - return VisitSumNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, reducer_params, - input_output_tensors); + return VisitReduceNode(BuiltinOperator_SUM, xnn_reduce_sum, subgraph, + delegate, logging_context, node_index, node, + context->tensors, reducer_params, + input_output_tensors); } case kTfLiteBuiltinMaximum: return VisitMaximumNode(subgraph, delegate, logging_context, node_index, @@ -2799,10 +2795,10 @@ class Subgraph { case kTfLiteBuiltinMean: { const TfLiteReducerParams* reducer_params = static_cast(node->builtin_data); - - return VisitMeanNode(subgraph, delegate, logging_context, node_index, - node, context->tensors, reducer_params, - input_output_tensors); + return VisitReduceNode(BuiltinOperator_MEAN, xnn_reduce_mean, subgraph, + delegate, logging_context, node_index, node, + context->tensors, reducer_params, + input_output_tensors); } case kTfLiteBuiltinMinimum: return VisitMinimumNode(subgraph, delegate, logging_context, node_index, @@ -3018,16 +3014,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_abs( @@ -3058,9 +3048,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -3070,9 +3057,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], @@ -3082,9 +3066,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (input1_tensor.type != input2_tensor.type || input1_tensor.type != output_tensor.type) { @@ -3167,16 +3148,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckPoolingParams(logging_context, pool_params, BuiltinOperator_AVERAGE_POOL_2D, @@ -3233,15 +3208,6 @@ class Subgraph { TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const TfLiteBatchMatMulParams* params, const std::unordered_map& input_output_tensors) { - // Check whether all required options are supported. - if (params->adj_x) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "failed to delegate %s node #%d. adj_x is not supported", - EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index); - return kTfLiteError; - } - // Check the input tensor types. const TfLiteTensor& input_a = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -3263,12 +3229,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - // Check the input tensor non-dynamic allocations. - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_a, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_b, node->inputs->data[1], node_index)); - // Check whether the dimensions are compatible. const int num_dims_a = NumDimensions(&input_a); if (num_dims_a < 2) { @@ -3277,7 +3237,7 @@ class Subgraph { "failed to delegate %s node #%d. Unsupported number " "of dimensions %d for tensor #%d, must be at least 2", EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index, - node->inputs->data[0], num_dims_a); + num_dims_a, node->inputs->data[0]); return kTfLiteError; } const int num_dims_b = NumDimensions(&input_b); @@ -3287,19 +3247,37 @@ class Subgraph { "failed to delegate %s node #%d. Unsupported number " "of dimensions %d for tensor #%d, must be at least 2", EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index, - node->inputs->data[1], num_dims_b); - return kTfLiteError; - } - if (params->adj_x) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "failed to delegate %s node #%d. adj_x is not supported", - EnumNameBuiltinOperator(BuiltinOperator_BATCH_MATMUL), node_index); + num_dims_b, node->inputs->data[1]); return kTfLiteError; } // Create and attach the subgraph nodes. if (subgraph != nullptr) { + uint32_t input1_id = input_output_tensors.at(node->inputs->data[0]); + if (params->adj_x) { + // XNNPack does not support transposed A. Insert a transpose node. + uint32_t new_id = XNN_INVALID_VALUE_ID; + std::array dims; + assert(num_dims_a <= XNN_MAX_TENSOR_DIMS); + for (int i = 0; i < num_dims_a; ++i) { + dims[i] = input_a.dims->data[i]; + } + xnn_status status = xnn_define_tensor_value( + subgraph, xnn_datatype_fp32, num_dims_a, dims.data(), + /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &new_id); + if (status != xnn_status_success) { + return kTfLiteError; + } + std::array perm; + std::iota(perm.begin(), perm.end(), 0); + std::swap(perm[num_dims_a - 1], perm[num_dims_a - 2]); + status = xnn_define_static_transpose(subgraph, num_dims_a, perm.data(), + input1_id, new_id, /*flags=*/0); + if (status != xnn_status_success) { + return kTfLiteError; + } + input1_id = new_id; + } const uint32_t flags = params->adj_y ? XNN_FLAG_TRANSPOSE_B : 0; // If we're using dynamic quantization, we first need to convert the first @@ -3385,10 +3363,9 @@ class Subgraph { } // Define the conversion op for the quantized input_a. - if (xnn_status status = xnn_define_convert( - subgraph, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - dq_input_a_id, /*flags=*/0); + if (xnn_status status = xnn_define_convert(subgraph, + /*input_id=*/input1_id, + dq_input_a_id, /*flags=*/0); status != xnn_status_success) { TF_LITE_KERNEL_LOG( logging_context, "failed to delegate %s node #%d", @@ -3412,7 +3389,7 @@ class Subgraph { } else { // No conversion of the inputs necessary, just send them on their way. if (xnn_status status = xnn_define_batch_matrix_multiply( - subgraph, input_output_tensors.at(node->inputs->data[0]), + subgraph, input1_id, input_output_tensors.at(node->inputs->data[1]), input_output_tensors.at(node->outputs->data[0]), flags); status != xnn_status_success) { @@ -3438,16 +3415,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_ceiling( @@ -3481,9 +3452,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); // Check dimensions if (output_tensor.type == kTfLiteUInt8) { @@ -3518,9 +3486,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( delegate, logging_context, input_tensor, node->inputs->data[i], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_tensor, node->inputs->data[i], - node_index)); } if (subgraph != nullptr) { @@ -3590,9 +3555,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, 4, node->inputs->data[0], BuiltinOperator_CONV_2D, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt8Type( @@ -3634,9 +3596,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, 4, node->outputs->data[0], BuiltinOperator_CONV_2D, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); bool dynamically_quantized = (delegate.enable_latest_operators() && (input_tensor.type == kTfLiteFloat32 && @@ -3805,9 +3764,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, 4, node->inputs->data[0], BuiltinOperator_DEPTHWISE_CONV_2D, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt8Type( @@ -3849,9 +3805,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, 4, node->outputs->data[0], BuiltinOperator_DEPTHWISE_CONV_2D, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (input_tensor.type != output_tensor.type || input_tensor.type != filter_tensor.type) { @@ -3926,17 +3879,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (depth_to_space_params->block_size <= 1) { TF_LITE_MAYBE_KERNEL_LOG( @@ -3977,9 +3924,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorQInt8OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -3988,9 +3932,6 @@ class Subgraph { const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_convert( @@ -4020,9 +3961,6 @@ class Subgraph { const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -4031,9 +3969,6 @@ class Subgraph { const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], @@ -4042,9 +3977,6 @@ class Subgraph { const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); float output_min = -std::numeric_limits::infinity(); float output_max = +std::numeric_limits::infinity(); @@ -4084,17 +4016,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_elu( @@ -4113,6 +4039,68 @@ class Subgraph { return kTfLiteOk; } + static TfLiteStatus VisitExpandDimsNode( + xnn_subgraph_t subgraph, const Delegate& delegate, + TfLiteContext* logging_context, int node_index, TfLiteNode* node, + const TfLiteTensor* tensors, + const std::unordered_map& input_output_tensors) { + return kTfLiteError; + TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( + logging_context, node, 2, 1, BuiltinOperator_EXPAND_DIMS, node_index)); + const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; + TF_LITE_ENSURE_STATUS( + CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, + node->inputs->data[0], node_index)); + const TfLiteTensor& axis_tensor = tensors[node->inputs->data[1]]; + TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( + logging_context, axis_tensor, node->inputs->data[1], + BuiltinOperator_EXPAND_DIMS, node_index)); + + const size_t num_new_axes = NumElements(&axis_tensor); + if (num_new_axes != 1) { + TF_LITE_MAYBE_KERNEL_LOG(logging_context, + "unexpected number of axes (%d) in node #%d: " + "TFLite only supports 1 new axes", + num_new_axes, node_index); + return kTfLiteError; + } + const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_STATUS( + CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, + node->outputs->data[0], node_index)); + + size_t axis_value; + switch (axis_tensor.type) { + case kTfLiteInt32: + axis_value = *GetTensorData(&axis_tensor); + break; + case kTfLiteInt64: + axis_value = *GetTensorData(&axis_tensor); + break; + default: + TF_LITE_MAYBE_KERNEL_LOG(logging_context, + "unexpected axis type (%d) in node #%d: " + "int32 or int64 are supported", + axis_tensor.type, node_index); + return kTfLiteError; + } + if (subgraph != nullptr) { + const xnn_status status = xnn_define_static_expand_dims( + subgraph, /*num_new_axes=*/1, &axis_value, + /*input_id=*/input_output_tensors.at(node->inputs->data[0]), + /*output_id=*/input_output_tensors.at(node->outputs->data[0]), + /*flags=*/0); + if (status != xnn_status_success) { + TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", + EnumNameBuiltinOperator(BuiltinOperator_EXPAND_DIMS), + node_index); + return kTfLiteError; + } + } + + return kTfLiteOk; + } + static TfLiteStatus VisitFullyConnectedNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -4130,21 +4118,14 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, filter_tensor, 2, node->inputs->data[1], BuiltinOperator_FULLY_CONNECTED, node_index)); // Dynamic filter is supported, but only for FP32. - if (delegate.support_dynamic_fully_connected_operator() && - filter_tensor.type == kTfLiteFloat32) { - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, filter_tensor, node->inputs->data[1], - node_index)); - } else { + if (!(delegate.support_dynamic_fully_connected_operator() && + filter_tensor.type == kTfLiteFloat32)) { TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt4OrQCInt8Type( delegate, logging_context, filter_tensor, /*expected_quantized_dimension=*/0, node->inputs->data[1], @@ -4165,12 +4146,8 @@ class Subgraph { if (bias_tensor_id >= 0) { const TfLiteTensor& bias_tensor = tensors[bias_tensor_id]; // Dynamic bias is supported, but only for FP32. - if (delegate.support_dynamic_fully_connected_operator() && - bias_tensor.type == kTfLiteFloat32) { - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, bias_tensor, node->inputs->data[2], - node_index)); - } else { + if (!(delegate.support_dynamic_fully_connected_operator() && + bias_tensor.type == kTfLiteFloat32)) { const int num_bias_elements = NumElements(&bias_tensor); if (num_bias_elements != output_channels) { TF_LITE_MAYBE_KERNEL_LOG( @@ -4196,9 +4173,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); bool dynamically_quantized = (delegate.enable_latest_operators() && (input_tensor.type == kTfLiteFloat32 && @@ -4213,14 +4187,6 @@ class Subgraph { return kTfLiteError; } - if (NumDimensions(&input_tensor) == 0) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unexpected number of shape dimensions %d in tensor #%d", - NumDimensions(&input_tensor), node->inputs->data[0]); - return kTfLiteError; - } - if (filter_tensor.type == kTfLiteInt4 && input_channels % 2 == 1) { TF_LITE_MAYBE_KERNEL_LOG( logging_context, @@ -4358,16 +4324,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_floor( @@ -4430,16 +4390,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_hardswish( @@ -4471,17 +4425,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (!std::isnormal(leaky_relu_params->alpha) || leaky_relu_params->alpha == 0.0f) { @@ -4545,17 +4493,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_sigmoid( @@ -4586,17 +4528,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckPoolingParams( logging_context, pool_params, BuiltinOperator_MAX_POOL_2D, node_index)); @@ -4646,23 +4582,20 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitSumNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteReducerParams* reducer_params, + static TfLiteStatus VisitReduceNode( + const tflite::BuiltinOperator tflite_operator, + const xnn_reduce_operator reduce_operator, xnn_subgraph_t subgraph, + const Delegate& delegate, TfLiteContext* logging_context, int node_index, + TfLiteNode* node, const TfLiteTensor* tensors, + const TfLiteReducerParams* reducer_params, const std::unordered_map& input_output_tensors) { TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_SUM, node_index)); + logging_context, node, 2, 1, tflite_operator, node_index)); const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( - logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, - node->inputs->data[0], - BuiltinOperator_SUM, node_index)); TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); + CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, + node->inputs->data[0], node_index)); const TfLiteTensor& axes_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, axes_tensor, @@ -4671,88 +4604,40 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckAxesTensorShape( logging_context, axes_tensor, node->inputs->data[1], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, axes_tensor, node->inputs->data[1], - BuiltinOperator_SUM, node_index)); + logging_context, axes_tensor, node->inputs->data[1], tflite_operator, + node_index)); const int32_t* axes_data = reinterpret_cast(axes_tensor.data.data); const int num_reduction_axes = NumElements(&axes_tensor); - switch (num_reduction_axes) { - case 1: - if (axes_data[0] != 2) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported SUM reduction along non-spatial " - "axis %d in node %d", - axes_data[0], node_index); - return kTfLiteError; - } - break; - case 2: - if (std::min(axes_data[0], axes_data[1]) != 1 || - std::max(axes_data[0], axes_data[1]) != 2) { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported SUM reduction along non-spatial " - "axes %d and %d in node %d", - std::min(axes_data[0], axes_data[1]), - std::max(axes_data[0], axes_data[1]), node_index); - return kTfLiteError; - } - break; - default: - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported SUM reduction along %d axes in node %d", - SizeOfDimension(&axes_tensor, 0), node_index); - return kTfLiteError; - } - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - int expected_output_dims = 4; + uint32_t flags = 0; - if (!reducer_params->keep_dims) { - expected_output_dims -= num_reduction_axes; - } else { - flags = XNN_FLAG_KEEP_DIMS; + if (reducer_params->keep_dims) { + flags |= XNN_FLAG_KEEP_DIMS; } - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, expected_output_dims, - node->outputs->data[0], BuiltinOperator_SUM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); - - const float output_min = -std::numeric_limits::infinity(); - const float output_max = +std::numeric_limits::infinity(); if (subgraph != nullptr) { - xnn_status status = xnn_status_success; - switch (num_reduction_axes) { - case 1: - status = xnn_define_global_sum_pooling_1d( - subgraph, output_min, output_max, - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - case 2: - status = xnn_define_global_sum_pooling_2d( - subgraph, output_min, output_max, + std::array reduction_axes; + for (int i = 0; i < num_reduction_axes; ++i) { + if (axes_data[i] < 0) { + reduction_axes[i] = axes_data[i] + NumDimensions(&input_tensor); + } else { + reduction_axes[i] = axes_data[i]; + } + } + std::sort(&reduction_axes[0], &reduction_axes[num_reduction_axes]); + if (xnn_define_static_reduce( + subgraph, reduce_operator, num_reduction_axes, + reduction_axes.data(), /*input_id=*/input_output_tensors.at(node->inputs->data[0]), /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - default: - status = xnn_status_unsupported_parameter; - break; - } - if (status != xnn_status_success) { + flags) != xnn_status_success) { TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_SUM), + EnumNameBuiltinOperator(tflite_operator), node_index); return kTfLiteError; } @@ -4772,23 +4657,14 @@ class Subgraph { const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_maximum2( @@ -4808,164 +4684,6 @@ class Subgraph { return kTfLiteOk; } - static TfLiteStatus VisitMeanNode( - xnn_subgraph_t subgraph, const Delegate& delegate, - TfLiteContext* logging_context, int node_index, TfLiteNode* node, - const TfLiteTensor* tensors, const TfLiteReducerParams* reducer_params, - const std::unordered_map& input_output_tensors) { - TF_LITE_ENSURE_STATUS(CheckNumInputsAndOutputs( - logging_context, node, 2, 1, BuiltinOperator_MEAN, node_index)); - - const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); - - const TfLiteTensor& axes_tensor = tensors[node->inputs->data[1]]; - TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, axes_tensor, - kTfLiteInt32, node->inputs->data[1], - node_index)); - TF_LITE_ENSURE_STATUS(CheckAxesTensorShape( - logging_context, axes_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation( - logging_context, axes_tensor, node->inputs->data[1], - BuiltinOperator_MEAN, node_index)); - - const int32_t* axes_data = - reinterpret_cast(axes_tensor.data.data); - const int num_reduction_axes = NumElements(&axes_tensor); - bool all_reductions_supported = false; - bool use_legacy_path = false; - if (input_tensor.type == kTfLiteFloat32) { - all_reductions_supported = true; - if (NumDimensions(&input_tensor) == 4) { - use_legacy_path = true; - } - } else { - TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, - node->inputs->data[0], - BuiltinOperator_MEAN, node_index)); - } - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, - node->outputs->data[0], node_index)); - switch (num_reduction_axes) { - case 1: - if (axes_data[0] != 2) { - if (all_reductions_supported) { - use_legacy_path = false; - } else { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along non-spatial " - "axis %d in node %d", - axes_data[0], node_index); - return kTfLiteError; - } - } - break; - case 2: - if (std::min(axes_data[0], axes_data[1]) != 1 || - std::max(axes_data[0], axes_data[1]) != 2) { - if (all_reductions_supported) { - use_legacy_path = false; - } else { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along non-spatial " - "axes %d and %d in node %d", - std::min(axes_data[0], axes_data[1]), - std::max(axes_data[0], axes_data[1]), node_index); - return kTfLiteError; - } - } - break; - default: - if (all_reductions_supported) { - use_legacy_path = false; - } else { - TF_LITE_MAYBE_KERNEL_LOG( - logging_context, - "unsupported MEAN reduction along %d axes in node %d", - SizeOfDimension(&axes_tensor, 0), node_index); - return kTfLiteError; - } - } - int expected_output_dims = 4; - if (!reducer_params->keep_dims) { - expected_output_dims -= num_reduction_axes; - } - if (NumDimensions(&output_tensor) != expected_output_dims) { - if (all_reductions_supported) { - use_legacy_path = false; - } else { - TF_LITE_ENSURE_STATUS(CheckTensorShape( - logging_context, output_tensor, expected_output_dims, - node->outputs->data[0], BuiltinOperator_MEAN, node_index)); - } - } - - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); - - if (subgraph != nullptr) { - uint32_t flags = reducer_params->keep_dims ? XNN_FLAG_KEEP_DIMS : 0; - xnn_status status = xnn_status_success; - if (all_reductions_supported && !use_legacy_path) { - std::array reduction_axes; - for (int i = 0; i < num_reduction_axes; ++i) { - if (axes_data[i] < 0) { - reduction_axes[i] = axes_data[i] + NumDimensions(&input_tensor); - } else { - reduction_axes[i] = axes_data[i]; - } - } - std::sort(&reduction_axes[0], &reduction_axes[num_reduction_axes]); - status = xnn_define_static_mean( - subgraph, num_reduction_axes, reduction_axes.data(), - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - } else { - switch (num_reduction_axes) { - case 1: - status = xnn_define_global_average_pooling_1d( - subgraph, - /*output_min=*/-std::numeric_limits::infinity(), - /*output_max=*/+std::numeric_limits::infinity(), - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - case 2: - status = xnn_define_global_average_pooling_2d( - subgraph, - /*output_min=*/-std::numeric_limits::infinity(), - /*output_max=*/+std::numeric_limits::infinity(), - /*input_id=*/input_output_tensors.at(node->inputs->data[0]), - /*output_id=*/input_output_tensors.at(node->outputs->data[0]), - flags); - break; - default: - break; - } - } - if (status != xnn_status_success) { - TF_LITE_KERNEL_LOG(logging_context, "failed to delegate %s node #%d", - EnumNameBuiltinOperator(BuiltinOperator_MEAN), - node_index); - return kTfLiteError; - } - } - - return kTfLiteOk; - } - static TfLiteStatus VisitMediaPipeDeconvolutionNode( xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, @@ -4982,9 +4700,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, node->inputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -5016,9 +4731,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4, node->outputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const int* input_tensor_dims = input_tensor.dims->data; const int input_height = input_tensor_dims[1]; @@ -5100,9 +4812,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4, node->inputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_value_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( @@ -5111,17 +4820,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_value_tensor, 4, node->outputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_value_tensor, node->outputs->data[0], - node_index)); const TfLiteTensor& output_index_tensor = tensors[node->outputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_index_tensor, 4, node->outputs->data[1], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_index_tensor, node->outputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS( CheckMediaPipePoolParams(logging_context, pool_params, node_index)); @@ -5169,17 +4872,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_value_tensor, 4, node->inputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_value_tensor, node->inputs->data[0], - node_index)); const TfLiteTensor& input_index_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_index_tensor, 4, node->inputs->data[1], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_index_tensor, node->inputs->data[1], - node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -5187,9 +4884,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4, node->outputs->data[0], BuiltinOperator_CUSTOM, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS( CheckMediaPipePoolParams(logging_context, pool_params, node_index)); @@ -5237,23 +4931,14 @@ class Subgraph { const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_minimum2( @@ -5285,9 +4970,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -5297,9 +4979,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], @@ -5309,9 +4988,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const float scale_min = 1.0f / 65536.0f; const float scale_max = 256.0f; @@ -5356,16 +5032,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_negate( @@ -5399,9 +5069,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, 1, XNN_MAX_TENSOR_DIMS, node->inputs->data[0], BuiltinOperator_PAD, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& paddings_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, paddings_tensor, @@ -5421,9 +5088,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, 1, XNN_MAX_TENSOR_DIMS, node->outputs->data[0], BuiltinOperator_PAD, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const int32_t* paddings_data = reinterpret_cast(paddings_tensor.data.data); @@ -5487,9 +5151,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, 1, XNN_MAX_TENSOR_DIMS, node->inputs->data[0], BuiltinOperator_PRELU, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& slope_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -5509,9 +5170,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, 1, XNN_MAX_TENSOR_DIMS, node->outputs->data[0], BuiltinOperator_PRELU, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_prelu( @@ -5543,17 +5201,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorQInt8OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -5667,16 +5319,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_clamp( @@ -5728,9 +5374,6 @@ class Subgraph { logging_context, input_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], BuiltinOperator_RESHAPE, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); std::array new_shape; int num_new_dimensions; @@ -5772,9 +5415,6 @@ class Subgraph { logging_context, output_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->outputs->data[0], BuiltinOperator_RESHAPE, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (output_tensor.type == kTfLiteUInt8 || output_tensor.type == kTfLiteInt8) { @@ -5831,9 +5471,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, 4, node->inputs->data[0], BuiltinOperator_RESIZE_BILINEAR, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& shape_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, shape_tensor, @@ -5859,9 +5496,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, output_tensor, 4, node->outputs->data[0], BuiltinOperator_RESIZE_BILINEAR, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const int32_t* shape_data = reinterpret_cast(shape_tensor.data.data); @@ -5910,16 +5544,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_bankers_rounding( @@ -5982,16 +5610,10 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, input_tensor_index, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - input_tensor_index, node_index)); TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, output_tensor_index, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, output_tensor_index, - node_index)); std::array begin; std::array size; @@ -6057,16 +5679,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_softmax( @@ -6099,17 +5715,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const int block_size = space_to_depth_params->block_size; if (block_size <= 1) { @@ -6181,8 +5791,6 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[input_idx]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( delegate, logging_context, input_tensor, input_idx, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input_tensor, input_idx, node_index)); int32_t split_dim = GetTensorData(&split_dim_tensor)[0]; @@ -6192,8 +5800,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type( delegate, logging_context, output_tensor, output_idx, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, output_idx, node_index)); } if (subgraph != nullptr) { @@ -6246,9 +5852,6 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -6257,9 +5860,6 @@ class Subgraph { const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_square( @@ -6290,17 +5890,11 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_tanh( @@ -6328,9 +5922,6 @@ class Subgraph { logging_context, node, 2, 1, BuiltinOperator_TRANSPOSE, node_index)); const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, @@ -6342,10 +5933,6 @@ class Subgraph { const int* perm_data = GetTensorData(&perm_tensor); - const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const int dims_count = NumElements(&perm_tensor); std::array perm; for (int i = 0; i < dims_count; ++i) { @@ -6384,16 +5971,10 @@ class Subgraph { const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - node->inputs->data[0], node_index)); const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_square_root( @@ -6456,9 +6037,6 @@ class Subgraph { const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -6467,9 +6045,6 @@ class Subgraph { const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], @@ -6478,9 +6053,6 @@ class Subgraph { const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); if (subgraph != nullptr) { const xnn_status status = xnn_define_squared_difference( @@ -6603,16 +6175,10 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor, input_tensor_index, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - input_tensor_index, node_index)); TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, output_tensor_index, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, output_tensor_index, - node_index)); auto begin_data = GetTensorData(&begin_tensor); auto end_data = GetTensorData(&end_tensor); @@ -7139,9 +6705,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor, node->inputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input1_tensor, node->inputs->data[0], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input1_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[0], @@ -7151,9 +6714,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor, node->inputs->data[1], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, input2_tensor, node->inputs->data[1], - node_index)); TF_LITE_ENSURE_STATUS(CheckTensorShape( logging_context, input2_tensor, /*min_num_dims=*/0, /*max_num_dims=*/XNN_MAX_TENSOR_DIMS, node->inputs->data[1], @@ -7163,9 +6723,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor, node->outputs->data[0], node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, node->outputs->data[0], - node_index)); const float scale_min = 1.0f / 1024.0f; const float scale_max = 256.0f; @@ -7257,9 +6814,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorShape(logging_context, input_tensor, 4, input_tensor_index, BuiltinOperator_TRANSPOSE_CONV, node_index)); - TF_LITE_ENSURE_STATUS( - CheckTensorNonDynamicAllocation(delegate, logging_context, input_tensor, - input_tensor_index, node_index)); bool dynamically_quantized = (input_tensor.type == kTfLiteFloat32 && filter_tensor.type == kTfLiteInt8); @@ -7297,9 +6851,6 @@ class Subgraph { TF_LITE_ENSURE_STATUS( CheckTensorShape(logging_context, output_tensor, 4, output_tensor_index, BuiltinOperator_TRANSPOSE_CONV, node_index)); - TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation( - delegate, logging_context, output_tensor, output_tensor_index, - node_index)); const int* input_tensor_dims = input_tensor.dims->data; const int input_height = input_tensor_dims[1]; @@ -7563,7 +7114,6 @@ class Subgraph { TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) { // Clear previous data, in case the delegate is reused without re-creation. static_unpacked_data_map_.clear(); - static_unpacked_data_.clear(); static_unpack_nodes_.clear(); static_sparse_weights_.clear(); variable_holder_.ClearTensorIdToGlobalId(); diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h index 9a2c3b34c680a5..6be79f0bea7aef 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_ #define TENSORFLOW_LITE_DELEGATES_XNNPACK_XNNPACK_DELEGATE_H_ +#include +#include + #include "tensorflow/lite/core/c/common.h" #ifdef __cplusplus diff --git a/tensorflow/lite/experimental/acceleration/configuration/BUILD b/tensorflow/lite/experimental/acceleration/configuration/BUILD index f4f92bcbe6860b..f4bcbb9616caa4 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/BUILD +++ b/tensorflow/lite/experimental/acceleration/configuration/BUILD @@ -155,7 +155,7 @@ cc_library( srcs = ["hexagon_plugin.cc"], deps = [ ":configuration_fbs", - "//tensorflow/lite/core/experimental/acceleration/configuration:delegate_registry", + "//tensorflow/lite/core/acceleration/configuration:delegate_registry", "@com_google_absl//absl/memory", ] + select({ "@platforms//cpu:aarch64": [ @@ -208,7 +208,7 @@ cc_library( deps = [ ":configuration_fbs", "//tensorflow/lite:minimal_logging", - "//tensorflow/lite/core/experimental/acceleration/configuration:delegate_registry", + "//tensorflow/lite/core/acceleration/configuration:delegate_registry", "@com_google_absl//absl/memory", ] + select({ "//tensorflow:macos": [ diff --git a/tensorflow/lite/experimental/acceleration/configuration/coreml_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/coreml_plugin.cc index 855c6f7e4cac52..1db5ffb2e1251b 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/coreml_plugin.cc +++ b/tensorflow/lite/experimental/acceleration/configuration/coreml_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/lite/core/experimental/acceleration/configuration/delegate_registry.h" -#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #include "tensorflow/lite/minimal_logging.h" // Guarding anyway although this file not expected to be compiled for non-Apple. diff --git a/tensorflow/lite/experimental/acceleration/configuration/hexagon_plugin.cc b/tensorflow/lite/experimental/acceleration/configuration/hexagon_plugin.cc index 809137bb0870be..daf6edb452c348 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/hexagon_plugin.cc +++ b/tensorflow/lite/experimental/acceleration/configuration/hexagon_plugin.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include "absl/memory/memory.h" -#include "tensorflow/lite/core/experimental/acceleration/configuration/delegate_registry.h" -#include "tensorflow/lite/experimental/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" #if defined(__ARM_ARCH) #include "tensorflow/lite/delegates/hexagon/hexagon_delegate.h" diff --git a/tensorflow/lite/experimental/genai/external_kvcache.cc b/tensorflow/lite/experimental/genai/external_kvcache.cc index 765a132bdabf33..e5340cffa6b2bd 100644 --- a/tensorflow/lite/experimental/genai/external_kvcache.cc +++ b/tensorflow/lite/experimental/genai/external_kvcache.cc @@ -109,13 +109,10 @@ TfLiteStatus ExternalKVCacheEval(TfLiteContext* context, TfLiteNode* node) { GetOutputSafe(context, node, kKeyTensor, &updated_k_cache)); TF_LITE_ENSURE_OK( context, GetOutputSafe(context, node, kValueTensor, &updated_v_cache)); - TF_LITE_ENSURE_EQ(context, k_cache->allocation_type, kTfLiteCustom); - TF_LITE_ENSURE_EQ(context, v_cache->allocation_type, kTfLiteCustom); - TF_LITE_ENSURE_EQ(context, updated_k_cache->allocation_type, kTfLiteCustom); - TF_LITE_ENSURE_EQ(context, updated_v_cache->allocation_type, kTfLiteCustom); - // If input and output buffers are not the same, copy the input buffer to the - // output buffer before inserting the new slice. + // Note: For the best performance, the following memcpys should be avoided. + // The way to avoid that is to take advantage of CustomAllocation and use + // the same buffer for both input and output. if (k_cache->data.raw != updated_k_cache->data.raw) { memcpy(updated_k_cache->data.data, k_cache->data.data, k_cache->bytes); } diff --git a/tensorflow/lite/experimental/litert/BUILD b/tensorflow/lite/experimental/litert/BUILD new file mode 100644 index 00000000000000..23b07d5602d7c8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD b/tensorflow/lite/experimental/litert/build_common/BUILD similarity index 59% rename from tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD rename to tensorflow/lite/experimental/litert/build_common/BUILD index d7cd167f2e40f3..b6b545ed68e824 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/c/BUILD +++ b/tensorflow/lite/experimental/litert/build_common/BUILD @@ -14,22 +14,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], ) -cc_library( - name = "lite_rt_c_api", - hdrs = [ - "lite_rt_common.h", - "lite_rt_compiler_plugin.h", - "lite_rt_model.h", - "lite_rt_op_code.h", - "lite_rt_support.h", - ], - deps = [ - "//tensorflow/lite:builtin_ops", - "//tensorflow/lite/core/c:c_api_types", - ], -) - -exports_files(srcs = glob(["lite_rt*.h"])) +exports_files(srcs = ["export_litert_only.lds"]) diff --git a/tensorflow/lite/experimental/litert/build_common/export_litert_only.lds b/tensorflow/lite/experimental/litert/build_common/export_litert_only.lds new file mode 100644 index 00000000000000..81b8a0b014bc20 --- /dev/null +++ b/tensorflow/lite/experimental/litert/build_common/export_litert_only.lds @@ -0,0 +1,36 @@ +VERS_1.0 { + + /* + Export abi-stable "vendor" implemented symbols. + + TODO: Add all vendor symbols. Also export qnn libc++ symbols + (statically linked) as "protected" as needed. + */ + + global: + + /* Compiler Plugin */ + + *LiteRtCompilerPlugin; + *LiteRtPluginInit; + *LiteRtPluginDestroy; + *LiteRtPluginSocManufacturer; + *LiteRtPluginNumSupportedSocModels; + *LiteRtPluginGetSupportedSocModel; + *LiteRtPluginPartitionModel; + *LiteRtPluginCompile; + + /* Compiled Result */ + + *LiteRtCompiledResult; + *LiteRtCompiledResultDestroy; + *LiteRtCompiledResultGetByteCode; + *LiteRtCompiledResultGetCallInfo; + *LiteRtCompiledResultGetNumCalls; + + local: + + /* Hide everything else */ + + *; +}; \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl new file mode 100644 index 00000000000000..d11484657357a9 --- /dev/null +++ b/tensorflow/lite/experimental/litert/build_common/litert_build_defs.bzl @@ -0,0 +1,194 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common LiteRT Build Utilities.""" + +#################################################################################################### +# Util + +_LRT_SO_PREFIX = "libLiteRt" +_SO_EXT = ".so" +_SHARED_LIB_SUFFIX = "_so" + +# Public + +def make_linkopt(opt): + return "-Wl,{}".format(opt) + +def make_rpaths(rpaths): + return make_linkopt("-rpath={}".format(":".join(rpaths))) + +def append_rule_kwargs(rule_kwargs, **append): + for k, v in append.items(): + append_to = rule_kwargs.pop(k, []) + append_to += v + rule_kwargs[k] = append_to + +# Private + +def _valild_shared_lib_name(name): + return name.endswith(_SHARED_LIB_SUFFIX) + +def _valid_so_name(name): + return name.startswith(_LRT_SO_PREFIX) and name.endswith(_SO_EXT) + +def _make_target_ref(name): + return ":{}".format(name) + +def _make_script_linkopt(script): + return make_linkopt("--version-script=$(location {})".format(script)) + +#################################################################################################### +# Explicitly Link System Libraries ("ungrte") + +_SYS_RPATHS_X86_64 = [ + "/usr/lib/x86_64-linux-gnu", + "/lib/x86_64-linux-gnu", +] +_SYS_RPATHS_LINKOPT_X86_64 = make_rpaths(_SYS_RPATHS_X86_64) + +_SYS_ELF_INTERPRETER_X86_64 = "/lib64/ld-linux-x86-64.so.2" +_SYS_ELF_INTERPRETER_LINKOPT_X86_64 = make_linkopt("--dynamic-linker={}".format(_SYS_ELF_INTERPRETER_X86_64)) + +#################################################################################################### +# Symbol Hiding + +_EXPORT_LRT_ONLY_SCRIPT = "//tensorflow/lite/experimental/litert/build_common:export_litert_only.lds" +_EXPORT_LRT_ONLY_LINKOPT = _make_script_linkopt(_EXPORT_LRT_ONLY_SCRIPT) + +#################################################################################################### +# Macros + +# Private + +def _litert_base( + rule, + ungrte = False, + **cc_rule_kwargs): + """ + Base rule for LiteRT targets. + + Args: + rule: The underlying rule to use (e.g., cc_test, cc_library). + ungrte: Whether to link against system libraries ("ungrte"). + **cc_rule_kwargs: Keyword arguments to pass to the underlying rule. + """ + if ungrte: + append_rule_kwargs( + cc_rule_kwargs, + linkopts = select({ + "//tensorflow:linux_x86_64": [_SYS_ELF_INTERPRETER_LINKOPT_X86_64, _SYS_RPATHS_LINKOPT_X86_64], + "//conditions:default": [], + }), + ) + rule(**cc_rule_kwargs) + +# Public + +def litert_test( + ungrte = False, + use_sys_malloc = False, + **cc_test_kwargs): + """ + LiteRT test rule. + + Args: + ungrte: Whether to link against system libraries ("ungrte"). + use_sys_malloc: Whether to use the system malloc. + **cc_test_kwargs: Keyword arguments to pass to the underlying rule. + """ + if use_sys_malloc: + # copybara:uncomment cc_test_kwargs["malloc"] = "//base:system_malloc" + pass + + append_rule_kwargs( + cc_test_kwargs, + deps = ["@com_google_googletest//:gtest_main"], + ) + + _litert_base( + native.cc_test, + ungrte, + **cc_test_kwargs + ) + +def litert_lib( + ungrte = False, + **cc_lib_kwargs): + """ + LiteRT library rule. + + Args: + ungrte: Whether to link against system libraries ("ungrte"). + **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. + """ + _litert_base( + native.cc_library, + ungrte, + **cc_lib_kwargs + ) + +def litert_dynamic_lib( + name, + shared_lib_name, + so_name, + export_litert_only = False, + ungrte = False, + **cc_lib_kwargs): + """ + LiteRT dynamic library rule. + + Args: + name: The name of the library. + shared_lib_name: The name of the shared library. + so_name: The name of the shared object file. + export_litert_only: Whether to export only LiteRT symbols. + ungrte: Whether to link against system libraries ("ungrte"). + **cc_lib_kwargs: Keyword arguments to pass to the underlying rule. + """ + if not _valild_shared_lib_name(shared_lib_name): + fail("\"shared_lib_name\" must end with \"_so\"") + if not _valid_so_name(so_name): + fail("\"so_name\" must be \"libLiteRt*.so\"") + + lib_name = name + cc_lib_kwargs["name"] = lib_name + + lib_target_ref = _make_target_ref(lib_name) + + vis = cc_lib_kwargs.get("visibility", None) + + # Share tags for all targets. + tags = cc_lib_kwargs.get("tags", []) + + litert_lib( + ungrte = ungrte, + **cc_lib_kwargs + ) + + user_link_flags = [] + additional_linker_inputs = [] + if export_litert_only: + user_link_flags.append(_EXPORT_LRT_ONLY_LINKOPT) + additional_linker_inputs.append(_EXPORT_LRT_ONLY_SCRIPT) + + native.cc_shared_library( + name = shared_lib_name, + shared_lib_name = so_name, + user_link_flags = user_link_flags, + additional_linker_inputs = additional_linker_inputs, + tags = tags, + visibility = vis, + deps = [lib_target_ref], + ) diff --git a/tensorflow/lite/experimental/litert/c/BUILD b/tensorflow/lite/experimental/litert/c/BUILD new file mode 100644 index 00000000000000..69080949b6334d --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/BUILD @@ -0,0 +1,133 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_logging", + srcs = [ + "litert_logging.cc", + ], + hdrs = [ + "litert_logging.h", + ], + deps = [ + ":litert_c_api", + "//tensorflow/lite:minimal_logging", + ], +) + +cc_test( + name = "litert_logging_test", + srcs = [ + "litert_logging_test.cc", + ], + deps = [ + ":litert_c_api", + ":litert_logging", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_c_api", + hdrs = [ + "litert_common.h", + "litert_logging.h", + "litert_model.h", + "litert_op_code.h", + "litert_options.h", + "litert_support.h", + ], + deps = [ + "//tensorflow/lite:builtin_ops", + "//tensorflow/lite/core/c:c_api_types", + ], +) + +cc_library( + name = "litert_tensor_buffer", + srcs = [ + "litert_event.cc", + "litert_tensor_buffer.cc", + "litert_tensor_buffer_requirements.cc", + ], + hdrs = [ + "litert_event.h", + "litert_tensor_buffer.h", + "litert_tensor_buffer_requirements.h", + ], + deps = [ + ":litert_c_api", + ":litert_logging", + "//tensorflow/lite/experimental/litert/core:tensor_buffer", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_tensor_buffer_test", + srcs = [ + "litert_tensor_buffer_test.cc", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_c_api", + ":litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core:tensor_buffer", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "litert_tensor_buffer_requirements_test", + srcs = [ + "litert_tensor_buffer_requirements_test.cc", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_c_api", + ":litert_tensor_buffer", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_dispatch_delegate", + hdrs = [ + "litert_dispatch_delegate.h", + ], + deps = [ + ":litert_c_api", + "//tensorflow/lite/c:c_api", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/utils:simple_opaque_delegate", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + ], +) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/c/litert_common.h b/tensorflow/lite/experimental/litert/c/litert_common.h new file mode 100644 index 00000000000000..38f30f748245ae --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_common.h @@ -0,0 +1,109 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Declares canonical opaque type. +#define LITERT_DEFINE_HANDLE(name) typedef struct name##T* name +// Declares an array of references to opaque type. `name` must be +// previously declared opaque type. +#define LITERT_DEFINE_HANDLE_ARRAY(name) typedef name* name##Array + +#if __ANDROID_API__ >= 26 +#define LITERT_HAS_AHWB_SUPPORT 1 +#else +#define LITERT_HAS_AHWB_SUPPORT 0 +#endif // __ANDROID_API__ >= 26 + +#if defined(__linux__) || defined(__ANDROID__) +#define LITERT_HAS_SYNC_FENCE_SUPPORT 1 +#else +#define LITERT_HAS_SYNC_FENCE_SUPPORT 0 +#endif + +#if defined(__ANDROID__) +#define LITERT_HAS_ION_SUPPORT 1 +#define LITERT_HAS_DMABUF_SUPPORT 1 +#define LITERT_HAS_FASTRPC_SUPPORT 1 +#else +#define LITERT_HAS_ION_SUPPORT 0 +#define LITERT_HAS_DMABUF_SUPPORT 0 +#define LITERT_HAS_FASTRPC_SUPPORT 0 +#endif + +typedef enum { + kLiteRtStatusOk = 0, + + // Generic errors. + kLiteRtStatusErrorInvalidArgument = 1, + kLiteRtStatusErrorMemoryAllocationFailure = 2, + kLiteRtStatusErrorRuntimeFailure = 3, + kLiteRtStatusErrorMissingInputTensor = 4, + kLiteRtStatusErrorUnsupported = 5, + kLiteRtStatusErrorNotFound = 6, + kLiteRtStatusErrorTimeoutExpired = 7, + + // File and loading related errors. + kLiteRtStatusErrorFileIO = 500, + kLiteRtStatusErrorInvalidFlatbuffer = 501, + kLiteRtStatusErrorDynamicLoading = 502, + kLiteRtStatusErrorSerialization = 503, + kLiteRtStatusErrorCompilationr = 504, + + // IR related errors. + kLiteRtStatusErrorIndexOOB = 1000, + kLiteRtStatusErrorInvalidIrType = 1001, + kLiteRtStatusErrorInvalidGraphInvariant = 1002, + kLiteRtStatusErrorGraphModification = 1003, + + // Tool related errors. + kLiteRtStatusErrorInvalidToolConfig = 1500, + + // Lealization related errors. + kLiteRtStatusLegalizeNoMatch = 2000, + kLiteRtStatusErrorInvalidLegalization = 2001, +} LiteRtStatus; + +typedef enum { + kLiteRtAnyTypeNone = 0, + kLiteRtAnyTypeBool = 1, + kLiteRtAnyTypeInt = 2, + kLiteRtAnyTypeReal = 3, + kLiteRtAnyTypeString = 8, + kLiteRtAnyTypeVoidPtr = 9, +} LiteRtAnyType; + +typedef struct { + LiteRtAnyType type; + union { + bool bool_value; + int64_t int_value; + double real_value; + const char* str_value; + const void* ptr_value; + }; +} LiteRtAny; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h new file mode 100644 index 00000000000000..3f7828884e6e71 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ + +#include + +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +#ifdef __cplusplus +#include + +#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct LiteRtDispatchDelegateOptions LiteRtDispatchDelegateOptions; + +// Returns DispatchDelegateOptions populated with default values. +LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions(); + +TfLiteStatus LiteRtAddDispatchDelegateOption( + LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option); + +// Add NPU executable information keyed by a provided tag. +TfLiteStatus LiteRtAddDispatchDelegateExecInfoOption( + LiteRtDispatchDelegateOptions* options, const char* exec_tag, + const void* bytecode_addr, size_t bytecode_size, const char* function_name); + +void LiteRtDestroyDispatchDelegateOptions( + LiteRtDispatchDelegateOptions* options); + +// Create a delegate that uses the Dispatch API for execution. Takes ownership +// of the passed `options`. Must outlive the TFL interpreter. +TfLiteOpaqueDelegate* LiteRtCreateDispatchDelegate( + LiteRtDispatchDelegateOptions* options); + +// Do any needed cleanup and delete 'delegate'. +void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#ifdef __cplusplus +namespace litert { + +using DispatchDelegateOptionsPtr = + std::unique_ptr; + +using DispatchDelegatePtr = tflite::TfLiteOpaqueDelegateUniquePtr; + +DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr(); + +DispatchDelegatePtr CreateDispatchDelegatePtr( + DispatchDelegateOptionsPtr&& options); + +} // namespace litert +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_DISPATCH_DELEGATE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_event.cc b/tensorflow/lite/experimental/litert/c/litert_event.cc new file mode 100644 index 00000000000000..a1215f9a5a6e56 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_event.cc @@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_event.h" + +#include +#include +#include + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/core/event.h" + +#if LITERT_HAS_SYNC_FENCE_SUPPORT +LiteRtStatus LiteRtEventCreateFromSyncFenceFd(int sync_fence_fd, bool owns_fd, + LiteRtEvent* event) { + *event = new LiteRtEventT{.fd = sync_fence_fd, .owns_fd = owns_fd}; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtEventGetSyncFenceFd(LiteRtEvent event, int* sync_fence_fd) { + *sync_fence_fd = event->fd; + return kLiteRtStatusOk; +} +#endif + +LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms) { + return event->Wait(timeout_in_ms); +} + +void LiteRtEventDestroy(LiteRtEvent event) { delete event; } diff --git a/tensorflow/lite/experimental/litert/c/litert_event.h b/tensorflow/lite/experimental/litert/c/litert_event.h new file mode 100644 index 00000000000000..60e6b265b254ba --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_event.h @@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtEvent); + +#if LITERT_HAS_SYNC_FENCE_SUPPORT +LiteRtStatus LiteRtEventCreateFromSyncFenceFd(int sync_fence_fd, bool owns_fd, + LiteRtEvent* event); + +LiteRtStatus LiteRtEventGetSyncFenceFd(LiteRtEvent event, int* sync_fence_fd); +#endif // LITERT_HAS_SYNC_FENCE_SUPPORT + +// Pass -1 for timeout_in_ms for indefinite wait. +LiteRtStatus LiteRtEventWait(LiteRtEvent event, int64_t timeout_in_ms); + +void LiteRtEventDestroy(LiteRtEvent event); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.cc b/tensorflow/lite/experimental/litert/c/litert_logging.cc new file mode 100644 index 00000000000000..f6cbf4ac7ec6a1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_logging.cc @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/logger.h" +#include "tensorflow/lite/minimal_logging.h" + +class LiteRtLoggerT { + public: + LiteRtLogSeverity GetMinSeverity() { + return ConvertSeverity( + tflite::logging_internal::MinimalLogger::GetMinimumLogSeverity()); + } + + void SetMinSeverity(LiteRtLogSeverity severity) { + tflite::logging_internal::MinimalLogger::SetMinimumLogSeverity( + ConvertSeverity(severity)); + } + + void Log(LiteRtLogSeverity severity, const char* format, va_list args) { + tflite::logging_internal::MinimalLogger::LogFormatted( + ConvertSeverity(severity), format, args); + } + + private: + static tflite::LogSeverity ConvertSeverity(LiteRtLogSeverity severity) { + return static_cast(severity); + } + + static LiteRtLogSeverity ConvertSeverity(tflite::LogSeverity severity) { + return static_cast(severity); + } +}; + +LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + *logger = new LiteRtLoggerT; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity* min_severity) { + if (!logger || !min_severity) { + return kLiteRtStatusErrorInvalidArgument; + } + *min_severity = logger->GetMinSeverity(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity min_severity) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + logger->SetMinSeverity(min_severity); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, + const char* format, ...) { + if (!logger || !format) { + return kLiteRtStatusErrorInvalidArgument; + } + va_list args; + va_start(args, format); + logger->Log(severity, format, args); + va_end(args); + return kLiteRtStatusOk; +} + +void LiteRtDestroyLogger(LiteRtLogger logger) { + if (logger != nullptr) { + delete logger; + } +} + +namespace { +LiteRtLoggerT StaticLogger; +LiteRtLogger DefaultLogger = &StaticLogger; +} // namespace + +LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger) { + if (!logger) { + return kLiteRtStatusErrorInvalidArgument; + } + DefaultLogger = logger; + return kLiteRtStatusOk; +} + +LiteRtLogger LiteRtGetDefaultLogger() { return DefaultLogger; } diff --git a/tensorflow/lite/experimental/litert/c/litert_logging.h b/tensorflow/lite/experimental/litert/c/litert_logging.h new file mode 100644 index 00000000000000..41962c64dd82f3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_logging.h @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtLogger); + +// WARNING: The values of the following enum are to be kept in sync with +// tflite::LogSeverity. +typedef enum { + kLiteRtLogSeverityVerbose = 0, + kLiteRtLogSeverityInfo = 1, + kLiteRtLogSeverityWarning = 2, + kLiteRtLogSeverityError = 3, + kLiteRtLogSeveritySilent = 4, +} LiteRtLogSeverity; + +#define LITERT_VERBOSE kLiteRtLogSeverityVerbose +#define LITERT_INFO kLiteRtLogSeverityInfo +#define LITERT_WARNING kLiteRtLogSeverityWarning +#define LITERT_ERROR kLiteRtLogSeverityError +#define LITERT_SILENT kLiteRtLogSeveritySilent + +LiteRtStatus LiteRtCreateLogger(LiteRtLogger* logger); +LiteRtStatus LiteRtGetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity* min_severity); +LiteRtStatus LiteRtSetMinLoggerSeverity(LiteRtLogger logger, + LiteRtLogSeverity min_severity); +LiteRtStatus LiteRtLoggerLog(LiteRtLogger logger, LiteRtLogSeverity severity, + const char* format, ...); +void LiteRtDestroyLogger(LiteRtLogger logger); + +LiteRtLogger LiteRtGetDefaultLogger(); +LiteRtStatus LiteRtSetDefaultLogger(LiteRtLogger logger); +LiteRtStatus LiteRtDefaultLoggerLog(LiteRtLogSeverity severity, + const char* format, ...); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#define LITERT_LOGGER_LOG_PROD(logger, severity, format, ...) \ + { \ + LiteRtLogSeverity __min_severity__; \ + if (LiteRtGetMinLoggerSeverity(logger, &__min_severity__) != \ + kLiteRtStatusOk) { \ + __min_severity__ = kLiteRtLogSeverityVerbose; \ + } \ + if (severity >= __min_severity__) { \ + LiteRtLoggerLog(logger, severity, "[%s:%d] " format, __FILE__, __LINE__, \ + ##__VA_ARGS__); \ + } \ + } + +#ifndef NDEBUG +#define LITERT_LOGGER_LOG LITERT_LOGGER_LOG_PROD +#else +#define LITERT_LOGGER_LOG(logger, severity, format, ...) \ + do { \ + LITERT_LOGGER_LOG_PROD(logger, severity, format, ##__VA_ARGS__); \ + } while (false) +#endif + +#define LITERT_LOG(severity, format, ...) \ + LITERT_LOGGER_LOG(LiteRtGetDefaultLogger(), severity, format, ##__VA_ARGS__); + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_LOGGING_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_logging_test.cc b/tensorflow/lite/experimental/litert/c/litert_logging_test.cc new file mode 100644 index 00000000000000..148fc778f18915 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_logging_test.cc @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +TEST(Layout, Creation) { + LiteRtLogger logger; + ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); + LiteRtDestroyLogger(logger); +} + +TEST(Layout, MinLogging) { + LiteRtLogger logger; + ASSERT_EQ(LiteRtCreateLogger(&logger), kLiteRtStatusOk); + ASSERT_EQ(LiteRtSetMinLoggerSeverity(logger, LITERT_SILENT), kLiteRtStatusOk); + LiteRtLogSeverity min_severity; + ASSERT_EQ(LiteRtGetMinLoggerSeverity(logger, &min_severity), kLiteRtStatusOk); + ASSERT_EQ(min_severity, LITERT_SILENT); + LiteRtDestroyLogger(logger); +} diff --git a/tensorflow/lite/experimental/litert/c/litert_model.h b/tensorflow/lite/experimental/litert/c/litert_model.h new file mode 100644 index 00000000000000..3259d1cda0ff3b --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_model.h @@ -0,0 +1,197 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ + +#include +#include + +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtWeights); + +LITERT_DEFINE_HANDLE(LiteRtTensor); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtTensor); + +LITERT_DEFINE_HANDLE(LiteRtOp); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtOp); + +LITERT_DEFINE_HANDLE(LiteRtSubgraph); +LITERT_DEFINE_HANDLE_ARRAY(LiteRtSubgraph); + +LITERT_DEFINE_HANDLE(LiteRtModel); + +// Append only list of ops. +LITERT_DEFINE_HANDLE(LiteRtOpList); + +// For indexing into litert collections or counting litert things. +typedef uint64_t LiteRtParamIndex; + +// +// Tensors +// + +typedef enum { + kLiteRtElementTypeNone = kTfLiteNoType, + kLiteRtElementTypeBool = kTfLiteBool, + kLiteRtElementTypeInt4 = kTfLiteInt4, + kLiteRtElementTypeInt8 = kTfLiteInt8, + kLiteRtElementTypeInt16 = kTfLiteInt16, + kLiteRtElementTypeInt32 = kTfLiteInt32, + kLiteRtElementTypeInt64 = kTfLiteInt64, + kLiteRtElementTypeUInt8 = kTfLiteUInt8, + kLiteRtElementTypeUInt16 = kTfLiteUInt16, + kLiteRtElementTypeUInt32 = kTfLiteUInt32, + kLiteRtElementTypeUInt64 = kTfLiteUInt64, + kLiteRtElementTypeFloat16 = kTfLiteFloat16, + kLiteRtElementTypeBFloat16 = kTfLiteBFloat16, + kLiteRtElementTypeFloat32 = kTfLiteFloat32, + kLiteRtElementTypeFloat64 = kTfLiteFloat64, + kLiteRtElementTypeComplex64 = kTfLiteComplex64, + kLiteRtElementTypeComplex128 = kTfLiteComplex128, + kLiteRtElementTypeTfResource = kTfLiteResource, + kLiteRtElementTypeTfString = kTfLiteString, + kLiteRtElementTypeTfVariant = kTfLiteVariant, +} LiteRtElementType; + +typedef struct { + uint32_t rank; + // TODO: b/365299994 - Decide on canonical type(s) for indices({s}32/64). Also + // representation of dynamic dim. + const int32_t* dimensions; + // Strides for a nomimal NWHC layout. NULL if unused. + const uint32_t* strides; +} LiteRtLayout; + +// Tensor whose rank is dynamic. +typedef struct { + LiteRtElementType element_type; +} LiteRtUnrankedTensorType; + +// Tensor whose rank is static but dimenions may be dynamic. +typedef struct { + LiteRtElementType element_type; + LiteRtLayout layout; +} LiteRtRankedTensorType; + +typedef enum { + kLiteRtRankedTensorType = 0, + kLiteRtUnrankedTensorType = 1, + // TODO: b/365299994 - q types. +} LiteRtTensorTypeId; + +// Get type identifier from tensor. +LiteRtStatus GetTensorTypeId(LiteRtTensor tensor, LiteRtTensorTypeId* type_id); + +// Get unranked tensor type info, return bad status if not unranked. +LiteRtStatus GetUrankedTensorType( + LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type); + +// Get ranked tensor type info, return bad status if not ranked. +LiteRtStatus GetRankedTensorType(LiteRtTensor tensor, + LiteRtRankedTensorType* ranked_tensor_type); + +// Get opaque array from given tensor weights. +LiteRtStatus GetWeightsInfo(LiteRtWeights weights, size_t* size, + const void** addr); + +// Get static weights associated with a given tensor. All tensors have weights, +// null weights have size = 0; +LiteRtStatus GetTensorWeights(LiteRtTensor tensor, LiteRtWeights* weights); + +// Get all the ops that reference given tensor, and at what operand index. +LiteRtStatus GetTensorUses(LiteRtTensor tensor, LiteRtParamIndex* num_uses, + LiteRtOpArray* users, + LiteRtParamIndex** user_arg_inds); + +// Get the op that defines this tensor and the corresponding output index. If +// tensor is a subgraph input, defining op will be null. +LiteRtStatus GetTensorDefiningOp( + LiteRtTensor tensor, LiteRtOp* maybe_defining_op, + LiteRtParamIndex* maybe_defining_op_output_ind); + +// +// Op +// + +// Get output tensors of given op. +LiteRtStatus GetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, + LiteRtTensorArray* output); + +// Get input tensors of given op. +LiteRtStatus GetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs); + +// Get code corresponding to operation type for given op. +LiteRtStatus GetOpCode(LiteRtOp op, LiteRtOpCode* code); + +// +// Subgraph +// + +// Get input tensors for given subgraph. +LiteRtStatus GetSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs); + +// Get output tensors for given subgraph. +LiteRtStatus GetSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs); + +// Get all ops in given subgraph in a topological order. +LiteRtStatus GetSubgraphOps(LiteRtSubgraph subgraph, LiteRtParamIndex* num_ops, + LiteRtOpArray* ops); + +// +// Model +// + +// Get the metadata buffer associated with given key if it exists. +LiteRtStatus LiteRtModelGetMetadata(LiteRtModel model, const char* metadata_key, + const void** metadata_buffer, + size_t* metadata_buffer_size); + +// Get number of subgraphs in model. +LiteRtStatus GetModelNumSubgraphs(LiteRtModel model, + LiteRtParamIndex* num_subgraphs); + +// Get subgraph at given index in model. +LiteRtStatus GetModelSubgraph(LiteRtModel model, + LiteRtParamIndex subgraph_index, + LiteRtSubgraph* subgraph); + +// Get the index of the entry subgraph. +// TODO: b/365299994 - Figure out signatures. +LiteRtStatus GetModelMainSubgraph(LiteRtModel model, + LiteRtParamIndex* main_subgraph_index); + +// +// Utility Types +// + +LiteRtStatus PushOp(LiteRtOpList op_list, LiteRtOp op); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_op_code.h b/tensorflow/lite/experimental/litert/c/litert_op_code.h new file mode 100644 index 00000000000000..529360e87dc415 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_op_code.h @@ -0,0 +1,245 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ + +#include "tensorflow/lite/builtin_ops.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { + kLiteRtOpCodeTflAdd = kTfLiteBuiltinAdd, + kLiteRtOpCodeTflAveragePool2d = kTfLiteBuiltinAveragePool2d, + kLiteRtOpCodeTflConcatenation = kTfLiteBuiltinConcatenation, + kLiteRtOpCodeTflConv2d = kTfLiteBuiltinConv2d, + kLiteRtOpCodeTflDepthwiseConv2d = kTfLiteBuiltinDepthwiseConv2d, + kLiteRtOpCodeTflDepthToSpace = kTfLiteBuiltinDepthToSpace, + kLiteRtOpCodeTflDequantize = kTfLiteBuiltinDequantize, + kLiteRtOpCodeTflEmbeddingLookup = kTfLiteBuiltinEmbeddingLookup, + kLiteRtOpCodeTflFloor = kTfLiteBuiltinFloor, + kLiteRtOpCodeTflFullyConnected = kTfLiteBuiltinFullyConnected, + kLiteRtOpCodeTflHashtableLookup = kTfLiteBuiltinHashtableLookup, + kLiteRtOpCodeTflL2Normalization = kTfLiteBuiltinL2Normalization, + kLiteRtOpCodeTflL2Pool2d = kTfLiteBuiltinL2Pool2d, + kLiteRtOpCodeTflLocalResponseNormalization = + kTfLiteBuiltinLocalResponseNormalization, + kLiteRtOpCodeTflLogistic = kTfLiteBuiltinLogistic, + kLiteRtOpCodeTflLshProjection = kTfLiteBuiltinLshProjection, + kLiteRtOpCodeTflLstm = kTfLiteBuiltinLstm, + kLiteRtOpCodeTflMaxPool2d = kTfLiteBuiltinMaxPool2d, + kLiteRtOpCodeTflMul = kTfLiteBuiltinMul, + kLiteRtOpCodeTflRelu = kTfLiteBuiltinRelu, + kLiteRtOpCodeTflReluN1To1 = kTfLiteBuiltinReluN1To1, + kLiteRtOpCodeTflRelu6 = kTfLiteBuiltinRelu6, + kLiteRtOpCodeTflReshape = kTfLiteBuiltinReshape, + kLiteRtOpCodeTflResizeBilinear = kTfLiteBuiltinResizeBilinear, + kLiteRtOpCodeTflRnn = kTfLiteBuiltinRnn, + kLiteRtOpCodeTflSoftmax = kTfLiteBuiltinSoftmax, + kLiteRtOpCodeTflSpaceToDepth = kTfLiteBuiltinSpaceToDepth, + kLiteRtOpCodeTflSvdf = kTfLiteBuiltinSvdf, + kLiteRtOpCodeTflTanh = kTfLiteBuiltinTanh, + kLiteRtOpCodeTflConcatEmbeddings = kTfLiteBuiltinConcatEmbeddings, + kLiteRtOpCodeTflSkipGram = kTfLiteBuiltinSkipGram, + kLiteRtOpCodeTflCall = kTfLiteBuiltinCall, + kLiteRtOpCodeTflCustom = kTfLiteBuiltinCustom, + kLiteRtOpCodeTflEmbeddingLookupSparse = kTfLiteBuiltinEmbeddingLookupSparse, + kLiteRtOpCodeTflPad = kTfLiteBuiltinPad, + kLiteRtOpCodeTflUnidirectionalSequenceRnn = + kTfLiteBuiltinUnidirectionalSequenceRnn, + kLiteRtOpCodeTflGather = kTfLiteBuiltinGather, + kLiteRtOpCodeTflBatchToSpaceNd = kTfLiteBuiltinBatchToSpaceNd, + kLiteRtOpCodeTflSpaceToBatchNd = kTfLiteBuiltinSpaceToBatchNd, + kLiteRtOpCodeTflTranspose = kTfLiteBuiltinTranspose, + kLiteRtOpCodeTflMean = kTfLiteBuiltinMean, + kLiteRtOpCodeTflSub = kTfLiteBuiltinSub, + kLiteRtOpCodeTflDiv = kTfLiteBuiltinDiv, + kLiteRtOpCodeTflSqueeze = kTfLiteBuiltinSqueeze, + kLiteRtOpCodeTflUnidirectionalSequenceLstm = + kTfLiteBuiltinUnidirectionalSequenceLstm, + kLiteRtOpCodeTflStridedSlice = kTfLiteBuiltinStridedSlice, + kLiteRtOpCodeTflBidirectionalSequenceRnn = + kTfLiteBuiltinBidirectionalSequenceRnn, + kLiteRtOpCodeTflExp = kTfLiteBuiltinExp, + kLiteRtOpCodeTflTopkV2 = kTfLiteBuiltinTopkV2, + kLiteRtOpCodeTflSplit = kTfLiteBuiltinSplit, + kLiteRtOpCodeTflLogSoftmax = kTfLiteBuiltinLogSoftmax, + kLiteRtOpCodeTflDelegate = kTfLiteBuiltinDelegate, + kLiteRtOpCodeTflBidirectionalSequenceLstm = + kTfLiteBuiltinBidirectionalSequenceLstm, + kLiteRtOpCodeTflCast = kTfLiteBuiltinCast, + kLiteRtOpCodeTflPrelu = kTfLiteBuiltinPrelu, + kLiteRtOpCodeTflMaximum = kTfLiteBuiltinMaximum, + kLiteRtOpCodeTflArgMax = kTfLiteBuiltinArgMax, + kLiteRtOpCodeTflMinimum = kTfLiteBuiltinMinimum, + kLiteRtOpCodeTflLess = kTfLiteBuiltinLess, + kLiteRtOpCodeTflNeg = kTfLiteBuiltinNeg, + kLiteRtOpCodeTflPadv2 = kTfLiteBuiltinPadv2, + kLiteRtOpCodeTflGreater = kTfLiteBuiltinGreater, + kLiteRtOpCodeTflGreaterEqual = kTfLiteBuiltinGreaterEqual, + kLiteRtOpCodeTflLessEqual = kTfLiteBuiltinLessEqual, + kLiteRtOpCodeTflSelect = kTfLiteBuiltinSelect, + kLiteRtOpCodeTflSlice = kTfLiteBuiltinSlice, + kLiteRtOpCodeTflSin = kTfLiteBuiltinSin, + kLiteRtOpCodeTflTransposeConv = kTfLiteBuiltinTransposeConv, + kLiteRtOpCodeTflSparseToDense = kTfLiteBuiltinSparseToDense, + kLiteRtOpCodeTflTile = kTfLiteBuiltinTile, + kLiteRtOpCodeTflExpandDims = kTfLiteBuiltinExpandDims, + kLiteRtOpCodeTflEqual = kTfLiteBuiltinEqual, + kLiteRtOpCodeTflNotEqual = kTfLiteBuiltinNotEqual, + kLiteRtOpCodeTflLog = kTfLiteBuiltinLog, + kLiteRtOpCodeTflSum = kTfLiteBuiltinSum, + kLiteRtOpCodeTflSqrt = kTfLiteBuiltinSqrt, + kLiteRtOpCodeTflRsqrt = kTfLiteBuiltinRsqrt, + kLiteRtOpCodeTflShape = kTfLiteBuiltinShape, + kLiteRtOpCodeTflPow = kTfLiteBuiltinPow, + kLiteRtOpCodeTflArgMin = kTfLiteBuiltinArgMin, + kLiteRtOpCodeTflFakeQuant = kTfLiteBuiltinFakeQuant, + kLiteRtOpCodeTflReduceProd = kTfLiteBuiltinReduceProd, + kLiteRtOpCodeTflReduceMax = kTfLiteBuiltinReduceMax, + kLiteRtOpCodeTflPack = kTfLiteBuiltinPack, + kLiteRtOpCodeTflLogicalOr = kTfLiteBuiltinLogicalOr, + kLiteRtOpCodeTflOneHot = kTfLiteBuiltinOneHot, + kLiteRtOpCodeTflLogicalAnd = kTfLiteBuiltinLogicalAnd, + kLiteRtOpCodeTflLogicalNot = kTfLiteBuiltinLogicalNot, + kLiteRtOpCodeTflUnpack = kTfLiteBuiltinUnpack, + kLiteRtOpCodeTflReduceMin = kTfLiteBuiltinReduceMin, + kLiteRtOpCodeTflFloorDiv = kTfLiteBuiltinFloorDiv, + kLiteRtOpCodeTflReduceAny = kTfLiteBuiltinReduceAny, + kLiteRtOpCodeTflSquare = kTfLiteBuiltinSquare, + kLiteRtOpCodeTflZerosLike = kTfLiteBuiltinZerosLike, + kLiteRtOpCodeTflFill = kTfLiteBuiltinFill, + kLiteRtOpCodeTflFloorMod = kTfLiteBuiltinFloorMod, + kLiteRtOpCodeTflRange = kTfLiteBuiltinRange, + kLiteRtOpCodeTflResizeNearestNeighbor = kTfLiteBuiltinResizeNearestNeighbor, + kLiteRtOpCodeTflLeakyRelu = kTfLiteBuiltinLeakyRelu, + kLiteRtOpCodeTflSquaredDifference = kTfLiteBuiltinSquaredDifference, + kLiteRtOpCodeTflMirrorPad = kTfLiteBuiltinMirrorPad, + kLiteRtOpCodeTflAbs = kTfLiteBuiltinAbs, + kLiteRtOpCodeTflSplitV = kTfLiteBuiltinSplitV, + kLiteRtOpCodeTflUnique = kTfLiteBuiltinUnique, + kLiteRtOpCodeTflCeil = kTfLiteBuiltinCeil, + kLiteRtOpCodeTflReverseV2 = kTfLiteBuiltinReverseV2, + kLiteRtOpCodeTflAddN = kTfLiteBuiltinAddN, + kLiteRtOpCodeTflGatherNd = kTfLiteBuiltinGatherNd, + kLiteRtOpCodeTflCos = kTfLiteBuiltinCos, + kLiteRtOpCodeTflWhere = kTfLiteBuiltinWhere, + kLiteRtOpCodeTflRank = kTfLiteBuiltinRank, + kLiteRtOpCodeTflElu = kTfLiteBuiltinElu, + kLiteRtOpCodeTflReverseSequence = kTfLiteBuiltinReverseSequence, + kLiteRtOpCodeTflMatrixDiag = kTfLiteBuiltinMatrixDiag, + kLiteRtOpCodeTflQuantize = kTfLiteBuiltinQuantize, + kLiteRtOpCodeTflMatrixSetDiag = kTfLiteBuiltinMatrixSetDiag, + kLiteRtOpCodeTflRound = kTfLiteBuiltinRound, + kLiteRtOpCodeTflHardSwish = kTfLiteBuiltinHardSwish, + kLiteRtOpCodeTflIf = kTfLiteBuiltinIf, + kLiteRtOpCodeTflWhile = kTfLiteBuiltinWhile, + kLiteRtOpCodeTflNonMaxSuppressionV4 = kTfLiteBuiltinNonMaxSuppressionV4, + kLiteRtOpCodeTflNonMaxSuppressionV5 = kTfLiteBuiltinNonMaxSuppressionV5, + kLiteRtOpCodeTflScatterNd = kTfLiteBuiltinScatterNd, + kLiteRtOpCodeTflSelectV2 = kTfLiteBuiltinSelectV2, + kLiteRtOpCodeTflDensify = kTfLiteBuiltinDensify, + kLiteRtOpCodeTflSegmentSum = kTfLiteBuiltinSegmentSum, + kLiteRtOpCodeTflBatchMatmul = kTfLiteBuiltinBatchMatmul, + kLiteRtOpCodeTflPlaceholderForGreaterOpCodeTfls = + kTfLiteBuiltinPlaceholderForGreaterOpCodes, + kLiteRtOpCodeTflCumsum = kTfLiteBuiltinCumsum, + kLiteRtOpCodeTflCallOnce = kTfLiteBuiltinCallOnce, + kLiteRtOpCodeTflBroadcastTo = kTfLiteBuiltinBroadcastTo, + kLiteRtOpCodeTflRfft2d = kTfLiteBuiltinRfft2d, + kLiteRtOpCodeTflConv3d = kTfLiteBuiltinConv3d, + kLiteRtOpCodeTflImag = kTfLiteBuiltinImag, + kLiteRtOpCodeTflReal = kTfLiteBuiltinReal, + kLiteRtOpCodeTflComplexAbs = kTfLiteBuiltinComplexAbs, + kLiteRtOpCodeTflHashtable = kTfLiteBuiltinHashtable, + kLiteRtOpCodeTflHashtableFind = kTfLiteBuiltinHashtableFind, + kLiteRtOpCodeTflHashtableImport = kTfLiteBuiltinHashtableImport, + kLiteRtOpCodeTflHashtableSize = kTfLiteBuiltinHashtableSize, + kLiteRtOpCodeTflReduceAll = kTfLiteBuiltinReduceAll, + kLiteRtOpCodeTflConv3dTranspose = kTfLiteBuiltinConv3dTranspose, + kLiteRtOpCodeTflVarHandle = kTfLiteBuiltinVarHandle, + kLiteRtOpCodeTflReadVariable = kTfLiteBuiltinReadVariable, + kLiteRtOpCodeTflAssignVariable = kTfLiteBuiltinAssignVariable, + kLiteRtOpCodeTflBroadcastArgs = kTfLiteBuiltinBroadcastArgs, + kLiteRtOpCodeTflRandomStandardNormal = kTfLiteBuiltinRandomStandardNormal, + kLiteRtOpCodeTflBucketize = kTfLiteBuiltinBucketize, + kLiteRtOpCodeTflRandomUniform = kTfLiteBuiltinRandomUniform, + kLiteRtOpCodeTflMultinomial = kTfLiteBuiltinMultinomial, + kLiteRtOpCodeTflGelu = kTfLiteBuiltinGelu, + kLiteRtOpCodeTflDynamicUpdateSlice = kTfLiteBuiltinDynamicUpdateSlice, + kLiteRtOpCodeTflRelu0To1 = kTfLiteBuiltinRelu0To1, + kLiteRtOpCodeTflUnsortedSegmentProd = kTfLiteBuiltinUnsortedSegmentProd, + kLiteRtOpCodeTflUnsortedSegmentMax = kTfLiteBuiltinUnsortedSegmentMax, + kLiteRtOpCodeTflUnsortedSegmentSum = kTfLiteBuiltinUnsortedSegmentSum, + kLiteRtOpCodeTflAtan2 = kTfLiteBuiltinAtan2, + kLiteRtOpCodeTflUnsortedSegmentMin = kTfLiteBuiltinUnsortedSegmentMin, + kLiteRtOpCodeTflSign = kTfLiteBuiltinSign, + kLiteRtOpCodeTflBitcast = kTfLiteBuiltinBitcast, + kLiteRtOpCodeTflBitwiseXor = kTfLiteBuiltinBitwiseXor, + kLiteRtOpCodeTflRightShift = kTfLiteBuiltinRightShift, + kLiteRtOpCodeShloLogistic = kTfLiteBuiltinStablehloLogistic, + kLiteRtOpCodeShloAdd = kTfLiteBuiltinStablehloAdd, + kLiteRtOpCodeShloDivide = kTfLiteBuiltinStablehloDivide, + kLiteRtOpCodeShloMultiply = kTfLiteBuiltinStablehloMultiply, + kLiteRtOpCodeShloMaximum = kTfLiteBuiltinStablehloMaximum, + kLiteRtOpCodeShloReshape = kTfLiteBuiltinStablehloReshape, + kLiteRtOpCodeShloClamp = kTfLiteBuiltinStablehloClamp, + kLiteRtOpCodeShloConcatenate = kTfLiteBuiltinStablehloConcatenate, + kLiteRtOpCodeShloBroadcastInDim = kTfLiteBuiltinStablehloBroadcastInDim, + kLiteRtOpCodeShloConvolution = kTfLiteBuiltinStablehloConvolution, + kLiteRtOpCodeShloSlice = kTfLiteBuiltinStablehloSlice, + kLiteRtOpCodeShloCustomCall = kTfLiteBuiltinStablehloCustomCall, + kLiteRtOpCodeShloReduce = kTfLiteBuiltinStablehloReduce, + kLiteRtOpCodeShloAbs = kTfLiteBuiltinStablehloAbs, + kLiteRtOpCodeShloAnd = kTfLiteBuiltinStablehloAnd, + kLiteRtOpCodeShloCosine = kTfLiteBuiltinStablehloCosine, + kLiteRtOpCodeShloExponential = kTfLiteBuiltinStablehloExponential, + kLiteRtOpCodeShloFloor = kTfLiteBuiltinStablehloFloor, + kLiteRtOpCodeShloLog = kTfLiteBuiltinStablehloLog, + kLiteRtOpCodeShloMinimum = kTfLiteBuiltinStablehloMinimum, + kLiteRtOpCodeShloNegate = kTfLiteBuiltinStablehloNegate, + kLiteRtOpCodeShloOr = kTfLiteBuiltinStablehloOr, + kLiteRtOpCodeShloPower = kTfLiteBuiltinStablehloPower, + kLiteRtOpCodeShloRemainder = kTfLiteBuiltinStablehloRemainder, + kLiteRtOpCodeShloRsqrt = kTfLiteBuiltinStablehloRsqrt, + kLiteRtOpCodeShloSelect = kTfLiteBuiltinStablehloSelect, + kLiteRtOpCodeShloSubtract = kTfLiteBuiltinStablehloSubtract, + kLiteRtOpCodeShloTanh = kTfLiteBuiltinStablehloTanh, + kLiteRtOpCodeShloScatter = kTfLiteBuiltinStablehloScatter, + kLiteRtOpCodeShloCompare = kTfLiteBuiltinStablehloCompare, + kLiteRtOpCodeShloConvert = kTfLiteBuiltinStablehloConvert, + kLiteRtOpCodeShloDynamicSlice = kTfLiteBuiltinStablehloDynamicSlice, + kLiteRtOpCodeShloDynamicUpdateSlice = + kTfLiteBuiltinStablehloDynamicUpdateSlice, + kLiteRtOpCodeShloPad = kTfLiteBuiltinStablehloPad, + kLiteRtOpCodeShloIota = kTfLiteBuiltinStablehloIota, + kLiteRtOpCodeShloGeneral = kTfLiteBuiltinStablehloDotGeneral, + kLiteRtOpCodeShloWindow = kTfLiteBuiltinStablehloReduceWindow, + kLiteRtOpCodeShloSort = kTfLiteBuiltinStablehloSort, + kLiteRtOpCodeShloWhile = kTfLiteBuiltinStablehloWhile, + kLiteRtOpCodeShloGather = kTfLiteBuiltinStablehloGather, + kLiteRtOpCodeShloTranspose = kTfLiteBuiltinStablehloTranspose, + kLiteRtOpCodeTflDilate = kTfLiteBuiltinDilate, + kLiteRtOpCodeShloRngBitGenerator = kTfLiteBuiltinStablehloRngBitGenerator, + kLiteRtOpCodeTflReduceWindow = kTfLiteBuiltinReduceWindow, + kLiteRtOpCodeShloComposite = kTfLiteBuiltinStablehloComposite, +} LiteRtOpCode; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OP_CODE_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_options.h b/tensorflow/lite/experimental/litert/c/litert_options.h new file mode 100644 index 00000000000000..ae03e35afefaef --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_options.h @@ -0,0 +1,163 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtOp); + +//============================================================================== +// +// Get option APIs for LiteRt ADD op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtAddGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt BatchMatmul op. +// Options: +// - AdjXOption : bool +// - AdjYOption : bool +// - AsymmtericQuantizeInputOption : bool +// +//============================================================================== +LiteRtStatus LiteRtBatchMatmulGetAdjXOption(LiteRtOp op, bool* adj_x); +LiteRtStatus LiteRtBatchMatmulGetAdjYOption(LiteRtOp op, bool* adj_y); +LiteRtStatus LiteRtBatchMatmulGetAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input); + +//============================================================================== +// +// Get option APIs for LiteRt Concatenation op. +// Options: +// - FusedActivationOption : uint32_t +// - AxisOption : int32_t +// +//============================================================================== +LiteRtStatus LiteRtConcatenationGetFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation); +LiteRtStatus LiteRtConcatenationGetAxisOption(LiteRtOp op, int32_t* axis); + +//============================================================================== +// +// Get option APIs for LiteRt Div op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtDivGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt FullyConnected op. +// Options: +// - FusedActivationOption : uint32_t +// - WeightsFormatOption : uint32_t +// - KeepNumDimsOption : bool +// - QuantizedBiasTypeOption : uint32_t +// - AsymmtericQuantizeInputOption : bool +// +//============================================================================== +LiteRtStatus LiteRtFullyConnectedGetFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation); +LiteRtStatus LiteRtFullyConnectedGetWeightsFormatOption( + LiteRtOp op, uint32_t* weights_format); +LiteRtStatus LiteRtFullyConnectedGetKeepNumDimsOption(LiteRtOp op, + bool* keep_num_dims); +LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( + LiteRtOp op, uint32_t* quantized_bias_type); +LiteRtStatus LiteRtFullyConnectedGetAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input); + +//============================================================================== +// +// Get option APIs for LiteRt Mul op. +// Options: +// - FusedActivationOption : uint32_t +// +//============================================================================== +LiteRtStatus LiteRtMulGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt Softmax op. +// Options: +// - BetaOption : float +// +//============================================================================== +LiteRtStatus LiteRtSoftmaxGetBetaOption(LiteRtOp op, float* beta); + +//============================================================================== +// +// Get option APIs for LiteRt StridedSlice op. +// Options: +// - BeginMaskOption : int32_t +// - EndMaskOption : int32_t +// - EllipsisMaskOption : int32_t +// - NewAxisMaskOption : int32_t +// - ShrinkAxisMaskOption : int32_t +// - OffsetOption : bool + +//============================================================================== +LiteRtStatus LiteRtStridedSliceGetBeginMaskOption(LiteRtOp op, + int32_t* begin_mask); +LiteRtStatus LiteRtStridedSliceGetEndMaskOption(LiteRtOp op, int32_t* end_mask); +LiteRtStatus LiteRtStridedSliceGetEllipsisMaskOption(LiteRtOp op, + int32_t* ellipsis_mask); +LiteRtStatus LiteRtStridedSliceGetNewAxisMaskOption(LiteRtOp op, + int32_t* new_axis_mask); +LiteRtStatus LiteRtStridedSliceGetShrinkAxisMaskOption( + LiteRtOp op, int32_t* shrink_axis_mask); +LiteRtStatus LiteRtStridedSliceGetOffsetOption(LiteRtOp op, bool* offset); + +//============================================================================== +// +// Get option APIs for LiteRt Sub op. +// Options: +// - FusedActivationOption : uint32_t +// - (Not supported) PotScaleInt16Option : bool +// +//============================================================================== +LiteRtStatus LiteRtSubGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation); + +//============================================================================== +// +// Get option APIs for LiteRt Reshape op. +// Options: +// - new_shape : int32_t[] +// +//============================================================================== +LiteRtStatus LiteRtReshapeGetNewShapeOption(LiteRtOp op, int32_t** new_shape, + int32_t* new_shape_size); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_support.h b/tensorflow/lite/experimental/litert/c/litert_support.h new file mode 100644 index 00000000000000..16489f635b057c --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_support.h @@ -0,0 +1,63 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_SUPPORT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_SUPPORT_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" // IWYU pragma: keep + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// #define LITERT_ABORT abort() +// TODO: b/365295276 - Find a fatal error approach that will pass kokoro. +#define LITERT_ABORT + +#define LITERT_FATAL(msg) \ + { \ + fprintf(stderr, "%s\n", (msg)); \ + LITERT_ABORT; \ + } + +#define LITERT_RETURN_STATUS_IF_NOT_OK(expr) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) return status; + +#define LITERT_RETURN_STATUS_IF_NOT_OK_OR_NOT_MATCHED(expr) \ + if (LiteRtStatus status = expr; \ + (status != kLiteRtStatusOk && status != kLiteRtStatusLegalizeNoMatch)) \ + return status; + +// TODO: b/365295276 - Add optional debug only print messages support +// to all macros. +#define LITERT_RETURN_STATUS_IF_NOT_OK_MSG(expr, d_msg) \ + LITERT_RETURN_STATUS_IF_NOT_OK(expr) + +#define LITERT_RETURN_VAL_IF_NOT_OK(expr, ret_val) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) return ret_val; + +#define LITERT_STACK_ARRAY(ty, var, size, init) \ + ty* var = (ty*)alloca(sizeof(ty) * size); \ + for (ty* e = var; e < var + size; ++e) { \ + *e = init; \ + } + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_SUPPORT_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc new file mode 100644 index 00000000000000..393e51bb8915c4 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.cc @@ -0,0 +1,299 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/tensor_buffer.h" + +LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( + const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, + size_t size, LiteRtHostMemoryDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !host_buffer_addr || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromHostMemory( + *tensor_type, + absl::MakeSpan(static_cast(host_buffer_addr), size), + deallocator); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +#if LITERT_HAS_AHWB_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromAhwb( + const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !ahwb || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromAhwb( + *tensor_type, ahwb, ahwb_offset, deallocator); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, + AHardwareBuffer** ahwb) { + if (!tensor_buffer || !ahwb) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto ahwb_buffer = tensor_buffer->GetAhwbBuffer(); + if (!ahwb_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", ahwb_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *ahwb = *ahwb_buffer; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_AHWB_SUPPORT + +#if LITERT_HAS_ION_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( + const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromIonBuffer( + *tensor_type, ion_buffer_addr, ion_buffer_fd, ion_buffer_size, + ion_buffer_offset, deallocator); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer tensor_buffer, + void** ion_buffer_addr, + int* ion_buffer_fd) { + if (!tensor_buffer || !ion_buffer_addr || !ion_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto ion_buffer = tensor_buffer->GetIonBuffer(); + if (!ion_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", ion_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *ion_buffer_addr = ion_buffer->first; + *ion_buffer_fd = ion_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_ION_SUPPORT + +#if LITERT_HAS_DMABUF_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( + const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromDmaBufBuffer( + *tensor_type, dmabuf_buffer_addr, dmabuf_buffer_fd, dmabuf_buffer_size, + dmabuf_buffer_offset, deallocator); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, + void** dmabuf_buffer_addr, + int* dmabuf_buffer_fd) { + if (!tensor_buffer || !dmabuf_buffer_addr || !dmabuf_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto dmabuf_buffer = tensor_buffer->GetDmaBufBuffer(); + if (!dmabuf_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", dmabuf_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *dmabuf_buffer_addr = dmabuf_buffer->first; + *dmabuf_buffer_fd = dmabuf_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_DMABUF_SUPPORT + +#if LITERT_HAS_FASTRPC_SUPPORT +LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( + const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateFromFastRpcBuffer( + *tensor_type, fastrpc_buffer_addr, fastrpc_buffer_fd, fastrpc_buffer_size, + fastrpc_buffer_offset, deallocator); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( + LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, + int* fastrpc_buffer_fd) { + if (!tensor_buffer || !fastrpc_buffer_addr || !fastrpc_buffer_fd) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto fastrpc_buffer = tensor_buffer->GetFastRpcBuffer(); + if (!fastrpc_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", fastrpc_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *fastrpc_buffer_addr = fastrpc_buffer->first; + *fastrpc_buffer_fd = fastrpc_buffer->second; + return kLiteRtStatusOk; +} +#endif // LITERT_HAS_FASTRPC_SUPPORT + +LiteRtStatus LiteRtCreateManagedTensorBuffer( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType* tensor_type, size_t buffer_size, + LiteRtTensorBuffer* tensor_buffer) { + if (!tensor_type || !tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + auto created_tensor_buffer = LiteRtTensorBufferT::CreateManaged( + buffer_type, *tensor_type, buffer_size); + if (!created_tensor_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", created_tensor_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + *tensor_buffer = created_tensor_buffer->release(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferType* buffer_type) { + if (!tensor_buffer || !buffer_type) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_type = tensor_buffer->buffer_type(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferTensorType( + LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type) { + if (!tensor_buffer || !tensor_type) { + return kLiteRtStatusErrorInvalidArgument; + } + *tensor_type = tensor_buffer->tensor_type(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, + size_t* buffer_size) { + if (!tensor_buffer || !buffer_size) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_size = tensor_buffer->buffer_size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, + size_t* buffer_offset) { + if (!tensor_buffer || !buffer_offset) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_offset = tensor_buffer->buffer_offset(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, + void** host_memory_addr) { + if (!tensor_buffer || !host_memory_addr) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto host_buffer = tensor_buffer->GetHostBuffer(); + if (!host_buffer.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", host_buffer.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *host_memory_addr = *host_buffer; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, + void** host_mem_addr, LiteRtEvent event) { + if (!tensor_buffer || !host_mem_addr) { + return kLiteRtStatusErrorInvalidArgument; + } + + auto mapped_addr = tensor_buffer->Lock(event); + if (!mapped_addr.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", mapped_addr.status().message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + *host_mem_addr = *mapped_addr; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer tensor_buffer) { + if (!tensor_buffer) { + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = tensor_buffer->Unlock(); !status.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", status.message()); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer tensor_buffer) { + delete tensor_buffer; +} diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h new file mode 100644 index 00000000000000..bbd33e532bfab9 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h @@ -0,0 +1,182 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +#if LITERT_HAS_AHWB_SUPPORT +#include +#else +// Define a place holder AHardwareBuffer struct just to enable compilation. +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +typedef struct AHardwareBuffer AHardwareBuffer; +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // LITERT_HAS_AHWB_SUPPORT + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtTensorBuffer); + +#define LITERT_HOST_MEMORY_BUFFER_ALIGNMENT 64 + +typedef enum { + kLiteRtTensorBufferTypeUnknown = 0, + kLiteRtTensorBufferTypeHostMemory = 1, + kLiteRtTensorBufferTypeAhwb = 2, + kLiteRtTensorBufferTypeIon = 3, + kLiteRtTensorBufferTypeDmaBuf = 4, + kLiteRtTensorBufferTypeFastRpc = 5, +} LiteRtTensorBufferType; + +typedef void (*LiteRtHostMemoryDeallocator)(void* addr); +typedef void (*LiteRtAhwbDeallocator)(AHardwareBuffer* ahwb); +typedef void (*LiteRtIonDeallocator)(void* ion_buffer_addr); +typedef void (*LiteRtDmaBufDeallocator)(void* dmabuf_buffer_addr); +typedef void (*LiteRtFastRpcDeallocator)(void* fastrpc_buffer_addr); + +// ///////////////////////////////////////////////////////////////////////////// +// TensorBuffers. +// ///////////////////////////////////////////////////////////////////////////// + +// Create a tensor buffer from an existing host memory buffer of a given size, +// with optional host memory buffer deallocator (it can be NULL). Return an +// error if the passed host memory buffer doesn't satisfy +// LITERT_HOST_MEMORY_BUFFER_ALIGNMENT alignment. +LiteRtStatus LiteRtCreateTensorBufferFromHostMemory( + const LiteRtRankedTensorType* tensor_type, void* host_buffer_addr, + size_t host_buffer_size, LiteRtHostMemoryDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not allocated on the host memory. +LiteRtStatus LiteRtGetTensorBufferHostMemory(LiteRtTensorBuffer tensor_buffer, + void** host_memory_addr); + +#if LITERT_HAS_AHWB_SUPPORT +// Create a tensor buffer from an existing AHardwareBuffer, with optional +// AHardwareBuffer deallocator (it can be NULL). An non-zero `buffer_offset` can +// be used to specify multiple tensor buffers sharing the same underlying AHWB, +// in which case the provided AHWB must be sufficiently large to accomodate for +// the allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromAhwb( + const LiteRtRankedTensorType* tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an AhardwareBuffer. +LiteRtStatus LiteRtGetTensorBufferAhwb(LiteRtTensorBuffer tensor_buffer, + AHardwareBuffer** ahwb); +#endif // LITERT_HAS_AHWB_SUPPORT + +#if LITERT_HAS_ION_SUPPORT +// Create a tensor buffer from an existing ION buffer of a given size, with +// optional ION buffer deallocator (it can be NULL). An non-zero +// `ion_buffer_offset` can be used to specify multiple tensor buffers sharing +// the same underlying ION buffer, in which case parameter `ion_buffer_size` +// must be the entire size of the underlying ION memory buffer, including the +// allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromIonBuffer( + const LiteRtRankedTensorType* tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator, LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an ION buffer. +LiteRtStatus LiteRtGetTensorBufferIonBuffer(LiteRtTensorBuffer buffer, + void** ion_buffer_addr, + int* ion_buffer_fd); +#endif // LITERT_HAS_ION_SUPPORT + +#if LITERT_HAS_DMABUF_SUPPORT +// Create a tensor buffer from an existing DMA-BUF buffer of a given size, with +// optional DMA-BUF buffer deallocator (it can be NULL). An non-zero +// `dmabuf_buffer_offset` can be used to specify multiple tensor buffers sharing +// the same underlying ION buffer, in which case parameter `ion_buffer_size` +// must be the entire size of the underlying ION memory buffer, including the +// allocation needed for all tensor buffers sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromDmaBufBuffer( + const LiteRtRankedTensorType* tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator, + LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not an DMA-BUF buffer. +LiteRtStatus LiteRtGetTensorBufferDmaBufBuffer(LiteRtTensorBuffer tensor_buffer, + void** dmabuf_buffer_addr, + int* dmabuf_buffer_fd); +#endif // LITERT_HAS_DMABUF_SUPPORT + +#if LITERT_HAS_FASTRPC_SUPPORT +// Create a tensor buffer from an existing FastRPC memory buffer of a given +// size, with optional FastRPC memory buffer deallocator (it can be NULL). An +// non-zero `fastrpc_buffer_offset` can be used to specify multiple tensor +// buffers sharing the same underlying FastRPC memory buffer, in which case +// parameter `fastrpc_buffer_size` must be the entire size of the underlying +// FastRPC memory buffer, including the allocation needed for all tensor buffers +// sharing it. +LiteRtStatus LiteRtCreateTensorBufferFromFastRpcBuffer( + const LiteRtRankedTensorType* tensor_type, void* fastrpc_buffer_addr, + int fastrpc_fd, size_t fastrpc_buffer_size, size_t fastrpc_buffer_offset, + LiteRtFastRpcDeallocator deallocator, LiteRtTensorBuffer* buffer); + +// Return an error if the backing buffer is not a FastRPC memory buffer. +LiteRtStatus LiteRtGetTensorBufferFastRpcBuffer( + LiteRtTensorBuffer tensor_buffer, void** fastrpc_buffer_addr, + int* fastrpc_buffer_fd); +#endif // LITERT_HAS_FASTRPC_SUPPORT + +// Create a buffer backed by managed memory for a given size. +LiteRtStatus LiteRtCreateManagedTensorBuffer( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType* tensor_type, size_t buffer_size, + LiteRtTensorBuffer* buffer); + +LiteRtStatus LiteRtGetTensorBufferType(LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferType* buffer_type); + +LiteRtStatus LiteRtGetTensorBufferTensorType( + LiteRtTensorBuffer tensor_buffer, LiteRtRankedTensorType* tensor_type); + +LiteRtStatus LiteRtGetTensorBufferSize(LiteRtTensorBuffer tensor_buffer, + size_t* size); + +LiteRtStatus LiteRtGetTensorBufferOffset(LiteRtTensorBuffer tensor_buffer, + size_t* offset); + +// Lock a tensor buffer and map it to host memory, optionally syncronizing on a +// given input event (parameter `event` can be NULL). +LiteRtStatus LiteRtLockTensorBuffer(LiteRtTensorBuffer tensor_buffer, + void** host_mem_addr, LiteRtEvent event); + +// Unlock a tensor buffer and (potentially) unmap it from host memory. +LiteRtStatus LiteRtUnlockTensorBuffer(LiteRtTensorBuffer buffer); + +void LiteRtDestroyTensorBuffer(LiteRtTensorBuffer buffer); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc new file mode 100644 index 00000000000000..d3efb164d44e06 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +class LiteRtTensorBufferRequirementsT { + public: + LiteRtTensorBufferRequirementsT( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size) + : supported_buffer_types_( + supported_tensor_buffer_types, + supported_tensor_buffer_types + num_supported_tensor_buffer_types), + buffer_size_(buffer_size) {} + std::vector supported_buffer_types() const { + return supported_buffer_types_; + } + size_t buffer_size() const { return buffer_size_; } + + private: + std::vector supported_buffer_types_; + size_t buffer_size_; +}; + +LiteRtStatus LiteRtCreateTensorBufferRequirements( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size, LiteRtTensorBufferRequirements* requirements) { + if (num_supported_tensor_buffer_types < 1 || !supported_tensor_buffer_types || + !requirements) { + return kLiteRtStatusErrorInvalidArgument; + } + *requirements = new LiteRtTensorBufferRequirementsT( + num_supported_tensor_buffer_types, supported_tensor_buffer_types, + buffer_size); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + LiteRtTensorBufferRequirements requirements, int* num_types) { + if (!requirements || !num_types) { + return kLiteRtStatusErrorInvalidArgument; + } + *num_types = requirements->supported_buffer_types().size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + LiteRtTensorBufferRequirements requirements, int type_index, + LiteRtTensorBufferType* type) { + if (!requirements || type_index < 0 || + type_index >= requirements->supported_buffer_types().size()) { + return kLiteRtStatusErrorInvalidArgument; + } + *type = requirements->supported_buffer_types()[type_index]; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( + LiteRtTensorBufferRequirements requirements, size_t* buffer_size) { + if (!requirements || !buffer_size) { + return kLiteRtStatusErrorInvalidArgument; + } + *buffer_size = requirements->buffer_size(); + return kLiteRtStatusOk; +} + +void LiteRtDestroyTensorBufferRequirements( + LiteRtTensorBufferRequirements requirements) { + delete requirements; +} diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h new file mode 100644 index 00000000000000..0c69ddc461d884 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtTensorBufferRequirements); + +LiteRtStatus LiteRtCreateTensorBufferRequirements( + int num_supported_tensor_buffer_types, + const LiteRtTensorBufferType* supported_tensor_buffer_types, + size_t buffer_size, LiteRtTensorBufferRequirements* requirements); + +LiteRtStatus LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + LiteRtTensorBufferRequirements requirements, int* num_types); + +LiteRtStatus LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + LiteRtTensorBufferRequirements requirements, int type_index, + LiteRtTensorBufferType* type); + +LiteRtStatus LiteRtGetTensorBufferRequirementsBufferSize( + LiteRtTensorBufferRequirements requirements, size_t* buffer_size); + +void LiteRtDestroyTensorBufferRequirements( + LiteRtTensorBufferRequirements requirements); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_C_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc new file mode 100644 index 00000000000000..72483802758905 --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements_test.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +namespace { + +constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { + kLiteRtTensorBufferTypeHostMemory, + kLiteRtTensorBufferTypeAhwb, + kLiteRtTensorBufferTypeIon, + kLiteRtTensorBufferTypeFastRpc, +}; + +constexpr const size_t kNumSupportedTensorBufferTypes = + sizeof(kSupportedTensorBufferTypes) / + sizeof(kSupportedTensorBufferTypes[0]); + +constexpr const size_t kBufferSize = 1234; + +} // namespace + +TEST(TensorBufferRequirements, SimpleTest) { + LiteRtTensorBufferRequirements requirements; + ASSERT_EQ(LiteRtCreateTensorBufferRequirements(kNumSupportedTensorBufferTypes, + kSupportedTensorBufferTypes, + kBufferSize, &requirements), + kLiteRtStatusOk); + + int num_types; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + requirements, &num_types), + kLiteRtStatusOk); + ASSERT_EQ(num_types, kNumSupportedTensorBufferTypes); + + for (auto i = 0; i < num_types; ++i) { + LiteRtTensorBufferType type; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + requirements, i, &type), + kLiteRtStatusOk); + ASSERT_EQ(type, kSupportedTensorBufferTypes[i]); + } + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferRequirementsBufferSize(requirements, &size), + kLiteRtStatusOk); + ASSERT_EQ(size, kBufferSize); + + LiteRtDestroyTensorBufferRequirements(requirements); +} diff --git a/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc new file mode 100644 index 00000000000000..4dfe76b754dbbd --- /dev/null +++ b/tensorflow/lite/experimental/litert/c/litert_tensor_buffer_test.cc @@ -0,0 +1,299 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/ahwb_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/dmabuf_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/fastrpc_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/ion_buffer.h" // IWYU pragma: keep + +namespace { +constexpr const float kTensorData[] = {10, 20, 30, 40}; + +constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / + sizeof(kTensorData[0])}; + +constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/{ + /*.rank=*/1, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/nullptr, + }}; + +} // namespace + +TEST(TensorBuffer, HostMemory) { + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, Ahwb) { + if (!litert::internal::AhwbBuffer::IsSupported()) { + GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " + "skipping the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, Ion) { + if (!litert::internal::IonBuffer::IsSupported()) { + GTEST_SKIP() + << "ION buffers are not supported on this platform; skipping the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, DmaBuf) { + if (!litert::internal::DmaBufBuffer::IsSupported()) { + GTEST_SKIP() + << "DMA-BUF buffers are not supported on this platform; skipping " + "the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} + +TEST(TensorBuffer, FastRpc) { + if (!litert::internal::FastRpcBuffer::IsSupported()) { + GTEST_SKIP() + << "FastRPC buffers are not supported on this platform; skipping " + "the test"; + } + + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; + + LiteRtTensorBuffer tensor_buffer; + ASSERT_EQ( + LiteRtCreateManagedTensorBuffer(kTensorBufferType, &kTensorType, + sizeof(kTensorData), &tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBufferType buffer_type; + ASSERT_EQ(LiteRtGetTensorBufferType(tensor_buffer, &buffer_type), + kLiteRtStatusOk); + ASSERT_EQ(buffer_type, kTensorBufferType); + + LiteRtRankedTensorType tensor_type; + ASSERT_EQ(LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type), + kLiteRtStatusOk); + ASSERT_EQ(tensor_type.element_type, kLiteRtElementTypeFloat32); + ASSERT_EQ(tensor_type.layout.rank, 1); + ASSERT_EQ(tensor_type.layout.dimensions[0], kTensorType.layout.dimensions[0]); + ASSERT_EQ(tensor_type.layout.strides, nullptr); + + size_t size; + ASSERT_EQ(LiteRtGetTensorBufferSize(tensor_buffer, &size), kLiteRtStatusOk); + ASSERT_EQ(size, sizeof(kTensorData)); + + size_t offset; + ASSERT_EQ(LiteRtGetTensorBufferOffset(tensor_buffer, &offset), + kLiteRtStatusOk); + ASSERT_EQ(offset, 0); + + void* host_mem_addr; + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTensorData, sizeof(kTensorData)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ( + LiteRtLockTensorBuffer(tensor_buffer, &host_mem_addr, /*event=*/nullptr), + kLiteRtStatusOk); + ASSERT_EQ(std::memcmp(host_mem_addr, kTensorData, sizeof(kTensorData)), 0); + ASSERT_EQ(LiteRtUnlockTensorBuffer(tensor_buffer), kLiteRtStatusOk); + + LiteRtDestroyTensorBuffer(tensor_buffer); +} diff --git a/tensorflow/lite/experimental/litert/cc/BUILD b/tensorflow/lite/experimental/litert/cc/BUILD new file mode 100644 index 00000000000000..daa688caa92f6a --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/BUILD @@ -0,0 +1,178 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_cc_api", + hdrs = [ + "litert_any.h", + "litert_model.h", + "litert_support.h", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h", + ], + deps = [ + "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "any_cc_test", + srcs = [ + "any_cc_test.cc", + ], + deps = [ + ":litert_cc_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "layout_cc_test", + srcs = [ + "layout_cc_test.cc", + ], + deps = [ + ":litert_cc_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "ranked_tensor_type_cc_test", + srcs = [ + "ranked_tensor_type_cc_test.cc", + ], + deps = [ + ":litert_cc_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_tensor_buffer", + hdrs = [ + "litert_handle.h", + "litert_tensor_buffer.h", + "litert_tensor_buffer_requirements.h", + ], + deps = [ + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "tensor_buffer_test", + srcs = [ + "tensor_buffer_test.cc", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:tensor_buffer", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "tensor_buffer_requirements_test", + srcs = [ + "tensor_buffer_requirements_test.cc", + ], + deps = [ + ":litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_tensor", + srcs = ["litert_tensor.cc"], + hdrs = ["litert_tensor.h"], + deps = [ + ":litert_cc_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "litert_tensor_test", + srcs = ["litert_tensor_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":litert_tensor", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "litert_op", + srcs = ["litert_op.cc"], + hdrs = ["litert_op.h"], + deps = ["//tensorflow/lite/experimental/litert/c:litert_c_api"], +) + +cc_test( + name = "litert_op_test", + srcs = ["litert_op_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":litert_op", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_googletest//:gtest_main", + ], +) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/cc/any_cc_test.cc b/tensorflow/lite/experimental/litert/cc/any_cc_test.cc new file mode 100644 index 00000000000000..0d3b4db29537c9 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/any_cc_test.cc @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_any.h" + +TEST(Any, ConversionNone) { + EXPECT_FALSE( + litert::ToStdAny(LiteRtAny{/*.type=*/kLiteRtAnyTypeNone}).has_value()); +} + +TEST(Any, ConversionBool) { + ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ + /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/true}})), + true); + ASSERT_EQ(std::any_cast(litert::ToStdAny(LiteRtAny{ + /*.type=*/kLiteRtAnyTypeBool, {/*.bool_value=*/false}})), + false); +} + +TEST(Any, ConversionInt) { + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeInt; + litert_any.int_value = 1234; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 1234); +} + +TEST(Any, ConversionReal) { + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeReal; + litert_any.real_value = 123.4; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), 123.4); +} + +TEST(Any, ConversionString) { + constexpr const char* kTestString = "test"; + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeString; + litert_any.str_value = kTestString; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), + kTestString); +} + +TEST(Any, ConversionPtr) { + const void* kTestPtr = reinterpret_cast(1234); + LiteRtAny litert_any; + litert_any.type = kLiteRtAnyTypeVoidPtr; + litert_any.ptr_value = kTestPtr; + ASSERT_EQ(std::any_cast(litert::ToStdAny(litert_any)), kTestPtr); +} diff --git a/tensorflow/lite/experimental/litert/cc/layout_cc_test.cc b/tensorflow/lite/experimental/litert/cc/layout_cc_test.cc new file mode 100644 index 00000000000000..547596e1cc315c --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/layout_cc_test.cc @@ -0,0 +1,91 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" + +namespace { + +constexpr const int32_t kTensorDimensions[] = {1, 2, 3}; +constexpr const auto kRank = + sizeof(kTensorDimensions) / sizeof(kTensorDimensions[0]); +constexpr const uint32_t kTensorStrides[] = {6, 3, 1}; + +} // namespace + +TEST(Layout, NoStrides) { + constexpr const LiteRtLayout kLayout = { + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/nullptr, + }; + + litert::Layout layout(kLayout); + + ASSERT_EQ(layout.Rank(), kLayout.rank); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Dimensions()[i], kLayout.dimensions[i]); + } + ASSERT_FALSE(layout.HasStrides()); +} + +TEST(Layout, WithStrides) { + constexpr const LiteRtLayout kLayout = { + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/kTensorStrides, + }; + + litert::Layout layout(kLayout); + + ASSERT_EQ(layout.Rank(), kLayout.rank); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Dimensions()[i], kLayout.dimensions[i]); + } + ASSERT_TRUE(layout.HasStrides()); + for (auto i = 0; i < layout.Rank(); ++i) { + ASSERT_EQ(layout.Strides()[i], kLayout.strides[i]); + } +} + +TEST(Layout, Equal) { + litert::Layout layout1({ + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/kTensorStrides, + }); + litert::Layout layout2({ + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/kTensorStrides, + }); + ASSERT_TRUE(layout1 == layout2); +} + +TEST(Layout, NotEqual) { + litert::Layout layout1({ + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/nullptr, + }); + litert::Layout layout2({ + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/kTensorStrides, + }); + ASSERT_FALSE(layout1 == layout2); +} diff --git a/tensorflow/lite/experimental/litert/cc/litert_any.h b/tensorflow/lite/experimental/litert/cc/litert_any.h new file mode 100644 index 00000000000000..4f724f85f52935 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_any.h @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +namespace litert { + +inline std::any ToStdAny(LiteRtAny litert_any) { + std::any res; + switch (litert_any.type) { + case kLiteRtAnyTypeNone: + break; + case kLiteRtAnyTypeBool: + res = litert_any.bool_value; + break; + case kLiteRtAnyTypeInt: + res = litert_any.int_value; + break; + case kLiteRtAnyTypeReal: + res = litert_any.real_value; + break; + case kLiteRtAnyTypeString: + res = litert_any.str_value; + break; + case kLiteRtAnyTypeVoidPtr: + res = litert_any.ptr_value; + break; + } + return res; +} + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_ANY_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_handle.h b/tensorflow/lite/experimental/litert/cc/litert_handle.h new file mode 100644 index 00000000000000..0b83030e70b944 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_handle.h @@ -0,0 +1,42 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ + +#include + +namespace litert { +namespace internal { + +// This class is used to wrap and manage the lifetime of opaque handles from the +// C API into an equivalent C++ object. The class is a wrapper on +// std::unique_ptr<> that has a default constructor and doesn't crash if the +// deleter is null. +template +class Handle : public std::unique_ptr { + public: + Handle() : std::unique_ptr(nullptr, DummyDeleter) {} + Handle(T* ptr, void (*deleter)(T*)) + : std::unique_ptr(ptr, + deleter ? deleter : DummyDeleter) {} + + private: + static void DummyDeleter(T*) {} +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_HANDLE_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_model.h b/tensorflow/lite/experimental/litert/cc/litert_model.h new file mode 100644 index 00000000000000..f6606ef1f598b2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_model.h @@ -0,0 +1,132 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert { + +// Data type of tensor elements. C++ equivalent to LiteRtElementType. +enum class ElementType { + None = kLiteRtElementTypeNone, + Bool = kLiteRtElementTypeBool, + Int4 = kLiteRtElementTypeInt4, + Int8 = kLiteRtElementTypeInt8, + Int16 = kLiteRtElementTypeInt16, + Int32 = kLiteRtElementTypeInt32, + Int64 = kLiteRtElementTypeInt64, + UInt8 = kLiteRtElementTypeUInt8, + UInt16 = kLiteRtElementTypeUInt16, + UInt32 = kLiteRtElementTypeUInt32, + UInt64 = kLiteRtElementTypeUInt64, + Float16 = kLiteRtElementTypeFloat16, + BFloat16 = kLiteRtElementTypeBFloat16, + Float32 = kLiteRtElementTypeFloat32, + Float64 = kLiteRtElementTypeFloat64, + Complex64 = kLiteRtElementTypeComplex64, + Complex128 = kLiteRtElementTypeComplex128, + TfResource = kLiteRtElementTypeTfResource, + TfString = kLiteRtElementTypeTfString, + TfVariant = kLiteRtElementTypeTfVariant, +}; + +// Tensor layout. C++ equivalent to LiteRtLayout. +class Layout { + public: + explicit Layout(std::vector&& dimensions, + std::vector&& strides = std::vector()) + : dimensions_(std::move(dimensions)), strides_(std::move(strides)) {} + + explicit Layout(const LiteRtLayout& layout) + : dimensions_(layout.dimensions, layout.dimensions + layout.rank) { + if (layout.strides) { + strides_.reserve(layout.rank); + std::copy(layout.strides, layout.strides + layout.rank, + std::back_inserter(strides_)); + } + } + + explicit operator LiteRtLayout() const { + return LiteRtLayout{ + /*.rank=*/Rank(), + /*.dimensions=*/dimensions_.data(), + /*.strides=*/(HasStrides() ? strides_.data() : nullptr), + }; + } + + bool operator==(const Layout& other) const { + return dimensions_ == other.dimensions_ && strides_ == other.strides_; + } + + uint32_t Rank() const { return dimensions_.size(); } + + absl::Span Dimensions() const { + return absl::MakeSpan(dimensions_.data(), dimensions_.size()); + } + + bool HasStrides() const { return !strides_.empty(); } + + absl::Span Strides() const { + const uint32_t* data = HasStrides() ? strides_.data() : nullptr; + auto size = HasStrides() ? Rank() : 0; + return absl::MakeSpan(data, size); + } + + private: + std::vector dimensions_; + std::vector strides_; +}; + +// Type for tensors with known dimensions. C++ equivalent to +// LiteRtRankedTensorType. +class RankedTensorType { + public: + RankedTensorType(ElementType element_type, Layout&& layout) + : element_type_(element_type), layout_(std::move(layout)) {} + explicit RankedTensorType(const LiteRtRankedTensorType& type) + : element_type_(static_cast(type.element_type)), + layout_(type.layout) {} + + explicit operator LiteRtRankedTensorType() const { + return LiteRtRankedTensorType{ + /*.element_type=*/static_cast(element_type_), + /*layout=*/static_cast(layout_), + }; + } + + bool operator==(const RankedTensorType& other) const { + return ElementType() == other.ElementType() && Layout() == other.Layout(); + } + + ElementType ElementType() const { return element_type_; } + + const Layout& Layout() const { return layout_; } + + private: + enum ElementType element_type_; + class Layout layout_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_MODEL_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_op.cc b/tensorflow/lite/experimental/litert/cc/litert_op.cc new file mode 100644 index 00000000000000..bcf1b8474ddea7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_op.cc @@ -0,0 +1,37 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" + +namespace litert { + +LiteRtStatus LiteRtOpManager::MakeFromOp(LiteRtOp op, Unique& result) { + result = std::make_unique(); + LITERT_RETURN_STATUS_IF_NOT_OK(GetOpCode(op, &result->code_)); + result->op_ = op; + return kLiteRtStatusOk; +} + +LiteRtOpCode LiteRtOpManager::Code() const { return code_; } + +LiteRtOp LiteRtOpManager::Op() { return op_; } + +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_op.h b/tensorflow/lite/experimental/litert/cc/litert_op.h new file mode 100644 index 00000000000000..e37a63690870e0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_op.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" + +namespace litert { + +// [WIP] Simple C++ wrapper over the C op api. Provided for convenience. +// +// NOTE ON USAGE: This "unpacks" upfront some of the data behind the LiteRtOp +// for efficiency and a cleaner interface (no status checks needed on getters). +// Because of this, it is required that `op : LiteRtOp` is stable and +// unmutated throughout the lifetime. This is guaranteed within (but not +// between) calls to an LiteRtCompilerPlugin. Plugins should close all +// LiteRtOpManagers before exiting a call and initialize fresh ones in later +// calls. +// +// This is an evolution of "graph_tools" and logic will be consolidated in +// the future. +// +// TODO: Expand this abstraction to handle options and edges (as +// LiteRtTensorManagers). +class LiteRtOpManager { + public: + using Unique = std::unique_ptr; + + static LiteRtStatus MakeFromOp(LiteRtOp op, Unique& result); + + LiteRtOpCode Code() const; + + LiteRtOp Op(); + + private: + LiteRtOp op_; + + LiteRtOpCode code_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_OP_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_op_test.cc b/tensorflow/lite/experimental/litert/cc/litert_op_test.cc new file mode 100644 index 00000000000000..3c80217806d59b --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_op_test.cc @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" + +#include +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +using ::litert::LiteRtOpManager; + +TEST(TestLiteRtOp, SimpleSupportedOp) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + + LiteRtOpManager::Unique op; + ASSERT_STATUS_OK(LiteRtOpManager::MakeFromOp(ops[0], op)); + + EXPECT_EQ(op->Code(), kLiteRtOpCodeTflMul); + EXPECT_EQ(op->Op(), ops[0]); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/cc/litert_support.h b/tensorflow/lite/experimental/litert/cc/litert_support.h new file mode 100644 index 00000000000000..c1404ec4655c3d --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_support.h @@ -0,0 +1,222 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SUPPORT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SUPPORT_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" // IWYU pragma: export +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// Flatbuffer's raw char type. +typedef uint8_t FbCharT; + +// Const view of flatbuffer's raw buffer type. +typedef absl::Span FbConstBufferT; + +// Mutable view of flatbuffer's raw buffer type. +typedef absl::Span FbBufferT; + +// Convenience method to get raw string view from native flatbuffer buffer. +inline absl::string_view FbBufToStr(FbConstBufferT fb_buf) { + auto fb_buf_raw = reinterpret_cast(fb_buf.data()); + const size_t fb_buf_size = fb_buf.size(); + return absl::string_view(fb_buf_raw, fb_buf_size); +} + +// Mutable version of above. +inline absl::string_view FbBufToStr(FbBufferT fb_buf) { + auto fb_buf_raw = reinterpret_cast(fb_buf.data()); + const size_t fb_buf_size = fb_buf.size(); + return absl::string_view(fb_buf_raw, fb_buf_size); +} + +#define _CONCAT_NAME_IMPL(x, y) x##y + +#define _CONCAT_NAME(x, y) _CONCAT_NAME_IMPL(x, y) + +#define _RETURN_VAL(val) return val + +// TODO: b/365295276 - Put all smart pointer wrappers in support.h. +struct LiteRtCompilerPluginDeleter { + void operator()(LiteRtCompilerPlugin plugin) { + if (plugin != nullptr) { + LiteRtPluginDestroy(plugin); + } + } +}; + +using UniqueLiteRtCompilerPlugin = + std::unique_ptr; + +// `StatusOr` analog for litert. Very basic currently. +// TODO: b/365295276 - Figure out how to better infer template param +// and not require passing typing to macros. +template +class LiteRtResult { + public: + // TODO: b/365295276 - Implement emplace for LiteRtResult. + + static LiteRtResult FromValue(const T& value) { + LiteRtResult result; + result.data_ = value; + return result; + } + + static LiteRtResult TakeValue(T&& value) { + LiteRtResult result; + result.data_ = std::move(value); + return result; + } + + static LiteRtResult FromStatus(LiteRtStatus status) { + LiteRtResult result; + result.data_ = status; + return result; + } + + T& Value() { + if (!HasValue()) { + LITERT_FATAL("Result does not contain a value."); + } + return std::get(data_); + } + + LiteRtStatus Status() { + if (std::holds_alternative(data_)) { + return kLiteRtStatusOk; + } + return std::get(data_); + } + + bool HasValue() { return std::holds_alternative(data_); } + + private: + std::variant data_; +}; + +#ifdef NDEBUG +#define _LITERT_D_MSG(msg) +#else +#define _LITERT_D_MSG(msg) LITERT_LOG(LITERT_INFO, "%s", msg) +#endif + +#ifdef LITERT_RETURN_STATUS_IF_NOT_OK_MSG +#undef LITERT_RETURN_STATUS_IF_NOT_OK_MSG +#define LITERT_RETURN_STATUS_IF_NOT_OK_MSG(expr, d_msg) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) { \ + _LITERT_D_MSG(d_msg); \ + return status; \ + } +#endif + +// TODO: b/365295276 Make c friendly `CHECK` macro(s) and move to c api. +#define LITERT_CHECK_STATUS_HAS_CODE_MSG(expr, code, d_msg) \ + if (LiteRtStatus status = expr; status != code) { \ + _LITERT_D_MSG(d_msg); \ + ABSL_CHECK(false); \ + } + +#define LITERT_CHECK_STATUS_HAS_CODE(expr, code) \ + LITERT_CHECK_STATUS_HAS_CODE_MSG(expr, code, ""); + +#define LITERT_CHECK_STATUS_OK(expr) \ + LITERT_CHECK_STATUS_HAS_CODE(expr, kLiteRtStatusOk); + +#define LITERT_CHECK_STATUS_OK_MSG(expr, d_msg) \ + LITERT_CHECK_STATUS_HAS_CODE_MSG(expr, kLiteRtStatusOk, d_msg); + +// If expr doesn't retur ok status, wrap as result and return. +#define LITERT_RETURN_RESULT_IF_NOT_OK(expr, ty) \ + if (LiteRtStatus status = expr; status != kLiteRtStatusOk) \ + return LiteRtResult::FromStatus(status); + +#define _ASSIGN_OR_BLOCK(decl, expr, block, result) \ + auto result = (expr); \ + if (!result.HasValue()) { \ + block; \ + } \ + decl = result.Value(); + +#define _MOVE_OR_BLOCK(decl, expr, block, result) \ + auto result = (expr); \ + if (!result.HasValue()) { \ + block; \ + } \ + decl = std::move(result.Value()); + +#define _MOVE_OR_RETURN_VAL(decl, expr, val, result) \ + _MOVE_OR_BLOCK(decl, expr, _RETURN_VAL(val), result) + +#define _ASSIGN_OR_RETURN_VAL(decl, expr, val, result) \ + _ASSIGN_OR_BLOCK(decl, expr, _RETURN_VAL(val), result) + +// Assign value behind result returned from expr. If not ok, return val. +#define LITERT_ASSIGN_OR_RETURN_VAL(decl, expr, val) \ + _ASSIGN_OR_RETURN_VAL(decl, expr, val, _CONCAT_NAME(_result, __COUNTER__)) + +#define _STATUS_FROM_RESULT(result) result.Status(); + +#define _ASSIGN_OR_RETURN_STATUS(decl, expr, result) \ + _ASSIGN_OR_RETURN_VAL(decl, expr, _STATUS_FROM_RESULT(result), result) + +#define _MOVE_OR_RETURN_STATUS(decl, expr, result) \ + _MOVE_OR_RETURN_VAL(decl, expr, _STATUS_FROM_RESULT(result), result) + +// Assign value behind result returned from expr. If not ok, return status. +#define LITERT_ASSIGN_OR_RETURN_STATUS(decl, expr) \ + _ASSIGN_OR_RETURN_STATUS(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + +// Assign value behind result returned from expr. If not ok, return status. +#define LITERT_MOVE_OR_RETURN_STATUS(decl, expr) \ + _MOVE_OR_RETURN_STATUS(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + +#define _FORWARD_RESULT(result, ty) \ + LiteRtResult::FromStatus(result.Status()); + +#define _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, result) \ + _ASSIGN_OR_RETURN_VAL(decl, expr, _FORWARD_RESULT(result, ty), result) + +// Assign value behind result returned from expr. If not ok, return result. +#define LITERT_ASSIGN_OR_RETURN_RESULT(decl, expr, ty) \ + _ASSIGN_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) + +#define _MOVE_OR_RETURN_RESULT(decl, expr, ty, result) \ + _MOVE_OR_RETURN_VAL(decl, expr, _FORWARD_RESULT(result, ty), result) + +// Move value behind result returned from expr. If not ok, return result. +#define LITERT_MOVE_OR_RETURN_RESULT(decl, expr, ty) \ + _MOVE_OR_RETURN_RESULT(decl, expr, ty, _CONCAT_NAME(_result, __COUNTER__)) + +#define LITERT_ENSURE_SUPPORTED(cond, msg) \ + if (!(cond)) { \ + LITERT_LOG(LITERT_ERROR, "%s", msg); \ + return kLiteRtStatusErrorUnsupported; \ + } + +#define LITERT_ENSURE(expr, fail_stat, msg) \ + if (!(expr)) { \ + LITERT_LOG(LITERT_ERROR, "%s", msg); \ + return fail_stat; \ + } + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_SUPPORT_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor.cc new file mode 100644 index 00000000000000..6ff03ed693c28c --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor.cc @@ -0,0 +1,76 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/cc/litert_tensor.h" + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" + +namespace litert { + +absl::Span LiteRtTensorManager::Dims() const { + return absl::MakeConstSpan(ranked_tensor_type_.layout.dimensions, Rank()); +} + +absl::Span LiteRtTensorManager::Strides() const { + if (ranked_tensor_type_.layout.strides) { + return absl::MakeConstSpan(ranked_tensor_type_.layout.strides, Rank()); + } else { + return {}; + } +} + +uint32_t LiteRtTensorManager::Rank() const { + return ranked_tensor_type_.layout.rank; +} + +LiteRtElementType LiteRtTensorManager::ElementType() const { + return ranked_tensor_type_.element_type; +} + +LiteRtTensor LiteRtTensorManager::Tensor() { return tensor_; } + +LiteRtStatus LiteRtTensorManager::MakeFromTensor(LiteRtTensor tensor, + Unique& result) { + result = std::make_unique(); + + LiteRtTensorTypeId type_id; + LITERT_RETURN_STATUS_IF_NOT_OK(GetTensorTypeId(tensor, &type_id)); + LITERT_ENSURE_SUPPORTED( + type_id == kLiteRtRankedTensorType, + "Only RankedTensorType currently supported in C++ api."); + + LITERT_RETURN_STATUS_IF_NOT_OK( + GetRankedTensorType(tensor, &result->ranked_tensor_type_)); + result->tensor_ = tensor; + + return kLiteRtStatusOk; +} + +bool LiteRtTensorManager::IsSubgraphOutput() const { + return ::graph_tools::MatchTensorNoUses(tensor_); +} + +bool LiteRtTensorManager::IsSubgraphInput() const { + return ::graph_tools::MatchTensorNoDefiningOp(tensor_) && + ::graph_tools::MatchNoWeights(tensor_); +} + +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor.h b/tensorflow/lite/experimental/litert/cc/litert_tensor.h new file mode 100644 index 00000000000000..f3ae75ea0b86f5 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor.h @@ -0,0 +1,72 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_H_ + +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert { + +// [WIP] Simple C++ wrapper over the C tensor api. Provided for convenience. +// Currently only supports LiteRtRankedTensors. +// +// NOTE ON USAGE: This "unpacks" upfront some of the data behind the +// LiteRtTensor for efficiency and a cleaner interface (no status checks needed +// on getters). Becasuse of this, it is required that `tensor : LiteRtTensor` is +// stable and unmutated throughout the lifetime. This is guaranteed within (but +// not between) calls to an LiteRtCompilerPlugin. Plugins should close all +// LiteRtTensorManagers before exiting a call and initialize fresh ones in later +// calls. +// +// This is an evolution of "graph_tools" and logic will be consolidated in +// the future. +// +// TODO Expand this abstraction +// to handle the union of possible tensor types cleanly as well as +// defining op/users. +class LiteRtTensorManager { + public: + using Unique = std::unique_ptr; + + static LiteRtStatus MakeFromTensor(LiteRtTensor tensor, Unique& result); + + uint32_t Rank() const; + + absl::Span Dims() const; + + bool HasStrides() const { + return ranked_tensor_type_.layout.strides != nullptr; + } + absl::Span Strides() const; + + LiteRtElementType ElementType() const; + + bool IsSubgraphOutput() const; + + bool IsSubgraphInput() const; + + LiteRtTensor Tensor(); + + private: + LiteRtTensor tensor_; + + LiteRtRankedTensorType ranked_tensor_type_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h new file mode 100644 index 00000000000000..fa31b05c2c54f9 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h @@ -0,0 +1,150 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" + +namespace litert { + +// Tensor and associated backing buffer. C++ equivalent of LiteRtTensorBuffer. +class TensorBuffer { + public: + TensorBuffer() = default; + + // Parameter `owned` indicates if the created TensorBuffer object should take + // ownership of the provided `tensor_buffer` handle. + explicit TensorBuffer(LiteRtTensorBuffer tensor_buffer, bool owned = true) + : handle_(tensor_buffer, owned ? LiteRtDestroyTensorBuffer : nullptr) {} + + TensorBuffer(TensorBuffer&& other) { *this = std::move(other); } + + TensorBuffer& operator=(TensorBuffer&& other) { + std::swap(handle_, other.handle_); + return *this; + } + + static absl::StatusOr CreateManaged( + LiteRtTensorBufferType buffer_type, const RankedTensorType& tensor_type, + size_t buffer_size) { + LiteRtTensorBuffer tensor_buffer; + auto& litert_tensor_type = + static_cast(tensor_type); + if (auto status = LiteRtCreateManagedTensorBuffer( + buffer_type, &litert_tensor_type, buffer_size, &tensor_buffer); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to create managed tensor buffer"); + } + return TensorBuffer(tensor_buffer); + } + + // Return true if the underlying LiteRtTensorBuffer handle is valid. + explicit operator bool() const { return static_cast(handle_); } + + // Return the underlying LiteRtTensorBuffer handle. + explicit operator LiteRtTensorBuffer() { return handle_.get(); } + + absl::StatusOr BufferType() const { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = + LiteRtGetTensorBufferType(handle_.get(), &tensor_buffer_type); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer type"); + } + return tensor_buffer_type; + } + + absl::StatusOr TensorType() const { + LiteRtRankedTensorType tensor_type; + if (auto status = + LiteRtGetTensorBufferTensorType(handle_.get(), &tensor_type); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor type"); + } + return RankedTensorType(tensor_type); + } + + absl::StatusOr Size() const { + size_t size; + if (auto status = LiteRtGetTensorBufferSize(handle_.get(), &size); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor size"); + } + return size; + } + + absl::StatusOr Offset() const { + size_t offset; + if (auto status = LiteRtGetTensorBufferOffset(handle_.get(), &offset); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor offset"); + } + return offset; + } + + absl::StatusOr Lock(LiteRtEvent event = nullptr) { + void* host_mem_addr; + if (auto status = + LiteRtLockTensorBuffer(handle_.get(), &host_mem_addr, event); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to lock the tensor buffer"); + } + return host_mem_addr; + } + + absl::Status Unlock() { + if (auto status = LiteRtUnlockTensorBuffer(handle_.get()); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to unlock the tensor buffer"); + } + return {}; + } + + private: + internal::Handle handle_; +}; + +class TensorBufferScopedLock { + public: + ~TensorBufferScopedLock() { (void)tensor_buffer_.Unlock(); } + + static absl::StatusOr> Create( + TensorBuffer& tensor_buffer, LiteRtEvent event = nullptr) { + auto addr = tensor_buffer.Lock(event); + if (!addr.ok()) { + return addr.status(); + } + return std::make_pair(TensorBufferScopedLock(tensor_buffer), *addr); + } + + private: + explicit TensorBufferScopedLock(TensorBuffer& tensor_buffer) + : tensor_buffer_(tensor_buffer) {} + TensorBuffer& tensor_buffer_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h new file mode 100644 index 00000000000000..3e99c28e864a66 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h @@ -0,0 +1,102 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/cc/litert_handle.h" + +namespace litert { + +// Requirements for allocating a TensorBuffer, typically specified by a HW +// accelerator for a given I/O tensor. C++ equivalent to +// LiteRtTensorBufferRequirements. +class TensorBufferRequirements { + public: + TensorBufferRequirements() = default; + + // Parameter `owned` indicates if the created TensorBufferRequirements object + // should take ownership of the provided `requirements` handle. + explicit TensorBufferRequirements(LiteRtTensorBufferRequirements requirements, + bool owned = true) + : handle_(requirements, + owned ? LiteRtDestroyTensorBufferRequirements : nullptr) {} + + static absl::StatusOr Create( + absl::Span buffer_types, + size_t buffer_size) { + LiteRtTensorBufferRequirements tensor_buffer_requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + buffer_types.size(), buffer_types.data(), buffer_size, + &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to create tensor buffer requirements"); + } + return TensorBufferRequirements(tensor_buffer_requirements); + } + + // Return true if the underlying LiteRtTensorBufferRequirements handle is + // valid. + explicit operator bool() const { return static_cast(handle_); } + + // Return the underlying LiteRtTensorBufferRequirements handle. + explicit operator LiteRtTensorBufferRequirements() { return handle_.get(); } + + absl::StatusOr> SupportedTypes() const { + int num_types; + if (auto status = + LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + handle_.get(), &num_types); + status != kLiteRtStatusOk) { + return absl::InternalError( + "Failed to get the number of supported tensor types"); + } + std::vector types(num_types); + for (auto i = 0; i < num_types; ++i) { + if (auto status = + LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + handle_.get(), i, &types[i]); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get supported tensor type"); + } + } + return types; + } + + absl::StatusOr BufferSize() const { + size_t buffer_size; + if (auto status = LiteRtGetTensorBufferRequirementsBufferSize(handle_.get(), + &buffer_size); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer size"); + } + return buffer_size; + } + + private: + internal::Handle handle_; +}; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CC_LITERT_TENSOR_BUFFER_REQUIREMENTS_H_ diff --git a/tensorflow/lite/experimental/litert/cc/litert_tensor_test.cc b/tensorflow/lite/experimental/litert/cc/litert_tensor_test.cc new file mode 100644 index 00000000000000..5ba48ac4e5d420 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/litert_tensor_test.cc @@ -0,0 +1,124 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/cc/litert_tensor.h" + +#include + +#include +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +using ::litert::LiteRtTensorManager; + +TEST(TestLiteRtTensorManager, SimpleRankedTensorSubgraphInput) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto inputs, + ::graph_tools::GetSubgraphInputs(subgraph)); + + LiteRtTensorManager::Unique tensor; + ASSERT_STATUS_OK(LiteRtTensorManager::MakeFromTensor(inputs[0], tensor)); + + ASSERT_EQ(tensor->Rank(), 2); + EXPECT_EQ(tensor->Dims(), absl::MakeConstSpan({2, 2})); + EXPECT_EQ(tensor->ElementType(), kLiteRtElementTypeFloat32); + EXPECT_EQ(tensor->Tensor(), inputs[0]); + EXPECT_TRUE(tensor->IsSubgraphInput()); + EXPECT_FALSE(tensor->IsSubgraphOutput()); +} + +TEST(TestLiteRtTensorManager, SimpleRankedTensorSubgraphOutput) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto outputs, + ::graph_tools::GetSubgraphOutputs(subgraph)); + + LiteRtTensorManager::Unique tensor; + ASSERT_STATUS_OK(LiteRtTensorManager::MakeFromTensor(outputs[0], tensor)); + + ASSERT_EQ(tensor->Rank(), 2); + EXPECT_EQ(tensor->Dims(), absl::MakeConstSpan({2, 2})); + EXPECT_EQ(tensor->ElementType(), kLiteRtElementTypeFloat32); + EXPECT_EQ(tensor->Tensor(), outputs[0]); + EXPECT_TRUE(tensor->IsSubgraphOutput()); + EXPECT_FALSE(tensor->IsSubgraphInput()); +} + +TEST(TestLiteRtTensorManager, SimpleRankedTensor) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + ASSERT_RESULT_OK_ASSIGN(auto op_outs, ::graph_tools::GetOpOuts(ops[1])); + + LiteRtTensorManager::Unique tensor; + ASSERT_STATUS_OK(LiteRtTensorManager::MakeFromTensor(op_outs[0], tensor)); + + ASSERT_EQ(tensor->Rank(), 2); + EXPECT_EQ(tensor->Dims(), absl::MakeConstSpan({2, 2})); + EXPECT_EQ(tensor->ElementType(), kLiteRtElementTypeFloat32); + EXPECT_EQ(tensor->Tensor(), op_outs[0]); + EXPECT_FALSE(tensor->IsSubgraphOutput()); + EXPECT_FALSE(tensor->IsSubgraphInput()); +} + +TEST(TestLiteRtTensorManager, NoStrides) { + int32_t dimensions[] = {1, 2, 3}; + + LiteRtTensorT tensor; + tensor.type_id = kLiteRtRankedTensorType; + tensor.type_detail.ranked_tensor_type.element_type = + kLiteRtElementTypeFloat32; + tensor.type_detail.ranked_tensor_type.layout.rank = + sizeof(dimensions) / sizeof(dimensions[0]); + tensor.type_detail.ranked_tensor_type.layout.dimensions = dimensions; + tensor.type_detail.ranked_tensor_type.layout.strides = nullptr; + + LiteRtTensorManager::Unique tensor_manager; + ASSERT_STATUS_OK( + LiteRtTensorManager::MakeFromTensor(&tensor, tensor_manager)); + EXPECT_FALSE(tensor_manager->HasStrides()); +} + +TEST(TestLiteRtTensorManager, Strides) { + int32_t dimensions[] = {1, 2, 3}; + uint32_t strides[] = {6, 3, 1}; + + LiteRtTensorT tensor; + tensor.type_id = kLiteRtRankedTensorType; + tensor.type_detail.ranked_tensor_type.element_type = + kLiteRtElementTypeFloat32; + tensor.type_detail.ranked_tensor_type.layout.rank = + sizeof(dimensions) / sizeof(dimensions[0]); + tensor.type_detail.ranked_tensor_type.layout.dimensions = dimensions; + tensor.type_detail.ranked_tensor_type.layout.strides = strides; + + LiteRtTensorManager::Unique tensor_manager; + ASSERT_STATUS_OK( + LiteRtTensorManager::MakeFromTensor(&tensor, tensor_manager)); + EXPECT_TRUE(tensor_manager->HasStrides()); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/cc/ranked_tensor_type_cc_test.cc b/tensorflow/lite/experimental/litert/cc/ranked_tensor_type_cc_test.cc new file mode 100644 index 00000000000000..f9f310037d7570 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/ranked_tensor_type_cc_test.cc @@ -0,0 +1,65 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" + +namespace { + +constexpr const int32_t kTensorDimensions[] = {1, 2, 3}; +constexpr const size_t kRank = + sizeof(kTensorDimensions) / sizeof(kTensorDimensions[0]); + +constexpr const LiteRtLayout kLayout = { + /*.rank=*/kRank, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/nullptr, +}; + +constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/kLayout, +}; + +} // namespace + +TEST(RankedTensorType, Accessors) { + litert::Layout layout(kLayout); + litert::RankedTensorType tensor_type(kTensorType); + ASSERT_EQ(tensor_type.ElementType(), + static_cast(kTensorType.element_type)); + ASSERT_TRUE(tensor_type.Layout() == layout); +} + +TEST(Layout, Equal) { + litert::RankedTensorType tensor_type1(kTensorType); + litert::RankedTensorType tensor_type2({ + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/kLayout, + }); + ASSERT_TRUE(tensor_type1 == tensor_type2); +} + +TEST(Layout, NotEqual) { + litert::RankedTensorType tensor_type1(kTensorType); + litert::RankedTensorType tensor_type2({ + /*.element_type=*/kLiteRtElementTypeFloat16, + /*.layout=*/kLayout, + }); + ASSERT_TRUE(tensor_type1 != tensor_type2); +} diff --git a/tensorflow/lite/experimental/litert/cc/tensor_buffer_requirements_test.cc b/tensorflow/lite/experimental/litert/cc/tensor_buffer_requirements_test.cc new file mode 100644 index 00000000000000..8fb5879223d52e --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/tensor_buffer_requirements_test.cc @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" + +namespace { + +constexpr const LiteRtTensorBufferType kSupportedTensorBufferTypes[] = { + kLiteRtTensorBufferTypeHostMemory, + kLiteRtTensorBufferTypeAhwb, + kLiteRtTensorBufferTypeIon, + kLiteRtTensorBufferTypeFastRpc, +}; + +constexpr const size_t kNumSupportedTensorBufferTypes = + sizeof(kSupportedTensorBufferTypes) / + sizeof(kSupportedTensorBufferTypes[0]); + +constexpr const size_t kBufferSize = 1234; + +} // namespace + +TEST(TensorBufferRequirements, Owned) { + auto requirements = litert::TensorBufferRequirements::Create( + absl::MakeSpan(kSupportedTensorBufferTypes, + kNumSupportedTensorBufferTypes), + kBufferSize); + ASSERT_TRUE(requirements.ok()); + + auto supported_types = requirements->SupportedTypes(); + ASSERT_TRUE(supported_types.ok()); + ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); + for (auto i = 0; i < supported_types->size(); ++i) { + ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); + } + + auto size = requirements->BufferSize(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, kBufferSize); +} + +TEST(TensorBufferRequirements, NotOwned) { + LiteRtTensorBufferRequirements litert_requirements; + ASSERT_EQ(LiteRtCreateTensorBufferRequirements( + kNumSupportedTensorBufferTypes, kSupportedTensorBufferTypes, + kBufferSize, &litert_requirements), + kLiteRtStatusOk); + + litert::TensorBufferRequirements requirements(litert_requirements, + /*owned=*/false); + + auto supported_types = requirements.SupportedTypes(); + ASSERT_TRUE(supported_types.ok()); + ASSERT_EQ(supported_types->size(), kNumSupportedTensorBufferTypes); + for (auto i = 0; i < supported_types->size(); ++i) { + ASSERT_EQ((*supported_types)[i], kSupportedTensorBufferTypes[i]); + } + + auto size = requirements.BufferSize(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, kBufferSize); + + ASSERT_EQ(static_cast(requirements), + litert_requirements); + + LiteRtDestroyTensorBufferRequirements(litert_requirements); +} diff --git a/tensorflow/lite/experimental/litert/cc/tensor_buffer_test.cc b/tensorflow/lite/experimental/litert/cc/tensor_buffer_test.cc new file mode 100644 index 00000000000000..420a7e5a447305 --- /dev/null +++ b/tensorflow/lite/experimental/litert/cc/tensor_buffer_test.cc @@ -0,0 +1,298 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include // NOLINT: Need when ANDROID_API_LEVEL >= 26 +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/core/ahwb_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/dmabuf_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/fastrpc_buffer.h" // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/ion_buffer.h" // IWYU pragma: keep + +namespace { +constexpr const float kTensorData[] = {10, 20, 30, 40}; + +constexpr const int32_t kTensorDimensions[] = {sizeof(kTensorData) / + sizeof(kTensorData[0])}; + +constexpr const LiteRtRankedTensorType kTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/{ + /*.rank=*/1, + /*.dimensions=*/kTensorDimensions, + /*.strides=*/nullptr, + }}; +} // namespace + +TEST(TensorBuffer, HostMemory) { + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeHostMemory; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer.ok()); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type.ok()); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type.ok()); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset.ok()); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, Ahwb) { + if (!litert::internal::AhwbBuffer::IsSupported()) { + GTEST_SKIP() << "AHardwareBuffers are not supported on this platform; " + "skipping the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeAhwb; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer.ok()); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type.ok()); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type.ok()); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset.ok()); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, Ion) { + if (!litert::internal::IonBuffer::IsSupported()) { + GTEST_SKIP() + << "ION buffers are not supported on this platform; skipping the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeIon; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer.ok()); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type.ok()); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type.ok()); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset.ok()); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, DmaBuf) { + if (!litert::internal::DmaBufBuffer::IsSupported()) { + GTEST_SKIP() + << "DMA-BUF buffers are not supported on this platform; skipping " + "the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeDmaBuf; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer.ok()); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type.ok()); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type.ok()); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset.ok()); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, FastRpc) { + if (!litert::internal::FastRpcBuffer::IsSupported()) { + GTEST_SKIP() + << "FastRPC buffers are not supported on this platform; skipping " + "the test"; + } + + const litert::RankedTensorType kTensorType(::kTensorType); + constexpr auto kTensorBufferType = kLiteRtTensorBufferTypeFastRpc; + + auto tensor_buffer = litert::TensorBuffer::CreateManaged( + kTensorBufferType, kTensorType, sizeof(kTensorData)); + ASSERT_TRUE(tensor_buffer.ok()); + + auto tensor_buffer_type = tensor_buffer->BufferType(); + ASSERT_TRUE(tensor_buffer_type.ok()); + ASSERT_EQ(*tensor_buffer_type, kTensorBufferType); + + auto tensor_type = tensor_buffer->TensorType(); + ASSERT_TRUE(tensor_type.ok()); + + ASSERT_EQ(tensor_type->ElementType(), litert::ElementType::Float32); + ASSERT_EQ(tensor_type->Layout().Rank(), 1); + ASSERT_EQ(tensor_type->Layout().Dimensions()[0], + kTensorType.Layout().Dimensions()[0]); + ASSERT_FALSE(tensor_type->Layout().HasStrides()); + + auto size = tensor_buffer->Size(); + ASSERT_TRUE(size.ok()); + ASSERT_EQ(*size, sizeof(kTensorData)); + + auto offset = tensor_buffer->Offset(); + ASSERT_TRUE(offset.ok()); + ASSERT_EQ(*offset, 0); + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + std::memcpy(lock_and_addr->second, kTensorData, sizeof(kTensorData)); + } + + { + auto lock_and_addr = litert::TensorBufferScopedLock::Create(*tensor_buffer); + ASSERT_TRUE(lock_and_addr.ok()); + ASSERT_EQ( + std::memcmp(lock_and_addr->second, kTensorData, sizeof(kTensorData)), + 0); + } +} + +TEST(TensorBuffer, NotOwned) { + LiteRtTensorBuffer litert_tensor_buffer; + ASSERT_EQ(LiteRtCreateManagedTensorBuffer(kLiteRtTensorBufferTypeHostMemory, + &kTensorType, sizeof(kTensorData), + &litert_tensor_buffer), + kLiteRtStatusOk); + + litert::TensorBuffer tensor_buffer(litert_tensor_buffer, /*owned=*/false); + ASSERT_EQ(static_cast(tensor_buffer), + litert_tensor_buffer); + + LiteRtDestroyTensorBuffer(litert_tensor_buffer); +} diff --git a/tensorflow/lite/experimental/litert/core/BUILD b/tensorflow/lite/experimental/litert/core/BUILD new file mode 100644 index 00000000000000..14a04fcfe0d7e7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/BUILD @@ -0,0 +1,297 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "api_internal", + srcs = ["litert_common.cc"], + hdrs = [ + "//tensorflow/lite/experimental/litert/c:litert_common.h", + "//tensorflow/lite/experimental/litert/c:litert_logging.h", + "//tensorflow/lite/experimental/litert/c:litert_model.h", + "//tensorflow/lite/experimental/litert/c:litert_op_code.h", + "//tensorflow/lite/experimental/litert/c:litert_options.h", + "//tensorflow/lite/experimental/litert/c:litert_support.h", + "//tensorflow/lite/experimental/litert/cc:litert_support.h", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h", + ], + deps = [ + "//tensorflow/lite:builtin_ops", + "//tensorflow/lite/core/c:c_api_types", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "model", + srcs = [ + "model.cc", + ], + hdrs = [ + "model.h", + ], + deps = [ + ":api_internal", + "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "option", + srcs = ["option.cc"], + hdrs = [ + "model.h", + ], + deps = [ + ":api_internal", + "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "litert_model_init", + srcs = ["litert_model_init.cc"], + hdrs = ["litert_model_init.h"], + deps = [ + ":api_internal", + ":model", + ":option", + "//tensorflow/compiler/mlir/lite:allocation", + "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite:framework", + "//tensorflow/lite:stderr_reporter", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:absl_check", + "@flatbuffers//:runtime_cc", + ], +) + +cc_library( + name = "litert_model_serialize", + srcs = ["litert_model_serialize.cc"], + hdrs = ["litert_model_serialize.h"], + deps = [ + ":litert_model_init", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_test( + name = "litert_model_serialize_test", + srcs = ["litert_model_serialize_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":graph_tools", + ":litert_model_init", + ":litert_model_serialize", + ":model", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "model_test", + srcs = ["model_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + tags = ["no_oss"], + deps = [ + ":api_internal", + ":graph_tools", + ":litert_model_init", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + ], +) + +cc_test( + name = "option_test", + srcs = ["option_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + tags = ["no_oss"], + deps = [ + ":api_internal", + ":graph_tools", + ":litert_model_init", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "graph_tools", + hdrs = [ + "graph_tools.h", + ], + deps = [ + ":api_internal", + "//tensorflow/compiler/mlir/lite/core:model_builder_base", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "dynamic_loading", + srcs = ["dynamic_loading.cc"], + hdrs = ["dynamic_loading.h"], + linkopts = ["-ldl"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +# copybara:uncomment_begin(no OSS for unique-test-directory) +# cc_test( +# name = "dynamic_loading_test", +# srcs = ["dynamic_loading_test.cc"], +# tags = [ +# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. +# "noasan", +# "nomsan", +# "nosan", +# ], +# deps = [ +# ":dynamic_loading", +# "@com_google_googletest//:gtest_main", +# "//testing/base/public:unique-test-directory", +# "@com_google_absl//absl/strings:string_view", +# "//tensorflow/lite/experimental/litert/c:litert_c_api", +# "//tensorflow/lite/experimental/litert/test:common", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "tensor_buffer", + srcs = [ + "ahwb_buffer.cc", + "dmabuf_buffer.cc", + "event.cc", + "fastrpc_buffer.cc", + "ion_buffer.cc", + "tensor_buffer.cc", + ], + hdrs = [ + "ahwb_buffer.h", + "dmabuf_buffer.h", + "event.h", + "fastrpc_buffer.h", + "ion_buffer.h", + "tensor_buffer.h", + "//tensorflow/lite/experimental/litert/c:litert_event.h", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer.h", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer_requirements.h", + ], + deps = [ + ":utils", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "utils.h", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "tfl_utils", + srcs = [ + "tfl_utils.cc", + ], + hdrs = [ + "tfl_utils.h", + ], + deps = [ + "//tensorflow/lite/c:c_api", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) diff --git a/tensorflow/lite/experimental/litert/core/ahwb_buffer.cc b/tensorflow/lite/experimental/litert/core/ahwb_buffer.cc new file mode 100644 index 00000000000000..622a419ce78919 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/ahwb_buffer.cc @@ -0,0 +1,111 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/ahwb_buffer.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" + +namespace litert { +namespace internal { + +bool AhwbBuffer::IsSupported() { +#if LITERT_HAS_AHWB_SUPPORT + return true; +#else + return false; +#endif +} + +absl::StatusOr AhwbBuffer::Alloc(size_t size) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer* ahwb; + AHardwareBuffer_Desc ahwb_desc = { + .width = static_cast(size), + .height = 1, + .layers = 1, + .format = AHARDWAREBUFFER_FORMAT_BLOB, + .usage = AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY | + AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | + AHARDWAREBUFFER_USAGE_GPU_DATA_BUFFER}; + if (AHardwareBuffer_allocate(&ahwb_desc, &ahwb) != 0) { + return absl::InternalError("Failed to allocate AHWB"); + } + return AhwbBuffer{/*.ahwb=*/ahwb}; +#else + return absl::InternalError( + "AHardwareBuffers are not supported on this platform"); +#endif // LITERT_HAS_AHWB_SUPPORT +} + +void AhwbBuffer::Free(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer_release(ahwb); +#endif +} + +absl::StatusOr AhwbBuffer::GetSize(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + AHardwareBuffer_Desc ahwb_desc; + AHardwareBuffer_describe(ahwb, &ahwb_desc); + return static_cast(ahwb_desc.width) * ahwb_desc.height * + ahwb_desc.layers; +#else + return absl::InternalError( + "AHardwareBuffers are not supported on this platform"); +#endif // LITERT_HAS_AHWB_SUPPORT +} + +absl::StatusOr AhwbBuffer::Lock(AHardwareBuffer* ahwb, + LiteRtEvent event) { +#if LITERT_HAS_AHWB_SUPPORT + int fence = -1; + if (event) { + if (auto status = LiteRtEventGetSyncFenceFd(event, &fence); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get sync fence fd from event"); + } + } + void* host_addr; + if (AHardwareBuffer_lock(ahwb, + AHARDWAREBUFFER_USAGE_CPU_READ_RARELY | + AHARDWAREBUFFER_USAGE_CPU_WRITE_RARELY, + fence, /*rect=*/nullptr, &host_addr) != 0) { + return absl::InternalError("Failed to lock AHWB"); + } + return host_addr; +#else + return absl::InternalError( + "AHardwareBuffers are not supported on this platform"); +#endif +} + +absl::Status AhwbBuffer::Unlock(AHardwareBuffer* ahwb) { +#if LITERT_HAS_AHWB_SUPPORT + if (AHardwareBuffer_unlock(ahwb, /*fence=*/nullptr) != 0) { + return absl::InternalError("Failed to unlock AHWB"); + } + return {}; +#else + return absl::InternalError( + "AHardwareBuffers are not supported on this platform"); +#endif +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/ahwb_buffer.h b/tensorflow/lite/experimental/litert/core/ahwb_buffer.h new file mode 100644 index 00000000000000..9053fbb409390c --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/ahwb_buffer.h @@ -0,0 +1,54 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_AHWB_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_AHWB_BUFFER_H_ + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" + +#if LITERT_HAS_AHWB_SUPPORT +#include +#else +// Define a place holder AHardwareBuffer struct just to enable compilation. +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus +typedef struct AHardwareBuffer AHardwareBuffer; +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // LITERT_HAS_AHWB_SUPPORT + +#include "absl/status/statusor.h" + +namespace litert { +namespace internal { + +struct AhwbBuffer { + AHardwareBuffer* ahwb; + + static bool IsSupported(); + static absl::StatusOr Alloc(size_t size); + static void Free(AHardwareBuffer* ahwb); + static absl::StatusOr GetSize(AHardwareBuffer* ahwb); + static absl::StatusOr Lock(AHardwareBuffer* ahwb, + LiteRtEvent event = nullptr); + static absl::Status Unlock(AHardwareBuffer* ahwb); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_AHWB_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/compiler_plugin/BUILD b/tensorflow/lite/experimental/litert/core/compiler_plugin/BUILD new file mode 100644 index 00000000000000..849d940deae3ff --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/BUILD @@ -0,0 +1,109 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "compiler_plugin_hdr", + hdrs = ["compiler_plugin.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:dynamic_loading", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "compiler_plugin", + srcs = ["compiler_plugin.cc"], + hdrs = ["compiler_plugin.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:dynamic_loading", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +# copybara:uncomment_begin(no OSS for unique-test-directory) +# cc_test( +# name = "compiler_plugin_test", +# srcs = ["compiler_plugin_test.cc"], +# data = [ +# "//tensorflow/lite/experimental/litert/test:tflite_test_data", +# "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", +# ], +# tags = [ +# # Sanitizer runtimes are incompatible with RTLD_DEEPBIND. +# "noasan", +# "nomsan", +# "nosan", +# ], +# deps = [ +# ":compiler_plugin", +# "@com_google_googletest//:gtest_main", +# "//testing/base/public:unique-test-directory", +# "@com_google_absl//absl/strings:string_view", +# "//tensorflow/lite/experimental/litert/core:graph_tools", +# "//tensorflow/lite/experimental/litert/test:common", +# "//tensorflow/lite/experimental/litert/tools:dump", +# ], +# ) +# copybara:uncomment_end + +cc_library( + name = "algo", + srcs = ["algo.cc"], + hdrs = ["algo.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + ], +) + +cc_test( + name = "algo_test", + srcs = ["algo_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + ], + deps = [ + ":algo", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo.cc similarity index 62% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h rename to tensorflow/lite/experimental/litert/core/compiler_plugin/algo.cc index d489b2d287c652..8a9a448dc0f433 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo.cc @@ -12,27 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h" #include #include #include +#include #include #include #include -#include "absl/log/check.h" #include "llvm/ADT/MapVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/model.h" #include "tensorflow/lite/schema/schema_generated.h" -// NOLINTBEGIN - -namespace algo { +namespace litert::internal { +namespace { // // flatlist to partition(s) @@ -40,18 +38,21 @@ namespace algo { class DisjointSets { public: - static std::vector> GetPartitionsFromFlatList( - const std::vector& flat_op_list); + static std::vector> GetPartitionsFromFlatList( + const std::vector& flat_op_list); private: - void Insert(LrtOp op, LrtOp parent); - std::vector> GetBuckets(); - LrtOp GetBucket(LrtOp op); - llvm::MapVector map_; + void Insert(LiteRtOp op, LiteRtOp parent); + std::vector> GetBuckets(); + LiteRtOp GetBucket(LiteRtOp op); + // NOLINTBEGIN + llvm::MapVector map_; + // NOLINTEND }; -inline std::vector> DisjointSets::GetPartitionsFromFlatList( - const std::vector& flat_op_list) { +inline std::vector> +DisjointSets::GetPartitionsFromFlatList( + const std::vector& flat_op_list) { DisjointSets disjoint_sets; for (auto* op : flat_op_list) { disjoint_sets.map_[op] = op; @@ -71,7 +72,7 @@ inline std::vector> DisjointSets::GetPartitionsFromFlatList( return disjoint_sets.GetBuckets(); } -inline void DisjointSets::Insert(LrtOp op, LrtOp parent) { +inline void DisjointSets::Insert(LiteRtOp op, LiteRtOp parent) { auto* parent_bucket = GetBucket(parent); auto* op_bucket = GetBucket(op); if (op_bucket == parent_bucket) { @@ -81,19 +82,21 @@ inline void DisjointSets::Insert(LrtOp op, LrtOp parent) { } // Get all disjoint sets. -inline std::vector> DisjointSets::GetBuckets() { - std::unordered_map> invert_map; +inline std::vector> DisjointSets::GetBuckets() { + // NOLINTBEGIN + std::unordered_map> invert_map; + // NOLINTEND for (const auto& entry : map_) { auto* bucket = GetBucket(entry.first); if (invert_map.find(bucket) == invert_map.end()) { - invert_map.insert_or_assign(bucket, std::vector{}); + invert_map.insert_or_assign(bucket, std::vector{}); } invert_map[bucket].push_back(entry.first); } - std::vector> res; + std::vector> res; res.reserve(invert_map.size()); for (auto& entry : invert_map) { @@ -105,7 +108,7 @@ inline std::vector> DisjointSets::GetBuckets() { // Gets the pointer which serves as the key for given ops bucket. Collapses // paths to amortize. -inline LrtOp DisjointSets::GetBucket(LrtOp op) { +inline LiteRtOp DisjointSets::GetBucket(LiteRtOp op) { auto* parent = map_[op]; if (op != parent) { parent = GetBucket(parent); @@ -120,21 +123,21 @@ inline LrtOp DisjointSets::GetBucket(LrtOp op) { // TODO: b/365339578 - Move helpers from algo.h to the internal model library. -inline void CloneOpData(const LrtOpT& old_op, LrtOpT& new_op) { +inline void CloneOpData(const LiteRtOpT& old_op, LiteRtOpT& new_op) { // TODO: b/365339578 - Support options in op clone. new_op.op_code = old_op.op_code; } -inline void CloneTensorData(const LrtTensorT& old_tensor, - LrtTensorT& new_tensor) { +inline void CloneTensorData(const LiteRtTensorT& old_tensor, + LiteRtTensorT& new_tensor) { new_tensor.type_id = old_tensor.type_id; new_tensor.type_detail = old_tensor.type_detail; - new_tensor.buffer.fb_buffer = std::make_unique(); + new_tensor.weights.fb_buffer = std::make_unique(); } -inline std::optional FindUseInd(LrtTensor tensor, - LrtOp user) { - for (lrt_param_index_t i = 0; i < tensor->users.size(); ++i) { +inline std::optional FindUseInd(LiteRtTensor tensor, + LiteRtOp user) { + for (LiteRtParamIndex i = 0; i < tensor->users.size(); ++i) { if (tensor->users[i] == user) { return i; } @@ -142,7 +145,7 @@ inline std::optional FindUseInd(LrtTensor tensor, return std::nullopt; } -inline void EraseUse(LrtTensor tensor, lrt_param_index_t use_ind) { +inline void EraseUse(LiteRtTensor tensor, LiteRtParamIndex use_ind) { if (use_ind < 0 || use_ind >= tensor->users.size()) { return; } @@ -152,58 +155,58 @@ inline void EraseUse(LrtTensor tensor, lrt_param_index_t use_ind) { tensor->user_arg_inds.pop_back(); } -inline void EraseUse(LrtTensor tensor, LrtOp user) { +inline void EraseUse(LiteRtTensor tensor, LiteRtOp user) { auto use_ind = FindUseInd(tensor, user); if (!use_ind.has_value()) { - _LRT_D_MSG("Trying to erase from tensor that doesn't use.") + _LITERT_D_MSG("Trying to erase from tensor that doesn't use.") return; } EraseUse(tensor, use_ind.value()); } // Push tensor to the end of ops arguments. -inline void AddUse(LrtTensorT& tensor, LrtOpT& op) { +inline void AddUse(LiteRtTensorT& tensor, LiteRtOpT& op) { op.inputs.push_back(&tensor); tensor.users.push_back(&op); tensor.user_arg_inds.push_back(op.inputs.size() - 1); } -inline void AddOutput(LrtOpT& op, LrtTensorT& tensor) { - DCHECK(tensor.defining_op == nullptr); +inline void AddOutput(LiteRtOpT& op, LiteRtTensorT& tensor) { op.outputs.push_back(&tensor); tensor.defining_op = &op; tensor.defining_op_out_ind = op.outputs.size() - 1; } -inline LrtTensor RequestNewTensor(LrtSubgraph subgraph, - const LrtTensorT& like) { +inline LiteRtTensor RequestNewTensor(LiteRtSubgraph subgraph, + const LiteRtTensorT& like) { auto& new_tensor = subgraph->tensors_storage.emplace_back(); CloneTensorData(like, new_tensor); return &new_tensor; } -inline LrtTensor RequestNewInput(LrtSubgraph subgraph, const LrtTensorT& like) { +inline LiteRtTensor RequestNewInput(LiteRtSubgraph subgraph, + const LiteRtTensorT& like) { auto new_tensor = RequestNewTensor(subgraph, like); subgraph->inputs.push_back(new_tensor); return new_tensor; } -inline LrtOp RequestNewOp(LrtSubgraph subgraph, const LrtOpT& like) { +inline LiteRtOp RequestNewOp(LiteRtSubgraph subgraph, const LiteRtOpT& like) { auto& new_op = subgraph->ops_storage.emplace_back(); CloneOpData(like, new_op); return &new_op; } -inline void AddOutput(LrtSubgraph subgraph, LrtTensor tensor) { +inline void AddOutput(LiteRtSubgraph subgraph, LiteRtTensor tensor) { subgraph->outputs.push_back(tensor); } -inline bool IsOutput(const LrtSubgraphT& subgraph, LrtTensor tensor) { +inline bool IsOutput(const LiteRtSubgraphT& subgraph, LiteRtTensor tensor) { return std::count(subgraph.outputs.begin(), subgraph.outputs.end(), tensor) > 0; } -inline void UpdateReferences(LrtSubgraphT& subgraph) { +inline void UpdateReferences(LiteRtSubgraphT& subgraph) { subgraph.tensors.clear(); subgraph.ops.clear(); for (auto& tensor : subgraph.tensors_storage) { @@ -214,7 +217,7 @@ inline void UpdateReferences(LrtSubgraphT& subgraph) { } } -inline void Drop(LrtOpT& op) { +inline void Drop(LiteRtOpT& op) { for (auto tensor : op.inputs) { EraseUse(tensor, &op); } @@ -226,7 +229,7 @@ inline void Drop(LrtOpT& op) { } // TODO expand dead code elimination to work recursively. This is a very simple. -inline void DCE(LrtSubgraphT& subgraph) { +inline void DCE(LiteRtSubgraphT& subgraph) { auto& ops = subgraph.ops_storage; for (auto it = ops.begin(); it != ops.end();) { if (it->inputs.empty() && it->outputs.empty()) { @@ -236,8 +239,11 @@ inline void DCE(LrtSubgraphT& subgraph) { } } - std::set inputs(subgraph.inputs.begin(), subgraph.inputs.end()); - std::set outputs(subgraph.outputs.begin(), subgraph.outputs.end()); + // NOLINTBEGIN + std::set inputs(subgraph.inputs.begin(), subgraph.inputs.end()); + std::set outputs(subgraph.outputs.begin(), + subgraph.outputs.end()); + // NOLINTEND auto& tensors = subgraph.tensors_storage; for (auto it = tensors.begin(); it != tensors.end();) { @@ -262,24 +268,28 @@ class GraphSlicer { // Slices "partitions" from "root" into the empty subgraph "slice". Assumes // the partition is a valid sub-DAG, and replaces it witha single // tfl.custom_op in "root". A reference to that op is returned. - static LrtOp SlicePartitionFromGraph(LrtSubgraphT& root, LrtSubgraph slice, - std::vector& partition); + static LiteRtOp SlicePartitionFromGraph(LiteRtSubgraphT& root, + LiteRtSubgraph slice, + std::vector& partition); private: - explicit GraphSlicer(LrtSubgraph slice) : slice_(slice) {} + explicit GraphSlicer(LiteRtSubgraph slice) : slice_(slice) {} - void CloneInto(const LrtOpT& op); + void CloneInto(const LiteRtOpT& op); - void RerouteTensorsThroughCustomOp(const LrtSubgraphT& root); + void RerouteTensorsThroughCustomOp(const LiteRtSubgraphT& root); - LrtSubgraph slice_; - // maps tensor in old subgraph to tensor in new subgraph. - llvm::MapVector tensor_map_; - LrtOp hal_cal_op_ = nullptr; + LiteRtSubgraph slice_; + // Maps tensor in old subgraph to tensor in new subgraph. + // NOLINTBEGIN + llvm::MapVector tensor_map_; + // NOLINTEND + LiteRtOp hal_cal_op_ = nullptr; }; -inline LrtOp GraphSlicer::SlicePartitionFromGraph( - LrtSubgraphT& root, LrtSubgraph slice, std::vector& partition) { +inline LiteRtOp GraphSlicer::SlicePartitionFromGraph( + LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition) { GraphSlicer slicer(slice); for (auto* op : partition) { @@ -293,7 +303,7 @@ inline LrtOp GraphSlicer::SlicePartitionFromGraph( // Reuse the storage from the last op in partition to maintain // toplogical order. slicer.hal_cal_op_ = partition.back(); - slicer.hal_cal_op_->op_code = kLrtOpCodeTflCustom; + slicer.hal_cal_op_->op_code = kLiteRtOpCodeTflCustom; UpdateReferences(*slicer.slice_); slicer.RerouteTensorsThroughCustomOp(root); @@ -302,10 +312,8 @@ inline LrtOp GraphSlicer::SlicePartitionFromGraph( return slicer.hal_cal_op_; } -// TODO replace this with iteration order sensitve one and fix the reversered -// arg order issue inline void GraphSlicer::RerouteTensorsThroughCustomOp( - const LrtSubgraphT& root) { + const LiteRtSubgraphT& root) { for (auto& [old_tensor, new_tensor] : tensor_map_) { // Reroute tensors which need to be passed into the scope of the new // subgraph to inputs of the custom op. @@ -317,20 +325,18 @@ inline void GraphSlicer::RerouteTensorsThroughCustomOp( // Reroute custom op as the definer of tensors within the removed partition // and referenced latern in the root graph. if (!old_tensor->users.empty() || IsOutput(root, old_tensor)) { - DCHECK(old_tensor->defining_op == nullptr) - << "Defining op should have been removed from the graph"; AddOutput(*hal_cal_op_, *old_tensor); AddOutput(slice_, new_tensor); } } } -inline void GraphSlicer::CloneInto(const LrtOpT& old_op) { +inline void GraphSlicer::CloneInto(const LiteRtOpT& old_op) { auto& new_op = *RequestNewOp(slice_, old_op); for (int i = 0; i < old_op.inputs.size(); ++i) { auto old_input = old_op.inputs[i]; - LrtTensor new_input; + LiteRtTensor new_input; if (tensor_map_.contains(old_input)) { // If old_input is already in the map then map[input] is its cloned @@ -356,8 +362,16 @@ inline void GraphSlicer::CloneInto(const LrtOpT& old_op) { } } -} // namespace algo +} // namespace + +std::vector> GroupPartitions( + const std::vector& ops) { + return DisjointSets::GetPartitionsFromFlatList(ops); +} -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_ALGO_H_ +LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition) { + return GraphSlicer::SlicePartitionFromGraph(root, slice, partition); +} -// NOLINTEND +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h new file mode 100644 index 00000000000000..6ec06109df9773 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_ALGO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_ALGO_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/model.h" + +namespace litert::internal { + +// Identifies sub-DAGs of ops connected w.r.t. the use-def chain. Expects +// all "ops" belong to the same Subgraph. The ops in the input +// and output will always be the same. +std::vector> GroupPartitions( + const std::vector& ops); + +// Outlines "partitin" from "root" into the empty subgraph "slice". Assumes +// the partition is a valid sub-DAG, and replaces it witha single +// tfl.custom_op in "root". A reference to that op is returned. +LiteRtOp OutlinePartition(LiteRtSubgraphT& root, LiteRtSubgraph slice, + std::vector& partition); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_ALGO_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo_test.cc similarity index 67% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc rename to tensorflow/lite/experimental/litert/core/compiler_plugin/algo_test.cc index 90c1b55e6e2abf..0bb66414bd7462 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/algo_test.cc +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/algo_test.cc @@ -12,33 +12,33 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/algo.h" +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h" #include #include #include #include -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" namespace { -using ::algo::DisjointSets; -using ::algo::GraphSlicer; +using ::litert::internal::GroupPartitions; +using ::litert::internal::OutlinePartition; // NOLINTBEGIN -bool HasValidGeneralTopology(LrtSubgraph subgraph) { +bool HasValidGeneralTopology(LiteRtSubgraph subgraph) { if (!::graph_tools::ValidateTopology(subgraph->ops)) { - _LRT_D_MSG("Failed validate op tolopology"); + _LITERT_D_MSG("Failed validate op tolopology"); return false; } - std::unordered_set implied_subgraph_outs; + std::unordered_set implied_subgraph_outs; for (auto tensor : subgraph->tensors) { if (tensor->users.empty()) { implied_subgraph_outs.insert(tensor); @@ -46,33 +46,33 @@ bool HasValidGeneralTopology(LrtSubgraph subgraph) { } if (implied_subgraph_outs.size() != subgraph->outputs.size()) { - _LRT_D_MSG("Outs not same size"); + _LITERT_D_MSG("Outs not same size"); return false; } for (auto tensor : subgraph->outputs) { if (implied_subgraph_outs.find(tensor) == implied_subgraph_outs.end()) { - _LRT_D_MSG("Mismatched subgraph outs"); + _LITERT_D_MSG("Mismatched subgraph outs"); return false; } } - std::unordered_set implied_subgraph_ins; + std::unordered_set implied_subgraph_ins; for (auto tensor : subgraph->tensors) { if (tensor->defining_op == nullptr && - tensor->buffer.fb_buffer->data.empty()) { + tensor->weights.fb_buffer->data.empty()) { implied_subgraph_ins.insert(tensor); } } if (implied_subgraph_ins.size() != subgraph->inputs.size()) { - _LRT_D_MSG("Ins not same size"); + _LITERT_D_MSG("Ins not same size"); return false; } for (auto tensor : subgraph->inputs) { if (implied_subgraph_ins.find(tensor) == implied_subgraph_ins.end()) { - _LRT_D_MSG("Mismatched subgraph ins"); + _LITERT_D_MSG("Mismatched subgraph ins"); return false; } } @@ -82,7 +82,7 @@ bool HasValidGeneralTopology(LrtSubgraph subgraph) { // NOLINTEND TEST(TestPartitionsFromFlatList, SimpleMultiOp) { - auto model = LoadTestFileModel("simple_multi_op.tflite"); + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); @@ -95,11 +95,11 @@ TEST(TestPartitionsFromFlatList, SimpleMultiOp) { // return 3 { - std::vector partition; + std::vector partition; partition.push_back(ops[1]); partition.push_back(ops[2]); - auto partitions = DisjointSets::GetPartitionsFromFlatList(partition); + auto partitions = GroupPartitions(partition); ASSERT_EQ(partitions.size(), 1); ASSERT_EQ(partitions.front().size(), 2); @@ -108,11 +108,11 @@ TEST(TestPartitionsFromFlatList, SimpleMultiOp) { } { - std::vector partition; + std::vector partition; partition.push_back(ops[1]); partition.push_back(ops[3]); - auto partitions = DisjointSets::GetPartitionsFromFlatList(partition); + auto partitions = GroupPartitions(partition); ASSERT_EQ(partitions.size(), 2); ASSERT_EQ(partitions.front().size(), 1); ASSERT_EQ(partitions.back().size(), 1); @@ -120,26 +120,27 @@ TEST(TestPartitionsFromFlatList, SimpleMultiOp) { auto p1_op_code = partitions.front().front()->op_code; auto p2_op_code = partitions.back().front()->op_code; - ASSERT_TRUE( - (p1_op_code == kLrtOpCodeTflMul && p2_op_code == kLrtOpCodeTflAdd) || - (p1_op_code == kLrtOpCodeTflAdd && p2_op_code == kLrtOpCodeTflMul)); + ASSERT_TRUE((p1_op_code == kLiteRtOpCodeTflMul && + p2_op_code == kLiteRtOpCodeTflAdd) || + (p1_op_code == kLiteRtOpCodeTflAdd && + p2_op_code == kLiteRtOpCodeTflMul)); } { - std::vector partition; + std::vector partition; - auto partitions = DisjointSets::GetPartitionsFromFlatList(partition); + auto partitions = GroupPartitions(partition); ASSERT_EQ(partitions.size(), 0); } { - std::vector partition; + std::vector partition; partition.push_back(ops[0]); partition.push_back(ops[1]); partition.push_back(ops[2]); partition.push_back(ops[3]); - auto partitions = DisjointSets::GetPartitionsFromFlatList(partition); + auto partitions = GroupPartitions(partition); ASSERT_EQ(partitions.size(), 1); ASSERT_EQ(partitions.front().size(), 4); @@ -151,7 +152,7 @@ TEST(TestPartitionsFromFlatList, SimpleMultiOp) { } TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { - auto model = LoadTestFileModel("simple_multi_op.tflite"); + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); @@ -164,13 +165,12 @@ TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { // 3 = tfl.add 2, 2 // return 3 - std::vector partition; + std::vector partition; partition.push_back(ops[1]); partition.push_back(ops[2]); - LrtSubgraph sliced_graph = &model->subgraphs.emplace_back(); - auto* hal_cal_op = - GraphSlicer::SlicePartitionFromGraph(*subgraph, sliced_graph, partition); + LiteRtSubgraph sliced_graph = &model->subgraphs.emplace_back(); + auto* hal_cal_op = OutlinePartition(*subgraph, sliced_graph, partition); ASSERT_TRUE(HasValidGeneralTopology(sliced_graph)); ASSERT_TRUE(HasValidGeneralTopology(subgraph)); @@ -179,16 +179,16 @@ TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { graph_tools::GetSubgraphOps(subgraph)); ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops[0]->op_code, kLrtOpCodeTflAdd); - ASSERT_EQ(edited_subgraph_ops[1]->op_code, kLrtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops[2]->op_code, kLrtOpCodeTflAdd); + ASSERT_EQ(edited_subgraph_ops[0]->op_code, kLiteRtOpCodeTflAdd); + ASSERT_EQ(edited_subgraph_ops[1]->op_code, kLiteRtOpCodeTflCustom); + ASSERT_EQ(edited_subgraph_ops[2]->op_code, kLiteRtOpCodeTflAdd); ASSERT_RESULT_OK_ASSIGN(auto sliced_subgraph_ops, graph_tools::GetSubgraphOps(sliced_graph)); ASSERT_EQ(sliced_subgraph_ops.size(), 2); - ASSERT_EQ(sliced_subgraph_ops[0]->op_code, kLrtOpCodeTflMul); - ASSERT_EQ(sliced_subgraph_ops[1]->op_code, kLrtOpCodeTflMul); + ASSERT_EQ(sliced_subgraph_ops[0]->op_code, kLiteRtOpCodeTflMul); + ASSERT_EQ(sliced_subgraph_ops[1]->op_code, kLiteRtOpCodeTflMul); ASSERT_EQ(hal_cal_op, edited_subgraph_ops[1]); @@ -228,12 +228,12 @@ TEST(TestSliceSubgraphSimpleMultiOp, OnePartition) { ASSERT_EQ(sliced_subgraph_outputs.size(), 1); ASSERT_TRUE(graph_tools::MatchTensorDefiningOp( sliced_subgraph_outputs[0], 0, sliced_subgraph_ops.back())); - ASSERT_TRUE(graph_tools::MatchkTensorNoUses(sliced_subgraph_outputs[0])); + ASSERT_TRUE(graph_tools::MatchTensorNoUses(sliced_subgraph_outputs[0])); } } TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { - auto model = LoadTestFileModel("simple_multi_op.tflite"); + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); @@ -245,21 +245,21 @@ TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { // 3 = tfl.add 2, 2 // return 3 - std::vector partition_1; + std::vector partition_1; partition_1.push_back(ops[0]); - LrtSubgraph sliced_graph_1 = &model->subgraphs.emplace_back(); - GraphSlicer::SlicePartitionFromGraph(*subgraph, sliced_graph_1, partition_1); + LiteRtSubgraph sliced_graph_1 = &model->subgraphs.emplace_back(); + OutlinePartition(*subgraph, sliced_graph_1, partition_1); ASSERT_TRUE(HasValidGeneralTopology(sliced_graph_1)); ASSERT_TRUE(HasValidGeneralTopology(subgraph)); - std::vector partition_2; + std::vector partition_2; partition_2.push_back(ops[2]); partition_2.push_back(ops[3]); - LrtSubgraph sliced_graph_2 = &model->subgraphs.emplace_back(); - GraphSlicer::SlicePartitionFromGraph(*subgraph, sliced_graph_2, partition_2); + LiteRtSubgraph sliced_graph_2 = &model->subgraphs.emplace_back(); + OutlinePartition(*subgraph, sliced_graph_2, partition_2); ASSERT_TRUE(HasValidGeneralTopology(sliced_graph_2)); ASSERT_TRUE(HasValidGeneralTopology(subgraph)); @@ -268,16 +268,16 @@ TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { graph_tools::GetSubgraphOps(subgraph)); ASSERT_EQ(edited_subgraph_ops.size(), 3); - ASSERT_EQ(edited_subgraph_ops[0]->op_code, kLrtOpCodeTflCustom); - ASSERT_EQ(edited_subgraph_ops[1]->op_code, kLrtOpCodeTflMul); - ASSERT_EQ(edited_subgraph_ops[2]->op_code, kLrtOpCodeTflCustom); + ASSERT_EQ(edited_subgraph_ops[0]->op_code, kLiteRtOpCodeTflCustom); + ASSERT_EQ(edited_subgraph_ops[1]->op_code, kLiteRtOpCodeTflMul); + ASSERT_EQ(edited_subgraph_ops[2]->op_code, kLiteRtOpCodeTflCustom); { ASSERT_RESULT_OK_ASSIGN(auto sliced_ops, graph_tools::GetSubgraphOps(sliced_graph_1)); ASSERT_EQ(sliced_ops.size(), 1); - ASSERT_EQ(sliced_ops[0]->op_code, kLrtOpCodeTflAdd); + ASSERT_EQ(sliced_ops[0]->op_code, kLiteRtOpCodeTflAdd); } { @@ -285,8 +285,8 @@ TEST(TestSliceSubgraphSimpleMultiOp, TwoPartitions) { graph_tools::GetSubgraphOps(sliced_graph_2)); ASSERT_EQ(sliced_ops.size(), 2); - ASSERT_EQ(sliced_ops[0]->op_code, kLrtOpCodeTflMul); - ASSERT_EQ(sliced_ops[1]->op_code, kLrtOpCodeTflAdd); + ASSERT_EQ(sliced_ops[0]->op_code, kLiteRtOpCodeTflMul); + ASSERT_EQ(sliced_ops[1]->op_code, kLiteRtOpCodeTflAdd); } } diff --git a/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.cc b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.cc new file mode 100644 index 00000000000000..4b7fe66996f11f --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.cc @@ -0,0 +1,284 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h" + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" + +namespace litert::internal { + +// +// CompiledResult +// + +std::string CompiledResult::BytesT::String() const { + return std::string(data, size); +} + +LiteRtResult CompiledResult::ByteCode() const { + BytesT byte_code; + LITERT_RETURN_RESULT_IF_NOT_OK( + allocating_plugin_api_.compiled_result_get_byte_code( + compiled_result_handle_, + reinterpret_cast(&byte_code.data), &byte_code.size), + BytesT); + return LiteRtResult::FromValue(byte_code); +} + +LiteRtResult CompiledResult::NumCalls() const { + LiteRtParamIndex call_idx; + LITERT_RETURN_RESULT_IF_NOT_OK( + allocating_plugin_api_.compiled_result_get_num_calls( + compiled_result_handle_, &call_idx), + LiteRtParamIndex); + return LiteRtResult::FromValue(call_idx); +} + +LiteRtResult CompiledResult::CallInfo( + LiteRtParamIndex call_idx) const { + BytesT call_info; + LITERT_RETURN_RESULT_IF_NOT_OK( + allocating_plugin_api_.compiled_result_get_call_info( + compiled_result_handle_, call_idx, + reinterpret_cast(&call_info.data), &call_info.size), + std::string); + return LiteRtResult::FromValue(call_info.String()); +} + +CompiledResult::~CompiledResult() { + allocating_plugin_api_.compiled_result_destroy(compiled_result_handle_); +} + +// +// CompilerPlugin +// + +namespace { + +#define RESOLVE_API_FUNC(ty, name, dest) \ + LITERT_RETURN_STATUS_IF_NOT_OK(ResolveLibSymbol(lib_handle, name, &dest)); + +LiteRtStatus ResolvePluginApi(void* lib_handle, + LiteRtCompilerPluginApi& result) { + RESOLVE_API_FUNC(LiteRtPluginApiSocManufacturer, + "LiteRtPluginSocManufacturer", result.soc_manufacturer); + RESOLVE_API_FUNC(LiteRtPluginApiNumSupportedModels, + "LiteRtPluginNumSupportedSocModels", + result.num_supported_models); + RESOLVE_API_FUNC(LiteRtPluginApiGetSupportedSocModel, + "LiteRtPluginGetSupportedSocModel", + result.get_supported_soc_model); + + RESOLVE_API_FUNC(LiteRtPluginApiInit, "LiteRtPluginInit", result.init); + RESOLVE_API_FUNC(LiteRtPluginApiDestroy, "LiteRtPluginDestroy", + result.destroy); + + RESOLVE_API_FUNC(LiteRtPluginApiPartitionModel, "LiteRtPluginPartitionModel", + result.partition_model); + RESOLVE_API_FUNC(LiteRtPluginApiCompile, "LiteRtPluginCompile", + result.compile); + + RESOLVE_API_FUNC(LiteRtCompiledResultApiDestroy, + "LiteRtCompiledResultDestroy", + result.compiled_result_destroy); + RESOLVE_API_FUNC(LiteRtCompiledResultApiGetByteCode, + "LiteRtCompiledResultGetByteCode", + result.compiled_result_get_byte_code); + RESOLVE_API_FUNC(LiteRtCompiledResultApiGetCallInfo, + "LiteRtCompiledResultGetCallInfo", + result.compiled_result_get_call_info); + RESOLVE_API_FUNC(LiteRtCompiledResultApiGetNumCalls, + "LiteRtCompiledResultGetNumCalls", + result.compiled_result_get_num_calls); + return kLiteRtStatusOk; +} + +std::vector GetSocModels(const LiteRtCompilerPluginApi& api, + LiteRtCompilerPlugin plugin_handle) { + std::vector soc_models; + const LiteRtParamIndex num_models = api.num_supported_models(plugin_handle); + for (LiteRtParamIndex i = 0; i < num_models; ++i) { + const char* model; + if (api.get_supported_soc_model(plugin_handle, i, &model) != + kLiteRtStatusOk) { + continue; + } + soc_models.push_back(std::string(model)); + } + return soc_models; +} + +} // namespace + +CompilerPlugin::ResultT CompilerPlugin::LoadPlugin( + const absl::string_view lib_path) { + LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.data()); + CompilerPlugin plugin; + + if (OpenLib(lib_path, &plugin.lib_handle_) != kLiteRtStatusOk) { + LITERT_LOG(LITERT_WARNING, "Failed to load plugin at: %s", lib_path.data()); + return ResultT::FromStatus(kLiteRtStatusErrorDynamicLoading); + } + + if (ResolvePluginApi(plugin.lib_handle_, plugin.plugin_api_) != + kLiteRtStatusOk) { + LITERT_LOG(LITERT_WARNING, "Failed to resolve plugin api at: %s", + lib_path.data()); + return ResultT::FromStatus(kLiteRtStatusErrorDynamicLoading); + } + + if (plugin.plugin_api_.init(&plugin.plugin_handle_) != kLiteRtStatusOk) { + LITERT_LOG(LITERT_WARNING, "Failed to initialize plugin at: %s", + lib_path.data()); + if (CloseLib(plugin.lib_handle_) != kLiteRtStatusOk) { + LITERT_LOG(LITERT_WARNING, "Failed to close loaded library at: %s", + lib_path.data()); + } + return ResultT::FromStatus(kLiteRtStatusErrorDynamicLoading); + } + + // This should never change throughout the lifetime of the compiler + // plugin so save to avoid recalling. + plugin.soc_models_ = GetSocModels(plugin.plugin_api_, plugin.plugin_handle_); + + return ResultT::TakeValue(std::move(plugin)); +} + +CompilerPlugin::ResultVecT CompilerPlugin::LoadPlugins( + absl::Span lib_search_paths) { + std::vector plugin_lib_paths; + for (auto search_path : lib_search_paths) { + LITERT_RETURN_RESULT_IF_NOT_OK( + FindLiteRtSharedLibs(search_path, plugin_lib_paths), VecT); + } + + VecT loaded_plugins; + loaded_plugins.reserve(lib_search_paths.size()); + + for (const auto& lib_path : plugin_lib_paths) { + LITERT_LOG(LITERT_INFO, "Loading plugin at: %s", lib_path.c_str()); + auto result = LoadPlugin(lib_path); + if (!result.HasValue()) { + continue; + } + loaded_plugins.push_back(std::move(result.Value())); + } + + return ResultVecT::TakeValue(std::move(loaded_plugins)); +} + +CompilerPlugin::CompilerPlugin(CompilerPlugin&& other) + : soc_models_(std::move(other.soc_models_)), + lib_handle_(other.lib_handle_), + plugin_api_(std::move(other.plugin_api_)), + plugin_handle_(other.plugin_handle_) { + other.soc_models_ = {}; + other.plugin_api_ = {}; + other.lib_handle_ = nullptr; + other.plugin_handle_ = nullptr; +} + +CompilerPlugin& CompilerPlugin::operator=(CompilerPlugin&& other) { + if (this != &other) { + soc_models_ = std::move(other.soc_models_); + other.soc_models_ = {}; + + lib_handle_ = other.lib_handle_; + other.lib_handle_ = nullptr; + + plugin_api_ = std::move(other.plugin_api_); + other.plugin_api_ = {}; + + plugin_handle_ = other.plugin_handle_; + other.plugin_handle_ = nullptr; + } + return *this; +} + +CompilerPlugin::~CompilerPlugin() { + if (plugin_handle_ != nullptr) { + plugin_api_.destroy(plugin_handle_); + } + if (lib_handle_ != nullptr) { + if (kLiteRtStatusOk != CloseLib(lib_handle_)) { + LITERT_LOG(LITERT_WARNING, "%s", "Failed to close shared library\n"); + } + } +} + +LiteRtResult> CompilerPlugin::PartitionModel( + const LiteRtModelT& model) { + LiteRtOpListT ops; + // TODO: Use const where appropriate in the C compiler plugin api. + LiteRtModel c_model = const_cast(&model); + LITERT_RETURN_RESULT_IF_NOT_OK( + plugin_api_.partition_model(plugin_handle_, c_model, &ops), + std::vector); + return LiteRtResult>::TakeValue(ops.Vec()); +} + +LiteRtStatus CompilerPlugin::Compile( + const absl::string_view soc_model, + const std::vector& partitions, std::ostream& byte_code_out, + std::vector& call_info_out) { + CompiledResult result = MakeResult(); + + // Compile given partitions into result. + // TODO: Use const where appropriate in the C compiler plugin api. + LiteRtSubgraphArray partitions_arr = + const_cast(partitions.data()); + LITERT_RETURN_STATUS_IF_NOT_OK( + plugin_api_.compile(plugin_handle_, soc_model.data(), partitions_arr, + partitions.size(), &result.compiled_result_handle_)); + + // Parse call info from the result. + { + LITERT_ASSIGN_OR_RETURN_STATUS(auto num_call, result.NumCalls()); + LITERT_ENSURE( + num_call == partitions.size(), kLiteRtStatusErrorRuntimeFailure, + "Plugin didn't return call info for each partition compiled.\n"); + for (int i = 0; i < num_call; ++i) { + LITERT_ASSIGN_OR_RETURN_STATUS(call_info_out.emplace_back(), + result.CallInfo(i)); + } + } + + // Parse byte code from result. + { + LITERT_ASSIGN_OR_RETURN_STATUS(const CompiledResult::BytesT byte_code, + result.ByteCode()); + LITERT_LOG(LITERT_INFO, "Compiled %d partitions in %lu bytes", + partitions.size(), byte_code.size); + byte_code_out.write(byte_code.data, byte_code.size); + } + + return kLiteRtStatusOk; +} + +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h new file mode 100644 index 00000000000000..a0b7203383b415 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h" + +namespace litert::internal { + +class CompiledResult { + friend class CompilerPlugin; + struct BytesT { + const char* data; + size_t size; + + std::string String() const; + }; + + // Get the single module of compiled byte code. This contains the + // compilation result for all entry points. + LiteRtResult ByteCode() const; + + // Get information regarding the "ith" entry points in the compiled module. + // There will be oe entry point for each subgraph compiled for. + LiteRtResult CallInfo(LiteRtParamIndex call_idx) const; + + // Get the number of entry points in the compiled module. This will be equal + // to the number of subgraphs passed to the compilation step. + LiteRtResult NumCalls() const; + + explicit CompiledResult(const LiteRtCompilerPluginApi& allocating_plugin_api) + : allocating_plugin_api_(allocating_plugin_api) {} + + CompiledResult(CompiledResult&& other) = default; + CompiledResult& operator=(CompiledResult&& other) = default; + CompiledResult(const CompiledResult& other) = delete; + CompiledResult& operator=(const CompiledResult& other) = delete; + + ~CompiledResult(); + + LiteRtCompilerPluginApi allocating_plugin_api_; + LiteRtCompiledResult compiled_result_handle_ = nullptr; +}; + +// Syntatic sugar around dynamically loaded LiteRtCompilerPlugin libraries. +// TODO turn this into a general C++ wraper for the whole compiler plugin api. +class CompilerPlugin { + public: + using VecT = std::vector; + using ResultT = LiteRtResult; + using ResultVecT = LiteRtResult; + + // Get the manufacturer associated with this plugin. NOTE: SocManufacturer + // string returned by the underlying plugin are expected to have static + // lifetime. + absl::string_view SocManufacturer() const { + return plugin_api_.soc_manufacturer(); + } + + // Get list of unique soc models targetable by this plugin. + const std::vector& SocModels() const { return soc_models_; } + + // Selects ops for the plugin to compile. + LiteRtResult> PartitionModel(const LiteRtModelT& model); + + // Compile given LiteRtSubgraphs for target "soc_model". Write compiled byte + // code to the given stream. For each given subgraph, write opaque data about + // the corresponding entry point to the given "call_info_out". + LiteRtStatus Compile(absl::string_view soc_model, + const std::vector& partitions, + std::ostream& byte_code_out, + std::vector& call_info_out); + + // Search for shared library files with prefix "libLiteRtPlugin" in the + // directories passed through "lib_search_paths". Populates "loaded_plugins" + // with resolved plugin apis for each found library that can be succesfully + // loaded. Additionally initializes the compiler plugin instances + // and stores handle. + static ResultVecT LoadPlugins( + absl::Span lib_search_paths); + + CompilerPlugin(CompilerPlugin&& other); + CompilerPlugin& operator=(CompilerPlugin&& other); + CompilerPlugin(const CompilerPlugin& other) = delete; + CompilerPlugin& operator=(const CompilerPlugin& other) = delete; + + // Destroys any living `LiteRtCompilerPlugin` and frees reference + // to dynamically loaded library. + ~CompilerPlugin(); + + private: + static ResultT LoadPlugin(absl::string_view lib_path); + CompilerPlugin() = default; + + std::vector soc_models_; + void* lib_handle_ = nullptr; + LiteRtCompilerPluginApi plugin_api_ = {}; + LiteRtCompilerPlugin plugin_handle_ = nullptr; + + // Internal LiteRtCompiledResult wrapper. + + CompiledResult MakeResult() const { return CompiledResult(plugin_api_); } +}; + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_COMPILER_PLUGIN_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin_test.cc new file mode 100644 index 00000000000000..23795cf1cf005a --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin_test.cc @@ -0,0 +1,149 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h" + +#include +#include +#include +#include + +#include +#include +#include "testing/base/public/unique-test-directory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/tools/dump.h" + +namespace { + +using ::litert::internal::CompilerPlugin; +using ::litert::testing::TouchTestFile; + +constexpr absl::string_view kTestPluginSearchPath = + "third_party/tensorflow/lite/experimental/litert/vendors/examples"; + +constexpr absl::string_view kTestManufacturer = "ExampleSocManufacturer"; +constexpr absl::string_view kTestModels = "ExampleSocModel"; + +TEST(CompilerPluginTest, LoadTestPlugin) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + ASSERT_EQ(plugins.front().SocModels().size(), 1); + EXPECT_EQ(plugins.front().SocModels().front(), kTestModels); +} + +TEST(CompilerPluginTest, LoadTestPluginWithMalformed) { + const auto dir = testing::UniqueTestDirectory(); + TouchTestFile("notLibLiteRt.so", dir); + + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MultipleValidPlugins) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins( + {kTestPluginSearchPath, kTestPluginSearchPath})); + + ASSERT_EQ(plugins.size(), 2); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + EXPECT_EQ(plugins.back().SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MoveAssign) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + + CompilerPlugin other = std::move(plugins.front()); + + EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, MoveConstruct) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + + CompilerPlugin other(std::move(plugins.front())); + + EXPECT_EQ(other.SocManufacturer(), kTestManufacturer); +} + +TEST(CompilerPluginTest, SocModels) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + + EXPECT_THAT(plugins.front().SocModels(), + ::testing::ElementsAreArray({kTestModels})); +} + +TEST(CompilerPluginTest, PartitionModel) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + + auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto ops, plugins.front().PartitionModel(*model)); + EXPECT_EQ(ops.size(), 2); +} + +TEST(CompilerPluginTest, CompileModel) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + ASSERT_EQ(plugins.size(), 1); + EXPECT_EQ(plugins.front().SocManufacturer(), kTestManufacturer); + + auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); + + std::ostringstream byte_code_out; + std::vector call_info_out; + ASSERT_STATUS_OK(plugins.front().Compile(kTestModels, {subgraph}, + byte_code_out, call_info_out)); + + EXPECT_GT(byte_code_out.str().size(), 0); + EXPECT_EQ(call_info_out.size(), 1); +} + +TEST(CompilerPluginTest, Dump) { + ASSERT_RESULT_OK_MOVE(CompilerPlugin::VecT plugins, + CompilerPlugin::LoadPlugins({kTestPluginSearchPath})); + ASSERT_EQ(plugins.size(), 1); + + std::stringstream dump; + litert::internal::Dump(plugins.front(), dump); + + ASSERT_EQ(dump.view(), + "SocManufacturer: ExampleSocManufacturer\nSocModels: { " + "ExampleSocModel }\n"); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/core/dispatch/BUILD b/tensorflow/lite/experimental/litert/core/dispatch/BUILD new file mode 100644 index 00000000000000..c94489ce122ada --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/BUILD @@ -0,0 +1,141 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "dispatch", + srcs = [ + "litert_dispatch.cc", + ], + hdrs = [ + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:utils", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "dispatch_delegate", + srcs = [ + "dispatch_delegate.cc", + "dispatch_delegate_kernel.cc", + ], + hdrs = [ + "dispatch_delegate_kernel.h", + "dispatch_delegate_options.h", + ], + deps = [ + ":dispatch", + "//tensorflow/lite/c:c_api", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/delegates/utils:simple_opaque_delegate", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core:tfl_utils", + "//tensorflow/lite/experimental/litert/core:utils", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "dispatch_delegate_google_tensor_test", + srcs = ["dispatch_delegate_google_tensor_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch:dispatch_api_shared", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":dispatch", + ":dispatch_delegate", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model_npu", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform", + ], +) + +cc_test( + name = "dispatch_delegate_qualcomm_test", + srcs = ["dispatch_delegate_qualcomm_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch:dispatch_api_shared", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":dispatch", + ":dispatch_delegate", + "//tensorflow/lite:framework", + "//tensorflow/lite/c:c_api_opaque", + "//tensorflow/lite/c:common", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_dispatch_delegate", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model_npu", + "//tensorflow/lite/kernels:builtin_ops", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform", + ], +) diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate.cc b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate.cc new file mode 100644 index 00000000000000..d00c446a065c1a --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate.cc @@ -0,0 +1,170 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.h" +#include "tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_options.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace { + +// A TFL Delegate that can recognize subgraphs that run on Dispatch API capable +// accelerators, e.g. TPU, DSP, ... It replaces such subgraphs and offloads +// their work through the Dispatch API. +class DispatchDelegate : public tflite::SimpleOpaqueDelegateInterface { + public: + static TfLiteOpaqueDelegate* Create(LiteRtDispatchDelegateOptions* options_) { + litert::DispatchDelegateOptionsPtr options( + options_, LiteRtDestroyDispatchDelegateOptions); + if (!options) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return nullptr; + } + + std::unique_ptr managed_sb_delegate( + new DispatchDelegate(std::move(options))); + return tflite::TfLiteOpaqueDelegateFactory::CreateSimpleDelegate( + std::move(managed_sb_delegate), + kTfLiteDelegateFlagsAllowDynamicTensors); + } + + bool IsNodeSupportedByDelegate(const TfLiteOperator* op, + const TfLiteOpaqueNode* node, + TfLiteOpaqueContext* context) const override; + + TfLiteStatus Initialize(TfLiteOpaqueContext* context) override; + + const char* Name() const override; + + std::unique_ptr + CreateDelegateKernelInterface() override; + + private: + static constexpr absl::string_view kDelegateName = "DispatchDelegate"; + static constexpr absl::string_view kDispatchNodeCustomCode = "dispatch_node"; + + explicit DispatchDelegate(litert::DispatchDelegateOptionsPtr&& options) + : options_(std::move(options)) {} + + litert::DispatchDelegateOptionsPtr options_; + int dispatch_graph_name_id_ = 0; +}; + +bool DispatchDelegate::IsNodeSupportedByDelegate( + const TfLiteOperator* op, const TfLiteOpaqueNode* node, + TfLiteOpaqueContext* context) const { + auto custom_code = absl::string_view(TfLiteOperatorGetCustomName(op)); + return custom_code == kDispatchNodeCustomCode; +} + +TfLiteStatus DispatchDelegate::Initialize(TfLiteOpaqueContext* context) { + return kTfLiteOk; +} + +const char* DispatchDelegate::Name() const { return kDelegateName.data(); } + +std::unique_ptr +DispatchDelegate::CreateDelegateKernelInterface() { + std::string dispatch_graph_name = + absl::StrFormat("DispatchGraph_%d", dispatch_graph_name_id_++); + + auto kernel = litert::internal::DispatchDelegateKernel::Create( + std::move(dispatch_graph_name), *options_); + if (kernel.ok()) { + return std::move(*kernel); + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create a dispatch delegate kernel: %s", + kernel.status().message().data()); + return nullptr; + } +} + +} // namespace + +LiteRtDispatchDelegateOptions* LiteRtCreateDefaultDispatchDelegateOptions() { + return new LiteRtDispatchDelegateOptions; +} + +TfLiteStatus LiteRtAddDispatchDelegateOption( + LiteRtDispatchDelegateOptions* options, LiteRtDispatchOption option) { + if (!options) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kTfLiteError; + } + + options->AddOption(option); + return kTfLiteOk; +} + +TfLiteStatus LiteRtAddDispatchDelegateExecInfoOption( + LiteRtDispatchDelegateOptions* options, const char* exec_tag, + const void* bytecode_addr, size_t bytecode_size, + const char* function_name) { + if (!options || !exec_tag || !bytecode_addr) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kTfLiteError; + } + + LiteRtDispatchDelegateOptions::ExecInfo exec_info; + exec_info.bytecode = + absl::MakeSpan(static_cast(bytecode_addr), bytecode_size); + if (function_name) { + exec_info.function_name = function_name; + } + + options->AddExecInfo(exec_tag, std::move(exec_info)); + return kTfLiteOk; +} + +void LiteRtDestroyDispatchDelegateOptions( + LiteRtDispatchDelegateOptions* options) { + delete options; +} + +TfLiteDelegate* LiteRtCreateDispatchDelegate( + LiteRtDispatchDelegateOptions* options) { + return DispatchDelegate::Create(options); +} + +void LiteRtDestroyDispatchDelegate(TfLiteOpaqueDelegate* delegate) { + tflite::TfLiteOpaqueDelegateFactory::DeleteSimpleDelegate(delegate); +} + +namespace litert { + +DispatchDelegateOptionsPtr CreateDispatchDelegateOptionsPtr() { + return {LiteRtCreateDefaultDispatchDelegateOptions(), + LiteRtDestroyDispatchDelegateOptions}; +} + +DispatchDelegatePtr CreateDispatchDelegatePtr( + DispatchDelegateOptionsPtr&& options) { + return DispatchDelegatePtr(LiteRtCreateDispatchDelegate(options.release()), + LiteRtDestroyDispatchDelegate); +} +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_google_tensor_test.cc b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_google_tensor_test.cc new file mode 100644 index 00000000000000..080c3b950571f6 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_google_tensor_test.cc @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/signature_runner.h" + +TEST(DispatchDelegate, GoogleTensor) { + auto npu_model_file_name = kGoogleTensorModelFileName; + auto npu_model = litert::testing::LoadBinaryFile(npu_model_file_name); + ASSERT_TRUE(npu_model.ok()); + ABSL_LOG(INFO) << "Loaded model " << npu_model_file_name << ", " + << npu_model->size() << " bytes"; + + auto tflite_file_name = + litert::testing::GetTestFilePath("simple_model_npu.tflite"); + auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file_name.data()); + ASSERT_NE(model, nullptr); + + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter, nullptr); + + EXPECT_EQ(interpreter->nodes_size(), 1); + EXPECT_EQ(interpreter->inputs().size(), 2); + EXPECT_EQ(interpreter->outputs().size(), 1); + ASSERT_EQ(interpreter->execution_plan().size(), 1); + + auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); + ASSERT_EQ( + LiteRtAddDispatchDelegateExecInfoOption( + dispatch_delegate_options.get(), "npu_bytecode", npu_model->data(), + npu_model->size(), /*function_name=*/nullptr), + kTfLiteOk); + auto dispatch_delegate = + litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "GoogleTensor eTPU"; +#endif + + ASSERT_EQ(interpreter->ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter->signature_keys(); + ASSERT_EQ(signature_defs.size(), 0); + + tflite::impl::SignatureRunner* runner = + interpreter->GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto output_tensor = runner->output_tensor("tfl.custom"); + ASSERT_NE(output_tensor, nullptr); + auto* output = output_tensor->data.f; + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + for (auto i = 0; i < kTestOutputSize; ++i) { + EXPECT_NEAR(output[i], kTestOutputTensor[i], 1e-5); + } +} diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.cc b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.cc new file mode 100644 index 00000000000000..c014229485db12 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.cc @@ -0,0 +1,446 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_options.h" +#include "tensorflow/lite/experimental/litert/core/tfl_utils.h" +#include "tensorflow/lite/experimental/litert/core/utils.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace internal { + +DispatchDelegateKernel::~DispatchDelegateKernel() { + for (size_t i = 0; i < input_tensor_buffer_handles_.size(); ++i) { + (void)LiteRtDispatchDetachInput(invocation_context_, i, + input_tensor_buffer_handles_[i]); + } + + for (size_t i = 0; i < output_tensor_buffer_handles_.size(); ++i) { + (void)LiteRtDispatchDetachOutput(invocation_context_, i, + output_tensor_buffer_handles_[i]); + } + + if (invocation_context_) { + (void)LiteRtDispatchInvocationContextDestroy(invocation_context_); + } + + for (auto& buffer_handle : input_tensor_buffer_handles_) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); + } + + for (auto& buffer_handle : output_tensor_buffer_handles_) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, buffer_handle); + } + + if (device_context_) { + (void)LiteRtDispatchDeviceContextDestroy(device_context_); + } + + input_tensor_buffers_.clear(); + output_tensor_buffers_.clear(); +} + +absl::StatusOr DispatchDelegateKernel::Create( + std::string&& graph_name, const LiteRtDispatchDelegateOptions& options) { + auto dispatch_options = options.GetDispatchOptions(); + if (auto status = LiteRtDispatchInitialize(dispatch_options.data(), + dispatch_options.size()); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to initialize Dispatch API: %d", status); + return absl::InternalError("Failed to initialize Dispatch API"); + } + + const char* vendor_id; + if (auto status = LiteRtDispatchGetVendorId(&vendor_id); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API vendor ID: %d", + status); + return absl::InternalError("Failed to get Dispatch API vendor ID"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API vendor ID: %s", vendor_id); + + const char* build_id; + if (auto status = LiteRtDispatchGetBuildId(&build_id); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API build ID: %d", status); + return absl::InternalError("Failed to get Dispatch API build ID"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API build ID: %s", build_id); + + LiteRtDispatchApiVersion api_version; + if (auto status = LiteRtDispatchGetApiVersion(&api_version); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API version: %d", status); + return absl::InternalError("Failed to get Dispatch API version"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API version: %d.%d.%d", api_version.major, + api_version.minor, api_version.patch); + // Check if the versions mach. + if (api_version.major != LITERT_DISPATCH_API_VERSION_MAJOR || + api_version.minor < LITERT_DISPATCH_API_VERSION_MINOR) { + return absl::InternalError( + "Found Dispatch API with an unsupported version"); + } + + int capabilities; + if (auto status = LiteRtDispatchGetCapabilities(&capabilities); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API capabilities: %d", + status); + return absl::InternalError("Failed to get Dispatch API capabilities"); + } + LITERT_LOG(LITERT_INFO, "Dispatch API capabilities: %d", capabilities); + + if (!(capabilities & kLiteRtDispatchCapabilitiesBasic)) { + return absl::InternalError("Dispatch API has insufficient capabilities"); + } + + LiteRtDispatchDeviceContext device_context; + if (auto status = LiteRtDispatchDeviceContextCreate(&device_context); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get Dispatch API device context: %d", + status); + return absl::InternalError("Failed to create Dispatch API device context"); + } + + return Ptr(new DispatchDelegateKernel(options, std::move(graph_name), + device_context)); +} + +TfLiteStatus DispatchDelegateKernel::Init( + TfLiteOpaqueContext* context, const TfLiteOpaqueDelegateParams* params) { + LITERT_LOG(LITERT_INFO, "DispatchDelegateKernel::Init"); + if (params->nodes_to_replace->size != 1) { + LITERT_LOG(LITERT_ERROR, + "Models with more than one dispatch node are not yet supported"); + return kTfLiteError; + } + + auto node_id = params->nodes_to_replace->data[0]; + TfLiteOpaqueNode* node; + TfLiteOperator* op; + if (auto status = TfLiteOpaqueContextGetNodeAndRegistration(context, node_id, + &node, &op); + status != kTfLiteOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get node and registration: %d", status); + return status; + } + + const void* init_data; + int init_data_size; + if (auto status = TfLiteOpaqueNodeGetCustomInitialData(node, &init_data, + &init_data_size); + status != kTfLiteOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get custom initial data: %d", status); + return status; + } + if (!init_data || !init_data_size) { + LITERT_LOG(LITERT_ERROR, "Found custom op with missing initial data"); + return kTfLiteError; + } + + std::string custom_option(static_cast(init_data), + init_data_size); + auto exec_info = options_.GetExecInfo(custom_option); + if (!exec_info.ok()) { + LITERT_LOG(LITERT_ERROR, "Failed to fetch ExecInfo for %s: %s", + custom_option.data(), exec_info.status().message().data()); + return kTfLiteError; + } + + const char* function_name = exec_info->function_name.has_value() + ? exec_info->function_name->data() + : nullptr; + int num_inputs = params->input_tensors->size; + int num_outputs = params->output_tensors->size; + if (auto status = LiteRtDispatchInvocationContextCreate( + device_context_, kLiteRtDispatchExecutableTypeMlModel, + exec_info->bytecode.data(), exec_info->bytecode.size(), function_name, + num_inputs, num_outputs, &invocation_context_); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to create invocation context: %d", status); + return kTfLiteError; + } + + input_tensor_buffers_.resize(num_inputs); + input_tensor_buffer_handles_.resize(num_inputs); + input_tensor_buffer_used_size_.resize(num_inputs); + + output_tensor_buffers_.resize(num_outputs); + output_tensor_buffer_handles_.resize(num_outputs); + output_tensor_buffer_used_size_.resize(num_outputs); + + return kTfLiteOk; +} + +absl::StatusOr +DispatchDelegateKernel::GetBufferRequirements( + const RankedTensorType& tensor_type, int io_tensor_index, + bool is_input) const { + auto litert_tensor_type = static_cast(tensor_type); + LiteRtTensorBufferRequirements tensor_buffer_requirements; + if (is_input) { + if (auto status = LiteRtDispatchGetInputRequirements( + invocation_context_, /*input_index=*/io_tensor_index, + &litert_tensor_type, &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, + "Failed to get tensor buffer requirements for input %d: %d", + io_tensor_index, status); + return absl::InternalError( + "Failed to get tensor buffer requirements for input"); + } + + } else { + if (auto status = LiteRtDispatchGetOutputRequirements( + invocation_context_, /*output_index=*/io_tensor_index, + &litert_tensor_type, &tensor_buffer_requirements); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, + "Failed to get tensor buffer requirements for output %d: %d", + io_tensor_index, status); + return absl::InternalError( + "Failed to get tensor buffer requirements for output"); + } + } + + return TensorBufferRequirements(tensor_buffer_requirements, /*owned=*/false); +} + +TfLiteStatus DispatchDelegateKernel::SetBuffer( + const TfLiteOpaqueTensor* tfl_opaque_tensor, int buffer_index, + bool is_input) { + auto& cached_tensor_buffer = is_input ? input_tensor_buffers_[buffer_index] + : output_tensor_buffers_[buffer_index]; + auto& cached_tensor_buffer_handle = + is_input ? input_tensor_buffer_handles_[buffer_index] + : output_tensor_buffer_handles_[buffer_index]; + auto& used_size = is_input ? input_tensor_buffer_used_size_[buffer_index] + : output_tensor_buffer_used_size_[buffer_index]; + + auto tensor_type = ConvertTensorType(tfl_opaque_tensor); + if (!tensor_type.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", tensor_type.status().message().data()); + return kTfLiteError; + } + + // Check if we can reuse a cached tensor buffer or we need to create a new + // one. + if (static_cast(cached_tensor_buffer)) { + if (auto cached_tensor_type = cached_tensor_buffer.TensorType(); + !cached_tensor_type.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", + cached_tensor_type.status().message().data()); + return kTfLiteError; + } + + if (tensor_type->Layout() == cached_tensor_buffer.TensorType()->Layout()) { + // We can reuse the cached tensor buffer. + return kTfLiteOk; + } + + // We cannot reuse the cached tensor buffer; proceed below. + } + + auto tensor_buffer_requirements = + GetBufferRequirements(*tensor_type, buffer_index, is_input); + if (!tensor_buffer_requirements.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", + tensor_buffer_requirements.status().message().data()); + return kTfLiteError; + } + + auto supported_tensor_buffer_types = + tensor_buffer_requirements->SupportedTypes(); + if (!supported_tensor_buffer_types.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", + supported_tensor_buffer_types.status().message().data()); + return kTfLiteError; + } + + if (supported_tensor_buffer_types->empty()) { + LITERT_LOG(LITERT_ERROR, + "Insufficient number of supported tensor buffer types"); + return kTfLiteError; + } + + // For now we simply pick the first buffer type that's supported. + LiteRtTensorBufferType tensor_buffer_type = + (*supported_tensor_buffer_types)[0]; + + auto tensor_buffer_size = tensor_buffer_requirements->BufferSize(); + if (!tensor_buffer_size.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", + tensor_buffer_size.status().message().data()); + return kTfLiteError; + } + + auto litert_tensor_type = static_cast(*tensor_type); + LiteRtTensorBuffer litert_tensor_buffer; + if (auto status = LiteRtCreateManagedTensorBuffer( + tensor_buffer_type, &litert_tensor_type, *tensor_buffer_size, + &litert_tensor_buffer); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to create managed tensor buffer: %d", + status); + return kTfLiteError; + } + + TensorBuffer tensor_buffer(litert_tensor_buffer); + + LiteRtTensorBufferHandle buffer_handle; + if (auto status = LiteRtDispatchRegisterTensorBuffer( + device_context_, static_cast(tensor_buffer), + &buffer_handle); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to register tensor buffer: %d", status); + return kTfLiteError; + } + + if (is_input) { + if (auto status = LiteRtDispatchAttachInput(invocation_context_, + buffer_index, buffer_handle); + status != kLiteRtStatusOk) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, + buffer_handle); + LITERT_LOG(LITERT_ERROR, "Failed to attach tensor buffer to input %d: %d", + buffer_index, status); + return kTfLiteError; + } + } else { + if (auto status = LiteRtDispatchAttachOutput(invocation_context_, + buffer_index, buffer_handle); + status != kLiteRtStatusOk) { + (void)LiteRtDispatchUnregisterTensorBuffer(device_context_, + buffer_handle); + LITERT_LOG(LITERT_ERROR, + "Failed to attach tensor buffer to output %d: %d", + buffer_index, status); + return kTfLiteError; + } + } + + auto num_bytes = internal::GetNumPackedBytes( + static_cast(*tensor_type)); + if (!num_bytes.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", num_bytes.status().message().data()); + return kTfLiteError; + } + + cached_tensor_buffer = std::move(tensor_buffer); + cached_tensor_buffer_handle = buffer_handle; + used_size = *num_bytes; + + return kTfLiteOk; +} + +TfLiteStatus DispatchDelegateKernel::Prepare(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); + for (size_t i = 0; i < num_node_inputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); + if (auto status = SetBuffer(tfl_opaque_tensor, i, /*is_input=*/true); + status != kTfLiteOk) { + return status; + } + } + + size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); + for (size_t i = 0; i < num_node_outputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); + if (auto status = SetBuffer(tfl_opaque_tensor, i, /*is_input=*/false); + status != kTfLiteOk) { + return status; + } + } + + return kTfLiteOk; +} + +TfLiteStatus DispatchDelegateKernel::Eval(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) { + size_t num_node_inputs = TfLiteOpaqueNodeNumberOfInputs(node); + if (num_node_inputs != input_tensor_buffers_.size()) { + LITERT_LOG(LITERT_ERROR, "Invalid number of inputs"); + return kTfLiteError; + } + + for (size_t i = 0; i < num_node_inputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetInput(context, node, i); + void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); + auto& tensor_buffer = input_tensor_buffers_[i]; + + auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); + if (!lock_and_addr.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.status().message().data()); + return kTfLiteError; + } + + size_t buffer_size = input_tensor_buffer_used_size_[i]; + std::memcpy(lock_and_addr->second, tensor_data, buffer_size); + } + + if (auto status = LiteRtDispatchInvoke(invocation_context_); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to invoke context: %d", status); + return kTfLiteError; + } + + size_t num_node_outputs = TfLiteOpaqueNodeNumberOfOutputs(node); + if (num_node_outputs != output_tensor_buffers_.size()) { + LITERT_LOG(LITERT_ERROR, "Invalid number of outputs"); + return kTfLiteError; + } + + for (size_t i = 0; i < num_node_outputs; ++i) { + auto* tfl_opaque_tensor = TfLiteOpaqueNodeGetOutput(context, node, i); + void* tensor_data = TfLiteOpaqueTensorData(tfl_opaque_tensor); + auto& tensor_buffer = output_tensor_buffers_[i]; + + auto lock_and_addr = TensorBufferScopedLock::Create(tensor_buffer); + if (!lock_and_addr.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", lock_and_addr.status().message().data()); + return kTfLiteError; + } + + size_t buffer_size = output_tensor_buffer_used_size_[i]; + std::memcpy(tensor_data, lock_and_addr->second, buffer_size); + } + + return kTfLiteOk; +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.h b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.h new file mode 100644 index 00000000000000..d0244ed4083b9c --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_kernel.h @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/delegates/utils/simple_opaque_delegate.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace internal { + +// A TFL kernel that the interpreter calls to dispatch execution through the +// Dispatch API. +class DispatchDelegateKernel + : public tflite::SimpleOpaqueDelegateKernelInterface { + public: + using Ptr = std::unique_ptr; + + ~DispatchDelegateKernel() override; + + static absl::StatusOr Create( + std::string&& graph_name, const LiteRtDispatchDelegateOptions& options); + + TfLiteStatus Init(TfLiteOpaqueContext* context, + const TfLiteOpaqueDelegateParams* params) override; + + TfLiteStatus Prepare(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) override; + + TfLiteStatus Eval(TfLiteOpaqueContext* context, + TfLiteOpaqueNode* node) override; + + private: + DispatchDelegateKernel(const LiteRtDispatchDelegateOptions& options, + std::string&& graph_name, + LiteRtDispatchDeviceContext device_context) + : options_(options), + graph_name_(std::move(graph_name)), + device_context_(device_context) {} + + absl::StatusOr GetBufferRequirements( + const RankedTensorType& tensor_type, int io_tensor_index, + bool is_input) const; + TfLiteStatus SetBuffer(const TfLiteOpaqueTensor* tfl_opaque_tensor, + int buffer_index, bool is_input); + + const LiteRtDispatchDelegateOptions& options_; + std::string graph_name_; + LiteRtDispatchDeviceContext device_context_; + LiteRtDispatchInvocationContext invocation_context_ = nullptr; + + std::vector input_tensor_buffers_; + std::vector input_tensor_buffer_handles_; + std::vector input_tensor_buffer_used_size_; + + std::vector output_tensor_buffers_; + std::vector output_tensor_buffer_handles_; + std::vector output_tensor_buffer_used_size_; +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_KERNEL_H_ diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_options.h b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_options.h new file mode 100644 index 00000000000000..a3f6288a0f8c7e --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_options.h @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +class LiteRtDispatchDelegateOptions { + public: + // Information about NPU binary, including the NPU binary bytecode and the + // name of the entry-point function. + struct ExecInfo { + absl::Span bytecode; + std::optional function_name; + }; + + void AddOption(LiteRtDispatchOption option) { options_.push_back(option); } + + // Store a given ExecInfo object and associated it to a given tag. + void AddExecInfo(absl::string_view exec_tag, ExecInfo&& exec_info) { + exec_infos_[std::string{exec_tag}] = std::move(exec_info); + } + + // Retrieve the ExecInfo object associated with a given tag. + absl::StatusOr GetExecInfo(const std::string& exec_tag) const { + if (auto iter = exec_infos_.find(exec_tag); iter != exec_infos_.end()) { + return iter->second; + } + return absl::NotFoundError("ExecInfo not found"); + } + + const std::vector& GetDispatchOptions() const { + return options_; + } + + private: + std::vector options_; + // ExecInfos are stored as (tag, ExecInfo) pairs. + std::map exec_infos_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DISPATCH_DISPATCH_DELEGATE_OPTIONS_H_ diff --git a/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_qualcomm_test.cc b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_qualcomm_test.cc new file mode 100644 index 00000000000000..fe161cc7d0dc34 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/dispatch_delegate_qualcomm_test.cc @@ -0,0 +1,108 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/experimental/litert/c/litert_dispatch_delegate.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/interpreter_builder.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model_builder.h" +#include "tensorflow/lite/signature_runner.h" + +TEST(DispatchDelegate, Qualcomm) { + auto npu_model_file_name = kQualcommModelFileName; + auto npu_model = litert::testing::LoadBinaryFile(npu_model_file_name); + ASSERT_TRUE(npu_model.ok()); + ABSL_LOG(INFO) << "Loaded model " << npu_model_file_name << ", " + << npu_model->size() << " bytes"; + + auto tflite_file_name = + litert::testing::GetTestFilePath("simple_model_npu.tflite"); + auto model = tflite::FlatBufferModel::BuildFromFile(tflite_file_name.data()); + ASSERT_NE(model, nullptr); + + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter, nullptr); + + EXPECT_EQ(interpreter->nodes_size(), 1); + EXPECT_EQ(interpreter->inputs().size(), 2); + EXPECT_EQ(interpreter->outputs().size(), 1); + ASSERT_EQ(interpreter->execution_plan().size(), 1); + + auto dispatch_delegate_options = litert::CreateDispatchDelegateOptionsPtr(); + ASSERT_EQ( + LiteRtAddDispatchDelegateExecInfoOption( + dispatch_delegate_options.get(), "npu_bytecode", npu_model->data(), + npu_model->size(), /*function_name=*/"simple"), + kTfLiteOk); + auto dispatch_delegate = + litert::CreateDispatchDelegatePtr(std::move(dispatch_delegate_options)); + +#if !defined(__ANDROID__) + GTEST_SKIP() << "The rest of this test is specific to Android devices with a " + "Qualcomm HTP"; +#endif + + ASSERT_EQ(interpreter->ModifyGraphWithDelegate(dispatch_delegate.get()), + kTfLiteOk); + + // Get the list of signatures and check it. + auto signature_defs = interpreter->signature_keys(); + ASSERT_EQ(signature_defs.size(), 0); + + tflite::impl::SignatureRunner* runner = + interpreter->GetSignatureRunner(/*signature_key=*/nullptr); + ASSERT_NE(runner, nullptr); + + EXPECT_EQ(runner->AllocateTensors(), kTfLiteOk); + + // Fill model inputs. + ASSERT_STREQ(runner->input_names()[0], "arg0"); + auto input_0_tensor = runner->input_tensor("arg0"); + ASSERT_NE(input_0_tensor, nullptr); + auto* input_0 = input_0_tensor->data.f; + std::memcpy(input_0, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + + ASSERT_STREQ(runner->input_names()[1], "arg1"); + auto input_1_tensor = runner->input_tensor("arg1"); + ASSERT_NE(input_1_tensor, nullptr); + auto* input_1 = input_1_tensor->data.f; + std::memcpy(input_1, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + + EXPECT_EQ(runner->Invoke(), kTfLiteOk); + + // Check model output. + ASSERT_STREQ(runner->output_names()[0], "tfl.custom"); + auto output_tensor = runner->output_tensor("tfl.custom"); + ASSERT_NE(output_tensor, nullptr); + auto* output = output_tensor->data.f; + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + for (auto i = 0; i < kTestOutputSize; ++i) { + EXPECT_NEAR(output[i], kTestOutputTensor[i], 1e-5); + } +} diff --git a/tensorflow/lite/experimental/litert/core/dispatch/litert_dispatch.cc b/tensorflow/lite/experimental/litert/core/dispatch/litert_dispatch.cc new file mode 100644 index 00000000000000..755b066b524765 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dispatch/litert_dispatch.cc @@ -0,0 +1,515 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +#include + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" + +#define INVOKE_FUNC(function, ...) \ + if (!TheApi.interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.interface->function(__VA_ARGS__); + +#define INVOKE_ASYNC_FUNC(function, ...) \ + if (!TheApi.async_interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API async interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.async_interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.async_interface->function(__VA_ARGS__); + +#define INVOKE_GRAPH_FUNC(function, ...) \ + if (!TheApi.graph_interface) { \ + LITERT_LOG(LITERT_ERROR, "Dispatch API graoh interface not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + if (!TheApi.graph_interface->function) { \ + LITERT_LOG(LITERT_ERROR, #function " not found"); \ + return kLiteRtStatusErrorRuntimeFailure; \ + } \ + return TheApi.graph_interface->function(__VA_ARGS__); + +namespace { + +constexpr const char* kSharedLibName = "libLiteRtDispatch.so"; + +bool IsTheApiInitialized = false; +LiteRtDispatchApi TheApi = { + /*.version=*/{/*.major=*/0, /*.minor=*/0, /*.patch=*/0}, + /*.interface=*/nullptr, + /*.async_interface=*/nullptr, + /*.graph_interface=*/nullptr, +}; + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + INVOKE_FUNC(initialize, options, num_options); +} + +std::string GetSharedLibraryPath(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return absl::StrFormat("%s/%s", option.value.str_value, kSharedLibName); + } + } + return kSharedLibName; +} + +} // namespace + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, + int num_options) { + if (IsTheApiInitialized) { + return kLiteRtStatusOk; + } + + auto shared_lib_path = GetSharedLibraryPath(options, num_options); + void* lib_handle = ::dlopen(shared_lib_path.data(), RTLD_NOW | RTLD_LOCAL); + if (!lib_handle) { + LITERT_LOG(LITERT_ERROR, "Failed to load dispatch library: %s", + ::dlerror()); + return kLiteRtStatusErrorRuntimeFailure; + } + + using LiteRtDispatchGetApi_t = LiteRtStatus (*)(LiteRtDispatchApi*); + auto LiteRtDispatchGetApi = reinterpret_cast( + ::dlsym(lib_handle, "LiteRtDispatchGetApi")); + if (!LiteRtDispatchGetApi) { + ::dlclose(lib_handle); + LITERT_LOG(LITERT_ERROR, "LiteRtDispatchGetApi not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = LiteRtDispatchGetApi(&TheApi); status != kLiteRtStatusOk) { + ::dlclose(lib_handle); + return status; + } + + if (!(TheApi.version.major == LITERT_DISPATCH_API_VERSION_MAJOR && + TheApi.version.minor <= LITERT_DISPATCH_API_VERSION_MINOR)) { + ::dlclose(lib_handle); + LITERT_LOG(LITERT_ERROR, + "Dispatch API runtime is too old, found version %d.%d.%d and " + "expected at least version %d.%d.%d", + TheApi.version.major, TheApi.version.minor, TheApi.version.patch, + LITERT_DISPATCH_API_VERSION_MAJOR, + LITERT_DISPATCH_API_VERSION_MINOR, + LITERT_DISPATCH_API_VERSION_PATCH); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto status = Initialize(options, num_options); + if (status == kLiteRtStatusOk) { + IsTheApiInitialized = true; + } + return status; +} + +LiteRtStatus LiteRtDispatchGetApiVersion( + LiteRtDispatchApiVersion* api_version) { + if (!api_version) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + *api_version = TheApi.version; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id) { + if (!vendor_id) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_vendor_id, vendor_id); +} + +LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id) { + if (!build_id) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_build_id, build_id); +} + +LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities) { + if (!capabilities) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_capabilities, capabilities); +} + +LiteRtStatus LiteRtDispatchDeviceContextCreate( + LiteRtDispatchDeviceContext* device_context) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(device_context_create, device_context); +} + +LiteRtStatus LiteRtDispatchDeviceContextDestroy( + LiteRtDispatchDeviceContext device_context) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(device_context_destroy, device_context); +} + +LiteRtStatus LiteRtDispatchGetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_input_requirements, invocation_context, input_index, + tensor_type, tensor_buffer_requirements); +} + +LiteRtStatus LiteRtDispatchGetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (!invocation_context || !tensor_type || !tensor_buffer_requirements) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(get_output_requirements, invocation_context, output_index, + tensor_type, tensor_buffer_requirements); +} + +LiteRtStatus LiteRtDispatchRegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + if (!device_context || !tensor_buffer || !tensor_buffer_handle) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(register_tensor_buffer, device_context, tensor_buffer, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(unregister_tensor_buffer, device_context, tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchInvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + if (!device_context || !exec_bytecode_ptr || !invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invocation_context_create, device_context, exec_type, + exec_bytecode_ptr, exec_bytecode_size, function_name, num_inputs, + num_outputs, invocation_context); +} + +LiteRtStatus LiteRtDispatchInvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invocation_context_destroy, invocation_context); +} + +LiteRtStatus LiteRtDispatchAttachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(attach_input, invocation_context, graph_input_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchAttachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + if (!TheApi.interface) { + LITERT_LOG(LITERT_ERROR, "Dispatch API interface not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (!TheApi.interface->attach_output) { + LITERT_LOG(LITERT_ERROR, "attach_output_tensor_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + INVOKE_FUNC(attach_output, invocation_context, graph_output_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchDetachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(detach_input, invocation_context, graph_input_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchDetachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(detach_output, invocation_context, graph_output_index, + tensor_buffer_handle); +} + +LiteRtStatus LiteRtDispatchInvoke( + LiteRtDispatchInvocationContext invocation_context) { + if (!invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_FUNC(invoke, invocation_context); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchAttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event) { + if (!invocation_context || !input_event) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_ASYNC_FUNC(attach_input_event, invocation_context, graph_input_index, + input_event); +} + +LiteRtStatus LiteRtDispatchInvokeAsync( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events) { + if (!invocation_context || !output_events) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_ASYNC_FUNC(invoke_async, invocation_context, num_output_events, + output_events); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchGraphCreate( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph) { + if (!device_context || !graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(graph_create, device_context, graph); +} + +LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph graph) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(graph_destroy, graph); +} + +LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(add_node, graph, node_id, node_type); +} + +LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(add_edge, graph, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + int input_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_node_input, graph, node_id, input_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + int output_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_node_output, graph, node_id, output_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph graph, + int input_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_graph_input, graph, input_index, edge_id); +} + +LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph graph, + int output_index, + LiteRtDispatchEdgeId edge_id) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(connect_graph_output, graph, output_index, edge_id); +} + +LiteRtStatus LiteRtDispatchLoadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle) { + if (!device_context || !bytecode || !exec_handle) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + if (!TheApi.graph_interface) { + LITERT_LOG(LITERT_ERROR, "Dispatch API graph interface not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (!TheApi.graph_interface->load_executable) { + LITERT_LOG(LITERT_ERROR, "load_executable not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + INVOKE_GRAPH_FUNC(load_executable, device_context, type, bytecode, + bytecode_size, exec_handle); +} + +LiteRtStatus LiteRtDispatchUnloadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle) { + if (!device_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(unload_executable, device_context, exec_handle); +} + +LiteRtStatus LiteRtDispatchAssignNodeFunction( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(assign_node_function, graph, node_id, exec_handle, + function_name); +} + +LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph graph, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_graph, graph, key, value); +} + +LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_node, graph, node_id, key, value); +} + +LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, + const char* key, const char* value) { + if (!graph) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(annotate_edge, graph, edge_id, key, value); +} + +LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context) { + if (!device_context || !graph || !invocation_context) { + LITERT_LOG(LITERT_ERROR, "Null input"); + return kLiteRtStatusErrorInvalidArgument; + } + INVOKE_GRAPH_FUNC(invocation_context_create_from_graph, device_context, graph, + invocation_context); +} diff --git a/tensorflow/lite/experimental/litert/core/dmabuf_buffer.cc b/tensorflow/lite/experimental/litert/core/dmabuf_buffer.cc new file mode 100644 index 00000000000000..cdb56dcf436c8f --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dmabuf_buffer.cc @@ -0,0 +1,173 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/dmabuf_buffer.h" + +#include +#include + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace litert { +namespace internal { + +namespace { + +class DmaBufLibrary { + public: + using Ptr = std::unique_ptr; + + ~DmaBufLibrary() { + if (allocator_) { + free_allocator_(allocator_); + } + } + + static absl::StatusOr Create() { + DlHandle dlhandle(::dlopen("libdmabufheap.so", RTLD_LAZY | RTLD_LOCAL), + ::dlclose); + if (!dlhandle) { + return absl::InternalError("libdmabufheap.so not found"); + } + + auto create_allocator = reinterpret_cast( + ::dlsym(dlhandle.get(), "CreateDmabufHeapBufferAllocator")); + if (!create_allocator) { + return absl::InternalError("CreateDmabufHeapBufferAllocator not found"); + } + + auto free_allocator = reinterpret_cast( + ::dlsym(dlhandle.get(), "FreeDmabufHeapBufferAllocator")); + if (!free_allocator) { + return absl::InternalError("FreeDmabufHeapBufferAllocator not found"); + } + + auto alloc_buffer = reinterpret_cast( + ::dlsym(dlhandle.get(), "DmabufHeapAlloc")); + if (!alloc_buffer) { + return absl::InternalError("DmabufHeapAlloc not found"); + } + + void* allocator = create_allocator(); + if (!allocator) { + return absl::InternalError("CreateDmabufHeapBufferAllocator failed"); + } + + return Ptr(new DmaBufLibrary(std::move(dlhandle), allocator, free_allocator, + alloc_buffer)); + } + + absl::StatusOr Alloc(size_t size) { + int fd = alloc_buffer_(allocator_, kDmaBufHeap, size, /*flags=*/0, + /*legacy_align=*/0); + if (fd < 0) { + return absl::InternalError("Failed to allocate DMA-BUF buffer"); + } + void* addr = + ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + return absl::InternalError("Failed to mem-map DMA-BUF buffer"); + } + records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; + return DmaBufBuffer{.fd = fd, .addr = addr}; + } + + void Free(void* addr) { + auto iter = records_.find(addr); + if (iter == records_.end()) { + return; + } + auto& record = iter->second; + ::munmap(record.addr, record.size); + ::close(record.fd); + records_.erase(iter); + } + + private: + static constexpr const char* kDmaBufHeap = "system"; + + struct Record { + int fd; + void* addr; + size_t size; + }; + + using DlHandle = std::unique_ptr; + using CreateAllocator = void* (*)(); + using FreeAllocator = void (*)(void*); + using AllocBuffer = int (*)(void*, const char*, size_t, unsigned int, size_t); + + DmaBufLibrary(DlHandle&& dlhandle, void* allocator, + FreeAllocator free_allocator, AllocBuffer alloc_buffer) + : dlhandle_(std::move(dlhandle)) { + allocator_ = allocator; + free_allocator_ = free_allocator; + alloc_buffer_ = alloc_buffer; + } + + DlHandle dlhandle_; + void* allocator_; + FreeAllocator free_allocator_; + AllocBuffer alloc_buffer_; + absl::node_hash_map records_; +}; + +DmaBufLibrary* TheDmaBufLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +absl::Status InitLibraryIfNeededUnlocked() { + if (!TheDmaBufLibrary) { + if (auto library = DmaBufLibrary::Create(); library.ok()) { + TheDmaBufLibrary = library->release(); + } else { + return library.status(); + } + } + return {}; +} + +} // namespace + +bool DmaBufBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return status.ok(); +} + +absl::StatusOr DmaBufBuffer::Alloc(size_t size) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status.ok()) { + return status; + } + return TheDmaBufLibrary->Alloc(size); +} + +void DmaBufBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheDmaBufLibrary) { + TheDmaBufLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc b/tensorflow/lite/experimental/litert/core/dmabuf_buffer.h similarity index 54% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc rename to tensorflow/lite/experimental/litert/core/dmabuf_buffer.h index 0af9d776ffdd40..83d50c2101f472 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_common.cc +++ b/tensorflow/lite/experimental/litert/core/dmabuf_buffer.h @@ -12,21 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DMABUF_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DMABUF_BUFFER_H_ -struct LrtStatusT { - LrtStatusCode code; - // TODO: b/365295276 - Implement error message payloads for lrt status. -}; +#include "absl/status/statusor.h" + +namespace litert { +namespace internal { -LrtStatusCode GetStatusCode(LrtStatus status) { return status->code; } +struct DmaBufBuffer { + int fd; + void* addr; -void StatusDestroy(LrtStatus status) { delete status; } + static bool IsSupported(); + static absl::StatusOr Alloc(size_t size); + static void Free(void* addr); +}; -LrtStatus StatusCreate(LrtStatusCode code) { - auto* res = new LrtStatusT; - res->code = code; - return res; -} +} // namespace internal +} // namespace litert -LrtStatus StatusOk() { return StatusCreate(kLrtStatusOk); } +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DMABUF_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc new file mode 100644 index 00000000000000..fbb002dc1aa035 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading.cc @@ -0,0 +1,107 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" + +#include + +#ifndef __ANDROID__ +#include +#include +#endif + +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" + +namespace litert { + +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle) { +#ifdef __ANDROID__ + void* res = ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL); +#else + void* res = ::dlopen(so_path.data(), RTLD_NOW | RTLD_LOCAL | RTLD_DEEPBIND); +#endif + + if (res == nullptr) { + LITERT_LOG(LITERT_ERROR, + "Failed to load .so at path: %s, with error:\n\t %s\n", + so_path.data(), ::dlerror()); + + return kLiteRtStatusErrorDynamicLoading; + } + *lib_handle = res; + return kLiteRtStatusOk; +} + +LiteRtStatus CloseLib(void* lib_handle) { + if (0 != ::dlclose(lib_handle)) { + LITERT_LOG(LITERT_ERROR, "Failed to close .so with error: %s", ::dlerror()); + return kLiteRtStatusErrorDynamicLoading; + } + return kLiteRtStatusOk; +} + +LiteRtStatus MakePluginLibGlobPattern(absl::string_view search_path, + std::string& pattern) { + LITERT_ENSURE(!search_path.ends_with("/"), kLiteRtStatusErrorInvalidArgument, + "Search paths must not have trailing slash"); + + // NOTE: Compiler plugin shared libraries also have "Plugin" appended + // to the standard prefix. + constexpr absl::string_view kGlobPluginLibTemplate = "%s/%sPlugin*.so"; + pattern = absl::StrFormat(kGlobPluginLibTemplate, search_path, + kLiteRtSharedLibPrefix); + return kLiteRtStatusOk; +} + +LiteRtStatus FindLiteRtSharedLibs(absl::string_view search_path, + std::vector& results) { +#ifndef __ANDROID__ + std::string glob_pattern; + LITERT_RETURN_STATUS_IF_NOT_OK( + MakePluginLibGlobPattern(search_path, glob_pattern)); + + glob_t glob_result = {}; + const int glob_status = + glob(glob_pattern.c_str(), GLOB_ERR, nullptr, &glob_result); + if (glob_status == GLOB_NOMATCH) { + LITERT_LOG(LITERT_WARNING, "%s", "Didn't find any plugin libs to load\n"); + globfree(&glob_result); + return kLiteRtStatusOk; + } else if (glob_status != 0) { + LITERT_LOG(LITERT_ERROR, "Glob failed with code: %d\n", glob_status); + globfree(&glob_result); + return kLiteRtStatusErrorNotFound; + } + + for (size_t i = 0; i < glob_result.gl_pathc; ++i) { + results.emplace_back().assign(glob_result.gl_pathv[i]); + LITERT_LOG(LITERT_INFO, "Glob matched: %s\n", results.back().c_str()); + } + + globfree(&glob_result); + return kLiteRtStatusOk; +#endif + // TODO: Glob is not supported on android. + return kLiteRtStatusErrorUnsupported; +} + +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading.h b/tensorflow/lite/experimental/litert/core/dynamic_loading.h new file mode 100644 index 00000000000000..4636187e779456 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading.h @@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +namespace litert { + +constexpr absl::string_view kLiteRtSharedLibPrefix = "libLiteRt"; + +// Loads shared library at given path. +LiteRtStatus OpenLib(absl::string_view so_path, void** lib_handle); + +// Closes reference to loaded shared library held by lib_handle. +LiteRtStatus CloseLib(void* lib_handle); + +// Resolves a named symbol from given lib handle of type Sym. +template +inline static LiteRtStatus ResolveLibSymbol(void* lib_handle, + absl::string_view sym_name, + Sym* sym_handle) { + Sym ptr = (Sym)::dlsym(lib_handle, sym_name.data()); + if (ptr == nullptr) { + LITERT_LOG(LITERT_ERROR, "Faild to resolve symbol: %s, with err: %s\n", + sym_name, ::dlerror()); + return kLiteRtStatusErrorDynamicLoading; + } + *sym_handle = ptr; + return kLiteRtStatusOk; +} + +// All internal dynamically linked dependencies for litert should be prefixed +// "libLiteRt". Find all litert shared libraries in "search_path" +LiteRtStatus FindLiteRtSharedLibs(absl::string_view search_path, + std::vector& results); + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_DYNAMIC_LOADING_H_ diff --git a/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc new file mode 100644 index 00000000000000..9800e845f7f661 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/dynamic_loading_test.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" + +#include +#include + +#include +#include +#include "testing/base/public/unique-test-directory.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +using ::litert::testing::TouchTestFile; + +constexpr absl::string_view kNotLiteRtSo = "notLibLiteRt.so"; +constexpr absl::string_view kLiteRtSo1 = "libLiteRtPlugin_1.so"; +constexpr absl::string_view kLiteRtSo2 = "libLiteRtPlugin_2.so"; + +TEST(TestDynamicLoading, GlobNoMatch) { + const auto dir = testing::UniqueTestDirectory(); + TouchTestFile(kNotLiteRtSo, dir); + + std::vector results; + ASSERT_STATUS_OK(litert::FindLiteRtSharedLibs(dir, results)); + EXPECT_EQ(results.size(), 0); +} + +TEST(TestDynamicLoading, GlobOneMatch) { + const auto dir = testing::UniqueTestDirectory(); + TouchTestFile(kLiteRtSo1, dir); + TouchTestFile(kNotLiteRtSo, dir); + + std::vector results; + ASSERT_STATUS_OK(litert::FindLiteRtSharedLibs(dir, results)); + EXPECT_EQ(results.size(), 1); + EXPECT_TRUE(absl::string_view(results.front()).ends_with(kLiteRtSo1)); +} + +TEST(TestDynamicLoading, GlobMultiMatch) { + const auto dir = testing::UniqueTestDirectory(); + TouchTestFile(kLiteRtSo1, dir); + TouchTestFile(kLiteRtSo2, dir); + TouchTestFile(kNotLiteRtSo, dir); + + std::vector results; + ASSERT_STATUS_OK(litert::FindLiteRtSharedLibs(dir, results)); + EXPECT_EQ(results.size(), 2); + EXPECT_THAT(results, testing::Contains(testing::HasSubstr(kLiteRtSo1))); + EXPECT_THAT(results, testing::Contains(testing::HasSubstr(kLiteRtSo2))); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/core/event.cc b/tensorflow/lite/experimental/litert/core/event.cc new file mode 100644 index 00000000000000..26723f428018c2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/event.cc @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/event.h" + +#include +#include +#include + +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +LiteRtStatus LiteRtEventT::Wait(int64_t timeout_in_ms) { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + struct pollfd fds = { + .fd = fd, + .events = POLLIN, + }; + + int ret; + do { + ret = ::poll(&fds, 1, timeout_in_ms); + if (ret == 1) { + break; + } else if (ret == 0) { + LITERT_LOG(LITERT_WARNING, "Timeout expired: %d", timeout_in_ms); + return kLiteRtStatusErrorTimeoutExpired; + } + } while (ret == -1 && (errno == EINTR || errno == EAGAIN)); + + if (ret < 0) { + LITERT_LOG(LITERT_ERROR, "Error waiting for fence: %s", ::strerror(errno)); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; + +#else + LITERT_LOG(LITERT_ERROR, "LiteRtEventWait not implemented for this platform"); + return kLiteRtStatusErrorUnsupported; +#endif +} + +namespace { +inline bool IsFdValid(int fd) { + return ::fcntl(fd, F_GETFD) != -1 || errno != EBADF; +} +} // namespace + +LiteRtEventT::~LiteRtEventT() { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + if (owns_fd && IsFdValid(fd)) { + ::close(fd); + } +#endif +} diff --git a/tensorflow/lite/experimental/litert/core/event.h b/tensorflow/lite/experimental/litert/core/event.h new file mode 100644 index 00000000000000..40e46e08da5124 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/event.h @@ -0,0 +1,31 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_EVENT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_EVENT_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +struct LiteRtEventT { +#if LITERT_HAS_SYNC_FENCE_SUPPORT + int fd; + bool owns_fd; +#endif + ~LiteRtEventT(); + LiteRtStatus Wait(int64_t timeout_in_ms); +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_EVENT_H_ diff --git a/tensorflow/lite/experimental/litert/core/fastrpc_buffer.cc b/tensorflow/lite/experimental/litert/core/fastrpc_buffer.cc new file mode 100644 index 00000000000000..6f6d0c61c63f95 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/fastrpc_buffer.cc @@ -0,0 +1,139 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/fastrpc_buffer.h" + +#include + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace litert { +namespace internal { + +namespace { + +class FastRpcMemLibrary { + public: + using Ptr = std::unique_ptr; + + static absl::StatusOr Create() { + DlHandle dlhandle(::dlopen("libcdsprpc.so", RTLD_NOW | RTLD_LOCAL), + ::dlclose); + if (!dlhandle) { + return absl::InternalError("libcdsprpc.so not found"); + } + + auto rpcmem_alloc = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_alloc")); + if (!rpcmem_alloc) { + return absl::InternalError("rpcmem_alloc not found"); + } + + auto rpcmem_free = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_free")); + if (!rpcmem_free) { + return absl::InternalError("rpcmem_free not found"); + } + + auto rpcmem_to_fd = + reinterpret_cast(::dlsym(dlhandle.get(), "rpcmem_to_fd")); + if (!rpcmem_to_fd) { + return absl::InternalError("rpcmem_to_fd not found"); + } + + return Ptr(new FastRpcMemLibrary(std::move(dlhandle), rpcmem_alloc, + rpcmem_free, rpcmem_to_fd)); + } + + void* Alloc(size_t size) const { + return rpcmem_alloc_(kRpcmemHeapIdSystem, kRpcmemDefaultFlags, size); + } + + void Free(void* buffer) const { return rpcmem_free_(buffer); } + + int ToFd(void* buffer) const { return rpcmem_to_fd_(buffer); } + + private: + static constexpr int kRpcmemHeapIdSystem = 25; + static constexpr uint32_t kRpcmemDefaultFlags = 1; + + using DlHandle = std::unique_ptr; + using RpcMemAlloc = void* (*)(int, uint32_t, int); + using RpcMemFree = void (*)(void*); + using RpcMemToFd = int (*)(void*); + + FastRpcMemLibrary(DlHandle&& dlhandle, RpcMemAlloc rpcmem_alloc, + RpcMemFree rpcmem_free, RpcMemToFd rpcmem_to_fd) + : dlhandle_(std::move(dlhandle)) { + rpcmem_alloc_ = rpcmem_alloc; + rpcmem_free_ = rpcmem_free; + rpcmem_to_fd_ = rpcmem_to_fd; + } + + DlHandle dlhandle_; + RpcMemAlloc rpcmem_alloc_; + RpcMemFree rpcmem_free_; + RpcMemToFd rpcmem_to_fd_; +}; + +FastRpcMemLibrary* TheFastRpcMemLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +absl::Status InitLibraryIfNeededUnlocked() { + if (!TheFastRpcMemLibrary) { + if (auto library = FastRpcMemLibrary::Create(); library.ok()) { + TheFastRpcMemLibrary = library->release(); + } else { + return library.status(); + } + } + return {}; +} + +} // namespace + +bool FastRpcBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return status.ok(); +} + +absl::StatusOr FastRpcBuffer::Alloc(size_t size) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status.ok()) { + return status; + } + void* addr = TheFastRpcMemLibrary->Alloc(size); + int fd = TheFastRpcMemLibrary->ToFd(addr); + return FastRpcBuffer{.fd = fd, .addr = addr}; +} + +void FastRpcBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheFastRpcMemLibrary) { + TheFastRpcMemLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/fastrpc_buffer.h b/tensorflow/lite/experimental/litert/core/fastrpc_buffer.h new file mode 100644 index 00000000000000..dcdbcaaa7b6e68 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/fastrpc_buffer.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FASTRPC_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FASTRPC_BUFFER_H_ + +#include "absl/status/statusor.h" + +namespace litert { +namespace internal { + +struct FastRpcBuffer { + int fd; + void* addr; + + static bool IsSupported(); + static absl::StatusOr Alloc(size_t size); + static void Free(void* addr); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_FASTRPC_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/graph_tools.h b/tensorflow/lite/experimental/litert/core/graph_tools.h new file mode 100644 index 00000000000000..bd4f171e0b73f2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/graph_tools.h @@ -0,0 +1,396 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_GRAPH_TOOLS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_GRAPH_TOOLS_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" + +#ifndef NDEBUG +#endif + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" + +#define _D_MATCH_TRUE(v) \ + { \ + if (!(v)) { \ + LITERT_LOG(LITERT_ERROR, "Failed MATCH_TRUE"); \ + return false; \ + } \ + } + +#define _D_MATCH_EQ(lhs, rhs) \ + { \ + if (lhs != rhs) { \ + LITERT_LOG(LITERT_ERROR, "Failed MATCH_EQ"); \ + return false; \ + } \ + } + +#define _MATCH_TRUE(v) \ + { \ + if (!(v)) return false; \ + } + +#define _MATCH_EQ(lhs, rhs) \ + { \ + if (lhs != rhs) return false; \ + } +#ifndef NDEBUG +#define MATCH_EQ(lhs, rhs) _D_MATCH_EQ(lhs, rhs) +#define MATCH_TRUE(v) _D_MATCH_TRUE(v) +#else +#define MATCH_EQ(lhs, rhs) _MATCH_EQ(lhs, rhs) +#define MATCH_TRUE(v) _MATCH_TRUE(v) +#endif + +namespace graph_tools { + +using RankedTypeInfo = std::tuple>; + +using TensorUseInfo = std::tuple; +using ::litert::BufferRef; + +//===----------------------------------------------------------------------===// +// Getters // +//===----------------------------------------------------------------------===// + +// TODO: b/365299994 - Switch llvm container types for mobile friendly ones. +// Likely will need to define them. + +// Get the ops that reference given tensor. +inline LiteRtResult> GetTensorUses( + LiteRtTensor tensor) { + LiteRtParamIndex num_uses; + LiteRtParamIndex* use_user_arg_ind; + LiteRtOpArray users = nullptr; + + LITERT_RETURN_RESULT_IF_NOT_OK( + GetTensorUses(tensor, &num_uses, &users, &use_user_arg_ind), + llvm::SmallVector); + + llvm::ArrayRef users_arr(users, num_uses); + llvm::ArrayRef user_arg_ind_arr(use_user_arg_ind, num_uses); + + auto results = llvm::zip(users_arr, user_arg_ind_arr); + llvm::SmallVector results_vec(results.begin(), results.end()); + + return LiteRtResult>::FromValue(results_vec); +} + +// Get the only user of given tensor, bad status if tensor doesn't have +// exactly one user. +inline LiteRtResult GetTensorOnlyUse(LiteRtTensor tensor) { + LITERT_ASSIGN_OR_RETURN_RESULT(auto uses, GetTensorUses(tensor), + TensorUseInfo); + if (uses.size() != 1) { + return LiteRtResult::FromStatus( + kLiteRtStatusErrorInvalidGraphInvariant); + } + return LiteRtResult::FromValue(uses[0]); +} + +// Get tensor inputs to given op. +inline LiteRtResult> GetOpIns(LiteRtOp op) { + LiteRtParamIndex num_inputs; + LiteRtTensorArray inputs = nullptr; + + LITERT_RETURN_RESULT_IF_NOT_OK(GetOpInputs(op, &num_inputs, &inputs), + llvm::ArrayRef); + + return LiteRtResult>::FromValue( + llvm::ArrayRef(inputs, num_inputs)); +} + +// Get the only tensor input to given op, bad status if op doesn't have +// exacty one input. +inline LiteRtResult GetOnlyOpIn(LiteRtOp op) { + LITERT_ASSIGN_OR_RETURN_RESULT(auto ins, GetOpIns(op), LiteRtTensor); + if (ins.size() != 1) { + return LiteRtResult::FromStatus( + kLiteRtStatusErrorInvalidGraphInvariant); + } + return LiteRtResult::FromValue(ins[0]); +} + +// Get tensors outputs to given op. +inline LiteRtResult> GetOpOuts(LiteRtOp op) { + LiteRtParamIndex num_outputs; + LiteRtTensorArray outputs = nullptr; + + LITERT_RETURN_RESULT_IF_NOT_OK(GetOpOutputs(op, &num_outputs, &outputs), + llvm::ArrayRef); + + return LiteRtResult>::FromValue( + llvm::ArrayRef(outputs, num_outputs)); +} + +// Get the only tensor output to given op, bad status if op doesn't have +// exactly one output. +inline LiteRtResult GetOnlyOpOut(LiteRtOp op) { + LITERT_ASSIGN_OR_RETURN_RESULT(auto outs, GetOpOuts(op), LiteRtTensor); + if (outs.size() != 1) { + return LiteRtResult::FromStatus( + kLiteRtStatusErrorInvalidGraphInvariant); + } + return LiteRtResult::FromValue(outs[0]); +} + +// Get all ops in given subgraph in topological order. +inline LiteRtResult> GetSubgraphOps( + LiteRtSubgraph subgraph) { + LiteRtParamIndex num_ops; + LiteRtOpArray ops = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK(GetSubgraphOps(subgraph, &num_ops, &ops), + llvm::ArrayRef); + + return LiteRtResult>::FromValue( + llvm::ArrayRef(ops, num_ops)); +} + +// Get tensor inputs to given subgraph. +inline LiteRtResult> GetSubgraphInputs( + LiteRtSubgraph subgraph) { + LiteRtParamIndex num_inputs; + LiteRtTensorArray inputs = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK( + GetSubgraphInputs(subgraph, &num_inputs, &inputs), + llvm::ArrayRef); + + return LiteRtResult>::FromValue( + llvm::ArrayRef(inputs, num_inputs)); +} + +// Get tensor outputs to given subgraph. +inline LiteRtResult> GetSubgraphOutputs( + LiteRtSubgraph subgraph) { + LiteRtParamIndex num_outputs; + LiteRtTensorArray outputs = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK( + GetSubgraphOutputs(subgraph, &num_outputs, &outputs), + llvm::ArrayRef); + + return LiteRtResult>::FromValue( + llvm::ArrayRef(outputs, num_outputs)); +} + +// Get only subgraph in given model, bad status if model doesn't have exactly +// one subgraph. +// TODO: b/365299994 - Add multi-subgraph getters for graph tools. +inline LiteRtResult GetSubgraph(LiteRtModel model) { + LiteRtParamIndex num_subgraphs; + LITERT_RETURN_RESULT_IF_NOT_OK(GetModelNumSubgraphs(model, &num_subgraphs), + LiteRtSubgraph); + + if (num_subgraphs != 1) { + return LiteRtResult::FromStatus( + kLiteRtStatusErrorUnsupported); + } + + LiteRtSubgraph subgraph = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK(GetModelSubgraph(model, 0, &subgraph), + LiteRtSubgraph); + + return LiteRtResult::FromValue(subgraph); +} + +// Get raw metadata buffer from model if it exists. +inline LiteRtResult> GetMetadata( + LiteRtModel model, const absl::string_view key) { + using ResT = LiteRtResult>; + const uint8_t* buf; + size_t size; + LITERT_RETURN_RESULT_IF_NOT_OK( + LiteRtModelGetMetadata(model, key.data(), + reinterpret_cast(&buf), &size), + BufferRef); + return ResT::FromValue(BufferRef(buf, size)); +} + +//===----------------------------------------------------------------------===// +// Matchers // +//===----------------------------------------------------------------------===// + +// Matches tensor type id, shape and element type for given tensor. +inline bool MatchRankedTensorType(LiteRtTensor tensor, + LiteRtElementType element_type, + llvm::ArrayRef shape) { + LiteRtTensorTypeId type_id; + LITERT_RETURN_VAL_IF_NOT_OK(GetTensorTypeId(tensor, &type_id), false); + MATCH_EQ(type_id, kLiteRtRankedTensorType); + + LiteRtRankedTensorType ranked_tensor_type; + LITERT_RETURN_VAL_IF_NOT_OK(GetRankedTensorType(tensor, &ranked_tensor_type), + false); + MATCH_EQ(ranked_tensor_type.element_type, element_type); + MATCH_EQ(ranked_tensor_type.layout.rank, shape.size()); + + for (int i = 0; i < shape.size(); ++i) { + MATCH_EQ(shape[i], ranked_tensor_type.layout.dimensions[i]); + } + + return true; +} + +// Matches users of given tensor (ordering doesn't matter). If strict is true, +// `use_info` must have same number of elements as tensor has uses. If not, +// it must be a subset. +inline bool MatchTensorHasUses(LiteRtTensor tensor, + llvm::ArrayRef use_info, + bool strict = true) { + // uses are unique so this is sufficient to check for equality. + LITERT_ASSIGN_OR_RETURN_VAL(auto uses, GetTensorUses(tensor), false); + MATCH_TRUE(!strict || (uses.size() == use_info.size())); + + llvm::SetVector unique_uses(uses.begin(), uses.end()); + + return llvm::all_of(use_info, + [&](auto use) { return unique_uses.contains(use); }); +} + +// Matches a tensor with no uses. +inline bool MatchTensorNoUses(LiteRtTensor tensor) { + LiteRtParamIndex num_uses; + LiteRtParamIndex* use_user_arg_ind; + LiteRtOpArray users = nullptr; + + LITERT_RETURN_VAL_IF_NOT_OK( + GetTensorUses(tensor, &num_uses, &users, &use_user_arg_ind), false); + + return num_uses == 0; +} + +// Matches a tensors defining op and output indice. +inline bool MatchTensorDefiningOp(LiteRtTensor tensor, + LiteRtParamIndex expected_defining_op_out_ind, + LiteRtOp expected_defining_op) { + LiteRtOp defining_op = nullptr; + LiteRtParamIndex defining_op_out_ind; + + LITERT_RETURN_VAL_IF_NOT_OK( + GetTensorDefiningOp(tensor, &defining_op, &defining_op_out_ind), false); + MATCH_EQ(defining_op, expected_defining_op); + + return expected_defining_op == nullptr || + expected_defining_op_out_ind == defining_op_out_ind; +} + +// Matches a tensor that is not the output of an op (subgraph inputs/consts). +inline bool MatchTensorNoDefiningOp(LiteRtTensor tensor) { + return MatchTensorDefiningOp(tensor, 0, nullptr); +} + +// Matches the op code and types of given ops inputs and outputs. +inline bool MatchOpType(LiteRtOp op, + llvm::ArrayRef input_type_info, + llvm::ArrayRef output_type_info, + LiteRtOpCode code) { + LiteRtOpCode actual_code; + LITERT_RETURN_VAL_IF_NOT_OK(GetOpCode(op, &actual_code), false); + MATCH_EQ(actual_code, code); + + const auto exptected_num_inputs = input_type_info.size(); + + LITERT_ASSIGN_OR_RETURN_VAL(auto inputs, GetOpIns(op), false); + for (int i = 0; i < exptected_num_inputs; ++i) { + const auto& [type, shape] = input_type_info[i]; + MATCH_TRUE(MatchRankedTensorType(inputs[i], type, shape)); + } + + const auto expected_num_outputs = output_type_info.size(); + + LITERT_ASSIGN_OR_RETURN_VAL(auto outputs, GetOpOuts(op), false); + for (int i = 0; i < expected_num_outputs; ++i) { + const auto& [type, shape] = output_type_info[i]; + MATCH_TRUE(MatchRankedTensorType(outputs[i], type, shape)); + } + + return true; +} + +// Checks that doubly linked structure of ops <-> tensors is valid. +inline bool ValidateTopology(llvm::ArrayRef ops) { + for (auto& op : ops) { + LITERT_ASSIGN_OR_RETURN_VAL(auto inputs, GetOpIns(op), false); + for (auto [input_ind, input] : llvm::enumerate(inputs)) { + MATCH_TRUE(MatchTensorHasUses(input, {{op, input_ind}}, false)); + } + + LITERT_ASSIGN_OR_RETURN_VAL(auto outputs, GetOpOuts(op), false); + for (auto [output_ind, output] : llvm::enumerate(outputs)) { + MATCH_TRUE(MatchTensorDefiningOp(output, output_ind, op)); + } + } + return true; +} + +// Get weights behind given tensor. +template +inline LiteRtResult> GetWeights(LiteRtTensor tensor) { + LiteRtWeights weights = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK(GetTensorWeights(tensor, &weights), + llvm::ArrayRef); + size_t size; + const void* data = nullptr; + LITERT_RETURN_RESULT_IF_NOT_OK(GetWeightsInfo(weights, &size, &data), + llvm::ArrayRef); + return LiteRtResult>::FromValue( + llvm::ArrayRef(static_cast(data), size)); +} + +// Match weights behind given tensor contains data. +template +inline bool MatchWeights(LiteRtTensor tensor, llvm::ArrayRef expected_data) { + LiteRtWeights weights = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK(GetTensorWeights(tensor, &weights), false); + MATCH_TRUE(weights != nullptr); + + size_t size; + const void* data = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK(GetWeightsInfo(weights, &size, &data), false); + MATCH_TRUE(data != nullptr); + + MATCH_EQ(size, expected_data.size() * sizeof(T)); + return llvm::ArrayRef(static_cast(data), expected_data.size()) == + expected_data; +} + +// Match given tensor having no (empty) weights. +inline bool MatchNoWeights(LiteRtTensor tensor) { + LiteRtWeights weights = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK(GetTensorWeights(tensor, &weights), false); + MATCH_TRUE(weights != nullptr); + + size_t size; + const void* data = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK(GetWeightsInfo(weights, &size, &data), false); + + return size == 0; +} +} // namespace graph_tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_GRAPH_TOOLS_H_ diff --git a/tensorflow/lite/experimental/litert/core/ion_buffer.cc b/tensorflow/lite/experimental/litert/core/ion_buffer.cc new file mode 100644 index 00000000000000..bbaf7978dc6985 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/ion_buffer.cc @@ -0,0 +1,175 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/ion_buffer.h" + +#include +#include + +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/const_init.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" + +namespace litert { +namespace internal { + +namespace { + +class IonLibrary { + public: + using Ptr = std::unique_ptr; + + ~IonLibrary() { + if (client_fd_ > 0) { + ion_close_(client_fd_); + } + } + + static absl::StatusOr Create() { + DlHandle dlhandle(::dlopen("libion.so", RTLD_NOW | RTLD_LOCAL), ::dlclose); + if (!dlhandle) { + return absl::InternalError("libion.so not found"); + } + + auto ion_open = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_open")); + if (!ion_open) { + return absl::InternalError("ion_open not found"); + } + + auto ion_close = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_close")); + if (!ion_close) { + return absl::InternalError("ion_close not found"); + } + + auto ion_alloc_fd = + reinterpret_cast(::dlsym(dlhandle.get(), "ion_alloc_fd")); + if (!ion_alloc_fd) { + return absl::InternalError("ion_alloc_fd not found"); + } + + int client_fd = ion_open(); + if (client_fd < 0) { + return absl::InternalError("Failed to open ion device"); + } + + return Ptr(new IonLibrary(std::move(dlhandle), client_fd, ion_close, + ion_alloc_fd)); + } + + absl::StatusOr Alloc(size_t size, size_t alignment) { + int heap_id_mask = 1 << kIonHeapId; + int fd; + if (auto status = ion_alloc_fd_(client_fd_, size, alignment, heap_id_mask, + kIonFlags, &fd); + status != 0) { + return absl::InternalError("Failed to allocate DMA-BUF buffer"); + } + void* addr = + ::mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (addr == MAP_FAILED) { + return absl::InternalError("Failed to mem-map DMA-BUF buffer"); + } + records_[addr] = Record{.fd = fd, .addr = addr, .size = size}; + return IonBuffer{.fd = fd, .addr = addr}; + } + + void Free(void* addr) { + auto iter = records_.find(addr); + if (iter == records_.end()) { + return; + } + auto& record = iter->second; + ::munmap(record.addr, record.size); + ::close(record.fd); + records_.erase(iter); + } + + private: + static constexpr const int kIonHeapId = 25; + static constexpr const int kIonFlags = 1; + + struct Record { + int fd; + void* addr; + size_t size; + }; + + using DlHandle = std::unique_ptr; + using IonOpen = int (*)(); + using IonClose = int (*)(int); + using IonAllocFd = int (*)(int, size_t, size_t, unsigned int, unsigned int, + int*); + + IonLibrary(DlHandle&& dlhandle, int client_fd, IonClose ion_close, + IonAllocFd ion_alloc_fd) + : dlhandle_(std::move(dlhandle)), + client_fd_(client_fd), + ion_close_(ion_close), + ion_alloc_fd_(ion_alloc_fd) {} + + DlHandle dlhandle_; + int client_fd_; + IonClose ion_close_; + IonAllocFd ion_alloc_fd_; + absl::node_hash_map records_; +}; + +IonLibrary* TheIonLibrary; +ABSL_CONST_INIT absl::Mutex TheMutex(absl::kConstInit); + +absl::Status InitLibraryIfNeededUnlocked() { + if (!TheIonLibrary) { + if (auto library = IonLibrary::Create(); library.ok()) { + TheIonLibrary = library->release(); + } else { + return library.status(); + } + } + return {}; +} + +} // namespace + +bool IonBuffer::IsSupported() { + absl::MutexLock lock(&TheMutex); + auto status = InitLibraryIfNeededUnlocked(); + return status.ok(); +} + +absl::StatusOr IonBuffer::Alloc(size_t size, size_t alignment) { + absl::MutexLock lock(&TheMutex); + if (auto status = InitLibraryIfNeededUnlocked(); !status.ok()) { + return status; + } + return TheIonLibrary->Alloc(size, alignment); +} + +void IonBuffer::Free(void* addr) { + absl::MutexLock lock(&TheMutex); + if (TheIonLibrary) { + TheIonLibrary->Free(addr); + } +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/ion_buffer.h b/tensorflow/lite/experimental/litert/core/ion_buffer.h new file mode 100644 index 00000000000000..f5b74ce82ac36d --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/ion_buffer.h @@ -0,0 +1,35 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ION_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ION_BUFFER_H_ + +#include "absl/status/statusor.h" + +namespace litert { +namespace internal { + +struct IonBuffer { + int fd; + void* addr; + + static bool IsSupported(); + static absl::StatusOr Alloc(size_t size, size_t alignment); + static void Free(void* addr); +}; + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_ION_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/litert_common.cc b/tensorflow/lite/experimental/litert/core/litert_common.cc new file mode 100644 index 00000000000000..b29ce8b30d65d4 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/litert_common.cc @@ -0,0 +1,15 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc b/tensorflow/lite/experimental/litert/core/litert_model_init.cc similarity index 50% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc rename to tensorflow/lite/experimental/litert/core/litert_model_init.cc index 4d9d94f18de0fd..369a1ab30c0495 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.cc +++ b/tensorflow/lite/experimental/litert/core/litert_model_init.cc @@ -17,173 +17,172 @@ #define FLATBUFFERS_DEBUG_VERIFICATION_FAILURE #include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" #endif #include #include #include #include +#include #include #include #include -#include "absl/log/check.h" -#include "flatbuffers/verifier.h" // from @flatbuffers +#include "absl/container/flat_hash_map.h" +#include "absl/log/absl_check.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/allocation.h" #include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/model.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/stderr_reporter.h" -// NOLINTBEGIN -void SetFbVerifyOptions(flatbuffers::Verifier::Options& opts) { -#ifndef NDEBUG - opts.assert = true; -#endif -} +using ::litert::OwningBufferRef; +using ::litert::internal::VerifyFlatbuffer; -LrtStatus VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { - // TODO: b/365299994 - If buffer verification is slow, run only in debug. - // Also check file size. - flatbuffers::Verifier::Options options; - SetFbVerifyOptions(options); - flatbuffers::Verifier verifier(buf, buf_size, options); - if (!tflite::VerifyModelBuffer(verifier)) { - _LRT_D_MSG("Failed to verify fb"); - return StatusCreate(kLrtStatusFlatbufferFailedVerify); - } - return StatusOk(); -} - -LrtStatus IsOpSupported(const tflite::OperatorT& op) { +static LiteRtStatus IsOpSupported(const tflite::OperatorT& op) { // TODO: b/365299994 - Check for supported options. if (!op.custom_options.empty()) { // TODO: b/365299994 - Support custom options. - _LRT_D_MSG("Custom options not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Custom options not supported."); + return kLiteRtStatusErrorUnsupported; } if (!op.intermediates.empty()) { // TODO: b/365299994 - Support intermediates. - _LRT_D_MSG("Intermediate tensors not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Intermediate tensors not supported."); + return kLiteRtStatusErrorUnsupported; } if (op.large_custom_options_size != 0) { // TODO: b/365299994 - Support large custom options. - _LRT_D_MSG("Large custom options not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Large custom options not supported."); + return kLiteRtStatusErrorUnsupported; } for (auto m_input : op.mutating_variable_inputs) { if (m_input) { // TODO: b/365299994 - Support mutating variable inputs. - _LRT_D_MSG("Mutating variable inputs not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Mutating variable inputs not supported."); + return kLiteRtStatusErrorUnsupported; } } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus IsBufferSupported(const tflite::BufferT& buffer) { +static LiteRtStatus IsBufferSupported(const tflite::BufferT& buffer) { if (buffer.offset != 0) { // TODO: b/365299994 - Support buffer with offset. - _LRT_D_MSG("Buffers with offset not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Buffers with offset not supported."); + return kLiteRtStatusErrorUnsupported; } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus IsTensorSupported(const tflite::TensorT& tensor) { +static LiteRtStatus IsTensorSupported(const tflite::TensorT& tensor) { if (!tensor.has_rank) { // TODO: b/365299994 - Support unranked tensors. - _LRT_D_MSG("Unranked tensors not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Unranked tensors not supported."); + return kLiteRtStatusErrorUnsupported; } if (tensor.is_variable) { // TODO: b/365299994 - Support variable tensors. - _LRT_D_MSG("Variable tensors not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Variable tensors not supported."); + return kLiteRtStatusErrorUnsupported; } if (!tensor.variant_tensors.empty()) { // TODO: b/365299994 - Support variant tensors. - _LRT_D_MSG("Variant tensors not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Variant tensors not supported."); + return kLiteRtStatusErrorUnsupported; } if (!tensor.shape_signature.empty()) { // TODO: b/365299994 - Support shape signature. - _LRT_D_MSG("Shape signature not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Shape signature not supported."); + return kLiteRtStatusErrorUnsupported; } if (tensor.sparsity) { // TODO: b/365299994 - Support sparsity tensors. - _LRT_D_MSG("Sparsity tensors not supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Sparsity tensors not supported."); + return kLiteRtStatusErrorUnsupported; } - if (tensor.type != tflite::TensorType_FLOAT32) { + if (tensor.type != tflite::TensorType_FLOAT32 && + tensor.type != tflite::TensorType_INT32 && + tensor.type != tflite::TensorType_BOOL) { // TODO: b/365299994 - Support all element types. - _LRT_D_MSG("Only f32 supported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Only f32 supported."); + return kLiteRtStatusErrorUnsupported; } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus SetDefaultOptions(tflite::BuiltinOptionsUnion& opts, LrtOpCode code) { +static LiteRtStatus SetDefaultOptions(tflite::BuiltinOptionsUnion& opts, + LiteRtOpCode code) { switch (code) { - case kLrtOpCodeTflMul: + case kLiteRtOpCodeTflMul: opts.Set(tflite::MulOptionsT()); break; - case kLrtOpCodeTflAdd: + case kLiteRtOpCodeTflAdd: opts.Set(tflite::AddOptionsT()); break; - case kLrtOpCodeTflCustom: - return StatusOk(); + case kLiteRtOpCodeTflCustom: + return kLiteRtStatusOk; default: - return StatusCreate(kLrtStatusErrorUnsupported); + return kLiteRtStatusErrorUnsupported; } - return StatusOk(); + return kLiteRtStatusOk; } -void SetCustomOptions(tflite::OperatorT& op, std::string_view options_data) { - const uint8_t* data = reinterpret_cast(options_data.data()); - op.custom_options.assign(data, data + options_data.size()); - op.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; +LiteRtElementType MapElementType(tflite::TensorType type) { + switch (type) { + case tflite::TensorType_FLOAT32: + return kLiteRtElementTypeFloat32; + case tflite::TensorType_FLOAT16: + return kLiteRtElementTypeFloat16; + case tflite::TensorType_INT32: + return kLiteRtElementTypeInt32; + case tflite::TensorType_BOOL: + return kLiteRtElementTypeBool; + default: + return kLiteRtElementTypeNone; + } } - //===----------------------------------------------------------------------===// // Load // //===----------------------------------------------------------------------===// class ModelUnpacker { public: - static LrtStatus Unpack(LrtModel model); + static LiteRtStatus Unpack(LiteRtModel model); private: - explicit ModelUnpacker(LrtModel model) : model_(model) {} + explicit ModelUnpacker(LiteRtModel model) : model_(model) {} - LrtStatus ConvertTensor(const tflite::TensorT& tensor, LrtTensor target); + LiteRtStatus ConvertTensor(const tflite::TensorT& tensor, + LiteRtTensor target); - LrtStatus ConvertOp(const tflite::OperatorT& op, - std::vector& tensors, LrtOp target); + LiteRtStatus ConvertOp(const tflite::OperatorT& op, + std::vector& tensors, LiteRtOp target); - LrtStatus UnpackSubgraph(LrtSubgraph target); + LiteRtStatus UnpackSubgraph(LiteRtSubgraph target); - LrtOpCode GetOpCode(uint32_t ind) { - return static_cast(Fb().operator_codes[ind]->builtin_code); + LiteRtOpCode GetOpCode(uint32_t ind) { + return static_cast(Fb().operator_codes[ind]->builtin_code); } std::unique_ptr GetBuffer(uint32_t ind) { @@ -192,38 +191,44 @@ class ModelUnpacker { tflite::ModelT& Fb() { return *model_->flatbuffer_model; } - LrtModel model_; + LiteRtModel model_; }; -LrtStatus ModelUnpacker::ConvertTensor(const tflite::TensorT& tensor, - LrtTensor target) { - LRT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(tensor)); +LiteRtStatus ModelUnpacker::ConvertTensor(const tflite::TensorT& tensor, + LiteRtTensor target) { + LITERT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(tensor)); const auto buffer_ind = tensor.buffer; if (buffer_ind != 0) { - target->buffer.fb_buffer = GetBuffer(buffer_ind); - LRT_RETURN_STATUS_IF_NOT_OK(IsBufferSupported(*target->buffer.fb_buffer)); + target->weights.fb_buffer = GetBuffer(buffer_ind); + LITERT_RETURN_STATUS_IF_NOT_OK( + IsBufferSupported(*target->weights.fb_buffer)); } - target->type_id = kLrtRankedTensorType; + target->type_id = kLiteRtRankedTensorType; auto& ranked_tensor = target->type_detail.ranked_tensor_type; - ranked_tensor.layout.dimensions = tensor.shape.data(); + ranked_tensor.element_type = MapElementType(tensor.type); ranked_tensor.layout.rank = tensor.shape.size(); + ranked_tensor.layout.dimensions = tensor.shape.data(); + ranked_tensor.layout.strides = + nullptr; // TFL tensors don't support strides yet. - ranked_tensor.element_type = kLrtElementTypeFloat32; - - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelUnpacker::ConvertOp(const tflite::OperatorT& op, - std::vector& tensors, - LrtOp target) { +LiteRtStatus ModelUnpacker::ConvertOp(const tflite::OperatorT& op, + std::vector& tensors, + LiteRtOp target) { target->op_code = GetOpCode(op.opcode_index); for (auto input : op.inputs) { + // Skipping optional input tensor. + if (input == -1) { + continue; + } auto& input_tensor = tensors[input]; input_tensor->users.push_back(target); @@ -240,21 +245,22 @@ LrtStatus ModelUnpacker::ConvertOp(const tflite::OperatorT& op, target->outputs.push_back(output_tensor); } + target->option = op.builtin_options; - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelUnpacker::UnpackSubgraph(LrtSubgraph target) { +LiteRtStatus ModelUnpacker::UnpackSubgraph(LiteRtSubgraph target) { auto& subgraph = target->flatbuffer_subgraph; for (int i = 0; i < subgraph->tensors.size(); ++i) { auto& flatbuffer_tensor = *subgraph->tensors[i]; - LRT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(flatbuffer_tensor)); + LITERT_RETURN_STATUS_IF_NOT_OK(IsTensorSupported(flatbuffer_tensor)); auto& tensor = target->tensors_storage.emplace_back(); target->tensors.push_back(&tensor); - LRT_RETURN_STATUS_IF_NOT_OK(ConvertTensor(flatbuffer_tensor, &tensor)); + LITERT_RETURN_STATUS_IF_NOT_OK(ConvertTensor(flatbuffer_tensor, &tensor)); } for (int i = 0; i < subgraph->operators.size(); ++i) { @@ -263,7 +269,8 @@ LrtStatus ModelUnpacker::UnpackSubgraph(LrtSubgraph target) { auto& op = target->ops_storage.emplace_back(); target->ops.push_back(&op); - LRT_RETURN_STATUS_IF_NOT_OK(ConvertOp(flatbuffer_op, target->tensors, &op)); + LITERT_RETURN_STATUS_IF_NOT_OK( + ConvertOp(flatbuffer_op, target->tensors, &op)); } for (auto input : subgraph->inputs) { @@ -274,65 +281,68 @@ LrtStatus ModelUnpacker::UnpackSubgraph(LrtSubgraph target) { target->outputs.push_back(target->tensors[output]); } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelUnpacker::Unpack(LrtModel model) { +LiteRtStatus ModelUnpacker::Unpack(LiteRtModel model) { ModelUnpacker unpacker(model); if (unpacker.Fb().subgraphs.size() != 1) { // TODO: b/365299994 - Support multi subgraph. - _LRT_D_MSG("Only single subgraph models suported."); - return StatusCreate(kLrtStatusErrorUnsupported); + _LITERT_D_MSG("Only single subgraph models suported."); + return kLiteRtStatusErrorUnsupported; } auto& subgraph = model->subgraphs.emplace_back(); subgraph.flatbuffer_subgraph = std::move(unpacker.Fb().subgraphs[0]); - LRT_RETURN_STATUS_IF_NOT_OK(unpacker.UnpackSubgraph(&subgraph)); + LITERT_RETURN_STATUS_IF_NOT_OK(unpacker.UnpackSubgraph(&subgraph)); - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus RegisterCustomOpCode(LrtModel model, const char* new_op_code) { +LiteRtStatus RegisterCustomOpCode(LiteRtModel model, const char* new_op_code) { model->custom_op_code.assign(new_op_code); - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus LoadModel(std::unique_ptr flatbuffer, - LrtModel* model) { - auto lrt_model = std::make_unique(); - lrt_model->flatbuffer_model = std::move(flatbuffer); - lrt_model->subgraphs.reserve(100); +static LiteRtStatus LoadModel(std::unique_ptr flatbuffer, + LiteRtModel* model) { + auto litert_model = std::make_unique(); + litert_model->flatbuffer_model = std::move(flatbuffer); + litert_model->subgraphs.reserve(100); - LRT_RETURN_STATUS_IF_NOT_OK(ModelUnpacker::Unpack(lrt_model.get())); + LITERT_RETURN_STATUS_IF_NOT_OK(ModelUnpacker::Unpack(litert_model.get())); - lrt_model->flatbuffer_model->subgraphs.clear(); + litert_model->flatbuffer_model->subgraphs.clear(); // Set as empty string in case its not set explictly. - LRT_RETURN_STATUS_IF_NOT_OK(RegisterCustomOpCode(lrt_model.get(), "")); + LITERT_RETURN_STATUS_IF_NOT_OK(RegisterCustomOpCode(litert_model.get(), "")); - *model = lrt_model.release(); + *model = litert_model.release(); - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus LoadModel(const uint8_t* buf, size_t buf_size, LrtModel* model) { - LRT_RETURN_STATUS_IF_NOT_OK(VerifyFlatbuffer(buf, buf_size)); +LiteRtStatus LoadModel(const uint8_t* buf, size_t buf_size, + LiteRtModel* model) { + LITERT_ENSURE(VerifyFlatbuffer(buf, buf_size), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to verify flatbuffer"); return LoadModel(tflite::UnPackModel(buf), model); } -LrtStatus LoadModelFromFile(const char* path, LrtModel* model) { +LiteRtStatus LoadModelFromFile(const char* path, LiteRtModel* model) { std::unique_ptr alloc = tflite::GetAllocationFromFile(path, tflite::DefaultErrorReporter()); if (!alloc->valid()) { - return StatusCreate(kLrtStatusBadFileOp); + return kLiteRtStatusErrorFileIO; } return LoadModel(reinterpret_cast(alloc->base()), alloc->bytes(), model); } -void ModelDestroy(LrtModel model) { delete model; } +void ModelDestroy(LiteRtModel model) { delete model; } //===----------------------------------------------------------------------===// // Serialize // @@ -340,23 +350,24 @@ void ModelDestroy(LrtModel model) { delete model; } class ModelRepacker { public: - static LrtStatus Repack(LrtModel model); + static LiteRtStatus Repack(LiteRtModel model); private: - static void BuildOpCodeMap(LrtModel model, - std::unordered_map& map); + static void BuildOpCodeMap(LiteRtModel model, + std::unordered_map& map); - explicit ModelRepacker(LrtModel model) : model_(model) { + explicit ModelRepacker(LiteRtModel model) : model_(model) { BuildOpCodeMap(model_, op_code_map_); } - LrtStatus SerializeTensor(LrtTensor tensor, tflite::TensorT& target); + LiteRtStatus SerializeTensor(LiteRtTensor tensor, tflite::TensorT& target); - LrtStatus SerializeOp( - LrtOp op, tflite::OperatorT& target, - const std::unordered_map& tensor_map); + LiteRtStatus SerializeOp( + LiteRtOp op, tflite::OperatorT& target, + const absl::flat_hash_map& tensor_map); - LrtStatus SerializeSubgraph(LrtSubgraph subgraph, tflite::SubGraphT& target); + LiteRtStatus SerializeSubgraph(LiteRtSubgraph subgraph, + tflite::SubGraphT& target); uint32_t SubmitBuffer(std::unique_ptr buffer) { OldFb().buffers.push_back(std::move(buffer)); @@ -365,12 +376,12 @@ class ModelRepacker { tflite::ModelT& OldFb() { return *model_->flatbuffer_model; } - LrtModel model_; - std::unordered_map op_code_map_; + LiteRtModel model_; + std::unordered_map op_code_map_; }; void ModelRepacker::BuildOpCodeMap( - LrtModel model, std::unordered_map& map) { + LiteRtModel model, std::unordered_map& map) { // Add the user set custom code to the flatbuffers known codes. auto& custom_code = model->flatbuffer_model->operator_codes.emplace_back( std::make_unique()); @@ -382,30 +393,34 @@ void ModelRepacker::BuildOpCodeMap( for (int i = 0; i < codes.size(); ++i) { const auto tfl_code = codes[i]->builtin_code; - map.insert({static_cast(tfl_code), i}); + map.insert({static_cast(tfl_code), i}); } } -LrtStatus ModelRepacker::SerializeTensor(LrtTensor tensor, - tflite::TensorT& target) { +LiteRtStatus ModelRepacker::SerializeTensor(LiteRtTensor tensor, + tflite::TensorT& target) { target.has_rank = true; const auto& type = tensor->type_detail.ranked_tensor_type; - // TODO: b/365299994 - Map lrt element types to flatbuffer elements types. + // TODO: b/365299994 - Map litert element types to flatbuffer elements types. target.type = tflite::TensorType_FLOAT32; for (int i = 0; i < type.layout.rank; ++i) { target.shape.push_back(type.layout.dimensions[i]); } - DCHECK(tensor->buffer.fb_buffer != nullptr) << "Submitting a null buffer"; - target.buffer = SubmitBuffer(std::move(tensor->buffer.fb_buffer)); + // TFL tensors don't support strides yet. + ABSL_DCHECK(type.layout.strides == nullptr); + + ABSL_DCHECK(tensor->weights.fb_buffer != nullptr) + << "Submitting a null buffer"; + target.buffer = SubmitBuffer(std::move(tensor->weights.fb_buffer)); - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelRepacker::SerializeOp( - LrtOp op, tflite::OperatorT& target, - const std::unordered_map& tensor_map) { +LiteRtStatus ModelRepacker::SerializeOp( + LiteRtOp op, tflite::OperatorT& target, + const absl::flat_hash_map& tensor_map) { target.opcode_index = op_code_map_.at(op->op_code); for (auto in : op->inputs) { @@ -417,32 +432,33 @@ LrtStatus ModelRepacker::SerializeOp( } // TODO: b/365299994 - Support options in serialize. - LRT_RETURN_STATUS_IF_NOT_OK_MSG( + LITERT_RETURN_STATUS_IF_NOT_OK_MSG( SetDefaultOptions(target.builtin_options, op->op_code), "Failed serializing options"); - if (!op->custom_options.empty()) { - SetCustomOptions(target, op->custom_options); + if (op->custom_options.Size() != 0) { + target.custom_options = op->custom_options.ToVec(); + target.custom_options_format = tflite::CustomOptionsFormat_FLEXBUFFERS; } // TODO: b/365299994 - Support exotic op fields in serialize. - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelRepacker::SerializeSubgraph(LrtSubgraph subgraph, - tflite::SubGraphT& target) { - std::unordered_map tensor_map; +LiteRtStatus ModelRepacker::SerializeSubgraph(LiteRtSubgraph subgraph, + tflite::SubGraphT& target) { + absl::flat_hash_map tensor_map; for (auto tensor : subgraph->tensors) { tensor_map.insert({tensor, tensor_map.size()}); target.tensors.push_back(std::make_unique()); - LRT_RETURN_STATUS_IF_NOT_OK( + LITERT_RETURN_STATUS_IF_NOT_OK( SerializeTensor(tensor, *target.tensors.back())); } for (auto op : subgraph->ops) { target.operators.push_back(std::make_unique()); - LRT_RETURN_STATUS_IF_NOT_OK( + LITERT_RETURN_STATUS_IF_NOT_OK( SerializeOp(op, *target.operators.back(), tensor_map)); } @@ -453,10 +469,10 @@ LrtStatus ModelRepacker::SerializeSubgraph(LrtSubgraph subgraph, target.outputs.push_back(tensor_map.at(out)); } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus ModelRepacker::Repack(LrtModel model) { +LiteRtStatus ModelRepacker::Repack(LiteRtModel model) { ModelRepacker repacker(model); auto& target = repacker.OldFb(); @@ -478,7 +494,7 @@ LrtStatus ModelRepacker::Repack(LrtModel model) { for (auto& subgraph : model->subgraphs) { target.subgraphs.push_back(std::make_unique()); - LRT_RETURN_STATUS_IF_NOT_OK( + LITERT_RETURN_STATUS_IF_NOT_OK( repacker.SerializeSubgraph(&subgraph, *target.subgraphs.back())); } @@ -492,11 +508,11 @@ LrtStatus ModelRepacker::Repack(LrtModel model) { target.buffers.emplace_back(std::move(buf)); } - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus AppendMetadata(LrtModel model, const void* metadata, - size_t metadata_size, const char* metadata_name) { +LiteRtStatus AppendMetadata(LiteRtModel model, const void* metadata, + size_t metadata_size, const char* metadata_name) { const auto metadata_buffer_ind = model->flatbuffer_model->buffers.size(); auto& metadata_buffer = model->flatbuffer_model->buffers.emplace_back( @@ -510,35 +526,32 @@ LrtStatus AppendMetadata(LrtModel model, const void* metadata, fb_metadata->name.assign(metadata_name); fb_metadata->buffer = metadata_buffer_ind; - return StatusOk(); + return kLiteRtStatusOk; } -LrtStatus SerializeModel(LrtModel model, uint8_t** buf, size_t* size, - size_t* offset) { +LiteRtStatus SerializeModel(LiteRtModel model, uint8_t** buf, size_t* size, + size_t* offset) { // Destroy model before return. - UniqueLrtModel u_model(model); + UniqueLiteRtModel u_model(model); - LRT_RETURN_STATUS_IF_NOT_OK_MSG(ModelRepacker::Repack(model), - "Failed to repack model."); + LITERT_RETURN_STATUS_IF_NOT_OK_MSG(ModelRepacker::Repack(model), + "Failed to repack model."); flatbuffers::FlatBufferBuilder b; auto model_offset = tflite::Model::Pack(b, model->flatbuffer_model.get()); tflite::FinishModelBuffer(b, model_offset); - size_t new_buf_size; - size_t new_buf_offset; - - uint8_t* new_buf = b.ReleaseRaw(new_buf_size, new_buf_offset); + OwningBufferRef buffer; + auto [new_buf, new_size, new_offset] = buffer.GetWeak(); + new_buf = b.ReleaseRaw(new_size, new_offset); - LRT_RETURN_STATUS_IF_NOT_OK_MSG( - VerifyFlatbuffer(new_buf + new_buf_offset, new_buf_size - new_buf_offset), - "Failed to verify flatbuffer"); + LITERT_ENSURE(VerifyFlatbuffer(buffer.Span()), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to verify flatbuffer"); - *buf = new_buf; - *size = new_buf_size; - *offset = new_buf_offset; + std::tie(*buf, *size, *offset) = buffer.Release(); - return StatusOk(); + return kLiteRtStatusOk; } // NOLINTEND diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h b/tensorflow/lite/experimental/litert/core/litert_model_init.h similarity index 59% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h rename to tensorflow/lite/experimental/litert/core/litert_model_init.h index 1e5219ccd7d050..29aaf99d5d3903 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h +++ b/tensorflow/lite/experimental/litert/core/litert_model_init.h @@ -12,55 +12,55 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_INIT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_INIT_H_ -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" #ifdef __cplusplus extern "C" { #endif // __cplusplus // Load model from flatbuffer file. -LrtStatus LoadModelFromFile(const char* path, LrtModel* model); +LiteRtStatus LoadModelFromFile(const char* path, LiteRtModel* model); // Load model from flatbuffer memory. -LrtStatus LoadModel(const uint8_t* buf, size_t buf_size, LrtModel* model); +LiteRtStatus LoadModel(const uint8_t* buf, size_t buf_size, LiteRtModel* model); // Add a new custom code to the registry in this model. This will be associated // with all custom ops and should only can be set once. // TODO consider expanding this to allow for "custom op builder" hook. -LrtStatus RegisterCustomOpCode(LrtModel model, const char* new_op_code); +LiteRtStatus RegisterCustomOpCode(LiteRtModel model, const char* new_op_code); // Destroy model and any associated storage. -void ModelDestroy(LrtModel model); +void ModelDestroy(LiteRtModel model); -// Adds given metadata buffer to be serialized with the flatbuffer. Buffer can +// Adds given metadata buffer to be serialized with the flatbuffer. Weights can // be retrieved at runtime under `metadata_name`. -LrtStatus AppendMetadata(LrtModel model, const void* metadata, - size_t metadata_size, const char* metadata_name); +LiteRtStatus AppendMetadata(LiteRtModel model, const void* metadata, + size_t metadata_size, const char* metadata_name); // Serializes model to bytes. NOTE this destroys the model before it returns. // NOTE: Caller takes ownership of `buf`. Flatbuffers are packed into their // arrays back to front, so the valid flatbuffer is buf[offset, size]. -LrtStatus SerializeModel(LrtModel model, uint8_t** buf, size_t* size, - size_t* offset); +LiteRtStatus SerializeModel(LiteRtModel model, uint8_t** buf, size_t* size, + size_t* offset); #ifdef __cplusplus } #include -struct LrtModelDeleter { - void operator()(LrtModel model) { +struct LiteRtModelDeleter { + void operator()(LiteRtModel model) { if (model != nullptr) { ModelDestroy(model); } } }; -using UniqueLrtModel = std::unique_ptr; +using UniqueLiteRtModel = std::unique_ptr; #endif // __cplusplus -#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_LRT_CORE_LITE_RT_MODEL_INIT_H_ +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_INIT_H_ diff --git a/tensorflow/lite/experimental/litert/core/litert_model_serialize.cc b/tensorflow/lite/experimental/litert/core/litert_model_serialize.cc new file mode 100644 index 00000000000000..622e94e8535f54 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/litert_model_serialize.cc @@ -0,0 +1,77 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/litert_model_serialize.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" + +// +// METADATA Strategy +// + +LiteRtStatus LiteRtModelAddByteCodeMetadata(LiteRtModel model, + const char* soc_manufacturer, + const char* soc_model, + const void* byte_code, + size_t byte_code_size) { + // Register custom code shared by all NPU dispatch ops. + LITERT_RETURN_STATUS_IF_NOT_OK( + RegisterCustomOpCode(model, kLiteRtDispatchOpCustomCode)); + + // Add the build tag to the model. + const std::string m_buffer = + absl::StrFormat(kLiteRtBuildTagTpl, soc_manufacturer, soc_model, + kLiteRtMetadataSerializationStrategy); + LITERT_RETURN_STATUS_IF_NOT_OK(AppendMetadata( + model, m_buffer.data(), m_buffer.size(), kLiteRtBuildTagKey)); + + // Add the raw byte code. + LITERT_RETURN_STATUS_IF_NOT_OK(AppendMetadata( + model, byte_code, byte_code_size, kLiteRtMetadataByteCodeKey)); + + return kLiteRtStatusOk; +} + +// +// APPEND Strategy +// + +LiteRtStatus LiteRtModelPrepareForByteCodeAppend(LiteRtModel model, + const char* soc_manufacturer, + const char* soc_model) { + // Register custom code shared by all NPU dispatch ops. + LITERT_RETURN_STATUS_IF_NOT_OK( + RegisterCustomOpCode(model, kLiteRtDispatchOpCustomCode)); + + // Add the build tag to the model. + const std::string m_buffer = + absl::StrFormat(kLiteRtBuildTagTpl, soc_manufacturer, soc_model, + kLiteRtAppendSerializationStrategy); + LITERT_RETURN_STATUS_IF_NOT_OK(AppendMetadata( + model, m_buffer.data(), m_buffer.size(), kLiteRtBuildTagKey)); + + // Add the byte code placeholder. + LITERT_RETURN_STATUS_IF_NOT_OK(AppendMetadata( + model, kLiteRtAppendedByteCodePlaceholder, + sizeof(kLiteRtAppendedByteCodePlaceholder), kLiteRtMetadataByteCodeKey)); + + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/core/litert_model_serialize.h b/tensorflow/lite/experimental/litert/core/litert_model_serialize.h new file mode 100644 index 00000000000000..06d7458be6f43f --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/litert_model_serialize.h @@ -0,0 +1,114 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_SERIALIZE_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_SERIALIZE_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// Shared "custom_code" for all dispatch ops. +static const char kLiteRtDispatchOpCustomCode[] = "NPU_OP"; + +// Template for build tag to add to the model, encodes context about how the +// model was generated. +static const char kLiteRtBuildTagTpl[] = + "soc_man:%s,soc_model:%s,serialization_strategy:%s"; + +// Metadata key to lookup the build tag. +static const char kLiteRtBuildTagKey[] = "LiteRtStamp"; + +// Serializaton strategy ID for adding raw byte code directly to the metadata. +static const char kLiteRtMetadataSerializationStrategy[] = "METADATA"; + +// Serialization strategy ID for appending byte code to the end of the file. +static const char kLiteRtAppendSerializationStrategy[] = "APPEND"; + +// NPU bytecode information for the append strategy. Placeholder +// for post-processing step, [,] padded to fixed length. +static const char kLiteRtAppendedByteCodePlaceholder[] = + "[**********,**********]"; + +// Metadata key for any NPU bytecode information. +static const char kLiteRtMetadataByteCodeKey[] = "LiteRtNpuByteCode"; + +//===----------------------------------------------------------------------===// +// +// << BYTE CODE PACKING >> +// +// Strategies for packaging LiteRtCompilerPlugin compilation output with the +// flatbuffer. These are different short-term approaches used for testing and/or +// development. +// +// < STRATEGIES > +// +// All serialization strategies add 2 metadata buffers to the model. The first +// is a build stamp, which indicates the make/model as well as the serialization +// strategy targeted during plugin appliction. The second contains information +// about the NPU bytecode, which may be a location to find it, or the raw data +// itself. +// +// "METADATA" strategy +// +// Adds the raw NPU bytecode directly in the flatbuffer in a standard metadata +// buffer. +// +// This is intented for use in testing as it may bloat the flatbuffer size. +// Packing the byte code in this way allows it to be rendered by existing tflite +// tooling. +// +// "APPEND" strategy +// +// Appends compiled byte code to the end of the flatbuffer. This avoids cases +// where embedding byte code directly would break 2gb flatbuffer limit. +// Offset into the file and where the byte code starts and size is stored in +// metadata. +// +// The actual value of the offset is written into serialized flatbuffer +// as a post processing step. This function populates the offset with a fixed +// size placeholder for size_t(s) which may be left padded by some filler +// characters. +// +//===----------------------------------------------------------------------===// + +// Adds NPU bytecode and build tag to metadata. +// Registers the "custom_code". +LiteRtStatus LiteRtModelAddByteCodeMetadata(LiteRtModel model, + const char* soc_manufacturer, + const char* soc_model, + const void* byte_code, + size_t byte_code_size); + +// Preps the model for future post processing step. A +// string with parts parseable as size_t (offset, size) is set in the metadata. +// A future step will find the prefix of this string and +// replace the size_t portions with the actual offset and size +// post-serializaiton. This post-process step cannot not change the length of +// the string, and therefore the result may be left-padded with filler +// characters. Also populates build tag and registers "custom_code". +LiteRtStatus LiteRtModelPrepareForByteCodeAppend(LiteRtModel model, + const char* soc_manufacturer, + const char* soc_model); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_LITERT_MODEL_SERIALIZE_H_ diff --git a/tensorflow/lite/experimental/litert/core/litert_model_serialize_test.cc b/tensorflow/lite/experimental/litert/core/litert_model_serialize_test.cc new file mode 100644 index 00000000000000..fc72fe84a9e600 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/litert_model_serialize_test.cc @@ -0,0 +1,122 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/litert_model_serialize.h" + +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/schema/schema_generated.h" + +namespace { + +using ::graph_tools::GetMetadata; +using ::litert::OwningBufferRef; +using ::litert::internal::VerifyFlatbuffer; +using ::litert::testing::LoadTestFileModel; +using ::testing::HasSubstr; + +static constexpr absl::string_view kSocModel = "TestSocModel"; +static constexpr absl::string_view kSocMan = "TestSocMan"; + +// Gets a test model with a single custom op with empty attributes. +UniqueLiteRtModel GetTestModel() { + static constexpr absl::string_view kTestModel = "one_mul.tflite"; + return LoadTestFileModel(kTestModel); +} + +UniqueLiteRtModel RoundTrip(UniqueLiteRtModel model) { + OwningBufferRef fb; + auto [buf, size, offset] = fb.GetWeak(); + LITERT_RETURN_VAL_IF_NOT_OK( + SerializeModel(model.release(), &buf, &size, &offset), {}); + LITERT_ENSURE(VerifyFlatbuffer(fb.Span()), {}, "Failed to verify flatbuffer"); + + LiteRtModel new_model; + LITERT_RETURN_VAL_IF_NOT_OK(LoadModel(fb.Data(), fb.Size(), &new_model), {}); + + return UniqueLiteRtModel(new_model); +} + +bool HasCustomCode(const LiteRtModelT& model, + const absl::string_view custom_code) { + const auto& fb = model.flatbuffer_model; + for (auto& c : fb->operator_codes) { + if (c->custom_code == custom_code && + c->builtin_code == tflite::BuiltinOperator_CUSTOM) { + return true; + } + } + return false; +} + +TEST(TestByteCodePacking, MetadataStrategy) { + static constexpr absl::string_view kByteCode = "some_byte_code"; + + auto model = GetTestModel(); + ASSERT_STATUS_OK(LiteRtModelAddByteCodeMetadata( + model.get(), kSocMan.data(), kSocModel.data(), kByteCode.data(), + kByteCode.size())); + + model = RoundTrip(std::move(model)); + ASSERT_NE(model, nullptr); + + EXPECT_TRUE(HasCustomCode(*model, kLiteRtDispatchOpCustomCode)); + + ASSERT_RESULT_OK_ASSIGN(auto build_tag, + GetMetadata(model.get(), kLiteRtBuildTagKey)); + EXPECT_EQ(build_tag.StrView(), + "soc_man:TestSocMan,soc_model:TestSocModel,serialization_strategy:" + "METADATA"); + + ASSERT_RESULT_OK_ASSIGN(auto byte_code, + GetMetadata(model.get(), kLiteRtMetadataByteCodeKey)); + EXPECT_EQ(byte_code.StrView(), kByteCode); +} + +TEST(TestByteCodePacking, AppendStrategy) { + auto model = GetTestModel(); + ASSERT_STATUS_OK(LiteRtModelPrepareForByteCodeAppend( + model.get(), kSocMan.data(), kSocModel.data())); + + model = RoundTrip(std::move(model)); + ASSERT_NE(model, nullptr); + + EXPECT_TRUE(HasCustomCode(*model, kLiteRtDispatchOpCustomCode)); + + ASSERT_RESULT_OK_ASSIGN(auto build_tag, + GetMetadata(model.get(), kLiteRtBuildTagKey)); + EXPECT_EQ(build_tag.StrView(), + "soc_man:TestSocMan,soc_model:TestSocModel,serialization_strategy:" + "APPEND"); + + ASSERT_RESULT_OK_ASSIGN(auto byte_code_placeholder, + GetMetadata(model.get(), kLiteRtMetadataByteCodeKey)); + EXPECT_THAT(byte_code_placeholder.StrView(), + HasSubstr(kLiteRtAppendedByteCodePlaceholder)); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/core/model.cc b/tensorflow/lite/experimental/litert/core/model.cc new file mode 100644 index 00000000000000..4f7f35c5193d77 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model.cc @@ -0,0 +1,214 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/model.h" + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" +#include "tensorflow/lite/schema/schema_generated.h" + +using ::litert::MutableBufferRef; + +// +// Model +// + +LiteRtStatus GetModelNumSubgraphs(LiteRtModel model, + LiteRtParamIndex* num_subgraphs) { + *num_subgraphs = model->subgraphs.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus GetModelSubgraph(LiteRtModel model, + LiteRtParamIndex subgraph_index, + LiteRtSubgraph* subgraph) { + if (subgraph_index >= model->subgraphs.size()) { + return kLiteRtStatusErrorIndexOOB; + } + *subgraph = model->subgraphs.data() + subgraph_index; + return kLiteRtStatusOk; +} + +LiteRtStatus GetModelMainSubgraph(LiteRtModel model, + LiteRtParamIndex* main_subgraph_index) { + // TODO replace this with signature. + *main_subgraph_index = 0; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtModelGetMetadata(LiteRtModel model, const char* metadata_key, + const void** metadata_buffer, + size_t* metadata_buffer_size) { + LITERT_ASSIGN_OR_RETURN_STATUS(auto m_buf, model->FindMetadata(metadata_key)); + *metadata_buffer = m_buf.Data(); + *metadata_buffer_size = m_buf.Size(); + return kLiteRtStatusOk; +} + +void ModelDestroy(LiteRtModel model) { + if (model != nullptr) { + delete model; + } +} + +LiteRtStatus PushOp(LiteRtOpList op_list, LiteRtOp op) { + op_list->Push(op); + return kLiteRtStatusOk; +} + +LiteRtResult> LiteRtModelT::FindMetadata( + const absl::string_view key) const { + using ResT = LiteRtResult>; + + tflite::MetadataT* fb_metadata = nullptr; + for (auto& m : flatbuffer_model->metadata) { + if (m->name == key) { + fb_metadata = m.get(); + break; + } + } + if (fb_metadata == nullptr) { + return ResT::FromStatus(kLiteRtStatusErrorNotFound); + } + + const uint32_t m_buffer_idx = fb_metadata->buffer; + if (m_buffer_idx >= flatbuffer_model->buffers.size()) { + return ResT::FromStatus(kLiteRtStatusErrorIndexOOB); + } + tflite::BufferT* m_buffer = flatbuffer_model->buffers.at(m_buffer_idx).get(); + + return ResT::FromValue( + MutableBufferRef(m_buffer->data.data(), m_buffer->data.size())); +} + +// +// Subgraph +// + +LiteRtStatus GetSubgraphInputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs) { + *num_inputs = subgraph->inputs.size(); + *inputs = subgraph->inputs.data(); + return kLiteRtStatusOk; +} + +LiteRtStatus GetSubgraphOutputs(LiteRtSubgraph subgraph, + LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs) { + *num_outputs = subgraph->outputs.size(); + *outputs = subgraph->outputs.data(); + return kLiteRtStatusOk; +} + +LiteRtStatus GetSubgraphOps(LiteRtSubgraph subgraph, LiteRtParamIndex* num_ops, + LiteRtOpArray* ops) { + *num_ops = subgraph->ops.size(); + *ops = subgraph->ops.data(); + return kLiteRtStatusOk; +} + +// +// Op +// + +LiteRtStatus GetOpOutputs(LiteRtOp op, LiteRtParamIndex* num_outputs, + LiteRtTensorArray* outputs) { + *num_outputs = op->outputs.size(); + *outputs = op->outputs.data(); + return kLiteRtStatusOk; +} + +LiteRtStatus GetOpInputs(LiteRtOp op, LiteRtParamIndex* num_inputs, + LiteRtTensorArray* inputs) { + *num_inputs = op->inputs.size(); + *inputs = op->inputs.data(); + return kLiteRtStatusOk; +} + +LiteRtStatus GetOpCode(LiteRtOp op, LiteRtOpCode* code) { + *code = op->op_code; + return kLiteRtStatusOk; +} + +// +// Tensor +// + +LiteRtStatus GetWeightsInfo(LiteRtWeights weights, size_t* size, + const void** addr) { + if (weights->fb_buffer == nullptr) { + *size = 0; + *addr = nullptr; + } else { + *size = weights->fb_buffer->data.size(); + *addr = weights->fb_buffer->data.data(); + } + return kLiteRtStatusOk; +} + +LiteRtStatus GetTensorWeights(LiteRtTensor tensor, LiteRtWeights* weights) { + *weights = &tensor->weights; + return kLiteRtStatusOk; +} + +LiteRtStatus GetTensorUses(LiteRtTensor tensor, LiteRtParamIndex* num_uses, + LiteRtOpArray* use_users, + LiteRtParamIndex** use_user_arg_inds) { + *num_uses = tensor->users.size(); + *use_users = tensor->users.data(); + *use_user_arg_inds = tensor->user_arg_inds.data(); + return kLiteRtStatusOk; +} + +// Null if subgraph input or constant. +LiteRtStatus GetTensorDefiningOp( + LiteRtTensor tensor, LiteRtOp* maybe_defining_op, + LiteRtParamIndex* maybe_defining_op_output_ind) { + if (tensor->defining_op != nullptr) { + *maybe_defining_op = tensor->defining_op; + *maybe_defining_op_output_ind = tensor->defining_op_out_ind; + } + return kLiteRtStatusOk; +} + +LiteRtStatus GetTensorTypeId(LiteRtTensor tensor, LiteRtTensorTypeId* type_id) { + *type_id = tensor->type_id; + return kLiteRtStatusOk; +} + +LiteRtStatus GetUrankedTensorType( + LiteRtTensor tensor, LiteRtUnrankedTensorType* unranked_tensor_type) { + if (tensor->type_id != kLiteRtUnrankedTensorType) { + return kLiteRtStatusErrorInvalidIrType; + } + *unranked_tensor_type = tensor->type_detail.unranked_tensor_type; + return kLiteRtStatusOk; +} + +LiteRtStatus GetRankedTensorType(LiteRtTensor tensor, + LiteRtRankedTensorType* ranked_tensor_type) { + if (tensor->type_id != kLiteRtRankedTensorType) { + return kLiteRtStatusErrorInvalidIrType; + } + *ranked_tensor_type = tensor->type_detail.ranked_tensor_type; + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/core/model.h b/tensorflow/lite/experimental/litert/core/model.h new file mode 100644 index 00000000000000..5fa9eadadc57bc --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/model.h @@ -0,0 +1,159 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" +#include "tensorflow/lite/schema/schema_generated.h" + +// +// Tensor +// + +struct LiteRtWeightsT { + std::unique_ptr fb_buffer = nullptr; +}; + +typedef union { + LiteRtUnrankedTensorType unranked_tensor_type; + LiteRtRankedTensorType ranked_tensor_type; +} LiteRtTypeDetail; + +struct LiteRtTensorT { + // Empty if subgraph output. This is a reference. + std::vector users; + + // Which arg number for user i. + std::vector user_arg_inds; + + // Null if subgraph input or constant. This is a reference. + LiteRtOp defining_op = nullptr; + + // Which output ind from defining op made this tensor. + LiteRtParamIndex defining_op_out_ind; + + // Not a reference. + LiteRtWeightsT weights; + + LiteRtTensorTypeId type_id; + + LiteRtTypeDetail type_detail; +}; + +// +// Op +// + +struct LiteRtOpT { + // These are references. + std::vector inputs; + + // These are references. + std::vector outputs; + + LiteRtOpCode op_code; + + litert::OwningBufferRef custom_options; + + tflite::BuiltinOptionsUnion option; +}; + +// +// Subgraph +// + +struct LiteRtSubgraphT { + // Storage and views of tensors. Clients are only shown views. Facilitates + // efficient topological mutation. + std::list tensors_storage; + std::vector tensors; + + // Storage and vies of ops. + std::list ops_storage; + std::vector ops; + + // Shared view of initial flatbuffer data. + std::shared_ptr flatbuffer_subgraph; + + // These are references and a subset of `tensors`. + std::vector inputs; + + // These are references and a subset of `tensors`. + std::vector outputs; +}; + +// +// Model +// + +// A (partial) unpacking of the flatbuffer model into a list of subgraphs. +// Keeps a reference to the flatbuffer model. Lifetimes of all storage +// are linked to the containing model. +struct LiteRtModelT { + // Subgraphs that have been unpacked into usable types. + std::vector subgraphs; + + // TODO: b/365299994 - Delete this. + // Shared views of remaining unpacked flatbuffer data. + std::vector> flatbuffer_subgraphs; + + // Initial flatbuffer loaded in. "Subgraphs" field has been invalidated. + std::unique_ptr flatbuffer_model; + + // Custom code associated with all customs ops emitted during + // re-serialization. + std::string custom_op_code; + + // Look up metadata by key, getting a view of its buffer as a string + // if it exists. + LiteRtResult> FindMetadata( + absl::string_view key) const; +}; + +// +// Utils +// + +// Used for communicating selections of ops. +class LiteRtOpListT { + public: + void Push(LiteRtOp op) { ops_.push_back(op); } + + std::vector Vec() const { + std::vector res; + res.reserve(ops_.size()); + res.assign(ops_.begin(), ops_.end()); + return res; + } + + private: + // NOTE: This was originally a vector. Was encountering really odd + // segfaults when freeing after code on another side of a compilation boundary + // was doing pushes that resized. A list+copy to vector is not optimimal, + // revisit if bottleneck. + std::list ops_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_MODEL_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc b/tensorflow/lite/experimental/litert/core/model_test.cc similarity index 63% rename from tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc rename to tensorflow/lite/experimental/litert/core/model_test.cc index 34923c6bc433e6..6ee71466a2e793 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/core/model_test.cc +++ b/tensorflow/lite/experimental/litert/core/model_test.cc @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include // NOLINTNEXTLINE #include #include @@ -23,56 +21,54 @@ #include // IWYU pragma: keep #include -#include "flatbuffers/verifier.h" // from @flatbuffers #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_common.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_model.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/c/lite_rt_op_code.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/cc/lite_rt_support.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/graph_tools.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/core/lite_rt_model_init.h" -#include "tensorflow/compiler/mlir/lite/experimental/lrt/test_data/test_data_util.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" #include "tensorflow/lite/schema/schema_generated.h" namespace { -inline bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { - flatbuffers::Verifier::Options options; - flatbuffers::Verifier verifier(buf, buf_size, options); - return tflite::VerifyModelBuffer(verifier); -} +using ::graph_tools::GetMetadata; +using ::litert::BufferRef; +using ::litert::OwningBufferRef; +using ::litert::internal::VerifyFlatbuffer; -inline UniqueLrtModel LoadModelThroughRoundTrip(std::string_view path) { - auto model = LoadTestFileModel(path); +inline UniqueLiteRtModel LoadModelThroughRoundTrip(std::string_view path) { + auto model = litert::testing::LoadTestFileModel(path); - uint8_t* buf = nullptr; - size_t buf_size; - size_t offset; + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); - LRT_CHECK_STATUS_OK_MSG( - SerializeModel(model.release(), &buf, &buf_size, &offset), + LITERT_CHECK_STATUS_OK_MSG( + SerializeModel(model.release(), &data, &size, &offset), "Failed to serialize model"); // Reload model. - LrtModel result = nullptr; - LRT_CHECK_STATUS_OK_MSG(LoadModel(buf + offset, buf_size - offset, &result), - "Failed to re load model"); - delete[] buf; + LiteRtModel result = nullptr; + LITERT_CHECK_STATUS_OK_MSG(LoadModel(buf.Data(), buf.Size(), &result), + "Failed to re load model"); - return UniqueLrtModel(result); + return UniqueLiteRtModel(result); } class TestWithPath : public ::testing::TestWithParam {}; -class TopologyTest : public ::testing::TestWithParam { +class TopologyTest : public ::testing::TestWithParam { public: - static std::vector MakeTestModels( + static std::vector MakeTestModels( const std::vector& paths) { - std::vector result; + std::vector result; for (auto p : paths) { - result.push_back(LoadTestFileModel(p).release()); + result.push_back(litert::testing::LoadTestFileModel(p).release()); result.push_back(LoadModelThroughRoundTrip(p).release()); } @@ -80,13 +76,13 @@ class TopologyTest : public ::testing::TestWithParam { } }; -TEST(LrtModelTest, TestLoadTestDataBadFilepath) { - LrtModel model = nullptr; +TEST(LiteRtModelTest, TestLoadTestDataBadFilepath) { + LiteRtModel model = nullptr; ASSERT_STATUS_HAS_CODE(LoadModelFromFile("bad_path", &model), - kLrtStatusBadFileOp); + kLiteRtStatusErrorFileIO); } -TEST(LrtModelTest, TestLoadTestDataBadFileData) { +TEST(LiteRtModelTest, TestLoadTestDataBadFileData) { // NOLINTBEGIN #ifndef NDEBUG // In debug mode, flatbuffers will `assert` while verifying. This will @@ -101,41 +97,41 @@ TEST(LrtModelTest, TestLoadTestDataBadFileData) { bad_file << "not_tflite"; bad_file.close(); - LrtModel model = nullptr; + LiteRtModel model = nullptr; ASSERT_STATUS_HAS_CODE(LoadModelFromFile(test_file_path.c_str(), &model), - kLrtStatusFlatbufferFailedVerify); + kLiteRtStatusErrorInvalidFlatbuffer); // NOLINTEND } TEST(TestSerializeModel, TestAllocations) { - auto model = LoadTestFileModel("add_simple.tflite"); - - uint8_t* buf = nullptr; - size_t buf_size; - size_t offset; + auto model = litert::testing::LoadTestFileModel("add_simple.tflite"); - ASSERT_STATUS_OK(SerializeModel(model.release(), &buf, &buf_size, &offset)); + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); - delete[] buf; + ASSERT_STATUS_OK(SerializeModel(model.release(), &data, &size, &offset)); + EXPECT_TRUE(VerifyFlatbuffer(data + offset, size - offset)); } TEST(TestSerializeModel, TestMetadata) { - auto model = LoadTestFileModel("add_simple.tflite"); + auto model = litert::testing::LoadTestFileModel("add_simple.tflite"); constexpr static std::string_view kMetadataName = "an_soc_manufacturer"; constexpr static std::string_view kMetadataData = "My_Meta_Data"; ASSERT_STATUS_OK(AppendMetadata(model.get(), kMetadataData.data(), kMetadataData.size(), kMetadataName.data())); + ASSERT_RESULT_OK_ASSIGN(auto m_buffer, + GetMetadata(model.get(), kMetadataName)); + EXPECT_EQ(m_buffer.StrView(), kMetadataData); - uint8_t* buf = nullptr; - size_t buf_size; - size_t offset; + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); - ASSERT_STATUS_OK(SerializeModel(model.release(), &buf, &buf_size, &offset)); - EXPECT_TRUE(VerifyFlatbuffer(buf + offset, buf_size - offset)); + ASSERT_STATUS_OK(SerializeModel(model.release(), &data, &size, &offset)); + EXPECT_TRUE(VerifyFlatbuffer(buf.Span())); - auto new_model = tflite::UnPackModel(buf + offset); + auto new_model = tflite::UnPackModel(buf.Data()); ASSERT_NE(new_model, nullptr); ASSERT_GT(new_model->metadata.size(), 0); @@ -153,30 +149,24 @@ TEST(TestSerializeModel, TestMetadata) { tflite::BufferT* metadata_buffer = new_model->buffers.at(fb_metadata->buffer).get(); + BufferRef metadata_buf(metadata_buffer->data.data(), + metadata_buffer->data.size()); - std::string_view fb_metadata_data( - reinterpret_cast(metadata_buffer->data.data()), - metadata_buffer->data.size()); - - EXPECT_EQ(fb_metadata_data, kMetadataData); - - delete[] buf; + EXPECT_EQ(metadata_buf.StrView(), kMetadataData); } TEST(TestSerializeModel, TestCustomOpCode) { - auto model = LoadTestFileModel("add_simple.tflite"); + auto model = litert::testing::LoadTestFileModel("add_simple.tflite"); constexpr static std::string_view kCustomCode = "MyCustomCode"; ASSERT_STATUS_OK(RegisterCustomOpCode(model.get(), kCustomCode.data())); - uint8_t* buf = nullptr; - size_t buf_size; - size_t offset; - - ASSERT_STATUS_OK(SerializeModel(model.release(), &buf, &buf_size, &offset)); - EXPECT_TRUE(VerifyFlatbuffer(buf + offset, buf_size - offset)); + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); - auto new_model = tflite::UnPackModel(buf + offset); + ASSERT_STATUS_OK(SerializeModel(model.release(), &data, &size, &offset)); + EXPECT_TRUE(VerifyFlatbuffer(buf.Span())); + auto new_model = tflite::UnPackModel(buf.Data()); tflite::OperatorCodeT* custom_op_code = nullptr; for (auto& c : new_model->operator_codes) { @@ -188,16 +178,14 @@ TEST(TestSerializeModel, TestCustomOpCode) { ASSERT_NE(custom_op_code, nullptr); ASSERT_EQ(custom_op_code->custom_code, kCustomCode); ASSERT_EQ(custom_op_code->builtin_code, tflite::BuiltinOperator_CUSTOM); - - delete[] buf; } TEST_P(TestWithPath, TestConstructDestroy) { - UniqueLrtModel model = LoadTestFileModel(GetParam()); + UniqueLiteRtModel model = litert::testing::LoadTestFileModel(GetParam()); } TEST_P(TestWithPath, TestConstructDestroyRoundTrip) { - UniqueLrtModel model = LoadModelThroughRoundTrip(GetParam()); + UniqueLiteRtModel model = LoadModelThroughRoundTrip(GetParam()); } INSTANTIATE_TEST_SUITE_P(InstTestWithPath, TestWithPath, @@ -208,14 +196,14 @@ INSTANTIATE_TEST_SUITE_P(InstTestWithPath, TestWithPath, using AddSimpleTest = TopologyTest; TEST_P(AddSimpleTest, TestBuildModelAddSimple) { - UniqueLrtModel model(GetParam()); + UniqueLiteRtModel model(GetParam()); // func(arg0) // output = tfl.add(arg0, arg0) // return(output) // - ASSERT_RESULT_OK_ASSIGN(LrtSubgraph subgraph, + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, graph_tools::GetSubgraph(model.get())); ASSERT_RESULT_OK_ASSIGN(auto subgraph_inputs, graph_tools::GetSubgraphInputs(subgraph)); @@ -231,9 +219,10 @@ TEST_P(AddSimpleTest, TestBuildModelAddSimple) { ASSERT_EQ(ops.size(), 1); auto op = ops[0]; - graph_tools::RankedTypeInfo float_2by2_type(kLrtElementTypeFloat32, {2, 2}); + graph_tools::RankedTypeInfo float_2by2_type(kLiteRtElementTypeFloat32, + {2, 2}); ASSERT_TRUE(graph_tools::MatchOpType(op, {float_2by2_type, float_2by2_type}, - {float_2by2_type}, kLrtOpCodeTflAdd)); + {float_2by2_type}, kLiteRtOpCodeTflAdd)); ASSERT_RESULT_OK_ASSIGN(auto op_inputs, graph_tools::GetOpIns(op)); ASSERT_EQ(op_inputs.size(), 2); @@ -243,8 +232,8 @@ TEST_P(AddSimpleTest, TestBuildModelAddSimple) { ASSERT_RESULT_OK_ASSIGN(auto op_out, graph_tools::GetOnlyOpOut(op)); ASSERT_EQ(op_out, subgraph_outputs[0]); - ASSERT_TRUE(graph_tools::MatchNoBuffer(subgraph_outputs[0])); - ASSERT_TRUE(graph_tools::MatchNoBuffer(subgraph_inputs[0])); + ASSERT_TRUE(graph_tools::MatchNoWeights(subgraph_outputs[0])); + ASSERT_TRUE(graph_tools::MatchNoWeights(subgraph_inputs[0])); } INSTANTIATE_TEST_SUITE_P( @@ -254,7 +243,7 @@ INSTANTIATE_TEST_SUITE_P( using AddCstTest = TopologyTest; TEST_P(AddCstTest, TestBuildModelAddCst) { - UniqueLrtModel model(GetParam()); + UniqueLiteRtModel model(GetParam()); // func(arg0) // cst = ConstantTensor([1, 2, 3, 4]) @@ -262,7 +251,7 @@ TEST_P(AddCstTest, TestBuildModelAddCst) { // return(output) // - ASSERT_RESULT_OK_ASSIGN(LrtSubgraph subgraph, + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, graph_tools::GetSubgraph(model.get())); ASSERT_RESULT_OK_ASSIGN(auto subgraph_inputs, graph_tools::GetSubgraphInputs(subgraph)); @@ -278,21 +267,21 @@ TEST_P(AddCstTest, TestBuildModelAddCst) { ASSERT_EQ(ops.size(), 1); auto op = ops[0]; - graph_tools::RankedTypeInfo float_2by2_type(kLrtElementTypeFloat32, {4}); + graph_tools::RankedTypeInfo float_2by2_type(kLiteRtElementTypeFloat32, {4}); ASSERT_TRUE(graph_tools::MatchOpType(op, {float_2by2_type, float_2by2_type}, - {float_2by2_type}, kLrtOpCodeTflAdd)); + {float_2by2_type}, kLiteRtOpCodeTflAdd)); ASSERT_RESULT_OK_ASSIGN(auto op_inputs, graph_tools::GetOpIns(op)); ASSERT_EQ(op_inputs.size(), 2); ASSERT_EQ(op_inputs[0], subgraph_inputs[0]); - ASSERT_TRUE(graph_tools::MatchBuffer( + ASSERT_TRUE(graph_tools::MatchWeights( op_inputs[1], llvm::ArrayRef{1.0, 2.0, 3.0, 4.0})); ASSERT_RESULT_OK_ASSIGN(auto op_out, graph_tools::GetOnlyOpOut(op)); ASSERT_EQ(op_out, subgraph_outputs[0]); - ASSERT_TRUE(graph_tools::MatchNoBuffer(subgraph_outputs[0])); - ASSERT_TRUE(graph_tools::MatchNoBuffer(subgraph_inputs[0])); + ASSERT_TRUE(graph_tools::MatchNoWeights(subgraph_outputs[0])); + ASSERT_TRUE(graph_tools::MatchNoWeights(subgraph_inputs[0])); } INSTANTIATE_TEST_SUITE_P( @@ -302,7 +291,7 @@ INSTANTIATE_TEST_SUITE_P( using SimpleMultiOpTest = TopologyTest; TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { - UniqueLrtModel model(GetParam()); + UniqueLiteRtModel model(GetParam()); // func.func @main(arg0) // 0 = tfl.add arg0, arg0 @@ -311,7 +300,7 @@ TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { // 3 = tfl.add 2, 2 // return 3 - ASSERT_RESULT_OK_ASSIGN(LrtSubgraph subgraph, + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, graph_tools::GetSubgraph(model.get())); ASSERT_RESULT_OK_ASSIGN(auto subgraph_inputs, graph_tools::GetSubgraphInputs(subgraph)); @@ -331,11 +320,12 @@ TEST_P(SimpleMultiOpTest, TestBuildModelSimpleMultiAdd) { ASSERT_EQ(inputs[0], inputs[1]); } - graph_tools::RankedTypeInfo float_2by2_type(kLrtElementTypeFloat32, {2, 2}); + graph_tools::RankedTypeInfo float_2by2_type(kLiteRtElementTypeFloat32, + {2, 2}); ASSERT_TRUE(graph_tools::MatchOpType(ops[2], {float_2by2_type, float_2by2_type}, - {float_2by2_type}, kLrtOpCodeTflMul)); + {float_2by2_type}, kLiteRtOpCodeTflMul)); } INSTANTIATE_TEST_SUITE_P(SimpleMultiOpTests, SimpleMultiOpTest, diff --git a/tensorflow/lite/experimental/litert/core/option.cc b/tensorflow/lite/experimental/litert/core/option.cc new file mode 100644 index 00000000000000..7c4a094e45e66f --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/option.cc @@ -0,0 +1,227 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_options.h" +#include "tensorflow/lite/experimental/litert/core/model.h" + +// +// Op Options +// + +LiteRtStatus LiteRtAddGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflAdd) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = op->option.AsAddOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtBatchMatmulGetAdjXOption(LiteRtOp op, bool* adj_x) { + if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *adj_x = op->option.AsBatchMatMulOptions()->adj_x; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtBatchMatmulGetAdjYOption(LiteRtOp op, bool* adj_y) { + if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *adj_y = op->option.AsBatchMatMulOptions()->adj_y; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtBatchMatmulGetAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input) { + if (op->op_code != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusErrorInvalidArgument; + } + *asymmetric_quantize_input = + op->option.AsBatchMatMulOptions()->asymmetric_quantize_inputs; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtConcatenationGetFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + op->option.AsConcatenationOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtConcatenationGetAxisOption(LiteRtOp op, int32_t* axis) { + if (op->op_code != kLiteRtOpCodeTflConcatenation) { + return kLiteRtStatusErrorInvalidArgument; + } + *axis = op->option.AsConcatenationOptions()->axis; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtDivGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflDiv) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = op->option.AsDivOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetFusedActivationOption( + LiteRtOp op, uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = + op->option.AsFullyConnectedOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetKeepNumDimsOption(LiteRtOp op, + bool* keep_num_dims) { + if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *keep_num_dims = op->option.AsFullyConnectedOptions()->keep_num_dims; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetQuantizedBiasTypeOption( + LiteRtOp op, uint32_t* quantized_bias_type) { + if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *quantized_bias_type = + op->option.AsFullyConnectedOptions()->quantized_bias_type; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetAsymmetricQuantizeInputOption( + LiteRtOp op, bool* asymmetric_quantize_input) { + if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *asymmetric_quantize_input = + op->option.AsFullyConnectedOptions()->asymmetric_quantize_inputs; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtFullyConnectedGetWeightsFormatOption( + LiteRtOp op, uint32_t* weights_format) { + if (op->op_code != kLiteRtOpCodeTflFullyConnected) { + return kLiteRtStatusErrorInvalidArgument; + } + *weights_format = op->option.AsFullyConnectedOptions()->weights_format; + return kLiteRtStatusOk; +} +LiteRtStatus LiteRtMulGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflMul) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = op->option.AsMulOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtSoftmaxGetBetaOption(LiteRtOp op, float* beta) { + if (op->op_code != kLiteRtOpCodeTflSoftmax) { + return kLiteRtStatusErrorInvalidArgument; + } + *beta = op->option.AsSoftmaxOptions()->beta; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetBeginMaskOption(LiteRtOp op, + int32_t* begin_mask) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *begin_mask = op->option.AsStridedSliceOptions()->begin_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetEndMaskOption(LiteRtOp op, + int32_t* end_mask) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *end_mask = op->option.AsStridedSliceOptions()->end_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetEllipsisMaskOption(LiteRtOp op, + int32_t* ellipsis_mask) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *ellipsis_mask = op->option.AsStridedSliceOptions()->ellipsis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetNewAxisMaskOption(LiteRtOp op, + int32_t* new_axis_mask) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *new_axis_mask = op->option.AsStridedSliceOptions()->new_axis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetShrinkAxisMaskOption( + LiteRtOp op, int32_t* shrink_axis_mask) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *shrink_axis_mask = op->option.AsStridedSliceOptions()->shrink_axis_mask; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtStridedSliceGetOffsetOption(LiteRtOp op, bool* offset) { + if (op->op_code != kLiteRtOpCodeTflStridedSlice) { + return kLiteRtStatusErrorInvalidArgument; + } + *offset = op->option.AsStridedSliceOptions()->offset; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtSubGetFusedActivationOption(LiteRtOp op, + uint32_t* fused_activation) { + if (op->op_code != kLiteRtOpCodeTflSub) { + return kLiteRtStatusErrorInvalidArgument; + } + *fused_activation = op->option.AsSubOptions()->fused_activation_function; + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtReshapeGetNewShapeOption(LiteRtOp op, int32_t** new_shape, + int32_t* new_shape_size) { + if (op->op_code != kLiteRtOpCodeTflReshape) { + return kLiteRtStatusErrorInvalidArgument; + } + if (op->option.AsReshapeOptions() == nullptr) { + *new_shape_size = -1; + return kLiteRtStatusOk; + } else { + *new_shape = op->option.AsReshapeOptions()->new_shape.data(); + *new_shape_size = op->option.AsReshapeOptions()->new_shape.size(); + } + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/core/option_test.cc b/tensorflow/lite/experimental/litert/core/option_test.cc new file mode 100644 index 00000000000000..80586fd4b72337 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/option_test.cc @@ -0,0 +1,209 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +// NOLINTNEXTLINE + +#include // IWYU pragma: keep +#include +#include "llvm/ADT/ArrayRef.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_options.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { +TEST(GetOpOptionTest, TestGetAddOptions) { + auto model = litert::testing::LoadTestFileModel("simple_add_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK(LiteRtAddGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetBatchMatmulOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_batch_matmul_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + bool adj_x; + ASSERT_STATUS_OK(LiteRtBatchMatmulGetAdjXOption(op, &adj_x)); + ASSERT_EQ(adj_x, false); + + bool adj_y; + ASSERT_STATUS_OK(LiteRtBatchMatmulGetAdjYOption(op, &adj_y)); + ASSERT_EQ(adj_y, false); + + bool asymmetric_quantize_input; + ASSERT_STATUS_OK(LiteRtBatchMatmulGetAsymmetricQuantizeInputOption( + op, &asymmetric_quantize_input)); + ASSERT_EQ(asymmetric_quantize_input, false); +} + +TEST(GetOpOptionTest, TestGetConcatenationOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_concatenation_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK( + LiteRtConcatenationGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); + + int32_t axis; + ASSERT_STATUS_OK(LiteRtConcatenationGetAxisOption(op, &axis)); + ASSERT_EQ(axis, 2); +} + +TEST(GetOpOptionTest, TestGetDivOptions) { + auto model = litert::testing::LoadTestFileModel("simple_div_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK(LiteRtDivGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetFullyConnectedOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_fully_connected_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK( + LiteRtFullyConnectedGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); + + uint32_t weights_format; + ASSERT_STATUS_OK( + LiteRtFullyConnectedGetWeightsFormatOption(op, &weights_format)); + ASSERT_EQ(weights_format, 0); + + bool keep_num_dims; + ASSERT_STATUS_OK( + LiteRtFullyConnectedGetKeepNumDimsOption(op, &keep_num_dims)); + ASSERT_EQ(keep_num_dims, true); + + uint32_t quantized_bias_type; + ASSERT_STATUS_OK( + LiteRtFullyConnectedGetQuantizedBiasTypeOption(op, &quantized_bias_type)); + ASSERT_EQ(quantized_bias_type, 0); + + bool asymmetric_quantize_input; + ASSERT_STATUS_OK(LiteRtFullyConnectedGetAsymmetricQuantizeInputOption( + op, &asymmetric_quantize_input)); + ASSERT_EQ(asymmetric_quantize_input, false); +} + +TEST(GetOpOptionTest, TestGetMulOptions) { + auto model = litert::testing::LoadTestFileModel("simple_mul_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK(LiteRtMulGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetSoftmaxOptions) { + auto model = litert::testing::LoadTestFileModel("simple_softmax_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + float beta; + ASSERT_STATUS_OK(LiteRtSoftmaxGetBetaOption(op, &beta)); + EXPECT_FLOAT_EQ(beta, 1.0); +} + +TEST(GetOpOptionTest, TestGetStridedSliceOptions) { + auto model = + litert::testing::LoadTestFileModel("simple_strided_slice_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + int32_t begin_mask; + ASSERT_STATUS_OK(LiteRtStridedSliceGetBeginMaskOption(op, &begin_mask)); + ASSERT_EQ(begin_mask, 0); + + int32_t end_mask; + ASSERT_STATUS_OK(LiteRtStridedSliceGetEndMaskOption(op, &end_mask)); + ASSERT_EQ(end_mask, 0); + + int32_t ellipsis_mask; + ASSERT_STATUS_OK(LiteRtStridedSliceGetEllipsisMaskOption(op, &ellipsis_mask)); + ASSERT_EQ(ellipsis_mask, 0); + + int32_t new_axis_mask; + ASSERT_STATUS_OK(LiteRtStridedSliceGetNewAxisMaskOption(op, &new_axis_mask)); + ASSERT_EQ(new_axis_mask, 0); + + int32_t shrink_axis_mask; + ASSERT_STATUS_OK( + LiteRtStridedSliceGetShrinkAxisMaskOption(op, &shrink_axis_mask)); + ASSERT_EQ(shrink_axis_mask, 0); + + bool offset; + ASSERT_STATUS_OK(LiteRtStridedSliceGetOffsetOption(op, &offset)); + ASSERT_EQ(offset, false); +} + +TEST(GetOpOptionTest, TestGetSubOptions) { + auto model = litert::testing::LoadTestFileModel("simple_sub_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + uint32_t fused_activation; + ASSERT_STATUS_OK(LiteRtSubGetFusedActivationOption(op, &fused_activation)); + ASSERT_EQ(fused_activation, 0); +} + +TEST(GetOpOptionTest, TestGetReshapeOptions) { + auto model = litert::testing::LoadTestFileModel("simple_reshape_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(LiteRtSubgraph subgraph, + graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, graph_tools::GetSubgraphOps(subgraph)); + auto op = ops[0]; + + int32_t* new_shape = nullptr; + int32_t new_shape_size; + ASSERT_STATUS_OK( + LiteRtReshapeGetNewShapeOption(op, &new_shape, &new_shape_size)); + ASSERT_EQ(new_shape_size, -1); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/core/tensor_buffer.cc b/tensorflow/lite/experimental/litert/core/tensor_buffer.cc new file mode 100644 index 00000000000000..f16c481c0be2cd --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/tensor_buffer.cc @@ -0,0 +1,425 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/tensor_buffer.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/core/ahwb_buffer.h" +#include "tensorflow/lite/experimental/litert/core/dmabuf_buffer.h" +#include "tensorflow/lite/experimental/litert/core/event.h" +#include "tensorflow/lite/experimental/litert/core/fastrpc_buffer.h" +#include "tensorflow/lite/experimental/litert/core/ion_buffer.h" +#include "tensorflow/lite/experimental/litert/core/utils.h" + +namespace { + +template +void Copy(size_t array_size, const T*& array, std::vector& vec) { + vec.clear(); + vec.reserve(array_size); + std::copy(array, array + array_size, std::back_inserter(vec)); + array = vec.data(); +} + +} // namespace + +LiteRtTensorBufferT::LiteRtTensorBufferT( + const LiteRtRankedTensorType& tensor_type, + LiteRtTensorBufferType buffer_type, size_t buffer_size, + size_t buffer_offset) + : tensor_type_(tensor_type), + buffer_type_(buffer_type), + buffer_size_(buffer_size), + buffer_offset_(buffer_offset) { + // Copy local memory passed by the caller. + Copy(tensor_type_.layout.rank, tensor_type_.layout.dimensions, dimensions_); + if (tensor_type_.layout.strides) { + Copy(tensor_type_.layout.rank, tensor_type_.layout.strides, strides_); + } +} + +LiteRtTensorBufferT::~LiteRtTensorBufferT() { + switch (buffer_type()) { + case kLiteRtTensorBufferTypeUnknown: + // Nothing to do. + break; + case kLiteRtTensorBufferTypeHostMemory: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeAhwb: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.ahwb); + } + break; + case kLiteRtTensorBufferTypeIon: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeDmaBuf: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + case kLiteRtTensorBufferTypeFastRpc: + if (auto& buffer = std::get(buffer_); buffer.deallocator) { + buffer.deallocator(buffer.addr); + } + break; + } +} + +absl::StatusOr +LiteRtTensorBufferT::CreateFromHostMemory( + const LiteRtRankedTensorType& tensor_type, absl::Span host_memory, + LiteRtHostMemoryDeallocator deallocator) { + Ptr tensor_buffer(new LiteRtTensorBufferT( + tensor_type, kLiteRtTensorBufferTypeHostMemory, host_memory.size())); + tensor_buffer->buffer_ = HostBuffer{ + .addr = host_memory.data(), + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status.ok()) { + return status; + } + + return tensor_buffer; +} + +absl::StatusOr +LiteRtTensorBufferT::CreateManagedOnHostMemory( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + void* host_memory_ptr; + if (auto rc = ::posix_memalign( + &host_memory_ptr, LITERT_HOST_MEMORY_BUFFER_ALIGNMENT, buffer_size); + rc) { + return absl::InternalError("Failed to allocate aligned memory"); + } + + LiteRtHostMemoryDeallocator deallocator = ::free; + auto tensor_buffer = CreateFromHostMemory( + tensor_type, + absl::MakeSpan(static_cast(host_memory_ptr), buffer_size), + deallocator); + if (!tensor_buffer.ok()) { + free(host_memory_ptr); + return tensor_buffer.status(); + } + + return std::move(*tensor_buffer); +} + +absl::StatusOr LiteRtTensorBufferT::CreateFromAhwb( + const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator) { + auto buffer_size = litert::internal::AhwbBuffer::GetSize(ahwb); + if (!buffer_size.ok()) { + return buffer_size.status(); + } + + Ptr tensor_buffer(new LiteRtTensorBufferT( + tensor_type, kLiteRtTensorBufferTypeAhwb, *buffer_size, ahwb_offset)); + tensor_buffer->buffer_ = AhwbBuffer{ + .ahwb = ahwb, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status.ok()) { + return status; + } + + return tensor_buffer; +} + +absl::StatusOr +LiteRtTensorBufferT::CreateManagedAhwbBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::AhwbBuffer::Alloc(buffer_size); + if (!buffer.ok()) { + return buffer.status(); + } + return CreateFromAhwb(tensor_type, buffer->ahwb, /*ahwb_offset=*/0, + /*deallocator=*/litert::internal::AhwbBuffer::Free); +} + +absl::StatusOr +LiteRtTensorBufferT::CreateFromIonBuffer( + const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator) { + if (!ion_buffer_addr) { + return absl::InvalidArgumentError("Invalid ION buffer address"); + } + if (ion_buffer_fd < 0) { + return absl::InvalidArgumentError("Invalid ION buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeIon, + ion_buffer_size, ion_buffer_offset)); + tensor_buffer->buffer_ = IonBuffer{ + .addr = ion_buffer_addr, + .fd = ion_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status.ok()) { + return status; + } + + return tensor_buffer; +} + +absl::StatusOr +LiteRtTensorBufferT::CreateManagedIonBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::IonBuffer::Alloc( + buffer_size, /*alignment=*/LITERT_HOST_MEMORY_BUFFER_ALIGNMENT); + if (!buffer.ok()) { + return buffer.status(); + } + return CreateFromIonBuffer(tensor_type, buffer->addr, buffer->fd, buffer_size, + /*ion_buffer_offset=*/0, + litert::internal::IonBuffer::Free); +} + +absl::StatusOr +LiteRtTensorBufferT::CreateFromDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, LiteRtDmaBufDeallocator deallocator) { + if (!dmabuf_buffer_addr) { + return absl::InvalidArgumentError("Invalid DMA-BUF buffer address"); + } + if (dmabuf_buffer_fd < 0) { + return absl::InvalidArgumentError("Invalid DMA-BUF buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeDmaBuf, + dmabuf_buffer_size, dmabuf_buffer_offset)); + tensor_buffer->buffer_ = DmaBufBuffer{ + .addr = dmabuf_buffer_addr, + .fd = dmabuf_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status.ok()) { + return status; + } + + return tensor_buffer; +} + +absl::StatusOr +LiteRtTensorBufferT::CreateManagedDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::DmaBufBuffer::Alloc(buffer_size); + if (!buffer.ok()) { + return buffer.status(); + } + return CreateFromDmaBufBuffer(tensor_type, buffer->addr, buffer->fd, + buffer_size, /*dmabuf_buffer_offset=*/0, + litert::internal::DmaBufBuffer::Free); +} + +absl::StatusOr +LiteRtTensorBufferT::CreateFromFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, LiteRtFastRpcDeallocator deallocator) { + if (!fastrpc_buffer_addr) { + return absl::InvalidArgumentError("Invalid FastRPC buffer address"); + } + if (fastrpc_buffer_fd < 0) { + return absl::InvalidArgumentError("Invalid FastRPC buffer fd"); + } + + Ptr tensor_buffer( + new LiteRtTensorBufferT(tensor_type, kLiteRtTensorBufferTypeFastRpc, + fastrpc_buffer_size, fastrpc_buffer_offset)); + tensor_buffer->buffer_ = FastRpcBuffer{ + .addr = fastrpc_buffer_addr, + .fd = fastrpc_buffer_fd, + .deallocator = deallocator, + }; + + if (auto status = tensor_buffer->IsValid(); !status.ok()) { + return status; + } + + return tensor_buffer; +} + +absl::StatusOr +LiteRtTensorBufferT::CreateManagedFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + auto buffer = litert::internal::FastRpcBuffer::Alloc(buffer_size); + if (!buffer.ok()) { + return buffer.status(); + } + return CreateFromFastRpcBuffer(tensor_type, buffer->addr, buffer->fd, + buffer_size, /*fastrpc_buffer_offset=*/0, + litert::internal::FastRpcBuffer::Free); +} + +absl::StatusOr LiteRtTensorBufferT::CreateManaged( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType& tensor_type, size_t buffer_size) { + switch (buffer_type) { + case kLiteRtTensorBufferTypeHostMemory: + return CreateManagedOnHostMemory(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeAhwb: + return CreateManagedAhwbBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeIon: + return CreateManagedIonBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeDmaBuf: + return CreateManagedDmaBufBuffer(tensor_type, buffer_size); + case kLiteRtTensorBufferTypeFastRpc: + return CreateManagedFastRpcBuffer(tensor_type, buffer_size); + default: + return absl::InvalidArgumentError("Unexpected tensor type"); + } +} + +absl::Status LiteRtTensorBufferT::IsValid() const { + // Check for static dimensions. + for (auto i = 0; i < tensor_type_.layout.rank; ++i) { + if (tensor_type_.layout.dimensions[i] <= 0) { + return absl::InternalError( + "TensorBuffer must have all static dimensions"); + } + } + + // Check for valid offset. + if (buffer_offset() >= buffer_size()) { + return absl::InternalError("Invalid buffer offset"); + } + + // Check for sufficient size. + if (auto num_bytes = litert::internal::GetNumPackedBytes(tensor_type_); + !num_bytes.ok()) { + return num_bytes.status(); + } else if (*num_bytes > buffer_size() - buffer_offset()) { + return absl::InternalError("Insufficient buffer size"); + } + + // Check for proper alignment. + if (buffer_type() == kLiteRtTensorBufferTypeHostMemory) { + auto host_buffer = GetHostBuffer(); + if (!host_buffer.ok()) { + return host_buffer.status(); + } + if (reinterpret_cast(*host_buffer) % + LITERT_HOST_MEMORY_BUFFER_ALIGNMENT) { + return absl::InternalError("Unaligned host memory pointer"); + } + } + + return {}; +} + +absl::StatusOr LiteRtTensorBufferT::GetHostBuffer() const { + if (buffer_type_ != kLiteRtTensorBufferTypeHostMemory) { + return absl::InternalError("Unexpected tensor buffer type"); + } + return std::get(buffer_).addr; +} + +absl::StatusOr LiteRtTensorBufferT::GetAhwbBuffer() const { + if (buffer_type_ != kLiteRtTensorBufferTypeAhwb) { + return absl::InternalError("Unexpected tensor buffer type"); + } + return std::get(buffer_).ahwb; +} + +absl::StatusOr> LiteRtTensorBufferT::GetIonBuffer() + const { + if (buffer_type_ != kLiteRtTensorBufferTypeIon) { + return absl::InternalError("Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +absl::StatusOr> LiteRtTensorBufferT::GetDmaBufBuffer() + const { + if (buffer_type_ != kLiteRtTensorBufferTypeDmaBuf) { + return absl::InternalError("Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +absl::StatusOr> LiteRtTensorBufferT::GetFastRpcBuffer() + const { + if (buffer_type_ != kLiteRtTensorBufferTypeFastRpc) { + return absl::InternalError("Unexpected tensor buffer type"); + } + auto buffer = std::get(buffer_); + return std::make_pair(buffer.addr, buffer.fd); +} + +absl::StatusOr LiteRtTensorBufferT::Lock(LiteRtEvent event) { + if (event) { + // Only AHWB supports waiting on an input sync fence when locking the + // buffer. For all other buffer types we wait here. + if (buffer_type() != kLiteRtTensorBufferTypeAhwb) { + if (auto status = event->Wait(/*timeout_in_ms*/ -1); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to wait on input event"); + } + } + } + + switch (buffer_type()) { + case kLiteRtTensorBufferTypeHostMemory: + return *GetHostBuffer(); + case kLiteRtTensorBufferTypeAhwb: + return litert::internal::AhwbBuffer::Lock(*GetAhwbBuffer(), event); + case kLiteRtTensorBufferTypeIon: + return GetIonBuffer()->first; + case kLiteRtTensorBufferTypeDmaBuf: + return GetDmaBufBuffer()->first; + case kLiteRtTensorBufferTypeFastRpc: + return GetFastRpcBuffer()->first; + default: + return absl::InternalError("Unexpected tensor buffer type"); + } +} + +absl::Status LiteRtTensorBufferT::Unlock() { + if (buffer_type() == kLiteRtTensorBufferTypeAhwb) { + auto ahwb = std::get(buffer_).ahwb; + return litert::internal::AhwbBuffer::Unlock(ahwb); + } + + return {}; +} diff --git a/tensorflow/lite/experimental/litert/core/tensor_buffer.h b/tensorflow/lite/experimental/litert/core/tensor_buffer.h new file mode 100644 index 00000000000000..7c2138873bf2f3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/tensor_buffer.h @@ -0,0 +1,144 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TENSOR_BUFFER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TENSOR_BUFFER_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" + +class LiteRtTensorBufferT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtTensorBufferT(); + + // Make this class non-copiable because it includes raw pointers and resource + // handles. + LiteRtTensorBufferT(const LiteRtTensorBufferT&) = delete; + LiteRtTensorBufferT(LiteRtTensorBufferT&&) = delete; + LiteRtTensorBufferT& operator=(const LiteRtTensorBufferT&) = delete; + LiteRtTensorBufferT& operator=(LiteRtTensorBufferT&&) = delete; + + static absl::StatusOr CreateFromHostMemory( + const LiteRtRankedTensorType& tensor_type, + absl::Span host_memory, + LiteRtHostMemoryDeallocator deallocator = nullptr); + + static absl::StatusOr CreateFromAhwb( + const LiteRtRankedTensorType& tensor_type, AHardwareBuffer* ahwb, + size_t ahwb_offset, LiteRtAhwbDeallocator deallocator = nullptr); + + static absl::StatusOr CreateFromIonBuffer( + const LiteRtRankedTensorType& tensor_type, void* ion_buffer_addr, + int ion_buffer_fd, size_t ion_buffer_size, size_t ion_buffer_offset, + LiteRtIonDeallocator deallocator = nullptr); + + static absl::StatusOr CreateFromDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, void* dmabuf_buffer_addr, + int dmabuf_buffer_fd, size_t dmabuf_buffer_size, + size_t dmabuf_buffer_offset, + LiteRtDmaBufDeallocator deallocator = nullptr); + + static absl::StatusOr CreateFromFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, void* fastrpc_buffer_addr, + int fastrpc_buffer_fd, size_t fastrpc_buffer_size, + size_t fastrpc_buffer_offset, + LiteRtFastRpcDeallocator deallocator = nullptr); + + static absl::StatusOr CreateManaged( + LiteRtTensorBufferType buffer_type, + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + LiteRtRankedTensorType tensor_type() const { return tensor_type_; } + LiteRtTensorBufferType buffer_type() const { return buffer_type_; } + size_t buffer_size() const { return buffer_size_; } + size_t buffer_offset() const { return buffer_offset_; } + + absl::StatusOr GetHostBuffer() const; + absl::StatusOr GetAhwbBuffer() const; + absl::StatusOr> GetIonBuffer() const; + absl::StatusOr> GetDmaBufBuffer() const; + absl::StatusOr> GetFastRpcBuffer() const; + + absl::StatusOr Lock(LiteRtEvent event = nullptr); + absl::Status Unlock(); + + private: + struct HostBuffer { + void* addr; + LiteRtHostMemoryDeallocator deallocator; + }; + + struct AhwbBuffer { + AHardwareBuffer* ahwb; + LiteRtAhwbDeallocator deallocator; + }; + + struct IonBuffer { + void* addr; + int fd; + LiteRtIonDeallocator deallocator; + }; + + struct DmaBufBuffer { + void* addr; + int fd; + LiteRtDmaBufDeallocator deallocator; + }; + + struct FastRpcBuffer { + void* addr; + int fd; + LiteRtFastRpcDeallocator deallocator; + }; + + LiteRtTensorBufferT(const LiteRtRankedTensorType& tensor_type, + LiteRtTensorBufferType buffer_type, size_t buffer_size, + size_t buffer_offset = 0); + + static absl::StatusOr CreateManagedOnHostMemory( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static absl::StatusOr CreateManagedAhwbBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static absl::StatusOr CreateManagedIonBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static absl::StatusOr CreateManagedDmaBufBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + static absl::StatusOr CreateManagedFastRpcBuffer( + const LiteRtRankedTensorType& tensor_type, size_t buffer_size); + + absl::Status IsValid() const; + + LiteRtRankedTensorType tensor_type_; + std::vector> dimensions_; + std::vector> strides_; + LiteRtTensorBufferType buffer_type_; + size_t buffer_size_; + size_t buffer_offset_; + std::variant + buffer_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TENSOR_BUFFER_H_ diff --git a/tensorflow/lite/experimental/litert/core/tfl_utils.cc b/tensorflow/lite/experimental/litert/core/tfl_utils.cc new file mode 100644 index 00000000000000..44d416d5287fdf --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/tfl_utils.cc @@ -0,0 +1,96 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/tfl_utils.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/c/c_api_opaque.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" + +namespace litert { +namespace internal { + +absl::StatusOr ConvertElementType(TfLiteType tfl_type) { + switch (tfl_type) { + case kTfLiteNoType: + return ElementType::None; + case kTfLiteBool: + return ElementType::Bool; + case kTfLiteInt4: + return ElementType::Int4; + case kTfLiteInt8: + return ElementType::Int8; + case kTfLiteInt16: + return ElementType::Int16; + case kTfLiteInt32: + return ElementType::Int32; + case kTfLiteInt64: + return ElementType::Int64; + case kTfLiteUInt8: + return ElementType::UInt8; + case kTfLiteUInt16: + return ElementType::UInt16; + case kTfLiteUInt32: + return ElementType::UInt32; + case kTfLiteUInt64: + return ElementType::UInt64; + case kTfLiteFloat16: + return ElementType::Float16; + case kTfLiteBFloat16: + return ElementType::BFloat16; + case kTfLiteFloat32: + return ElementType::Float32; + case kTfLiteFloat64: + return ElementType::Float64; + case kTfLiteComplex64: + return ElementType::Complex64; + case kTfLiteComplex128: + return ElementType::Complex128; + case kTfLiteResource: + return ElementType::TfResource; + case kTfLiteString: + return ElementType::TfString; + case kTfLiteVariant: + return ElementType::TfVariant; + default: + return absl::InternalError("Unsupported TfLiteType"); + } +} + +absl::StatusOr ConvertTensorType( + const TfLiteOpaqueTensor* tfl_opaque_tensor) { + auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); + auto element_type = ConvertElementType(tfl_type); + if (!element_type.ok()) { + return element_type.status(); + } + + size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor); + std::vector dimensions(rank); + for (size_t i = 0; i < rank; ++i) { + dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i); + } + + return RankedTensorType(*element_type, Layout(std::move(dimensions))); +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/tfl_utils.h b/tensorflow/lite/experimental/litert/core/tfl_utils.h new file mode 100644 index 00000000000000..c3603fbcc5d9c2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/tfl_utils.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TFL_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TFL_UTILS_H_ + +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/cc/litert_model.h" + +struct TfLiteOpaqueTensor; + +namespace litert { +namespace internal { + +absl::StatusOr ConvertElementType(TfLiteType tfl_type); + +absl::StatusOr ConvertTensorType( + const TfLiteOpaqueTensor* tfl_opaque_tensor); + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_TFL_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/BUILD b/tensorflow/lite/experimental/litert/core/util/BUILD new file mode 100644 index 00000000000000..8851cbd47ef74d --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/util/BUILD @@ -0,0 +1,56 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "buffer_ref", + hdrs = [ + "buffer_ref.h", + ], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "buffer_ref_test", + srcs = ["buffer_ref_test.cc"], + deps = [ + ":buffer_ref", + ":flatbuffer_tools", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "flatbuffer_tools", + srcs = ["flatbuffer_tools.cc"], + hdrs = [ + "flatbuffer_tools.h", + ], + deps = [ + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@flatbuffers//:runtime_cc", + ], +) diff --git a/tensorflow/lite/experimental/litert/core/util/buffer_ref.h b/tensorflow/lite/experimental/litert/core/util/buffer_ref.h new file mode 100644 index 00000000000000..ce3a9e871f0d6c --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/util/buffer_ref.h @@ -0,0 +1,348 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_BUFFER_REF_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_BUFFER_REF_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace litert { + +//===----------------------------------------------------------------------===// +// +// << BUFFER REF >> +// +// Read, read/write, and owning views of buffers of arbitrary byte width types. +// +// Serialized model artifacts and assets are frequently large strings that with +// (annoyingly) non-standard char type and left padded. The following classes +// simplify handling such buffers in an efficient copy free manner. They also +// provide read and write left-padded aware interpretebility through standard +// signed char strings types. This is used for making manual edits to flatbuffer +// metadata or dierctly to serialized flatbuffer. +// NOTE: std::basic_xxx not supported by our C++ toolchain. +// +// Pre-allocated buffers can be transferred to these classes or allocation can +// be internalized. XBufferRefs can be implictly upcasted to non-owning +// read/write or read-only to provide other routines with an appropriate view of +// the data. E.g.: +// +// ``` +// void ReadBuffer(BufferRef r_buf) { std::cerr << r_buf.StrView(); } +// void WriteToBuffer(MutableBufferRef rw_buf) { rw_buf.WriteTo("SomeData"); } +// ... +// OwningBuffer buf(size); +// WriteToBuffer(buf); // Implicitly convert to read/write with no ownership. +// ReadBuffer(buf); // Implicitly convert to read-only. +// ``` +// +//===----------------------------------------------------------------------===// + +// Allocation/Deallocation behavior for owning buffer refs. An allocator is a +// trivially constructible/destructible object that overrides () for allocating +// and freeing memory. + +// Malloc/free based memory. +template +struct Mallocator { + void operator()(ByteT* d) { + if (d != nullptr) { + free(d); + } + } + + ByteT* operator()(size_t bytes) { + return reinterpret_cast(malloc(bytes)); + } +}; + +// New/delete based memory. +template +struct Newlocator { + void operator()(ByteT* d) { + if (d != nullptr) { + delete[] d; + } + } + + ByteT* operator()(size_t bytes) { return new ByteT[bytes]; } +}; + +// +// Read-Only Bytes +// + +// Immutable and non-owning view of a buffer. +template +class BufferRef { + public: + using TupleT = std::tuple; + + // Null buffer. + explicit BufferRef() : size_(0), offset_(0), data_(nullptr) {} + + // Construct from already allocated buffer. Methods will only expose + // data[offset, offset + size]. + BufferRef(const ByteT* data, size_t size, size_t offset = 0) + : size_(size), offset_(offset), data_(const_cast(data)) {} + BufferRef(const void* data, size_t size, size_t offset = 0) + : size_(size), + offset_(offset), + data_(const_cast(reinterpret_cast(data))) {} + explicit BufferRef(absl::Span data) + : size_(data.size()), + offset_(0), + data_(const_cast(data.data())) {} + + // Start of actual data. + const ByteT* Data() const { return data_ + offset_; } + + // Size of actual data. + size_t Size() const { return size_ - offset_; } + + // Get buffer details in tuple form. + TupleT Get() const { return TupleT(data_, size_, offset_); } + + // Start of actual data as signed char. Might not be null terminated. + const char* StrData() const { return reinterpret_cast(Data()); } + + // Convenience view of actual data as a string. Makes null terminated. + absl::string_view StrView() const { + return absl::string_view(StrData(), Size()); + } + + // Const view of actual data. + absl::Span Span() const { + return absl::MakeConstSpan(Data(), Size()); + } + + // Copy the buffer data to a vector. + std::vector ToVec() const { + return std::vector(StrData(), StrData() + Size()); + } + + // Write the string data to a stream. + void WriteStr(std::ostream& out) const { out.write(StrData(), Size()); } + + // Print info about this buffer. + void Dump(std::ostream& out) const { + out << absl::StreamFormat("%s[%lu:%lu]\n", TypeName(), offset_, size_); + } + + BufferRef(const BufferRef& other) = default; + BufferRef& operator=(const BufferRef& other) = default; + + virtual ~BufferRef() = default; + + protected: + size_t size_; + size_t offset_; + ByteT* data_ = nullptr; + + // Debug name. + virtual absl::string_view TypeName() const { return "BufferRef"; } +}; +template +BufferRef(const ByteT*, size_t, size_t) -> BufferRef; + +// +// Read-Write Non-Owning Bytes +// + +// Writeable (but still non-owning) version of BufferRef. +template +class MutableBufferRef : public BufferRef { + public: + using TupleT = std::tuple; + + // Null buffer. + explicit MutableBufferRef() + : BufferRef((ByteT*)nullptr, /*size*/ 0, /*offset*/ 0) {} + + // Create a mutable view from pre-allocated non-const buffer. + MutableBufferRef(ByteT* data, size_t size, size_t offset = 0) + : BufferRef(data, size, offset) {} + MutableBufferRef(void* data, size_t size, size_t offset = 0) + : BufferRef(data, size, offset) {} + explicit MutableBufferRef(absl::Span data) : BufferRef(data) {} + explicit MutableBufferRef(absl::Span data) = delete; + MutableBufferRef(const ByteT*, size_t, size_t) = delete; + MutableBufferRef(const void*, size_t, size_t) = delete; + + // Mutable start of actual data. + ByteT* Data() { return this->data_ + this->offset_; } + + // Get buffer info in tuple form. + TupleT Get() { return TupleT(this->data_, this->size_, this->offset_); } + + // Mutable span of actual data. + absl::Span Span() { return absl::MakeSpan(Data(), this->Size()); } + + // Write string into the actual buffer at offset. Returns false if the entire + // string cannot fit into the actual buffer. + bool WriteInto(absl::string_view str, size_t offset = 0) { + if (str.size() > this->Size() - offset) { + return false; + } + std::memcpy(Data() + offset, str.data(), str.size()); + return true; + } + + MutableBufferRef(const MutableBufferRef& other) = default; + MutableBufferRef& operator=(const MutableBufferRef& other) = default; + + protected: + // Debug name. + absl::string_view TypeName() const override { return "MutableBufferRef"; } +}; +template +MutableBufferRef(ByteT*, size_t, size_t) -> MutableBufferRef; + +// +// Read-Write Owning Bytes +// + +// Writable and owning buffer reference. Can allocate new buffers internally and +// take ownership of existing buffers. Does not support resizing. +template > +class OwningBufferRef : public MutableBufferRef { + public: + using TupleT = std::tuple; + using WeakTupleT = std::tuple; + + // Null buffer. + explicit OwningBufferRef() + : MutableBufferRef(/*data*/ (ByteT*)nullptr, /*size*/ 0) {} + + // Initialize a new buffer reference and allocate internally. + explicit OwningBufferRef(size_t size) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, /*offset*/ 0) { + this->data_ = (ByteT*)Allocator()(size); + } + + // Take ownership of given buffer. + OwningBufferRef(ByteT* data, size_t size, size_t offset = 0) + : MutableBufferRef(data, size, offset) {} + OwningBufferRef(void* data, size_t size, size_t offset = 0) + : MutableBufferRef(data, size, offset) {} + explicit OwningBufferRef(absl::Span data) + : MutableBufferRef(data) {} + + // Copy the given buffer. + OwningBufferRef(const ByteT* data, size_t size) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, size, + /*offset*/ 0) { + this->data_ = (ByteT*)Allocator()(size); + std::memcpy(this->data_, data, size); + } + explicit OwningBufferRef(absl::Span data) + : OwningBufferRef(data.data(), data.size()) {} + + // Copy data from givens string. + explicit OwningBufferRef(absl::string_view data) + : OwningBufferRef( + reinterpret_cast(data.data()), data.size()) {} + + // Copy data from given c-style string. + explicit OwningBufferRef(const char* data) + : OwningBufferRef(absl::string_view(data)) {} + + // Drop reference to any owned memory. + void Drop() { + this->data_ = nullptr; + this->size_ = 0; + this->offset_ = 0; + } + + // Get the buffer details and drop references to them. + TupleT Release() { + auto res = std::make_tuple(this->data_, this->size_, this->offset_); + Drop(); + return res; + } + + // Get weak references to buffer data. Takes ownership of anything that + // is swapped in. + WeakTupleT GetWeak() { + return WeakTupleT(this->data_, this->size_, this->offset_); + } + + OwningBufferRef(OwningBufferRef&& other) + : MutableBufferRef(other.data_, other.size_, other.offset_) { + other.data_ = nullptr; + other.size_ = 0; + other.offset_ = 0; + } + + OwningBufferRef& operator=(OwningBufferRef&& other) { + if (this != &other) { + Allocator()(this->data_); + this->data_ = other.data_; + this->size_ = other.size_; + this->offset_ = other.offset_; + other.data_ = nullptr; + other.size_ = 0; + other.offset_ = 0; + } + return *this; + } + + OwningBufferRef(const OwningBufferRef& other) + : MutableBufferRef(/*data*/ (ByteT*)nullptr, other.size_, + other.offset_) { + this->data_ = (ByteT*)Allocator()(other.size_); + std::memcpy(this->data_, other.data_, other.size_); + } + + OwningBufferRef& operator=(const OwningBufferRef& other) { + Allocator()(this->data_); + this->size_ = other.size_; + this->data_ = (ByteT*)Allocator()(this->size_); + std::memcpy(this->data_, other.data_, this->size_); + this->offset_ = other.offset_; + } + + ~OwningBufferRef() override { + Allocator()(this->data_); + this->data_ = nullptr; + this->size_ = 0; + this->offset_ = 0; + } + + protected: + // Debug string. + absl::string_view TypeName() const override { return "OwningBufferRef"; } +}; +template > +OwningBufferRef(const ByteT*, size_t) -> OwningBufferRef; +template > +OwningBufferRef(ByteT*, size_t) -> OwningBufferRef; +template > +OwningBufferRef(const char*) -> OwningBufferRef; + +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_BUFFER_REF_H_ diff --git a/tensorflow/lite/experimental/litert/core/util/buffer_ref_test.cc b/tensorflow/lite/experimental/litert/core/util/buffer_ref_test.cc new file mode 100644 index 00000000000000..50f50eed0c9ffa --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/util/buffer_ref_test.cc @@ -0,0 +1,324 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" + +namespace { + +using ::testing::ElementsAreArray; +using ::testing::Eq; +using ::testing::Pointwise; +using ::testing::StartsWith; + +using ::litert::BufferRef; +using ::litert::Mallocator; +using ::litert::MutableBufferRef; +using ::litert::Newlocator; +using ::litert::OwningBufferRef; +using ::litert::internal::FbBufToStr; + +static constexpr size_t kOffset = 4; + +static constexpr absl::string_view kData = "SomeRawBuffer"; +static constexpr absl::string_view kOtherData = "SOMERawBuffer"; + +absl::Span MakeConstFbData(absl::string_view data) { + const uint8_t* fb_data = reinterpret_cast(data.data()); + return absl::MakeConstSpan(fb_data, data.size()); +} + +absl::Span MakeFbData(absl::string_view data) { + const uint8_t* c_fb_data = reinterpret_cast(data.data()); + uint8_t* fb_data = const_cast(c_fb_data); + return absl::MakeSpan(fb_data, data.size()); +} + +std::vector MakeFbDataVec(absl::string_view data) { + const uint8_t* c_fb_data = reinterpret_cast(data.data()); + uint8_t* fb_data = const_cast(c_fb_data); + return std::vector(fb_data, fb_data + data.size()); +} + +template , typename ByteT = uint8_t> +absl::Span MakeInternalTestBuffer(absl::string_view data) { + ByteT* buffer = Allocator()(data.size()); + std::memcpy(buffer, data.data(), data.size()); + return absl::MakeSpan(reinterpret_cast(buffer), data.size()); +} + +// +// flatbuffer_tools.h +// + +TEST(FbBufToStringTest, ConstSpan) { + EXPECT_THAT(FbBufToStr(MakeConstFbData(kData)), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, Span) { + EXPECT_THAT(FbBufToStr(MakeFbData(kData)), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, ConstPointer) { + auto data = MakeConstFbData(kData); + EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); +} + +TEST(FbBufToStringTest, Pointer) { + auto data = MakeFbData(kData); + EXPECT_THAT(FbBufToStr(data.data(), data.size()), Pointwise(Eq(), kData)); +} + +// +// BufferRef (read-only) +// + +TEST(BufferRefTest, Dump) { + BufferRef buf(kData.data(), kData.size()); + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.view(), StartsWith("BufferRef")); +} + +TEST(BufferRefTest, WithData) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + EXPECT_EQ(buf.Span(), data); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(BufferRefTest, WithDataAndOffset) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size(), kOffset); + EXPECT_EQ(buf.Span(), data.subspan(kOffset, buf.Size())); + EXPECT_EQ(buf.StrView(), kData.substr(kOffset, buf.Size())); +} + +TEST(BufferRefTest, ToVec) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + EXPECT_THAT(buf.ToVec(), ElementsAreArray(data)); +} + +TEST(BufferRefTest, WriteStr) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size()); + std::stringstream out; + buf.WriteStr(out); + EXPECT_EQ(out.view(), kData); +} + +TEST(BufferRefTest, WriteStrOffset) { + auto data = MakeConstFbData(kData); + BufferRef buf(data.data(), data.size(), kOffset); + std::stringstream out; + buf.WriteStr(out); + EXPECT_EQ(out.view(), kData.substr(kOffset, buf.Size())); +} + +TEST(BufferRefTest, TupleGet) { + auto input = MakeConstFbData(kData); + BufferRef buf(input); + auto [data, size, offset] = buf.Get(); + ASSERT_EQ(offset, 0); + EXPECT_EQ(input, buf.Span()); +} + +// +// MutableBufferRef (read/write) +// + +TEST(MutableBufferRefTest, Dump) { + MutableBufferRef buf; + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.view(), StartsWith("MutableBufferRef")); +} + +TEST(MutableBufferRefTest, WriteInto) { + auto v_data = MakeFbDataVec(kOtherData); + MutableBufferRef buf(v_data.data(), v_data.size()); + ASSERT_TRUE(buf.WriteInto("Some")); + EXPECT_THAT(buf.Span(), ElementsAreArray(v_data)); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(MutableBufferRefTest, WriteIntoOffsetBuf) { + auto v_data = MakeFbDataVec(kOtherData); + static constexpr absl::string_view kExpData = "RAWBuffer"; + MutableBufferRef buf(v_data.data(), v_data.size(), kOffset); + ASSERT_TRUE(buf.WriteInto("RAW")); + EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); + EXPECT_EQ(buf.StrView(), kExpData); +} + +TEST(MutableBufferRefTest, WriteIntoOffsetData) { + auto v_data = MakeFbDataVec(kOtherData); + static constexpr absl::string_view kExpData = "SOMERAWBuffer"; + MutableBufferRef buf(v_data.data(), v_data.size()); + ASSERT_TRUE(buf.WriteInto("RAW", kOffset)); + EXPECT_THAT(buf.Span(), ElementsAreArray(MakeConstFbData(kExpData))); + EXPECT_EQ(buf.StrView(), kExpData); +} + +TEST(MutableBufferRefTest, TupleGet) { + auto input = MakeInternalTestBuffer("FOO"); + MutableBufferRef buf(input); + auto [data, size, offset] = buf.Get(); + *data = 'b'; + EXPECT_EQ(buf.StrView(), "bOO"); + delete[] input.data(); +} + +// +// OwningBufferRef (read/write with memory management) +// + +TEST(OwningBufferRefTest, Dump) { + OwningBufferRef buf; + std::stringstream out; + buf.Dump(out); + EXPECT_THAT(out.view(), StartsWith("OwningBufferRef")); +} + +TEST(OwningBufferRefTest, MoveCstor) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other(std::move(buf)); + EXPECT_EQ(other.StrView(), kData); +} + +TEST(OwningBufferRefTest, MoveAssign) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other = std::move(buf); + EXPECT_EQ(other.StrView(), kData); +} + +TEST(OwningBufferRefTest, CopyCstor) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other(buf); + other.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), kData); + EXPECT_EQ(other.StrView(), "SOMERawBuffer"); +} + +TEST(OwningBufferRefTest, CopyAssign) { + auto raw = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(raw.data(), raw.size()); + OwningBufferRef> other = buf; + other.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), kData); + EXPECT_EQ(other.StrView(), "SOMERawBuffer"); +} + +TEST(OwningBufferRefTest, InternalMalloc) { + OwningBufferRef> buf(kData.size()); + ASSERT_EQ(buf.Size(), kData.size()); + ASSERT_NE(buf.Data(), nullptr); + + buf.WriteInto(kData); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, InternalNew) { + OwningBufferRef buf(kData.size()); + ASSERT_EQ(buf.Size(), kData.size()); + ASSERT_NE(buf.Data(), nullptr); + + buf.WriteInto(kData); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipMalloc) { + auto malloc_buffer = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(malloc_buffer.data(), + malloc_buffer.size()); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipNew) { + auto new_buffer = MakeInternalTestBuffer(kData); + OwningBufferRef buf(new_buffer.data(), new_buffer.size()); + EXPECT_EQ(buf.StrView(), kData); +} + +TEST(OwningBufferRefTest, TakeOwnershipOffset) { + auto malloc_buffer = MakeInternalTestBuffer>(kData); + OwningBufferRef> buf(malloc_buffer.data(), + malloc_buffer.size(), + /*offset=*/4); + EXPECT_EQ(buf.StrView(), "RawBuffer"); +} + +TEST(OwningBufferRefTest, CopyBuffer) { + auto const_buf = MakeConstFbData(kData); + OwningBufferRef buf(const_buf.data(), const_buf.size()); + buf.WriteInto("SOME"); + EXPECT_EQ(buf.StrView(), "SOMERawBuffer"); + EXPECT_EQ(FbBufToStr(const_buf), "SomeRawBuffer"); +} + +TEST(OwningBufferRefTest, ImplicitUpCasts) { + OwningBufferRef buf(kData.size()); + BufferRef c_buf = buf; + + buf.WriteInto(kData); + EXPECT_EQ(c_buf.StrView(), buf.StrView()); +} + +TEST(OwningBufferRefTest, TupleGetWeak) { + auto input = MakeInternalTestBuffer("FOO"); + + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); + + data = input.data(); + size = input.size(); + offset = 0; + + ASSERT_EQ(buf.Size(), input.size()); + ASSERT_EQ(buf.Size(), input.size()); + + buf.WriteInto("BAR"); + + EXPECT_EQ(buf.StrView(), "BAR"); + EXPECT_EQ(buf.Span(), input); +} + +TEST(OwningBufferRefTest, TupleRelease) { + OwningBufferRef buf("BAZ"); + + auto [data, size, offset] = buf.Release(); + + EXPECT_EQ(buf.Size(), 0); + EXPECT_EQ(absl::string_view(data, size), "BAZ"); + + delete[] data; +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc new file mode 100644 index 00000000000000..6adbba0826c0aa --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.cc @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" + +#ifndef NDEBUG +// Make flatbuffers verifier `assert` in debug mode. +#define FLATBUFFERS_DEBUG_VERIFICATION_FAILURE + +#include "flatbuffers/flatbuffers.h" // from @flatbuffers // IWYU pragma: keep +#endif + +#include +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/lite/schema/schema_generated.h" + +namespace litert::internal { + +using ::flatbuffers::Verifier; +using ::tflite::VerifyModelBuffer; + +absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size) { + auto fb_buf_raw = reinterpret_cast(fb_data); + return absl::string_view(fb_buf_raw, size); +} + +absl::string_view FbBufToStr(absl::Span fb_buf) { + auto fb_buf_raw = reinterpret_cast(fb_buf.data()); + const size_t fb_buf_size = fb_buf.size(); + return absl::string_view(fb_buf_raw, fb_buf_size); +} + +absl::Span FbBufToStr(absl::Span fb_buf) { + return absl::MakeSpan(reinterpret_cast(fb_buf.data()), fb_buf.size()); +} + +absl::Span FbBufToStr(uint8_t* fb_data, size_t size) { + return absl::MakeSpan(reinterpret_cast(fb_data), size); +} + +bool VerifyFlatbuffer(absl::Span buf) { + return VerifyFlatbuffer(buf.data(), buf.size()); +} + +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size) { + flatbuffers::Verifier::Options options; +#ifndef NDEBUG + options.assert = true; +#endif + flatbuffers::Verifier verifier(buf, buf_size, options); + return VerifyModelBuffer(verifier); +} + +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h new file mode 100644 index 00000000000000..8d5f2c2bbe7416 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ + +#include + +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace litert::internal { + +// Flatbuffer's native char type is unsigned char. + +// Convenience method to get string view from native flatbuffer chars. +absl::string_view FbBufToStr(const uint8_t* fb_data, size_t size); + +// Span version. +absl::string_view FbBufToStr(absl::Span fb_buf); + +// Convenience method to get mutable signed char span from native flatbuffer +// chars. +absl::Span FbBufToStr(uint8_t* fb_data, size_t size); + +// Span to span version. +absl::Span FbBufToStr(absl::Span fb_buf); + +// Verifies given serialized flatbuffer +bool VerifyFlatbuffer(const uint8_t* buf, size_t buf_size); + +// Override of above with view input. +bool VerifyFlatbuffer(absl::Span buf); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTIL_FLATBUFFER_TOOLS_H_ diff --git a/tensorflow/lite/experimental/litert/core/utils.cc b/tensorflow/lite/experimental/litert/core/utils.cc new file mode 100644 index 00000000000000..5e5bb8d9d4b010 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/utils.cc @@ -0,0 +1,90 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/core/utils.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert { +namespace internal { + +absl::StatusOr GetElementSize(LiteRtElementType element_type) { + switch (element_type) { + case kLiteRtElementTypeInt4: + return Ratio{1, 2}; + case kLiteRtElementTypeBool: + return Ratio{1, 1}; + case kLiteRtElementTypeInt8: + case kLiteRtElementTypeUInt8: + return Ratio{1, 1}; + case kLiteRtElementTypeInt16: + case kLiteRtElementTypeUInt16: + case kLiteRtElementTypeFloat16: + case kLiteRtElementTypeBFloat16: + return Ratio{2, 1}; + case kLiteRtElementTypeInt32: + case kLiteRtElementTypeUInt32: + case kLiteRtElementTypeFloat32: + return Ratio{4, 1}; + case kLiteRtElementTypeInt64: + case kLiteRtElementTypeUInt64: + case kLiteRtElementTypeFloat64: + return Ratio{8, 1}; + case kLiteRtElementTypeComplex64: + return Ratio{16, 1}; + case kLiteRtElementTypeComplex128: + return Ratio{32, 1}; + default: + return absl::InvalidArgumentError("Unexpected element type"); + } +} + +absl::StatusOr GetNumPackedBytes(const LiteRtRankedTensorType& type) { + auto element_size = GetElementSize(type.element_type); + if (!element_size.ok()) { + return element_size.status(); + } + + auto num_elements = GetNumElements(type); + if (!num_elements.ok()) { + return num_elements.status(); + } + + return ((*num_elements * element_size->num) + (element_size->denom - 1)) / + element_size->denom; +} + +absl::StatusOr GetNumElements( + const LiteRtRankedTensorType& tensor_type) { + size_t num_elements = 1; + for (auto i = 0; i < tensor_type.layout.rank; ++i) { + auto dim = tensor_type.layout.dimensions[i]; + if (dim < 0) { + return absl::InvalidArgumentError( + "Unexpected dynamic tensor passed as input"); + } else if (dim == 0) { + return absl::InvalidArgumentError("Unexpected 0 tensor dimension"); + } + num_elements *= dim; + } + + return num_elements; +} + +} // namespace internal +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/core/utils.h b/tensorflow/lite/experimental/litert/core/utils.h new file mode 100644 index 00000000000000..309efb712cbb98 --- /dev/null +++ b/tensorflow/lite/experimental/litert/core/utils.h @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTILS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTILS_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert { +namespace internal { + +struct Ratio { + using Type = int; + Type num; + Type denom; + std::string ToString() const { return absl::StrCat(num, "/", denom); } +}; + +absl::StatusOr GetElementSize(LiteRtElementType element_type); + +absl::StatusOr GetNumElements( + const LiteRtRankedTensorType& tensor_type); + +// Get the number of bytes necessary to represent a tensor type, ignoring any +// stride information. +absl::StatusOr GetNumPackedBytes( + const LiteRtRankedTensorType& tensor_type); + +} // namespace internal +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_CORE_UTILS_H_ diff --git a/tensorflow/lite/experimental/litert/integration_test/BUILD b/tensorflow/lite/experimental/litert/integration_test/BUILD new file mode 100644 index 00000000000000..23b07d5602d7c8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/integration_test/BUILD @@ -0,0 +1,18 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) diff --git a/tensorflow/lite/experimental/litert/test/BUILD b/tensorflow/lite/experimental/litert/test/BUILD new file mode 100644 index 00000000000000..697f7515f89af3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/BUILD @@ -0,0 +1,96 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +# TODO: b/365295276 - Make custom rule and move to `.sh`. +OUT_DIR = "$(RULEDIR)" + +CONVERTER = "//tensorflow/compiler/mlir/lite:tf_tfl_translate" + +CMD = """ +for mlir_file in $(SRCS); do + $(location {converter}) --input-mlir $$mlir_file --o={out_dir}/testdata/$$(basename $$mlir_file .mlir).tflite +done +""".format( + converter = CONVERTER, + out_dir = OUT_DIR, +) + +genrule( + name = "tflite_test_data", + srcs = glob(["testdata/*.mlir"]), + outs = [s.removesuffix(".mlir") + ".tflite" for s in glob(["testdata/*.mlir"])], + cmd = CMD, + tools = [CONVERTER], +) + +cc_library( + name = "common", + testonly = 1, + srcs = [ + "common.cc", + ], + hdrs = [ + "common.h", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:litert_model_init", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@flatbuffers//:runtime_cc", + "@local_tsl//tsl/platform", + ], +) + +cc_library( + name = "simple_model", + testonly = 1, + hdrs = [ + "testdata/simple_model_test_vectors.h", + ], + data = [ + "testdata/simple_model.tflite", + "testdata/simple_model_google_tensor.bin", + "testdata/simple_model_qualcomm.bin", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + ], +) + +cc_library( + name = "simple_model_npu", + testonly = 1, + srcs = [], + hdrs = [ + "testdata/simple_model_test_vectors.h", + ], + data = [ + "testdata/simple_model_google_tensor.bin", + "testdata/simple_model_npu.tflite", + "testdata/simple_model_qualcomm.bin", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + ], +) diff --git a/tensorflow/lite/experimental/litert/test/common.cc b/tensorflow/lite/experimental/litert/test/common.cc new file mode 100644 index 00000000000000..21b376ab25a1b7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/common.cc @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/test/common.h" + +// NOLINTNEXTLINE +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tsl/platform/platform.h" + +namespace litert { +namespace testing { + +std::string GetTestFilePath(absl::string_view filename) { + static constexpr std::string_view kTestDataDir = + "tensorflow/lite/experimental/litert/" + "test/testdata/"; + + std::filesystem::path result_path; + if constexpr (!tsl::kIsOpenSource) { + result_path.append("third_party"); + } + + result_path.append(kTestDataDir); + result_path.append(filename.data()); + + return result_path.generic_string(); +} + +absl::StatusOr> LoadBinaryFile(absl::string_view filename) { + std::string model_path = GetTestFilePath(filename); + ABSL_CHECK(std::filesystem::exists(model_path)); + auto size = std::filesystem::file_size(model_path); + std::vector buffer(size); + std::ifstream f(model_path, std::ifstream::binary); + if (!f) { + return absl::InternalError("Failed to open file"); + } + f.read(buffer.data(), buffer.size()); + if (!f) { + return absl::InternalError("Failed to read file"); + } + f.close(); + return buffer; +} + +UniqueLiteRtModel LoadTestFileModel(absl::string_view filename) { + LiteRtModel model = nullptr; + LITERT_CHECK_STATUS_OK( + LoadModelFromFile(GetTestFilePath(filename).data(), &model)); + ABSL_CHECK_NE(model, nullptr); + return UniqueLiteRtModel(model); +} + +void TouchTestFile(absl::string_view filename, absl::string_view dir) { + std::filesystem::path path(dir.data()); + path.append(filename.data()); + std::ofstream f(path); +} + +} // namespace testing +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/test/common.h b/tensorflow/lite/experimental/litert/test/common.h new file mode 100644 index 00000000000000..0cc3bf897f9a38 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/common.h @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" + +#define _ASSERT_RESULT_OK_ASSIGN(decl, expr, result) \ + auto result = (expr); \ + ASSERT_TRUE(result.HasValue()); \ + decl = result.Value(); + +#define ASSERT_RESULT_OK_ASSIGN(decl, expr) \ + _ASSERT_RESULT_OK_ASSIGN(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + +#define _ASSERT_RESULT_OK_MOVE(decl, expr, result) \ + auto result = (expr); \ + ASSERT_TRUE(result.HasValue()); \ + decl = std::move(result.Value()); + +#define ASSERT_RESULT_OK_MOVE(decl, expr) \ + _ASSERT_RESULT_OK_MOVE(decl, expr, _CONCAT_NAME(_result, __COUNTER__)) + +#define ASSERT_STATUS_HAS_CODE(expr, code) \ + { \ + LiteRtStatus status = (expr); \ + ASSERT_EQ(status, code); \ + } + +#define ASSERT_STATUS_OK(expr) ASSERT_STATUS_HAS_CODE(expr, kLiteRtStatusOk); + +namespace litert { +namespace testing { + +std::string GetTestFilePath(absl::string_view filename); + +absl::StatusOr> LoadBinaryFile(absl::string_view filename); + +UniqueLiteRtModel LoadTestFileModel(absl::string_view filename); + +void TouchTestFile(absl::string_view filename, absl::string_view dir); + +} // namespace testing +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_COMMON_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_cst.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_cst.mlir rename to tensorflow/lite/experimental/litert/test/testdata/add_cst.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/add_simple.mlir rename to tensorflow/lite/experimental/litert/test/testdata/add_simple.mlir diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/mul_simple.mlir b/tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/mul_simple.mlir rename to tensorflow/lite/experimental/litert/test/testdata/mul_simple.mlir diff --git a/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir b/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir new file mode 100644 index 00000000000000..afabf1903ee846 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/one_mul.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir new file mode 100644 index 00000000000000..0902f5966f8266 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_add_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x1xf32>, %arg1: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir new file mode 100644 index 00000000000000..e756a0dab87cbc --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_batch_matmul_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x4x256x128xf32>, %arg1: tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> { + %0 = "tfl.batch_matmul"(%arg0, %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<1x4x256x128xf32>, tensor<1x4x128x128xf32>) -> tensor<1x4x256x128xf32> + return %0 : tensor<1x4x256x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir new file mode 100644 index 00000000000000..e1e9bd36ae01b0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_concatenation_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<128x4x1x256xf32>, %arg1: tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> { + %0 = "tfl.concatenation"(%arg0, %arg1) <{axis = 2 : i32, fused_activation_function = "NONE"}> : (tensor<128x4x1x256xf32>, tensor<128x4x1x256xf32>) -> tensor<128x4x2x256xf32> + return %0 : tensor<128x4x2x256xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir new file mode 100644 index 00000000000000..3748d45bcd5249 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_div_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xf32>, %arg1: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = tfl.div %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir new file mode 100644 index 00000000000000..6bd3f1fa79d77c --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_floor_mod_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tensor<5xf32> { + %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir new file mode 100644 index 00000000000000..5cad120662635e --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_fully_connected_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<128x2048xf32>, %arg1: tensor<2304x2048xf32>, %arg2: none) -> tensor<128x2304xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<128x2048xf32>, tensor<2304x2048xf32>, none) -> tensor<128x2304xf32> + return %0 : tensor<128x2304xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin new file mode 100644 index 00000000000000..208cb983671510 Binary files /dev/null and b/tensorflow/lite/experimental/litert/test/testdata/simple_model_google_tensor.bin differ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir new file mode 100644 index 00000000000000..3c2907c3b030eb --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_model_npu.mlir @@ -0,0 +1,6 @@ +module { + func.func @main(%x: tensor<2xf32>, %y: tensor<2xf32>) -> tensor<2xf32> { + %out = "tfl.custom"(%x, %y) {custom_code = "dispatch_node", custom_option = #tfl} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + func.return %out : tensor<2xf32> + } +} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin b/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin new file mode 100644 index 00000000000000..a66f76296d7698 Binary files /dev/null and b/tensorflow/lite/experimental/litert/test/testdata/simple_model_qualcomm.bin differ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h b/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h new file mode 100644 index 00000000000000..dd850fa58b72f8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h @@ -0,0 +1,70 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +constexpr const char* kModelFileName = "simple_model.tflite"; +constexpr const char* kQualcommModelFileName = "simple_model_qualcomm.bin"; +constexpr const char* kGoogleTensorModelFileName = + "simple_model_google_tensor.bin"; + +constexpr const int32_t kTestInput0Dimensions[] = {2}; +constexpr const int32_t kNumTestInput0Dimensions = + sizeof(kTestInput0Dimensions) / sizeof(kTestInput0Dimensions[0]); +constexpr const int32_t kTestInput1Dimensions[] = {2}; +constexpr const int32_t kNumTestInput1Dimensions = + sizeof(kTestInput1Dimensions) / sizeof(kTestInput1Dimensions[0]); +constexpr const int32_t kTestOutputDimensions[] = {2}; +constexpr const int32_t kNumTestOutputDimensions = + sizeof(kTestOutputDimensions) / sizeof(kTestOutputDimensions[0]); + +constexpr const float kTestInput0Tensor[] = {1, 2}; +constexpr const float kTestInput1Tensor[] = {10, 20}; +constexpr const float kTestOutputTensor[] = {11, 22}; + +constexpr const size_t kTestInput0Size = + sizeof(kTestInput0Tensor) / sizeof(kTestInput0Tensor[0]); +constexpr const size_t kTestInput1Size = + sizeof(kTestInput1Tensor) / sizeof(kTestInput1Tensor[0]); +constexpr const size_t kTestOutputSize = + sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); + +constexpr const LiteRtRankedTensorType kInput0TensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/{ + /*.rank=*/kNumTestInput0Dimensions, + /*.dimensions=*/kTestInput0Dimensions, + }}; + +constexpr const LiteRtRankedTensorType kInput1TensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/{ + /*.rank=*/kNumTestInput1Dimensions, + /*.dimensions=*/kTestInput1Dimensions, + }}; + +constexpr const LiteRtRankedTensorType kOutputTensorType = { + /*.element_type=*/kLiteRtElementTypeFloat32, + /*.layout=*/{ + /*.rank=*/kNumTestOutputDimensions, + /*.dimensions=*/kTestOutputDimensions, + }}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TEST_TESTDATA_SIMPLE_MODEL_TEST_VECTORS_H_ diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir new file mode 100644 index 00000000000000..7fb5ac2d2187f0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_mul_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<1x128x2304xf32>) -> tensor<1x128x2304xf32> { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x2304xf32> + return %0 : tensor<1x128x2304xf32> +} +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/test_data/simple_multi_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir similarity index 100% rename from tensorflow/compiler/mlir/lite/experimental/lrt/test_data/simple_multi_op.mlir rename to tensorflow/lite/experimental/litert/test/testdata/simple_multi_op.mlir diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir new file mode 100644 index 00000000000000..515db6e424e6a7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_reshape_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>) -> tensor<128x4x1x256xf32> { + %0 = "tfl.reshape"(%arg0, %arg1) : (tensor<1x128x4x256xf32>, tensor<4xi32>) -> tensor<128x4x1x256xf32> + return %0 : tensor<128x4x1x256xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir new file mode 100644 index 00000000000000..5083f3f3a30383 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_rsqrt_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x1xf32>) -> tensor<1x128x1xf32> { + %0 = "tfl.rsqrt"(%arg0) : (tensor<1x128x1xf32>) -> tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir new file mode 100644 index 00000000000000..2405e5d3626893 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_select_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xi1>, %arg1: tensor<1x128x8x128xf32>, %arg2: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = "tfl.select"(%arg0, %arg1, %arg2) : (tensor<1x128x8x128xi1>, tensor<1x128x8x128xf32>, tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir new file mode 100644 index 00000000000000..0f6ca0a8ed6d5a --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_select_v2_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x1x128xi1>, %arg1: tensor<1x128x8x128xf32>, %arg2: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = "tfl.select_v2"(%arg0, %arg1, %arg2) : (tensor<1x128x1x128xi1>, tensor<1x128x8x128xf32>, tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir new file mode 100644 index 00000000000000..117b9feb3758a6 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_slice_op.mlir @@ -0,0 +1,8 @@ +module { +func.func @main(%arg0: tensor<1x128x8x256xf32>) -> tensor<1x128x8x128xf32> { + %cst_0 = "tfl.pseudo_const"() <{value = dense<0> : tensor<4xi32>}> : () -> tensor<4xi32> + %cst_1 = "tfl.pseudo_const"() <{value = dense<[1, 128, 4, 128]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tfl.slice"(%arg0, %cst_0, %cst_1) : (tensor<1x128x8x256xf32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir new file mode 100644 index 00000000000000..bb3a83a3787f6f --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_softmax_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = "tfl.softmax"(%arg0) <{beta = 1.000000e+00 : f32}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + return %0 : tensor<8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir new file mode 100644 index 00000000000000..9d098eb0b9f61d --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_stablehlo_scatter_op.mlir @@ -0,0 +1,9 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<131072x4xi32>, %arg2: tensor<131072xf32>) -> tensor<1x128x4x256xf32> { + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1x128x4x256xf32>, tensor<131072x4xi32>, tensor<131072xf32>) -> tensor<1x128x4x256xf32> + return %0 : tensor<1x128x4x256xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir new file mode 100644 index 00000000000000..373eff80ff3cd8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_strided_slice_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x256xf32>, %arg1: tensor<4xi32>, %arg2: tensor<4xi32>, %arg3: tensor<4xi32>) -> tensor<1x128x4x128xf32> { + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x128x4x256xf32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) -> tensor<1x128x4x128xf32> + return %0 : tensor<1x128x4x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir new file mode 100644 index 00000000000000..e1483fed87d802 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_sub_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x4x128xf32>, %arg1: tensor<1x128x4x128xf32>) -> tensor<1x128x4x128xf32> { + %0 = tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1x128x4x128xf32> + return %0 : tensor<1x128x4x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir new file mode 100644 index 00000000000000..d494541a39d79c --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_sum_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x2304xf32>, %arg1: tensor<1xi32>) -> tensor<1x128x1xf32> { + %0 = "tfl.sum"(%arg0, %arg1) <{keep_dims = true}> : (tensor<1x128x2304xf32>, tensor<1xi32>) -> tensor<1x128x1xf32> + return %0 : tensor<1x128x1xf32> +} +} diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir new file mode 100644 index 00000000000000..ce1d0302c8a838 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_tanh_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> { + %0 = "tfl.tanh"(%arg0) : (tensor<1x128x8x128xf32>) -> tensor<1x128x8x128xf32> + return %0 : tensor<1x128x8x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir b/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir new file mode 100644 index 00000000000000..456fa371d13a76 --- /dev/null +++ b/tensorflow/lite/experimental/litert/test/testdata/simple_transpose_op.mlir @@ -0,0 +1,6 @@ +module { +func.func @main(%arg0: tensor<128x4x2x128xf32>, %arg1: tensor<4xi32>) -> tensor<2x128x4x128xf32> { + %0 = "tfl.transpose"(%arg0, %arg1) : (tensor<128x4x2x128xf32>, tensor<4xi32>) -> tensor<2x128x4x128xf32> + return %0 : tensor<2x128x4x128xf32> +} +} \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/tools/BUILD b/tensorflow/lite/experimental/litert/tools/BUILD new file mode 100644 index 00000000000000..cad6dd970d81b1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/BUILD @@ -0,0 +1,136 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "apply_plugin", + testonly = 1, + srcs = ["apply_plugin.cc"], + hdrs = ["apply_plugin.h"], + deps = [ + ":dump", + ":tool_display", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:litert_model_init", + "//tensorflow/lite/experimental/litert/core:litert_model_serialize", + "//tensorflow/lite/experimental/litert/core/compiler_plugin", + "//tensorflow/lite/experimental/litert/core/compiler_plugin:algo", + "//tensorflow/lite/experimental/litert/core/util:buffer_ref", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "apply_plugin_test", + srcs = ["apply_plugin_test.cc"], + data = [ + "//tensorflow/lite/experimental/litert/test:tflite_test_data", + "//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so", + ], + tags = [ + "noasan", + "nomsan", + "nosan", + ], + deps = [ + ":apply_plugin", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:litert_model_init", + "//tensorflow/lite/experimental/litert/core:litert_model_serialize", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_binary( + name = "apply_plugin_main", + testonly = 1, + srcs = ["apply_plugin_main.cc"], + data = ["//tensorflow/lite/experimental/litert/vendors/examples:example_plugin_so"], + linkstatic = 1, + tags = [ + "noasan", + "nomsan", + "nosan", + ], + deps = [ + ":apply_plugin", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + ], +) + +cc_library( + name = "tool_display", + srcs = ["tool_display.cc"], + hdrs = ["tool_display.h"], + deps = [ + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "tool_display_test", + srcs = ["tool_display_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + deps = [ + ":tool_display", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/core/compiler_plugin:compiler_plugin_hdr", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_test( + name = "dump_test", + srcs = ["dump_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + deps = [ + ":dump", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc new file mode 100644 index 00000000000000..7427cfa37439b6 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin.cc @@ -0,0 +1,504 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/algo.h" +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_serialize.h" +#include "tensorflow/lite/experimental/litert/core/util/buffer_ref.h" +#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h" +#include "tensorflow/lite/experimental/litert/tools/dump.h" +#include "tensorflow/lite/experimental/litert/tools/tool_display.h" + +namespace litert::tools { + +using ::litert::internal::CompilerPlugin; +using ::litert::internal::Dump; +using ::litert::internal::GroupPartitions; +using ::litert::internal::OutlinePartition; +using ::litert::internal::VerifyFlatbuffer; +using ::litert::tools::ApplyPluginRun; + +#define _ENSURE_CONFIG(expr) \ + if (!(expr)) { \ + return kLiteRtStatusErrorInvalidToolConfig; \ + } + +namespace { + +static constexpr absl::string_view kArt = R"( + __ _ __ ____ __ + / / (_/ /____ / __ \/ /_ + / / / / __/ _ \/ /_/ / __/ + / /___/ / /_/ __/ _, _/ /_ +/_____/_/\__/\___/_/ |_|\__/ +)"; + +class Context { + public: + using Ptr = std::unique_ptr; + using ResultT = LiteRtResult; + + explicit Context(ApplyPluginRun::Ptr run) + : run_(std::move(run)), + display_(ToolDisplay(run_->dump_out, Context::CmdStr(run_->cmd))) {} + + ApplyPluginRun::Cmd Cmd() const { return run_->cmd; } + + absl::Span LibSearchPaths() const { + return absl::MakeConstSpan(run_->lib_search_paths.data(), + run_->lib_search_paths.size()); + } + + absl::string_view SocModelTarget() const { + ABSL_CHECK_EQ(run_->soc_models.size(), 1); + return run_->soc_models.front(); + } + + std::ostream& Out() { + ABSL_CHECK_EQ(run_->outs.size(), 1); + return run_->outs.front(); + } + + ApplyPluginRun::OutStreamT SwapOut(ApplyPluginRun::OutStreamT out) { + ABSL_CHECK_EQ(run_->outs.size(), 1); + auto res = run_->outs.front(); + run_->outs.at(0) = out; + return res; + } + + const ApplyPluginRun& Run() const { return *run_; } + ApplyPluginRun& Run() { return *run_; } + + ToolDisplay& Dump() { return display_; } + + void DumpPrelude(); + + static absl::string_view CmdStr(ApplyPluginRun::Cmd cmd); + + private: + ApplyPluginRun::Ptr run_; + ToolDisplay display_; +}; + +absl::string_view Context::CmdStr(ApplyPluginRun::Cmd cmd) { + switch (cmd) { + case ApplyPluginRun::Cmd::INFO: + return "INFO"; + case ApplyPluginRun::Cmd::NOOP: + return "NOOP"; + case ApplyPluginRun::Cmd::PARTITION: + return "PARTITION"; + case ApplyPluginRun::Cmd::COMPILE: + return "COMPILE"; + case ApplyPluginRun::Cmd::APPLY: + return "APPLY"; + } +} + +void Context::DumpPrelude() { + Dump().Display() << kArt << "\n"; + // TODO pretty print run struct. +} + +CompilerPlugin::ResultVecT LoadAllPlugins(Context* ctx) { + ctx->Dump().Start("Load Plugins"); + ctx->Dump().Labeled() << "Loading plugins from: "; + const auto paths = ctx->LibSearchPaths(); + for (auto it = paths.begin(); it < paths.end(); ++it) { + ctx->Dump().Display() << *it; + if (it < paths.end() - 1) { + ctx->Dump().Display() << ", "; + } + } + ctx->Dump().Display() << "\n"; + + auto plugins = CompilerPlugin::LoadPlugins(ctx->LibSearchPaths()); + if (!plugins.HasValue()) { + ctx->Dump().Fail(); + return plugins; + } + ctx->Dump().Labeled() << "Found plugins\n"; + ctx->Dump().Labeled() << absl::StreamFormat("Loaded %lu plugins\n", + plugins.Value().size()); + + ctx->Dump().Done(); + return plugins; +} + +CompilerPlugin::ResultT LoadPlugin(Context* ctx) { + LITERT_MOVE_OR_RETURN_RESULT(auto plugins, LoadAllPlugins(ctx), + CompilerPlugin); + ctx->Dump().Start("Select Plugin"); + + for (auto& plugin : plugins) { + if (plugin.SocManufacturer() == ctx->Run().soc_manufacturer) { + ctx->Dump().Done(); + return CompilerPlugin::ResultT::TakeValue(std::move(plugin)); + } + } + + ctx->Dump().Fail(); + return CompilerPlugin::ResultT::FromStatus(kLiteRtStatusErrorNotFound); +} + +LiteRtResult LoadModel(Context* ctx) { + ctx->Dump().Start("Load Model"); + ctx->Dump().Labeled() << absl::StreamFormat("Loading model from: %s\n", + ctx->Run().model.value()); + + LiteRtModel model; + if (LoadModelFromFile(ctx->Run().model->data(), &model) != kLiteRtStatusOk) { + ctx->Dump().Fail(); + return LiteRtResult::FromStatus( + kLiteRtStatusErrorFileIO); + } + + ctx->Dump().Labeled(); + Dump(*model, ctx->Dump().Display()); + + ctx->Dump().Done(); + return LiteRtResult::TakeValue(UniqueLiteRtModel(model)); +} + +LiteRtStatus SerializeModel(Context* ctx, UniqueLiteRtModel model) { + ctx->Dump().Start("Serialize Model"); + + OwningBufferRef buf; + auto [data, size, offset] = buf.GetWeak(); + + LITERT_RETURN_STATUS_IF_NOT_OK( + SerializeModel(model.release(), &data, &size, &offset)); + LITERT_ENSURE(VerifyFlatbuffer(buf.Span()), + kLiteRtStatusErrorInvalidFlatbuffer, + "Failed to verify flatbuffer."); + + buf.WriteStr(ctx->Out()); + ctx->Dump().Labeled() << "Serialized a model... "; + buf.Dump(ctx->Dump().Display()); + + ctx->Dump().Done(); + return kLiteRtStatusOk; +} + +std::vector ApplyPartition(Context* ctx, LiteRtModelT& model, + CompilerPlugin& plugin) { + ctx->Dump().Start("Partition Model"); + LITERT_RETURN_VAL_IF_NOT_OK( + RegisterCustomOpCode(&model, ctx->Run().soc_manufacturer->data()), {}); + + ctx->Dump().Labeled() << "Input model: \n"; + for (auto it = model.subgraphs.begin(); it < model.subgraphs.end(); ++it) { + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(input graph) "; + Dump(*it, ctx->Dump().Display()); + } + + auto partiion = plugin.PartitionModel(model); + if (!partiion.HasValue()) { + return {}; + } + auto grouped_partitions = GroupPartitions(partiion.Value()); + if (grouped_partitions.empty()) { + return {}; + } + ctx->Dump().Labeled() << absl::StreamFormat( + "Plugin selected %lu ops, yielding %lu partitions\n", + partiion.Value().size(), grouped_partitions.size()); + + std::vector res; + for (auto& partition : grouped_partitions) { + LiteRtOp custom_op = OutlinePartition( + model.subgraphs.front(), &model.subgraphs.emplace_back(), partition); + res.push_back(custom_op); + } + + ctx->Dump().Labeled() << "Partitioned model: \n"; + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(initial graph) "; + Dump(model.subgraphs.front(), ctx->Dump().Display()); + for (auto it = model.subgraphs.begin() + 1; it < model.subgraphs.end(); + ++it) { + ctx->Dump().Labeled(); + ctx->Dump().Indented() << "(new graph) "; + Dump(*it, ctx->Dump().Display()); + } + + ctx->Dump().Done(); + return res; +} + +LiteRtResult PartitionModel(Context* ctx, + UniqueLiteRtModel model, + CompilerPlugin& plugin) { + auto custom_ops = ApplyPartition(ctx, *model, plugin); + if (custom_ops.empty()) { + return LiteRtResult::FromStatus( + kLiteRtStatusErrorGraphModification); + } + return LiteRtResult::TakeValue(std::move(model)); +} + +LiteRtResult> CompilePartitions( + Context* ctx, std::vector& partitions, + CompilerPlugin& plugin) { + ctx->Dump().Start("Compile Model"); + ctx->Dump().Labeled() << absl::StreamFormat( + "Requesting compilation for target \"%s\" on %lu subgraphs\n", + ctx->SocModelTarget(), partitions.size()); + + std::vector call_info_out; + if (plugin.Compile(ctx->SocModelTarget(), partitions, ctx->Out(), + call_info_out) != kLiteRtStatusOk) { + ctx->Dump().Fail(); + return LiteRtResult>::FromStatus( + kLiteRtStatusErrorCompilationr); + } + + ctx->Dump().Labeled() << "Entry point info: "; + for (auto it = call_info_out.begin(); it < call_info_out.end(); ++it) { + ctx->Dump().Display() << absl::StreamFormat("\"%s\"", *it); + if (it < call_info_out.end() - 1) { + ctx->Dump().Display() << ", "; + } + } + ctx->Dump().Display() << "\n"; + + ctx->Dump().Done(); + return LiteRtResult>::TakeValue( + std::move(call_info_out)); +} + +// +// INFO Command +// + +LiteRtStatus ValidateInfoRun(const ApplyPluginRun& run) { + _ENSURE_CONFIG(!run.lib_search_paths.empty()); + _ENSURE_CONFIG(run.outs.size() == 1); + return kLiteRtStatusOk; +} + +LiteRtStatus Info(Context* ctx) { + LITERT_MOVE_OR_RETURN_STATUS(auto plugins, LoadAllPlugins(ctx)); + for (auto& plugin : plugins) { + ctx->Out() << absl::StreamFormat("< LiteRtCompilerPlugin > \"%s\" | ", + plugin.SocManufacturer()); + const auto& models = plugin.SocModels(); + for (auto it = models.begin(); it < models.end(); ++it) { + ctx->Out() << absl::StreamFormat("\"%s\"", *it); + if (it < models.end() - 1) { + ctx->Out() << ", "; + } + } + } + return kLiteRtStatusOk; +} + +// +// NOOP Command +// + +LiteRtStatus ValidateNoopRun(const ApplyPluginRun& run) { + _ENSURE_CONFIG(run.model.has_value()); + _ENSURE_CONFIG(run.outs.size() == 1); + return kLiteRtStatusOk; +} + +LiteRtStatus Noop(Context* ctx) { + LITERT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LITERT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(model))); + return kLiteRtStatusOk; +} + +// +// PARTITION Command +// + +LiteRtStatus ValidatePartitionRun(const ApplyPluginRun& run) { + _ENSURE_CONFIG(!run.lib_search_paths.empty()); + _ENSURE_CONFIG(run.model.has_value()); + _ENSURE_CONFIG(run.soc_manufacturer.has_value()); + _ENSURE_CONFIG(!run.outs.empty()); + return kLiteRtStatusOk; +} + +LiteRtStatus Partition(Context* ctx) { + LITERT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + LITERT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + + LITERT_MOVE_OR_RETURN_STATUS(auto new_model, + PartitionModel(ctx, std::move(model), plugin)); + LITERT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(new_model))); + return kLiteRtStatusOk; +} + +// +// COMPILE Command +// + +LiteRtStatus ValidateCompileRun(const ApplyPluginRun& run) { + _ENSURE_CONFIG(!run.lib_search_paths.empty()); + _ENSURE_CONFIG(run.model.has_value()); + _ENSURE_CONFIG(run.soc_manufacturer.has_value()); + _ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); + // TODO: implement multi target compilation. + LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, + "Multi target compilation not implemented."); + // TODO: implement append serialization. + LITERT_ENSURE_SUPPORTED( + run.serialization == ApplyPluginRun::Serialization::METADATA, + "Only metadata serialization currently supported."); + return kLiteRtStatusOk; +} + +LiteRtStatus Compile(Context* ctx) { + LITERT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LITERT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + + std::vector compilation_input; + compilation_input.reserve(model->subgraphs.size()); + for (auto& subgraph : model->subgraphs) { + compilation_input.push_back(&subgraph); + } + LITERT_MOVE_OR_RETURN_STATUS( + auto entry_point_info, CompilePartitions(ctx, compilation_input, plugin)); + + return kLiteRtStatusOk; +} + +// +// APPLY Command +// + +LiteRtStatus ValidateApplyRun(const ApplyPluginRun& run) { + _ENSURE_CONFIG(!run.lib_search_paths.empty()); + _ENSURE_CONFIG(run.model.has_value()); + _ENSURE_CONFIG(run.soc_manufacturer.has_value()); + _ENSURE_CONFIG(run.outs.size() == run.soc_models.size()); + // TODO: implement multi target compilation. + LITERT_ENSURE_SUPPORTED(run.soc_models.size() == 1, + "Multi target compilation not implemented."); + // TODO: implement append serialization. + LITERT_ENSURE_SUPPORTED( + run.serialization == ApplyPluginRun::Serialization::METADATA, + "Only metadata serialization currently supported."); + return kLiteRtStatusOk; +} + +LiteRtStatus Apply(Context* ctx) { + LITERT_MOVE_OR_RETURN_STATUS(auto model, LoadModel(ctx)); + LITERT_MOVE_OR_RETURN_STATUS(auto plugin, LoadPlugin(ctx)); + static constexpr size_t kNumInputSubgraphs = 1; + LITERT_ENSURE_SUPPORTED(model->subgraphs.size() == kNumInputSubgraphs, + "Only single subgraph models currently supported."); + + // Query plugin for compilable ops and slice partitions out of the graph, + // replacing use with single custom op.. + auto custom_ops = ApplyPartition(ctx, *model, plugin); + LITERT_ENSURE(!custom_ops.empty(), kLiteRtStatusErrorGraphModification, + "Failed to partiion graph."); + // All new subgraphs to be compiled are appended to the model's subgraphs. + std::vector compilation_input; + for (auto it = model->subgraphs.begin() + kNumInputSubgraphs; + it < model->subgraphs.end(); ++it) { + compilation_input.push_back(&*it); + } + + // Call compilation method on the plugin. + std::stringstream compilation_out; + ApplyPluginRun::OutStreamT out = ctx->SwapOut(compilation_out); + LITERT_MOVE_OR_RETURN_STATUS( + auto call_info, CompilePartitions(ctx, compilation_input, plugin)); + + // Update custom op info the it's respective entry point info from the plugin. + LITERT_ENSURE(call_info.size() == custom_ops.size(), + kLiteRtStatusErrorCompilationr, + "Failed to verify entry point information."); + auto call_it = call_info.begin(); + auto custom_op_it = custom_ops.begin(); + for (; call_it < call_info.end() && custom_op_it < custom_ops.end();) { + (*custom_op_it)->custom_options = + OwningBufferRef(*call_it->c_str()); + ++call_it; + ++custom_op_it; + } + + model->subgraphs.resize(kNumInputSubgraphs); + + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtModelAddByteCodeMetadata( + model.get(), plugin.SocManufacturer().data(), + plugin.SocModels().front().data(), compilation_out.str().data(), + compilation_out.str().size())); + + ctx->SwapOut(out); + LITERT_RETURN_STATUS_IF_NOT_OK(SerializeModel(ctx, std::move(model))); + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run) { + Context context(std::move(run)); + context.DumpPrelude(); + + switch (context.Cmd()) { + case ApplyPluginRun::Cmd::INFO: + LITERT_RETURN_STATUS_IF_NOT_OK(ValidateInfoRun(context.Run())); + return Info(&context); + + case ApplyPluginRun::Cmd::PARTITION: + LITERT_RETURN_STATUS_IF_NOT_OK(ValidatePartitionRun(context.Run())); + return Partition(&context); + + case ApplyPluginRun::Cmd::COMPILE: + LITERT_RETURN_STATUS_IF_NOT_OK(ValidateCompileRun(context.Run())); + return Compile(&context); + + case ApplyPluginRun::Cmd::APPLY: + LITERT_RETURN_STATUS_IF_NOT_OK(ValidateApplyRun(context.Run())); + return Apply(&context); + + case ApplyPluginRun::Cmd::NOOP: + LITERT_RETURN_STATUS_IF_NOT_OK(ValidateNoopRun(context.Run())); + return Noop(&context); + + default: + return kLiteRtStatusErrorInvalidArgument; + } + + return kLiteRtStatusOk; +} + +} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin.h b/tensorflow/lite/experimental/litert/tools/apply_plugin.h new file mode 100644 index 00000000000000..a336bcb3959cf2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin.h @@ -0,0 +1,182 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" + +namespace litert::tools { + +struct ApplyPluginRun { + // NOTE: All StrFlagT are expected to have static storage duration. + using StrFlagT = absl::string_view; + using StrFlagListT = std::vector; + using OptStrFlagT = std::optional; + using OutStreamT = std::reference_wrapper; + using OutStreamtListT = std::vector; + using OptOutStreamT = std::optional; + using Ptr = std::unique_ptr; + using ShrPtr = std::shared_ptr; + + // A specific command implemented by the tool to run. + enum class Cmd { + // Displays info about all plugins found in given search paths. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Ignored. + // "soc_manufacturer": Optional, filters plugins to display. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + INFO, + + // Does nothing and simply de-serializes and re-serializes the given model. + // This is intended for testing and internal debugging only. + // + // FLAG SEMANTICS: + // "lib_search_paths": Ignored. + // "model": Required. + // "soc_manufacturer": Ignored. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + NOOP, + + // Runs the entire end to end flow. This is the standard compiler plugin + // usage. A seperate compilation step will occur for each sco_model tag that + // is supported by the loaded plugin, and a new output model will be + // generated for each. Partitioning is invariant accross different soc_model + // targets from the same manufacturer, so only one compilation step will + // occur even if multiple targest are requested. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Required, at least one. + // "outs": Required, must be size equal to "soc_models". + // "dump_out": Optional. + // "serialization": Required. + // + // TODO: Support multi target compilation. + APPLY, + + // Only run the partiion step and skip compilation. Writes a ".tflite" model + // to "out" where selected partitions are manifested as new standard + // flatbuffer subgraphs added to the input model. + // The partitions original locations are replaced with a single custom op + // the contains an identifier to the corresponding partition (new subgraph). + // This is intended for testing and development. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Ignored. + // "outs": Required, must be size one. + // "dump_out": Optional. + // "serialization": Ignored. + PARTITION, + + // Skip partitioning and run the entire input model through compilation + // directly. Fails if any ops in the input model are unsupported by the + // plugin. Writes the raw compiled result to the "out" stream without any + // wrapping flatbuffer. Runs multi-target compilation as in "APPLY", + // Intended for testing and development. + // + // FLAG SEMANTICS: + // "lib_search_paths": Required, at least one. + // "model": Required. + // "soc_manufacturer": Required. + // "soc_models": Required, at least one. + // "out": Required, must be size equal to "soc_models". + // "dump_out": Optional. + // "serialization": Required. + // + // TODO: Support multi target compilation. + COMPILE, + }; + + // A command to run, see above. + Cmd cmd; + + // Collection of paths on local files system dictating where the tool should + // look for suitable LiteRtCompilerPlugin shared libraries. The tool will + // select the first ".so" file found with prefix "libLiteRtPlugin" that has + // the "soc_manufacturer" tag passed. Providing more than one plugin shared + // library for the same manufacturer results in an error. + StrFlagListT lib_search_paths = {}; + + // Path to ".tflite" model the tool should operated on. + OptStrFlagT model = {}; + + // A tag representing a manufacturer the tool should target for compilation. + // This is used to select the appropriate plugin if multiple plugins are found + // in "lib_search_paths". + OptStrFlagT soc_manufacturer = {}; + + // Collection of soc models tags the tool should target for compilation. + StrFlagListT soc_models = {}; + + // Where the tool should write its result file(s) to. If the command runs + // compilation, an "out" stream should be passed for each "soc_model" target + // requested for compilation. Output for the "ith" target will be written to + // the "ith" outs stream. + OutStreamtListT outs = {std::cout}; + + // Where to direct logging for this run. Passing nullopt here indicates + // "silent" behavior and should only be used when this tool is part of a + // larger pipeline like an end2end test. + OptOutStreamT dump_out = std::cerr; + + // Dictates how the final model with compiled assets should be serialized. + // Only relevant to runs with a compilation step. + enum class Serialization { + // Write the compiled module into a metadata buffer using the + // soc_manufacturer as a key. This is for testing and debugging as it allows + // the contents of the byte code to be rendered by exisitng flatbuffer + // tooling. Custom op options will contain only a string identifying the + // respective entry point. + METADATA, + + // Appends the compiled byte code to the end of the ".tflite" file. Custom + // op options will contain both an entry point string and an offset into the + // file where the byte code starts. Options will be a string of the form + // "\"\"". byte_offset is a size_t offset + // where the compiled module starts in the file. Currently only single + // shared byte code modules are supported and so all ops will have the same + // offset. + // TODO: Implement. + APPEND, + }; + + // Serialization strategy to use, see above. + Serialization serialization = Serialization::METADATA; +}; + +LiteRtStatus ApplyPlugin(ApplyPluginRun::Ptr run); + +} // namespace litert::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_APPLY_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc new file mode 100644 index 00000000000000..23ac8936f087c0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin_main.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/CommandLine.h" +#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" + +using ::litert::tools::ApplyPlugin; +using ::litert::tools::ApplyPluginRun; + +// NOLINTNEXTLINE +static llvm::cl::opt cmd( + llvm::cl::Positional, + llvm::cl::desc("Routine to run (apply, partition, compile, info, noop)."), + llvm::cl::init("partition")); + +// NOLINTNEXTLINE +static llvm::cl::opt model( + "model", llvm::cl::desc("Path to flatbuffer file."), llvm::cl::init("")); + +// TODO: b/366821557 - Support path to pre-compiled plugin in flags. +// NOLINTNEXTLINE +static llvm::cl::opt soc_manufacturer( + "soc_man", + llvm::cl::desc("String identifier of SoC manufacturer (e.g., GoogleTensor, " + "Qualcomm)."), + llvm::cl::init("ExampleSocManufacturer")); + +// TODO: Support multi target compilation. +// NOLINTNEXTLINE +static llvm::cl::opt soc_model("soc_model", + llvm::cl::desc("Target SoC model."), + llvm::cl::init("ExampleSocModel")); + +// NOLINTNEXTLINE +static llvm::cl::list libs( + "libs", + llvm::cl::desc("List of directories in which to search for suitable " + "compiler plugin shared libraries."), + llvm::cl::list_init(llvm::ArrayRef{ + "third_party/tensorflow/lite/experimental/litert/vendors/examples"})); + +// NOLINTNEXTLINE +static llvm::cl::opt out( + "o", + llvm::cl::desc("Path to file for output, \"-\" indicates standard out."), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt err( + "err", + llvm::cl::desc("Path to file for error output, \"-\" indicates stdandard " + "error and \"none\" indicates silent."), + llvm::cl::init("-")); + +// NOLINTNEXTLINE +static llvm::cl::opt serialization( + "serialization", llvm::cl::desc("Serialization strategy to use."), + llvm::cl::init("METADATA")); + +ApplyPluginRun::Ptr ParseFlags() { + auto res = std::make_unique(); + + std::ofstream file_out; + if (out != "-") { + file_out.open(out); + res->outs.clear(); + res->outs.push_back(file_out); + } + + std::ofstream file_err; + if (err != "-") { + file_err.open(err); + res->dump_out.emplace(file_err); + } + + if (!model.empty()) { + res->model = model; + } + + res->soc_manufacturer = soc_manufacturer; + res->soc_models.push_back(soc_model); + + res->lib_search_paths.assign(libs.begin(), libs.end()); + + if (cmd == "apply") { + res->cmd = ApplyPluginRun::Cmd::APPLY; + } else if (cmd == "partition") { + res->cmd = ApplyPluginRun::Cmd::PARTITION; + } else if (cmd == "compile") { + res->cmd = ApplyPluginRun::Cmd::COMPILE; + } else if (cmd == "info") { + res->cmd = ApplyPluginRun::Cmd::INFO; + } else if (cmd == "noop") { + res->cmd = ApplyPluginRun::Cmd::NOOP; + } + + return res; +} + +int main(int argc, char* argv[]) { + llvm::cl::ParseCommandLineOptions(argc, argv); + + auto run = ParseFlags(); + if (run == nullptr) { + return 1; + } + + return ApplyPlugin(std::move(run)); +} diff --git a/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc new file mode 100644 index 00000000000000..52dc9836f103dd --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/apply_plugin_test.cc @@ -0,0 +1,174 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/apply_plugin.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/absl_check.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_init.h" +#include "tensorflow/lite/experimental/litert/core/litert_model_serialize.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +using ::graph_tools::GetMetadata; +using ::litert::tools::ApplyPlugin; +using ::litert::tools::ApplyPluginRun; +using ::testing::HasSubstr; + +static constexpr absl::string_view kPluginSearchPath = + "third_party/tensorflow/lite/experimental/litert/vendors/examples"; + +static constexpr absl::string_view kSocManufacturer = "ExampleSocManufacturer"; + +static constexpr absl::string_view kSocModel = "ExampleSocModel"; + +absl::string_view TestModelPath() { + static char kModelPath[512] = {}; + if (kModelPath[0] == '\0') { + const auto model_path = + ::litert::testing::GetTestFilePath("one_mul.tflite"); + ABSL_CHECK(model_path.size() < 512); + model_path.copy(kModelPath, model_path.size(), 0); + } + return kModelPath; +} + +ApplyPluginRun::Ptr MakeBaseRun(ApplyPluginRun::Cmd cmd) { + auto run = std::make_unique(); + run->cmd = cmd; + run->lib_search_paths.push_back(kPluginSearchPath); + run->model.emplace(TestModelPath()); + run->soc_manufacturer.emplace(kSocManufacturer); + run->soc_models.push_back(kSocModel); + run->outs.clear(); + return run; +} + +TEST(TestApplyPluginTool, TestInfoBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); + run->dump_out = {}; + run->lib_search_paths.clear(); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestInfo) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::INFO); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_THAT(out.str(), + ::testing::HasSubstr( + "< LiteRtCompilerPlugin > \"ExampleSocManufacturer\" | " + "\"ExampleSocModel\"")); +} + +TEST(TestApplyPluginTool, TestNoopBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); + run->model.reset(); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestNoop) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::NOOP); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + LiteRtModel model; + ASSERT_STATUS_OK( + LoadModel(reinterpret_cast(out.view().data()), + out.view().size(), &model)); + UniqueLiteRtModel u_model(model); + + EXPECT_EQ(model->subgraphs.size(), 1); +} + +TEST(TestApplyPluginTool, TestPartitionBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); + run->model.reset(); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestPartition) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::PARTITION); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); +} + +TEST(TestApplyPluginTool, TestCompileBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); + run->model.reset(); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestCompile) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::COMPILE); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + EXPECT_FALSE(out.str().empty()); + EXPECT_THAT(out.str(), HasSubstr("Partition_0_with_1_muls")); +} + +TEST(TestApplyPluginTool, TestApplyBadConfig) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); + run->model.reset(); + ASSERT_STATUS_HAS_CODE(ApplyPlugin(std::move(run)), + kLiteRtStatusErrorInvalidToolConfig); +} + +TEST(TestApplyPluginTool, TestApply) { + auto run = MakeBaseRun(ApplyPluginRun::Cmd::APPLY); + std::stringstream out; + run->outs.push_back(out); + ASSERT_STATUS_OK(ApplyPlugin(std::move(run))); + + LiteRtModel model; + ASSERT_STATUS_OK( + LoadModel(reinterpret_cast(out.view().data()), + out.view().size(), &model)); + UniqueLiteRtModel u_model(model); + EXPECT_EQ(model->subgraphs.size(), 1); + + ASSERT_RESULT_OK_ASSIGN(auto byte_code_buffer, + GetMetadata(model, kLiteRtMetadataByteCodeKey)); + EXPECT_THAT(byte_code_buffer.StrView(), HasSubstr("Partition_0_with_1_muls")); + + ASSERT_RESULT_OK_ASSIGN(auto tag_buffer, + GetMetadata(model, kLiteRtBuildTagKey)); + EXPECT_EQ(tag_buffer.StrView(), + "soc_man:ExampleSocManufacturer,soc_model:ExampleSocModel," + "serialization_strategy:" + "METADATA"); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/tools/dump.cc b/tensorflow/lite/experimental/litert/tools/dump.cc new file mode 100644 index 00000000000000..1df67c61525f58 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/dump.cc @@ -0,0 +1,351 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/dump.h" + +#include + +#ifndef __ANDROID__ +#include +#endif + +#include +#include +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/core/model.h" + +namespace litert::internal { + +namespace { + +void DumpNode(const LiteRtTensorT& tensor, std::ostream& out) { + switch (tensor.type_id) { + case kLiteRtRankedTensorType: + Dump(tensor.type_detail.ranked_tensor_type, out); + break; + case kLiteRtUnrankedTensorType: + Dump(tensor.type_detail.unranked_tensor_type.element_type, out); + break; + default: + out << "UKNOWN_TENSOR_TYPE" << tensor.type_id; + } +} + +void DumpNode(const LiteRtOpT& op, std::ostream& out) { Dump(op.op_code, out); } + +void DumpSignature(const std::vector& ins, + const std::vector& outs, std::ostream& out) { + out << "("; + for (auto it = ins.begin(); it < ins.end(); ++it) { + DumpNode(**it, out); + if (it != ins.end() - 1) { + out << ", "; + } + } + out << ")"; + + out << " -> "; + const bool paren_outs = outs.size() != 1; + if (paren_outs) { + out << "("; + } + for (auto it = outs.begin(); it < outs.end(); ++it) { + DumpNode(**it, out); + if (it != outs.end() - 1) { + out << ", "; + } + } + if (paren_outs) { + out << ")"; + } +} + +} // namespace + +void Dump(LiteRtOpCode code, std::ostream& out) { + switch (code) { + case kLiteRtOpCodeTflAdd: + out << "TFL_ADD"; + break; + case kLiteRtOpCodeTflMul: + out << "TFL_MUL"; + break; + case kLiteRtOpCodeTflCustom: + out << "TFL_CUSTOM_OP"; + break; + case kLiteRtOpCodeTflSlice: + out << "TFL_SLICE"; + break; + case kLiteRtOpCodeTflDiv: + out << "TFL_DIV"; + break; + case kLiteRtOpCodeTflRsqrt: + out << "TFL_RSQRT"; + break; + case kLiteRtOpCodeTflTanh: + out << "TFL_TANH"; + break; + case kLiteRtOpCodeTflSub: + out << "TFL_SUB"; + break; + case kLiteRtOpCodeTflReshape: + out << "TFL_RESHAPE"; + break; + case kLiteRtOpCodeTflBatchMatmul: + out << "TFL_BATCH_MATMUL"; + break; + default: + out << "UKNOWN_OP_CODE: " << code; + break; + } +}; + +// Dump details about the given LiteRtElementType to the given stream. +void Dump(LiteRtElementType type, std::ostream& out) { + switch (type) { + case kLiteRtElementTypeFloat32: + out << "f32"; + break; + case kLiteRtElementTypeInt32: + out << "i32"; + break; + case kLiteRtElementTypeFloat64: + out << "f64"; + break; + case kLiteRtElementTypeInt64: + out << "i64"; + break; + case kLiteRtElementTypeFloat16: + out << "f16"; + break; + case kLiteRtElementTypeInt16: + out << "i16"; + break; + case kLiteRtElementTypeInt8: + out << "i8"; + break; + case kLiteRtElementTypeUInt8: + out << "ui8"; + break; + case kLiteRtElementTypeBool: + out << "i1"; + break; + default: + out << "UKNNOWN_ELEMENT_TYPE: " << type; + } +} + +void Dump(const LiteRtRankedTensorType& type, std::ostream& out) { + out << "<"; + for (int i = 0; i < type.layout.rank; ++i) { + out << type.layout.dimensions[i] << "x"; + } + Dump(type.element_type, out); + out << ">"; +} + +void Dump(const LiteRtTensorT& tensor, std::ostream& out) { + out << "LiteRtTensor : "; + DumpNode(tensor, out); + out << " [ "; + if (tensor.defining_op == nullptr) { + out << "*"; + } else { + DumpNode(*tensor.defining_op, out); + } + out << " ] "; + + out << "("; + for (auto it = tensor.users.begin(); it < tensor.users.end(); ++it) { + DumpNode(**it, out); + if (it != tensor.users.end() - 1) { + out << ", "; + } + } + out << ")"; + out << "\n"; +} + +void Dump(const LiteRtOpT& op, std::ostream& out) { + out << "LiteRtOp : [ "; + DumpNode(op, out); + out << " ] "; + DumpSignature(op.inputs, op.outputs, out); + out << "\n"; +} + +void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out) { + constexpr absl::string_view kSubgraphTpl = + "LiteRtSubgraph : [ #ops=%d #tensors=%d ] "; + out << absl::StreamFormat(kSubgraphTpl, subgraph.ops.size(), + subgraph.tensors.size()); + DumpSignature(subgraph.inputs, subgraph.outputs, out); + out << "\n"; +} + +void Dump(const CompilerPlugin& plugin, std::ostream& out) { + constexpr absl::string_view kPluginDumpTpl = + "SocManufacturer: %s\nSocModels: { "; + out << absl::StreamFormat(kPluginDumpTpl, plugin.SocManufacturer()); + + for (auto it = plugin.SocModels().begin(); it < plugin.SocModels().end(); + ++it) { + out << *it; + if (it != plugin.SocModels().end() - 1) { + out << ","; + } + out << " "; + } + + out << "}\n"; +} + +void Dump(void* lib_handle, std::ostream& out) { +#ifndef __ANDROID__ + out << "\n--- Lib Info ---\n"; + if (lib_handle == nullptr) { + out << "Handle is nullptr\n"; + return; + } + + Lmid_t dl_ns_idx; + if (0 != ::dlinfo(lib_handle, RTLD_DI_LMID, &dl_ns_idx)) { + return; + } + + std::string dl_origin; + dl_origin.resize(512); + if (0 != ::dlinfo(lib_handle, RTLD_DI_ORIGIN, dl_origin.data())) { + return; + } + + link_map* lm; + if (0 != ::dlinfo(lib_handle, RTLD_DI_LINKMAP, &lm)) { + return; + } + + out << "Lib Namespace: " << dl_ns_idx << "\n"; + out << "Lib Origin: " << dl_origin << "\n"; + + out << "loaded objects:\n"; + + auto* forward = lm->l_next; + auto* backward = lm->l_prev; + + while (forward != nullptr) { + out << " " << forward->l_name << "\n"; + forward = forward->l_next; + } + + out << "***" << lm->l_name << "\n"; + + while (backward != nullptr) { + out << " " << backward->l_name << "\n"; + backward = backward->l_prev; + } + + out << "\n"; +#endif +} + +void Dump(const LiteRtModelT& model, std::ostream& out) { + out << absl::StreamFormat("LiteRtModel : [ #subgraphs=%d ]\n", + model.subgraphs.size()); +} + +void DumpOptions(const LiteRtOpT& op, std::ostream& out) { + switch (op.op_code) { + case kLiteRtOpCodeTflAdd: + out << "fused_activation_function: " + << op.option.AsAddOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflMul: + out << "fused_activation_function: " + << op.option.AsMulOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflBatchMatmul: + out << "adj_x: " << op.option.AsBatchMatMulOptions()->adj_x << "\n"; + out << "adj_y: " << op.option.AsBatchMatMulOptions()->adj_y << "\n"; + out << "asymmetric_quantize_input: " + << op.option.AsBatchMatMulOptions()->asymmetric_quantize_inputs + << "\n"; + break; + case kLiteRtOpCodeTflConcatenation: + out << "fused_activation_function: " + << op.option.AsConcatenationOptions()->fused_activation_function + << "\n"; + out << "axis: " << op.option.AsConcatenationOptions()->axis << "\n"; + break; + case kLiteRtOpCodeTflDiv: + out << "fused_activation_function: " + << op.option.AsDivOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflFullyConnected: + out << "fused_activation_function: " + << op.option.AsFullyConnectedOptions()->fused_activation_function + << "\n"; + out << "weights_format: " + << op.option.AsFullyConnectedOptions()->weights_format << "\n"; + out << "keep_num_dims: " + << op.option.AsFullyConnectedOptions()->keep_num_dims << "\n"; + out << "quantized_bias_type: " + << op.option.AsFullyConnectedOptions()->quantized_bias_type << "\n"; + out << "asymmetric_quantize_input: " + << op.option.AsFullyConnectedOptions()->asymmetric_quantize_inputs + << "\n"; + break; + case kLiteRtOpCodeTflSoftmax: + out << "beta: " << op.option.AsSoftmaxOptions()->beta << "\n"; + break; + case kLiteRtOpCodeTflStridedSlice: + out << "begin_mask: " << op.option.AsStridedSliceOptions()->begin_mask + << "\n"; + out << "end_mask: " << op.option.AsStridedSliceOptions()->end_mask + << "\n"; + out << "ellipsis_mask: " + << op.option.AsStridedSliceOptions()->ellipsis_mask << "\n"; + out << "new_axis_mask: " + << op.option.AsStridedSliceOptions()->new_axis_mask << "\n"; + out << "shrink_axis_mask: " + << op.option.AsStridedSliceOptions()->shrink_axis_mask << "\n"; + out << "offset: " << op.option.AsStridedSliceOptions()->offset << "\n"; + break; + case kLiteRtOpCodeTflSub: + out << "fused_activation_function: " + << op.option.AsSubOptions()->fused_activation_function << "\n"; + break; + case kLiteRtOpCodeTflReshape: + out << "new_shape: "; + if (op.option.AsReshapeOptions() != nullptr) { + const int32_t* new_shape = + op.option.AsReshapeOptions()->new_shape.data(); + int32_t new_shape_size = op.option.AsReshapeOptions()->new_shape.size(); + for (int i = 0; i < new_shape_size; ++i) { + out << new_shape[i] << " "; + } + } + break; + default: + out << "No options for op code: " << op.op_code; + break; + } +} +} // namespace litert::internal diff --git a/tensorflow/lite/experimental/litert/tools/dump.h b/tensorflow/lite/experimental/litert/tools/dump.h new file mode 100644 index 00000000000000..29b0de4f41bdc8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/dump.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ + +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/core/compiler_plugin/compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/core/model.h" + +namespace litert::internal { + +// +// LiteRt IR +// + +// Dump details about the given LiteRtOpT to the given stream. +void Dump(const LiteRtOpT& op, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtSubgraphT to the given stream. +void Dump(const LiteRtSubgraphT& subgraph, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtTensorT to the given stream. +void Dump(const LiteRtTensorT& tensor, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtOpCode to the given stream. +void Dump(LiteRtOpCode code, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtElementType to the given stream. +void Dump(LiteRtElementType type, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtRankedTensorType to the given stream. +void Dump(const LiteRtRankedTensorType& type, std::ostream& out = std::cerr); + +// Dump details about the given LiteRtModel to the given stream. +void Dump(const LiteRtModelT& model, std::ostream& out = std::cerr); + +// Dump details about options +void DumpOptions(const LiteRtOpT& op, std::ostream& out = std::cerr); + +// +// Library Utilities +// + +// Dumps details about the loaded LiteRtCompilerPlugin library. +void Dump(const CompilerPlugin& plugin, std::ostream& out = std::cerr); + +// Dumps details about the dynamic library (see "dlinfo"). +void Dump(void* lib_handle, std::ostream& out = std::cerr); + +} // namespace litert::internal + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_DUMP_H_ diff --git a/tensorflow/lite/experimental/litert/tools/dump_test.cc b/tensorflow/lite/experimental/litert/tools/dump_test.cc new file mode 100644 index 00000000000000..7de40992f8d1ec --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/dump_test.cc @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/dump.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +using ::litert::internal::Dump; +using ::litert::internal::DumpOptions; +using ::litert::testing::LoadTestFileModel; + +TEST(DumpTest, TestDump) { + auto model = LoadTestFileModel("one_mul.tflite"); + + { + std::ostringstream model_dump; + Dump(*model, model_dump); + EXPECT_EQ(model_dump.view(), "LiteRtModel : [ #subgraphs=1 ]\n"); + } + + { + const LiteRtTensorT& in_tensor = *model->subgraphs.front().inputs.front(); + std::ostringstream in_tensor_dump; + Dump(in_tensor, in_tensor_dump); + EXPECT_EQ(in_tensor_dump.view(), + "LiteRtTensor : <2x2xf32> [ * ] (TFL_MUL)\n"); + } + + { + const LiteRtTensorT& out_tensor = *model->subgraphs.front().outputs.front(); + std::ostringstream out_tensor_dump; + Dump(out_tensor, out_tensor_dump); + EXPECT_EQ(out_tensor_dump.view(), + "LiteRtTensor : <2x2xf32> [ TFL_MUL ] ()\n"); + } + + { + const LiteRtOpT& op = *model->subgraphs.front().ops.front(); + std::ostringstream op_dump; + Dump(op, op_dump); + EXPECT_EQ(op_dump.view(), + "LiteRtOp : [ TFL_MUL ] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>\n"); + } + + { + const LiteRtSubgraphT& subgraph = model->subgraphs.front(); + std::ostringstream subgraph_dump; + Dump(subgraph, subgraph_dump); + EXPECT_EQ( + subgraph_dump.view(), + "LiteRtSubgraph : [ #ops=1 #tensors=3 ] (<2x2xf32>, <2x2xf32>) -> " + "<2x2xf32>\n"); + } +} + +TEST(DumpTest, TestDumpOptions) { + auto model = LoadTestFileModel("simple_strided_slice_op.tflite"); + const LiteRtOpT& op = *model->subgraphs.front().ops.front(); + std::ostringstream op_dump; + DumpOptions(op, op_dump); + EXPECT_EQ(op_dump.view(), + "begin_mask: 0\n" + "end_mask: 0\n" + "ellipsis_mask: 0\n" + "new_axis_mask: 0\n" + "shrink_axis_mask: 0\n" + "offset: 0\n"); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/tools/temp.txt b/tensorflow/lite/experimental/litert/tools/temp.txt new file mode 100644 index 00000000000000..1dd6e74af40a90 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/temp.txt @@ -0,0 +1,237 @@ +// std::pair GetModelAndPlugin() { +// std::vector plugins; +// LRT_CHECK_STATUS_OK(PluginManager::LoadPlugins({kPluginSearchPath}, +// plugins)); ABSL_CHECK_EQ(plugins.size(), 1); return +// {LoadTestFileModel(kModel), std::move(plugins.front())}; +// } + +// TEST(PluginToolTest, SerializeRoundTrip) { +// auto test_data = GetModelAndPlugin(); +// { +// ASSERT_EQ(test_data.first->subgraphs.size(), 1); +// const LiteRtSubgraphT& subgraph = test_data.first->subgraphs.front(); +// EXPECT_EQ(subgraph.inputs.size(), 2); +// EXPECT_EQ(subgraph.outputs.size(), 1); +// ASSERT_EQ(subgraph.ops.size(), 1); +// EXPECT_EQ(subgraph.ops.front()->op_code, kLiteRtOpCodeTflMul); +// } + +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); + +// std::stringstream serialized; +// ASSERT_STATUS_OK(tool.Serialize(serialized)); + +// LiteRtModel model; +// ASSERT_STATUS_OK( +// LoadModel(reinterpret_cast(serialized.str().data()), +// serialized.str().size(), &model)); +// UniqueLiteRtModel umodel(model); + +// { +// ASSERT_EQ(model->subgraphs.size(), 1); +// const LiteRtSubgraphT& subgraph = model->subgraphs.front(); +// EXPECT_EQ(subgraph.inputs.size(), 2); +// EXPECT_EQ(subgraph.outputs.size(), 1); +// ASSERT_EQ(subgraph.ops.size(), 1); +// EXPECT_EQ(subgraph.ops.front()->op_code, kLiteRtOpCodeTflMul); +// } +// } + +// TEST(PluginToolTest, DumpCompilationStats) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); + +// std::ostringstream dump_out; +// tool.DumpCompilationStats(dump_out); +// EXPECT_EQ(dump_out.view(), "LiteRtCompiledResult : +// \n"); +// } + +// TEST(PluginToolTest, TestPartition) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); +// ASSERT_STATUS_OK(tool.Partiion()); +// ASSERT_EQ(tool.Model().subgraphs.size(), 2); +// ASSERT_EQ(tool.MainSubgraph().ops.size(), 1); +// ASSERT_EQ(tool.Partitions().size(), 1); +// ASSERT_EQ(tool.Partitions().at(0).ops.size(), 1); +// } + +// TEST(PluginToolTest, TestDumpPartitionDetails) { +// auto test_data = GetModelAndPlugin(); +// PluginTool tool(std::move(test_data.first), std::move(test_data.second)); +// ASSERT_STATUS_OK(tool.Partiion()); +// std::ostringstream dump_out; +// tool.DumpPartitionDetails(dump_out); +// EXPECT_TRUE( +// absl::StrContains(dump_out.view(), +// "(main subgraph) LiteRtSubgraph : [ #ops=1 #tensors=3 ] +// " +// "(<2x2xf32>, <2x2xf32>) -> <2x2xf32>")); +// EXPECT_TRUE(absl::StrContains(dump_out.view(), +// "(partition) LiteRtSubgraph : [ #ops=1 +// #tensors=3 " +// "] (<2x2xf32>, <2x2xf32>) -> <2x2xf32>")); +// } + +// // Utility for applying various functions from given compiler +// // plugin to the given model. Writes details about the process to "dump". +// class PluginTool { +// public: +// // Perform the partition step. Plugin selects ops which are sliced from +// // the original graph. +// LiteRtStatus Partiion(); + +// // Perform the compilation step for "soc_model" provided. Writes +// // a new flatbuffer with embedded compiled module and custom ops to +// // the given stream. +// // NOTE: Currently this invalidates the underlying input model so it +// // cannot be called more than once. +// // TODO: Implement model copy to support compiling for multiple soc_models +// // in one run. +// LiteRtStatus Compile(const absl::string_view soc_model); + +// PluginTool(UniqueLiteRtModel model, internal::PluginManager plugin, +// std::ostream& dump = std::cerr) +// : model_(std::move(model)), plugin_(std::move(plugin)), dump_(dump) {} + +// PluginTool(const PluginTool&) = delete; +// PluginTool& operator=(const PluginTool&) = delete; +// PluginTool(PluginTool&&) = delete; +// PluginTool& operator=(PluginTool&&) = delete; + +// private: +// const LiteRtModelT& Model() const { return *model_; } +// LiteRtModelT& Model() { return *model_; } + +// const LiteRtSubgraphT& MainSubgraph() const { return +// Model().subgraphs.front(); } LiteRtSubgraphT& MainSubgraph() { return +// Model().subgraphs.front(); } + +// const absl::Span Partitions() const; + +// std::ostream& Dump() { return dump_; } +// std::ostream& dump_; + +// void DumpPartitionDetails() const; +// void DumpCompilationStats(const absl::string_view soc_model) const; + +// std::vector& CustomOps() { return custom_ops_; } +// std::vector custom_ops_; + +// UniqueLiteRtModel model_; + +// internal::PluginManager plugin_; +// }; + +// void PluginTool::DumpCompilationStats(const absl::string_view soc_model) +// const { +// static constexpr absl::string_view kCompiledResultTpl = +// "LiteRtCompiledResult : [ module_size=%lu (bytes), +// #compiled_partitions=%lu " +// "]\n"; +// static constexpr absl::string_view kCompiledResultErr = +// "LiteRtCompiledResult : \n"; +// if (plugin_.CompiledResultHandle(soc_model) == nullptr) { +// Dump() << kCompiledResultErr; +// return; +// } +// const void* byte_code; +// size_t byte_code_size; +// if (kLiteRtStatusOk != +// plugin_.Api().compiled_result_get_byte_code( +// plugin_.CompiledResultHandle(), &byte_code, &byte_code_size)) { +// Dump() << kCompiledResultErr; +// return; +// } + +// size_t num_compiled_partitions; +// if (kLiteRtStatusOk != +// plugin_.Api().compiled_result_get_num_calls( +// plugin_.CompiledResultHandle(), &num_compiled_partitions)) { +// Dump() << kCompiledResultErr; +// return; +// } + +// Dump() << absl::StreamFormat(kCompiledResultTpl, byte_code_size, +// num_compiled_partitions); +// } + +// void PluginTool::DumpPartitionDetails() const { +// Dump() << "[[ Partition Results ]]\n"; +// Dump() << "(main subgraph) "; +// litert::internal::Dump(MainSubgraph(), Dump()); +// for (const auto& partition : Partitions()) { +// Dump() << "(partition) "; +// litert::internal::Dump(partition, Dump()); +// } +// } + +// // Currently new partitioned subgraphs are appended to the model subgraphs +// and +// // there is only support of input models with one subgraph. +// const absl::Span PluginTool::Partitions() const { +// return absl::MakeConstSpan(model_->subgraphs.data() + 1, +// model_->subgraphs.size() - 1); +// } + +// LiteRtStatus PluginTool::Partiion() { +// LiteRtOpListT selected_ops; +// LRT_RETURN_STATUS_IF_NOT_OK(plugin_.Api().partition_model( +// plugin_.PluginHandle(), model_.get(), &selected_ops)); +// auto partitions = GroupPartitions(selected_ops.ops); + +// CustomOps().reserve(partitions.size()); + +// for (auto& partition : partitions) { +// LiteRtSubgraph new_subgraph = &model_->subgraphs.emplace_back(); +// CustomOps().push_back( +// OutlinePartition(MainSubgraph(), new_subgraph, partition)); +// } + +// return kLiteRtStatusOk; +// } + +// LiteRtStatus PluginTool::Compile(const absl::string_view soc_models) { +// LRT_RETURN_STATUS_IF_NOT_OK( +// plugin_.Api().compile(plugin_.PluginHandle(), soc_model.data(), +// slices.data(), slices.size(), &compiled_result)); + +// LiteRtParamIndex num_calls_compiled; +// LRT_RETURN_STATUS_IF_NOT_OK( +// LiteRtCompiledResultGetNumCalls(compiled_result, &num_calls_compiled)); + +// if (num_calls_compiled != slices.size()) { +// std::cerr +// << "Plugin must provide and entry point for each compiled +// partition\n"; +// return kLiteRtStatusErrorNotFound; +// } + +// for (int i = 0; i < num_calls_compiled; ++i) { +// const void* call_info; +// size_t call_info_size; + +// LRT_RETURN_STATUS_IF_NOT_OK(LiteRtCompiledResultGetCallInfo( +// compiled_result, i, &call_info, &call_info_size)); + +// auto* custom_op = custom_ops.at(i); +// custom_op->custom_options.assign(reinterpret_cast(call_info), +// call_info_size); +// } +// return kLiteRtStatusOk; +// } + +// LiteRtStatus PluginTool::Serialize(const absl::string_view soc_model, +// std::ostream& out) { +// uint8_t* buf; +// size_t size; +// size_t offset; +// LRT_RETURN_STATUS_IF_NOT_OK( +// SerializeModel(model_.release(), &buf, &size, &offset)); +// const char* cbuf = reinterpret_cast(buf); +// out.write(cbuf + offset, size - offset); +// delete[] buf; +// return kLiteRtStatusOk; +// } \ No newline at end of file diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.cc b/tensorflow/lite/experimental/litert/tools/tool_display.cc new file mode 100644 index 00000000000000..4a1213fcf203d8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/tool_display.cc @@ -0,0 +1,61 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/tool_display.h" + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace litert::tools { + +ToolDisplay::ToolDisplay(OptOstreamRefT display_stream, + const absl::string_view tool_label) + : display_(display_stream) { + label_ = absl::StrFormat( + "[LITERT_TOOLS%s] ", + tool_label.empty() ? tool_label : absl::StrFormat(":%s", tool_label)); +} + +std::ostream& ToolDisplay::Display() { + return display_.has_value() ? display_.value().get() : null_display_; +} + +std::ostream& ToolDisplay::Labeled() { + Display() << label_; + return Display(); +} + +std::ostream& ToolDisplay::Indented() { + Display() << "\t"; + return Display(); +} + +void ToolDisplay::Start(const absl::string_view start_label) { + Labeled() << absl::StreamFormat("Starting %s...\n", start_label); +} + +void ToolDisplay::Done() { + Labeled(); + Indented() << "Done!\n"; +} + +void ToolDisplay::Fail() { + Labeled(); + Indented() << "Failed\n"; +} + +} // namespace litert::tools diff --git a/tensorflow/lite/experimental/litert/tools/tool_display.h b/tensorflow/lite/experimental/litert/tools/tool_display.h new file mode 100644 index 00000000000000..65f3dafb3e5121 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/tool_display.h @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" + +namespace litert::tools { + +// Utility class for interactive logging for usage in command line tools only. +// Allows user to explicitly set target stream. +class ToolDisplay { + using OptOstreamRefT = std::optional>; + + public: + // Construct configured ToolDisplay. Label is used for prefixing dumps + // in "LabeledStream". If "dump" is null, all printing through this class + // is silenced. + explicit ToolDisplay(OptOstreamRefT display_stream = std::nullopt, + absl::string_view tool_label = ""); + + // Get out stream. + std::ostream& Display(); + + // Get Display with label prefix. + std::ostream& Labeled(); + + // Get Display with indent. + std::ostream& Indented(); + + // Log string indicating a sub rountine is beginning. + void Start(absl::string_view start_label); + + // Log string indicating a sub rountine is done and succeeded. + void Done(); + + // Log string indicating a sub rountine is done and failed. + void Fail(); + + private: + std::string label_; + std::ostream null_display_ = std::ostream(nullptr); + OptOstreamRefT display_; +}; + +} // namespace litert::tools + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_TOOLS_TOOL_DISPLAY_H_ diff --git a/tensorflow/lite/experimental/litert/tools/tool_display_test.cc b/tensorflow/lite/experimental/litert/tools/tool_display_test.cc new file mode 100644 index 00000000000000..00580acdacb170 --- /dev/null +++ b/tensorflow/lite/experimental/litert/tools/tool_display_test.cc @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/tools/tool_display.h" + +#include + +#include +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" + +namespace { + +using ::litert::tools::ToolDisplay; + +static constexpr absl::string_view kToolName = "test-tool"; +static constexpr absl::string_view kLabel = "[LITERT_TOOLS:test-tool]"; +static constexpr absl::string_view kStartLabel = "Test Routine"; +static constexpr absl::string_view kDisplayInfo = "info"; + +TEST(TestToolDisplay, Display) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Display() << kDisplayInfo; + EXPECT_EQ(out.view(), kDisplayInfo); +} + +TEST(TestToolDisplay, Indented) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Indented() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("\t%s", kDisplayInfo)); +} + +TEST(TestToolDisplay, Labeled) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), absl::StrFormat("%s %s", kLabel, kDisplayInfo)); +} + +TEST(TestToolDisplay, LabeledNoToolName) { + std::stringstream out; + ToolDisplay display(out); + display.Labeled() << kDisplayInfo; + EXPECT_EQ(out.view(), + absl::StrFormat("%s %s", "[LITERT_TOOLS]", kDisplayInfo)); +} + +TEST(TestToolDisplay, Start) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Start(kStartLabel); + EXPECT_EQ(out.view(), + absl::StrFormat("%s Starting %s...\n", kLabel, kStartLabel)); +} + +TEST(TestToolDisplay, Done) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Done(); + EXPECT_EQ(out.view(), absl::StrFormat("%s \tDone!\n", kLabel)); +} + +TEST(TestToolDisplay, Fail) { + std::stringstream out; + ToolDisplay display(out, kToolName); + display.Fail(); + EXPECT_EQ(out.view(), absl::StrFormat("%s \tFailed\n", kLabel)); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/c/BUILD b/tensorflow/lite/experimental/litert/vendors/c/BUILD new file mode 100644 index 00000000000000..7a482ff2b1449e --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/c/BUILD @@ -0,0 +1,47 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "litert_compiler_plugin", + hdrs = ["litert_compiler_plugin.h"], + deps = ["//tensorflow/lite/experimental/litert/c:litert_c_api"], +) + +cc_library( + name = "litert_compiler_plugin_api", + hdrs = ["litert_compiler_plugin_api.h"], + deps = [ + ":litert_compiler_plugin", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + ], +) + +cc_library( + name = "litert_dispatch_c_api", + hdrs = [ + "litert_dispatch.h", + "litert_dispatch_api.h", + ], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + ], +) + +exports_files(srcs = glob(["litert_*.h"])) diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h new file mode 100644 index 00000000000000..dfc210c83079b8 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +LITERT_DEFINE_HANDLE(LiteRtCompilerPlugin); + +// Artifact produced from compiling a selected partition of ops. +LITERT_DEFINE_HANDLE(LiteRtCompiledResult); + +// +// Plugin +// + +LiteRtStatus LiteRtPluginInit(LiteRtCompilerPlugin* compiler_plugin); + +void LiteRtPluginDestroy(LiteRtCompilerPlugin compiler_plugin); + +// Name associated with the manufacturer this plugin relates to (e.g, +// GoogleTensor, Qualcomm). +const char* LiteRtPluginSocManufacturer(); + +// Number of SoC models supported by this plugin. +LiteRtParamIndex LiteRtPluginNumSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin); + +// Gets the name of the SoC model at the given index. The memory +// associated with the returned name is owned by the plugin. +LiteRtStatus LiteRtPluginGetSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name); + +// Select desired ops for compilation. This will be called only once +// during the plugin application flow, all ops should be selected during this +// call. +LiteRtStatus LiteRtPluginPartitionModel(LiteRtCompilerPlugin compiler_plugin, + LiteRtModel model, + LiteRtOpList selected_ops); + +// Prepare result to pass to the runtime for given partition and, optionally, +// for a given SoC model (parameter `soc_model` can be NULL). The given +// subgraphs are valid sub-DAG within the ops selected in partition step. +LiteRtStatus LiteRtPluginCompile(LiteRtCompilerPlugin compiler_plugin, + const char* soc_model, + LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result); + +// +// Compiled Partition +// + +void LiteRtCompiledResultDestroy(LiteRtCompiledResult result); + +// Get serialized result to compiled modules available to all custom ops. +// This could be one module with multiple entry points or multiple modules +// concat together. +LiteRtStatus LiteRtCompiledResultGetByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size); + +// Get info to embed in a particular custom op. This could be any opaque data +// parsed in the custom op. +LiteRtStatus LiteRtCompiledResultGetCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size); + +// Get the number of calls that will be made to the HAL for this graph. +// This should equal the number of partitions given for compilation which +// is equal to the number of custom ops in the final model. +LiteRtStatus LiteRtCompiledResultGetNumCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h new file mode 100644 index 00000000000000..8e59a13a212173 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// Wrapper for dynamically loaded LiteRtCompilerPlugin library. See +// "litert_compiler_plugin.h". + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// +// Api Interface +// + +typedef const char* (*LiteRtPluginApiSocManufacturer)(); + +typedef LiteRtStatus (*LiteRtPluginApiInit)(LiteRtCompilerPlugin*); + +typedef void (*LiteRtPluginApiDestroy)(LiteRtCompilerPlugin); + +typedef LiteRtParamIndex (*LiteRtPluginApiNumSupportedModels)( + LiteRtCompilerPlugin); + +typedef LiteRtStatus (*LiteRtPluginApiGetSupportedSocModel)( + LiteRtCompilerPlugin, LiteRtParamIndex soc_model_idx, + const char** soc_moel_idx); + +typedef LiteRtStatus (*LiteRtPluginApiPartitionModel)( + LiteRtCompilerPlugin, LiteRtModel model, LiteRtOpList selected_ops); + +typedef LiteRtStatus (*LiteRtPluginApiCompile)( + LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result); + +typedef void (*LiteRtCompiledResultApiDestroy)(LiteRtCompiledResult); + +typedef LiteRtStatus (*LiteRtCompiledResultApiGetByteCode)( + LiteRtCompiledResult, const void** byte_code, size_t* byte_code_size); + +typedef LiteRtStatus (*LiteRtCompiledResultApiGetCallInfo)( + LiteRtCompiledResult, LiteRtParamIndex call_idx, const void** call_info, + size_t* call_info_size); + +typedef LiteRtStatus (*LiteRtCompiledResultApiGetNumCalls)( + LiteRtCompiledResult, LiteRtParamIndex* num_calls); + +// +// Function Pointer Container +// + +// Wraps all resolved functions from api interface. +struct LiteRtCompilerPluginApi { + LiteRtPluginApiInit init = nullptr; + LiteRtPluginApiDestroy destroy = nullptr; + + LiteRtPluginApiSocManufacturer soc_manufacturer = nullptr; + LiteRtPluginApiNumSupportedModels num_supported_models = nullptr; + LiteRtPluginApiGetSupportedSocModel get_supported_soc_model = nullptr; + + LiteRtPluginApiPartitionModel partition_model = nullptr; + LiteRtPluginApiCompile compile = nullptr; + + LiteRtCompiledResultApiDestroy compiled_result_destroy = nullptr; + LiteRtCompiledResultApiGetByteCode compiled_result_get_byte_code = nullptr; + LiteRtCompiledResultApiGetCallInfo compiled_result_get_call_info = nullptr; + LiteRtCompiledResultApiGetNumCalls compiled_result_get_num_calls = nullptr; +}; + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_COMPILER_PLUGIN_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h new file mode 100644 index 00000000000000..3bb4ab89d3e93e --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h @@ -0,0 +1,284 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ + +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +#define LITERT_DISPATCH_API_VERSION_MAJOR 0 +#define LITERT_DISPATCH_API_VERSION_MINOR 1 +#define LITERT_DISPATCH_API_VERSION_PATCH 0 + +LITERT_DEFINE_HANDLE(LiteRtDispatchDeviceContext); +LITERT_DEFINE_HANDLE(LiteRtDispatchInvocationContext); + +typedef uint64_t LiteRtTensorBufferHandle; + +typedef struct LiteRtDispatchApiVersion { + int major; + int minor; + int patch; +} LiteRtDispatchApiVersion; + +typedef enum LiteRtDispatchCapabilities { + kLiteRtDispatchCapabilitiesNone = 0, + kLiteRtDispatchCapabilitiesBasic = 1, // The vendor supports the Basic API + kLiteRtDispatchCapabilitiesAsync = 2, // The vendor supports the Async API + kLiteRtDispatchCapabilitiesGraph = 4, // The vendor supports the Graph API +} LiteRtDispatchCapabilities; + +// Types of executable that can run on the HW accelerators. +typedef enum LiteRtDispatchExecutableType { + kLiteRtDispatchExecutableTypeUnknown = 0, + kLiteRtDispatchExecutableTypeDspLibrary = 1, // DSP library + kLiteRtDispatchExecutableTypeMlModel = 2, // Vendor-specific ML model +} LiteRtDispatchExecutableType; + +typedef struct LiteRtDispatchOption { + const char* name; + LiteRtAny value; +} LiteRtDispatchOption; + +// This option can be used to specify a directory from where to load shared +// libraries. +static const char* kDispatchOptionSharedLibraryDir = "shared_library_dir"; + +// Initialize the Dispatch API runtime. +// +// This function should be called before calling any other Dispatch API +// functions. +LiteRtStatus LiteRtDispatchInitialize(const LiteRtDispatchOption* options, + int num_options); + +// Return the version of the Dispatch API runtime. +LiteRtStatus LiteRtDispatchGetApiVersion(LiteRtDispatchApiVersion* api_version); + +// Return the vendor id of the Dispatch API runtime. +// +// This function returns a pointer to a statically allocated string that is the +// ID of vendor providing the Dispatch API runtime. +LiteRtStatus LiteRtDispatchGetVendorId(const char** vendor_id); + +// Return the build ID of the Dispatch API runtime. +// +// This function returns a pointer to a statically allocated string that is the +// ID of the Dispatch API runtime build. +LiteRtStatus LiteRtDispatchGetBuildId(const char** build_id); + +// Return the capabilities supported by the Dispatch API runtime as a set of the +// values specified in LiteRtDispatchCapabilities. +LiteRtStatus LiteRtDispatchGetCapabilities(int* capabilities); + +// Create a `LiteRtDispatchDeviceContext` object. +// +// The returned object is used to talk with the underlying HW. The caller owns +// the memory associated with the context and should call +// LiteRtDispatchDeviceContextDestroy() to release it. Return NULL in case of +// error. +LiteRtStatus LiteRtDispatchDeviceContextCreate( + LiteRtDispatchDeviceContext* device_context); + +// Release a `LiteRtDispatchDeviceContext` object. +// +// The given context should be release only after releasing all associated +// objects. +LiteRtStatus LiteRtDispatchDeviceContextDestroy( + LiteRtDispatchDeviceContext device_context); + +// Given a tensor type for an invocation context input, obtain the attributes +// the HW requires for the associated tensor buffer. The returned +// `tensor_buffer_requirements` object is owned by the caller. +LiteRtStatus LiteRtDispatchGetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +// Given a tensor type for an invocation context output, obtain the attributes +// the HW requires for the associated tensor buffer. The returned +// `tensor_buffer_requirements` object is owned by the caller. +LiteRtStatus LiteRtDispatchGetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +// Registers a buffer with the given device context. +// Note: The memory backing the buffer should be valid until +// `LiteRtDispatchUnregisterTensorBuffer` is called. +LiteRtStatus LiteRtDispatchRegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle); + +// Unregisters the registered buffer associated with the given +// `LiteRtTensorBufferHandle`. +// Note: The registered `LiteRtTensorBufferHandle` is supposed to be +// unregistered with this function before the associated `ThrContext` is deleted +// by calling `LiteRtDispatchDeviceContextDestroy`. +LiteRtStatus LiteRtDispatchUnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle); + +// Create an invocation context to run a given function from a given +// executable. Parameter `function_name` is required if the provided executable +// includes multiple functions. +LiteRtStatus LiteRtDispatchInvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context); + +LiteRtStatus LiteRtDispatchInvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context); + +LiteRtStatus LiteRtDispatchAttachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchAttachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchDetachInput( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchDetachOutput( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +LiteRtStatus LiteRtDispatchInvoke( + LiteRtDispatchInvocationContext invocation_context); + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus LiteRtDispatchAttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event); + +LiteRtStatus LiteRtDispatchInvokeAsync( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events); + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +typedef uint64_t LiteRtDispatchNodeId; +typedef uint64_t LiteRtDispatchEdgeId; +typedef uint64_t LiteRtDispatchExecutableHandle; + +LITERT_DEFINE_HANDLE(LiteRtDispatchGraph); + +// Types of graph nodes. +typedef enum LiteRtDispatchNodeType { + kLiteRtDispatchNodeTypeUnknown = 0, + kLiteRtDispatchNodeTypeDsp = + 1, // Can execute both ML models and Dsp libraries + kLiteRtDispatchNodeTypeNpu = 2, // Can execute only ML models +} LiteRtDispatchNodeType; + +LiteRtStatus LiteRtDispatchGraphCreate( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph** graph); + +LiteRtStatus LiteRtDispatchGraphDestroy(LiteRtDispatchGraph* graph); + +// Add a compute node to a given graph. Parameter node_id should be unique to +// the graph. +LiteRtStatus LiteRtDispatchAddNode(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); + +// Add an edge a given graph. Parameter edge_id should be unique to the graph. +LiteRtStatus LiteRtDispatchAddEdge(LiteRtDispatchGraph* graph, + LiteRtDispatchEdgeId edge_id); + +// Connect a given node's input. +LiteRtStatus LiteRtDispatchConnectNodeInput(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + int input_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given node's output. +LiteRtStatus LiteRtDispatchConnectNodeOutput(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + int output_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given graph's input. +LiteRtStatus LiteRtDispatchConnectGraphInput(LiteRtDispatchGraph* graph, + int input_index, + LiteRtDispatchEdgeId edge_id); + +// Connect a given graph's output. +LiteRtStatus LiteRtDispatchConnectGraphOutput(LiteRtDispatchGraph* graph, + int output_index, + LiteRtDispatchEdgeId edge_id); + +LiteRtStatus LiteRtDispatchLoadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle); + +LiteRtStatus LiteRtDispatchUnloadExecutable( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); + +// Assign an executable function to a graph node. Parameter `function_name` is +// mandatory if the given executable includes multiple functions. +LiteRtStatus LiteRtDispatchAssignNodeFunction( + LiteRtDispatchGraph* graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name); + +// Add an annotation to an entire graph. +LiteRtStatus LiteRtDispatchAnnotateGraph(LiteRtDispatchGraph* graph, + const char* key, const char* value); + +// Add an annotation to a specified node. +LiteRtStatus LiteRtDispatchAnnotateNode(LiteRtDispatchGraph* graph, + LiteRtDispatchNodeId node_id, + const char* key, const char* value); + +// Add an annotation to a specified edge. +LiteRtStatus LiteRtDispatchAnnotateEdge(LiteRtDispatchGraph* graph, + LiteRtDispatchEdgeId edge_id, + const char* key, const char* value); + +LiteRtStatus LiteRtDispatchInvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph, + LiteRtDispatchInvocationContext* invocation_context); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h new file mode 100644 index 00000000000000..9f1586e4698dcc --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h @@ -0,0 +1,222 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchInitializeT)( + const LiteRtDispatchOption* options, int num_options); + +typedef LiteRtStatus (*LiteRtDispatchGetVendorIdT)(const char** vendor_id); + +typedef LiteRtStatus (*LiteRtDispatchGetBuildIdT)(const char** build_id); + +typedef LiteRtStatus (*LiteRtDispatchGetCapabilitiesT)(int* capabilities); + +typedef LiteRtStatus (*LiteRtDispatchDeviceContextCreateT)( + LiteRtDispatchDeviceContext* device_context); + +typedef LiteRtStatus (*LiteRtDispatchDeviceContextDestroyT)( + LiteRtDispatchDeviceContext device_context); + +typedef LiteRtStatus (*LiteRtDispatchGetInputRequirementsT)( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +typedef LiteRtStatus (*LiteRtDispatchGetOutputRequirementsT)( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements); + +typedef LiteRtStatus (*LiteRtDispatchRegisterTensorBufferT)( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchUnregisterTensorBufferT)( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle handle); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextDestroyT)( + LiteRtDispatchInvocationContext invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchAttachInputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchAttachOutputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchDetachInputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchDetachOutputT)( + LiteRtDispatchInvocationContext invocation_context, int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + +typedef LiteRtStatus (*LiteRtDispatchInvokeT)( + LiteRtDispatchInvocationContext invocation_context); + +typedef struct LiteRtDispatchInterface { + LiteRtDispatchInitializeT initialize; + LiteRtDispatchGetVendorIdT get_vendor_id; + LiteRtDispatchGetBuildIdT get_build_id; + LiteRtDispatchGetCapabilitiesT get_capabilities; + LiteRtDispatchDeviceContextCreateT device_context_create; + LiteRtDispatchDeviceContextDestroyT device_context_destroy; + LiteRtDispatchGetInputRequirementsT get_input_requirements; + LiteRtDispatchGetOutputRequirementsT get_output_requirements; + LiteRtDispatchRegisterTensorBufferT register_tensor_buffer; + LiteRtDispatchUnregisterTensorBufferT unregister_tensor_buffer; + LiteRtDispatchInvocationContextCreateT invocation_context_create; + LiteRtDispatchInvocationContextDestroyT invocation_context_destroy; + LiteRtDispatchAttachInputT attach_input; + LiteRtDispatchAttachOutputT attach_output; + LiteRtDispatchDetachInputT detach_input; + LiteRtDispatchDetachOutputT detach_output; + LiteRtDispatchInvokeT invoke; +} LiteRtDispatchInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchAttachInputEventT)( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event); + +typedef LiteRtStatus (*LiteRtDispatchInvokeAsyncT)( + LiteRtDispatchInvocationContext invocation_context, int num_output_events, + LiteRtEvent* output_events); + +typedef struct LiteRtDispatchAsyncInterface { + LiteRtDispatchAttachInputEventT attach_input_event; + LiteRtDispatchInvokeAsyncT invoke_async; +} LiteRtDispatchAsyncInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +typedef LiteRtStatus (*LiteRtDispatchGraphCreateT)( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph* graph); + +typedef LiteRtStatus (*LiteRtDispatchGraphDestroyT)(LiteRtDispatchGraph graph); + +typedef LiteRtStatus (*LiteRtDispatchAddNodeT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); + +typedef LiteRtStatus (*LiteRtDispatchAddEdgeT)(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectNodeInputT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectNodeOutputT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectGraphInputT)( + LiteRtDispatchGraph graph, int input_index, LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchConnectGraphOutputT)( + LiteRtDispatchGraph graph, int output_index, LiteRtDispatchEdgeId edge_id); + +typedef LiteRtStatus (*LiteRtDispatchLoadExecutableT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, const void* bytecode_ptr, + size_t bytecode_size, LiteRtDispatchExecutableHandle* exec_handle); + +typedef LiteRtStatus (*LiteRtDispatchUnloadExecutableT)( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); + +typedef LiteRtStatus (*LiteRtDispatchAssignNodeFunctionT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, const char* function_name); + +typedef LiteRtStatus (*LiteRtDispatchInvocationContextCreateFromGraphT)( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateGraphT)(LiteRtDispatchGraph graph, + const char* key, + const char* value); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateNodeT)( + LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, const char* key, + const char* value); + +typedef LiteRtStatus (*LiteRtDispatchAnnotateEdgeT)( + LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id, const char* key, + const char* value); + +typedef struct LiteRtDispatchGraphInterface { + LiteRtDispatchGraphCreateT graph_create; + LiteRtDispatchGraphDestroyT graph_destroy; + LiteRtDispatchAddNodeT add_node; + LiteRtDispatchAddEdgeT add_edge; + LiteRtDispatchConnectNodeInputT connect_node_input; + LiteRtDispatchConnectNodeOutputT connect_node_output; + LiteRtDispatchConnectGraphInputT connect_graph_input; + LiteRtDispatchConnectGraphOutputT connect_graph_output; + LiteRtDispatchLoadExecutableT load_executable; + LiteRtDispatchUnloadExecutableT unload_executable; + LiteRtDispatchAssignNodeFunctionT assign_node_function; + LiteRtDispatchAnnotateGraphT annotate_graph; + LiteRtDispatchAnnotateNodeT annotate_node; + LiteRtDispatchAnnotateEdgeT annotate_edge; + LiteRtDispatchInvocationContextCreateFromGraphT + invocation_context_create_from_graph; +} LiteRtDispatchGraphInterface; + +// ///////////////////////////////////////////////////////////////////////////// + +// FIXME See Vulkan and OpenCL extensions. +typedef struct LiteRtDispatchApi { + LiteRtDispatchApiVersion version; + LiteRtDispatchInterface* interface; + LiteRtDispatchAsyncInterface* async_interface; + LiteRtDispatchGraphInterface* graph_interface; +} LiteRtDispatchApi; + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_C_LITERT_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/examples/BUILD b/tensorflow/lite/experimental/litert/vendors/examples/BUILD new file mode 100644 index 00000000000000..bba0ae96fff042 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/BUILD @@ -0,0 +1,56 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_dynamic_lib( + name = "example_plugin", + srcs = ["example_plugin.cc"], + hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], + export_litert_only = True, + linkstatic = 1, + shared_lib_name = "example_plugin_so", + so_name = "libLiteRtPlugin_ExampleSocManufacturer_ExampleSocModel.so", + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:model", + ], +) + +cc_test( + name = "example_plugin_test", + srcs = [ + "example_plugin_test.cc", + ], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + deps = [ + ":example_plugin", # buildcleaner: keep + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/test:common", + "@com_google_absl//absl/log:absl_check", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc new file mode 100644 index 00000000000000..7089f8a62afc6b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc @@ -0,0 +1,188 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +// +// Configurations +// + +namespace { + +constexpr char kPluginManufacturer[] = "ExampleSocManufacturer"; +constexpr char kPluginSocModel[] = "ExampleSocModel"; + +} // namespace + +const char* LiteRtPluginSocManufacturer() { return kPluginManufacturer; } + +LiteRtParamIndex LiteRtPluginNumSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin) { + return 1; +} + +LiteRtStatus LiteRtPluginGetSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (soc_model_idx != 0) { + return kLiteRtStatusErrorUnsupported; + } + *soc_model_name = kPluginSocModel; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +struct LiteRtCompiledResultT { + std::string byte_code; + std::vector per_op_data; +}; + +LiteRtStatus LiteRtCompiledResultGetByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + *byte_code = compiled_result->byte_code.data(); + *byte_code_size = compiled_result->byte_code.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCompiledResultGetCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (call_idx >= compiled_result->per_op_data.size()) { + return kLiteRtStatusErrorIndexOOB; + } + + *call_info = compiled_result->per_op_data.at(call_idx).data(); + *call_info_size = compiled_result->per_op_data.at(call_idx).size(); + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCompiledResultGetNumCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + *num_calls = compiled_result->per_op_data.size(); + return kLiteRtStatusOk; +} + +void LiteRtCompiledResultDestroy(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LiteRtCompilerPluginT {}; + +LiteRtStatus LiteRtPluginInit(LiteRtCompilerPlugin* compiler_plugin) { + *compiler_plugin = new LiteRtCompilerPluginT; + return kLiteRtStatusOk; +} + +void LiteRtPluginDestroy(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +LiteRtStatus LiteRtPluginPartitionModel(LiteRtCompilerPlugin compiler_plugin, + LiteRtModel model, + LiteRtOpList selected_ops) { + LITERT_ASSIGN_OR_RETURN_STATUS(auto subgraph, + graph_tools::GetSubgraph(model)); + LITERT_ASSIGN_OR_RETURN_STATUS(auto ops, + graph_tools::GetSubgraphOps(subgraph)); + + for (auto op : ops) { + LiteRtOpCode op_code; + LITERT_RETURN_STATUS_IF_NOT_OK(GetOpCode(op, &op_code)); + if (op_code != kLiteRtOpCodeTflMul) { + continue; + } + LITERT_RETURN_STATUS_IF_NOT_OK(PushOp(selected_ops, op)); + } + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index, + LiteRtSubgraph subgraph, + LiteRtCompiledResultT& result) { + LITERT_ASSIGN_OR_RETURN_STATUS(auto ops, + graph_tools::GetSubgraphOps(subgraph)); + + int num_muls_in_partition = 0; + for (auto op : ops) { + LiteRtOpCode op_code; + + LITERT_RETURN_STATUS_IF_NOT_OK(GetOpCode(op, &op_code)); + if (op_code != kLiteRtOpCodeTflMul) { + return kLiteRtStatusErrorUnsupported; + } + + ++num_muls_in_partition; + } + + { + char* byte_code_append; + (void)asprintf(&byte_code_append, + "Partition_%lu_with_%d_muls:", partition_index, + num_muls_in_partition); + result.byte_code.append(byte_code_append); + free(byte_code_append); + } + + { + char* per_op_data; + (void)asprintf(&per_op_data, "Partition_%lu", partition_index); + result.per_op_data.push_back(per_op_data); + free(per_op_data); + } + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus LiteRtPluginCompile(LiteRtCompilerPlugin compiler_plugin, + const char* soc_model, + LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + LiteRtCompiledResult result = new LiteRtCompiledResultT; + + for (auto i = 0; i < num_partitions; ++i) { + LITERT_RETURN_STATUS_IF_NOT_OK( + CompileSinglePartition(i, partitions[i], *result)); + } + + *compiled_result = result; + + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc new file mode 100644 index 00000000000000..7f92b56fe5e508 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include +#include "absl/log/absl_check.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +namespace { + +UniqueLiteRtCompilerPlugin GetDummyPlugin() { + LiteRtCompilerPlugin dummy_plugin; + LITERT_CHECK_STATUS_OK(LiteRtPluginInit(&dummy_plugin)); + ABSL_CHECK_NE(dummy_plugin, nullptr); + return UniqueLiteRtCompilerPlugin(dummy_plugin); +} + +TEST(TestDummyPlugin, GetConfigInfo) { + ASSERT_STREQ(LiteRtPluginSocManufacturer(), "ExampleSocManufacturer"); + + auto plugin = GetDummyPlugin(); + + ASSERT_EQ(1, LiteRtPluginNumSupportedSocModels(plugin.get())); + + const char* soc_model_name; + ASSERT_STATUS_OK( + LiteRtPluginGetSupportedSocModel(plugin.get(), 0, &soc_model_name)); + ASSERT_STREQ(soc_model_name, "ExampleSocModel"); +} + +TEST(TestCallDummyPlugin, PartitionSimpleMultiAdd) { + auto plugin = GetDummyPlugin(); + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + + LiteRtOpListT selected_op_list; + ASSERT_STATUS_OK( + LiteRtPluginPartitionModel(plugin.get(), model.get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 2); + ASSERT_EQ(selected_ops[0]->op_code, kLiteRtOpCodeTflMul); + ASSERT_EQ(selected_ops[1]->op_code, kLiteRtOpCodeTflMul); +} + +TEST(TestCallDummyPlugin, CompileMulSubgraph) { + auto plugin = GetDummyPlugin(); + auto model = litert::testing::LoadTestFileModel("mul_simple.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, graph_tools::GetSubgraph(model.get())); + + LiteRtCompiledResult compiled; + ASSERT_STATUS_OK(LiteRtPluginCompile(plugin.get(), /*soc_model=*/nullptr, + &subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetByteCode(compiled, &byte_code, &byte_code_size)); + + std::string byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_EQ(byte_code_string, "Partition_0_with_2_muls:"); + + const void* op_data; + size_t op_data_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetCallInfo(compiled, 0, &op_data, &op_data_size)); + + std::string op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ(op_data_string, "Partition_0"); + + LiteRtCompiledResultDestroy(compiled); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD new file mode 100644 index 00000000000000..aa92bb507c49a7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/BUILD @@ -0,0 +1,85 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "dispatch_api", + srcs = [ + "dispatch_api.cc", + "litert_dispatch_device_context.cc", + "litert_dispatch_invocation_context.cc", + "southbound.cc", + ], + hdrs = [ + "dispatch_api.h", + "litert_dispatch_device_context.h", + "litert_dispatch_graph.h", + "litert_dispatch_invocation_context.h", + "southbound.h", + # copybara:uncomment "//third_party/odml/infra/southbound:sb_api.h", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core:utils", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_shared_library( + name = "dispatch_api_shared", + shared_lib_name = "libLiteRtDispatch.so", + visibility = ["//visibility:public"], + deps = [":dispatch_api"], +) + +cc_test( + name = "dispatch_api_google_tensor_test", + srcs = [ + "dispatch_api_google_tensor_test.cc", + ], + data = [ + ":dispatch_api_shared", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + deps = [ + ":dispatch_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core/dispatch", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc new file mode 100644 index 00000000000000..544d0fe0732a23 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.cc @@ -0,0 +1,1209 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" + +#include +#include +#include +#include +#include + +#if LITERT_HAS_AHWB_SUPPORT +#include +#endif + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_event.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +namespace { + +constexpr const int VERSION_MAJOR = 0; +constexpr const int VERSION_MINOR = 1; +constexpr const int VERSION_PATCH = 0; + +constexpr char kDynamicInteropKey[] = "dynamic_interop_mode"; +constexpr char kEnableEarlyWakeup[] = "2"; + +// We store THR names in a global set as a workaround to b/369144429. +std::set ThrNames; + +absl::string_view ThrNodeIdStr(LiteRtDispatchNodeId node_id) { + auto str = "node_" + std::to_string(node_id); + auto iter = ThrNames.find(str); + if (iter == ThrNames.end()) { + iter = ThrNames.insert(iter, str); + } + return *iter; +} + +absl::string_view ThrEdgeIdStr(LiteRtDispatchEdgeId edge_id) { + auto str = "edge_" + std::to_string(edge_id); + auto iter = ThrNames.find(str); + if (iter == ThrNames.end()) { + iter = ThrNames.insert(iter, str); + } + return *iter; +} + +litert::google_tensor::Southbound* TheSouthbound; +char BuildId[256]; + +} // namespace + +namespace litert { +namespace google_tensor { + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return option.value.str_value; + } + } + return nullptr; +} + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + auto* shared_library_dir = GetSharedLibraryDir(options, num_options); + std::optional shared_library_dir_opt = + shared_library_dir ? std::make_optional(std::string(shared_library_dir)) + : std::nullopt; + + if (auto southbound = + litert::google_tensor::Southbound::Create(shared_library_dir_opt); + !southbound.ok()) { + LITERT_LOG(LITERT_INFO, "Initialization failure: %s", + southbound.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } else { + TheSouthbound = southbound->release(); + } + + auto thr_initialize = TheSouthbound->thr_functions().thr_initialize; + if (!thr_initialize) { + LITERT_LOG(LITERT_INFO, "thr_initialize not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + if (auto status = thr_initialize(); status != kThrStatusSuccess) { + LITERT_LOG(LITERT_INFO, "thr_initialize failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto thr_get_vendor_api_version = + TheSouthbound->thr_functions().thr_get_vendor_api_version; + const char* sb_api_version = + thr_get_vendor_api_version ? thr_get_vendor_api_version() : "N.A."; + auto thr_get_vendor_id = TheSouthbound->thr_functions().thr_get_vendor_id; + const char* sb_vendor_id = thr_get_vendor_id ? thr_get_vendor_id() : "N.A."; + snprintf( + BuildId, sizeof(BuildId), + "GoogleTensor Dispatch API version %d.%d.%d, Darwinn API version %s, " + "vendor id: %s", + VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, sb_api_version, + sb_vendor_id); + BuildId[sizeof(BuildId) - 1] = 0; + + return kLiteRtStatusOk; +} + +LiteRtStatus GetVendorId(const char** vendor_id) { + *vendor_id = "Google"; + return kLiteRtStatusOk; +} + +LiteRtStatus GetBuildId(const char** build_id) { + *build_id = BuildId; + return kLiteRtStatusOk; +} + +LiteRtStatus GetCapabilities(int* capabilities) { + *capabilities = kLiteRtDispatchCapabilitiesBasic | + kLiteRtDispatchCapabilitiesAsync | + kLiteRtDispatchCapabilitiesGraph; + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { + if (auto status_or = LiteRtDispatchDeviceContextT::Create(*TheSouthbound); + status_or.ok()) { + *device_context = status_or->release(); + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to create device context: %s", + status_or.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { + delete device_context; + return kLiteRtStatusOk; +} + +LiteRtStatus GetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetInputRequirements(input_index, *tensor_type); + requirements.ok()) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } +} + +LiteRtStatus GetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetOutputRequirements(output_index, *tensor_type); + requirements.ok()) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer requirements: %s", + requirements.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } +} + +LiteRtStatus RegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBuffer tensor_buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = + LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer type: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (tensor_buffer_type != kLiteRtTensorBufferTypeAhwb) { + LITERT_LOG(LITERT_ERROR, "Unsupported buffer type: %d", tensor_buffer_type); + return kLiteRtStatusErrorUnsupported; + } + + size_t tensor_buffer_size; + if (auto status = + LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer size: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + size_t tensor_buffer_offset; + if (auto status = + LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); + status != kLiteRtStatusOk) { + if (status == kLiteRtStatusErrorNotFound) { + tensor_buffer_offset = 0; + } else { + LITERT_LOG(LITERT_ERROR, "Failed to get buffer offset: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + LiteRtRankedTensorType tensor_type; + if (auto status = + LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get tensor buffer type: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + LITERT_LOG(LITERT_ERROR, "Tensor strides are not supported"); + return kLiteRtStatusErrorRuntimeFailure; + } + + AHardwareBuffer* ahwb; +#if LITERT_HAS_AHWB_SUPPORT + if (auto status = LiteRtGetTensorBufferAhwb(tensor_buffer, &ahwb); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get AHWB: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } +#else + LITERT_LOG(LITERT_ERROR, "AHardwareBuffer is not supported on this platform"); + return kLiteRtStatusErrorRuntimeFailure; +#endif + + ThrContext* thr_context = device_context->thr_context(); + ThrBufferHandle thr_buffer_handle; + + if (tensor_buffer_offset == 0) { + auto thr_register_buffer = + TheSouthbound->thr_functions().thr_register_buffer; + if (!thr_register_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = thr_register_buffer( + thr_context, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, + tensor_buffer_size, &thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + } else { + auto thr_register_buffer_with_offset = + TheSouthbound->thr_functions().thr_register_buffer_with_offset; + if (!thr_register_buffer_with_offset) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = thr_register_buffer_with_offset( + thr_context, ThrBufferType::kThrBufferTypeAHardwareBuffer, ahwb, + tensor_buffer_offset, tensor_buffer_size, &thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer_with_offset failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + *tensor_buffer_handle = thr_buffer_handle; + return kLiteRtStatusOk; +} + +LiteRtStatus UnregisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_unregister_buffer = + TheSouthbound->thr_functions().thr_unregister_buffer; + if (!thr_unregister_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_unregister_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_unregister_buffer(device_context->thr_context(), + thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_register_buffer failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + LiteRtDispatchGraph graph = nullptr; + if (auto status = GraphCreate(device_context, &graph); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = + AnnotateGraph(graph, kDynamicInteropKey, kEnableEarlyWakeup); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchNodeId node_id = 0; + LiteRtDispatchNodeType node_type; + switch (exec_type) { + case kLiteRtDispatchExecutableTypeDspLibrary: + node_type = kLiteRtDispatchNodeTypeDsp; + break; + case kLiteRtDispatchExecutableTypeMlModel: + node_type = kLiteRtDispatchNodeTypeNpu; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", exec_type); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = AddNode(graph, node_id, node_type); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchExecutableHandle exec_handle; + if (auto status = LoadExecutable(device_context, exec_type, exec_bytecode, + exec_bytecode_size, &exec_handle); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = + AssignNodeFunction(graph, node_id, exec_handle, function_name); + status != kLiteRtStatusOk) { + return status; + } + + LiteRtDispatchEdgeId next_edge_id = 0; + + for (auto input_index = 0; input_index < num_inputs; ++input_index) { + LiteRtDispatchEdgeId edge_id = next_edge_id++; + if (auto status = AddEdge(graph, edge_id); status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectGraphInput(graph, input_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectNodeInput(graph, node_id, input_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + } + + for (auto output_index = 0; output_index < num_outputs; ++output_index) { + LiteRtDispatchEdgeId edge_id = next_edge_id++; + if (auto status = AddEdge(graph, edge_id); status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectNodeOutput(graph, node_id, output_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + if (auto status = ConnectGraphOutput(graph, output_index, edge_id); + status != kLiteRtStatusOk) { + return status; + } + } + + if (auto status = InvocationContextCreateFromGraph(device_context, graph, + invocation_context); + status != kLiteRtStatusOk) { + return status; + } + + (*invocation_context)->AttachExecutable(exec_handle); + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_delete = + TheSouthbound->thr_functions().thr_invocation_context_delete; + if (!thr_invocation_context_delete) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = invocation_context->graph()->thr_graph(); + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_delete(thr_graph, thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_delete failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + delete invocation_context; + + return kLiteRtStatusOk; +} + +LiteRtStatus AttachBufferHelper( + LiteRtDispatchInvocationContext invocation_context, + LiteRtDispatchEdgeId edge_id, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_invocation_context_attach_buffer = + TheSouthbound->thr_functions().thr_invocation_context_attach_buffer; + if (!thr_invocation_context_attach_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + ThrContext* thr_context = invocation_context->device_context()->thr_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_invocation_context_attach_buffer( + thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_attach_buffer failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status_or = + invocation_context->graph()->InputEdge(graph_input_index); + !status_or.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status_or; + return AttachBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->graph()->OutputEdge(graph_output_index); + !status.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status; + return AttachBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus DetachTensorBufferHelper( + LiteRtDispatchInvocationContext invocation_context, + LiteRtDispatchEdgeId edge_id, + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto thr_invocation_context_detach_buffer = + TheSouthbound->thr_functions().thr_invocation_context_detach_buffer; + if (!thr_invocation_context_detach_buffer) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + ThrContext* thr_context = invocation_context->device_context()->thr_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrBufferHandle thr_buffer_handle = tensor_buffer_handle; + if (auto status = thr_invocation_context_detach_buffer( + thr_icontext, thr_context, thr_edge_id.data(), thr_buffer_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_detach_buffer failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status_or = + invocation_context->graph()->InputEdge(graph_input_index); + !status_or.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status_or; + return DetachTensorBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->graph()->OutputEdge(graph_output_index); + !status.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } else { + auto edge_id = *status; + return DetachTensorBufferHelper(invocation_context, edge_id, + tensor_buffer_handle); + } +} + +LiteRtStatus PrepareForInvoke( + LiteRtDispatchInvocationContext invocation_context, + bool create_output_sync_fence) { + auto thr_invocation_context_prepare_for_invoke = + TheSouthbound->thr_functions().thr_invocation_context_prepare_for_invoke; + if (!thr_invocation_context_prepare_for_invoke) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_prepare_for_invoke not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_prepare_for_invoke( + thr_icontext, create_output_sync_fence); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_prepare_for_invoke failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvokeOnce(LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_invoke_once = + TheSouthbound->thr_functions().thr_invocation_context_invoke_once; + if (!thr_invocation_context_invoke_once) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_invoke_once(thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_invoke_once failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus Wait(LiteRtDispatchInvocationContext invocation_context) { + auto thr_invocation_context_wait = + TheSouthbound->thr_functions().thr_invocation_context_wait; + if (!thr_invocation_context_wait) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + if (auto status = thr_invocation_context_wait(thr_icontext); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_wait failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { + if (auto status = PrepareForInvoke(invocation_context, + /*create_output_sync_fence=*/false); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = InvokeOnce(invocation_context); status != kLiteRtStatusOk) { + return status; + } + return Wait(invocation_context); +} + +// ///////////////////////////////////////////////////////////////////////////// +// Async Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus AttachInputEvent( + LiteRtDispatchInvocationContext invocation_context, int graph_input_index, + LiteRtEvent input_event) { + auto status_or = invocation_context->graph()->InputEdge(graph_input_index); + if (!status_or.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph input index: %d", + graph_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + auto edge_id = *status_or; + + auto thr_invocation_context_attach_input_buffer_sync_fence = + TheSouthbound->thr_functions() + .thr_invocation_context_attach_input_buffer_sync_fence; + if (!thr_invocation_context_attach_input_buffer_sync_fence) { + LITERT_LOG( + LITERT_ERROR, + "thr_invocation_context_attach_input_buffer_sync_fence not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int input_fence_fd; + if (auto status = LiteRtEventGetSyncFenceFd(input_event, &input_fence_fd); + status != kLiteRtStatusOk) { + return status; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_invocation_context_attach_input_buffer_sync_fence( + thr_icontext, thr_edge_id.data(), input_fence_fd); + status != kThrStatusSuccess) { + LITERT_LOG( + LITERT_ERROR, + "thr_invocation_context_attach_input_buffer_sync_fence failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +namespace { + +LiteRtStatus GetOutputEvent(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, LiteRtEvent* output_event) { + auto status_or = invocation_context->graph()->OutputEdge(graph_output_index); + if (!status_or.ok()) { + LITERT_LOG(LITERT_ERROR, "Unexpected graph output index: %d", + graph_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + auto edge_id = *status_or; + + auto thr_invocation_context_get_output_buffer_sync_fence = + TheSouthbound->thr_functions() + .thr_invocation_context_get_output_buffer_sync_fence; + if (!thr_invocation_context_get_output_buffer_sync_fence) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_get_output_buffer_sync_fence not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrInvocationContext* thr_icontext = + invocation_context->thr_invocation_context(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + int output_fence_fd; + if (auto status = thr_invocation_context_get_output_buffer_sync_fence( + thr_icontext, thr_edge_id.data(), &output_fence_fd); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, + "thr_invocation_context_get_output_buffer_sync_fence failed: %d", + status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = LiteRtEventCreateFromSyncFenceFd( + output_fence_fd, /*owns_fd=*/false, output_event); + status != kLiteRtStatusOk) { + return status; + } + + return kLiteRtStatusOk; +} + +} // namespace + +LiteRtStatus InvokeAsync(LiteRtDispatchInvocationContext invocation_context, + int num_output_events, LiteRtEvent* output_events) { + if (num_output_events != invocation_context->graph()->NumOutputs()) { + LITERT_LOG(LITERT_ERROR, "Unexpected number of output events: %d", + num_output_events); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = PrepareForInvoke(invocation_context, + /*create_output_sync_fence=*/true); + status != kLiteRtStatusOk) { + return status; + } + + if (auto status = InvokeOnce(invocation_context); status != kLiteRtStatusOk) { + return status; + } + + for (auto graph_output_index = 0; graph_output_index < num_output_events; + ++graph_output_index) { + if (auto status = GetOutputEvent(invocation_context, graph_output_index, + &output_events[graph_output_index]); + status != kLiteRtStatusOk) { + LITERT_LOG(LITERT_ERROR, "Failed to get event for output %d: %d", + graph_output_index, status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + return kLiteRtStatusOk; +} + +// ///////////////////////////////////////////////////////////////////////////// +// Graph Execution API +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph* graph) { + auto thr_graph_create = TheSouthbound->thr_functions().thr_graph_create; + if (!thr_graph_create) { + LITERT_LOG(LITERT_ERROR, "thr_graph_create not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = thr_graph_create(device_context->thr_context()); + if (!thr_graph) { + LITERT_LOG(LITERT_ERROR, "thr_graph_create failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + *graph = new LiteRtDispatchGraphT(thr_graph, device_context); + return kLiteRtStatusOk; +} + +LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph) { + auto thr_graph_delete = TheSouthbound->thr_functions().thr_graph_delete; + if (!thr_graph_delete) { + LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->device_context()->remove_graph(graph->thr_graph()); + + ThrGraph* thr_graph = graph->thr_graph(); + if (auto status = thr_graph_delete(thr_graph); status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_destroy failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + delete graph; + return kLiteRtStatusOk; +} + +LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type) { + auto thr_graph_add_sq_node = + TheSouthbound->thr_functions().thr_graph_add_sq_node; + if (!thr_graph_add_sq_node) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + ThrNodeType thr_node_type; + switch (node_type) { + case kLiteRtDispatchNodeTypeDsp: + thr_node_type = kThrNodeTypeDsp; + break; + case kLiteRtDispatchNodeTypeNpu: + thr_node_type = kThrNodeTypeNpu; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected node type: %d", node_type); + return kLiteRtStatusErrorInvalidArgument; + } + + if (auto status = + thr_graph_add_sq_node(thr_graph, thr_node_id.data(), thr_node_type); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_sq_node failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id) { + auto thr_graph_add_edge = TheSouthbound->thr_functions().thr_graph_add_edge; + if (!thr_graph_add_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + ThrEdgeType thr_edge_type = kThrEdgeNoType; + if (auto status = + thr_graph_add_edge(thr_graph, thr_edge_id.data(), thr_edge_type); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_add_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id) { + auto thr_graph_connect_node_input = + TheSouthbound->thr_functions().thr_graph_connect_node_input; + if (!thr_graph_connect_node_input) { + LITERT_LOG(LITERT_ERROR, "thr_graph_connect_node_input not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int next_input_index = graph->NextNodeInputIndex(node_id); + if (input_index != next_input_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", + input_index, next_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_connect_node_input(thr_graph, thr_node_id.data(), + thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->AddInputEdge(input_index, edge_id); + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id) { + auto thr_graph_connect_node_output = + TheSouthbound->thr_functions().thr_graph_connect_node_output; + if (!thr_graph_connect_node_output) { + LITERT_LOG(LITERT_ERROR, "thr_graph_connect_node_output not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + int next_output_index = graph->NextNodeOutputIndex(node_id); + if (output_index != next_output_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", + output_index, next_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_connect_node_output(thr_graph, thr_node_id.data(), + thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + graph->AddOutputEdge(output_index, edge_id); + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, + LiteRtDispatchEdgeId edge_id) { + int next_input_index = graph->NextGraphInputIndex(); + if (input_index != next_input_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected input index %d, expected %d", + input_index, next_input_index); + return kLiteRtStatusErrorInvalidArgument; + } + + auto thr_graph_set_input_edge = + TheSouthbound->thr_functions().thr_graph_set_input_edge; + if (!thr_graph_set_input_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_set_input_edge(thr_graph, thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_input_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, + LiteRtDispatchEdgeId edge_id) { + int next_output_index = graph->NextGraphOutputIndex(); + if (output_index != next_output_index) { + LITERT_LOG(LITERT_ERROR, "Unexpected output index %d, expected %d", + output_index, next_output_index); + return kLiteRtStatusErrorInvalidArgument; + } + + auto thr_graph_set_output_edge = + TheSouthbound->thr_functions().thr_graph_set_output_edge; + if (!thr_graph_set_output_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = thr_graph_set_output_edge(thr_graph, thr_edge_id.data()); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_set_output_edge failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, + const void* bytecode, size_t bytecode_size, + LiteRtDispatchExecutableHandle* exec_handle) { + auto thr_load_sq_container = + TheSouthbound->thr_functions().thr_load_sq_container; + if (!thr_load_sq_container) { + LITERT_LOG(LITERT_ERROR, "thr_load_sq_container not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrSqContainerType thr_type; + switch (type) { + case kLiteRtDispatchExecutableTypeDspLibrary: + thr_type = kThrSqContainerTypeFunctionLibrary; + break; + case kLiteRtDispatchExecutableTypeMlModel: + thr_type = kThrSqContainerTypeMlModel; + break; + default: + LITERT_LOG(LITERT_ERROR, "Unexpected executable type: %d", type); + return kLiteRtStatusErrorInvalidArgument; + } + + ThrContext* thr_context = device_context->thr_context(); + ThrSqContainerHandle sq_handle; + if (auto status = thr_load_sq_container(thr_context, thr_type, bytecode, + bytecode_size, &sq_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_load_sq_container failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + *exec_handle = sq_handle; + return kLiteRtStatusOk; +} + +LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle) { + auto thr_unload_sq_container = + TheSouthbound->thr_functions().thr_unload_sq_container; + if (!thr_unload_sq_container) { + LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrContext* thr_context = device_context->thr_context(); + ThrSqContainerHandle sq_handle = exec_handle; + if (auto status = thr_unload_sq_container(thr_context, sq_handle); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_unload_sq_container failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, + const char* function_name) { + auto thr_graph_assign_sq = TheSouthbound->thr_functions().thr_graph_assign_sq; + if (!thr_graph_assign_sq) { + LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + ThrSqContainerHandle sq_handle = exec_handle; + if (auto status = thr_graph_assign_sq(thr_graph, thr_node_id.data(), + sq_handle, function_name); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_assign_sq failed: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, + const char* value) { + auto thr_graph_annotate_graph = + TheSouthbound->thr_functions().thr_graph_annotate_graph; + if (!thr_graph_annotate_graph) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + if (auto status = thr_graph_annotate_graph(thr_graph, key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_graph failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, const char* key, + const char* value) { + auto thr_graph_annotate_node = + TheSouthbound->thr_functions().thr_graph_annotate_node; + if (!thr_graph_annotate_node) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_node_id = ThrNodeIdStr(node_id); + if (auto status = + thr_graph_annotate_node(thr_graph, thr_node_id.data(), key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_node failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, const char* key, + const char* value) { + auto thr_graph_annotate_edge = + TheSouthbound->thr_functions().thr_graph_annotate_edge; + if (!thr_graph_annotate_edge) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_edge_id = ThrEdgeIdStr(edge_id); + if (auto status = + thr_graph_annotate_edge(thr_graph, thr_edge_id.data(), key, value); + status != kThrStatusSuccess) { + LITERT_LOG(LITERT_ERROR, "thr_graph_annotate_edge failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context) { + auto thr_invocation_context_get = + TheSouthbound->thr_functions().thr_invocation_context_get; + if (!thr_invocation_context_get) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_get not found"); + return kLiteRtStatusErrorRuntimeFailure; + } + + ThrGraph* thr_graph = graph->thr_graph(); + auto thr_icontext = + thr_invocation_context_get(thr_graph, device_context->thr_context()); + if (!thr_icontext) { + LITERT_LOG(LITERT_ERROR, "thr_invocation_context_get failed"); + return kLiteRtStatusErrorRuntimeFailure; + } + + device_context->add_graph(thr_graph); + *invocation_context = + new LiteRtDispatchInvocationContextT(thr_icontext, device_context, graph); + + return kLiteRtStatusOk; +} + +} // namespace google_tensor +} // namespace litert + +// ///////////////////////////////////////////////////////////////////////////// + +namespace { + +LiteRtDispatchInterface TheInterface = { + .initialize = litert::google_tensor::Initialize, + .get_vendor_id = litert::google_tensor::GetVendorId, + .get_build_id = litert::google_tensor::GetBuildId, + .get_capabilities = litert::google_tensor::GetCapabilities, + .device_context_create = litert::google_tensor::DeviceContextCreate, + .device_context_destroy = litert::google_tensor::DeviceContextDestroy, + .get_input_requirements = litert::google_tensor::GetInputRequirements, + .get_output_requirements = litert::google_tensor::GetOutputRequirements, + .register_tensor_buffer = litert::google_tensor::RegisterTensorBuffer, + .unregister_tensor_buffer = litert::google_tensor::UnregisterTensorBuffer, + .invocation_context_create = litert::google_tensor::InvocationContextCreate, + .invocation_context_destroy = + litert::google_tensor::InvocationContextDestroy, + .attach_input = litert::google_tensor::AttachInput, + .attach_output = litert::google_tensor::AttachOutput, + .detach_input = litert::google_tensor::DetachInput, + .detach_output = litert::google_tensor::DetachOutput, + .invoke = litert::google_tensor::Invoke, +}; + +LiteRtDispatchAsyncInterface TheAsyncInterface = { + .attach_input_event = litert::google_tensor::AttachInputEvent, + .invoke_async = litert::google_tensor::InvokeAsync, +}; + +LiteRtDispatchGraphInterface TheGraphInterface = { + .graph_create = litert::google_tensor::GraphCreate, + .graph_destroy = litert::google_tensor::GraphDestroy, + .add_node = litert::google_tensor::AddNode, + .add_edge = litert::google_tensor::AddEdge, + .connect_node_input = litert::google_tensor::ConnectNodeInput, + .connect_node_output = litert::google_tensor::ConnectNodeOutput, + .connect_graph_input = litert::google_tensor::ConnectGraphInput, + .connect_graph_output = litert::google_tensor::ConnectGraphOutput, + .load_executable = litert::google_tensor::LoadExecutable, + .unload_executable = litert::google_tensor::UnloadExecutable, + .assign_node_function = litert::google_tensor::AssignNodeFunction, + .annotate_graph = litert::google_tensor::AnnotateGraph, + .annotate_node = litert::google_tensor::AnnotateNode, + .annotate_edge = litert::google_tensor::AnnotateEdge, + .invocation_context_create_from_graph = + litert::google_tensor::InvocationContextCreateFromGraph, +}; + +LiteRtDispatchApi TheApi = { + .version = {.major = VERSION_MAJOR, + .minor = VERSION_MINOR, + .patch = VERSION_PATCH}, + .interface = &TheInterface, + .async_interface = &TheAsyncInterface, + .graph_interface = &TheGraphInterface, +}; + +} // namespace + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { + *api = TheApi; + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h new file mode 100644 index 00000000000000..00e0559c085d91 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ + +#include + +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace litert { +namespace google_tensor { + +LiteRtStatus GraphCreate(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph* graph); +LiteRtStatus GraphDestroy(LiteRtDispatchGraph graph); +LiteRtStatus AddNode(LiteRtDispatchGraph graph, LiteRtDispatchNodeId node_id, + LiteRtDispatchNodeType node_type); +LiteRtStatus AddEdge(LiteRtDispatchGraph graph, LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectNodeInput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int input_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectNodeOutput(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, int output_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectGraphInput(LiteRtDispatchGraph graph, int input_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus ConnectGraphOutput(LiteRtDispatchGraph graph, int output_index, + LiteRtDispatchEdgeId edge_id); +LiteRtStatus LoadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType type, + const void* bytecode, size_t bytecode_size, + LiteRtDispatchExecutableHandle* exec_handle); +LiteRtStatus UnloadExecutable(LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableHandle exec_handle); +LiteRtStatus AssignNodeFunction(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, + LiteRtDispatchExecutableHandle exec_handle, + const char* function_name); +LiteRtStatus AnnotateGraph(LiteRtDispatchGraph graph, const char* key, + const char* value); +LiteRtStatus AnnotateNode(LiteRtDispatchGraph graph, + LiteRtDispatchNodeId node_id, const char* key, + const char* value); +LiteRtStatus AnnotateEdge(LiteRtDispatchGraph graph, + LiteRtDispatchEdgeId edge_id, const char* key, + const char* value); +LiteRtStatus InvocationContextCreateFromGraph( + LiteRtDispatchDeviceContext device_context, LiteRtDispatchGraph graph, + LiteRtDispatchInvocationContext* invocation_context); + +} // namespace google_tensor +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_DISPATCH_API_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc new file mode 100644 index 00000000000000..536015b10f7685 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api_google_tensor_test.cc @@ -0,0 +1,277 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +TEST(DispatchApi, GoogleTensor) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a GoogleTensor eTPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtDispatchApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kGoogleTensorModelFileName; + auto model = litert::testing::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model.ok()); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->data(), model->size(), /*function_name=*/nullptr, + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/0, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/0, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/0, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeAhwb); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto* output = static_cast(host_mem_addr); + constexpr auto output_size = + sizeof(kTestOutputTensor) / sizeof(kTestOutputTensor[0]); + for (auto i = 0; i < output_size; ++i) { + EXPECT_NEAR(output[i], kTestOutputTensor[i], 1e-3); + } + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc new file mode 100644 index 00000000000000..ffe528b30c80a7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.cc @@ -0,0 +1,58 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +LiteRtDispatchDeviceContextT::~LiteRtDispatchDeviceContextT() { + if (!thr_graphs_.empty()) { + auto thr_graph_delete = southbound_.thr_functions().thr_graph_delete; + if (!thr_graph_delete) { + LITERT_LOG(LITERT_ERROR, "thr_graph_delete not found"); + } else { + for (auto* thr_graph : thr_graphs_) { + thr_graph_delete(thr_graph); + } + } + } + + if (thr_context_) { + auto thr_context_delete = southbound_.thr_functions().thr_context_delete; + if (!thr_context_delete) { + LITERT_LOG(LITERT_ERROR, "thr_context_delete not found"); + } else { + thr_context_delete(thr_context_); + } + } +} + +absl::StatusOr> +LiteRtDispatchDeviceContextT::Create( + const litert::google_tensor::Southbound& southbound) { + std::unique_ptr device_context( + new LiteRtDispatchDeviceContextT(southbound)); + + auto thr_context_create = southbound.thr_functions().thr_context_create; + if (!thr_context_create) { + return absl::InternalError("thr_context_create not found"); + } + + device_context->thr_context_ = thr_context_create(); + return device_context; +} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h new file mode 100644 index 00000000000000..7617b5c0e681cf --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_device_context.h @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +class LiteRtDispatchDeviceContextT { + public: + ~LiteRtDispatchDeviceContextT(); + + static absl::StatusOr> Create( + const litert::google_tensor::Southbound& southbound); + + ThrContext* thr_context() { return thr_context_; } + void add_graph(ThrGraph* graph) { thr_graphs_.insert(graph); } + void remove_graph(ThrGraph* graph) { thr_graphs_.erase(graph); } + + private: + explicit LiteRtDispatchDeviceContextT( + const litert::google_tensor::Southbound& southbound) + : southbound_(southbound) {} + + const litert::google_tensor::Southbound& southbound_; + ThrContext* thr_context_ = nullptr; + absl::flat_hash_set thr_graphs_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h new file mode 100644 index 00000000000000..c63e1a5a7954ab --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_graph.h @@ -0,0 +1,93 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +class LiteRtDispatchGraphT { + public: + LiteRtDispatchGraphT(ThrGraph* thr_graph, + LiteRtDispatchDeviceContext device_context) + : thr_graph_(thr_graph), device_context_(device_context) {} + + ThrGraph* thr_graph() { return thr_graph_; } + + LiteRtDispatchDeviceContext device_context() { return device_context_; } + + int NextNodeInputIndex(LiteRtDispatchNodeId node_id) { + return NextNodeIoIndex(node_id, next_node_input_index_); + } + + int NextNodeOutputIndex(LiteRtDispatchNodeId node_id) { + return NextNodeIoIndex(node_id, next_node_output_index_); + } + + int NextGraphInputIndex() { return next_graph_input_index_++; } + + int NextGraphOutputIndex() { return next_graph_output_index_++; } + + void AddInputEdge(int input_index, LiteRtDispatchEdgeId edge_id) { + input_edges_[input_index] = edge_id; + } + + void AddOutputEdge(int output_index, LiteRtDispatchEdgeId edge_id) { + output_edges_[output_index] = edge_id; + } + + absl::StatusOr InputEdge(int input_index) const { + return IoEdge(input_index, input_edges_); + } + + absl::StatusOr OutputEdge(int output_index) const { + return IoEdge(output_index, output_edges_); + } + + size_t NumOutputs() const { return output_edges_.size(); } + + private: + using NextNodeIoIndexMap = std::map; + using IoIndexToEdgeIdMap = std::map; + + int NextNodeIoIndex(LiteRtDispatchNodeId node_id, NextNodeIoIndexMap& map) { + return map[node_id]++; + } + + absl::StatusOr IoEdge( + int io_index, const IoIndexToEdgeIdMap& map) const { + auto iter = map.find(io_index); + if (iter == map.end()) { + return absl::NotFoundError("Unexpected graph input/output index"); + } + return iter->second; + } + + ThrGraph* thr_graph_; + LiteRtDispatchDeviceContext device_context_; + NextNodeIoIndexMap next_node_input_index_; + NextNodeIoIndexMap next_node_output_index_; + int next_graph_input_index_ = 0; + int next_graph_output_index_ = 0; + IoIndexToEdgeIdMap input_edges_; + IoIndexToEdgeIdMap output_edges_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc new file mode 100644 index 00000000000000..d3d029e6e9713c --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.cc @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/core/utils.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +namespace { + +constexpr const size_t kEdgeTpuPadding = 64; + +inline constexpr auto Pad(auto x, auto align) { + return ((x + align - 1) / align) * align; +} + +absl::StatusOr GetTensorBufferRequirements( + const LiteRtRankedTensorType& tensor_type) { + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return absl::InternalError( + "Tensor strides are not supported on GoogleTensor"); + } + + LiteRtTensorBufferType supported_tensor_buffer_types[] = { + kLiteRtTensorBufferTypeAhwb, + }; + int num_supported_tensor_buffer_types = + sizeof(supported_tensor_buffer_types) / + sizeof(supported_tensor_buffer_types[0]); + + auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); + if (!buffer_size.ok()) { + return buffer_size.status(); + } + + size_t padded_buffer_size = Pad(*buffer_size, kEdgeTpuPadding); + + LiteRtTensorBufferRequirements requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + num_supported_tensor_buffer_types, supported_tensor_buffer_types, + padded_buffer_size, &requirements); + status != kLiteRtStatusOk) { + return absl::InternalError("Not implemented"); + } + + return requirements; +} +} // namespace + +absl::StatusOr +LiteRtDispatchInvocationContextT::GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +absl::StatusOr +LiteRtDispatchInvocationContextT::GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h new file mode 100644 index 00000000000000..996e4a4614bafb --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/litert_dispatch_invocation_context.h @@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ + +#include + +#include "absl/status/statusor.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/dispatch_api.h" + +class LiteRtDispatchInvocationContextT { + public: + LiteRtDispatchInvocationContextT(ThrInvocationContext* thr_invocation_context, + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchGraph graph) + : thr_invocation_context_(thr_invocation_context), + device_context_(device_context), + graph_(graph) {} + + ~LiteRtDispatchInvocationContextT() { + if (exec_handle_) { + litert::google_tensor::UnloadExecutable(device_context_, *exec_handle_); + } + } + + absl::StatusOr GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type); + absl::StatusOr GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type); + + ThrInvocationContext* thr_invocation_context() { + return thr_invocation_context_; + } + + LiteRtDispatchDeviceContext device_context() { return device_context_; } + + LiteRtDispatchGraph graph() { return graph_; } + + void AttachExecutable(LiteRtDispatchExecutableHandle exec_handle) { + exec_handle_ = exec_handle; + } + + private: + ThrInvocationContext* thr_invocation_context_; + LiteRtDispatchDeviceContext device_context_; + LiteRtDispatchGraph graph_; + std::optional exec_handle_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc new file mode 100644 index 00000000000000..04ef1159cf035e --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.cc @@ -0,0 +1,149 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h" + +#include + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/odml/infra/southbound/sb_api.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" + +#define Load(H, S) \ + H = reinterpret_cast(::dlsym(dlib_handle_, #S)); \ + if (!H) { \ + LITERT_LOG(LITERT_WARNING, "Failed to load symbol %s: %s", #S, \ + ::dlerror()); \ + } + +namespace litert { +namespace google_tensor { + +namespace { +// Currently the SouthBound implementation is bundled inside the Edge TPU +// runtime shared library. +constexpr const char* kSouthBoundLibPath = "/vendor/lib64/libedgetpu_util.so"; +} // namespace + +Southbound::Southbound() : thr_functions_(new ThrFunctions) {} + +Southbound::~Southbound() { + if (dlib_handle_) { + ::dlclose(dlib_handle_); + } +} + +absl::StatusOr> Southbound::Create( + std::optional shared_library_dir) { + std::unique_ptr southbound(new Southbound); + if (auto status = southbound->LoadSymbols(shared_library_dir); !status.ok()) { + return status; + } + + return southbound; +} + +absl::Status Southbound::LoadSymbols( + std::optional shared_library_dir) { + // Always load the Southbound API library from the vendor partition. + (void)shared_library_dir; + + dlib_handle_ = ::dlopen(kSouthBoundLibPath, RTLD_NOW | RTLD_LOCAL); + if (!dlib_handle_) { + return absl::InternalError("Failed to load Southbound shared library"); + } + + // Binds all supported symbols from the shared library to the function + // pointers. + Load(thr_functions_->thr_initialize, thrInitialize); + + Load(thr_functions_->thr_get_vendor_api_version, thrGetVendorApiVersion); + Load(thr_functions_->thr_get_vendor_id, thrGetVendorId); + + Load(thr_functions_->thr_context_create, thrContextCreate); + Load(thr_functions_->thr_context_delete, thrContextDelete); + + Load(thr_functions_->thr_graph_create, thrGraphCreate); + Load(thr_functions_->thr_graph_delete, thrGraphDelete); + + Load(thr_functions_->thr_graph_add_edge, thrGraphAddEdge); + Load(thr_functions_->thr_graph_add_sq_node, thrGraphAddSqNode); + + Load(thr_functions_->thr_graph_connect_node_input, thrGraphConnectNodeInput); + Load(thr_functions_->thr_graph_connect_node_output, + thrGraphConnectNodeOutput); + + Load(thr_functions_->thr_graph_set_input_edge, thrGraphSetInputEdge); + Load(thr_functions_->thr_graph_set_output_edge, thrGraphSetOutputEdge); + + Load(thr_functions_->thr_graph_annotate_graph, thrGraphAnnotateGraph); + Load(thr_functions_->thr_graph_annotate_edge, thrGraphAnnotateEdge); + Load(thr_functions_->thr_graph_annotate_node, thrGraphAnnotateNode); + + Load(thr_functions_->thr_load_sq_container, thrLoadSqContainer); + Load(thr_functions_->thr_load_sq_container_fd, thrLoadSqContainerFd); + Load(thr_functions_->thr_load_sq_container_file, thrLoadSqContainerFile); + Load(thr_functions_->thr_unload_sq_container, thrUnloadSqContainer); + + Load(thr_functions_->thr_graph_assign_sq, thrGraphAssignSq); + Load(thr_functions_->thr_sq_query_scratch_pad, thrSqQueryScratchPad); + Load(thr_functions_->thr_sq_attach_scratch_pad_buffer, + thrSqAttachScratchPadBuffer); + + Load(thr_functions_->thr_register_buffer, thrRegisterBuffer); + Load(thr_functions_->thr_register_buffer_with_offset, + thrRegisterBufferWithOffset); + Load(thr_functions_->thr_unregister_buffer, thrUnregisterBuffer); + + Load(thr_functions_->thr_invocation_context_get, thrInvocationContextGet); + Load(thr_functions_->thr_invocation_context_delete, + thrInvocationContextDelete); + + Load(thr_functions_->thr_invocation_context_attach_buffer, + thrInvocationContextAttachBuffer); + Load(thr_functions_->thr_invocation_context_detach_buffer, + thrInvocationContextDetachBuffer); + + Load(thr_functions_->thr_invocation_context_prepare_for_invoke, + thrInvocationContextPrepareForInvoke); + Load(thr_functions_->thr_invocation_context_invoke_once, + thrInvocationContextInvokeOnce); + Load(thr_functions_->thr_invocation_context_wait, thrInvocationContextWait); + + Load(thr_functions_->thr_invocation_context_attach_input_buffer_sync_fence, + thrInvocationContextAttachInputBufferSyncFence); + Load(thr_functions_->thr_invocation_context_get_output_buffer_sync_fence, + thrInvocationContextGetOutputBufferSyncFence); + + Load(thr_functions_->thr_invocation_context_query_node_scratch_pad, + thrInvocationContextQueryNodeScratchPad); + Load(thr_functions_->thr_invocation_context_attach_scratch_pad_buffer, + thrInvocationContextAttachScratchPadBuffer); + + Load(thr_functions_->thr_vendor_set_system_attribute_str, + thrVendorSetSystemAttributeStr); + Load(thr_functions_->thr_vendor_set_system_attribute_int64, + thrVendorSetSystemAttributeInt64); + + LITERT_LOG(LITERT_INFO, "SouthBound symbols loaded"); + return {}; +} + +} // namespace google_tensor +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h new file mode 100644 index 00000000000000..d6d5ebcf6789aa --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/google_tensor/dispatch/southbound.h @@ -0,0 +1,129 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/odml/infra/southbound/sb_api.h" + +namespace litert { +namespace google_tensor { + +class Southbound { + public: + struct ThrFunctions; + + Southbound(Southbound&) = delete; + Southbound(Southbound&&) = delete; + Southbound& operator=(const Southbound&) = delete; + Southbound& operator=(Southbound&&) = delete; + + ~Southbound(); + + static absl::StatusOr> Create( + std::optional shared_library_dir); + + const ThrFunctions& thr_functions() const { return *thr_functions_; } + + private: + Southbound(); + absl::Status LoadSymbols(std::optional shared_library_dir); + + void* dlib_handle_ = nullptr; + std::unique_ptr thr_functions_; +}; + +// A convenient struct for holding function pointers to SouthBound symbols. +// These function pointers will be loaded to the shared library on device during +// runtime. +struct Southbound::ThrFunctions { + decltype(&thrInitialize) thr_initialize = nullptr; + + decltype(&thrGetVendorApiVersion) thr_get_vendor_api_version = nullptr; + decltype(&thrGetVendorId) thr_get_vendor_id = nullptr; + + decltype(&thrContextCreate) thr_context_create = nullptr; + decltype(&thrContextDelete) thr_context_delete = nullptr; + + decltype(&thrGraphCreate) thr_graph_create = nullptr; + decltype(&thrGraphDelete) thr_graph_delete = nullptr; + + decltype(&thrGraphAddEdge) thr_graph_add_edge = nullptr; + decltype(&thrGraphAddSqNode) thr_graph_add_sq_node = nullptr; + + decltype(&thrGraphConnectNodeInput) thr_graph_connect_node_input = nullptr; + decltype(&thrGraphConnectNodeOutput) thr_graph_connect_node_output = nullptr; + + decltype(&thrGraphSetInputEdge) thr_graph_set_input_edge = nullptr; + decltype(&thrGraphSetOutputEdge) thr_graph_set_output_edge = nullptr; + + decltype(&thrGraphAnnotateGraph) thr_graph_annotate_graph = nullptr; + decltype(&thrGraphAnnotateEdge) thr_graph_annotate_edge = nullptr; + decltype(&thrGraphAnnotateNode) thr_graph_annotate_node = nullptr; + + decltype(&thrLoadSqContainer) thr_load_sq_container = nullptr; + decltype(&thrLoadSqContainerFd) thr_load_sq_container_fd = nullptr; + decltype(&thrLoadSqContainerFile) thr_load_sq_container_file = nullptr; + decltype(&thrUnloadSqContainer) thr_unload_sq_container = nullptr; + + decltype(&thrGraphAssignSq) thr_graph_assign_sq = nullptr; + decltype(&thrSqQueryScratchPad) thr_sq_query_scratch_pad = nullptr; + decltype(&thrSqAttachScratchPadBuffer) thr_sq_attach_scratch_pad_buffer = + nullptr; + + decltype(&thrRegisterBuffer) thr_register_buffer = nullptr; + decltype(&thrRegisterBufferWithOffset) thr_register_buffer_with_offset = + nullptr; + decltype(&thrUnregisterBuffer) thr_unregister_buffer = nullptr; + + decltype(&thrInvocationContextGet) thr_invocation_context_get = nullptr; + decltype(&thrInvocationContextDelete) thr_invocation_context_delete = nullptr; + + decltype(&thrInvocationContextAttachBuffer) + thr_invocation_context_attach_buffer = nullptr; + decltype(&thrInvocationContextDetachBuffer) + thr_invocation_context_detach_buffer = nullptr; + + decltype(&thrInvocationContextPrepareForInvoke) + thr_invocation_context_prepare_for_invoke = nullptr; + decltype(&thrInvocationContextInvokeOnce) thr_invocation_context_invoke_once = + nullptr; + decltype(&thrInvocationContextWait) thr_invocation_context_wait = nullptr; + + decltype(&thrInvocationContextAttachInputBufferSyncFence) + thr_invocation_context_attach_input_buffer_sync_fence = nullptr; + decltype(&thrInvocationContextGetOutputBufferSyncFence) + thr_invocation_context_get_output_buffer_sync_fence = nullptr; + + decltype(&thrInvocationContextQueryNodeScratchPad) + thr_invocation_context_query_node_scratch_pad = nullptr; + decltype(&thrInvocationContextAttachScratchPadBuffer) + thr_invocation_context_attach_scratch_pad_buffer = nullptr; + + decltype(&thrVendorSetSystemAttributeStr) + thr_vendor_set_system_attribute_str = nullptr; + decltype(&thrVendorSetSystemAttributeInt64) + thr_vendor_set_system_attribute_int64 = nullptr; +}; + +} // namespace google_tensor +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_GOOGLE_TENSOR_DISPATCH_SOUTHBOUND_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD new file mode 100644 index 00000000000000..5a7a76bdbdb999 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/BUILD @@ -0,0 +1,140 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib", "litert_test") +load("//tensorflow/lite/experimental/litert/vendors/qualcomm:qualcomm_build_defs.bzl", "litert_lib_with_qnn") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "common", + hdrs = ["common.h"], + deps = [ + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + ], +) + +litert_lib( + name = "qnn_log", + srcs = ["qnn_log.cc"], + hdrs = ["qnn_log.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + ], +) + +cc_library( + name = "qnn_manager_hdr", + hdrs = ["qnn_manager.h"], + deps = [ + ":common", + ":qnn_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/core:dynamic_loading", + ], +) + +litert_lib_with_qnn( + name = "qnn_manager", + srcs = [ + "qnn_manager.cc", + ], + hdrs = ["qnn_manager.h"], + include_system = True, + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + ungrte = True, + deps = [ + ":common", + ":qnn_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:dynamic_loading", + ], +) + +litert_test( + name = "qnn_manager_test", + srcs = ["qnn_manager_test.cc"], + linkstatic = True, + tags = [ + # Tests with ungrte deps do not currently work on forge. + "no-remote-exec", + "notap", + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. + "nosan", + ], + deps = [ + ":qnn_manager", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/tools:dump", + ], +) + +cc_library( + name = "context_binary_info", + srcs = ["context_binary_info.cc"], + hdrs = ["context_binary_info.h"], + deps = [ + ":qnn_manager", + ":qnn_tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + ], +) + +cc_library( + name = "qnn_tensor", + srcs = ["qnn_tensor.cc"], + hdrs = ["qnn_tensor.h"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h new file mode 100644 index 00000000000000..9fd1ddf6d009bf --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/common.h @@ -0,0 +1,99 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ + +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define LITERT_RETURN_STATUS_IF_QNN_NOT_OK(expr) \ + if (QNN_SUCCESS != (expr)) { \ + return kLiteRtStatusErrorNotFound; \ + } + +// Pointers to functions of a dynamically loaded QNN library. +typedef QNN_INTERFACE_VER_TYPE QnnApi; + +// Pointers to functions of a dynamically loaded QNN system library. +typedef QNN_SYSTEM_INTERFACE_VER_TYPE QnnSystemApi; + +// QNN backend library should be on DT_RUNPATH (-rpath). +static const char kLibQnnHtpSo[] = "libQnnHtp.so"; + +// QNN backend library should be on DT_RUNPATH (-rpath). +static const char kLibQnnSystemSo[] = "libQnnSystem.so"; + +// Map LiteRT element type to Qnn counterpart. +inline LiteRtStatus LegalizeElementType(LiteRtElementType litert_type, + Qnn_DataType_t* qnn_type) { + switch (litert_type) { + case kLiteRtElementTypeBool: + *qnn_type = QNN_DATATYPE_BOOL_8; + break; + case kLiteRtElementTypeInt4: + *qnn_type = QNN_DATATYPE_SFIXED_POINT_4; + break; + case kLiteRtElementTypeInt8: + *qnn_type = QNN_DATATYPE_INT_8; + break; + case kLiteRtElementTypeInt16: + *qnn_type = QNN_DATATYPE_INT_16; + break; + case kLiteRtElementTypeInt32: + *qnn_type = QNN_DATATYPE_INT_32; + break; + case kLiteRtElementTypeInt64: + *qnn_type = QNN_DATATYPE_INT_64; + break; + case kLiteRtElementTypeUInt8: + *qnn_type = QNN_DATATYPE_UINT_8; + break; + case kLiteRtElementTypeUInt16: + *qnn_type = QNN_DATATYPE_UINT_16; + break; + case kLiteRtElementTypeUInt32: + *qnn_type = QNN_DATATYPE_UINT_32; + break; + case kLiteRtElementTypeUInt64: + *qnn_type = QNN_DATATYPE_UINT_64; + break; + case kLiteRtElementTypeFloat16: + *qnn_type = QNN_DATATYPE_FLOAT_16; + break; + case kLiteRtElementTypeFloat32: + *qnn_type = QNN_DATATYPE_FLOAT_32; + break; + case kLiteRtElementTypeFloat64: + *qnn_type = QNN_DATATYPE_FLOAT_64; + break; + default: + return kLiteRtStatusErrorUnsupported; + } + return kLiteRtStatusOk; +} + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMMON_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD new file mode 100644 index 00000000000000..f68fb69cf29d1e --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/BUILD @@ -0,0 +1,139 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_dynamic_lib", "litert_lib", "litert_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_dynamic_lib( + name = "qnn_compiler_plugin", + srcs = ["qnn_compiler_plugin.cc"], + hdrs = ["//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin.h"], + export_litert_only = True, + shared_lib_name = "qnn_compiler_plugin_so", + so_name = "libLiteRtQnnCompilerPlugin.so", + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + ungrte = True, + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":qnn_compose_graph", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +litert_test( + name = "qnn_compiler_plugin_test", + srcs = [ + "qnn_compiler_plugin_test.cc", + ], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + linkstatic = True, + tags = [ + # Tests with ungrte deps do not currently work on forge. + "no-remote-exec", + "notap", + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + # Sanitizer runtime doesn't work with anything that loads libQnnHtp.so. + "nosan", + ], + use_sys_malloc = True, + deps = [ + ":qnn_compiler_plugin", # buildcleaner: keep + "@com_google_absl//absl/log:absl_check", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/core:model", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +litert_lib( + name = "qnn_compose_graph", + srcs = ["qnn_compose_graph.cc"], + hdrs = ["qnn_compose_graph.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + ":graph_mapper", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:add_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:batch_matmul_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:div_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:mul_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:reshape_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:rsqrt_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:slice_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:sub_op_legalization", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations:tanh_op_legalization", + ], +) + +litert_lib( + name = "graph_mapper", + srcs = [ + "graph_mapper.cc", + ], + hdrs = ["graph_mapper.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD new file mode 100644 index 00000000000000..a1afff0bcd6877 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/BUILD @@ -0,0 +1,121 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:__subpackages__"], +) + +cc_library( + name = "qnn_tensor", + srcs = ["qnn_tensor.cc"], + hdrs = ["qnn_tensor.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + ], +) + +cc_test( + name = "qnn_tensor_test", + srcs = ["qnn_tensor_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_tensor", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/test:common", + ], +) + +cc_library( + name = "qnn_op", + srcs = ["qnn_op.cc"], + hdrs = ["qnn_op.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/core:api_internal", + ], +) + +cc_test( + name = "qnn_op_test", + srcs = ["qnn_op_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_op", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/test:common", + ], +) + +cc_test( + name = "op_compatibility_test", + srcs = ["op_compatibility_test.cc"], + data = ["//tensorflow/lite/experimental/litert/test:tflite_test_data"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + "no_oss", + ], + deps = [ + ":qnn_op", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/core:api_internal", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/test:common", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc new file mode 100644 index 00000000000000..6b5d1309b300f2 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/op_compatibility_test.cc @@ -0,0 +1,81 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +namespace { + +static constexpr absl::string_view kOpTpl = "simple_%s_op.tflite"; +struct OpInfo { + std::string op_name; + std::string expected_type_name; +}; +// TODOL: b/365299994 - Add "stablehlo_scatter" once muti subgraphs is +// supported. +// clang-format off +const auto kSupportedOps = testing::Values( + OpInfo{"add", "ElementWiseAdd"}, + OpInfo{"mul", "ElementWiseMultiply"}, + OpInfo{"batch_matmul", "MatMul"}, + OpInfo{"concatenation", "Concat"}, + OpInfo{"div", "ElementWiseDivide"}, + OpInfo{"fully_connected", "FullyConnected"}, + OpInfo{"reshape", "Reshape"}, + OpInfo{"rsqrt", "ElementWiseRsqrt"}, + OpInfo{"select_v2", "ElementWiseSelect"}, + OpInfo{"select", "ElementWiseSelect"}, + OpInfo{"strided_slice", "StridedSlice"}, + OpInfo{"slice", "StridedSlice"}, + OpInfo{"softmax", "Softmax"}, + OpInfo{"sub", "ElementWiseSubtract"}, + OpInfo{"tanh", "Tanh"}, + OpInfo{"transpose", "Transpose"}); +// clang-format on + +class OpCompatibilityTest : public ::testing::TestWithParam {}; + +TEST_P(OpCompatibilityTest, SupportedOpsTest) { + auto test_params = GetParam(); + std::string model_path = absl::StrFormat(kOpTpl, test_params.op_name); + auto model = litert::testing::LoadTestFileModel(model_path); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + ASSERT_STATUS_OK(litert::qnn::LegalizeOp(ops[0], qnn_op)); + + EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, test_params.op_name)); + EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); + EXPECT_STREQ(qnn_op.v1.typeName, test_params.expected_type_name.c_str()); + + EXPECT_EQ(qnn_op.v1.numOfInputs, 0); + EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); + EXPECT_EQ(qnn_op.v1.numOfParams, 0); + + litert::qnn::ResetOp(qnn_op); +} + +INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, OpCompatibilityTest, kSupportedOps); + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc new file mode 100644 index 00000000000000..b8617006683a20 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.cc @@ -0,0 +1,155 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" + +// A macro dance to create a unique literal string given a prefix. +#define STRINGIFY(x) #x +#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER) + +namespace litert::qnn { + +using ::litert::LiteRtOpManager; + +namespace { + +// Maps "op-code" related information (name, packageName, typeName) from src +// to dest. +LiteRtStatus LegalizeOpType(const LiteRtOpManager& src, Qnn_OpConfig_t& dest) { + switch (src.Code()) { + case kLiteRtOpCodeTflMul: + dest.v1.name = QNN_OP_NAME(mul_); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseMultiply"; + break; + case kLiteRtOpCodeTflAdd: + dest.v1.name = QNN_OP_NAME("add"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseAdd"; + break; + case kLiteRtOpCodeTflBatchMatmul: + dest.v1.name = QNN_OP_NAME("batch_matmul"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "MatMul"; + break; + case kLiteRtOpCodeTflConcatenation: + dest.v1.name = QNN_OP_NAME("concatenation"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Concat"; + break; + case kLiteRtOpCodeTflDiv: + dest.v1.name = QNN_OP_NAME("div"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseDivide"; + break; + case kLiteRtOpCodeTflFullyConnected: + dest.v1.name = QNN_OP_NAME("fully_connected"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "FullyConnected"; + break; + case kLiteRtOpCodeTflReshape: + dest.v1.name = QNN_OP_NAME("reshape"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Reshape"; + break; + case kLiteRtOpCodeTflRsqrt: + dest.v1.name = QNN_OP_NAME("rsqrt"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseRsqrt"; + break; + case kLiteRtOpCodeTflSelectV2: + dest.v1.name = QNN_OP_NAME("select_v2"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSelect"; + break; + case kLiteRtOpCodeTflSelect: + dest.v1.name = QNN_OP_NAME("select"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSelect"; + break; + case kLiteRtOpCodeTflStridedSlice: + dest.v1.name = QNN_OP_NAME("strided_slice"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "StridedSlice"; + break; + case kLiteRtOpCodeTflSlice: + dest.v1.name = QNN_OP_NAME("slice"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "StridedSlice"; + break; + case kLiteRtOpCodeTflSoftmax: + dest.v1.name = QNN_OP_NAME("softmax"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Softmax"; + break; + case kLiteRtOpCodeTflSub: + dest.v1.name = QNN_OP_NAME("sub"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "ElementWiseSubtract"; + break; + case kLiteRtOpCodeTflTanh: + dest.v1.name = QNN_OP_NAME("tanh"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Tanh"; + break; + case kLiteRtOpCodeTflTranspose: + dest.v1.name = QNN_OP_NAME("transpose"); + dest.v1.packageName = "qti.aisw"; + dest.v1.typeName = "Transpose"; + break; + default: + return kLiteRtStatusErrorUnsupported; + } + return kLiteRtStatusOk; +} + +} // namespace + +Qnn_OpConfig_t BuildDefaultOp() { + Qnn_OpConfig_t op = QNN_OPCONFIG_INIT; + ResetOp(op); + return op; +} +Qnn_Param_t BuildDefaultParam() { + Qnn_Param_t param = QNN_PARAM_INIT; + ResetParam(param); + return param; +} + +void ResetOp(Qnn_OpConfig_t& op) { + op = QNN_OPCONFIG_INIT; + op.version = QNN_OPCONFIG_VERSION_1; + op.v1 = QNN_OPCONFIG_V1_INIT; +} + +void ResetParam(Qnn_Param_t& param) { param = QNN_PARAM_INIT; } +LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest) { + ResetOp(dest); + + LiteRtOpManager::Unique src_op; + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtOpManager::MakeFromOp(src, src_op)); + + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeOpType(*src_op, dest)); + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h new file mode 100644 index 00000000000000..20e0f27f798b98 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h @@ -0,0 +1,53 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert::qnn { + +// +// Initialize QNN Op. +// + +// NOTE: Any referential data within a QNN Op +// is allocated with "new" and must be explicitly cleaned up with ResetOp. + +// Construct a "blank" QNN Op. +Qnn_OpConfig_t BuildDefaultOp(); + +// Construct a "blank" QNN Param. +Qnn_Param_t BuildDefaultParam(); + +// Reset the given tensor, deallocating anything on the heap that it points to. +void ResetOp(Qnn_OpConfig_t& op); + +// Reset the given param, deallocating anything on the heap that it points to. +void ResetParam(Qnn_Param_t& param); + +// +// Legalize LiteRt Op to Analogous QNN Construct. +// + +// Map src op onto dest. Resets dest before doing anything. This only handles +// attribute-like info. It does not set edges (in/out tensors). +LiteRtStatus LegalizeOp(LiteRtOp src, Qnn_OpConfig_t& dest); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_OP_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc new file mode 100644 index 00000000000000..963ac244af3793 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op_test.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" + +#include +#include "absl/strings/match.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +TEST(TestInitQnnOp, BuildDefaultOp) { + Qnn_OpConfig_t op = litert::qnn::BuildDefaultOp(); + ASSERT_EQ(op.version, QNN_OPCONFIG_VERSION_1); +} + +TEST(TestLegalizeOp, SimpleSupportedOp) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + ASSERT_STATUS_OK(litert::qnn::LegalizeOp(ops[0], qnn_op)); + + EXPECT_TRUE(absl::StrContains(qnn_op.v1.name, "mul")); + EXPECT_STREQ(qnn_op.v1.packageName, "qti.aisw"); + EXPECT_STREQ(qnn_op.v1.typeName, "ElementWiseMultiply"); + + EXPECT_EQ(qnn_op.v1.numOfInputs, 0); + EXPECT_EQ(qnn_op.v1.numOfOutputs, 0); + EXPECT_EQ(qnn_op.v1.numOfParams, 0); + + litert::qnn::ResetOp(qnn_op); +} + +TEST(TestLegalizeOp, UnsupportedOp) { + auto model = litert::testing::LoadTestFileModel("simple_floor_mod_op.tflite"); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + + Qnn_OpConfig_t qnn_op = litert::qnn::BuildDefaultOp(); + ASSERT_STATUS_HAS_CODE(litert::qnn::LegalizeOp(ops[0], qnn_op), + kLiteRtStatusErrorUnsupported); + + litert::qnn::ResetOp(qnn_op); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc new file mode 100644 index 00000000000000..6d63c1b6a8f002 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.cc @@ -0,0 +1,147 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" + +#include + +#include "absl/log/absl_check.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" + +namespace litert::qnn { + +using ::litert::LiteRtTensorManager; + +namespace { + +LiteRtStatus LegalizeShapeInfo(const LiteRtTensorManager& src, + Qnn_Tensor_t& dest) { + LITERT_ENSURE_SUPPORTED(!src.HasStrides(), "Strides not yet supported"); + + dest.v2.rank = src.Rank(); + dest.v2.dimensions = new uint32_t[dest.v2.rank]; + for (int i = 0; i < dest.v2.rank; ++i) { + const auto src_dim = src.Dims()[i]; + LITERT_ENSURE(src_dim >= 1, kLiteRtStatusErrorInvalidArgument, + "Cannot pass dim < 1 to QNN Tensor."); + + dest.v2.dimensions[i] = src.Dims()[i]; + } + return kLiteRtStatusOk; +} + +void FreeTensorDims(Qnn_Tensor_t& tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2 && + tensor.v2.dimensions != nullptr) { + delete[] tensor.v2.dimensions; + tensor.v2.dimensions = nullptr; + tensor.v2.rank = 0; + } +} + +} // namespace + +void SetInputTensorAttrs(Qnn_Tensor_t& tensor) { + ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); + tensor.v2.type = QNN_TENSOR_TYPE_APP_WRITE; + tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + tensor.v2.clientBuf = QNN_CLIENT_BUFFER_INIT; +} + +void SetOutputTensorAttrs(Qnn_Tensor_t& tensor) { + ABSL_DCHECK(tensor.version == QNN_TENSOR_VERSION_2); + tensor.v2.type = QNN_TENSOR_TYPE_APP_READ; +} + +void ResetTensor(Qnn_Tensor_t& tensor) { + FreeTensorDims(tensor); + tensor = QNN_TENSOR_INIT; + tensor.version = QNN_TENSOR_VERSION_2; + tensor.v2 = QNN_TENSOR_V2_INIT; + tensor.v2.dataFormat = QNN_TENSOR_DATA_FORMAT_DENSE; +} + +Qnn_Tensor_t BuildDefaultTensor(uint32_t id) { + Qnn_Tensor_t tensor = QNN_TENSOR_INIT; + ResetTensor(tensor); + tensor.v2.id = id; + return tensor; +} + +Qnn_Tensor_t BuildDefaultTensor() { return BuildDefaultTensor(0); } + +Qnn_Tensor_t BuildInputTensor() { + auto tensor = BuildDefaultTensor(); + SetInputTensorAttrs(tensor); + return tensor; +} + +Qnn_ClientBuffer_t BuildDefaultClientBuffer() { + Qnn_ClientBuffer_t client_buf = QNN_CLIENT_BUFFER_INIT; + client_buf.data = nullptr; + client_buf.dataSize = 0; + return client_buf; +} + +Qnn_Tensor_t BuildOutputTensor() { + Qnn_Tensor_t tensor = BuildDefaultTensor(); + SetOutputTensorAttrs(tensor); + return tensor; +} + +uint32_t MoveToId(Qnn_Tensor_t& tensor) { + const auto id = tensor.v2.id; + ResetTensor(tensor); + tensor.v2.id = id; + return id; +} + +LiteRtStatus LegalizeTensor(LiteRtTensor src, Qnn_Tensor_t& dest) { + ResetTensor(dest); + + LiteRtTensorManager::Unique src_tensor; + LITERT_RETURN_STATUS_IF_NOT_OK( + LiteRtTensorManager::MakeFromTensor(src, src_tensor)); + + Qnn_DataType_t* qnn_data_type = &dest.v2.dataType; + LITERT_RETURN_STATUS_IF_NOT_OK( + LegalizeElementType(src_tensor->ElementType(), qnn_data_type)); + + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeShapeInfo(*src_tensor, dest)); + + const bool is_subgraph_in = src_tensor->IsSubgraphInput(); + const bool is_subgraph_out = src_tensor->IsSubgraphOutput(); + + LITERT_ENSURE(!(is_subgraph_in && is_subgraph_out), + kLiteRtStatusErrorInvalidArgument, + "Malformed tensor, cannot be both subgraph in and out."); + + if (is_subgraph_in) { + SetInputTensorAttrs(dest); + } + if (is_subgraph_out) { + SetOutputTensorAttrs(dest); + } + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h new file mode 100644 index 00000000000000..b4d568831f9e90 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h @@ -0,0 +1,69 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" + +namespace litert::qnn { + +// +// Initialize QNN Tensors. +// + +// NOTE: Within LiteRt land, all Qnn Tensors are treated as "v2". Any +// referential data (like dimensions : uint32_t*) within a QNN Tensor +// is allocated with "new" and must be explicitly cleaned up with ResetTensor. + +// Construct a "blank" QNN Tensor. +Qnn_Tensor_t BuildDefaultTensor(); + +// Construct a "blank" QNN Tensor with given id. +Qnn_Tensor_t BuildDefaultTensor(uint32_t id); + +// Constructa a "blank" QNN Tensor meant to be used as a graph input. +Qnn_Tensor_t BuildInputTensor(); + +// Constructa a "blank" QNN Tensor meant to be used as a graph output. +Qnn_Tensor_t BuildOutputTensor(); + +Qnn_ClientBuffer_t BuildDefaultClientBuffer(); + +// Adds attributes to given tensor making it amenable for use as graph input. +void SetInputTensorAttrs(Qnn_Tensor_t& tensor); + +// Adds attributes to given tensor making it amenable for use as graph output. +void SetOutputTensorAttrs(Qnn_Tensor_t& tensor); + +// Reset the given tensor, deallocating anything on the heap that it points to. +void ResetTensor(Qnn_Tensor_t& tensor); + +// Resets all fields other than id in the given tensor and returns the id for +// convenience. Only the id is needed to traffic QNN Tensors after they have +// been registered with the context. +uint32_t MoveToId(Qnn_Tensor_t& tensor); + +// +// Legalize LiteRt Tensors to Analogous QNN Construct. +// + +// Map src tensor onto dest. Resets dest before doing anything. +LiteRtStatus LegalizeTensor(LiteRtTensor src, Qnn_Tensor_t& dest); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_IR_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc new file mode 100644 index 00000000000000..c1ec5983ce4752 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor_test.cc @@ -0,0 +1,131 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" + +#include +#include +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/test/common.h" + +namespace { + +TEST(TestInitQnnTensor, BuildDefaultTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); + EXPECT_EQ(tensor.v2.rank, 0); + EXPECT_EQ(tensor.v2.dimensions, nullptr); + EXPECT_EQ(tensor.v2.id, 0); +} + +TEST(TestInitQnnTensor, BuildDefaultTensorWithId) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.dataFormat, QNN_TENSOR_DATA_FORMAT_DENSE); + EXPECT_EQ(tensor.v2.rank, 0); + EXPECT_EQ(tensor.v2.dimensions, nullptr); + EXPECT_EQ(tensor.v2.id, 2); +} + +TEST(TestInitQnnTensor, BuildDefaultInputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildInputTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); + EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); + EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); +} + +TEST(TestInitQnnTensor, SetInputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + litert::qnn::SetInputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_WRITE); + EXPECT_EQ(tensor.v2.memType, QNN_TENSORMEMTYPE_RAW); + EXPECT_EQ(tensor.v2.clientBuf.dataSize, 0); +} + +TEST(TestInitQnnTensor, BuildDefaultOutputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildOutputTensor(); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); +} + +TEST(TestInitQnnTensor, SetOutputTensor) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(); + litert::qnn::SetOutputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); +} + +TEST(TestInitQnnTensor, MoveToId) { + Qnn_Tensor_t tensor = litert::qnn::BuildDefaultTensor(2); + + litert::qnn::SetOutputTensorAttrs(tensor); + ASSERT_EQ(tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + EXPECT_EQ(litert::qnn::MoveToId(tensor), 2); + EXPECT_EQ(tensor.v2.id, 2); + EXPECT_EQ(tensor.v2.type, QNN_TENSOR_TYPE_UNDEFINED); +} + +TEST(TestLegalizeTensor, SimpleSupportedTensorSubgraphInput) { + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto outputs, + ::graph_tools::GetSubgraphOutputs(subgraph)); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + ASSERT_STATUS_OK(litert::qnn::LegalizeTensor(outputs[0], qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_APP_READ); + + ASSERT_EQ(qnn_tensor.v2.rank, 2); + ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); + EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), + ::testing::ElementsAreArray({2, 2})); + + litert::qnn::ResetTensor(qnn_tensor); +} + +TEST(TestLegalizeTensor, SimpleSupportedTensor) { + auto model = litert::testing::LoadTestFileModel("simple_multi_op.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + ASSERT_RESULT_OK_ASSIGN(auto ops, ::graph_tools::GetSubgraphOps(subgraph)); + ASSERT_RESULT_OK_ASSIGN(auto op_outs, ::graph_tools::GetOpOuts(ops[1])); + + auto qnn_tensor = litert::qnn::BuildDefaultTensor(); + ASSERT_STATUS_OK(litert::qnn::LegalizeTensor(op_outs[0], qnn_tensor)); + + ASSERT_EQ(qnn_tensor.version, QNN_TENSOR_VERSION_2); + EXPECT_EQ(qnn_tensor.v2.dataType, QNN_DATATYPE_FLOAT_32); + EXPECT_EQ(qnn_tensor.v2.type, QNN_TENSOR_TYPE_UNDEFINED); + + ASSERT_EQ(qnn_tensor.v2.rank, 2); + ASSERT_NE(qnn_tensor.v2.dimensions, nullptr); + EXPECT_THAT(absl::MakeConstSpan(qnn_tensor.v2.dimensions, 2), + ::testing::ElementsAreArray({2, 2})); + + litert::qnn::ResetTensor(qnn_tensor); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc new file mode 100644 index 00000000000000..f54374cf56b368 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.cc @@ -0,0 +1,158 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +#include +#include + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnGraph.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Get empty configurations for graph building. +inline absl::Span GetDefaultGraphConfigs() { + static const QnnGraph_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +LiteRtStatus GraphMapper::AssignTensorName(Qnn_Tensor_t& qnn_tensor) { + char* name = nullptr; + const int written = asprintf(&name, "Tensor_%d", cur_tensor_num_++); + LITERT_ENSURE(written != -1 && name != nullptr, kLiteRtStatusErrorNotFound, + "Failed to make tensor name"); + qnn_tensor.v2.name = name; + return kLiteRtStatusOk; +} + +LiteRtSubgraph GraphMapper::Subgraph() { return subgraph_; } + +absl::Span GraphMapper::LiteRtSubgraphInputs() { + return litert_subgraph_inputs_; +} + +absl::Span GraphMapper::LiteRtSubgraphOutputs() { + return litert_subgraph_outputs_; +} + +absl::Span GraphMapper::LiteRtSubgraphOps() { + return litert_subgraph_ops_; +} + +absl::flat_hash_map& GraphMapper::CurrentScope() { + return current_scope_; +} + +LiteRtStatus GraphMapper::LookupInScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + // If we go in topological order, this should never happen. TODO: add + // "internal error" status code. + const auto qnn_id = CurrentScope().find(litert_tensor); + LITERT_ENSURE(qnn_id != CurrentScope().end(), kLiteRtStatusErrorNotFound, + "Couldn't find tensor in current_scope."); + + ResetTensor(qnn_tensor); + qnn_tensor.v2.id = qnn_id->second; + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::PushToScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + CurrentScope()[litert_tensor] = MoveToId(qnn_tensor); + return kLiteRtStatusOk; +} + +QnnManager& GraphMapper::Qnn() { return qnn_; } + +Qnn_GraphHandle_t& GraphMapper::QnnGraph() { return qnn_graph_; } + +LiteRtStatus GraphMapper::LegalizeAndRegister(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor) { + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeTensor(litert_tensor, qnn_tensor)); + LITERT_RETURN_STATUS_IF_NOT_OK(AssignTensorName(qnn_tensor)); + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->tensorCreateGraphTensor(QnnGraph(), &qnn_tensor)); + + LITERT_LOG(LITERT_INFO, "Legalized and registered tensor %d", + qnn_tensor.v2.id); + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::ParseLiteRtSubgraph() { + LITERT_ASSIGN_OR_RETURN_STATUS(auto inputs, + graph_tools::GetSubgraphInputs(Subgraph())); + litert_subgraph_inputs_ = + absl::MakeSpan(const_cast(inputs.data()), inputs.size()); + + LITERT_ASSIGN_OR_RETURN_STATUS(auto outputs, + graph_tools::GetSubgraphOutputs(Subgraph())); + litert_subgraph_outputs_ = + absl::MakeSpan(const_cast(outputs.data()), outputs.size()); + + LITERT_ASSIGN_OR_RETURN_STATUS(auto ops, + graph_tools::GetSubgraphOps(Subgraph())); + litert_subgraph_ops_ = + absl::MakeSpan(const_cast(ops.data()), ops.size()); + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::IsLiteRtSubgraphSupported() { + LITERT_ENSURE_SUPPORTED( + LiteRtSubgraphInputs().size() < 4, + "Only subgraphs with less than 4 inputs currently supported."); + + LITERT_ENSURE_SUPPORTED(LiteRtSubgraphOutputs().size() == 1, + "Only subgraphs with 1 output currently supported."); + + LITERT_ENSURE_SUPPORTED(LiteRtSubgraphOps().size() == 1, + "Only subgraphs with 1 op currently supported."); + + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::InitQnnGraph(absl::string_view qnn_graph_name) { + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->graphCreate(context_handle_, qnn_graph_name.data(), + GetDefaultGraphConfigs().data(), &QnnGraph())); + return kLiteRtStatusOk; +} + +LiteRtStatus GraphMapper::Finalize() { + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + qnn_.Api()->graphFinalize(QnnGraph(), nullptr, nullptr)); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h new file mode 100644 index 00000000000000..86b60fe6acddf1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Algorithm class for managing "scope" when mapping litert Subgraphs +// to QNN Graphs. +class GraphMapper { + public: + GraphMapper(LiteRtSubgraph subgraph, QnnManager& qnn, + Qnn_ContextHandle_t context_handle) + : subgraph_(subgraph), qnn_(qnn), context_handle_(context_handle) {} + + // Legalize given LiteRtTensors attributes into QNN Tensor registered with + // QNN context. Result QNN Tensor is empty except for the canonical id + // assigned by QNN Api. + LiteRtStatus LegalizeAndRegister(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // Find ID associated with evaluated litert Tensor and add it to given + // QNN Tensor. + LiteRtStatus LookupInScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // Adds new mapping to scope. All fields other than ID in given QNN Tensor are + // cleared and its ID is added to "current_scope". Expects QNN Tensor has + // already been registered with context. + LiteRtStatus PushToScope(LiteRtTensor litert_tensor, + Qnn_Tensor_t& qnn_tensor); + + // NOTE: QNN Tensors must be created with a unique name. This will ensure + // uniqueness but will want to have more meaningful names in the future. + LiteRtStatus AssignTensorName(Qnn_Tensor_t& qnn_tensor); + + // QNN Sdk Accessors + QnnManager& Qnn(); + Qnn_GraphHandle_t& QnnGraph(); + + // CC Convienence Accessors + absl::Span LiteRtSubgraphOps(); + absl::Span LiteRtSubgraphInputs(); + absl::Span LiteRtSubgraphOutputs(); + + // Accessor for current scope. + // Since each QNN Tensor needs to have a unique name globally within each QNN + // context, we maintain "Current scope", which is a map of evaluated + // LiteRtTensors to their resolved QNN Tensor ID. + absl::flat_hash_map& CurrentScope(); + + // Can implementation handle given LiteRtSubgraph topology (see comment at + // bottom of file). + LiteRtStatus IsLiteRtSubgraphSupported(); + + // Parse LiteRtSubgraph entities into usable types. Call this before + // doing anything else. + LiteRtStatus ParseLiteRtSubgraph(); + + // Initialize QNN Graph with given name. Call this after parsing + // LiteRtSubgraph. + LiteRtStatus InitQnnGraph(absl::string_view qnn_graph_name); + + // Finalize QNN Graph. Call this after all ops have been mapped. + LiteRtStatus Finalize(); + + private: + absl::Span litert_subgraph_inputs_; + + absl::Span litert_subgraph_outputs_; + + absl::Span litert_subgraph_ops_; + + LiteRtSubgraph Subgraph(); + LiteRtSubgraph subgraph_; + + // Maps evaluated tensors to their resolved QNN Tensor ID. + absl::flat_hash_map current_scope_; + + // + // QNN Sdk State + // + QnnManager& qnn_; + Qnn_ContextHandle_t context_handle_; + Qnn_GraphHandle_t qnn_graph_ = nullptr; + + // + // Tensor Naming + // + + uint32_t cur_tensor_num_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_GRAPH_MAPPER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD new file mode 100644 index 00000000000000..020473eb013851 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/BUILD @@ -0,0 +1,324 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "litert_lib") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], +) + +litert_lib( + name = "legalization", + hdrs = ["legalization.h"], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + ], +) + +litert_lib( + name = "add_op_legalization", + srcs = ["add_op_legalization.cc"], + hdrs = ["add_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "batch_matmul_op_legalization", + srcs = ["batch_matmul_op_legalization.cc"], + hdrs = ["batch_matmul_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "div_op_legalization", + srcs = ["div_op_legalization.cc"], + hdrs = ["div_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "mul_op_legalization", + srcs = ["mul_op_legalization.cc"], + hdrs = ["mul_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "reshape_op_legalization", + srcs = ["reshape_op_legalization.cc"], + hdrs = ["reshape_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "rsqrt_op_legalization", + srcs = ["rsqrt_op_legalization.cc"], + hdrs = ["rsqrt_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "slice_op_legalization", + srcs = ["slice_op_legalization.cc"], + hdrs = ["slice_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "sub_op_legalization", + srcs = ["sub_op_legalization.cc"], + hdrs = ["sub_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "tanh_op_legalization", + srcs = ["tanh_op_legalization.cc"], + hdrs = ["tanh_op_legalization.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + ":util", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) + +litert_lib( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], + deps = [ + ":legalization", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_logging", + "//tensorflow/lite/experimental/litert/cc:litert_cc_api", + "//tensorflow/lite/experimental/litert/cc:litert_op", + "//tensorflow/lite/experimental/litert/cc:litert_tensor", + "//tensorflow/lite/experimental/litert/core:graph_tools", + "//tensorflow/lite/experimental/litert/tools:dump", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler:graph_mapper", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_op", + "//tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR:qnn_tensor", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc new file mode 100644 index 00000000000000..3eb066fb433a22 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnAddOpTypeName = "ElementWiseAdd"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kAddOpFmt = "add_%d"; + +LiteRtStatus AddOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflAdd) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kAddOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnAddOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized add op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h new file mode 100644 index 00000000000000..7aded37477d65f --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class AddOpLegalization : public Legalization { + public: + AddOpLegalization() = default; + ~AddOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_ADD_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc new file mode 100644 index 00000000000000..94973209c0e996 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnBatchMatmulOpTypeName = "MatMul"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kBatchMatmulOpFmt = "batch_matmul_%d"; + +LiteRtStatus BatchMatmulOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflBatchMatmul) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kBatchMatmulOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK( + SetOpInfo(op_name.c_str(), kDefaultQnnOpPackageName.data(), + kQnnBatchMatmulOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized batch_matmul op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h new file mode 100644 index 00000000000000..dc376b4da38440 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class BatchMatmulOpLegalization : public Legalization { + public: + BatchMatmulOpLegalization() = default; + ~BatchMatmulOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_BATCH_MATMUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc new file mode 100644 index 00000000000000..9434d693934ab4 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnDivOpTypeName = "ElementWiseDivide"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kDivOpFmt = "div_%d"; + +LiteRtStatus DivOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflDiv) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kDivOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnDivOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized div op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h new file mode 100644 index 00000000000000..2240d2a327a05d --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class DivOpLegalization : public Legalization { + public: + DivOpLegalization() = default; + ~DivOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_DIV_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h new file mode 100644 index 00000000000000..000af7661e0667 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +#define STRINGIFY(x) #x +#define QNN_OP_NAME(prefix) STRINGIFY(prefix##__COUNTER__) + +namespace litert::qnn { + +class Legalization { + public: + Legalization() = default; + virtual ~Legalization() = default; + + virtual LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) = 0; + + // Sets the op name, package name, and type. + // Note: All argument strings can't be de-allocated until the op has been + // registered with the qnn api. i.e graphAddNode(). + inline LiteRtStatus SetOpInfo(const char* name, const char* op_package_name, + const char* op_type, Qnn_OpConfig_t& op) { + op.v1.name = name; + op.v1.packageName = op_package_name; + op.v1.typeName = op_type; + return kLiteRtStatusOk; + } +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc new file mode 100644 index 00000000000000..e3f8e08b35651a --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.cc @@ -0,0 +1,52 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnMulOpTypeName = "ElementWiseMultiply"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kMulOpFmt = "mul_%d"; + +LiteRtStatus MulOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflMul) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kMulOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnMulOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized mul op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h new file mode 100644 index 00000000000000..ebd2d1f9e4a6a3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class MulOpLegalization : public Legalization { + public: + MulOpLegalization() = default; + ~MulOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_MUL_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc new file mode 100644 index 00000000000000..2be0631bc9b37c --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.cc @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnReshapeOpTypeName = "Reshape"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kReshapeOpFmt = "reshape_%d"; + +static constexpr int kReshapeOpInputSize = 1; +static constexpr int kReshapeOpOutputSize = 1; + +LiteRtStatus ReshapeOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflReshape) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kReshapeOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnReshapeOpTypeName.data(), dest)); + DumpLegalization(*src.Op()); + // Look up op input tensors in scope. + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_ins, + ::graph_tools::GetOpIns(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kReshapeOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins[0], qnn_op_ins[0])); + + // Legalize op outputs and update scope. + + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_outs, + ::graph_tools::GetOpOuts(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kReshapeOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs[0], qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs[0], qnn_op_outs[0])); + + dest.v1.numOfInputs = kReshapeOpInputSize; + dest.v1.inputTensors = qnn_op_ins; + + dest.v1.numOfOutputs = kReshapeOpOutputSize; + dest.v1.outputTensors = qnn_op_outs; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized reshape op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h new file mode 100644 index 00000000000000..c980f12c8cbe34 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class ReshapeOpLegalization : public Legalization { + public: + ReshapeOpLegalization() = default; + ~ReshapeOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { + return std::make_unique(); + } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RESHAPE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc new file mode 100644 index 00000000000000..56b3bd7571097b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnRsqrtOpTypeName = "ElementWiseRsqrt"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kRsqrtOpFmt = "rsqrt_%d"; + +LiteRtStatus RsqrtOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflRsqrt) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kRsqrtOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnRsqrtOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized rsqrt op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h new file mode 100644 index 00000000000000..9c4a8eb2af562f --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class RsqrtOpLegalization : public Legalization { + public: + RsqrtOpLegalization() = default; + ~RsqrtOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_RSQRT_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc new file mode 100644 index 00000000000000..784e14ef89c2d0 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.cc @@ -0,0 +1,149 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" + +#include +#include + +#include +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_tensor.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSliceOpTypeName = "StridedSlice"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSliceOpFmt = "slice_%d"; + +static constexpr int kSliceOpInputSize = 1; +static constexpr int kSliceOpOutputSize = 1; +static constexpr int kSliceOpParamSize = 1; +// QNN StridedSlice op packs "start", "end", and "stride" into a single tensor +// param "ranges". +static constexpr int kRangesParamArgSize = 3; +static constexpr int kRangesParamRank = 2; + +LiteRtStatus SliceOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSlice) { + return kLiteRtStatusLegalizeNoMatch; + } + DumpLegalization(*src.Op()); + std::string op_name = absl::StrFormat(kSliceOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSliceOpTypeName.data(), dest)); + + // QNN strided slice op expects 1 input tensor. + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_ins, + ::graph_tools::GetOpIns(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, kSliceOpInputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_ins[0], qnn_op_ins[0])); + + // QNN strided slice op expects 1 output tensor. + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_outs, + ::graph_tools::GetOpOuts(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, kSliceOpOutputSize, + QNN_TENSOR_INIT); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_outs[0], qnn_op_outs[0])); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_outs[0], qnn_op_outs[0])); + + LiteRtTensorManager::Unique src_input_tensor; + LITERT_RETURN_STATUS_IF_NOT_OK( + LiteRtTensorManager::MakeFromTensor(op_ins[0], src_input_tensor)); + + // Prepare qnn strided slice parameters. + auto src_begin_indices = graph_tools::GetWeights(op_ins[1]).Value(); + auto src_size_indices = graph_tools::GetWeights(op_ins[2]).Value(); + + // Check if src_begin_indices and src_size_indices are weights tensors. + if (src_begin_indices.empty() || src_size_indices.empty()) { + return kLiteRtStatusErrorInvalidLegalization; + } + + LITERT_STACK_ARRAY(int32_t, range_tensor_data, + src_input_tensor->Rank() * kRangesParamArgSize, + /*init value*/ 0); + for (int i = 0; i < src_input_tensor->Rank(); ++i) { + // Copy begin, end, and stride values from src_begin_indices and + // src_size_indices to range_tensor_data. Stride is always 1. + range_tensor_data[i * kRangesParamArgSize] = src_begin_indices[i]; + range_tensor_data[i * kRangesParamArgSize + 1] = src_size_indices[i]; + range_tensor_data[i * kRangesParamArgSize + 2] = 1; + } + + Qnn_ClientBuffer_t range_tensor_client_buf = BuildDefaultClientBuffer(); + range_tensor_client_buf.data = range_tensor_data; + range_tensor_client_buf.dataSize = + src_input_tensor->Rank() * kRangesParamArgSize * sizeof(int32_t); + + // Construct the const tensor "ranges". + Qnn_Tensor_t range_tensor = BuildDefaultTensor(); + graph_mapper.AssignTensorName(range_tensor); + range_tensor.v2.dataType = QNN_DATATYPE_INT_32; + range_tensor.v2.type = QNN_TENSOR_TYPE_STATIC; + range_tensor.v2.rank = kRangesParamRank; + range_tensor.v2.dimensions = new uint32_t[kRangesParamRank]; + range_tensor.v2.dimensions[0] = src_input_tensor->Rank(); + range_tensor.v2.dimensions[1] = kRangesParamArgSize; + range_tensor.v2.memType = QNN_TENSORMEMTYPE_RAW; + range_tensor.v2.clientBuf = range_tensor_client_buf; + range_tensor.v2.isDynamicDimensions = nullptr; + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->tensorCreateGraphTensor(graph_mapper.QnnGraph(), + &range_tensor)); + + Qnn_Param_t range_param = BuildDefaultParam(); + range_param.paramType = QNN_PARAMTYPE_TENSOR; + range_param.name = "ranges"; + range_param.tensorParam = range_tensor; + + Qnn_Param_t strided_slice_params[] = {range_param}; + dest.v1.inputTensors = qnn_op_ins; + dest.v1.numOfInputs = kSliceOpInputSize; + dest.v1.outputTensors = qnn_op_outs; + dest.v1.numOfOutputs = kSliceOpOutputSize; + dest.v1.numOfParams = kSliceOpParamSize; + dest.v1.params = strided_slice_params; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + LITERT_LOG(LITERT_INFO, "Legalized slice op", ""); + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h new file mode 100644 index 00000000000000..bbb983f6901038 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h @@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SliceOpLegalization : public Legalization { + public: + SliceOpLegalization() = default; + ~SliceOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SLICE_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc new file mode 100644 index 00000000000000..8c8a710d912c30 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnSubOpTypeName = "ElementWiseSubtract"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kSubOpFmt = "sub_%d"; + +LiteRtStatus SubOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflSub) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kSubOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnSubOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized sub op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h new file mode 100644 index 00000000000000..72275a3991b366 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class SubOpLegalization : public Legalization { + public: + SubOpLegalization() = default; + ~SubOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_SUB_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc new file mode 100644 index 00000000000000..d9fffd7f00534f --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +namespace litert::qnn { + +static constexpr absl::string_view kQnnTanhOpTypeName = "Tanh"; +static constexpr absl::string_view kDefaultQnnOpPackageName = "qti.aisw"; +static constexpr absl::string_view kTanhOpFmt = "tanh_%d"; + +LiteRtStatus TanhOpLegalization::LegalizeOp(LiteRtOpManager& src, + Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + if (src.Code() != kLiteRtOpCodeTflTanh) { + return kLiteRtStatusLegalizeNoMatch; + } + std::string op_name = absl::StrFormat(kTanhOpFmt, op_counter_++); + LITERT_RETURN_STATUS_IF_NOT_OK(SetOpInfo(op_name.c_str(), + kDefaultQnnOpPackageName.data(), + kQnnTanhOpTypeName.data(), dest)); + LITERT_RETURN_STATUS_IF_NOT_OK(LegalizeSimpleOp(src, dest, graph_mapper)); + LITERT_LOG(LITERT_INFO, "Legalized tanh op", ""); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h new file mode 100644 index 00000000000000..f5746b23e52e5c --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h @@ -0,0 +1,49 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ + +#include +#include + +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" + +namespace litert::qnn { + +class TanhOpLegalization : public Legalization { + public: + TanhOpLegalization() = default; + ~TanhOpLegalization() = default; + using UniquePtr = std::unique_ptr; + static UniquePtr Create() { return std::make_unique(); } + + LiteRtStatus LegalizeOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + + private: + // Counter to ensure unique op names. + uint32_t op_counter_ = 0; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_TANH_OP_LEGALIZATION_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc new file mode 100644 index 00000000000000..4fa1df5b6cce30 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.cc @@ -0,0 +1,87 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h" + +#include + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/tools/dump.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +using ::litert::internal::Dump; +using ::litert::internal::DumpOptions; + +// Dump source Op details. +void DumpLegalization(LiteRtOpT& op) { + std::ostringstream dump; + Dump(op, dump); + DumpOptions(op, dump); + std::string s = dump.str(); + LITERT_LOG(LITERT_INFO, "%s", s.data()); +} + +LiteRtStatus LegalizeSimpleOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper) { + DumpLegalization(*src.Op()); + // Look up op input tensors in scope. + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_ins, + ::graph_tools::GetOpIns(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_ins, op_ins.size(), QNN_TENSOR_INIT); + + Qnn_Tensor_t* cur_qnn_op_in = qnn_op_ins; + for (auto op_in : op_ins) { + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LookupInScope(op_in, *cur_qnn_op_in)); + ++cur_qnn_op_in; + } + + // Legalize op outputs and update scope. + + LITERT_ASSIGN_OR_RETURN_STATUS(auto op_outs, + ::graph_tools::GetOpOuts(src.Op())); + LITERT_STACK_ARRAY(Qnn_Tensor_t, qnn_op_outs, op_outs.size(), + QNN_TENSOR_INIT); + + Qnn_Tensor_t* cur_qnn_op_out = qnn_op_outs; + for (auto op_out : op_outs) { + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(op_out, *cur_qnn_op_out)); + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(op_out, *cur_qnn_op_out)); + ++cur_qnn_op_out; + } + dest.v1.numOfInputs = op_ins.size(); + dest.v1.inputTensors = qnn_op_ins; + + dest.v1.numOfOutputs = op_outs.size(); + dest.v1.outputTensors = qnn_op_outs; + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK( + graph_mapper.Qnn().Api()->graphAddNode(graph_mapper.QnnGraph(), dest)); + + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h new file mode 100644 index 00000000000000..5a142f3ac1f638 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/util.h @@ -0,0 +1,39 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ + +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" + +namespace litert::qnn { + +// Use this function to legalize a LiteRtOp to a Qnn Op when: +// 1. Source input/output tensor and destination input/ouptut tensor are 1 : 1 +// mapped +// 2. Assigning params to destination OP does not depending on input tensor of +// source OP. +LiteRtStatus LegalizeSimpleOp(LiteRtOpManager& src, Qnn_OpConfig_t& dest, + GraphMapper& graph_mapper); + +// Dump source Op details. +void DumpLegalization(LiteRtOpT& op); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_LEGALIZATIONS_UTIL_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc new file mode 100644 index 00000000000000..bb8153303f38b3 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc @@ -0,0 +1,218 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include + +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using ::litert::qnn::QnnManager; + +// +// Configurations +// + +namespace { + +constexpr char kPluginManufacturer[] = "Qualcomm"; + +constexpr std::pair kPluginSocModels[] = { + {"V68", QNN_HTP_DEVICE_ARCH_V68}, + {"V69", QNN_HTP_DEVICE_ARCH_V69}, + {"V73", QNN_HTP_DEVICE_ARCH_V73}, + {"V75", QNN_HTP_DEVICE_ARCH_V75}, +}; + +constexpr auto kNumPluginSocModels = + sizeof(kPluginSocModels) / sizeof(kPluginSocModels[0]); + +std::optional FindSocModel( + absl::string_view soc_model_name) { + std::optional soc_model; + for (auto i = 0; i < kNumPluginSocModels; ++i) { + if (soc_model_name == kPluginSocModels[i].first) { + soc_model = kPluginSocModels[i].second; + break; + } + } + return soc_model; +} + +} // namespace + +const char* LiteRtPluginSocManufacturer() { return kPluginManufacturer; } + +LiteRtParamIndex LiteRtPluginNumSupportedSocModels( + LiteRtCompilerPlugin compiler_plugin) { + return kNumPluginSocModels; +} + +LiteRtStatus LiteRtPluginGetSupportedSocModel( + LiteRtCompilerPlugin compiler_plugin, LiteRtParamIndex soc_model_idx, + const char** soc_model_name) { + if (soc_model_idx < 0 || soc_model_idx >= kNumPluginSocModels) { + return kLiteRtStatusErrorInvalidArgument; + } + *soc_model_name = kPluginSocModels[soc_model_idx].first; + return kLiteRtStatusOk; +} + +// +// Compiled Result Definition +// + +struct LiteRtCompiledResultT { + std::vector context_bin; + std::vector graph_names; +}; + +LiteRtStatus LiteRtCompiledResultGetByteCode( + LiteRtCompiledResult compiled_result, const void** byte_code, + size_t* byte_code_size) { + *byte_code = compiled_result->context_bin.data(); + *byte_code_size = compiled_result->context_bin.size(); + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCompiledResultGetCallInfo( + LiteRtCompiledResult compiled_result, LiteRtParamIndex call_idx, + const void** call_info, size_t* call_info_size) { + if (call_idx >= compiled_result->graph_names.size()) { + return kLiteRtStatusErrorIndexOOB; + } + + *call_info = compiled_result->graph_names.at(call_idx).data(); + *call_info_size = compiled_result->graph_names.at(call_idx).size(); + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtCompiledResultGetNumCalls( + LiteRtCompiledResult compiled_result, LiteRtParamIndex* num_calls) { + *num_calls = compiled_result->graph_names.size(); + return kLiteRtStatusOk; +} + +void LiteRtCompiledResultDestroy(LiteRtCompiledResult compiled_result) { + delete compiled_result; +} + +// +// Plugin Definition +// + +// Plugins can hold state. +struct LiteRtCompilerPluginT {}; + +LiteRtStatus LiteRtPluginInit(LiteRtCompilerPlugin* compiler_plugin) { + auto* plugin = new LiteRtCompilerPluginT; + *compiler_plugin = plugin; + return kLiteRtStatusOk; +} + +void LiteRtPluginDestroy(LiteRtCompilerPlugin compiler_plugin) { + delete compiler_plugin; +} + +namespace { + +bool IsOpSupported(LiteRtOp op) { + using TyInfo = graph_tools::RankedTypeInfo; + + // NOTE: Currently we are demoing by just mapping simple f32 mul ops. + // In the limit this function withh want to leverage QNN SDK's getSuportedOps + // feature (along with our op/type mappings). + + static const TyInfo supported_op_type = {kLiteRtElementTypeFloat32, {2, 2}}; + return graph_tools::MatchOpType(op, {supported_op_type, supported_op_type}, + {supported_op_type}, kLiteRtOpCodeTflMul); +} + +} // namespace + +LiteRtStatus LiteRtPluginPartitionModel(LiteRtCompilerPlugin compiler_plugin, + LiteRtModel model, + LiteRtOpList selected_ops) { + LITERT_ASSIGN_OR_RETURN_STATUS(auto subgraph, + graph_tools::GetSubgraph(model)); + LITERT_ASSIGN_OR_RETURN_STATUS(auto ops, + graph_tools::GetSubgraphOps(subgraph)); + + for (auto op : ops) { + if (!IsOpSupported(op)) { + continue; + } + + LITERT_RETURN_STATUS_IF_NOT_OK(PushOp(selected_ops, op)); + } + + return kLiteRtStatusOk; +} + +LiteRtStatus LiteRtPluginCompile(LiteRtCompilerPlugin compiler_plugin, + const char* soc_model, + LiteRtSubgraphArray partitions, + LiteRtParamIndex num_partitions, + LiteRtCompiledResult* compiled_result) { + auto opt_soc_model = FindSocModel(soc_model); + + auto backend_configs = QnnManager::DefaultBackendConfigs(); + auto qnn_manager = QnnManager::Create( + backend_configs, /*shared_library_dir=*/std::nullopt, opt_soc_model); + if (!qnn_manager.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", qnn_manager.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto context_configs = QnnManager::DefaultContextConfigs(); + auto context_handle = (*qnn_manager)->CreateContextHandle(context_configs); + if (!context_handle.ok()) { + LITERT_LOG(LITERT_ERROR, "%s", context_handle.status().message().data()); + return kLiteRtStatusErrorRuntimeFailure; + } + + auto result = std::make_unique(); + + // TODO: Support multiple partitions in QCC plugin compile. + LITERT_ENSURE_SUPPORTED(num_partitions, 1); + { + std::string& entry_point_name = result->graph_names.emplace_back(); + entry_point_name = "qnn_partition_0"; + LITERT_RETURN_STATUS_IF_NOT_OK(litert::qnn::ComposeGraph( + **qnn_manager, context_handle->get(), partitions[0], entry_point_name)); + } + + LITERT_RETURN_STATUS_IF_NOT_OK( + (*qnn_manager) + ->GenerateContextBinary(context_handle->get(), result->context_bin)); + + *compiled_result = result.release(); + + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc new file mode 100644 index 00000000000000..40b65efa997394 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc @@ -0,0 +1,153 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include "absl/log/absl_check.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_op_code.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/graph_tools.h" +#include "tensorflow/lite/experimental/litert/core/model.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h" + +namespace { +static constexpr absl::string_view kOpTpl = "simple_%s_op.tflite"; +// clang-format off +const auto kSupportedOps = + testing::Values( + "add", + "div", + "mul", + "rsqrt", + "slice", + "sub", + "tanh", + "reshape", + "batch_matmul" + ); +// clang-format on + +UniqueLiteRtCompilerPlugin GetQnnPlugin() { + LiteRtCompilerPlugin qnn_plugin; + LITERT_CHECK_STATUS_OK(LiteRtPluginInit(&qnn_plugin)); + ABSL_CHECK_NE(qnn_plugin, nullptr); + return UniqueLiteRtCompilerPlugin(qnn_plugin); +} + +TEST(TestQnnPlugin, GetConfigInfo) { + EXPECT_STREQ(LiteRtPluginSocManufacturer(), "Qualcomm"); + + auto plugin = GetQnnPlugin(); + + ASSERT_GE(LiteRtPluginNumSupportedSocModels(plugin.get()), 1); + + const char* config_id; + LITERT_CHECK_STATUS_OK( + LiteRtPluginGetSupportedSocModel(plugin.get(), 0, &config_id)); + EXPECT_STREQ(config_id, "V68"); +} + +TEST(TestQnnPlugin, PartitionMulOps) { + auto plugin = GetQnnPlugin(); + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + + LiteRtOpListT selected_op_list; + ASSERT_STATUS_OK( + LiteRtPluginPartitionModel(plugin.get(), model.get(), &selected_op_list)); + const auto selected_ops = selected_op_list.Vec(); + + ASSERT_EQ(selected_ops.size(), 1); + EXPECT_EQ(selected_ops[0]->op_code, kLiteRtOpCodeTflMul); +} + +TEST(TestQnnPlugin, CompileMulSubgraph) { + auto plugin = GetQnnPlugin(); + auto model = litert::testing::LoadTestFileModel("one_mul.tflite"); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + + LiteRtCompiledResult compiled; + ASSERT_STATUS_OK( + LiteRtPluginCompile(plugin.get(), "V75", &subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetByteCode(compiled, &byte_code, &byte_code_size)); + + std::string byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_FALSE(byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetCallInfo(compiled, 0, &op_data, &op_data_size)); + + std::string op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ("qnn_partition_0", op_data_string); + + LiteRtCompiledResultDestroy(compiled); +} + +class QnnPluginOpCompatibilityTest + : public ::testing::TestWithParam {}; + +TEST_P(QnnPluginOpCompatibilityTest, SupportedOpsTest) { + auto plugin = GetQnnPlugin(); + auto model = + litert::testing::LoadTestFileModel(absl::StrFormat(kOpTpl, GetParam())); + + ASSERT_RESULT_OK_ASSIGN(auto subgraph, + ::graph_tools::GetSubgraph(model.get())); + + LiteRtCompiledResult compiled; + ASSERT_STATUS_OK( + LiteRtPluginCompile(plugin.get(), "V75", &subgraph, 1, &compiled)); + + const void* byte_code; + size_t byte_code_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetByteCode(compiled, &byte_code, &byte_code_size)); + + std::string byte_code_string(reinterpret_cast(byte_code), + byte_code_size); + ASSERT_FALSE(byte_code_string.empty()); + + const void* op_data; + size_t op_data_size; + + ASSERT_STATUS_OK( + LiteRtCompiledResultGetCallInfo(compiled, 0, &op_data, &op_data_size)); + + std::string op_data_string(reinterpret_cast(op_data), + op_data_size); + ASSERT_EQ("qnn_partition_0", op_data_string); + + LiteRtCompiledResultDestroy(compiled); +} + +INSTANTIATE_TEST_SUITE_P(SupportedOpsTest, QnnPluginOpCompatibilityTest, + kSupportedOps); + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc new file mode 100644 index 00000000000000..bd891d1c6e04b1 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.cc @@ -0,0 +1,157 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h" + +#include +#include + +#include +#include + +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/cc/litert_op.h" +#include "tensorflow/lite/experimental/litert/cc/litert_support.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_op.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/IR/qnn_tensor.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/graph_mapper.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/add_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/batch_matmul_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/div_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/mul_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/reshape_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/rsqrt_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/slice_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/sub_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/legalizations/tanh_op_legalization.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +namespace { + +LiteRtStatus RegisterAllLegalizations( + std::vector>& legalizations) { + legalizations.push_back(MulOpLegalization::Create()); + legalizations.push_back(BatchMatmulOpLegalization::Create()); + legalizations.push_back(SliceOpLegalization::Create()); + legalizations.push_back(AddOpLegalization::Create()); + legalizations.push_back(DivOpLegalization::Create()); + legalizations.push_back(RsqrtOpLegalization::Create()); + legalizations.push_back(TanhOpLegalization::Create()); + legalizations.push_back(SubOpLegalization::Create()); + legalizations.push_back(ReshapeOpLegalization::Create()); + LITERT_LOG(LITERT_INFO, "Scheduling %lu legalizations", legalizations.size()); + return kLiteRtStatusOk; +} + +LiteRtStatus MapGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name) { + // Register legalizations. + std::vector> legalizations; + LITERT_RETURN_STATUS_IF_NOT_OK(RegisterAllLegalizations(legalizations)); + + GraphMapper graph_mapper(subgraph, qnn, context_handle); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.ParseLiteRtSubgraph()); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.IsLiteRtSubgraphSupported()); + LITERT_RETURN_STATUS_IF_NOT_OK(graph_mapper.InitQnnGraph(qnn_graph_name)); + + // + // Legalize subgraph inputs and update tensors in scope + // + + for (auto subgraph_input : graph_mapper.LiteRtSubgraphInputs()) { + Qnn_Tensor_t qnn_subgraph_input = BuildInputTensor(); + + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.LegalizeAndRegister(subgraph_input, qnn_subgraph_input)); + + LITERT_RETURN_STATUS_IF_NOT_OK( + graph_mapper.PushToScope(subgraph_input, qnn_subgraph_input)); + } + // + // Topologically traverse graph, legalizing and updating tensors in scope + // + + // Use simple traversal for now. + // TODO: Drive traversal here. + for (auto op : graph_mapper.LiteRtSubgraphOps()) { + Qnn_OpConfig_t qnn_op = BuildDefaultOp(); + LiteRtOpManager::Unique op_manager; + LITERT_RETURN_STATUS_IF_NOT_OK(LiteRtOpManager::MakeFromOp(op, op_manager)); + for (auto it = legalizations.begin(); it != legalizations.end(); ++it) { + LITERT_RETURN_STATUS_IF_NOT_OK_OR_NOT_MATCHED( + (*it)->LegalizeOp(*op_manager, qnn_op, graph_mapper)); + } + } + + LITERT_RETURN_STATUS_IF_QNN_NOT_OK(graph_mapper.Finalize()); + + return kLiteRtStatusOk; +} + +} // namespace + +//===----------------------------------------------------------------------===// +// +// [WIP] LiteRT SUBGRAPH -> QNN GRAPH +// +// Core driver for IR translation. Traverses LiteRt Subgraph, iteratively +// "legalizing" (mapping) LiteRt entities to their QNN counterpart. +// +// APPROACH: +// +// Currently demoing by just handling a simple case where there is one +// partitions and the partitions is as follows: +// +// func(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) +// %0 = tfl.mul(%arg0, %arg1) +// return %0 +// +// To support the general case we will need a driver loop that either +// traverses input recursively through edges or just iterates topologically. +// Currently we just have only implemented n=1. +// +// The algorithm is pretty straightforward: +// * Store mapping between already evaluated LiteRtTensors and their +// newly constructed Qnn Tensor counterpart. +// * Look up QNN Tensors when setting QNN Op inputs. +// * Add new QNN Tensor when setting QNN Op outputs. +// +// NOTES ON QNN API: +// +// After QNN Tensors are registered in the context, they need only +// be stored as their ID. QNN Tensor and "id" : uint32_t are used +// interchangeably. +// +//===----------------------------------------------------------------------===// + +LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name) { + LITERT_RETURN_STATUS_IF_NOT_OK( + MapGraph(qnn, context_handle, subgraph, qnn_graph_name)); + return kLiteRtStatusOk; +} + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h new file mode 100644 index 00000000000000..db978ba440427b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compose_graph.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ + +#include "absl/strings/string_view.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn { + +// Composes a new QNN Graph from given LiteRt Graph. Qnn Graph is written to +// context behind "qnn". Uses given graph_name to name entry point. +LiteRtStatus ComposeGraph(QnnManager& qnn, Qnn_ContextHandle_t context_handle, + LiteRtSubgraph subgraph, + absl::string_view qnn_graph_name); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_COMPILER_QNN_COMPOSE_GRAPH_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc new file mode 100644 index 00000000000000..d9d7febc8cae93 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.cc @@ -0,0 +1,179 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" + +#include +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +namespace litert { +namespace qnn { + +namespace { + +absl::Status InsertQnnTensors(int num_qnn_tensors, Qnn_Tensor_t* qnn_tensors, + std::vector* tensors) { + tensors->clear(); + tensors->reserve(num_qnn_tensors); + for (auto i = 0; i < num_qnn_tensors; ++i) { + auto tensor = QnnTensor::Create(qnn_tensors[i]); + if (!tensor.ok()) { + return tensor.status(); + } + tensors->push_back(std::move(*tensor)); + } + return {}; +} + +absl::Status InsertQnnGraphInfos(int num_qnn_graph_infos, + QnnSystemContext_GraphInfo_t* qnn_graph_infos, + std::vector* graphs) { + graphs->clear(); + graphs->reserve(num_qnn_graph_infos); + for (auto i = 0; i < num_qnn_graph_infos; ++i) { + auto graph = GraphInfo::Create(qnn_graph_infos[i]); + if (!graph.ok()) { + return graph.status(); + } + graphs->push_back(std::move(*graph)); + } + + return {}; +} + +} // namespace + +absl::StatusOr GraphInfo::Create( + const QnnSystemContext_GraphInfo_t& graph_info) { + GraphInfo info; + auto status = info.Init(graph_info); + if (status.ok()) { + return info; + } else { + return status; + } +} + +absl::Status GraphInfo::Init(const QnnSystemContext_GraphInfo_t& graph_info) { + if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + const auto& graph_info_ = graph_info.graphInfoV1; + name_ = graph_info_.graphName; + if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, + graph_info_.graphInputs, &inputs_); + !status.ok()) { + return status; + } + if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, + graph_info_.graphOutputs, &outputs_); + !status.ok()) { + return status; + } + + } else if (graph_info.version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2) { + const auto& graph_info_ = graph_info.graphInfoV2; + name_ = graph_info_.graphName; + if (auto status = InsertQnnTensors(graph_info_.numGraphInputs, + graph_info_.graphInputs, &inputs_); + !status.ok()) { + return status; + } + if (auto status = InsertQnnTensors(graph_info_.numGraphOutputs, + graph_info_.graphOutputs, &outputs_); + !status.ok()) { + return status; + } + } + + return {}; +} + +absl::Status ContextBinaryInfo::Init( + const QnnSystemContext_BinaryInfo_t& binary_info) { + if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + const auto& context_binary_info = binary_info.contextBinaryInfoV1; + if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, + context_binary_info.contextTensors, + &context_tensors_); + !status.ok()) { + return status; + } + if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, + context_binary_info.graphs, &graphs_); + !status.ok()) { + return status; + } + + } else if (binary_info.version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + const auto& context_binary_info = binary_info.contextBinaryInfoV1; + if (auto status = InsertQnnTensors(context_binary_info.numContextTensors, + context_binary_info.contextTensors, + &context_tensors_); + !status.ok()) { + return status; + } + if (auto status = InsertQnnGraphInfos(context_binary_info.numGraphs, + context_binary_info.graphs, &graphs_); + !status.ok()) { + return status; + } + } + + return {}; +} + +absl::StatusOr ContextBinaryInfo::Create( + QnnManager& qnn, const void* exec_bytecode_ptr, size_t exec_bytecode_size) { + auto system_context_handle = qnn.CreateSystemContextHandle(); + if (!system_context_handle.ok()) { + return system_context_handle.status(); + } + + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size = 0; + if (auto status = qnn.SystemApi()->systemContextGetBinaryInfo( + system_context_handle->get(), const_cast(exec_bytecode_ptr), + exec_bytecode_size, &binary_info, &binary_info_size); + status != QNN_SUCCESS) { + ABSL_LOG(ERROR) << "Failed to get context binary info: " << status; + return absl::InternalError("Failed to get context binary info"); + } + + if (!binary_info) { + ABSL_LOG(ERROR) << "Null binary info"; + return absl::InternalError("Null binary info"); + } + + ContextBinaryInfo info; + auto status = info.Init(*binary_info); + + if (status.ok()) { + return info; + } else { + return status; + } +} + +} // namespace qnn +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h new file mode 100644 index 00000000000000..cd3ac6af29b520 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +namespace litert { +namespace qnn { + +class GraphInfo { + public: + static absl::StatusOr Create( + const QnnSystemContext_GraphInfo_t& graph_info); + const std::string& Name() const { return name_; } + const std::vector& Inputs() const { return inputs_; } + const std::vector& Outputs() const { return outputs_; } + + private: + GraphInfo() = default; + absl::Status Init(const QnnSystemContext_GraphInfo_t& graph_info); + std::string name_; + std::vector inputs_; + std::vector outputs_; +}; + +class ContextBinaryInfo { + public: + static absl::StatusOr Create(QnnManager& qnn, + const void* exec_bytecode_ptr, + size_t exec_bytecode_size); + const std::vector& ContextTensors() const { + return context_tensors_; + } + const std::vector& Graphs() const { return graphs_; } + + private: + ContextBinaryInfo() = default; + absl::Status Init(const QnnSystemContext_BinaryInfo_t& binary_info); + std::vector context_tensors_; + std::vector graphs_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_CONTEXT_BINARY_INFO_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD new file mode 100644 index 00000000000000..e85b4b2cfef2f5 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/BUILD @@ -0,0 +1,96 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], +) + +cc_library( + name = "dispatch_api", + srcs = [ + "dispatch_api.cc", + "litert_dispatch_device_context.cc", + "litert_dispatch_invocation_context.cc", + ], + hdrs = [ + "litert_dispatch_device_context.h", + "litert_dispatch_invocation_context.h", + "registry.h", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + tags = [ + # Don't build/test in OS until qnn is available. + "nobuilder", + ], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core:utils", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:common", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:context_binary_info", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager", + ], +) + +cc_shared_library( + name = "dispatch_api_shared", + shared_lib_name = "libLiteRtDispatch.so", + visibility = ["//visibility:public"], + deps = [":dispatch_api"], +) + +cc_test( + name = "dispatch_api_qualcomm_test", + srcs = [ + "dispatch_api_qualcomm_test.cc", + ], + data = [ + ":dispatch_api_shared", + ], + linkopts = select({ + "//tensorflow:android": ["-landroid"], + "//conditions:default": [], + }), + linkstatic = 1, + tags = [ + "no-remote-exec", + "no_oss", + "notap", + ], + deps = [ + ":dispatch_api", + "//tensorflow/lite/experimental/litert/c:litert_c_api", + "//tensorflow/lite/experimental/litert/c:litert_tensor_buffer", + "//tensorflow/lite/experimental/litert/core/dispatch", + "//tensorflow/lite/experimental/litert/test:common", + "//tensorflow/lite/experimental/litert/test:simple_model", + "//tensorflow/lite/experimental/litert/vendors/c:litert_dispatch_c_api", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc new file mode 100644 index 00000000000000..60f3629d80fa71 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api.cc @@ -0,0 +1,297 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch_api.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace { + +using ::litert::qnn::QnnManager; + +static constexpr const int VERSION_MAJOR = 0; +static constexpr const int VERSION_MINOR = 1; +static constexpr const int VERSION_PATCH = 0; + +static std::unique_ptr TheQnnManager; + +QnnManager& Qnn() { return *TheQnnManager; } + +char BuildId[256]; + +// ///////////////////////////////////////////////////////////////////////////// +// Basic Execution API +// ///////////////////////////////////////////////////////////////////////////// + +const char* GetSharedLibraryDir(const LiteRtDispatchOption* options, + int num_options) { + for (auto i = 0; i < num_options; ++i) { + auto& option = options[i]; + if (!strcmp(option.name, kDispatchOptionSharedLibraryDir)) { + return option.value.str_value; + } + } + return nullptr; +} + +LiteRtStatus Initialize(const LiteRtDispatchOption* options, int num_options) { + auto* shared_library_dir = GetSharedLibraryDir(options, num_options); + std::optional shared_library_dir_opt = + shared_library_dir ? std::make_optional(std::string(shared_library_dir)) + : std::nullopt; + + auto configs = QnnManager::DefaultBackendConfigs(); + if (auto qnn_manager = QnnManager::Create(configs, shared_library_dir_opt); + !qnn_manager.ok()) { + ABSL_LOG(ERROR) << qnn_manager.status(); + return kLiteRtStatusErrorRuntimeFailure; + } else { + std::swap(TheQnnManager, *qnn_manager); + } + + Qnn_ApiVersion_t qnn_api_version; + if (auto status = Qnn().Api()->backendGetApiVersion(&qnn_api_version); + status != QNN_SUCCESS) { + ABSL_LOG(ERROR) << "Failed to get QNN API version: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } + + const char* build_id; + if (auto status = Qnn().Api()->backendGetBuildId(&build_id); + status != QNN_SUCCESS) { + ABSL_LOG(ERROR) << "Failed to get QNN build ID: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } + + snprintf(BuildId, sizeof(BuildId), + "Qualcomm Dispatch API version %d.%d.%d, QNN API version %d.%d.%d, " + "build id: %s", + VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, + qnn_api_version.coreApiVersion.major, + qnn_api_version.coreApiVersion.minor, + qnn_api_version.coreApiVersion.patch, build_id); + BuildId[sizeof(BuildId) - 1] = 0; + + return kLiteRtStatusOk; +} + +LiteRtStatus GetVendorId(const char** vendor_id) { + *vendor_id = "Qualcomm"; + return kLiteRtStatusOk; +} + +LiteRtStatus GetBuildId(const char** build_id) { + *build_id = BuildId; + return kLiteRtStatusOk; +} + +LiteRtStatus GetCapabilities(int* capabilities) { + *capabilities = kLiteRtDispatchCapabilitiesBasic; + return kLiteRtStatusOk; +} + +LiteRtStatus DeviceContextCreate(LiteRtDispatchDeviceContext* device_context) { + if (auto status_or = LiteRtDispatchDeviceContextT::Create(Qnn()); + status_or.ok()) { + *device_context = status_or->release(); + return kLiteRtStatusOk; + } else { + ABSL_LOG(ERROR) << "Failed to create device context: " + << status_or.status(); + return kLiteRtStatusErrorRuntimeFailure; + } +} + +LiteRtStatus DeviceContextDestroy(LiteRtDispatchDeviceContext device_context) { + delete device_context; + return kLiteRtStatusOk; +} + +LiteRtStatus GetInputRequirements( + LiteRtDispatchInvocationContext invocation_context, int input_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetInputRequirements(input_index, *tensor_type); + requirements.ok()) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + ABSL_LOG(ERROR) << "Failed to get tensor buffer requirements: " + << requirements.status(); + return kLiteRtStatusErrorRuntimeFailure; + } +} + +LiteRtStatus GetOutputRequirements( + LiteRtDispatchInvocationContext invocation_context, int output_index, + const LiteRtRankedTensorType* tensor_type, + LiteRtTensorBufferRequirements* tensor_buffer_requirements) { + if (auto requirements = + invocation_context->GetOutputRequirements(output_index, *tensor_type); + requirements.ok()) { + *tensor_buffer_requirements = *requirements; + return kLiteRtStatusOk; + } else { + ABSL_LOG(ERROR) << "Failed to get tensor buffer requirements: " + << requirements.status(); + return kLiteRtStatusErrorRuntimeFailure; + } +} + +LiteRtStatus RegisterTensorBuffer( + LiteRtDispatchDeviceContext device_context, LiteRtTensorBuffer buffer, + LiteRtTensorBufferHandle* tensor_buffer_handle) { + if (auto status = device_context->RegisterTensorBuffer(buffer); + !status.ok()) { + ABSL_LOG(ERROR) << "Failed to register buffer: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } else { + *tensor_buffer_handle = *status; + return kLiteRtStatusOk; + } +} + +LiteRtStatus UnregisterTensorBuffer(LiteRtDispatchDeviceContext device_context, + LiteRtTensorBufferHandle handle) { + if (auto status = device_context->UnregisterTensorBuffer(handle); + !status.ok()) { + ABSL_LOG(ERROR) << "Failed to unregister buffer: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } else { + return kLiteRtStatusOk; + } +} + +LiteRtStatus InvocationContextCreate( + LiteRtDispatchDeviceContext device_context, + LiteRtDispatchExecutableType exec_type, const void* exec_bytecode_ptr, + size_t exec_bytecode_size, const char* function_name, int num_inputs, + int num_outputs, LiteRtDispatchInvocationContext* invocation_context) { + auto context = LiteRtDispatchInvocationContextT::Create( + Qnn(), *device_context, exec_bytecode_ptr, exec_bytecode_size, + function_name); + if (!context.ok()) { + ABSL_LOG(ERROR) << "Failed to create context from context binary: " + << context.status(); + return kLiteRtStatusErrorRuntimeFailure; + } + *invocation_context = context->release(); + device_context->SetInvocationContext(*invocation_context); + return kLiteRtStatusOk; +} + +LiteRtStatus InvocationContextDestroy( + LiteRtDispatchInvocationContext invocation_context) { + delete invocation_context; + return kLiteRtStatusOk; +} + +LiteRtStatus AttachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachInput(graph_input_index, + tensor_buffer_handle); + !status.ok()) { + ABSL_LOG(ERROR) << "Failed to attach input buffer: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } + return kLiteRtStatusOk; +} + +LiteRtStatus AttachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + if (auto status = invocation_context->AttachOutput(graph_output_index, + tensor_buffer_handle); + !status.ok()) { + ABSL_LOG(ERROR) << "Failed to attach output buffer: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } + return kLiteRtStatusOk; +} + +LiteRtStatus DetachInput(LiteRtDispatchInvocationContext invocation_context, + int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do here. + return kLiteRtStatusOk; +} + +LiteRtStatus DetachOutput(LiteRtDispatchInvocationContext invocation_context, + int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle) { + // Nothing to do here. + return kLiteRtStatusOk; +} + +LiteRtStatus Invoke(LiteRtDispatchInvocationContext invocation_context) { + if (auto status = invocation_context->Execute(); !status.ok()) { + ABSL_LOG(ERROR) << "Failed to execute invocation context: " << status; + return kLiteRtStatusErrorRuntimeFailure; + } + return kLiteRtStatusOk; +} + +// ///////////////////////////////////////////////////////////////////////////// + +LiteRtDispatchInterface TheInterface = { + /*.initialize=*/Initialize, + /*.get_vendor_id=*/GetVendorId, + /*.get_build_id=*/GetBuildId, + /*.get_capabilities=*/GetCapabilities, + /*.device_context_create=*/DeviceContextCreate, + /*.device_context_destroy=*/DeviceContextDestroy, + /*.get_input_requirements=*/GetInputRequirements, + /*.get_output_requirements=*/GetOutputRequirements, + /*.register_tensor_buffer=*/RegisterTensorBuffer, + /*.unregister_tensor_buffer=*/UnregisterTensorBuffer, + /*.invocation_context_create=*/InvocationContextCreate, + /*.invocation_context_destroy=*/InvocationContextDestroy, + /*.attach_input=*/AttachInput, + /*.attach_output=*/AttachOutput, + /*.detach_input=*/DetachInput, + /*.detach_output=*/DetachOutput, + /*.invoke=*/Invoke, +}; + +LiteRtDispatchApi TheApi = { + /*.version=*/{/*.major=*/VERSION_MAJOR, + /*.minor=*/VERSION_MINOR, + /*.patch=*/VERSION_PATCH}, + /*.interface=*/&TheInterface, + /*.async_interface=*/nullptr, + /*.graph_interface=*/nullptr, +}; + +} // namespace + +LiteRtStatus LiteRtDispatchGetApi(LiteRtDispatchApi* api) { + *api = TheApi; + return kLiteRtStatusOk; +} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc new file mode 100644 index 00000000000000..654528ce6e1e88 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/dispatch_api_qualcomm_test.cc @@ -0,0 +1,532 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include +#include +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/test/testdata/simple_model_test_vectors.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" + +using ::testing::Pointwise; + +TEST(Qualcomm, DispatchApiWithFastRpc) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a Qualcomm NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtDispatchApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kQualcommModelFileName; + auto model = litert::testing::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model.ok()); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->data(), model->size(), /*function_name=*/"simple", + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/0, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/0, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/0, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeFastRpc); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} + +TEST(Qualcomm, DispatchApiWithDmaBuf) { +#if !defined(__ANDROID__) + GTEST_SKIP() + << "This test is specific to Android devices with a Qualcomm NPU"; +#endif + + EXPECT_EQ(LiteRtDispatchInitialize(/*options=*/nullptr, /*num_options=*/0), + kLiteRtStatusOk); + + const char* vendor_id; + EXPECT_EQ(LiteRtDispatchGetVendorId(&vendor_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "vendor_id: " << vendor_id; + + const char* build_id; + EXPECT_EQ(LiteRtDispatchGetBuildId(&build_id), kLiteRtStatusOk); + ABSL_LOG(INFO) << "build_id: " << build_id; + + LiteRtDispatchApiVersion api_version; + EXPECT_EQ(LiteRtDispatchGetApiVersion(&api_version), kLiteRtStatusOk); + ABSL_LOG(INFO) << "api_version: " << api_version.major << "." + << api_version.minor << "." << api_version.patch; + + int capabilities; + EXPECT_EQ(LiteRtDispatchGetCapabilities(&capabilities), kLiteRtStatusOk); + ABSL_LOG(INFO) << "capabilities: " << capabilities; + + LiteRtDispatchDeviceContext device_context = nullptr; + EXPECT_EQ(LiteRtDispatchDeviceContextCreate(&device_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "device_context: " << device_context; + + auto model_file_name = kQualcommModelFileName; + auto model = ::litert::testing::LoadBinaryFile(model_file_name); + EXPECT_TRUE(model.ok()); + ABSL_LOG(INFO) << "Loaded model " << model_file_name << ", " << model->size() + << " bytes"; + + // /////////////////////////////////////////////////////////////////////////// + // Set up an invocation context for a given model. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtDispatchInvocationContext invocation_context = nullptr; + EXPECT_EQ(LiteRtDispatchInvocationContextCreate( + device_context, kLiteRtDispatchExecutableTypeMlModel, + model->data(), model->size(), /*function_name=*/"simple", + /*num_inputs=*/2, /*num_outputs=*/1, &invocation_context), + kLiteRtStatusOk); + ABSL_LOG(INFO) << "Invocation context: " << invocation_context; + + // /////////////////////////////////////////////////////////////////////////// + // Determine tensor buffer requirements. + // /////////////////////////////////////////////////////////////////////////// + + int num_tensor_buffer_types; + LiteRtTensorBufferRequirements input_0_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/0, &kInput0TensorType, + &input_0_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_0_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_0_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_0_tensor_buffer_requirements, /*type_index=*/1, + &input_0_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_0_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_0_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_0_tensor_buffer_requirements, &input_0_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_0_tensor_buffer_size, sizeof(kTestInput0Tensor)); + + LiteRtTensorBufferRequirements input_1_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetInputRequirements( + invocation_context, /*input_index=*/1, &kInput1TensorType, + &input_1_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + input_1_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType input_1_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + input_1_tensor_buffer_requirements, /*type_index=*/1, + &input_1_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(input_1_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t input_1_tensor_buffer_size; + EXPECT_EQ( + LiteRtGetTensorBufferRequirementsBufferSize( + input_1_tensor_buffer_requirements, &input_1_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(input_1_tensor_buffer_size, sizeof(kTestInput1Tensor)); + + LiteRtTensorBufferRequirements output_tensor_buffer_requirements; + EXPECT_EQ(LiteRtDispatchGetOutputRequirements( + invocation_context, /*output_index=*/0, &kOutputTensorType, + &output_tensor_buffer_requirements), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtGetTensorBufferRequirementsNumSupportedTensorBufferTypes( + output_tensor_buffer_requirements, &num_tensor_buffer_types), + kLiteRtStatusOk); + EXPECT_GE(num_tensor_buffer_types, 1); + LiteRtTensorBufferType output_tensor_buffer_type; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsSupportedTensorBufferType( + output_tensor_buffer_requirements, /*type_index=*/1, + &output_tensor_buffer_type), + kLiteRtStatusOk); + EXPECT_EQ(output_tensor_buffer_type, kLiteRtTensorBufferTypeDmaBuf); + size_t output_tensor_buffer_size; + EXPECT_EQ(LiteRtGetTensorBufferRequirementsBufferSize( + output_tensor_buffer_requirements, &output_tensor_buffer_size), + kLiteRtStatusOk); + EXPECT_GE(output_tensor_buffer_size, sizeof(kTestOutputTensor)); + + // /////////////////////////////////////////////////////////////////////////// + // Allocate tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBuffer input_0_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_0_tensor_buffer_type, &kInput0TensorType, + input_0_tensor_buffer_size, &input_0_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer input_1_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + input_1_tensor_buffer_type, &kInput1TensorType, + input_1_tensor_buffer_size, &input_1_tensor_buffer), + kLiteRtStatusOk); + + LiteRtTensorBuffer output_tensor_buffer; + EXPECT_EQ(LiteRtCreateManagedTensorBuffer( + output_tensor_buffer_type, &kOutputTensorType, + output_tensor_buffer_size, &output_tensor_buffer), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Register tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + LiteRtTensorBufferHandle input_1_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_1_tensor_buffer, &input_1_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle input_0_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, input_0_tensor_buffer, &input_0_handle), + kLiteRtStatusOk); + + LiteRtTensorBufferHandle output_handle; + EXPECT_EQ(LiteRtDispatchRegisterTensorBuffer( + device_context, output_tensor_buffer, &output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Attach tensor buffers. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchAttachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Fill the input buffers with data. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Filling inputs with data"; + void* host_mem_addr; + + ASSERT_EQ(LiteRtLockTensorBuffer(input_0_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput0Tensor, sizeof(kTestInput0Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_0_tensor_buffer), kLiteRtStatusOk); + + ASSERT_EQ(LiteRtLockTensorBuffer(input_1_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + std::memcpy(host_mem_addr, kTestInput1Tensor, sizeof(kTestInput1Tensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(input_1_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Execute model. + // /////////////////////////////////////////////////////////////////////////// + + ABSL_LOG(INFO) << "Invoking execution..."; + EXPECT_EQ(LiteRtDispatchInvoke(invocation_context), kLiteRtStatusOk); + + // /////////////////////////////////////////////////////////////////////////// + // Check output for correctness. + // /////////////////////////////////////////////////////////////////////////// + + { + ABSL_LOG(INFO) << "Checking output..."; + void* host_mem_addr; + ASSERT_EQ(LiteRtLockTensorBuffer(output_tensor_buffer, &host_mem_addr, + /*event=*/nullptr), + kLiteRtStatusOk); + auto output = absl::MakeSpan(static_cast(host_mem_addr), + kTestOutputSize); + for (auto i = 0; i < kTestOutputSize; ++i) { + ABSL_LOG(INFO) << output[i] << "\t" << kTestOutputTensor[i]; + } + EXPECT_THAT(output, Pointwise(testing::FloatNear(1e-3), kTestOutputTensor)); + ASSERT_EQ(LiteRtUnlockTensorBuffer(output_tensor_buffer), kLiteRtStatusOk); + } + + // /////////////////////////////////////////////////////////////////////////// + // Clean up resources. + // /////////////////////////////////////////////////////////////////////////// + + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/0, input_0_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachInput(invocation_context, + /*graph_input_index=*/1, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDetachOutput(invocation_context, + /*graph_output_index=*/0, output_handle), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchUnregisterTensorBuffer(device_context, output_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_1_handle), + kLiteRtStatusOk); + EXPECT_EQ( + LiteRtDispatchUnregisterTensorBuffer(device_context, input_0_handle), + kLiteRtStatusOk); + LiteRtDestroyTensorBuffer(output_tensor_buffer); + LiteRtDestroyTensorBuffer(input_1_tensor_buffer); + LiteRtDestroyTensorBuffer(input_0_tensor_buffer); + EXPECT_EQ(LiteRtDispatchInvocationContextDestroy(invocation_context), + kLiteRtStatusOk); + EXPECT_EQ(LiteRtDispatchDeviceContextDestroy(device_context), + kLiteRtStatusOk); +} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc new file mode 100644 index 00000000000000..22400ba1982779 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.cc @@ -0,0 +1,184 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" + +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpMem.h" +#include "third_party/qairt/latest/include/QNN/QnnBackend.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnMem.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using ::litert::qnn::QnnManager; + +absl::StatusOr +LiteRtDispatchDeviceContextT::Create(QnnManager& qnn) { + return Ptr(new LiteRtDispatchDeviceContextT(qnn)); +} + +absl::StatusOr +LiteRtDispatchDeviceContextT::GetTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle) { + auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); + if (!registry_entry.ok()) { + return registry_entry.status(); + } + + return (*registry_entry)->tensor_buffer; +} + +absl::StatusOr LiteRtDispatchDeviceContextT::GetMemHandle( + LiteRtTensorBufferHandle tensor_buffer_handle, const Qnn_Tensor_t& tensor) { + auto registry_entry = tensor_buffer_registry_.Get(tensor_buffer_handle); + if (!registry_entry.ok()) { + return registry_entry.status(); + } + + if (!(*registry_entry)->qnn_mem_handle) { + auto qnn_mem_handle = + RegisterTensorBuffer((*registry_entry)->tensor_buffer, tensor); + if (!qnn_mem_handle.ok()) { + return qnn_mem_handle.status(); + } + (*registry_entry)->qnn_mem_handle = *qnn_mem_handle; + } + + return (*registry_entry)->qnn_mem_handle; +} + +absl::StatusOr +LiteRtDispatchDeviceContextT::RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor) { + LiteRtTensorBufferType tensor_buffer_type; + if (auto status = + LiteRtGetTensorBufferType(tensor_buffer, &tensor_buffer_type); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer type"); + } + + size_t tensor_buffer_size; + if (auto status = + LiteRtGetTensorBufferSize(tensor_buffer, &tensor_buffer_size); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer size"); + } + + size_t tensor_buffer_offset; + if (auto status = + LiteRtGetTensorBufferOffset(tensor_buffer, &tensor_buffer_offset); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer offset"); + } + + LiteRtRankedTensorType tensor_type; + if (auto status = + LiteRtGetTensorBufferTensorType(tensor_buffer, &tensor_type); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get tensor buffer's type"); + } + + Qnn_DataType_t tensor_data_type; + if (kLiteRtStatusOk != + LegalizeElementType(tensor_type.element_type, &tensor_data_type)) { + return absl::InternalError("Failed to legalize datatype"); + } + + uint32_t tensor_rank = tensor_type.layout.rank; + uint32_t* tensor_dimensions = reinterpret_cast( + const_cast(tensor_type.layout.dimensions)); + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return absl::InternalError("Tensor strides are not supported by QNN"); + } + + void* buffer_host_addr; + int buffer_fd; + (void)buffer_host_addr; + + switch (tensor_buffer_type) { + case kLiteRtTensorBufferTypeFastRpc: +#if LITERT_HAS_FASTRPC_SUPPORT + if (auto status = LiteRtGetTensorBufferFastRpcBuffer( + tensor_buffer, &buffer_host_addr, &buffer_fd); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get FastRPC buffer"); + } +#else + return absl::InternalError("FastRPC support is missing on this platform"); +#endif // LRT_HAS_FASTRPC_SUPPORT + break; + + case kLiteRtTensorBufferTypeDmaBuf: +#if LITERT_HAS_DMABUF_SUPPORT + if (auto status = LiteRtGetTensorBufferDmaBufBuffer( + tensor_buffer, &buffer_host_addr, &buffer_fd); + status != kLiteRtStatusOk) { + return absl::InternalError("Failed to get DMA-BUF buffer"); + } +#else + return absl::InternalError("DmaBuf support is missing on this platform"); +#endif // LRT_HAS_DMABUF_SUPPORT + break; + + default: + return absl::InternalError("Unsupported tensor buffer type"); + } + + QnnMemHtp_Descriptor_t mem_htp_descriptor = {}; + mem_htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + mem_htp_descriptor.size = tensor_buffer_size; + mem_htp_descriptor.sharedBufferConfig = + QnnHtpMem_SharedBufferConfig_t{buffer_fd, tensor_buffer_offset}; + + Qnn_MemDescriptor_t mem_descriptor = {}; + mem_descriptor.memShape = {tensor_rank, tensor_dimensions, nullptr}; + mem_descriptor.dataType = tensor_data_type; + mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; + mem_descriptor.customInfo = &mem_htp_descriptor; + + if (invocation_context_ == nullptr) { + return absl::InternalError("Missing invocation context"); + } + + Qnn_ContextHandle_t context_handle = invocation_context_->ContextHandle(); + + Qnn_MemHandle_t mem_handle = nullptr; + if (auto status = qnn_manager_.Api()->memRegister( + context_handle, &mem_descriptor, 1UL, &mem_handle); + status != QNN_SUCCESS) { + return absl::InternalError("Failed to register tensor buffer"); + } + + if (!mem_handle) { + return absl::InternalError("Failed to register buffer: null mem_handle"); + } + + return mem_handle; +} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h new file mode 100644 index 00000000000000..490178e01a7683 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +class LiteRtDispatchDeviceContextT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtDispatchDeviceContextT() = default; + + static absl::StatusOr Create(litert::qnn::QnnManager& qnn_manager); + + absl::StatusOr RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer) { + return tensor_buffer_registry_.Register( + TensorBufferRegistryEntry(tensor_buffer)); + } + + absl::Status UnregisterTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle) { + return tensor_buffer_registry_.Unregister(tensor_buffer_handle); + } + + absl::StatusOr GetTensorBuffer( + LiteRtTensorBufferHandle tensor_buffer_handle); + + absl::StatusOr GetMemHandle( + LiteRtTensorBufferHandle tensor_buffer_handle, + const Qnn_Tensor_t& tensor); + + void SetInvocationContext( + LiteRtDispatchInvocationContextT* invocation_context) { + invocation_context_ = invocation_context; + } + + private: + struct TensorBufferRegistryEntry { + LiteRtTensorBuffer tensor_buffer; + Qnn_MemHandle_t qnn_mem_handle = nullptr; + explicit TensorBufferRegistryEntry(LiteRtTensorBuffer tensor_buffer_) + : tensor_buffer(tensor_buffer_) {} + }; + + using TensorBufferRegistry = litert::qnn::Registry; + + LiteRtDispatchDeviceContextT(litert::qnn::QnnManager& qnn_manager) + : qnn_manager_(qnn_manager) {} + + absl::StatusOr RegisterTensorBuffer( + LiteRtTensorBuffer tensor_buffer, const Qnn_Tensor_t& tensor); + + litert::qnn::QnnManager& qnn_manager_; + TensorBufferRegistry tensor_buffer_registry_; + LiteRtDispatchInvocationContextT* invocation_context_ = nullptr; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_DEVICE_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc new file mode 100644 index 00000000000000..eeb3cc04ba5733 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.cc @@ -0,0 +1,223 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnContext.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_model.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/core/utils.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_device_context.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +using ::litert::qnn::QnnManager; + +LiteRtDispatchInvocationContextT::LiteRtDispatchInvocationContextT( + litert::qnn::QnnManager& qnn_manager, + const litert::qnn::ContextBinaryInfo& context_binary_info, + LiteRtDispatchDeviceContextT& device_context, + QnnManager::ContextHandle&& context_handle, + Qnn_ProfileHandle_t profile_handle, int graph_index, + Qnn_GraphHandle_t graph_handle) + : qnn_manager_(qnn_manager), + device_context_(device_context), + context_handle_(std::move(context_handle)), + profile_handle_(profile_handle), + graph_index_(graph_index), + graph_handle_(graph_handle), + inputs_(context_binary_info.Graphs()[graph_index].Inputs()), + outputs_(context_binary_info.Graphs()[graph_index].Outputs()) {} + +absl::StatusOr +LiteRtDispatchInvocationContextT::Create( + QnnManager& qnn, LiteRtDispatchDeviceContextT& device_context, + const void* exec_bytecode_ptr, size_t exec_bytecode_size, + const char* function_name) { + auto context_binary_info = litert::qnn::ContextBinaryInfo::Create( + qnn, exec_bytecode_ptr, exec_bytecode_size); + if (!context_binary_info.ok()) { + return context_binary_info.status(); + } + + int graph_index = -1; + const auto& graphs = context_binary_info->Graphs(); + for (auto i = 0; i < graphs.size(); ++i) { + const auto& graph = graphs[i]; + if (graph.Name() == absl::string_view(function_name)) { + graph_index = i; + break; + } + } + if (graph_index < 0) { + return absl::InternalError("Function name not found"); + } + + auto configs = QnnManager::DefaultContextConfigs(); + Qnn_ProfileHandle_t profile_handle = nullptr; + auto context_handle = qnn.CreateContextHandle( + configs, + absl::MakeSpan(static_cast(exec_bytecode_ptr), + exec_bytecode_size), + profile_handle); + if (!context_handle.ok()) { + return context_handle.status(); + } + + Qnn_GraphHandle_t graph_handle; + if (auto status = qnn.Api()->graphRetrieve(context_handle->get(), + function_name, &graph_handle); + status != QNN_SUCCESS) { + return absl::InternalError("Failed to retrieve graph"); + } + + return Ptr(new LiteRtDispatchInvocationContextT( + qnn, std::move(*context_binary_info), device_context, + std::move(*context_handle), profile_handle, graph_index, graph_handle)); +} + +namespace { + +absl::StatusOr GetTensorBufferRequirements( + const LiteRtRankedTensorType& tensor_type) { + auto* tensor_strides = tensor_type.layout.strides; + if (tensor_strides != nullptr) { + return absl::InternalError("Tensor strides are not supported by QNN"); + } + + static constexpr std::array + kSupportedTensorBufferTypes = { + kLiteRtTensorBufferTypeFastRpc, + kLiteRtTensorBufferTypeDmaBuf, + }; + + auto buffer_size = litert::internal::GetNumPackedBytes(tensor_type); + if (!buffer_size.ok()) { + return buffer_size.status(); + } + + LiteRtTensorBufferRequirements requirements; + if (auto status = LiteRtCreateTensorBufferRequirements( + kSupportedTensorBufferTypes.size(), + kSupportedTensorBufferTypes.data(), *buffer_size, &requirements); + status != kLiteRtStatusOk) { + return absl::InternalError("Not implemented"); + } + + return requirements; +} + +} // namespace + +absl::StatusOr +LiteRtDispatchInvocationContextT::GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +absl::StatusOr +LiteRtDispatchInvocationContextT::GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type) { + return GetTensorBufferRequirements(tensor_type); +} + +absl::Status LiteRtDispatchInvocationContextT::AttachInput( + int graph_input_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + if (graph_input_index < 0 || graph_input_index >= inputs_.size()) { + return absl::InternalError("Invalid graph_input_index"); + } + + auto& tensor = inputs_[graph_input_index]; + return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); +} + +absl::Status LiteRtDispatchInvocationContextT::AttachOutput( + int graph_output_index, LiteRtTensorBufferHandle tensor_buffer_handle) { + if (graph_output_index < 0 || graph_output_index >= outputs_.size()) { + return absl::InternalError("Invalid graph_output_index"); + } + + auto& tensor = outputs_[graph_output_index]; + return AttachBuffer(tensor.Tensor(), tensor_buffer_handle); +} + +absl::Status LiteRtDispatchInvocationContextT::AttachBuffer( + Qnn_Tensor_t& tensor, LiteRtTensorBufferHandle tensor_buffer_handle) { + auto tensor_buffer = device_context_.GetTensorBuffer(tensor_buffer_handle); + if (!tensor_buffer.ok()) { + return tensor_buffer.status(); + } + + auto mem_handle = device_context_.GetMemHandle(tensor_buffer_handle, tensor); + if (!mem_handle.ok()) { + return mem_handle.status(); + } + + if (tensor.version == QNN_TENSOR_VERSION_1) { + tensor.v1.memType = QNN_TENSORMEMTYPE_MEMHANDLE; + tensor.v1.memHandle = *mem_handle; + + } else if (tensor.version == QNN_TENSOR_VERSION_2) { + if (tensor.v2.isDynamicDimensions != nullptr) { + return absl::InternalError("Dynamic dimensions not yet supported"); + } + tensor.v2.memType = QNN_TENSORMEMTYPE_MEMHANDLE; + tensor.v2.memHandle = *mem_handle; + + } else { + return absl::InternalError("Unsupported QNN tensor version"); + } + + return {}; +} + +absl::Status LiteRtDispatchInvocationContextT::Execute() { + const size_t num_ins = inputs_.size(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, inputs, num_ins, QNN_TENSOR_INIT); + for (size_t i = 0; i < num_ins; ++i) { + *(inputs + i) = inputs_.at(i).Tensor(); + } + + const size_t num_outs = outputs_.size(); + LITERT_STACK_ARRAY(Qnn_Tensor_t, outputs, num_outs, QNN_TENSOR_INIT); + for (size_t i = 0; i < num_outs; ++i) { + *(outputs + i) = outputs_.at(i).Tensor(); + } + + if (auto status = qnn_manager_.Api()->graphExecute( + graph_handle_, inputs, num_ins, outputs, num_outs, + /*profileHandle=*/nullptr, /*signalHandle=*/nullptr); + status != QNN_SUCCESS) { + return absl::InternalError("Failed to execute graph"); + } + + return {}; +} diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h new file mode 100644 index 00000000000000..787518caa352c4 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/litert_dispatch_invocation_context.h @@ -0,0 +1,83 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer.h" +#include "tensorflow/lite/experimental/litert/c/litert_tensor_buffer_requirements.h" +#include "tensorflow/lite/experimental/litert/vendors/c/litert_dispatch.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/context_binary_info.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +class LiteRtDispatchDeviceContextT; + +class LiteRtDispatchInvocationContextT { + public: + using Ptr = std::unique_ptr; + + ~LiteRtDispatchInvocationContextT() = default; + + static absl::StatusOr Create( + litert::qnn::QnnManager& qnn_manager, + LiteRtDispatchDeviceContextT& device_context, + const void* exec_bytecode_ptr, size_t exec_bytecode_size, + const char* function_name); + + absl::StatusOr GetInputRequirements( + int input_index, const LiteRtRankedTensorType& tensor_type); + absl::StatusOr GetOutputRequirements( + int output_index, const LiteRtRankedTensorType& tensor_type); + + absl::Status AttachInput(int graph_input_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + + absl::Status AttachOutput(int graph_output_index, + LiteRtTensorBufferHandle tensor_buffer_handle); + + absl::Status Execute(); + + Qnn_ContextHandle_t ContextHandle() { return context_handle_.get(); } + + private: + LiteRtDispatchInvocationContextT( + litert::qnn::QnnManager& qnn_manager, + const litert::qnn::ContextBinaryInfo& context_binary_info, + LiteRtDispatchDeviceContextT& device_context, + litert::qnn::QnnManager::ContextHandle&& context_handle, + Qnn_ProfileHandle_t profile_handle, int graph_index, + Qnn_GraphHandle_t graph_handle); + + absl::Status AttachBuffer(Qnn_Tensor_t& tensor, + LiteRtTensorBufferHandle tensor_buffer_handle); + + litert::qnn::QnnManager& qnn_manager_; + LiteRtDispatchDeviceContextT& device_context_; + litert::qnn::QnnManager::ContextHandle context_handle_; + Qnn_ProfileHandle_t profile_handle_; + int graph_index_; + Qnn_GraphHandle_t graph_handle_; + std::vector inputs_; + std::vector outputs_; +}; + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_LITERT_DISPATCH_INVOCATION_CONTEXT_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h new file mode 100644 index 00000000000000..1cf9dca9728129 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/dispatch/registry.h @@ -0,0 +1,74 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" + +namespace litert { +namespace qnn { + +template +class Registry { + public: + absl::StatusOr Register(const V& value) { + // TODO: improve this linear search by keeping an index to the first unused + // element. + for (auto i = 0; i < entries_.size(); ++i) { + auto& entry = entries_[i]; + if (!entry.used) { + entry.value = value; + entry.used = true; + return static_cast(i); + } + } + // Grow the set of entries. + H handle = static_cast(entries_.size()); + entries_.emplace_back(value); + return handle; + } + + absl::Status Unregister(H handle) { + if (handle < 0 || handle >= entries_.size()) { + return absl::NotFoundError("Unexpected handle"); + } + entries_[handle].used = false; + return {}; + } + + absl::StatusOr Get(H handle) { + if (handle < 0 || handle >= entries_.size()) { + return absl::NotFoundError("Unexpected handle"); + } + return &entries_[handle].value; + } + + private: + struct Entry { + V value; + bool used; + explicit Entry(const V& v) : value(v), used(true) {} + }; + + std::vector entries_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_DISPATCH_REGISTRY_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc new file mode 100644 index 00000000000000..a0967992192570 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.cc @@ -0,0 +1,64 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" + +#include +#include +#include +#include + +#include "third_party/qairt/latest/include/QNN/QnnLog.h" + +namespace litert::qnn { +namespace { + +void DefaultStdOutLogger(const char* fmt, QnnLog_Level_t level, + uint64_t timestamp, va_list argp) { + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + char buffer1[256]; + char buffer2[256]; + double ms = timestamp; + snprintf(buffer1, sizeof(buffer1), "%8.1fms [%-7s] ", ms, levelStr); + buffer1[sizeof(buffer1) - 1] = 0; + vsnprintf(buffer2, sizeof(buffer2), fmt, argp); + buffer2[sizeof(buffer1) - 2] = 0; + std::cout << buffer1 << buffer2; +} + +} // namespace + +QnnLog_Callback_t GetDefaultStdOutLogger() { return DefaultStdOutLogger; } + +} // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h new file mode 100644 index 00000000000000..934a164b49f933 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h @@ -0,0 +1,28 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ + +#include "third_party/qairt/latest/include/QNN/QnnLog.h" + +namespace litert::qnn { + +// Gets a default logger implementation to stdout. +// This is used when initializing qnn logging. +QnnLog_Callback_t GetDefaultStdOutLogger(); + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_LOG_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc new file mode 100644 index 00000000000000..9a50d7636ec5f7 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.cc @@ -0,0 +1,380 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnLog.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemCommon.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/c/litert_logging.h" +#include "tensorflow/lite/experimental/litert/c/litert_support.h" +#include "tensorflow/lite/experimental/litert/core/dynamic_loading.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_log.h" + +namespace litert::qnn { + +namespace { + +constexpr char kLibQnnGetProvidersSymbol[] = "QnnInterface_getProviders"; + +constexpr char kLibQnnSystemGetProvidersSymbol[] = + "QnnSystemInterface_getProviders"; + +typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( + const QnnInterface_t*** provider_list, uint32_t* num_providers); + +typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)( + const QnnSystemInterface_t***, uint32_t*); + +absl::Span LoadProvidersFromLib(void* lib_so) { + QnnInterfaceGetProvidersFn_t get_providers = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK( + litert::ResolveLibSymbol( + lib_so, kLibQnnGetProvidersSymbol, &get_providers), + {}); + + const QnnInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get providers\n"); + return {}; + } + + return absl::MakeSpan(interface_providers, num_providers); +} + +absl::Span LoadSystemProvidersFromLib( + void* lib_so) { + QnnSystemInterfaceGetProvidersFn_t get_providers = nullptr; + LITERT_RETURN_VAL_IF_NOT_OK( + litert::ResolveLibSymbol( + lib_so, kLibQnnSystemGetProvidersSymbol, &get_providers), + {}); + + const QnnSystemInterface_t** interface_providers = nullptr; + uint32_t num_providers = 0; + if (QNN_SUCCESS != get_providers(&interface_providers, &num_providers)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get system providers\n"); + return {}; + } + + return absl::MakeSpan(interface_providers, num_providers); +} + +} // namespace + +QnnManager::~QnnManager() { + (void)FreeDevice(); + (void)FreeBackend(); + (void)FreeLogging(); +} + +LiteRtStatus QnnManager::LoadLib(absl::string_view path) { + LITERT_RETURN_STATUS_IF_NOT_OK(litert::OpenLib(path, &lib_so_)); + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::LoadSystemLib(absl::string_view path) { + LITERT_RETURN_STATUS_IF_NOT_OK(litert::OpenLib(path, &lib_system_so_)); + return kLiteRtStatusOk; +} + +const QnnApi* QnnManager::Api() const { + if (interface_ == nullptr) { + return nullptr; + } + return &interface_->QNN_INTERFACE_VER_NAME; +} + +LiteRtStatus QnnManager::ResolveApi() { + if (lib_so_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", + "Cannot resolve functions: libQnn*.so has not been loaded.\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + auto providers = LoadProvidersFromLib(lib_so_); + for (const auto& prov : providers) { + const bool major = + prov->apiVersion.coreApiVersion.major == QNN_API_VERSION_MAJOR; + + const bool minor = + prov->apiVersion.coreApiVersion.minor == QNN_API_VERSION_MINOR; + + const bool patch = + prov->apiVersion.coreApiVersion.patch == QNN_API_VERSION_PATCH; + + if (major && minor && patch) { + interface_ = prov; + break; + } + } + + if (interface_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", "No valid interface was provided\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::ResolveSystemApi() { + if (lib_so_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", + "Cannot resolve functions: libQnn*.so has not been loaded.\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + auto system_providers = LoadSystemProvidersFromLib(lib_system_so_); + for (const auto& system_prov : system_providers) { + const bool major = + system_prov->systemApiVersion.major == QNN_SYSTEM_API_VERSION_MAJOR; + + const bool minor = + system_prov->systemApiVersion.minor == QNN_SYSTEM_API_VERSION_MINOR; + + const bool patch = + system_prov->systemApiVersion.patch == QNN_SYSTEM_API_VERSION_PATCH; + + if (major && minor && patch) { + system_interface_ = system_prov; + break; + } + } + + if (system_interface_ == nullptr) { + LITERT_LOG(LITERT_ERROR, "%s", "No valid system interface was provided\n"); + return kLiteRtStatusErrorDynamicLoading; + } + + return kLiteRtStatusOk; +} + +const QnnSystemApi* QnnManager::SystemApi() const { + if (system_interface_ == nullptr) { + return nullptr; + } + return &system_interface_->QNN_SYSTEM_INTERFACE_VER_NAME; +} + +LiteRtStatus QnnManager::FreeLogging() { + if (log_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->logFree(log_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free logging\n"); + return kLiteRtStatusErrorNotFound; + } + } + log_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::FreeBackend() { + if (backend_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->backendFree(backend_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free backend\n"); + return kLiteRtStatusErrorNotFound; + } + } + backend_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::FreeDevice() { + if (device_handle_ != nullptr) { + if (QNN_SUCCESS != Api()->deviceFree(device_handle_)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to free device\n"); + return kLiteRtStatusErrorNotFound; + } + } + device_handle_ = nullptr; + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::GenerateContextBinary( + Qnn_ContextHandle_t context_handle, std::vector& buffer) { + Qnn_ContextBinarySize_t bin_size = 0; + if (QNN_SUCCESS != Api()->contextGetBinarySize(context_handle, &bin_size)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to get context bin size\n"); + return kLiteRtStatusErrorNotFound; + } + buffer.clear(); + buffer.resize(bin_size); + + Qnn_ContextBinarySize_t written_bin_size = 0; + if (QNN_SUCCESS != Api()->contextGetBinary(context_handle, buffer.data(), + buffer.size(), + &written_bin_size)) { + LITERT_LOG(LITERT_ERROR, "%s", "Failed to generated context binary \n"); + return kLiteRtStatusErrorNotFound; + } + + LITERT_LOG(LITERT_INFO, "Serialized a context bin of size (bytes): %lu\n", + written_bin_size); + + return kLiteRtStatusOk; +} + +LiteRtStatus QnnManager::Init(absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model) { + if (shared_library_dir.has_value()) { + // We must change the variable environment used to load DSP libraries. + std::string new_adsp_library_path; + if (auto* adsp_library_path = getenv("ADSP_LIBRARY_PATH"); + adsp_library_path != nullptr) { + new_adsp_library_path = absl::StrFormat( + "%s:%s", shared_library_dir->data(), adsp_library_path); + } else { + new_adsp_library_path = shared_library_dir->data(); + } + LITERT_LOG(LITERT_INFO, "Setting ADSP_LIBRARY_PATH to %s", + new_adsp_library_path.data()); + setenv("ADSP_LIBRARY_PATH", new_adsp_library_path.data(), /*overwrite=*/1); + } + + auto lib_qnn_htp_so_path = + shared_library_dir.has_value() + ? absl::StrFormat("%s/%s", shared_library_dir->data(), kLibQnnHtpSo) + : kLibQnnHtpSo; + LITERT_RETURN_STATUS_IF_NOT_OK(LoadLib(lib_qnn_htp_so_path)); + LITERT_RETURN_STATUS_IF_NOT_OK(ResolveApi()); + + auto lib_qnn_system_so_path = + shared_library_dir.has_value() + ? absl::StrFormat("%s/%s", shared_library_dir->data(), + kLibQnnSystemSo) + : kLibQnnSystemSo; + LITERT_RETURN_STATUS_IF_NOT_OK(LoadSystemLib(lib_qnn_system_so_path)); + LITERT_RETURN_STATUS_IF_NOT_OK(ResolveSystemApi()); + + if (auto status = Api()->logCreate(GetDefaultStdOutLogger(), + QNN_LOG_LEVEL_INFO, &LogHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN logger: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (auto status = + Api()->backendCreate(LogHandle(), configs.data(), &BackendHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN backend: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + + if (soc_model.has_value()) { + LITERT_LOG(LITERT_INFO, + "Initializing QNN backend for device architecture %d", + *soc_model); + QnnHtpDevice_CustomConfig_t arch_custom_config = {}; + arch_custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + arch_custom_config.arch.arch = *soc_model; + arch_custom_config.arch.deviceId = 0; + + QnnDevice_Config_t arch_device_config = {}; + arch_device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + arch_device_config.customConfig = &arch_custom_config; + + const QnnDevice_Config_t* device_configs[2] = { + &arch_device_config, + nullptr, + }; + + if (auto status = + Api()->deviceCreate(nullptr, device_configs, &DeviceHandle()); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN device: %d", status); + return kLiteRtStatusErrorRuntimeFailure; + } + } + + return kLiteRtStatusOk; +} + +absl::StatusOr +QnnManager::CreateSystemContextHandle() { + QnnSystemContext_Handle_t system_context_handle; + if (auto status = SystemApi()->systemContextCreate(&system_context_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN system context: %d", status); + return absl::InternalError("Failed to create QNN system context"); + } + auto deleter = SystemApi()->systemContextFree; + return SystemContextHandle{system_context_handle, deleter}; +} + +absl::StatusOr QnnManager::CreateContextHandle( + absl::Span configs) { + Qnn_ContextHandle_t context_handle; + if (auto status = Api()->contextCreate(BackendHandle(), DeviceHandle(), + configs.data(), &context_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); + return absl::InternalError("Failed to create QNN context"); + } + auto deleter = Api()->contextFree; + return ContextHandle{context_handle, /*profile_handle=*/nullptr, deleter}; +} + +absl::StatusOr QnnManager::CreateContextHandle( + absl::Span configs, + absl::Span bytecode, Qnn_ProfileHandle_t profile_handle) { + Qnn_ContextHandle_t context_handle; + if (auto status = Api()->contextCreateFromBinary( + BackendHandle(), DeviceHandle(), configs.data(), bytecode.data(), + bytecode.size(), &context_handle, profile_handle); + status != QNN_SUCCESS) { + LITERT_LOG(LITERT_ERROR, "Failed to create QNN context: %d", status); + return absl::InternalError("Failed to create QNN context"); + } + auto deleter = Api()->contextFree; + return ContextHandle{context_handle, profile_handle, deleter}; +} + +absl::StatusOr QnnManager::Create( + absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model) { + Ptr qnn_manager(new QnnManager); + if (qnn_manager->Init(configs, shared_library_dir, soc_model) != + kLiteRtStatusOk) { + return absl::InternalError("Failed to set up QNN manager"); + } + return qnn_manager; +} + +absl::Span QnnManager::DefaultBackendConfigs() { + static const QnnBackend_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +absl::Span QnnManager::DefaultContextConfigs() { + static const QnnContext_Config_t* configs[] = {nullptr}; + return absl::MakeSpan(configs); +} + +}; // namespace litert::qnn diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h new file mode 100644 index 00000000000000..9206f3edfac49b --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h @@ -0,0 +1,225 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "third_party/qairt/latest/include/QNN/HTP/QnnHtpDevice.h" +#include "third_party/qairt/latest/include/QNN/QnnBackend.h" +#include "third_party/qairt/latest/include/QNN/QnnCommon.h" +#include "third_party/qairt/latest/include/QNN/QnnContext.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemContext.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tensorflow/lite/experimental/litert/c/litert_common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/common.h" + +//===----------------------------------------------------------------------===// +// +// QnnManger +// +// Syntactic sugar for various Qnn Sdk routines. +// +// Provides various utilities for linking shared libraries at runtime +// against Qnn symbols as well as convience getters and storage of handles +// (pointers). Provides simple wrappers for freeing handles and returning +// LiteRtStatus rather than Qnn ones. Additionally exposes hooks for dumping +// api and shared libarary details. +// +// Does not own any memory and will always have trivial cstor/dstor. The +// user is responsible for freeing any Qnn handles explicitly. Note, +// Qnn handles will be automatically freed when the library is unloaded +// if they have been already. +// +//===----------------------------------------------------------------------===// + +namespace litert::qnn { + +class QnnManager; + +namespace internal { + +void Dump(const QnnManager& qnn, std::ostream& out); + +} // namespace internal + +class QnnManager { + friend void internal::Dump(const QnnManager& qnn, std::ostream& out); + + public: + using Ptr = std::unique_ptr; + using SystemContextHandle = + std::unique_ptr::type, + QnnSystemContext_FreeFn_t>; + class ContextHandle; + + ~QnnManager(); + + static absl::StatusOr Create( + absl::Span configs, + std::optional shared_library_dir = std::nullopt, + std::optional soc_model = std::nullopt); + + static absl::Span DefaultBackendConfigs(); + static absl::Span DefaultContextConfigs(); + + // Get resolved function pointers for qnn sdk calls. Nullptr if functions + // have not been resolved yet. + const QnnApi* Api() const; + + // Get resolved function pointers for qnn sdk calls. Nullptr if functions + // have not been resolved yet. + const QnnSystemApi* SystemApi() const; + + // + // QNN SDK Objects. + // + + // Create system context handle. + absl::StatusOr CreateSystemContextHandle(); + + // Create a context handle for compilation. + absl::StatusOr CreateContextHandle( + absl::Span configs); + + // Create a context handle for inference, from a given bytecode. + absl::StatusOr CreateContextHandle( + absl::Span configs, + absl::Span bytecode, Qnn_ProfileHandle_t profile_handle); + + // + // Context Binary + // + + // Generates QNN context binary from current context. Writes to given + // buffer. + LiteRtStatus GenerateContextBinary(Qnn_ContextHandle_t context_handle, + std::vector& buffer); + + private: + QnnManager() = default; + + LiteRtStatus Init(absl::Span configs, + std::optional shared_library_dir, + std::optional soc_model); + + // + // Manage libQnn*.so Loading + // + + // Loads the libQnn*.so at given path. + LiteRtStatus LoadLib(absl::string_view path); + + // Loads the libQnnSystem.so at given path. + LiteRtStatus LoadSystemLib(absl::string_view path); + + // + // Resolve and Access QNN SDK Functions + // + + // Resolve all available QNN SDK functions from (already) loaded so. If + // multiple providers are found, selects the first one with a suitable + // version. Fails if none can be found. + LiteRtStatus ResolveApi(); + + // Resolve all available QNN SDK functions from (already) loaded so. If + // multiple providers are found, selects the first one with a suitable + // version. Fails if none can be found. + LiteRtStatus ResolveSystemApi(); + + // Get qnn log handle. Nullptr if logCreate has not been successfully called. + Qnn_LogHandle_t& LogHandle() { return log_handle_; } + + // Get qnn backend handle. Nullptr if backendCreate has not been successfully + // called. + Qnn_BackendHandle_t& BackendHandle() { return backend_handle_; } + + // Get qnn device handle. Nullptr if deviceCreate has not been successfully + // called. + Qnn_DeviceHandle_t& DeviceHandle() { return device_handle_; } + + // Signal QNN SDK to free any memory related to the device. Does nothing + // if deviceCreate has not been called. + LiteRtStatus FreeDevice(); + + // Signal QNN SDK to free any memory related to logging. Does nothing + // if logCreate has not been called. + LiteRtStatus FreeLogging(); + + // Signal QNN SDK to free any memory related to backend. Does nothing + // if backendCreate has not been called. + LiteRtStatus FreeBackend(); + + void* lib_so_ = nullptr; + void* lib_system_so_ = nullptr; + + const QnnInterface_t* interface_ = nullptr; + const QnnSystemInterface_t* system_interface_ = nullptr; + + Qnn_LogHandle_t log_handle_ = nullptr; + Qnn_BackendHandle_t backend_handle_ = nullptr; + Qnn_DeviceHandle_t device_handle_ = nullptr; +}; + +// Unfortunately we can't use std::unique_ptr with a deleter because +// QnnContext_FreeFn_t takes a profile handle as a second argument. +class QnnManager::ContextHandle { + public: + ContextHandle(Qnn_ContextHandle_t context_handle, Qnn_ProfileHandle_t profile, + QnnContext_FreeFn_t free_fn) + : context_handle_(context_handle), profile_(profile), free_fn_(free_fn) {} + + ~ContextHandle() { + if (context_handle_ && free_fn_) { + free_fn_(context_handle_, profile_); + } + } + + ContextHandle(ContextHandle&& other) { *this = std::move(other); } + + ContextHandle(const ContextHandle& other) = delete; + + ContextHandle& operator=(ContextHandle&& other) { + std::swap(context_handle_, other.context_handle_); + std::swap(profile_, other.profile_); + std::swap(free_fn_, other.free_fn_); + return *this; + } + + ContextHandle& operator=(const ContextHandle& other) = delete; + + Qnn_ContextHandle_t get() const noexcept { return context_handle_; } + explicit operator bool() const noexcept { return context_handle_ != nullptr; } + + private: + Qnn_ContextHandle_t context_handle_ = nullptr; + Qnn_ProfileHandle_t profile_ = nullptr; + QnnContext_FreeFn_t free_fn_ = nullptr; +}; + +} // namespace litert::qnn + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_MANAGER_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc new file mode 100644 index 00000000000000..e61490cb21bf2a --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager_test.cc @@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +#include +#include +#include "tensorflow/lite/experimental/litert/test/common.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" + +namespace { + +using ::litert::qnn::QnnManager; +using ::litert::qnn::internal::Dump; +using ::testing::HasSubstr; + +// NOTE: This tests that all of the dynamic loading works properly and +// the QNN SDK instance can be properly initialized and destroyed. + +TEST(QnnManagerTest, SetupQnnManager) { + auto configs = QnnManager::DefaultBackendConfigs(); + auto qnn = QnnManager::Create(configs); + ASSERT_TRUE(qnn.ok()); +} + +TEST(QnnManagerTest, Dump) { + auto configs = QnnManager::DefaultBackendConfigs(); + auto qnn = QnnManager::Create(configs); + ASSERT_TRUE(qnn.ok()); + + std::ostringstream dump; + Dump(**qnn, dump); + + EXPECT_THAT(dump.str(), HasSubstr("< QnnInterface_t >")); + EXPECT_THAT(dump.str(), HasSubstr("< QnnSystemInterface_t >")); +} + +} // namespace diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc new file mode 100644 index 00000000000000..4f1a4c6f048a59 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.cc @@ -0,0 +1,101 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h" + +#include +#include +#include + +#include "absl/log/absl_check.h" +#include "absl/log/absl_log.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" + +namespace litert { +namespace qnn { + +QnnTensor::QnnTensor(const QnnTensor& other) : QnnTensor(other.Tensor()) { + auto status = DeepCopy(); + // This should never fail because the input QnnTensor was already deep-copied. + if (!status.ok()) { + ABSL_LOG(ERROR) << "Failed to build QnnTensor: " << status; + ABSL_CHECK_OK(status); + } +} + +QnnTensor::QnnTensor(QnnTensor&& other) { + tensor_ = other.tensor_; + // Swap managed memory. + std::swap(name_, other.name_); + std::swap(dimensions_, other.dimensions_); + std::swap(is_dynamic_dimensions_, other.is_dynamic_dimensions_); +} + +absl::StatusOr QnnTensor::Create(const Qnn_Tensor_t& tensor) { + QnnTensor qnn_tensor(tensor); + if (auto status = qnn_tensor.DeepCopy(); !status.ok()) { + return status; + } + return qnn_tensor; +} + +absl::Status QnnTensor::DeepCopy() { + if (tensor_.version == QNN_TENSOR_VERSION_1) { + dimensions_.reserve(tensor_.v1.rank); + std::copy(tensor_.v1.dimensions, tensor_.v1.dimensions + tensor_.v1.rank, + std::back_inserter(dimensions_)); + tensor_.v1.dimensions = dimensions_.data(); + + // FIXME: Implement deep copy for quantizeParams. + if (tensor_.v1.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || + tensor_.v1.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_VECTOR) { + return absl::InternalError("Unsupported QNN quantization"); + } + + } else if (tensor_.version == QNN_TENSOR_VERSION_2) { + dimensions_.reserve(tensor_.v2.rank); + std::copy(tensor_.v2.dimensions, tensor_.v2.dimensions + tensor_.v2.rank, + std::back_inserter(dimensions_)); + tensor_.v2.dimensions = dimensions_.data(); + + if (tensor_.v2.isDynamicDimensions) { + is_dynamic_dimensions_.reserve(tensor_.v2.rank); + std::copy(tensor_.v2.isDynamicDimensions, + tensor_.v2.isDynamicDimensions + tensor_.v2.rank, + std::back_inserter(is_dynamic_dimensions_)); + tensor_.v2.isDynamicDimensions = is_dynamic_dimensions_.data(); + } + + // FIXME: Implement deep copy for quantizeParams. + if (tensor_.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION || + tensor_.v2.quantizeParams.quantizationEncoding == + QNN_QUANTIZATION_ENCODING_VECTOR) { + return absl::InternalError("Unsupported QNN quantization"); + } + + } else { + return absl::InternalError("Unsupported QNN tensor version"); + } + + return {}; +} + +} // namespace qnn +} // namespace litert diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h new file mode 100644 index 00000000000000..3ba66c8c341b5d --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_tensor.h @@ -0,0 +1,59 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/QnnTypes.h" + +namespace litert { +namespace qnn { + +class QnnTensor { + public: + static absl::StatusOr Create(const Qnn_Tensor_t& tensor); + + QnnTensor(const QnnTensor& other); + QnnTensor(QnnTensor&& other); + + QnnTensor& operator=(const QnnTensor&) = delete; + QnnTensor& operator=(QnnTensor&&) = delete; + + Qnn_Tensor_t& Tensor() { return tensor_; } + const Qnn_Tensor_t& Tensor() const { return tensor_; } + + size_t Rank() const { return dimensions_.size(); } + const uint32_t* Dimensions() const { return dimensions_.data(); } + + private: + explicit QnnTensor(const Qnn_Tensor_t& tensor) : tensor_(tensor) {} + absl::Status DeepCopy(); + + Qnn_Tensor_t tensor_; + std::string name_; + std::vector dimensions_; + std::vector is_dynamic_dimensions_; +}; + +} // namespace qnn +} // namespace litert + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_QNN_TENSOR_H_ diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl b/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl new file mode 100644 index 00000000000000..939849401306f4 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/qualcomm_build_defs.bzl @@ -0,0 +1,83 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Build definitions for QualComm backend.""" + +load("//tensorflow/lite/experimental/litert/build_common:litert_build_defs.bzl", "append_rule_kwargs", "litert_lib", "make_rpaths") + +_QNN_LIBCC_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++.so.1", + # "//third_party/qairt/latest:lib/x86_64-linux-clang/libc++abi.so.1", + # copybara:uncomment_end +] # @unused + +# TODO: Make rpaths dynamic with "$(location {})". +_QNN_LIB_RPATHS_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "third_party/qairt/latest/lib/x86_64-linux-clang", + # copybara:uncomment_end +] + +_QNN_LIB_HTP_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnHtp.so", + # copybara:uncomment_end +] + +_QNN_LIB_SYSTEM_X86_64 = [ + # copybara:uncomment_begin(google-only) + # "//third_party/qairt/latest:lib/x86_64-linux-clang/libQnnSystem.so", + # copybara:uncomment_end +] + +def litert_lib_with_qnn( + backend = "htp", + include_system = False, + use_custom_libcc = False, + **litert_lib_kwargs): + """Creates a litert_lib target with QualComm backend dependencies. + + Args: + backend: The backend to use. Currently only "htp" is supported. + include_system: Whether to include libQnnSystem.so. + use_custom_libcc: Whether to use a custom libcc. Not yet supported. + **litert_lib_kwargs: Keyword arguments passed to litert_lib. + """ + if backend != "htp": + fail("Only htp currently supported") + + if use_custom_libcc: + # TODO: Figure out strategy for custom libcc. + fail("Custom libcc not yet supported") + + data_x86_64 = [] + data_x86_64.extend(_QNN_LIB_HTP_X86_64) + if include_system: + data_x86_64.extend(_QNN_LIB_SYSTEM_X86_64) + data = select({ + "//tensorflow:linux_x86_64": data_x86_64, + "//conditions:default": [], + }) + + append_rule_kwargs( + litert_lib_kwargs, + data = data, + linkopts = select({ + "//tensorflow:linux_x86_64": [make_rpaths(_QNN_LIB_RPATHS_X86_64)], + "//conditions:default": [], + }), + ) + + litert_lib(**litert_lib_kwargs) diff --git a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD similarity index 61% rename from tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD rename to tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD index 351b5d14748d78..45df0fef3b5a21 100644 --- a/tensorflow/compiler/mlir/lite/experimental/lrt/cc/BUILD +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/BUILD @@ -14,19 +14,18 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow/compiler/mlir/lite/experimental/lrt:__subpackages__"], + default_visibility = ["//tensorflow/lite/experimental/litert:__subpackages__"], ) cc_library( - name = "lite_rt_cc_api", - hdrs = [ - "lite_rt_support.h", - ], + name = "dump", + srcs = ["dump.cc"], + hdrs = ["dump.h"], + tags = ["nobuilder"], deps = [ - "//tensorflow/compiler/mlir/lite/core:model_builder_base", - "//tensorflow/compiler/mlir/lite/experimental/lrt/c:lite_rt_c_api", - "//tensorflow/lite/c:c_api_types", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + # copybara:uncomment "//third_party/qairt/latest:qnn_lib_headers", + "//tensorflow/lite/experimental/litert/vendors/qualcomm:qnn_manager_hdr", ], ) - -exports_files(srcs = glob(["lite_rt*.h"])) diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc new file mode 100644 index 00000000000000..0e94b6b0385890 --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.cc @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h" + +#include + +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "third_party/qairt/latest/include/QNN/QnnInterface.h" +#include "third_party/qairt/latest/include/QNN/System/QnnSystemInterface.h" +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn::internal { +namespace { + +static constexpr absl::string_view kNullDumpTpl = "%s : nullptr\n"; + +void Dump(const QnnInterface_t* interface, std::ostream& out) { + static constexpr absl::string_view kQnnInterfaceHeader = "< QnnInterface_t >"; + // NOLINTBEGIN + static constexpr absl::string_view kQnnInterfaceDumpTpl = + "\ + %s\n\ + name: %s\n\ + backend_id: %u\n\ + core_api_version: %u.%u.%u\n\ + backend_api_version: %u.%u.%u\n"; + // NOLINTEND + + if (interface == nullptr) { + out << absl::StreamFormat(kNullDumpTpl, kQnnInterfaceHeader); + return; + } + + const auto core_version = interface->apiVersion.coreApiVersion; + const auto backend_version = interface->apiVersion.backendApiVersion; + + out << absl::StreamFormat(kQnnInterfaceDumpTpl, kQnnInterfaceHeader, + interface->providerName, interface->backendId, + core_version.major, core_version.minor, + core_version.patch, backend_version.major, + backend_version.minor, backend_version.patch); +} + +void Dump(const QnnSystemInterface_t* interface, std::ostream& out) { + static constexpr absl::string_view kQnnSystemInterfaceHeader = + "< QnnSystemInterface_t >"; + // NOLINTBEGIN + static constexpr absl::string_view kQnnSystemInterfaceDumpTpl = + "\ + %s\n\ + name: %s\n\ + backend_id: %u\n\ + system_api_version: %u.%u.%u\n"; + // NOLINTEND + + if (interface == nullptr) { + out << absl::StreamFormat(kNullDumpTpl, kQnnSystemInterfaceHeader); + return; + } + + const auto system_version = interface->systemApiVersion; + + out << absl::StreamFormat(kQnnSystemInterfaceDumpTpl, + kQnnSystemInterfaceHeader, interface->providerName, + interface->backendId, system_version.major, + system_version.minor, system_version.patch); +} + +} // namespace + +void Dump(const QnnManager& qnn, std::ostream& out) { + Dump(qnn.interface_, out); + Dump(qnn.system_interface_, out); +} +} // namespace litert::qnn::internal diff --git a/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h new file mode 100644 index 00000000000000..b64650249af0af --- /dev/null +++ b/tensorflow/lite/experimental/litert/vendors/qualcomm/tools/dump.h @@ -0,0 +1,29 @@ +// Copyright 2024 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ +#define TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ + +#include +#include + +#include "tensorflow/lite/experimental/litert/vendors/qualcomm/qnn_manager.h" + +namespace litert::qnn::internal { + +void Dump(const QnnManager& qnn, std::ostream& out = std::cerr); + +} + +#endif // TENSORFLOW_LITE_EXPERIMENTAL_LITERT_VENDORS_QUALCOMM_TOOLS_DUMP_H_ diff --git a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc index 10c5a878ddd931..de2bcc3f0d079d 100644 --- a/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc +++ b/tensorflow/lite/experimental/shlo/quantized_tensor_element_type_test.cc @@ -28,9 +28,7 @@ namespace shlo_ref { namespace { using testing::Each; -using testing::ElementsAre; using testing::ElementsAreArray; -using testing::Eq; using testing::FloatEq; using testing::Pointwise; diff --git a/tensorflow/lite/java/BUILD b/tensorflow/lite/java/BUILD index e4c1432e6c8dda..64eaf2f5323dda 100644 --- a/tensorflow/lite/java/BUILD +++ b/tensorflow/lite/java/BUILD @@ -133,6 +133,7 @@ filegroup( TFLITE_HEADERS = [ # TODO(b/175298345): Clean up and if possible remove c:common.h and core/c:common.h here. "//tensorflow/lite:builtin_ops.h", + "//tensorflow/lite/c:builtin_op_data.h", "//tensorflow/lite/c:c_api.h", "//tensorflow/lite/c:c_api_experimental.h", "//tensorflow/lite/c:c_api_opaque.h", diff --git a/tensorflow/lite/java/ovic/BUILD b/tensorflow/lite/java/ovic/BUILD index e36c77b7369780..a6ce1d4a07aeea 100644 --- a/tensorflow/lite/java/ovic/BUILD +++ b/tensorflow/lite/java/ovic/BUILD @@ -2,6 +2,7 @@ # OVIC Benchmarker Java API. load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@rules_java//java:defs.bzl", "java_binary", "java_library", "java_test") load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") package( diff --git a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD index 365942d6490601..894858dbd4e022 100644 --- a/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD +++ b/tensorflow/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -2,6 +2,7 @@ # Internal helper function to test TF Lite API. load("@build_bazel_rules_android//android:rules.bzl", "android_library") +load("@rules_java//java:defs.bzl", "java_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:LICENSE"], diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD index 003cd5c5b9e968..b7daf2c331cbb5 100644 --- a/tensorflow/lite/kernels/BUILD +++ b/tensorflow/lite/kernels/BUILD @@ -14,7 +14,10 @@ package( filegroup( name = "tflite_internal_cc_3p_api_deps_src", - srcs = ["op_macros.h"], + srcs = [ + "builtin_ops_list.inc", + "op_macros.h", + ], visibility = [ "//tensorflow/lite:__pkg__", ], @@ -839,6 +842,7 @@ BUILTIN_KERNEL_DEPS = [ "//tensorflow/lite/kernels/internal:cpu_check", "//tensorflow/lite/kernels/internal:kernel_utils", "//tensorflow/lite/kernels/internal:optimized_base", + "//tensorflow/lite/kernels/internal:portable_tensor_utils", "//tensorflow/lite/kernels/internal:quantization_util", "//tensorflow/lite/kernels/internal:reference_base", "//tensorflow/lite/kernels/internal:strided_slice_logic", @@ -852,17 +856,6 @@ BUILTIN_KERNEL_DEPS = [ ":eigen_support", "//tensorflow/lite/kernels/internal:optimized_eigen", ], -}) + select({ - "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [], - "//conditions:default": [ - "@pthreadpool", - ], -}) + select({ - "//tensorflow/lite:tflite_with_xnnpack_explicit_false": [], - "//tensorflow/lite:tflite_kernel_use_xnnpack_false": [], - "//conditions:default": [ - "@XNNPACK", - ], }) + select({ # This select must match the similar select in `copts` "//tensorflow:linux_ppc64le": [], @@ -896,8 +889,8 @@ cc_library( "//tensorflow/lite:array", "//tensorflow/lite:builtin_ops", "//tensorflow/lite:cc_api_stable", - "@local_tsl//tsl/lib/random:philox_random", - "@local_tsl//tsl/lib/random:random_distributions_utils", + "@local_xla//xla/tsl/lib/random:philox_random", + "@local_xla//xla/tsl/lib/random:random_distributions_utils", "//tensorflow/lite/core/c:c_api_types", # TODO(b/179298174): Move out from the experimental directory. "//tensorflow/lite/experimental/resource", @@ -1788,6 +1781,8 @@ cc_test( deps = [ ":test_main", ":test_util", + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels/internal:tensor_utils_no_eigen", "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest", ], diff --git a/tensorflow/lite/kernels/acceleration_test_util_internal.h b/tensorflow/lite/kernels/acceleration_test_util_internal.h index 3b5a5166ece46d..afb928dcbb016a 100644 --- a/tensorflow/lite/kernels/acceleration_test_util_internal.h +++ b/tensorflow/lite/kernels/acceleration_test_util_internal.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -41,20 +42,20 @@ class ConfigurationEntry { public: ConfigurationEntry(const std::string& test_id_rex, T test_config, bool is_denylist) - : test_id_rex_(test_id_rex), + : test_id_rex_(new RE2(test_id_rex)), test_config_(test_config), is_denylist_(is_denylist) {} - bool Matches(const std::string& test_id) { - return RE2::FullMatch(test_id, test_id_rex_); + bool Matches(const std::string& test_id) const { + return RE2::FullMatch(test_id, *test_id_rex_); } bool IsDenylistEntry() const { return is_denylist_; } const T& TestConfig() const { return test_config_; } - const std::string& TestIdRex() const { return test_id_rex_; } + const std::string& TestIdRex() const { return test_id_rex_->pattern(); } private: - std::string test_id_rex_; + std::unique_ptr test_id_rex_; T test_config_; bool is_denylist_; }; @@ -74,7 +75,7 @@ std::optional GetAccelerationTestParam(std::string test_id) { auto consumer = [&config](std::string key, std::string value_str, bool is_denylist) mutable { T value = T::ParseConfigurationLine(value_str); - config->push_back(ConfigurationEntry(key, value, is_denylist)); + config->emplace_back(key, value, is_denylist); }; ReadAccelerationConfig(T::AccelerationTestConfig(), consumer); @@ -88,9 +89,11 @@ std::optional GetAccelerationTestParam(std::string test_id) { const std::vector>* test_config = test_config_ptr.load(); - const auto test_config_iter = std::find_if( - test_config->begin(), test_config->end(), - [&test_id](ConfigurationEntry elem) { return elem.Matches(test_id); }); + const auto test_config_iter = + std::find_if(test_config->begin(), test_config->end(), + [&test_id](const ConfigurationEntry& elem) { + return elem.Matches(test_id); + }); if (test_config_iter != test_config->end() && !test_config_iter->IsDenylistEntry()) { return std::optional(test_config_iter->TestConfig()); diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index d1eb3130c78c04..b5f416ea68ceef 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h" #include +#include #include #include @@ -440,6 +441,17 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { return swapped_shape; } +TfLiteStatus VerifyPerChannelQuantization(TfLiteContext* context, + const TfLiteTensor* tensor) { + TF_LITE_ENSURE_EQ(context, tensor->quantization.type, + kTfLiteAffineQuantization); + const auto* affine_quantization = + reinterpret_cast(tensor->quantization.params); + TF_LITE_ENSURE(context, affine_quantization); + TF_LITE_ENSURE(context, affine_quantization->scale); + return affine_quantization->scale->size > 1 ? kTfLiteOk : kTfLiteError; +} + TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& input_shape, const TfLiteTensor* input, @@ -481,9 +493,22 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, input_size, quant_data, scaling_factors_ptr, input_offset_ptr, params->asymmetric_quantize_inputs); - for (int b = 0; b < num_batches_to_quantize; ++b) { - // Incorporate scaling of the filter. - scaling_factors_ptr[b] *= filter->params.scale; + float* per_channel_scale_ptr = nullptr; + if (VerifyPerChannelQuantization(context, filter) == kTfLiteOk) { + // Per channel quantization. + const auto* affine_quantization = + reinterpret_cast( + filter->quantization.params); + TF_LITE_ENSURE_EQ( + context, affine_quantization->scale->size, + filter->dims->data[affine_quantization->quantized_dimension]); + per_channel_scale_ptr = affine_quantization->scale->data; + } else { + // Per tensor quantization. + for (int b = 0; b < num_batches_to_quantize; ++b) { + // Incorporate scaling of the filter + scaling_factors_ptr[b] *= filter->params.scale; + } } RuntimeShape output_shape = GetTensorShape(output); @@ -492,10 +517,11 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, output_size *= output_shape.Dims(i); } std::fill_n(GetTensorData(output), output_size, 0.0f); - reference_ops::BatchMatMul( - filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, - input_offset_ptr, row_sums_ptr, GetTensorShape(output), - GetTensorData(output), &(data->compute_row_sums)); + reference_ops::BatchMatMul(filter_shape, filter_data, input_shape, quant_data, + scaling_factors_ptr, input_offset_ptr, + row_sums_ptr, GetTensorShape(output), + GetTensorData(output), + &(data->compute_row_sums), per_channel_scale_ptr); return kTfLiteOk; } @@ -660,10 +686,44 @@ TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node, return nullptr; } + TfLiteIntArrayFree(transposed_rhs->dims); + transposed_rhs->dims = TfLiteIntArrayCopy(rhs->dims); + std::swap(transposed_rhs->dims->data[transposed_rhs->dims->size - 1], + transposed_rhs->dims->data[transposed_rhs->dims->size - 2]); if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) { // Get the quantization params from the RHS tensor. transposed_rhs->params.scale = rhs->params.scale; transposed_rhs->params.zero_point = rhs->params.zero_point; + if (rhs->quantization.type == kTfLiteAffineQuantization) { + transposed_rhs->quantization.type = rhs->quantization.type; + if (transposed_rhs->quantization.params) { + auto* transposed_rhs_affine_quantization = + reinterpret_cast( + transposed_rhs->quantization.params); + TfLiteIntArrayFree(transposed_rhs_affine_quantization->zero_point); + TfLiteFloatArrayFree(transposed_rhs_affine_quantization->scale); + free(transposed_rhs->quantization.params); + } + transposed_rhs->quantization.params = + malloc(sizeof(TfLiteAffineQuantization)); + const auto* rhs_affine_quantization = + reinterpret_cast(rhs->quantization.params); + auto* transposed_rhs_affine_quantization = + reinterpret_cast( + transposed_rhs->quantization.params); + int quantized_dimension = rhs_affine_quantization->quantized_dimension; + if (quantized_dimension == rhs->dims->size - 1) { + quantized_dimension = rhs->dims->size - 2; + } else if (quantized_dimension == rhs->dims->size - 2) { + quantized_dimension = rhs->dims->size - 1; + } + transposed_rhs_affine_quantization->quantized_dimension = + quantized_dimension; + transposed_rhs_affine_quantization->zero_point = + TfLiteIntArrayCopy(rhs_affine_quantization->zero_point); + transposed_rhs_affine_quantization->scale = + TfLiteFloatArrayCopy(rhs_affine_quantization->scale); + } } return transposed_rhs; } @@ -738,8 +798,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { lhs_dims_count = orig_lhs_shape.DimensionsCount(); const TfLiteTensor* rhs_tensor = rhs; bool implicit_transpose_possible = true; - if ((lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) || - kernel_type == kReference || rhs->type == kTfLiteInt16) { + if (lhs->type == kTfLiteFloat32 || kernel_type == kReference || + rhs->type == kTfLiteInt16) { implicit_transpose_possible = false; } bool do_implicit_transpose = !adj_y && implicit_transpose_possible; @@ -767,18 +827,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (rhs->type) { case kTfLiteFloat32: // Note we pass RHS args first, LHS args second. See note above. - if (kernel_type == kGenericOptimized) { - optimized_ops::BatchMatMul( - rhs_shape, GetTensorData(rhs_tensor), lhs_shape, - GetTensorData(lhs_tensor), GetTensorShape(output), - GetTensorData(output), - CpuBackendContext::GetFromContext(context), do_implicit_transpose); - } else { - reference_ops::BatchMatMul(rhs_shape, GetTensorData(rhs_tensor), - lhs_shape, GetTensorData(lhs_tensor), - GetTensorShape(output), - GetTensorData(output)); - } + reference_ops::BatchMatMul(rhs_shape, GetTensorData(rhs_tensor), + lhs_shape, GetTensorData(lhs_tensor), + GetTensorShape(output), + GetTensorData(output)); break; case kTfLiteInt8: case kTfLiteInt16: diff --git a/tensorflow/lite/kernels/conv.cc b/tensorflow/lite/kernels/conv.cc index befd60da71b7c8..76abb49a3250da 100644 --- a/tensorflow/lite/kernels/conv.cc +++ b/tensorflow/lite/kernels/conv.cc @@ -38,10 +38,9 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/optimized/multithreaded_conv.h" #endif #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" -#include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -520,7 +519,7 @@ TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context, &context->tensors[node->temporaries->data[data->im2col_index]]; im2col->type = input->type; if (is_hybrid) { - im2col->type = filter->type; + im2col->type = filter->type == kTfLiteInt4 ? kTfLiteInt8 : filter->type; } im2col->allocation_type = kTfLiteArenaRw; auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); @@ -696,6 +695,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, effective_kernel_type = kReference; } + const uint8_t* filter_data = nullptr; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), GetTensorShape(filter).FlatSize(), + unpacked_filter_data.get()); + filter_data = reinterpret_cast(unpacked_filter_data.get()); + } else { + filter_data = GetTensorData(filter); + } + ConvParams op_params; op_params.padding_type = PaddingType::kSame; op_params.padding_values.width = data->padding.width; @@ -715,10 +727,10 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, case kReference: { reference_ops::Conv( op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output), - GetTensorShape(im2col), GetTensorData(im2col), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col), /* cpu_backend_context = */ nullptr); break; } @@ -728,10 +740,10 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, // There is only one optimized implementation for Quantized Conv. optimized_ops::Conv( op_params, GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output), - GetTensorShape(im2col), GetTensorData(im2col), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), GetTensorShape(im2col), + GetTensorData(im2col), CpuBackendContext::GetFromContext(context)); break; } @@ -770,17 +782,17 @@ void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, effective_kernel_type = kReference; } - const int8_t* filter_data; - const size_t bytes_unpacked = filter->bytes * 2; - auto unpacked_filter_data = std::make_unique(bytes_unpacked); - + const int8_t* filter_data = nullptr; + std::unique_ptr unpacked_filter_data = nullptr; if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); tflite::tensor_utils::UnpackDenseInt4IntoInt8( GetTensorData(filter), GetTensorShape(filter).FlatSize(), unpacked_filter_data.get()); filter_data = unpacked_filter_data.get(); } else { - filter_data = GetTensorData(filter); + filter_data = GetTensorData(filter); } switch (effective_kernel_type) { @@ -870,24 +882,35 @@ void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node, filter->params.zero_point || output->params.zero_point; + const int8_t* filter_data = nullptr; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), GetTensorShape(filter).FlatSize(), + unpacked_filter_data.get()); + filter_data = unpacked_filter_data.get(); + } else { + filter_data = GetTensorData(filter); + } + if (data->quantized_bias_type == kTfLiteInt32) { if (effective_kernel_type == kReference || has_non_zero_point) { reference_integer_ops::ConvPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); + GetTensorData(input), GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else { optimized_integer_ops::ConvPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output), GetTensorShape(im2col), - GetTensorData(im2col), + GetTensorData(input), GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + GetTensorShape(im2col), GetTensorData(im2col), CpuBackendContext::GetFromContext(context)); } } else { @@ -897,10 +920,9 @@ void EvalQuantizedPerChannel16x8(TfLiteContext* context, TfLiteNode* node, reference_integer_ops::ConvPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); + GetTensorData(input), GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } } @@ -1041,11 +1063,9 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, } int8_t* im2col_ptr = nullptr; - int8_t* filter_ptr = nullptr; if (im2col != nullptr) { im2col_ptr = im2col->data.int8; } - filter_ptr = filter->data.int8; const auto* affine_quantization = reinterpret_cast(filter->quantization.params); @@ -1062,6 +1082,19 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, effective_kernel_type = kReference; } + const int8_t* filter_data = nullptr; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), GetTensorShape(filter).FlatSize(), + unpacked_filter_data.get()); + filter_data = unpacked_filter_data.get(); + } else { + filter_data = GetTensorData(filter); + } + ConvParams op_params; op_params.padding_type = PaddingType::kSame; op_params.padding_values.width = data->padding.width; @@ -1076,7 +1109,7 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, case kReference: reference_ops::HybridConvPerChannel( op_params, scaling_factors_ptr, GetTensorShape(input), - quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr, + quantized_input_ptr_batch, GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), im2col_ptr, affine_quantization->scale->data, @@ -1095,7 +1128,7 @@ TfLiteStatus EvalHybridPerChannel(TfLiteContext* context, TfLiteNode* node, GetTemporarySafe(context, node, data->accum_scratch_index, &scratch)); optimized_ops::HybridConvPerChannel( op_params, scaling_factors_ptr, GetTensorShape(input), - quantized_input_ptr_batch, GetTensorShape(filter), filter_ptr, + quantized_input_ptr_batch, GetTensorShape(filter), filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), im2col_ptr, affine_quantization->scale->data, @@ -1151,6 +1184,19 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, } } + const int8_t* filter_data = nullptr; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), GetTensorShape(filter).FlatSize(), + unpacked_filter_data.get()); + filter_data = unpacked_filter_data.get(); + } else { + filter_data = GetTensorData(filter); + } + switch (kernel_type) { case kReference: case kGenericOptimized: @@ -1170,9 +1216,9 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, if (data->groups == 1) { optimized_ops::HybridConv( op_params, scaling_factors_ptr, GetTensorShape(input), - quantized_input_ptr_batch, GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(accum_scratch), + quantized_input_ptr_batch, GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(accum_scratch), GetTensorData(accum_scratch), GetTensorShape(output), GetTensorData(output), GetTensorShape(im2col), GetTensorData(im2col), diff --git a/tensorflow/lite/kernels/conv_test.cc b/tensorflow/lite/kernels/conv_test.cc index 6fe086793a14b5..0acd14582edd4a 100644 --- a/tensorflow/lite/kernels/conv_test.cc +++ b/tensorflow/lite/kernels/conv_test.cc @@ -1422,6 +1422,39 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridInt8) { 0.16))); } +TEST_P(ConvolutionOpTest, SimpleTestHybridInt4) { + HybridConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_INT4, {3, 2, 2, 1}, 0, 0, 4.0 / 7.0, 0}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetSignedFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 36, 4, 3, // second batch, right + }, + 0.45))); +} + TEST_P(ConvolutionOpTest, SimpleTestHybridInt8WithDilation) { const int stride_width = 1; const int stride_height = 1; @@ -1750,6 +1783,53 @@ TEST_P(ConvolutionOpTest, SimplePerTensorTest) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 111, -115, -89})); } +TEST_P(ConvolutionOpTest, SimplePerTensorTest4bit) { + PerChannelQuantizedConvolutionOpModel m( + GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, + {TensorType_INT4, + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + {2, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{1}, + /*per_channel_quantization_offsets=*/{0}, + /*channel_index=*/0}, + {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, + /*stride_width=*/1, /*stride_height=*/1); + m.SetInput({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetFilter( + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + { + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 7, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }); + m.SetBias({3, -2}); + + // Invoke and verify output. + // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel] + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({31, 54, -57, -43}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({61, 107, -115, -87})); +} + TEST_P(ConvolutionOpTest, SimplePerChannelTest) { PerChannelQuantizedConvolutionOpModel m( GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1}, @@ -1854,6 +1934,63 @@ TEST_P(ConvolutionOpTest, SimplePerChannel16x8Bias32) { ElementsAreArray({15872, 32767, -29184, -23552})); } +TEST_P(ConvolutionOpTest, SimplePerChannel16x4Bias32) { + const float scale = 128.0 / 65536; + PerChannelQuantizedConvolutionOpModel m( + GetRegistration(), {TensorType_INT16, {1, 2, 3, 2}, 0, 0, scale, 0}, + {TensorType_INT4, + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + {2, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{1, 1}, + /*per_channel_quantization_offsets=*/{0, 0}, + /*channel_index=*/0}, + {TensorType_INT16, {}, 0, 0, scale, 0}, + /*stride_width=*/1, /*stride_height=*/1, + /*padding=*/Padding_VALID, + /*activation=*/ActivationFunctionType_NONE, + /*dilation_width_factor=*/1, + /*dilation_height_factor=*/1, + /*num_threads=*/-1, + /*filter_data=*/{}, + /*bias_type=*/TensorType_INT32); + + m.SetInput({ + // [1 * 2 * 3 * 2] as [batch, y, x, input_channel] + 3, 2, // batch = 0, y = 0, x = 0 + 1, -1, // batch = 0, y = 0, x = 1 + -2, -3, // batch = 0, y = 0, x = 2 + 4, 3, // batch = 0, y = 1, x = 0 + 2, -2, // batch = 0, y = 1, x = 1 + -3, -4, // batch = 0, y = 1, x = 2 + }); + m.SetFilter( + // [2 * 2 * 2 * 2] as [output_channel, y, x, input_channel] + { + 1, 2, // out channel = 0, y = 0, x = 0 + 3, 4, // out channel = 0, y = 0, x = 1 + 3, 4, // out channel = 0, y = 1, x = 0 + 5, 6, // out channel = 0, y = 1, x = 1 + 7, 7, // out channel = 1, y = 0, x = 0 + 5, 6, // out channel = 1, y = 0, x = 1 + 3, 4, // out channel = 1, y = 1, x = 0 + 1, 2, // out channel = 1, y = 1, x = 1 + }); + m.SetBias({3, -2}); + + // Invoke and verify output. + // output has dimension [1 * 1 * 2 * 2] as [batch, y, x, output_channel] + ASSERT_EQ(m.Invoke(), kTfLiteOk); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({31, 54, -57, -43}))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({15872, 27648, -29184, -22016})); +} + TEST_P(ConvolutionOpTest, SimplePerChannel16x8Bias64) { const float scale = 128.0 / 65536; PerChannelQuantizedConvolutionOpModel m( @@ -2027,6 +2164,50 @@ TEST_P(ConvolutionOpTest, SimpleTestHybridPerChannel) { 0.16))); } +TEST_P(ConvolutionOpTest, SimpleTestHybridPerChannelInt4) { + float scale = 4.0 / 7.0; + float scale2 = 1.0 / 7.0; + HybridPerChannelConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}}, + {TensorType_INT4, + {3, 2, 2, 2}, + 0, + 0, + 0, + 0, + /*per_channel_quantization=*/true, + /*per_channel_quantization_scales=*/{scale, scale2, scale2}, + /*per_channel_quantization_offsets=*/{0, 0, 0}, + /*channel_index=*/0}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1 + 1, 1, 1, 1, 1, 1, 1, 1, // row = 2 + // Second batch + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1 + 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2 + }); + m.SetSignedFilter({ + 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter + -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter + -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 16, 4, 3, // second batch, left + 36, 4, 3, // second batch, right + }, + 0.45))); +} + TEST_P(ConvolutionOpTest, SimpleTestHybridPerChannelGrouped) { float scale = 4.0 / 127.0; float scale2 = 1.0 / 127.0; diff --git a/tensorflow/lite/kernels/dynamic_update_slice.cc b/tensorflow/lite/kernels/dynamic_update_slice.cc index 61a3f3d680df24..1bbf84e7804a81 100644 --- a/tensorflow/lite/kernels/dynamic_update_slice.cc +++ b/tensorflow/lite/kernels/dynamic_update_slice.cc @@ -105,6 +105,25 @@ std::vector ClampStartIndices(int input_dims, const int64_t* indices_data, return clamped_start_indices; } +template +void update_slice(int current_dim, int max_dim, const int32_t* output_stride, + const int32_t* update_stride, const int32_t* update_shape, + const T* update, const int32_t* indices_data, T* output) { + if (current_dim == max_dim) return; + if (current_dim == max_dim - 1) { + output += indices_data[current_dim] * output_stride[current_dim]; + memcpy(output, update, update_shape[max_dim - 1] * sizeof(T)); + } else { + output += indices_data[current_dim] * output_stride[current_dim]; + for (int i = 0; i < update_shape[current_dim]; ++i) { + update_slice(current_dim + 1, max_dim, output_stride, update_stride, + update_shape, update, indices_data, output); + output += output_stride[current_dim]; + update += update_stride[current_dim]; + } + } +} + template void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, const int64_t* indices_data, TfLiteTensor* output) { @@ -114,6 +133,12 @@ void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, T* output_data = GetTensorData(output); const int input_dims = input_shape.DimensionsCount(); + // If the update is the entirety of the output, then simply copy it and + // return. + if (input_shape.FlatSize() == update_shape.FlatSize()) { + memcpy(output_data, update_data, input_shape.FlatSize() * sizeof(T)); + return; + } // Computes the effective slice indices. // The clamped indices are gauranteed to >= 0 since update is less than or // equal to the operand size for each dimension. @@ -130,18 +155,19 @@ void DynamicUpdateSlice(const TfLiteTensor* input, const TfLiteTensor* update, return; } - std::vector current_dim(input_dims, 0); - // Overwrites update to output. - do { - int flat_update_index = - TensorIndexToFlat(current_dim.data(), input_dims, update_shape); - int flat_input_index = - TensorIndexToFlat(current_dim.data(), input_dims, input_shape, - clamped_start_indices.data()); - output_data[flat_input_index] = update_data[flat_update_index]; - } while (NextIndex(input_dims, - reinterpret_cast(update_shape.DimsData()), - current_dim.data())); + std::vector output_stride(input_dims); + std::vector update_stride(input_dims); + output_stride[input_dims - 1] = 1; + update_stride[input_dims - 1] = 1; + const int32_t* input_shape_data = input_shape.DimsData(); + const int32_t* update_shape_data = update_shape.DimsData(); + for (int i = input_dims - 2; i >= 0; --i) { + output_stride[i] = output_stride[i + 1] * input_shape_data[i + 1]; + update_stride[i] = update_stride[i + 1] * update_shape_data[i + 1]; + } + update_slice(0, input_dims, output_stride.data(), update_stride.data(), + update_shape.DimsData(), update_data, + clamped_start_indices.data(), output_data); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { diff --git a/tensorflow/lite/kernels/elementwise.cc b/tensorflow/lite/kernels/elementwise.cc index 5552f125818db4..c71fe04d3d9818 100644 --- a/tensorflow/lite/kernels/elementwise.cc +++ b/tensorflow/lite/kernels/elementwise.cc @@ -384,6 +384,12 @@ TfLiteStatus RsqrtEvalQuantizedInt16(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); TfLiteTensor* output; TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); + const int64_t num_elements = NumElements(input); + const int16_t* in_data = GetTensorData(input); + for (int64_t i = 0; i < num_elements; ++i) { + TF_LITE_ENSURE_MSG(context, in_data[i] >= op_data->input_offset, + "Rsqrt is only defined for positive values"); + } reference_integer_ops::LookupTable( GetTensorData(input), MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)), diff --git a/tensorflow/lite/kernels/elementwise_test.cc b/tensorflow/lite/kernels/elementwise_test.cc index 57b39de3e435a2..601397d69ac566 100644 --- a/tensorflow/lite/kernels/elementwise_test.cc +++ b/tensorflow/lite/kernels/elementwise_test.cc @@ -14,11 +14,17 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include +#include +#include #include #include #include +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -252,12 +258,13 @@ TEST(ElementWise, AbsInt32) { } TEST(ElementWise, AbsInt8) { - std::vector data = {15., 46., 78., -142., -1., -17., -49., 113.}; - std::vector abs_data(data.size()); - for (int i = 0; i < abs_data.size(); i++) { - abs_data[i] = std::abs(data[i]); + const std::vector input_data = {15., 46., 78., -142., + -1., -17., -49., 113.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = std::abs(input_data[i]); } - const auto minmax = std::minmax_element(data.begin(), data.end()); + const auto minmax = std::minmax_element(input_data.begin(), input_data.end()); const float abs_max = std::max(std::abs(*minmax.first), *minmax.second); const float kInputScale = (*minmax.second - *minmax.first) / 255.0; const float kOutputScale = abs_max / 255.0; @@ -275,19 +282,20 @@ TEST(ElementWise, AbsInt8) { {kInputScale}, {input_zero_point}}, {TensorType_INT8, {1, 8}, 0, abs_max, kOutputScale, output_zero_point}); - m.AsymmetricQuantizeAndPopulate(m.input(), data); + m.AsymmetricQuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear(abs_data, kInputScale))); + ElementsAreArray(ArrayFloatNear(expected_output, kInputScale))); } TEST(ElementWise, AbsSameScaleInt8) { - std::vector data = {15., 46., 78., -142., -1., -17., -49., 113.}; - std::vector abs_data(data.size()); - for (int i = 0; i < abs_data.size(); i++) { - abs_data[i] = std::abs(data[i]); + const std::vector input_data = {15., 46., 78., -142., + -1., -17., -49., 113.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = std::abs(input_data[i]); } - const auto minmax = std::minmax_element(data.begin(), data.end()); + const auto minmax = std::minmax_element(input_data.begin(), input_data.end()); const float abs_max = std::max(std::abs(*minmax.first), *minmax.second); const float kInputScale = (*minmax.second - *minmax.first) / 255.0; const int input_zero_point = 127 - *minmax.second; @@ -303,26 +311,28 @@ TEST(ElementWise, AbsSameScaleInt8) { {kInputScale}, {input_zero_point}}, {TensorType_INT8, {1, 8}, 0, abs_max, kInputScale, input_zero_point}); - m.AsymmetricQuantizeAndPopulate(m.input(), data); + m.AsymmetricQuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear(abs_data, kInputScale))); + ElementsAreArray(ArrayFloatNear(expected_output, kInputScale))); } TEST(ElementWise, AbsInt16) { const float kQuantizedTolerance = GetQuantizationStep(-150, 150); - std::vector data = {15., 46., 78., -142., -1., -17., -49., 113.}; - std::vector abs_data(data.size()); - for (int i = 0; i < abs_data.size(); i++) { - abs_data[i] = std::abs(data[i]); + const std::vector input_data = {15., 46., 78., -142., + -1., -17., -49., 113.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = std::abs(input_data[i]); } ElementWiseOpQuantizedModel m(BuiltinOperator_ABS, {TensorType_INT16, {1, 8}, -142, 142}, {TensorType_INT16, {1, 8}, -150, 150}); - m.QuantizeAndPopulate(m.input(), data); + m.QuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear(abs_data, kQuantizedTolerance))); + EXPECT_THAT( + m.ExtractDequantVector(m.output()), + ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); } TEST(ElementWise, Sqrt) { @@ -344,10 +354,11 @@ TEST(ElementWise, Rsqrt) { } TEST(ElementWise, RsqrtInt8) { - std::vector data = {15., 46., 78., 142., 1., 17., 49., 113.}; - std::vector rsqrt_data(data.size()); - for (int i = 0; i < rsqrt_data.size(); i++) { - rsqrt_data[i] = 1.f / std::sqrt(data[i]); + const std::vector input_data = {15., 46., 78., 142., + 1., 17., 49., 113.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = 1.f / std::sqrt(input_data[i]); } float kInputScale = 142.0 / 255.0; float kOutputScale = 1.0 / 255.0; @@ -371,17 +382,18 @@ TEST(ElementWise, RsqrtInt8) { true, {kOutputScale}, {zero_point}}); - m.QuantizeAndPopulate(m.input(), data); + m.QuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear(rsqrt_data, kInputScale))); + ElementsAreArray(ArrayFloatNear(expected_output, kInputScale))); } TEST(ElementWise, RsqrtCloseTo0Int8) { - std::vector data = {15., 46., 78., 142., 0.1, 1., 49., 113.}; - std::vector rsqrt_data(data.size()); - for (int i = 0; i < rsqrt_data.size(); i++) { - rsqrt_data[i] = 1.f / std::sqrt(data[i]); + const std::vector input_data = {15., 46., 78., 142., + 0.1, 1., 49., 113.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = 1.f / std::sqrt(input_data[i]); } float kInputScale = 142.0 / 255.0; float kOutputScale = 3.16 / 255.0; @@ -405,18 +417,15 @@ TEST(ElementWise, RsqrtCloseTo0Int8) { true, {kOutputScale}, {zero_point}}); - m.QuantizeAndPopulate(m.input(), data); + m.QuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear(rsqrt_data, kInputScale))); + ElementsAreArray(ArrayFloatNear(expected_output, kInputScale))); } -TEST(ElementWise, RsqrtNanInt8) { - std::vector data = {15., 46., 78., 142., 1., 17., -49., 113.}; - std::vector rsqrt_data(data.size()); - for (int i = 0; i < rsqrt_data.size(); i++) { - rsqrt_data[i] = 1.f / std::sqrt(data[i]); - } +TEST(ElementWise, RsqrtNegativeInt8) { + const std::vector input_data = {15., 46., 78., 142., + 1., 17., -49., 113.}; float kInputScale = 142.0 / 127.0; float kOutputScale = 1.0 / 255.0; int32_t input_zero_point = 0; @@ -440,49 +449,43 @@ TEST(ElementWise, RsqrtNanInt8) { true, {kOutputScale}, {output_zero_point}}); - m.QuantizeAndPopulate(m.input(), data); + m.QuantizeAndPopulate(m.input(), input_data); EXPECT_THAT(m.Invoke(), kTfLiteError); } TEST(ElementWise, RsqrtInt16) { - const float input_min = -0.8f; - const float input_max = 0.8f; + const std::vector input_data = {1., 0.1, 4., 9.}; + std::vector expected_output(input_data.size()); + for (int i = 0; i < expected_output.size(); i++) { + expected_output[i] = 1.f / std::sqrt(input_data[i]); + } - const float output_min = -2.4f; - const float output_max = 2.4f; + const float input_min = -10.; + const float input_max = 10.; + + const float output_min = -4.; + const float output_max = 4.; const float kQuantizedTolerance = GetLUTTolerance(input_min, input_max, output_min, output_max); - ElementWiseOpQuantizedModel m(BuiltinOperator_RSQRT, - {TensorType_INT16, {1, 1, 4, 1}, -10, 10}, - {TensorType_INT16, {1, 1, 4, 1}, -10, 10}); - m.QuantizeAndPopulate(m.input(), {1, 0.1, 4, 9}); + ElementWiseOpQuantizedModel m( + BuiltinOperator_RSQRT, + {TensorType_INT16, {1, 1, 4, 1}, input_min, input_max}, + {TensorType_INT16, {1, 1, 4, 1}, output_min, output_max}); + m.QuantizeAndPopulate(m.input(), input_data); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT( m.ExtractDequantVector(m.output()), - ElementsAreArray(ArrayFloatNear({1.00009, 3.19407, 0.500198, 0.333262}, - kQuantizedTolerance))); + ElementsAreArray(ArrayFloatNear(expected_output, kQuantizedTolerance))); } -TEST(ElementWise, RsqrtNanInt16) { - const float input_min = -0.8f; - const float input_max = 0.8f; - - const float output_min = -2.4f; - const float output_max = 2.4f; - - const float kQuantizedTolerance = - GetLUTTolerance(input_min, input_max, output_min, output_max); - +TEST(ElementWise, RsqrtNegativeInt16) { ElementWiseOpQuantizedModel m(BuiltinOperator_RSQRT, {TensorType_INT16, {1, 1, 4, 1}, -10, 10}, {TensorType_INT16, {1, 1, 4, 1}, -10, 10}); m.QuantizeAndPopulate(m.input(), {-1, 0, -4, -9}); - ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.ExtractDequantVector(m.output()), - ElementsAreArray( - ArrayFloatNear({10, 9.82452, 10, 10}, kQuantizedTolerance))); + ASSERT_EQ(m.Invoke(), kTfLiteError); } TEST(ElementWise, Square) { diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 0693594cdd8613..8bfb045bc1b477 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -444,8 +444,7 @@ TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node, TF_LITE_ENSURE(context, input->type == kTfLiteInt8 || input->type == kTfLiteInt16); TF_LITE_ENSURE(context, (filter->type == kTfLiteInt8 || - (filter->type == kTfLiteInt4 && - input->type == kTfLiteInt16))); + filter->type == kTfLiteInt4)); TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, per_channel_quantization_size); TF_LITE_ENSURE_EQ( @@ -747,6 +746,7 @@ TfLiteStatus EvalHybridDense( int8_t* quant_data = GetTensorData(input_quantized); const int8_t* filter_data = nullptr; std::unique_ptr unpacked_filter_data = nullptr; + // Unoptimized 4-bit implementation. Ideally use EvalHybridDenseInt4 instead. if (filter->type == kTfLiteInt4) { const size_t bytes_unpacked = filter->bytes * 2; unpacked_filter_data = std::make_unique(bytes_unpacked); @@ -1115,8 +1115,8 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, namespace { template void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, - const TfLiteTensor* filter, const TfLiteTensor* bias, - TfLiteTensor* output, + const TfLiteTensor* filter, const int8_t* filter_data, + const TfLiteTensor* bias, TfLiteTensor* output, CpuBackendContext* cpu_backend_context) { FullyConnectedParams op_params; op_params.input_offset = -input->params.zero_point; @@ -1129,20 +1129,6 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, op_params.lhs_cacheable = IsConstantTensor(filter); op_params.rhs_cacheable = IsConstantTensor(input); - const int8_t* filter_data; - std::unique_ptr unpacked_filter_data = nullptr; - - if (filter->type == kTfLiteInt4) { - const size_t bytes_unpacked = filter->bytes * 2; - unpacked_filter_data = std::make_unique(bytes_unpacked); - tflite::tensor_utils::UnpackDenseInt4IntoInt8( - GetTensorData(filter), GetTensorShape(filter).FlatSize(), - unpacked_filter_data.get()); - filter_data = unpacked_filter_data.get(); - } else { - filter_data = GetTensorData(filter); - } - if (kernel_type == kReference) { reference_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), @@ -1160,8 +1146,8 @@ void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, template void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input, - const TfLiteTensor* filter, const TfLiteTensor* bias, - TfLiteTensor* output) { + const TfLiteTensor* filter, const int8_t* filter_data, + const TfLiteTensor* bias, TfLiteTensor* output) { FullyConnectedParams op_params; op_params.input_offset = -input->params.zero_point; op_params.weights_offset = -filter->params.zero_point; @@ -1171,20 +1157,6 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input, op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; - const int8_t* filter_data; - std::unique_ptr unpacked_filter_data = nullptr; - - if (filter->type == kTfLiteInt4) { - const size_t bytes_unpacked = filter->bytes * 2; - unpacked_filter_data = std::make_unique(bytes_unpacked); - tflite::tensor_utils::UnpackDenseInt4IntoInt8( - GetTensorData(filter), GetTensorShape(filter).FlatSize(), - unpacked_filter_data.get()); - filter_data = unpacked_filter_data.get(); - } else { - filter_data = GetTensorData(filter); - } - if (data->quantized_bias_type == kTfLiteInt32) { reference_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), @@ -1203,6 +1175,7 @@ void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input, template void FullyConnectedPerChannelInt8(const OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, + const int8_t* filter_data, const TfLiteTensor* bias, TfLiteTensor* output, CpuBackendContext* cpu_backend_context) { @@ -1216,31 +1189,29 @@ void FullyConnectedPerChannelInt8(const OpData* data, const TfLiteTensor* input, op_params.quantized_activation_max = data->output_activation_max; op_params.lhs_cacheable = IsConstantTensor(filter); op_params.rhs_cacheable = IsConstantTensor(input); + if (kernel_type == kReference) { reference_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output)); + GetTensorData(input), GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output)); } else { optimized_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), - GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), - GetTensorData(bias), GetTensorShape(output), - GetTensorData(output), cpu_backend_context); + GetTensorData(input), GetTensorShape(filter), filter_data, + GetTensorShape(bias), GetTensorData(bias), + GetTensorShape(output), GetTensorData(output), + cpu_backend_context); } } template -void FullyConnectedPerChannelInt16(const OpData* data, - const TfLiteTensor* input, - const TfLiteTensor* filter, - const TfLiteTensor* bias, - TfLiteTensor* output) { +void FullyConnectedPerChannelInt16( + const OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, + const int8_t* filter_data, const TfLiteTensor* bias, TfLiteTensor* output) { // FullyConnectedPerChannel ops spec is that weights are symmetric. // op_params.weights_offset is not set (filter.params.zero_point is not used), // since it will be always assumed to be 0. @@ -1250,19 +1221,6 @@ void FullyConnectedPerChannelInt16(const OpData* data, op_params.quantized_activation_min = data->output_activation_min; op_params.quantized_activation_max = data->output_activation_max; - const int8_t* filter_data; - std::unique_ptr unpacked_filter_data = nullptr; - if (filter->type == kTfLiteInt4) { - const size_t bytes_unpacked = filter->bytes * 2; - unpacked_filter_data = std::make_unique(bytes_unpacked); - tflite::tensor_utils::UnpackDenseInt4IntoInt8( - GetTensorData(filter), GetTensorShape(filter).FlatSize(), - unpacked_filter_data.get()); - filter_data = unpacked_filter_data.get(); - } else { - filter_data = GetTensorData(filter); - } - if (data->quantized_bias_type == kTfLiteInt32) { reference_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), @@ -1416,11 +1374,23 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, return kTfLiteError; } } else { + const int8_t* filter_data; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), + GetTensorShape(filter).FlatSize(), unpacked_filter_data.get()); + filter_data = unpacked_filter_data.get(); + } else { + filter_data = GetTensorData(filter); + } is_per_channel ? FullyConnectedPerChannelInt8( - data, input, filter, bias, output, + data, input, filter, filter_data, bias, output, CpuBackendContext::GetFromContext(context)) : FullyConnectedInt8( - data, input, filter, bias, output, + data, input, filter, filter_data, bias, output, CpuBackendContext::GetFromContext(context)); } break; @@ -1431,27 +1401,41 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, bool has_non_zero_point = input->params.zero_point || filter->params.zero_point || output->params.zero_point; + + const int8_t* filter_data; + std::unique_ptr unpacked_filter_data = nullptr; + if (filter->type == kTfLiteInt4) { + const size_t bytes_unpacked = filter->bytes * 2; + unpacked_filter_data = std::make_unique(bytes_unpacked); + tflite::tensor_utils::UnpackDenseInt4IntoInt8( + GetTensorData(filter), + GetTensorShape(filter).FlatSize(), unpacked_filter_data.get()); + filter_data = unpacked_filter_data.get(); + } else { + filter_data = GetTensorData(filter); + } + if (kernel_type == kReference || has_non_zero_point || - (bias && bias->type == kTfLiteInt64) || - (filter->type == kTfLiteInt4)) { - is_per_channel ? FullyConnectedPerChannelInt16( - data, input, filter, bias, output) - : FullyConnectedInt16( - data, input, filter, bias, output); + (bias && bias->type == kTfLiteInt64)) { + is_per_channel + ? FullyConnectedPerChannelInt16( + data, input, filter, filter_data, bias, output) + : FullyConnectedInt16(data, input, filter, + filter_data, bias, output); } else { is_per_channel ? optimized_integer_ops::FullyConnectedPerChannel( op_params, data->per_channel_output_multiplier.data(), data->per_channel_output_shift.data(), GetTensorShape(input), GetTensorData(input), - GetTensorShape(filter), GetTensorData(filter), - GetTensorShape(bias), GetTensorData(bias), - GetTensorShape(output), GetTensorData(output), + GetTensorShape(filter), filter_data, GetTensorShape(bias), + GetTensorData(bias), GetTensorShape(output), + GetTensorData(output), CpuBackendContext::GetFromContext(context)) : optimized_integer_ops::FullyConnected( op_params, GetTensorShape(input), GetTensorData(input), GetTensorShape(filter), - GetTensorData(filter), GetTensorShape(bias), + filter_data, GetTensorShape(bias), GetTensorData(bias), GetTensorShape(output), GetTensorData(output), CpuBackendContext::GetFromContext(context)); diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc index 10415098f3973d..ea4a04b0482220 100644 --- a/tensorflow/lite/kernels/fully_connected_test.cc +++ b/tensorflow/lite/kernels/fully_connected_test.cc @@ -30,11 +30,7 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/api/op_resolver.h" #include "tensorflow/lite/core/interpreter.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_type.h" @@ -144,7 +140,7 @@ class BaseFullyConnectedOpModel : public SingleOpModel { FullyConnectedOptionsWeightsFormat_DEFAULT, int input_size = -1, bool weights_per_channel_quantized = false, std::vector per_channel_quantization_scales = {}, - TfLiteType filter_type = kTfLiteNoType) + const TensorType& filter_type = TensorType_FLOAT32) : batches_(batches), units_(units), input_size_(input_size), @@ -162,66 +158,25 @@ class BaseFullyConnectedOpModel : public SingleOpModel { if (weights_per_channel_quantized) { std::vector per_channel_quantization_offsets( per_channel_quantization_scales.size(), 0); - if (input.type == TensorType_INT16) { - if (filter_type == kTfLiteInt4) { - weights_ = AddInput({TensorType_INT4, - {units_, input_size_}, - 0, - 0, - 0, - 0, - true, - per_channel_quantization_scales, - per_channel_quantization_offsets, - 0}); - } else { - weights_ = AddInput({TensorType_INT8, - {units_, input_size_}, - 0, - 0, - 0, - 0, - true, - per_channel_quantization_scales, - per_channel_quantization_offsets, - 0}); - } - } else { - weights_ = AddInput({input.type, - {units_, input_size_}, - 0, - 0, - 0, - 0, - true, - per_channel_quantization_scales, - per_channel_quantization_offsets, - 0}); - } + weights_ = AddInput({filter_type, + {units_, input_size_}, + 0, + 0, + 0, + 0, + true, + per_channel_quantization_scales, + per_channel_quantization_offsets, + 0}); } else { - if (input.type == TensorType_INT16) { - if (filter_type == kTfLiteInt4) { - weights_ = AddInput({TensorType_INT4, - {units_, input_size_}, - /*min=*/-7, - /*max=*/7}); - } else { - // Set min and max values that are used to calculate per-tensor scale - // and zero points. - weights_ = AddInput({TensorType_INT8, - {units_, input_size_}, - /*min=*/-63.5, - /*max=*/64}); - } - } else if (filter_type == kTfLiteInt4) { - weights_ = AddInput({TensorType_INT4, - {units_, input_size_}, - /*min=*/input.min, - /*max=*/input.max}); - } else { - weights_ = - AddInput({input.type, {units_, input_size_}, input.min, input.max}); + // per-tensor + float min = input.min; + float max = input.max; + if (filter_type == TensorType_INT4 || filter_type == TensorType_INT8) { + min = filter_type == TensorType_INT4 ? -7.f : -63.5f; + max = filter_type == TensorType_INT4 ? 7.f : 64.f; } + weights_ = AddInput({filter_type, {units_, input_size_}, min, max}); } if (bias_tensor_optional) { @@ -314,7 +269,7 @@ class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { ActivationFunctionType activation_func = ActivationFunctionType_RELU, FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT, - int input_size = -1, TfLiteType filter_type = kTfLiteNoType) + int input_size = -1, const TensorType& filter_type = TensorType_INT8) : BaseFullyConnectedOpModel( registration, units, batches, input, output, bias_type, keep_num_dims, bias_tensor_optional, activation_func, @@ -395,7 +350,7 @@ class PerChannelQuantizedFullyConnectedOpModel ActivationFunctionType activation_func = ActivationFunctionType_RELU, FullyConnectedOptionsWeightsFormat weights_format = FullyConnectedOptionsWeightsFormat_DEFAULT, - int input_size = -1, TfLiteType filter_type = kTfLiteNoType) + int input_size = -1, const TensorType& filter_type = TensorType_INT8) : BaseFullyConnectedOpModel( registration, units, batches, input, output, bias_type, keep_num_dims, bias_tensor_optional, activation_func, @@ -485,10 +440,6 @@ class HybridFullyConnectedOpModel : public SingleOpModel { PerChannelSymmetricQuantizeAndPopulate(weights_, f); } - void SetSignedWeights4Bit(std::initializer_list f) { - SignedSymmetricQuantizeAndPopulate4Bit(weights_, f); - } - void SetInput(const std::vector& f) { PopulateTensor(input_, f); } std::vector GetOutput() { return ExtractVector(output_); } std::vector GetOutputShape() { return GetTensorShape(output_); } @@ -685,7 +636,13 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, - /*output=*/{TensorType_UINT8, {}, -127, 128}); + /*output=*/{TensorType_UINT8, {}, -127, 128}, + /*bias_type=*/TensorType_INT32, + /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false, + /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU, + /*FullyConnectedOptionsWeightsFormat weights_format =*/ + FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); // input_product_scale < output_scale was not true. m.SetWeights({ @@ -720,7 +677,8 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) { /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/true, /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU, /*FullyConnectedOptionsWeightsFormat weights_format =*/ - FullyConnectedOptionsWeightsFormat_DEFAULT); + FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); // input_product_scale < output_scale was not true. m.SetWeights({ @@ -746,20 +704,21 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedUint8NoBias) { } // The expected values for this test were obtained by running the test with the -// same parameters but by setting filter type to INT8. +// same parameters but by setting filter_type == TensorType_INT8 and +// m.SetWeights. TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt4) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64}, /*output=*/{TensorType_INT8, {}, -127, 128}, TensorType_INT32, false, false, ActivationFunctionType_RELU, - FullyConnectedOptionsWeightsFormat_DEFAULT, -1, kTfLiteInt4); + FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT4); - // input_product_scale < output_scale was not true. + // Scale is set to 1.f by QuantizationParams() so don't exceed [-7,7] m.SetWeights4bit({ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 2 }); m.SetBias({1, 2, 3}); @@ -769,9 +728,10 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt4) { }); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetDequantizedOutput(), - testing::Pointwise(testing::FloatEq(), {64, 64, 68, 82, 82, 87})); - EXPECT_THAT(m.GetOutput(), ElementsAre(63, 63, 67, 81, 81, 86)); + EXPECT_THAT( + m.GetDequantizedOutput(), + testing::Pointwise(testing::FloatEq(), {104, 105, 106, 98, 99, 100})); + EXPECT_THAT(m.GetOutput(), ElementsAre(103, 104, 105, 97, 98, 99)); } TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt8) { @@ -827,6 +787,34 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt8) { EXPECT_THAT(m.GetOutput(), ElementsAre(23, 24, 25, 57, 58, 59)); } +TEST_P(QuantizedFullyConnectedOpTest, SimpleTestPerChannelQuantizedInt4) { + PerChannelQuantizedFullyConnectedOpModel m( + GetRegistration(), /*units=*/3, /*batches*/ 2, + /*input=*/{TensorType_INT8, {2, 10}, -63.5, 64}, + /*per_channel_quantization_scales=*/{1.0, 1.0, 1.0}, + /*output=*/{TensorType_INT8, {}, -127, 128}, + /*bias_type=*/TensorType_INT32, false, false, ActivationFunctionType_RELU, + FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT4); + + m.SetWeights4bit({ + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({104, 105, 106, 98, 99, 100}))); + EXPECT_THAT(m.GetOutput(), ElementsAre(103, 104, 105, 97, 98, 99)); +} + TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16NoBias) { const float scale = 128.0 / 65536; QuantizedFullyConnectedOpModel m( @@ -898,7 +886,7 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantizedInt16Bias32Weight4) { /*bias_tensor_optional=*/false, /*activation_func*/ ActivationFunctionType_RELU, /*weights_format=*/FullyConnectedOptionsWeightsFormat_DEFAULT, - /*input_size=*/-1, /*filter_type=*/kTfLiteInt4); + /*input_size=*/-1, /*filter_type=*/TensorType_INT4); m.SetWeights4bit({ 1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 0 @@ -960,7 +948,7 @@ TEST_P(QuantizedFullyConnectedOpTest, /*per_channel_quantization_scales=*/{1.0, 1.0, 1.0}, /*output=*/{TensorType_INT16, {}, 0, 0, scale, 0}, /*bias_type=*/TensorType_INT32, false, false, ActivationFunctionType_RELU, - FullyConnectedOptionsWeightsFormat_DEFAULT, -1, kTfLiteInt4); + FullyConnectedOptionsWeightsFormat_DEFAULT, -1, TensorType_INT4); m.SetWeights4bit({ 1, 2, 3, 4, 5, 6, -7, 1, 2, 3, // u = 0 @@ -1163,7 +1151,13 @@ TEST_P(QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches*/ 2, /*input=*/{TensorType_UINT8, {2, 10}, -127, 128}, - /*output=*/{TensorType_UINT8, {}, -63.5, 64}); + /*output=*/{TensorType_UINT8, {}, -63.5, 64}, + /*bias_type=*/TensorType_INT32, + /*keep_num_dims =*/false, /*bool bias_tensor_optional =*/false, + /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU, + /*FullyConnectedOptionsWeightsFormat weights_format =*/ + FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 @@ -1240,7 +1234,8 @@ void SimpleTestQuantizedInt16OutputCase( /*bias_type=*/TensorType_INT32, /*keep_num_dims=*/false, /*bias_tensor_optional=*/false, - /*activation_func=*/ActivationFunctionType_NONE, weights_format); + /*activation_func=*/ActivationFunctionType_NONE, weights_format, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); std::mt19937 random_engine; // Some compilers don't support uint8_t for uniform_distribution. @@ -1405,6 +1400,34 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8) { /*max_abs_err=*/1.3f))); } +TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt4) { + HybridFullyConnectedOpModel m( + /*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}, + /*weights=*/{TensorType_INT4, {3, 10}, 0, 0, 1.0, 0}); + + m.SetSignedWeights({ + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 6, 5, 4, // u = 2 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + ASSERT_EQ(m.Invoke(), kTfLiteOk); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + { + 104, 105, 106, // + 98, 99, 100, // + }, + /*max_abs_err=*/0.5f))); +} + TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt8MultiThreaded) { for (int num_threads = 1; num_threads <= 4; ++num_threads) { HybridFullyConnectedOpModel m( @@ -1503,18 +1526,28 @@ TEST(HybridAsymmetricInputFullyConnectedOpTest, SimpleTestQuantizedInt8) { /*max_abs_err=*/1.3f))); } -// The expected values for this test were obtained by running the test with the -// same weights but populated to a Int8 filter. -TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt4) { +TEST(HybridAsymmetricInputPerChannelWeightsFullyConnectedOpTest, + SimpleTestQuantizedPerChannelInt8) { HybridFullyConnectedOpModel m( /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {2, 10}}, - /*weights=*/{TensorType_INT4, {3, 10}, 0, 0, 10.0 / 7.0, 0}); // Hybrid + /*weights=*/ + {TensorType_INT8, + {3, 10}, + 0, + 0, + 0.0f, + 0, + true, + {10.0 / 127.0, 20.0 / 127.0, 30.0 / 127.0}, + {0, 0, 0}}, + {TensorType_FLOAT32}, + /*asymmetric_quantize_input*/ true); - m.SetSignedWeights4Bit({ - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2 + m.SetSignedPerChannelWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, // u = 1 + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, // u = 2 }); m.SetBias({1, 2, 3}); @@ -1527,26 +1560,26 @@ TEST(HybridFullyConnectedOpTest, SimpleTestQuantizedInt4) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 36, 37, 38, // - 52, 53, 54, // + 24, 195, 366, // + 58, 251, 441, // }, /*max_abs_err=*/1.3f))); } TEST(HybridAsymmetricInputPerChannelWeightsFullyConnectedOpTest, - SimpleTestQuantizedPerChannelInt8) { + SimpleTestQuantizedPerChannelInt4) { HybridFullyConnectedOpModel m( /*units=*/3, /*batches=*/2, /*input=*/{TensorType_FLOAT32, {2, 10}}, /*weights=*/ - {TensorType_INT8, + {TensorType_INT4, {3, 10}, 0, 0, 0.0f, 0, true, - {10.0 / 127.0, 20.0 / 127.0, 30.0 / 127.0}, + {10.0 / 7.0, 20.0 / 7.0, 30.0 / 7.0}, {0, 0, 0}}, {TensorType_FLOAT32}, /*asymmetric_quantize_input*/ true); @@ -1567,10 +1600,10 @@ TEST(HybridAsymmetricInputPerChannelWeightsFullyConnectedOpTest, EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( { - 24, 195, 366, // - 58, 251, 441, // + 35, 188, 368, // + 53, 275, 430, // }, - /*max_abs_err=*/1.3f))); + /*max_abs_err=*/0.5f))); } TEST_P(FloatFullyConnectedOpTest, SimpleTest4DInput) { @@ -1654,7 +1687,13 @@ TEST_P(QuantizedFullyConnectedOpTest, SimpleTest4dInputQuantizedUint8) { QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, - /*output=*/{TensorType_UINT8, {}, -127, 128}); + /*output=*/{TensorType_UINT8, {}, -127, 128}, + /*bias_type=*/TensorType_INT32, /*keep_num_dims =*/false, + /*bool bias_tensor_optional =*/false, + /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU, + /*FullyConnectedOptionsWeightsFormat weights_format =*/ + FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); // input_product_scale < output_scale was not true. m.SetWeights({ @@ -1686,7 +1725,13 @@ TEST_P(QuantizedFullyConnectedOpTest, QuantizedFullyConnectedOpModel m( GetRegistration(), /*units=*/3, /*batches=*/2, /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -127, 128}, - /*output=*/{TensorType_UINT8, {}, -63.5, 64}); + /*output=*/{TensorType_UINT8, {}, -63.5, 64}, + /*bias_type=*/TensorType_INT32, /*keep_num_dims =*/false, + /*bool bias_tensor_optional =*/false, + /*ActivationFunctionType activation_func =*/ActivationFunctionType_RELU, + /*FullyConnectedOptionsWeightsFormat weights_format =*/ + FullyConnectedOptionsWeightsFormat_DEFAULT, + /*input_size=*/-1, /*filter_type=*/TensorType_UINT8); m.SetWeights({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD index 57001a800c5287..2a1f510c131b4a 100644 --- a/tensorflow/lite/kernels/internal/BUILD +++ b/tensorflow/lite/kernels/internal/BUILD @@ -669,7 +669,9 @@ cc_library( copts = tflite_copts(), deps = [ ":common", + ":compatibility", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite/core/c:common", ], ) @@ -784,6 +786,7 @@ cc_library( ":tensor", ":tensor_utils", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", @@ -881,6 +884,7 @@ cc_library( ":tensor", ":tensor_utils", ":types", + "//tensorflow/lite:macros", "//tensorflow/lite:string_util", "//tensorflow/lite/core/c:c_api_types", "//tensorflow/lite/core/c:common", diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 726a279bfaef13..6d20210f41f626 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -25,98 +25,6 @@ limitations under the License. namespace tflite { namespace optimized_ops { -inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, - const RuntimeShape& rhs_shape, const float* rhs_data, - const RuntimeShape& output_shape, float* output_data, - CpuBackendContext* context, - bool transpose_lhs = false) { - using ::tflite::cpu_backend_gemm::Gemm; - using ::tflite::cpu_backend_gemm::GemmParams; - using ::tflite::cpu_backend_gemm::MatrixParams; - const RuntimeShape extended_lhs_shape = - RuntimeShape::ExtendedShape(5, lhs_shape); - const RuntimeShape extended_rhs_shape = - RuntimeShape::ExtendedShape(5, rhs_shape); - - // Determine which dimension is the broadcast dimension. - auto broadcast_dim = [](int lhs_dim, int rhs_dim) { - if (lhs_dim == rhs_dim) return lhs_dim; - if (lhs_dim == 1) return rhs_dim; - TFLITE_DCHECK_EQ(rhs_dim, 1); - return lhs_dim; - }; - - // Compute the "extent" for iterating on this dimension. - // If we are broadcasting, then don't advance (i.e return 0). - auto extent = [](const RuntimeShape& shape, int x) { - if (shape.Dims(x) == 1) { - return 0; - } - int prod = 1; - for (int i = x + 1; i < shape.DimensionsCount(); ++i) { - prod *= shape.Dims(i); - } - return prod; - }; - - const int batch_dim0 = - broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); - const int batch_dim1 = - broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); - const int batch_dim2 = - broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); - - const int lhs_ext0 = extent(extended_lhs_shape, 0); - const int lhs_ext1 = extent(extended_lhs_shape, 1); - const int lhs_ext2 = extent(extended_lhs_shape, 2); - const int rhs_ext0 = extent(extended_rhs_shape, 0); - const int rhs_ext1 = extent(extended_rhs_shape, 1); - const int rhs_ext2 = extent(extended_rhs_shape, 2); - - // Set params for each matrix multiply. - const int lhs_rows = extended_lhs_shape.Dims(3); - const int rhs_cols = extended_rhs_shape.Dims(4); - const int accum_depth = extended_lhs_shape.Dims(4); - - MatrixParams lhs_params; - if (transpose_lhs) { - lhs_params.order = cpu_backend_gemm::Order::kColMajor; - } else { - lhs_params.order = cpu_backend_gemm::Order::kRowMajor; - } - lhs_params.rows = lhs_rows; - lhs_params.cols = accum_depth; - - MatrixParams rhs_params; - rhs_params.order = cpu_backend_gemm::Order::kColMajor; - rhs_params.rows = accum_depth; - rhs_params.cols = rhs_cols; - - MatrixParams dst_params; - dst_params.order = cpu_backend_gemm::Order::kColMajor; - dst_params.rows = lhs_rows; - dst_params.cols = rhs_cols; - - for (int b0 = 0; b0 < batch_dim0; ++b0) { - const float* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); - const float* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); - for (int b1 = 0; b1 < batch_dim1; ++b1) { - const float* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; - const float* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; - for (int b2 = 0; b2 < batch_dim2; ++b2) { - const float* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; - const float* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; - float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + - b1 * batch_dim2 + b2) * - lhs_rows * rhs_cols; - GemmParams gemm_params; - cpu_backend_gemm::Gemm(lhs_params, lhs_ptr2, rhs_params, rhs_ptr2, - dst_params, out_ptr, gemm_params, context); - } - } - } -} - inline void BatchMatMul(const FullyConnectedParams& params, const RuntimeShape& lhs_shape, const int8_t* lhs_data, const RuntimeShape& rhs_shape, const int8_t* rhs_data, diff --git a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h index 3299f610697bbf..d340d2fec437ec 100644 --- a/tensorflow/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/optimized_ops.h @@ -7087,327 +7087,12 @@ inline void Logistic16bitPrecision(const LogisticParams& params, } } -// Transpose2D only deals with typical 2D matrix transpose ops. -// Perform transpose by transposing 4x4 blocks of the input, proceeding from -// left to right (down the rows) of the input, and then from top to bottom. -template -inline void Transpose2D(const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); - - const int d0 = input_shape.DimsData()[0]; - const int d1 = input_shape.DimsData()[1]; - const int kLines = 4; - const int kSkipSize = (kLines - 1) * d1; - - const T* input = input_data; - - int i = 0; - for (; i <= d0 - kLines; i += kLines) { - T* output = output_data + i; - - const T* input_ptr = input; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - - int j = 0; - for (; j <= d1 - kLines; j += kLines) { - input_ptr = input; - const T a00 = input_ptr[0]; - const T a01 = input_ptr[1]; - const T a02 = input_ptr[2]; - const T a03 = input_ptr[3]; - input_ptr += d1; - const T a10 = input_ptr[0]; - const T a11 = input_ptr[1]; - const T a12 = input_ptr[2]; - const T a13 = input_ptr[3]; - input_ptr += d1; - const T a20 = input_ptr[0]; - const T a21 = input_ptr[1]; - const T a22 = input_ptr[2]; - const T a23 = input_ptr[3]; - input_ptr += d1; - const T a30 = input_ptr[0]; - const T a31 = input_ptr[1]; - const T a32 = input_ptr[2]; - const T a33 = input_ptr[3]; - - output[0] = a00; - output[1] = a10; - output[2] = a20; - output[3] = a30; - output += d0; - - output[0] = a01; - output[1] = a11; - output[2] = a21; - output[3] = a31; - output += d0; - - output[0] = a02; - output[1] = a12; - output[2] = a22; - output[3] = a32; - output += d0; - - output[0] = a03; - output[1] = a13; - output[2] = a23; - output[3] = a33; - output += d0; - - input += kLines; - } - if (j == d1) { - input += kSkipSize; - } else { - for (int p = 0; p < kLines; ++p) { - for (int q = 0; q < d1 - j; ++q) { - *(output + q * d0 + p) = *(input + p * d1 + q); - } - } - input += (d1 - j) + kSkipSize; - } - } - for (; i < d0; ++i) { - T* output = output_data + i; - for (int j = 0; j < d1; ++j) { - *output = *input; - output += d0; - ++input; - } - } -} - -template <> -inline void Transpose2D(const RuntimeShape& input_shape, - const int32_t* input_data, - const RuntimeShape& output_shape, - int32_t* output_data) { - TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 2); - TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 2); - - const int d0 = input_shape.DimsData()[0]; - const int d1 = input_shape.DimsData()[1]; -#ifdef USE_NEON - const int kLines = 4; - const int kSkipSize = (kLines - 1) * d1; -#endif - - const int32_t* input = input_data; - - int i = 0; -#ifdef USE_NEON - for (; i <= d0 - kLines; i += kLines) { - int32_t* output = output_data + i; - - const int32_t* input_ptr = input; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - input_ptr += d1; - optimized_ops_preload_l1_keep(input_ptr); - - int j = 0; - for (; j <= d1 - kLines; j += kLines) { - input_ptr = input; - int32x4_t a0 = vld1q_s32(input); - input_ptr += d1; - int32x4_t a1 = vld1q_s32(input_ptr); - input_ptr += d1; - int32x4_t a2 = vld1q_s32(input_ptr); - input_ptr += d1; - int32x4_t a3 = vld1q_s32(input_ptr); - - int32x4x2_t tmp1 = vuzpq_s32(a0, a2); - int32x4x2_t tmp2 = vuzpq_s32(a1, a3); - int32x4x2_t tmp3 = vtrnq_s32(tmp1.val[0], tmp2.val[0]); - int32x4x2_t tmp4 = vtrnq_s32(tmp1.val[1], tmp2.val[1]); - - vst1q_s32(output, tmp3.val[0]); - output += d0; - vst1q_s32(output, tmp4.val[0]); - output += d0; - vst1q_s32(output, tmp3.val[1]); - output += d0; - vst1q_s32(output, tmp4.val[1]); - output += d0; - input += kLines; - } - if (j == d1) { - input += kSkipSize; - } else { - for (int p = 0; p < kLines; ++p) { - for (int q = 0; q < d1 - j; ++q) { - *(output + q * d0 + p) = *(input + p * d1 + q); - } - } - input += (d1 - j) + kSkipSize; - } - } -#endif - for (; i < d0; ++i) { - int32_t* output = output_data + i; - for (int j = 0; j < d1; ++j) { - *output = *input; - output += d0; - ++input; - } - } -} - -// TODO(b/173718660): see if we can reduce the number -// of lines of code in branching without affecting latency. -template -inline void Transpose3D(const TransposeParams& params, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { - int s2, s3; - s2 = input_shape.Dims(1); - s3 = input_shape.Dims(2); - - int p1, p2, p3; - if (params.perm[0] == 2) { - p1 = 1; - } else if (params.perm[1] == 2) { - p2 = 1; - } else { - p3 = 1; - } - - if (params.perm[0] == 1) { - p1 = s3; - } else if (params.perm[1] == 1) { - p2 = s3; - } else { - p3 = s3; - } - - if (params.perm[0] == 0) { - p1 = s2 * s3; - } else if (params.perm[1] == 0) { - p2 = s2 * s3; - } else { - p3 = s2 * s3; - } - - int o_s[3]; - o_s[0] = input_shape.Dims(params.perm[0]); - o_s[1] = input_shape.Dims(params.perm[1]); - o_s[2] = input_shape.Dims(params.perm[2]); - - for (int i1 = 0; i1 < o_s[0]; ++i1) { - for (int i2 = 0; i2 < o_s[1]; ++i2) { - for (int i3 = 0; i3 < o_s[2]; ++i3) { - const int i = i1 * p1 + i2 * p2 + i3 * p3; - const int o = i1 * o_s[1] * o_s[2] + i2 * o_s[2] + i3; - output_data[o] = input_data[i]; - } - } - } -} - -template -void TransposeImpl(const TransposeParams& params, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { - const int dims_cnt = input_shape.DimensionsCount(); - - int dim0, dim1; - if (transpose_utils::IsTranspose2DApplicable(params, input_shape, &dim0, - &dim1)) { - Transpose2D(RuntimeShape({dim0, dim1}), input_data, - RuntimeShape({dim1, dim0}), output_data); - return; - } - - // TODO(b/141217325): notably Eigen is better suited for - // larger inputs whereas Transpose3D is generally - // better for smaller ones. - // - // E.g. on Nexus 5, Eigen is better for size 96^3 and up - // and Transpose3D is better for 72^3 and down. - // - // 96^3 is not mobile-friendly for certain usecases - // (e.g. model used in beam search for seq2seq) but is in others. - // Consider tradeoffs. - if (dims_cnt == 3) { - Transpose3D(params, input_shape, input_data, output_shape, output_data); - return; - } - - // Reroute to the reference version if an optimized method for the given data - // is not available. - reference_ops::Transpose(params, input_shape, input_data, output_shape, - output_data); -} - template -void Transpose(const TransposeParams& unshrinked_params, - const RuntimeShape& unshrinked_input_shape, const T* input_data, - const RuntimeShape& unshrinked_output_shape, T* output_data) { - ruy::profiler::ScopeLabel label("Transpose"); - - const int output_size = unshrinked_output_shape.DimensionsCount(); - TFLITE_DCHECK_EQ(output_size, unshrinked_params.perm_count); - - RuntimeShape shrinked_input_shape = RuntimeShape(unshrinked_input_shape); - RuntimeShape shrinked_output_shape = RuntimeShape(unshrinked_output_shape); - TransposeParams shrinked_params = unshrinked_params; - - // Reduce any dimensions that have one size. Lower transpose op usually - // performs better since memory access patterns will be improved. - transpose_utils::RemoveOneSizeDimensions( - &shrinked_input_shape, &shrinked_output_shape, &shrinked_params); - - // Handle identity cases. - // TODO(b/140779653): Add an optimization pass in the conversion process to - // remove transpose op nodes where they do nothing like the below one. - bool identical = true; - for (int i = 0; i < shrinked_params.perm_count; ++i) { - if (shrinked_params.perm[i] != i) { - identical = false; - break; - } - } - if (identical) { - memcpy(output_data, input_data, - unshrinked_input_shape.FlatSize() * sizeof(T)); - return; - } - - // Reduce dimensions by flattening. - if (shrinked_params.perm[0] == 0 && output_size >= 3) { - RuntimeShape non_flatten_input_shape; - RuntimeShape non_flatten_output_shape; - TransposeParams non_flatten_params; - const int total_size = shrinked_input_shape.FlatSize(); - const int non_flatten_size = transpose_utils::Flatten( - shrinked_input_shape, shrinked_output_shape, shrinked_params, - &non_flatten_input_shape, &non_flatten_output_shape, - &non_flatten_params); - TFLITE_DCHECK_NE(non_flatten_params.perm[0], 0); - - for (int i = 0; i < total_size; i += non_flatten_size) { - TransposeImpl(non_flatten_params, non_flatten_input_shape, - input_data + i, non_flatten_output_shape, - output_data + i); - } - return; - } - - // Call non-flattened case. - TransposeImpl(shrinked_params, shrinked_input_shape, input_data, - shrinked_output_shape, output_data); +void Transpose(const TransposeParams& params, const RuntimeShape& input_shape, + const T* input_data, const RuntimeShape& output_shape, + T* output_data) { + return reference_ops::Transpose(params, input_shape, input_data, output_shape, + output_data); } // Assume input1 & input2 have the same scale & zero point. diff --git a/tensorflow/lite/kernels/internal/reference/batch_matmul.h b/tensorflow/lite/kernels/internal/reference/batch_matmul.h index 767ad6ab0af12b..d83696219c2572 100644 --- a/tensorflow/lite/kernels/internal/reference/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/reference/batch_matmul.h @@ -111,7 +111,8 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, const float* scaling_factors, const int32_t* input_offset, int32_t* row_sums, const RuntimeShape& output_shape, float* output_data, - bool* compute_row_sums) { + bool* compute_row_sums, + const float* per_channel_scales) { const RuntimeShape extended_lhs_shape = RuntimeShape::ExtendedShape(5, lhs_shape); const RuntimeShape extended_rhs_shape = @@ -188,7 +189,11 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, int32_t row_sum = woff_ptr2[i]; total -= row_sum * batch_offset; int idx = lhs_rows * j + i; - out_ptr[idx] += batch_scaling_factor * total; + float scale = batch_scaling_factor; + if (per_channel_scales) { + scale *= per_channel_scales[i]; + } + out_ptr[idx] += scale * total; } } } diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.cc b/tensorflow/lite/kernels/internal/reference/comparisons.cc index 86b4a6af0c0f2e..36ce951ec17536 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.cc +++ b/tensorflow/lite/kernels/internal/reference/comparisons.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/reference/comparisons.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" + namespace tflite { namespace reference_ops { diff --git a/tensorflow/lite/kernels/internal/reference/comparisons.h b/tensorflow/lite/kernels/internal/reference/comparisons.h index 366b378c825266..a9f1e42c0a6c94 100644 --- a/tensorflow/lite/kernels/internal/reference/comparisons.h +++ b/tensorflow/lite/kernels/internal/reference/comparisons.h @@ -16,7 +16,9 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_ #include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { diff --git a/tensorflow/lite/kernels/internal/runtime_shape.h b/tensorflow/lite/kernels/internal/runtime_shape.h index e266bb85477ad6..8982cb1732f018 100644 --- a/tensorflow/lite/kernels/internal/runtime_shape.h +++ b/tensorflow/lite/kernels/internal/runtime_shape.h @@ -19,6 +19,7 @@ limitations under the License. // LINT.IfChange #include +#include #include #include #include diff --git a/tensorflow/lite/kernels/maximum_minimum.cc b/tensorflow/lite/kernels/maximum_minimum.cc index 3a56d171474645..08e6d991d6cd2c 100644 --- a/tensorflow/lite/kernels/maximum_minimum.cc +++ b/tensorflow/lite/kernels/maximum_minimum.cc @@ -26,16 +26,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" -#ifdef TFLITE_KERNEL_USE_XNNPACK -#include -#include -#include - -#include "xnnpack.h" // from @XNNPACK -#include "tensorflow/lite/kernels/cpu_backend_context.h" -#include "tensorflow/lite/minimal_logging.h" -#endif // TFLITE_KERNEL_USE_XNNPACK - namespace tflite { namespace ops { namespace builtin { @@ -175,57 +165,6 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (op_context.output->type) { case kTfLiteFloat32: { -#ifdef TFLITE_KERNEL_USE_XNNPACK - size_t num_input1_dims = static_cast( - GetTensorShape(op_context.input1).DimensionsCount()); - size_t num_input2_dims = static_cast( - GetTensorShape(op_context.input2).DimensionsCount()); - if (std::max(num_input1_dims, num_input2_dims) < XNN_MAX_TENSOR_DIMS) { - std::array input1_shape; - std::array input2_shape; - for (size_t i = 0; i < num_input1_dims; ++i) { - input1_shape[i] = GetTensorShape(op_context.input1).Dims(i); - } - for (size_t i = 0; i < num_input2_dims; ++i) { - input2_shape[i] = GetTensorShape(op_context.input2).Dims(i); - } - CpuBackendContext* cpu_backend_context = - CpuBackendContext::GetFromContext(context); - pthreadpool_t threadpool = - cpu_backend_context->get_xnnpack_threadpool(); - enum xnn_status status = xnn_status_invalid_parameter; - if (std::is_same::value) { - status = xnn_run_maximum_nd_f32( - num_input1_dims, input1_shape.data(), num_input2_dims, - input2_shape.data(), GetTensorData(op_context.input1), - GetTensorData(op_context.input2), - GetTensorData(op_context.output), - /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool); - if (status != xnn_status_success) { - TFLITE_LOG(TFLITE_LOG_INFO, - "Failed to run xnn_run_maximum_nd_f32. Error code: %d", - status); - TFLiteOperation(context, node, - op_context); - } - } else if (std::is_same::value) { - status = xnn_run_minimum_nd_f32( - num_input1_dims, input1_shape.data(), num_input2_dims, - input2_shape.data(), GetTensorData(op_context.input1), - GetTensorData(op_context.input2), - GetTensorData(op_context.output), - /*flags=*/XNN_FLAG_YIELD_WORKERS, threadpool); - if (status != xnn_status_success) { - TFLITE_LOG(TFLITE_LOG_INFO, - "Failed to run xnn_run_minimum_nd_f32. Error code: %d", - status); - TFLiteOperation(context, node, - op_context); - } - } - break; - } -#endif TFLiteOperation(context, node, op_context); break; } diff --git a/tensorflow/lite/kernels/pow_test.cc b/tensorflow/lite/kernels/pow_test.cc index 553159c5fdd684..0eb381f4a9c053 100644 --- a/tensorflow/lite/kernels/pow_test.cc +++ b/tensorflow/lite/kernels/pow_test.cc @@ -119,7 +119,8 @@ TEST(PowOpModel, BroadcastFloatTest) { model.PopulateTensor(model.input2(), {4}); ASSERT_EQ(model.Invoke(), kTfLiteOk); EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); - EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear({20736, 16, 2401, 4096}))); } template diff --git a/tensorflow/lite/kernels/random_ops.cc b/tensorflow/lite/kernels/random_ops.cc index 70665061f39cb4..28f0e3f80ccf2b 100644 --- a/tensorflow/lite/kernels/random_ops.cc +++ b/tensorflow/lite/kernels/random_ops.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" namespace tflite { namespace ops { diff --git a/tensorflow/lite/kernels/reshape.cc b/tensorflow/lite/kernels/reshape.cc index ff53ddb85be876..006af7583218c1 100644 --- a/tensorflow/lite/kernels/reshape.cc +++ b/tensorflow/lite/kernels/reshape.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/array.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/tensor.h" @@ -37,6 +38,7 @@ struct OpData { // This is to prevent incorrect results when mischievous users overwrite // output pointers with their own. const void* output_ptr; + bool output_shape_known = true; }; TfLiteIntArray* GetOutputShape(TfLiteContext*, TfLiteNode*); @@ -96,7 +98,9 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { inline TfLiteIntArray* GetOutputShapeFromTensor(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); - if (shape == nullptr) return nullptr; + if (shape == nullptr) { + return nullptr; + } TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape->dims->data[0]); for (int i = 0; i < output_shape->size; ++i) { @@ -159,19 +163,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* shape = GetInput(context, node, kShapeTensor); if (NumInputs(node) == 1 || IsConstantOrPersistentTensor(shape)) { + op_data->output_shape_known = true; if (IsConstantOrPersistentTensor(input)) { SetTensorToPersistentRo(output); TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); op_data->output_ptr = output->data.data; memcpy(output->data.data, input->data.data, input->bytes); - return kTfLiteOk; } else { TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); } + return kTfLiteOk; } else { - SetTensorToDynamic(output); + op_data->output_shape_known = false; + // We know the output bytes size is the same as the input. Setting this + // enables tensor sharing in the ArenaPlanner. + if (output->allocation_type == kTfLiteArenaRw) { + output->bytes = input->bytes; + } + return kTfLiteOutputShapeNotKnown; } } + op_data->output_shape_known = true; return kTfLiteOk; } @@ -186,8 +198,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // There are two ways in which the 'output' can be made dynamic: it could be // a string tensor, or its shape cannot be calculated during Prepare(). In // either case, we now have all the information to calculate its shape. - if (IsDynamicTensor(output)) { - TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + if (output->type != kTfLiteString) { + if (!op_data->output_shape_known) { + if (output->data.data != input->data.data) { + // If the otuput cannot overwrite the input, then we have to set the + // tensor to dyanmic. + SetTensorToDynamic(output); + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + // The output pointer was set to zero during the call to ResizeTensor. + // Since the output aliases the input, set it back. + output->data.data = input->data.data; + } + } } // Note that string tensors are always "dynamic" in the sense that their size @@ -197,6 +221,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // reshape doesn't change the data, the output tensor needs exactly as many // bytes as the input tensor. if (output->type == kTfLiteString) { + SetTensorToDynamic(output); + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); auto bytes_required = input->bytes; TfLiteTensorRealloc(bytes_required, output); output->bytes = bytes_required; @@ -235,7 +261,8 @@ TfLiteRegistration* Register_RESHAPE() { /*version=*/0, /*registration_external=*/nullptr, /*async_kernel=*/nullptr, - kTfLiteInplaceOpInput0Shared | kTfLiteInplaceOpDataUnmodified}; + /*inplace_operator=*/kTfLiteInplaceOpInput0Shared | + kTfLiteInplaceOpDataUnmodified}; return &r; } diff --git a/tensorflow/lite/kernels/strided_slice_test.cc b/tensorflow/lite/kernels/strided_slice_test.cc index 9e63abebe736fc..e769831c8391a1 100644 --- a/tensorflow/lite/kernels/strided_slice_test.cc +++ b/tensorflow/lite/kernels/strided_slice_test.cc @@ -156,6 +156,15 @@ using DataTypes = ::testing::Types; TYPED_TEST_SUITE(StridedSliceOpTest, DataTypes); +template +auto ElementsAreTypedArray(std::vector x) { + if constexpr (std::is_floating_point_v) { + return ElementsAreArray(ArrayFloatNear(std::move(x))); + } else { + return ElementsAreArray(std::move(x)); + } +} + #if GTEST_HAS_DEATH_TEST TYPED_TEST(StridedSliceOpTest, UnsupportedInputSize) { EXPECT_DEATH(StridedSliceOpModel({2, 2, 2, 2, 2, 2}, {5}, {5}, {5}, @@ -191,7 +200,7 @@ TYPED_TEST(StridedSliceOpTest, Offset) { 0, 0, 0, 0, constant_tensors, /*offset=*/true); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 3})); if (constant_tensors) { EXPECT_THAT(m.GetOutputTensor(0)->allocation_type, kTfLitePersistentRo); } else { @@ -212,7 +221,7 @@ TYPED_TEST(StridedSliceOpTest, OffsetArray) { {2, 2}, {1, 1}, 0, 0, 0, 0, 0, constant_tensors, /*offset=*/true); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 5, 6})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 5, 6})); if (constant_tensors) { EXPECT_THAT(m.GetOutputTensor(0)->allocation_type, kTfLitePersistentRo); } else { @@ -229,7 +238,7 @@ TYPED_TEST(StridedSliceOpTest, OffsetConstant) { /*offset=*/true); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 5, 6})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 5, 6})); EXPECT_THAT(m.GetOutputTensor(0)->allocation_type, kTfLiteArenaRw); } @@ -245,7 +254,7 @@ TYPED_TEST(StridedSliceOpTest, OffsetConstantStride) { /*offset=*/true); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 13, 15})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3, 13, 15})); EXPECT_THAT(m.GetOutputTensor(0)->allocation_type, kTfLiteArenaRw); } @@ -261,7 +270,8 @@ TYPED_TEST(StridedSliceOpTest, OffsetConstantNegativeStride) { /*offset=*/true); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({28, 26, 16, 14})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({28, 26, 16, 14})); EXPECT_THAT(m.GetOutputTensor(0)->allocation_type, kTfLiteArenaRw); } @@ -275,7 +285,7 @@ TYPED_TEST(StridedSliceOpTest, In1D) { {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 3})); } } @@ -289,7 +299,7 @@ TYPED_TEST(StridedSliceOpTest, In1DConst) { {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 3})); } } @@ -308,7 +318,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_Int32End) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({32768})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray(values)); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray(values)); } } @@ -335,7 +345,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_NegativeBegin) { {3}, {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 3})); } } @@ -349,7 +359,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { {3}, {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 3})); } } @@ -364,7 +374,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_NegativeEnd) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2})); } } @@ -378,7 +388,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { {5}, {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 3, 4})); } } @@ -392,7 +402,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_BeginMask) { {1}, 1, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 3})); } } @@ -408,7 +418,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({3})); } } @@ -422,7 +432,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { {-1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({4})); } } @@ -437,7 +447,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({3, 2})); } } @@ -452,7 +462,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 1})); } } @@ -466,7 +476,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_EndMask) { {1}, 0, 1, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2, 3, 4})); } } @@ -480,7 +490,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_NegStride) { {-1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({3, 2, 1})); } } @@ -494,7 +504,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_EvenLenStride2) { 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } } @@ -508,7 +518,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_OddLenStride2) { {2}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3})); } } @@ -523,7 +533,8 @@ TYPED_TEST(StridedSliceOpTest, In2D_Identity) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } @@ -538,7 +549,7 @@ TYPED_TEST(StridedSliceOpTest, In2D) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({4, 5})); } } @@ -553,7 +564,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_Stride2) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3})); } } @@ -568,7 +579,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_NegStride) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({6, 5, 4})); } } @@ -583,7 +594,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_BeginMask) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 4, 5})); } } @@ -598,7 +609,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_EndMask) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { @@ -612,7 +623,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({6, 5, 4})); } } TYPED_TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { @@ -626,7 +637,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({5, 4})); } } @@ -672,7 +683,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_Strided2) { {0, 0, 0}, {2, 3, 2}, {2, 2, 2}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 5})); } } TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { @@ -685,7 +696,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { {1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2})); } } TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { @@ -700,7 +711,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1_NegativeSlice) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({3})); } } TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { @@ -716,7 +727,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis3_NegativeSlice) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2})); } } TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { @@ -732,7 +743,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis2_BeginEndAxis1_NegativeSlice) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 1, 2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({0, 1, 2, 3})); } } TYPED_TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { @@ -745,7 +756,7 @@ TYPED_TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { {1}, 1, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } } TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { @@ -759,7 +770,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 3})); } } TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { @@ -773,7 +784,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 4})); } } TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { @@ -787,7 +798,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { @@ -801,7 +812,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { {0, 0, 0}, {1, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { @@ -815,7 +827,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { {0, 0, 0}, {2, 1, 2}, {1, 1, 1}, 0, 0, 0, 0, 2, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 7, 8})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { @@ -829,7 +841,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { {0, 0, 0}, {1, 1, 2}, {1, 1, 1}, 0, 0, 0, 0, 3, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { @@ -843,7 +855,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { {0, 0, 0}, {2, 3, 1}, {1, 1, 1}, 0, 0, 0, 0, 4, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 3, 5, 7, 9, 11})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { @@ -857,7 +870,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { {0, 0, 0}, {1, 3, 1}, {1, 1, 1}, 0, 0, 0, 0, 5, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3, 5})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { @@ -871,7 +884,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { {0, 0, 0}, {2, 1, 1}, {1, 1, 1}, 0, 0, 0, 0, 6, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 7})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { @@ -881,7 +894,7 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { {0, 0, 0}, {1, 1, 1}, {1, 1, 1}, 0, 0, 0, 0, 7, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_TRUE(m.GetOutputShape().empty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } // This tests catches a very subtle bug that was fixed by cl/188403234. @@ -892,7 +905,7 @@ TYPED_TEST(StridedSliceOpTest, RunTwice) { false); ASSERT_EQ(m.Invoke(), kTfLiteOk); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 4, 5})); auto setup_inputs = [&m]() { m.template SetInput({1, 2, 3, 4, 5, 6}, @@ -905,7 +918,7 @@ TYPED_TEST(StridedSliceOpTest, RunTwice) { setup_inputs(); ASSERT_EQ(m.Invoke(), kTfLiteOk); // Prior to cl/188403234 this was {4, 5}. - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 4, 5})); } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { for (bool constant_tensors : {true, false}) { @@ -918,7 +931,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1Uint8) { {0, 0, 0}, {1, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { @@ -932,7 +946,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1int8) { {0, 0, 0}, {1, 3, 2}, {1, 1, 1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In5D_Identity) { @@ -948,7 +963,8 @@ TYPED_TEST(StridedSliceOpTest, In5D_Identity) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 2, 1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 9, 10, 11, 12})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 9, 10, 11, 12})); } } TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) { @@ -964,7 +980,7 @@ TYPED_TEST(StridedSliceOpTest, In5D_IdentityShrinkAxis1) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 3, 4})); } } TYPED_TEST(StridedSliceOpTest, In3D_SmallBegin) { @@ -978,7 +994,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBegin) { {1}, {1}, 0, 0, 0, 0, 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) { @@ -992,7 +1009,8 @@ TYPED_TEST(StridedSliceOpTest, In3D_SmallBeginWithhrinkAxis1) { {1}, {1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, In3D_BackwardSmallBeginEndMask) { @@ -1082,7 +1100,7 @@ TYPED_TEST(StridedSliceOpTest, In2D_ShrinkAxis_Endmask_AtSameAxis) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } } TYPED_TEST(StridedSliceOpTest, EllipsisMask1_NewAxisMask2) { @@ -1097,7 +1115,8 @@ TYPED_TEST(StridedSliceOpTest, EllipsisMask1_NewAxisMask2) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 1, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 3, 5, 7, 9, 11})); } } TYPED_TEST(StridedSliceOpTest, EllipsisMask2_NewAxisMask1) { @@ -1112,7 +1131,8 @@ TYPED_TEST(StridedSliceOpTest, EllipsisMask2_NewAxisMask1) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 3, 5, 7, 9, 11})); } } TYPED_TEST(StridedSliceOpTest, EllipsisMask2_NewAxisMask5) { @@ -1143,7 +1163,7 @@ TYPED_TEST(StridedSliceOpTest, EllipsisMask2_NewAxisMask2) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3, 5})); } } TYPED_TEST(StridedSliceOpTest, EllipsisMask4_NewAxisMask2) { @@ -1158,7 +1178,8 @@ TYPED_TEST(StridedSliceOpTest, EllipsisMask4_NewAxisMask2) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 3, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({1, 2, 3, 4, 5, 6})); } } TYPED_TEST(StridedSliceOpTest, EllipsisMask2) { @@ -1173,7 +1194,7 @@ TYPED_TEST(StridedSliceOpTest, EllipsisMask2) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 3, 5})); } } TYPED_TEST(StridedSliceOpTest, NewAxisMask2) { @@ -1188,7 +1209,7 @@ TYPED_TEST(StridedSliceOpTest, NewAxisMask2) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2})); } } TYPED_TEST(StridedSliceOpTest, NewAxisMask1) { @@ -1203,7 +1224,7 @@ TYPED_TEST(StridedSliceOpTest, NewAxisMask1) { ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1, 2})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1, 2, 7, 8})); } } TYPED_TEST(StridedSliceOpTest, NoInfiniteLoop) { @@ -1229,7 +1250,7 @@ TYPED_TEST(StridedSliceOpTest, MinusThreeMinusFourMinusOne) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2})); } } TYPED_TEST(StridedSliceOpTest, MinusFourMinusThreeOne) { @@ -1243,7 +1264,7 @@ TYPED_TEST(StridedSliceOpTest, MinusFourMinusThreeOne) { constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({1})); } } TYPED_TEST(StridedSliceOpTest, OneOneOne) { @@ -1268,7 +1289,7 @@ TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxis) { {1}, 0, 0, 0, 0, 1, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), IsEmpty()); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreTypedArray({2})); } } TYPED_TEST(StridedSliceOpTest, OneOneOneShrinkAxisOOB) { @@ -1318,7 +1339,8 @@ TYPED_TEST(StridedSliceOpTest, NegEndMask) { 0, constant_tensors); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1, 6, 5, 4})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({3, 2, 1, 6, 5, 4})); } } TYPED_TEST(StridedSliceOpTest, NoopOffset) { @@ -1326,7 +1348,8 @@ TYPED_TEST(StridedSliceOpTest, NoopOffset) { {0, -1}, {2, -3}, {1, -1}, 0, 0b10, 0, 0, 0); ASSERT_EQ(m.Invoke(), kTfLiteOk); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); - EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1, 6, 5, 4})); + EXPECT_THAT(m.GetOutput(), + ElementsAreTypedArray({3, 2, 1, 6, 5, 4})); } } // namespace } // namespace tflite diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h index ec7d799e202512..8cc248a36ed7f9 100644 --- a/tensorflow/lite/kernels/test_util.h +++ b/tensorflow/lite/kernels/test_util.h @@ -560,17 +560,15 @@ class SingleOpModel { void SignedSymmetricQuantizeAndPopulate(int index, const std::vector& data) { - std::vector q = QuantizeTensor(index, data); - PopulateTensor(index, /*offset=*/0, q.data(), q.data() + q.size()); - } - - void SignedSymmetricQuantizeAndPopulate4Bit(int index, - const std::vector& data) { TfLiteTensor* t = interpreter_->tensor(index); - t->type = kTfLiteInt4; - std::vector q = - Quantize(data, t->params.scale, t->params.zero_point, t->type); - PopulateTensor4bit(index, /*offset=*/0, q.data(), q.data() + q.size()); + if (t->type == kTfLiteInt4) { + std::vector q = Quantize(data, t->params.scale, + t->params.zero_point, t->type); + PopulateTensor4bit(index, /*offset=*/0, q.data(), q.data() + q.size()); + } else { + std::vector q = QuantizeTensor(index, data); + PopulateTensor(index, /*offset=*/0, q.data(), q.data() + q.size()); + } } // Quantize and populate data for filter with per channel quantization. diff --git a/tensorflow/lite/kernels/transpose.cc b/tensorflow/lite/kernels/transpose.cc index 8805bf87c459e1..f0ebfb79f22f3f 100644 --- a/tensorflow/lite/kernels/transpose.cc +++ b/tensorflow/lite/kernels/transpose.cc @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/lite/kernels/internal/reference/transpose.h" + #include #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -25,12 +26,6 @@ namespace ops { namespace builtin { namespace transpose { -// This file has two implementations of Transpose. -enum KernelType { - kReference, - kGenericOptimized, -}; - struct TransposeContext { TransposeContext(TfLiteContext* context, TfLiteNode* node) { input = GetInput(context, node, 0); @@ -89,7 +84,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ResizeOutputTensor(context, &op_context); } -template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TransposeContext op_context(context, node); @@ -119,11 +113,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { switch (op_context.input->type) { case kTfLiteFloat32: case kTfLiteInt32: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, int32_t); - } else { - TF_LITE_TRANSPOSE(reference_ops, int32_t); - } + TF_LITE_TRANSPOSE(reference_ops, int32_t); break; case kTfLiteBool: if (sizeof(bool) != 1) { @@ -133,18 +123,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { [[fallthrough]]; case kTfLiteUInt8: case kTfLiteInt8: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, int8_t); - } else { - TF_LITE_TRANSPOSE(reference_ops, int8_t); - } + TF_LITE_TRANSPOSE(reference_ops, int8_t); break; case kTfLiteInt16: - if (kernel_type == kGenericOptimized) { - TF_LITE_TRANSPOSE(optimized_ops, int16_t); - } else { - TF_LITE_TRANSPOSE(reference_ops, int16_t); - } + TF_LITE_TRANSPOSE(reference_ops, int16_t); break; case kTfLiteInt64: TF_LITE_TRANSPOSE(reference_ops, int64_t); @@ -164,19 +146,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration* Register_TRANSPOSE_REF() { static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, - transpose::Eval}; + transpose::Eval}; return &r; } -TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() { - static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, - transpose::Eval}; - return &r; -} - -TfLiteRegistration* Register_TRANSPOSE() { - return Register_TRANSPOSE_GENERIC_OPTIMIZED(); -} +TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); } } // namespace builtin } // namespace ops diff --git a/tensorflow/lite/kernels/transpose_conv.cc b/tensorflow/lite/kernels/transpose_conv.cc index 7d7c9a410ef451..0dd89386e0b23b 100644 --- a/tensorflow/lite/kernels/transpose_conv.cc +++ b/tensorflow/lite/kernels/transpose_conv.cc @@ -477,14 +477,25 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { scaling_factors_size)); } - const auto* affine_quantization = - reinterpret_cast( - weights->quantization.params); + auto* affine_quantization = reinterpret_cast( + weights->quantization.params); TF_LITE_ENSURE(context, affine_quantization); TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE_EQ( - context, affine_quantization->scale->size, - weights->dims->data[affine_quantization->quantized_dimension]); + + const int channels_out = + weights->dims->data[affine_quantization->quantized_dimension]; + if (affine_quantization->scale->size != channels_out) { + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, 1); + TfLiteFloatArrayFree(affine_quantization->scale); + affine_quantization->scale = TfLiteFloatArrayCreate(channels_out); + for (int i = 0; i < channels_out; ++i) { + affine_quantization->scale->data[i] = weights->params.scale; + } + } else { + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + channels_out); + } + node->temporaries->data[data->input_offset_index] = data->input_offset_id; TfLiteTensor* input_offsets; TF_LITE_ENSURE_OK(context, diff --git a/tensorflow/lite/kernels/variants/BUILD b/tensorflow/lite/kernels/variants/BUILD index 531fc8bfe0f6eb..e9d462cdbaf2ae 100644 --- a/tensorflow/lite/kernels/variants/BUILD +++ b/tensorflow/lite/kernels/variants/BUILD @@ -291,13 +291,17 @@ cc_test( deps = [ ":list_ops_subgraph_test_util", ":tensor_array", + "//tensorflow/lite:array", "//tensorflow/lite:interpreter_test_util", "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/c:common", "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels:test_util", + "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/kernels/variants/list_ops_subgraph_test.cc b/tensorflow/lite/kernels/variants/list_ops_subgraph_test.cc index ee12e45ea3871d..cb5491e9ca2c2f 100644 --- a/tensorflow/lite/kernels/variants/list_ops_subgraph_test.cc +++ b/tensorflow/lite/kernels/variants/list_ops_subgraph_test.cc @@ -18,10 +18,14 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "tensorflow/lite/array.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/interpreter_test_util.h" +#include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/test_util.h" #include "tensorflow/lite/kernels/variants/list_ops_subgraph_test_util.h" #include "tensorflow/lite/kernels/variants/tensor_array.h" diff --git a/tensorflow/lite/kernels/variants/list_ops_subgraph_test_util.cc b/tensorflow/lite/kernels/variants/list_ops_subgraph_test_util.cc index 536e1446588a78..de05c3625828aa 100644 --- a/tensorflow/lite/kernels/variants/list_ops_subgraph_test_util.cc +++ b/tensorflow/lite/kernels/variants/list_ops_subgraph_test_util.cc @@ -22,11 +22,13 @@ limitations under the License. #include #include +#include #include "absl/types/span.h" #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/c/builtin_op_data.h" +#include "tensorflow/lite/core/subgraph.h" #include "tensorflow/lite/kernels/builtin_op_kernels.h" #include "tensorflow/lite/kernels/op_macros.h" #include "tensorflow/lite/kernels/subgraph_test_util.h" diff --git a/tensorflow/lite/kernels/variants/list_ops_util.cc b/tensorflow/lite/kernels/variants/list_ops_util.cc index 447e09e952669d..4df9bd9bae4388 100644 --- a/tensorflow/lite/kernels/variants/list_ops_util.cc +++ b/tensorflow/lite/kernels/variants/list_ops_util.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/variants/tensor_array.h" -#include "tensorflow/lite/util.h" namespace tflite { namespace variants { diff --git a/tensorflow/lite/minimal_logging.cc b/tensorflow/lite/minimal_logging.cc index bdcec47e779359..7b5e4f6245a567 100644 --- a/tensorflow/lite/minimal_logging.cc +++ b/tensorflow/lite/minimal_logging.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/lite/logger.h" + namespace tflite { namespace logging_internal { diff --git a/tensorflow/lite/model_flex_test.cc b/tensorflow/lite/model_flex_test.cc index 987dcc4f234eac..c2257a6e393b83 100644 --- a/tensorflow/lite/model_flex_test.cc +++ b/tensorflow/lite/model_flex_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/core/model_builder.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { diff --git a/tensorflow/lite/model_xnnpack_test.cc b/tensorflow/lite/model_xnnpack_test.cc index 64e8104cb9874d..740518dc05cf54 100644 --- a/tensorflow/lite/model_xnnpack_test.cc +++ b/tensorflow/lite/model_xnnpack_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/core/macros.h" #include "tensorflow/lite/core/model_builder.h" -#include "tensorflow/lite/string_type.h" #include "tensorflow/lite/util.h" namespace tflite { diff --git a/tensorflow/lite/mutable_op_resolver_test.cc b/tensorflow/lite/mutable_op_resolver_test.cc index 8622579a3c8aa3..6a76f09575e0bf 100644 --- a/tensorflow/lite/mutable_op_resolver_test.cc +++ b/tensorflow/lite/mutable_op_resolver_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" -#include "tensorflow/lite/testing/util.h" namespace tflite { namespace { diff --git a/tensorflow/lite/profiling/proto/BUILD b/tensorflow/lite/profiling/proto/BUILD index 907ac4df3e0e3b..4ce67e6947d0bf 100644 --- a/tensorflow/lite/profiling/proto/BUILD +++ b/tensorflow/lite/profiling/proto/BUILD @@ -42,14 +42,12 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "profiling_info_py_pb2", -# api_version = 2, # compatible_with = get_compatible_with_portable(), # deps = [":profiling_info_proto"], # ) # # py_proto_library( # name = "model_runtime_info_py_pb2", -# api_version = 2, # compatible_with = get_compatible_with_portable(), # deps = [":model_runtime_info_proto"], # ) diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD index c5a4ba27639458..d6f626a28594e2 100644 --- a/tensorflow/lite/python/BUILD +++ b/tensorflow/lite/python/BUILD @@ -490,7 +490,6 @@ pytype_strict_library( "//tensorflow/lite/tools:flatbuffer_utils", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/platform:resource_loader", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:tf_export", ], diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index f4206cb68c932d..1fde13680a4971 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -14,13 +14,9 @@ # ============================================================================== """Converts a frozen graph into a TFLite FlatBuffer.""" -import distutils.spawn + import enum import hashlib -import os as _os -import platform as _platform -import subprocess as _subprocess -import tempfile as _tempfile from typing import Optional import warnings @@ -41,7 +37,6 @@ from tensorflow.lite.tools import flatbuffer_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape -from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export as _tf_export @@ -143,18 +138,6 @@ def convert_inference_tf_type_to_tflite_type( return tflite_type -# Find the deprecated conversion binary using the resource loader if using from -# bazel, otherwise we are in a pip where console_scripts already has the tool. -if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY: - _deprecated_conversion_binary = "" -else: - _deprecated_conversion_binary = _resource_loader.get_path_to_datafile( - "../toco/python/toco_from_protos" - ) - if not _os.path.exists(_deprecated_conversion_binary): - _deprecated_conversion_binary = "toco_from_protos" - - def _try_convert_to_unicode(output): if output is None: return "" @@ -315,7 +298,6 @@ def convert( conversion_flags: _conversion_flags_pb2.ConverterFlags, input_data_str: Optional[str] = None, debug_info_str: Optional[str] = None, - enable_mlir_converter: bool = True, ): """Converts `input_data_str` to a TFLite model. @@ -327,178 +309,45 @@ def convert( it can be hlo text or proto) debug_info_str: Serialized `GraphDebugInfo` proto describing logging information. - enable_mlir_converter: Enables MLIR-based conversion. Returns: Converted model in serialized form (e.g. a TFLITE model is common). Raises: ConverterError: When conversion fails in TFLiteConverter, usually due to ops not being supported. - RuntimeError: When conversion fails, an exception is raised with the error - message embedded. """ - # Historically, deprecated conversion failures would trigger a crash, so we - # attempt to run the converter out-of-process. The current MLIR conversion - # pipeline surfaces errors instead, and can be safely run in-process. - if enable_mlir_converter or not _deprecated_conversion_binary: - try: - return wrap_converter.wrapped_convert( - model_flags.SerializeToString(), - conversion_flags.SerializeToString(), - input_data_str, - debug_info_str, - enable_mlir_converter, - ) - except Exception as e: - converter_error = ConverterError(str(e)) - - for error_data in _metrics_wrapper.retrieve_collected_errors(): - converter_error.append_error(error_data) - # Seldom we encounter the case where an unsupported - # `StatefulPartitionedCallOp` is not inlined and remains in the final - # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry. - # This makes the converter copy functions definitions called by - # multiple StatefulPartitionedCall, thus allowing them to be properly - # inlined. - if ( - error_data.error_code - == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR - and not conversion_flags.guarantee_all_funcs_one_use - ): - conversion_flags.guarantee_all_funcs_one_use = True - return convert( - model_flags, - conversion_flags, - input_data_str, - debug_info_str, - enable_mlir_converter, - ) - raise converter_error - - return _run_deprecated_conversion_binary( - model_flags.SerializeToString(), - conversion_flags.SerializeToString(), - input_data_str, - debug_info_str, - ) - -@convert_phase( - Component.CONVERT_TF_TO_TFLITE_MODEL, - SubComponent.CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER, -) -def _run_deprecated_conversion_binary( - model_flags_str, conversion_flags_str, input_data_str, debug_info_str=None -): - """Convert `input_data_str` using deprecated conversion binary. - - Args: - model_flags_str: Serialized proto describing model properties, see - `model_flags.proto`. - conversion_flags_str: Serialized proto describing TFLite converter - properties, see `compiler/mlir/lite/converter_flags.proto`. - input_data_str: Input data in serialized form (e.g. a graphdef is common) - debug_info_str: Serialized `GraphDebugInfo` proto describing logging - information. (default None) - - Returns: - Converted model in serialized form (e.g. a TFLITE model is common). - Raises: - ConverterError: When cannot find the deprecated conversion binary. - RuntimeError: When conversion fails, an exception is raised with the error - message embedded. - """ - if distutils.spawn.find_executable(_deprecated_conversion_binary) is None: - raise ConverterError("""Could not find `toco_from_protos` binary, make sure -your virtualenv bin directory or pip local bin directory is in your path. -In particular, if you have installed TensorFlow with --user, make sure you -add the install directory to your path. - -For example: -Linux: export PATH=$PATH:~/.local/bin/ -Mac: export PATH=$PATH:~/Library/Python//bin - -Alternative, use virtualenv.""") - # Windows and TemporaryFile are not that useful together, - # since you cannot have two readers/writers. So we have to - # make the temporaries and close and delete them explicitly. - conversion_filename: str = None - model_filename: str = None - input_filename: str = None - output_filename: str = None try: - # Build all input files - with ( - _tempfile.NamedTemporaryFile(delete=False) as fp_conversion, - _tempfile.NamedTemporaryFile(delete=False) as fp_model, - _tempfile.NamedTemporaryFile(delete=False) as fp_input, - _tempfile.NamedTemporaryFile(delete=False) as fp_debug, - ): - conversion_filename = fp_conversion.name - input_filename = fp_input.name - model_filename = fp_model.name - debug_filename = fp_debug.name - - fp_model.write(model_flags_str) - fp_conversion.write(conversion_flags_str) - fp_input.write(input_data_str) - debug_info_str = debug_info_str if debug_info_str else "" - # if debug_info_str contains a "string value", then the call to - # fp_debug.write(debug_info_str) will fail with the following error - # - # TypeError: a bytes-like object is required, not 'str' - # - # Some of the subtests within the "convert_test" unit-test fail - # with the error shown above. So watch out for that scenario and - # convert debug_info_str to bytes where needed - if not isinstance(debug_info_str, bytes): - fp_debug.write(debug_info_str.encode("utf-8")) - else: - fp_debug.write(debug_info_str) - - # Reserve an output file - with _tempfile.NamedTemporaryFile(delete=False) as fp: - output_filename = fp.name - - # Run - cmd = [ - _deprecated_conversion_binary, - model_filename, - conversion_filename, - input_filename, - output_filename, - "--debug_proto_file={}".format(debug_filename), - ] - cmdline = " ".join(cmd) - is_windows = _platform.system() == "Windows" - proc = _subprocess.Popen( - cmdline, - shell=True, - stdout=_subprocess.PIPE, - stderr=_subprocess.STDOUT, - close_fds=not is_windows, + return wrap_converter.wrapped_convert( + model_flags.SerializeToString(), + conversion_flags.SerializeToString(), + input_data_str, + debug_info_str, ) - stdout, stderr = proc.communicate() - exitcode = proc.returncode - if exitcode == 0: - with open(output_filename, "rb") as fp: - return fp.read() - else: - stdout = _try_convert_to_unicode(stdout) - stderr = _try_convert_to_unicode(stderr) - raise ConverterError("See console for info.\n%s\n%s\n" % (stdout, stderr)) - finally: - # Must manually cleanup files. - for filename in [ - conversion_filename, - input_filename, - model_filename, - output_filename, - ]: - try: - _os.unlink(filename) - except (OSError, TypeError): - pass + except Exception as e: + converter_error = ConverterError(str(e)) + + for error_data in _metrics_wrapper.retrieve_collected_errors(): + converter_error.append_error(error_data) + # Seldom we encounter the case where an unsupported + # `StatefulPartitionedCallOp` is not inlined and remains in the final + # IR. If this occurs we can set `guarantee_all_funcs_one_use` and retry. + # This makes the converter copy functions definitions called by + # multiple StatefulPartitionedCall, thus allowing them to be properly + # inlined. + if ( + error_data.error_code + == converter_error_data_pb2.ConverterErrorData.ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR + and not conversion_flags.guarantee_all_funcs_one_use + ): + conversion_flags.guarantee_all_funcs_one_use = True + return convert( + model_flags, + conversion_flags, + input_data_str, + debug_info_str, + ) + raise converter_error def build_model_flags( @@ -909,7 +758,6 @@ def convert_graphdef_with_arrays( """ model_flags = build_model_flags(**kwargs) conversion_flags = build_conversion_flags(**kwargs) - enable_mlir_converter = kwargs.get("enable_mlir_converter", True) quantized_input_stats = kwargs.get("quantized_input_stats", None) for idx, (name, shape) in enumerate(input_arrays_with_shape): @@ -940,7 +788,6 @@ def convert_graphdef_with_arrays( conversion_flags, input_data.SerializeToString(), debug_info_str=None, - enable_mlir_converter=enable_mlir_converter, ) return data @@ -972,7 +819,6 @@ def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): conversion_flags = build_conversion_flags(**kwargs) saved_model_dir = kwargs.get("saved_model_dir", None) input_shapes = kwargs.get("input_shapes", None) - enable_mlir_converter = kwargs.get("enable_mlir_converter", True) quantized_input_stats = kwargs.get("quantized_input_stats", None) debug_info = kwargs.get("debug_info", None) @@ -1030,7 +876,6 @@ def convert_graphdef(input_data, input_tensors, output_tensors, **kwargs): conversion_flags, input_data.SerializeToString(), debug_info_str=debug_info.SerializeToString() if debug_info else None, - enable_mlir_converter=enable_mlir_converter, ) return data @@ -1047,7 +892,6 @@ def convert_saved_model(**kwargs): conversion_flags, input_data_str=None, debug_info_str=None, - enable_mlir_converter=True, ) return data @@ -1075,7 +919,6 @@ def convert_jax_hlo(input_content, input_names, is_proto_format, **kwargs): conversion_flags, input_data_str=input_content, debug_info_str=None, - enable_mlir_converter=True, ) return data @@ -1103,7 +946,6 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): Raises: Defined in `convert`. """ - kwargs["enable_mlir_converter"] = kwargs.get("enable_mlir_converter", False) return convert_graphdef( input_data, input_tensors, output_tensors, *args, **kwargs ) diff --git a/tensorflow/lite/python/convert_test.py b/tensorflow/lite/python/convert_test.py index 91b3e810328b16..aad267ed4384ea 100644 --- a/tensorflow/lite/python/convert_test.py +++ b/tensorflow/lite/python/convert_test.py @@ -41,7 +41,6 @@ def _mock_wrapped_convert( conversion_flags_str="", unused_input_data_str="", unused_debug_info_str="", - unused_enable_mlir_converter=True, ): # Simulate the converter throwing and error when # `guarantee_all_funcs_one_use` is not set. @@ -76,32 +75,6 @@ def testBasic(self): ) self.assertTrue(tflite_model) - @mock.patch.object( - convert, - "_deprecated_conversion_binary", - new="tocos_from_proto", - ) - @mock.patch.object( - convert, - "_run_deprecated_conversion_binary", - autospec=True, - ) - def testBasicDeprecatedConversionBinary(self, mock_func): - with ops.Graph().as_default(): - in_tensor = array_ops.placeholder( - shape=[1, 16, 16, 3], dtype=dtypes.float32 - ) - out_tensor = in_tensor + in_tensor - sess = session.Session() - - convert.convert_graphdef( - sess.graph_def, - input_tensors=[in_tensor], - output_tensors=[out_tensor], - enable_mlir_converter=False, - ) - mock_func.assert_called_once() - @mock.patch.object( convert.wrap_converter, "wrapped_convert", new=_mock_wrapped_convert ) @@ -125,7 +98,6 @@ def testConversionStatefulPartitionRetry(self, mock_convert): sess.graph_def, input_tensors=[in_tensor], output_tensors=[out_tensor], - enable_mlir_converter=True, guarantee_all_funcs_one_use=False, ) self.assertTrue(str(model, encoding="utf-8"), "A model") @@ -164,7 +136,6 @@ def testGraphDefBasic(self): output_arrays=["add"], control_output_arrays=None, inference_type=dtypes.float32, - enable_mlir_converter=False, ) self.assertTrue(tflite_model) @@ -209,7 +180,6 @@ def testGraphDefQuantization(self): control_output_arrays=None, inference_type=dtypes.uint8, quantized_input_stats=[(0.0, 1.0), (0.0, 1.0)], - enable_mlir_converter=False, ) self.assertTrue(tflite_model) @@ -263,7 +233,6 @@ def testGraphDefQuantizationInvalid(self): output_arrays=["output"], control_output_arrays=None, inference_type=dtypes.uint8, - enable_mlir_converter=False, ) self.assertEqual( "The `quantized_input_stats` flag must be defined when either " diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 37ddd6bb416cb9..24d90fa9a7a89a 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -290,6 +290,9 @@ PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) { if (subgraph_index == kUndeterminedSubgraphIndex) { TFLITE_PY_CHECK(interpreter_->AllocateTensors()); } else { + // We don't check the return of this call. Failing is a real possiblity as + // the default XNNPack delegate may fail to apply on certain graphs. + interpreter_->ApplyLazyDelegateProviders(); TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index); TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors()); } diff --git a/tensorflow/lite/python/kernel_tests/signal/BUILD b/tensorflow/lite/python/kernel_tests/signal/BUILD index a6128e6be32f54..87b1deb32f4b35 100644 --- a/tensorflow/lite/python/kernel_tests/signal/BUILD +++ b/tensorflow/lite/python/kernel_tests/signal/BUILD @@ -3,7 +3,6 @@ load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//tensorflow:internal"], licenses = ["notice"], ) diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 87cb2bdf8969f7..5f0d1e14632f89 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -114,8 +114,8 @@ class Optimize(enum.Enum): The default optimization strategy that enables post-training quantization. The type of post-training quantization that will be used is dependent on the other converter options supplied. Refer to the - [documentation](/lite/performance/post_training_quantization) for further - information on the types available and how to use them. + [documentation](https://ai.google.dev/edge/litert/models/post_training_quantization) + for further information on the types available and how to use them. OPTIMIZE_FOR_SIZE Deprecated. Does the same as DEFAULT. @@ -792,7 +792,6 @@ def _get_base_converter_args(self): "allow_custom_ops": self.allow_custom_ops, "debug_info": self._debug_info, "target_ops": self.target_spec.supported_ops, - "enable_mlir_converter": self.experimental_new_converter, "select_user_tf_ops": self.target_spec.experimental_select_user_tf_ops, "supported_backends": self.target_spec.experimental_supported_backends, "unfold_batchmatmul": self.unfold_batchmatmul, diff --git a/tensorflow/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py index 4fc63f79f8c0c5..0d33eaeef14dd1 100644 --- a/tensorflow/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -60,12 +60,6 @@ _tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant( __name__, "GRAPHVIZ_DOT") -# Currently the default mode of operation is to shell to another python process -# to protect against crashes. However, it breaks some dependent targets because -# it forces us to depend on an external py_binary. The experimental API doesn't -# have that drawback. -EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False - _allowed_symbols = [ "FLOAT", @@ -85,6 +79,5 @@ "KERAS", "JAX", "PYTORCH", - "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/python/metrics/BUILD b/tensorflow/lite/python/metrics/BUILD index fdd6eb890fd571..14ab0aca8c4b52 100644 --- a/tensorflow/lite/python/metrics/BUILD +++ b/tensorflow/lite/python/metrics/BUILD @@ -16,7 +16,7 @@ cc_library( ), hdrs = ["wrapper/metrics_wrapper.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:private"], + visibility = ["//tensorflow/python:__pkg__"], deps = [ "//third_party/python_runtime:headers", ] + if_portable( @@ -38,6 +38,7 @@ pybind_extension( ], visibility = [ "__subpackages__", + "//tensorflow/python:__pkg__", "//tensorflow/tools/pip_package:__subpackages__", ], deps = [ @@ -127,6 +128,7 @@ py_strict_test( if_true = "metrics_portable_test.py", ), python_version = "PY3", + tags = ["notap"], # TODO(b/373657707): Remove once we debug the failure. deps = [ ":metrics", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/lite/python/metrics/metrics_nonportable_test.py b/tensorflow/lite/python/metrics/metrics_nonportable_test.py index ff12a822602c8c..f47482b658d0c6 100644 --- a/tensorflow/lite/python/metrics/metrics_nonportable_test.py +++ b/tensorflow/lite/python/metrics/metrics_nonportable_test.py @@ -154,7 +154,6 @@ def test_conversion_from_constructor_success(self): mock.call.increase_counter_converter_success(), mock.call.export_metrics(), mock.call.set_converter_param('input_format', '1'), - mock.call.set_converter_param('enable_mlir_converter', 'True'), mock.call.set_converter_param('allow_custom_ops', 'False'), mock.call.set_converter_param('api_version', '1'), ], any_order=True) # pyformat: disable @@ -275,7 +274,6 @@ def test_conversion_from_saved_model(self): mock.call.increase_counter_converter_success(), mock.call.set_converter_latency(2000), mock.call.export_metrics(), - mock.call.set_converter_param('enable_mlir_converter', 'True'), ], any_order=True) # pyformat: disable def disable_converter_counter_metrics(self, tflite_metrics): @@ -475,13 +473,11 @@ def create_graph_with_custom_add(opname='CustomAdd'): exported_error = metrics._gauge_conversion_errors.get_cell( 'CONVERT_TF_TO_TFLITE_MODEL', 'CONVERT_SAVED_MODEL', 'tf.CustomAdd', 'ERROR_NEEDS_CUSTOM_OPS').value() - self.assertContainsSubsequence( - exported_error, + self.assertIn( "'tf.CustomAdd' op is neither a custom op nor a flex op\n", + exported_error, ) - self.assertContainsSubsequence( - exported_error, 'Error code: ERROR_NEEDS_CUSTOM_OPS' - ) + self.assertIn('Error code: ERROR_NEEDS_CUSTOM_OPS', exported_error) def test_unsupported_control_flow_v1(self): filename = resource_loader.get_path_to_datafile( diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index fbd167c1ac5625..a3a922051fb1b9 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -111,9 +111,10 @@ exports_files([ # srcs = ["//tensorflow/compiler/mlir/lite/schema:schema.fbs"], # compatible_with = get_compatible_with_portable(), # ) -# copybara:uncomment_end_and_comment_begin +# copybara:uncomment_end(google-only) + cc_library( - name = "schema_fbs", + name = "schema_fbs", # copybara:comment_replace name = "schema_fbs_for_oss", hdrs = [ ":schema_generated.h", "//tensorflow/compiler/mlir/lite/schema:schema_generated.h", @@ -123,7 +124,6 @@ cc_library( "@flatbuffers//:runtime_cc", ], ) -# copybara:comment_end # Generic schema for flatbuffer converter (but with mutable makes bigger). flatbuffer_cc_library( diff --git a/tensorflow/lite/testing/matchers.h b/tensorflow/lite/testing/matchers.h index e32576c941cac2..17646ffb811eb4 100644 --- a/tensorflow/lite/testing/matchers.h +++ b/tensorflow/lite/testing/matchers.h @@ -257,6 +257,7 @@ struct SimpleConstTensor : public TfLiteTensor { std::memcpy(dims->data, shape.data(), shape.size() * sizeof(int)); data = {.data = buf.data()}; bytes = buf.size() * sizeof(T); + sparsity = nullptr; } ~SimpleConstTensor() { TfLiteIntArrayFree(dims); } }; diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 19644b6678110b..2c2e5e41081a9c 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -519,21 +519,18 @@ tf_cc_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "model_flags_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":model_flags_proto"], # ) # # py_proto_library( # name = "toco_flags_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_flags_proto"], # ) # # py_proto_library( # name = "types_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_flags_proto"], # ) diff --git a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc index 518a6832066a3c..a66ad270a3d347 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_following_affine_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc index 314521b6ab2711..35888667d4b3c9 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/fuse_binary_into_preceding_affine_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc index 4c55b7d6dcbb06..c21118f4df7e2e 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_normalization_test.cc @@ -16,10 +16,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc index bfb2acf3aa6fe8..ab487b4cf3bb28 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/identify_l2_pool_test.cc @@ -16,10 +16,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc index e18b2a8a486423..ae9006af978237 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/lstm_utils_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include #include #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc index 746a7579d41a57..561ca830fcb34b 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/remove_successive_transpose_test.cc @@ -16,11 +16,9 @@ limitations under the License. #include #include -#include #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace { diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc index 4ecd9c992bb058..405c79b8d52c40 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc index 136c14cad9b834..af26eef7ff6922 100644 --- a/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc @@ -19,10 +19,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc index b7302051043052..3a22849b949955 100755 --- a/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc +++ b/tensorflow/lite/toco/graph_transformations/tests/unpack_quantize_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/logging/BUILD b/tensorflow/lite/toco/logging/BUILD index 2b9f9205f86022..06c83facb5f977 100644 --- a/tensorflow/lite/toco/logging/BUILD +++ b/tensorflow/lite/toco/logging/BUILD @@ -107,7 +107,6 @@ py_strict_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "toco_conversion_log_proto_py", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":toco_conversion_log_proto"], # ) diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 7377ec00d6b666..b24b1944ed3ffe 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -148,7 +148,7 @@ tf_cc_test( "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@flatbuffers", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index 44223eac63c130..2ac4cd50310636 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -23,6 +23,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/string.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/core/lib/core/errors.h" diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index e6b7e977f0df2e..b1133f60f5c88a 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/builtin_operator.h" #include "tensorflow/lite/toco/tflite/operator.h" #include "tensorflow/lite/toco/tflite/types.h" -#include "tsl/protobuf/error_codes.pb.h" namespace toco { namespace tflite { diff --git a/tensorflow/lite/toco/tflite/import_test.cc b/tensorflow/lite/toco/tflite/import_test.cc index 5fe4989663e3e5..b73c673c9199d3 100644 --- a/tensorflow/lite/toco/tflite/import_test.cc +++ b/tensorflow/lite/toco/tflite/import_test.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/tflite/types_test.cc b/tensorflow/lite/toco/tflite/types_test.cc index d5ac84d9769768..5ed493c2ac066f 100644 --- a/tensorflow/lite/toco/tflite/types_test.cc +++ b/tensorflow/lite/toco/tflite/types_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/runtime/types.h" diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc index 5e2d3e3dea0005..f3c0e46e5786db 100644 --- a/tensorflow/lite/toco/toco_convert.cc +++ b/tensorflow/lite/toco/toco_convert.cc @@ -70,11 +70,10 @@ void ReadInputData(const ParsedTocoFlags& parsed_toco_flags, } } // namespace -tensorflow::Status Convert(const std::string& graph_def_contents, - const TocoFlags& toco_flags, - const ModelFlags& model_flags, - std::string* output_file_contents, - int64_t* arithmetic_ops_count = nullptr) { +absl::Status Convert(const std::string& graph_def_contents, + const TocoFlags& toco_flags, const ModelFlags& model_flags, + std::string* output_file_contents, + int64_t* arithmetic_ops_count = nullptr) { std::unique_ptr model = Import(toco_flags, model_flags, graph_def_contents); TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get())); @@ -86,8 +85,8 @@ tensorflow::Status Convert(const std::string& graph_def_contents, return absl::OkStatus(); } -tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags) { +absl::Status Convert(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags) { ModelFlags model_flags; ReadModelFlagsFromCommandLineFlags(parsed_model_flags, &model_flags); @@ -105,7 +104,7 @@ tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, TF_RETURN_IF_ERROR( port::file::SetContents(parsed_toco_flags.output_file.value(), output_file_contents, port::file::Defaults())); - return tensorflow::Status(); + return absl::Status(); } } // namespace toco diff --git a/tensorflow/lite/toco/toco_convert.h b/tensorflow/lite/toco/toco_convert.h index 737b31563fb7db..e77ab87f317696 100644 --- a/tensorflow/lite/toco/toco_convert.h +++ b/tensorflow/lite/toco/toco_convert.h @@ -24,14 +24,13 @@ limitations under the License. namespace toco { -tensorflow::Status Convert(const std::string& graph_def_contents, - const TocoFlags& toco_flags, - const ModelFlags& model_flags, - std::string* output_file_contents, - int64_t* arithmetic_ops_count = nullptr); - -tensorflow::Status Convert(const ParsedTocoFlags& parsed_toco_flags, - const ParsedModelFlags& parsed_model_flags); +absl::Status Convert(const std::string& graph_def_contents, + const TocoFlags& toco_flags, const ModelFlags& model_flags, + std::string* output_file_contents, + int64_t* arithmetic_ops_count = nullptr); + +absl::Status Convert(const ParsedTocoFlags& parsed_toco_flags, + const ParsedModelFlags& parsed_model_flags); } // namespace toco #endif // TENSORFLOW_LITE_TOCO_TOCO_CONVERT_H_ diff --git a/tensorflow/lite/toco/toco_port.cc b/tensorflow/lite/toco/toco_port.cc index f2e5e29f8c266c..941bb4d4f90521 100644 --- a/tensorflow/lite/toco/toco_port.cc +++ b/tensorflow/lite/toco/toco_port.cc @@ -72,10 +72,10 @@ void CheckInitGoogleIsDone(const char* message) { namespace file { // Conversion to our wrapper Status. -tensorflow::Status ToStatus(const absl::Status& uts) { +absl::Status ToStatus(const absl::Status& uts) { if (!uts.ok()) { - return tensorflow::Status(absl::StatusCode(::util::RetrieveErrorCode(uts)), - uts.message()); + return absl::Status(absl::StatusCode(::util::RetrieveErrorCode(uts)), + uts.message()); } return absl::OkStatus(); } @@ -86,7 +86,7 @@ toco::port::file::Options ToOptions(const ::file::Options& options) { return Options(); } -tensorflow::Status Writable(const std::string& filename) { +absl::Status Writable(const std::string& filename) { File* f = nullptr; const auto status = ::file::Open(filename, "w", &f, ::file::Defaults()); if (f) { @@ -95,26 +95,24 @@ tensorflow::Status Writable(const std::string& filename) { return ToStatus(status); } -tensorflow::Status Readable(const std::string& filename, - const file::Options& options) { +absl::Status Readable(const std::string& filename, + const file::Options& options) { return ToStatus(::file::Readable(filename, ::file::Defaults())); } -tensorflow::Status Exists(const std::string& filename, - const file::Options& options) { +absl::Status Exists(const std::string& filename, const file::Options& options) { auto status = ::file::Exists(filename, ::file::Defaults()); return ToStatus(status); } -tensorflow::Status GetContents(const std::string& filename, - std::string* contents, - const file::Options& options) { +absl::Status GetContents(const std::string& filename, std::string* contents, + const file::Options& options) { return ToStatus(::file::GetContents(filename, contents, ::file::Defaults())); } -tensorflow::Status SetContents(const std::string& filename, - const std::string& contents, - const file::Options& options) { +absl::Status SetContents(const std::string& filename, + const std::string& contents, + const file::Options& options) { return ToStatus(::file::SetContents(filename, contents, ::file::Defaults())); } diff --git a/tensorflow/lite/toco/toco_port.h b/tensorflow/lite/toco/toco_port.h index cc3ca93e6e8bb3..553830bd9c5d32 100644 --- a/tensorflow/lite/toco/toco_port.h +++ b/tensorflow/lite/toco/toco_port.h @@ -68,16 +68,14 @@ inline Options Defaults() { Options o; return o; } -tensorflow::Status GetContents(const std::string& filename, - std::string* contents, const Options& options); -tensorflow::Status SetContents(const std::string& filename, - const std::string& contents, - const Options& options); +absl::Status GetContents(const std::string& filename, std::string* contents, + const Options& options); +absl::Status SetContents(const std::string& filename, + const std::string& contents, const Options& options); std::string JoinPath(const std::string& a, const std::string& b); -tensorflow::Status Writable(const std::string& filename); -tensorflow::Status Readable(const std::string& filename, - const Options& options); -tensorflow::Status Exists(const std::string& filename, const Options& options); +absl::Status Writable(const std::string& filename); +absl::Status Readable(const std::string& filename, const Options& options); +absl::Status Exists(const std::string& filename, const Options& options); } // namespace file // Copy `src` string to `dest`. User must ensure `dest` has enough space. diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc index d0a28da9381917..5b38d535c7e8ac 100644 --- a/tensorflow/lite/toco/toco_tooling.cc +++ b/tensorflow/lite/toco/toco_tooling.cc @@ -242,8 +242,7 @@ std::unique_ptr Import(const TocoFlags& toco_flags, return model; } -tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, - Model* model) { +absl::Status TransformWithStatus(const TocoFlags& toco_flags, Model* model) { const FileFormat output_format = toco_flags.output_format(); const IODataType inference_type = toco_flags.inference_type(); @@ -472,9 +471,8 @@ tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, return absl::OkStatus(); } -tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, - bool allow_custom_ops, - std::string* output_file_contents) { +absl::Status Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, std::string* output_file_contents) { switch (toco_flags.output_format()) { case TENSORFLOW_GRAPHDEF: ExportTensorFlowGraphDef(model, output_file_contents); @@ -508,7 +506,7 @@ tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, LOG(FATAL) << "Unhandled output_format='" << FileFormat_Name(toco_flags.output_format()) << "'"; } - return tensorflow::Status(); + return absl::Status(); } } // namespace toco diff --git a/tensorflow/lite/toco/toco_tooling.h b/tensorflow/lite/toco/toco_tooling.h index 5577e20a53b7ad..6fe4fb064af1d4 100644 --- a/tensorflow/lite/toco/toco_tooling.h +++ b/tensorflow/lite/toco/toco_tooling.h @@ -31,8 +31,7 @@ std::unique_ptr Import(const TocoFlags& toco_flags, // Transforms a Model. The resulting Model is ready to be passed // to Export with the exact same toco_flags. -tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags, - Model* model); +absl::Status TransformWithStatus(const TocoFlags& toco_flags, Model* model); inline void Transform(const TocoFlags& toco_flags, Model* model) { auto s = TransformWithStatus(toco_flags, model); CHECK(s.ok()) << s.message(); @@ -41,9 +40,8 @@ inline void Transform(const TocoFlags& toco_flags, Model* model) { // Exports the Model, which must be of the 'lowered' form returned by // Transform, to a file of the format given by // toco_flags.output_format(). -tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model, - bool allow_custom_ops, - std::string* output_file_contents); +absl::Status Export(const TocoFlags& toco_flags, const Model& model, + bool allow_custom_ops, std::string* output_file_contents); // This if for backward-compatibility with internal tools. inline void Export(const TocoFlags& toco_flags, const Model& model, diff --git a/tensorflow/lite/toco/tooling_util.h b/tensorflow/lite/toco/tooling_util.h index e37108dcc47c03..b9419f19dbf649 100644 --- a/tensorflow/lite/toco/tooling_util.h +++ b/tensorflow/lite/toco/tooling_util.h @@ -327,7 +327,7 @@ void UseArraysExtraInfo(Model* model, bool quantize_output); // doesn't have enough range to represent the sum of elements, an error is // returned. template -tensorflow::Status NumElements(const std::vector& shape, U* num_elements) { +absl::Status NumElements(const std::vector& shape, U* num_elements) { static_assert( std::numeric_limits::max() <= std::numeric_limits::max(), "vector type exceed capabilities of NumElements"); diff --git a/tensorflow/lite/toco/tooling_util_test.cc b/tensorflow/lite/toco/tooling_util_test.cc index c92e546146d100..f0da510c69540a 100644 --- a/tensorflow/lite/toco/tooling_util_test.cc +++ b/tensorflow/lite/toco/tooling_util_test.cc @@ -105,7 +105,7 @@ static const char kLargeTensorMessage[] = "Tensor shape is too large"; TEST(NumElementsTest, Int) { int count; - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -124,7 +124,7 @@ TEST(NumElementsTest, Int) { TEST(NumElementsTest, Int32) { int32_t count; - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); status = NumElements(std::vector{1024, 1024, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -139,7 +139,7 @@ TEST(NumElementsTest, Int32) { TEST(NumElementsTest, Int64) { int64_t count; - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); status = NumElements(std::vector{16777216, 16777216, 32767}, &count); EXPECT_TRUE(status.ok()); @@ -154,7 +154,7 @@ TEST(NumElementsTest, Int64) { TEST(NumElementsTest, UnsignedInt32) { uint32_t count; - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); status = NumElements(std::vector{1024, 2048, 2047}, &count); EXPECT_TRUE(status.ok()); @@ -169,7 +169,7 @@ TEST(NumElementsTest, UnsignedInt32) { TEST(NumElementsTest, UnsignedInt64) { uint64_t count; - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); status = NumElements(std::vector{16777216, 16777216, 65535}, &count); @@ -185,7 +185,7 @@ TEST(NumElementsTest, UnsignedInt64) { } TEST(NumElementsTest, Scalar) { - tensorflow::Status status = absl::OkStatus(); + absl::Status status = absl::OkStatus(); int32_t count; status = NumElements(std::vector{}, &count); diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index b9260be0b9eac3..350e7d1a33ad6e 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -465,7 +465,6 @@ tflite_portable_test_suite() # # py_proto_library( # name = "op_kernel_set_py_pb2", -# api_version = 2, # deps = [":op_kernel_set_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD index d2d1f807604ab5..9114d2314f6327 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/BUILD @@ -37,7 +37,10 @@ cc_library( "//tensorflow/core/util:stats_calculator_portable", "//tensorflow/lite:minimal_logging", "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/profiling:memory_info", + "//tensorflow/lite/tools/benchmark:benchmark_model_lib", + "//tensorflow/lite/tools/benchmark:benchmark_params", "//tensorflow/lite/tools/benchmark:benchmark_tflite_model_lib", "//tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto:delegate_performance_cc_proto", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc index 906b109f86da66..b49c45e12ec62b 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.cc @@ -23,8 +23,11 @@ limitations under the License. #include #include +#include "flatbuffers/base.h" // from @flatbuffers #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" #include "tensorflow/lite/experimental/acceleration/mini_benchmark/c/c_api.h" #include "tensorflow/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h index 63e85cf70028b1..c2fcb0ca5df72c 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/accuracy_benchmark.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" namespace tflite { diff --git a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc index 4f739ccc838a47..03ca95276e0700 100644 --- a/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc +++ b/tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/src/main/native/latency_benchmark.cc @@ -29,9 +29,12 @@ limitations under the License. #include "absl/strings/str_format.h" #include "tensorflow/core/util/stats_calculator.h" #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/profiling/memory_info.h" +#include "tensorflow/lite/tools/benchmark/benchmark_model.h" +#include "tensorflow/lite/tools/benchmark/benchmark_params.h" #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h" #include "tensorflow/lite/tools/benchmark/experimental/delegate_performance/android/proto/delegate_performance.pb.h" diff --git a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake index ba2a25967a575c..145933451851bf 100644 --- a/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake +++ b/tensorflow/lite/tools/cmake/modules/cpuinfo.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( cpuinfo GIT_REPOSITORY https://github.com/pytorch/cpuinfo # Sync with tensorflow/workspace2.bzl - GIT_TAG fa1c679da8d19e1d87f20175ae1ec10995cd3dd3 + GIT_TAG 1e83a2fdd3102f65c6f1fb602c1b320486218a99 GIT_PROGRESS TRUE SOURCE_DIR "${CMAKE_BINARY_DIR}/cpuinfo" ) diff --git a/tensorflow/lite/tools/cmake/modules/ml_dtypes.cmake b/tensorflow/lite/tools/cmake/modules/ml_dtypes.cmake index 5e61934afafae0..c18da2cc8b00fd 100644 --- a/tensorflow/lite/tools/cmake/modules/ml_dtypes.cmake +++ b/tensorflow/lite/tools/cmake/modules/ml_dtypes.cmake @@ -23,7 +23,9 @@ OverridableFetchContent_Declare( ml_dtypes GIT_REPOSITORY https://github.com/jax-ml/ml_dtypes # Sync with tensorflow/third_party/py/ml_dtypes/workspace.bzl - GIT_TAG 24084d9ed2c3d45bf83b7a9bff833aa185bf9172 + # Github link: + # https://github.com/jax-ml/ml_dtypes/commit/6f02f77c4fa624d8b467c36d1d959a9b49b07900 + GIT_TAG 6f02f77c4fa624d8b467c36d1d959a9b49b07900 # It's not currently possible to shallow clone with a GIT TAG # as cmake attempts to git checkout the commit hash after the clone # which doesn't work as it's a shallow clone hence a different commit hash. diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 0bf6fc80f62345..327183cb6293dc 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 6b83f69d4938da4dc9ad63c00bd13e9695659a51 + GIT_TAG 743f95f0c34b02d6d2cdb9e87da21caffe9c668f GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/evaluation/proto/BUILD b/tensorflow/lite/tools/evaluation/proto/BUILD index 696876aae1f1ac..524cf5962d4d6f 100644 --- a/tensorflow/lite/tools/evaluation/proto/BUILD +++ b/tensorflow/lite/tools/evaluation/proto/BUILD @@ -92,7 +92,6 @@ cc_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "evaluation_stages_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":evaluation_stages_proto"], # ) diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc index 094d76134f3bac..b26af9f983a640 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h" #include +#include #include #include #include @@ -101,12 +102,12 @@ TfLiteStatus InferenceProfilerStage::Init( for (int i = 0; i < model_info_->inputs.size(); ++i) { const TfLiteType model_input_type = model_info_->inputs[i]->type; if (model_input_type == kTfLiteUInt8 || model_input_type == kTfLiteInt8 || - model_input_type == kTfLiteInt64 || - model_input_type == kTfLiteFloat32 || + model_input_type == kTfLiteInt32 || model_input_type == kTfLiteInt64 || + model_input_type == kTfLiteBool || model_input_type == kTfLiteFloat32 || model_input_type == kTfLiteFloat16) { } else { LOG(ERROR) << "InferenceProfilerStage only supports " - "float16/float32/int8/uint8/int64 " + "float16/float32/int8/uint8/int32/int64/bool " "input types"; return kTfLiteError; } @@ -121,14 +122,18 @@ TfLiteStatus InferenceProfilerStage::Init( int8_tensors_.emplace_back(); float16_tensors_.emplace_back(); int64_tensors_.emplace_back(); + int32_tensors_.emplace_back(); + bool_tensors_.emplace_back(); } // Preprocess output metadata for calculating diffs later. for (int i = 0; i < model_info_->outputs.size(); ++i) { const TfLiteType model_output_type = model_info_->outputs[i]->type; if (model_output_type == kTfLiteUInt8 || model_output_type == kTfLiteInt8 || + model_output_type == kTfLiteInt32 || model_output_type == kTfLiteBool || model_output_type == kTfLiteFloat32) { } else { - LOG(ERROR) << "InferenceProfilerStage only supports float32/int8/uint8 " + LOG(ERROR) << "InferenceProfilerStage only supports " + "float32/int8/uint8/int32/bool " "output types"; return kTfLiteError; } @@ -160,11 +165,20 @@ TfLiteStatus InferenceProfilerStage::Run() { input_num_elements_[i], std::numeric_limits::min(), std::numeric_limits::max(), &int8_tensors_[i]); input_ptrs.push_back(int8_tensors_[i].data()); + } else if (model_input_type == kTfLiteInt32) { + GenerateRandomGaussianData( + input_num_elements_[i], std::numeric_limits::min(), + std::numeric_limits::max(), &int32_tensors_[i]); + input_ptrs.push_back(int32_tensors_[i].data()); } else if (model_input_type == kTfLiteInt64) { GenerateRandomGaussianData( input_num_elements_[i], std::numeric_limits::min(), std::numeric_limits::max(), &int64_tensors_[i]); input_ptrs.push_back(int64_tensors_[i].data()); + } else if (model_input_type == kTfLiteBool) { + GenerateRandomGaussianData(input_num_elements_[i], 0, 1, + &bool_tensors_[i]); + input_ptrs.push_back(bool_tensors_[i].data()); } else if (model_input_type == kTfLiteFloat32) { GenerateRandomGaussianData(input_num_elements_[i], -1, 1, &(float_tensors_[i])); @@ -179,7 +193,7 @@ TfLiteStatus InferenceProfilerStage::Run() { input_ptrs.push_back(float16_tensors_[i].data()); } else { LOG(ERROR) << "InferenceProfilerStage only supports " - "float16/float32/int8/uint8/int64 " + "float16/float32/int8/uint8/int32/int64/bool " "input types"; return kTfLiteError; } @@ -205,6 +219,15 @@ TfLiteStatus InferenceProfilerStage::Run() { output_diff = CalculateAverageError(static_cast(reference_ptr), static_cast(test_ptr), output_num_elements_[i]); + } else if (model_output_type == kTfLiteInt32) { + output_diff = CalculateAverageError(static_cast(reference_ptr), + static_cast(test_ptr), + output_num_elements_[i]); + } else if (model_output_type == kTfLiteBool) { + // Use int8_t* for bool tensors to use void* casting. + output_diff = CalculateAverageError(static_cast(reference_ptr), + static_cast(test_ptr), + output_num_elements_[i]); } else if (model_output_type == kTfLiteFloat32) { output_diff = CalculateAverageError(static_cast(reference_ptr), static_cast(test_ptr), diff --git a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h index d48c836f035b44..a68049ed960b11 100644 --- a/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h +++ b/tensorflow/lite/tools/evaluation/stages/inference_profiler_stage.h @@ -66,7 +66,10 @@ class InferenceProfilerStage : public EvaluationStage { std::vector> int8_tensors_; std::vector> uint8_tensors_; std::vector> float16_tensors_; + std::vector> int32_tensors_; std::vector> int64_tensors_; + // Use uint8_t for bool tensors to use void* casting. + std::vector> bool_tensors_; }; } // namespace evaluation diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD index 3a0b2fb9eb0633..82b1567212e997 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/BUILD @@ -18,6 +18,7 @@ cc_library( "//tensorflow/lite/core/c:common", "//tensorflow/lite/tools:command_line_flags", "//tensorflow/lite/tools:logging", + "//tensorflow/lite/tools/evaluation:evaluation_delegate_provider", "//tensorflow/lite/tools/evaluation:evaluation_stage", "//tensorflow/lite/tools/evaluation:utils", "//tensorflow/lite/tools/evaluation/proto:evaluation_config_cc_proto", diff --git a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc index c2a082447866aa..1dbb26a0176d91 100644 --- a/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc +++ b/tensorflow/lite/tools/evaluation/tasks/imagenet_image_classification/run_eval.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/tools/command_line_flags.h" +#include "tensorflow/lite/tools/evaluation/evaluation_delegate_provider.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h" #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h" #include "tensorflow/lite/tools/evaluation/stages/image_classification_stage.h" diff --git a/tensorflow/lite/tools/pip_package/Dockerfile.py3 b/tensorflow/lite/tools/pip_package/Dockerfile.py3 index 63373905a63c4e..6459458633ca8a 100644 --- a/tensorflow/lite/tools/pip_package/Dockerfile.py3 +++ b/tensorflow/lite/tools/pip_package/Dockerfile.py3 @@ -45,7 +45,6 @@ RUN apt-get update && \ python$PYTHON_VERSION \ python$PYTHON_VERSION-dev \ python$PYTHON_VERSION-venv \ - python$PYTHON_VERSION-distutils \ libpython$PYTHON_VERSION-dev \ libpython$PYTHON_VERSION-dev:armhf \ libpython$PYTHON_VERSION-dev:arm64 diff --git a/tensorflow/lite/tools/versioning/BUILD b/tensorflow/lite/tools/versioning/BUILD index f173ce2c89734b..4caaa6bbc07071 100644 --- a/tensorflow/lite/tools/versioning/BUILD +++ b/tensorflow/lite/tools/versioning/BUILD @@ -108,6 +108,7 @@ cc_library( ":op_signature", "//tensorflow/lite:builtin_op_data", "//tensorflow/lite:builtin_ops", + "//tensorflow/lite:util", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/c:common", "//tensorflow/lite/schema:schema_fbs", diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc index a71042ad7f32f0..d358fdda411f42 100644 --- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc +++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" #include +#include +#include #include #include @@ -26,6 +28,7 @@ limitations under the License. #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/tools/versioning/op_signature.h" +#include "tensorflow/lite/util.h" namespace tflite { @@ -371,15 +374,11 @@ absl::Status CheckSelectV2GpuDelegateCompatibility(const OpSignature& op_sig) { } // Only supports float inputs with non-broadcastable or scalar if/else. absl::Status error = absl::InvalidArgumentError( - "Cond must be float or bool type, if, else tensors must be float and " + "Cond must be float or bool type, if, else tensors must be " "either be same the shape as output or constant, scalar."); - if ((op_sig.inputs.at(0).type != kTfLiteBool && - op_sig.inputs.at(0).type != kTfLiteFloat16 && - op_sig.inputs.at(0).type != kTfLiteFloat32) || - (op_sig.inputs.at(1).type != kTfLiteFloat16 && - op_sig.inputs.at(1).type != kTfLiteFloat32) || - (op_sig.inputs.at(2).type != kTfLiteFloat16 && - op_sig.inputs.at(2).type != kTfLiteFloat32)) { + if (op_sig.inputs.at(0).type != kTfLiteBool && + op_sig.inputs.at(0).type != kTfLiteFloat16 && + op_sig.inputs.at(0).type != kTfLiteFloat32) { return error; } std::vector output_dims = op_sig.outputs[0].dims; @@ -545,6 +544,67 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, return absl::OkStatus(); } + case kTfLiteBuiltinBitcast: { + RETURN_IF_ERROR(CheckInputsOutputs(op_sig, + /*required_runtime_inputs=*/1, + /*required_outputs=*/1)); + std::vector input_dims = op_sig.inputs.at(0).dims; + std::vector output_dims = op_sig.outputs.at(0).dims; + size_t input_elem_size, output_elem_size; + TfLiteStatus status = GetSizeOfType( + /*context=*/nullptr, op_sig.inputs.at(0).type, &input_elem_size); + if (status != kTfLiteOk) { + return absl::InternalError("Could not parse input type"); + } + status = GetSizeOfType(/*context=*/nullptr, op_sig.outputs.at(0).type, + &output_elem_size); + if (status != kTfLiteOk) { + return absl::InternalError("Could not parse output type"); + } + if (input_elem_size == output_elem_size) { + if (input_dims != output_dims) { + return absl::InternalError( + "If input and output types have the same element size, they must " + "have the same shapes"); + } + } else if (input_elem_size > output_elem_size) { + if (input_dims.size() + 1 != output_dims.size()) { + return absl::InternalError( + "If input element size is greater than output element size, " + "require that input rank is one greater than output rank"); + } + for (int d = 0; d < input_dims.size(); ++d) { + if (input_dims[d] != output_dims[d]) { + return absl::InternalError("Shapes must match in all but last dim"); + } + } + if (output_dims[output_dims.size() - 1] * output_elem_size != + input_elem_size) { + return absl::InternalError( + "Last output dim must be equal to input element size divided by " + "output element size"); + } + } else { // output_elem_size > input_elem_size + if (input_dims.size() != output_dims.size() + 1) { + return absl::InternalError( + "If output element size is greater than input element size, " + "require that output rank is on greater than input rank"); + } + for (int d = 0; d < output_dims.size(); ++d) { + if (input_dims[d] != output_dims[d]) { + return absl::InternalError("Shapes must match in all but last dim"); + } + } + if (input_dims[input_dims.size() - 1] * input_elem_size != + output_elem_size) { + return absl::InternalError( + "Last input dim must be equal to output element size divided by " + "input element size"); + } + } + return absl::OkStatus(); + } + case kTfLiteBuiltinCast: RETURN_IF_ERROR(CheckInputsOutputs(op_sig, /*required_runtime_inputs=*/1, @@ -1078,6 +1138,7 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, case kTfLiteBuiltinFloor: case kTfLiteBuiltinGelu: case kTfLiteBuiltinLog: + case kTfLiteBuiltinLogicalNot: case kTfLiteBuiltinLogistic: // Sigmoid case kTfLiteBuiltinNeg: case kTfLiteBuiltinRsqrt: @@ -1099,13 +1160,16 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, case kTfLiteBuiltinGreater: case kTfLiteBuiltinGreaterEqual: case kTfLiteBuiltinLogicalAnd: + case kTfLiteBuiltinLogicalOr: case kTfLiteBuiltinLess: case kTfLiteBuiltinLessEqual: case kTfLiteBuiltinMaximum: case kTfLiteBuiltinMinimum: case kTfLiteBuiltinNotEqual: case kTfLiteBuiltinPow: + case kTfLiteBuiltinRightShift: case kTfLiteBuiltinStablehloRemainder: + case kTfLiteBuiltinStablehloShiftLeft: case kTfLiteBuiltinSquaredDifference: case kTfLiteBuiltinSub: { if (!CheckInputsConstsOutputs(op_sig, /*required_runtime_inputs=*/2, @@ -1152,6 +1216,18 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, "Require size(indices) = rank(operand)"); } return absl::OkStatus(); + case kTfLiteBuiltinStablehloCbrt: + if (op_sig.inputs[0].type != kTfLiteFloat16 && + op_sig.inputs[0].type != kTfLiteFloat32 && + op_sig.inputs[0].type != kTfLiteBFloat16) { + return absl::InvalidArgumentError("Only support float inputs"); + } + if (op_sig.inputs[0].type != op_sig.outputs[0].type) { + return absl::InvalidArgumentError("Input and output types must match"); + } + return CheckInputsConstsOutputs(op_sig, /*required_runtime_inputs=*/1, + /*required_const_inputs=*/0, + /*required_outputs=*/1); case kTfLiteBuiltinStablehloClamp: if ((op_sig.inputs.at(0).type != op_sig.inputs.at(1).type) || (op_sig.inputs.at(1).type != op_sig.inputs.at(2).type)) { diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index c08da1f62201fc..8344a9b4dc0a32 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -1,5 +1,4 @@ tf_staging/BUILD: -tf_staging/ci/official/wheel_test/BUILD: tf_staging/tensorflow/__init__:.py tf_staging/tensorflow/api_template.__init__:.py tf_staging/tensorflow/api_template_v1.__init__:.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 0ad8e84b889f18..0211c7abc83e95 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -11,7 +11,21 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") # Placeholder: load py_proto_library load("//tensorflow:tensorflow.bzl", "VERSION", "cc_header_only_library", "clean_dep", "if_google", "if_oss", "if_windows", "if_xla_available", "tf_enable_mlir_bridge", "tf_python_pybind_static_deps", "tsl_async_value_deps") -load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "pywrap_tensorflow_macro", "tf_external_workspace_visible", "tf_monitoring_python_deps", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") +load( + "//tensorflow:tensorflow.default.bzl", + "get_compatible_with_portable", + "pybind_extension", + "pywrap_aware_cc_import", + "pywrap_aware_filegroup", + "pywrap_aware_genrule", + "pywrap_common_library", + "pywrap_library", + "pywrap_tensorflow_macro", + "tf_external_workspace_visible", + "tf_monitoring_python_deps", + "tf_pybind_cc_library_wrapper", + "tf_python_pybind_extension", +) load( "//tensorflow/core/platform:build_config.bzl", "tf_additional_binary_deps", @@ -19,6 +33,7 @@ load( ) load( "//tensorflow/core/platform:build_config_root.bzl", + "if_pywrap", "if_static", "tf_additional_plugin_deps", "tf_additional_profiler_deps", @@ -873,7 +888,7 @@ pywrap_tensorflow_macro( # ** Targets for Windows build (start) ** # We need the following targets to expose symbols from _pywrap_tensorflow.dll -filegroup( +pywrap_aware_filegroup( name = "win_lib_files_for_exported_symbols", srcs = [ "//tensorflow/c:checkpoint_reader", # checkpoint_reader @@ -975,7 +990,7 @@ filegroup( # Filter the DEF file to reduce the number of symbols to 64K or less. # Note that we also write the name of the pyd file into DEF file so that # the dynamic libraries of custom ops can find it at runtime. -genrule( +pywrap_aware_genrule( name = "pywrap_tensorflow_filtered_def_file", srcs = select({ "//tensorflow:windows": [ @@ -1003,7 +1018,7 @@ genrule( ) # Write to a file a list of all cc_library targets that we need for exporting symbols on Windows. -genrule( +pywrap_aware_genrule( name = "pybind_symbol_target_libs_file", srcs = [":win_lib_files_for_exported_symbols"], outs = ["pybind_symbol_target_libs_file.txt"], @@ -1019,25 +1034,25 @@ genrule( ) # Get the import library of _pywrap_tensorflow_internal.pyd, platform-specific to Windows. -filegroup( +pywrap_aware_filegroup( name = "get_pywrap_tensorflow_import_lib_file", srcs = [":_pywrap_tensorflow_internal.so"], output_group = "interface_library", ) -cc_import( +pywrap_aware_cc_import( name = "_pywrap_tensorflow_internal_linux", shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.so", visibility = tf_external_workspace_visible(visibility), ) -cc_import( +pywrap_aware_cc_import( name = "_pywrap_tensorflow_internal_macos", shared_library = "//tensorflow/python:lib_pywrap_tensorflow_internal.dylib", visibility = tf_external_workspace_visible(visibility), ) -cc_import( +pywrap_aware_cc_import( name = "_pywrap_tensorflow_internal_windows", interface_library = "//tensorflow/python:pywrap_tensorflow_import_lib_file", shared_library = "//tensorflow/python:_pywrap_tensorflow_internal.dll", @@ -1046,7 +1061,7 @@ cc_import( # Rename the import library for _pywrap_tensorflow_internal.pyd to _pywrap_tensorflow_internal.lib # (It was _pywrap_tensorflow_internal.so.if.lib). -genrule( +pywrap_aware_genrule( name = "pywrap_tensorflow_import_lib_file", srcs = [":get_pywrap_tensorflow_import_lib_file"], outs = ["_pywrap_tensorflow_internal.lib"], @@ -1059,7 +1074,7 @@ genrule( # Create a cc_import rule for the import library of _pywrap_tensorflow_internal.dll # so that custom ops' dynamic libraries can link against it. -cc_import( +pywrap_aware_cc_import( name = "pywrap_tensorflow_import_lib", interface_library = select({ "//tensorflow:windows": ":pywrap_tensorflow_import_lib_file", @@ -1097,15 +1112,19 @@ tf_python_pybind_extension( "_pywrap_mlir.pyi", ], deps = [ + "//tensorflow/compiler/tf2tensorrt:common_utils", + "//tensorflow/compiler/tf2tensorrt:trt_parameters", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", + "//tensorflow/core/platform:statusor", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", "//tensorflow/python/lib/core:safe_pyobject_ptr", "//third_party/python_runtime:headers", "@com_google_absl//absl/strings", "@pybind11", - ], + ] + if_pywrap(["//tensorflow/compiler/mlir/python:mlir"]), ) py_strict_library( @@ -1194,24 +1213,26 @@ py_strict_library( tf_python_pybind_extension( name = "_pywrap_tfe", srcs = ["tfe_wrapper.cc"], - hdrs = [ - "//tensorflow/c:headers", - "//tensorflow/c:safe_ptr_hdr", - "//tensorflow/c/eager:headers", - "//tensorflow/c/eager:pywrap_required_hdrs", - "//tensorflow/c/experimental/ops:pywrap_required_hdrs", - "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime/coordination:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", - "//tensorflow/python/eager:pywrap_required_hdrs", - "//tensorflow/python/lib/core:py_exception_registry_hdr", - "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", - "//tensorflow/python/util:util_hdr", - "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", - "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_xla//xla/tsl/python/lib/core:numpy_hdr", - ], + hdrs = if_pywrap( + if_false = [ + "//tensorflow/c:headers", + "//tensorflow/c:safe_ptr_hdr", + "//tensorflow/c/eager:headers", + "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", + "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", + "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", + "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/coordination:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", + "//tensorflow/python/eager:pywrap_required_hdrs", + "//tensorflow/python/lib/core:py_exception_registry_hdr", + "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", + "//tensorflow/python/util:util_hdr", + ], + ), dynamic_deps = [":_pywrap_tensorflow_internal.so"] + select({ "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.%s.dylib" % VERSION], "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION], @@ -1248,18 +1269,20 @@ tf_python_pybind_extension( "@com_google_absl//absl/types:optional", "@pybind11", # copybara:uncomment "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ] + if_static( + ] + if_pywrap( + if_true = ["//tensorflow/python/eager:pywrap_tfe_lib"], + ) + if_static( extra_deps = [ "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) @@ -1401,7 +1424,6 @@ py_strict_library( # py_proto_library( # name = "protos_all_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":protos_all"], # ) # @@ -1426,3 +1448,175 @@ pytype_strict_library( "//tensorflow/python/util:tf_export", ], ) + +pywrap_library( + name = "_pywrap_tensorflow", + cc_deps_filter = [ + "@com_google_protobuf//:protobuf", + "@com_google_protobuf//:protobuf_lite", + "@zlib//:zlib", + ], + linkopts = select({ + "//tensorflow:windows": [ + "-DEFAULTLIB:ws2_32.lib", + "-DEFAULTLIB:advapi32.lib", + "-DEFAULTLIB:crypt32.lib", + "-DEFAULTLIB:Normaliz.lib", + ], + "//conditions:default": [], + }), + py_cc_deps_filter = select({ + "//tensorflow:windows": [], + "//conditions:default": [ + "@local_xla//xla/tsl/python/lib/core:ml_dtypes_lib", + "@local_xla//xla/tsl/python/lib/core:numpy", + "@local_xla//xla/backends/profiler/cpu:python_tracer_impl", + "@local_xla//xla/backends/profiler/cpu:python_tracer", + "@local_xla//xla/python/profiler/internal:python_hooks", + "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", + "//tensorflow/lite/python/interpreter_wrapper:python_utils", + "//tensorflow/lite/toco/python:toco_python_api", + "//tensorflow/python/client:tf_session_helper", + "//tensorflow/python/eager:pywrap_tfe_lib", + "//tensorflow/python/framework:op_def_util_cc", + "//tensorflow/python/framework:py_context_manager", + "//tensorflow/python/framework:python_api_info", + "//tensorflow/python/framework:python_api_parameter_converter", + "//tensorflow/python/framework:python_tensor_converter", + "//tensorflow/python/framework:python_api_dispatcher", + "//tensorflow/python/lib/core:ndarray_tensor_bridge", + "//tensorflow/python/lib/core:ndarray_tensor", + "//tensorflow/python/lib/core:py_seq_tensor", + "//tensorflow/python/lib/core:py_util", + "//tensorflow/python/lib/core:py_exception_registry", + "//tensorflow/python/lib/core:py_func_lib", + "//tensorflow/python/util:cpp_python_util", + "//tensorflow/python/util:function_parameter_canonicalizer", + "//tensorflow/python/util:stack_trace", + "//tensorflow/python/util:cpp_nest", + "//tensorflow/compiler/mlir/lite/python:converter_python_api", + "//tensorflow/lite/python/metrics:metrics_wrapper_lib", + "//tensorflow/lite/python/interpreter_wrapper:interpreter_wrapper_lib", + "//tensorflow/lite/python/interpreter_wrapper:numpy", + "//tensorflow/lite/python/optimize:calibration_wrapper_lib", + ], + }), + visibility = ["//visibility:public"], + # win_def_file = "_pywrap_tensorflow.def", + deps = [ + ":_pywrap_quantize_training", + ":_pywrap_tensorflow_cc_only", + "//tensorflow/compiler/mlir/lite/python:_pywrap_converter_api", + "//tensorflow/compiler/mlir/python/mlir_wrapper:filecheck_wrapper", + "//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_function_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:pywrap_quantize_model", + "//tensorflow/compiler/mlir/stablehlo:stablehlo_extension", + "//tensorflow/compiler/mlir/tfr:tfr_wrapper", + "//tensorflow/compiler/tf2tensorrt:_pywrap_py_utils", + "//tensorflow/lite/python/analyzer_wrapper:_pywrap_analyzer_wrapper", + "//tensorflow/lite/python/interpreter_wrapper:_pywrap_tensorflow_interpreter_wrapper", + "//tensorflow/lite/python/metrics:_pywrap_tensorflow_lite_metrics_wrapper", + "//tensorflow/lite/python/optimize:_pywrap_tensorflow_lite_calibration_wrapper", + "//tensorflow/python:_pywrap_dtensor_device", + "//tensorflow/python:_pywrap_mlir", + "//tensorflow/python:_pywrap_parallel_device", + "//tensorflow/python:_pywrap_py_exception_registry", + "//tensorflow/python:_pywrap_sanitizers", + "//tensorflow/python:_pywrap_tfcompile", + "//tensorflow/python:_pywrap_tfe", + "//tensorflow/python:_pywrap_toco_api", + "//tensorflow/python:flags_pybind", + "//tensorflow/python/autograph/impl/testing:pybind_for_testing", + "//tensorflow/python/client:_pywrap_debug_events_writer", + "//tensorflow/python/client:_pywrap_device_lib", + "//tensorflow/python/client:_pywrap_events_writer", + "//tensorflow/python/client:_pywrap_tf_session", + "//tensorflow/python/data/experimental/service:_pywrap_server_lib", + "//tensorflow/python/data/experimental/service:_pywrap_snapshot_utils", + "//tensorflow/python/data/experimental/service:_pywrap_utils_exp", + "//tensorflow/python/framework:_dtypes", + "//tensorflow/python/framework:_errors_test_helper", + "//tensorflow/python/framework:_op_def_library_pybind", + "//tensorflow/python/framework:_op_def_registry", + "//tensorflow/python/framework:_op_def_util", + "//tensorflow/python/framework:_proto_comparators", + "//tensorflow/python/framework:_py_context_manager", + "//tensorflow/python/framework:_python_memory_checker_helper", + "//tensorflow/python/framework:_pywrap_python_api_dispatcher", + "//tensorflow/python/framework:_pywrap_python_api_info", + "//tensorflow/python/framework:_pywrap_python_api_parameter_converter", + "//tensorflow/python/framework:_pywrap_python_op_gen", + "//tensorflow/python/framework:_pywrap_python_tensor_converter", + "//tensorflow/python/framework:_test_metrics_util", + "//tensorflow/python/framework/experimental:_math_ops", + "//tensorflow/python/framework/experimental:_nn_ops", + "//tensorflow/python/framework/experimental:_tape", + "//tensorflow/python/framework/experimental:_unified_api", + "//tensorflow/python/grappler:_pywrap_cost_analyzer", + "//tensorflow/python/grappler:_pywrap_model_analyzer", + "//tensorflow/python/grappler:_pywrap_tf_cluster", + "//tensorflow/python/grappler:_pywrap_tf_item", + "//tensorflow/python/grappler:_pywrap_tf_optimizer", + "//tensorflow/python/lib/core:_pywrap_py_func", + "//tensorflow/python/lib/io:_pywrap_file_io", + "//tensorflow/python/lib/io:_pywrap_record_io", + "//tensorflow/python/platform:_pywrap_cpu_feature_guard", + "//tensorflow/python/platform:_pywrap_stacktrace_handler", + "//tensorflow/python/platform:_pywrap_tf2", + "//tensorflow/python/profiler/internal:_pywrap_profiler", + "//tensorflow/python/profiler/internal:_pywrap_traceme", + "//tensorflow/python/saved_model:pywrap_saved_model", + "//tensorflow/python/tpu:_pywrap_sparse_core_layout", + "//tensorflow/python/tpu:_pywrap_tpu_embedding", + "//tensorflow/python/util:_function_parameter_canonicalizer_binding_for_test", + "//tensorflow/python/util:_pywrap_checkpoint_reader", + "//tensorflow/python/util:_pywrap_determinism", + "//tensorflow/python/util:_pywrap_kernel_registry", + "//tensorflow/python/util:_pywrap_nest", + "//tensorflow/python/util:_pywrap_tensor_float_32_execution", + "//tensorflow/python/util:_pywrap_tfprof", + "//tensorflow/python/util:_pywrap_transform_graph", + "//tensorflow/python/util:_pywrap_util_port", + "//tensorflow/python/util:_pywrap_utils", + "//tensorflow/python/util:_tf_stack", + "//tensorflow/python/util:fast_module_type", + "//tensorflow/python/util:pywrap_xla_ops", + ], +) + +pybind_extension( + name = "_pywrap_tensorflow_cc_only", + srcs = [], + deps = [ + ":_protobuf_inline_symbols_enforcer", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core/distributed_runtime:server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", + "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/core/kernels:data_service_ops", + "//tensorflow/core/kernels:reader_ops", + "//tensorflow/distribute/experimental/rpc/kernels:rpc_ops", + "//tensorflow/dtensor/cc:tensor_layout", + "@local_xla//xla/backends/profiler/cpu:python_tracer", + ], +) + +cc_library( + name = "_protobuf_inline_symbols_enforcer", + srcs = ["protobuf_inline_symbols_enforcer.cc"], + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:function_proto_cc", + "//tensorflow/core/framework:graph_proto_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "//tensorflow/dtensor/proto:layout_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + ], +) + +pywrap_common_library( + name = "_pywrap_tensorflow_common", + dep = ":_pywrap_tensorflow", +) diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index 76c4ccad009a29..d3719e0ea17e55 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -2,7 +2,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test", "tf_python_pybind_extension") -load("//tensorflow/core/platform:build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap", "if_static") load( "//tensorflow/tools/test:performance.bzl", "cuda_py_benchmark_test", @@ -31,21 +31,29 @@ py_strict_library( tf_python_pybind_extension( name = "_pywrap_tf_session", srcs = ["tf_session_wrapper.cc"], - hdrs = [ - "tf_session_helper.h", - "//tensorflow/c:headers", - "//tensorflow/c:safe_ptr_hdr", - "//tensorflow/c/eager:headers", - "//tensorflow/c/eager:pywrap_required_hdrs", - "//tensorflow/c/experimental/ops:pywrap_required_hdrs", - "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime/coordination:pywrap_required_hdrs", - "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", - "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", - "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", - "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", - "@local_xla//xla/tsl/python/lib/core:numpy_hdr", + hdrs = if_pywrap( + if_false = [ + "tf_session_helper.h", + "//tensorflow/c:headers", + "//tensorflow/c:safe_ptr_hdr", + "//tensorflow/c/eager:headers", + "//tensorflow/c/eager:pywrap_required_hdrs", + "//tensorflow/c/experimental/ops:pywrap_required_hdrs", + "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", + "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", + "@local_xla//xla/tsl/python/lib/core:numpy_hdr", + "//tensorflow/core/common_runtime/eager:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/coordination:pywrap_required_hdrs", + "//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs", + "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", + ], + if_true = [], + ), + additional_exported_symbols = [ + "_TF_SetTarget", + "_TF_SetConfig", + "_TF_NewSessionOptions", ], enable_stub_generation = True, pytype_srcs = [ @@ -72,19 +80,25 @@ tf_python_pybind_extension( "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ] + if_static( + ] + if_pywrap([ + "//tensorflow/c:safe_ptr", + "//tensorflow/c:c_api_experimental", + "//tensorflow/python/client:tf_session_helper", + "//tensorflow/c:python_api", + "//tensorflow/core/common_runtime:core_cpu_lib", + ]) + if_static( extra_deps = [ "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", "//tensorflow/core:version_lib", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) diff --git a/tensorflow/python/client/events_writer_wrapper.cc b/tensorflow/python/client/events_writer_wrapper.cc index 661c845b3aac57..8d555c6a9908e9 100644 --- a/tensorflow/python/client/events_writer_wrapper.cc +++ b/tensorflow/python/client/events_writer_wrapper.cc @@ -26,7 +26,7 @@ limitations under the License. namespace py = pybind11; PYBIND11_MODULE(_pywrap_events_writer, m) { - py::class_ Status(m, "Status", py::module_local()); + py::class_ Status(m, "Status", py::module_local()); py::class_ events_writer_class(m, "EventsWriter"); events_writer_class.def(py::init()) .def("InitWithSuffix", diff --git a/tensorflow/python/client/session_ref.cc b/tensorflow/python/client/session_ref.cc index 2a1bc7cfe41b28..1cc345277f3244 100644 --- a/tensorflow/python/client/session_ref.cc +++ b/tensorflow/python/client/session_ref.cc @@ -110,27 +110,28 @@ class SessionLogger { log_file_->Close().IgnoreError(); } - Status RecordNewSession(Session* session) { + absl::Status RecordNewSession(Session* session) { ReplayOp op; NewReplaySession* req = op.mutable_new_replay_session(); req->set_session_handle(SessionToHandle(session)); return Flush(op); } - Status RecordRun(Session* session, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) { + absl::Status RecordRun(Session* session, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) { return RecordRun(session, *kEmptyRunOptions(), inputs, output_tensor_names, target_node_names, outputs, nullptr); } - Status RecordRun(Session* session, const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, RunMetadata* run_metadata) { + absl::Status RecordRun(Session* session, const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata) { ReplayOp op; RunStepRequest* req = op.mutable_run_step(); RunStepResponse* resp = op.mutable_run_step_response(); @@ -179,13 +180,13 @@ class SessionLogger { return Flush(op); } - Status RecordCreate(Session* session, const GraphDef& graph) { + absl::Status RecordCreate(Session* session, const GraphDef& graph) { return RecordCreate(session, *kEmptyRunOptions(), graph); } // N.B. RunOptions is not stored (it has no entry in CreateRequest) - Status RecordCreate(Session* session, const RunOptions& run_options, - const GraphDef& graph) { + absl::Status RecordCreate(Session* session, const RunOptions& run_options, + const GraphDef& graph) { ReplayOp op; CreateSessionRequest* req = op.mutable_create_session(); *req->mutable_graph_def() = graph; @@ -200,13 +201,13 @@ class SessionLogger { return Flush(op); } - Status RecordExtend(Session* session, const GraphDef& graph) { + absl::Status RecordExtend(Session* session, const GraphDef& graph) { return RecordExtend(session, *kEmptyRunOptions(), graph); } // N.B. RunOptions is not stored (it has no entry in ExtendRequest) - Status RecordExtend(Session* session, const RunOptions& run_options, - const GraphDef& graph) { + absl::Status RecordExtend(Session* session, const RunOptions& run_options, + const GraphDef& graph) { ReplayOp op; ExtendSessionRequest* req = op.mutable_extend_session(); op.mutable_extend_session_response(); @@ -221,12 +222,12 @@ class SessionLogger { return Flush(op); } - Status RecordClose(Session* session) { + absl::Status RecordClose(Session* session) { return RecordClose(session, *kEmptyRunOptions()); } // N.B. RunOptions is not stored (it has no entry in CloseRequest) - Status RecordClose(Session* session, const RunOptions& run_options) { + absl::Status RecordClose(Session* session, const RunOptions& run_options) { ReplayOp op; CloseSessionRequest* req = op.mutable_close_session(); req->set_session_handle(SessionToHandle(session)); @@ -239,8 +240,8 @@ class SessionLogger { return Flush(op); } - Status RecordListDevices(Session* session, - std::vector* response) { + absl::Status RecordListDevices(Session* session, + std::vector* response) { ReplayOp op; ListDevicesRequest* req = op.mutable_list_devices(); ListDevicesResponse* resp = op.mutable_list_devices_response(); @@ -252,11 +253,11 @@ class SessionLogger { return Flush(op); } - Status RecordPRunSetup(Session* session, - const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) { + absl::Status RecordPRunSetup(Session* session, + const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) { ReplayOp op; PartialRunSetupRequest* req = op.mutable_partial_run_setup(); req->set_session_handle(SessionToHandle(session)); @@ -275,10 +276,10 @@ class SessionLogger { return Flush(op); } - Status RecordPRun(Session* session, const string& handle, - const std::vector >& inputs, - const std::vector& output_names, - std::vector* outputs) { + absl::Status RecordPRun(Session* session, const string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) { ReplayOp op; RunStepRequest* req = op.mutable_run_step(); RunStepResponse* resp = op.mutable_run_step_response(); @@ -308,9 +309,9 @@ class SessionLogger { return Flush(op); } - Status RecordMakeCallable(Session* session, - const CallableOptions& callable_options, - Session::CallableHandle* handle) { + absl::Status RecordMakeCallable(Session* session, + const CallableOptions& callable_options, + Session::CallableHandle* handle) { ReplayOp op; MakeCallableRequest* req = op.mutable_make_callable(); req->set_session_handle(SessionToHandle(session)); @@ -324,10 +325,11 @@ class SessionLogger { return Flush(op); } - Status RecordRunCallable(Session* session, Session::CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) { + absl::Status RecordRunCallable(Session* session, + Session::CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { ReplayOp op; RunCallableRequest* req = op.mutable_run_callable(); req->set_session_handle(SessionToHandle(session)); @@ -348,8 +350,8 @@ class SessionLogger { return Flush(op); } - Status RecordReleaseCallable(Session* session, - Session::CallableHandle handle) { + absl::Status RecordReleaseCallable(Session* session, + Session::CallableHandle handle) { ReplayOp op; ReleaseCallableRequest* req = op.mutable_release_callable(); req->set_session_handle(SessionToHandle(session)); @@ -359,7 +361,7 @@ class SessionLogger { } private: - Status Flush(const ReplayOp& op) { + absl::Status Flush(const ReplayOp& op) { mutex_lock l(log_mutex_); string buf; @@ -391,7 +393,7 @@ SessionRef::SessionRef(Session* session) : session_(session) { SessionRef::~SessionRef() = default; -Status SessionRef::CheckNotClosed() { +absl::Status SessionRef::CheckNotClosed() { mutex_lock l(run_lock_); if (session_ == nullptr) return errors::Cancelled("Session has been closed."); return absl::OkStatus(); @@ -407,75 +409,75 @@ Status SessionRef::CheckNotClosed() { } \ return logger_->Record##OpName(rc.session.get(), __VA_ARGS__); -Status SessionRef::Run(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, - RunMetadata* run_metadata) { +absl::Status SessionRef::Run( + const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, std::vector* outputs, + RunMetadata* run_metadata) { LOG_AND_RUN_OPERATION(Run, run_options, inputs, output_tensor_names, target_node_names, outputs, run_metadata); } -Status SessionRef::Run(const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) { +absl::Status SessionRef::Run( + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) { LOG_AND_RUN_OPERATION(Run, inputs, output_tensor_names, target_node_names, outputs); } -Status SessionRef::Create(const GraphDef& graph) { +absl::Status SessionRef::Create(const GraphDef& graph) { LOG_AND_RUN_OPERATION(Create, graph); } -Status SessionRef::Create(const RunOptions& run_options, - const GraphDef& graph) { +absl::Status SessionRef::Create(const RunOptions& run_options, + const GraphDef& graph) { LOG_AND_RUN_OPERATION(Create, run_options, graph); } -Status SessionRef::Extend(const RunOptions& run_options, - const GraphDef& graph) { +absl::Status SessionRef::Extend(const RunOptions& run_options, + const GraphDef& graph) { LOG_AND_RUN_OPERATION(Extend, run_options, graph); } -Status SessionRef::Extend(const GraphDef& graph) { +absl::Status SessionRef::Extend(const GraphDef& graph) { LOG_AND_RUN_OPERATION(Extend, graph); } -Status SessionRef::ListDevices(std::vector* response) { +absl::Status SessionRef::ListDevices(std::vector* response) { LOG_AND_RUN_OPERATION(ListDevices, response); } -Status SessionRef::PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) { +absl::Status SessionRef::PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) { LOG_AND_RUN_OPERATION(PRunSetup, input_names, output_names, target_nodes, handle); } -Status SessionRef::PRun(const string& handle, - const std::vector >& inputs, - const std::vector& output_names, - std::vector* outputs) { +absl::Status SessionRef::PRun( + const string& handle, const std::vector >& inputs, + const std::vector& output_names, std::vector* outputs) { LOG_AND_RUN_OPERATION(PRun, handle, inputs, output_names, outputs); } -Status SessionRef::MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) { +absl::Status SessionRef::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { LOG_AND_RUN_OPERATION(MakeCallable, callable_options, out_handle); } -Status SessionRef::RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) { +absl::Status SessionRef::RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) { LOG_AND_RUN_OPERATION(RunCallable, handle, feed_tensors, fetch_tensors, run_metadata); } -Status SessionRef::ReleaseCallable(CallableHandle handle) { +absl::Status SessionRef::ReleaseCallable(CallableHandle handle) { { mutex_lock l(run_lock_); if (session_ == nullptr) { @@ -486,10 +488,10 @@ Status SessionRef::ReleaseCallable(CallableHandle handle) { LOG_AND_RUN_OPERATION(ReleaseCallable, handle); } -Status SessionRef::Close(const RunOptions& run_options) { +absl::Status SessionRef::Close(const RunOptions& run_options) { TF_RETURN_IF_ERROR(CheckNotClosed()); mutex_lock l(run_lock_); - Status status; + absl::Status status; if (logger_) { status = logger_->RecordClose(session_.get(), run_options); } else { @@ -502,10 +504,10 @@ Status SessionRef::Close(const RunOptions& run_options) { return status; } -Status SessionRef::Close() { +absl::Status SessionRef::Close() { TF_RETURN_IF_ERROR(CheckNotClosed()); mutex_lock l(run_lock_); - Status status; + absl::Status status; if (logger_) { status = logger_->RecordClose(session_.get()); } else { diff --git a/tensorflow/python/client/session_ref.h b/tensorflow/python/client/session_ref.h index a1d96c630c13f8..362bb3e5204512 100644 --- a/tensorflow/python/client/session_ref.h +++ b/tensorflow/python/client/session_ref.h @@ -34,45 +34,48 @@ class SessionRef : public Session { explicit SessionRef(Session* session); ~SessionRef() override; - Status Create(const GraphDef& graph) override; - Status Extend(const GraphDef& graph) override; - Status Create(const RunOptions& run_options, const GraphDef& graph) override; - Status Extend(const RunOptions& run_options, const GraphDef& graph) override; - Status Run(const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs) override; - - Status ListDevices(std::vector* response) override; - - Status Close() override; - Status Close(const RunOptions& run_options) override; - - Status Run(const RunOptions& run_options, - const std::vector >& inputs, - const std::vector& output_tensor_names, - const std::vector& target_node_names, - std::vector* outputs, RunMetadata* run_metadata) override; - - Status PRunSetup(const std::vector& input_names, - const std::vector& output_names, - const std::vector& target_nodes, - string* handle) override; - - Status PRun(const string& handle, - const std::vector >& inputs, - const std::vector& output_names, - std::vector* outputs) override; - - Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override; - - Status RunCallable(CallableHandle handle, - const std::vector& feed_tensors, - std::vector* fetch_tensors, - RunMetadata* run_metadata) override; - - Status ReleaseCallable(CallableHandle handle) override; + absl::Status Create(const GraphDef& graph) override; + absl::Status Extend(const GraphDef& graph) override; + absl::Status Create(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Extend(const RunOptions& run_options, + const GraphDef& graph) override; + absl::Status Run(const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs) override; + + absl::Status ListDevices(std::vector* response) override; + + absl::Status Close() override; + absl::Status Close(const RunOptions& run_options) override; + + absl::Status Run(const RunOptions& run_options, + const std::vector >& inputs, + const std::vector& output_tensor_names, + const std::vector& target_node_names, + std::vector* outputs, + RunMetadata* run_metadata) override; + + absl::Status PRunSetup(const std::vector& input_names, + const std::vector& output_names, + const std::vector& target_nodes, + string* handle) override; + + absl::Status PRun(const string& handle, + const std::vector >& inputs, + const std::vector& output_names, + std::vector* outputs) override; + + absl::Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; + + absl::Status RunCallable(CallableHandle handle, + const std::vector& feed_tensors, + std::vector* fetch_tensors, + RunMetadata* run_metadata) override; + + absl::Status ReleaseCallable(CallableHandle handle) override; private: mutex run_lock_; @@ -83,7 +86,7 @@ class SessionRef : public Session { // Borrowed reference to global session logger. SessionLogger* logger_; - Status CheckNotClosed(); + absl::Status CheckNotClosed(); }; } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 585d770d7c735e..18bf61c37c031c 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -88,7 +88,7 @@ void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, PyObject* value; Py_ssize_t pos = 0; int index = 0; - Status s; + absl::Status s; while (PyDict_Next(feed_dict, &pos, &key, &value)) { char* key_string = PyBytes_AsString(key); @@ -195,7 +195,7 @@ void MakeCallableHelper(tensorflow::Session* session, return; } tensorflow::Session::CallableHandle handle; - Status s = session->MakeCallable(callable_options_proto, &handle); + absl::Status s = session->MakeCallable(callable_options_proto, &handle); if (!s.ok()) { tsl::Set_TF_Status_from_Status(out_status, s); return; @@ -221,7 +221,7 @@ void RunCallableHelper(tensorflow::Session* session, int64_t handle, PyObjectVector* out_values, TF_Buffer* run_metadata) { // Convert feed values to a vector of tensorflow::Tensor objects. std::vector input_tensors; - Status s; + absl::Status s; { feed_values = PySequence_Fast(feed_values, "feed_values must be a sequence"); @@ -369,7 +369,7 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, DCHECK_EQ(inputs.size(), input_ndarrays.size()); DCHECK(py_outputs != nullptr); DCHECK(py_outputs->empty()); - Status s; + absl::Status s; // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous // array of TF_Tensor*s as well as scoped containers to make sure they're @@ -722,7 +722,7 @@ PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output, Safe_TF_TensorPtr safe_result_tensor(result_tensor); PyObject* out; - Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out); + absl::Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out); tsl::Set_TF_Status_from_Status(status, s); if (!s.ok()) Py_RETURN_NONE; return PyArray_Return(reinterpret_cast(out)); diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 6f2b1be39c6ebd..7f61e00eef8e60 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 9, 19) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 10, 27) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py b/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py index 4e5244d6f0037f..c12bc56f44abd4 100644 --- a/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py +++ b/tensorflow/python/compiler/tensorrt/model_tests/model_handler.py @@ -664,7 +664,7 @@ class ModelHandlerManagerV1(_ModelHandlerManagerBase): class ModelHandlerManagerV2(_ModelHandlerManagerBase): - """Manages a series of ModelHandlers for aggregrated testing/benchmarking in TF2.""" + """Manages a series of ModelHandlers for aggregated testing/benchmarking in TF2.""" model_handler_cls = ModelHandlerV2 trt_model_handler_cls = TrtModelHandlerV2 diff --git a/tensorflow/python/compiler/tensorrt/utils.py b/tensorflow/python/compiler/tensorrt/utils.py index a908f920b14996..7eadddae2f1517 100644 --- a/tensorflow/python/compiler/tensorrt/utils.py +++ b/tensorflow/python/compiler/tensorrt/utils.py @@ -242,7 +242,7 @@ def draw_graphdef_as_graphviz(graphdef, dot_output_filename): print(" }", file=f) - # Step 3: Alignement of the legend with the graph. + # Step 3: Alignment of the legend with the graph. print("\n edge[style=\"invisible\", dir=\"none\"];", file=f) for dtype in dtype_index.keys(): for node_name in nodes_with_no_inputs: diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD index 34c1098ccae8fc..ac6a8752f8db7c 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD @@ -196,6 +196,7 @@ tf_py_strict_test( size = "medium", srcs = ["optimization_test.py"], shard_count = 2, + tags = ["nomsan"], # Runs out of memory. deps = [ "//tensorflow/python/data/experimental/ops:batching", "//tensorflow/python/data/experimental/ops:grouping", diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py index 13bcb2656149cd..adb11faebcda90 100644 --- a/tensorflow/python/data/kernel_tests/dataset_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_test.py @@ -252,7 +252,7 @@ def testDebugString(self): dataset = dataset.filter(lambda x: x > 10) debug_string = dataset.__debug_string__() for transformation in ["Range", "Map", "Filter"]: - self.assertContainsSubsequence(debug_string, transformation) + self.assertIn(transformation, debug_string) @combinations.generate(test_base.default_test_combinations()) def testNoWarnings(self): diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index fe5c9cf394ef7e..a7960d75d0e0ca 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1813,6 +1813,7 @@ distribute_py_strict_test( "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:indexed_slices", diff --git a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py index 4af2925ebd8650..6e69ba14152cf2 100644 --- a/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py +++ b/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py @@ -144,7 +144,7 @@ def connect(tpu=None, """ resolver = TPUClusterResolver(tpu, zone, project) remote.connect_to_cluster(resolver) - tpu_strategy_util.initialize_tpu_system_impl(resolver) + tpu_strategy_util.initialize_tpu_system_impl(resolver, TPUClusterResolver) return resolver @staticmethod diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py index c7dac643cefb37..8899a7d2c04cd6 100644 --- a/tensorflow/python/distribute/collective_all_reduce_strategy.py +++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py @@ -19,6 +19,7 @@ import time import weakref +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.distribute import collective_util @@ -49,7 +50,6 @@ from tensorflow.python.trackable import base from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 # pylint: disable=line-too-long diff --git a/tensorflow/python/distribute/parameter_server_strategy_v2.py b/tensorflow/python/distribute/parameter_server_strategy_v2.py index d6bdcd0ac13d6f..9457fe576638b2 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_v2.py +++ b/tensorflow/python/distribute/parameter_server_strategy_v2.py @@ -21,6 +21,7 @@ import os import threading +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_lib @@ -50,7 +51,6 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 ALLOWED_TASK_TYPES = ("chief", "worker", "ps") diff --git a/tensorflow/python/distribute/vars_test.py b/tensorflow/python/distribute/vars_test.py index 5dd2c5a3b1ae4a..4cf07ddd13d958 100644 --- a/tensorflow/python/distribute/vars_test.py +++ b/tensorflow/python/distribute/vars_test.py @@ -31,6 +31,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import def_function from tensorflow.python.eager import test +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices @@ -654,7 +655,7 @@ def scatter_update(v): @combinations.generate(ms_combination + tpu_combination) def testScatterOpsWithNoneAggregation(self, distribution): - + config.disable_mlir_bridge() def assert_close(v, op, delta, expect): scatter_op = getattr(v, op) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 61e64a594a42cb..ded5af927f70b7 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -26,6 +26,7 @@ from absl import logging import numpy as np +from xla.tsl.protobuf import coordination_config_pb2 from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_debug_info_pb2 from tensorflow.core.protobuf import config_pb2 @@ -46,7 +47,6 @@ from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export -from tsl.protobuf import coordination_config_pb2 # TODO(b/307794935): Remove after a solution is found. @@ -1425,7 +1425,8 @@ def _compute_device_options(self, device_type="GPU"): ) visible_device_list = [] virtual_devices = [] - gpu_index = -1 + # This mapping is needed to handle multiple sub types of PluggableDevices. + device_to_indices = collections.defaultdict(int) memory_growths = set() compatible_devices = ( self.list_physical_devices("GPU") @@ -1434,14 +1435,18 @@ def _compute_device_options(self, device_type="GPU"): ) support_virtual_devices = device_type == "GPU" for dev in compatible_devices: - gpu_index += 1 + device_index = device_to_indices[dev.device_type] + device_to_indices[dev.device_type] += 1 if dev not in self._visible_device_list: continue growth = self._memory_growth_map[dev] memory_growths.add(growth) - visible_device_list.append(str(gpu_index)) + if device_type == "PluggableDevice": + visible_device_list.append(dev.device_type + ":" + str(device_index)) + else: + visible_device_list.append(str(device_index)) if support_virtual_devices and self._virtual_device_map: vdevs = self._virtual_device_map.get(dev, []) diff --git a/tensorflow/python/eager/polymorphic_function/compiler_ir.py b/tensorflow/python/eager/polymorphic_function/compiler_ir.py index 5bb959fb9e9697..97f4ed586f86dd 100644 --- a/tensorflow/python/eager/polymorphic_function/compiler_ir.py +++ b/tensorflow/python/eager/polymorphic_function/compiler_ir.py @@ -117,6 +117,8 @@ def compiler_ir_generator(stage="hlo", device_name=None, platform_name=None): stage=stage, ) if stage in ( + # Ordered by IrExportStage enum order + "stablehlo_serialized", "hlo_serialized", "optimized_hlo_serialized", "optimized_hlo_proto_serialized", diff --git a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py index deaa18bfddf770..ff48ff2920f714 100644 --- a/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py +++ b/tensorflow/python/eager/polymorphic_function/compiler_ir_test.py @@ -48,6 +48,17 @@ def _compareTwoMethodsCompilerIROutput(self, f, args, kwargs): f' \nhlo(concrete_input):\n{hlo_1}\nhlo(tensor_spec):\n{hlo_2}\n' ) + # Check that StableHLO conversion succeeds + hlo_3 = f.experimental_get_compiler_ir(*args, **kwargs)(stage='stablehlo') + self.assertIn('stablehlo', hlo_3) + + # Check that StableHLO bytecode conversion succeeds. + # MLIR bytecode files all begin with magic `MLiR` byte, check for byte. + hlo_4 = f.experimental_get_compiler_ir(*args, **kwargs)( + stage='stablehlo_serialized' + ) + self.assertIn(b'ML\xefR', hlo_4) + def test_zero_input(self): with ops.device('device:{}:0'.format(self.device)): diff --git a/tensorflow/python/eager/polymorphic_function/polymorphic_function.py b/tensorflow/python/eager/polymorphic_function/polymorphic_function.py index e504410c8d5b15..99c8fc453e818b 100644 --- a/tensorflow/python/eager/polymorphic_function/polymorphic_function.py +++ b/tensorflow/python/eager/polymorphic_function/polymorphic_function.py @@ -1029,8 +1029,13 @@ def compiler_ir_generator( captured_inputs=concrete_fn.captured_inputs, stage=stage, ) - if stage in ("hlo_serialized", "optimized_hlo_serialized", - "optimized_hlo_proto_serialized"): + if stage in ( + # Ordered by IrExportStage enum order + "stablehlo_serialized", + "hlo_serialized", + "optimized_hlo_serialized", + "optimized_hlo_proto_serialized", + ): return res_bytes else: return res_bytes.decode("utf-8") diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index e04ad38b81d4c9..5bfa389e92a08a 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -744,7 +744,7 @@ static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) { // not include a shape or dtype. static PyObject* EagerTensor_summarize_value(EagerTensor* self) { std::string summary; - tensorflow::Status status = + absl::Status status = tensorflow::unwrap(self->handle)->SummarizeValue(summary); if (MaybeRaiseExceptionFromStatus(status, nullptr)) { return nullptr; diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 67b9ddf2fb045e..13eb0c9b9ae950 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -112,7 +112,7 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception); // Returns 0 if 'status' is ok. Otherwise, raises an exception (using // `exception` if not nullptr, else using the class registered via // TFE_Py_RegisterExceptionClass), and returns -1. -int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, +int MaybeRaiseExceptionFromStatus(const absl::Status& status, PyObject* exception); // Returns the string associated with the passed-in python object. diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index bb22079755d1d5..b33cb2ad729201 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1039,7 +1039,7 @@ void RaiseFallbackException(const char* message) { // Format and return `status`' error message with the attached stack trace if // available. `status` must have an error. -std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { +std::string FormatErrorStatusStackTrace(const absl::Status& status) { tensorflow::DCheckPyGilState(); DCHECK(!status.ok()); @@ -1116,7 +1116,7 @@ int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { } // namespace tensorflow -int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, +int MaybeRaiseExceptionFromStatus(const absl::Status& status, PyObject* exception) { if (status.ok()) return 0; const char* msg = absl::StatusMessageAsCStr(status); @@ -1269,7 +1269,7 @@ class PyVSpace : public tensorflow::eager::VSpace(tensor)); - PyObject* result = PyEval_CallObject(num_elements_, arglist); + PyObject* result = PyObject_Call(num_elements_, arglist, nullptr); Py_DECREF(arglist); if (result == nullptr) { // The caller detects whether a python exception has been raised. @@ -1342,7 +1342,7 @@ class PyVSpace : public tensorflow::eager::VSpace& unneeded_gradients, absl::Span output_gradients, @@ -2205,7 +2205,7 @@ static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { tensorflow::TensorShape tensor_shape; int num_dims; - tensorflow::Status status = handle->NumDims(&num_dims); + absl::Status status = handle->NumDims(&num_dims); if (status.ok()) { for (int i = 0; i < num_dims; ++i) { int64_t dim_size; @@ -2485,7 +2485,7 @@ bool TapeSetRecordForwardprop( input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); } for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { - tensorflow::Status status = accumulator->accumulator->Accumulate( + absl::Status status = accumulator->accumulator->Accumulate( op_type, input_info, output_info, input_ids, input_dtypes, forward_function, backward_function_getter, backward_function_killer); if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. @@ -2519,8 +2519,8 @@ PyObject* TangentsAsPyTuple(const std::vector& input_tangents) { return py_input_tangents; } -tensorflow::Status ParseTangentOutputs( - PyObject* user_output, std::vector* output_tangents) { +absl::Status ParseTangentOutputs(PyObject* user_output, + std::vector* output_tangents) { if (user_output == Py_None) { // No connected gradients. return absl::OkStatus(); @@ -2551,11 +2551,11 @@ tensorflow::Status ParseTangentOutputs( // // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which // the forward function is being called. -tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, - PyObject* inputs, PyObject* results, - const std::vector& input_tangents, - std::vector* output_tangents, - bool use_batch) { +absl::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, + PyObject* inputs, PyObject* results, + const std::vector& input_tangents, + std::vector* output_tangents, + bool use_batch) { if (forward_gradient_function == nullptr) { return tensorflow::errors::Internal( "No forward gradient function registered."); @@ -2581,7 +2581,7 @@ tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, // Like CallJVPFunction, but calls a pre-bound forward function. // These are passed in from a record_gradient argument. -tensorflow::Status CallOpSpecificJVPFunction( +absl::Status CallOpSpecificJVPFunction( PyObject* op_specific_forward_function, const std::vector& input_tangents, std::vector* output_tangents) { diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi index 7c450b682a40a8..7cd1ac17d52a78 100644 --- a/tensorflow/python/flags_pybind.pyi +++ b/tensorflow/python/flags_pybind.pyi @@ -22,6 +22,7 @@ class Flags: enable_aggressive_constant_replication: Flag enable_colocation_key_propagation_in_while_op_lowering: Flag enable_function_pruning_before_inlining: Flag + enable_graph_debug_info_caching_for_stack_frames: Flag enable_nested_function_shape_inference: Flag enable_quantized_dtypes_training: Flag enable_skip_encapsulation_for_non_tpu_graphs: Flag diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index b19955efe7f682..029b59a119a58d 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -16,9 +16,9 @@ load( "tf_gen_op_wrapper_py", "tf_kernel_library", ) -load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_py_strict_test", "tf_python_framework_friends", "tf_python_pybind_extension") +load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "stripped_cc_info", "tf_py_strict_test", "tf_python_framework_friends", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "pyx_library", "tf_additional_all_protos", "tf_additional_lib_deps", "tf_proto_library", "tf_protos_grappler") # @unused -load("//tensorflow/core/platform:build_config_root.bzl", "if_static", "tf_additional_xla_deps_py") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap", "if_static", "tf_additional_xla_deps_py") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_strict_test") load( "//tensorflow/tools/test:performance.bzl", @@ -33,8 +33,8 @@ package( licenses = ["notice"], ) -tf_cc_shared_object( - name = "test_file_system.so", +cc_library( + name = "test_file_system", srcs = ["test_file_system.cc"], copts = if_not_windows(["-Wno-sign-compare"]), linkopts = select({ @@ -45,12 +45,32 @@ tf_cc_shared_object( "//tensorflow:windows": [], }), deps = [ - "//tensorflow/core:framework_headers_lib", - "@com_google_protobuf//:protobuf_headers", - "@eigen_archive//:eigen3", + "//tensorflow/core:lib", + "//tensorflow/core/platform:null_file_system", ], ) +stripped_cc_info( + name = "test_file_system_stripped", + deps = [":test_file_system"], +) + +tf_cc_shared_object( + name = "test_file_system.so", + srcs = if_pywrap(if_false = ["test_file_system.cc"]), + copts = if_not_windows(["-Wno-sign-compare"]), + deps = if_pywrap( + if_false = [ + "@eigen_archive//:eigen3", + "//tensorflow/core:framework_headers_lib", + ], + if_true = [ + ":test_file_system_stripped", + "//tensorflow/python:_pywrap_tensorflow_common", + ], + ) + ["@com_google_protobuf//:protobuf_headers"], +) + tf_py_strict_test( name = "file_system_test", size = "small", @@ -656,8 +676,10 @@ tf_python_pybind_extension( pytype_srcs = [ "_op_def_library_pybind.pyi", ], - deps = [ - ":op_def_util_headers", + deps = if_pywrap( + if_false = [":op_def_util_headers"], + if_true = [":op_def_util_cc"], + ) + [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:lib_proto_parsing", @@ -796,8 +818,10 @@ tf_python_pybind_extension( pytype_srcs = [ "_op_def_util.pyi", ], - deps = [ - ":op_def_util_headers", + deps = if_pywrap( + if_false = [":op_def_util_headers"], + if_true = [":op_def_util_cc"], + ) + [ "//tensorflow/core:lib_headers_for_pybind", "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_lib", @@ -852,10 +876,6 @@ tf_python_pybind_extension( name = "_pywrap_python_api_parameter_converter", srcs = ["python_api_parameter_converter_wrapper.cc"], hdrs = [ - "op_def_util.h", - "python_api_info.h", - "python_api_parameter_converter.h", - "python_tensor_converter.h", "//tensorflow/c:headers", "//tensorflow/c/eager:pywrap_required_hdrs", "//tensorflow/c/experimental/ops:pywrap_required_hdrs", @@ -868,7 +888,14 @@ tf_python_pybind_extension( "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", "@local_xla//xla/tsl/python/lib/core:numpy_hdr", - ], + ] + if_pywrap( + if_false = [ + "op_def_util.h", + "python_api_info.h", + "python_api_parameter_converter.h", + "python_tensor_converter.h", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_python_api_parameter_converter.pyi", @@ -890,18 +917,25 @@ tf_python_pybind_extension( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@pybind11", - ] + if_static( + ] + if_pywrap( + if_true = [ + ":python_api_parameter_converter", + ":python_api_info", + ":op_def_util_cc", + ":python_tensor_converter", + ], + ) + if_static( extra_deps = [ "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", ], otherwise = [ "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", ], ), ) @@ -953,9 +987,6 @@ tf_python_pybind_extension( name = "_pywrap_python_api_info", srcs = ["python_api_info_wrapper.cc"], hdrs = [ - "op_def_util.h", - "python_api_info.h", - "python_tensor_converter.h", "//tensorflow/c:headers", "//tensorflow/c/eager:pywrap_required_hdrs", "//tensorflow/c/experimental/ops:pywrap_required_hdrs", @@ -968,7 +999,13 @@ tf_python_pybind_extension( "@local_xla//xla/tsl/distributed_runtime:pywrap_required_hdrs", "@local_xla//xla/tsl/distributed_runtime/coordination:pywrap_required_hdrs", "@local_xla//xla/tsl/python/lib/core:numpy_hdr", - ], + ] + if_pywrap( + if_false = [ + "op_def_util.h", + "python_api_info.h", + "python_tensor_converter.h", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_python_api_info.pyi", @@ -990,15 +1027,21 @@ tf_python_pybind_extension( "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@pybind11", - ] + if_static( + ] + if_pywrap( + if_true = [ + ":python_api_info", + ":op_def_util_cc", + ":python_tensor_converter", + ], + ) + if_static( extra_deps = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", ], otherwise = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", @@ -1064,7 +1107,7 @@ tf_python_pybind_extension( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@pybind11", - ], + ] + if_pywrap([":python_api_dispatcher"]), ) tf_py_strict_test( @@ -1143,13 +1186,13 @@ tf_python_pybind_extension( "@pybind11", ] + if_static( extra_deps = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc", "//tensorflow/core/protobuf:eager_service_proto_cc", "//tensorflow/core/protobuf:master_proto_cc", "//tensorflow/core/protobuf:worker_proto_cc", ], otherwise = [ - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_headers_only", + "@local_xla//xla/tsl/protobuf:coordination_service_proto_cc_headers_only", "//tensorflow/core/protobuf:eager_service_proto_cc_headers_only", "//tensorflow/core/protobuf:master_proto_cc_headers_only", "//tensorflow/core/protobuf:worker_proto_cc_headers_only", @@ -2153,6 +2196,7 @@ pytype_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_inspect", @@ -3313,7 +3357,6 @@ tf_python_pybind_extension( # py_proto_library( # name = "cpp_shape_inference_proto_py_pb2", # has_services = 0, -# api_version = 2, # deps = [":cpp_shape_inference_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/framework/experimental/BUILD b/tensorflow/python/framework/experimental/BUILD index 2d7a8f11129a7f..273cf42c4e132c 100644 --- a/tensorflow/python/framework/experimental/BUILD +++ b/tensorflow/python/framework/experimental/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_python_pybind_extension") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") load( "//tensorflow/tools/test:performance.bzl", "cuda_py_benchmark_test", @@ -20,6 +21,9 @@ tf_python_pybind_extension( pytype_srcs = [ "_unified_api.pyi", ], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:lib", @@ -35,6 +39,9 @@ tf_python_pybind_extension( name = "_tape", srcs = ["tape.cc"], features = ["-layering_check"], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:lib", @@ -43,7 +50,13 @@ tf_python_pybind_extension( "//tensorflow/python:unified_api_pywrap_required_headers", "//tensorflow/python/lib/core:pybind11_lib", "@pybind11", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/c/experimental/gradients/tape:tape_context", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + ], + ), ) tf_python_pybind_extension( @@ -53,6 +66,9 @@ tf_python_pybind_extension( pytype_srcs = [ "_math_ops.pyi", ], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:framework", @@ -63,7 +79,11 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_lib", "@com_google_absl//absl/types:span", "@pybind11", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/c/experimental/ops:math_ops", + ], + ), ) tf_python_pybind_extension( @@ -73,6 +93,9 @@ tf_python_pybind_extension( pytype_srcs = [ "_nn_ops.pyi", ], + visibility = [ + "//tensorflow/python:__pkg__", + ], deps = [ "//tensorflow/c/eager:tfe_tensorhandle_internal", "//tensorflow/core:framework", @@ -83,7 +106,11 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_lib", "@com_google_absl//absl/types:span", "@pybind11", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/c/experimental/ops:nn_ops", + ], + ), ) py_strict_library( diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index c2b769a99d6f8e..46f981df64b6c6 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -100,6 +100,7 @@ from tensorflow.python.util import tf_inspect from tensorflow.python.util import traceback_utils from tensorflow.python.util.compat import collections_abc +from tensorflow.python.util.numpy_compat import np_where from tensorflow.python.util.protobuf import compare from tensorflow.python.util.tf_export import tf_export @@ -3248,11 +3249,11 @@ def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): np.abs(a - b) > atol + rtol * np.abs(b), np.isnan(a) != np.isnan(b)) if a.ndim: - x = a[np.where(cond)] - y = b[np.where(cond)] - msgs.append("not close where = {}".format(np.where(cond))) + x = a[np_where(cond)] + y = b[np_where(cond)] + msgs.append("not close where = {}".format(np_where(cond))) else: - # np.where is broken for scalars + # np_where is broken for scalars x, y = a, b msgs.append("not close lhs = {}".format(x)) msgs.append("not close rhs = {}".format(y)) @@ -3479,11 +3480,11 @@ def assertAllEqual(self, a, b, msg=None): # Adds more details to np.testing.assert_array_equal. diff = np.logical_not(same) if a.ndim: - x = a[np.where(diff)] - y = b[np.where(diff)] - msgs.append("not equal where = {}".format(np.where(diff))) + x = a[np_where(diff)] + y = b[np_where(diff)] + msgs.append("not equal where = {}".format(np_where(diff))) else: - # np.where is broken for scalars + # np_where is broken for scalars x, y = a, b msgs.append("not equal lhs = %r" % x) msgs.append("not equal rhs = %r" % y) @@ -3583,7 +3584,7 @@ def _format_subscripts(self, subscripts, value, limit=10, indent=2): Args: subscripts: The tensor (np.ndarray) subscripts, of the same format as - np.where()'s return value, i.e., a tuple of arrays with each array + np_where()'s return value, i.e., a tuple of arrays with each array corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). value: (np.ndarray) value of the tensor. limit: (int) The maximum number of indices to print. @@ -3639,7 +3640,7 @@ def assertAllInRange(self, "The value of %s does not have an ordered numeric type, instead it " "has type: %s" % (target, target.dtype)) - nan_subscripts = np.where(np.isnan(target)) + nan_subscripts = np_where(np.isnan(target)) if np.size(nan_subscripts): raise AssertionError( "%d of the %d element(s) are NaN. " @@ -3657,7 +3658,7 @@ def assertAllInRange(self, violations, np.greater_equal(target, upper_bound) if open_upper_bound else np.greater(target, upper_bound)) - violation_subscripts = np.where(violations) + violation_subscripts = np_where(violations) if np.size(violation_subscripts): raise AssertionError( "%d of the %d element(s) are outside the range %s. " % diff --git a/tensorflow/python/grappler/BUILD b/tensorflow/python/grappler/BUILD index 366ebfa1927674..1e1d643602b5ba 100644 --- a/tensorflow/python/grappler/BUILD +++ b/tensorflow/python/grappler/BUILD @@ -2,6 +2,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") load("//tensorflow:tensorflow.bzl", "if_not_windows") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "get_compatible_with_portable", "tf_py_strict_test", "tf_pybind_cc_library_wrapper", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "tf_protos_grappler") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -96,7 +97,7 @@ tf_python_pybind_extension( "//tensorflow/core:protos_all_cc", "//tensorflow/python/lib/core:pybind11_status", "@pybind11", - ], + ] + if_pywrap(["//tensorflow/python/grappler:model_analyzer_lib"]), ) py_strict_library( @@ -201,11 +202,14 @@ tf_python_pybind_extension( srcs = ["cluster_wrapper.cc"], hdrs = [ "//tensorflow/cc:pywrap_required_hdrs", - "//tensorflow/core/grappler:pywrap_required_hdrs", - "//tensorflow/core/grappler/clusters:pywrap_required_hdrs", - "//tensorflow/core/grappler/costs:pywrap_required_hdrs", - "//tensorflow/core/grappler/utils:pywrap_required_hdrs", - ], + ] + if_pywrap( + if_false = [ + "//tensorflow/core/grappler:pywrap_required_hdrs", + "//tensorflow/core/grappler/clusters:pywrap_required_hdrs", + "//tensorflow/core/grappler/costs:pywrap_required_hdrs", + "//tensorflow/core/grappler/utils:pywrap_required_hdrs", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_tf_cluster.pyi", @@ -219,7 +223,12 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_status", "@com_google_absl//absl/types:span", "@pybind11", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/core/grappler/costs:measuring_cost_estimator", + "//tensorflow/core/grappler/clusters:single_machine", + ], + ), ) cuda_py_strict_test( @@ -265,14 +274,16 @@ py_strict_library( tf_python_pybind_extension( name = "_pywrap_tf_optimizer", srcs = ["tf_optimizer_wrapper.cc"], - hdrs = [ - "//tensorflow/cc:pywrap_required_hdrs", - "//tensorflow/core/grappler:pywrap_required_hdrs", - "//tensorflow/core/grappler/clusters:pywrap_required_hdrs", - "//tensorflow/core/grappler/costs:pywrap_required_hdrs", - "//tensorflow/core/grappler/optimizers:pywrap_required_hdrs", - "//tensorflow/core/grappler/verifiers:pywrap_required_hdrs", - ], + hdrs = if_pywrap( + if_false = [ + "//tensorflow/cc:pywrap_required_hdrs", + "//tensorflow/core/grappler:pywrap_required_hdrs", + "//tensorflow/core/grappler/clusters:pywrap_required_hdrs", + "//tensorflow/core/grappler/costs:pywrap_required_hdrs", + "//tensorflow/core/grappler/optimizers:pywrap_required_hdrs", + "//tensorflow/core/grappler/verifiers:pywrap_required_hdrs", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_tf_optimizer.pyi", @@ -293,7 +304,16 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_status", "@pybind11", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/clusters:utils", + "//tensorflow/core/grappler:grappler_item_builder", + "//tensorflow/core/grappler/optimizers:meta_optimizer", + "//tensorflow/core/grappler/optimizers:graph_optimizer", + "//tensorflow/core/grappler/verifiers:graph_verifier", + ], + ), ) tf_py_strict_test( diff --git a/tensorflow/python/grappler/cluster_wrapper.cc b/tensorflow/python/grappler/cluster_wrapper.cc index 05403e76bd1a40..dbf97535082413 100644 --- a/tensorflow/python/grappler/cluster_wrapper.cc +++ b/tensorflow/python/grappler/cluster_wrapper.cc @@ -49,12 +49,12 @@ limitations under the License. namespace py = pybind11; -tensorflow::Status _GetOpPerformanceDataAndRunTime( +absl::Status _GetOpPerformanceDataAndRunTime( const tensorflow::grappler::GrapplerItem& item, tensorflow::grappler::CostEstimator* cost_measure, tensorflow::OpPerformanceList* op_performance_data, tensorflow::grappler::Costs* costs) { - tensorflow::Status status = cost_measure->Initialize(item); + absl::Status status = cost_measure->Initialize(item); if (!status.ok()) return status; tensorflow::RunMetadata run_metadata; @@ -159,7 +159,7 @@ PYBIND11_MODULE(_pywrap_tf_cluster, m) { tensorflow::grappler::GrapplerItem* item) -> std::unordered_map> { if (cluster == nullptr || item == nullptr) { - tsl::MaybeRaiseRegisteredFromStatus(tensorflow::Status( + tsl::MaybeRaiseRegisteredFromStatus(absl::Status( tensorflow::errors::Internal("You need both a cluster and an " "item to get supported devices."))); } @@ -184,7 +184,7 @@ PYBIND11_MODULE(_pywrap_tf_cluster, m) { } else { // Check the kernel capabilities const tensorflow::DeviceType dev_type(type); - tensorflow::Status s = + absl::Status s = tensorflow::FindKernelDef(dev_type, node, nullptr, nullptr); if (s.ok()) { supported_device_types[node.name()].insert(type); @@ -193,7 +193,7 @@ PYBIND11_MODULE(_pywrap_tf_cluster, m) { // TODO: extends this to support outputs as well tensorflow::MemoryTypeVector inp_mtypes; tensorflow::MemoryTypeVector out_mtypes; - tensorflow::Status s = tensorflow::MemoryTypesForNode( + absl::Status s = tensorflow::MemoryTypesForNode( tensorflow::OpRegistry::Global(), dev_type, node, &inp_mtypes, &out_mtypes); if (s.ok()) { @@ -261,7 +261,7 @@ PYBIND11_MODULE(_pywrap_tf_cluster, m) { tensorflow::OpPerformanceList op_performance_data; tensorflow::grappler::Costs costs; - tensorflow::Status s = _GetOpPerformanceDataAndRunTime( + absl::Status s = _GetOpPerformanceDataAndRunTime( *item, &cost_measure, &op_performance_data, &costs); double run_time = FLT_MAX; if (s.ok()) { @@ -298,7 +298,7 @@ PYBIND11_MODULE(_pywrap_tf_cluster, m) { std::tuple>> { if (item == nullptr || cluster == nullptr) { tsl::MaybeRaiseRegisteredFromStatus( - tensorflow::Status(tensorflow::errors::Internal( + absl::Status(tensorflow::errors::Internal( "You need both a cluster and an item to determine peak " "memory usage."))); } diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 709a140588f9a6..90f9b426d3756c 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -31,8 +31,8 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, /*use_aggressive_shape_inference=*/true), suffix_(suffix) {} -Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report, - bool verbose) { +absl::Status CostAnalyzer::GenerateReport(std::ostream& os, + bool per_node_report, bool verbose) { GatherCosts(); PreprocessCosts(); AnalyzeCosts(); @@ -45,7 +45,7 @@ void CostAnalyzer::PredictCosts(CostEstimator* cost_estimator, TF_CHECK_OK(cost_estimator->Initialize(*item_)); RunMetadata run_metadata; Costs costs; - const Status status = + const absl::Status status = cost_estimator->PredictCosts(item_->graph, &run_metadata, &costs); if (cost_graph) { cost_graph->Swap(run_metadata.mutable_cost_graph()); diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h index 0410aa3e781bdc..44e1e45265b9c5 100644 --- a/tensorflow/python/grappler/cost_analyzer.h +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -51,7 +51,8 @@ class CostAnalyzer { public: explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster, const string& suffix); - Status GenerateReport(std::ostream& os, bool per_node_report, bool verbose); + absl::Status GenerateReport(std::ostream& os, bool per_node_report, + bool verbose); private: void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph, diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index f2899ea62d4d4a..202eb758a91221 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -26,8 +26,8 @@ namespace grappler { ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} -Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds, - std::ostream& os) { +absl::Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds, + std::ostream& os) { GraphProperties properties(item_); TF_RETURN_IF_ERROR(properties.InferStatically(assume_valid_feeds)); @@ -80,7 +80,8 @@ void ModelAnalyzer::PrintNodeInfo(const NodeDef* node, if (debug) { const OpRegistrationData* op_reg_data; - Status status = OpRegistry::Global()->LookUp(node->op(), &op_reg_data); + absl::Status status = + OpRegistry::Global()->LookUp(node->op(), &op_reg_data); if (!status.ok()) { os << "\tCouldn't find op registration for " << node->op() << std::endl; } else if (!op_reg_data->shape_inference_fn) { diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h index 9764a75b29ac8c..d66ad8915c99b5 100644 --- a/tensorflow/python/grappler/model_analyzer.h +++ b/tensorflow/python/grappler/model_analyzer.h @@ -31,7 +31,8 @@ class GraphProperties; class ModelAnalyzer { public: explicit ModelAnalyzer(const GrapplerItem& item); - Status GenerateReport(bool debug, bool assume_valid_feeds, std::ostream& os); + absl::Status GenerateReport(bool debug, bool assume_valid_feeds, + std::ostream& os); private: void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties, diff --git a/tensorflow/python/grappler/remapper_test.py b/tensorflow/python/grappler/remapper_test.py index a759310eb63a4c..addfdef41278b4 100644 --- a/tensorflow/python/grappler/remapper_test.py +++ b/tensorflow/python/grappler/remapper_test.py @@ -168,6 +168,15 @@ def gelu_approximate(x): def gelu_exact(x): return nn.gelu(x, approximate=False) + # Erfc-based implementation of GeluExact from: + # https://github.com/tensorflow/tensorflow/pull/76174 + def gelu_exact_erfc(x): + return ( + 0.5 + * x + * math_ops.erfc(-x * math_ops.cast(0.7071067811865476, x.dtype)) + ) + device = '/device:GPU:0' if mode == 'cuda' else '/device:CPU:0' config = [] use_fp16 = True @@ -180,6 +189,7 @@ def gelu_exact(x): use_fp16 = False if mode == 'mkl': config.append((dtypes.float32, gelu_exact, b'GeluExact')) + config.append((dtypes.float32, gelu_exact_erfc, b'GeluExact')) config.append((dtypes.float32, gelu_approximate, b'GeluApproximate')) if _pywrap_utils.IsDataTypeSupportedByOneDNNOnThisCPU(dtypes.bfloat16): config.append((dtypes.bfloat16, gelu_approximate, b'GeluApproximate')) diff --git a/tensorflow/python/keras/protobuf/BUILD b/tensorflow/python/keras/protobuf/BUILD index 73db8516220ffb..17825a4aeefe1d 100644 --- a/tensorflow/python/keras/protobuf/BUILD +++ b/tensorflow/python/keras/protobuf/BUILD @@ -31,19 +31,16 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "saved_metadata_proto_py_pb2", -# api_version = 2, # deps = [":saved_metadata_proto"], # ) # # py_proto_library( # name = "projector_config_proto_py_pb2", -# api_version = 2, # deps = [":projector_config_proto"], # ) # # py_proto_library( # name = "versions_proto_py_pb2", -# api_version = 2, # deps = [":versions_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/kernel_tests/array_ops/BUILD b/tensorflow/python/kernel_tests/array_ops/BUILD index 1a8d26ef226f7c..73f07f5b95b81b 100644 --- a/tensorflow/python/kernel_tests/array_ops/BUILD +++ b/tensorflow/python/kernel_tests/array_ops/BUILD @@ -430,6 +430,7 @@ cuda_py_strict_test( "//tensorflow/python/ops:variable_scope", "//tensorflow/python/ops:variables", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/util:numpy_compat", "//third_party/py/numpy", ], ) @@ -479,6 +480,7 @@ cuda_py_strict_test( "//tensorflow/python/ops:manip_ops", "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", + "@pypi_packaging//:pkg", ], ) diff --git a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py index 460b8f8e064e2c..9c35c61a605cc0 100644 --- a/tensorflow/python/kernel_tests/array_ops/init_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/init_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.util.numpy_compat import np_where # Returns true iff the two initializers produce the same tensor to @@ -714,7 +715,7 @@ def _baseNDArrayCompareToNumpy(self, axis): self.assert_close(actual, expected) def assert_close(self, actual, expected): - wrong_indices = np.where(~np.allclose(actual, expected)) + wrong_indices = np_where(~np.allclose(actual, expected)) mess = "Wrong float answer. Wrong indices: {}".format(wrong_indices) self.assertTrue(np.allclose(actual, expected), mess) diff --git a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py index 35e2c3c0f86e36..382d3ebc2956bc 100644 --- a/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/manip_ops_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for manip_ops.""" import numpy as np +from packaging.version import Version from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -24,14 +25,8 @@ from tensorflow.python.ops import manip_ops from tensorflow.python.platform import test as test_lib -# pylint: disable=g-import-not-at-top -try: - from distutils.version import StrictVersion as Version - # numpy.roll for multiple shifts was introduced in numpy version 1.12.0 - NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") -except ImportError: - NP_ROLL_CAN_MULTISHIFT = False -# pylint: enable=g-import-not-at-top +# numpy.roll for multiple shifts was introduced in numpy version 1.12.0 +NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") class RollTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index ea1fa6566c626e..3f6dfc1fe3e6a3 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -213,7 +213,10 @@ cuda_py_strict_test( size = "medium", srcs = ["linear_operator_block_lower_triangular_test.py"], shard_count = 8, - tags = ["optonly"], + tags = [ + "no_gpu", # Seg fault. http://b/365525243 + "optonly", + ], deps = [ "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", diff --git a/tensorflow/python/kernel_tests/proto/BUILD b/tensorflow/python/kernel_tests/proto/BUILD index 10bc4c327c5a59..129797173211ed 100644 --- a/tensorflow/python/kernel_tests/proto/BUILD +++ b/tensorflow/python/kernel_tests/proto/BUILD @@ -3,10 +3,10 @@ load("//tensorflow:strict.default.bzl", "py_strict_library") # Placeholder: load py_proto_library -load("//tensorflow:tensorflow.bzl", "tf_cc_shared_object") -load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test") +load("//tensorflow:tensorflow.bzl", "if_oss", "tf_cc_shared_object") +load("//tensorflow:tensorflow.default.bzl", "stripped_cc_info", "tf_py_strict_test") load("//tensorflow/core/platform:build_config.bzl", "tf_additional_all_protos", "tf_proto_library") -load("//tensorflow/core/platform:build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -22,10 +22,7 @@ tf_py_strict_test( name = "decode_proto_op_test", size = "small", srcs = ["decode_proto_op_test.py"], - data = if_static( - [], - otherwise = [":libtestexample.so"], - ), + data = if_oss([":libtestexample.so"]), python_version = "PY3", tags = [ "no_pip", # TODO(b/78026780) @@ -43,10 +40,7 @@ tf_py_strict_test( name = "encode_proto_op_test", size = "small", srcs = ["encode_proto_op_test.py"], - data = if_static( - [], - otherwise = [":libtestexample.so"], - ), + data = if_oss([":libtestexample.so"]), python_version = "PY3", tags = [ "no_pip", # TODO(b/78026780) @@ -118,9 +112,20 @@ tf_proto_library( tf_cc_shared_object( name = "libtestexample.so", linkstatic = 1, - deps = [ - ":test_example_proto_cc", - ], + deps = if_pywrap( + if_false = [ + ":test_example_proto_cc", + ], + if_true = [ + "//tensorflow/python:_pywrap_tensorflow_common", + ":test_example_proto_cc_stripped", + ], + ), +) + +stripped_cc_info( + name = "test_example_proto_cc_stripped", + deps = [":test_example_proto_cc"], ) py_strict_library( @@ -155,7 +160,6 @@ tf_py_strict_test( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "test_example_proto_py", -# api_version = 2, # deps = [":test_example_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index d1b7986c0a998e..799c3a121c5217 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -214,7 +214,7 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) { CHECK(args); // Invokes the trampoline. - PyObject* result = PyEval_CallObject(trampoline, args); + PyObject* result = PyObject_Call(trampoline, args, nullptr); Py_DECREF(args); Status s = OkStatus(); if (result == nullptr) { diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 1c81b35e48cc5e..2b6c2e289918a9 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -617,9 +617,9 @@ tstring PyRepr(PyObject* obj) { bool IsPyDimension(PyObject* obj) { const char* tp_name = obj->ob_type->tp_name; if (strcmp(tp_name, "Dimension") != 0) return false; - bool ret = str_util::EndsWith( - PyRepr(PyType(obj)), - "tensorflow.python.framework.tensor_shape.Dimension'>"); + bool ret = + absl::EndsWith(PyRepr(PyType(obj)), + "tensorflow.python.framework.tensor_shape.Dimension'>"); return ret; } diff --git a/tensorflow/python/mlir_wrapper.cc b/tensorflow/python/mlir_wrapper.cc index 8fe71b77ef11c9..4846a089b35ddf 100644 --- a/tensorflow/python/mlir_wrapper.cc +++ b/tensorflow/python/mlir_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/pytypes.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 +#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/safe_ptr.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/python/mlir.h" diff --git a/tensorflow/python/ops/BUILD b/tensorflow/python/ops/BUILD index a6486af0837f67..ab4a5f2f5a5941 100644 --- a/tensorflow/python/ops/BUILD +++ b/tensorflow/python/ops/BUILD @@ -1093,6 +1093,7 @@ cuda_py_strict_test( name = "compiled_collective_ops_gpu_test", size = "small", srcs = ["compiled_collective_ops_gpu_test.py"], + env = {"TF_FORCE_GPU_ALLOW_GROWTH": "true"}, main = "compiled_collective_ops_gpu_test.py", python_version = "PY3", tags = [ @@ -3676,6 +3677,7 @@ cuda_py_strict_test( "//tensorflow/python/platform:test", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", + "@pypi_packaging//:pkg", ], ) diff --git a/tensorflow/python/ops/math_ops_linspace_test.py b/tensorflow/python/ops/math_ops_linspace_test.py index 9e2499cd4dc279..66cd91caaa9770 100644 --- a/tensorflow/python/ops/math_ops_linspace_test.py +++ b/tensorflow/python/ops/math_ops_linspace_test.py @@ -14,12 +14,9 @@ # ============================================================================== """Tests for tensorflow.ops.math_ops.linspace.""" -# Using distutils.version.LooseVersion was resulting in an error, so importing -# directly. -from distutils.version import LooseVersion # pylint: disable=g-importing-member - from absl.testing import parameterized import numpy as np +from packaging.version import Version # pylint: disable=g-importing-member from tensorflow.python.eager import def_function from tensorflow.python.framework import dtypes @@ -47,7 +44,7 @@ class LinspaceTest(test_util.TensorFlowTestCase, parameterized.TestCase): ]) # pylint: enable=g-complex-comprehension def testLinspaceBroadcasts(self, start_shape, stop_shape, dtype, num): - if LooseVersion(np.version.version) < LooseVersion("1.16.0"): + if Version(np.version.version) < Version("1.16.0"): self.skipTest("numpy doesn't support axes before version 1.16.0") ndims = max(len(start_shape), len(stop_shape)) diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index b44f675dff3878..a71c525c386b31 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -3748,8 +3748,13 @@ def gelu(features, approximate=False, name=None): if approximate: return gen_nn_ops.gelu(features, name=name) else: - return 0.5 * features * (1.0 + math_ops.erf( - features / math_ops.cast(1.4142135623730951, features.dtype))) + return ( + 0.5 + * features + * math_ops.erfc( + -features * math_ops.cast(0.7071067811865476, features.dtype) + ) + ) def _flatten_outer_dims(logits): diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 755a40f3bf661b..1af6a97cc20335 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -16,8 +16,13 @@ import functools import math +<<<<<<< HEAD import os import time +======= +import sys + +>>>>>>> master from absl.testing import parameterized import numpy as np @@ -1063,11 +1068,12 @@ def gelu(x, approximate=False): from scipy.stats import norm # pylint: disable=g-import-not-at-top return x * norm.cdf(x) - np.random.seed(1) # Make it reproducible. - x = np.random.randn(3, 4).astype(np.float32) + # Make sure we test for negative arguments where GeLU is difficult to + # evaluate accurately. + x = np.linspace(-12, 5, 1000).astype(np.float32) y = gelu(x) z = self.evaluate(nn_ops.gelu(x)) - self.assertAllClose(y, z) + self.assertAllClose(y, z, atol=0, rtol=2e-5) y = gelu(x, True) z = self.evaluate(nn_ops.gelu(x, True)) @@ -1750,6 +1756,27 @@ def testIncorrectSizeInput(self): "`input.shape.rank` must be 3, 4 or 5.*of rank 6."): nn_ops.max_pool_v2(x, 2, 2, "SAME") + @test_util.disable_xla("XLA catches the error and rethrows as different one") + def testIncoorectKSize(self): + with self.assertRaisesRegex( + errors.InvalidArgumentError, "Sliding window ksize must be positive." + ): + op = nn_ops.max_pool_v2( + array_ops.ones([3, 4, 4, 5]), [1, -1, -1, 1], 2, "SAME" + ) + with test_util.use_gpu(): + self.evaluate(op) + + ksize = sys.maxsize + 100 # Set to a value larger than sys.maxsize + with self.assertRaises( + OverflowError if context.executing_eagerly() else ValueError + ): + op = nn_ops.max_pool_v2( + array_ops.ones([3, 4, 4, 5]), ksize=ksize, strides=2, padding="SAME" + ) + with test_util.use_gpu(): + self.evaluate(op) + @test_util.run_all_in_graph_and_eager_modes class ConvolutionTest(test_lib.TestCase): diff --git a/tensorflow/python/ops/numpy_ops/tests/BUILD b/tensorflow/python/ops/numpy_ops/tests/BUILD index 5a69b0f25adc86..70c9b958895d71 100644 --- a/tensorflow/python/ops/numpy_ops/tests/BUILD +++ b/tensorflow/python/ops/numpy_ops/tests/BUILD @@ -222,6 +222,7 @@ py_strict_test( "//tensorflow/python/ops/numpy_ops:np_config", "//tensorflow/python/ops/numpy_ops:numpy", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", ], diff --git a/tensorflow/python/ops/numpy_ops/tests/np_test.py b/tensorflow/python/ops/numpy_ops/tests/np_test.py index 6e2499960f4849..37c869db58fc48 100644 --- a/tensorflow/python/ops/numpy_ops/tests/np_test.py +++ b/tensorflow/python/ops/numpy_ops/tests/np_test.py @@ -35,6 +35,7 @@ import tensorflow.python.ops.numpy_ops.tests.np_wrapper as tnp import tensorflow.python.ops.numpy_ops.tests.test_util as jtu from tensorflow.python.util import nest +from tensorflow.python.util.numpy_compat import np_where config.parse_flags_with_absl() @@ -683,7 +684,7 @@ def testCountNonzero(self, shape, dtype, axis): for shape in all_shapes for dtype in all_dtypes)) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.nonzero(x) # pylint: disable=unnecessary-lambda + onp_fun = lambda x: onp.nonzero(onp.atleast_1d(x)) # pylint: disable=unnecessary-lambda lnp_fun = lambda x: tnp.nonzero(x) # pylint: disable=unnecessary-lambda args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) @@ -2338,7 +2339,7 @@ def onp_fun(*args): for shape in all_shapes for dtype in all_dtypes)) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero() - onp_fun = lambda x: onp.where(x) + onp_fun = lambda x: np_where(x) lnp_fun = lambda x: tnp.where(x) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=False) diff --git a/tensorflow/python/ops/numpy_ops/tests/test_util.py b/tensorflow/python/ops/numpy_ops/tests/test_util.py index 43a03379a5de24..fa8db8965a49a7 100644 --- a/tensorflow/python/ops/numpy_ops/tests/test_util.py +++ b/tensorflow/python/ops/numpy_ops/tests/test_util.py @@ -14,7 +14,6 @@ # ============================================================================== """NumPy test utilities.""" from contextlib import contextmanager -from distutils.util import strtobool import functools from functools import partial import re @@ -50,6 +49,14 @@ FLAGS = flags.FLAGS +# https://danielms.site/zet/2023/pythons-distutil-strtobool-replacement/ +def strtobool(value: str) -> bool: + value = value.lower() + if value in ('y', 'yes', 'on', '1', 'true', 't'): + return True + return False + + # TODO(wangpeng): Remove this flag after broken tests are fixed flags.DEFINE_bool('enable_x64', strtobool('False'), diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index 5b66904cc4b6cd..3e47c991e0247b 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -237,6 +237,7 @@ py_strict_library( "//tensorflow/python/framework:tensor_shape", "//tensorflow/python/ops:array_ops", "//tensorflow/python/util:dispatch", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], @@ -547,6 +548,7 @@ py_strict_library( "//tensorflow/python/util:deprecation", "//tensorflow/python/util:dispatch", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_decorator_py", "//tensorflow/python/util:tf_export", "//tensorflow/python/util:tf_inspect", diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 215304c867507c..55505df533d447 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -26,6 +26,7 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.util import dispatch +from tensorflow.python.util.numpy_compat import np_reshape from tensorflow.python.util.tf_export import tf_export @@ -151,9 +152,9 @@ def _ragged_factory(values, row_splits): def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument if dtype is object or dtype is None: - return np.reshape(np.array(pylist, dtype=dtype), shape) + return np_reshape(np.array(pylist, dtype=dtype), shape) else: - return np.reshape(np.array(pylist).astype(dtype), shape) + return np_reshape(np.array(pylist).astype(dtype), shape) return _constant_value( _ragged_factory, _inner_factory, pylist, dtype, ragged_rank, inner_shape diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index a92d425a4c748e..e03a788e527cd1 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -1487,7 +1487,7 @@ def merge_dims(self, outer_axis, inner_axis): tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32) To mimic the behavior of `np.flatten` (which flattens all dimensions), use - `rt.merge_dims(0, -1). To mimic the behavior of `tf.layers.Flatten` (which + `rt.merge_dims(0, -1)`. To mimic the behavior of `tf.layers.Flatten` (which flattens all dimensions except the outermost batch dimension), use `rt.merge_dims(1, -1)`. diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index edcdbf134a8490..a046d9b7ed8f6e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -857,7 +857,8 @@ def read_and_set_handle(no_copy): and self._xla_sharding is not None ): sharding_string = self._xla_sharding.SerializeToString() - result = gen_xla_ops.xla_sharding(result, sharding=sharding_string) + with ops.colocate_with(result): + result = gen_xla_ops.xla_sharding(result, sharding=sharding_string) # pylint: disable=protected-access result.op._set_attr( "_XlaSharding", diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index 8e07f478198d0d..a64bfd314d5742 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -107,6 +107,7 @@ tf_python_pybind_extension( ], visibility = [ "//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__", + "//tensorflow/python:__pkg__", "//tensorflow/python/profiler:__subpackages__", "//tensorflow/tools/pip_package:__subpackages__", ], @@ -126,6 +127,7 @@ tf_python_pybind_extension( ], visibility = [ "//tensorflow/core/profiler:internal", + "//tensorflow/python:__pkg__", "//tensorflow/python/eager:__pkg__", "//tensorflow/python/profiler:__pkg__", "//tensorflow/tools/pip_package:__subpackages__", @@ -179,8 +181,8 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/profiler/utils:session_manager", "@local_xla//xla/tsl/profiler/convert:xplane_to_trace_events", "@local_xla//xla/tsl/profiler/rpc/client:capture_profile", + "@local_xla//xla/tsl/profiler/utils:session_manager", ], ) diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc index 2a7af8ccac312c..f02ca2aa50c9a2 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/types/variant.h" #include "xla/tsl/profiler/convert/xplane_to_trace_events.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/profiler/rpc/client/save_profile.h" #include "tensorflow/core/profiler/rpc/profiler_server.h" -#include "tsl/profiler/utils/session_manager.h" namespace tensorflow { namespace profiler { @@ -47,7 +47,7 @@ namespace pywrap { using tsl::profiler::GetRemoteSessionManagerOptionsLocked; using tsl::profiler::ValidateHostPortPair; -tensorflow::Status Trace( +absl::Status Trace( const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops, int duration_ms, int num_tracing_attempts, const absl::flat_hash_map>& @@ -57,9 +57,9 @@ tensorflow::Status Trace( num_tracing_attempts, options); } -tensorflow::Status Monitor(const char* service_addr, int duration_ms, - int monitoring_level, bool display_timestamp, - tensorflow::string* result) { +absl::Status Monitor(const char* service_addr, int duration_ms, + int monitoring_level, bool display_timestamp, + tensorflow::string* result) { TF_RETURN_IF_ERROR(ValidateHostPortPair(service_addr)); { TF_RETURN_IF_ERROR(tsl::profiler::Monitor(service_addr, duration_ms, @@ -69,7 +69,7 @@ tensorflow::Status Monitor(const char* service_addr, int duration_ms, return absl::OkStatus(); } -tensorflow::Status ProfilerSessionWrapper::Start( +absl::Status ProfilerSessionWrapper::Start( const char* logdir, const absl::flat_hash_map>& options) { @@ -79,10 +79,10 @@ tensorflow::Status ProfilerSessionWrapper::Start( return session_->Status(); } -tensorflow::Status ProfilerSessionWrapper::Stop(tensorflow::string* result) { +absl::Status ProfilerSessionWrapper::Stop(tensorflow::string* result) { if (session_ != nullptr) { tensorflow::profiler::XSpace xspace; - tensorflow::Status status = session_->CollectData(&xspace); + absl::Status status = session_->CollectData(&xspace); session_.reset(); tsl::profiler::ConvertXSpaceToTraceEventsString(xspace, result); TF_RETURN_IF_ERROR(status); @@ -90,12 +90,12 @@ tensorflow::Status ProfilerSessionWrapper::Stop(tensorflow::string* result) { return absl::OkStatus(); } -tensorflow::Status ProfilerSessionWrapper::ExportToTensorBoard() { +absl::Status ProfilerSessionWrapper::ExportToTensorBoard() { if (!session_ || logdir_.empty()) { return absl::OkStatus(); } tensorflow::profiler::XSpace xspace; - tensorflow::Status status; + absl::Status status; status = session_->CollectData(&xspace); session_.reset(); status = tsl::profiler::ExportToTensorBoard(xspace, logdir_); diff --git a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h index 11d7e2abd4ae12..e8bda8c9f012f6 100644 --- a/tensorflow/python/profiler/internal/profiler_pywrap_impl.h +++ b/tensorflow/python/profiler/internal/profiler_pywrap_impl.h @@ -27,24 +27,24 @@ namespace tensorflow { namespace profiler { namespace pywrap { -tensorflow::Status Trace( +absl::Status Trace( const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops, int duration_ms, int num_tracing_attempts, const absl::flat_hash_map>& options); -tensorflow::Status Monitor(const char* service_addr, int duration_ms, - int monitoring_level, bool display_timestamp, - tensorflow::string* result); +absl::Status Monitor(const char* service_addr, int duration_ms, + int monitoring_level, bool display_timestamp, + tensorflow::string* result); class ProfilerSessionWrapper { public: - tensorflow::Status Start( + absl::Status Start( const char* logdir, const absl::flat_hash_map>& options); - tensorflow::Status Stop(tensorflow::string* result); - tensorflow::Status ExportToTensorBoard(); + absl::Status Stop(tensorflow::string* result); + absl::Status ExportToTensorBoard(); private: std::unique_ptr session_; diff --git a/tensorflow/python/profiler/internal/profiler_wrapper.cc b/tensorflow/python/profiler/internal/profiler_wrapper.cc index d44fc9a39c6e6c..dfca50bdce6891 100644 --- a/tensorflow/python/profiler/internal/profiler_wrapper.cc +++ b/tensorflow/python/profiler/internal/profiler_wrapper.cc @@ -67,7 +67,7 @@ PYBIND11_MODULE(_pywrap_profiler, m) { .def("start", [](ProfilerSessionWrapper& wrapper, const char* logdir, const py::dict& options) { - tensorflow::Status status; + absl::Status status; ToolOptions tool_options = ToolOptionsFromPythonDict(options); { py::gil_scoped_release release; @@ -79,7 +79,7 @@ PYBIND11_MODULE(_pywrap_profiler, m) { .def("stop", [](ProfilerSessionWrapper& wrapper) { tensorflow::string content; - tensorflow::Status status; + absl::Status status; { py::gil_scoped_release release; status = wrapper.Stop(&content); @@ -90,7 +90,7 @@ PYBIND11_MODULE(_pywrap_profiler, m) { return py::bytes(content); }) .def("export_to_tb", [](ProfilerSessionWrapper& wrapper) { - tensorflow::Status status; + absl::Status status; { py::gil_scoped_release release; status = wrapper.ExportToTensorBoard(); @@ -112,7 +112,7 @@ PYBIND11_MODULE(_pywrap_profiler, m) { [](const char* service_addr, const char* logdir, const char* worker_list, bool include_dataset_ops, int duration_ms, int num_tracing_attempts, py::dict options) { - tensorflow::Status status; + absl::Status status; ToolOptions tool_options = ToolOptionsFromPythonDict(options); { py::gil_scoped_release release; @@ -127,7 +127,7 @@ PYBIND11_MODULE(_pywrap_profiler, m) { m.def("monitor", [](const char* service_addr, int duration_ms, int monitoring_level, bool display_timestamp) { tensorflow::string content; - tensorflow::Status status; + absl::Status status; { py::gil_scoped_release release; status = tensorflow::profiler::pywrap::Monitor( diff --git a/tensorflow/python/proto_exports.py b/tensorflow/python/proto_exports.py index c414936539df3d..34475ffb3a15f7 100644 --- a/tensorflow/python/proto_exports.py +++ b/tensorflow/python/proto_exports.py @@ -14,6 +14,7 @@ # ============================================================================== """Registers protos with tf_export that should be public.""" +from xla.tsl.protobuf import histogram_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import summary_pb2 @@ -21,7 +22,6 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.util import event_pb2 from tensorflow.python.util import tf_export -from tsl.protobuf import histogram_pb2 AttrValue = tf_export.tf_export(v1=['AttrValue'])(attr_value_pb2.AttrValue) ConfigProto = tf_export.tf_export(v1=['ConfigProto'])(config_pb2.ConfigProto) diff --git a/tensorflow/python/protobuf_inline_symbols_enforcer.cc b/tensorflow/python/protobuf_inline_symbols_enforcer.cc new file mode 100644 index 00000000000000..24beeeb70fd4f6 --- /dev/null +++ b/tensorflow/python/protobuf_inline_symbols_enforcer.cc @@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/protobuf/config.pb.h" +#include "tensorflow/core/protobuf/data_service.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/service_config.pb.h" +#include "tensorflow/dtensor/proto/layout.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow { +namespace python { +void protobuf_inline_symbols_enforcer() { + tensorflow::NamedDevice named_device; + named_device.mutable_properties(); + named_device.properties(); + + tensorflow::NamedDevice named_device_move(std::move(named_device)); + named_device_move.mutable_properties(); + + tensorflow::quantization::ExportedModel exported_model; + exported_model.function_aliases(); + + tensorflow::profiler::XSpace x_space; + x_space.mutable_hostnames(); + x_space.mutable_hostnames(0); + + tensorflow::dtensor::LayoutProto layout_proto; + layout_proto.GetDescriptor(); + layout_proto.GetReflection(); + layout_proto.default_instance(); + + tensorflow::dtensor::MeshProto mesh_proto; + mesh_proto.GetDescriptor(); + mesh_proto.GetReflection(); + mesh_proto.default_instance(); + + tensorflow::FunctionDef function_def; + function_def.descriptor(); + function_def.GetDescriptor(); + function_def.GetReflection(); + function_def.default_instance(); + + tensorflow::FunctionDefLibrary function_def_library; + function_def_library.descriptor(); + + tensorflow::GraphDef graph_def; + graph_def.descriptor(); + graph_def.GetDescriptor(); + graph_def.GetReflection(); + graph_def.default_instance(); + + tensorflow::MetaGraphDef meta_graph_def; + meta_graph_def.GetDescriptor(); + meta_graph_def.GetReflection(); + meta_graph_def.default_instance(); + + tensorflow::AttrValue attr_value; + attr_value.default_instance(); + + tensorflow::ConfigProto config_proto; + config_proto.default_instance(); + + tensorflow::data::experimental::DispatcherConfig dispatcher_config; + dispatcher_config.default_instance(); + + tensorflow::data::experimental::WorkerConfig worker_config; + worker_config.default_instance(); + + tensorflow::data::DataServiceMetadata data_service_metadata; +} +} // namespace python +} // namespace tensorflow diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index e909b79797f022..49e50cb380a735 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -803,6 +803,7 @@ tf_py_strict_test( "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:extension_type", + "//tensorflow/python/framework:immutable_dict", "//tensorflow/python/framework:sparse_tensor", "//tensorflow/python/framework:tensor", "//tensorflow/python/framework:tensor_shape", @@ -928,6 +929,7 @@ tf_python_pybind_extension( "pywrap_saved_model/metrics.pyi", ], visibility = [ + "//tensorflow/python:__pkg__", "//tensorflow/python/checkpoint:__subpackages__", "//tensorflow/python/tpu:__pkg__", "//tensorflow/python/training:__subpackages__", diff --git a/tensorflow/python/saved_model/nested_structure_coder.py b/tensorflow/python/saved_model/nested_structure_coder.py index bce787087a978f..a5802d0d6442e3 100644 --- a/tensorflow/python/saved_model/nested_structure_coder.py +++ b/tensorflow/python/saved_model/nested_structure_coder.py @@ -186,7 +186,7 @@ class _DictCodec: """Codec for dicts.""" def can_encode(self, pyobj): - return isinstance(pyobj, dict) + return isinstance(pyobj, collections_abc.Mapping) def do_encode(self, dict_value, encode_fn): encoded_dict = struct_pb2.StructuredValue() diff --git a/tensorflow/python/saved_model/nested_structure_coder_test.py b/tensorflow/python/saved_model/nested_structure_coder_test.py index c2b9e12d437605..45d29b72adab3e 100644 --- a/tensorflow/python/saved_model/nested_structure_coder_test.py +++ b/tensorflow/python/saved_model/nested_structure_coder_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import extension_type +from tensorflow.python.framework import immutable_dict from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor from tensorflow.python.framework import tensor_shape @@ -82,6 +83,20 @@ def testEncodeDecodeDict(self): self.assertIsInstance(decoded["a"], int) self.assertEqual(structure, decoded) + def testEncodeDecodeImmutableDict(self): + structure = immutable_dict.ImmutableDict(dict(a=3, b=[7, 2.5])) + self.assertTrue(nested_structure_coder.can_encode(structure)) + encoded = nested_structure_coder.encode_structure(structure) + expected = struct_pb2.StructuredValue() + expected.dict_value.fields["a"].int64_value = 3 + list_value = expected.dict_value.fields["b"].list_value + list_value.values.add().int64_value = 7 + list_value.values.add().float64_value = 2.5 + self.assertEqual(expected, encoded) + decoded = nested_structure_coder.decode_proto(encoded) + self.assertIsInstance(decoded["a"], int) + self.assertEqual(structure, decoded) + def testEncodeDecodeTensorShape(self): structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"] self.assertTrue(nested_structure_coder.can_encode(structure)) diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc index 36daad0d5ef925..4b9713c35c90d3 100644 --- a/tensorflow/python/tfe_wrapper.cc +++ b/tensorflow/python/tfe_wrapper.cc @@ -459,7 +459,11 @@ static py::bytes TFE_GetCompilerIr(py::handle& ctx, std::string s_stage(stage); IrExportStage selected_stage = [&] { - if (s_stage == "hlo") { + if (s_stage == "stablehlo") { + return IrExportStage::STABLEHLO; + } else if (s_stage == "stablehlo_serialized") { + return IrExportStage::STABLEHLO_SERIALIZED; + } else if (s_stage == "hlo") { return IrExportStage::HLO; } else if (s_stage == "hlo_no_metadata") { return IrExportStage::HLO_NO_METADATA; diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index 324fcb8389757f..b9f0628f6387e8 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -3,6 +3,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.bzl", "if_google", "if_xla_available", "tf_cc_test") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") load("//tensorflow/python/tools:tools.bzl", "saved_model_compile_aot") package( @@ -97,7 +98,9 @@ py_strict_binary( srcs = ["freeze_graph.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":freeze_graph_lib"], + deps = [":freeze_graph_lib"] + if_pywrap( + if_true = ["//tensorflow/python:_pywrap_tensorflow"], + ), ) py_strict_binary( @@ -356,7 +359,9 @@ py_strict_binary( srcs = ["saved_model_cli.py"], python_version = "PY3", srcs_version = "PY3", - deps = [":saved_model_cli_lib"], + deps = [":saved_model_cli_lib"] + if_pywrap( + if_true = ["//tensorflow/python:_pywrap_tensorflow"], + ), ) py_strict_library( @@ -471,7 +476,7 @@ py_strict_binary( "//tensorflow/python/trackable:autotrackable", "@absl_py//absl:app", "@absl_py//absl/flags", - ], + ] + if_pywrap(["//tensorflow/python:_pywrap_tensorflow"]), ) # copybara:comment_begin(oss-only) @@ -497,6 +502,7 @@ genrule( name = "create_models_for_aot_compile", outs = EMITTED_AOT_SAVE_MODEL_OBJECTS, cmd = ( + "PYWRAP_TARGET='//third_party/tensorflow/python:_pywrap_tensorflow' " + "$(location :make_aot_compile_models) --out_dir $(@D)" ), tags = ["no_rocm"], diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index 0ed7102674bae6..42a95ef19dafb3 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -132,6 +132,7 @@ def saved_model_compile_aot( "{}_makefile.inc".format(name), ], cmd = ( + "PYWRAP_TARGET='//third_party/tensorflow/python:_pywrap_tensorflow' " + "$(location {}) aot_compile_cpu ".format( clean_dep("//tensorflow/python/tools:saved_model_cli"), ) + diff --git a/tensorflow/python/tpu/BUILD b/tensorflow/python/tpu/BUILD index 6e690be7feb65a..4c857a6b724a8f 100644 --- a/tensorflow/python/tpu/BUILD +++ b/tensorflow/python/tpu/BUILD @@ -7,6 +7,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") # Placeholder: load py_proto_library load("//tensorflow:tensorflow.default.bzl", "tf_py_strict_test", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") load("//tensorflow/python/tpu:tpu.bzl", "internal_create_sanitizer_settings", "tpu_py_strict_test") # Do not add anymore paths here. You do not need to be in the visibility list @@ -1013,7 +1014,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "tensor_tracer_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":tensor_tracer_proto"], # ) @@ -1028,7 +1028,6 @@ tf_python_pybind_extension( "_pywrap_sparse_core_layout.pyi", ], deps = [ - "//tensorflow/core/tpu/kernels:_pywrap_sparse_core_layout_header_only", "//tensorflow/python/lib/core:pybind11_lib", "//tensorflow/python/lib/core:pybind11_status", "//tensorflow/python/lib/core:pybind11_status_headers", @@ -1036,7 +1035,14 @@ tf_python_pybind_extension( "@pybind11", "@pybind11_abseil//pybind11_abseil:status_casters", "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", - ], + ] + if_pywrap( + if_false = [ + "//tensorflow/core/tpu/kernels:_pywrap_sparse_core_layout_header_only", + ], + if_true = [ + "//tensorflow/core/tpu/kernels:sparse_core_layout", + ], + ), ) tf_python_pybind_extension( diff --git a/tensorflow/python/tpu/profiler/BUILD b/tensorflow/python/tpu/profiler/BUILD index af8e275606eb53..4fa1b7d0531655 100644 --- a/tensorflow/python/tpu/profiler/BUILD +++ b/tensorflow/python/tpu/profiler/BUILD @@ -44,6 +44,7 @@ py_strict_library( "//tensorflow/python/profiler:profiler_v2", "@absl_py//absl:app", "@absl_py//absl/flags", + "@pypi_packaging//:pkg", ], ) @@ -55,5 +56,6 @@ py_strict_binary( deps = [ ":capture_tpu_profile_lib", "@absl_py//absl/flags", + "@pypi_packaging//:pkg", ], ) diff --git a/tensorflow/python/tpu/profiler/capture_tpu_profile.py b/tensorflow/python/tpu/profiler/capture_tpu_profile.py index 4e5563d4bc2bdf..b02ac8aa7f788f 100644 --- a/tensorflow/python/tpu/profiler/capture_tpu_profile.py +++ b/tensorflow/python/tpu/profiler/capture_tpu_profile.py @@ -19,15 +19,15 @@ from absl import app from absl import flags -from distutils.version import LooseVersion +from packaging.version import Version from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver as resolver -from tensorflow.python.profiler import profiler_client -from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.framework import errors from tensorflow.python.framework import versions from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.profiler import profiler_client +from tensorflow.python.profiler import profiler_v2 as profiler from tensorflow.python.tpu.profiler import version as profiler_version FLAGS = flags.FLAGS @@ -139,7 +139,7 @@ def main(unused_argv=None): print('TensorFlow version %s detected' % tf_version) print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) - if LooseVersion(tf_version) < LooseVersion('2.2.0'): + if Version(tf_version) < Version('2.2.0'): sys.exit('You must install tensorflow >= 2.2.0 to use this plugin.') if not FLAGS.service_addr and not FLAGS.tpu: @@ -188,7 +188,7 @@ def main(unused_argv=None): gfile.MakeDirs(FLAGS.logdir) try: - if LooseVersion(tf_version) < LooseVersion('2.3.0'): + if Version(tf_version) < Version('2.3.0'): profiler_client.trace(service_addr, os.path.expanduser(FLAGS.logdir), duration_ms, workers_list, FLAGS.num_tracing_attempts) diff --git a/tensorflow/python/tpu/tpu_embedding_v3.py b/tensorflow/python/tpu/tpu_embedding_v3.py index 1536801f77a74c..c822ee9ddae177 100644 --- a/tensorflow/python/tpu/tpu_embedding_v3.py +++ b/tensorflow/python/tpu/tpu_embedding_v3.py @@ -84,6 +84,7 @@ class SparseCoreEmbeddingConfig: max_unique_ids_per_table: Optional[Dict[str, int]] = None allow_id_dropping: bool = False initialize_tables_on_host: bool = True + enable_fast_table_initialization: bool = False class EmbeddingPipeliningContext(control_flow_ops.ControlFlowContext): @@ -812,8 +813,17 @@ def _create_variables( ) def table_initialize_fn(shape, dtype, shard_info=None): + # If enable fast table initialization, we will initialize the table + # directly on the device and use the initializer from the first table. + if self._sparse_core_embedding_config.enable_fast_table_initialization: + return stacked_tables[0].initializer( + shape=(shard_info.shape[0], stacked_tables[0].dim), + dtype=dtype, + ) + # Concat all the tables along the first axis. concat_tensors = [] + # Temporary patch, we need to initialize tables with the SC level # sharding. Note that we need to ensure that the vocab size is divisible # by the global number of SC. diff --git a/tensorflow/python/training/BUILD b/tensorflow/python/training/BUILD index 5fe6311c9902f5..23460be0233c0d 100644 --- a/tensorflow/python/training/BUILD +++ b/tensorflow/python/training/BUILD @@ -781,7 +781,6 @@ tf_proto_library( # py_proto_library( # name = "checkpoint_state_py_pb2", # testonly = 0, -# api_version = 2, # deps = [":checkpoint_state"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py index 2478ab64e6be80..7bfddf38185b5f 100644 --- a/tensorflow/python/training/server_lib_test.py +++ b/tensorflow/python/training/server_lib_test.py @@ -18,6 +18,7 @@ import numpy as np +from xla.tsl.protobuf import rpc_options_pb2 from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 @@ -35,7 +36,6 @@ from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import server_lib -from tsl.protobuf import rpc_options_pb2 class GrpcServerTest(test.TestCase): diff --git a/tensorflow/python/types/core.py b/tensorflow/python/types/core.py index 534211fd9d29ba..627725c7765c2c 100644 --- a/tensorflow/python/types/core.py +++ b/tensorflow/python/types/core.py @@ -264,6 +264,9 @@ def experimental_get_compiler_ir(self, *args, **kwargs): Function callable with the following kwargs: - `stage` at which the compiler IR should be serialized. Allowed values are: + - `stablehlo`: StableHLO module textual assembly. + - `stablehlo_serialized`: Like stage=`stablehlo`, but the output is a + serialized MLIR bytecode. - `hlo`: HLO output after conversion from TF (https://www.tensorflow.org/xla/operation_semantics). - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index ee56bb821a2f30..9933709da15394 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -4,7 +4,7 @@ load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_library") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable", "tf_py_strict_test", "tf_python_pybind_extension") load("//tensorflow/core/platform:build_config.bzl", "tf_proto_library") # @unused -load("//tensorflow/core/platform:build_config_root.bzl", "if_static") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap", "if_static") visibility = [ "//engedu/ml/tf_from_scratch:__pkg__", @@ -72,10 +72,22 @@ tf_python_pybind_extension( pytype_srcs = [ "_pywrap_tfprof.pyi", ], - deps = [ - "//tensorflow/core/profiler/internal:print_model_analysis_hdr", - "@pybind11", - ], + deps = if_pywrap( + if_false = [ + "//tensorflow/core/profiler/internal:print_model_analysis_hdr", + "@pybind11", + ], + if_true = [ + "//tensorflow/core:framework", + "//tensorflow/core/framework:reader_base", + "//tensorflow/core:lib_headers_for_pybind", + "//tensorflow/core/profiler/internal:print_model_analysis", + "//third_party/python_runtime:headers", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen3", + "@pybind11", + ], + ), ) tf_python_pybind_extension( @@ -107,7 +119,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", - ], + ] + if_pywrap([":cpp_nest"]), ) cc_library( @@ -140,7 +152,7 @@ tf_python_pybind_extension( "//tensorflow/python/lib/core:pybind11_lib", "//third_party/python_runtime:headers", "@pybind11", - ], + ] + if_pywrap([":kernel_registry"]), ) tf_python_pybind_extension( @@ -257,7 +269,7 @@ tf_python_pybind_extension( "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", "@pybind11", - ], + ] + if_pywrap(["@com_google_absl//absl/strings"]), ) cc_library( @@ -527,7 +539,7 @@ cc_library( tf_python_pybind_extension( name = "_function_parameter_canonicalizer_binding_for_test", - testonly = True, + # testonly = True, srcs = ["function_parameter_canonicalizer_binding_for_test.cc"], hdrs = [ "function_parameter_canonicalizer.h", @@ -543,7 +555,12 @@ tf_python_pybind_extension( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/types:span", "@pybind11", - ], + ] + if_pywrap( + if_true = [ + "//tensorflow/compiler/tf2xla:tf2xla_opset", + "//tensorflow/python/lib/core:pybind11_lib", + ], + ), ) tf_py_strict_test( diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py index 7a4659e0f62251..9d22d659561a69 100644 --- a/tensorflow/python/util/compat.py +++ b/tensorflow/python/util/compat.py @@ -109,6 +109,21 @@ def as_text(bytes_or_text, encoding='utf-8'): def as_str(bytes_or_text, encoding='utf-8'): + """Acts as an alias for the `as_text` function.. + + Args: + bytes_or_text: The input value to be converted. A bytes or unicode object. + encoding: Optional string. The encoding to use if bytes_or_text is a bytes + object. Defaults to 'utf-8'. + + Returns: + A unicode string. + + Raises: + TypeError: If bytes_or_text is not a bytes or unicode object. + UnicodeDecodeError: If bytes_or_text is a bytes object and cannot be + decoded using the specified encoding. + """ return as_text(bytes_or_text, encoding) tf_export('compat.as_text')(as_text) diff --git a/tensorflow/python/util/custom_nest_protocol.py b/tensorflow/python/util/custom_nest_protocol.py index 1da4e463604b5f..7a10fead7d3d43 100644 --- a/tensorflow/python/util/custom_nest_protocol.py +++ b/tensorflow/python/util/custom_nest_protocol.py @@ -96,7 +96,7 @@ def __tf_flatten__(self): - This method only needs to flatten the current level. If current object has an attribute that also need custom flattening, nest functions (such as `nest.flatten`) will utilize this method to do recursive flattening. - - Components must ba a `tuple`, not a `list` + - Components must be a `tuple`, not a `list` """ @classmethod @@ -104,7 +104,7 @@ def __tf_unflatten__(cls, metadata, components): """Create a user-defined object from (metadata, components). Args: - metadata: a custom Python objet that stands for the static config for + metadata: a custom Python object that stands for the static config for reconstructing a new object of the current class. components: a `tuple` that contains the dynamic data fields of the current class, for object reconstruction. diff --git a/tensorflow/python/util/dispatch.py b/tensorflow/python/util/dispatch.py index 2605c2a17c7695..ff1fa45ba2a64d 100644 --- a/tensorflow/python/util/dispatch.py +++ b/tensorflow/python/util/dispatch.py @@ -234,7 +234,7 @@ def get_compatible_func(op, func): op_signature = _remove_annotation(tf_inspect.signature(op)) func_signature = _remove_annotation(tf_inspect.signature(func)) - # Identitical signatures, no need to apply compatibility fixes. + # Identical signatures, no need to apply compatibility fixes. if op_signature == func_signature: return func @@ -395,7 +395,7 @@ def dispatch_for_api(api, *signatures): being overridden. In particular, parameters must have the same names, and must occur in the same order. The dispatch target may optionally elide the "name" parameter, in which case it will be wrapped with a call to - `tf.name_scope` when appropraite. + `tf.name_scope` when appropriate. Args: api: The TensorFlow API to override. @@ -797,7 +797,7 @@ def _signature_from_annotations(func): # decorators. # # _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of -# `(api, dispatch_target)` pairs. Used to impelement +# `(api, dispatch_target)` pairs. Used to implement # `unregister_elementwise_api_handler`. _UNARY_ELEMENTWISE_APIS = [] _BINARY_ELEMENTWISE_APIS = [] @@ -1206,7 +1206,7 @@ def add_dispatch_support(target=None, iterable_parameters=None): need to be handled specially during dispatch, since just iterating over an iterable uses up its values. In the following example, we define a new API whose second argument can be an iterable value; and then override the default - implementatio of that API when the iterable contains MaskedTensors: + implementation of that API when the iterable contains MaskedTensors: >>> @add_dispatch_support(iterable_parameters=['ys']) ... def add_tensor_to_list_of_tensors(x, ys): @@ -1245,7 +1245,7 @@ def decorator(dispatch_target): @traceback_utils.filter_traceback def op_dispatch_handler(*args, **kwargs): - """Call `dispatch_target`, peforming dispatch when appropriate.""" + """Call `dispatch_target`, performing dispatch when appropriate.""" # Type-based dispatch system (dispatch v2): if api_dispatcher is not None: diff --git a/tensorflow/python/util/dispatch_test.py b/tensorflow/python/util/dispatch_test.py index 7bb8e8f8898f6a..67fbd9f0e8d34a 100644 --- a/tensorflow/python/util/dispatch_test.py +++ b/tensorflow/python/util/dispatch_test.py @@ -957,7 +957,7 @@ def silly_add(x: SillyTensor, y: SillyTensor): def silly_abs(x: SillyTensor): del x - # Note: `expeced` does not contain keys or values from SillyTensor. + # Note: `expected` does not contain keys or values from SillyTensor. targets = dispatch.type_based_dispatch_signatures_for(MaskedTensor) expected = {math_ops.add: [{"x": MaskedTensor, "y": MaskedTensor}], array_ops.concat: [{"values": MaskedTensorList}]} diff --git a/tensorflow/python/util/function_parameter_canonicalizer.cc b/tensorflow/python/util/function_parameter_canonicalizer.cc index 748210d7f5439c..662b19d5ccaf73 100644 --- a/tensorflow/python/util/function_parameter_canonicalizer.cc +++ b/tensorflow/python/util/function_parameter_canonicalizer.cc @@ -120,7 +120,7 @@ bool FunctionParameterCanonicalizer::Canonicalize( index = InternedArgNameLinearSearch(key); Py_DECREF(key); - // Stil not found, then return an error. + // Still not found, then return an error. if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) { PyErr_Format(PyExc_TypeError, "Got an unexpected keyword argument '%s'", diff --git a/tensorflow/python/util/lazy_loader.py b/tensorflow/python/util/lazy_loader.py index 7d8c186677583f..220f2861288882 100644 --- a/tensorflow/python/util/lazy_loader.py +++ b/tensorflow/python/util/lazy_loader.py @@ -37,7 +37,7 @@ def __init__(self, local_name, parent_module_globals, name, warning=None): self._tfll_warning = warning # These members allows doctest correctly process this module member without - # triggering self._load(). self._load() mutates parant_module_globals and + # triggering self._load(). self._load() mutates parent_module_globals and # triggers a dict mutated during iteration error from doctest.py. # - for from_module() super().__setattr__("__module__", name.rsplit(".", 1)[0]) diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 0378076cba247b..03a75d1a505d4c 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -92,7 +92,7 @@ def __eq__(self, other): ) def __len__(self): - # Used by `nest.map_structure_up_to` and releatd functions to verify the + # Used by `nest.map_structure_up_to` and related functions to verify the # arity compatibility. return 1 diff --git a/tensorflow/python/util/nest_util.py b/tensorflow/python/util/nest_util.py index c53042f7dc11ab..54f8cf1026a0e0 100644 --- a/tensorflow/python/util/nest_util.py +++ b/tensorflow/python/util/nest_util.py @@ -553,7 +553,7 @@ def _tf_core_packed_nest_with_indices( index: Index at which to start reading from flat. is_nested_fn: Function used to test if a value should be treated as a nested structure. - sequence_fn: Function used to generate a new strcuture instance. + sequence_fn: Function used to generate a new structure instance. Returns: The tuple (new_index, child), where: diff --git a/tensorflow/python/util/numpy_compat.py b/tensorflow/python/util/numpy_compat.py index e9fa23ae637e09..ce2cfe8220e582 100644 --- a/tensorflow/python/util/numpy_compat.py +++ b/tensorflow/python/util/numpy_compat.py @@ -80,3 +80,64 @@ def np_asarray(values, dtype=None, order=None, copy=None): return np.asarray(values, dtype=dtype, order=order, copy=copy) else: return np.asarray(values, dtype=dtype, order=order) + + +def np_where(condition, x=None, y=None): + """Return elements chosen from x or y depending on condition. + + When only condition is provided, np.where(condition) is a shorthand for + np.asarray(condition).nonzero(). See + https://numpy.org/doc/stable/reference/generated/numpy.where.html. NumPy + 2.1.0rc0 disallows 0D input arrays in nonzero, so np.atleast_1d is used here + to remain compatible with NumPy 1.x. See + https://github.com/numpy/numpy/pull/26268. + + Args: + condition: Array_like, bool. Where True, yield x, otherwise yield y. + x: Array_like. Values from which to choose. x, y and condition need to be + broadcastable to some shape. + y: Array_like. Values from which to choose. x, y and condition need to be + broadcastable to some shape. + + Returns: + An array with elements from x where condition is True, and elements from y + elsewhere. Or the indices of the elements that are non-zero. + """ + if x is None and y is None: + if np.lib.NumpyVersion(np.__version__) >= '2.1.0.rc0': + return np.atleast_1d(np.asarray(condition)).nonzero() + return np.where(condition) + return np.where(condition, x, y) + + +def np_reshape(a, /, shape=None, *, newshape=None, order='C', copy=None): + """Reshapes an array without changing its data. + + NumPy 2.1.0rc1 added shape and copy arguments to numpy.reshape. See + https://github.com/numpy/numpy/pull/26292. Both newshape and shape keywords + are supported, but newshape is going to be deprecated. Use `shape` instead. + + Besides, shape cannot be None now. See + https://github.com/numpy/numpy/blob/v2.1.0rc1/numpy/_core/fromnumeric.py#L309. + Previously, np.reshape with newshape=None returned a copy. To maintain this + behavior, we now use asarray to create an ndarray. + + Args: + a: Array_like. Array to be reshaped. + shape: The new shape of the array. + newshape: The new shape of the array (deprecated). + order: {‘C’, ‘F’, ‘K’}. + copy: bool. If True, then the array data is copied. If None, a copy will + only be made if it’s required by order. For False it raises a ValueError if + a copy cannot be avoided. + + Returns: + This will be a new view object if possible; otherwise, it will be a copy. + """ + if shape is None: + shape = newshape + if np.lib.NumpyVersion(np.__version__) >= '2.1.0.rc0': + if shape is None and newshape is None: + return np.asarray(a, order=order, copy=copy) + return np.reshape(a, shape, order=order, copy=copy) + return np.reshape(a, shape, order=order) diff --git a/tensorflow/python/util/protobuf/BUILD b/tensorflow/python/util/protobuf/BUILD index ecd1aeb919483e..73005deae56e4b 100644 --- a/tensorflow/python/util/protobuf/BUILD +++ b/tensorflow/python/util/protobuf/BUILD @@ -60,7 +60,6 @@ filegroup( # name = "compare_test_py_pb2", # testonly = 1, # has_services = 0, -# api_version = 2, # deps = [":compare_test_proto"], # ) # copybara:uncomment_end diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 922db88bddf2c7..661ba0aed648d4 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -816,7 +816,7 @@ void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg, // Returns true iff there were no "internal" errors. In other words, // errors that has nothing to do with structure checking. // If an "internal" error occurred, the appropriate Python error will be -// set and the caller can propage it directly to the user. +// set and the caller can propagate it directly to the user. // // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must // be empty. diff --git a/tensorflow/security/fuzzing/cc/BUILD b/tensorflow/security/fuzzing/cc/BUILD index 50a647266abe05..c0bf2a86adbe5c 100644 --- a/tensorflow/security/fuzzing/cc/BUILD +++ b/tensorflow/security/fuzzing/cc/BUILD @@ -243,6 +243,6 @@ tf_cc_fuzz_test( deps = [ "@local_tsl//tsl/platform:env", "@local_xla//xla:text_literal_reader", - "@local_xla//xla/service:hlo_parser", + "@local_xla//xla/hlo/parser:hlo_parser", ], ) diff --git a/tensorflow/security/fuzzing/cc/base64_fuzz.cc b/tensorflow/security/fuzzing/cc/base64_fuzz.cc index a964c860fb6d26..1e241ac9c0a38b 100644 --- a/tensorflow/security/fuzzing/cc/base64_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/base64_fuzz.cc @@ -30,7 +30,7 @@ namespace { void FuzzTest(std::string_view input) { std::string encoded_string; std::string decoded_string; - tensorflow::Status s; + absl::Status s; s = tensorflow::Base64Encode(input, &encoded_string); assert(s.ok()); s = tensorflow::Base64Decode(encoded_string, &decoded_string); diff --git a/tensorflow/security/fuzzing/cc/end_to_end_fuzz.cc b/tensorflow/security/fuzzing/cc/end_to_end_fuzz.cc index f6f86da3e12bb0..ba9e9573b24e10 100644 --- a/tensorflow/security/fuzzing/cc/end_to_end_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/end_to_end_fuzz.cc @@ -46,8 +46,8 @@ void FuzzEndToEnd( TF_CHECK_OK(tsl::WriteBinaryProto(tensorflow::Env::Default(), export_dir + kSavedModelFilenamePb, model)); - Status status = LoadSavedModel(session_options, run_options, export_dir, - {kSavedModelTagServe}, &bundle); + absl::Status status = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); if (!status.ok()) { return; } @@ -55,7 +55,7 @@ void FuzzEndToEnd( // Create output placeholder tensors for results std::vector outputs; std::vector output_names = {"fuzz_out:0", "fuzz_out:1"}; - tensorflow::Status status_run = + absl::Status status_run = bundle.session->Run(input_dict, output_names, {}, &outputs); } diff --git a/tensorflow/security/fuzzing/cc/fuzz_session.h b/tensorflow/security/fuzzing/cc/fuzz_session.h index 1673309025be5a..b492c6e91da2d3 100644 --- a/tensorflow/security/fuzzing/cc/fuzz_session.h +++ b/tensorflow/security/fuzzing/cc/fuzz_session.h @@ -111,7 +111,7 @@ class FuzzSession { // Initializes the FuzzSession. Not safe for multithreading. // Separate init function because the call to virtual BuildGraphDef // can't be put into the constructor. - Status InitIfNeeded() { + absl::Status InitIfNeeded() { if (initialized_) { return absl::OkStatus(); } @@ -126,7 +126,7 @@ class FuzzSession { GraphDef graph_def; TF_CHECK_OK(root.ToGraphDef(&graph_def)); - Status status = session_->Create(graph_def); + absl::Status status = session_->Create(graph_def); if (!status.ok()) { // This is FATAL, because this code is designed to fuzz an op // within a session. Failure to create the session means we @@ -147,15 +147,15 @@ class FuzzSession { } // Same as RunInputs but don't ignore status - Status RunInputsWithStatus( - const std::vector >& inputs) { + absl::Status RunInputsWithStatus( + const std::vector>& inputs) { return session_->Run(inputs, {}, {"output"}, nullptr); } // Dispatches to FuzzImpl; small amount of sugar to keep the code // of the per-op fuzzers tiny. void Fuzz(const T&... args) { - Status status = InitIfNeeded(); + absl::Status status = InitIfNeeded(); TF_CHECK_OK(status) << "Fuzzer graph initialization failed: " << status.message(); // No return value from fuzzing: Success is defined as "did not diff --git a/tensorflow/security/fuzzing/cc/ops/bincount_fuzz.cc b/tensorflow/security/fuzzing/cc/ops/bincount_fuzz.cc index 02238e70b55034..8db2c6e6b88d88 100644 --- a/tensorflow/security/fuzzing/cc/ops/bincount_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/ops/bincount_fuzz.cc @@ -40,7 +40,7 @@ class FuzzBincount : public FuzzSession { Tensor size(DT_INT32, {}); size.flat()(0) = nbins; - Status s = RunInputsWithStatus( + absl::Status s = RunInputsWithStatus( {{"arr", arr}, {"size", size}, {"weights", weights}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); diff --git a/tensorflow/security/fuzzing/cc/ops/concat_fuzz.cc b/tensorflow/security/fuzzing/cc/ops/concat_fuzz.cc index 2bd35b1a2c7ddd..b64bd042976028 100644 --- a/tensorflow/security/fuzzing/cc/ops/concat_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/ops/concat_fuzz.cc @@ -45,7 +45,7 @@ class FuzzConcat : public FuzzSession { const int32& axis) final { Tensor axis_tensor(DT_INT32, {}); axis_tensor.scalar()() = axis; - Status s = RunInputsWithStatus( + absl::Status s = RunInputsWithStatus( {{"value1", value1}, {"value2", value2}, {"axis", axis_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); diff --git a/tensorflow/security/fuzzing/cc/ops/identity_fuzz.cc b/tensorflow/security/fuzzing/cc/ops/identity_fuzz.cc index 2a1260f35c9686..c05dca281d5726 100644 --- a/tensorflow/security/fuzzing/cc/ops/identity_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/ops/identity_fuzz.cc @@ -34,7 +34,7 @@ class FuzzIdentity : public FuzzSession { tensorflow::ops::Identity(scope.WithOpName("output"), op_node); } void FuzzImpl(const Tensor& input_tensor) final { - Status s = RunInputsWithStatus({{"input", input_tensor}}); + absl::Status s = RunInputsWithStatus({{"input", input_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); } diff --git a/tensorflow/security/fuzzing/cc/ops/string_ops_fuzz.cc b/tensorflow/security/fuzzing/cc/ops/string_ops_fuzz.cc index 33fcf093ea09eb..396560f8fc73ee 100644 --- a/tensorflow/security/fuzzing/cc/ops/string_ops_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/ops/string_ops_fuzz.cc @@ -42,7 +42,7 @@ class FuzzStringOpsStringSplit : public FuzzSession { Tensor separator_tensor(tensorflow::DT_STRING, TensorShape({})); separator_tensor.scalar()() = separator_string; - Status s = RunInputsWithStatus( + absl::Status s = RunInputsWithStatus( {{"input", input_tensor}, {"delimiter", separator_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); @@ -78,7 +78,7 @@ class FuzzStringOpsStringSplitV2 Tensor separator_tensor(tensorflow::DT_STRING, TensorShape({})); separator_tensor.scalar()() = separator_string; - Status s = RunInputsWithStatus( + absl::Status s = RunInputsWithStatus( {{"input", input_tensor}, {"separator", separator_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); @@ -103,7 +103,7 @@ class FuzzStringOpsStringUpper : public FuzzSession { Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); input_tensor.scalar()() = input_string; - Status s = RunInputsWithStatus({{"input", input_tensor}}); + absl::Status s = RunInputsWithStatus({{"input", input_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); } diff --git a/tensorflow/security/fuzzing/cc/ops/string_to_number_fuzz.cc b/tensorflow/security/fuzzing/cc/ops/string_to_number_fuzz.cc index 11b3e8adfce6e6..835da4e7643c4d 100644 --- a/tensorflow/security/fuzzing/cc/ops/string_to_number_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/ops/string_to_number_fuzz.cc @@ -31,7 +31,7 @@ class FuzzStringToNumber : public FuzzSession { void FuzzImpl(const std::string& input_string) final { Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); input_tensor.scalar()() = input_string; - Status s = RunInputsWithStatus({{"input", input_tensor}}); + absl::Status s = RunInputsWithStatus({{"input", input_tensor}}); if (!s.ok()) { LOG(ERROR) << "Execution failed: " << s.message(); } diff --git a/tensorflow/security/fuzzing/cc/status_fuzz.cc b/tensorflow/security/fuzzing/cc/status_fuzz.cc index ebf8406cdb1fe8..9e259fd4e8d4c9 100644 --- a/tensorflow/security/fuzzing/cc/status_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/status_fuzz.cc @@ -31,7 +31,7 @@ limitations under the License. namespace { void FuzzTest(absl::StatusCode error_code, std::string_view error_message) { - tensorflow::Status s = tensorflow::Status(error_code, error_message); + absl::Status s = absl::Status(error_code, error_message); const std::string actual_message = s.ToString(); const std::size_t pos = actual_message.rfind(error_message); assert(pos != std::string::npos); // Suffix is error message diff --git a/tensorflow/security/fuzzing/cc/status_group_fuzz.cc b/tensorflow/security/fuzzing/cc/status_group_fuzz.cc index 747e4d95c125e9..a0273717367262 100644 --- a/tensorflow/security/fuzzing/cc/status_group_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/status_group_fuzz.cc @@ -29,10 +29,10 @@ namespace { void FuzzTest(absl::StatusCode error_code, bool is_derived) { const std::string error_message = "ERROR"; tensorflow::StatusGroup sg; - tensorflow::Status s = tensorflow::Status(error_code, error_message); + absl::Status s = absl::Status(error_code, error_message); if (is_derived) { - tensorflow::Status derived_s = tensorflow::StatusGroup::MakeDerived(s); + absl::Status derived_s = tensorflow::StatusGroup::MakeDerived(s); sg.Update(derived_s); } else { sg.Update(s); diff --git a/tensorflow/security/fuzzing/cc/text_literal_reader_fuzz.cc b/tensorflow/security/fuzzing/cc/text_literal_reader_fuzz.cc index ba504a38a506a6..5ad8d91bb3a205 100644 --- a/tensorflow/security/fuzzing/cc/text_literal_reader_fuzz.cc +++ b/tensorflow/security/fuzzing/cc/text_literal_reader_fuzz.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "fuzztest/fuzztest.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/text_literal_reader.h" #include "tsl/platform/env.h" diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7fb0387e29a065..eb8b647368c704 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -69,6 +69,7 @@ load( "@local_xla//xla/tsl:tsl.bzl", "tsl_gpu_library", _cc_header_only_library = "cc_header_only_library", + _custom_op_cc_header_only_library = "custom_op_cc_header_only_library", _if_cuda_or_rocm = "if_cuda_or_rocm", _if_cuda_tools = "if_cuda_tools", _if_nccl = "if_nccl", @@ -79,6 +80,12 @@ load( "if_tensorrt", "if_tensorrt_exec", ) +load( + "@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", + "use_pywrap_rules", + _pybind_extension = "pybind_extension", + _stripped_cc_info = "stripped_cc_info", +) # Do not sort: copybara rule changes this def register_extension_info(**kwargs): @@ -88,7 +95,8 @@ def register_extension_info(**kwargs): # not contain rc or alpha, only numbers. # Also update tensorflow/core/public/version.h # and tensorflow/tools/pip_package/setup.py -VERSION = "2.18.0" +WHEEL_VERSION = "2.19.0" +VERSION = "2.19.0" VERSION_MAJOR = VERSION.split(".")[0] two_gpu_tags = ["requires-gpu-nvidia:2", "manual", "no_pip"] @@ -109,6 +117,7 @@ def clean_dep(target): return str(Label(target)) cc_header_only_library = _cc_header_only_library +custom_op_cc_header_only_library = _custom_op_cc_header_only_library transitive_hdrs = _transitive_hdrs def if_oss(oss_value, google_value = []): @@ -141,7 +150,7 @@ def if_not_v2(a): def if_nvcc(a): return select({ - "@local_config_cuda//cuda:using_nvcc": a, + clean_dep("//tensorflow:is_cuda_nvcc"): a, "//conditions:default": [], }) @@ -177,8 +186,8 @@ def tf_android_core_proto_headers(core_proto_sources_relative): def tf_portable_full_lite_protos(full, lite): return select({ - "//tensorflow:mobile_lite_protos": lite, - "//tensorflow:mobile_full_protos": full, + clean_dep("//tensorflow:mobile_lite_protos"): lite, + clean_dep("//tensorflow:mobile_full_protos"): full, # The default should probably be lite runtime, but since most clients # seem to use the non-lite version, let's make that the default for now. "//conditions:default": full, @@ -656,9 +665,13 @@ def _rpath_user_link_flags(name): ], }) +# TODO(b/356020232): remove completely after migration is done # Bazel-generated shared objects which must be linked into TensorFlow binaries # to define symbols from //tensorflow/core:framework and //tensorflow/core:lib. def tf_binary_additional_srcs(fullversion = False): + if use_pywrap_rules(): + return [] + if fullversion: suffix = "." + VERSION else: @@ -674,7 +687,11 @@ def tf_binary_additional_srcs(fullversion = False): ], ) +# TODO(b/356020232): remove completely after migration is done def tf_binary_additional_data_deps(): + if use_pywrap_rules(): + return [] + return if_static( extra_deps = [], macos = [ @@ -689,7 +706,11 @@ def tf_binary_additional_data_deps(): ], ) +# TODO(b/356020232): remove completely after migration is done def tf_binary_pybind_deps(): + if use_pywrap_rules(): + return [] + return select({ clean_dep("//tensorflow:macos"): [ clean_dep( @@ -708,8 +729,12 @@ def tf_binary_pybind_deps(): ], }) +# TODO(b/356020232): remove completely after migration is done # Helper function for the per-OS tensorflow libraries and their version symlinks def tf_shared_library_deps(): + if use_pywrap_rules(): + return [] + return select({ clean_dep("//tensorflow:macos_with_framework_shared_object"): [ clean_dep("//tensorflow:libtensorflow.dylib"), @@ -775,6 +800,11 @@ def tf_cc_shared_object( visibility = None, **kwargs): """Configure the shared object (.so) file for TensorFlow.""" + + actual_framework_so = framework_so + if use_pywrap_rules(): + actual_framework_so = [] + if soversion != None: suffix = "." + str(soversion).split(".")[0] longsuffix = "." + str(soversion) @@ -825,13 +855,13 @@ def tf_cc_shared_object( soname = name_os_major.split("/")[-1] data_extra = [] - if framework_so != []: + if actual_framework_so != []: data_extra = tf_binary_additional_data_deps() cc_binary( exec_properties = if_google({"cpp_link.mem": "16g"}, {}), name = name_os_full, - srcs = srcs + framework_so, + srcs = srcs + actual_framework_so, deps = deps, linkshared = 1, data = data + data_extra, @@ -865,6 +895,7 @@ def tf_cc_shared_object( testonly = testonly, ) +# TODO(b/356020232): remove completely after migration is done # buildozer: disable=function-docstring-args def tf_cc_shared_library_opensource( name, @@ -885,6 +916,10 @@ def tf_cc_shared_library_opensource( win_def_file = None, visibility = None): """Configures the shared object file for TensorFlow.""" + + if use_pywrap_rules(): + return + names = _get_shared_library_name_os_version_matrix( name, per_os_targets = per_os_targets, @@ -944,6 +979,7 @@ def tf_cc_shared_library_opensource( visibility = visibility, ) +# TODO(b/356020232): remove completely after migration is done def _tf_cc_shared_library_opensource( name, additional_linker_inputs = None, @@ -960,6 +996,9 @@ def _tf_cc_shared_library_opensource( user_link_flags = None, visibility = None, win_def_file = None): + if use_pywrap_rules(): + return + cc_library_name = name + "_cclib" cc_library( name = cc_library_name, @@ -1671,8 +1710,7 @@ def tf_gpu_cc_test( linkstatic = select({ # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. clean_dep("//tensorflow:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, + clean_dep("//tensorflow:is_cuda_enabled"): 1, "//conditions:default": 0, }), suffix = "_gpu", @@ -1699,8 +1737,7 @@ def tf_gpu_cc_test( linkstatic = select({ # TODO(allenl): Remove Mac static linking when Bazel 0.6 is out. clean_dep("//tensorflow:macos"): 1, - "@local_config_cuda//cuda:using_nvcc": 1, - "@local_config_cuda//cuda:using_clang": 1, + clean_dep("//tensorflow:is_cuda_enabled"): 1, "//conditions:default": 0, }), suffix = "_2gpu", @@ -1762,6 +1799,7 @@ def tf_cc_tests( srcs, deps, name = "", + data = [], linkstatic = 0, tags = [], size = "medium", @@ -1779,6 +1817,7 @@ def tf_cc_tests( size = size, srcs = [src], args = args, + data = data, kernels = kernels, linkopts = linkopts, linkstatic = linkstatic, @@ -1901,11 +1940,11 @@ def _cuda_copts(opts = []): """ return select({ "//conditions:default": [], - "@local_config_cuda//cuda:using_nvcc": [ + clean_dep("//tensorflow:is_cuda_nvcc"): [ "-nvcc_options=relaxed-constexpr", "-nvcc_options=ftz=true", ] + opts, - "@local_config_cuda//cuda:using_clang": [ + clean_dep("//tensorflow:is_cuda_clang"): [ "-fcuda-flush-denormals-to-zero", ] + opts, }) @@ -2058,13 +2097,17 @@ def tf_kernel_library( ) # TODO(gunan): CUDA dependency not clear here. Fix it. - tf_cc_shared_object( - name = "libtfkernel_%s.so" % name, - srcs = srcs + hdrs + textual_hdrs, - copts = copts, - tags = ["manual", "notap"], - deps = deps, - ) + # TODO(b/356020232): remove completely after migration is done + if use_pywrap_rules(): + pass + else: + tf_cc_shared_object( + name = "libtfkernel_%s.so" % name, + srcs = srcs + hdrs + textual_hdrs, + copts = copts, + tags = ["manual", "notap"], + deps = deps, + ) register_extension_info( extension = tf_kernel_library, @@ -2256,6 +2299,7 @@ check_deps = rule( }, ) +# TODO(b/356020232): cleanup use_pywrap_rules after migration is done def tf_custom_op_library( name, srcs = [], @@ -2270,14 +2314,26 @@ def tf_custom_op_library( if not gpu_deps: gpu_deps = [] - deps = deps + if_cuda_or_rocm([ + if use_pywrap_rules(): + deps = [clean_dep("//tensorflow/python:_pywrap_tensorflow_common")] + deps + else: + deps = list(deps) + + deps += if_cuda_or_rocm([ clean_dep("//tensorflow/core:stream_executor_headers_lib"), ]) + if_cuda([ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart_static", - ]) + if_windows([ - clean_dep("//tensorflow/python:pywrap_tensorflow_import_lib"), - ]) + tf_custom_op_library_additional_deps() + "@local_config_cuda//cuda:cuda_runtime", + ]) + + if use_pywrap_rules(): + pass + else: + deps += if_windows( + [clean_dep("//tensorflow/python:pywrap_tensorflow_import_lib")], + ) + + deps += tf_custom_op_library_additional_deps() # Override EIGEN_STRONG_INLINE to inline when # --define=override_eigen_strong_inline=true to avoid long compiling time. @@ -2386,6 +2442,7 @@ _append_init_to_versionscript = rule( implementation = _append_init_to_versionscript_impl, ) +# TODO(b/356020232): remove completely after migration is done # This macro should only be used for pywrap_tensorflow_internal.so. # It was copied and refined from the original tf_py_wrap_cc_opensource rule. # buildozer: disable=function-docstring-args @@ -2401,6 +2458,15 @@ def pywrap_tensorflow_macro_opensource( version_script = None, win_def_file = None): """Builds the pywrap_tensorflow_internal shared object.""" + + if use_pywrap_rules(): + native.py_library( + name = name, + srcs = [], + deps = [], + ) + return + module_name = name.split("/")[-1] # Convert a rule name such as foo/bar/baz to foo/bar/_baz.so @@ -2531,6 +2597,8 @@ def pywrap_tensorflow_macro_opensource( # Export open source version of pywrap_tensorflow_macro under base name as well. pywrap_tensorflow_macro = pywrap_tensorflow_macro_opensource +# TODO(b/356020232): keep only the use_pywrap_rules part after migration is done +# also remove the comments below, as they will become false # This macro is for running python tests against system installed pip package # on Windows. # @@ -2547,23 +2615,45 @@ pywrap_tensorflow_macro = pywrap_tensorflow_macro_opensource # Note that this only works on Windows. See the definition of # //third_party/tensorflow/tools/pip_package:win_pip_package_marker for specific reasons. # 2. When --define=no_tensorflow_py_deps=false (by default), it's a normal py_test. -def py_test(deps = [], data = [], kernels = [], exec_properties = None, test_rule = _plain_py_test, **kwargs): +def py_test( + deps = [], + data = [], + kernels = [], + exec_properties = None, + test_rule = _plain_py_test, + env = {}, + **kwargs): if not exec_properties: exec_properties = tf_exec_properties(kwargs) - _make_tags_mutable(kwargs) - test_rule( - deps = select({ - "//conditions:default": deps, - clean_dep("//tensorflow:no_tensorflow_py_deps"): [], - }), - data = data + select({ - "//conditions:default": kernels, - clean_dep("//tensorflow:no_tensorflow_py_deps"): [], - }), - exec_properties = exec_properties, - **kwargs - ) + if use_pywrap_rules(): + test_env = { + "PYWRAP_TARGET": clean_dep(Label("//tensorflow/python:_pywrap_tensorflow")), + } + test_env.update(env) + actual_deps = deps.to_list() if hasattr(deps, "to_list") else deps + test_rule( + deps = actual_deps + [test_env["PYWRAP_TARGET"]], + exec_properties = exec_properties, + env = test_env, + data = data, + **kwargs + ) + else: + _make_tags_mutable(kwargs) + test_rule( + deps = select({ + "//conditions:default": deps, + clean_dep("//tensorflow:no_tensorflow_py_deps"): [], + }), + data = data + select({ + "//conditions:default": kernels, + clean_dep("//tensorflow:no_tensorflow_py_deps"): [], + }), + exec_properties = exec_properties, + env = env, + **kwargs + ) register_extension_info( extension = py_test, @@ -2596,11 +2686,15 @@ def pytype_library(name, pytype_deps = [], pytype_srcs = [], **kwargs): _make_tags_mutable(kwargs) _plain_py_library(name = name, **kwargs) +# TODO(b/356020232): remove completely after migration is done # Tensorflow uses rules_python 0.0.1, and in that version of rules_python, # the rules require the tags value to be a mutable list because they # modify it in-place. Later versions of rules_python don't have this # requirement. def _make_tags_mutable(kwargs): + if use_pywrap_rules(): + return + if "tags" in kwargs and kwargs["tags"] != None: # The value might be a frozen list, which looks just like # a regular list. So always make a copy. @@ -3021,6 +3115,7 @@ def pybind_library( **kwargs ) +# TODO(b/356020232): remove completely after migration is done # buildozer: disable=function-docstring-args def pybind_extension_opensource( name, @@ -3233,7 +3328,8 @@ def pybind_extension_opensource( ) # Export open source version of pybind_extension under base name as well. -pybind_extension = pybind_extension_opensource +pybind_extension = _pybind_extension if use_pywrap_rules() else pybind_extension_opensource +stripped_cc_info = _stripped_cc_info # Note: we cannot add //third_party/tf_runtime:__subpackages__ here, # because that builds all of tf_runtime's packages, and some of them @@ -3252,7 +3348,11 @@ def tsl_async_value_deps(): "@tf_runtime//third_party/llvm_derived:in_place", ] +# TODO(b/356020232): remove completely after migration is done def tf_python_pybind_static_deps(testonly = False): + if use_pywrap_rules(): + return [] + # TODO(b/146808376): Reduce the dependencies to those that are really needed. static_deps = [ "//:__subpackages__", @@ -3323,6 +3423,7 @@ def tf_python_pybind_static_deps(testonly = False): ] return if_oss(static_deps) +# TODO(b/356020232): remove completely after migration is done # buildozer: enable=function-docstring-args def tf_python_pybind_extension_opensource( name, @@ -3342,7 +3443,8 @@ def tf_python_pybind_extension_opensource( pytype_srcs = [], testonly = False, visibility = None, - win_def_file = None): + win_def_file = None, + additional_exported_symbols = None): """A wrapper macro for pybind_extension_opensource that is used in tensorflow/python/BUILD. Please do not use it anywhere else as it may behave unexpectedly. b/146445820 @@ -3374,7 +3476,7 @@ def tf_python_pybind_extension_opensource( ) # Export open source version of tf_python_pybind_extension under base name as well. -tf_python_pybind_extension = tf_python_pybind_extension_opensource +tf_python_pybind_extension = _pybind_extension if use_pywrap_rules() else tf_python_pybind_extension_opensource def tf_pybind_cc_library_wrapper_opensource(name, deps, visibility = None, **kwargs): """Wrapper for cc_library and proto dependencies used by tf_python_pybind_extension_opensource. @@ -3580,3 +3682,20 @@ def tf_python_framework_friends(): def if_cuda_tools(if_true, if_false = []): return _if_cuda_tools(if_true, if_false) + +# The config is used to determine if we need dependency on pre-built wheels. +def if_wheel_dependency(if_true, if_false = []): + return select({ + "@local_tsl//third_party/py:enable_wheel_dependency": if_true, + "//conditions:default": if_false, + }) + +# TODO(b/356020232): remove completely after migration is done +def pywrap_aware_tf_cc_shared_object(name, **kwargs): + if use_pywrap_rules(): + pass + else: + tf_cc_shared_object( + name = name, + **kwargs + ) diff --git a/tensorflow/tensorflow.default.bzl b/tensorflow/tensorflow.default.bzl index be61ba2e7b598b..9f29115a44ef91 100644 --- a/tensorflow/tensorflow.default.bzl +++ b/tensorflow/tensorflow.default.bzl @@ -1,11 +1,21 @@ """Default (OSS) build versions of TensorFlow general-purpose build extensions.""" +load( + "@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", + _pywrap_aware_cc_import = "pywrap_aware_cc_import", + _pywrap_aware_filegroup = "pywrap_aware_filegroup", + _pywrap_aware_genrule = "pywrap_aware_genrule", + _pywrap_common_library = "pywrap_common_library", + _pywrap_library = "pywrap_library", + _stripped_cc_info = "stripped_cc_info", +) load( "//tensorflow:tensorflow.bzl", _ADDITIONAL_API_INDEXABLE_SETTINGS = "ADDITIONAL_API_INDEXABLE_SETTINGS", _cc_header_only_library = "cc_header_only_library", _clean_dep = "clean_dep", _cuda_py_test = "cuda_py_test", + _custom_op_cc_header_only_library = "custom_op_cc_header_only_library", _filegroup = "filegroup", _genrule = "genrule", _get_compatible_with_portable = "get_compatible_with_portable", @@ -16,6 +26,7 @@ load( _pybind_extension = "pybind_extension", _pybind_library = "pybind_library", _pytype_library = "pytype_library", + _pywrap_aware_tf_cc_shared_object = "pywrap_aware_tf_cc_shared_object", _pywrap_tensorflow_macro = "pywrap_tensorflow_macro", _replace_with_portable_tf_lib_when_required = "replace_with_portable_tf_lib_when_required", _tensorflow_opensource_extra_deps = "tensorflow_opensource_extra_deps", @@ -80,6 +91,7 @@ tf_grpc_dependencies = _tf_grpc_dependencies tf_grpc_cc_dependencies = _tf_grpc_cc_dependencies get_compatible_with_portable = _get_compatible_with_portable cc_header_only_library = _cc_header_only_library +custom_op_cc_header_only_library = _custom_op_cc_header_only_library tf_gen_op_libs = _tf_gen_op_libs tf_gen_op_wrapper_cc = _tf_gen_op_wrapper_cc tf_gen_op_wrappers_cc = _tf_gen_op_wrappers_cc @@ -91,3 +103,10 @@ internal_tfrt_deps = _internal_tfrt_deps tf_disable_ptxas_warning_flags = _tf_disable_ptxas_warning_flags replace_with_portable_tf_lib_when_required = _replace_with_portable_tf_lib_when_required tf_python_framework_friends = _tf_python_framework_friends +pywrap_aware_tf_cc_shared_object = _pywrap_aware_tf_cc_shared_object +pywrap_aware_filegroup = _pywrap_aware_filegroup +pywrap_aware_genrule = _pywrap_aware_genrule +pywrap_aware_cc_import = _pywrap_aware_cc_import +pywrap_library = _pywrap_library +pywrap_common_library = _pywrap_common_library +stripped_cc_info = _stripped_cc_info diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt index d1f11277be17f4..2128a8e80f3739 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-device-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt index 35b7e32aecce43..404d0400e36c5c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt @@ -181,7 +181,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt index e42c2702cec45c..8fb9c75673d610 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-multi-r-n-n-cell.pbtxt @@ -176,7 +176,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt index 4e0851595da6d7..071db3a8abf79f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-r-n-n-cell.pbtxt @@ -175,7 +175,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt index f7b61918ff1cc0..1dd1ee3ea574e5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt @@ -177,7 +177,7 @@ tf_class { } member_method { name: "build" - argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None" } member_method { name: "build_from_config" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt index 6a548287f35ce4..cd5e25b1908d72 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "disable_table_stacking" mtype: "" } + member { + name: "enable_fast_table_initialization" + mtype: "" + } member { name: "initialize_tables_on_host" mtype: "" @@ -28,6 +32,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\'], " + argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\', \'enable_fast_table_initialization\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\', \'False\'], " } } diff --git a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt index 6a548287f35ce4..cd5e25b1908d72 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.tpu.experimental.embedding.-sparse-core-embedding-config.pbtxt @@ -10,6 +10,10 @@ tf_class { name: "disable_table_stacking" mtype: "" } + member { + name: "enable_fast_table_initialization" + mtype: "" + } member { name: "initialize_tables_on_host" mtype: "" @@ -28,6 +32,6 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\'], " + argspec: "args=[\'self\', \'disable_table_stacking\', \'max_ids_per_chip_per_sample\', \'max_ids_per_table\', \'max_unique_ids_per_table\', \'allow_id_dropping\', \'initialize_tables_on_host\', \'enable_fast_table_initialization\'], varargs=None, keywords=None, defaults=[\'False\', \'64\', \'None\', \'None\', \'False\', \'True\', \'False\'], " } } diff --git a/tensorflow/tools/api/lib/BUILD b/tensorflow/tools/api/lib/BUILD index 462bc6b24d31ac..25edeab1486bf5 100644 --- a/tensorflow/tools/api/lib/BUILD +++ b/tensorflow/tools/api/lib/BUILD @@ -39,7 +39,6 @@ py_strict_library( # py_proto_library( # name = "api_objects_proto_py_pb2", # has_services = 0, -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":api_objects_proto"], # ) diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 7edaea6c182c23..b135554bfaabba 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -61,8 +61,8 @@ namespace benchmark_model { namespace { -Status InitializeVariables(Session* session, - const std::vector& init_ops) { +absl::Status InitializeVariables(Session* session, + const std::vector& init_ops) { LOG(INFO) << "Initializing graph variables"; for (const string& init_op : init_ops) { TF_RETURN_IF_ERROR(session->Run({}, {}, {init_op}, nullptr)); @@ -128,9 +128,10 @@ void CreateTensorsFromInputInfo( } } -Status GetOutputShapes(const std::vector& inputs, - const std::set& wanted_shapes, Session* session, - std::unordered_map* node_shapes) { +absl::Status GetOutputShapes( + const std::vector& inputs, + const std::set& wanted_shapes, Session* session, + std::unordered_map* node_shapes) { std::vector > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); std::vector output_tensors; @@ -160,10 +161,10 @@ Status GetOutputShapes(const std::vector& inputs, return absl::OkStatus(); } -Status CalculateFlops(const GraphDef& graph, - const std::vector& inputs, - Session* session, int64_t* total_flops, - std::unordered_map* flops_by_op) { +absl::Status CalculateFlops(const GraphDef& graph, + const std::vector& inputs, + Session* session, int64_t* total_flops, + std::unordered_map* flops_by_op) { std::unordered_set floppable_ops = { "Conv2D", "MatMul", "QuantizedConv2D", "QuantizedMatMul", "DepthwiseConv2dNative"}; @@ -260,9 +261,9 @@ void SleepSeconds(double sleep_seconds) { } // namespace -Status InitializeSession(int num_threads, const string& graph, - std::unique_ptr* session, - std::unique_ptr* graph_def) { +absl::Status InitializeSession(int num_threads, const string& graph, + std::unique_ptr* session, + std::unique_ptr* graph_def) { LOG(INFO) << "Loading TensorFlow."; tensorflow::SessionOptions options; @@ -276,7 +277,7 @@ Status InitializeSession(int num_threads, const string& graph, session->reset(tensorflow::NewSession(options)); *graph_def = std::make_unique(); tensorflow::GraphDef tensorflow_graph; - Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); + absl::Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); if (!s.ok()) { s = ReadTextProto(Env::Default(), graph, graph_def->get()); } @@ -295,16 +296,16 @@ Status InitializeSession(int num_threads, const string& graph, return absl::OkStatus(); } -Status RunBenchmark(const std::vector& inputs, - const std::vector& outputs, - const std::vector& targets, Session* session, - StatSummarizer* stats, int64_t* inference_time_us) { +absl::Status RunBenchmark(const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, Session* session, + StatSummarizer* stats, int64_t* inference_time_us) { std::vector > input_tensors; CreateTensorsFromInputInfo(inputs, &input_tensors); std::vector output_tensors; - tensorflow::Status s; + absl::Status s; RunOptions run_options; if (stats != nullptr) { @@ -332,12 +333,14 @@ Status RunBenchmark(const std::vector& inputs, return s; } -Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& targets, Session* session, - StatSummarizer* stats, int64_t* total_time_us, - int64_t* actual_num_runs) { +absl::Status TimeMultipleRuns(double sleep_seconds, int num_runs, + double max_time_s, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, + Session* session, StatSummarizer* stats, + int64_t* total_time_us, + int64_t* actual_num_runs) { *total_time_us = 0; LOG(INFO) << "Running benchmark for max " << num_runs << " iterations, max " @@ -350,7 +353,7 @@ Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, const bool until_max_time = num_runs <= 0; for (int i = 0; until_max_time || i < num_runs; ++i) { int64_t time; - Status run_status = + absl::Status run_status = RunBenchmark(inputs, outputs, targets, session, stats, &time); stat.UpdateStat(time); (*total_time_us) += time; @@ -504,7 +507,7 @@ int Main(int argc, char** argv) { std::unique_ptr graph_def; int64_t initialization_start_us = Env::Default()->NowMicros(); - Status initialize_status = + absl::Status initialize_status = InitializeSession(num_threads, graph, &session, &graph_def); int64_t initialization_end_us = Env::Default()->NowMicros(); double initialization_time_s = @@ -515,7 +518,7 @@ int Main(int argc, char** argv) { } if (!init_ops.empty()) { - Status initialize_variables_status = + absl::Status initialize_variables_status = InitializeVariables(session.get(), init_ops); if (!initialize_variables_status.ok()) { LOG(ERROR) << "Graph variables initialization failed with " @@ -584,7 +587,7 @@ int Main(int argc, char** argv) { int64_t warmup_time_us = 0; int64_t num_warmup_runs = 0; if (warmup_runs > 0) { - Status warmup_time_status = + absl::Status warmup_time_status = TimeMultipleRuns(inter_inference_sleep_seconds, warmup_runs, -1.0, inputs, output_layers, target_layers, session.get(), nullptr, &warmup_time_us, &num_warmup_runs); @@ -599,7 +602,7 @@ int Main(int argc, char** argv) { SleepSeconds(inter_benchmark_sleep_seconds); int64_t no_stat_time_us = 0; int64_t no_stat_num_runs = 0; - Status no_stat_time_status = TimeMultipleRuns( + absl::Status no_stat_time_status = TimeMultipleRuns( inter_inference_sleep_seconds, max_num_runs, max_benchmark_time_seconds, inputs, output_layers, target_layers, session.get(), nullptr, &no_stat_time_us, &no_stat_num_runs); @@ -614,7 +617,7 @@ int Main(int argc, char** argv) { SleepSeconds(inter_benchmark_sleep_seconds); int64_t stat_time_us = 0; int64_t stat_num_runs = 0; - Status stat_time_status = TimeMultipleRuns( + absl::Status stat_time_status = TimeMultipleRuns( inter_inference_sleep_seconds, max_num_runs, max_benchmark_time_seconds, inputs, output_layers, target_layers, session.get(), stats.get(), &stat_time_us, &stat_num_runs); @@ -638,8 +641,8 @@ int Main(int argc, char** argv) { if (show_flops) { int64_t total_flops; std::unordered_map flops_by_op; - Status flop_status = CalculateFlops(*graph_def, inputs, session.get(), - &total_flops, &flops_by_op); + absl::Status flop_status = CalculateFlops(*graph_def, inputs, session.get(), + &total_flops, &flops_by_op); if (!flop_status.ok()) { LOG(ERROR) << "FLOPs calculation failed with " << flop_status; return -1; diff --git a/tensorflow/tools/benchmark/benchmark_model.h b/tensorflow/tools/benchmark/benchmark_model.h index e983ea4167d740..1c8ea8de63ef09 100644 --- a/tensorflow/tools/benchmark/benchmark_model.h +++ b/tensorflow/tools/benchmark/benchmark_model.h @@ -35,23 +35,24 @@ struct InputLayerInfo { }; // Loads a model from disk into a new session. -Status InitializeSession(int num_threads, const string& graph, - std::unique_ptr* session, - std::unique_ptr* graph_def); +absl::Status InitializeSession(int num_threads, const string& graph, + std::unique_ptr* session, + std::unique_ptr* graph_def); // Does a single run of the model that's been loaded into the given session. -Status RunBenchmark(const std::vector& inputs, - const std::vector& outputs, - const std::vector& targets, Session* session, - StatSummarizer* stats, int64_t* inference_time_us); +absl::Status RunBenchmark(const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, Session* session, + StatSummarizer* stats, int64_t* inference_time_us); // Runs the model multiple time, keeping track of timing information. -Status TimeMultipleRuns(double sleep_seconds, int num_runs, double max_time_s, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& targets, Session* session, - StatSummarizer* stats, int64_t* total_time_us, - int64_t* actual_num_runs); +absl::Status TimeMultipleRuns(double sleep_seconds, int num_runs, + double max_time_s, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, + Session* session, StatSummarizer* stats, + int64_t* total_time_us, int64_t* actual_num_runs); // Handles all setup and argument parsing. int Main(int argc, char** argv); diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython index 05f868d4167461..615b414540e2d0 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython @@ -37,7 +37,7 @@ RUN /install/build_and_install_python.sh "3.9.18" RUN /install/build_and_install_python.sh "3.10.13" RUN /install/build_and_install_python.sh "3.11.6" RUN /install/build_and_install_python.sh "3.12.2" -RUN /install/build_and_install_python.sh "3.13.0rc2" +RUN /install/build_and_install_python.sh "3.13.0" COPY install/install_pip_packages_by_version.sh /install/ # https://github.com/numpy/numpy/issues/22623 for `SETUPTOOLS_USE_DISTUTILS`. diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython index 09a1ebdb84972b..7160aa1582f845 100644 --- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython +++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython @@ -27,6 +27,8 @@ RUN apt-get update && apt-get install -y \ libreadline-dev \ libsqlite3-dev \ patchelf \ + libcudnn9-dev-cuda-12=9.1.1.17-1 \ + libcudnn9-cuda-12=9.1.1.17-1 \ && \ rm -rf /var/lib/apt/lists/* @@ -35,7 +37,7 @@ RUN /install/build_and_install_python.sh "3.9.18" RUN /install/build_and_install_python.sh "3.10.13" RUN /install/build_and_install_python.sh "3.11.6" RUN /install/build_and_install_python.sh "3.12.0" -RUN /install/build_and_install_python.sh "3.13.0rc2" +RUN /install/build_and_install_python.sh "3.13.0" COPY install/install_pip_packages_by_version.sh /install/ # https://github.com/numpy/numpy/issues/22623 for `SETUPTOOLS_USE_DISTUTILS`. diff --git a/tensorflow/tools/ci_build/install/install_clang_18.sh b/tensorflow/tools/ci_build/install/install_clang_18.sh index 18b0cc0f906ad1..cf2a6b4abb559e 100755 --- a/tensorflow/tools/ci_build/install/install_clang_18.sh +++ b/tensorflow/tools/ci_build/install/install_clang_18.sh @@ -25,7 +25,7 @@ deb http://apt.llvm.org/focal/ llvm-toolchain-focal-18 main deb-src http://apt.llvm.org/focal/ llvm-toolchain-focal-18 main SOURCES -apt-get autoremove clang-17 -y +apt-get autoremove clang-17 -y || true # Remove clang-17 if it exists. apt-get update && apt-get install -y \ llvm-18 \ clang-18 \ diff --git a/tensorflow/tools/ci_build/update_version.py b/tensorflow/tools/ci_build/update_version.py index 3423818930e210..3265533f803751 100755 --- a/tensorflow/tools/ci_build/update_version.py +++ b/tensorflow/tools/ci_build/update_version.py @@ -234,6 +234,11 @@ def update_tensorflow_bzl(old_version, new_version): new_version.patch) replace_string_in_line('VERSION = "%s"' % old_mmp, 'VERSION = "%s"' % new_mmp, TENSORFLOW_BZL) + replace_string_in_line( + 'WHEEL_VERSION = "%s"' % old_version.string, + 'WHEEL_VERSION = "%s"' % new_version.string, + TENSORFLOW_BZL, + ) def update_m1_builds(old_version, new_version): diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index 6b994ebb7d9c78..75023a5d88e6df 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -3,6 +3,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.bzl", "VERSION_MAJOR") +load("//tensorflow/core/platform:build_config_root.bzl", "if_pywrap") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -23,11 +24,13 @@ py_strict_library( py_strict_test( name = "public_api_test", srcs = ["public_api_test.py"], - data = select({ - "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.dylib"], - "//tensorflow:windows": [], - "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION_MAJOR], - }), + data = if_pywrap( + if_false = select({ + "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.dylib"], + "//tensorflow:windows": [], + "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION_MAJOR], + }), + ), python_version = "PY3", srcs_version = "PY3", deps = [ @@ -47,11 +50,13 @@ py_strict_library( py_strict_test( name = "traverse_test", srcs = ["traverse_test.py"], - data = select({ - "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.dylib"], - "//tensorflow:windows": [], - "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION_MAJOR], - }), + data = if_pywrap( + if_false = select({ + "//tensorflow:macos": ["//tensorflow:libtensorflow_framework.dylib"], + "//tensorflow:windows": [], + "//conditions:default": ["//tensorflow:libtensorflow_framework.so.%s" % VERSION_MAJOR], + }), + ), python_version = "PY3", srcs_version = "PY3", deps = [ diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile index af80a4e473b780..9031ed8a2b4c1c 100644 --- a/tensorflow/tools/gcs_test/Dockerfile +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:24.04@sha256:8a37d68f4f73ebf3d4efafbcf66379bf3728902a8038616808f04e34a9ab63ee +FROM ubuntu:24.04@sha256:dfc10878be8d8fc9c61cbff33166cb1d1fe44391539243703c72766894fa834a LABEL maintainer="Shanqing Cai " diff --git a/tensorflow/tools/graph_transforms/add_default_attributes_test.cc b/tensorflow/tools/graph_transforms/add_default_attributes_test.cc index a0f1d3162a5326..73efa79b378182 100644 --- a/tensorflow/tools/graph_transforms/add_default_attributes_test.cc +++ b/tensorflow/tools/graph_transforms/add_default_attributes_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status AddDefaultAttributes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status AddDefaultAttributes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class AddDefaultAttributesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc index 155ec29e93687c..53d26b56b727fc 100644 --- a/tensorflow/tools/graph_transforms/backports_test.cc +++ b/tensorflow/tools/graph_transforms/backports_test.cc @@ -30,12 +30,12 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status BackportConcatV2Transform(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status BackportConcatV2Transform(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class BackportConcatV2Test : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/compare_graphs.cc b/tensorflow/tools/graph_transforms/compare_graphs.cc index d658279ec88e24..e725936bf688ae 100644 --- a/tensorflow/tools/graph_transforms/compare_graphs.cc +++ b/tensorflow/tools/graph_transforms/compare_graphs.cc @@ -46,7 +46,7 @@ int ParseFlagsAndCompareGraphs(int argc, char* argv[]) { } GraphDef a; - Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a); + absl::Status a_load_status = LoadTextOrBinaryGraphFile(argv[1], &a); if (!a_load_status.ok()) { LOG(ERROR) << "Loading graph '" << argv[1] << "' failed with " << a_load_status.message(); @@ -54,7 +54,7 @@ int ParseFlagsAndCompareGraphs(int argc, char* argv[]) { } GraphDef b; - Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b); + absl::Status b_load_status = LoadTextOrBinaryGraphFile(argv[2], &b); if (!b_load_status.ok()) { LOG(ERROR) << "Loading graph '" << argv[2] << "' failed with " << b_load_status.message(); diff --git a/tensorflow/tools/graph_transforms/file_utils.h b/tensorflow/tools/graph_transforms/file_utils.h index a3723f5cd38334..4185f7b4edc353 100644 --- a/tensorflow/tools/graph_transforms/file_utils.h +++ b/tensorflow/tools/graph_transforms/file_utils.h @@ -24,7 +24,8 @@ namespace graph_transforms { // First tries to load the file as a text protobuf, if that fails tries to parse // it as a binary protobuf, and returns an error if both fail. -Status LoadTextOrBinaryGraphFile(const string& file_name, GraphDef* graph_def); +absl::Status LoadTextOrBinaryGraphFile(const string& file_name, + GraphDef* graph_def); } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/flatten_atrous_test.cc b/tensorflow/tools/graph_transforms/flatten_atrous_test.cc index 3cfb7b668735e8..c6d77f43284fd6 100644 --- a/tensorflow/tools/graph_transforms/flatten_atrous_test.cc +++ b/tensorflow/tools/graph_transforms/flatten_atrous_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status FlattenAtrousConv(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status FlattenAtrousConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class FlattenAtrousConvTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc index 885fbd59b7797c..4ace0eedd9bfd2 100644 --- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc @@ -30,9 +30,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status FoldBatchNorms(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status FoldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class FoldBatchNormsTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.h b/tensorflow/tools/graph_transforms/fold_constants_lib.h index 0802ebb815ac71..dada5a74122b26 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.h +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.h @@ -27,15 +27,15 @@ namespace graph_transforms { // with Const nodes, to simplify the graph. The inputs and outputs arguments are // the names of all the nodes that data is fed into, or read out of, when the // graph is actually run. -Status FoldConstants(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status FoldConstants(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); // Analyzes which nodes are used for the given set of inputs and outputs, and // returns a copy of the graph with any that aren't used removed. -Status RemoveUnusedNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RemoveUnusedNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 3d388cd665499f..deea4bbf0ee8bd 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -35,11 +35,11 @@ namespace tensorflow { namespace graph_transforms { // Declaring this here so it doesn't need to be in the public header. -Status ReplaceSendRecvs(const GraphDef& original_graph_def, - const GraphDef& rewritten_graph_def, - const std::vector& inputs, - const std::vector& outputs, - GraphDef* output_graph_def); +absl::Status ReplaceSendRecvs(const GraphDef& original_graph_def, + const GraphDef& rewritten_graph_def, + const std::vector& inputs, + const std::vector& outputs, + GraphDef* output_graph_def); class ConstantFoldingTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index 69000be1dfde7e..f7dad55698d519 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -31,9 +31,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status FoldOldBatchNorms(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status FoldOldBatchNorms(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class FoldOldBatchNormsTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc index ab6b2ffef0ff87..36c554ffe25776 100644 --- a/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc +++ b/tensorflow/tools/graph_transforms/freeze_requantization_ranges_test.cc @@ -30,16 +30,16 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status FreezeRequantizationRanges(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status FreezeRequantizationRanges(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); struct MinMaxRecord { string name; float min; float max; }; -Status ExtractMinMaxRecords(const string& log_file_name, - std::vector* records); +absl::Status ExtractMinMaxRecords(const string& log_file_name, + std::vector* records); class FreezeRequantizationRangesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc b/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc index b315b9caba1df9..95ef4338647171 100644 --- a/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc +++ b/tensorflow/tools/graph_transforms/fuse_convolutions_test.cc @@ -29,15 +29,15 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status FuseResizePadAndConv(const GraphDef& input_graph_def, +absl::Status FuseResizePadAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status FuseResizeAndConv(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status FusePadAndConv(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def); -Status FuseResizeAndConv(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status FusePadAndConv(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); class FuseConvolutionsTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc b/tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc index 9523e71132728e..ec36d4245a747a 100644 --- a/tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc +++ b/tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc @@ -110,9 +110,9 @@ constexpr char kGraphDefWithPartitionedCall[] = "}\n"; // Declare here, so we don't need a public header. -Status InlinePartitionedCall(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status InlinePartitionedCall(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); TEST(InlinePartitionedCallTest, Inlining) { GraphDef in_graph; diff --git a/tensorflow/tools/graph_transforms/insert_logging.cc b/tensorflow/tools/graph_transforms/insert_logging.cc index 4db5a57c4b45fc..ccb96efdbd51bb 100644 --- a/tensorflow/tools/graph_transforms/insert_logging.cc +++ b/tensorflow/tools/graph_transforms/insert_logging.cc @@ -27,9 +27,9 @@ namespace tensorflow { namespace graph_transforms { // Clears the device field of all ops in the graph. -Status InsertLogging(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def) { +absl::Status InsertLogging(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { std::unordered_set ops; bool has_ops; if (context.params.count("op")) { diff --git a/tensorflow/tools/graph_transforms/insert_logging_test.cc b/tensorflow/tools/graph_transforms/insert_logging_test.cc index e1586a46e548df..1143cc9572244d 100644 --- a/tensorflow/tools/graph_transforms/insert_logging_test.cc +++ b/tensorflow/tools/graph_transforms/insert_logging_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status InsertLogging(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status InsertLogging(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class InsertLoggingTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/obfuscate_names_test.cc b/tensorflow/tools/graph_transforms/obfuscate_names_test.cc index 14df7ba74e0324..81dfbaba277f34 100644 --- a/tensorflow/tools/graph_transforms/obfuscate_names_test.cc +++ b/tensorflow/tools/graph_transforms/obfuscate_names_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status ObfuscateNames(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status ObfuscateNames(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class ObfuscateNamesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index 513db0b1f81695..98abbf9a59cd29 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -32,27 +32,27 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status QuantizeNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status QuantizePlaceholders(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status HoistFakeQuants(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status MergeDuplicateNodes(const GraphDef& input_graph_def, +absl::Status QuantizeNodes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def); +absl::Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status QuantizePlaceholders(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status HoistFakeQuants(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status MergeDuplicateNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class QuantizeNodesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/quantize_weights_test.cc b/tensorflow/tools/graph_transforms/quantize_weights_test.cc index a58ce73453dadd..ae5ae6fab8a184 100644 --- a/tensorflow/tools/graph_transforms/quantize_weights_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_weights_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status QuantizeWeights(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status QuantizeWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class QuantizeWeightsTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/remove_attribute_test.cc b/tensorflow/tools/graph_transforms/remove_attribute_test.cc index 77a69864b0f726..f4954201b567e7 100644 --- a/tensorflow/tools/graph_transforms/remove_attribute_test.cc +++ b/tensorflow/tools/graph_transforms/remove_attribute_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RemoveAttribute(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RemoveAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RemoveAttributeTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/remove_device_test.cc b/tensorflow/tools/graph_transforms/remove_device_test.cc index 17a87cd2366877..2e28aa36da89bd 100644 --- a/tensorflow/tools/graph_transforms/remove_device_test.cc +++ b/tensorflow/tools/graph_transforms/remove_device_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RemoveDevice(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RemoveDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RemoveDeviceTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/remove_nodes_test.cc b/tensorflow/tools/graph_transforms/remove_nodes_test.cc index d8d85a3b47103e..01f7806d498d39 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RemoveNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RemoveNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RemoveNodesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/rename_attribute_test.cc b/tensorflow/tools/graph_transforms/rename_attribute_test.cc index 31619d82ad998a..e51048deaf176e 100644 --- a/tensorflow/tools/graph_transforms/rename_attribute_test.cc +++ b/tensorflow/tools/graph_transforms/rename_attribute_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RenameAttribute(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RenameAttribute(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RenameAttributeTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/rename_node_test.cc b/tensorflow/tools/graph_transforms/rename_node_test.cc index 574272b8cca103..a18b3a626972d2 100644 --- a/tensorflow/tools/graph_transforms/rename_node_test.cc +++ b/tensorflow/tools/graph_transforms/rename_node_test.cc @@ -26,9 +26,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RenameNode(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RenameNode(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); TEST(RenameNodeTest, Rename) { GraphDef in_graph; diff --git a/tensorflow/tools/graph_transforms/rename_op_test.cc b/tensorflow/tools/graph_transforms/rename_op_test.cc index d09f2abaa9e649..dc604c0593f362 100644 --- a/tensorflow/tools/graph_transforms/rename_op_test.cc +++ b/tensorflow/tools/graph_transforms/rename_op_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RenameOp(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RenameOp(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RenameOpTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/round_weights_test.cc b/tensorflow/tools/graph_transforms/round_weights_test.cc index 74700a2760ce1a..6d8731d54bc867 100644 --- a/tensorflow/tools/graph_transforms/round_weights_test.cc +++ b/tensorflow/tools/graph_transforms/round_weights_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status RoundWeights(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status RoundWeights(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class RoundWeightsTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/set_device_test.cc b/tensorflow/tools/graph_transforms/set_device_test.cc index fb64e0019d32ce..98ef4c421335ed 100644 --- a/tensorflow/tools/graph_transforms/set_device_test.cc +++ b/tensorflow/tools/graph_transforms/set_device_test.cc @@ -29,9 +29,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status SetDevice(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status SetDevice(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); namespace { GraphDef CreateDeviceGraph() { diff --git a/tensorflow/tools/graph_transforms/sort_by_execution_order.cc b/tensorflow/tools/graph_transforms/sort_by_execution_order.cc index 30578dc818c543..4fad49fb351b1f 100644 --- a/tensorflow/tools/graph_transforms/sort_by_execution_order.cc +++ b/tensorflow/tools/graph_transforms/sort_by_execution_order.cc @@ -28,7 +28,7 @@ namespace graph_transforms { // This is a thin wrapper with the standard TransformFunc interface to the // underlying utility function. The only difference is that we don't use the // input or output name arguments. -Status SortByExecutionOrderWithUnusedContext( +absl::Status SortByExecutionOrderWithUnusedContext( const GraphDef& input_graph_def, const TransformFuncContext& unused_context, GraphDef* output_graph_def) { return SortByExecutionOrder(input_graph_def, output_graph_def); diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc index 25b1555f014db1..a4b58b1f1ea983 100644 --- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc +++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc @@ -29,10 +29,10 @@ namespace tensorflow { namespace graph_transforms { // Declarations so we don't need a public header. -Status SparsifyGather(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); -Status ReadTensorFromCheckpoint( +absl::Status SparsifyGather(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); +absl::Status ReadTensorFromCheckpoint( const string& tensor_name, const std::unique_ptr& ckpt_reader, const string& shape_and_slice, Tensor* tensor); diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc index c0107014e2cf11..74f37064ec23eb 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes_test.cc @@ -30,9 +30,9 @@ namespace tensorflow { namespace graph_transforms { // Declare here, so we don't need a public header. -Status StripUnusedNodes(const GraphDef& input_graph_def, - const TransformFuncContext& context, - GraphDef* output_graph_def); +absl::Status StripUnusedNodes(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); class StripUnusedNodesTest : public ::testing::Test { protected: diff --git a/tensorflow/tools/graph_transforms/transform_graph.h b/tensorflow/tools/graph_transforms/transform_graph.h index 58ec14193171c0..4082a162c08e31 100644 --- a/tensorflow/tools/graph_transforms/transform_graph.h +++ b/tensorflow/tools/graph_transforms/transform_graph.h @@ -33,16 +33,16 @@ int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main); // arguments. typedef std::vector> TransformParameters; -Status ParseTransformParameters(const string& transforms_string, - TransformParameters* params_list); +absl::Status ParseTransformParameters(const string& transforms_string, + TransformParameters* params_list); // Applies a series of transformations to the GraphDef. These transforms are // defined by modules that call REGISTER_GRAPH_TRANSFORM() to associate a // function with a name string. -Status TransformGraph(const std::vector& inputs, - const std::vector& outputs, - const TransformParameters& transform_params, - GraphDef* graph_def); +absl::Status TransformGraph(const std::vector& inputs, + const std::vector& outputs, + const TransformParameters& transform_params, + GraphDef* graph_def); } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 62253b44f0d805..4b724d63439faf 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -118,8 +118,8 @@ void RemoveAttributes(const GraphDef& input_graph_def, // For a lot of replacement and matching operations it's useful to have the // nodes processed in a controlled order, so this does a topological sort to // ensure that nodes always appear in the GraphDef.node list after their inputs. -Status SortByExecutionOrder(const GraphDef& input_graph_def, - GraphDef* output_graph_def); +absl::Status SortByExecutionOrder(const GraphDef& input_graph_def, + GraphDef* output_graph_def); // Finds inputs that refer to nodes that are not in the graph. void FindInvalidInputs(const GraphDef& graph_def, @@ -127,14 +127,15 @@ void FindInvalidInputs(const GraphDef& graph_def, // Returns a descriptive error status if there are problems spotted with the // graph. -Status IsGraphValid(const GraphDef& graph_def); +absl::Status IsGraphValid(const GraphDef& graph_def); // Returns input and output types for a particular NodeDef. -Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, - DataTypeVector* outputs); +absl::Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, + DataTypeVector* outputs); // Takes a comma-separated string of numbers and parses them into a shape. -Status TensorShapeFromString(const string& shape_string, TensorShape* result); +absl::Status TensorShapeFromString(const string& shape_string, + TensorShape* result); // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: @@ -167,8 +168,8 @@ class GraphMatcher { // matches so that no node appears in more than one match. The NodeDef // pointers contained in the results are owned by the GraphMatcher object, and // so will be invalid after its lifetime. - Status GetOpTypeMatches(const OpTypePattern& pattern, - std::vector* matches); + absl::Status GetOpTypeMatches(const OpTypePattern& pattern, + std::vector* matches); private: bool DoesOpTypeMatch(const NodeDef& node, const OpTypePattern& pattern, @@ -195,11 +196,11 @@ struct ReplaceMatchingOpTypesOptions { // by setting allow_inconsistencies to true in the options, but then it's the // caller's responsibility to patch up any problems before passing on the graph // to others. There's more comprehensive usage documentation in the README. -Status ReplaceMatchingOpTypes( +absl::Status ReplaceMatchingOpTypes( const GraphDef& input_graph_def, const OpTypePattern& pattern, - const std::function&, - const std::set&, std::vector*)>& - node_generator, + const std::function&, + const std::set&, + std::vector*)>& node_generator, const ReplaceMatchingOpTypesOptions& options, GraphDef* output_graph_def); // Returns a list of the unique nodes found in this match. @@ -207,10 +208,10 @@ void MatchedNodesAsArray(const NodeMatch& match, std::vector* result); // Changes all input references to a particular node name. Any nodes with names // listed in nodes_to_ignore will not have their inputs rewritten. -Status RenameNodeInputs(const GraphDef& input_graph_def, - const std::map& inputs_to_rename, - const std::unordered_set& nodes_to_ignore, - GraphDef* output_graph_def); +absl::Status RenameNodeInputs(const GraphDef& input_graph_def, + const std::map& inputs_to_rename, + const std::unordered_set& nodes_to_ignore, + GraphDef* output_graph_def); // Utility function that copies all the nodes found in a match into the // new_nodes list. This is useful in replacement functions when you decide to @@ -228,38 +229,39 @@ struct TransformFuncContext { int CountParameters(const string& name) const; // Gets a single instance of a parameter, using a default if it's not present. - Status GetOneStringParameter(const string& name, const string& default_value, - string* result) const; + absl::Status GetOneStringParameter(const string& name, + const string& default_value, + string* result) const; // Gets a single occurrence of a parameter as a 32-bit integer, falling back // to a default if it isn't present and returning an error if it isn't // convertible to a number. - Status GetOneInt32Parameter(const string& name, int32_t default_value, - int32* result) const; + absl::Status GetOneInt32Parameter(const string& name, int32_t default_value, + int32* result) const; // Gets a single occurrence of a parameter as a 64-bit integer, falling back // to a default if it isn't present and returning an error if it isn't // convertible to a number. - Status GetOneInt64Parameter(const string& name, int64_t default_value, - int64_t* result) const; + absl::Status GetOneInt64Parameter(const string& name, int64_t default_value, + int64_t* result) const; // Gets a single occurrence of a parameter as a floating point number, falling // back to a default if it isn't present and returning an error if it isn't // convertible to a number. - Status GetOneFloatParameter(const string& name, float default_value, - float* result) const; + absl::Status GetOneFloatParameter(const string& name, float default_value, + float* result) const; // Gets a single occurrence of a parameter as a boolean, falling back to a // default if it isn't present and returning an error if it's not one of // "true", "1", "false", or "0". - Status GetOneBoolParameter(const string& name, bool default_value, - bool* result) const; + absl::Status GetOneBoolParameter(const string& name, bool default_value, + bool* result) const; }; // This is the function API for all graph transformations, taking an input // GraphDef and other arguments, and returning a transformed GraphDef. -typedef std::function +typedef std::function TransformFunc; // To add a new graph transform function, call the macro: diff --git a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc index ad2856332dd0a0..300552914c230a 100644 --- a/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc +++ b/tensorflow/tools/optimization/gpu_optimization_pass_runner_main.cc @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { namespace { -Status RealMain(int argc, char** argv) { +absl::Status RealMain(int argc, char** argv) { string input_file_path; string output_file_path; string optimization_pass; diff --git a/tensorflow/tools/optimization/optimization_pass_runner.cc b/tensorflow/tools/optimization/optimization_pass_runner.cc index 8a81da83be6240..008cf9a6f50a58 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.cc +++ b/tensorflow/tools/optimization/optimization_pass_runner.cc @@ -52,11 +52,11 @@ class FakeDevice : public Device { : Device(nullptr, device_attributes) {} public: - Status Sync() override; + absl::Status Sync() override; static std::unique_ptr Make(const string& name, const string& type); }; -Status FakeDevice::Sync() { +absl::Status FakeDevice::Sync() { return errors::Unimplemented("FakeDevice::Sync()"); } @@ -68,8 +68,8 @@ std::unique_ptr FakeDevice::Make(const string& name, return std::unique_ptr(new FakeDevice(device_attributes)); } -Status FindPassWithName(absl::string_view name, - GraphOptimizationPass** result) { +absl::Status FindPassWithName(absl::string_view name, + GraphOptimizationPass** result) { *result = nullptr; // Run the optimization pass specified by the command line flag. for (const auto& groups_and_passes : @@ -93,8 +93,8 @@ Status FindPassWithName(absl::string_view name, } } // namespace -Status OptimizationPassRunner::Run(absl::string_view pass_to_run, - GraphDef input, GraphDef* result) { +absl::Status OptimizationPassRunner::Run(absl::string_view pass_to_run, + GraphDef input, GraphDef* result) { auto session_options = std::make_unique(); session_options->config.mutable_graph_options() ->mutable_optimizer_options() @@ -131,13 +131,14 @@ Status OptimizationPassRunner::Run(absl::string_view pass_to_run, return absl::OkStatus(); } -Status OptimizationPassRunner::SetJitLevel( +absl::Status OptimizationPassRunner::SetJitLevel( OptimizerOptions::GlobalJitLevel jit_level) { jit_level_ = jit_level; return absl::OkStatus(); } -Status OptimizationPassRunner::AddDevices(absl::string_view type, int count) { +absl::Status OptimizationPassRunner::AddDevices(absl::string_view type, + int count) { for (int i = 0; i < count; i++) { devices_.push_back(FakeDevice::Make( absl::StrCat("/job:localhost/replica:0/task:0/device:", type, ":", i), diff --git a/tensorflow/tools/optimization/optimization_pass_runner.h b/tensorflow/tools/optimization/optimization_pass_runner.h index 0b96ce3e5a9d47..5c81f2a13a7396 100644 --- a/tensorflow/tools/optimization/optimization_pass_runner.h +++ b/tensorflow/tools/optimization/optimization_pass_runner.h @@ -37,20 +37,21 @@ class OptimizationPassRunner { // Increasing the Jit level will cause XLA to compile parts of the tensorflow // graph that it is able to. - Status SetJitLevel(OptimizerOptions::GlobalJitLevel jit_level); + absl::Status SetJitLevel(OptimizerOptions::GlobalJitLevel jit_level); - Status Run(absl::string_view pass_to_run, GraphDef input, GraphDef* result); + absl::Status Run(absl::string_view pass_to_run, GraphDef input, + GraphDef* result); - Status AddCpus(int count) { + absl::Status AddCpus(int count) { return AddDevices(tensorflow::DEVICE_CPU, count); } - Status AddGpus(int count) { + absl::Status AddGpus(int count) { return AddDevices(tensorflow::DEVICE_GPU, count); } private: - Status AddDevices(absl::string_view type, int count); + absl::Status AddDevices(absl::string_view type, int count); OptimizerOptions::GlobalJitLevel jit_level_; std::vector> devices_; diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index c6ba8762df3b0f..25ac69f4ecbbc4 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -3,12 +3,17 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_syslibs//:build_defs.bzl", "if_not_system_lib") +load( + "@local_tsl//third_party/py:python_wheel_library.bzl", + "wheel_library", +) +load("@local_xla//xla/tsl:tsl.bzl", "if_cuda_libs") load("@local_xla//xla/tsl/mkl:build_defs.bzl", "if_enable_mkl", "if_mkl", "if_mkl_ml") -load("//tensorflow:tensorflow.bzl", "if_with_tpu_support", "transitive_hdrs") +load("//tensorflow:tensorflow.bzl", "if_wheel_dependency", "if_with_tpu_support", "transitive_hdrs") load("//tensorflow/core/platform:build_config_root.bzl", "tf_additional_license_deps") load("//tensorflow/tools/pip_package/utils:data_deps.bzl", "collect_data_files") load("//tensorflow/tools/pip_package/utils:py_deps.bzl", "transitive_py_deps") -load("//tensorflow/tools/pip_package/utils:tf_wheel.bzl", "tf_wheel") +load("//tensorflow/tools/pip_package/utils:tf_wheel.bzl", "tf_wheel", "tf_wheel_dep") package(default_visibility = ["//visibility:public"]) @@ -270,6 +275,17 @@ tf_wheel( ":licenses", "//tensorflow/core:protos_all_proto_srcs", ], + platform_name = select({ + "@platforms//os:osx": "macosx", + "@platforms//os:macos": "macosx", + "@platforms//os:windows": "win", + "@platforms//os:linux": "linux", + }), + platform_tag = select({ + "@platforms//cpu:aarch64": "arm64", + "@platforms//cpu:arm64": "arm64", + "@platforms//cpu:x86_64": "x86_64", + }), source_files = [ "MANIFEST.in", "//tensorflow/tools/pip_package:THIRD_PARTY_NOTICES.txt", @@ -292,3 +308,72 @@ tf_wheel( ":xla_cmake", ], ) + +genrule( + name = "empty_test", + outs = ["empty_test.py"], + cmd = "echo '' > $@", +) + +py_test( + name = "prebuilt_wheel_import_api_packages_test", + srcs = if_wheel_dependency( + ["import_api_packages_test.py"], + [":empty_test"], + ), + main = if_wheel_dependency("import_api_packages_test.py", "empty_test.py"), + tags = [ + "cpu", + "gpu", + "windows_excluded", + ], + deps = if_wheel_dependency(tf_wheel_dep()), +) + +py_test( + name = "import_api_packages_test", + srcs = ["import_api_packages_test.py"], + main = "import_api_packages_test.py", + tags = [ + "cpu", + "gpu", + "windows_excluded", + ], + deps = [ + ":tf_wheel_library", + ], +) + +wheel_library( + name = "tf_wheel_library", + wheel = ":wheel", + wheel_deps = if_cuda_libs([ + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + "@cuda_cudart//:cudart", + "@cuda_cudnn//:cudnn", + "@cuda_cufft//:cufft", + "@cuda_cupti//:cupti", + "@cuda_curand//:curand", + "@cuda_cusolver//:cusolver", + "@cuda_cusparse//:cusparse", + "@cuda_nccl//:nccl", + "@cuda_nvjitlink//:nvjitlink", + "@cuda_nvrtc//:nvrtc", + ]), + deps = [ + "@pypi_absl_py//:pkg", + "@pypi_astunparse//:pkg", + "@pypi_flatbuffers//:pkg", + "@pypi_gast//:pkg", + "@pypi_ml_dtypes//:pkg", + "@pypi_numpy//:pkg", + "@pypi_opt_einsum//:pkg", + "@pypi_packaging//:pkg", + "@pypi_protobuf//:pkg", + "@pypi_requests//:pkg", + "@pypi_termcolor//:pkg", + "@pypi_typing_extensions//:pkg", + "@pypi_wrapt//:pkg", + ], +) diff --git a/tensorflow/tools/pip_package/MANIFEST.in b/tensorflow/tools/pip_package/MANIFEST.in index dafc500d7f4106..bd8ada77aef5f9 100644 --- a/tensorflow/tools/pip_package/MANIFEST.in +++ b/tensorflow/tools/pip_package/MANIFEST.in @@ -12,4 +12,4 @@ recursive-include * *.csv recursive-include tensorflow * recursive-exclude tensorflow *.md recursive-exclude tensorflow/_api/ * -include tensorflow/_api/api_packages.txt \ No newline at end of file +include tensorflow/_api/v2/api_packages.txt \ No newline at end of file diff --git a/tensorflow/tools/pip_package/build_pip_package.py b/tensorflow/tools/pip_package/build_pip_package.py index d882db09a10c8a..39acf24022bb38 100644 --- a/tensorflow/tools/pip_package/build_pip_package.py +++ b/tensorflow/tools/pip_package/build_pip_package.py @@ -48,6 +48,11 @@ def parse_args() -> argparse.Namespace: help="Output file for the wheel, mandatory") parser.add_argument("--project-name", required=True, help="Project name to be passed to setup.py") + parser.add_argument( + "--platform", + required=True, + help="Platform name to be passed to setup.py", + ) parser.add_argument( "--headers", help="header files for the wheel", action="append") parser.add_argument("--srcs", help="source files for the wheel", @@ -350,14 +355,20 @@ def create_local_config_python(dst_dir: str) -> None: shutil.copytree(glob.glob(path)[0], os.path.join(dst_dir, "python_include")) -def build_wheel(dir_path: str, cwd: str, project_name: str, - collab: str = False) -> None: +def build_wheel( + dir_path: str, + cwd: str, + project_name: str, + platform: str, + collab: str = False, +) -> None: """Build the wheel in the target directory. - + Args: dir_path: directory where the wheel will be stored cwd: path to directory with wheel source files project_name: name to pass to setup.py. + platform: platform name to pass to setup.py. collab: defines if this is a collab build """ env = os.environ.copy() @@ -376,6 +387,7 @@ def build_wheel(dir_path: str, cwd: str, project_name: str, "tensorflow/tools/pip_package/setup.py", "bdist_wheel", f"--dist-dir={dir_path}", + f"--plat-name={platform}", ], check=True, cwd=cwd, @@ -390,7 +402,12 @@ def build_wheel(dir_path: str, cwd: str, project_name: str, try: prepare_wheel_srcs(args.headers, args.srcs, args.xla_aot, temp_dir_path, args.version) - build_wheel(os.path.join(os.getcwd(), args.output_name), - temp_dir_path, args.project_name, args.collab) + build_wheel( + os.path.join(os.getcwd(), args.output_name), + temp_dir_path, + args.project_name, + args.platform, + args.collab, + ) finally: temp_dir.cleanup() diff --git a/ci/official/wheel_test/test_import_api_packages.py b/tensorflow/tools/pip_package/import_api_packages_test.py similarity index 66% rename from ci/official/wheel_test/test_import_api_packages.py rename to tensorflow/tools/pip_package/import_api_packages_test.py index 1c9fb5365500b0..ca8849fef03978 100644 --- a/ci/official/wheel_test/test_import_api_packages.py +++ b/tensorflow/tools/pip_package/import_api_packages_test.py @@ -15,19 +15,33 @@ """Import API packages test. -This is a Python test that verifies whether API v2 packages can be imported -from the current build or not. - -It uses the `_api/v2/api_packages.txt` file from the local wheel file. -The `_api/v2/api_packages.txt` file is created during the process of generating -TensorFlow API v2 init files and is stored in the wheel file after the build. - -See README.md file for "how to run" instruction. +This Python test verifies whether the API v2 packages can be imported from the +current build. It utilizes the `_api/v2/api_packages.txt` list of packages from +the local wheel file specified in the `requirements_lock_.txt`. + +Packages are imported one by one in alphabetical order during runtime. + +The test doesn't identify package's order-dependent issues; for instance, +importing "tf.foo" followed by "tf.bar" won't reveal that "tf.bar" depends on +"tf.foo" being imported first. + +The `_api/v2/api_packages.txt` file is generated during the TensorFlow API v2 +init files creation process and is subsequently stored in the wheel file after +the build. It also contains a few paths that cannot be directly imported. These +paths point to attributes or sub-modules within a module's namespace, but they +don't correspond to an actual file or directory on the filesystem. The list of +such paths is stored in the packages_for_skip variable and will be skipped +during the test. """ import logging +import os import unittest -import pkg_resources + +try: + import importlib.resources as pkg_resources # pylint: disable=g-import-not-at-top +except ImportError: + import importlib_resources as pkg_resources # pylint: disable=g-import-not-at-top logging.basicConfig(level=logging.INFO) @@ -37,13 +51,14 @@ class ImportApiPackagesTest(unittest.TestCase): def setUp(self): def _get_api_packages_v2(): - api_packages_path = pkg_resources.resource_filename( - "tensorflow", "_api/v2/api_packages.txt" - ) - + api_packages_path = os.path.join("_api", "v2", "api_packages.txt") logging.info("Load api packages file: %s", api_packages_path) - with open(api_packages_path) as file: - return set(file.read().splitlines()) + return set( + pkg_resources.files("tensorflow") + .joinpath(api_packages_path) + .read_text() + .splitlines() + ) super().setUp() self.api_packages_v2 = _get_api_packages_v2() diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index 125c3cfd13a57c..556542864e3155 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -48,7 +48,7 @@ # result for pip. # Also update tensorflow/tensorflow.bzl and # tensorflow/core/public/version.h -_VERSION = '2.18.0' +_VERSION = '2.19.0' # We use the same setup.py for all tensorflow_* packages and for the nightly @@ -86,7 +86,7 @@ def standard_or_nightly(standard, nightly): 'packaging', # pylint:disable=line-too-long ( - 'protobuf>=3.20.3,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5' + 'protobuf>=3.20.3,<6.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5' ), 'requests >= 2.21.0, < 3', 'setuptools', @@ -109,13 +109,11 @@ def standard_or_nightly(standard, nightly): # dependencies on the release branch is updated to the stable releases (RC # or final). For example, 'keras-nightly ~= 2.14.0.dev' will be replaced by # 'keras >= 2.14.0rc0, < 2.15' on the release branch after the branch cut. - 'tb-nightly ~= 2.18.0.a', - 'keras-nightly >= 3.2.0.dev', - # TODO(b/367877753): Update the upper bound to <2.2.0 once the compatibility - # issues with numpy 2.1.0 is fixed. - 'numpy >= 1.26.0, < 2.1.0', + 'tb-nightly ~= 2.19.0.a', + 'keras-nightly >= 3.6.0.dev', + 'numpy >= 1.26.0, < 2.2.0', 'h5py >= 3.11.0', - 'ml_dtypes >= 0.4.0, < 0.5.0', + 'ml_dtypes >= 0.4.0, < 1.0.0', ] REQUIRED_PACKAGES = [p for p in REQUIRED_PACKAGES if p is not None] @@ -161,7 +159,7 @@ def standard_or_nightly(standard, nightly): 'nvidia-curand-cu12 == 10.3.6.82', 'nvidia-cusolver-cu12 == 11.6.3.83', 'nvidia-cusparse-cu12 == 12.5.1.3', - 'nvidia-nccl-cu12 == 2.21.5', + 'nvidia-nccl-cu12 == 2.23.4', 'nvidia-nvjitlink-cu12 == 12.5.82', ] diff --git a/tensorflow/tools/pip_package/utils/tf_wheel.bzl b/tensorflow/tools/pip_package/utils/tf_wheel.bzl index 2d4ac55a8461ec..8c9d1dcc80f6d9 100644 --- a/tensorflow/tools/pip_package/utils/tf_wheel.bzl +++ b/tensorflow/tools/pip_package/utils/tf_wheel.bzl @@ -25,7 +25,40 @@ Should be set via --repo_env=WHEEL_NAME=tensorflow_cpu. load("@python_version_repo//:py_version.bzl", "WHEEL_COLLAB", "WHEEL_NAME", "OUTPUT_PATH") load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo") +<<<<<<< HEAD load("//tensorflow:tensorflow.bzl", "VERSION") +======= +load( + "@python_version_repo//:py_version.bzl", + "HERMETIC_PYTHON_VERSION", + "MACOSX_DEPLOYMENT_TARGET", + "WHEEL_COLLAB", + "WHEEL_NAME", +) +load("//tensorflow:tensorflow.bzl", "VERSION", "WHEEL_VERSION") + +def _get_wheel_platform_name(platform_name, platform_tag): + macos_platform_version = "{}_".format(MACOSX_DEPLOYMENT_TARGET.replace(".", "_")) if MACOSX_DEPLOYMENT_TARGET else "" + tag = platform_tag + if platform_tag == "x86_64" and platform_name == "win": + tag = "amd64" + if platform_tag == "arm64" and platform_name == "linux": + tag = "aarch64" + return "{platform_name}_{platform_version}{platform_tag}".format( + platform_name = platform_name, + platform_tag = tag, + platform_version = macos_platform_version, + ) + +def _get_full_wheel_name(platform_name, platform_tag): + python_version = HERMETIC_PYTHON_VERSION.replace(".", "") + return "{wheel_name}-{wheel_version}-cp{python_version}-cp{python_version}-{wheel_platform_tag}.whl".format( + wheel_name = WHEEL_NAME, + wheel_version = WHEEL_VERSION.replace("-", "."), + python_version = python_version, + wheel_platform_tag = _get_wheel_platform_name(platform_name, platform_tag), + ) +>>>>>>> master def _tf_wheel_impl(ctx): include_cuda_libs = ctx.attr.include_cuda_libs[BuildSettingInfo].value @@ -37,12 +70,33 @@ def _tf_wheel_impl(ctx): " If you absolutely need to add CUDA dependencies, provide `--@local_config_cuda//cuda:override_include_cuda_libs=true`.") executable = ctx.executable.wheel_binary +<<<<<<< HEAD output = ctx.actions.declare_directory("wheel_house") output_path = OUTPUT_PATH if OUTPUT_PATH else output.path +======= + full_wheel_name = _get_full_wheel_name( + platform_name = ctx.attr.platform_name, + platform_tag = ctx.attr.platform_tag, + ) + wheel_dir_name = "wheel_house" + output_dir = ctx.actions.declare_directory(wheel_dir_name) + output_file = ctx.actions.declare_file("{wheel_dir}/{wheel_name}".format( + wheel_dir = wheel_dir_name, + wheel_name = full_wheel_name, + )) +>>>>>>> master args = ctx.actions.args() args.add("--project-name", WHEEL_NAME) + args.add("--platform", _get_wheel_platform_name( + ctx.attr.platform_name, + ctx.attr.platform_tag, + )) args.add("--collab", str(WHEEL_COLLAB)) +<<<<<<< HEAD args.add("--output-name", output_path) +======= + args.add("--output-name", output_dir.path) +>>>>>>> master args.add("--version", VERSION) headers = ctx.files.headers[:] @@ -64,10 +118,10 @@ def _tf_wheel_impl(ctx): ctx.actions.run( arguments = [args], inputs = srcs + headers + xla_aot, - outputs = [output], + outputs = [output_dir, output_file], executable = executable, ) - return [DefaultInfo(files = depset(direct = [output]))] + return [DefaultInfo(files = depset(direct = [output_file]))] tf_wheel = rule( attrs = { @@ -81,6 +135,11 @@ tf_wheel = rule( ), "include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:include_cuda_libs")), "override_include_cuda_libs": attr.label(default = Label("@local_config_cuda//cuda:override_include_cuda_libs")), + "platform_tag": attr.string(mandatory = True), + "platform_name": attr.string(mandatory = True), }, implementation = _tf_wheel_impl, ) + +def tf_wheel_dep(): + return ["@pypi_{}//:pkg".format(WHEEL_NAME)] diff --git a/tensorflow/tools/proto_splitter/BUILD b/tensorflow/tools/proto_splitter/BUILD index b05a7c3688ff58..d4847b5f34c197 100644 --- a/tensorflow/tools/proto_splitter/BUILD +++ b/tensorflow/tools/proto_splitter/BUILD @@ -20,8 +20,6 @@ package( default_visibility = [ "__subpackages__", "//tensorflow:internal", - "//tensorflow/cc/experimental/tf2:__subpackages__", - "//tensorflow/cc/saved_model/image_format:__subpackages__", ], licenses = ["notice"], ) @@ -53,7 +51,6 @@ cc_library( # # py_proto_library( # name = "versions_proto_py_pb2", -# api_version = 2, # deps = [ # ":versions_proto", # ], @@ -61,7 +58,6 @@ cc_library( # # py_proto_library( # name = "chunk_proto_py_pb2", -# api_version = 2, # deps = [ # ":chunk_proto", # ], @@ -188,7 +184,6 @@ cc_library( name = "merge", hdrs = ["merge.h"], visibility = [ - "__subpackages__", "//tensorflow:internal", "//tensorflow/cc/experimental/tf2:__subpackages__", "//tensorflow/cc/saved_model/image_format:__subpackages__", diff --git a/tensorflow/tools/proto_splitter/testdata/BUILD b/tensorflow/tools/proto_splitter/testdata/BUILD index 5ab95c1313e7a4..11c508c8783fbb 100644 --- a/tensorflow/tools/proto_splitter/testdata/BUILD +++ b/tensorflow/tools/proto_splitter/testdata/BUILD @@ -53,7 +53,6 @@ tf_proto_library( # # py_proto_library( # name = "test_message_proto_py_pb2", -# api_version = 2, # deps = [ # ":test_message_proto", # ], diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile index 6fdf95d1e8ba94..2a950f002de0b3 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile +++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:22.04@sha256:adbb90115a21969d2fe6fa7f9af4253e16d45f8d4c1e930182610c4731962658 as builder +FROM ubuntu:22.04@sha256:58b87898e82351c6cf9cf5b9f3c20257bb9e2dcf33af051e12ce532d7f94e3fe as builder ################################################################################ # Install devtoolset build dependencies diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt index fba0146f4299dd..6eb403d3fbce73 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt @@ -61,6 +61,7 @@ mlocate moreutils openjdk-21-jdk openjdk-21-jre-headless +parallel pkg-config python3-dev python3-setuptools diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh index bef0c6f32981f0..1d9bbab9d086e5 100755 --- a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh +++ b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh @@ -61,10 +61,16 @@ add-apt-repository -y 'ppa:deadsnakes/ppa' # Install Python packages for this container's version cat >pythons.txt <>>>>>> master EOF /setup.packages.sh pythons.txt diff --git a/tensorflow/tools/tfg_graph_transforms/tfg_graph_transforms_main.cc b/tensorflow/tools/tfg_graph_transforms/tfg_graph_transforms_main.cc index 465db3dbfaa3e7..10509ebb2335ca 100644 --- a/tensorflow/tools/tfg_graph_transforms/tfg_graph_transforms_main.cc +++ b/tensorflow/tools/tfg_graph_transforms/tfg_graph_transforms_main.cc @@ -132,7 +132,7 @@ void RegisterDialects(mlir::DialectRegistry& registry) { }); } -tensorflow::Status RunOptimizationPasses( +absl::Status RunOptimizationPasses( const mlir::PassPipelineCLParser& passPipeline, mlir::ModuleOp module, mlir::MLIRContext* context) { mlir::PassManager pm(context); @@ -191,12 +191,11 @@ absl::StatusOr> ImportModel( } } -tensorflow::Status ExportTFGModule(mlir::ModuleOp module_op, - DataFormat data_format, - const std::string& input_file, - const std::string& output_file, - bool experimental_image_format, - int experimental_image_format_max_size) { +absl::Status ExportTFGModule(mlir::ModuleOp module_op, DataFormat data_format, + const std::string& input_file, + const std::string& output_file, + bool experimental_image_format, + int experimental_image_format_max_size) { switch (data_format) { case DataFormat::SavedModel: { tensorflow::SavedModel original_saved_model; @@ -276,14 +275,14 @@ int main(int argc, char** argv) { // Parse the optimization pipeline configuration and run requested graph // optimizations. - tensorflow::Status pass_pipeline_status = + absl::Status pass_pipeline_status = RunOptimizationPasses(pass_pipeline, *module_ref, &context); if (!pass_pipeline_status.ok()) { LOG(QFATAL) << pass_pipeline_status << "\n"; } // Export MLIR TFG module to the resulting model proto. - tensorflow::Status export_status = ExportTFGModule( + absl::Status export_status = ExportTFGModule( *module_ref, data_format, input_file, output_file, experimental_image_format, experimental_image_format_max_proto_size); diff --git a/tensorflow/tools/tfg_graph_transforms/utils.cc b/tensorflow/tools/tfg_graph_transforms/utils.cc index a5b6f0af916518..2fe2e9476e7659 100644 --- a/tensorflow/tools/tfg_graph_transforms/utils.cc +++ b/tensorflow/tools/tfg_graph_transforms/utils.cc @@ -43,15 +43,15 @@ bool IsTextProto(const std::string& input_file) { return !extension.compare("pbtxt"); } -tensorflow::Status ReadSavedModelImageFormat( - const std::string& input_file, tensorflow::SavedModel& model_proto) { +absl::Status ReadSavedModelImageFormat(const std::string& input_file, + tensorflow::SavedModel& model_proto) { std::string saved_model_prefix(GetNameWithoutExtension(input_file)); return tensorflow::image_format::ReadSavedModel(saved_model_prefix, &model_proto); } -tensorflow::Status WriteSavedModelImageFormat( - tensorflow::SavedModel* model_proto, const std::string& output_file, - int debug_max_size) { +absl::Status WriteSavedModelImageFormat(tensorflow::SavedModel* model_proto, + const std::string& output_file, + int debug_max_size) { std::string saved_model_prefix(GetNameWithoutExtension(output_file)); if (debug_max_size > 0) { return tensorflow::image_format::WriteSavedModel( diff --git a/tensorflow/tools/tfg_graph_transforms/utils.h b/tensorflow/tools/tfg_graph_transforms/utils.h index ec4ca3c781b40d..9ea59a385ad6ee 100644 --- a/tensorflow/tools/tfg_graph_transforms/utils.h +++ b/tensorflow/tools/tfg_graph_transforms/utils.h @@ -36,8 +36,7 @@ namespace graph_transforms { // If the format of proto cannot be identified based on the file extension, // attempts to load in a binary format first and then in a text format. template -tensorflow::Status ReadModelProto(const std::string& input_file, - T& model_proto) { +absl::Status ReadModelProto(const std::string& input_file, T& model_proto) { // Proto might be either in binary or text format. tensorflow::StringPiece extension = tensorflow::io::Extension(input_file); bool binary_extenstion = !extension.compare("pb"); @@ -76,8 +75,7 @@ tensorflow::Status ReadModelProto(const std::string& input_file, bool IsTextProto(const std::string& input_file); template -tensorflow::Status SerializeProto(T model_proto, - const std::string& output_file) { +absl::Status SerializeProto(T model_proto, const std::string& output_file) { auto output_dir = tensorflow::io::Dirname(output_file); TF_RETURN_IF_ERROR(tensorflow::Env::Default()->RecursivelyCreateDir( @@ -97,11 +95,11 @@ tensorflow::Status SerializeProto(T model_proto, } // Read and write to the experimental SavedModel Image format. -tensorflow::Status ReadSavedModelImageFormat( - const std::string& input_file, tensorflow::SavedModel& model_proto); -tensorflow::Status WriteSavedModelImageFormat( - tensorflow::SavedModel* model_proto, const std::string& output_file, - int debug_max_size); +absl::Status ReadSavedModelImageFormat(const std::string& input_file, + tensorflow::SavedModel& model_proto); +absl::Status WriteSavedModelImageFormat(tensorflow::SavedModel* model_proto, + const std::string& output_file, + int debug_max_size); } // namespace graph_transforms } // namespace tfg diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index 2feee8960439e9..1182e52997fce0 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/tensorflow/tools/toolchains/remote_config/containers.bzl b/tensorflow/tools/toolchains/remote_config/containers.bzl index c976ddc7dbbdd5..b22fbe0b65ad2e 100644 --- a/tensorflow/tools/toolchains/remote_config/containers.bzl +++ b/tensorflow/tools/toolchains/remote_config/containers.bzl @@ -7,12 +7,12 @@ container_digests = { # JAX manylinux2014 configs. "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3", "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63", - "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:7d948c3d2e3ab8867d600457b5666cc74c4206f08517791c95fc9a69b7cffefa", + "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:fafe12fbe5bb02a21b9a95aa9dc3ac6d0e6276fcb7dd26bf1bb2d093b444b71a", "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", "cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:0c78f3428cde36f041b758fc2f01d23d2f0dd72dec248f78667fb0c9d1f74cef", "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2", "cuda12.3-cudnn8.9-ubuntu22.04-manylinux2014-multipython": "sha256:97b219abb22994cf0530771d536f26fe301bacd328f0485c38af3847c2ee6b14", - "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:e590303ea55a0990c26db4640161120ff6bc4124152c62155d397ba22d2ca850", + "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:a9acf6849a905079847074798405b18d4badc6270dc32076f9e7ac4b377e51a8", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl index a916c10e77d634..8a6120efbbd69d 100644 --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index be83c971749341..fedad18909d154 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -150,18 +150,18 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "f66213a4d66991b2a44400f95fcd260adf6f4f7077956cdf7fce2571d6164d5e", - strip_prefix = "XNNPACK-6b83f69d4938da4dc9ad63c00bd13e9695659a51", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/6b83f69d4938da4dc9ad63c00bd13e9695659a51.zip"), + sha256 = "bfedea7d94d4b7953a857868b63eda27a2d8206c79dd0b0456d4150cc43bf825", + strip_prefix = "XNNPACK-743f95f0c34b02d6d2cdb9e87da21caffe9c668f", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/743f95f0c34b02d6d2cdb9e87da21caffe9c668f.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) # XNNPack dependency. tf_http_archive( name = "KleidiAI", - sha256 = "88233e427be6579560073267575f00f3b5fc370a31a43bbdd87a1810bd4bf1b6", - strip_prefix = "kleidiai-cddf991af5de49fd34949fa39690e4e906e04074", - urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/cddf991af5de49fd34949fa39690e4e906e04074/kleidiai-cddf991af5de49fd34949fa39690e4e906e04074.zip"), + sha256 = "6682b7a2795c711c1dd23ada552675b6514523e991043753648f2cad826f588f", + strip_prefix = "kleidiai-382b07835c43fcb0401cb4dab3c8fb85eaf187b6", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/382b07835c43fcb0401cb4dab3c8fb85eaf187b6/kleidiai-382b07835c43fcb0401cb4dab3c8fb85eaf187b6.zip"), ) tf_http_archive( @@ -180,9 +180,10 @@ def _tf_repositories(): tf_http_archive( name = "cpuinfo", - sha256 = "2bf2b62eb86e2d2eaf862d0b9683a6c467a4d69fb2f7f1dc47c799809148608f", - strip_prefix = "cpuinfo-fa1c679da8d19e1d87f20175ae1ec10995cd3dd3", - urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/fa1c679da8d19e1d87f20175ae1ec10995cd3dd3.zip"), + sha256 = "ca31f17a86e4db01b5fc05efa1807ddc84c02ba4611464b67e185e8210bf096b", + strip_prefix = "cpuinfo-1e83a2fdd3102f65c6f1fb602c1b320486218a99", + patch_file = ["//third_party/cpuinfo:cpuinfo_ppc64le_support.patch"], + urls = tf_mirror_urls("https://github.com/pytorch/cpuinfo/archive/1e83a2fdd3102f65c6f1fb602c1b320486218a99.zip"), ) tf_http_archive( @@ -524,9 +525,17 @@ def _tf_repositories(): name = "nccl_archive", build_file = "//third_party:nccl/archive.BUILD", patch_file = ["//third_party/nccl:archive.patch"], - sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", - strip_prefix = "nccl-2.21.5-1", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), + sha256 = "6b946b70a9d2d01871842cbd15ec56488d358abe9a0f3767e372fddc3e241ba7", + strip_prefix = "nccl-2.23.4-1", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.23.4-1.tar.gz"), + ) + + tf_http_archive( + name = "nvtx_archive", + build_file = "//third_party:nvtx/BUILD", + sha256 = "e4438f921fb88a564b0b92791c1c1fdd0f388901213e6a31fdd0dc3803fb9764", + strip_prefix = "NVTX-bf31d7859ab3130cbf1ef77c33d18d0ebb8c8d08/c/include", + urls = tf_mirror_urls("https://github.com/NVIDIA/NVTX/archive/bf31d7859ab3130cbf1ef77c33d18d0ebb8c8d08.tar.gz"), ) java_import_external( @@ -616,15 +625,6 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/NVlabs/cub/archive/1.9.9.zip"), ) - # Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h - tf_http_archive( - name = "nvtx_archive", - build_file = "//third_party:nvtx/BUILD", - sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", - strip_prefix = "nccl-2.21.5-1/src/include/nvtx3", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), - ) - tf_http_archive( name = "cython", build_file = "//third_party:cython.BUILD", @@ -797,7 +797,10 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/pybind/pybind11_protobuf/archive/80f3440cd8fee124e077e2e47a8a17b78b451363.zip"), sha256 = "c7ab64b1ccf9a678694a89035a8c865a693e4e872803778f91f0965c2f281d78", strip_prefix = "pybind11_protobuf-80f3440cd8fee124e077e2e47a8a17b78b451363", - patch_file = ["//third_party/pybind11_protobuf:remove_license.patch"], + patch_file = [ + "//third_party/pybind11_protobuf:protobuf.patch", + "//third_party/pybind11_protobuf:remove_license.patch", + ], ) tf_http_archive( diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/BUILD b/third_party/cpuinfo/BUILD similarity index 100% rename from third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/BUILD rename to third_party/cpuinfo/BUILD diff --git a/third_party/cpuinfo/cpuinfo_ppc64le_support.patch b/third_party/cpuinfo/cpuinfo_ppc64le_support.patch new file mode 100644 index 00000000000000..9d735af3b5f89e --- /dev/null +++ b/third_party/cpuinfo/cpuinfo_ppc64le_support.patch @@ -0,0 +1,24 @@ +diff --git a/BUILD.bazel b/BUILD.bazel +index 2c6375f..5417d7e 100644 +--- a/BUILD.bazel ++++ b/BUILD.bazel +@@ -137,6 +137,7 @@ cc_library( + ":linux_riscv32": COMMON_SRCS + RISCV_SRCS + LINUX_SRCS + LINUX_RISCV_SRCS, + ":linux_riscv64": COMMON_SRCS + RISCV_SRCS + LINUX_SRCS + LINUX_RISCV_SRCS, + ":linux_s390x": COMMON_SRCS + LINUX_SRCS, ++ ":linux_ppc64le": COMMON_SRCS + LINUX_SRCS, + ":macos_x86_64": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":macos_x86_64_legacy": COMMON_SRCS + X86_SRCS + MACH_SRCS + MACH_X86_SRCS, + ":macos_arm64": COMMON_SRCS + MACH_SRCS + MACH_ARM_SRCS, +@@ -277,6 +278,11 @@ config_setting( + values = {"cpu": "s390x"}, + ) + ++config_setting( ++ name = "linux_ppc64le", ++ values = {"cpu": "ppc"}, ++) ++ + config_setting( + name = "macos_x86_64_legacy", + values = { diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index b1a10a86b9aac6..a1d47efcc93a81 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -27,16 +27,10 @@ import os import os.path import platform +import shutil import subprocess import sys -# pylint: disable=g-import-not-at-top,g-importing-member -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top,g-importing-member - class ConfigError(Exception): pass @@ -59,7 +53,7 @@ def check_cuda_lib(path, check_soname=True): """ if not os.path.isfile(path): raise ConfigError("No library found under: " + path) - objdump = which("objdump") + objdump = shutil.which("objdump") if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") diff --git a/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/gpus/crosstool/BUILD.rocm.tpl index 9eac59cb532506..011cdc72bf0cfc 100644 --- a/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -82,18 +82,22 @@ cc_toolchain_config( "-fdata-sections", ], dbg_compile_flags = ["-g"], - cxx_flags = ["-std=c++14"], + cxx_flags = ["-std=c++17"], link_flags = [ "-Wl,-no-as-needed", "-Wl,-z,relro,-z,now", - "-pass-exit-codes", + ], + link_libs = [ "-lstdc++", "-lm", +<<<<<<< HEAD ] + [%{link_flags}], link_libs = [], +======= + ], +>>>>>>> master opt_link_flags = [], unfiltered_compile_flags = [ - "-fno-canonical-system-headers", "-Wno-builtin-macro-redefined", "-D__DATE__=\"redacted\"", "-D__TIMESTAMP__=\"redacted\"", diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index a89e2a146d4921..7eb6fc5543ea89 100755 --- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -24,8 +24,10 @@ import pipes # Template values set by rocm_configure.bzl. CPU_COMPILER = ('%{cpu_compiler}') +HOST_COMPILER_PATH = ('%{host_compiler_path}') HIPCC_PATH = '%{hipcc_path}' +PREFIX_DIR = os.path.dirname(HOST_COMPILER_PATH) HIPCC_ENV = '%{hipcc_env}' HIP_RUNTIME_PATH = '%{hip_runtime_path}' HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' @@ -81,6 +83,7 @@ def GetHostCompilerOptions(argv): parser.add_argument('--sysroot', nargs=1) parser.add_argument('-g', nargs='*', action='append') parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') parser.add_argument('--genco', action='store_true') args, _ = parser.parse_known_args(argv) @@ -93,7 +96,7 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - if args.fno_canonical_system_headers: + if args.fno_canonical_system_headers or args.no_canonical_prefixes: opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl index e0541defa34687..e5a942b66c17fc 100644 --- a/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +++ b/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl @@ -1046,7 +1046,6 @@ def _impl(ctx): flag_group( flags = [ "-no-canonical-prefixes", - "-fno-canonical-system-headers", ] ), ], diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index 2faabefe081f4b..6c1b68ffb77bcf 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -149,11 +149,14 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], tags = [],**kwargs): +def cuda_library(copts = [], tags = [], deps = [], **kwargs): """Wrapper over cc_library which adds default CUDA options.""" native.cc_library( copts = cuda_default_copts() + copts, tags = tags + ["gpu"], + deps = deps + if_cuda_is_configured([ + "@local_config_cuda//cuda:implicit_cuda_headers_dependency", + ]), **kwargs ) diff --git a/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/gpus/cuda/hermetic/BUILD.tpl index 5d9a9da3c967d8..58c4638dd55c3f 100644 --- a/third_party/gpus/cuda/hermetic/BUILD.tpl +++ b/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -69,6 +69,16 @@ cc_library( ":nvjitlink_headers"], ) +# This target is needed by the `cuda_library` rule. We can't implicitly +# depend on `:cuda_headers` directly since the user may explicit depend +# on `:cuda_headers` and duplicated dependencies are not allowed in Bazel. +# There is also no good way to deduplicate dependencies, but an alias works +# just fine. +alias( + name = "implicit_cuda_headers_dependency", + actual = ":cuda_headers", +) + cc_library( name = "cudart_static", srcs = ["@cuda_cudart//:static"], @@ -79,6 +89,11 @@ cc_library( ], ) +alias( + name = "cuda_runtime", + actual = ":cudart_static", +) + alias( name = "cuda_driver", actual = select({ diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index 11b32cdbb71c56..ecc99f06455614 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -219,6 +219,10 @@ def _create_libcuda_symlinks( repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1") repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so") +def _create_cuda_header_symlinks(repository_ctx): + if repository_ctx.name == "cuda_nvcc": + repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h") + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): repository_ctx, lib_name_to_version_dict, ) + _create_cuda_header_symlinks(repository_ctx) repository_ctx.file("version.txt", major_version) def _cuda_repo_impl(repository_ctx): diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 6934b75b47852d..89516f869ad07b 100644 --- a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -58,6 +58,10 @@ CUDA_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", ], + "12.6.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.1.json", + "22ddfeb81a6f9cee4a708a2e3b4db1c36c7db0a1daa1f33f9c7f2f12a1e790de", + ], } CUDNN_REDIST_JSON_DICT = { @@ -97,20 +101,22 @@ CUDNN_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.4.0.json", "6eeaafc5cc3d4bb2f283e6298e4c55d4c59d7c83c5d9fd8721a2c0e55aee4e54", ], + "9.5.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.5.0.json", + "3939f0533fdd0d3aa7edd1ac358d43da18e438e5d8f39c3c15bb72519bad7fb5", + ], } -# The versions are different for x86 and aarch64 architectures because only -# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. CUDA_12_NCCL_WHEEL_DICT = { "x86_64-unknown-linux-gnu": { - "version": "2.21.5", - "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", - "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + "version": "2.23.4", + "url": "https://files.pythonhosted.org/packages/ed/1f/6482380ec8dcec4894e7503490fc536d846b0d59694acad9cf99f27d0e7d/nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl", + "sha256": "b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1", }, "aarch64-unknown-linux-gnu": { - "version": "2.20.5", - "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", - "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + "version": "2.23.4", + "url": "https://files.pythonhosted.org/packages/c8/3a/0112397396dec37ffc8edd7836d48261b4d14ca60ec8ed7bc857cce1d916/nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_aarch64.whl", + "sha256": "aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec", }, } @@ -134,12 +140,14 @@ CUDA_NCCL_WHEELS = { "12.5.0": CUDA_12_NCCL_WHEEL_DICT, "12.5.1": CUDA_12_NCCL_WHEEL_DICT, "12.6.0": CUDA_12_NCCL_WHEEL_DICT, + "12.6.1": CUDA_12_NCCL_WHEEL_DICT, } REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "nvidia_driver": { "repo_name": "cuda_driver", "version_to_template": { + "560": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "555": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "550": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "545": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index 68623bf671da71..c04dace79fe599 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -56,21 +56,15 @@ tf__library_dir: ... """ +import glob import io import os -import glob import platform import re +import shutil import subprocess import sys -# pylint: disable=g-import-not-at-top -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top - class ConfigError(Exception): pass @@ -139,7 +133,7 @@ def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" if not _is_linux(): return [] - ldconfig_path = which("ldconfig") or "/sbin/ldconfig" + ldconfig_path = shutil.which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") result = set() diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 4f54c3ef9b7626..4a834a62eae68e 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -3,7 +3,11 @@ `rocm_configure` depends on the following environment variables: * `TF_NEED_ROCM`: Whether to enable building with ROCm. - * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path + * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path. + * `TF_ROCM_CLANG`: Whether to use clang for C++ and HIPCC for ROCm compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `CLANG_COMPILER_PATH`: The clang compiler path that will be used for + host code compilation if TF_ROCM_CLANG is 1. * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`. * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ @@ -40,6 +44,8 @@ load( _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" +_CLANG_COMPILER_PATH = "CLANG_COMPILER_PATH" +_TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" @@ -76,9 +82,10 @@ def verify_build_defines(params): ".", ) -def find_cc(repository_ctx): +def find_cc(repository_ctx, use_rocm_clang): """Find the C++ compiler.""" +<<<<<<< HEAD if _is_clang_enabled(repository_ctx): target_cc_name = "clang" cc_path_envvar = "CLANG_COMPILER_PATH" @@ -87,6 +94,14 @@ def find_cc(repository_ctx): target_cc_name = "gcc" cc_path_envvar = _GCC_HOST_COMPILER_PATH +======= + if use_rocm_clang: + target_cc_name = "clang" + cc_path_envvar = _CLANG_COMPILER_PATH + else: + target_cc_name = "gcc" + cc_path_envvar = _GCC_HOST_COMPILER_PATH +>>>>>>> master cc_name = target_cc_name cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) @@ -108,24 +123,26 @@ def _cxx_inc_convert(path): path = path.strip() return path -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): """Compute the list of default C or C++ include directories.""" if lang_is_cpp: lang = "c++" else: lang = "c" + sysroot = [] + if tf_sysroot: + sysroot += ["--sysroot", tf_sysroot] # TODO: We pass -no-canonical-prefixes here to match the compiler flags, # but in rocm_clang CROSSTOOL file that is a `feature` and we should # handle the case when it's disabled and no flag is passed result = raw_exec(repository_ctx, [ cc, - "-no-canonical-prefixes", "-E", "-x" + lang, "-", "-v", - ]) + ] + sysroot) stderr = err_out(result) index1 = stderr.find(_INC_DIR_MARKER_BEGIN) if index1 == -1: @@ -147,14 +164,24 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): for p in inc_dirs.split("\n") ] -def get_cxx_inc_directories(repository_ctx, cc): +def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): """Compute the list of default C and C++ include directories.""" # For some reason `clang -xc` sometimes returns include paths that are # different from the ones from `clang -xc++`. (Symlink and a dir) # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) - includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sysroot, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sysroot, + ) includes_cpp_set = depset(includes_cpp) return includes_cpp + [ @@ -214,6 +241,22 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_config.llvm_path + "/lib/clang/18/include") inc_dirs.append(rocm_config.llvm_path + "/lib/clang/19/include") rocm_toolkit_path = realpath(repository_ctx, rocm_config.rocm_toolkit_path, bash_bin) +<<<<<<< HEAD +======= + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/8.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/9.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/10.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/11.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/12.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/13.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/14.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/15.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") +>>>>>>> master if int(rocm_config.rocm_version_number) >= 60200: inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/17/include") inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include") @@ -528,7 +571,7 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": _lib_name("hipfft"), + "%{hipfft_or_rocfft}": "hipfft", "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), @@ -591,8 +634,18 @@ def _genrule(src_dir, genrule_name, command, outs): ")\n" ) +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_rocm_clang(repository_ctx): + # Returns the flag if we need to use clang for the host. + return _flag_enabled(repository_ctx, "TF_ROCM_CLANG") + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): - amdgpu_target_flags = ["--amdgpu-target=" + + amdgpu_target_flags = ["--offload-arch=" + amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) @@ -709,6 +762,10 @@ def _create_local_rocm_repository(repository_ctx): "%{copy_rules}": "\n".join(copy_rules), "%{rocm_headers}": ('":rocm-include",\n' + rocm_components_include), } + + is_rocm_clang = _use_rocm_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + if rocm_libs["hipblaslt"] != None: repository_dict["%{hipblaslt_lib}"] = rocm_libs["hipblaslt"].file_name @@ -724,24 +781,36 @@ def _create_local_rocm_repository(repository_ctx): # Set up crosstool/ - cc = find_cc(repository_ctx) - - host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) + cc = find_cc(repository_ctx, is_rocm_clang) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) - host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") + # host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) rocm_defines = {} - + rocm_defines["%{builtin_sysroot}"] = tf_sysroot + rocm_defines["%{compiler}"] = "unknown" + if is_rocm_clang: + rocm_defines["%{compiler}"] = "clang" + host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix + rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + host_compiler_prefix + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "" + rocm_defines["%{unfiltered_compile_flags}"] = "" + rocm_defines["%{rocm_hipcc_files}"] = "[]" - rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin" - - # For gcc, do not canonicalize system header paths; some versions of gcc - # pick the shortest possible path for system includes when creating the - # .d file - given that includes that are prefixed with "../" multiple - # time quickly grow longer than the root of the tree, this can lead to - # bazel's header check failing. - rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" + if is_rocm_clang: + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-no-canonical-prefixes\"" + else: + # For gcc, do not canonicalize system header paths; some versions of gcc + # pick the shortest possible path for system includes when creating the + # .d file - given that includes that are prefixed with "../" multiple + # time quickly grow longer than the root of the tree, this can lead to + # bazel's header check failing. + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", @@ -880,6 +949,10 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", "TF_ROCM_CLANG", +<<<<<<< HEAD +======= + "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro +>>>>>>> master _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, "CLANG_COMPILER_PATH", diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979e83..79626b687eb4e4 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,151 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/Linker/IRMover.cpp b/llvm/lib/Linker/IRMover.cpp +--- a/llvm/lib/Linker/IRMover.cpp ++++ b/llvm/lib/Linker/IRMover.cpp +@@ -595,11 +595,15 @@ + if (!SGV) + return nullptr; + ++ // If SGV is from dest, it was already materialized when dest was loaded. ++ if (SGV->getParent() == &DstM) ++ return nullptr; ++ + // When linking a global from other modules than source & dest, skip + // materializing it because it would be mapped later when its containing + // module is linked. Linking it now would potentially pull in many types that + // may not be mapped properly. +- if (SGV->getParent() != &DstM && SGV->getParent() != SrcM.get()) ++ if (SGV->getParent() != SrcM.get()) + return nullptr; + + Expected NewProto = linkGlobalValueProto(SGV, ForIndirectSymbol); +diff -ruN --strip-trailing-cr a/llvm/test/ThinLTO/X86/Inputs/ditemplatevalueparameter-remap.ll b/llvm/test/ThinLTO/X86/Inputs/ditemplatevalueparameter-remap.ll +--- a/llvm/test/ThinLTO/X86/Inputs/ditemplatevalueparameter-remap.ll ++++ b/llvm/test/ThinLTO/X86/Inputs/ditemplatevalueparameter-remap.ll +@@ -0,0 +1,29 @@ ++target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" ++target triple = "x86_64-unknown-linux-gnu" ++ ++define void @_Z8thinlto1v() unnamed_addr { ++ %3 = alloca i64, align 4 ++ #dbg_declare(ptr %3, !14, !DIExpression(), !15) ++ ret void ++} ++ ++!llvm.dbg.cu = !{!0} ++!llvm.module.flags = !{!2, !3, !4, !5} ++ ++!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug, splitDebugInlining: false, nameTableKind: None) ++!1 = !DIFile(filename: "B.cpp", directory: ".") ++!2 = !{i32 7, !"Dwarf Version", i32 4} ++!3 = !{i32 2, !"Debug Info Version", i32 3} ++!4 = !{i32 1, !"wchar_size", i32 4} ++!5 = !{i32 8, !"PIC Level", i32 2} ++!10 = distinct !DISubprogram(name: "thinlto1", linkageName: "_Z8thinlto1v", scope: !11, file: !11, line: 8, type: !12, scopeLine: 8, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) ++!11 = !DIFile(filename: "b.cpp", directory: ".") ++!12 = !DISubroutineType(types: !13) ++!13 = !{null} ++!14 = !DILocalVariable(name: "a", arg: 1, scope: !10, file: !11, line: 18, type: !16) ++!15 = !DILocation(line: 18, column: 19, scope: !10) ++!16 = distinct !DICompositeType(tag: DW_TAG_structure_type, name: "S<&func1>", file: !11, line: 2, size: 8, flags: DIFlagTypePassByValue, elements: !17, templateParams: !18, identifier: "_ZTS1SIXadL_Z5func1vEEE") ++!17 = !{} ++!18 = !{!19} ++!19 = !DITemplateValueParameter(name: "Func", type: !20, value: ptr undef) ++!20 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !12, size: 64) +diff -ruN --strip-trailing-cr a/llvm/test/ThinLTO/X86/ditemplatevalueparameter-remap.ll b/llvm/test/ThinLTO/X86/ditemplatevalueparameter-remap.ll +--- a/llvm/test/ThinLTO/X86/ditemplatevalueparameter-remap.ll ++++ b/llvm/test/ThinLTO/X86/ditemplatevalueparameter-remap.ll +@@ -0,0 +1,93 @@ ++; https://github.com/llvm/llvm-project/pull/110064 ++; This test case checks if thinLTO correctly links metadata values in a specific ++; situation. Assume we are linking module B into module A, where an extern ++; function used in A is defined in B, but the function body has a ++; DITemplateValueParameter referring to another function back in A. The ++; compiler must check this other function is actually coming from A, thus ++; already materialized and does not require remapping. The IR here is modified ++; from the following source code. ++; ++; // A.h ++; template ++; struct S { ++; void Impl() { ++; Func(); ++; } ++; }; ++; ++; void func1(); ++; ++; // A.cpp ++; #include "A.h" ++; __attribute__((weak)) void func1() {} ++; extern void thinlto1(); ++; void bar() { ++; S s; // Force instantiation of S in this compilation unit. ++; s.Impl(); ++; thinlto1(); ++; } ++; ++; // B.cpp ++; #include "A.h" ++; void thinlto1() { ++; S s; ++; } ++; ++; RUN: opt -module-summary -o %t1.bc %s ++; RUN: opt -module-summary -o %t2.bc %S/Inputs/ditemplatevalueparameter-remap.ll ++; RUN: llvm-lto2 run %t1.bc %t2.bc -o %t3 -save-temps \ ++; RUN: -r=%t1.bc,_Z5func1v,p \ ++; RUN: -r=%t1.bc,_Z3bazv,px \ ++; RUN: -r=%t1.bc,_Z8thinlto1v,x \ ++; RUN: -r=%t1.bc,_Z3barv,px \ ++; RUN: -r=%t2.bc,_Z8thinlto1v,px ++; RUN: llvm-dis %t3.1.4.opt.bc -o - | FileCheck %s ++ ++target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128" ++target triple = "x86_64-unknown-linux-gnu" ++ ++$_Z5func1v = comdat any ++ ++define linkonce_odr void @_Z5func1v() unnamed_addr !dbg !10 { ++ ret void ++} ++ ++; Dummy function to use _Z5func1v so that it is not treated as dead symbol. ++define void @_Z3bazv() { ++ tail call void @_Z5func1v() ++ ret void ++} ++ ++declare void @_Z8thinlto1v() unnamed_addr ++ ++; Check _Z8thinlto1v is inlined after thinLTO. ++; CHECK: void @_Z3barv() ++; CHECK-NOT: @_Z8thinlto1v() ++; CHECK-NEXT: ret void ++define void @_Z3barv() unnamed_addr !dbg !14 { ++ tail call void @_Z8thinlto1v(), !dbg !25 ++ ret void ++} ++ ++!llvm.dbg.cu = !{!0} ++!llvm.module.flags = !{!2, !3, !4, !5} ++ ++!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug, splitDebugInlining: false, nameTableKind: None) ++!1 = !DIFile(filename: "A.cpp", directory: ".") ++!2 = !{i32 7, !"Dwarf Version", i32 4} ++!3 = !{i32 2, !"Debug Info Version", i32 3} ++!4 = !{i32 1, !"wchar_size", i32 4} ++!5 = !{i32 8, !"PIC Level", i32 2} ++!10 = distinct !DISubprogram(name: "func1", linkageName: "_Z5func1v", scope: !11, file: !11, line: 6, type: !12, scopeLine: 6, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0) ++!11 = !DIFile(filename: "a.h", directory: ".") ++!12 = !DISubroutineType(types: !13) ++!13 = !{null} ++!14 = distinct !DISubprogram(name: "bar", linkageName: "_Z3barv", scope: !11, file: !11, line: 15, type: !12, scopeLine: 15, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !16) ++!16 = !{!17} ++!17 = !DILocalVariable(name: "s", scope: !14, file: !11, line: 10, type: !18) ++!18 = distinct !DICompositeType(tag: DW_TAG_structure_type, name: "S<&func1>", file: !11, line: 2, size: 8, flags: DIFlagTypePassByValue, elements: !19, templateParams: !20, identifier: "_ZTS1SIXadL_Z5func1vEEE") ++!19 = !{} ++!20 = !{!21} ++!21 = !DITemplateValueParameter(name: "Func", type: !22, value: ptr @_Z5func1v) ++!22 = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: !12, size: 64) ++!25 = !DILocation(line: 16, column: 5, scope: !14) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c011aabc014eda..2d60ea057d9d8f 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" - LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" + LLVM_COMMIT = "33363521ca24f912cc25530f6cecbca53acce8a3" + LLVM_SHA256 = "3fd9cbd992ed880e348d81715f39138538fd6c8e9164b981551a97181a3b7b24" tf_http_archive( name = name, diff --git a/third_party/nanobind/nanobind.BUILD b/third_party/nanobind/nanobind.BUILD index 72b47585b5e5d0..814fe3595df65d 100644 --- a/third_party/nanobind/nanobind.BUILD +++ b/third_party/nanobind/nanobind.BUILD @@ -1,7 +1,21 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +bool_flag( + name = "enabled_free_threading", + build_setting_default = False, +) + +config_setting( + name = "use_enabled_free_threading", + flag_values = { + ":enabled_free_threading": "True", + }, +) + cc_library( name = "nanobind", srcs = glob( @@ -11,10 +25,17 @@ cc_library( exclude = ["src/nb_combined.cpp"], ), copts = ["-fexceptions"], - defines = [ - "NB_BUILD=1", - "NB_SHARED=1", - ], + defines = select({ + ":use_enabled_free_threading": [ + "NB_FREE_THREADED=1", + "NB_BUILD=1", + "NB_SHARED=1", + ], + "//conditions:default": [ + "NB_BUILD=1", + "NB_SHARED=1", + ], + }), includes = ["include"], textual_hdrs = glob( [ diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl index 1c692d396e9b98..aa39484e078f3b 100644 --- a/third_party/nanobind/workspace.bzl +++ b/third_party/nanobind/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-2.1.0", - sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), + strip_prefix = "nanobind-2.2.0", + sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", ) diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD index 1f4b58f47e379c..bfbde6cf22eeff 100644 --- a/third_party/nccl/archive.BUILD +++ b/third_party/nccl/archive.BUILD @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"]) NCCL_MAJOR = 2 -NCCL_MINOR = 21 +NCCL_MINOR = 23 -NCCL_PATCH = 5 +NCCL_PATCH = 4 NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605 diff --git a/third_party/nccl/archive.patch b/third_party/nccl/archive.patch index 2b4fa56a97e759..4fc2dbb7aded8a 100644 --- a/third_party/nccl/archive.patch +++ b/third_party/nccl/archive.patch @@ -1,35 +1,16 @@ -diff --git a/src/device/all_gather.h b/src/device/all_gather.h -index 809e8ae..57eab81 100644 ---- a/src/device/all_gather.h -+++ b/src/device/all_gather.h -@@ -296,7 +296,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - } - return; -@@ -314,7 +314,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - return; - } diff --git a/src/device/common.cu b/src/device/common.cu.cc similarity index 100% rename from src/device/common.cu rename to src/device/common.cu.cc +diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc +similarity index 100% +rename from src/device/onerank.cu +rename to src/device/onerank.cu.cc diff --git a/src/device/common.h b/src/device/common.h -index d8581d3..09ac3b6 100644 --- a/src/device/common.h +++ b/src/device/common.h -@@ -15,7 +15,7 @@ - #define COLL_UNROLL (ncclCollUnroll()) +@@ -24,7 +24,7 @@ + #endif typedef void(*ncclDevFuncPtr_t)(); -extern __device__ ncclDevFuncPtr_t const ncclDevFuncTable[]; @@ -38,14 +19,16 @@ index d8581d3..09ac3b6 100644 struct ncclShmemGroup { ncclConnInfo *recvConns[NCCL_MAX_ARITY]; diff --git a/src/device/generate.py b/src/device/generate.py -index 43de85d..87cd677 100755 +index a0d2259..62d6014 100755 --- a/src/device/generate.py +++ b/src/device/generate.py -@@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) +@@ -194,8 +194,8 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) + ################################################################################ - # Generate /device_table.cu +-# Generate /device_table.cu -with open(os.path.join(gensrc, "device_table.cu"), "w") as f: ++# Generate /device_table.cu.cc +with open(os.path.join(gensrc, "device_table.cu.cc"), "w") as f: out = f.write out('#include "common.h"\n') @@ -59,12 +42,11 @@ index 43de85d..87cd677 100755 index = 0 for fn in primary_funcs: sym = paste("_", "ncclDevFunc", *fn) -@@ -257,28 +257,45 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: +@@ -262,28 +262,43 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: # List of all kernel function pointers. out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs)) - out("extern void* const ncclDevKernelList[] = {\n") -+ index = 0 for kfn in kernel_funcs: cudart, _ = required_cuda(*kfn) @@ -88,7 +70,6 @@ index 43de85d..87cd677 100755 # Maps primary id to kernel function pointer. - out("extern void* const ncclDevKernelForFunc[] = {\n") -+ index = 0 for fn in primary_funcs: kfn = best_kernel(*fn) @@ -111,7 +92,7 @@ index 43de85d..87cd677 100755 index += 1 out("nullptr};\n") out("\n") -@@ -297,7 +314,7 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: +@@ -302,7 +317,7 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: # "coll" is reflected in the name: formally that no two funcs having different # coll's map to the same filename. def impl_filename(coll, redop, ty, algo, proto): @@ -120,7 +101,7 @@ index 43de85d..87cd677 100755 # Partition the functions and kernels to the .cu filenames. The partition is # a dictionary mapping filename to (coll, func-tuple list) -@@ -318,7 +335,7 @@ name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Gene +@@ -323,7 +338,7 @@ name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Gene with open(os.path.join(gensrc, "rules.mk"), "w") as f: out = f.write impl_names = sorted(name_to_funcs.keys()) @@ -129,29 +110,3 @@ index 43de85d..87cd677 100755 out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n" .format(names=" ".join(names))) out("\n") -diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc -similarity index 100% -rename from src/device/onerank.cu -rename to src/device/onerank.cu.cc -diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h -index d0b5249..2dacd60 100644 ---- a/src/device/reduce_scatter.h -+++ b/src/device/reduce_scatter.h -@@ -254,7 +254,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - return; - } -@@ -278,7 +278,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - } - return; diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl index 14469acdfc5aa1..c1e49a6b9f1dd2 100644 --- a/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -60,6 +60,15 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "hermetic_nccl_config", hdrs = ["nccl_config.h"], diff --git a/third_party/protobuf/protobuf.patch b/third_party/protobuf/protobuf.patch index 9d928ba175f330..ac33ccbf8c3aea 100644 --- a/third_party/protobuf/protobuf.patch +++ b/third_party/protobuf/protobuf.patch @@ -1,22 +1,46 @@ diff --git a/BUILD.bazel b/BUILD.bazel --- a/BUILD.bazel (revision 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66) -+++ b/BUILD.bazel (date 1670471682469) -@@ -68,6 +68,7 @@ ++++ b/BUILD.bazel (date 1714620794503) +@@ -68,6 +68,8 @@ copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, ++ local_defines = ["PROTOBUF_USE_DLLS", "LIBPROTOBUF_EXPORTS"], + alwayslink = 1, visibility = ["//visibility:public"], ) -@@ -135,6 +136,7 @@ +@@ -135,6 +137,8 @@ copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, ++ local_defines = ["PROTOBUF_USE_DLLS", "LIBPROTOBUF_EXPORTS"], + alwayslink = 1, visibility = ["//visibility:public"], deps = [":protobuf_lite"] + select({ "//build_defs:config_msvc": [], +@@ -1074,7 +1078,8 @@ + "@com_google_protobuf//:type_proto", + "@com_google_protobuf//:wrappers_proto", + ], +- command_line = "--cpp_out=$(OUT)", ++ command_line = "--cpp_out=dllexport_decl=PROTOBUF_EXPORT:$(OUT)", ++# command_line = "--cpp_out=$(OUT)", + runtime = ":protobuf", + visibility = ["//visibility:public"], + ) +diff --git a/protobuf.bzl b/protobuf.bzl +--- a/protobuf.bzl (revision 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66) ++++ b/protobuf.bzl (date 1714611573270) +@@ -127,7 +127,7 @@ + use_grpc_plugin = (ctx.attr.plugin_language == "grpc" and ctx.attr.plugin) + path_tpl = "$(realpath %s)" if in_gen_dir else "%s" + if ctx.attr.gen_cc: +- args += [("--cpp_out=" + path_tpl) % gen_dir] ++ args += [("--cpp_out=dllexport_decl=PROTOBUF_EXPORT:" + path_tpl) % gen_dir] + outs.extend(_CcOuts([src.basename], use_grpc_plugin = use_grpc_plugin)) + if ctx.attr.gen_py: + args += [("--python_out=" + path_tpl) % gen_dir] diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 162531226..e93ec4809 100644 --- a/python/google/protobuf/pyext/descriptor.cc diff --git a/third_party/py/BUILD b/third_party/py/BUILD index 84eba77ce1a7af..0381d65bb27514 100644 --- a/third_party/py/BUILD +++ b/third_party/py/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") @@ -38,3 +39,16 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, ) + +# Flag indicating if the target requires pre-built wheel. +bool_flag( + name = "wheel_dependency", + build_setting_default = False, +) + +config_setting( + name = "enable_wheel_dependency", + flag_values = { + ":wheel_dependency": "True", + }, +) diff --git a/third_party/py/python_repo.bzl b/third_party/py/python_repo.bzl index 13aed2b687129f..6fe63fb9c1e674 100644 --- a/third_party/py/python_repo.bzl +++ b/third_party/py/python_repo.bzl @@ -14,6 +14,7 @@ def _python_repository_impl(ctx): ctx.file("BUILD", "") wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow") wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False) + macos_deployment_target = ctx.os.environ.get("MACOSX_DEPLOYMENT_TARGET", "") requirements = None for i in range(0, len(ctx.attr.requirements_locks)): @@ -34,13 +35,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -55,6 +54,13 @@ Please check python_init_repositories() in your WORKSPACE file. merged_requirements_content, ) + use_pywrap_rules = bool( + ctx.os.environ.get("USE_PYWRAP_RULES", False), + ) + + if use_pywrap_rules: + print("!!!Using pywrap rules instead of directly creating .so objects!!!") # buildifier: disable=print + ctx.file( "py_version.bzl", """ @@ -64,12 +70,16 @@ WHEEL_NAME = "{wheel_name}" WHEEL_COLLAB = "{wheel_collab}" REQUIREMENTS = "{requirements}" REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}" +USE_PYWRAP_RULES = {use_pywrap_rules} +MACOSX_DEPLOYMENT_TARGET = "{macos_deployment_target}" """.format( version = version, wheel_name = wheel_name, wheel_collab = wheel_collab, requirements = str(requirements), requirements_with_local_wheels = requirements_with_local_wheels, + use_pywrap_rules = use_pywrap_rules, + macos_deployment_target = macos_deployment_target, ), ) @@ -118,8 +128,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -140,18 +149,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( @@ -200,6 +197,7 @@ python_repository = repository_rule( "HERMETIC_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB", + "USE_PYWRAP_RULES", ], local = True, ) diff --git a/third_party/py/rules_python.patch b/third_party/py/rules_python.patch index 7d59ac107cc952..3dbe06dd2d6d96 100644 --- a/third_party/py/rules_python.patch +++ b/third_party/py/rules_python.patch @@ -1,32 +1,28 @@ -Subject: [PATCH] Add Python 3.13.0rc2 support to rules_python ---- -Index: python/versions.bzl -<+>UTF-8 -=================================================================== diff --git a/python/versions.bzl b/python/versions.bzl ---- a/python/versions.bzl (revision 084b877c98b580839ceab2b071b02fc6768f3de6) -+++ b/python/versions.bzl (date 1726256410148) -@@ -484,6 +484,19 @@ +index fd385cd1..eb4133f1 100644 +--- a/python/versions.bzl ++++ b/python/versions.bzl +@@ -484,6 +484,19 @@ TOOL_VERSIONS = { }, "strip_prefix": "python", }, + "3.13.0": { -+ "url": "20240909/cpython-{python_version}rc2+20240909-{platform}-{build}.tar.gz", ++ "url": "20241008/cpython-{python_version}+20241008-{platform}-{build}.tar.gz", + "sha256": { -+ "aarch64-apple-darwin": "5d38ca1e6b030b004714e10813903e906c6b8f2a6361770df4512a838f4a4a9f", -+ "aarch64-unknown-linux-gnu": "85e103fc81a1fcf94a93180f6df42e39a7dc15d4b711705e133dc2ec847552e7", -+ "ppc64le-unknown-linux-gnu": "3be3d8aefae579c420fc6abf01658ae89fda8120154f989575b08085d2f8d6dc", -+ "s390x-unknown-linux-gnu": "6ec5130d62473368ecc7e55338bf1cc58607dbfe8088959cab51265b9f13c38d", -+ "x86_64-apple-darwin": "c3dcd4314324159945dc19342c73b9deb8de0f2d1709171427dd52f1a05eecca", -+ "x86_64-pc-windows-msvc": "31282f912e984d399c56925dfb69a4f3ce76226dfb4806b09f37e3b4a15e5a30", -+ "x86_64-unknown-linux-gnu": "028581cce5004c66775a3ae8b3ed65681ab4b289608dfd1aec3354d169216099", ++ "aarch64-apple-darwin": "5d3cb8d7ca4cfbbe7ae1f118f26be112ee417d982fab8c6d85cfd8ccccf70718", ++ "aarch64-unknown-linux-gnu": "c1142af8f2c85923d2ba8201a35b913bb903a5d15f052c38bbecf2f49e2342dc", ++ "ppc64le-unknown-linux-gnu": "1be64a330499fed4e1f864b97eef5445b0e4abc0559ae45df3108981800cf998", ++ "s390x-unknown-linux-gnu": "c0b1cc51426feadaa932fdd9afd9a9af789916e128e48ac8909f9a269bbbd749", ++ "x86_64-apple-darwin": "b58ca12d9ae14bbd79f9e5cf4b748211ff1953e59abeac63b0f4e8e49845669f", ++ "x86_64-pc-windows-msvc": "c7651a7a575104f47c808902b020168057f3ad80f277e54cecfaf79a9ff50e22", ++ "x86_64-unknown-linux-gnu": "455200e1a202e9d9ef4b630c04af701c0a91dcaa6462022efc76893fc762ec95", + }, + "strip_prefix": "python", + }, } # buildifier: disable=unsorted-dict-items -@@ -493,6 +506,7 @@ +@@ -493,6 +506,7 @@ MINOR_MAPPING = { "3.10": "3.10.14", "3.11": "3.11.9", "3.12": "3.12.3", diff --git a/third_party/pybind11_protobuf/protobuf.patch b/third_party/pybind11_protobuf/protobuf.patch new file mode 100644 index 00000000000000..c568f5cd6f8bd8 --- /dev/null +++ b/third_party/pybind11_protobuf/protobuf.patch @@ -0,0 +1,20 @@ +diff --git a/pybind11_protobuf/BUILD b/pybind11_protobuf/BUILD +--- a/pybind11_protobuf/BUILD (revision 80f3440cd8fee124e077e2e47a8a17b78b451363) ++++ b/pybind11_protobuf/BUILD (date 1714533560692) +@@ -53,8 +53,8 @@ + "proto_caster_impl.h", + ], + local_defines = select({ +- ":enable_pyproto_api_setting": ["PYBIND11_PROTOBUF_ENABLE_PYPROTO_API"], +- "//conditions:default": [], ++ ":enable_pyproto_api_setting": ["PROTOBUF_USE_DLLS", "PYBIND11_PROTOBUF_ENABLE_PYPROTO_API"], ++ "//conditions:default": ["PROTOBUF_USE_DLLS"], + }), + deps = [ + ":check_unknown_fields", +@@ -95,4 +95,5 @@ + "@com_google_absl//absl/synchronization", + "@com_google_protobuf//:protobuf", + ], ++ local_defines = ["PROTOBUF_USE_DLLS"], + ) \ No newline at end of file diff --git a/third_party/shardy/temporary.patch b/third_party/shardy/temporary.patch index c5aa30af88f875..e69de29bb2d1d6 100644 --- a/third_party/shardy/temporary.patch +++ b/third_party/shardy/temporary.patch @@ -1,15 +0,0 @@ -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index cd6a8b6..c011aab 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") - - def repo(name): - """Imports LLVM.""" -- LLVM_COMMIT = "104f3c180644c8872eaad0b3fcf6a6b948d92a71" -- LLVM_SHA256 = "5caf03c6e40c87e7593ce50bfe53ec52a08677c221f4f611f30b3f40397505b8" -+ LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -+ LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" - - tf_http_archive( - name = name, diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index f2425a6d6c98fe..b303e313939ee5 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "a66667eefd65f73d50fab04298f477fc123b6740" - SHARDY_SHA256 = "543407a5fb203959d1189813275402dc5b8af6076203700ddea96a1dd8d981e1" + SHARDY_COMMIT = "ebd224c2199a003b2951fbeaa10daab88041762d" + SHARDY_SHA256 = "2809c6a97b99229a0279b2198bce0218629185f088047995f991c4dcfade8583" tf_http_archive( name = "shardy", diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 3e0b0e66bc8a4f..2eb32ea8c944be 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1,353 +1,83 @@ -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -41,3 +41,170 @@ - %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> - func.return %2, %3 : tensor<4xf64>, tensor<4xf64> +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +@@ -47,36 +47,36 @@ + return shapedType; } -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dim_size_zero -+// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1], -+// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32> -+func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> -+ func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32> -+func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+ func.return %0 : tensor<3x2x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = true}> -+// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32> -+func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1], -+ index_vector_dim = 3 -+ >, -+ unique_indices = true -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+ func.return %0 : tensor<3x2x4x9xi32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -12,14 +12,22 @@ - - #include -+#include - #include -+#include -+#include -+#include +-std::optional materializeCastFromIllegal(OpBuilder &builder, Type type, ++Value materializeCastFromIllegal(OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) { + Type fromType = getElementTypeOrSelf(inputs[0].getType()); + Type toType = getElementTypeOrSelf(type); + if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || + !toType.isSignlessInteger()) +- return std::nullopt; ++ return Value(); + // Use unrealized conversion casts to do signful->signless conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); + } - #include "llvm/ADT/APFloat.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -+#include "mlir/Rewrite/FrozenRewritePatternSet.h" - #include "mlir/Support/LLVM.h" - #include "mlir/Transforms/DialectConversion.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -@@ -58,6 +66,132 @@ - return targetVersion; +-std::optional materializeCastToIllegal(OpBuilder &builder, Type type, ++Value materializeCastToIllegal(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + Type fromType = getElementTypeOrSelf(inputs[0].getType()); + Type toType = getElementTypeOrSelf(type); + if (!fromType.isSignlessInteger() || + (!toType.isSignedInteger() && !toType.isUnsignedInteger())) +- return std::nullopt; ++ return Value(); + // Use unrealized conversion casts to do signless->signful conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); } -+SmallVector mergeSortedDims(ArrayRef dims1, -+ ArrayRef dims2) { -+ SmallVector result; -+ result.reserve(dims1.size() + dims2.size()); -+ std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), -+ std::back_inserter(result)); -+ return result; -+} -+ -+// Returns an updated indices tensor such that an `IotaOp` is prepended for each -+// dim in `indicesBatchingDims` with a `ConcatenateOp`. +-std::optional scalarToTensor(OpBuilder &builder, Type type, ++Value scalarToTensor(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + if (mlir::isa(inputs.front().getType())) { +- return std::nullopt; ++ return Value(); + } + Value result = + builder +diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir +--- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir ++++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir +@@ -0,0 +1,15 @@ ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.4.1' %s | FileCheck %s ++ ++// AllToAll was in the initial StableHLO opset, but changed in v1.5.0 to have ++// tuple arguments. Ensure that serializing for 1.4.1 is valid and targets the ++// v1.4.0 opset. +// -+// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have -+// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. -+Value createConcatIndices(Value indices, int64_t indexVectorDim, -+ ArrayRef indicesBatchingDims, -+ PatternRewriter &rewriter) { -+ Location loc = indices.getLoc(); -+ auto indicesType = cast(indices.getType()); -+ bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -+ -+ SmallVector iotaShape(indicesType.getShape()); -+ if (indexVectorDimOnLastDim) { -+ iotaShape.push_back(1); -+ } else { -+ iotaShape[indexVectorDim] = 1; -+ } -+ auto iotaType = -+ RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ -+ SmallVector indicesToConcat; -+ indicesToConcat.reserve(indicesBatchingDims.size() + 1); -+ for (int64_t batchingDim : indicesBatchingDims) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, batchingDim)); -+ } -+ if (indexVectorDimOnLastDim) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, indices)); -+ } else { -+ indicesToConcat.push_back(indices); -+ } -+ return rewriter.create(loc, indicesToConcat, indexVectorDim); ++// This will catch issues in op `isLegal` checks: ++// op.minVersion() <= target <= op.maxVersion() ++ ++// CHECK-LABEL: vhlo.func_v1 @all_to_all ++func.func public @all_to_all(%arg0: tensor<8x8x1xui16>) -> tensor<1x8x8xui16> { ++ // CHECK: vhlo.all_to_all_v1 ++ %0 = "stablehlo.all_to_all"(%arg0) <{concat_dimension = 2 : i64, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, split_count = 8 : i64, split_dimension = 0 : i64}> : (tensor<8x8x1xui16>) -> tensor<1x8x8xui16> ++ return %0 : tensor<1x8x8xui16> +} -+ -+//===----------------------------------------------------------------------===// -+// Patterns (non DRR) -+//===----------------------------------------------------------------------===// -+ -+// Converts a `GatherOp` with batching dims to a `GatherOp` without batching -+// dims, such that each batching dim becomes a collapsed slice dim with a -+// corresponding `IotaOp` concatenated to the start indices. -+class GatherWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(GatherOp op, -+ PatternRewriter &rewriter) const override { -+ GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); -+ ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ if (operandBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newCollapsedSliceDims = mergeSortedDims( -+ operandBatchingDims, dimNumbers.getCollapsedSliceDims()); -+ SmallVector newStartIndexMap = -+ llvm::to_vector(llvm::concat( -+ operandBatchingDims, dimNumbers.getStartIndexMap())); -+ Value newIndices = createConcatIndices( -+ op.getStartIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ rewriter.replaceOpWithNewOp( -+ op, op.getOperand(), newIndices, -+ GatherDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, -+ /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, -+ newStartIndexMap, dimNumbers.getIndexVectorDim()), -+ op.getSliceSizes(), /*indicesAreSorted=*/false); -+ -+ return success(); -+ } -+}; -+ -+// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching -+// dims, such that each batching dim becomes an inserted window dim with a -+// corresponding `IotaOp` concatenated to the scatter indices. -+class ScatterWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ScatterOp op, -+ PatternRewriter &rewriter) const override { -+ ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); -+ ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ if (inputBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newInsertedWindowDims = -+ mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); -+ SmallVector newScatterDimsToOperandDims = -+ llvm::to_vector(llvm::concat( -+ inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); -+ Value newIndices = createConcatIndices( -+ op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ auto newScatterOp = rewriter.create( -+ op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, -+ op.getUpdates(), -+ ScatterDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getUpdateWindowDims(), -+ newInsertedWindowDims, -+ /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, -+ newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -+ /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ -+ newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); -+ rewriter.replaceOp(op, newScatterOp.getResults()); -+ -+ return success(); +diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp +--- stablehlo/stablehlo/transforms/VhloToVersion.cpp ++++ stablehlo/stablehlo/transforms/VhloToVersion.cpp +@@ -92,6 +92,13 @@ + << " is greater than current version " + << Version::getCurrentVersion(); + ++ // Opset changes warrant a minor version bump, so this conversion assumes ++ // patch v0 since it is written against the opset at version `X.Y.0`. ++ if (targetVersion.getPatch() != 0) { ++ targetVersion = ++ vhlo::Version(targetVersion.getMajor(), targetVersion.getMinor(), 0); + } -+}; + - //===----------------------------------------------------------------------===// - // Pass - //===----------------------------------------------------------------------===// -@@ -107,10 +241,16 @@ - void populateStablehloCreateCompatibilityExpanderPatterns( - RewritePatternSet *patterns, MLIRContext *context, - vhlo::Version targetVersion) { -+ // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. -+ if (targetVersion < vhlo::Version(1, 1, 0)) { -+ patterns -+ ->add( -+ context); -+ } - // StableHLO TanOp is introduced in v1.4.0. - if (targetVersion < vhlo::Version(1, 4, 0)) { -- patterns->add(context); -- patterns->add(context); -+ patterns->add(context); - } + return targetVersion; } diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 97fd0b990fc1c7..62097715d4e914 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" - STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" + STABLEHLO_COMMIT = "1c0b606503aac4f8e01f5511b0a10418bf93e7a6" + STABLEHLO_SHA256 = "9ccf08c7d2c7dc0a5c314cf13e3e82faafc8c3dc2a45f4d6fa634ca8e5e97705" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 3466def95fd60d..7f85d491712101 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "07992d7c1ead60f610c17b7c1f9e50b6898adc87" - TFRT_SHA256 = "e1de8d371248d3dfc6e9ebd0e4094b57ce04d9545ae3756b5a84c33482614d5f" + TFRT_COMMIT = "8e00ae114e65160da6f5719c45a79102735789c8" + TFRT_SHA256 = "58299defb9f2dc1ed7041d22cf93255c71c71662ca6ca1f6a758f39d854c682e" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/llvm_integration/cl680875920.patch b/third_party/triton/llvm_integration/cl680875920.patch new file mode 100644 index 00000000000000..bbc8f024c78689 --- /dev/null +++ b/third_party/triton/llvm_integration/cl680875920.patch @@ -0,0 +1,114 @@ + +--- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp 2024-03-19 09:23:43.000000000 -0700 ++++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp 2024-10-01 02:58:18.000000000 -0700 +@@ -104,9 +104,26 @@ + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } ++ // Add LLVMOp Bundle Attrs ++ // https://github.com/llvm/llvm-project/blob/main/flang/lib/Optimizer/CodeGen/CodeGen.cpp#L113-L131 ++ llvm::SmallVector newAttrs; ++ newAttrs.reserve(callOp->getAttrs().size() + 2); ++ ++ for (mlir::NamedAttribute attr : callOp->getAttrs()) { ++ if (attr.getName() != "operandSegmentSizes") ++ newAttrs.push_back(attr); ++ } ++ ++ newAttrs.push_back(rewriter.getNamedAttr( ++ "operandSegmentSizes", ++ rewriter.getDenseI32ArrayAttr( ++ {static_cast(promotedOperands.size()), 0}))); ++ newAttrs.push_back(rewriter.getNamedAttr( ++ "op_bundle_sizes", rewriter.getDenseI32ArrayAttr({}))); ++ + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), +- promotedOperands, callOp->getAttrs()); ++ promotedOperands, newAttrs); + return newCallOp; + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-09-25 10:13:59.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -190,7 +190,8 @@ + auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32"); + LLVM::FastmathFlagsAttr defaultFlags{}; + auto rcpOp = rewriter.create( +- loc, returnType, name, operands[1], defaultFlags); ++ loc, returnType, name, operands[1], defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + + replacementOp = rewriter.create( + loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-08-20 03:28:55.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -219,7 +219,8 @@ + } + auto wmmaIntrinsic = rewriter.create( + loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name), +- operands, defaultFlags); ++ operands, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + + return wmmaIntrinsic.getResult(0); + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-09-16 13:44:40.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -72,7 +72,10 @@ + auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); + SmallVector operands = {cmp}; + Value asmResult = +- rewriter.create(loc, type, stringAttr, operands) ++ rewriter ++ .create( ++ loc, type, stringAttr, operands, ::mlir::LLVM::FastmathFlags{}, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/) + ->getResult(0); + return asmResult; + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +@@ -48,9 +48,10 @@ void createSchedGroupBarrier(PatternRewr + static_cast(groupIdValue)); + + LLVM::FastmathFlagsAttr defaultFlags{}; +- rewriter.create(loc, TypeRange{}, intrinsicName, +- ValueRange{mask, size, groupId}, +- defaultFlags); ++ rewriter.create( ++ loc, TypeRange{}, intrinsicName, ValueRange{mask, size, groupId}, ++ defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + // Insert intrinsic that controls the types of instructions that may be +@@ -63,8 +64,9 @@ Operation *createSchedBarrier(PatternRew + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); +- return rewriter.create(loc, TypeRange{}, intrinsicName, +- ValueRange{mask}, defaultFlags); ++ return rewriter.create( ++ loc, TypeRange{}, intrinsicName, ValueRange{mask}, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + // Insert an experimental intrinsic for instruction group level parallelism. +@@ -76,7 +78,8 @@ Operation *createIglpOpt(PatternRewriter + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); ++ loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + struct InstructionSchedHintsRewriter diff --git a/third_party/triton/llvm_integration/cl683501567.patch b/third_party/triton/llvm_integration/cl683501567.patch new file mode 100644 index 00000000000000..7395934253fc0c --- /dev/null +++ b/third_party/triton/llvm_integration/cl683501567.patch @@ -0,0 +1,13 @@ + +--- a/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-09-16 13:44:40.000000000 -0700 ++++ b/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-10-08 22:38:50.000000000 -0700 +@@ -104,7 +104,8 @@ + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line, +- subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}); ++ subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}, ++ /*annotations=*/{}); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + diff --git a/third_party/triton/llvm_integration/cl686059966.patch b/third_party/triton/llvm_integration/cl686059966.patch new file mode 100644 index 00000000000000..b5fcd3a266e313 --- /dev/null +++ b/third_party/triton/llvm_integration/cl686059966.patch @@ -0,0 +1,36 @@ + +--- a/lib/Analysis/AxisInfo.cpp 2024-10-01 12:24:54.000000000 -0700 ++++ b/lib/Analysis/AxisInfo.cpp 2024-10-15 05:20:45.000000000 -0700 +@@ -1079,8 +1079,8 @@ + + void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { +- auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); +- auto step = getLatticeElementFor(op, op.getStep())->getValue(); ++ auto lb = getLatticeElementFor(getProgramPointAfter(op), op.getLowerBound())->getValue(); ++ auto step = getLatticeElementFor(getProgramPointAfter(op), op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + +--- a/lib/Analysis/Utility.cpp 2024-10-02 02:26:53.000000000 -0700 ++++ b/lib/Analysis/Utility.cpp 2024-10-15 05:20:45.000000000 -0700 +@@ -826,15 +826,15 @@ + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { +- if (failed(visit(op))) ++ if (failed(visit(getProgramPointAfter(op)))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + +- LogicalResult visit(ProgramPoint point) override { +- Operation *op = point.get(); ++ LogicalResult visit(ProgramPoint* point) override { ++ Operation *op = point->getPrevOp(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( diff --git a/third_party/triton/llvm_integration/cl686893691.patch b/third_party/triton/llvm_integration/cl686893691.patch new file mode 100644 index 00000000000000..b27d7abb41d8eb --- /dev/null +++ b/third_party/triton/llvm_integration/cl686893691.patch @@ -0,0 +1,80 @@ + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-10-17 07:36:44.000000000 -0700 +@@ -190,8 +190,7 @@ + auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32"); + LLVM::FastmathFlagsAttr defaultFlags{}; + auto rcpOp = rewriter.create( +- loc, returnType, name, operands[1], defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ loc, returnType, name, operands[1], defaultFlags); + + replacementOp = rewriter.create( + loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-10-17 07:49:54.000000000 -0700 +@@ -219,8 +219,7 @@ + } + auto wmmaIntrinsic = rewriter.create( + loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name), +- operands, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ operands, defaultFlags); + + return wmmaIntrinsic.getResult(0); + } + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp 2024-10-02 02:26:53.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp 2024-10-17 07:39:38.000000000 -0700 +@@ -48,10 +48,9 @@ + static_cast(groupIdValue)); + + LLVM::FastmathFlagsAttr defaultFlags{}; +- rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{mask, size, groupId}, +- defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ rewriter.create(loc, TypeRange{}, intrinsicName, ++ ValueRange{mask, size, groupId}, ++ defaultFlags); + } + + // Insert intrinsic that controls the types of instructions that may be +@@ -64,9 +63,8 @@ + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); +- return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{mask}, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ return rewriter.create(loc, TypeRange{}, intrinsicName, ++ ValueRange{mask}, defaultFlags); + } + + // Insert an experimental intrinsic for instruction group level parallelism. +@@ -78,8 +76,7 @@ + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); + } + + struct InstructionSchedHintsRewriter + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-10-17 07:37:23.000000000 -0700 +@@ -72,10 +72,7 @@ + auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); + SmallVector operands = {cmp}; + Value asmResult = +- rewriter +- .create( +- loc, type, stringAttr, operands, ::mlir::LLVM::FastmathFlags{}, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/) ++ rewriter.create(loc, type, stringAttr, operands) + ->getResult(0); + return asmResult; + } diff --git a/third_party/triton/llvm_integration/cl689707450.patch b/third_party/triton/llvm_integration/cl689707450.patch new file mode 100644 index 00000000000000..0afc2edb10d8d2 --- /dev/null +++ b/third_party/triton/llvm_integration/cl689707450.patch @@ -0,0 +1,47 @@ + +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp 2024-08-05 02:40:13.000000000 -0700 ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp 2024-10-25 02:46:07.000000000 -0700 +@@ -56,7 +56,7 @@ + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, +- Location loc) -> std::optional { ++ Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining arguments that have been converted to a new type. + // We use this to rewrite triton_gpu.sparse_dot in a separate pass after +@@ -65,14 +65,14 @@ + inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); +- return std::nullopt; ++ return Value(); + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, +- Location loc) -> std::optional { ++ Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining uses of values that have been converted to a new type. + // We use this to rewrite triton_gpu.sparse_dot in a separate pass after +@@ -81,7 +81,7 @@ + inputs); + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); +- return std::nullopt; ++ return Value(); + }); + + // This will be called when (desiredType != newOperandType) +@@ -91,7 +91,7 @@ + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); +- return std::optional(cast.getResult()); ++ return Value(cast.getResult()); + }); + } + diff --git a/third_party/triton/llvm_integration/series.bzl b/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..70fef78927d338 100644 --- a/third_party/triton/llvm_integration/series.bzl +++ b/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,10 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton:llvm_integration/cl680875920.patch", + "//third_party/triton:llvm_integration/cl683501567.patch", + "//third_party/triton:llvm_integration/cl686059966.patch", + "//third_party/triton:llvm_integration/cl686893691.patch", + "//third_party/triton:llvm_integration/cl689707450.patch", # Add new patches just above this line ] diff --git a/third_party/triton/temporary/fix_left_shift_overflow.patch b/third_party/triton/temporary/fix_left_shift_overflow.patch new file mode 100644 index 00000000000000..ca31caef4b2824 --- /dev/null +++ b/third_party/triton/temporary/fix_left_shift_overflow.patch @@ -0,0 +1,11 @@ +--- a/lib/Analysis/AxisInfo.cpp ++++ b/lib/Analysis/AxisInfo.cpp +@@ -932,7 +932,7 @@ private: + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } +- return std::max(1, lhsDivisibility / (1 << shift)); ++ return std::max(1, lhsDivisibility / (int64_t(1) << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, diff --git a/third_party/triton/temporary/further_mixed_precision_fix.patch b/third_party/triton/temporary/further_mixed_precision_fix.patch new file mode 100644 index 00000000000000..6152ab48194c09 --- /dev/null +++ b/third_party/triton/temporary/further_mixed_precision_fix.patch @@ -0,0 +1,36 @@ +This resolves the issue here b/372630230. The patch is not intended to be +submitted to Triton upstream. This is because OAI historically refused these +similar work-arounds and the proper fixes are considerably more expensive to do. +diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -55,7 +55,8 @@ SmallVector reorderValues(const S + } + return ret; + } +- if (inBitWidth == 8 && ouBitWidth == 16) { ++ if ((inBitWidth == 8 && ouBitWidth == 16) || ++ (inBitWidth == 16 && ouBitWidth == 8)) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -1693,3 +1693,16 @@ module attributes {"triton_gpu.num-ctas" + tt.return + } + } ++ ++// ----- ++ ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> ++#dot_operand = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=4}> ++module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { ++ tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { ++ // CHECK-LABEL: @f16_to_f8_dot_operand ++ ++ %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> ++ tt.return ++ } ++} diff --git a/third_party/triton/temporary/i4_to_bf16.patch b/third_party/triton/temporary/i4_to_bf16.patch new file mode 100644 index 00000000000000..6afe4ee3b7157f --- /dev/null +++ b/third_party/triton/temporary/i4_to_bf16.patch @@ -0,0 +1,129 @@ + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-09-25 10:13:59.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-10-07 00:38:03.000000000 -0700 +@@ -264,7 +264,8 @@ + outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) -> SmallVector { + int numElements = v.size(); +- assert(numElements == 4 || numElements == 2 && "invalid vector size"); ++ assert(numElements == 8 || numElements == 4 || ++ numElements == 2 && "invalid vector size"); + + auto ctx = rewriter.getContext(); + int inBitwidth = inType.getIntOrFloatBitWidth(); +@@ -669,6 +670,115 @@ + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} ++ ++ LogicalResult matchAndRewrite( ++ arith::SIToFPOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const override { ++ if (succeeded(matchAndRewriteInt4ToBf16Conversion(op, rewriter))) { ++ return success(); ++ } ++ return Base::matchAndRewrite(op, adaptor, rewriter); ++ } ++ ++ // Matches subgraph of convert 8xi4 to 8xbf16 and rewrites it to inline PTX. ++ LogicalResult matchAndRewriteInt4ToBf16Conversion( ++ arith::SIToFPOp op, ConversionPatternRewriter &rewriter) const { ++ if (computeCapability < 90) return failure(); ++ Type inElemTy = getElementType(op.getIn()); ++ Type outElemTy = getElementType(op.getOut()); ++ if (!inElemTy.isInteger(8) || !outElemTy.isBF16()) return failure(); ++ FailureOr unpack = matchInt4Unpack(op.getIn()); ++ if (failed(unpack)) return failure(); ++ ++ Location loc = op.getLoc(); ++ Value src = rewriter.getRemappedValue(unpack.value()); ++ auto structTy = dyn_cast(src.getType()); ++ if (!structTy || structTy.getBody().size() % 4 != 0) return failure(); ++ auto isInt8 = [](Type type) { return type.isInteger(8); }; ++ if (!all_of(structTy.getBody(), isInt8)) return failure(); ++ ++ const LLVMTypeConverter *typeConverter = getTypeConverter(); ++ assert(inElemTy == typeConverter->convertType(inElemTy)); ++ assert(outElemTy == typeConverter->convertType(outElemTy)); ++ ++ const std::string S4_to_Bf16_sm90 = R"({ ++ .reg .b32 r<4>, mi, mf; ++ mov.b32 mi, 0x43404340 - 0x00080008; ++ mov.b32 mf, 0x43404340; ++ // Shift 4-bit inputs to 16-bit boundary. ++ shr.u32 r1, $4, 4; ++ shr.u32 r2, $4, 8; ++ shr.u32 r3, $4, 12; ++ // Sign-extend from 4 bits is equivalent to (x ^ 0x8) - 0x8. ++ lop3.b32 r0, $4, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r1, r1, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r2, r2, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r3, r3, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ // Interger-add magic number (minus bias from sign-extend above). ++ add.s16x2 r0, r0, mi; ++ add.s16x2 r1, r1, mi; ++ add.s16x2 r2, r2, mi; ++ add.s16x2 r3, r3, mi; ++ // Float-subtract magic number. ++ sub.bf16x2 r0, r0, mf; ++ sub.bf16x2 r1, r1, mf; ++ sub.bf16x2 r2, r2, mf; ++ sub.bf16x2 r3, r3, mf; ++ // Shuffle results into correct order. ++ prmt.b32 $0, r1, r0, 0x5410; ++ prmt.b32 $1, r3, r2, 0x5410; ++ prmt.b32 $2, r1, r0, 0x7632; ++ prmt.b32 $3, r3, r2, 0x7632; ++ })"; ++ ++ SmallVector resultVals; ++ SmallVector unpackedVals = unpackLLElements(loc, src, rewriter); ++ auto cvtFunc = makeConverterFromPtx(S4_to_Bf16_sm90, inElemTy, outElemTy); ++ for (ValueRange operands = unpackedVals; !operands.empty(); ++ operands = operands.drop_front(4)) { ++ SmallVector inVals = { ++ operands[0], operands[1], operands[2], operands[3], ++ // Repeat operands so that cvtFunc produces 8 outputs. ++ operands[0], operands[1], operands[2], operands[3]}; ++ auto outVals = cvtFunc(loc, rewriter, inVals); ++ assert(inVals.size() == outVals.size()); ++ resultVals.append(outVals.begin(), outVals.end()); ++ } ++ ++ resultVals = reorderValues(resultVals, op.getIn().getType(), op.getType()); ++ resultVals = maybeDeduplicate(op, resultVals); ++ Value view = ++ packLLElements(loc, typeConverter, resultVals, rewriter, op.getType()); ++ rewriter.replaceOp(op, view); ++ ++ return success(); ++ } ++ ++ // Returns the source if value is the result of an 2xi4 -> 2xi8 unpack ++ // sequence. ++ static FailureOr matchInt4Unpack(Value value) { ++ auto reshape = value.getDefiningOp(); ++ if (!reshape) return failure(); ++ auto join = reshape.getSrc().getDefiningOp(); ++ if (!join) return failure(); ++ auto shrHi = join.getLhs().getDefiningOp(); ++ if (!shrHi || !isConst4(shrHi.getRhs())) return failure(); ++ auto shrLo = join.getRhs().getDefiningOp(); ++ if (!shrLo || !isConst4(shrLo.getRhs())) return failure(); ++ auto shlLo = shrLo.getLhs().getDefiningOp(); ++ if (!shlLo || !isConst4(shlLo.getRhs())) return failure(); ++ if (shrHi.getLhs() != shlLo.getLhs()) return failure(); ++ return shrHi.getLhs(); ++ } ++ ++ // Returns true if the value is equal to 4. ++ static bool isConst4(Value v) { ++ auto constOp = v.getDefiningOp(); ++ if (!constOp) return false; ++ auto attr = mlir::dyn_cast(constOp.getValue()); ++ if (!attr || !attr.isSplat()) return false; ++ return attr.getSplatValue().getLimitedValue() == 4; ++ }; + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, diff --git a/third_party/triton/temporary/prefetch.patch b/third_party/triton/temporary/prefetch.patch new file mode 100644 index 00000000000000..57033b1ac972b1 --- /dev/null +++ b/third_party/triton/temporary/prefetch.patch @@ -0,0 +1,27 @@ +# b/370665038 These seeem to be real bugs that should be upstreamed. +diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +--- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; +- SmallVector offset{0, 0}; ++ SmallVector offset(shape.size(), 0); + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) +@@ -205,6 +205,13 @@ LogicalResult Prefetcher::initialize() { + if (srcType.isInteger(1)) + break; + } ++ // Propagation through ExpandDims is currently not supported. This blindly ++ // replaces the encoding with dot encoding & but ExpandDims requires a ++ // SliceEncoding. This could be rewritten to support it somehow, but I ++ // don't think it's trivial & it's currently crashing. ++ if (isa(op)) { ++ break; ++ } + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + foundConvertFromShared = true; diff --git a/third_party/triton/temporary/series.bzl b/third_party/triton/temporary/series.bzl index 4fa55269e3323c..d6b1a6a31f783f 100644 --- a/third_party/triton/temporary/series.bzl +++ b/third_party/triton/temporary/series.bzl @@ -14,5 +14,9 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton:temporary/fix_left_shift_overflow.patch", + "//third_party/triton:temporary/prefetch.patch", + "//third_party/triton:temporary/i4_to_bf16.patch", + "//third_party/triton:temporary/further_mixed_precision_fix.patch", # Add new patches just above this line ] diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 911fe7493783cf..3d5f6f99bf73d4 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl673813747" - TRITON_SHA256 = "3e901c1b441407b1b7ac601092f64a9141571879b00a1ff54437c8e9370a365f" + TRITON_COMMIT = "cl680473520" + TRITON_SHA256 = "c18fa65138b8c566b2f0299ebde4242f0f30c3625741a17de73ed5b1990cdabb" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 8f613badb53988..add4a2273ae715 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -280,7 +280,7 @@ index d74e0a224..4e45f7c4c 100644 static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, tt::CoarseSchedule &schedule, -@@ -236,19 +240,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { +@@ -235,17 +239,25 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; @@ -296,10 +296,7 @@ index d74e0a224..4e45f7c4c 100644 + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), cast(enc), -+ srcTy.getShape(), ttg::getOrder(srcTy.getEncoding()), -+ ttg::getCTALayout(srcTy.getEncoding()), -+ srcTy.getElementType().getIntOrFloatBitWidth(), -+ /*needTrans=*/false); ++ srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false); + } else if (isa(enc)) { + auto srcTy = cast(val.getType()); + tempAttr = ttg::SharedEncodingAttr::get( @@ -313,15 +310,13 @@ index d74e0a224..4e45f7c4c 100644 - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( -- val.getContext(), dotOpEnc, srcTy.getShape(), -- ttg::getOrder(srcTy.getEncoding()), -- ttg::getCTALayout(srcTy.getEncoding()), -- srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); +- val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, +- bitWidth, /*needTrans=*/false); + } } // Check that the shared encodings needed by the users are compatible. - if (attr != nullptr && attr != tempAttr) { -@@ -357,7 +370,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + if (attr != nullptr) +@@ -352,7 +364,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { }; for (Operation &op : forOp.getBody()->without_terminator()) { @@ -330,7 +325,7 @@ index d74e0a224..4e45f7c4c 100644 continue; seen.clear(); dfs(&op, 0, &op); -@@ -434,7 +447,7 @@ assignMemoryLayouts(llvm::SmallVector> +@@ -429,7 +441,7 @@ assignMemoryLayouts(llvm::SmallVector> continue; } @@ -339,15 +334,6 @@ index d74e0a224..4e45f7c4c 100644 loadInfo.usedByDot = true; if (loadIsMMAv3(op)) { loadInfo.loadIsMMAV3 = true; -@@ -460,7 +473,7 @@ assignMemoryLayouts(llvm::SmallVector> - // The codegen bug is caught by an assertion, so if you think you've - // fixed it, feel free to delete this code and see if the assert still - // fails. :) -- if (!loadInfo.sharedEncoding) { -+ if (dot && !loadInfo.sharedEncoding) { - if (auto dotEnc = dyn_cast( - dot.getResult().getType().getEncoding())) { - auto loadTy = cast(op->getResultTypes()[0]); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fd..37795c20c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index fcbc9ff2772db4..ebef2e1e7a6c48 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -150,6 +150,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos @@ -212,6 +214,7 @@ build:mkl_aarch64 -c opt # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). # with Eigen threadpool support build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +build:mkl_aarch64_threadpool --@compute_library//:openmp=false build:mkl_aarch64_threadpool -c opt # CUDA: This config refers to building CUDA op kernels with nvcc. @@ -240,6 +243,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" @@ -254,10 +259,11 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc - +build:cuda_nvcc --config=cuda +build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc +# Old config for backward compatibility +build:nvcc_clang --config=cuda_nvcc # Debug config build:dbg -c dbg @@ -327,8 +333,6 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" @@ -550,7 +554,7 @@ build:rbe_linux_cuda --config=rbe_linux_cpu build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda -build:rbe_linux_cuda_nvcc --config=nvcc_clang +build:rbe_linux_cuda_nvcc --config=cuda_nvcc build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_win_base --config=rbe_base @@ -738,27 +742,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/.github/workflows/bazel_dependency_violations.yml b/third_party/xla/.github/workflows/bazel_dependency_violations.yml index 988a84fed8a457..d11d78f259f631 100644 --- a/third_party/xla/.github/workflows/bazel_dependency_violations.yml +++ b/third_party/xla/.github/workflows/bazel_dependency_violations.yml @@ -29,23 +29,31 @@ jobs: dependency-violations: strategy: matrix: - tag: [gpu, no_rocm] + tag: [gpu, cuda-only, rocm-only] name: no-${{ matrix.tag }}-targets-in-cpu-build runs-on: ubuntu-22.04 defaults: run: shell: bash - timeout-minutes: 3 + timeout-minutes: 6 continue-on-error: true steps: - name: "Checking out repository" uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: "Install bazelisk" - run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + run: parallel --ungroup --retries 3 --delay 15 --nonall -- go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + - name: "Run bazel build --nobuild //xla/... with retries" + run: parallel --ungroup --retries 3 --delay 15 --nonall -- bazelisk build --nobuild //xla/... - name: "Run bazel cquery ... //xla/..." run: | set -euo pipefail - OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${{ matrix.tag }}_tag //xla/... 2>&1) + TAG_WITH_UNDERSCORES="${{ matrix.tag }}" + TAG_WITH_UNDERSCORES="${TAG_WITH_UNDERSCORES/-/_}" + if ! OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${TAG_WITH_UNDERSCORES}_tag //xla/... 2>&1); then + echo "Failed to run bazel cquery. Output:" + echo "$OUTPUT" + exit 1 + fi if echo "$OUTPUT" | grep 'Violation' >/dev/null; then echo "The following dependency violations were found:" echo "$OUTPUT" | grep 'Violation' | sed -e 's/^.*\[Violation\]/ -/' @@ -62,4 +70,4 @@ jobs: exit 1 fi - echo "No dependency violations found for tag '${{ matrix.tag }}'." \ No newline at end of file + echo "No dependency violations found for tag '${{ matrix.tag }}'." diff --git a/third_party/xla/.github/workflows/bazel_query.yml b/third_party/xla/.github/workflows/bazel_query.yml index da47f7fcb99260..4eed8dec22b191 100644 --- a/third_party/xla/.github/workflows/bazel_query.yml +++ b/third_party/xla/.github/workflows/bazel_query.yml @@ -31,12 +31,14 @@ jobs: defaults: run: shell: bash - timeout-minutes: 2 + timeout-minutes: 6 steps: - name: "Checking out repository" uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: "Install bazelisk" - run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + run: parallel --ungroup --retries 3 --delay 15 --nonall -- go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + - name: "Run bazel build --nobuild //xla/... with retries" + run: parallel --ungroup --retries 3 --delay 15 --nonall -- bazelisk build --nobuild //xla/... - name: "Run bazel query //xla/..." run: bazelisk query //xla/... > /dev/null - name: "Run bazel query deps(//xla/...)" diff --git a/third_party/xla/.github/workflows/bazel_tags.yml b/third_party/xla/.github/workflows/bazel_tags.yml index 71d1d6ba45e4ee..753615bc61e63b 100644 --- a/third_party/xla/.github/workflows/bazel_tags.yml +++ b/third_party/xla/.github/workflows/bazel_tags.yml @@ -31,11 +31,13 @@ jobs: defaults: run: shell: bash - timeout-minutes: 2 + timeout-minutes: 6 steps: - name: "Checking out repository" uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: "Install bazelisk" - run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + run: parallel --ungroup --retries 3 --delay 15 --nonall -- go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + - name: "Run bazel build --nobuild //xla/... with retries" + run: parallel --ungroup --retries 3 --delay 15 --nonall -- bazelisk build --nobuild //xla/... - name: "Assert all tags are documented" run: bazelisk query //xla/... --output=build | python3 build_tools/lint/tags.py diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml index be44d5e4193818..6a5b11ca49d36c 100644 --- a/third_party/xla/.github/workflows/buildifier.yml +++ b/third_party/xla/.github/workflows/buildifier.yml @@ -31,11 +31,11 @@ jobs: defaults: run: shell: bash - timeout-minutes: 1 + timeout-minutes: 6 steps: - name: "Checking out repository" uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 - name: "Install buildifier" - run: go install github.com/bazelbuild/buildtools/buildifier@433ea85 # 6.4.0 + run: parallel --ungroup --retries 3 --delay 15 --nonall -- go install github.com/bazelbuild/buildtools/buildifier@433ea85 # 6.4.0 - name: "Run buildifier" run: buildifier --lint=warn --warnings=-out-of-order-load -r xla/ diff --git a/third_party/xla/.github/workflows/check_contents.yml b/third_party/xla/.github/workflows/check_contents.yml index 0b8125f1b40f0e..6b236f2f09f8ca 100644 --- a/third_party/xla/.github/workflows/check_contents.yml +++ b/third_party/xla/.github/workflows/check_contents.yml @@ -40,7 +40,7 @@ jobs: defaults: run: shell: bash - timeout-minutes: 1 + timeout-minutes: 6 if: | github.event.sender.type == 'User' || contains(github.event.pull_request.body, 'FORCE_TEST_ACTIONS') diff --git a/third_party/xla/.github/workflows/clang_format.yml b/third_party/xla/.github/workflows/clang_format.yml index 338630d85430c7..344005dd7fb125 100644 --- a/third_party/xla/.github/workflows/clang_format.yml +++ b/third_party/xla/.github/workflows/clang_format.yml @@ -28,7 +28,7 @@ jobs: defaults: run: shell: bash - timeout-minutes: 1 + timeout-minutes: 6 if: | github.event.sender.type == 'User' || contains(github.event.pull_request.body, 'FORCE_TEST_ACTIONS') diff --git a/third_party/xla/.github/workflows/rollback_notification.yml b/third_party/xla/.github/workflows/rollback_notification.yml index 3d3658efdf5bf7..ac01b30bc82c84 100644 --- a/third_party/xla/.github/workflows/rollback_notification.yml +++ b/third_party/xla/.github/workflows/rollback_notification.yml @@ -30,7 +30,7 @@ jobs: env: GH_TOKEN: ${{ github.token }} HEAD_COMMIT_MESSAGE: ${{ github.event.head_commit.message }} - timeout-minutes: 1 + timeout-minutes: 6 steps: - name: "Checking out repository" uses: actions/checkout@e2f20e631ae6d7dd3b768f56a5d2af784dd54791 # v2.5.0 diff --git a/third_party/xla/build_tools/ci/build.py b/third_party/xla/build_tools/ci/build.py index 7ce80be519c240..2beba819342d52 100755 --- a/third_party/xla/build_tools/ci/build.py +++ b/third_party/xla/build_tools/ci/build.py @@ -55,6 +55,21 @@ ) +def retry( + args: List[str], delay_seconds: int = 15, retries: int = 3 +) -> List[str]: + # Possibly a slight abuse of `parallel` as nothing happens in parallel, just + # retries with delay if the command fails. + # pyformat:disable + return [ + "parallel", "--ungroup", + "--retries", str(retries), + "--delay", str(delay_seconds), + "--nonall", + "--", *args, + ] + + def sh(args, check=True, **kwargs): logging.info("Starting process: %s", " ".join(args)) return subprocess.run(args, check=check, **kwargs) @@ -102,9 +117,15 @@ class Build: options: Dict[str, Any] = dataclasses.field(default_factory=dict) extra_setup_commands: Tuple[List[str], ...] = () - def bazel_test_command(self) -> List[str]: + def bazel_command( + self, subcommand: str = "test", extra_options: Tuple[str, ...] = () + ) -> List[str]: """Returns a bazel test command for this build. + Args: + subcommand: The subcommand to give to bazel. `test` by default. + extra_options: Extra options. For now just used to pass in `--nobuild`. + Returns: List of command line arguments """ options = _dict_to_cli_options(self.options) @@ -117,8 +138,15 @@ def bazel_test_command(self) -> List[str]: test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] tag_filters = [build_tag_filters, test_tag_filters] - all_options = tag_filters + configs + action_env + test_env + options - return ["bazel", "test", *all_options, "--", *self.target_patterns] + all_options = ( + tag_filters + + configs + + action_env + + test_env + + options + + list(extra_options) + ) + return ["bazel", subcommand, *all_options, "--", *self.target_patterns] def docker_run_command(self, *, command: str, **kwargs: Any) -> List[str]: assert self.image_url, "`docker run` has no meaning without an image." @@ -147,15 +175,12 @@ def commands(self) -> List[List[str]]: # pyformat:disable - if self.type_ == BuildType.CPU_ARM64: + if self.type_ == BuildType.CPU_ARM64 and using_docker: # We would need to install parallel, but `apt` hangs regularly on Kokoro # VMs due to yaqs/eng/q/4506961933928235008 cmds.append(["docker", "pull", self.image_url]) elif using_docker: - # This is a slightly odd use of parallel, we aren't doing anything besides - # retrying after 15 seconds up to 3 times if `docker pull` fails. - cmds.append(["parallel", "--ungroup", "--retries", "3", "--delay", "15", - "docker", "pull", ":::", self.image_url]) + cmds.append(retry(["docker", "pull", self.image_url])) container_name = "xla_ci" _, repo_name = self.repo.split("/") @@ -173,7 +198,28 @@ def commands(self) -> List[List[str]]: maybe_docker_exec = ( ["docker", "exec", container_name] if using_docker else [] ) - cmds.append(maybe_docker_exec + self.bazel_test_command()) + + # We really want `bazel fetch` here, but it uses `bazel query` and not + # `cquery`, which means that it fails due to config issues that aren't + # problems in practice. + + # TODO(ddunleavy): Remove the condition here. Need to get parallel on the + # MacOS VM, and slightly change TF config (likely by specifying tag_filters + # manually). + if self.type_ not in ( + BuildType.TENSORFLOW_CPU, + BuildType.TENSORFLOW_GPU, + BuildType.MACOS_CPU_X86, + ): + cmds.append( + maybe_docker_exec + + retry( + self.bazel_command( + subcommand="build", extra_options=("--nobuild",) + ) + ) + ) + cmds.append(maybe_docker_exec + self.bazel_command()) cmds.append( maybe_docker_exec + ["bazel", "analyze-profile", "profile.json.gz"] ) @@ -213,9 +259,9 @@ def nvidia_gpu_build_with_compute_capability( image_url=_DEFAULT_IMAGE, target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, configs=configs, - test_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu") + test_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu", "-rocm-only") + extra_gpu_tags, - build_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu"), + build_tag_filters=("-no_oss", "requires-gpu-nvidia", "gpu", "-rocm-only"), options={ "run_under": "//tools/ci_build/gpu_build:parallel_gpu_execute", "repo_env": f"TF_CUDA_COMPUTE_CAPABILITIES={compute_capability/10}", @@ -315,10 +361,7 @@ def nvidia_gpu_build_with_compute_capability( repo="google/jax", image_url=_DEFAULT_IMAGE, configs=( - "avx_posix", - "mkl_open_source_only", - "rbe_cpu_linux_py3.12", - "tensorflow_testing_rbe_linux", + "rbe_linux_x86_64", ), target_patterns=("//tests:cpu_tests", "//tests:backend_independent_tests"), test_env=dict( @@ -326,7 +369,9 @@ def nvidia_gpu_build_with_compute_capability( JAX_SKIP_SLOW_TESTS=1, ), options=dict( - **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + **_DEFAULT_BAZEL_OPTIONS, + override_repository="xla=/github/xla", + repo_env="HERMETIC_PYTHON_VERSION=3.12", ), ) @@ -335,10 +380,7 @@ def nvidia_gpu_build_with_compute_capability( repo="google/jax", image_url=_DEFAULT_IMAGE, configs=( - "avx_posix", - "mkl_open_source_only", - "rbe_linux_cuda12.3_nvcc_py3.10", - "tensorflow_testing_rbe_linux", + "rbe_linux_x86_64_cuda", ), target_patterns=("//tests:gpu_tests", "//tests:backend_independent_tests"), build_tag_filters=("-multiaccelerator",), @@ -349,7 +391,9 @@ def nvidia_gpu_build_with_compute_capability( JAX_EXCLUDE_TEST_TARGETS="PmapTest.testSizeOverflow", ), options=dict( - **_DEFAULT_BAZEL_OPTIONS, override_repository="xla=/github/xla" + **_DEFAULT_BAZEL_OPTIONS, + override_repository="xla=/github/xla", + repo_env="HERMETIC_PYTHON_VERSION=3.10", ), ) diff --git a/third_party/xla/build_tools/ci/golden_commands.txt b/third_party/xla/build_tools/ci/golden_commands.txt index 17cfb64ff950a5..1a79600da98d1e 100644 --- a/third_party/xla/build_tools/ci/golden_commands.txt +++ b/third_party/xla/build_tools/ci/golden_commands.txt @@ -2,14 +2,16 @@ $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html docker pull us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/xla us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python bash +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64_xla --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel test --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd,-not_run:arm --config=warnings --config=rbe_cross_compile_linux_arm64_xla --config=nonccl --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --build_tests_only -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.CPU_ARM64 # BEGIN BuildType.CPU_X86 $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/xla gcr.io/tensorflow-sigs/build:latest-python3.11 bash +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd --config=warnings --config=nonccl --config=rbe_linux_cpu --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel test --build_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd --test_tag_filters=-no_oss,-gpu,-requires-gpu-nvidia,-requires-gpu-amd --config=warnings --config=nonccl --config=rbe_linux_cpu --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci @@ -17,27 +19,30 @@ docker stop xla_ci # BEGIN BuildType.GPU $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html nvidia-smi -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/xla gcr.io/tensorflow-sigs/build:latest-python3.11 bash -docker exec xla_ci bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-amd --config=warnings --config=rbe_linux_cuda_nvcc --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --@cuda_driver//:enable_forward_compatibility=true --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-amd --config=warnings --config=rbe_linux_cuda_nvcc --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --@cuda_driver//:enable_forward_compatibility=true --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --nobuild -- //xla/... //build_tools/... @local_tsl//tsl/... +docker exec xla_ci bazel test --build_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only --test_tag_filters=-no_oss,requires-gpu-nvidia,gpu,-rocm-only,requires-gpu-sm75-only,requires-gpu-sm60,requires-gpu-sm70,-requires-gpu-sm80,-requires-gpu-sm80-only,-requires-gpu-sm90,-requires-gpu-sm90-only,-requires-gpu-amd --config=warnings --config=rbe_linux_cuda_nvcc --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute --repo_env=TF_CUDA_COMPUTE_CAPABILITIES=7.5 --@cuda_driver//:enable_forward_compatibility=true --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async -- //xla/... //build_tools/... @local_tsl//tsl/... docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.GPU # BEGIN BuildType.JAX_CPU $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html git clone --depth=1 https://github.com/google/jax ./github/jax -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/jax gcr.io/tensorflow-sigs/build:latest-python3.11 bash -docker exec xla_ci bazel test --build_tag_filters= --test_tag_filters= --config=avx_posix --config=mkl_open_source_only --config=rbe_cpu_linux_py3.12 --config=tensorflow_testing_rbe_linux --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla -- //tests:cpu_tests //tests:backend_independent_tests +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla --repo_env=HERMETIC_PYTHON_VERSION=3.12 --nobuild -- //tests:cpu_tests //tests:backend_independent_tests +docker exec xla_ci bazel test --build_tag_filters= --test_tag_filters= --config=rbe_linux_x86_64 --test_env=JAX_NUM_GENERATED_CASES=25 --test_env=JAX_SKIP_SLOW_TESTS=1 --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla --repo_env=HERMETIC_PYTHON_VERSION=3.12 -- //tests:cpu_tests //tests:backend_independent_tests docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.JAX_CPU # BEGIN BuildType.JAX_GPU $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html git clone --depth=1 https://github.com/google/jax ./github/jax -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/jax gcr.io/tensorflow-sigs/build:latest-python3.11 bash -docker exec xla_ci bazel test --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=avx_posix --config=mkl_open_source_only --config=rbe_linux_cuda12.3_nvcc_py3.10 --config=tensorflow_testing_rbe_linux --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla -- //tests:gpu_tests //tests:backend_independent_tests +docker exec xla_ci parallel --ungroup --retries 3 --delay 15 --nonall -- bazel build --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla --repo_env=HERMETIC_PYTHON_VERSION=3.10 --nobuild -- //tests:gpu_tests //tests:backend_independent_tests +docker exec xla_ci bazel test --build_tag_filters=-multiaccelerator --test_tag_filters=-multiaccelerator --config=rbe_linux_x86_64_cuda --test_env=JAX_SKIP_SLOW_TESTS=1 --test_env=TF_CPP_MIN_LOG_LEVEL=0 --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow --test_output=errors --verbose_failures --keep_going --nobuild_tests_only --profile=profile.json.gz --flaky_test_attempts=3 --jobs=150 --bes_upload_mode=fully_async --override_repository=xla=/github/xla --repo_env=HERMETIC_PYTHON_VERSION=3.10 -- //tests:gpu_tests //tests:backend_independent_tests docker exec xla_ci bazel analyze-profile profile.json.gz docker stop xla_ci # END BuildType.JAX_GPU @@ -53,7 +58,7 @@ bazel analyze-profile profile.json.gz # BEGIN BuildType.TENSORFLOW_CPU $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html git clone --depth=1 https://github.com/tensorflow/tensorflow ./github/tensorflow -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/tensorflow gcr.io/tensorflow-sigs/build:latest-python3.11 bash docker exec xla_ci bazel test --build_tag_filters= --test_tag_filters= --config=release_cpu_linux --config=rbe_linux_cpu --config=linux_cpu_pycpp_test_filters --verbose_failures --test_output=errors --override_repository=xla=/github/xla --profile=profile.json.gz -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/compiler/tensorrt/... docker exec xla_ci bazel analyze-profile profile.json.gz @@ -62,7 +67,7 @@ docker stop xla_ci # BEGIN BuildType.TENSORFLOW_GPU $KOKORO_ARTIFACTS_DIR/github/xla/.kokoro/generate_index_html.sh index.html git clone --depth=1 https://github.com/tensorflow/tensorflow ./github/tensorflow -parallel --ungroup --retries 3 --delay 15 docker pull ::: gcr.io/tensorflow-sigs/build:latest-python3.11 +parallel --ungroup --retries 3 --delay 15 --nonall -- docker pull gcr.io/tensorflow-sigs/build:latest-python3.11 docker run --detach --name=xla_ci --rm --interactive --tty --volume=./github:/github --workdir=/github/tensorflow gcr.io/tensorflow-sigs/build:latest-python3.11 bash docker exec xla_ci bazel test --build_tag_filters=-no_oss,+gpu --test_tag_filters=-no_oss,+gpu --config=release_gpu_linux --config=rbe_linux_cuda --config=linux_cuda_pycpp_test_filters --verbose_failures --test_output=errors --override_repository=xla=/github/xla --profile=profile.json.gz -- //tensorflow/compiler/... -//tensorflow/compiler/tf2tensorrt/... //tensorflow/python/... -//tensorflow/python/distribute/... -//tensorflow/python/compiler/tensorrt/... docker exec xla_ci bazel analyze-profile profile.json.gz diff --git a/third_party/xla/build_tools/configure/configure.py b/third_party/xla/build_tools/configure/configure.py index 43e0f234d49cfd..e975a2b540f59e 100755 --- a/third_party/xla/build_tools/configure/configure.py +++ b/third_party/xla/build_tools/configure/configure.py @@ -35,6 +35,7 @@ import logging import os import pathlib +import platform import shutil import subprocess import sys @@ -187,9 +188,14 @@ class CudaCompiler(ArgparseableEnum): NVCC = enum.auto() +class RocmCompiler(ArgparseableEnum): + HIPCC = enum.auto() + + class OS(ArgparseableEnum): + """Modeled after the values returned by `platform.system()`.""" LINUX = enum.auto() - MACOS = enum.auto() + DARWIN = enum.auto() WINDOWS = enum.auto() @@ -263,6 +269,9 @@ class XLAConfigOptions: cuda_compiler: CudaCompiler using_nccl: bool + # ROCM specific + rocm_compiler: RocmCompiler + def to_bazelrc_lines( self, dpav: DiscoverablePathsAndVersions, @@ -284,6 +293,9 @@ def to_bazelrc_lines( rc = [] build_and_test_tag_filters = list(_DEFAULT_BUILD_AND_TEST_TAG_FILTERS) + if self.os == OS.DARWIN: + build_and_test_tag_filters.append("-no_mac") + # Platform independent options based on host compiler if self.host_compiler == HostCompiler.GCC: rc.append(f"build --action_env GCC_HOST_COMPILER_PATH={dpav.gcc_path}") @@ -299,6 +311,9 @@ def to_bazelrc_lines( build_and_test_tag_filters.append("-gpu") elif self.backend == Backend.CUDA: + build_and_test_tag_filters.append("-rocm-only") + build_and_test_tag_filters.append("-sycl-only") + compiler_pair = self.cuda_compiler, self.host_compiler if compiler_pair == (CudaCompiler.CLANG, HostCompiler.CLANG): @@ -307,7 +322,7 @@ def to_bazelrc_lines( f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" ) elif compiler_pair == (CudaCompiler.NVCC, HostCompiler.CLANG): - rc.append("build --config nvcc_clang") + rc.append("build --config cuda_nvcc") # This is demanded by cuda_configure.bzl rc.append( f"build --action_env CLANG_CUDA_COMPILER_PATH={dpav.clang_path}" @@ -325,8 +340,8 @@ def to_bazelrc_lines( f"build:cuda --repo_env HERMETIC_CUDA_VERSION={dpav.cuda_version}" ) rc.append( - "build:cuda --repo_env" - f" HERMETIC_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + "build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=" + f"{','.join(dpav.cuda_compute_capabilities)}" ) if dpav.cudnn_version: rc.append( @@ -347,8 +362,23 @@ def to_bazelrc_lines( if not self.using_nccl: rc.append("build --config nonccl") elif self.backend == Backend.ROCM: - pass + build_and_test_tag_filters.append("-cuda-only") + build_and_test_tag_filters.append("-sycl-only") + + compiler_pair = self.rocm_compiler, self.host_compiler + + if compiler_pair == (RocmCompiler.HIPCC, HostCompiler.CLANG): + rc.append("build --config rocm") + # This is demanded by rocm_configure.bzl. + rc.append(f"build --action_env CLANG_COMPILER_PATH={dpav.clang_path}") + elif compiler_pair == (RocmCompiler.HIPCC, HostCompiler.GCC): + rc.append("build --config rocm") + else: + raise NotImplementedError("ROCm clang with host compiler not supported") elif self.backend == Backend.SYCL: + build_and_test_tag_filters.append("-cuda-only") + build_and_test_tag_filters.append("-rocm-only") + rc.append("build --config sycl") # Lines that are added for every backend @@ -390,7 +420,7 @@ def _parse_args(): required=True, ) parser.add_argument( - "--os", type=OS.from_str, choices=list(OS), default="linux" + "--os", type=OS.from_str, choices=list(OS), default=platform.system() ) parser.add_argument( "--host_compiler", @@ -404,6 +434,12 @@ def _parse_args(): choices=list(CudaCompiler), default="nvcc", ) + parser.add_argument( + "--rocm_compiler", + type=RocmCompiler.from_str, + choices=list(RocmCompiler), + default="hipcc", + ) parser.add_argument( "--cuda_compute_capabilities", type=comma_separated_list, @@ -479,6 +515,7 @@ def main(): python_bin_path=args.python_bin_path, compiler_options=args.compiler_options, using_nccl=args.nccl, + rocm_compiler=args.rocm_compiler, ) bazelrc_lines = config.to_bazelrc_lines( diff --git a/third_party/xla/build_tools/configure/configure_test.py b/third_party/xla/build_tools/configure/configure_test.py index 8457ff40aea3ee..8849b931a905e9 100644 --- a/third_party/xla/build_tools/configure/configure_test.py +++ b/third_party/xla/build_tools/configure/configure_test.py @@ -25,6 +25,7 @@ Backend = configure.Backend HostCompiler = configure.HostCompiler CudaCompiler = configure.CudaCompiler +RocmCompiler = configure.RocmCompiler OS = configure.OS _PYTHON_BIN_PATH = "/usr/bin/python3" @@ -98,6 +99,7 @@ def test_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( @@ -119,6 +121,7 @@ def test_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( @@ -139,6 +142,7 @@ def test_cuda_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.CLANG, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( @@ -160,6 +164,7 @@ def test_default_cuda_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.CLANG, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( @@ -181,6 +186,7 @@ def test_nvcc_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( @@ -202,6 +208,7 @@ def test_nvcc_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, + rocm_compiler=RocmCompiler.HIPCC, ) bazelrc_lines = config.to_bazelrc_lines( diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc index 502bc8541c1285..3f42ca9e563aa2 100644 --- a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -15,7 +15,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc index 4623f6f52073fa..04b79c87aed6ab 100644 --- a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc @@ -13,7 +13,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc index 8cd19224698311..59d8d15c220843 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -1,7 +1,7 @@ build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang -build --config nvcc_clang +build --config cuda_nvcc build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 @@ -15,7 +15,7 @@ test --test_size_filters small,medium build --copt -Wno-sign-compare build --copt -Wno-error=unused-command-line-argument build --copt -Wno-gnu-offsetof-extensions -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc index be90a87545368b..f4d4f72c566e7f 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -10,7 +10,7 @@ build --python_path /usr/bin/python3 test --test_env LD_LIBRARY_PATH test --test_size_filters small,medium build --copt -Wno-sign-compare -build --build_tag_filters -no_oss -build --test_tag_filters -no_oss -test --build_tag_filters -no_oss -test --test_tag_filters -no_oss +build --build_tag_filters -no_oss,-rocm-only,-sycl-only +build --test_tag_filters -no_oss,-rocm-only,-sycl-only +test --build_tag_filters -no_oss,-rocm-only,-sycl-only +test --test_tag_filters -no_oss,-rocm-only,-sycl-only diff --git a/third_party/xla/build_tools/dependencies/aspects.bzl b/third_party/xla/build_tools/dependencies/aspects.bzl index 76b531112fdee6..c5e51da52aae7b 100644 --- a/third_party/xla/build_tools/dependencies/aspects.bzl +++ b/third_party/xla/build_tools/dependencies/aspects.bzl @@ -73,10 +73,18 @@ validate_gpu_tag = aspect( attr_aspects = ["deps"], ) -def _no_rocm_tag_violation_aspect_impl(target, ctx): - return _dependency_violation_aspect_impl(target, ctx, "no_rocm") +def _cuda_only_tag_violation_aspect_impl(target, ctx): + return _dependency_violation_aspect_impl(target, ctx, "cuda-only") -validate_no_rocm_tag = aspect( - implementation = _no_rocm_tag_violation_aspect_impl, +validate_cuda_only_tag = aspect( + implementation = _cuda_only_tag_violation_aspect_impl, + attr_aspects = ["deps"], +) + +def _rocm_only_tag_violation_aspect_impl(target, ctx): + return _dependency_violation_aspect_impl(target, ctx, "rocm-only") + +validate_rocm_only_tag = aspect( + implementation = _rocm_only_tag_violation_aspect_impl, attr_aspects = ["deps"], ) diff --git a/third_party/xla/build_tools/lint/tags.py b/third_party/xla/build_tools/lint/tags.py index 2ec82cc0113e65..1a666a25ffd765 100644 --- a/third_party/xla/build_tools/lint/tags.py +++ b/third_party/xla/build_tools/lint/tags.py @@ -30,9 +30,9 @@ "local": "https://bazel.build/reference/be/common-definitions", "manual": "https://bazel.build/reference/be/common-definitions", "large": "Conventional tag for `test_suites` of large tests", + "__PYTHON_RULES_MIGRATION_DO_NOT_USE_WILL_BREAK__": "Internal bazel tag", # Various disable tags (currently recognized by OpenXLA CI) "no_oss": "Test is disabled on OpenXLA CI.", - "no_rocm": "Disabled on ROCm builds.", "no_mac": "Disabled on MacOS.", "no_windows": "Disabled on Windows.", "no_mac_arm64": "Disabled on ARM MacOS.", @@ -65,6 +65,8 @@ "requires-gpu-sm90-only": "Requires exactly sm90.", "gpu": "Catch-all tag for targets that should be built/tested on GPU CI", "cpu": "Catch-all tag for targets that should be built/tested on CPU CI.", + "cuda-only": "Targets that require the CUDA backend to be enabled.", + "rocm-only": "Targets that require the ROCm backend to be enabled.", # Below tags are generated by `xla_test`. "broken": "Test will be marked with other tags to disable in `xla_test`.", "xla_interpreter": "Uses interpreter backend.", diff --git a/third_party/xla/build_tools/rocm/run_xla.sh b/third_party/xla/build_tools/rocm/run_xla.sh index d7eee422ec01db..c7fe4b7a77e4f3 100755 --- a/third_party/xla/build_tools/rocm/run_xla.sh +++ b/third_party/xla/build_tools/rocm/run_xla.sh @@ -50,7 +50,7 @@ fi export PYTHON_BIN_PATH=`which python3` export TF_NEED_ROCM=1 export ROCM_PATH=$ROCM_INSTALL_DIR -TAGS_FILTER="gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-no_rocm" +TAGS_FILTER="gpu,requires-gpu-amd,-requires-gpu-nvidia,-no_oss,-oss_excluded,-oss_serial,-no_gpu,-cuda-only" UNSUPPORTED_GPU_TAGS="$(echo -requires-gpu-sm{60,70,80,86,89,90}{,-only})" TAGS_FILTER="${TAGS_FILTER},${UNSUPPORTED_GPU_TAGS// /,}" @@ -67,6 +67,7 @@ bazel \ --local_test_jobs=${N_TEST_JOBS} \ --test_env=TF_TESTS_PER_GPU=$TF_TESTS_PER_GPU \ --test_env=TF_GPU_COUNT=$TF_GPU_COUNT \ + --action_env=TF_ROCM_AMDGPU_TARGETS=gfx90a \ --action_env=XLA_FLAGS=--xla_gpu_force_compilation_parallelism=16 \ --action_env=XLA_FLAGS=--xla_gpu_enable_llvm_module_compilation_parallelism=true \ --run_under=//tools/ci_build/gpu_build:parallel_gpu_execute \ diff --git a/third_party/xla/docs/_toc.yaml b/third_party/xla/docs/_toc.yaml index d8ef492e6c4cca..4e340eafc2dbc8 100644 --- a/third_party/xla/docs/_toc.yaml +++ b/third_party/xla/docs/_toc.yaml @@ -1,55 +1,68 @@ toc: - heading: XLA developer guide - title: Getting started + # These should be in alphabetical order unless otherwise noted. section: - - title: Overview - path: /xla - - title: XLA architecture - path: /xla/architecture - - title: Operation semantics - path: /xla/operation_semantics + # This is the default tab for the Getting started section. + - title: Overview + path: /xla + - title: Operation semantics + path: /xla/operation_semantics + - title: XLA architecture + path: /xla/architecture - title: Developer details + # These should be in alphabetical order unless otherwise noted. section: - - title: Broadcasting - path: /xla/broadcasting - - title: Shapes and layout - path: /xla/shapes - - title: Aliasing - path: /xla/aliasing - - title: Indexing Analysis - path: /xla/indexing - - title: Tiled layout - path: /xla/tiled_layout - - title: Writing custom calls - path: /xla/custom_call - - title: Persisted autotuning - path: /xla/persisted_autotuning - - title: Determinism - path: /xla/determinism - - title: XLA Tooling - path: /xla/tools - - title: Using LSP autocompletion - path: /xla/lsp + - title: Aliasing + path: /xla/aliasing + - title: Async HLO instructions + path: /xla/async_ops + - title: Broadcasting + path: /xla/broadcasting + - title: Copybara quirks + path: /xla/copybara + - title: Determinism + path: /xla/determinism + - title: Hermetic CUDA overview + path: /xla/hermetic_cuda + - title: Indexing Analysis + path: /xla/indexing + - title: Multi-host HLO runner + path: /xla/tools_multi_host_hlo_runner + - title: Persisted autotuning + path: /xla/persisted_autotuning + - title: Shapes and layout + path: /xla/shapes + - title: Tiled layout + path: /xla/tiled_layout + - title: Using LSP autocompletion + path: /xla/lsp + - title: Writing custom calls + path: /xla/custom_call + - title: XLA Tooling + path: /xla/tools - title: Contributing + # These should be in alphabetical order unless otherwise noted. section: - - title: Contributing - path: /xla/contributing - - title: Developer guide - path: /xla/developer_guide - - title: Build from source - path: /xla/build_from_source - - title: Develop a new backend for XLA - path: /xla/developing_new_backend - - title: Develop a new PJRT plugin - path: /xla/pjrt_integration + # This is the default tab for the Contributing section. + - title: Contributing + path: /xla/contributing + - title: Build from source + path: /xla/build_from_source + - title: Develop a new backend for XLA + path: /xla/developing_new_backend + - title: Develop a new PJRT plugin + path: /xla/pjrt_integration + - title: Developer guide + path: /xla/developer_guide - title: Using XLA in TensorFlow + # These should be in alphabetical order unless otherwise noted. section: - - title: Using XLA in TensorFlow - path: /xla/tf2xla - - title: Use tfcompile - path: /xla/tf2xla/tfcompile - - title: Autoclustering tutorial - path: /xla/tf2xla/tutorials/autoclustering_xla - - title: Use XLA with tf.function - path: /xla/tf2xla/tutorials/jit_compile - + - title: Using XLA in TensorFlow + path: /xla/tf2xla + - title: Use tfcompile + path: /xla/tf2xla/tfcompile + - title: Autoclustering tutorial + path: /xla/tf2xla/tutorials/autoclustering_xla + - title: Use XLA with tf.function + path: /xla/tf2xla/tutorials/jit_compile diff --git a/third_party/xla/docs/broadcasting.md b/third_party/xla/docs/broadcasting.md index f47a6f7a642a66..0f781f06e869fd 100644 --- a/third_party/xla/docs/broadcasting.md +++ b/third_party/xla/docs/broadcasting.md @@ -99,7 +99,7 @@ dimensions 1 and 2 of the cuboid. This type of broadcast is used in the binary ops in `XlaBuilder`, if the `broadcast_dimensions` argument is given. For example, see -[XlaBuilder::Add](https://github.com/openxla/xla/blob/main/xla/client/xla_builder.cc). +[XlaBuilder::Add](https://github.com/openxla/xla/blob/main/xla/hlo/builder/xla_builder.cc). In the XLA source code, this type of broadcasting is sometimes called "InDim" broadcasting. diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md index 2471df68331057..1bd39c0e070405 100644 --- a/third_party/xla/docs/custom_call.md +++ b/third_party/xla/docs/custom_call.md @@ -267,8 +267,8 @@ struct Range { int64_t hi; }; -XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember("i64"), - StructMember("i64")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember("lo"), + StructMember("hi")); auto handler = Ffi::Bind().Attr("range").To([](Range range) -> Error{ return Error::Success(); diff --git a/third_party/xla/docs/hermetic_cuda.md b/third_party/xla/docs/hermetic_cuda.md index d51f7ddfab516b..b11dd47996598c 100644 --- a/third_party/xla/docs/hermetic_cuda.md +++ b/third_party/xla/docs/hermetic_cuda.md @@ -137,6 +137,16 @@ is specified in [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https: test:cuda --@cuda_driver//:enable_forward_compatibility=true ``` + The default flag value is `false`. + + When CUDA forward compatibility mode is disabled, Bazel targets will use User + Mode and Kernel Mode Drivers pre-installed on the system. + + When CUDA forward compatibility mode is enabled, Bazel targets will use User + Mode Driver from CUDA driver redistribution downloaded into Bazel cache and + Kernel Mode Driver pre-installed on the system. It allows enabling new CUDA + Toolkit features while using older Kernel Mode Driver. + Forward compatibility mode should be enforced only when it is appropriate - see [NVIDIA documentation](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#forward-compatibility-support-across-major-toolkit-versions) for the details. diff --git a/third_party/xla/docs/indexing.md b/third_party/xla/docs/indexing.md index c44f2845358ee6..f18782f169aa6a 100644 --- a/third_party/xla/docs/indexing.md +++ b/third_party/xla/docs/indexing.md @@ -316,7 +316,7 @@ d1 in [0, 29] ``` ### [Gather](https://openxla.org/xla/operation_semantics#gather) -Only the simplified gather is supported. See [gather_simplifier].(https://github.com/openxla/xla/blob/main/xla/service/gather_simplifier.h). +Only the simplified gather is supported. See [gather_simplifier].(https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/gather_simplifier.h). ```c++ operand = f32[33,76,70] parameter(0) diff --git a/third_party/xla/docs/lsp.md b/third_party/xla/docs/lsp.md index 12dd0f17247823..a29067866e662e 100644 --- a/third_party/xla/docs/lsp.md +++ b/third_party/xla/docs/lsp.md @@ -15,6 +15,6 @@ each file in a project. Use the [build_tools/lint/generate_compile_commands.py](https://github.com/openxla/xla/blob/main/build_tools/lint/generate_compile_commands.py) script. The following invocation from XLA repo root generates a -`compile_commands.json` file in place: `bash bazel aquery "mnemonic(CppCompile, +`compile_commands.json` file in place: `bazel aquery "mnemonic(CppCompile, //xla/...)" --output=jsonproto | \ python3 build_tools/lint/generate_compile_commands.py` diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md index 55849974726628..92943521cd0343 100644 --- a/third_party/xla/docs/operation_semantics.md +++ b/third_party/xla/docs/operation_semantics.md @@ -1500,8 +1500,9 @@ if, e.g., `offset_dims.size` is `4`, `operand.rank` is `6` and `1`→`3`, `2`→`4`, `3`→`5`}. If `indices_are_sorted` is set to true then XLA can assume that `start_indices` -are sorted (in ascending `start_index_map` order) by the user. If they are not -then the semantics is implementation defined. +are sorted (in ascending order, _after_ scattering its values according to +`start_index_map`) by the user. If they are not then the semantics are +implementation defined. ### Informal Description and Examples @@ -2493,9 +2494,10 @@ always be the current value from the `output` array and the second parameter will always be the value from the `updates` array. This is important specifically for cases when the `update_computation` is _not commutative_. -If `indices_are_sorted` is set to true then XLA can assume that `start_indices` -are sorted (in ascending `start_index_map` order) by the user. If they are not -then the semantics is implementation defined. +If `indices_are_sorted` is set to true then XLA can assume that `scatter_indices` +are sorted (in ascending order, _after_ scattering its values according to +`scatter_dims_to_operand_dims`) by the user. If they are not then the semantics +are implementation defined. If `unique_indices` is set to true then XLA can assume that all elements scattered to are unique. So XLA could use non-atomic operations. If diff --git a/third_party/xla/docs/persisted_autotuning.md b/third_party/xla/docs/persisted_autotuning.md index 5d1f01ab501442..15083cf6f482fe 100644 --- a/third_party/xla/docs/persisted_autotuning.md +++ b/third_party/xla/docs/persisted_autotuning.md @@ -10,6 +10,38 @@ normally. Autotuning caches are still useful if we make a few changes: the fusions that are present in the cache will use the cache, and the other ones will be autotuned normally. +## Recommended: Cache directory + +``` +--xla_gpu_per_fusion_autotune_cache_dir=your/directory +``` + +Use and maintain a per-fusion autotune cache in the given directory. There will +be one file per distinct fusion. + +The main advantage of this approach is that you can use the same cache directory +for multiple XLA runs (of different models) and your cache will grow with each +new fusion encountered - speeding up subsequent runs. There is also basic +support for running multiple XLA instances with the same cache directory +concurrently. + +XLA will read existing results when they are needed and write new results after +they are determined. + +- The directory must exist before running XLA and it must be writable. +- Cache invalidation has to be handled by the user: + - Please use an empty directory if you want to start with an empty cache. +- XLA version checks must be done by the user: + - If you want to use separate caches for different versions of XLA, please + use different directories. + +The cache is turned off by default (when you don't provide the parameter). + +Limitation: This is not guaranteed to work well in combination with the other +caching method described below. + +## Alternative: Loading or dumping all results from a given HLO to one file + The autotuning results can be dumped/loaded using these parameters: ``` diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index 8472111016b09f..97db0a56bd186d 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -1,3 +1,4 @@ +compiler/xla/internal/package_groups.bzl: compiler/xla/mlir_hlo/WORKSPACE: compiler/xla/package_groups.bzl: compiler/xla/stream_executor/build_defs.bzl: diff --git a/third_party/xla/third_party/nanobind/nanobind.BUILD b/third_party/xla/third_party/nanobind/nanobind.BUILD index 72b47585b5e5d0..814fe3595df65d 100644 --- a/third_party/xla/third_party/nanobind/nanobind.BUILD +++ b/third_party/xla/third_party/nanobind/nanobind.BUILD @@ -1,7 +1,21 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + licenses(["notice"]) package(default_visibility = ["//visibility:public"]) +bool_flag( + name = "enabled_free_threading", + build_setting_default = False, +) + +config_setting( + name = "use_enabled_free_threading", + flag_values = { + ":enabled_free_threading": "True", + }, +) + cc_library( name = "nanobind", srcs = glob( @@ -11,10 +25,17 @@ cc_library( exclude = ["src/nb_combined.cpp"], ), copts = ["-fexceptions"], - defines = [ - "NB_BUILD=1", - "NB_SHARED=1", - ], + defines = select({ + ":use_enabled_free_threading": [ + "NB_FREE_THREADED=1", + "NB_BUILD=1", + "NB_SHARED=1", + ], + "//conditions:default": [ + "NB_BUILD=1", + "NB_SHARED=1", + ], + }), includes = ["include"], textual_hdrs = glob( [ diff --git a/third_party/xla/third_party/nanobind/workspace.bzl b/third_party/xla/third_party/nanobind/workspace.bzl index 1c692d396e9b98..aa39484e078f3b 100644 --- a/third_party/xla/third_party/nanobind/workspace.bzl +++ b/third_party/xla/third_party/nanobind/workspace.bzl @@ -5,8 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-2.1.0", - sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), + strip_prefix = "nanobind-2.2.0", + sha256 = "bfbfc7e5759f1669e4ddb48752b1ddc5647d1430e94614d6f8626df1d508e65a", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.2.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", ) diff --git a/third_party/xla/third_party/py/BUILD b/third_party/xla/third_party/py/BUILD index 84eba77ce1a7af..0381d65bb27514 100644 --- a/third_party/xla/third_party/py/BUILD +++ b/third_party/xla/third_party/py/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") @@ -38,3 +39,16 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, ) + +# Flag indicating if the target requires pre-built wheel. +bool_flag( + name = "wheel_dependency", + build_setting_default = False, +) + +config_setting( + name = "enable_wheel_dependency", + flag_values = { + ":wheel_dependency": "True", + }, +) diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/LICENSE b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/LICENSE deleted file mode 100644 index d645695673349e..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD deleted file mode 100644 index f386124a36dfe8..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD +++ /dev/null @@ -1,64 +0,0 @@ -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -cc_library( - name = "float8", - hdrs = ["include/float8.h"], - include_prefix = "ml_dtypes", - # Internal headers are all relative to . but other packages - # include these headers with the prefix. - includes = [ - ".", - "ml_dtypes", - ], - deps = ["@eigen_archive//:eigen3"], -) - -cc_library( - name = "intn", - hdrs = ["include/intn.h"], - include_prefix = "ml_dtypes", - # Internal headers are all relative to . but other packages - # include these headers with the prefix. - includes = [ - ".", - "ml_dtypes", - ], -) - -pybind_extension( - name = "_ml_dtypes_ext", - srcs = [ - "_src/common.h", - "_src/custom_float.h", - "_src/dtypes.cc", - "_src/int4_numpy.h", - "_src/numpy.cc", - "_src/numpy.h", - "_src/ufuncs.h", - ], - includes = ["ml_dtypes"], - visibility = [":__subpackages__"], - deps = [ - ":float8", - ":intn", - "@eigen_archive//:eigen3", - "@local_tsl//third_party/py/numpy:headers", - ], -) - -py_library( - name = "ml_dtypes", - srcs = [ - "__init__.py", - "_finfo.py", - "_iinfo.py", - ], - deps = [":_ml_dtypes_ext"], -) diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD deleted file mode 100644 index c811379a19dabd..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD +++ /dev/null @@ -1,71 +0,0 @@ -package( - default_visibility = ["//visibility:public"], -) - -py_library( - name = "testing_base", - deps = [ - "//:ml_dtypes", - "@absl_py//absl/testing:absltest", - "@absl_py//absl/testing:parameterized", - "@local_tsl//third_party/py/numpy", - ], -) - -py_test( - name = "custom_float_test", - srcs = ["custom_float_test.py"], - main = "custom_float_test.py", - deps = [":testing_base"], -) - -py_test( - name = "int4_test", - srcs = ["int4_test.py"], - main = "int4_test.py", - deps = [":testing_base"], -) - -py_test( - name = "iinfo_test", - srcs = ["iinfo_test.py"], - main = "iinfo_test.py", - deps = [":testing_base"], -) - -py_test( - name = "finfo_test", - srcs = ["finfo_test.py"], - main = "finfo_test.py", - deps = [":testing_base"], -) - -py_test( - name = "metadata_test", - srcs = ["metadata_test.py"], - main = "metadata_test.py", - deps = [":testing_base"], -) - -cc_test( - name = "float8_test", - srcs = ["float8_test.cc"], - linkstatic = 1, - deps = [ - "//:float8", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@eigen_archive//:eigen3", - ], -) - -cc_test( - name = "intn_test_cc", - srcs = ["intn_test.cc"], - linkstatic = 1, - deps = [ - "//:intn", - "@com_google_googletest//:gtest_main", - "@eigen_archive//:eigen3", - ], -) diff --git a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/workspace.bzl b/third_party/xla/third_party/py/non_hermetic/ml_dtypes/workspace.bzl deleted file mode 100644 index 51505bf3a1460d..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/ml_dtypes/workspace.bzl +++ /dev/null @@ -1,22 +0,0 @@ -"""Provides the repo macro to import ml_dtypes. - -ml_dtypes provides machine-learning-specific data-types like bfloat16, -float8 varieties, and int4. -""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172" - ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a" - tf_http_archive( - name = "ml_dtypes", - build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", - link_files = { - "//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel", - "//third_party/py/ml_dtypes:LICENSE": "LICENSE", - }, - sha256 = ML_DTYPES_SHA256, - strip_prefix = "ml_dtypes-{commit}/ml_dtypes".format(commit = ML_DTYPES_COMMIT), - urls = tf_mirror_urls("https://github.com/jax-ml/ml_dtypes/archive/{commit}/ml_dtypes-{commit}.tar.gz".format(commit = ML_DTYPES_COMMIT)), - ) diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/BUILD b/third_party/xla/third_party/py/non_hermetic/numpy/BUILD deleted file mode 100644 index c80cc5287bc469..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -licenses(["restricted"]) - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "numpy", - srcs = ["tf_numpy_dummy.py"], - srcs_version = "PY3", -) - -alias( - name = "headers", - actual = "@local_config_python//:numpy_headers", -) - -genrule( - name = "dummy", - outs = ["tf_numpy_dummy.py"], - cmd = "touch $@", - visibility = ["//visibility:private"], -) diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/README.md b/third_party/xla/third_party/py/non_hermetic/numpy/README.md deleted file mode 100644 index 4e58b9df87b5ec..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# numpy_ops - -The folder tf_numpy_api/ contains lists of NumPy API symbols that the -`numpy_ops` internal module in TensorFlow implements. diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD b/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD deleted file mode 100644 index 070f8ab8a65352..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -# TensorFlow API backwards compatibility test goldens for tf.experimental.numpy. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -filegroup( - name = "api_golden", - srcs = glob(["*.pbtxt"]), -) diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt b/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt deleted file mode 100644 index 9198264c02961f..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt +++ /dev/null @@ -1,51 +0,0 @@ -path: "tensorflow.experimental.numpy.ndarray" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "OVERLOADABLE_OPERATORS" - mtype: "" - } - member { - name: "dtype" - mtype: "" - } - member { - name: "name" - mtype: "" - } - member { - name: "ndim" - mtype: "" - } - member { - name: "shape" - mtype: "" - } - member_method { - name: "__init__" - } - member_method { - name: "eval" - argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "experimental_ref" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_shape" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ref" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "set_shape" - argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt b/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt deleted file mode 100644 index 2f5490ad0c922f..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt +++ /dev/null @@ -1,919 +0,0 @@ -path: "tensorflow.experimental.numpy" -tf_module { - member { - name: "bool_" - mtype: "" - } - member { - name: "complex128" - mtype: "" - } - member { - name: "complex64" - mtype: "" - } - member { - name: "complex_" - mtype: "" - } - member { - name: "e" - mtype: "" - } - member { - name: "float16" - mtype: "" - } - member { - name: "float32" - mtype: "" - } - member { - name: "float64" - mtype: "" - } - member { - name: "float_" - mtype: "" - } - member { - name: "iinfo" - mtype: "" - } - member { - name: "inexact" - mtype: "" - } - member { - name: "inf" - mtype: "" - } - member { - name: "int16" - mtype: "" - } - member { - name: "int32" - mtype: "" - } - member { - name: "int64" - mtype: "" - } - member { - name: "int8" - mtype: "" - } - member { - name: "int_" - mtype: "" - } - member { - name: "ndarray" - mtype: "" - } - member { - name: "newaxis" - mtype: "" - } - member { - name: "object_" - mtype: "" - } - member { - name: "pi" - mtype: "" - } - member { - name: "random" - mtype: "" - } - member { - name: "string_" - mtype: "" - } - member { - name: "uint16" - mtype: "" - } - member { - name: "uint32" - mtype: "" - } - member { - name: "uint64" - mtype: "" - } - member { - name: "uint8" - mtype: "" - } - member { - name: "unicode_" - mtype: "" - } - member_method { - name: "abs" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "absolute" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "add" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "all" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "allclose" - argspec: "args=[\'a\', \'b\', \'rtol\', \'atol\', \'equal_nan\'], varargs=None, keywords=None, defaults=[\'1e-05\', \'1e-08\', \'False\'], " - } - member_method { - name: "amax" - argspec: "args=[\'a\', \'axis\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "amin" - argspec: "args=[\'a\', \'axis\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "angle" - argspec: "args=[\'z\', \'deg\'], varargs=None, keywords=None, defaults=[\'False\'], " - } - member_method { - name: "any" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "append" - argspec: "args=[\'arr\', \'values\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "arange" - argspec: "args=[\'start\', \'stop\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\'], " - } - member_method { - name: "arccos" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arccosh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arcsin" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arcsinh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctan2" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctanh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "argmax" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "argmin" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "argsort" - argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], " - } - member_method { - name: "around" - argspec: "args=[\'a\', \'decimals\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "array" - argspec: "args=[\'val\', \'dtype\', \'copy\', \'ndmin\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'0\'], " - } - member_method { - name: "array_equal" - argspec: "args=[\'a1\', \'a2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "asanyarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "asarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "ascontiguousarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "atleast_1d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "atleast_2d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "atleast_3d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "average" - argspec: "args=[\'a\', \'axis\', \'weights\', \'returned\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "bitwise_and" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_not" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_or" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_xor" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "broadcast_arrays" - argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "broadcast_to" - argspec: "args=[\'array\', \'shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cbrt" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ceil" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "clip" - argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "compress" - argspec: "args=[\'condition\', \'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "concatenate" - argspec: "args=[\'arys\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "conj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "conjugate" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "copy" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cos" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cosh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "count_nonzero" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "cross" - argspec: "args=[\'a\', \'b\', \'axisa\', \'axisb\', \'axisc\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'-1\', \'None\'], " - } - member_method { - name: "cumprod" - argspec: "args=[\'a\', \'axis\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "cumsum" - argspec: "args=[\'a\', \'axis\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "deg2rad" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "diag" - argspec: "args=[\'v\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "diag_indices" - argspec: "args=[\'n\', \'ndim\'], varargs=None, keywords=None, defaults=[\'2\'], " - } - member_method { - name: "diagflat" - argspec: "args=[\'v\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "diagonal" - argspec: "args=[\'a\', \'offset\', \'axis1\', \'axis2\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'1\'], " - } - member_method { - name: "diff" - argspec: "args=[\'a\', \'n\', \'axis\'], varargs=None, keywords=None, defaults=[\'1\', \'-1\'], " - } - member_method { - name: "divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "divmod" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dot" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "einsum" - argspec: "args=[\'subscripts\'], varargs=operands, keywords=kwargs, defaults=None" - } - member_method { - name: "empty" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "empty_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "exp" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "exp2" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "expand_dims" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "experimental_enable_numpy_behavior" - argspec: "args=[\'prefer_float32\'], varargs=None, keywords=None, defaults=[\'False\'], " - } - member_method { - name: "expm1" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "eye" - argspec: "args=[\'N\', \'M\', \'k\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \"\"], " - } - member_method { - name: "fabs" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "finfo" - argspec: "args=[\'dtype\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "fix" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "flatten" - argspec: "args=[\'a\', \'order\'], varargs=None, keywords=None, defaults=[\'C\'], " - } - member_method { - name: "flip" - argspec: "args=[\'m\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "fliplr" - argspec: "args=[\'m\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "flipud" - argspec: "args=[\'m\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "float_power" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "floor" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "floor_divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "full" - argspec: "args=[\'shape\', \'fill_value\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "full_like" - argspec: "args=[\'a\', \'fill_value\', \'dtype\', \'order\', \'subok\', \'shape\'], varargs=None, keywords=None, defaults=[\'None\', \'K\', \'True\', \'None\'], " - } - member_method { - name: "gcd" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "geomspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'None\', \'0\'], " - } - member_method { - name: "greater" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "greater_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "heaviside" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hypot" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "identity" - argspec: "args=[\'n\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "imag" - argspec: "args=[\'val\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "inner" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isclose" - argspec: "args=[\'a\', \'b\', \'rtol\', \'atol\', \'equal_nan\'], varargs=None, keywords=None, defaults=[\'1e-05\', \'1e-08\', \'False\'], " - } - member_method { - name: "iscomplex" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "iscomplexobj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isfinite" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isinf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isnan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isneginf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isposinf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isreal" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isrealobj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isscalar" - argspec: "args=[\'num\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "issubdtype" - argspec: "args=[\'arg1\', \'arg2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ix_" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } - member_method { - name: "kron" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "lcm" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "less" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "less_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "linspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'retstep\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'False\', \"\", \'0\'], " - } - member_method { - name: "log" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log10" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log1p" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log2" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logaddexp" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logaddexp2" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_and" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_not" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_or" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_xor" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'base\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'10.0\', \'None\', \'0\'], " - } - member_method { - name: "matmul" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "max" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "maximum" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "mean" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " - } - member_method { - name: "meshgrid" - argspec: "args=[], varargs=xi, keywords=kwargs, defaults=None" - } - member_method { - name: "min" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "minimum" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "mod" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "moveaxis" - argspec: "args=[\'a\', \'source\', \'destination\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "multiply" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nanmean" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "nanprod" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "nansum" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "ndim" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "negative" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nextafter" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nonzero" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "not_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ones" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "ones_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "outer" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "pad" - argspec: "args=[\'array\', \'pad_width\', \'mode\'], varargs=None, keywords=kwargs, defaults=None" - } - member_method { - name: "polyval" - argspec: "args=[\'p\', \'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "positive" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "power" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "prod" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "promote_types" - argspec: "args=[\'type1\', \'type2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ptp" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "rad2deg" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ravel" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "real" - argspec: "args=[\'val\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "reciprocal" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "remainder" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "repeat" - argspec: "args=[\'a\', \'repeats\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reshape" - argspec: "args=[\'a\', \'newshape\', \'order\'], varargs=None, keywords=None, defaults=[\'C\'], " - } - member_method { - name: "result_type" - argspec: "args=[], varargs=arrays_and_dtypes, keywords=None, defaults=None" - } - member_method { - name: "roll" - argspec: "args=[\'a\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "rot90" - argspec: "args=[\'m\', \'k\', \'axes\'], varargs=None, keywords=None, defaults=[\'1\', \'(0, 1)\'], " - } - member_method { - name: "round" - argspec: "args=[\'a\', \'decimals\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "select" - argspec: "args=[\'condlist\', \'choicelist\', \'default\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "shape" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sign" - argspec: "args=[\'x\', \'out\', \'where\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " - } - member_method { - name: "signbit" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sin" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sinc" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sinh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "size" - argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "sort" - argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], " - } - member_method { - name: "split" - argspec: "args=[\'ary\', \'indices_or_sections\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "sqrt" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "square" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "squeeze" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "stack" - argspec: "args=[\'arrays\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "std" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "subtract" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sum" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "swapaxes" - argspec: "args=[\'a\', \'axis1\', \'axis2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "take" - argspec: "args=[\'a\', \'indices\', \'axis\', \'out\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'clip\'], " - } - member_method { - name: "take_along_axis" - argspec: "args=[\'arr\', \'indices\', \'axis\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tanh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tensordot" - argspec: "args=[\'a\', \'b\', \'axes\'], varargs=None, keywords=None, defaults=[\'2\'], " - } - member_method { - name: "tile" - argspec: "args=[\'a\', \'reps\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "trace" - argspec: "args=[\'a\', \'offset\', \'axis1\', \'axis2\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'1\', \'None\'], " - } - member_method { - name: "transpose" - argspec: "args=[\'a\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "tri" - argspec: "args=[\'N\', \'M\', \'k\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " - } - member_method { - name: "tril" - argspec: "args=[\'m\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "triu" - argspec: "args=[\'m\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "true_divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vander" - argspec: "args=[\'x\', \'N\', \'increasing\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " - } - member_method { - name: "var" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'out\', \'ddof\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'0\', \'None\'], " - } - member_method { - name: "vdot" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "where" - argspec: "args=[\'condition\', \'x\', \'y\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "zeros" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "zeros_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt b/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt deleted file mode 100644 index 61a4766f3f8f0f..00000000000000 --- a/third_party/xla/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt +++ /dev/null @@ -1,35 +0,0 @@ -path: "tensorflow.experimental.numpy.random" -tf_module { - member_method { - name: "poisson" - argspec: "args=[\'lam\', \'size\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\'], " - } - member_method { - name: "rand" - argspec: "args=[], varargs=size, keywords=None, defaults=None" - } - member_method { - name: "randint" - argspec: "args=[\'low\', \'high\', \'size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\"], " - } - member_method { - name: "randn" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } - member_method { - name: "random" - argspec: "args=[\'size\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "seed" - argspec: "args=[\'s\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "standard_normal" - argspec: "args=[\'size\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "uniform" - argspec: "args=[\'low\', \'high\', \'size\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\'], " - } -} diff --git a/third_party/xla/third_party/py/python_repo.bzl b/third_party/xla/third_party/py/python_repo.bzl index 0c58e3077712c6..1362e2fbfe6f63 100644 --- a/third_party/xla/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/py/python_repo.bzl @@ -14,7 +14,11 @@ def _python_repository_impl(ctx): ctx.file("BUILD", "") wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow") wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False) +<<<<<<< HEAD output_path = ctx.os.environ.get("OUTPUT_PATH", None) +======= + macos_deployment_target = ctx.os.environ.get("MACOSX_DEPLOYMENT_TARGET", "") +>>>>>>> master requirements = None for i in range(0, len(ctx.attr.requirements_locks)): @@ -35,13 +39,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -56,6 +58,13 @@ Please check python_init_repositories() in your WORKSPACE file. merged_requirements_content, ) + use_pywrap_rules = bool( + ctx.os.environ.get("USE_PYWRAP_RULES", False), + ) + + if use_pywrap_rules: + print("!!!Using pywrap rules instead of directly creating .so objects!!!") # buildifier: disable=print + ctx.file( "py_version.bzl", """ @@ -66,6 +75,8 @@ WHEEL_COLLAB = "{wheel_collab}" OUTPUT_PATH = "{output_path}" REQUIREMENTS = "{requirements}" REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}" +USE_PYWRAP_RULES = {use_pywrap_rules} +MACOSX_DEPLOYMENT_TARGET = "{macos_deployment_target}" """.format( version = version, wheel_name = wheel_name, @@ -73,6 +84,8 @@ REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}" output_path = output_path, requirements = str(requirements), requirements_with_local_wheels = requirements_with_local_wheels, + use_pywrap_rules = use_pywrap_rules, + macos_deployment_target = macos_deployment_target, ), ) @@ -121,8 +134,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -143,18 +155,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( @@ -203,7 +203,11 @@ python_repository = repository_rule( "HERMETIC_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB", +<<<<<<< HEAD "OUTPUT_PATH", +======= + "USE_PYWRAP_RULES", +>>>>>>> master ], local = True, ) diff --git a/third_party/xla/third_party/py/rules_python.patch b/third_party/xla/third_party/py/rules_python.patch index 7d59ac107cc952..3dbe06dd2d6d96 100644 --- a/third_party/xla/third_party/py/rules_python.patch +++ b/third_party/xla/third_party/py/rules_python.patch @@ -1,32 +1,28 @@ -Subject: [PATCH] Add Python 3.13.0rc2 support to rules_python ---- -Index: python/versions.bzl -<+>UTF-8 -=================================================================== diff --git a/python/versions.bzl b/python/versions.bzl ---- a/python/versions.bzl (revision 084b877c98b580839ceab2b071b02fc6768f3de6) -+++ b/python/versions.bzl (date 1726256410148) -@@ -484,6 +484,19 @@ +index fd385cd1..eb4133f1 100644 +--- a/python/versions.bzl ++++ b/python/versions.bzl +@@ -484,6 +484,19 @@ TOOL_VERSIONS = { }, "strip_prefix": "python", }, + "3.13.0": { -+ "url": "20240909/cpython-{python_version}rc2+20240909-{platform}-{build}.tar.gz", ++ "url": "20241008/cpython-{python_version}+20241008-{platform}-{build}.tar.gz", + "sha256": { -+ "aarch64-apple-darwin": "5d38ca1e6b030b004714e10813903e906c6b8f2a6361770df4512a838f4a4a9f", -+ "aarch64-unknown-linux-gnu": "85e103fc81a1fcf94a93180f6df42e39a7dc15d4b711705e133dc2ec847552e7", -+ "ppc64le-unknown-linux-gnu": "3be3d8aefae579c420fc6abf01658ae89fda8120154f989575b08085d2f8d6dc", -+ "s390x-unknown-linux-gnu": "6ec5130d62473368ecc7e55338bf1cc58607dbfe8088959cab51265b9f13c38d", -+ "x86_64-apple-darwin": "c3dcd4314324159945dc19342c73b9deb8de0f2d1709171427dd52f1a05eecca", -+ "x86_64-pc-windows-msvc": "31282f912e984d399c56925dfb69a4f3ce76226dfb4806b09f37e3b4a15e5a30", -+ "x86_64-unknown-linux-gnu": "028581cce5004c66775a3ae8b3ed65681ab4b289608dfd1aec3354d169216099", ++ "aarch64-apple-darwin": "5d3cb8d7ca4cfbbe7ae1f118f26be112ee417d982fab8c6d85cfd8ccccf70718", ++ "aarch64-unknown-linux-gnu": "c1142af8f2c85923d2ba8201a35b913bb903a5d15f052c38bbecf2f49e2342dc", ++ "ppc64le-unknown-linux-gnu": "1be64a330499fed4e1f864b97eef5445b0e4abc0559ae45df3108981800cf998", ++ "s390x-unknown-linux-gnu": "c0b1cc51426feadaa932fdd9afd9a9af789916e128e48ac8909f9a269bbbd749", ++ "x86_64-apple-darwin": "b58ca12d9ae14bbd79f9e5cf4b748211ff1953e59abeac63b0f4e8e49845669f", ++ "x86_64-pc-windows-msvc": "c7651a7a575104f47c808902b020168057f3ad80f277e54cecfaf79a9ff50e22", ++ "x86_64-unknown-linux-gnu": "455200e1a202e9d9ef4b630c04af701c0a91dcaa6462022efc76893fc762ec95", + }, + "strip_prefix": "python", + }, } # buildifier: disable=unsorted-dict-items -@@ -493,6 +506,7 @@ +@@ -493,6 +506,7 @@ MINOR_MAPPING = { "3.10": "3.10.14", "3.11": "3.11.9", "3.12": "3.12.3", diff --git a/third_party/xla/third_party/shardy/temporary.patch b/third_party/xla/third_party/shardy/temporary.patch index c5aa30af88f875..e69de29bb2d1d6 100644 --- a/third_party/xla/third_party/shardy/temporary.patch +++ b/third_party/xla/third_party/shardy/temporary.patch @@ -1,15 +0,0 @@ -diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl -index cd6a8b6..c011aab 100644 ---- a/third_party/llvm/workspace.bzl -+++ b/third_party/llvm/workspace.bzl -@@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") - - def repo(name): - """Imports LLVM.""" -- LLVM_COMMIT = "104f3c180644c8872eaad0b3fcf6a6b948d92a71" -- LLVM_SHA256 = "5caf03c6e40c87e7593ce50bfe53ec52a08677c221f4f611f30b3f40397505b8" -+ LLVM_COMMIT = "94c024adedcb53059c29d7c2d62982053b60e86a" -+ LLVM_SHA256 = "204cedeaab86f065ef64cb3889dd2e92ddd4a8f5d5b6bc1cb4b276694fb6a798" - - tf_http_archive( - name = name, diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index f2425a6d6c98fe..b303e313939ee5 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "a66667eefd65f73d50fab04298f477fc123b6740" - SHARDY_SHA256 = "543407a5fb203959d1189813275402dc5b8af6076203700ddea96a1dd8d981e1" + SHARDY_COMMIT = "ebd224c2199a003b2951fbeaa10daab88041762d" + SHARDY_SHA256 = "2809c6a97b99229a0279b2198bce0218629185f088047995f991c4dcfade8583" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 3e0b0e66bc8a4f..2eb32ea8c944be 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1,353 +1,83 @@ -diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ---- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -+++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir -@@ -41,3 +41,170 @@ - %3 = stablehlo.imag %1 : (tensor<4xcomplex>) -> tensor<4xf64> - func.return %2, %3 : tensor<4xf64>, tensor<4xf64> +diff --ruN a/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp b/stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +--- stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp ++++ stablehlo/stablehlo/conversions/linalg/transforms/TypeConversion.cpp +@@ -47,36 +47,36 @@ + return shapedType; } -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], -+// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) { -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1, 3], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1, 3], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], -+// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>) -> tensor<4x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<4x3x5x8xi32> -+func.func @gather_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0, 2], -+ start_indices_batching_dims = [1, 0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> -+ func.return %0 : tensor<4x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @gather_with_batching_dim_size_zero -+// CHECK-NEXT: %[[iota:.*]] = stablehlo.iota dim = 0 : tensor<0x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota]], %arg1, dim = 3 : (tensor<0x3x5x1xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x2xi32> -+// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ -+// CHECK-SAME: dimension_numbers = #stablehlo.gather< -+// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1], -+// CHECK-SAME: start_index_map = [0, 1], index_vector_dim = 3>, -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: slice_sizes = array -+// CHECK-SAME: }> : (tensor<0x2x9xi32>, tensor<0x3x5x2xi32>) -> tensor<0x3x5x8xi32> -+// CHECK-NEXT: return %[[gather]] : tensor<0x3x5x8xi32> -+func.func @gather_with_batching_dim_size_zero(%arg0: tensor<0x2x9xi32>, %arg1: tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> { -+ // CHECK-NO-DOWNGRADE: operand_batching_dims = [0] -+ // CHECK-NO-DOWNGRADE: start_indices_batching_dims = [0] -+ %0 = "stablehlo.gather"(%arg0, %arg1) <{ -+ dimension_numbers = #stablehlo.gather< -+ offset_dims = [3], -+ collapsed_slice_dims = [1], -+ operand_batching_dims = [0], -+ start_indices_batching_dims = [0], -+ start_index_map = [1], -+ index_vector_dim = 3 -+ >, -+ slice_sizes = array, -+ indices_are_sorted = true -+ }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> -+ func.return %0 : tensor<0x3x5x8xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_dims -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = false}> -+// CHECK: (tensor<3x2x4x7x9xi32>, tensor<4x3x5x4xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x7x9xi32> -+func.func @scatter_with_batching_dims(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1, 3], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1, 3], -+ index_vector_dim = 3 -+ >, -+ unique_indices = false -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> -+ func.return %0 : tensor<3x2x4x7x9xi32> -+} -+ -+// ----- -+ -+// CHECK-LABEL: @scatter_with_batching_no_index_vector_dim -+// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> -+// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> -+// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ -+// CHECK-SAME: indices_are_sorted = false, -+// CHECK-SAME: dimension_numbers = #stablehlo.scatter< -+// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2], -+// CHECK-SAME: scatter_dims_to_operand_dims = [0, 2, 1], index_vector_dim = 3>, -+// CHECK-SAME: unique_indices = true}> -+// CHECK: (tensor<3x2x4x9xi32>, tensor<4x3x5x3xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+// CHECK-NEXT: return %[[scatter]] : tensor<3x2x4x9xi32> -+func.func @scatter_with_batching_no_index_vector_dim(%arg0: tensor<3x2x4x9xi32>, %arg1: tensor<4x3x5xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> { -+ // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] -+ // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] -+ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ -+ indices_are_sorted = true, -+ scatter_dimension_numbers = #stablehlo.scatter< -+ update_window_dims = [3], -+ inserted_window_dims = [1], -+ input_batching_dims = [0, 2], -+ scatter_indices_batching_dims = [1, 0], -+ scatter_dims_to_operand_dims = [1], -+ index_vector_dim = 3 -+ >, -+ unique_indices = true -+ }> ({ -+ ^bb0(%arg3: tensor, %arg4: tensor): -+ stablehlo.return %arg4 : tensor -+ }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> -+ func.return %0 : tensor<3x2x4x9xi32> -+} -diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ---- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -+++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp -@@ -12,14 +12,22 @@ - - #include -+#include - #include -+#include -+#include -+#include +-std::optional materializeCastFromIllegal(OpBuilder &builder, Type type, ++Value materializeCastFromIllegal(OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) { + Type fromType = getElementTypeOrSelf(inputs[0].getType()); + Type toType = getElementTypeOrSelf(type); + if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || + !toType.isSignlessInteger()) +- return std::nullopt; ++ return Value(); + // Use unrealized conversion casts to do signful->signless conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); + } - #include "llvm/ADT/APFloat.h" -+#include "llvm/ADT/STLExtras.h" -+#include "llvm/ADT/SmallVector.h" - #include "llvm/Support/ErrorHandling.h" - #include "mlir/Dialect/Func/IR/FuncOps.h" - #include "mlir/IR/BuiltinAttributes.h" - #include "mlir/IR/BuiltinTypes.h" -+#include "mlir/IR/Diagnostics.h" - #include "mlir/IR/PatternMatch.h" -+#include "mlir/Rewrite/FrozenRewritePatternSet.h" - #include "mlir/Support/LLVM.h" - #include "mlir/Transforms/DialectConversion.h" - #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -@@ -58,6 +66,132 @@ - return targetVersion; +-std::optional materializeCastToIllegal(OpBuilder &builder, Type type, ++Value materializeCastToIllegal(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + Type fromType = getElementTypeOrSelf(inputs[0].getType()); + Type toType = getElementTypeOrSelf(type); + if (!fromType.isSignlessInteger() || + (!toType.isSignedInteger() && !toType.isUnsignedInteger())) +- return std::nullopt; ++ return Value(); + // Use unrealized conversion casts to do signless->signful conversions. + return builder.create(loc, type, inputs[0]) + ->getResult(0); } -+SmallVector mergeSortedDims(ArrayRef dims1, -+ ArrayRef dims2) { -+ SmallVector result; -+ result.reserve(dims1.size() + dims2.size()); -+ std::merge(dims1.begin(), dims1.end(), dims2.begin(), dims2.end(), -+ std::back_inserter(result)); -+ return result; -+} -+ -+// Returns an updated indices tensor such that an `IotaOp` is prepended for each -+// dim in `indicesBatchingDims` with a `ConcatenateOp`. +-std::optional scalarToTensor(OpBuilder &builder, Type type, ++Value scalarToTensor(OpBuilder &builder, Type type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + if (mlir::isa(inputs.front().getType())) { +- return std::nullopt; ++ return Value(); + } + Value result = + builder +diff --ruN a/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir b/stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir +--- stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir ++++ stablehlo/stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir +@@ -0,0 +1,15 @@ ++// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.4.1' %s | FileCheck %s ++ ++// AllToAll was in the initial StableHLO opset, but changed in v1.5.0 to have ++// tuple arguments. Ensure that serializing for 1.4.1 is valid and targets the ++// v1.4.0 opset. +// -+// If `indexVectorDim` is equal to the rank of `indices`, it is reshaped to have -+// a trailing dimension of size 1 so it can be concatenated with the `IotaOp`s. -+Value createConcatIndices(Value indices, int64_t indexVectorDim, -+ ArrayRef indicesBatchingDims, -+ PatternRewriter &rewriter) { -+ Location loc = indices.getLoc(); -+ auto indicesType = cast(indices.getType()); -+ bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); -+ -+ SmallVector iotaShape(indicesType.getShape()); -+ if (indexVectorDimOnLastDim) { -+ iotaShape.push_back(1); -+ } else { -+ iotaShape[indexVectorDim] = 1; -+ } -+ auto iotaType = -+ RankedTensorType::get(iotaShape, indicesType.getElementType()); -+ -+ SmallVector indicesToConcat; -+ indicesToConcat.reserve(indicesBatchingDims.size() + 1); -+ for (int64_t batchingDim : indicesBatchingDims) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, batchingDim)); -+ } -+ if (indexVectorDimOnLastDim) { -+ indicesToConcat.push_back( -+ rewriter.create(loc, iotaType, indices)); -+ } else { -+ indicesToConcat.push_back(indices); -+ } -+ return rewriter.create(loc, indicesToConcat, indexVectorDim); ++// This will catch issues in op `isLegal` checks: ++// op.minVersion() <= target <= op.maxVersion() ++ ++// CHECK-LABEL: vhlo.func_v1 @all_to_all ++func.func public @all_to_all(%arg0: tensor<8x8x1xui16>) -> tensor<1x8x8xui16> { ++ // CHECK: vhlo.all_to_all_v1 ++ %0 = "stablehlo.all_to_all"(%arg0) <{concat_dimension = 2 : i64, replica_groups = dense<[[0, 1, 2, 3, 4, 5, 6, 7]]> : tensor<1x8xi64>, split_count = 8 : i64, split_dimension = 0 : i64}> : (tensor<8x8x1xui16>) -> tensor<1x8x8xui16> ++ return %0 : tensor<1x8x8xui16> +} -+ -+//===----------------------------------------------------------------------===// -+// Patterns (non DRR) -+//===----------------------------------------------------------------------===// -+ -+// Converts a `GatherOp` with batching dims to a `GatherOp` without batching -+// dims, such that each batching dim becomes a collapsed slice dim with a -+// corresponding `IotaOp` concatenated to the start indices. -+class GatherWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(GatherOp op, -+ PatternRewriter &rewriter) const override { -+ GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); -+ ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); -+ if (operandBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "gather op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newCollapsedSliceDims = mergeSortedDims( -+ operandBatchingDims, dimNumbers.getCollapsedSliceDims()); -+ SmallVector newStartIndexMap = -+ llvm::to_vector(llvm::concat( -+ operandBatchingDims, dimNumbers.getStartIndexMap())); -+ Value newIndices = createConcatIndices( -+ op.getStartIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getStartIndicesBatchingDims(), rewriter); -+ rewriter.replaceOpWithNewOp( -+ op, op.getOperand(), newIndices, -+ GatherDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, -+ /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, -+ newStartIndexMap, dimNumbers.getIndexVectorDim()), -+ op.getSliceSizes(), /*indicesAreSorted=*/false); -+ -+ return success(); -+ } -+}; -+ -+// Converts a `ScatterOp` with batching dims to a `ScatterOp` without batching -+// dims, such that each batching dim becomes an inserted window dim with a -+// corresponding `IotaOp` concatenated to the scatter indices. -+class ScatterWithBatchingDimsExpander : public OpRewritePattern { -+ using OpRewritePattern::OpRewritePattern; -+ -+ LogicalResult matchAndRewrite(ScatterOp op, -+ PatternRewriter &rewriter) const override { -+ ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); -+ ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); -+ if (inputBatchingDims.empty()) { -+ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { -+ diag << "scatter op has no batching dims"; -+ }); -+ } -+ -+ SmallVector newInsertedWindowDims = -+ mergeSortedDims(inputBatchingDims, dimNumbers.getInsertedWindowDims()); -+ SmallVector newScatterDimsToOperandDims = -+ llvm::to_vector(llvm::concat( -+ inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); -+ Value newIndices = createConcatIndices( -+ op.getScatterIndices(), dimNumbers.getIndexVectorDim(), -+ dimNumbers.getScatterIndicesBatchingDims(), rewriter); -+ auto newScatterOp = rewriter.create( -+ op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, -+ op.getUpdates(), -+ ScatterDimensionNumbersAttr::get( -+ op.getContext(), dimNumbers.getUpdateWindowDims(), -+ newInsertedWindowDims, -+ /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, -+ newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), -+ /*indicesAreSorted=*/false, op.getUniqueIndices()); -+ -+ newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); -+ rewriter.replaceOp(op, newScatterOp.getResults()); -+ -+ return success(); +diff --ruN a/stablehlo/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/stablehlo/transforms/VhloToVersion.cpp +--- stablehlo/stablehlo/transforms/VhloToVersion.cpp ++++ stablehlo/stablehlo/transforms/VhloToVersion.cpp +@@ -92,6 +92,13 @@ + << " is greater than current version " + << Version::getCurrentVersion(); + ++ // Opset changes warrant a minor version bump, so this conversion assumes ++ // patch v0 since it is written against the opset at version `X.Y.0`. ++ if (targetVersion.getPatch() != 0) { ++ targetVersion = ++ vhlo::Version(targetVersion.getMajor(), targetVersion.getMinor(), 0); + } -+}; + - //===----------------------------------------------------------------------===// - // Pass - //===----------------------------------------------------------------------===// -@@ -107,10 +241,16 @@ - void populateStablehloCreateCompatibilityExpanderPatterns( - RewritePatternSet *patterns, MLIRContext *context, - vhlo::Version targetVersion) { -+ // StableHLO GatherOp/ScatterOp with batching dims is introduced in v1.1.0. -+ if (targetVersion < vhlo::Version(1, 1, 0)) { -+ patterns -+ ->add( -+ context); -+ } - // StableHLO TanOp is introduced in v1.4.0. - if (targetVersion < vhlo::Version(1, 4, 0)) { -- patterns->add(context); -- patterns->add(context); -+ patterns->add(context); - } + return targetVersion; } diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index 97fd0b990fc1c7..62097715d4e914 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "78c753ad13ad8205cacc5fcc12418c1ac97276c7" - STABLEHLO_SHA256 = "b7fef892020eb465a6d1ed921160f5229398ba10acff36b6345171b9867ccc7c" + STABLEHLO_COMMIT = "1c0b606503aac4f8e01f5511b0a10418bf93e7a6" + STABLEHLO_SHA256 = "9ccf08c7d2c7dc0a5c314cf13e3e82faafc8c3dc2a45f4d6fa634ca8e5e97705" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/llvm_integration/cl680875920.patch b/third_party/xla/third_party/triton/llvm_integration/cl680875920.patch new file mode 100644 index 00000000000000..bbc8f024c78689 --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl680875920.patch @@ -0,0 +1,114 @@ + +--- a/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp 2024-03-19 09:23:43.000000000 -0700 ++++ b/lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp 2024-10-01 02:58:18.000000000 -0700 +@@ -104,9 +104,26 @@ + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } ++ // Add LLVMOp Bundle Attrs ++ // https://github.com/llvm/llvm-project/blob/main/flang/lib/Optimizer/CodeGen/CodeGen.cpp#L113-L131 ++ llvm::SmallVector newAttrs; ++ newAttrs.reserve(callOp->getAttrs().size() + 2); ++ ++ for (mlir::NamedAttribute attr : callOp->getAttrs()) { ++ if (attr.getName() != "operandSegmentSizes") ++ newAttrs.push_back(attr); ++ } ++ ++ newAttrs.push_back(rewriter.getNamedAttr( ++ "operandSegmentSizes", ++ rewriter.getDenseI32ArrayAttr( ++ {static_cast(promotedOperands.size()), 0}))); ++ newAttrs.push_back(rewriter.getNamedAttr( ++ "op_bundle_sizes", rewriter.getDenseI32ArrayAttr({}))); ++ + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), +- promotedOperands, callOp->getAttrs()); ++ promotedOperands, newAttrs); + return newCallOp; + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-09-25 10:13:59.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -190,7 +190,8 @@ + auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32"); + LLVM::FastmathFlagsAttr defaultFlags{}; + auto rcpOp = rewriter.create( +- loc, returnType, name, operands[1], defaultFlags); ++ loc, returnType, name, operands[1], defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + + replacementOp = rewriter.create( + loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-08-20 03:28:55.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -219,7 +219,8 @@ + } + auto wmmaIntrinsic = rewriter.create( + loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name), +- operands, defaultFlags); ++ operands, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + + return wmmaIntrinsic.getResult(0); + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-09-16 13:44:40.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-09-30 23:51:44.000000000 -0700 +@@ -72,7 +72,10 @@ + auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); + SmallVector operands = {cmp}; + Value asmResult = +- rewriter.create(loc, type, stringAttr, operands) ++ rewriter ++ .create( ++ loc, type, stringAttr, operands, ::mlir::LLVM::FastmathFlags{}, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/) + ->getResult(0); + return asmResult; + } + + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +@@ -48,9 +48,10 @@ void createSchedGroupBarrier(PatternRewr + static_cast(groupIdValue)); + + LLVM::FastmathFlagsAttr defaultFlags{}; +- rewriter.create(loc, TypeRange{}, intrinsicName, +- ValueRange{mask, size, groupId}, +- defaultFlags); ++ rewriter.create( ++ loc, TypeRange{}, intrinsicName, ValueRange{mask, size, groupId}, ++ defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + // Insert intrinsic that controls the types of instructions that may be +@@ -63,8 +64,9 @@ Operation *createSchedBarrier(PatternRew + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); +- return rewriter.create(loc, TypeRange{}, intrinsicName, +- ValueRange{mask}, defaultFlags); ++ return rewriter.create( ++ loc, TypeRange{}, intrinsicName, ValueRange{mask}, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + // Insert an experimental intrinsic for instruction group level parallelism. +@@ -76,7 +78,8 @@ Operation *createIglpOpt(PatternRewriter + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); ++ loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags, ++ ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); + } + + struct InstructionSchedHintsRewriter diff --git a/third_party/xla/third_party/triton/llvm_integration/cl683501567.patch b/third_party/xla/third_party/triton/llvm_integration/cl683501567.patch new file mode 100644 index 00000000000000..7395934253fc0c --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl683501567.patch @@ -0,0 +1,13 @@ + +--- a/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-09-16 13:44:40.000000000 -0700 ++++ b/lib/Target/LLVMIR/LLVMDIScope.cpp 2024-10-08 22:38:50.000000000 -0700 +@@ -104,7 +104,8 @@ + auto subprogramAttr = LLVM::DISubprogramAttr::get( + context, distinctId, compileUnitAttr, fileAttr, funcNameAttr, + funcNameAttr, fileAttr, /*line=*/line, /*scopeline=*/line, +- subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}); ++ subprogramFlags, subroutineTypeAttr, /*retainNodes=*/{}, ++ /*annotations=*/{}); + funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr)); + } + diff --git a/third_party/xla/third_party/triton/llvm_integration/cl686059966.patch b/third_party/xla/third_party/triton/llvm_integration/cl686059966.patch new file mode 100644 index 00000000000000..b5fcd3a266e313 --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl686059966.patch @@ -0,0 +1,36 @@ + +--- a/lib/Analysis/AxisInfo.cpp 2024-10-01 12:24:54.000000000 -0700 ++++ b/lib/Analysis/AxisInfo.cpp 2024-10-15 05:20:45.000000000 -0700 +@@ -1079,8 +1079,8 @@ + + void AxisInfoAnalysis::visitForOpInductionVar( + scf::ForOp op, ArrayRef *> argLattices) { +- auto lb = getLatticeElementFor(op, op.getLowerBound())->getValue(); +- auto step = getLatticeElementFor(op, op.getStep())->getValue(); ++ auto lb = getLatticeElementFor(getProgramPointAfter(op), op.getLowerBound())->getValue(); ++ auto step = getLatticeElementFor(getProgramPointAfter(op), op.getStep())->getValue(); + + AxisInfo::DimVectorT knownContiguity(1, 1); + AxisInfo::DimVectorT knownDivisibility(1, 1); + +--- a/lib/Analysis/Utility.cpp 2024-10-02 02:26:53.000000000 -0700 ++++ b/lib/Analysis/Utility.cpp 2024-10-15 05:20:45.000000000 -0700 +@@ -826,15 +826,15 @@ + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { +- if (failed(visit(op))) ++ if (failed(visit(getProgramPointAfter(op)))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + +- LogicalResult visit(ProgramPoint point) override { +- Operation *op = point.get(); ++ LogicalResult visit(ProgramPoint* point) override { ++ Operation *op = point->getPrevOp(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( diff --git a/third_party/xla/third_party/triton/llvm_integration/cl686893691.patch b/third_party/xla/third_party/triton/llvm_integration/cl686893691.patch new file mode 100644 index 00000000000000..b27d7abb41d8eb --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl686893691.patch @@ -0,0 +1,80 @@ + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp 2024-10-17 07:36:44.000000000 -0700 +@@ -190,8 +190,7 @@ + auto name = StringAttr::get(callOp.getContext(), "llvm.amdgcn.rcp.f32"); + LLVM::FastmathFlagsAttr defaultFlags{}; + auto rcpOp = rewriter.create( +- loc, returnType, name, operands[1], defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ loc, returnType, name, operands[1], defaultFlags); + + replacementOp = rewriter.create( + loc, returnType, operands[0], rcpOp->getResult(0), defaultFlags); + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/WMMA.cpp 2024-10-17 07:49:54.000000000 -0700 +@@ -219,8 +219,7 @@ + } + auto wmmaIntrinsic = rewriter.create( + loc, TypeRange{valC.getType()}, StringAttr::get(loc.getContext(), name), +- operands, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ operands, defaultFlags); + + return wmmaIntrinsic.getResult(0); + } + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp 2024-10-02 02:26:53.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp 2024-10-17 07:39:38.000000000 -0700 +@@ -48,10 +48,9 @@ + static_cast(groupIdValue)); + + LLVM::FastmathFlagsAttr defaultFlags{}; +- rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{mask, size, groupId}, +- defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ rewriter.create(loc, TypeRange{}, intrinsicName, ++ ValueRange{mask, size, groupId}, ++ defaultFlags); + } + + // Insert intrinsic that controls the types of instructions that may be +@@ -64,9 +63,8 @@ + + Value mask = + LLVM::createConstantI32(loc, rewriter, static_cast(maskValue)); +- return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{mask}, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ return rewriter.create(loc, TypeRange{}, intrinsicName, ++ ValueRange{mask}, defaultFlags); + } + + // Insert an experimental intrinsic for instruction group level parallelism. +@@ -78,8 +76,7 @@ + Value iglpValue = + LLVM::createConstantI32(loc, rewriter, static_cast(value)); + return rewriter.create( +- loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/); ++ loc, TypeRange{}, intrinsicName, ValueRange{iglpValue}, defaultFlags); + } + + struct InstructionSchedHintsRewriter + +--- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-10-01 05:53:49.000000000 -0700 ++++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp 2024-10-17 07:37:23.000000000 -0700 +@@ -72,10 +72,7 @@ + auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); + SmallVector operands = {cmp}; + Value asmResult = +- rewriter +- .create( +- loc, type, stringAttr, operands, ::mlir::LLVM::FastmathFlags{}, +- ::llvm::ArrayRef<::mlir::ValueRange>{} /*op_bundle_operands*/) ++ rewriter.create(loc, type, stringAttr, operands) + ->getResult(0); + return asmResult; + } diff --git a/third_party/xla/third_party/triton/llvm_integration/cl689707450.patch b/third_party/xla/third_party/triton/llvm_integration/cl689707450.patch new file mode 100644 index 00000000000000..0afc2edb10d8d2 --- /dev/null +++ b/third_party/xla/third_party/triton/llvm_integration/cl689707450.patch @@ -0,0 +1,47 @@ + +--- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp 2024-08-05 02:40:13.000000000 -0700 ++++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp 2024-10-25 02:46:07.000000000 -0700 +@@ -56,7 +56,7 @@ + // This will create newArg, and map(origArg, newArg) + addArgumentMaterialization([&](OpBuilder &builder, + RankedTensorType tensorType, ValueRange inputs, +- Location loc) -> std::optional { ++ Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining arguments that have been converted to a new type. + // We use this to rewrite triton_gpu.sparse_dot in a separate pass after +@@ -65,14 +65,14 @@ + inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); +- return std::nullopt; ++ return Value(); + }); + + // If the origValue still has live user(s), use this to + // convert origValue to newValue + addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, + ValueRange inputs, +- Location loc) -> std::optional { ++ Location loc) -> Value { + // Allows partial TTIR to TTGIR conversion by materializing a conversion for + // remaining uses of values that have been converted to a new type. + // We use this to rewrite triton_gpu.sparse_dot in a separate pass after +@@ -81,7 +81,7 @@ + inputs); + llvm_unreachable("Source rematerialization should not happen in Triton -> " + "TritonGPU Conversion"); +- return std::nullopt; ++ return Value(); + }); + + // This will be called when (desiredType != newOperandType) +@@ -91,7 +91,7 @@ + ValueRange inputs, Location loc) { + auto cast = + builder.create(loc, tensorType, inputs); +- return std::optional(cast.getResult()); ++ return Value(cast.getResult()); + }); + } + diff --git a/third_party/xla/third_party/triton/llvm_integration/series.bzl b/third_party/xla/third_party/triton/llvm_integration/series.bzl index 656b9c894904d8..70fef78927d338 100644 --- a/third_party/xla/third_party/triton/llvm_integration/series.bzl +++ b/third_party/xla/third_party/triton/llvm_integration/series.bzl @@ -8,5 +8,10 @@ LLVM nor MLIR integrator, please do not add any patches to this list. """ llvm_patch_list = [ + "//third_party/triton:llvm_integration/cl680875920.patch", + "//third_party/triton:llvm_integration/cl683501567.patch", + "//third_party/triton:llvm_integration/cl686059966.patch", + "//third_party/triton:llvm_integration/cl686893691.patch", + "//third_party/triton:llvm_integration/cl689707450.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/temporary/fix_left_shift_overflow.patch b/third_party/xla/third_party/triton/temporary/fix_left_shift_overflow.patch new file mode 100644 index 00000000000000..ca31caef4b2824 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/fix_left_shift_overflow.patch @@ -0,0 +1,11 @@ +--- a/lib/Analysis/AxisInfo.cpp ++++ b/lib/Analysis/AxisInfo.cpp +@@ -932,7 +932,7 @@ private: + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } +- return std::max(1, lhsDivisibility / (1 << shift)); ++ return std::max(1, lhsDivisibility / (int64_t(1) << shift)); + } + + int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, diff --git a/third_party/xla/third_party/triton/temporary/further_mixed_precision_fix.patch b/third_party/xla/third_party/triton/temporary/further_mixed_precision_fix.patch new file mode 100644 index 00000000000000..6152ab48194c09 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/further_mixed_precision_fix.patch @@ -0,0 +1,36 @@ +This resolves the issue here b/372630230. The patch is not intended to be +submitted to Triton upstream. This is because OAI historically refused these +similar work-arounds and the proper fixes are considerably more expensive to do. +diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +--- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp ++++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +@@ -55,7 +55,8 @@ SmallVector reorderValues(const S + } + return ret; + } +- if (inBitWidth == 8 && ouBitWidth == 16) { ++ if ((inBitWidth == 8 && ouBitWidth == 16) || ++ (inBitWidth == 16 && ouBitWidth == 8)) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); +diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir +--- a/test/Conversion/tritongpu_to_llvm.mlir ++++ b/test/Conversion/tritongpu_to_llvm.mlir +@@ -1693,3 +1693,16 @@ module attributes {"triton_gpu.num-ctas" + tt.return + } + } ++ ++// ----- ++ ++#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], instrShape = [16, 8]}> ++#dot_operand = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=4}> ++module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { ++ tt.func @f16_to_f8_dot_operand(%f16_inp: tensor<32x32xf16, #dot_operand>) { ++ // CHECK-LABEL: @f16_to_f8_dot_operand ++ ++ %f8 = tt.fp_to_fp %f16_inp, rounding = rtne : tensor<32x32xf16, #dot_operand> -> tensor<32x32xf8E5M2, #dot_operand> ++ tt.return ++ } ++} diff --git a/third_party/xla/third_party/triton/temporary/i4_to_bf16.patch b/third_party/xla/third_party/triton/temporary/i4_to_bf16.patch new file mode 100644 index 00000000000000..6afe4ee3b7157f --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/i4_to_bf16.patch @@ -0,0 +1,129 @@ + +--- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-09-25 10:13:59.000000000 -0700 ++++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ElementwiseOpToLLVM.cpp 2024-10-07 00:38:03.000000000 -0700 +@@ -264,7 +264,8 @@ + outVecWidthBits](Location loc, ConversionPatternRewriter &rewriter, + const SmallVector &v) -> SmallVector { + int numElements = v.size(); +- assert(numElements == 4 || numElements == 2 && "invalid vector size"); ++ assert(numElements == 8 || numElements == 4 || ++ numElements == 2 && "invalid vector size"); + + auto ctx = rewriter.getContext(); + int inBitwidth = inType.getIntOrFloatBitWidth(); +@@ -669,6 +670,115 @@ + PatternBenefit benefit = patternBenefitDefault) + : ElementwiseOpConversionBase(typeConverter, axisAnalysisPass, benefit), + computeCapability(computeCapability) {} ++ ++ LogicalResult matchAndRewrite( ++ arith::SIToFPOp op, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const override { ++ if (succeeded(matchAndRewriteInt4ToBf16Conversion(op, rewriter))) { ++ return success(); ++ } ++ return Base::matchAndRewrite(op, adaptor, rewriter); ++ } ++ ++ // Matches subgraph of convert 8xi4 to 8xbf16 and rewrites it to inline PTX. ++ LogicalResult matchAndRewriteInt4ToBf16Conversion( ++ arith::SIToFPOp op, ConversionPatternRewriter &rewriter) const { ++ if (computeCapability < 90) return failure(); ++ Type inElemTy = getElementType(op.getIn()); ++ Type outElemTy = getElementType(op.getOut()); ++ if (!inElemTy.isInteger(8) || !outElemTy.isBF16()) return failure(); ++ FailureOr unpack = matchInt4Unpack(op.getIn()); ++ if (failed(unpack)) return failure(); ++ ++ Location loc = op.getLoc(); ++ Value src = rewriter.getRemappedValue(unpack.value()); ++ auto structTy = dyn_cast(src.getType()); ++ if (!structTy || structTy.getBody().size() % 4 != 0) return failure(); ++ auto isInt8 = [](Type type) { return type.isInteger(8); }; ++ if (!all_of(structTy.getBody(), isInt8)) return failure(); ++ ++ const LLVMTypeConverter *typeConverter = getTypeConverter(); ++ assert(inElemTy == typeConverter->convertType(inElemTy)); ++ assert(outElemTy == typeConverter->convertType(outElemTy)); ++ ++ const std::string S4_to_Bf16_sm90 = R"({ ++ .reg .b32 r<4>, mi, mf; ++ mov.b32 mi, 0x43404340 - 0x00080008; ++ mov.b32 mf, 0x43404340; ++ // Shift 4-bit inputs to 16-bit boundary. ++ shr.u32 r1, $4, 4; ++ shr.u32 r2, $4, 8; ++ shr.u32 r3, $4, 12; ++ // Sign-extend from 4 bits is equivalent to (x ^ 0x8) - 0x8. ++ lop3.b32 r0, $4, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r1, r1, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r2, r2, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ lop3.b32 r3, r3, 0x000f000f, 0x00080008, (0xf0 & 0xcc) ^ 0xaa; ++ // Interger-add magic number (minus bias from sign-extend above). ++ add.s16x2 r0, r0, mi; ++ add.s16x2 r1, r1, mi; ++ add.s16x2 r2, r2, mi; ++ add.s16x2 r3, r3, mi; ++ // Float-subtract magic number. ++ sub.bf16x2 r0, r0, mf; ++ sub.bf16x2 r1, r1, mf; ++ sub.bf16x2 r2, r2, mf; ++ sub.bf16x2 r3, r3, mf; ++ // Shuffle results into correct order. ++ prmt.b32 $0, r1, r0, 0x5410; ++ prmt.b32 $1, r3, r2, 0x5410; ++ prmt.b32 $2, r1, r0, 0x7632; ++ prmt.b32 $3, r3, r2, 0x7632; ++ })"; ++ ++ SmallVector resultVals; ++ SmallVector unpackedVals = unpackLLElements(loc, src, rewriter); ++ auto cvtFunc = makeConverterFromPtx(S4_to_Bf16_sm90, inElemTy, outElemTy); ++ for (ValueRange operands = unpackedVals; !operands.empty(); ++ operands = operands.drop_front(4)) { ++ SmallVector inVals = { ++ operands[0], operands[1], operands[2], operands[3], ++ // Repeat operands so that cvtFunc produces 8 outputs. ++ operands[0], operands[1], operands[2], operands[3]}; ++ auto outVals = cvtFunc(loc, rewriter, inVals); ++ assert(inVals.size() == outVals.size()); ++ resultVals.append(outVals.begin(), outVals.end()); ++ } ++ ++ resultVals = reorderValues(resultVals, op.getIn().getType(), op.getType()); ++ resultVals = maybeDeduplicate(op, resultVals); ++ Value view = ++ packLLElements(loc, typeConverter, resultVals, rewriter, op.getType()); ++ rewriter.replaceOp(op, view); ++ ++ return success(); ++ } ++ ++ // Returns the source if value is the result of an 2xi4 -> 2xi8 unpack ++ // sequence. ++ static FailureOr matchInt4Unpack(Value value) { ++ auto reshape = value.getDefiningOp(); ++ if (!reshape) return failure(); ++ auto join = reshape.getSrc().getDefiningOp(); ++ if (!join) return failure(); ++ auto shrHi = join.getLhs().getDefiningOp(); ++ if (!shrHi || !isConst4(shrHi.getRhs())) return failure(); ++ auto shrLo = join.getRhs().getDefiningOp(); ++ if (!shrLo || !isConst4(shrLo.getRhs())) return failure(); ++ auto shlLo = shrLo.getLhs().getDefiningOp(); ++ if (!shlLo || !isConst4(shlLo.getRhs())) return failure(); ++ if (shrHi.getLhs() != shlLo.getLhs()) return failure(); ++ return shrHi.getLhs(); ++ } ++ ++ // Returns true if the value is equal to 4. ++ static bool isConst4(Value v) { ++ auto constOp = v.getDefiningOp(); ++ if (!constOp) return false; ++ auto attr = mlir::dyn_cast(constOp.getValue()); ++ if (!attr || !attr.isSplat()) return false; ++ return attr.getSplatValue().getLimitedValue() == 4; ++ }; + + SmallVector createDestOps(arith::SIToFPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, diff --git a/third_party/xla/third_party/triton/temporary/prefetch.patch b/third_party/xla/third_party/triton/temporary/prefetch.patch new file mode 100644 index 00000000000000..57033b1ac972b1 --- /dev/null +++ b/third_party/xla/third_party/triton/temporary/prefetch.patch @@ -0,0 +1,27 @@ +# b/370665038 These seeem to be real bugs that should be upstreamed. +diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +--- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ++++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +@@ -116,7 +116,7 @@ Value Prefetcher::generatePrefetch(Value + // opIdx: 0 => a, 1 => b + auto type = cast(v.getType()); + SmallVector shape{type.getShape().begin(), type.getShape().end()}; +- SmallVector offset{0, 0}; ++ SmallVector offset(shape.size(), 0); + Type elementType = type.getElementType(); + + // k => (prefetchWidth, k - prefetchWidth) +@@ -205,6 +205,13 @@ LogicalResult Prefetcher::initialize() { + if (srcType.isInteger(1)) + break; + } ++ // Propagation through ExpandDims is currently not supported. This blindly ++ // replaces the encoding with dot encoding & but ExpandDims requires a ++ // SliceEncoding. This could be rewritten to support it somehow, but I ++ // don't think it's trivial & it's currently crashing. ++ if (isa(op)) { ++ break; ++ } + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast(op)) { + foundConvertFromShared = true; diff --git a/third_party/xla/third_party/triton/temporary/series.bzl b/third_party/xla/third_party/triton/temporary/series.bzl index 4fa55269e3323c..d6b1a6a31f783f 100644 --- a/third_party/xla/third_party/triton/temporary/series.bzl +++ b/third_party/xla/third_party/triton/temporary/series.bzl @@ -14,5 +14,9 @@ those to this list. """ temporary_patch_list = [ + "//third_party/triton:temporary/fix_left_shift_overflow.patch", + "//third_party/triton:temporary/prefetch.patch", + "//third_party/triton:temporary/i4_to_bf16.patch", + "//third_party/triton:temporary/further_mixed_precision_fix.patch", # Add new patches just above this line ] diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index 911fe7493783cf..3d5f6f99bf73d4 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl673813747" - TRITON_SHA256 = "3e901c1b441407b1b7ac601092f64a9141571879b00a1ff54437c8e9370a365f" + TRITON_COMMIT = "cl680473520" + TRITON_SHA256 = "c18fa65138b8c566b2f0299ebde4242f0f30c3625741a17de73ed5b1990cdabb" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index 8f613badb53988..add4a2273ae715 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -280,7 +280,7 @@ index d74e0a224..4e45f7c4c 100644 static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, tt::CoarseSchedule &schedule, -@@ -236,19 +240,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { +@@ -235,17 +239,25 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) return std::nullopt; @@ -296,10 +296,7 @@ index d74e0a224..4e45f7c4c 100644 + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + tempAttr = ttg::SharedEncodingAttr::get( + val.getContext(), cast(enc), -+ srcTy.getShape(), ttg::getOrder(srcTy.getEncoding()), -+ ttg::getCTALayout(srcTy.getEncoding()), -+ srcTy.getElementType().getIntOrFloatBitWidth(), -+ /*needTrans=*/false); ++ srcTy.getShape(), order, CTALayout, bitWidth, /*needTrans=*/false); + } else if (isa(enc)) { + auto srcTy = cast(val.getType()); + tempAttr = ttg::SharedEncodingAttr::get( @@ -313,15 +310,13 @@ index d74e0a224..4e45f7c4c 100644 - auto order = ttg::getOrder(srcTy.getEncoding()); - unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); - tempAttr = ttg::SharedEncodingAttr::get( -- val.getContext(), dotOpEnc, srcTy.getShape(), -- ttg::getOrder(srcTy.getEncoding()), -- ttg::getCTALayout(srcTy.getEncoding()), -- srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); +- val.getContext(), dotOpEnc, srcTy.getShape(), order, CTALayout, +- bitWidth, /*needTrans=*/false); + } } // Check that the shared encodings needed by the users are compatible. - if (attr != nullptr && attr != tempAttr) { -@@ -357,7 +370,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + if (attr != nullptr) +@@ -352,7 +364,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { }; for (Operation &op : forOp.getBody()->without_terminator()) { @@ -330,7 +325,7 @@ index d74e0a224..4e45f7c4c 100644 continue; seen.clear(); dfs(&op, 0, &op); -@@ -434,7 +447,7 @@ assignMemoryLayouts(llvm::SmallVector> +@@ -429,7 +441,7 @@ assignMemoryLayouts(llvm::SmallVector> continue; } @@ -339,15 +334,6 @@ index d74e0a224..4e45f7c4c 100644 loadInfo.usedByDot = true; if (loadIsMMAv3(op)) { loadInfo.loadIsMMAV3 = true; -@@ -460,7 +473,7 @@ assignMemoryLayouts(llvm::SmallVector> - // The codegen bug is caught by an assertion, so if you think you've - // fixed it, feel free to delete this code and see if the assert still - // fails. :) -- if (!loadInfo.sharedEncoding) { -+ if (dot && !loadInfo.sharedEncoding) { - if (auto dotEnc = dyn_cast( - dot.getResult().getType().getEncoding())) { - auto loadTy = cast(op->getResultTypes()[0]); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fd..37795c20c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index fcbc9ff2772db4..ebef2e1e7a6c48 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -150,6 +150,8 @@ build:android_x86_64 --fat_apk_cpu=x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. build:android --dynamic_mode=off +# TODO(belitskiy): Remove once on Clang 20. +build:android --define=xnn_enable_avxvnniint8=false # Sets the default Apple platform to macOS. build:macos --apple_platform_type=macos @@ -212,6 +214,7 @@ build:mkl_aarch64 -c opt # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). # with Eigen threadpool support build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +build:mkl_aarch64_threadpool --@compute_library//:openmp=false build:mkl_aarch64_threadpool -c opt # CUDA: This config refers to building CUDA op kernels with nvcc. @@ -240,6 +243,8 @@ build:cuda_clang --copt=-Qunused-arguments # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Permit newer CUDA versions than Clang is aware of +build:cuda_clang --copt="-Wno-unknown-cuda-version" # Set lld as the linker. build:cuda_clang --host_linkopt="-fuse-ld=lld" build:cuda_clang --host_linkopt="-lm" @@ -254,10 +259,11 @@ build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-1 build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc - +build:cuda_nvcc --config=cuda +build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc +# Old config for backward compatibility +build:nvcc_clang --config=cuda_nvcc # Debug config build:dbg -c dbg @@ -327,8 +333,6 @@ build:linux --copt="-Werror=unused-result" # Add switch as an error on Linux. build:linux --copt="-Wswitch" build:linux --copt="-Werror=switch" -# Required for building with clang -build:linux --copt="-Wno-error=unused-but-set-variable" # Linux ARM64 specific options build:linux_arm64 --copt="-mtune=generic" --copt="-march=armv8-a" --copt="-O3" @@ -550,7 +554,7 @@ build:rbe_linux_cuda --config=rbe_linux_cpu build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda -build:rbe_linux_cuda_nvcc --config=nvcc_clang +build:rbe_linux_cuda_nvcc --config=cuda_nvcc build:rbe_linux_cuda_nvcc --repo_env TF_NCCL_USE_STUB=1 build:rbe_win_base --config=rbe_base @@ -738,27 +742,27 @@ build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_pa test:linux_cpu_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cpu_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cpu_wheel_test --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cpu_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cpu_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test # CUDA WHEEL test:linux_cuda_wheel_test_filters --test_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --build_tag_filters=gpu,requires-gpu,-no_gpu,-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-no_cuda11,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_cuda_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_cuda_wheel_test --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... +test:linux_cuda_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_cuda_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/tools/pip_package:import_api_packages_test # ARM64 WHEEL test:linux_arm64_wheel_test_filters --test_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --build_tag_filters=-no_oss,-no_aarch64,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only,-no_oss_py38,-no_oss_py39,-no_oss_py310 test:linux_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:linux_arm64_wheel_test --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test +test:linux_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=linux_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/core/grappler/optimizers:auto_mixed_precision_test_cpu -//tensorflow/core/grappler/optimizers:remapper_test_cpu -//tensorflow/core/kernels/image:resize_bicubic_op_test -//tensorflow/tools/pip_package:import_api_packages_test # MACOS ARM64 WHEEL test:macos_arm64_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test,-no_mac_arm64,-no_aarch64 test:macos_arm64_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_arm64_wheel_test --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_arm64_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_arm64_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test # MACOS X86 WHEEL test:macos_x86_wheel_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-no_oss_py38,-no_oss_py39,-no_oss_py310,-nomac,-no_mac,-mac_excluded,-v1only,-gpu,-tpu,-benchmark-test test:macos_x86_wheel_test_filters --test_lang_filters=py --test_size_filters=small,medium -test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... +test:macos_x86_wheel_test --@local_tsl//third_party/py:wheel_dependency=true --config=macos_x86_wheel_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -//tensorflow/go/... -//tensorflow/java/... -//tensorflow/compiler/aot/... -//tensorflow/tools/pip_package:import_api_packages_test # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index b96a4bc89722c5..265b15c67ea2a0 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -108,6 +108,7 @@ third_party/py/python_init_repositories.bzl: third_party/py/python_init_rules.bzl: third_party/py/python_init_toolchains.bzl: third_party/py/python_repo.bzl: +third_party/py/python_wheel_library.bzl: third_party/pybind11.BUILD: third_party/pybind11_bazel/BUILD: third_party/python_runtime/BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py index b1a10a86b9aac6..a1d47efcc93a81 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -27,16 +27,10 @@ import os import os.path import platform +import shutil import subprocess import sys -# pylint: disable=g-import-not-at-top,g-importing-member -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top,g-importing-member - class ConfigError(Exception): pass @@ -59,7 +53,7 @@ def check_cuda_lib(path, check_soname=True): """ if not os.path.isfile(path): raise ConfigError("No library found under: " + path) - objdump = which("objdump") + objdump = shutil.which("objdump") if check_soname and objdump is not None and not _is_windows(): # Decode is necessary as in py3 the return type changed from str to bytes output = subprocess.check_output([objdump, "-p", path]).decode("utf-8") diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl index a742cfcd208ec1..03a9dde83cfddc 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.rocm.tpl @@ -82,19 +82,18 @@ cc_toolchain_config( "-fdata-sections", ], dbg_compile_flags = ["-g"], - cxx_flags = ["-std=c++14"], + cxx_flags = ["-std=c++17"], link_flags = [ "-fuse-ld=gold", "-Wl,-no-as-needed", "-Wl,-z,relro,-z,now", - "-pass-exit-codes", + ], + link_libs = [ "-lstdc++", "-lm", ], - link_libs = [], opt_link_flags = [], unfiltered_compile_flags = [ - "-fno-canonical-system-headers", "-Wno-builtin-macro-redefined", "-D__DATE__=\"redacted\"", "-D__TIMESTAMP__=\"redacted\"", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl index f15df1f974a77b..b17b2c9b0e23b5 100755 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl @@ -24,8 +24,10 @@ import pipes # Template values set by rocm_configure.bzl. CPU_COMPILER = ('%{cpu_compiler}') +HOST_COMPILER_PATH = ('%{host_compiler_path}') HIPCC_PATH = '%{hipcc_path}' +PREFIX_DIR = os.path.dirname(HOST_COMPILER_PATH) HIPCC_ENV = '%{hipcc_env}' HIP_RUNTIME_PATH = '%{hip_runtime_path}' HIP_RUNTIME_LIBRARY = '%{hip_runtime_library}' @@ -75,6 +77,7 @@ def GetHostCompilerOptions(argv): parser.add_argument('--sysroot', nargs=1) parser.add_argument('-g', nargs='*', action='append') parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') parser.add_argument('--genco', action='store_true') args, _ = parser.parse_known_args(argv) @@ -87,7 +90,7 @@ def GetHostCompilerOptions(argv): opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) if args.g: opts += ' -g' + ' -g'.join(sum(args.g, [])) - if args.fno_canonical_system_headers: + if args.fno_canonical_system_headers or args.no_canonical_prefixes: opts += ' -no-canonical-prefixes' if args.sysroot: opts += ' --sysroot ' + args.sysroot[0] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl index e0541defa34687..e5a942b66c17fc 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/hipcc_cc_toolchain_config.bzl.tpl @@ -1046,7 +1046,6 @@ def _impl(ctx): flag_group( flags = [ "-no-canonical-prefixes", - "-fno-canonical-system-headers", ] ), ], diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index 2faabefe081f4b..6c1b68ffb77bcf 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -149,11 +149,14 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], tags = [],**kwargs): +def cuda_library(copts = [], tags = [], deps = [], **kwargs): """Wrapper over cc_library which adds default CUDA options.""" native.cc_library( copts = cuda_default_copts() + copts, tags = tags + ["gpu"], + deps = deps + if_cuda_is_configured([ + "@local_config_cuda//cuda:implicit_cuda_headers_dependency", + ]), **kwargs ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl index 5d9a9da3c967d8..58c4638dd55c3f 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -69,6 +69,16 @@ cc_library( ":nvjitlink_headers"], ) +# This target is needed by the `cuda_library` rule. We can't implicitly +# depend on `:cuda_headers` directly since the user may explicit depend +# on `:cuda_headers` and duplicated dependencies are not allowed in Bazel. +# There is also no good way to deduplicate dependencies, but an alias works +# just fine. +alias( + name = "implicit_cuda_headers_dependency", + actual = ":cuda_headers", +) + cc_library( name = "cudart_static", srcs = ["@cuda_cudart//:static"], @@ -79,6 +89,11 @@ cc_library( ], ) +alias( + name = "cuda_runtime", + actual = ":cudart_static", +) + alias( name = "cuda_driver", actual = select({ diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl index 11b32cdbb71c56..ecc99f06455614 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -219,6 +219,10 @@ def _create_libcuda_symlinks( repository_ctx.symlink(nvidia_driver_path, "lib/libcuda.so.1") repository_ctx.symlink("lib/libcuda.so.1", "lib/libcuda.so") +def _create_cuda_header_symlinks(repository_ctx): + if repository_ctx.name == "cuda_nvcc": + repository_ctx.symlink("../cuda_cudart/include/cuda.h", "include/cuda.h") + def use_local_path(repository_ctx, local_path, dirs): # buildifier: disable=function-docstring-args """Creates repository using local redistribution paths.""" @@ -339,6 +343,7 @@ def _use_downloaded_cuda_redistribution(repository_ctx): repository_ctx, lib_name_to_version_dict, ) + _create_cuda_header_symlinks(repository_ctx) repository_ctx.file("version.txt", major_version) def _cuda_repo_impl(repository_ctx): diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl index 6934b75b47852d..89516f869ad07b 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -58,6 +58,10 @@ CUDA_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", ], + "12.6.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.1.json", + "22ddfeb81a6f9cee4a708a2e3b4db1c36c7db0a1daa1f33f9c7f2f12a1e790de", + ], } CUDNN_REDIST_JSON_DICT = { @@ -97,20 +101,22 @@ CUDNN_REDIST_JSON_DICT = { "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.4.0.json", "6eeaafc5cc3d4bb2f283e6298e4c55d4c59d7c83c5d9fd8721a2c0e55aee4e54", ], + "9.5.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.5.0.json", + "3939f0533fdd0d3aa7edd1ac358d43da18e438e5d8f39c3c15bb72519bad7fb5", + ], } -# The versions are different for x86 and aarch64 architectures because only -# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. CUDA_12_NCCL_WHEEL_DICT = { "x86_64-unknown-linux-gnu": { - "version": "2.21.5", - "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", - "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + "version": "2.23.4", + "url": "https://files.pythonhosted.org/packages/ed/1f/6482380ec8dcec4894e7503490fc536d846b0d59694acad9cf99f27d0e7d/nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_x86_64.whl", + "sha256": "b097258d9aab2fa9f686e33c6fe40ae57b27df60cedbd15d139701bb5509e0c1", }, "aarch64-unknown-linux-gnu": { - "version": "2.20.5", - "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", - "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + "version": "2.23.4", + "url": "https://files.pythonhosted.org/packages/c8/3a/0112397396dec37ffc8edd7836d48261b4d14ca60ec8ed7bc857cce1d916/nvidia_nccl_cu12-2.23.4-py3-none-manylinux2014_aarch64.whl", + "sha256": "aa946c8327e22ced28e7cef508a334673abc42064ec85f02d005ba1785ea4cec", }, } @@ -134,12 +140,14 @@ CUDA_NCCL_WHEELS = { "12.5.0": CUDA_12_NCCL_WHEEL_DICT, "12.5.1": CUDA_12_NCCL_WHEEL_DICT, "12.6.0": CUDA_12_NCCL_WHEEL_DICT, + "12.6.1": CUDA_12_NCCL_WHEEL_DICT, } REDIST_VERSIONS_TO_BUILD_TEMPLATES = { "nvidia_driver": { "repo_name": "cuda_driver", "version_to_template": { + "560": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "555": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "550": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", "545": "//third_party/gpus/cuda/hermetic:cuda_driver.BUILD.tpl", diff --git a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py index 68623bf671da71..c04dace79fe599 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -56,21 +56,15 @@ tf__library_dir: ... """ +import glob import io import os -import glob import platform import re +import shutil import subprocess import sys -# pylint: disable=g-import-not-at-top -try: - from shutil import which -except ImportError: - from distutils.spawn import find_executable as which -# pylint: enable=g-import-not-at-top - class ConfigError(Exception): pass @@ -139,7 +133,7 @@ def _get_ld_config_paths(): """Returns all directories from 'ldconfig -p'.""" if not _is_linux(): return [] - ldconfig_path = which("ldconfig") or "/sbin/ldconfig" + ldconfig_path = shutil.which("ldconfig") or "/sbin/ldconfig" output = subprocess.check_output([ldconfig_path, "-p"]) pattern = re.compile(".* => (.*)") result = set() diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index fb63d4db886c1c..e56810a4cbbf8b 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -3,7 +3,11 @@ `rocm_configure` depends on the following environment variables: * `TF_NEED_ROCM`: Whether to enable building with ROCm. - * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path + * `GCC_HOST_COMPILER_PATH`: The GCC host compiler path. + * `TF_ROCM_CLANG`: Whether to use clang for C++ and HIPCC for ROCm compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `CLANG_COMPILER_PATH`: The clang compiler path that will be used for + host code compilation if TF_ROCM_CLANG is 1. * `ROCM_PATH`: The path to the ROCm toolkit. Default is `/opt/rocm`. * `TF_ROCM_AMDGPU_TARGETS`: The AMDGPU targets. """ @@ -39,6 +43,8 @@ load( _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" +_CLANG_COMPILER_PATH = "CLANG_COMPILER_PATH" +_TF_SYSROOT = "TF_SYSROOT" _ROCM_TOOLKIT_PATH = "ROCM_PATH" _TF_ROCM_AMDGPU_TARGETS = "TF_ROCM_AMDGPU_TARGETS" _TF_ROCM_CONFIG_REPO = "TF_ROCM_CONFIG_REPO" @@ -72,12 +78,15 @@ def verify_build_defines(params): ".", ) -def find_cc(repository_ctx): +def find_cc(repository_ctx, use_rocm_clang): """Find the C++ compiler.""" - # Return a dummy value for GCC detection here to avoid error - target_cc_name = "gcc" - cc_path_envvar = _GCC_HOST_COMPILER_PATH + if use_rocm_clang: + target_cc_name = "clang" + cc_path_envvar = _CLANG_COMPILER_PATH + else: + target_cc_name = "gcc" + cc_path_envvar = _GCC_HOST_COMPILER_PATH cc_name = target_cc_name cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) @@ -99,24 +108,26 @@ def _cxx_inc_convert(path): path = path.strip() return path -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): """Compute the list of default C or C++ include directories.""" if lang_is_cpp: lang = "c++" else: lang = "c" + sysroot = [] + if tf_sysroot: + sysroot += ["--sysroot", tf_sysroot] # TODO: We pass -no-canonical-prefixes here to match the compiler flags, # but in rocm_clang CROSSTOOL file that is a `feature` and we should # handle the case when it's disabled and no flag is passed result = raw_exec(repository_ctx, [ cc, - "-no-canonical-prefixes", "-E", "-x" + lang, "-", "-v", - ]) + ] + sysroot) stderr = err_out(result) index1 = stderr.find(_INC_DIR_MARKER_BEGIN) if index1 == -1: @@ -138,14 +149,24 @@ def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp): for p in inc_dirs.split("\n") ] -def get_cxx_inc_directories(repository_ctx, cc): +def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): """Compute the list of default C and C++ include directories.""" # For some reason `clang -xc` sometimes returns include paths that are # different from the ones from `clang -xc++`. (Symlink and a dir) # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True) - includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False) + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sysroot, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sysroot, + ) includes_cpp_set = depset(includes_cpp) return includes_cpp + [ @@ -207,6 +228,7 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/16.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/17/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") if int(rocm_config.rocm_version_number) >= 60200: inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include") @@ -476,7 +498,7 @@ def _create_dummy_repository(repository_ctx): "%{hipblas_lib}": _lib_name("hipblas"), "%{miopen_lib}": _lib_name("miopen"), "%{rccl_lib}": _lib_name("rccl"), - "%{hipfft_or_rocfft}": _lib_name("hipfft"), + "%{hipfft_or_rocfft}": "hipfft", "%{hipfft_or_rocfft_lib}": _lib_name("hipfft"), "%{hiprand_lib}": _lib_name("hiprand"), "%{hipsparse_lib}": _lib_name("hipsparse"), @@ -539,8 +561,18 @@ def _genrule(src_dir, genrule_name, command, outs): ")\n" ) +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_rocm_clang(repository_ctx): + # Returns the flag if we need to use clang for the host. + return _flag_enabled(repository_ctx, "TF_ROCM_CLANG") + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + def _compute_rocm_extra_copts(repository_ctx, amdgpu_targets): - amdgpu_target_flags = ["--amdgpu-target=" + + amdgpu_target_flags = ["--offload-arch=" + amdgpu_target for amdgpu_target in amdgpu_targets] return str(amdgpu_target_flags) @@ -674,6 +706,10 @@ def _create_local_rocm_repository(repository_ctx): hiprand_include + rocrand_include), } + + is_rocm_clang = _use_rocm_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + if rocm_libs["hipblaslt"] != None: repository_dict["%{hipblaslt_lib}"] = rocm_libs["hipblaslt"].file_name @@ -689,24 +725,36 @@ def _create_local_rocm_repository(repository_ctx): # Set up crosstool/ - cc = find_cc(repository_ctx) + cc = find_cc(repository_ctx, is_rocm_clang) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) - host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) - - host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") + # host_compiler_includes = get_cxx_inc_directories(repository_ctx, cc) rocm_defines = {} - + rocm_defines["%{builtin_sysroot}"] = tf_sysroot + rocm_defines["%{compiler}"] = "unknown" + if is_rocm_clang: + rocm_defines["%{compiler}"] = "clang" + host_compiler_prefix = get_host_environ(repository_ctx, _GCC_HOST_COMPILER_PREFIX, "/usr/bin") rocm_defines["%{host_compiler_prefix}"] = host_compiler_prefix + rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + host_compiler_prefix + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "" + rocm_defines["%{unfiltered_compile_flags}"] = "" + rocm_defines["%{rocm_hipcc_files}"] = "[]" - rocm_defines["%{linker_bin_path}"] = rocm_config.rocm_toolkit_path + "/hcc/compiler/bin" - - # For gcc, do not canonicalize system header paths; some versions of gcc - # pick the shortest possible path for system includes when creating the - # .d file - given that includes that are prefixed with "../" multiple - # time quickly grow longer than the root of the tree, this can lead to - # bazel's header check failing. - rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" + if is_rocm_clang: + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-no-canonical-prefixes\"" + else: + # For gcc, do not canonicalize system header paths; some versions of gcc + # pick the shortest possible path for system includes when creating the + # .d file - given that includes that are prefixed with "../" multiple + # time quickly grow longer than the root of the tree, this can lead to + # bazel's header check failing. + rocm_defines["%{extra_no_canonical_prefixes_flags}"] = "\"-fno-canonical-system-headers\"" rocm_defines["%{unfiltered_compile_flags}"] = to_list_of_strings([ "-DTENSORFLOW_USE_ROCM=1", @@ -834,6 +882,7 @@ _ENVIRONS = [ _GCC_HOST_COMPILER_PATH, _GCC_HOST_COMPILER_PREFIX, "TF_NEED_ROCM", + "TF_ROCM_CLANG", "TF_NEED_CUDA", # Needed by the `if_gpu_is_configured` macro _ROCM_TOOLKIT_PATH, _TF_ROCM_AMDGPU_TARGETS, diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD index 1f4b58f47e379c..bfbde6cf22eeff 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD +++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.BUILD @@ -22,9 +22,9 @@ exports_files(["LICENSE.txt"]) NCCL_MAJOR = 2 -NCCL_MINOR = 21 +NCCL_MINOR = 23 -NCCL_PATCH = 5 +NCCL_PATCH = 4 NCCL_VERSION = NCCL_MAJOR * 10000 + NCCL_MINOR * 100 + NCCL_PATCH # e.g., 21605 diff --git a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch index 2b4fa56a97e759..4fc2dbb7aded8a 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/archive.patch +++ b/third_party/xla/third_party/tsl/third_party/nccl/archive.patch @@ -1,35 +1,16 @@ -diff --git a/src/device/all_gather.h b/src/device/all_gather.h -index 809e8ae..57eab81 100644 ---- a/src/device/all_gather.h -+++ b/src/device/all_gather.h -@@ -296,7 +296,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - } - return; -@@ -314,7 +314,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - return; - } diff --git a/src/device/common.cu b/src/device/common.cu.cc similarity index 100% rename from src/device/common.cu rename to src/device/common.cu.cc +diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc +similarity index 100% +rename from src/device/onerank.cu +rename to src/device/onerank.cu.cc diff --git a/src/device/common.h b/src/device/common.h -index d8581d3..09ac3b6 100644 --- a/src/device/common.h +++ b/src/device/common.h -@@ -15,7 +15,7 @@ - #define COLL_UNROLL (ncclCollUnroll()) +@@ -24,7 +24,7 @@ + #endif typedef void(*ncclDevFuncPtr_t)(); -extern __device__ ncclDevFuncPtr_t const ncclDevFuncTable[]; @@ -38,14 +19,16 @@ index d8581d3..09ac3b6 100644 struct ncclShmemGroup { ncclConnInfo *recvConns[NCCL_MAX_ARITY]; diff --git a/src/device/generate.py b/src/device/generate.py -index 43de85d..87cd677 100755 +index a0d2259..62d6014 100755 --- a/src/device/generate.py +++ b/src/device/generate.py -@@ -195,7 +195,7 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) +@@ -194,8 +194,8 @@ kernel_funcs = sorted(set(best_kernel(*fn) for fn in primary_funcs)) + ################################################################################ - # Generate /device_table.cu +-# Generate /device_table.cu -with open(os.path.join(gensrc, "device_table.cu"), "w") as f: ++# Generate /device_table.cu.cc +with open(os.path.join(gensrc, "device_table.cu.cc"), "w") as f: out = f.write out('#include "common.h"\n') @@ -59,12 +42,11 @@ index 43de85d..87cd677 100755 index = 0 for fn in primary_funcs: sym = paste("_", "ncclDevFunc", *fn) -@@ -257,28 +257,45 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: +@@ -262,28 +262,43 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: # List of all kernel function pointers. out("extern int const ncclDevKernelCount = %d;\n" % len(kernel_funcs)) - out("extern void* const ncclDevKernelList[] = {\n") -+ index = 0 for kfn in kernel_funcs: cudart, _ = required_cuda(*kfn) @@ -88,7 +70,6 @@ index 43de85d..87cd677 100755 # Maps primary id to kernel function pointer. - out("extern void* const ncclDevKernelForFunc[] = {\n") -+ index = 0 for fn in primary_funcs: kfn = best_kernel(*fn) @@ -111,7 +92,7 @@ index 43de85d..87cd677 100755 index += 1 out("nullptr};\n") out("\n") -@@ -297,7 +314,7 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: +@@ -302,7 +317,7 @@ with open(os.path.join(gensrc, "host_table.cc"), "w") as f: # "coll" is reflected in the name: formally that no two funcs having different # coll's map to the same filename. def impl_filename(coll, redop, ty, algo, proto): @@ -120,7 +101,7 @@ index 43de85d..87cd677 100755 # Partition the functions and kernels to the .cu filenames. The partition is # a dictionary mapping filename to (coll, func-tuple list) -@@ -318,7 +335,7 @@ name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Gene +@@ -323,7 +338,7 @@ name_to_kernels = partition_by_name(kfn for kfn in kernel_funcs if kfn[0]!="Gene with open(os.path.join(gensrc, "rules.mk"), "w") as f: out = f.write impl_names = sorted(name_to_funcs.keys()) @@ -129,29 +110,3 @@ index 43de85d..87cd677 100755 out("LIB_OBJS_GEN = $(patsubst %, $(OBJDIR)/genobj/%.o, {names})\n" .format(names=" ".join(names))) out("\n") -diff --git a/src/device/onerank.cu b/src/device/onerank.cu.cc -similarity index 100% -rename from src/device/onerank.cu -rename to src/device/onerank.cu.cc -diff --git a/src/device/reduce_scatter.h b/src/device/reduce_scatter.h -index d0b5249..2dacd60 100644 ---- a/src/device/reduce_scatter.h -+++ b/src/device/reduce_scatter.h -@@ -254,7 +254,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - return; - } -@@ -278,7 +278,7 @@ struct RunWorkElement(scat); -+ prims.template process(scat); - } - } - return; diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl index 14469acdfc5aa1..c1e49a6b9f1dd2 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl @@ -60,6 +60,15 @@ alias( visibility = ["//visibility:public"], ) +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + cc_library( name = "hermetic_nccl_config", hdrs = ["nccl_config.h"], diff --git a/third_party/xla/third_party/tsl/third_party/protobuf/protobuf.patch b/third_party/xla/third_party/tsl/third_party/protobuf/protobuf.patch index 9d928ba175f330..ac33ccbf8c3aea 100644 --- a/third_party/xla/third_party/tsl/third_party/protobuf/protobuf.patch +++ b/third_party/xla/third_party/tsl/third_party/protobuf/protobuf.patch @@ -1,22 +1,46 @@ diff --git a/BUILD.bazel b/BUILD.bazel --- a/BUILD.bazel (revision 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66) -+++ b/BUILD.bazel (date 1670471682469) -@@ -68,6 +68,7 @@ ++++ b/BUILD.bazel (date 1714620794503) +@@ -68,6 +68,8 @@ copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, ++ local_defines = ["PROTOBUF_USE_DLLS", "LIBPROTOBUF_EXPORTS"], + alwayslink = 1, visibility = ["//visibility:public"], ) -@@ -135,6 +136,7 @@ +@@ -135,6 +137,8 @@ copts = COPTS, includes = ["src/"], linkopts = LINK_OPTS, ++ local_defines = ["PROTOBUF_USE_DLLS", "LIBPROTOBUF_EXPORTS"], + alwayslink = 1, visibility = ["//visibility:public"], deps = [":protobuf_lite"] + select({ "//build_defs:config_msvc": [], +@@ -1074,7 +1078,8 @@ + "@com_google_protobuf//:type_proto", + "@com_google_protobuf//:wrappers_proto", + ], +- command_line = "--cpp_out=$(OUT)", ++ command_line = "--cpp_out=dllexport_decl=PROTOBUF_EXPORT:$(OUT)", ++# command_line = "--cpp_out=$(OUT)", + runtime = ":protobuf", + visibility = ["//visibility:public"], + ) +diff --git a/protobuf.bzl b/protobuf.bzl +--- a/protobuf.bzl (revision 90b73ac3f0b10320315c2ca0d03a5a9b095d2f66) ++++ b/protobuf.bzl (date 1714611573270) +@@ -127,7 +127,7 @@ + use_grpc_plugin = (ctx.attr.plugin_language == "grpc" and ctx.attr.plugin) + path_tpl = "$(realpath %s)" if in_gen_dir else "%s" + if ctx.attr.gen_cc: +- args += [("--cpp_out=" + path_tpl) % gen_dir] ++ args += [("--cpp_out=dllexport_decl=PROTOBUF_EXPORT:" + path_tpl) % gen_dir] + outs.extend(_CcOuts([src.basename], use_grpc_plugin = use_grpc_plugin)) + if ctx.attr.gen_py: + args += [("--python_out=" + path_tpl) % gen_dir] diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 162531226..e93ec4809 100644 --- a/python/google/protobuf/pyext/descriptor.cc diff --git a/third_party/xla/third_party/tsl/third_party/py/BUILD b/third_party/xla/third_party/tsl/third_party/py/BUILD index 84eba77ce1a7af..0381d65bb27514 100644 --- a/third_party/xla/third_party/tsl/third_party/py/BUILD +++ b/third_party/xla/third_party/tsl/third_party/py/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") load("@python//:defs.bzl", "compile_pip_requirements") load("@python_version_repo//:py_version.bzl", "REQUIREMENTS") @@ -38,3 +39,16 @@ compile_pip_requirements( requirements_in = "requirements.in", requirements_txt = REQUIREMENTS, ) + +# Flag indicating if the target requires pre-built wheel. +bool_flag( + name = "wheel_dependency", + build_setting_default = False, +) + +config_setting( + name = "enable_wheel_dependency", + flag_values = { + ":wheel_dependency": "True", + }, +) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/LICENSE b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/LICENSE deleted file mode 100644 index d645695673349e..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/LICENSE +++ /dev/null @@ -1,202 +0,0 @@ - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD deleted file mode 100644 index f386124a36dfe8..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.BUILD +++ /dev/null @@ -1,64 +0,0 @@ -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") - -package( - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -exports_files(["LICENSE"]) - -cc_library( - name = "float8", - hdrs = ["include/float8.h"], - include_prefix = "ml_dtypes", - # Internal headers are all relative to . but other packages - # include these headers with the prefix. - includes = [ - ".", - "ml_dtypes", - ], - deps = ["@eigen_archive//:eigen3"], -) - -cc_library( - name = "intn", - hdrs = ["include/intn.h"], - include_prefix = "ml_dtypes", - # Internal headers are all relative to . but other packages - # include these headers with the prefix. - includes = [ - ".", - "ml_dtypes", - ], -) - -pybind_extension( - name = "_ml_dtypes_ext", - srcs = [ - "_src/common.h", - "_src/custom_float.h", - "_src/dtypes.cc", - "_src/int4_numpy.h", - "_src/numpy.cc", - "_src/numpy.h", - "_src/ufuncs.h", - ], - includes = ["ml_dtypes"], - visibility = [":__subpackages__"], - deps = [ - ":float8", - ":intn", - "@eigen_archive//:eigen3", - "@local_tsl//third_party/py/numpy:headers", - ], -) - -py_library( - name = "ml_dtypes", - srcs = [ - "__init__.py", - "_finfo.py", - "_iinfo.py", - ], - deps = [":_ml_dtypes_ext"], -) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD deleted file mode 100644 index c811379a19dabd..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/ml_dtypes.tests.BUILD +++ /dev/null @@ -1,71 +0,0 @@ -package( - default_visibility = ["//visibility:public"], -) - -py_library( - name = "testing_base", - deps = [ - "//:ml_dtypes", - "@absl_py//absl/testing:absltest", - "@absl_py//absl/testing:parameterized", - "@local_tsl//third_party/py/numpy", - ], -) - -py_test( - name = "custom_float_test", - srcs = ["custom_float_test.py"], - main = "custom_float_test.py", - deps = [":testing_base"], -) - -py_test( - name = "int4_test", - srcs = ["int4_test.py"], - main = "int4_test.py", - deps = [":testing_base"], -) - -py_test( - name = "iinfo_test", - srcs = ["iinfo_test.py"], - main = "iinfo_test.py", - deps = [":testing_base"], -) - -py_test( - name = "finfo_test", - srcs = ["finfo_test.py"], - main = "finfo_test.py", - deps = [":testing_base"], -) - -py_test( - name = "metadata_test", - srcs = ["metadata_test.py"], - main = "metadata_test.py", - deps = [":testing_base"], -) - -cc_test( - name = "float8_test", - srcs = ["float8_test.cc"], - linkstatic = 1, - deps = [ - "//:float8", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@eigen_archive//:eigen3", - ], -) - -cc_test( - name = "intn_test_cc", - srcs = ["intn_test.cc"], - linkstatic = 1, - deps = [ - "//:intn", - "@com_google_googletest//:gtest_main", - "@eigen_archive//:eigen3", - ], -) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/workspace.bzl b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/workspace.bzl deleted file mode 100644 index 51505bf3a1460d..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/ml_dtypes/workspace.bzl +++ /dev/null @@ -1,22 +0,0 @@ -"""Provides the repo macro to import ml_dtypes. - -ml_dtypes provides machine-learning-specific data-types like bfloat16, -float8 varieties, and int4. -""" - -load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") - -def repo(): - ML_DTYPES_COMMIT = "24084d9ed2c3d45bf83b7a9bff833aa185bf9172" - ML_DTYPES_SHA256 = "c916a3e6b3d9bdcb476f506fdbbecb6d5e9f21f82f221dfcb42b320b4e85e55a" - tf_http_archive( - name = "ml_dtypes", - build_file = "//third_party/py/ml_dtypes:ml_dtypes.BUILD", - link_files = { - "//third_party/py/ml_dtypes:ml_dtypes.tests.BUILD": "tests/BUILD.bazel", - "//third_party/py/ml_dtypes:LICENSE": "LICENSE", - }, - sha256 = ML_DTYPES_SHA256, - strip_prefix = "ml_dtypes-{commit}/ml_dtypes".format(commit = ML_DTYPES_COMMIT), - urls = tf_mirror_urls("https://github.com/jax-ml/ml_dtypes/archive/{commit}/ml_dtypes-{commit}.tar.gz".format(commit = ML_DTYPES_COMMIT)), - ) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/BUILD deleted file mode 100644 index c80cc5287bc469..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -licenses(["restricted"]) - -package(default_visibility = ["//visibility:public"]) - -py_library( - name = "numpy", - srcs = ["tf_numpy_dummy.py"], - srcs_version = "PY3", -) - -alias( - name = "headers", - actual = "@local_config_python//:numpy_headers", -) - -genrule( - name = "dummy", - outs = ["tf_numpy_dummy.py"], - cmd = "touch $@", - visibility = ["//visibility:private"], -) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/README.md b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/README.md deleted file mode 100644 index 4e58b9df87b5ec..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# numpy_ops - -The folder tf_numpy_api/ contains lists of NumPy API symbols that the -`numpy_ops` internal module in TensorFlow implements. diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD deleted file mode 100644 index 070f8ab8a65352..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/BUILD +++ /dev/null @@ -1,12 +0,0 @@ -# TensorFlow API backwards compatibility test goldens for tf.experimental.numpy. - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//visibility:public"], - licenses = ["notice"], -) - -filegroup( - name = "api_golden", - srcs = glob(["*.pbtxt"]), -) diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt deleted file mode 100644 index 9198264c02961f..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.ndarray.pbtxt +++ /dev/null @@ -1,51 +0,0 @@ -path: "tensorflow.experimental.numpy.ndarray" -tf_class { - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" - member { - name: "OVERLOADABLE_OPERATORS" - mtype: "" - } - member { - name: "dtype" - mtype: "" - } - member { - name: "name" - mtype: "" - } - member { - name: "ndim" - mtype: "" - } - member { - name: "shape" - mtype: "" - } - member_method { - name: "__init__" - } - member_method { - name: "eval" - argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "experimental_ref" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "get_shape" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ref" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "set_shape" - argspec: "args=[\'self\', \'shape\'], varargs=None, keywords=None, defaults=None" - } -} diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt deleted file mode 100644 index 2f5490ad0c922f..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.pbtxt +++ /dev/null @@ -1,919 +0,0 @@ -path: "tensorflow.experimental.numpy" -tf_module { - member { - name: "bool_" - mtype: "" - } - member { - name: "complex128" - mtype: "" - } - member { - name: "complex64" - mtype: "" - } - member { - name: "complex_" - mtype: "" - } - member { - name: "e" - mtype: "" - } - member { - name: "float16" - mtype: "" - } - member { - name: "float32" - mtype: "" - } - member { - name: "float64" - mtype: "" - } - member { - name: "float_" - mtype: "" - } - member { - name: "iinfo" - mtype: "" - } - member { - name: "inexact" - mtype: "" - } - member { - name: "inf" - mtype: "" - } - member { - name: "int16" - mtype: "" - } - member { - name: "int32" - mtype: "" - } - member { - name: "int64" - mtype: "" - } - member { - name: "int8" - mtype: "" - } - member { - name: "int_" - mtype: "" - } - member { - name: "ndarray" - mtype: "" - } - member { - name: "newaxis" - mtype: "" - } - member { - name: "object_" - mtype: "" - } - member { - name: "pi" - mtype: "" - } - member { - name: "random" - mtype: "" - } - member { - name: "string_" - mtype: "" - } - member { - name: "uint16" - mtype: "" - } - member { - name: "uint32" - mtype: "" - } - member { - name: "uint64" - mtype: "" - } - member { - name: "uint8" - mtype: "" - } - member { - name: "unicode_" - mtype: "" - } - member_method { - name: "abs" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "absolute" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "add" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "all" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "allclose" - argspec: "args=[\'a\', \'b\', \'rtol\', \'atol\', \'equal_nan\'], varargs=None, keywords=None, defaults=[\'1e-05\', \'1e-08\', \'False\'], " - } - member_method { - name: "amax" - argspec: "args=[\'a\', \'axis\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "amin" - argspec: "args=[\'a\', \'axis\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "angle" - argspec: "args=[\'z\', \'deg\'], varargs=None, keywords=None, defaults=[\'False\'], " - } - member_method { - name: "any" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "append" - argspec: "args=[\'arr\', \'values\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "arange" - argspec: "args=[\'start\', \'stop\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\'], " - } - member_method { - name: "arccos" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arccosh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arcsin" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arcsinh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctan2" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "arctanh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "argmax" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "argmin" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "argsort" - argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], " - } - member_method { - name: "around" - argspec: "args=[\'a\', \'decimals\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "array" - argspec: "args=[\'val\', \'dtype\', \'copy\', \'ndmin\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'0\'], " - } - member_method { - name: "array_equal" - argspec: "args=[\'a1\', \'a2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "asanyarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "asarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "ascontiguousarray" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "atleast_1d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "atleast_2d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "atleast_3d" - argspec: "args=[], varargs=arys, keywords=None, defaults=None" - } - member_method { - name: "average" - argspec: "args=[\'a\', \'axis\', \'weights\', \'returned\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "bitwise_and" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_not" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_or" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "bitwise_xor" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "broadcast_arrays" - argspec: "args=[], varargs=args, keywords=kwargs, defaults=None" - } - member_method { - name: "broadcast_to" - argspec: "args=[\'array\', \'shape\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cbrt" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ceil" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "clip" - argspec: "args=[\'a\', \'a_min\', \'a_max\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "compress" - argspec: "args=[\'condition\', \'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "concatenate" - argspec: "args=[\'arys\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "conj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "conjugate" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "copy" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cos" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "cosh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "count_nonzero" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "cross" - argspec: "args=[\'a\', \'b\', \'axisa\', \'axisb\', \'axisc\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'-1\', \'None\'], " - } - member_method { - name: "cumprod" - argspec: "args=[\'a\', \'axis\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "cumsum" - argspec: "args=[\'a\', \'axis\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "deg2rad" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "diag" - argspec: "args=[\'v\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "diag_indices" - argspec: "args=[\'n\', \'ndim\'], varargs=None, keywords=None, defaults=[\'2\'], " - } - member_method { - name: "diagflat" - argspec: "args=[\'v\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "diagonal" - argspec: "args=[\'a\', \'offset\', \'axis1\', \'axis2\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'1\'], " - } - member_method { - name: "diff" - argspec: "args=[\'a\', \'n\', \'axis\'], varargs=None, keywords=None, defaults=[\'1\', \'-1\'], " - } - member_method { - name: "divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "divmod" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dot" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "dstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "einsum" - argspec: "args=[\'subscripts\'], varargs=operands, keywords=kwargs, defaults=None" - } - member_method { - name: "empty" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "empty_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "exp" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "exp2" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "expand_dims" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "experimental_enable_numpy_behavior" - argspec: "args=[\'prefer_float32\'], varargs=None, keywords=None, defaults=[\'False\'], " - } - member_method { - name: "expm1" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "eye" - argspec: "args=[\'N\', \'M\', \'k\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \"\"], " - } - member_method { - name: "fabs" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "finfo" - argspec: "args=[\'dtype\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "fix" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "flatten" - argspec: "args=[\'a\', \'order\'], varargs=None, keywords=None, defaults=[\'C\'], " - } - member_method { - name: "flip" - argspec: "args=[\'m\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "fliplr" - argspec: "args=[\'m\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "flipud" - argspec: "args=[\'m\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "float_power" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "floor" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "floor_divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "full" - argspec: "args=[\'shape\', \'fill_value\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "full_like" - argspec: "args=[\'a\', \'fill_value\', \'dtype\', \'order\', \'subok\', \'shape\'], varargs=None, keywords=None, defaults=[\'None\', \'K\', \'True\', \'None\'], " - } - member_method { - name: "gcd" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "geomspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'None\', \'0\'], " - } - member_method { - name: "greater" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "greater_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "heaviside" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "hypot" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "identity" - argspec: "args=[\'n\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "imag" - argspec: "args=[\'val\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "inner" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isclose" - argspec: "args=[\'a\', \'b\', \'rtol\', \'atol\', \'equal_nan\'], varargs=None, keywords=None, defaults=[\'1e-05\', \'1e-08\', \'False\'], " - } - member_method { - name: "iscomplex" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "iscomplexobj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isfinite" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isinf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isnan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isneginf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isposinf" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isreal" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isrealobj" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "isscalar" - argspec: "args=[\'num\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "issubdtype" - argspec: "args=[\'arg1\', \'arg2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ix_" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } - member_method { - name: "kron" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "lcm" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "less" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "less_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "linspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'retstep\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'False\', \"\", \'0\'], " - } - member_method { - name: "log" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log10" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log1p" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "log2" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logaddexp" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logaddexp2" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_and" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_not" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_or" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logical_xor" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "logspace" - argspec: "args=[\'start\', \'stop\', \'num\', \'endpoint\', \'base\', \'dtype\', \'axis\'], varargs=None, keywords=None, defaults=[\'50\', \'True\', \'10.0\', \'None\', \'0\'], " - } - member_method { - name: "matmul" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "max" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "maximum" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "mean" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'out\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " - } - member_method { - name: "meshgrid" - argspec: "args=[], varargs=xi, keywords=kwargs, defaults=None" - } - member_method { - name: "min" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "minimum" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "mod" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "moveaxis" - argspec: "args=[\'a\', \'source\', \'destination\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "multiply" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nanmean" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "nanprod" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "nansum" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " - } - member_method { - name: "ndim" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "negative" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nextafter" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "nonzero" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "not_equal" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ones" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "ones_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "outer" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "pad" - argspec: "args=[\'array\', \'pad_width\', \'mode\'], varargs=None, keywords=kwargs, defaults=None" - } - member_method { - name: "polyval" - argspec: "args=[\'p\', \'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "positive" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "power" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "prod" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "promote_types" - argspec: "args=[\'type1\', \'type2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ptp" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "rad2deg" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "ravel" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "real" - argspec: "args=[\'val\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "reciprocal" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "remainder" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "repeat" - argspec: "args=[\'a\', \'repeats\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "reshape" - argspec: "args=[\'a\', \'newshape\', \'order\'], varargs=None, keywords=None, defaults=[\'C\'], " - } - member_method { - name: "result_type" - argspec: "args=[], varargs=arrays_and_dtypes, keywords=None, defaults=None" - } - member_method { - name: "roll" - argspec: "args=[\'a\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "rot90" - argspec: "args=[\'m\', \'k\', \'axes\'], varargs=None, keywords=None, defaults=[\'1\', \'(0, 1)\'], " - } - member_method { - name: "round" - argspec: "args=[\'a\', \'decimals\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "select" - argspec: "args=[\'condlist\', \'choicelist\', \'default\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "shape" - argspec: "args=[\'a\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sign" - argspec: "args=[\'x\', \'out\', \'where\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\'], " - } - member_method { - name: "signbit" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sin" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sinc" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sinh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "size" - argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "sort" - argspec: "args=[\'a\', \'axis\', \'kind\', \'order\'], varargs=None, keywords=None, defaults=[\'-1\', \'quicksort\', \'None\'], " - } - member_method { - name: "split" - argspec: "args=[\'ary\', \'indices_or_sections\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "sqrt" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "square" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "squeeze" - argspec: "args=[\'a\', \'axis\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "stack" - argspec: "args=[\'arrays\', \'axis\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "std" - argspec: "args=[\'a\', \'axis\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "subtract" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "sum" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " - } - member_method { - name: "swapaxes" - argspec: "args=[\'a\', \'axis1\', \'axis2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "take" - argspec: "args=[\'a\', \'indices\', \'axis\', \'out\', \'mode\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'clip\'], " - } - member_method { - name: "take_along_axis" - argspec: "args=[\'arr\', \'indices\', \'axis\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tan" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tanh" - argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "tensordot" - argspec: "args=[\'a\', \'b\', \'axes\'], varargs=None, keywords=None, defaults=[\'2\'], " - } - member_method { - name: "tile" - argspec: "args=[\'a\', \'reps\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "trace" - argspec: "args=[\'a\', \'offset\', \'axis1\', \'axis2\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'1\', \'None\'], " - } - member_method { - name: "transpose" - argspec: "args=[\'a\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "tri" - argspec: "args=[\'N\', \'M\', \'k\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], " - } - member_method { - name: "tril" - argspec: "args=[\'m\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "triu" - argspec: "args=[\'m\', \'k\'], varargs=None, keywords=None, defaults=[\'0\'], " - } - member_method { - name: "true_divide" - argspec: "args=[\'x1\', \'x2\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vander" - argspec: "args=[\'x\', \'N\', \'increasing\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " - } - member_method { - name: "var" - argspec: "args=[\'a\', \'axis\', \'dtype\', \'out\', \'ddof\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'0\', \'None\'], " - } - member_method { - name: "vdot" - argspec: "args=[\'a\', \'b\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vsplit" - argspec: "args=[\'ary\', \'indices_or_sections\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "vstack" - argspec: "args=[\'tup\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "where" - argspec: "args=[\'condition\', \'x\', \'y\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " - } - member_method { - name: "zeros" - argspec: "args=[\'shape\', \'dtype\'], varargs=None, keywords=None, defaults=[\"\"], " - } - member_method { - name: "zeros_like" - argspec: "args=[\'a\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\'], " - } -} diff --git a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt b/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt deleted file mode 100644 index 61a4766f3f8f0f..00000000000000 --- a/third_party/xla/third_party/tsl/third_party/py/non_hermetic/numpy/tf_numpy_api/tensorflow.experimental.numpy.random.pbtxt +++ /dev/null @@ -1,35 +0,0 @@ -path: "tensorflow.experimental.numpy.random" -tf_module { - member_method { - name: "poisson" - argspec: "args=[\'lam\', \'size\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\'], " - } - member_method { - name: "rand" - argspec: "args=[], varargs=size, keywords=None, defaults=None" - } - member_method { - name: "randint" - argspec: "args=[\'low\', \'high\', \'size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"\"], " - } - member_method { - name: "randn" - argspec: "args=[], varargs=args, keywords=None, defaults=None" - } - member_method { - name: "random" - argspec: "args=[\'size\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "seed" - argspec: "args=[\'s\'], varargs=None, keywords=None, defaults=None" - } - member_method { - name: "standard_normal" - argspec: "args=[\'size\'], varargs=None, keywords=None, defaults=[\'None\'], " - } - member_method { - name: "uniform" - argspec: "args=[\'low\', \'high\', \'size\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \'None\'], " - } -} diff --git a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl index 13aed2b687129f..6fe63fb9c1e674 100644 --- a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl @@ -14,6 +14,7 @@ def _python_repository_impl(ctx): ctx.file("BUILD", "") wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow") wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False) + macos_deployment_target = ctx.os.environ.get("MACOSX_DEPLOYMENT_TARGET", "") requirements = None for i in range(0, len(ctx.attr.requirements_locks)): @@ -34,13 +35,11 @@ Please check python_init_repositories() in your WORKSPACE file. requirements_with_local_wheels = str(requirements) - local_wheels_dir = ctx.os.environ.get("LOCAL_WHEELS_DIR", "") - if ctx.attr.local_wheel_workspaces or local_wheels_dir: + if ctx.attr.local_wheel_workspaces: local_wheel_requirements = _get_injected_local_wheels( ctx, version, ctx.attr.local_wheel_workspaces, - local_wheels_dir, ) requirements_content = [ctx.read(requirements)] + local_wheel_requirements merged_requirements_content = "\n".join(requirements_content) @@ -55,6 +54,13 @@ Please check python_init_repositories() in your WORKSPACE file. merged_requirements_content, ) + use_pywrap_rules = bool( + ctx.os.environ.get("USE_PYWRAP_RULES", False), + ) + + if use_pywrap_rules: + print("!!!Using pywrap rules instead of directly creating .so objects!!!") # buildifier: disable=print + ctx.file( "py_version.bzl", """ @@ -64,12 +70,16 @@ WHEEL_NAME = "{wheel_name}" WHEEL_COLLAB = "{wheel_collab}" REQUIREMENTS = "{requirements}" REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}" +USE_PYWRAP_RULES = {use_pywrap_rules} +MACOSX_DEPLOYMENT_TARGET = "{macos_deployment_target}" """.format( version = version, wheel_name = wheel_name, wheel_collab = wheel_collab, requirements = str(requirements), requirements_with_local_wheels = requirements_with_local_wheels, + use_pywrap_rules = use_pywrap_rules, + macos_deployment_target = macos_deployment_target, ), ) @@ -118,8 +128,7 @@ def _parse_python_version(version_str): def _get_injected_local_wheels( ctx, py_version, - local_wheel_workspaces, - local_wheels_dir): + local_wheel_workspaces): local_wheel_requirements = [] py_ver_marker = "-cp%s-" % py_version.replace(".", "") py_major_ver_marker = "-py%s-" % py_version.split(".")[0] @@ -140,18 +149,6 @@ def _get_injected_local_wheels( ctx.attr.local_wheel_inclusion_list, ctx.attr.local_wheel_exclusion_list, ) - if local_wheels_dir: - dist_folder_path = ctx.path(local_wheels_dir) - if dist_folder_path.exists: - dist_wheels = dist_folder_path.readdir() - _process_dist_wheels( - dist_wheels, - wheels, - py_ver_marker, - py_major_ver_marker, - ctx.attr.local_wheel_inclusion_list, - ctx.attr.local_wheel_exclusion_list, - ) for wheel_name, wheel_path in wheels.items(): local_wheel_requirements.append( @@ -200,6 +197,7 @@ python_repository = repository_rule( "HERMETIC_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB", + "USE_PYWRAP_RULES", ], local = True, ) diff --git a/third_party/xla/third_party/tsl/third_party/py/python_wheel_library.bzl b/third_party/xla/third_party/tsl/third_party/py/python_wheel_library.bzl new file mode 100644 index 00000000000000..01ea7554e2323a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/python_wheel_library.bzl @@ -0,0 +1,60 @@ +""" Macros to unpack a wheel and use its content as a py_library. """ + +def _unpacked_wheel_impl(ctx): + output_dir = ctx.actions.declare_directory(ctx.label.name) + libs = [] + for dep in ctx.attr.deps: + linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list() + for linker_input in linker_inputs: + if linker_input.libraries and linker_input.libraries[0].dynamic_library: + lib = linker_input.libraries[0].dynamic_library + libs.append(lib) + script = """ + {zipper} x {wheel} -d {output} + for lib in {libs}; do + cp $lib {output}/tensorflow + done + """.format( + zipper = ctx.executable.zipper.path, + wheel = ctx.file.wheel.path, + output = output_dir.path, + libs = " ".join(["'%s'" % lib.path for lib in libs]), + ) + ctx.actions.run_shell( + inputs = [ctx.file.wheel] + libs, + command = script, + outputs = [output_dir], + tools = [ctx.executable.zipper], + ) + + return [ + DefaultInfo(files = depset([output_dir])), + ] + +_unpacked_wheel = rule( + implementation = _unpacked_wheel_impl, + attrs = { + "wheel": attr.label(mandatory = True, allow_single_file = True), + "zipper": attr.label( + default = Label("@bazel_tools//tools/zip:zipper"), + cfg = "exec", + executable = True, + ), + "deps": attr.label_list(providers = [CcInfo]), + }, +) + +def wheel_library(name, wheel, deps = [], wheel_deps = []): + unpacked_wheel_name = name + "_unpacked_wheel" + _unpacked_wheel( + name = unpacked_wheel_name, + wheel = wheel, + deps = wheel_deps, + ) + native.py_library( + name = name, + data = [":" + unpacked_wheel_name], + imports = [unpacked_wheel_name], + deps = deps, + visibility = ["//visibility:public"], + ) diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_python.patch b/third_party/xla/third_party/tsl/third_party/py/rules_python.patch index 7d59ac107cc952..3dbe06dd2d6d96 100644 --- a/third_party/xla/third_party/tsl/third_party/py/rules_python.patch +++ b/third_party/xla/third_party/tsl/third_party/py/rules_python.patch @@ -1,32 +1,28 @@ -Subject: [PATCH] Add Python 3.13.0rc2 support to rules_python ---- -Index: python/versions.bzl -<+>UTF-8 -=================================================================== diff --git a/python/versions.bzl b/python/versions.bzl ---- a/python/versions.bzl (revision 084b877c98b580839ceab2b071b02fc6768f3de6) -+++ b/python/versions.bzl (date 1726256410148) -@@ -484,6 +484,19 @@ +index fd385cd1..eb4133f1 100644 +--- a/python/versions.bzl ++++ b/python/versions.bzl +@@ -484,6 +484,19 @@ TOOL_VERSIONS = { }, "strip_prefix": "python", }, + "3.13.0": { -+ "url": "20240909/cpython-{python_version}rc2+20240909-{platform}-{build}.tar.gz", ++ "url": "20241008/cpython-{python_version}+20241008-{platform}-{build}.tar.gz", + "sha256": { -+ "aarch64-apple-darwin": "5d38ca1e6b030b004714e10813903e906c6b8f2a6361770df4512a838f4a4a9f", -+ "aarch64-unknown-linux-gnu": "85e103fc81a1fcf94a93180f6df42e39a7dc15d4b711705e133dc2ec847552e7", -+ "ppc64le-unknown-linux-gnu": "3be3d8aefae579c420fc6abf01658ae89fda8120154f989575b08085d2f8d6dc", -+ "s390x-unknown-linux-gnu": "6ec5130d62473368ecc7e55338bf1cc58607dbfe8088959cab51265b9f13c38d", -+ "x86_64-apple-darwin": "c3dcd4314324159945dc19342c73b9deb8de0f2d1709171427dd52f1a05eecca", -+ "x86_64-pc-windows-msvc": "31282f912e984d399c56925dfb69a4f3ce76226dfb4806b09f37e3b4a15e5a30", -+ "x86_64-unknown-linux-gnu": "028581cce5004c66775a3ae8b3ed65681ab4b289608dfd1aec3354d169216099", ++ "aarch64-apple-darwin": "5d3cb8d7ca4cfbbe7ae1f118f26be112ee417d982fab8c6d85cfd8ccccf70718", ++ "aarch64-unknown-linux-gnu": "c1142af8f2c85923d2ba8201a35b913bb903a5d15f052c38bbecf2f49e2342dc", ++ "ppc64le-unknown-linux-gnu": "1be64a330499fed4e1f864b97eef5445b0e4abc0559ae45df3108981800cf998", ++ "s390x-unknown-linux-gnu": "c0b1cc51426feadaa932fdd9afd9a9af789916e128e48ac8909f9a269bbbd749", ++ "x86_64-apple-darwin": "b58ca12d9ae14bbd79f9e5cf4b748211ff1953e59abeac63b0f4e8e49845669f", ++ "x86_64-pc-windows-msvc": "c7651a7a575104f47c808902b020168057f3ad80f277e54cecfaf79a9ff50e22", ++ "x86_64-unknown-linux-gnu": "455200e1a202e9d9ef4b630c04af701c0a91dcaa6462022efc76893fc762ec95", + }, + "strip_prefix": "python", + }, } # buildifier: disable=unsorted-dict-items -@@ -493,6 +506,7 @@ +@@ -493,6 +506,7 @@ MINOR_MAPPING = { "3.10": "3.10.14", "3.11": "3.11.9", "3.12": "3.12.3", diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/BUILD b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/BUILD new file mode 100644 index 00000000000000..595b43626f01e4 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/BUILD @@ -0,0 +1,16 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + +exports_files(["pybind_extension.py.tpl"]) + +bzl_library( + name = "pywrap_bzl", + srcs = [ + "pywrap.bzl", + # copybara:uncomment "pywrap.google.bzl", + "pywrap.impl.bzl", + ], + # copybara:uncomment parse_tests = False, + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl new file mode 100644 index 00000000000000..98428b51486efd --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pybind_extension.py.tpl @@ -0,0 +1,49 @@ +import os +import re + + +def __calc_import_path(): + module_name = os.path.basename(__file__)[:-3] + outer_module_name = "" # template_val + for var in ["PYWRAP_TARGET", "TEST_TARGET"]: + path = __find_pywrap_module_by_target_label(os.environ.get(var)) + if path: + return "%s.%s%s" % (path, outer_module_name, module_name) + + for var in ["RUNFILES_MANIFEST_FILE", "RUNFILES_DIR"]: + path = __find_pywrap_module_by_runfiles_env(os.environ.get(var)) + if path: + return "%s.%s%s" % (path, outer_module_name, module_name) + + raise RuntimeError("Could not detect original test/binary location") + + +def __find_pywrap_module_by_target_label(target_label): + if target_label: + return target_label.split("//", 1)[1].split(":")[0].replace("/", ".") + return None + + +def __find_pywrap_module_by_runfiles_env(runfiles_env_var): + pattern = re.compile( + r"bazel-out/.*/bin/(?P[\w/]*)/(?P\w+)(\.exe)?\.runfiles" + ) + if runfiles_env_var: + match = pattern.search(runfiles_env_var) + return match.group("pkg").replace("/", ".") + return None + + +def __update_globals(pywrap_m): + if hasattr(pywrap_m, '__all__'): + all_names = pywrap_m.__all__ + else: + all_names = [name for name in dir(pywrap_m) if not name.startswith('_')] + + extra_names = [] # template_val + all_names.extend(extra_names) + globals().update({name: getattr(pywrap_m, name) for name in all_names}) + + +__pywrap_m = __import__(__calc_import_path(), fromlist=["*"]) +__update_globals(__pywrap_m) diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.bzl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.bzl new file mode 100644 index 00000000000000..e7b038f571cae2 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.bzl @@ -0,0 +1,20 @@ +load( + "//third_party/py/rules_pywrap:pywrap.default.bzl", + _pybind_extension = "pybind_extension", + _pywrap_aware_cc_import = "pywrap_aware_cc_import", + _pywrap_aware_filegroup = "pywrap_aware_filegroup", + _pywrap_aware_genrule = "pywrap_aware_genrule", + _pywrap_common_library = "pywrap_common_library", + _pywrap_library = "pywrap_library", + _stripped_cc_info = "stripped_cc_info", + _use_pywrap_rules = "use_pywrap_rules", +) + +pybind_extension = _pybind_extension +use_pywrap_rules = _use_pywrap_rules +pywrap_library = _pywrap_library +pywrap_common_library = _pywrap_common_library +stripped_cc_info = _stripped_cc_info +pywrap_aware_filegroup = _pywrap_aware_filegroup +pywrap_aware_genrule = _pywrap_aware_genrule +pywrap_aware_cc_import = _pywrap_aware_cc_import diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl new file mode 100644 index 00000000000000..b1514f100a4934 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.default.bzl @@ -0,0 +1,146 @@ +# TODO(b/356020232): remove entire file and all usages after migration is done +load("@python_version_repo//:py_version.bzl", "USE_PYWRAP_RULES") +load( + "//third_party/py/rules_pywrap:pywrap.impl.bzl", + _pybind_extension = "pybind_extension", + _pywrap_common_library = "pywrap_common_library", + _pywrap_library = "pywrap_library", + _stripped_cc_info = "stripped_cc_info", +) + +def pybind_extension( + name, # original + deps, # original + srcs = [], # original + private_deps = [], # original + visibility = None, # original + win_def_file = None, # original + testonly = None, # original + compatible_with = None, # original + outer_module_name = "", # deprecate + additional_exported_symbols = [], + data = None, # original + # Garbage parameters, exist only to maingain backward compatibility for + # a while. Will be removed once migration is fully completed + + # To patch top-level deps lists in sophisticated cases + pywrap_ignored_deps_filter = ["@pybind11", "@pybind11//:pybind11"], + pywrap_private_deps_filter = [ + "@pybind11_abseil//pybind11_abseil:absl_casters", + "@pybind11_abseil//pybind11_abseil:import_status_module", + "@pybind11_abseil//pybind11_abseil:status_casters", + "@pybind11_protobuf//pybind11_protobuf:native_proto_caster", + ], + pytype_srcs = None, # alias for data + hdrs = [], # merge into sources + pytype_deps = None, # ignore? + ignore_link_in_framework = None, # ignore + dynamic_deps = [], # ignore + static_deps = [], # ignore + enable_stub_generation = None, # ignore + module_name = None, # ignore + link_in_framework = None, # ignore + additional_stubgen_deps = None, # ignore + **kwargs): + _ignore = [ + ignore_link_in_framework, + dynamic_deps, + static_deps, + enable_stub_generation, + module_name, + link_in_framework, + additional_stubgen_deps, + pytype_deps, + ] + + private_deps_filter_dict = {k: None for k in pywrap_private_deps_filter} + ignored_deps_filter_dict = {k: None for k in pywrap_ignored_deps_filter} + + actual_srcs = srcs + hdrs + + actual_data = data + if pytype_srcs: + data = pytype_srcs + + actual_deps = [] + actual_private_deps = [] + actual_default_deps = ["@pybind11//:pybind11"] + + if type(deps) == list: + for dep in deps: + if dep in ignored_deps_filter_dict: + continue + if dep in private_deps_filter_dict: + actual_private_deps.append(dep) + continue + actual_deps.append(dep) + else: + actual_deps = deps + actual_default_deps = [] + + _pybind_extension( + name = name, + deps = actual_deps, + srcs = actual_srcs, + private_deps = actual_private_deps, + visibility = visibility, + win_def_file = win_def_file, + testonly = testonly, + compatible_with = compatible_with, + outer_module_name = outer_module_name, + additional_exported_symbols = additional_exported_symbols, + data = actual_data, + default_deps = actual_default_deps, + **kwargs + ) + +def use_pywrap_rules(): + return USE_PYWRAP_RULES + +def pywrap_library(name, **kwargs): + if use_pywrap_rules(): + _pywrap_library( + name = name, + **kwargs + ) + +def pywrap_common_library(name, **kwargs): + if use_pywrap_rules(): + _pywrap_common_library( + name = name, + **kwargs + ) + +def stripped_cc_info(name, **kwargs): + if use_pywrap_rules(): + _stripped_cc_info( + name = name, + **kwargs + ) + +def pywrap_aware_filegroup(name, **kwargs): + if use_pywrap_rules(): + pass + else: + native.filegroup( + name = name, + **kwargs + ) + +def pywrap_aware_genrule(name, **kwargs): + if use_pywrap_rules(): + pass + else: + native.genrule( + name = name, + **kwargs + ) + +def pywrap_aware_cc_import(name, **kwargs): + if use_pywrap_rules(): + pass + else: + native.cc_import( + name = name, + **kwargs + ) diff --git a/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl new file mode 100644 index 00000000000000..f33012c3876523 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/py/rules_pywrap/pywrap.impl.bzl @@ -0,0 +1,731 @@ +load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") + +PywrapInfo = provider( + fields = { + "cc_info": "Wrapped CcInfo", + "private_deps": "Libraries to link only to individual pywrap libraries, but not in commmon library", + "owner": "Owner's label", + "py_stub": "Pybind Python stub used to resolve cross-package references", + "outer_module_name": "Outer module name for deduping libraries with the same name", + "cc_only": "True if this PywrapInfo represents cc-only library (no PyIni_)", + }, +) + +CollectedPywrapInfo = provider( + fields = { + "pywrap_infos": "depset of PywrapInfo providers", + }, +) + +PywrapFilters = provider( + fields = { + "py_cc_linker_inputs": "", + "cc_linker_inputs": "", + "pywrap_private_linker_inputs": "", + }, +) + +def pywrap_library( + name, + deps, + py_cc_deps_filter = [], + cc_deps_filter = [], + linkopts = [], + py_cc_linkopts = [], + win_def_file = None, + py_cc_win_def_file = None, + pywrap_count = None, + extra_deps = ["@pybind11//:pybind11"], + visibility = None, + testonly = None, + compatible_with = None): + # 0) If pywrap_count is not specified, assume we pass pybind_extension, + # targets directly, so actual pywrap_count should just be equal to number + # of deps. + actual_pywrap_count = len(deps) if pywrap_count == None else pywrap_count + + # 1) Create common libraries cc-only (C API) and py-specific (parts reused + # by different pywrap libraries but dependin on Python symbols). + # The common library should link in everything except the object file with + # Python Extension's init function PyInit_. + info_collector_name = "_%s_info_collector" % name + collected_pywrap_infos( + name = info_collector_name, + deps = deps, + pywrap_count = actual_pywrap_count, + ) + + linker_input_filters_name = "_%s_linker_input_filters" % name + _linker_input_filters( + name = linker_input_filters_name, + dep = ":%s" % info_collector_name, + py_cc_deps_filter = py_cc_deps_filter, + cc_deps_filter = cc_deps_filter, + ) + + # _internal binary + common_split_name = "_%s_split" % name + _pywrap_split_library( + name = common_split_name, + mode = "cc_common", + dep = ":%s" % info_collector_name, + linker_input_filters = "%s" % linker_input_filters_name, + testonly = testonly, + compatible_with = compatible_with, + ) + + common_cc_binary_name = "%s_internal" % name + common_import_name = _construct_common_binary( + common_cc_binary_name, + [":%s" % common_split_name], + linkopts, + testonly, + compatible_with, + win_def_file, + None, + ) + + # _py_internal binary + py_common_split_name = "_%s_py_split" % name + _pywrap_split_library( + name = py_common_split_name, + mode = "py_common", + dep = ":%s" % info_collector_name, + linker_input_filters = "%s" % linker_input_filters_name, + testonly = testonly, + compatible_with = compatible_with, + ) + + common_py_cc_binary_name = "%s_py_internal" % name + common_py_import_name = _construct_common_binary( + common_py_cc_binary_name, + [ + ":%s" % py_common_split_name, + ":%s" % common_import_name, + "@pybind11//:pybind11", + ], + py_cc_linkopts, + testonly, + compatible_with, + py_cc_win_def_file, + ["PROTOBUF_USE_DLLS"], + ) + + common_deps = extra_deps + [ + ":%s" % common_import_name, + ":%s" % common_py_import_name, + ] + binaries_data = [ + ":%s" % common_cc_binary_name, + ":%s" % common_py_cc_binary_name, + ] + + # 2) Create individual super-thin pywrap libraries, which depend on the + # common one. The individual libraries must link in statically only the + # object file with Python Extension's init function PyInit_ + # + shared_objects = [] + for pywrap_index in range(0, actual_pywrap_count): + dep_name = "_%s_%s" % (name, pywrap_index) + shared_object_name = "%s_shared_object" % dep_name + win_def_name = "%s_win_def" % dep_name + pywrap_name = "%s_pywrap" % dep_name + + _pywrap_split_library( + name = pywrap_name, + mode = "pywrap", + dep = ":%s" % info_collector_name, + linker_input_filters = "%s" % linker_input_filters_name, + pywrap_index = pywrap_index, + testonly = testonly, + compatible_with = compatible_with, + ) + + _generated_win_def_file( + name = win_def_name, + dep = ":%s" % info_collector_name, + pywrap_index = pywrap_index, + testonly = testonly, + compatible_with = compatible_with, + ) + + native.cc_binary( + name = shared_object_name, + srcs = [], + deps = [":%s" % pywrap_name] + common_deps, + linkshared = True, + linkstatic = True, + win_def_file = ":%s" % win_def_name, + testonly = testonly, + compatible_with = compatible_with, + local_defines = ["PROTOBUF_USE_DLLS"], + ) + shared_objects.append(":%s" % shared_object_name) + + # 3) Construct final binaries with proper names and put them as data + # attribute in a py_library, which is the final and only public artifact of + # this macro + # + pywrap_binaries_name = "_%s_binaries" % name + _pywrap_binaries( + name = pywrap_binaries_name, + collected_pywraps = ":%s" % info_collector_name, + deps = shared_objects, + extension = select({ + "@bazel_tools//src/conditions:windows": ".pyd", + "//conditions:default": ".so", + }), + testonly = testonly, + compatible_with = compatible_with, + ) + + binaries_data.append("%s" % pywrap_binaries_name) + binaries_data.extend([shared_objects[0]]) + + native.py_library( + name = name, + srcs = [":%s" % info_collector_name], + data = binaries_data, + testonly = testonly, + compatible_with = compatible_with, + visibility = visibility, + ) + + # For debugging purposes only + native.filegroup( + name = "_%s_all_binaries" % name, + srcs = binaries_data, + testonly = testonly, + compatible_with = compatible_with, + ) + +def _construct_common_binary( + name, + deps, + linkopts, + testonly, + compatible_with, + win_def_file, + local_defines): + native.cc_binary( + name = name, + deps = deps, + linkstatic = True, + linkshared = True, + linkopts = linkopts, + testonly = testonly, + compatible_with = compatible_with, + win_def_file = win_def_file, + local_defines = local_defines, + ) + + if_lib_name = "%s_if_lib" % name + native.filegroup( + name = if_lib_name, + srcs = [":%s" % name], + output_group = "interface_library", + testonly = testonly, + compatible_with = compatible_with, + ) + + import_name = "%s_import" % name + native.cc_import( + name = import_name, + shared_library = ":%s" % name, + interface_library = ":%s" % if_lib_name, + testonly = testonly, + compatible_with = compatible_with, + ) + + return import_name + +def _pywrap_split_library_impl(ctx): + pywrap_index = ctx.attr.pywrap_index + pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + split_linker_inputs = [] + private_linker_inputs = [] + + mode = ctx.attr.mode + filters = ctx.attr.linker_input_filters[PywrapFilters] + py_cc_linker_inputs = filters.py_cc_linker_inputs + + if mode == "pywrap": + pw = pywrap_infos[pywrap_index] + + # print("%s matches %s" % (str(pw.owner), ctx.label)) + if not pw.cc_only: + li = pw.cc_info.linking_context.linker_inputs.to_list()[0] + split_linker_inputs.append(li) + private_linker_inputs = [ + depset(direct = filters.pywrap_private_linker_inputs[pywrap_index].keys()), + ] + else: + for i in range(0, len(pywrap_infos)): + pw = pywrap_infos[i] + pw_private_linker_inputs = filters.pywrap_private_linker_inputs[i] + pw_lis = pw.cc_info.linking_context.linker_inputs.to_list()[1:] + for li in pw_lis: + if li in pw_private_linker_inputs: + continue + if li in filters.py_cc_linker_inputs: + if mode == "py_common": + split_linker_inputs.append(li) + elif mode == "cc_common": + split_linker_inputs.append(li) + + dependency_libraries = _construct_dependency_libraries( + ctx, + split_linker_inputs, + ) + + linker_input = cc_common.create_linker_input( + owner = ctx.label, + libraries = depset(direct = dependency_libraries), + ) + + linking_context = cc_common.create_linking_context( + linker_inputs = depset( + direct = [linker_input], + transitive = private_linker_inputs, + ), + ) + + return [CcInfo(linking_context = linking_context)] + +_pywrap_split_library = rule( + attrs = { + "dep": attr.label( + allow_files = False, + providers = [CollectedPywrapInfo], + ), + # py_deps, meaning C++ deps which depend on Python symbols + "linker_input_filters": attr.label( + allow_files = False, + providers = [PywrapFilters], + mandatory = True, + ), + "pywrap_index": attr.int(mandatory = False, default = -1), + "mode": attr.string( + mandatory = True, + values = ["pywrap", "cc_common", "py_common"], + ), + "_cc_toolchain": attr.label( + default = "@bazel_tools//tools/cpp:current_cc_toolchain", + ), + }, + fragments = ["cpp"], + toolchains = use_cpp_toolchain(), + implementation = _pywrap_split_library_impl, +) + +def _construct_dependency_libraries(ctx, split_linker_inputs): + cc_toolchain = find_cpp_toolchain(ctx) + feature_configuration = cc_common.configure_features( + ctx = ctx, + cc_toolchain = cc_toolchain, + requested_features = ctx.features, + unsupported_features = ctx.disabled_features, + ) + dependency_libraries = [] + for split_linker_input in split_linker_inputs: + for lib in split_linker_input.libraries: + lib_copy = lib + if not lib.alwayslink: + lib_copy = cc_common.create_library_to_link( + actions = ctx.actions, + cc_toolchain = cc_toolchain, + feature_configuration = feature_configuration, + static_library = lib.static_library, + pic_static_library = lib.pic_static_library, + interface_library = lib.interface_library, + alwayslink = True, + ) + dependency_libraries.append(lib_copy) + + return dependency_libraries + +def _linker_input_filters_impl(ctx): + py_cc_linker_inputs = {} + for py_cc_dep in ctx.attr.py_cc_deps_filter: + for li in py_cc_dep[CcInfo].linking_context.linker_inputs.to_list()[:1]: + py_cc_linker_inputs[li] = li.owner + + cc_linker_inputs = {} + for cc_dep in ctx.attr.cc_deps_filter: + for li in cc_dep[CcInfo].linking_context.linker_inputs.to_list()[:1]: + cc_linker_inputs[li] = li.owner + + pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + pywrap_private_linker_inputs = [] + + for pw in pywrap_infos: + private_linker_inputs = {} + + for private_dep in pw.private_deps: + for priv_li in private_dep[CcInfo].linking_context.linker_inputs.to_list(): + if (priv_li not in py_cc_linker_inputs) and (priv_li not in cc_linker_inputs): + private_linker_inputs[priv_li] = priv_li.owner + pywrap_private_linker_inputs.append(private_linker_inputs) + + return [ + PywrapFilters( + py_cc_linker_inputs = py_cc_linker_inputs, + pywrap_private_linker_inputs = pywrap_private_linker_inputs, + ), + ] + +_linker_input_filters = rule( + attrs = { + "dep": attr.label( + allow_files = False, + providers = [CollectedPywrapInfo], + ), + "py_cc_deps_filter": attr.label_list( + allow_files = False, + providers = [CcInfo], + mandatory = False, + default = [], + ), + "cc_deps_filter": attr.label_list( + allow_files = False, + providers = [CcInfo], + mandatory = False, + default = [], + ), + }, + implementation = _linker_input_filters_impl, +) + +def pywrap_common_library(name, dep): + native.alias( + name = name, + actual = "%s_internal_import" % dep, + ) + +def pywrap_py_common_library(name, dep): + native.alias( + name = name, + actual = "%s_py_internal_import" % dep, + ) + +def _generated_win_def_file_impl(ctx): + pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list() + pywrap_info = pywrap_infos[ctx.attr.pywrap_index] + win_def_file_name = "%s.def" % pywrap_info.owner.name + win_def_file = ctx.actions.declare_file(win_def_file_name) + + if pywrap_info.cc_only: + command = "echo \"EXPORTS\r\n\">> {win_def_file}" + else: + command = "echo \"EXPORTS\r\n PyInit_{owner}\">> {win_def_file}" + + ctx.actions.run_shell( + inputs = [], + command = command.format( + owner = pywrap_info.owner.name, + win_def_file = win_def_file.path, + ), + outputs = [win_def_file], + ) + + return [DefaultInfo(files = depset(direct = [win_def_file]))] + +_generated_win_def_file = rule( + attrs = { + "dep": attr.label( + allow_files = False, + providers = [CollectedPywrapInfo], + ), + "pywrap_index": attr.int(mandatory = True), + }, + implementation = _generated_win_def_file_impl, +) + +def pybind_extension( + name, + deps, + srcs = [], + private_deps = [], + visibility = None, + win_def_file = None, + testonly = None, + compatible_with = None, + outer_module_name = "", + additional_exported_symbols = [], + default_deps = ["@pybind11//:pybind11"], + **kwargs): + cc_library_name = "_%s_cc_library" % name + + native.cc_library( + name = cc_library_name, + deps = deps + private_deps + default_deps, + srcs = srcs, + linkstatic = True, + alwayslink = True, + visibility = visibility, + testonly = testonly, + compatible_with = compatible_with, + local_defines = ["PROTOBUF_USE_DLLS"], + **kwargs + ) + + if not srcs: + _cc_only_pywrap_info_wrapper( + name = name, + deps = ["%s" % cc_library_name], + testonly = testonly, + compatible_with = compatible_with, + visibility = visibility, + ) + else: + _pywrap_info_wrapper( + name = name, + deps = ["%s" % cc_library_name], + private_deps = private_deps, + outer_module_name = outer_module_name, + additional_exported_symbols = additional_exported_symbols, + testonly = testonly, + compatible_with = compatible_with, + visibility = visibility, + ) + +def _pywrap_info_wrapper_impl(ctx): + #the attribute is called deps not dep to match aspect's attr_aspects + + if len(ctx.attr.deps) != 1: + fail("deps attribute must contain exactly one dependency") + + py_stub = ctx.actions.declare_file("%s.py" % ctx.attr.name) + substitutions = {} + outer_module_name = ctx.attr.outer_module_name + if outer_module_name: + val = 'outer_module_name = "%s."' % outer_module_name + substitutions['outer_module_name = "" # template_val'] = val + + additional_exported_symbols = ctx.attr.additional_exported_symbols + if additional_exported_symbols: + val = "extra_names = %s # template_val" % additional_exported_symbols + substitutions["extra_names = [] # template_val"] = val + + ctx.actions.expand_template( + template = ctx.file.py_stub_src, + output = py_stub, + substitutions = substitutions, + ) + + return [ + PyInfo(transitive_sources = depset()), + PywrapInfo( + cc_info = ctx.attr.deps[0][CcInfo], + private_deps = ctx.attr.private_deps, + owner = ctx.label, + py_stub = py_stub, + outer_module_name = outer_module_name, + cc_only = False, + ), + ] + +_pywrap_info_wrapper = rule( + attrs = { + "deps": attr.label_list(providers = [CcInfo]), + "private_deps": attr.label_list(providers = [CcInfo]), + "outer_module_name": attr.string(mandatory = False, default = ""), + "py_stub_src": attr.label( + allow_single_file = True, + default = Label("//third_party/py/rules_pywrap:pybind_extension.py.tpl"), + ), + "additional_exported_symbols": attr.string_list( + mandatory = False, + default = [], + ), + }, + implementation = _pywrap_info_wrapper_impl, +) + +def _cc_only_pywrap_info_wrapper_impl(ctx): + wrapped_dep = ctx.attr.deps[0] + return [ + PyInfo(transitive_sources = depset()), + PywrapInfo( + cc_info = wrapped_dep[CcInfo], + private_deps = [], + owner = ctx.label, + py_stub = None, + outer_module_name = None, + cc_only = True, + ), + ] + +_cc_only_pywrap_info_wrapper = rule( + attrs = { + "deps": attr.label_list(providers = [CcInfo]), + }, + implementation = _cc_only_pywrap_info_wrapper_impl, +) + +def _pywrap_info_collector_aspect_impl(target, ctx): + pywrap_infos = [] + transitive_pywrap_infos = [] + if PywrapInfo in target: + pywrap_infos.append(target[PywrapInfo]) + + if hasattr(ctx.rule.attr, "deps"): + for dep in ctx.rule.attr.deps: + if CollectedPywrapInfo in dep: + collected_pywrap_info = dep[CollectedPywrapInfo] + transitive_pywrap_infos.append(collected_pywrap_info.pywrap_infos) + + return [ + CollectedPywrapInfo( + pywrap_infos = depset( + direct = pywrap_infos, + transitive = transitive_pywrap_infos, + order = "topological", + ), + ), + ] + +_pywrap_info_collector_aspect = aspect( + attr_aspects = ["deps"], + implementation = _pywrap_info_collector_aspect_impl, +) + +def _collected_pywrap_infos_impl(ctx): + pywrap_infos = [] + for dep in ctx.attr.deps: + if CollectedPywrapInfo in dep: + pywrap_infos.append(dep[CollectedPywrapInfo].pywrap_infos) + + rv = CollectedPywrapInfo( + pywrap_infos = depset( + transitive = pywrap_infos, + order = "topological", + ), + ) + pywraps = rv.pywrap_infos.to_list() + + if ctx.attr.pywrap_count != len(pywraps): + found_pywraps = "\n ".join([str(pw.owner) for pw in pywraps]) + fail(""" + Number of actual pywrap libraries does not match expected pywrap_count. + Expected pywrap_count: {expected_pywrap_count} + Actual pywrap_count: {actual_pywra_count} + Actual pywrap libraries in the transitive closure of {label}: + {found_pywraps} + """.format( + expected_pywrap_count = ctx.attr.pywrap_count, + actual_pywra_count = len(pywraps), + label = ctx.label, + found_pywraps = found_pywraps, + )) + + py_stubs = [] + for pw in pywraps: + if pw.py_stub: + py_stubs.append(pw.py_stub) + + return [ + DefaultInfo(files = depset(direct = py_stubs)), + rv, + ] + +collected_pywrap_infos = rule( + attrs = { + "deps": attr.label_list( + aspects = [_pywrap_info_collector_aspect], + providers = [PyInfo], + ), + "pywrap_count": attr.int(mandatory = True, default = 1), + }, + implementation = _collected_pywrap_infos_impl, +) + +def _pywrap_binaries_impl(ctx): + deps = ctx.attr.deps + dep = ctx.attr.collected_pywraps + extension = ctx.attr.extension + + pywrap_infos = dep[CollectedPywrapInfo].pywrap_infos.to_list() + original_binaries = deps + + if len(pywrap_infos) != len(original_binaries): + fail() + + final_binaries = [] + original_to_final_binaries = [ + "\n\nvvv Shared objects corresondence map, target = {} vvv".format(ctx.label), + ] + for i in range(0, len(pywrap_infos)): + pywrap_info = pywrap_infos[i] + original_binary = original_binaries[i] + subfolder = "" + if pywrap_info.outer_module_name: + subfolder = pywrap_info.outer_module_name + "/" + final_binary_name = "%s%s%s" % (subfolder, pywrap_info.owner.name, extension) + final_binary = ctx.actions.declare_file(final_binary_name) + original_binary_file = original_binary.files.to_list()[0] + ctx.actions.run_shell( + inputs = [original_binary_file], + command = "cp {original} {final}".format( + original = original_binary_file.path, + final = final_binary.path, + ), + outputs = [final_binary], + ) + + original_to_final_binaries.append( + " '{original}' => '{final}'".format( + original = original_binary_file.path, + final = final_binary.path, + ), + ) + + final_binaries.append(final_binary) + + original_to_final_binaries.append( + "^^^ Shared objects corresondence map^^^\n\n", + ) + print("\n".join(original_to_final_binaries)) + + return [DefaultInfo(files = depset(direct = final_binaries))] + +_pywrap_binaries = rule( + attrs = { + "deps": attr.label_list(mandatory = True, allow_files = False), + "collected_pywraps": attr.label(mandatory = True, allow_files = False), + "extension": attr.string(default = ".so"), + }, + implementation = _pywrap_binaries_impl, +) + +def _stripped_cc_info_impl(ctx): + filtered_libraries = [] + + for dep in ctx.attr.deps: + cc_info = dep[CcInfo] + cc_linker_inputs = cc_info.linking_context.linker_inputs + linker_input = cc_linker_inputs.to_list()[0] + + for lib in linker_input.libraries: + filtered_libraries.append(lib) + + linker_input = cc_common.create_linker_input( + owner = ctx.label, + libraries = depset(direct = filtered_libraries), + ) + + linking_context = cc_common.create_linking_context( + linker_inputs = depset(direct = [linker_input]), + ) + + return [CcInfo(linking_context = linking_context)] + +stripped_cc_info = rule( + attrs = { + "deps": attr.label_list( + allow_files = False, + providers = [CcInfo], + ), + }, + implementation = _stripped_cc_info_impl, +) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 492d591d208a81..83f52d9af9970a 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl index c976ddc7dbbdd5..b22fbe0b65ad2e 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/containers.bzl @@ -7,12 +7,12 @@ container_digests = { # JAX manylinux2014 configs. "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3", "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63", - "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:7d948c3d2e3ab8867d600457b5666cc74c4206f08517791c95fc9a69b7cffefa", + "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:fafe12fbe5bb02a21b9a95aa9dc3ac6d0e6276fcb7dd26bf1bb2d093b444b71a", "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", "cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:0c78f3428cde36f041b758fc2f01d23d2f0dd72dec248f78667fb0c9d1f74cef", "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2", "cuda12.3-cudnn8.9-ubuntu22.04-manylinux2014-multipython": "sha256:97b219abb22994cf0530771d536f26fe301bacd328f0485c38af3847c2ee6b14", - "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:e590303ea55a0990c26db4640161120ff6bc4124152c62155d397ba22d2ca850", + "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:a9acf6849a905079847074798405b18d4badc6270dc32076f9e7ac4b377e51a8", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index fe29d52ad1bef9..280b8d914283dd 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index c32e7c3bce4e01..e4840baac4bafc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -21,7 +21,7 @@ load( load("@local_xla//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( "//tsl/platform:build_config.bzl", - "tf_cuda_libdevice_path_deps", + "tf_cuda_root_path_deps", "tf_error_logging_deps", "tf_fingerprint_deps", "tf_google_mobile_srcs_no_runtime", @@ -100,6 +100,17 @@ cc_library( ], ) +tsl_cc_test( + name = "cpu_info_test", + size = "small", + srcs = ["cpu_info_test.cc"], + deps = [ + ":platform_port", + ":test", + ":test_main", + ], +) + cc_library( name = "criticality", compatible_with = get_compatible_with_portable(), @@ -303,7 +314,6 @@ cc_library( ":strcat", ":stringprintf", ":types", - "//tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:function_ref", @@ -311,6 +321,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/types:optional", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ] + tf_platform_deps("status"), ) @@ -322,8 +333,10 @@ cc_library( hdrs = ["status_to_from_proto.h"], deps = [ ":status", - "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ] + tf_platform_deps("status"), ) @@ -336,7 +349,7 @@ cc_library( ":status", ":statusor", ":test", - "//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -664,7 +677,7 @@ exports_files( "cpu_info.h", "crash_analysis.h", "criticality.h", - "cuda_libdevice_path.h", + "cuda_root_path.h", "demangle.h", "env.cc", "env.h", @@ -1186,10 +1199,10 @@ tsl_cc_test( ) cc_library( - name = "cuda_libdevice_path", + name = "cuda_root_path", compatible_with = get_compatible_with_portable(), - textual_hdrs = ["cuda_libdevice_path.h"], - deps = tf_cuda_libdevice_path_deps(), + textual_hdrs = ["cuda_root_path.h"], + deps = tf_cuda_root_path_deps(), ) cc_library( @@ -1348,11 +1361,11 @@ tsl_cc_test( ":status_to_from_proto", ":test", ":test_main", - "//tsl/protobuf:error_codes_proto_impl_cc", - "//tsl/protobuf:status_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:status_proto_cc", ], ) @@ -1382,7 +1395,7 @@ tsl_cc_test( ":statusor", ":test", ":test_main", - "//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl index bec0e8403b2488..4a22f84baf1493 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/build_config.bzl @@ -11,7 +11,7 @@ load( _tf_additional_rpc_deps = "tf_additional_rpc_deps", _tf_additional_tensor_coding_deps = "tf_additional_tensor_coding_deps", _tf_additional_test_deps = "tf_additional_test_deps", - _tf_cuda_libdevice_path_deps = "tf_cuda_libdevice_path_deps", + _tf_cuda_root_path_deps = "tf_cuda_root_path_deps", _tf_error_logging_deps = "tf_error_logging_deps", _tf_fingerprint_deps = "tf_fingerprint_deps", _tf_google_mobile_srcs_no_runtime = "tf_google_mobile_srcs_no_runtime", @@ -49,7 +49,7 @@ tf_additional_lib_hdrs = _tf_additional_lib_hdrs tf_additional_rpc_deps = _tf_additional_rpc_deps tf_additional_tensor_coding_deps = _tf_additional_tensor_coding_deps tf_additional_test_deps = _tf_additional_test_deps -tf_cuda_libdevice_path_deps = _tf_cuda_libdevice_path_deps +tf_cuda_root_path_deps = _tf_cuda_root_path_deps tf_error_logging_deps = _tf_error_logging_deps tf_fingerprint_deps = _tf_fingerprint_deps tf_google_mobile_srcs_no_runtime = _tf_google_mobile_srcs_no_runtime diff --git a/third_party/xla/third_party/tsl/tsl/platform/build_config_root.bzl b/third_party/xla/third_party/tsl/tsl/platform/build_config_root.bzl index 151e40d4c02e3d..fd87c70d761604 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/build_config_root.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/build_config_root.bzl @@ -8,6 +8,7 @@ load( _if_llvm_powerpc_available = "if_llvm_powerpc_available", _if_llvm_system_z_available = "if_llvm_system_z_available", _if_llvm_x86_available = "if_llvm_x86_available", + _if_pywrap = "if_pywrap", _if_static = "if_static", _if_static_and_not_mobile = "if_static_and_not_mobile", _tf_additional_grpc_deps_py = "tf_additional_grpc_deps_py", @@ -27,6 +28,7 @@ if_llvm_powerpc_available = _if_llvm_powerpc_available if_llvm_system_z_available = _if_llvm_system_z_available if_llvm_x86_available = _if_llvm_x86_available if_static = _if_static +if_pywrap = _if_pywrap if_static_and_not_mobile = _if_static_and_not_mobile tf_additional_grpc_deps_py = _tf_additional_grpc_deps_py tf_additional_license_deps = _tf_additional_license_deps diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index 77a79f56e1ae44..d6a7f8cb6328f3 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -207,7 +207,6 @@ cc_library( copts = tsl_copts(), deps = [ ":http_request", - "//tsl/lib/gtl:map_util", "//tsl/platform:env", "//tsl/platform:errors", "//tsl/platform:macros", @@ -218,6 +217,7 @@ cc_library( "//tsl/platform:stringpiece", "//tsl/platform:types", "@curl", + "@local_xla//xla/tsl/lib/gtl:map_util", "@local_xla//xla/tsl/util:env_var", ], ) @@ -396,9 +396,9 @@ tsl_cc_test( "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", - "//tsl/profiler/utils:time_utils_impl", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", + "@local_xla//xla/tsl/profiler/utils:time_utils_impl", ], ) @@ -437,12 +437,14 @@ tsl_cc_test( srcs = ["curl_http_request_test.cc"], deps = [ ":curl_http_request", + "//tsl/platform", "//tsl/platform:env_impl", "//tsl/platform:path", "//tsl/platform:platform_port", "//tsl/platform:test", "//tsl/platform:test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc index 44eeab7f511fb9..fde422c2d04919 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc @@ -17,8 +17,8 @@ limitations under the License. #include +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/tsl/util/env_var.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" #include "tsl/platform/scanner.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc index 31cde679f4978b..429006a3724bdc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc @@ -19,9 +19,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/mem.h" #include "tsl/platform/path.h" +#include "tsl/platform/platform.h" #include "tsl/platform/test.h" namespace tsl { @@ -495,9 +497,11 @@ TEST(CurlHttpRequestTest, GetRequest_CouldntResolveHost) { const auto& status = http_request.Send(); EXPECT_EQ(error::FAILED_PRECONDITION, status.code()); EXPECT_EQ( - "Error executing an HTTP request: libcurl code 6 meaning " - "'Couldn't resolve host name', error details: Could not resolve host " - "'metadata'", + absl::StrCat( + "Error executing an HTTP request: libcurl code 6 meaning ", + (kIsOpenSource ? "'Couldn't resolve host name', error details: " + : "'Could not resolve hostname', error details: "), + "Could not resolve host ", "'metadata'"), status.message()); EXPECT_EQ(0, http_request.GetResponseCode()); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h index 68506b1d34ae8e..c8d3903ffa6f33 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info.h @@ -21,6 +21,7 @@ limitations under the License. // TODO(ahentz): This is not strictly required here but, for historical // reasons, many people depend on cpu_info.h in order to use kLittleEndian. #include "tsl/platform/byte_order.h" +#include "tsl/platform/platform.h" #if defined(_MSC_VER) // included so __cpuidex function is available for GETCPUID on Windows @@ -150,6 +151,24 @@ bool TestAarch64CPU(Aarch64CPU cpu); // Checks CPU registers to return hardware capabilities. bool TestCPUFeature(CPUFeature feature); +// Checks whether the current processor is x86. +constexpr bool IsX86CPU() { +#ifdef PLATFORM_IS_X86 + return true; +#else + return false; +#endif +} + +// Checks whether the current processor is aarch64. +constexpr bool IsAarch64CPU() { +#if defined(PLATFORM_IS_ARM64) && !defined(__APPLE__) && !defined(__OpenBSD__) + return true; +#else + return false; +#endif +} + // Returns CPU Vendor string (i.e. 'GenuineIntel', 'AuthenticAMD', etc.) std::string CPUVendorIDString(); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc new file mode 100644 index 00000000000000..dbef5a57f47397 --- /dev/null +++ b/third_party/xla/third_party/tsl/tsl/platform/cpu_info_test.cc @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software + +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tsl/platform/cpu_info.h" + +#include "tsl/platform/test.h" + +namespace tsl { + +TEST(CPUInfo, CommonX86CPU) { + // CPUs from 1999 onwards support SSE. + if (port::TestCPUFeature(port::CPUFeature::SSE)) { + EXPECT_TRUE(port::IsX86CPU()); + } +} + +TEST(CPUInfo, Aarch64NeoverseV1CPU) { + if (port::TestAarch64CPU(port::Aarch64CPU::ARM_NEOVERSE_V1)) { + EXPECT_TRUE(port::IsAarch64CPU()); + } +} + +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h b/third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h similarity index 90% rename from third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h rename to third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h index d8c2b6d01daf43..65a9ca5a7acb0c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cuda_libdevice_path.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cuda_root_path.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ -#define TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#ifndef TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ +#define TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ #include #include @@ -46,4 +46,4 @@ bool PreferPtxasFromPath(); } // namespace tsl -#endif // TENSORFLOW_TSL_PLATFORM_CUDA_LIBDEVICE_PATH_H_ +#endif // TENSORFLOW_TSL_PLATFORM_CUDA_ROOT_PATH_H_ diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index 785a44cfff0e7c..bcddadc908c95f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -56,9 +56,9 @@ cc_library( ) cc_library( - name = "cuda_libdevice_path", - srcs = ["cuda_libdevice_path.cc"], - hdrs = ["//tsl/platform:cuda_libdevice_path.h"], + name = "cuda_root_path", + srcs = ["cuda_root_path.cc"], + hdrs = ["//tsl/platform:cuda_root_path.h"], compatible_with = [], data = if_cuda_tools([ "@cuda_nvcc//:nvvm", @@ -157,12 +157,12 @@ cc_library( "//tsl/platform:threadpool_interface", "//tsl/platform:tracing", "//tsl/platform:types", - "//tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@eigen_archive//:eigen3", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -182,7 +182,7 @@ cc_library( "//tsl/platform:logging", "//tsl/platform:mutex", "//tsl/platform:strcat", - "//tsl/protobuf:error_codes_proto_impl_cc", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -573,6 +573,7 @@ bzl_library( name = "build_config_root_bzl", srcs = ["build_config_root.bzl"], # copybara:uncomment parse_tests = False, + deps = ["//third_party/py/rules_pywrap:pywrap_bzl"], ) # Export source files needed for mobile builds, which do not use granular targets. diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl index 726f54634c5661..b7ad02b93fc5fe 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config.bzl @@ -8,6 +8,7 @@ load( "if_not_windows", "if_tsl_link_protobuf", ) +load("//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") load("//tsl/platform:build_config_root.bzl", "if_static") def well_known_proto_libs(): @@ -147,6 +148,8 @@ def _proto_py_outs(srcs, use_grpc_plugin = False): ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs] return ret +# TODO(b/356020232): cleanup non use_pywrap_rules() parts and everythin relate +# to creation of header-only protobuf targets # Re-defined protocol buffer rule to allow building "header only" protocol # buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs # containing select() statements. @@ -247,21 +250,29 @@ def cc_proto_library( }) impl_name = name + "_impl" - header_only_name = name + "_headers_only" - header_only_deps = tf_deps(protolib_deps, "_cc_headers_only") - if make_default_target_header_only: - native.alias( - name = name, - actual = header_only_name, - visibility = kwargs["visibility"], - ) - else: + if use_pywrap_rules(): native.alias( name = name, actual = impl_name, visibility = kwargs["visibility"], ) + else: + header_only_name = name + "_headers_only" + header_only_deps = tf_deps(protolib_deps, "_cc_headers_only") + + if make_default_target_header_only: + native.alias( + name = name, + actual = header_only_name, + visibility = kwargs["visibility"], + ) + else: + native.alias( + name = name, + actual = impl_name, + visibility = kwargs["visibility"], + ) native.cc_library( name = impl_name, @@ -272,14 +283,18 @@ def cc_proto_library( alwayslink = 1, **kwargs ) - native.cc_library( - name = header_only_name, - deps = [ - "@com_google_protobuf//:protobuf_headers", - ] + header_only_deps + if_tsl_link_protobuf([impl_name]), - hdrs = gen_hdrs, - **kwargs - ) + + if use_pywrap_rules(): + pass + else: + native.cc_library( + name = header_only_name, + deps = [ + "@com_google_protobuf//:protobuf_headers", + ] + header_only_deps + if_tsl_link_protobuf([impl_name]), + hdrs = gen_hdrs, + **kwargs + ) # Re-defined protocol buffer rule to allow setting service namespace. def cc_grpc_library( @@ -429,6 +444,8 @@ def py_proto_library( **kwargs ) +# TODO(b/356020232): cleanup non-use_pywrap_rules part and all logic reated to +# protobuf header-only targets after migration is done def tf_proto_library_cc( name, srcs = [], @@ -478,24 +495,36 @@ def tf_proto_library_cc( visibility = visibility, ) - native.alias( - name = cc_name + "_headers_only", - actual = cc_name, - testonly = testonly, - visibility = visibility, - ) + if use_pywrap_rules(): + pass + else: + native.alias( + name = cc_name + "_headers_only", + actual = cc_name, + testonly = testonly, + visibility = visibility, + ) + + native.cc_library( + name = cc_name, + deps = cc_deps + ["@com_google_protobuf//:protobuf_headers"] + if_tsl_link_protobuf([name + "_cc_impl"]), + testonly = testonly, + visibility = visibility, + ) - native.cc_library( - name = cc_name, - deps = cc_deps + ["@com_google_protobuf//:protobuf_headers"] + if_tsl_link_protobuf([name + "_cc_impl"]), - testonly = testonly, - visibility = visibility, - ) native.cc_library( name = cc_name + "_impl", deps = [s + "_impl" for s in cc_deps], ) + if use_pywrap_rules(): + native.cc_library( + name = cc_name, + deps = cc_deps + ["@com_google_protobuf//:protobuf_headers", cc_name + "_impl"], + testonly = testonly, + visibility = visibility, + ) + return cc_proto_library( @@ -726,7 +755,7 @@ def tf_lib_proto_parsing_deps(): return [ ":protos_all_cc", clean_dep("@eigen_archive//:eigen3"), - clean_dep("//tsl/protobuf:protos_all_cc"), + clean_dep("@local_xla//xla/tsl/protobuf:protos_all_cc"), ] def tf_py_clif_cc(name, visibility = None, **kwargs): @@ -760,6 +789,7 @@ def tf_protobuf_deps(): otherwise = [clean_dep("@com_google_protobuf//:protobuf_headers")], ) +# TODO(b/356020232): remove completely after migration is done # Link protobuf, unless the tsl_link_protobuf build flag is explicitly set to false. def tsl_protobuf_deps(): return if_tsl_link_protobuf([clean_dep("@com_google_protobuf//:protobuf")], [clean_dep("@com_google_protobuf//:protobuf_headers")]) @@ -778,9 +808,9 @@ def tsl_cc_test( clean_dep("@com_google_protobuf//:protobuf"), # TODO(ddunleavy) remove these and add proto deps to tests # granularly - clean_dep("//tsl/protobuf:error_codes_proto_impl_cc_impl"), - clean_dep("//tsl/protobuf:histogram_proto_cc_impl"), - clean_dep("//tsl/protobuf:status_proto_cc_impl"), + clean_dep("@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc_impl"), + clean_dep("@local_xla//xla/tsl/protobuf:histogram_proto_cc_impl"), + clean_dep("@local_xla//xla/tsl/protobuf:status_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:xplane_proto_cc_impl"), clean_dep("//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), ], @@ -789,7 +819,7 @@ def tsl_cc_test( ) def tf_portable_proto_lib(): - return ["//tensorflow/core:protos_all_cc_impl", clean_dep("//tsl/protobuf:protos_all_cc_impl")] + return ["//tensorflow/core:protos_all_cc_impl", clean_dep("@local_xla//xla/tsl/protobuf:protos_all_cc_impl")] def tf_protobuf_compiler_deps(): return if_static( @@ -845,5 +875,5 @@ def tf_google_mobile_srcs_no_runtime(): def tf_google_mobile_srcs_only_runtime(): return [] -def tf_cuda_libdevice_path_deps(): - return tf_platform_deps("cuda_libdevice_path") +def tf_cuda_root_path_deps(): + return tf_platform_deps("cuda_root_path") diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/build_config_root.bzl b/third_party/xla/third_party/tsl/tsl/platform/default/build_config_root.bzl index 142641b16d2fa3..05559211a93bb2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/build_config_root.bzl +++ b/third_party/xla/third_party/tsl/tsl/platform/default/build_config_root.bzl @@ -3,6 +3,7 @@ # be separate to avoid cyclic references. load("@local_config_remote_execution//:remote_execution.bzl", "gpu_test_tags") +load("//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") # RBE settings for tests that require a GPU. This is used in exec_properties of rules # that need GPU access. @@ -39,12 +40,16 @@ def tf_additional_license_deps(): def tf_additional_tpu_ops_deps(): return [] +# TODO(b/356020232): remove completely after migration is done # Include specific extra dependencies when building statically, or # another set of dependencies otherwise. If "macos" is provided, that # dependency list is used when using the framework_shared_object config # on MacOS platforms. If "macos" is not provided, the "otherwise" list is # used for all framework_shared_object platforms including MacOS. def if_static(extra_deps, otherwise = [], macos = []): + if use_pywrap_rules(): + return extra_deps + ret = { str(Label("@local_xla//xla/tsl:framework_shared_object")): otherwise, "//conditions:default": extra_deps, @@ -53,7 +58,11 @@ def if_static(extra_deps, otherwise = [], macos = []): ret[str(Label("@local_xla//xla/tsl:macos_with_framework_shared_object"))] = macos return select(ret) +# TODO(b/356020232): remove completely after migration is done def if_static_and_not_mobile(extra_deps, otherwise = []): + if use_pywrap_rules(): + return extra_deps + return select({ str(Label("@local_xla//xla/tsl:framework_shared_object")): otherwise, str(Label("@local_xla//xla/tsl:android")): otherwise, @@ -61,6 +70,10 @@ def if_static_and_not_mobile(extra_deps, otherwise = []): "//conditions:default": extra_deps, }) +# TODO(b/356020232): remove completely after migration is done +def if_pywrap(if_true = [], if_false = []): + return if_true if use_pywrap_rules() else if_false + def if_llvm_aarch32_available(then, otherwise = []): return select({ str(Label("@local_xla//xla/tsl:aarch32_or_cross")): then, diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc rename to third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc index ac0a804b4dfd42..ca6da0e5532eaa 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_root_path.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/platform/cuda_libdevice_path.h" +#include "tsl/platform/cuda_root_path.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/env.cc b/third_party/xla/third_party/tsl/tsl/platform/default/env.cc index 6786be68aa4efc..e91c35454f892f 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/env.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/env.cc @@ -35,6 +35,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/default/posix_file_system.h" #include "tsl/platform/env.h" #include "tsl/platform/load_library.h" @@ -42,7 +43,6 @@ limitations under the License. #include "tsl/platform/mutex.h" #include "tsl/platform/ram_file_system.h" #include "tsl/platform/strcat.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/posix_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/default/posix_file_system.cc index d1b2109823f35e..834bbef63ab5cc 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/posix_file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/posix_file_system.cc @@ -29,6 +29,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/default/posix_file_system.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/strcat.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h index 916be8db4f6998..89a40bd891e106 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/xla/third_party/tsl/tsl/platform/ml_dtypes.h @@ -20,6 +20,8 @@ limitations under the License. #include "ml_dtypes/include/intn.h" // from @ml_dtypes namespace tsl { +using float8_e3m4 = ::ml_dtypes::float8_e3m4; +using float8_e4m3 = ::ml_dtypes::float8_e4m3; using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.cc b/third_party/xla/third_party/tsl/tsl/platform/status.cc index 85f7290e05079b..f6d4aed1d71984 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status.cc @@ -39,13 +39,13 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/mutex.h" #include "tsl/platform/stack_frame.h" #include "tsl/platform/stacktrace.h" #include "tsl/platform/str_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/stringprintf.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status.h b/third_party/xla/third_party/tsl/tsl/platform/status.h index 84954ff485a48b..6fbbf53a851464 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status.h @@ -32,12 +32,12 @@ limitations under the License. #include "absl/strings/cord.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/platform.h" #include "tsl/platform/stack_frame.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/error_codes.pb.h" // Include appropriate platform-dependent parts of status. #if defined(PLATFORM_GOOGLE) diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc index 77422d564f8cda..bcb04018dbc7f9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { namespace testing { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h index cb8eba40783093..e7e12c269d28e0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers.h @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" // Defines the following utilities: // diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc b/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc index ea0e0d489c24bf..3a681f6f3aed31 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_matchers_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { namespace testing { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc b/third_party/xla/third_party/tsl/tsl/platform/status_test.cc index 6d9948fa68d99b..e716a15b96e46e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_test.cc @@ -12,24 +12,27 @@ limitations under the License. ==============================================================================*/ #include "tsl/platform/status.h" +#include #include #include #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/str_format.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/stack_frame.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { namespace { +using ::testing::ElementsAre; using ::testing::IsEmpty; +using ::testing::Pair; using ::tsl::testing::IsOk; using ::tsl::testing::StatusIs; @@ -188,16 +191,18 @@ TEST(Status, SaveOKStatusToProto) { } TEST(Status, SaveErrorStatusToProto) { - tensorflow::StatusProto status_proto = - StatusToProto(errors::NotFound("Not found")); + tensorflow::StatusProto status_proto = StatusToProto(errors::Create( + absl::StatusCode::kNotFound, "Not found", {{"foo", "bar"}})); EXPECT_EQ(status_proto.code(), error::NOT_FOUND); EXPECT_EQ(status_proto.message(), "Not found"); + EXPECT_THAT(status_proto.payload(), ElementsAre(Pair("foo", "bar"))); } TEST(Status, SaveEmptyStatusToProto) { tensorflow::StatusProto status_proto = StatusToProto(absl::Status()); EXPECT_EQ(status_proto.code(), error::OK); EXPECT_THAT(status_proto.message(), IsEmpty()); + EXPECT_THAT(status_proto.payload(), IsEmpty()); } TEST(Status, MakeOKStatusFromProto) { @@ -210,8 +215,10 @@ TEST(Status, MakeErrorStatusFromProto) { tensorflow::StatusProto status_proto; status_proto.set_code(error::INVALID_ARGUMENT); status_proto.set_message("Invalid argument"); - EXPECT_THAT(StatusFromProto(status_proto), - StatusIs(error::INVALID_ARGUMENT, "Invalid argument")); + status_proto.mutable_payload()->insert({"foo", "bar"}); + absl::Status s = StatusFromProto(status_proto); + EXPECT_THAT(s, StatusIs(error::INVALID_ARGUMENT, "Invalid argument")); + EXPECT_EQ(s.GetPayload("foo"), "bar"); } TEST(Status, MakeStatusFromEmptyProto) { diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc index 96ad290f92c71a..54e2b2ef3391ab 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.cc @@ -16,9 +16,11 @@ limitations under the License. #include +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { @@ -32,6 +34,12 @@ tensorflow::StatusProto StatusToProto(const absl::Status& s) { if (!s.message().empty()) { status_proto.set_message(std::string(s.message())); } + + s.ForEachPayload( + [&status_proto](absl::string_view type_url, absl::Cord value) { + status_proto.mutable_payload()->insert( + {std::string(type_url), std::string(value)}); + }); return status_proto; } @@ -41,15 +49,23 @@ absl::Status StatusFromProto(const tensorflow::StatusProto& proto, if (proto.code() == tensorflow::error::OK) { return absl::OkStatus(); } - return absl::Status(static_cast(proto.code()), - proto.message(), loc); + absl::Status s(static_cast(proto.code()), proto.message(), + loc); + for (const auto& [key, payload] : proto.payload()) { + s.SetPayload(key, absl::Cord(payload)); + } + return s; } #else Status StatusFromProto(const tensorflow::StatusProto& proto) { if (proto.code() == tensorflow::error::OK) { return OkStatus(); } - return Status(static_cast(proto.code()), proto.message()); + Status s(static_cast(proto.code()), proto.message()); + for (const auto& [key, payload] : proto.payload()) { + s.SetPayload(key, absl::Cord(payload)); + } + return s; } #endif diff --git a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h index 9891737f08159c..021e002ae4041d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h +++ b/third_party/xla/third_party/tsl/tsl/platform/status_to_from_proto.h @@ -15,8 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ #define TENSORFLOW_TSL_PLATFORM_STATUS_TO_FROM_PROTO_H_ +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/status.pb.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD index bd07c1b07f59ef..48737e5c6c531b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/BUILD @@ -71,11 +71,11 @@ cc_library( "//tsl/platform:threadpool_interface", "//tsl/platform:tracing", "//tsl/platform:types", - "//tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@eigen_archive//:eigen3", + "@local_xla//xla/tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/env.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/env.cc index 13fb4515d5a9fd..58fb1d83afda9a 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/env.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/env.cc @@ -28,12 +28,12 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/load_library.h" #include "tsl/platform/logging.h" #include "tsl/platform/ram_file_system.h" #include "tsl/platform/windows/wide_char.h" #include "tsl/platform/windows/windows_file_system.h" -#include "tsl/protobuf/error_codes.pb.h" #pragma comment(lib, "shlwapi.lib") diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/windows_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/windows_file_system.cc index cd57c2c92e9bec..faa90a82530f2d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/windows_file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/windows_file_system.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system_helper.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/platform/strcat.h" #include "tsl/platform/windows/error_windows.h" #include "tsl/platform/windows/wide_char.h" -#include "tsl/protobuf/error_codes.pb.h" // TODO(mrry): Prevent this Windows.h #define from leaking out of our headers. #undef DeleteFile diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD index ee6e82ea4fa185..b68d8b55302ad2 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/BUILD @@ -1,5 +1,5 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_xla//xla/tsl:tsl.bzl", "if_not_android", "internal_visibility", "nvtx_headers") +load("@local_xla//xla/tsl:tsl.bzl", "if_not_android", "if_oss", "internal_visibility", "nvtx_headers") load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@local_xla//xla/tsl/profiler/builds:build_config.bzl", @@ -38,7 +38,7 @@ filegroup( "scoped_memory_debug_annotation.h", "traceme.h", "traceme_encode.h", - "//tsl/profiler/utils:mobile_srcs_no_runtime", + "@local_xla//xla/tsl/profiler/utils:mobile_srcs_no_runtime", ], compatible_with = get_compatible_with_portable(), visibility = ["//visibility:public"], @@ -221,7 +221,7 @@ cc_library( "//tsl/platform:platform_port", "//tsl/platform:status", "@local_xla//xla/tsl/profiler/convert:post_process_single_host_xplane", - "//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), alwayslink = True, ) @@ -268,11 +268,11 @@ cc_library( "//tsl/platform:logging", "//tsl/platform:macros", "//tsl/platform:types", - "//tsl/profiler/utils:no_init", "@com_google_absl//absl/strings", + "@local_xla//xla/tsl/profiler/utils:no_init", ] + if_not_android([ "@local_xla//xla/tsl/profiler/backends/cpu:traceme_recorder", - "//tsl/profiler/utils:time_utils", + "@local_xla//xla/tsl/profiler/utils:time_utils", ]), ) @@ -283,6 +283,7 @@ cc_library( ["nvtx_utils_stub.cc"], ), hdrs = ["nvtx_utils.h"], + local_defines = if_oss(["NVTX_VERSION_3_1=1"]), visibility = ["//visibility:public"], deps = if_cuda_is_configured(nvtx_headers()), ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc index 4943fba0c1bfea..e2a80dd1e7776f 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -92,7 +92,11 @@ void RangePush(ProfilerDomainHandle domain, StringHandle title, attrs.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; attrs.messageType = NVTX_MESSAGE_TYPE_REGISTERED; attrs.message.registered = reinterpret_cast(title); +#ifdef NVTX_VERSION_3_1 + NVTX_PAYLOAD_EVTATTR_SET(&attrs, schema_id, payload, payload_size); +#else NVTX_PAYLOAD_EVTATTR_SET(attrs, schema_id, payload, payload_size); +#endif nvtxDomainRangePushEx(reinterpret_cast(domain), &attrs); } } // namespace detail diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc index 30e718cc456c94..2932415dceae2e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_session.cc @@ -28,12 +28,12 @@ limitations under the License. #if !defined(IS_MOBILE_PLATFORM) #include "xla/tsl/profiler/convert/post_process_single_host_xplane.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/host_info.h" #include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/profiler_lock.h" -#include "tsl/profiler/utils/time_utils.h" #endif namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h index 4218d1c848a02a..ac5f0c14aba35c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h @@ -21,14 +21,14 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/no_init.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/profiler/lib/traceme_encode.h" // IWYU pragma: export -#include "tsl/profiler/utils/no_init.h" #if !defined(IS_MOBILE_PLATFORM) #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #endif namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD index bab55bb8764edc..401b08515a8b1e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/protobuf/BUILD @@ -71,7 +71,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "trace_events_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":trace_events_proto"], # ) @@ -111,7 +110,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xplane_py_pb2", -# api_version = 2, # visibility = internal_visibility([":friends"]), # deps = [":xplane_proto"], # ) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD deleted file mode 100644 index 6922da086f24f4..00000000000000 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ /dev/null @@ -1,112 +0,0 @@ -# Placeholder: load py_proto_library -load( - "@local_xla//xla/tsl:tsl.bzl", - "if_google", - "internal_visibility", -) -load( - "//tsl/platform:build_config.bzl", - "tf_proto_library", -) - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = internal_visibility([ - "//tensorflow/core:__subpackages__", - "@local_xla//xla/tsl:internal", - "//tensorflow_models:__subpackages__", - ]), - features = if_google(["-parse_headers"]), - licenses = ["notice"], -) - -tf_proto_library( - name = "dnn_proto", - srcs = ["dnn.proto"], - make_default_target_header_only = True, - protodeps = if_google(["//google/protobuf:wrappers"]), - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "error_codes_proto_impl", - srcs = ["error_codes.proto"], - make_default_target_header_only = True, - protodeps = if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "status_proto", - srcs = ["status.proto"], - make_default_target_header_only = True, - protodeps = [":error_codes_proto_impl"], - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "histogram_proto", - srcs = ["histogram.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_config_proto", - srcs = ["coordination_config.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "coordination_service_proto", - srcs = ["coordination_service.proto"], - has_services = 1, - create_grpc_library = True, - create_java_proto = False, - create_service = True, - protodeps = if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) - -# copybara:uncomment_begin(google-only) -# py_proto_library( -# name = "coordination_service_py_pb2", -# api_version = 2, -# visibility = ["//visibility:public"], -# deps = [":coordination_service_proto"], -# ) -# copybara:uncomment_end - -tf_proto_library( - name = "distributed_runtime_payloads_proto", - srcs = ["distributed_runtime_payloads.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "rpc_options_proto", - srcs = ["rpc_options.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - -tf_proto_library( - name = "protos_all", - create_go_proto = False, - make_default_target_header_only = True, - protodeps = [ - # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes - # breakages (and they are not actually used). - "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", - ":coordination_config_proto", - ":distributed_runtime_payloads_proto", - ":error_codes_proto_impl", - ":histogram_proto", - ":rpc_options_proto", - ":status_proto", - "@local_xla//xla/tsl/protobuf:test_log_proto", - ] + if_google(["//google/protobuf:any"]), - visibility = ["//visibility:public"], -) diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 90485629a8a445..05661246820911 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -393,18 +393,17 @@ def _tf_repositories(): name = "nccl_archive", build_file = "//third_party:nccl/archive.BUILD", patch_file = ["//third_party/nccl:archive.patch"], - sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", - strip_prefix = "nccl-2.21.5-1", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), + sha256 = "6b946b70a9d2d01871842cbd15ec56488d358abe9a0f3767e372fddc3e241ba7", + strip_prefix = "nccl-2.23.4-1", + urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.23.4-1.tar.gz"), ) - # Note that we are currently taking NVTX headers from a NCCL release to get nvToolsExtPayload.h tf_http_archive( name = "nvtx_archive", build_file = "//third_party:nvtx/BUILD", - sha256 = "1923596984d85e310b5b6c52b2c72a1b93da57218f2bc5a5c7ac3d59297a3303", - strip_prefix = "nccl-2.21.5-1/src/include/nvtx3", - urls = tf_mirror_urls("https://github.com/nvidia/nccl/archive/v2.21.5-1.tar.gz"), + sha256 = "e4438f921fb88a564b0b92791c1c1fdd0f388901213e6a31fdd0dc3803fb9764", + strip_prefix = "NVTX-bf31d7859ab3130cbf1ef77c33d18d0ebb8c8d08/c/include", + urls = tf_mirror_urls("https://github.com/NVIDIA/NVTX/archive/bf31d7859ab3130cbf1ef77c33d18d0ebb8c8d08.tar.gz"), ) java_import_external( @@ -476,14 +475,6 @@ def _tf_repositories(): licenses = ["notice"], # Apache 2.0 ) - tf_http_archive( - name = "nvtx_archive", - build_file = "//third_party:nvtx.BUILD", - sha256 = "bb8d1536aad708ec807bc675e12e5838c2f84481dec4005cd7a9bbd49e326ba1", - strip_prefix = "NVTX-3.0.1/c/include", - urls = tf_mirror_urls("https://github.com/NVIDIA/NVTX/archive/v3.0.1.tar.gz"), - ) - tf_http_archive( name = "cython", build_file = "//third_party:cython.BUILD", diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 492d591d208a81..83f52d9af9970a 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -9,7 +9,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -17,7 +16,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", @@ -25,8 +23,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", @@ -34,7 +30,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/usr/lib/llvm-18/bin/clang", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -42,8 +37,6 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", - compiler = "/dt9/usr/bin/gcc", - compiler_prefix = "/usr/bin", cuda_version = "12.3.2", cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", @@ -70,25 +63,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -99,23 +73,6 @@ def initialize_rbe_configs(): "sigbuild-r2.16-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:842a5ba84d3658c5bf1f8a31e16284f7becc35409da0dfd71816afa3cd28d728", "sigbuild-r2.16-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:40fcd1d05c672672b599d9cb3784dcf379d6aa876f043b46c6ab18237d5d4e10", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-17/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-17/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-17/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-17/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -126,25 +83,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:b6f572a897a69fa3311773f949b9aa9e81bc393e4fbe2c0d56d8afb03a6de080", "sigbuild-r2.17-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:8b856ad736147bb9c8bc9e1ec2c8e1ab17d36397905da7a5b63dadeff9310f0c", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/dt9/usr/bin/gcc", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/dt9/usr/bin/gcc", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "GCC_HOST_COMPILER_PATH": "/dt9/usr/bin/gcc", - "GCC_HOST_COMPILER_PREFIX": "/usr/bin", - "HOST_CXX_COMPILER": "/dt9/usr/bin/gcc", - "HOST_C_COMPILER": "/dt9/usr/bin/gcc", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -155,23 +93,6 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:2d737fc9fe931507a89927eee792b1bb934215e6aaae58b1941586e3400e2645", "sigbuild-r2.17-clang-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:45ea78e79305f91cdae5a26094f80233bba54bbfbc612623381012f097035b9a", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) sigbuild_tf_configs( @@ -182,21 +103,4 @@ def initialize_rbe_configs(): "sigbuild-r2.17-clang-cudnn9-python3.11": "docker://gcr.io/tensorflow-sigs/build@sha256:daa5bdd802fe3def188e2200ed707c73d278f6f1930bf26c933d6ba041b0e027", "sigbuild-r2.17-clang-cudnn9-python3.12": "docker://gcr.io/tensorflow-sigs/build@sha256:23e477895dd02e45df1056d4a0a9c4229dec3a20c23fb2f3fb5832ecbd0a29bc", }, - # Unclear why LIBC is set to 2.19 here, and yet manylinux2010 is 2.12 - # and manylinux2014 is 2.17. - env = { - "ABI_LIBC_VERSION": "glibc_2.19", - "ABI_VERSION": "gcc", - "BAZEL_COMPILER": "/usr/lib/llvm-18/bin/clang", - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC": "/usr/lib/llvm-18/bin/clang", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": "/usr/lib/llvm-18/bin/clang", - "HOST_C_COMPILER": "/usr/lib/llvm-18/bin/clang", - "TF_ENABLE_XLA": "1", - }, ) diff --git a/third_party/xla/tools/toolchains/remote_config/containers.bzl b/third_party/xla/tools/toolchains/remote_config/containers.bzl index c976ddc7dbbdd5..b22fbe0b65ad2e 100644 --- a/third_party/xla/tools/toolchains/remote_config/containers.bzl +++ b/third_party/xla/tools/toolchains/remote_config/containers.bzl @@ -7,12 +7,12 @@ container_digests = { # JAX manylinux2014 configs. "cuda11.8-cudnn8.6-ubuntu20.04-manylinux2014-multipython": "sha256:45619e91f14faabddd79fe0cb1526df4c4ad92fc2e6ebdc725ea4419225429c3", "cuda12.1-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:8c266e5b0acd203aed5e8871b63f68a39d8d23f6d882e619797e58b973f7fe63", - "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:7d948c3d2e3ab8867d600457b5666cc74c4206f08517791c95fc9a69b7cffefa", + "cuda12.1-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:fafe12fbe5bb02a21b9a95aa9dc3ac6d0e6276fcb7dd26bf1bb2d093b444b71a", "cuda12.2-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:9fefda035b4a12b24cd5bae56c7dbb9527a5fd06a41ced0a22ac86fe5ed26428", "cuda12.2-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:0c78f3428cde36f041b758fc2f01d23d2f0dd72dec248f78667fb0c9d1f74cef", "cuda12.3-cudnn8.9-ubuntu20.04-manylinux2014-multipython": "sha256:6f9524a2ed7f75255dc4be3a0c5e3bda581385a1c13e2fa890bc17fa62da95b2", "cuda12.3-cudnn8.9-ubuntu22.04-manylinux2014-multipython": "sha256:97b219abb22994cf0530771d536f26fe301bacd328f0485c38af3847c2ee6b14", - "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:e590303ea55a0990c26db4640161120ff6bc4124152c62155d397ba22d2ca850", + "cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython": "sha256:a9acf6849a905079847074798405b18d4badc6270dc32076f9e7ac4b377e51a8", # ROCM, probably not all of them still in use "rocm-ubuntu18.04-manylinux2010-multipython": "sha256:6e953a09b145df338bcb03e9e36f99b291140c29b72d0a048fb6c5905ccad5eb", "rocm-ubuntu20.04-manylinux2014-multipython": "sha256:906faec7765fe5dd067f2b092b5d5f220c1fedde725fb42c83d031b4d6f32204", diff --git a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl index fe29d52ad1bef9..280b8d914283dd 100644 --- a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl @@ -9,34 +9,13 @@ def _container_image_uri(container_name): container = containers[container_name] return "docker://%s/%s@%s" % (container["registry"], container["repository"], container["digest"]) -def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version = None, cudnn_version = None, compiler_prefix = None): +def _tensorflow_rbe_config(name, os, rocm_version = None, cuda_version = None, cudnn_version = None): if cuda_version != None and rocm_version != None: fail("Specifying both cuda_version and rocm_version is not supported.") - env = { - "ABI_VERSION": "gcc", - "ABI_LIBC_VERSION": "glibc_2.19", - "BAZEL_COMPILER": compiler, - "BAZEL_HOST_SYSTEM": "i686-unknown-linux-gnu", - "BAZEL_TARGET_LIBC": "glibc_2.19", - "BAZEL_TARGET_CPU": "k8", - "BAZEL_TARGET_SYSTEM": "x86_64-unknown-linux-gnu", - "CC_TOOLCHAIN_NAME": "linux_gnu_x86", - "CC": compiler, - "CLEAR_CACHE": "1", - "HOST_CXX_COMPILER": compiler, - "HOST_C_COMPILER": compiler, - } + env = {} if cuda_version != None: - # The cuda toolchain currently contains its own C++ toolchain definition, - # so we do not fetch local_config_cc. - env.update({ - "TF_ENABLE_XLA": "1", - "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", - "GCC_HOST_COMPILER_PREFIX": compiler_prefix if compiler_prefix != None else "/usr/bin", - }) - cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) container_name = "cuda%s-cudnn%s-%s" % ( @@ -49,13 +28,11 @@ def _tensorflow_rbe_config(name, compiler, os, rocm_version = None, cuda_version "container-image": container_image, "Pool": "default", } - elif rocm_version != None: # The rocm toolchain currently contains its own C++ toolchain definition, # so we do not fetch local_config_cc. env.update({ "TF_NEED_ROCM": "1", - "TF_ENABLE_XLA": "0", }) container_name = "rocm-%s" % (os) @@ -121,9 +98,8 @@ tensorflow_local_config = _tensorflow_local_config # Streamlined platform configuration for the SIG Build containers. # See //tensorflow/tools/tf_sig_build_dockerfiles -# These containers do not support ROCm and all have CUDA. We demand that the configuration -# provide all the env variables to remove hidden logic. -def sigbuild_tf_configs(name_container_map, env): +# These containers do not support ROCm and all have CUDA. +def sigbuild_tf_configs(name_container_map): for name, container in name_container_map.items(): exec_properties = { "container-image": container, diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index c9316d9c546119..b7f1a874dada1b 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -316,6 +316,7 @@ xla_cc_test( ":util", "@com_google_absl//absl/base", "@com_google_absl//absl/numeric:bits", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", ], @@ -336,6 +337,8 @@ cc_library( ":status_macros", ":types", ":xla_data_proto_cc", + "//xla/tsl/lib/gtl:iterator_range", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -343,6 +346,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", @@ -353,8 +357,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/gtl:iterator_range", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:bfloat16", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", @@ -373,9 +375,12 @@ xla_cc_test( ":test", ":types", ":util", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", + "@ml_dtypes//:float8", ], ) @@ -386,6 +391,7 @@ cc_library( deps = [ ":types", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", ], @@ -487,7 +493,7 @@ cc_library( deps = [ ":status_macros", ":util", - "//xla/service:hlo_lexer", + "//xla/hlo/parser:hlo_lexer", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -880,7 +886,10 @@ cc_library( ":xla_data_proto_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", ], @@ -910,13 +919,16 @@ cc_library( ":types", ":util", ":xla_data_proto_cc", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", + "//xla/tsl/lib/io:buffered_inputstream", + "//xla/tsl/lib/io:random_inputstream", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/io:buffered_inputstream", - "@local_tsl//tsl/lib/io:random_inputstream", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) @@ -946,6 +958,7 @@ cc_library( ":status_macros", ":types", ":xla_data_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", @@ -976,12 +989,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":shape_util", + "//xla/tsl/lib/gtl:iterator_range", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -1038,6 +1051,7 @@ xla_cc_test( deps = [ ":test", ":window_util", + ":xla_data_proto_cc", "@local_tsl//tsl/platform:test_main", ], ) @@ -1057,17 +1071,17 @@ cc_library( ":util", ":window_util", ":xla_data_proto_cc", - "//xla/client:padding", - "//xla/client:xla_builder", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service:shape_inference", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", ], ) @@ -1085,7 +1099,7 @@ xla_cc_test( ":reference_util", ":test", ":xla_data_proto_cc", - "//xla/client:padding", + "//xla/hlo/builder:padding", "//xla/tests:literal_test_util", "@local_tsl//tsl/platform:test_main", ], @@ -1099,9 +1113,11 @@ cc_library( [ ":types", "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", ], @@ -1117,6 +1133,7 @@ xla_cc_test( "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:subprocess", "@local_tsl//tsl/platform:test", ], @@ -1136,6 +1153,7 @@ cc_library( [ ":parse_flags_from_env", ":xla_proto_cc", + "//xla/service:collective_utils", "//xla/stream_executor/cuda:nvjitlink_support", "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/tsl/util:command_line_flags", @@ -1160,6 +1178,7 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ "@com_google_absl//absl/base:dynamic_annotations", + "@eigen_archive//:eigen3", ], ) @@ -1170,11 +1189,17 @@ xla_cc_test( "debug_options_parsers.h", "debug_options_parsers_test.cc", ], + # TODO(https://github.com/openxla/xla/issues/17808) + tags = ["noasan"], deps = [ + ":debug_options_flags", + ":parse_flags_from_env", ":xla_proto_cc", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", ], ) @@ -1228,7 +1253,6 @@ tf_proto_library( xla_py_proto_library( name = "autotune_results_py_pb2", - api_version = 2, visibility = ["//visibility:public"], deps = [ ":autotune_results_proto", @@ -1248,7 +1272,7 @@ tf_proto_library( srcs = ["autotuning.proto"], make_default_target_header_only = True, protodeps = [ - "@local_tsl//tsl/protobuf:dnn_proto", + "//xla/tsl/protobuf:dnn_proto", ] + if_google([ "@com_google_protobuf//:any", "@com_google_protobuf//:duration", @@ -1301,7 +1325,6 @@ cc_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "xla_data_proto_py_pb2", -# api_version = 2, # visibility = internal_visibility([":friends"]), # deps = [":xla_data_proto"], # ) @@ -1309,7 +1332,6 @@ cc_library( # py_proto_library( # name = "xla_py_pb2", # testonly = 0, -# api_version = 2, # compatible_with = get_compatible_with_portable(), # visibility = internal_visibility([":friends"]), # deps = [":xla_proto"], @@ -1349,6 +1371,7 @@ bzl_library( deps = [ "//xla/tsl:tsl_bzl", "@bazel_skylib//lib:paths", + "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", ], ) diff --git a/third_party/xla/xla/array.cc b/third_party/xla/xla/array.cc index 84fbc12125dc89..b75941fbe87e79 100644 --- a/third_party/xla/xla/array.cc +++ b/third_party/xla/xla/array.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/array.h" #include +#include #include #include "xla/types.h" diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 03c5f3b9760c4b..1d28388c563117 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -20,8 +20,10 @@ limitations under the License. #include #include #include +#include #include #include +#include #include #include #include @@ -268,11 +270,19 @@ class Array { // Fills the array with random uniform variables in the [min_value, max_value] // range. Defined for integral types. - template ::value>> + template >> void FillRandomUniform(const T& min_value, const T& max_value, int seed = 12345) { + using RngInputType = + std::conditional_t, T, + std::conditional_t::is_signed, + int64_t, uint64_t>>; + static_assert(std::numeric_limits::digits <= + std::numeric_limits::digits); std::mt19937 g(seed); - std::uniform_int_distribution distribution(min_value, max_value); + std::uniform_int_distribution distribution( + static_cast(min_value), + static_cast(max_value)); for (int64_t i = 0; i < num_elements(); ++i) { values_[i] = static_cast(distribution(g)); } diff --git a/third_party/xla/xla/array2d.h b/third_party/xla/xla/array2d.h index 9afe514f5458c1..98942f954b4efc 100644 --- a/third_party/xla/xla/array2d.h +++ b/third_party/xla/xla/array2d.h @@ -57,6 +57,17 @@ class Array2D : public Array { : Array(values) {} Array2D(const Array2D& other) : Array(other) {} + Array2D(Array2D&& other) noexcept : Array(std::move(other)) {} + + Array2D& operator=(const Array2D& other) { + Array::operator=(other); + return *this; + } + + Array2D& operator=(Array2D&& other) noexcept { + Array::operator=(std::move(other)); + return *this; + } int64_t n1() const { return this->dim(0); } int64_t n2() const { return this->dim(1); } diff --git a/third_party/xla/xla/array2d_test.cc b/third_party/xla/xla/array2d_test.cc index 4d0fbf3732ff9a..921da30256fa3d 100644 --- a/third_party/xla/xla/array2d_test.cc +++ b/third_party/xla/xla/array2d_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/array2d.h" +#include #include #include @@ -162,6 +163,20 @@ TEST(Array2dTest, LinspaceF8E5M2) { EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); } +TEST(Array2dTest, LinspaceF8E4M3) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, LinspaceF8E4M3Fn) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); @@ -190,6 +205,20 @@ TEST(Array2dTest, LinspaceF8E4M3FnNoNan) { } } +TEST(Array2dTest, LinspaceF8E3M4) { + auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); + + EXPECT_EQ(arr->n1(), 3); + EXPECT_EQ(arr->n2(), 2); + + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 0)), 1.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(0, 1)), 1.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 0)), 2.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(1, 1)), 2.5); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 0)), 3.0); + EXPECT_FLOAT_EQ(static_cast((*arr)(2, 1)), 3.5); +} + TEST(Array2dTest, Stringification) { auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2); const std::string expected = R"([[1, 1.5], diff --git a/third_party/xla/xla/array3d_test.cc b/third_party/xla/xla/array3d_test.cc index 07599797fbbb0d..334d733266b41b 100644 --- a/third_party/xla/xla/array3d_test.cc +++ b/third_party/xla/xla/array3d_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/array3d.h" +#include #include #include "xla/test.h" diff --git a/third_party/xla/xla/array4d_test.cc b/third_party/xla/xla/array4d_test.cc index 9acae92eef3b89..1deb1bc81f3c7e 100644 --- a/third_party/xla/xla/array4d_test.cc +++ b/third_party/xla/xla/array4d_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/array4d.h" +#include #include #include #include diff --git a/third_party/xla/xla/array_test.cc b/third_party/xla/xla/array_test.cc index f8a8f09ba3629b..bf79aa98f40491 100644 --- a/third_party/xla/xla/array_test.cc +++ b/third_party/xla/xla/array_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/array.h" +#include #include #include #include diff --git a/third_party/xla/xla/autotuning.proto b/third_party/xla/xla/autotuning.proto index a7ffcbb57ae6ef..b3d6b8e380b4b8 100644 --- a/third_party/xla/xla/autotuning.proto +++ b/third_party/xla/xla/autotuning.proto @@ -9,7 +9,7 @@ package xla; import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; -import "tsl/protobuf/dnn.proto"; +import "xla/tsl/protobuf/dnn.proto"; message CudnnVersion { int32 major = 1; @@ -83,6 +83,10 @@ message AutotuneResult { int64 num_ctas = 7; } + message CustomKernelFusionKey { + int64 kernel_index = 1; + } + int64 scratch_bytes = 8; google.protobuf.Duration run_time = 9; @@ -93,10 +97,11 @@ message AutotuneResult { GemmKey gemm = 6; TritonGemmKey triton = 17; CudaConvPlanKey cuda_conv_plan = 15; + CustomKernelFusionKey custom_kernel_fusion = 18; stream_executor.dnn.AlgorithmProto algorithm = 16; } - // Next ID: 17 + // Next ID: 19 } message AutotuningLog { diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/BUILD b/third_party/xla/xla/backends/cpu/codegen/ir/BUILD new file mode 100644 index 00000000000000..f1afa0fcc7bf37 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/BUILD @@ -0,0 +1,107 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +td_library( + name = "xla_cpu_td_files", + srcs = glob(["*.td"]), + includes = ["."], + deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_cpu_dialect_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-dialect-decls"], + "xla_cpu_dialect.h.inc", + ), + ( + ["-gen-dialect-defs"], + "xla_cpu_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_dialect.td", + deps = [":xla_cpu_td_files"], +) + +gentbl_cc_library( + name = "xla_cpu_types_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + "-typedefs-dialect=xla_cpu", + ], + "xla_cpu_types.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "-typedefs-dialect=xla_cpu", + ], + "xla_cpu_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_types.td", + deps = [":xla_cpu_td_files"], +) + +gentbl_cc_library( + name = "xla_cpu_ops_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + ["-gen-op-decls"], + "xla_cpu_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "xla_cpu_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_cpu_ops.td", + deps = [":xla_cpu_td_files"], +) + +cc_library( + name = "xla_cpu", + srcs = [ + "xla_cpu_dialect.cc", + "xla_cpu_ops.cc", + "xla_cpu_types.cc", + ], + hdrs = [ + "xla_cpu_dialect.h", + "xla_cpu_ops.h", + "xla_cpu_types.h", + ], + deps = [ + ":xla_cpu_dialect_inc_gen", + ":xla_cpu_ops_inc_gen", + ":xla_cpu_types_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD b/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD new file mode 100644 index 00000000000000..c33df92b01cc32 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/backends/cpu/codegen/tools:xla_cpu_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir b/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir new file mode 100644 index 00000000000000..0e7faa0a235242 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/ops.mlir @@ -0,0 +1,28 @@ +// RUN: xla_cpu_opt %s --split-input-file | FileCheck %s + +func.func @load(%arg0: !xla_cpu.call_frame) -> tensor<32x32xf32> { + %0 = xla_cpu.load %arg0, 0 : tensor<32x32xf32> + return %0 : tensor<32x32xf32> +} + +// CHECK-LABEL: @load( +// CHECK: %[[ARG0:.+]]: !xla_cpu.call_frame +// CHECK: ) -> tensor<32x32xf32> { +// CHECK: %[[LOAD:.+]] = xla_cpu.load %[[ARG0]], 0 : tensor<32x32xf32> +// CHECK: return %[[LOAD]] : tensor<32x32xf32> +// CHECK: } + +// ----- + +func.func @store(%arg0: !xla_cpu.call_frame, %arg1: tensor<32x32xf32>) { + xla_cpu.store %arg1 into %arg0, 0 : tensor<32x32xf32> + return +} + +// CHECK-LABEL: @store( +// CHECK: %[[ARG0:.+]]: !xla_cpu.call_frame, +// CHECK: %[[ARG1:.+]]: tensor<32x32xf32> +// CHECK: ) { +// CHECK: xla_cpu.store %[[ARG1]] into %[[ARG0]], 0 : tensor<32x32xf32> +// CHECK: return +// CHECK: } diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir b/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir new file mode 100644 index 00000000000000..504dfa29976c33 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/tests/types.mlir @@ -0,0 +1,9 @@ +// RUN: xla_cpu_opt %s | FileCheck %s + +func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) { + return +} + +// CHECK-LABEL: @call_frame_arg( +// CHECK-SAME: %[[ARG0:.+]]: !xla_cpu.call_frame +// CHECK-SAME: ) diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.cc b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.cc new file mode 100644 index 00000000000000..c8e6c2efc27496 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" + +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.h" // IWYU pragma: keep + +// Include the auto-generated implementation file. +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.cc.inc" + +namespace xla::cpu { + +void XlaCpuDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.cc.inc" + >(); + addTypes< +#define GET_TYPEDEF_LIST +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.cc.inc" + >(); +} + +} // namespace xla::cpu diff --git a/tensorflow/cc/experimental/libtf/impl/scalars_test.cc b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.h similarity index 61% rename from tensorflow/cc/experimental/libtf/impl/scalars_test.cc rename to third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.h index 79c73f194426d5..a9a8edbc7434fd 100644 --- a/tensorflow/cc/experimental/libtf/impl/scalars_test.cc +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,18 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/scalars.h" -#include "tensorflow/core/platform/test.h" +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT_H_ +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT_H_ -namespace tf { -namespace libtf { -namespace impl { +#include "mlir/IR/Dialect.h" // IWYU pragma: keep -TEST(ScalarsTest, TestHeterogeneousAddition) { - ASSERT_EQ((Int64(1) + Float32(0.375)).get(), 1.375); -} +// Include the auto-generated header file. +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h.inc" -} // namespace impl -} // namespace libtf -} // namespace tf +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.td b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.td new file mode 100644 index 00000000000000..72904d8d084806 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_dialect.td @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT + +include "mlir/IR/DialectBase.td" + +def XlaCpuDialect : Dialect { + let name = "xla_cpu"; + + let description = [{ + This dialect contains ops required for lowering HLO to LLVM for XLA:CPU + backend and runtime. + }]; + + let cppNamespace = "::xla::cpu"; + let useDefaultTypePrinterParser = 1; +} + +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_DIALECT diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.cc b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.cc new file mode 100644 index 00000000000000..882c38eeee7eaf --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.cc @@ -0,0 +1,21 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.h" + +#include "mlir/IR/Builders.h" // IWYU pragma: keep + +#define GET_OP_CLASSES +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.cc.inc" diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.h b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.h new file mode 100644 index 00000000000000..5607bfabce9ede --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS_H_ +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep +#include "mlir/IR/Attributes.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep +#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.h" // IWYU pragma: keep + +#define GET_OP_CLASSES +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.h.inc" + +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.td b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.td new file mode 100644 index 00000000000000..7ea511f75c7d56 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_ops.td @@ -0,0 +1,78 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS + +include "mlir/IR/OpBase.td" +include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.td" +include "xla/backends/cpu/codegen/ir/xla_cpu_types.td" + +class XLACPU_Op traits = []> : + Op { +} + +//===----------------------------------------------------------------------===// +// !xla_cpu.load +//===----------------------------------------------------------------------===// + +def XLACPU_LoadOp : XLACPU_Op<"load"> { + let summary = "Loads a tensor from an XLA:CPU call frame"; + + let description = [{ + Loads a tensor from an XLA:CPU call frame at the given index. + + ```mlir + %0 = xla_cpu.load %call_frame, 0 : tensor<32x32xf32> + ``` + }]; + + let arguments = (ins XLACPU_CallFrame:$call_frame, + I32Attr:$index); + + let results = (outs AnyStaticShapeTensor:$result); + + let assemblyFormat = [{ + $call_frame `,` $index attr-dict `:` type($result) + }]; +} + +//===----------------------------------------------------------------------===// +// !xla_cpu.store +//===----------------------------------------------------------------------===// + +def XLACPU_StoreOp : XLACPU_Op<"store"> { + let summary = "Stores a tensor into an XLA:CPU call frame"; + + let description = [{ + Stores a tensor into an XLA:CPU call frame at the given index. + + ```mlir + %0 = ... : tensor<32x32xf32> + xla_cpu.store %0 into %call_frame, 0 : tensor<32x32xf32> + ``` + }]; + + let arguments = (ins AnyStaticShapeTensor:$tensor, + XLACPU_CallFrame:$call_frame, + I32Attr:$index); + + let assemblyFormat = [{ + $tensor `into` $call_frame `,` $index attr-dict `:` type($tensor) + }]; + +} + +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_OPS diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.cc b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.cc new file mode 100644 index 00000000000000..ce94cb7082ce2c --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.cc @@ -0,0 +1,23 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.h" + +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" // IWYU pragma: keep + +#define GET_TYPEDEF_CLASSES +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.cc.inc" diff --git a/tensorflow/cc/experimental/libtf/object.cc b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.h similarity index 63% rename from tensorflow/cc/experimental/libtf/object.cc rename to third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.h index a5f4882d532fce..f59ba42c2b73fb 100644 --- a/tensorflow/cc/experimental/libtf/object.cc +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,18 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Implementation of objects. -#include "tensorflow/cc/experimental/libtf/object.h" -#include +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES_H_ +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES_H_ -namespace tf { -namespace libtf { +#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep -const String& Object::ParentKey() { - static const String* key = new String("__parent__"); - return *key; -} +#define GET_TYPEDEF_CLASSES +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.h.inc" -} // namespace libtf -} // namespace tf +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES_H_ diff --git a/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.td b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.td new file mode 100644 index 00000000000000..9b033099017d3e --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/ir/xla_cpu_types.td @@ -0,0 +1,36 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES +#define XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypes.td" +include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.td" + +class XLACPU_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// !xla_cpu.call_frame +//===----------------------------------------------------------------------===// + +def XLACPU_CallFrame : XLACPU_Type<"CallFrame", "call_frame"> { + let summary = "XLA:CPU host kernel call frame"; +} + +#endif // XLA_BACKENDS_CPU_CODEGEN_IR_XLA_CPU_TYPES diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/BUILD b/third_party/xla/xla/backends/cpu/codegen/tools/BUILD new file mode 100644 index 00000000000000..cfc8a5a33f6b41 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/tools/BUILD @@ -0,0 +1,22 @@ +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "xla_cpu_opt", + srcs = ["xla_cpu_opt.cc"], + visibility = ["//xla/backends/cpu/codegen:__subpackages__"], + deps = [ + "//xla/backends/cpu/codegen/ir:xla_cpu", + "//xla/backends/cpu/codegen/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc b/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc new file mode 100644 index 00000000000000..109b4d5489526f --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/tools/xla_cpu_opt.cc @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Transforms/Passes.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" +#include "xla/backends/cpu/codegen/transforms/passes.h" + +int main(int argc, char** argv) { + mlir::DialectRegistry registry; + registry.insert(); + + // Register builtin MLIR passes. + mlir::func::registerAllExtensions(registry); + mlir::registerCanonicalizerPass(); + mlir::registerCSEPass(); + + // Register XLA:CPU passes. + xla::cpu::registerXlaCpuTransformsPasses(); + + return mlir::failed( + MlirOptMain(argc, argv, "XLA:CPU Pass Driver\n", registry)); +} diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/BUILD b/third_party/xla/xla/backends/cpu/codegen/transforms/BUILD new file mode 100644 index 00000000000000..07519d2c9d8dae --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/BUILD @@ -0,0 +1,64 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=XlaCpuTransforms", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = ["lower_trivial.cc"], + hdrs = ["passes.h"], + deps = [ + ":passes_inc_gen", + ":xla_cpu_rewrite_patterns", + "//xla/backends/cpu/codegen/ir:xla_cpu", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "xla_cpu_rewrite_patterns", + srcs = ["xla_cpu_rewrite_patterns.cc"], + hdrs = ["xla_cpu_rewrite_patterns.h"], + deps = [ + "//xla/backends/cpu/codegen/ir:xla_cpu", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/lower_trivial.cc b/third_party/xla/xla/backends/cpu/codegen/transforms/lower_trivial.cc new file mode 100644 index 00000000000000..ed2ad8e2bf3f42 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/lower_trivial.cc @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // IWYU pragma: keep +#include "mlir/Dialect/Tensor/IR/Tensor.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" // IWYU pragma: keep +#include "xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.h" + +namespace xla::cpu { + +#define GEN_PASS_DECL_LOWERTRIVIALPASS +#define GEN_PASS_DEF_LOWERTRIVIALPASS +#include "xla/backends/cpu/codegen/transforms/passes.h.inc" + +namespace { +class LowerTrivialPass : public impl::LowerTrivialPassBase { + void runOnOperation() override { + mlir::TypeConverter converter; + mlir::ConversionTarget target(getContext()); + + converter.addConversion([](mlir::Type type) { return type; }); + PopulateXlaCpuTypeConversionAndLegality(converter, target); + + mlir::RewritePatternSet patterns(&getContext()); + PopulateXlaCpuConversionPatterns(patterns); + + // Add conversion patterns for function signatures. + mlir::populateFunctionOpInterfaceTypeConversionPattern( + patterns, converter); + + // Set up basic legality constraints. + target.addLegalOp(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Add dynamic legality constraints to apply conversions defined above. + target.addDynamicallyLegalOp( + [&](mlir::func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + + if (mlir::failed(mlir::applyFullConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr CreateLowerTrivialPass() { + return std::make_unique(); +} + +} // namespace xla::cpu diff --git a/tensorflow/cc/experimental/libtf/runtime/core/core.h b/third_party/xla/xla/backends/cpu/codegen/transforms/passes.h similarity index 53% rename from tensorflow/cc/experimental/libtf/runtime/core/core.h rename to third_party/xla/xla/backends/cpu/codegen/transforms/passes.h index 12ced72eccb79b..fdd0c4e58fb2fa 100644 --- a/tensorflow/cc/experimental/libtf/runtime/core/core.h +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/passes.h @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,22 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ -#include "tensorflow/cc/experimental/libtf/runtime/runtime.h" +#ifndef XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_PASSES_H_ +#define XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_PASSES_H_ -namespace tf { -namespace libtf { -namespace runtime { -namespace core { +#include -// Instantiate a Core Runtime. -Runtime Runtime(); +#include "mlir/Pass/Pass.h" -} // namespace core -} // namespace runtime -} // namespace libtf -} // namespace tf +namespace xla::cpu { -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_RUNTIME_CORE_CORE_H_ +#define GEN_PASS_DECL +#include "xla/backends/cpu/codegen/transforms/passes.h.inc" + +std::unique_ptr CreateLowerTrivialPass(); + +#define GEN_PASS_REGISTRATION +#include "xla/backends/cpu/codegen/transforms/passes.h.inc" + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.h b/third_party/xla/xla/backends/cpu/codegen/transforms/passes.td similarity index 51% rename from tensorflow/cc/experimental/libtf/mlir/mlir_transform.h rename to third_party/xla/xla/backends/cpu/codegen/transforms/passes.td index bd5ec58c29fa56..f15b3e5abffded 100644 --- a/tensorflow/cc/experimental/libtf/mlir/mlir_transform.h +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/passes.td @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,19 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ -#define TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ -#include "tensorflow/cc/experimental/libtf/object.h" +include "mlir/Pass/PassBase.td" -namespace tf { -namespace libtf { +def LowerTrivialPass : Pass<"xla-cpu-lower-trivial", "mlir::ModuleOp"> { + let summary = "Trivial one shot lowering from tensors + xla_cpu to LLVM"; -// Returns a MLIR object with methods that can be used to load/save saved -// models, and also do transformations. -Object MLIR(); + let description = [{ + This is a trivial pass for lowering tensors and xla_cpu operations to LLVM + dialect and pointers. + }]; -} // namespace libtf -} // namespace tf + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::func::FuncDialect", + "mlir::tensor::TensorDialect", + "xla::cpu::XlaCpuDialect", + ]; -#endif // TENSORFLOW_CC_EXPERIMENTAL_LIBTF_MLIR_MLIR_TRANSFORM_H_ + let constructor = "::xla::cpu::CreateLowerTrivialPass()"; +} \ No newline at end of file diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD new file mode 100644 index 00000000000000..c33df92b01cc32 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/backends/cpu/codegen/tools:xla_cpu_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir new file mode 100644 index 00000000000000..a7d4a117f0f005 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/tests/lower_trivial.mlir @@ -0,0 +1,16 @@ +// RUN: xla_cpu_opt %s --xla-cpu-lower-trivial | FileCheck %s + +func.func @call_frame_arg(%arg0: !xla_cpu.call_frame) { + %0 = xla_cpu.load %arg0, 0 : tensor<32x32xf32> + return +} + +// CHECK-LABEL: @call_frame_arg( +// CHECK: %[[ARG0:.+]]: !llvm.ptr +// CHECK: ) { +// CHECK: %[[ARGS_GEP:.+]] = llvm.getelementptr %[[ARG0]][3] +// CHECK: %[[ARGS:.+]] = llvm.load %[[ARGS_GEP]] +// CHECK: %[[ARG_GEP:.+]] = llvm.getelementptr %[[ARGS]][0] +// CHECK: %[[ARG:.+]] = llvm.load %[[ARG_GEP]] +// CHECK: return +// CHECK: } diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.cc b/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.cc new file mode 100644 index 00000000000000..93991881a592f3 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.cc @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_dialect.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_ops.h" +#include "xla/backends/cpu/codegen/ir/xla_cpu_types.h" + +namespace xla::cpu { + +void PopulateXlaCpuTypeConversionAndLegality(mlir::TypeConverter& converter, + mlir::ConversionTarget& target) { + converter.addConversion([](CallFrameType call_frame) { + return mlir::LLVM::LLVMPointerType::get(call_frame.getContext()); + }); + + target.addIllegalDialect(); +} + +namespace { +struct LowerLoadOp : public mlir::OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + LoadOp op, LoadOp::Adaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override; +}; +} // namespace + +// LLVM structs corresponds to `SE_HOST_KernelCallFrame` struct that defines +// XLA:CPU host kernel ABI contract. + +static mlir::LLVM::LLVMStructType KernelDim3Type(mlir::MLIRContext* ctx) { + auto i64 = mlir::IntegerType::get(ctx, 64); + return mlir::LLVM::LLVMStructType::getNewIdentified(ctx, "kernel_dim3", + {i64, i64, i64}); +} + +static mlir::LLVM::LLVMStructType KernelArgType(mlir::MLIRContext* ctx) { + auto ptr = mlir::LLVM::LLVMPointerType::get(ctx); + auto i64 = mlir::IntegerType::get(ctx, 64); + return mlir::LLVM::LLVMStructType::getNewIdentified(ctx, "kernel_arg", + {ptr, i64}); +} + +static mlir::LLVM::LLVMStructType KernelCallFrameType(mlir::MLIRContext* ctx) { + auto ptr = mlir::LLVM::LLVMPointerType::get(ctx); + auto i64 = mlir::IntegerType::get(ctx, 64); + return mlir::LLVM::LLVMStructType::getNewIdentified(ctx, "kernel_call_frame", + {ptr, ptr, i64, ptr}); +} + +mlir::LogicalResult LowerLoadOp::matchAndRewrite( + LoadOp op, LoadOp::Adaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const { + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + auto ptr = b.getType(); + auto kernel_call_frame = KernelCallFrameType(b.getContext()); + auto kernel_arg = KernelArgType(b.getContext()); + + // Get a pointer to the first `KernelArg` struct. + auto args_gep = b.create( + ptr, kernel_call_frame, adaptor.getCallFrame(), mlir::LLVM::GEPArg(3)); + auto args_ptr = b.create(ptr, args_gep); + + // Get a pointer to the `KernelArg` at the given index. + auto arg_gep = b.create(ptr, kernel_arg, args_ptr, + mlir::LLVM::GEPArg(op.getIndex())); + auto arg_ptr = b.create(ptr, arg_gep); + + rewriter.replaceOp(op, arg_ptr); + return mlir::success(); +} + +void PopulateXlaCpuConversionPatterns(mlir::RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.h b/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.h new file mode 100644 index 00000000000000..25e3fcd4850b4e --- /dev/null +++ b/third_party/xla/xla/backends/cpu/codegen/transforms/xla_cpu_rewrite_patterns.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_XLA_CPU_REWRITE_PATTERNS_H_ +#define XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_XLA_CPU_REWRITE_PATTERNS_H_ + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace xla::cpu { + +// Populates type conversion and legality constraints for lowering XLA:CPU +// types to LLVM types. +void PopulateXlaCpuTypeConversionAndLegality(mlir::TypeConverter& converter, + mlir::ConversionTarget& target); + +// Populates rewrite patterns for converting XLA:CPU ops to LLVM ops. +void PopulateXlaCpuConversionPatterns(mlir::RewritePatternSet& patterns); + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_CODEGEN_TRANSFORMS_XLA_CPU_REWRITE_PATTERNS_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD index 8e60bd7b30582a..37b746abac1060 100644 --- a/third_party/xla/xla/backends/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -42,12 +42,10 @@ cc_library( "//xla:util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", ], ) @@ -67,6 +65,15 @@ xla_cc_test( ], ) +cc_library( + name = "thread_pool_task_runner", + hdrs = ["thread_pool_task_runner.h"], + deps = [ + ":thunk", + "@eigen_archive//:eigen3", + ], +) + cc_library( name = "thunk", srcs = ["thunk.cc"], @@ -83,12 +90,12 @@ cc_library( "//xla/service/cpu:cpu_executable_run_options", "//xla/service/cpu:cpu_runtime", "//xla/service/cpu:in_process_collectives", - "//xla/stream_executor", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -159,6 +166,7 @@ xla_cc_test( deps = [ ":buffer_allocations", ":resource_use", + ":thread_pool_task_runner", ":thunk", ":thunk_executor", "//xla/runtime:buffer_use", @@ -203,22 +211,18 @@ cc_library( srcs = ["conditional_thunk.cc"], hdrs = ["conditional_thunk.h"], deps = [ - ":resource_use", ":thunk", ":thunk_executor", "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -235,7 +239,6 @@ xla_cc_test( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -260,7 +263,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service/cpu:collectives_interface", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -387,7 +389,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service/cpu:collectives_interface", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -419,7 +420,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service/cpu:collectives_interface", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -451,7 +451,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service/cpu:collectives_interface", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -484,7 +483,6 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service/cpu:collectives_interface", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", @@ -548,7 +546,7 @@ cc_library( "//xla/pjrt:transpose", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -573,9 +571,8 @@ xla_cc_test( "//xla:shape_util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -600,7 +597,7 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_status_internal", "//xla/service:custom_call_target_registry", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", @@ -639,10 +636,9 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/pjrt:transpose", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "@com_google_absl//absl/algorithm:container", @@ -650,14 +646,12 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -675,11 +669,9 @@ cc_library( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service/cpu:cpu_runtime", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -702,7 +694,6 @@ xla_cc_test( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -723,7 +714,7 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service:computation_placer_hdr", "//xla/service:global_device_id", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", @@ -744,7 +735,7 @@ xla_cc_test( "//xla:executable_run_options", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -766,11 +757,9 @@ cc_library( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service/cpu:cpu_runtime", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -793,7 +782,6 @@ xla_cc_test( "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", @@ -813,7 +801,8 @@ cc_library( "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:launch_dim", "//xla/stream_executor/host:host_kernel", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", @@ -831,7 +820,6 @@ cc_library( "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", ], @@ -846,10 +834,10 @@ xla_cc_test( ":thunk", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:launch_dim", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -891,7 +879,7 @@ cc_library( "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", @@ -917,10 +905,9 @@ cc_library( "//xla:util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:config", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:inlined_vector", @@ -949,15 +936,13 @@ xla_cc_test( "//xla:shape_util", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", - "//xla/stream_executor/host:host_kernel_c_api", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) @@ -972,15 +957,13 @@ cc_library( ":thunk_executor", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", @@ -996,14 +979,11 @@ xla_cc_test( ":thunk", ":thunk_testlib", ":while_thunk", - "//xla:shape_util", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", @@ -1024,7 +1004,8 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/cpu:runtime_fft", "//xla/service/cpu:runtime_single_threaded_fft", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", diff --git a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 8ce3106213b07f..8f693a1e3c5378 100644 --- a/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -60,7 +60,7 @@ limitations under the License. namespace xla::cpu { namespace { -using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; +using AttributesMap = ffi::CallFrameBuilder::AttributesMap; absl::StatusOr ParseAttributes( absl::string_view backend_config) { diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index 8d2df6f298cbcf..3f8ac4792a1980 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -17,15 +17,18 @@ limitations under the License. #include #include +#include #include #include #include #include #include #include +#include #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" @@ -86,36 +89,44 @@ static absl::Status VerifySortInputs(absl::Span inputs, absl::StatusOr> SortThunk::Create( Info info, absl::Span inputs, int64_t dimension, - bool is_stable, LessThan less_than) { + bool is_stable, LessThan less_than, + std::optional direction) { TF_RETURN_IF_ERROR(VerifySortInputs(inputs, dimension)); return absl::WrapUnique(new SortThunk(std::move(info), inputs, dimension, - is_stable, std::move(less_than))); + is_stable, std::move(less_than), + direction)); } absl::StatusOr> SortThunk::Create( Info info, absl::Span inputs, int64_t dimension, - bool is_stable, std::string comparator_name) { + bool is_stable, std::string comparator_name, + std::optional direction) { TF_RETURN_IF_ERROR(VerifySortInputs(inputs, dimension)); return absl::WrapUnique(new SortThunk(std::move(info), inputs, dimension, - is_stable, std::move(comparator_name))); + is_stable, std::move(comparator_name), + direction)); } SortThunk::SortThunk(Info info, absl::Span inputs, - int64_t dimension, bool is_stable, LessThan less_than) + int64_t dimension, bool is_stable, LessThan less_than, + std::optional direction) : Thunk(Kind::kSort, std::move(info)), inputs_(inputs.begin(), inputs.end()), dimension_(dimension), is_stable_(is_stable), + direction_(direction), less_than_(std::move(less_than)), less_than_ptr_(&*less_than_) {} SortThunk::SortThunk(Info info, absl::Span inputs, int64_t dimension, bool is_stable, - std::string comparator_name) + std::string comparator_name, + std::optional direction) : Thunk(Kind::kSort, std::move(info)), inputs_(inputs.begin(), inputs.end()), dimension_(dimension), is_stable_(is_stable), + direction_(direction), comparator_name_(std::move(comparator_name)), less_than_ptr_(nullptr) {} @@ -131,6 +142,7 @@ static constexpr size_t kMaxElementSize = 16; // Forward declare reference type defined below. template struct Ref; +struct DRef; // Value type to store values loaded from the input buffers. template @@ -145,6 +157,18 @@ struct Value { std::array value_sizes; }; +struct DValue { + DValue(const DRef& ref); // NOLINT + + const void* compared_value(size_t i) const { return value[i].data(); } + + // Use properly aligned byte array to store primitive values. + using ValueStorage = std::array; + std::vector value; + std::vector value_sizes; + size_t n; +}; + // Reference to values stored in the input buffers. template struct Ref { @@ -160,6 +184,20 @@ struct Ref { std::array ptr_sizes; }; +struct DRef { + DRef(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef& operator=(const DValue& value); + DRef& operator=(const DRef& other); + + const void* compared_value(size_t i) const { return ptr[i]; } + + std::vector ptr; + std::vector ptr_sizes; + const size_t n; +}; + template Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { for (size_t i = 0; i < n; ++i) { @@ -167,6 +205,15 @@ Value::Value(const Ref& ref) : value_sizes(ref.ptr_sizes) { } } +DValue::DValue(const DRef& ref) + : value_sizes(ref.ptr_sizes), n(ref.ptr.size()) { + value.reserve(n); + for (size_t i = 0; i < n; ++i) { + value.emplace_back(); + std::memcpy(value[i].data(), ref.ptr[i], ref.ptr_sizes[i]); + } +} + template Ref& Ref::operator=(const Value& value) { DCHECK(ptr_sizes == value.value_sizes); @@ -176,6 +223,14 @@ Ref& Ref::operator=(const Value& value) { return *this; } +DRef& DRef::operator=(const DValue& value) { + DCHECK(ptr_sizes == value.value_sizes); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], value.value[i].data(), value.value_sizes[i]); + } + return *this; +} + template Ref& Ref::operator=(const Ref& other) { DCHECK(ptr_sizes == other.ptr_sizes); @@ -185,6 +240,15 @@ Ref& Ref::operator=(const Ref& other) { return *this; } +DRef& DRef::operator=(const DRef& other) { + DCHECK(ptr_sizes == other.ptr_sizes); + const size_t n = other.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::memcpy(ptr[i], other.ptr[i], other.ptr_sizes[i]); + } + return *this; +} + // Swap function required by `std::sort` and `std::stable_sort` implementations. template void swap(const Ref& lhs, const Ref& rhs) { @@ -196,6 +260,17 @@ void swap(const Ref& lhs, const Ref& rhs) { } } +void swap(const DRef& lhs, const DRef& rhs) { + DCHECK(lhs.ptr_sizes == rhs.ptr_sizes); + const size_t n = lhs.ptr.size(); + for (size_t i = 0; i < n; ++i) { + std::array tmp; + std::memcpy(tmp.data(), lhs.ptr[i], lhs.ptr_sizes[i]); + std::memcpy(lhs.ptr[i], rhs.ptr[i], rhs.ptr_sizes[i]); + std::memcpy(rhs.ptr[i], tmp.data(), lhs.ptr_sizes[i]); + } +} + // An array of pointers to the input data. template struct Ptr { @@ -250,19 +325,72 @@ struct Ptr { std::array ptr_sizes; // pointers sizes in bytes }; +struct DPtr { + using difference_type = std::ptrdiff_t; + + DPtr() = default; + + DPtr(std::vector ptr, std::vector ptr_sizes) + : ptr(ptr), ptr_sizes(ptr_sizes), n(ptr.size()) {} + + DRef operator*() const { return DRef{ptr, ptr_sizes}; } + + DPtr& operator+=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] += diff * ptr_sizes[i]; + return *this; + } + + DPtr& operator-=(difference_type diff) { + for (size_t i = 0; i < n; ++i) ptr[i] -= diff * ptr_sizes[i]; + return *this; + } + + DPtr operator+(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] + diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + DPtr operator-(difference_type diff) const { + std::vector upd(n); + for (size_t i = 0; i < n; ++i) upd[i] = ptr[i] - diff * ptr_sizes[i]; + return DPtr{upd, ptr_sizes}; + } + + // In all comparison operators defined below we use only the ptr at index 0, + // because we know that all pointers change together and this is an + // implementation detail of sort iterator. + + difference_type operator-(const DPtr& rhs) const { + DCHECK(ptr_sizes == rhs.ptr_sizes); + return (ptr[0] - rhs.ptr[0]) / ptr_sizes[0]; + } + + bool operator==(const DPtr& rhs) const { return ptr[0] == rhs.ptr[0]; } + bool operator!=(const DPtr& rhs) const { return ptr[0] != rhs.ptr[0]; } + bool operator>(const DPtr& rhs) const { return ptr[0] > rhs.ptr[0]; } + bool operator<(const DPtr& rhs) const { return ptr[0] < rhs.ptr[0]; } + bool operator>=(const DPtr& rhs) const { return ptr[0] >= rhs.ptr[0]; } + bool operator<=(const DPtr& rhs) const { return ptr[0] <= rhs.ptr[0]; } + + std::vector ptr; // pointers into the input buffers + std::vector ptr_sizes; // pointers sizes in bytes + size_t n; +}; + // We rely on `std::sort` and `std::stable_sort` to sort the raw data. We sort // multiple input buffers together using the same comparator function, so we // need to provide a custom iterator that can access the data of all input // buffers at the same time and swap elements in them. -template +template class SortIterator { public: using iterator_category = std::random_access_iterator_tag; using difference_type = std::ptrdiff_t; - using value_type = Value; - using reference = Ref; - using pointer = Ptr; + using value_type = Value; + using reference = Ref; + using pointer = Ptr; SortIterator() = default; SortIterator(pointer ptr, difference_type stride) @@ -364,6 +492,36 @@ static SortDims GetSortDims(const Shape& shape, int64_t dimension) { num_iterations}; } +// The most efficient way to sort a single buffer is to use the builtin +// comparator functions. +template +static void Sort1DArrInplace(const SortDims& sort_dims, int64_t offset, + absl::Span data, + bool is_stable, + SortThunk::SortDirection direction) { + using NativeT = typename primitive_util::PrimitiveTypeToNative::type; + DCHECK_EQ(data.size(), 1); + + NativeT* begin = reinterpret_cast(data[0].opaque()) + offset; + + if (direction == SortThunk::SortDirection::kAscending) { + if (is_stable) { + std::stable_sort(begin, begin + sort_dims.sort_dim_size, + std::less()); + } else { + std::sort(begin, begin + sort_dims.sort_dim_size, std::less()); + } + } else { + if (is_stable) { + std::stable_sort(begin, begin + sort_dims.sort_dim_size, + std::greater()); + } else { + std::sort(begin, begin + sort_dims.sort_dim_size, + std::greater()); + } + }; +} + // Sorts `n` buffers in place. template static void SortInplace(const SortDims& sort_dims, int64_t offset, @@ -388,8 +546,40 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset, return (*less_than)(data.data()); }; - SortIterator begin(Ptr(ptr, ptr_sizes), - /*stride=*/sort_dims.inner_dim_size); + SortIterator, Ref, Ptr> begin( + Ptr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); + if (is_stable) { + std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); + } else { + std::sort(begin, begin + sort_dims.sort_dim_size, compare); + } +} + +static void DSortInplace(const SortDims& sort_dims, int64_t offset, + absl::Span data, + absl::Span shapes, bool is_stable, + SortThunk::LessThan* less_than, size_t n) { + std::vector ptr(n); + std::vector ptr_sizes(n); + + for (size_t i = 0; i < n; ++i) { + std::byte* base = reinterpret_cast(data[i].opaque()); + ptr_sizes[i] = primitive_util::ByteWidth(shapes[i].element_type()); + ptr[i] = base + offset * ptr_sizes[i]; + } + + auto compare = [&](const auto& a, const auto& b) { + std::vector data(2 * n); + for (size_t i = 0, j = 0; i < n; i += 1, j += 2) { + data[j] = a.compared_value(i); + data[j + 1] = b.compared_value(i); + } + return (*less_than)(data.data()); + }; + + SortIterator begin(DPtr(ptr, ptr_sizes), + /*stride=*/sort_dims.inner_dim_size); if (is_stable) { std::stable_sort(begin, begin + sort_dims.sort_dim_size, compare); } else { @@ -398,10 +588,10 @@ static void SortInplace(const SortDims& sort_dims, int64_t offset, } // Sorts `data` of the given `shape` along the `dimension` inplace. -static absl::Status SortInplace(absl::Span data, - absl::Span shapes, - int64_t dimension, bool is_stable, - SortThunk::LessThan* less_than) { +static absl::Status SortInplace( + absl::Span data, absl::Span shapes, + int64_t dimension, bool is_stable, SortThunk::LessThan* less_than, + std::optional direction) { // All inputs have the same dimensions and layout, so we can use the first // shape to get the sort dimensions. SortDims sort_dims = GetSortDims(shapes[0], dimension); @@ -416,12 +606,66 @@ static absl::Status SortInplace(absl::Span data, is_stable, less_than); }; - // TODO(ezhulenev): We can replace statically known number of sorted inputs - // with a dynamic value, however statically known number of inputs allows - // compiler to generate better code. Benchmark if it really matters. + auto dsort = [&](size_t num_inputs) { + DSortInplace(sort_dims, offset, data, shapes, is_stable, less_than, + num_inputs); + }; + + // Sorts array using builtin comparator functor + auto builtin_sort = [&](PrimitiveType type, + SortThunk::SortDirection direction) { + switch (type) { + case S8: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case S16: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case S32: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case S64: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case U8: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case U16: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case U32: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case U64: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case F16: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case F32: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + case F64: + Sort1DArrInplace(sort_dims, offset, data, is_stable, direction); + break; + default: + sort(std::integral_constant{}); + break; + } + }; + + // use "sort" for statically known number of sorted inputs (expected to be + // faster) and "dsort" for dynamically known number of sorted inputs. + // for 100 elements stable sort is 1.5 times faster than stable dsort. + // for 100 elements unstable sort is 2.47 times faster than unstable dsort. switch (data.size()) { case 1: - sort(std::integral_constant{}); + DCHECK_EQ(shapes.size(), 1); + if (direction.has_value()) { + builtin_sort(shapes[0].element_type(), *direction); + } else { + sort(std::integral_constant{}); + } break; case 2: sort(std::integral_constant{}); @@ -468,14 +712,36 @@ static absl::Status SortInplace(absl::Span data, case 16: sort(std::integral_constant{}); break; + case 17: + sort(std::integral_constant{}); + break; + case 18: + sort(std::integral_constant{}); + break; + case 19: + sort(std::integral_constant{}); + break; + case 20: + sort(std::integral_constant{}); + break; + case 21: + sort(std::integral_constant{}); + break; + case 22: + sort(std::integral_constant{}); + break; + case 23: + sort(std::integral_constant{}); + break; + case 24: + sort(std::integral_constant{}); + break; case 25: sort(std::integral_constant{}); break; - case 29: - sort(std::integral_constant{}); - break; default: - return Internal("Unsupported number of sorted inputs: %d", data.size()); + dsort(data.size()); + break; } } @@ -533,7 +799,7 @@ tsl::AsyncValueRef SortThunk::Execute( } TF_RETURN_IF_ERROR(SortInplace(absl::MakeSpan(data), shapes, dimension_, - is_stable_, less_than)); + is_stable_, less_than, direction_)); return OkExecuteEvent(); } diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h index a1c2b5eda242ee..4d4942ba0c24c4 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h @@ -41,6 +41,11 @@ class SortThunk final : public Thunk { public: using LessThan = absl::AnyInvocable; + enum class SortDirection { + kAscending, + kDescending, + }; + struct Input { BufferAllocation::Slice slice; Shape shape; @@ -48,11 +53,13 @@ class SortThunk final : public Thunk { static absl::StatusOr> Create( Info info, absl::Span inputs, int64_t dimension, - bool is_stable, LessThan less_than); + bool is_stable, LessThan less_than, + std::optional direction = std::nullopt); static absl::StatusOr> Create( Info info, absl::Span inputs, int64_t dimension, - bool is_stable, std::string comparator_name); + bool is_stable, std::string comparator_name, + std::optional direction = std::nullopt); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; @@ -60,14 +67,17 @@ class SortThunk final : public Thunk { private: SortThunk(Info info, absl::Span inputs, int64_t dimension, - bool is_stable, LessThan less_than); + bool is_stable, LessThan less_than, + std::optional direction); SortThunk(Info info, absl::Span inputs, int64_t dimension, - bool is_stable, std::string comparator_name); + bool is_stable, std::string comparator_name, + std::optional direction); std::vector inputs_; int64_t dimension_; bool is_stable_; + std::optional direction_; // Name of the comparator function, lazily resolved to a comparator function // pointer using Thunk::FunctionRegistry. diff --git a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc index 1f450f77548d70..6d8012d1157a7b 100644 --- a/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -15,8 +15,13 @@ limitations under the License. #include "xla/backends/cpu/runtime/sort_thunk.h" +#include +#include #include #include +#include +#include +#include #include #include @@ -34,6 +39,7 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace xla::cpu { namespace { @@ -59,6 +65,49 @@ class LessThanComparator : public Thunk::FunctionRegistry { } }; +TEST_P(SortThunkTest, DescendingSortPlainArray) { + bool is_stable = GetParam(); + const int data_size = 10000; + + std::vector buffers; + std::vector data(data_size); + + std::default_random_engine gen; + std::uniform_real_distribution distribution(0.0, 1000.0); + + for (int i = 0; i < data_size; i++) { + data[i] = distribution(gen); + } + + const size_t size_in_bytes = data_size * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(data.data(), size_in_bytes)); + + const BufferAllocations allocations(buffers); + const BufferAllocation alloc(0, size_in_bytes, 0); + const BufferAllocation::Slice slice0(&alloc, 0, size_in_bytes); + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + + // The comparator function is not used in the plain array sort when the sort + // direction is specified and data types are supported. + auto fake_less_than = [](const void** data) { return false; }; + + // Use sort direction to activate the most efficient sorting function. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, {{slice0, data_shape}}, + /*dimension=*/0, is_stable, fake_less_than, + SortThunk::SortDirection::kDescending)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_TRUE( + std::is_sorted(data.cbegin(), data.cend(), std::greater())); +} + TEST_P(SortThunkTest, Sort1D) { bool is_stable = GetParam(); @@ -100,6 +149,83 @@ TEST_P(SortThunkTest, Sort1D) { EXPECT_EQ(indices, expected_indices); } +TEST_P(SortThunkTest, DynamicSort1D) { + bool is_stable = GetParam(); + + // 33 empty slices + 2 slices with data = 35 slices + // This amount of slices will call the dynamic sort implementation. + constexpr int num_of_empty_slices = 33; + constexpr int total_num_of_slices = num_of_empty_slices + 2; + + // size of each of 33 data buffers + constexpr int data_size = 31; + + // values range will be [5.0, 35.0] + constexpr float starting_value = 5.0f; + + std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is a container for the rest of the buffers. + std::array empty; + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + std::vector buffers; + buffers.emplace_back(se::DeviceMemoryBase(data.data(), data_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(indices.data(), ind_size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::array inputs{ + {{slice0, data_shape}, {slice1, indices_shape}}}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + std::array expected_data; + std::iota(expected_data.begin(), expected_data.end(), starting_value); + const std::array expected_indices{ + 2, 28, 20, 5, 6, 3, 30, 13, 21, 8, 24, 1, 0, 16, 12, 26, + 7, 15, 19, 25, 14, 22, 29, 11, 10, 4, 27, 9, 23, 18, 17}; + + EXPECT_EQ(data, expected_data); + EXPECT_EQ(indices, expected_indices); +} + TEST_P(SortThunkTest, Sort2D) { bool is_stable = GetParam(); @@ -237,6 +363,163 @@ TEST_P(SortThunkTest, Sort2DWithLayout) { EXPECT_EQ(indices, expected_indices); } +void BM_DynamicSort1D(::testing::benchmark::State& state, bool is_stable) { + const int total_num_of_slices = state.range(0); + const int num_of_empty_slices = total_num_of_slices - 2; + + // size of each of data buffers + constexpr int data_size = 31; + + const std::array data{ + 17.0f, 16.0f, 5.0f, 10.0f, 30.0f, 8.0f, 9.0f, 21.0f, + 14.0f, 32.0f, 29.0f, 28.0f, 19.0f, 12.0f, 25.0f, 22.0f, + 18.0f, 35.0f, 34.0f, 23.0f, 7.0f, 13.0f, 26.0f, 33.0f, + 15.0f, 24.0f, 20.0f, 31.0f, 6.0f, 27.0f, 11.0f}; + std::array indices; + std::iota(indices.begin(), indices.end(), 0); + + // This is the container for the rest of the buffers. + std::vector empty(data_size * num_of_empty_slices); + + const size_t data_size_in_bytes = data.size() * sizeof(float); + const size_t ind_size_in_bytes = indices.size() * sizeof(int32_t); + const size_t empty_size_in_bytes = empty.size() * sizeof(uint32_t); + + const BufferAllocation alloc0(0, data_size_in_bytes, 0); + const BufferAllocation alloc1(1, ind_size_in_bytes, 0); + const BufferAllocation rest(2, empty_size_in_bytes, 0); + + const BufferAllocation::Slice slice0(&alloc0, 0, data_size_in_bytes); + const BufferAllocation::Slice slice1(&alloc1, 0, ind_size_in_bytes); + + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + const Shape indices_shape = ShapeUtil::MakeShape(S32, {data_size}); + const Shape rest_shape = ShapeUtil::MakeShape(U32, {data_size}); + + for (auto s : state) { + // Pause timing to avoid counting the time spent in the setup. + state.PauseTiming(); + auto data_clone(data); + auto indices_clone(indices); + + std::vector buffers; + buffers.emplace_back( + se::DeviceMemoryBase(data_clone.data(), data_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(indices_clone.data(), ind_size_in_bytes)); + buffers.emplace_back( + se::DeviceMemoryBase(empty.data(), empty_size_in_bytes)); + + BufferAllocations allocations(buffers); + + std::vector inputs(total_num_of_slices); + inputs[0] = {slice0, data_shape}; + inputs[1] = {slice1, indices_shape}; + for (int i = 0; i < num_of_empty_slices; ++i) { + constexpr size_t empty_slice_in_bytes = data_size * sizeof(uint32_t); + inputs[i + 2].slice = BufferAllocation::Slice( + &rest, i * empty_slice_in_bytes, empty_slice_in_bytes); + inputs[i + 2].shape = rest_shape; + } + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + state.ResumeTiming(); + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, SortThunk::Create({"sort"}, inputs, + /*dimension=*/0, is_stable, LessThan)); + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + } +} + +void BM_SortPlainArray(::testing::benchmark::State& state, bool is_stable) { + const int data_size = state.range(0); + + std::vector data(data_size); + + std::default_random_engine gen; + std::uniform_real_distribution distribution(0.0, 1000.0); + + for (int i = 0; i < data_size; i++) { + data[i] = distribution(gen); + } + + const size_t size_in_bytes = data_size * sizeof(float); + const BufferAllocation alloc(0, size_in_bytes, 0); + const BufferAllocation::Slice slice0(&alloc, 0, size_in_bytes); + const Shape data_shape = ShapeUtil::MakeShape(F32, {data_size}); + + for (auto s : state) { + state.PauseTiming(); + auto data_clone(data); + std::vector buffer; + buffer.emplace_back(se::DeviceMemoryBase(data_clone.data(), size_in_bytes)); + + const BufferAllocations allocations(buffer); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + + // The comparator function is not used in the plain array sort when the sort + // direction is specified and data types are supported. + auto fake_less_than = [](const void** data) { return false; }; + + state.ResumeTiming(); + // Use sort direction to activate the most efficient sorting function. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + SortThunk::Create({"sort"}, {{slice0, data_shape}}, + /*dimension=*/0, is_stable, fake_less_than, + SortThunk::SortDirection::kAscending)); + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + } +} + +void BM_StableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/true); +} + +void BM_UnstableDynamicSort1D(::testing::benchmark::State& state) { + BM_DynamicSort1D(state, /*is_stable=*/false); +} + +void BM_StableSortPlainArray(::testing::benchmark::State& state) { + BM_SortPlainArray(state, /*is_stable=*/true); +} + +void BM_UnstableSortPlainArray(::testing::benchmark::State& state) { + BM_SortPlainArray(state, /*is_stable=*/false); +} + +BENCHMARK(BM_StableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + +BENCHMARK(BM_UnstableDynamicSort1D) + ->MeasureProcessCPUTime() + ->Arg(35) + ->Arg(50) + ->Arg(100); + +BENCHMARK(BM_StableSortPlainArray) + ->MeasureProcessCPUTime() + ->Arg(10000) + ->Arg(100000); + +BENCHMARK(BM_UnstableSortPlainArray) + ->MeasureProcessCPUTime() + ->Arg(10000) + ->Arg(100000); + INSTANTIATE_TEST_SUITE_P(SortThunk, SortThunkTest, testing::Bool(), testing::PrintToStringParamName()); diff --git a/third_party/xla/xla/backends/cpu/runtime/thread_pool_task_runner.h b/third_party/xla/xla/backends/cpu/runtime/thread_pool_task_runner.h new file mode 100644 index 00000000000000..95e36c293f7af5 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/thread_pool_task_runner.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_THREAD_POOL_TASK_RUNNER_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THREAD_POOL_TASK_RUNNER_H_ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "unsupported/Eigen/CXX11/ThreadPool" +#include "xla/backends/cpu/runtime/thunk.h" + +namespace xla::cpu { + +// An implementation of a `Thunk::TaskRunner` that uses Eigen thread pool for +// launching ThunkExecutor tasks. In XLA in practice it means that we run +// all ThunkExecutor tasks in the intra-op thread pool (owned by PjRt client). +class ThreadPoolTaskRunner : public Thunk::TaskRunner { + public: + explicit ThreadPoolTaskRunner(Eigen::ThreadPoolInterface* thread_pool) + : thread_pool_(thread_pool) {} + + void operator()(Thunk::Task task) final { + if (thread_pool_ == nullptr) { + task(); + } else { + thread_pool_->Schedule(std::move(task)); + } + } + + std::optional current_worker_id() const final { + if (thread_pool_ == nullptr) { + return {0}; + } else { + int64_t thread_id = thread_pool_->CurrentThreadId(); + return thread_id == -1 ? std::nullopt : std::make_optional(thread_id); + } + } + + private: + Eigen::ThreadPoolInterface* thread_pool_; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_THREAD_POOL_TASK_RUNNER_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h index 794e5089db99c4..41380cefd8230f 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -29,7 +29,6 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" -#include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/resource_use.h" @@ -40,7 +39,6 @@ limitations under the License. #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/stream.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" #include "xla/util.h" @@ -99,6 +97,8 @@ class Thunk { int64_t module_id; }; + using Task = std::function; + // An abstract task runner that can be used by a ThunkExecutor (including // thunk executors for nested computations in conditional or while thunks) for // running tasks corresponding to thunk execution. It can be a simple inline @@ -107,8 +107,21 @@ class Thunk { // pool with the intra-op thread pool used for compute tasks. We deliberately // do not prescribe task runner to be Eigen or any other particular thread // pool, and let users make the choice. - using Task = std::function; - using TaskRunner = absl::AnyInvocable; + class TaskRunner { + public: + virtual ~TaskRunner() = default; + + virtual void operator()(Task task) = 0; + + // Returns the current worker id if the caller happens to run on a thread + // managed by the task runner. Otherwise returns empty optional. Thunk + // executor relies on this information to do a best-effort resource + // isolation by making sure that all thunks are executed inside a task + // runner, and do not "leak" into arbitrary thread pools in the process, + // because by default we resume execution on a thread that completed thunk + // execute event AsyncValue, and it can be an external thread pool. + virtual std::optional current_worker_id() const = 0; + }; Thunk(Kind kind, Info info); diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index 62df8d8f5b12a7..d330f2116e14d2 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -60,6 +60,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, sink_.push_back(i); } } + // Erase redundant edges between nodes. int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities(); @@ -69,7 +70,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); } - // Maybe mark execution as sequential if all thunks use small buffers. + // Prefer sequential execution if all thunks use small buffers. auto uses_small_buffers = [&](const std::unique_ptr& thunk) { return absl::c_all_of(thunk->buffer_uses(), [&](const BufferUse& use) { return use.slice().size() <= options.execute_sequential_buffer_threshold; @@ -79,6 +80,10 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, bool small_buffers = absl::c_all_of(thunk_sequence_, uses_small_buffers); is_sequential_ |= small_buffers; + // Prefer sequential execution for small thunk sequences. + is_sequential_ |= + thunk_sequence_.size() <= options.execute_sequential_num_thunks_threshold; + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v, small_buffers=%v", @@ -140,9 +145,6 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, execute_event(tsl::MakeConstructedAsyncValueRef()), pending_sink_nodes(executor->sink().size()), abort(false) { - DCHECK(runner == nullptr || static_cast(*runner)) - << "`runner` must be nullptr or a valid TaskRunner"; - NodeStorage* node = nodes.data(); for (const NodeDef& node_def : executor->nodes_defs()) { new (node++) Node(node_def); @@ -159,8 +161,10 @@ tsl::AsyncValueRef ThunkExecutor::Execute( return thunk_sequence_[0]->Execute(params); } - // If thunk sequence dependencies form a sequential execution graph, we skip - // expensive async execution and simply run thunks one by one. + // When we choose sequential execution strategy (we rely on heuristics and + // a cost model to make the decision), we skip expensive async execution and + // simply run thunks one by one. This minimizes runtime overheads from small + // XLA programs with many cheap operations. if (is_sequential_) { return ExecuteSequential(params); } @@ -216,10 +220,19 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { auto event = tsl::MakeConstructedAsyncValueRef(); execute_event.AndThen([this, ¶ms, it, event](absl::Status status) { + Thunk::TaskRunner* runner = params.task_runner; + if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); - } else { + } else if (!runner || runner->current_worker_id()) { + // Resume execution in the current thread if we are already running + // on a thread managed by the task runner. ResumeExecuteSequential(it + 1, params, std::move(event)); + } else { + // Resume execution in the task runner to avoid thread "leaks". + (*runner)([this, ¶ms, it, event = std::move(event)] { + ResumeExecuteSequential(it + 1, params, std::move(event)); + }); } }); return event; @@ -253,10 +266,19 @@ void ThunkExecutor::ResumeExecuteSequential( if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { execute_event.AndThen( [this, ¶ms, it, event = std::move(event)](absl::Status status) { + Thunk::TaskRunner* runner = params.task_runner; + if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); - } else { + } else if (!runner || runner->current_worker_id()) { + // Resume execution in the current thread if we are already + // running on a thread managed by the task runner. ResumeExecuteSequential(it + 1, params, std::move(event)); + } else { + // Resume execution in the task runner to avoid thread "leaks". + (*runner)([this, ¶ms, it, event = std::move(event)] { + ResumeExecuteSequential(it + 1, params, std::move(event)); + }); } }); return; @@ -338,12 +360,27 @@ void ThunkExecutor::Execute(ExecuteState* state, : params.session.Join()]() mutable { state->executor->ProcessOutEdges(state, execute_event, node, ready_queue); + // If ready queue is empty, it might mean that we have completed an // execution and destroyed the `state`, so we make sure we don't // touch `state` if we don't have to. - if (ABSL_PREDICT_TRUE(!ready_queue.Empty())) { + if (ABSL_PREDICT_FALSE(ready_queue.Empty())) { + return; + } + + Thunk::TaskRunner* runner = state->runner; + if (!runner || runner->current_worker_id()) { + // Resume execution in the current thread if we are already + // running on a thread managed by the task runner. state->executor->Execute(state, params, std::move(ready_queue), std::move(lock)); + } else { + // Resume execution in the task runner to avoid thread "leaks". + (*runner)([state, ¶ms, ready_queue = std::move(ready_queue), + lock = std::move(lock)] { + state->executor->Execute(state, params, std::move(ready_queue), + std::move(lock)); + }); } }); } @@ -372,7 +409,7 @@ inline ABSL_ATTRIBUTE_ALWAYS_INLINE void ThunkExecutor::SplitReadyQueue( // Execute half of the ready queue nodes in the task runner. (*state->runner)([¶ms, state, ready_queue = ready_queue.PopHalf(), - lock = std::move(task_runner_lock)]() mutable { + lock = std::move(task_runner_lock)] { state->executor->Execute(state, params, std::move(ready_queue), std::move(lock)); }); diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index 5ba15b0432b504..54b4a4be2ac0c6 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -42,11 +42,16 @@ namespace internal { // Clang does not allow defining a nested struct with member initializer, as // a workaround we define a struct in internal namespace and create an alias. struct ThunkExecutorOptions { - // If all thunks in a sequence use buffers of size less than or equal to - // `execute_sequential_buffer_threshold`, we mark execution as sequential, as - // concurrency overheads will likely dominate the overall execution time. + // If all thunks in a sequence use buffers of size less than or equal to the + // given threshold, we mark execution as sequential, as concurrency overheads + // will likely dominate the overall execution time. size_t execute_sequential_buffer_threshold = 512; + // If thunk sequence length is less than or equal to the given threshold, we + // mark execution as sequential, as concurrency overheads will likely dominate + // the overall execution time. + size_t execute_sequential_num_thunks_threshold = 8; + // Use priority ready queue to execute nodes according to their priority. By // default we use FIFO ready queue. bool use_priority_ready_queue = false; @@ -228,8 +233,8 @@ class ThunkExecutor { ReadyQueue& ready_queue, int64_t split_threshold); // Processes out edges of a completed `node` and updates `ready_queue` with - // nodes that are ready to execute. If `event` is in error state, aborts the - // execution and records the error status to forward it to the caller. + // nodes that are ready to execute. If `node_event` is in error state, aborts + // the execution and records the error status to forward it to the caller. template void ProcessOutEdges(ExecuteState* state, tsl::AsyncValuePtr node_event, diff --git a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc index d0deefe5d4880d..511456e2adf762 100644 --- a/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/backends/cpu/runtime/thunk_executor.h" #include +#include +#include #include #include #include @@ -32,6 +34,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" #include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" @@ -60,6 +63,48 @@ using ::testing::ElementsAre; // with a thread sanitizer and checking that there are no data races. static int64_t shared_resource; +// An adaptor from a lambda that runs tasks and a TaskRunner API. +template +class TaskRunnerAdaptor : public Thunk::TaskRunner { + public: + TaskRunnerAdaptor(Runner runner, WorkerId worker_id) + : runner_(std::move(runner)), worker_id_(std::move(worker_id)) {} + + void operator()(Thunk::Task task) final { runner_(std::move(task)); } + + std::optional current_worker_id() const final { + return worker_id_(); + } + + private: + Runner runner_; + WorkerId worker_id_; +}; + +template +auto MakeTaskRunnerFrom(Runner&& runner) { + auto no_id = []() { return std::nullopt; }; + return TaskRunnerAdaptor( + std::forward(runner), no_id); +} + +template +auto MakeTaskRunnerFrom(Runner&& runner, WorkerId&& worker_id) { + return TaskRunnerAdaptor(std::forward(runner), + std::forward(worker_id)); +} + +template +std::vector AsDeviceMemory( + absl::Span* const> data) { + std::vector buffers; + for (auto& vec : data) { + buffers.emplace_back( + se::DeviceMemoryBase(vec->data(), vec->size() * sizeof(T))); + } + return buffers; +} + // A test-only thunk for verifying thunk executor implementation: // // dst += src (for all srcs and dsts slices) @@ -80,9 +125,6 @@ class AddI32Thunk final : public Thunk { std::vector* trace = nullptr, bool use_shared_resource = false, bool inject_error = false); - static std::vector AsDeviceMemory( - absl::Span* const> data); - // Executes `dst += src` for a single src/dst pair. static absl::Status Execute(const BufferAllocations* allocations, BufferAllocation::Slice src_slice, @@ -110,16 +152,6 @@ std::unique_ptr AddI32Thunk::Create( use_shared_resource, inject_error); } -std::vector AddI32Thunk::AsDeviceMemory( - absl::Span* const> data) { - std::vector buffers; - for (auto& vec : data) { - buffers.emplace_back( - se::DeviceMemoryBase(vec->data(), vec->size() * sizeof(int32_t))); - } - return buffers; -} - AddI32Thunk::AddI32Thunk(std::string name, std::vector srcs, std::vector dsts, @@ -213,10 +245,8 @@ AddI32Thunk::ResourceUses AddI32Thunk::resource_uses() const { } static ThunkExecutor::Options OptionsForTest() { - // Override small buffers threshold to make sure that we test all execution - // paths, because in test we always use small buffers below the default - // threshold of `512`. - return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0}; + return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0, + /*execute_sequential_num_thunks_threshold=*/0}; } TEST(ThunkExecutorTest, FifoReadyQueueTest) { @@ -455,13 +485,16 @@ TEST(ThunkExecutorTest, Execute) { std::vector data(20, 1); // shared src and dst allocation - auto buffers = AddI32Thunk::AsDeviceMemory({&data}); + auto buffers = AsDeviceMemory({&data}); BufferAllocations allocations(buffers); - Thunk::TaskRunner task_runner = [&](Thunk::Task task) { - trace.push_back(""); - task(); - }; + auto task_runner = MakeTaskRunnerFrom( + [&](Thunk::Task task) { + trace.push_back(""); + task(); + }, + // Always return current worker id as 0. + [] { return 0; }); Thunk::ExecuteParams params = {nullptr, &allocations}; params.task_runner = &task_runner; @@ -479,6 +512,95 @@ TEST(ThunkExecutorTest, Execute) { 2, 2, 2, 2, 2)); // slice1 } +//===----------------------------------------------------------------------===// +// ThunkExecutor resource isolation testing +//===----------------------------------------------------------------------===// + +// No-op thunk that completes execution on a separate thread pool. We use this +// thunk to test that ThunkExecutor can jump out of a separate thread pool to +// continue execution in the intra-op thread pool. This is important for +// resource isolation as we don't want to accidentally continue with expensive +// execution on a non blocking IO callbacks thread pool. +class NoOpAsyncThunk : public Thunk { + public: + NoOpAsyncThunk(std::string name, BufferAllocation::Slice slice) + : Thunk(Kind::kKernel, Info{std::move(name)}), slice_(slice) {} + + static std::unique_ptr Create(std::string name, + BufferAllocation::Slice slice) { + return std::make_unique(std::move(name), slice); + } + + tsl::AsyncValueRef Execute(const ExecuteParams&) final { + auto ret = tsl::MakeConstructedAsyncValueRef(); + ThreadPool()->Schedule([ret] { + tsl::Env::Default()->SleepForMicroseconds(10 * 1000); + ret.SetStateConcrete(); + }); + return ret; + } + + BufferUses buffer_uses() const override { + return BufferUses{BufferUse::Write(slice_)}; + } + + private: + static tsl::thread::ThreadPool* ThreadPool() { + static auto* thread_pool = + new tsl::thread::ThreadPool(tsl::Env::Default(), "no-op-thunk", 8); + return thread_pool; + } + + BufferAllocation::Slice slice_; +}; + +TEST(ThunkExecutorTest, ExecuteOnCorrectThreadPool) { + BufferAllocation alloc(/*index=*/0, /*size=*/60, /*color=*/0); + + BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/20); + BufferAllocation::Slice slice1(&alloc, /*offset=*/20, /*size=*/20); + BufferAllocation::Slice slice2(&alloc, /*offset=*/40, /*size=*/20); + + std::array slices = {slice0, slice1, slice2}; + + ThunkSequence sequence; + for (int i = 0; i < 100; ++i) { + sequence.push_back(NoOpAsyncThunk::Create(absl::StrCat(i), slices[i % 3])); + } + + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); + + std::vector data(60, 1); // shared src and dst allocation + + auto buffers = AsDeviceMemory({&data}); + BufferAllocations allocations(buffers); + + // Task runner must be used only when ThunkExecutor detects that it runs on a + // wrong thread and has to jump into the task runner. + std::atomic num_tasks = 0; + auto task_runner = MakeTaskRunnerFrom([&](Thunk::Task task) { + ++num_tasks; + task(); + }); + + Thunk::ExecuteParams params = {nullptr, &allocations}; + params.task_runner = &task_runner; + params.session = + Thunk::ExecuteSession(/*max_workers=*/1, /*split_threshold=*/1000); + + auto execute_event = executor.Execute(params); + + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsConcrete()); + + // We compare using GE because thread scheduling introduces small + // non-determinism and ThunkExecutor might resume after NoOpAsyncThunk already + // completes its execution event. + EXPECT_GE(num_tasks, 90); +} + //===----------------------------------------------------------------------===// // ThunkExecutor stress testing //===----------------------------------------------------------------------===// @@ -516,8 +638,8 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks, }); g->sequence.reserve(num_thunks); - g->expected_buffers = AddI32Thunk::AsDeviceMemory({&g->src, &g->expected}); - g->buffers = AddI32Thunk::AsDeviceMemory({&g->src, &g->dst}); + g->expected_buffers = AsDeviceMemory({&g->src, &g->expected}); + g->buffers = AsDeviceMemory({&g->src, &g->dst}); std::minstd_rand0 engine; @@ -585,9 +707,7 @@ class ThunkExecutorStressTest thread_pool_.emplace(tsl::Env::Default(), "thunk-executor", 8); device_.emplace(thread_pool_->AsEigenThreadPool(), thread_pool_->NumThreads()); - task_runner_.emplace([this](Thunk::Task task) { - thread_pool_->Schedule(std::move(task)); - }); + task_runner_.emplace(thread_pool_->AsEigenThreadPool()); } } @@ -606,7 +726,7 @@ class ThunkExecutorStressTest bool use_device_; std::optional thread_pool_; std::optional device_; - std::optional task_runner_; + std::optional task_runner_; }; TEST_P(ThunkExecutorStressTest, Execute) { @@ -811,10 +931,7 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) { ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); - - Thunk::TaskRunner task_runner = [&](Thunk::Task task) { - thread_pool.Schedule(std::move(task)); - }; + ThreadPoolTaskRunner task_runner(thread_pool.AsEigenThreadPool()); Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, &device, &task_runner}; diff --git a/third_party/xla/xla/backends/interpreter/BUILD b/third_party/xla/xla/backends/interpreter/BUILD index 16ccd763891b6f..74e5172785943d 100644 --- a/third_party/xla/xla/backends/interpreter/BUILD +++ b/third_party/xla/xla/backends/interpreter/BUILD @@ -39,22 +39,23 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:cholesky_expander", + "//xla/hlo/transforms:dynamic_index_splitter", + "//xla/hlo/transforms:eigh_expander", + "//xla/hlo/transforms:qr_expander", "//xla/service:batchnorm_expander", - "//xla/service:cholesky_expander", "//xla/service:compiler", "//xla/service:computation_placer", "//xla/service:custom_call_target_registry", "//xla/service:dynamic_dimension_inference", - "//xla/service:dynamic_index_splitter", - "//xla/service:eigh_expander", "//xla/service:executable", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", "//xla/service:layout_assignment", - "//xla/service:qr_expander", "//xla/service:topk_rewriter", "//xla/service:triangular_solve_expander", - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -70,7 +71,7 @@ cc_library( name = "platform_id", srcs = ["platform_id.cc"], hdrs = ["platform_id.h"], - deps = ["//xla/stream_executor"] + if_google( + deps = ["//xla/stream_executor:platform"] + if_google( ["@com_google_protobuf//:any_cc_proto"], ["@com_google_protobuf//:protobuf_headers"], ), @@ -95,7 +96,11 @@ cc_library( "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -124,7 +129,7 @@ cc_library( "//xla/service:hlo_execution_profile", "//xla/service:hlo_module_config", "//xla/service:shaped_buffer", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -144,7 +149,7 @@ cc_library( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -161,12 +166,17 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:host_memory_allocation", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_common", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_stream", diff --git a/third_party/xla/xla/backends/interpreter/compiler.cc b/third_party/xla/xla/backends/interpreter/compiler.cc index 2ae30ab8951db1..1584eb4303a0ce 100644 --- a/third_party/xla/xla/backends/interpreter/compiler.cc +++ b/third_party/xla/xla/backends/interpreter/compiler.cc @@ -31,19 +31,19 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/expanders/cholesky_expander.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/eigh_expander.h" +#include "xla/hlo/transforms/expanders/qr_expander.h" #include "xla/literal.h" #include "xla/service/batchnorm_expander.h" -#include "xla/service/cholesky_expander.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/dynamic_dimension_inference.h" -#include "xla/service/dynamic_index_splitter.h" -#include "xla/service/eigh_expander.h" #include "xla/service/executable.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/layout_assignment.h" -#include "xla/service/qr_expander.h" #include "xla/service/topk_rewriter.h" #include "xla/service/triangular_solve_expander.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/backends/interpreter/executor.cc b/third_party/xla/xla/backends/interpreter/executor.cc index 1f1f4b9779754c..c9ae598be668ae 100644 --- a/third_party/xla/xla/backends/interpreter/executor.cc +++ b/third_party/xla/xla/backends/interpreter/executor.cc @@ -56,10 +56,6 @@ absl::Status XlaInterpreterExecutor::SynchronousMemcpy( return absl::OkStatus(); } -absl::Status XlaInterpreterExecutor::BlockHostUntilDone(Stream *stream) { - return AsExecutorStream(stream)->BlockUntilDone(); -} - absl::StatusOr> XlaInterpreterExecutor::CreateDeviceDescription(int device_ordinal) { DeviceDescription desc; diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 1228b3ba890055..f421d18a6a7528 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -113,8 +113,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { void DeallocateStream(Stream *stream) override {} - absl::Status BlockHostUntilDone(Stream *stream) override; - bool DeviceMemoryUsage(int64_t *free, int64_t *total) const override { return false; } diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index a67bf1fdf718e6..bd2556a12079f3 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -36,15 +36,15 @@ cc_library( "//xla/tsl/profiler/backends/cpu:host_tracer_utils", "//xla/tsl/profiler/backends/cpu:threadpool_listener", "//xla/tsl/profiler/backends/cpu:traceme_recorder", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/profiler/lib:profiler_collection", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -96,14 +96,14 @@ cc_library( ":metadata_utils", "//xla/service:hlo_proto_cc", "//xla/service:xla_debug_info_manager", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/status", "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], alwayslink = True, ) @@ -117,9 +117,9 @@ cc_library( deps = [ "//xla/service:hlo_proto_cc", "//xla/tsl/profiler/convert:xla_op_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) @@ -129,6 +129,10 @@ xla_cc_test( deps = [ ":host_tracer_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:timespan", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:blocking_counter", @@ -138,9 +142,5 @@ xla_cc_test( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:timespan", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc index 031301290406ab..75843df3160e52 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer.cc @@ -24,13 +24,13 @@ limitations under the License. #include "xla/tsl/profiler/backends/cpu/host_tracer_utils.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/errors.h" #include "tsl/profiler/lib/profiler_collection.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 2fca882f9910d8..ad7b241859d478 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -23,6 +23,10 @@ limitations under the License. #include #include "absl/types/optional.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" @@ -31,10 +35,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc b/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc index 26735490d8851c..2f75c0e6c64676 100644 --- a/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc +++ b/third_party/xla/xla/backends/profiler/cpu/metadata_collector.cc @@ -22,13 +22,13 @@ limitations under the License. #include "xla/backends/profiler/cpu/metadata_utils.h" #include "xla/service/hlo.pb.h" #include "xla/service/xla_debug_info_manager.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h b/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h index 149e72fe259349..b30da770dda90f 100644 --- a/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h +++ b/third_party/xla/xla/backends/profiler/cpu/metadata_utils.h @@ -18,9 +18,9 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index c3efcc2bd5a784..7388360ffd09de 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -11,6 +11,8 @@ load( load("//xla/tests:build_defs.bzl", "xla_test") load( "//xla/tsl:tsl.bzl", + "if_google", + "if_nvcc", "internal_visibility", "tsl_copts", "tsl_gpu_library", @@ -26,16 +28,9 @@ tsl_gpu_library( name = "device_tracer", srcs = tf_additional_device_tracer_srcs(), copts = tf_profiler_copts() + tsl_copts(), - cuda_deps = [ - ":cupti_buffer_events", - ":cupti_collector", - ":cupti_tracer", - ":cupti_wrapper", - ":rocm_collector", - ":rocm_tracer", - ], deps = [ ":cupti_utils", + "//xla/tsl/profiler/utils:time_utils", "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -46,8 +41,17 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - ], + ] + if_cuda([ + # keep sorted + ":cupti_buffer_events", + ":cupti_collector", + ":cupti_tracer", + ":cupti_wrapper", + ]) + if_rocm([ + # keep sorted + ":rocm_collector", + ":rocm_tracer", + ]), alwayslink = 1, ) @@ -114,7 +118,7 @@ xla_test( ":cupti_wrapper", ":mock_cupti", "@com_google_absl//absl/memory", - "@local_tsl//tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:time_utils", ]), ) @@ -123,17 +127,14 @@ cuda_library( testonly = 1, srcs = ["cuda_test.cu.cc"], hdrs = ["cuda_test.h"], - copts = select({ - "@local_config_cuda//cuda:using_nvcc": [ - "-nvcc_options", - "ptxas-options=-v", - ], - "//conditions:default": [], - }), + copts = if_nvcc([ + "-nvcc_options", + "ptxas-options=-v", + ]), visibility = ["//visibility:public"], deps = [ "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudart_static", + "@local_config_cuda//cuda:cuda_runtime", "@local_tsl//tsl/platform:test", ], ) @@ -172,6 +173,8 @@ tsl_gpu_library( ":cupti_utils", ":nvtx_utils", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -183,8 +186,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:per_thread", ], ) @@ -218,10 +219,21 @@ tsl_gpu_library( srcs = if_rocm(["rocm_collector.cc"]), hdrs = if_rocm(["rocm_collector.h"]), copts = tf_profiler_copts() + tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], deps = [ "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "//xla/tsl/util:env_var", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -241,10 +253,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory", "@local_tsl//tsl/profiler/lib:profiler_interface", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -253,11 +261,19 @@ tsl_gpu_library( srcs = if_rocm(["rocm_tracer.cc"]), hdrs = if_rocm(["rocm_tracer.h"]), copts = tf_profiler_copts() + tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], deps = [ ":rocm_collector", "//xla/stream_executor/rocm:roctracer_wrapper", "//xla/tsl/profiler/backends/cpu:annotation_stack", + "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -271,7 +287,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:time_utils", ], ) @@ -296,6 +311,12 @@ tsl_gpu_library( deps = [ ":cupti_buffer_events", ":cupti_interface", + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/hash", @@ -304,12 +325,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ] + if_cuda([ "//xla/tsl/cuda:cupti", "//xla/tsl/cuda", @@ -324,6 +339,8 @@ tsl_gpu_library( visibility = ["//visibility:public"], deps = [ ":cupti_interface", + "//xla/tsl/profiler/utils:buffer_pool", + "//xla/tsl/profiler/utils:lock_free_queue", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/container:flat_hash_map", @@ -334,8 +351,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/profiler/utils:buffer_pool", - "@local_tsl//tsl/profiler/utils:lock_free_queue", ] + if_cuda(["//xla/tsl/cuda:cupti"]), ) diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h index f58dda54e623c1..d0a48535834024 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h @@ -31,10 +31,10 @@ limitations under the License. #include "absl/container/node_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/profiler/utils/buffer_pool.h" -#include "tsl/profiler/utils/lock_free_queue.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc index 6191849b0d0944..043bc4250c9681 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_collector.cc @@ -15,7 +15,9 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" +#include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -26,15 +28,15 @@ limitations under the License. #include "third_party/gpus/cuda/extras/CUPTI/include/cupti_activity.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_occupancy.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/abi.h" #include "tsl/platform/host_info.h" #include "tsl/platform/mem.h" #include "tsl/platform/mutex.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { @@ -157,6 +159,11 @@ class PerDeviceCollector { if (kernel_name.empty()) { kernel_name = GetTraceEventTypeName(event.type); } + // For CPU events like cuGraphLaunch(), add the graph id to the name. + if (event.graph_id != 0 && event.type == CuptiTracerEventType::CudaGraph && + event.source == CuptiTracerEventSource::DriverCallback) { + absl::StrAppend(&kernel_name, " (CudaGraph:", event.graph_id, ")"); + } XEventMetadata* event_metadata = plane->GetOrCreateEventMetadata(std::move(kernel_name)); XEventBuilder xevent = line->AddEvent(*event_metadata); @@ -317,6 +324,15 @@ class PerDeviceCollector { return ret_val; } + std::optional GetDeviceName(CUdevice device) { + char device_name[512]; + if (cuDeviceGetName(device_name, sizeof(device_name), device) != + CUDA_SUCCESS) { + return std::nullopt; + } + return std::string(device_name); + } + std::string GetDeviceXLineName( int64_t stream_id, absl::flat_hash_set& event_types) { @@ -385,6 +401,13 @@ class PerDeviceCollector { CUdevice device; if (cuDeviceGet(&device, device_ordinal) != CUDA_SUCCESS) return; + std::optional device_name = GetDeviceName(device); + if (device_name.has_value()) { + device_plane->AddStatValue(*device_plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kGpuDeviceName)), + *device_name); + } + auto clock_rate_in_khz = GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_CLOCK_RATE); if (clock_rate_in_khz) { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc index a357d9ab41c97b..05aa020d84ab9e 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_error_manager_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" #include "xla/backends/profiler/gpu/mock_cupti.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/test.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc index 3374181c569204..3950ea7c861c89 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_tracer.cc @@ -31,11 +31,11 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_interface.h" #include "xla/backends/profiler/gpu/nvtx_utils.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/logging.h" -#include "tsl/profiler/utils/per_thread.h" namespace xla { namespace profiler { @@ -1413,8 +1413,11 @@ absl::Status CuptiTracer::ProcessActivityBuffer(CUcontext context, collector_->GetOptions().max_activity_api_events; if (max_activity_event_count > 0 && num_activity_events_in_cached_buffer_ >= max_activity_event_count) { - LOG(WARNING) << "Already too many activity events, drop the buffer of " - << size << "bytes of event to reuse."; + LOG_EVERY_N(WARNING, 10000) + << "Already too many activity events, drop the buffer of " << size + << "bytes of event to reuse. This warning is logged once per 10000 " + "occurrences, the current count is " + << COUNTER << "."; num_activity_events_in_dropped_buffer_ += event_count_in_buffer; // buffer will be return to the pool return absl::OkStatus(); diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc index a34b5134b26455..578d4ab6d3021d 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_cuda.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/backends/profiler/gpu/cupti_collector.h" #include "xla/backends/profiler/gpu/cupti_tracer.h" #include "xla/backends/profiler/gpu/cupti_wrapper.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc index aca42312a7d404..e5a91f9431b86a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc +++ b/third_party/xla/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -27,6 +27,10 @@ limitations under the License. #include "xla/backends/profiler/gpu/rocm_collector.h" #include "xla/backends/profiler/gpu/rocm_tracer.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" @@ -36,10 +40,6 @@ limitations under the License. #include "tsl/platform/thread_annotations.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc index d96cfdc8ed23bb..88371e5b09605a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc @@ -25,6 +25,10 @@ limitations under the License. #include "absl/types/optional.h" #include "xla/stream_executor/rocm/roctracer_wrapper.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/abi.h" #include "tsl/platform/env_time.h" @@ -36,10 +40,6 @@ limitations under the License. #include "tsl/platform/types.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h index 2c9ccd847aed1e..220fa2bb13e4a2 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_set.h" -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc index 5a77b679112e30..fad3e39831c49a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_tracer.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/container/node_hash_map.h" #include "rocm/rocm_config.h" #include "xla/tsl/profiler/backends/cpu/annotation_stack.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/mem.h" -#include "tsl/profiler/utils/time_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/plugin/BUILD b/third_party/xla/xla/backends/profiler/plugin/BUILD index b50b08779f2d91..01556a6a0828c5 100644 --- a/third_party/xla/xla/backends/profiler/plugin/BUILD +++ b/third_party/xla/xla/backends/profiler/plugin/BUILD @@ -80,6 +80,8 @@ xla_cc_test( ":plugin_tracer_impl", ":profiler_c_api_hdrs", ":profiler_error", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:logging", @@ -87,7 +89,5 @@ xla_cc_test( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) diff --git a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc index b9dc203ddad276..a0e1cd45407a92 100644 --- a/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc +++ b/third_party/xla/xla/backends/profiler/plugin/plugin_tracer_impl_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/backends/profiler/plugin/plugin_tracer.h" #include "xla/backends/profiler/plugin/profiler_c_api.h" #include "xla/backends/profiler/plugin/profiler_error.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/profiler_factory.h" #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/backends/profiler/tpu/BUILD b/third_party/xla/xla/backends/profiler/tpu/BUILD index bf6dd25a961296..9128dd6ff9339e 100644 --- a/third_party/xla/xla/backends/profiler/tpu/BUILD +++ b/third_party/xla/xla/backends/profiler/tpu/BUILD @@ -19,6 +19,7 @@ cc_library( "//xla/stream_executor/tpu:tpu_profiler_init_fns", "//xla/stream_executor/tpu:tsl_status_helper", "//xla/tsl/c:tsl_status", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -27,7 +28,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], alwayslink = True, ) diff --git a/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc b/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc index 7488645cf40f2a..b602cdd79fe603 100644 --- a/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc +++ b/third_party/xla/xla/backends/profiler/tpu/tpu_tracer.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/stream_executor/tpu/tpu_ops_c_api.h" #include "xla/stream_executor/tpu/tsl_status_helper.h" #include "xla/tsl/c/tsl_status.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tsl/profiler/lib/profiler_interface.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" #if !defined(PLATFORM_GOOGLE) #include "xla/stream_executor/tpu/tpu_profiler_init_fns.inc" diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 1e19e92cbc844c..fda33ef2da0047 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -2,7 +2,6 @@ # XLA client libraries. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -43,27 +42,10 @@ cc_library( cc_library( name = "padding", - srcs = ["padding.cc"], hdrs = ["padding.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder:padding instead.", deps = [ - "//xla:types", - "//xla:util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/math:math_util", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "padding_test", - srcs = ["padding_test.cc"], - deps = [ - ":padding", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", + "//xla/hlo/builder:padding", ], ) @@ -72,7 +54,6 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":xla_computation", "//xla:execution_options_util", "//xla:literal", "//xla:shape_util", @@ -81,6 +62,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/hlo/builder:xla_computation", "//xla/service", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/status", @@ -125,7 +107,6 @@ cc_library( deps = [ ":client", ":executable_build_options", - ":xla_computation", "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:literal", @@ -133,6 +114,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_computation", "//xla/service:backend", "//xla/service:compiler", "//xla/service:computation_layout", @@ -143,8 +125,10 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:source_map_util", "//xla/service:stream_pool", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -162,15 +146,15 @@ cc_library( hdrs = ["compile_only_client.h"], deps = [ ":client", - ":xla_computation", "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/hlo/builder:xla_computation", "//xla/service:compile_only_service", "//xla/service:compiler", "//xla/service:hlo_module_config", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -194,8 +178,9 @@ cc_library( "//xla/service:compile_only_service", "//xla/service:local_service", "//xla/service:platform_util", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", @@ -208,139 +193,39 @@ cc_library( cc_library( name = "sharding_builder", - srcs = ["sharding_builder.cc"], hdrs = ["sharding_builder.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder:sharding_builder instead.", deps = [ - "//xla:array", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "@com_google_absl//absl/log:check", + "//xla/hlo/builder:sharding_builder", ], ) cc_library( name = "xla_computation", - srcs = ["xla_computation.cc"], hdrs = ["xla_computation.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder:xla_computation instead.", visibility = ["//visibility:public"], deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder:xla_computation", ], ) cc_library( name = "value_inference", - srcs = ["value_inference.cc"], hdrs = ["value_inference.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder:value_inference instead.", visibility = ["//visibility:public"], deps = [ - ":xla_builder", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder:value_inference", ], ) cc_library( name = "xla_builder", - srcs = ["xla_builder.cc"], hdrs = ["xla_builder.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder:xla_builder instead.", visibility = ["//visibility:public"], deps = [ - ":padding", - ":sharding_builder", - ":xla_computation", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:array4d", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:sharding_op_util", - "//xla:status_macros", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "//xla/service:shape_inference", - "//xla/tsl/lib/core:bitmap", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:stacktrace", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "xla_builder_test", - srcs = ["xla_builder_test.cc"], - deps = [ - ":padding", - ":sharding_builder", - ":value_inference", - ":xla_builder", - ":xla_computation", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder:xla_builder", ], ) diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index f5e174df13d98a..d6d4e8abb40fbc 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index 120156874869a3..dfefdb615e86a3 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -24,7 +24,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h index 2dcb9775725027..8f755691940d49 100644 --- a/third_party/xla/xla/client/compile_only_client.h +++ b/third_party/xla/xla/client/compile_only_client.h @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/compile_only_service.h" #include "xla/service/compiler.h" #include "xla/service/hlo_module_config.h" diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc index e4843194012cee..68a7bd2dc90ea7 100644 --- a/third_party/xla/xla/client/executable_build_options.cc +++ b/third_party/xla/xla/client/executable_build_options.cc @@ -169,6 +169,8 @@ absl::StatusOr ExecutableBuildOptions::ToProto() output.set_num_partitions(num_partitions()); output.set_use_spmd_partitioning(use_spmd_partitioning()); output.set_use_auto_spmd_partitioning(use_auto_spmd_partitioning()); + output.set_exec_time_optimization_effort(exec_time_optimization_effort()); + output.set_memory_fitting_effort(memory_fitting_effort()); output.set_deduplicate_hlo(deduplicate_hlo()); if (has_device_assignment()) { device_assignment().Serialize(output.mutable_device_assignment()); @@ -221,6 +223,9 @@ absl::StatusOr ExecutableBuildOptionsFromProto( output.set_num_partitions(input.num_partitions()); output.set_use_spmd_partitioning(input.use_spmd_partitioning()); output.set_use_auto_spmd_partitioning(input.use_auto_spmd_partitioning()); + output.set_exec_time_optimization_effort( + input.exec_time_optimization_effort()); + output.set_memory_fitting_effort(input.memory_fitting_effort()); output.set_deduplicate_hlo(input.deduplicate_hlo()); if (input.has_device_assignment()) { TF_ASSIGN_OR_RETURN( @@ -274,6 +279,10 @@ ExecutionOptions CreateExecutionOptions( for (auto t : build_options.auto_spmd_partitioning_mesh_ids()) { execution_options.mutable_auto_spmd_partitioning_mesh_ids()->Add(t); } + execution_options.set_exec_time_optimization_effort( + build_options.exec_time_optimization_effort()); + execution_options.set_memory_fitting_effort( + build_options.memory_fitting_effort()); execution_options.set_deduplicate_hlo(build_options.deduplicate_hlo()); if (!build_options.allow_spmd_sharding_propagation_to_parameters().empty()) { execution_options.mutable_allow_spmd_sharding_propagation_to_parameters() diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h index f1129d6ac5c1fe..e73d9d763102c6 100644 --- a/third_party/xla/xla/client/executable_build_options.h +++ b/third_party/xla/xla/client/executable_build_options.h @@ -124,6 +124,22 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_auto_spmd_partitioning_mesh_ids( std::vector mesh_ids); + float exec_time_optimization_effort() const { + return exec_time_optimization_effort_; + } + ExecutableBuildOptions& set_exec_time_optimization_effort( + float exec_time_optimization_effort) { + exec_time_optimization_effort_ = exec_time_optimization_effort; + return *this; + } + + float memory_fitting_effort() const { return memory_fitting_effort_; } + ExecutableBuildOptions& set_memory_fitting_effort( + float memory_fitting_effort) { + memory_fitting_effort_ = memory_fitting_effort; + return *this; + } + bool deduplicate_hlo() const { return deduplicate_hlo_; } ExecutableBuildOptions& set_deduplicate_hlo(bool deduplicate_hlo); @@ -277,6 +293,8 @@ class ExecutableBuildOptions { bool use_auto_spmd_partitioning_ = false; std::vector auto_spmd_partitioning_mesh_shape_; std::vector auto_spmd_partitioning_mesh_ids_; + float exec_time_optimization_effort_ = 0.0f; + float memory_fitting_effort_ = 0.0f; bool deduplicate_hlo_ = false; bool broadcast_replicated_params_ = false; std::optional device_assignment_; diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 7074e8423bef58..056ca610952c9f 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -1,7 +1,7 @@ # Common computation builders for XLA. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("//xla/tests:build_defs.bzl", "generate_backend_suites") load("//xla/tsl:tsl.bzl", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") @@ -30,553 +30,165 @@ generate_backend_suites() cc_library( name = "arithmetic", - srcs = ["arithmetic.cc"], hdrs = ["arithmetic.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:arithmetic instead.", deps = [ - ":constants", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "arithmetic_test", - srcs = ["arithmetic_test.cc"], - deps = [ - ":arithmetic", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:arithmetic", ], ) cc_library( name = "comparators", - srcs = ["comparators.cc"], hdrs = [ "comparators.h", ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:comparators instead.", deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -xla_test( - name = "comparators_test", - srcs = ["comparators_test.cc"], - deps = [ - ":comparators", - ":constants", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:protobuf", + "//xla/hlo/builder/lib:comparators", ], ) cc_library( name = "constants", - srcs = ["constants.cc"], hdrs = ["constants.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:constants instead.", deps = [ - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:constants", ], ) cc_library( name = "broadcast", - srcs = ["broadcast.cc"], hdrs = ["broadcast.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:broadcast instead.", deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "constants_test", - srcs = ["constants_test.cc"], - deps = [ - ":constants", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/builder/lib:broadcast", ], ) cc_library( name = "conv_grad_size_util", - srcs = ["conv_grad_size_util.cc"], hdrs = ["conv_grad_size_util.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:conv_grad_size_util instead.", deps = [ - "//xla/client:padding", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:conv_grad_size_util", ], ) cc_library( name = "dynamic_shaped_ops", - srcs = ["dynamic_shaped_ops.cc"], hdrs = ["dynamic_shaped_ops.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:dynamic_shaped_ops instead.", deps = [ - ":constants", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:value_inference", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:dynamic_shaped_ops", ], ) cc_library( name = "loops", - srcs = ["loops.cc"], hdrs = ["loops.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:loops instead.", deps = [ - ":constants", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:loops", ], ) cc_library( name = "math", - srcs = ["math.cc"], hdrs = [ "math.h", - "math_impl.h", ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:math instead.", deps = [ - ":arithmetic", - ":constants", - ":loops", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "math_test", - timeout = "long", - srcs = ["math_test.cc"], - backend_tags = { - # Times out. - "ghostfish_iss": ["noasan"], - }, - deps = [ - ":constants", - ":math", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/service", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", + "//xla/hlo/builder/lib:math", ], ) cc_library( name = "matrix", - srcs = ["matrix.cc"], hdrs = ["matrix.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:matrix instead.", deps = [ - ":arithmetic", - ":constants", - ":slicing", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "matrix_test", - srcs = ["matrix_test.cc"], - deps = [ - ":constants", - ":matrix", - ":slicing", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:array4d", - "//xla:test", - "//xla:types", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/hlo/builder/lib:matrix", ], ) cc_library( name = "pooling", - srcs = ["pooling.cc"], hdrs = ["pooling.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:pooling instead.", deps = [ - ":arithmetic", - ":constants", - ":conv_grad_size_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:padding", - "//xla/client:xla_builder", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "pooling_test", - srcs = ["pooling_test.cc"], - deps = [ - ":pooling", - "//xla:error_spec", - "//xla:shape_util", - "//xla/client:padding", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:pooling", ], ) cc_library( name = "prng", - srcs = ["prng.cc"], hdrs = ["prng.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:prng instead.", deps = [ - ":constants", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "prng_test", - srcs = ["prng_test.cc"], - deps = [ - ":constants", - ":prng", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder/lib:prng", ], ) cc_library( name = "qr", - srcs = ["qr.cc"], hdrs = ["qr.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:qr instead.", deps = [ - ":constants", - ":matrix", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "qr_test", - srcs = ["qr_test.cc"], - tags = ["optonly"], - deps = [ - ":matrix", - ":qr", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:qr", ], ) cc_library( name = "lu_decomposition", - srcs = ["lu_decomposition.cc"], hdrs = ["lu_decomposition.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:lu_decomposition instead.", deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:lu_decomposition", ], ) cc_library( name = "approx_topk", - srcs = ["approx_topk.cc"], hdrs = ["approx_topk.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk instead.", deps = [ - ":approx_topk_shape", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", + "//xla/hlo/builder/lib:approx_topk", ], ) cc_library( name = "approx_topk_shape", - srcs = ["approx_topk_shape.cc"], hdrs = ["approx_topk_shape.h"], - deps = [ - "//xla:util", - "@com_google_absl//absl/status:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:approx_topk_shape instead.", + deps = ["//xla/hlo/builder/lib:approx_topk_shape"], ) cc_library( name = "slicing", - srcs = ["slicing.cc"], hdrs = ["slicing.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:slicing instead.", deps = [ - ":arithmetic", - ":constants", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "slicing_test", - srcs = ["slicing_test.cc"], - deps = [ - ":slicing", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:slicing", ], ) cc_library( name = "sorting", - srcs = ["sorting.cc"], hdrs = ["sorting.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:sorting instead.", deps = [ - ":comparators", - ":constants", - ":loops", - ":slicing", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "sorting_test", - srcs = ["sorting_test.cc"], - deps = [ - ":sorting", - "//xla:array", - "//xla:array2d", - "//xla:error_spec", - "//xla:literal_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", + "//xla/hlo/builder/lib:sorting", ], ) cc_library( name = "quantize", hdrs = ["quantize.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:quantize instead.", deps = [ - ":constants", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@local_tsl//tsl/platform:bfloat16", - ], -) - -xla_test( - name = "quantize_test", - srcs = ["quantize_test.cc"], - # TODO(b/122119490): re-enable TAP after fixing. - tags = [ - "manual", - "notap", - ], - deps = [ - ":quantize", - "//xla:array2d", - "//xla:test", - "//xla:types", - "//xla:util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:bfloat16", + "//xla/hlo/builder/lib:quantize", ], ) @@ -593,10 +205,9 @@ cc_library( "//xla:xla_proto_cc", "//xla/client", "//xla/client:global_data", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service", - "//xla/tests:test_utils", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -606,9 +217,10 @@ cc_library( cc_library( name = "self_adjoint_eig", - srcs = ["self_adjoint_eig.cc"], hdrs = ["self_adjoint_eig.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:self_adjoint_eig instead.", deps = [ +<<<<<<< HEAD ":slicing", "//xla:shape_util", "//xla:util", @@ -648,170 +260,43 @@ xla_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", +======= + "//xla/hlo/builder/lib:self_adjoint_eig", +>>>>>>> master ], ) cc_library( name = "svd", - srcs = ["svd.cc"], hdrs = ["svd.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:svd instead.", deps = [ - ":arithmetic", - ":comparators", - ":constants", - ":loops", - ":math", - ":matrix", - ":slicing", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "svd_test", - srcs = ["svd_test.cc"], - real_hardware_only = True, - shard_count = 10, - tags = ["optonly"], - deps = [ - ":arithmetic", - ":constants", - ":matrix", - ":slicing", - ":svd", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", + "//xla/hlo/builder/lib:svd", ], ) cc_library( name = "tridiagonal", - srcs = ["tridiagonal.cc"], hdrs = ["tridiagonal.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tridiagonal instead.", deps = [ - ":constants", - ":loops", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "tridiagonal_test", - srcs = ["tridiagonal_test.cc"], - real_hardware_only = True, - shard_count = 10, - tags = ["optonly"], - deps = [ - ":slicing", - ":tridiagonal", - "//xla:array", - "//xla:array3d", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:tridiagonal", ], ) cc_library( name = "logdet", - srcs = ["logdet.cc"], - hdrs = ["logdet.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:logdet instead.", deps = [ - ":arithmetic", - ":constants", - ":matrix", - ":qr", - ":slicing", - "//xla:shape_util", - "//xla:util", - "//xla/client:xla_builder", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "logdet_test", - srcs = ["logdet_test.cc"], - tags = [ - "optonly", - ], - deps = [ - ":logdet", - "//xla:array", - "//xla:array2d", - "//xla:array3d", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla/client:xla_builder", - "//xla/tests:client_library_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/builder/lib:logdet", ], ) cc_library( name = "tuple", - srcs = ["tuple.cc"], hdrs = ["tuple.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/builder/lib:tuple instead.", deps = [ - "//xla:shape_tree", - "//xla:shape_util", - "//xla/client:xla_builder", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "tuple_test", - srcs = ["tuple_test.cc"], - deps = [ - ":tuple", - "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/service", - "//xla/tests:client_library_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/builder/lib:tuple", ], ) diff --git a/third_party/xla/xla/client/lib/approx_topk.h b/third_party/xla/xla/client/lib/approx_topk.h index ccad3dc79175fa..175a12cad0e94a 100644 --- a/third_party/xla/xla/client/lib/approx_topk.h +++ b/third_party/xla/xla/client/lib/approx_topk.h @@ -16,57 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_APPROX_TOPK_H_ #define XLA_CLIENT_LIB_APPROX_TOPK_H_ -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes approximate top-ks by aggregating top-1s in equal-sized windows. -// The number and the size of the windows are determined by the `recall_target`. -// -// operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1} -// init_values: N starting values for top-1 reductions -// top_k: Determines the k in top-k operation. -// reduction_dim: Determines the dimension to compute top-k. -// comparator: The comparator computation to use, which should have function -// signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool. -// recall_target: Valid range (0, 1]. User can trade-off quality and performance -// with this knob. -// aggregate_to_topk: When true, sorts the set of approximate top-k elements and -// only keep the final k elements on TPU. This option is useful when user -// wanted to forward the approximate results to host and aggregate the results -// on CPU for better throughput. -// reduction_input_size_override: When set to a positive value, it overrides the -// size determined by operands[reduction_dim] for evaluating the recall. This -// option is useful when the given operand is only a subset of the overall -// computation in SPMD or distributed pipelines, where the true input size -// cannot be deferred by the operand shape. -// -// Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1}, -// which contains the approximate top-ks from the input operands. When -// `aggregate_to_topk` is set to true, the output size is just top_k. When -// `aggregate_to_topk` is set to false, the output size varied by the target -// recall. For target recall = 0.9, the output size is roughly 10 * top_k. For -// target recall = 0.99, the output size is roughly 100 * top_k. -// -// TODO(fchern): Support other hardware platforms. -XlaOp ApproxTopK(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, int64_t top_k, - int64_t reduction_dim, const XlaComputation& comparator, - float recall_target = 0.9, bool aggregate_to_topk = true, - int64_t reduction_input_size_override = -1); - -// Fallback for platforms that haven't been optimized. -XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, int64_t top_k, - int64_t reduction_dim, - const XlaComputation& comparator, - float recall_target = 0.9, - bool aggregate_to_topk = true, - int64_t reduction_input_size_override = -1); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/approx_topk.h" #endif // XLA_CLIENT_LIB_APPROX_TOPK_H_ diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.h b/third_party/xla/xla/client/lib/approx_topk_shape.h index ef59a604adb7f2..eef1e296f36fd3 100644 --- a/third_party/xla/xla/client/lib/approx_topk_shape.h +++ b/third_party/xla/xla/client/lib/approx_topk_shape.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ #define XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ -#include - -#include "absl/status/statusor.h" - -namespace xla { - -// Determine the output size of the reduction dimension. This is useful for jax -// abstract eval to determine the output size. -// -// input_size: Input size of the reduction dimension. -// rank: Rank of the input operand. -// top_k: Determines the k in top-k operation. -// recall_target: Valid range (0, 1]. User can trade-off quality and performance -// with this knob. -// aggregate_to_topk: When true, sorts the set of approximate top-k elements and -// only keep the final k elements on TPU. This option is useful when user -// wanted to forward the approximate results to host and aggregate the results -// on CPU for better throughput. -// -// Returns a pair of -// 1. Reduction output size -// 2. Reduction amount in log2 form. -// -// 2. is invalid and set to -1 when the approximate output is disabled, i.e. -// top_k = 1 or aggregate_to_topk = true. -absl::StatusOr> ApproxTopKReductionOutputSize( - int64_t input_size, int64_t rank, int64_t top_k, float recall_target, - bool aggregate_to_topk, int64_t input_size_override = -1); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/approx_topk_shape.h" #endif // XLA_CLIENT_LIB_APPROX_TOPK_SHAPE_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic.h b/third_party/xla/xla/client/lib/arithmetic.h index c434ca7ecc430a..0b8e000a2f276b 100644 --- a/third_party/xla/xla/client/lib/arithmetic.h +++ b/third_party/xla/xla/client/lib/arithmetic.h @@ -16,75 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_ARITHMETIC_H_ #define XLA_CLIENT_LIB_ARITHMETIC_H_ -#include -#include -#include - -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -using XlaOpGenerator = std::function; - -// Creates a scalar computation based on a lambda and returns it. -XlaComputation CreateScalarComputation(const std::string& name, - PrimitiveType type, XlaBuilder* builder, - XlaOpGenerator generator); - -// Creates a scalar add computation and returns it. -XlaComputation CreateScalarAddComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar multiply computation and returns it. -XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar ge computation and returns it. -XlaComputation CreateScalarGeComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar max computation and returns it. -XlaComputation CreateScalarMaxComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar min computation and returns it. -XlaComputation CreateScalarMinComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar logical AND computation and returns it. -XlaComputation CreateScalarAndComputation(PrimitiveType type, - XlaBuilder* builder); - -// Creates a scalar logical OR computation and returns it. -XlaComputation CreateScalarOrComputation(PrimitiveType type, - XlaBuilder* builder); - -// This is to be used for general purpose "identity" like reductions with zero -// for any type (ie. boolean operations for PRED and Add for real numbers). -// As an example, this operation can be used for a situation of: -// x_type = type(x) -// op = CreateScalarIdentityWithZeroComputation(x_type) -// ASSERT_TRUE(op(x, 0) == x) -// -// This functionality is used for operations that are similar to a slice, -// gather, or broadcast, but are created through a reduction. -XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, - XlaBuilder* builder); - -// Returns whether any predicate in "predicates" is set. -// -// Note: if predicates is zero-sized, Any() vacuously returns false. -XlaOp Any(XlaOp predicates); - -// Returns the argmax of `input` along `axis`. `output_type` is the type to -// use for the output. In case of ties always prefers smaller index. -XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); - -// Dispatch to ArgMin or ArgMax above, depending on bool. -XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/arithmetic.h" #endif // XLA_CLIENT_LIB_ARITHMETIC_H_ diff --git a/third_party/xla/xla/client/lib/broadcast.h b/third_party/xla/xla/client/lib/broadcast.h index d28b28133a7b15..deb85ae9ab8585 100644 --- a/third_party/xla/xla/client/lib/broadcast.h +++ b/third_party/xla/xla/client/lib/broadcast.h @@ -16,20 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_BROADCAST_H_ #define XLA_CLIENT_LIB_BROADCAST_H_ -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting -// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. -absl::StatusOr BroadcastTo(XlaOp input, - absl::Span output_dims); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/broadcast.h" #endif // XLA_CLIENT_LIB_BROADCAST_H_ diff --git a/third_party/xla/xla/client/lib/comparators.h b/third_party/xla/xla/client/lib/comparators.h index e5d3de12ca2df1..ad9b37d716d717 100644 --- a/third_party/xla/xla/client/lib/comparators.h +++ b/third_party/xla/xla/client/lib/comparators.h @@ -16,45 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_COMPARATORS_H_ #define XLA_CLIENT_LIB_COMPARATORS_H_ -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Creates a scalar less-than computation and returns it. The created -// computation has 2 * 'operand_types.size()' many parameters, where parameters -// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The -// computation compares the first two parameters. For floating point types, a -// total order is created where -// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN -XlaComputation CreateScalarLtComputation( - const std::vector& operand_types, XlaBuilder* builder); - -// Creates a scalar greater-than computation and returns it. The created -// computation has 2 * 'operand_types.size()' many parameters, where parameters -// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The -// computation compares the first two parameters. For floating point types, a -// total order is created where -// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN -XlaComputation CreateScalarGtComputation( - const std::vector& operand_types, XlaBuilder* builder); - -// Creates a scalar comparison computation and returns it. This function takes -// a vector of comparator functions to compare the operands where the function -// isn't nullopt with the specified comparator at that location. -XlaComputation CreateScalarComparisonComputation( - const std::string& name, const std::vector& operand_types, - const std::vector< - std::optional)>>& - generators, - XlaBuilder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/comparators.h" #endif // XLA_CLIENT_LIB_COMPARATORS_H_ diff --git a/third_party/xla/xla/client/lib/constants.h b/third_party/xla/xla/client/lib/constants.h index 6f25b82d077cb9..2135f481977396 100644 --- a/third_party/xla/xla/client/lib/constants.h +++ b/third_party/xla/xla/client/lib/constants.h @@ -16,125 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_CONSTANTS_H_ #define XLA_CLIENT_LIB_CONSTANTS_H_ -#include - -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/ml_dtypes.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is -// determined at C++ run-time, rather than C++ compile-time. -// If 'value' is floating point but 'type' is not, or if 'value' is complex but -// 'type' is not, an error will be returned. This is to catch accidental -// truncation; in such cases, use an explicit cast. -template -XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { - if (std::is_floating_point::value && - !(primitive_util::IsFloatingPointType(type) || - primitive_util::IsComplexType(type))) { - return builder->ReportError(InvalidArgument( - "Invalid cast from floating point type to %s in ConstantR0WithType.", - PrimitiveType_Name(type))); - } - if (std::is_same::value && - !primitive_util::IsComplexType(type)) { - return builder->ReportError(InvalidArgument( - "Invalid cast from complex type to %s in ConstantR0WithType.", - PrimitiveType_Name(type))); - } - return primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> XlaOp { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = primitive_util::NativeTypeOf; - return ConstantR0(builder, static_cast(value)); - } - return builder->ReportError( - InvalidArgument("Invalid type for ConstantR0WithType (%s).", - PrimitiveType_Name(type))); - }, - type); -} - -// Returns a scalar containing 'value' cast to the same run-time type as -// 'prototype'. -// If 'value' is floating point but 'prototype' is not, or if 'value' is complex -// 'prototype' is not, an error will be returned. -template -XlaOp ScalarLike(XlaOp prototype, T value) { - XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); - return ConstantR0WithType(builder, shape.element_type(), value); - }); -} - -// Returns an array or scalar containing copies of `value` cast to the same -// run-type type as `prototype` and broadcast to the same dimensions as -// `prototype`. -// -// If `prototype` is not a scalar or array, returns an error. -template -XlaOp FullLike(XlaOp prototype, T value) { - XlaBuilder* builder = prototype.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); - if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { - return Broadcast(ScalarLike(prototype, value), shape.dimensions()); - } else { - return InvalidArgument( - "Prototype shape for BroadcastConstantLike must be a scalar or " - "array, but was %s", - shape.ToString()); - } - }); -} - -// Returns a scalar with value '0' of 'type'. -XlaOp Zero(XlaBuilder* builder, PrimitiveType type); - -// Returns a zero-filled tensor with shape `shape`. -XlaOp Zeros(XlaBuilder* builder, const Shape& shape); - -// Returns a zero-filled tensor with the same shape as `prototype`. -XlaOp ZerosLike(XlaOp prototype); - -// Returns a scalar with value '1' of 'type'. -XlaOp One(XlaBuilder* builder, PrimitiveType type); - -// Returns the machine epsilon for floating-point type `type`, i.e., -// the difference between 1.0 and the next representable value. -XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum representable finite or infinite value for 'type'. -// Returns '-inf' for floating-point types. -XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum representable finite value for 'type'. For a floating -// point type, this is equal to -MaxFiniteValue(). -XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the minimum positive normal value for floating-point type `type`. -XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the maximum representable finite or infinite value for 'type'. -// Returns 'inf' for floating-point types. -XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); - -// Returns the maximum representable finite value for 'type'. -XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); - -// Returns a nan for the given type. Only valid for real-valued fp types. -XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/constants.h" #endif // XLA_CLIENT_LIB_CONSTANTS_H_ diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.h b/third_party/xla/xla/client/lib/conv_grad_size_util.h index ca56ada8b55f25..e991982968da9e 100644 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.h +++ b/third_party/xla/xla/client/lib/conv_grad_size_util.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ #define XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ -#include "absl/status/statusor.h" -#include "xla/client/padding.h" - -namespace xla { - -// Information about a single spatial dimension for a convolution gradients and -// windowed operations. -struct SpatialDimensionOutputSizeAndPadding { - // Effective size of the operation output (potentially expanded). - int64_t output_size; - // Number of padding elements to be added before/after this dimension of - // the input when computing the input gradient. - int64_t pad_before; - int64_t pad_after; -}; - -// Verifies that the dimensions all match, and computes the size and padding of -// a spatial dimension for convolution gradient operations. -absl::StatusOr -ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, - int64_t output_size, int64_t dilation, - int64_t stride, Padding padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/conv_grad_size_util.h" #endif // XLA_CLIENT_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h b/third_party/xla/xla/client/lib/dynamic_shaped_ops.h index 31305bd90a7b58..cf62a37d6f920e 100644 --- a/third_party/xla/xla/client/lib/dynamic_shaped_ops.h +++ b/third_party/xla/xla/client/lib/dynamic_shaped_ops.h @@ -16,44 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ #define XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/primitive_util.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Similar to static shaped conditional, but allows true_computation and -// false_computation to have different dimension sizes (ranks still have to be -// the same). Fall back to static conditional if dynamism is not presented. -XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, - XlaOp true_operand, - const XlaComputation& true_computation, - XlaOp false_operand, - const XlaComputation& false_computation); - -// Similar to DynamicConditional, but support multiple branches. -XlaOp DynamicConditional( - XlaBuilder* builder, XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - -// Similar to SetDimensionSize, but automatically adjust the bound of output if -// a tighter one can be inferred by `value_inference`. -absl::StatusOr SetDimensionSizeWithRebound( - ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, - int64_t dimension); - -// Take a `operand` tensor and a R1 tensor `size_vector` representing the sizes -// of `operand`, Call SetDimensionSize if for each dimension whose size is -// dynamic. -absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, - XlaOp operand, XlaOp size_vector); -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" #endif // XLA_CLIENT_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/third_party/xla/xla/client/lib/loops.h b/third_party/xla/xla/client/lib/loops.h index 3b9855e58cc3dd..d714efeaa415f1 100644 --- a/third_party/xla/xla/client/lib/loops.h +++ b/third_party/xla/xla/client/lib/loops.h @@ -16,59 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_LOOPS_H_ #define XLA_CLIENT_LIB_LOOPS_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Function that builds a loop condition. Takes as input a sequence of input -// values, and returns a boolean value representing if the condition succeeds. -typedef std::function(absl::Span, - XlaBuilder*)> - WhileLoopHelperConditionFunction; - -// Function that builds a loop body. Takes as input a sequence of input values -// and returns a sequence of output values. -typedef std::function>( - absl::Span, XlaBuilder*)> - WhileLoopHelperBodyFunction; - -// Helper function for building an XLA while loop, where the values carried by -// the loop are a tuple of values, e.g., (a, b, c): -// while( -// condition: (a, b, c) -> bool, -// body: (a, b, c) -> (a, b, c) -// init: (a, b, c) -// ) -// 'name' is a descriptive name for the loop. -absl::StatusOr> WhileLoopHelper( - const WhileLoopHelperConditionFunction& condition_function, - const WhileLoopHelperBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - XlaBuilder* builder); - -// Builds an XLA loop that repeats a computation `num_iterations` times. -// -// The body function (ForEachIndexBodyFunction) takes as input a pair of -// (current iteration number, loop-carried values), and returns an updated -// vector of the loop-carried values. -typedef std::function>( - XlaOp, absl::Span, XlaBuilder*)> - ForEachIndexBodyFunction; - -absl::StatusOr> ForEachIndex( - int64_t num_iterations, PrimitiveType num_iterations_type, - const ForEachIndexBodyFunction& body_function, - absl::Span initial_values, absl::string_view name, - XlaBuilder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/loops.h" #endif // XLA_CLIENT_LIB_LOOPS_H_ diff --git a/third_party/xla/xla/client/lib/lu_decomposition.h b/third_party/xla/xla/client/lib/lu_decomposition.h index a2d26e02f4e635..752e84c9d2b12f 100644 --- a/third_party/xla/xla/client/lib/lu_decomposition.h +++ b/third_party/xla/xla/client/lib/lu_decomposition.h @@ -16,46 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ #define XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes the LU decomposition with partial pivoting of a batch of matrices. -// -// Given a (batched) matrix a with shape [..., m, n], computes the matrix -// decomposition A = P @ L @ U where P is a permutation matrix, L is a -// lower-triangular matrix with unit diagonal entries, and U is an -// upper-triangular matrix. -// -// L and U are returned as a single matrix [..., m, n] containing both L and U -// packed in the same array. The unit diagonal of L is not represented -// explicitly. -// -// The permutation matrix P is returned in two forms, both as `pivots`, which is -// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the -// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array -// which gives the permutation to apply to the rows. We return both -// representations because they are each useful for different purposes; `pivots` -// is useful for computing the sign of a determinant, whereas `permutation` can -// be used via a Gather operation to permute the rows of a matrix. -// -// This method is only implemented on TPU at the moment. -// TODO(b/168208200): the implementation only supports F32 arrays. Handle the -// complex case. -struct LuDecompositionResult { - // The LU decomposition, with both L and U packed into an array with shape - // [..., m, n]. - XlaOp lu; - // An array of shape s32[..., min(m, n)] containing the pivot rows. - XlaOp pivots; - // An array of shape s32[..., m], containing an another representation of the - // pivots as a permutation. - XlaOp permutation; -}; - -LuDecompositionResult LuDecomposition(XlaOp a); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/lu_decomposition.h" #endif // XLA_CLIENT_LIB_LU_DECOMPOSITION_H_ diff --git a/third_party/xla/xla/client/lib/math.h b/third_party/xla/xla/client/lib/math.h index 74b8a387a416de..9956776ee87d1a 100644 --- a/third_party/xla/xla/client/lib/math.h +++ b/third_party/xla/xla/client/lib/math.h @@ -16,112 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_MATH_H_ #define XLA_CLIENT_LIB_MATH_H_ -#include "xla/client/xla_builder.h" - -namespace xla { - -// Determines whether operand is +/-inf or nan. -// -// Raises an error if called on integral or complex values. -XlaOp IsPosInf(XlaOp operand); -XlaOp IsNegInf(XlaOp operand); -XlaOp IsInf(XlaOp operand); -XlaOp IsNan(XlaOp operand); - -// Determines whether operand is equal to -0. -// -// Raises an error for integral or complex values. -XlaOp IsNegZero(XlaOp operand); - -// Returns the next number after 'from' in the direction of 'to' the same way -// std::nextafter(from, to) would. -XlaOp NextAfter(XlaOp from, XlaOp to); - -// Computes the square of 'operand'. -XlaOp Square(XlaOp operand); - -// Computes the reciprocal of 'operand'. -XlaOp Reciprocal(XlaOp operand); - -// Computes an approximation of the error function complement (1 - erf(x)). -XlaOp Erfc(XlaOp x); - -// Computes an approximation of the inverse of the error function. -XlaOp ErfInv(XlaOp x); - -// Computes an approximation of the lgamma function. -XlaOp Lgamma(XlaOp input); - -// Computes an approximation of the digamma function. -XlaOp Digamma(XlaOp input); - -// Computes an approximation of the incomplete gamma function. -XlaOp Igamma(XlaOp a, XlaOp x); - -// Computes an approximation of the derivative of the incomplete gamma function -// with respect to a. -XlaOp IgammaGradA(XlaOp a, XlaOp x); - -// Computes an approximation of the derivative of a sample `x` from a `Gamma(a, -// 1)` distribution with respect to a. -XlaOp RandomGammaGrad(XlaOp a, XlaOp x); - -// Computes an approximation of the complementary incomplete gamma function. -XlaOp Igammac(XlaOp a, XlaOp x); - -// Computes the Polygamma of two arguments. -XlaOp Polygamma(XlaOp n, XlaOp x); - -// Computes the Riemann zeta function of two arguments. -XlaOp Zeta(XlaOp x, XlaOp q); - -// Rounds the given number to even when the number is equidistant between two -// integers. -XlaOp RoundToEven(XlaOp x); - -// Trigonometric functions - -// Computes the arc cosine of 'x'. -XlaOp Acos(XlaOp x); - -// Computes the arc sine of 'x'. -XlaOp Asin(XlaOp x); - -// Computes the arc tangent of 'x'. -XlaOp Atan(XlaOp x); - -// Hyperbolic trigonometric functions - -// Computes the inverse hyperbolic cosine of 'x'. -XlaOp Acosh(XlaOp x); - -// Computes the inverse hyperbolic sine of 'x'. -XlaOp Asinh(XlaOp x); - -// Computes the inverse hyperbolic tangent of 'x'. -XlaOp Atanh(XlaOp x); - -// Computes the hyperbolic cosine of 'x'. -XlaOp Cosh(XlaOp x); - -// Computes the hyperbolic sine of 'x'. -XlaOp Sinh(XlaOp x); - -// Applies a complex conjugation operation if 'a' is complex and 'conjugate' -// is true, otherwise returns its argument. -xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); - -// Computes the Modified Bessel function of the first kind of the zeroth order -// at x. -XlaOp BesselI0e(XlaOp x); - -// Computes the Modified Bessel function of the first kind of the first order -// at x. -XlaOp BesselI1e(XlaOp x); - -// Computes the Regularized Incomplete Beta function. -XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/math.h" #endif // XLA_CLIENT_LIB_MATH_H_ diff --git a/third_party/xla/xla/client/lib/matrix.h b/third_party/xla/xla/client/lib/matrix.h index df3a2e878d88a7..aaf938786fc020 100644 --- a/third_party/xla/xla/client/lib/matrix.h +++ b/third_party/xla/xla/client/lib/matrix.h @@ -16,144 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_MATRIX_H_ #define XLA_CLIENT_LIB_MATRIX_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere -// else. -XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m, - int64_t n); - -// Returns a mask where the 'diagonal'-th diagonal is true and everything else -// is false. -XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); - -// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the -// main diagonal, and k<0 for diagonals below the main diagonal. -// -// If 'x' has shape [..., M, N] -// If k >= 0: then the output has shape [..., min(M, N - k)], containing the -// diagonal elements (i.e., with indices [..., i, i + k]). -// If k < 0: then the output has shape [..., min(M + k, N)], containing the -// diagonal elements (i.e., with indices [..., i - k, i]). -XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); -XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); - -// Places diag along the kth diagonal of target. -XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); - -// Returns a lower-triangular mask, i.e., true below and including the -// `diagonal`-th diagonal and false above that diagonal. -XlaOp TriangleMask(XlaOp x, int diagonal); - -// Get the upper or lower triangle part of the last two dimensions -XlaOp Triangle(XlaOp x, bool lower); - -// Get the upper triangle part of the last two dimensions -XlaOp UpperTriangle(XlaOp x); - -// Get the lower triangle part of the last two dimensions -XlaOp LowerTriangle(XlaOp x); - -// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing -// the upper triangle with the transpose of the lower triangle (if lower is -// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix -// Hermitian by taking the conjugate of the complex part and setting the -// complex diagonal to zero. -XlaOp Symmetrize(XlaOp x, bool lower); - -// Multiplies slices of two tensors in batches. - -// Multiplies all slices of `Tensor` `x` and `y` (each slice can be -// viewed as an element of a batch), and arranges the individual results -// in a single output tensor of the same batch size. -// -// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` -// and `[..., r_y, c_y]`. -// -// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: -// -// r_o = c_x if transpose_x else r_x -// c_o = r_y if transpose_y else c_y -// -// It is computed as: -// -// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -xla::XlaOp BatchDot( - xla::XlaOp x, xla::XlaOp y, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt); -xla::XlaOp BatchDot( - xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); - -// Parse an einsum string into dimension numbers: -// "ab,cb->ac" -// becomes: -// {{0, 1},{2, 1},{0, 2}} -// -// Each occurrence of ellipsis ("...") occurring in the input is replaced with -// the same numeric dimensions. The number of such dimensions is inferred from -// x_rank and y_rank. For example: -// einsum_config: "...ab,...bcd->...acd" -// x_rank: 4 -// y_rank: 5 -// becomes: -// {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}} -// -// NOTE: This function is meant for testing, there is no need to call it -// directly. - -absl::StatusOr, 3>> ParseEinsumString( - absl::string_view einsum_config, int64_t x_rank, int64_t y_rank); - -// If an einsum config does not contain an -> one will be added and the output -// config will be the sorted characters with any ellipsis at the beginning. -// Returns an empty string if the einsum string already has an ->. -std::string NormalizeEinsumString(absl::string_view einsum_config); - -// Supports two operand einsum notation like "ab,cb->ac". -xla::XlaOp Einsum( - xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); -xla::XlaOp Einsum( - xla::XlaOp x, absl::string_view einsum_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); - -// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" -// becomes: -// x_config = {0, 1} -// y_config = {2, 1} -// output_config = {0, 2} -xla::XlaOp Einsum( - xla::XlaOp x, absl::Span x_config, xla::XlaOp y, - absl::Span y_config, absl::Span output_config, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, - std::optional preferred_element_type = std::nullopt, - bool grad_x = false, bool grad_y = false); - -// Transposes a stack of matrices `x` by swapping the last two dimensions. -xla::XlaOp TransposeInMinorDims(xla::XlaOp x); - -// Transposes `x` in its minor dimensions if `transpose` is true, otherwise -// returns `x` unchanged. -xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/matrix.h" #endif // XLA_CLIENT_LIB_MATRIX_H_ diff --git a/third_party/xla/xla/client/lib/pooling.h b/third_party/xla/xla/client/lib/pooling.h index eb0a43029b359d..22f3d2f0b07b9c 100644 --- a/third_party/xla/xla/client/lib/pooling.h +++ b/third_party/xla/xla/client/lib/pooling.h @@ -16,68 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_POOLING_H_ #define XLA_CLIENT_LIB_POOLING_H_ -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" - -namespace xla { - -// Tensor format for reduce window operations. -class TensorFormat { - public: - TensorFormat(int batch_dimension, int feature_dimension, - absl::Span spatial_dimensions) - : batch_dimension_(batch_dimension), - feature_dimension_(feature_dimension), - spatial_dimensions_(spatial_dimensions.begin(), - spatial_dimensions.end()) {} - - int batch_dimension() const { return batch_dimension_; } - - int feature_dimension() const { return feature_dimension_; } - - int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } - - int num_spatial_dims() const { return spatial_dimensions_.size(); } - - private: - // The number of the dimension that represents the batch. - int batch_dimension_; - // The number of the dimension that represents the features. - int feature_dimension_; - // The dimension numbers for the spatial dimensions. - absl::InlinedVector spatial_dimensions_; -}; - -// Computes the max pool of 'operand'. -XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, - absl::Span stride, Padding padding, - const TensorFormat& data_format); - -// Computes the average pool of 'operand'. -XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, - absl::Span stride, - absl::Span> padding, - const TensorFormat& data_format, bool counts_include_padding); - -// Returns the list of low and high padding elements in each spatial dimension -// for the given 'padding' specification. -std::vector> MakeSpatialPadding( - absl::Span input_size, absl::Span kernel_size, - absl::Span stride, Padding padding, - const TensorFormat& data_format); - -// Computes the average pool gradient. -XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, - absl::Span kernel_size, - absl::Span stride, - absl::Span> spatial_padding, - const TensorFormat& data_format, bool counts_include_padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/pooling.h" #endif // XLA_CLIENT_LIB_POOLING_H_ diff --git a/third_party/xla/xla/client/lib/prng.h b/third_party/xla/xla/client/lib/prng.h index ef78a881c19460..0c9e460ba10cbb 100644 --- a/third_party/xla/xla/client/lib/prng.h +++ b/third_party/xla/xla/client/lib/prng.h @@ -16,86 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_PRNG_H_ #define XLA_CLIENT_LIB_PRNG_H_ -#include -#include -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Records the bits and state generated by a random number generator. -struct RngOutput { - XlaOp value; - XlaOp state; -}; - -// A BitGenerator returns random bits and updated random bit generator state. -// -// key: is a value input to a random number generator that can affect the -// sequence of number it will generate. A random number generator constructs -// its seed using the key and the initial state. The tf2xla bridge passes the -// seed operand of a tensorflow random operation as a key to the random bit -// generator, for example. -// initial_state: initial_state is the initial state of the current random -// number generation. It could be 0 for a stateless random operation, and -// the returned state from a previous execution for a stateful random -// operation. -// shape: the shape of the random bits. -using BitGeneratorTy = std::function; - -// Implements the ThreeFry counter-based PRNG algorithm. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, - const xla::Shape& shape); - -// Implements the Philox algorithm to generate random numbers in parallel. -// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. -// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -// -// The paper presents a few variants of the Philox algorithm, we picked the -// 4x32_10 version of the algorithm for the following reasons: -// . 4x32 uses 32-bit multiplication which is fast on GPUs. -// . The authors recommend the 10-round variant, and TensorFlow also uses it. -RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, - const Shape& shape); -// Returns a scrambled pair of (state, key) from a single key. -std::pair ScramblePhiloxKey(XlaOp key); - -// Uses the given bit generator to generate random bits and then converts the -// random bits to random numbers of uniform distribution in the given range. -// Returns the random numbers and the state of the random number generator. -// This function is for shape with floating point element types. -RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, - XlaOp minval, XlaOp maxval, - const xla::Shape& shape); - -// Similar to UniformFloatingPointDistribution but for shape with integer -// element types. -RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, XlaOp minval, - XlaOp maxval, const xla::Shape& shape); - -// Uses the given bit generator to generate random bits and then converts the -// random bits to random numbers of normal distribution. -// Returns the random numbers and the state of the random number generator. -RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, - BitGeneratorTy bit_generator, - const xla::Shape& shape); - -// Concatenates scalars into a vector. -xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, - absl::Span scalars); - -// Increases Philox counter (an uint128_t) by a delta (an uint64_t). -xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/prng.h" #endif // XLA_CLIENT_LIB_PRNG_H_ diff --git a/third_party/xla/xla/client/lib/qr.h b/third_party/xla/xla/client/lib/qr.h index ce51ab342bb39b..743b36503b6175 100644 --- a/third_party/xla/xla/client/lib/qr.h +++ b/third_party/xla/xla/client/lib/qr.h @@ -16,37 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_QR_H_ #define XLA_CLIENT_LIB_QR_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Computes the QR decompositions of a batch of matrices. That is, -// given a (batched) matrix a, computes an orthonormal matrix Q and an -// upper-triangular matrix R such that a = QR. -// `a` must be a (batched) matrix of size [..., m, n]. -struct QrDecomposition { - // A matrix with the same shape as the input matrix `a`, whose upper triangle - // (inclusive of the diagonal) is the matrix R, and whose lower triangle - // (exclusive of the diagonal) contains the elementary Householder reflectors. - // This is the same output format as used by LAPACK's xGEQRF routine. - XlaOp q_and_r; - // A vector of shape [..., min(m, n)] containing the scalar factors of the - // elementary Householder reflectors. - XlaOp taus; -}; - -QrDecomposition Qr(XlaOp a); - -// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of -// the elementary Householder reflectors (i.e., the matrix Q of the QR -// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR. -XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus); - -// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to -// compute explicit matrices `q` and `r`. -void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/qr.h" #endif // XLA_CLIENT_LIB_QR_H_ diff --git a/third_party/xla/xla/client/lib/quantize.h b/third_party/xla/xla/client/lib/quantize.h index f9835c42642d32..459716b36b54db 100644 --- a/third_party/xla/xla/client/lib/quantize.h +++ b/third_party/xla/xla/client/lib/quantize.h @@ -16,169 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_QUANTIZE_H_ #define XLA_CLIENT_LIB_QUANTIZE_H_ -#include -#include -#include -#include - -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/bfloat16.h" - -namespace xla { - -// Represents the range used for quantization -struct QuantizedRange { - QuantizedRange() = default; - QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} - - bool operator==(const QuantizedRange& rhs) const { - return this->min == rhs.min && this->max == rhs.max; - } - - bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } - - tsl::bfloat16 min = tsl::bfloat16(0.0f); - tsl::bfloat16 max = tsl::bfloat16(0.0f); -}; - -template -inline std::vector PackToUint32(absl::Span input) { - const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T); - const int64_t input_size = input.size(); - const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack); - - std::vector output_vec; - constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT; - - for (int64_t i = 0; i < output_size; i++) { - uint32_t result = 0; - for (int64_t p = 0; p < kElementsPerPack; p++) { - int64_t index = i * kElementsPerPack + p; - if (index < input_size) { - int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); - result |= (input[index] << total_shift_bits); - } - } - output_vec.push_back(result); - } - - return output_vec; -} - -// Dequantize the quantized input of packed uint32_t to bfloat16. -// Only uint8_t or uint16_t is supported for the original unpacked input. -// Returns a tensor of shape [d0,..., dn * unpack_size] if -// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). -// If transpose_output is true, will return a tensor of shape -// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when -// input's rank higher than 1. The input needs to be transposed to use -// transpose_output feature. -template -inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, - absl::string_view mode_string = "MIN_COMBINED", - bool transpose_output = false) { - XlaBuilder* const builder = input.builder(); - return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { - float half_range = - !std::is_signed::value - ? 0.0f - : (static_cast(std::numeric_limits::max()) - - std::numeric_limits::min() + 1) / - 2.0f; - const int64_t unpack_size = sizeof(uint32_t) / sizeof(T); - TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); - - auto element_type = shape.element_type(); - if (element_type != U32) { - return InvalidArgument( - "Only U32 is supported for input type of xla::Dequantize Op."); - } - - // Broadcast the input to [unpack_size, d0, ..., dn] if input size is - // [d0, ..., dn]. - auto broadcast_input = Broadcast(input, {unpack_size}); - - XlaOp iota_r1 = Iota(builder, U32, unpack_size); - // Highest significant bytes needs to shift more bytes than lower - // significant bytes. - XlaOp shift_bytes = - xla::ConstantR0(builder, unpack_size - 1) - iota_r1; - - const int bytes_of_type = sizeof(T) / sizeof(uint8_t); - std::vector shift_vec(unpack_size, CHAR_BIT * bytes_of_type); - XlaOp shift_bits = - shift_bytes * xla::ConstantR1(builder, shift_vec); - - // Make bit_mask for different data type T. - uint32_t bit_mask = 0x00000000; - for (int i = 0; i < bytes_of_type; i++) { - bit_mask <<= CHAR_BIT; - bit_mask |= 0x000000ff; - } - - std::vector shift_transpose_dimensions(shape.dimensions_size()); - std::iota(shift_transpose_dimensions.begin(), - shift_transpose_dimensions.end(), 0); - shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, - shape.dimensions_size()); - - // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. - XlaOp shifted_input = ShiftRightLogical( - broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), - shift_transpose_dimensions)); - XlaOp unpack_input = - And(shifted_input, xla::ConstantR0(builder, bit_mask)); - - XlaOp result; - - if (mode_string == "MIN_COMBINED") { - const tsl::bfloat16 scale_factor = - (range.max - range.min) / - (static_cast(std::numeric_limits::max() - - std::numeric_limits::min())); - // result = bfloat16(input + half_range) * scale_factor + range.min - XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); - XlaOp half_range_bf16 = xla::ConstantR0( - builder, static_cast(half_range)); - XlaOp sum = unpack_input_bf16 + half_range_bf16; - - result = sum * xla::ConstantR0(builder, scale_factor) + - xla::ConstantR0(builder, range.min); - } else { - // TODO(wangtao): support other modes. - return InvalidArgument( - "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); - } - - std::vector transpose_dimensions(shape.dimensions_size()); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); - std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); - transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); - - // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. - XlaOp transposed_result = Transpose(result, transpose_dimensions); - - // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. - XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); - - // Return the transpose result if transpose_output is true. - if (transpose_output) { - return reshaped_result; - } - - // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. - std::vector result_dimensions(shape.dimensions_size()); - std::iota(result_dimensions.begin(), result_dimensions.end(), 0); - std::reverse(result_dimensions.begin(), result_dimensions.end()); - - return Transpose(reshaped_result, result_dimensions); - }); -} - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/quantize.h" #endif // XLA_CLIENT_LIB_QUANTIZE_H_ diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.h b/third_party/xla/xla/client/lib/self_adjoint_eig.h index f375f192e71f0e..ae81dbc0baf5a0 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.h +++ b/third_party/xla/xla/client/lib/self_adjoint_eig.h @@ -16,26 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ #define XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The eigenvalue decomposition of a symmetric matrix, the original matrix is -// recovered by v * w * v_t. -struct SelfAdjointEigResult { - // The i-th column is the normalized eigenvector corresponding to the - // eigenvalue w[i]. Will return a matrix object if a is a matrix object. - XlaOp v; - // The eigenvalues in ascending order, each repeated according to its - // multiplicity. - XlaOp w; -}; - -SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, - int64_t max_iter = 15, float tol = 1e-5, - bool sort_eigenvalues = true); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #endif // XLA_CLIENT_LIB_SELF_ADJOINT_EIG_H_ diff --git a/third_party/xla/xla/client/lib/slicing.h b/third_party/xla/xla/client/lib/slicing.h index 329f299e40a896..c2ea243ae2c937 100644 --- a/third_party/xla/xla/client/lib/slicing.h +++ b/third_party/xla/xla/client/lib/slicing.h @@ -13,71 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/types.h" - #ifndef XLA_CLIENT_LIB_SLICING_H_ #define XLA_CLIENT_LIB_SLICING_H_ -namespace xla { - -// Updates a slice of 'x', i.e., -// x[start[0], ..., start[n]] = update -XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); - -// Performs a slice in the minor dimensions of a tensor. -// x[..., start[0]:end[0], ..., start[n]:end[n]] -XlaOp SliceInMinorDims(XlaOp x, absl::Span start, - absl::Span end); - -// Updates a slice of 'x', where 'start' contains a list of minor dimensions: -// x[..., start[0]:..., ..., start[n]:...] = update -XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, - absl::Span start); - -// Performs a dynamic slice in the minor dimensions of a tensor. -XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, - absl::Span sizes); - -XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, - absl::Span starts); - -// Gathers values along an axis specified by dim. -// -// For a 3-D tensor the output is specified by: -// -// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 -// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 -// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 -// -// If `input` is an n-dimensional tensor with size -// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size -// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as -// `index`. -XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse = true); - -// idx = index[i][j][k] -// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 -// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 -// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 -XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, - const std::function& combiner); - -// Returns a new tensor which indexes the input tensor along dimension dim using -// the entries in index. -// -// The returned tensor has the same number of dimensions as the original tensor -// (input). The dimth dimension has the same size as the length of index; other -// dimensions have the same size as in the original tensor. -// -// This operation supports 0 or more major batch dimensions that act like a -// multidimensional loop over both the input and the index. -XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, - int64_t batch_dims = 0); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/slicing.h" #endif // XLA_CLIENT_LIB_SLICING_H_ diff --git a/third_party/xla/xla/client/lib/sorting.h b/third_party/xla/xla/client/lib/sorting.h index 4af4f8caaf977e..5cb81a43c11f36 100644 --- a/third_party/xla/xla/client/lib/sorting.h +++ b/third_party/xla/xla/client/lib/sorting.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SORTING_H_ #define XLA_CLIENT_LIB_SORTING_H_ -#include "xla/client/xla_builder.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns a tuple composed of the top `k` values and corresponding indices in -// `input`. Output values are in descending order, from largest to smallest. -XlaOp TopK(XlaOp input, int64_t k, - PrimitiveType index_type = PrimitiveType::S32); - -// Split sort in TopK into smaller sorts. -// Returns a tuple composed of the top `k` values and corresponding indices in -// `input`. Output values are in descending order, from largest to smallest. -XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1, - PrimitiveType index_type = PrimitiveType::S32); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/sorting.h" #endif // XLA_CLIENT_LIB_SORTING_H_ diff --git a/third_party/xla/xla/client/lib/svd.h b/third_party/xla/xla/client/lib/svd.h index 07f361f73b3a3f..54893697c5fced 100644 --- a/third_party/xla/xla/client/lib/svd.h +++ b/third_party/xla/xla/client/lib/svd.h @@ -16,34 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_SVD_H_ #define XLA_CLIENT_LIB_SVD_H_ -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The singular value decomposition of a given matrix A[..., M, N], the original -// matrix is recovered by u * diag(d) * v_t, where the first dims(A) - 2 -// dimensions are batch dimensions. -struct SVDResult { - // The columns of U are the left-singular vectors, e.g., - // U[..., :, :]_T * U[..., :, :] = I. - XlaOp u; - // Vector(s) with the singular values, within each vector sorted in descending - // order. The first dims(D) - 1 dimensions have the same size as the batch - // dimensions of A. And U[..., :, i] * D[..., i] = A[..., :, :] * V[..., :, - // i]. - XlaOp d; - // The columns of V are the right-singular vectors. e.g., - // V[..., :, :]_T * V[..., :, :] = I. - XlaOp v; -}; - -// TODO(kuny): Add a bool flag that supports SVD with economy (reduced) -// representation, which is more memory efficient, especially in the case of -// tall-skinny matrices. -SVDResult SVD(XlaOp a, int64_t max_iter = 100, float epsilon = 1e-6, - PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/svd.h" #endif // XLA_CLIENT_LIB_SVD_H_ diff --git a/third_party/xla/xla/client/lib/testing.cc b/third_party/xla/xla/client/lib/testing.cc index dfda52163ebf1f..f8fb61b79c0c84 100644 --- a/third_party/xla/xla/client/lib/testing.cc +++ b/third_party/xla/xla/client/lib/testing.cc @@ -22,15 +22,14 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/client/client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/service.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/test_utils.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/client/lib/testing.h b/third_party/xla/xla/client/lib/testing.h index 76e268b36ebb97..a9b566c6635b3f 100644 --- a/third_party/xla/xla/client/lib/testing.h +++ b/third_party/xla/xla/client/lib/testing.h @@ -21,7 +21,7 @@ limitations under the License. #include "xla/client/client.h" #include "xla/client/global_data.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/service.h" #include "xla/shape.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/client/lib/tridiagonal.h b/third_party/xla/xla/client/lib/tridiagonal.h index b24ef6a3d4b71b..5cc51c5e98262e 100644 --- a/third_party/xla/xla/client/lib/tridiagonal.h +++ b/third_party/xla/xla/client/lib/tridiagonal.h @@ -16,28 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_TRIDIAGONAL_H_ #define XLA_CLIENT_LIB_TRIDIAGONAL_H_ -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace tridiagonal { - -enum SolverAlgorithm { kThomas }; - -absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, - XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, XlaOp rhs); - -absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, - XlaOp rhs); - -absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, - XlaOp main_diagonal, - XlaOp lower_diagonal, XlaOp rhs); - -} // namespace tridiagonal -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/tridiagonal.h" #endif // XLA_CLIENT_LIB_TRIDIAGONAL_H_ diff --git a/third_party/xla/xla/client/lib/tuple.h b/third_party/xla/xla/client/lib/tuple.h index dd8fb3c6ec82bf..c1dc9de027a50f 100644 --- a/third_party/xla/xla/client/lib/tuple.h +++ b/third_party/xla/xla/client/lib/tuple.h @@ -16,21 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_LIB_TUPLE_H_ #define XLA_CLIENT_LIB_TUPLE_H_ -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/shape_tree.h" - -namespace xla { - -// Returns a ShapeTree where each index is a GetTupleElement instruction for -// that subshape of the tuple. The root index is the original argument. -absl::StatusOr> DisassembleTuple(XlaOp tuple); - -// Assembles a tuple from a ShapeTree that contains the leaves of the tuple. -// Non-leaf elements of the ShapeTree are ignored. DisassembleTuple and -// AssembleTuple are essentially inverse operations. -XlaOp AssembleTuple(XlaBuilder* builder, ShapeTree elements); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/lib/tuple.h" #endif // XLA_CLIENT_LIB_TUPLE_H_ diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index 05056ba76664f9..66ef01ae2ae16d 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -27,9 +27,9 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 07c6e6e8b11978..d21881b311f963 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -26,8 +26,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/client.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h index e717183ce2d6c8..a9e928d865da0e 100644 --- a/third_party/xla/xla/client/padding.h +++ b/third_party/xla/xla/client/padding.h @@ -16,52 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_PADDING_H_ #define XLA_CLIENT_PADDING_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/types.h" - -namespace xla { - -// Describes the padding applied for a windowed operation like -// convolution, where a window is placed inside a base area. -enum class Padding { - // Make the output have the same dimensions as the base area. For - // example, for a 3x3 base area and a 2x2 window, the output will be - // 3x3, so that requires padding the 3x3 base area to 4x4. - kSame, - - // Use no padding. For example, for a 4x4 base area and a 2x2 - // window, the output will be 3x3. - kValid, -}; - -// Validates that the slices are acceptable for determining padding -- this can -// be used to check the preconditions of MakePadding below to produce an error -// message that can be returned to the user. -absl::Status ValidatePaddingValues(absl::Span input_dimensions, - absl::Span window_dimensions, - absl::Span window_strides); - -// Returns the padding needed for the base area, given the base area dimensions, -// window dimensions, strides, and the type of padding. -// -// If v is the returned vector, then for each dimension number i, -// v[i].first is the padding to the left (i.e. in the direction of -// lower indices) and v[i].second is the padding to the right (i.e. in -// the direction of higher indices). -// -// Precondition: The number of dimensions (i.e., rank) in input_dimensions, -// window_dimensions, and strides must match, which is equal to the number -// of elements in the result vector. -std::vector> MakePadding( - absl::Span input_dimensions, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/padding.h" #endif // XLA_CLIENT_PADDING_H_ diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h index eef395e0b46368..995978b165f885 100644 --- a/third_party/xla/xla/client/sharding_builder.h +++ b/third_party/xla/xla/client/sharding_builder.h @@ -16,48 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_SHARDING_BUILDER_H_ #define XLA_CLIENT_SHARDING_BUILDER_H_ -#include - -#include "xla/array.h" -#include "xla/shape.h" -#include "xla/shape_tree.h" -#include "xla/shape_util.h" -#include "xla/types.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace sharding_builder { -// A shaped array used to describe the assignment of tiles to devices. -using TileAssignment = Array; - -// Creates a replicated sharding - replicate a tensor on every device. -OpSharding Replicate(); - -// Creates a manual sharding - the partitioner will not change the shape. -OpSharding Manual(); - -// Creates a sharding that assigns a tensor to just one device. -OpSharding AssignDevice(int device); - -// Creates a tiled sharding with the given tile shape and assignment of tiles -// to devices. -// -// If tile_shape is not evenly divisible by the number of devices in -// tile_assignment, operations behave as if implicit padding had been inserted. -// The value of this padding is undefined. -OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); - -// Creates a sharding in one dimension, with the given tile shape which must -// be rank 1 and using devices [0..num_tiles). -// -// This is simply a convenience wrapper for Tile(). -OpSharding Tile1D(const Shape& tile_shape, int64_t num_tiles); - -// Creates a tuple sharding from the given ShapeTree of element shardings. -OpSharding Tuple(const ShapeTree& shardings); - -} // namespace sharding_builder -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/sharding_builder.h" #endif // XLA_CLIENT_SHARDING_BUILDER_H_ diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h index 84c1c99f53fd4d..f717cc703b2502 100644 --- a/third_party/xla/xla/client/value_inference.h +++ b/third_party/xla/xla/client/value_inference.h @@ -15,103 +15,7 @@ limitations under the License. #ifndef XLA_CLIENT_VALUE_INFERENCE_H_ #define XLA_CLIENT_VALUE_INFERENCE_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/log/check.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/evaluator/hlo_evaluator.h" -#include "xla/hlo/ir/dfs_hlo_visitor.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -// OptionalLiteral is an augmented literal class which returns optional -// values for each index (the value can be either valid or invalid). The -// implementation keeps two literals, a value literal, holding both the valid -// and garabage value, and a masking literal representing if a value is valid or -// garbage. -class OptionalLiteral { - public: - explicit OptionalLiteral(Literal value, Literal mask) - : value_(std::move(value)), mask_(std::move(mask)) {} - - template - std::optional Get(absl::Span element_index, - ShapeIndex shape_index = {}) const { - if (mask_.Get(element_index, shape_index)) { - return std::nullopt; - } else { - return value_.Get(element_index, shape_index); - } - } - - // Returns true if all values in this literal slice are value. - bool AllValid() { return mask_.IsAll(0); } - - // Get value out of this slice if all values are valid. Otherwise returns - // nullopt. - std::optional GetValue() { - if (!AllValid()) { - return std::nullopt; - } - return LiteralSlice(value_); - } - - private: - Literal value_; - Literal mask_; -}; - -enum ValueInferenceMode { - // Inference the constant value itself. - kValue = 0, - // Inference upper-bound and lower-bound of the value. Bounds are inclusive. - kUpperBound, - kLowerBound, -}; - -class ValueInference { - public: - // ValueInference analyzes values in XlaOp answers following questions: - // - What's the upper-bound of each value in a tensor. - // - What's the lower-bound of each value in a tensor. - // - What's the constant value of each tensor. - // - Whether or not each value in a tensor is dynamic. - explicit ValueInference(XlaBuilder* builder) : builder_(builder) { - CHECK(builder_); - } - absl::StatusOr AnalyzeIsDynamic(XlaOp op); - // Returns an OptionalLiteral. Each individual value of the literal is - // the concrete constant value if it can be inferred, otherwise a nullopt. - absl::StatusOr AnalyzeConstant(XlaOp op, - ValueInferenceMode mode); - - // Returns underlying xla builder. - XlaBuilder* builder() { return builder_; } - - private: - // Given an op handle, returns a simplified version of the handle inside a - // int64_t Literal. If the a -1 value for the handle means invalid - // simplification and the result shouldn't be used. - absl::StatusOr SimplifyOp(int64_t handle); - - // Perform CSE on a given handle, and return an equivalent handle if seen - // before. Otherwise, returns nullopt. - absl::StatusOr> CseOpHandle(int64_t handle); - XlaBuilder* builder_; - HloEvaluator evaluator_; - // A map from instruction_hash to handle that helps perform CSE. - absl::flat_hash_map cse_map_; -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/value_inference.h" #endif // XLA_CLIENT_VALUE_INFERENCE_H_ diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index dd222f1d82095b..1599160a713014 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -16,3071 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_XLA_BUILDER_H_ #define XLA_CLIENT_XLA_BUILDER_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/functional/function_ref.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/array.h" -#include "xla/array2d.h" -#include "xla/array3d.h" -#include "xla/array4d.h" -#include "xla/client/padding.h" -#include "xla/client/xla_computation.h" -#include "xla/comparison_util.h" -#include "xla/hlo/ir/dynamic_parameter_binding.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout.h" -#include "xla/literal.h" -#include "xla/literal_util.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/tsl/lib/core/bitmap.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/stacktrace.h" - -namespace xla { - -class XlaBuilder; -class XlaOp; -class HloInstruction; - -namespace internal { - -struct XlaBuilderFriend { - static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, - XlaOp token, const Shape& shape); - - static std::pair BuildAsyncStart( - XlaBuilder* builder, absl::Span operands, - std::string execution_thread, const XlaComputation& called_computation, - const Shape& shape); - static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildAllGatherStart( - XlaBuilder* builder, XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllGatherDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildAllReduceStart( - XlaBuilder* builder, XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - static XlaOp BuildAllReduceDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildCollectivePermuteStart( - XlaBuilder* builder, XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands, - const Shape& shape); - - static XlaOp BuildCopyStart( - XlaBuilder* builder, XlaOp operand, - std::optional cross_program_prefetch_index = std::nullopt); - static XlaOp BuildCopyDone(XlaBuilder* builder, XlaOp operand, - const Shape& shape); - - static XlaOp BuildFusion( - XlaBuilder* builder, absl::Span operands, - absl::string_view fusion_kind, const XlaComputation& fused_computation, - absl::Span>> - output_operand_aliasing = {}); - - static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, - const Shape& shape); - - static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape); - - static XlaOp BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, - const ChannelHandle& handle, bool is_host_transfer); - static XlaOp BuildSendDone(XlaBuilder* builder, XlaOp operand, - const ChannelHandle& handle, - bool is_host_transfer); - - static XlaOp BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, - const ChannelHandle& handle, bool is_host_transfer); - static XlaOp BuildRecvDone(XlaBuilder* builder, XlaOp token, - const Shape& shape, const ChannelHandle& handle, - bool is_host_transfer); - - static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, OpSharding entry, - OpSharding exit, const Shape& shape); - - static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, - const Shape& shape); - - static HloInstructionProto* GetInstruction(XlaOp op); - static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder, - int64_t handle); -}; - -} // namespace internal - -// This represents an instruction that has been enqueued using the XlaBuilder. -// This is used to pass to subsequent computations that depends upon the -// instruction as an operand. -class XlaOp { - public: - XlaOp() : handle_(-1), builder_(nullptr) { - static_assert(std::is_trivially_destructible::value, - "XlaOp should be trivially destructible"); - } - ~XlaOp() = default; - - XlaOp(const XlaOp& other) = default; - XlaOp& operator=(const XlaOp& other) = default; - - // Precondition: !IsUninitialized(). - // - // It's very common to do foo.builder()->bar(). Without this precondition, if - // foo.builder() is null, the call to bar will segfault at some point possibly - // deep in the callstack when we finally dereference `this`. The precondition - // lets us avoid this tricky-to-debug problem. - XlaBuilder* builder() const { - CHECK(builder_ != nullptr); - return builder_; - } - - // Returns true if the XlaOp represents valid, non-erroneous value. - bool valid() const { return handle_ >= 0; } - - // Returns true if the XlaOp was created by the XlaOp() constructor and - // not returned by a builder. - bool IsUninitialized() const { return builder_ == nullptr; } - - bool IsIdenticalTo(XlaOp rhs) const { - return handle_ == rhs.handle_ && builder_ == rhs.builder_; - } - - friend std::ostream& operator<<(std::ostream& out, XlaOp op) { - out << op.handle(); - return out; - } - - private: - explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} - XlaOp(int64_t handle, XlaBuilder* builder) - : handle_(handle), builder_(builder) {} - - int64_t handle() const { return handle_; } - - friend class XlaBuilder; - friend class ValueInference; - friend struct internal::XlaBuilderFriend; - - // < 0 means "invalid handle". - int64_t handle_; - - // Not owned. Non-null for any handle returned by XlaBuilder, even if the - // handle is invalid. - XlaBuilder* builder_; -}; - -// Arithmetic operator overloads for the XlaOp type. -XlaOp operator-(XlaOp x); -XlaOp operator+(XlaOp x, XlaOp y); -XlaOp operator-(XlaOp x, XlaOp y); -XlaOp operator*(XlaOp x, XlaOp y); -XlaOp operator/(XlaOp x, XlaOp y); -XlaOp operator%(XlaOp x, XlaOp y); - -// Bitwise operator overloads for the XlaOp type. -XlaOp operator~(XlaOp x); -XlaOp operator&(XlaOp x, XlaOp y); -XlaOp operator|(XlaOp x, XlaOp y); -XlaOp operator^(XlaOp x, XlaOp y); -XlaOp operator<<(XlaOp x, XlaOp y); -// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs -// a right logical shift. -XlaOp operator>>(XlaOp x, XlaOp y); - -// We don't overload the relational operators (==, !=, <, <=, >, >=) because the -// semantics might be surprising since their result types are usually 'bool'. -// Further programmers may expect == to be a structural equality. -// We also choose not to overload any of the mutating operators (e.g., +=, -=) -// because the semantics might be misleading — XLA computations are immutable. - -// A convenient interface for building up computations. -// -// Thread-compatible. -class XlaBuilder { - public: - // computation_name: name to use for the built computation. - explicit XlaBuilder(const std::string& computation_name); - - XlaBuilder(const XlaBuilder&) = delete; - XlaBuilder& operator=(const XlaBuilder&) = delete; - - virtual ~XlaBuilder(); - - // Returns the computation name. - const std::string& name() const { return name_; } - - // Sets OpMetadata that will be added to all instructions until cleared. - // - // OpMetadata is often applied to a series of XLA HLO instructions. As a - // result, OpMetadata is set on the computation builder. All subsequent - // instructions generated via this computation builder will have the same - // OpMetadata attached until a call to ClearOpMetadata. - void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } - - // Swaps the passed op metadata with the ones currently set. - // - // Returns the old op metadata. - OpMetadata SwapOpMetadata(OpMetadata metadata) { - OpMetadata old_metadata = std::move(metadata_); - metadata_ = std::move(metadata); - return old_metadata; - } - - // Similar to SetOpMetadata, but only set the metadata for the next op. - void SetOneShotOpMetadata(OpMetadata metadata) { - one_shot_metadata_ = std::move(metadata); - } - - // Clears the HloMetadata state. - void ClearOpMetadata() { metadata_.Clear(); } - - // Sets an OpSharding that will be attached to all instructions until cleared. - void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - - // Sets the FrontendAttributes that will be added to all instructions until - // cleared. - // - // FrontendAttributes are often applied to a series of XLA HLO instructions. - // As a result they are set on the computation builder and all the - // instructions generated via the computation builder will have the same - // frontend attributes attached to them. - virtual void SetFrontendAttributes( - const FrontendAttributes& frontend_attributes) { - frontend_attributes_ = frontend_attributes; - } - - // Swap the passed FrontendAttributes with the ones currently set. - // - // Return the old attributes. - FrontendAttributes SwapFrontendAttributes( - const FrontendAttributes& frontend_attributes) { - FrontendAttributes old_attributes = std::move(frontend_attributes_); - frontend_attributes_ = frontend_attributes; - return old_attributes; - } - - // Returns the FrontendAttributes that will be attached to all instructions. - const FrontendAttributes& frontend_attributes() const { - return frontend_attributes_; - } - - // Clears all the frontend attributes. - void ClearFrontendAttributes() { frontend_attributes_.Clear(); } - - // Clears the sharding. Ops will be sharded according to the default placement - // policy. - void ClearSharding() { sharding_ = std::nullopt; } - - // Returns the OpSharding that will be attached to all instructions. - const std::optional& sharding() const { return sharding_; } - - // Sets the builder to a mode where it will die immediately when an error is - // encountered, rather than producing it in a deferred fashion when Build() is - // called (which is the default). - void set_die_immediately_on_error(bool enabled) { - die_immediately_on_error_ = enabled; - } - - // Default dimension numbers used for a 2D convolution. - static constexpr int64_t kConvBatchDimension = 0; - static constexpr int64_t kConvFeatureDimension = 1; - static constexpr int64_t kConvFirstSpatialDimension = 2; - static constexpr int64_t kConvSecondSpatialDimension = 3; - static constexpr int64_t kConvKernelOutputDimension = 0; - static constexpr int64_t kConvKernelInputDimension = 1; - static constexpr int64_t kConvKernelFirstSpatialDimension = 2; - static constexpr int64_t kConvKernelSecondSpatialDimension = 3; - - // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for - // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for - // the kernel operand - // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. - static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( - int num_spatial_dims = 2); - - // Returns an error if the convolution dimension numbers have conflicts. - static absl::Status Validate(const ConvolutionDimensionNumbers& dnum); - - // Returns a new XlaBuilder whose resultant Computation is used only by this - // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error - // behavior as the parent. - std::unique_ptr CreateSubBuilder( - const std::string& computation_name); - - // Builds the computation with the requested operations, or returns a non-ok - // status. Note that all ops that have been enqueued will be moved to the - // computation being returned. The root of the computation will be the last - // added operation. - // - // `remove_dynamic_dimensions` tells the builder whether to remove the - // dynamic dimensions information in all ops. - // - // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the - // dynamic dimensions information when XLA backend can handle dynamic - // dimensions. - absl::StatusOr Build(bool remove_dynamic_dimensions = false); - - // Overload of Build which specifies a particular root instruction for the - // computation. - absl::StatusOr Build(XlaOp root, - bool remove_dynamic_dimensions = false); - - // Builds the computation with the requested operations, or notes an error in - // the parent XlaBuilder and returns an empty computation if building failed. - // This function is intended to be used where the returned XlaComputation is - // only used by the parent XlaBuilder and hence further operation on the - // returned XlaComputation will simply be error'ed out if an error occurred - // while building this computation. If the built computation is to be used by - // a XlaBuilder other than the parent XlaBuilder then Build() should be used - // instead. - XlaComputation BuildAndNoteError(); - - // Returns a subgraph that roots on the given root. If the root is not a - // compile-time constant (see `IsConstant`), returns an error. - // - // This will copy the needed ops/computations to the subgraph. - absl::StatusOr BuildConstantSubGraph( - XlaOp root_op, bool dynamic_dimension_is_minus_one = false); - - // Returns the first error that was encountered while building the - // computation. When an error is encountered, by default we return a vacuous - // XlaOp and inform the user of the error that occurred while - // building the computation when they make a final call to Build(). - // - // See also set_die_immediately_on_error(). - absl::Status first_error() const { return first_error_; } - - // Returns the current status of the builder, complete with the stack trace - // information. - absl::Status GetCurrentStatus() const; - - // Returns the shape of the given op. - absl::StatusOr GetShape(XlaOp op) const; - - // Returns the shape of the given op. - virtual absl::StatusOr GetShapePtr(XlaOp op) const; - - // Returns the OpSharding of the given op. If "op" has no sharding, return - // std::nullopt. - absl::StatusOr> GetOpSharding(XlaOp op) const; - - // Returns the (inferred) result for the current computation's shape. This - // assumes the root instruction is the last added instruction. - absl::StatusOr GetProgramShape() const; - - // Returns the (inferred) result for the current computation's shape using the - // given operation as the root. - absl::StatusOr GetProgramShape(XlaOp root) const; - - // Reports an error to the builder, by - // * storing it internally and capturing a backtrace if it's the first error - // (this deferred value will be produced on the call to - // Build()/GetShape()/...) - // * dying if die_immediately_on_error_ is true. - // Returns an XlaOp with an invalid handle but a valid builder. This value can - // be returned in place of a value in APIs that return an XlaOp. - XlaOp ReportError(const absl::Status& error); - - // A helper function that converts a absl::StatusOr into an XlaOp. - // If the absl::Status was an error, reports the error to builder and returns - // an invalid XlaOp handle. - XlaOp ReportErrorOrReturn(const absl::StatusOr& op); - - // A helper function that runs a function that returns a absl::StatusOr - // and returns an XlaOp. - XlaOp ReportErrorOrReturn( - absl::FunctionRef()> op_creator); - - // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on any parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. - // - // This tests whether a computation is a compile-time constant without - // evaluating the computation. - absl::StatusOr IsConstant(XlaOp operand) const; - - // Adds a new input/output alias. Since the input/output shape information are - // not available until the computation is built, any eventual error in the - // arguments of this API will be detected only at computation Build() time. - // - // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' - // and only donated buffer at runtime will be aliased with output. If a buffer - // is not donated at runtime, a copy will be inserted by XLA to prevent buffer - // clobbering. - void SetUpAlias(const ShapeIndex& output_index, int64_t param_number, - const ShapeIndex& param_index, - HloInputOutputAliasConfig::AliasKind kind = - HloInputOutputAliasConfig::AliasKind::kMayAlias) { - input_output_aliases_.push_back( - {output_index, param_number, param_index, kind}); - } - - // Describes an input/output alias as inserted by the SetUpAlias() API. - struct InputOutputAlias { - // Specifies the index of the aliased buffer in the result tuple. - ShapeIndex output_index; - // Specifies the parameter containing the buffer to be aliased. - int64_t param_number; - // Specifies the index of the aliased buffer in the parameter. - ShapeIndex param_index; - // Specifies if the alias is a must alias or may alias. - HloInputOutputAliasConfig::AliasKind kind; - }; - - // Adds a new buffer donor. The donated buffer may be paired with any valid - // output. On the contrary, the buffer aliasing bonds the input output pair. - // The input can only donate the buffer to the paired output. - void AddBufferDonor(int64_t param_number, const ShapeIndex& param_index) { - buffer_donors_.insert({param_number, param_index}); - } - - // Looks up the HloInstruction and sets the frontend attribute "attribute" to - // "value". If the attribute already existed, then its value is updated. - // - // The attribute is only added to the HloInstruction, not to the builder. - absl::Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute, - std::string value); - - // Looks up the HloInstruction and sets the sharding. If the sharding already - // existed, then its value is updated. - // - // The sharding is only added to the HloInstruction, not to the builder. - absl::Status SetInstructionSharding( - XlaOp op, const std::optional& sharding); - - // Returns shapes for the operands. - absl::StatusOr> GetOperandShapes( - absl::Span operands) const; - - // Converts the op to string for the ease of debugging. - std::string OpToString(XlaOp op) const; - - private: - void ToStringHelper(std::string* out, int ident, int64_t op_handle) const; - - // Build helper which takes the id of the root operation.. - absl::StatusOr Build(int64_t root_id, - bool remove_dynamic_dimensions); - - // Description for the methods below can be found in the corresponding public - // functions section in this file. - - XlaOp Parameter(int64_t parameter_number, const Shape& shape, - const std::string& name, - const std::vector& replicated_at_leaf_buffers); - XlaOp Parameter(int64_t parameter_number, const Shape& shape, - const std::string& name) { - std::vector empty_bools; - return Parameter(parameter_number, shape, name, empty_bools); - } - - virtual XlaOp ConstantLiteral(const LiteralSlice& literal); - - XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - - XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); - - // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim - // op from the XlaBuilder. This is only intended for export to MHLO or - // StableHLO, and cannot be compiled. Only static output_dimensions are - // allowed, and broadcast_dimensions is verified. - XlaOp MhloDynamicBroadcastInDim( - XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - - XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - - virtual absl::StatusOr PadInternal( - const Shape& shape, XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - - XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes, - int64_t inferred_dimension = -1); - - XlaOp Reshape(XlaOp operand, absl::Span new_sizes, - int64_t inferred_dimension = -1); - - XlaOp Reshape(const Shape& shape, XlaOp operand, - int64_t inferred_dimension = -1); - - XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - - XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, - const Shape& shape); - - XlaOp Collapse(XlaOp operand, absl::Span dimensions); - - XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - virtual absl::StatusOr SliceInternal( - const Shape& shape, XlaOp operand, - absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, int64_t stride, int64_t dimno); - - XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - virtual absl::StatusOr DynamicSliceInternal( - const Shape& shape, XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - - XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - virtual absl::StatusOr DynamicUpdateSliceInternal( - const Shape& shape, XlaOp operand, XlaOp update, - absl::Span start_indices); - - XlaOp ConcatInDim(absl::Span operands, int64_t dimension); - virtual absl::StatusOr ConcatInDimInternal( - const Shape& shape, absl::Span operands, int64_t dimension); - - XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - - XlaOp Tuple(absl::Span elements); - virtual absl::StatusOr TupleInternal(const Shape& shape, - absl::Span elements); - - XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - virtual absl::StatusOr GetTupleElementInternal(const Shape& shape, - XlaOp tuple_data, - int64_t index); - - XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp DotGeneral( - XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp SparseDot( - XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp Conv( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, int64_t feature_group_count = 1, - int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - - XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt, - std::optional> window_reversal = std::nullopt); - - XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - absl::StatusOr DynamicConvInstruction( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - - virtual absl::StatusOr ConvGeneralDilatedInternal( - const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config); - - XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - virtual absl::StatusOr FftInternal( - const Shape& shape, XlaOp operand, FftType fft_type, - absl::Span fft_length); - - virtual absl::StatusOr TriangularSolveInternal( - const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); - - virtual absl::StatusOr CholeskyInternal(const Shape& shape, XlaOp a, - bool lower); - - XlaOp Infeed(const Shape& shape, const std::string& config = ""); - XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config); - virtual absl::StatusOr InfeedWithTokenInternal( - const Shape& infeed_instruction_shape, XlaOp token, - const std::string& config); - - void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - virtual absl::StatusOr OutfeedWithTokenInternal( - XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const std::string& outfeed_config); - XlaOp Call(const XlaComputation& computation, - absl::Span operands); - - XlaOp CompositeCall( - const XlaComputation& computation, absl::Span operands, - const std::string& name, - std::optional attributes = std::nullopt, - std::optional version = std::nullopt); - - XlaOp CustomCall( - const std::string& call_target_name, absl::Span operands, - const Shape& shape_with_layout, const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, std::optional window, - std::optional dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - - // Internal version of CustomCall without computation that doesn't do op - // specific error handling and expects arguments to be legal. CustomCall - // method above calls this method after error handling. - virtual absl::StatusOr CustomCallInternal( - const std::string& call_target_name, absl::Span operands, - const XlaComputation* computation, const Shape& shape_with_layout, - const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, std::optional window, - std::optional dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - - // TODO(b/239474321) Remove this overload as it has simply led to code - // duplication. - XlaOp CustomCall( - const std::string& call_target_name, absl::Span operands, - const XlaComputation& computation, const Shape& shape_with_layout, - const std::string& opaque, - std::optional> operand_shapes_with_layout, - bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - - XlaOp OptimizationBarrier(XlaOp operand); - - XlaOp Reduce(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - XlaOp Reduce(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - virtual absl::StatusOr ReduceInternal( - const Shape& shape, absl::Span all_operands, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - - XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - - XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - - XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - - XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - absl::StatusOr ReduceWindowInternal( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - virtual absl::StatusOr ReduceWindowInternal( - const Shape& shape, XlaOp operand, XlaOp init_value, - const XlaComputation& computation, Window window); - XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups = {}); - - XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp ReduceScatter( - XlaOp operand, const XlaComputation& computation, - int64_t scatter_dimension, int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - - XlaOp AllToAll(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - - XlaOp AllToAllTuple( - absl::Span operands, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id = std::nullopt); - - XlaOp AllToAllTuple( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id = std::nullopt); - - XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - - XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - - XlaOp ReplicaId(); - - XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - - XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - - absl::StatusOr SelectAndScatterInternal( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - - virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension); - - XlaOp Iota(PrimitiveType type, int64_t size); - - XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); - - XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - virtual absl::StatusOr BitcastConvertTypeInternal(const Shape& shape, - XlaOp operand); - - XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - - XlaOp Transpose(XlaOp operand, absl::Span permutation); - virtual absl::StatusOr TransposeInternal( - const Shape& shape, XlaOp operand, absl::Span permutation); - - XlaOp Rev(XlaOp operand, absl::Span dimensions); - virtual absl::StatusOr RevInternal( - const Shape& shape, XlaOp operand, absl::Span dimensions); - - XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64_t dimension = -1, bool is_stable = false); - virtual absl::StatusOr SortInternal(const Shape& shape, - absl::Span operands, - const XlaComputation& comparator, - int64_t dimension, bool is_stable); - - XlaOp TopK(XlaOp operand, int64_t k, bool largest); - virtual absl::StatusOr TopKInternal(const Shape& shape, XlaOp operand, - int64_t k, bool largest); - - XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - - XlaOp Map(absl::Span operands, const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands = {}); - - XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - - XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - - XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - // Internal variant for the op with the full result shape containing both data - // and state shape as a tuple. - virtual absl::StatusOr RngBitGeneratorInternal( - const Shape& full_result_shape, RandomAlgorithm algorithm, - XlaOp initial_state); - - XlaOp While(const XlaComputation& condition, const XlaComputation& body, - XlaOp init); - virtual absl::StatusOr WhileInternal(const Shape& shape, - const XlaComputation& condition, - const XlaComputation& body, - XlaOp init); - - XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, XlaOp false_operand, - const XlaComputation& false_computation); - - XlaOp Conditional(XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - - XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); - virtual absl::StatusOr ReducePrecisionInternal(const Shape& shape, - XlaOp operand, - int exponent_bits, - int mantissa_bits); - - XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted = false); - - virtual absl::StatusOr GatherInternal( - const Shape& shape, XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, bool indices_are_sorted); - - XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - - virtual absl::StatusOr ScatterInternal( - const Shape& shape, absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, - bool unique_indices); - - void Send(XlaOp operand, const ChannelHandle& handle); - XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); - - XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const ChannelHandle& handle); - - XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - - virtual XlaOp CreateToken(); - - XlaOp AfterAll(absl::Span tokens); - - XlaOp Recv(const Shape& shape, const ChannelHandle& handle); - XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - - XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, - float epsilon, int64_t feature_index); - - XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, - XlaOp variance, float epsilon, - int64_t feature_index); - - XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - - XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - - XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - - virtual absl::StatusOr SetDimensionSizeInternal(const Shape& shape, - XlaOp operand, - XlaOp val, - int64_t dimension); - - XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - - virtual absl::StatusOr AddInstruction( - HloInstructionProto&& instr, HloOpcode opcode, - absl::Span operands); - absl::StatusOr AddInstruction(HloInstructionProto&& instr, - HloOpcode opcode) { - return AddInstruction(std::move(instr), opcode, /*operands=*/{}); - } - - void AddCalledComputation(const XlaComputation& computation, - HloInstructionProto* instr); - - absl::StatusOr LookUpInstruction(XlaOp op) const; - absl::StatusOr LookUpInstructionByHandle( - int64_t handle) const; - absl::StatusOr LookUpMutableInstruction(XlaOp op); - absl::StatusOr LookUpMutableInstructionByHandle( - int64_t handle); - - // Internal helper method that does the building for an arbitrary unary op. - virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); - - // Internal helper method that does the building for an arbitrary binary op. - // broadcast_dimensions specifies which dimensions to use for broadcasting - // when the operation is between tensors of different ranks. The direction is - // only used if opcode is kCompare. - XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - std::optional direction = std::nullopt, - std::optional type = std::nullopt); - - absl::StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, - ComparisonDirection direction); - - // Internal helper method for binary op compare without broadcast dimensions. - virtual absl::StatusOr Compare(const Shape& shape, XlaOp lhs, - XlaOp rhs, - ComparisonDirection direction, - Comparison::Type type); - - // Internal helper method that does the building for an arbitrary binary op - // with same ranked operands that doesn't broadcast. - virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, - XlaOp lhs, XlaOp rhs); - - // Internal helper method that does the building for an arbitrary ternary op. - XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs); - - XlaOp RngOp(RandomDistribution distribution, - absl::Span parameters, const Shape& shape); - - virtual absl::StatusOr RngOpInternal( - RandomDistribution distribution, absl::Span parameters, - const Shape& shape); - - virtual absl::StatusOr InDimBroadcast( - const Shape& shape, XlaOp operand, - absl::Span broadcast_dimensions); - - // Internal helper method that creates a sequence of instructions that - // performs an explicit broadcast of the operand to the target shape. - // All dimensions of the operand must either be equal to the corresponding - // output shape dimension, or be exactly 1. (Such dimensions are the - // degenerate dimensions.) - absl::StatusOr AddBroadcastSequence(const Shape& output_shape, - XlaOp operand); - - // Internal helper method that broadcasts a scalar to the shape of the output. - absl::StatusOr BroadcastScalarToOutputShape(XlaOp scalar, - XlaOp output); - - // Internal helper method for creating a Reshape op with the already inferred - // shape. - virtual absl::StatusOr ReshapeInternal(const Shape& shape, - XlaOp operand, - int64_t inferred_dimension); - - // Returns the (inferred) result for the program shape using the given root. - absl::StatusOr GetProgramShape(int64_t root_id) const; - - // A visitor which checks whether an operation is a compile-time constant, - // meaning that it doesn't depend on any parameters, or on any stateful - // operation such as `RngNormal` or `Infeed`. The visitor walks the - // computation starting at a given operation and sets is_constant to false iff - // a parameter or stateful operation is encountered. - void IsConstantVisitor(int64_t op_handle, int depth, - absl::flat_hash_set* visited, - bool* is_constant) const; - - // Checks bounds for convolution parameters. - absl::Status VerifyConvolution( - const Shape& lhs_shape, const Shape& rhs_shape, - const ConvolutionDimensionNumbers& dimension_numbers) const; - - int64_t GetNextId() { return ++next_id_; } - - // Populates the module with the input/output alias information stored within - // the input_output_aliases vector. - static absl::Status PopulateInputOutputAliasAndBufferDonor( - HloModuleProto* module, const ProgramShape& program_shape, - const std::vector& input_output_aliases, - const absl::flat_hash_set& - buffer_donors); - - std::string name_; // Name to use for the built computation. - - // The next sequential ID for every instruction/computation contained within - // this computation. - int64_t next_id_ = 0; - - // The first error encountered while building the computation. - // This is OK until the first error is encountered. - absl::Status first_error_; - - // The saved stack trace from the point at which the first error occurred. - tsl::SavedStackTrace first_error_backtrace_; - - // The instructions of this computation. - // Use a deque so pointers into this are stable, for example the return - // value of LookUpInstructionByHandle(). - std::deque instructions_; - // A cache for the HloInstructionProto shapes, to avoid recreating Shape - // objects from protos and to support the GetShapePtr() API. - std::vector> instruction_shapes_; - - // Dynamic parameter configuration of this computation. - DynamicParameterBinding dynamic_parameter_binding_; - - // Holds the input/output alias information populated by the SetUpAlias() API. - std::vector input_output_aliases_; - - // Holds the buffer donor information populated by the AddBufferDonor() API. - absl::flat_hash_set buffer_donors_; - - // A map from XlaOp::Handle to the index in the instructions_ vector where the - // instruction is held. - absl::flat_hash_map handle_to_index_; - - // Track imported instructions by their computation id and the position in - // their computation's instruction list. - struct ImportedInstruction { - int64_t computation_id; - int64_t instruction_index; - }; - - absl::flat_hash_map handle_to_imported_index_; - - // The embedded computations used by this computation. Each computation was - // the entry computation of some XlaComputation, the key is the unique id of - // that XlaComputation. - std::map embedded_; - - // The unique parameter numbers. - absl::flat_hash_set parameter_numbers_; - - // The metadata to attach to each op. This is structured as a "modal"-like - // operation, in order to simplify client code (and not sprinkle this metadata - // throughout the TensorFlow op kernel implementations). - OpMetadata metadata_; - - // A temporary metadata that will only be applied to the next op created. - std::optional one_shot_metadata_; - - // Sharding for this operator. This is structured as a "model"-like operation, - // in order to simplify client code, similar to metadata_. - std::optional sharding_; - - // Mode bit that indicates whether to die when a first error is encountered. - bool die_immediately_on_error_ = false; - - XlaBuilder* parent_builder_{nullptr}; - - FrontendAttributes frontend_attributes_; - - friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name, - const std::vector& replicated_at_leaf_buffers); - friend XlaOp ConstantLiteral(XlaBuilder* builder, - const LiteralSlice& literal); - - friend XlaOp Broadcast(XlaOp operand, - absl::Span broadcast_sizes); - - friend XlaOp BroadcastInDim(XlaOp operand, - absl::Span out_dim_size, - absl::Span broadcast_dimensions); - - friend XlaOp MhloDynamicBroadcastInDim( - XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - - friend XlaOp Copy(XlaOp operand); - - friend XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - - friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - - friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes); - - friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); - - friend XlaOp Reshape(const Shape& shape, XlaOp operand); - - friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - - friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, - const Shape& shape); - - friend XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); - - friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); - - friend XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - - friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, - int64_t limit_index, int64_t stride, int64_t dimno); - - friend XlaOp DynamicSlice(XlaOp operand, - absl::Span start_indices, - absl::Span slice_sizes); - - friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - - friend XlaOp ConcatInDim(XlaBuilder* builder, - absl::Span operands, int64_t dimension); - - friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); - friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction); - friend XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction, - Comparison::Type compare_type); - friend XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - virtual absl::StatusOr DotGeneralInternal( - const Shape& shape, XlaOp lhs, XlaOp rhs, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config); - friend XlaOp SparseDot(XlaOp lhs, XlaOp rhs, - absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_number, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp Conv(XlaOp lhs, XlaOp rhs, - absl::Span window_strides, Padding padding, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - friend XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - friend XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - friend XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type); - - friend XlaOp ConvKernelGrad( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type); - - friend XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, - std::optional preferred_element_type, - std::optional> window_reversal); - - friend XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a); - friend XlaOp Cholesky(XlaOp a, bool lower); - friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, - const std::string& config); - friend void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands); - - friend XlaOp CompositeCall(XlaBuilder* builder, - const XlaComputation& computation, - absl::Span operands, - const std::string& name, - std::optional attributes, - std::optional version); - - friend XlaOp CustomCall( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithComputation( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithLayout( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, CustomCallSchedule schedule, - CustomCallApiVersion api_version); - friend XlaOp CustomCallWithConvDnums( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, - CustomCallSchedule schedule, CustomCallApiVersion api_version); - friend XlaOp OptimizationBarrier(XlaOp operand); - friend XlaOp Complex(XlaOp real, XlaOp imag, - absl::Span broadcast_dimensions); - friend XlaOp Conj(XlaOp operand); - friend XlaOp Add(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Sub(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Mul(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Div(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Rem(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Max(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Min(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp And(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Or(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Xor(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp Not(XlaOp operand); - friend XlaOp PopulationCount(XlaOp operand); - friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp ShiftRightArithmetic( - XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp ShiftRightLogical( - XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); - friend XlaOp Reduce(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding); - friend XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding); - friend XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - friend XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - - friend XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups); - friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - friend XlaOp AllGatherTuple(absl::Span operands, - int64_t all_gather_dimension, int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& shape_with_layout, - std::optional use_global_device_ids); - friend XlaOp AllReduceTuple(absl::Span operand, - const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& shape_with_layout, - std::optional use_global_device_ids); - friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, - int64_t scatter_dimension, int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids); - - friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp AllToAllTuple(absl::Span operands, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, - int64_t concat_dimension, int64_t split_count, - absl::Span replica_groups, - const std::optional& layout, - const std::optional& channel_id); - friend XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id); - friend XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id); - friend XlaOp ReplicaId(XlaBuilder* builder); - friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - friend XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - friend XlaOp Abs(XlaOp operand); - friend XlaOp Atan2(XlaOp y, XlaOp x, - absl::Span broadcast_dimensions); - friend XlaOp Erf(XlaOp operand); - friend XlaOp Exp(XlaOp operand); - friend XlaOp Expm1(XlaOp operand); - friend XlaOp Floor(XlaOp operand); - friend XlaOp Ceil(XlaOp operand); - friend XlaOp Round(XlaOp operand); - friend XlaOp RoundNearestEven(XlaOp operand); - friend XlaOp Log(XlaOp operand); - friend XlaOp Log1p(XlaOp operand); - friend XlaOp Logistic(XlaOp operand); - friend XlaOp Sign(XlaOp operand); - friend XlaOp Clz(XlaOp operand); - friend XlaOp Cos(XlaOp operand); - friend XlaOp Sin(XlaOp operand); - friend XlaOp Tan(XlaOp operand); - friend XlaOp Tanh(XlaOp operand); - friend XlaOp Real(XlaOp operand); - friend XlaOp Imag(XlaOp operand); - friend XlaOp Sqrt(XlaOp operand); - friend XlaOp Rsqrt(XlaOp operand); - friend XlaOp Cbrt(XlaOp operand); - friend XlaOp Pow(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions); - friend XlaOp IsFinite(XlaOp operand); - friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, - int64_t iota_dimension); - friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); - friend XlaOp ConvertElementType(XlaOp operand, - PrimitiveType new_element_type); - friend XlaOp BitcastConvertType(XlaOp operand, - PrimitiveType new_element_type); - friend XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - friend XlaOp Neg(XlaOp operand); - friend XlaOp Transpose(XlaOp operand, absl::Span permutation); - friend XlaOp Rev(XlaOp operand, absl::Span dimensions); - friend XlaOp Sort(absl::Span operands, - const XlaComputation& comparator, int64_t dimension, - bool is_stable); - friend XlaOp TopK(XlaOp operand, int64_t k, bool largest); - friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - friend XlaOp Map(XlaBuilder* builder, absl::Span operands, - const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands); - friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - friend XlaOp While(const XlaComputation& condition, - const XlaComputation& body, XlaOp init); - friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, - XlaOp false_operand, - const XlaComputation& false_computation); - friend XlaOp Conditional( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - friend XlaOp ConditionalImpl( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - friend XlaOp ReducePrecision(XlaOp operand, int exponent_bits, - int mantissa_bits); - friend XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted); - friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted, bool unique_indices); - friend XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted, bool unique_indices); - friend void Send(XlaOp operand, const ChannelHandle& handle); - friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, - float epsilon, int64_t feature_index); - friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, - XlaOp mean, XlaOp variance, float epsilon, - int64_t feature_index); - friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - friend XlaOp SendWithToken(XlaOp operand, XlaOp token, - const ChannelHandle& handle); - friend XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp SendToHost(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const ChannelHandle& handle); - friend XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config); - friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - friend XlaOp CreateToken(XlaBuilder* builder); - friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); - - friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - - protected: - // Returns OK status if the given op was built using this builder. Otherwise, - // returns an error. - absl::Status CheckOpBuilder(XlaOp op) const; - - private: - XlaOp AllGatherImpl(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids, bool async); - - XlaOp AllReduceImpl(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups, - const std::optional& channel_id, - const std::optional& layout, - std::optional use_global_device_ids, bool async); - - XlaOp CollectiveBroadcastImpl(XlaOp operand, - absl::Span replica_groups, - const std::optional& channel_id); - - XlaOp CollectivePermuteImpl( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id, bool async); - - XlaOp ConditionalImpl( - XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - - XlaOp AllToAllArray( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - - // Creates an op with the given opcode and the output shape. - virtual absl::StatusOr AddOpWithShape( - HloOpcode opcode, const Shape& shape, absl::Span operands); - - // Here, InstructionType is either const HloInstructionProto* or non-const - // HloInstructionProto*. - template - absl::StatusOr LookUpInstructionByHandleInternal( - int64_t handle) const { - auto it = handle_to_index_.find(handle); - if (it == handle_to_index_.end()) { - // Try look for the instruction in the imported instructions. - auto imported_it = handle_to_imported_index_.find(handle); - if (imported_it != handle_to_imported_index_.end()) { - ImportedInstruction imported = imported_it->second; - return const_cast( - &embedded_.at(imported.computation_id) - .instructions(imported.instruction_index)); - } - return InvalidArgument("No XlaOp with handle %d", handle); - } - return const_cast(&instructions_.at(it->second)); - } - - // Here, InstructionType is either const HloInstructionProto* or non-const - // HloInstructionProto*. - // - // TODO(hinsu): Return const pointer within absl::StatusOr and use - // absl::implicit_cast at callsites. This requires implicit_cast support in - // absl::StatusOr similar to absl::StatusOr. - template - absl::StatusOr LookUpInstructionInternal(XlaOp op) const { - TF_RETURN_IF_ERROR(CheckOpBuilder(op)); - return LookUpInstructionByHandleInternal(op.handle()); - } - - friend struct internal::XlaBuilderFriend; - - friend class ValueInference; -}; - -// RAII-style object: sets the current sharding assignment in builder on -// construction, and sets back to the previous assignment on destruction. -class XlaScopedShardingAssignment { - public: - XlaScopedShardingAssignment(xla::XlaBuilder* builder, - std::optional sharding) - : builder_(builder), prev_sharding_(builder->sharding()) { - SetSharding(sharding); - } - - XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; - XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = - delete; - - ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } - - private: - void SetSharding(const std::optional& sharding) { - if (sharding.has_value()) { - builder_->SetSharding(sharding.value()); - } else { - builder_->ClearSharding(); - } - } - - xla::XlaBuilder* const builder_; - std::optional prev_sharding_; -}; - -// RAII-style object: save the current builder's frontend attributes, and merge -// them with the new ones on construction. -// Restore the original attributes on destruction. -class XlaScopedFrontendAttributesAssignment { - public: - XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, - FrontendAttributes attributes) - : builder_(builder) { - saved_ = builder_->SwapFrontendAttributes(attributes); - } - - ~XlaScopedFrontendAttributesAssignment() { - builder_->SetFrontendAttributes(saved_); - } - - private: - xla::XlaBuilder* const builder_; - FrontendAttributes saved_; - - XlaScopedFrontendAttributesAssignment( - const XlaScopedFrontendAttributesAssignment&) = delete; - XlaScopedFrontendAttributesAssignment& operator=( - const XlaScopedFrontendAttributesAssignment&) = delete; -}; - -// RAII-style object: sets the current op metadata in builder on construction, -// and sets back to the previous assignment on destruction. -class XlaScopedOpMetadataAssignment { - public: - XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) - : builder_(builder) { - saved_ = builder_->SwapOpMetadata(metadata); - } - - ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } - - private: - xla::XlaBuilder* const builder_; - OpMetadata saved_; - - XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete; - XlaScopedOpMetadataAssignment& operator=( - const XlaScopedOpMetadataAssignment&) = delete; -}; - -// Free functions for building XlaOps. The intention is that these will -// become the public API for building XlaOps rather than calling methods on -// XlaBuilder directly. -// - -// Enqueues a "retrieve parameter value" instruction for a parameter that was -// passed to the computation. -XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name); - -// Same as above, but with leaf buffer replication annotation. -XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, - const Shape& shape, const std::string& name, - const std::vector& replicated_at_leaf_buffers); - -// Enqueues a constant with the value of the given literal onto the -// computation. -XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); - -// Enqueues a constant onto the computation. Methods are templated on the -// native host type (NativeT) which corresponds to a specific XLA -// PrimitiveType as given in the following table: -// -// Native Type PrimitiveType -// ----------------------------- -// bool PRED -// int32_t S32 -// int64_t S64 -// uint32_t U32 -// uint64_t U64 -// float F32 -// double F64 -// -// Note: not all primitive types defined in xla_data.proto have a -// corresponding native type yet. -template -XlaOp ConstantR0(XlaBuilder* builder, NativeT value); -template -XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); -XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values); -template -XlaOp ConstantR2(XlaBuilder* builder, - std::initializer_list> values); -template -XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout); -template -XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); -template -XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout); -template -XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values); -template -XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout); -template -XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values); -template -XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout); -template -XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values); - -// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the -// computation. The vector has size 'length' and every element has the value -// 'value'. -template -XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); - -// Adds dimensions to an array by duplicating the data in the array. -// -// The new dimensions are inserted on the left, i.e. if -// broadcast_sizes has values {a0, ..., aN} and the operand shape -// has dimensions {b0, ..., bM} then the shape of the output has -// dimensions {a0, ..., aN, b0, ..., bM}. -// -// The new dimensions index into copies of the operand, i.e. -// -// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] -XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); - -// This op broadcasts the `operand` to an output with the given `shape`. -// `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the -// i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th -// dimension of the output. This also requires that the i'th input dimension is -// either 1 or is the same as the output dimension it's broadcasting into. -// -// For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the -// output shape is s32[2,2]: -// - Specifying {1} as broadcast_dimension will generate output -// {{1, 2}, -// {1, 2}} -// - On the other hand, specifying {0} as broadcast_dimension -// will generate output -// {{1 , 1}, -// {2 , 2}} -XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, - absl::Span broadcast_dimensions); - -// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim -// op from the XlaBuilder. This is only intended for export to MHLO or -// StableHLO, and cannot be compiled. See -// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. -// for the op semantics. -XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, - absl::Span broadcast_dimensions, - const Shape& output_shape); - -// Copies the input operand to the output. This operation is for internal -// purpose and is only used by the compiler for optimization purposes or to -// ensure correctness. The XLA client should never have to generate this -// instruction. -// -// Copy has two potential use cases: -// -// * Create a copy of the operand with a new layout. -// -// * Create a copy of the operand in a separately allocated buffer. This is -// necessary for some backends if the operand is a parameter or constant and -// the operand is returned within a tuple. In this case, the lifetime of the -// operand buffer must be the same as the lifetime of the output result. -// However, the lifetimes of parameters and constants are managed separately -// from the lifetime of the output result. Creating a separate copy of the -// parameter or constant buffer resolves this issue. -XlaOp Copy(XlaOp operand); - -// Enqueues a pad operation onto the computation that pads the given value on -// the edges as well as between the elements of the input. padding_config -// specifies the padding amount for each dimension. -XlaOp Pad(XlaOp operand, XlaOp padding_value, - const PaddingConfig& padding_config); - -// Enqueues a pad operation in a given dimension, taking all other -// dimensions as they are. -XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, - int64_t pad_lo, int64_t pad_hi); - -// Enqueues an operation onto the computation that flattens the operand based -// on the dimension order (major/slowest-varying to minor/fastest-varying) -// given, followed by reshaping it into the shape with the given dimension -// sizes (also major to minor). Conceptually, this is a limited form of -// "shape casting". -XlaOp Reshape(XlaOp operand, absl::Span dimensions, - absl::Span new_sizes); - -// Enqueues a dynamic reshape operation. The dynamic reshape takes additional -// XlaOps as sizes for the result dimension. The result dim i is a dynamic -// dimension dimension if dims_are_dynamic[i] is true. -XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, - absl::Span new_size_bounds, - const std::vector& dims_are_dynamic); - -// This is an experimental API for creating the mhlo.dynamic_reshape op from the -// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot -// be compiled. -XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); - -// Enqueues an operation onto the computation that collapses the operand, -// from first to last dimension (C order), then reshapes it to the given -// dimension sizes. Conceptually, this is a limited form of "shape casting". -XlaOp Reshape(XlaOp operand, absl::Span new_sizes); - -// Enqueues a Reshape op that uses an explicit target shape. -XlaOp Reshape(const Shape& shape, XlaOp operand); - -// `inferred_dimension` represents the output dimension that's inferred by -// upper-level framework by dividing the input element count by the known -// output element count. While an inferred_dimension can be static, if there -// is a dynamic dimension in the output, it must be the inferred dimension. -XlaOp ReshapeWithInferredDimension(XlaOp operand, - absl::Span new_sizes, - int64_t inferred_dimension); - -// Wrapper for Reshape. -// Enqueues an operation to collapse the provided dimensions; e.g. an -// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to -// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must -// be a consecutive, in-order subsequence of the operand dimensions. -// -// Note that collapsing a single dimension does nothing: -// -// {256} collapsing {0} => {256} -// {1} collapsing {0} => {1} -// -// Collapsing multiple dimensions produces a single result dimension: -// -// {256, 2} collapsing {0,1} => {512} -// {256, 2, 3} collapsing {0,1} => {512, 3} -// -// This could potentially cause data to be moved -- it provides a more -// structured form of reshaping than an arbitrary Reshape operation. -XlaOp Collapse(XlaOp operand, absl::Span dimensions); - -// Enqueues a slice operation onto the computation that slices the operand -// from the start indices to the limit indices; e.g. -// -// x -// [ 0 1 2 3 ] -// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] -// [ 8 9 a b ] -// -// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D -// range notation. -// The strides parameter determines the stride over the slice -XlaOp Slice(XlaOp operand, absl::Span start_indices, - absl::Span limit_indices, - absl::Span strides); - -// Enqueues a slice operation in a given dimension, taking all other -// dimensions as they are; e.g. if dimno is 1 from start_index 2 to -// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand -// for: -// -// array[:, 2:4:1, :] -XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, - int64_t stride, int64_t dimno); - -// Enqueues a slice operation onto the computation that slices the 'operand' -// from dynamic start indices which are passed in 'start_indices'. -// The size of the slice in each dimension is passed in 'slice_sizes', -// which specify the end point of exclusive slice intervals in each -// dimension [start, start + size). -// The shape of each element of 'start_indices' must be scalar, with the span -// size equal to the rank of the 'operand'. All elements of 'start_indices' must -// have the same shape. -// Slice index calculations are computed modulo input dimension sizes to -// prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, - absl::Span slice_sizes); - -// Enqueues a dynamic update slice operation onto the computation, which -// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. -// The shape of 'update' determines the shape of the slice of 'operand' -// which is updated. -// The indices specified in 'start_indices' specify the offset of the slice -// of 'operand' which is updated. -// -// update = {10, 11} // calculated at runtime. -// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] -// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] -// [7 8 9] [7 8 9 ] -// -// The shape of each element of 'start_indices' must be scalar, with the span -// size equal to the rank of the 'operand'. All elements of 'start_indices' must -// have the same shape. -// Slice index calculations are computed modulo update dimension sizes to -// prevent dynamic start indices from generating out-of-bound array accesses. -XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, - absl::Span start_indices); - -// Enqueues a concatenate instruction onto the computation. 'operands' must -// have >= 1 entry. -XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, - int64_t dimension); - -// Enqueues a conditional-move-like select operation onto the computation; -// predicated on pred, selects between on_true and on_false. -XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); - -// Enqueues a tuple-creation instruction onto the computation. -XlaOp Tuple(XlaBuilder* builder, absl::Span elements); - -// Enqueues a tuple-element-get instruction onto the computation. -XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); - -// Enqueues an equal-to comparison instruction onto the computation. -XlaOp Eq(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a not-equal comparison instruction onto the computation. -XlaOp Ne(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a greater-or-equal comparison instruction onto the computation. -XlaOp Ge(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a greater-than comparison instruction onto the computation. -XlaOp Gt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a less-than comparison instruction onto the computation. -XlaOp Lt(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a less-or-equal comparison instruction onto the computation. -XlaOp Le(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a comparison instruction onto the computation (optionally without -// broadcast_dimensions for consistency with others). -XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction, Comparison::Type compare_type); -XlaOp Compare(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions, - ComparisonDirection direction); -XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); - -// Enqueues a dot instruction onto the computation. -XlaOp Dot(XlaOp lhs, XlaOp rhs, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a general dot instruction onto the computation. -XlaOp DotGeneral( - XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a sparse dot instruction onto the computation. -XlaOp SparseDot( - XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, - absl::Span sparsity, - const DotDimensionNumbers& dimension_numbers, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, which uses the -// default convolution dimension numbers. -XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, int64_t feature_group_count = 1, - int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration in the format returned by MakePadding(). -XlaOp ConvWithGeneralPadding( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided dimension numbers configuration. -XlaOp ConvWithGeneralDimensions( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration as well as the dimension numbers. -XlaOp ConvGeneral( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt); - -// Enqueues a convolution instruction onto the computation, with the caller -// provided padding configuration, dilation factors and dimension numbers. -XlaOp ConvGeneralDilated( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count = 1, int64_t batch_group_count = 1, - const PrecisionConfig* precision_config = nullptr, - std::optional preferred_element_type = std::nullopt, - std::optional> window_reversal = std::nullopt); - -XlaOp DynamicConvForward( - XlaOp lhs, XlaOp rhs, absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -XlaOp DynamicConvInputGrad( - XlaOp input_sizes, XlaOp lhs, XlaOp rhs, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -XlaOp DynamicConvKernelGrad( - XlaOp activations, XlaOp gradients, - absl::Span window_strides, - absl::Span> padding, - absl::Span lhs_dilation, - absl::Span rhs_dilation, - const ConvolutionDimensionNumbers& dimension_numbers, - int64_t feature_group_count, int64_t batch_group_count, - const PrecisionConfig* precision_config, PaddingType padding_type, - std::optional preferred_element_type = std::nullopt); - -// Enqueues an FFT instruction onto the computation, of the given type and -// with the given FFT length. -XlaOp Fft(XlaOp operand, FftType fft_type, - absl::Span fft_length); - -// Solves systems of linear equations with lower or upper triangular coefficient -// matrices by forward- or back-substitution. Broadcasting along leading -// dimensions, this routine solves for x in one of the matrix systems -// `op(a) * x = b`, or `x * op(a) = b`, -// for the variable `x` given `a` and `b`, where `op(a)` is either -// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. -// -// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form -// square matrices. If `lower` is true (false), then the strictly upper -// (lower) triangular part of each innermost matrix in `a` is assumed to be -// zero and is not accessed. -// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a -// tensor of shape `[..., K, M]`. -// * `left_side` is a boolean, indicating whether to solve a system of the form -// op(a) * x = b (true) or x * op(a) = b (false). -// * `lower` is a boolean, indicating whether the argument `a` is -// lower-triangular (true) or upper-triangular (false). -// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be -// 1 and not accessed. -// * `transpose_a` indicates which function `op` we use to transform the tensor -// `a`: the identity function, transpose(a), or conjugate(transpose(a)) -XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, - bool unit_diagonal, - TriangularSolveOptions::Transpose transpose_a); - -// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) -// positive definite matrices. -// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the -// two minor dimensions equal. -// If `lower` is true, the data from the lower triangle is used; if false, the -// upper triangle is used. The input data in the other triangle of the input -// does not affect the output. Returns the output in the same lower/upper -// triangle. The data returned in the other output triangle is arbitrary and -// implementation-defined. -// -// If `a` is not Hermitian positive definite, returns an array full of NaNs. -XlaOp Cholesky(XlaOp a, bool lower); - -// Enqueues an infeed instruction onto the computation, which writes data of -// the given shape to the infeed buffer of the device. -XlaOp Infeed(XlaBuilder* builder, const Shape& shape, - const std::string& config = ""); - -// Variant of Infeed which takes a token-shaped operand and produces a -// two-element tuple containing the data value and a token-shaped value. -// Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const std::string& config = ""); - -// Enqueues an outfeed instruction onto the computation. This instruction -// generates outgoing data transfers for the given data. -// -// shape_with_layout communicates the laid out shape that we want to outfeed -// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error -// will occur. -void Outfeed(XlaOp operand, const Shape& shape_with_layout, - const std::string& outfeed_config); - -// Variant of Outfeed which takes a token-shaped operand and produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, - const Shape& shape_with_layout, - const std::string& outfeed_config); - -// Enqueues a call instruction onto the computation. -XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands); - -// Enqueues a composite call instruction onto the computation. -XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, - absl::Span operands, const std::string& name, - std::optional attributes = std::nullopt, - std::optional version = std::nullopt); - -// Enqueues a custom call instruction onto the computation. A custom call -// invokes code external to XLA. The |operands| are passed to the external code, -// and the external code is expected to produce a result of the given -// |shape|. The exact mechanism is backend-specific. For example, in the CPU -// backend, a call instruction is emitted which targets a symbol with the name -// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, -// but |call_target_name| should be short as it may be used in labels. |opaque| -// can encode arbitrarily large amounts of information. |has_side_effect| -// specifies whether the instruction can have side effects. -// |output_operand_aliasing| specifies a list of output/operand buffer pairs -// that alias each other, where the output buffer is represented as a -// ShapeIndex, and the operand buffer is represented as the operand index and -// the ShapeIndex. -XlaOp CustomCall( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - const std::string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which constructs a custom call that applies an Xla computation. -XlaOp CustomCallWithComputation( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const XlaComputation& computation, - const Shape& shape, const std::string& opaque = "", - bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which constructs a custom call with fixed layouts. The operands will -// have the layouts specified by |operand_shapes_with_layout| when provided to -// external code, and the external code is expected to produce a result with the -// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| -// and |operand_shapes_with_layout| must have layouts. -XlaOp CustomCallWithLayout( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape_with_layout, - absl::Span operand_shapes_with_layout, - const std::string& opaque = "", bool has_side_effect = false, - absl::Span>> - output_operand_aliasing = {}, - const Literal* literal = nullptr, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Overload which annotates a custom call with the given Window and -// ConvolutionDimensionNumbers. Useful for custom-calls which represent -// convolutions. -// -// This sets the layout of its operands if operand_shapes_with_layout is -// nonempty, and it sets the layout of its result if `shape` has a layout. -XlaOp CustomCallWithConvDnums( - XlaBuilder* builder, const std::string& call_target_name, - absl::Span operands, const Shape& shape, - absl::Span operand_shapes_with_layout, - const std::string& opaque, bool has_side_effect, - absl::Span>> - output_operand_aliasing, - const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, - CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, - CustomCallApiVersion api_version = API_VERSION_ORIGINAL); - -// Enqueues an optimization barrier onto the computation. -XlaOp OptimizationBarrier(XlaOp operand); - -// The following methods enqueue element-wise binary arithmetic operations -// onto the computation. The shapes of the operands have to match unless one -// of the operands is a scalar, or an explicit broadcast dimension is given -// (see g3doc for more details). - -// Enqueues a complex compose instruction onto the computation. -XlaOp Complex(XlaOp real, XlaOp imag, - absl::Span broadcast_dimensions = {}); - -// Enqueues a complex conjugate instruction onto the computation. -XlaOp Conj(XlaOp operand); - -// Enqueues an add instruction onto the computation. -XlaOp Add(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a subtract instruction onto the computation. -XlaOp Sub(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a multiply instruction onto the computation. -XlaOp Mul(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a divide instruction onto the computation. -XlaOp Div(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a remainder instruction onto the computation. -XlaOp Rem(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a max instruction onto the computation. -XlaOp Max(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues a min instruction onto the computation. -XlaOp Min(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Element-wise logical operators -XlaOp And(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Overload to call And with 3 or more operands. We need the following somewhat -// convoluted overload set to disambiguate with the overload that takes the -// `broadcast_dimensions` optional param. -inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) { - return And(op1, And(op2, op3)); -} -template -XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { - return And(op1, And(op2, And(op3, operands...))); -} - -XlaOp Or(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Overload to call Or with 3 or more operands. As with `And`, we need the -// following complicated overload set to handle the default arg in the `Or` -// overload above. -inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) { - return Or(op1, Or(op2, op3)); -} -template -XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { - return Or(op1, Or(op2, Or(op3, operands...))); -} - -XlaOp Xor(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -XlaOp Not(XlaOp operand); - -XlaOp PopulationCount(XlaOp operand); - -XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); -// Reduces an array among the provided dimensions, given "computation" as a -// reduction operator. -XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span dimensions_to_reduce); - -// Reduces several arrays simultaneously among the provided dimensions, given -// "computation" as a reduction operator. -XlaOp Reduce(XlaBuilder* builder, absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span dimensions_to_reduce); - -// Convenience wrapper around the above that reduces all the dimensions in the -// operand shape. -XlaOp ReduceAll(XlaOp operand, XlaOp init_value, - const XlaComputation& computation); - -// Enqueues a windowed reduce instruction onto the computation. -XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -XlaOp ReduceWindow(absl::Span operands, - absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, Padding padding); - -// As ReduceWindow(), but the padding is given in the format -// returned by MakePadding(). -XlaOp ReduceWindowWithGeneralPadding( - XlaOp operand, XlaOp init_value, const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); -XlaOp ReduceWindowWithGeneralPadding( - absl::Span operands, absl::Span init_values, - const XlaComputation& computation, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span base_dilations, - absl::Span window_dilations, - absl::Span> padding); - -// Returns the sum of the operand value within each subgroup of replicas. All -// replicas supply one input to the sum and all replicas receive the resulting -// sum for each subgroup. -XlaOp CrossReplicaSum(XlaOp operand, - absl::Span replica_groups = {}); - -XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, - int64_t shard_count, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp AllGatherTuple( - absl::Span operands, int64_t all_gather_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -// Enqueues an operation that do an AllReduce of the operand cross cores. Here -// AllReduce means doing a reduction on the input operand cross cores and then -// broadcasting the reduction result to those cores. The reduction function is -// defined by `computation`, which should be a commutative computation on -// scalars, e.g., add, min, or max. The way that AllReduce is applied is -// configured by: -// -// - `replica_groups`: each ReplicaGroup contains a list of replica id. If -// empty, all replicas belong to one group. Allreduce will be applied within -// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} -// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. -// -// - `channel_id`: for Allreduce nodes from different modules, if they have the -// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be -// applied cross modules. -// -// - `shape_with_layout`: forces the layout of the AllReduce to the given -// layout. This is used to guarantee the same layout for a group of AllReduce -// ops compiled separately. -XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp AllReduceTuple( - absl::Span operand, const XlaComputation& computation, - absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& shape_with_layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -XlaOp ReduceScatter( - XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, - int64_t shard_count, absl::Span replica_groups = {}, - const std::optional& channel_id = std::nullopt, - const std::optional& layout = std::nullopt, - std::optional use_global_device_ids = std::nullopt); - -// Enqueues an operation that do an AllToAll of the operand cross cores. -// This involves AllToAll, followed by Reshape, Transpose, and another Reshape -// to get proper codegen. See implementation for additional details. -// -// An optional `layout` can be specified to force the layout of the instruction. -// This is used to guarantee the same layout for a group of AllToAll ops -// compiled separately. -XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp AllToAllTuple( - absl::Span operand, - absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp AllToAllTuple( - XlaOp operand, int64_t split_dimension, int64_t concat_dimension, - int64_t split_count, absl::Span replica_groups = {}, - const std::optional& layout = std::nullopt, - const std::optional& channel_id = std::nullopt); - -XlaOp CollectiveBroadcast( - XlaOp operand, absl::Span replica_groups, - const std::optional& channel_id = std::nullopt); - -// Enqueues an collective operation that sends and receives data cross replicas. -// -// - `source_target_pair`: a list of (source_replica_id, target_replica_id) -// pairs. For each pair, the operand is sent from source replica to target -// replica. Note that, 1) any two pairs should not have the same target replica -// id, and they should not have the same source replica id; 2) if a replica id -// is not a target in any pair, then the output on that replica is a tensor -// consists of 0(s) with the same shape as the input. -XlaOp CollectivePermute( - XlaOp operand, - const std::vector>& source_target_pairs, - const std::optional& channel_id = std::nullopt); - -// Enqueues an operation that returns the replica ID. -XlaOp ReplicaId(XlaBuilder* builder); - -// Enqueues an operation that scatters the `source` array to the selected -// indices of each window. -XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - Padding padding, XlaOp source, XlaOp init_value, - const XlaComputation& scatter); - -// As SelectAndScatter(), but the padding is given in the format -// returned by MakePadding(). -XlaOp SelectAndScatterWithGeneralPadding( - XlaOp operand, const XlaComputation& select, - absl::Span window_dimensions, - absl::Span window_strides, - absl::Span> padding, XlaOp source, - XlaOp init_value, const XlaComputation& scatter); - -// Enqueues an abs instruction onto the computation. -XlaOp Abs(XlaOp operand); - -// Enqueues a atan2 instruction onto the computation. -XlaOp Atan2(XlaOp y, XlaOp x, - absl::Span broadcast_dimensions = {}); - -// Enqueues an erf instruction onto the computation. -XlaOp Erf(XlaOp operand); - -// Enqueues an exp instruction onto the computation. -XlaOp Exp(XlaOp operand); - -// Enqueues an expm1 instruction onto the computation. -XlaOp Expm1(XlaOp operand); - -// Enqueues a floor instruction onto the computation. -XlaOp Floor(XlaOp operand); - -// Enqueues a ceil instruction onto the computation. -XlaOp Ceil(XlaOp operand); - -// Enqueues a round instruction onto the computation, -// with half-way cases rounding away from zero. -XlaOp Round(XlaOp operand); - -// Enqueues a round instruction onto the computation, rounding to nearest even -XlaOp RoundNearestEven(XlaOp operand); - -// Enqueues an log instruction (natural logarithm) onto the computation. -XlaOp Log(XlaOp operand); - -// Enqueues an log1p instruction (log(x+1)) onto the computation. -XlaOp Log1p(XlaOp operand); - -// Enqueues a logistic instruction onto the computation. -XlaOp Logistic(XlaOp operand); - -// Enqueues a sign instruction onto the computation. -XlaOp Sign(XlaOp operand); - -// Enqueues a count leading zeros instruction onto the computation. -XlaOp Clz(XlaOp operand); - -// Enqueues a cosine instruction onto the computation. -XlaOp Cos(XlaOp operand); - -// Enqueues a sine instruction onto the computation. -XlaOp Sin(XlaOp operand); - -// Enqueues a tan instruction onto the computation. -XlaOp Tan(XlaOp operand); - -// Enqueues a tanh instruction onto the computation. -XlaOp Tanh(XlaOp operand); - -// Enqueues a real-part instruction onto the computation. -XlaOp Real(XlaOp operand); - -// Enqueues an imaginary-part instruction onto the computation. -XlaOp Imag(XlaOp operand); - -// Enqueues a sqrt computation onto the computation. -XlaOp Sqrt(XlaOp operand); - -// Enqueues a cbrt computation onto the computation. -XlaOp Cbrt(XlaOp operand); - -// Enqueues a rsqrt computation onto the computation. -XlaOp Rsqrt(XlaOp operand); - -// Enqueues a lhs^rhs computation onto the computation. -XlaOp Pow(XlaOp lhs, XlaOp rhs, - absl::Span broadcast_dimensions = {}); - -// Enqueues an operator that tests if the operand's values are finite, i.e., not -// +/-Inf or NaN. Returns an array of booleans with the same shape where -// entries are true iff the corresponding entry was not infinite or NaN. -// -// Defined only for real-valued (i.e. not complex) floating-point types; raises -// an error for other types. -// -// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. -XlaOp IsFinite(XlaOp operand); - -// Enqueues an iota operation onto the computation. -XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension); - -// Enqueues a rank-1 iota operation onto the computation. -XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); - -// Enqueues a convert instruction onto the computation that changes the -// element type of the operand array to primitive_type. -XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); - -// Enqueues a no-op instruction onto the computation that changes -// the element type of the operand array to primitive_type. The -// bit-widths of the source and destination element types must be -// identical. -XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); - -// Enqueues a stochastic convert instruction onto the computation that changes -// the element type of the operand array with stochastic rounding to -// primitive_type. -XlaOp StochasticConvertType(XlaOp operand, XlaOp random, - PrimitiveType new_element_type); - -// Enqueues a negate instruction onto the computation. -XlaOp Neg(XlaOp operand); - -// Enqueues a transpose instruction onto the computation. -XlaOp Transpose(XlaOp operand, absl::Span permutation); - -// Enqueues a reverse instruction onto the computation. The order of the -// elements in the given dimensions is reversed (i.e., the element at index i -// is moved to index dimension_size - 1 - i). -XlaOp Rev(XlaOp operand, absl::Span dimensions); - -// Enqueues a sort instruction onto the computation, using 'comparator' for -// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' -// determines whether the stable sorting should be used. -// If only one operand is provided: -// * If the operand is a rank-1 tensor (an array), the result is a sorted array. -// The resulting sorting order has the property that for all index positions -// i, j with i < j, either -// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or -// comparator(value[i], value[j]) = true. -// * If the operand has higher rank, the operand is sorted along the provided -// dimension. For example, for a rank-2 tensor (a matrix), a dimension value -// of 0 will independently sort every column, and a dimension value of 1 will -// independently sort each row. If no dimension number is provided, then the -// last dimension is chosen by default. For the dimension which is sorted, the -// same sorting order applies as in the rank-1 case. -// -// If more than one operand is provided: -// * All operands must be tensors with the same dimensions. The element types of -// the tensors may be different. -// * The result is a tuple that consists of the operands in sorted order (along -// the provided dimension, as above). The same permutation as implied by the -// comparison computation is applied to all operand tensors. When comparing -// two index positions, 'comparator' is called with 2 * n scalar parameters, -// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at -// two index positions. -// Default comparator computations can be found in lib/comparators.h -XlaOp Sort(absl::Span operands, const XlaComputation& comparator, - int64_t dimension = -1, bool is_stable = false); - -// Enqueues a topk instruction onto the computation. TopK returns the largest -// 'k' values and their indices along the last dimension of the 'operand' if -// `lagest=true` or the smallest `k` values if `largest=false`. -// -// * If the operand is a rank-1 tensor (an array), the result is a tuple that -// consists of: -// * a sorted array with the top 'k' elements. -// * an array containing the indices of the k elements. -// For example, if the input is [0.1, 0.3, 0.2] and k == 2, the output tuple -// is ([0.3, 0.2], [1, 2]). -// * If the operand has higher rank, the result is a tuple that consists of: -// * a tensor equivalent to one produced by sorting the operand along the last -// dimension and slicing that dimension to only the top 'k' values. The last -// dimension is sorted as in the rank-1 case. -// * a tensor containing the indices of the top 'k' values along the last -// dimension. -// For example, if the input is [0.1, 0.3, 0.2][0.5, 0.4, 0.6] and k == 1, the -// output tuple is ([0.3][0.6], [1][2]). -XlaOp TopK(XlaOp operand, int64_t k, bool largest); - -// Enqueues a clamp instruction onto the computation. -XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); - -// Enqueues a map instruction onto the computation. -XlaOp Map(XlaBuilder* builder, absl::Span operands, - const XlaComputation& computation, - absl::Span dimensions, - absl::Span static_operands = {}); - -// Enqueues a N(mu, sigma) random number generation instruction onto the -// computation. -XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); - -// Enqueues a U(a, b) random number generation instruction onto the -// computation. Returns values in the semi-open interval [a, b). -XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); - -// Enqueues a B(initial_state) random bit generation instruction onto the -// computation. Returns the new key and random bits with the specified shape. -XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, - const Shape& shape); - -// Enqueues a while node onto the computation. -XlaOp While(const XlaComputation& condition, const XlaComputation& body, - XlaOp init); - -// Enqueues a conditional node onto the computation. -XlaOp Conditional(XlaOp predicate, XlaOp true_operand, - const XlaComputation& true_computation, XlaOp false_operand, - const XlaComputation& false_computation); - -// Enqueues either a predicated (if/else) or indexed (switch/case/default) -// conditional node onto the computation. N >= 1 branch_computations and -// branch_operands are matched by index. branch_index selects the branch that -// will be executed. Out of range branch_index uses the N-1'th -// branch_computation as default. -XlaOp Conditional(XlaOp branch_index, - absl::Span branch_computations, - absl::Span branch_operands); - -// Enqueues a ReducePrecision node onto the computation. -XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); - -// Enqueues a Gather node onto the computation. -XlaOp Gather(XlaOp input, XlaOp start_indices, - const GatherDimensionNumbers& dimension_numbers, - absl::Span slice_sizes, - bool indices_are_sorted = false); - -// Enqueues a Scatter node onto the computation. -XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); -XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, - absl::Span updates, - const XlaComputation& update_computation, - const ScatterDimensionNumbers& dimension_numbers, - bool indices_are_sorted = false, bool unique_indices = false); - -// Enqueues a Send node onto the computation for device-to-device -// communication. This operation sends the given operand to -// a Recv instruction in a different computation that shares the same channel -// handle. -void Send(XlaOp operand, const ChannelHandle& handle); - -// Variant of Send which takes a token-shaped operand and produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); - -// Enqueues a Recv node onto the computation for device-to-device -// communication. The data comes from a Send instruction in a different -// computation that shares the same channel handle and its shape must be the -// same as the given shape. -XlaOp Recv(XlaBuilder* builder, const Shape& shape, - const ChannelHandle& handle); - -// Variant of Recv which takes a token-shaped operand and produces a two-element -// tuple containing the data value and a token-shaped value. Tokens are used -// for ordering side-effecting operations. -// TODO(b/110532604): Replace all uses of the non-token form with this variant. -XlaOp RecvWithToken(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - -// Enqueues a Send node which transfers data from the device to the host. The -// 'shape_with_layout' argument defines the layout of the data transferred; its -// shape must be compatible with the shape of the operand. The operand must be -// array-shaped. -// TODO(b/111544877): Support tuple shapes. -XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, - const ChannelHandle& handle); - -// Enqueues a Recv node which transfers data from the host to the device. The -// given shape must contain a layout and must be an array. -// TODO(b/111544877): Support tuple shapes. -XlaOp RecvFromHost(XlaOp token, const Shape& shape, - const ChannelHandle& handle); - -// Enqueues an operation (AfterAll) with no operands that produces a -// token-shaped value. Tokens are used for ordering side-effecting operations. -// This is a separate method from AfterAll to facility the removal of -// operand-less AfterAll instructions. -// TODO(b/110532604): Remove this function when all tokens are derived from a -// single token generated or passed into the entry computation. -XlaOp CreateToken(XlaBuilder* builder); - -// Enqueues an AfterAll instruction which produces a token-shaped value and -// takes a variadic number of token-shaped operands. The number of operands must -// be greater than zero. Used for joining tokens. -XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); - -// Normalizes operand across spatial and batch dimensions for each feature. -// -// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` -// is the normalized result and batch_mean and batch_var are the mean and -// variance, respectively, across batch for the operand. -XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, - int64_t feature_index); - -// Normalizes operand across spatial and batch dimensions for each feature. -// -// `BatchNormInference` is equivalent to calling `BatchNormTraining` without -// computing `mean` and `variance` for each batch inside the operation. It -// uses the input `mean` and `variance` instead as estimated values. The -// purpose of this op is to reduce latency in inference, hence the name -// `BatchNormInference`. -// -// The output has the same shape as `operand`, and contains the normalized -// values for each batch. -XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, - XlaOp variance, float epsilon, int64_t feature_index); - -// Calculates the gradients of a batch norm op. -// -// The inputs `batch_mean` and `batch_var` represent the mean and variance -// across the batch. -// -// Returns a tuple of three elements: -// - grad_operand: Gradient with respect to input `operand` -// - grad_offset: Gradient with respect to input `offset` -// - grad_scale: Gradient with respect to input `scale` -XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, - XlaOp batch_var, XlaOp grad_output, float epsilon, - int64_t feature_index); - -// Returns the size of the given dimension of the operand. The operand must be -// array shaped. -XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); - -// Sets the size of the given dimension of the operand. The operand must be -// array shaped. The result will have the same shape as the operand, but the -// given dimension will be dynamic (if not already). -XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); - -// Returns the same op but with dynamic dimension removed. -XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); - -// Implementation details below this point. -// - -// Free function template implementations. - -template -XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { - return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); -} - -template -XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { - BorrowingLiteral literal( - reinterpret_cast(values.begin()), - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), - {static_cast(values.size())})); - return ConstantLiteral(builder, literal); -} - -template -XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) { - Literal literal(ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {length})); - literal.PopulateWithValue(value); - return ConstantLiteral(builder, literal); -} - -inline XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values) { - return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); -} - -template -XlaOp ConstantR2(XlaBuilder* builder, - std::initializer_list> values) { - return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); -} - -template -XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, - const Array& values, - const Layout& layout) { - return ConstantLiteral( - builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { - return ConstantLiteral(builder, - LiteralUtil::CreateFromArray(values)); -} - -template -XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, - const Array2D& values, - const Layout& layout) { - return ConstantLiteral( - builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); -} - -template -XlaOp ConstantR2FromArray2D(XlaBuilder* builder, - const Array2D& values) { - return ConstantLiteral(builder, - LiteralUtil::CreateR2FromArray2D(values)); -} - -template -XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, - const Array3D& values, - const Layout& layout) { - return ConstantLiteral( - builder, - LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); -} - -template -XlaOp ConstantR3FromArray3D(XlaBuilder* builder, - const Array3D& values) { - return ConstantFromArray(builder, values); -} - -template -XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, - const Array4D& values, - const Layout& layout) { - return ConstantFromArrayWithLayout(builder, values, layout); -} - -template -XlaOp ConstantR4FromArray4D(XlaBuilder* builder, - const Array4D& values) { - return ConstantFromArray(builder, values); -} - -// Switches from automatic SPMD partitioning to manual partitioning. Converts a -// full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a -// shard-shaped tensor to be consumed by manually partitioned ops. -absl::StatusOr ConvertSpmdFullToShardShape( - xla::XlaBuilder* builder, xla::XlaOp input, int single_dim, - const xla::OpSharding& manual_sharding, - absl::Span unspecified_dims); - -// Switches from manual partitioning to automatic SPMD partitioning. Converts a -// shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped -// tensor to be partitioned automatically by the SPMD partitioner. -absl::StatusOr ConvertSpmdShardToFullShape( - xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape, - int single_dim, const xla::OpSharding& manual_sharding, - absl::Span unspecified_dims); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/xla_builder.h" #endif // XLA_CLIENT_XLA_BUILDER_H_ diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h index 52a54aa113b178..685fcfecb0b093 100644 --- a/third_party/xla/xla/client/xla_computation.h +++ b/third_party/xla/xla/client/xla_computation.h @@ -16,58 +16,7 @@ limitations under the License. #ifndef XLA_CLIENT_XLA_COMPUTATION_H_ #define XLA_CLIENT_XLA_COMPUTATION_H_ -#include -#include -#include - -#include "absl/status/statusor.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/status_macros.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// The computation graph that the user builds up with the XlaBuilder. -class XlaComputation { - public: - XlaComputation() : unique_id_(-1) {} - XlaComputation(HloModuleProto proto) - : unique_id_(proto.id()), proto_(std::move(proto)) {} - - ~XlaComputation() = default; - - XlaComputation(const XlaComputation&) = delete; - XlaComputation& operator=(const XlaComputation&) = delete; - - XlaComputation(XlaComputation&& from) = default; - - XlaComputation& operator=(XlaComputation&& from) = default; - - // Returns the "program shape" (parameter and return shapes) for this - // computation. - absl::StatusOr GetProgramShape() const; - - const std::string& name() const { return proto().name(); } - - const HloModuleProto& proto() const { return proto_; } - HloModuleProto* mutable_proto() { return &proto_; } - - // Requests that we snapshot the computation into a serializable protocol - // buffer form. - absl::StatusOr> Snapshot() const; - - // Returns true if this object is a null Computation. - bool IsNull() const { return unique_id_ == -1; } - - private: - XlaComputation(const int64_t unique_id) : unique_id_(unique_id) {} - friend class XlaBuilder; - - int64_t unique_id_; - HloModuleProto proto_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/builder/xla_computation.h" #endif // XLA_CLIENT_XLA_COMPUTATION_H_ diff --git a/third_party/xla/xla/cpu_function_runtime.cc b/third_party/xla/xla/cpu_function_runtime.cc index 034f21acc359a4..cbc0f71e66185c 100644 --- a/third_party/xla/xla/cpu_function_runtime.cc +++ b/third_party/xla/xla/cpu_function_runtime.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { namespace { -// Inline memory allocation routines here, because depending on '//base' brings +// Inline memory allocation routines here, because depending on 'base' brings // in libraries which use c++ streams, which adds considerable code size on // android. void* aligned_malloc(size_t size, int minimum_alignment) { diff --git a/third_party/xla/xla/cpu_function_runtime.h b/third_party/xla/xla/cpu_function_runtime.h index 6dccf0e0facc13..214fa37ea4a77c 100644 --- a/third_party/xla/xla/cpu_function_runtime.h +++ b/third_party/xla/xla/cpu_function_runtime.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "Eigen/Core" + namespace xla { namespace cpu_function_runtime { @@ -179,7 +181,7 @@ class BufferInfo { inline constexpr size_t Align() { return 64; } // The minimum alignment of buffers passed to XLA:CPU. -inline constexpr size_t MinAlign() { return 16; } +inline constexpr size_t MinAlign() { return EIGEN_MAX_ALIGN_BYTES; } // When declaring variables that will be passed to an XLA instance as input via // set_arg_data(), be it a regular input or a resource variable in the graph, diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index f651949a82a14d..872a68ac7833a7 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" +#include "xla/service/collective_utils.h" #include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/tsl/util/command_line_flags.h" @@ -85,8 +86,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { #endif opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_parallel_codegen_split_count(32); + opts.set_xla_cpu_copy_insertion_use_region_analysis(false); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); + opts.set_xla_cpu_max_isa(""); opts.set_xla_cpu_enable_fast_math(false); // Disable forms of fast math that have caused users problems in the past. @@ -127,11 +130,13 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_allow_excess_precision(true); opts.set_xla_force_host_platform_device_count(1); - constexpr int64_t kDefaultThreshold = 30 * 1024 * 1024; - opts.set_xla_gpu_all_reduce_combine_threshold_bytes(kDefaultThreshold); - opts.set_xla_gpu_all_gather_combine_threshold_bytes(kDefaultThreshold); - opts.set_xla_gpu_reduce_scatter_combine_threshold_bytes(kDefaultThreshold); - opts.set_xla_gpu_enable_all_gather_combine_by_dim(true); + opts.set_xla_gpu_all_reduce_combine_threshold_bytes( + kDefaultAllReduceCombineThreshold); + opts.set_xla_gpu_all_gather_combine_threshold_bytes( + kDefaultAllGatherCombineThreshold); + opts.set_xla_gpu_reduce_scatter_combine_threshold_bytes( + kDefaultReduceScatterCombineThreshold); + opts.set_xla_gpu_enable_all_gather_combine_by_dim(false); opts.set_xla_gpu_enable_reduce_scatter_combine_by_dim(true); opts.set_xla_gpu_enable_approx_costly_collectives(false); @@ -144,7 +149,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_dumping(true); opts.set_xla_gpu_enable_custom_fusions(false); - opts.set_xla_gpu_enable_dynamic_slice_fusion(true); + opts.set_xla_gpu_enable_dynamic_slice_fusion(false); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); @@ -168,11 +173,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_pipelined_collectives(false); opts.set_xla_gpu_enable_pipelined_all_reduce(false); opts.set_xla_gpu_enable_pipelined_all_gather(false); - opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); + opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); opts.set_xla_gpu_enable_pipelined_p2p(false); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(false); - opts.set_xla_gpu_collective_permute_decomposer_threshold( std::numeric_limits::max()); @@ -202,7 +205,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_exhaustive_tiling_search(false); - opts.set_xla_gpu_enable_priority_fusion(true); + opts.set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(false); opts.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(false); opts.set_xla_gpu_auto_spmd_partitioning_memory_budget_gb(0); @@ -219,7 +222,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_fusion_level(0); opts.set_xla_gpu_enable_while_loop_double_buffering(false); opts.set_xla_gpu_enable_while_loop_unrolling( - DebugOptions::WHILE_LOOP_UNROLLING_NO_UNROLL); + DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL); opts.set_xla_gpu_ensure_minor_dot_contraction_dims(false); opts.set_xla_gpu_filter_kernels_spilling_registers_on_autotuning(true); opts.set_xla_gpu_llvm_verification_level(0); @@ -227,8 +230,10 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_cub_radix_sort(true); opts.set_xla_gpu_enable_cudnn_layer_norm(false); opts.set_xla_gpu_threshold_for_windowed_einsum_mib(100000); + opts.set_xla_gpu_operand_bytes_threshold_for_windowed_einsum(-1); opts.set_xla_gpu_enable_triton_hopper(false); + opts.set_xla_gpu_experimental_enable_fusion_block_level_rewriter(false); opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); opts.set_xla_gpu_enable_libnvptxcompiler( @@ -284,13 +289,19 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_cudnn_gemm_max_plans(5); - opts.set_xla_gpu_enable_triton_gemm_int4(false); - + // TODO: remove this as it is replaced by xla_gpu_pgle_accuracy_checker. opts.set_xla_gpu_enable_pgle_accuracy_checker(false); + opts.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN); + opts.set_xla_gpu_executable_warn_stuck_timeout_seconds(10); opts.set_xla_gpu_executable_terminate_timeout_seconds(30); opts.set_xla_gpu_experimental_disable_binary_libraries(false); + opts.set_xla_experimental_ignore_channel_id(false); + opts.set_xla_gpu_dot_merger_threshold_mb(32); + opts.set_xla_enable_fast_math(false); + opts.set_xla_gpu_experimental_parallel_collective_overlap_limit(1); return opts; } @@ -376,6 +387,15 @@ void MakeDebugOptionsFlags(std::vector* flag_list, }; }; + auto uppercase_string_setter_for = + [debug_options]( + void (DebugOptions::*member_setter)(const std::string& value)) { + return [debug_options, member_setter](const std::string& value) { + (debug_options->*member_setter)(absl::AsciiStrToUpper(value)); + return true; + }; + }; + auto float_setter_for = [debug_options](void (DebugOptions::*member_setter)(float)) { return [debug_options, member_setter](float value) { @@ -618,6 +638,17 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return absl::StrJoin(collective_ops, ", ", Formatter()); }; + // Custom parser for `xla_gpu_enable_while_loop_unrolling` flag. + auto setter_for_xla_gpu_enable_while_loop_unrolling = + [&debug_options](absl::string_view input) { + DebugOptions::WhileLoopUnrolling unroll_strategy; + bool parsed = DebugOptions::WhileLoopUnrolling_Parse( + absl::AsciiStrToUpper(input), &unroll_strategy); + if (!parsed) return false; + debug_options->set_xla_gpu_enable_while_loop_unrolling(unroll_strategy); + return true; + }; + // Custom parser for xla_gpu_disable_async_collectives. auto setter_for_xla_gpu_disable_async_collectives = [debug_options](const absl::string_view& input) { @@ -675,6 +706,18 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + // Custom "sub-parser" lambda for xla_gpu_pgle_accuracy_checker. + auto setter_for_xla_gpu_pgle_accuracy_checker = + [debug_options](const std::string& value) { + DebugOptions::PGLEStrictnessLevel strictness_level; + if (!DebugOptions::PGLEStrictnessLevel_Parse(value, + &strictness_level)) { + return false; + } + debug_options->set_xla_gpu_pgle_accuracy_checker(strictness_level); + return true; + }; + // Don't use an initializer list for initializing the vector; this would // create a temporary copy, and exceeds the stack space when compiling with // certain configurations. @@ -855,6 +898,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_cpu_parallel_codegen_split_count(), "Split LLVM module into at most this many parts before codegen to enable " "parallel compilation for the CPU backend.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_copy_insertion_use_region_analysis", + bool_setter_for( + &DebugOptions::set_xla_cpu_copy_insertion_use_region_analysis), + debug_options->xla_cpu_copy_insertion_use_region_analysis(), + "Use region based analysis in copy insertion pass.")); flag_list->push_back(tsl::Flag( "xla_cpu_enable_concurrency_optimized_scheduler", bool_setter_for( @@ -867,6 +916,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width), debug_options->xla_cpu_prefer_vector_width(), "Preferred vector with for the XLA:CPU LLVM backend.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_max_isa", + uppercase_string_setter_for(&DebugOptions::set_xla_cpu_max_isa), + debug_options->xla_cpu_max_isa(), + "Maximum ISA that XLA:CPU LLVM backend will codegen, i.e., it will not " + "use newer instructions. Available values: SSE4_2, AVX, AVX2, AVX512, " + "AVX512_VNNI, AVX512_BF16, AMX, and AMX_FP16. (`AMX` will enable both " + "`AMX_BF16` and `AMX_INT8` instructions.)")); flag_list->push_back(tsl::Flag( "xla_gpu_crash_on_verification_failures", bool_setter_for( @@ -966,6 +1023,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_dump_hlo_as_proto(), "Dumps HLO modules as HloProtos to the directory specified by " "--xla_dump_to.")); + flag_list->push_back( + tsl::Flag("xla_gpu_experimental_dump_fdo_profiles", + bool_setter_for( + &DebugOptions::set_xla_gpu_experimental_dump_fdo_profiles), + debug_options->xla_gpu_experimental_dump_fdo_profiles(), + "Dumps FDO profiles as text to the directory specified " + "by --xla_dump_to.")); flag_list->push_back( tsl::Flag("xla_dump_hlo_as_dot", bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot), @@ -1158,6 +1222,16 @@ void MakeDebugOptionsFlags(std::vector* flag_list, " synchronous ones. By default, this is empty which indicates enabling" " async execution for all collectives. A sample usage is: " " --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_while_loop_unrolling", + setter_for_xla_gpu_enable_while_loop_unrolling, + DebugOptions::WhileLoopUnrolling_Name( + debug_options->xla_gpu_enable_while_loop_unrolling()), + "Enables while loop unrolling features. " + "`WHILE_LOOP_UNROLLING_DOUBLE_BUFFER` unrolls the loop by factor of 2, " + "`WHILE_LOOP_UNROLLING_FULL_UNROLL` will unroll the entire loop " + "`WHILE_LOOP_UNROLLING_AUTO_UNROLL` unrolls by a factor of 2, if there is" + " any collective present within a while loop.")); flag_list->push_back(tsl::Flag( "xla_gpu_all_reduce_combine_threshold_bytes", int64_setter_for( @@ -1515,12 +1589,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_p2p), debug_options->xla_gpu_enable_pipelined_p2p(), "Enable pipelinling of P2P instructions.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_run_post_layout_collective_pipeliner", - bool_setter_for( - &DebugOptions::set_xla_gpu_run_post_layout_collective_pipeliner), - debug_options->xla_gpu_run_post_layout_collective_pipeliner(), - "Move collective pipeliner after the post-layout optimization.")); flag_list->push_back(tsl::Flag( "xla_gpu_collective_permute_decomposer_threshold", int64_setter_for( @@ -1566,11 +1634,19 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_exhaustive_tiling_search), debug_options->xla_gpu_exhaustive_tiling_search(), "Enable (slow) search for the Triton GEMM fusion tilings.")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_priority_fusion", + noop_flag_setter, true, + "[Deprecated, do not use]")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_priority_fusion", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_priority_fusion), - debug_options->xla_gpu_enable_priority_fusion(), - "Enable priority queue for fusion order.")); + "xla_gpu_experimental_enable_triton_heroless_priority_fusion", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_experimental_enable_triton_heroless_priority_fusion), + debug_options + ->xla_gpu_experimental_enable_triton_heroless_priority_fusion(), + "Enable heroless Triton fusions in the PriorityFusion pass. The pass " + "will try to make Triton fusions first and foremost where it is " + "possible.")); flag_list->push_back(tsl::Flag( "xla_gpu_experimental_enable_triton_softmax_priority_fusion", bool_setter_for( @@ -1747,11 +1823,30 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Einsums that have partitioned operand(can be either LHS or RHS) that's " "larger than this threshold will be transformed to use windowed einsums." "Default is 100000")); + flag_list->push_back(tsl::Flag( + "xla_gpu_operand_bytes_threshold_for_windowed_einsum", + int64_setter_for( + &DebugOptions:: + set_xla_gpu_operand_bytes_threshold_for_windowed_einsum), + debug_options->xla_gpu_operand_bytes_threshold_for_windowed_einsum(), + "This controls whether to enable windowed einsum (collective matmul) " + "based on the sum of sizes of 2 operands if set >= 0." + "If set >= 0, xla_gpu_threshold_for_windowed_einsum_mib is ignored." + "Default is -1")); + flag_list->push_back(tsl::Flag( "xla_gpu_enable_triton_hopper", bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_hopper), debug_options->xla_gpu_enable_triton_hopper(), "Currently used to enable MMA_V3 for Hopper in Triton")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_enable_fusion_block_level_rewriter", + bool_setter_for( + &DebugOptions:: + set_xla_gpu_experimental_enable_fusion_block_level_rewriter), + debug_options->xla_gpu_experimental_enable_fusion_block_level_rewriter(), + "Enabling this flag will attempt to redirect every fusion possible to " + "the Triton emitter")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_libnvptxcompiler", [debug_options](bool enabled) { @@ -1899,11 +1994,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Limit for the number of kernel configurations (plans) to use during " "autotuning of cuDNN GEMM fusions.")); - flag_list->push_back(tsl::Flag( - "xla_gpu_enable_triton_gemm_int4", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4), - debug_options->xla_gpu_enable_triton_gemm_int4(), - "Experimental: Enable Triton gemm for int4 inputs.")); + flag_list->push_back(tsl::Flag("xla_gpu_enable_triton_gemm_int4", + noop_flag_setter, true, + "[Deprecated, do not use]")); flag_list->push_back( tsl::Flag("xla_gpu_async_dot", bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), @@ -1918,12 +2011,13 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "a training. The location of the marker (if any) is determined " "by the option value of type DebugOptions::StepMarkerLocation.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_pgle_accuracy_checker", - bool_setter_for(&DebugOptions::set_xla_gpu_enable_pgle_accuracy_checker), - debug_options->xla_gpu_enable_pgle_accuracy_checker(), - "Enables strict PGLE checking. If an FDO profile is specified and " - "latency hiding scheduler encounters missing instructions in the profile " - "compilation will halt.")); + "xla_gpu_pgle_accuracy_checker", setter_for_xla_gpu_pgle_accuracy_checker, + DebugOptions::PGLEStrictnessLevel_Name( + debug_options->xla_gpu_pgle_accuracy_checker()), + "If an FDO profile is specified and latency hiding scheduler encounters " + "missing instructions in the profile, then the compilation will halt " + "(ERROR), or a warning will be emitted (WARN), or the checker is " + "disabled (OFF)")); flag_list->push_back(tsl::Flag( "xla_gpu_executable_warn_stuck_timeout", @@ -1944,6 +2038,29 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_experimental_disable_binary_libraries(), "Disable XLA GPU passes that depend on non-open source binary " "libraries")); + flag_list->push_back(tsl::Flag( + "xla_experimental_ignore_channel_id", + bool_setter_for(&DebugOptions::set_xla_experimental_ignore_channel_id), + debug_options->xla_experimental_ignore_channel_id(), + "Experimental: ignore channel ids for collective operations.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_dot_merger_threshold_mb", + int32_setter_for(&DebugOptions::set_xla_gpu_dot_merger_threshold_mb), + debug_options->xla_gpu_dot_merger_threshold_mb(), + "Dot merger pass threshold to be set in MB.")); + flag_list->push_back( + tsl::Flag("xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + debug_options->xla_enable_fast_math(), + "Enable optimizations that assume finite math, i.e., no NaN.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_experimental_parallel_collective_overlap_limit", + int32_setter_for( + &DebugOptions:: + set_xla_gpu_experimental_parallel_collective_overlap_limit), + debug_options->xla_gpu_experimental_parallel_collective_overlap_limit(), + "This controls how many in-flight collectives " + "latency hiding scheduler can schedule.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/debug_options_parsers_test.cc b/third_party/xla/xla/debug_options_parsers_test.cc index 318b98f163faa2..42aacf12760203 100644 --- a/third_party/xla/xla/debug_options_parsers_test.cc +++ b/third_party/xla/xla/debug_options_parsers_test.cc @@ -21,9 +21,15 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/strings/str_cat.h" +#include "xla/debug_options_flags.h" +#include "xla/parse_flags_from_env.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tsl/platform/env.h" #include "tsl/platform/test.h" namespace xla { +namespace { // Test that the xla_backend_extra_options flag is parsed correctly. TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { @@ -37,6 +43,59 @@ TEST(DebugOptionsFlags, ParseXlaBackendExtraOptions) { EXPECT_EQ(test_map.at("ee"), "ff=gg"); } +struct UppercaseStringSetterTestSpec { + std::string user_max_isa; + std::string expected_max_isa; +}; + +class UppercaseStringSetterTest + : public ::testing::Test, + public ::testing::WithParamInterface { + public: + UppercaseStringSetterTest() + : flag_values_(DefaultDebugOptionsIgnoringFlags()) { + MakeDebugOptionsFlags(&flag_objects_, &flag_values_); + } + static std::string Name( + const ::testing::TestParamInfo& info) { + return info.param.user_max_isa; + } + DebugOptions flag_values() const { return flag_values_; } + std::vector flag_objects() { return flag_objects_; } + + private: + DebugOptions flag_values_; + std::vector flag_objects_; +}; + +TEST_P(UppercaseStringSetterTest, XlaCpuMaxIsa) { + UppercaseStringSetterTestSpec spec = GetParam(); + tsl::setenv("XLA_FLAGS", + absl::StrCat("--xla_cpu_max_isa=", spec.user_max_isa).c_str(), + /*overwrite=*/true); + + // Parse flags from the environment variable. + int* pargc; + std::vector* pargv; + ResetFlagsFromEnvForTesting("XLA_FLAGS", &pargc, &pargv); + ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", flag_objects()); + EXPECT_EQ(flag_values().xla_cpu_max_isa(), spec.expected_max_isa); +} + +std::vector GetUppercaseStringSetterTestCases() { + return std::vector({ + UppercaseStringSetterTestSpec{"sse4_2", "SSE4_2"}, + UppercaseStringSetterTestSpec{"aVx512", "AVX512"}, + UppercaseStringSetterTestSpec{"AMx_fP16", "AMX_FP16"}, + }); +} + +INSTANTIATE_TEST_SUITE_P( + UppercaseStringSetterTestInstantiation, UppercaseStringSetterTest, + ::testing::ValuesIn(GetUppercaseStringSetterTestCases()), + UppercaseStringSetterTest::Name); + +} // namespace } // namespace xla int main(int argc, char* argv[]) { diff --git a/third_party/xla/xla/examples/axpy/BUILD b/third_party/xla/xla/examples/axpy/BUILD index 9d3424498f3413..3f9daf300f66ad 100644 --- a/third_party/xla/xla/examples/axpy/BUILD +++ b/third_party/xla/xla/examples/axpy/BUILD @@ -10,26 +10,25 @@ xla_cc_test( "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla/client:client_library", - "//xla/client:local_client", - "//xla/pjrt:local_device_state", + "//xla/pjrt:pjrt_api", + "//xla/pjrt:pjrt_c_api_client", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", - "//xla/pjrt:pjrt_stream_executor_client", - "//xla/service:cpu_plugin", - "//xla/service:platform_util", - "//xla/service:stream_pool", - "//xla/service/cpu:cpu_compiler", - "//xla/stream_executor:platform", + "//xla/pjrt/c:pjrt_c_api_cpu", + "//xla/pjrt/c:pjrt_c_api_hdrs", + "//xla/pjrt/cpu:cpu_client", "//xla/tests:literal_test_util", + "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", "@stablehlo//:register", ], diff --git a/third_party/xla/xla/examples/axpy/README.md b/third_party/xla/xla/examples/axpy/README.md index 39bacfb18c5659..c947b63b0d145e 100644 --- a/third_party/xla/xla/examples/axpy/README.md +++ b/third_party/xla/xla/examples/axpy/README.md @@ -1,8 +1,8 @@ # Compile a StableHLO program with XLA This tutorial and the code in this directory shows how to write a simple -StableHLO program and then compile it with XLA. The purpose is simply to -show how XLA can injest a StableHLO program and produce an executable +StableHLO program and then compile it with XLA and PJRT . The purpose is to +show how XLA can ingest StableHLO program and produce an executable that's compatible with the local device. As such, the program is very simple: $\alpha x+y$ ("axpy"). @@ -57,43 +57,12 @@ This code is in [`stablehlo_axpy.mlir`](stablehlo_axpy.mlir). Our program for this tutorial is set up as a test in [`stablehlo_compile_test.cc`](stablehlo_compile_test.cc). In this file, -you'll see that we first set up a `PjRtStreamExecutorClient` that +you'll see that we first set up a `PjrtClient` with the XLA:CPU plugin that allows us to compile our StableHLO program: ```c++ -// Setup client -LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); - -// Retrieve the "platform" we intend to execute the computation on. The -// concept of "platform" in XLA abstracts entirely everything need to -// interact with some hardware (compiler, runtime, etc.). New HW vendor -// plugs into XLA by registering a new platform with a different string -// key. For example for an Nvidia GPU change the following to: -// PlatformUtil::GetPlatform("CUDA")); -TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - PlatformUtil::GetPlatform("cpu")); -TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->ExecutorForDevice(0)); - -// LocalDeviceState and PjRtStreamExecutorDevice describes the state of a -// device which can do computation or transfer buffers. Could represent a GPU -// or accelerator, but we'll use the CPU for this example. -auto device_state = std::make_unique( - executor, local_client, LocalDeviceState::kSynchronous, - /*max_inflight_computations=*/32, - /*allow_event_reuse=*/false, /*use_callback_stream=*/false); -auto device = std::make_unique( - 0, std::move(device_state), "cpu"); -std::vector> devices; -devices.emplace_back(std::move(device)); - -// The PjRtStreamExecutorClient will allow us to compile and execute -// computations on the device we just configured. -auto pjrt_se_client = PjRtStreamExecutorClient( - "cpu", local_client, std::move(devices), /*process_index=*/0, - /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, - /*should_stage_host_to_device_transfers=*/false, - /*gpu_run_options=*/nullptr); +ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient(kCpuPjrtName)); ``` Then we read the StableHLO program from our MLIR file into a string: @@ -130,7 +99,7 @@ compile it to an executable: ```c++ // Use our client to compile our StableHLO program to an executable. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - pjrt_se_client.Compile(*program, CompileOptions{})); + client->Compile(*program, CompileOptions{})); ``` ## 3. Execute the computation diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc index 49a99ee88a679c..8c27d7383a8e44 100644 --- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc +++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc @@ -20,7 +20,8 @@ limitations under the License. #include #include -#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" @@ -28,89 +29,124 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" #include "stablehlo/dialect/Register.h" -#include "xla/client/client_library.h" -#include "xla/client/local_client.h" #include "xla/error_spec.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/pjrt/local_device_state.h" +#include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_cpu.h" +#include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/pjrt_api.h" +#include "xla/pjrt/pjrt_c_api_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/pjrt/pjrt_stream_executor_client.h" -#include "xla/service/platform_util.h" -#include "xla/service/stream_pool.h" -#include "xla/stream_executor/platform.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/statusor.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" namespace xla { namespace { -TEST(StableHloAxpyTest, LoadAndRunCpuExecutable) { - // Setup client - LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); - - // Retrieve the "platform" we intend to execute the computation on. The - // concept of "platform" in XLA abstracts entirely everything needed to - // interact with some hardware (compiler, runtime, etc.). New HW vendor - // plugs into XLA by registering a new platform with a different string - // key. For example for an Nvidia GPU change the following to: - // PlatformUtil::GetPlatform("CUDA")); - TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, - PlatformUtil::GetPlatform("cpu")); - TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->ExecutorForDevice(/*ordinal=*/0)); - - // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a - // device which can do computation or transfer buffers. This could represent a - // GPU or accelerator, but we'll use the CPU for this example. - auto device_state = std::make_unique( - executor, local_client, LocalDeviceState::kSynchronous, - /*max_inflight_computations=*/32, - /*allow_event_reuse=*/false, /*use_callback_stream=*/false); - auto device = std::make_unique( - 0, std::move(device_state), "cpu"); - std::vector> devices; - devices.emplace_back(std::move(device)); - - // The PjRtStreamExecutorClient will allow us to compile and execute - // computations on the device we just configured. - auto pjrt_se_client = - PjRtStreamExecutorClient("cpu", local_client, std::move(devices), - /*process_index=*/0, /*allocator=*/nullptr, - /*host_memory_allocator=*/nullptr, - /*should_stage_host_to_device_transfers=*/false, - /*gpu_run_options=*/nullptr); - - // Read StableHLO program to string. - std::string program_path = tsl::io::JoinPath( - tsl::testing::XlaSrcRoot(), "examples", "axpy", "stablehlo_axpy.mlir"); - std::string program_string; - - TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), program_path, - &program_string)); - - std::cerr << "Loaded StableHLO program from " << program_path << ":\n" - << program_string << std::endl; - - // Register MLIR dialects necessary to parse our program. In our case this is - // just the Func dialect and StableHLO. - mlir::DialectRegistry dialects; - dialects.insert(); - mlir::stablehlo::registerAllDialects(dialects); - - // Parse StableHLO program. - auto ctx = std::make_unique(dialects); - mlir::OwningOpRef program = - mlir::parseSourceString(program_string, ctx.get()); +using ::testing::NotNull; +constexpr absl::string_view kCpuPjrtName = "cpu"; + +std::string GetTestProgramPath() { + return tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "examples", "axpy", + "stablehlo_axpy.mlir"); +} + +class StableHloAxpyTest : public ::testing::Test { + public: + static void SetUpTestSuite() { + StableHloAxpyTest::RegisterXlaCpuPluginTestSetup(); + } + + protected: + static void RegisterXlaCpuPluginTestSetup() { + // PJRT API isn't registered yet, so make sure we have to call register. + absl::StatusOr pjrt_api = pjrt::PjrtApi(kCpuPjrtName); + EXPECT_THAT(pjrt_api, ::testing::Not(::tsl::testing::IsOk())); + + // Grab the XLA:CPU PJRT API from the plugin explicitly. + const PJRT_Api* cpu_api = GetPjrtApi(); + EXPECT_THAT(cpu_api, NotNull()); + + // Register the XLA:CPU PJRT API. + TF_EXPECT_OK(pjrt::SetPjrtApi(kCpuPjrtName, cpu_api)); + } + + absl::StatusOr> CreateStableHloProgram( + absl::string_view program_path) { + // Register MLIR dialects necessary to parse our program. In our case this + // is just the Func dialect and StableHLO. + registry_.insert(); + mlir::stablehlo::registerAllDialects(registry_); + context_.appendDialectRegistry(registry_); + context_.loadAllAvailableDialects(); + + // Read StableHLO program to string. + std::string program_string; + TF_RETURN_IF_ERROR(tsl::ReadFileToString( + tsl::Env::Default(), std::string(program_path), &program_string)); + + std::cerr << "Loaded StableHLO program from " << program_path << ":\n" + << program_string << std::endl; + + return mlir::parseSourceString(program_string, &context_); + } + + private: + mlir::DialectRegistry registry_; + mlir::MLIRContext context_; +}; // class + +TEST_F(StableHloAxpyTest, GetCPUPlugin) { + // Grab the XLA:CPU PJRT API from the plugin explicitly. + const PJRT_Api* cpu_api = GetPjrtApi(); + EXPECT_THAT(cpu_api, NotNull()); + + absl::StatusOr registered_pjrt_api = + pjrt::PjrtApi(kCpuPjrtName); + EXPECT_THAT(registered_pjrt_api, ::tsl::testing::IsOkAndHolds(cpu_api)); +} + +TEST_F(StableHloAxpyTest, UsePjrtCppWrapper) { + absl::StatusOr> client = + GetCApiClient(kCpuPjrtName); + + EXPECT_THAT(client, ::tsl::testing::IsOk()); + EXPECT_THAT(*client, NotNull()); +} + +TEST_F(StableHloAxpyTest, CompileCPUTestProgram) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient(kCpuPjrtName)); + + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef program, + CreateStableHloProgram(GetTestProgramPath())); + + // Use our client to compile our StableHLO program to an executable. + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, + client->Compile(*program, CompileOptions{})); +} + +TEST_F(StableHloAxpyTest, CompileAndExecuteCPUTestProgram) { + // TODO(masonchang): Use the C API client once it supports + // BufferFromHostLiteral. + xla::CpuClientOptions options; + options.cpu_device_count = 4; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + xla::GetTfrtCpuClient(std::move(options))); + + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef program, + CreateStableHloProgram(GetTestProgramPath())); // Use our client to compile our StableHLO program to an executable. TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr executable, - pjrt_se_client.Compile(*program, CompileOptions{})); + client->Compile(*program, CompileOptions{})); // Create inputs to our computation. auto alpha_literal = xla::LiteralUtil::CreateR0(3.14f); @@ -123,18 +159,17 @@ TEST(StableHloAxpyTest, LoadAndRunCpuExecutable) { std::cerr << "\tx:" << x_literal << std::endl; std::cerr << "\ty:" << y_literal << std::endl; - // Get the host device. - PjRtDevice* cpu = pjrt_se_client.devices()[0]; + PjRtDevice* host_cpu = client->devices()[0]; // Transfer our literals to buffers. If we were using a GPU, these buffers // would correspond to device memory. TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr alpha, - pjrt_se_client.BufferFromHostLiteral(alpha_literal, cpu)); + client->BufferFromHostLiteral(alpha_literal, host_cpu)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr x, - pjrt_se_client.BufferFromHostLiteral(x_literal, cpu)); + client->BufferFromHostLiteral(x_literal, host_cpu)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr y, - pjrt_se_client.BufferFromHostLiteral(y_literal, cpu)); + client->BufferFromHostLiteral(y_literal, host_cpu)); // Do our computation. TF_ASSERT_OK_AND_ASSIGN( diff --git a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD index 8583582c5690c4..4218993bd106d5 100644 --- a/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD +++ b/third_party/xla/xla/experiments/sm_bandwidth_benchmark/BUILD @@ -7,8 +7,8 @@ cc_library( name = "sm_bw_utils", hdrs = ["sm_bw_utils.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ "@local_config_cuda//cuda:cuda_headers", @@ -20,7 +20,7 @@ cuda_library( name = "sm_bw_kernels", srcs = ["sm_bw_kernels.cu.cc"], hdrs = ["sm_bw_kernels.h"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":sm_bw_utils", ], @@ -30,8 +30,8 @@ xla_cc_test( name = "sm_bw_test", srcs = ["sm_bw_test.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", "requires-gpu-nvidia", ], deps = [ diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index d4ff3dff4c8215..4a2ca973fc1888 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -126,9 +126,10 @@ cc_library( "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -157,7 +158,9 @@ cc_library( "//xla/ffi/api:c_api_internal", "//xla/hlo/ir:hlo", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", @@ -180,6 +183,7 @@ cc_library( hdrs = ["attribute_map.h"], deps = [ ":call_frame", + "//xla:util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -187,6 +191,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -203,13 +208,12 @@ xla_cc_test( ":ffi_api", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", @@ -228,11 +232,11 @@ cc_library( hdrs = ["type_id_registry.h"], deps = [ "//xla:util", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/gtl:int_type", ], ) diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 952a31eb872388..cf98210af1b717 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -133,6 +133,10 @@ inline std::ostream& operator<<(std::ostream& os, return os << "TOKEN"; case XLA_FFI_DataType_F8E5M2: return os << "F8E5M2"; + case XLA_FFI_DataType_F8E3M4: + return os << "F8E3M4"; + case XLA_FFI_DataType_F8E4M3: + return os << "F8E4M3"; case XLA_FFI_DataType_F8E4M3FN: return os << "F8E4M3FN"; case XLA_FFI_DataType_F8E4M3B11FNUZ: @@ -888,7 +892,7 @@ struct CtxDecoding; // XLA_FFI_Error* Encode(const XLA_FFI_Api* api, // XLA_FFI_ExecutionContext* ctx, // absl::Status status) {...} -// } +// }; // // Result encoding is execution stage specific, for example at instantiation // stage FFI handler can return an FFI handler state, while at execution stage @@ -907,7 +911,7 @@ struct CtxDecoding; // std::variant Encode( // const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, // xla::ffi::Future future) {...} -// } +// }; // template struct ResultEncoding; @@ -1439,10 +1443,21 @@ class Handler : public Ffi { // handler (or a struct decoding) should be responsible for it. if (XLA_FFI_PREDICT_FALSE(kNumDictAttrs == 0 && call_frame->attrs.size != kNumAttrs)) { - return InvalidArgument( - call_frame->api, - StrCat("Wrong number of attributes: expected ", kNumAttrs, - " but got ", call_frame->attrs.size)); + std::stringstream msg; + msg << "Wrong number of attributes: expected " << kNumAttrs << " but got " + << call_frame->attrs.size; + if (call_frame->attrs.size > 0) { + msg << " with name(s): "; + for (int64_t n = 0; n < call_frame->attrs.size - 1; ++n) { + msg << std::string_view(call_frame->attrs.names[n]->ptr, + call_frame->attrs.names[n]->len) + << ", "; + } + msg << std::string_view( + call_frame->attrs.names[call_frame->attrs.size - 1]->ptr, + call_frame->attrs.names[call_frame->attrs.size - 1]->len); + } + return InvalidArgument(call_frame->api, msg.str()); } // Define index sequences to access custom call operands. @@ -1744,13 +1759,14 @@ auto DictionaryDecoder(Members... m) { // binding specification inference from a callable signature. // #define XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(T, ...) \ + namespace xla::ffi { \ template <> \ - struct ::xla::ffi::AttrsBinding { \ + struct AttrsBinding { \ using Attrs = T; \ }; \ \ template <> \ - struct ::xla::ffi::AttrDecoding { \ + struct AttrDecoding { \ using Type = T; \ static std::optional Decode(XLA_FFI_AttrType type, void* attr, \ DiagnosticEngine& diagnostic) { \ @@ -1765,13 +1781,17 @@ auto DictionaryDecoder(Members... m) { reinterpret_cast(attr), \ internal::StructMemberNames(__VA_ARGS__), diagnostic); \ } \ - } + }; \ + } /* namespace xla::ffi */ \ + static_assert(std::is_class_v<::xla::ffi::AttrsBinding>); \ + static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) // Registers decoding for a user-defined enum class type. Uses enums underlying // type to decode the attribute as a scalar value and cast it to the enum type. #define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T) \ + namespace xla::ffi { \ template <> \ - struct ::xla::ffi::AttrDecoding { \ + struct AttrDecoding { \ using Type = T; \ using U = std::underlying_type_t; \ static_assert(std::is_enum::value, "Expected enum class"); \ @@ -1784,7 +1804,8 @@ auto DictionaryDecoder(Members... m) { } \ \ auto* scalar = reinterpret_cast(attr); \ - auto expected_dtype = internal::NativeTypeToCApiDataType(); \ + auto expected_dtype = \ + ::xla::ffi::internal::NativeTypeToCApiDataType(); \ if (XLA_FFI_PREDICT_FALSE(scalar->dtype != expected_dtype)) { \ return diagnostic.Emit("Wrong scalar data type: expected ") \ << expected_dtype << " but got " << scalar->dtype; \ @@ -1793,7 +1814,9 @@ auto DictionaryDecoder(Members... m) { auto underlying = *reinterpret_cast(scalar->value); \ return static_cast(underlying); \ } \ - }; + }; \ + } /* namespace xla::ffi */ \ + static_assert(std::is_class_v<::xla::ffi::AttrDecoding>) //===----------------------------------------------------------------------===// // Helper macro for registering FFI implementations diff --git a/third_party/xla/xla/ffi/api/c_api.h b/third_party/xla/xla/ffi/api/c_api.h index d5e2b11538133f..f0c4f40e78ea7a 100644 --- a/third_party/xla/xla/ffi/api/c_api.h +++ b/third_party/xla/xla/ffi/api/c_api.h @@ -195,6 +195,8 @@ typedef enum { XLA_FFI_DataType_C128 = 18, XLA_FFI_DataType_TOKEN = 17, XLA_FFI_DataType_F8E5M2 = 19, + XLA_FFI_DataType_F8E3M4 = 29, + XLA_FFI_DataType_F8E4M3 = 28, XLA_FFI_DataType_F8E4M3FN = 20, XLA_FFI_DataType_F8E4M3B11FNUZ = 23, XLA_FFI_DataType_F8E5M2FNUZ = 24, diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index b31da22175333d..19eaaf52bb37ce 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -73,10 +73,12 @@ enum class DataType : uint8_t { C128 = XLA_FFI_DataType_C128, TOKEN = XLA_FFI_DataType_TOKEN, F8E5M2 = XLA_FFI_DataType_F8E5M2, + F8E4M3 = XLA_FFI_DataType_F8E4M3, F8E4M3FN = XLA_FFI_DataType_F8E4M3FN, F8E4M3B11FNUZ = XLA_FFI_DataType_F8E4M3B11FNUZ, F8E5M2FNUZ = XLA_FFI_DataType_F8E5M2FNUZ, F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, + F8E3M4 = XLA_FFI_DataType_F8E3M4, }; // Create aliases in ::xla::ffi namespace for all DataTypes, for consistency @@ -98,10 +100,12 @@ inline constexpr DataType C64 = DataType::C64; inline constexpr DataType C128 = DataType::C128; inline constexpr DataType TOKEN = DataType::TOKEN; inline constexpr DataType F8E5M2 = DataType::F8E5M2; +inline constexpr DataType F8E4M3 = DataType::F8E4M3; inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN; inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; +inline constexpr DataType F8E3M4 = DataType::F8E3M4; inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); @@ -117,10 +121,12 @@ constexpr size_t ByteWidth(DataType dtype) { case DataType::S8: case DataType::U8: case DataType::F8E5M2: + case DataType::F8E4M3: case DataType::F8E4M3FN: case DataType::F8E4M3B11FNUZ: case DataType::F8E5M2FNUZ: case DataType::F8E4M3FNUZ: + case DataType::F8E3M4: return 1; case DataType::S16: case DataType::U16: @@ -471,6 +477,13 @@ class AnyBuffer { void* untyped_data() const { return buf_->data; } + template + T* typed_data() const { + assert(internal::NativeTypeToCApiDataType() == buf_->dtype && + "Template type must match the underlying buffer dtype"); + return reinterpret_cast(buf_->data); + } + private: const XLA_FFI_Buffer* buf_; }; diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index bea5176a560b6e..73fe75ed8247ea 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -130,11 +130,13 @@ TEST(FfiTest, DataTypeEnumValue) { EXPECT_EQ(encoded(PrimitiveType::TOKEN), encoded(DataType::TOKEN)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2), encoded(DataType::F8E5M2)); + EXPECT_EQ(encoded(PrimitiveType::F8E4M3), encoded(DataType::F8E4M3)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FN), encoded(DataType::F8E4M3FN)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3B11FNUZ), encoded(DataType::F8E4M3B11FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E5M2FNUZ), encoded(DataType::F8E5M2FNUZ)); EXPECT_EQ(encoded(PrimitiveType::F8E4M3FNUZ), encoded(DataType::F8E4M3FNUZ)); + EXPECT_EQ(encoded(PrimitiveType::F8E3M4), encoded(DataType::F8E3M4)); } TEST(FfiTest, DataTypeByteWidth) { @@ -179,6 +181,8 @@ TEST(FfiTest, DataTypeByteWidth) { EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E5M2), ByteWidth(DataType::F8E5M2)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3), + ByteWidth(DataType::F8E4M3)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FN), ByteWidth(DataType::F8E4M3FN)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3B11FNUZ), @@ -187,6 +191,8 @@ TEST(FfiTest, DataTypeByteWidth) { ByteWidth(DataType::F8E5M2FNUZ)); EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E4M3FNUZ), ByteWidth(DataType::F8E4M3FNUZ)); + EXPECT_EQ(primitive_util::ByteWidth(PrimitiveType::F8E3M4), + ByteWidth(DataType::F8E3M4)); } TEST(FfiTest, ErrorEnumValue) { @@ -358,6 +364,8 @@ TEST(FfiTest, AnyBufferArgument) { auto handler = Ffi::Bind().Arg().To([&](auto buffer) { EXPECT_EQ(buffer.untyped_data(), storage.data()); + EXPECT_EQ(buffer.template typed_data(), + reinterpret_cast(storage.data())); EXPECT_EQ(buffer.dimensions().size(), 2); return Error::Success(); }); @@ -394,6 +402,8 @@ TEST(FfiTest, AnyBufferResult) { auto handler = Ffi::Bind().Ret().To([&](Result buffer) { EXPECT_EQ(buffer->untyped_data(), storage.data()); + EXPECT_EQ(buffer->template typed_data(), + reinterpret_cast(storage.data())); EXPECT_EQ(buffer->dimensions().size(), 2); return Error::Success(); }); @@ -449,6 +459,25 @@ TEST(FfiTest, WrongTypeBufferArgument) { HasSubstr("Wrong buffer dtype: expected F32 but got S32"))); } +TEST(FfiTest, WrongNumberOfArguments) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("foo", 42); + attrs.Insert("bar", 43); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto handler = + Ffi::Bind().Attr("foo").To([](int foo) { return Error::Success(); }); + auto status = Call(*handler, call_frame); + + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Wrong number of attributes"))); + EXPECT_THAT(status.message(), HasSubstr("foo")); + EXPECT_THAT(status.message(), HasSubstr("bar")); +} + TEST(FfiTest, TokenArgument) { CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); builder.AddBufferArg(se::DeviceMemoryBase(), PrimitiveType::TOKEN, @@ -822,10 +851,10 @@ TEST(FfiTest, AttrsAsDictionary) { } TEST(FfiTest, DictionaryAttr) { - CallFrameBuilder::FlatAttributesMap dict0; + CallFrameBuilder::AttributesMap dict0; dict0.try_emplace("i32", 42); - CallFrameBuilder::FlatAttributesMap dict1; + CallFrameBuilder::AttributesMap dict1; dict1.try_emplace("f32", 42.0f); CallFrameBuilder::AttributesBuilder attrs; @@ -864,7 +893,7 @@ TEST(FfiTest, DictionaryAttr) { } TEST(FfiTest, StructAttr) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); dict.try_emplace("f32", 42.0f); @@ -977,7 +1006,7 @@ TEST(FfiTest, EnumAttr) { } TEST(FfiTest, WrongEnumAttrType) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); CallFrameBuilder::AttributesBuilder attrs; diff --git a/third_party/xla/xla/ffi/attribute_map.cc b/third_party/xla/xla/ffi/attribute_map.cc index 33d756f6fcce2b..d774362eadf1fa 100644 --- a/third_party/xla/xla/ffi/attribute_map.cc +++ b/third_party/xla/xla/ffi/attribute_map.cc @@ -16,7 +16,10 @@ limitations under the License. #include "xla/ffi/attribute_map.h" #include +#include #include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -27,122 +30,171 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "xla/ffi/call_frame.h" #include "tsl/platform/errors.h" - -using FlatAttribute = xla::ffi::CallFrameBuilder::FlatAttribute; -using FlatAttributesMap = xla::ffi::CallFrameBuilder::FlatAttributesMap; +#include "tsl/platform/statusor.h" namespace xla::ffi { -absl::StatusOr BuildAttributesMap( - mlir::DictionaryAttr dict) { - FlatAttributesMap attributes; - for (auto& kv : dict) { - std::string_view name = kv.getName().strref(); +static absl::StatusOr ConvertBoolAttr( + std::string_view name, mlir::BoolAttr boolean) { + return static_cast(boolean.getValue()); +} - auto boolean = [&](mlir::BoolAttr boolean) { - attributes[name] = static_cast(boolean.getValue()); - return absl::OkStatus(); - }; +static absl::StatusOr ConvertStringAttr( + std::string_view name, mlir::StringAttr str) { + return str.getValue().str(); +} - auto integer = [&](mlir::IntegerAttr integer) { - if (integer.getType().isUnsignedInteger()) { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getUInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } else { - switch (integer.getType().getIntOrFloatBitWidth()) { - case 8: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 16: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 32: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - case 64: - attributes[name] = static_cast(integer.getInt()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported integer attribute bit width for attribute: ", - name)); - } - } - }; +static absl::StatusOr ConvertIntegerAttr( + std::string_view name, mlir::IntegerAttr integer) { + if (integer.getType().isUnsignedInteger()) { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + return static_cast(integer.getUInt()); + case 16: + return static_cast(integer.getUInt()); + case 32: + return static_cast(integer.getUInt()); + case 64: + return static_cast(integer.getUInt()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + } else { + switch (integer.getType().getIntOrFloatBitWidth()) { + case 8: + return static_cast(integer.getInt()); + case 16: + return static_cast(integer.getInt()); + case 32: + return static_cast(integer.getInt()); + case 64: + return static_cast(integer.getInt()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported integer attribute bit width for attribute: ", name)); + } + } +} + +static absl::StatusOr ConvertFloatAttr( + std::string_view name, mlir::FloatAttr fp) { + switch (fp.getType().getIntOrFloatBitWidth()) { + case 32: + return static_cast(fp.getValue().convertToFloat()); + case 64: + return static_cast(fp.getValue().convertToDouble()); + default: + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported float attribute bit width for attribute: ", name)); + } +} - auto fp = [&](mlir::FloatAttr fp) { - switch (fp.getType().getIntOrFloatBitWidth()) { +static absl::StatusOr ConvertArrayAttr( + std::string_view name, mlir::DenseArrayAttr arr) { + if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else if (auto dense = mlir::dyn_cast(arr)) { + return dense.asArrayRef().vec(); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported array element type for attribute: ", name)); + } +} + +template +static std::vector CopyDenseElementsToVec( + mlir::DenseIntOrFPElementsAttr arr) { + auto it = arr.getValues(); + return std::vector(it.begin(), it.end()); +} + +static absl::StatusOr ConvertDenseElementsAttr( + std::string_view name, mlir::DenseIntOrFPElementsAttr arr) { + auto type = arr.getElementType(); + if (type.isInteger()) { + if (type.isUnsignedInteger()) { + switch (type.getIntOrFloatBitWidth()) { + case 8: + return CopyDenseElementsToVec(arr); + case 16: + return CopyDenseElementsToVec(arr); case 32: - attributes[name] = static_cast(fp.getValue().convertToFloat()); - return absl::OkStatus(); + return CopyDenseElementsToVec(arr); case 64: - attributes[name] = - static_cast(fp.getValue().convertToDouble()); - return absl::OkStatus(); - default: - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported float attribute bit width for attribute: ", name)); + return CopyDenseElementsToVec(arr); } - }; - - auto arr = [&](mlir::DenseArrayAttr arr) { - if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else if (auto dense = mlir::dyn_cast(arr)) { - attributes[name] = dense.asArrayRef().vec(); - return absl::OkStatus(); - } else { - return absl::InvalidArgumentError(absl::StrCat( - "Unsupported array element type for attribute: ", name)); + } else { + switch (type.getIntOrFloatBitWidth()) { + case 8: + return CopyDenseElementsToVec(arr); + case 16: + return CopyDenseElementsToVec(arr); + case 32: + return CopyDenseElementsToVec(arr); + case 64: + return CopyDenseElementsToVec(arr); } - }; + } + } else if (type.isIntOrFloat()) { + switch (type.getIntOrFloatBitWidth()) { + case 32: + return CopyDenseElementsToVec(arr); + case 64: + return CopyDenseElementsToVec(arr); + } + } + return absl::InvalidArgumentError( + absl::StrCat("Unsupported array element type for attribute: ", name)); +} + +static absl::StatusOr ConvertDictionaryAttr( + std::string_view name, mlir::DictionaryAttr dict) { + TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict)); + return CallFrameBuilder::Dictionary{ + std::make_shared(std::move(attrs))}; +} - auto str = [&](mlir::StringAttr str) { - attributes[name] = str.getValue().str(); - return absl::OkStatus(); +absl::StatusOr BuildAttributesMap( + mlir::DictionaryAttr dict) { + CallFrameBuilder::AttributesMap attributes; + for (auto& kv : dict) { + std::string_view name = kv.getName().strref(); + mlir::Attribute value = kv.getValue(); + + // Wraps attribute conversion function into callable object. + auto convert_with = [&](auto converter_fn) { + return [&, fn = converter_fn](auto attr) -> absl::Status { + TF_ASSIGN_OR_RETURN(attributes[name], fn(name, attr)); + return absl::OkStatus(); + }; }; TF_RETURN_IF_ERROR( - llvm::TypeSwitch(kv.getValue()) - .Case(boolean) - .Case(integer) - .Case(fp) - .Case(arr) - .Case(str) + llvm::TypeSwitch(value) + .Case(convert_with(ConvertBoolAttr)) + .Case(convert_with(ConvertIntegerAttr)) + .Case(convert_with(ConvertFloatAttr)) + .Case(convert_with(ConvertArrayAttr)) + .Case( + convert_with(ConvertDenseElementsAttr)) + .Case(convert_with(ConvertStringAttr)) + .Case(convert_with(ConvertDictionaryAttr)) .Default([&](mlir::Attribute) { return absl::InvalidArgumentError(absl::StrCat( "Unsupported attribute type for attribute: ", name)); })); } + return attributes; } + } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/attribute_map.h b/third_party/xla/xla/ffi/attribute_map.h index cb9415ff3eb09b..43ad41888772cb 100644 --- a/third_party/xla/xla/ffi/attribute_map.h +++ b/third_party/xla/xla/ffi/attribute_map.h @@ -24,7 +24,7 @@ namespace xla::ffi { // Converts MLIR dictionary attribute attached to a custom call operation to a // custom call handler attributes that are forwarded to the FFI handler. -absl::StatusOr BuildAttributesMap( +absl::StatusOr BuildAttributesMap( mlir::DictionaryAttr dict); } // namespace xla::ffi diff --git a/third_party/xla/xla/ffi/call_frame.cc b/third_party/xla/xla/ffi/call_frame.cc index 655aa6a02f69a2..3fb2ac3c7786fa 100644 --- a/third_party/xla/xla/ffi/call_frame.cc +++ b/third_party/xla/xla/ffi/call_frame.cc @@ -65,20 +65,17 @@ CallFrameBuilder::AttributesBuilder::AttributesBuilder() = default; CallFrameBuilder::AttributesBuilder::~AttributesBuilder() = default; void CallFrameBuilder::AttributesBuilder::Insert(std::string name, - FlatAttribute attr) { - attrs_.try_emplace(std::move(name), FromFlatAttribute(std::move(attr))); + Attribute attr) { + attrs_.try_emplace(std::move(name), std::move(attr)); } void CallFrameBuilder::AttributesBuilder::Insert(std::string name, - FlatAttributesMap attrs) { - AttributesBuilder builder; - for (auto& [name, attr] : attrs) builder.Insert(name, std::move(attr)); - - auto attrs_map = std::make_unique(builder.Build()); - attrs_.try_emplace(std::move(name), Dictionary{std::move(attrs_map)}); + AttributesMap attrs) { + attrs_.try_emplace(std::move(name), + Dictionary{std::make_shared(attrs)}); } -void CallFrameBuilder::AttributesBuilder::Append(FlatAttributesMap attrs) { +void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) { for (auto& [name, attr] : attrs) Insert(name, std::move(attr)); } @@ -268,10 +265,12 @@ static XLA_FFI_DataType ToDataType(PrimitiveType primitive_type) { case PrimitiveType::C128: case PrimitiveType::TOKEN: case PrimitiveType::F8E5M2: + case PrimitiveType::F8E4M3: case PrimitiveType::F8E4M3FN: case PrimitiveType::F8E4M3B11FNUZ: case PrimitiveType::F8E5M2FNUZ: case PrimitiveType::F8E4M3FNUZ: + case PrimitiveType::F8E3M4: return static_cast(primitive_type); default: DCHECK(false) << "Unsupported primitive type " diff --git a/third_party/xla/xla/ffi/call_frame.h b/third_party/xla/xla/ffi/call_frame.h index 526723b3a92d80..0614bd750fd29e 100644 --- a/third_party/xla/xla/ffi/call_frame.h +++ b/third_party/xla/xla/ffi/call_frame.h @@ -81,9 +81,10 @@ class CallFrameBuilder { using AttributesMap = absl::flat_hash_map; // Dictionary is just a wrapper around AttributesMap. We need an indirection - // through `std::unique_ptr` to be able to define recursive `std::variant`. + // through `std::shared_ptr` to be able to define recursive `std::variant`. We + // use shared pointer to keep `AttributesMap` copyable. struct Dictionary { - std::unique_ptr attrs; + std::shared_ptr attrs; }; // A helper class to build call frame attributes. @@ -92,14 +93,14 @@ class CallFrameBuilder { AttributesBuilder(); ~AttributesBuilder(); + void Insert(std::string name, Attribute attr); + void Insert(std::string name, AttributesMap attrs); + void Append(AttributesMap attrs); + // This overload is only necessary to support older GCC versions. void Insert(std::string name, const char* attr) { - Insert(std::move(name), std::string(attr)); + Insert(std::move(name), Attribute{std::string(attr)}); } - void Insert(std::string name, FlatAttribute attr); - void Insert(std::string name, FlatAttributesMap attrs); - - void Append(FlatAttributesMap attrs); AttributesMap Build(); diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index 7b767bfb841af8..89d306455e6a19 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -130,14 +130,14 @@ void BM_AddBufferArg(benchmark::State& state) { void BM_AddAttributes(benchmark::State& state) { size_t num_attrs = state.range(0); - CallFrameBuilder::FlatAttributesMap flat_attrs; + CallFrameBuilder::AttributesMap attrs; for (size_t i = 0; i < num_attrs; ++i) { - flat_attrs.try_emplace(absl::StrCat("attr_", i), 42); + attrs.try_emplace(absl::StrCat("attr_", i), 42); } for (auto _ : state) { CallFrameBuilder::AttributesBuilder attrs_builder; - attrs_builder.Append(flat_attrs); + attrs_builder.Append(attrs); CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); builder.AddAttributes(attrs_builder.Build()); diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 17d57671e5170c..70fdc1f834b50f 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -68,6 +68,9 @@ struct ScratchAllocator {}; // binds `se::OwningScratchAllocator` struct CalledComputation {}; // binds `HloComputation*` struct IntraOpThreadPool {}; // binds `const Eigen::ThreadPoolDevice*` +template +struct PlatformStream {}; // binds a platform stream, e.g. `cudaStream_t` + //===----------------------------------------------------------------------===// // Arguments //===----------------------------------------------------------------------===// @@ -110,6 +113,13 @@ class AnyBuffer { void* untyped_data() const { return buf_->data; } + template + T* typed_data() const { + DCHECK(primitive_util::NativeToPrimitiveType() == element_type()) + << "Template type must match the underlying buffer dtype"; + return reinterpret_cast(buf_->data); + } + se::DeviceMemoryBase device_memory() const { return se::DeviceMemoryBase(untyped_data(), size_bytes()); } @@ -459,8 +469,11 @@ struct CtxDecoding { static std::optional Decode(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx, - DiagnosticEngine&) { + DiagnosticEngine& diagnostic) { void* ptr = api->internal_api->XLA_FFI_INTERNAL_Stream_Get(ctx); + if (ABSL_PREDICT_FALSE(ptr == nullptr)) { + return diagnostic.Emit("Failed to decode stream"); + } return reinterpret_cast(ptr); } }; @@ -532,6 +545,22 @@ struct CtxDecoding { } }; +template +struct CtxDecoding> { + using Type = T; + static_assert(std::is_pointer_v, "platform stream type must be a pointer"); + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine& diagnostic) { + if (auto stream = CtxDecoding::Decode(api, ctx, diagnostic)) { + return reinterpret_cast( + stream.value()->platform_specific_handle().stream); + } + return std::nullopt; + } +}; + //===----------------------------------------------------------------------===// // UserData //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 128bf997ae263f..483630c0a648ef 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -198,8 +198,10 @@ TEST(FfiTest, WrongNumAttrs) { auto status = Call(*handler, call_frame); - ASSERT_EQ(status.message(), - "Wrong number of attributes: expected 1 but got 2"); + EXPECT_THAT( + status, + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Wrong number of attributes: expected 1 but got 2"))); } TEST(FfiTest, BuiltinAttributes) { @@ -379,10 +381,10 @@ TEST(FfiTest, AttrsAsDictionary) { } TEST(FfiTest, DictionaryAttr) { - CallFrameBuilder::FlatAttributesMap dict0; + CallFrameBuilder::AttributesMap dict0; dict0.try_emplace("i32", 42); - CallFrameBuilder::FlatAttributesMap dict1; + CallFrameBuilder::AttributesMap dict1; dict1.try_emplace("f32", 42.0f); CallFrameBuilder::AttributesBuilder attrs; @@ -421,7 +423,7 @@ TEST(FfiTest, DictionaryAttr) { } TEST(FfiTest, StructAttr) { - CallFrameBuilder::FlatAttributesMap dict; + CallFrameBuilder::AttributesMap dict; dict.try_emplace("i32", 42); dict.try_emplace("f32", 42.0f); @@ -528,6 +530,8 @@ TEST(FfiTest, AnyBufferArgument) { auto fn = [&](AnyBuffer buffer) { EXPECT_EQ(buffer.element_type(), PrimitiveType::F32); EXPECT_EQ(buffer.untyped_data(), storage.data()); + EXPECT_EQ(buffer.typed_data(), + reinterpret_cast(storage.data())); AnyBuffer::Dimensions dimensions = buffer.dimensions(); EXPECT_EQ(dimensions.size(), 2); EXPECT_EQ(dimensions[0], 2); @@ -1071,6 +1075,21 @@ TEST(FfiTest, MetadataTraits) { EXPECT_EQ(metadata.api_version.minor_version, XLA_FFI_API_MINOR); } +// Use opaque struct to define a platform stream type just like platform +// stream for GPU backend (e.g. `CUstream_st` and `cudaStream_t`). +struct TestStreamSt; +using TestStream = TestStreamSt*; + +template <> +struct CtxBinding { + using Ctx = PlatformStream; +}; + +TEST(FfiTest, PlatformStream) { + // We only check that it compiles. + (void)Ffi::BindTo(+[](TestStream stream) { return absl::OkStatus(); }); +} + //===----------------------------------------------------------------------===// // Performance benchmarks are below. //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/type_id_registry.h b/third_party/xla/xla/ffi/type_id_registry.h index 116142b3de0f23..5672ac691e253b 100644 --- a/third_party/xla/xla/ffi/type_id_registry.h +++ b/third_party/xla/xla/ffi/type_id_registry.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla::ffi { diff --git a/third_party/xla/xla/fp_util_test.cc b/third_party/xla/xla/fp_util_test.cc index 36f0c5be9d5bde..3eb7c54f919b0a 100644 --- a/third_party/xla/xla/fp_util_test.cc +++ b/third_party/xla/xla/fp_util_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/base/casts.h" #include "absl/numeric/bits.h" #include "xla/bit_cast.h" @@ -111,21 +112,74 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, 0x1.fffffffffffffp-127, 0x1.aaaaaaaaaaaaap-127)); -// Test F8E4M3 floating-point types (F8E4M3FN) +// Test F8E4M3 floating-point types (F8E4M3, F8E4M3FN) template class FP8E4M3DistanceTest : public ::testing::Test {}; -using F8E4M3Types = ::testing::Types; +using F8E4M3Types = ::testing::Types; TYPED_TEST_SUITE(FP8E4M3DistanceTest, F8E4M3Types); +TEST(FPDistanceTest, F8E3M4Distance) { + // a & b are equal + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(8.0)), + 0); + + // a & b have the same exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(15.5)), + 15); + + // a & b have different exponents + EXPECT_EQ(CalculateDistanceInFloats(tsl::float8_e3m4(8.0), + tsl::float8_e3m4(6)), + 8); + + // 1 from 0 in the positive direction + EXPECT_EQ(CalculateDistanceInFloats( + std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // 1 from 0 in the negative direction + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + tsl::float8_e3m4(0)), + 1); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::denorm_min(), + std::numeric_limits::denorm_min()), + 2); + + // 1 non denorm from 0 in the positive direction + EXPECT_EQ( + CalculateDistanceInFloats( + std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // 1 non denorm from 0 in the negative direction + EXPECT_EQ( + CalculateDistanceInFloats( + -std::numeric_limits::min(), tsl::float8_e3m4(0)), + 16); + + // a & b have different signs + EXPECT_EQ(CalculateDistanceInFloats( + -std::numeric_limits::min(), + std::numeric_limits::min()), + 32); +} + TYPED_TEST(FP8E4M3DistanceTest, F8E4M3Distance) { // a & b are equal, distance should be 0 EXPECT_EQ( CalculateDistanceInFloats(TypeParam(8.0), TypeParam(8.0)), 0); // a & b have the same exponents - EXPECT_EQ(CalculateDistanceInFloats(TypeParam(8.0), TypeParam(13)), - 5); + EXPECT_EQ( + CalculateDistanceInFloats(TypeParam(8.0), TypeParam(15.0)), 7); // a & b have different exponents EXPECT_EQ( diff --git a/third_party/xla/xla/hlo/analysis/BUILD b/third_party/xla/xla/hlo/analysis/BUILD new file mode 100644 index 00000000000000..e640deafb2f018 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/BUILD @@ -0,0 +1,507 @@ +# Description: +# HLO analysis implementation. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_dfs_reachability", + srcs = ["hlo_dfs_reachability.cc"], + hdrs = ["hlo_dfs_reachability.h"], + deps = [ + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@llvm-project//llvm:Support", + ], +) + +xla_cc_test( + name = "hlo_dfs_reachability_test", + srcs = ["hlo_dfs_reachability_test.cc"], + deps = [ + ":hlo_dfs_reachability", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:computation_placer_hdr", + "//xla/service:hlo_module_config", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_reachability", + srcs = ["hlo_reachability.cc"], + hdrs = ["hlo_reachability.h"], + deps = [ + "//xla:types", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "hlo_reachability_test", + srcs = ["hlo_reachability_test.cc"], + deps = [ + ":hlo_reachability", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:computation_placer", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/random", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_ordering", + srcs = ["hlo_ordering.cc"], + hdrs = ["hlo_ordering.h"], + deps = [ + ":hlo_dataflow_analysis", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_reachability", + "//xla/service:call_graph", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_value", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_ordering_test", + size = "small", + srcs = ["hlo_ordering_test.cc"], + deps = [ + ":hlo_dataflow_analysis", + ":hlo_ordering", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_value", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "while_loop_analysis", + srcs = ["while_loop_analysis.cc"], + hdrs = ["while_loop_analysis.h"], + deps = [ + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_reachability", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + +xla_cc_test( + name = "while_loop_analysis_test", + srcs = ["while_loop_analysis_test.cc"], + deps = [ + ":while_loop_analysis", + "//xla:comparison_util", + "//xla:test", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_dataflow_analysis", + srcs = ["hlo_dataflow_analysis.cc"], + hdrs = ["hlo_dataflow_analysis.h"], + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "//xla/service:hlo_phi_graph", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_dataflow_analysis_test", + srcs = ["hlo_dataflow_analysis_test.cc"], + deps = [ + ":hlo_dataflow_analysis", + ":hlo_ordering", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:flatten_call_graph", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_dce", + "//xla/service:hlo_value", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_value_semantics_analysis", + srcs = ["hlo_value_semantics_analysis.cc"], + hdrs = ["hlo_value_semantics_analysis.h"], + deps = [ + "//xla:shape_tree", + "//xla:shape_util", + "//xla:side_effect_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_value_semantics_analysis_test", + srcs = ["hlo_value_semantics_analysis_test.cc"], + deps = [ + ":hlo_value_semantics_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_replication_analysis", + srcs = ["hlo_replication_analysis.cc"], + hdrs = ["hlo_replication_analysis.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "hlo_replication_analysis_test", + srcs = ["hlo_replication_analysis_test.cc"], + deps = [ + ":hlo_replication_analysis", + "//xla:shape_util", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_liveness_analysis", + srcs = ["hlo_liveness_analysis.cc"], + hdrs = ["hlo_liveness_analysis.h"], + deps = [ + "//xla:shape_tree", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:call_graph", + "//xla/service:hlo_value", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_liveness_analysis_test", + srcs = ["hlo_liveness_analysis_test.cc"], + deps = [ + ":hlo_liveness_analysis", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_alias_analysis", + srcs = ["hlo_alias_analysis.cc"], + hdrs = ["hlo_alias_analysis.h"], + deps = [ + ":hlo_dataflow_analysis", + ":hlo_ordering", + "//xla:comparison_util", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_buffer", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_alias_analysis_test", + srcs = ["hlo_alias_analysis_test.cc"], + deps = [ + ":hlo_alias_analysis", + ":hlo_ordering", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:flatten_call_graph", + "//xla/service:hlo_buffer", + "//xla/service:hlo_value", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "logical_buffer_analysis", + srcs = ["logical_buffer_analysis.cc"], + hdrs = ["logical_buffer_analysis.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:logical_buffer", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "tuple_points_to_analysis", + srcs = ["tuple_points_to_analysis.cc"], + hdrs = ["tuple_points_to_analysis.h"], + deps = [ + ":logical_buffer_analysis", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:logical_buffer", + "//xla/tsl/lib/gtl:compactptrset", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "tuple_points_to_analysis_test", + srcs = ["tuple_points_to_analysis_test.cc"], + deps = [ + ":tuple_points_to_analysis", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:logical_buffer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "indexed_array_analysis", + srcs = ["indexed_array_analysis.cc"], + hdrs = ["indexed_array_analysis.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "indexed_array_analysis_test", + srcs = ["indexed_array_analysis_test.cc"], + deps = [ + ":indexed_array_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/hlo_alias_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc similarity index 99% rename from third_party/xla/xla/service/hlo_alias_analysis.cc rename to third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc index c140f077ea52bf..7a489404bd6456 100644 --- a/third_party/xla/xla/service/hlo_alias_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.h new file mode 100644 index 00000000000000..0d1462d4845305 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis.h @@ -0,0 +1,126 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_ALIAS_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_HLO_ALIAS_ANALYSIS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_buffer.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Analysis which allocates HloBuffers to HloValues. +class HloAliasAnalysis { + public: + // The callgraph of the given HloModule must be flattened + // (xla::FlattenCallGraph) prior to running the analysis. + static absl::StatusOr> Run( + const HloModule* module, + const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr); + + std::string ToString() const; + + // Return the buffer containing the given value. + const HloBuffer& GetBufferContainingValue(const HloValue& value) const { + return *value_to_buffer_.at(&value); + } + HloBuffer& GetBufferContainingValue(const HloValue& value) { + return *value_to_buffer_.at(&value); + } + + // Return the HloBuffer with the given ID. + const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const { + return buffers_.at(buffer_id); + } + HloBuffer& GetBuffer(HloBuffer::Id buffer_id) { + return buffers_.at(buffer_id); + } + + // Returns the unique buffer at the given position. CHECK fails if the buffer + // set at that position does not contain exactly one buffer. + const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Compute the set of buffers at the given instruction and index and return as + // a vector. This set is exactly the union of the buffers containing the + // HloValues at this position. + std::vector ComputeBuffersAt( + const HloInstruction* instruction, const ShapeIndex& index = {}) const; + + // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This + // vector is lazily computed. Mutating operations on HloAliasAnalysis may + // invalidate the underlying vector requiring recomputation. + const std::vector& buffers() const { return buffers_; } + + // Returns the underlying dataflow analysis used by this alias analysis. + HloDataflowAnalysis& dataflow_analysis() const { return *dataflow_analysis_; } + + // Returns true if a buffer lives out of the module. + bool BufferLivesOut(const HloBuffer& buffer) const { + return live_out_buffers_.contains(&buffer); + } + + // Returns true if a hlo value lives out of the module. + bool ValueLivesOut(const HloValue& value) const { + return live_out_buffers_.contains(&GetBufferContainingValue(value)); + } + + std::vector LiveOutBuffers() const { + std::vector results(live_out_buffers_.begin(), + live_out_buffers_.end()); + absl::c_sort(results, HloBuffer::IdLessThan); + return results; + } + + protected: + explicit HloAliasAnalysis(const HloModule* module); + + // Verify various invariants of the alias analysis. + absl::Status Verify() const; + + const HloModule* module_; + + // A set of buffers that live out the module. + absl::flat_hash_set live_out_buffers_; + + // The underlying dataflow analysis used by this alias analysis. + std::unique_ptr dataflow_analysis_; + + // A map indicating which buffer a value is contained in. + absl::flat_hash_map value_to_buffer_; + + // A lazily constructed vector containing all HloBuffers sorted by + // HloBuffer::Id. + std::vector buffers_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_ALIAS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_alias_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc similarity index 98% rename from third_party/xla/xla/service/hlo_alias_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc index ea687b640a58ef..00109570e14d18 100644 --- a/third_party/xla/xla/service/hlo_alias_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_alias_analysis_test.cc @@ -13,24 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" -#include #include - +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" #include "xla/service/flatten_call_graph.h" -#include "xla/service/hlo_ordering.h" -#include "xla/service/instruction_fusion.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -38,9 +46,9 @@ namespace { using ::testing::UnorderedElementsAre; -class HloAliasAnalysisTest : public HloTestBase { +class HloAliasAnalysisTest : public HloHardwareIndependentTestBase { protected: - HloAliasAnalysisTest() : HloTestBase() { + HloAliasAnalysisTest() : HloHardwareIndependentTestBase() { module_ = CreateNewVerifiedModule(); } diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc similarity index 99% rename from third_party/xla/xla/service/hlo_dataflow_analysis.cc rename to third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc index e83b284b107d35..2743067dd88026 100644 --- a/third_party/xla/xla/service/hlo_dataflow_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include #include @@ -1895,7 +1895,11 @@ GetFusionInstructionInPlaceInputOutputPairs(const HloInstruction* instruction) { } } } - + // Skip bitcast + if (in_place_input_source != nullptr && + in_place_input_source->opcode() == HloOpcode::kBitcast) { + in_place_input_source = in_place_input_source->operand(0); + } if (in_place_input_source != nullptr && in_place_input_source->opcode() == HloOpcode::kParameter) { in_place_input_output_pairs.emplace_back( diff --git a/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h new file mode 100644 index 00000000000000..5509e137043333 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis.h @@ -0,0 +1,407 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Analysis for determining the possible set of values for all positions +// (instructions and ShapeIndexes) in the HLO module. Analysis is module-scoped +// tracking values across computation boundaries. + +#ifndef XLA_HLO_ANALYSIS_HLO_DATAFLOW_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_HLO_DATAFLOW_ANALYSIS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_phi_graph.h" +#include "xla/service/hlo_value.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Identifies one array input of an HloInstruction. +struct HloOperandIndex { + using MyTuple = std::tuple; + + template + friend H AbslHashValue(H h, const HloOperandIndex& hlo_operand_index) { + return H::combine(std::move(h), hlo_operand_index.ToTuple()); + } + + friend bool operator==(const HloOperandIndex& lhs, + const HloOperandIndex& rhs) { + return lhs.ToTuple() == rhs.ToTuple(); + } + + bool operator!=(const HloOperandIndex& other) const { + return !(*this == other); + } + + MyTuple ToTuple() const { + return std::make_tuple(operand_number, std::cref(operand_index)); + } + + // The operand number in which the array value appears. + int64_t operand_number; + + // The shape index within the operand in which the array value appears. + ShapeIndex operand_index; +}; + +// Analysis which identifies all HLO values and their uses in an HLO module. +class HloDataflowAnalysis { + public: + // Infrastructure for passing may-alias hints: HLO passes can populate the + // may-alias table. If an empty optional is returned, default rules are used. + // + // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be + // overriden using backend-specific overrides. + // + // The first parameter of the function should be the instruction, the + // second parameter should be an operand of the instruction. The third + // parameter should be the output index of the instruction. + using CanShareBuffer = std::function( + const HloInstruction* instr, const HloInstruction* operand, + const ShapeIndex& user_index)>; + + // Infrastructure for overriding whether an instruction defines a new value. + // + // The first parameter is the instruction and the second parameter is the + // output index. If an empty optional is used, default rules are used. If a + // ForwardedOperand object is returned, the value at the corresponding + // operand's index is used for the output, overriding all default logic. + struct ForwardedOperand { + int64_t operand_number; + ShapeIndex operand_index; + }; + using ForwardsValue = std::function( + const HloInstruction* instr, const ShapeIndex& index)>; + + // Runs dataflow analysis on the given module. Parameters: + // + // ssa_form : If true then new values are defined at the merge points of + // kWhile instructions. Abusing nomenclature somewhat, we call these "phi + // values". The merge is formed by the init value and loop backedge. The + // SSA form is minimal in that a new phi value is defined only if the + // merge point is reachable by multiple different values. The SSA form is + // also in loop-closed form in that no values defined inside of a loop + // (while body) is used outside of the loop. Example use of this ssa_form + // mode is to reason about live range interference of buffers. + // + // If ssa_form is false, then merge points do not define new + // values. Rather, the HloValueSet for the merge point contains the union + // of the merged HloValues. + // + // bitcast_defines_value : If true then the Bitcast HLO instruction defines + // a new HLO value in the analysis. If false then Bitcast forwards the + // value of its operand. + static absl::StatusOr> Run( + const HloModule& module, bool ssa_form = false, + bool bitcast_defines_value = false, + const CanShareBuffer& can_share_buffer = nullptr, + const ForwardsValue& forwards_value = nullptr, + absl::flat_hash_set execution_threads = {}); + + // Returns true if 'instruction' defines an HLO value at the given shape index + // of its output. + bool ValueIsDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + // Returns the HloValue defined by 'instruction' at the given shape index of + // its output. + // + // Precondition: ValueIsDefinedAt is true for this instruction and index. + const HloValue& GetValueDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + HloValue& GetValueDefinedAt(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Returns the InstructionValueSet for the given instruction. + const InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction) const; + InstructionValueSet& GetInstructionValueSet( + const HloInstruction* instruction); + + // Returns all values that are contained in the output of this instruction in + // a flattened set. + HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; + + // Returns the HloValueSet for the given instruction at the given index or the + // given position. + const HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + const HloValueSet& GetValueSet(const HloPosition& position) const; + HloValueSet& GetValueSet(const HloPosition& position); + HloValueSet& GetValueSet(const HloInstruction* instruction, + const ShapeIndex& index = {}); + + // Returns the unique value in the HloValueSet at the given instruction and + // shape index. CHECKs if the value set does not contain a exactly one value. + const HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + return GetValueSet(instruction, index).GetUniqueValue(); + } + HloValue& GetUniqueValueAt(const HloInstruction* instruction, + const ShapeIndex& index = {}) { + return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); + } + + // Returns the HloValue with the given Id. + const HloValue& GetValue(HloValue::Id value_id) const; + HloValue& GetValue(HloValue::Id value_id); + + // Returns the total number of HloValues. + int64_t value_count() const { return values_.size(); } + + // Returns a vector of all HloValues stabily sorted by HloValue::Id. + const std::vector& values() const { return values_vector_; } + + // Returns the call graph used for computing the dataflow. + const CallGraph& call_graph() const { return *call_graph_; } + + std::string ToString() const; + + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // 'operand' does not have to be an operand of 'user'. This can be the + // case with indirect uses. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + // Returns true if 'user' (at 'user_index') can share a buffer with its + // operand 'operand' (at 'operand_index'). Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index) const; + + const HloModule& module() const { return module_; } + + // Returns true if the operation is an in-place operation and its operand 0 + // must alias with the output. + static bool IsInPlaceOperation(HloOpcode opcode); + + // Returns true if the operation is the start/done of an asynchronous + // operation, where the buffer used/produced by the op needs to stay alive + // until the asynchronous operation completes. + static bool IsAsynchronousOperationStart(HloOpcode opcode); + static bool IsAsynchronousOperationDone(HloOpcode opcode); + + // Returns the pairs of inputs and outputs that must share the same buffer, + // according to the aliasing rules for that instruction. + // + // This function only considers array values as inputs and outputs, so + // when tuples are present it "sees through" to the array values inside. The + // HloUse describing the input parameter contains not only the operand number + // but also a shape index describing its position inside a nested tuple shape + // (if any). Similarly, the output parameter is described by a shape index + // into the nested tuple shape (if any) of the output value. + // + // For example, for this hypothetical op: + // %foo = (f32[1], (f32[2], f32[3])) + // op((f32[4], f32[5]) %arg0, f32[6] %arg1) + // + // ... the results can include any of the 3 * 3 = 9 possible pairs of + // input and output arrays. + static std::vector> + GetInPlaceInputOutputPairs(const HloInstruction* instruction); + + // Verifies various invariants of the dataflow analysis. + absl::Status Verify() const; + + private: + static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); + + HloDataflowAnalysis(const HloModule& module, bool ssa_form, + bool bitcast_defines_value, + const CanShareBuffer& can_share_buffer, + const ForwardsValue& forwards_value, + absl::flat_hash_set execution_threads); + + // 1. During value propagation (Propagate function), always create phi + // values once it see multiple inputs merging at the same point. It then + // records those phi values as well as their inputs in a phi graph. + // + // 2. Post value propagation, Dataflow analysis can then do certain + // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi + // nodes. + // + // Note that this applies in SSA form, and Both of the functions are + // guaranteed to exit. + // + void OptimizePhiValues(); + + // Returns a new HloValue defined at the given instruction and shape index. + HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, + bool is_phi); + + // Marks the HloValue with the given ID for deletion. + void MarkValueForDeletion(HloValue::Id value_id); + + // Deletes all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); + + // Constructs and initializes the InstructionValueSets of all instructions to + // contain exactly the HloValues defined by each instruction. These values can + // then propagated throughout the HLO graph by calling Propagate. + absl::Status InitializeInstructionValueSets(); + + // Updates the value set of the given instruction based on the values flowing + // into the instruction (operands and cross-computation dataflow). + bool UpdateInstructionValueSet(HloInstruction* instruction); + + // Updates the value set for a particular instruction type. Returns whether + // the instruction value set changed. + bool UpdateBitcastValueSet(HloInstruction* bitcast); + bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); + bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateCustomCallValueSet(HloInstruction* custom_call); + bool UpdateDomainValueSet(HloInstruction* domain); + bool UpdateGetTupleElementValueSet(HloInstruction* gte); + bool UpdateParameterValueSet(HloInstruction* parameter); + // Async op propagation rules: + // - Operand of async-start to parameter of async wrapped computation and at + // index {0, operand_number} of async-start and async-update outputs. + // - Root of async wrapped computation to index {1} of async-start and + // async-update and index {} of async-done. + // - The contexts in indices {2+} of async-start to the same indices of + // async-update. + // + // As a result of this, the operands/outputs of async-start and async-done + // instructions share the same values as the parameters/roots of the async + // wrapped computation. + bool UpdateAsyncStartValueSet(HloInstruction* async_start); + bool UpdateAsyncUpdateValueSet(HloInstruction* async_update); + bool UpdateAsyncDoneValueSet(HloInstruction* async_done); + bool UpdateCopyStartValueSet(HloInstruction* copy_start); + bool UpdateCopyDoneValueSet(HloInstruction* copy_done); + bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier); + bool UpdateRecvDoneValueSet(HloInstruction* recv_done); + bool UpdateSendValueSet(HloInstruction* send); + bool UpdateTupleValueSet(HloInstruction* tuple); + bool UpdateWhileValueSet(HloInstruction* xla_while); + bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); + bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); + bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); + bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); + bool UpdateCollectivePermuteStartValueSet( + HloInstruction* collective_permute_start); + bool UpdateCollectivePermuteDoneValueSet( + HloInstruction* collective_permute_done); + + // Propagates the dataflow through the module. In particular, it propagates + // the HloValueSet from its defining instruction to the users of the + // instructions. + void Propagate(); + + // Returns the result of the SSA Phi function applied to the given inputs at + // the given instruction. + bool Phi(HloInstruction* instruction, + absl::Span inputs); + + // Updates the positions of the HloValues in the output of the given + // instruction. This should be called after the instruction value set of + // 'instruction' has been changed. 'prev_value_set' must point to the previous + // state of the value set prior to the change. 'prev_value_set' may be null if + // this is the first time positions are being computed. The previous state is + // necessary to efficiently remove positions which have been eliminated due to + // changes in the instructions' InstructionValueSet. + void UpdatePositionsOfValuesAt( + HloInstruction* instruction, const InstructionValueSet& new_value_set, + const InstructionValueSet* prev_value_set = nullptr); + + const HloModule& module_; + const absl::flat_hash_set execution_threads_; + const bool ssa_form_; + const bool bitcast_defines_value_; + + std::unique_ptr call_graph_; + + // The map of all HloValues in the module. We pass around pointers to the + // mapped HloValues, so the underlying container must keep them valid despite + // mutations touching other map entries. + absl::flat_hash_map> values_; + + // A map from instruction to InstructionValueSet. + absl::flat_hash_map> + value_sets_; + + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + + // A vector containing all HloValues sorted by HloValue::Id. + std::vector values_vector_; + + // The Id to use for the next HloValue. + HloValue::Id next_value_id_ = 0; + + // An explicit graph holding phi values and edges. + PhiGraph phi_graph_; + + // Backend specific function that decides whether an instruction can share + // a buffer with its operand. + CanShareBuffer can_share_buffer_ = nullptr; + + ForwardsValue forwards_value_ = nullptr; +}; + +// Removes layers of tuple indirection introduced via 'tuple' and +// 'get-tuple-element' instructions to more directly identify the source of the +// given HLO value (identified by the given `ShapeIndex` into the output of the +// given `HloInstruction`). +// +// e.g. for the following: +// %x = some-op(...) +// %foo = get-tuple-element(%x), index=0 +// %bar = tuple(%y, %foo) +// +// ... FollowTupleIndirection(%bar, {1}) == {%x, {0}} (output 1 of 'bar' comes +// from output 0 of %x). +// +// Note that all 'tuple' instructions are followed before all +// 'get-tuple-element' instructions are followed. This is because it is assumed +// that tupling a value and then extracting it from the tuple again will not +// occur in properly-optimized IR. +std::pair FollowTupleIndirection( + const HloInstruction* instruction, ShapeIndex operand_index); + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_DATAFLOW_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc similarity index 98% rename from third_party/xla/xla/service/hlo_dataflow_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc index f7734f40fe3713..ea28bc9d7aa48f 100644 --- a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dataflow_analysis_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include #include @@ -27,20 +27,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_dce.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -55,7 +55,7 @@ using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is // performed with SSA form. -class HloDataflowAnalysisTest : public HloTestBase, +class HloDataflowAnalysisTest : public HloHardwareIndependentTestBase, public ::testing::WithParamInterface { protected: HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {} @@ -2155,7 +2155,7 @@ std::unique_ptr RunAnalysis( .value(); } -using DoesNotUseOperandBufferTest = HloTestBase; +using DoesNotUseOperandBufferTest = HloHardwareIndependentTestBase; TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { auto builder = HloComputation::Builder(TestName()); @@ -2267,7 +2267,7 @@ TEST_F(DoesNotUseOperandBufferTest, IndirectUses) { dataflow_analysis->DoesNotUseOperandBuffer(tuple_param, {0}, fusion)); } -using CanShareOperandBufferWithUserTest = HloTestBase; +using CanShareOperandBufferWithUserTest = HloHardwareIndependentTestBase; TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { auto builder = HloComputation::Builder(TestName()); @@ -3090,7 +3090,7 @@ TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) { fusion, {2})); } -using GetInPlaceInputOutputPairsTest = HloTestBase; +using GetInPlaceInputOutputPairsTest = HloHardwareIndependentTestBase; TEST_F(GetInPlaceInputOutputPairsTest, DUS) { const char* kModule = R"( @@ -3440,5 +3440,39 @@ TEST_F(GetInPlaceInputOutputPairsTest, DUSOutputFusionWithCollective) { EXPECT_EQ(in_place_pairs, expected_pairs); } +TEST_F(GetInPlaceInputOutputPairsTest, DUSLoopFusionWithBitcast) { + const char* kModule = R"( + HloModule DUSLoopFusionWithBitcast + + fused_dynamic_update_slice { + param_1.133 = bf16[32,1,4096,18432]{2,3,1,0} parameter(1) + bitcast.8539.1 = bf16[32,1,18432,4096]{3,2,1,0} bitcast(param_1.133) + param_0.168 = bf16[1,4096,18432]{1,0,2} parameter(0) + bitcast.8543.1 = bf16[1,1,18432,4096]{3,2,1,0} bitcast(param_0.168) + param_2.98 = s32[] parameter(2) + constant_2153_8 = s32[] constant(0) + compare.753.6 = pred[] compare(param_2.98, constant_2153_8), direction=LT + constant_2154_12 = s32[] constant(96) + add.950.6 = s32[] add(param_2.98, constant_2154_12) + select.883.5 = s32[] select(compare.753.6, add.950.6, param_2.98) + ROOT dynamic-update-slice.178.1 = bf16[32,1,18432,4096]{3,2,1,0} dynamic-update-slice(bitcast.8539.1, bitcast.8543.1, select.883.5, constant_2153_8, constant_2153_8, /*index=5*/constant_2153_8) + } + + ENTRY entry { + p0 = bf16[1,4096,18432]{1,0,2} parameter(0) + p1 = bf16[32,1,4096,18432]{2,3,1,0} parameter(1) + p2 = s32[] parameter(2) + ROOT fusion1 = bf16[32,1,18432,4096]{3,2,1,0} fusion(p0, p1, p2), kind=kLoop, calls=fused_dynamic_update_slice + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule)); + HloInstruction* fusion = module->entry_computation()->root_instruction(); + auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion); + std::vector> expected_pairs; + // p1 should be aliased with fusion1 + expected_pairs.push_back({HloOperandIndex{1, {}}, {}}); + EXPECT_EQ(in_place_pairs, expected_pairs); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc similarity index 98% rename from third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc rename to third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc index c831f31cec03f1..cf6e495424e60b 100644 --- a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/ir/hlo_dfs_reachability.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.h b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.h new file mode 100644 index 00000000000000..775dd8eb824d24 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_DFS_REACHABILITY_H_ +#define XLA_HLO_ANALYSIS_HLO_DFS_REACHABILITY_H_ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { + +// A simple DFS-based reachability analysis for HLO instructions. +// +// When the class is created, the instructions are ordered in a defs-before-uses +// topological order. +// The reachability query runs a DFS from the destination node (going up through +// operands / control predecessors), and stops when the instruction's index in +// the defs-before-uses list is before the source node. As the reachability is +// tested for nodes that are close to each other, this optimization works well, +// and the time is dominated by the post-order sort. +class HloDfsReachability { + public: + // Returns true iff the instruction was present in the computation passed to + // Build(). The calling code may want to still use the class after the + // computation is modified, if it's known that the def-before-use order is + // still preserved. + bool IsPresent(const HloInstruction* instruction) const; + // Returns true iff there is a path (with edges being users and control + // successors) from 'from' to 'to'. (i.e. path from definitions to uses; from + // producers to consumers) + bool IsReachable(const HloInstruction* from, const HloInstruction* to) const; + // Returns true iff either `a` is reachable from `b` or `b` is reachable from + // `a`. + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; + static std::unique_ptr Build( + const HloComputation* computation); + + private: + // LLVM dense map shows ~10-20% speedup compared to absl::flat_hash_map. + llvm::DenseMap instruction_to_idx_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_DFS_REACHABILITY_H_ diff --git a/third_party/xla/xla/service/hlo_dfs_reachability_test.cc b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc similarity index 94% rename from third_party/xla/xla/service/hlo_dfs_reachability_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc index 9bc77c75f42ea1..52f8ccc249ed70 100644 --- a/third_party/xla/xla/service/hlo_dfs_reachability_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_dfs_reachability_test.cc @@ -13,24 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/ir/hlo_dfs_reachability.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include -#include +#include +#include #include -#include "absl/random/random.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" namespace xla { namespace { -class HloDfsReachabilityTest : public HloTestBase {}; +class HloDfsReachabilityTest : public HloHardwareIndependentTestBase {}; TEST_F(HloDfsReachabilityTest, NonTrivialReachability) { // Test reachability of a non-trivial computation: diff --git a/third_party/xla/xla/service/hlo_liveness_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc similarity index 99% rename from third_party/xla/xla/service/hlo_liveness_analysis.cc rename to third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc index ff534fb1d4ecc1..16152cbc875051 100644 --- a/third_party/xla/xla/service/hlo_liveness_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_liveness_analysis.h" +#include "xla/hlo/analysis/hlo_liveness_analysis.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.h new file mode 100644 index 00000000000000..40d991a200bf8e --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_LIVENESS_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_HLO_LIVENESS_ANALYSIS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_value.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" + +namespace xla { + +// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in +// an HLO module. +// +// HloLivenessAnalysis marks the shape index of each live output of each +// instruction in the module, by propagating live shape index information +// from an instruction to its called computations and operands. +class HloLivenessAnalysis { + public: + // Maps from an HloInstruction to its live/dead output shape indices. + using HloIndexMap = absl::flat_hash_map>>; + + // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object + // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. + static absl::StatusOr> Run( + const HloModule& module); + + // Returns true if output of 'instruction' at 'shape_index' is live. + // Returns false otherwise. + bool IsLive(const HloInstruction* instruction, + const ShapeIndex& shape_index) const; + + private: + HloLivenessAnalysis(const HloModule& module); + + void RunAnalysis(); + + const HloModule& module_; + std::unique_ptr call_graph_; + HloIndexMap live_index_map_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_LIVENESS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_liveness_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_liveness_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc index 4a132059878c31..436f5dedfef321 100644 --- a/third_party/xla/xla/service/hlo_liveness_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_liveness_analysis_test.cc @@ -13,23 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_liveness_analysis.h" +#include "xla/hlo/analysis/hlo_liveness_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" namespace xla { namespace { -class HloLivenessAnalysisTest : public HloTestBase { +class HloLivenessAnalysisTest : public HloHardwareIndependentTestBase { protected: HloLivenessAnalysisTest() {} diff --git a/third_party/xla/xla/service/hlo_ordering.cc b/third_party/xla/xla/hlo/analysis/hlo_ordering.cc similarity index 99% rename from third_party/xla/xla/service/hlo_ordering.cc rename to third_party/xla/xla/hlo/analysis/hlo_ordering.cc index 388de97291fab1..79b39a6f60de6d 100644 --- a/third_party/xla/xla/service/hlo_ordering.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_ordering.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_ordering.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_ordering.h b/third_party/xla/xla/hlo/analysis/hlo_ordering.h new file mode 100644 index 00000000000000..644c3881fd2233 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_ordering.h @@ -0,0 +1,243 @@ +/* Copyright 2016 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_ORDERING_H_ +#define XLA_HLO_ANALYSIS_HLO_ORDERING_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_value.h" +#include "xla/types.h" + +namespace xla { + +// Base class for describing a partial ordering of HLO instructions. Used to +// determine live range overlap of HLO instruction output buffers. +class HloOrdering { + public: + explicit HloOrdering(const HloModule* module) + : module_(module), call_graph_(CallGraph::Build(module)) {} + virtual ~HloOrdering() = default; + + // Specify the ordering constraints between a pair of instructions a and b. + enum class ExecutionConstraint { + // Indicate a and b are the same instruction; + kIsSame, + // Indicate a runs before b starts; + kRunBeforeStart, + // Indicate a runs before b ends but after b starts, e.g., when b is a + // conditional or while loop; + kRunBeforeEnd, + // Only one of a or b runs each time their common ancestor is evaluated, + // and a is in an earlier branch than b. + kRunExclusiveBefore, + // Only one of a or b runs each time, and a is in a later branch than b. + kRunExclusiveAfter, + // Indicate a runs after b ends. + kRunAfter, + // An order cannot be detrermined as a and b do not have a common ancestor. + kUnordered, + }; + // Return the execution constraint between a and b. + HloOrdering::ExecutionConstraint GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const; + + // Returns true if instruction 'a' executes before instruction 'b'. This is + // not reflexive, that is, an instruction does not execute before itself. + bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; + + // Returns whether the value 'a' is defined before the value 'b' under the + // given ordering. + bool IsDefinedBefore(const HloValue& a, const HloValue& b) const; + + // Returns whether the given use is before the given value definition under + // the given ordering. Set use_is_always_before_def_in_same_instr to false if + // you want the analysis to always consider a use at an instruction's operand + // to be strictly before that instructions definition. The configuration needs + // to be false when result will be used to remove unnecessary copy + // instructions, due to additional buffer sharing constraints. + bool UsesBeforeValueDefinition( + absl::Span uses, const HloValue& value, + const HloDataflowAnalysis& dataflow, + bool use_is_always_before_def_in_same_instr = false) const; + // Returns whether the given values interfere. Two values interfere if they + // may both be simultaneously live. + bool MayInterfere(const HloValue& a, const HloValue& b, + const HloDataflowAnalysis& dataflow) const; + + // Returns true if the live range of the given value 'a' is strictly before + // the live range of value 'b' using the given HLO ordering. + bool LiveRangeStrictlyBefore( + const HloValue& a, const HloValue& b, const HloDataflowAnalysis& dataflow, + bool use_is_always_before_def_in_same_instr = false) const; + + // Returns the sequential instruction order for the given computation, or + // nullptr if the computation does not have a sequential ordering. + virtual const HloInstructionSequence* SequentialOrder( + const HloComputation& computation) const = 0; + + // Return the call graph of the module used to compute ordering. + const CallGraph& call_graph() const { return *call_graph_; } + + virtual std::string ToString() const = 0; + + protected: + // Returns true if instruction 'a' executes before instruction 'b'. + // Precondition: 'a' and 'b' are in the same computation. + // + // Derived classes should implement this method for determining order of + // instructions in the same computation. ExecutesBefore() analyzes the + // callgraph and uses this method to determine ordering of instructions in + // different computations. + virtual bool ExecutesBeforeInSameComputation( + const HloInstruction* a, const HloInstruction* b) const = 0; + + const HloModule* module_; + + std::unique_ptr call_graph_; +}; + +// Base class for partial orderings implemented by a map of predecessors for +// each instruction. Subclasses should fill in predecessors_. +class PredecessorHloOrdering : public HloOrdering { + public: + ~PredecessorHloOrdering() override = default; + + // Returns nullptr indicating the computation does not have a sequential + // ordering. + const HloInstructionSequence* SequentialOrder( + const HloComputation& computation) const override { + return nullptr; + } + + HloReachabilityMap& reachability_map(const HloComputation* computation) { + return *predecessors_.at(computation); + } + const HloReachabilityMap& reachability_map( + const HloComputation* computation) const { + return *predecessors_.at(computation); + } + + protected: + explicit PredecessorHloOrdering(const HloModule* module); + std::string ToStringHelper(const std::string& name) const; + + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; + + // For each computation in the module, this is the set of the instruction's + // predecessors. An instruction is an element of its own predecessor set. + // + // Subclasses should fill this in to define the desired ordering. + absl::flat_hash_map> + predecessors_; +}; + +// An HLO ordering based on data dependencies in the HLO graph. In this partial +// order, instruction A executes before instruction B only if there is a path +// from A to B in the HLO graph. For example, given the following graph: +/* + param + / \ + negate exp + \ / + add +*/ +// DependencyHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before add +// exp executes before add +// add executes before nothing +// negate and exp are not ordered because the dependencies allow either to +// execute before the other (or in parallel). DependencyHloOrdering ordering +// allows maximum parallelism and enables any execution order which satisfies +// data dependencies. This requires pessimistic assumptions about buffer live +// ranges and can result in more memory used than more constrained orderings. +class DependencyHloOrdering : public PredecessorHloOrdering { + public: + explicit DependencyHloOrdering(const HloModule* module); + ~DependencyHloOrdering() override = default; + + std::string ToString() const override; +}; + +// An HLO ordering based on a total order of instructions in each computation. +// The computation total order is a sequencing of all of its instructions in +// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded +// execution. For example, given the following HLO graph: +/* + param + / \ + negate exp + \ / + add +*/ +// and the following sequence: +// +// {param, negate, exp, add} +// +// SequentialHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before exp and add +// exp executes before add +// add executes before nothing +// This is more constrained than DependencyHloOrdering in this example because +// negate and exp are ordered (negate before exp). This enables param to share +// the same buffer as exp (param buffer is dead after exp). Generally, this +// ordering enables more buffer sharing (reduced memory usage) because buffer +// interference is reduced relative to DependencyHloOrdering. +class SequentialHloOrdering : public HloOrdering { + public: + explicit SequentialHloOrdering(const HloSchedule& schedule); + explicit SequentialHloOrdering(HloSchedule&& schedule); + ~SequentialHloOrdering() override = default; + + // Returns the sequential instruction order for the given computation. + const HloInstructionSequence* SequentialOrder( + const HloComputation& computation) const override; + + std::string ToString() const override; + + protected: + void Initialize(); + + bool ExecutesBeforeInSameComputation(const HloInstruction* a, + const HloInstruction* b) const override; + + const HloSchedule schedule_; + + // The position of every instruction in the HLO module in its respective + // computation sequence (a value of zero indicates the instruction is first in + // the sequence, etc). Instructions from all computations are contained in + // this map so more than one instruction may have the same position + // value. This is not a problem because ExecutesBefore also verifies + // instructions are in the same computation. + absl::flat_hash_map order_position_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_ORDERING_H_ diff --git a/third_party/xla/xla/service/hlo_ordering_test.cc b/third_party/xla/xla/hlo/analysis/hlo_ordering_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_ordering_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_ordering_test.cc index c0b1dc9c0c6bb7..488953b6ba66c1 100644 --- a/third_party/xla/xla/service/hlo_ordering_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_ordering_test.cc @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_ordering.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include #include #include #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -36,7 +37,7 @@ limitations under the License. namespace xla { namespace { -class HloOrderingTest : public HloTestBase {}; +class HloOrderingTest : public HloHardwareIndependentTestBase {}; TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { // Tests the ordering of instructions in different computations using the diff --git a/third_party/xla/xla/hlo/ir/hlo_reachability.cc b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc similarity index 99% rename from third_party/xla/xla/hlo/ir/hlo_reachability.cc rename to third_party/xla/xla/hlo/analysis/hlo_reachability.cc index cdc1b664ef4380..7123abbb1a73ce 100644 --- a/third_party/xla/xla/hlo/ir/hlo_reachability.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_reachability.h b/third_party/xla/xla/hlo/analysis/hlo_reachability.h new file mode 100644 index 00000000000000..6c895c38a5a160 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability.h @@ -0,0 +1,221 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_REACHABILITY_H_ +#define XLA_HLO_ANALYSIS_HLO_REACHABILITY_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/types.h" + +namespace xla { + +// A class for representing reachability between HloInstructions. +// +// It has an adjacency matrix and it is up to the user of the class to set the +// adjacency matrix such that it represents reachability, i.e. such that it is +// transitive. That the graph be transitive is thus not an invariant of this +// class, but it is required for the name of the class and its methods to make +// sense. +class HloReachabilityMap { + public: + using Index = size_t; + + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. + explicit HloReachabilityMap( + absl::Span instructions); + + // Computes and returns the reachability between HLO instructions in the + // computation. The returned HloReachabilityMap is constructed such that + // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a + // directed path (from producer to consumer) from 'a' to 'b'. Both data + // dependencies (operands) and control dependencies are considered for + // reachability. Trivially an instruction is reachable from itself. + static std::unique_ptr Build( + const HloComputation* computation); + + // Similar to the above Build operation except that it tries to identify + // paths between instructions that do not contain control instructions + // and multiple operands, i.e., b is_reachable a == true iff + // b = f(f(f(f(f(a), constant), constant), constant). + // Further, the only ops allowed in a path are basic math operations such + // as add, sub, mul, div. + static std::unique_ptr BuildWithRestrictions( + const HloComputation* computation, + absl::FunctionRef*)> + add_dependencies); + + // Set the reachability set of 'instruction' to the union of the reachability + // sets of 'inputs'. Upon return, IsReachable(x, instruction) where + // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true + // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from + // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. + bool SetReachabilityToUnion(absl::Span inputs, + const HloInstruction* instruction); + + // As above, but faster because it does not check if the reachability changed. + void FastSetReachabilityToUnion( + absl::Span inputs, + const HloInstruction* instruction); + // As above, but use Index instead if it's already looked up which is even + // faster since no hash map lookup will occur. + void FastSetReachabilityToUnion(absl::Span input_indices, + Index index); + + Index GetIndex(const HloInstruction* instruction) const { + return indices_.at(GetKey(instruction)); + } + + // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. + void SetReachable(const HloInstruction* a, const HloInstruction* b) { + SetReachable(GetIndex(a), GetIndex(b)); + } + void SetReachable(Index a, Index b) { bit_sets_[b].Set(a); } + + // Updates the given reachability map after the immediate predecessor set + // (operands and control predecessors) of 'instruction' has changed. + void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); + + // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. + bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { + return IsReachable(GetIndex(a), GetIndex(b)); + } + bool IsReachable(Index a, Index b) const { return bit_sets_[b].Get(a); } + + // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. + bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { + return IsConnected(GetIndex(a), GetIndex(b)); + } + bool IsConnected(Index a, Index b) const { + return IsReachable(a, b) || IsReachable(b, a); + } + + // Checks if an instruction is in the Reachability map. + bool IsPresent(const HloInstruction* instruction) const { + return indices_.contains(GetKey(instruction)); + } + + // Replace the instruction "original" with "replacement" in the reachability + // map. + void Replace(const HloInstruction* original, + const HloInstruction* replacement); + + private: + // A dynamically sized bit-set implementation specialized for this use case + // providing fast bitwise OR (not available in tsl::gtl::BitMap). + class BitSet { + public: + BitSet() = default; + explicit BitSet(size_t size) + : size_(size), vector_((size + kBits - 1) / kBits, 0) {} + + // Returns the bit at the given index. + bool Get(Index index) const { + DCHECK(index >= 0 && index < size_); + return vector_[index / kBits] & (1ull << (index % kBits)); + } + + // Sets the bit at the given index. + void Set(Index index) { + DCHECK(index >= 0 && index < size_); + vector_[index / kBits] |= 1ull << (index % kBits); + } + + // Sets this bit-set to union of this bit-set and `other`. + void operator|=(const BitSet& other) { + if (this == &other) return; + DCHECK(size_ == other.size_); + + // Ease the work of the auto-vectorizer. + const Word* a = vector_.data(); + const Word* b = other.vector_.data(); + Word* __restrict out = vector_.data(); + size_t num_words = vector_.size(); + for (size_t i = 0; i < num_words; ++i) { + out[i] = a[i] | b[i]; + } + } + + // Sets the bitvector to all zeros. + void SetToZero() { absl::c_fill(vector_, 0); } + + bool operator==(const BitSet& other) const { + return vector_ == other.vector_; + } + bool operator!=(const BitSet& other) const { return !(*this == other); } + + private: + using Word = uint64_t; + static constexpr size_t kBits = 64; + + size_t size_; // Number of bits in the set. + std::vector vector_; + }; + + friend class HloReachabilityMapBitSetBenchmark; + + using Key = std::pair; // module ID, instruction ID. + static Key GetKey(const HloInstruction* instruction) { + return {instruction->GetModule()->unique_id(), instruction->unique_id()}; + } + + // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. + void SetReachabilityToUnionHelper( + absl::Span inputs, Index index); + void SetReachabilityToUnionHelper(absl::Span input_indices, + Index index); + + // Map from instruction to index. The index is used for bit_set_ and the bits + // within a BitSet. + absl::flat_hash_map indices_; + + // Bit-sets holding the reachability to each instruction. The bit-set for + // instruction X includes ones for each instruction which X is reachable from. + std::vector bit_sets_; + + // A temporary used by SetReachabilityToUnion to avoid an allocation with each + // call to the method. + BitSet tmp_bit_set_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_REACHABILITY_H_ diff --git a/third_party/xla/xla/service/hlo_reachability_test.cc b/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc similarity index 96% rename from third_party/xla/xla/service/hlo_reachability_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc index bc0d2b7293b47d..3050d028ecf420 100644 --- a/third_party/xla/xla/service/hlo_reachability_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_reachability_test.cc @@ -13,25 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include #include +#include #include #include "absl/random/random.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" #include "xla/service/computation_placer.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/test_benchmark.h" namespace xla { namespace { -class HloReachabilityTest : public HloTestBase {}; +class HloReachabilityTest : public HloHardwareIndependentTestBase {}; TEST_F(HloReachabilityTest, Reachability) { // Construct and test a reachability graph of the following form: diff --git a/third_party/xla/xla/service/hlo_replication_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc similarity index 99% rename from third_party/xla/xla/service/hlo_replication_analysis.cc rename to third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc index bc804e76d5e9ca..95d2ca2c14c9bd 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_replication_analysis.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h new file mode 100644 index 00000000000000..2818e1ff61196e --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis.h @@ -0,0 +1,153 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// An HLO pass that determines whether each instruction in the module outputs +// the same value across replicas or across partitions (depending on the value +// `cross_partition_spmd`). It propagates sources of replicated values to +// the rest of the module, where sources include cross-replica-sum, annotated +// entry parameters, and constants. +class HloReplicationAnalysis { + public: + // Runs the analysis on module and returns the result or an error. + static absl::StatusOr> Run( + const HloModule* module, bool cross_partition_spmd); + + // Same as above, but the caller can provide additional annotations: a set of + // while loops that are known to have the same iteration counts across + // replicas or partitions. + static absl::StatusOr> Run( + const HloModule* module, bool cross_partition_spmd, + const absl::flat_hash_set* + loops_known_with_same_iterations); + + // Same as above but supports finding partially replicated HLOs. + static absl::StatusOr> + RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd); + + // Returns if the HLO instruction outputs the same value (i.e., replicated) at + // the given index across all replicas or partitions. + bool HloInstructionIsReplicatedAt(const HloInstruction* inst, + const ShapeIndex& index) const; + + bool HloInstructionIsReplicatedAt( + const HloInstruction* inst, const ShapeIndex& index, + absl::Span replica_groups) const; + + private: + // A data structure that represents how an HLO is replicated among a set of + // devices. Device ID could be either partition ID or replica ID. + // We represent partial replication by grouping devices that have the same + // value into the same set. + class HloReplication { + public: + static HloReplication ReplicatedOnAllDevices(); + static HloReplication UniqueOnAllDevices(); + static HloReplication PartiallyReplicated( + absl::Span> device_sets); + HloReplication(); + HloReplication(const HloReplication& other) = default; + HloReplication(HloReplication&& other) = default; + HloReplication& operator=(HloReplication&& other) = default; + HloReplication Merge(const HloReplication& other) const; + bool Equal(const HloReplication& other) const; + bool IsReplicatedOnAllDevices() const; + bool IsUniqueOnAllDevices() const; + bool IsReplicatedWithinSubgroup(absl::Span device_ids) const; + std::string ToString() const; + + private: + enum class State { + kReplicatedOnAllDevices = 0, + kUniqueOnAllDevices = 1, + kPartiallyReplicated = 2, + }; + explicit HloReplication(State state, + absl::Span device_set_root); + State state_; + // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices. + // Otherwise, its size equals to the number of devices (either partitions + // or replications). Maps each device ID to the smallest device ID in the + // set. + std::vector device_set_root_; + }; + + static HloReplication DetermineHloInstructionIsReplicated( + const HloInstruction* hlo, const ShapeIndex& index, + bool cross_partition_spmd, + const absl::flat_hash_map>& hlo_replication, + bool support_partial_replication); + + HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, + const absl::flat_hash_set* + loops_known_with_same_iterations, + bool support_partial_replication) + : module_(module), + cross_partition_spmd_(cross_partition_spmd), + loops_known_with_same_iterations_(*loops_known_with_same_iterations), + support_partial_replication_(support_partial_replication) {} + + // Computes hlo_replication_. + absl::Status ComputeHloReplication(); + + // A helper function to recursively compute hlo_replication on a computation. + // Returns whether hlo_replication_ is changed. + bool ComputeHloReplicationOnComputation(const HloComputation* computation, + bool mark_everything_not_replicated); + + const HloModule* module_; + + // If true, run this replication analysis for replicated values across + // partitions (not across replicas) on an SPMD partitioned module. This means + // that HloInstructionIsReplicatedAt() returns true if the value is identical + // across partitions for each replica. The module-level parameter and root + // instructions may have HloSharding attributes that indicate whether values + // are identical across partitions. + // + // If false, HloReplicationAnalysis runs across replicas. + bool cross_partition_spmd_; + + // A set of while loops that are known to have the same iteration counts + // across replicas or partitions. This is provided by the caller as additional + // annotations. + const absl::flat_hash_set& + loops_known_with_same_iterations_; + + const bool support_partial_replication_; + + // A map from each analyzed HLO instruction to a shape tree that represents + // whether the instruction outputs the same value across replicas or + // partitions at each shape index. + absl::flat_hash_map> + hlo_replication_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_REPLICATION_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_replication_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_replication_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc index e57e7112226072..19bc0e0e5eafc5 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_replication_analysis_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_replication_analysis.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include #include @@ -21,15 +21,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" namespace xla { namespace { -class HloReplicationAnalysisTest : public HloTestBase {}; +class HloReplicationAnalysisTest : public HloHardwareIndependentTestBase {}; TEST_F(HloReplicationAnalysisTest, NoControlFlow) { const std::string module_str = R"( @@ -194,8 +194,6 @@ ENTRY entry { FindInstruction(module.get(), "subtract.2"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "add"), {})); - EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( - FindInstruction(module.get(), "add"), {})); EXPECT_TRUE(analysis->HloInstructionIsReplicatedAt( FindInstruction(module.get(), "replica-id"), {})); EXPECT_FALSE(analysis->HloInstructionIsReplicatedAt( diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc similarity index 99% rename from third_party/xla/xla/service/hlo_value_semantics_analysis.cc rename to third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc index 06ae99051ddcce..49b7c78fd2b9a1 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_value_semantics_analysis.h" +#include "xla/hlo/analysis/hlo_value_semantics_analysis.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h new file mode 100644 index 00000000000000..c6fa0284e7cf97 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis.h @@ -0,0 +1,443 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_HLO_VALUE_SEMANTICS_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_HLO_VALUE_SEMANTICS_ANALYSIS_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" + +namespace xla { + +struct SendRecvGroup { + HloInstruction* send; + HloInstruction* recv; +}; + +class SendRecvGroupMap { + public: + explicit SendRecvGroupMap(const HloModule& hlo_module); + SendRecvGroupMap(SendRecvGroupMap&& other) = default; + SendRecvGroupMap(const SendRecvGroupMap& other) = default; + virtual ~SendRecvGroupMap() = default; + virtual absl::StatusOr GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const; + + private: + absl::flat_hash_map host_transfer_rendezvous_map_; +}; + +class HloPreOrderDFS { + public: + HloPreOrderDFS() = default; + ~HloPreOrderDFS() = default; + absl::Status Run(const HloComputation& computation, + DfsHloVisitorBase* visitor); + + private: + bool IsReady(const HloInstruction* instruction) const; + std::vector stack_; + absl::flat_hash_set visited_; +}; + +using EinsumDepthMap = + absl::node_hash_map>; + +// The einsum depth is the length of the einsum dependency chain. And we +// distinguish instructions that are used by root and that are not used by +// root. +// The einsum depth of an HLO value A is defined as follows: +// for B = op(A, ...) +// 1) the root instruction has a depth of 0; +// 2) non-root instructions that have zero users have a depth of -1; +// 3) if op is a Dot or Convolution (i.e., einsum), +// depth(A, B) = depth(B) >= 0 ? depth(B) + 1 : depth(B) - 1. +// depth(A, B) means the depth of A because of B; +// 4) otherwise depth(A, B) = depth(B); +// 5) depth(A) is computed by merging all depth(A, u) where u is a user of A. +// See MergeDepth for how user depths are merged. + +class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { + public: + static absl::StatusOr> Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map); + ~EinsumDepthAnalysis() override = default; + absl::Status DefaultAction(HloInstruction* instruction) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleDot(HloInstruction* dot) override; + absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleCall(HloInstruction* call) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleWhile(HloInstruction* xla_while) override; + absl::Status HandleConditional(HloInstruction* conditional) override; + absl::Status HandleAfterAll(HloInstruction* after_all) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleRecv(HloInstruction* recv) override; + absl::Status HandleSendDone(HloInstruction* send_done) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandleAllReduce(HloInstruction* all_reduce) override; + absl::Status HandleAsyncStart(HloInstruction* async_start) override; + absl::Status HandleAsyncDone(HloInstruction* async_done) override; + const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } + + private: + explicit EinsumDepthAnalysis(const SendRecvGroupMap& send_recv_group_map) + : send_recv_group_map_(&send_recv_group_map) {} + absl::Status RunInternal(const HloComputation& computation, + const std::optional>& root_depth); + ShapeTree& GetOrCreateDepthTree(const HloInstruction* instruction); + ShapeTree& GetDepthTreeOrDie(const HloInstruction* instruction); + absl::Status SetInstructionDepth(const HloInstruction* instruction, + int depth); + absl::Status SetInstructionDepth(const HloInstruction* instruction, + const ShapeTree& depth); + absl::Status SetInstructionDepthFromTupleDepth( + const HloInstruction* instruction, const ShapeTree& tuple_depth_tree, + int tuple_index); + absl::Status HandleDepthIncrementInstruction(HloInstruction* instruction); + absl::Status HandleCalledComputation( + const HloComputation& called_computation, + const ShapeTree& root_depth, + absl::Span operands); + absl::Status HandleTupleLike(HloInstruction* tuple_like); + EinsumDepthMap einsum_depth_map_; + const SendRecvGroupMap* const send_recv_group_map_; +}; + +using EinsumHeightMap = + absl::node_hash_map>; + +// Einsum height is the maximum number of einsums between this instruction and +// any leaf. + +class EinsumHeightAnalysis : public DfsHloVisitorWithDefault { + public: + static absl::StatusOr> Run( + const HloComputation& computation, + const SendRecvGroupMap& send_recv_group_map); + ~EinsumHeightAnalysis() override = default; + absl::Status DefaultAction(HloInstruction* instruction) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleDot(HloInstruction* dot) override; + absl::Status HandleConvolution(HloInstruction* convolution) override; + absl::Status HandleCall(HloInstruction* call) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleWhile(HloInstruction* xla_while) override; + absl::Status HandleConditional(HloInstruction* conditional) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleRecv(HloInstruction* recv) override; + absl::Status HandleSendDone(HloInstruction* send_done) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandleAllReduce(HloInstruction* all_reduce) override; + absl::Status HandleAsyncStart(HloInstruction* async_start) override; + absl::Status HandleAsyncDone(HloInstruction* async_done) override; + const EinsumHeightMap& GetEinsumHeightMap() const { + return einsum_height_map_; + } + + private: + explicit EinsumHeightAnalysis(const SendRecvGroupMap& send_recv_group_map) + : send_recv_group_map_(&send_recv_group_map) {} + absl::Status RunInternal(const HloComputation& computation, + absl::Span operands); + ShapeTree& GetOrCreateHeightTree(const HloInstruction* instruction); + ShapeTree& GetHeightTreeOrDie(const HloInstruction* instruction); + bool HasHeightFor(const HloInstruction* instruction) const; + absl::Status SetInstructionHeight(const HloInstruction* instruction, + int height); + absl::Status SetInstructionHeight(const HloInstruction* instruction, + const ShapeTree& height); + absl::Status HandleHeightIncrementInstruction(HloInstruction* instruction); + absl::Status HandleCalledComputation( + const HloComputation& computation, + absl::Span operands); + absl::Status HandleTupleLike(HloInstruction* tuple_like); + + EinsumHeightMap einsum_height_map_; + const SendRecvGroupMap* const send_recv_group_map_; +}; + +// The comment below explains where the labels could originate from. Once +// originated, those labels are then propagated throughout the HLO module. +enum class HloValueSemanticLabel { + // Values that are known or predictable at compile time, including constants, + // iota, replica-id, and partition-id. + kStatic, + // Values that are not known or can't be predicated at compile time. + kRandom, + // HLO module parameters. + kWeight, + // Output of weight-weight or weight-activation matmuls. + kActivation, + // Output of weight-activation matmuls where the weight is a dependence of + // that activation. Or output of weight-activation-gradient matmuls. + kActivationGradient, + // Output of activation-gradient-activation matmuls. + kWeightGradient, + kTupleOrToken, +}; + +std::string HloValueSemanticLabelToString(HloValueSemanticLabel label); + +class HloValueSemantics { + public: + using Id = int64_t; + HloValueSemantics(HloValueSemanticLabel label, const HloPosition& origin); + HloValueSemantics(Id id, HloValueSemanticLabel label, + const HloPosition& origin); + HloValueSemantics(const HloValueSemantics& other) = default; + HloValueSemantics(HloValueSemantics&& other) = default; + HloValueSemantics& operator=(const HloValueSemantics& other) = default; + + Id id() const { return id_; } + HloValueSemanticLabel label() const { return label_; } + const HloPosition& origin() const { return origin_; } + std::string ToString() const; + + private: + const Id id_; + const HloValueSemanticLabel label_; + const HloPosition origin_; +}; + +std::string HloValueSemanticsTreeToString( + const ShapeTree& tree); + +using HloValueSemanticsMap = + absl::node_hash_map>; +class HloValueSemanticsPropagation; + +class HloValueSemanticsAnalysis { + public: + static absl::StatusOr> Run( + const HloModule& module, + const absl::flat_hash_set& execution_threads = {}); + virtual ~HloValueSemanticsAnalysis() = default; + bool HasSemanticsFor(const HloInstruction* instruction) const; + const HloValueSemantics* GetSemantics(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + const HloValueSemanticsMap& GetSemanticsMap() const { + return value_semantics_; + } + + const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } + const EinsumHeightMap& GetEinsumHeightMap() const { + return einsum_height_map_; + } + int GetDepth(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + int GetHeight(const HloInstruction* instruction, + const ShapeIndex& index = {}) const; + + const SendRecvGroupMap& GetSendRecvGroupMap() const { + return *send_recv_group_map_; + } + + absl::StatusOr GetMatchingSendOrRecv( + HloInstruction* send_or_recv) const; + + protected: + friend class HloValueSemanticsPropagation; + explicit HloValueSemanticsAnalysis( + const HloModule& module, + const absl::flat_hash_set& execution_threads); + virtual absl::Status InitializeEinsumDepth(); + virtual absl::Status InitializeEinsumHeight(); + // We match send and recv HLOs to propagate semantics from send to recv. + virtual void InitializeSendRecvGroups(); + void AnnotateWeights(); + + // Infer semantics for all instructions in the computation. Computation + // parameters are assigned the semantics of the corresponding operand. + absl::Status RunOnComputation( + const HloComputation& computation, + absl::Span operands); + // Same as the above RunOnComputation, but computation parameters have + // already been assigned with semantics. + virtual absl::Status RunOnComputation(const HloComputation& computation); + HloValueSemantics::Id NextId(); + const HloValueSemantics* NewHloValueSemantics(HloValueSemanticLabel label, + const HloPosition& origin); + const ShapeTree& GetInstructionSemantics( + const HloInstruction* instruction) const; + void DeepCopyHloValueSemantics( + ShapeTree& copy_to, + const ShapeTree& copy_from, + const ShapeIndex& source_index, const ShapeIndex& destination_index); + void DeepCopyHloValueSemantics( + const HloInstruction* target, + const ShapeTree& copy_from, + const ShapeIndex& source_index = {}); + void SetHloValueSemantics( + const HloInstruction* target, + const ShapeTree& semantics); + void DeleteHloValueSemantics( + const ShapeTree& to_delete); + void DeleteHloValueSemantics(const HloValueSemantics* to_delete); + const HloModule& module_; + const absl::flat_hash_set& execution_threads_; + HloValueSemanticsMap value_semantics_; + absl::flat_hash_map> + value_semantics_map_; + HloValueSemantics::Id next_id_; + EinsumDepthMap einsum_depth_map_; + EinsumHeightMap einsum_height_map_; + std::unique_ptr send_recv_group_map_; +}; + +class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { + public: + explicit HloValueSemanticsPropagation(HloValueSemanticsAnalysis* analysis); + absl::Status Run(const HloComputation& computation); + // Infer the output semantics from all operands of the instruction. + absl::Status DefaultAction(HloInstruction* instruction) override; + absl::Status HandleParameter(HloInstruction* parameter) override; + absl::Status HandleConstant(HloInstruction* constant) override; + absl::Status HandleIota(HloInstruction* iota) override; + absl::Status HandlePartitionId(HloInstruction* partition_id) override; + absl::Status HandleReplicaId(HloInstruction* replica_id) override; + absl::Status HandleClamp(HloInstruction* clamp) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleCall(HloInstruction* call) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + absl::Status HandleWhile(HloInstruction* xla_while) override; + absl::Status HandleConditional(HloInstruction* conditional) override; + absl::Status HandleSelect(HloInstruction* select) override; + absl::Status HandleConcatenate(HloInstruction* concatenate) override; + absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + absl::Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + absl::Status HandleCopyStart(HloInstruction* copy_start) override; + absl::Status HandleCopyDone(HloInstruction* copy_done) override; + absl::Status HandleAllGatherStart(HloInstruction* all_gather_start) override; + absl::Status HandleAllGatherDone(HloInstruction* all_gather_done) override; + absl::Status HandleCollectivePermuteStart( + HloInstruction* collective_permute_start) override; + absl::Status HandleCollectivePermuteDone( + HloInstruction* collective_permute_done) override; + absl::Status HandleGather(HloInstruction* gather) override; + absl::Status HandleScatter(HloInstruction* scatter) override; + absl::Status HandleAfterAll(HloInstruction* after_all) override; + absl::Status HandleAllReduce(HloInstruction* all_reduce) override; + absl::Status HandleAsyncStart(HloInstruction* async_start) override; + absl::Status HandleAsyncDone(HloInstruction* async_done) override; + absl::Status HandleInfeed(HloInstruction* infeed) override; + absl::Status HandleOutfeed(HloInstruction* outfeed) override; + absl::Status HandleDomain(HloInstruction* domain) override; + absl::Status HandleOptimizationBarrier(HloInstruction* opt_barrier) override; + absl::Status HandleRngBitGenerator( + HloInstruction* rng_bit_generator) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleRecv(HloInstruction* recv) override; + absl::Status HandleSendDone(HloInstruction* send_done) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + + protected: + HloValueSemantics CopySemantics(const HloValueSemantics& semantics) const; + HloValueSemantics CopySemanticsWithNewOrigin( + const HloValueSemantics& semantics, HloInstruction* new_origin, + const ShapeIndex& index = {}) const; + const HloValueSemantics* AddSemantics(const HloValueSemantics& semantics); + struct EinsumAndOperandIndex { + HloInstruction* einsum; + int64_t operand_index; + }; + // Checks if the origin of `semantics` is an einsum that takes + // `origin_dependence` as an operand. + // If `recursive` is set to true, recursively checks all ancestors of the + // `semantics`' origin (including itself) for the above condition. + // Returns all such einsums and the operand index corresponding to + // `origin_dependence`. + // We use this function to find whether the output of an einsum who has an + // operand X is used in another einsum who takes X as an operand. This is + // the pattern for gradient. + // For example, consider C = einsum(A, B), dC / dB = einsum(A, C). + std::vector FindEinsumsWhereOriginDependsOnOther( + const HloValueSemantics& semantics, const HloPosition& origin_dependence, + bool recursive = false) const; + bool OriginDependsOn(const HloValueSemantics& semantics, + const HloPosition& origin_dependence, + bool recursive = false) const; + absl::StatusOr MaybeCreateGradientSemantics( + HloInstruction* gradient_candidate, + HloValueSemanticLabel fallback_label) const; + absl::StatusOr ComputeSemanticsFromStaticAndOther( + const HloValueSemantics& static_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr ComputeSemanticsFromRandomAndOther( + const HloValueSemantics& random_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr ComputeSemanticsFromWeightAndOther( + const HloValueSemantics& weight_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr ComputeSemanticsFromActivationAndOther( + const HloValueSemantics& activation_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr + ComputeSemanticsFromActivationGradientAndOther( + const HloValueSemantics& activation_gradient_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr ComputeSemanticsFromWeightGradientAndOther( + const HloValueSemantics& weight_gradient_semantics, + const HloValueSemantics& other_semantics, + HloInstruction* instruction) const; + absl::StatusOr MergeSemanticsForAnInstruction( + HloInstruction* instruction, + std::vector& semantics_vec) const; + absl::StatusOr ComputeSemanticsFromOperands( + HloInstruction* instruction, absl::Span operand_indices, + absl::Span operand_shape_indices = {}) const; + absl::Status HandleTupleLike(HloInstruction* tuple_like); + absl::Status HandleCollectiveOrCopyStart(HloInstruction* op_start); + absl::Status HandleCollectiveOrCopyDone(HloInstruction* op_done); + HloValueSemanticsAnalysis* analysis_; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_HLO_VALUE_SEMANTICS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc index e7390ce63a7852..4c66f9de7207fb 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/hlo_value_semantics_analysis_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_value_semantics_analysis.h" +#include "xla/hlo/analysis/hlo_value_semantics_analysis.h" #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "tsl/platform/statusor.h" namespace xla { @@ -155,7 +155,7 @@ ENTRY MnistTrainingLoopWithInfeed.140 { } )"; -class HloValueSemanticsAnalysisTest : public HloTestBase { +class HloValueSemanticsAnalysisTest : public HloHardwareIndependentTestBase { public: bool HasLabel(const HloValueSemanticsAnalysis& hlo_value_semantics_analysis, HloModule* module, absl::string_view instruction_name, @@ -590,7 +590,7 @@ TEST_F(HloValueSemanticsAnalysisTest, MnistTrainingLoop) { IsWeightGradient(*hlo_value_semantics_analysis, module.get(), "dot.99")); } -class EinsumDepthAnalysisTest : public HloTestBase { +class EinsumDepthAnalysisTest : public HloHardwareIndependentTestBase { public: int GetInstructionDepth(const EinsumDepthMap& depth_map, HloComputation* computation, absl::string_view name) { @@ -686,7 +686,7 @@ TEST_F(EinsumDepthAnalysisTest, HandleAfterAll) { 0); } -class EinsumHeightAnalysisTest : public HloTestBase { +class EinsumHeightAnalysisTest : public HloHardwareIndependentTestBase { public: int GetInstructionHeight(const EinsumHeightMap& height_map, HloComputation* computation, diff --git a/third_party/xla/xla/service/indexed_array_analysis.cc b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc similarity index 99% rename from third_party/xla/xla/service/indexed_array_analysis.cc rename to third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc index a9a5004b011ecc..0e4011c3213b35 100644 --- a/third_party/xla/xla/service/indexed_array_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/indexed_array_analysis.h" +#include "xla/hlo/analysis/indexed_array_analysis.h" #include #include diff --git a/third_party/xla/xla/hlo/analysis/indexed_array_analysis.h b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.h new file mode 100644 index 00000000000000..83bf625a337120 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/indexed_array_analysis.h @@ -0,0 +1,395 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_INDEXED_ARRAY_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_INDEXED_ARRAY_ANALYSIS_H_ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a +// gather from another array. It does this by mapping HLO instructions to +// instances of IndexedArrayAnalysis::Array, which can be inspected to discover +// whether said HLO is equivalent to a gather. +class IndexedArrayAnalysis { + public: + // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. + // Array really just a sum type of the classes that inherit from it. The + // meaning of each of the subtypes is documented on the subtype declaration. + // + // Array instances are immutable once created. + class Array { + public: + enum Kind { + kUnknown, + kConstant, + kReshaped, + kScalarIndexedConstant, + kScalarIndexed + }; + + virtual Kind kind() const = 0; + virtual const Shape& shape() const = 0; + + // Does a checked downcast from `Array` to `T` which must be one of its + // subtypes. + template + T* as() { + static_assert((std::is_base_of::value), + "target type not derived from source type"); + // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + CHECK_NE(dynamic_cast(this), nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(this); + } + + virtual ~Array() = default; + + Array& operator=(const Array& other) = delete; + }; + + // Represents an HLO instruction that was not analyzable by this + // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing + // HloInstruction. + class UnknownArray : public Array { + public: + Kind kind() const override { return kUnknown; } + const Shape& shape() const override { return instruction().shape(); } + const HloInstruction& instruction() const { return instruction_; } + + private: + explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} + + const HloInstruction& instruction_; + + friend class IndexedArrayAnalysis; + }; + + // Represents a constant value. This constant value may be present in the HLO + // module being analyzed, or it could have been created on the fly by the + // analysis. + class ConstantArray : public Array { + public: + Kind kind() const override { return kConstant; } + const Shape& shape() const override { return literal()->shape(); } + const Literal* literal() const { return literal_; } + + private: + explicit ConstantArray(const Literal* literal) : literal_(literal) {} + const Literal* literal_; + + friend class IndexedArrayAnalysis; + }; + + // Represents an Array that is a reshape of another Array. + class ReshapedArray : public Array { + public: + Kind kind() const override { return kReshaped; } + + // The array to reshape. + Array* operand() const { return operand_; } + + // The output shape. + const Shape& shape() const override { return shape_; } + + private: + explicit ReshapedArray(Array* operand, Shape shape) + : operand_(operand), shape_(shape) {} + + Array* operand_; + const Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // --------------------------------------------------------------------------- + // Indexed Array Overview + // --------------------------------------------------------------------------- + // + // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this + // analysis. ScalarIndexedConstantArray is just a specialization of + // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this + // overview. + // + // A ScalarIndexedArray represents an array that can be computed by indexing + // into a "source" array using an "indices" tensor. A simple example is a + // gather operation gathering 12 rows out of a [100,100] matrix -- such an + // operation will be represented by an instance of a ScalarIndexedArray with + // the [100,100] matrix as the "source" array and the [12]-shaped indices + // array as the "indices" tensor. The ScalarIndexedArray operation itself + // will be of shape [12,100] (assuming we were gathering with axis=0). + // + // Gather operations are not the only operation that maps to + // ScalarIndexedArray instances (if that were true there would be little point + // in having a separate analysis). We can often infer ScalarIndexedArrays for + // other operations too. For instance, consider: + // + // %source = f32[100,100] constant + // %indices = s32[12] ... + // %gather = f32[12,100] ... gather from %source using %indices at axis 0 + // %dot = dot(%gather, other_constant) [canonical contracting dims] + // + // The dot operation itself is also a ScalarIndexedArray with source = + // dot(constant, other_constant) and indices = %indices. A reshape of %gather + // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately + // reshaped constant and indices = %indices. + + // Represents the result of a gather operation. This gather operation may + // explicitly be present in the HLO module being analyzed, or it could have + // been created on the fly by the analysis. + // + // An instance of ScalarIndexedArray represents a array whose I'th element can + // be mapped to the J'th element of the `source` array (where I and J are + // multidimensional indices) in this way: + // + // I' = remove components at positions `output_dims` from I + // G' = remove components not at positions `output_dims` from I + // T = indices[G'] + // J = I' with T inserted at position `source_dim` + // + // For example, if source is of shape [11,13,17,19], indices is of shape + // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of + // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the + // input index [B,D,indices[A,C],E]. + class ScalarIndexedArray : public Array { + public: + Kind kind() const override { return kScalarIndexed; } + const Shape& shape() const override { return shape_; } + + Array* source() const { return source_; } + Array* indices() const { return indices_; } + + // `source_dim` is the dimension in the source array that is being indexed + // over using indices from the `indices` array. See the class documentation + // and the overview for more details. + int64_t source_dim() const { return source_dim_; } + + // `output_dims` are the dimensions in the output array that are being used + // to compute an index into the `indices` array. See the class + // documentation and the overview for more details. + absl::Span output_dims() const { return output_dims_; } + + private: + explicit ScalarIndexedArray(Array* source, Array* indices, + int64_t source_dim, + std::vector output_dims, Shape shape) + : source_(source), + indices_(indices), + source_dim_(source_dim), + output_dims_(std::move(output_dims)), + shape_(std::move(shape)) {} + + Array* source_; + Array* indices_; + int64_t source_dim_; + std::vector output_dims_; + Shape shape_; + + friend class IndexedArrayAnalysis; + }; + + // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to + // have a ConstantArray instance as the source. This is an ergonomic + // concession -- in theory it is possible to just keep ScalarIndexedArray and + // check source()->kind(). + class ScalarIndexedConstantArray : public ScalarIndexedArray { + public: + Kind kind() const override { return kScalarIndexedConstant; } + + const Literal& literal() const { + return *source()->as()->literal(); + } + + private: + explicit ScalarIndexedConstantArray(Array* source, Array* indices, + int64_t source_dim, + std::vector output_dims, + Shape shape) + : ScalarIndexedArray(source, indices, source_dim, + std::move(output_dims), std::move(shape)) { + CHECK(dynamic_cast(source)); + } + + friend class IndexedArrayAnalysis; + }; + + // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance + // keeps ownership of the returned Array instance. + // + // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO + // instructions to IndexedArrayAnalysis::Array instances. This entire cache + // becomes stale and may cause the analysis to return incorrect results if any + // transitive operand (stopping at the containing computation) is modified for + // any HLO instruction on which GetArrayFor has been invoked. + // + // NB! By inspecting the implementation, you may be able to infer a stronger + // caching guarantee than what is mentioned above. Nevertheless, what is + // stated above is the contract. + absl::StatusOr GetArrayFor(const HloInstruction* instr); + + // Pretty-prints the expression rooted at `root`. + std::string ToString(Array* root, bool print_constants = false); + + private: + // Helper function that ensures that every HLO instruction that is + // transitively used by `root` has an entry in `cache_`. + absl::Status TraverseAndPopulateCache(const HloInstruction* root); + + // Creates an Array instance for `instr` under the assumption that all + // operations of `instr` are present in `cache_`. + absl::StatusOr ComputeArrayFor(const HloInstruction* instr); + + absl::StatusOr ComputeArrayForConstant(const Literal& literal); + + absl::StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + absl::Span slice_sizes, Array* source, Array* indices); + + absl::StatusOr ComputeArrayForDotWithIndexedLhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, + ConstantArray* rhs); + + absl::StatusOr ComputeArrayForDotWithIndexedRhs( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, ConstantArray* lhs, + ScalarIndexedConstantArray* rhs); + + absl::StatusOr ComputeArrayForDot( + const Shape& shape, const DotDimensionNumbers& dim_numbers, + const PrecisionConfig& precision_config, Array* lhs, Array* rhs); + + // This tries to fold a ScalarIndexedArray which has another + // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a + // ScalarIndexedArray as indices. If `source` happened to be a + // ScalarIndexedConstantArray this can result in an expression that is more + // canonical. + // + // As an example, consider a gather operation, G0, gathering 7 elements from + // an array "Arr" of shape [100] resulting in an array of shape [7], and a + // second gather operation, G1, which gathers 3 elements out of the result of + // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 + // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can + // instead rewrite G1 to gather directly from "Arr" with the three indices + // from I0 as per I1. In other words, we can rewrite: + // + // G0 = [Arr[i] for i in I0] + // G1 = [G0[i] for i in I1] + // + // into + // + // I2 = [I0[i] for i in I1] + // G1 = [Arr[i] for i in I2] + absl::StatusOr FoldGatherOfGather( + ScalarIndexedArray* source, Array* indices, int64_t source_dim, + absl::Span output_dims, Shape shape); + + // Reshapes a scalar-indexed node to remove the degenerate dimensions in its + // output. The result is always a scalar-indexed node. + absl::StatusOr ReshapeToRemoveDegenerateDims( + ScalarIndexedArray* operand); + + // Reshapes a scalar-indexed node such that the result has the degenerate + // dimensions `degenerate_dims`. The result is always a scalar-indexed node. + absl::StatusOr ReshapeToAddDegenerateDims( + ScalarIndexedArray* operand, absl::Span degenerate_dims); + + absl::StatusOr FoldReshapeOfGather( + const Shape& shape, ScalarIndexedConstantArray* operand); + absl::StatusOr FoldReshapeOfGatherNoDegenerateDims( + const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); + absl::StatusOr ComputeArrayForReshape(const Shape& shape, + Array* operand); + + absl::StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, + Array* lhs, + Array* rhs); + absl::StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, + Array* operand); + + template + T* Construct(Args&&... args) { + T* new_tensor = new T(std::forward(args)...); + owned_tensors_.push_back(std::unique_ptr(new_tensor)); + return new_tensor; + } + + ScalarIndexedArray* ConstructScalarIndexedArray( + Array* source, Array* indices, int64_t source_dim, + std::vector output_dims, Shape shape) { + if (source->kind() == Array::kConstant) { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } else { + return Construct(source, indices, source_dim, + std::move(output_dims), + std::move(shape)); + } + } + + Literal* TakeOwnership(Literal literal) { + owned_literals_.push_back(std::move(literal)); + return &owned_literals_.back(); + } + + absl::StatusOr TakeOwnership( + absl::StatusOr literal_or_error) { + TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return &owned_literals_.back(); + } + + std::vector> owned_tensors_; + std::vector owned_literals_; + absl::flat_hash_map cache_; +}; + +// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. +// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to +// unconditionally add to the regular HLO pass pipeline. +class IndexedArrayAnalysisPrinterPass : public HloModulePass { + public: + absl::string_view name() const override { + return "indexed-array-analysis-printer-pass"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/indexed_array_analysis_test.cc b/third_party/xla/xla/hlo/analysis/indexed_array_analysis_test.cc similarity index 99% rename from third_party/xla/xla/service/indexed_array_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/indexed_array_analysis_test.cc index 7438ac5de0bee0..35ed4aa06c4f15 100644 --- a/third_party/xla/xla/service/indexed_array_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/indexed_array_analysis_test.cc @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/indexed_array_analysis.h" +#include "xla/hlo/analysis/indexed_array_analysis.h" #include #include "absl/log/log.h" #include "absl/strings/ascii.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -class IndexedArrayAnalysisTest : public HloTestBase { +class IndexedArrayAnalysisTest : public HloHardwareIndependentTestBase { protected: void AssertArrayForRootExpressionIs(const std::string& hlo_text, const std::string& root_expression) { diff --git a/third_party/xla/xla/service/logical_buffer_analysis.cc b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc similarity index 99% rename from third_party/xla/xla/service/logical_buffer_analysis.cc rename to third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc index 916bc800aefc7e..5346a4112d0ef0 100644 --- a/third_party/xla/xla/service/logical_buffer_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/logical_buffer_analysis.h" +#include "xla/hlo/analysis/logical_buffer_analysis.h" #include diff --git a/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.h b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.h new file mode 100644 index 00000000000000..94cb521b8dbefe --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/logical_buffer_analysis.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_LOGICAL_BUFFER_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_LOGICAL_BUFFER_ANALYSIS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/logical_buffer.h" +#include "xla/shape_util.h" + +namespace xla { +// A class to create all the logical buffers defined by the HLO ops in a module. +class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { + public: + // Runs points-to analysis on 'module'. + static absl::StatusOr> Run( + const HloModule* module); + + // Returns the logical buffer with the given ID. + LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; + + // Returns the logical buffer that represents the output of a given HLO + // at a given index. + LogicalBuffer& GetBuffer(HloInstruction* instruction, + const ShapeIndex& index) const; + + const std::vector>& logical_buffers() const { + return logical_buffers_; + } + size_t num_logical_buffers() const { return logical_buffers_.size(); } + + private: + explicit LogicalBufferAnalysis(const HloModule* module) : module_(module) {} + absl::Status Analyze(); + + // The module this analysis is performed on. + const HloModule* module_; + + // Create a new logical buffer and return a reference to it. The newly created + // buffer is stored in an internal vector of LogicalBuffers and can be + // accessed with GetBuffer. + void NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index); + + absl::Status DefaultAction(HloInstruction* hlo_instruction) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleBitcast(HloInstruction* bitcast) override; + absl::Status HandleDomain(HloInstruction* domain) override; + absl::Status HandleCopy(HloInstruction* copy) override; + absl::Status HandleCopyStart(HloInstruction* copy_start) override; + absl::Status HandleCopyDone(HloInstruction* copy_done) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleAddDependency(HloInstruction* add_dependency) override; + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + + // A map from the buffer ID to the logical buffer + std::vector> logical_buffers_; + + // A map from an hlo + shape index to the logical buffer representing + // the appropriate output. + absl::flat_hash_map, + LogicalBuffer*> + output_buffers_; + // Whether to alias buffers defined by dataflow relations. This aliasing + // relation should not be recognized if copies can be inserted to break up + // the dataflow relation-induced aliasing. + const bool alias_buffer_across_dataflow_ = false; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_LOGICAL_BUFFER_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/tuple_points_to_analysis.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc similarity index 99% rename from third_party/xla/xla/service/tuple_points_to_analysis.cc rename to third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc index 179942c5ce7829..961be9ef4d7587 100644 --- a/third_party/xla/xla/service/tuple_points_to_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/tuple_points_to_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include #include @@ -29,7 +29,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/map_util.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/shape_util.h" #include "xla/types.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.h b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.h new file mode 100644 index 00000000000000..d182cfc7231f20 --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis.h @@ -0,0 +1,370 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_TUPLE_POINTS_TO_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_TUPLE_POINTS_TO_ANALYSIS_H_ + +#include + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/analysis/logical_buffer_analysis.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/logical_buffer.h" +#include "xla/shape_tree.h" +#include "xla/tsl/lib/gtl/compactptrset.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" + +namespace xla { + +// A class describing the source(s) of the Buffer(s) contained in the output of +// a particular HLO instruction. The structure of PointsToSet mirrors the +// structure of the instruction's shape, which may be an arbitrary tree (eg, a +// nested tuple). Each node in this tree corresponds to a single buffer in the +// instruction's output and contains the set of Buffers which might define +// the corresponding buffer. +class PointsToSet { + public: + // Construct our ShapeTree with a pointer rather than a reference to a Shape + // because this is very hot code, and copying (and then destroying) all these + // Shapes is slow. + explicit PointsToSet(const Shape* shape) : tree_(shape) {} + + // Returns true if any points-to sets for any subshape element is not a + // singleton. + bool IsAmbiguous() const; + + // Returns true if no LogicalBuffer appears in more than one points-to set of + // the shape nodes. + bool IsDistinct() const; + + // Returns the total number of different LogicalBuffers contained in this + // object. This is equal to CreateFlattenedSet().size(). + size_t size() const; + + // Creates a set containing the union of all LogicalBuffers contained in the + // PointsToSet. + using BufferSet = tsl::gtl::CompactPointerSet; + BufferSet CreateFlattenedSet() const; + + // Returns true if the given buffer is in the points-to set at the given + // index. + bool ContainsBufferAtIndex(const LogicalBuffer& buffer, + const ShapeIndex& index) const; + + // Returns true if the given buffer is in the points-to set at any index. + bool ContainsBuffer(const LogicalBuffer& buffer) const; + + // Adds the given buffer to the points-to set at the given index. This is a + // nop if the buffer already is in the set at that index. + void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); + + // For the subshape at the given index (where index is defined as in + // ShapeUtil::GetSubshape) this method returns the set of HLO instructions + // which may produce the tuple subshape at that index. For example, given: + // + // %tuple1 = tuple(...) + // %tuple2 = tuple(...) + // %select = select(%tuple1, %tuple2) + // %nested_tuple = tuple(%select, %tuple1) + // + // These are the values for tuple_sources() for the PointsToSet of + // %nested_tuple: + // + // tuple_sources({}) = {%nested_tuple} + // tuple_sources({0}) = {%tuple1, %tuple2} + // tuple_sources({1}) = {%tuple1} + // + // tuple_sources() at the index of an array shape (not a tuple) returns the + // empty set. The instructions in the set returned by tuple_sources + // necessarily are either Tuple instructions, constants, or parameters. + using SourceSet = tsl::gtl::CompactPointerSet; + const SourceSet& tuple_sources(const ShapeIndex& index) const; + + // Add a tuple source instruction for the given index. + void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); + + using BufferList = absl::InlinedVector; + + // Return the list of logical buffers for the subshape at index. + const BufferList& element(const ShapeIndex& index) const { + return tree_.element(index).buffers; + } + BufferList* mutable_element(const ShapeIndex& index) { + return &tree_.mutable_element(index)->buffers; + } + + // Call fn(index, buflist) for every subshape index. + template + void ForEachElement(const Fn& fn) const { + tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) { + fn(index, elem.buffers); + }); + } + template + void ForEachMutableElement(const Fn& fn) { + tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) { + fn(index, &elem->buffers); + }); + } + template + absl::Status ForEachElementWithStatus(const Fn& fn) const { + return tree_.ForEachElementWithStatus( + [&fn](const ShapeIndex& index, const Elem& elem) { + return fn(index, elem.buffers); + }); + } + + private: + struct Elem { + BufferList buffers; + SourceSet tuple_sources; + }; + ShapeTree tree_; + + // PointsToSet contains references (const LogicalBuffer*) to elements within + // TuplePointsToAnalysis, so disable copying. + PointsToSet(const PointsToSet&) = delete; + PointsToSet& operator=(const PointsToSet&) = delete; +}; + +// This class describes a particular subshape in a computation (instruction and +// shape index) and the logical buffer which may be a source of the subshape +// value. +class BufferAlias { + public: + BufferAlias(HloInstruction* instruction, const ShapeIndex& index) + : instruction_(instruction), index_(index) {} + + // Return the instruction/index of the subshape. + HloInstruction* instruction() const { return instruction_; } + const ShapeIndex& index() const { return index_; } + + bool operator==(const BufferAlias& other) const { + return instruction_ == other.instruction_ && index_ == other.index_; + } + bool operator!=(const BufferAlias& other) const { return !(*this == other); } + + std::string ToString() const; + + private: + HloInstruction* instruction_; + ShapeIndex index_; +}; + +std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); + +// DFS visitor that performs tuple points-to analysis. This analysis determines +// the potential sources of each buffer in each instruction's output. +class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { + public: + // Runs points-to analysis on 'module'. + static absl::StatusOr> Run( + const HloModule* module); + + // Return the points-to set of an instruction. This describes the potential + // sources of each buffer in the instruction's output. + const PointsToSet& GetPointsToSet( + const HloInstruction* hlo_instruction) const; + + // Returns the logical buffer with the given ID. + const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; + + // Returns the buffer defined at the given instruction and index. An error is + // returned if no buffer is defined at that point. + absl::StatusOr GetBufferDefinedAt( + const HloInstruction* instruction, const ShapeIndex& index) const; + + // Return a (possibly empty) vector containing all BufferAliases of the given + // logical buffer The buffer alias set is the inverse of the points-to set. + // That is, LogicalBuffer B is in the points-to set of instruction I at index + // N iff instruction I, index N is a BufferAlias of B. + using BufferAliasVector = absl::InlinedVector; + const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; + + // Returns the number of logical buffers in the module + LogicalBuffer::Id num_logical_buffers() const { + return logical_buffer_analysis_->num_logical_buffers(); + } + + // Return a the logical buffer with id "id" in the module. Iteration + // over all logical buffers is usually done with something like: + // + // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){ + // const auto& buffer = points_to.logical_buffer(id); + // ... do something with buffer ... + // } + LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const { + return logical_buffer_analysis_->GetBuffer(id); + } + + // Returns a vector of buffers that the instruction produces. Most + // instructions produce a single buffer (the top-level buffer), some produce + // no buffers (eg bitcast), and some produce more than one buffer (eg, + // tuple-shaped parameters). + using BufferDefinitionVector = absl::InlinedVector; + const BufferDefinitionVector& GetBuffersDefinedByInstruction( + const HloInstruction* instruction) const; + + // Returns true if the given instruction defines a buffer at the given index. + bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, + const ShapeIndex& index) const; + + // Returns an OK status if the given buffer is defined by instruction + // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer + // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns + // an FailedPrecondition error status otherwise. An example of a LogicalBuffer + // which is not defined is a tuple element in a Tuple instruction. In this + // case, the Tuple instruction does not define the LogicalBuffer, rather that + // index aliases one of its operands. + absl::Status VerifyBuffer(const LogicalBuffer& buffer) const; + + absl::Status DefaultAction(HloInstruction* hlo_instruction) override; + absl::Status HandleTuple(HloInstruction* tuple) override; + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + absl::Status HandleAsyncStart(HloInstruction* async_start) override; + absl::Status HandleAsyncUpdate(HloInstruction* async_update) override; + absl::Status HandleAsyncDone(HloInstruction* async_done) override; + absl::Status HandleBitcast(HloInstruction* bitcast) override; + absl::Status HandleDomain(HloInstruction* domain) override; + absl::Status HandleCopy(HloInstruction* copy) override; + absl::Status HandleCopyStart(HloInstruction* copy_start) override; + absl::Status HandleCopyDone(HloInstruction* copy_done) override; + absl::Status HandleRecvDone(HloInstruction* recv_done) override; + absl::Status HandleSend(HloInstruction* send) override; + absl::Status HandleAddDependency(HloInstruction* add_dependency) override; + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + absl::Status HandleFusion(HloInstruction* fusion) override; + absl::Status HandleOptimizationBarrier(HloInstruction* barrier) override; + + std::string ToString() const; + + // Returns true if 'user' cannot possibly use the buffer at 'index' in + // 'operand'. Returns false otherwise. + // + // REQUIRES: 'operand' is an operand of 'user'. + bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user) const; + + private: + explicit TuplePointsToAnalysis( + const HloModule* module, + std::unique_ptr logical_buffer_analysis) + : module_(module), + logical_buffer_analysis_(std::move(logical_buffer_analysis)) {} + + // Perform the analysis. Should be called immediately after constructing the + // object and before calling GetPointsToSet. + absl::Status Analyze(); + + // Populates instruction-defined buffers and aliases for each instruction + // in 'instructions'. + absl::Status PopulateDefinedBuffersAndAliases( + const decltype(std::declval() + .instructions())& instructions); + + // Creates an empty PointsToSet in the points_to_ map for the given + // instruction. + PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); + + // Creates a PointsToSet in the points_to_ map for 'instruction' which is a + // copy of the existing PointsToSet for 'src'. + PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, + const HloInstruction* src); + + // Adds the buffers defined by the given instruction to the given vector. + absl::Status GatherBuffersDefinedByInstruction( + const HloInstruction* instruction, BufferDefinitionVector* buffers); + + // Print points-to set for 'instruction' to 'output'. + void InstructionToString(const HloInstruction* instruction, + std::string* output) const; + + // Information kept per instruction + struct PerInstruction { + std::unique_ptr points_to_set; + // Empirically, ~92% of instructions have 1 + // instruction_defined_buffer, and 99% have 0 or 1 + BufferDefinitionVector instruction_defined_buffers; + }; + + const PerInstruction* PerInst(const HloInstruction* inst) const { + int id = inst->unique_id(); + DCHECK_GE(id, 0); + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + LOG(FATAL) << "Expected per-instruction information to already exist"; + } else { + return iter->second.get(); + } + } + PerInstruction* PerInst(const HloInstruction* inst) { + int id = inst->unique_id(); + DCHECK_GE(id, 0); + auto iter = per_instruction_.find(id); + if (iter == per_instruction_.end()) { + return per_instruction_.emplace(id, std::make_unique()) + .first->second.get(); + } else { + return iter->second.get(); + } + } + + std::vector> + GetAllUsesOfInstructionAtIndex(HloInstruction* instruction, + const ShapeIndex& index) const; + bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* fusion, + const int64_t use_operand_index) const; + + // The module this analysis is performed on. + const HloModule* module_; + + // The logical buffers for this module. + const std::unique_ptr logical_buffer_analysis_; + + // A map from instruction->unique_id() to + absl::flat_hash_map> per_instruction_; + + // A map from LogicalBuffer->id() to alias information about that logical + // buffer + std::vector logical_buffer_aliases_; + + TuplePointsToAnalysis(const TuplePointsToAnalysis&) = delete; + TuplePointsToAnalysis& operator=(const TuplePointsToAnalysis&) = delete; + // Whether to alias buffers connected by dataflow relations. This aliasing + // relation should not be recognized if copies can be inserted to break up + // the dataflow relation. + const bool alias_buffer_across_dataflow_ = false; +}; + +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_TUPLE_POINTS_TO_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/tuple_points_to_analysis_test.cc b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc similarity index 99% rename from third_party/xla/xla/service/tuple_points_to_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc index 90689b8e1227b9..b340b03f0440b8 100644 --- a/third_party/xla/xla/service/tuple_points_to_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/tuple_points_to_analysis_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/tuple_points_to_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include #include @@ -28,13 +28,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -46,7 +46,7 @@ namespace { using ::testing::UnorderedElementsAre; using ::testing::UnorderedElementsAreArray; -class TuplePointsToAnalysisTest : public HloTestBase { +class TuplePointsToAnalysisTest : public HloHardwareIndependentTestBase { protected: // Builds a module with the given entry computation and runs points to // analysis. @@ -745,7 +745,7 @@ ENTRY %FusionParam0TwoUsers (param0: (f32[8], f32[3])) -> f32[8] { Run(hlo_str, /*expected_num_users=*/2); } -class PointsToAnalysisTestBase : public HloTestBase { +class PointsToAnalysisTestBase : public HloHardwareIndependentTestBase { protected: void BuildModule(std::unique_ptr computation) { module_ = CreateNewVerifiedModule(); diff --git a/third_party/xla/xla/service/while_loop_analysis.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc similarity index 89% rename from third_party/xla/xla/service/while_loop_analysis.cc rename to third_party/xla/xla/hlo/analysis/while_loop_analysis.cc index 9ca0da67eaed17..92903693c34e0c 100644 --- a/third_party/xla/xla/service/while_loop_analysis.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/while_loop_analysis.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include #include @@ -83,9 +83,7 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, } if (!Match(possibly_gte_operand, - m::GetTupleElement(m::Op().Is(gte_operand))) && - !Match(possibly_gte_operand, - m::GetTupleElement(m::CustomCall(m::Op().Is(gte_operand))))) { + m::GetTupleElement(m::Op().Is(gte_operand)))) { return nullopt; } @@ -282,7 +280,7 @@ optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { return nullopt; } - // The while_body computation should have one of the following forms: + // The while_body computation should have the form: // // Form 1: // while_body_inc = @@ -290,60 +288,15 @@ optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { // while_body_root = tuple(..., while_body_inc, ...) // // where while_body_inc is operand N of while_body_root. - // - // Form 2: - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // tuple = tuple(..., while_body_inc, ...) - // while_body_root = custom-call(tuple) - // - // where while_body_inc is operand N of the tuple, and the tuple is the - // operand of the while_body_root custom-call instruction. - // - // Form 3: - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = custom-call(input1, ..., while_body_inc, ...) - // - // where while_body_inc is an operand of the while_body_root custom-call - // instruction, and the custom-call instruction does not have a tuple operand. auto* while_body = while_op->while_body(); auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple && - while_body_root->opcode() != HloOpcode::kCustomCall) { - VLOG(2) << "While body's root is not a tuple or custom-call instruction: " + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " << while_body_root->ToString(); return nullopt; } const HloInstruction* while_body_inc; - if (while_body_root->opcode() == HloOpcode::kTuple) { - while_body_inc = while_body_root->operand(*indvar_tuple_idx); - } else { - // Custom-call cases - if (while_body_root->operand_count() == 1 && - while_body_root->operand(0)->opcode() == HloOpcode::kTuple) { - // Custom-call case - // Single tuple operand. - auto* while_body_root_input_tuple = while_body_root->operand(0); - if (*indvar_tuple_idx >= while_body_root_input_tuple->operand_count()) { - VLOG(2) << "Cannot find the induction variable in the output root " - "custom-call " - << while_body_root->ToString(); - return std::nullopt; - } - while_body_inc = while_body_root_input_tuple->operand(*indvar_tuple_idx); - } else { - // Custom-call case - // Operand is not single tuple. - if (*indvar_tuple_idx >= while_body_root->operand_count()) { - VLOG(2) << "Cannot find the induction variable in the output root " - "custom-call " - << while_body_root->ToString(); - return std::nullopt; - } - while_body_inc = while_body_root->operand(*indvar_tuple_idx); - } - } + while_body_inc = while_body_root->operand(*indvar_tuple_idx); auto* while_body_param = while_body->parameter_instruction(0); optional while_body_indvar_tuple_idx = GetGTEOperandIndex(while_body_inc, while_body_param); @@ -415,28 +368,8 @@ optional MatchTrivialLoopTripCount(const HloInstruction* while_op, // Check that `i` goes as `i += k` in the while body where k is a natural // number. auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - HloInstruction* while_body_indvar_update; - - if (while_body_root->opcode() == HloOpcode::kCustomCall) { - // We know it must be a custom-call. - if (while_body_root->operand_count() == 1 && - while_body_root->operand(0)->opcode() == HloOpcode::kTuple) { - // Custom-call case - // Single tuple operand. - auto* while_body_root_input_tuple = while_body_root->mutable_operand(0); - while_body_indvar_update = - while_body_root_input_tuple->mutable_operand(indvar_tuple_idx); - } else { - // Custom-call case - // Operand is not single tuple. - while_body_indvar_update = - while_body_root->mutable_operand(indvar_tuple_idx); - } - } else { - while_body_indvar_update = - while_body_root->mutable_operand(indvar_tuple_idx); - } + auto* while_body_indvar_update = + while_body->root_instruction()->mutable_operand(indvar_tuple_idx); auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); HloInstruction* trip_count_increase_step_instr = nullptr; int64_t trip_count_step = 0; diff --git a/third_party/xla/xla/hlo/analysis/while_loop_analysis.h b/third_party/xla/xla/hlo/analysis/while_loop_analysis.h new file mode 100644 index 00000000000000..e2270489cde49e --- /dev/null +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_ANALYSIS_WHILE_LOOP_ANALYSIS_H_ +#define XLA_HLO_ANALYSIS_WHILE_LOOP_ANALYSIS_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" + +namespace xla { + +// Returns the precise trip count of the loop if it's statically known, +// nullopt otherwise. +// +// max_brute_force_iters limits the number of steps that are evaluated while +// trying to brute force a loop trip count. trip counts larger than +// max_brute_force_iters may be returned if we can pattern-match the loop +// condition. +std::optional ComputeWhileLoopTripCount( + const HloInstruction *while_op, int64_t max_brute_force_iters = 128); + +// Returns an upper bound on the trip count of the loop if it's statically +// known, nullopt otherwise. +std::optional ComputeWhileLoopTripCountUpperBound( + const HloInstruction *while_op); + +// The below function identifies a subset of all possible auxiliary +// induction variables (AIV). Specifically, candidates are gtes, e.g., +// gte(param0, N) +std::vector GetAuxiliaryLoopInductionVars( + const HloInstruction *while_op); +// Returns the tuple index of the loop induction variable if there is such an +// induction variable detected. Otherwise returns nullopt. +std::optional GetLoopInductionVarTupleIdx( + const HloInstruction *while_op); + +// Checks the following conditions: +// - `i`, the induction varaiable, is initialized to a scalar constant K +// (namely, `indvar_init`), +// - the while condition does `i < N` or `i <= N` (where N is a know constant) +// - the while body does `i++`. +// If so, it's trivial to compute the loop bound as `N - k` or `N - k + 1`, +// respectively. +std::optional MatchTrivialLoopTripCount(const HloInstruction *while_op, + int64_t indvar_tuple_idx, + const Literal &indvar_init); +} // namespace xla + +#endif // XLA_HLO_ANALYSIS_WHILE_LOOP_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/while_loop_analysis_test.cc b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc similarity index 97% rename from third_party/xla/xla/service/while_loop_analysis_test.cc rename to third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc index d8e3c05c31b25f..13e8af536f46cd 100644 --- a/third_party/xla/xla/service/while_loop_analysis_test.cc +++ b/third_party/xla/xla/hlo/analysis/while_loop_analysis_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/while_loop_analysis.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include #include @@ -33,15 +33,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -class WhileLoopAnalysisTest : public HloTestBase { +class WhileLoopAnalysisTest : public HloHardwareIndependentTestBase { protected: [[nodiscard]] absl::StatusOr MakeWhileLoopAndGetTripCount( int init, int limit, int step, ComparisonDirection dir); @@ -159,7 +159,7 @@ TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCallNonTuple) { )"; auto m = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* while_op = m->entry_computation()->root_instruction(); - EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 5); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), std::nullopt); } TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCall) { @@ -192,7 +192,7 @@ TEST_F(WhileLoopAnalysisTest, SimpleLoopWithCustomCall) { )"; auto m = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* while_op = m->entry_computation()->root_instruction(); - EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 5); + EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), std::nullopt); } TEST_F(WhileLoopAnalysisTest, NoUpperBound) { diff --git a/third_party/xla/xla/hlo/builder/BUILD b/third_party/xla/xla/hlo/builder/BUILD new file mode 100644 index 00000000000000..1d23ec480af36c --- /dev/null +++ b/third_party/xla/xla/hlo/builder/BUILD @@ -0,0 +1,190 @@ +# Description: +# XLA builder libraries. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +cc_library( + name = "padding", + srcs = ["padding.cc"], + hdrs = ["padding.h"], + deps = [ + "//xla:util", + "//xla/tsl/lib/math:math_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "padding_test", + srcs = ["padding_test.cc"], + deps = [ + ":padding", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sharding_builder", + srcs = ["sharding_builder.cc"], + hdrs = ["sharding_builder.h"], + deps = [ + "//xla:array", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log:check", + ], +) + +cc_library( + name = "xla_computation", + srcs = ["xla_computation.cc"], + hdrs = ["xla_computation.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "value_inference", + srcs = ["value_inference.cc"], + hdrs = ["value_inference.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_builder", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "xla_builder", + srcs = ["xla_builder.cc"], + hdrs = ["xla_builder.h"], + visibility = ["//visibility:public"], + deps = [ + ":padding", + ":sharding_builder", + ":xla_computation", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:array4d", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:sharding_op_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:bitmap", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:stacktrace", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "xla_builder_test", + srcs = ["xla_builder_test.cc"], + deps = [ + ":padding", + ":sharding_builder", + ":value_inference", + ":xla_builder", + ":xla_computation", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/service:hlo_proto_cc", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/hlo/builder/lib/BUILD b/third_party/xla/xla/hlo/builder/lib/BUILD new file mode 100644 index 00000000000000..d096bb73ba289a --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/BUILD @@ -0,0 +1,787 @@ +# Common computation builders for XLA. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tests:build_defs.bzl", "generate_backend_suites", "xla_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla/hlo/builder:friends"]), + licenses = ["notice"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +# Generate test_suites for all backends, named "${backend}_tests". +generate_backend_suites() + +cc_library( + name = "arithmetic", + srcs = ["arithmetic.cc"], + hdrs = ["arithmetic.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "arithmetic_test", + srcs = ["arithmetic_test.cc"], + deps = [ + ":arithmetic", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "comparators", + srcs = ["comparators.cc"], + hdrs = [ + "comparators.h", + ], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_test( + name = "comparators_test", + srcs = ["comparators_test.cc"], + deps = [ + ":comparators", + ":constants", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "constants", + srcs = ["constants.cc"], + hdrs = ["constants.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "broadcast", + srcs = ["broadcast.cc"], + hdrs = ["broadcast.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "constants_test", + srcs = ["constants_test.cc"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "conv_grad_size_util", + srcs = ["conv_grad_size_util.cc"], + hdrs = ["conv_grad_size_util.h"], + deps = [ + "//xla/hlo/builder:padding", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "dynamic_shaped_ops", + srcs = ["dynamic_shaped_ops.cc"], + hdrs = ["dynamic_shaped_ops.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:value_inference", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "loops", + srcs = ["loops.cc"], + hdrs = ["loops.h"], + deps = [ + ":constants", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "math", + srcs = ["math.cc"], + hdrs = [ + "math.h", + "math_impl.h", + ], + deps = [ + ":arithmetic", + ":constants", + ":loops", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "math_test", + timeout = "long", + srcs = ["math_test.cc"], + backend_tags = { + # Times out. + "ghostfish_iss": ["noasan"], + }, + deps = [ + ":constants", + ":math", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/service", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "matrix", + srcs = ["matrix.cc"], + hdrs = ["matrix.h"], + deps = [ + ":arithmetic", + ":constants", + ":slicing", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "matrix_test", + srcs = ["matrix_test.cc"], + deps = [ + ":constants", + ":matrix", + ":slicing", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:array4d", + "//xla:test", + "//xla:types", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "pooling", + srcs = ["pooling.cc"], + hdrs = ["pooling.h"], + deps = [ + ":arithmetic", + ":constants", + ":conv_grad_size_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "pooling_test", + srcs = ["pooling_test.cc"], + deps = [ + ":pooling", + "//xla:error_spec", + "//xla:shape_util", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "prng", + srcs = ["prng.cc"], + hdrs = ["prng.h"], + deps = [ + ":constants", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "prng_test", + srcs = ["prng_test.cc"], + deps = [ + ":constants", + ":prng", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "qr", + srcs = ["qr.cc"], + hdrs = ["qr.h"], + deps = [ + ":constants", + ":matrix", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "qr_test", + srcs = ["qr_test.cc"], + tags = ["optonly"], + deps = [ + ":matrix", + ":qr", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "lu_decomposition", + srcs = ["lu_decomposition.cc"], + hdrs = ["lu_decomposition.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "approx_topk", + srcs = ["approx_topk.cc"], + hdrs = ["approx_topk.h"], + deps = [ + ":approx_topk_shape", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "approx_topk_shape", + srcs = ["approx_topk_shape.cc"], + hdrs = ["approx_topk_shape.h"], + deps = [ + "//xla:util", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "slicing", + srcs = ["slicing.cc"], + hdrs = ["slicing.h"], + deps = [ + ":arithmetic", + ":constants", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "slicing_test", + srcs = ["slicing_test.cc"], + deps = [ + ":slicing", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sorting", + srcs = ["sorting.cc"], + hdrs = ["sorting.h"], + deps = [ + ":comparators", + ":constants", + ":loops", + ":slicing", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "sorting_test", + srcs = ["sorting_test.cc"], + deps = [ + ":sorting", + "//xla:array", + "//xla:array2d", + "//xla:error_spec", + "//xla:literal_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/algorithm:container", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "quantize", + hdrs = ["quantize.h"], + deps = [ + ":constants", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@local_tsl//tsl/platform:bfloat16", + ], +) + +xla_test( + name = "quantize_test", + srcs = ["quantize_test.cc"], + # TODO(b/122119490): re-enable TAP after fixing. + tags = [ + "manual", + "notap", + ], + deps = [ + ":quantize", + "//xla:array2d", + "//xla:test", + "//xla:types", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "self_adjoint_eig", + srcs = ["self_adjoint_eig.cc"], + hdrs = ["self_adjoint_eig.h"], + deps = [ + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "self_adjoint_eig_test", + srcs = ["self_adjoint_eig_test.cc"], + real_hardware_only = True, + shard_count = 5, + tags = ["optonly"], + deps = [ + ":arithmetic", + ":constants", + ":math", + ":matrix", + ":self_adjoint_eig", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "svd", + srcs = ["svd.cc"], + hdrs = ["svd.h"], + deps = [ + ":arithmetic", + ":comparators", + ":constants", + ":loops", + ":math", + ":matrix", + ":slicing", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "svd_test", + srcs = ["svd_test.cc"], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], + deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":slicing", + ":svd", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tridiagonal", + srcs = ["tridiagonal.cc"], + hdrs = ["tridiagonal.h"], + deps = [ + ":constants", + ":loops", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "tridiagonal_test", + srcs = ["tridiagonal_test.cc"], + real_hardware_only = True, + shard_count = 10, + tags = ["optonly"], + deps = [ + ":slicing", + ":tridiagonal", + "//xla:array", + "//xla:array3d", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "logdet", + srcs = ["logdet.cc"], + hdrs = ["logdet.h"], + deps = [ + ":arithmetic", + ":constants", + ":matrix", + ":qr", + ":slicing", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "logdet_test", + srcs = ["logdet_test.cc"], + tags = [ + "optonly", + ], + deps = [ + ":logdet", + "//xla:array", + "//xla:array2d", + "//xla:array3d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/builder:xla_builder", + "//xla/tests:client_library_test_base", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tuple", + srcs = ["tuple.cc"], + hdrs = ["tuple.h"], + deps = [ + "//xla:shape_tree", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "tuple_test", + srcs = ["tuple_test.cc"], + deps = [ + ":tuple", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/service", + "//xla/tests:client_library_test_base", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/client/lib/approx_topk.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc similarity index 98% rename from third_party/xla/xla/client/lib/approx_topk.cc rename to third_party/xla/xla/hlo/builder/lib/approx_topk.cc index 7a5c7bd379cb82..16e9c090e9dd3b 100644 --- a/third_party/xla/xla/client/lib/approx_topk.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/approx_topk.h" +#include "xla/hlo/builder/lib/approx_topk.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/lib/approx_topk_shape.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk.h b/third_party/xla/xla/hlo/builder/lib/approx_topk.h new file mode 100644 index 00000000000000..f940d26967cc76 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk.h @@ -0,0 +1,72 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ +#define XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes approximate top-ks by aggregating top-1s in equal-sized windows. +// The number and the size of the windows are determined by the `recall_target`. +// +// operand: A sequence of multi-dimensional arrays of type T_0, ..., T_{N-1} +// init_values: N starting values for top-1 reductions +// top_k: Determines the k in top-k operation. +// reduction_dim: Determines the dimension to compute top-k. +// comparator: The comparator computation to use, which should have function +// signatore of (T_0, T_0, T_1, T_1, ..., T_{N-1}, T_{N-1}) -> bool. +// recall_target: Valid range (0, 1]. User can trade-off quality and performance +// with this knob. +// aggregate_to_topk: When true, sorts the set of approximate top-k elements and +// only keep the final k elements on TPU. This option is useful when user +// wanted to forward the approximate results to host and aggregate the results +// on CPU for better throughput. +// reduction_input_size_override: When set to a positive value, it overrides the +// size determined by operands[reduction_dim] for evaluating the recall. This +// option is useful when the given operand is only a subset of the overall +// computation in SPMD or distributed pipelines, where the true input size +// cannot be deferred by the operand shape. +// +// Returns a sequence of multidimensional arrays of type T_0, ..., T_{N-1}, +// which contains the approximate top-ks from the input operands. When +// `aggregate_to_topk` is set to true, the output size is just top_k. When +// `aggregate_to_topk` is set to false, the output size varied by the target +// recall. For target recall = 0.9, the output size is roughly 10 * top_k. For +// target recall = 0.99, the output size is roughly 100 * top_k. +// +// TODO(fchern): Support other hardware platforms. +XlaOp ApproxTopK(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, int64_t top_k, + int64_t reduction_dim, const XlaComputation& comparator, + float recall_target = 0.9, bool aggregate_to_topk = true, + int64_t reduction_input_size_override = -1); + +// Fallback for platforms that haven't been optimized. +XlaOp ApproxTopKFallback(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, int64_t top_k, + int64_t reduction_dim, + const XlaComputation& comparator, + float recall_target = 0.9, + bool aggregate_to_topk = true, + int64_t reduction_input_size_override = -1); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_APPROX_TOPK_H_ diff --git a/third_party/xla/xla/client/lib/approx_topk_shape.cc b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc similarity index 98% rename from third_party/xla/xla/client/lib/approx_topk_shape.cc rename to third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc index 374aa01830fccf..f6925f330c1267 100644 --- a/third_party/xla/xla/client/lib/approx_topk_shape.cc +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/approx_topk_shape.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" #include #include diff --git a/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h new file mode 100644 index 00000000000000..83b2b71d1054e5 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/approx_topk_shape.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ +#define XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ + +#include + +#include "absl/status/statusor.h" + +namespace xla { + +// Determine the output size of the reduction dimension. This is useful for jax +// abstract eval to determine the output size. +// +// input_size: Input size of the reduction dimension. +// rank: Rank of the input operand. +// top_k: Determines the k in top-k operation. +// recall_target: Valid range (0, 1]. User can trade-off quality and performance +// with this knob. +// aggregate_to_topk: When true, sorts the set of approximate top-k elements and +// only keep the final k elements on TPU. This option is useful when user +// wanted to forward the approximate results to host and aggregate the results +// on CPU for better throughput. +// +// Returns a pair of +// 1. Reduction output size +// 2. Reduction amount in log2 form. +// +// 2. is invalid and set to -1 when the approximate output is disabled, i.e. +// top_k = 1 or aggregate_to_topk = true. +absl::StatusOr> ApproxTopKReductionOutputSize( + int64_t input_size, int64_t rank, int64_t top_k, float recall_target, + bool aggregate_to_topk, int64_t input_size_override = -1); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_APPROX_TOPK_SHAPE_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc similarity index 97% rename from third_party/xla/xla/client/lib/arithmetic.cc rename to third_party/xla/xla/hlo/builder/lib/arithmetic.cc index e14bd9118def05..6ec14f7dd31d43 100644 --- a/third_party/xla/xla/client/lib/arithmetic.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/arithmetic.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/arithmetic.h b/third_party/xla/xla/hlo/builder/lib/arithmetic.h new file mode 100644 index 00000000000000..fda730573f37f8 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic.h @@ -0,0 +1,90 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ +#define XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ + +#include +#include +#include + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +using XlaOpGenerator = std::function; + +// Creates a scalar computation based on a lambda and returns it. +XlaComputation CreateScalarComputation(const std::string& name, + PrimitiveType type, XlaBuilder* builder, + XlaOpGenerator generator); + +// Creates a scalar add computation and returns it. +XlaComputation CreateScalarAddComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar multiply computation and returns it. +XlaComputation CreateScalarMultiplyComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar ge computation and returns it. +XlaComputation CreateScalarGeComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar max computation and returns it. +XlaComputation CreateScalarMaxComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar min computation and returns it. +XlaComputation CreateScalarMinComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar logical AND computation and returns it. +XlaComputation CreateScalarAndComputation(PrimitiveType type, + XlaBuilder* builder); + +// Creates a scalar logical OR computation and returns it. +XlaComputation CreateScalarOrComputation(PrimitiveType type, + XlaBuilder* builder); + +// This is to be used for general purpose "identity" like reductions with zero +// for any type (ie. boolean operations for PRED and Add for real numbers). +// As an example, this operation can be used for a situation of: +// x_type = type(x) +// op = CreateScalarIdentityWithZeroComputation(x_type) +// ASSERT_TRUE(op(x, 0) == x) +// +// This functionality is used for operations that are similar to a slice, +// gather, or broadcast, but are created through a reduction. +XlaComputation CreateScalarIdentityWithZeroComputation(PrimitiveType type, + XlaBuilder* builder); + +// Returns whether any predicate in "predicates" is set. +// +// Note: if predicates is zero-sized, Any() vacuously returns false. +XlaOp Any(XlaOp predicates); + +// Returns the argmax of `input` along `axis`. `output_type` is the type to +// use for the output. In case of ties always prefers smaller index. +XlaOp ArgMax(XlaOp input, PrimitiveType output_type, int axis); + +// Dispatch to ArgMin or ArgMax above, depending on bool. +XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_ARITHMETIC_H_ diff --git a/third_party/xla/xla/client/lib/arithmetic_test.cc b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/arithmetic_test.cc rename to third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc index abbdf06fb8b731..3cde6bf0f4e5c3 100644 --- a/third_party/xla/xla/client/lib/arithmetic_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/arithmetic_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/arithmetic.h" +#include "xla/hlo/builder/lib/arithmetic.h" #include #include #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/broadcast.cc b/third_party/xla/xla/hlo/builder/lib/broadcast.cc similarity index 97% rename from third_party/xla/xla/client/lib/broadcast.cc rename to third_party/xla/xla/hlo/builder/lib/broadcast.cc index 8c3336ef9e312e..aaabe046cebb02 100644 --- a/third_party/xla/xla/client/lib/broadcast.cc +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/broadcast.h" +#include "xla/hlo/builder/lib/broadcast.h" #include @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/broadcast.h b/third_party/xla/xla/hlo/builder/lib/broadcast.h new file mode 100644 index 00000000000000..86cf39f64ddc82 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/broadcast.h @@ -0,0 +1,35 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_BROADCAST_H_ +#define XLA_HLO_BUILDER_LIB_BROADCAST_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/primitive_util.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting +// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling. +absl::StatusOr BroadcastTo(XlaOp input, + absl::Span output_dims); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_BROADCAST_H_ diff --git a/third_party/xla/xla/client/lib/comparators.cc b/third_party/xla/xla/hlo/builder/lib/comparators.cc similarity index 97% rename from third_party/xla/xla/client/lib/comparators.cc rename to third_party/xla/xla/hlo/builder/lib/comparators.cc index 771d5331803d49..fec1874a0373d4 100644 --- a/third_party/xla/xla/client/lib/comparators.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/comparators.h" +#include "xla/hlo/builder/lib/comparators.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/comparators.h b/third_party/xla/xla/hlo/builder/lib/comparators.h new file mode 100644 index 00000000000000..8dd3e47e07eb48 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/comparators.h @@ -0,0 +1,60 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_COMPARATORS_H_ +#define XLA_HLO_BUILDER_LIB_COMPARATORS_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Creates a scalar less-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// -NaN < -infinity < ... < -0 < 0 < ... < infinity < NaN +XlaComputation CreateScalarLtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar greater-than computation and returns it. The created +// computation has 2 * 'operand_types.size()' many parameters, where parameters +// 2 * i and 2 * i + 1 are a scalar with primitive type 'operand_types[i]'. The +// computation compares the first two parameters. For floating point types, a +// total order is created where +// NaN > infinity > ... > 0 > -0 > ... > -infinity > -NaN +XlaComputation CreateScalarGtComputation( + const std::vector& operand_types, XlaBuilder* builder); + +// Creates a scalar comparison computation and returns it. This function takes +// a vector of comparator functions to compare the operands where the function +// isn't nullopt with the specified comparator at that location. +XlaComputation CreateScalarComparisonComputation( + const std::string& name, const std::vector& operand_types, + const std::vector< + std::optional)>>& + generators, + XlaBuilder* builder); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_COMPARATORS_H_ diff --git a/third_party/xla/xla/client/lib/comparators_test.cc b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/comparators_test.cc rename to third_party/xla/xla/hlo/builder/lib/comparators_test.cc index acaf2f19985276..39bf073171a86b 100644 --- a/third_party/xla/xla/client/lib/comparators_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/comparators_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/comparators.h" +#include "xla/hlo/builder/lib/comparators.h" #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/client/lib/constants.cc b/third_party/xla/xla/hlo/builder/lib/constants.cc similarity index 98% rename from third_party/xla/xla/client/lib/constants.cc rename to third_party/xla/xla/hlo/builder/lib/constants.cc index 1e5a7fae4c9c10..acfa2fe0b66e2c 100644 --- a/third_party/xla/xla/client/lib/constants.cc +++ b/third_party/xla/xla/hlo/builder/lib/constants.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/hlo/builder/lib/constants.h b/third_party/xla/xla/hlo/builder/lib/constants.h new file mode 100644 index 00000000000000..ce695736d1e49c --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/constants.h @@ -0,0 +1,140 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_CONSTANTS_H_ +#define XLA_HLO_BUILDER_LIB_CONSTANTS_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/primitive_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is +// determined at C++ run-time, rather than C++ compile-time. +// If 'value' is floating point but 'type' is not, or if 'value' is complex but +// 'type' is not, an error will be returned. This is to catch accidental +// truncation; in such cases, use an explicit cast. +template +XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) { + if (std::is_floating_point::value && + !(primitive_util::IsFloatingPointType(type) || + primitive_util::IsComplexType(type))) { + return builder->ReportError(InvalidArgument( + "Invalid cast from floating point type to %s in ConstantR0WithType.", + PrimitiveType_Name(type))); + } + if (std::is_same::value && + !primitive_util::IsComplexType(type)) { + return builder->ReportError(InvalidArgument( + "Invalid cast from complex type to %s in ConstantR0WithType.", + PrimitiveType_Name(type))); + } + return primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> XlaOp { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + return ConstantR0(builder, static_cast(value)); + } + return builder->ReportError( + InvalidArgument("Invalid type for ConstantR0WithType (%s).", + PrimitiveType_Name(type))); + }, + type); +} + +// Returns a scalar containing 'value' cast to the same run-time type as +// 'prototype'. +// If 'value' is floating point but 'prototype' is not, or if 'value' is complex +// 'prototype' is not, an error will be returned. +template +XlaOp ScalarLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + return ConstantR0WithType(builder, shape.element_type(), value); + }); +} + +// Returns an array or scalar containing copies of `value` cast to the same +// run-type type as `prototype` and broadcast to the same dimensions as +// `prototype`. +// +// If `prototype` is not a scalar or array, returns an error. +template +XlaOp FullLike(XlaOp prototype, T value) { + XlaBuilder* builder = prototype.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype)); + if (ShapeUtil::IsScalar(shape) || shape.IsArray()) { + return Broadcast(ScalarLike(prototype, value), shape.dimensions()); + } else { + return InvalidArgument( + "Prototype shape for BroadcastConstantLike must be a scalar or " + "array, but was %s", + shape.ToString()); + } + }); +} + +// Returns a scalar with value '0' of 'type'. +XlaOp Zero(XlaBuilder* builder, PrimitiveType type); + +// Returns a zero-filled tensor with shape `shape`. +XlaOp Zeros(XlaBuilder* builder, const Shape& shape); + +// Returns a zero-filled tensor with the same shape as `prototype`. +XlaOp ZerosLike(XlaOp prototype); + +// Returns a scalar with value '1' of 'type'. +XlaOp One(XlaBuilder* builder, PrimitiveType type); + +// Returns the machine epsilon for floating-point type `type`, i.e., +// the difference between 1.0 and the next representable value. +XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite or infinite value for 'type'. +// Returns '-inf' for floating-point types. +XlaOp MinValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum representable finite value for 'type'. For a floating +// point type, this is equal to -MaxFiniteValue(). +XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the minimum positive normal value for floating-point type `type`. +XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite or infinite value for 'type'. +// Returns 'inf' for floating-point types. +XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type); + +// Returns the maximum representable finite value for 'type'. +XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type); + +// Returns a nan for the given type. Only valid for real-valued fp types. +XlaOp NanValue(XlaBuilder* builder, PrimitiveType type); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_CONSTANTS_H_ diff --git a/third_party/xla/xla/client/lib/constants_test.cc b/third_party/xla/xla/hlo/builder/lib/constants_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/constants_test.cc rename to third_party/xla/xla/hlo/builder/lib/constants_test.cc index 2ae344f2e6cf9e..61aa0ae71dee5b 100644 --- a/third_party/xla/xla/client/lib/constants_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/constants_test.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/conv_grad_size_util.cc b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc similarity index 97% rename from third_party/xla/xla/client/lib/conv_grad_size_util.cc rename to third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc index f08328c9086b4f..9bbe184a9d6140 100644 --- a/third_party/xla/xla/client/lib/conv_grad_size_util.cc +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" #include #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h new file mode 100644 index 00000000000000..91e43d226c180b --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/conv_grad_size_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ +#define XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/padding.h" + +namespace xla { + +// Information about a single spatial dimension for a convolution gradients and +// windowed operations. +struct SpatialDimensionOutputSizeAndPadding { + // Effective size of the operation output (potentially expanded). + int64_t output_size; + // Number of padding elements to be added before/after this dimension of + // the input when computing the input gradient. + int64_t pad_before; + int64_t pad_after; +}; + +// Verifies that the dimensions all match, and computes the size and padding of +// a spatial dimension for convolution gradient operations. +absl::StatusOr +ConvGradExtractAndVerifyDimension(int64_t input_size, int64_t filter_size, + int64_t output_size, int64_t dilation, + int64_t stride, Padding padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_CONV_GRAD_SIZE_UTIL_H_ diff --git a/third_party/xla/xla/client/lib/dynamic_shaped_ops.cc b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc similarity index 98% rename from third_party/xla/xla/client/lib/dynamic_shaped_ops.cc rename to third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc index c263d31badcdf5..ba82ec343ce55a 100644 --- a/third_party/xla/xla/client/lib/dynamic_shaped_ops.cc +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/dynamic_shaped_ops.h" +#include "xla/hlo/builder/lib/dynamic_shaped_ops.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h new file mode 100644 index 00000000000000..71188b8fb80a22 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/dynamic_shaped_ops.h @@ -0,0 +1,59 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ +#define XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/primitive_util.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Similar to static shaped conditional, but allows true_computation and +// false_computation to have different dimension sizes (ranks still have to be +// the same). Fall back to static conditional if dynamism is not presented. +XlaOp DynamicConditional(XlaBuilder* builder, XlaOp predicate, + XlaOp true_operand, + const XlaComputation& true_computation, + XlaOp false_operand, + const XlaComputation& false_computation); + +// Similar to DynamicConditional, but support multiple branches. +XlaOp DynamicConditional( + XlaBuilder* builder, XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + +// Similar to SetDimensionSize, but automatically adjust the bound of output if +// a tighter one can be inferred by `value_inference`. +absl::StatusOr SetDimensionSizeWithRebound( + ValueInference* value_inference, XlaOp operand, XlaOp dimension_size, + int64_t dimension); + +// Take a `operand` tensor and a R1 tensor `size_vector` representing the sizes +// of `operand`, Call SetDimensionSize if for each dimension whose size is +// dynamic. +absl::StatusOr SetAllDimensionSizes(ValueInference* value_inference, + XlaOp operand, XlaOp size_vector); +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_DYNAMIC_SHAPED_OPS_H_ diff --git a/third_party/xla/xla/client/lib/generate_math_impl.py b/third_party/xla/xla/hlo/builder/lib/generate_math_impl.py similarity index 100% rename from third_party/xla/xla/client/lib/generate_math_impl.py rename to third_party/xla/xla/hlo/builder/lib/generate_math_impl.py diff --git a/third_party/xla/xla/client/lib/logdet.cc b/third_party/xla/xla/hlo/builder/lib/logdet.cc similarity index 90% rename from third_party/xla/xla/client/lib/logdet.cc rename to third_party/xla/xla/hlo/builder/lib/logdet.cc index 96063a5c72431f..cc17d0ec26ffe6 100644 --- a/third_party/xla/xla/client/lib/logdet.cc +++ b/third_party/xla/xla/hlo/builder/lib/logdet.cc @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/logdet.h" +#include "xla/hlo/builder/lib/logdet.h" #include #include #include #include "absl/status/statusor.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/qr.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/client/lib/logdet.h b/third_party/xla/xla/hlo/builder/lib/logdet.h similarity index 87% rename from third_party/xla/xla/client/lib/logdet.h rename to third_party/xla/xla/hlo/builder/lib/logdet.h index ee3d984fa69319..8c02d72de9940b 100644 --- a/third_party/xla/xla/client/lib/logdet.h +++ b/third_party/xla/xla/hlo/builder/lib/logdet.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_CLIENT_LIB_LOGDET_H_ -#define XLA_CLIENT_LIB_LOGDET_H_ +#ifndef XLA_HLO_BUILDER_LIB_LOGDET_H_ +#define XLA_HLO_BUILDER_LIB_LOGDET_H_ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { @@ -34,4 +34,4 @@ XlaOp LogDet(XlaOp a); } // namespace xla -#endif // XLA_CLIENT_LIB_LOGDET_H_ +#endif // XLA_HLO_BUILDER_LIB_LOGDET_H_ diff --git a/third_party/xla/xla/client/lib/logdet_test.cc b/third_party/xla/xla/hlo/builder/lib/logdet_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/logdet_test.cc rename to third_party/xla/xla/hlo/builder/lib/logdet_test.cc index b2600ed7f7ea23..8618aab2aa833d 100644 --- a/third_party/xla/xla/client/lib/logdet_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/logdet_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/logdet.h" +#include "xla/hlo/builder/lib/logdet.h" #include #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/loops.cc b/third_party/xla/xla/hlo/builder/lib/loops.cc similarity index 97% rename from third_party/xla/xla/client/lib/loops.cc rename to third_party/xla/xla/hlo/builder/lib/loops.cc index 5785e9969dee8f..e7dbad01163d93 100644 --- a/third_party/xla/xla/client/lib/loops.cc +++ b/third_party/xla/xla/hlo/builder/lib/loops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/loops.h" +#include "xla/hlo/builder/lib/loops.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/loops.h b/third_party/xla/xla/hlo/builder/lib/loops.h new file mode 100644 index 00000000000000..540ab784f34684 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/loops.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_LOOPS_H_ +#define XLA_HLO_BUILDER_LIB_LOOPS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Function that builds a loop condition. Takes as input a sequence of input +// values, and returns a boolean value representing if the condition succeeds. +typedef std::function(absl::Span, + XlaBuilder*)> + WhileLoopHelperConditionFunction; + +// Function that builds a loop body. Takes as input a sequence of input values +// and returns a sequence of output values. +typedef std::function>( + absl::Span, XlaBuilder*)> + WhileLoopHelperBodyFunction; + +// Helper function for building an XLA while loop, where the values carried by +// the loop are a tuple of values, e.g., (a, b, c): +// while( +// condition: (a, b, c) -> bool, +// body: (a, b, c) -> (a, b, c) +// init: (a, b, c) +// ) +// 'name' is a descriptive name for the loop. +absl::StatusOr> WhileLoopHelper( + const WhileLoopHelperConditionFunction& condition_function, + const WhileLoopHelperBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); + +// Builds an XLA loop that repeats a computation `num_iterations` times. +// +// The body function (ForEachIndexBodyFunction) takes as input a pair of +// (current iteration number, loop-carried values), and returns an updated +// vector of the loop-carried values. +typedef std::function>( + XlaOp, absl::Span, XlaBuilder*)> + ForEachIndexBodyFunction; + +absl::StatusOr> ForEachIndex( + int64_t num_iterations, PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + absl::Span initial_values, absl::string_view name, + XlaBuilder* builder); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_LOOPS_H_ diff --git a/third_party/xla/xla/client/lib/lu_decomposition.cc b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc similarity index 96% rename from third_party/xla/xla/client/lib/lu_decomposition.cc rename to third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc index b4f00876ce36a8..78e9c00e07ca1a 100644 --- a/third_party/xla/xla/client/lib/lu_decomposition.cc +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/lu_decomposition.h" +#include "xla/hlo/builder/lib/lu_decomposition.h" #include #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h new file mode 100644 index 00000000000000..d233dab04f50e2 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/lu_decomposition.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ +#define XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes the LU decomposition with partial pivoting of a batch of matrices. +// +// Given a (batched) matrix a with shape [..., m, n], computes the matrix +// decomposition A = P @ L @ U where P is a permutation matrix, L is a +// lower-triangular matrix with unit diagonal entries, and U is an +// upper-triangular matrix. +// +// L and U are returned as a single matrix [..., m, n] containing both L and U +// packed in the same array. The unit diagonal of L is not represented +// explicitly. +// +// The permutation matrix P is returned in two forms, both as `pivots`, which is +// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the +// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array +// which gives the permutation to apply to the rows. We return both +// representations because they are each useful for different purposes; `pivots` +// is useful for computing the sign of a determinant, whereas `permutation` can +// be used via a Gather operation to permute the rows of a matrix. +// +// This method is only implemented on TPU at the moment. +// TODO(b/168208200): the implementation only supports F32 arrays. Handle the +// complex case. +struct LuDecompositionResult { + // The LU decomposition, with both L and U packed into an array with shape + // [..., m, n]. + XlaOp lu; + // An array of shape s32[..., min(m, n)] containing the pivot rows. + XlaOp pivots; + // An array of shape s32[..., m], containing an another representation of the + // pivots as a permutation. + XlaOp permutation; +}; + +LuDecompositionResult LuDecomposition(XlaOp a); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_LU_DECOMPOSITION_H_ diff --git a/third_party/xla/xla/client/lib/math.cc b/third_party/xla/xla/hlo/builder/lib/math.cc similarity index 94% rename from third_party/xla/xla/client/lib/math.cc rename to third_party/xla/xla/hlo/builder/lib/math.cc index c3f27638bdd2e9..aabb14fdc8cda0 100644 --- a/third_party/xla/xla/client/lib/math.cc +++ b/third_party/xla/xla/hlo/builder/lib/math.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/math.h" #include #include #include +#include #include #include #include @@ -28,11 +29,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math_impl.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math_impl.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/status_macros.h" @@ -122,6 +123,10 @@ static absl::Status EnsureOperandIsRealFp(absl::string_view op_name, return absl::OkStatus(); } +XlaOp PredFalse(XlaBuilder* builder, const Shape& shape) { + return Broadcast(ConstantR0(builder, false), shape.dimensions()); +} + XlaOp IsPosInf(XlaOp operand) { auto& b = *operand.builder(); return b.ReportErrorOrReturn([&]() -> absl::StatusOr { @@ -129,7 +134,9 @@ XlaOp IsPosInf(XlaOp operand) { TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); // Note that this is only correct for floating-point types. If we wanted it // to be correct for all types, we'd need to Gt(MaxFiniteValue). - return Eq(operand, MaxValue(&b, shape.element_type())); + return primitive_util::HasInfinity(shape.element_type()) + ? Eq(operand, MaxValue(&b, shape.element_type())) + : PredFalse(&b, shape); }); } @@ -140,7 +147,9 @@ XlaOp IsNegInf(XlaOp operand) { TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand)); // Note that this is only correct for floating-point types. If we wanted it // to be correct for all types, we'd need to Lt(MinFiniteValue). - return Eq(operand, MinValue(&b, shape.element_type())); + return primitive_util::HasInfinity(shape.element_type()) + ? Eq(operand, MinValue(&b, shape.element_type())) + : PredFalse(&b, shape); }); } @@ -175,11 +184,10 @@ XlaOp IsNegZero(XlaOp operand) { case F32: return Eq(BitcastConvertType(operand, U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F8E3M4: + case F8E4M3: case F8E5M2: case F8E4M3FN: - case F8E4M3B11FNUZ: - case F8E5M2FNUZ: - case F8E4M3FNUZ: case F16: case BF16: // Not all XLA backends handle U16 well, so we convert to F32/U32. @@ -187,6 +195,12 @@ XlaOp IsNegZero(XlaOp operand) { // backends that *do* support it. return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32), ConstantR0WithType(&b, U32, uint32_t{1} << 31)); + case F8E4M3B11FNUZ: + case F8E5M2FNUZ: + case F8E4M3FNUZ: { + // FP8 types with no unsigned zero representation. + return PredFalse(&b, shape); + } default: LOG(FATAL) << "Expected real fp type."; } @@ -199,51 +213,59 @@ XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; } // Computes an approximation of the error function complement (1 - erf(x)). // -// Precondition: abs(x) >= 1. Otherwise, use ErfImpl. +// Precondition: abs(x) >= 1. Otherwise, use ErfcSmallImpl32. // -// This follows Cephes's f32 implementation of erfc. -static XlaOp ErfcImpl32(XlaOp x) { - // Coefficients for erfc(f32), from Cephes. - const double kMaxlog = 88.72283905206835; - // erfc(x) = exp(-x^2) P(1/x^2), 1 < x < 2 - static const std::array kErfcPCoefficient{ - +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1, - -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1, - +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1, - }; - // erfc(x) = exp(-x^2) R(1/x^2), 2 <= x < kMaxlog - static const std::array kErfcRCoefficient{ - -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0, - +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1, - -2.820767439740514E-1, +5.641895067754075E-1, - }; - XlaOp abs_x = Abs(x); - XlaOp z = Exp(-x * x); - XlaOp q = ScalarLike(x, 1) / abs_x; - XlaOp y = q * q; - XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)), - EvaluatePolynomial(y, kErfcPCoefficient), - EvaluatePolynomial(y, kErfcRCoefficient)); - y = z * q * p; - XlaOp y_clamp = Select(Lt(z, ScalarLike(x, -kMaxlog)), ScalarLike(x, 0), y); - return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y_clamp, y_clamp); +// This follows Eigen's f32 implementation of erfc. +static XlaOp ErfcLargeImpl32(XlaOp x) { + // Take absolute value and clamp at x=10.06 where erfc(x) is less than the + // underflow boundary for float32. + XlaOp abs_x = Min(Abs(x), ScalarLike(x, 10.06)); + + // erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 10.06. + // + // Coefficients for P and Q were generated with Rminimax command: + // ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)" + // --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log + // --dispCoeff="dec" + static const std::array kErfcGammaCoefficient{ + 1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f, + 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f}; + static const std::array kErfcDeltaCoefficient{ + 1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f, + 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f, + 9.5662862062454223632812500e-02f}; + + XlaOp x2 = x * x; + XlaOp z = Exp(-x2); + XlaOp q2 = Reciprocal(x2); + XlaOp num = EvaluatePolynomial(q2, kErfcGammaCoefficient); + XlaOp denom = abs_x * EvaluatePolynomial(q2, kErfcDeltaCoefficient); + XlaOp r = num / denom; + XlaOp erfc_abs_x = z * r; + return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0f) - erfc_abs_x, + erfc_abs_x); } -// Compute a polynomial approximation of the error function. +// Compute a polynomial approximation of the complementary error function +// for abs(x) <= 1. // -// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl. +// Precondition: abs(x) <= 1. Otherwise, use ErfcLargeImpl32. // -// This follows Cephes's f32 implementation of erf. -static XlaOp ErfImpl32Cephes(XlaOp x) { - // Coefficients for by erf(f32), from Cephes. +static XlaOp ErfcSmallImpl32(XlaOp x) { + // erfc(x) = x * P(x^2) + 1, |x| <= 1 // - // erf(x) = x P(x^2), 0 < x < 1 - static const std::array kErfTCoefficient{ - +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3, - -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1, - +1.128379165726710E+0, - }; - return x * EvaluatePolynomial(x * x, kErfTCoefficient); + // Coefficients were generated with Rminimax command: + // ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd" + // --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec" + static const std::array kErfcSmallCoefficient{ + +5.61802298761904239654541015625e-04f, + -4.91381669417023658752441406250e-03f, + +2.67075151205062866210937500000e-02f, + -1.12800106406211853027343750000e-01f, + +3.76122951507568359375000000000e-01f, + -1.12837910652160644531250000000e+00f}; + return x * EvaluatePolynomial(x * x, kErfcSmallCoefficient) + + ScalarLike(x, 1.0f); } static XlaOp ErfcImpl64(XlaOp x) { @@ -321,36 +343,12 @@ XlaOp Erfc(XlaOp x) { // Erf(c)Impl don't have enough precision when run with bf16 intermediates // (not surprising!), so upcast to f32 in this case. return DoWithUpcastToF32(x, {}, [](XlaOp x) { - return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl32(x), - ScalarLike(x, 1) - ErfImpl32Cephes(x)); + return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcLargeImpl32(x), + ErfcSmallImpl32(x)); }); }); } -// Compute a rational approximation of the error function. -static XlaOp ErfImpl32(XlaOp x) { - static const std::array kAlpha{ - 0.00022905065861350646f, 0.0034082910107109506f, 0.050955695062380861f, - 0.18520832239976145f, 1.128379143519084f}; - - static const std::array kBeta{-1.1791602954361697e-7, - 0.000023547966471313185f, - 0.0010179625278914885f, - 0.014070470171167667f, - 0.11098505178285362f, - 0.49746925110067538f, - 1.0f}; - - // We clamp x to be within [-c;c] where c = erfinv(1-2^-23), outside of - // which x should be +/-1. - constexpr float kErfInvOneMinusHalfULP = 3.7439211627767994f; - x = Clamp(ScalarLike(x, -kErfInvOneMinusHalfULP), x, - ScalarLike(x, kErfInvOneMinusHalfULP)); - auto x2 = x * x; - return (x * EvaluatePolynomial(x2, kAlpha)) / - EvaluatePolynomial(x2, kBeta); -} - namespace { // Approximation for the inverse error function from @@ -973,8 +971,8 @@ XlaOp Igamma(XlaOp a, XlaOp x) { TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Igamma", a)); PrimitiveType a_x_type = a_shape.element_type(); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1026,8 +1024,8 @@ XlaOp IgammaGradA(XlaOp a, XlaOp x) { } TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IgammaGradA", a)); bool needs_upcast = false; - for (PrimitiveType type : - {BF16, F16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + for (PrimitiveType type : {BF16, F16, F8E3M4, F8E4M3, F8E5M2, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { if (a_shape.element_type() == type) { needs_upcast = true; break; @@ -1485,6 +1483,23 @@ XlaOp NextAfter(XlaOp from, XlaOp to) { Broadcast(ScalarLike(from_as_int, -1), shape.dimensions()), Broadcast(ScalarLike(from_as_int, 1), shape.dimensions())); auto result = Add(from_as_int, magnitude_adjustment); + + if (shape.element_type() == F8E5M2FNUZ || + shape.element_type() == F8E4M3FNUZ || + shape.element_type() == F8E4M3B11FNUZ) { + // Handle 'from' is the negative value closest to zero and 'to' is + // positive. For FNUZ dtypes, the result is +0 instead of -0 since -0 + // represents a NaN value. + const int64_t least_negative = sign_mask | 1; + auto to_is_nonnegative = Not(ConvertElementType(to_sign, PRED)); + auto predicate = + And(Eq(from_as_int, ScalarLike(from_as_int, least_negative)), + to_is_nonnegative); + auto result_if_predicate = + Broadcast(ScalarLike(from_as_int, 0), shape.dimensions()); + result = Select(predicate, result_if_predicate, result); + } + // Handle from == ±0. result = Select(from_is_zero, Select(to_is_zero, result_for_both_zero, diff --git a/third_party/xla/xla/hlo/builder/lib/math.h b/third_party/xla/xla/hlo/builder/lib/math.h new file mode 100644 index 00000000000000..6c26ec20410c64 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/math.h @@ -0,0 +1,127 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_MATH_H_ +#define XLA_HLO_BUILDER_LIB_MATH_H_ + +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { + +// Determines whether operand is +/-inf or nan. +// +// Raises an error if called on integral or complex values. +XlaOp IsPosInf(XlaOp operand); +XlaOp IsNegInf(XlaOp operand); +XlaOp IsInf(XlaOp operand); +XlaOp IsNan(XlaOp operand); + +// Determines whether operand is equal to -0. +// +// Raises an error for integral or complex values. +XlaOp IsNegZero(XlaOp operand); + +// Returns the next number after 'from' in the direction of 'to' the same way +// std::nextafter(from, to) would. +XlaOp NextAfter(XlaOp from, XlaOp to); + +// Computes the square of 'operand'. +XlaOp Square(XlaOp operand); + +// Computes the reciprocal of 'operand'. +XlaOp Reciprocal(XlaOp operand); + +// Computes an approximation of the error function complement (1 - erf(x)). +XlaOp Erfc(XlaOp x); + +// Computes an approximation of the inverse of the error function. +XlaOp ErfInv(XlaOp x); + +// Computes an approximation of the lgamma function. +XlaOp Lgamma(XlaOp input); + +// Computes an approximation of the digamma function. +XlaOp Digamma(XlaOp input); + +// Computes an approximation of the incomplete gamma function. +XlaOp Igamma(XlaOp a, XlaOp x); + +// Computes an approximation of the derivative of the incomplete gamma function +// with respect to a. +XlaOp IgammaGradA(XlaOp a, XlaOp x); + +// Computes an approximation of the derivative of a sample `x` from a `Gamma(a, +// 1)` distribution with respect to a. +XlaOp RandomGammaGrad(XlaOp a, XlaOp x); + +// Computes an approximation of the complementary incomplete gamma function. +XlaOp Igammac(XlaOp a, XlaOp x); + +// Computes the Polygamma of two arguments. +XlaOp Polygamma(XlaOp n, XlaOp x); + +// Computes the Riemann zeta function of two arguments. +XlaOp Zeta(XlaOp x, XlaOp q); + +// Rounds the given number to even when the number is equidistant between two +// integers. +XlaOp RoundToEven(XlaOp x); + +// Trigonometric functions + +// Computes the arc cosine of 'x'. +XlaOp Acos(XlaOp x); + +// Computes the arc sine of 'x'. +XlaOp Asin(XlaOp x); + +// Computes the arc tangent of 'x'. +XlaOp Atan(XlaOp x); + +// Hyperbolic trigonometric functions + +// Computes the inverse hyperbolic cosine of 'x'. +XlaOp Acosh(XlaOp x); + +// Computes the inverse hyperbolic sine of 'x'. +XlaOp Asinh(XlaOp x); + +// Computes the inverse hyperbolic tangent of 'x'. +XlaOp Atanh(XlaOp x); + +// Computes the hyperbolic cosine of 'x'. +XlaOp Cosh(XlaOp x); + +// Computes the hyperbolic sine of 'x'. +XlaOp Sinh(XlaOp x); + +// Applies a complex conjugation operation if 'a' is complex and 'conjugate' +// is true, otherwise returns its argument. +xla::XlaOp MaybeConjugate(xla::XlaOp x, bool conjugate); + +// Computes the Modified Bessel function of the first kind of the zeroth order +// at x. +XlaOp BesselI0e(XlaOp x); + +// Computes the Modified Bessel function of the first kind of the first order +// at x. +XlaOp BesselI1e(XlaOp x); + +// Computes the Regularized Incomplete Beta function. +XlaOp RegularizedIncompleteBeta(XlaOp a, XlaOp b, XlaOp x); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_MATH_H_ diff --git a/third_party/xla/xla/client/lib/math_impl.h b/third_party/xla/xla/hlo/builder/lib/math_impl.h similarity index 97% rename from third_party/xla/xla/client/lib/math_impl.h rename to third_party/xla/xla/hlo/builder/lib/math_impl.h index f89851ad9366c2..262856d08c712f 100644 --- a/third_party/xla/xla/client/lib/math_impl.h +++ b/third_party/xla/xla/hlo/builder/lib/math_impl.h @@ -17,12 +17,12 @@ limitations under the License. // https://github.com/pearu/functional_algorithms // for more information. -#ifndef XLA_CLIENT_LIB_MATH_IMPL_H_ -#define XLA_CLIENT_LIB_MATH_IMPL_H_ +#ifndef XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ +#define XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" namespace xla { namespace math_impl { @@ -256,4 +256,4 @@ XlaOp AsinReal(XlaOp x) { } // namespace math_impl } // namespace xla -#endif // XLA_CLIENT_LIB_MATH_IMPL_H_ +#endif // XLA_HLO_BUILDER_LIB_MATH_IMPL_H_ diff --git a/third_party/xla/xla/client/lib/math_test.cc b/third_party/xla/xla/hlo/builder/lib/math_test.cc similarity index 95% rename from third_party/xla/xla/client/lib/math_test.cc rename to third_party/xla/xla/hlo/builder/lib/math_test.cc index 0c5776f4bea333..30eaf4b503de62 100644 --- a/third_party/xla/xla/client/lib/math_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/math_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/math.h" +#include "xla/hlo/builder/lib/math.h" #include #include @@ -26,9 +26,9 @@ limitations under the License. #include #include "xla/array3d.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" @@ -83,8 +83,8 @@ class MathTypedTest : public MathTest { auto x = ConstantR1(&b, { T{0}, - T{100}, - T{-1000}, + T{8}, + T{-8}, T{std::numeric_limits::max()}, T{std::numeric_limits::lowest()}, T{std::numeric_limits::infinity()}, @@ -94,17 +94,18 @@ class MathTypedTest : public MathTest { }); Tuple(&b, {IsFinite(x), IsInf(x), IsPosInf(x), IsNegInf(x), IsNan(x)}); + bool has_inf = std::numeric_limits::has_infinity; auto expected = LiteralUtil::MakeTupleOwned( LiteralUtil::CreateR1( {true, true, true, true, true, false, false, false, false}), + LiteralUtil::CreateR1({false, false, false, false, false, has_inf, + has_inf, false, false}), LiteralUtil::CreateR1( - {false, false, false, false, false, true, true, false, false}), + {false, false, false, false, false, has_inf, false, false, false}), LiteralUtil::CreateR1( - {false, false, false, false, false, true, false, false, false}), - LiteralUtil::CreateR1( - {false, false, false, false, false, false, true, false, false}), - LiteralUtil::CreateR1( - {false, false, false, false, false, false, false, true, true})); + {false, false, false, false, false, false, has_inf, false, false}), + LiteralUtil::CreateR1({false, false, false, false, false, + !has_inf, !has_inf, true, true})); ComputeAndCompareLiteral(&b, expected, {}); } @@ -120,7 +121,7 @@ class MathTypedTest : public MathTest { ComputeAndCompareLiteral( &b, LiteralUtil::CreateR1( - {true, false, false, false, false, false, false}), + {has_negative_zero_v, false, false, false, false, false, false}), {}, error_spec_); } @@ -171,6 +172,7 @@ class MathTypedTest : public MathTest { SetFastMathDisabled(true); const T kErfInvOneMinusHalfULP = T(3.832506856900711); const T inf(std::numeric_limits::infinity()); + const T nan(std::numeric_limits::quiet_NaN()); XlaBuilder b(TestName()); auto x = AddParam(LiteralUtil::CreateR1({T{-inf}, T{inf}, T{-0}, T{0}, @@ -179,23 +181,29 @@ class MathTypedTest : public MathTest { &b); Erf(x); - std::vector expected = {T(-1), T(1), T(-0), T(0), T(-1), T(1)}; + bool has_inf = std::numeric_limits::has_infinity; + std::vector expected = { + has_inf ? T(-1) : nan, has_inf ? T(1) : nan, T(-0), T(0), T(-1), T(1)}; ComputeAndCompareR1(&b, expected, {}, error_spec_); } }; // TODO(b/123355973): Add bfloat16 to TestTypes once it's working. -using TestTypes = ::testing::Types; + float>; TYPED_TEST_CASE(MathTypedTest, TestTypes); diff --git a/third_party/xla/xla/client/lib/matrix.cc b/third_party/xla/xla/hlo/builder/lib/matrix.cc similarity index 99% rename from third_party/xla/xla/client/lib/matrix.cc rename to third_party/xla/xla/hlo/builder/lib/matrix.cc index 38a1a67efde17f..7c189b762a49f3 100644 --- a/third_party/xla/xla/client/lib/matrix.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/matrix.h" #include #include @@ -36,10 +36,10 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/hlo/builder/lib/matrix.h b/third_party/xla/xla/hlo/builder/lib/matrix.h new file mode 100644 index 00000000000000..8fdf01d438d7a1 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/matrix.h @@ -0,0 +1,159 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_MATRIX_H_ +#define XLA_HLO_BUILDER_LIB_MATRIX_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns an m x n matrix with 1s on the diagonal elements, zeros everywhere +// else. +XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64_t m, + int64_t n); + +// Returns a mask where the 'diagonal'-th diagonal is true and everything else +// is false. +XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0); + +// Get the diagonals of the last two dimensions. Use k>0 for diagonals above the +// main diagonal, and k<0 for diagonals below the main diagonal. +// +// If 'x' has shape [..., M, N] +// If k >= 0: then the output has shape [..., min(M, N - k)], containing the +// diagonal elements (i.e., with indices [..., i, i + k]). +// If k < 0: then the output has shape [..., min(M + k, N)], containing the +// diagonal elements (i.e., with indices [..., i - k, i]). +XlaOp GetMatrixDiagonal(XlaOp x, int k = 0); +XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0); + +// Places diag along the kth diagonal of target. +XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0); + +// Returns a lower-triangular mask, i.e., true below and including the +// `diagonal`-th diagonal and false above that diagonal. +XlaOp TriangleMask(XlaOp x, int diagonal); + +// Get the upper or lower triangle part of the last two dimensions +XlaOp Triangle(XlaOp x, bool lower); + +// Get the upper triangle part of the last two dimensions +XlaOp UpperTriangle(XlaOp x); + +// Get the lower triangle part of the last two dimensions +XlaOp LowerTriangle(XlaOp x); + +// If x is an array of shape [..., n, n], symmetrizes the matrix by replacing +// the upper triangle with the transpose of the lower triangle (if lower is +// True, vice-versa otherwise). If the type of `x` is complex, makes the matrix +// Hermitian by taking the conjugate of the complex part and setting the +// complex diagonal to zero. +XlaOp Symmetrize(XlaOp x, bool lower); + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +xla::XlaOp BatchDot( + xla::XlaOp x, xla::XlaOp y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt); +xla::XlaOp BatchDot( + xla::XlaOp x, bool transpose_x, xla::XlaOp y, bool transpose_y, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); + +// Parse an einsum string into dimension numbers: +// "ab,cb->ac" +// becomes: +// {{0, 1},{2, 1},{0, 2}} +// +// Each occurrence of ellipsis ("...") occurring in the input is replaced with +// the same numeric dimensions. The number of such dimensions is inferred from +// x_rank and y_rank. For example: +// einsum_config: "...ab,...bcd->...acd" +// x_rank: 4 +// y_rank: 5 +// becomes: +// {{0, 1, 2, 3},{0, 1, 3, 4, 5},{0, 1, 2, 4, 5}} +// +// NOTE: This function is meant for testing, there is no need to call it +// directly. + +absl::StatusOr, 3>> ParseEinsumString( + absl::string_view einsum_config, int64_t x_rank, int64_t y_rank); + +// If an einsum config does not contain an -> one will be added and the output +// config will be the sorted characters with any ellipsis at the beginning. +// Returns an empty string if the einsum string already has an ->. +std::string NormalizeEinsumString(absl::string_view einsum_config); + +// Supports two operand einsum notation like "ab,cb->ac". +xla::XlaOp Einsum( + xla::XlaOp x, xla::XlaOp y, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); +xla::XlaOp Einsum( + xla::XlaOp x, absl::string_view einsum_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT); + +// Same as above but supporting numeric labels on dimensions. So "ab,cb->ac" +// becomes: +// x_config = {0, 1} +// y_config = {2, 1} +// output_config = {0, 2} +xla::XlaOp Einsum( + xla::XlaOp x, absl::Span x_config, xla::XlaOp y, + absl::Span y_config, absl::Span output_config, + xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::DEFAULT, + std::optional preferred_element_type = std::nullopt, + bool grad_x = false, bool grad_y = false); + +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::XlaOp TransposeInMinorDims(xla::XlaOp x); + +// Transposes `x` in its minor dimensions if `transpose` is true, otherwise +// returns `x` unchanged. +xla::XlaOp MaybeTransposeInMinorDims(xla::XlaOp x, bool transpose); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_MATRIX_H_ diff --git a/third_party/xla/xla/client/lib/matrix_test.cc b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/matrix_test.cc rename to third_party/xla/xla/hlo/builder/lib/matrix_test.cc index caa313b4ab8923..debb6e20ae0108 100644 --- a/third_party/xla/xla/client/lib/matrix_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/matrix_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/matrix.h" +#include "xla/hlo/builder/lib/matrix.h" #include #include @@ -28,9 +28,9 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/pooling.cc b/third_party/xla/xla/hlo/builder/lib/pooling.cc similarity index 98% rename from third_party/xla/xla/client/lib/pooling.cc rename to third_party/xla/xla/hlo/builder/lib/pooling.cc index 5f03ad45afb0fd..81dd1a7c4c0f95 100644 --- a/third_party/xla/xla/client/lib/pooling.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/pooling.h" +#include "xla/hlo/builder/lib/pooling.h" #include #include @@ -22,11 +22,11 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/pooling.h b/third_party/xla/xla/hlo/builder/lib/pooling.h new file mode 100644 index 00000000000000..15176888939c04 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/pooling.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_POOLING_H_ +#define XLA_HLO_BUILDER_LIB_POOLING_H_ + +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" + +namespace xla { + +// Tensor format for reduce window operations. +class TensorFormat { + public: + TensorFormat(int batch_dimension, int feature_dimension, + absl::Span spatial_dimensions) + : batch_dimension_(batch_dimension), + feature_dimension_(feature_dimension), + spatial_dimensions_(spatial_dimensions.begin(), + spatial_dimensions.end()) {} + + int batch_dimension() const { return batch_dimension_; } + + int feature_dimension() const { return feature_dimension_; } + + int spatial_dimension(int dim) const { return spatial_dimensions_[dim]; } + + int num_spatial_dims() const { return spatial_dimensions_.size(); } + + private: + // The number of the dimension that represents the batch. + int batch_dimension_; + // The number of the dimension that represents the features. + int feature_dimension_; + // The dimension numbers for the spatial dimensions. + absl::InlinedVector spatial_dimensions_; +}; + +// Computes the max pool of 'operand'. +XlaOp MaxPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool of 'operand'. +XlaOp AvgPool(XlaOp operand, absl::Span kernel_size, + absl::Span stride, + absl::Span> padding, + const TensorFormat& data_format, bool counts_include_padding); + +// Returns the list of low and high padding elements in each spatial dimension +// for the given 'padding' specification. +std::vector> MakeSpatialPadding( + absl::Span input_size, absl::Span kernel_size, + absl::Span stride, Padding padding, + const TensorFormat& data_format); + +// Computes the average pool gradient. +XlaOp AvgPoolGrad(XlaOp out_backprop, absl::Span gradients_size, + absl::Span kernel_size, + absl::Span stride, + absl::Span> spatial_padding, + const TensorFormat& data_format, bool counts_include_padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_POOLING_H_ diff --git a/third_party/xla/xla/client/lib/pooling_test.cc b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/pooling_test.cc rename to third_party/xla/xla/hlo/builder/lib/pooling_test.cc index 54ef5f43b49f94..97b874d81c04ce 100644 --- a/third_party/xla/xla/client/lib/pooling_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/pooling_test.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/pooling.h" +#include "xla/hlo/builder/lib/pooling.h" #include #include #include "absl/container/inlined_vector.h" #include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/prng.cc b/third_party/xla/xla/hlo/builder/lib/prng.cc similarity index 99% rename from third_party/xla/xla/client/lib/prng.cc rename to third_party/xla/xla/hlo/builder/lib/prng.cc index 370382238adf4a..7bafd7bf5b8e22 100644 --- a/third_party/xla/xla/client/lib/prng.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/prng.h" +#include "xla/hlo/builder/lib/prng.h" #include #include @@ -27,8 +27,8 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/prng.h b/third_party/xla/xla/hlo/builder/lib/prng.h new file mode 100644 index 00000000000000..89b4dd62bbcd14 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/prng.h @@ -0,0 +1,101 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_PRNG_H_ +#define XLA_HLO_BUILDER_LIB_PRNG_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Records the bits and state generated by a random number generator. +struct RngOutput { + XlaOp value; + XlaOp state; +}; + +// A BitGenerator returns random bits and updated random bit generator state. +// +// key: is a value input to a random number generator that can affect the +// sequence of number it will generate. A random number generator constructs +// its seed using the key and the initial state. The tf2xla bridge passes the +// seed operand of a tensorflow random operation as a key to the random bit +// generator, for example. +// initial_state: initial_state is the initial state of the current random +// number generation. It could be 0 for a stateless random operation, and +// the returned state from a previous execution for a stateful random +// operation. +// shape: the shape of the random bits. +using BitGeneratorTy = std::function; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state, + const xla::Shape& shape); + +// Implements the Philox algorithm to generate random numbers in parallel. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +// +// The paper presents a few variants of the Philox algorithm, we picked the +// 4x32_10 version of the algorithm for the following reasons: +// . 4x32 uses 32-bit multiplication which is fast on GPUs. +// . The authors recommend the 10-round variant, and TensorFlow also uses it. +RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state, + const Shape& shape); +// Returns a scrambled pair of (state, key) from a single key. +std::pair ScramblePhiloxKey(XlaOp key); + +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of uniform distribution in the given range. +// Returns the random numbers and the state of the random number generator. +// This function is for shape with floating point element types. +RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + XlaOp minval, XlaOp maxval, + const xla::Shape& shape); + +// Similar to UniformFloatingPointDistribution but for shape with integer +// element types. +RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, XlaOp minval, + XlaOp maxval, const xla::Shape& shape); + +// Uses the given bit generator to generate random bits and then converts the +// random bits to random numbers of normal distribution. +// Returns the random numbers and the state of the random number generator. +RngOutput NormalFloatingPointDistribution(XlaOp key, XlaOp initial_state, + BitGeneratorTy bit_generator, + const xla::Shape& shape); + +// Concatenates scalars into a vector. +xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, + absl::Span scalars); + +// Increases Philox counter (an uint128_t) by a delta (an uint64_t). +xla::XlaOp PhiloxIncreaseCounter(xla::XlaOp counter, xla::XlaOp delta); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_PRNG_H_ diff --git a/third_party/xla/xla/client/lib/prng_test.cc b/third_party/xla/xla/hlo/builder/lib/prng_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/prng_test.cc rename to third_party/xla/xla/hlo/builder/lib/prng_test.cc index 22241e9fab1da9..0e5f9772c35d26 100644 --- a/third_party/xla/xla/client/lib/prng_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/prng_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/prng.h" +#include "xla/hlo/builder/lib/prng.h" #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/qr.cc b/third_party/xla/xla/hlo/builder/lib/qr.cc similarity index 96% rename from third_party/xla/xla/client/lib/qr.cc rename to third_party/xla/xla/hlo/builder/lib/qr.cc index 794b2e4887f8b2..699e13b4c2e181 100644 --- a/third_party/xla/xla/client/lib/qr.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/qr.h" +#include "xla/hlo/builder/lib/qr.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/qr.h b/third_party/xla/xla/hlo/builder/lib/qr.h new file mode 100644 index 00000000000000..6e4f3cc15fa4ec --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/qr.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_QR_H_ +#define XLA_HLO_BUILDER_LIB_QR_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Computes the QR decompositions of a batch of matrices. That is, +// given a (batched) matrix a, computes an orthonormal matrix Q and an +// upper-triangular matrix R such that a = QR. +// `a` must be a (batched) matrix of size [..., m, n]. +struct QrDecomposition { + // A matrix with the same shape as the input matrix `a`, whose upper triangle + // (inclusive of the diagonal) is the matrix R, and whose lower triangle + // (exclusive of the diagonal) contains the elementary Householder reflectors. + // This is the same output format as used by LAPACK's xGEQRF routine. + XlaOp q_and_r; + // A vector of shape [..., min(m, n)] containing the scalar factors of the + // elementary Householder reflectors. + XlaOp taus; +}; + +QrDecomposition Qr(XlaOp a); + +// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of +// the elementary Householder reflectors (i.e., the matrix Q of the QR +// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR. +XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus); + +// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to +// compute explicit matrices `q` and `r`. +void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_QR_H_ diff --git a/third_party/xla/xla/client/lib/qr_test.cc b/third_party/xla/xla/hlo/builder/lib/qr_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/qr_test.cc rename to third_party/xla/xla/hlo/builder/lib/qr_test.cc index fc9e583ab9ad12..9f8e28e53cef66 100644 --- a/third_party/xla/xla/client/lib/qr_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/qr_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/qr.h" +#include "xla/hlo/builder/lib/qr.h" #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/hlo/builder/lib/quantize.h b/third_party/xla/xla/hlo/builder/lib/quantize.h new file mode 100644 index 00000000000000..d0126f0c021b2f --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/quantize.h @@ -0,0 +1,184 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_QUANTIZE_H_ +#define XLA_HLO_BUILDER_LIB_QUANTIZE_H_ + +#include +#include +#include +#include + +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/bfloat16.h" + +namespace xla { + +// Represents the range used for quantization +struct QuantizedRange { + QuantizedRange() = default; + QuantizedRange(float min_in, float max_in) : min(min_in), max(max_in) {} + + bool operator==(const QuantizedRange& rhs) const { + return this->min == rhs.min && this->max == rhs.max; + } + + bool operator!=(const QuantizedRange& rhs) const { return !(*this == rhs); } + + tsl::bfloat16 min = tsl::bfloat16(0.0f); + tsl::bfloat16 max = tsl::bfloat16(0.0f); +}; + +template +inline std::vector PackToUint32(absl::Span input) { + const int64_t kElementsPerPack = sizeof(uint32_t) / sizeof(T); + const int64_t input_size = input.size(); + const int64_t output_size = CeilOfRatio(input_size, kElementsPerPack); + + std::vector output_vec; + constexpr int64_t kShiftBits = sizeof(T) / sizeof(uint8_t) * CHAR_BIT; + + for (int64_t i = 0; i < output_size; i++) { + uint32_t result = 0; + for (int64_t p = 0; p < kElementsPerPack; p++) { + int64_t index = i * kElementsPerPack + p; + if (index < input_size) { + int64_t total_shift_bits = kShiftBits * (kElementsPerPack - p - 1); + result |= (input[index] << total_shift_bits); + } + } + output_vec.push_back(result); + } + + return output_vec; +} + +// Dequantize the quantized input of packed uint32_t to bfloat16. +// Only uint8_t or uint16_t is supported for the original unpacked input. +// Returns a tensor of shape [d0,..., dn * unpack_size] if +// input shape is [d0, ..., dn], where unpack_size = sizeof(unit32) / sizeof(T). +// If transpose_output is true, will return a tensor of shape +// [dn * unpack_size, dn-1, ..., d1, d0]. transpose_output is faster when +// input's rank higher than 1. The input needs to be transposed to use +// transpose_output feature. +template +inline XlaOp Dequantize(XlaOp input, const QuantizedRange& range, + absl::string_view mode_string = "MIN_COMBINED", + bool transpose_output = false) { + XlaBuilder* const builder = input.builder(); + return builder->ReportErrorOrReturn([&]() -> absl::StatusOr { + float half_range = + !std::is_signed::value + ? 0.0f + : (static_cast(std::numeric_limits::max()) - + std::numeric_limits::min() + 1) / + 2.0f; + const int64_t unpack_size = sizeof(uint32_t) / sizeof(T); + TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(input)); + + auto element_type = shape.element_type(); + if (element_type != U32) { + return InvalidArgument( + "Only U32 is supported for input type of xla::Dequantize Op."); + } + + // Broadcast the input to [unpack_size, d0, ..., dn] if input size is + // [d0, ..., dn]. + auto broadcast_input = Broadcast(input, {unpack_size}); + + XlaOp iota_r1 = Iota(builder, U32, unpack_size); + // Highest significant bytes needs to shift more bytes than lower + // significant bytes. + XlaOp shift_bytes = + xla::ConstantR0(builder, unpack_size - 1) - iota_r1; + + const int bytes_of_type = sizeof(T) / sizeof(uint8_t); + std::vector shift_vec(unpack_size, CHAR_BIT * bytes_of_type); + XlaOp shift_bits = + shift_bytes * xla::ConstantR1(builder, shift_vec); + + // Make bit_mask for different data type T. + uint32_t bit_mask = 0x00000000; + for (int i = 0; i < bytes_of_type; i++) { + bit_mask <<= CHAR_BIT; + bit_mask |= 0x000000ff; + } + + std::vector shift_transpose_dimensions(shape.dimensions_size()); + std::iota(shift_transpose_dimensions.begin(), + shift_transpose_dimensions.end(), 0); + shift_transpose_dimensions.insert(shift_transpose_dimensions.begin(), 1, + shape.dimensions_size()); + + // Shift the input by sizeof(T) bytes and apply bit_mask to unpack. + XlaOp shifted_input = ShiftRightLogical( + broadcast_input, Transpose(Broadcast(shift_bits, shape.dimensions()), + shift_transpose_dimensions)); + XlaOp unpack_input = + And(shifted_input, xla::ConstantR0(builder, bit_mask)); + + XlaOp result; + + if (mode_string == "MIN_COMBINED") { + const tsl::bfloat16 scale_factor = + (range.max - range.min) / + (static_cast(std::numeric_limits::max() - + std::numeric_limits::min())); + // result = bfloat16(input + half_range) * scale_factor + range.min + XlaOp unpack_input_bf16 = ConvertElementType(unpack_input, BF16); + XlaOp half_range_bf16 = xla::ConstantR0( + builder, static_cast(half_range)); + XlaOp sum = unpack_input_bf16 + half_range_bf16; + + result = sum * xla::ConstantR0(builder, scale_factor) + + xla::ConstantR0(builder, range.min); + } else { + // TODO(wangtao): support other modes. + return InvalidArgument( + "Only MIN_COMBINED mode is supported in xla::Dequantize Op."); + } + + std::vector transpose_dimensions(shape.dimensions_size()); + std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 1); + std::reverse(transpose_dimensions.begin(), transpose_dimensions.end()); + transpose_dimensions.insert(transpose_dimensions.begin() + 1, 1, 0); + + // Transpose the result to be [dn, unpack_size, dn-1, ..., d1, d0]. + XlaOp transposed_result = Transpose(result, transpose_dimensions); + + // Reshape to be [dn * unpack_size, dn-1, ..., d1, d0]. + XlaOp reshaped_result = Collapse(transposed_result, {0, 1}); + + // Return the transpose result if transpose_output is true. + if (transpose_output) { + return reshaped_result; + } + + // Transpose the result to be [d0, d1, ..., dn-1, dn * unpack_size]. + std::vector result_dimensions(shape.dimensions_size()); + std::iota(result_dimensions.begin(), result_dimensions.end(), 0); + std::reverse(result_dimensions.begin(), result_dimensions.end()); + + return Transpose(reshaped_result, result_dimensions); + }); +} + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_QUANTIZE_H_ diff --git a/third_party/xla/xla/client/lib/quantize_test.cc b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/quantize_test.cc rename to third_party/xla/xla/hlo/builder/lib/quantize_test.cc index 6f371404f12869..6520bb4a07fef1 100644 --- a/third_party/xla/xla/client/lib/quantize_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/quantize_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/quantize.h" +#include "xla/hlo/builder/lib/quantize.h" #include #include #include "xla/array2d.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc similarity index 95% rename from third_party/xla/xla/client/lib/self_adjoint_eig.cc rename to third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc index 05ba43f9fabcef..a7f3a3c00b6933 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h new file mode 100644 index 00000000000000..f0dffdc41218bf --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig.h @@ -0,0 +1,41 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ +#define XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The eigenvalue decomposition of a symmetric matrix, the original matrix is +// recovered by v * w * v_t. +struct SelfAdjointEigResult { + // The i-th column is the normalized eigenvector corresponding to the + // eigenvalue w[i]. Will return a matrix object if a is a matrix object. + XlaOp v; + // The eigenvalues in ascending order, each repeated according to its + // multiplicity. + XlaOp w; +}; + +SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower = true, + int64_t max_iter = 15, float tol = 1e-5, + bool sort_eigenvalues = true); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SELF_ADJOINT_EIG_H_ diff --git a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/self_adjoint_eig_test.cc rename to third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc index 624c8211d874ec..510dd7b0de9c2f 100644 --- a/third_party/xla/xla/client/lib/self_adjoint_eig_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/self_adjoint_eig_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" #include #include @@ -25,12 +25,12 @@ limitations under the License. #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/slicing.cc b/third_party/xla/xla/hlo/builder/lib/slicing.cc similarity index 98% rename from third_party/xla/xla/client/lib/slicing.cc rename to third_party/xla/xla/hlo/builder/lib/slicing.cc index 26c5ea59ff1931..42dd4c8a82d188 100644 --- a/third_party/xla/xla/client/lib/slicing.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/slicing.h" +#include "xla/hlo/builder/lib/slicing.h" #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/hlo/builder/lib/slicing.h b/third_party/xla/xla/hlo/builder/lib/slicing.h new file mode 100644 index 00000000000000..dfb880805d2153 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/slicing.h @@ -0,0 +1,83 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" + +#ifndef XLA_HLO_BUILDER_LIB_SLICING_H_ +#define XLA_HLO_BUILDER_LIB_SLICING_H_ + +namespace xla { + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span start); + +// Performs a slice in the minor dimensions of a tensor. +// x[..., start[0]:end[0], ..., start[n]:end[n]] +XlaOp SliceInMinorDims(XlaOp x, absl::Span start, + absl::Span end); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0]:..., ..., start[n]:...] = update +XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span start); + +// Performs a dynamic slice in the minor dimensions of a tensor. +XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span starts, + absl::Span sizes); + +XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update, + absl::Span starts); + +// Gathers values along an axis specified by dim. +// +// For a 3-D tensor the output is specified by: +// +// out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 +// out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 +// out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 +// +// If `input` is an n-dimensional tensor with size +// [X0,X1,X2,..XN] and dim = i `index` must be an n-dimensional tensor with size +// [X0,X1,...Y,Xi+1,...,X[N] where y >= 1 and `out` will have the same sizes as +// `index`. +XlaOp TorchGather(XlaOp input, XlaOp index, int64_t dim, bool sparse = true); + +// idx = index[i][j][k] +// output[idx][j][k] = combiner(input[idx][j][k], src[i][j][k]) # if dim == 0 +// output[i][idx][k] = combiner(input[i][idx][k], src[i][j][k]) # if dim == 1 +// output[i][j][idx] = combiner(input[i][j][idx], src[i][j][k]) # if dim == 2 +XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64_t dim, + const std::function& combiner); + +// Returns a new tensor which indexes the input tensor along dimension dim using +// the entries in index. +// +// The returned tensor has the same number of dimensions as the original tensor +// (input). The dimth dimension has the same size as the length of index; other +// dimensions have the same size as in the original tensor. +// +// This operation supports 0 or more major batch dimensions that act like a +// multidimensional loop over both the input and the index. +XlaOp TorchIndexSelect(XlaOp input, XlaOp index, int64_t dim, + int64_t batch_dims = 0); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SLICING_H_ diff --git a/third_party/xla/xla/client/lib/slicing_test.cc b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc similarity index 99% rename from third_party/xla/xla/client/lib/slicing_test.cc rename to third_party/xla/xla/hlo/builder/lib/slicing_test.cc index 8dfc55f521f089..72e8e1ca7026d8 100644 --- a/third_party/xla/xla/client/lib/slicing_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/slicing_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/slicing.h" +#include "xla/hlo/builder/lib/slicing.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/sorting.cc b/third_party/xla/xla/hlo/builder/lib/sorting.cc similarity index 97% rename from third_party/xla/xla/client/lib/sorting.cc rename to third_party/xla/xla/hlo/builder/lib/sorting.cc index 48eec5d5ff2f7c..456accc515e111 100644 --- a/third_party/xla/xla/client/lib/sorting.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting.cc @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/sorting.h" +#include "xla/hlo/builder/lib/sorting.h" #include #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/sorting.h b/third_party/xla/xla/hlo/builder/lib/sorting.h new file mode 100644 index 00000000000000..b951f26b97b043 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/sorting.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SORTING_H_ +#define XLA_HLO_BUILDER_LIB_SORTING_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopK(XlaOp input, int64_t k, + PrimitiveType index_type = PrimitiveType::S32); + +// Split sort in TopK into smaller sorts. +// Returns a tuple composed of the top `k` values and corresponding indices in +// `input`. Output values are in descending order, from largest to smallest. +XlaOp TopKWithPartitions(XlaOp input, int64_t k, int64_t num_partitions = 1, + PrimitiveType index_type = PrimitiveType::S32); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SORTING_H_ diff --git a/third_party/xla/xla/client/lib/sorting_test.cc b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/sorting_test.cc rename to third_party/xla/xla/hlo/builder/lib/sorting_test.cc index 02eeff7ad80f22..2230eb73ecc4fb 100644 --- a/third_party/xla/xla/client/lib/sorting_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/sorting_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/sorting.h" +#include "xla/hlo/builder/lib/sorting.h" #include #include @@ -24,8 +24,8 @@ limitations under the License. #include "absl/algorithm/container.h" #include "xla/array.h" #include "xla/array2d.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/client/lib/svd.cc b/third_party/xla/xla/hlo/builder/lib/svd.cc similarity index 98% rename from third_party/xla/xla/client/lib/svd.cc rename to third_party/xla/xla/hlo/builder/lib/svd.cc index 88afe31e2ed0c3..22e4ab8d039bdc 100644 --- a/third_party/xla/xla/client/lib/svd.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/svd.h" +#include "xla/hlo/builder/lib/svd.h" #include #include @@ -24,14 +24,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/builder/lib/svd.h b/third_party/xla/xla/hlo/builder/lib/svd.h new file mode 100644 index 00000000000000..42d165f766ab43 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/svd.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_SVD_H_ +#define XLA_HLO_BUILDER_LIB_SVD_H_ + +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The singular value decomposition of a given matrix A[..., M, N], the original +// matrix is recovered by u * diag(d) * v_t, where the first dims(A) - 2 +// dimensions are batch dimensions. +struct SVDResult { + // The columns of U are the left-singular vectors, e.g., + // U[..., :, :]_T * U[..., :, :] = I. + XlaOp u; + // Vector(s) with the singular values, within each vector sorted in descending + // order. The first dims(D) - 1 dimensions have the same size as the batch + // dimensions of A. And U[..., :, i] * D[..., i] = A[..., :, :] * V[..., :, + // i]. + XlaOp d; + // The columns of V are the right-singular vectors. e.g., + // V[..., :, :]_T * V[..., :, :] = I. + XlaOp v; +}; + +// TODO(kuny): Add a bool flag that supports SVD with economy (reduced) +// representation, which is more memory efficient, especially in the case of +// tall-skinny matrices. +SVDResult SVD(XlaOp a, int64_t max_iter = 100, float epsilon = 1e-6, + PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_SVD_H_ diff --git a/third_party/xla/xla/client/lib/svd_test.cc b/third_party/xla/xla/hlo/builder/lib/svd_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/svd_test.cc rename to third_party/xla/xla/hlo/builder/lib/svd_test.cc index f1a7fc62a1c2e4..7266cde21684fe 100644 --- a/third_party/xla/xla/client/lib/svd_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/svd_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/svd.h" +#include "xla/hlo/builder/lib/svd.h" #include #include @@ -22,12 +22,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/client/lib/tridiagonal.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc similarity index 99% rename from third_party/xla/xla/client/lib/tridiagonal.cc rename to third_party/xla/xla/hlo/builder/lib/tridiagonal.cc index 4d4a4604e5ce23..9538a742e4cfce 100644 --- a/third_party/xla/xla/client/lib/tridiagonal.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tridiagonal.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tridiagonal.h b/third_party/xla/xla/hlo/builder/lib/tridiagonal.h new file mode 100644 index 00000000000000..d6bf56c009c2a7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ +#define XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace tridiagonal { + +enum SolverAlgorithm { kThomas }; + +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, + XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, XlaOp rhs); + +absl::StatusOr TridiagonalSolver(SolverAlgorithm algo, XlaOp diagonals, + XlaOp rhs); + +absl::StatusOr TridiagonalMatMul(XlaOp upper_diagonal, + XlaOp main_diagonal, + XlaOp lower_diagonal, XlaOp rhs); + +} // namespace tridiagonal +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_TRIDIAGONAL_H_ diff --git a/third_party/xla/xla/client/lib/tridiagonal_test.cc b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc similarity index 98% rename from third_party/xla/xla/client/lib/tridiagonal_test.cc rename to third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc index 280e4dd8ec17ae..5948c8840303e1 100644 --- a/third_party/xla/xla/client/lib/tridiagonal_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tridiagonal_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tridiagonal.h" +#include "xla/hlo/builder/lib/tridiagonal.h" #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/status.h" #include "xla/array.h" #include "xla/array3d.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/client/lib/tuple.cc b/third_party/xla/xla/hlo/builder/lib/tuple.cc similarity index 96% rename from third_party/xla/xla/client/lib/tuple.cc rename to third_party/xla/xla/hlo/builder/lib/tuple.cc index 4cefa748bc8d04..6a0145addefbde 100644 --- a/third_party/xla/xla/client/lib/tuple.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tuple.h" +#include "xla/hlo/builder/lib/tuple.h" #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/builder/lib/tuple.h b/third_party/xla/xla/hlo/builder/lib/tuple.h new file mode 100644 index 00000000000000..11d7d022806aef --- /dev/null +++ b/third_party/xla/xla/hlo/builder/lib/tuple.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_LIB_TUPLE_H_ +#define XLA_HLO_BUILDER_LIB_TUPLE_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/shape_tree.h" + +namespace xla { + +// Returns a ShapeTree where each index is a GetTupleElement instruction for +// that subshape of the tuple. The root index is the original argument. +absl::StatusOr> DisassembleTuple(XlaOp tuple); + +// Assembles a tuple from a ShapeTree that contains the leaves of the tuple. +// Non-leaf elements of the ShapeTree are ignored. DisassembleTuple and +// AssembleTuple are essentially inverse operations. +XlaOp AssembleTuple(XlaBuilder* builder, ShapeTree elements); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_LIB_TUPLE_H_ diff --git a/third_party/xla/xla/client/lib/tuple_test.cc b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc similarity index 97% rename from third_party/xla/xla/client/lib/tuple_test.cc rename to third_party/xla/xla/hlo/builder/lib/tuple_test.cc index cb2cab8abd0bed..67f270300acce4 100644 --- a/third_party/xla/xla/client/lib/tuple_test.cc +++ b/third_party/xla/xla/hlo/builder/lib/tuple_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/lib/tuple.h" +#include "xla/hlo/builder/lib/tuple.h" #include #include #include -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/service.h" diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/hlo/builder/padding.cc similarity index 98% rename from third_party/xla/xla/client/padding.cc rename to third_party/xla/xla/hlo/builder/padding.cc index daf26d5467ac7b..b8951735619e92 100644 --- a/third_party/xla/xla/client/padding.cc +++ b/third_party/xla/xla/hlo/builder/padding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/util.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/hlo/builder/padding.h b/third_party/xla/xla/hlo/builder/padding.h new file mode 100644 index 00000000000000..b0c83b7587a1ef --- /dev/null +++ b/third_party/xla/xla/hlo/builder/padding.h @@ -0,0 +1,66 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_PADDING_H_ +#define XLA_HLO_BUILDER_PADDING_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" + +namespace xla { + +// Describes the padding applied for a windowed operation like +// convolution, where a window is placed inside a base area. +enum class Padding { + // Make the output have the same dimensions as the base area. For + // example, for a 3x3 base area and a 2x2 window, the output will be + // 3x3, so that requires padding the 3x3 base area to 4x4. + kSame, + + // Use no padding. For example, for a 4x4 base area and a 2x2 + // window, the output will be 3x3. + kValid, +}; + +// Validates that the slices are acceptable for determining padding -- this can +// be used to check the preconditions of MakePadding below to produce an error +// message that can be returned to the user. +absl::Status ValidatePaddingValues(absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides); + +// Returns the padding needed for the base area, given the base area dimensions, +// window dimensions, strides, and the type of padding. +// +// If v is the returned vector, then for each dimension number i, +// v[i].first is the padding to the left (i.e. in the direction of +// lower indices) and v[i].second is the padding to the right (i.e. in +// the direction of higher indices). +// +// Precondition: The number of dimensions (i.e., rank) in input_dimensions, +// window_dimensions, and strides must match, which is equal to the number +// of elements in the result vector. +std::vector> MakePadding( + absl::Span input_dimensions, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_PADDING_H_ diff --git a/third_party/xla/xla/client/padding_test.cc b/third_party/xla/xla/hlo/builder/padding_test.cc similarity index 98% rename from third_party/xla/xla/client/padding_test.cc rename to third_party/xla/xla/hlo/builder/padding_test.cc index 0d183d0e16ede9..2d06a84cd3da4e 100644 --- a/third_party/xla/xla/client/padding_test.cc +++ b/third_party/xla/xla/hlo/builder/padding_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include diff --git a/third_party/xla/xla/client/sharding_builder.cc b/third_party/xla/xla/hlo/builder/sharding_builder.cc similarity index 98% rename from third_party/xla/xla/client/sharding_builder.cc rename to third_party/xla/xla/hlo/builder/sharding_builder.cc index 7b179b8c91ee4a..2c01cb16203c2b 100644 --- a/third_party/xla/xla/client/sharding_builder.cc +++ b/third_party/xla/xla/hlo/builder/sharding_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/sharding_builder.h" #include diff --git a/third_party/xla/xla/hlo/builder/sharding_builder.h b/third_party/xla/xla/hlo/builder/sharding_builder.h new file mode 100644 index 00000000000000..245ab9e2c004d7 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/sharding_builder.h @@ -0,0 +1,60 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_SHARDING_BUILDER_H_ +#define XLA_HLO_BUILDER_SHARDING_BUILDER_H_ + +#include + +#include "xla/array.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace sharding_builder { +// A shaped array used to describe the assignment of tiles to devices. +using TileAssignment = Array; + +// Creates a replicated sharding - replicate a tensor on every device. +OpSharding Replicate(); + +// Creates a manual sharding - the partitioner will not change the shape. +OpSharding Manual(); + +// Creates a sharding that assigns a tensor to just one device. +OpSharding AssignDevice(int device); + +// Creates a tiled sharding with the given tile shape and assignment of tiles +// to devices. +// +// If tile_shape is not evenly divisible by the number of devices in +// tile_assignment, operations behave as if implicit padding had been inserted. +// The value of this padding is undefined. +OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); + +// Creates a sharding in one dimension, with the given tile shape which must +// be rank 1 and using devices [0..num_tiles). +// +// This is simply a convenience wrapper for Tile(). +OpSharding Tile1D(const Shape& tile_shape, int64_t num_tiles); + +// Creates a tuple sharding from the given ShapeTree of element shardings. +OpSharding Tuple(const ShapeTree& shardings); + +} // namespace sharding_builder +} // namespace xla + +#endif // XLA_HLO_BUILDER_SHARDING_BUILDER_H_ diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/hlo/builder/value_inference.cc similarity index 99% rename from third_party/xla/xla/client/value_inference.cc rename to third_party/xla/xla/hlo/builder/value_inference.cc index 2f0b6e20756bff..165ee203042443 100644 --- a/third_party/xla/xla/client/value_inference.cc +++ b/third_party/xla/xla/hlo/builder/value_inference.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/hlo/builder/value_inference.h b/third_party/xla/xla/hlo/builder/value_inference.h new file mode 100644 index 00000000000000..7f69a5979553dc --- /dev/null +++ b/third_party/xla/xla/hlo/builder/value_inference.h @@ -0,0 +1,117 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_BUILDER_VALUE_INFERENCE_H_ +#define XLA_HLO_BUILDER_VALUE_INFERENCE_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +// OptionalLiteral is an augmented literal class which returns optional +// values for each index (the value can be either valid or invalid). The +// implementation keeps two literals, a value literal, holding both the valid +// and garabage value, and a masking literal representing if a value is valid or +// garbage. +class OptionalLiteral { + public: + explicit OptionalLiteral(Literal value, Literal mask) + : value_(std::move(value)), mask_(std::move(mask)) {} + + template + std::optional Get(absl::Span element_index, + ShapeIndex shape_index = {}) const { + if (mask_.Get(element_index, shape_index)) { + return std::nullopt; + } else { + return value_.Get(element_index, shape_index); + } + } + + // Returns true if all values in this literal slice are value. + bool AllValid() { return mask_.IsAll(0); } + + // Get value out of this slice if all values are valid. Otherwise returns + // nullopt. + std::optional GetValue() { + if (!AllValid()) { + return std::nullopt; + } + return LiteralSlice(value_); + } + + private: + Literal value_; + Literal mask_; +}; + +enum ValueInferenceMode { + // Inference the constant value itself. + kValue = 0, + // Inference upper-bound and lower-bound of the value. Bounds are inclusive. + kUpperBound, + kLowerBound, +}; + +class ValueInference { + public: + // ValueInference analyzes values in XlaOp answers following questions: + // - What's the upper-bound of each value in a tensor. + // - What's the lower-bound of each value in a tensor. + // - What's the constant value of each tensor. + // - Whether or not each value in a tensor is dynamic. + explicit ValueInference(XlaBuilder* builder) : builder_(builder) { + CHECK(builder_); + } + absl::StatusOr AnalyzeIsDynamic(XlaOp op); + // Returns an OptionalLiteral. Each individual value of the literal is + // the concrete constant value if it can be inferred, otherwise a nullopt. + absl::StatusOr AnalyzeConstant(XlaOp op, + ValueInferenceMode mode); + + // Returns underlying xla builder. + XlaBuilder* builder() { return builder_; } + + private: + // Given an op handle, returns a simplified version of the handle inside a + // int64_t Literal. If the a -1 value for the handle means invalid + // simplification and the result shouldn't be used. + absl::StatusOr SimplifyOp(int64_t handle); + + // Perform CSE on a given handle, and return an equivalent handle if seen + // before. Otherwise, returns nullopt. + absl::StatusOr> CseOpHandle(int64_t handle); + XlaBuilder* builder_; + HloEvaluator evaluator_; + // A map from instruction_hash to handle that helps perform CSE. + absl::flat_hash_map cse_map_; +}; +} // namespace xla + +#endif // XLA_HLO_BUILDER_VALUE_INFERENCE_H_ diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/hlo/builder/xla_builder.cc similarity index 99% rename from third_party/xla/xla/client/xla_builder.cc rename to third_party/xla/xla/hlo/builder/xla_builder.cc index 98e7dada978400..140addbcc026a4 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include #include @@ -44,10 +44,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" diff --git a/third_party/xla/xla/hlo/builder/xla_builder.h b/third_party/xla/xla/hlo/builder/xla_builder.h new file mode 100644 index 00000000000000..891d4fec725c69 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/xla_builder.h @@ -0,0 +1,3086 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_XLA_BUILDER_H_ +#define XLA_HLO_BUILDER_XLA_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/array3d.h" +#include "xla/array4d.h" +#include "xla/comparison_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/dynamic_parameter_binding.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/lib/core/bitmap.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/stacktrace.h" + +namespace xla { + +class XlaBuilder; +class XlaOp; +class HloInstruction; + +namespace internal { + +struct XlaBuilderFriend { + static XlaOp BuildAddDependency(XlaBuilder* builder, XlaOp operand, + XlaOp token, const Shape& shape); + + static std::pair BuildAsyncStart( + XlaBuilder* builder, absl::Span operands, + std::string execution_thread, const XlaComputation& called_computation, + const Shape& shape); + static XlaOp BuildAsyncUpdate(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + static XlaOp BuildAsyncDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildAllGatherStart( + XlaBuilder* builder, XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllGatherDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildAllReduceStart( + XlaBuilder* builder, XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + static XlaOp BuildAllReduceDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildCollectivePermuteStart( + XlaBuilder* builder, XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + static XlaOp BuildCollectivePermuteDone(XlaBuilder* builder, XlaOp operands, + const Shape& shape); + + static XlaOp BuildCopyStart( + XlaBuilder* builder, XlaOp operand, + std::optional cross_program_prefetch_index = std::nullopt); + static XlaOp BuildCopyDone(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static XlaOp BuildFusion( + XlaBuilder* builder, absl::Span operands, + absl::string_view fusion_kind, const XlaComputation& fused_computation, + absl::Span>> + output_operand_aliasing = {}); + + static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand, + const Shape& shape); + + static XlaOp BuildPartitionId(XlaBuilder* builder, const Shape& shape); + + static XlaOp BuildSend(XlaBuilder* builder, XlaOp operand, XlaOp token, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildSendDone(XlaBuilder* builder, XlaOp operand, + const ChannelHandle& handle, + bool is_host_transfer); + + static XlaOp BuildRecv(XlaBuilder* builder, XlaOp token, const Shape& shape, + const ChannelHandle& handle, bool is_host_transfer); + static XlaOp BuildRecvDone(XlaBuilder* builder, XlaOp token, + const Shape& shape, const ChannelHandle& handle, + bool is_host_transfer); + + static XlaOp BuildDomain(XlaBuilder* builder, XlaOp operand, OpSharding entry, + OpSharding exit, const Shape& shape); + + static XlaOp BuildRngGetAndUpdateState(XlaBuilder* builder, int64_t delta, + const Shape& shape); + + static HloInstructionProto* GetInstruction(XlaOp op); + static HloInstructionProto* GetInstructionByHandle(XlaBuilder* builder, + int64_t handle); +}; + +} // namespace internal + +// This represents an instruction that has been enqueued using the XlaBuilder. +// This is used to pass to subsequent computations that depends upon the +// instruction as an operand. +class XlaOp { + public: + XlaOp() : handle_(-1), builder_(nullptr) { + static_assert(std::is_trivially_destructible::value, + "XlaOp should be trivially destructible"); + } + ~XlaOp() = default; + + XlaOp(const XlaOp& other) = default; + XlaOp& operator=(const XlaOp& other) = default; + + // Precondition: !IsUninitialized(). + // + // It's very common to do foo.builder()->bar(). Without this precondition, if + // foo.builder() is null, the call to bar will segfault at some point possibly + // deep in the callstack when we finally dereference `this`. The precondition + // lets us avoid this tricky-to-debug problem. + XlaBuilder* builder() const { + CHECK(builder_ != nullptr); + return builder_; + } + + // Returns true if the XlaOp represents valid, non-erroneous value. + bool valid() const { return handle_ >= 0; } + + // Returns true if the XlaOp was created by the XlaOp() constructor and + // not returned by a builder. + bool IsUninitialized() const { return builder_ == nullptr; } + + bool IsIdenticalTo(XlaOp rhs) const { + return handle_ == rhs.handle_ && builder_ == rhs.builder_; + } + + friend std::ostream& operator<<(std::ostream& out, XlaOp op) { + out << op.handle(); + return out; + } + + private: + explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {} + XlaOp(int64_t handle, XlaBuilder* builder) + : handle_(handle), builder_(builder) {} + + int64_t handle() const { return handle_; } + + friend class XlaBuilder; + friend class ValueInference; + friend struct internal::XlaBuilderFriend; + + // < 0 means "invalid handle". + int64_t handle_; + + // Not owned. Non-null for any handle returned by XlaBuilder, even if the + // handle is invalid. + XlaBuilder* builder_; +}; + +// Arithmetic operator overloads for the XlaOp type. +XlaOp operator-(XlaOp x); +XlaOp operator+(XlaOp x, XlaOp y); +XlaOp operator-(XlaOp x, XlaOp y); +XlaOp operator*(XlaOp x, XlaOp y); +XlaOp operator/(XlaOp x, XlaOp y); +XlaOp operator%(XlaOp x, XlaOp y); + +// Bitwise operator overloads for the XlaOp type. +XlaOp operator~(XlaOp x); +XlaOp operator&(XlaOp x, XlaOp y); +XlaOp operator|(XlaOp x, XlaOp y); +XlaOp operator^(XlaOp x, XlaOp y); +XlaOp operator<<(XlaOp x, XlaOp y); +// Performs a right arithmetic shift if 'x' is a signed type, otherwise performs +// a right logical shift. +XlaOp operator>>(XlaOp x, XlaOp y); + +// We don't overload the relational operators (==, !=, <, <=, >, >=) because the +// semantics might be surprising since their result types are usually 'bool'. +// Further programmers may expect == to be a structural equality. +// We also choose not to overload any of the mutating operators (e.g., +=, -=) +// because the semantics might be misleading — XLA computations are immutable. + +// A convenient interface for building up computations. +// +// Thread-compatible. +class XlaBuilder { + public: + // computation_name: name to use for the built computation. + explicit XlaBuilder(const std::string& computation_name); + + XlaBuilder(const XlaBuilder&) = delete; + XlaBuilder& operator=(const XlaBuilder&) = delete; + + virtual ~XlaBuilder(); + + // Returns the computation name. + const std::string& name() const { return name_; } + + // Sets OpMetadata that will be added to all instructions until cleared. + // + // OpMetadata is often applied to a series of XLA HLO instructions. As a + // result, OpMetadata is set on the computation builder. All subsequent + // instructions generated via this computation builder will have the same + // OpMetadata attached until a call to ClearOpMetadata. + void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); } + + // Swaps the passed op metadata with the ones currently set. + // + // Returns the old op metadata. + OpMetadata SwapOpMetadata(OpMetadata metadata) { + OpMetadata old_metadata = std::move(metadata_); + metadata_ = std::move(metadata); + return old_metadata; + } + + // Similar to SetOpMetadata, but only set the metadata for the next op. + void SetOneShotOpMetadata(OpMetadata metadata) { + one_shot_metadata_ = std::move(metadata); + } + + // Clears the HloMetadata state. + void ClearOpMetadata() { metadata_.Clear(); } + + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } + + // Sets the FrontendAttributes that will be added to all instructions until + // cleared. + // + // FrontendAttributes are often applied to a series of XLA HLO instructions. + // As a result they are set on the computation builder and all the + // instructions generated via the computation builder will have the same + // frontend attributes attached to them. + virtual void SetFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + frontend_attributes_ = frontend_attributes; + } + + // Swap the passed FrontendAttributes with the ones currently set. + // + // Return the old attributes. + FrontendAttributes SwapFrontendAttributes( + const FrontendAttributes& frontend_attributes) { + FrontendAttributes old_attributes = std::move(frontend_attributes_); + frontend_attributes_ = frontend_attributes; + return old_attributes; + } + + // Returns the FrontendAttributes that will be attached to all instructions. + const FrontendAttributes& frontend_attributes() const { + return frontend_attributes_; + } + + // Clears all the frontend attributes. + void ClearFrontendAttributes() { frontend_attributes_.Clear(); } + + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = std::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const std::optional& sharding() const { return sharding_; } + + // Sets the builder to a mode where it will die immediately when an error is + // encountered, rather than producing it in a deferred fashion when Build() is + // called (which is the default). + void set_die_immediately_on_error(bool enabled) { + die_immediately_on_error_ = enabled; + } + + // Default dimension numbers used for a 2D convolution. + static constexpr int64_t kConvBatchDimension = 0; + static constexpr int64_t kConvFeatureDimension = 1; + static constexpr int64_t kConvFirstSpatialDimension = 2; + static constexpr int64_t kConvSecondSpatialDimension = 3; + static constexpr int64_t kConvKernelOutputDimension = 0; + static constexpr int64_t kConvKernelInputDimension = 1; + static constexpr int64_t kConvKernelFirstSpatialDimension = 2; + static constexpr int64_t kConvKernelSecondSpatialDimension = 3; + + // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for + // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for + // the kernel operand + // {output_feature, input_feature, height, width} = {0, 1, 2, 3}. + static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers( + int num_spatial_dims = 2); + + // Returns an error if the convolution dimension numbers have conflicts. + static absl::Status Validate(const ConvolutionDimensionNumbers& dnum); + + // Returns a new XlaBuilder whose resultant Computation is used only by this + // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error + // behavior as the parent. + std::unique_ptr CreateSubBuilder( + const std::string& computation_name); + + // Builds the computation with the requested operations, or returns a non-ok + // status. Note that all ops that have been enqueued will be moved to the + // computation being returned. The root of the computation will be the last + // added operation. + // + // `remove_dynamic_dimensions` tells the builder whether to remove the + // dynamic dimensions information in all ops. + // + // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the + // dynamic dimensions information when XLA backend can handle dynamic + // dimensions. + absl::StatusOr Build(bool remove_dynamic_dimensions = false); + + // Overload of Build which specifies a particular root instruction for the + // computation. + absl::StatusOr Build(XlaOp root, + bool remove_dynamic_dimensions = false); + + // Builds the computation with the requested operations, or notes an error in + // the parent XlaBuilder and returns an empty computation if building failed. + // This function is intended to be used where the returned XlaComputation is + // only used by the parent XlaBuilder and hence further operation on the + // returned XlaComputation will simply be error'ed out if an error occurred + // while building this computation. If the built computation is to be used by + // a XlaBuilder other than the parent XlaBuilder then Build() should be used + // instead. + XlaComputation BuildAndNoteError(); + + // Returns a subgraph that roots on the given root. If the root is not a + // compile-time constant (see `IsConstant`), returns an error. + // + // This will copy the needed ops/computations to the subgraph. + absl::StatusOr BuildConstantSubGraph( + XlaOp root_op, bool dynamic_dimension_is_minus_one = false); + + // Returns the first error that was encountered while building the + // computation. When an error is encountered, by default we return a vacuous + // XlaOp and inform the user of the error that occurred while + // building the computation when they make a final call to Build(). + // + // See also set_die_immediately_on_error(). + absl::Status first_error() const { return first_error_; } + + // Returns the current status of the builder, complete with the stack trace + // information. + absl::Status GetCurrentStatus() const; + + // Returns the shape of the given op. + absl::StatusOr GetShape(XlaOp op) const; + + // Returns the shape of the given op. + virtual absl::StatusOr GetShapePtr(XlaOp op) const; + + // Returns the OpSharding of the given op. If "op" has no sharding, return + // std::nullopt. + absl::StatusOr> GetOpSharding(XlaOp op) const; + + // Returns the (inferred) result for the current computation's shape. This + // assumes the root instruction is the last added instruction. + absl::StatusOr GetProgramShape() const; + + // Returns the (inferred) result for the current computation's shape using the + // given operation as the root. + absl::StatusOr GetProgramShape(XlaOp root) const; + + // Reports an error to the builder, by + // * storing it internally and capturing a backtrace if it's the first error + // (this deferred value will be produced on the call to + // Build()/GetShape()/...) + // * dying if die_immediately_on_error_ is true. + // Returns an XlaOp with an invalid handle but a valid builder. This value can + // be returned in place of a value in APIs that return an XlaOp. + XlaOp ReportError(const absl::Status& error); + + // A helper function that converts a absl::StatusOr into an XlaOp. + // If the absl::Status was an error, reports the error to builder and returns + // an invalid XlaOp handle. + XlaOp ReportErrorOrReturn(const absl::StatusOr& op); + + // A helper function that runs a function that returns a absl::StatusOr + // and returns an XlaOp. + XlaOp ReportErrorOrReturn( + absl::FunctionRef()> op_creator); + + // Returns true if 'operand' is a compile-time constant. A compile-time + // constant does not depend on any parameters, or on stateful operators such + // as `RngNormal` or `Infeed`. + // + // This tests whether a computation is a compile-time constant without + // evaluating the computation. + absl::StatusOr IsConstant(XlaOp operand) const; + + // Adds a new input/output alias. Since the input/output shape information are + // not available until the computation is built, any eventual error in the + // arguments of this API will be detected only at computation Build() time. + // + // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias' + // and only donated buffer at runtime will be aliased with output. If a buffer + // is not donated at runtime, a copy will be inserted by XLA to prevent buffer + // clobbering. + void SetUpAlias(const ShapeIndex& output_index, int64_t param_number, + const ShapeIndex& param_index, + HloInputOutputAliasConfig::AliasKind kind = + HloInputOutputAliasConfig::AliasKind::kMayAlias) { + input_output_aliases_.push_back( + {output_index, param_number, param_index, kind}); + } + + // Describes an input/output alias as inserted by the SetUpAlias() API. + struct InputOutputAlias { + // Specifies the index of the aliased buffer in the result tuple. + ShapeIndex output_index; + // Specifies the parameter containing the buffer to be aliased. + int64_t param_number; + // Specifies the index of the aliased buffer in the parameter. + ShapeIndex param_index; + // Specifies if the alias is a must alias or may alias. + HloInputOutputAliasConfig::AliasKind kind; + }; + + // Adds a new buffer donor. The donated buffer may be paired with any valid + // output. On the contrary, the buffer aliasing bonds the input output pair. + // The input can only donate the buffer to the paired output. + void AddBufferDonor(int64_t param_number, const ShapeIndex& param_index) { + buffer_donors_.insert({param_number, param_index}); + } + + // Looks up the HloInstruction and sets the frontend attribute "attribute" to + // "value". If the attribute already existed, then its value is updated. + // + // The attribute is only added to the HloInstruction, not to the builder. + absl::Status SetInstructionFrontendAttribute(XlaOp op, std::string attribute, + std::string value); + + // Looks up the HloInstruction and sets the sharding. If the sharding already + // existed, then its value is updated. + // + // The sharding is only added to the HloInstruction, not to the builder. + absl::Status SetInstructionSharding( + XlaOp op, const std::optional& sharding); + + // Returns shapes for the operands. + absl::StatusOr> GetOperandShapes( + absl::Span operands) const; + + // Converts the op to string for the ease of debugging. + std::string OpToString(XlaOp op) const; + + private: + void ToStringHelper(std::string* out, int ident, int64_t op_handle) const; + + // Build helper which takes the id of the root operation.. + absl::StatusOr Build(int64_t root_id, + bool remove_dynamic_dimensions); + + // Description for the methods below can be found in the corresponding public + // functions section in this file. + + XlaOp Parameter(int64_t parameter_number, const Shape& shape, + const std::string& name, + const std::vector& replicated_at_leaf_buffers); + XlaOp Parameter(int64_t parameter_number, const Shape& shape, + const std::string& name) { + std::vector empty_bools; + return Parameter(parameter_number, shape, name, empty_bools); + } + + virtual XlaOp ConstantLiteral(const LiteralSlice& literal); + + XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + + XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); + + // This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim + // op from the XlaBuilder. This is only intended for export to MHLO or + // StableHLO, and cannot be compiled. Only static output_dimensions are + // allowed, and broadcast_dimensions is verified. + XlaOp MhloDynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + + XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + + virtual absl::StatusOr PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + + XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes, + int64_t inferred_dimension = -1); + + XlaOp Reshape(XlaOp operand, absl::Span new_sizes, + int64_t inferred_dimension = -1); + + XlaOp Reshape(const Shape& shape, XlaOp operand, + int64_t inferred_dimension = -1); + + XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + + XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + + XlaOp Collapse(XlaOp operand, absl::Span dimensions); + + XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual absl::StatusOr SliceInternal( + const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, int64_t stride, int64_t dimno); + + XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + virtual absl::StatusOr DynamicSliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + + XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + virtual absl::StatusOr DynamicUpdateSliceInternal( + const Shape& shape, XlaOp operand, XlaOp update, + absl::Span start_indices); + + XlaOp ConcatInDim(absl::Span operands, int64_t dimension); + virtual absl::StatusOr ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64_t dimension); + + XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + + XlaOp Tuple(absl::Span elements); + virtual absl::StatusOr TupleInternal(const Shape& shape, + absl::Span elements); + + XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + virtual absl::StatusOr GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64_t index); + + XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp Conv( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, int64_t feature_group_count = 1, + int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + + XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt, + std::optional> window_reversal = std::nullopt); + + XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + absl::StatusOr DynamicConvInstruction( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + + virtual absl::StatusOr ConvGeneralDilatedInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config); + + XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + virtual absl::StatusOr FftInternal( + const Shape& shape, XlaOp operand, FftType fft_type, + absl::Span fft_length); + + virtual absl::StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); + + virtual absl::StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower); + + XlaOp Infeed(const Shape& shape, const std::string& config = ""); + XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config); + virtual absl::StatusOr InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, + const std::string& config); + + void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + virtual absl::StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const std::string& outfeed_config); + XlaOp Call(const XlaComputation& computation, + absl::Span operands); + + XlaOp CompositeCall( + const XlaComputation& computation, absl::Span operands, + const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + + XlaOp CustomCall( + const std::string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, std::optional window, + std::optional dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + + // Internal version of CustomCall without computation that doesn't do op + // specific error handling and expects arguments to be legal. CustomCall + // method above calls this method after error handling. + virtual absl::StatusOr CustomCallInternal( + const std::string& call_target_name, absl::Span operands, + const XlaComputation* computation, const Shape& shape_with_layout, + const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, std::optional window, + std::optional dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + + // TODO(b/239474321) Remove this overload as it has simply led to code + // duplication. + XlaOp CustomCall( + const std::string& call_target_name, absl::Span operands, + const XlaComputation& computation, const Shape& shape_with_layout, + const std::string& opaque, + std::optional> operand_shapes_with_layout, + bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + + XlaOp OptimizationBarrier(XlaOp operand); + + XlaOp Reduce(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + XlaOp Reduce(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + virtual absl::StatusOr ReduceInternal( + const Shape& shape, absl::Span all_operands, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + + XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + + XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + + XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + + XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + absl::StatusOr ReduceWindowInternal( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + virtual absl::StatusOr ReduceWindowInternal( + const Shape& shape, XlaOp operand, XlaOp init_value, + const XlaComputation& computation, Window window); + XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups = {}); + + XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp ReduceScatter( + XlaOp operand, const XlaComputation& computation, + int64_t scatter_dimension, int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + + XlaOp AllToAll(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + + XlaOp AllToAllTuple( + absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); + + XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id = std::nullopt); + + XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + + XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + + XlaOp ReplicaId(); + + XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + + XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + + absl::StatusOr SelectAndScatterInternal( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + + virtual XlaOp Iota(const Shape& shape, int64_t iota_dimension); + + XlaOp Iota(PrimitiveType type, int64_t size); + + XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); + + XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + virtual absl::StatusOr BitcastConvertTypeInternal(const Shape& shape, + XlaOp operand); + + XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + + XlaOp Transpose(XlaOp operand, absl::Span permutation); + virtual absl::StatusOr TransposeInternal( + const Shape& shape, XlaOp operand, absl::Span permutation); + + XlaOp Rev(XlaOp operand, absl::Span dimensions); + virtual absl::StatusOr RevInternal( + const Shape& shape, XlaOp operand, absl::Span dimensions); + + XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64_t dimension = -1, bool is_stable = false); + virtual absl::StatusOr SortInternal(const Shape& shape, + absl::Span operands, + const XlaComputation& comparator, + int64_t dimension, bool is_stable); + + XlaOp TopK(XlaOp operand, int64_t k, bool largest); + virtual absl::StatusOr TopKInternal(const Shape& shape, XlaOp operand, + int64_t k, bool largest); + + XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + + XlaOp Map(absl::Span operands, const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); + + XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + + XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + + XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + // Internal variant for the op with the full result shape containing both data + // and state shape as a tuple. + virtual absl::StatusOr RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state); + + XlaOp While(const XlaComputation& condition, const XlaComputation& body, + XlaOp init); + virtual absl::StatusOr WhileInternal(const Shape& shape, + const XlaComputation& condition, + const XlaComputation& body, + XlaOp init); + + XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, XlaOp false_operand, + const XlaComputation& false_computation); + + XlaOp Conditional(XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + + XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); + virtual absl::StatusOr ReducePrecisionInternal(const Shape& shape, + XlaOp operand, + int exponent_bits, + int mantissa_bits); + + XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted = false); + + virtual absl::StatusOr GatherInternal( + const Shape& shape, XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, bool indices_are_sorted); + + XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + + virtual absl::StatusOr ScatterInternal( + const Shape& shape, absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted, + bool unique_indices); + + void Send(XlaOp operand, const ChannelHandle& handle); + XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); + + XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const ChannelHandle& handle); + + XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + + virtual XlaOp CreateToken(); + + XlaOp AfterAll(absl::Span tokens); + + XlaOp Recv(const Shape& shape, const ChannelHandle& handle); + XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + + XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, + float epsilon, int64_t feature_index); + + XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, + XlaOp variance, float epsilon, + int64_t feature_index); + + XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + + XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + + XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + + virtual absl::StatusOr SetDimensionSizeInternal(const Shape& shape, + XlaOp operand, + XlaOp val, + int64_t dimension); + + XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + + virtual absl::StatusOr AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands); + absl::StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { + return AddInstruction(std::move(instr), opcode, /*operands=*/{}); + } + + void AddCalledComputation(const XlaComputation& computation, + HloInstructionProto* instr); + + absl::StatusOr LookUpInstruction(XlaOp op) const; + absl::StatusOr LookUpInstructionByHandle( + int64_t handle) const; + absl::StatusOr LookUpMutableInstruction(XlaOp op); + absl::StatusOr LookUpMutableInstructionByHandle( + int64_t handle); + + // Internal helper method that does the building for an arbitrary unary op. + virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand); + + // Internal helper method that does the building for an arbitrary binary op. + // broadcast_dimensions specifies which dimensions to use for broadcasting + // when the operation is between tensors of different ranks. The direction is + // only used if opcode is kCompare. + XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + std::optional direction = std::nullopt, + std::optional type = std::nullopt); + + absl::StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, + ComparisonDirection direction); + + // Internal helper method for binary op compare without broadcast dimensions. + virtual absl::StatusOr Compare(const Shape& shape, XlaOp lhs, + XlaOp rhs, + ComparisonDirection direction, + Comparison::Type type); + + // Internal helper method that does the building for an arbitrary binary op + // with same ranked operands that doesn't broadcast. + virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape, + XlaOp lhs, XlaOp rhs); + + // Internal helper method that does the building for an arbitrary ternary op. + XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs); + + XlaOp RngOp(RandomDistribution distribution, + absl::Span parameters, const Shape& shape); + + virtual absl::StatusOr RngOpInternal( + RandomDistribution distribution, absl::Span parameters, + const Shape& shape); + + virtual absl::StatusOr InDimBroadcast( + const Shape& shape, XlaOp operand, + absl::Span broadcast_dimensions); + + // Internal helper method that creates a sequence of instructions that + // performs an explicit broadcast of the operand to the target shape. + // All dimensions of the operand must either be equal to the corresponding + // output shape dimension, or be exactly 1. (Such dimensions are the + // degenerate dimensions.) + absl::StatusOr AddBroadcastSequence(const Shape& output_shape, + XlaOp operand); + + // Internal helper method that broadcasts a scalar to the shape of the output. + absl::StatusOr BroadcastScalarToOutputShape(XlaOp scalar, + XlaOp output); + + // Internal helper method for creating a Reshape op with the already inferred + // shape. + virtual absl::StatusOr ReshapeInternal(const Shape& shape, + XlaOp operand, + int64_t inferred_dimension); + + // Returns the (inferred) result for the program shape using the given root. + absl::StatusOr GetProgramShape(int64_t root_id) const; + + // A visitor which checks whether an operation is a compile-time constant, + // meaning that it doesn't depend on any parameters, or on any stateful + // operation such as `RngNormal` or `Infeed`. The visitor walks the + // computation starting at a given operation and sets is_constant to false iff + // a parameter or stateful operation is encountered. + void IsConstantVisitor(int64_t op_handle, int depth, + absl::flat_hash_set* visited, + bool* is_constant) const; + + // Checks bounds for convolution parameters. + absl::Status VerifyConvolution( + const Shape& lhs_shape, const Shape& rhs_shape, + const ConvolutionDimensionNumbers& dimension_numbers) const; + + int64_t GetNextId() { return ++next_id_; } + + // Populates the module with the input/output alias information stored within + // the input_output_aliases vector. + static absl::Status PopulateInputOutputAliasAndBufferDonor( + HloModuleProto* module, const ProgramShape& program_shape, + const std::vector& input_output_aliases, + const absl::flat_hash_set& + buffer_donors); + + std::string name_; // Name to use for the built computation. + + // The next sequential ID for every instruction/computation contained within + // this computation. + int64_t next_id_ = 0; + + // The first error encountered while building the computation. + // This is OK until the first error is encountered. + absl::Status first_error_; + + // The saved stack trace from the point at which the first error occurred. + tsl::SavedStackTrace first_error_backtrace_; + + // The instructions of this computation. + // Use a deque so pointers into this are stable, for example the return + // value of LookUpInstructionByHandle(). + std::deque instructions_; + // A cache for the HloInstructionProto shapes, to avoid recreating Shape + // objects from protos and to support the GetShapePtr() API. + std::vector> instruction_shapes_; + + // Dynamic parameter configuration of this computation. + DynamicParameterBinding dynamic_parameter_binding_; + + // Holds the input/output alias information populated by the SetUpAlias() API. + std::vector input_output_aliases_; + + // Holds the buffer donor information populated by the AddBufferDonor() API. + absl::flat_hash_set buffer_donors_; + + // A map from XlaOp::Handle to the index in the instructions_ vector where the + // instruction is held. + absl::flat_hash_map handle_to_index_; + + // Track imported instructions by their computation id and the position in + // their computation's instruction list. + struct ImportedInstruction { + int64_t computation_id; + int64_t instruction_index; + }; + + absl::flat_hash_map handle_to_imported_index_; + + // The embedded computations used by this computation. Each computation was + // the entry computation of some XlaComputation, the key is the unique id of + // that XlaComputation. + std::map embedded_; + + // The unique parameter numbers. + absl::flat_hash_set parameter_numbers_; + + // The metadata to attach to each op. This is structured as a "modal"-like + // operation, in order to simplify client code (and not sprinkle this metadata + // throughout the TensorFlow op kernel implementations). + OpMetadata metadata_; + + // A temporary metadata that will only be applied to the next op created. + std::optional one_shot_metadata_; + + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + std::optional sharding_; + + // Mode bit that indicates whether to die when a first error is encountered. + bool die_immediately_on_error_ = false; + + XlaBuilder* parent_builder_{nullptr}; + + FrontendAttributes frontend_attributes_; + + friend XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name, + const std::vector& replicated_at_leaf_buffers); + friend XlaOp ConstantLiteral(XlaBuilder* builder, + const LiteralSlice& literal); + + friend XlaOp Broadcast(XlaOp operand, + absl::Span broadcast_sizes); + + friend XlaOp BroadcastInDim(XlaOp operand, + absl::Span out_dim_size, + absl::Span broadcast_dimensions); + + friend XlaOp MhloDynamicBroadcastInDim( + XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + + friend XlaOp Copy(XlaOp operand); + + friend XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + + friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + + friend XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes); + + friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + + friend XlaOp Reshape(const Shape& shape, XlaOp operand); + + friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + + friend XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, + const Shape& shape); + + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, + absl::Span new_sizes, + int64_t inferred_dimension); + + friend XlaOp Collapse(XlaOp operand, absl::Span dimensions); + + friend XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + + friend XlaOp SliceInDim(XlaOp operand, int64_t start_index, + int64_t limit_index, int64_t stride, int64_t dimno); + + friend XlaOp DynamicSlice(XlaOp operand, + absl::Span start_indices, + absl::Span slice_sizes); + + friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + + friend XlaOp ConcatInDim(XlaBuilder* builder, + absl::Span operands, int64_t dimension); + + friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + friend XlaOp Tuple(XlaBuilder* builder, absl::Span elements); + friend XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); + friend XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, + Comparison::Type compare_type); + friend XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + virtual absl::StatusOr DotGeneralInternal( + const Shape& shape, XlaOp lhs, XlaOp rhs, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config); + friend XlaOp SparseDot(XlaOp lhs, XlaOp rhs, + absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_number, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp Conv(XlaOp lhs, XlaOp rhs, + absl::Span window_strides, Padding padding, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + friend XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + friend XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + friend XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type); + + friend XlaOp ConvKernelGrad( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type); + + friend XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, + std::optional preferred_element_type, + std::optional> window_reversal); + + friend XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + friend XlaOp Cholesky(XlaOp a, bool lower); + friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const std::string& config); + friend void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands); + + friend XlaOp CompositeCall(XlaBuilder* builder, + const XlaComputation& computation, + absl::Span operands, + const std::string& name, + std::optional attributes, + std::optional version); + + friend XlaOp CustomCall( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithComputation( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, CustomCallSchedule schedule, + CustomCallApiVersion api_version); + friend XlaOp CustomCallWithConvDnums( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, + CustomCallSchedule schedule, CustomCallApiVersion api_version); + friend XlaOp OptimizationBarrier(XlaOp operand); + friend XlaOp Complex(XlaOp real, XlaOp imag, + absl::Span broadcast_dimensions); + friend XlaOp Conj(XlaOp operand); + friend XlaOp Add(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Sub(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Mul(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Div(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Rem(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Max(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Min(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp And(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Or(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Xor(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp Not(XlaOp operand); + friend XlaOp PopulationCount(XlaOp operand); + friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp ShiftRightArithmetic( + XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); + friend XlaOp ShiftRightLogical( + XlaOp lhs, XlaOp rhs, absl::Span broadcast_dimensions); + friend XlaOp Reduce(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + friend XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); + friend XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding); + friend XlaOp ReduceWindowWithGeneralPadding( + XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + friend XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + + friend XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups); + friend XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + friend XlaOp AllGatherTuple(absl::Span operands, + int64_t all_gather_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + std::optional use_global_device_ids); + friend XlaOp AllReduceTuple(absl::Span operand, + const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& shape_with_layout, + std::optional use_global_device_ids); + friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation, + int64_t scatter_dimension, int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids); + + friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp AllToAllTuple(absl::Span operands, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp AllToAllTuple(XlaOp operand, int64_t split_dimension, + int64_t concat_dimension, int64_t split_count, + absl::Span replica_groups, + const std::optional& layout, + const std::optional& channel_id); + friend XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id); + friend XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id); + friend XlaOp ReplicaId(XlaBuilder* builder); + friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + friend XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + friend XlaOp Abs(XlaOp operand); + friend XlaOp Atan2(XlaOp y, XlaOp x, + absl::Span broadcast_dimensions); + friend XlaOp Erf(XlaOp operand); + friend XlaOp Exp(XlaOp operand); + friend XlaOp Expm1(XlaOp operand); + friend XlaOp Floor(XlaOp operand); + friend XlaOp Ceil(XlaOp operand); + friend XlaOp Round(XlaOp operand); + friend XlaOp RoundNearestEven(XlaOp operand); + friend XlaOp Log(XlaOp operand); + friend XlaOp Log1p(XlaOp operand); + friend XlaOp Logistic(XlaOp operand); + friend XlaOp Sign(XlaOp operand); + friend XlaOp Clz(XlaOp operand); + friend XlaOp Cos(XlaOp operand); + friend XlaOp Sin(XlaOp operand); + friend XlaOp Tan(XlaOp operand); + friend XlaOp Tanh(XlaOp operand); + friend XlaOp Real(XlaOp operand); + friend XlaOp Imag(XlaOp operand); + friend XlaOp Sqrt(XlaOp operand); + friend XlaOp Rsqrt(XlaOp operand); + friend XlaOp Cbrt(XlaOp operand); + friend XlaOp Pow(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions); + friend XlaOp IsFinite(XlaOp operand); + friend XlaOp Iota(XlaBuilder* builder, const Shape& shape, + int64_t iota_dimension); + friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); + friend XlaOp ConvertElementType(XlaOp operand, + PrimitiveType new_element_type); + friend XlaOp BitcastConvertType(XlaOp operand, + PrimitiveType new_element_type); + friend XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + friend XlaOp Neg(XlaOp operand); + friend XlaOp Transpose(XlaOp operand, absl::Span permutation); + friend XlaOp Rev(XlaOp operand, absl::Span dimensions); + friend XlaOp Sort(absl::Span operands, + const XlaComputation& comparator, int64_t dimension, + bool is_stable); + friend XlaOp TopK(XlaOp operand, int64_t k, bool largest); + friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + friend XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands); + friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + friend XlaOp While(const XlaComputation& condition, + const XlaComputation& body, XlaOp init); + friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, + XlaOp false_operand, + const XlaComputation& false_computation); + friend XlaOp Conditional( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + friend XlaOp ConditionalImpl( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + friend XlaOp ReducePrecision(XlaOp operand, int exponent_bits, + int mantissa_bits); + friend XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted); + friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted, bool unique_indices); + friend XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted, bool unique_indices); + friend void Send(XlaOp operand, const ChannelHandle& handle); + friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, + float epsilon, int64_t feature_index); + friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, + XlaOp mean, XlaOp variance, float epsilon, + int64_t feature_index); + friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + friend XlaOp SendWithToken(XlaOp operand, XlaOp token, + const ChannelHandle& handle); + friend XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp SendToHost(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const ChannelHandle& handle); + friend XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config); + friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + friend XlaOp CreateToken(XlaBuilder* builder); + friend XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + + friend XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + friend XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + + protected: + // Returns OK status if the given op was built using this builder. Otherwise, + // returns an error. + absl::Status CheckOpBuilder(XlaOp op) const; + + private: + XlaOp AllGatherImpl(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids, bool async); + + XlaOp AllReduceImpl(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups, + const std::optional& channel_id, + const std::optional& layout, + std::optional use_global_device_ids, bool async); + + XlaOp CollectiveBroadcastImpl(XlaOp operand, + absl::Span replica_groups, + const std::optional& channel_id); + + XlaOp CollectivePermuteImpl( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id, bool async); + + XlaOp ConditionalImpl( + XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + + XlaOp AllToAllArray( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + + // Creates an op with the given opcode and the output shape. + virtual absl::StatusOr AddOpWithShape( + HloOpcode opcode, const Shape& shape, absl::Span operands); + + // Here, InstructionType is either const HloInstructionProto* or non-const + // HloInstructionProto*. + template + absl::StatusOr LookUpInstructionByHandleInternal( + int64_t handle) const { + auto it = handle_to_index_.find(handle); + if (it == handle_to_index_.end()) { + // Try look for the instruction in the imported instructions. + auto imported_it = handle_to_imported_index_.find(handle); + if (imported_it != handle_to_imported_index_.end()) { + ImportedInstruction imported = imported_it->second; + return const_cast( + &embedded_.at(imported.computation_id) + .instructions(imported.instruction_index)); + } + return InvalidArgument("No XlaOp with handle %d", handle); + } + return const_cast(&instructions_.at(it->second)); + } + + // Here, InstructionType is either const HloInstructionProto* or non-const + // HloInstructionProto*. + // + // TODO(hinsu): Return const pointer within absl::StatusOr and use + // absl::implicit_cast at callsites. This requires implicit_cast support in + // absl::StatusOr similar to absl::StatusOr. + template + absl::StatusOr LookUpInstructionInternal(XlaOp op) const { + TF_RETURN_IF_ERROR(CheckOpBuilder(op)); + return LookUpInstructionByHandleInternal(op.handle()); + } + + friend struct internal::XlaBuilderFriend; + + friend class ValueInference; +}; + +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class XlaScopedShardingAssignment { + public: + XlaScopedShardingAssignment(xla::XlaBuilder* builder, + std::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete; + XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) = + delete; + + ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const std::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::XlaBuilder* const builder_; + std::optional prev_sharding_; +}; + +// RAII-style object: save the current builder's frontend attributes, and merge +// them with the new ones on construction. +// Restore the original attributes on destruction. +class XlaScopedFrontendAttributesAssignment { + public: + XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder, + FrontendAttributes attributes) + : builder_(builder) { + saved_ = builder_->SwapFrontendAttributes(attributes); + } + + ~XlaScopedFrontendAttributesAssignment() { + builder_->SetFrontendAttributes(saved_); + } + + private: + xla::XlaBuilder* const builder_; + FrontendAttributes saved_; + + XlaScopedFrontendAttributesAssignment( + const XlaScopedFrontendAttributesAssignment&) = delete; + XlaScopedFrontendAttributesAssignment& operator=( + const XlaScopedFrontendAttributesAssignment&) = delete; +}; + +// RAII-style object: sets the current op metadata in builder on construction, +// and sets back to the previous assignment on destruction. +class XlaScopedOpMetadataAssignment { + public: + XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata) + : builder_(builder) { + saved_ = builder_->SwapOpMetadata(metadata); + } + + ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); } + + private: + xla::XlaBuilder* const builder_; + OpMetadata saved_; + + XlaScopedOpMetadataAssignment(const XlaScopedOpMetadataAssignment&) = delete; + XlaScopedOpMetadataAssignment& operator=( + const XlaScopedOpMetadataAssignment&) = delete; +}; + +// Free functions for building XlaOps. The intention is that these will +// become the public API for building XlaOps rather than calling methods on +// XlaBuilder directly. +// + +// Enqueues a "retrieve parameter value" instruction for a parameter that was +// passed to the computation. +XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name); + +// Same as above, but with leaf buffer replication annotation. +XlaOp Parameter(XlaBuilder* builder, int64_t parameter_number, + const Shape& shape, const std::string& name, + const std::vector& replicated_at_leaf_buffers); + +// Enqueues a constant with the value of the given literal onto the +// computation. +XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal); + +// Enqueues a constant onto the computation. Methods are templated on the +// native host type (NativeT) which corresponds to a specific XLA +// PrimitiveType as given in the following table: +// +// Native Type PrimitiveType +// ----------------------------- +// bool PRED +// int32_t S32 +// int64_t S64 +// uint32_t U32 +// uint64_t U64 +// float F32 +// double F64 +// +// Note: not all primitive types defined in xla_data.proto have a +// corresponding native type yet. +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value); +template +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values); +XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values); +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values); +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout); +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values); +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout); +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values); +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout); +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values); +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout); +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values); + +// Enqueues a rank one constant (XlaBuilder* builder, vector) onto the +// computation. The vector has size 'length' and every element has the value +// 'value'. +template +XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value); + +// Adds dimensions to an array by duplicating the data in the array. +// +// The new dimensions are inserted on the left, i.e. if +// broadcast_sizes has values {a0, ..., aN} and the operand shape +// has dimensions {b0, ..., bM} then the shape of the output has +// dimensions {a0, ..., aN, b0, ..., bM}. +// +// The new dimensions index into copies of the operand, i.e. +// +// output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] +XlaOp Broadcast(XlaOp operand, absl::Span broadcast_sizes); + +// This op broadcasts the `operand` to an output with the given `shape`. +// `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the +// i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th +// dimension of the output. This also requires that the i'th input dimension is +// either 1 or is the same as the output dimension it's broadcasting into. +// +// For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the +// output shape is s32[2,2]: +// - Specifying {1} as broadcast_dimension will generate output +// {{1, 2}, +// {1, 2}} +// - On the other hand, specifying {0} as broadcast_dimension +// will generate output +// {{1 , 1}, +// {2 , 2}} +XlaOp BroadcastInDim(XlaOp operand, absl::Span out_dim_size, + absl::Span broadcast_dimensions); + +// This is an experimental API for creating the mhlo.dynamic_broadcast_in_dim +// op from the XlaBuilder. This is only intended for export to MHLO or +// StableHLO, and cannot be compiled. See +// https://www.tensorflow.org/mlir/hlo_ops#mhlodynamic_broadcast_in_dim_mhlodynamicbroadcastindimop. +// for the op semantics. +XlaOp MhloDynamicBroadcastInDim(XlaOp operand, XlaOp output_dimensions, + absl::Span broadcast_dimensions, + const Shape& output_shape); + +// Copies the input operand to the output. This operation is for internal +// purpose and is only used by the compiler for optimization purposes or to +// ensure correctness. The XLA client should never have to generate this +// instruction. +// +// Copy has two potential use cases: +// +// * Create a copy of the operand with a new layout. +// +// * Create a copy of the operand in a separately allocated buffer. This is +// necessary for some backends if the operand is a parameter or constant and +// the operand is returned within a tuple. In this case, the lifetime of the +// operand buffer must be the same as the lifetime of the output result. +// However, the lifetimes of parameters and constants are managed separately +// from the lifetime of the output result. Creating a separate copy of the +// parameter or constant buffer resolves this issue. +XlaOp Copy(XlaOp operand); + +// Enqueues a pad operation onto the computation that pads the given value on +// the edges as well as between the elements of the input. padding_config +// specifies the padding amount for each dimension. +XlaOp Pad(XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config); + +// Enqueues a pad operation in a given dimension, taking all other +// dimensions as they are. +XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64_t dimno, + int64_t pad_lo, int64_t pad_hi); + +// Enqueues an operation onto the computation that flattens the operand based +// on the dimension order (major/slowest-varying to minor/fastest-varying) +// given, followed by reshaping it into the shape with the given dimension +// sizes (also major to minor). Conceptually, this is a limited form of +// "shape casting". +XlaOp Reshape(XlaOp operand, absl::Span dimensions, + absl::Span new_sizes); + +// Enqueues a dynamic reshape operation. The dynamic reshape takes additional +// XlaOps as sizes for the result dimension. The result dim i is a dynamic +// dimension dimension if dims_are_dynamic[i] is true. +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + +// This is an experimental API for creating the mhlo.dynamic_reshape op from the +// XlaBuilder. This is only intended for export to MHLO or StableHLO, and cannot +// be compiled. +XlaOp MhloDynamicReshape(XlaOp operand, XlaOp output_shape, const Shape& shape); + +// Enqueues an operation onto the computation that collapses the operand, +// from first to last dimension (C order), then reshapes it to the given +// dimension sizes. Conceptually, this is a limited form of "shape casting". +XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + +// Enqueues a Reshape op that uses an explicit target shape. +XlaOp Reshape(const Shape& shape, XlaOp operand); + +// `inferred_dimension` represents the output dimension that's inferred by +// upper-level framework by dividing the input element count by the known +// output element count. While an inferred_dimension can be static, if there +// is a dynamic dimension in the output, it must be the inferred dimension. +XlaOp ReshapeWithInferredDimension(XlaOp operand, + absl::Span new_sizes, + int64_t inferred_dimension); + +// Wrapper for Reshape. +// Enqueues an operation to collapse the provided dimensions; e.g. an +// operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to +// {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must +// be a consecutive, in-order subsequence of the operand dimensions. +// +// Note that collapsing a single dimension does nothing: +// +// {256} collapsing {0} => {256} +// {1} collapsing {0} => {1} +// +// Collapsing multiple dimensions produces a single result dimension: +// +// {256, 2} collapsing {0,1} => {512} +// {256, 2, 3} collapsing {0,1} => {512, 3} +// +// This could potentially cause data to be moved -- it provides a more +// structured form of reshaping than an arbitrary Reshape operation. +XlaOp Collapse(XlaOp operand, absl::Span dimensions); + +// Enqueues a slice operation onto the computation that slices the operand +// from the start indices to the limit indices; e.g. +// +// x +// [ 0 1 2 3 ] +// y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ] +// [ 8 9 a b ] +// +// Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D +// range notation. +// The strides parameter determines the stride over the slice +XlaOp Slice(XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + +// Enqueues a slice operation in a given dimension, taking all other +// dimensions as they are; e.g. if dimno is 1 from start_index 2 to +// limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand +// for: +// +// array[:, 2:4:1, :] +XlaOp SliceInDim(XlaOp operand, int64_t start_index, int64_t limit_index, + int64_t stride, int64_t dimno); + +// Enqueues a slice operation onto the computation that slices the 'operand' +// from dynamic start indices which are passed in 'start_indices'. +// The size of the slice in each dimension is passed in 'slice_sizes', +// which specify the end point of exclusive slice intervals in each +// dimension [start, start + size). +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. +// Slice index calculations are computed modulo input dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicSlice(XlaOp operand, absl::Span start_indices, + absl::Span slice_sizes); + +// Enqueues a dynamic update slice operation onto the computation, which +// updates a slice of 'operand' with 'update' at dynamic 'start_indices'. +// The shape of 'update' determines the shape of the slice of 'operand' +// which is updated. +// The indices specified in 'start_indices' specify the offset of the slice +// of 'operand' which is updated. +// +// update = {10, 11} // calculated at runtime. +// [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ] +// [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11] +// [7 8 9] [7 8 9 ] +// +// The shape of each element of 'start_indices' must be scalar, with the span +// size equal to the rank of the 'operand'. All elements of 'start_indices' must +// have the same shape. +// Slice index calculations are computed modulo update dimension sizes to +// prevent dynamic start indices from generating out-of-bound array accesses. +XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update, + absl::Span start_indices); + +// Enqueues a concatenate instruction onto the computation. 'operands' must +// have >= 1 entry. +XlaOp ConcatInDim(XlaBuilder* builder, absl::Span operands, + int64_t dimension); + +// Enqueues a conditional-move-like select operation onto the computation; +// predicated on pred, selects between on_true and on_false. +XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); + +// Enqueues a tuple-creation instruction onto the computation. +XlaOp Tuple(XlaBuilder* builder, absl::Span elements); + +// Enqueues a tuple-element-get instruction onto the computation. +XlaOp GetTupleElement(XlaOp tuple_data, int64_t index); + +// Enqueues an equal-to comparison instruction onto the computation. +XlaOp Eq(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a not-equal comparison instruction onto the computation. +XlaOp Ne(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a greater-or-equal comparison instruction onto the computation. +XlaOp Ge(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a greater-than comparison instruction onto the computation. +XlaOp Gt(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a less-than comparison instruction onto the computation. +XlaOp Lt(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a less-or-equal comparison instruction onto the computation. +XlaOp Le(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a comparison instruction onto the computation (optionally without +// broadcast_dimensions for consistency with others). +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction, Comparison::Type compare_type); +XlaOp Compare(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions, + ComparisonDirection direction); +XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction); + +// Enqueues a dot instruction onto the computation. +XlaOp Dot(XlaOp lhs, XlaOp rhs, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a general dot instruction onto the computation. +XlaOp DotGeneral( + XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a sparse dot instruction onto the computation. +XlaOp SparseDot( + XlaOp lhs, XlaOp rhs, absl::Span sparse_meta, + absl::Span sparsity, + const DotDimensionNumbers& dimension_numbers, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, which uses the +// default convolution dimension numbers. +XlaOp Conv(XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, int64_t feature_group_count = 1, + int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration in the format returned by MakePadding(). +XlaOp ConvWithGeneralPadding( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided dimension numbers configuration. +XlaOp ConvWithGeneralDimensions( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + Padding padding, const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration as well as the dimension numbers. +XlaOp ConvGeneral( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt); + +// Enqueues a convolution instruction onto the computation, with the caller +// provided padding configuration, dilation factors and dimension numbers. +XlaOp ConvGeneralDilated( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count = 1, int64_t batch_group_count = 1, + const PrecisionConfig* precision_config = nullptr, + std::optional preferred_element_type = std::nullopt, + std::optional> window_reversal = std::nullopt); + +XlaOp DynamicConvForward( + XlaOp lhs, XlaOp rhs, absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +XlaOp DynamicConvInputGrad( + XlaOp input_sizes, XlaOp lhs, XlaOp rhs, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +XlaOp DynamicConvKernelGrad( + XlaOp activations, XlaOp gradients, + absl::Span window_strides, + absl::Span> padding, + absl::Span lhs_dilation, + absl::Span rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers, + int64_t feature_group_count, int64_t batch_group_count, + const PrecisionConfig* precision_config, PaddingType padding_type, + std::optional preferred_element_type = std::nullopt); + +// Enqueues an FFT instruction onto the computation, of the given type and +// with the given FFT length. +XlaOp Fft(XlaOp operand, FftType fft_type, + absl::Span fft_length); + +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves for x in one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// +// * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If `lower` is true (false), then the strictly upper +// (lower) triangular part of each innermost matrix in `a` is assumed to be +// zero and is not accessed. +// * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a +// tensor of shape `[..., K, M]`. +// * `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// * `lower` is a boolean, indicating whether the argument `a` is +// lower-triangular (true) or upper-triangular (false). +// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be +// 1 and not accessed. +// * `transpose_a` indicates which function `op` we use to transform the tensor +// `a`: the identity function, transpose(a), or conjugate(transpose(a)) +XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, + bool unit_diagonal, + TriangularSolveOptions::Transpose transpose_a); + +// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) +// positive definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// If `lower` is true, the data from the lower triangle is used; if false, the +// upper triangle is used. The input data in the other triangle of the input +// does not affect the output. Returns the output in the same lower/upper +// triangle. The data returned in the other output triangle is arbitrary and +// implementation-defined. +// +// If `a` is not Hermitian positive definite, returns an array full of NaNs. +XlaOp Cholesky(XlaOp a, bool lower); + +// Enqueues an infeed instruction onto the computation, which writes data of +// the given shape to the infeed buffer of the device. +XlaOp Infeed(XlaBuilder* builder, const Shape& shape, + const std::string& config = ""); + +// Variant of Infeed which takes a token-shaped operand and produces a +// two-element tuple containing the data value and a token-shaped value. +// Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp InfeedWithToken(XlaOp token, const Shape& shape, + const std::string& config = ""); + +// Enqueues an outfeed instruction onto the computation. This instruction +// generates outgoing data transfers for the given data. +// +// shape_with_layout communicates the laid out shape that we want to outfeed +// -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error +// will occur. +void Outfeed(XlaOp operand, const Shape& shape_with_layout, + const std::string& outfeed_config); + +// Variant of Outfeed which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, + const Shape& shape_with_layout, + const std::string& outfeed_config); + +// Enqueues a call instruction onto the computation. +XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands); + +// Enqueues a composite call instruction onto the computation. +XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands, const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. |has_side_effect| +// specifies whether the instruction can have side effects. +// |output_operand_aliasing| specifies a list of output/operand buffer pairs +// that alias each other, where the output buffer is represented as a +// ShapeIndex, and the operand buffer is represented as the operand index and +// the ShapeIndex. +XlaOp CustomCall( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + const std::string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which constructs a custom call that applies an Xla computation. +XlaOp CustomCallWithComputation( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const XlaComputation& computation, + const Shape& shape, const std::string& opaque = "", + bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const std::string& opaque = "", bool has_side_effect = false, + absl::Span>> + output_operand_aliasing = {}, + const Literal* literal = nullptr, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Overload which annotates a custom call with the given Window and +// ConvolutionDimensionNumbers. Useful for custom-calls which represent +// convolutions. +// +// This sets the layout of its operands if operand_shapes_with_layout is +// nonempty, and it sets the layout of its result if `shape` has a layout. +XlaOp CustomCallWithConvDnums( + XlaBuilder* builder, const std::string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const std::string& opaque, bool has_side_effect, + absl::Span>> + output_operand_aliasing, + const Literal* literal, Window window, ConvolutionDimensionNumbers dnums, + CustomCallSchedule schedule = CustomCallSchedule::SCHEDULE_NONE, + CustomCallApiVersion api_version = API_VERSION_ORIGINAL); + +// Enqueues an optimization barrier onto the computation. +XlaOp OptimizationBarrier(XlaOp operand); + +// The following methods enqueue element-wise binary arithmetic operations +// onto the computation. The shapes of the operands have to match unless one +// of the operands is a scalar, or an explicit broadcast dimension is given +// (see g3doc for more details). + +// Enqueues a complex compose instruction onto the computation. +XlaOp Complex(XlaOp real, XlaOp imag, + absl::Span broadcast_dimensions = {}); + +// Enqueues a complex conjugate instruction onto the computation. +XlaOp Conj(XlaOp operand); + +// Enqueues an add instruction onto the computation. +XlaOp Add(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a subtract instruction onto the computation. +XlaOp Sub(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a multiply instruction onto the computation. +XlaOp Mul(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a divide instruction onto the computation. +XlaOp Div(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a remainder instruction onto the computation. +XlaOp Rem(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a max instruction onto the computation. +XlaOp Max(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues a min instruction onto the computation. +XlaOp Min(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Element-wise logical operators +XlaOp And(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Overload to call And with 3 or more operands. We need the following somewhat +// convoluted overload set to disambiguate with the overload that takes the +// `broadcast_dimensions` optional param. +inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) { + return And(op1, And(op2, op3)); +} +template +XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { + return And(op1, And(op2, And(op3, operands...))); +} + +XlaOp Or(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Overload to call Or with 3 or more operands. As with `And`, we need the +// following complicated overload set to handle the default arg in the `Or` +// overload above. +inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) { + return Or(op1, Or(op2, op3)); +} +template +XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) { + return Or(op1, Or(op2, Or(op3, operands...))); +} + +XlaOp Xor(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +XlaOp Not(XlaOp operand); + +XlaOp PopulationCount(XlaOp operand); + +XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); +// Reduces an array among the provided dimensions, given "computation" as a +// reduction operator. +XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span dimensions_to_reduce); + +// Reduces several arrays simultaneously among the provided dimensions, given +// "computation" as a reduction operator. +XlaOp Reduce(XlaBuilder* builder, absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span dimensions_to_reduce); + +// Convenience wrapper around the above that reduces all the dimensions in the +// operand shape. +XlaOp ReduceAll(XlaOp operand, XlaOp init_value, + const XlaComputation& computation); + +// Enqueues a windowed reduce instruction onto the computation. +XlaOp ReduceWindow(XlaOp operand, XlaOp init_value, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +XlaOp ReduceWindow(absl::Span operands, + absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, Padding padding); + +// As ReduceWindow(), but the padding is given in the format +// returned by MakePadding(). +XlaOp ReduceWindowWithGeneralPadding( + XlaOp operand, XlaOp init_value, const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); +XlaOp ReduceWindowWithGeneralPadding( + absl::Span operands, absl::Span init_values, + const XlaComputation& computation, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span base_dilations, + absl::Span window_dilations, + absl::Span> padding); + +// Returns the sum of the operand value within each subgroup of replicas. All +// replicas supply one input to the sum and all replicas receive the resulting +// sum for each subgroup. +XlaOp CrossReplicaSum(XlaOp operand, + absl::Span replica_groups = {}); + +XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension, + int64_t shard_count, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp AllGatherTuple( + absl::Span operands, int64_t all_gather_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +// Enqueues an operation that do an AllReduce of the operand cross cores. Here +// AllReduce means doing a reduction on the input operand cross cores and then +// broadcasting the reduction result to those cores. The reduction function is +// defined by `computation`, which should be a commutative computation on +// scalars, e.g., add, min, or max. The way that AllReduce is applied is +// configured by: +// +// - `replica_groups`: each ReplicaGroup contains a list of replica id. If +// empty, all replicas belong to one group. Allreduce will be applied within +// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} +// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// +// - `channel_id`: for Allreduce nodes from different modules, if they have the +// same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be +// applied cross modules. +// +// - `shape_with_layout`: forces the layout of the AllReduce to the given +// layout. This is used to guarantee the same layout for a group of AllReduce +// ops compiled separately. +XlaOp AllReduce(XlaOp operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp AllReduceTuple( + absl::Span operand, const XlaComputation& computation, + absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& shape_with_layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +XlaOp ReduceScatter( + XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension, + int64_t shard_count, absl::Span replica_groups = {}, + const std::optional& channel_id = std::nullopt, + const std::optional& layout = std::nullopt, + std::optional use_global_device_ids = std::nullopt); + +// Enqueues an operation that do an AllToAll of the operand cross cores. +// This involves AllToAll, followed by Reshape, Transpose, and another Reshape +// to get proper codegen. See implementation for additional details. +// +// An optional `layout` can be specified to force the layout of the instruction. +// This is used to guarantee the same layout for a group of AllToAll ops +// compiled separately. +XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, + absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp AllToAllTuple( + absl::Span operand, + absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp AllToAllTuple( + XlaOp operand, int64_t split_dimension, int64_t concat_dimension, + int64_t split_count, absl::Span replica_groups = {}, + const std::optional& layout = std::nullopt, + const std::optional& channel_id = std::nullopt); + +XlaOp CollectiveBroadcast( + XlaOp operand, absl::Span replica_groups, + const std::optional& channel_id = std::nullopt); + +// Enqueues an collective operation that sends and receives data cross replicas. +// +// - `source_target_pair`: a list of (source_replica_id, target_replica_id) +// pairs. For each pair, the operand is sent from source replica to target +// replica. Note that, 1) any two pairs should not have the same target replica +// id, and they should not have the same source replica id; 2) if a replica id +// is not a target in any pair, then the output on that replica is a tensor +// consists of 0(s) with the same shape as the input. +XlaOp CollectivePermute( + XlaOp operand, + const std::vector>& source_target_pairs, + const std::optional& channel_id = std::nullopt); + +// Enqueues an operation that returns the replica ID. +XlaOp ReplicaId(XlaBuilder* builder); + +// Enqueues an operation that scatters the `source` array to the selected +// indices of each window. +XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + Padding padding, XlaOp source, XlaOp init_value, + const XlaComputation& scatter); + +// As SelectAndScatter(), but the padding is given in the format +// returned by MakePadding(). +XlaOp SelectAndScatterWithGeneralPadding( + XlaOp operand, const XlaComputation& select, + absl::Span window_dimensions, + absl::Span window_strides, + absl::Span> padding, XlaOp source, + XlaOp init_value, const XlaComputation& scatter); + +// Enqueues an abs instruction onto the computation. +XlaOp Abs(XlaOp operand); + +// Enqueues a atan2 instruction onto the computation. +XlaOp Atan2(XlaOp y, XlaOp x, + absl::Span broadcast_dimensions = {}); + +// Enqueues an erf instruction onto the computation. +XlaOp Erf(XlaOp operand); + +// Enqueues an exp instruction onto the computation. +XlaOp Exp(XlaOp operand); + +// Enqueues an expm1 instruction onto the computation. +XlaOp Expm1(XlaOp operand); + +// Enqueues a floor instruction onto the computation. +XlaOp Floor(XlaOp operand); + +// Enqueues a ceil instruction onto the computation. +XlaOp Ceil(XlaOp operand); + +// Enqueues a round instruction onto the computation, +// with half-way cases rounding away from zero. +XlaOp Round(XlaOp operand); + +// Enqueues a round instruction onto the computation, rounding to nearest even +XlaOp RoundNearestEven(XlaOp operand); + +// Enqueues an log instruction (natural logarithm) onto the computation. +XlaOp Log(XlaOp operand); + +// Enqueues an log1p instruction (log(x+1)) onto the computation. +XlaOp Log1p(XlaOp operand); + +// Enqueues a logistic instruction onto the computation. +XlaOp Logistic(XlaOp operand); + +// Enqueues a sign instruction onto the computation. +XlaOp Sign(XlaOp operand); + +// Enqueues a count leading zeros instruction onto the computation. +XlaOp Clz(XlaOp operand); + +// Enqueues a cosine instruction onto the computation. +XlaOp Cos(XlaOp operand); + +// Enqueues a sine instruction onto the computation. +XlaOp Sin(XlaOp operand); + +// Enqueues a tan instruction onto the computation. +XlaOp Tan(XlaOp operand); + +// Enqueues a tanh instruction onto the computation. +XlaOp Tanh(XlaOp operand); + +// Enqueues a real-part instruction onto the computation. +XlaOp Real(XlaOp operand); + +// Enqueues an imaginary-part instruction onto the computation. +XlaOp Imag(XlaOp operand); + +// Enqueues a sqrt computation onto the computation. +XlaOp Sqrt(XlaOp operand); + +// Enqueues a cbrt computation onto the computation. +XlaOp Cbrt(XlaOp operand); + +// Enqueues a rsqrt computation onto the computation. +XlaOp Rsqrt(XlaOp operand); + +// Enqueues a lhs^rhs computation onto the computation. +XlaOp Pow(XlaOp lhs, XlaOp rhs, + absl::Span broadcast_dimensions = {}); + +// Enqueues an operator that tests if the operand's values are finite, i.e., not +// +/-Inf or NaN. Returns an array of booleans with the same shape where +// entries are true iff the corresponding entry was not infinite or NaN. +// +// Defined only for real-valued (i.e. not complex) floating-point types; raises +// an error for other types. +// +// See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h. +XlaOp IsFinite(XlaOp operand); + +// Enqueues an iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64_t iota_dimension); + +// Enqueues a rank-1 iota operation onto the computation. +XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64_t size); + +// Enqueues a convert instruction onto the computation that changes the +// element type of the operand array to primitive_type. +XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type); + +// Enqueues a no-op instruction onto the computation that changes +// the element type of the operand array to primitive_type. The +// bit-widths of the source and destination element types must be +// identical. +XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type); + +// Enqueues a stochastic convert instruction onto the computation that changes +// the element type of the operand array with stochastic rounding to +// primitive_type. +XlaOp StochasticConvertType(XlaOp operand, XlaOp random, + PrimitiveType new_element_type); + +// Enqueues a negate instruction onto the computation. +XlaOp Neg(XlaOp operand); + +// Enqueues a transpose instruction onto the computation. +XlaOp Transpose(XlaOp operand, absl::Span permutation); + +// Enqueues a reverse instruction onto the computation. The order of the +// elements in the given dimensions is reversed (i.e., the element at index i +// is moved to index dimension_size - 1 - i). +XlaOp Rev(XlaOp operand, absl::Span dimensions); + +// Enqueues a sort instruction onto the computation, using 'comparator' for +// comparisons. 'comparator' needs to define a strict weak order. 'is_stable' +// determines whether the stable sorting should be used. +// If only one operand is provided: +// * If the operand is a rank-1 tensor (an array), the result is a sorted array. +// The resulting sorting order has the property that for all index positions +// i, j with i < j, either +// comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or +// comparator(value[i], value[j]) = true. +// * If the operand has higher rank, the operand is sorted along the provided +// dimension. For example, for a rank-2 tensor (a matrix), a dimension value +// of 0 will independently sort every column, and a dimension value of 1 will +// independently sort each row. If no dimension number is provided, then the +// last dimension is chosen by default. For the dimension which is sorted, the +// same sorting order applies as in the rank-1 case. +// +// If more than one operand is provided: +// * All operands must be tensors with the same dimensions. The element types of +// the tensors may be different. +// * The result is a tuple that consists of the operands in sorted order (along +// the provided dimension, as above). The same permutation as implied by the +// comparison computation is applied to all operand tensors. When comparing +// two index positions, 'comparator' is called with 2 * n scalar parameters, +// where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at +// two index positions. +// Default comparator computations can be found in lib/comparators.h +XlaOp Sort(absl::Span operands, const XlaComputation& comparator, + int64_t dimension = -1, bool is_stable = false); + +// Enqueues a topk instruction onto the computation. TopK returns the largest +// 'k' values and their indices along the last dimension of the 'operand' if +// `lagest=true` or the smallest `k` values if `largest=false`. +// +// * If the operand is a rank-1 tensor (an array), the result is a tuple that +// consists of: +// * a sorted array with the top 'k' elements. +// * an array containing the indices of the k elements. +// For example, if the input is [0.1, 0.3, 0.2] and k == 2, the output tuple +// is ([0.3, 0.2], [1, 2]). +// * If the operand has higher rank, the result is a tuple that consists of: +// * a tensor equivalent to one produced by sorting the operand along the last +// dimension and slicing that dimension to only the top 'k' values. The last +// dimension is sorted as in the rank-1 case. +// * a tensor containing the indices of the top 'k' values along the last +// dimension. +// For example, if the input is [0.1, 0.3, 0.2][0.5, 0.4, 0.6] and k == 1, the +// output tuple is ([0.3][0.6], [1][2]). +XlaOp TopK(XlaOp operand, int64_t k, bool largest); + +// Enqueues a clamp instruction onto the computation. +XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max); + +// Enqueues a map instruction onto the computation. +XlaOp Map(XlaBuilder* builder, absl::Span operands, + const XlaComputation& computation, + absl::Span dimensions, + absl::Span static_operands = {}); + +// Enqueues a N(mu, sigma) random number generation instruction onto the +// computation. +XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape); + +// Enqueues a U(a, b) random number generation instruction onto the +// computation. Returns values in the semi-open interval [a, b). +XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape); + +// Enqueues a B(initial_state) random bit generation instruction onto the +// computation. Returns the new key and random bits with the specified shape. +XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, + const Shape& shape); + +// Enqueues a while node onto the computation. +XlaOp While(const XlaComputation& condition, const XlaComputation& body, + XlaOp init); + +// Enqueues a conditional node onto the computation. +XlaOp Conditional(XlaOp predicate, XlaOp true_operand, + const XlaComputation& true_computation, XlaOp false_operand, + const XlaComputation& false_computation); + +// Enqueues either a predicated (if/else) or indexed (switch/case/default) +// conditional node onto the computation. N >= 1 branch_computations and +// branch_operands are matched by index. branch_index selects the branch that +// will be executed. Out of range branch_index uses the N-1'th +// branch_computation as default. +XlaOp Conditional(XlaOp branch_index, + absl::Span branch_computations, + absl::Span branch_operands); + +// Enqueues a ReducePrecision node onto the computation. +XlaOp ReducePrecision(XlaOp operand, int exponent_bits, int mantissa_bits); + +// Enqueues a Gather node onto the computation. +XlaOp Gather(XlaOp input, XlaOp start_indices, + const GatherDimensionNumbers& dimension_numbers, + absl::Span slice_sizes, + bool indices_are_sorted = false); + +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); +XlaOp Scatter(absl::Span inputs, XlaOp scatter_indices, + absl::Span updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers, + bool indices_are_sorted = false, bool unique_indices = false); + +// Enqueues a Send node onto the computation for device-to-device +// communication. This operation sends the given operand to +// a Recv instruction in a different computation that shares the same channel +// handle. +void Send(XlaOp operand, const ChannelHandle& handle); + +// Variant of Send which takes a token-shaped operand and produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle); + +// Enqueues a Recv node onto the computation for device-to-device +// communication. The data comes from a Send instruction in a different +// computation that shares the same channel handle and its shape must be the +// same as the given shape. +XlaOp Recv(XlaBuilder* builder, const Shape& shape, + const ChannelHandle& handle); + +// Variant of Recv which takes a token-shaped operand and produces a two-element +// tuple containing the data value and a token-shaped value. Tokens are used +// for ordering side-effecting operations. +// TODO(b/110532604): Replace all uses of the non-token form with this variant. +XlaOp RecvWithToken(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues a Send node which transfers data from the device to the host. The +// 'shape_with_layout' argument defines the layout of the data transferred; its +// shape must be compatible with the shape of the operand. The operand must be +// array-shaped. +// TODO(b/111544877): Support tuple shapes. +XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const ChannelHandle& handle); + +// Enqueues a Recv node which transfers data from the host to the device. The +// given shape must contain a layout and must be an array. +// TODO(b/111544877): Support tuple shapes. +XlaOp RecvFromHost(XlaOp token, const Shape& shape, + const ChannelHandle& handle); + +// Enqueues an operation (AfterAll) with no operands that produces a +// token-shaped value. Tokens are used for ordering side-effecting operations. +// This is a separate method from AfterAll to facility the removal of +// operand-less AfterAll instructions. +// TODO(b/110532604): Remove this function when all tokens are derived from a +// single token generated or passed into the entry computation. +XlaOp CreateToken(XlaBuilder* builder); + +// Enqueues an AfterAll instruction which produces a token-shaped value and +// takes a variadic number of token-shaped operands. The number of operands must +// be greater than zero. Used for joining tokens. +XlaOp AfterAll(XlaBuilder* builder, absl::Span tokens); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// Returns a tuple (normalized, batch_mean, batch_var) where `normalized` +// is the normalized result and batch_mean and batch_var are the mean and +// variance, respectively, across batch for the operand. +XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon, + int64_t feature_index); + +// Normalizes operand across spatial and batch dimensions for each feature. +// +// `BatchNormInference` is equivalent to calling `BatchNormTraining` without +// computing `mean` and `variance` for each batch inside the operation. It +// uses the input `mean` and `variance` instead as estimated values. The +// purpose of this op is to reduce latency in inference, hence the name +// `BatchNormInference`. +// +// The output has the same shape as `operand`, and contains the normalized +// values for each batch. +XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean, + XlaOp variance, float epsilon, int64_t feature_index); + +// Calculates the gradients of a batch norm op. +// +// The inputs `batch_mean` and `batch_var` represent the mean and variance +// across the batch. +// +// Returns a tuple of three elements: +// - grad_operand: Gradient with respect to input `operand` +// - grad_offset: Gradient with respect to input `offset` +// - grad_scale: Gradient with respect to input `scale` +XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean, + XlaOp batch_var, XlaOp grad_output, float epsilon, + int64_t feature_index); + +// Returns the size of the given dimension of the operand. The operand must be +// array shaped. +XlaOp GetDimensionSize(XlaOp operand, int64_t dimension); + +// Sets the size of the given dimension of the operand. The operand must be +// array shaped. The result will have the same shape as the operand, but the +// given dimension will be dynamic (if not already). +XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64_t dimension); + +// Returns the same op but with dynamic dimension removed. +XlaOp RemoveDynamicDimension(XlaOp operand, int64_t dimension); + +// Implementation details below this point. +// + +// Free function template implementations. + +template +XlaOp ConstantR0(XlaBuilder* builder, NativeT value) { + return ConstantLiteral(builder, LiteralUtil::CreateR0(value)); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, absl::Span values) { + BorrowingLiteral literal( + reinterpret_cast(values.begin()), + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())})); + return ConstantLiteral(builder, literal); +} + +template +XlaOp ConstantR1(XlaBuilder* builder, int64_t length, NativeT value) { + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(builder, literal); +} + +inline XlaOp ConstantR1(XlaBuilder* builder, const tsl::core::Bitmap& values) { + return ConstantLiteral(builder, LiteralUtil::CreateR1(values)); +} + +template +XlaOp ConstantR2(XlaBuilder* builder, + std::initializer_list> values) { + return ConstantLiteral(builder, LiteralUtil::CreateR2(values)); +} + +template +XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder, + const Array& values, + const Layout& layout) { + return ConstantLiteral( + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantFromArray(XlaBuilder* builder, const Array& values) { + return ConstantLiteral(builder, + LiteralUtil::CreateFromArray(values)); +} + +template +XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder, + const Array2D& values, + const Layout& layout) { + return ConstantLiteral( + builder, LiteralUtil::CreateFromArrayWithLayout(values, layout)); +} + +template +XlaOp ConstantR2FromArray2D(XlaBuilder* builder, + const Array2D& values) { + return ConstantLiteral(builder, + LiteralUtil::CreateR2FromArray2D(values)); +} + +template +XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder, + const Array3D& values, + const Layout& layout) { + return ConstantLiteral( + builder, + LiteralUtil::CreateR3FromArray3DWithLayout(values, layout)); +} + +template +XlaOp ConstantR3FromArray3D(XlaBuilder* builder, + const Array3D& values) { + return ConstantFromArray(builder, values); +} + +template +XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder, + const Array4D& values, + const Layout& layout) { + return ConstantFromArrayWithLayout(builder, values, layout); +} + +template +XlaOp ConstantR4FromArray4D(XlaBuilder* builder, + const Array4D& values) { + return ConstantFromArray(builder, values); +} + +// Switches from automatic SPMD partitioning to manual partitioning. Converts a +// full-shaped tensor (to be automatically partitioned by SPMD partitioner) to a +// shard-shaped tensor to be consumed by manually partitioned ops. +absl::StatusOr ConvertSpmdFullToShardShape( + xla::XlaBuilder* builder, xla::XlaOp input, int single_dim, + const xla::OpSharding& manual_sharding, + absl::Span unspecified_dims); + +// Switches from manual partitioning to automatic SPMD partitioning. Converts a +// shard-shaped tensor (manually partitioned in SPMD-style) to a full-shaped +// tensor to be partitioned automatically by the SPMD partitioner. +absl::StatusOr ConvertSpmdShardToFullShape( + xla::XlaBuilder* builder, xla::XlaOp input, const xla::Shape& output_shape, + int single_dim, const xla::OpSharding& manual_sharding, + absl::Span unspecified_dims); + +} // namespace xla + +#endif // XLA_HLO_BUILDER_XLA_BUILDER_H_ diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/hlo/builder/xla_builder_test.cc similarity index 99% rename from third_party/xla/xla/client/xla_builder_test.cc rename to third_party/xla/xla/hlo/builder/xla_builder_test.cc index 8ecf2434fc1d3f..089dba93f216ad 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/hlo/builder/xla_builder_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include #include @@ -34,12 +34,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/value_inference.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/value_inference.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -47,9 +47,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout_util.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -1560,6 +1560,23 @@ TEST(XlaBuilderTest, CheckBufferDonor) { EXPECT_FALSE(config.ParameterIsBufferDonor(1, {})); } +TEST(XlaBuilderTest, ConstantLiteral) { + XlaBuilder b(TestName()); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + int old_csr = _mm_getcsr(); + // Treat denormals as zero. This will make the small number below equal to + // 0.0, as far as the FP unit is concerned. + _mm_setcsr(old_csr | _MM_DENORMALS_ZERO_ON); +#endif + ConstantR1(&b, {0.0f, 1.401298e-45f}); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + _mm_setcsr(old_csr); +#endif + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + const HloInstruction* root = GetRoot(*module); + ASSERT_THAT(root, GmockMatch(m::Constant())); +} + TEST(XlaBuilderTest, InvalidInputOutputAliasBufferDonor) { XlaBuilder b(TestName()); diff --git a/third_party/xla/xla/client/xla_computation.cc b/third_party/xla/xla/hlo/builder/xla_computation.cc similarity index 96% rename from third_party/xla/xla/client/xla_computation.cc rename to third_party/xla/xla/hlo/builder/xla_computation.cc index fc558462d1a576..1d01870f1d85c9 100644 --- a/third_party/xla/xla/client/xla_computation.cc +++ b/third_party/xla/xla/hlo/builder/xla_computation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include #include diff --git a/third_party/xla/xla/hlo/builder/xla_computation.h b/third_party/xla/xla/hlo/builder/xla_computation.h new file mode 100644 index 00000000000000..379d386e4b7908 --- /dev/null +++ b/third_party/xla/xla/hlo/builder/xla_computation.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_BUILDER_XLA_COMPUTATION_H_ +#define XLA_HLO_BUILDER_XLA_COMPUTATION_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// The computation graph that the user builds up with the XlaBuilder. +class XlaComputation { + public: + XlaComputation() : unique_id_(-1) {} + XlaComputation(HloModuleProto proto) + : unique_id_(proto.id()), proto_(std::move(proto)) {} + + ~XlaComputation() = default; + + XlaComputation(const XlaComputation&) = delete; + XlaComputation& operator=(const XlaComputation&) = delete; + + XlaComputation(XlaComputation&& from) = default; + + XlaComputation& operator=(XlaComputation&& from) = default; + + // Returns the "program shape" (parameter and return shapes) for this + // computation. + absl::StatusOr GetProgramShape() const; + + const std::string& name() const { return proto().name(); } + + const HloModuleProto& proto() const { return proto_; } + HloModuleProto* mutable_proto() { return &proto_; } + + // Requests that we snapshot the computation into a serializable protocol + // buffer form. + absl::StatusOr> Snapshot() const; + + // Returns true if this object is a null Computation. + bool IsNull() const { return unique_id_ == -1; } + + private: + XlaComputation(const int64_t unique_id) : unique_id_(unique_id) {} + friend class XlaBuilder; + + int64_t unique_id_; + HloModuleProto proto_; +}; + +} // namespace xla + +#endif // XLA_HLO_BUILDER_XLA_COMPUTATION_H_ diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index 7093c530d41ab8..63c26e537706fb 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -52,6 +52,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:call_graph", @@ -61,7 +62,6 @@ cc_library( "//xla/service:logical_buffer", "//xla/service:pattern_matcher", "//xla/service:shape_inference", - "//xla/service:tuple_points_to_analysis", "//xla/service/cpu:runtime_single_threaded_matmul", "//xla/tsl/lib/core:bitmap", "@com_google_absl//absl/algorithm:container", @@ -107,11 +107,11 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) @@ -134,18 +134,17 @@ xla_cc_test( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/analysis:tuple_points_to_analysis", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/transforms:hlo_element_type_converter", "//xla/service:call_graph", "//xla/service:dynamic_dimension_inference", - "//xla/service:hlo_element_type_converter", "//xla/service:hlo_module_config", "//xla/service:shape_inference", - "//xla/service:tuple_points_to_analysis", - "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", - "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:endian", "@com_google_absl//absl/container:flat_hash_set", @@ -160,5 +159,6 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep ], ) diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index b75903563ea57c..d6a625a5ee3a32 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -52,6 +52,7 @@ limitations under the License. #include "Eigen/Core" #include "xla/array2d.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/evaluator/hlo_evaluator_typed_visitor.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -73,7 +74,6 @@ limitations under the License. #include "xla/service/logical_buffer.h" #include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -4769,4 +4769,82 @@ std::unique_ptr> HloEvaluator::MatmulArray2D( lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulU8); } +/* static */ std::unique_ptr> Array2DF8E5M2ToF32( + const Array2D& input) { + auto result = std::make_unique>(input.height(), input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF8E4M3FNToF32( + const Array2D& input) { + auto result = std::make_unique>(input.height(), input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF32ToF8E5M2( + const Array2D& input) { + auto result = std::make_unique>(input.height(), + input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = + static_cast(input(rowno, colno)); + } + } + return result; +} + +/* static */ std::unique_ptr> Array2DF32ToF8E4M3FN( + const Array2D& input) { + auto result = std::make_unique>(input.height(), + input.width()); + for (int64_t rowno = 0; rowno < input.height(); ++rowno) { + for (int64_t colno = 0; colno < input.width(); ++colno) { + (*result)(rowno, colno) = + static_cast(input(rowno, colno)); + } + } + return result; +} + +static bool promote_f8_to_f32 = true; + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs) { + if (promote_f8_to_f32) { + auto lhs_float = Array2DF8E5M2ToF32(lhs); + auto rhs_float = Array2DF8E5M2ToF32(rhs); + auto result = MatmulArray2D(*lhs_float, *rhs_float); + return Array2DF32ToF8E5M2(*result); + } else { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2); + } +} + +std::unique_ptr> HloEvaluator::MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs) { + if (promote_f8_to_f32) { + auto lhs_float = Array2DF8E4M3FNToF32(lhs); + auto rhs_float = Array2DF8E4M3FNToF32(rhs); + auto result = MatmulArray2D(*lhs_float, *rhs_float); + return Array2DF32ToF8E4M3FN(*result); + } else { + return MatmulArray2DImpl( + lhs, rhs, __xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN); + } +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index f26b633b20a3ae..cff76a0a09ff64 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -39,6 +39,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/array2d.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,10 +49,10 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/shape_inference.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" namespace xla { @@ -238,6 +239,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { const Array2D>& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs); + static std::unique_ptr> MatmulArray2D( + const Array2D& lhs, + const Array2D& rhs); static std::unique_ptr> MatmulArray2D( const Array2D& lhs, const Array2D& rhs); diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_slow_reduce_window_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_slow_reduce_window_test.cc index 80d3cc08b41981..5a9024d701d646 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_slow_reduce_window_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_slow_reduce_window_test.cc @@ -22,15 +22,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -TEST_F(HloTestBase, SlowReduceWindow) { +TEST_F(HloHardwareIndependentTestBase, SlowReduceWindow) { constexpr absl::string_view kHloModule = R"( HloModule SlowReduceWindow %add { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index 56f5de4feb3237..b1136533947937 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -38,13 +38,16 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" #include "xla/error_spec.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -52,14 +55,11 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/call_graph.h" #include "xla/service/dynamic_dimension_inference.h" -#include "xla/service/hlo_element_type_converter.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" #include "xla/types.h" @@ -78,7 +78,7 @@ static std::array use_bf16_params{true, false}; // Test fixture for the HloEvaluator. // // In bf16 mode, all f32 shapes are converted to bf16 before running. -class HloEvaluatorTest : public HloTestBase { +class HloEvaluatorTest : public HloHardwareIndependentTestBase { public: HloEvaluatorTest() : use_bfloat16_(false) { InitializeFftData(); } @@ -91,8 +91,9 @@ class HloEvaluatorTest : public HloTestBase { } // Evaluate function that takes in a local module instead of using m_ - // that is in HloTestBase. Once m_ in HloTestBase is - // removed, this should be the default Evaluate function. + // that is in HloHardwareIndependentTestBase. Once m_ in + // HloHardwareIndependentTestBase is removed, this should be the default + // Evaluate function. Literal EvaluateWithModule( HloModule* module, absl::Span arg_literals = {}) { if (use_bfloat16_) { @@ -2603,7 +2604,7 @@ ENTRY main { EXPECT_TRUE(LiteralTestUtil::Near(expected, result, fft_error_)); } -class HloEvaluatorPreciseReduceTest : public HloTestBase {}; +class HloEvaluatorPreciseReduceTest : public HloHardwareIndependentTestBase {}; // Tests that Reduce doesn't lose precision when adding many numbers (because // it accumulates its result in a double). @@ -5171,7 +5172,7 @@ TEST_F(HloEvaluatorTest, ParameterThroughCallSucceedsWithPrecomputation) { EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } -class PatternMatchParseWhileLoopTest : public HloTestBase {}; +class PatternMatchParseWhileLoopTest : public HloHardwareIndependentTestBase {}; TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) { constexpr absl::string_view kHloModule = R"( diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index d7bc1ba49a9bf7..41cd753d987201 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1129,20 +1129,12 @@ class HloEvaluatorTypedVisitor : public ConstDfsHloVisitorWithDefault { CHECK_EQ(dnums.lhs_batch_dimensions_size(), dnums.rhs_batch_dimensions_size()); - DimensionVector lhs_non_contracting_dims; - DimensionVector rhs_non_contracting_dims; - for (int64_t i = 0; i < lhs_rank; i++) { - if (!absl::c_linear_search(dnums.lhs_contracting_dimensions(), i) && - !absl::c_linear_search(dnums.lhs_batch_dimensions(), i)) { - lhs_non_contracting_dims.push_back(i); - } - } - for (int64_t i = 0; i < rhs_rank; i++) { - if (!absl::c_linear_search(dnums.rhs_contracting_dimensions(), i) && - !absl::c_linear_search(dnums.rhs_batch_dimensions(), i)) { - rhs_non_contracting_dims.push_back(i); - } - } + DimensionVector lhs_non_contracting_dims = + GetNonContractingDims(lhs_rank, dnums.lhs_contracting_dimensions(), + dnums.lhs_batch_dimensions()); + DimensionVector rhs_non_contracting_dims = + GetNonContractingDims(rhs_rank, dnums.rhs_contracting_dimensions(), + dnums.rhs_batch_dimensions()); DimensionVector contracting_dim_sizes; contracting_dim_sizes.reserve(dnums.lhs_contracting_dimensions_size()); @@ -1743,10 +1735,12 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index 7c97c210aa36a5..d425d33c2feab5 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -19,8 +19,10 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 9cf801548493a7..6a4c537069b84f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -46,11 +46,16 @@ cc_library( "//xla:array", "//xla:shape_tree", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/transforms:hlo_constant_splitter", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms:optimize_input_output_buffer_alias", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:buffer_value", @@ -58,13 +63,9 @@ cc_library( "//xla/service:computation_layout", "//xla/service:dot_as_convolution_util", "//xla/service:dump", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_dce", - "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_value", - "//xla/service:optimize_input_output_buffer_alias", "//xla/service:sharding_propagation", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -217,6 +218,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -227,7 +229,6 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", - ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_wrapper", @@ -236,6 +237,7 @@ cc_library( "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -261,10 +263,12 @@ cc_library( ":auto_sharding_strategy", ":auto_sharding_util", ":profiling_result", - "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -291,6 +295,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/ir:ptrvec", "//xla/hlo/ir:tile_assignment", @@ -298,7 +303,6 @@ cc_library( "//xla/service:call_graph", "//xla/service:computation_layout", "//xla/service:sharding_propagation", - "//xla/service:while_loop_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -330,7 +334,6 @@ xla_cc_binary( deps = [ ":auto_sharding", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", "//xla/tools:hlo_module_loader", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:platform_port", @@ -379,26 +382,28 @@ xla_cc_test( ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/transforms:hlo_memory_scheduler", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_memory_scheduler", - "//xla/service:hlo_parser", "//xla/service:hlo_value", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ], ) @@ -417,12 +422,13 @@ xla_cc_test( ":auto_sharding_proto_cc", ":auto_sharding_solver", # build_cleaner: keep ":auto_sharding_strategy", - "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ] + if_google(["@com_google_ortools//ortools/linear_solver:linear_solver_scip"]), ) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index ec9dc5f0916c56..21f49fc8498e63 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -48,6 +48,7 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h" @@ -68,27 +69,28 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" -#include "xla/hlo/transforms/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" #include "xla/service/dump.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_value.h" -#include "xla/service/optimize_input_output_buffer_alias.h" #include "xla/service/sharding_propagation.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -212,8 +214,8 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( const CallGraph& call_graph, InputShardings& input_shardings) { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - if (input_shardings.empty() && ins->operand_count() > 0) { - input_shardings.resize(ins->operand_count()); + if (input_shardings.shardings.empty() && ins->operand_count() > 0) { + input_shardings.shardings.resize(ins->operand_count()); } for (int64_t k = 0; k < ins->operand_count(); ++k) { const HloInstruction* operand = ins->operand(k); @@ -224,14 +226,14 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( if (operand_shape.IsToken() || operand_shape.rank() == 0) { communication_resharding_costs.push_back(zeros); memory_resharding_costs.push_back(zeros); - if (!input_shardings[k].has_value()) { - input_shardings[k] = HloSharding::Replicate(); + if (!input_shardings.shardings[k].has_value()) { + input_shardings.shardings[k] = HloSharding::Replicate(); } } else { std::optional cur_input_sharding; - CHECK_EQ(input_shardings.size(), ins->operand_count()); - if (input_shardings[k].has_value()) { - cur_input_sharding = input_shardings[k]; + CHECK_EQ(input_shardings.shardings.size(), ins->operand_count()); + if (input_shardings.shardings[k].has_value()) { + cur_input_sharding = input_shardings.shardings[k]; } else { cur_input_sharding = GetInputSharding( ins, k, output_sharding, call_graph, cluster_env.NumDevices()); @@ -250,8 +252,8 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( } } CHECK(cur_input_sharding.has_value()); - if (!input_shardings[k].has_value()) { - input_shardings[k] = cur_input_sharding; + if (!input_shardings.shardings[k].has_value()) { + input_shardings.shardings[k] = cur_input_sharding; } if (ins->opcode() == HloOpcode::kGather && k == 0 && is_sharding_default_replicated) { @@ -259,7 +261,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( << output_sharding.ToString(); communication_resharding_costs.push_back(zeros); memory_resharding_costs.push_back(zeros); - input_shardings[k] = std::nullopt; + input_shardings.shardings[k] = std::nullopt; } else { communication_resharding_costs.push_back( CommunicationReshardingCostVector( @@ -275,8 +277,7 @@ GenerateReshardingCostsAndMissingShardingsForAllOperands( memory_resharding_costs); } -std::tuple>> +std::tuple GenerateReshardingCostsAndShardingsForAllOperands( const HloInstruction* ins, const HloSharding& output_sharding, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, @@ -286,7 +287,7 @@ GenerateReshardingCostsAndShardingsForAllOperands( GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); - for (const auto& sharding_optional : input_shardings_optional) { + for (const auto& sharding_optional : input_shardings_optional.shardings) { CHECK(sharding_optional.has_value()); } @@ -333,7 +334,7 @@ void FollowArrayOrTokenStrategyGroup( double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, *output_spec); size_t num_in_nodes = strategy_group.in_nodes.size(); - InputShardings input_shardings(num_in_nodes, *output_spec); + InputShardings input_shardings{name, {num_in_nodes, *output_spec}}; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; for (size_t i = 0; i < strategy_group.in_nodes.size(); ++i) { @@ -345,7 +346,7 @@ void FollowArrayOrTokenStrategyGroup( } strategy_group.AddStrategy( - ShardingStrategy({name, *output_spec, compute_cost, communication_cost, + ShardingStrategy({*output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, memory_resharding_costs}), input_shardings); @@ -386,16 +387,16 @@ std::unique_ptr HandlePartialReduce( } // Get a list of input shardings, each corresponds to an operand. - InputShardings input_shardings; + std::string name = ToStringSimple(output_spec); + InputShardings input_shardings = {std::move(name)}; for (int64_t k = 0; k < output_size * 2; ++k) { if (k < output_size) { - input_shardings.push_back(input_spec); + input_shardings.shardings.push_back(input_spec); } else { - input_shardings.push_back(HloSharding::Replicate()); + input_shardings.shardings.push_back(HloSharding::Replicate()); } } - std::string name = ToStringSimple(output_spec); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding( ins->shape().tuple_shapes(i), output_spec); @@ -405,8 +406,8 @@ std::unique_ptr HandlePartialReduce( input_shardings); child_strategy_group->AddStrategy( - ShardingStrategy({std::move(name), std::move(output_spec), - compute_cost, communication_cost, memory_cost, + ShardingStrategy({std::move(output_spec), compute_cost, + communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), std::move(input_shardings)); @@ -553,9 +554,9 @@ absl::StatusOr> FollowReduceStrategy( } } const ShardingStrategy strategy = ShardingStrategy( - {name, output_spec, compute_cost, communication_cost, memory_cost, + {output_spec, compute_cost, communication_cost, memory_cost, communication_resharding_costs, memory_resharding_costs}); - strategy_group->AddStrategy(strategy, {input_sharding}); + strategy_group->AddStrategy(strategy, {name, {input_sharding}}); } } else { LOG(FATAL) << "Unhandled kReduce shape: " << ins->shape().ToString(); @@ -574,8 +575,7 @@ std::vector FindReplicateStrategyIndices( return indices; } -std::tuple>> +std::tuple ReshardingCostsForTupleOperand(const HloInstruction* operand, const StrategyGroup& operand_strategy_vector) { // TODO(yuemmawang) Support instructions with more than one tuple operand. @@ -606,9 +606,10 @@ ReshardingCostsForTupleOperand(const HloInstruction* operand, communication_resharding_costs.back().at(i) = 0.0; } } - return {communication_resharding_costs, memory_resharding_costs, - std::vector>( - {HloSharding::Tuple(operand->shape(), tuple_element_shardings)})}; + return { + communication_resharding_costs, + memory_resharding_costs, + {{}, {HloSharding::Tuple(operand->shape(), tuple_element_shardings)}}}; } ReshardingCosts CreateZeroReshardingCostsForAllOperands( @@ -650,7 +651,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = HloSharding::Replicate(); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"R"}; const int tuple_size = ins->operand(0)->shape().tuple_shapes_size(); const auto& operand_strategy_group = strategy_map.at(ins->operand(0)); @@ -677,7 +678,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, const StrategyGroup& child = *operand_children[i]; const Shape& tuple_shape = ins->operand(0)->shape().tuple_shapes(i); const HloSharding& input_sharding = get_input_sharding(i); - input_shardings.push_back(input_sharding); + input_shardings.shardings.push_back(input_sharding); communication_resharding_costs.push_back( CommunicationReshardingCostVector(child, tuple_shape, input_sharding, cluster_env)); @@ -685,7 +686,7 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, child, tuple_shape, input_sharding, cluster_env)); } const HloSharding& input_sharding = get_input_sharding(-1); - input_shardings.push_back(input_sharding); + input_shardings.shardings.push_back(input_sharding); } else { for (size_t i = 0; i < tuple_size; ++i) { const StrategyGroup& child = *operand_children[i]; @@ -698,20 +699,19 @@ void GenerateOutfeedStrategy(const HloInstruction* ins, const Shape& shape, memory_resharding_costs.push_back({}); double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); strategy_group.AddStrategy( - ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), input_shardings); } -double ComputeCommunicationCost( - const HloInstruction* ins, - const std::vector>& operand_shardings, - const ClusterEnvironment& cluster_env) { +double ComputeCommunicationCost(const HloInstruction* ins, + const InputShardings& operand_shardings, + const ClusterEnvironment& cluster_env) { switch (ins->opcode()) { case HloOpcode::kGather: { - if (operand_shardings[0].has_value() && - !operand_shardings[0]->IsReplicated()) { + if (operand_shardings.shardings[0].has_value() && + !operand_shardings.shardings[0]->IsReplicated()) { auto mesh_shape = cluster_env.device_mesh_.dimensions(); auto mesh_dim = std::distance( mesh_shape.begin(), @@ -761,9 +761,10 @@ void AddReplicatedStrategy( CHECK(!operand->shape().IsTuple()); const auto& operand_strategy_group = strategy_map.at(operand).get(); const auto& operand_strategies = operand_strategy_group->GetStrategies(); + InputShardings input_shardings = {"R"}; + input_shardings.shardings.resize(ins->operand_count()); std::vector possible_input_shardings( - operand_strategies.size(), - std::vector>(ins->operand_count())); + operand_strategies.size(), input_shardings); std::vector possible_communication_resharding_costs( operand_strategies.size(), ReshardingCosts(ins->operand_count())); std::vector possible_memory_resharding_costs( @@ -778,7 +779,7 @@ void AddReplicatedStrategy( CHECK_EQ(possible_input_shardings.size(), operand_strategies.size()); for (size_t j = 0; j < possible_input_shardings.size(); ++j) { const auto& operand_sharding = operand_strategies[j].output_sharding; - possible_input_shardings[j][k] = operand_sharding; + possible_input_shardings[j].shardings[k] = operand_sharding; possible_communication_resharding_costs[j][k] = CommunicationReshardingCostVector(operand_strategy_group, operand_shape, operand_sharding, @@ -789,7 +790,7 @@ void AddReplicatedStrategy( } } else { for (size_t j = 0; j < possible_input_shardings.size(); ++j) { - possible_input_shardings[j][k] = replicated_strategy; + possible_input_shardings[j].shardings[k] = replicated_strategy; possible_communication_resharding_costs[j][k] = CommunicationReshardingCostVector( operand_strategy_group, operand_shape, replicated_strategy, @@ -806,7 +807,7 @@ void AddReplicatedStrategy( ins, possible_input_shardings[j], cluster_env); strategy_group.AddStrategy( ShardingStrategy( - {"R", replicated_strategy, replicated_penalty, communication_cost, + {replicated_strategy, replicated_penalty, communication_cost, memory_cost, std::move(possible_communication_resharding_costs[j]), std::move(possible_memory_resharding_costs[j])}), @@ -815,7 +816,7 @@ void AddReplicatedStrategy( } else { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"R"}; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -843,12 +844,12 @@ void AddReplicatedStrategy( cluster_env)); memory_resharding_costs.push_back(MemoryReshardingCostVector( operand_strategy_group, operand_shape, output_spec, cluster_env)); - input_shardings.push_back(output_spec); + input_shardings.shardings.push_back(output_spec); } } } strategy_group.AddStrategy( - ShardingStrategy({"R", HloSharding::Replicate(), replicated_penalty, 0, + ShardingStrategy({HloSharding::Replicate(), replicated_penalty, 0, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -897,7 +898,7 @@ void EnumerateAll1DPartition( ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {name}; if (ins->opcode() == HloOpcode::kConditional) { // TODO(pratikf): Compute input_shardings for kConditional ops communication_resharding_costs = @@ -915,7 +916,7 @@ void EnumerateAll1DPartition( *strategy_map.at(ins->operand(0))); } else if (ins->opcode() == HloOpcode::kRngBitGenerator && ins->operand(0)->shape().IsArray()) { - input_shardings.push_back(HloSharding::Replicate()); + input_shardings.shardings.push_back(HloSharding::Replicate()); std::tie(communication_resharding_costs, memory_resharding_costs) = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_spec, strategy_map, cluster_env, call_graph, @@ -939,7 +940,7 @@ void EnumerateAll1DPartition( ins->operand(0)->shape().rank() - 1, i, j, shape, cluster_env); } strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -959,8 +960,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, void EnumerateAllPartition( const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, + const StrategyMap& strategy_map, bool only_allow_divisible, bool allow_shardings_small_dims_across_many_devices, const CallGraph& call_graph, const int64_t partition_dimensions, const std::vector& tensor_dims, StrategyGroup& strategy_group) { @@ -971,15 +971,10 @@ void EnumerateAllPartition( strategy_group); return; } - auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); - int64_t batch_dim = -1; - if (iter != batch_dim_map.end()) { - batch_dim = iter->second; - } // Fully tile the buffer to the mesh for (int64_t i = 0; i < shape.rank(); ++i) { auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); - if ((batch_dim != -1 && batch_dim != i) || tensor_it != tensor_dims.end()) { + if (tensor_it != tensor_dims.end()) { continue; } if (!allow_shardings_small_dims_across_many_devices && @@ -993,7 +988,7 @@ void EnumerateAllPartition( std::vector next_tensor_dims = tensor_dims; next_tensor_dims.push_back(i); EnumerateAllPartition( - ins, shape, device_mesh, cluster_env, strategy_map, batch_dim_map, + ins, shape, device_mesh, cluster_env, strategy_map, only_allow_divisible, allow_shardings_small_dims_across_many_devices, call_graph, partition_dimensions, next_tensor_dims, strategy_group); } @@ -1014,7 +1009,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, HloSharding output_spec = Tile(shape, tensor_dims, mesh_dims, device_mesh); double compute_cost = 0, communication_cost = 0; double memory_cost = ByteSizeOfShapeWithSharding(shape, output_spec); - InputShardings input_shardings; + InputShardings input_shardings = {name}; ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; if (ins->opcode() == HloOpcode::kConditional) { @@ -1057,7 +1052,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), input_shardings); @@ -1107,99 +1102,13 @@ void EnumerateAll1DPartitionReshape(const HloInstruction* ins, ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( operand_strategy_group, operand_shape, *input_spec, cluster_env)}; strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), - {*input_spec}); - } - } -} - -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - absl::Span tensor_dims, - StrategyGroup& strategy_group); - -// Enumerate all partitions for reshape. Batch dim is always partitioned. -void EnumeratePartitionReshape( - const HloInstruction* ins, const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, - const bool only_allow_divisible, const int64_t partition_dimensions, - const std::vector& tensor_dims, StrategyGroup& strategy_group) { - const auto tensor_dims_size = tensor_dims.size(); - if (tensor_dims_size == partition_dimensions) { - BuildStrategyAndCostForReshape(ins, device_mesh, cluster_env, strategy_map, - tensor_dims, strategy_group); - return; - } - auto iter = batch_dim_map.find(GetBatchDimMapKey(ins)); - int64_t batch_dim = -1; - if (iter != batch_dim_map.end()) { - batch_dim = iter->second; - } - - // Split batch dim + another dim - for (int64_t i = 0; i < ins->shape().rank(); ++i) { - auto tensor_it = std::find(tensor_dims.begin(), tensor_dims.end(), i); - if ((batch_dim != -1 && batch_dim != i) || tensor_it != tensor_dims.end()) { - continue; - } - if (ins->shape().dimensions(i) < device_mesh.dim(tensor_dims_size)) { - continue; - } - if (only_allow_divisible && - !IsDivisible(ins->shape().dimensions(i), - device_mesh.dim(tensor_dims_size))) { - continue; + {name, {*input_spec}}); } - - std::vector next_tensor_dims = tensor_dims; - next_tensor_dims.push_back(i); - EnumeratePartitionReshape(ins, device_mesh, cluster_env, strategy_map, - batch_dim_map, only_allow_divisible, - partition_dimensions, next_tensor_dims, - strategy_group); - } -} - -void BuildStrategyAndCostForReshape(const HloInstruction* ins, - const DeviceMesh& device_mesh, - const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - absl::Span tensor_dims, - StrategyGroup& strategy_group) { - const HloInstruction* operand = ins->operand(0); - const Shape& operand_shape = operand->shape(); - const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); - std::vector mesh_dims(tensor_dims.size()); - std::iota(mesh_dims.begin(), mesh_dims.end(), 0); - const HloSharding output_spec = - Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh); - std::optional input_spec = hlo_sharding_util::ReshapeSharding( - ins->shape(), operand_shape, output_spec); - if (!input_spec.has_value()) { // invalid reshape - return; } - std::string name = - absl::StrFormat("S%s @ {%s}", absl::StrJoin(tensor_dims, ""), - absl::StrJoin(mesh_dims, ",")); - double compute_cost = 0, communication_cost = 0; - double memory_cost = ByteSizeOfShapeWithSharding(ins->shape(), output_spec); - - ReshardingCosts communication_resharding_costs{ - CommunicationReshardingCostVector(operand_strategy_group, operand_shape, - *input_spec, cluster_env)}; - ReshardingCosts memory_resharding_costs{MemoryReshardingCostVector( - operand_strategy_group, operand_shape, *input_spec, cluster_env)}; - strategy_group.AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, - memory_cost, std::move(communication_resharding_costs), - std::move(memory_resharding_costs)}), - {*input_spec}); } // Return the maximum number of tiles among all strategies of an instruction. @@ -1290,40 +1199,12 @@ bool AllowTieFollowing(const HloInstruction* ins) { return true; } -// 1. Disable mixed mesh shape if the batch dim is not divisible by the -// number of devices. -// 2. Disable force_batch_dim_to_mesh_dim if the batch dim is 1. In this case, -// the batch dim analysis can be wrong because the batch dim might be dropped. -void DisableIncompatibleMixedMeshShapeAndForceBatchDim( - const InstructionBatchDimMap& batch_dim_map, - const std::vector& instructions, int num_devices, - AutoShardingOption& option) { - int64_t batch_size = INT_MAX; - for (const auto& iter : batch_dim_map) { - batch_size = std::min(batch_size, FindInstruction(instructions, iter.first) - ->shape() - .dimensions(iter.second)); - } - - if (IsDivisible(batch_size, num_devices)) { - if (option.allow_mixed_mesh_shape) { - option.allow_mixed_mesh_shape = false; - LOG(WARNING) - << "Mixed mesh shape is disabled due to indivisible batch size."; - } - } - - if (batch_size == 1) { - option.force_batch_dim_to_mesh_dim = -1; - } -} - void FillAllStrategiesForArray( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, - const bool only_allow_divisible, const bool create_replicated_strategies, + const CallGraph& call_graph, const bool only_allow_divisible, + const bool create_replicated_strategies, const bool create_partially_replicated_strategies, StrategyGroup& strategy_group) { if (create_partially_replicated_strategies || cluster_env.IsDeviceMesh1D()) { @@ -1336,7 +1217,7 @@ void FillAllStrategiesForArray( // Split 2 dims if (cluster_env.IsDeviceMesh2D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, batch_dim_map, only_allow_divisible, + strategy_map, only_allow_divisible, option.allow_shardings_small_dims_across_many_devices, call_graph, /*partitions*/ 2, /*tensor_dims*/ {}, strategy_group); @@ -1344,7 +1225,7 @@ void FillAllStrategiesForArray( // Split 3 dims if (cluster_env.IsDeviceMesh3D()) { EnumerateAllPartition(ins, shape, cluster_env.device_mesh_, cluster_env, - strategy_map, batch_dim_map, only_allow_divisible, + strategy_map, only_allow_divisible, option.allow_shardings_small_dims_across_many_devices, call_graph, /*partitions*/ 3, /*tensor_dims*/ {}, strategy_group); @@ -1367,22 +1248,13 @@ void FillAllStrategiesForArray( AddReplicatedStrategy(ins, shape, cluster_env, strategy_map, replicated_penalty, {}, strategy_group); } - - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option.force_batch_dim_to_mesh_dim >= 0 && - batch_dim_map.contains(GetBatchDimMapKey(ins))) { - CHECK_OK(FilterStrategy(ins, shape, cluster_env, batch_dim_map, option, - strategy_group)); - } } absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, const size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, - const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, + const double replicated_penalty, const CallGraph& call_graph, const bool only_allow_divisible, const bool create_replicated_strategies, const bool create_partially_replicated_strategies) { std::unique_ptr strategy_group; @@ -1390,12 +1262,11 @@ absl::StatusOr> CreateAllStrategiesGroup( strategy_group = CreateTupleStrategyGroup(instruction_id); for (size_t i = 0; i < shape.tuple_shapes_size(); ++i) { auto child_strategies = - CreateAllStrategiesGroup(ins, shape.tuple_shapes(i), instruction_id, - strategy_groups, cluster_env, strategy_map, - option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - create_replicated_strategies, - create_partially_replicated_strategies) + CreateAllStrategiesGroup( + ins, shape.tuple_shapes(i), instruction_id, strategy_groups, + cluster_env, strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, create_replicated_strategies, + create_partially_replicated_strategies) .value(); child_strategies->tuple_element_idx = i; strategy_group->AddChild(std::move(child_strategies)); @@ -1405,9 +1276,8 @@ absl::StatusOr> CreateAllStrategiesGroup( strategy_groups); FillAllStrategiesForArray( ins, shape, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, - create_replicated_strategies, create_partially_replicated_strategies, - *strategy_group); + call_graph, only_allow_divisible, create_replicated_strategies, + create_partially_replicated_strategies, *strategy_group); } else if (shape.IsToken()) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); @@ -1501,7 +1371,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( std::string name = ToStringSimple(existing_sharding); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {name}; if (!strategy_group.in_nodes.empty()) { HloInstruction* ins = instructions.at(strategy_group.instruction_id); for (size_t i = 0; i < strategy_group.in_nodes.size(); i++) { @@ -1534,7 +1404,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( CHECK(input_sharding.has_value()); - input_shardings.push_back(*input_sharding); + input_shardings.shardings.push_back(*input_sharding); communication_resharding_costs.push_back( CommunicationReshardingCostVector( *operand_strategy_group, operand_shape, *input_sharding, @@ -1552,7 +1422,7 @@ void TrimOrGenerateStrategiesBasedOnExistingSharding( } strategy_group.ClearStrategies(); strategy_group.AddStrategy( - ShardingStrategy({name, existing_sharding, 0, 0, memory_cost, + ShardingStrategy({existing_sharding, 0, 0, memory_cost, communication_resharding_costs, memory_resharding_costs}), input_shardings); @@ -1681,24 +1551,26 @@ void RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( void ScaleCostsWithExecutionCounts(const int64_t execution_count, StrategyGroup& strategy_group) { - if (strategy_group.is_tuple) { - for (const auto& child : strategy_group.GetChildren()) { - ScaleCostsWithExecutionCounts(execution_count, *child); + auto scale_cost = [&execution_count](double& cost) { + if (cost < kInfinityCost - 1) { + cost *= execution_count; } - } else { - for (size_t sid = 0; sid < strategy_group.GetStrategies().size(); ++sid) { - ShardingStrategy& strategy = strategy_group.GetStrategy(sid); - strategy.compute_cost *= execution_count; - strategy.communication_cost *= execution_count; - for (auto i = 0; i < strategy.communication_resharding_costs.size(); - ++i) { - for (auto j = 0; j < strategy.communication_resharding_costs[i].size(); + }; + auto scale_for_leaf = [&](StrategyGroup& leaf_strategy_group) { + for (int sid = 0; sid < leaf_strategy_group.GetStrategies().size(); ++sid) { + ShardingStrategy& strategy = leaf_strategy_group.GetStrategy(sid); + scale_cost(strategy.compute_cost); + scale_cost(strategy.communication_cost); + for (int i = 0; i < strategy.communication_resharding_costs.size(); ++i) { + for (int j = 0; j < strategy.communication_resharding_costs[i].size(); ++j) { - strategy.communication_resharding_costs[i][j] *= execution_count; + scale_cost(strategy.communication_resharding_costs[i][j]); } } } - } + }; + + strategy_group.ForEachLeafStrategyGroup(scale_for_leaf); } std::unique_ptr CreateElementwiseOperatorStrategies( @@ -1783,7 +1655,7 @@ std::unique_ptr HandleManuallyShardedInstruction( strategy_groups); ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - InputShardings input_shardings; + InputShardings input_shardings = {"MANUAL"}; if (ins->operand_count() > 0 && ins->operand(0)->shape().IsTuple()) { CHECK_EQ(ins->operand_count(), 1) @@ -1805,7 +1677,7 @@ std::unique_ptr HandleManuallyShardedInstruction( } } strategy_group->AddStrategy( - ShardingStrategy({"MANUAL", HloSharding::Replicate(), 0, 0, + ShardingStrategy({HloSharding::Replicate(), 0, 0, static_cast(ShapeUtil::ByteSizeOf(shape)), std::move(communication_resharding_costs), std::move(memory_resharding_costs)}), @@ -1820,60 +1692,49 @@ std::unique_ptr CreateReshapeStrategies( const size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, const bool only_allow_divisible, const double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph) { - const DeviceMesh& device_mesh = cluster_env.device_mesh_; - - int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); std::unique_ptr strategy_group = CreateLeafStrategyGroup( instruction_id, ins, strategy_map, strategy_groups); - if (mesh_nn_dims < 2 || !option.allow_mixed_mesh_shape) { - const HloInstruction* operand = ins->operand(0); - - // Create follow strategies - const StrategyGroup& src_strategy_group = *strategy_map.at(operand); - CHECK(!src_strategy_group.is_tuple); - strategy_group->following = &src_strategy_group; - - for (const auto& src_strategy : src_strategy_group.GetStrategies()) { - std::optional output_spec = - hlo_sharding_util::ReshapeSharding(operand->shape(), ins->shape(), - src_strategy.output_sharding); - - if (!output_spec.has_value()) { - continue; - } - - if (!IsValidTileAssignment(*output_spec)) { - continue; - } - - if (!TileAssignmentMatchesMesh(*output_spec, device_mesh)) { - continue; - } - const std::string name = ToStringSimple(*output_spec); - double compute_cost = 0, communication_cost = 0; - double memory_cost = - ByteSizeOfShapeWithSharding(ins->shape(), output_spec); - std::vector communication_resharding_costs = - CommunicationReshardingCostVector( - src_strategy_group, operand->shape(), - src_strategy.output_sharding, cluster_env); - std::vector memory_resharding_costs = - MemoryReshardingCostVector(src_strategy_group, operand->shape(), - src_strategy.output_sharding, cluster_env); - strategy_group->AddStrategy( - ShardingStrategy({name, - *output_spec, - compute_cost, - communication_cost, - memory_cost, - {communication_resharding_costs}, - {memory_resharding_costs}}), - {src_strategy.output_sharding}); + // Create strategies from operands, but do not follow the operand. We + // anecdotally observe that following the operands causes regressions. + const HloInstruction* operand = ins->operand(0); + const StrategyGroup& operand_strategy_group = *strategy_map.at(operand); + CHECK(!operand_strategy_group.is_tuple); + + for (const ShardingStrategy& operand_strategy : + operand_strategy_group.GetStrategies()) { + std::optional output_sharding = + hlo_sharding_util::ReshapeSharding(operand->shape(), ins->shape(), + operand_strategy.output_sharding); + + if (!output_sharding.has_value() || + !IsValidTileAssignment(*output_sharding) || + !TileAssignmentMatchesMesh(*output_sharding, + cluster_env.device_mesh_)) { + continue; } + + const std::string name = ToStringSimple(*output_sharding); + double compute_cost = 0, communication_cost = 0; + double memory_cost = + ByteSizeOfShapeWithSharding(ins->shape(), output_sharding); + std::vector communication_resharding_costs = + CommunicationReshardingCostVector( + operand_strategy_group, operand->shape(), + operand_strategy.output_sharding, cluster_env); + std::vector memory_resharding_costs = MemoryReshardingCostVector( + operand_strategy_group, operand->shape(), + operand_strategy.output_sharding, cluster_env); + strategy_group->AddStrategy( + ShardingStrategy({*output_sharding, + compute_cost, + communication_cost, + memory_cost, + {communication_resharding_costs}, + {memory_resharding_costs}}), + {name, {operand_strategy.output_sharding}}); } if (strategy_group->GetStrategies().empty()) { @@ -1881,15 +1742,15 @@ std::unique_ptr CreateReshapeStrategies( VLOG(2) << "Enumerating all strategies for reshape"; FillAllStrategiesForArray( ins, ins->shape(), cluster_env, strategy_map, option, - replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, - /* create_replicated_strategies */ true, - /* create_partially_replicated_strategies */ true, *strategy_group); + replicated_penalty, call_graph, only_allow_divisible, + /*create_replicated_strategies=*/true, + /*create_partially_replicated_strategies=*/true, *strategy_group); } - return strategy_group; } -AutoShardingSolverResult CallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -2111,7 +1972,7 @@ AutoShardingSolverResult CallSolver( PopulateTemporalValues(cost_graph, request); - return CallORToolsSolver(request); + return FormulateAndSolveMIPFromSolverRequest(request); } void CheckHloSharding( @@ -2301,8 +2162,6 @@ absl::Status InsertReshardReshapes( // spmd partitioner generate correct code. if (inst->opcode() == HloOpcode::kDot || inst->opcode() == HloOpcode::kConvolution) { - const ShardingStrategy& stra = - GetShardingStrategy(inst, strategy_map, cost_graph, s_val); const HloInstruction* lhs = inst->operand(0); const HloInstruction* rhs = inst->operand(1); const HloSharding& lhs_sharding = lhs->sharding(); @@ -2335,7 +2194,9 @@ absl::Status InsertReshardReshapes( "Cannot generate tensor dim to mesh dim mapping"); } - if (absl::StrContains(stra.name, "allreduce") && + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + if (absl::StrContains(input_shardings.name, "allreduce") && std::any_of(lhs_con_dims.begin(), lhs_con_dims.end(), [&lhs_tensor_dim_to_mesh_dim](int64_t dim) { return lhs_tensor_dim_to_mesh_dim[dim] == -1; @@ -2347,19 +2208,20 @@ absl::Status InsertReshardReshapes( // Allow duplicated dot computation in this case to reduce // communication } else { - const InputShardings& input_shardings = - GetInputShardings(inst, strategy_map, cost_graph, s_val); - CHECK(input_shardings.size() == 2) + CHECK(input_shardings.shardings.size() == 2) << "Dot op requires both operands to have input shardings, " "but get instruction: " - << inst->ToString() << ", strategy : " << stra.ToString(); - if (input_shardings[0].has_value()) { + << inst->ToString() + << ", input shardings : " << input_shardings.ToString(); + if (input_shardings.shardings[0].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, 0, *input_shardings[0], device_mesh, resharding_cache)); + inst, 0, *input_shardings.shardings[0], device_mesh, + resharding_cache)); } - if (input_shardings[1].has_value()) { + if (input_shardings.shardings[1].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, 1, *input_shardings[1], device_mesh, resharding_cache)); + inst, 1, *input_shardings.shardings[1], device_mesh, + resharding_cache)); } } } @@ -2393,11 +2255,11 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - if (input_shardings.size() > i && - input_shardings[i].has_value()) { - TF_RETURN_IF_ERROR( - FixMixedMeshShapeResharding(inst, i, *input_shardings[i], - device_mesh, resharding_cache)); + if (input_shardings.shardings.size() > i && + input_shardings.shardings[i].has_value()) { + TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( + inst, i, *input_shardings.shardings[i], device_mesh, + resharding_cache)); } } break; @@ -2407,10 +2269,11 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - CHECK_EQ(input_shardings.size(), 1); - CHECK(input_shardings[0].has_value()); + CHECK_EQ(input_shardings.shardings.size(), 1); + CHECK(input_shardings.shardings[0].has_value()); TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, i, *input_shardings[0], device_mesh, resharding_cache)); + inst, i, *input_shardings.shardings[0], device_mesh, + resharding_cache)); } break; } @@ -2424,8 +2287,9 @@ absl::Status InsertReshardReshapes( const InputShardings& input_shardings = GetInputShardingsForTuple(inst, {static_cast(i)}, strategy_map, cost_graph, s_val); - if (!input_shardings.empty() && input_shardings[0].has_value()) { - dst_shardings[i] = *input_shardings[0]; + if (!input_shardings.shardings.empty() && + input_shardings.shardings[0].has_value()) { + dst_shardings[i] = *input_shardings.shardings[0]; } } TF_RETURN_IF_ERROR( @@ -2447,7 +2311,7 @@ absl::Status InsertReshardReshapes( } else { const InputShardings& input_shardings = GetInputShardings(inst, strategy_map, cost_graph, s_val); - if (input_shardings.empty()) { + if (input_shardings.shardings.empty()) { continue; } if (inst->opcode() == HloOpcode::kGetTupleElement) { @@ -2457,9 +2321,11 @@ absl::Status InsertReshardReshapes( } for (size_t i = 0; i < inst->operand_count(); ++i) { - if (input_shardings.size() > i && input_shardings[i].has_value()) { + if (input_shardings.shardings.size() > i && + input_shardings.shardings[i].has_value()) { TF_RETURN_IF_ERROR(FixMixedMeshShapeResharding( - inst, i, *input_shardings[i], device_mesh, resharding_cache)); + inst, i, *input_shardings.shardings[i], device_mesh, + resharding_cache)); } } } @@ -2705,32 +2571,40 @@ std::string PrintSolutionMemoryUsage(const LivenessSet& liveness_set, return str; } -void SaveShardingForInstruction( +absl::Status SaveShardingForInstruction( const HloInstruction* inst, bool save_for_copy_users, absl::flat_hash_map>& preserve_shardings) { - auto save_sharding = [&preserve_shardings](const HloInstruction* inst) { + auto save_sharding = + [&preserve_shardings](const HloInstruction* inst) -> absl::Status { if (!inst->has_sharding()) { - return; + return absl::OkStatus(); + } + if (inst->sharding().IsUnknown() && + (inst->sharding().IsShardLike() || inst->sharding().IsShardAs())) { + return absl::UnimplementedError( + "Auto-sharding currently does not support shard_as/shard_like " + "sharding annotations"); } if (!inst->sharding().IsTuple()) { preserve_shardings[inst->name()] = {inst->sharding()}; } else { preserve_shardings[inst->name()] = inst->sharding().tuple_elements(); } + return absl::OkStatus(); }; - save_sharding(inst); + TF_RETURN_IF_ERROR(save_sharding(inst)); + // Also preserve the shardings of copy users of theinstruction. if (save_for_copy_users) { for (const auto user : inst->users()) { - // Also preserve the shardings of copy ops that are the users of those - // instructions. if (user->opcode() == HloOpcode::kCopy) { - save_sharding(user); + TF_RETURN_IF_ERROR(save_sharding(user)); } } } + return absl::OkStatus(); } // Check whether the shardings that need to be preserved are preserved. @@ -3015,7 +2889,9 @@ absl::Status GenerateReduceScatter( } const ShardingStrategy& strategy = GetShardingStrategy(inst, strategy_map, cost_graph, s_val); - if (!absl::StrContains(strategy.name, "allreduce")) { + const InputShardings& input_shardings = + GetInputShardings(inst, strategy_map, cost_graph, s_val); + if (!absl::StrContains(input_shardings.name, "allreduce")) { continue; } @@ -3119,14 +2995,14 @@ absl::Status GenerateReduceScatter( if (num_replicated_parameters >= 1 && need_all_gather.size() <= 1 && replicated_set.size() >= 5) { HloSharding output_spec = - GetReduceScatterOutput(inst, strategy, cluster_env); + GetReduceScatterOutput(inst, input_shardings, strategy, cluster_env); if (IsUndefined(output_spec)) { continue; } VLOG(10) << "SET: " << output_spec.ToString(); - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { + if (absl::StartsWith(input_shardings.name, "RR = RS x SR")) { // If set the sharding for this dot instruction, the SPMD // partitioner will generate bad fallback code. replicated_set.erase(inst); @@ -3235,60 +3111,9 @@ absl::Status GenerateReduceScatter( return absl::OkStatus(); } -// Filter strategies according to the option.force_batch_dim_to_mesh_dim. -// This can be used to forcibly generate data-parallel strategies. -absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, - const AutoShardingOption& option, - StrategyGroup& strategy_group) { - int mesh_dim = option.force_batch_dim_to_mesh_dim; - int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); - const DeviceMesh& device_mesh = cluster_env.device_mesh_; - - if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { - return absl::InvalidArgumentError( - "The length of batch dimension is " - "not divisible by the number of devices"); - } - - std::vector> new_strategies; - const auto& strategy_input_shardings = - strategy_group.GetStrategyInputShardings(); - for (size_t iid = 0; iid < strategy_input_shardings.size(); ++iid) { - const InputShardings& input_shardings = strategy_input_shardings[iid]; - const ShardingStrategy& strategy = - strategy_group.GetStrategyForInputShardings(iid); - const HloSharding& output_sharding = strategy.output_sharding; - const std::vector tensor_dim_to_mesh_dim = - cluster_env.GetTensorDimToMeshDimWrapper(shape, output_sharding); - - if (device_mesh.dim(mesh_dim) > 1) { - // If the mesh dim is not one, the output tensor must be - // tiled along the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == mesh_dim) { - new_strategies.push_back({strategy, input_shardings}); - } - } else { - // If the mesh dim is one, the output tensor must be replicated - // on the mesh dim. - if (tensor_dim_to_mesh_dim[batch_dim] == -1) { - new_strategies.push_back({strategy, input_shardings}); - } - } - } - CHECK(!new_strategies.empty()) - << ins->ToString() << " does not have any valid strategies"; - strategy_group.ClearStrategies(); - for (const auto& [strategy, input_shardings] : new_strategies) { - strategy_group.AddStrategy(strategy, input_shardings); - } - - return absl::OkStatus(); -} - // Return the output sharding of the reduce-scatter variant of a given strategy. HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const InputShardings& input_shardings, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env) { const DeviceMesh& device_mesh = cluster_env.device_mesh_; @@ -3298,10 +3123,10 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); int64_t space_base_dim = dot_dnums.lhs_batch_dimensions_size(); - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { + if (absl::StartsWith(input_shardings.name, "SR = SS x SR") || + absl::StartsWith(input_shardings.name, "RS = RS x SS")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {space_base_dim, space_base_dim + 1}, {mesh_dim0, mesh_dim1})) { @@ -3314,9 +3139,9 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {space_base_dim, space_base_dim + 1}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "SbR = SbSk x SbSk")) { + if (absl::StartsWith(input_shardings.name, "SbR = SbSk x SbSk")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {0, space_base_dim}, {mesh_dim0, mesh_dim1})) { @@ -3329,8 +3154,8 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {0, space_base_dim}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "RR = RS x SR")) { - int mesh_dim = absl::StrContains(strategy.name, "{0}") ? 0 : 1; + if (absl::StartsWith(input_shardings.name, "RR = RS x SR")) { + int mesh_dim = absl::StrContains(input_shardings.name, "{0}") ? 0 : 1; if (!IsDivisible(ins, device_mesh, {space_base_dim}, {mesh_dim})) { return Undefined(); @@ -3338,7 +3163,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {space_base_dim}, {mesh_dim}, device_mesh); } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + if (absl::StartsWith(input_shardings.name, "R = Sk x Sk")) { int mesh_dim = 0; if (!IsDivisible(ins, device_mesh_1d, {space_base_dim}, {mesh_dim})) { @@ -3353,10 +3178,10 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, int out_batch_dim = conv_dnums.output_batch_dimension(); int out_out_channel_dim = conv_dnums.output_feature_dimension(); - if (absl::StartsWith(strategy.name, "SR = SS x SR") || - absl::StartsWith(strategy.name, "RS = RS x SS")) { + if (absl::StartsWith(input_shardings.name, "SR = SS x SR") || + absl::StartsWith(input_shardings.name, "RS = RS x SS")) { int mesh_dim0, mesh_dim1; - std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(strategy.name); + std::tie(mesh_dim0, mesh_dim1) = ParseMeshDims(input_shardings.name); if (!IsDivisible(ins, device_mesh, {out_batch_dim, out_out_channel_dim}, {mesh_dim0, mesh_dim1})) { @@ -3366,7 +3191,7 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, return Tile(ins->shape(), {out_batch_dim, out_out_channel_dim}, {mesh_dim0, mesh_dim1}, device_mesh); } - if (absl::StartsWith(strategy.name, "R = Sk x Sk")) { + if (absl::StartsWith(input_shardings.name, "R = Sk x Sk")) { int mesh_dim = 0; if (!IsDivisible(ins, device_mesh_1d, {out_batch_dim}, {mesh_dim})) { @@ -3380,14 +3205,14 @@ HloSharding GetReduceScatterOutput(const HloInstruction* ins, CHECK_EQ(ins->shape().rank(), 1); int mesh_dim; - if (absl::StrContains(strategy.name, "allreduce @ [0]")) { + if (absl::StrContains(input_shardings.name, "allreduce @ [0]")) { mesh_dim = 0; } else { mesh_dim = 1; } if (strategy.output_sharding.IsReplicated()) { - if (absl::StrContains(strategy.name, "1d")) { + if (absl::StrContains(input_shardings.name, "1d")) { if (!IsDivisible(ins, device_mesh_1d, {0}, {mesh_dim})) { return Undefined(); } @@ -3456,13 +3281,14 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst, } // namespace spmd -std::pair>, bool> +absl::StatusOr AutoShardingImplementation::SaveAndRemoveShardingAnnotation( HloModule* module, const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads) { - absl::flat_hash_map> preserve_shardings; + absl::flat_hash_map> + preserved_shardings; absl::flat_hash_set keep_inst; for (const HloComputation* computation : @@ -3473,16 +3299,16 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( inst->opcode() == HloOpcode::kRecvDone || inst->opcode() == HloOpcode::kSend || inst->opcode() == HloOpcode::kSendDone) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/false, preserved_shardings)); continue; } if (spmd::IsInstructionBeforeSPMDFullToShardShapeCustomCall(inst) || spmd::IsSPMDShardToFullShapeCustomCall(inst)) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/false, preserved_shardings)); } if (inst->has_sharding() && spmd::IsShardingMisaligned(inst->sharding(), inst->shape()) && @@ -3502,12 +3328,12 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( for (const HloComputation* computation : module->computations(execution_threads)) { for (const auto inst : computation->instructions()) { - spmd::SaveShardingForInstruction(inst, - /* save_for_copy_users */ true, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + inst, + /*save_for_copy_users=*/true, preserved_shardings)); } } - return std::make_pair(preserve_shardings, /* module_is_changed */ false); + return SaveShardingAnnotationsResult{preserved_shardings, false}; } bool module_is_changed = false; @@ -3519,23 +3345,23 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( // they are small tensors if (replicated_small_tensors.count(ins->name())) { keep_inst.insert(ins); - spmd::SaveShardingForInstruction(ins, - /* save_for_copy_users */ false, - preserve_shardings); + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( + ins, + /*save_for_copy_users=*/false, preserved_shardings)); continue; } // Do not remove entry computation's parameter and root instruction's - // sharding if preserve_shardings is kKeepInputOutputShardings. + // sharding if preserved_shardings is kKeepInputOutputShardings. if (option_.preserve_shardings == AutoShardingOption::PreserveShardingsType:: kKeepInputOutputShardings && is_entry_computation && (ins->opcode() == HloOpcode::kParameter || ins->IsRoot())) { keep_inst.insert(ins); - spmd::SaveShardingForInstruction( + TF_RETURN_IF_ERROR(spmd::SaveShardingForInstruction( ins, - /* save_for_copy_users */ ins->opcode() == HloOpcode::kParameter, - preserve_shardings); + /*save_for_copy_users=*/ins->opcode() == HloOpcode::kParameter, + preserved_shardings)); continue; } @@ -3559,7 +3385,7 @@ AutoShardingImplementation::SaveAndRemoveShardingAnnotation( } } } - return std::make_pair(preserve_shardings, module_is_changed); + return SaveShardingAnnotationsResult{preserved_shardings, module_is_changed}; } absl::Status AutoShardingImplementation::CanonicalizeLayouts( @@ -3597,10 +3423,15 @@ absl::flat_hash_set ComputeInstructionsToShard( for (HloInstruction* instruction : sequence.instructions()) { if (spmd::IsSPMDFullToShardShapeCustomCall(instruction)) { for (const HloInstruction* user : instruction->users()) { - if (spmd::IsSPMDShardToFullShapeCustomCall(user)) { - continue; + if (!spmd::IsSPMDShardToFullShapeCustomCall(user)) { + queue.push(user); + } + } + } else if (spmd::IsSPMDShardToFullShapeCustomCall(instruction)) { + for (const HloInstruction* operand : instruction->operands()) { + if (!spmd::IsSPMDFullToShardShapeCustomCall(operand)) { + queue.push(operand); } - queue.push(user); } } } @@ -3618,30 +3449,27 @@ absl::flat_hash_set ComputeInstructionsToShard( instruction->called_computations()) { for (const HloInstruction* parameter : computation->parameter_instructions()) { - if (spmd::IsSPMDShardToFullShapeCustomCall(parameter) || - spmd::IsSPMDFullToShardShapeCustomCall(parameter) || - parameter == instruction || visited.contains(parameter)) { - continue; + if (!spmd::IsSPMDShardToFullShapeCustomCall(parameter) && + !spmd::IsSPMDFullToShardShapeCustomCall(parameter) && + parameter != instruction && !visited.contains(parameter)) { + queue.push(parameter); } - queue.push(parameter); } } for (const HloInstruction* user : instruction->users()) { - if (spmd::IsSPMDShardToFullShapeCustomCall(user) || - spmd::IsSPMDFullToShardShapeCustomCall(user) || - visited.contains(user)) { - continue; + if (!spmd::IsSPMDShardToFullShapeCustomCall(user) && + !spmd::IsSPMDFullToShardShapeCustomCall(user) && + !visited.contains(user)) { + queue.push(user); } - queue.push(user); } for (const HloInstruction* operand : instruction->operands()) { - if (spmd::IsSPMDShardToFullShapeCustomCall(operand) || - spmd::IsSPMDFullToShardShapeCustomCall(operand) || - operand == instruction || visited.contains(operand)) { - continue; + if (!spmd::IsSPMDShardToFullShapeCustomCall(operand) && + !spmd::IsSPMDFullToShardShapeCustomCall(operand) && + operand != instruction && !visited.contains(operand)) { + queue.push(operand); } - queue.push(operand); } } @@ -3649,12 +3477,10 @@ absl::flat_hash_set ComputeInstructionsToShard( for (HloInstruction* instruction : sequence.instructions()) { if (!visited.contains(instruction) && !spmd::IsSPMDFullToShardShapeCustomCall(instruction)) { - if (HloCollectiveInstruction::ClassOf(instruction)) { - LOG(FATAL) << "The module contains collective ops not contained within " - "a graph surrounded by SPMDFullToShardShape and " - "SPMDShardToFullShape custom calls. This case is not yet " - "supported."; - } + LOG_IF(FATAL, HloCollectiveInstruction::ClassOf(instruction)) + << "The module contains collective ops not contained within a graph " + "surrounded by SPMDFullToShardShape and SPMDShardToFullShape " + "custom calls. This case is not yet supported."; to_shard.insert(instruction); } } @@ -3689,14 +3515,14 @@ std::pair ReduceMemoryTerms( return num_terms; } -absl::StatusOr AutoShardingImplementation::RunAutoSharding( +absl::StatusOr AutoShardingImplementation::RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution) { if (!option_.enable) { - return AutoShardingResult::kModuleUnchanged; + return false; } bool module_is_changed = false; @@ -3711,7 +3537,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( ProcessShardingInstruction( module, execution_threads, /*replace_sharding_with_copy=*/true, &unspecified_dims, /*saved_root_shardings=*/nullptr, - /*saved_parameter_shardings=*/nullptr)); + /*saved_parameter_shardings=*/nullptr, + /*instruction_to_shard_group_id=*/nullptr, + /*shard_group_id_to_shard_as_group=*/nullptr, + /*shard_group_id_to_shard_like_group=*/nullptr, + /*allow_spmd_sharding_propagation_to_parameters_vector=*/nullptr, + /*remove_unknown_shardings=*/true)); + + DumpHloModuleIfEnabled(*module, "after_spmd_calls"); if (changed) { module_is_changed = true; VLOG(3) << "CustomCalls with custom_call_target=Sharding are removed and " @@ -3768,25 +3601,18 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( const absl::flat_hash_set& instructions_to_shard = ComputeInstructionsToShard(*module, sequence); - std::pair>, bool> - preserve_shardings_result = SaveAndRemoveShardingAnnotation( - module, instructions_to_shard, replicated_small_tensors, - execution_threads); + TF_ASSIGN_OR_RETURN(SaveShardingAnnotationsResult saved_sharding_result, + SaveAndRemoveShardingAnnotation( + module, instructions_to_shard, + replicated_small_tensors, execution_threads)); absl::flat_hash_map> - preserve_shardings = std::move(preserve_shardings_result.first); - module_is_changed |= preserve_shardings_result.second; + preserve_shardings = std::move(saved_sharding_result.preserved_shardings); + module_is_changed |= saved_sharding_result.module_is_changed; absl::flat_hash_map instruction_execution_counts = spmd::ComputeInstructionExecutionCounts( module, option_.loop_iteration_count_estimate); - // ----- Analyze the batch dim ----- - spmd::InstructionBatchDimMap batch_dim_map; - // TODO(yuemmawang) Enable the batch_dim_map if it becomes helpful. This is - // supposed to make the solver faster, but it makes it much much slower for - // both 1D and 2D mesh shapes. - // batch_dim_map = spmd::BuildInstructionBatchDimMap(sequence); - // ----- Read parameters of device mesh ----- spmd::DeviceMesh original_device_mesh(option_.device_mesh_shape); original_device_mesh.SetValues(option_.device_mesh_ids); @@ -3801,6 +3627,10 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( } else { partial_mesh_shapes = {option_.device_mesh_shape}; } + // Allocate an equal portion of solver timeout to each partial mesh shape. + option_.solver_timeout_in_seconds /= partial_mesh_shapes.size(); + LOG(INFO) << "Setting solver timeout per partial mesh shape to " + << option_.solver_timeout_in_seconds << " seconds."; std::unique_ptr call_graph = CallGraph::Build(module); @@ -3871,12 +3701,6 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( << option_.memory_budget_per_device; } - if (option_.force_batch_dim_to_mesh_dim >= 0) { - spmd::DisableIncompatibleMixedMeshShapeAndForceBatchDim( - batch_dim_map, sequence.instructions(), device_mesh.num_elements(), - option_); - } - // ----- Analyze depth ----- spmd::InstructionDepthMap ins_depth_map; ins_depth_map = spmd::BuildInstructionDepthMap(sequence); @@ -3889,8 +3713,8 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::tie(strategy_map, strategy_groups, associative_dot_pairs), BuildStrategyAndCost(sequence, module, instructions_to_shard, instruction_execution_counts, ins_depth_map, - batch_dim_map, alias_map, cluster_env, option_, - *call_graph, hlo_cost_analysis, + alias_map, cluster_env, option_, *call_graph, + hlo_cost_analysis, option_.try_multiple_mesh_shapes)); spmd::AliasSet alias_set = spmd::BuildAliasSet(module, input_output_alias_config, strategy_map); @@ -3981,20 +3805,18 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // ----- Call the ILP Solver ----- std::string request_name = absl::StrCat("mesh_idx_", mesh_idx); - auto solver_result = + TF_ASSIGN_OR_RETURN( + spmd::AutoShardingSolverOutput output, Solve(*module, *hlo_live_range, strategy_map, strategy_groups, cost_graph, alias_set, reduced_node_intervals, reduced_edge_intervals, reduced_node_groups, reduced_edge_groups, - option_, request_name, sharding_propagation_solution); - if (solver_result.skip_auto_sharding) { - return AutoShardingResult::kModuleUnchangedNoShardingPerformed; - } else if (!solver_result.status.ok()) { - return AutoShardingResult::kModuleUnchanged; - } - TF_ASSIGN_OR_RETURN(spmd::AutoShardingSolverOutput output, - solver_result.status); + option_, request_name, sharding_propagation_solution)); if (mesh_idx == partial_mesh_shapes.size() - 1) { this->solver_optimal_objective_value_ = output.cost; + } else { + TF_RET_CHECK(output.is_optimal) + << "The solver did not find an optimal solution for a partial mesh " + << "shape."; } XLA_VLOG_LINES(5, PrintAutoShardingSolution( @@ -4014,21 +3836,14 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( output.s_val, (mesh_idx == partial_mesh_shapes.size() - 1)); if (mesh_idx == partial_mesh_shapes.size() - 1) { - if (!spmd::SetHloShardingPostProcessing(sequence, instructions_to_shard, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } - - if (!InsertReshardReshapes( - sequence, instructions_to_shard, strategy_map, cost_graph, - output.s_val, cluster_env, - /* crash_at_error */ !option_.try_multiple_mesh_shapes, - option_.insert_resharding_reshapes_for_non_dot_ops, - preserve_shardings) - .ok()) { - return AutoShardingResult::kModuleUnchanged; - } + TF_RETURN_IF_ERROR(spmd::SetHloShardingPostProcessing( + sequence, instructions_to_shard, preserve_shardings)); + TF_RETURN_IF_ERROR(InsertReshardReshapes( + sequence, instructions_to_shard, strategy_map, cost_graph, + output.s_val, cluster_env, + /* crash_at_error */ !option_.try_multiple_mesh_shapes, + option_.insert_resharding_reshapes_for_non_dot_ops, + preserve_shardings)); } else { spmd::RecoverShardingsFromPartialMesh(sequence, preserve_shardings); } @@ -4064,12 +3879,12 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( CHECK(instruction->has_sharding()); CHECK(!instruction->sharding().IsManual()); CHECK(instruction->operand(0)->has_sharding()); - CHECK(instruction->operand(0)->sharding().IsManual()); + CHECK(instruction->operand(0)->sharding().IsManual()) + << instruction->ToString(); } } - return module_is_changed ? AutoShardingResult::kModuleChangedShardingPerformed - : AutoShardingResult::kModuleUnchanged; + return module_is_changed; } bool ModuleIsManuallyPartitioned(const HloModule* module) { @@ -4147,14 +3962,55 @@ absl::Status MoveComputationsFromModuleToModule(HloModule* from_module, AutoSharding::AutoSharding(const AutoShardingOption& option) : option_(option) {} +absl::Time DumpModuleAndRecordPassStart(const HloModule* module) { + XLA_VLOG_LINES(6, + absl::StrCat("Before auto sharding:\n", module->ToString())); + DumpHloModuleIfEnabled(*module, "before_auto_spmd_sharding"); + + // TODO(b/348372403) Explore replacing these with a runtime check, per + // go/no-ifdefs-in-xla +#if !defined(__APPLE__) + // Streamz metrics. + metrics::RecordAutoShardingInvocations(); +#endif + return absl::Now(); +} + +void RecordPassEndAndDumpModule(absl::Time start_time, + const HloModule* module) { + absl::Time end_time = absl::Now(); + absl::Duration duration = end_time - start_time; + LOG(INFO) << "Auto Sharding took " << absl::ToInt64Seconds(duration) + << " seconds"; + // TODO(b/348372403) Explore replacing these with a runtime check, per + // go/no-ifdefs-in-xla +#if !defined(__APPLE__) + metrics::RecordAutoShardingCompilationTime( + absl::ToInt64Microseconds(duration)); +#endif + + XLA_VLOG_LINES(6, absl::StrCat("After auto sharding:\n", module->ToString())); + DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); +} + +std::vector FindAllIndices(std::vector vec, int64_t element) { + std::vector result; + for (int i = 0; i < vec.size(); ++i) { + if (vec[i] == element) { + result.push_back(i); + } + } + return result; +} + absl::StatusOr AutoSharding::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!option_.enable) { return false; } - LOG(INFO) << "Starting the auto sharding pass"; + LOG(INFO) << "Starting the auto sharding pass"; // TODO(b/332951306): Remove this check once nested tuples are supported // everywhere if (HasUnsupportedNestedTuples(*module)) { @@ -4164,15 +4020,7 @@ absl::StatusOr AutoSharding::Run( return false; } - XLA_VLOG_LINES(6, - absl::StrCat("Before auto sharding:\n", module->ToString())); - DumpHloModuleIfEnabled(*module, "before_auto_spmd_sharding"); - - absl::Time start_time = absl::Now(); -#if !defined(__APPLE__) - // Streamz metrics. - metrics::RecordAutoShardingInvocations(); -#endif + absl::Time start_time = DumpModuleAndRecordPassStart(module); TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); TF_RETURN_IF_ERROR(option_.CheckAndSetup()); @@ -4276,18 +4124,16 @@ absl::StatusOr AutoSharding::Run( } } - size_t num_meshes = mesh_shapes.size(); - std::vector> changed( - num_meshes, AutoShardingResult::kModuleUnchanged); - + bool module_is_changed = false; VLOG(1) << "Original mesh shape " << spmd::ToString(option_.device_mesh_shape); double min_objective_value = std::numeric_limits::max(); int min_mesh_shape_index = -1; std::unique_ptr min_mesh_shape_module; - bool skip_auto_sharding = true; + std::vector mesh_shape_error_messages(mesh_shapes.size()); for (size_t i = 0; i < mesh_shapes.size(); ++i) { VLOG(1) << "Trying mesh shape " << spmd::ToString(mesh_shapes[i]); + AutoShardingOption this_option = option_; this_option.device_mesh_shape = mesh_shapes[i]; if (this_option.device_mesh_shape.size() != @@ -4296,20 +4142,44 @@ absl::StatusOr AutoSharding::Run( this_option.device_mesh_beta.clear(); TF_RETURN_IF_ERROR(this_option.CheckAndSetup()); } + // Allocate an equal portion of solver timeout to each attempted mesh shape. + this_option.solver_timeout_in_seconds /= mesh_shapes.size(); + LOG(INFO) << "Setting solver timeout per mesh shape to " + << this_option.solver_timeout_in_seconds << " seconds."; + + // Try to infer DCN axis if the HLO is multi-slice. + // TODO(b/372720563) Improve this DCN axis inference. Currently, we assume + // there is only one DCN axis, and that there is no ICI axis with the same + // size as the DCN axis. + if (option_.num_dcn_slices.has_value() && *option_.num_dcn_slices > 1) { + std::vector dcn_indices = + FindAllIndices(mesh_shapes[i], *option_.num_dcn_slices); + if (dcn_indices.empty()) { + VLOG(1) << " Mesh shape does not contain DCN axis."; + } else { + if (dcn_indices.size() > 1) { + LOG(WARNING) + << "Could not infer a unique DCN axis. Choosing one randomly."; + } + this_option.device_mesh_alpha[dcn_indices[0]] = kDcnDeviceMeshAlpha; + this_option.device_mesh_beta[dcn_indices[0]] = kDcnDeviceMeshBeta; + } + } + auto pass = std::make_unique(this_option); std::unique_ptr module_clone = CloneModule(module); - absl::StatusOr pass_result = + absl::StatusOr pass_result = pass->RunAutoSharding(module_clone.get(), replicated_small_tensors, execution_threads, sharding_propagation_solution); - - changed[i] = pass_result; - double this_mesh_objective_value = pass->GetSolverOptimalObjectiveValue(); if (!pass_result.ok()) { + mesh_shape_error_messages[i] = pass_result.status().message(); VLOG(1) << "Mesh shape " << spmd::ToString(mesh_shapes[i]) << " led to the following error: " << pass_result.status().message(); continue; } + + double this_mesh_objective_value = pass->GetSolverOptimalObjectiveValue(); VLOG(1) << "Mesh shape " << spmd::ToString(mesh_shapes[i]) << " has objective value " << this_mesh_objective_value; if (this_mesh_objective_value >= 0 && @@ -4317,65 +4187,36 @@ absl::StatusOr AutoSharding::Run( min_mesh_shape_index = i; min_mesh_shape_module = std::move(module_clone); min_objective_value = this_mesh_objective_value; - } - if (*pass_result != - AutoShardingResult::kModuleUnchangedNoShardingPerformed) { - skip_auto_sharding = false; + CHECK_OK(pass_result); + module_is_changed = *pass_result; } } - absl::StatusOr module_is_changed; - if (skip_auto_sharding) { - module_is_changed = false; // The auto-sharding solver timed out. - } else { - std::string trying_to_find = - option_.try_multiple_mesh_shapes - ? "a device mesh (and the corresponding shardings)" - : "shardings"; - CHECK_GE(min_mesh_shape_index, 0) - << "The auto-sharding pass could not find " << trying_to_find - << " that works for this input. This could be the result of a low " - "memory budget (please refer to the " - "`--xla_tpu_auto_spmd_partitioning_memory_budget_ratio` flag to set " - "a higher budget). If you think you have set a reasonably large " - "memory budget, please report this as a bug."; - - if (!changed[min_mesh_shape_index].ok()) { - module_is_changed = changed[min_mesh_shape_index].status(); - } else { - solver_optimal_objective_value_ = min_objective_value; - if (changed[min_mesh_shape_index].value() == - AutoShardingResult::kModuleChangedShardingPerformed) { - VLOG(1) << "Choosing mesh shape " - << spmd::ToString(mesh_shapes[min_mesh_shape_index]) - << " which had the minimal solver objective value of " - << min_objective_value; - chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; - TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( - min_mesh_shape_module.get(), module)); - module_is_changed = true; - } else { - module_is_changed = false; - } + if (min_mesh_shape_index < 0) { + std::string error_message = + "The auto-sharding pass could not find a solution for any of the mesh " + "shapes tried. Below, we list the errors encountered for each of the " + "mesh shapes:\n"; + for (size_t i = 0; i < mesh_shapes.size(); ++i) { + LOG(INFO) << mesh_shape_error_messages[i]; + absl::StrAppend(&error_message, "Mesh shape ", + spmd::ToString(mesh_shapes[i]), ": ", + mesh_shape_error_messages[i], "\n"); } + return absl::InternalError(error_message); } - absl::Time end_time = absl::Now(); - absl::Duration duration = end_time - start_time; - LOG(INFO) << "Auto Sharding took " << absl::ToInt64Seconds(duration) - << " seconds"; -#if !defined(__APPLE__) - metrics::RecordAutoShardingCompilationTime( - absl::ToInt64Microseconds(duration)); -#endif - - XLA_VLOG_LINES(6, absl::StrCat("After auto sharding:\n", module->ToString())); - DumpHloModuleIfEnabled(*module, "after_auto_spmd_sharding"); - - if (skip_auto_sharding) { - LOG(FATAL) << "The auto-sharding solver has timed out without a solution."; + solver_optimal_objective_value_ = min_objective_value; + if (module_is_changed) { + VLOG(1) << "Choosing mesh shape " + << spmd::ToString(mesh_shapes[min_mesh_shape_index]) + << " which had the minimal solver objective value of " + << min_objective_value; + chosen_mesh_shape_ = mesh_shapes[min_mesh_shape_index]; + TF_RETURN_IF_ERROR(MoveComputationsFromModuleToModule( + min_mesh_shape_module.get(), module)); } - + RecordPassEndAndDumpModule(start_time, module); return module_is_changed; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index be983f15916eed..40cbc6295b0bc8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -44,35 +45,33 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" namespace xla { -enum class AutoShardingResult { - kModuleUnchanged, - kModuleChangedShardingPerformed, - kModuleUnchangedNoShardingPerformed -}; - class AutoShardingImplementation { public: explicit AutoShardingImplementation(const AutoShardingOption& option); ~AutoShardingImplementation() = default; - absl::StatusOr RunAutoSharding( + absl::StatusOr RunAutoSharding( HloModule* module, const absl::flat_hash_set& replicated_small_tensors, const absl::flat_hash_set& execution_threads, const absl::flat_hash_map& sharding_propagation_solution = {}); + struct SaveShardingAnnotationsResult { + absl::flat_hash_map> + preserved_shardings; + bool module_is_changed; + }; + // Returns sharding annotations that need to be preserved in a map (for // verification after auto-sharding is done), and removes any sharding - // anotations that need to be removed. - std::pair>, bool> - SaveAndRemoveShardingAnnotation( + // annotations that need to be removed. + absl::StatusOr SaveAndRemoveShardingAnnotation( HloModule* module, const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_set& replicated_small_tensors, @@ -116,8 +115,10 @@ class AutoSharding : public HloModulePass { std::vector GetChosenDeviceMeshShape() { return chosen_mesh_shape_; } - private: + protected: AutoShardingOption option_; + + private: // Stores the optimal value of the objective the solver found. double solver_optimal_objective_value_ = -1.0; // Stores the optimal mesh shape found. @@ -151,7 +152,6 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group); absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, StrategyGroup& strategy_group); @@ -162,7 +162,6 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -173,7 +172,6 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); @@ -211,22 +209,10 @@ bool HasReduceScatterOpportunity(const HloInstruction* inst, const ConstInstructionSet& modified); HloSharding GetReduceScatterOutput(const HloInstruction* ins, + const InputShardings& input_shardings, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env); -// The high-level "recipe" for solving an Auto Sharding problem. -AutoShardingSolverResult Solve( - const HloModule& hlo_module, const HloLiveRange& hlo_live_range, - const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, - const CostGraph& cost_graph, const AliasSet& alias_set, - const std::vector>& node_intervals, - const std::vector>& edge_intervals, - const std::vector>& node_groups, - const std::vector>& edge_groups, - const AutoShardingOption& option, absl::string_view request_prefix, - const absl::flat_hash_map& - sharding_propagation_solution = {}); - // Populates temporal distance values. void PopulateTemporalValues(const CostGraph& cost_graph, AutoShardingSolverRequest& request); @@ -250,17 +236,16 @@ void FillAllStrategiesForArray( const HloInstruction* ins, const Shape& shape, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const CallGraph& call_graph, - bool only_allow_divisible, bool create_replicated_strategies, + const CallGraph& call_graph, bool only_allow_divisible, + bool create_replicated_strategies, bool create_partially_replicated_strategies, StrategyGroup& strategy_group); absl::StatusOr> CreateAllStrategiesGroup( const HloInstruction* ins, const Shape& shape, size_t instruction_id, StrategyGroups& strategy_groups, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const AutoShardingOption& option, - double replicated_penalty, const InstructionBatchDimMap& batch_dim_map, - const CallGraph& call_graph, bool only_allow_divisible, - bool create_replicated_strategies, + double replicated_penalty, const CallGraph& call_graph, + bool only_allow_divisible, bool create_replicated_strategies, bool create_partially_replicated_strategies); // Enumerates sharding strategies for elementwise operators by following @@ -294,7 +279,6 @@ std::unique_ptr CreateReshapeStrategies( size_t instruction_id, const HloInstruction* ins, const StrategyMap& strategy_map, const ClusterEnvironment& cluster_env, bool only_allow_divisible, double replicated_penalty, - const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph); @@ -313,8 +297,7 @@ void EnumerateAll1DPartition( void EnumerateAllPartition( const HloInstruction* ins, const Shape& shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, - const StrategyMap& strategy_map, - const InstructionBatchDimMap& batch_dim_map, bool only_allow_divisible, + const StrategyMap& strategy_map, bool only_allow_divisible, bool allow_shardings_small_dims_across_many_devices, const CallGraph& call_graph, int64_t partition_dimensions, const std::vector& tensor_dims, StrategyGroup& strategy_group); @@ -368,8 +351,7 @@ BuildStrategyAndCost( const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_map& instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, const AliasMap& alias_map, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 67a67a2a27884e..5d1016830c1c63 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -123,16 +123,26 @@ CostGraph::CostGraph(const StrategyGroups& strategy_groups, node_lens_[dst_idx]); absl::flat_hash_map src_strategy_name_to_idx_map; - for (NodeStrategyIdx i = 0; i < node_lens_[src_idx]; ++i) { + const auto& src_strategy_input_shardings = + src_strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < src_strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = src_strategy_input_shardings[iid]; + NodeStrategyIdx i = + src_strategy_group.GetStrategyIdxForInputShardings(iid); const ShardingStrategy& strategy = src_strategy_group.GetStrategy(i); if (strategy.communication_cost > 0) { - src_strategy_name_to_idx_map[strategy.name] = i; + src_strategy_name_to_idx_map[input_shardings.name] = i; } } - for (NodeStrategyIdx i = 0; i < node_lens_[dst_idx]; ++i) { + const auto& dst_strategy_input_shardings = + dst_strategy_group.GetStrategyInputShardings(); + for (size_t iid = 0; iid < dst_strategy_input_shardings.size(); ++iid) { + const InputShardings& input_shardings = dst_strategy_input_shardings[iid]; + NodeStrategyIdx i = + dst_strategy_group.GetStrategyIdxForInputShardings(iid); const ShardingStrategy& dst_strategy = dst_strategy_group.GetStrategy(i); if (dst_strategy.communication_cost > 0) { - auto it = src_strategy_name_to_idx_map.find(dst_strategy.name); + auto it = src_strategy_name_to_idx_map.find(input_shardings.name); if (it != src_strategy_name_to_idx_map.end()) { const auto& src_strategy = src_strategy_group.GetStrategy(it->second); CHECK_LE(std::abs(src_strategy.communication_cost - diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index f80958b099ff45..e161674473157e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -70,7 +70,6 @@ class HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : strategy_group_(strategy_group), strategy_map_(strategy_map), @@ -79,7 +78,6 @@ class HandlerBase { instruction_sequence_(instruction_sequence), hlo_cost_analysis_(hlo_cost_analysis), cluster_env_(cluster_env), - batch_map_(batch_map), option_(option), call_graph_(call_graph), device_mesh_(cluster_env.device_mesh_), @@ -221,7 +219,6 @@ class HandlerBase { const HloInstructionSequence& instruction_sequence_; const HloCostAnalysis& hlo_cost_analysis_; const ClusterEnvironment& cluster_env_; - const InstructionBatchDimMap& batch_map_; const AutoShardingOption& option_; const CallGraph& call_graph_; @@ -238,7 +235,6 @@ class DotHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); DotHandler( @@ -247,8 +243,7 @@ class DotHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const ClusterEnvironment& cluster_env, const AutoShardingOption& option, const CallGraph& call_graph); ~DotHandler() override = default; @@ -293,7 +288,6 @@ class ConvHandler : public HandlerBase { const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph); ~ConvHandler() override = default; @@ -341,12 +335,12 @@ void HandlerBase::AppendNewStrategy(const std::string& name, } strategy_group_->AddStrategy( - ShardingStrategy({name, output_spec, compute_cost, communication_cost, + ShardingStrategy({output_spec, compute_cost, communication_cost, static_cast(ByteSizeOfShapeWithSharding( ins_->shape(), output_spec)), communication_resharding_costs, memory_resharding_costs}), - {input_specs.begin(), input_specs.end()}); + {name, {input_specs.begin(), input_specs.end()}}); } // Given lhs and rhs dim maps, infers a sharding for the output by relying @@ -444,13 +438,13 @@ std::optional HandlerBase::GetShardingFromUser( CHECK_OK(ins_clone->ReplaceOperandWith(1, rhs_clone.get())); if (ins_->opcode() == HloOpcode::kConvolution) { xla::InferConvolutionShardingFromOperands( - ins_clone.get(), /* aggressiveness */ 10, - /* may_combine_partial_sharding */ true); + ins_clone.get(), call_graph_, + /* may_combine_partial_sharding */ true, /* is_spmd */ true); } else { xla::InferDotShardingFromOperands( - ins_clone.get(), + ins_clone.get(), call_graph_, dot_as_convolution_util::ParseDotGeneralFromDot(ins_clone.get()), - /* aggressiveness */ 10, /* may_combine_partial_sharding */ true); + /* may_combine_partial_sharding/ */ true, /* is_spmd */ true); } if (!ins_clone->has_sharding()) { return std::nullopt; @@ -473,7 +467,7 @@ void HandlerBase::SortStrategies() { [](const std::pair& s1, const std::pair& s2) { if (s1.first.memory_cost == s2.first.memory_cost) { - return s1.first.name < s2.first.name; + return s1.second.name < s2.second.name; } else { return s1.first.memory_cost < s2.first.memory_cost; } @@ -492,12 +486,11 @@ DotHandler::DotHandler(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), is_dot_(true), space_base_dim_(ins->dot_dimension_numbers().lhs_batch_dimensions_size()), lhs_con_dims_(ins->dot_dimension_numbers().lhs_contracting_dimensions()), @@ -525,12 +518,11 @@ DotHandler::DotHandler( const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims, - const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, + const ClusterEnvironment& cluster_env, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), is_dot_(false), space_base_dim_(-1) { CHECK(conv_as_dot_dims.conv_spatial_dims.empty()); @@ -858,12 +850,11 @@ ConvHandler::ConvHandler(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) : HandlerBase(strategy_group, strategy_map, ins, instruction_id, - instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph), + instruction_sequence, hlo_cost_analysis, cluster_env, option, + call_graph), conv_dnums_(ins->convolution_dimension_numbers()) { lhs_batch_dim_ = conv_dnums_.input_batch_dimension(); lhs_in_channel_dim_ = conv_dnums_.input_feature_dimension(); @@ -969,14 +960,6 @@ absl::Status ConvHandler::RegisterStrategies() { 2, /*current_mesh_dim_idx=*/0, all_mesh_dims, /*current_dim_map=*/{}); - // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies - // and only keep the data parallel strategies. - if (option_.force_batch_dim_to_mesh_dim >= 0 && - batch_map_.contains(GetBatchDimMapKey(ins_))) { - TF_RETURN_IF_ERROR(FilterStrategy(ins_, ins_->shape(), cluster_env_, - batch_map_, option_, *strategy_group_)); - } - SortStrategies(); return absl::OkStatus(); } @@ -1013,7 +996,9 @@ void ConvHandler::SplitDepthwise(bool forward) { lhs_dim_map, rhs_dim_map, output_dim_map); }; std::vector all_mesh_dims(device_mesh_.num_dimensions()); - Enumerate(split_func, 2, /*current_mesh_dim_idx=*/0, all_mesh_dims, + std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0); + Enumerate(split_func, ins_->shape().rank(), /*current_mesh_dim_idx=*/0, + all_mesh_dims, /*current_dim_map=*/{}); } @@ -1027,7 +1012,6 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -1035,7 +1019,7 @@ absl::Status HandleDot(std::unique_ptr& strategy_group, DotHandler handler(strategy_group, strategy_map, Cast(ins), instruction_id, instruction_sequence, hlo_cost_analysis, - cluster_env, batch_map, option, call_graph); + cluster_env, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); return absl::OkStatus(); } @@ -1048,7 +1032,6 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const HloInstructionSequence& instruction_sequence, const HloCostAnalysis& hlo_cost_analysis, const ClusterEnvironment& cluster_env, - const InstructionBatchDimMap& batch_map, const AutoShardingOption& option, const CallGraph& call_graph) { strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, @@ -1057,16 +1040,16 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { - DotHandler handler( - strategy_group, strategy_map, Cast(ins), - instruction_id, instruction_sequence, hlo_cost_analysis, - conv_as_dot_dims, cluster_env, batch_map, option, call_graph); + DotHandler handler(strategy_group, strategy_map, + Cast(ins), instruction_id, + instruction_sequence, hlo_cost_analysis, + conv_as_dot_dims, cluster_env, option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } else { ConvHandler handler(strategy_group, strategy_map, ins, instruction_id, instruction_sequence, hlo_cost_analysis, cluster_env, - batch_map, option, call_graph); + option, call_graph); TF_RETURN_IF_ERROR(handler.RegisterStrategies()); } return absl::OkStatus(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc index e0fdd6ad71bbf8..b9226f561244ea 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_impl.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -37,7 +38,7 @@ limitations under the License. namespace xla { namespace spmd { -AutoShardingSolverResult Solve( +absl::StatusOr Solve( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, @@ -48,13 +49,13 @@ AutoShardingSolverResult Solve( const AutoShardingOption& option, absl::string_view request_prefix, const absl::flat_hash_map& sharding_propagation_solution) { - return CallSolver(hlo_module, hlo_live_range, strategy_map, strategy_groups, - cost_graph, alias_set, node_intervals, edge_intervals, - node_groups, edge_groups, /*s_hint*/ {}, - /*compute_iis*/ true, option.solver_timeout_in_seconds, - option, /*max_cost*/ std::nullopt, request_prefix, - sharding_propagation_solution, - /*deterministic mode*/ true); + return CreateAutoShardingSolverRequestAndCallSolver( + hlo_module, hlo_live_range, strategy_map, strategy_groups, cost_graph, + alias_set, node_intervals, edge_intervals, node_groups, edge_groups, + /*s_hint*/ {}, + /*compute_iis*/ true, option.solver_timeout_in_seconds, option, + /*max_cost*/ std::nullopt, request_prefix, sharding_propagation_solution, + /*deterministic mode*/ true); } void PopulateTemporalValues(const CostGraph& cost_graph, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index f8c2d41614ccb5..7a8408bee00994 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -66,8 +66,6 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("reduce_scatter_cost: ", reduce_scatter_cost)); } - lines.push_back(absl::StrCat("force_batch_dim_to_mesh_dim: ", - force_batch_dim_to_mesh_dim)); lines.push_back(absl::StrCat("allow_replicated_parameters: ", allow_replicated_parameters)); lines.push_back( @@ -143,6 +141,10 @@ std::string AutoShardingOption::ToString() const { lines.push_back(absl::StrCat("insert_resharding_reshapes_for_non_dot_ops: ", insert_resharding_reshapes_for_non_dot_ops)); + if (num_dcn_slices.has_value()) { + lines.push_back(absl::StrCat("num_dcn_slices: ", *num_dcn_slices)); + } + return absl::StrJoin(lines, "\n"); } @@ -166,14 +168,16 @@ absl::Status AutoShardingOption::CheckAndSetup() { if (device_mesh_alpha.empty()) { // Generates simple device_mesh_alpha based on the size of // device_mesh_shape. - device_mesh_alpha = std::vector(device_mesh_shape.size(), kDeviceMeshAlpha); + device_mesh_alpha = + std::vector(device_mesh_shape.size(), kIciDeviceMeshAlpha); VLOG(0) << "Using default values for device_mesh_alpha: " << absl::StrJoin(device_mesh_alpha, ","); } if (device_mesh_beta.empty()) { // Generates simple device_mesh_beta based on the size of // device_mesh_shape. - device_mesh_beta = std::vector(device_mesh_shape.size(), kDeviceMeshBeta); + device_mesh_beta = + std::vector(device_mesh_shape.size(), kIciDeviceMeshBeta); VLOG(0) << "Using default values for device_mesh_beta: " << absl::StrJoin(device_mesh_beta, ","); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 468ab4aa8f3c79..33b6f4385ee20f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_OPTION_H_ #include +#include #include #include @@ -24,8 +25,12 @@ limitations under the License. namespace xla { -static constexpr double kDeviceMeshAlpha = 1.0; -static constexpr double kDeviceMeshBeta = 1.0; +static constexpr double kIciDeviceMeshAlpha = 1.0; +static constexpr double kIciDeviceMeshBeta = 1.0; +// By default, assume that DCN communication is 10 times slower than ICI +// communication +static constexpr double kDcnDeviceMeshAlpha = 10.0; +static constexpr double kDcnDeviceMeshBeta = 10.0; static constexpr double kOverbudgetCoeff = 1e6; // Options for the autosharding pass @@ -84,11 +89,6 @@ struct AutoShardingOption { bool force_override_reduce_scatter_cost = false; double reduce_scatter_cost = 0; - // Forcibly split the batch dimension and map it to a mesh dimension. - // This can force the auto-sharding pass to generate the data parallel - // strategy. - int force_batch_dim_to_mesh_dim = -1; - // If true, allow replicated parameters. bool allow_replicated_parameters = true; @@ -165,7 +165,7 @@ struct AutoShardingOption { // Static estimate for iteration count of a while loop, used in the cost // model. This estimate is used when we cannot infer an upper bound on the // number of iterations in the loop (as implemented in - // third_party/tensorflow/compiler/xla/service/while_loop_analysis.h) + // third_party/tensorflow/compiler/xla/hlo/analysis/while_loop_analysis.h) int64_t loop_iteration_count_estimate = 100; // Allows the conversion of aliases to followers if their pairwise strategy @@ -207,6 +207,9 @@ struct AutoShardingOption { // ops in a principled manner. bool insert_resharding_reshapes_for_non_dot_ops = false; + // The number of slices used + std::optional num_dcn_slices = std::nullopt; + // Prints a debug string. std::string ToString() const; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc index e82bf2a5e31751..fd20b8b596e10e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_runner.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/status/status.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_parser.h" #include "xla/tools/hlo_module_loader.h" #include "tsl/platform/init_main.h" diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index 073b2b8ac5221b..f88f01c58b49f2 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -78,13 +78,7 @@ constexpr double kMaxCostValue = 1e18; bool AutoShardingSolverOutput::operator==( const AutoShardingSolverOutput& other) const { return s_val == other.s_val && cost == other.cost && - peak_times == other.peak_times; -} - -bool AutoShardingSolverResult::operator==( - const AutoShardingSolverResult& other) const { - return status == other.status && - skip_auto_sharding == other.skip_auto_sharding; + is_optimal == other.is_optimal && peak_times == other.peak_times; } void PrintLargestInstructions( @@ -143,7 +137,7 @@ void PrintLargestInstructions( } } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -399,7 +393,7 @@ void AddMemoryTerms( // can be a few (usually < 10) edges in the problem with negative costs. This // is guaranteed to never produce a negative overall cost for the graph, // however. -AutoShardingSolverResult CallORToolsSolver( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& unscaled_request) { const absl::Time start_time = absl::Now(); const AutoShardingSolverRequest& request = ScaleRequest(unscaled_request); @@ -420,13 +414,16 @@ AutoShardingSolverResult CallORToolsSolver( // Set random_seed, interleave_search and share_binary_clauses for // determinism, mip_max_bound (to handle large costs), and num_workers for // parallelism. - solver_parameter_str = - request.deterministic_mode() - ? absl::StrCat( - "share_binary_clauses:false,random_seed:1,interleave_" - "search:true,num_workers:", - num_workers) - : absl::StrCat("num_workers:", num_workers); + solver_parameter_str = absl::StrCat("num_workers:", num_workers); + if (request.deterministic_mode()) { + absl::StrAppend( + &solver_parameter_str, + ",share_binary_clauses:false,random_seed:1,interleave_search:true"); + } + if (request.has_solver_timeout()) { + absl::StrAppend(&solver_parameter_str, ",max_deterministic_time:", + request.solver_timeout().solver_timeout_in_seconds()); + } solver->SetSolverSpecificParametersAsString(solver_parameter_str); } // Create variables @@ -565,8 +562,7 @@ AutoShardingSolverResult CallORToolsSolver( LOG(FATAL) << err_msg; } else { LOG(WARNING) << err_msg; - return AutoShardingSolverResult(absl::InternalError(err_msg), - /*skip_auto_sharding=*/false); + return absl::InternalError(err_msg); } } } @@ -745,20 +741,21 @@ AutoShardingSolverResult CallORToolsSolver( tsl::Fingerprint64(unscaled_request.SerializeAsString()); std::string request_dump_path = absl::StrCat("/tmp/solver_request_", unscaled_request.request_name(), - "_", solver_request_fprint, ".proto"); - auto write_status = file::SetBinaryProto( + "_", solver_request_fprint, ".textproto"); + auto write_status = file::SetTextProto( // Modify this file path if needed. request_dump_path, unscaled_request, file::Defaults()); - VLOG(5) << "Dumped solver request to " << request_dump_path; + LOG(INFO) << "Dumped solver request to " << request_dump_path; if (!write_status.ok()) { LOG(ERROR) << write_status.message(); } } -#endif - if (request.has_solver_timeout()) { - solver->SetTimeLimit( - absl::Seconds(request.solver_timeout().solver_timeout_in_seconds())); + // Invokes the solver request callback for any additional debugging. + bool solver_request_callback = false; + if (solver_request_callback) { + SolverRequestCallback(unscaled_request); } +#endif if (request.enable_output()) { solver->EnableOutput(); } @@ -766,27 +763,33 @@ AutoShardingSolverResult CallORToolsSolver( << "Solver parameter string: " << solver_parameter_str << "\n" << "Number of workers: " << num_workers << "\n" << "Number of threads: " << solver->GetNumThreads() << "\n" - << "Time limit: " << solver->time_limit() << "\n" + << "Time limit: " + << request.solver_timeout().solver_timeout_in_seconds() + << " seconds\n" << "Request valid: " << ValidateRequest(request).ok() << "\n" << "Aliases: " << request.aliases_size() << "\n" << "Unique nodes: " << unique_nodes << "\n" << "Unique edges: " << unique_edges << "\n" << "Total instructions: " << request.num_nodes() << "\n" << "Total edges: " << request.edges_size() << "\n" - << "Memory budget: " << request.memory_budget() / (1024 * 1024 * 1024) - << "GB\n" + << "Memory budget: " << request.memory_budget() << " (" + << request.memory_budget() / (1024 * 1024 * 1024) << "GB)\n" << "Number variables for ILP: " << solver->NumVariables() << "\n" << "Number of ILP constraints: " << solver->NumConstraints() << "\n" << "Deterministic mode: " << request.deterministic_mode() << "\n" + << "Minimize departures: " << request.minimize_departures() << "\n" << "Module name: " << request.module_name(); if (request.has_max_cost()) { VLOG(0) << "Max cost: " << request.max_cost().coeff(); } + if (request.has_max_departures()) { + VLOG(0) << "Max departures: " << request.max_departures().coeff(); + } auto result = SolveAndExtractSolution(request, s, e, overbudget_var, makespan_var, *solver); - if (result.status.ok()) { + if (result.ok()) { const AutoShardingEvaluation evaluation = - Evaluate(unscaled_request, result); + Evaluate(unscaled_request, *result); LOG(INFO) << "*** Total costs for the (unscaled) solver request ***"; LOG(INFO) << "Total Communication Cost: " << evaluation.total.communication_cost @@ -809,6 +812,7 @@ AutoShardingSolverResult CallORToolsSolver( LOG(INFO) << "Total Departures: " << evaluation.total_departures; LOG(INFO) << "Total Makespan: " << evaluation.total_makespan; LOG(INFO) << "Total Violations: " << evaluation.violation_codes.size(); + LOG(INFO) << "Maximum Total Memory: " << evaluation.max_total_memory; } const absl::Time end_time = absl::Now(); const auto duration = end_time - start_time; @@ -832,7 +836,7 @@ std::vector GetChosenNodeStrategy( return chosen_node_strategy; } -AutoShardingSolverResult SolveAndExtractSolution( +absl::StatusOr SolveAndExtractSolution( const AutoShardingSolverRequest& request, const std::vector>& s, const std::vector>& e, @@ -841,6 +845,7 @@ AutoShardingSolverResult SolveAndExtractSolution( auto status = solver.Solve(); LOG(INFO) << "Solver absl::Status: " << status; + bool is_optimal = false; if (status == operations_research::MPSolver::INFEASIBLE) { LOG(ERROR) << "MPSolver could not find any feasible solution."; #ifdef PLATFORM_GOOGLE @@ -870,19 +875,20 @@ AutoShardingSolverResult SolveAndExtractSolution( } } #endif - return AutoShardingSolverResult( - absl::InternalError("MPSolver could not find any feasible solution."), - /*skip_auto_sharding=*/false); + return absl::InternalError( + "MPSolver could not find any feasible solution."); } else if (status == operations_research::MPSolver::MODEL_INVALID) { - LOG(FATAL) << "Solver says that the input MIP is invalid. This is most " - "likely a bug and should be reported."; - return AutoShardingSolverResult(absl::InternalError("Solver timed out."), - /*skip_auto_sharding=*/false); + LOG(FATAL) << "The MIP fed to the solver is invalid. This is most likely a " + "bug and should be reported."; + return absl::InternalError("Invalid MIP."); + } else if (status == operations_research::MPSolver::NOT_SOLVED) { + LOG(WARNING) << "Solver timeout; no solution was produced"; + return absl::InternalError("Solver timed out."); } else if (status != operations_research::MPSolver::OPTIMAL) { - return AutoShardingSolverResult(absl::InternalError("Solver timed out."), - /*skip_auto_sharding=*/true); + LOG(WARNING) << "Solver timeout; moving forward with a suboptimal solution"; + } else { + is_optimal = true; } - // Fingerprint the model & solution (useful when checking for determinism). // We use TensorFlow's fingerprint library here, which differs from CP-SAT's. operations_research::MPModelProto model_proto; @@ -949,9 +955,9 @@ AutoShardingSolverResult SolveAndExtractSolution( << request.memory_budget() / (1024 * 1024 * 1024) << " GB"; } PrintLargestInstructions(chosen_node_strategy, request); - const AutoShardingSolverOutput output = {std::move(chosen_node_strategy), - solver.Objective().Value()}; - return AutoShardingSolverResult(output, /*skip_auto_sharding=*/false); + return AutoShardingSolverOutput{.s_val = std::move(chosen_node_strategy), + .cost = solver.Objective().Value(), + .is_optimal = is_optimal}; } bool CostComponents::operator==(const CostComponents& other) const { @@ -975,13 +981,13 @@ bool AutoShardingEvaluation::operator==( } AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result) { + const AutoShardingSolverOutput& result) { const auto& c = request.computation_costs(); const auto& d = request.communication_costs(); const auto& r = request.resharding_costs(); const auto& v = request.value_costs(); const auto& p = request.departure_costs(); - const std::vector& s_val = result.status->s_val; + const std::vector& s_val = result.s_val; const auto e_val = [&](EdgeIdx edge_idx) { const auto& edge = request.edges(edge_idx); return s_val[edge.first()] * request.s_len(edge.second()) + @@ -1023,8 +1029,6 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } } if (request.memory_budget() > 0) { - double total_overbudget = 0.0; - double lower_bound_overbudget = 0.0; std::vector total_memory_costs, lower_bound_memory_costs; if (request.node_intervals().empty()) { // Handles live matrices. total_memory_costs.resize(request.live_size(), 0.0); @@ -1125,8 +1129,12 @@ AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, } } } + double total_overbudget = 0.0; + double lower_bound_overbudget = 0.0; for (LivenessIdx time_idx = 0; time_idx < total_memory_costs.size(); ++time_idx) { + evaluation.max_total_memory = + std::max(evaluation.max_total_memory, total_memory_costs[time_idx]); if (request.has_overbudget_coeff()) { total_overbudget = std::max(total_overbudget, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h index cb051f7718fd44..811f31137e02d8 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_SOLVER_H_ -#include #include #include "absl/container/flat_hash_set.h" @@ -32,22 +31,13 @@ namespace spmd { struct AutoShardingSolverOutput { std::vector s_val; double cost = -1.0; + bool is_optimal = true; absl::flat_hash_set peak_times; bool operator==(const AutoShardingSolverOutput& other) const; }; -struct AutoShardingSolverResult { - public: - AutoShardingSolverResult(absl::StatusOr status, - bool skip_auto_sharding) - : status(status), skip_auto_sharding(skip_auto_sharding) {} - bool operator==(const AutoShardingSolverResult& other) const; - absl::StatusOr status; - bool skip_auto_sharding; -}; - -AutoShardingSolverResult CallORToolsSolver( +absl::StatusOr FormulateAndSolveMIPFromSolverRequest( const AutoShardingSolverRequest& request); enum AutoShardingViolationCode { @@ -86,13 +76,16 @@ struct AutoShardingEvaluation { // The (raw) total makespan, i.e., not scaled by the makespan coefficient. double total_makespan = 0.0; + // The maximum total memory over all time steps. + double max_total_memory = 0.0; + bool operator==(const AutoShardingEvaluation& other) const; }; // Evaluates the given solver result w.r.t. the input request, computing various // solution quality metrics and validating the consistency of hard constraints. AutoShardingEvaluation Evaluate(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result); + const AutoShardingSolverOutput& result); // Creates and returns a variable for makespan. operations_research::MPVariable* CreateMakespanVar( @@ -101,7 +94,7 @@ operations_research::MPVariable* CreateMakespanVar( operations_research::MPSolver& solver); double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation); // Scale down values to reduce the range of costs & coefficients in the solver. @@ -138,6 +131,8 @@ class StrategyShaver { // Note: This does not include checks for valid variable aliasing yet. absl::Status ValidateRequest(const AutoShardingSolverRequest& request); +void SolverRequestCallback(const AutoShardingSolverRequest& request); + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc index 4be54f98a0a496..570a21268c50e9 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_impl.cc @@ -33,7 +33,7 @@ MPVariable* CreateMakespanVar(const AutoShardingSolverRequest& request, } double EvaluateMakespan(const AutoShardingSolverRequest& request, - const AutoShardingSolverResult& result, + const AutoShardingSolverOutput& result, AutoShardingEvaluation& evaluation) { return 0.0; // TODO(moffitt): Implement this. } @@ -45,5 +45,9 @@ NodeStrategies StrategyShaver::FindShavedStrategies() const { return {}; // TODO(moffitt): Implement this. } +void SolverRequestCallback(const AutoShardingSolverRequest& request) { + // TODO(mofftt): Implement this. +} + } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 81c02acd354bd5..4ddafbee670cae 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/auto_sharding.pb.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "tsl/platform/platform.h" +#include "tsl/platform/statusor.h" namespace xla { namespace spmd { @@ -250,87 +251,87 @@ AutoShardingSolverRequest AutoShardingSolverRequestWithEquivalences() { return request; } -TEST(CallORToolsSolverTest, SolvesOptimally) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOptimally) { const AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, SolvesOverbudget) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesOverbudget) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_memory_budget(100000); request.mutable_overbudget_coeff()->set_coeff(10.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 9007650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, SolvesMaxDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesMaxDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_departures()->set_coeff(3.0); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, MinimizesDepartures) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, MinimizesDepartures) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.set_minimize_departures(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 1, 0, 0, 1}; const double objective_value = 3.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, AvoidsInfiniteNodeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteNodeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_computation_costs(0)->set_costs(0, kInfinityCost); request.mutable_computation_costs(0)->set_costs(1, kInfinityCost); request.mutable_computation_costs(0)->set_costs(2, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {3, 0, 0, 0, 0}; const double objective_value = 10683.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, AvoidsInfiniteEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, AvoidsInfiniteEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_resharding_costs(0)->set_costs(0, kInfinityCost); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesFollowedEdges) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesFollowedEdges) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(1); @@ -346,16 +347,16 @@ TEST(CallORToolsSolverTest, HandlesFollowedEdges) { 70000, 71000, 72000, 73000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 12650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesCollapsedEdge) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); AutoShardingSolverRequest_Pair edge; edge.set_first(2); @@ -373,52 +374,53 @@ TEST(CallORToolsSolverTest, HandlesCollapsedEdge) { 80000, 81000, 82000, 83000}}; AddCosts(request.mutable_duration_costs(), t); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 13972.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, UsesHint) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, UsesHint) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const auto s_hint = {1, 0, 0, 0, 0}; // Not optimal, but close. request.mutable_s_hint()->Add(s_hint.begin(), s_hint.end()); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HonorsMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HonorsMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(7600.0); // Best possible is 7650.0 - const AutoShardingSolverResult result = CallORToolsSolver(request); + const absl::StatusOr result = + FormulateAndSolveMIPFromSolverRequest(request); - EXPECT_TRUE(absl::IsInternal(result.status.status())); + EXPECT_TRUE(absl::IsInternal(result.status())); } -TEST(CallORToolsSolverTest, HandlesExtremelyHighMaxCost) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesExtremelyHighMaxCost) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); request.mutable_max_cost()->set_coeff(1e19); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const EdgeMatrix live_edges = {{}, {0}, {0, 1}, {1}, {}}; const CostMatrix memory_edge_costs = {{1000000, 1100, 1200, 1300, @@ -432,16 +434,16 @@ TEST(CallORToolsSolverTest, HandlesMemoryEdgeCosts) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesIntervals) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, HandlesIntervals) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{0, 4}, {0, 4}, {2, 3}, {3, 4}, {100, -1}}; @@ -460,16 +462,17 @@ TEST(CallORToolsSolverTest, HandlesIntervals) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroups) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -492,16 +495,17 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); request.set_enable_memory_edge_costs(true); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 1, 1, 0}; const double objective_value = 7872.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -511,16 +515,17 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroupsNoMemoryEdgeCosts) { AddGroups(request.mutable_node_groups(), node_groups); request.set_enable_memory_edge_costs(false); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, + HandlesGroupsWithTinyMemoryCosts) { AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); const std::vector> node_intervals = {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; @@ -551,26 +556,26 @@ TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { request.set_enable_memory_edge_costs(true); request.set_memory_budget(4321); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 0, 0, 0}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } -TEST(CallORToolsSolverTest, SolvesWithEquivalences) { +TEST(FormulateAndSolveMIPFromSolverRequestTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); - const AutoShardingSolverResult result = CallORToolsSolver(request); + TF_ASSERT_OK_AND_ASSIGN(const AutoShardingSolverOutput result, + FormulateAndSolveMIPFromSolverRequest(request)); const std::vector s_val = {0, 0, 5, 5, 1}; const double objective_value = 7650.0; const AutoShardingSolverOutput expected_output = {s_val, objective_value}; - const AutoShardingSolverResult expected_result = {expected_output, false}; - EXPECT_EQ(result, expected_result); + EXPECT_EQ(result, expected_output); } TEST(AutoShardingEvaluatorTest, NoViolations) { @@ -578,9 +583,8 @@ TEST(AutoShardingEvaluatorTest, NoViolations) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 159.0; // 13+21+32+42+51 @@ -600,9 +604,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudget) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -628,9 +631,8 @@ TEST(AutoShardingEvaluatorTest, EvaluatesOverbudgetWithIntervals) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -659,9 +661,8 @@ TEST(AutoShardingEvaluatorTest, const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.total.computation_cost = 158.0; // 12+21+32+42+51 @@ -681,9 +682,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesFollower) { const std::vector s_val = {3, 1, 2, 1 /* violates */, 1}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kFollowerViolationCode}; @@ -702,9 +702,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesAlias) { const std::vector s_val = {3, 1, 2, 2, 0 /* violates */}; const double objective_value = 12138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kAliasViolationCode}; @@ -723,9 +722,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMemory) { const std::vector s_val = {2 /* violates */, 1, 2, 2, 1}; const double objective_value = 11138.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMemoryViolationCode}; @@ -747,9 +745,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForNode) { const std::vector s_val = {0 /* violates */, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -769,9 +766,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesInfiniteCostForEdge) { const std::vector s_val = {0, 1, 2, 2, 1}; const double objective_value = 1e+20; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kInfiniteCostViolationCode}; @@ -791,9 +787,8 @@ TEST(AutoShardingEvaluatorTest, ViolatesMaxDepartures) { const std::vector s_val = {3, 1, 2, 2, 1}; const double objective_value = 12149.0; const AutoShardingSolverOutput output = {s_val, objective_value}; - const AutoShardingSolverResult result = {output, false}; - const AutoShardingEvaluation evaluation = Evaluate(request, result); + const AutoShardingEvaluation evaluation = Evaluate(request, output); AutoShardingEvaluation expected_evaluation; expected_evaluation.violation_codes = {kMaxDeparturesViolationCode}; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 91d40860c73255..e1a81dbf33772c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -139,38 +139,44 @@ ComputeSliceShardingAndCommunicationCostFromOperand( // the original implementation), but it should be easy to generalize if needed. void GenerateScatterShardingFromOperands( const HloScatterInstruction* scatter, const HloSharding& data_sharding, - const HloSharding& indices_sharding, const HloSharding& update_sharding, - const CallGraph& call_graph, + const HloSharding& update_sharding, const CallGraph& call_graph, absl::FunctionRef yield_sharding) { + std::vector scatter_shardings; + auto scatter_shardings_insert = [&](const HloSharding& sharding) { + const auto it = + std::find(scatter_shardings.begin(), scatter_shardings.end(), sharding); + if (it == scatter_shardings.end()) scatter_shardings.push_back(sharding); + }; CHECK_EQ(scatter->scatter_operand_count(), 1); - const HloInstruction* scatter_data = scatter->scatter_operands()[0]; - const HloInstruction* scatter_indices = scatter->scatter_indices(); - const HloInstruction* scatter_update = scatter->scatter_updates()[0]; - yield_sharding(data_sharding, indices_sharding, update_sharding, - data_sharding); + const HloSharding& indices_sharding = hlo_sharding_util:: + ScatterIndexShardingFromUpdateIndexPassthroughDimensions(update_sharding, + scatter); + scatter_shardings_insert(data_sharding); if (std::optional maybe_from_update = hlo_sharding_util::ScatterOutputShardingFromUpdate(update_sharding, *scatter)) { - yield_sharding(data_sharding, indices_sharding, update_sharding, - *maybe_from_update); + scatter_shardings_insert(*maybe_from_update); } std::optional scatter_parallel_dims = hlo_sharding_util::GetScatterParallelBatchDims(*scatter, call_graph); if (!scatter_parallel_dims) { + for (const HloSharding& sharding : scatter_shardings) { + yield_sharding(data_sharding, indices_sharding, update_sharding, + sharding); + } return; } absl::InlinedVector aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims( - *scatter_parallel_dims); + scatter_parallel_dims->operand_parallel_dims; absl::InlinedVector update_parallel_dims = hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, *scatter_parallel_dims); @@ -178,29 +184,29 @@ void GenerateScatterShardingFromOperands( aligned_operand_parallel_dims; // Infer output sharding from scatter operand sharding. const Shape& shape = scatter->shape(); - yield_sharding( - data_sharding, indices_sharding, update_sharding, + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - data_sharding, scatter_data->shape(), shape, + data_sharding, shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter indices sharding. - HloSharding parallel_sharding_from_indices = + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - indices_sharding, scatter_indices->shape(), shape, + indices_sharding, shape, absl::MakeConstSpan(scatter_parallel_dims->indices_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)); - yield_sharding(data_sharding, indices_sharding, update_sharding, - parallel_sharding_from_indices); + absl::MakeConstSpan(output_parallel_dims))); // Infer output sharding from scatter update sharding. - yield_sharding( - data_sharding, indices_sharding, update_sharding, + scatter_shardings_insert( hlo_sharding_util::InferGatherScatterParallelShardingFromOperandSharding( - update_sharding, scatter_update->shape(), shape, - absl::MakeConstSpan(update_parallel_dims), + update_sharding, shape, absl::MakeConstSpan(update_parallel_dims), absl::MakeConstSpan(output_parallel_dims))); + + for (const HloSharding& scatter_sharding : scatter_shardings) { + yield_sharding(data_sharding, indices_sharding, update_sharding, + scatter_sharding); + } } // NOLINTBEGIN(readability/fn_size) @@ -211,8 +217,7 @@ BuildStrategyAndCost( const absl::flat_hash_set& instructions_to_shard, const absl::flat_hash_map& instruction_execution_counts, - const InstructionDepthMap& depth_map, - const InstructionBatchDimMap& batch_dim_map, const AliasMap& alias_map, + const InstructionDepthMap& depth_map, const AliasMap& alias_map, const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes) { @@ -304,9 +309,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, option.allow_replicated_parameters, /* create_partially_replicated_strategies */ true) .value(); break; @@ -316,9 +320,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, - option.allow_replicated_parameters, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, option.allow_replicated_parameters, /* create_partially_replicated_strategies */ true) .value(); break; @@ -343,14 +346,14 @@ BuildStrategyAndCost( ByteSizeOfShapeWithSharding(ins->shape(), scatter_sharding); InputShardings input_shardings_optional( - {data_sharding, indices_sharding, update_sharding}); + {name, {data_sharding, indices_sharding, update_sharding}}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, scatter_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); strategy_group->AddStrategy( - ShardingStrategy({name, scatter_sharding, compute_cost, + ShardingStrategy({scatter_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), @@ -359,18 +362,15 @@ BuildStrategyAndCost( const HloScatterInstruction* scatter = Cast(ins); const HloInstruction* scatter_data = scatter->scatter_operands()[0]; - const HloInstruction* scatter_indices = scatter->scatter_indices(); const HloInstruction* scatter_update = scatter->scatter_updates()[0]; ForEachInCartesianProduct( {strategy_map.at(scatter_data)->GetStrategies(), - strategy_map.at(scatter_indices)->GetStrategies(), strategy_map.at(scatter_update)->GetStrategies()}, [&](const std::vector& operand_shardings) { GenerateScatterShardingFromOperands( scatter, operand_shardings[0].output_sharding, - operand_shardings[1].output_sharding, - operand_shardings[2].output_sharding, call_graph, + operand_shardings[1].output_sharding, call_graph, add_scatter_sharding); }); @@ -397,15 +397,14 @@ BuildStrategyAndCost( double memory_cost = ByteSizeOfShapeWithSharding(gather_shape, output_sharding); InputShardings input_shardings_optional( - {data_sharding, indices_sharding}); + {output_sharding.ToString(), {data_sharding, indices_sharding}}); std::pair resharding_costs = GenerateReshardingCostsAndMissingShardingsForAllOperands( ins, output_sharding, strategy_map, cluster_env, call_graph, input_shardings_optional); strategy_group->AddStrategy( - ShardingStrategy({std::string(output_sharding.ToString()), - output_sharding, compute_cost, + ShardingStrategy({output_sharding, compute_cost, communication_cost, memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), @@ -436,8 +435,7 @@ BuildStrategyAndCost( HloSharding output_spec = indices_to_combine_spec; if (gather_parallel_dims) { auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims( - *gather_parallel_dims); + gather_parallel_dims->operand_parallel_dims; auto output_parallel_dims = hlo_sharding_util::GetGatherParallelOutputDims( *ins, *gather_parallel_dims); @@ -445,37 +443,33 @@ BuildStrategyAndCost( if (hlo_sharding_util::IsSpatiallyPartitioned(data_spec)) { const HloSharding to_merge = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - data_spec, data->shape(), gather_shape, + data_spec, gather_shape, absl::MakeConstSpan(aligned_operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims)); if (std::optional improved_spec = ConstructImprovedSharding( to_merge, output_spec, gather_shape, - /* may_combine_partial_sharding */ true, - /* allow_aggressive_resharding */ false)) { + /*may_combine_partial_sharding=*/true, + /*allow_aggressive_resharding=*/false)) { output_spec = *improved_spec; add_sharding_strategy(data_spec, indices_spec, output_spec); - } else { - add_sharding_strategy(data_spec, indices_spec, to_merge); } } // Infer output sharding from scatter indices sharding. if (hlo_sharding_util::IsSpatiallyPartitioned(indices_spec)) { const HloSharding to_merge = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - indices_spec, indices->shape(), gather_shape, + indices_spec, gather_shape, absl::MakeConstSpan( gather_parallel_dims->indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims)); if (std::optional improved_spec = ConstructImprovedSharding( to_merge, output_spec, gather_shape, - /* may_combine_partial_sharding */ true, - /* allow_aggressive_resharding */ false)) { + /*may_combine_partial_sharding=*/true, + /*allow_aggressive_resharding=*/false)) { output_spec = *improved_spec; add_sharding_strategy(data_spec, indices_spec, output_spec); - } else { - add_sharding_strategy(data_spec, indices_spec, to_merge); } } } @@ -499,17 +493,15 @@ BuildStrategyAndCost( if (std::optional improved_spec = ConstructImprovedSharding( *maybe_from_data, output_spec, gather_shape, - /* may_combine_partial_sharding */ true, - /* allow_aggressive_resharding */ false)) { + /*may_combine_partial_sharding=*/true, + /*allow_aggressive_resharding=*/false)) { output_spec = *improved_spec; add_sharding_strategy(data_spec, indices_spec, output_spec); - } else { - add_sharding_strategy(data_spec, indices_spec, *maybe_from_data); } } } AddReplicatedStrategy(ins, ins->shape(), cluster_env, strategy_map, 0, - /* operands_to_consider_all_strategies_for */ {0}, + /*operands_to_consider_all_strategies_for=*/{}, *strategy_group); break; } @@ -517,8 +509,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -527,8 +519,8 @@ BuildStrategyAndCost( case HloOpcode::kReshape: { strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups, call_graph); + only_allow_divisible, replicated_penalty, option, strategy_groups, + call_graph); break; } case HloOpcode::kTranspose: @@ -566,14 +558,13 @@ BuildStrategyAndCost( MemoryReshardingCostVector(src_strategy_group, operand->shape(), input_spec, cluster_env); strategy_group->AddStrategy( - ShardingStrategy({name, - output_spec, + ShardingStrategy({output_spec, compute_cost, communication_cost, memory_cost, {communication_resharding_costs}, {memory_resharding_costs}}), - {input_spec}); + {name, {input_spec}}); } break; } @@ -686,9 +677,9 @@ BuildStrategyAndCost( if (k == follow_idx || ToString(ins->operand(k)->shape().dimensions()) == ToString(operand->shape().dimensions())) { - input_shardings.push_back(input_spec); + input_shardings.shardings.push_back(input_spec); } else { - input_shardings.push_back(std::nullopt); + input_shardings.shardings.push_back(std::nullopt); } } if (!output_spec.has_value()) { @@ -704,11 +695,10 @@ BuildStrategyAndCost( input_shardings); strategy_group->AddStrategy( - ShardingStrategy({name, *output_spec, compute_cost, - communication_cost, memory_cost, - std::move(resharding_costs.first), + ShardingStrategy({*output_spec, compute_cost, communication_cost, + memory_cost, std::move(resharding_costs.first), std::move(resharding_costs.second)}), - {input_spec}); + {name, {input_spec}}); } if (strategy_group->GetStrategies().empty()) { @@ -734,8 +724,8 @@ BuildStrategyAndCost( } else { strategy_group = CreateReshapeStrategies( instruction_id, ins, strategy_map, cluster_env, - only_allow_divisible, replicated_penalty, batch_dim_map, option, - strategy_groups, call_graph); + only_allow_divisible, replicated_penalty, option, strategy_groups, + call_graph); } break; } @@ -809,10 +799,9 @@ BuildStrategyAndCost( break; } case HloOpcode::kDot: { - TF_RETURN_IF_ERROR(HandleDot(strategy_group, strategy_groups, - strategy_map, ins, instruction_id, - sequence, hlo_cost_analysis, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleDot( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, option, call_graph)); if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( @@ -824,10 +813,9 @@ BuildStrategyAndCost( break; } case HloOpcode::kConvolution: { - TF_RETURN_IF_ERROR(HandleConv(strategy_group, strategy_groups, - strategy_map, ins, instruction_id, - sequence, hlo_cost_analysis, cluster_env, - batch_dim_map, option, call_graph)); + TF_RETURN_IF_ERROR(HandleConv( + strategy_group, strategy_groups, strategy_map, ins, instruction_id, + sequence, hlo_cost_analysis, cluster_env, option, call_graph)); if (option.allow_recompute_heavy_op) { AddReplicatedStrategy( ins, ins->shape(), cluster_env, strategy_map, @@ -848,8 +836,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -920,7 +908,7 @@ BuildStrategyAndCost( CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, - batch_dim_map, call_graph, only_allow_divisible, + call_graph, only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -984,8 +972,8 @@ BuildStrategyAndCost( strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, - strategy_map, option, replicated_penalty, batch_dim_map, - call_graph, only_allow_divisible, + strategy_map, option, replicated_penalty, call_graph, + only_allow_divisible, /* create_replicated_strategies */ true, /* create_partially_replicated_strategies */ true) .value(); @@ -1032,11 +1020,6 @@ BuildStrategyAndCost( } CHECK(strategy_group != nullptr); RemoveDuplicatedStrategy(*strategy_group); - if (!option.allow_shardings_small_dims_across_many_devices) { - RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( - ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(), - *strategy_group); - } if (ins->has_sharding() && ins->opcode() != HloOpcode::kOutfeed) { // Finds the sharding strategy that aligns with the given sharding spec // Do not merge nodes if this one instruction has annotations. @@ -1045,6 +1028,11 @@ BuildStrategyAndCost( cluster_env, pretrimmed_strategy_map, call_graph, option.nd_sharding_iteratively_strict_search_space, *strategy_group); } + if (!option.allow_shardings_small_dims_across_many_devices) { + RemoveShardingsWhereSmallDimsShardedAcrossManyDevices( + ins->shape(), /* instruction_has_user_sharding */ ins->has_sharding(), + *strategy_group); + } if (!strategy_group->is_tuple && strategy_group->following) { if (!LeafVectorsAreConsistent( strategy_group->GetStrategies(), @@ -1097,7 +1085,7 @@ BuildStrategyAndCost( const InputShardings& input_shardings = strategy_input_shardings[iid]; const ShardingStrategy& strategy = strategy_group->GetStrategyForInputShardings(iid); - if (strategy.name == stra_names[idx]) { + if (input_shardings.name == stra_names[idx]) { new_strategies.push_back({strategy, input_shardings}); } } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h index 04c15b20e9aa15..49212fe84ce655 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h @@ -77,12 +77,37 @@ using ReshardingCache = ConstInstructionMap>>; // Resharding costs for each operand using ReshardingCosts = std::vector>; -// Optional shardings for each operand -using InputShardings = std::vector>; + +// A named vector of optional shardings for each operand. +struct InputShardings { + std::string name; + std::vector> shardings; + + std::string ToString() const { + std::string str = absl::StrCat(name, " "); + for (const auto& s : shardings) { + if (!s.has_value()) { + absl::StrAppend(&str, "[*],"); + } else if (s->IsReplicated()) { + absl::StrAppend(&str, "[R],"); + } else { + if (s->ReplicateOnLastTileDim()) { + absl::StrAppend( + &str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "), + "]last_tile_dim_replicate,"); + } else { + absl::StrAppend( + &str, "[", absl::StrJoin(s->tile_assignment().dimensions(), ", "), + "],"); + } + } + } + return str; + } +}; // One sharding strategy struct ShardingStrategy { - std::string name; HloSharding output_sharding; double compute_cost; double communication_cost; @@ -94,9 +119,7 @@ struct ShardingStrategy { ReshardingCosts communication_resharding_costs; ReshardingCosts memory_resharding_costs; - std::string ToString() const { - return absl::StrCat(name, ", ", output_sharding.ToString()); - } + std::string ToString() const { return output_sharding.ToString(); } std::string ToStringLong() const { std::vector communication_resharding_vector_strings; @@ -119,7 +142,7 @@ struct ShardingStrategy { "{", absl::StrJoin(memory_resharding_vector_strings, ", "), "}"); return absl::StrCat( - name, ", ", output_sharding.ToString(), ", compute_cost=", compute_cost, + output_sharding.ToString(), ", compute_cost=", compute_cost, ", communication_cost=", communication_cost, ", memory_cost=", memory_cost, ", communication_resharding_costs=", communication_resharding_cost_str, @@ -127,7 +150,7 @@ struct ShardingStrategy { } bool operator==(const ShardingStrategy& other) const { - return name == other.name && output_sharding == other.output_sharding && + return output_sharding == other.output_sharding && compute_cost == other.compute_cost && communication_cost == other.communication_cost && memory_cost == other.memory_cost && @@ -221,25 +244,8 @@ struct StrategyGroup { } if (!is_tuple) { for (const auto& input_shardings : strategy_input_shardings) { - std::string input_sharding_str = "{"; - for (const auto& s : input_shardings) { - if (!s.has_value()) { - input_sharding_str += "[*],"; - } else if (s->IsReplicated()) { - input_sharding_str += "[R],"; - } else { - if (s->ReplicateOnLastTileDim()) { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "]last_tile_dim_replicate,"; - } else { - input_sharding_str += - "[" + absl::StrJoin(s->tile_assignment().dimensions(), ", ") + - "],"; - } - } - } - input_sharding_str += "}\n"; + const std::string input_sharding_str = + absl::StrCat("{", input_shardings.ToString(), "}\n"); absl::StrAppend(&str, indent, "Input Sharding ", input_sharding_str); } } @@ -313,6 +319,10 @@ struct StrategyGroup { return strategies[strategy_idx]; } + size_t GetStrategyIdxForInputShardings(size_t input_sharding_idx) const { + return input_sharding_idx_to_strategy_idx[input_sharding_idx]; + } + const InputShardings& GetInputShardings(size_t input_sharding_idx) const { return strategy_input_shardings[input_sharding_idx]; } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 2b0c2aec59e6f9..ad18259d411fa5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -28,8 +28,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -41,14 +43,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/buffer_value.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -63,13 +65,16 @@ using ::testing::ElementsAre; using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::FieldsAre; +using ::testing::HasSubstr; using ::testing::IsEmpty; using ::testing::IsFalse; using ::testing::IsTrue; using ::testing::Not; using ::testing::Pair; using ::testing::ResultOf; +using ::testing::SizeIs; using ::testing::UnorderedElementsAre; +using ::testing::status::StatusIs; TEST(DeviceMeshTest, IotaDeviceMesh2DStartsWith0) { DeviceMesh device_mesh({2, 4}); @@ -139,7 +144,7 @@ TEST(DeviceMeshTest, ReshapeTestWithIota) { EXPECT_EQ(device_mesh.num_elements(), 64); } -class AutoShardingTest : public HloTestBase { +class AutoShardingTest : public HloHardwareIndependentTestBase { protected: const absl::string_view kDotHloString = R"( HloModule module @@ -272,6 +277,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DAllOptions) { option.device_mesh_ids = {0, 1, 2, 3}; option.device_mesh_alpha = {1.0, 1.0}; option.device_mesh_beta = {0.01, 1.0}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -288,6 +294,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBeta) { option.enable = true; option.device_mesh_shape = {2, 2}; option.device_mesh_ids = {0, 1, 2, 3}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -304,6 +311,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoAlphaBetaMeshIds) { AutoShardingOption option; option.enable = true; option.device_mesh_shape = {2, 2}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -322,6 +330,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape2DNoMeshIds) { option.device_mesh_shape = {2, 2}; option.device_mesh_alpha = {1.0, 1.0}; option.device_mesh_beta = {0.01, 1.0}; + option.allow_mixed_mesh_shape = false; RunMatMulAutoShardingWithOptions(option, 4, 2); option.enable = true; @@ -349,6 +358,7 @@ TEST_F(AutoShardingTest, MatmulMeshShape3DAllOptions) { TEST_F(AutoShardingTest, Matmul3DMeshShape2DSharding) { AutoShardingOption option; option.enable = true; + option.allow_mixed_mesh_shape = false; option.device_mesh_shape = {1, 2, 2}; RunMatMulAutoShardingWithOptions(option, 4, 2); @@ -458,7 +468,7 @@ TEST_F(AutoShardingTest, LargeSize) { option.device_mesh_alpha = {1.0, 1.0, 1.0, 1.0}; option.device_mesh_beta = {1.0, 1.0, 1.0, 1.0}; option.memory_budget_per_device = (8192 + 8192 * 2 + 8192 * 4 / 8); - RunMatMulAutoShardingWithOptions(option, 7, 1); + RunMatMulAutoShardingWithOptions(option, 56, 1); } TEST_F(AutoShardingTest, InvalidOptions) { @@ -583,11 +593,62 @@ ENTRY %elementwise { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); VLOG(10) << module->ToString(); EXPECT_TRUE(changed); - auto* instruction = FindInstruction(module.get(), "param0"); + const HloInstruction* instruction = FindInstruction(module.get(), "param0"); ASSERT_NE(instruction, nullptr); EXPECT_THAT(instruction, op::Sharding("{devices=[2,2]0,2,1,3}")); } +TEST_F(AutoShardingTest, RedundantReshardReshapeTest) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %elementwise { + param0 = bf16[4,68096]{1,0} parameter(0), sharding={replicated} + param1 = s8[68096,8512]{1,0} parameter(1), sharding={devices=[1,16,8]<=[128] last_tile_dim_replicate} + convolution = bf16[4,8512]{1,0} convolution(param0, param1), dim_labels=bf_io->bf, sharding={devices=[1,128]<=[128]} + +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.enable = true; + option.device_mesh_shape = {16, 8}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(5) << module->ToString(); + EXPECT_TRUE(changed); + std::vector instructions = + module->entry_computation()->MakeInstructionPostOrder(); + // Check that only one additional reshape op is inserted + EXPECT_THAT(instructions, SizeIs(4)); +} + +TEST_F(AutoShardingTest, ConvolutionSplitDepthwiseTest) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %elementwise { + %param0 = pred[512,1,1024,512]{3,2,1,0} parameter(0) + %param1 = f32[1,1,5,5]{3,2,1,0} parameter(1) + %convolution = f32[512,1,1024,512]{3,2,1,0} convolution(pred[512,1,1024,512]{3,2,1,0} %param0, f32[1,1,5,5]{3,2,1,0} %param1), window={size=5x5 pad=2_2x2_2}, dim_labels=bf01_oi01->bf01 + ROOT %copy = f32[512,1,1024,512]{3,2,1,0} copy(%convolution) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {512}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(5) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* convolution = + FindInstruction(module.get(), "convolution"); + ASSERT_NE(convolution, nullptr); + ASSERT_TRUE(convolution->has_sharding()); + EXPECT_EQ(convolution->sharding().NumTiles(), 512); +} + TEST_F(AutoShardingTest, NDIterativeSolveTest) { constexpr absl::string_view kHloString = R"( HloModule module @@ -716,6 +777,7 @@ ENTRY %elementwise { .enable = true, .preserve_shardings = AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .allow_mixed_mesh_shape = false, .only_allow_divisible_input_output = false, .device_mesh_shape = {16, 16}, .device_mesh_alpha = {1.0, 1.0}, @@ -997,8 +1059,11 @@ ENTRY main { all-reduce.1 = bf16[128,128]{1,0} all-reduce(custom-call.2), channel_id=621, replica_groups={{0,1,2,3},{4,5,6,7},{8,9,10,11},{12,13,14,15}}, use_global_device_ids=true, to_apply=add.6.clone, frontend_attributes={from-cross-replica-sharding="true"}, backend_config={"flag_configs":[],"barrier_config":{"barrier_type":"CUSTOM","id":"9"},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_DEFAULT","device_type":"DEVICE_TYPE_INVALID","used_scoped_memory_configs":[]} add.1 = bf16[128,128]{1,0} add(bf16[128,128]{1,0} all-reduce.1, bf16[128,128]{1,0} broadcast.1) custom-call.3 = bf16[512,512]{1,0} custom-call(add.1), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,4]<=[16]last_tile_dim_replicate} + partition-id.1 = u32[] partition-id() + broadcast.3 = u32[512,512]{1,0} broadcast(partition-id.1) + custom-call.4 = u32[512,512]{1,0} custom-call(broadcast.3), custom_call_target="SPMDShardToFullShape", sharding={devices=[4,1,4]<=[16]last_tile_dim_replicate} add.2 = bf16[512,512]{1,0} add(bf16[512,512]{1,0} custom-call.3, bf16[512,512]{1,0} broadcast.2) - ROOT copy.1 = bf16[512,512]{1,0} copy(add.2) + ROOT tuple.1 = (bf16[512,512]{1,0}, u32[512,512]{1,0}) tuple(add.2, custom-call.4) })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -1473,16 +1538,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_FALSE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_FALSE(saved_shardings_result.module_is_changed); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); EXPECT_THAT(instructions, @@ -1531,16 +1597,16 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {"dot"}, - /* execution_threads */ {}); + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {"dot"}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_FALSE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_FALSE(saved_shardings_result.module_is_changed); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); EXPECT_THAT(instructions, @@ -1586,16 +1652,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); // Dot does not have shardings anymore. const HloInstruction* dot = FindInstruction(module.get(), "dot"); @@ -1670,16 +1737,17 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {}, - /* execution_threads */ {}); + + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); EXPECT_THAT(saved_shardings, IsEmpty()); std::vector instructions = module->entry_computation()->MakeInstructionPostOrder(); @@ -1708,16 +1776,16 @@ ENTRY %entry (param0: f32[4,256,64], param1: f32[4,256,32]) -> f32[64,32] { absl::flat_hash_set instructions_to_shard( module->entry_computation()->instructions().begin(), module->entry_computation()->instructions().end()); - std::pair>, bool> - saved_shardings_result = - AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( - module.get(), instructions_to_shard, - /* replicated_small_tensors */ {"dot", "copy"}, - /* execution_threads */ {}); + TF_ASSERT_OK_AND_ASSIGN( + AutoShardingImplementation::SaveShardingAnnotationsResult + saved_shardings_result, + AutoShardingImplementation(option).SaveAndRemoveShardingAnnotation( + module.get(), instructions_to_shard, + /* replicated_small_tensors */ {"dot", "copy"}, + /* execution_threads */ {})); absl::flat_hash_map> saved_shardings = - saved_shardings_result.first; - bool changed = saved_shardings_result.second; - EXPECT_TRUE(changed); + saved_shardings_result.preserved_shardings; + EXPECT_TRUE(saved_shardings_result.module_is_changed); // params have no shardings. const HloInstruction* param0 = FindInstruction(module.get(), "param0"); @@ -1986,8 +2054,9 @@ ENTRY %entry { option.device_mesh_ids = {0, 1, 2, 3, 4, 5, 6, 7}; option.device_mesh_alpha = {1.0, 1.0, 1.0}; option.device_mesh_beta = {0.01, 1.0, 1.0}; + option.memory_budget_per_device = (1000 * 128 + 8 * 128) / 8 + 8; TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); - VLOG(10) << module->ToString(); + VLOG(5) << module->ToString(); EXPECT_TRUE(changed); const HloInstruction* gather = FindInstruction(module.get(), "gather"); const HloInstruction* data = FindInstruction(module.get(), "data"); @@ -2597,6 +2666,36 @@ ENTRY %entry { EXPECT_THAT(reshape, op::Sharding("{devices=[1,32,1,1]<=[32]}")); } +TEST_F(AutoShardingTest, ShardingCustomCallInvalidUserSharding) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %entry { + concatenate.824 = bf16[256,4096,4,256]{3,2,1,0} parameter(0) + custom-call.5828 = bf16[256,4096,4,256]{3,2,1,0} custom-call(concatenate.824), custom_call_target="Sharding", sharding={devices=[32,1,4,1,2]<=[8,4,8]T(1,0,2) last_tile_dim_replicate} + ROOT copy = copy(custom-call.5828) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.device_mesh_shape = {8, 4, 8}; + option.device_mesh_alpha = {1.0, 1.0, 1.0}; + option.device_mesh_beta = {1.0, 1.0, 1.0}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + const HloInstruction* sharding_call_copy = + FindInstruction(module.get(), "copy")->operand(0); + EXPECT_THAT( + sharding_call_copy, + op::Sharding( + "{devices=[32,1,4,1,2]<=[8,4,8]T(1,0,2) last_tile_dim_replicate}")); +} + TEST_F(AutoShardingTest, Broadcast) { constexpr absl::string_view kHloString = R"( HloModule module @@ -2807,6 +2906,123 @@ ENTRY %entry { EXPECT_THAT(slice1, op::Sharding("{replicated}")); } +TEST_F(AutoShardingTest, CrashIfAskedToRespectShardAsShardLike) { + const char* const kHloString = R"( +HloModule module +ENTRY matmul { + param1 = f32[32,64]{1,0} parameter(0) + param2 = f32[64,128]{1,0} parameter(1) + custom-call1 = f32[32,64]{1,0} custom-call(param1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + custom-call2 = f32[64,128]{1,0} custom-call(param2), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + ROOT root = f32[32,128]{1,0} dot(custom-call1, custom-call2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + option.enable = true; + option.device_mesh_shape = {4, 1}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + // TODO(b/369616683) Fix the error message output in this case. + EXPECT_THAT(AutoSharding(option).Run(module.get()), + StatusIs(absl::StatusCode::kInternal, + HasSubstr("Auto-sharding currently does not support " + "shard_as/shard_like sharding annotations"))); +} + +TEST_F(AutoShardingTest, IgnoreShardAsShardLike) { + const char* const kHloString = R"( +HloModule module +ENTRY matmul { + param1 = f32[32,64]{1,0} parameter(0) + param2 = f32[64,128]{1,0} parameter(1) + custom-call1 = f32[32,64]{1,0} custom-call(param1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + custom-call2 = f32[64,128]{1,0} custom-call(param2), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 0} + ROOT root = f32[32,128]{1,0} dot(custom-call1, custom-call2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kRemoveAllShardings; + option.enable = true; + option.device_mesh_shape = {4, 1}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + EXPECT_TRUE(changed); +} + +TEST_F(AutoShardingTest, SimpleDCNTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +%func (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param0 = f32[32,8192]{1,0} parameter(0) + %param1 = f32[] parameter(1) + %reduce = f32[32]{0} reduce(f32[32,8192]{1,0} %param0, f32[] %param1), dimensions={1}, to_apply=%func + })"; + AutoShardingOption option; + option.enable = true; + option.solve_nd_sharding_iteratively = false; + option.allow_mixed_mesh_shape = false; + option.device_mesh_shape = {8, 16}; + option.num_dcn_slices = 8; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(5) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* slice = FindInstruction(module.get(), "reduce"); + EXPECT_NE(slice, nullptr); + EXPECT_THAT(slice, + op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}")); +} + +TEST_F(AutoShardingTest, MissingAxisDCNTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +%func (x: f32[], y: f32[]) -> f32[] { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) +} + +ENTRY %entry { + %param0 = f32[32,8192]{1,0} parameter(0) + %param1 = f32[] parameter(1) + %reduce = f32[32]{0} reduce(f32[32,8192]{1,0} %param0, f32[] %param1), dimensions={1}, to_apply=%func + })"; + AutoShardingOption option; + option.enable = true; + option.solve_nd_sharding_iteratively = false; + option.allow_mixed_mesh_shape = false; + option.device_mesh_shape = {8, 16}; + option.num_dcn_slices = 4; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(5) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* slice = FindInstruction(module.get(), "reduce"); + EXPECT_NE(slice, nullptr); + EXPECT_THAT(slice, + op::Sharding("{devices=[8,16]<=[128] last_tile_dim_replicate}")); +} + TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { EdgeReshardingCostMatrix edge_cost(2, 2); edge_cost(0, 0).communication_cost = -100; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 5ee4d464ff116a..ccdbbe4c0ff137 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -41,6 +41,7 @@ limitations under the License. #include "absl/types/span.h" #include "json/json.h" #include "xla/array.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_computation.h" @@ -55,7 +56,6 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" #include "xla/service/sharding_propagation.h" -#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_tree.h" @@ -851,8 +851,8 @@ void RemoveDuplicatedStrategy(StrategyGroup& strategy_group) { continue; } std::string key = strategy.output_sharding.ToString(); - if (!input_shardings.empty()) { - for (const auto& sharding : input_shardings) { + if (!input_shardings.shardings.empty()) { + for (const auto& sharding : input_shardings.shardings) { key += "/" + (sharding.has_value() ? sharding->ToString() : "none"); } } @@ -1051,17 +1051,17 @@ absl::StatusOr CheckArithmeticSequence( return delta; } -bool IsValidTileAssignment(const HloSharding& spec) { - if (IsUndefined(spec)) { +bool IsValidTileAssignment(const HloSharding& sharding) { + if (IsUndefined(sharding)) { return false; } - if (spec.IsReplicated()) { + if (sharding.IsReplicated()) { return true; } // Check all tile dims - const auto& tile_assignment = spec.tile_assignment(); + const auto& tile_assignment = sharding.tile_assignment(); for (int i = 0; i < tile_assignment.num_dimensions(); i++) { if (tile_assignment.dim(i) != 1) { std::vector device_ids = @@ -1076,24 +1076,24 @@ bool IsValidTileAssignment(const HloSharding& spec) { return true; } -int64_t NumTileDimensions(const HloSharding& spec) { - if (spec.IsReplicated()) { +int64_t NumTileDimensions(const HloSharding& sharding) { + if (sharding.IsReplicated()) { return -1; } int64_t num_tile_dims = 0; - for (int i = 0; i < spec.tile_assignment().num_dimensions(); i++) { - if (spec.tile_assignment().dim(i) != 1) { + for (int i = 0; i < sharding.tile_assignment().num_dimensions(); i++) { + if (sharding.tile_assignment().dim(i) != 1) { num_tile_dims++; } } return num_tile_dims; } -bool TileAssignmentMatchesMesh(const HloSharding& spec, +bool TileAssignmentMatchesMesh(const HloSharding& sharding, const DeviceMesh& mesh) { int sharded_dims = 0; - for (int i = 0; i < spec.tile_assignment().num_dimensions(); ++i) { - if (spec.tile_assignment().dim(i) > 1) { + for (int i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) > 1) { sharded_dims++; } } @@ -1193,6 +1193,55 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( return tensor_dim_to_device_dim; } +absl::StatusOr>> +GetTensorDimToMeshDimMixedMeshSharding(int64_t tensor_shape_rank, + const HloSharding& sharding, + const DeviceMesh& device_mesh, + bool consider_reverse_device_meshes) { + CHECK(!sharding.IsReplicated()); + // Check the compatibility of tensor_shape_rank and spec + if (tensor_shape_rank != sharding.TiledDataRank()) { + return absl::InvalidArgumentError( + "Tensor shape rank should be equal to the tiled data rank of the input " + "spec."); + } + if (!TileAssignmentMatchesMesh(sharding, device_mesh)) { + return absl::InvalidArgumentError( + "Device mesh and tile assignment need to have the same number of " + "sharded dims."); + } + + TF_ASSIGN_OR_RETURN( + std::vector axes, + GetMeshDimPermutationOrderInShardingSpec(sharding, device_mesh, + consider_reverse_device_meshes)); + + std::vector> tensor_dim_to_mesh_axis_mapping; + int mesh_axis_idx = 0; + for (int i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + if (sharding.tile_assignment().dim(i) == 1) { + tensor_dim_to_mesh_axis_mapping.push_back({}); + continue; + } + + absl::btree_set mesh_axes_for_this_tensor_dim; + int product = 1; + do { + if (mesh_axis_idx >= device_mesh.num_dimensions()) { + return absl::InternalError( + "Mismatched mesh shapes encountered. This can happen when the " + "sharding does not map well to the mesh shape provided"); + } + product *= device_mesh.dim(axes[mesh_axis_idx]); + mesh_axes_for_this_tensor_dim.insert(axes[mesh_axis_idx]); + mesh_axis_idx++; + } while (product < sharding.tile_assignment().dim(i)); + CHECK(!mesh_axes_for_this_tensor_dim.empty()); + tensor_dim_to_mesh_axis_mapping.push_back(mesh_axes_for_this_tensor_dim); + } + return tensor_dim_to_mesh_axis_mapping; +} + std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { @@ -1210,18 +1259,11 @@ absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, const Shape& shape, const DeviceMesh& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); - - const HloSharding* sharding_1d; - - if (src_n_dim == 1) { - sharding_1d = &src_sharding; - } else { - sharding_1d = &dst_sharding; - } + const HloSharding* sharding_1d = + src_n_dim == 1 ? &src_sharding : &dst_sharding; // Find an intermediate shape std::vector inter_shape_dims; - for (size_t i = 0; i < shape.rank(); ++i) { if (sharding_1d->tile_assignment().dim(i) == 1) { inter_shape_dims.push_back(shape.dimensions(i)); @@ -1260,32 +1302,25 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, hlo_sharding_util::ReshapeSharding(shape, *inter_shape, src_sharding); std::optional dst_inter_sharding = hlo_sharding_util::ReshapeSharding(shape, *inter_shape, dst_sharding); - if (!src_inter_sharding.has_value() || !dst_inter_sharding.has_value()) { - src_inter_sharding = HloSharding::Replicate(); - dst_inter_sharding = HloSharding::Replicate(); - LOG(WARNING) << "Invalid mixed mesh shape resharding."; + if (src_inter_sharding.has_value() && dst_inter_sharding.has_value()) { + HloInstruction* src_inter = computation->AddInstruction( + HloInstruction::CreateReshape(*inter_shape, tensor)); + src_inter->set_sharding(*src_inter_sharding); + + HloInstruction* dst_inter = computation->AddInstruction( + HloInstruction::CreateReshape(*inter_shape, src_inter)); + dst_inter->set_sharding(*dst_inter_sharding); + + replace_with = computation->AddInstruction( + HloInstruction::CreateReshape(shape, dst_inter)); + replace_with->set_sharding(dst_sharding); + return replace_with; } - - HloInstruction* src_inter = computation->AddInstruction( - HloInstruction::CreateReshape(*inter_shape, tensor)); - src_inter->set_sharding(*src_inter_sharding); - - HloInstruction* dst_inter = computation->AddInstruction( - HloInstruction::CreateReshape(*inter_shape, src_inter)); - dst_inter->set_sharding(*dst_inter_sharding); - - replace_with = computation->AddInstruction( - HloInstruction::CreateReshape(shape, dst_inter)); - } else { - replace_with = computation->AddInstruction( - HloInstruction::CreateReshape(shape, tensor)); } - } else { - replace_with = computation->AddInstruction( - HloInstruction::CreateReshape(shape, tensor)); } + replace_with = + computation->AddInstruction(HloInstruction::CreateReshape(shape, tensor)); replace_with->set_sharding(dst_sharding); - return replace_with; } @@ -1417,41 +1452,35 @@ absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, // token. CHECK_EQ(operand_num, 1); operand->set_sharding(dst_sharding); - } else { - const HloSharding& src_sharding = operand->sharding(); - HloInstruction* replace_with = nullptr; - // Query cache first - std::vector>* cache_vector = - nullptr; - if (resharding_cache != nullptr) { - cache_vector = &((*resharding_cache)[operand]); - for (const std::pair& entry : - *cache_vector) { - if (entry.first == dst_sharding) { - replace_with = entry.second; - } + return absl::OkStatus(); + } + const HloSharding& src_sharding = operand->sharding(); + HloInstruction* replace_with = nullptr; + // Query cache first + std::vector>* cache_vector = nullptr; + if (resharding_cache != nullptr) { + cache_vector = &((*resharding_cache)[operand]); + for (const std::pair& entry : *cache_vector) { + if (entry.first == dst_sharding) { + replace_with = entry.second; } } + } - if (replace_with != nullptr) { - // Do nothing - } else { - replace_with = - ReshardTensor(operand, src_sharding, dst_sharding, device_mesh); - if (cache_vector != nullptr) { - cache_vector->push_back({dst_sharding, replace_with}); - } + if (replace_with == nullptr) { + replace_with = + ReshardTensor(operand, src_sharding, dst_sharding, device_mesh); + if (cache_vector != nullptr) { + cache_vector->push_back({dst_sharding, replace_with}); } + } - size_t size = ByteSizeOfShape(replace_with->shape()) / (1024 * 1024 * 1024); - if (size > 1) { - LOG(WARNING) << "Large reshape instruction inserted (operand of " - << inst->name() << ") with size " << size - << "GB: " << replace_with->ToString(); - } + size_t size = ByteSizeOfShape(replace_with->shape()) / (1024 * 1024 * 1024); + LOG_IF(WARNING, size > 1) + << "Large reshape instruction inserted (operand of " << inst->name() + << ") with size " << size << "GB: " << replace_with->ToString(); - TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, replace_with)); - } + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(operand_num, replace_with)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index 3b8dd44bd094e2..862aefd7755061 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" @@ -411,13 +412,13 @@ std::optional PropagateReduceWindowSharding( // For every tile dimension, the device id sequence along that dimension has to // be an arithmetic sequence. // e.g., we don't allow specs like sharding={devices=[8,1] 0,4,1,5,2,7,3,8} -bool IsValidTileAssignment(const HloSharding& spec); +bool IsValidTileAssignment(const HloSharding& sharding); -// Get number of tile dimensions that are not 1. For example, for sharding spec +// Get number of tile dimensions that are not 1. For example, for sharding // {devices=[2,1,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} -// spec.tile_assignment.num_dimensions() = [2,1,1,4]. This function returns 2. -// -1 means the tensor is replicated on the whole the mesh. -int64_t NumTileDimensions(const HloSharding& spec); +// sharding.tile_assignment.num_dimensions() = [2,1,1,4]. This function +// returns 2. -1 means the tensor is replicated on the whole the mesh. +int64_t NumTileDimensions(const HloSharding& sharding); // When fixing mixed mesh resharding (see below), compute the correct // intermediate shape in order to insert copies. @@ -470,7 +471,17 @@ absl::StatusOr CheckArithmeticSequence( // Checks if the number of sharded dimensions in the tile assignment matches the // device mesh. -bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh); +bool TileAssignmentMatchesMesh(const HloSharding& sharding, + const DeviceMesh& mesh); + +absl::StatusOr> GetMeshDimPermutationOrderInShardingSpec( + const HloSharding& spec, const Array& device_mesh, + bool consider_reverse_device_meshes); + +absl::StatusOr>> +GetTensorDimToMeshDimMixedMeshSharding( + int64_t tensor_shape_rank, const HloSharding& sharding, + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); // Get the mapped mesh dimension for every tensor dimension. // The returned value maps ith tensor dim to one mesh dim. -1 means the tensor diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h index 069fde4e14c580..333df715447f0b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/btree_set.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" @@ -39,9 +40,23 @@ limitations under the License. namespace xla { namespace spmd { +// The high-level "recipe" for solving an Auto Sharding problem. +absl::StatusOr Solve( + const HloModule& hlo_module, const HloLiveRange& hlo_live_range, + const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, + const CostGraph& cost_graph, const AliasSet& alias_set, + const std::vector>& node_intervals, + const std::vector>& edge_intervals, + const std::vector>& node_groups, + const std::vector>& edge_groups, + const AutoShardingOption& option, absl::string_view request_prefix, + const absl::flat_hash_map& + sharding_propagation_solution = {}); + // A wrapper around the solver that converts the given objects into a // combinatorial optimization problem & solves it. -AutoShardingSolverResult CallSolver( +absl::StatusOr +CreateAutoShardingSolverRequestAndCallSolver( const HloModule& hlo_module, const HloLiveRange& hlo_live_range, const StrategyMap& strategy_map, const StrategyGroups& strategy_groups, const CostGraph& cost_graph, const AliasSet& alias_set, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index 42402e39a1496f..9a68b636b79fa5 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include +#include #include #include #include @@ -23,9 +24,14 @@ limitations under the License. #include #include +#include "absl/container/btree_set.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/service/spmd/spmd_partitioner_util.h" #include "xla/shape.h" @@ -42,12 +48,6 @@ double ClusterEnvironment::AllGatherCost(double num_bytes, int mesh_dim) const { num_bytes / 4, "float32"); } - if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - int64_t num_devices = device_mesh_.dim(mesh_dim); return (round(mesh_alpha_[mesh_dim] + mesh_beta_[mesh_dim] * (num_devices - 1) / num_devices * @@ -123,45 +123,83 @@ double ClusterEnvironment::AllToAllCost(double num_bytes, int mesh_dim) const { num_bytes / 4, "float32"); } - if (auto_sharding_option_.force_batch_dim_to_mesh_dim == mesh_dim) { - // if data-parallel is forced on this dim, we only allow all-reduce - // in this dimension. - return kInfinityCost; - } - int64_t num_devices = device_mesh_.dim(mesh_dim); return AllToAllCostUtil(num_bytes, mesh_dim, num_devices); } +template +bool IsSubset(absl::btree_set superset, absl::btree_set subset) { + for (const T& element : subset) { + if (!superset.contains(element)) { + return false; + } + } + return true; +} + // Do not consider device id changes yet. double ClusterEnvironment::ReshardingCostMixedMeshShape( - const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, - absl::Span dst_tensor_dim_to_mesh_dim) const { + const Shape& shape, const HloSharding& src_sharding, + const HloSharding& dst_sharding) const { + absl::StatusOr>> + src_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( + shape.rank(), src_sharding, device_mesh_, + /*consider_reverse_device_meshes=*/true); + absl::StatusOr>> + dst_tensor_dim_to_mesh_axis = GetTensorDimToMeshDimMixedMeshSharding( + shape.rank(), dst_sharding, device_mesh_, + /*consider_reverse_device_meshes=*/true); + if (!src_tensor_dim_to_mesh_axis.ok() || !dst_tensor_dim_to_mesh_axis.ok()) { + return OverestimateReplicationCost(shape, src_sharding, device_mesh_); + } + int64_t num_devices = device_mesh_.num_elements(); - double resharding_costs = 0.0; + std::vector collective_mesh_axes; + // Only consider sharded dimensions, do not consider replicate_on_last_dim. for (size_t i = 0; i < shape.rank(); ++i) { - // Only consider sharded dimensions, do not consider replicate_on_last_dim. - if (src_tensor_dim_to_mesh_dim[i] == dst_tensor_dim_to_mesh_dim[i]) { + if ((*src_tensor_dim_to_mesh_axis)[i] == + (*dst_tensor_dim_to_mesh_axis)[i]) { continue; } - if (dst_tensor_dim_to_mesh_dim[i] == -1 || - src_tensor_dim_to_mesh_dim[i] == -1) { - // AllToAll cost - int64_t communication_dim; - if (dst_tensor_dim_to_mesh_dim[i] != -1) { - communication_dim = dst_tensor_dim_to_mesh_dim[i]; - } else { - communication_dim = src_tensor_dim_to_mesh_dim[i]; - } - int64_t communication_bytes = ByteSizeOfShape(shape); - resharding_costs += - AllToAllCostUtil(communication_bytes, communication_dim, num_devices); - } else { + if (IsSubset((*dst_tensor_dim_to_mesh_axis)[i], + (*src_tensor_dim_to_mesh_axis)[i])) { + // do nothing; the dst is sharded more than the src + continue; + } + if (!IsSubset((*src_tensor_dim_to_mesh_axis)[i], + (*dst_tensor_dim_to_mesh_axis)[i])) { // Do not support this sharding, assuming it is gonna be very expensive. - return kInfinityCost; + return OverestimateReplicationCost(shape, src_sharding, device_mesh_); + } + for (int64_t mesh_dim : (*src_tensor_dim_to_mesh_axis)[i]) { + if (!(*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) { + collective_mesh_axes.push_back(mesh_dim); + } + } + } + + auto is_mesh_axis_used_for_dst_sharding = [&](int64_t mesh_dim) { + int end = dst_sharding.ReplicateOnLastTileDim() + ? dst_tensor_dim_to_mesh_axis->size() - 1 + : dst_tensor_dim_to_mesh_axis->size(); + for (int i = 0; i < end; ++i) { + if ((*dst_tensor_dim_to_mesh_axis)[i].contains(mesh_dim)) { + return true; + } } + return false; + }; + + double resharding_cost = 0.0; + int64_t communication_bytes = ByteSizeOfShape(shape); + for (int mesh_dim : collective_mesh_axes) { + bool used_for_dst_sharding = is_mesh_axis_used_for_dst_sharding(mesh_dim); + resharding_cost += + used_for_dst_sharding + ? AllToAllCostUtil(communication_bytes, mesh_dim, num_devices) + : AllGatherCost(communication_bytes, mesh_dim); } - return resharding_costs; + return resharding_cost; } double ClusterEnvironment::CollectivePermuteCost( @@ -194,17 +232,16 @@ double ClusterEnvironment::CollectivePermuteCost( // Overestimate the cost of replicating a tensor by decomposing the resharding // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( - const Shape& shape, const HloSharding& src_spec, + const Shape& shape, const HloSharding& src_sharding, const DeviceMesh& device_mesh) const { - if (src_spec.IsTileMaximal() || src_spec.IsManual()) { - // TODO(b/238210866) Do not use kInfinityCost. - return kInfinityCost; + if (src_sharding.IsReplicated()) { + return 0; } - int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_spec); + int64_t bytes_moved = ByteSizeOfShapeWithSharding(shape, src_sharding); double cost = 0.0; for (size_t i = 0; i < device_mesh.num_dimensions(); ++i) { - auto this_cost = this->AllGatherCost(bytes_moved, i); - cost += this_cost; + cost += src_sharding.IsTileMaximal() ? this->AllReduceCost(bytes_moved, i) + : this->AllGatherCost(bytes_moved, i); bytes_moved *= device_mesh.dimensions()[i]; } return cost; @@ -325,8 +362,7 @@ double ClusterEnvironment::ReshardingCost(const Shape& shape, dst_tensor_dim_to_mesh_dim_or.value(); if (src_n_dim != dst_n_dim && src_n_dim != -1 && dst_n_dim != -1) { - return ReshardingCostMixedMeshShape(shape, src_tensor_dim_to_mesh_dim, - dst_tensor_dim_to_mesh_dim); + return ReshardingCostMixedMeshShape(shape, src_spec, dst_spec); } AdjustTensorMeshDimMapping(src_tensor_dim_to_mesh_dim, src_n_dim); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index d17b026dd8ffb4..89b81133c95d0d 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -145,9 +145,9 @@ class ClusterEnvironment { double AllToAllCost(double num_bytes, int mesh_dim) const; - double ReshardingCostMixedMeshShape( - const Shape& shape, absl::Span src_tensor_dim_to_mesh_dim, - absl::Span dst_tensor_dim_to_mesh_dim) const; + double ReshardingCostMixedMeshShape(const Shape& shape, + const HloSharding& src_sharding, + const HloSharding& dst_sharding) const; double CollectivePermuteCost( double num_bytes, diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index e65c48d982d89d..39955314bca2df 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -68,6 +68,7 @@ cc_library( "//xla:literal_util", "//xla:printer", "//xla:protobuf_util", + "//xla:shape_layout", "//xla:shape_tree", "//xla:shape_util", "//xla:sort_json", @@ -77,14 +78,16 @@ cc_library( "//xla:window_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/hlo/parser:hlo_lexer", "//xla/service:compilation_environments", "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", - "//xla/service:hlo_lexer", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:mapped_ptr_container_sorter", "//xla/service:name_uniquer", + "//xla/tsl/lib/gtl:iterator_range", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", @@ -92,6 +95,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -102,8 +106,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:iterator_range", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", @@ -114,6 +116,25 @@ cc_library( ], ) +xla_cc_test( + name = "hlo_module_test", + srcs = ["hlo_module_test.cc"], + deps = [ + ":hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/hash", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "backend_config", srcs = ["backend_config.cc"], @@ -150,9 +171,13 @@ cc_library( hdrs = ["hlo_module_group.h"], deps = [ ":hlo", + "//xla:status_macros", + "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", ], ) @@ -162,32 +187,24 @@ cc_library( hdrs = ["hlo_instruction_utils.h"], deps = [ ":hlo", + "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", ], ) cc_library( name = "hlo_reachability", - srcs = ["hlo_reachability.cc"], hdrs = ["hlo_reachability.h"], - deps = [ - ":hlo", - "//xla:types", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/types:span", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_reachability instead.", + deps = ["//xla/hlo/analysis:hlo_reachability"], ) cc_library( name = "hlo_dfs_reachability", - srcs = ["hlo_dfs_reachability.cc"], hdrs = ["hlo_dfs_reachability.h"], - deps = [ - ":hlo", - "@com_google_absl//absl/algorithm:container", - "@llvm-project//llvm:Support", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_dfs_reachability instead.", + deps = ["//xla/hlo/analysis:hlo_dfs_reachability"], ) cc_library( @@ -223,7 +240,9 @@ cc_library( "//xla:util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", ], ) diff --git a/third_party/xla/xla/hlo/ir/collective_device_list.cc b/third_party/xla/xla/hlo/ir/collective_device_list.cc index efa009c72fefae..2b431770512391 100644 --- a/third_party/xla/xla/hlo/ir/collective_device_list.cc +++ b/third_party/xla/xla/hlo/ir/collective_device_list.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/hlo/ir/collective_device_list.h b/third_party/xla/xla/hlo/ir/collective_device_list.h index 66d898e2ccaa45..06228143b5cdfe 100644 --- a/third_party/xla/xla/hlo/ir/collective_device_list.h +++ b/third_party/xla/xla/hlo/ir/collective_device_list.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/array.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc index b89eb721ac1170..d0f6bfb6fe1357 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.cc @@ -17,8 +17,9 @@ limitations under the License. #include +#include "absl/status/status.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/types.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/util.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h index d69ceb51dd97d5..cd4fa0db260b78 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor.h @@ -135,6 +135,7 @@ class DfsHloVisitorBase { virtual absl::Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual absl::Status HandleOptimizationBarrier(HloInstructionPtr hlo) = 0; virtual absl::Status HandlePartitionId(HloInstructionPtr hlo) = 0; + virtual absl::Status HandleRaggedAllToAll(HloInstructionPtr hlo) = 0; virtual absl::Status HandleReduceScatter(HloInstructionPtr hlo) = 0; virtual absl::Status HandleReplicaId(HloInstructionPtr hlo) = 0; /* go/keep-sorted end */ diff --git a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h index 37e52646ee83ff..452f9a2cdac393 100644 --- a/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h +++ b/third_party/xla/xla/hlo/ir/dfs_hlo_visitor_with_default.h @@ -120,6 +120,9 @@ class DfsHloVisitorWithDefaultBase absl::Status HandleAllToAll(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + absl::Status HandleRaggedAllToAll(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } absl::Status HandleCollectiveBroadcast(HloInstructionPtr hlo) override { return DefaultAction(hlo); } diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 6ab8e712c1eee7..bfabb382eda52c 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -34,9 +34,13 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -46,17 +50,22 @@ limitations under the License. #include "xla/hlo/ir/ptrvec.h" #include "xla/map_util.h" #include "xla/printer.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/mapped_ptr_container_sorter.h" #include "xla/service/name_uniquer.h" #include "xla/shape.h" +#include "xla/shape_layout.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -680,6 +689,7 @@ HloComputation::ChannelDependencies HloComputation::ComputeChannelDependencies() case HloOpcode::kAllToAll: case HloOpcode::kCollectiveBroadcast: case HloOpcode::kCollectivePermute: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kReduceScatter: { HloInstruction* instruction = inst.inst(); std::optional channel_id = instruction->channel_id(); @@ -1156,7 +1166,8 @@ absl::StatusOr HloComputation::CreateAsyncInstructions( HloInstruction* root = builder.AddInstruction( instruction->CloneWithNewOperands(instruction->shape(), parameters)); if (override_names) { - root->SetAndSanitizeName(absl::StrCat(instruction->name(), ".cloned")); + parent()->SetAndUniquifyInstrName( + root, absl::StrCat(instruction->name(), ".cloned")); } HloComputation* async_computation = parent_->AddEmbeddedComputation(builder.Build(root)); @@ -1171,9 +1182,10 @@ absl::StatusOr HloComputation::CreateAsyncInstructions( async_done = AddInstruction( HloInstruction::CreateAsyncDone(root->shape(), async_start)); if (override_names) { - async_start->SetAndSanitizeName( - absl::StrCat(root->name(), ".call-start")); - async_done->SetAndSanitizeName(absl::StrCat(root->name(), ".call-done")); + parent()->SetAndUniquifyInstrName( + async_start, absl::StrCat(root->name(), ".call-start")); + parent()->SetAndUniquifyInstrName( + async_done, absl::StrCat(root->name(), ".call-done")); } } async_start->set_metadata(instruction->metadata()); @@ -1430,8 +1442,8 @@ absl::StatusOr HloComputation::ReplaceInstructionWithDifferentShape( old_instruction->frontend_attributes()); } if (auto old_original_value = old_instruction->original_value()) { - // Fusions are handled separately. The original_value attribute of fused - // instructions is copied when they are added into the fused computation. + // Fusions are handled separately. The original value of fused instructions + // is copied when they are added into the fused computation. if (new_instruction->opcode() != HloOpcode::kFusion) { if (ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) { diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 3e73a68762e74f..d603535511531e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -48,9 +48,9 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" namespace xla { @@ -202,6 +202,39 @@ class HloComputation { ~HloComputation(); + enum class InstructionType : uint8_t { + kUnset, + // This computation is a fusion computation. A fusion computation ordinarily + // also has a non-null instruction. However, if a fusion instruction + // is removed during compilation, the fusion computation becomes + // unreachable, and its instruction is set to null. We still need to regard + // such computations as fusion computations for HLO scheduling purposes. + kFusion, + // This computation is a custom-call computation. + kCustomCall, + // This computation is a collective computation. + kCollective, + // This computation is a while body computation. + kWhile, + // This computation is a conditional branch computation. + kConditional, + // Last Value for range checking. + kLast = kConditional, + }; + static constexpr uintptr_t kInstructionTypeMask = 0b111; + static_assert(static_cast(InstructionType::kUnset) == 0, + "kUnset must be 0."); + + InstructionType instruction_type() const { + return static_cast(instruction_and_type_ & + kInstructionTypeMask); + } + + HloInstruction* instruction() const { + DCHECK(instruction_type() <= InstructionType::kLast); + return reinterpret_cast(instruction_and_type_ & + ~kInstructionTypeMask); + } // Add an instruction to the computation. The computation takes ownership of // the instruction. HloInstruction* AddInstruction(std::unique_ptr instruction, @@ -787,17 +820,23 @@ class HloComputation { } // Returns if this computation is a body computation of a while. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] bool IsWhileBodyComputation() const { return instruction_type() == InstructionType::kWhile; } // Returns the owning while call instruction, or nullptr if this is not a // while call body computation. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] HloInstruction* WhileCallInstruction() const { return instruction_type() == InstructionType::kWhile ? instruction() : nullptr; } + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] void SetWhileCallInstruction(HloInstruction* while_call_instruction) { CHECK(while_call_instruction != nullptr); CHECK(while_call_instruction->opcode() == HloOpcode::kWhile); @@ -805,17 +844,23 @@ class HloComputation { } // Returns if this computation is a branch computation of a conditional. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] bool IsConditionalBranchComputation() const { return instruction_type() == InstructionType::kConditional; } // Returns the owning conditional call instruction, or nullptr if this is not // a conditional branch computation. + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] HloInstruction* ConditionalCallInstruction() const { return instruction_type() == InstructionType::kConditional ? instruction() : nullptr; } + [[deprecated( + "This is broken. Use CallGraph::GetComputationCallers() instead")]] void SetConditionalCallInstruction( HloInstruction* conditional_call_instruction) { CHECK(conditional_call_instruction != nullptr); @@ -826,6 +871,18 @@ class HloComputation { // Returns if this computation is an async computation. bool IsAsyncComputation() const { return async_start_ != nullptr; } + // Returns true if this computation only contains send/recv instructions. + bool OnlyContainsSendRecv() { + for (const HloInstruction* instruction : this->instructions()) { + if (!HloPredicateIsOp( + instruction)) { + return false; + } + } + return true; + } + // Returns the owning async instruction. It's nullptr if this is not an async // computation. HloInstruction* AsyncStart() const { return async_start_; } @@ -932,37 +989,6 @@ class HloComputation { absl::Status RemoveInstructionImpl(HloInstruction* instruction, bool ignore_safety_check); - enum class InstructionType : uint8_t { - kUnset, - // This computation is a fusion computation. A fusion computation ordinarily - // also has a non-null instruction. However, if a fusion instruction - // is removed during compilation, the fusion computation becomes - // unreachable, and its instruction is set to null. We still need to regard - // such computations as fusion computations for HLO scheduling purposes. - kFusion, - // This computation is a custom-call computation. - kCustomCall, - // This computation is a collective computation. - kCollective, - // This computation is a while body computation. - kWhile, - // This computation is a conditional branch computation. - kConditional, - }; - static constexpr uintptr_t kInstructionTypeMask = 0b111; - static_assert(static_cast(InstructionType::kUnset) == 0, - "kUnset must be 0."); - - InstructionType instruction_type() const { - return static_cast(instruction_and_type_ & - kInstructionTypeMask); - } - - HloInstruction* instruction() const { - return reinterpret_cast(instruction_and_type_ & - ~kInstructionTypeMask); - } - void SetInstruction(HloInstruction* instruction, InstructionType type); int64_t unique_id_; diff --git a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h index 3db9a5309b4efc..446be761b96228 100644 --- a/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h +++ b/third_party/xla/xla/hlo/ir/hlo_dfs_reachability.h @@ -16,46 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ #define XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ -#include -#include - -#include "llvm/ADT/DenseMap.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" - -namespace xla { - -// A simple DFS-based reachability analysis for HLO instructions. -// -// When the class is created, the instructions are ordered in a defs-before-uses -// topological order. -// The reachability query runs a DFS from the destination node (going up through -// operands / control predecessors), and stops when the instruction's index in -// the defs-before-uses list is before the source node. As the reachability is -// tested for nodes that are close to each other, this optimization works well, -// and the time is dominated by the post-order sort. -class HloDfsReachability { - public: - // Returns true iff the instruction was present in the computation passed to - // Build(). The calling code may want to still use the class after the - // computation is modified, if it's known that the def-before-use order is - // still preserved. - bool IsPresent(const HloInstruction* instruction) const; - // Returns true iff there is a path (with edges being users and control - // successors) from 'from' to 'to'. (i.e. path from definitions to uses; from - // producers to consumers) - bool IsReachable(const HloInstruction* from, const HloInstruction* to) const; - // Returns true iff either `a` is reachable from `b` or `b` is reachable from - // `a`. - bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; - static std::unique_ptr Build( - const HloComputation* computation); - - private: - // LLVM dense map shows ~10-20% speedup compared to absl::flat_hash_map. - llvm::DenseMap instruction_to_idx_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #endif // XLA_HLO_IR_HLO_DFS_REACHABILITY_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc b/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc index c54af5675f4a7c..c31d7388b7b00e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc +++ b/third_party/xla/xla/hlo/ir/hlo_input_output_alias_config.cc @@ -22,17 +22,21 @@ limitations under the License. #include #include +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout_util.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 37d7a39d8ee0e0..4240b1f12d2d3f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -63,23 +63,23 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/ir/ptrvec.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/map_util.h" #include "xla/primitive_util.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_lexer.h" #include "xla/service/mapped_ptr_container_sorter.h" #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/sort_json.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/status.h" @@ -747,6 +747,18 @@ absl::StatusOr> HloInstruction::CreateFromProto( proto.constrain_layout(), channel_id, split_dimension); break; } + case HloOpcode::kRaggedAllToAll: { + std::optional channel_id; + if (proto.channel_id() > 0) { + channel_id = proto.channel_id(); + } + TF_RET_CHECK(all_operands().size() == 6) + << "RaggedAllToAll must have 6 operands"; + instruction = CreateRaggedAllToAll(shape, all_operands(), + CollectiveDeviceList::FromProto(proto), + channel_id); + break; + } case HloOpcode::kCollectiveBroadcast: { std::optional channel_id; if (proto.channel_id() > 0) { @@ -1660,6 +1672,24 @@ HloInstruction::CreateAllReduceStart( constrain_layout, channel_id, split_dimension); } +/* static */ std::unique_ptr +HloInstruction::CreateRaggedAllToAll(const Shape& shape, + absl::Span operands, + const CollectiveDeviceList& device_list, + const std::optional& channel_id) { + return std::make_unique( + shape, operands, device_list, channel_id); +} + +/* static */ std::unique_ptr +HloInstruction::CreateRaggedAllToAll( + const Shape& shape, absl::Span operands, + absl::Span replica_groups, + const std::optional& channel_id) { + return CreateRaggedAllToAll(shape, operands, + CollectiveDeviceList(replica_groups), channel_id); +} + /* static */ std::unique_ptr HloInstruction::CreateCollectiveBroadcast( const Shape& shape, absl::Span operands, @@ -2174,8 +2204,10 @@ HloInstruction::CreateDynamicReshape( } /* static */ std::unique_ptr HloInstruction::CreateFusion( - const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { - return std::make_unique(shape, fusion_kind, fused_root); + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root, + absl::string_view prefix) { + return std::make_unique(shape, fusion_kind, fused_root, + prefix); } /* static */ std::unique_ptr HloInstruction::CreateFusion( @@ -2214,9 +2246,11 @@ void HloInstruction::SetupDerivedInstruction( derived_instruction->mutable_rare()->frontend_attributes.Clear(); derived_instruction->mutable_rare()->statistics_viz.Clear(); } - // If the derived instruction has the same opcode as current, - // then the backend config is also applicable. - if (opcode() == derived_instruction->opcode() && has_backend_config()) { + // If the derived instruction has the same opcode as current, then the backend + // config is also applicable (only if derived instruction doesn't have its own + // backend config which might be different from the original one). + if (opcode() == derived_instruction->opcode() && has_backend_config() && + !derived_instruction->has_backend_config()) { derived_instruction->CopyBackendConfigFrom(this); } } @@ -2759,6 +2793,20 @@ int64_t HloInstruction::operand_index(const HloInstruction* target) const { LOG(FATAL) << "target was not an operand: " << target->ToString(); } +std::vector HloInstruction::operand_indices( + const HloInstruction* target) const { + std::vector indices; + for (int64_t i = 0; i < operand_count(); ++i) { + if (target == operand(i)) { + indices.push_back(i); + } + } + if (indices.empty()) { + LOG(FATAL) << "target was not an operand: " << target->ToString(); + } + return indices; +} + HloInstruction::InstructionVector HloInstruction::unique_operands() const { InstructionVector unique; absl::flat_hash_set seen; @@ -3093,6 +3141,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDot: case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kSetDimensionSize: case HloOpcode::kTriangularSolve: case HloOpcode::kCholesky: @@ -3239,6 +3288,55 @@ absl::Status HloInstruction::Defuse() { return module->RemoveEmbeddedComputation(fused_computation); } +absl::StatusOr HloInstruction::UnfuseInstruction( + HloInstruction* instruction) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + + std::vector new_operands; + // Gather the operands that need to be extracted from the fusion. + for (int64_t operand_num = 0; operand_num < instruction->operand_count(); + ++operand_num) { + HloInstruction* operand = instruction->mutable_operand(operand_num); + if (operand->opcode() == HloOpcode::kParameter) { + // If the operand is a parameter of the fusion, we need to extract it. + HloInstruction* extracted_operand = + mutable_operand(operand->parameter_number()); + new_operands.push_back(extracted_operand); + } else if (operand->opcode() == HloOpcode::kConstant) { + HloInstruction* cloned_constant = AddInstruction(operand->Clone()); + new_operands.push_back(cloned_constant); + } else if (operand->opcode() == HloOpcode::kBroadcast && + operand->operand(0)->opcode() == HloOpcode::kConstant) { + HloInstruction* cloned_constant = + AddInstruction(operand->operand(0)->Clone()); + new_operands.push_back(AddInstruction( + operand->CloneWithNewOperands(operand->shape(), {cloned_constant}))); + } else { + return InvalidArgument( + "Unsupported operand type for unfusing: %s. Currently only " + "parameters and constants are supported.", + operand->ToString()); + } + } + + // Clone the instruction to be unfused. + HloInstruction* unfused_instruction = AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), new_operands)); + + // Add the unfused instruction as a parameter to the fusion instruction. + HloComputation* fusion_computation = fused_instructions_computation(); + + HloInstruction* new_parameter = AddFusionOperand(unfused_instruction); + // Replace the instruction in the fusion computation with the new parameter. + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(new_parameter)); + + // Remove the original instruction from the fusion computation. + TF_RETURN_IF_ERROR( + fusion_computation->RemoveInstructionAndUnusedOperands(instruction)); + + return unfused_instruction; +} + absl::Status HloInstruction::ReplaceUsesWith( absl::Span users, HloInstruction* new_producer) { TF_RET_CHECK( @@ -3399,18 +3497,30 @@ const PtrVec& HloInstruction::branch_computations() const { return called_computations(); } -int HloInstruction::branch_count() const { +int32_t HloInstruction::branch_count() const { CHECK(HloOpcode::kConditional == opcode_); return called_computations().size(); } -HloComputation* HloInstruction::branch_computation(int b) const { - CHECK(HloOpcode::kConditional == opcode_); +HloComputation* HloInstruction::branch_computation(int32_t b) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); CHECK_GE(b, 0); CHECK_LT(b, called_computations().size()); return called_computations()[b]; } +int32_t HloInstruction::branch_index(HloComputation* computation) const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + CHECK_NE(computation, nullptr); + for (int32_t idx = 0; idx < branch_count(); idx++) { + if (branch_computation(idx) == computation) { + return idx; + } + } + LOG(FATAL) << absl::StrFormat("Conditional %s does not contain branch %s", + name(), computation->name()); +} + void HloInstruction::set_branch_computation(int b, HloComputation* computation) { CHECK_EQ(HloOpcode::kConditional, opcode_); @@ -3680,8 +3790,8 @@ void HloInstruction::PrintWithCanonicalNameMap( }); PrintExtraAttributes(attr_printer, options); - if (original_value_) { - printer->Append(", original_value={"); + if (options.print_original_value() && original_value_) { + printer->Append(", origin={"); printer->Append(OriginalValueToString(*original_value())); printer->Append("}"); } @@ -4094,20 +4204,7 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_statistics_viz() = statistics_viz(); if (original_value_) { - xla::OriginalValueProto* original_value_proto = - proto.mutable_original_value(); - for (const auto& leaf : original_value_->leaves()) { - OriginalArrayProto* original_array_proto = - original_value_proto->add_leaves(); - for (const auto& index : leaf.first) { - original_array_proto->add_leaf_shape_index(index); - } - *original_array_proto->mutable_instruction_name() = - leaf.second->instruction_name; - for (const auto& index : leaf.second->shape_index) { - original_array_proto->add_shape_index(index); - } - } + *proto.mutable_original_value() = OriginalValueToProto(*original_value_); } return proto; @@ -4290,6 +4387,8 @@ absl::Status HloInstruction::Visit( return visitor->HandleAllReduceDone(this); case HloOpcode::kAllToAll: return visitor->HandleAllToAll(this); + case HloOpcode::kRaggedAllToAll: + return visitor->HandleRaggedAllToAll(this); case HloOpcode::kCollectiveBroadcast: return visitor->HandleCollectiveBroadcast(this); case HloOpcode::kCollectivePermute: diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 3dcf016acd6cd0..9b8d6c1e02814a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -64,8 +64,8 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" @@ -96,6 +96,7 @@ class HloPrintOptions { indent_amount_(0), print_large_constants_(false), print_only_essential_constants_(false), + print_original_value_(true), print_metadata_(true), print_metadata_only_op_name_(false), print_backend_config_(true), @@ -201,6 +202,11 @@ class HloPrintOptions { return *this; } + // If true, origin will be printed. + HloPrintOptions& set_print_original_value(bool value) { + print_original_value_ = value; + return *this; + } // If true, metadata will be printed. HloPrintOptions& set_print_metadata(bool value) { print_metadata_ = value; @@ -387,6 +393,7 @@ class HloPrintOptions { PrintSubcomputationMode print_subcomputation_mode() const { return print_subcomputation_mode_; } + bool print_original_value() const { return print_original_value_; } bool print_metadata() const { return print_metadata_; } bool print_metadata_only_op_name() const { return print_metadata_only_op_name_; @@ -430,6 +437,7 @@ class HloPrintOptions { int indent_amount_; bool print_large_constants_; bool print_only_essential_constants_; + bool print_original_value_; bool print_metadata_; bool print_metadata_only_op_name_; bool print_backend_config_; @@ -1009,6 +1017,59 @@ class HloInstruction { const std::optional& channel_id, const std::optional& split_dimension = std::nullopt); + // The RaggedAllToAll instruction performs a collective all-to-all operation, + // where the input and output are ragged tensors. + // + // Ragged tensors are defined by a set of three tensors: + // *) ‘data’: the ‘data’ tensor is “ragged” along its outermost dimension, + // along which each indexed element has variable size. + // *) ‘offsets’: the ‘offsets’ tensor indexes the outermost dimension of the + // ‘data’ tensor, and represents the starting offset of each ragged element + // of the ‘data’ tensor. + // *) ‘sizes’: the ‘sizes’ tensor represents the size of each ragged element + // of the ‘data’ tensor, where the size is specified in units of + // sub-elements. A sub-element is defined as the suffix of the ‘data’ tensor + // shape obtained by removing the outermost “ragged” dimension. + // *) The ‘offsets’ and ‘sizes’ tensors must have the same size. + // + // An example ragged tensor + // data: [8,3] = + // {{a,b,c},{d,e,f},{g,h,i},{j,k,l},{m,n,o},{p,q,r},{s,t,u},{v,w,x}} + // offsets: [3] = {0, 1, 4} + // sizes: [3] = {1, 3, 4} + // + // Index 'data' at 'offsets'[0], 'sizes'[0]' + // {a,b,c} + // + // Index 'data' at 'offsets'[1], 'sizes'[1]' + // {d,e,f},{g,h,i},{j,k,l} + // + // Index 'data' at 'offsets'[2], 'sizes'[2]' + // {m,n,o},{p,q,r},{s,t,u},{v,w,x} + // + // The ragged all-to-all HLO has the following arguments: + // input: ragged input data tensor. + // input_offsets: ragged input offsets tensor. + // input_sizes: ragged input sizes tensor. + // output: ragged output data tensor. + // output_offsets: ragged output offsets tensor. + // output_sizes: ragged output sizes tensor. + // + // The '*_offsets' and '*_sizes' tensors must have the same shape. + // The output buffer is passed in as an input (and aliased in the output), + // to support incremental updates to the same buffer. + // + static std::unique_ptr CreateRaggedAllToAll( + const Shape& shape, absl::Span operands, + const CollectiveDeviceList& device_list, + const std::optional& channel_id); + + ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.") + static std::unique_ptr CreateRaggedAllToAll( + const Shape& shape, absl::Span operands, + absl::Span replica_groups, + const std::optional& channel_id); + // Creates a communication instruction that broadcasts data cross replicas. // Data is sent from to the first replica id in each group to the other ids in // the same group. If a replica id is not a in any replica group, the output @@ -1347,7 +1408,8 @@ class HloInstruction { // "fused_root". Additional instructions can be added to the fusion // instruction with the method FuseInstruction. static std::unique_ptr CreateFusion( - const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root, + absl::string_view prefix = ""); static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, @@ -1493,10 +1555,14 @@ class HloInstruction { // within the operand vector. InstructionVector unique_operands() const; - // Returns the index of 'target' in the operands sequence. + // Returns the first index of 'target' that occurs in the operands sequence. // Precondition: target must be an operand (or a fatal error will occur). int64_t operand_index(const HloInstruction* target) const; + // Returns all indices of 'target' that occur in the operands sequence. + // Precondition: target must be an operand (or a fatal error will occur). + std::vector operand_indices(const HloInstruction* target) const; + // Returns the number of users of this instruction. int64_t user_count() const { return users_.size(); } @@ -1681,6 +1747,13 @@ class HloInstruction { // Decomposes fusion back to individual parts. absl::Status Defuse(); + // Unfuses the given instruction from its fusion computation. If the given + // instruction is not fused, this is a no-op and returns nullptr. Returns a + // pointer to the newly unfused instruction if successful. Currently, fused + // instructions with parameter or constant operands are supported. + absl::StatusOr UnfuseInstruction( + HloInstruction* instruction); + // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. @@ -1808,8 +1881,9 @@ class HloInstruction { // // Precondition: The instruction is a Conditional instruction. const PtrVec& branch_computations() const; - int branch_count() const; - HloComputation* branch_computation(int b) const; + int32_t branch_count() const; + HloComputation* branch_computation(int32_t b) const; + int32_t branch_index(HloComputation* computation) const; // Sets a branch HloComputation for Conditional. // The setter should only be called by HloModule or HloComputation methods. // diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc index 52eac784f085d5..c500b0ccd079c1 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction_utils.cc @@ -18,8 +18,11 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/xla_data.pb.h" namespace xla { namespace hlo_instruction_utils { diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index 0f88a38a0448e2..98e2a2d9058966 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -59,10 +59,10 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/protobuf.h" @@ -1211,6 +1211,40 @@ bool HloAllToAllInstruction::IdenticalSlowPathIgnoringChannelIdValues( split_dimension_ == casted_other.split_dimension(); } +HloRaggedAllToAllInstruction::HloRaggedAllToAllInstruction( + const Shape& shape, absl::Span operands, + const CollectiveDeviceList& device_list, + const std::optional& channel_id) + : HloCollectiveInstruction(HloOpcode::kRaggedAllToAll, shape, operands, + device_list, + /*constrain_layout=*/false, channel_id) {} + +HloRaggedAllToAllInstruction::HloRaggedAllToAllInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + absl::Span replica_groups, + const std::optional& channel_id) + : HloRaggedAllToAllInstruction( + shape, operands, CollectiveDeviceList(replica_groups), channel_id) {} + +std::unique_ptr +HloRaggedAllToAllInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* /*context*/) const { + return std::make_unique( + shape, new_operands, device_list(), channel_id()); +} + +HloInstructionProto HloRaggedAllToAllInstruction::ToProto() const { + HloInstructionProto proto = HloCollectiveInstruction::ToProto(); + return proto; +} + +void HloRaggedAllToAllInstruction::PrintExtraAttributesImpl( + AttributePrinter& printer, const HloPrintOptions& options) const { + HloCollectiveInstruction::PrintExtraAttributesImpl(printer, options); +} + HloCollectiveBroadcastInstruction::HloCollectiveBroadcastInstruction( HloOpcode opcode, const Shape& shape, absl::Span operands, @@ -2038,7 +2072,7 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( // instruction. Add it as an operand and add a corresponding called // computation parameter instruction. - // No need to create an original_value for an added parameter as the + // No need to create an original value for an added parameter as the // original value is saved in the corresponding argument. called_computation_parameter = AddCallOperand(operand); } @@ -2047,7 +2081,7 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( } if (clone != instruction_to_append) { - // Copy over original_value attribute to the clone of a fused instruction. + // Copy over the original value to the clone of a fused instruction. if (auto original_value = instruction_to_append->original_value()) { clone->set_original_value(original_value); } @@ -2092,7 +2126,7 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( HloInstruction* new_root = called_computation()->AddInstruction( HloInstruction::CreateTuple(tuple_elements)); - // No need to create an original_value for a new root with added outputs + // No need to create an original value for a new root with added outputs // as the original value is saved in the get-tuple-element instructions // that use it. called_computation()->set_root_instruction(new_root, @@ -2172,14 +2206,20 @@ void HloCallableInstruction::RecursivelySetComputationsThreadName( HloFusionInstruction::HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, - HloInstruction* fused_root) + HloInstruction* fused_root, + absl::string_view prefix) : HloCallableInstruction(HloOpcode::kFusion, shape), fusion_kind_(fusion_kind) { CHECK(fused_root != nullptr); - SetAndSanitizeName(HloOpcodeString(opcode())); + SetAndSanitizeName(absl::StrCat(prefix, HloOpcodeString(opcode()))); + set_parent(fused_root->parent()); set_metadata(fused_root->metadata()); set_frontend_attributes(fused_root->frontend_attributes()); + // This simplifies some use cases for the original value that involve fusions. + if (auto original_value = fused_root->original_value()) { + set_original_value(original_value); + } CHECK(fused_root->IsFusible()) << fused_root->ToString(); CloneAndAppendInstructionIntoCalledComputation(fused_root); } @@ -2423,7 +2463,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput( auto cloned_instruction = parent()->AddInstruction(fused_instruction->CloneWithNewOperands( fused_instruction->shape(), new_operands, /*suffix=*/"clone")); - // Copy over original_value attribute to the clone of a fused instruction. + // Copy over the original value to the clone of a fused instruction. // This is necessary as the clone will be cloned again when the clone is // fused in FuseInstructionIntoMultiOutput(). This can be skipped if we // improve the code to only clone once as stated in the preceding comment. diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index c0f03248dbf772..505dead6969d99 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -44,8 +44,8 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/status.h" @@ -903,6 +903,36 @@ class HloAllToAllInstruction : public HloCollectiveInstruction { std::optional split_dimension_; }; +class HloRaggedAllToAllInstruction : public HloCollectiveInstruction { + public: + explicit HloRaggedAllToAllInstruction( + const Shape& shape, absl::Span operands, + const CollectiveDeviceList& device_list, + const std::optional& channel_id); + + ABSL_DEPRECATED("Use CollectiveDeviceList instead of list of ReplicaGroup.") + explicit HloRaggedAllToAllInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, + absl::Span replica_groups, + const std::optional& channel_id); + + static bool ClassOf(const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kRaggedAllToAll; + } + + protected: + void PrintExtraAttributesImpl(AttributePrinter& printer, + const HloPrintOptions& options) const override; + HloInstructionProto ToProto() const override; + + private: + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; +}; + class HloCollectiveBroadcastInstruction : public HloCollectiveInstruction { public: explicit HloCollectiveBroadcastInstruction( @@ -989,7 +1019,8 @@ inline bool HloCollectiveInstruction::ClassOf(const HloInstruction* hlo) { return HloAllReduceInstructionBase::ClassOf(hlo) || HloCollectiveBroadcastInstruction::ClassOf(hlo) || HloAllGatherInstruction::ClassOf(hlo) || - HloAllToAllInstruction::ClassOf(hlo); + HloAllToAllInstruction::ClassOf(hlo) || + HloRaggedAllToAllInstruction::ClassOf(hlo); } inline bool HloChannelInstruction::ClassOf(const HloInstruction* hlo) { @@ -1439,7 +1470,8 @@ class HloCallableInstruction : public HloInstruction { class HloFusionInstruction : public HloCallableInstruction { public: explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, - HloInstruction* fused_root); + HloInstruction* fused_root, + absl::string_view prefix = ""); explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind, absl::Span operands, diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 0e98622801e97f..7f406e81bcb85f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -56,9 +57,10 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" @@ -718,6 +720,10 @@ absl::StatusOr HloModule::CreateModuleConfigFromShape( module_config.set_auto_spmd_partitioning_mesh_ids(std::vector( execution_options->auto_spmd_partitioning_mesh_ids().begin(), execution_options->auto_spmd_partitioning_mesh_ids().end())); + module_config.set_exec_time_optimization_effort( + execution_options->exec_time_optimization_effort()); + module_config.set_memory_fitting_effort( + execution_options->memory_fitting_effort()); module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo()); if (!execution_options->allow_spmd_sharding_propagation_to_parameters() .empty()) { @@ -1128,17 +1134,21 @@ std::unique_ptr HloModule::Clone( } absl::Status HloModule::RemoveUnusedComputations() { - std::string suffix = "tmp"; - auto module = std::make_unique( - absl::StrCat(name_, "-", suffix), config(), - std::make_unique(*comp_envs_)); - HloCloneContext context(module.get(), suffix); - entry_computation_->Clone(suffix, &context); - std::vector to_remove; - for (auto computation : computations()) { - auto found_computation = context.FindComputation(computation); - if (found_computation == nullptr) { - to_remove.push_back(computation); + absl::flat_hash_set to_remove(computations().begin(), + computations().end()); + std::stack agenda; + agenda.push(entry_computation_); + to_remove.erase(entry_computation_); + while (!agenda.empty()) { + HloComputation* computation = agenda.top(); + agenda.pop(); + for (HloInstruction* instruction : computation->instructions()) { + for (HloComputation* called_computation : + instruction->called_computations()) { + if (to_remove.erase(called_computation) > 0) { + agenda.push(called_computation); + } + } } } for (auto computation : to_remove) { diff --git a/third_party/xla/xla/hlo/ir/hlo_module.h b/third_party/xla/xla/hlo/ir/hlo_module.h index 6dea7d5234fc8b..36967922122e02 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.h +++ b/third_party/xla/xla/hlo/ir/hlo_module.h @@ -17,6 +17,7 @@ limitations under the License. #define XLA_HLO_IR_HLO_MODULE_H_ #include +#include #include #include #include @@ -29,8 +30,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/ir/dynamic_parameter_binding.h" #include "xla/hlo/ir/hlo_clone_context.h" @@ -38,16 +41,20 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module_metadata.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/iterator_util.h" #include "xla/printer.h" #include "xla/service/compilation_environments.h" +#include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/name_uniquer.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/xla.pb.h" -#include "tsl/lib/gtl/iterator_range.h" #include "tsl/platform/logging.h" namespace xla { @@ -58,6 +65,8 @@ using LayoutCanonicalizationCallback = // Helper class to maintain a copy-on-write storage of an object of the // specified type. Logically Variant. +// The class's purpose is to share (shared_ptr) underlying storage (when it's +// not changed) thus reducing memory footprint. template class CopyOnWrite { public: @@ -137,6 +146,9 @@ class HloModule { // - comp_envs must not be null. HloModule(const std::string& name, HloModuleConfig config, std::unique_ptr comp_envs); + + // You can share a config from other modules by passing + // HloModule::shared_config() HloModule(const std::string& name, std::variant, std::shared_ptr> @@ -287,7 +299,8 @@ class HloModule { // with respect to HloInstruction::Identical() method. template friend H AbslHashValue(H h, const HloModule& module) { - h = H::combine(std::move(h), module.entry_computation_layout()); + if (module.config().has_entry_computation_layout()) + h = H::combine(std::move(h), module.entry_computation_layout()); // Use MakeComputationSorted() instead of MakeComputationPostOrder() // because naming may affect the order of MakeComputationPostOrder() but not // MakeComputationSorted(). @@ -816,8 +829,7 @@ class HloModule { // Compilation environments (protos that carry command line flags and // environment variables). - std::unique_ptr comp_envs_ = - std::make_unique(); + std::unique_ptr comp_envs_; // Stack frame indexes flat representation. std::optional stack_frame_index_; diff --git a/third_party/xla/xla/hlo/ir/hlo_module_group.cc b/third_party/xla/xla/hlo/ir/hlo_module_group.cc index c5623e45aea588..6d57d496a4f748 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_group.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_group.cc @@ -22,6 +22,15 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" +#include "xla/status_macros.h" +#include "tsl/platform/statusor.h" + namespace xla { HloModuleGroup::HloModuleGroup(std::unique_ptr module) diff --git a/third_party/xla/xla/hlo/ir/hlo_module_group.h b/third_party/xla/xla/hlo/ir/hlo_module_group.h index 753e8bc61c62d3..68c2253b0e1878 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_group.h +++ b/third_party/xla/xla/hlo/ir/hlo_module_group.h @@ -21,10 +21,12 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc index 79cad36763f6ed..bf19f043396a84 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.cc @@ -18,10 +18,14 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/service/hlo.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h index 7f115f042d139f..e6b4050db526fc 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_module_metadata.h @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/functional/function_ref.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/service/hlo.pb.h" #include "xla/status_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/ir/hlo_module_test.cc b/third_party/xla/xla/hlo/ir/hlo_module_test.cc new file mode 100644 index 00000000000000..f0921685cbda89 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_module_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_module.h" + +#include +#include + +#include +#include "absl/hash/hash.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/hlo_module_config.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +TEST(HloModuleTest, AbslHashValue) { + HloModule module1("temp_module", HloModuleConfig()); + HloModule module2("temp_module3", HloModuleConfig()); + EXPECT_EQ(absl::HashOf(module1), absl::HashOf(module2)); + + std::string_view hlo = R"( + HloModule m1 + ENTRY main { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT res = f32[] multiply(a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module3, + ParseAndReturnUnverifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module4, + ParseAndReturnUnverifiedModule(hlo)); + EXPECT_EQ(absl::HashOf(*module3), absl::HashOf(*module4)); + EXPECT_NE(absl::HashOf(module1), absl::HashOf(*module4)); +} + +TEST(HloModuleTest, MutableOwnedImmutableSharedConfig) { + HloModuleConfig config1; + config1.set_device_type("first"); + config1.set_device_memory_size(7); + HloModule m1("-", config1); + HloModule m2("-", m1.shared_config(), + std::make_unique()); + EXPECT_EQ(&m1.config(), &m2.config()) + << "Shared config referres to the same object."; + m1.mutable_config().set_device_type("second"); + EXPECT_NE(&m1.config(), &m2.config()) << "Config is copied on modification."; + EXPECT_EQ(m1.config().device_type(), "second"); + EXPECT_EQ(m2.config().device_type(), "first"); + EXPECT_EQ(m1.config().device_memory_size(), m2.config().device_memory_size()); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc index 5370eed46e3765..30b1d2c3cfc6a6 100644 --- a/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_op_metadata.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_opcode.cc b/third_party/xla/xla/hlo/ir/hlo_opcode.cc index dcdf9c1933c829..cc12b39873a866 100644 --- a/third_party/xla/xla/hlo/ir/hlo_opcode.cc +++ b/third_party/xla/xla/hlo/ir/hlo_opcode.cc @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_opcode.h b/third_party/xla/xla/hlo/ir/hlo_opcode.h index 8c1cc7c00e5dcb..e2036b1725feed 100644 --- a/third_party/xla/xla/hlo/ir/hlo_opcode.h +++ b/third_party/xla/xla/hlo/ir/hlo_opcode.h @@ -127,6 +127,7 @@ namespace xla { V(kPartitionId, "partition-id", 0) \ V(kPopulationCount, "popcnt", 1) \ V(kPower, "power", 2) \ + V(kRaggedAllToAll, "ragged-all-to-all", 6) \ V(kReal, "real", 1) \ V(kRecv, "recv", 1) \ V(kRecvDone, "recv-done", 1) \ diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.cc b/third_party/xla/xla/hlo/ir/hlo_original_value.cc index 789978d74cbf39..915069a8c2a3e5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.cc +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { @@ -65,4 +66,22 @@ std::string OriginalValueToString(const OriginalValue& original_value) { return OriginalValueToStringHelper(original_value, original_value.shape(), shape_index); } + +OriginalValueProto OriginalValueToProto(const OriginalValue& original_value) { + OriginalValueProto original_value_proto; + for (const auto& leaf : original_value.leaves()) { + OriginalArrayProto* original_array_proto = + original_value_proto.add_leaves(); + for (const auto& index : leaf.first) { + original_array_proto->add_leaf_shape_index(index); + } + *original_array_proto->mutable_instruction_name() = + leaf.second->instruction_name; + for (const auto& index : leaf.second->shape_index) { + original_array_proto->add_shape_index(index); + } + } + return original_value_proto; +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.h b/third_party/xla/xla/hlo/ir/hlo_original_value.h index a77bc8a13460c7..eca98ef3d0bc3e 100644 --- a/third_party/xla/xla/hlo/ir/hlo_original_value.h +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.h @@ -21,6 +21,7 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { // Stores information of original values. @@ -32,6 +33,8 @@ struct OriginalArray { using OriginalValue = ShapeTree>; std::string OriginalValueToString(const OriginalValue& original_value); + +OriginalValueProto OriginalValueToProto(const OriginalValue& original_value); } // namespace xla #endif // XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_reachability.h b/third_party/xla/xla/hlo/ir/hlo_reachability.h index 157991067ae9ad..30153bf07aadc8 100644 --- a/third_party/xla/xla/hlo/ir/hlo_reachability.h +++ b/third_party/xla/xla/hlo/ir/hlo_reachability.h @@ -16,206 +16,7 @@ limitations under the License. #ifndef XLA_HLO_IR_HLO_REACHABILITY_H_ #define XLA_HLO_IR_HLO_REACHABILITY_H_ -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/container/flat_hash_map.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/types.h" - -namespace xla { - -// A class for representing reachability between HloInstructions. -// -// It has an adjacency matrix and it is up to the user of the class to set the -// adjacency matrix such that it represents reachability, i.e. such that it is -// transitive. That the graph be transitive is thus not an invariant of this -// class, but it is required for the name of the class and its methods to make -// sense. -class HloReachabilityMap { - public: - using Index = size_t; - - // Sets up a graph with no edges and where the nodes correspond to the given - // instructions. - explicit HloReachabilityMap( - absl::Span instructions); - - // Computes and returns the reachability between HLO instructions in the - // computation. The returned HloReachabilityMap is constructed such that - // HloReachabilityMap::IsReachable(a, b) returns true iff there exists a - // directed path (from producer to consumer) from 'a' to 'b'. Both data - // dependencies (operands) and control dependencies are considered for - // reachability. Trivially an instruction is reachable from itself. - static std::unique_ptr Build( - const HloComputation* computation); - - // Similar to the above Build operation except that it tries to identify - // paths between instructions that do not contain control instructions - // and multiple operands, i.e., b is_reachable a == true iff - // b = f(f(f(f(f(a), constant), constant), constant). - // Further, the only ops allowed in a path are basic math operations such - // as add, sub, mul, div. - static std::unique_ptr BuildWithRestrictions( - const HloComputation* computation, - absl::FunctionRef*)> - add_dependencies); - - // Set the reachability set of 'instruction' to the union of the reachability - // sets of 'inputs'. Upon return, IsReachable(x, instruction) where - // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true - // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from - // itself. Returns whether the reachability set of 'instruction' changed. - // - // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency - // vector in the internal graph of this HloReachabilityMap for the given - // instruction and does not transitively update any other part of the - // adjacency matrix. - bool SetReachabilityToUnion(absl::Span inputs, - const HloInstruction* instruction); - - // As above, but faster because it does not check if the reachability changed. - void FastSetReachabilityToUnion( - absl::Span inputs, - const HloInstruction* instruction); - // As above, but use Index instead if it's already looked up which is even - // faster since no hash map lookup will occur. - void FastSetReachabilityToUnion(absl::Span input_indices, - Index index); - - Index GetIndex(const HloInstruction* instruction) const { - return indices_.at(GetKey(instruction)); - } - - // Sets entry so that IsReachable(a, b) will return true - // - // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency - // matrix in the internal graph of this HloReachabilityMap to have an edge - // from a to b and does not transitively update any other part of the - // adjacency matrix. - void SetReachable(const HloInstruction* a, const HloInstruction* b) { - SetReachable(GetIndex(a), GetIndex(b)); - } - void SetReachable(Index a, Index b) { bit_sets_[b].Set(a); } - - // Updates the given reachability map after the immediate predecessor set - // (operands and control predecessors) of 'instruction' has changed. - void UpdateReachabilityThroughInstruction(const HloInstruction* instruction); - - // Returns true if "b" is reachable from "a" - // - // Note that this function only correctly answers queries about reachability - // if the set of edges that have been provided to this class are transitive. - bool IsReachable(const HloInstruction* a, const HloInstruction* b) const { - return IsReachable(GetIndex(a), GetIndex(b)); - } - bool IsReachable(Index a, Index b) const { return bit_sets_[b].Get(a); } - - // Returns true if "b" is reachable from "a" or "a" is reachable from "b" - // - // Note that this function only correctly answers queries about reachability - // if the set of edges that have been provided to this class are transitive. - bool IsConnected(const HloInstruction* a, const HloInstruction* b) const { - return IsConnected(GetIndex(a), GetIndex(b)); - } - bool IsConnected(Index a, Index b) const { - return IsReachable(a, b) || IsReachable(b, a); - } - - // Checks if an instruction is in the Reachability map. - bool IsPresent(const HloInstruction* instruction) const { - return indices_.contains(GetKey(instruction)); - } - - // Replace the instruction "original" with "replacement" in the reachability - // map. - void Replace(const HloInstruction* original, - const HloInstruction* replacement); - - private: - // A dynamically sized bit-set implementation specialized for this use case - // providing fast bitwise OR (not available in tsl::gtl::BitMap). - class BitSet { - public: - BitSet() = default; - explicit BitSet(size_t size) - : size_(size), vector_((size + kBits - 1) / kBits, 0) {} - - // Returns the bit at the given index. - bool Get(Index index) const { - DCHECK(index >= 0 && index < size_); - return vector_[index / kBits] & (1ull << (index % kBits)); - } - - // Sets the bit at the given index. - void Set(Index index) { - DCHECK(index >= 0 && index < size_); - vector_[index / kBits] |= 1ull << (index % kBits); - } - - // Sets this bit-set to union of this bit-set and `other`. - void operator|=(const BitSet& other) { - if (this == &other) return; - DCHECK(size_ == other.size_); - - // Ease the work of the auto-vectorizer. - const Word* a = vector_.data(); - const Word* b = other.vector_.data(); - Word* __restrict out = vector_.data(); - size_t num_words = vector_.size(); - for (size_t i = 0; i < num_words; ++i) { - out[i] = a[i] | b[i]; - } - } - - // Sets the bitvector to all zeros. - void SetToZero() { absl::c_fill(vector_, 0); } - - bool operator==(const BitSet& other) const { - return vector_ == other.vector_; - } - bool operator!=(const BitSet& other) const { return !(*this == other); } - - private: - using Word = uint64_t; - static constexpr size_t kBits = 64; - - size_t size_; // Number of bits in the set. - std::vector vector_; - }; - - friend class HloReachabilityMapBitSetBenchmark; - - using Key = std::pair; // module ID, instruction ID. - static Key GetKey(const HloInstruction* instruction) { - return {instruction->GetModule()->unique_id(), instruction->unique_id()}; - } - - // Helper for SetReachabilityToUnion/FastSetReachabilityToUnion. - void SetReachabilityToUnionHelper( - absl::Span inputs, Index index); - void SetReachabilityToUnionHelper(absl::Span input_indices, - Index index); - - // Map from instruction to index. The index is used for bit_set_ and the bits - // within a BitSet. - absl::flat_hash_map indices_; - - // Bit-sets holding the reachability to each instruction. The bit-set for - // instruction X includes ones for each instruction which X is reachable from. - std::vector bit_sets_; - - // A temporary used by SetReachabilityToUnion to avoid an allocation with each - // call to the method. - BitSet tmp_bit_set_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_reachability.h" #endif // XLA_HLO_IR_HLO_REACHABILITY_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.cc b/third_party/xla/xla/hlo/ir/hlo_schedule.cc index 18747b3ae50f9e..ddd2aaf4cffef5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.cc +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.cc @@ -26,15 +26,23 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/map_util.h" +#include "xla/service/hlo.pb.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" -#include "tsl/lib/gtl/map_util.h" +#include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_schedule.h b/third_party/xla/xla/hlo/ir/hlo_schedule.h index 4deb36d4e89c32..37cbff34856a9a 100644 --- a/third_party/xla/xla/hlo/ir/hlo_schedule.h +++ b/third_party/xla/xla/hlo/ir/hlo_schedule.h @@ -23,11 +23,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo.pb.h" namespace xla { @@ -86,6 +90,11 @@ class HloInstructionSequence { id_sequence_.insert(id_sequence_.begin() + index, instruction->unique_id()); } + bool contains(const HloInstruction* inst) const { + return absl::c_find(instruction_sequence_, inst) != + instruction_sequence_.end(); + } + // Clears the sequence of all instructions. void clear() { instruction_sequence_.clear(); diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.cc b/third_party/xla/xla/hlo/ir/hlo_sharding.cc index 71520e196c0718..a24fdb72f54c5f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.cc @@ -29,19 +29,26 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/optimization.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/array.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/overflow_util.h" #include "xla/printer.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.h b/third_party/xla/xla/hlo/ir/hlo_sharding.h index 5a7c49e9265899..f8740ba4d223ff 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.h @@ -34,7 +34,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/ir/tile_assignment.h" // IWYU pragma: export +#include "xla/printer.h" +#include "xla/shape.h" #include "xla/shape_tree.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc index 5bc48d8ee29e51..b578408ed9dd80 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc +++ b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.cc @@ -21,9 +21,21 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_domain_metadata.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h index 689a1d28694479..5d963e931d96b2 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding_metadata.h @@ -21,6 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/hlo/ir/tile_assignment.cc b/third_party/xla/xla/hlo/ir/tile_assignment.cc index 7620ab685e65ad..c3518bbe0ed152 100644 --- a/third_party/xla/xla/hlo/ir/tile_assignment.cc +++ b/third_party/xla/xla/hlo/ir/tile_assignment.cc @@ -22,10 +22,13 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/printer.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/tile_assignment.h b/third_party/xla/xla/hlo/ir/tile_assignment.h index 4d9caed9b0acb1..7adc5ab50b2d70 100644 --- a/third_party/xla/xla/hlo/ir/tile_assignment.h +++ b/third_party/xla/xla/hlo/ir/tile_assignment.h @@ -23,6 +23,9 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/functional/function_ref.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "xla/array.h" #include "xla/printer.h" diff --git a/third_party/xla/xla/hlo/parser/BUILD b/third_party/xla/xla/hlo/parser/BUILD new file mode 100644 index 00000000000000..6bc86b609219bb --- /dev/null +++ b/third_party/xla/xla/hlo/parser/BUILD @@ -0,0 +1,115 @@ +# Description: +# XLA parser implementation. + +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla:xla.bzl", + "xla_cc_test", +) +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//xla:array", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service:name_uniquer", + "//xla/service:shape_inference", + "//xla/tsl/lib/gtl:map_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_lexer", + ":hlo_parser", + "//xla:array", + "//xla:shape_util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:hlo_module_config", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + ], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:regexp", + ], +) diff --git a/third_party/xla/xla/service/hlo_lexer.cc b/third_party/xla/xla/hlo/parser/hlo_lexer.cc similarity index 98% rename from third_party/xla/xla/service/hlo_lexer.cc rename to third_party/xla/xla/hlo/parser/hlo_lexer.cc index 18ca3fcb775277..4c294b8567e327 100644 --- a/third_party/xla/xla/service/hlo_lexer.cc +++ b/third_party/xla/xla/hlo/parser/hlo_lexer.cc @@ -13,19 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_lexer.h" +#include "xla/hlo/parser/hlo_lexer.h" +#include #include -#include #include #include #include #include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/escaping.h" +#include "absl/strings/match.h" #include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "xla/primitive_util.h" #include "xla/util.h" #include "tsl/platform/numbers.h" diff --git a/third_party/xla/xla/hlo/parser/hlo_lexer.h b/third_party/xla/xla/hlo/parser/hlo_lexer.h new file mode 100644 index 00000000000000..f787392b39b37e --- /dev/null +++ b/third_party/xla/xla/hlo/parser/hlo_lexer.h @@ -0,0 +1,218 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PARSER_HLO_LEXER_H_ +#define XLA_HLO_PARSER_HLO_LEXER_H_ + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/regexp.h" + +namespace xla { + +// Defines different kinds of tokens used by the HLO lexer. +// +// You shouldn't need to use this directly unless you're using HloLexer +// directly, and you probably don't need to do that. Use hlo_parser instead. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kAsterisk, // * + kQuestionMark, // ? + kOctothorp, // # + kPlus, // + + kTilde, // ~ + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + kDots, // ... + + kArrow, // -> + kLeq, // <= + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_manual, + kw_last_tile_dim_replicate, + kw_shard_as, + kw_shard_like, + kw_unknown, + kw_inf, + + kNegInf, // -inf + + // Typed tokens. + kPrimitiveType, // F32, PRED, etc. + kName, // %foo + kAttributeName, // dimensions= + kDimLabels, // [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kSparsityDesc, // ([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+ + kIdent, // other identifiers + kString, // "abcd\"\n" + kInt, // 42 + kDecimal, // 4.2 +}; + +std::string TokKindToString(TokKind kind); + +// Lexer for the HloModule::ToString() format text. +// +// This class is meant to be used by hlo_parser.cc. You shouldn't need to use +// it directly. +class HloLexer { + public: + explicit HloLexer(absl::string_view buf) : buf_(buf) { + current_ptr_ = buf_.data(); + } + + TokKind Lex() { return token_state_.current_kind = LexToken(); } + + TokKind GetKind() const { return token_state_.current_kind; } + std::string GetStrVal() const { + switch (GetKind()) { + case TokKind::kName: + case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kPad: + case TokKind::kSparsityDesc: + case TokKind::kString: + case TokKind::kIdent: + return token_state_.str_val; + default: + LOG(FATAL) << "This token does not have string value"; + } + } + int64_t GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt) << TokKindToString(GetKind()); + return token_state_.int64_val; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return token_state_.decimal_val; + } + PrimitiveType GetPrimitiveTypeVal() const { + CHECK(GetKind() == TokKind::kPrimitiveType); + return token_state_.primitive_type_val; + } + + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_state_.token_start; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + absl::string_view GetLine(LocTy loc) const; + + // Looks ahead one token and returns it. Lexer state is unchanged. + TokKind LookAhead(); + + // Lexes a string delimited by matching curly braces. Curlies contained + // inside double quotes don't count. + // + // Requires that you've already lexed the open curly brace. + // + // The returned string value includes the outer curlies. + // + // Returns TokKind::kString on success. + TokKind LexJsonDict(); + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates string_view with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + absl::string_view StringViewFromPointers(const char* begin, + const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexNumberOrPattern(); + TokKind LexString(); + + std::optional LexNanPayload(absl::string_view& consumable); + + absl::string_view buf_; + const char* current_ptr_; + + // Information about the current token. + struct TokenState { + const char* token_start = nullptr; + TokKind current_kind; + std::string str_val; + int64_t int64_val; + double decimal_val; + PrimitiveType primitive_type_val; + }; + TokenState token_state_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; +}; + +// Does this string start with "{", end with "}", and contain valid-ish JSON +// in-between? If so, hlo_parser can parse e.g. backend_config={blah: "blah"} +// instead of the much uglier backend_config="{blah: \"blah\"}". +// +// (Technically we're not checking for fully-valid JSON, just something we can +// find the end of reasonably.) +bool LexesAsJsonDict(absl::string_view str); + +} // namespace xla + +#endif // XLA_HLO_PARSER_HLO_LEXER_H_ diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/hlo/parser/hlo_parser.cc similarity index 99% rename from third_party/xla/xla/service/hlo_parser.cc rename to third_party/xla/xla/hlo/parser/hlo_parser.cc index 9ec2d0f1510301..21caec4653eb47 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include #include @@ -60,6 +60,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -67,17 +68,16 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/name_uniquer.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" @@ -215,6 +215,7 @@ bool CanInferShape(HloOpcode code) { case HloOpcode::kDynamicReshape: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduceScatter: @@ -249,10 +250,8 @@ class HloParserImpl : public HloParser { using BoolList = absl::InlinedVector; explicit HloParserImpl(absl::string_view str, - bool set_to_default_entry_computation_layout = true) - : lexer_(str), - set_to_default_entry_computation_layout_( - set_to_default_entry_computation_layout) {} + const HloParserOptions& options = HloParserOptions()) + : lexer_(str), options_(options) {} // Runs the parser and constructs the resulting HLO in the given (empty) // HloModule. Returns the error status in case an error occurred. @@ -549,7 +548,7 @@ class HloParserImpl : public HloParser { bool ParseJsonDict(std::string* result); bool ParseDimensionSizes(std::vector* dimension_sizes, std::vector* dynamic_dimensions); - bool ParseShape(Shape* result, bool set_to_default_layout = true); + bool ParseShape(Shape* result); bool ParseLayout(Layout* layout); bool ParseLayoutIntAttribute(int64_t* attr_value, absl::string_view attr_description); @@ -673,7 +672,7 @@ class HloParserImpl : public HloParser { // Used to generate names for anonymous instructions. NameUniquer name_uniquer_{/*separator=*/"."}; - const bool set_to_default_entry_computation_layout_; + const HloParserOptions options_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector* out) { @@ -917,7 +916,7 @@ bool HloParserImpl::ParseComputationLayout( } while (lexer_.GetKind() != TokKind::kRparen) { Shape param; - if (!ParseShape(¶m, set_to_default_entry_computation_layout_)) { + if (!ParseShape(¶m)) { return false; } computation_layout->add_parameter_layout(ShapeLayout(param)); @@ -937,7 +936,7 @@ bool HloParserImpl::ParseComputationLayout( return false; } Shape result; - if (!ParseShape(&result, set_to_default_entry_computation_layout_)) { + if (!ParseShape(&result)) { return false; } *computation_layout->mutable_result_layout() = ShapeLayout(result); @@ -1392,8 +1391,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, &predecessors}; optional> original_value; - attrs["original_value"] = {/*required=*/false, AttrTy::kOriginalValue, - &original_value}; + attrs["origin"] = {/*required=*/false, AttrTy::kOriginalValue, + &original_value}; optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; @@ -1787,6 +1786,23 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT constrain_layout ? *constrain_layout : false, channel_id, split_dimension)); } + case HloOpcode::kRaggedAllToAll: { + CollectiveDeviceList device_list; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kCollectiveDeviceList, &device_list}; + optional channel_id; + attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; + optional> dimensions; + attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, + &dimensions}; + if ((!preset_operands && !ParseOperands(&operands, builder)) || + !ParseAttributes(attrs, allow_attributes, shape) || + (dimensions && dimensions->size() != 1)) { + return nullptr; + } + return builder->AddInstruction(HloInstruction::CreateRaggedAllToAll( + *shape, operands, device_list, channel_id)); + } case HloOpcode::kCollectiveBroadcast: { CollectiveDeviceList device_list; attrs["replica_groups"] = {/*required=*/true, @@ -6099,7 +6115,7 @@ bool HloParserImpl::ParseLayout(Layout* layout) { // tuple_elements // ::= /*empty*/ // ::= shape (',' shape)* -bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { +bool HloParserImpl::ParseShape(Shape* result) { if (EatIfPresent(TokKind::kLparen)) { // Tuple std::vector shapes; if (lexer_.GetKind() == TokKind::kRparen) { @@ -6108,7 +6124,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { // shape (',' shape)* do { shapes.emplace_back(); - if (!ParseShape(&shapes.back(), set_to_default_layout)) { + if (!ParseShape(&shapes.back())) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -6134,7 +6150,7 @@ bool HloParserImpl::ParseShape(Shape* result, bool set_to_default_layout) { result->add_dimensions(dimension_sizes[i]); result->set_dynamic_dimension(i, dynamic_dimensions[i]); } - if (set_to_default_layout || ShapeUtil::IsScalar(*result)) { + if (options_.fill_missing_layouts() || ShapeUtil::IsScalar(*result)) { LayoutUtil::SetToDefaultLayout(result); } // We need to lookahead to see if a following open brace is the start of a @@ -6990,19 +7006,13 @@ bool HloParserImpl::ParseSingleInstruction(HloModule* module) { absl::StatusOr> ParseAndReturnUnverifiedModule( absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout) { + const HloParserOptions& options) { auto module = std::make_unique(/*name=*/"_", config); - HloParserImpl parser(str, set_to_default_entry_computation_layout); + HloParserImpl parser(str, options); TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout) { - return ParseAndReturnUnverifiedModule( - str, HloModuleConfig(), set_to_default_entry_computation_layout); -} - absl::StatusOr ParseSharding(absl::string_view str) { HloParserImpl parser(str); return parser.ParseShardingOnly(); diff --git a/third_party/xla/xla/hlo/parser/hlo_parser.h b/third_party/xla/xla/hlo/parser/hlo_parser.h new file mode 100644 index 00000000000000..302bc829f9bd92 --- /dev/null +++ b/third_party/xla/xla/hlo/parser/hlo_parser.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_PARSER_HLO_PARSER_H_ +#define XLA_HLO_PARSER_HLO_PARSER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_lexer.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class HloParserOptions { + public: + // When a shape layout is not set (e.g. in the entry computation layout or + // instruction layout), set the layout to be the default (e.g. {3,2,1,0}). + HloParserOptions& set_fill_missing_layouts(bool value) { + fill_missing_layouts_ = value; + return *this; + } + + bool fill_missing_layouts() const { return fill_missing_layouts_; } + + private: + bool fill_missing_layouts_ = true; +}; + +// Given a string in the HloModule::ToString() format, parses the string and +// creates a HloModule with the given config. +// Note: Tests derived from HloHardwareIndependentTestBase should use +// ParseAndReturnVerifiedModule() instead! +absl::StatusOr> ParseAndReturnUnverifiedModule( + absl::string_view str, const HloModuleConfig& config = HloModuleConfig(), + const HloParserOptions& options = HloParserOptions()); + +// Parses sharding from str. str is supposed to contain the body of the +// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., +// "{replicated}". +absl::StatusOr ParseSharding(absl::string_view str); + +// Parses frontend attributes from str. str is supposed to contain the body of +// the frontend attributes , i.e. just the rhs of the +// "frontend_attributes={...}" attribute string, e.g., +// "{attr_a=a,attr_b=b}". +absl::StatusOr ParseFrontendAttributes( + absl::string_view str); + +// Parses statistics viz from str. str is supposed to contain the body of the +// statistics visualization, i.e. just the rhs of the "statistics={...}" +// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". +absl::StatusOr ParseStatisticsViz(absl::string_view str); + +// Parses parameter replication from str. str is supposed to contain the body of +// the parameter replication, i.e. just the rhs of the +// "parameter_replication={...}" attribute string, e.g., "{true, false}". +absl::StatusOr> ParseParameterReplication( + absl::string_view str); + +// Parses the result of window_util::ToString(const Window&). +absl::StatusOr ParseWindow(absl::string_view str); + +// Parses the result of ConvolutionDimensionNumbersToString(), e.g. +// "b0f_0io->b0f". +absl::StatusOr ParseConvolutionDimensionNumbers( + absl::string_view str); + +// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". +absl::StatusOr ParsePaddingConfig(absl::string_view str); + +// Parses and returns a Shape::ToString-format string. +absl::StatusOr ParseShape(absl::string_view str); + +// Parses and returns a Layout::ToString-format string. +absl::StatusOr ParseLayout(absl::string_view str); + +// Parses and returns a std::vector from str. str is supposed to +// contain a list of the replica groups, i.e. just the rhs of the +// "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". +absl::StatusOr> ParseReplicaGroupsOnly( + absl::string_view str); + +class HloParser { + public: + // Runs the parser and constructs the resulting HLO in the given (empty) + // HloModule. Returns the error status in case an error occurred. + virtual absl::Status Run(HloModule* module) = 0; + virtual ~HloParser() {} + + private: + static std::unique_ptr CreateHloParserForTests( + absl::string_view str); + friend class VerifiedHloModule; +}; + +} // namespace xla + +#endif // XLA_HLO_PARSER_HLO_PARSER_H_ diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc similarity index 96% rename from third_party/xla/xla/service/hlo_parser_test.cc rename to third_party/xla/xla/hlo/parser/hlo_parser_test.cc index d6783d516807fb..34311e3c95e766 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/hlo/parser/hlo_parser_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include #include @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include #include #include "absl/log/log.h" #include "absl/status/status.h" @@ -38,15 +39,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_lexer.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/layout.h" #include "xla/layout_util.h" -#include "xla/service/hlo_lexer.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" @@ -1530,11 +1531,11 @@ ENTRY %test (p: f32[100]) -> u32[100] { R"(HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} ENTRY %test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]), f32[2,3]) { - %v1 = f32[] parameter(0), original_value={{"v1"}} - %v2 = f32[3]{0} parameter(1), original_value={{"v2"}} - %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), original_value={({"v1"}, {"v2"})} - %v3 = f32[2,3]{1,0} parameter(2), original_value={{"v3"}} - ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), original_value={(({"v1"}, {"v2"}), {"v3"})} + %v1 = f32[] parameter(0), origin={{"v1"}} + %v2 = f32[3]{0} parameter(1), origin={{"v2"}} + %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), origin={({"v1"}, {"v2"})} + %v3 = f32[2,3]{1,0} parameter(2), origin={{"v3"}} + ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), origin={(({"v1"}, {"v2"}), {"v3"})} } )" @@ -2183,6 +2184,59 @@ ENTRY AllToAllWithSubgroupsIotaList { )", /*replica_count=*/40 }, +// ragged-all-to-all +{ +"RaggedAllToAllWithReplicaGroups", +R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8 + +ENTRY AllToAll { + input = bf16[1024,256]{1,0} parameter(0) + output = bf16[1024,256]{1,0} parameter(1) + input_offsets = s32[8]{0} parameter(2) + input_sizes = s32[8]{0} parameter(3) + output_offsets = s32[8]{0} parameter(4) + output_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={{0,1,2,3,4,5,6,7}} +} + +)", +/*replica_count=*/8 +}, +// ragged-all-to-all +{ +"RaggedAllToAllWithCollectiveDeviceList", +R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8 + +ENTRY AllToAll { + input = bf16[1024,256]{1,0} parameter(0) + output = bf16[1024,256]{1,0} parameter(1) + input_offsets = s32[8]{0} parameter(2) + input_sizes = s32[8]{0} parameter(3) + output_offsets = s32[8]{0} parameter(4) + output_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups=[2,4]<=[4,2]T(1,0) +} + +)", +/*replica_count=*/8 +}, +// ragged-all-to-all +{ +"RaggedAllToAll", +R"(HloModule RaggedAllToAll, entry_computation_layout={(bf16[1024,256]{1,0}, bf16[1024,256]{1,0}, s32[8]{0}, s32[8]{0}, s32[8]{0}, /*index=5*/s32[8]{0})->bf16[1024,256]{1,0}}, replica_count=8 + +ENTRY AllToAll { + input = bf16[1024,256]{1,0} parameter(0) + output = bf16[1024,256]{1,0} parameter(1) + input_offsets = s32[8]{0} parameter(2) + input_sizes = s32[8]{0} parameter(3) + output_offsets = s32[8]{0} parameter(4) + output_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={} +} + +)" +}, // collective-broadcast { "CollectiveBroadcast", @@ -3434,12 +3488,135 @@ ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { absl::StatusOr> module = ParseAndReturnUnverifiedModule( - original, /*set_to_default_entry_computation_layout=*/false); + original, {}, HloParserOptions().set_fill_missing_layouts(false)); TF_ASSERT_OK(module.status()); // Do not set the default layout. EXPECT_FALSE(module.value()->entry_computation_layout().AnyLayoutSet()); } +TEST_F(HloParserTest, DoNotSetEntryComputationLayoutIfSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]{1,2,0}) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(1, 2, 0)); +} + +TEST_F(HloParserTest, SetEntryComputationLayoutIfNotSet) { + const std::string original = R"( +HloModule layout_defined, entry_computation_layout={(f32[8,16,256]) -> f32[8,16]} + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation_layout() + .parameter_layout(0) + .layout() + .minor_to_major(), + ElementsAre(2, 1, 0)); +} + +TEST_F(HloParserTest, DoNotFallBackToDefaultLayoutIfDisabled) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(false)); + TF_ASSERT_OK(module.status()); + EXPECT_FALSE(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .has_layout()); +} + +TEST_F(HloParserTest, FallBackToDefaultLayoutIfEnabled) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80] dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .layout() + .minor_to_major(), + ElementsAre(3, 2, 1, 0)); +} + +TEST_F(HloParserTest, FallBackToDefaultLayoutIfAlreadySet) { + const std::string original = R"( +HloModule t + +ENTRY main { + p0 = f16[16,32,48,64]{3,2,1,0} parameter(0) + p1 = f16[80,64,48,32]{3,2,1,0} parameter(1) + ROOT dot = f16[64,32,16,80]{1,2,0,3} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={2}, lhs_batch_dims={3,1}, rhs_batch_dims={1,3} +})"; + + absl::StatusOr> module = + ParseAndReturnUnverifiedModule( + original, {}, HloParserOptions().set_fill_missing_layouts(true)); + TF_ASSERT_OK(module.status()); + EXPECT_THAT(module.value() + ->entry_computation() + ->root_instruction() + ->shape() + .layout() + .minor_to_major(), + ElementsAre(1, 2, 0, 3)); +} + TEST_F(HloParserTest, NoEntry) { const std::string original = R"(HloModule no_entry: c1 { @@ -5537,8 +5714,8 @@ TEST_F(HloParserTest, OriginalValueWithoutShape) { const std::string hlo_string = R"(HloModule test ENTRY %test { - %a = f32[2,10]{1,0} parameter(0), original_value={{"a"}} - ROOT %v = abs(%a), original_value={{"v"}} + %a = f32[2,10]{1,0} parameter(0), origin={{"a"}} + ROOT %v = abs(%a), origin={{"v"}} } diff --git a/third_party/xla/xla/hlo/pass/BUILD b/third_party/xla/xla/hlo/pass/BUILD index 81905a744d8bb5..6fb06db5a3443b 100644 --- a/third_party/xla/xla/hlo/pass/BUILD +++ b/third_party/xla/xla/hlo/pass/BUILD @@ -31,11 +31,14 @@ cc_library( deps = [ "//xla:status_macros", "//xla:types", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", ], ) @@ -80,9 +83,8 @@ xla_cc_test( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -91,5 +93,6 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_interface.h b/third_party/xla/xla/hlo/pass/hlo_pass_interface.h index d09f65e271451e..f9a52d2fa1aa1b 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_interface.h +++ b/third_party/xla/xla/hlo/pass/hlo_pass_interface.h @@ -17,12 +17,16 @@ limitations under the License. #define XLA_HLO_PASS_HLO_PASS_INTERFACE_H_ #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/status_macros.h" #include "xla/types.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -66,23 +70,11 @@ class HloPassInterface { // Run the pass on the given HLO module with specified execution_threads. // Empty execution_threads list means all execution_threads are included. - // Returns whether it modified the module. Note that due to C++ inheritance - // hides overloaded function, Run(HloModule* module) is not a member function - // of a subclass unless it's explicitly brought to the subclass besides - // implementing the virtual version, for instance, - // - // class MyNewPass : public HloModulePass { - // public: - // MyNewPass(); - // absl::string_view name() const override { return "my-new-pass"; } - // - // using HloPassInterface::Run; - // absl::StatusOr Run( - // HloModule* module, - // const absl::flat_hash_set& execution_threads) - // override; - // }; + // Returns whether it modified the module. // + // Note: C++ hides non-explicitly declared overloaded functions. + // You can make all overloaded variants available in the child class by + // adding `using HloPassInterface::Run;` to the child class declaration. absl::StatusOr Run(HloModule* module) { return Run(module, /*execution_threads=*/{}); } @@ -115,23 +107,8 @@ class HloPassInterface { // Ideally, the module group variant would be named "Run" as well, but C++ // does not handle overloaded virtual methods well. // - // Note that due to C++ inheritance hides overloaded function, - // RunOnModuleGroup(HloModuleGroup* module_group) is not a member function of - // a subclass unless it's explicitly brought to the subclass besides - // implementing the virtual version, for instance, - // - // class MyNewPass : public HloModuleGroupPass { - // public: - // MyNewPass(); - // absl::string_view name() const override { return "my-new-pass"; } - // - // using HloPassInterface::RunOnModuleGroup; - // absl::StatusOr RunOnModuleGroup( - // HloModuleGroup* module_group, - // const absl::flat_hash_set& execution_threads) - // override; - // }; - // + // See the caveat about C++ hiding overloaded functions in the Run function + // above. absl::StatusOr RunOnModuleGroup(HloModuleGroup* module_group) { return RunOnModuleGroup(module_group, /*execution_threads=*/{}); } @@ -139,7 +116,7 @@ class HloPassInterface { HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) = 0; - virtual bool IsPassPipeline() { return false; } + virtual bool IsPassPipeline() const { return false; } }; // Base class for passes which are module-scoped. diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h index 9510787d0294c9..37a131b899c67a 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline.h @@ -87,7 +87,7 @@ class HloPassPipeline : public HloPassInterface { HloModuleGroup* module_group, const absl::flat_hash_set& execution_threads) override; - bool IsPassPipeline() override { return true; } + bool IsPassPipeline() const override { return true; } // Return size of passes_. int PassesSize() { return passes_.size(); } diff --git a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc index 5ef86e33ef9461..97ce61299bb68d 100644 --- a/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc +++ b/third_party/xla/xla/hlo/pass/hlo_pass_pipeline_test.cc @@ -32,10 +32,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -47,7 +47,7 @@ using ::testing::ElementsAre; using ::testing::SizeIs; using ::testing::StrEq; -class HloPassPipelineTest : public HloTestBase { +class HloPassPipelineTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr ParseModuleGroup( absl::Span hlo_strings) { diff --git a/third_party/xla/xla/hlo/testlib/BUILD b/third_party/xla/xla/hlo/testlib/BUILD new file mode 100644 index 00000000000000..0bc189042cdabb --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/BUILD @@ -0,0 +1,114 @@ +# Description: +# Base testing infrastructure for XLA. + +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load("//xla:package_groups.bzl", "xla_tests_package_groups") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([":friends"]), + licenses = ["notice"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +xla_tests_package_groups() + +cc_library( + name = "verified_hlo_module", + testonly = True, + srcs = ["verified_hlo_module.cc"], + hdrs = ["verified_hlo_module.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/service:hlo_module_config", + "//xla/service:hlo_verifier", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "hlo_hardware_independent_test_base", + testonly = True, + srcs = ["hlo_hardware_independent_test_base.cc"], + hdrs = ["hlo_hardware_independent_test_base.h"], + deps = [ + ":filecheck", + ":verified_hlo_module", + "//xla:debug_options_flags", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "//xla/service:hlo_verifier", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "filecheck", + testonly = True, + srcs = ["filecheck.cc"], + hdrs = ["filecheck.h"], + data = [ + "@llvm-project//llvm:FileCheck", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deps = [ + "//xla:types", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:resource_loader", + "@local_tsl//tsl/platform:subprocess", + ], +) diff --git a/third_party/xla/xla/tests/filecheck.cc b/third_party/xla/xla/hlo/testlib/filecheck.cc similarity index 99% rename from third_party/xla/xla/tests/filecheck.cc rename to third_party/xla/xla/hlo/testlib/filecheck.cc index 5ca6138d9f234c..78f0f0a90cc834 100644 --- a/third_party/xla/xla/tests/filecheck.cc +++ b/third_party/xla/xla/hlo/testlib/filecheck.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/filecheck.h" +#include "xla/hlo/testlib/filecheck.h" #include diff --git a/third_party/xla/xla/hlo/testlib/filecheck.h b/third_party/xla/xla/hlo/testlib/filecheck.h new file mode 100644 index 00000000000000..3ea8de22f60fe8 --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/filecheck.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TESTLIB_FILECHECK_H_ +#define XLA_HLO_TESTLIB_FILECHECK_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/types.h" + +namespace xla { + +// Runs FileCheck with the given pattern over given input string. Provided that +// FileCheck can execute, returns true if and only if FileCheck succeeded in +// matching the input. +absl::StatusOr RunFileCheck(const std::string& input, + absl::string_view pattern); + +// Runs FileCheck with the given pattern file over given input string. Provided +// that FileCheck can execute, returns true if and only if FileCheck succeeded +// in matching the input. +absl::StatusOr RunFileCheckWithPatternFile( + const std::string& input, const std::string& pattern_file); + +} // namespace xla + +#endif // XLA_HLO_TESTLIB_FILECHECK_H_ diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc new file mode 100644 index 00000000000000..bbe1ecea736a3e --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.cc @@ -0,0 +1,351 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_verifier.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { + +HloHardwareIndependentTestBase::HloHardwareIndependentTestBase( + bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, + HloPredicate instruction_can_change_layout_func) + : verifier_layout_sensitive_(verifier_layout_sensitive), + allow_mixed_precision_in_hlo_verifier_( + allow_mixed_precision_in_hlo_verifier), + instruction_can_change_layout_func_(instruction_can_change_layout_func) { + hlo_verifier_ = std::make_unique( + /*layout_sensitive=*/verifier_layout_sensitive, + /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func); +} + +std::unique_ptr +HloHardwareIndependentTestBase::CreateNewUnverifiedModule( + const std::string& name) const { + return std::make_unique(name, GetModuleConfigForTest()); +} + +std::unique_ptr +HloHardwareIndependentTestBase::CreateNewVerifiedModule( + const std::string& name, int64_t replica_count) const { + return std::make_unique( + name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, + instruction_can_change_layout_func_); +} + +absl::StatusOr> +HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, int64_t replica_count, + int64_t num_partitions) const { + return ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); +} + +absl::Status HloHardwareIndependentTestBase:: + UpdateEntryComputationLayoutToMatchProgramLayout(HloModule* module) { + for (auto* const computation : module->computations({})) { + if (computation->IsEntryComputation()) { + for (int64_t i = 0; i < computation->num_parameters(); ++i) { + const Shape& param_shape = + computation->parameter_instruction(i)->shape(); + TF_RETURN_IF_ERROR(computation->parent() + ->mutable_entry_computation_layout() + ->mutable_parameter_layout(i) + ->CopyLayoutFromShape(param_shape)); + } + + TF_RETURN_IF_ERROR( + computation->parent() + ->mutable_entry_computation_layout() + ->mutable_result_layout() + ->CopyLayoutFromShape(computation->root_instruction()->shape())); + } + } + return absl::OkStatus(); +} + +absl::StatusOr> +HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule( + absl::string_view hlo_text, const HloModuleConfig& config) const { + auto module = std::make_unique( + TestName(), config, verifier_layout_sensitive_, + allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements, + instruction_can_change_layout_func_); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + return std::move(module); +} + +/* static */ +absl::StatusOr HloHardwareIndependentTestBase::RunHloPass( + HloPassInterface* hlo_pass, HloModule* module) { + const std::string module_str_before_run = + module->ToProto().ShortDebugString(); + const auto status_or = hlo_pass->Run(module); + if (status_or.status().ok()) { + const std::string module_str_after_run = + module->ToProto().ShortDebugString(); + const bool passChangedHlo = status_or.value(); + if (passChangedHlo) { + // Check that the proto actually changed. + EXPECT_NE(module_str_after_run, module_str_before_run); + } else { + // Check that the proto remains same. + EXPECT_EQ(module_str_after_run, module_str_before_run); + } + } + return status_or; +} + +/* static */ +absl::StatusOr HloHardwareIndependentTestBase::RunHloPass( + HloPassInterface&& hlo_pass, HloModuleGroup* module_group) { + const std::string module_group_str_before_run = + module_group->ToProto().ShortDebugString(); + const auto status_or = hlo_pass.RunOnModuleGroup(module_group); + if (status_or.status().ok()) { + const std::string module_group_str_after_run = + module_group->ToProto().ShortDebugString(); + const bool passChangedHlo = status_or.value(); + if (passChangedHlo) { + // Check that the proto actually changed. + EXPECT_NE(module_group_str_after_run, module_group_str_before_run); + } else { + // Check that the proto remains same. + EXPECT_EQ(module_group_str_after_run, module_group_str_before_run); + } + } + return status_or; +} + +/* static */ +PrecisionConfig HloHardwareIndependentTestBase::DefaultPrecisionConfig( + int operands) { + PrecisionConfig precision_config; + precision_config.mutable_operand_precision()->Resize( + operands, PrecisionConfig::DEFAULT); + return precision_config; +} + +void HloHardwareIndependentTestBase::SetAotFastMathDebugOptions( + DebugOptions* options) { + options->set_xla_cpu_enable_fast_math(true); + options->set_xla_gpu_enable_fast_min_max(true); + options->set_xla_cpu_enable_fast_min_max(true); + options->set_xla_cpu_fast_math_honor_nans(false); + options->set_xla_cpu_fast_math_honor_infs(false); + options->set_xla_cpu_fast_math_honor_functions(false); + options->set_xla_cpu_fast_math_honor_division(false); +} + +DebugOptions HloHardwareIndependentTestBase::GetDebugOptionsForTest() const { + auto debug_options = GetDebugOptionsFromFlags(); + // TODO(b/38354253): Change tests to use Parameters instead of Constants. + debug_options.add_xla_disable_hlo_passes("constant_folding"); + debug_options.set_xla_hlo_evaluator_use_fast_path(true); + return debug_options; +} + +void HloHardwareIndependentTestBase::RunAndFilecheckHloRewrite( + absl::string_view hlo, HloPassInterface&& hlo_pass, + std::optional expected, + std::function after_pass_checks, + const HloModuleConfig* config) const { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + config ? ParseAndReturnVerifiedModule(hlo, *config) + : ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get())); + EXPECT_EQ(changed, expected.has_value()) << module->ToString(); + if (changed) { + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), + *expected)); + EXPECT_TRUE(filecheck_matches); + if (after_pass_checks) { + after_pass_checks(module.get()); + } + } +} + +void HloHardwareIndependentTestBase::RunAndFilecheckHloModuleGroupRewrite( + absl::Span hlo_module_strs, + HloPassInterface&& hlo_pass, + std::optional> expected) const { + std::vector> modules; + for (absl::string_view hlo : hlo_module_strs) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + modules.push_back(std::move(module)); + } + HloModuleGroup module_group("test_input_module_group", std::move(modules)); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunHloPass(std::move(hlo_pass), &module_group)); + EXPECT_EQ(changed, expected.has_value()) << module_group.ToString(); + + if (!changed) { + return; + } + + EXPECT_THAT(module_group.modules(), + ::testing::SizeIs(expected.value().size())); + int index = 0; + for (auto expected_str : expected.value()) { + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck(module_group.module(index).ToString( + HloPrintOptions{}.set_print_operand_shape(false)), + expected_str)); + EXPECT_TRUE(filecheck_matches); + index++; + } +} + +absl::StatusOr> +HloHardwareIndependentTestBase::RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change, FixedMapping params) const { + std::string hlo_string = absl::StrReplaceAll(hlo_template, params); + SCOPED_TRACE("Input HLO: " + hlo_string); + VLOG(7) << "Input HLO: " << hlo_string; + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSIGN_OR_RETURN(bool changed, RunHloPass(hlo_pass, module.get())); + VLOG(7) << "Output HLO: " + << module->ToString(HloPrintOptions::ShortParsable()); + EXPECT_EQ(changed, expect_change); + return module; +} + +std::vector HloHardwareIndependentTestBase::CompareInputs( + const HloModule& module_0, const HloModule& module_1) { + const auto params_0 = module_0.entry_computation()->parameter_instructions(); + const auto params_1 = module_1.entry_computation()->parameter_instructions(); + std::vector mismatches; + int64_t min = std::min(params_0.size(), params_1.size()); + int64_t max = std::max(params_0.size(), params_1.size()); + for (int64_t i = 0; i < min; ++i) { + const HloModuleConfig& module_config_0 = module_0.config(); + const Shape& param_shape_0 = + (module_config_0.has_entry_computation_layout() && + module_config_0.entry_computation_layout() + .parameter_layout(i) + .shape() + .is_static()) + ? module_config_0.entry_computation_layout() + .parameter_layout(i) + .shape() + : params_0[i]->shape(); + + const HloModuleConfig& module_config_1 = module_1.config(); + const Shape& param_shape_1 = + (module_config_1.has_entry_computation_layout() && + module_config_1.entry_computation_layout() + .parameter_layout(i) + .shape() + .is_static()) + ? module_config_1.entry_computation_layout() + .parameter_layout(i) + .shape() + : params_1[i]->shape(); + + if (!Shape::Equal().IgnoreTilesInLayout()(param_shape_0, param_shape_1)) { + mismatches.push_back(i); + } + } + for (int64_t i = min; i < max; i++) { + mismatches.push_back(i); + } + return mismatches; +} + +/* static */ +HloComputation* HloHardwareIndependentTestBase::FindComputation( + HloModule* module, absl::string_view name) { + return hlo_query::FindComputation(module, name); +} + +/* static */ +HloInstruction* HloHardwareIndependentTestBase::FindInstruction( + HloModule* module, absl::string_view name) { + for (const HloComputation* computation : module->computations()) { + if (HloInstruction* instruction = + hlo_query::FindInstruction(computation, name)) { + return instruction; + } + } + return nullptr; +} + +/* static */ +HloInstruction* HloHardwareIndependentTestBase::FindInstruction( + HloModule* module, HloOpcode opcode) { + for (const HloComputation* computation : module->computations()) { + if (HloInstruction* instruction = + hlo_query::FindInstruction(computation, opcode)) { + return instruction; + } + } + return nullptr; +} + +/* static */ +std::vector HloHardwareIndependentTestBase::FindInstructions( + HloModule* module, HloOpcode opcode) { + std::vector instructions; + for (const HloComputation* c : module->computations()) { + absl::c_copy_if(c->instructions(), std::back_inserter(instructions), + [&](HloInstruction* i) { return i->opcode() == opcode; }); + } + return instructions; +} + +} // namespace xla diff --git a/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h new file mode 100644 index 00000000000000..2a7f1f488b54e8 --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/hlo_hardware_independent_test_base.h @@ -0,0 +1,265 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TESTLIB_HLO_HARDWARE_INDEPENDENT_TEST_BASE_H_ +#define XLA_HLO_TESTLIB_HLO_HARDWARE_INDEPENDENT_TEST_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/layout.h" +#include "xla/service/computation_layout.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_verifier.h" +#include "xla/shape_layout.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/test.h" + +namespace xla { + +// A base class for tests which build and manipulate HLO without running it. +// +class HloHardwareIndependentTestBase : public ::testing::Test { + public: + static PrecisionConfig DefaultPrecisionConfig(int operands); + + protected: + explicit HloHardwareIndependentTestBase( + bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true, + HloPredicate instruction_can_change_layout_func = {}); + + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + // + // This returns a vanilla HloModule that doesn't run the HLO verifier on + // destruction. + ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") + std::unique_ptr CreateNewUnverifiedModule( + const std::string& name = TestName()) const; + + // Like CreateNewUnverifiedModule, except the HloModule returned here runs the + // HLO verifier on destruction. + std::unique_ptr CreateNewVerifiedModule( + const std::string& name = TestName(), int64_t replica_count = 1) const; + + // Parses the given string and returns module as a VerifiedHloModule. + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count = 1, + int64_t num_partitions = 1) const; + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) const; + + // Runs the hlo_pass with the provided module and returns the result. This + // function also verifies that the module remains unchanged when hlo_pass + // returns false as the absl::StatusOr value. + // + // These three overloads all do the same thing. The && overload lets you do + // `RunHloPass(MyPass(), module)` all in one line. The reason for the + // overload that takes a pointer is that, at one point in the past, non-const + // lvalue references were banned in Google code. + static absl::StatusOr RunHloPass(HloPassInterface* hlo_pass, + HloModule* module); + static absl::StatusOr RunHloPass(HloPassInterface& hlo_pass, + HloModule* module) { + return RunHloPass(&hlo_pass, module); + } + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModule* module) { + return RunHloPass(&hlo_pass, module); + } + + // Runs the hlo_pass with the provided module group and returns the result. + // This method runs the input HLO module group pass for a `HloModuleGroup` and + // it also verifies the module group remains unchanged when hlo_pass returns + // false as the absl::StatusOr value. + static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, + HloModuleGroup* module_group); + + // Sets most fath math options to be enabled to model the fast math flags + // generally used for CPU:AOT compilation. + static void SetAotFastMathDebugOptions(DebugOptions* options); + + // Runs pass `hlo_pass` on input HLO module `hlo` with optional config, and + // FileChecks the result against `expected`. + // + // If the rewrite has changed the module, also runs `additional_checks` on the + // result. + void RunAndFilecheckHloRewrite( + absl::string_view hlo, HloPassInterface&& hlo_pass, + std::optional expected, + std::function after_pass_checks = nullptr, + const HloModuleConfig* config = nullptr) const; + + // Runs pass `hlo_pass` on a group of input HLO modules `hlo_module_strs`, + // and FileChecks the result against `expected`. + void RunAndFilecheckHloModuleGroupRewrite( + absl::Span hlo_module_strs, + HloPassInterface&& hlo_pass, + std::optional> expected) const; + + using FixedMapping = + std::initializer_list>; + + // Creates an HLO module from a template and an optional replacement map and + // runs the given hlo_pass on the module. Validates whether the pass has + // changed the module or not based on expect_change flag. Returns unique_ptr + // to the HLO module for further inspection. + absl::StatusOr> RunAndCheckHloRewrite( + absl::string_view hlo_template, HloPassInterface&& hlo_pass, + bool expect_change = true, FixedMapping params = {}) const; + + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + // + // This function is virtual so tests can specify an alternative set of debug + // options (e.g. disabling additional passes). + virtual DebugOptions GetDebugOptionsForTest() const; + + // Gets an HloModuleConfig with options appropriate for tests. + HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, + int64_t num_partitions = 1) const { + HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + config.set_replica_count(replica_count); + config.set_num_partitions(num_partitions); + return config; + } + + // Convenience method to force the layout of a given parameter in a module. + // The layout of parameter number 'param_no' in the 'module' is set to + // 'layout'. + static void ForceParameterLayout(HloModule* module, int64_t param_no, + const Layout& layout) { + ASSERT_LT(param_no, + module->mutable_entry_computation_layout()->parameter_count()); + module->mutable_entry_computation_layout() + ->mutable_parameter_layout(param_no) + ->ResetLayout(layout); + } + + // Convenience method to force the layout of the computation result in a + // module. The result layout of 'module' is set to 'layout'. + static void ForceResultLayout(HloModule* module, const Layout& layout) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout); + } + + static void ForceResultLayout(HloModule* module, const Layout& layout, + ShapeIndexView shape_index) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(layout, shape_index); + } + + // Convenience method to clear the layout of the computation result in + // 'module'. + static void ForceClearResultLayout(HloModule* module) { + module->mutable_entry_computation_layout() + ->mutable_result_layout() + ->Clear(); + } + + // Gets the computation/instruction from the given module with the given name. + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + static HloComputation* FindComputation(HloModule* module, + absl::string_view name); + static HloInstruction* FindInstruction(HloModule* module, + absl::string_view name); + // Gets the instruction from the given module with the given opcode. + static HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); + // Gets all the instructions from the given module with the given opcode. + static std::vector FindInstructions(HloModule* module, + HloOpcode opcode); + + bool verifier_layout_sensitive() const { return verifier_layout_sensitive_; } + void set_verifier_layout_sensitive(bool verifier_layout_sensitive) { + verifier_layout_sensitive_ = verifier_layout_sensitive; + } + HloPredicate instruction_can_change_layout_func() const { + return instruction_can_change_layout_func_; + } + void set_instruction_can_change_layout_func( + HloPredicate instruction_can_change_layout_func) { + instruction_can_change_layout_func_ = + std::move(instruction_can_change_layout_func); + } + // Return an HLO verifier constructed for the test backend. + HloVerifier& verifier() const { return *hlo_verifier_; } + void set_hlo_verifier(std::unique_ptr hlo_verifier) { + hlo_verifier_ = std::move(hlo_verifier); + } + bool allow_mixed_precision_in_hlo_verifier() const { + return allow_mixed_precision_in_hlo_verifier_; + } + + static std::string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + // Updates the entry computation layout to match the program shape. Useful + // when tiling assignment has been run to update the latter and we want those + // changes propagated into the former. + static absl::Status UpdateEntryComputationLayoutToMatchProgramLayout( + HloModule* module); + + // Compares the inputs shapes of two modules and returns the list of parameter + // indices that mismatch. The mismatch could be either in shape or datatype. + // If there is no mismatch, an empty vector is returned. + [[nodiscard]] std::vector CompareInputs(const HloModule& module_0, + const HloModule& module_1); + + private: + bool verifier_layout_sensitive_; + bool allow_mixed_precision_in_hlo_verifier_; + HloPredicate instruction_can_change_layout_func_; + std::unique_ptr hlo_verifier_; +}; + +} // namespace xla + +#endif // XLA_HLO_TESTLIB_HLO_HARDWARE_INDEPENDENT_TEST_BASE_H_ diff --git a/third_party/xla/xla/tests/verified_hlo_module.cc b/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc similarity index 95% rename from third_party/xla/xla/tests/verified_hlo_module.cc rename to third_party/xla/xla/hlo/testlib/verified_hlo_module.cc index 9c9f1f2a21bd05..044bc2f5ca40bc 100644 --- a/third_party/xla/xla/tests/verified_hlo_module.cc +++ b/third_party/xla/xla/hlo/testlib/verified_hlo_module.cc @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/verified_hlo_module.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/testlib/verified_hlo_module.h b/third_party/xla/xla/hlo/testlib/verified_hlo_module.h new file mode 100644 index 00000000000000..6c8f03a1c01df3 --- /dev/null +++ b/third_party/xla/xla/hlo/testlib/verified_hlo_module.h @@ -0,0 +1,66 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_ +#define XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_verifier.h" +#include "xla/shape.h" +#include "xla/types.h" +#include "tsl/platform/status.h" + +namespace xla { + +// An HLO module derived class which verifies itself on destruction. This class +// is intended to be used in unit tests. Any verification errors are raised via +// ADD_FAILURE. +class VerifiedHloModule : public HloModule { + public: + VerifiedHloModule(const std::string& name, const HloModuleConfig& config, + bool verifier_layout_sensitive, + bool allow_mixed_precision_in_hlo_verifier, + std::function shape_size_function, + HloPredicate instruction_can_change_layout_func = {}) + : HloModule(name, config), + verifier_(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func, shape_size_function) {} + + ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } + + // Given a string in the HloModule::ToString() format, parses the string and + // builds the VerifiedHloModule in place. Before calling this method, the + // module must be empty (no computations). Finally verifies the module using + // HloVerifier and returns the status. + absl::Status ParseHloStringAndVerifyModule(absl::string_view str); + + // Verifies the module and flags any error with ADD_FAILURE. 'message' is + // included in the failure message. + void VerifyOrAddFailure(absl::string_view message); + + // Verifies the module using HloVerifier and returns the status. + absl::Status Verify(); + + private: + HloVerifier verifier_; +}; + +} // namespace xla + +#endif // XLA_HLO_TESTLIB_VERIFIED_HLO_MODULE_H_ diff --git a/third_party/xla/xla/hlo/tools/BUILD b/third_party/xla/xla/hlo/tools/BUILD new file mode 100644 index 00000000000000..eb35cbc1a0ab19 --- /dev/null +++ b/third_party/xla/xla/hlo/tools/BUILD @@ -0,0 +1,155 @@ +# Tools and utilities that aid in XLA development and usage. + +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load( + "//xla:xla.bzl", + "xla_cc_binary", +) +load("//xla/tsl:tsl.default.bzl", "filegroup") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla:internal"], + licenses = ["notice"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), + visibility = ["//xla:internal"], +) + +build_test( + name = "hex_floats_to_packed_literal_build_test", + targets = [ + ":hex_floats_to_packed_literal", + ], +) + +xla_cc_binary( + name = "hex_floats_to_packed_literal", + srcs = ["hex_floats_to_packed_literal.cc"], + deps = [ + "//xla/tsl/lib/io:buffered_inputstream", + "//xla/tsl/lib/io:random_inputstream", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + ], +) + +build_test( + name = "show_literal_build_test", + targets = [ + ":show_literal", + ], +) + +xla_cc_binary( + name = "show_literal", + srcs = ["show_literal.cc"], + deps = [ + "//xla:literal", + "//xla:types", + "//xla:xla_data_proto_cc", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + ], +) + +build_test( + name = "convert_computation_build_test", + targets = [ + ":convert_computation", + ], +) + +xla_cc_binary( + name = "convert_computation", + srcs = ["convert_computation.cc"], + deps = [ + "//xla/service:hlo_proto_cc", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + ], +) + +build_test( + name = "hlo_module_metadata_processor_build_test", + targets = [ + ":hlo_module_metadata_processor", + ], +) + +xla_cc_binary( + name = "hlo_module_metadata_processor", + srcs = ["hlo_module_metadata_processor.cc"], + deps = [ + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + ], +) + +build_test( + name = "show_text_literal_build_test", + targets = [ + ":show_text_literal", + ], +) + +xla_cc_binary( + name = "show_text_literal", + srcs = ["show_text_literal.cc"], + deps = [ + "//xla:literal", + "//xla:text_literal_reader", + "//xla:types", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", + ], +) + +build_test( + name = "hlo_proto_to_json_build_test", + targets = [ + ":hlo_proto_to_json", + ], +) + +xla_cc_binary( + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], + deps = [ + "//xla:util", + "//xla/service:hlo_proto_cc", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + ], +) diff --git a/third_party/xla/xla/tools/convert_computation.cc b/third_party/xla/xla/hlo/tools/convert_computation.cc similarity index 100% rename from third_party/xla/xla/tools/convert_computation.cc rename to third_party/xla/xla/hlo/tools/convert_computation.cc diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc similarity index 96% rename from third_party/xla/xla/tools/hex_floats_to_packed_literal.cc rename to third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc index 6388a8fb84d71c..659e4cde814b5d 100644 --- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc +++ b/third_party/xla/xla/hlo/tools/hex_floats_to_packed_literal.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/util/command_line_flags.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/init_main.h" diff --git a/third_party/xla/xla/tools/hlo_module_metadata_processor.cc b/third_party/xla/xla/hlo/tools/hlo_module_metadata_processor.cc similarity index 100% rename from third_party/xla/xla/tools/hlo_module_metadata_processor.cc rename to third_party/xla/xla/hlo/tools/hlo_module_metadata_processor.cc diff --git a/third_party/xla/xla/tools/hlo_proto_to_json.cc b/third_party/xla/xla/hlo/tools/hlo_proto_to_json.cc similarity index 100% rename from third_party/xla/xla/tools/hlo_proto_to_json.cc rename to third_party/xla/xla/hlo/tools/hlo_proto_to_json.cc diff --git a/third_party/xla/xla/tools/show_literal.cc b/third_party/xla/xla/hlo/tools/show_literal.cc similarity index 100% rename from third_party/xla/xla/tools/show_literal.cc rename to third_party/xla/xla/hlo/tools/show_literal.cc diff --git a/third_party/xla/xla/tools/show_text_literal.cc b/third_party/xla/xla/hlo/tools/show_text_literal.cc similarity index 100% rename from third_party/xla/xla/tools/show_text_literal.cc rename to third_party/xla/xla/hlo/tools/show_text_literal.cc diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index 3dd713127f9435..a412d1bd3db8a6 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -2,7 +2,12 @@ # Implementation of XLA’s HLO transformations. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "tsl_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -19,8 +24,8 @@ package_group( cc_library( name = "hlo_constant_splitter", - srcs = ["hlo_constant_splitter.cc"], - hdrs = ["hlo_constant_splitter.h"], + srcs = ["simplifiers/hlo_constant_splitter.cc"], + hdrs = ["simplifiers/hlo_constant_splitter.h"], deps = [ "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", @@ -39,17 +44,2757 @@ cc_library( xla_cc_test( name = "hlo_constant_splitter_test", - srcs = ["hlo_constant_splitter_test.cc"], + srcs = ["simplifiers/hlo_constant_splitter_test.cc"], deps = [ ":hlo_constant_splitter", + ":hlo_dce", + "//xla:test", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "async_collective_creator", + srcs = ["collectives/async_collective_creator.cc"], + hdrs = ["collectives/async_collective_creator.h"], + deps = [ + "//xla:frontend_attributes", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "async_collective_creator_test", + srcs = ["collectives/async_collective_creator_test.cc"], + deps = [ + ":async_collective_creator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_reduce_folder", + srcs = ["simplifiers/all_reduce_folder.cc"], + hdrs = ["simplifiers/all_reduce_folder.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:all_reduce_key", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "all_reduce_folder_test", + srcs = ["simplifiers/all_reduce_folder_test.cc"], + deps = [ + ":all_reduce_folder", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "broadcast_canonicalizer", + srcs = ["simplifiers/broadcast_canonicalizer.cc"], + hdrs = ["simplifiers/broadcast_canonicalizer.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "broadcast_canonicalizer_test", + srcs = ["simplifiers/broadcast_canonicalizer_test.cc"], + deps = [ + ":broadcast_canonicalizer", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["simplifiers/bfloat16_conversion_folding.cc"], + hdrs = ["simplifiers/bfloat16_conversion_folding.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:float_support", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["simplifiers/bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:float_support", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "float_normalization", + srcs = ["simplifiers/float_normalization.cc"], + hdrs = ["simplifiers/float_normalization.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:float_support", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "float_normalization_test", + srcs = ["simplifiers/float_normalization_test.cc"], + deps = [ + ":float_normalization", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:float_support", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_verifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "bfloat16_propagation", + srcs = ["bfloat16_propagation.cc"], + hdrs = ["bfloat16_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:literal", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:float_support", + "//xla/service:hlo_value", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "bfloat16_propagation_test", + srcs = ["bfloat16_propagation_test.cc"], + deps = [ + ":bfloat16_propagation", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:float_support", + "//xla/service:hlo_verifier", + "//xla/tests:literal_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "convert_async_collectives_to_sync", + srcs = ["collectives/convert_async_collectives_to_sync.cc"], + hdrs = ["collectives/convert_async_collectives_to_sync.h"], + deps = [ + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "convert_async_collectives_to_sync_test", + srcs = ["collectives/convert_async_collectives_to_sync_test.cc"], + deps = [ + ":convert_async_collectives_to_sync", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "collective_quantizer", + srcs = ["collectives/collective_quantizer.cc"], + hdrs = ["collectives/collective_quantizer.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/analysis:hlo_replication_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "collective_quantizer_test", + srcs = ["collectives/collective_quantizer_test.cc"], + deps = [ + ":collective_quantizer", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_verifier", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "flatten_call_graph", + srcs = ["simplifiers/flatten_call_graph.cc"], + hdrs = ["simplifiers/flatten_call_graph.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "hlo_computation_deduplicator", + srcs = ["simplifiers/hlo_computation_deduplicator.cc"], + hdrs = ["simplifiers/hlo_computation_deduplicator.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "hlo_computation_deduplicator_test", + size = "small", + srcs = ["simplifiers/hlo_computation_deduplicator_test.cc"], + deps = [ + ":hlo_computation_deduplicator", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "flatten_call_graph_test", + srcs = ["simplifiers/flatten_call_graph_test.cc"], + deps = [ + ":flatten_call_graph", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:call_graph", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_memory_scheduler", + srcs = ["simplifiers/hlo_memory_scheduler.cc"], + hdrs = ["simplifiers/hlo_memory_scheduler.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:tuple_points_to_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:buffer_value", + "//xla/service:logical_buffer", + "//xla/service/heap_simulator", + "//xla/tsl/lib/gtl:map_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + ], +) + +xla_cc_test( + name = "hlo_memory_scheduler_test", + srcs = ["simplifiers/hlo_memory_scheduler_test.cc"], + deps = [ + ":hlo_dce", + ":hlo_memory_scheduler", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:buffer_value", + "//xla/service:hlo_value", + "//xla/service:logical_buffer", + "//xla/service/heap_simulator", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "op_expander_pass", + srcs = ["expanders/op_expander_pass.cc"], + hdrs = ["expanders/op_expander_pass.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "optimization_barrier_expander", + srcs = ["expanders/optimization_barrier_expander.cc"], + hdrs = ["expanders/optimization_barrier_expander.h"], + deps = [ + ":op_expander_pass", + ], +) + +cc_library( + name = "comparison_expander", + srcs = ["expanders/comparison_expander.cc"], + hdrs = ["expanders/comparison_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "cholesky_expander", + srcs = ["expanders/cholesky_expander.cc"], + hdrs = ["expanders/cholesky_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "qr_expander", + srcs = ["expanders/qr_expander.cc"], + hdrs = ["expanders/qr_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:qr", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "real_imag_expander", + srcs = ["expanders/real_imag_expander.cc"], + hdrs = ["expanders/real_imag_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + ], +) + +xla_cc_test( + name = "real_imag_expander_test", + size = "small", + srcs = ["expanders/real_imag_expander_test.cc"], + deps = [ + ":real_imag_expander", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_creation_utils", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "eigh_expander", + srcs = ["expanders/eigh_expander.cc"], + hdrs = ["expanders/eigh_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:comparators", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:loops", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "convolution_4d_expander", + srcs = ["expanders/convolution_4d_expander.cc"], + hdrs = ["expanders/convolution_4d_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "convolution_4d_expander_test", + srcs = ["expanders/convolution_4d_expander_test.cc"], + deps = [ + "convolution_4d_expander", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "convolution_pred_expander", + srcs = ["expanders/convolution_pred_expander.cc"], + hdrs = ["expanders/convolution_pred_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "convolution_pred_expander_test", + srcs = ["expanders/convolution_pred_expander_test.cc"], + deps = [ + ":convolution_pred_expander", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "algebraic_simplifier", + srcs = ["simplifiers/algebraic_simplifier.cc"], + hdrs = ["simplifiers/algebraic_simplifier.h"], + copts = tsl_copts(), + deps = [ + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_sharding_util", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_module_config", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:pattern_matcher", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "tree_reduction_rewriter", + srcs = ["simplifiers/tree_reduction_rewriter.cc"], + hdrs = ["simplifiers/tree_reduction_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:padding", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "algebraic_simplifier_test", + srcs = ["simplifiers/algebraic_simplifier_test.cc"], + deps = [ + ":algebraic_simplifier", + ":hlo_constant_folding", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_creation_utils", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:layout_assignment", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "simplify_fp_conversions", + srcs = ["simplifiers/simplify_fp_conversions.cc"], + hdrs = ["simplifiers/simplify_fp_conversions.h"], + deps = [ + "//xla:comparison_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "simplify_fp_conversions_test", + srcs = ["simplifiers/simplify_fp_conversions_test.cc"], + deps = [ + ":simplify_fp_conversions", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "logistic_expander", + srcs = ["expanders/logistic_expander.cc"], + hdrs = ["expanders/logistic_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "logistic_expander_test", + srcs = ["expanders/logistic_expander_test.cc"], + deps = [ + ":logistic_expander", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:dynamic_padder", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "collectives_schedule_linearizer", + srcs = ["collectives/collectives_schedule_linearizer.cc"], + hdrs = ["collectives/collectives_schedule_linearizer.h"], + deps = [ + "//xla:util", + "//xla/hlo/analysis:hlo_reachability", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "collectives_schedule_linearizer_test", + srcs = ["collectives/collectives_schedule_linearizer_test.cc"], + deps = [ + ":collectives_schedule_linearizer", + "//xla:test", + "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_gather_broadcast_reorder", + srcs = ["collectives/all_gather_broadcast_reorder.cc"], + hdrs = ["collectives/all_gather_broadcast_reorder.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "bitcast_dtypes_expander", + srcs = ["expanders/bitcast_dtypes_expander.cc"], + hdrs = ["expanders/bitcast_dtypes_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:broadcast", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "bitcast_dtypes_expander_test", + srcs = ["expanders/bitcast_dtypes_expander_test.cc"], + deps = [ + ":bitcast_dtypes_expander", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "all_gather_broadcast_reorder_test", + srcs = ["collectives/all_gather_broadcast_reorder_test.cc"], + deps = [ + ":all_gather_broadcast_reorder", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_gather_combiner", + srcs = ["collectives/all_gather_combiner.cc"], + hdrs = ["collectives/all_gather_combiner.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_sharding_util", + "//xla/service:collective_combiner_utils", + "//xla/service:hlo_domain_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_gather_combiner_test", + srcs = ["collectives/all_gather_combiner_test.cc"], + deps = [ + ":all_gather_combiner", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_reduce_combiner", + srcs = ["collectives/all_reduce_combiner.cc"], + hdrs = ["collectives/all_reduce_combiner.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/hlo/utils:hlo_sharding_util", + "//xla/service:all_reduce_key", + "//xla/service:collective_combiner_utils", + "//xla/service:hlo_domain_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_reduce_combiner_test", + srcs = ["collectives/all_reduce_combiner_test.cc"], + deps = [ + ":all_reduce_combiner", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_reduce_contiguous", + srcs = ["collectives/all_reduce_contiguous.cc"], + hdrs = ["collectives/all_reduce_contiguous.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "all_reduce_contiguous_test", + srcs = ["collectives/all_reduce_contiguous_test.cc"], + deps = [ + ":all_reduce_contiguous", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "batch_dot_simplification", + srcs = ["simplifiers/batch_dot_simplification.cc"], + hdrs = ["simplifiers/batch_dot_simplification.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "batch_dot_simplification_test", + srcs = ["simplifiers/batch_dot_simplification_test.cc"], + deps = [ + ":batch_dot_simplification", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "convolution_group_converter", + srcs = ["simplifiers/convolution_group_converter.cc"], + hdrs = ["simplifiers/convolution_group_converter.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "convolution_group_converter_test", + size = "small", + srcs = ["simplifiers/convolution_group_converter_test.cc"], + deps = [ + ":convolution_group_converter", + "//xla:test", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "while_loop_trip_count_annotator", + srcs = ["while_loop_trip_count_annotator.cc"], + hdrs = ["while_loop_trip_count_annotator.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "while_loop_trip_count_annotator_test", + srcs = ["while_loop_trip_count_annotator_test.cc"], + deps = [ + ":while_loop_trip_count_annotator", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "defuser", + srcs = ["defuser.cc"], + hdrs = ["defuser.h"], + deps = [ + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_map", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "defuser_test", + srcs = ["defuser_test.cc"], + deps = [ + ":defuser", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "despecializer_test", + srcs = ["despecializer_test.cc"], + deps = [ + ":despecializer", + "//xla:literal", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_decomposer", + srcs = ["expanders/dot_decomposer.cc"], + hdrs = ["expanders/dot_decomposer.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_decomposer_test", + srcs = ["expanders/dot_decomposer_test.cc"], + deps = [ + ":dot_decomposer", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "dot_dimension_merger", + srcs = ["simplifiers/dot_dimension_merger.cc"], + hdrs = ["simplifiers/dot_dimension_merger.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_dimension_merger_test", + srcs = ["simplifiers/dot_dimension_merger_test.cc"], + deps = [ + ":dot_dimension_merger", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_merger", + srcs = ["simplifiers/dot_merger.cc"], + hdrs = ["simplifiers/dot_merger.h"], + deps = [ + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:shape_inference", + "//xla/service/graphcycles", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_merger_test", + srcs = ["simplifiers/dot_merger_test.cc"], + deps = [ + ":algebraic_simplifier", + ":dot_merger", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "convert_mover", + srcs = ["simplifiers/convert_mover.cc"], + hdrs = ["simplifiers/convert_mover.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "convert_mover_test", + srcs = ["simplifiers/convert_mover_test.cc"], + deps = [ + ":convert_mover", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +cc_library( + name = "tuple_simplifier", + srcs = ["simplifiers/tuple_simplifier.cc"], + hdrs = ["simplifiers/tuple_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "tuple_simplifier_test", + srcs = ["simplifiers/tuple_simplifier_test.cc"], + deps = [ + ":tuple_simplifier", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reshape_mover", + srcs = ["simplifiers/reshape_mover.cc"], + hdrs = ["simplifiers/reshape_mover.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "reshape_decomposer", + srcs = ["expanders/reshape_decomposer.cc"], + hdrs = ["expanders/reshape_decomposer.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "reduce_decomposer", + srcs = ["expanders/reduce_decomposer.cc"], + hdrs = ["expanders/reduce_decomposer.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/status", + ], +) + +xla_cc_test( + name = "reduce_decomposer_test", + srcs = ["expanders/reduce_decomposer_test.cc"], + deps = [ + ":reduce_decomposer", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "reshape_decomposer_test", + srcs = ["expanders/reshape_decomposer_test.cc"], + deps = [ + ":reshape_decomposer", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dynamic_dimension_simplifier", + srcs = ["simplifiers/dynamic_dimension_simplifier.cc"], + hdrs = ["simplifiers/dynamic_dimension_simplifier.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + ], +) + +xla_cc_test( + name = "dynamic_dimension_simplifier_test", + srcs = ["simplifiers/dynamic_dimension_simplifier_test.cc"], + deps = [ + ":dynamic_dimension_simplifier", + "//xla:literal", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_creation_utils", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", # fixdeps: keep + ], +) + +xla_cc_test( + name = "reshape_mover_test", + srcs = ["simplifiers/reshape_mover_test.cc"], + deps = [ + ":algebraic_simplifier", + ":reshape_mover", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:hlo_verifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "memory_space_propagation", + srcs = ["memory_space_propagation.cc"], + hdrs = ["memory_space_propagation.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + ], +) + +xla_cc_test( + name = "memory_space_propagation_test", + srcs = ["memory_space_propagation_test.cc"], + deps = [ + ":memory_space_propagation", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_dce", + srcs = ["simplifiers/hlo_dce.cc"], + hdrs = ["simplifiers/hlo_dce.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_rematerialization", + srcs = ["simplifiers/hlo_rematerialization.cc"], + hdrs = ["simplifiers/hlo_rematerialization.h"], + deps = [ + ":hlo_dce", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:tuple_points_to_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "//xla/service:hlo_cost_analysis", + "//xla/service:logical_buffer", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_rematerialization_test_utils", + testonly = 1, + hdrs = ["simplifiers/hlo_rematerialization_test_utils.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "hlo_rematerialization_test_utils_test", + srcs = ["simplifiers/hlo_rematerialization_test_utils_test.cc"], + deps = [ + ":hlo_rematerialization_test_utils", + "//xla/hlo/ir:hlo", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "hlo_rematerialization_test", + srcs = ["simplifiers/hlo_rematerialization_test.cc"], + deps = [ + ":hlo_memory_scheduler", + ":hlo_rematerialization", + ":hlo_rematerialization_test_utils", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_cost_analysis", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "hlo_dce_test", + srcs = ["simplifiers/hlo_dce_test.cc"], + deps = [ + ":hlo_dce", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_constant_folding", + srcs = ["simplifiers/hlo_constant_folding.cc"], + hdrs = ["simplifiers/hlo_constant_folding.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:slow_operation_alarm", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "hlo_constant_folding_test", + srcs = ["simplifiers/hlo_constant_folding_test.cc"], + deps = [ + ":hlo_constant_folding", + "//xla:literal", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "hlo_element_type_converter", + srcs = ["simplifiers/hlo_element_type_converter.cc"], + hdrs = ["simplifiers/hlo_element_type_converter.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:types", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "hlo_element_type_converter_test", + srcs = ["simplifiers/hlo_element_type_converter_test.cc"], + deps = [ + ":hlo_element_type_converter", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "conditional_canonicalizer", + srcs = ["simplifiers/conditional_canonicalizer.cc"], + hdrs = ["simplifiers/conditional_canonicalizer.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + ], +) + +xla_cc_test( + name = "conditional_canonicalizer_test", + srcs = ["simplifiers/conditional_canonicalizer_test.cc"], + deps = [ + ":conditional_canonicalizer", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "zero_sized_hlo_elimination", + srcs = ["simplifiers/zero_sized_hlo_elimination.cc"], + hdrs = ["simplifiers/zero_sized_hlo_elimination.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "zero_sized_hlo_elimination_test", + srcs = ["simplifiers/zero_sized_hlo_elimination_test.cc"], + deps = [ + ":zero_sized_hlo_elimination", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sort_simplifier", + srcs = ["simplifiers/sort_simplifier.cc"], + hdrs = ["simplifiers/sort_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "sort_simplifier_test", + srcs = ["simplifiers/sort_simplifier_test.cc"], + deps = [ + ":sort_simplifier", + "//xla:test", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "stable_sort_expander", + srcs = ["expanders/stable_sort_expander.cc"], + hdrs = ["expanders/stable_sort_expander.h"], + deps = [ + ":op_expander_pass", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_cc_test( + name = "stable_sort_expander_test", + srcs = ["expanders/stable_sort_expander_test.cc"], + deps = [ + ":algebraic_simplifier", + ":stable_sort_expander", + "//xla:test", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "root_instruction_sinker", + srcs = ["simplifiers/root_instruction_sinker.cc"], + hdrs = ["simplifiers/root_instruction_sinker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:tuple_util", + ], +) + +xla_cc_test( + name = "root_instruction_sinker_test", + srcs = ["simplifiers/root_instruction_sinker_test.cc"], + deps = [ + ":root_instruction_sinker", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "convert_memory_placement_to_internal_annotations", + srcs = ["convert_memory_placement_to_internal_annotations.cc"], + hdrs = ["convert_memory_placement_to_internal_annotations.h"], + deps = [ + "//xla:side_effect_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:host_memory_offload_annotations_hdr", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_memory_placement_to_internal_annotations_test", + srcs = ["convert_memory_placement_to_internal_annotations_test.cc"], + deps = [ + ":convert_memory_placement_to_internal_annotations", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:host_memory_offload_annotations_hdr", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_memory_transfer_asyncifier", + srcs = ["simplifiers/host_memory_transfer_asyncifier.cc"], + hdrs = ["simplifiers/host_memory_transfer_asyncifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_memory_transfer_asyncifier_test", + srcs = ["simplifiers/host_memory_transfer_asyncifier_test.cc"], + deps = [ + ":host_memory_transfer_asyncifier", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_offload_legalize", + srcs = ["host_offload_legalize.cc"], + hdrs = ["host_offload_legalize.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:hlo_value", + "//xla/service:host_memory_offload_annotations_hdr", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offload_legalize_test", + srcs = ["host_offload_legalize_test.cc"], + shard_count = 12, + deps = [ + ":host_offload_legalize", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_offloader", + srcs = ["host_offloader.cc"], + hdrs = ["host_offloader.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_tree", + "//xla:shape_util", + "//xla:side_effect_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:hlo_buffer", + "//xla/service:hlo_cse", + "//xla/service:hlo_value", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:host_offload_utils", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offloader_test", + srcs = ["host_offloader_test.cc"], + shard_count = 12, + deps = [ + ":host_offload_legalize", + ":host_offloader", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:hlo_verifier", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:host_offload_utils", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "host_offloading_prepare", + srcs = ["host_offloading_prepare.cc"], + hdrs = ["host_offloading_prepare.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "//xla/service:host_memory_offload_annotations_hdr", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offloading_prepare_test", + srcs = ["host_offloading_prepare_test.cc"], + deps = [ + ":host_offloading_prepare", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "fusion_constant_sinking", + srcs = ["simplifiers/fusion_constant_sinking.cc"], + hdrs = ["simplifiers/fusion_constant_sinking.h"], + deps = [ + ":hlo_dce", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_constant_sinking_test", + srcs = ["simplifiers/fusion_constant_sinking_test.cc"], + deps = [ + ":fusion_constant_sinking", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "despecializer", + srcs = ["despecializer.cc"], + hdrs = ["despecializer.h"], + deps = [ + ":defuser", + ":float_normalization", + ":hlo_memory_scheduler", + ":sub_byte_normalization", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "optimize_input_output_buffer_alias", + srcs = ["simplifiers/optimize_input_output_buffer_alias.cc"], + hdrs = ["simplifiers/optimize_input_output_buffer_alias.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "optimize_input_output_buffer_alias_test", + srcs = ["simplifiers/optimize_input_output_buffer_alias_test.cc"], + deps = [ + ":optimize_input_output_buffer_alias", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "ar_crs_combiner", + srcs = ["simplifiers/ar_crs_combiner.cc"], + hdrs = ["simplifiers/ar_crs_combiner.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_replication_analysis", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "//xla/service:pattern_matcher", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "dynamic_index_splitter", + srcs = ["expanders/dynamic_index_splitter.cc"], + hdrs = ["expanders/dynamic_index_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "dynamic_index_splitter_test", + srcs = ["expanders/dynamic_index_splitter_test.cc"], + deps = [ + ":dynamic_index_splitter", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "ar_crs_combiner_test", + srcs = ["simplifiers/ar_crs_combiner_test.cc"], + deps = [ + ":ar_crs_combiner", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "slice_sinker", + srcs = ["simplifiers/slice_sinker.cc"], + hdrs = ["simplifiers/slice_sinker.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "slice_sinker_test", + srcs = ["simplifiers/slice_sinker_test.cc"], + deps = [ + ":hlo_dce", + ":slice_sinker", + "//xla:literal_util", + "//xla:shape_util", + "//xla:types", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "rng_expander", + srcs = ["expanders/rng_expander.cc"], + hdrs = ["expanders/rng_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:prng", + "//xla/service:hlo_creation_utils", + ], +) + +cc_library( + name = "rng_bit_generator_expander", + srcs = ["expanders/rng_bit_generator_expander.cc"], + hdrs = ["expanders/rng_bit_generator_expander.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:prng", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "collective_transformation_reorderer", + srcs = ["collectives/collective_transformation_reorderer.cc"], + hdrs = ["collectives/collective_transformation_reorderer.h"], + deps = [ + ":hlo_dce", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "collective_transformation_reorderer_test", + srcs = ["collectives/collective_transformation_reorderer_test.cc"], + deps = [ + ":collective_transformation_reorderer", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_verifier", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "operand_upcaster", + srcs = ["operand_upcaster.cc"], + hdrs = ["operand_upcaster.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:shape_inference", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "operand_upcaster_test", + srcs = ["operand_upcaster_test.cc"], + deps = [ + ":operand_upcaster", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "result_caster", + srcs = ["simplifiers/result_caster.cc"], + hdrs = ["simplifiers/result_caster.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:shape_inference", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "result_caster_test", + srcs = ["simplifiers/result_caster_test.cc"], + deps = [ + ":result_caster", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "convert_operand_folding", + srcs = ["simplifiers/convert_operand_folder.cc"], + hdrs = ["simplifiers/convert_operand_folder.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "convert_operand_folding_test", + srcs = ["simplifiers/convert_operand_folder_test.cc"], + deps = [ + ":convert_operand_folding", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "instruction_hoister", + srcs = ["simplifiers/instruction_hoister.cc"], + hdrs = ["simplifiers/instruction_hoister.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( + name = "gather_simplifier", + srcs = ["simplifiers/gather_simplifier.cc"], + hdrs = ["simplifiers/gather_simplifier.h"], + deps = [ + ":op_expander_pass", + "//xla:literal_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:gather_scatter_utils", + "//xla/service:hlo_creation_utils", + "@com_google_absl//absl/algorithm:container", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduce_window_rewriter", + srcs = ["simplifiers/reduce_window_rewriter.cc"], + hdrs = ["simplifiers/reduce_window_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduce_window_rewriter_test", + srcs = ["simplifiers/reduce_window_rewriter_test.cc"], + deps = [ + ":reduce_window_rewriter", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "stochastic_convert_decomposer", + srcs = ["expanders/stochastic_convert_decomposer.cc"], + hdrs = ["expanders/stochastic_convert_decomposer.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", + "//xla/service:shape_inference", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stochastic_convert_decomposer_test", + srcs = ["expanders/stochastic_convert_decomposer_test.cc"], + deps = [ + ":stochastic_convert_decomposer", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "sub_byte_normalization", + srcs = ["simplifiers/sub_byte_normalization.cc"], + hdrs = ["simplifiers/sub_byte_normalization.h"], + deps = [ + "//xla:shape_layout", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( + name = "sharding_format_picker", + testonly = True, + srcs = ["sharding_format_picker.cc"], + hdrs = ["sharding_format_picker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "gather_simplifier_test", + srcs = ["simplifiers/gather_simplifier_test.cc"], + deps = [ + ":gather_simplifier", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "add_original_value", + srcs = ["add_original_value.cc"], + hdrs = ["add_original_value.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "add_original_value_test", + srcs = ["add_original_value_test.cc"], + deps = [ + ":add_original_value", + "//xla:shape_util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "infeed_token_propagation", + srcs = ["collectives/infeed_token_propagation.cc"], + hdrs = ["collectives/infeed_token_propagation.h"], + deps = [ + ":hlo_dce", + ":tuple_simplifier", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:call_graph", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "infeed_token_propagation_test", + srcs = ["collectives/infeed_token_propagation_test.cc"], + deps = [ + ":infeed_token_propagation", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/utils:hlo_matchers", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/transforms/README.md b/third_party/xla/xla/hlo/transforms/README.md new file mode 100644 index 00000000000000..32590587e00ab6 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/README.md @@ -0,0 +1 @@ +This folder consolidates hardware independent HLO transformation passes. \ No newline at end of file diff --git a/third_party/xla/xla/service/add_original_value.cc b/third_party/xla/xla/hlo/transforms/add_original_value.cc similarity index 97% rename from third_party/xla/xla/service/add_original_value.cc rename to third_party/xla/xla/hlo/transforms/add_original_value.cc index 37cab3c7cad81a..e26b16689808e0 100644 --- a/third_party/xla/xla/service/add_original_value.cc +++ b/third_party/xla/xla/hlo/transforms/add_original_value.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/add_original_value.h" +#include "xla/hlo/transforms/add_original_value.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/add_original_value.h b/third_party/xla/xla/hlo/transforms/add_original_value.h new file mode 100644 index 00000000000000..f253b8ad42ec09 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/add_original_value.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_ADD_ORIGINAL_VALUE_H_ +#define XLA_HLO_TRANSFORMS_ADD_ORIGINAL_VALUE_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This pass adds to each op in the HLO graph the original_value attribute, +// which is used for HLO value tracking. See go/hlo-value-tracking for more +// details. +class AddOriginalValue : public HloModulePass { + public: + absl::string_view name() const override { return "add-original-value"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_ADD_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/service/add_original_value_test.cc b/third_party/xla/xla/hlo/transforms/add_original_value_test.cc similarity index 80% rename from third_party/xla/xla/service/add_original_value_test.cc rename to third_party/xla/xla/hlo/transforms/add_original_value_test.cc index f69ba94cba440e..5063e21103f5ec 100644 --- a/third_party/xla/xla/service/add_original_value_test.cc +++ b/third_party/xla/xla/hlo/transforms/add_original_value_test.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/add_original_value.h" +#include "xla/hlo/transforms/add_original_value.h" #include #include #include "absl/strings/string_view.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { -using AddOriginalValueTest = HloTestBase; +using AddOriginalValueTest = HloHardwareIndependentTestBase; using ::absl::string_view; @@ -68,11 +68,11 @@ ENTRY test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]{0}), f32[2,3 )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[V1:.*]] = f32[] parameter(0), original_value={{[{]}}{"[[V1]]"} -CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), original_value={{[{]}}{"[[V2]]"} -CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), original_value={({"[[V1]]"}, {"[[V2]]"})} -CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), original_value={{[{]}}{"[[V3]]"} -CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), original_value={(({"v1"}, {"v2"}), {"v3"})} +CHECK: %[[V1:.*]] = f32[] parameter(0), origin={{[{]}}{"[[V1]]"} +CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), origin={{[{]}}{"[[V2]]"} +CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), origin={({"[[V1]]"}, {"[[V2]]"})} +CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), origin={{[{]}}{"[[V3]]"} +CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), origin={(({"v1"}, {"v2"}), {"v3"})} )"); } @@ -90,10 +90,10 @@ ENTRY test { )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), original_value={{[{]}}{"[[CONSTANT1]]"} -CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), original_value={{[{]}}{"[[CONSTANT2]]"} -CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), original_value={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})} -CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, original_value={{[{]}}{"[[CONSTANT2]]"} +CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), origin={{[{]}}{"[[CONSTANT1]]"} +CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), origin={{[{]}}{"[[CONSTANT2]]"} +CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), origin={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})} +CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, origin={{[{]}}{"[[CONSTANT2]]"} )"); } @@ -109,8 +109,8 @@ ENTRY test { )"; RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( -CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), original_value={({"p" {0}{{[}]}}, {"p" {1}})} -CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, original_value={{[{]}}{"[[PARAM]]" {1} +CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), origin={({"p" {0}{{[}]}}, {"p" {1}})} +CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, origin={{[{]}}{"[[PARAM]]" {1} )"); } diff --git a/third_party/xla/xla/service/bfloat16_propagation.cc b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc similarity index 98% rename from third_party/xla/xla/service/bfloat16_propagation.cc rename to third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc index bf3dfedf4a0cad..27607d16ada289 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.cc @@ -13,23 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bfloat16_propagation.h" +#include "xla/hlo/transforms/bfloat16_propagation.h" #include "absl/algorithm/container.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/literal.h" #include "xla/map_util.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/tuple_simplifier.h" +#include "xla/service/float_support.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h new file mode 100644 index 00000000000000..005c68ada53037 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation.h @@ -0,0 +1,238 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_BFLOAT16_PROPAGATION_H_ +#define XLA_HLO_TRANSFORMS_BFLOAT16_PROPAGATION_H_ + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/float_support.h" +#include "xla/service/hlo_value.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// HLO pass which reduces the precision of some HLO instructions to BF16 +// according to the backend-specific FloatSupport rule provided by the +// caller. +// +// This pass can be used to reduce instruction precision without affecting the +// numerical accuracy of the module, i.e., the final output of the module would +// be bitwise identical to that without this pass; this is possible if the +// backend already reduces precision to BF16 on some HLO instructions. +// +// This pass will not modify the signature of a computation, unless it is a +// fusion computation or its only caller is a while. +// +// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, +// which has two issues: +// +// 1) It does not guarantee to respect the passed-in FloatSupport +// specification in terms of mixed precision, so the backend may not support an +// HLO that has mixed precision produced by this pass. To address this issue, +// run FloatNormalization with the same FloatSupport after this pass. +// +// 2) In general, mixed precision may break the assumptions of some other HLO +// passes even if the specific backend supports the individual HLOs. Such +// assumptions include that there are no HLOs using mixed precision, or that the +// precision of an HLO's output is determined by its inputs. It should be used +// at the end of the HLO optimization pipeline but before +// BFloat16ConversionFolding. If other passes are needed after this pass, run +// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this +// pass. +class BFloat16Propagation : public HloModulePass { + public: + explicit BFloat16Propagation(const FloatSupport* bfloat16_support); + + ~BFloat16Propagation() override = default; + + absl::string_view name() const override { return "bfloat16-propagation"; } + + // Runs the pass on the given module. Returns whether the module was changed + // (precision reductions were added). + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // Returns whether we should avoid changing the precision of inst regardless + // of the producers and users. + virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst); + + // Determines whether we should consider changing the precision of the given + // instruction in the forward pass. + virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); + + protected: + const FloatSupport* bfloat16_support_; + + private: + // *************************** + // Function called and state produced by the forward analysis pass (from + // parameters to root) that determines the candidate HLOs to use BF16 outputs. + + // The set of instructions to consider using bfloat16, computed in the forward + // pass. + absl::flat_hash_set consider_using_bfloat16_; + + // *************************** + // Functions called and state produced by the backward pass (from root to + // parameters) that finds opportunities to use BF16. + + // Determines the precision for the given instruction in the + // opportunity-finding pass. + void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); + + // Special handling in the opportunity-finding pass for fusion computations. + // + // Precondition: hlo->opcode() == kFusion + void DetermineFusionComputationPrecision(HloInstruction* fusion); + + // Reverts changes to BF16 that will not propagate outside a fusion + // computation. This avoids BF16 casts overhead inside a fusion which won't + // save memory bandwidth. + // + // Precondition: hlo->opcode() == kFusion + void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); + + // Special handling in the opportunity-finding pass for while computations. + // + // Precondition: hlo->opcode() == kWhile + void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); + + // Special handling in the opportunity-finding pass for conditional branches. + // + // Precondition: hlo->opcode() == kConditional + void DetermineConditionalComputationsPrecision(HloInstruction* cond); + + // The set of HloInstructions that have been visited in the + // opportunity-finding pass. + absl::flat_hash_set + instructions_visited_in_backward_pass_; + + // The set of HloComputations that have been visited in the + // opportunity-finding pass. + absl::flat_hash_set + computations_visited_in_backward_pass_; + + // *************************** + // Functions called by the final inconsistency resolving pass. + + // Adjusts the output shapes of HloInstructions such that if two + // HloInstructions have aliasing buffers in their outputs, they must have the + // same precision. + void ResolveInconsistencyOfAliasingBuffers( + HloModule* module, + const absl::flat_hash_set& execution_threads); + + // Resolves inconsistency of aliasing buffers for the given computation, and + // recursively runs on a while instruction's condition and body until a fixed + // point is reached. + bool ResolveInconsistencyOfAliasingBuffersHelper( + HloComputation* computation, + absl::flat_hash_set* visited_computations); + + // Makes the parameters of called computations match how they are called by + // the given HLO. + void AdjustCalledComputationParameters(HloInstruction* hlo); + + // Makes the root instructions of called computations match how they are used + // by the given HLO. + void AdjustCalledComputationRoot(HloInstruction* hlo); + + // *************************** + // Functions called after changes in changes_to_bf16_ are applied. + + // Resolves inconsistencies introduced by this pass for fusions with + // tuple-type output. + absl::Status ResolveInconsistentFusions( + HloModule* module, + const absl::flat_hash_set& execution_threads); + + // Converts the literals in kConstant HLOs which have their types changed to + // BF16 by this pass. + absl::Status ResolveConvertedConstants( + HloModule* module, + const absl::flat_hash_set& execution_threads); + + // Skips no-op conversions (same source and target shapes) that can be + // produced this pass, i.e., replaces them in their uses with their operands. + absl::Status SkipNoopConversions( + HloModule* module, + const absl::flat_hash_set& execution_threads); + + // *************************** + // Functions called and state used by two or more passes. + + // Returns whether all uses of the given HloInstruction can consume BF16 + // input. + bool AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const; + + // The output element type of the HLO at the given shape index after changes + // in changes_to_bf16_ are applied. + PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, + const ShapeIndex& index) const; + + // The element type of the HLO value after changes in changes_to_bf16_ are + // applied. + PrimitiveType ValueTypeAfterChange(const HloValue* value) const; + + // If target_type == BF16, adds the HLO at the given index to + // changes_to_bf16_; otherwise, target_type must be F32 and this function + // removes the HLO at the given index from changes_to_bf16_ if it was earlier + // added. + void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, + const ShapeIndex& index, + PrimitiveType target_type); + + // The set of F32 HLO values that must be kept in F32. + absl::flat_hash_set values_that_must_be_kept_as_f32_; + + // Mapping from each HloComputation to the number of callers to it in the + // module. Populated at the beginning of this pass. + absl::flat_hash_map caller_counts_; + + // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which + // are subject to further adjustment, then finally applied to the HLOs. This + // avoids setting changed_ to true but all changes are reverted during + // adjustment. + // + // For each HloInstruction, changes_to_bf16_ stores the affected buffers in + // the output as a map from in-place pointers to subshapes to shape indices. + absl::flat_hash_map> + changes_to_bf16_; + + // Whether the last processed HLO module has been changed by this pass. + bool changed_ = false; + + std::unique_ptr dataflow_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_BFLOAT16_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/bfloat16_propagation_test.cc b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc similarity index 99% rename from third_party/xla/xla/service/bfloat16_propagation_test.cc rename to third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc index 53e76916ba98c9..ff99c9215cbd1f 100644 --- a/third_party/xla/xla/service/bfloat16_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/bfloat16_propagation_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bfloat16_propagation.h" +#include "xla/hlo/transforms/bfloat16_propagation.h" #include #include @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/service/float_support.h" #include "xla/service/hlo_verifier.h" @@ -34,7 +35,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -68,11 +68,12 @@ class TestBFloat16Support : public FloatSupport { } }; -class BFloat16PropagationTest : public HloTestBase { +class BFloat16PropagationTest : public HloHardwareIndependentTestBase { protected: BFloat16PropagationTest() - : HloTestBase(/*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} // Runs the propagation pass on the given module, and returns whether the // module is changed after this pass. diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.cc similarity index 99% rename from third_party/xla/xla/service/all_gather_broadcast_reorder.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.cc index 287a9e8c7df680..51b18f4f11c8ca 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_gather_broadcast_reorder.h" +#include "xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h new file mode 100644 index 00000000000000..0a1a8d5e27c2d7 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h @@ -0,0 +1,43 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_BROADCAST_REORDER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_BROADCAST_REORDER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass that reorders all-gather(broadcast(x)) -> broadcast(all-gather(x)). +// The intent is to reduce the size of all-gather when possible by doing an +// all-gather on the (smaller) pre-broadcasted data and then applying the +// broadcast. +class AllGatherBroadcastReorder : public HloModulePass { + public: + absl::string_view name() const override { return "all-gather-bcast-reorder"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_BROADCAST_REORDER_H_ diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder_test.cc similarity index 92% rename from third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder_test.cc index 0c7eb62232d13a..6e4a14b3a1c9d0 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_broadcast_reorder_test.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_gather_broadcast_reorder.h" +#include "xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h" +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { @@ -24,7 +27,7 @@ namespace { namespace m = xla::testing::opcode_matchers; -class AllGatherBroadcastReorderTest : public HloTestBase { +class AllGatherBroadcastReorderTest : public HloHardwareIndependentTestBase { public: enum class PassOutput { NoChange, NonUniformAGPattern, UniformAGPattern }; void RunPass(absl::string_view hlo_module, PassOutput expected_output) { diff --git a/third_party/xla/xla/service/all_gather_combiner.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc similarity index 98% rename from third_party/xla/xla/service/all_gather_combiner.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc index efd7a803f04a0c..92469d13d3f251 100644 --- a/third_party/xla/xla/service/all_gather_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_gather_combiner.h" +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include #include @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -47,6 +48,7 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.h b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.h new file mode 100644 index 00000000000000..d4e9bc25321939 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_COMBINER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_COMBINER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/hlo_domain_map.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Combines small non-dependent AllGather ops into larger combined +// AllGather ops. A typical AllGather implementation has a minimum +// latency-induced time for a AllGather op so a single combined op can be +// more efficient than many small ones. +class AllGatherCombiner : public HloModulePass { + public: + AllGatherCombiner(int64_t combine_threshold_in_bytes, + int64_t combine_threshold_count, bool combine_by_dim, + bool combine_different_dtypes = true); + + absl::string_view name() const override { return "all-gather-combiner"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // The group key encapsulates all of the properties which must match for it to + // be possible to combine the instructions. + // The field of the key corresponds to the following: + // 1. all_gather_dimension + // 2. domain_metadata_id + // 3. channel_id + // 4. use_global_device_ids + // 5. data_type + // 6. replica_groups + // 7. extra arguments in string format. + using GroupKey = + std::tuple, int64_t, bool, bool, PrimitiveType, + std::vector>, std::string>; + + static std::string& GetGroupKeyExtraArgs(GroupKey& key); + + // Returns a key that will be equal for instructions that might be combined, + // or different if not. + static std::optional CombineKey( + const HloInstruction* instruction, const HloDomainMap& domain_map, + bool combine_by_dim, bool combine_different_dtypes = true); + + protected: + absl::StatusOr RunWithKeyCombiner( + HloModule* module, + const absl::flat_hash_set& execution_threads, + absl::FunctionRef( + const HloInstruction*, const HloDomainMap&, bool, bool)> + combine_key); + + protected: + // Combine all gather ops up to this threshold. + int64_t combine_threshold_in_bytes_; + + // Combine all gather ops up to this threshold (number of operands). + int64_t combine_threshold_count_; + + // Combine only all-gather ops with the same gather dimension. + bool combine_by_dim_; + + // Combine all-gather ops with different dtypes. + bool combine_different_dtypes_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_GATHER_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_gather_combiner_test.cc b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner_test.cc similarity index 98% rename from third_party/xla/xla/service/all_gather_combiner_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner_test.cc index 97a966b8815036..49a1c9ae613c02 100644 --- a/third_party/xla/xla/service/all_gather_combiner_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_gather_combiner_test.cc @@ -13,21 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_gather_combiner.h" +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" #include #include #include +#include +#include #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -55,7 +58,7 @@ int64_t AllGatherCount(const HloModule& module) { return FindAllGathers(module).size(); } -using AllGatherCombinerTest = HloTestBase; +using AllGatherCombinerTest = HloHardwareIndependentTestBase; // Tests combination of several AllGather instructions. TEST_F(AllGatherCombinerTest, CombineAllGathers) { diff --git a/third_party/xla/xla/service/all_reduce_combiner.cc b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.cc similarity index 94% rename from third_party/xla/xla/service/all_reduce_combiner.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.cc index a581b15d420dca..70ec486622c480 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.cc @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_combiner.h" +#include "xla/hlo/transforms/collectives/all_reduce_combiner.h" -#include -#include +#include #include #include #include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -34,10 +37,12 @@ limitations under the License. #include "xla/service/all_reduce_key.h" #include "xla/service/collective_combiner_utils.h" #include "xla/service/hlo_domain_map.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.h b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.h new file mode 100644 index 00000000000000..5562624debea5b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner.h @@ -0,0 +1,56 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_COMBINER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_COMBINER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Combines small non-dependent AllReduce ops into larger combined +// AllReduce ops. A typical AllReduce implementation has a minimum +// latency-induced time for a AllReduce op so a single combined op can be +// more efficient than many small ones. +class AllReduceCombiner : public HloModulePass { + public: + AllReduceCombiner(int64_t combine_threshold_in_bytes, + int64_t combine_threshold_count); + + absl::string_view name() const override { return "all-reduce-combiner"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Combine all reduce ops up to this threshold. + int64_t combine_threshold_in_bytes_; + + // Combine all reduce ops up to this threshold (number of operands). + int64_t combine_threshold_count_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_reduce_combiner_test.cc b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner_test.cc similarity index 97% rename from third_party/xla/xla/service/all_reduce_combiner_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner_test.cc index 188d7a99251bb0..6478459c60c6ce 100644 --- a/third_party/xla/xla/service/all_reduce_combiner_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_combiner_test.cc @@ -13,21 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_combiner.h" +#include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #include #include +#include +#include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -92,7 +98,7 @@ HloComputation* MakeReduction(const HloOpcode type, HloModule* module) { return reduction; } -using AllReduceCombinerTest = HloTestBase; +using AllReduceCombinerTest = HloHardwareIndependentTestBase; // Tests combination of several AllReduce instructions. TEST_F(AllReduceCombinerTest, CombineAllReduces) { diff --git a/third_party/xla/xla/service/all_reduce_contiguous.cc b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.cc similarity index 93% rename from third_party/xla/xla/service/all_reduce_contiguous.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.cc index fa76de45facd59..6106ade729451c 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.cc @@ -13,18 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_contiguous.h" +#include "xla/hlo/transforms/collectives/all_reduce_contiguous.h" #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.h b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.h new file mode 100644 index 00000000000000..5262c94366cce0 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous.h @@ -0,0 +1,41 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_CONTIGUOUS_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_CONTIGUOUS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Concatenates all-reduce operands together, so the all-reduce is performed +// over a single, contiguous buffer. +class AllReduceContiguous : public HloModulePass { + public: + absl::string_view name() const override { return "all-reduce-contiguous"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ALL_REDUCE_CONTIGUOUS_H_ diff --git a/third_party/xla/xla/service/all_reduce_contiguous_test.cc b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous_test.cc similarity index 87% rename from third_party/xla/xla/service/all_reduce_contiguous_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous_test.cc index ccd1effdbc6c30..2f114c578714c6 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/all_reduce_contiguous_test.cc @@ -13,16 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_contiguous.h" +#include "xla/hlo/transforms/collectives/all_reduce_contiguous.h" #include +#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -30,7 +33,7 @@ namespace { using ::testing::AllOf; namespace op = xla::testing::opcode_matchers; -using AllReduceContiguousTest = HloTestBase; +using AllReduceContiguousTest = HloHardwareIndependentTestBase; TEST_F(AllReduceContiguousTest, Simple) { const absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/service/async_collective_creator.cc b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.cc similarity index 95% rename from third_party/xla/xla/service/async_collective_creator.cc rename to third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.cc index 71c7eb820b61b1..16051a7e51a968 100644 --- a/third_party/xla/xla/service/async_collective_creator.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.cc @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/async_collective_creator.h" +#include "xla/hlo/transforms/collectives/async_collective_creator.h" #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/frontend_attributes.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -33,6 +38,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -156,7 +162,9 @@ std::vector AsyncCollectiveCreator::MatchCollectives( GetShapeSize(instruction->shape()) >= config_.all_reduce_min_threshold_in_bytes) || (op == HloOpcode::kAllGather && - config_.convert_all_gather(instruction)) || + config_.convert_all_gather(instruction) && + GetShapeSize(instruction->shape()) >= + config_.all_gather_min_threshold_in_bytes) || (op == HloOpcode::kCollectiveBroadcast && config_.convert_collective_broadcast(instruction)) || (op == HloOpcode::kCollectivePermute && diff --git a/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.h b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.h new file mode 100644 index 00000000000000..af76954c4379ef --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator.h @@ -0,0 +1,78 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_ASYNC_COLLECTIVE_CREATOR_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_ASYNC_COLLECTIVE_CREATOR_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/util.h" + +namespace xla { + +// Transforms each all-reduce instruction to a pair of all-reduce-start and +// all-reduce-done. +class AsyncCollectiveCreator : public HloModulePass { + public: + // Function to query the shape of the "context" for collectives that use + // HLO async-start/async-done. + using ContextShapeQuery = + std::function(const HloInstruction *)>; + struct CollectiveCreatorConfig { + HloPredicate convert_all_reduce = HloPredicateFalse; + HloPredicate convert_all_gather = HloPredicateFalse; + HloPredicate convert_collective_broadcast = HloPredicateFalse; + HloPredicate convert_collective_permute = HloPredicateFalse; + HloPredicate convert_all_to_all = HloPredicateFalse; + HloPredicate convert_reduce_scatter = HloPredicateFalse; + ContextShapeQuery get_context_shapes = [](const HloInstruction *) { + return std::vector{}; + }; + int64_t all_reduce_min_threshold_in_bytes = 0; + int64_t all_gather_min_threshold_in_bytes = 0; + }; + explicit AsyncCollectiveCreator(CollectiveCreatorConfig creator_config) + : config_(std::move(creator_config)) {} + absl::string_view name() const override { return "async-collective-creator"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) override; + + std::vector MatchCollectives(HloComputation *computation); + absl::StatusOr ReplaceCollectives( + HloComputation *computation, + std::vector &supported_collectives); + const CollectiveCreatorConfig *config() const { return &config_; } + + private: + CollectiveCreatorConfig config_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_ASYNC_COLLECTIVE_CREATOR_H_ diff --git a/third_party/xla/xla/service/async_collective_creator_test.cc b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc similarity index 95% rename from third_party/xla/xla/service/async_collective_creator_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc index ad783ca23d0770..80538900c69ab1 100644 --- a/third_party/xla/xla/service/async_collective_creator_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/async_collective_creator_test.cc @@ -13,18 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/async_collective_creator.h" +#include "xla/hlo/transforms/collectives/async_collective_creator.h" +#include #include +#include +#include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -34,7 +39,7 @@ namespace m = ::xla::match; using ::testing::NotNull; using ::testing::SizeIs; -using AsyncAllReduceCreatorTest = HloTestBase; +using AsyncAllReduceCreatorTest = HloHardwareIndependentTestBase; TEST_F(AsyncAllReduceCreatorTest, SplitsSingleAllReduce) { constexpr absl::string_view hlo_string = R"( @@ -270,9 +275,9 @@ TEST_F(AsyncAllReduceCreatorTest, ControlPredecessor) { constexpr absl::string_view hlo_string = R"( HloModule test ENTRY entry { - p0 = f32[1] parameter(0) - ag = f32[8] all-gather(p0), dimensions={0}, replica_groups={{0,1,2,3,4,5,6,7}}, control-predecessors={p0} - p1 = f32[1] parameter(1), control-predecessors={ag} + p0 = f32[128] parameter(0) + ag = f32[1024] all-gather(p0), dimensions={0}, replica_groups={{0,1,2,3,4,5,6,7}}, control-predecessors={p0} + p1 = f32[128] parameter(1), control-predecessors={ag} ROOT sum = add(ag, ag) } )"; @@ -281,6 +286,7 @@ TEST_F(AsyncAllReduceCreatorTest, ControlPredecessor) { ParseAndReturnVerifiedModule(hlo_string)); AsyncCollectiveCreator::CollectiveCreatorConfig config; config.convert_all_gather = HloPredicateTrue; + config.all_gather_min_threshold_in_bytes = 4096; TF_ASSERT_OK( RunHloPass(AsyncCollectiveCreator(config), hlo_module.get()).status()); SCOPED_TRACE(hlo_module->ToString()); diff --git a/third_party/xla/xla/service/collective_quantizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc similarity index 99% rename from third_party/xla/xla/service/collective_quantizer.cc rename to third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc index bf4b9e57b4dff4..edb20f859f834d 100644 --- a/third_party/xla/xla/service/collective_quantizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collective_quantizer.h" +#include "xla/hlo/transforms/collectives/collective_quantizer.h" -#include "xla/service/hlo_replication_analysis.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.h b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.h new file mode 100644 index 00000000000000..be5722307decae --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_QUANTIZER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_QUANTIZER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Reduces the amount of data transferred in all-gather, all-to-all, +// collective-broadcast and collective-permute ops by exchanging the collectives +// with subsequent quantizations or type conversions to a narrower type as well +// as preceding dequantizations or type conversions to a wider type. When +// present, unary ops such as bitcasts, copies, reshapes and slices between +// collective and quantization/dequantiation/type conversion are shifted, i.e. +// transforms +// +// collective --> unary --> quantization/type conversion +// +// into +// +// quantization/type conversion --> collective --> unary +// +// and +// +// dequantization/type conversion --> unary --> collective +// +// into +// +// unary --> collective --> dequantization/type conversion. +class CollectiveQuantizer : public HloModulePass { + public: + absl::string_view name() const override { return "collective-quantizer"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_QUANTIZER_H_ diff --git a/third_party/xla/xla/service/collective_quantizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc similarity index 99% rename from third_party/xla/xla/service/collective_quantizer_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc index e6448e844db25f..d6c9ce416e4a2c 100644 --- a/third_party/xla/xla/service/collective_quantizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_quantizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collective_quantizer.h" +#include "xla/hlo/transforms/collectives/collective_quantizer.h" #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -34,7 +34,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class CollectiveQuantizerTest : public HloTestBase { +class CollectiveQuantizerTest : public HloHardwareIndependentTestBase { public: absl::StatusOr RunCollectiveQuantizer(HloModule* module) { CollectiveQuantizer collective_quantizer; diff --git a/third_party/xla/xla/service/collective_transformation_reorderer.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.cc similarity index 98% rename from third_party/xla/xla/service/collective_transformation_reorderer.cc rename to third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.cc index 51712282019c92..1fb091556690f0 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collective_transformation_reorderer.h" +#include "xla/hlo/transforms/collectives/collective_transformation_reorderer.h" #include #include @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.h b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.h new file mode 100644 index 00000000000000..9b8071d517d635 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer.h @@ -0,0 +1,76 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_TRANSFORMATION_REORDERER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_TRANSFORMATION_REORDERER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Transforms +// -- all-gather + reshape into reshape + all-gather and +// -- reshape + all-reduce into all-reduce + reshape. +// Both transformations require that there are no other users affected, i.e., +// reshape user count should be 1. +// all-gather transformation requires the reshape to only change the shape of +// the all-gather shards, i.e., not reshaping across the all-gather dimension. +// all-reduce transformation requires all-reduce to be not layout constrained. + +// all-gather + reshape example: + +// input = [C_0, C_1, ..., C_i, ..., C_{n-1}, C_n] ... +// all-gather = [C_0, C_1, ..., P*C_i, ... C_{n-1}, C_n] all-gather(input) +// reshape = [D_0, D_1, ..., P*D_j, ..., D_{m-1}, D_m] reshape(all-gather) + +// can be transformed to: + +// input = [C_0, C_1, ..., C_i, ..., C_{n-1}, C_n] ... +// reshape = [D_0, D_1, ..., D_j, ..., D_{m-1}, D_m] reshape(input) +// all-gather = [D_0, D_1, ..., P*D_j, ... D_{m-1}, D_m] all-gather(input) + +// if and only if C_0 * C_1 * ... * C_{i-1} = D_0 * D_1 * ... * D_{j-1} +// and C_{i+1} * ... * C_{n-1} * C_n = D_{j+1} * ... * D_{m-1} * D_{m}. + +class CollectiveTransformationReorder : public HloModulePass { + public: + CollectiveTransformationReorder() = default; + ~CollectiveTransformationReorder() override = default; + absl::string_view name() const override { + return "collective-transformation-reorderer"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::StatusOr ReorderAllGatherTransformations( + HloModule* module, + const absl::flat_hash_set& execution_threads); + absl::StatusOr ReorderAllReduceTransformations( + HloModule* module, + const absl::flat_hash_set& execution_threads); +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVE_TRANSFORMATION_REORDERER_H_ diff --git a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer_test.cc similarity index 98% rename from third_party/xla/xla/service/collective_transformation_reorderer_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer_test.cc index 73f185e1caf73f..d0e22e6ce2734c 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collective_transformation_reorderer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collective_transformation_reorderer.h" +#include "xla/hlo/transforms/collectives/collective_transformation_reorderer.h" #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -34,7 +34,8 @@ namespace { namespace op = xla::testing::opcode_matchers; -class CollectiveTransformationReordererTest : public HloTestBase { +class CollectiveTransformationReordererTest + : public HloHardwareIndependentTestBase { public: absl::StatusOr RunCollectiveTransformationReorderer(HloModule* module) { CollectiveTransformationReorder reorderer; diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc similarity index 96% rename from third_party/xla/xla/service/collectives_schedule_linearizer.cc rename to third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc index a367831a1d0fec..175e7850f7a091 100644 --- a/third_party/xla/xla/service/collectives_schedule_linearizer.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collectives_schedule_linearizer.h" +#include "xla/hlo/transforms/collectives/collectives_schedule_linearizer.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.h b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.h new file mode 100644 index 00000000000000..432294ac343e95 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer.h @@ -0,0 +1,52 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVES_SCHEDULE_LINEARIZER_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVES_SCHEDULE_LINEARIZER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/util.h" + +namespace xla { + +// Enforces a total order on all collectives present in the module, based on the +// order given to the instructions. +// +// Does not insert inter-computation dependencies, only linearizes the order +// within each computation. +class CollectivesScheduleLinearizer : public HloModulePass { + public: + explicit CollectivesScheduleLinearizer(HloModulePredicate is_enabled = {}) + : is_enabled_(is_enabled) {} + + absl::string_view name() const override { + return "collectives-schedule-linearizer"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + HloModulePredicate is_enabled_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_COLLECTIVES_SCHEDULE_LINEARIZER_H_ diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer_test.cc b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc similarity index 96% rename from third_party/xla/xla/service/collectives_schedule_linearizer_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc index eeb9b8b936e55d..68f313687e5eab 100644 --- a/third_party/xla/xla/service/collectives_schedule_linearizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/collectives_schedule_linearizer_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/collectives_schedule_linearizer.h" +#include "xla/hlo/transforms/collectives/collectives_schedule_linearizer.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" namespace xla { @@ -37,7 +37,8 @@ int64_t CountControlEdges(const HloComputation& computation) { return count; } -class CollectivesScheduleLinearizerTest : public HloTestBase { +class CollectivesScheduleLinearizerTest + : public HloHardwareIndependentTestBase { protected: void InsertCollectivesSchedule(HloModule* module) { CollectivesScheduleLinearizer collectives_schedule_linearizer; diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync.cc b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.cc similarity index 99% rename from third_party/xla/xla/service/convert_async_collectives_to_sync.cc rename to third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.cc index 09365e158bf34b..9b4e11d01b9e0e 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_async_collectives_to_sync.h" +#include "xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h new file mode 100644 index 00000000000000..2c0ccaad6569cf --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h @@ -0,0 +1,72 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/util.h" + +namespace xla { + +// Convert asynchronous collectives to synchronous (after HLO scheduling) if +// there are no compute operations overlapping with them. + +class ConvertAsyncCollectivesToSync : public HloModulePass { + public: + explicit ConvertAsyncCollectivesToSync(HloPredicate is_nop = {}) + : is_nop_(is_nop) {} + absl::string_view name() const override { + return "convert-async-collectives-to-sync"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + virtual absl::Status ConvertAsyncInstructionsToSync( + HloComputation* computation, + absl::Span> async_pairs) + const { + return ReplaceAsyncInstructionsWithSync(computation, async_pairs); + } + + // Helper utility to replace a list of pairs of async-start/done ops in a + // computation with their synchronous variants and update the schedule. + static absl::Status ReplaceAsyncInstructionsWithSync( + HloComputation* computation, + absl::Span> + async_pairs); + + static constexpr char kAsyncCollectiveNameAttributeName[] = + "async_collective_name"; + + private: + absl::StatusOr RunOnComputation(HloComputation* computation); + HloPredicate is_nop_; +}; +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc similarity index 98% rename from third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc rename to third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc index 5fdc9119865384..4d21c33f0d44e0 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/hlo/transforms/collectives/convert_async_collectives_to_sync_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_async_collectives_to_sync.h" +#include "xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h" #include @@ -25,8 +25,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -40,7 +40,8 @@ namespace m = xla::testing::opcode_matchers; // Note: The pass only processes modules that are already scheduled. If the test // does not work as epxected, make sure to check if "is_scheduled=true" is added // to the HLO module string. -class ConvertAsyncCollectivesToSyncTest : public HloTestBase { +class ConvertAsyncCollectivesToSyncTest + : public HloHardwareIndependentTestBase { public: absl::Status RunPass(HloModule *module, bool expect_change, HloPredicate is_nop = {}) { diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc new file mode 100644 index 00000000000000..3de31a8315ba50 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.cc @@ -0,0 +1,442 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/collectives/infeed_token_propagation.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/service/call_graph.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { +bool IsDanglingInfeed(HloInstruction* infeed) { + CHECK(infeed->opcode() == HloOpcode::kInfeed); + if (infeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = infeed->operand(0); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + for (const HloInstruction* user : infeed->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == 1) { + return false; + } + } + + return true; +} + +bool IsDanglingOutfeed(HloInstruction* outfeed) { + CHECK(outfeed->opcode() == HloOpcode::kOutfeed); + if (outfeed->has_sharding()) { + // TODO: b/368327832 - Skip handling sharding until it is removed. + return false; + } + + // Check for dangling input token. + if (const HloInstruction* after_all = outfeed->operand(1); + after_all->opcode() != HloOpcode::kAfterAll || + after_all->operand_count() != 0) { + return false; + } + + // Check for dangling output token. + if (outfeed->user_count() != 0) { + return false; + } + + return true; +} + +HloInstruction* ReconstructTuple(HloInstruction* tuple) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + std::vector gtes; + gtes.resize(tuple->shape().tuple_shapes_size()); + for (int64_t idx = 0; idx < gtes.size(); ++idx) { + gtes[idx] = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(tuple, idx)); + } + + return computation->AddInstruction(HloInstruction::CreateTuple(gtes)); +} + +absl::StatusOr InsertTokenIntoTuple(HloInstruction* tuple, + bool add_token_operand) { + CHECK(tuple->shape().IsTuple()); + HloComputation* computation = tuple->parent(); + + // Recreate the original tuple, we'll need to pass this to all the users. + // Trying to use tuple->ReplaceAllUsesWith(original_tuple) cause a cycle. + std::vector original_users = tuple->users(); + HloInstruction* original_tuple = ReconstructTuple(tuple); + for (HloInstruction* original_user : original_users) { + for (int64_t idx : original_user->operand_indices(tuple)) { + TF_RETURN_IF_ERROR( + original_user->ReplaceOperandWith(idx, original_tuple)); + } + } + + // Append the token to the parameter tuple. + *tuple->mutable_shape()->add_tuple_shapes() = ShapeUtil::MakeTokenShape(); + if (add_token_operand) { + tuple->AppendOperand( + computation->AddInstruction(HloInstruction::CreateToken())); + } + + HloInstruction* input_token_gte = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + tuple, tuple->shape().tuple_shapes_size() - 1)); + return input_token_gte; +} +} // namespace + +absl::Status CanonicalizeConditionalInstruction(HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + + for (HloComputation* branch : conditional->branch_computations()) { + // Tuplify the branch parameter if needed. + HloInstruction* parameter = branch->parameter_instruction(0); + if (!parameter->shape().IsTuple()) { + *parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({parameter->shape()}); + HloInstruction* original = branch->AddInstruction( + HloInstruction::CreateGetTupleElement(parameter, 0)); + TF_RETURN_IF_ERROR(parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the branch tuple if needed. + int64_t branch_operand_idx = conditional->branch_index(branch) + 1; + HloInstruction* branch_tuple = + conditional->mutable_operand(branch_operand_idx); + if (!branch_tuple->shape().IsTuple()) { + branch_tuple = conditional->parent()->AddInstruction( + HloInstruction::CreateTuple({branch_tuple})); + TF_RETURN_IF_ERROR(conditional->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + } + + // Explicitly disjoin computation parameters from branch inputs, so we can + // insert tokens into the input tuple. + if (branch_tuple->opcode() == HloOpcode::kParameter) { + branch_tuple = ReconstructTuple(branch_tuple); + TF_RETURN_IF_ERROR( + conditional->ReplaceOperandWith(branch_operand_idx, branch_tuple)); + } + + // Explicitly make the root of the branch a tuple. + HloInstruction* root = branch->root_instruction(); + if (root->opcode() != HloOpcode::kTuple) { + root = ReconstructTuple(root); + branch->set_root_instruction(root); + } + } + + // ConditionalCanonicalizer should have already turned the conditional output + // to be a tuple. + CHECK(conditional->shape().IsTuple()); + + // Explicitly disjoin the conditional from being a computation root, so that + // we can insert tokens into, while preserving the original computation shape. + if (conditional->IsRoot()) { + HloInstruction* new_root = ReconstructTuple(conditional); + conditional->parent()->set_root_instruction(new_root); + } + + return absl::OkStatus(); +} + +absl::Status CanonicalizeWhileInstruction(HloInstruction* loop) { + CHECK_EQ(loop->opcode(), HloOpcode::kWhile); + HloComputation* body = loop->while_body(); + HloComputation* cond = loop->while_condition(); + + // Tuplify the body parameter if needed. + HloInstruction* body_parameter = body->parameter_instruction(0); + if (!body_parameter->shape().IsTuple()) { + *body_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({body_parameter->shape()}); + HloInstruction* original = body->AddInstruction( + HloInstruction::CreateGetTupleElement(body_parameter, 0)); + TF_RETURN_IF_ERROR( + body_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the body root if needed. + HloInstruction* root = body->root_instruction(); + if (!root->shape().IsTuple()) { + root = body->AddInstruction(HloInstruction::CreateTuple({root})); + body->set_root_instruction(root, /*accept_different_shape=*/true); + } + + // Tuplify the condition parameter if needed. + HloInstruction* cond_parameter = cond->parameter_instruction(0); + if (!cond_parameter->shape().IsTuple()) { + *cond_parameter->mutable_shape() = + ShapeUtil::MakeTupleShape({cond_parameter->shape()}); + HloInstruction* original = cond->AddInstruction( + HloInstruction::CreateGetTupleElement(cond_parameter, 0)); + TF_RETURN_IF_ERROR( + cond_parameter->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while instruction if needed. + if (!loop->shape().IsTuple()) { + *loop->mutable_shape() = ShapeUtil::MakeTupleShape({loop->shape()}); + HloInstruction* original = loop->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement(loop, 0)); + TF_RETURN_IF_ERROR(loop->ReplaceAllUsesWithDifferentShape(original)); + } + + // Tuplify the while tuple if needed. + HloInstruction* loop_tuple = loop->mutable_operand(0); + if (!loop_tuple->shape().IsTuple()) { + loop_tuple = loop->parent()->AddInstruction( + HloInstruction::CreateTuple({loop_tuple})); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWithDifferentShape(0, loop_tuple)); + } + + // Explicitly disjoin computation parameters from loop inputs, so we can + // insert tokens into the input tuple. + if (loop_tuple->opcode() == HloOpcode::kParameter) { + loop_tuple = ReconstructTuple(loop_tuple); + TF_RETURN_IF_ERROR(loop->ReplaceOperandWith(0, loop_tuple)); + } + + // Explicitly make the root of the body a tuple. + if (root->opcode() != HloOpcode::kTuple) { + root = ReconstructTuple(root); + body->set_root_instruction(root); + } + + // Explicitly disjoin the loop from being a computation root, so that + // we can insert tokens into, while preserving the original computation shape. + if (loop->IsRoot()) { + HloInstruction* new_root = ReconstructTuple(loop); + loop->parent()->set_root_instruction(new_root); + } + + return absl::OkStatus(); +} + +absl::Status InfeedTokenPropagation::PropagateTokenThroughConditionalBranch() { + // Conditional branches can diverge in inputs, but must converge on outputs. + + HloComputation* comp = dangling_instruction_->parent(); + dangling_instruction_ = call_graph_->GetComputationCallers(comp)[0]; + CHECK_EQ(dangling_instruction_->opcode(), HloOpcode::kConditional); + + // Insert the output token into each branch. + for (HloComputation* branch : dangling_instruction_->branch_computations()) { + HloInstruction* root = branch->root_instruction(); + if (branch == comp) { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token_); + } else { + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/true).status()); + } + } + + // Insert the input token into the branch parameter. + HloInstruction* parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token_->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the branch tuple. + int64_t branch_operand_idx = dangling_instruction_->branch_index(comp) + 1; + HloInstruction* branch_tuple = + dangling_instruction_->mutable_operand(branch_operand_idx); + TF_ASSIGN_OR_RETURN( + HloInstruction * next_input_token_gte, + InsertTokenIntoTuple(branch_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR(dangling_instruction_->ReplaceOperandWithDifferentShape( + branch_operand_idx, branch_tuple)); + input_token_ = + branch_tuple->mutable_operand(next_input_token_gte->tuple_index()); + + // Insert the output token into conditional instruction. + TF_ASSIGN_OR_RETURN( + output_token_, + InsertTokenIntoTuple(dangling_instruction_, /*add_token_operand=*/false)); + + return absl::OkStatus(); +} + +absl::Status InfeedTokenPropagation::PropagateTokenThroughWhileBody() { + // While loops need to converge on input and output. + + HloComputation* comp = dangling_instruction_->parent(); + dangling_instruction_ = call_graph_->GetComputationCallers(comp)[0]; + CHECK_EQ(dangling_instruction_->opcode(), HloOpcode::kWhile); + + // Insert the output token into the body root. + HloInstruction* root = comp->root_instruction(); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(root, /*add_token_operand=*/false).status()); + root->AppendOperand(output_token_); + + // Insert the input token into the body parameter. + HloInstruction* body_parameter = comp->parameter_instruction(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * input_token_gte, + InsertTokenIntoTuple(body_parameter, /*add_token_operand=*/false)); + TF_RETURN_IF_ERROR(input_token_->ReplaceAllUsesWith(input_token_gte)); + + // Insert the input token into the condition parameter. + HloComputation* cond = dangling_instruction_->while_condition(); + HloInstruction* cond_parameter = cond->parameter_instruction(0); + TF_RETURN_IF_ERROR( + InsertTokenIntoTuple(cond_parameter, /*add_token_operand=*/false) + .status()); + + // Insert the input token into the while tuple. + HloInstruction* while_tuple = dangling_instruction_->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + input_token_, + InsertTokenIntoTuple(while_tuple, /*add_token_operand=*/true)); + TF_RETURN_IF_ERROR( + dangling_instruction_->ReplaceOperandWithDifferentShape(0, while_tuple)); + + // Insert the input token into the while instruction. + TF_ASSIGN_OR_RETURN( + output_token_, + InsertTokenIntoTuple(dangling_instruction_, /*add_token_operand=*/false)); + + return absl::OkStatus(); +} + +absl::Status InfeedTokenPropagation::PropagateToken() { + HloComputation* comp = dangling_instruction_->parent(); + if (comp->IsEntryComputation()) { + return absl::OkStatus(); + } + VLOG(2) << "Propagating tokens for: " << dangling_instruction_->name(); + + HloInstruction* caller = call_graph_->GetComputationCallers(comp)[0]; + // TODO: b/368327832 - Skip handling sharding until it is removed. + if (caller->has_sharding()) { + return absl::OkStatus(); + } + if (caller->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(CanonicalizeConditionalInstruction(caller)); + TF_RETURN_IF_ERROR(PropagateTokenThroughConditionalBranch()); + } else if (caller->opcode() == HloOpcode::kWhile && + comp == caller->while_body()) { + TF_RETURN_IF_ERROR(CanonicalizeWhileInstruction(caller)); + TF_RETURN_IF_ERROR(PropagateTokenThroughWhileBody()); + } else { + // We only expect to encounter computations behind while and conditional + // instructions. In the case of it being behind a while condition, there is + // no way to propagate the output token, as the root only returns a + // predicate. All other computations that could possibly contain infeed + // or outfeed ops should have already been inlined. + VLOG(2) << "Unhandled computation: " << comp->name(); + return absl::OkStatus(); + } + + return PropagateToken(); +} + +absl::StatusOr InfeedTokenPropagation::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + VLOG(5) << "Before InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + + std::vector dangling_infeeds; + std::vector dangling_outfeeds; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + if (!computation->IsEntryComputation()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed && + IsDanglingInfeed(instruction)) { + VLOG(1) << "Found dangling infeed: " << instruction->ToString(); + dangling_infeeds.push_back(instruction); + } else if (instruction->opcode() == HloOpcode::kOutfeed && + IsDanglingOutfeed(instruction)) { + VLOG(1) << "Found dangling outfeed: " << instruction->ToString(); + dangling_outfeeds.push_back(instruction); + } + } + } + } + bool changed = !dangling_infeeds.empty() || !dangling_outfeeds.empty(); + + if (changed) { + call_graph_ = CallGraph::Build(module); + if (!call_graph_->IsFlattened()) { + return FailedPrecondition( + "Call graph must be flattened before infeed token propagation."); + } + } + + for (HloInstruction* dangling_infeed : dangling_infeeds) { + dangling_instruction_ = dangling_infeed; + input_token_ = dangling_infeed->mutable_operand(0); + output_token_ = dangling_infeed->AddInstruction( + HloInstruction::CreateGetTupleElement(dangling_infeed, 1)); + TF_RETURN_IF_ERROR(PropagateToken()); + } + for (HloInstruction* dangling_outfeed : dangling_outfeeds) { + dangling_instruction_ = dangling_outfeed; + input_token_ = dangling_outfeed->mutable_operand(1); + output_token_ = dangling_outfeed; + TF_RETURN_IF_ERROR(PropagateToken()); + } + + if (changed) { + TF_RETURN_IF_ERROR( + TupleSimplifier().Run(module, execution_threads).status()); + TF_RETURN_IF_ERROR(HloDCE().Run(module, execution_threads).status()); + } + + VLOG(5) << "After InfeedTokenPropagation:"; + XLA_VLOG_LINES(5, module->ToString()); + return changed; +} +} // namespace xla diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h new file mode 100644 index 00000000000000..f1e3080b7a07e7 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_COLLECTIVES_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_HLO_TRANSFORMS_COLLECTIVES_INFEED_TOKEN_PROPAGATION_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/call_graph.h" + +namespace xla { +// Finds dangling infeed/outfeed tokens inside nested computations and bubbles +// them up through callers until they reach the entry computation. This is +// needed to prepare these computations to be inlined, otherwise the previous +// computation boundaries won't be there to stop infeeds/outfeeds from being +// reordered during scheduling. +// +// This pass assumes the HLO graph is flattened. +class InfeedTokenPropagation : public HloModulePass { + public: + std::string_view name() const override { return "infeed-token-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::Status PropagateToken(); + absl::Status PropagateTokenThroughWhileBody(); + absl::Status PropagateTokenThroughConditionalBranch(); + + std::unique_ptr call_graph_; + HloInstruction* dangling_instruction_ = nullptr; + HloInstruction* input_token_ = nullptr; + HloInstruction* output_token_ = nullptr; +}; +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_COLLECTIVES_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc new file mode 100644 index 00000000000000..2be79575afe8b2 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/collectives/infeed_token_propagation_test.cc @@ -0,0 +1,653 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/collectives/infeed_token_propagation.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "tsl/platform/statusor.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class InfeedTokenPropagationTest : public HloHardwareIndependentTestBase { + protected: + InfeedTokenPropagationTest() = default; +}; + +TEST_F(InfeedTokenPropagationTest, EntryComputationInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT gte.0 = get-tuple-element(infeed.0), index=0 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, EntryComputationOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.1 = tuple() +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(arg.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, ConditionalDuplicateOperand) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, tuple.0, tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the true tuple. + const HloInstruction* true_tuple = cond->operand(1); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should not have propagated through the false tuple. + const HloInstruction* false_tuple = cond->operand(2); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The infeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleConditional) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = s32[] parameter(0) + outfeed_tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + arg.0 = s32[] parameter(0) + pred.0 = pred[] constant(true) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, arg.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = cond->mutable_operand(1); + EXPECT_TRUE(true_tuple->shape().IsTuple()); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, DisjointConditionalOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +ENTRY main { + pred.0 = pred[] constant(true) + true_tuple.0 = tuple() + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} + +TEST_F(InfeedTokenPropagationTest, WhileInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + ROOT tuple.0 = tuple() +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The infeed output token should have propagated through the while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1))); + + // The infeed input token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The infeed input token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, WhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + gte.0 = get-tuple-element(arg.0), index=0 + ROOT tuple.0 = tuple(gte.0) +} + +cond { + arg.0 = (s32[]) parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + while_tuple.0 = tuple(arg.0) + ROOT while.0 = (s32[]) while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, DisjointWhileOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = () parameter(0) + one.0 = s32[] constant(1) + outfeed_tuple.0 = tuple(one.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(outfeed_tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(body_param->shape().tuple_shapes()[0].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[0].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NonTupleWhile) { + constexpr std::string_view hlo = R"( +HloModule main + +comp { + ROOT arg.0 = s32[] parameter(0) + tuple.0 = tuple(arg.0) + token.0 = after-all() + outfeed.0 = token[] outfeed(tuple.0, token.0), outfeed_shape=(s32[]) +} + +cond { + arg.0 = s32[] parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + arg.0 = s32[] parameter(0) + ROOT while.0 = s32[] while(arg.0), condition=cond, body=comp +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The outfeed output token should have propagated through the loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_TRUE(loop->shape().IsTuple()); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should have propagated through the loop tuple. + EXPECT_THAT(loop->operand(0), op::Tuple(op::Parameter(), op::AfterAll())); + + // The outfeed output token should have propagated through the while body + // root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(), op::Outfeed())); + + // The outfeed output token should have propagated through the body parameter. + HloInstruction* body_param = body_comp->parameter_instruction(0); + EXPECT_EQ(body_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(body_param->shape().tuple_shapes()[1].IsToken()); + + // The outfeed output token should have propagated through the condition + // parameter. + HloComputation* cond_comp = FindComputation(module.get(), "cond"); + HloInstruction* cond_param = cond_comp->parameter_instruction(0); + EXPECT_EQ(cond_param->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(cond_param->shape().tuple_shapes()[1].IsToken()); +} + +TEST_F(InfeedTokenPropagationTest, NestedInfeedOutfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +true_comp { + arg.0 = (s32[]) parameter(0) + token.0 = after-all() + outfeed.0 = token[] outfeed(arg.0, token.0), outfeed_shape=(s32[]) + ROOT tuple.0 = tuple() +} + +false_comp { + arg.0 = () parameter(0) + ROOT tuple.0 = tuple() +} + +comp { + arg.0 = () parameter(0) + token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(token.0) + gte.0 = get-tuple-element(infeed.0), index=0 + pred.0 = pred[] constant(true) + true_tuple.0 = tuple(gte.0) + false_tuple.0 = tuple() + ROOT cond.0 = () conditional(pred.0, true_tuple.0, false_tuple.0), true_computation=true_comp, false_computation=false_comp +} + +cond { + arg.0 = () parameter(0) + ROOT true.0 = pred[] constant(true) +} + +ENTRY main { + while_tuple.0 = tuple() + ROOT while.0 = () while(while_tuple.0), condition=cond, body=comp +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + InfeedTokenPropagation itp; + TF_ASSERT_OK_AND_ASSIGN(bool changed, itp.Run(module.get())); + EXPECT_TRUE(changed); + + // The infeed and outfeed output tokens should have propagated through the + // loop. + HloInstruction* loop = FindInstruction(module.get(), "while.0"); + EXPECT_EQ(loop->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed input tokens should have propagated through the loop + // tuple. + HloInstruction* loop_tuple = FindInstruction(module.get(), "while_tuple.0"); + EXPECT_EQ(loop_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[0].IsToken()); + EXPECT_TRUE(loop_tuple->shape().tuple_shapes()[1].IsToken()); + + // The infeed and outfeed output tokens should have propagated through the + // while body root. + HloComputation* body_comp = FindComputation(module.get(), "comp"); + EXPECT_THAT(body_comp->root_instruction(), + op::Tuple(op::GetTupleElement(op::Infeed(), 1), + op::GetTupleElement(op::Conditional(), 0))); + + // The outfeed output token should have propagated through the conditional. + HloInstruction* cond = FindInstruction(module.get(), "cond.0"); + EXPECT_EQ(cond->shape().tuple_shapes_size(), 1); + EXPECT_TRUE(cond->shape().tuple_shapes()[0].IsToken()); + + // The outfeed input token should have propagated through the true tuple. + HloInstruction* true_tuple = FindInstruction(module.get(), "true_tuple.0"); + EXPECT_EQ(true_tuple->shape().tuple_shapes_size(), 2); + EXPECT_TRUE(true_tuple->shape().tuple_shapes()[1].IsToken()); + + // The outfeed input token should not have propagated through the false tuple. + HloInstruction* false_tuple = FindInstruction(module.get(), "false_tuple.0"); + EXPECT_EQ(false_tuple->shape().tuple_shapes_size(), 0); + + // The outfeed output token should have propagated through the true + // computation's root. + HloComputation* true_comp = FindComputation(module.get(), "true_comp"); + EXPECT_THAT(true_comp->root_instruction(), op::Tuple(op::Outfeed())); + + // The outfeed output token should have propagated to the false computation's + // root. + HloComputation* false_comp = FindComputation(module.get(), "false_comp"); + EXPECT_THAT(false_comp->root_instruction(), op::Tuple(op::AfterAll())); +} +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc similarity index 98% rename from third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc rename to third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc index 88315441dc19e3..6846a186c7e691 100644 --- a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.cc @@ -13,7 +13,7 @@ limitations under the License. ==============================================================================*/ -#include "xla/service/convert_memory_placement_to_internal_annotations.h" +#include "xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" diff --git a/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h new file mode 100644 index 00000000000000..ab2bc8359c0f76 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h @@ -0,0 +1,46 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ +#define XLA_HLO_TRANSFORMS_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass { + public: + ConvertMemoryPlacementToInternalAnnotations() = default; + + absl::string_view name() const override { + return "convert-memory-placement-to-internal-annotations"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc similarity index 99% rename from third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc rename to third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc index 5fad129fe2ee95..28f6e909f48dcd 100644 --- a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations_test.cc +++ b/third_party/xla/xla/hlo/transforms/convert_memory_placement_to_internal_annotations_test.cc @@ -13,7 +13,7 @@ limitations under the License. ==============================================================================*/ -#include "xla/service/convert_memory_placement_to_internal_annotations.h" +#include "xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h" #include #include @@ -24,9 +24,9 @@ #include #include "absl/status/statusor.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/host_memory_offload_annotations.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -34,7 +34,8 @@ namespace xla { namespace { -class ConvertMemoryPlacementToInternalAnnotationsTest : public HloTestBase { +class ConvertMemoryPlacementToInternalAnnotationsTest + : public HloHardwareIndependentTestBase { public: ConvertMemoryPlacementToInternalAnnotationsTest() = default; }; diff --git a/third_party/xla/xla/service/defuser.cc b/third_party/xla/xla/hlo/transforms/defuser.cc similarity index 98% rename from third_party/xla/xla/service/defuser.cc rename to third_party/xla/xla/hlo/transforms/defuser.cc index 2f9e0bb07366ec..225328ecec25c5 100644 --- a/third_party/xla/xla/service/defuser.cc +++ b/third_party/xla/xla/hlo/transforms/defuser.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/defuser.h" +#include "xla/hlo/transforms/defuser.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/defuser.h b/third_party/xla/xla/hlo/transforms/defuser.h new file mode 100644 index 00000000000000..16c459e77524a2 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/defuser.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_DEFUSER_H_ +#define XLA_HLO_TRANSFORMS_DEFUSER_H_ + +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass which replaces all fusion instructions with the equivalent un-fused +// instructions. +class Defuser : public HloModulePass { + public: + Defuser() {} + ~Defuser() override {} + absl::string_view name() const override { return "defuser"; } + + // Run defusion on the given module. Returns whether the module was + // changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_DEFUSER_H_ diff --git a/third_party/xla/xla/service/defuser_test.cc b/third_party/xla/xla/hlo/transforms/defuser_test.cc similarity index 95% rename from third_party/xla/xla/service/defuser_test.cc rename to third_party/xla/xla/hlo/transforms/defuser_test.cc index ad70f7998c66a4..8152873c657420 100644 --- a/third_party/xla/xla/service/defuser_test.cc +++ b/third_party/xla/xla/hlo/transforms/defuser_test.cc @@ -13,19 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/defuser.h" - +#include "xla/hlo/transforms/defuser.h" + +#include +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -class DefuserTest : public HloTestBase { +class DefuserTest : public HloHardwareIndependentTestBase { protected: // Returns the number of fusion instructions in the module. int FusionCount(const HloModule* m) { diff --git a/third_party/xla/xla/service/despecializer.cc b/third_party/xla/xla/hlo/transforms/despecializer.cc similarity index 96% rename from third_party/xla/xla/service/despecializer.cc rename to third_party/xla/xla/hlo/transforms/despecializer.cc index 6bf6f98275bba3..25a1727b900070 100644 --- a/third_party/xla/xla/service/despecializer.cc +++ b/third_party/xla/xla/hlo/transforms/despecializer.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/despecializer.h" +#include "xla/hlo/transforms/despecializer.h" #include #include #include -#include "xla/service/defuser.h" -#include "xla/service/float_normalization.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/sub_byte_normalization.h" +#include "xla/hlo/transforms/defuser.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/despecializer.h b/third_party/xla/xla/hlo/transforms/despecializer.h new file mode 100644 index 00000000000000..633f46cfbeba10 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/despecializer.h @@ -0,0 +1,102 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_DESPECIALIZER_H_ +#define XLA_HLO_TRANSFORMS_DESPECIALIZER_H_ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" + +namespace xla { + +// Creates an HloPassPipeline containing multiple HloPasses that can +// despecialize an optimized HloModule. This is useful to run an HloModule +// optimized for one specific platform on a different platform (undoing platform +// specific passes) with matching numerics for comparison. +// +// Current despecialization passes are HloDescheduler, ControlDepRemover, +// Defuser and BFloat16MixedPrecisionRemoval. +class Despecializer : public HloModulePass { + public: + Despecializer(); + void AddReduceWindowToReduceBroadcastDeconstruct(); + void AddAssumeGatherIndicesInBoundRewriteToCopy(); + absl::string_view name() const override { return "despecializer"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + HloPassPipeline pipeline_; +}; + +class AssumeGatherIndicesInBoundRewriteToCopy : public HloModulePass { + public: + AssumeGatherIndicesInBoundRewriteToCopy() = default; + absl::string_view name() const override { + return "AssumeGatherIndicesInBoundRewriteToCopy"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +class DeconstructReduceWindowToReduceBroadcast : public HloModulePass { + public: + DeconstructReduceWindowToReduceBroadcast() = default; + absl::string_view name() const override { + return "ReduceWindowToReduceAndBroadcast"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +// Pass which strips control dependencies from all instructions in the module. +class ControlDepRemover : public HloModulePass { + public: + ControlDepRemover() = default; + absl::string_view name() const override { return "control-dep-remover"; } + + using HloPassInterface::Run; + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { + bool changed = false; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + changed |= !instruction->control_predecessors().empty(); + TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); + } + } + return changed; + } +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_DESPECIALIZER_H_ diff --git a/third_party/xla/xla/service/despecializer_test.cc b/third_party/xla/xla/hlo/transforms/despecializer_test.cc similarity index 98% rename from third_party/xla/xla/service/despecializer_test.cc rename to third_party/xla/xla/hlo/transforms/despecializer_test.cc index 6ba16f6b8f32e9..8a360059163fe1 100644 --- a/third_party/xla/xla/service/despecializer_test.cc +++ b/third_party/xla/xla/hlo/transforms/despecializer_test.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/despecializer.h" +#include "xla/hlo/transforms/despecializer.h" #include #include #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { -class DespecializerTest : public HloTestBase { +class DespecializerTest : public HloHardwareIndependentTestBase { protected: Despecializer despecializer_; }; diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc similarity index 83% rename from third_party/xla/xla/service/bitcast_dtypes_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc index f4cc6809599cdd..9918e34c352386 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.cc @@ -13,25 +13,26 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bitcast_dtypes_expander.h" - -#include "absl/algorithm/container.h" -#include "absl/strings/str_join.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/broadcast.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" + +#include "absl/strings/str_format.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/broadcast.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -104,14 +105,8 @@ absl::StatusOr BitcastDtypesExpander::ExpandInstruction( BitcastConvertType(input, to_shape.element_type()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, b.Build()); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.h b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.h new file mode 100644 index 00000000000000..da8c63ffb0ee3d --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander.h @@ -0,0 +1,48 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_BITCAST_DTYPES_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_BITCAST_DTYPES_EXPANDER_H_ + +namespace xla { + +// A pass which expands bitcast-convert between differently sized dtypes to a +// reduction. +class BitcastDtypesExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "bitcast_dtypes_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_BITCAST_DTYPES_EXPANDER_H_ diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc similarity index 96% rename from third_party/xla/xla/service/bitcast_dtypes_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc index a5dc3b882446cc..2b5efab5c6897b 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/bitcast_dtypes_expander_test.cc @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bitcast_dtypes_expander.h" +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -class BitcastDtypesExpanderTest : public HloTestBase {}; +class BitcastDtypesExpanderTest : public HloHardwareIndependentTestBase {}; TEST_F(BitcastDtypesExpanderTest, S32toS8) { absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/service/cholesky_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc similarity index 92% rename from third_party/xla/xla/service/cholesky_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc index d70e0211103fff..2bdb4c18036da9 100644 --- a/third_party/xla/xla/service/cholesky_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.cc @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cholesky_expander.h" +#include "xla/hlo/transforms/expanders/cholesky_expander.h" #include #include #include "absl/status/statusor.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -248,15 +249,8 @@ absl::StatusOr CholeskyExpander::ExpandInstruction( MaybeTransposeInMinorDims(l, !options.lower()); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h new file mode 100644 index 00000000000000..3ee4a26ad2ee2f --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/cholesky_expander.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_CHOLESKY_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_CHOLESKY_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +class CholeskyExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "cholesky_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + virtual absl::StatusOr> CholeskyUnblocked( + XlaOp a, PrecisionConfig::Precision precision); + + private: + XlaOp BuildCholesky(XlaOp a, int64_t block_size, + PrecisionConfig::Precision precision); + + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_CHOLESKY_EXPANDER_H_ diff --git a/third_party/xla/xla/service/comparison_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc similarity index 99% rename from third_party/xla/xla/service/comparison_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc index 4a7ff3d5a44628..0f09ecced1ebaf 100644 --- a/third_party/xla/xla/service/comparison_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/comparison_expander.h" +#include "xla/hlo/transforms/expanders/comparison_expander.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.h b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.h new file mode 100644 index 00000000000000..ed812c41728263 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/comparison_expander.h @@ -0,0 +1,58 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_COMPARISON_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_COMPARISON_EXPANDER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// A pass which performs expansion of the comparison operator to support total +// order comparison of floating point numbers. +class ComparisonExpander : public OpExpanderPass { + public: + explicit ComparisonExpander( + absl::Span> + expand_via_upcast = {}) + : expand_via_upcast_(expand_via_upcast.begin(), expand_via_upcast.end()) { + } + ~ComparisonExpander() override = default; + absl::string_view name() const override { return "comparison-expander"; } + + private: + // Returns `true` if `instruction` should be expanded by this pass. + bool InstructionMatchesPattern(HloInstruction* instruction) override; + // Returns a replacement for `instruction`, or nullptr if no replacement is + // needed (e.g. only the to_apply subcomputation of the instruction was + // modified). + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + std::vector> expand_via_upcast_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_COMPARISON_EXPANDER_H_ diff --git a/third_party/xla/xla/service/convolution_4d_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc similarity index 99% rename from third_party/xla/xla/service/convolution_4d_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc index a8c9d29796a93f..a6c25114a4ce19 100644 --- a/third_party/xla/xla/service/convolution_4d_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_4d_expander.h" +#include "xla/hlo/transforms/expanders/convolution_4d_expander.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.h b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.h new file mode 100644 index 00000000000000..e9804d8942a577 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_4D_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_4D_EXPANDER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +class Convolution4DExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "convolution_4d_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_4D_EXPANDER_H_ diff --git a/third_party/xla/xla/service/convolution_4d_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc similarity index 97% rename from third_party/xla/xla/service/convolution_4d_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc index cfe5eece076698..39d7e3ebb9a9c1 100644 --- a/third_party/xla/xla/service/convolution_4d_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_4d_expander_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_4d_expander.h" +#include "xla/hlo/transforms/expanders/convolution_4d_expander.h" #include #include @@ -21,14 +21,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -using Convolution4DExpanderTest = HloTestBase; +using Convolution4DExpanderTest = HloHardwareIndependentTestBase; TEST_F(Convolution4DExpanderTest, ConvertTo2DConvolution) { std::string hlo_string = R"(HloModule convolution_4d_fp32 diff --git a/third_party/xla/xla/service/convolution_pred_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.cc similarity index 97% rename from third_party/xla/xla/service/convolution_pred_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.cc index 99d666d0b4c1d6..74dab17666f236 100644 --- a/third_party/xla/xla/service/convolution_pred_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_pred_expander.h" +#include "xla/hlo/transforms/expanders/convolution_pred_expander.h" #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.h b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.h new file mode 100644 index 00000000000000..28750c72bee26c --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_PRED_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_PRED_EXPANDER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// A pass that rewrites boolean convolutions to floating point and converts the +// result back to boolean. This is necessary, as the convolutions on GPUs are +// implemented using custom call to cuDNN, which only supports FP and S8 inputs. +class ConvolutionPredExpander : public OpExpanderPass { + public: + absl::string_view name() const override { + return "convolution-pred-expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_CONVOLUTION_PRED_EXPANDER_H_ diff --git a/third_party/xla/xla/service/convolution_pred_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc similarity index 92% rename from third_party/xla/xla/service/convolution_pred_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc index dd7d3124ffe9f1..e7aab8622b75f1 100644 --- a/third_party/xla/xla/service/convolution_pred_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/convolution_pred_expander_test.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_pred_expander.h" +#include "xla/hlo/transforms/expanders/convolution_pred_expander.h" #include #include #include #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -31,7 +31,7 @@ namespace { namespace m = match; -using ConvolutionPredExpanderTest = HloTestBase; +using ConvolutionPredExpanderTest = HloHardwareIndependentTestBase; TEST_F(ConvolutionPredExpanderTest, Match) { std::string hlo_string = R"(HloModule convolution_pred diff --git a/third_party/xla/xla/service/dot_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc similarity index 99% rename from third_party/xla/xla/service/dot_decomposer.cc rename to third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc index 531c8899da0ffe..1df1743532438b 100644 --- a/third_party/xla/xla/service/dot_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_decomposer.h" +#include "xla/hlo/transforms/expanders/dot_decomposer.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.h new file mode 100644 index 00000000000000..b399970fe5cffa --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_DOT_DECOMPOSER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_DOT_DECOMPOSER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which converts dots into a canonical form where +// non-contracting and contracting dimensions are reshaped together and batch +// dimensions are the most major dimensions. +class DotDecomposer : public HloModulePass { + public: + absl::string_view name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_DOT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/dot_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc similarity index 97% rename from third_party/xla/xla/service/dot_decomposer_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc index 70cd99d44d6bf3..ad8e6d874fd80d 100644 --- a/third_party/xla/xla/service/dot_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dot_decomposer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_decomposer.h" +#include "xla/hlo/transforms/expanders/dot_decomposer.h" #include @@ -23,10 +23,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -36,7 +36,7 @@ namespace { namespace m = ::xla::match; namespace op = ::xla::testing::opcode_matchers; -using DotDecomposerTest = HloTestBase; +using DotDecomposerTest = HloHardwareIndependentTestBase; TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) { absl::string_view module_string = R"( diff --git a/third_party/xla/xla/service/dynamic_index_splitter.cc b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc similarity index 98% rename from third_party/xla/xla/service/dynamic_index_splitter.cc rename to third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc index cf4e21c997e979..bf4ecc61bf6361 100644 --- a/third_party/xla/xla/service/dynamic_index_splitter.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h new file mode 100644 index 00000000000000..26f68155ac71e6 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_DYNAMIC_INDEX_SPLITTER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_DYNAMIC_INDEX_SPLITTER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into +// separate scalars. +class DynamicIndexSplitter : public HloModulePass { + public: + DynamicIndexSplitter() = default; + absl::string_view name() const override { return "dynamic-index-splitter"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/third_party/xla/xla/service/dynamic_index_splitter_test.cc b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc similarity index 96% rename from third_party/xla/xla/service/dynamic_index_splitter_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc index e3daa8a6ded591..b0699e5a07b6fc 100644 --- a/third_party/xla/xla/service/dynamic_index_splitter_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/dynamic_index_splitter_test.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class DynamicIndexSplitterTest : public HloTestBase {}; +class DynamicIndexSplitterTest : public HloHardwareIndependentTestBase {}; TEST_F(DynamicIndexSplitterTest, DynamicSlice) { const char* const kDynamicSlice = R"( diff --git a/third_party/xla/xla/service/eigh_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc similarity index 96% rename from third_party/xla/xla/service/eigh_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc index e95b268c1f3d8b..d7900a19fdbce0 100644 --- a/third_party/xla/xla/service/eigh_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/eigh_expander.h" +#include "xla/hlo/transforms/expanders/eigh_expander.h" #include #include @@ -24,16 +24,17 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -582,15 +583,8 @@ absl::StatusOr EighExpander::ExpandInstruction( } XlaOp result = BuildEigh(a, lower, max_iter, tol, sort_eigenvalues); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h new file mode 100644 index 00000000000000..54cbee776d9c99 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/eigh_expander.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_EIGH_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_EIGH_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +class EighExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "eigh_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + virtual XlaOp BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol, + bool sort_eigenvalues); + + absl::Status SortByEigenvalues(XlaOp& v, XlaOp& w); + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_EIGH_EXPANDER_H_ diff --git a/third_party/xla/xla/service/logistic_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc similarity index 96% rename from third_party/xla/xla/service/logistic_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc index 8423f93b0514e0..416d29ed6ef8fc 100644 --- a/third_party/xla/xla/service/logistic_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/logistic_expander.h" +#include "xla/hlo/transforms/expanders/logistic_expander.h" #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.h b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.h new file mode 100644 index 00000000000000..fbfb1db901f051 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_LOGISTIC_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_LOGISTIC_EXPANDER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// A pass which performs expansion of the logistic function. +class LogisticExpander : public OpExpanderPass { + public: + LogisticExpander() = default; + ~LogisticExpander() override = default; + absl::string_view name() const override { return "logistic-expander"; } + + private: + // Returns `true` if `instruction` should be expanded by this pass. + bool InstructionMatchesPattern(HloInstruction* instruction) override; + // Returns a replacement for `instruction`, or nullptr if no replacement is + // needed (e.g. only the to_apply subcomputation of the instruction was + // modified). + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_LOGISTIC_EXPANDER_H_ diff --git a/third_party/xla/xla/service/logistic_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc similarity index 91% rename from third_party/xla/xla/service/logistic_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc index 57fce48fc8a325..fb5598524006f6 100644 --- a/third_party/xla/xla/service/logistic_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/logistic_expander_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/logistic_expander.h" +#include "xla/hlo/transforms/expanders/logistic_expander.h" #include #include @@ -21,12 +21,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/dynamic_padder.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -35,8 +35,7 @@ namespace { namespace m = match; -class LogisticExpanderTest : public HloTestBase {}; - +class LogisticExpanderTest : public HloHardwareIndependentTestBase {}; // option is enabled. TEST_F(LogisticExpanderTest, ExpandWith) { diff --git a/third_party/xla/xla/service/op_expander_pass.cc b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.cc similarity index 97% rename from third_party/xla/xla/service/op_expander_pass.cc rename to third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.cc index 318211dce1f08a..25ba442542d2c8 100644 --- a/third_party/xla/xla/service/op_expander_pass.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h new file mode 100644 index 00000000000000..798c6a4ed46c06 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/op_expander_pass.h @@ -0,0 +1,63 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_OP_EXPANDER_PASS_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_OP_EXPANDER_PASS_H_ + +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This pass is an abstract superclass for passes that replace operations that +// match a pattern. It is intended to be subclassed, not used directly. +// +// This pass is useful for legalizing HLO instructions that a particular backend +// does not support into other HLO instructions. +class OpExpanderPass : public HloModulePass { + public: + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // extra_filter: Optional extra filtering criteria for matching instructions, + // used in conjunction with InstructionMatchesPattern. + // preserve_sharding and relay_control_dependency: If we preserve sharding and + // relay control dependency when replacing the matched instructions. + explicit OpExpanderPass(HloPredicate extra_filter = nullptr, + bool preserve_sharding = false, + bool relay_control_dependency = false) + : extra_filter_(std::move(extra_filter)), + preserve_sharding_(preserve_sharding), + relay_control_dependency_(relay_control_dependency) {} + + protected: + // Returns `true` if `instruction` should be expanded by this pass. + virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; + + // Returns a replacement for `instruction`, or nullptr if no replacement is + // needed (e.g. only the to_apply subcomputation of the instruction was + // modified). + virtual absl::StatusOr ExpandInstruction( + HloInstruction* instruction) = 0; + + HloPredicate extra_filter_; + const bool preserve_sharding_; + const bool relay_control_dependency_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_OP_EXPANDER_PASS_H_ diff --git a/third_party/xla/xla/service/optimization_barrier_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc similarity index 96% rename from third_party/xla/xla/service/optimization_barrier_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc index 877fcb1670236e..12908f26c8fbd8 100644 --- a/third_party/xla/xla/service/optimization_barrier_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/optimization_barrier_expander.h" +#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h new file mode 100644 index 00000000000000..f6904ec0ff1b7e --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/optimization_barrier_expander.h @@ -0,0 +1,38 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_OPTIMIZATION_BARRIER_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_OPTIMIZATION_BARRIER_EXPANDER_H_ + +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// This pass removes the opt-barrier operation which is functionally a no-op. +class OptimizationBarrierExpander : public HloModulePass { + public: + OptimizationBarrierExpander() = default; + + absl::string_view name() const override { return "cse_barrier_expander"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_OPTIMIZATION_BARRIER_EXPANDER_H_ diff --git a/third_party/xla/xla/service/qr_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc similarity index 96% rename from third_party/xla/xla/service/qr_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc index 4f79769d7c6bf8..1627a6be5e683b 100644 --- a/third_party/xla/xla/service/qr_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.cc @@ -13,22 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/qr_expander.h" +#include "xla/hlo/transforms/expanders/qr_expander.h" #include #include #include "absl/status/statusor.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/loops.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/qr.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/loops.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" @@ -551,15 +552,8 @@ absl::StatusOr QrExpander::ExpandInstruction( } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result)); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h new file mode 100644 index 00000000000000..8d7c4a8e90786b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/qr_expander.h @@ -0,0 +1,57 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_QR_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_QR_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +class QrExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "qr_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + virtual absl::StatusOr QrBlock( + XlaOp a, PrecisionConfig::Precision precision); + + virtual absl::StatusOr CompactWYRepresentation( + PrimitiveType type, absl::Span batch_dims, XlaOp vs, + XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision); + + private: + absl::StatusOr BuildQrDecomposition( + XlaOp a, int64_t block_size, PrecisionConfig::Precision precision); + + absl::StatusOr ProductOfElementaryHouseholderReflectors( + XlaOp a, XlaOp taus, int64_t block_size, + PrecisionConfig::Precision precision); + + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_QR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/real_imag_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc similarity index 96% rename from third_party/xla/xla/service/real_imag_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc index 6f50e250ec1140..33735a16f25e8b 100644 --- a/third_party/xla/xla/service/real_imag_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/real_imag_expander.h" +#include "xla/hlo/transforms/expanders/real_imag_expander.h" #include "xla/literal_util.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h new file mode 100644 index 00000000000000..52b50455744b27 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander.h @@ -0,0 +1,37 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_REAL_IMAG_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_REAL_IMAG_EXPANDER_H_ + +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// Expands real/image instructions with non-complex inputs. +class RealImagExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "real_imag_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_REAL_IMAG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/real_imag_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc similarity index 95% rename from third_party/xla/xla/service/real_imag_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc index 429042745427f0..7f0042a5169db1 100644 --- a/third_party/xla/xla/service/real_imag_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/real_imag_expander_test.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/real_imag_expander.h" +#include "xla/hlo/transforms/expanders/real_imag_expander.h" #include #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/hlo_creation_utils.h" @@ -27,7 +28,6 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" @@ -36,7 +36,7 @@ namespace { namespace m = match; -class RealImagExpanderTest : public HloTestBase {}; +class RealImagExpanderTest : public HloHardwareIndependentTestBase {}; TEST_F(RealImagExpanderTest, RealWithNonComplexInput) { const char* kModuleStr = R"( diff --git a/third_party/xla/xla/service/reduce_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc similarity index 98% rename from third_party/xla/xla/service/reduce_decomposer.cc rename to third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc index 907fd824b750d4..de795a8f74989a 100644 --- a/third_party/xla/xla/service/reduce_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reduce_decomposer.h" +#include "xla/hlo/transforms/expanders/reduce_decomposer.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h new file mode 100644 index 00000000000000..46c2e7ddf6e429 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer.h @@ -0,0 +1,80 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_REDUCE_DECOMPOSER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_REDUCE_DECOMPOSER_H_ + +#include + +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// For each reduction R(I), ensures the postcondition: +// +// !custom_layout_allowed(R) +// => +// layout(R) == layout(I) # modulo removed dimensions +// +// To achieve that, decomposes layout-mutating reductions which do not satisfy +// `custom_layout_allowed` into a reduction and a copy. +// +// For a singular reduction: +// +// -> reduce -> +// +// Gets turned into: +// +// -> reduce -> copy -> +// +// For a variadic recuction, the layout assignment guarantees that the layout +// is the same for all outputs. This pass will transpose the variadic reduction +// inputs which have different physical layout to the first operand. +// +// A{L} \ +// B{L} -> reduce{L'} -> +// C{L} / +// +// Get turned into: +// +// A{L} \ / GTE(1) -> copy{L'} \ +// B{L} -> reduce{E(L)} --- GTE(2) -> copy{L'} - Tuple{L'} +// C{L} / \ GTE(3) -> copy{L'} / +// +// Where E(L) is expected layout of a reduction (original layout with reduce +// dimensions dropped). +// +// PRECONDITION: +// In variadic reduction, all outputs have the same layout +// (enforced by layout assignment). +class ReduceDecomposer : public HloModulePass { + public: + explicit ReduceDecomposer(HloPredicate custom_layout_allowed = nullptr) + : custom_layout_allowed_(custom_layout_allowed) {} + + absl::string_view name() const override { return "reduce-decomposer"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + HloPredicate custom_layout_allowed_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_REDUCE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reduce_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc similarity index 96% rename from third_party/xla/xla/service/reduce_decomposer_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc index 54d290ec9e4418..997ea50e51b565 100644 --- a/third_party/xla/xla/service/reduce_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reduce_decomposer_test.cc @@ -12,22 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reduce_decomposer.h" +#include "xla/hlo/transforms/expanders/reduce_decomposer.h" #include #include #include -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { -class ReduceDecomposerTest : public HloTestBase {}; +class ReduceDecomposerTest : public HloHardwareIndependentTestBase {}; TEST_F(ReduceDecomposerTest, ReducePerformsTransposition) { // Reshape is already a bitcast, nothing should be changed. diff --git a/third_party/xla/xla/service/reshape_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc similarity index 98% rename from third_party/xla/xla/service/reshape_decomposer.cc rename to third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc index 60ccb36adc8a68..ac0b058426a67e 100644 --- a/third_party/xla/xla/service/reshape_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reshape_decomposer.h" +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" #include "absl/status/status.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" diff --git a/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h new file mode 100644 index 00000000000000..1efa0cbf2c7ef2 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_RESHAPE_DECOMPOSER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_RESHAPE_DECOMPOSER_H_ + +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Decomposes a reshape which does not satisfy the ReshapeIsBitcast precondition +// into a bitcast and a copy (physical transposition). Tries to create only one +// transposition, but when it's not possible, creates two. +// +// Postcondition: All reshapes are turned into bitcasts. +class ReshapeDecomposer : public HloModulePass { + public: + absl::string_view name() const override { return "reshape-decomposer"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_RESHAPE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reshape_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc similarity index 93% rename from third_party/xla/xla/service/reshape_decomposer_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc index 94d135e8b15b0d..87cf748818069e 100644 --- a/third_party/xla/xla/service/reshape_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/reshape_decomposer_test.cc @@ -12,21 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reshape_decomposer.h" +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" #include #include -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { -class ReshapeDecomposerTest : public HloTestBase { +class ReshapeDecomposerTest : public HloHardwareIndependentTestBase { public: // Runs reshape decomposer, if `expected` is present, checks it with FileCheck // on the output, otherwise checks that the module has not changed. diff --git a/third_party/xla/xla/service/rng_bit_generator_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc similarity index 87% rename from third_party/xla/xla/service/rng_bit_generator_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc index 0d78762f47b964..045784c50b50c9 100644 --- a/third_party/xla/xla/service/rng_bit_generator_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.cc @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/rng_bit_generator_expander.h" +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" #include "absl/status/statusor.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -86,15 +87,8 @@ RngBitGeneratorExpander::GetGeneratorComputation(const Shape& data_shape, ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0); Tuple(&builder, {final_state, output.value}); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - HloComputation* new_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * new_computation, + XlaComputationToHloComputation(xla_computation, module)); computation_cache_.emplace(cache_key, new_computation); return new_computation; } diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h new file mode 100644 index 00000000000000..15df45060052b5 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_bit_generator_expander.h @@ -0,0 +1,72 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_RNG_BIT_GENERATOR_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_RNG_BIT_GENERATOR_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class RngBitGeneratorExpander : public OpExpanderPass { + public: + explicit RngBitGeneratorExpander(RandomAlgorithm default_algorithm) + : default_algorithm_(default_algorithm) { + CHECK_NE(default_algorithm_, RandomAlgorithm::RNG_DEFAULT); + } + + absl::string_view name() const override { + return "rng-bit-generator-expander"; + } + + protected: + struct RngGeneratorKey { + Shape data_shape; + Shape state_shape; + RandomAlgorithm algorithm; + HloModule* module; + + template + friend H AbslHashValue(H h, const RngGeneratorKey& c) { + return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm, + c.module); + } + + bool operator==(const RngGeneratorKey& o) const { + return data_shape == o.data_shape && state_shape == o.state_shape && + algorithm == o.algorithm && module == o.module; + } + }; + + bool InstructionMatchesPattern(HloInstruction* instruction) override; + absl::StatusOr ExpandInstruction( + HloInstruction* hlo) override; + absl::StatusOr GetGeneratorComputation( + const Shape& data_shape, const Shape& state_shape, + RandomAlgorithm algorithm, HloModule* module); + + const RandomAlgorithm default_algorithm_; + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_RNG_BIT_GENERATOR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/rng_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc similarity index 91% rename from third_party/xla/xla/service/rng_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc index cbc5a1d4549db9..2667440674887a 100644 --- a/third_party/xla/xla/service/rng_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/rng_expander.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" #include -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo_creation_utils.h" @@ -111,16 +111,7 @@ absl::StatusOr GetComputationForRng(HloInstruction* rng) { } TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloModule* module = rng->GetModule(); - HloCloneContext context(module); - return module->DeepCloneComputation(new_module->entry_computation(), - &context); + return XlaComputationToHloComputation(xla_computation, rng->GetModule()); } } // namespace diff --git a/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h new file mode 100644 index 00000000000000..e6c52cf1143a44 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/rng_expander.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_RNG_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_RNG_EXPANDER_H_ + +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +class RngExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "rng-expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* rng) override; + + private: + // Cache RNG computations based on the distribution, output shape and shapes + // of the first and second operand. + absl::flat_hash_map, + HloComputation*> + expanded_rng_instructions_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_RNG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/stable_sort_expander.cc b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc similarity index 99% rename from third_party/xla/xla/service/stable_sort_expander.cc rename to third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc index ca87dce4df65a7..3df7d03a2b0024 100644 --- a/third_party/xla/xla/service/stable_sort_expander.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/stable_sort_expander.h" +#include "xla/hlo/transforms/expanders/stable_sort_expander.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h new file mode 100644 index 00000000000000..210eaeb1a17b74 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander.h @@ -0,0 +1,51 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_STABLE_SORT_EXPANDER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_STABLE_SORT_EXPANDER_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// HLO pass which expands Sort ops that have the is_stable field set to true +// into equivalent Sort ops which guarantee stable sorting without relying on +// the is_stable field. +class StableSortExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "stable-sort-expander"; } + + // Returns the index of the sort operand that is an iota op with an iota + // dimension which is the same as the dimension to sort. Also it should have + // an integral type that is large enough for the number of elements in the + // sort dimension. For now, we only allow S32, because we expect to find a S32 + // iota operand for all Sort ops which are created by TopK. + // + // If no operand of the input sort matches the conditions above, returns -1. + static int64_t IotaOperandIndexForStableSort(const HloSortInstruction& sort); + + private: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_STABLE_SORT_EXPANDER_H_ diff --git a/third_party/xla/xla/service/stable_sort_expander_test.cc b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc similarity index 97% rename from third_party/xla/xla/service/stable_sort_expander_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc index 83ba193ede5aef..a3b40831a24f5e 100644 --- a/third_party/xla/xla/service/stable_sort_expander_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stable_sort_expander_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/stable_sort_expander.h" +#include "xla/hlo/transforms/expanders/stable_sort_expander.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/algebraic_simplifier.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { @@ -29,7 +29,7 @@ namespace { namespace m = match; -using StableSortExpanderTest = HloTestBase; +using StableSortExpanderTest = HloHardwareIndependentTestBase; // Checks whether 'a' and 'b' are roots of equivalent computations, except that // parameters 2 * i and 2 * i + 1 are switched. diff --git a/third_party/xla/xla/service/stochastic_convert_decomposer.cc b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc similarity index 98% rename from third_party/xla/xla/service/stochastic_convert_decomposer.cc rename to third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc index d7746a2fb838bb..1fb054159d7848 100644 --- a/third_party/xla/xla/service/stochastic_convert_decomposer.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h new file mode 100644 index 00000000000000..835a55be249c7c --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_EXPANDERS_STOCHASTIC_CONVERT_DECOMPOSER_H_ +#define XLA_HLO_TRANSFORMS_EXPANDERS_STOCHASTIC_CONVERT_DECOMPOSER_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// StochasticConvertDecomposer is a pass which replaces unsupported +// stochastic-convert with multiple hlos. +class StochasticConvertDecomposer : public HloModulePass { + public: + absl::string_view name() const override { + return "stochastic_convert_decomposer"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_EXPANDERS_STOCHASTIC_CONVERT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/stochastic_convert_decomposer_test.cc b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc similarity index 94% rename from third_party/xla/xla/service/stochastic_convert_decomposer_test.cc rename to third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc index 48ac6e61ce8e9c..8ebc1b448e09a2 100644 --- a/third_party/xla/xla/service/stochastic_convert_decomposer_test.cc +++ b/third_party/xla/xla/hlo/transforms/expanders/stochastic_convert_decomposer_test.cc @@ -13,22 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -using StochasticConvertDecomposerTest = HloTestBase; +using StochasticConvertDecomposerTest = HloHardwareIndependentTestBase; using ::testing::HasSubstr; TEST_F(StochasticConvertDecomposerTest, DecomposeStochasticConvertF32ToS32) { diff --git a/third_party/xla/xla/service/host_offload_legalize.cc b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc similarity index 96% rename from third_party/xla/xla/service/host_offload_legalize.cc rename to third_party/xla/xla/hlo/transforms/host_offload_legalize.cc index e4f9a3a227cd44..296fa1a9b2e93e 100644 --- a/third_party/xla/xla/service/host_offload_legalize.cc +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offload_legalize.h" +#include "xla/hlo/transforms/host_offload_legalize.h" #include #include @@ -245,7 +245,8 @@ absl::StatusOr WalkUpMemoryOffload( // instruction at a time, but returns multiple instructions for each conforming // user. absl::StatusOr> WalkDownMemoryOffload( - const InstructionAndIndex& current_value, const CallGraph& call_graph) { + const InstructionAndIndex& current_value, const CallGraph& call_graph, + bool for_move_copy_phase) { // TODO(maggioni): Verify that set of instructions supported in chain by // legalization is in sync with host_offloader. VLOG(6) << "Getting users of: \"" << current_value.instruction->ToString() @@ -348,8 +349,23 @@ absl::StatusOr> WalkDownMemoryOffload( results.emplace_back(user, current_value.index); break; } + case HloOpcode::kAsyncStart: { + if (user->async_execution_thread() == HloInstruction::kHostThread) { + // For move copy phase, we need to handle the copy even though we + // never move the tensor to device yet. For now just throw an error. + CHECK(!for_move_copy_phase) + << "Transpose copy going into host call is not supported yet."; + + // For first phase to collect copies to move, it's ok to ignore this + // path since we don't see copies along the path yet and it's ok to + // pass host tensor to the async host call. + break; + } + [[fallthrough]]; + } default: { - return absl::InvalidArgumentError("Unrecognized user opcode"); + return absl::InvalidArgumentError( + absl::StrFormat("Unrecognized user name: %s", user->name())); } } } @@ -423,11 +439,12 @@ absl::Status MoveCopy( current_instruction_and_shapes.instruction_and_index; stack.pop_back(); VLOG(5) << "Current top of stack: " - << current_instruction_and_index.instruction->ToString() << " " - << current_instruction_and_index.index; + << current_instruction_and_index.instruction->ToString() + << ", index: " << current_instruction_and_index.index; // Get the users of the current instruction. absl::StatusOr> current_value_down = - WalkDownMemoryOffload(current_instruction_and_index, *call_graph); + WalkDownMemoryOffload(current_instruction_and_index, *call_graph, + /*for_move_copy_phase=*/true); if (!current_value_down.ok()) { VLOG(5) << "WalkDownMemoryOffload failed: " << current_value_down.status(); @@ -677,7 +694,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( std::vector stack = {current_value}; while (!stack.empty()) { VLOG(5) << "Current value before down: " - << stack.back().instruction->ToString(); + << stack.back().instruction->ToString() << " " + << stack.back().index; if (absl::c_linear_search(kUsersOpcodes, stack.back().instruction->opcode()) || stack.back().instruction->IsCustomCall( @@ -737,7 +755,8 @@ absl::StatusOr ProcessAnnotationForCopyMovement( continue; } absl::StatusOr> current_value_down = - WalkDownMemoryOffload(stack.back(), *call_graph); + WalkDownMemoryOffload(stack.back(), *call_graph, + /*for_move_copy_phase=*/false); if (!current_value_down.ok()) { VLOG(5) << "Current value down failed: " << current_value_down.status(); break; @@ -758,6 +777,10 @@ absl::StatusOr ProcessAnnotationForCopyMovement( } } + if (copies_to_move.empty()) { + return false; + } + // Process all copies one at a time from the last to the first and push it to // its specific user. for (auto it = copies_to_move.rbegin(); it != copies_to_move.rend(); ++it) { diff --git a/third_party/xla/xla/hlo/transforms/host_offload_legalize.h b/third_party/xla/xla/hlo/transforms/host_offload_legalize.h new file mode 100644 index 00000000000000..a5d85fa40a8a5c --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize.h @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_HOST_OFFLOAD_LEGALIZE_H_ +#define XLA_HLO_TRANSFORMS_HOST_OFFLOAD_LEGALIZE_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +class HloCostAnalysis; + +// This pass legalizes the graph for the "host memory offloading" pass to +// correctly identified buffers that are meant to be move on the host. Any +// legalization that could block that is welcome into this pass. +class HostOffloadLegalize : public HloModulePass { + public: + explicit HostOffloadLegalize(int64_t host_memory_space_color, + bool after_layout) + : kHostMemorySpaceColor(host_memory_space_color), + after_layout_(after_layout) {} + ~HostOffloadLegalize() override = default; + + absl::string_view name() const override { return "host-offload-legalize"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const int64_t kHostMemorySpaceColor; + const bool after_layout_; + + // For any memory offloaded to the host, return the instruction which is the + // start of such and offload. These will either be "MoveToHost" annotations or + // entry computation parameters. + std::vector FindStartingInstructionsOfHostMemoryOffload( + HloModule* module, + const absl::flat_hash_set& execution_threads) const; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_HOST_OFFLOAD_LEGALIZE_H_ diff --git a/third_party/xla/xla/service/host_offload_legalize_test.cc b/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc similarity index 94% rename from third_party/xla/xla/service/host_offload_legalize_test.cc rename to third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc index 55c36a5310f6db..bb5151bc243ec5 100644 --- a/third_party/xla/xla/service/host_offload_legalize_test.cc +++ b/third_party/xla/xla/hlo/transforms/host_offload_legalize_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offload_legalize.h" +#include "xla/hlo/transforms/host_offload_legalize.h" #include #include @@ -27,12 +27,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -40,7 +40,7 @@ limitations under the License. namespace xla { namespace { -class HostOffloadLegalizeTest : public HloTestBase { +class HostOffloadLegalizeTest : public HloHardwareIndependentTestBase { protected: static constexpr int64_t kHostMemorySpaceColor{5}; @@ -74,6 +74,36 @@ class HostOffloadLegalizeTest : public HloTestBase { } }; +TEST_F(HostOffloadLegalizeTest, TestWithAsyncCall) { + const std::string& hlo_string = R"( +HloModule jit_update, entry_computation_layout={(f32[20,3,256,133]{2,3,1,0:T(8,128)S(5)})->(f32[20,3,256,133]{2,1,0,3:T(4,128)}, f32[4096]{0:T(1024)})} + +%async_computation { + %param_0 = f32[20,3,256,133] parameter(0) + ROOT %offloaded-custom-call = f32[4096] custom-call(%param_0), custom_call_target="HostExecute" +}, execution_thread="host" + +ENTRY main { + %param.246 = f32[20,3,256,133] parameter(0) + %async-start = ((f32[20,3,256,133]), f32[4096], u32[]) async-start(%param.246), async_execution_thread="host", calls=%async_computation + %async-done = f32[4096] custom-call-done(%async-start) + copy.16744 = f32[20,3,256,133]{2,1,0,3:T(4,128)} copy(param.246) + custom-call.7832 = f32[20,3,256,133]{2,1,0,3:T(4,128)} custom-call(copy.16744), custom_call_target="MoveToDevice" + ROOT tuple.16745 = (f32[20,3,256,133]{2,1,0,3:T(4,128)}, f32[4096]{0:T(1024)}) tuple(custom-call.7832, %async-done) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloadLegalize(module.get())); + EXPECT_TRUE(changed); + HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call.7832"); + ASSERT_NE(custom_call, nullptr); + EXPECT_EQ(custom_call->users()[0]->opcode(), HloOpcode::kCopy); + XLA_VLOG_LINES(1, module->ToString()); +} + TEST_F(HostOffloadLegalizeTest, NoCopyWithOptBarrierMoreElaborate) { const std::string& hlo_string = R"( HloModule jit_f, entry_computation_layout={(f32[16,256]{0,1})->f32[16,256]{1,0}} diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/hlo/transforms/host_offloader.cc similarity index 96% rename from third_party/xla/xla/service/host_offloader.cc rename to third_party/xla/xla/hlo/transforms/host_offloader.cc index 5cc3cdc13b24af..f24f85d5f13c64 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloader.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offloader.h" +#include "xla/hlo/transforms/host_offloader.h" #include #include @@ -36,6 +36,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_value.h" @@ -53,6 +53,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/side_effect_util.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -91,6 +92,14 @@ bool SetBuffersToMemorySpaceColor( return changed; } +void SetHostComputeFrontendAttribute(HloInstruction& host_instruction) { + FrontendAttributes frontend_attributes = + host_instruction.frontend_attributes(); + frontend_attributes.mutable_map()->insert( + {kXlaComputeTypeAttr, kXlaComputeTypeHost}); + host_instruction.set_frontend_attributes(frontend_attributes); +} + } // namespace bool HostOffloader::InstructionIsAllowedBetweenMoveToHostAndDus( @@ -201,6 +210,14 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( // memory. slices_to_dynamify.insert(instruction); continue; + } else if (instruction->opcode() == HloOpcode::kAllGather || + instruction->opcode() == HloOpcode::kAllReduce) { + LOG(WARNING) << absl::StreamFormat( + "Found an instruction (\"%s\") which does device compute in host " + "memory space. Converting into host compute. This is likely to have " + "a very high overhead.", + instruction->name()); + SetHostComputeFrontendAttribute(*instruction); } else { // Found an instruction which is invalid during host memory offloading. return absl::InvalidArgumentError( @@ -1007,6 +1024,31 @@ absl::StatusOr HostOffloader::HandleRedundantCopiesBackToHost( return UpdateMemorySpaceForHostOffloadedOutputs(call_start, host_instrs_tree); } +absl::StatusOr HostOffloader::ProcessNextMoveToHostInstr( + HloComputation* computation) { + for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { + if (instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { + TF_ASSIGN_OR_RETURN(bool removed_move_to_host, + HandleMoveToHostCustomCall(instruction)); + if (removed_move_to_host) { + return true; + } + } + + if (instruction->has_called_computations()) { + for (HloComputation* called_comp : instruction->called_computations()) { + TF_ASSIGN_OR_RETURN(bool removed_move_to_host, + ProcessNextMoveToHostInstr(called_comp)); + if (removed_move_to_host) { + return true; + } + } + } + } + return false; +} + absl::StatusOr HostOffloader::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -1038,26 +1080,10 @@ absl::StatusOr HostOffloader::Run( // Iterate over the computations in the order that they are executed. This // ensures we process "MoveToHost" instructions that are at the beginning of // a host memory offload instruction chain. - std::vector post_order_computations = - module->MakeComputationPostOrder(execution_threads); - for (auto it = post_order_computations.rbegin(); - it != post_order_computations.rend(); ++it) { - HloComputation* computation = *it; - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - if (instruction->IsCustomCall( - host_memory_offload_annotations::kMoveToHostCustomCallTarget)) { - TF_ASSIGN_OR_RETURN(changed_in_loop, - HandleMoveToHostCustomCall(instruction)); - if (changed_in_loop) { - changed = true; - break; - } - } - } - if (changed_in_loop) { - break; - } + TF_ASSIGN_OR_RETURN(changed_in_loop, ProcessNextMoveToHostInstr( + module->entry_computation())); + if (changed_in_loop) { + changed = true; } } while (changed_in_loop); diff --git a/third_party/xla/xla/hlo/transforms/host_offloader.h b/third_party/xla/xla/hlo/transforms/host_offloader.h new file mode 100644 index 00000000000000..a4d7a755c8302e --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/host_offloader.h @@ -0,0 +1,171 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_HOST_OFFLOADER_H_ +#define XLA_HLO_TRANSFORMS_HOST_OFFLOADER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/host_offload_utils.h" + +namespace xla { + +class HloCostAnalysis; + +// This pass does "host memory offloading". If a tensor is annotated to be moved +// to or from the host, this pass will remove the annotations and update each +// tensor's layout with host memory spaces and insert copies if necessary. This +// pass checks to make sure that no compute is done on the tensors annotated for +// host memory offload; if there is compute, it is considered a user error and +// an error will be returned. +// The pass will "walk down" the Hlo graph starting from either MoveToHost +// custom calls or from parameters with host memory space in their layout. All +// tensors along each path have their memory space set as host memory space. If +// a MoveToHost custom call is paired with a DynamicUpdateSlice, the +// DynamicUpdateSlice will write into host memory space. Otherwise, a copy from +// device to host will be inserted. +// +// If an output of a host offloaded computation is only used on host, the memory +// space of the usages are updated to reflect it and no copies to and from host +// are performed. Any MoveToHost instructions for outputs used only on host, are +// removed. +// TODO(b/347101407): A better approach could be to remove redundant copies in a +// generalized fashion. Should also be moved out of Host Offloader. +// +// All MoveToHost and MoveToDevice custom calls are removed by the end of this +// pass. +class HostOffloader : public HloModulePass { + public: + explicit HostOffloader(int64_t host_memory_space_color) + : kHostMemorySpaceColor(host_memory_space_color) {} + ~HostOffloader() override = default; + + absl::string_view name() const override { return "host-offloader"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Process the next "MoveToHost" instruction that resides at the beginning of + // a host memory offload instruction chain. This ensures that redundant + // "MoveToHost" (those already residing inside of a host memory offload + // instruction chain) are ignored. + absl::StatusOr ProcessNextMoveToHostInstr(HloComputation* computation); + + const int64_t kHostMemorySpaceColor; + absl::flat_hash_set + already_visited_move_to_host_custom_calls_; + absl::flat_hash_set dynamic_update_slices_already_allocated_; + absl::flat_hash_set validated_slices_; + absl::flat_hash_map copies_created_after_; + absl::flat_hash_set move_to_device_custom_calls_to_remove_; + absl::flat_hash_set + already_inserted_copy_before_; + + // Sometimes previous transformations turn a DynamicSlice into a Slice. Since + // we're doing a DMA between the host and device, we need to turn the Slice + // back into a DynamicSlice. + absl::StatusOr DynamifySlice(HloInstruction* slice); + + // Returns true if the instruction is allowed to be in the + // middle of a path between a MoveToHost custom-call annotation and a + // DynamicUpdateSlice. Ideally the custom-call should be immediately followed + // by the DynamicUpdateSlice, but this is not always the case. + bool InstructionIsAllowedBetweenMoveToHostAndDus( + const HloInstruction* instruction) const; + + // Returns true if the instruction is allowed to be in the + // middle of a path between a DynamicSlice and a MoveToDevice custom-call + // annotation. Ideally the DynamicSlice should be immediately followed by the + // custom-call, but this is not always the case. + bool InstructionIsAllowedBetweenDsAndMoveToDevice( + const HloInstruction* instruction) const; + + // Walks down the graph and does "host memory offloading" starting from every + // host memory parameter in the entry computation. + absl::StatusOr HandleInputStreaming(HloComputation* entry_computation); + + // Walks down the graph and does "host memory offloading" starting from every + // MoveToHost custom call. + absl::StatusOr HandleMoveToHostCustomCall( + HloInstruction* custom_call_instruction); + + // Since we always walk the graph from the top down, this function only needs + // to remove these lingering custom calls. This function should only be called + // once all host memory offloading is done because multiple paths might lead + // to the same MoveToDevice custom call. Removing it too early will confuse + // subsequent walkings of the graph. + absl::StatusOr HandleMoveToDeviceCustomCall( + HloInstruction* custom_call_instruction); + + // DynamicUpdateSlices which write into host memory must have their + // destination buffer allocated on the host. This function creates the + // allocation and updates all positions to have host memory space. + absl::Status CreateAllocateBufferForDynamicUpdateSlice( + HloInstruction* dynamic_update_slice); + + // Returns an error if something unallowed exists between the + // Slice/DynamicSlice and the MoveToDevice custom call. + absl::Status ValidateSliceLeadsToMoveToDeviceCustomCall( + HloInstruction* slice); + + // Common function for doing the actual walking of the graph. Host memory + // spaces are set and copies are inserted in here. + absl::StatusOr WalkDownHostMemoryOffloadPaths( + const host_offload_utils::InstructionAndShapeIndex& + starting_instruction_and_index, + bool insert_copy_before); + + // Given a custom call, this returns the first instruction and shape index to + // start the host memory offload path from for each use of the custom call. + absl::StatusOr> + GetStartingInstructions(HloInstruction* custom_call_instruction); + + // When a MoveToHost custom call is not paired with a DynamicUpdateSlice, a + // copy from device to host must be inserted. + absl::StatusOr InsertCopyBetween( + const host_offload_utils::InstructionAndShapeIndex& + before_instruction_and_index, + const host_offload_utils::InstructionAndShapeIndex& + after_instruction_and_index); + + // This is a fix for scheduling. Add copies to inputs of dynamic-update-slice + // if the inserted value is directly a parameter of a computation. This is to + // avoid cases in while loop where parameter/output aliasing can stop + // scheduling because control-dependencies are added. + absl::StatusOr ApplySchedulingFix( + HloModule* module, + const absl::flat_hash_set& execution_threads); + + // Starting from the outputs of the host offloaded computation, track all + // their usages. For the outputs that are ONLY used on host, remove redundant + // copies to and from host, as well as update the memory space. + absl::StatusOr HandleRedundantCopiesBackToHost( + const HloModule* module, HloInstruction* instruction); +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_HOST_OFFLOADER_H_ diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/hlo/transforms/host_offloader_test.cc similarity index 97% rename from third_party/xla/xla/service/host_offloader_test.cc rename to third_party/xla/xla/hlo/transforms/host_offloader_test.cc index 13185e379c3aa4..0306f5974bfd0e 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloader_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offloader.h" +#include "xla/hlo/transforms/host_offloader.h" #include #include @@ -31,16 +31,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/transforms/host_offload_legalize.h" #include "xla/layout.h" #include "xla/service/hlo_verifier.h" #include "xla/service/host_memory_offload_annotations.h" -#include "xla/service/host_offload_legalize.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -50,7 +50,7 @@ namespace m = ::xla::match; namespace xla { namespace { -class HostOffloaderTest : public HloTestBase { +class HostOffloaderTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr RunHostOffloader(HloModule* module, bool after_layout = false) { @@ -4114,6 +4114,82 @@ TEST_F(HostOffloaderTest, RemoveRedundantCopiesBackToHostOutputIsNonTuple) { TestShapeHasMemorySpace(ShapeUtil::GetSubshape(output_tuple->shape(), {1}), Layout::kHostMemorySpace); } + +// Test to ensure that redundant "MoveToHost" instructions do not produce +// redundant copy to host instructions after running the host offloader pass. +TEST_F(HostOffloaderTest, AvoidRedundantCopiesToHost) { + const absl::string_view hlo_string = R"( + HloModule AvoidRedundantCopiesToHost, entry_computation_layout={(bf16[65536,1024]{1,0:T(8,128)(2,1)})->bf16[65536,1024]{1,0:T(8,128)(2,1)S(5)}}, num_partitions=8 + + body { + param.1 = (s32[]{:T(128)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(param.1), index=0 + constant.22 = s32[]{:T(128)} constant(1) + add.1 = s32[]{:T(128)} add(get-tuple-element.3, constant.22) + get-tuple-element.4 = bf16[65536,1024]{1,0:T(8,128)(2,1)} get-tuple-element(param.1), index=1 + get-tuple-element.5 = bf16[65536,1024]{1,0:T(8,128)(2,1)} get-tuple-element(param.1), index=2 + constant.23 = s32[]{:T(128)} constant(8) + multiply.1 = s32[]{:T(128)} multiply(get-tuple-element.3, constant.23) + constant.24 = s32[]{:T(128)} constant(0) + compare.3 = pred[]{:T(512)} compare(multiply.1, constant.24), direction=LT + constant.25 = s32[]{:T(128)} constant(65536) + add.2 = s32[]{:T(128)} add(multiply.1, constant.25) + select.1 = s32[]{:T(128)} select(compare.3, add.2, multiply.1) + dynamic-slice.1 = bf16[8,1024]{1,0:T(8,128)(2,1)} dynamic-slice(get-tuple-element.5, select.1, constant.24), dynamic_slice_sizes={8,1024} + custom-call.4 = bf16[8,1024]{1,0:T(8,128)(2,1)} custom-call(dynamic-slice.1), custom_call_target="MoveToHost" + dynamic-update-slice.0 = bf16[65536,1024]{1,0:T(8,128)(2,1)} dynamic-update-slice(get-tuple-element.4, custom-call.4, select.1, constant.24) + ROOT tuple.1 = (s32[]{:T(128)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}) tuple(add.1, dynamic-update-slice.0, get-tuple-element.5) + } + + or_comp { + Arg_0.27 = pred[]{:T(512)} parameter(0) + Arg_1.28 = pred[]{:T(512)} parameter(1) + ROOT or.29 = pred[]{:T(512)} or(Arg_0.27, Arg_1.28) + } + + condition { + param = (s32[]{:T(128)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(param), index=0 + constant.15 = s32[]{:T(128)} constant(8) + multiply.0 = s32[]{:T(128)} multiply(get-tuple-element.1, constant.15) + constant.16 = s32[]{:T(128)} constant(65536) + compare.0 = pred[]{:T(512)} compare(multiply.0, constant.16), direction=LT + get-tuple-element.2 = bf16[65536,1024]{1,0:T(8,128)(2,1)} get-tuple-element(param), index=2 + constant.17 = s32[]{:T(128)} constant(0) + compare.1 = pred[]{:T(512)} compare(multiply.0, constant.17), direction=LT + add.0 = s32[]{:T(128)} add(multiply.0, constant.16) + select.0 = s32[]{:T(128)} select(compare.1, add.0, multiply.0) + dynamic-slice.0 = bf16[8,1024]{1,0:T(8,128)(2,1)} dynamic-slice(get-tuple-element.2, select.0, constant.17), dynamic_slice_sizes={8,1024} + constant.20 = bf16[]{:T(256)} constant(0) + broadcast.3 = bf16[8,1024]{1,0:T(8,128)(2,1)} broadcast(constant.20), dimensions={} + compare.2 = pred[8,1024]{1,0:T(8,128)(4,1)} compare(dynamic-slice.0, broadcast.3), direction=GT + constant.21 = pred[]{:T(512)} constant(false) + reduce.0 = pred[]{:T(512)} reduce(compare.2, constant.21), dimensions={0,1}, to_apply=or_comp + ROOT and.0 = pred[]{:T(512)} and(compare.0, reduce.0) + } + + ENTRY main { + constant.28 = s32[]{:T(128)} constant(0) + constant.29 = bf16[]{:T(256)} constant(0) + broadcast.4 = bf16[65536,1024]{1,0:T(8,128)(2,1)} broadcast(constant.29), dimensions={} + param.2 = bf16[65536,1024]{1,0:T(8,128)(2,1)} parameter(0), sharding={devices=[8,1]<=[4,2]T(1,0)} + tuple.2 = (s32[]{:T(128)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}) tuple(constant.28, broadcast.4, param.2) + while.1 = (s32[]{:T(128)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}, bf16[65536,1024]{1,0:T(8,128)(2,1)}) while(tuple.2), condition=condition, body=body + get-tuple-element.9 = bf16[65536,1024]{1,0:T(8,128)(2,1)} get-tuple-element(while.1), index=1 + ROOT custom-call.5 = bf16[65536,1024]{1,0:T(8,128)(2,1)} custom-call(get-tuple-element.9), custom_call_target="MoveToHost" + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + + for (HloInstruction* instr : module->entry_computation()->instructions()) { + ASSERT_NE(instr->opcode(), HloOpcode::kCopy); + } +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/host_offloading_prepare.cc b/third_party/xla/xla/hlo/transforms/host_offloading_prepare.cc similarity index 98% rename from third_party/xla/xla/service/host_offloading_prepare.cc rename to third_party/xla/xla/hlo/transforms/host_offloading_prepare.cc index fe3478e8b1d916..41a3d431dd8a76 100644 --- a/third_party/xla/xla/service/host_offloading_prepare.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloading_prepare.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offloading_prepare.h" +#include "xla/hlo/transforms/host_offloading_prepare.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/host_offloading_prepare.h b/third_party/xla/xla/hlo/transforms/host_offloading_prepare.h new file mode 100644 index 00000000000000..d45336e6111a05 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/host_offloading_prepare.h @@ -0,0 +1,91 @@ +/* Copyright 2024 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_HOST_OFFLOADING_PREPARE_H_ +#define XLA_HLO_TRANSFORMS_HOST_OFFLOADING_PREPARE_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This is a collection of rewrites that prepares an HLO module for host +// offloading. These rewrites can be placed in a different parts of +// the overall compilation pipeline to prepare HLO module for host offloading +// for the given backend. +class HostOffloadingPrepare : public HloModulePass { + public: + enum class Rewrite { + // This rewrite removes `MoveToHost` custom calls that feed directly into + // the a host computation. + // + // In the HLO, it will look like HBM is directly fed into the host + // computation. The runtime will, once the async-call-start is executed, + // allocate a buffer on the host and copy the HBM buffer into it. This has + // the benefit that the device will never be blocking directly on the + // tranfser, since that's clumped together with the computation. + kElideMoveToHost, + + // Currently host compute offloading does not support tiled layouts, and + // because of that layouts on the call instruction arguments might be + // different from the layouts in the called computation body. + // + // Host offloading handles layout mismatches at run time by delinearizing + // arguments and linearizing results on the fly. + // + // To keep HLO module valid we rewrite calls to host offloaded computations + // into custom calls with the only purpose to suppress verification error. + // Host offloading compiler later does its own verification to check that + // arguments are compatible with parameters in the offloaded computation and + // knows how to handle mismatched layouts. + kConvertToCustomCall, + }; + + static std::string RewriteName(Rewrite rewrite) { + switch (rewrite) { + case Rewrite::kElideMoveToHost: + return "elide-move-to-host"; + case Rewrite::kConvertToCustomCall: + return "convert-to-custom-call"; + } + } + + explicit HostOffloadingPrepare(Rewrite rewrite) + : rewrite_(rewrite), + pass_name_(absl::StrCat("host-offloading-prepare", "-", + RewriteName(rewrite_))) {} + + absl::string_view name() const override { return pass_name_; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + Rewrite rewrite_; + std::string pass_name_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_HOST_OFFLOADING_PREPARE_H_ diff --git a/third_party/xla/xla/service/host_offloading_prepare_test.cc b/third_party/xla/xla/hlo/transforms/host_offloading_prepare_test.cc similarity index 98% rename from third_party/xla/xla/service/host_offloading_prepare_test.cc rename to third_party/xla/xla/hlo/transforms/host_offloading_prepare_test.cc index 92d5490cfb2d15..a19724ca2286cf 100644 --- a/third_party/xla/xla/service/host_offloading_prepare_test.cc +++ b/third_party/xla/xla/hlo/transforms/host_offloading_prepare_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_offloading_prepare.h" +#include "xla/hlo/transforms/host_offloading_prepare.h" #include #include @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/host_memory_offload_annotations.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -33,7 +33,7 @@ namespace { using Rewrite = HostOffloadingPrepare::Rewrite; -class HostOffloadingPrepareTest : public HloTestBase { +class HostOffloadingPrepareTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr RunRewrite(HloModule* module, Rewrite rewrite) { TF_EXPECT_OK(verifier().Run(module).status()); diff --git a/third_party/xla/xla/service/memory_space_propagation.cc b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc similarity index 98% rename from third_party/xla/xla/service/memory_space_propagation.cc rename to third_party/xla/xla/hlo/transforms/memory_space_propagation.cc index d2bd0a3f834dde..d0704df0e88af9 100644 --- a/third_party/xla/xla/service/memory_space_propagation.cc +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/memory_space_propagation.h" +#include "xla/hlo/transforms/memory_space_propagation.h" #include diff --git a/third_party/xla/xla/hlo/transforms/memory_space_propagation.h b/third_party/xla/xla/hlo/transforms/memory_space_propagation.h new file mode 100644 index 00000000000000..bb0da70bf1a7fc --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation.h @@ -0,0 +1,48 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ +#define XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ + +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This is a legalization pass that propagates the memory space in the layout to +// the fusion computations. +class MemorySpacePropagation : public HloModulePass { + public: + ~MemorySpacePropagation() override = default; + absl::string_view name() const override { return "memory-space-propagation"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Given the shape index (operand or output) and its corresponding instruction + // in the fused computation (parameter or root), propagates the memory space + // in the callee side. Returns true if the module is modified. + bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction, + int64_t memory_space) const; + + std::unique_ptr dataflow_analysis_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_MEMORY_SPACE_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/memory_space_propagation_test.cc b/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc similarity index 98% rename from third_party/xla/xla/service/memory_space_propagation_test.cc rename to third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc index 98ae47c8b164f2..15cd6c4cd4cbff 100644 --- a/third_party/xla/xla/service/memory_space_propagation_test.cc +++ b/third_party/xla/xla/hlo/transforms/memory_space_propagation_test.cc @@ -13,19 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/memory_space_propagation.h" +#include "xla/hlo/transforms/memory_space_propagation.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { -class MemorySpacePropagationTest : public HloTestBase { +class MemorySpacePropagationTest : public HloHardwareIndependentTestBase { public: MemorySpacePropagationTest() - : HloTestBase(), + : HloHardwareIndependentTestBase(), verifier_(/*layout_sensitive=*/false, /*allow_mixed_precision*/ false) { } diff --git a/third_party/xla/xla/service/operand_upcaster.cc b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc similarity index 99% rename from third_party/xla/xla/service/operand_upcaster.cc rename to third_party/xla/xla/hlo/transforms/operand_upcaster.cc index 81c30c88fc2fc9..ed6b4d41ff443a 100644 --- a/third_party/xla/xla/service/operand_upcaster.cc +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/operand_upcaster.h" +#include "xla/hlo/transforms/operand_upcaster.h" #include diff --git a/third_party/xla/xla/hlo/transforms/operand_upcaster.h b/third_party/xla/xla/hlo/transforms/operand_upcaster.h new file mode 100644 index 00000000000000..3b0b6fd2e7699f --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_OPERAND_UPCASTER_H_ +#define XLA_HLO_TRANSFORMS_OPERAND_UPCASTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/util.h" + +namespace xla { + +// Inserts Convert to operands of instructions that allows result accumulation +// as wider integral types. +class OperandUpcaster : public OpExpanderPass { + public: + explicit OperandUpcaster(HloPredicate extra_filter = nullptr) + : OpExpanderPass(std::move(extra_filter)) {} + + absl::string_view name() const override { return "operand_upcaster"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_OPERAND_UPCASTER_H_ diff --git a/third_party/xla/xla/service/operand_upcaster_test.cc b/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc similarity index 97% rename from third_party/xla/xla/service/operand_upcaster_test.cc rename to third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc index 37a8b0657c8942..8a143b365af618 100644 --- a/third_party/xla/xla/service/operand_upcaster_test.cc +++ b/third_party/xla/xla/hlo/transforms/operand_upcaster_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/operand_upcaster.h" +#include "xla/hlo/transforms/operand_upcaster.h" #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { @@ -32,7 +32,7 @@ namespace { namespace op = ::xla::testing::opcode_matchers; class OperandUpcasterTest - : public HloTestBase, + : public HloHardwareIndependentTestBase, public ::testing::WithParamInterface< std::tuple> {}; diff --git a/third_party/xla/xla/service/sharding_format_picker.cc b/third_party/xla/xla/hlo/transforms/sharding_format_picker.cc similarity index 99% rename from third_party/xla/xla/service/sharding_format_picker.cc rename to third_party/xla/xla/hlo/transforms/sharding_format_picker.cc index a493e4cc80805b..90b192400f1f1e 100644 --- a/third_party/xla/xla/service/sharding_format_picker.cc +++ b/third_party/xla/xla/hlo/transforms/sharding_format_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/sharding_format_picker.h" +#include "xla/hlo/transforms/sharding_format_picker.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/sharding_format_picker.h b/third_party/xla/xla/hlo/transforms/sharding_format_picker.h new file mode 100644 index 00000000000000..a6cbeb9420a4c8 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/sharding_format_picker.h @@ -0,0 +1,49 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ +#define XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Test-only pass to transform the HloSharding format of all the instructions in +// a module to the selected format. +class ShardingFormatPicker : public HloModulePass { + public: + enum class ShardingType { + kV1, // Converts all HloSharding to V1 format. + kBestEffortV2, // Best effort to convert all HloSharding to V2 format. + }; + explicit ShardingFormatPicker(ShardingType sharding_type) + : sharding_type_(sharding_type) {} + absl::string_view name() const override { return "sharding-format-picker"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const ShardingType sharding_type_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SHARDING_FORMAT_PICKER_H_ diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc similarity index 98% rename from third_party/xla/xla/service/algebraic_simplifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc index f54864220e2e69..a8be8c222868ba 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include #include @@ -530,7 +530,8 @@ bool AlgebraicSimplifierVisitor::IsNonNegative( return hlo->operand(0) == hlo->operand(1); } case HloOpcode::kAbs: - case HloOpcode::kExp: { + case HloOpcode::kExp: + case HloOpcode::kIota: { return true; } case HloOpcode::kBroadcast: { @@ -1052,10 +1053,28 @@ absl::Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { if (d == lhs_dnums.index_vector_dim()) { continue; } + // Skip the dimensions that are in the update window before we subtract 1 + // from `update_dim` for the next iteration. while ( absl::c_linear_search(lhs_dnums.update_window_dims(), update_dim)) { --update_dim; } + if (absl::c_linear_search(lhs_dnums.scatter_indices_batching_dims(), d)) { + // Corresponding batch dimensions in updates, scatter_indices and inputs + // have the same sizes. So we can't concatenate a batch dim in updates + // and scatter_indices without changing inputs. Instead, we ensure the + // two scatter instructions have the same batch dimensions to support + // the transformation. + if (lhs_scatter_index->shape().dimensions(d) != + rhs_scatter_index->shape().dimensions(d)) { + // This shouldn't be reachable as we currently only combine two + // scatter instructions feeding into the same add straightforwardly, + // which should have the same result shapes. + return absl::OkStatus(); + } + update_dim--; + continue; + } if (lhs_scatter_index->shape().dimensions(d) == rhs_scatter_index->shape().dimensions(d)) { first_index_dim = d; @@ -1093,7 +1112,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { absl::c_equal(lhs_dnums.inserted_window_dims(), rhs_dnums.inserted_window_dims()) && absl::c_equal(lhs_dnums.update_window_dims(), - rhs_dnums.update_window_dims()); + rhs_dnums.update_window_dims()) && + absl::c_equal(lhs_dnums.scatter_indices_batching_dims(), + rhs_dnums.scatter_indices_batching_dims()) && + absl::c_equal(lhs_dnums.input_batching_dims(), + rhs_dnums.input_batching_dims()); const bool index_concat_is_safe = !lhs->unique_indices() && !rhs->unique_indices() && !DynCast(lhs)->indices_are_sorted() && @@ -1290,12 +1313,17 @@ absl::Status AlgebraicSimplifierVisitor::HandleBitcast( // operand. if (bitcast->opcode() == HloOpcode::kBitcast && bitcast->operand(0)->opcode() == HloOpcode::kBroadcast) { - // DeduceTransposeDimensionsForBitcast() checks whether the bitcast is a - // transpose and returns the dimensions attribute if it is. - auto dimensions = ShapeUtil::DeduceTransposeDimensionsForBitcast( - bitcast->operand(0)->shape(), bitcast->shape()); - if (dimensions.has_value()) { - return SimplifyTransposeOfBroadcast(bitcast, dimensions.value()); + // Make sure the bitcast and the broadcast have the same tiling. + bool enable_broadcast = bitcast->operand(0)->shape().layout().tiles() == + bitcast->shape().layout().tiles(); + if (enable_broadcast) { + // DeduceTransposeDimensionsForBitcast() checks whether the bitcast is a + // transpose and returns the dimensions attribute if it is. + auto dimensions = ShapeUtil::DeduceTransposeDimensionsForBitcast( + bitcast->operand(0)->shape(), bitcast->shape()); + if (dimensions.has_value()) { + return SimplifyTransposeOfBroadcast(bitcast, dimensions.value()); + } } } @@ -1406,10 +1434,6 @@ std::optional AlgebraicSimplifierVisitor::ReshapeLayoutDimensions( } auto bit_dims = original_map[op_dim]; for (int64_t bitcast_dim : bit_dims) { - if (result_shape.dimensions(bitcast_dim) == 1) { - // Postpone all degenerated dimensions (those with size 1) to the end. - continue; - } VLOG(3) << "Add new reshaped dimension:" << bitcast_dim << "\n"; if (bitcast_pos < 0 || (*reshaped_dimensions)[bitcast_pos] != bitcast_dim) { @@ -1420,6 +1444,10 @@ std::optional AlgebraicSimplifierVisitor::ReshapeLayoutDimensions( VLOG(3) << "bitcast pos is over incremented:" << bitcast_pos << "\n"; return std::nullopt; } + if (result_shape.dimensions(bitcast_dim) == 1) { + // Postpone all degenerated dimensions (those with size 1) to the end. + continue; + } (*reshaped_dimensions)[bitcast_pos] = bitcast_dim; } auto op_dims = result_map[bitcast_dim]; @@ -2072,6 +2100,13 @@ absl::Status AlgebraicSimplifierVisitor::HandleConstant( absl::Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { HloInstruction *lhs, *rhs; CHECK(Match(sub, m::Subtract(m::Op(&lhs), m::Op(&rhs)))); + // A - A => 0 + if (options_.enable_fast_math() || + ShapeUtil::ElementIsIntegral(sub->shape())) { + if (lhs == rhs) { + return ReplaceInstruction(sub, MakeScalarLike(sub, 0)); + } + } // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfCompatible(sub, lhs)) { @@ -4207,7 +4242,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { PaddingConfig pad_config; for (int64_t i = 0; i != gather->shape().rank(); ++i) { auto dimension = pad_config.add_dimensions(); - if (reshape_dims_to_padded_dims.contains( + if (gather_operand_passthrough_output_to_operand_dims.contains(i) && + reshape_dims_to_padded_dims.contains( gather_operand_passthrough_output_to_operand_dims[i])) { int64_t padded_dim = reshape_dims_to_padded_dims [gather_operand_passthrough_output_to_operand_dims[i]]; @@ -4297,6 +4333,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum( HloInstruction *lhs, *rhs; CHECK(Match(maximum, m::Maximum(m::Op(&lhs), m::Op(&rhs)))); + // max(x, x) -> x + if (lhs == rhs) { + return ReplaceInstruction(maximum, lhs); + } + // max(x, -inf) -> x PrimitiveType ty = maximum->shape().element_type(); if (primitive_util::IsIntegralType(ty) || @@ -4397,6 +4438,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleMinimum( HloInstruction *lhs, *rhs; CHECK(Match(minimum, m::Minimum(m::Op(&lhs), m::Op(&rhs)))); + // min(x, x) -> x + if (lhs == rhs) { + return ReplaceInstruction(minimum, lhs); + } + // min(x, inf) -> x PrimitiveType ty = minimum->shape().element_type(); if (primitive_util::IsIntegralType(ty) || @@ -4487,6 +4533,51 @@ absl::Status AlgebraicSimplifierVisitor::HandleClamp(HloInstruction* clamp) { return absl::OkStatus(); } +absl::Status AlgebraicSimplifierVisitor::TryToReorderConvAddMultiply( + HloInstruction* multiply) { + if (!options_.enable_conv_add_multiply_reorder()) return absl::OkStatus(); + HloInstruction *input, *filter, *bias, *constant, *convolution, *broadcast, + *add; + // We conservatively only consider the case where the multiplier is a + // broadcast of a 1D constant to the output feature dimension and the filter + // is a constant so that they can be constant-folded. + if (!Match(multiply, + m::MultiplyAnyOrder( + m::AddAnyOrder(&add, + m::Convolution(&convolution, m::Op(&input), + m::Constant(&filter)) + .WithOneUser(), + m::Op(&bias).WithOneUser()), + m::Broadcast(&broadcast, m::Constant(&constant).WithShape( + m::Shape().WithRank(1))) + .WithOneUser()))) { + return absl::OkStatus(); + } + const ConvolutionDimensionNumbers& dnums = + convolution->convolution_dimension_numbers(); + if (broadcast->dimensions().size() != 1 || + broadcast->dimensions()[0] != dnums.output_feature_dimension()) { + return absl::OkStatus(); + } + + HloInstruction* bcast_to_filter_dim = + multiply->AddInstruction(HloInstruction::CreateBroadcast( + filter->shape(), constant, + {dnums.kernel_output_feature_dimension()})); + HloInstruction* filter_multiply = + multiply->AddInstruction(HloInstruction::CreateBinary( + filter->shape(), HloOpcode::kMultiply, filter, bcast_to_filter_dim)); + HloInstruction* new_conv = + multiply->AddInstruction(convolution->CloneWithNewOperands( + convolution->shape(), {input, filter_multiply})); + HloInstruction* bias_multiply = + multiply->AddInstruction(HloInstruction::CreateBinary( + bias->shape(), HloOpcode::kMultiply, bias, broadcast)); + std::unique_ptr new_add = + add->CloneWithNewOperands(add->shape(), {new_conv, bias_multiply}); + return ReplaceWithNewInstruction(multiply, std::move(new_add)); +} + absl::Status AlgebraicSimplifierVisitor::HandleMultiply( HloInstruction* multiply) { HloInstruction *lhs, *rhs; @@ -4693,7 +4784,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleMultiply( MakeScalarLike(lhs, 1), lhs)); } - return absl::OkStatus(); + return TryToReorderConvAddMultiply(multiply); } absl::Status AlgebraicSimplifierVisitor::HandleNegate(HloInstruction* negate) { @@ -5047,7 +5138,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleBroadcast( if (options_.is_layout_sensitive()) { return absl::OkStatus(); } - if (ShapeUtil::HasDegenerateDimensions(operand->shape())) { + if (options_.enable_broadcast_degenerate_dimension() && + ShapeUtil::HasDegenerateDimensions(operand->shape())) { auto new_operand = operand->AddInstruction(HloInstruction::CreateReshape( ShapeUtil::DropDegenerateDimensions(operand->shape()), operand)); std::vector new_dims; @@ -5095,16 +5187,16 @@ absl::Status AlgebraicSimplifierVisitor::HandleCompare( } if (compare->comparison_direction() == ComparisonDirection::kLt && - lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + IsNonNegative(lhs, options_) && IsAll(rhs, 0)) { return ReplaceInstruction(compare, MakeScalarLike(compare, false)); } else if (compare->comparison_direction() == ComparisonDirection::kGt && - IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + IsAll(lhs, 0) && IsNonNegative(rhs, options_)) { return ReplaceInstruction(compare, MakeScalarLike(compare, false)); } else if (compare->comparison_direction() == ComparisonDirection::kGe && - lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { + IsNonNegative(lhs, options_) && IsAll(rhs, 0)) { return ReplaceInstruction(compare, MakeScalarLike(compare, true)); } else if (compare->comparison_direction() == ComparisonDirection::kLe && - IsAll(lhs, 0) && rhs->opcode() == HloOpcode::kIota) { + IsAll(lhs, 0) && IsNonNegative(rhs, options_)) { return ReplaceInstruction(compare, MakeScalarLike(compare, true)); } if (lhs == rhs && @@ -5288,6 +5380,16 @@ absl::Status AlgebraicSimplifierVisitor::HandleCustomCall( return absl::OkStatus(); } +absl::Status AlgebraicSimplifierVisitor::HandleExp( + HloInstruction* exponential) { + // Exp(0) => 1 + if (Match(exponential, m::Exp(m::ConstantScalar(0))) || + Match(exponential, m::Exp(m::Broadcast(m::ConstantScalar(0))))) { + return ReplaceInstruction(exponential, MakeScalarLike(exponential, 1.0)); + } + return absl::OkStatus(); +} + // Complex(Real(c), Imag(c)) -> c absl::Status AlgebraicSimplifierVisitor::HandleComplex( HloInstruction* complex) { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h new file mode 100644 index 00000000000000..a04f698590d105 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier.h @@ -0,0 +1,750 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALGEBRAIC_SIMPLIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALGEBRAIC_SIMPLIFIER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/literal.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class AlgebraicSimplifierOptions { + public: + // Platform dependent callback to determine if a reshape `from_shape` to + // `to_shape` is a bitcast. + using ReshapeIsBitcastCallback = + std::function; + // Platform dependent callback to determine if a set of reverse dimensions is + // lowerable + using ConvIsLowerableCallback = std::function; + + explicit AlgebraicSimplifierOptions( + ReshapeIsBitcastCallback reshape_is_bitcast_callback = {}, + ConvIsLowerableCallback conv_is_lowerable_callback = {}) + : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)), + conv_is_lowerable_callback_(std::move(conv_is_lowerable_callback)) {} + + // Use the platform specific callback if set. It is not sensible to return + // true here if the options are not layout sensitive. + bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { + if (!is_layout_sensitive_) { + return false; + } + if (!reshape_is_bitcast_callback_) { + return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); + } + return reshape_is_bitcast_callback_(from_shape, to_shape); + } + + // Use the platform specific callback if set. Otherwise, return true. + bool ConvIsLowerable(HloInstruction* reverse_dims) const { + if (!conv_is_lowerable_callback_) { + return true; + } + return conv_is_lowerable_callback_(reverse_dims); + } + + void set_conv_is_lowerable_callback( + ConvIsLowerableCallback conv_is_lowerable_callback) { + conv_is_lowerable_callback_ = std::move(conv_is_lowerable_callback); + } + + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + void set_is_layout_sensitive(bool is_layout_sensitive) { + is_layout_sensitive_ = is_layout_sensitive; + } + + bool is_layout_sensitive() const { return is_layout_sensitive_; } + + void set_use_associative_reordering(bool use_associative_reordering) { + use_associative_reordering_ = use_associative_reordering; + } + + bool use_associative_reordering() const { + return use_associative_reordering_; + } + + void set_associative_reordering_threshold( + double associative_reordering_threshold) { + associative_reordering_threshold_ = associative_reordering_threshold; + } + + double associative_reordering_threshold() const { + return associative_reordering_threshold_; + } + + void set_use_convert_constant_folding(bool use_convert_constant_folding) { + use_convert_constant_folding_ = use_convert_constant_folding; + } + + bool use_convert_constant_folding() const { + return use_convert_constant_folding_; + } + + void set_raise_slice_and_reduce_through_dot( + bool raise_slice_and_reduce_through_dot) { + raise_slice_and_reduce_through_dot_ = raise_slice_and_reduce_through_dot; + } + + bool raise_slice_and_reduce_through_dot() const { + return raise_slice_and_reduce_through_dot_; + } + + void set_raise_slice_and_reduce_through_dot_threshold( + double raise_slice_and_reduce_through_dot_threshold) { + raise_slice_and_reduce_through_dot_threshold_ = + raise_slice_and_reduce_through_dot_threshold; + } + + double raise_slice_and_reduce_through_dot_threshold() const { + return raise_slice_and_reduce_through_dot_threshold_; + } + + // Enable dot simplification on platforms where it is profitable. + void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { + enable_dot_strength_reduction_ = enable_dot_strength_reduction; + } + + bool enable_dot_strength_reduction() const { + return enable_dot_strength_reduction_; + } + + // Enable dot->multiple rewrite for dot as an outer-product + void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) { + enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite; + } + + bool enable_dot_to_multiply_rewrite() const { + return enable_dot_to_multiply_rewrite_; + } + + void set_enable_move_dot_param_to_rhs(bool enable_move_dot_param_to_rhs) { + enable_move_dot_param_to_rhs_ = enable_move_dot_param_to_rhs; + } + + bool enable_move_dot_param_to_rhs() const { + return enable_move_dot_param_to_rhs_; + } + + // This platform will not run the DotDecomposer to canonicalize dots. + void set_supports_non_canonical_dots(bool supports_non_canonical_dots) { + supports_non_canonical_dots_ = supports_non_canonical_dots; + } + bool supports_non_canonical_dots() const { + return supports_non_canonical_dots_; + } + + // Enable convolution simplification on platforms where it is profitable. + void set_enable_conv_simplification(bool enable_conv_simplification) { + enable_conv_simplification_ = enable_conv_simplification; + } + bool enable_conv_simplification() const { + return enable_conv_simplification_; + } + + // Enable convolution operand swapping on platforms where it is supported. + void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { + enable_conv_operand_swap_ = enable_conv_operand_swap; + } + bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } + + // Enable rewrite of convolution + add + multiply -> multiply + convolution + + // add. + void set_enable_conv_add_multiply_reorder( + bool enable_conv_add_multiply_reorder) { + enable_conv_add_multiply_reorder_ = enable_conv_add_multiply_reorder; + } + + bool enable_conv_add_multiply_reorder() const { + return enable_conv_add_multiply_reorder_; + } + + // Move constant scalar multiply to one operand or output of convolutions with + // the smallest tensor size, to reduce the number of scalar multiply. + void set_enable_scalar_multiply_reduction( + bool enable_scalar_multiply_reduction) { + enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; + } + + bool enable_scalar_multiply_reduction() const { + return enable_scalar_multiply_reduction_; + } + + // Also the algebraic simplifer to treat floating point values like real + // numbers. + void set_enable_floats_are_real(bool enable_floats_are_real) { + enable_floats_are_real_ = enable_floats_are_real; + } + + bool enable_floats_are_real() const { return enable_floats_are_real_; } + + // If enable_window_reduce_replacement is true, the kReduceWindow instruction + // can be optimized by replacement with simpler operations. + void set_enable_window_reduce_to_reduce_replacement( + bool enable_window_reduce_to_reduce_replacement) { + enable_window_reduce_to_reduce_replacement_ = + enable_window_reduce_to_reduce_replacement; + } + + bool enable_window_reduce_to_reduce_replacement() const { + return enable_window_reduce_to_reduce_replacement_; + } + + // Sets the size of a gather operand that can be unrolled into many selects. + void set_very_small_gather_size(int64_t size) { + very_small_gather_size_ = size; + } + + int64_t very_small_gather_size() const { return very_small_gather_size_; } + + void set_cudnn_batchnorm_forward_training_metadata(const std::string& c) { + metadata_.cudnn_batchnorm_forward_training_metadata = c; + } + + const std::string& get_cudnn_batchnorm_forward_training_metadata() const { + return metadata_.cudnn_batchnorm_forward_training_metadata; + } + + void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { + enable_reduce_of_reshape_ = enable_reduce_of_reshape; + } + + bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } + + void set_enable_negative_padding_replacement( + bool enable_negative_padding_replacement) { + enable_negative_padding_replacement_ = enable_negative_padding_replacement; + } + + bool enable_negative_padding_replacement() const { + return enable_negative_padding_replacement_; + } + + void set_enable_sink_broadcast(bool enable_sink_broadcast) { + enable_sink_broadcast_ = enable_sink_broadcast; + } + + bool enable_sink_broadcast() const { return enable_sink_broadcast_; } + + // If true, always simplify reduce(transpose(x)) and reduce(reshape(x)), even + // if the transpose/reshape has multiple users. This can be beneficial + // on platforms where the extra transpose/reshape isn't as expensive as + // the optimization benefits brought about by simplifying the graph. + bool unconditionally_simplify_reduce_of_transpose_or_reshape() const { + return unconditionally_simplify_reduce_of_transpose_or_reshape_; + } + void set_unconditionally_simplify_reduce_of_transpose_or_reshape(bool val) { + unconditionally_simplify_reduce_of_transpose_or_reshape_ = val; + } + + // If true, min(x, NaN) = NaN. If false, min(x, NaN) = x. + // + // TODO(b/209827141): Remove this and make minmax_propagate_nan + // unconditionally true. + bool minmax_propagate_nan() const { return minmax_propagate_nan_; } + void set_minmax_propagate_nan(bool val) { minmax_propagate_nan_ = val; } + + // When true, always replaces Reduce(concat({a,b,...})) with + // map(reduce(a),map(reduce(b),...,)). If false, only does the replacement if + // the shapes of a,b,... have the same dimensions. + bool enable_unconditional_reduce_of_concat_replacement() const { + return enable_unconditional_reduce_of_concat_replacement_; + } + void set_enable_unconditional_reduce_of_concat_replacement( + bool enable_unconditional_reduce_of_concat_replacement) { + enable_unconditional_reduce_of_concat_replacement_ = + enable_unconditional_reduce_of_concat_replacement; + } + + // Indicates whether running on CPU + bool executing_on_cpu() const { return executing_on_cpu_; } + void set_executing_on_cpu(bool executing_on_cpu) { + executing_on_cpu_ = executing_on_cpu; + } + + // Option to disable conversion of dynamic-slice to slice. + void set_disable_dynamic_slice_to_slice_conversion(bool disable) { + disable_dynamic_slice_to_slice_conversion_ = disable; + } + bool disable_dynamic_slice_to_slice_conversion() const { + return disable_dynamic_slice_to_slice_conversion_; + } + + // Option to set finite math. + void set_enable_fast_math(bool enable_fast_math) { + enable_fast_math_ = enable_fast_math; + } + bool enable_fast_math() const { return enable_fast_math_; } + + void set_enable_broadcast_degenerate_dimension( + bool enable_broadcast_degenerate_dimension) { + enable_broadcast_degenerate_dimension_ = + enable_broadcast_degenerate_dimension; + } + bool enable_broadcast_degenerate_dimension() const { + return enable_broadcast_degenerate_dimension_; + } + + private: + // Metadata struct can be used to store any metadata information encapsulated + // with the AlgebraicSimplifierOptions that can be later used in an + // AlgebraicSimplifier pass. For example, + // cudnn_batchnorm_forward_training_metadata can be used to store the name of + // a custom call. If the custom call is + // __cudnn$batchNormalizationForwardTraining, the output with index 2 is + // guaranteed to be positive. This property has been used to recursively + // determine if the operand of an instruction is always positive. + struct Metadata { + std::string cudnn_batchnorm_forward_training_metadata{""}; + Metadata() {} + }; + ReshapeIsBitcastCallback reshape_is_bitcast_callback_; + ConvIsLowerableCallback conv_is_lowerable_callback_; + bool is_layout_sensitive_{false}; + bool enable_dot_strength_reduction_{true}; + bool supports_non_canonical_dots_{true}; + bool enable_dot_to_multiply_rewrite_{true}; + bool enable_move_dot_param_to_rhs_{false}; + bool enable_conv_simplification_{true}; + bool enable_conv_operand_swap_{true}; + bool enable_conv_add_multiply_reorder_{false}; + bool enable_scalar_multiply_reduction_{false}; + bool enable_floats_are_real_{false}; + bool enable_window_reduce_to_reduce_replacement_{true}; + bool enable_reduce_of_reshape_{true}; + bool enable_negative_padding_replacement_{true}; + bool enable_sink_broadcast_{true}; + bool unconditionally_simplify_reduce_of_transpose_or_reshape_{false}; + int64_t very_small_gather_size_{4}; + bool minmax_propagate_nan_{true}; + bool enable_unconditional_reduce_of_concat_replacement_{true}; + bool executing_on_cpu_{false}; + bool use_associative_reordering_{false}; + double associative_reordering_threshold_{2.0}; + bool raise_slice_and_reduce_through_dot_{false}; + double raise_slice_and_reduce_through_dot_threshold_{2.0}; + bool use_convert_constant_folding_{false}; + bool disable_dynamic_slice_to_slice_conversion_{false}; + bool enable_fast_math_{false}; + bool enable_broadcast_degenerate_dimension_{true}; + Metadata metadata_; +}; + +// A pass which performs algebraic simplifications. +class AlgebraicSimplifier : public HloModulePass { + public: + // If is_layout_sensitive is true, then the simplifier preserves layout during + // transformation. Otherwise, layout is ignored. + explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) + : options_(options) {} + ~AlgebraicSimplifier() override = default; + absl::string_view name() const override { return "algsimp"; } + + // Run algebraic simplification on the given computation. Returns whether the + // computation was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // Create constant from literal with tiles and element size updated in the + // constant's layout. + std::unique_ptr CreateConstantWithLayoutUpdated( + Literal literal) { + auto constant = HloInstruction::CreateConstant(std::move(literal)); + UpdateLayout(constant->mutable_shape()); + return constant; + } + + protected: + AlgebraicSimplifierOptions options_; +}; + +// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain +// algebraic expressions to simplified forms. Note: This only supports +// simplifications that simply look at the operands of an instruction. For the +// more general case a worklist based approach would be needed. +class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { + public: + explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier) + : options_(options), simplifier_(simplifier) {} + + absl::Status HandleAbs(HloInstruction* abs) override; + + absl::Status HandleAdd(HloInstruction* add) override; + + absl::Status HandleAllToAll(HloInstruction* all_to_all) override; + + absl::Status HandleAnd(HloInstruction* logical_and) override; + + absl::Status HandleBitcast(HloInstruction* bitcast) override; + + absl::Status HandleBitcastConvert(HloInstruction* bitcast) override; + + absl::Status HandleBroadcast(HloInstruction* broadcast) override; + + absl::Status HandleCompare(HloInstruction* compare) override; + + absl::Status HandleConcatenate(HloInstruction* concatenate) override; + + absl::Status HandleConstant(HloInstruction* constant) override; + + absl::Status HandleCopy(HloInstruction* copy) override; + + absl::Status HandleConvert(HloInstruction* convert) override; + + absl::Status HandleComplex(HloInstruction* complex) override; + + absl::Status HandleCustomCall(HloInstruction* custom_call) override; + + absl::Status HandleExp(HloInstruction* exp) override; + + absl::Status HandleReal(HloInstruction* real) override; + + absl::Status HandleImag(HloInstruction* imag) override; + + absl::Status HandleIota(HloInstruction* instruction) override; + + absl::Status HandleConvolution(HloInstruction* convolution) override; + + absl::Status HandleDivide(HloInstruction* divide) override; + + absl::Status HandleDot(HloInstruction* dot) override; + + absl::Status HandleGather(HloInstruction* gather) override; + + absl::Status HandleGetTupleElement( + HloInstruction* get_tuple_element) override; + + absl::Status HandleLog(HloInstruction* log) override; + + absl::Status HandleMaximum(HloInstruction* maximum) override; + + absl::Status HandleMinimum(HloInstruction* minimum) override; + + absl::Status HandleClamp(HloInstruction* clamp) override; + + absl::Status HandleMultiply(HloInstruction* multiply) override; + + absl::Status HandleNegate(HloInstruction* negate) override; + + absl::Status HandleNot(HloInstruction* logical_not) override; + + absl::Status HandleOptimizationBarrier(HloInstruction* barrier) override; + + absl::Status HandleOr(HloInstruction* logical_or) override; + + absl::Status HandlePad(HloInstruction* pad) override; + + absl::Status HandlePower(HloInstruction* power) override; + + absl::Status HandleRemainder(HloInstruction* remainder) override; + + absl::Status HandleReshape(HloInstruction* reshape) override; + + absl::Status HandleReduce(HloInstruction* hlo) override; + + absl::Status HandleReduceWindow(HloInstruction* hlo) override; + + absl::Status HandleReverse(HloInstruction* reverse) override; + + absl::Status HandleRsqrt(HloInstruction* rsqrt) override; + + absl::Status HandleSlice(HloInstruction* slice) override; + + absl::Status HandleSqrt(HloInstruction* sqrt) override; + + absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + + absl::Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + absl::Status HandleScatter(HloInstruction* hlo) override; + + absl::Status HandleSelect(HloInstruction* select) override; + + absl::Status HandleSort(HloInstruction* sort) override; + + absl::Status HandleTranspose(HloInstruction* transpose) override; + + absl::Status HandleSubtract(HloInstruction* sub) override; + + absl::Status HandleMap(HloInstruction* map) override; + + // Runs the visitor on a computation. + bool Run(HloComputation* computation, + const AlgebraicSimplifierOptions& options, + AlgebraicSimplifier* simplifier); + + // Compute a function that maps from bitcasted dimensions to the resulting + // ones. Returns the function as a vector if successful; std::optional + // otherwise. + static std::optional>> ComputeBitcastDimMap( + const Shape& bitcast_shape, const Shape& operand_shape); + // Invert the directions of the given bitcast dimension map. + static std::vector> InvertBitcastDimMap( + const Shape& original_shape, const Shape& bitcast_shape, + const std::vector>& original_map); + + // Checks if the output of a given instruction is guaranteed to be + // non-negative. e.g. abs + static bool IsNonNegative(const HloInstruction* hlo, + const AlgebraicSimplifierOptions& options); + + // Check if the opcode of a given instruction is a non-decreasing function + // asymptotically satisfying |f(x)| <= |x| + static bool IsNondecreasingSublinear(const HloInstruction* hlo); + + // Modify the layout dimensions of result_shape, so that it becomes the + // re-shaped result of applying bitcast to the original_shape, by using + // dim_map to re-shape layout dimensions of original_shape. Returns the + // result_shape with modified layout if the conversion succeeds; Returns + // std::nullopt if fails. + static std::optional ReshapeLayoutDimensions( + const Shape& original_shape, const Shape& result_shape, + const std::vector>& original_map, + const std::vector>& result_map); + + // Allow backend constraints on tiling etc. to invalidate optimizations. + virtual bool IsValidLayout(const Shape& shape) { return true; } + // Allow backend targets to determine whether a layout is inefficient. + virtual bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) { + return true; + } + + protected: + // The backend-specific options selected for the algebraic simplifier. + const AlgebraicSimplifierOptions& options_; + + private: + // Removes degenerate dimension from dot. + absl::StatusOr RemoveDegenerateDimensionFromDot(HloDotInstruction* dot); + + // Moves the transpose to the broadcast if possible. Can also be called with a + // bitcast transpose. + absl::Status SimplifyTransposeOfBroadcast( + HloInstruction* transpose, absl::Span dimensions); + + // Converts to primitive type if the input hlo is not that type, otherwise + // returns the original hlo. + HloInstruction* AsType(HloInstruction* hlo, + const PrimitiveType element_type) { + if (hlo->shape().element_type() == element_type) { + return hlo; + } + Shape changed_shape = + ShapeUtil::ChangeElementType(hlo->shape(), element_type); + simplifier_->UpdateLayout(&changed_shape); + return computation_->AddInstruction( + HloInstruction::CreateConvert(changed_shape, hlo)); + } + + // Transposes a dot operand such that the batch dimensions are the most major, + // and the contracting dimensions are most minor. + absl::StatusOr + NormalizeDotOperandToBatchMajorAndContractingMinor( + HloInstruction* dot_operand, absl::Span batch_dimensions, + absl::Span contracting_dimensions); + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)) (or + // transpose(dot(a,b)) if only the batch dims are transposed). + // + // Requires the dot has been canonicalized by DotDecomposer into + // + // LHS [batch dims..., non-contracting dim, contracting dim] + // RHS [batch dims..., contracting dim, non-contracting dim]. + absl::StatusOr RemoveTransposesFromDotOperands(HloDotInstruction* dot); + + // Swap the operands of dots, if one operand is "parameter-like" (i.e. a + // parameter, or a pointwise transformation of a parameter), so the + // "parameter-like" operand (e.g. a weight tensor) is placed on the RHS. + absl::StatusOr MoveDotParamToRhs(HloDotInstruction* dot); + + // Helper method to perform and add reduction on a list of dimensions. + HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims, + PrimitiveType type); + + // Move scalar multiply to the smallest side of convolution to + // reduce multiply computations. + absl::Status ScalarMultiplyReduction(HloInstruction* dot); + + // Convenience method for replacing an instruction with a bitcast. If operand + // is not null, then the bitcast will use the specified operand instead of the + // operand of the instruction. + void ReplaceWithBitcast(HloInstruction* instruction, + HloInstruction* operand = nullptr); + + // Change copy(bitcast...(copy)) into copy(bitcast) or bitcast(copy) so that + // the replicated copies are combined when allowed by layout/tiling assignment + // constraints. + bool SwapCopyBitcastCopy(HloInstruction* root_copy); + + // Replace old instruction with new instruction if old and new instructions + // are compatible (have the same shape and replacement preserves sharding). + // Updates uses and root instruction. Returns whether a replacement was made. + bool ReplaceInstructionIfCompatible(HloInstruction* old_instruction, + HloInstruction* new_instruction); + // Similar to above but tuplizes `new_instructions` if there are more than 1 + // instructions. + bool ReplaceInstructionIfCompatible( + HloInstruction* old_instruction, + absl::Span new_instructions); + + // Returns whether the shape of the output of the given instructions are the + // same for the purposes of simplification. If options_.is_layout_sensitive() + // is true, then this tests shape equality including layout + // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the + // tests shape compatibility (ShapeUtil::Compatible). + bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; + + // Same as above but takes shape arguments directly. + bool SameShape(const Shape& lhs, const Shape& rhs) const; + + // A Broadcast that feeds an element-wise operation with a unique non-scalar + // operand can sink to after the operation. + absl::StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( + HloInstruction* broadcast); + + absl::StatusOr OptimizeDotOfConcat(HloInstruction* dot); + absl::StatusOr OptimizeDotOfConcatHelper( + HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, + HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped); + + absl::StatusOr OptimizeDotOfGather(HloInstruction* dot); + + absl::StatusOr OptimizeDotOfReorderContractingDims( + HloInstruction* dot); + + absl::StatusOr AssociativeReorderDotOperator( + HloDotInstruction* dot); + + HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { + HloComputation*& scalar_add_computation = scalar_add_computations_[type]; + if (scalar_add_computation) { + return scalar_add_computation; + } + + HloComputation::Builder b("scalar_add_computation"); + Shape shape = ShapeUtil::MakeShape(type, {}); + simplifier_->UpdateLayout(&shape); + auto scalar_lhs = b.AddInstruction( + HloInstruction::CreateParameter(0, shape, "scalar_lhs")); + auto scalar_rhs = b.AddInstruction( + HloInstruction::CreateParameter(1, shape, "scalar_rhs")); + auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); + scalar_add_computation = + computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_add_computation; + } + + // Tries to fold a kPad in the input or filter into the convolution + // instruction's window. + virtual absl::StatusOr FoldConvInputPad(HloInstruction* convolution); + absl::StatusOr FoldConvFilterPad(HloInstruction* convolution); + + // Tries to swap convolution operands if they would result in a more efficient + // convolution. + absl::StatusOr SwapConvOperands(HloInstruction* convolution); + + // Checks if the given convolution is in BF16 and is oneDNN rewritable, if not + // then it promotes the data type of the convolution to F32 + absl::StatusOr IsOneDnnRewritableBF16Conv(HloInstruction** convolution); + + // Tries to use a kDot in place of the given convolution. + absl::StatusOr SimplifyConvToDot(HloInstruction* convolution); + // Tries to use a multiplication in place of the given convolution. + absl::StatusOr SimplifyConvToMultiply(HloInstruction* convolution); + + // Tries to reorder mul(add(conv(input, filter), bias), multiplier) -> + // add(conv(input, mul(filter, multiplier)), mul(bias, multiplier)). It only + // does that when the multiplier is a 1D constant of the size equal to the + // convolution output feature dimension. + absl::Status TryToReorderConvAddMultiply(HloInstruction* multiply); + + // Tries to simplify a slice where the result of the slice is a scalar. + absl::StatusOr TrySimplifyScalarSlice(HloInstruction* slice); + + // Tries to convert slice(reshape(X)) into reshape(slice(X)) + absl::StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); + + // Tries to convert slice(reverse(X)) into reverse(slice(X)) + absl::StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); + + // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into + // `(< a N)`. This is crucial for being able to figure out the loop trip + // count. + // + // Assumes that the input is conjunction. + absl::StatusOr TrySimplifyTautologicalCompare( + HloInstruction* conjunction); + + // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where + // the types of inner and outer bitcast-convert cancel out. + absl::StatusOr TrySimplifyTautologicalBitcastConvert( + HloInstruction* bitcast); + + // Tries to remove surrounding converts around a binary op where the op has a + // more precise type than its inputs and output. + // + // convert(bin_op(convert(data1), + // convert(data2))) + // where TS is a smaller point type than TL (ex, TS=fp16, TL=fp32) + // -> + // bin_op(data1, data2) + absl::Status TryRemoveUpcastAndDowncastSurroundingBinaryOp( + HloInstruction* convert_instruction); + + // Useful when we want to use the same visitor over multiple computations. + void ResetState(HloComputation* computation); + + // Current HloComputation instance the AlgebraicSimplifierVisitor is + // traversing. + HloComputation* computation_; + + // Cached computation for adding two scalars of a given type. + absl::flat_hash_map scalar_add_computations_; + + AlgebraicSimplifier* simplifier_ = nullptr; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc similarity index 96% rename from third_party/xla/xla/service/algebraic_simplifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc index 1af8721066b8bf..09e99a4ffd1527 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/algebraic_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include #include @@ -42,13 +42,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_fix.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_parser.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/layout_assignment.h" #include "xla/service/pattern_matcher.h" @@ -57,7 +58,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" @@ -71,12 +71,13 @@ namespace { using ::testing::ElementsAre; namespace m = match; -class AlgebraicSimplifierTest : public HloTestBase { +class AlgebraicSimplifierTest : public HloHardwareIndependentTestBase { public: AlgebraicSimplifierTest() - : HloTestBase(/*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/true, - LayoutAssignment::InstructionCanChangeLayout) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/true, + LayoutAssignment::InstructionCanChangeLayout) {} protected: AlgebraicSimplifierOptions default_options_; @@ -399,6 +400,81 @@ TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) { m::Parameter(1), m::Constant()))))); } +// Mul(Add(Conv(input, filter), bias), Broadcast(constant)) => Conv(input, +// Mul(filter, Broadcast(constant))), Mul(bias, Broadcast(constant))) +TEST_F(AlgebraicSimplifierTest, ReorderConvAddMul) { + const char* kModuleStr = R"( + HloModule m + test { + input = f32[5,4,4,1] parameter(0) + filter = f32[2,2,1,2] constant({{{{1.1, 1.2}}, {{2.1, 2.2}}}, + {{{3.1, 3.2}}, {{4.1, 4.2}}}}) + conv = f32[5,3,3,2] convolution(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + bias = f32[5,3,3,2] parameter(1) + add = f32[5,3,3,2] add(conv, bias) + constant = f32[2] constant({1.0, 1.1}) + bcast = f32[5,3,3,2] broadcast(constant), dimensions={3} + ROOT multiply = f32[5,3,3,2] multiply(add, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions opts = default_options_; + opts.set_enable_conv_add_multiply_reorder(true); + ASSERT_TRUE(AlgebraicSimplifier(opts).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::AddAnyOrder( + m::Convolution( + m::Parameter(0), + m::Multiply(m::Constant(), m::Broadcast(m::Constant()))), + m::Multiply(m::Parameter(1), m::Broadcast(m::Constant()))))); +} + +TEST_F(AlgebraicSimplifierTest, DoNotReorderConvAddMulWhenDisabled) { + const char* kModuleStr = R"( + HloModule m + test { + input = f32[5,4,4,1] parameter(0) + filter = f32[2,2,1,2] constant({{{{1.1, 1.2}}, {{2.1, 2.2}}}, + {{{3.1, 3.2}}, {{4.1, 4.2}}}}) + conv = f32[5,3,3,2] convolution(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + bias = f32[5,3,3,2] parameter(1) + add = f32[5,3,3,2] add(conv, bias) + constant = f32[2] constant({1.0, 1.1}) + bcast = f32[5,3,3,2] broadcast(constant), dimensions={3} + ROOT multiply = f32[5,3,3,2] multiply(add, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions opts = default_options_; + opts.set_enable_conv_add_multiply_reorder(false); + EXPECT_FALSE(AlgebraicSimplifier(opts).Run(m.get()).value()); +} + +TEST_F(AlgebraicSimplifierTest, + DoNotReorderConvAddMulWithUnmatchingOutputFeatureDimension) { + const char* kModuleStr = R"( + HloModule m + test { + input = f32[5,3,3,1] parameter(0) + filter = f32[2,2,1,2] constant({{{{1.1, 1.2}}, {{2.1, 2.2}}}, + {{{3.1, 3.2}}, {{4.1, 4.2}}}}) + conv = f32[5,2,2,2] convolution(input, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + bias = f32[5,2,2,2] parameter(1) + add = f32[5,2,2,2] add(conv, bias) + constant = f32[2] constant({1.0, 1.1}) + bcast = f32[5,2,2,2] broadcast(constant), dimensions={2} + ROOT multiply = f32[5,2,2,2] multiply(add, bcast) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions opts = default_options_; + opts.set_enable_conv_add_multiply_reorder(true); + EXPECT_FALSE(AlgebraicSimplifier(opts).Run(m.get()).value()); +} + // A*C + B*C => (A+B)*C if C is a broadcast of a floating-point power of 2. TEST_F(AlgebraicSimplifierTest, FactorFpAdditionWithBroadcast) { const char* kModuleStr = R"( @@ -1343,6 +1419,41 @@ TEST_F(AlgebraicSimplifierTest, AddReassociateMergeBroadcastedConstants) { m::ConstantScalar(2.0)))))); } +TEST_F(AlgebraicSimplifierTest, ReplaceSubtractOfEqualOperandsWithZero) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + ROOT sub = f32[] subtract(p0, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_enable_fast_math(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::ConstantScalar(0.0))); +} + +TEST_F(AlgebraicSimplifierTest, + ReplaceSubtractOfEqualOperandsWithBroadcastZero) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[512,20] parameter(0) + ROOT sub = f32[512,20] subtract(p0, p0) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_enable_fast_math(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast())); +} + TEST_F(AlgebraicSimplifierTest, SubAddReassociateMergeConstants) { const char* kModuleStr = R"( HloModule m @@ -1362,6 +1473,23 @@ TEST_F(AlgebraicSimplifierTest, SubAddReassociateMergeConstants) { m::Parameter(0)))); } +TEST_F(AlgebraicSimplifierTest, ExpOfZero) { + const char* m = R"( + HloModule m + ENTRY main{ + %constant = bf16[] constant(0) + %broadcast = bf16[6,512]{1,0} broadcast(bf16[] %constant), dimensions={} + ROOT exponential.11278 = bf16[6,512] exponential(%broadcast) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(m)); + HloPassFix simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::ConstantScalar(1.0)))); +} + TEST_F(AlgebraicSimplifierTest, SubAddReassociateMergeBroadcastedConstants) { const char* kModuleStr = R"( HloModule m @@ -5776,12 +5904,12 @@ struct ConvTestOptions { b.AddInstruction(HloInstruction::CreateConvolve( inferred_shape, input, kernel, feature_group_count, /*batch_group_count=*/1, window, dnums, - HloTestBase::DefaultPrecisionConfig(2))); + HloHardwareIndependentTestBase::DefaultPrecisionConfig(2))); return b.Build(); } }; -class ConvTestBase : public HloTestBase { +class ConvTestBase : public HloHardwareIndependentTestBase { public: std::unique_ptr Simplify(ConvTestOptions options) { auto module = CreateNewVerifiedModule(); @@ -6827,8 +6955,8 @@ TEST_F(AlgebraicSimplifierTest, TransposeOfNonCanonicalBatchDotCantSimplify) { TEST_F(AlgebraicSimplifierTest, DynamicSliceOfTranspose) { // This test is without layouts so we have to set the verifier to be layout // insensitive. - verifier_layout_sensitive_ = false; - instruction_can_change_layout_func_ = {}; + set_verifier_layout_sensitive(false); + set_instruction_can_change_layout_func({}); const char* hlo_string = R"( HloModule module @@ -8323,11 +8451,11 @@ TEST_F(AlgebraicSimplifierTest, GatherOfPad) { HloModule module ENTRY %entry { - reshape.17992 = f32[25165824,32]{1,0} parameter(0) - constant.31700 = f32[] constant(0) - pad.921 = f32[25165824,128]{1,0} pad(reshape.17992, constant.31700), padding=0_0x0_96 - reshape.40561 = s32[20447232,1]{1,0} parameter(1) - gather.100277 = f32[20447232,128]{1,0} gather(pad.921, reshape.40561), + par.0 = f32[25165824,32]{1,0} parameter(0) + constant.0 = f32[] constant(0) + pad = f32[25165824,128]{1,0} pad(par.0, constant.0), padding=0_0x0_96 + start_indices = s32[20447232,1]{1,0} parameter(1) + gather = f32[20447232,128]{1,0} gather(pad, start_indices), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,128} })"; @@ -8339,22 +8467,26 @@ ENTRY %entry { EXPECT_TRUE(simplifier.Run(module.get()).value()); VLOG(2) << "After rewrite \n" << module->ToString(); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, - GmockMatch(m::Pad(m::Gather(m::Parameter(0), m::Parameter(1)), - m::ConstantScalar(0)))); + const HloInstruction* gather_instr; + EXPECT_THAT(root, GmockMatch(m::Pad(m::Gather(&gather_instr, m::Parameter(0), + m::Parameter(1)), + m::ConstantScalar(0)))); + EXPECT_THAT(Cast(gather_instr)->gather_slice_sizes(), + ElementsAre(1, 32)); } -TEST_F(AlgebraicSimplifierTest, GatherOfPad2) { +TEST_F(AlgebraicSimplifierTest, GatherOfPadWithBatchDims) { const char* hlo_string = R"( HloModule module ENTRY %entry { - iota.3 = s32[4,1]{1,0} iota(), iota_dimension=0 - constant.36 = s32[] constant(0) - pad = s32[4,2]{1,0} pad(iota.3, constant.36), padding=0_0x0_1 - reshape.300 = s32[3,40,1]{2,1,0} parameter(0) - gather.363 = s32[3,40,2]{2,1,0} gather(pad, reshape.300), - offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, + iota = s32[4,1]{1,0} iota(), iota_dimension=0 + constant.0 = s32[] constant(0) + pad = s32[4,2]{1,0} pad(iota, constant.0), padding=0_0x0_1 + start_indices = s32[4,40,1]{2,1,0} parameter(0) + gather = s32[4,40,2]{2,1,0} gather(pad, start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={0}, + operand_batching_dims={0}, start_indices_batching_dims={0}, index_vector_dim=2, slice_sizes={1,2} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, @@ -8365,8 +8497,16 @@ ENTRY %entry { EXPECT_TRUE(simplifier.Run(module.get()).value()); VLOG(2) << "After rewrite \n" << module->ToString(); auto root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, GmockMatch(m::Pad(m::Gather(m::Iota(), m::Parameter(0)), - m::ConstantScalar(0)))); + const HloInstruction* gather_instr; + EXPECT_THAT(root, GmockMatch(m::Pad( + m::Gather(&gather_instr, m::Iota(), m::Parameter(0)), + m::ConstantScalar(0)))); + auto gather = Cast(gather_instr); + EXPECT_THAT(gather->gather_slice_sizes(), ElementsAre(1, 1)); + EXPECT_THAT(gather->gather_dimension_numbers().operand_batching_dims(), + ElementsAre(0)); + EXPECT_THAT(gather->gather_dimension_numbers().start_indices_batching_dims(), + ElementsAre(0)); } TEST_F(AlgebraicSimplifierTest, GatherOfReshapeOfPad) { @@ -8448,6 +8588,37 @@ ENTRY %entry { m::Parameter(1)))); } +TEST_F(AlgebraicSimplifierTest, GatherOfReshapeOfPad4) { + const char* hlo_string = R"( +HloModule module + +ENTRY %entry { + dot.165 = bf16[2048,8192]{1,0} parameter(0) + constant.16 = bf16[] constant(0) + reshape.60 = s32[16,1]{1,0} parameter(1) + pad.6 = bf16[4096,8192]{1,0} pad( + bf16[2048,8192]{1,0} %dot.165, bf16[] %constant.16), padding=0_2048x0_0 + reshape.170 = bf16[4096,16,512]{2,1,0} reshape(bf16[4096,8192]{1,0} %pad.6) + gather.175 = bf16[4096,16,512]{2,1,0} gather( + bf16[4096,16,512]{2,1,0} %reshape.170, s32[16,1]{1,0} %reshape.60), + offset_dims={0,2}, collapsed_slice_dims={1}, start_index_map={1}, + index_vector_dim=1, slice_sizes={4096,1,512} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + VLOG(0) << "After rewrite \n" << module->ToString(); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Pad(m::Gather(m::Reshape(), m::Parameter(1)), + m::ConstantScalar(0)))); + EXPECT_EQ(root->padding_config().dimensions(0).edge_padding_high(), 2048); + EXPECT_EQ(root->padding_config().dimensions(1).edge_padding_high(), 0); + EXPECT_EQ(root->padding_config().dimensions(2).edge_padding_high(), 0); +} + TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { const char* hlo_string = R"( HloModule module @@ -8544,8 +8715,8 @@ TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) { TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) { // This test is without layouts so we have to set the verifier to be layout // insensitive. - verifier_layout_sensitive_ = false; - instruction_can_change_layout_func_ = {}; + set_verifier_layout_sensitive(false); + set_instruction_can_change_layout_func({}); Shape shape = ShapeUtil::MakeShape(F32, {}); shape.clear_layout(); @@ -8985,6 +9156,21 @@ TEST_F(AlgebraicSimplifierTest, CompareIota) { GmockMatch(m::Broadcast(m::ConstantScalar(false)))); } +TEST_F(AlgebraicSimplifierTest, CompareAbsLtZeroBecomesFalse) { + // |x| < 0 -> false + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( +m { + p = s32[5] parameter(0) + a = s32[5] abs(p) + z = s32[] constant(0) + b = s32[5] broadcast(z) + ROOT r = pred[5] compare(a, b), direction=LT +})")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::ConstantScalar(false)))); +} + TEST_F(AlgebraicSimplifierTest, CompareLtZero) { const char* kModuleStr = R"( HloModule m @@ -10026,6 +10212,110 @@ TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWeirdDnums2) { m::Concatenate(m::Parameter(3), m::Parameter(4))))); } +TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWithBatchDim) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[3,100,4] broadcast(z), dimensions={} + index0 = s32[3,1,4,5] parameter(0) + index1 = s32[3,1,2,5] parameter(1) + update0 = f32[3,4,4,5] parameter(2) + update1 = f32[3,2,4,5] parameter(3) + scatter.0 = f32[3,100,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={2}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1, + input_batching_dims={0}, + scatter_indices_batching_dims={0} + scatter.1 = f32[3,100,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={2}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=1, + input_batching_dims={0}, + scatter_indices_batching_dims={0} + ROOT add.1 = f32[3,100,4] add(scatter.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + // Simplify Add + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + const HloInstruction* concat1; + const HloInstruction* concat2; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter( + m::Broadcast(), + m::Concatenate(&concat1, m::Parameter(0), m::Parameter(1)), + m::Concatenate(&concat2, m::Parameter(2), m::Parameter(3))))); + EXPECT_EQ(Cast(concat1)->concatenate_dimension(), + 2); + EXPECT_EQ(Cast(concat2)->concatenate_dimension(), + 1); +} + +TEST_F(AlgebraicSimplifierTest, ScatterAddCombinedWithBatchDim2) { + const char* hlo_string = R"( + HloModule m + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a, b) + } + test { + z = f32[] constant(0) + init = f32[100,3,4] broadcast(z), dimensions={} + index0 = s32[4,3,5,1] parameter(0) + index1 = s32[2,3,5,1] parameter(1) + update0 = f32[4,3,4,5] parameter(2) + update1 = f32[2,3,4,5] parameter(3) + scatter.0 = f32[100,3,4] scatter(init, index0, update0), + to_apply=apply, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=3, + input_batching_dims={1}, + scatter_indices_batching_dims={1} + scatter.1 = f32[100,3,4] scatter(init, index1, update1), + to_apply=apply, + update_window_dims={2}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=3, + input_batching_dims={1}, + scatter_indices_batching_dims={1} + ROOT add.1 = f32[100,3,4] add(scatter.0, scatter.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + // Combine Scatters + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + // Simplify Add + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + const HloInstruction* concat1; + const HloInstruction* concat2; + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Scatter( + m::Broadcast(), + m::Concatenate(&concat1, m::Parameter(0), m::Parameter(1)), + m::Concatenate(&concat2, m::Parameter(2), m::Parameter(3))))); + EXPECT_EQ(Cast(concat1)->concatenate_dimension(), + 0); + EXPECT_EQ(Cast(concat2)->concatenate_dimension(), + 0); +} + TEST_F(AlgebraicSimplifierTest, ScalarScatter) { const char* hlo_string = R"( HloModule m @@ -10561,6 +10851,18 @@ TEST_F(AlgebraicSimplifierTest, AbsEliminationSelMaxBcast) { m::Broadcast(m::ConstantScalar()))))); } +TEST_F(AlgebraicSimplifierTest, AbsEliminationIota) { + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(R"( + e { + i = s32[3,2] iota(), iota_dimension=0 + ROOT a = s32[3,2] abs(i) + } + )")); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Iota())); +} + TEST_F(AlgebraicSimplifierTest, SimplifyRedundantBitcastConvert) { const char* kModuleStr = R"( HloModule m @@ -10723,6 +11025,28 @@ TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy) { GmockMatch(m::Bitcast(m::Copy(m::Parameter())))); } +TEST_F(AlgebraicSimplifierTest, CopyBitcastCopyDimSize1) { + const char* kModuleStr = R"( + HloModule m + + ENTRY test { + param.8 = f32[9, 1, 12]{2,1,0} parameter(0) + transpose.1 = f32[1,12,9]{1,0,2} transpose(param.8), dimensions={1,2,0} + copy.4 = f32[1,12,9]{2,1,0} copy(transpose.1) + bitcast.15 = f32[1,108]{1,0} bitcast(copy.4) + copy.1 = f32[1,108]{0,1} copy(bitcast.15) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + ASSERT_TRUE(simplifier.Run(m.get()).value()); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast(m::Bitcast(m::Copy(m::Bitcast(m::Parameter())))))); +} + TEST_F(AlgebraicSimplifierTest, CopyBitcastCopy2) { const char* kModuleStr = R"( HloModule m @@ -11779,8 +12103,8 @@ TEST_F(AlgebraicSimplifierTest, PreserveSharding) { // Move parameter from the LHS of a dot to the RHS. TEST_F(AlgebraicSimplifierTest, SwapDotOperands) { - verifier_layout_sensitive_ = false; - instruction_can_change_layout_func_ = {}; + set_verifier_layout_sensitive(false); + set_instruction_can_change_layout_func({}); const std::string hlo_string = R"( HloModule main @@ -12048,5 +12372,106 @@ TEST_F(AlgebraicSimplifierTest, GmockMatch(m::Slice(m::Parameter(0)))); } +// Bitcast of broadcast is not simplified if the layouts are different. +// TransposeBitcastOfBroadcast is a simplified example. +TEST_F(AlgebraicSimplifierTest, BitcastBroadcastDifferentLayout) { + const char* hlo_string = R"( + HloModule module + + ENTRY f { + %operand = f32[200001]{0:T(1024)} parameter(0) + %broadcast.91 = f32[200001,128]{1,0:T(8,128)} broadcast(f32[200001]{0:T(1024)} %operand), dimensions={0} + %bitcast.8 = f32[200001,128]{1,0:T(8)L(1024)} bitcast(f32[200001,128]{1,0:T(8,128)} %broadcast.91) + ROOT %add = f32[200001,128]{1,0:T(8)L(1024)} add(f32[200001,128]{1,0:T(8)L(1024)} %bitcast.8, f32[200001,128]{1,0:T(8)L(1024)} %bitcast.8) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_is_layout_sensitive(true); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + +TEST_F(AlgebraicSimplifierTest, TrivialMin) { + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4,4] parameter(0) + ROOT %min = f32[4,4] minimum(%a, %a) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AlgebraicSimplifierTest, TrivialMax) { + const char* kModuleStr = R"( + HloModule m + test { + a = f32[4,4] parameter(0) + ROOT %min = f32[4,4] maximum(%a, %a) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Parameter(0))); +} + +TEST_F(AlgebraicSimplifierTest, PathologicalComplexity) { + // Without replacing min(x,x)->x, the algorithmic recursion complexity is + // O(2^n). + const char* kModuleStr = R"( + HloModule m + test { + a = s32[4,4] parameter(0) + b = s32[4,4] parameter(1) + %cmp0 = pred[4,4] compare(a, b), direction=GE + %c1 = f32[] constant(1) + %ones = f32[4,4] broadcast(f32[] %c1) + %c0 = f32[] constant(0) + %zeros = f32[4,4] broadcast(f32[] %c0) + %min = f32[4,4] minimum(%ones, %zeros) + %min0 = f32[4,4] minimum(%min, %min) + %min1 = f32[4,4] minimum(%min0, %min0) + %min2 = f32[4,4] minimum(%min1, %min1) + %min3 = f32[4,4] minimum(%min2, %min2) + %min4 = f32[4,4] minimum(%min3, %min3) + %min5 = f32[4,4] minimum(%min4, %min4) + %min6 = f32[4,4] minimum(%min5, %min5) + %min7 = f32[4,4] minimum(%min6, %min6) + %min8 = f32[4,4] minimum(%min7, %min7) + %min9 = f32[4,4] minimum(%min8, %min8) + %min10 = f32[4,4] minimum(%min9, %min9) + %min11 = f32[4,4] minimum(%min10, %min10) + %min12 = f32[4,4] minimum(%min11, %min11) + %min13 = f32[4,4] minimum(%min12, %min12) + %min14 = f32[4,4] minimum(%min13, %min13) + %min15 = f32[4,4] minimum(%min14, %min14) + %min16 = f32[4,4] minimum(%min15, %min15) + %min17 = f32[4,4] minimum(%min16, %min16) + %min18 = f32[4,4] minimum(%min17, %min17) + %min19 = f32[4,4] minimum(%min18, %min18) + %min20 = f32[4,4] minimum(%min19, %min19) + %min21 = f32[4,4] minimum(%min20, %min20) + %min22 = f32[4,4] minimum(%min21, %min21) + %min23 = f32[4,4] minimum(%min22, %min22) + %min24 = f32[4,4] minimum(%min23, %min23) + %min25 = f32[4,4] minimum(%min24, %min24) + %min26 = f32[4,4] minimum(%min25, %min25) + %min27 = f32[4,4] minimum(%min26, %min26) + %min28 = f32[4,4] minimum(%min27, %min27) + %min29 = f32[4,4] minimum(%min28, %min28) + ROOT %cmp1 = pred[4,4] compare(%min29, %zeros), direction=LT + })"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::Broadcast(m::Constant()))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/all_reduce_folder.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc similarity index 98% rename from third_party/xla/xla/service/all_reduce_folder.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc index d616cc411844f5..49ba41a4cedcdd 100644 --- a/third_party/xla/xla/service/all_reduce_folder.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_folder.h" +#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" #include #include @@ -35,6 +35,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/all_reduce_key.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.h b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.h new file mode 100644 index 00000000000000..ed43d9be54838b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder.h @@ -0,0 +1,50 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_REDUCE_FOLDER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_REDUCE_FOLDER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass that folds an all-reduce feeding into another all-reduce by expanding +// the replica groups. As an example: +// +// ar0 = all-reduce(x) replica_groups={{0,1},{2,3},{4,5},{6,7}} +// ar1 = all-reduce(all-reduce0) replica_groups={{0,2},{1,3},{4,6},{5,7}} +// +// Can be combined into a single all-reduce: +// +// ar1 = all-reduce(x) replica_groups={{0,1,2,3},{4,5,6,7}} +// + +class AllReduceFolder : public HloModulePass { + public: + absl::string_view name() const override { return "all-reduce-folder"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_ALL_REDUCE_FOLDER_H_ diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder_test.cc new file mode 100644 index 00000000000000..808b6e20f77704 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/all_reduce_folder_test.cc @@ -0,0 +1,207 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/utils/hlo_matchers.h" +#include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +namespace matcher = xla::testing::opcode_matchers; +using ::testing::HasSubstr; + +class AllReduceFolderTest : public HloHardwareIndependentTestBase {}; + +const char *k2AllReduce = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups=$group_0, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups=$group_1, to_apply=sum + } + )"; + +size_t AllReduceCount(HloModule *module) { + return absl::c_count_if(module->entry_computation()->instructions(), + HloPredicateIsOp); +} + +void ExpectOneAllReduce(HloModule *module, + absl::string_view target_replica_groups) { + EXPECT_EQ(AllReduceCount(module), 1); + HloInstruction *root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, matcher::AllReduce(matcher::Parameter(0))); + EXPECT_THAT(root->ToString(), HasSubstr(target_replica_groups)); +} + +TEST_F(AllReduceFolderTest, Simple) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_0", "{{0,1},{2,3}}"}, + {"$group_1", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); +} + +// Same as Simple, but groups for the 2 all-reduce's are swapped. +TEST_F(AllReduceFolderTest, SimpleSwap) { + TF_ASSERT_OK_AND_ASSIGN( + auto module, RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), true, + {{"$group_1", "{{0,1},{2,3}}"}, + {"$group_0", "{{0,2},{1,3}}"}})); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3}}"); +} + +TEST_F(AllReduceFolderTest, BothEmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite(k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{}"}})); +} + +TEST_F(AllReduceFolderTest, EmptyReplicaGroups_NotTransformed) { + TF_ASSERT_OK(RunAndCheckHloRewrite( + k2AllReduce, AllReduceFolder(), false, + {{"$group_0", "{}"}, {"$group_1", "{{0,2},{1,3}}"}})); +} + +TEST_F(AllReduceFolderTest, MismatchOtherProperties0_NotTransformed) { + absl::string_view hlo_string = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); +} + +TEST_F(AllReduceFolderTest, MismatchOtherProperties1_NotTransformed) { + absl::string_view hlo_string = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + mul { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT mul = f32[] multiply(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); +} + +TEST_F(AllReduceFolderTest, NotFoldable_NotTransformed) { + absl::string_view hlo_string = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum + } + )"; + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo_string, AllReduceFolder(), false)); +} + +TEST_F(AllReduceFolderTest, Foldable0) { + absl::string_view hlo_string = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum + ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,4,5},{2,3,6,7}}"); +} + +// Verify that a chain of foldable all-reduce's folds in a single pass +// invocation. +TEST_F(AllReduceFolderTest, FoldableChain) { + absl::string_view hlo_string = R"( + HloModule m + + sum { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT add.2 = f32[] add(a, b) + } + + ENTRY main { + p0 = f32[8] parameter(0) + ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum + ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum + ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo_string, AllReduceFolder())); + ExpectOneAllReduce(module.get(), "replica_groups={{0,1,2,3,4,5,6,7}}"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/ar_crs_combiner.cc b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc similarity index 99% rename from third_party/xla/xla/service/ar_crs_combiner.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc index a75acbc2b38498..b759bdb6d70422 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/ar_crs_combiner.h" +#include "xla/hlo/transforms/simplifiers/ar_crs_combiner.h" #include #include @@ -28,6 +28,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -37,11 +38,11 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_replication_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.h b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.h new file mode 100644 index 00000000000000..3b5ffc22bf83c8 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner.h @@ -0,0 +1,200 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_AR_CRS_COMBINER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_AR_CRS_COMBINER_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/call_graph.h" + +namespace xla { + +// When the HLO graph contains a cross-module AllReduce (N separate AllReduce +// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op +// for SPMD partitioning), followed by some simple linear operations, followed +// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we +// can combine the CMAR and the CRAR, to use an efficient AllReduce +// implementation that fully utilizes the interconnect bandwidth. +// +// Such sequences appear in spatially partitioned models (either MPMD or SPMD). +// This pass must run right after spatial partitioning, when the code is still +// in a single HLO module. +// +// The steps are: +// 1) Find CMARs followed by simple ops followed by CRARs. +// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD +// partitioning, there will only be a single CMAR for each channel_id. +// 3) Prove that the CMAR patterns in each core produce the same result. +// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the +// other operand by the number of spatial partitions. +// 5) Turn the CRAR into an all-core AllReduce. +// +// The pass also handles the case where multiple CMARs lead to the same CRAR, +// and eliminates all CMARs. This graph: +// +// Y +// | +// X CMAR_2 Z +// | \ / +// CMAR_1 + +// \ / +// + +// | +// CRAR +// +// gets rewritten to: +// +// Z num_partitions +// \ / +// Y div +// \ / +// X + +// \ / +// + +// | +// all-core AR +// +class ArCrsCombiner : public HloModulePass { + public: + ArCrsCombiner(int num_spatial_partitions, bool spmd_partition) + : num_spatial_partitions_(num_spatial_partitions), + spmd_partition_(spmd_partition) {} + absl::string_view name() const override { return "ar-crs-combiner"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // Helper method to allow testing of InstructionsComputeSameValue. + static bool TestInstructionsComputeSameValue(HloInstruction* i1, + HloInstruction* i2); + + private: + // We used this struct because multiple ARs could be paired with the same CRS. + // In this case, we want to select the AR that is furthest from the CRS, + // because it makes it easier to eliminate all ARs during RewriteGraph. + struct ArCrsPair { + HloInstruction* ar; + HloInstruction* crs; + // The length of the path from AR to CRS in the HLO graph. + int64_t distance; + + ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, + int64_t dist) + : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} + + std::string ToString() { + std::string result; + absl::StrAppend(&result, "("); + HloInstruction* instruction = ar; + while (instruction != crs) { + absl::StrAppend(&result, instruction->name(), ","); + instruction = instruction->users()[0]; + } + absl::StrAppend(&result, instruction->name(), + ")[id:", *(ar->channel_id()), ",dist:", distance, "]"); + return result; + } + }; + + std::optional MatchesArCrsPattern( + HloInstruction* instruction); + + // If the passed instruction is a while parameter, and the while body is only + // called by a single while instruction, return the while instruction. + std::optional WhileFromBodyParameter( + HloInstruction* instruction); + + // If the passed instruction is a parameter in one of the branch computations, + // and the branch body is only called by a single instruction, return the + // conditional instruction. + std::optional ConditionalFromBodyParameter( + HloInstruction* instruction); + + // Returns a vector of tuple instructions. + // If all instructions that flow to "instruction" are tuples, return them. + // Otherwise, return std::nullopt. Returns an empty vector if the instruction + // is already in the visited set. + std::optional> GetAllTuples( + HloInstruction* instruction, + absl::flat_hash_set* visited); + + // Checks whether two different elements in the same tuple compute the same + // value. + bool TupleElementsComputeSameValue( + HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2, + absl::flat_hash_map* visited_pairs); + + // Returns whether the instructions i1 and i2 can be shown to evaluate to the + // same value. Handling WHILE requires recursion, which may cause us to visit + // the same instruction again. To avoid infinite loops, we pass a cache of + // visited instruction pairs. + bool InstructionsComputeSameValue( + HloInstruction* i1, HloInstruction* i2, + absl::flat_hash_map* visited_pairs); + + // Populates all_reduce_map_. + void GroupAllReducesById(HloModule* module); + + // Looks at each AllReduce group in all_reduce_map_, and keeps only the + // groups for which it's safe to move the AllReduce later in the HLO graph. + absl::Status KeepProvablyEqualInstructionGroupsMPMD(); + + // Same as above, but runs on SPMD partitioned module instead of MPMD. + absl::Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module); + + // Performs the graph rewrite that eliminates the early AllReduce and turns + // the later CRS into an AllReduce. + absl::StatusOr RewriteGraph(); + + int num_spatial_partitions_; + + // Run this combiner pass assuming the input module is an SPMD partitioned + // module (as opposed to MPMD partitioned). + // + // The main difference between the two w.r.t. this pass is that there would be + // N all-reduce ops for each channel in MPMD mode, whereas there is only 1 + // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO + // equivalence check in SPMD mode. + bool spmd_partition_; + + // Map from all-reduce ids to the AR/CRS pairs. + absl::flat_hash_map> all_reduce_map_; + + // Map from a CRS instruction to the all-reduce ID of the AR paired with the + // CRS. Sometimes, several ARs in the code could be paired with the same CRS. + // We use this map to pick a single AR/CRS path to rewrite. + absl::flat_hash_map crs_reserved_map_; + + std::unique_ptr call_graph_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_AR_CRS_COMBINER_H_ diff --git a/third_party/xla/xla/service/ar_crs_combiner_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner_test.cc similarity index 99% rename from third_party/xla/xla/service/ar_crs_combiner_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner_test.cc index e18d81d20fa93e..e435089deaf071 100644 --- a/third_party/xla/xla/service/ar_crs_combiner_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/ar_crs_combiner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/ar_crs_combiner.h" +#include "xla/hlo/transforms/simplifiers/ar_crs_combiner.h" #include #include @@ -24,8 +24,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { @@ -33,7 +34,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ArCrsCombinerTest : public HloTestBase {}; +class ArCrsCombinerTest : public HloHardwareIndependentTestBase {}; TEST_F(ArCrsCombinerTest, SameValueTestBasecase) { const char* module_str = R"( diff --git a/third_party/xla/xla/service/batch_dot_simplification.cc b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc similarity index 91% rename from third_party/xla/xla/service/batch_dot_simplification.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc index 3f22acf1930249..52fc16518a0302 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.cc @@ -13,13 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/batch_dot_simplification.h" +#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { absl::StatusOr diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.h b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.h new file mode 100644 index 00000000000000..4d82376d61de08 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_BATCH_DOT_SIMPLIFICATION_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_BATCH_DOT_SIMPLIFICATION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Simplifies batch dot operations. +// +// Normally these would live in the algebraic simplifier, but we want to run +// this to fixpoint (this pass reaches fixed point in one execution) before we +// run the DotDecomposer. +class BatchDotSimplification : public HloModulePass { + public: + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + absl::string_view name() const override { return "batch-dot-simplification"; } + + private: + absl::StatusOr ElideDegenerateBatchDimensionFromBatchDot( + HloInstruction* batch_dot); +}; +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/third_party/xla/xla/service/batch_dot_simplification_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification_test.cc similarity index 95% rename from third_party/xla/xla/service/batch_dot_simplification_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification_test.cc index fd60e8f2a3ade3..ef2c10ac0b36dd 100644 --- a/third_party/xla/xla/service/batch_dot_simplification_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/batch_dot_simplification_test.cc @@ -13,18 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/batch_dot_simplification.h" +#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -class BatchDotSimplificationTest : public HloTestBase {}; +class BatchDotSimplificationTest : public HloHardwareIndependentTestBase {}; TEST_F(BatchDotSimplificationTest, ElideSingleDegenerateBatchDotDim_VectorVector) { diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc similarity index 98% rename from third_party/xla/xla/service/bfloat16_conversion_folding.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc index 4785b03f07c947..bd6b036fca8436 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bfloat16_conversion_folding.h" +#include "xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h" #include #include @@ -21,12 +21,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/float_support.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h new file mode 100644 index 00000000000000..b21e512d60bc83 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_BFLOAT16_CONVERSION_FOLDING_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_BFLOAT16_CONVERSION_FOLDING_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/float_support.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// A pass which folds F32 <-> BF16 conversions to their operands or users, when +// it is supported by the backend. +// +// This pass follows the passed-in backend-specific BF16 support rules, but can +// introduce mixed precision in individual HLOs which breaks the assumption of +// some other HLO passes. So it should be used at the end of the HLO +// optimization pipeline followed by a DCE pass. If other passes are needed +// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the +// changed made by this pass. +class BFloat16ConversionFolding : public HloModulePass { + public: + explicit BFloat16ConversionFolding(const FloatSupport* bfloat16_support) + : bfloat16_support_(bfloat16_support) { + DCHECK(bfloat16_support->LowPrecisionType() == BF16); + } + + ~BFloat16ConversionFolding() override = default; + absl::string_view name() const override { return "bfloat16-fold"; } + + // Run BF16 conversion folding on the given computation. Returns whether the + // computation was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const FloatSupport* bfloat16_support_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding_test.cc similarity index 97% rename from third_party/xla/xla/service/bfloat16_conversion_folding_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding_test.cc index 09646c524ad004..39754b1d0b7f22 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/bfloat16_conversion_folding_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/bfloat16_conversion_folding.h" +#include "xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h" #include #include @@ -24,12 +24,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/float_support.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" namespace xla { @@ -72,11 +72,12 @@ class TestBFloat16Support : public FloatSupport { } }; -class BFloat16ConversionFoldingTest : public HloTestBase { +class BFloat16ConversionFoldingTest : public HloHardwareIndependentTestBase { protected: BFloat16ConversionFoldingTest() - : HloTestBase(/*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool FoldConversions(HloModule* module) { TestBFloat16Support bfloat16_support_; diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.cc b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.cc similarity index 97% rename from third_party/xla/xla/service/broadcast_canonicalizer.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.cc index 4938f087bf711f..2fcac8f27f53a3 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/broadcast_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h new file mode 100644 index 00000000000000..08cc85b814933b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h @@ -0,0 +1,40 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_BROADCAST_CANONICALIZER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_BROADCAST_CANONICALIZER_H_ + +#include + +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This transform ensures that dimensions in all broadcast operations are +// sorted. +class BroadcastCanonicalizer : public HloModulePass { + public: + explicit BroadcastCanonicalizer(); + + absl::string_view name() const override { return "broadcast_canonicalizer"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_BROADCAST_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/broadcast_canonicalizer_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer_test.cc similarity index 91% rename from third_party/xla/xla/service/broadcast_canonicalizer_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer_test.cc index 812c964ebc9345..a227b1324d28f3 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/broadcast_canonicalizer_test.cc @@ -13,21 +13,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/broadcast_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" #include #include #include +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { -class BroadcastCanonicalizerTest : public HloTestBase {}; +class BroadcastCanonicalizerTest : public HloHardwareIndependentTestBase {}; TEST_F(BroadcastCanonicalizerTest, ReshapeBroadcast) { const char* hlo = R"( diff --git a/third_party/xla/xla/service/conditional_canonicalizer.cc b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc similarity index 97% rename from third_party/xla/xla/service/conditional_canonicalizer.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc index 7f13bb67ad7fc4..a511652e8eb5b5 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/conditional_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.h b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.h new file mode 100644 index 00000000000000..36de42ce428f51 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer.h @@ -0,0 +1,41 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONDITIONAL_CANONICALIZER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONDITIONAL_CANONICALIZER_H_ + +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Canonicalize output of conditionals, make non-tuple outputs into tuple with +// single element output. After this pass, all conditional instructions have +// tuple outputs. +class ConditionalCanonicalizer : public HloModulePass { + public: + absl::string_view name() const override { + return "conditional-canonicalizer"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONDITIONAL_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/conditional_canonicalizer_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer_test.cc similarity index 88% rename from third_party/xla/xla/service/conditional_canonicalizer_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer_test.cc index beba61a5a67832..54d9b3c163809f 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/conditional_canonicalizer_test.cc @@ -13,18 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/conditional_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" @@ -34,7 +33,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class ConditionalCanonicalizerTest : public HloTestBase { +class ConditionalCanonicalizerTest : public HloHardwareIndependentTestBase { protected: ConditionalCanonicalizerTest() {} }; diff --git a/third_party/xla/xla/service/convert_mover.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc similarity index 99% rename from third_party/xla/xla/service/convert_mover.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc index 0baafbfd4fbb1c..adc4c6d62e7c76 100644 --- a/third_party/xla/xla/service/convert_mover.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.h b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.h new file mode 100644 index 00000000000000..43732be420ea21 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover.h @@ -0,0 +1,54 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_MOVER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_MOVER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Moves narrowing conversions up the graph and widening conversions down the +// graph, when we can do so with no effect on numerics. Motivations: +// +// - It's preferable to spend more of our time in lower precision and less of +// our time in higher precision. +// +// - Moving these converts exposes optimization opportunities. For example, in +// reshape(convert-big-to-small(reshape(convert-small-to-big(x)))), we can +// commute one of the converts with one of the reshapes. This leaves us with +// convert(convert(reshape(reshape))), which can probably be simplified +// further by algsimp. +class ConvertMover : public HloModulePass { + public: + ConvertMover() = default; + + absl::string_view name() const override { return "convert-mover"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_MOVER_H_ diff --git a/third_party/xla/xla/service/convert_mover_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover_test.cc similarity index 96% rename from third_party/xla/xla/service/convert_mover_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convert_mover_test.cc index d5cb059f3aced3..1cbe5c3415a881 100644 --- a/third_party/xla/xla/service/convert_mover_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_mover_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" #include #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -30,11 +30,12 @@ namespace { namespace m = ::xla::match; -class ConvertMoverTest : public HloTestBase { +class ConvertMoverTest : public HloHardwareIndependentTestBase { public: ConvertMoverTest() - : HloTestBase(/*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; template diff --git a/third_party/xla/xla/service/convert_operand_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc similarity index 98% rename from third_party/xla/xla/service/convert_operand_folding.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc index 10336583d06a88..bee24f707f3b5c 100644 --- a/third_party/xla/xla/service/convert_operand_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_operand_folding.h" +#include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.h b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.h new file mode 100644 index 00000000000000..b1010c37fa75b2 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_OPERAND_FOLDER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_OPERAND_FOLDER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// Folds Convert operands to wider types into instructions that supports wider +// result accumulation than the shape inference type. +// +// e.g. s32 hlo(s32 convert(s8), s32 convert(s8)) -> s32 hlo(s8, s8) +class ConvertOperandFolding : public OpExpanderPass { + public: + absl::string_view name() const override { return "convert_operand_folding"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVERT_OPERAND_FOLDER_H_ diff --git a/third_party/xla/xla/service/convert_operand_folding_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder_test.cc similarity index 97% rename from third_party/xla/xla/service/convert_operand_folding_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder_test.cc index 69343e49754dec..063b2d4130555f 100644 --- a/third_party/xla/xla/service/convert_operand_folding_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convert_operand_folder_test.cc @@ -13,21 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convert_operand_folding.h" +#include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" #include #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { namespace op = ::xla::testing::opcode_matchers; -using ConvertOperandFoldingTest = HloTestBase; +using ConvertOperandFoldingTest = HloHardwareIndependentTestBase; TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) { absl::string_view module_string = R"( diff --git a/third_party/xla/xla/service/convolution_group_converter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.cc similarity index 99% rename from third_party/xla/xla/service/convolution_group_converter.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.cc index a3be4a0dddf438..d8f9fcd50d7e09 100644 --- a/third_party/xla/xla/service/convolution_group_converter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_group_converter.h" +#include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.h b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.h new file mode 100644 index 00000000000000..16d39ad72382d5 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter.h @@ -0,0 +1,69 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVOLUTION_GROUP_CONVERTER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVOLUTION_GROUP_CONVERTER_H_ + +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/status_macros.h" + +namespace xla { + +// A pass which rewrites convolutions with feature_group_count > 1 into +// convolutions with feature_group_count = 1. +class ConvolutionGroupConverter : public HloModulePass { + public: + ConvolutionGroupConverter(std::function should_expand, + std::function is_cost_viable, + bool convert_batch_groups_only, + bool filter_expansion = true) + : should_expand_(should_expand), + is_cost_viable_(is_cost_viable), + convert_batch_groups_only_(convert_batch_groups_only), + filter_expansion_(filter_expansion) {} + + absl::string_view name() const override { + return "convolution-group-converter"; + } + + // Run convolution rewriting on the given computation. Returns whether the + // computation was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // Predicate that determines whether this pass should rewrite a given + // convolution. + std::function should_expand_; + + // Lambda containing cost model that decides whether to expand + // batch_group_count. + std::function is_cost_viable_; + + // Decides whether to convert batch groups or feature groups. + bool convert_batch_groups_only_; + + // Tells whether filter expansion is required. + bool filter_expansion_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/third_party/xla/xla/service/convolution_group_converter_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter_test.cc similarity index 97% rename from third_party/xla/xla/service/convolution_group_converter_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter_test.cc index 16f7dcbd49acf9..53be45e88f1201 100644 --- a/third_party/xla/xla/service/convolution_group_converter_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/convolution_group_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/convolution_group_converter.h" +#include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" #include #include @@ -21,15 +21,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/types.h" namespace xla { namespace { -using ConvolutionGroupConverterTest = HloTestBase; +using ConvolutionGroupConverterTest = HloHardwareIndependentTestBase; namespace op = testing::opcode_matchers; TEST_F(ConvolutionGroupConverterTest, diff --git a/third_party/xla/xla/service/dot_dimension_merger.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc similarity index 99% rename from third_party/xla/xla/service/dot_dimension_merger.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc index a8b881187dcaa7..c05ef071f2e592 100644 --- a/third_party/xla/xla/service/dot_dimension_merger.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_dimension_merger.h" +#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.h b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.h new file mode 100644 index 00000000000000..52bf94d24154a6 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger.h @@ -0,0 +1,42 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_DIMENSION_MERGER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_DIMENSION_MERGER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Merge consecutive batch dimensions of a dot() by inserting reshapes. +class DotDimensionMerger : public HloModulePass { + public: + absl::string_view name() const override { return "dot_dimension_merger"; } + + // Run the pass on computations in 'module'. + // Return whether the 'module' was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_DIMENSION_MERGER_H_ diff --git a/third_party/xla/xla/service/dot_dimension_merger_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger_test.cc similarity index 95% rename from third_party/xla/xla/service/dot_dimension_merger_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger_test.cc index a41b904f5259a9..1894c0c0708188 100644 --- a/third_party/xla/xla/service/dot_dimension_merger_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_dimension_merger_test.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_dimension_merger.h" +#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" #include #include #include -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -using DotDimensionMergerTest = HloTestBase; +using DotDimensionMergerTest = HloHardwareIndependentTestBase; TEST_F(DotDimensionMergerTest, MergeConsecutiveBatchDimensions) { const std::string kHloText = R"( diff --git a/third_party/xla/xla/service/dot_merger.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc similarity index 97% rename from third_party/xla/xla/service/dot_merger.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc index f0da83fd52d206..812c46112c62a1 100644 --- a/third_party/xla/xla/service/dot_merger.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_merger.h" +#include "xla/hlo/transforms/simplifiers/dot_merger.h" #include #include @@ -278,8 +278,10 @@ absl::StatusOr TryMergeSameOperand(HloInstruction* a, return new_dot; } -absl::StatusOr MergeDots(HloComputation* comp, - int64_t max_size_to_merge) { +absl::StatusOr MergeDots(HloComputation* comp, int64_t max_size_to_merge, + std::function + can_merge) { auto is_merge_candidate = [&](HloInstruction* instr) { int64_t bytes = ShapeUtil::ByteSizeOfElements(instr->shape()); for (const HloInstruction* operand : instr->operands()) { @@ -395,7 +397,7 @@ absl::StatusOr MergeDots(HloComputation* comp, (!is_merge_candidate(a) && !is_merge_candidate(b)) || // Perform reachability checks last since they can be expensive. graph.IsReachableNonConst(a_id, b_id) || - graph.IsReachableNonConst(b_id, a_id)) { + graph.IsReachableNonConst(b_id, a_id) || !can_merge(a, b)) { continue; } @@ -437,7 +439,7 @@ absl::StatusOr DotMerger::Run( for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool changed_computation, - MergeDots(comp, max_size_to_merge_)); + MergeDots(comp, max_size_to_merge_, can_merge_)); changed |= changed_computation; } return changed; diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.h b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.h new file mode 100644 index 00000000000000..bde8d3f9df6140 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger.h @@ -0,0 +1,81 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_MERGER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_MERGER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Merges dots that share an operand. Transforms +// +// x = dot(a, b) +// y = dot(a, c) +// +// into +// +// z = dot(a, concat(b, c)) +// x = slice(z) +// y = slice(z). +// +// This requires that x and y are independent -- that is, x does not +// transitively depend on y, and y does not transitively depend on x. +// +// This is a good transformation if the merged dot runs faster than the original +// dots. On the other hand, merging the dots results in a single result buffer +// z whose live range is the union of x and y's live ranges, so can lead to +// increased memory pressure. You probably only want to do this optimization on +// "small" dots which cannot saturate your device when run alone. +// +// We thus allow backends to set a max size above which an op will not be +// merged. The input+output bytes of at least one dot must be below the +// threshold otherwise we won't merge. (We don't require that both dots be +// below the threshold because backends likely want to allow merging a "small" +// dot into a "large" dot while preventing two large dots from being merged.) +// +// Will skip gemms with more than one non-contracting dimension in the dot +// operands to be concatenated. +class DotMerger : public HloModulePass { + public: + explicit DotMerger( + int64_t max_size_to_merge, + std::function + can_merge = [](const HloInstruction* dot_a, + const HloInstruction* dot_b) -> bool { return true; }) + : max_size_to_merge_(max_size_to_merge), can_merge_(can_merge) {} + + absl::string_view name() const override { return "dot-merger"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + int64_t max_size_to_merge_; + // Predicate function for backend-specific compatibility check. + std::function + can_merge_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_DOT_MERGER_H_ diff --git a/third_party/xla/xla/service/dot_merger_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger_test.cc similarity index 95% rename from third_party/xla/xla/service/dot_merger_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dot_merger_test.cc index 786970e7904f96..e352f74636beca 100644 --- a/third_party/xla/xla/service/dot_merger_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dot_merger_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dot_merger.h" +#include "xla/hlo/transforms/simplifiers/dot_merger.h" #include #include @@ -23,12 +23,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -37,11 +37,12 @@ namespace { namespace m = ::xla::match; -class DotMergerTest : public HloTestBase { +class DotMergerTest : public HloHardwareIndependentTestBase { public: DotMergerTest() - : HloTestBase(/*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/false) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/false) {} }; TEST_F(DotMergerTest, MergeRHS) { @@ -813,5 +814,30 @@ TEST_F(DotMergerTest, MergeSparseDotsDifferentMetadata) { EXPECT_FALSE(changed); } +TEST_F(DotMergerTest, NoMergeWithFalseCompatibility) { + absl::string_view module_string = R"( + HloModule module + + ENTRY main { + lhs0 = f32[2,4,100,200] parameter(0) + lhs1 = f32[2,4,300,200] parameter(1) + rhs = f32[2,4,200, 50] parameter(2) + dot0 = f32[2,4,100, 50] dot(lhs0, rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={2} + dot1 = f32[2,4,300, 50] dot(lhs1, rhs), lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, + lhs_contracting_dims={3}, rhs_contracting_dims={2} + ROOT tuple = (f32[2,4,100,50], f32[2,4,300,50]) tuple(dot0, dot1) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_string)); + std::function + can_merge = [&](const HloInstruction* dot_a, + const HloInstruction* dot_b) -> bool { return false; }; + DotMerger pass(/*max_size_to_merge=*/std::numeric_limits::max(), + can_merge); + TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/dynamic_dimension_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc similarity index 98% rename from third_party/xla/xla/service/dynamic_dimension_simplifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc index 003a499d28ae4c..c220ae1df608a1 100644 --- a/third_party/xla/xla/service/dynamic_dimension_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dynamic_dimension_simplifier.h" +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h new file mode 100644 index 00000000000000..171e9c02f59b93 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h @@ -0,0 +1,40 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_DYNAMIC_DIMENSION_SIMPLIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_DYNAMIC_DIMENSION_SIMPLIFIER_H_ + +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This pass simplifies operations on dynamic dimension sizes so that it can be +// easily analyzed by later passes. +class DynamicDimensionSimplifier : public HloModulePass { + public: + absl::string_view name() const override { + return "dynamic-dimension-simplifier"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_DYNAMIC_DIMENSION_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier_test.cc similarity index 96% rename from third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier_test.cc index 3b71fc4bf128ad..6112503842fdec 100644 --- a/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/dynamic_dimension_simplifier.h" +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #include #include @@ -25,17 +25,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/shape_inference.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/window_util.h" @@ -46,7 +46,7 @@ namespace { namespace m = match; -class DynamicDimensionSimplifierTest : public HloTestBase {}; +class DynamicDimensionSimplifierTest : public HloHardwareIndependentTestBase {}; TEST_F(DynamicDimensionSimplifierTest, ForwardConcat) { const char* kModuleStr = R"( diff --git a/third_party/xla/xla/service/flatten_call_graph.cc b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc similarity index 99% rename from third_party/xla/xla/service/flatten_call_graph.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc index c40e91905eb192..68a017b16050d6 100644 --- a/third_party/xla/xla/service/flatten_call_graph.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.h b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.h new file mode 100644 index 00000000000000..6d35a483e6e4aa --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Flatten the call graph for an HLO module into a tree. + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLATTEN_CALL_GRAPH_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLATTEN_CALL_GRAPH_H_ + +#include "absl/status/statusor.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Flattening associates each call site with a unique computation (for +// sequential calling contexts) This simplifies buffer assignment and +// points-to analysis (see b/36865746 for details). +class FlattenCallGraph : public HloModulePass { + public: + absl::string_view name() const override { return "flatten-call-graph"; } + + // Duplicates computations called from multiple call- or while-nodes to + // flatten the call graph. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLATTEN_CALL_GRAPH_H_ diff --git a/third_party/xla/xla/service/flatten_call_graph_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc similarity index 98% rename from third_party/xla/xla/service/flatten_call_graph_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc index 70cba29463b7a2..a5653857fe3ff4 100644 --- a/third_party/xla/xla/service/flatten_call_graph_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/flatten_call_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include #include @@ -23,12 +23,12 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -36,7 +36,7 @@ limitations under the License. namespace xla { namespace { -class FlattenCallGraphTest : public HloTestBase { +class FlattenCallGraphTest : public HloHardwareIndependentTestBase { protected: // Build and return a trivial computation taking and returning a scalar. std::unique_ptr MakeScalarComputation() { diff --git a/third_party/xla/xla/service/float_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc similarity index 98% rename from third_party/xla/xla/service/float_normalization.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc index bb6ff30a335caa..b6d8a532054502 100644 --- a/third_party/xla/xla/service/float_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include #include @@ -30,11 +30,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/primitive_util.h" #include "xla/service/call_graph.h" #include "xla/service/float_support.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -199,6 +199,10 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( } bool is_root = computation->root_instruction() == hlo; + bool allow_excess_precision = computation->parent() + ->config() + .debug_options() + .xla_allow_excess_precision(); // If we are rewriting the root instruction of the entry computation, we need // to save and restore original input output alias config. @@ -252,7 +256,7 @@ absl::Status FloatNormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( // Tuple [fp32, bf16] // So we should keep the 'Convert' and replace it after all of the other // users has been replaced. - if (user->opcode() == HloOpcode::kConvert && + if (allow_excess_precision && user->opcode() == HloOpcode::kConvert && user->shape().element_type() == to && to == HighPrecisionType() && from == LowPrecisionType()) { conversions_to_simplify.emplace_back(user); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.h b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.h new file mode 100644 index 00000000000000..2f3807cb5815b3 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization.h @@ -0,0 +1,105 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLOAT_NORMALIZATION_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLOAT_NORMALIZATION_H_ + +#include + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/float_support.h" + +namespace xla { + +// A pass which adds type conversions (e.g. F32 <-> BF16) for HLO instructions +// that do not support low-precision input/output or mixed precision, according +// to the passed-in backend-specific FloatSupport instance. +class FloatNormalization : public HloModulePass { + public: + explicit FloatNormalization(const FloatSupport* float_support) + : float_support_(float_support), + name_("float-normalization-" + + primitive_util::LowercasePrimitiveTypeName( + float_support_->LowPrecisionType())) {} + + ~FloatNormalization() override = default; + absl::string_view name() const override { return name_; } + + // Run float normalization on the given computation. Returns whether the + // computation was changed. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const FloatSupport* float_support_; + std::string name_; +}; + +// A pass that unconditionally removes the mixed F32/BF16 uses in HLO +// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike +// FloatNormalization, this pass does not use a backend-specific +// FloatSupport, and does not change HLOs that have BF16 data if they do not +// use mixed precision; it removes mixed precision even if the backend supports +// it. This pass is used to make the HLO module valid for other HLO passes which +// do not support mixed precision. Currently, this pass is only used by the +// Despecializer, not by our normal compilation flow on TPU. +class BFloat16MixedPrecisionRemoval : public HloModulePass { + public: + BFloat16MixedPrecisionRemoval() = default; + + ~BFloat16MixedPrecisionRemoval() override = default; + + absl::string_view name() const override { + return "bf16-mixed-precision-removal"; + } + + // Run mixed precision removal on the given computation. Returns whether the + // computation was changed. + using HloPassInterface::Run; + absl::StatusOr Run(HloModule* module, + const absl::flat_hash_set& + execution_threads) override { + FloatNormalization normalization(&no_mixed_precision_support_); + return normalization.Run(module, execution_threads); + } + + private: + class BFloat16SupportForMixedPrecisionRemoval : public FloatSupport { + public: + BFloat16SupportForMixedPrecisionRemoval() : FloatSupport(BF16) {} + + ~BFloat16SupportForMixedPrecisionRemoval() override = default; + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + return true; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return false; + } + } no_mixed_precision_support_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_FLOAT_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/float_normalization_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc similarity index 93% rename from third_party/xla/xla/service/float_normalization_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc index a140d2e933af9a..86ec889abc6527 100644 --- a/third_party/xla/xla/service/float_normalization_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/float_normalization_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include #include @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" @@ -28,18 +29,22 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/float_support.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_verifier.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { +namespace { +namespace m = match; class TestFloatSupport : public FloatSupport { public: @@ -118,11 +123,12 @@ class TestFloatNoComputeSupport : public FloatSupport { } }; -class FloatNormalizationTest : public HloTestBase { +class FloatNormalizationTest : public HloHardwareIndependentTestBase { protected: FloatNormalizationTest() - : HloTestBase(/*verifier_layout_sensitive=*/false, - /*allow_mixed_precision_in_hlo_verifier=*/true) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/false, + /*allow_mixed_precision_in_hlo_verifier=*/true) {} bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16, PrimitiveType high_precision_type = F32) { @@ -144,7 +150,7 @@ class FloatNormalizationF8Test public ::testing::WithParamInterface {}; INSTANTIATE_TEST_SUITE_P(FloatNormalizationF8Suite, FloatNormalizationF8Test, - ::testing::Values(F8E5M2)); + ::testing::Values(F8E3M4, F8E4M3, F8E5M2)); TEST_F(FloatNormalizationTest, NoopIfSupported) { auto builder = HloComputation::Builder(TestName()); @@ -793,4 +799,48 @@ TEST_F(FloatNormalizationTest, KeepEntryInputOutputAlias) { HloInputOutputAliasConfig::AliasKind::kMustAlias); } +TEST_F(FloatNormalizationTest, AllowExcessPrecisionTrue) { + const std::string hlo_text = R"( + HloModule dot_with_convert + + ENTRY main { + Arg_0 = bf16[2,2]{1,0} parameter(0) + dot = bf16[2,2]{1,0} dot(Arg_0, Arg_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + convert = f32[2,2]{1,0} convert(dot) + ROOT tuple = (f32[2,2]{1,0}, bf16[2,2]{1,0}) tuple(convert, dot) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + module->mutable_config() + .mutable_debug_options() + .set_xla_allow_excess_precision(true); + + EXPECT_TRUE(Normalize(module.get(), BF16)); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Dot(), m::Convert()))); +} + +TEST_F(FloatNormalizationTest, AllowExcessPrecisionFalse) { + const std::string hlo_text = R"( + HloModule dot_with_convert + + ENTRY main { + Arg_0 = bf16[2,2]{1,0} parameter(0) + dot = bf16[2,2]{1,0} dot(Arg_0, Arg_0), lhs_contracting_dims={1}, rhs_contracting_dims={0} + convert = f32[2,2]{1,0} convert(dot) + ROOT tuple = (f32[2,2]{1,0}, bf16[2,2]{1,0}) tuple(convert, dot) + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + module->mutable_config() + .mutable_debug_options() + .set_xla_allow_excess_precision(false); + + EXPECT_TRUE(Normalize(module.get(), BF16)); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Tuple(m::Convert(m::Convert(m::Dot())), m::Convert()))); +} + +} // namespace } // namespace xla diff --git a/third_party/xla/xla/service/fusion_constant_sinking.cc b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc similarity index 96% rename from third_party/xla/xla/service/fusion_constant_sinking.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc index 9d9e210496d2e3..c15b4926c92c26 100644 --- a/third_party/xla/xla/service/fusion_constant_sinking.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/fusion_constant_sinking.h" +#include "xla/hlo/transforms/simplifiers/fusion_constant_sinking.h" #include #include @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.h b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.h new file mode 100644 index 00000000000000..6e1c7d9813d239 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking.h @@ -0,0 +1,39 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_FUSION_CONSTANT_SINKING_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_FUSION_CONSTANT_SINKING_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass which sinks constants into fusion computations. +class FusionConstantSinking : public HloModulePass { + public: + absl::string_view name() const override { return "fusion_constant_sinking"; } + + // Run fusion constant sinking operations on the given module. Returns whether + // the module was changed (constant expressions folded). + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_FUSION_CONSTANT_SINKING_H_ diff --git a/third_party/xla/xla/service/fusion_constant_sinking_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking_test.cc similarity index 97% rename from third_party/xla/xla/service/fusion_constant_sinking_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking_test.cc index d822f03bd46b9f..4fca970c832950 100644 --- a/third_party/xla/xla/service/fusion_constant_sinking_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/fusion_constant_sinking_test.cc @@ -13,23 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/fusion_constant_sinking.h" +#include "xla/hlo/transforms/simplifiers/fusion_constant_sinking.h" #include #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -using FusionConstantSinkingTest = HloTestBase; +using FusionConstantSinkingTest = HloHardwareIndependentTestBase; TEST_F(FusionConstantSinkingTest, SinkConstant) { std::string hlo_string = R"( diff --git a/third_party/xla/xla/service/gather_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc similarity index 98% rename from third_party/xla/xla/service/gather_simplifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc index 354d26b4026a68..64b7aa8efba0de 100644 --- a/third_party/xla/xla/service/gather_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gather_simplifier.h" +#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.h new file mode 100644 index 00000000000000..da7b17c847158b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_GATHER_SIMPLIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_GATHER_SIMPLIFIER_H_ + +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// This pass rewrites gather operations into a combination of transposes, +// reshapes and a simpler gather. +// +// The output gather's attributes will have the following characteristics: +// - start_indices is a two-dimensional tensor +// - index_vector_dim is 1 +// - start_index_map is [0, 1, ...] +// - collapsed_slice_dims is [] +// - offset_dims is [1, 2, ...] +// +// The purpose of this pass is to check whether this transformation has any +// performance implications. +class GatherSimplifier : public OpExpanderPass { + public: + absl::string_view name() const override { return "gather_simplifier"; } + + static bool IsSimplifiedGather(const HloGatherInstruction* gather); + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_GATHER_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/gather_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier_test.cc similarity index 96% rename from third_party/xla/xla/service/gather_simplifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier_test.cc index 61b8bc716e120b..22f4e2e9205793 100644 --- a/third_party/xla/xla/service/gather_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/gather_simplifier_test.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gather_simplifier.h" +#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" #include -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" namespace xla { namespace { -class GatherSimplifierTest : public HloTestBase {}; +class GatherSimplifierTest : public HloHardwareIndependentTestBase {}; TEST_F(GatherSimplifierTest, TransformsStartIndices) { // Verifies that GatherSimplifier diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.cc similarity index 98% rename from third_party/xla/xla/service/hlo_computation_deduplicator.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.cc index 01fbfa2109f750..c3747f209750d2 100644 --- a/third_party/xla/xla/service/hlo_computation_deduplicator.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_computation_deduplicator.h" +#include "xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h new file mode 100644 index 00000000000000..7d04caede058dd --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h @@ -0,0 +1,51 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_COMPUTATION_DEDUPLICATOR_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_COMPUTATION_DEDUPLICATOR_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Deduplicate computations inside a `HloModule`: If two computations are +// identical then keep the first one (in postorder terms) and remove the rest. +class HloComputationDeduplicator : public HloModulePass { + public: + // Setting mark_fusion_duplications to true will only process fusions in the + // HLO. The comparator in this pass will mark duplicate fusions which is + // needed for groupings in analysis (e.g. Xprof). Currently, the pass + // doesn't change the HLO if the flag is set to true. + explicit HloComputationDeduplicator(bool mark_fusion_duplications = false) + : mark_fusion_duplications_(mark_fusion_duplications) {} + absl::string_view name() const override { return "computation-deduplicator"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + bool ContainsLargeConstants(HloComputation* comp); + bool mark_fusion_duplications_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_COMPUTATION_DEDUPLICATOR_H_ diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_computation_deduplicator_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc index 115d047ab57ea7..85b8e1c9619589 100644 --- a/third_party/xla/xla/service/hlo_computation_deduplicator_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_computation_deduplicator_test.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_computation_deduplicator.h" +#include "xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h" #include #include @@ -26,17 +26,17 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class HloComputationDeduplicatorTest : public HloTestBase { +class HloComputationDeduplicatorTest : public HloHardwareIndependentTestBase { protected: std::vector RunDeduplicatePass(const std::string_view text, bool expect_true) { diff --git a/third_party/xla/xla/service/hlo_constant_folding.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc similarity index 99% rename from third_party/xla/xla/service/hlo_constant_folding.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc index cb2852fbb2e619..c8a78e3ba854fc 100644 --- a/third_party/xla/xla/service/hlo_constant_folding.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.h new file mode 100644 index 00000000000000..d3f7f704267a84 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_FOLDING_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_FOLDING_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass which performs constant folding in order to avoid unnecessary +// computation on constants. +class HloConstantFolding : public HloModulePass { + public: + absl::string_view name() const override { return "constant_folding"; } + + // Run constant folding operations on the given module. Returns whether the + // module was changed (constant expressions folded). + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Number of slow constant-folds we've encountered. Used for firing + // SlowOperationAlarms. + static std::atomic slow_op_counter_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_FOLDING_H_ diff --git a/third_party/xla/xla/service/hlo_constant_folding_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding_test.cc similarity index 98% rename from third_party/xla/xla/service/hlo_constant_folding_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding_test.cc index 255012f734a5f1..5adaca9b3abdee 100644 --- a/third_party/xla/xla/service/hlo_constant_folding_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_folding_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" #include #include @@ -24,19 +24,19 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -45,7 +45,7 @@ namespace { namespace op = xla::testing::opcode_matchers; namespace m = xla::match; -using HloConstantFoldingTest = HloTestBase; +using HloConstantFoldingTest = HloHardwareIndependentTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); diff --git a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc similarity index 99% rename from third_party/xla/xla/hlo/transforms/hlo_constant_splitter.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc index 1318acd55e2c25..d804e24cc985ed 100644 --- a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/transforms/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_splitter.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.h similarity index 91% rename from third_party/xla/xla/hlo/transforms/hlo_constant_splitter.h rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.h index bea1fafe33b7fc..2f58909fb8a7d4 100644 --- a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter.h +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter.h @@ -12,8 +12,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ -#define XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_SPLITTER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" @@ -57,4 +57,4 @@ class HloConstantSplitter : public HloModulePass { } // namespace xla -#endif // XLA_HLO_TRANSFORMS_HLO_CONSTANT_SPLITTER_H_ +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_CONSTANT_SPLITTER_H_ diff --git a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc similarity index 93% rename from third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc index c7ebf8459502e8..6a9dc33350c5fd 100644 --- a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_constant_splitter_test.cc @@ -12,17 +12,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/hlo/transforms/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_splitter.h" #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -30,7 +30,7 @@ limitations under the License. namespace xla { namespace { -using HloConstantSplitterTest = HloTestBase; +using HloConstantSplitterTest = HloHardwareIndependentTestBase; TEST_F(HloConstantSplitterTest, SplitConstants) { const char* module_str = R"( @@ -123,7 +123,8 @@ TEST_F(HloConstantSplitterTest, PreservingConstantsWithZeroUsers) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(module_str)); HloConstantSplitter pass = HloConstantSplitter(); - const auto status_or = HloTestBase::RunHloPass(&pass, module.get()); + const auto status_or = + HloHardwareIndependentTestBase::RunHloPass(&pass, module.get()); TF_ASSERT_OK(status_or.status()); // Verify that the changed flag returned is correct. EXPECT_FALSE(status_or.value()); @@ -154,7 +155,8 @@ TEST_F(HloConstantSplitterTest, SplittingExpressionsWithBroadcast) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(module_str)); HloConstantSplitter pass = HloConstantSplitter(/*split_expressions=*/true); - const auto status_or = HloTestBase::RunHloPass(&pass, module.get()); + const auto status_or = + HloHardwareIndependentTestBase::RunHloPass(&pass, module.get()); TF_ASSERT_OK(status_or.status()); // Verify that the changed flag returned is correct. EXPECT_TRUE(status_or.value()); @@ -184,7 +186,8 @@ TEST_F(HloConstantSplitterTest, SplittingExpressionsWithSlice) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(module_str)); HloConstantSplitter pass = HloConstantSplitter(/*split_expressions=*/true); - const auto status_or = HloTestBase::RunHloPass(&pass, module.get()); + const auto status_or = + HloHardwareIndependentTestBase::RunHloPass(&pass, module.get()); TF_ASSERT_OK(status_or.status()); // Verify that the changed flag returned is correct. EXPECT_TRUE(status_or.value()); @@ -223,8 +226,9 @@ TEST_F(HloConstantSplitterTest, NoSplittingSideEffectExpressions) { HloConstantSplitter pass = HloConstantSplitter(/*split_expressions=*/true); const int64_t count_before = module->entry_computation()->instruction_count(); - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloTestBase::RunHloPass(&pass, module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + HloHardwareIndependentTestBase::RunHloPass(&pass, module.get())); HloDCE dce; TF_ASSERT_OK(dce.Run(module.get()).status()); const int64_t count_after_dce = @@ -280,8 +284,9 @@ TEST_F(HloConstantSplitterTest, InstructionsWithOneUser) { ParseAndReturnUnverifiedModule(module_str)); HloConstantSplitter pass = HloConstantSplitter(/*split_expressions=*/true); // Verify that the module is not changed as splitting on rng is prevented. - TF_ASSERT_OK_AND_ASSIGN(bool changed, - HloTestBase::RunHloPass(&pass, module.get())); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + HloHardwareIndependentTestBase::RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); int64_t broadcast_count_before_dce = 0, broadcast_count_after_dce = 0; diff --git a/third_party/xla/xla/service/hlo_dce.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc similarity index 67% rename from third_party/xla/xla/service/hlo_dce.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc index 5617190d36b854..89df0f3b9f57d9 100644 --- a/third_party/xla/xla/service/hlo_dce.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include #include #include #include +#include #include #include @@ -67,104 +68,112 @@ bool IsRemovableWhile(HloInstruction* instruction, return true; } -} // namespace +// Returns true if it found and removed unused outputs. +absl::StatusOr RemoveMultiOutputFusionsUnusedOutputs( + HloComputation* computation) { + HloInstruction* fusion_instruction = computation->FusionInstruction(); + if (!fusion_instruction) { + return false; + } -/*static*/ absl::StatusOr HloDCE::RunOnComputation( - HloComputation* computation, bool remove_cross_partition_collective_ops) { - bool changed = false; - // Cleanup unused tuple elements in multi-output fusion roots. We do this - // first, because it may create dead roots which we can clean up next. - if (auto* fusion_instruction = computation->FusionInstruction(); - fusion_instruction != nullptr && - computation->root_instruction()->opcode() == HloOpcode::kTuple && - !computation->root_instruction()->has_sharding() && - fusion_instruction->output_operand_aliasing().empty() && - !fusion_instruction->HasControlDependencies() && - !fusion_instruction->IsCustomFusion()) { - // The order of the used outputs is relevant for the algorithm below. - std::set used_tuple_elements; - // We only support this cleanup if all users of the fusion instruction are - // GetTupleElement ops, and there is at least one user of - // 'fusion_instruction'. - bool supported = fusion_instruction->user_count() > 0; - for (HloInstruction* gte : fusion_instruction->users()) { - if (gte->opcode() != HloOpcode::kGetTupleElement) { - supported = false; - break; - } - used_tuple_elements.insert(gte->tuple_index()); - } + if (computation->root_instruction()->opcode() != HloOpcode::kTuple || + computation->root_instruction()->has_sharding() || + !fusion_instruction->output_operand_aliasing().empty() || + fusion_instruction->HasControlDependencies() || + fusion_instruction->IsCustomFusion()) { + return false; + } + + // The order of the used outputs is relevant for the algorithm below. + std::set used_tuple_elements; + + // We only support this cleanup if all users of the fusion instruction are + // GetTupleElement ops, and there is at least one user of + // 'fusion_instruction'. + if (fusion_instruction->users().empty()) { + return false; + } - // If all outputs are used, nothing to clean up. - if (used_tuple_elements.size() == - computation->root_instruction()->operand_count()) { - supported = false; + for (HloInstruction* gte : fusion_instruction->users()) { + if (gte->opcode() != HloOpcode::kGetTupleElement) { + return false; } + used_tuple_elements.insert(gte->tuple_index()); + } - if (supported) { - std::vector tuple_shapes; - tuple_shapes.reserve(used_tuple_elements.size()); - for (int64_t tuple_index : used_tuple_elements) { - tuple_shapes.push_back( - fusion_instruction->shape().tuple_shapes(tuple_index)); - } - Shape new_shape = tuple_shapes.size() == 1 - ? tuple_shapes[0] - : ShapeUtil::MakeTupleShape(tuple_shapes); - *fusion_instruction->mutable_shape() = std::move(new_shape); - - // Update the users of the old fusion instruction. - if (tuple_shapes.size() > 1) { - for (HloInstruction* gte : fusion_instruction->users()) { - auto it = - std::lower_bound(used_tuple_elements.begin(), - used_tuple_elements.end(), gte->tuple_index()); - int64_t new_tuple_index = - std::distance(used_tuple_elements.begin(), it); - gte->set_tuple_index(new_tuple_index); - } - } else { - // Since we iterate over users while removing them .. make a local copy - // first. - std::vector users(fusion_instruction->users()); - for (HloInstruction* gte : users) { - // Replace and change control successors to be dependent on the fusion - // instruction itself. - TF_ASSIGN_OR_RETURN(bool replaced, - gte->parent()->ReplaceInstruction( - gte, fusion_instruction, - /*preserve_sharding=*/true, - /*relay_control_dependency=*/true)); - if (replaced) { - changed |= replaced; - } - } - } + // If all outputs are used, nothing to clean up. + if (used_tuple_elements.size() == + computation->root_instruction()->operand_count()) { + return false; + } - // Update the root of the fusion computation. - if (tuple_shapes.size() > 1) { - std::vector new_operands; - new_operands.reserve(used_tuple_elements.size()); - for (int64_t tuple_index : used_tuple_elements) { - new_operands.push_back( - computation->root_instruction()->mutable_operand(tuple_index)); - } - auto new_tuple = computation->AddInstruction( - HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( - computation->root_instruction(), new_tuple)); - } else { - TF_RETURN_IF_ERROR( - computation->root_instruction()->ReplaceAllUsesWithDifferentShape( - computation->root_instruction()->mutable_operand( - *used_tuple_elements.begin()))); - } + std::vector tuple_shapes; + tuple_shapes.reserve(used_tuple_elements.size()); + for (int64_t tuple_index : used_tuple_elements) { + tuple_shapes.push_back( + fusion_instruction->shape().tuple_shapes(tuple_index)); + } + Shape new_shape = tuple_shapes.size() == 1 + ? tuple_shapes[0] + : ShapeUtil::MakeTupleShape(tuple_shapes); + *fusion_instruction->mutable_shape() = std::move(new_shape); + + // Update the users of the old fusion instruction. + if (tuple_shapes.size() > 1) { + for (HloInstruction* gte : fusion_instruction->users()) { + auto it = used_tuple_elements.lower_bound(gte->tuple_index()); + int64_t new_tuple_index = std::distance(used_tuple_elements.begin(), it); + gte->set_tuple_index(new_tuple_index); + } + } else { + // Since we iterate over users while removing them .. make a local copy + // first. + std::vector users(fusion_instruction->users()); + for (HloInstruction* gte : users) { + // Replace and change control successors to be dependent on the fusion + // instruction itself. + TF_ASSIGN_OR_RETURN(std::ignore, gte->parent()->ReplaceInstruction( + gte, fusion_instruction, + /*preserve_sharding=*/true, + /*relay_control_dependency=*/true)); + } + } + + // Update the root of the fusion computation. + if (tuple_shapes.size() > 1) { + std::vector new_operands; + new_operands.reserve(used_tuple_elements.size()); + for (int64_t tuple_index : used_tuple_elements) { + new_operands.push_back( + computation->root_instruction()->mutable_operand(tuple_index)); } + auto new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + computation->root_instruction(), new_tuple)); + } else { + TF_RETURN_IF_ERROR( + computation->root_instruction()->ReplaceAllUsesWithDifferentShape( + computation->root_instruction()->mutable_operand( + *used_tuple_elements.begin()))); } - // Remove any dead roots and their dead transitive operands. Collect them - // into a separate list first to avoid problems with iterating through the - // computation's instruction while simultaneously removing instructions. + // We always updated the fusion if we got here. + return true; +} + +} // namespace + +/*static*/ absl::StatusOr HloDCE::RunOnComputation( + HloComputation* computation, bool remove_cross_partition_collective_ops) { + // We do this first, because it may create dead roots which we can clean up + // next. + TF_ASSIGN_OR_RETURN(bool changed, + RemoveMultiOutputFusionsUnusedOutputs(computation)); + + // Remove any dead roots and their dead transitive operands. Collect + // them into a separate list first to avoid problems with iterating through + // the computation's instruction while simultaneously removing instructions. std::vector dead_roots; for (auto* instruction : computation->instructions()) { auto maybe_collective_op = DynCast(instruction); diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h new file mode 100644 index 00000000000000..7fe1ceb1909662 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_DCE_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_DCE_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// HLO pass which removes dead instructions from each computation in the module +// and removes dead computations from the module. +// +// An instruction is dead if it is not reachable from the root. A computation is +// dead if it is not the entry computation of the module and it is not reachable +// from the entry computation. +// +// This pass does not remove dead parameter instructions, as parameter +// instructions cannot be deleted. +class HloDCE : public HloModulePass { + public: + HloDCE() : remove_cross_partition_collective_ops_(false) {} + explicit HloDCE(bool remove_cross_partition_collective_ops) + : remove_cross_partition_collective_ops_( + remove_cross_partition_collective_ops) {} + ~HloDCE() override {} + absl::string_view name() const override { return "dce"; } + + // Run DCE on a computation. + static absl::StatusOr RunOnComputation( + HloComputation* computation, bool remove_cross_partition_collective_ops); + + // Run the pass on the given module. Returns whether the module was changed + // (instructions were removed). + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // Finds all computations that are not called by any instruction and removes + // them from the module. Returns whether any dead code was removed. + absl::StatusOr RecursivelyRemoveDeadComputations(HloModule* module); + + // Given a dead computation, decrements the ref count of all its called + // computations and checks if any of the subcomputations become dead after the + // removal. Returns whether all dead computations were successfully removed + // from the module. + absl::Status RecursivelyRemoveDeadComputation( + HloModule* module, HloComputation* computation, + absl::flat_hash_map& live_call_counts); + + bool remove_cross_partition_collective_ops_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_DCE_H_ diff --git a/third_party/xla/xla/service/hlo_dce_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_dce_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc index 765af9fecbeb8d..22a385a32f9b97 100644 --- a/third_party/xla/xla/service/hlo_dce_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_dce_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include #include @@ -28,14 +28,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -45,7 +44,7 @@ namespace { namespace m = ::xla::match; -class HloDceTest : public HloTestBase { +class HloDceTest : public HloHardwareIndependentTestBase { protected: HloDceTest() {} diff --git a/third_party/xla/xla/service/hlo_element_type_converter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc similarity index 99% rename from third_party/xla/xla/service/hlo_element_type_converter.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc index 05721ae9b2a2ce..7639055d7b4fe0 100644 --- a/third_party/xla/xla/service/hlo_element_type_converter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_element_type_converter.h" +#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.h new file mode 100644 index 00000000000000..382c09b685e245 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloModulePass { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + absl::string_view name() const override { return "element_type_converter"; } + + // Returns the pass on the module and returns whether the module was modified. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/third_party/xla/xla/service/hlo_element_type_converter_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter_test.cc similarity index 97% rename from third_party/xla/xla/service/hlo_element_type_converter_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter_test.cc index d7bfc8d0a09612..4adf215fdc5791 100644 --- a/third_party/xla/xla/service/hlo_element_type_converter_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_element_type_converter_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_element_type_converter.h" +#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { @@ -29,7 +29,7 @@ using ::testing::Eq; using ::testing::Not; using ::testing::ResultOf; -using HloElementTypeConverterTest = HloTestBase; +using HloElementTypeConverterTest = HloHardwareIndependentTestBase; TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { const std::string& hlo_string = R"( diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc similarity index 99% rename from third_party/xla/xla/service/hlo_memory_scheduler.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc index 83e40723895289..3035c0b408390b 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include #include @@ -31,6 +31,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -38,9 +40,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/logical_buffer.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape_util.h" #include "xla/util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h new file mode 100644 index 00000000000000..231030d2ad26d9 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h @@ -0,0 +1,191 @@ +/* Copyright 2016 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_MEMORY_SCHEDULER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_MEMORY_SCHEDULER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/logical_buffer.h" + +namespace xla { + +// Postprocessor of the HloInstructionSequence. This is an opt-in postprocessing +// function to MemorySchedulerAlgorithm to enforce certain hlo schedule +// constraints desired for custom-calls. +using MemorySchedulerPostprocessor = + std::function; + +// A memory scheduler computes an execution sequence for the HLO instructions in +// 'computation' that minimizes peak memory (or finds a balance between memory +// and available concurrency), given a points-to analysis result that describes +// buffer aliasing, together with a target-specific size function that maps a +// tensor's logical size to its padded size. peak_memory (may be nullptr) is set +// to the peak memory of the resulting schedule according to the HeapSimulator. +// +// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis. +using MemorySchedulerAlgorithm = + std::function( + HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, + const LogicalBuffer::SizeFunction&, + const MemorySchedulerPostprocessor&, + /*peak_memory*/ int64_t*)>; + +// Scheduler for the entire module. +using ModuleSchedulerAlgorithm = std::function( + const HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, + const LogicalBuffer::SizeFunction&, + const absl::flat_hash_set& execution_threads, + /*peak_memory*/ int64_t*)>; + +// Lift a computation scheduler into a module scheduler by calling the +// computation scheduler on all computations in a module. +ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( + const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {}); + +// List scheduler +absl::StatusOr ListMemoryScheduler( + HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); + +// DFS-order scheduler +absl::StatusOr DFSMemoryScheduler( + HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); + +// BFS-order scheduler +// +// BFS-order scheduler is a simple memory scheduler that schedules instructions +// in a breadth-first order, which maximizes the available concurrency at the +// cost of increased memory usage (HLO operations that do not have buffer +// conflicts can be executed in parallel). +// +// This is the most trivial scheduling optimized for maximum concurrency. In +// practice it is only useful for CPU backend where memory is cheap and we have +// a lot of available compute cores, and cheap concurrency primitives. +absl::StatusOr BFSMemoryScheduler( + HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); + +// Naive Post Order scheduler +absl::StatusOr PostOrderMemoryScheduler( + HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); + +// The default scheduling algorithm. Runs the list scheduler, the DFS scheduler, +// and the post-order scheduler and chooses whichever returns a lower min- +// memory, not accounting for fragmentation. peak_memory (may be nullptr) is set +// to the peak memory of the resulting schedule according to the HeapSimulator. +absl::StatusOr DefaultMemoryScheduler( + HloComputation* computation, + const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); + +absl::StatusOr DefaultModuleScheduler( + const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, + const HloAliasAnalysis& alias_analysis, + const LogicalBuffer::SizeFunction& size_function, + const absl::flat_hash_set& execution_threads, + int64_t* peak_memory); + +// Returns an HloSchedule which seeks to minimize the memory required for the +// module. size_function is the function returning the number of bytes required +// for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak +// memory (according to the HeapSimulator) of all computations in the module. +absl::StatusOr ScheduleModule( + const HloModule* module, const LogicalBuffer::SizeFunction& size_function, + const ModuleSchedulerAlgorithm& algorithm = {}, + const absl::flat_hash_set& execution_threads = {}, + int64_t* peak_memory = nullptr); + +// A pass which schedules the HLO instructions in a module. The HloModule's +// schedule field is set to the resulting HloSchedule using +// HloModule::set_schedule. +class HloMemoryScheduler : public HloModulePass { + public: + // size_function is the function returning the number of bytes required for a + // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not + // specified, then DefaultMemoryScheduler is used. + explicit HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, + const ModuleSchedulerAlgorithm& algorithm = {}); + + ~HloMemoryScheduler() override = default; + + absl::string_view name() const override { return "hlo-memory-scheduler"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + LogicalBuffer::SizeFunction size_function_; + + ModuleSchedulerAlgorithm algorithm_; +}; + +// A pass which produces a naive, but correct schedule. The schedule is produced +// using a DFS traversal of the graph with no attempt to minimize memory use. +class HloTrivialScheduler : public HloModulePass { + public: + absl::string_view name() const override { return "hlo-trivial-scheduler"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +// A trivial pass which clears the schedule currently set on the +// HloModule. After this pass runs HloModule::has_schedule will return false. +class HloDescheduler : public HloModulePass { + public: + HloDescheduler() = default; + ~HloDescheduler() override = default; + absl::string_view name() const override { return "hlo-descheduler"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_MEMORY_SCHEDULER_H_ diff --git a/third_party/xla/xla/service/hlo_memory_scheduler_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc similarity index 98% rename from third_party/xla/xla/service/hlo_memory_scheduler_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc index 62a13d14097887..74fcfb4d08106a 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_memory_scheduler_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include #include @@ -26,21 +26,21 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -50,7 +50,7 @@ limitations under the License. namespace xla { namespace { -class HloSchedulingTest : public HloTestBase {}; +class HloSchedulingTest : public HloHardwareIndependentTestBase {}; int64_t PeakMemoryUseOfEntryComputation( HloModule* module, LogicalBuffer::SizeFunction size_function) { diff --git a/third_party/xla/xla/service/hlo_rematerialization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc similarity index 98% rename from third_party/xla/xla/service/hlo_rematerialization.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc index fed3b52c29ecbb..7ad655f9a888a1 100644 --- a/third_party/xla/xla/service/hlo_rematerialization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_rematerialization.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" #include +#include #include #include #include @@ -39,6 +40,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -47,17 +52,20 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/layout_util.h" #include "xla/map_util.h" +#include "xla/service/call_graph.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -66,8 +74,7 @@ namespace { using ::tsl::strings::HumanReadableNumBytes; // Potential optimizations: -// . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue -// of candidates. +// . Avoid N^2 behavior by keeping a priority queue of candidates. // . Cache IsRematerializable in Item? Only correct if control // predecessors and successors don't change. @@ -660,7 +667,7 @@ class MemoryUsageTracker { const HloRematerialization::Options& options() const { return options_; } - // Check invariants of the data structure. This is expensive to call. + // Checks invariants of the data structure. This is expensive to call. bool Check() const; std::string ToString() const; @@ -710,21 +717,21 @@ class MemoryUsageTracker { } }; - // Adjust our tracked memory usage as a result of this new item coming into + // Adjusts our tracked memory usage as a result of this new item coming into // scope. void CountAllocatedMemory(Item* item); - // Adjust our tracked memory usage as a result of this item going out of + // Adjusts our tracked memory usage as a result of this item going out of // scope. absl::Status CountFreedMemory(Item* item); - // Buffers have users and users have buffers used, this function resolves + // Buffers have users and users have buffers used. This function resolves // outstanding issues in that bidirectional dependency. void ReplaceUsesInUsersOfBuffer(Buffer& buffer, BufferId old_id) const; - // Get the compact shape of given hlo instruction. An internal cache is used + // Gets the compact shape of given hlo instruction. An internal cache is used // to avoid computing the shape multiple times. - absl::StatusOr GetCompactShape(const HloInstruction* hlo); + absl::StatusOr GetCompactShape(const HloInstruction* hlo); // Creates a Buffer representing the given logical buffer. The buffer is added // to buffers_ and a reference is returned. @@ -739,7 +746,7 @@ class MemoryUsageTracker { std::move(users), live_out, has_indirect_uses); } - // Create a new buffer representing a rematerialization of given buffer for + // Creates a new buffer representing a rematerialization of given buffer for // the given uses. Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item, UsesList&& rematerialized_uses) { @@ -755,10 +762,10 @@ class MemoryUsageTracker { /*has_indirect_uses=*/false); } - // Return number of bytes allocated for the buffer with the given id. Buffers - // allocated by the calling computation (eg, parameter and output buffers) are - // considered to have zero bytes because the memory is accounted for in a - // different computation. + // Returns the number of bytes allocated for the buffer with the given id. + // Buffers allocated by the calling computation (eg, parameter and output + // buffers) are considered to have zero bytes because the memory is accounted + // for in a different computation. int64_t AllocatedSize(BufferId buffer_id) const { const Buffer& buffer = buffers_.at(buffer_id); HloInstruction* inst = buffer.defining_instruction->instruction; @@ -776,8 +783,8 @@ class MemoryUsageTracker { } } - // Returns true if BeginInstruction and EndInstruction has been called for the - // given instruction. + // Returns whether BeginInstruction and EndInstruction have been called for + // the given instruction. bool IsFinished(Item* item) const { return item->placed && item != in_progress_item_; } @@ -815,7 +822,7 @@ class MemoryUsageTracker { return false; } - // Create a new buffer, add it to buffers_, and return a reference. + // Creates a new buffer, adds it to buffers_, and returns a reference. Buffer& NewBuffer(Item* defining_instruction, const Shape& shape, const ShapeIndex& index, UsesList&& uses, bool live_out, bool has_indirect_uses) { @@ -1506,17 +1513,16 @@ std::string MemoryUsageTracker::ToString() const { return output; } -absl::StatusOr MemoryUsageTracker::GetCompactShape( +absl::StatusOr MemoryUsageTracker::GetCompactShape( const HloInstruction* hlo) { auto it = compact_shape_.find(hlo); if (it != compact_shape_.end()) { - return it->second; + return &it->second; } const Shape& original_shape = hlo->shape(); TF_ASSIGN_OR_RETURN(Shape min_shape, options_.compact_shape_function(original_shape)); - compact_shape_[hlo] = min_shape; - return min_shape; + return &compact_shape_.emplace(hlo, min_shape).first->second; } bool MemoryUsageTracker::Check() const { @@ -1660,9 +1666,10 @@ std::optional MemoryUsageTracker::GetCostOfCompression( return {}; } - Shape compact_shape = GetCompactShape(candidate_item->instruction).value(); + const Shape* compact_shape = + GetCompactShape(candidate_item->instruction).value(); const int64_t memory_reduced = - MemoryReducedIfCompressed(candidate_item, compact_shape); + MemoryReducedIfCompressed(candidate_item, *compact_shape); // Since the compressed and uncompressed buffers need to be alive // while performing the compression/uncompression, only perform // the compression if the sum of the two sizes is less than the @@ -1670,7 +1677,7 @@ std::optional MemoryUsageTracker::GetCostOfCompression( const int64_t size = options_.hlo_cost_analysis.GetShapeSize( candidate_item->instruction->shape()); const int64_t reduced_size = - options_.hlo_cost_analysis.GetShapeSize(compact_shape); + options_.hlo_cost_analysis.GetShapeSize(*compact_shape); // TODO(victorstone): I don't think this size check is right. if (memory_reduced > 0 && size + reduced_size < peak_memory_bytes) { return memory_limit_bytes / memory_reduced; @@ -1893,7 +1900,7 @@ MemoryUsageTracker::PickRematerializationCandidates( continue; } - // First, calculate the cost of compression rematerialziation for this + // First, calculate the cost of compression rematerialization for this // instruction. if (options_.remat_mode_config.compress && block.size() == 1) { auto cost = @@ -1907,7 +1914,7 @@ MemoryUsageTracker::PickRematerializationCandidates( // computed inside GetCostOfCompression, should we get it from there? Or // is it ok to recompute? best_strategy.compact_shape = - GetCompactShape(block[0]->instruction).value(); + *GetCompactShape(block[0]->instruction).value(); best_items = block; best_cost = *cost; } @@ -1989,6 +1996,8 @@ UsesList MemoryUsageTracker::GetItemUses(Item* item) const { return combined_users; } +// Performs the rematerialization of all items in `best_items` and returns the +// number of net instructions added. absl::StatusOr RematerializeInstructions( MemoryUsageTracker* memory_tracker, std::vector* best_items, absl::flat_hash_set* remat_move_instructions, @@ -2174,6 +2183,8 @@ absl::StatusOr RematerializeInstructions( return net_instructions_added; } +// Performs rematerialization of `best_item` via the compression strategy. +// Returns the net number of instructions added. absl::StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, Item* best_item, const Shape& compact_shape, @@ -2224,9 +2235,12 @@ absl::StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, instruction_list->InsertBeforeInstructions(uncompressed_item, place_before); instruction_list->InsertAfterInstructions(compressed_item, {best_item}); + // Net two instructions added. return 2; } +// Performs rematerialization of `best_item` via the host offload strategy. +// Returns the net number of instructions added. absl::StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, Item* best_item, InstructionList* instruction_list) { @@ -2486,6 +2500,7 @@ absl::StatusOr OffloadInstruction(MemoryUsageTracker* memory_tracker, best_item, copy_start_to_host_item, copy_done_to_host_item, copy_start_to_device_item, copy_done_to_device_item)); + // Net four instructions added. return 4; } diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.h new file mode 100644 index 00000000000000..69b5f27055341f --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.h @@ -0,0 +1,250 @@ +/* Copyright 2017 The OpenXLA Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/call_graph.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" + +namespace xla { + +// HLO pass which rematerializes instructions to reduce peak memory use, where +// memory use is defined as the total size of all live HLO instruction +// values. Parameters and constants are included in memory use estimates. +// +// CSE will undo the effects of this optimization and should not be run after +// this pass. In general, this pass should be run very late, immediately before +// code generation. +class HloRematerialization : public HloModulePass { + public: + using ShapeSizeFunction = std::function; + + using CompactShapeFunction = + std::function(const Shape&)>; + + // Helper struct that communicates the before / after sizes for the + // rematerialization process. + struct RematerializationSizes { + int64_t before_bytes = -1; + int64_t after_bytes = -1; + }; + + // Mode in which the rematerialization algorithm should be run. + struct RematerializationModeConfig { + RematerializationModeConfig(bool recompute, bool compress, + bool host_offload) + : recompute(recompute), + compress(compress), + host_offload(host_offload) {} + bool recompute; // Enables the kRecompute RematStrategy. + bool compress; // Enables the kCompress RematStrategy. + bool host_offload; // Enables the kHostOffload RematStrategy. + }; + + // This is a struct containing configuration options that are specific to the + // Host Memory Offload strategy. + struct HostMemoryOffloadConfig { + explicit HostMemoryOffloadConfig(int64_t host_memory_space, + float bandwidth_to_host_bytes_per_second, + float bandwidth_from_host_bytes_per_second) + : host_memory_space(host_memory_space), + bandwidth_to_host_bytes_per_second( + bandwidth_to_host_bytes_per_second), + bandwidth_from_host_bytes_per_second( + bandwidth_from_host_bytes_per_second) {} + + // The host memory space, which is used during the host offload strategy. + int64_t host_memory_space; + + float bandwidth_to_host_bytes_per_second; + + float bandwidth_from_host_bytes_per_second; + }; + + static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } + + struct Options { + explicit Options(HloCostAnalysis& hlo_cost_analysis, + const RematerializationModeConfig& remat_mode_config, + int64_t memory_limit_bytes, int block_size_limit, + int block_rematerialization_factor, int64_t min_remat_size, + CompactShapeFunction compact_shape_function, + std::optional + host_memory_offload_config = std::nullopt, + absl::flat_hash_map + async_computation_parallelism = {}) + : hlo_cost_analysis(hlo_cost_analysis), + remat_mode_config(remat_mode_config), + memory_limit_bytes(memory_limit_bytes), + block_size_limit(block_size_limit), + block_rematerialization_factor(block_rematerialization_factor), + min_remat_size(min_remat_size), + compact_shape_function(compact_shape_function == nullptr + ? DefaultCompactShapeFunction + : std::move(compact_shape_function)), + host_memory_offload_config(host_memory_offload_config), + async_computation_parallelism(async_computation_parallelism) {} + + // The cost model used for decisions during rematerialization for host + // memory offload. It is also used for getting Shape size. + HloCostAnalysis& hlo_cost_analysis; + + // Holds the rematerialization strategy configuration to be used by the + // pass. + RematerializationModeConfig remat_mode_config; + + // Function which computes the size of the top-level buffer of a shape. + const ShapeSizeFunction size_function; + + // The threshold number of bytes to reduce memory use to via + // rematerialization. Size of aliased outputs should be subtracted + // from this. + int64_t memory_limit_bytes; + + // Maximum number of consecutive instructions to consider for + // rematerialization. + int block_size_limit; + + // Controls the amount of effort spent trying to find large blocks for + // rematerialization. Larger values leads to longer compilation times in + // return for potentially reduced memory consumption. + int block_rematerialization_factor; + + // The minimum size, in bytes, of a tensor to be considered for + // rematerialization. All tensors smaller than this size will be skipped + // over. + int64_t min_remat_size; + + // Converts a shape into compact form, returns the same shape if a shape is + // already considered compact. + CompactShapeFunction compact_shape_function; + + std::optional host_memory_offload_config; + + // Collection of async entry computations and their number of parallel + // invocations. + absl::flat_hash_map async_computation_parallelism; + }; + + explicit HloRematerialization(Options options, RematerializationSizes& sizes) + : options_(std::move(options)), sizes_(sizes) {} + + ~HloRematerialization() override = default; + + absl::string_view name() const override { return "rematerialization"; } + + // Get the next available channel id and increment count. + int64_t NextChannelId() { return next_channel_id_++; } + + // Get the peak memory for the computation. + int64_t ComputationPeakMemory(const HloComputation* computation) const { + return computation_peak_memory_.at(computation); + } + + // Runs rematerialization on the given module. Returns whether the module was + // changed. Requires that the module has a schedule set + // (HloModule::has_schedule() is true) before running. Returns whether any + // instructions were rematerialized. If memory use is already below the limit + // specified in the constructor then no instructions are rematerialized and + // false is returned. + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + protected: + // Rematerializes instructions within the given computation. 'schedule' + // constrains the order in which the computation's instructions will be + // emitted in the backend. Rematerialized instructions will be added to the + // HLO computation and inserted into 'schedule'. + virtual absl::StatusOr RematerializeComputation( + HloComputation* computation, HloSchedule* schedule, + int64_t memory_limit_bytes, int64_t min_remat_size, + const absl::flat_hash_set& execution_threads); + + // Computes and returns the peak memory used by the given computation. The + // peak memory is the maximum total size of all live HLO instruction values at + // any program point. 'order' is the order in which the HLO instructions will + // be emitted which is used to determine lifespans of HLO values. + absl::StatusOr ComputePeakMemory( + const HloComputation* computation, const HloInstructionSequence& order, + const absl::flat_hash_set& execution_threads) const; + + // Returns the peak memory usage of the called computations for the given + // instruction. Zero is returned if the instruction calls no computations. + absl::StatusOr CalledComputationsMemoryUsage( + const HloInstruction* instruction, + const absl::flat_hash_set& execution_threads) const; + + const Options options_; + + // Reference to data structure which records the peak memory usage of the HLO + // module before/after rematerialization. + RematerializationSizes& sizes_; + + // Call graph of the hlo_module. + std::unique_ptr call_graph_; + + // The peak memory usage of each computation. The map contains only those + // computations called from sequential context (CallContext::kSequential). + // These values are updated as rematerialization occurs. + absl::flat_hash_map computation_peak_memory_; + + std::unique_ptr points_to_analysis_; + + // Set of computations which have had rematerialization + // applied. Rematerialization is only applied once per computation. + absl::flat_hash_set rematerialized_computations_; + + // Count of the total instructions rematerialized. + int64_t instructions_rematerialized_ = 0; + + // Count of the net instructions added to the HLO module by + // rematerialization. This can be different than instructions_rematerialized_ + // because some rematerializations are effectively moves in the HLO + // schedule. In these cases, the rematerialization instruction replaces all + // uses of the original instruction and the original instruction is + // dead. Hence, no net instructions were added. + int64_t net_instructions_added_ = 0; + + // Size of the largest block that has been rematerialized. This is actually an + // upper bound (within a factor of 2) on the block size. + int max_rematerialized_block_size_ = 0; + + // Tracking available channel id numbers to use to apply to rematerialized + // channel instructions + int64_t next_channel_id_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc similarity index 99% rename from third_party/xla/xla/service/hlo_rematerialization_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc index c3a945345b3101..65beab9a155bb2 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_rematerialization.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" #include #include @@ -28,11 +28,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_rematerialization_test_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h new file mode 100644 index 00000000000000..a0d1e67848cec8 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h @@ -0,0 +1,150 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class to create computations for testing rematerialization methods. + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_TEST_UTILS_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_TEST_UTILS_H_ + +#include +#include +#include + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class RematerializationTestBase : public HloHardwareIndependentTestBase { + protected: + // Creates and returns a computation which can benefit from + // rematerialization. The computation looks like: + // + // F32[1] %param = {...} + // F32[] %reshape = reshape(F32[], param) + // F32[1024] %bcast = broadcast(%param) + // F32[1024] %negate = negate(%bcast) + // F32[2048] %concat_1 = concat({%negate, %negate}) + // F32[1] %slice_1 = slice(%concat_1, {0:1}) + // F32[1025] %concat_2 = concat({%bcast, %slice_1}) + // F32[1] %slice_2 = slice(%concat_2, {0:1}); + // + // The instruction %bcast can be rematerialized before its use at %concat_2 + // to reduce peak memory usage. This avoids %bcast and %concat_1 being + // simultaneously live. Peak memory use is about 16KB before rematerialization + // (during execution of %concat_1) and about 12KB after rematerializing %bcast + // for its use in %concat_2. + std::unique_ptr MakeRematerializableComputation( + const std::string& suffix = "") { + auto builder = HloComputation::Builder(TestName() + suffix); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); + auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate}, + /*dimension=*/0)); + auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( + vec1_shape_, concat_1, /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, + /*dimension=*/0)); + // Add a final slice to make the parameter shape match the output shape + // which is necessary to use this computation in a while. + builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, + /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + return builder.Build(); + } + + // Creates and returns a computation which includes a while and can benefit + // from rematerialization. The computation looks like: + // + // F32[] %param = {...} + // F32[1024] %bcast = broadcast(%param) + // F32[1] %slice_1 = slice(%bcast, {0:1}) + // F32[1] %while = while(%slice_1, while_body, while_cond) + // F32[1025] %concat = concat({%bcast, %while}) + // F32[1] %slice_2 = slice(%concat, {0:1}); + // + // The instruction %bcast can be rematerialized before its use at %concat to + // reduce peak memory usage. This avoids %bcast being live during execution of + // the while. Peak memory use is maximum of 8K and 4K plus the memory use of + // the while subcomputations. + std::unique_ptr MakeRematerializableWhileComputation( + HloComputation* while_cond, HloComputation* while_body, + const std::string& suffix = "") { + auto builder = HloComputation::Builder(TestName() + suffix); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + auto reshape = builder.AddInstruction( + HloInstruction::CreateReshape(scalar_shape_, param)); + auto bcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); + auto slice_1 = builder.AddInstruction( + HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + vec1_shape_, while_cond, while_body, slice_1)); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst}, + /*dimension=*/0)); + builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, + /*start_indices=*/{0}, + /*limit_indices=*/{1}, + /*strides=*/{1})); + return builder.Build(); + } + + // Create and return a trivial computation appropriate for use as a while + // condition. + std::unique_ptr MakeConditionComputation() { + auto builder = HloComputation::Builder(TestName() + ".cond"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, vec1_shape_, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + return builder.Build(); + } + + // Return the byte size of the top-level buffer of the given shape. + static int64_t ByteSizeOf(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + } + + protected: + // Various shapes used in the canned computations. + const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); + const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); + const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HLO_REMATERIALIZATION_TEST_UTILS_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils_test.cc similarity index 97% rename from third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils_test.cc index 803a0704fde839..7d7920ef724674 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test_utils_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils_test.cc @@ -12,11 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/hlo_rematerialization_test_utils.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h" #include #include +#include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc similarity index 99% rename from third_party/xla/xla/service/host_memory_transfer_asyncifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc index 24817d168534e4..278d9f8e704424 100644 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_memory_transfer_asyncifier.h" +#include "xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h" #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h new file mode 100644 index 00000000000000..632faeb9e81730 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +/* +This pass finds copies between the host memory and device memory and converts +them into the async ops. This includes, but is not limited to: + - device to host DynamicUpdateSlice + - host to device DynamicSlice +* The examples below are not yet supported * + - host to device DynamicUpdateSlice + - device to host DynamicSlice + - host to device Copy + - device to host Copy +*/ +class HostMemoryTransferAsyncifier : public HloModulePass { + public: + explicit HostMemoryTransferAsyncifier(int64_t host_memory_space_color) + : kHostMemorySpaceColor(host_memory_space_color) {} + ~HostMemoryTransferAsyncifier() override = default; + + absl::string_view name() const override { + return "host-memory-transfer-asyncifier"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const int64_t kHostMemorySpaceColor; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc similarity index 98% rename from third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc index fd85488a2239ec..0b7c96c090ee74 100644 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/host_memory_transfer_asyncifier.h" +#include "xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h" #include #include @@ -26,9 +26,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -38,7 +38,7 @@ namespace { namespace m = ::xla::match; -class HostMemoryTransferAsyncifierTest : public HloTestBase { +class HostMemoryTransferAsyncifierTest : public HloHardwareIndependentTestBase { protected: absl::StatusOr RunAsyncifier(absl::string_view hlo_string) { TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); diff --git a/third_party/xla/xla/service/instruction_hoister.cc b/third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.cc similarity index 98% rename from third_party/xla/xla/service/instruction_hoister.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.cc index d706a873429b55..aa8edadb330178 100644 --- a/third_party/xla/xla/service/instruction_hoister.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/instruction_hoister.h" +#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.h b/third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.h new file mode 100644 index 00000000000000..64cad2f68da73b --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/instruction_hoister.h @@ -0,0 +1,50 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_INSTRUCTION_HOISTER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_INSTRUCTION_HOISTER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// HLO pass that hoists parameters and constants to increase opportunities for +// prefetching. +class InstructionHoister : public HloModulePass { + public: + explicit InstructionHoister(bool hoist_parameters = true, + bool host_constants = true) + : hoist_parameters_(hoist_parameters), host_constants_(host_constants) {} + + ~InstructionHoister() override = default; + + absl::string_view name() const override { return "instruction-hoister"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + bool hoist_parameters_; + bool host_constants_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_INSTRUCTION_HOISTER_H_ diff --git a/third_party/xla/xla/service/optimize_input_output_buffer_alias.cc b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc similarity index 98% rename from third_party/xla/xla/service/optimize_input_output_buffer_alias.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc index 8900e5610cf35e..85bb11c409b373 100644 --- a/third_party/xla/xla/service/optimize_input_output_buffer_alias.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/optimize_input_output_buffer_alias.h" +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h new file mode 100644 index 00000000000000..d33512c0dda2bc --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h @@ -0,0 +1,89 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { + +// This pass finds input and output buffers that can be aliased, and writes the +// alias config into the HloModule. +// +// The input and the output buffers can be in any shape, and each output buffer +// can alias with an input buffer with the same shape. Each input buffer may +// only alias with a single output buffer. For example, for the following +// parameter and the output buffers, +// +// Parameters : { P1(f32[3]), P2(s32[3]), P3(f32[3,12]), P4(f32[16,12]), ... } +// Outputs : { O1(s32[3]), O2(f32[3]), O3(f32[16,12]), ... } +// +// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), .. +class OptimizeInputOutputBufferAlias : public HloModulePass { + public: + OptimizeInputOutputBufferAlias() = default; + explicit OptimizeInputOutputBufferAlias( + bool registered_buffer_donor_only, + std::function shape_size_fn = + [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) + : registered_buffer_donor_only_(registered_buffer_donor_only), + shape_size_fn_(shape_size_fn) {} + ~OptimizeInputOutputBufferAlias() override = default; + + absl::string_view name() const override { + return "optimize_input_output_buffer_alias"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + friend class OptimizeInputOutputBufferAliasTest; + + // If true, we only consider the registered buffer donor in + // HloBufferDonorConfig, ignoring unregistered input parameters. If false, we + // treat all input parameters as buffer donors. + bool registered_buffer_donor_only_ = false; + + // Match buffer donors and donees and save the matched paired in the + // alias_config. The availability of buffer donors is controlled by the flag + // registered_buffer_donor_only_. + absl::StatusOr Build(absl::Span input_shapes, + const Shape& output_shape, + HloInputOutputAliasConfig* alias_config, + HloBufferDonorConfig* buffer_donor_config); + + std::function shape_size_fn_ = [](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape); + }; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/third_party/xla/xla/service/optimize_input_output_buffer_alias_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias_test.cc similarity index 97% rename from third_party/xla/xla/service/optimize_input_output_buffer_alias_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias_test.cc index 1c8e7fd0dc049e..ec8954b58d7866 100644 --- a/third_party/xla/xla/service/optimize_input_output_buffer_alias_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/optimize_input_output_buffer_alias.h" +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" #include #include @@ -23,19 +23,19 @@ limitations under the License. #include #include "absl/status/status.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" #include "tsl/platform/test.h" namespace xla { // Tests that UserBufferAlias properly maps input and output buffer indices of // various shapes for aliasing. -class OptimizeInputOutputBufferAliasTest : public HloTestBase { +class OptimizeInputOutputBufferAliasTest + : public HloHardwareIndependentTestBase { protected: OptimizeInputOutputBufferAliasTest() { r1f32_ = ShapeUtil::MakeShape(F32, {4}); diff --git a/third_party/xla/xla/service/reduce_window_rewriter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/reduce_window_rewriter.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc index 241b30a9eb68d7..ddae0b21286be7 100644 --- a/third_party/xla/xla/service/reduce_window_rewriter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reduce_window_rewriter.h" +#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.h b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.h new file mode 100644 index 00000000000000..b68308fdd33d13 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter.h @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_REDUCE_WINDOW_REWRITER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_REDUCE_WINDOW_REWRITER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Rewrite ReduceWindow to be more performant in cases it is written in a +// quadratic way: +// +// 1) Work around unimplemented cases in the implementation of ReduceWindow. +// +// This rewrites all R1 ReduceWindow nodes. We reshape the operand to an +// R2, perform the operation, and reshape back to R1. The reshapes correspond to +// a bitcast if the tensor length is less than or equal to a passed parameter. +// The motivation for this is to avoid use of overly large reductions and the +// complexities and restrictions therein. +// +// 2) Rewrite ReduceWindow ops that represent a CumSum/CumProd into a +// tree-reduction (see details in the implementation). +// Note that this may itself generate R1 ReduceWindow ops, which means this pass +// needs to be run to a fixed point. +class ReduceWindowRewriter : public HloModulePass { + public: + // `base_length` is a size of a reduce-window we are comfortable with + // executing. + explicit ReduceWindowRewriter(int64_t base_length) + : base_length_(base_length) {} + + absl::string_view name() const override { return "reduce-window-rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::Status ReplaceReduceWindowWithReshape( + HloReduceWindowInstruction* reduce_window); + + absl::StatusOr TryOptimizeCumSumOrProd( + HloReduceWindowInstruction* reduce_window); + + int64_t base_length_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_REDUCE_WINDOW_REWRITER_H_ diff --git a/third_party/xla/xla/service/reduce_window_rewriter_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter_test.cc similarity index 97% rename from third_party/xla/xla/service/reduce_window_rewriter_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter_test.cc index b40314f6e4da4b..5e142bd509c80f 100644 --- a/third_party/xla/xla/service/reduce_window_rewriter_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reduce_window_rewriter_test.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reduce_window_rewriter.h" +#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" #include #include #include "absl/strings/string_view.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" namespace xla { namespace { -class ReduceWindowRewriterTest : public HloTestBase { +class ReduceWindowRewriterTest : public HloHardwareIndependentTestBase { public: void CheckReduceWindowRewrite(absl::string_view hlo, std::optional expected) { diff --git a/third_party/xla/xla/service/reshape_mover.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc similarity index 99% rename from third_party/xla/xla/service/reshape_mover.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc index 8040a6eb544f22..cb2fcdb3225d9e 100644 --- a/third_party/xla/xla/service/reshape_mover.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" #include #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.h b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.h new file mode 100644 index 00000000000000..21d67fd8ed5a31 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESHAPE_MOVER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESHAPE_MOVER_H_ + +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// This pass sinks kReshape and kTranspose operations (known as "rearrange" ops) +// down through elementwise ops: +// +// op(rearrange(x), rearrange(y)) => rearrange(op(x, y)). +// +// We also handle the case where one of the operands is not itself a rearrange +// op but can be trivially rearranged. For example: +// +// op(rearrange(x), broadcast(scalar_y)) => +// rearrange(x, broadcast'(scalar_y)). +// +// This pass should be run to a fixed point. It also expects algsimp to be run +// after each iteration. + +struct ReshapeMoverOptions { + // On some platforms, it's cheap to do `reshape(broadcast(f32[n] x))`. The + // reshape and broadcast can always be fused, and the index calculations are + // not expensive. In such cases it can be beneficial for us to create these + // reshapes eagerly, allowing us to get rid of more expensive ones. + bool reshape_of_1d_broadcast_is_cheap = false; +}; + +class ReshapeMover : public HloModulePass { + public: + explicit ReshapeMover( + const ReshapeMoverOptions& options = ReshapeMoverOptions{}) + : options_(options) {} + + absl::string_view name() const override { return "reshape-mover"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::StatusOr TryReshapeMoveOnCandidates( + HloInstructionSet* candidates); + absl::StatusOr SinkRearrangeOperands(HloInstruction* instruction); + absl::StatusOr ApplyInverseRearrange( + const HloInstruction* rearrange, HloInstruction* operand); + bool IsReshapeMoveCandidate(HloInstruction* instruction); + const HloInstruction* FirstNontrivialRearrange( + absl::Span instrs); + bool CanTriviallyRearrange(const HloInstruction* instr, + const HloInstruction* rearrange); + + ReshapeMoverOptions options_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESHAPE_MOVER_H_ diff --git a/third_party/xla/xla/service/reshape_mover_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc similarity index 98% rename from third_party/xla/xla/service/reshape_mover_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc index 6aca742e0190ea..c72ae7ce392220 100644 --- a/third_party/xla/xla/service/reshape_mover_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/reshape_mover_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" #include #include @@ -22,11 +22,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/service/hlo_verifier.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" namespace xla { @@ -34,7 +34,7 @@ namespace { namespace m = xla::match; -class ReshapeMoverTest : public HloTestBase { +class ReshapeMoverTest : public HloHardwareIndependentTestBase { protected: // ReshapeMover relies on algsimp for cleanup. absl::Status RunPass(HloModule* module, bool change_expected, diff --git a/third_party/xla/xla/service/result_caster.cc b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster.cc similarity index 98% rename from third_party/xla/xla/service/result_caster.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/result_caster.cc index ea41b9a94b0fa2..ee78e6efd5f8e8 100644 --- a/third_party/xla/xla/service/result_caster.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/result_caster.h" +#include "xla/hlo/transforms/simplifiers/result_caster.h" #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/result_caster.h b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster.h new file mode 100644 index 00000000000000..01d1980a5d0044 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster.h @@ -0,0 +1,51 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESULT_CASTER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESULT_CASTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/util.h" + +namespace xla { + +// Inserts Convert to result of instructions to the preferred element type +// specified by the instructions when direct accumulation of that type isn't +// supported by the backend. This pass is run in combination with +// OperandUpcaster. If the inferred accumulation type has less precision, +// OperandUpcaster will convert the operands to the higher precision type if +// necessary. +class ResultCaster : public OpExpanderPass { + public: + explicit ResultCaster(HloPredicate extra_filter = nullptr) + : OpExpanderPass(std::move(extra_filter)) {} + + absl::string_view name() const override { return "result_caster"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_RESULT_CASTER_H_ diff --git a/third_party/xla/xla/service/result_caster_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster_test.cc similarity index 96% rename from third_party/xla/xla/service/result_caster_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/result_caster_test.cc index a2f072a5fc3cd6..58c550147ab214 100644 --- a/third_party/xla/xla/service/result_caster_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/result_caster_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/result_caster.h" +#include "xla/hlo/transforms/simplifiers/result_caster.h" #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/primitive_util.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" namespace xla { @@ -32,7 +32,7 @@ namespace { namespace op = ::xla::testing::opcode_matchers; class ResultCasterTest - : public HloTestBase, + : public HloHardwareIndependentTestBase, public ::testing::WithParamInterface< std::tuple> {}; diff --git a/third_party/xla/xla/service/root_instruction_sinker.cc b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.cc similarity index 97% rename from third_party/xla/xla/service/root_instruction_sinker.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.cc index 007e9914499b95..9fcbab4222c26d 100644 --- a/third_party/xla/xla/service/root_instruction_sinker.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/root_instruction_sinker.h" +#include "xla/hlo/transforms/simplifiers/root_instruction_sinker.h" #include "xla/service/tuple_util.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.h b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.h new file mode 100644 index 00000000000000..0e692943da18d4 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_ROOT_INSTRUCTION_SINKER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_ROOT_INSTRUCTION_SINKER_H_ + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to +// the bottom of the non-fusion computations. To avoid dependency violations of +// moving the ROOT instruction, it creates a new ROOT instruction that looks +// like the following: +// - For tuple ROOT type: +// new_root = tuple(gte(old_root), gte(old_root), ...) +// - For non-tuple ROOT type: +// new_root = bitcast(old_root) +class RootInstructionSinker : public HloModulePass { + public: + ~RootInstructionSinker() override = default; + absl::string_view name() const override { return "root-instruction-sinker"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_ROOT_INSTRUCTION_SINKER_H_ diff --git a/third_party/xla/xla/service/root_instruction_sinker_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker_test.cc similarity index 97% rename from third_party/xla/xla/service/root_instruction_sinker_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker_test.cc index 1be67c96c61edc..bbc472061f8032 100644 --- a/third_party/xla/xla/service/root_instruction_sinker_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/root_instruction_sinker_test.cc @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/root_instruction_sinker.h" +#include "xla/hlo/transforms/simplifiers/root_instruction_sinker.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" namespace xla { namespace { namespace op = xla::testing::opcode_matchers; -using RootInstructionSinkerTest = HloTestBase; +using RootInstructionSinkerTest = HloHardwareIndependentTestBase; TEST_F(RootInstructionSinkerTest, TupleNoChange) { // ROOTS are already sunk, no change performed to the module. diff --git a/third_party/xla/xla/service/simplify_fp_conversions.cc b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc similarity index 97% rename from third_party/xla/xla/service/simplify_fp_conversions.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc index a30bdcfebea413..43c7171fafd09f 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/simplify_fp_conversions.h" +#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.h b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.h new file mode 100644 index 00000000000000..a2266f34721f66 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions.h @@ -0,0 +1,47 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_SIMPLIFY_FP_CONVERSIONS_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_SIMPLIFY_FP_CONVERSIONS_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Simplifies chains of floating-point conversions. +// +// The algebraic simplifier will remove convert pairs of the form `X -> Y -> X`, +// only when they are a no-op, e.g. `bf16 -> f32 -> bf16` or +// `f32 -> bf16 -> f32`. Note that the latter optimization might lead to +// increased precision. +class SimplifyFPConversions : public HloModulePass { + public: + explicit SimplifyFPConversions() = default; + + absl::string_view name() const override { return "simplify-fp-conversions"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_SIMPLIFY_FP_CONVERSIONS_H_ diff --git a/third_party/xla/xla/service/simplify_fp_conversions_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions_test.cc similarity index 91% rename from third_party/xla/xla/service/simplify_fp_conversions_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions_test.cc index ad85bb873eb654..17e97fdf49d90f 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/simplify_fp_conversions_test.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/simplify_fp_conversions.h" +#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" #include +#include +#include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/tests/hlo_test_base.h" #include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -34,7 +36,7 @@ namespace op = xla::testing::opcode_matchers; using ::testing::AllOf; using ::tsl::testing::IsOkAndHolds; -using SimplifyFPConversionsTest = HloTestBase; +using SimplifyFPConversionsTest = HloHardwareIndependentTestBase; TEST_F(SimplifyFPConversionsTest, DoesNotChangeSingleConvert) { const absl::string_view kModuleStr = R"( diff --git a/third_party/xla/xla/service/slice_sinker.cc b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc similarity index 95% rename from third_party/xla/xla/service/slice_sinker.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc index 5f6e75c486b33b..e1304eddc9ab57 100644 --- a/third_party/xla/xla/service/slice_sinker.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/slice_sinker.h" +#include "xla/hlo/transforms/simplifiers/slice_sinker.h" #include #include @@ -21,8 +21,21 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.h b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.h new file mode 100644 index 00000000000000..66a58ed24f0a80 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker.h @@ -0,0 +1,41 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_SLICE_SINKER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_SLICE_SINKER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// An HLO pass that sinks slice operations used by a group of elementwise +// operations and merges the group of elementwise operations. +class SliceSinker : public HloModulePass { + public: + absl::string_view name() const override { return "slice-sinker"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_SLICE_SINKER_H_ diff --git a/third_party/xla/xla/service/slice_sinker_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker_test.cc similarity index 98% rename from third_party/xla/xla/service/slice_sinker_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker_test.cc index 413710bd6a225b..1a983aa5335f4a 100644 --- a/third_party/xla/xla/service/slice_sinker_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/slice_sinker_test.cc @@ -13,25 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/slice_sinker.h" +#include "xla/hlo/transforms/simplifiers/slice_sinker.h" #include #include +#include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/literal_util.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -39,7 +34,7 @@ namespace { namespace m = match; using ::testing::ElementsAre; -class SliceSinkerTest : public HloTestBase {}; +class SliceSinkerTest : public HloHardwareIndependentTestBase {}; TEST_F(SliceSinkerTest, TernaryOperation) { const char* kModuleStr = R"( diff --git a/third_party/xla/xla/service/sort_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc similarity index 93% rename from third_party/xla/xla/service/sort_simplifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc index bc9b373424b37c..16335f0d1d21af 100644 --- a/third_party/xla/xla/service/sort_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.cc @@ -13,16 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/sort_simplifier.h" +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.h new file mode 100644 index 00000000000000..44779f1e10c452 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier.h @@ -0,0 +1,40 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_SORT_SIMPLIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_SORT_SIMPLIFIER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// HLO pass which removes unused operands from sort, where an unused operand is +// defined as an operand at some index 'x' at which the output is not used. +class SortSimplifier : public HloModulePass { + public: + absl::string_view name() const override { return "simplify-sorts"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_SORT_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/sort_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier_test.cc similarity index 95% rename from third_party/xla/xla/service/sort_simplifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier_test.cc index 678ce7c37eb905..f9cfd9f7f99150 100644 --- a/third_party/xla/xla/service/sort_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sort_simplifier_test.cc @@ -13,22 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/sort_simplifier.h" +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { namespace m = match; -using SortSimplifierTest = HloTestBase; +using SortSimplifierTest = HloHardwareIndependentTestBase; TEST_F(SortSimplifierTest, RemoveUnusedSortOperandArrayResult) { const char* hlo_string = R"( diff --git a/third_party/xla/xla/service/sub_byte_normalization.cc b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc similarity index 98% rename from third_party/xla/xla/service/sub_byte_normalization.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc index fcedf9510ad77a..9e8521bbf83623 100644 --- a/third_party/xla/xla/service/sub_byte_normalization.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/sub_byte_normalization.h" +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.h b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.h new file mode 100644 index 00000000000000..b8918d63c6ae21 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/sub_byte_normalization.h @@ -0,0 +1,66 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_SUB_BYTE_NORMALIZATION_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_SUB_BYTE_NORMALIZATION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass that can modify the sub-byte element_size_in_bits annotation on +// layouts. Depending on the constructor argument, it either removes the +// element_size_in_bits annotation for platforms that don't support packed +// types, or it sets element_size_in_bits to N for N-bit values. +class SubByteNormalization : public HloModulePass { + public: + enum Mode { + // Remove element_size_in_bits on all layouts. Useful for platforms which + // do not support packed types. + REMOVE_ELEMENT_SIZE, + // Set element_size_in_bits to bitwidth(type) for layouts of types < 8 bits + // (S4, U4, etc.), and to 0 for all other layouts. Useful for platforms + // which support packed types. + SET_ELEMENT_SIZE, + }; + + explicit SubByteNormalization(Mode mode) : mode_(mode) {} + + ~SubByteNormalization() override = default; + + absl::string_view name() const override { + switch (mode_) { + case REMOVE_ELEMENT_SIZE: + return "sub-byte-size-removal"; + case SET_ELEMENT_SIZE: + return "sub-byte-size-setter"; + } + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + Mode mode_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_SUB_BYTE_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/tree_reduction_rewriter.cc b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc similarity index 97% rename from third_party/xla/xla/service/tree_reduction_rewriter.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc index 45a87836f786aa..15c930d52e1fb1 100644 --- a/third_party/xla/xla/service/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/tree_reduction_rewriter.h" +#include "xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h" #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h new file mode 100644 index 00000000000000..c332642ba82005 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h @@ -0,0 +1,61 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_TREE_REDUCTION_REWRITER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_TREE_REDUCTION_REWRITER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Increase precision for the reduction operation by applying the reduce-window +// first. +// +// E.g. suppose we want to reduce f32[1024] to a scalar. This pass first applies +// a reduce-window (with kSame padding) of size `reduce_window_size`, and then +// reduces the resulting array f32[32]. The rewrite is not applied if any of the +// reduced dimensions is smaller than the `reduce_window_size`. +// +// Applying this pass until a fixed point performs a variant of pairwise +// summation (https://en.wikipedia.org/wiki/Pairwise_summation), which is +// guaranteed to have an asymptotically smaller error bound provided that +// intermediate roundoff errors are random and have random sign. +// +// If this pass lowers the performance too much, the window size can always be +// increased to a larger value. +class TreeReductionRewriter : public HloModulePass { + public: + explicit TreeReductionRewriter(int64_t reduce_window_size = 32) + : reduce_window_size_(reduce_window_size) {} + ~TreeReductionRewriter() override = default; + absl::string_view name() const override { return "tree_reduction_rewriter"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + int64_t reduce_window_size_; +}; + +} // end namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_TREE_REDUCTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/tuple_simplifier.cc b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc similarity index 98% rename from third_party/xla/xla/service/tuple_simplifier.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc index 3557b076df0ef6..efac46a4f14775 100644 --- a/third_party/xla/xla/service/tuple_simplifier.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.h b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.h new file mode 100644 index 00000000000000..ba3113ef8d38b4 --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_TUPLE_SIMPLIFIER_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_TUPLE_SIMPLIFIER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// A pass which simplifies patterns of Tuple and GetTupleElement instructions in +// the module. +class TupleSimplifier : public HloModulePass { + public: + TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} + explicit TupleSimplifier(bool exclude_entry_computation); + ~TupleSimplifier() override {} + absl::string_view name() const override { return "tuple-simplifier"; } + + // Runs tuple simplification on the given module. Returns whether the module + // was changed. + using HloPassInterface::Run; + using HloPassInterface::RunOnModuleGroup; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + // When set, this pipeline stage will perform optimization of all computations + // apart from the module's entry computation. This is used by Graphcore's + // backend. + bool exclude_entry_computation_; + + // Collapse the following structure into just 'Tuple-shaped Op', iff the + // sequence of GTE ops is order-preserving: + // + // Tuple-shaped Op + // | + // +-----+-----+ + // | | | + // GTE GTE GTE + // | | | + // +-----+-----+ + // | + // Tuple + // + absl::StatusOr RemoveWholeTuple(HloInstruction* tuple); +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_TUPLE_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/tuple_simplifier_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier_test.cc similarity index 99% rename from third_party/xla/xla/service/tuple_simplifier_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier_test.cc index 33305afd7e0f71..17da39d7472699 100644 --- a/third_party/xla/xla/service/tuple_simplifier_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/tuple_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include @@ -21,11 +21,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -34,7 +34,7 @@ namespace { namespace op = xla::testing::opcode_matchers; -class TupleSimplifierTest : public HloTestBase { +class TupleSimplifierTest : public HloHardwareIndependentTestBase { protected: void Run(HloModule* module, bool change_expected) { auto changed_status = RunHloPass(TupleSimplifier(), module); diff --git a/third_party/xla/xla/service/zero_sized_hlo_elimination.cc b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc similarity index 96% rename from third_party/xla/xla/service/zero_sized_hlo_elimination.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc index ce784875dff930..2ea480fadcd846 100644 --- a/third_party/xla/xla/service/zero_sized_hlo_elimination.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/zero_sized_hlo_elimination.h" +#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" diff --git a/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h new file mode 100644 index 00000000000000..7823b7e4b4e71a --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_SIMPLIFIERS_ZERO_SIZED_HLO_ELIMINATION_H_ +#define XLA_HLO_TRANSFORMS_SIMPLIFIERS_ZERO_SIZED_HLO_ELIMINATION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +// HLO pass that replaces zero sized Hlos with a zero sized constant literal. +namespace xla { +class ZeroSizedHloElimination : public HloModulePass { + public: + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + absl::string_view name() const override { + return "zero_sized_hlo_elimination"; + } +}; +} // namespace xla +#endif // XLA_HLO_TRANSFORMS_SIMPLIFIERS_ZERO_SIZED_HLO_ELIMINATION_H_ diff --git a/third_party/xla/xla/service/zero_sized_hlo_elimination_test.cc b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination_test.cc similarity index 93% rename from third_party/xla/xla/service/zero_sized_hlo_elimination_test.cc rename to third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination_test.cc index 9da305fb978cdb..bbb087bade170d 100644 --- a/third_party/xla/xla/service/zero_sized_hlo_elimination_test.cc +++ b/third_party/xla/xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/zero_sized_hlo_elimination.h" +#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" #include #include @@ -21,21 +21,21 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -class ZeroSizedHloEliminationTest : public HloTestBase { +class ZeroSizedHloEliminationTest : public HloHardwareIndependentTestBase { protected: ZeroSizedHloEliminationTest() - : HloTestBase(), + : HloHardwareIndependentTestBase(), builder_("zero_sized_computation"), zero_sized_param_( builder_.AddInstruction(HloInstruction::CreateParameter( diff --git a/third_party/xla/xla/service/while_loop_trip_count_annotator.cc b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.cc similarity index 93% rename from third_party/xla/xla/service/while_loop_trip_count_annotator.cc rename to third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.cc index 037a6151fa37f3..db59fa2df83c8d 100644 --- a/third_party/xla/xla/service/while_loop_trip_count_annotator.cc +++ b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/while_loop_trip_count_annotator.h" +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/while_loop_analysis.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.h b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.h new file mode 100644 index 00000000000000..0cab15b749d9ad --- /dev/null +++ b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator.h @@ -0,0 +1,53 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSFORMS_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ +#define XLA_HLO_TRANSFORMS_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// Pass that annotates `while` loops with known trip counts. +// +// The annotation is stored as a backend-config on the while loop node. +// +// This pass should run after all passes that might semantically modify a while +// loop, e.g. by unrolling it. Otherwise, a loop could end up with a +// backend-config that doesn't match its true trip-count. +// +// This pass does some pattern-matching on loop bodies and conditions, so it +// should run after most HLO simplifications and before fusion and layout +// assignment, which make pattern matching much more difficult by e.g. +// introducing `copy` nodes. +class WhileLoopTripCountAnnotator : public HloModulePass { + public: + ~WhileLoopTripCountAnnotator() override {} + absl::string_view name() const override { + return "while-loop-trip-count-annotator"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSFORMS_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/while_loop_trip_count_annotator_test.cc b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc similarity index 97% rename from third_party/xla/xla/service/while_loop_trip_count_annotator_test.cc rename to third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc index 1b12f3178f4b09..942408086452d4 100644 --- a/third_party/xla/xla/service/while_loop_trip_count_annotator_test.cc +++ b/third_party/xla/xla/hlo/transforms/while_loop_trip_count_annotator_test.cc @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/while_loop_trip_count_annotator.h" +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -class TripCountAnnotatorTest : public HloTestBase {}; +class TripCountAnnotatorTest : public HloHardwareIndependentTestBase {}; TEST_F(TripCountAnnotatorTest, KnownSmallTripCount) { const char* kModuleStr = R"( diff --git a/third_party/xla/xla/hlo/translate/BUILD b/third_party/xla/xla/hlo/translate/BUILD new file mode 100644 index 00000000000000..01b896690baf5f --- /dev/null +++ b/third_party/xla/xla/hlo/translate/BUILD @@ -0,0 +1,83 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("//xla:xla.bzl", "xla_cc_binary") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "portable_api", + srcs = ["portable_api.cc"], + hdrs = ["portable_api.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/mlir_hlo:hlo_dialect_registration", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + "@stablehlo//:register", + ], +) + +build_test( + name = "xla-translate_build_test", + targets = [ + ":xla-translate", + ], +) + +xla_cc_binary( + name = "xla-translate", + testonly = True, + srcs = ["xla_translate_main.cc"], + deps = [ + "//xla/hlo/translate/hlo_to_mhlo:translate_registration", + "//xla/hlo/translate/mhlo_to_hlo:translate_registration", + "//xla/hlo/translate/stablehlo_to_hlo:translate_registration", + "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:cpu_transfer_manager", + "//xla/stream_executor/host:host_platform", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:platform_port", + ], +) + +build_test( + name = "xla-translate-opt_build_test", + targets = [ + ":xla-translate-opt", + ], +) + +xla_cc_binary( + name = "xla-translate-opt", + testonly = True, + srcs = ["xla_translate_opt_main.cc"], + deps = [ + "//xla/mlir/framework/ir:xla_framework", + "//xla/mlir/framework/transforms:passes", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:cpu_plugin", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@local_tsl//tsl/platform:platform_port", + "@stablehlo//:register", + ], +) diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD new file mode 100644 index 00000000000000..a9a7481aad2a97 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/BUILD @@ -0,0 +1,274 @@ +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "attribute_importer", + srcs = ["attribute_importer.cc"], + hdrs = ["attribute_importer.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "async_importer", + srcs = ["async_importer.cc"], + hdrs = ["async_importer.h"], + deps = [ + ":attribute_importer", + ":hlo_utils", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "custom_call_importer", + srcs = ["custom_call_importer.cc"], + hdrs = ["custom_call_importer.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "stack_location_utils", + srcs = ["stack_location_utils.cc"], + hdrs = ["stack_location_utils.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_proto_cc", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "hlo_function_importer", + srcs = ["hlo_function_importer.cc"], + hdrs = ["hlo_function_importer.h"], + deps = [ + ":async_importer", + ":attribute_importer", + ":custom_call_importer", + ":hlo_utils", + ":location_importer", + "//xla:comparison_util", + "//xla:literal", + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:AsmParser", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_module_importer", + srcs = [ + "hlo_module_importer.cc", + ], + hdrs = [ + "hlo_module_importer.h", + ], + deps = [ + ":hlo_function_importer", + ":module_attributes_importer", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_to_mlir_hlo", + srcs = ["hlo_to_mlir_hlo.cc"], + hdrs = ["hlo_to_mlir_hlo.h"], + deps = [ + ":hlo_module_importer", + "//xla:status_macros", + "//xla/mlir/utils:error_util", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "hlo_utils", + srcs = ["hlo_utils.cc"], + hdrs = ["hlo_utils.h"], + includes = ["include"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "hlo_utils_test", + srcs = ["hlo_utils_test.cc"], + deps = [ + ":hlo_utils", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/tsl/lib/core:status_test_util", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "location_importer", + srcs = ["location_importer.cc"], + hdrs = ["location_importer.h"], + deps = [ + "stack_location_utils", + "//xla/hlo/ir:hlo", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "module_attributes_importer", + srcs = ["module_attributes_importer.cc"], + hdrs = ["module_attributes_importer.h"], + deps = [ + ":hlo_function_importer", + ":hlo_utils", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + ":hlo_to_mlir_hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:hlo_proto_cc", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:protobuf", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = ["translate_registration.cc"], + deps = [ + ":translate", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc index 57bc78a0ead971..c71cdabc7b2acb 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/async_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/async_importer.h" #include #include @@ -34,10 +34,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h similarity index 95% rename from third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h index efdd487c21f03d..906f9235f28498 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/async_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ #include @@ -85,4 +85,4 @@ absl::StatusOr ImportAsyncOpDone( } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc index 7e3ea9b3d9e282..9ce0d2faf6b1b8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include @@ -35,6 +35,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h new file mode 100644 index 00000000000000..6a54b864e38f0d --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/attribute_importer.h @@ -0,0 +1,104 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. +mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, + mlir::Builder* builder); + +// Converts the gather dimensions to attributes. +mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( + const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the scatter dimensions to attributes. +mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( + const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the dot algorithm to attributes. +mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( + PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); + +// Converts the dot dimensions to attributes. +mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( + const DotDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the conv dimensions to attributes. +mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( + const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); + +// Converts the output operand aliasing to attributes. +mlir::ArrayAttr ConvertOutputOperandAliasing( + const std::vector>>& aliaInfo, + mlir::Builder* builder); + +// Converts the sparsity descriptor to attributes. +absl::StatusOr ConvertSparsityDescriptor( + xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); + +absl::StatusOr ConvertFftType(FftType type); +absl::StatusOr ConvertTranspose( + TriangularSolveOptions_Transpose transpose); + +absl::StatusOr ConvertCustomCallApiVersion( + xla::CustomCallApiVersion api_version); + +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder); +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertReplicaGroups( + absl::Span replica_groups, mlir::Builder* builder); + +mlir::NamedAttribute ConvertSourceTargetPairs( + const std::vector>& source_target_pairs, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder); + +// Extracts layouts from shapes and converts it into layout attributes (array of +// rank-1 index tensors). Returns an error if any of the shapes is a tuple. +absl::StatusOr ExtractLayoutsFromShapes( + const absl::Span shapes_with_layouts, mlir::Builder* builder); + +// Extracts the layouts of each element from a tuple shape and returns them as +// an array of rank-1 index tensors. Returns an error in presence of nested +// tuple shapes. +absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, + mlir::Builder* builder); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc index 24f69b8ce5c595..988375ab5cd4b4 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h" #include #include -#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/match.h" #include "llvm/ADT/STLExtras.h" #include "mlir/AsmParser/AsmParser.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h similarity index 89% rename from third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h index 8ccf85c77b5f89..92424e85dd356f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/custom_call_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ #include "absl/status/statusor.h" #include "mlir/IR/Builders.h" @@ -42,4 +42,4 @@ bool IsOpEncodedCustomCall(const HloCustomCallInstruction* instruction); } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_CUSTOM_CALL_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc index 392794277900cc..4072633c91218a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include #include @@ -63,6 +63,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/translate/hlo_to_mhlo/async_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/custom_call_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/location_importer.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -70,11 +75,6 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/async_importer.h" -#include "xla/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/translate/hlo_to_mhlo/custom_call_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "xla/translate/hlo_to_mhlo/location_importer.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -879,6 +879,20 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( new_operation->setAttr(attr.getName(), attr.getValue()); } } + // Shardy currently requires roundtripping passes after HW specific passes + // which introduce kCall with backend_config for host communication. If + // we get to a point where compiler flow for sharding propagation doesn't + // require roundtrip this can likely be removed. + const std::string& raw_backend_config = + instruction->raw_backend_config_string(); + if (!raw_backend_config.empty()) { + llvm::SmallVector frontend_attributes; + frontend_attributes.push_back(builder_->getNamedAttr( + "backend_config", builder_->getStringAttr(raw_backend_config))); + new_operation->setAttr( + kFrontendAttributesAttr, + builder_->getDictionaryAttr(frontend_attributes)); + } return new_operation; } case HloOpcode::kCollectiveBroadcast: { diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h new file mode 100644 index 00000000000000..c65c41e5bd9269 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h @@ -0,0 +1,258 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +class HloModule; +class HloComputation; +class HloInstruction; +class Shape; + +// HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes +// (which lose the bound information) or casted to static shape using the +// bounds. +enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic }; + +// Helper class for importing HloComputations. +class HloFunctionImporter { + public: + // Imports the given computation as a function in the given symbol table and + // returns the FuncOp. This also imports any computations referred by + // instructions in this computation. + static absl::StatusOr ImportAsFunc( + const HloComputation& computation, mlir::SymbolTable& symbol_table, + std::unordered_map* + function_map, + mlir::Builder* builder, bool is_main, + bool flatten_computation_args_result = false); + + // Imports the given hlo computation to the specified region. + // + // Flattens the tuple-typed region argument(s) and return value(s). + static absl::Status ImportAsRegion( + const HloComputation& computation, mlir::SymbolTable& symbol_table, + mlir::Region* region, mlir::Builder* builder, + bool flatten_computation_args_result = false); + + // Imports the given computation to the given place specified by `builder`. + // `arguments` contains values for all parameters. + static absl::StatusOr ImportInstructions( + const HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, + bool flatten_computation_args_result = false); + + static absl::StatusOr ImportInstruction( + const HloInstruction* instr, + const llvm::SmallVectorImpl& operands, + mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, + bool flatten_computation_args_result = false, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, + llvm::StringRef attr_name); + + // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block + // arguments with 'implicit_operands'. Here | implicit_operands | == sum of + // the number of arguments in all the regions in IfOp or CaseOp. + void ReplaceBlockArgumentsWithImplicitOperands( + mlir::Operation* op, llvm::ArrayRef implicit_operands); + + // FlattenTupleType flattens the types in (nested) tuple-type 'type' and + // stores them in 'flattened_types'. + static void FlattenTupleType( + mlir::Type type, llvm::SmallVectorImpl& flattened_types); + + // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and + // stores them in 'flattened_values'. + static void FlattenTupleValue( + mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value, + llvm::SmallVectorImpl& flattened_values); + + // FlattenTupleValues flattens the values in (nested) tuple-typed 'values' and + // returns the flattened values. + static llvm::SmallVector FlattenTupleValues( + mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange values, std::optional reserve_size = std::nullopt); + + private: + HloFunctionImporter(mlir::SymbolTable& symbol_table, + std::unordered_map* function_map, + mlir::Builder* builder, + bool flatten_computation_args_result) + : context_(symbol_table.getOp()->getContext()), + symbol_table_(symbol_table), + builder_(builder), + function_map_(function_map), + flatten_computation_args_result_(flatten_computation_args_result) { + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + context_->loadDialect(); + } + + // Imports the given computation as a new function, if it hasn't been already + // imported. + absl::StatusOr ImportAsFunc( + const HloComputation& computation, bool is_main); + + // Imports the given computation in the specified region. + absl::Status ImportAsRegion(const HloComputation& computation, + mlir::Region* region); + + // Imports instructions from the given computation in the specified block. + // Assumes that the block already has correct arguments populated. + absl::Status ImportInstructions(const HloComputation& computation, + mlir::Block* block); + absl::StatusOr ImportInstructionsImpl( + const HloComputation& computation, + const llvm::SmallVectorImpl& arguments, + mlir::OpBuilder* builder); + + // Imports an instruction. + absl::StatusOr ImportInstructionWithLayout( + const HloInstruction* instruction, + const llvm::SmallVectorImpl& operands, + mlir::OpBuilder* func_builder, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + absl::StatusOr ImportInstructionImpl( + const HloInstruction* instruction, + const llvm::SmallVectorImpl& operands, + mlir::OpBuilder* func_builder, + DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); + + // Gets the MLIR operand values from an HLO Instruction. + absl::StatusOr> GetOperands( + const HloInstruction* instruction); + + // Converts xla Tensor type to the corresponding MLIR type. + absl::StatusOr ConvertTensorType(const Shape& shape); + + // Converts an XLA shape/layout to the corresponding MLIR layout, in + // flattened_attr, while flattening the tuple layout. + absl::Status ConvertShapeToMlirLayout( + const Shape& shape, + llvm::SmallVectorImpl& flattened_attr); + + // Returns the output type of an HloInstruction. + absl::StatusOr GetReturnType(const HloInstruction* instruction); + + // Takes a list of HloInstructions and generates the list of types used for + // input, bypassing tuples to subsets. + absl::Status GetMlirTypes( + absl::Span instructions, + llvm::SmallVectorImpl* types); + + // Returns the Mlir Value for the corresponding HloInstruction. + absl::StatusOr GetMlirValue(const HloInstruction* instruction); + + // TODO(b/179166199): Move attribute converters to attribute_importer. + // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonDirection( + ComparisonDirection direction); + + // Converts an XLA Comparison::Type to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); + + // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); + + // Converts the dimensions of an HLO instruction into an MLIR attribute. + mlir::DenseIntElementsAttr ConvertDimensions( + absl::Span op_dimensions); + + // Converts Array ref to an DenseIntElementsAttr. + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); + + // Converts Array ref of bools to a DenseIntElementsAttr of I1 type. + mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); + + // Converts Array ref to padding attribute. Input is a flattened list of + // padding low and padding high for each of the spatial dimensions. + mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); + + mlir::MLIRContext* context_; + + // SymbolTable to which new functions should be inserted. + mlir::SymbolTable& symbol_table_; + + mlir::Builder* builder_; + + // Mapping from HloComputation to the created MLIR function. + std::unordered_map* function_map_; + + // Mapping from HloInstructions to the associative MLIR values. + std::unordered_map instruction_value_map_; + + bool flatten_computation_args_result_; +}; + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ sharding. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertSharding(const HloSharding& sharding, + mlir::Builder* builder); + +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO proto sharding. +// Will fail and return an empty attribute if the proto sharding cannot be +// converted to the C++ sharding. +mlir::Attribute ConvertSharding(const OpSharding& sharding, + mlir::Builder* builder); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc similarity index 90% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc index 8ad9d3844438e7..95d40af6ae70f8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #include @@ -21,15 +21,15 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/module_attributes_importer.h" #include "xla/xla.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -46,7 +46,7 @@ HloModuleImporter::HloModuleImporter(mlir::ModuleOp module, module.getContext()->loadDialect(); module.getContext()->loadDialect(); module.getContext()->loadDialect(); - module.getContext()->loadDialect(); + module.getContext()->loadDialect(); } absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { @@ -83,7 +83,8 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { flatten_computation_args_result_) .status()); - ImportEntryComputationLayoutAndTiles(hlo_module, module, builder_); + ImportEntryComputationLayoutAndTiles( + hlo_module, module, flatten_computation_args_result_, builder_); return absl::OkStatus(); } diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h new file mode 100644 index 00000000000000..8937f673035a23 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h @@ -0,0 +1,64 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ + +#include + +#include "absl/status/status.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/SymbolTable.h" +#include "xla/xla_data.pb.h" + +namespace xla { +class HloModule; +class HloModuleProto; +class HloComputation; +class HloInstruction; +class Shape; + +// Importer that takes an HloModule and imports it as an MLIR module in the XLA +// dialect. HloModuleImporter does not take ownership. +class HloModuleImporter { + public: + explicit HloModuleImporter(mlir::ModuleOp module, + bool import_all_computation = false, + bool flatten_computation_args_result = false); + + // Import the HloModule into the MLIR Module. + absl::Status Import(const xla::HloModule& module); + + // Import the HloModuleProto into the MLIR Module. + absl::Status Import(const xla::HloModuleProto& module); + + private: + bool import_all_computation_; + bool flatten_computation_args_result_; + mlir::SymbolTable symbol_table_; + mlir::Builder builder_; + + // Map for tracking which MLIR function map to which HLO Computation. This + // tracks functions as they are imported and provides a quick lookup for + // functions invoked by control flow related operations (e.g. while, call). + std::unordered_map + function_map_; +}; + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc similarity index 76% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index d6dafe01300c82..c6024887796abd 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -13,17 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include + +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #include "xla/mlir/utils/error_util.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -69,4 +75,20 @@ absl::StatusOr> ConvertHloToMlirHlo( return module; } +absl::StatusOr> ConvertHloToStablehlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module) { + TF_ASSIGN_OR_RETURN( + mlir::OwningOpRef module, + ConvertHloToMlirHlo(ctx, hlo_module, /*import_all_computations=*/true, + /*flatten_computation_args_result=*/true)); + + mlir::BaseScopedDiagnosticHandler diag_handler(&ctx); + mlir::PassManager pm(&ctx); + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (failed(pm.run(*module))) { + return diag_handler.ConsumeStatus(); + } + return std::move(module); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h new file mode 100644 index 00000000000000..a540df92b1935e --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -0,0 +1,75 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" + +namespace mlir { +class ModuleOp; +} // namespace mlir + +namespace xla { +class HloModule; +class HloModuleProto; + +// Converts an HLO module proto to a MLIR module in HLO dialect. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, + xla::HloModuleProto const* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts an HLO module to a MLIR module in HLO dialect. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, + const xla::HloModule* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Entrypoint for HLO to StableHLO conversion. +absl::StatusOr> ConvertHloToStablehlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc similarity index 99% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc index e6004cfe5291d6..564440ac00edcb 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.cc @@ -15,7 +15,7 @@ limitations under the License. // This file defines helpers useful when creating or manipulating lhlo/hlo. -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include #include diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h new file mode 100644 index 00000000000000..5c116fd08c9705 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils.h @@ -0,0 +1,250 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines helpers useful when creating or manipulating lhlo/hlo. + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr CreateDenseElementsAttrFromLiteral( + const LiteralBase& literal, mlir::Builder builder); + +// Creates an DenseIntElementsAttr using the elements of the vector and the +// optional shape. +mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( + const llvm::ArrayRef vector, mlir::Builder builder, + llvm::ArrayRef shape = {}); + +// Converts the given XLA shape for tensors to the template MLIR type. +template +static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, + mlir::Builder builder) { + auto element_type_or = + ConvertPrimitiveTypeToMlirType(xla_ty.element_type(), builder); + if (!element_type_or.ok()) return element_type_or.status(); + + bool is_bounded_dynamic = false; + int64_t rank = xla_ty.rank(); + llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); + llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t dim_size = xla_ty.dimensions(dim); + if (xla_ty.is_dynamic_dimension(dim)) { + if (!xla_ty.is_unbounded_dynamic_dimension(dim)) { + bounds[dim] = dim_size; + is_bounded_dynamic = true; + } + } else { + shape[dim] = dim_size; + } + } + using mlir::mhlo::TypeExtensionsAttr; + mlir::Attribute encoding; + if (is_bounded_dynamic) { + encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); + } + + using mlir::sparse_tensor::SparseTensorEncodingAttr; + // TODO(b/238903065): We don't yet support bounded dynamism shapes and + // sparsity at the same time, as we can currently only have one `encoding` on + // a RankedTensorType, and we don't currently have a meet of + // SparseTensorEncodingAttr and TypeExtensionsAttr (which holds bounds). + // + // For example, we wouldn't be able to represent the xla type + // `f32[4,<=4]{1,0:D(D,C)}`. + if (xla_ty.has_layout()) { + auto layout = xla_ty.layout(); + if (LayoutUtil::IsSparse(layout)) { + if (is_bounded_dynamic) + return Unimplemented( + "MHLO doesn't support bounded dynamic shapes for sparse tensors"); + llvm::SmallVector lts; + for (size_t i = 0, e = layout.dim_level_types_size(); i < e; ++i) { + auto dlt = layout.dim_level_type(i); + bool ordered = + i < layout.dim_ordered_size() ? layout.dim_ordered(i) : true; + bool unique = + i < layout.dim_unique_size() ? layout.dim_unique(i) : true; + switch (dlt) { + case DimLevelType::DIM_DENSE: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Dense, ordered, unique)); + break; + case DimLevelType::DIM_COMPRESSED: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Compressed, ordered, unique)); + break; + case DimLevelType::DIM_SINGLETON: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::Singleton, ordered, unique)); + break; + case DimLevelType::DIM_LOOSE_COMPRESSED: + lts.push_back(*mlir::sparse_tensor::buildLevelType( + mlir::sparse_tensor::LevelFormat::LooseCompressed, ordered, + unique)); + break; + default: + return InvalidArgument("Unknown DimLevelType from HLO"); + } + } + auto ordering = layout.minor_to_major(); + llvm::SmallVector major_to_minor = {ordering.rbegin(), + ordering.rend()}; + auto id_map = mlir::AffineMap::getPermutationMap(major_to_minor, + builder.getContext()); + // TODO(atondwal): support sizes other than 32 when XLA does + encoding = SparseTensorEncodingAttr::get( + builder.getContext(), lts, id_map, mlir::AffineMap(), 32, 32); + } + } + return TypeT::get(shape, element_type_or.value(), encoding); +} + +absl::StatusOr ConvertTensorShapeToMemRefType( + const Shape& shape, mlir::Builder builder); + +template <> +inline absl::StatusOr ConvertTensorShapeToType( + const Shape& shape, mlir::Builder builder) { + if (shape.is_dynamic()) { + return FailedPrecondition( // NOLINT + "MemRefType don't support dynamic shapes"); + } + return ConvertTensorShapeToMemRefType(shape, builder); +} + +// Converts the given XLA shape to the template MLIR type. +template +static absl::StatusOr ConvertShapeToType(const Shape& shape, + mlir::Builder builder) { + if (shape.IsTuple()) { + llvm::SmallVector contents; + contents.reserve(shape.tuple_shapes_size()); + for (const auto& subtype : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(auto mlir_subtype, + ConvertShapeToType(subtype, builder)); + contents.push_back(mlir_subtype); + } + return builder.getTupleType(contents); + } + if (shape.IsToken()) { + return mlir::mhlo::TokenType::get(builder.getContext()); + } + return ConvertTensorShapeToType(shape, builder); +} + +// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using +// the non-tuple-typed values in 'flatten_values'. +// +// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, +// The function returns %t2 such that: +// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple +// %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> +// +// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to +// resp. flatten and create tuples in the exact same order. +// 2. `flatten_values`, initially storing the flattened values, will be +// mutated to a 0-length array by the end of function invocation. +mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange& flatten_values, mlir::Type type); + +// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. +// Otherwise, return 'op'. +mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, + mlir::Location loc, + mlir::Operation* op, mlir::Type type); + +mlir::TypeRange Untuple(const mlir::Type& type); + +static std::pair GetLayoutAttribute( + mlir::Builder& b, const Shape& shape, + std::optional maybe_layout = std::nullopt) { + if (shape.IsTuple()) { + llvm::SmallVector element_attrs; + llvm::SmallVector tile_attrs; + for (const auto& tuple_shape : shape.tuple_shapes()) { + // TODO here we do not dissect the layout of a tuple into sublayouts. + // Presently ShapeLayout cannot represent an explicit layout for a tuple + // type so this should never occur. However, if this function were to + // be used in another context where this assumption were to be lifted. + // users should be aware of this limitation which will use the default + // layout for tuple subshapes. + std::pair inner = + tuple_shape.has_layout() + ? GetLayoutAttribute(b, tuple_shape, tuple_shape.layout()) + : GetLayoutAttribute(b, tuple_shape); + element_attrs.push_back(inner.first); + tile_attrs.push_back(inner.second); + } + return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), + b.getArrayAttr(tile_attrs)); + } + + Layout layout = maybe_layout.value_or( + shape.has_layout() ? shape.layout() + : LayoutUtil::GetDefaultLayoutForShape(shape)); + + llvm::SmallVector vec_of_tiles; + for (const Tile& tile : layout.tiles()) { + llvm::SmallVector tile_vec = {tile.dimensions().begin(), + tile.dimensions().end()}; + vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); + } + llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), + layout.minor_to_major().end()}; + return std::make_pair(b.getIndexTensorAttr(layout_vec), + b.getArrayAttr(vec_of_tiles)); +} + +static bool HasCustomLayout(const Shape& shape) { + if (shape.IsTuple()) { + return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); + } + return shape.has_layout() && !shape.layout().minor_to_major().empty() && + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); +} + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc similarity index 91% rename from third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc index b16e5870e99d79..28edbc3aab4c7d 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -13,22 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include #include #include #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/DebugStringHelper.h" -#include "xla/literal.h" -#include "xla/literal_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc similarity index 89% rename from third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc index b39d971141240a..3d2f9744d0fee7 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.cc @@ -13,13 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/location_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/location_importer.h" #include #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" -#include "xla/translate/hlo_to_mhlo/stack_location_utils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h" namespace mlir { namespace mhlo { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h similarity index 82% rename from third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h index 23307e7fe135b7..0137fa446b024a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/location_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/location_importer.h @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ #include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" namespace mlir { @@ -30,4 +31,4 @@ mlir::Location GenerateInstructionLocation( } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_LOCATION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc similarity index 72% rename from third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc index a60499f1b9bdf0..75beb87d33224b 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/module_attributes_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h" #include #include @@ -30,6 +30,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/computation_layout.h" @@ -37,8 +39,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -64,24 +64,23 @@ constexpr char kSpmdParametersShardings[] = "mhlo.spmd_parameters_shardings"; constexpr char kUseAutoSpmdPartitioning[] = "mhlo.use_auto_spmd_partitioning"; mlir::ArrayAttr ConvertCrossProgramPrefetches( - const absl::Span prefetches, - const xla::HloComputation& entryComputation, mlir::Builder* builder, + const absl::Span prefetches, + const HloComputation& entryComputation, mlir::Builder* builder, bool flatten_computation_args_result) { llvm::SmallVector shapes; shapes.reserve(prefetches.size()); if (flatten_computation_args_result) { - llvm::SmallVector> + llvm::SmallVector> original_param_index_to_flattened_arg_index; int64_t arg_index = 0; - for (xla::HloInstruction* param_instruction : + for (HloInstruction* param_instruction : entryComputation.parameter_instructions()) { auto& param_map = original_param_index_to_flattened_arg_index.emplace_back(); - xla::ShapeUtil::ForEachLeafShape( - param_instruction->shape(), - [&](const xla::Shape&, const xla::ShapeIndex& index) { - param_map[index] = arg_index++; - }); + ShapeUtil::ForEachLeafShape(param_instruction->shape(), + [&](const Shape&, const ShapeIndex& index) { + param_map[index] = arg_index++; + }); } for (const auto& [parameter, index, alt_memory_offset] : prefetches) shapes.push_back(mlir::mhlo::CrossProgramPrefetchAttr::get( @@ -100,15 +99,35 @@ mlir::ArrayAttr ConvertCrossProgramPrefetches( } void ImportEntryComputationParameterLayoutAndTiles( - const xla::HloModule& hlo_module, mlir::ModuleOp module, - const ComputationLayout& computation_layout, mlir::Builder builder) { + const HloModule& hlo_module, mlir::ModuleOp module, + const ComputationLayout& computation_layout, + bool flatten_computation_args_result, mlir::Builder builder) { llvm::SmallVector parameter_layouts; llvm::SmallVector parameter_tiles; - for (auto& layout : computation_layout.parameter_layouts()) { - if (layout.shape().IsTuple()) { + if (flatten_computation_args_result) { + for (auto& parameter_layout : computation_layout.parameter_layouts()) { + xla::ShapeUtil::ForEachLeafShape( + parameter_layout.shape(), + [&](const xla::Shape& subshape, const xla::ShapeIndex& index) { + std::pair layout_attrs = + GetLayoutAttribute(builder, subshape); + parameter_layouts.push_back(layout_attrs.first); + parameter_tiles.push_back(layout_attrs.second); + }); + } + module->setAttr(kEntryComputationParameterLayouts, + builder.getArrayAttr({parameter_layouts})); + module->setAttr(kEntryComputationParameterTiles, + builder.getArrayAttr({parameter_tiles})); + return; + } + + for (auto& parameter_layout : computation_layout.parameter_layouts()) { + if (parameter_layout.shape().IsTuple()) { llvm::SmallVector tuple_element_parameter_layouts; llvm::SmallVector tuple_element_parameter_tiles; - for (auto& tuple_element_shape : layout.shape().tuple_shapes()) { + for (auto& tuple_element_shape : + parameter_layout.shape().tuple_shapes()) { std::pair layout_attrs = GetLayoutAttribute(builder, tuple_element_shape); tuple_element_parameter_layouts.push_back(layout_attrs.first); @@ -120,7 +139,7 @@ void ImportEntryComputationParameterLayoutAndTiles( builder.getArrayAttr({tuple_element_parameter_tiles})); } else { std::pair layout_attrs = - GetLayoutAttribute(builder, layout.shape()); + GetLayoutAttribute(builder, parameter_layout.shape()); parameter_layouts.push_back(layout_attrs.first); parameter_tiles.push_back(layout_attrs.second); } @@ -132,11 +151,24 @@ void ImportEntryComputationParameterLayoutAndTiles( } void ImportEntryComputationResultLayoutAndTiles( - const xla::HloModule& hlo_module, mlir::ModuleOp module, - const ComputationLayout& computation_layout, mlir::Builder builder) { + const HloModule& hlo_module, mlir::ModuleOp module, + const ComputationLayout& computation_layout, + bool flatten_computation_args_result, mlir::Builder builder) { + llvm::SmallVector result_layouts; + llvm::SmallVector result_tiles; + if (flatten_computation_args_result) { + xla::ShapeUtil::ForEachLeafShape( + computation_layout.result_layout().shape(), + [&](const xla::Shape& subshape, const xla::ShapeIndex& index) { + std::pair layout_attrs = + GetLayoutAttribute(builder, subshape); + result_layouts.push_back(layout_attrs.first); + result_tiles.push_back(layout_attrs.second); + }); + return; + } + if (computation_layout.result_layout().shape().IsTuple()) { - llvm::SmallVector result_layouts; - llvm::SmallVector result_tiles; for (auto& tuple_element_layout : computation_layout.result_layout().shape().tuple_shapes()) { std::pair layout_attrs = @@ -149,20 +181,21 @@ void ImportEntryComputationResultLayoutAndTiles( builder.getArrayAttr({builder.getArrayAttr(result_layouts)})); module->setAttr(kEntryComputationResultTiles, builder.getArrayAttr({builder.getArrayAttr(result_tiles)})); - } else { - std::pair layout_attrs = - GetLayoutAttribute(builder, computation_layout.result_layout().shape(), - computation_layout.result_layout().layout()); - module->setAttr(kEntryComputationResultLayout, - builder.getArrayAttr({layout_attrs.first})); - module->setAttr(kEntryComputationResultTiles, - builder.getArrayAttr({layout_attrs.second})); + return; } + + std::pair layout_attrs = + GetLayoutAttribute(builder, computation_layout.result_layout().shape(), + computation_layout.result_layout().layout()); + module->setAttr(kEntryComputationResultLayout, + builder.getArrayAttr({layout_attrs.first})); + module->setAttr(kEntryComputationResultTiles, + builder.getArrayAttr({layout_attrs.second})); } } // namespace -void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module, +void ImportCrossProgramPrefetches(const HloModule& hlo_module, mlir::ModuleOp module, bool flatten_computation_args_result, mlir::Builder builder) { @@ -173,8 +206,9 @@ void ImportCrossProgramPrefetches(const xla::HloModule& hlo_module, flatten_computation_args_result)); } -void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, +void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module, mlir::ModuleOp module, + bool flatten_computation_args_result, mlir::Builder builder) { const auto& computation_layout = hlo_module.entry_computation_layout(); if (!computation_layout.LayoutIsSet()) return; @@ -186,16 +220,18 @@ void ImportEntryComputationLayoutAndTiles(const xla::HloModule& hlo_module, [](const ShapeLayout& shape) { return HasCustomLayout(shape.shape()); })) { - ImportEntryComputationParameterLayoutAndTiles(hlo_module, module, - computation_layout, builder); + ImportEntryComputationParameterLayoutAndTiles( + hlo_module, module, computation_layout, flatten_computation_args_result, + builder); } if (HasCustomLayout(computation_layout.result_layout().shape())) { - ImportEntryComputationResultLayoutAndTiles(hlo_module, module, - computation_layout, builder); + ImportEntryComputationResultLayoutAndTiles( + hlo_module, module, computation_layout, flatten_computation_args_result, + builder); } } -void ImportFrontendAttributes(const xla::HloModule& hlo_module, +void ImportFrontendAttributes(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { if (!hlo_module.frontend_attributes().map().empty()) { llvm::SmallVector frontend_attributes; @@ -208,6 +244,19 @@ void ImportFrontendAttributes(const xla::HloModule& hlo_module, } } +void ImportInputOutputAlias(const xla::HloModule& hlo_module, + mlir::ModuleOp module, mlir::Builder builder) { + module->setAttr(kInputOutputAlias, + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder)); +} + +void ImportIsDynamic(const xla::HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder) { + module->setAttr(kIsDynamic, mlir::BoolAttr::get(builder.getContext(), + hlo_module.is_dynamic())); +} + void ImportNumPartitions(const xla::HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { const auto& config = hlo_module.config(); @@ -217,7 +266,7 @@ void ImportNumPartitions(const xla::HloModule& hlo_module, } } -void ImportNumReplicas(const xla::HloModule& hlo_module, mlir::ModuleOp module, +void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { const auto& config = hlo_module.config(); if (config.replica_count() != 1) { @@ -226,19 +275,6 @@ void ImportNumReplicas(const xla::HloModule& hlo_module, mlir::ModuleOp module, } } -void ImportInputOutputAlias(const xla::HloModule& hlo_module, - mlir::ModuleOp module, mlir::Builder builder) { - module->setAttr(kInputOutputAlias, - ConvertInputOutputAlias( - hlo_module.input_output_alias_config(), &builder)); -} - -void ImportIsDynamic(const xla::HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder) { - module->setAttr(kIsDynamic, mlir::BoolAttr::get(builder.getContext(), - hlo_module.is_dynamic())); -} - void ImportSpmdOutputSharding(const xla::HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { if (hlo_module.has_spmd_output_sharding()) @@ -247,7 +283,7 @@ void ImportSpmdOutputSharding(const xla::HloModule& hlo_module, ConvertSharding(hlo_module.spmd_output_sharding(), &builder)); } -void ImportSpmdParametersShardings(const xla::HloModule& hlo_module, +void ImportSpmdParametersShardings(const HloModule& hlo_module, mlir::ModuleOp module, bool flatten_computation_args_result, mlir::Builder builder) { @@ -266,7 +302,7 @@ void ImportSpmdParametersShardings(const xla::HloModule& hlo_module, } } -void ImportUseAutoSpmdPartitioning(const xla::HloModule& hlo_module, +void ImportUseAutoSpmdPartitioning(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder) { module->setAttr(kUseAutoSpmdPartitioning, diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h similarity index 90% rename from third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h index bd4580e5d315a5..7be3cd20274180 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/module_attributes_importer.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/module_attributes_importer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -32,23 +32,24 @@ void ImportCrossProgramPrefetches(const HloModule& hlo_module, void ImportEntryComputationLayoutAndTiles(const HloModule& hlo_module, mlir::ModuleOp module, + bool flatten_computation_args_result, mlir::Builder builder); void ImportFrontendAttributes(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); -void ImportNumPartitions(const HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder); - -void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, - mlir::Builder builder); - void ImportInputOutputAlias(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); void ImportIsDynamic(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); +void ImportNumPartitions(const HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder); + +void ImportNumReplicas(const HloModule& hlo_module, mlir::ModuleOp module, + mlir::Builder builder); + void ImportSpmdOutputSharding(const HloModule& hlo_module, mlir::ModuleOp module, mlir::Builder builder); @@ -63,4 +64,4 @@ void ImportUseAutoSpmdPartitioning(const HloModule& hlo_module, } // namespace xla -#endif // XLA_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_MODULE_ATTRIBUTES_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc similarity index 91% rename from third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc index d09b8ff1d56f18..7cae446ee8ab84 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/stack_location_utils.h" +#include "xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h" #include +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_module.h" namespace mlir { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h similarity index 85% rename from third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h index 09df9d91d148fb..f5210d558d9152 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/stack_location_utils.h +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/stack_location_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ -#define XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" @@ -31,4 +31,4 @@ mlir::Location GetLocationFromFrameIndex(int frame_id, mlir::Builder &builder, } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_STACK_LOCATION_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD similarity index 92% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD index 8e0d7d63df3551..6b190f8a1cd9e4 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/BUILD @@ -11,6 +11,7 @@ lit_test_suite( [ "attributes.hlo", "bool_compare.hlo", + "call.hlo", "case_conditional.hlo", "composite_call.hlo", "custom_call.hlo", @@ -29,6 +30,7 @@ lit_test_suite( "module_config.hlo", "simple.hlo", "spmd_module_sharding.hlo", + "stablehlo.hlo", "stacktrace_to_location.hlo", "types.hlo", "while.hlo", @@ -39,7 +41,7 @@ lit_test_suite( ), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo similarity index 90% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo index 0e927b07c5ab29..a76befcd04bc13 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/attributes.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/attributes.hlo @@ -7,7 +7,7 @@ HloModule dot_algorithm_f8_f8_f32, entry_computation_layout={(f32[2,2,2]{2,1,0}, ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:1 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:1 offset " source_line=7} } // ----- @@ -20,7 +20,7 @@ HloModule dot_algorithm_f8_f8_f32_fast_accum, entry_computation_layout={(f32[2,2 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32_fast_accum, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:23 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_any_f8_any_f8_f32_fast_accum, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:23 offset " source_line=7} } // ----- @@ -32,7 +32,7 @@ HloModule dot_algorithm_f16_f16_f16, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:45 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:45 offset " source_line=7} } // ----- @@ -44,7 +44,7 @@ HloModule dot_algorithm_f16_f16_f32, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:67 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f16_f16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:67 offset " source_line=7} } // ----- @@ -56,7 +56,7 @@ HloModule dot_algorithm_bf16_bf16_bf16, entry_computation_layout={(bf16[2,2,2]{2 ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_bf16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:89 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_bf16, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:89 offset " source_line=7} } // ----- @@ -68,7 +68,7 @@ HloModule dot_algorithm_bf16_bf16_f32, entry_computation_layout={(bf16[2,2,2]{2, ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:111 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:111 offset " source_line=7} } // ----- @@ -80,7 +80,7 @@ HloModule dot_algorithm_bf16_bf16_f32_x3, entry_computation_layout={(bf16[2,2,2] ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:133 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:133 offset " source_line=7} } // ----- @@ -92,7 +92,7 @@ HloModule dot_algorithm_bf16_bf16_f32_x6, entry_computation_layout={(bf16[2,2,2] ENTRY %main.4 (Arg_0.1: bf16[2,2,2], Arg_1.2: bf16[2,2,2]) -> bf16[2,2,2] { %Arg_0.1 = bf16[2,2,2] parameter(0) %Arg_1.2 = bf16[2,2,2] parameter(1) - ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x6, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:155 offset " source_line=7} + ROOT %dot.3 = bf16[2,2,2] dot(bf16[2,2,2] %Arg_0.1, bf16[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_bf16_bf16_f32_x6, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:155 offset " source_line=7} } // ----- @@ -104,7 +104,7 @@ HloModule dot_algorithm_tf32_tf32_f32, entry_computation_layout={(f32[2,2,2]{2,1 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:177 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:177 offset " source_line=7} } // ----- @@ -116,7 +116,7 @@ HloModule dot_algorithm_tf32_tf32_f32_x3, entry_computation_layout={(f32[2,2,2]{ ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:199 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_tf32_tf32_f32_x3, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:199 offset " source_line=7} } // ----- @@ -128,7 +128,7 @@ HloModule dot_algorithm_f32_f32_f32, entry_computation_layout={(f32[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f32[2,2,2], Arg_1.2: f32[2,2,2]) -> f32[2,2,2] { %Arg_0.1 = f32[2,2,2] parameter(0) %Arg_1.2 = f32[2,2,2] parameter(1) - ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f32_f32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:221 offset " source_line=7} + ROOT %dot.3 = f32[2,2,2] dot(f32[2,2,2] %Arg_0.1, f32[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f32_f32_f32, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:221 offset " source_line=7} } // ----- @@ -140,5 +140,5 @@ HloModule dot_algorithm_f64_f64_f64, entry_computation_layout={(f64[2,2,2]{2,1,0 ENTRY %main.4 (Arg_0.1: f64[2,2,2], Arg_1.2: f64[2,2,2]) -> f64[2,2,2] { %Arg_0.1 = f64[2,2,2] parameter(0) %Arg_1.2 = f64[2,2,2] parameter(1) - ROOT %dot.3 = f64[2,2,2] dot(f64[2,2,2] %Arg_0.1, f64[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f64_f64_f64, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/attributes.mlir:243 offset " source_line=7} + ROOT %dot.3 = f64[2,2,2] dot(f64[2,2,2] %Arg_0.1, f64[2,2,2] %Arg_1.2), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, algorithm=dot_f64_f64_f64, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir:243 offset " source_line=7} } diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/bool_compare.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/bool_compare.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/bool_compare.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/bool_compare.hlo diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/call.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/call.hlo new file mode 100644 index 00000000000000..89c92854ad3db8 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/call.hlo @@ -0,0 +1,15 @@ +// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s + +HloModule CallWithBackendConfig + +%g.2 (Arg_0.3: s32[8,2]) -> s32[8,2] { + %Arg_0.3 = s32[8,2]{1,0} parameter(0) + ROOT %multiply.4 = s32[8,2]{1,0} multiply(s32[8,2]{1,0} %Arg_0.3, s32[8,2]{1,0} %Arg_0.3) +} + +ENTRY %main.9 (Arg_0.1: s32[8,2]) -> s32[8,2] { + %Arg_0.1 = s32[8,2]{1,0} parameter(0) + // CHECK: call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %call.5 = s32[8,2]{1,0} call(s32[8,2]{1,0} %Arg_0.1), to_apply=%g.2, backend_config={"flag_configs":[],"scoped_memory_configs":[],"device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]} + ROOT %custom-call = s32[8,2]{1,0} custom-call(s32[8,2]{1,0} %call.5), custom_call_target="MoveToHost" +} diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/case_conditional.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/case_conditional.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/case_conditional.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/case_conditional.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/composite_call.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/composite_call.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/custom_call.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/custom_call.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/custom_call.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/dynamic_param.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/dynamic_param.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/dynamic_param.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/dynamic_param.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/frontend_attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/frontend_attributes.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/frontend_attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/frontend_attributes.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fully_connected_reference_model.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/fusion.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fusion.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/fusion.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/fusion.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/if_conditional.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/if_conditional.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/if_conditional.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/if_conditional.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo similarity index 98% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo index 68ad73882e02a7..3a1e7ceabb160f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import.hlo @@ -410,11 +410,17 @@ add { // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3B11FNUZ> %constant.9 = f8e4m3b11fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> + // CHECK: %[[VAL_10:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3FNUZ> %constant.10 = f8e4m3fnuz[4] constant({1, 2, 3, 4}) - // CHECK: %[[VAL_9:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: %[[VAL_11:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> %constant.11 = f8e5m2fnuz[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + %constant.12 = f8e4m3[4] constant({1, 2, 3, 4}) + + // CHECK: %[[VAL_13:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + %constant.13 = f8e3m4[4] constant({1, 2, 3, 4}) } // TODO(b/129422361) Potentially update when copy, reshape, and conv have actual @@ -524,7 +530,19 @@ add { %convert.11 = f8e5m2fnuz[4] convert(f32[4] %convert.10) // CHECK-NEXT: %9 = mhlo.convert %8 : (tensor<4xf8E5M2FNUZ>) -> tensor<4xf32> - ROOT %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + %convert.12 = f32[4] convert(f8e5m2fnuz[4] %convert.11) + + // CHECK-NEXT: %10 = mhlo.convert %9 : (tensor<4xf32>) -> tensor<4xf8E4M3> + %convert.13 = f8e4m3[4] convert(f32[4] %convert.12) + + // CHECK-NEXT: %11 = mhlo.convert %10 : (tensor<4xf8E4M3>) -> tensor<4xf32> + %convert.14 = f32[4] convert(f8e4m3[4] %convert.13) + + // CHECK-NEXT: %12 = mhlo.convert %11 : (tensor<4xf32>) -> tensor<4xf8E3M4> + %convert.15 = f8e3m4[4] convert(f32[4] %convert.14) + + // CHECK-NEXT: %13 = mhlo.convert %12 : (tensor<4xf8E3M4>) -> tensor<4xf32> + ROOT %convert.16 = f32[4] convert(f8e3m4[4] %convert.15) } // CHECK-LABEL: func private @test_stochastic_convert(%arg0: tensor<4x3xf32>, %arg1: tensor<4x3xui32>) -> tensor<4x3xi8> diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo similarity index 78% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo index 4e9633014b332b..5aa09777f30022 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async.hlo @@ -41,8 +41,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] { %Arg_0.1 = f32[128,32] parameter(0) - %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} - ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} + %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} + ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} } // ----- @@ -52,7 +52,7 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} %region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { %Arg_0.3 = f32[] parameter(0) %Arg_1.4 = f32[] parameter(1) - ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} + ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} } // CHECK-LABEL: func.func private @all_reduce_ @@ -63,8 +63,8 @@ HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} // CHECK: mhlo.async_done ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] { %Arg_0.1 = f32[10] parameter(0) - %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} - ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} + %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} + ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} } // ----- @@ -79,8 +79,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} // CHECK: mhlo.async_done ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} - ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} + %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} + ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} } // ----- @@ -89,8 +89,8 @@ HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] { %Arg_0.1 = f32[128,32] parameter(0) - %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} - ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} + %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} + ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} } // ----- @@ -99,10 +99,10 @@ HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) { %Arg_0.1 = token[] parameter(0) - %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} - %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} - %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} + %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5) } @@ -113,8 +113,8 @@ HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { %Arg_0.1 = s32[3,4] parameter(0) %Arg_1.2 = token[] parameter(1) - %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} - ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} + %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} + ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} } @@ -124,18 +124,18 @@ ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) { // %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} // } // HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} // ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) { // %Arg_0.1 = f32[10] parameter(0) -// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} -// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} -// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} // } @@ -146,17 +146,17 @@ ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { // HloModule main, entry_computation_layout={(token[])->token[]} // ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] { -// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} // %Arg_0.1 = token[] parameter(0) -// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} -// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} +// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} // } // HloModule main, entry_computation_layout={(token[])->((), token[])} // ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) { // %Arg_0.1 = token[] parameter(0) -// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} -// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} +// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} +// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} // } diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_async2.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/import_entry_computation_layout.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/layouts_and_names.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/layouts_and_names.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/layouts_and_names.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/layouts_and_names.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/location.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/location.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/location.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/location.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_attributes.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_attributes.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_config.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_config.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/module_config.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/module_config.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/simple.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/simple.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/simple.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/simple.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/spmd_module_sharding.hlo diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stablehlo.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stablehlo.hlo new file mode 100644 index 00000000000000..3de10f7527960b --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stablehlo.hlo @@ -0,0 +1,21 @@ +// RUN: xla-translate -hlo-text-to-stablehlo %s -o - | FileCheck %s + +// CHECK: module @foobar +HloModule foobar + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { + ROOT %Arg_0.1 = f32[] parameter(0) +} + +// CHECK-LABEL: func private @test_simple +%test_simple (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] { + %Arg_0.1 = f32[4]{0} parameter(0) + %Arg_1.2 = f32[4]{0} parameter(1) + + // CHECK-NEXT: stablehlo.add %arg0, %arg1 : tensor<4xf32> + %add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) + + // CHECK-NEXT: stablehlo.dot %0, %arg1, precision = [DEFAULT, DEFAULT] : (tensor<4xf32>, tensor<4xf32>) -> tensor + ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} \ No newline at end of file diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/stacktrace_to_location.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/types.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/types.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/types.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/types.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/while.hlo b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/while.hlo similarity index 100% rename from third_party/xla/xla/translate/hlo_to_mhlo/tests/while.hlo rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/tests/while.hlo diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc similarity index 80% rename from third_party/xla/xla/translate/hlo_to_mhlo/translate.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc index 362e0e19ad8795..8fef24fead917d 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.cc @@ -12,19 +12,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/hlo_to_mhlo/translate.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" +#include + +#include "absl/log/log.h" #include "absl/status/status.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/Pass/PassManager.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "tsl/platform/protobuf.h" namespace xla { @@ -97,10 +99,10 @@ mlir::OwningOpRef HloTextToMlirHloTranslateFunction( } mlir::OwningOpRef HloToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations, bool flatten_computation_args_result) { + llvm::StringRef input, mlir::MLIRContext* context) { auto module = xla::HloToMlirHloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); + input, context, /*import_all_computations=*/true, + /*flatten_computation_args_result=*/true); mlir::PassManager pm(module->getContext()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); if (failed(pm.run(*module))) { @@ -112,18 +114,23 @@ mlir::OwningOpRef HloToStablehloTranslateFunction( } mlir::OwningOpRef HloTextToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations, bool flatten_computation_args_result) { - auto module = xla::HloTextToMlirHloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); - mlir::PassManager pm(module->getContext()); - pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); - if (failed(pm.run(*module))) { - module->emitError("Failed to legalize to StableHLO"); + llvm::StringRef input, mlir::MLIRContext* context) { + std::string content(input.data(), input.size()); + + auto hlo_module_error = ParseAndReturnUnverifiedModule(content); + if (!hlo_module_error.ok()) { + LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status(); return nullptr; } - return module; + auto stablehlo_module = + ConvertHloToStablehlo(*context, hlo_module_error.value().get()); + if (!stablehlo_module.ok()) { + LOG(ERROR) << "HLO Module import failed: " << stablehlo_module.status(); + return nullptr; + } + + return std::move(stablehlo_module.value()); } } // namespace xla diff --git a/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h new file mode 100644 index 00000000000000..69b06bf9a60813 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate.h @@ -0,0 +1,74 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ + +namespace llvm { +class StringRef; +} // namespace llvm + +namespace mlir { +class MLIRContext; +class ModuleOp; +template +class OwningOpRef; +} // namespace mlir + +namespace xla { + +// Converts a HloModuleProto stored in the file with the given `input_filename` +// into a MHLO module. Creates MLIR entities into the given MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloToMlirHloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts a HloModule stored in text form for a file with the given +// `input_filename` into a MHLO module. Creates MLIR entities into the given +// MLIR `context`. +// +// If `import_all_computation` is set to true, imports all computations +// irrespective if transitively called from entry computation. +// +// If `flatten_computation_args_result` is set to true, flattens all tuple +// arguments and result of every computation when importing them as func ops. +mlir::OwningOpRef HloTextToMlirHloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + +// Converts a HloModuleProto stored in the file with the given `input_filename` +// into a StableHLO module. Creates MLIR entities into the given MLIR `context`. +// +mlir::OwningOpRef HloToStablehloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context); + +// Converts a HloModule stored in text form for a file with the given +// `input_filename` into a StableHLO module. Creates MLIR entities into the +// given MLIR `context`. +// +mlir::OwningOpRef HloTextToStablehloTranslateFunction( + llvm::StringRef input, mlir::MLIRContext* context); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc similarity index 74% rename from third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc index b3d8f2f97b414b..c1ef2a37675c83 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/hlo_to_mhlo/translate_registration.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/Support/CommandLine.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/Tools/mlir-translate/Translation.h" -#include "xla/translate/hlo_to_mhlo/translate.h" +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" namespace { // NOLINTNEXTLINE @@ -43,14 +46,24 @@ static mlir::OwningOpRef HloTextToMlirHloTranslate( static mlir::OwningOpRef HloToStablehloTranslate( llvm::StringRef input, mlir::MLIRContext* context) { - return xla::HloToStablehloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); + if (!flatten_computation_args_result.getValue() || + !import_all_computations.getValue()) { + mlir::emitWarning(mlir::UnknownLoc::get(context), + "HLO => StableHLO requires flattened_args and " + "import_all_computations to be set to true."); + } + return xla::HloToStablehloTranslateFunction(input, context); } static mlir::OwningOpRef HloTextToStablehloTranslate( llvm::StringRef input, mlir::MLIRContext* context) { - return xla::HloTextToStablehloTranslateFunction( - input, context, import_all_computations, flatten_computation_args_result); + if (!flatten_computation_args_result.getValue() || + !import_all_computations.getValue()) { + mlir::emitWarning(mlir::UnknownLoc::get(context), + "HLO => StableHLO requires flattened_args and " + "import_all_computations to be set to true."); + } + return xla::HloTextToStablehloTranslateFunction(input, context); } static mlir::TranslateToMLIRRegistration HloToMlirHloTranslateRegistration( diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD new file mode 100644 index 00000000000000..9aacdd79aa524f --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/BUILD @@ -0,0 +1,323 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "attribute_exporter", + srcs = ["attribute_exporter.cc"], + hdrs = ["attribute_exporter.h"], + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/mlir_hlo", + "//xla/service:hlo_proto_cc", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:base", + ], +) + +cc_library( + name = "layout_util", + srcs = ["layout_util.cc"], + hdrs = ["layout_util.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "literal_exporter", + srcs = ["literal_exporter.cc"], + hdrs = ["literal_exporter.h"], + deps = [ + ":type_to_shape", + "//xla:array", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + ], +) + +cc_library( + name = "location_exporter", + srcs = ["location_exporter.cc"], + hdrs = ["location_exporter.h"], + deps = [ + ":stack_frame_index_builder", + "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "module_attributes_exporter", + srcs = ["module_attributes_exporter.cc"], + hdrs = ["module_attributes_exporter.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "stack_frame_index_builder", + srcs = ["stack_frame_index_builder.cc"], + hdrs = ["stack_frame_index_builder.h"], + deps = [ + "//xla/service:hlo_proto_cc", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "mlir_hlo_to_hlo", + srcs = [ + "mlir_hlo_to_hlo.cc", + "operator_writers.inc", + ], + hdrs = ["mlir_hlo_to_hlo.h"], + deps = [ + ":attribute_exporter", + ":layout_util", + ":literal_exporter", + ":location_exporter", + ":module_attributes_exporter", + ":operator_writer_inc", + ":stack_frame_index_builder", + ":type_to_shape", + "//xla:array", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:xla_data_proto_cc", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/client/lib:approx_topk", + "//xla/client/lib:approx_topk_shape", + "//xla/client/lib:matrix", + "//xla/client/lib:quantize", + "//xla/client/lib:slicing", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:approx_topk", + "//xla/hlo/builder/lib:approx_topk_shape", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/mlir/utils:error_util", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "//xla/mlir_hlo:mhlo_passes", + "//xla/service:computation_layout", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:types", + "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", + ], +) + +build_test( + name = "operator_writer_gen_build_test", + targets = [ + ":operator_writer_gen", + ], +) + +cc_binary( + name = "operator_writer_gen", + srcs = ["operator_writer_gen.cc"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TableGen", + "@llvm-project//mlir:TableGen", + ], +) + +gentbl_cc_library( + name = "operator_writer_inc", + compatible_with = get_compatible_with_portable(), + tbl_outs = [([], "operator_writers.inc")], + tblgen = ":operator_writer_gen", + td_file = "//xla/mlir_hlo:mhlo/IR/hlo_ops.td", + deps = [ + "//xla/mlir_hlo:hlo_ops_td_files", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +xla_cc_test( + name = "mlir_hlo_to_hlo_test", + srcs = ["mlir_hlo_to_hlo_test.cc"], + deps = [ + ":mlir_hlo_to_hlo", + "//xla/mlir/utils:error_util", + "//xla/tsl/lib/core:status_test_util", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:ShapeDialect", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@stablehlo//:register", + ], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + ":mlir_hlo_to_hlo", + ":type_to_shape", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/service:hlo_module_config", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_proto_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = [ + "translate_registration.cc", + "translate_registration.h", + ], + deps = [ + ":translate", + "//xla/mlir_hlo:hlo_dialect_registration", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + ], + alwayslink = 1, +) + +cc_library( + name = "type_to_shape", + srcs = ["type_to_shape.cc"], + hdrs = ["type_to_shape.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/mlir/utils:type_util", + "//xla/mlir_hlo", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + +xla_cc_test( + name = "type_to_shape_test", + srcs = ["type_to_shape_test.cc"], + deps = [ + ":type_to_shape", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/mlir_hlo", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc index ae8537283d9148..0246563c455a16 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include @@ -25,9 +25,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "stablehlo/dialect/Base.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h new file mode 100644 index 00000000000000..bc8344ce11b01d --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h @@ -0,0 +1,75 @@ +/* Copyright 2020 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ + +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/dnn.h" +#include "xla/types.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Converts the conv dimensions attribute to XLA HLO. +ConvolutionDimensionNumbers ConvertConvDimensionNumbers( + mlir::mhlo::ConvDimensionNumbersAttr input); + +// Converts the dot algorithm attribute to XLA HLO. +absl::StatusOr ConvertDotAlgorithm( + mlir::mhlo::DotAlgorithmAttr attr); + +absl::StatusOr> ConvertReplicaGroups( + mlir::DenseIntElementsAttr input); + +// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding +// and source-target pairs are defined in HLO. +absl::StatusOr>> ConvertNx2Attribute( + std::optional optional_attr); + +absl::StatusOr ConvertTranspose( + llvm::StringRef transpose_string); + +absl::StatusOr ConvertCustomCallSchedule( + mlir::mhlo::CustomCallSchedule schedule); + +absl::StatusOr ConvertCustomCallApiVersion( + mlir::mhlo::CustomCallApiVersion api_version); + +absl::StatusOr< + std::vector>>> +ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); + +// Returns an OpSharding that represents the result of parsing the given string: +// first, as serialized protobuf, and then as prettyprinted representation. +// Will fail if both attempts at parsing failed. +std::optional ConvertSharding(mlir::StringRef sharding); + +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + +} // namespace xla +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc index a07bba4004b59f..252ce543813719 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.cc @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h new file mode 100644 index 00000000000000..ab00814a102df5 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/layout_util.h @@ -0,0 +1,85 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utilities for working with XLA layout and shapes. + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace mlir { + +// XLA Layout preferences. Currently, when it comes to TPU, there are two +// primary layout choices for any XLA arguments (parameter or resource): (1) +// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU +// layout while Linear is native host (CPU) layout. +// This enum allows the caller of XLA to propagate layout preference to the XLA +// compiler. +// kNoPreference: the generic layout where the XLA compiler has the freedom +// to assign any layout. +// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. +// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may +// insert transformation TPU kernels. +// As the layout of any argument will change from a native host layout to a +// native TPU layout either on host or on device, XLA compiler and TPU runtime +// must be in coordination to transform the parameters in a consistent way. +enum class XlaLayoutPreference { + kNoPreference = 0, + kTpuPreferCompactChunkPaddedLayout = 1, + kTpuPreferLinearLayout = 2 +}; + +// The following defines the layout preference of an xla tensor. +// The return value of LayoutPreferenceFn can be used in +// ShapeRepresentationFn. +typedef std::function( + const xla::Shape& shape)> + LayoutPreferenceFn; + +typedef std::function( + const xla::Shape& shape, bool fast_mem, + XlaLayoutPreference layout_preference)> + ShapeRepresentationFn; + +// Return a LayoutPreferenceFn that always uses kNoPreference layout. +LayoutPreferenceFn UseNoPreferenceLayoutFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +absl::Status RewriteLayoutWithShardedShape( + const std::optional& sharding, bool use_fast_memory, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + const LayoutPreferenceFn& layout_preference_fn, + const ShapeRepresentationFn& shape_representation_fn, + std::optional sharding, bool fast_mem); + +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc new file mode 100644 index 00000000000000..821f1487cf88c1 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.cc @@ -0,0 +1,90 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/APInt.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Support/LLVM.h" +#include "xla/array.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/shape.h" + +namespace mlir { +namespace mhlo { + +template +xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { + constexpr xla::PrimitiveType type = + xla::primitive_util::NativeToPrimitiveType(); + xla::Shape shape = xla::TypeToShape(dense_attr.getType()); + xla::Array array(shape.dimensions()); + if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { + array.SetValues(dense_attr.getValues()); + } else { + // The only way to get subbyte integers from getValues() is to get them as + // APInts. + auto values = dense_attr.getValues(); + for (int i = 0; i < values.size(); i++) { + if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) { + array.data()[i] = T{values[i].getZExtValue()}; + } else { + static_assert(xla::primitive_util::IsSignedIntegralType(type)); + array.data()[i] = T{values[i].getSExtValue()}; + } + } + } + return array; +} + +absl::StatusOr CreateLiteralFromAttribute(mlir::ElementsAttr attr, + xla::Layout layout) { + auto dense_attr = mlir::dyn_cast(attr); + if (!dense_attr) + return absl::UnimplementedError("Only dense elements attr are supported"); + + xla::Shape shape = xla::TypeToShape(dense_attr.getType()); + + return xla::primitive_util::PrimitiveTypeSwitch>( + [&](auto primitive_type_constant) -> absl::StatusOr { + if constexpr (xla::primitive_util::IsArrayType( + primitive_type_constant)) { + using cpp_type = + xla::primitive_util::NativeTypeOf; + xla::Array source_data = + ArrayFromDenseElementsAttr(dense_attr); + if (layout.minor_to_major().empty()) { + return xla::LiteralUtil::CreateFromArray(source_data); + } + return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, + layout); + } + return absl::InternalError(absl::StrCat( // NOLINT + "Unsupported type: ", + xla::PrimitiveType_Name(shape.element_type()))); + }, + shape.element_type()); +} + +} // namespace mhlo +} // namespace mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h new file mode 100644 index 00000000000000..f5cb3c74a2819a --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/literal_exporter.h @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ + +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "xla/layout.h" +#include "xla/literal.h" + +namespace mlir { +namespace mhlo { + +absl::StatusOr CreateLiteralFromAttribute(mlir::ElementsAttr attr, + xla::Layout layout); + +} // namespace mhlo +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LITERAL_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc index bb274f6bb99a51..c7f80898a21321 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/location_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #include @@ -26,7 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/Visitors.h" #include "mlir/Support/LLVM.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" #include "xla/xla_data.pb.h" namespace mlir { diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h new file mode 100644 index 00000000000000..70ab1d6395076a --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/location_exporter.h @@ -0,0 +1,44 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ + +#include + +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/xla_data.pb.h" + +namespace mlir { +namespace mhlo { + +// Returns a OpMetadata proto based on the location of the op. If the location +// is unknown, an empty proto is returned. `op_name` are populated with the op +// location (converted). FileLineColLoc locations are populated by taking the +// file name and line number, and populating `source_file` and `source_line` +// respectively. +xla::OpMetadata CreateOpMetadataFromLocation( + Operation* op, StackFrameIndexBuilder* frame_index_builder); + +// Returns a name that can be used for debugging purposes, e.g., naming +// variable names in generated IR or producing logging output. +std::string GetDebugNameFromLocation(Location location); + +} // namespace mhlo +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index db21cfb6095d18..0807964addc5e1 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include #include @@ -64,18 +64,26 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" #include "stablehlo/dialect/Base.h" #include "xla/array.h" -#include "xla/client/lib/approx_topk.h" -#include "xla/client/lib/approx_topk_shape.h" -#include "xla/client/lib/matrix.h" // IWYU pragma: keep -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/lib/approx_topk.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/hlo/builder/lib/matrix.h" // IWYU pragma: keep +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/dynamic_parameter_binding.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/hlo/translate/mhlo_to_hlo/literal_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -88,15 +96,8 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_hlo/module_attributes_exporter.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -214,56 +215,6 @@ bool IsBoundedOrStatic(mlir::Type ty) { return true; } -template -xla::Array ArrayFromDenseElementsAttr(mlir::DenseElementsAttr dense_attr) { - constexpr xla::PrimitiveType type = - xla::primitive_util::NativeToPrimitiveType(); - xla::Shape shape = xla::TypeToShape(dense_attr.getType()); - xla::Array array(shape.dimensions()); - if constexpr (!xla::primitive_util::IsSubByteNonPredType(type)) { - array.SetValues(dense_attr.getValues()); - } else { - // The only way to get subbyte integers from getValues() is to get them as - // APInts. - auto values = dense_attr.getValues(); - for (int i = 0; i < values.size(); i++) { - if constexpr (xla::primitive_util::IsUnsignedIntegralType(type)) { - array.data()[i] = T{values[i].getZExtValue()}; - } else { - static_assert(xla::primitive_util::IsSignedIntegralType(type)); - array.data()[i] = T{values[i].getSExtValue()}; - } - } - } - return array; -} - -absl::StatusOr CreateArrayLiteralFromAttr(mlir::ElementsAttr attr, - xla::Layout layout) { - auto dense_attr = mlir::dyn_cast(attr); - if (!dense_attr) - return tsl::errors::Unimplemented("Only dense elements attr are supported"); - - xla::Shape shape = xla::TypeToShape(dense_attr.getType()); - - return xla::primitive_util::PrimitiveTypeSwitch>( - [&](auto primitive_type_constant) -> absl::StatusOr { - if constexpr (xla::primitive_util::IsArrayType( - primitive_type_constant)) { - using cpp_type = - xla::primitive_util::NativeTypeOf; - xla::Array source_data = - ArrayFromDenseElementsAttr(dense_attr); - return xla::LiteralUtil::CreateFromArrayWithLayout(source_data, - layout); - } - return tsl::errors::Internal(absl::StrCat( // NOLINT - "Unsupported type: ", - xla::PrimitiveType_Name(shape.element_type()))); - }, - shape.element_type()); -} - // Convert APInt into an int. // TODO(hpucha): This should be consolidated into a general place. static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } @@ -648,23 +599,31 @@ static std::optional CreateOpShardingFromAttribute( // Returns a FrontendAttributes proto from the "frontend_attributes" attribute // of the op. An empty FrontendAttributes proto is returned if an op does not // have frontend attributes. -void ConstructFrontendAttributesFromAttribute( - const mlir::DictionaryAttr& frontend_attributes_dict, - xla::FrontendAttributes& frontend_attributes) { - for (const auto& attr : frontend_attributes_dict) +void CreateFrontendAttributes(mlir::ArrayRef named_attrs, + xla::FrontendAttributes& frontend_attributes) { + for (const auto& attr : named_attrs) if (auto value_str_attr = mlir::dyn_cast(attr.getValue())) frontend_attributes.mutable_map()->insert( {attr.getName().str(), value_str_attr.getValue().str()}); } +// Returns a FrontendAttributes proto from the "frontend_attributes" attribute +// of the op. An empty FrontendAttributes proto is returned if an op does not +// have frontend attributes. +void CreateFrontendAttributes( + const mlir::DictionaryAttr& frontend_attributes_dict, + xla::FrontendAttributes& frontend_attributes) { + CreateFrontendAttributes(frontend_attributes_dict.getValue(), + frontend_attributes); +} + static xla::FrontendAttributes CreateXlaFrontendAttributesFromOp( mlir::Operation* op) { xla::FrontendAttributes frontend_attributes; auto frontend_attributes_dict = op->getAttrOfType(kMhloFrontendAttributes); if (!frontend_attributes_dict) return frontend_attributes; - ConstructFrontendAttributesFromAttribute(frontend_attributes_dict, - frontend_attributes); + CreateFrontendAttributes(frontend_attributes_dict, frontend_attributes); return frontend_attributes; } @@ -676,7 +635,7 @@ static void ExtractFrontendAttributesFromFunction( if (auto fe_attr = function.getArgAttrOfType( i, kMhloFrontendAttributes)) { xla::FrontendAttributes frontend_attributes; - ConstructFrontendAttributesFromAttribute(fe_attr, frontend_attributes); + CreateFrontendAttributes(fe_attr, frontend_attributes); (*fe_attrs)[i] = frontend_attributes; } } @@ -1651,6 +1610,9 @@ LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { if (!algorithm.ok()) { return op.emitError(algorithm.status().ToString()); } + if (precision_config == nullptr) { + precision_config = std::make_unique(); + } precision_config->set_algorithm(algorithm.value()); } auto xlaOp = xla::DotGeneral( @@ -2257,7 +2219,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { const xla::Literal* literal_ptr = nullptr; auto literal_attr = op->getAttrOfType(kMhloLiteral); if (literal_attr) { - literal = CreateArrayLiteralFromAttr(literal_attr, {}); + literal = mhlo::CreateLiteralFromAttribute(literal_attr, {}); if (!literal.ok()) return failure(); literal_ptr = &*literal; } @@ -3020,7 +2982,7 @@ LogicalResult ExportXlaOp(MinimumBroadcastShapesOp op, OpLoweringContext ctx) { } // namespace mhlo } // namespace mlir -#include "xla/translate/mhlo_to_hlo/operator_writers.inc" +#include "xla/hlo/translate/mhlo_to_hlo/operator_writers.inc" namespace mlir { namespace { @@ -3301,7 +3263,8 @@ LogicalResult ConvertToHloModule::LowerConstant( mlir::FailureOr shape_or = ExtractXlaShape(inst); if (failed(shape_or)) return failure(); - auto literal_or = CreateArrayLiteralFromAttr(const_attr, shape_or->layout()); + auto literal_or = + mhlo::CreateLiteralFromAttribute(const_attr, shape_or->layout()); if (!literal_or.ok()) return inst->emitError(literal_or.status().ToString()); xla::XlaScopedShardingAssignment scoped_sharding( @@ -3524,6 +3487,8 @@ LogicalResult ConvertToHloModule::Lower( LogicalResult ConvertToHloModule::LowerFunctionCall( mlir::func::CallOp call_op, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering) { + xla::XlaScopedShardingAssignment scoped_sharding( + builder, CreateOpShardingFromAttribute(call_op)); auto& value_map = *value_lowering; mlir::func::FuncOp callee = module_.lookupSymbol(call_op.getCallee()); @@ -3541,10 +3506,23 @@ LogicalResult ConvertToHloModule::LowerFunctionCall( // callees, but eventually before lowering call graph is "flattened" to // make that true. This is done before lowering because buffer assignment // needs this invariant. + + // Remove the backend_config from the frontend attributes. xla::FrontendAttributes fe_attrs = CreateXlaFrontendAttributesFromOp(call_op); + std::string backend_config = ""; + auto fe_attrs_map = fe_attrs.mutable_map(); + if (fe_attrs_map->contains(kBackendConfig)) { + backend_config = fe_attrs_map->at(kBackendConfig); + fe_attrs_map->erase(kBackendConfig); + } xla::XlaScopedFrontendAttributesAssignment assignment(builder, fe_attrs); xla::XlaOp call_result = xla::Call(builder, lowered_computation_[callee], operands); + xla::HloInstructionProto* call_instruction = + xla::internal::XlaBuilderFriend::GetInstruction(call_result); + // `call_op` with `backend_config` can appear when round-tripping a program + // that has already run some XLA host communication passes. + call_instruction->set_backend_config(backend_config); // Use GetTupleElement for multiple outputs unsigned num_results = call_op.getNumResults(); if (num_results > 1) { @@ -3619,10 +3597,9 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { // means no replication. This avoids the need for unrelated tests to handle // this field. if (!any_arg_replicated) entry_args_same_across_replicas.clear(); - - ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); ExtractFrontendAttributesFromFunction(f, &arg_fe_attrs); } + ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function, false, entry_args_same_across_replicas, arg_shardings, ret_shardings, @@ -3967,8 +3944,8 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, } if (auto frontend_attributes = module->getAttrOfType(kMhloFrontendAttributes)) { - ConstructFrontendAttributesFromAttribute( - frontend_attributes, *hlo_module.mutable_frontend_attributes()); + CreateFrontendAttributes(frontend_attributes, + *hlo_module.mutable_frontend_attributes()); } if (auto use_auto_spmd_partitioning = module->getAttrOfType(kMhloUseAutoSpmdPartitioning)) { diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h new file mode 100644 index 00000000000000..45a0344a2618df --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -0,0 +1,96 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinOps.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" + +namespace mlir { + +struct MlirToHloConversionOptions { + // Best-effort propagation of the layouts. These layouts serve as performance + // hints to the backend. + // + // Note that non-array shapes are not carrying layouts, and users have to + // figure out the proper layouts of them through context. This is one of the + // reasons why the attribute-based solution is temporary. + // + // TODO(timshen): Investigate the necessity of having layouts in MHLO. + bool propagate_layouts = false; + + // Propagate the source and result layouts from mhlo bitcast op into the + // backend config for the bitcast. This is required for XLA:GPU backend to + // use elemental IR emitters for fused bitcasts without propagating layouts. + bool propagate_bitcast_layouts_to_backend_config = false; + + LayoutPreferenceFn layout_preference_fn; + ShapeRepresentationFn shape_representation_fn; + + // If use_tuple_args is set, then the entry computations's arguments are + // converted to a tuple and passed as a single parameter. + bool use_tuple_args = false; + + // If return tuple is true, then the entry function's return values + // are converted to a tuple even when there is only a single return value. + // Multiple return values are always converted to a tuple and returned as a + // single value. + bool return_tuple = true; +}; + +// Prefer `ConvertMlirHloToHloModule` over this method when possible, as it +// preserves more information and abstracts away the proto. This method is +// preserved for legacy reasons. +// TODO (b/345806521): Migrate callsites to ConvertMlirHloToHloModule, +// and delete this method. +// +// Converts a MLIR module in HLO dialect into a HloModuleProto. +// +absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, + ::xla::HloProto* hlo_proto, + bool use_tuple_args, bool return_tuple, + MlirToHloConversionOptions options = {}); + +// Converts a MLIR module in HLO dialect into a HloModule with HloModuleConfig. +// This method preserves config data stored in MHLO module attributes. +// +// See `MlirToHloConversionOptions` for details on conversion flags. +absl::StatusOr> ConvertMlirHloToHloModule( + mlir::ModuleOp module, MlirToHloConversionOptions options = {}); + +// Transforms a Block into HLO, where the HLO is represented as calls into an +// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. +// xla_params are inputs to block. returns are the returned XlaOps. +absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, + llvm::ArrayRef xla_params, + std::vector& returns, + MlirToHloConversionOptions options = {}); + +} // namespace mlir + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc similarity index 97% rename from third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc index 10dd9cec91f529..ad96da29cd4cfc 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc index 7022a110572d37..afdae4739d79df 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/module_attributes_exporter.h" +#include "xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h" #include #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h similarity index 89% rename from third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h index ccff0f957e6406..2081f24aee8f97 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_attributes_exporter.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/module_attributes_exporter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ #include "absl/status/status.h" #include "mlir/IR/BuiltinAttributes.h" @@ -48,4 +48,4 @@ absl::Status ExportModuleEntryComputationResultTiles( } // namespace mhlo } // namespace mlir -#endif // XLA_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_MODULE_ATTRIBUTES_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/operator_writer_gen.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/operator_writer_gen.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/operator_writer_gen.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/operator_writer_gen.cc index 0d0a537272f73c..cd6fa0ca171a4b 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/operator_writer_gen.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/operator_writer_gen.cc @@ -140,7 +140,7 @@ static void BuildOperator(const Operator& op, raw_ostream& os) { // The function below has a non-constant reference as that is required by LLVM's // TableGenMain. // NOLINTNEXTLINE -static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { +static bool OperatorWritersMain(raw_ostream& os, const RecordKeeper& records) { emitSourceFileHeader("MLIR XLA Builders", os); // Emit all the helper functions. diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc index 790f606d6457fe..dc96c4192938c3 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" +#include "xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h" #include #include diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h similarity index 88% rename from third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h index db584a3ff58d6a..b8bed27e2ab091 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/stack_frame_index_builder.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/stack_frame_index_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ #include #include @@ -53,4 +53,4 @@ class StackFrameIndexBuilder { }; } // namespace mlir -#endif // XLA_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_STACK_FRAME_INDEX_BUILDER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD similarity index 94% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD index 70abbacdf0394b..0ca5d928c976d8 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD @@ -12,6 +12,7 @@ lit_test_suite( [ "add.mlir", "attributes.mlir", + "call.mlir", "case.mlir", "composite.mlir", "dynamic.mlir", @@ -23,6 +24,7 @@ lit_test_suite( "export_large_constants.mlir", "export_replicas.mlir", "frontend_attributes.mlir", + "function.mlir", "fusion.mlir", "if.mlir", "input_output_aliasing.mlir", @@ -49,7 +51,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/add.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/add.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/add.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/attributes.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir new file mode 100644 index 00000000000000..4bf093139a8595 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/call.mlir @@ -0,0 +1,35 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +module @call_with_backend_config { + func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[8,2]) -> s32[8,2] { + // CHECK-NEXT: %[[ARG0]] = s32[8,2] parameter(0) + // CHECK-NEXT: s32[8,2] call(s32[8,2] %[[ARG0]]), to_apply=%g.{{[0-9.]+}}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"device_type":"DEVICE_TYPE_HOST","used_scoped_memory_configs":[]} + %0 = call @g.2(%arg0) {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> + } + + func.func private @g.2(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } +} + +// ----- + +module @call_with_sharding { + func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[8,2]) -> s32[8,2] { + // CHECK-NEXT: %[[ARG0]] = s32[8,2] parameter(0) + // CHECK-NEXT: s32[8,2] call(s32[8,2] %[[ARG0]]), to_apply=%g.{{[0-9.]+}}, sharding={devices=[2,2]<=[4]} + %0 = call @g.2(%arg0) {mhlo.sharding = "{devices=[2,2]<=[4]}"} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> + } + + func.func private @g.2(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/case.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/case.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/case.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir new file mode 100644 index 00000000000000..0147971a96c32c --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/composite.mlir @@ -0,0 +1,190 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %[[ADD:add.[0-9]+]] ([[ARG0:Arg_0.[0-9]+]]: f32[]) -> f32[] { + // CHECK: %[[ARG0]] = f32[] parameter(0) + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG0]], f32[] %[[CONSTANT]]) + // CHECK: } + // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add, + version = 1 : i32 + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// zero-output composite +module @composite { + //CHECK: HloModule composite, entry_computation_layout={()->()} + //CHECK: %[[RETURN:return.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> () { + //CHECK: %[[ARG]] = f32[] parameter(0) + //CHECK: ROOT %tuple.{{[0-9]+}} = () tuple() + //CHECK: } + //CHECK: ENTRY %main.{{[0-9]+}} () -> () { + //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + //CHECK: %call.5 = () call(f32[] %[[CONSTANT]]), to_apply=%[[RETURN]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: ROOT %tuple.{{[0-9]+}} = () tuple() + //CHECK: } + func.func @main() -> () { + %0 = mhlo.constant dense<4.200000e+01> : tensor + "mhlo.composite"(%0) { + name = "foo.bar", + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @return, + version = 1 : i32 + } : (tensor) -> () + return + } + func.func @return(%arg0: tensor) -> () { + return + } +} + +// ----- + +// multi-output composite +module @composite { + //CHECK: HloModule composite, entry_computation_layout={()->(f32[], f32[])} + //CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> (f32[], f32[]) { + //CHECK: %[[ARG]] = f32[] parameter(0) + //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) + //CHECK: %[[ADDOP:add.[0-9]+]] = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(f32[] %[[ADDOP]], f32[] %[[ADDOP]]) + //CHECK: } + //CHECK: ENTRY %main.{{[0-9]+}} () -> (f32[], f32[]) { + //CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + //CHECK: %[[CALL:call.[0-9]+]] = (f32[], f32[]) call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: %[[GTE0:get-tuple-element.[0-9]+]] = f32[] get-tuple-element((f32[], f32[]) %[[CALL]]), index=0 + //CHECK: %[[GTE1:get-tuple-element.[0-9]+]] = f32[] get-tuple-element((f32[], f32[]) %[[CALL]]), index=1 + //CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple(f32[] %[[GTE0]], f32[] %[[GTE1]]) + //CHECK: } + func.func @main() -> (tensor, tensor) { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %result:2 = "mhlo.composite"(%0) { + name = "foo.bar", + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add, + version = 1 : i32 + } : (tensor) -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor + } + func.func @add(%arg0: tensor) -> (tensor, tensor) { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1, %1 : tensor, tensor + } +} + +// ----- + +// optional composite attributes +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { + // CHECK: %[[ARG]] = f32[] parameter(0) + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: } + // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + decomposition = @add, + version = 1 : i32 + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// optional composite version +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { + // CHECK: %[[ARG]] = f32[] parameter(0) + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: } + // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="0"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// optional composite attributes and version +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %[[ADD:add.[0-9]+]] ([[ARG:Arg_0.[0-9]+]]: f32[]) -> f32[] { + // CHECK: %[[ARG]] = f32[] parameter(0) + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(2) + // CHECK: ROOT %add.{{[0-9]+}} = f32[] add(f32[] %[[ARG]], f32[] %[[CONSTANT]]) + // CHECK: } + // CHECK: ENTRY %main.{{[0-9]+}} () -> f32[] { + // CHECK: %[[CONSTANT:constant.[0-9]+]] = f32[] constant(42) + // CHECK: ROOT %call.{{[0-9]+}} = f32[] call(f32[] %[[CONSTANT]]), to_apply=%[[ADD]], is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + decomposition = @add + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/dynamic.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/dynamic.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/dynamic.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export-with-layouts.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export-with-layouts.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir index b4e7a128a5d1ed..17b686cc2f5ebe 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export.mlir @@ -600,6 +600,12 @@ func.func @main() { // CHECK: f8e5m2fnuz[4] constant({1, 2, 3, 4}) %cst_15 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E5M2FNUZ> + // CHECK: f8e4m3[4] constant({1, 2, 3, 4}) + %cst_16 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E4M3> + + // CHECK: f8e3m4[4] constant({1, 2, 3, 4}) + %cst_17 = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf8E3M4> + func.return } @@ -729,7 +735,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { %5 = "mhlo.convert"(%4) : (tensor<2xf8E4M3FNUZ>) -> tensor<2xf32> %6 = "mhlo.convert"(%5) : (tensor<2xf32>) -> tensor<2xf8E5M2FNUZ> %7 = "mhlo.convert"(%6) : (tensor<2xf8E5M2FNUZ>) -> tensor<2xf32> - func.return %7 : tensor<2xf32> + %8 = "mhlo.convert"(%7) : (tensor<2xf32>) -> tensor<2xf8E4M3> + %9 = "mhlo.convert"(%8) : (tensor<2xf8E4M3>) -> tensor<2xf32> + %10 = "mhlo.convert"(%9) : (tensor<2xf32>) -> tensor<2xf8E3M4> + %11 = "mhlo.convert"(%10) : (tensor<2xf8E3M4>) -> tensor<2xf32> + func.return %11 : tensor<2xf32> } // CHECK: ENTRY @@ -741,7 +751,11 @@ func.func @main(%arg0: tensor<2xf32>) -> tensor<2xf32> { // CHECK: %[[E4M3FNUZ_VAL:.*]] = f8e4m3fnuz[2] convert(f32[2] %[[F32_VAL2]]) // CHECK: %[[F32_VAL3:.*]] = f32[2] convert(f8e4m3fnuz[2] %[[E4M3FNUZ_VAL]]) // CHECK: %[[E5M2FNUZ_VAL:.*]] = f8e5m2fnuz[2] convert(f32[2] %[[F32_VAL3]]) -// CHECK: ROOT %[[RESULT:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[F32_VAL4:.*]] = f32[2] convert(f8e5m2fnuz[2] %[[E5M2FNUZ_VAL]]) +// CHECK: %[[E4M3_VAL:.*]] = f8e4m3[2] convert(f32[2] %[[F32_VAL4]]) +// CHECK: %[[F32_VAL5:.*]] = f32[2] convert(f8e4m3[2] %[[E4M3_VAL]]) +// CHECK: %[[E3M4_VAL:.*]] = f8e3m4[2] convert(f32[2] %[[F32_VAL5]]) +// CHECK: ROOT %[[F32_VAL6:.*]] = f32[2] convert(f8e3m4[2] %[[E3M4_VAL]]) // ----- @@ -2442,7 +2456,7 @@ func.func @main(%input0: tensor<16x16xf32>, %input1: tensor<16x16xi32>) { } // CHECK: %[[SORT_CMP:.*]] ([[ARG0:.*]]: f32[], [[ARG1:.*]]: f32[], {{.*}}: s32[], {{.*}}: s32[]) -> pred[] { -// CHECK: ROOT %compare.8 = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT +// CHECK: ROOT %compare.{{[0-9+]}} = pred[] compare(f32[] %[[ARG0]], f32[] %[[ARG1]]), direction=GT // CHECK: [[SORT:%.+]] = (f32[16,16], s32[16,16]) sort(f32[16,16] %Arg_0.1, s32[16,16] %Arg_1.2), dimensions={1}, is_stable=true, to_apply=%[[SORT_CMP]] // CHECK: [[GET0:%.+]] = f32[16,16] get-tuple-element((f32[16,16], s32[16,16]) [[SORT]]), index=0 diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_and_check_layouts.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_async.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_entry_computation_layout.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_large_constants.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_large_constants.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_large_constants.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_large_constants.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/export_replicas.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/export_replicas.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/frontend_attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/frontend_attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/frontend_attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/frontend_attributes.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir new file mode 100644 index 00000000000000..29ac479e024ef3 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/function.mlir @@ -0,0 +1,18 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +module @non_entry_function_shardings { + func.func @main(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + %0 = call @called_computation(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } + + // CHECK: %called_computation.{{[0-9]+}} (Arg_0.{{[0-9]+}}: s32[8,2]) -> s32[8,2] { + // CHECK-NEXT: %[[ARG:.*]] = s32[8,2] parameter(0), sharding={devices=[2,2]<=[4]} + // CHECK-NEXT: %[[MULT:.*]] = s32[8,2] multiply(s32[8,2] %[[ARG]], s32[8,2] %[[ARG]]) + // CHECK-NEXT: %[[TUPLE:.*]] = (s32[8,2]) tuple(s32[8,2] %[[MULT]]) + // CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[8,2] get-tuple-element((s32[8,2]) %[[TUPLE]]), index=0, sharding={devices=[2,2]<=[4]} + func.func private @called_computation(%arg0: tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) -> (tensor<8x2xi32> {mhlo.sharding = "{devices=[2,2]<=[4]}"}) { + %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + return %0 : tensor<8x2xi32> + } +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/fusion.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/fusion.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/fusion.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/if.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/if.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir index 6c4a3faf9ae012..4c414422124736 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/if.mlir +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/if.mlir @@ -279,7 +279,7 @@ func.func @main(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> te // CHECK-DAG: %[[CST2:.+]] = f32[] constant(10) // CHECK-DAG: %[[TUPLE2:.+]] = () tuple() // CHECK: %[[COND2:.+]] = f32[] conditional(pred[] %{{.+}}, f32[] %[[CST2]], () %[[TUPLE2]]), true_computation=[[R0]], false_computation=[[R1]] -// CHECK: ROOT %tuple.18 = (f32[], f32[]) tuple +// CHECK: ROOT %tuple.{{[0-9]+}} = (f32[], f32[]) tuple // CHECK-NEXT: } // CHECK: [[R3:%.+]] ([[A3_TUPLE:.+]]: (f32[], f32[])) -> (f32[], f32[]) { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/input_output_aliasing.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/int4.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/int4.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/layouts_and_names.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/layouts_and_names.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/layouts_and_names.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_op_metadata.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/location_to_stacktrace.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/missing_main.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/missing_main.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/missing_main.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/missing_main.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_attributes.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_attributes.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_config.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_config.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/module_config.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/module_config.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/multiple_return_tuple.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/opaque_elements_attr.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/rng_get_and_update_state.mlir diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir new file mode 100644 index 00000000000000..b7255055f4b372 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/sharding.mlir @@ -0,0 +1,362 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> f32[4,4] +func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "\08\03\1A\01\02\22\02\00\01"}) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { + // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1} + // CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated} + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> + %1 = mhlo.multiply %arg1, %0 : tensor<4xf32> + %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4x4xf32> + // CHECK: ROOT {{.*}}, sharding={devices=[2,1]0,1} + func.return %2 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128] +func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) { + // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(f32[5,8,128] %Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2) + // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element((f32[5,8,128]) %tuple.3), index=0 + // CHECK-SAME: sharding={devices=[1,2,1]0,1} + %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", + mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01" + } : (tensor<5x8x128xf32>) -> tensor<5x8x128xf32> + func.return %0 : tensor<5x8x128xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[4,4]) -> (f32[4,4], f32[4,4]) +func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\03\02\01\02\22\04\00\01\02\03B\01\00"}, tensor<4x4xf32>) { + // CHECK-NEXT: %Arg_0.1 = f32[4,4] parameter(0) + // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + // CHECK-NEXT: [[RESHAPE_1:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1) + // CHECK-NOT: sharding + // CHECK-NEXT: ROOT {{%.*}} = (f32[4,4], f32[4,4]) tuple(f32[4,4] [[RESHAPE_0]], f32[4,4] [[RESHAPE_1]]) + // CHECK-SAME: sharding={{\{}}{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {replicated}} + return %arg0, %arg0 : tensor<4x4xf32>, tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[4] +func.func @main() -> (tensor<4xf32>) { + // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) + // CHECK-NEXT: %broadcast.2 = f32[4] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[2]0,1} + // CHECK-NEXT: ROOT %add.3 = f32[4] add(f32[4] %broadcast.2, f32[4] %broadcast.2) + %0 = mhlo.constant {mhlo.sharding = "{devices=[2]0,1}"} dense<3.1415926> : tensor<4xf32> + %1 = mhlo.add %0, %0 : tensor<4xf32> + return %1 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[12,24,36] +func.func @main() -> (tensor<12x24x36xf32>) { + // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) + // CHECK-NEXT: %broadcast.2 = f32[12,24,36] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[1,2,1]0,1} + // CHECK-NEXT: ROOT %add.3 = f32[12,24,36] add(f32[12,24,36] %broadcast.2, f32[12,24,36] %broadcast.2) + %0 = mhlo.constant {mhlo.sharding = "{devices=[1,2,1]0,1}"} dense<3.1415926> : tensor<12x24x36xf32> + %1 = mhlo.add %0, %0 : tensor<12x24x36xf32> + return %1 : tensor<12x24x36xf32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) +func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{devices=[2,16]<=[32] last_tile_dim_replicate}"}, tensor<512x4xui32> {mhlo.sharding = "{devices=[4,8]<=[32]}"}) { + // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {devices=[8,4]<=[32]}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) + // CHECK-NEXT: %reshape.6 = u64[2] reshape(u64[2] %add.5) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={devices=[8,4]<=[32]} + // CHECK-NEXT: %reshape.7 = u32[512,4] reshape(u32[512,4] %get-tuple-element.4) + // CHECK-NEXT: ROOT %tuple.8 = (u64[2], u32[512,4]) tuple(u64[2] %reshape.6, u32[512,4] %reshape.7), sharding={{\{}}{devices=[2,16]<=[32] last_tile_dim_replicate}, {devices=[4,8]<=[32]}} + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{{replicated}, {devices=[8,4]<=[32]}}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) + %0 = mhlo.add %output_state, %output_state : tensor<2xui64> + return %0, %output : tensor<2xui64>, tensor<512x4xui32> +} + +// ----- + +// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) +func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) { + // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) + // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {replicated}} + // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} + // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) + // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={replicated} + // CHECK-NEXT: ROOT %tuple.6 = (u64[2], u32[512,4]) tuple(u64[2] %add.5, u32[512,4] %get-tuple-element.4) + %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{replicated}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) + %0 = mhlo.add %output_state, %output_state : tensor<2xui64> + return %0, %output : tensor<2xui64>, tensor<512x4xui32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG:Arg_.[0-9]+]]: s32[]) -> s32[] { +// CHECK-NEXT: %[[ARG]] = s32[] parameter(0), sharding={replicated} +// CHECK-NEXT: %[[ADD:add.[0-9]+]] = s32[] add(s32[] %[[ARG]], s32[] %[[ARG]]) +// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[]) tuple(s32[] %[[ADD]]) +// CHECK-NEXT: ROOT %get-tuple-element.{{[0-9]+}} = s32[] get-tuple-element((s32[]) %[[TUPLE]]), index=0, sharding={replicated} + +// CHECK: %[[COND:region_1.[0-9]+]] ([[ARG:Arg_.[0-9]+]]: s32[]) -> pred[] { +// CHECK-NEXT: %[[ARG]] = s32[] parameter(0), sharding={replicated} +// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(s32[] %[[ARG]], s32[] %[[ARG]]), direction=LT + +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG:Arg_0.[0-9]+]]: s32[]) -> s32[] { +// CHECK-NEXT: %[[ARG]] = s32[] parameter(0) +// CHECK-NEXT: ROOT %while.10 = s32[] while(s32[] %[[ARG]]), condition=%[[COND]], body=%[[BODY]], sharding={replicated} + +func.func @main(%arg0: tensor) -> tensor { + %0 = mhlo.while(%iterArg = %arg0) : tensor attributes {mhlo.sharding = "{replicated}"} + cond { + %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + } do { + %1 = mhlo.add %iterArg, %iterArg : tensor + mhlo.return %1 : tensor + } + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG_TUPLE:arg_tuple.[0-9]+]]: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE0:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE1:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE2:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[ADD:add.[0-9]+]] = f32[4] add(f32[4] %[[GTE1]], f32[4] %[[GTE2]]) +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (s32[], f32[4], f32[4]) tuple(s32[] %[[GTE0]], f32[4] %[[ADD]], f32[4] %[[GTE2]]) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} + +// CHECK: %[[COND:region_1.[0-9]+]] ([[ARG_TUPLE:arg_tuple.[0-9]+]]: (s32[], f32[4], f32[4])) -> pred[] { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE16:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[GTE14:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: ROOT %compare.{{[0-9]+}} = pred[] compare(s32[] %[[GTE14]], s32[] %[[GTE14]]), direction=LT + +// CHECK: ENTRY %main.{{[0-9]+}} ([[ARG0:Arg_0.[0-9]+]]: s32[], [[ARG1:Arg_1.[0-9]+]]: f32[4], [[ARG2:Arg_2.[0-9]+]]: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: %[[TUPLE:tuple.[0-9]+]] = (s32[], f32[4], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[WHILE:while.[0-9]+]] = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE19:get-tuple-element.[0-9]+]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE21:get-tuple-element.[0-9]+]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=2, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{[0-9]+}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE20]], f32[4] %[[GTE21]]) + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> + attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate},{devices=[4]<=[4]}}"} + cond { + %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + } do { + %1 = mhlo.add %iterArg_0, %iterArg_1 : tensor<4xf32> + mhlo.return %iterArg, %1, %iterArg_1 : tensor, tensor<4xf32>, tensor<4xf32> + } + func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BODY:region_0.[0-9]+]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} +// CHECK-NEXT: %[[GTE7:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={manual} +// CHECK-NEXT: %[[GTE8:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE9:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={manual} +// CHECK-NEXT: %[[ADD:add.*]] = f32[4] add(f32[4] %[[GTE8]], f32[4] %[[GTE9]]) +// CHECK-NEXT: ROOT %tuple.{{.*}} = (s32[], f32[4], f32[4]) tuple(s32[] %[[GTE7]], f32[4] %[[ADD]], f32[4] %[[GTE9]]) +// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} + +// CHECK: %[[COND:region_1.[0-9]+]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], f32[4])) -> pred[] { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=2, sharding={manual} +// CHECK-NEXT: %[[GTE14:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={manual} +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %[[GTE14]], s32[] %[[GTE14]]), direction=LT + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]) +// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} +// CHECK-NEXT: %[[WHILE:while.*]] = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=0, sharding={manual} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=1, sharding={manual} +// CHECK-NEXT: %[[GTE21:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %[[WHILE]]), index=2, sharding={manual} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE20]], f32[4] %[[GTE21]]) + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> + attributes {mhlo.sharding = "{manual}"} + cond { + %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor + mhlo.return %1 : tensor + } do { + %1 = mhlo.add %iterArg_0, %iterArg_1 : tensor<4xf32> + mhlo.return %iterArg, %1, %iterArg_1 : tensor, tensor<4xf32>, tensor<4xf32> + } + func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BRANCH0:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE10]], f32[4] %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %[[BRANCH1:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE15]], f32[4] %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4], [[ARG3:Arg_3.*]]: f32[4], [[ARG4:Arg_4.*]]: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG1]], f32[4] %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[ARG3]] = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %[[ARG4]] = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG3]], f32[4] %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[COND:conditional.*]] = (f32[4], f32[4]) conditional(s32[] %[[ARG0]], (f32[4], f32[4]) %[[TUPLE6]], (f32[4], f32[4]) %[[TUPLE7]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]}, +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[COND]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[COND]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE19]], f32[4] %[[GTE20]]) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BRANCH0:region_0.*]] ([[ARG:Arg_.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %[[ARG]] = f32[4] parameter(0) + +// CHECK: %[[BRANCH1:region_1.*]] ([[ARG:Arg_.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %[[ARG]] = f32[4] parameter(0) + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(s32[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]), branch_computations={%[[BRANCH0]], %[[BRANCH1]]} +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BRANCH0:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[GTE10:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[GTE11:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1 +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE10]], f32[4] %[[GTE11]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %[[BRANCH1:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE15:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE16:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %[[ARG_TUPLE]]), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE15]], f32[4] %[[GTE16]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: pred[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4], [[ARG3:Arg_3.*]]: f32[4], [[ARG4:Arg_4.*]]: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %[[ARG0]] = pred[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: %[[TUPLE6:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG1]], f32[4] %[[ARG2]]), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %[[ARG3]] = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %[[ARG4]] = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %[[TUPLE7:tuple.*]] = (f32[4], f32[4]) tuple(f32[4] %[[ARG3]], f32[4] %[[ARG4]]), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %[[ARG0]], (f32[4], f32[4]) %[[TUPLE6]], (f32[4], f32[4]) %[[TUPLE7]]), true_computation=%[[BRANCH0]], false_computation=%[[BRANCH1]], +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE19:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %[[GTE20:get-tuple-element.*]] = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.{{.*}} = (f32[4], f32[4]) tuple(f32[4] %[[GTE19]], f32[4] %[[GTE20]]) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %[[TRUE:region_0.*]] ([[ARG:Arg_.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %[[ARG]] = f32[4] parameter(0) + +// CHECK: %[[FALSE:region_1.*]] ([[ARG:Arg_.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %[[ARG]] = f32[4] parameter(0) + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: pred[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: %[[ARG0]] = pred[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.{{.*}} = f32[4] conditional(pred[] %[[ARG0]], f32[4] %[[ARG1]], f32[4] %[[ARG2]]), true_computation=%[[TRUE]], false_computation=%[[FALSE]] + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/simple.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/simple.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/simple.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/simple.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/unsupported_type.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/unsupported_type.mlir similarity index 100% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/unsupported_type.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/unsupported_type.mlir diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir similarity index 96% rename from third_party/xla/xla/translate/mhlo_to_hlo/tests/while.mlir rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir index fddcb2bc61d7e4..e99a57fe7f9225 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while.mlir +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while.mlir @@ -6,10 +6,10 @@ module { %0 = "mhlo.while"(%arg0) ({ // CHECK: [[R0:%.+]] ([[A0:.+]]: s64[]) -> s64[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %add.4 = s64[] add(s64[] %[[A0]], s64[] %[[A0]]) + // CHECK: ROOT %add.{{.*}} = s64[] add(s64[] %[[A0]], s64[] %[[A0]]) // CHECK: [[R1:%.+]] ([[A0:.+]]: s64[]) -> pred[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %compare.7 = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT + // CHECK: ROOT %compare.{{.*}} = pred[] compare(s64[] %[[A0]], s64[] %[[A0]]), direction=LT ^bb0(%arg1: tensor): %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor "mhlo.return"(%1) : (tensor) -> () @@ -19,9 +19,9 @@ module { "mhlo.return"(%1) : (tensor) -> () }) : (tensor) -> tensor - // CHECK: ENTRY %main.9 ([[A0:.+]]: s64[]) -> s64[] { + // CHECK: ENTRY %main.{{.*}} ([[A0:.+]]: s64[]) -> s64[] { // CHECK: %[[A0]] = s64[] parameter(0) - // CHECK: ROOT %while.8 = s64[] while(s64[] %[[A0]]), condition=[[R1]], body=[[R0]] + // CHECK: ROOT %while.{{.*}} = s64[] while(s64[] %[[A0]]), condition=[[R1]], body=[[R0]] func.return %0 : tensor } } @@ -103,7 +103,7 @@ func.func @main(%arg0: tensor) -> tensor { // CHECK-NEXT: %[[GTE_1:.*]] = f32[3] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=3 // CHECK-NEXT: %[[GTE_2:.*]] = s32[1] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=0 // CHECK-NEXT: %[[CST_0:.*]] = s32[] constant(0) -// CHECK-NEXT: %[[RED_0:.*]] = s32[] reduce(s32[1] %[[GTE_2]], s32[] %constant.32), dimensions={0}, to_apply= +// CHECK-NEXT: %[[RED_0:.*]] = s32[] reduce(s32[1] %[[GTE_2]], s32[] %[[CST_0]]), dimensions={0}, to_apply= // CHECK-NEXT: %[[GTE_3:.*]] = s32[2] get-tuple-element((s32[1], s32[2], f32[1], f32[3]) %[[TUPLE_0]]), index=1 // CHECK-NEXT: %[[RED_1:.*]] = s32[] reduce(s32[2] %[[GTE_3]], s32[] %[[CST_0]]), dimensions={0}, to_apply= // CHECK-NEXT: ROOT %[[CMP:.*]] = pred[] compare(s32[] %[[RED_0]], s32[] %[[RED_1]]), direction=LT @@ -111,8 +111,8 @@ func.func @main(%arg0: tensor) -> tensor { // CHECK: ENTRY // CHECK-NEXT: %[[CST_0:.*]] = s32[1] constant({0}) -// CHECK-NEXT: %[[CST_1:.*]].3 = s32[] constant(100) -// CHECK-NEXT: %[[BDCAST_0:.*]] = s32[2] broadcast(s32[] %constant.3), dimensions={} +// CHECK-NEXT: %[[CST_1:.*]] = s32[] constant(100) +// CHECK-NEXT: %[[BDCAST_0:.*]] = s32[2] broadcast(s32[] %[[CST_1]]), dimensions={} // CHECK-NEXT: %[[CST_2:.*]] = f32[1] constant({1}) // CHECK-NEXT: %[[ARG_0:.*]] = f32[3] parameter(0) // CHECK-NEXT: %[[TUPLE:.*]] = (s32[1], s32[2], f32[1], f32[3]) tuple(s32[1] %[[CST_0]], s32[2] %[[BDCAST_0]], f32[1] %[[CST_2]], f32[3] %[[ARG_0]]) @@ -285,7 +285,7 @@ func.func @main(%arg0: tensor<3x3xf32>) -> tensor<3x3xf32> { // CHECK-NEXT: %[[ARG_TUPLE:.*]] = (s32[], s32[]) parameter(0) // CHECK-NEXT: %[[GTE_0:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=0 // CHECK-NEXT: %[[GTE_1:.*]] = s32[] get-tuple-element((s32[], s32[]) %[[ARG_TUPLE]]), index=1 -// CHECK-NEXT: ROOT %compare.17 = pred[] compare(s32[] %[[GTE_0]], s32[] %[[GTE_1]]), direction=LT +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %[[GTE_0]], s32[] %[[GTE_1]]), direction=LT // CHECK: } // CHECK: ENTRY diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir new file mode 100644 index 00000000000000..b1ebb843cf8321 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/while_free_vars.mlir @@ -0,0 +1,89 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s -o - | FileCheck %s + +// This test verifies that the correct shardings are added when a while loop +// has free variables. + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BODY:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-DAG: %[[GTE12:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=3 +// CHECK-DAG: %[[GTE13:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=4, sharding={devices=[4]<=[4]} +// CHECK-DAG: %[[ADD14:add.*]] = s32[] add(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE12]]) +// CHECK-DAG: %[[ADD15:add.*]] = f32[4] add(f32[4] %get-tuple-element.{{.*}}, f32[4] %[[GTE13]]) +// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %[[ADD14]], f32[4] %[[ADD15]], s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE12]], f32[4] %[[GTE13]]) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} + +// CHECK: %[[COND:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK: %[[GTE21:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[ARG_TUPLE]]), index=2 +// CHECK-NEXT: ROOT %compare.{{.*}} = pred[] compare(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE21]]), direction=LT + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: f32[4]) -> f32[4] { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) +// CHECK-NEXT: %[[CONSTANT4:constant.*]] = s32[] constant(0) +// CHECK-NEXT: %[[CONSTANT5:constant.*]] = s32[] constant(1) +// CHECK-NEXT: %[[ARG2]] = f32[4] parameter(2) +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], s32[] %[[CONSTANT4]], s32[] %[[CONSTANT5]], f32[4] %[[ARG2]]) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[WHILE:while.25]] = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %[[GTE26:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[WHILE]]), index=0, sharding={replicated} +// CHECK-NEXT: ROOT %[[GTE27:get-tuple-element.*]] = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %[[WHILE]]), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> + attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate}}"} + cond { + %3 = mhlo.compare LT, %iterArg, %0 : (tensor, tensor) -> tensor + mhlo.return %3 : tensor + } do { + %3 = mhlo.add %iterArg, %1 : tensor + %4 = mhlo.add %iterArg_0, %arg2 : tensor<4xf32> + mhlo.return %3, %4: tensor, tensor<4xf32> + } + func.return %2#1 : tensor<4xf32> +} + +// ----- + +// This test verifies that a value captured multiple times is only lifted once +// and all its uses are replaced. Also verifies that no sharding is added to +// region parameters or root when the while doesn't have a sharding. + +// CHECK-LABEL: HloModule main + +// CHECK: %[[BODY:region_0.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[]) parameter(0) +// CHECK: %[[GTE:get-tuple-element.*]] = s32[] get-tuple-element((s32[], f32[4], s32[]) %[[ARG_TUPLE]]), index=2 +// CHECK: %[[ADD:add.*]] = s32[] add(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE]]) +// CHECK: ROOT %tuple.{{.*}} = (s32[], f32[4], s32[]) tuple(s32[] %[[ADD]], f32[4] %get-tuple-element.{{.*}}, s32[] %[[GTE]]) + +// CHECK: %[[COND:region_1.*]] ([[ARG_TUPLE:arg_tuple.*]]: (s32[], f32[4], s32[])) -> pred[] { +// CHECK-NEXT: %[[ARG_TUPLE]] = (s32[], f32[4], s32[]) parameter(0) +// CHECK: %[[GTE:get-tuple-element..*]] = s32[] get-tuple-element((s32[], f32[4], s32[]) %[[ARG_TUPLE]]), index=2 +// CHECK: ROOT %compare.{{.*}} = pred[] compare(s32[] %get-tuple-element.{{.*}}, s32[] %[[GTE]]), direction=LT + +// CHECK: ENTRY %main.{{.*}} ([[ARG0:Arg_0.*]]: s32[], [[ARG1:Arg_1.*]]: f32[4], [[ARG2:Arg_2.*]]: s32[]) -> f32[4] { +// CHECK-NEXT: %[[ARG0]] = s32[] parameter(0) +// CHECK-NEXT: %[[ARG1]] = f32[4] parameter(1) +// CHECK-NEXT: %[[ARG2]] = s32[] parameter(2) +// CHECK-NEXT: %[[TUPLE:tuple.*]] = (s32[], f32[4], s32[]) tuple(s32[] %[[ARG0]], f32[4] %[[ARG1]], s32[] %[[ARG2]]) +// CHECK-NEXT: %while.{{.*}} = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %[[TUPLE]]), condition=%[[COND]], body=%[[BODY]] + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<4xf32> { + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> + cond { + %3 = mhlo.compare LT, %iterArg, %arg2 : (tensor, tensor) -> tensor + mhlo.return %3 : tensor + } do { + %3 = mhlo.add %iterArg, %arg2 : tensor + mhlo.return %3, %iterArg_0: tensor, tensor<4xf32> + } + func.return %2#1 : tensor<4xf32> +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc similarity index 96% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc index 7c07582a46c794..0ac14b220f9531 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/translate.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include #include @@ -36,20 +36,20 @@ limitations under the License. #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_proto_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h new file mode 100644 index 00000000000000..064db33984b864 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate.h @@ -0,0 +1,50 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ + +#include +#include + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace xla { + +mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, + llvm::raw_ostream& output, + bool emit_return_tuple, + bool emit_use_tuple_arg); + +mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +// Translate the MHLO program in in-memory file 'buffer' to a HLO program +// written in a file represented with handle 'output_stream'; +mlir::LogicalResult MlirHloToHloTextMain( + std::unique_ptr buffer, + llvm::raw_ostream& output_stream, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc similarity index 95% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc index ec5954af59a25d..ed0be0bee7345e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/translate_registration.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate_registration.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -22,8 +22,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/Translation.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include "xla/mlir_hlo/mhlo/IR/register.h" -#include "xla/translate/mhlo_to_hlo/translate.h" static mlir::LogicalResult MlirHloToHloTranslate(mlir::ModuleOp module, llvm::raw_ostream& output) { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h similarity index 91% rename from third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h index 6cd3e8e4fdd898..42c480a15fbc4d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate_registration.h +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/translate_registration.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ -#define XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ #include "llvm/Support/CommandLine.h" @@ -60,4 +60,4 @@ llvm::cl::opt via_builder( "via-builder", llvm::cl::desc("Translate MHLO->XLA HLO via XLA Builder"), llvm::cl::init(false)); -#endif // XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TRANSLATE_REGISTRATION_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc similarity index 99% rename from third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc index 82e79e7ff63197..89a3ab09b9f51e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include #include @@ -53,7 +53,6 @@ using xla::PrimitiveType; namespace xla { - std::optional> ConvertDimLevelType( mlir::sparse_tensor::LevelType lt) { auto f = mlir::sparse_tensor::getLevelFormat(lt); diff --git a/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h new file mode 100644 index 00000000000000..eb641ce44e3440 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape.h @@ -0,0 +1,31 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ +#define XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/Types.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { + +// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. +Shape TypeToShape(mlir::Type type); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc similarity index 98% rename from third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc rename to third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc index 9d09c79eeaa507..464a6f21f9dfb6 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape_test.cc +++ b/third_party/xla/xla/hlo/translate/mhlo_to_hlo/type_to_shape_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla_data.pb.h" #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/hlo/translate/portable_api.cc b/third_party/xla/xla/hlo/translate/portable_api.cc new file mode 100644 index 00000000000000..9546279cff849f --- /dev/null +++ b/third_party/xla/xla/hlo/translate/portable_api.cc @@ -0,0 +1,71 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/translate/portable_api.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/Register.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +std::string PrintModule(mlir::ModuleOp module) { + std::string s; + llvm::raw_string_ostream os(s); + mlir::OpPrintingFlags flags; + flags.enableDebugInfo(); + module->print(os, flags); + return s; +} + +void LoadHloDialects(mlir::MLIRContext& context) { + mlir::DialectRegistry registry; + mlir::stablehlo::registerAllDialects(registry); + mlir::mhlo::registerAllMhloDialects(registry); + context.appendDialectRegistry(registry); +} + +absl::StatusOr SerializeUsingBytecode(mlir::ModuleOp module) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + mlir::BytecodeWriterConfig config; + if (mlir::failed(mlir::writeBytecodeToFile(module, os, config))) { + return absl::InvalidArgumentError("mlir::writeBytecodeToFile failed"); + } + return bytecode; +} + +absl::StatusOr ConvertHloToStablehlo( + xla::HloModule const& hlo_module, bool emit_bytecode) { + mlir::MLIRContext context; + LoadHloDialects(context); + TF_ASSIGN_OR_RETURN(auto module, ConvertHloToStablehlo(context, &hlo_module)); + if (emit_bytecode) return SerializeUsingBytecode(*module); + return PrintModule(*module); +} + +} // namespace xla diff --git a/third_party/xla/xla/hlo/translate/portable_api.h b/third_party/xla/xla/hlo/translate/portable_api.h new file mode 100644 index 00000000000000..3069c4e92d0be6 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/portable_api.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_PORTABLE_API_H_ +#define XLA_HLO_TRANSLATE_PORTABLE_API_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" + +// This file is a portable version of the HLO API. +// Is offers a string API passthrough for MLIR datatypes and is intended +// to offer a safe means of using StableHLO opaquely in non-MLIR code. + +namespace xla { + +absl::StatusOr ConvertHloToStablehlo( + xla::HloModule const& hlo_module, bool emit_bytecode = false); + +} + +#endif // XLA_HLO_TRANSLATE_PORTABLE_API_H_ diff --git a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD new file mode 100644 index 00000000000000..c06de76069abf0 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/BUILD @@ -0,0 +1,46 @@ +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tsl:tsl.bzl", "internal_visibility") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//learning/brain/mlir:tensorflow_friends", + "//learning/brain/mlir:xla_friends", + ]), + licenses = ["notice"], +) + +cc_library( + name = "translate", + srcs = ["translate.cc"], + hdrs = ["translate.h"], + deps = [ + "//xla/hlo/translate/mhlo_to_hlo:translate", + "//xla/mlir_hlo:mhlo_passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@stablehlo//:register", + ], +) + +cc_library( + name = "translate_registration", + testonly = True, + srcs = ["translate_registration.cc"], + deps = [ + ":translate", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TranslateLib", + "@stablehlo//:register", + ], + alwayslink = 1, +) diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD similarity index 93% rename from third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD index c68f2ea17d8423..a9535fa3e35b4a 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/BUILD @@ -19,7 +19,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], tools = [ - "//xla/translate:xla-translate", + "//xla/hlo/translate:xla-translate", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/tests/simple.mlir b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/simple.mlir similarity index 100% rename from third_party/xla/xla/translate/stablehlo_to_hlo/tests/simple.mlir rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/tests/simple.mlir diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc similarity index 96% rename from third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc index 5cf7b516f4229e..ee50467832e872 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.cc @@ -12,7 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/translate/stablehlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #include #include @@ -29,8 +29,8 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Register.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/translate/mhlo_to_hlo/translate.h" namespace xla { diff --git a/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h new file mode 100644 index 00000000000000..c3f0a86cb88340 --- /dev/null +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate.h @@ -0,0 +1,50 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ +#define XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ + +#include +#include + +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Support/LogicalResult.h" + +namespace xla { + +mlir::LogicalResult StablehloToHloTranslateFunction(mlir::ModuleOp module, + llvm::raw_ostream& output, + bool emit_return_tuple, + bool emit_use_tuple_arg); + +mlir::LogicalResult StablehloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +// Translate the StableHLO program in in-memory file 'buffer' to a HLO program +// written in a file represented with handle 'output_stream'; +mlir::LogicalResult StablehloToHloTextMain( + std::unique_ptr buffer, + llvm::raw_ostream& output_stream, bool emit_return_tuple, + bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, + bool print_sugar, bool via_builder, bool with_layouts); + +} // namespace xla + +#endif // XLA_HLO_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc similarity index 95% rename from third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc rename to third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc index 38e827dac3475a..23258c1e212d6e 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate_registration.cc +++ b/third_party/xla/xla/hlo/translate/stablehlo_to_hlo/translate_registration.cc @@ -23,10 +23,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "stablehlo/dialect/Register.h" -#include "xla/translate/stablehlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" // The following symbols are defined in -// tensorflow/compiler/xla/translate/mhlo_to_hlo/translate_registration.h +// tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo/translate_registration.h extern llvm::cl::opt emit_use_tuple_arg; extern llvm::cl::opt emit_return_tuple; extern llvm::cl::opt with_layouts; diff --git a/third_party/xla/xla/translate/xla_translate_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_main.cc similarity index 100% rename from third_party/xla/xla/translate/xla_translate_main.cc rename to third_party/xla/xla/hlo/translate/xla_translate_main.cc diff --git a/third_party/xla/xla/translate/xla_translate_opt_main.cc b/third_party/xla/xla/hlo/translate/xla_translate_opt_main.cc similarity index 100% rename from third_party/xla/xla/translate/xla_translate_opt_main.cc rename to third_party/xla/xla/hlo/translate/xla_translate_opt_main.cc diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 814f54fa613550..9d11def6d749a2 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -26,11 +26,11 @@ cc_library( hdrs = ["hlo_live_range.h"], deps = [ "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -46,17 +46,18 @@ xla_cc_test( srcs = ["hlo_live_range_test.cc"], deps = [ ":hlo_live_range", - "//xla:literal", - "//xla:status_macros", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", - "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_ordering", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:hlo_value", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ], ) @@ -69,7 +70,7 @@ cc_library( "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/strings", ], ) @@ -82,9 +83,9 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test_main", ], ) @@ -136,14 +137,14 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/service:dot_as_convolution_util", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", ], ) @@ -154,11 +155,11 @@ cc_library( deps = [ "//xla:literal", "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], ) @@ -170,11 +171,10 @@ xla_cc_test( ], deps = [ ":hlo_query", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.cc b/third_party/xla/xla/hlo/utils/hlo_live_range.cc index 093e3d8cfdb4a8..61ce4c68dfe486 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.cc @@ -29,11 +29,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range.h b/third_party/xla/xla/hlo/utils/hlo_live_range.h index 87c6ec1ed5f0e2..da6e93b992e08a 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range.h +++ b/third_party/xla/xla/hlo/utils/hlo_live_range.h @@ -21,10 +21,10 @@ the License. #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" namespace xla { diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc index 64e4ab5ee37d62..5dc63e4434f042 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc @@ -14,20 +14,27 @@ limitations under the License. ==============================================================================*/ #include "xla/hlo/utils/hlo_live_range.h" +#include #include #include #include #include #include +#include #include "absl/container/flat_hash_map.h" +#include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/hlo_alias_analysis.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal_util.h" #include "xla/service/hlo_value.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" @@ -35,7 +42,7 @@ namespace xla { namespace { using TimeBound = HloLiveRange::TimeBound; -class HloLiveRangeTest : public HloTestBase { +class HloLiveRangeTest : public HloHardwareIndependentTestBase { protected: HloLiveRangeTest() : module_(CreateNewVerifiedModule()) {} ~HloLiveRangeTest() override {} diff --git a/third_party/xla/xla/hlo/utils/hlo_matchers.h b/third_party/xla/xla/hlo/utils/hlo_matchers.h index 17f3294156bcab..2c00ddb7b3edfb 100644 --- a/third_party/xla/xla/hlo/utils/hlo_matchers.h +++ b/third_party/xla/xla/hlo/utils/hlo_matchers.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/test.h" #include "xla/xla_data.pb.h" @@ -323,6 +323,7 @@ HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); HLO_MATCHER(PartitionId); HLO_MATCHER(Power); +HLO_MATCHER(RaggedAllToAll); HLO_MATCHER(Recv); HLO_MATCHER(RecvDone); HLO_MATCHER(Reduce); diff --git a/third_party/xla/xla/hlo/utils/hlo_matchers_test.cc b/third_party/xla/xla/hlo/utils/hlo_matchers_test.cc index 67f4412702cdea..3a9261db4fe110 100644 --- a/third_party/xla/xla/hlo/utils/hlo_matchers_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_matchers_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include #include +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/literal_util.h" #include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" namespace op = xla::testing::opcode_matchers; @@ -34,7 +34,7 @@ using ::testing::HasSubstr; namespace xla { namespace { -using HloMatchersTest = HloTestBase; +using HloMatchersTest = HloHardwareIndependentTestBase; std::string DescribeHloMatcher( const ::testing::Matcher& m) { diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc index 147f54822aef97..90b6ddfd4d2b2a 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query.cc @@ -34,7 +34,8 @@ namespace hlo_query { bool IsCollectiveCommunicationOp(HloOpcode op) { return op == HloOpcode::kAllReduce || op == HloOpcode::kAllGather || - op == HloOpcode::kAllToAll || op == HloOpcode::kCollectivePermute || + op == HloOpcode::kAllToAll || op == HloOpcode::kRaggedAllToAll || + op == HloOpcode::kCollectivePermute || op == HloOpcode::kCollectiveBroadcast || op == HloOpcode::kReduceScatter || op == HloOpcode::kAllReduceStart || op == HloOpcode::kAllGatherStart || @@ -280,36 +281,21 @@ HloComputation* FindComputation(HloModule* module, absl::string_view name) { return *it; } -std::pair FindFirstInstruction( - const HloComputation* computation, absl::string_view name) { - int current_index = 0; - for (auto* instruction : computation->instructions()) { - if (instruction->name() == name) { - return {instruction, current_index}; - break; - } - current_index++; +HloInstruction* FindInstruction(const HloComputation* computation, + absl::string_view name) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->name() == name) return instruction; } - return {nullptr, -1}; + return nullptr; } -std::pair FindFirstInstruction( - const HloComputation* computation, HloOpcode opcode) { - int current_index = 0; +HloInstruction* FindInstruction(const HloComputation* computation, + HloOpcode opcode) { for (auto* instruction : computation->instructions()) { - if (instruction->opcode() == opcode) { - return {instruction, current_index}; - break; - } - current_index++; + if (instruction->opcode() == opcode) return instruction; } - return {nullptr, -1}; + return nullptr; } -bool IsBeforeInComputation(const HloComputation* computation, - absl::string_view inst1, absl::string_view inst2) { - return FindFirstInstruction(computation, inst1).second < - FindFirstInstruction(computation, inst2).second; -} } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h index ec5c0b25804d10..f219594024dc7e 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.h +++ b/third_party/xla/xla/hlo/utils/hlo_query.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/util.h" namespace xla { @@ -82,7 +83,8 @@ bool IsBroadcastOfParameter(const HloInstruction& instr); HloInstruction* GetFirstInstructionWithOpcode(const HloComputation& computation, HloOpcode opcode); -// Applies `fn` to a collection of instruction for a given `computation`. +// Applies `fn` to a collection of instruction with `opcode` for a given +// `computation`. template void ForEachInstructionWithOpcode(HloComputation& computation, HloOpcode opcode, Fn&& fn) { @@ -93,7 +95,8 @@ void ForEachInstructionWithOpcode(HloComputation& computation, HloOpcode opcode, } } -// Applies `fn` to a collection of instruction for a given `module`. +// Applies `fn` to a collection of instruction with `opcode` for a given +// `module`. template void ForEachInstructionWithOpcode(HloModule& module, HloOpcode opcode, Fn&& fn) { @@ -102,6 +105,27 @@ void ForEachInstructionWithOpcode(HloModule& module, HloOpcode opcode, } } +// Applies `fn` to a collection of instruction satisfying `pred` for a given +// `computation`. +template +void ForEachInstructionWithPred(HloComputation& computation, HloPredicate pred, + Fn&& fn) { + for (HloInstruction* instr : computation.instructions()) { + if (pred(instr)) { + fn(instr); + } + } +} + +// Applies `fn` to a collection of instruction satisfying `pred` for a given +// `module`. +template +void ForEachInstructionWithPred(HloModule& module, HloPredicate pred, Fn&& fn) { + for (HloComputation* computation : module.computations()) { + ForEachInstructionWithPred(*computation, pred, fn); + } +} + // Determines whether the given computation contains an instruction with one of // the given opcodes. Checks both comp's instructions and the instructions of // any computations nested within it. @@ -156,23 +180,17 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, // Gets the computation from the given module with the given name. HloComputation* FindComputation(HloModule* module, absl::string_view name); -// Gets the first instruction and its index from the given computation with the -// given instruction name. The function returns {nullptr, -1} if the instruction -// cannot be found. -std::pair FindFirstInstruction( - const HloComputation* computation, absl::string_view name); -// Gets the first instruction and its index from the given computation with the -// given instruction opcode. The function returns {nullptr, -1} if the -// instruction cannot be found. -std::pair FindFirstInstruction( - const HloComputation* computation, HloOpcode opcode); - -// Check that one instruction comes before another one for a given computation. -// The function returns true if the first instruction comes before the second -// one, and false otherwise. This is useful for partial checks on the -// transformed IR without going through a full file check. -bool IsBeforeInComputation(const HloComputation* computation, - absl::string_view inst1, absl::string_view inst2); + +// Gets the instruction from the given computation with the given instruction +// name. Returns nullptr if no such instruction can be found. +HloInstruction* FindInstruction(const HloComputation* computation, + absl::string_view name); + +// Gets any instruction from the given computation with the given opcode. +// Returns nullptr if no such instruction can be found. +HloInstruction* FindInstruction(const HloComputation* computation, + HloOpcode opcode); + } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc index 1f715ad6815284..20d3108bae4dc0 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc @@ -16,25 +16,23 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include -#include #include -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_parser.h" -#include "xla/tests/hlo_test_base.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/util.h" #include "tsl/platform/statusor.h" namespace xla { namespace { -using HloQueryTest = HloTestBase; +using HloQueryTest = HloHardwareIndependentTestBase; template int CountInstructions(Hlo& module, HloOpcode opcode) { @@ -44,6 +42,14 @@ int CountInstructions(Hlo& module, HloOpcode opcode) { return counter; } +template +int CountInstructions(Hlo& module, HloPredicate pred) { + int counter = 0; + hlo_query::ForEachInstructionWithPred(module, pred, + [&counter](auto& instr) { counter++; }); + return counter; +} + constexpr absl::string_view kConstantAdditionHloString = R"( HloModule test ENTRY main { @@ -82,6 +88,11 @@ ENTRY main { EXPECT_EQ(CountInstructions(*module, HloOpcode::kAdd), 2); EXPECT_EQ(CountInstructions(*module, HloOpcode::kSubtract), 1); EXPECT_EQ(CountInstructions(*module, HloOpcode::kMultiply), 3); + EXPECT_EQ(CountInstructions(*module, HloPredicateIsOp), 2); + EXPECT_EQ(CountInstructions(*module, HloPredicateIsOp), + 1); + EXPECT_EQ(CountInstructions(*module, HloPredicateIsOp), + 3); } TEST_F(HloQueryTest, @@ -120,6 +131,14 @@ ENTRY main { EXPECT_EQ(CountInstructions(*computation, HloOpcode::kAdd), 2); EXPECT_EQ(CountInstructions(*computation, HloOpcode::kSubtract), 1); EXPECT_EQ(CountInstructions(*computation, HloOpcode::kMultiply), 3); + EXPECT_EQ(CountInstructions(*computation, HloPredicateIsOp), + 2); + EXPECT_EQ( + CountInstructions(*computation, HloPredicateIsOp), + 1); + EXPECT_EQ( + CountInstructions(*computation, HloPredicateIsOp), + 3); } TEST_F(HloQueryTest, GetUniqueGteTest) { @@ -157,31 +176,21 @@ TEST_F(HloQueryTest, FindInstructionUsingNameTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr); - EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr); - EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr); -} - -std::pair FindFirst(const HloComputation* main, - absl::string_view opcode) { - return hlo_query::FindFirstInstruction(main, - StringToHloOpcode(opcode).value()); + EXPECT_NE(hlo_query::FindInstruction(main, "zero"), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, "five"), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, "out"), nullptr); + EXPECT_EQ(hlo_query::FindInstruction(main, "foo"), nullptr); } -// Assures that the string and opcode versions of FindFirstInstruction return +// Assures that the string and opcode versions of FindInstruction return // the same result -void FindFirstInstructionsAndExpectEqual(const HloComputation* main, - absl::string_view name, - absl::string_view opcode_str) { +void FindInstructionsAndExpectEqual(const HloComputation* main, + absl::string_view name, HloOpcode opcode) { SCOPED_TRACE(absl::StrCat("Comparing finding by name: ", name, - " and opcode: ", opcode_str)); - auto withString = hlo_query::FindFirstInstruction(main, name); - auto withOpCode = FindFirst(main, opcode_str); - EXPECT_EQ(withString.first, withOpCode.first); - EXPECT_EQ(withString.second, withOpCode.second); - if (withString.first != nullptr) - EXPECT_EQ(withString.first->ToString(), withOpCode.first->ToString()); + " and opcode: ", opcode)); + HloInstruction* by_name = hlo_query::FindInstruction(main, name); + HloInstruction* by_opcode = hlo_query::FindInstruction(main, opcode); + EXPECT_EQ(by_name, by_opcode); } TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { @@ -189,9 +198,9 @@ TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_NE(FindFirst(main, "add").first, nullptr); - EXPECT_NE(FindFirst(main, "constant").first, nullptr); - EXPECT_EQ(FindFirst(main, "select").first, nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kConstant), nullptr); + EXPECT_NE(hlo_query::FindInstruction(main, HloOpcode::kAdd), nullptr); + EXPECT_EQ(hlo_query::FindInstruction(main, HloOpcode::kSelect), nullptr); } TEST_F(HloQueryTest, FindInstructionUsingOpcodeAndNameEqualTest) { @@ -199,10 +208,10 @@ TEST_F(HloQueryTest, FindInstructionUsingOpcodeAndNameEqualTest) { std::unique_ptr module, ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - FindFirstInstructionsAndExpectEqual(main, "zero", "constant"); - FindFirstInstructionsAndExpectEqual(main, "out", "add"); + FindInstructionsAndExpectEqual(main, "zero", HloOpcode::kConstant); + FindInstructionsAndExpectEqual(main, "out", HloOpcode::kAdd); // both are not found - FindFirstInstructionsAndExpectEqual(main, "dummy", "select"); + FindInstructionsAndExpectEqual(main, "dummy", HloOpcode::kSelect); } TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { @@ -211,21 +220,10 @@ TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); EXPECT_NE(main, nullptr); - auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef"); - auto find_nothing = hlo_query::FindFirstInstruction(main, ""); - EXPECT_EQ(find_beef.first, nullptr); - EXPECT_EQ(find_beef.second, -1); - EXPECT_EQ(find_nothing.first, nullptr); - EXPECT_EQ(find_nothing.second, -1); -} - -TEST_F(HloQueryTest, IsBeforeInComputationTest) { - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr module, - ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); - const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five")); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out")); + auto find_beef = hlo_query::FindInstruction(main, "deadbeef"); + auto find_nothing = hlo_query::FindInstruction(main, ""); + EXPECT_EQ(find_beef, nullptr); + EXPECT_EQ(find_nothing, nullptr); } TEST_F(HloQueryTest, NextChannelIdForModuleWithoutChannelIdTest) { @@ -253,8 +251,10 @@ TEST_F(HloQueryTest, NextChannelIdTwoIdsTest) { HloModule test ENTRY test_computation { p = u32[] partition-id() - l = u32[] collective-permute(p), channel_id=8, source_target_pairs={{0,1},{1,2}} - r = u32[] collective-permute(p), channel_id=9, source_target_pairs={{2,3},{3,0}} + l = u32[] collective-permute(p), channel_id=8, + source_target_pairs={{0,1},{1,2}} + r = u32[] collective-permute(p), channel_id=9, + source_target_pairs={{2,3},{3,0}} ROOT res = u32[] add(l,r) } )"; diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index c3cd98219899ea..f72ec2bbfc0c93 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -704,9 +704,9 @@ HloSharding TransposeSharding(const HloSharding& sharding, std::optional ReshapeSharding(const Shape& source_shape, const Shape& target_shape, - const HloSharding& sharding) { - if (sharding.IsTileMaximal() || sharding.IsManual()) { - return sharding; + const HloSharding& source_sharding) { + if (source_sharding.IsTileMaximal() || source_sharding.IsManual()) { + return source_sharding; } // In case of a tiled sharding, the reshaped sharding will be valid if the @@ -732,10 +732,24 @@ std::optional ReshapeSharding(const Shape& source_shape, DimensionVector target_dims_stack(target_shape.dimensions().rbegin(), target_shape.dimensions().rend()); DimensionVector sharding_tile_dims_stack( - sharding.tile_assignment().dimensions().begin(), - sharding.tile_assignment().dimensions().begin() + source_shape.rank()); + source_sharding.tile_assignment().dimensions().begin(), + source_sharding.tile_assignment().dimensions().begin() + + source_shape.rank()); std::reverse(sharding_tile_dims_stack.begin(), sharding_tile_dims_stack.end()); + int64_t source_dims_index = -1; + std::vector dims_to_replicate; + + auto source_dims_push = [&](int64_t shape_size, int64_t partitions) { + source_dims_stack.push_back(shape_size); + sharding_tile_dims_stack.push_back(partitions); + source_dims_index--; + }; + auto source_dims_pop = [&]() { + source_dims_stack.pop_back(); + sharding_tile_dims_stack.pop_back(); + source_dims_index++; + }; bool inplace_add_sharding_dim = false; auto append_sharding_dim = [&](int64_t size) { @@ -753,22 +767,20 @@ std::optional ReshapeSharding(const Shape& source_shape, break; } - int64_t source_dim_product = 1; + int64_t source_dims_product = 1; while (!sharding_tile_dims_stack.empty() && sharding_tile_dims_stack.back() == 1) { - sharding_tile_dims_stack.pop_back(); - source_dim_product *= source_dims_stack.back(); - source_dims_stack.pop_back(); + source_dims_product *= source_dims_stack.back(); + source_dims_pop(); } while (!target_dims_stack.empty() && target_dims_stack.back() > 1 && - source_dim_product % target_dims_stack.back() == 0) { - source_dim_product /= target_dims_stack.back(); + source_dims_product % target_dims_stack.back() == 0) { + source_dims_product /= target_dims_stack.back(); target_dims_stack.pop_back(); append_sharding_dim(1); } - if (source_dim_product != 1) { - source_dims_stack.push_back(source_dim_product); - sharding_tile_dims_stack.push_back(1); + if (source_dims_product != 1) { + source_dims_push(source_dims_product, 1); } if (target_dims_stack.empty()) { @@ -781,9 +793,8 @@ std::optional ReshapeSharding(const Shape& source_shape, int64_t s_partitions = 1; if (!source_dims_stack.empty()) { s_size = source_dims_stack.back(); - source_dims_stack.pop_back(); s_partitions = sharding_tile_dims_stack.back(); - sharding_tile_dims_stack.pop_back(); + source_dims_pop(); } if (s_size == t_size) { @@ -793,19 +804,20 @@ std::optional ReshapeSharding(const Shape& source_shape, t_size % s_partitions == 0) { // If s_partitions evenly divides both s_size and t_size, we can add this // sharding dim and work on shard sized shapes in the next iteration. - source_dims_stack.push_back(s_size / s_partitions); + source_dims_push(s_size / s_partitions, 1); target_dims_stack.push_back(t_size / s_partitions); - sharding_tile_dims_stack.push_back(1); append_sharding_dim(s_partitions); inplace_add_sharding_dim = true; } else if (t_size == 1) { // Trivial dimension added. append_sharding_dim(1); - source_dims_stack.push_back(s_size); - sharding_tile_dims_stack.push_back(s_partitions); + source_dims_push(s_size, s_partitions); } else if (s_size == 1) { // Trivial dimension removed. target_dims_stack.push_back(t_size); + if (s_partitions > 1) { + dims_to_replicate.push_back(source_dims_index); + } } else if (s_size > t_size) { // Dimension split. if (s_size % s_partitions != 0) { @@ -819,13 +831,11 @@ std::optional ReshapeSharding(const Shape& source_shape, if (t_size % s_partitions == 0) { append_sharding_dim(s_partitions); // We have part of the s_size unprocessed, so put it back to stack. - source_dims_stack.push_back(s_size / t_size); - sharding_tile_dims_stack.push_back(1); + source_dims_push(s_size / t_size, 1); } else if (s_partitions % t_size == 0) { append_sharding_dim(t_size); // We have part of the s_size unprocessed, so put it back to stack. - source_dims_stack.push_back(s_size / t_size); - sharding_tile_dims_stack.push_back(s_partitions / t_size); + source_dims_push(s_size / t_size, s_partitions / t_size); } else { append_sharding_dim(std::gcd(t_size, s_partitions)); break; @@ -860,6 +870,16 @@ std::optional ReshapeSharding(const Shape& source_shape, while (target_tile_assignment_dimensions.size() < target_shape.rank()) { target_tile_assignment_dimensions.push_back(1); } + + // If there is a source dimension satisfying (1) size is 1, (2) partition > 1, + // and (3) there is no corresponding target dimension, we replicate the source + // sharding along this dimension since the source sharding cannot be + // propagated along this dimension. + const HloSharding sharding = !dims_to_replicate.empty() + ? PartiallyReplicateTiledShardingOnDims( + source_sharding, dims_to_replicate) + : source_sharding; + for (int64_t i = sharding.TiledDataRank(); i < sharding.tile_assignment().num_dimensions(); ++i) { target_tile_assignment_dimensions.push_back( @@ -1058,6 +1078,80 @@ bool ContainsTileSharding(const HloModule& module) { return false; } +template +std::vector argsort(absl::Span data) { + std::vector indices(data.size()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&data](int64_t i, int64_t j) { return data[i] < data[j]; }); + return indices; +} + +// Given a `source_sharding`, preserve the tiles along the `source_dims` and +// replicate the rest. The `target_dims` are used to determine the order of the +// dimensions in the resulting sharding. If `source_dims` and `target_dims` are +// in the different order (i.e., different argsort results), we need to +// transpose the tile assignment. +// +// Given the following input, +// * source_sharding = {devices=[2,3,5,7,11]<=[2310]} +// * source_dims = [2, 4, 1] +// * target_dims = [2, 1, 3] +// * target_shape_rank = 5 +// The result shoule be {devices=[1,11,5,3,1,14]<=[2,3,5,7,11]T(4,2,1,0,3) +// last_tile_dim_replicate}. +HloSharding PropagateShardingAlongDimsAndReplicateOthers( + const HloSharding& source_sharding, absl::Span source_dims, + absl::Span target_dims, int64_t target_shape_rank) { + CHECK_EQ(source_dims.size(), target_dims.size()); + if (source_sharding.IsTileMaximal() || source_sharding.IsManual()) { + return source_sharding; + } + + HloSharding replicate_other_dims = + PartiallyReplicateTiledShardingOnAllDimsExcept(source_sharding, + source_dims); + if (replicate_other_dims.IsTileMaximal()) { + return replicate_other_dims; + } + + std::vector argsort_source_dims = argsort(source_dims); + std::vector argsort_target_dims = argsort(target_dims); + if (argsort_source_dims != argsort_target_dims) { + std::vector perm( + replicate_other_dims.tile_assignment().num_dimensions(), -1); + for (int64_t i = 0; i < source_dims.size(); ++i) { + perm[source_dims[argsort_target_dims[i]]] = i; + } + int64_t i = source_dims.size(); + for (int64_t& perm_element : perm) { + if (perm_element == -1) { + perm_element = i++; + } + } + replicate_other_dims = TransposeSharding(replicate_other_dims, perm); + } + + std::vector target_tile_dims(target_shape_rank, 1); + for (int i = 0; i < source_dims.size(); ++i) { + target_tile_dims[target_dims[i]] = + source_sharding.tile_assignment().dim(source_dims[i]); + } + for (int64_t i = replicate_other_dims.TiledDataRank(); + i < replicate_other_dims.tile_assignment().num_dimensions(); ++i) { + target_tile_dims.push_back(replicate_other_dims.tile_assignment().dim(i)); + } + + auto target_tile_assignment = + replicate_other_dims.tile_assignment().Reshape(target_tile_dims); + return replicate_other_dims.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(target_tile_assignment, + replicate_other_dims.metadata()) + : HloSharding::Subgroup(target_tile_assignment, + replicate_other_dims.subgroup_types(), + replicate_other_dims.metadata()); +} + HloSharding GatherOutputShardingFromIndexIndexPassthroughDimensions( const HloSharding& index_sharding, const HloInstruction* hlo) { CHECK(hlo->opcode() == HloOpcode::kGather); @@ -1369,22 +1463,24 @@ namespace { absl::InlinedVector GetGatherScatterOperandPassthroughOperandDims( const Shape& operand_shape, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { absl::InlinedVector passthrough_dims; - int64_t collapsed = 0; - for (int64_t i = 0; i != operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + int64_t collapsed_or_batching = 0; + for (int64_t i = 0; i < operand_shape.rank(); ++i) { + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; continue; } if (slice_size[i] != operand_shape.dimensions(i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; - if (i - collapsed > 0 && - offset_dim < offset_or_window_dims[i - collapsed - 1]) { + if (i - collapsed_or_batching > 0 && + offset_or_window_dims[i - collapsed_or_batching] < + offset_or_window_dims[i - collapsed_or_batching - 1]) { // Output offsets are transposed, we do not support this case. continue; } @@ -1397,22 +1493,25 @@ absl::InlinedVector GetGatherScatterOperandPassthroughOutputOrUpdateDims( const int64_t output_or_update_rank, const Shape& operand_shape, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); absl::InlinedVector passthrough_dims; - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_dims.push_back(offset_dim); } return passthrough_dims; @@ -1426,6 +1525,7 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( const Shape& operand_shape, const HloSharding& operand_sharding, const int64_t output_or_update_rank, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size, const int64_t index_vector_dim) { @@ -1433,18 +1533,20 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( return std::nullopt; } auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); DimensionVector passthrough_tile(output_or_update_rank, 1); - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_tile[offset_dim] = operand_sharding.tile_assignment().dim(i); } HloSharding replicate_non_passthrough_dims = @@ -1475,6 +1577,7 @@ std::optional PassthroughOperandToGatherOutputOrScatterUpdate( std::optional PassthroughGatherOutputOrScatterUpdateToOperand( const Shape& operand_shape, const HloSharding& output_or_update_sharding, absl::Span collapsed_or_inserted_dims, + absl::Span operand_batching_dims, absl::Span index_map, absl::Span offset_or_window_dims, absl::Span slice_size) { @@ -1483,20 +1586,22 @@ std::optional PassthroughGatherOutputOrScatterUpdateToOperand( return output_or_update_sharding; } auto operand_passthrough_dims = GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_or_inserted_dims, index_map, - offset_or_window_dims, slice_size); + operand_shape, collapsed_or_inserted_dims, operand_batching_dims, + index_map, offset_or_window_dims, slice_size); DimensionVector passthrough_tile(operand_shape.rank(), 1); - int64_t collapsed = 0; + int64_t collapsed_or_batching = 0; // Relevant dims have shardings passed to the operand. DimensionVector relevant_output_or_update_dims; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(collapsed_or_inserted_dims, i)) { - collapsed++; + if (absl::c_linear_search(collapsed_or_inserted_dims, i) || + absl::c_linear_search(operand_batching_dims, i)) { + collapsed_or_batching++; + continue; } if (!absl::c_linear_search(operand_passthrough_dims, i)) { continue; } - int64_t offset_dim = offset_or_window_dims[i - collapsed]; + int64_t offset_dim = offset_or_window_dims[i - collapsed_or_batching]; passthrough_tile[i] = output_or_update_sharding.tile_assignment().dim(offset_dim); relevant_output_or_update_dims.push_back(offset_dim); @@ -1528,71 +1633,37 @@ std::optional GatherOperandShardingFromOutputParallelDimensions( if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) { return output_sharding; } - auto parallel_dims = GetGatherParallelBatchDims(gather, call_graph); - if (parallel_dims) { - auto output_parallel_dims = - GetGatherParallelOutputDims(gather, *parallel_dims); - auto output_aligned_operand_parallel_dims = - IndexAlignedOperandParallelDims(*parallel_dims); - const Shape gather_shape = gather.shape(); - CHECK_EQ(output_parallel_dims.size(), - output_aligned_operand_parallel_dims.size()); - DimensionVector operand_tile_assignment(gather.operand(0)->shape().rank(), - 1); - DimensionVector relevant_output_dims; - for (int i = 0, parallel_idx = 0; i < gather_shape.rank(); ++i) { - if (parallel_idx >= output_parallel_dims.size() || - output_parallel_dims[parallel_idx] != i) { - continue; - } - const int64_t operand_dim = - output_aligned_operand_parallel_dims[parallel_idx++]; - operand_tile_assignment[operand_dim] = - output_sharding.tile_assignment().dim(i); - relevant_output_dims.push_back(i); - } - HloSharding relevant_output_sharding = - PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding, - relevant_output_dims); - if (relevant_output_sharding.IsTileMaximal()) { - return std::move(relevant_output_sharding); - } - - for (int64_t i = relevant_output_sharding.TiledDataRank(); - i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) { - operand_tile_assignment.push_back( - relevant_output_sharding.tile_assignment().dim(i)); - } - auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape( - operand_tile_assignment); - return relevant_output_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(tile_assignment, - output_sharding.metadata()) - : HloSharding::Subgroup( - tile_assignment, relevant_output_sharding.subgroup_types(), - output_sharding.metadata()); + + GatherScatterParallelDims parallel_dims; + + const GatherDimensionNumbers& dnums = gather.gather_dimension_numbers(); + if (!dnums.operand_batching_dims().empty()) { + parallel_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + } + if (std::optional implicit_parallel_dims = + GetGatherParallelBatchDims(gather, call_graph)) { + parallel_dims.operand_parallel_dims.insert( + parallel_dims.operand_parallel_dims.end(), + implicit_parallel_dims->operand_parallel_dims.begin(), + implicit_parallel_dims->operand_parallel_dims.end()); + parallel_dims.indices_parallel_dims.insert( + parallel_dims.indices_parallel_dims.end(), + implicit_parallel_dims->indices_parallel_dims.begin(), + implicit_parallel_dims->indices_parallel_dims.end()); } - return std::nullopt; -} -// Reorders `to_align` based on the order of how `target_permuted` is reordered -// from `target`, expecting the container size to be small. -absl::InlinedVector AlignSmallContainers( - absl::Span to_align, absl::Span target, - absl::Span target_permuted) { - CHECK(absl::c_is_permutation(target_permuted, target)); - CHECK_EQ(to_align.size(), target.size()); - absl::InlinedVector to_align_permuted(to_align.size()); - for (auto i = 0; i < target.size(); ++i) { - // This is small so just look linearly. - for (auto j = 0; j < target_permuted.size(); ++j) { - if (target_permuted[j] == target[i]) { - to_align_permuted[j] = to_align[i]; - break; - } - } + if (parallel_dims.operand_parallel_dims.empty()) { + return std::nullopt; } - return to_align_permuted; + + return PropagateShardingAlongDimsAndReplicateOthers( + output_sharding, GetGatherParallelOutputDims(gather, parallel_dims), + parallel_dims.operand_parallel_dims, gather.operand(0)->shape().rank()); } } // namespace @@ -1609,27 +1680,17 @@ GatherOutputShardingFromOperandOperandPassthroughDimensions( const Shape& operand_shape, const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return PassthroughOperandToGatherOutputOrScatterUpdate( - operand_shape, operand_sharding, hlo.shape().rank(), collapsed_slice_dims, - start_index_map, offset_dims, slice_sizes, dnums.index_vector_dim()); + operand_shape, operand_sharding, hlo.shape().rank(), + dnums.collapsed_slice_dims(), dnums.operand_batching_dims(), + dnums.start_index_map(), dnums.offset_dims(), slice_sizes, + dnums.index_vector_dim()); } std::optional GatherOperandShardingFromOutput( const HloSharding& output_sharding, const HloInstruction& hlo, const CallGraph& call_graph) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); // Prioritize parallel sharding first as this is how it is in // spmd_partitioner. std::optional parallel_sharding = @@ -1637,8 +1698,10 @@ std::optional GatherOperandShardingFromOutput( call_graph); std::optional passthrough_sharding = PassthroughGatherOutputOrScatterUpdateToOperand( - hlo.operand(0)->shape(), output_sharding, collapsed_slice_dims, - start_index_map, offset_dims, hlo.gather_slice_sizes()); + hlo.operand(0)->shape(), output_sharding, + dnums.collapsed_slice_dims(), dnums.operand_batching_dims(), + dnums.start_index_map(), dnums.offset_dims(), + hlo.gather_slice_sizes()); // Try to merge the two shardings or return the one that is present if only // one of the two is. if (!passthrough_sharding) { @@ -1664,7 +1727,8 @@ std::vector GetScatterSliceSize(const Shape& operand_shape, std::vector slice_size(operand_shape.rank(), 1); int64_t num_update_window_dims = 0; for (int64_t i = 0; i < operand_shape.rank(); ++i) { - if (absl::c_linear_search(dnums.inserted_window_dims(), i)) { + if (absl::c_linear_search(dnums.inserted_window_dims(), i) || + absl::c_linear_search(dnums.input_batching_dims(), i)) { continue; } slice_size[i] = update_shape.dimensions( @@ -1677,19 +1741,13 @@ std::vector GetScatterSliceSize(const Shape& operand_shape, std::optional ScatterOutputShardingFromUpdate( const HloSharding& update_sharding, const HloScatterInstruction& scatter) { const auto& dnums = scatter.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); std::vector slice_size = GetScatterSliceSize(scatter.scatter_operands()[0]->shape(), scatter.scatter_updates()[0]->shape(), dnums); return PassthroughGatherOutputOrScatterUpdateToOperand( scatter.scatter_operands()[0]->shape(), update_sharding, - inserted_window_dims, scatter_dims_to_operand_dims, update_window_dims, + dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), slice_size); } @@ -1744,18 +1802,12 @@ ScatterUpdateShardingFromOutputOperandPassthroughDimensions( const HloScatterInstruction* scatter = DynCast(&hlo); CHECK(scatter); const auto& dnums = scatter->scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return PassthroughOperandToGatherOutputOrScatterUpdate( output_shape, output_sharding, - scatter->scatter_updates()[0]->shape().rank(), inserted_window_dims, - scatter_dims_to_operand_dims, update_window_dims, slice_sizes, - dnums.index_vector_dim()); + scatter->scatter_updates()[0]->shape().rank(), + dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), + slice_sizes, dnums.index_vector_dim()); } std::optional ScatterUpdateShardingFromOutputParallelDimensions( @@ -1764,58 +1816,37 @@ std::optional ScatterUpdateShardingFromOutputParallelDimensions( if (output_sharding.IsTileMaximal() || output_sharding.IsManual()) { return output_sharding; } - auto parallel_dims = GetScatterParallelBatchDims(scatter, call_graph); - if (parallel_dims) { - auto update_parallel_dims = - GetScatterParallelUpdateDims(scatter, *parallel_dims); - auto index_aligned_operand_parallel_dims = - IndexAlignedOperandParallelDims(*parallel_dims); - auto operand_parallel_dims_sorted = index_aligned_operand_parallel_dims; - absl::c_sort(operand_parallel_dims_sorted); - auto operand_aligned_update_parallel_dims = AlignSmallContainers( - update_parallel_dims, index_aligned_operand_parallel_dims, - operand_parallel_dims_sorted); - const Shape scatter_shape = scatter.shape().IsTuple() - ? scatter.shape().tuple_shapes()[0] - : scatter.shape(); - CHECK_EQ(update_parallel_dims.size(), - index_aligned_operand_parallel_dims.size()); - DimensionVector update_tile_assignment( - scatter.scatter_updates()[0]->shape().rank(), 1); - DimensionVector relevant_output_dims; - for (int i = 0, parallel_idx = 0; i < scatter_shape.rank(); ++i) { - if (parallel_idx >= operand_parallel_dims_sorted.size() || - operand_parallel_dims_sorted[parallel_idx] != i) { - continue; - } - const int64_t update_dim = - operand_aligned_update_parallel_dims[parallel_idx++]; - update_tile_assignment[update_dim] = - output_sharding.tile_assignment().dim(i); - relevant_output_dims.push_back(i); - } - HloSharding relevant_output_sharding = - PartiallyReplicateTiledShardingOnAllDimsExcept(output_sharding, - relevant_output_dims); - if (relevant_output_sharding.IsTileMaximal()) { - return std::move(relevant_output_sharding); - } - - for (int64_t i = relevant_output_sharding.TiledDataRank(); - i < relevant_output_sharding.tile_assignment().num_dimensions(); ++i) { - update_tile_assignment.push_back( - relevant_output_sharding.tile_assignment().dim(i)); - } - auto tile_assignment = relevant_output_sharding.tile_assignment().Reshape( - update_tile_assignment); - return relevant_output_sharding.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(tile_assignment, - output_sharding.metadata()) - : HloSharding::Subgroup( - tile_assignment, relevant_output_sharding.subgroup_types(), - output_sharding.metadata()); + + GatherScatterParallelDims parallel_dims; + + const ScatterDimensionNumbers& dnums = scatter.scatter_dimension_numbers(); + if (!dnums.input_batching_dims().empty()) { + parallel_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), dnums.input_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + } + if (std::optional implicit_parallel_dims = + GetScatterParallelBatchDims(scatter, call_graph)) { + parallel_dims.operand_parallel_dims.insert( + parallel_dims.operand_parallel_dims.end(), + implicit_parallel_dims->operand_parallel_dims.begin(), + implicit_parallel_dims->operand_parallel_dims.end()); + parallel_dims.indices_parallel_dims.insert( + parallel_dims.indices_parallel_dims.end(), + implicit_parallel_dims->indices_parallel_dims.begin(), + implicit_parallel_dims->indices_parallel_dims.end()); + } + + if (parallel_dims.operand_parallel_dims.empty()) { + return std::nullopt; } - return std::nullopt; + + return PropagateShardingAlongDimsAndReplicateOthers( + output_sharding, parallel_dims.operand_parallel_dims, + GetScatterParallelUpdateDims(scatter, parallel_dims), + scatter.scatter_updates()[0]->shape().rank()); } HloSharding GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( @@ -2247,8 +2278,7 @@ std::optional GetGatherScatterBatchParallelDims( // %indices = concatenate(..., %iota.1, ...) // ... = gather(..., %indices) // is common for tf.reverse_sequence and would match this case. - const int num_indices = index_map.size(); - std::vector index_parallel_in_dim(num_indices, -1); + std::vector index_parallel_in_dim(index_map.size(), -1); // looks through any copies to find the concatenate. auto findConcatenate = [&](const HloInstruction* indices) { @@ -2320,8 +2350,8 @@ std::optional GetGatherScatterBatchParallelDims( } } if (!indices_parallel_dims.empty()) { - return GatherScatterParallelDims{ - indices_parallel_dims, operand_parallel_dims, index_parallel_in_dim}; + return GatherScatterParallelDims{indices_parallel_dims, + operand_parallel_dims}; } return std::nullopt; } @@ -2357,10 +2387,9 @@ std::optional GetScatterParallelBatchDims( static absl::InlinedVector GetGatherOutputOrScatterUpdateParallelDims( - const Shape& shape, const GatherScatterParallelDims& parallel_dim, + const Shape& shape, absl::Span indices_parallel_dims, int64_t index_vector_dim, absl::Span offset_or_window_dims) { absl::InlinedVector output_parallel_dims; - auto indices_parallel_dims = parallel_dim.indices_parallel_dims; for (int64_t indices_parallel_dim : indices_parallel_dims) { for (int i = 0, idx_dim = 0; i < shape.dimensions_size(); ++i) { if (absl::c_linear_search(offset_or_window_dims, i)) { @@ -2374,6 +2403,7 @@ GetGatherOutputOrScatterUpdateParallelDims( ++idx_dim; } } + CHECK_EQ(output_parallel_dims.size(), indices_parallel_dims.size()); return output_parallel_dims; } @@ -2385,7 +2415,8 @@ absl::InlinedVector GetGatherParallelOutputDims( int64_t index_vector_dim = dnums.index_vector_dim(); const auto& offset_dims = dnums.offset_dims(); return GetGatherOutputOrScatterUpdateParallelDims( - output_shape, parallel_dim, index_vector_dim, offset_dims); + output_shape, parallel_dim.indices_parallel_dims, index_vector_dim, + offset_dims); } absl::InlinedVector GetScatterParallelUpdateDims( @@ -2397,53 +2428,38 @@ absl::InlinedVector GetScatterParallelUpdateDims( int64_t index_vector_dim = dnums.index_vector_dim(); const auto& window_dims = dnums.update_window_dims(); return GetGatherOutputOrScatterUpdateParallelDims( - update_shape, parallel_dim, index_vector_dim, window_dims); + update_shape, parallel_dim.indices_parallel_dims, index_vector_dim, + window_dims); } absl::InlinedVector GetGatherOperandPassthroughOperandDims( const Shape& operand_shape, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return GetGatherScatterOperandPassthroughOperandDims( - operand_shape, collapsed_slice_dims, start_index_map, offset_dims, - slice_sizes); + operand_shape, dnums.collapsed_slice_dims(), + dnums.operand_batching_dims(), dnums.start_index_map(), + dnums.offset_dims(), slice_sizes); } absl::InlinedVector GetScatterOperandPassthroughOperandDims( const Shape& operand_shape, const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return GetGatherScatterOperandPassthroughOperandDims( - operand_shape, inserted_window_dims, scatter_dims_to_operand_dims, - update_window_dims, slice_sizes); + operand_shape, dnums.inserted_window_dims(), dnums.input_batching_dims(), + dnums.scatter_dims_to_operand_dims(), dnums.update_window_dims(), + slice_sizes); } absl::InlinedVector GetGatherOperandPassthroughOutputDims( const Shape& output_shape, const Shape& operand_shape, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.gather_dimension_numbers(); - std::vector collapsed_slice_dims( - dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end()); - std::vector start_index_map(dnums.start_index_map().begin(), - dnums.start_index_map().end()); - std::vector offset_dims(dnums.offset_dims().begin(), - dnums.offset_dims().end()); return GetGatherScatterOperandPassthroughOutputOrUpdateDims( - output_shape.rank(), operand_shape, collapsed_slice_dims, start_index_map, - offset_dims, slice_sizes); + output_shape.rank(), operand_shape, dnums.collapsed_slice_dims(), + dnums.operand_batching_dims(), dnums.start_index_map(), + dnums.offset_dims(), slice_sizes); } absl::InlinedVector GetScatterOperandPassthroughUpdateDims( @@ -2451,16 +2467,10 @@ absl::InlinedVector GetScatterOperandPassthroughUpdateDims( const HloSharding& operand_sharding, const HloInstruction& hlo, absl::Span slice_sizes) { const auto& dnums = hlo.scatter_dimension_numbers(); - std::vector inserted_window_dims( - dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end()); - std::vector scatter_dims_to_operand_dims( - dnums.scatter_dims_to_operand_dims().begin(), - dnums.scatter_dims_to_operand_dims().end()); - std::vector update_window_dims(dnums.update_window_dims().begin(), - dnums.update_window_dims().end()); return GetGatherScatterOperandPassthroughOutputOrUpdateDims( - update_shape.rank(), operand_shape, inserted_window_dims, - scatter_dims_to_operand_dims, update_window_dims, slice_sizes); + update_shape.rank(), operand_shape, dnums.inserted_window_dims(), + dnums.input_batching_dims(), dnums.scatter_dims_to_operand_dims(), + dnums.update_window_dims(), slice_sizes); } absl::InlinedVector GetGatherScatterIndexPassthroughIndexDims( @@ -2489,69 +2499,12 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( } HloSharding InferGatherScatterParallelShardingFromOperandSharding( - const HloSharding& operand_sharding, const Shape& operand_shape, - const Shape& shape, + const HloSharding& operand_sharding, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims) { - if (operand_sharding.IsTileMaximal()) { - return operand_sharding; - } - std::vector output_tile_dims(shape.rank(), 1); - std::vector operand_non_parallel_dims; - operand_non_parallel_dims.reserve(operand_shape.rank()); - // Detect non parallel dimensions in the operand. - for (int i = 0; i < operand_shape.rank(); ++i) { - if (!absl::c_linear_search(output_aligned_operand_parallel_dims, i)) { - operand_non_parallel_dims.push_back(i); - } - } - // Collect tile dimensions in the operand. The order of the parallel - // dimensions in output_aligned_operand_parallel_dims is the same as that of - // the output - for (int i = 0; i < output_aligned_operand_parallel_dims.size(); ++i) { - const int64_t operand_idx = output_aligned_operand_parallel_dims[i]; - const int64_t output_idx = output_parallel_dims[i]; - output_tile_dims[output_idx] = - operand_sharding.tile_assignment().dim(operand_idx); - } - HloSharding replicate_non_parallel_dims = - PartiallyReplicateTiledShardingOnDims(operand_sharding, - operand_non_parallel_dims); - if (replicate_non_parallel_dims.IsTileMaximal()) { - return replicate_non_parallel_dims; - } - for (int64_t i = replicate_non_parallel_dims.TiledDataRank(); - i < replicate_non_parallel_dims.tile_assignment().num_dimensions(); - ++i) { - output_tile_dims.push_back( - replicate_non_parallel_dims.tile_assignment().dim(i)); - } - auto output_tile_assignment = - replicate_non_parallel_dims.tile_assignment().Reshape(output_tile_dims); - return replicate_non_parallel_dims.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(output_tile_assignment, - replicate_non_parallel_dims.metadata()) - : HloSharding::Subgroup( - output_tile_assignment, - replicate_non_parallel_dims.subgroup_types(), - replicate_non_parallel_dims.metadata()); -} - -absl::InlinedVector IndexAlignedOperandParallelDims( - const GatherScatterParallelDims& parallel_dims) { - CHECK_EQ(parallel_dims.indices_parallel_dims.size(), - parallel_dims.operand_parallel_dims.size()); - std::vector index_parallel_in_dim = - parallel_dims.index_parallel_in_dim; - // Remove all -1s in `index_parallel_in_dim`. - index_parallel_in_dim.erase(std::remove(index_parallel_in_dim.begin(), - index_parallel_in_dim.end(), -1), - index_parallel_in_dim.end()); - // Populate the operand parallel dimensions based on the order of the index - // batch dims (which is the same order as the output). - return AlignSmallContainers(parallel_dims.operand_parallel_dims, - index_parallel_in_dim, - parallel_dims.indices_parallel_dims); + return PropagateShardingAlongDimsAndReplicateOthers( + operand_sharding, output_aligned_operand_parallel_dims, + output_parallel_dims, shape.rank()); } std::string GroupedSharding::ToString() const { diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index fc440f38f3215d..3233fa16245494 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -44,7 +44,6 @@ namespace hlo_sharding_util { struct GatherScatterParallelDims { absl::InlinedVector indices_parallel_dims; absl::InlinedVector operand_parallel_dims; - std::vector index_parallel_in_dim; }; // Determines if the first operand 'potential_subsharding' is a subsharding of @@ -122,7 +121,7 @@ HloSharding TransposeSharding(const HloSharding& sharding, // maximal sharding returns the original sharding. std::optional ReshapeSharding(const Shape& source_shape, const Shape& target_shape, - const HloSharding& sharding); + const HloSharding& source_sharding); // Propagates sharding through reshape. It tries to find partial matches on // subsets of dimensions that could satisfy ReshapeSharding() constraints, then @@ -359,17 +358,10 @@ GetGatherScatterIndexPassthroughOutputOrUpdateDims( // Infer output sharding on index parallel dimensions for gather/scatter from // gather operand/indices or scatter operands/indices/updates. HloSharding InferGatherScatterParallelShardingFromOperandSharding( - const HloSharding& operand_sharding, const Shape& operand_shape, - const Shape& shape, + const HloSharding& operand_sharding, const Shape& shape, absl::Span output_aligned_operand_parallel_dims, absl::Span output_parallel_dims); -// Returns the parallel dimensions of the data operand of a gather/scatter with -// the order of the parallel dimensions matching that of the parallel dimensions -// of the indices. -absl::InlinedVector IndexAlignedOperandParallelDims( - const GatherScatterParallelDims& parallel_dims); - // Represents grouping devices in a tiled sharding along certain dimensions. // Elements in group dimensions define different device groups, and the sharding // represents the in-group sharding. diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index fcbc4a4cd4bbdf..75c81215cb6916 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -29,11 +29,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/service/dot_as_convolution_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -132,6 +132,65 @@ TEST(HloShardingUtilTest, TransposeShardingWithCollapsedDimsSubgroupManual) { output); } +TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned1) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16}); + HloSharding input_sharding = HloSharding::IotaTile({3, 2, 2}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 2, 2}, {1, 2, 0})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned2) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 1, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16}); + HloSharding input_sharding = HloSharding::IotaTile({2, 3, 2}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({2, 2, 3}, {2, 3, 2}, {0, 2, 1})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned3) { + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 1, 16}); + Shape output_shape = ShapeUtil::MakeShape(F32, {32}); + HloSharding input_sharding = HloSharding::IotaTile({2, 3, 2}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({4, 3}, {2, 3, 2}, {0, 2, 1})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned4) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 32}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 16}); + HloSharding input_sharding = HloSharding::IotaTile({3, 4}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({2, 2, 3}, {3, 4}, {1, 0})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + +TEST(HloShardingUtilTest, ReshapeShardingDimensionSizeOnePartitioned5) { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 32}); + Shape output_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 16}); + HloSharding input_sharding = HloSharding::IotaTile({2, 3, 4}); + HloSharding output_sharding = HloSharding::IotaTile({2, 3, 2, 2}); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + TEST(HloShardingUtilTest, ReshapeShardingMaximal) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 3, 5}); Shape output_shape = ShapeUtil::MakeShape(F32, {3, 5, 2}); @@ -1027,7 +1086,7 @@ TEST(HloShardingUtilTest, UntileShape) { ShapeUtil::MakeTupleShape({expected_shape_0, expected_shape_1})); } -using HloShardingUtilTestWithHlo = HloTestBase; +using HloShardingUtilTestWithHlo = HloHardwareIndependentTestBase; TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest1) { absl::string_view hlo_string = R"( diff --git a/third_party/xla/xla/internal/BUILD b/third_party/xla/xla/internal/BUILD new file mode 100644 index 00000000000000..910c821b378351 --- /dev/null +++ b/third_party/xla/xla/internal/BUILD @@ -0,0 +1,10 @@ +load("//xla/internal:package_groups.bzl", "xla_internal_packages") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:private"], + licenses = ["notice"], +) + +# Required to load the package groups. +xla_internal_packages() diff --git a/third_party/xla/xla/internal/README b/third_party/xla/xla/internal/README new file mode 100644 index 00000000000000..1e325ba14d8e17 --- /dev/null +++ b/third_party/xla/xla/internal/README @@ -0,0 +1,5 @@ +These files are internal to XLA and should not be exposed externally. +Integrators should use the PJRT API. See https://openxla.org/xla/pjrt_integration + +Even tools within XLA should NOT use these files directly and instead use the +PJRT API to ensure consistent usage. diff --git a/third_party/xla/xla/internal/package_groups.bzl b/third_party/xla/xla/internal/package_groups.bzl new file mode 100644 index 00000000000000..32848df74ff71a --- /dev/null +++ b/third_party/xla/xla/internal/package_groups.bzl @@ -0,0 +1,7 @@ +"""Package groups for XLA internal.""" + +def xla_internal_packages(name = "xla_internal_packages"): + native.package_group( + name = "hwi_internal", + packages = ["//..."], + ) diff --git a/third_party/xla/xla/iterator_util.h b/third_party/xla/xla/iterator_util.h index 80001e4a9b2996..2457348e2d3f8e 100644 --- a/third_party/xla/xla/iterator_util.h +++ b/third_party/xla/xla/iterator_util.h @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" namespace xla { diff --git a/third_party/xla/xla/layout_util.cc b/third_party/xla/xla/layout_util.cc index c5c6db9392def8..5c94af35f41908 100644 --- a/third_party/xla/xla/layout_util.cc +++ b/third_party/xla/xla/layout_util.cc @@ -623,8 +623,8 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { return CopyLayoutInternal(src, dst); } -/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs, - const Shape& rhs) { +/* static */ bool LayoutUtil::LayoutsInShapesEqual( + const Shape& lhs, const Shape& rhs, std::optional equal) { if (lhs.IsTuple()) { if (!rhs.IsTuple() || ShapeUtil::TupleElementCount(lhs) != ShapeUtil::TupleElementCount(rhs)) { @@ -647,6 +647,11 @@ absl::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) { if (!lhs.has_layout() || !rhs.has_layout()) { return false; } + + if (equal.has_value()) { + return equal.value()(lhs.layout(), rhs.layout()); + } + return LayoutUtil::Equal(lhs.layout(), rhs.layout()); } // Layouts of non-array and non-tuple shapes is ignored. diff --git a/third_party/xla/xla/layout_util.h b/third_party/xla/xla/layout_util.h index 8192c2bbb7a052..f49c25b3f0079b 100644 --- a/third_party/xla/xla/layout_util.h +++ b/third_party/xla/xla/layout_util.h @@ -250,7 +250,9 @@ class LayoutUtil { // lhs and rhs need not be compatible to have the same layout but the two // shapes must have the same tuple structure (if any) and arrays must have the // same rank. Element type is ignored. - static bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs); + static bool LayoutsInShapesEqual( + const Shape& lhs, const Shape& rhs, + std::optional equal = std::nullopt); // Returns whether the given dimensions are consecutive in the given layout, // not necessarily in the order given. diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index 5837c54ad81eab..5ac1cde98f1d8c 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,6 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") +load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla/tsl:tsl.bzl", "if_cuda_tools", "if_google", "if_oss") def enforce_glob(files, **kwargs): @@ -209,7 +210,11 @@ def lit_test( srcs = tools, bin_dir = bin_dir, lib_dir = lib_dir, - deps = ["//xla/stream_executor/cuda:all_runtime"], + deps = if_cuda_is_configured( + [ + "//xla/stream_executor/cuda:all_runtime", + ], + ), visibility = ["//visibility:private"], **kwargs ) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index c1026718435087..ed0716f8ec50e8 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -91,9 +91,10 @@ bool LiteralProtoHasValues(const LiteralProto& proto) { !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || - !proto.f8e5m2s().empty() || !proto.f8e4m3fns().empty() || - !proto.f8e4m3b11fnuzs().empty() || !proto.f8e5m2fnuzs().empty() || - !proto.f8e4m3fnuzs().empty() || !proto.f16s().empty() || + !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || + !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || + !proto.f8e5m2fnuzs().empty() || !proto.f8e4m3fnuzs().empty() || + !proto.f8e3m4s().empty() || !proto.f16s().empty() || !proto.bf16s().empty() || proto.f32s_size() || proto.f64s_size() || proto.c64s_size() || proto.c128s_size() || proto.preds_size() || proto.tuple_literals_size(); @@ -1684,7 +1685,15 @@ void ConvertBetweenNativeTypes(absl::Span src_data, return std::numeric_limits::lowest(); } } - return static_cast(src); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case e3m4 by + // casting to half first. + if constexpr (sizeof(src) == 1 && + std::is_same_v) { + return static_cast(static_cast(src)); + } else { + return static_cast(src); + } }; NativeDestT* dest_data = static_cast(dst_base); @@ -1950,7 +1959,7 @@ template static bool AllElementsEqualValue(absl::Span data, NativeT value) { for (int64_t i = 0; i < data.size(); ++i) { - if (!EqualIncludingNan(data[i], value)) { + if (memcmp(&data[i], &value, sizeof value)) { return false; } } @@ -2258,6 +2267,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E4M3: + *proto->mutable_f8e4m3s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F8E4M3FN: *proto->mutable_f8e4m3fns() = std::string( reinterpret_cast(data().data()), @@ -2278,6 +2292,11 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { reinterpret_cast(data().data()), size_bytes_dense()); break; + case F8E3M4: + *proto->mutable_f8e3m4s() = std::string( + reinterpret_cast(data().data()), + size_bytes_dense()); + break; case F16: *proto->mutable_f16s() = std::string(reinterpret_cast(data().data()), @@ -2436,6 +2455,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E4M3: { + const std::string& s(proto.f8e4m3s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e4m3) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F8E4M3FN: { const std::string& s(proto.f8e4m3fns()); TF_RET_CHECK(data().size() * @@ -2468,6 +2494,13 @@ absl::Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { memcpy(untyped_data(), s.data(), s.size()); break; } + case F8E3M4: { + const std::string& s(proto.f8e3m4s()); + TF_RET_CHECK(data().size() * sizeof(tsl::float8_e3m4) == + s.size()); + memcpy(untyped_data(), s.data(), s.size()); + break; + } case F16: { const std::string& s(proto.f16s()); TF_RET_CHECK(data().size() * sizeof(half) == s.size()); diff --git a/third_party/xla/xla/literal_comparison.cc b/third_party/xla/xla/literal_comparison.cc index fa3a7cda9824cd..c97629594122bb 100644 --- a/third_party/xla/xla/literal_comparison.cc +++ b/third_party/xla/xla/literal_comparison.cc @@ -354,8 +354,16 @@ class NearComparator { return primitive_util::FloatingPointTypeSwitch( [&](const auto kType) -> int { using NarrowNativeT = primitive_util::NativeTypeOf; - return CalculateDistanceInFloats(NarrowNativeT(expected), - NarrowNativeT(actual)); + // TODO(b/370786669): Once ml_dtypes is updated to include + // https://github.com/jax-ml/ml_dtypes/pull/205, do not special-case + // e3m4 by casting to half first. + if constexpr (std::is_same_v) { + return CalculateDistanceInFloats(NarrowNativeT(half(expected)), + NarrowNativeT(half(actual))); + } else { + return CalculateDistanceInFloats(NarrowNativeT(expected), + NarrowNativeT(actual)); + } }, error_.low_precision_fp_error_spec.type); } diff --git a/third_party/xla/xla/literal_comparison_test.cc b/third_party/xla/xla/literal_comparison_test.cc index 37b7c31f267104..7713aceaaa3bc5 100644 --- a/third_party/xla/xla/literal_comparison_test.cc +++ b/third_party/xla/xla/literal_comparison_test.cc @@ -29,14 +29,15 @@ namespace { template class LiteralComparisonTest : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = + ::testing::Types; TYPED_TEST_SUITE(LiteralComparisonTest, TestedTypes); TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); auto expected = LiteralUtil::CreateR0(TypeParam(8.0)); - TF_EXPECT_OK(literal_comparison::Near(actual, expected, ErrorSpec(0.0, 0.0), + TF_EXPECT_OK(literal_comparison::Near(expected, actual, ErrorSpec(0.0, 0.0), /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -44,15 +45,19 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_Equal) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 10.0 : 9.0; + float expV = 9.0; // F8E4M3* + if (type == F8E5M2) + expV = 10.0; + else if (type == F8E3M4) + expV = 8.5; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -60,17 +65,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_1ulp) { TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(TypeParam(8.0)); - float expV = type == F8E5M2 ? 14.0 : 12.0; + float expV = 12.0; // F8E4M3* + if (type == F8E5M2) + expV = 14.0; + else if (type == F8E3M4) + expV = 10.0; auto expected = LiteralUtil::CreateR0(TypeParam{expV}); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } @@ -78,17 +87,21 @@ TYPED_TEST(LiteralComparisonTest, CompareNear_NotEqual_4ulps) { TYPED_TEST(LiteralComparisonTest, FloatUsingCompareNear_NotEqual_4ulps) { PrimitiveType type = primitive_util::NativeToPrimitiveType(); auto actual = LiteralUtil::CreateR0(8.0); - float expV = type == F8E5M2 ? 13.0 : 12.1; + float expV = 12.1; // F8E4M3* + if (type == F8E5M2) + expV = 13.0; + else if (type == F8E3M4) + expV = 10.125; auto expected = LiteralUtil::CreateR0(expV); auto error_spec = ErrorSpec(0.0, 0.0); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 1; - EXPECT_IS_NOT_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_NOT_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); error_spec.low_precision_fp_error_spec.type = type; error_spec.low_precision_fp_error_spec.within_n_values = 4; - EXPECT_IS_OK(literal_comparison::Near(actual, expected, error_spec, + EXPECT_IS_OK(literal_comparison::Near(expected, actual, error_spec, /*detailed_message=*/false, /*miscompare_callback=*/nullptr)); } diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index 3ce2f675b1a5a1..65aa09040668fb 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -121,6 +121,17 @@ class LiteralUtilTest : public ::testing::Test { Literal literal_r4_2x2x3x3_dim0minor_; }; +template +class LiteralUtilFloatTest : public LiteralUtilTest {}; + +using FloatTypes = + ::testing::Types; + +TYPED_TEST_SUITE(LiteralUtilFloatTest, FloatTypes); + TEST_F(LiteralUtilTest, LiteralScalarToString) { auto true_lit = LiteralUtil::CreateR0(true); EXPECT_EQ("pred[] true", true_lit.ToString()); @@ -174,8 +185,12 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { EXPECT_EQ("f8e5m2[] 3", f8e5m2_lit_truncated.ToString()); auto f8e4m3_lit = + LiteralUtil::CreateR0(tsl::float8_e4m3(0.5)); + EXPECT_EQ("f8e4m3[] 0.5", f8e4m3_lit.ToString()); + + auto f8e4m3fn_lit = LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); - EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); + EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3fn_lit.ToString()); auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0( tsl::float8_e4m3b11fnuz(0.5)); @@ -188,6 +203,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f8e5m2fnuz_lit = LiteralUtil::CreateR0(tsl::float8_e5m2fnuz(0.5)); EXPECT_EQ("f8e5m2fnuz[] 0.5", f8e5m2fnuz_lit.ToString()); + + auto f8e3m4_lit = + LiteralUtil::CreateR0(tsl::float8_e3m4(0.5)); + EXPECT_EQ("f8e3m4[] 0.5", f8e3m4_lit.ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -640,20 +659,24 @@ TEST_F(LiteralUtilTest, IsAll) { bfloat16 b90(9.00f); EXPECT_TRUE(LiteralUtil::CreateR2({{b91}, {b90}}).IsAll(9.0)); - tsl::float8_e5m2 q16(8); - EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(8)); + tsl::float8_e5m2 p16(8); + EXPECT_TRUE(LiteralUtil::CreateR1({p16}).IsAll(8)); // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false - EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(9)); + EXPECT_FALSE(LiteralUtil::CreateR1({p16}).IsAll(9)); + + tsl::float8_e4m3 q16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({q16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({q16}).IsAll(9)); - tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 + tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3fn EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); - tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3 + tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3b11fnuz EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); - tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 + tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3fnuz EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({t16}).IsAll(9)); @@ -662,6 +685,10 @@ TEST_F(LiteralUtilTest, IsAll) { // 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false EXPECT_FALSE(LiteralUtil::CreateR1({u16}).IsAll(9)); + tsl::float8_e3m4 v16(9); // Exactly representable in e3m4 + EXPECT_FALSE(LiteralUtil::CreateR1({v16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({v16}).IsAll(9)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAll(8)); @@ -734,6 +761,24 @@ TEST_F(LiteralUtilTest, IsAllFirst) { complex64 c7_9 = {7, 9}; EXPECT_TRUE(LiteralUtil::CreateR2({{c8_9}, {c8_9}}).IsAllFirst()); EXPECT_FALSE(LiteralUtil::CreateR2({{c7_9}, {c8_9}}).IsAllFirst()); + +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + int old_csr = _mm_getcsr(); + // Treat denormals as zero. This will make the small numbers below equal to + // 0.0, as far as the FP unit is concerned. + _mm_setcsr(old_csr | _MM_DENORMALS_ZERO_ON); +#endif + bool eq0 = LiteralUtil::CreateR1({0.0, 1.401298e-45}).IsAllFirst(); + bool eq1 = LiteralUtil::CreateR1({0.0, 2.802597e-45}).IsAllFirst(); + bool eq2 = + LiteralUtil::CreateR1({4.203895e-45, 7.006492e-45}).IsAllFirst(); +#if defined(__x86_64__) && defined(_MM_DENORMALS_ZERO_ON) + _mm_setcsr(old_csr); +#endif + + EXPECT_FALSE(eq0); + EXPECT_FALSE(eq1); + EXPECT_FALSE(eq2); } TEST_F(LiteralUtilTest, CountEqualInt) { @@ -812,7 +857,14 @@ template class LiteralUtilTestTemplated : public ::testing::Test {}; using TestedTypes = ::testing::Types; -TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes); +class TestNamer { + public: + template + static std::string GetName(int) { + return ::testing::internal::GetTypeName(); + } +}; +TYPED_TEST_SUITE(LiteralUtilTestTemplated, TestedTypes, TestNamer); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. @@ -1133,34 +1185,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {})); - bfloat16 h(0.25f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR0(h); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR0Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {})); + TypeParam h(0.25f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR0(h); EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {3})); - bfloat16 h(0.5f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR1({h, h, h}); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR1Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {3})); + TypeParam h(0.5f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR1({h, h, h}); EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { - Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); - bfloat16 h(2.0f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { - Literal output(ShapeUtil::MakeShape(F32, {})); - output.PopulateWithValue(2.5f); - auto expected = LiteralUtil::CreateR0(2.5f); +TYPED_TEST(LiteralUtilFloatTest, PopulateWithValueR2Float) { + Literal output(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {2, 2})); + TypeParam h(2.0f); + output.PopulateWithValue(h); + auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, expected); } @@ -1194,70 +1242,6 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2C128) { EXPECT_EQ(output, expected); } -TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { - Literal output(ShapeUtil::MakeShape(F16, {})); - half h(0.25f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR0(h); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { - Literal output(ShapeUtil::MakeShape(F16, {3})); - half h(0.5f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { - Literal output(ShapeUtil::MakeShape(F16, {2, 2})); - half h(2.0f); - output.PopulateWithValue(h); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR0F8e5m2) { - Literal output(ShapeUtil::MakeShape(F8E5M2, {})); - tsl::float8_e5m2 x(0.25f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR0(x); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) { - Literal output(ShapeUtil::MakeShape(F8E4M3FN, {3})); - tsl::float8_e4m3fn x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3b11) { - Literal output(ShapeUtil::MakeShape(F8E4M3B11FNUZ, {3})); - tsl::float8_e4m3b11fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3fnuz) { - Literal output(ShapeUtil::MakeShape(F8E4M3FNUZ, {3})); - tsl::float8_e4m3fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - -TEST_F(LiteralUtilTest, PopulateWithValueR1F8e5m2fnuz) { - Literal output(ShapeUtil::MakeShape(F8E5M2FNUZ, {3})); - tsl::float8_e5m2fnuz x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); - EXPECT_EQ(output, expected); -} - TEST_F(LiteralUtilTest, ReplicateR2U32) { auto input = LiteralUtil::CreateR2( {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); @@ -1738,92 +1722,70 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { EXPECT_EQ(c128.Convert(S32).status().code(), tsl::error::UNIMPLEMENTED); } -TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { - auto s8 = LiteralUtil::CreateR2WithLayout({{0, 1}, {2, 3}}, - layout_r2_dim0major_); - auto f32 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, - layout_r2_dim0major_); - auto c128 = LiteralUtil::CreateR2WithLayout({{0., 1.}, {2., 3.}}, - layout_r2_dim0major_); - using e5 = tsl::float8_e5m2; - auto f8e5m2 = LiteralUtil::CreateR2WithLayout( - {{e5{0.}, e5{1.}}, {e5{2.}, e5{3.}}}, layout_r2_dim0major_); - using e4 = tsl::float8_e4m3fn; - auto f8e4m3 = LiteralUtil::CreateR2WithLayout( - {{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_); - using b11 = tsl::float8_e4m3b11fnuz; - auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout( - {{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_); - using e5f = tsl::float8_e5m2fnuz; - auto f8e5m2fnuz = LiteralUtil::CreateR2WithLayout( - {{e5f{0.}, e5f{1.}}, {e5f{2.}, e5f{3.}}}, layout_r2_dim0major_); - using e4f = tsl::float8_e4m3fnuz; - auto f8e4m3fnuz = LiteralUtil::CreateR2WithLayout( - {{e4f{0.}, e4f{1.}}, {e4f{2.}, e4f{3.}}}, layout_r2_dim0major_); - Literal conv; - - conv = s8.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = f32.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = f8e4m3.Convert(F8E5M2).value(); - EXPECT_EQ(conv, f8e5m2); - - conv = s8.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); - - conv = f32.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); - - conv = f8e5m2.Convert(F8E4M3FN).value(); - EXPECT_EQ(conv, f8e4m3); +TYPED_TEST(LiteralUtilFloatTest, ConvertIfTypesMatchF8) { + constexpr auto ptype = primitive_util::NativeToPrimitiveType(); + if (!primitive_util::IsF8Type(ptype)) { + GTEST_SKIP() << "Skipping test for non F8 types"; + } + auto s8 = LiteralUtil::CreateR2WithLayout( + {{0, 1}, {2, 3}}, LiteralUtilTest::layout_r2_dim0major_); + auto bf16 = LiteralUtil::CreateR2WithLayout( + {{bfloat16(0.), bfloat16(1.)}, {bfloat16(2.), bfloat16(3.)}}, + LiteralUtilTest::layout_r2_dim0major_); + auto f32 = LiteralUtil::CreateR2WithLayout( + {{0., 1.}, {2., 3.}}, LiteralUtilTest::layout_r2_dim0major_); + auto c128 = LiteralUtil::CreateR2WithLayout( + {{0., 1.}, {2., 3.}}, LiteralUtilTest::layout_r2_dim0major_); + // Let's also use a couple of popular F8 types as sources for conversion + using f8e5m2_t = tsl::float8_e5m2; + auto f8e5m2 = LiteralUtil::CreateR2WithLayout( + {{f8e5m2_t{0.}, f8e5m2_t{1.}}, {f8e5m2_t{2.}, f8e5m2_t{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); + using e4m3fn_t = tsl::float8_e4m3fn; + auto f8e4m3fn = LiteralUtil::CreateR2WithLayout( + {{e4m3fn_t{0.}, e4m3fn_t{1.}}, {e4m3fn_t{2.}, e4m3fn_t{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); + + auto f8 = LiteralUtil::CreateR2WithLayout( + {{TypeParam{0.}, TypeParam{1.}}, {TypeParam{2.}, TypeParam{3.}}}, + LiteralUtilTest::layout_r2_dim0major_); - conv = f8e5m2.Convert(S8).value(); - EXPECT_EQ(conv, s8); + Literal conv; - conv = f8e5m2.Convert(F32).value(); - EXPECT_EQ(conv, f32); + // Convert to f8 + conv = s8.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e5m2.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = bf16.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(S8).value(); - EXPECT_EQ(conv, s8); + conv = f32.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(F32).value(); - EXPECT_EQ(conv, f32); + conv = f8e5m2.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = f8e4m3fn.Convert(ptype).value(); + EXPECT_EQ(conv, f8); - conv = f8e4m3b11.Convert(S8).value(); + // Convert from f8 + conv = f8.Convert(S8).value(); EXPECT_EQ(conv, s8); - conv = f8e4m3b11.Convert(F32).value(); - EXPECT_EQ(conv, f32); + conv = f8.Convert(BF16).value(); + EXPECT_EQ(conv, bf16); - conv = f8e4m3b11.Convert(C128).value(); - EXPECT_EQ(conv, c128); - - conv = f8e5m2fnuz.Convert(S8).value(); - EXPECT_EQ(conv, s8); - - conv = f8e5m2fnuz.Convert(F32).value(); + conv = f8.Convert(F32).value(); EXPECT_EQ(conv, f32); - conv = f8e5m2fnuz.Convert(C128).value(); + conv = f8.Convert(C128).value(); EXPECT_EQ(conv, c128); - conv = f8e4m3fnuz.Convert(S8).value(); - EXPECT_EQ(conv, s8); - - conv = f8e4m3fnuz.Convert(F32).value(); - EXPECT_EQ(conv, f32); + conv = f8.Convert(F8E5M2).value(); + EXPECT_EQ(conv, f8e5m2); - conv = f8e4m3fnuz.Convert(C128).value(); - EXPECT_EQ(conv, c128); + conv = f8.Convert(F8E4M3FN).value(); + EXPECT_EQ(conv, f8e4m3fn); } TEST_F(LiteralUtilTest, BitcastConvert) { @@ -2255,9 +2217,12 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e5 = tsl::float8_e5m2; auto vector_f8e5m2 = LiteralUtil::CreateR1({e5{10.0}, e5{20.0}, e5{-32.0}}); - using e4 = tsl::float8_e4m3fn; + using e4 = tsl::float8_e4m3; auto vector_f8e4m3 = LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); + using e4fn = tsl::float8_e4m3fn; + auto vector_f8e4m3fn = + LiteralUtil::CreateR1({e4fn{10.0}, e4fn{20.0}, e4fn{-32.0}}); using b11 = tsl::float8_e4m3b11fnuz; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); @@ -2267,6 +2232,8 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e4f = tsl::float8_e4m3fnuz; auto vector_f8e4m3fnuz = LiteralUtil::CreateR1({e4f{10.0}, e4f{20.0}, e4f{-30.0}}); + using e3 = tsl::float8_e3m4; + auto vector_f8e3m4 = LiteralUtil::CreateR1({e3{2.5}, e3{5.0}, e3{-8.0}}); auto matrix_pred = LiteralUtil::CreateR2({{true, false, true}, {false, false, true}}); auto vector_s4 = LiteralUtil::CreateR1({s4{-1}, s4{3}, s4{7}}); @@ -2289,9 +2256,11 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { EXPECT_EQ(vector_bfloat16, to_from_proto(vector_bfloat16)); EXPECT_EQ(vector_f8e5m2, to_from_proto(vector_f8e5m2)); EXPECT_EQ(vector_f8e4m3, to_from_proto(vector_f8e4m3)); + EXPECT_EQ(vector_f8e4m3fn, to_from_proto(vector_f8e4m3fn)); EXPECT_EQ(vector_f8e4m3b11, to_from_proto(vector_f8e4m3b11)); EXPECT_EQ(vector_f8e5m2fnuz, to_from_proto(vector_f8e5m2fnuz)); EXPECT_EQ(vector_f8e4m3fnuz, to_from_proto(vector_f8e4m3fnuz)); + EXPECT_EQ(vector_f8e3m4, to_from_proto(vector_f8e3m4)); EXPECT_EQ(matrix_pred, to_from_proto(matrix_pred)); EXPECT_EQ(vector_s4, to_from_proto(vector_s4)); EXPECT_EQ(vector_u4, to_from_proto(vector_u4)); @@ -2576,6 +2545,18 @@ TEST_F(LiteralUtilTest, IsEqualAt) { tsl::float8_e4m3fnuz{val_double}); EXPECT_TRUE(c6.IsEqualAt({}, val_double)); EXPECT_TRUE(c6.IsEqualAt({}, val_integral)); + Literal c8 = + LiteralUtil::CreateR0(tsl::float8_e4m3{val_double}); + EXPECT_TRUE(c8.IsEqualAt({}, val_double)); + EXPECT_TRUE(c8.IsEqualAt({}, val_integral)); + Literal c9 = + LiteralUtil::CreateR0(tsl::float8_e4m3fn{val_double}); + EXPECT_TRUE(c9.IsEqualAt({}, val_double)); + EXPECT_TRUE(c9.IsEqualAt({}, val_integral)); + Literal c10 = + LiteralUtil::CreateR0(tsl::float8_e3m4{val_double}); + EXPECT_TRUE(c10.IsEqualAt({}, val_double)); + EXPECT_TRUE(c10.IsEqualAt({}, val_integral)); } TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { @@ -2901,10 +2882,10 @@ class LiteralSerializationTest : public ::testing::Test, static std::vector GenerateSimpleParams() { std::vector params; for (PrimitiveType element_type : - {PRED, S4, U4, S8, U8, S16, - U16, S32, U32, S64, U64, F16, - F32, F64, BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, - F8E5M2FNUZ, F8E4M3FNUZ, C64, C128}) { + {PRED, S4, U4, S8, U8, S16, + U16, S32, U32, S64, U64, F16, + F32, F64, BF16, F8E5M2, F8E4M3, F8E4M3FN, + F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, F8E3M4, C64, C128}) { for (const DimensionVector& dimensions : { DimensionVector{}, DimensionVector{0}, diff --git a/third_party/xla/xla/literal_util.cc b/third_party/xla/xla/literal_util.cc index 745194cdc24b39..9b5507327789a8 100644 --- a/third_party/xla/xla/literal_util.cc +++ b/third_party/xla/xla/literal_util.cc @@ -221,6 +221,184 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, literal.Set(multi_index, scalar.Get({})); } +template +void PopulateWithIntNext(Literal* literal) { + using BitRepT = UnsignedIntegerTypeForSizeType; + // Duplicates may be generated if we don't have enough bits. + // Skip bfloat16 and float32 subnormals. + const FloatT kFirstValue = + std::is_same_v || sizeof(FloatT) >= sizeof(float) + ? std::numeric_limits::min() + : std::numeric_limits::denorm_min(); + // `current` keeps track of the next value we need to populate. + auto current = literal->data().begin(); + auto end = literal->data().end(); + // `sign` keeps track of the sign of the next value. + bool sign = false; + while (current != end) { + // We start populating values at zero and increase magnitude from there. + *current = sign ? static_cast(-0.0f) : static_cast(0.0f); + current++; + // The next value is either the smallest denormal or normal. + auto value = sign ? -kFirstValue : kFirstValue; + // Fill the array with values of increasing magnitude until we hit a + // non-finite value. + while (current != end && Eigen::numext::isfinite(value)) { + // Populate the value. + *current = value; + // Generate the next value by lexicographically increasing the bit + // representation. + const BitRepT next_value = Eigen::numext::bit_cast(value) + 1; + value = Eigen::numext::bit_cast(next_value); + current++; + } + // We ran out of finite values, flip the sign and begin again. + sign = !sign; + } +} + +template +void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { + PopulateWithIntNext(literal); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); +} + +// Populates a floating point literal with random floating points sampled from a +// uniform-log distribution spanning approximately the entire range of the +// representable floating point. +template +void PopulateWithRandomFullRangeFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + constexpr float kSpecialValueProbability = 1e-6; + constexpr float kSpecialValues[] = {+0.F, + -0.F, + 1.F, + -1.F, + std::numeric_limits::infinity(), + -std::numeric_limits::infinity()}; + constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float); + std::uniform_real_distribution special_value_gen(0, 1); + + // Generates floating points with a log-uniform distribution. This causes the + // exponent of the floating point to have a uniform distribution. + const int min_exp = std::numeric_limits::min_exponent; + const int max_exp = std::numeric_limits::max_exponent; + std::uniform_real_distribution generator(min_exp - 1, max_exp - 1); + + for (FloatT& value : literal->data()) { + // Each special value has a kSpecialValueProbability chance to be generated + // instead of sampling using the normal distributions. + if (special_value_gen(*engine) < + kSpecialValueProbability * kNumSpecialValues) { + value = + static_cast(kSpecialValues[(*engine)() % kNumSpecialValues]); + } else { + float sign = ((*engine)() % 2 == 0) ? 1 : -1; + value = static_cast(pow(2, generator(*engine)) * sign); + } + } +} + +template +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + std::uniform_real_distribution generator(-0.1f, 0.2f); + for (FloatT& value : literal->data()) { + value = static_cast(generator(*engine)); + } +} + +template +void PopulateWithFloatingPointData( + Literal* literal, std::minstd_rand0* engine, bool no_duplicates, + bool use_large_range, std::optional max_bits_of_precision) { + using ComputeT = + std::conditional_t; + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (max_bits_of_precision.has_value()) { + CHECK(!use_large_range) << "Cannot set both use_large_range and " + "max_bits_of_precision for floating points."; + CHECK(!no_duplicates) << "Cannot set both no_duplicates and " + "max_bits_of_precision for floating points."; + std::uniform_int_distribution generator( + -(1 << *max_bits_of_precision), 1 << *max_bits_of_precision); + for (FloatT& value : literal->data()) { + int64_t temp = generator(*engine); + // We want to generate floating point numbers to a fixed precision, while + // keeping them between -1 and 1. This preserves their bits of precision + // while keeping the numbers small. + value = static_cast(temp * pow(2, -ceil(log2(abs(temp))))); + } + } else if (no_duplicates) { + PopulateWithNoDuplicateData(literal, engine); + } else if (use_large_range) { + PopulateWithRandomFullRangeFloatingPointData(literal, engine); + } else { + PopulateWithRandomFloatingPointData(literal, engine); + } +} + +template +void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine, + bool no_duplicates, bool use_large_range) { + using InnerFloatT = typename ComplexT::value_type; + CHECK(engine != nullptr); + CHECK_EQ(result->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + Shape floating_point_shape = ShapeUtil::ChangeElementType( + result->shape(), primitive_util::NativeToPrimitiveType()); + Literal real_lit(floating_point_shape); + Literal imaginary_lit(floating_point_shape); + + PopulateWithFloatingPointData( + &real_lit, engine, no_duplicates, use_large_range, + /*max_bits_of_precision=*/std::nullopt); + PopulateWithFloatingPointData( + &imaginary_lit, engine, no_duplicates, use_large_range, + /*max_bits_of_precision=*/std::nullopt); + + absl::Span real_data = real_lit.data(); + absl::Span imaginary_data = + imaginary_lit.data(); + absl::Span result_data = result->data(); + for (int i = 0; i < real_lit.data().size(); i++) { + result_data[i] = ComplexT(real_data[i], imaginary_data[i]); + } +} + +// uniform_int_distribution is not defined for 8-bit integers. +// Use 'short' for those types. +template +using RngT = std::conditional_t< + sizeof(IntT) < sizeof(uint16_t), + std::conditional_t::is_signed, int16_t, uint16_t>, + IntT>; +template +void PopulateWithRandomIntegralDataWithBounds(Literal* literal, + std::minstd_rand0* engine, + bool no_duplicates, IntT min, + IntT max) { + CHECK(engine != nullptr); + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + if (no_duplicates && + ShapeUtil::ElementsIn(literal->shape()) < static_cast(max)) { + std::iota(literal->data().begin(), literal->data().end(), + static_cast(0)); + std::shuffle(literal->data().begin(), literal->data().end(), + *engine); + } else { + std::uniform_int_distribution> generator( + static_cast>(min), static_cast>(max)); + for (IntT& value : literal->data()) { + value = static_cast(generator(*engine)); + } + } +} + } // namespace /* static */ Literal LiteralUtil::CreateFromDimensions( @@ -229,6 +407,11 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, ShapeUtil::MakeShape(primitive_type, dimensions)); } +/* static */ Literal LiteralUtil::ConvertS8ToF32( + const LiteralSlice& s8_literal) { + return ConvertType(s8_literal); +} + /* static */ Literal LiteralUtil::ConvertBF16ToF32( const LiteralSlice& bf16_literal) { return ConvertType(bf16_literal); @@ -249,6 +432,16 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return ConvertType(f32_literal); } +/* static */ Literal LiteralUtil::ConvertF32ToF8E5M2( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + +/* static */ Literal LiteralUtil::ConvertF32ToF8E4M3FN( + const LiteralSlice& f32_literal) { + return ConvertType(f32_literal); +} + /* static */ Literal LiteralUtil::ConvertF32ToBF16( const LiteralSlice& f32_literal) { return ConvertType(f32_literal); @@ -483,4 +676,103 @@ void SetScalarAtIndexImpl(MutableLiteralBase& literal, return l.GetFirstInteger(); } +absl::StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, + bool use_large_range) { + auto engine = pseudo_random ? std::make_unique() : nullptr; + return MakeFakeLiteral(shape, engine.get(), /*limit=*/std::nullopt, + /*is_sorted=*/false, + /*no_duplicates=*/false, use_large_range, + /*max_bits_of_precision=*/std::nullopt); +} + +absl::StatusOr MakeFakeLiteral( + const Shape& shape, std::minstd_rand0* engine, + std::optional> limit, bool is_sorted, + bool no_duplicates, bool use_large_range, + std::optional max_bits_of_precision) { + if (shape.IsTuple()) { + std::vector elements; + const auto& shape_tuple_shapes = shape.tuple_shapes(); + elements.reserve(shape_tuple_shapes.size()); + for (const Shape& element_shape : shape_tuple_shapes) { + TF_ASSIGN_OR_RETURN( + Literal element, + MakeFakeLiteral(element_shape, engine, limit, is_sorted, + no_duplicates, use_large_range, + max_bits_of_precision)); + elements.push_back(std::move(element)); + } + return LiteralUtil::MakeTupleOwned(std::move(elements)); + } + if (engine == nullptr) { + return Literal::CreateFromShape(shape); + } + // Clear tiles/element size in shape's layout before using it for creating + // literal. + Shape new_shape = shape; + new_shape.mutable_layout()->clear_tiles(); + new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1); + new_shape.mutable_layout()->set_element_size_in_bits(0); + Literal literal(new_shape); + + TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch( + [&](auto primitive_type_constant) -> absl::Status { + if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { + using NativeT = primitive_util::NativeTypeOf; + if constexpr (primitive_util::IsFloatingPointType( + primitive_type_constant)) { + PopulateWithFloatingPointData( + &literal, engine, no_duplicates, use_large_range, + max_bits_of_precision); + return absl::OkStatus(); + } + if constexpr (primitive_type_constant == PRED) { + std::uniform_int_distribution generator(0, 1); + TF_CHECK_OK(literal.Populate( + [&](absl::Span /*indices*/) { + return generator(*engine); + })); + return absl::OkStatus(); + } + if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + NativeT max = std::numeric_limits::max(); + NativeT min = std::numeric_limits::lowest(); + if (limit.has_value()) { + max = static_cast(limit->second); + min = static_cast(limit->first); + } + if (max_bits_of_precision.has_value()) { + max = std::min(max, + static_cast(1 << *max_bits_of_precision)); + if (primitive_util::IsSignedIntegralType( + primitive_type_constant)) { + min = std::max( + min, static_cast(-(1 << *max_bits_of_precision))); + } + } + PopulateWithRandomIntegralDataWithBounds( + &literal, engine, /*no_duplicate*/ no_duplicates, min, max); + if (is_sorted) { + std::sort(literal.data().begin(), + literal.data().end()); + } + return absl::OkStatus(); + } + if constexpr (primitive_util::IsComplexType( + primitive_type_constant)) { + PopulateWithComplexData(&literal, engine, no_duplicates, + use_large_range); + return absl::OkStatus(); + } + } + return Unimplemented( + "Unsupported type for fake random literal generation with bounds: " + "%s", + ShapeUtil::HumanString(shape)); + }, + shape.element_type())); + return std::move(literal); +} + } // namespace xla diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index a19ed6fb1e529e..01af0cea5499b8 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -239,10 +239,13 @@ class LiteralUtil { // If the given literal's data type is , converts it to a // literal; otherwise, returns a copy of it. If the literal is a tuple, // recursively converts its elements. + static Literal ConvertS8ToF32(const LiteralSlice& s8_literal); static Literal ConvertBF16ToF32(const LiteralSlice& bf16_literal); static Literal ConvertBF16ToF64(const LiteralSlice& bf16_literal); static Literal ConvertF32ToF8E4M3FNUZ(const LiteralSlice& f32_literal); static Literal ConvertF32ToF8E5M2FNUZ(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E5M2(const LiteralSlice& f32_literal); + static Literal ConvertF32ToF8E4M3FN(const LiteralSlice& f32_literal); static Literal ConvertF32ToBF16(const LiteralSlice& f32_literal); static Literal ConvertF32ToS8(const LiteralSlice& f32_literal); static Literal ConvertF32ToF64(const LiteralSlice& f32_literal); @@ -607,6 +610,32 @@ template return CreateRandomLiteral(shape, &engine, mean, stddev); } +// Generates fake data in a literal of the given shape, or returns an error +// status if the element type is currently unhandled for fake data +// generation. See below for documentation of pseudo_random and use_large_range. +absl::StatusOr MakeFakeLiteral(const Shape& shape, + bool pseudo_random = true, + bool use_large_range = false); + +// Similar to MakeFakeLiteral above but takes a random number generator engine +// to enable reusing the engine across randomly generated literals. 'limit' is a +// optional pair that contains the min and the max values to be sample for +// integers (integer format only). 'is_sorted' sorts the sample data for +// integers (integer format only). 'no_duplicates' indicates that there should +// be no duplicate values in each generated array. This is uniqueness is +// best-effort only. Some types (half and bfloat16) are not supported and +// uniqueness cannot be guaranteed if the number of elements exceeds the number +// of different values supported by the type. (floating point format only) +// 'use_large_range' indicates the sampled data is from the full range of the +// floating point format. (floating point format only) +// 'max_bits_of_precision' sets the data to have the given number of bits or +// less (integer or floating point formats only). +absl::StatusOr MakeFakeLiteral( + const Shape& shape, std::minstd_rand0* engine, + std::optional> limit, bool is_sorted, + bool no_duplicates, bool use_large_range, + std::optional max_bits_of_precision); + } // namespace xla #endif // XLA_LITERAL_UTIL_H_ diff --git a/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td b/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td index f8015781bed1a8..6a72799112f780 100644 --- a/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td +++ b/third_party/xla/xla/mlir/framework/ir/xla_framework_ops.td @@ -27,7 +27,7 @@ def XLAFramework_Dialect : Dialect { let summary = "Types and operations for xla_framework dialect"; let description = [{ This dialect contains operations and types that correspond to XLA compiled C - functions in tensorflow/compiler/xla/service/cpu/ir_function.cc. + functions in xla/service/cpu/ir_function.cc. }]; let cppNamespace = "::mlir::xla_framework"; diff --git a/third_party/xla/xla/mlir/framework/tests/BUILD b/third_party/xla/xla/mlir/framework/tests/BUILD index e0311ea4ac362d..f7278a1241eb99 100644 --- a/third_party/xla/xla/mlir/framework/tests/BUILD +++ b/third_party/xla/xla/mlir/framework/tests/BUILD @@ -17,7 +17,7 @@ lit_test_suite( ), cfg = "//xla:lit.cfg.py", tools = [ - "//xla/translate:xla-translate-opt", + "//xla/hlo/translate:xla-translate-opt", "@llvm-project//llvm:FileCheck", ], ) diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD index 666fde17ba224b..165d8032b07b73 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/BUILD @@ -39,6 +39,7 @@ cc_library( "//xla/mlir/tools/mlir_interpreter/framework", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc b/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc index cf92289ef12188..5a32a6fcbe5e92 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/compiler_trace_instrumentation.cc @@ -17,12 +17,12 @@ limitations under the License. #include -#include "absl/strings/str_format.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/Pass.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "tsl/platform/env.h" #include "tsl/platform/logging.h" -#include "tsl/platform/path.h" namespace mlir { namespace interpreter { diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc index 81173bfd18456a..b141ccaa88ded7 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" @@ -43,6 +44,7 @@ limitations under the License. #include "xla/mlir/tools/mlir_interpreter/framework/tensor_or_memref.h" #include "xla/mlir/tools/mlir_replay/public/execution_trace.pb.h" #include "xla/primitive_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace mlir { diff --git a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h index 2148a982c0f5a4..c272900bd6a2b6 100644 --- a/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h +++ b/third_party/xla/xla/mlir/tools/mlir_replay/public/execution_trace_utils.h @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Region.h" +#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" #include "xla/literal.h" #include "xla/mlir/tools/mlir_interpreter/framework/interpreter.h" diff --git a/third_party/xla/xla/mlir/utils/type_util.cc b/third_party/xla/xla/mlir/utils/type_util.cc index 59b19c34611412..2581390a1e13d7 100644 --- a/third_party/xla/xla/mlir/utils/type_util.cc +++ b/third_party/xla/xla/mlir/utils/type_util.cc @@ -34,6 +34,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getI1Type(); case xla::PrimitiveType::F8E5M2: return b.getFloat8E5M2Type(); + case xla::PrimitiveType::F8E4M3: + return b.getFloat8E4M3Type(); case xla::PrimitiveType::F8E4M3FN: return b.getFloat8E4M3FNType(); case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -42,6 +44,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( return b.getFloat8E5M2FNUZType(); case xla::PrimitiveType::F8E4M3FNUZ: return b.getFloat8E4M3FNUZType(); + case xla::PrimitiveType::F8E3M4: + return b.getFloat8E3M4Type(); case xla::PrimitiveType::F16: return b.getF16Type(); case xla::PrimitiveType::BF16: @@ -76,6 +80,8 @@ absl::StatusOr ConvertPrimitiveTypeToMlirType( xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { if (type.isFloat8E5M2()) { return xla::PrimitiveType::F8E5M2; + } else if (type.isFloat8E4M3()) { + return xla::PrimitiveType::F8E4M3; } else if (type.isFloat8E4M3FN()) { return xla::PrimitiveType::F8E4M3FN; } else if (type.isFloat8E4M3B11FNUZ()) { @@ -84,6 +90,8 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F8E4M3FNUZ; } else if (type.isFloat8E5M2FNUZ()) { return xla::PrimitiveType::F8E5M2FNUZ; + } else if (type.isFloat8E3M4()) { + return xla::PrimitiveType::F8E3M4; } else if (type.isBF16()) { return xla::PrimitiveType::BF16; } else if (type.isF16()) { diff --git a/third_party/xla/xla/mlir/utils/type_util_test.cc b/third_party/xla/xla/mlir/utils/type_util_test.cc index 6c19098574dec5..a8043ab0b5f140 100644 --- a/third_party/xla/xla/mlir/utils/type_util_test.cc +++ b/third_party/xla/xla/mlir/utils/type_util_test.cc @@ -102,6 +102,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(std::vector( {{PRED, [](mlir::Builder b) { return b.getI1Type(); }}, {F8E5M2, [](mlir::Builder b) { return b.getFloat8E5M2Type(); }}, + {F8E4M3, [](mlir::Builder b) { return b.getFloat8E4M3Type(); }}, {F8E4M3FN, [](mlir::Builder b) { return b.getFloat8E4M3FNType(); }}, {F8E4M3B11FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3B11FNUZType(); }}, @@ -109,6 +110,7 @@ INSTANTIATE_TEST_SUITE_P( [](mlir::Builder b) { return b.getFloat8E5M2FNUZType(); }}, {F8E4M3FNUZ, [](mlir::Builder b) { return b.getFloat8E4M3FNUZType(); }}, + {F8E3M4, [](mlir::Builder b) { return b.getFloat8E3M4Type(); }}, {F16, [](mlir::Builder b) { return b.getF16Type(); }}, {BF16, [](mlir::Builder b) { return b.getBF16Type(); }}, {F32, [](mlir::Builder b) { return b.getF32Type(); }}, diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 11764a0594c9f4..b3f1106472299b 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -497,6 +497,7 @@ cc_library( strip_include_prefix = ".", deps = [ ":mlir_hlo", + ":transformation_helpers", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ComplexDialect", @@ -507,6 +508,18 @@ cc_library( ], ) +cc_library( + name = "transformation_helpers", + hdrs = ["mhlo/transforms/transformation_helpers.h"], + strip_include_prefix = ".", + deps = [ + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithUtils", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "map_chlo_to_hlo_op", hdrs = ["mhlo/transforms/map_chlo_to_hlo_op.h"], @@ -803,6 +816,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncToLLVM", "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index dbd006df877c8e..ff087387bff590 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -58,7 +58,6 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.h b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.h index 543c47c27a992f..843e3919e330d3 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.h @@ -19,7 +19,6 @@ limitations under the License. #define MLIR_HLO_MHLO_IR_HLO_OPS_H #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/QuantTypes.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc index a3a9f6664f9544..9c03b387174730 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc @@ -144,7 +144,8 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern { rewriter.create(loc, operand, ValueRange())); } Value scalarResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, resultTy->getElementType(), operands, &rewriter); + op, resultTy->getElementType(), operands, /*attributes=*/std::nullopt, + &rewriter); if (!scalarResult) return failure(); rewriter.replaceOpWithNewOp(op, *resultTy, scalarResult); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc index 935465d2dabaa6..b4fb6ed509994e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -103,7 +104,8 @@ class MhloElementwiseConverter : public OpRewritePattern { } Value scalarOp = mhlo::MhloOpToStdScalarOp::mapOp( - op, resultTy.getElementType(), extracts, &rewriter); + op, resultTy.getElementType(), extracts, /*attributes=*/std::nullopt, + &rewriter); operands.push_back(scalarOp); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc index a12af2369bf82d..9b2ae10e951181 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc @@ -612,7 +612,7 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { auto rhs = rewriter.create(loc, mhloOp.rhs()); Value opResult = mhlo::MhloOpToStdScalarOp::mapOp( mhloOp, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, - &rewriter); + /*attributes=*/std::nullopt, &rewriter); rewriter.create(loc, opResult, mhloOp.out()); rewriter.eraseOp(mhloOp); return success(); @@ -1512,7 +1512,7 @@ class IotaConverter : public OpConversionPattern { indexOp); castOp = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( nestedLoc, targetElementType, resultElementType, castOp.getType(), - {castOp}, &nestedBuilder); + {castOp}, /*attributes=*/std::nullopt, &nestedBuilder); nestedBuilder.create(nestedLoc, castOp); }, linalg::getPrunedAttributeList(iotaOp)); @@ -1548,7 +1548,8 @@ class IotaToMapConverter : public OpConversionPattern { nestedLoc, nestedBuilder.getI64Type(), index); Value result = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( nestedLoc, targetElementType, resultTy.getElementType(), - index.getType(), {ValueRange{index}}, &nestedBuilder); + index.getType(), {ValueRange{index}}, /*attributes=*/std::nullopt, + &nestedBuilder); nestedBuilder.create(nestedLoc, ValueRange{result}); }, linalg::getPrunedAttributeList(iotaOp)); @@ -4369,7 +4370,8 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern { [&](OpBuilder& b, Location loc, ValueRange args) { Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp( op, getElementTypeOrSelf(emptyTensor), - interleaveScalarAndBlockArgs(scalarInputs, args), &b); + interleaveScalarAndBlockArgs(scalarInputs, args), + /*attributes=*/std::nullopt, &b); b.create(loc, innerResult); }, linalg::getPrunedAttributeList(op)); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index c3b7f7103841f1..7cd367297ccd62 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -16,12 +16,14 @@ limitations under the License. #ifndef MLIR_HLO_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H #define MLIR_HLO_MHLO_TRANSFORMS_MAP_MHLO_TO_SCALAR_OP_H +#include #include #include #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "mhlo/IR/hlo_ops.h" +#include "mhlo/transforms/transformation_helpers.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" @@ -194,7 +196,7 @@ template struct MapMhloOpToScalarOpImpl { Value operator()(Location /*loc*/, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, ValueRange /*args*/, - OpBuilder* /*b*/) { + ArrayRef /*attributes*/, OpBuilder* /*b*/) { return nullptr; } }; @@ -202,32 +204,34 @@ struct MapMhloOpToScalarOpImpl { template struct MapMhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef resultTypes, - ArrayRef /*argTypes*/, ValueRange args, OpBuilder* b) { - return b->template create(loc, resultTypes, args, - std::nullopt); + ArrayRef /*argTypes*/, ValueRange args, + ArrayRef attributes, OpBuilder* b) { + return b->template create(loc, resultTypes, args, attributes); } }; template struct MapMhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef resultTypes, - ArrayRef argTypes, ValueRange args, OpBuilder* b) { + ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { Type elementType = getElementTypeOrSelf(argTypes.front()); if (SupportedType{}(elementType)) { return b->template create(loc, resultTypes, args, - std::nullopt); + attributes); } return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, args, - b); + attributes, b); } }; template struct MapMhloOpToScalarOpImpl { Value operator()(Location loc, ArrayRef resultTypes, - ArrayRef argTypes, ValueRange args, OpBuilder* b) { + ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, args, - b); + attributes, b); } }; @@ -273,6 +277,7 @@ template inline Value mapMhloOpToStdScalarOp(Location loc, ArrayRef resultTypes, ArrayRef argTypes, typename MhloOpTy::Adaptor adaptor, + ArrayRef attributes, OpBuilder* b) { using ScalarIOpOrVoid = typename MapableIf::type; using ScalarUOpOrVoid = typename MapableIf::type; @@ -281,24 +286,23 @@ inline Value mapMhloOpToStdScalarOp(Location loc, ArrayRef resultTypes, return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, - adaptor.getOperands(), b); + ScalarCOpOrVoid>{}( + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::AbsOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::AbsOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { Type elementType = getElementTypeOrSelf(argTypes.front()); if (mlir::isa(elementType)) { return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } if (mlir::isa(elementType)) { return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } if (elementType.isSignlessInteger() || elementType.isSignedInteger()) { // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) @@ -405,7 +409,8 @@ inline Value cmpComplex(Location loc, Value lhs, Value rhs, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef /*resultTypes*/, ArrayRef argTypes, - mhlo::CompareOp::Adaptor adaptor, OpBuilder* b) { + mhlo::CompareOp::Adaptor adaptor, ArrayRef /*attributes*/, + OpBuilder* b) { ComparisonDirection comparisonDirection = adaptor.getComparisonDirection(); const auto& lhs = adaptor.getLhs(); const auto& rhs = adaptor.getRhs(); @@ -467,143 +472,62 @@ inline Value mapMhloOpToStdScalarOp( return nullptr; } -template <> -inline Value mapMhloOpToStdScalarOp( - Location loc, ArrayRef /*resultTypes*/, ArrayRef argTypes, - mhlo::ReducePrecisionOp::Adaptor adaptor, OpBuilder* builder) { - using llvm::APInt; - mlir::ImplicitLocOpBuilder b(loc, *builder); - - // Integer and float types for casting and constant generation. - auto floatType = - mlir::cast(getElementTypeOrSelf(argTypes.front())); - int64_t nbits = floatType.getWidth(); - auto intType = mlir::IntegerType::get(loc.getContext(), nbits); - - Value xAsInt = b.create(intType, adaptor.getOperand()); - - // SignificandWidth includes the implicit extra bit. - auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; - int srcExponentBits = nbits - 1 - srcMantissaBits; - - // Clear the sign bit, it does not participate in rounding and we will restore - // it later. - APInt signBitMask(nbits, 1); - signBitMask <<= nbits - 1; - - APInt expBitsMask(nbits, 1); - expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; - - auto createConstant = [&](const APInt& v) { - return b.create(v.getZExtValue(), intType) - .getResult(); - }; - - Value xAbsBits = - b.create(xAsInt, createConstant(~signBitMask)); - Value xIsNan = b.create(arith::CmpIPredicate::ugt, xAbsBits, - createConstant(expBitsMask)); - - int destMantissaBits = adaptor.getMantissaBits(); - if (destMantissaBits < static_cast(srcMantissaBits)) { - // Last remaining mantissa bit. - APInt lastMantissaBitMask(nbits, 1); - lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; - - Value mantissaDiff = b.create( - srcMantissaBits - destMantissaBits, intType); - Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); - Value baseRoundingBiasVal = createConstant(baseRoundingBias); - Value xLastMantissaBit = b.create( - b.create(xAsInt, highestMantissaMaskVal), mantissaDiff); - Value xRoundingBias = - b.create(xLastMantissaBit, baseRoundingBiasVal); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - APInt truncationMask = ~(lastMantissaBitMask - 1); - Value xRounded = b.create(xAsInt, xRoundingBias); - xAsInt = b.create(xRounded, createConstant(truncationMask)); +static bool HasDefaultMantissaBits(Type type, uint32_t mantissa_bits) { + if (auto float_ty = mlir::dyn_cast(type)) { + return float_ty.getFPMantissaWidth() == mantissa_bits; } + return false; +} - int destExponentBits = adaptor.getExponentBits(); - if (destExponentBits < srcExponentBits) { - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - APInt exponentBias(nbits, 1); - exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; - - APInt reducedExponentBias(nbits, 1); - reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; - - APInt reducedMaxExponent = exponentBias + reducedExponentBias; - APInt reducedMinExponent = exponentBias - reducedExponentBias; - - // Do we overflow or underflow? - Value xExponent = - b.create(xAsInt, createConstant(expBitsMask)); - Value xOverflows = b.create( - arith::CmpIPredicate::ugt, xExponent, - createConstant(reducedMaxExponent << srcMantissaBits)); - Value xUnderflows = b.create( - arith::CmpIPredicate::ule, xExponent, - createConstant(reducedMinExponent << srcMantissaBits)); - - // Compute appropriately-signed values of zero and infinity. - Value xSignedZero = - b.create(xAsInt, createConstant(signBitMask)); - Value xSignedInf = - b.create(xSignedZero, createConstant(expBitsMask)); - - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - xAsInt = b.create(xOverflows, xSignedInf, xAsInt); - xAsInt = b.create(xUnderflows, xSignedZero, xAsInt); +static bool HasDefaultExponentBits(Type type, uint32_t exponent_bits) { + if (auto float_ty = mlir::dyn_cast(type)) { + return float_ty.getWidth() - float_ty.getFPMantissaWidth() - 1 == + exponent_bits; } + return false; +} - Value result = b.create(floatType, xAsInt); - return b.create(xIsNan, adaptor.getOperand(), result); +template <> +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, + mhlo::ReducePrecisionOp::Adaptor adaptor, + ArrayRef /*attributes*/, OpBuilder* builder) { + // TODO(b/373787166): This should actually be a folder, but JAX is adding + // no-op ReducePrecision ops to workaround an issue with some simplifications + // allowed with the xla_allow_excess_precision flag. We would already fold + // these ops away before they reach HLO. Folding them away at emission time + // keeps the workaround intact. + if (HasDefaultExponentBits(resultTypes[0], adaptor.getExponentBits()) && + HasDefaultMantissaBits(resultTypes[0], adaptor.getMantissaBits())) { + return adaptor.getOperand(); + } + return reducePrecision(loc, adaptor.getOperand(), + adaptor.getExponentBits(), + adaptor.getMantissaBits(), builder); } template <> inline Value mapMhloOpToStdScalarOp( Location /*loc*/, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, mhlo::CopyOp::Adaptor adaptor, - OpBuilder* /*b*/) { + ArrayRef /*attributes*/, OpBuilder* /*b*/) { return adaptor.getOperands().front(); } template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, - mhlo::ComplexOp::Adaptor adaptor, OpBuilder* b) { + mhlo::ComplexOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::MaxOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::MaxOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { ValueRange operands = adaptor.getOperands(); Value lhs = operands.front(); Type complexTy = lhs.getType(); @@ -612,7 +536,7 @@ inline Value mapMhloOpToStdScalarOp(Location loc, return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); assert(resultTypes.size() == 1 && "MaxOp should return a single result"); assert(operands.size() == 2 && "MaxOp should take exactly two arguments"); @@ -625,11 +549,10 @@ inline Value mapMhloOpToStdScalarOp(Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::MinOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::MinOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { ValueRange operands = adaptor.getOperands(); Value lhs = operands.front(); Type complexTy = lhs.getType(); @@ -638,7 +561,7 @@ inline Value mapMhloOpToStdScalarOp(Location loc, return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); assert(resultTypes.size() == 1 && "MinOp should return a single result"); assert(operands.size() == 2 && "MinOp should take exactly two arguments"); @@ -651,28 +574,26 @@ inline Value mapMhloOpToStdScalarOp(Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::RealOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::RealOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { if (!mlir::isa(adaptor.getOperand().getType())) return adaptor.getOperand(); - return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, - adaptor.getOperands(), b); + return MapMhloOpToScalarOpImpl{}( + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::ImagOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::ImagOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { if (!mlir::isa(adaptor.getOperand().getType())) return b->create( loc, b->getZeroAttr(adaptor.getOperand().getType())); - return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, - adaptor.getOperands(), b); + return MapMhloOpToScalarOpImpl{}( + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } // 'target_types' is the unconverted type (signed or unsigned if integer), @@ -680,6 +601,7 @@ inline Value mapMhloOpToStdScalarOp(Location loc, inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, ArrayRef resultTypes, ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { assert(targetTypes.size() == 1 && "ConvertOp should return a single result"); assert(resultTypes.size() == 1 && "ConvertOp should return a single result"); @@ -695,12 +617,10 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, if (IsUnsignedIntegerType{}(sourceType) && mlir::arith::UIToFPOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, - std::nullopt); + return b->create(loc, resultTypes, args, attributes); } if (mlir::arith::SIToFPOp::areCastCompatible(sourceType, targetType)) { - return b->create(loc, resultTypes, args, - std::nullopt); + return b->create(loc, resultTypes, args, attributes); } if (mlir::isa(sourceType) && mlir::isa(targetType)) { if (sourceType == targetType) { @@ -719,9 +639,9 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, if (sourceType.getIntOrFloatBitWidth() > dst.getWidth()) { return b->create(loc, resultTypes, src, - std::nullopt); + attributes); } - return b->create(loc, resultTypes, src, std::nullopt); + return b->create(loc, resultTypes, src, attributes); } if (targetType.isInteger(/*width=*/1)) { // When casting to bool, we need to compare whether the value is equal to @@ -745,16 +665,16 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, auto res = mlir::cast(targetType); if (src.getWidth() > res.getWidth()) { return b->create(loc, resultTypes, args, - std::nullopt); + attributes); } if (src.getWidth() < res.getWidth()) { // Special case boolean values, so they get casted to `1` instead of `-1`. if (IsUnsignedIntegerType{}(src)) { return b->create(loc, resultTypes, args, - std::nullopt); + attributes); } return b->create(loc, resultTypes, args, - std::nullopt); + attributes); } // No conversion is needed for the same width integers return args.front(); @@ -762,13 +682,11 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, if (targetType.isUnsignedInteger() && mlir::arith::FPToUIOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, - std::nullopt); + return b->create(loc, resultTypes, args, attributes); } if (mlir::arith::FPToSIOp::areCastCompatible(convertedSourceType, targetType)) { - return b->create(loc, resultTypes, args, - std::nullopt); + return b->create(loc, resultTypes, args, attributes); } if (mlir::isa(targetType)) { Type targetElementType = @@ -786,19 +704,20 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, "elements of complex numbers should not be complex"); Value sourceReal = b->create(loc, sourceElementType, args.front()); - targetReal = - mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, - sourceElementType, sourceReal, b); + targetReal = mapConvertOpToStdScalarOp( + loc, targetElementType, targetElementType, sourceElementType, + sourceReal, attributes, b); Value sourceImag = b->create(loc, sourceElementType, args.front()); - targetImag = - mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, - sourceElementType, sourceImag, b); + targetImag = mapConvertOpToStdScalarOp( + loc, targetElementType, targetElementType, sourceElementType, + sourceImag, attributes, b); } else { // We are converting from real (float, integer, etc.) type, convert the // real part and set the imaginary part to 0. - targetReal = mapConvertOpToStdScalarOp( - loc, targetElementType, targetElementType, argTypes, args, b); + targetReal = + mapConvertOpToStdScalarOp(loc, targetElementType, targetElementType, + argTypes, args, attributes, b); targetImag = b->create( loc, b->getFloatAttr(targetElementType, 0.0)); } @@ -812,7 +731,8 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, Value sourceReal = b->create(loc, sourceElementType, args.front()); return mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes, - sourceElementType, sourceReal, b); + sourceElementType, sourceReal, attributes, + b); } return nullptr; } @@ -823,7 +743,8 @@ inline Value mapConvertOpToStdScalarOp(Location loc, ArrayRef targetTypes, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, - mhlo::BitcastConvertOp::Adaptor adaptor, OpBuilder* b) { + mhlo::BitcastConvertOp::Adaptor adaptor, + ArrayRef attributes, OpBuilder* b) { Type argType = getElementTypeOrSelf(argTypes.front()); Type resultType = getElementTypeOrSelf(resultTypes.front()); @@ -831,15 +752,14 @@ inline Value mapMhloOpToStdScalarOp( return nullptr; return b->create(loc, resultTypes, - adaptor.getOperands()); + adaptor.getOperands(), attributes); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::DotOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::DotOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { // Dot Op converter from lhlo to affine only accepts float and integer types. const auto& lhs = adaptor.getOperands()[0]; const auto& rhs = adaptor.getOperands()[1]; @@ -848,16 +768,16 @@ inline Value mapMhloOpToStdScalarOp(Location loc, if (mlir::isa(elementType)) { Value floatMul = MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, {lhs, rhs}, b); + loc, resultTypes, argTypes, {lhs, rhs}, attributes, b); return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, {floatMul, result}, b); + loc, resultTypes, argTypes, {floatMul, result}, attributes, b); } if (mlir::isa(elementType)) { Value intMul = MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, {lhs, rhs}, b); + loc, resultTypes, argTypes, {lhs, rhs}, attributes, b); return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, {intMul, result}, b); + loc, resultTypes, argTypes, {intMul, result}, attributes, b); } return nullptr; } @@ -865,7 +785,8 @@ inline Value mapMhloOpToStdScalarOp(Location loc, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, - mhlo::IsFiniteOp::Adaptor adaptor, OpBuilder* b) { + mhlo::IsFiniteOp::Adaptor adaptor, ArrayRef /*attributes*/, + OpBuilder* b) { if (mlir::isa(adaptor.getX().getType())) { auto posInf = APFloat::getInf( mlir::cast(adaptor.getX().getType()).getFloatSemantics()); @@ -929,16 +850,17 @@ inline Value mhloAlwaysPropagateNaN(Value v, ValueRange args, Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::ClampOp::Adaptor op, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::ClampOp::Adaptor op, ArrayRef attributes, + OpBuilder* b) { // clamp(lb, x, ub) = min(max(lb, x), ub) Value maxLbX = mapMhloOpToStdScalarOp( - loc, resultTypes, argTypes, ValueRange{op.getMin(), op.getOperand()}, b); - return mapMhloOpToStdScalarOp( - loc, resultTypes, argTypes, ValueRange{maxLbX, op.getMax()}, b); + loc, resultTypes, argTypes, ValueRange{op.getMin(), op.getOperand()}, + attributes, b); + return mapMhloOpToStdScalarOp(loc, resultTypes, argTypes, + ValueRange{maxLbX, op.getMax()}, + attributes, b); } template @@ -980,16 +902,15 @@ inline Value makeSafeIntDiv(ImplicitLocOpBuilder& lb, Type originalType, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::DivOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::DivOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { Type originalType = getElementTypeOrSelf(argTypes.front()); if (mlir::isa(originalType)) { return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, - adaptor.getOperands(), b); + complex::DivOp>{}( + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } // Integer division overflow behavior: @@ -1012,15 +933,14 @@ inline Value mapMhloOpToStdScalarOp(Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::RemOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::RemOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { Type originalType = getElementTypeOrSelf(argTypes.front()); if (mlir::isa(originalType)) { return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } // Integer remainder overflow behavior: @@ -1037,16 +957,15 @@ inline Value mapMhloOpToStdScalarOp(Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::NegOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::NegOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { Type elementType = getElementTypeOrSelf(adaptor.getOperand().getType()); if (mlir::isa(elementType)) { return MapMhloOpToScalarOpImpl{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } if (mlir::isa(elementType)) { // lmhlo.neg(x, result) -> result = sub(0, x) @@ -1059,11 +978,10 @@ inline Value mapMhloOpToStdScalarOp(Location loc, } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef /*ResultTypes*/, - ArrayRef /*argTypes*/, - mhlo::NotOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, + mhlo::NotOp::Adaptor adaptor, ArrayRef /*attributes*/, + OpBuilder* b) { Type elementType = getElementTypeOrSelf(adaptor.getOperand().getType()); if (auto integerType = mlir::dyn_cast(elementType)) { // lmhlo.not(x) -> x ^ -1 @@ -1079,12 +997,13 @@ inline Value mapMhloOpToStdScalarOp(Location loc, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, - mhlo::LogisticOp::Adaptor adaptor, OpBuilder* b) { + mhlo::LogisticOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { // 1.0 / (1.0 + exp(-x)) Value negX = mapMhloOpToStdScalarOp( - loc, resultTypes, resultTypes, {adaptor.getOperand()}, b); - Value expNegX = mapMhloOpToStdScalarOp(loc, resultTypes, - resultTypes, {{negX}}, b); + loc, resultTypes, resultTypes, {adaptor.getOperand()}, attributes, b); + Value expNegX = mapMhloOpToStdScalarOp( + loc, resultTypes, resultTypes, {{negX}}, attributes, b); Type type = getElementTypeOrSelf(resultTypes[0]); Value oneFloat = @@ -1093,27 +1012,27 @@ inline Value mapMhloOpToStdScalarOp( : getConstantOrSplat(b, loc, resultTypes[0], FloatAttr::get(type, 1.0f)); Value one = mapConvertOpToStdScalarOp(loc, resultTypes, resultTypes, - {oneFloat.getType()}, {{oneFloat}}, b); + {oneFloat.getType()}, {{oneFloat}}, + attributes, b); Value oneAddExprNegX = mapMhloOpToStdScalarOp( - loc, resultTypes, resultTypes, {{expNegX, one}}, b); - return mapMhloOpToStdScalarOp(loc, resultTypes, resultTypes, - {{one, oneAddExprNegX}}, b); + loc, resultTypes, resultTypes, {{expNegX, one}}, attributes, b); + return mapMhloOpToStdScalarOp( + loc, resultTypes, resultTypes, {{one, oneAddExprNegX}}, attributes, b); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef argTypes, - mhlo::PowOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef argTypes, + mhlo::PowOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { auto lb = ImplicitLocOpBuilder(loc, *b); // TODO: b/315868720 Consider alternate lowerings of mhlo::PowOp with integer // operands. Floating point can use std::powf auto resultType = getElementTypeOrSelf(resultTypes.front()); if (mlir::isa(resultType)) { return MapMhloOpToScalarOpImpl{}(loc, resultTypes, argTypes, - adaptor.getOperands(), b); + complex::PowOp>{}( + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } // Exponentiation by squaring: @@ -1185,17 +1104,17 @@ inline Value mapMhloOpToStdScalarOp(Location loc, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef resultTypes, ArrayRef argTypes, - mhlo::SelectOp::Adaptor adaptor, OpBuilder* b) { + mhlo::SelectOp::Adaptor adaptor, ArrayRef attributes, + OpBuilder* b) { return MapMhloOpToScalarOpImpl<::mlir::arith::SelectOp>{}( - loc, resultTypes, argTypes, adaptor.getOperands(), b); + loc, resultTypes, argTypes, adaptor.getOperands(), attributes, b); } template <> -inline Value mapMhloOpToStdScalarOp(Location loc, - ArrayRef resultTypes, - ArrayRef /*argTypes*/, - mhlo::SignOp::Adaptor adaptor, - OpBuilder* b) { +inline Value mapMhloOpToStdScalarOp( + Location loc, ArrayRef resultTypes, ArrayRef /*argTypes*/, + mhlo::SignOp::Adaptor adaptor, ArrayRef /*attributes*/, + OpBuilder* b) { Value operand = adaptor.getOperand(); Type elementType = getElementTypeOrSelf(operand.getType()); if (auto floatType = mlir::dyn_cast(elementType)) { @@ -1252,7 +1171,8 @@ inline Value selectShiftedOrSaturated(ImplicitLocOpBuilder& lb, Value rhs, template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, - mhlo::ShiftLeftOp::Adaptor adaptor, OpBuilder* b) { + mhlo::ShiftLeftOp::Adaptor adaptor, ArrayRef /*attributes*/, + OpBuilder* b) { ImplicitLocOpBuilder lb(loc, *b); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); @@ -1268,7 +1188,8 @@ inline Value mapMhloOpToStdScalarOp( template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, - mhlo::ShiftRightLogicalOp::Adaptor adaptor, OpBuilder* b) { + mhlo::ShiftRightLogicalOp::Adaptor adaptor, + ArrayRef /*attributes*/, OpBuilder* b) { ImplicitLocOpBuilder lb(loc, *b); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); @@ -1284,7 +1205,8 @@ inline Value mapMhloOpToStdScalarOp( template <> inline Value mapMhloOpToStdScalarOp( Location loc, ArrayRef /*ResultTypes*/, ArrayRef /*argTypes*/, - mhlo::ShiftRightArithmeticOp::Adaptor adaptor, OpBuilder* b) { + mhlo::ShiftRightArithmeticOp::Adaptor adaptor, + ArrayRef /*attributes*/, OpBuilder* b) { ImplicitLocOpBuilder lb(loc, *b); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); @@ -1308,9 +1230,9 @@ struct MhloOpToStdScalarOp { // Converts mhlo 'op' to linalg and arith ops. template static Value mapOp(MhloOpTy op, ArrayRef resultTypes, ValueRange args, - OpBuilder* b) { + ArrayRef attributes, OpBuilder* b) { auto argTypes = llvm::to_vector(op->getOperandTypes()); - return mapOpWithArgTypes(op, resultTypes, argTypes, args, b); + return mapOpWithArgTypes(op, resultTypes, argTypes, args, attributes, b); } // Converts mhlo 'op' to linalg and arith ops. The types of 'args' may already @@ -1318,38 +1240,40 @@ struct MhloOpToStdScalarOp { template static Value mapOpWithArgTypes(MhloOpTy op, ArrayRef resultTypes, ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { static_assert(!std::is_same::value); typename MhloOpTy::Adaptor adaptor(args, op->getAttrDictionary(), op->getPropertiesStorage(), op->getRegions()); return mapOpOfType(op.getLoc(), resultTypes, argTypes, adaptor, - b); + attributes, b); } // Overload for mhlo::ConvertOp. static Value mapOpWithArgTypes(mhlo::ConvertOp op, ArrayRef resultTypes, ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { - return impl::mapConvertOpToStdScalarOp(op.getLoc(), op.getType(), - resultTypes, argTypes, args, b); + return impl::mapConvertOpToStdScalarOp( + op.getLoc(), op.getType(), resultTypes, argTypes, args, attributes, b); } // Converts mhlo 'op' to linalg and arith ops. template static Value mapOpOfType(Location loc, ArrayRef resultTypes, ArrayRef argTypes, - typename MhloOpTy::Adaptor adaptor, OpBuilder* b) { + typename MhloOpTy::Adaptor adaptor, + ArrayRef attributes, OpBuilder* b) { return impl::mapMhloOpToStdScalarOp(loc, resultTypes, argTypes, - adaptor, b); + adaptor, attributes, b); } - static Value mapConvertOpToStdScalarOp(Location loc, - ArrayRef targetTypes, - ArrayRef resultTypes, - ArrayRef argTypes, - ValueRange args, OpBuilder* b) { + static Value mapConvertOpToStdScalarOp( + Location loc, ArrayRef targetTypes, ArrayRef resultTypes, + ArrayRef argTypes, ValueRange args, + ArrayRef attributes, OpBuilder* b) { return impl::mapConvertOpToStdScalarOp(loc, targetTypes, resultTypes, - argTypes, args, b); + argTypes, args, attributes, b); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h new file mode 100644 index 00000000000000..9a85afab2570c1 --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/transformation_helpers.h @@ -0,0 +1,161 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ +#define XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +namespace mlir::mhlo { + +// Returns the input value with a reduced precision as specified by the target +// exponent and mantissa bits. This function will preserve the input shape on +// the output - i.e. it works with both scalars and tensors. +// +// The templated bitcast type allows this function to work with different kinds +// of bitcats, e.g. `arith.bitcast` or `triton.bitcast`. +template +Value reducePrecision(Location loc, Value input, int destExponentBits, + int destMantissaBits, OpBuilder* builder) { + using llvm::APInt; + mlir::ImplicitLocOpBuilder b(loc, *builder); + + // Integer and float types for casting and constant generation. + auto floatType = mlir::cast(getElementTypeOrSelf(input.getType())); + int64_t nbits = floatType.getWidth(); + auto intScalarType = mlir::IntegerType::get(loc.getContext(), nbits); + + Type intType = intScalarType; + std::optional> shape; + if (auto shapedType = dyn_cast(input.getType())) { + shape = shapedType.getShape().vec(); + intType = shapedType.clone(intScalarType); + } + + Value xAsInt = b.create(intType, input); + + // SignificandWidth includes the implicit extra bit. + auto srcMantissaBits = floatType.getFPMantissaWidth() - 1; + int srcExponentBits = nbits - 1 - srcMantissaBits; + + // Clear the sign bit, it does not participate in rounding and we will restore + // it later. + APInt signBitMask(nbits, 1); + signBitMask <<= nbits - 1; + + APInt expBitsMask(nbits, 1); + expBitsMask = ((expBitsMask << srcExponentBits) - 1) << srcMantissaBits; + + auto createConstant = [&](const APInt& v) { + return createScalarOrSplatConstant(b, loc, intType, v); + }; + + Value xAbsBits = + b.create(xAsInt, createConstant(~signBitMask)); + Value xIsNan = b.create(arith::CmpIPredicate::ugt, xAbsBits, + createConstant(expBitsMask)); + + if (destMantissaBits < static_cast(srcMantissaBits)) { + // Last remaining mantissa bit. + APInt lastMantissaBitMask(nbits, 1); + lastMantissaBitMask <<= srcMantissaBits - destMantissaBits; + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + APInt baseRoundingBias = lastMantissaBitMask.lshr(1) - 1; + + Value mantissaDiff = + createConstant(APInt(nbits, srcMantissaBits - destMantissaBits)); + + Value highestMantissaMaskVal = createConstant(lastMantissaBitMask); + Value baseRoundingBiasVal = createConstant(baseRoundingBias); + Value xLastMantissaBit = b.create( + b.create(xAsInt, highestMantissaMaskVal), mantissaDiff); + Value xRoundingBias = + b.create(xLastMantissaBit, baseRoundingBiasVal); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + APInt truncationMask = ~(lastMantissaBitMask - 1); + Value xRounded = b.create(xAsInt, xRoundingBias); + xAsInt = b.create(xRounded, createConstant(truncationMask)); + } + + if (destExponentBits < srcExponentBits) { + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + APInt exponentBias(nbits, 1); + exponentBias = (exponentBias << (srcExponentBits - 1)) - 1; + + APInt reducedExponentBias(nbits, 1); + reducedExponentBias = (reducedExponentBias << (destExponentBits - 1)) - 1; + + APInt reducedMaxExponent = exponentBias + reducedExponentBias; + APInt reducedMinExponent = exponentBias - reducedExponentBias; + + // Do we overflow or underflow? + Value xExponent = + b.create(xAsInt, createConstant(expBitsMask)); + Value xOverflows = b.create( + arith::CmpIPredicate::ugt, xExponent, + createConstant(reducedMaxExponent << srcMantissaBits)); + Value xUnderflows = b.create( + arith::CmpIPredicate::ule, xExponent, + createConstant(reducedMinExponent << srcMantissaBits)); + + // Compute appropriately-signed values of zero and infinity. + Value xSignedZero = + b.create(xAsInt, createConstant(signBitMask)); + Value xSignedInf = + b.create(xSignedZero, createConstant(expBitsMask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + xAsInt = b.create(xOverflows, xSignedInf, xAsInt); + xAsInt = b.create(xUnderflows, xSignedZero, xAsInt); + } + + Value result = b.create(input.getType(), xAsInt); + return b.create(xIsNan, input, result); +} +} // namespace mlir::mhlo + +#endif // XLA_MLIR_HLO_MHLO_TRANSFORMS_TRANSFORMATION_HELPERS_H_ diff --git a/third_party/xla/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h b/third_party/xla/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h index fa379f0410c2d6..ce6db083e7941e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/utils/legalize_to_linalg_utils.h @@ -166,7 +166,8 @@ class PointwiseToLinalgConverter : public OpConversionPattern { auto argvec = llvm::to_vector<2>(args.take_front(inputs.size())); auto semiring = preSparsify(op, argvec, innerResultTy, &rewriter); Value innerResult = mhlo::MhloOpToStdScalarOp::mapOp( - op, innerResultTy, argvec, &rewriter); + op, innerResultTy, argvec, /*attributes=*/std::nullopt, + &rewriter); if (innerResult == nullptr) { failed = true; } else { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc b/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc index 4ff1bf56fde53f..27aa4efc2ea6f0 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/utils/type_conversion.cc @@ -49,36 +49,36 @@ Type convertShapedType(ShapedType shapedType) { return shapedType; } -std::optional materializeCastFromIllegal(OpBuilder& builder, Type type, +Value materializeCastFromIllegal(OpBuilder& builder, Type type, ValueRange inputs, Location loc) { Type fromType = getElementTypeOrSelf(inputs[0].getType()); Type toType = getElementTypeOrSelf(type); if ((!fromType.isSignedInteger() && !fromType.isUnsignedInteger()) || !toType.isSignlessInteger()) - return std::nullopt; + return Value(); // Use unrealized conversion casts to do signful->signless conversions. return builder.create(loc, type, inputs[0]) ->getResult(0); } -std::optional materializeCastToIllegal(OpBuilder& builder, Type type, +Value materializeCastToIllegal(OpBuilder& builder, Type type, ValueRange inputs, Location loc) { Type fromType = getElementTypeOrSelf(inputs[0].getType()); Type toType = getElementTypeOrSelf(type); if (!fromType.isSignlessInteger() || (!toType.isSignedInteger() && !toType.isUnsignedInteger())) - return std::nullopt; + return Value(); // Use unrealized conversion casts to do signless->signful conversions. return builder.create(loc, type, inputs[0]) ->getResult(0); } -std::optional scalarToTensor(OpBuilder& builder, Type type, +Value scalarToTensor(OpBuilder& builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); if (mlir::isa(inputs.front().getType())) { - return std::nullopt; + return Value(); } Value result = builder diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h index 5c54390de66f4f..e2b77594141f05 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/passes.h @@ -17,8 +17,10 @@ limitations under the License. #define STABLEHLO_EXT_TRANSFORMS_PASSES_H #include +#include #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp new file mode 100644 index 00000000000000..96dc5b8c645b9e --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_compatibility_expander.cpp @@ -0,0 +1,35 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Pass/Pass.h" +#include "stablehlo/transforms/Passes.h" +#include "stablehlo_ext/transforms/passes.h" + +namespace mlir { +namespace stablehlo_ext { + +// TODO(b/369406385): remove this method (and file) once issue is resolved. + +std::unique_ptr<::mlir::Pass> createStablehloCompatibilityExpanderPass( + std::string targetVersionOption) { + return mlir::stablehlo::createStablehloCompatibilityExpanderPass( + {std::move(targetVersionOption)}); +} + +} // namespace stablehlo_ext +} // namespace mlir diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index ddd8348641cb30..b47a159ff6d654 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -1668,8 +1668,8 @@ func.func @zeta_f16(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[TMP_33:.*]] = mhlo.add %[[TMP_30]], %[[TMP_3]] // CHECK: %[[TMP_34:.*]] = mhlo.power %[[TMP_33]], %[[TMP_4]] // CHECK: %[[TMP_35:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] - // CHECK: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] + // CHECK-DAG: %[[TMP_36:.*]] = mhlo.multiply %[[TMP_34]], %[[TMP_33]] + // CHECK-DAG: %[[TMP_37:.*]] = mhlo.subtract %[[TMP_0]], %[[TMP_35]] // CHECK: %[[TMP_38:.*]] = mhlo.divide %[[TMP_36]], %[[TMP_37]] // CHECK: %[[TMP_39:.*]] = mhlo.multiply %[[TMP_33]], %[[TMP_33]] // CHECK: %[[TMP_40:.*]] = mhlo.divide %[[TMP_3]], %[[TMP_39]] @@ -1943,8 +1943,8 @@ func.func @polygamma_f32(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK-DAG: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK-DAG: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] @@ -2330,8 +2330,8 @@ func.func @polygamma_f64(%lhs : tensor, %rhs : tensor) -> tensor // CHECK: %[[TMP_121:.*]] = mhlo.add %[[TMP_118]], %[[TMP_91]] // CHECK: %[[TMP_122:.*]] = mhlo.power %[[TMP_121]], %[[TMP_92]] // CHECK: %[[TMP_123:.*]] = mhlo.constant dense<1.000000e+00> - // CHECK: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] - // CHECK: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] + // CHECK-DAG: %[[TMP_124:.*]] = mhlo.multiply %[[TMP_122]], %[[TMP_121]] + // CHECK-DAG: %[[TMP_125:.*]] = mhlo.subtract %[[TMP_5]], %[[TMP_123]] // CHECK: %[[TMP_126:.*]] = mhlo.divide %[[TMP_124]], %[[TMP_125]] // CHECK: %[[TMP_127:.*]] = mhlo.multiply %[[TMP_121]], %[[TMP_121]] // CHECK: %[[TMP_128:.*]] = mhlo.divide %[[TMP_91]], %[[TMP_127]] @@ -2754,13 +2754,13 @@ func.func @atanh_complex_f32(%arg : tensor>) -> tensor // CHECK-NEXT: %[[ABS_IMAG:.*]] = mhlo.abs %[[IMAG]] // CHECK-NEXT: %[[CMP4:.*]] = mhlo.compare LT, %[[ABS_IMAG]], %[[SQUARE]] // CHECK-NEXT: %[[AND:.*]] = mhlo.and %[[CMP3]], %[[CMP4]] - // CHECK-NEXT: %[[SUB0:.*]] = mhlo.subtract %[[ONE]], %[[ABS_REAL]] - // CHECK-NEXT: %[[SQUARE1:.*]] = mhlo.multiply %[[SUB0]], %[[SUB0]] - // CHECK-NEXT: %[[SQUARE2:.*]] = mhlo.multiply %[[IMAG]], %[[IMAG]] - // CHECK-NEXT: %[[ADD0:.*]] = mhlo.add %[[SQUARE1]], %[[SQUARE2]] - // CHECK-NEXT: %[[DIV0:.*]] = mhlo.divide %[[ABS_REAL]], %[[ADD0]] - // CHECK-NEXT: %[[MULT0:.*]] = mhlo.multiply %[[ABS_IMAG]], %[[SELECT2]] - // CHECK-NEXT: %[[CMP5:.*]] = mhlo.compare LT, %[[MULT0]], %[[ABS_REAL]] + // CHECK-NEXT: %[[SUB0:.*]] = mhlo.subtract %[[ONE]], %[[ABS_REAL]] + // CHECK-NEXT: %[[SQUARE1:.*]] = mhlo.multiply %[[SUB0]], %[[SUB0]] + // CHECK-NEXT: %[[SQUARE2:.*]] = mhlo.multiply %[[IMAG]], %[[IMAG]] + // CHECK-NEXT: %[[ADD0:.*]] = mhlo.add %[[SQUARE1]], %[[SQUARE2]] + // CHECK-NEXT: %[[DIV0:.*]] = mhlo.divide %[[ABS_REAL]], %[[ADD0]] + // CHECK-NEXT: %[[MULT0:.*]] = mhlo.multiply %[[ABS_IMAG]], %[[SELECT2]] + // CHECK-NEXT: %[[CMP5:.*]] = mhlo.compare LT, %[[MULT0]], %[[ABS_REAL]] // CHECK-NEXT: %[[DIV1:.*]] = mhlo.divide %[[ONE]], %[[ABS_REAL]] // CHECK-NEXT: %[[ISINF_REAL:.*]] = mhlo.constant dense<0x7F800000> // CHECK-NEXT: %[[ISNINF_REAL:.*]] = mhlo.compare EQ, %[[REAL]], %[[ISINF_REAL]] diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir index b34126213193df..b8cd07d2e5bcce 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/canonicalize.mlir @@ -1,290 +1,114 @@ // RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s -// CHECK-LABEL: add_fold -func.func @add_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> - %1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> - // CHECK: mhlo.constant dense<[6, 8, 10, 12]> - %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: add_scalar_fold -func.func @add_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<1> : tensor<4xi64> - %1 = mhlo.constant dense<5> : tensor<4xi64> - // CHECK: mhlo.constant dense<6> - %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: add_fold_float -func.func @add_fold_float() -> tensor<4xf64> { - %0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> - %1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> - // CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> - %2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - func.return %2 : tensor<4xf64> -} - -// CHECK-LABEL: add_zero_int_fold -func.func @add_zero_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = mhlo.constant dense<0> : tensor<2x2xi64> - %1 = "mhlo.add"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> - // CHECK: return %arg0 : tensor<2x2xi64> - func.return %1 : tensor<2x2xi64> -} - -// CHECK-LABEL: add_zero_float_flod -func.func @add_zero_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = mhlo.constant dense<0.0> : tensor<2x2xf32> - %1 = "mhlo.add"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: return %arg0 : tensor<2x2xf32> - func.return %1 : tensor<2x2xf32> -} - -// CHECK-LABEL: sub_scalar_fold -func.func @sub_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<5> : tensor<4xi64> - %1 = mhlo.constant dense<1> : tensor<4xi64> - // CHECK: mhlo.constant dense<4> - %2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: multiply_scalar_fold -func.func @multiply_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<5> : tensor<4xi64> - %1 = mhlo.constant dense<3> : tensor<4xi64> - // CHECK: mhlo.constant dense<15> - %2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: mul_one_int_fold -func.func @mul_one_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = mhlo.constant dense<1> : tensor<2x2xi64> - %1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> - // CHECK: return %arg0 : tensor<2x2xi64> - func.return %1 : tensor<2x2xi64> -} - -// CHECK-LABEL: mul_one_int8_fold -func.func @mul_one_int8_fold(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { - %0 = mhlo.constant dense<1> : tensor<2x2xi8> - %1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> - // CHECK: return %arg0 : tensor<2x2xi8> - func.return %1 : tensor<2x2xi8> -} - -// CHECK-LABEL: mul_one_float_flod -func.func @mul_one_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - %0 = mhlo.constant dense<1.0> : tensor<2x2xf32> - %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // CHECK: return %arg0 : tensor<2x2xf32> - func.return %1 : tensor<2x2xf32> -} - -// CHECK-LABEL: mul_one_fp16_flod -func.func @mul_one_fp16_flod(%arg0: tensor<2x2xf16>) -> tensor<2x2xf16> { - %0 = mhlo.constant dense<1.0> : tensor<2x2xf16> - %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> - // CHECK: return %arg0 : tensor<2x2xf16> - func.return %1 : tensor<2x2xf16> -} - -// CHECK-LABEL: mul_one_bf16_flod -func.func @mul_one_bf16_flod(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { - %0 = mhlo.constant dense<1.0> : tensor<2x2xbf16> - %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> - // CHECK: return %arg0 : tensor<2x2xbf16> - func.return %1 : tensor<2x2xbf16> -} - - -// CHECK-LABEL: divide_scalar_fold -func.func @divide_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<7> : tensor<4xi64> - %1 = mhlo.constant dense<5> : tensor<4xi64> - // CHECK: mhlo.constant dense<1> - %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: divide_scalar_fold_by_zero -func.func @divide_scalar_fold_by_zero() -> tensor<4xi64> { - %0 = mhlo.constant dense<7> : tensor<4xi64> - %1 = mhlo.constant dense<0> : tensor<4xi64> - // CHECK: mhlo.divide - %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: divide_fold_int -func.func @divide_fold_int() -> tensor<4xi32> { - %0 = mhlo.constant dense<[1, -2, 3, 4]> : tensor<4xi32> - %1 = mhlo.constant dense<[-1, -2, -3, 2]> : tensor<4xi32> - // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[-1, 1, -1, 2]> - %2 = "mhlo.divide"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) - // CHECK: return %[[RESULT]] - func.return %2 : tensor<4xi32> +//////// +// BroadcastOp(deprecated) + +// CHECK-LABEL: func @broadcast_identity +func.func @broadcast_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: return %arg0 + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %0 : tensor<2x3x4xf32> } -// CHECK-LABEL: divide_fold_unsigned -func.func @divide_fold_unsigned() -> tensor<4xui32> { - %0 = mhlo.constant dense<[1, -2, 3, 4]> : tensor<4xi32> - %1 = "mhlo.convert"(%0) : (tensor<4xi32>) -> tensor<4xui32> - %2 = mhlo.constant dense<[-1, -2, -3, 2]> : tensor<4xi32> - %3 = "mhlo.convert"(%2) : (tensor<4xi32>) -> tensor<4xui32> - // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[0, 1, 0, 2]> - %4 = "mhlo.divide"(%1, %3) : (tensor<4xui32>, tensor<4xui32>) -> (tensor<4xui32>) - // CHECK: return %[[RESULT]] - func.return %4 : tensor<4xui32> -} - -// CHECK-LABEL: divide_fold_float -func.func @divide_fold_float() -> tensor<4xf64> { - %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> - %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> - // CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> - %2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) - func.return %2 : tensor<4xf64> -} - -// CHECK-LABEL: divide_fold_by_zero -func.func @divide_fold_by_zero() -> tensor<4xi64> { - %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> - %1 = mhlo.constant dense<[1, 2, 3, 0]> : tensor<4xi64> - // CHECK: mhlo.divide - %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: remainder_scalar_fold_by_zero -func.func @remainder_scalar_fold_by_zero() -> tensor<4xi64> { - %0 = mhlo.constant dense<7> : tensor<4xi64> - %1 = mhlo.constant dense<0> : tensor<4xi64> - // CHECK: mhlo.remainder - %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: remainder_fold_int -func.func @remainder_fold_int() -> tensor<4xi32> { - %0 = mhlo.constant dense<[5, 66, 5, -1]> : tensor<4xi32> - %1 = mhlo.constant dense<[3, 5, 1, -2]> : tensor<4xi32> - // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[2, 1, 0, -1]> - %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) - // CHECK: return %[[RESULT]] - func.return %2 : tensor<4xi32> +// CHECK-LABEL: func @broadcast_dynamic_shape_identity +func.func @broadcast_dynamic_shape_identity(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor + func.return %0 : tensor } -// CHECK-LABEL: remainder_fold_float -func.func @remainder_fold_float() -> tensor<8xf32> { - %0 = mhlo.constant dense<[-2.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0]> : tensor<8xf32> - %1 = mhlo.constant dense<[10.0, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0]> : tensor<8xf32> - // CHECK{LITERAL}: mhlo.constant dense<[-2.500000e+00, 2.500000e-01, -0.000000e+00, 0.000000e+00, 1.000000e+00, 1.000000e+00, -1.000000e+00, -0.000000e+00]> - %2 = "mhlo.remainder"(%0, %1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>) - func.return %2 : tensor<8xf32> +// CHECK-LABEL: func @broadcast_dynamic_shape_not_identity +func.func @broadcast_dynamic_shape_not_identity(%arg0: tensor) -> tensor<20x?x?x?xf32> { + // CHECK: mhlo.broadcast + %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[20]> : tensor<1xi64>}> : (tensor) -> tensor<20x?x?x?xf32> + func.return %0 : tensor<20x?x?x?xf32> } -// CHECK-LABEL: round_fold -func.func @round_fold() -> tensor<4xf32> { - %0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32> - %1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32> - func.return %1 : tensor<4xf32> - // CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]> +//////// +// BroadcastInDimOp + +// CHECK-LABEL: func @broadcast_consecutive +func.func @broadcast_consecutive(%arg0: tensor<2x3xf32>) -> tensor<2x3x4x5xf32> { + // CHECK: mhlo.broadcast_in_dim + // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> + // CHECK-NEXT: return + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<2x3x4xf32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4x5xf32> + func.return %1 : tensor<2x3x4x5xf32> } -// CHECK-LABEL: round_nearest_even_fold -func.func @round_nearest_even_fold() -> tensor<4xf32> { - %0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32> - %1 = "mhlo.round_nearest_even"(%0) : (tensor<4xf32>) -> tensor<4xf32> - func.return %1 : tensor<4xf32> - // CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 2.000000e+00]> -} - -// CHECK-LABEL: max_scalar_fold -func.func @max_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<7> : tensor<4xi64> - %1 = mhlo.constant dense<-5> : tensor<4xi64> - // CHECK: %[[RESULT:.+]] = mhlo.constant dense<7> - %2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - // CHECK: return %[[RESULT]] - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: max_scalar_fold_unsigned -func.func @max_scalar_fold_unsigned() -> tensor<4xui32> { - %0 = mhlo.constant dense<7> : tensor<4xui32> - %1 = mhlo.constant dense<-5> : tensor<4xi32> - %2 = "mhlo.convert"(%1) : (tensor<4xi32>) -> tensor<4xui32> - // CHECK: %[[RESULT:.+]] = mhlo.constant dense<4294967291> - %3 = "mhlo.maximum"(%0, %2) : (tensor<4xui32>, tensor<4xui32>) -> (tensor<4xui32>) - // CHECK: return %[[RESULT]] - func.return %3 : tensor<4xui32> -} - -// CHECK-LABEL: max_fold_float -func.func @max_fold_float() -> tensor<6xf32> { - %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 1.0]> : tensor<6xf32> - %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> - // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 4.000000e+00] - %2 = "mhlo.maximum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> (tensor<6xf32>) - func.return %2 : tensor<6xf32> -} - -// CHECK-LABEL: min_scalar_fold -func.func @min_scalar_fold() -> tensor<4xi64> { - %0 = mhlo.constant dense<7> : tensor<4xi64> - %1 = mhlo.constant dense<-5> : tensor<4xi64> - // CHECK: mhlo.constant dense<-5> - %2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) - func.return %2 : tensor<4xi64> -} - -// CHECK-LABEL: min_fold_float -func.func @min_fold_float() -> tensor<6xf32> { - %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 1.0]> : tensor<6xf32> - %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> - // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 1.000000e+00] - %2 = "mhlo.minimum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> (tensor<6xf32>) - func.return %2 : tensor<6xf32> -} - -// CHECK-LABEL: clamp_scalar_fold -func.func @clamp_scalar_fold() -> tensor<5xi64> { - %0 = mhlo.constant dense<149> : tensor - %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> - %2 = mhlo.constant dense<0> : tensor - // CHECK{LITERAL}: mhlo.constant dense<[0, 100, 149, 0, 149]> - // CHECK-NOT: mhlo.clamp - %3 = mhlo.clamp %2, %1, %0 : (tensor, tensor<5xi64>, tensor) -> tensor<5xi64> - return %3 : tensor<5xi64> -} - -// CHECK-LABEL: clamp_fold -func.func @clamp_fold() -> tensor<5xi64> { - %0 = mhlo.constant dense<[149, 101, -1, 30, 50]> : tensor<5xi64> - %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> - %2 = mhlo.constant dense<[0, 10, -10, 10, -100]> : tensor<5xi64> - // CHECK{LITERAL}: mhlo.constant dense<[0, 100, -1, 10, 50]> - // CHECK-NOT: mhlo.clamp - %3 = mhlo.clamp %2, %1, %0 : (tensor<5xi64>, tensor<5xi64>, tensor<5xi64>) -> tensor<5xi64> - return %3 : tensor<5xi64> -} - -// CHECK-LABEL: clamp_fold_float -func.func @clamp_fold_float() -> tensor<6xf32> { - %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 6.0]> : tensor<6xf32> - %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> - %2 = mhlo.constant dense<[5.0, 1.0, 1.0, 0xFFFFFFFF, 0xFFFFFFFF, 5.0]> : tensor<6xf32> - // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 5.000000e+00] - // CHECK-NOT: mhlo.clamp - %3 = mhlo.clamp %2, %1, %0 : (tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32> - return %3 : tensor<6xf32> +// CHECK-LABEL: func @broadcast_in_dim_equivalent_reshape +func.func @broadcast_in_dim_equivalent_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> { + // CHECK: mhlo.reshape + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> + func.return %0 : tensor<1x2x3x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts +func.func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { + // CHECK: mhlo.broadcast_in_dim + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_equivalent_transpose +func.func @broadcast_in_dim_equivalent_transpose(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: mhlo.transpose + // CHECK-SAME: permutation = dense<[1, 0]> + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: @identity_broadcast_reshape +func.func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast"(%arg0) { + broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + func.return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +} + +// CHECK-LABEL: @identity_broadcast_in_dim_reshape +func.func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32> + %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> + func.return %1 : tensor<128xf32> + // CHECK: return %arg0 : tensor<128xf32> +} + +// CHECK-LABEL: @eliminate_identity_convert +func.func @eliminate_identity_convert(%arg : tensor) -> tensor { + // CHECK-NOT: mhlo.convert + %0 = "mhlo.convert"(%arg) : (tensor) -> tensor + // CHECK: return %arg0 : tensor + func.return %0 : tensor +} + +//////// +// ComplexOp + +// CHECK-LABEL: @complex_expand +func.func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { + %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) + %1 = mhlo.real %0 : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + // CHECK: return %arg0, %arg1 + func.return %1, %2 : tensor<4xf32>, tensor<4xf32> +} + +// CHECK-LABEL: @complex_collapse +func.func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { + %0 = mhlo.real %arg0 : (tensor<4xcomplex>) -> (tensor<4xf32>) + %1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) + %2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + // CHECK: return %arg0 + func.return %2 : tensor<4xcomplex> } +//////// +// ConcatenateOp + // CHECK-LABEL: concatenate_noop func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG:%.+]]: tensor<4xi32> @@ -294,16 +118,6 @@ func.func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { func.return %0 : tensor<4xi32> } -// CHECK-LABEL: concatenate_noop_typecast -func.func @concatenate_noop_typecast(%arg0: tensor) -> tensor<4xi32> { - // CHECK-SAME: [[ARG:%.+]]: tensor - // CHECK-NEXT: [[RES:%.+]] = tensor.cast [[ARG]] : tensor to tensor<4xi32> - %0 = "mhlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor) -> tensor<4xi32> - - // CHECK: return [[RES]] - func.return %0 : tensor<4xi32> -} - // CHECK-LABEL: concatenate_remove_operand func.func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> { // CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32> @@ -397,437 +211,104 @@ func.func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> { func.return %2 : tensor<2x2xi32> } -// CHECK-LABEL: constant_like_constant -func.func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { - // CHECK: chlo.constant dense<3.200000e+00> - %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32> - func.return %0 : tensor<3x4xf32> -} - -// CHECK-LABEL: constant_like_constant_dynamic -func.func @constant_like_constant_dynamic(%arg0: tensor) -> tensor { - // CHECK: chlo.constant_like - %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: dynamic_update_slice_fold_length_0 -func.func @dynamic_update_slice_fold_length_0(%arg0: tensor<3x4xi64>, %arg1: tensor<3x0xi64>) -> tensor<3x4xi64> { - // CHECK: return %arg0 - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x0xi64>, tensor, tensor) -> tensor<3x4xi64> - func.return %1 : tensor<3x4xi64> -} +//////// +// CopyOp -// CHECK-LABEL: dynamic_update_slice_identity_update -func.func @dynamic_update_slice_identity_update(%arg0: tensor<3x4xi64>, %arg1: tensor<3x4xi64>) -> tensor<3x4xi64> { - // CHECK: return %arg1 - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x4xi64>, tensor, tensor) -> tensor<3x4xi64> - func.return %1 : tensor<3x4xi64> +// CHECK-LABEL: func @fold_copy +// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] +func.func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { + // CHECK: return [[ARG]] + %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> + func.return %0 : tensor<1x4xf32> } -// CHECK-LABEL: dynamic_update_slice_fold_fail_dynamic_shapes -func.func @dynamic_update_slice_fold_fail_dynamic_shapes(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor, tensor, tensor, tensor) -> tensor - func.return %1 : tensor - // CHECK: %[[CST:.*]] = mhlo.constant dense<0> : tensor - // CHECK: %[[VAL:.*]] = mhlo.dynamic_update_slice %arg0, %arg1, %[[CST]], %[[CST]] : (tensor, tensor, tensor, tensor) -> tensor - // CHECK: return %[[VAL]] : tensor -} +//////// +// DynamicBroadcastInDimOp -// CHECK-LABEL: dynamic_slice_variable_start -func.func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { - // CHECK: "mhlo.dynamic_slice" - %1 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> - func.return %1 : tensor<1x4xi32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic +func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + func.return %0 : tensor<5x4xf32> } -// CHECK-LABEL: dynamic_slice_constant_start -func.func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { - // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) - // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> - // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> - // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} - // CHECK: return %[[RESULT]] : tensor<2xi32> - %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.dynamic_slice"(%arg0, %0) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> - func.return %1 : tensor<2xi32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape +func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor) -> tensor<4x32xi32> { + %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32> + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xi32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi32>) -> tensor + %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xi32>) -> tensor<4x32xi32> + // CHECK: return %[[RESULT]] : tensor<4x32xi32> + func.return %2 : tensor<4x32xi32> } -// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape -func.func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: mhlo.dynamic_slice - // CHECK-NOT: mhlo.slice - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor, tensor, tensor) -> tensor<1x4xi32> - func.return %2 : tensor<1x4xi32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape +func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape(%arg0: tensor) -> tensor<4x32xf32> { + %0 = shape.const_shape [4, 32] : tensor<2xindex> + // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor + %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xindex>) -> tensor<4x32xf32> + // CHECK: return %[[RESULT]] : tensor<4x32xf32> + func.return %2 : tensor<4x32xf32> } -// CHECK-LABEL: dynamic_slice_constant_start_upper_bound -func.func @dynamic_slice_constant_start_upper_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) - // CHECK-SAME: limit_indices = dense<[8, 4]> : tensor<2xi64> - // CHECK-SAME: start_indices = dense<[7, 0]> : tensor<2xi64> - // CHECK-SAME: strides = dense<1> : tensor<2xi64> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %0 = mhlo.constant dense<10> : tensor - %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> - func.return %2 : tensor<1x4xi32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast +func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast(%arg0: tensor) -> tensor { + %0 = shape.const_shape [4, 32] : tensor<2xindex> + // CHECK: %[[BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> + %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor + // CHECK: %[[RESULT:.*]] = tensor.cast %[[BCAST]] : tensor<4x32xf32> to tensor + // CHECK: return %[[RESULT]] : tensor + func.return %1 : tensor } -// CHECK-LABEL: dynamic_slice_constant_start_lower_bound -func.func @dynamic_slice_constant_start_lower_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { - // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) - // CHECK-SAME: limit_indices = dense<[1, 4]> : tensor<2xi64> - // CHECK-SAME: start_indices = dense<0> : tensor<2xi64> - // CHECK-SAME: strides = dense<1> : tensor<2xi64> - // CHECK: return %[[RESULT]] : tensor<1x4xi32> - %0 = mhlo.constant dense<-1> : tensor - %1 = mhlo.constant dense<0> : tensor - %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> - func.return %2 : tensor<1x4xi32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic +func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { + // CHECK: %[[RESULT:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> + %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> + // CHECK: return %[[RESULT]] : tensor<5x4xf32> + func.return %0 : tensor<5x4xf32> } -// CHECK-LABEL: slice_2D_noop -// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> -func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { - %0 = "mhlo.slice"(%arg0) <{ limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x2xi64>) -> (tensor<2x2xi64>) - - // CHECK-NEXT: return [[ARG]] - func.return %0 : tensor<2x2xi64> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_1 +func.func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + func.return %2 : tensor } -// CHECK-LABEL: slice_1D_fold -func.func @slice_1D_fold() -> tensor<2xi64> { - %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: mhlo.constant dense<[7, 9]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) - func.return %1 : tensor<2xi64> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2 +func.func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> !shape.shape + %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + func.return %2 : tensor } -// CHECK-LABEL: slice_1D_fp -func.func @slice_1D_fp() -> tensor<2xf32> { - %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> - // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> (tensor<2xf32>) - func.return %1 : tensor<2xf32> +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3 +func.func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> tensor + %1 = tensor.cast %0 : tensor to tensor<1xindex> + %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + func.return %2 : tensor } -// CHECK-LABEL: slice_1D_strided_fold -func.func @slice_1D_strided_fold() -> tensor<2xi64> { - %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> - // CHECK: mhlo.constant dense<[7, 10]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) - func.return %1 : tensor<2xi64> -} - -// CHECK-LABEL: slice_2D_fold -func.func @slice_2D_fold() -> tensor<2x2xi64> { - %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: mhlo.constant dense<[ - // CHECK-SAME: [6, 7], - // CHECK-SAME: [10, 11] - // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<2x2xi64>) - func.return %1 : tensor<2x2xi64> -} - -// CHECK-LABEL: slice_2D_fold_horizontal -func.func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { - %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: mhlo.constant dense<[ - // CHECK-SAME: [0, 1, 2, 3] - // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<1x4xi64>) - func.return %1 : tensor<1x4xi64> -} - -// CHECK-LABEL: slice_2D_fold_vertical -func.func @slice_2D_fold_vertical() -> tensor<4x1xi64> { - %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> - // CHECK-NEXT: mhlo.constant dense<[ - // CHECK-SAME: [2], [6], [10], [14] - // CHECK-SAME: ]> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<4x1xi64>) - func.return %1 : tensor<4x1xi64> -} - -// CHECK-LABEL: slice_zero_elements -func.func @slice_zero_elements() -> tensor<0xi64> { - %0 = mhlo.constant dense<> : tensor<0xi64> - // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<0xi64>) -> (tensor<0xi64>) - // CHECK: return %[[CONST]] : tensor<0xi64> - func.return %1 : tensor<0xi64> -} - -// CHECK-LABEL: slice_unknown_shape -func.func @slice_unknown_shape(%arg0: tensor) -> tensor { - // CHECK: "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor - %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: slice_concat_fold_first -func.func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) - // CHECK: return %arg0 - func.return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_second -func.func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) - // CHECK: return %arg1 - func.return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_second_with_slice -func.func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x5xf32>) -> tensor<1x4xf32> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x4xf32>) - - // CHECK: return [[SLICE]] - func.return %1 : tensor<1x4xf32> -} - -// CHECK-LABEL: slice_concat_fold_middle -func.func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<1x5xf32>) - - // CHECK: return [[SLICE]] - func.return %1 : tensor<1x5xf32> -} - -// CHECK-LABEL: slice_concat_fold_two -func.func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { - // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) <{dimension = 0 : i64}> - %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> - - // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) <{limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<2x5xf32>) - - // CHECK: return [[SLICE]] - func.return %1 : tensor<2x5xf32> -} - -// CHECK-LABEL: slice_concat_empty -func.func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { - %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> - %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<0x5xf32>) - %2 = "mhlo.concatenate"(%1, %arg2) <{ dimension = 0 : i64 }> : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32> - - // CHECK: return %arg2 - func.return %2 : tensor<1x5xf32> -} - -// CHECK-LABEL: func @broadcast_identity -func.func @broadcast_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - // CHECK: return %arg0 - %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> - func.return %0 : tensor<2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_dynamic_shape_identity -func.func @broadcast_dynamic_shape_identity(%arg0: tensor) -> tensor { - // CHECK: return %arg0 - %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @broadcast_dynamic_shape_not_identity -func.func @broadcast_dynamic_shape_not_identity(%arg0: tensor) -> tensor<20x?x?x?xf32> { - // CHECK: mhlo.broadcast - %0 = "mhlo.broadcast"(%arg0) <{broadcast_sizes = dense<[20]> : tensor<1xi64>}> : (tensor) -> tensor<20x?x?x?xf32> - func.return %0 : tensor<20x?x?x?xf32> -} - -// CHECK-LABEL: func @broadcast_constant_fold_0d -func.func @broadcast_constant_fold_0d() -> tensor<1x64x224x224xf32> { - %cst = mhlo.constant dense<0.000000e+00> : tensor - %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor) -> tensor<1x64x224x224xf32> - func.return %b : tensor<1x64x224x224xf32> -} -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> - -// CHECK-LABEL: func @broadcast_constant_fold -func.func @broadcast_constant_fold() -> tensor<1x64x4x4xf32> { - %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> - %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> - func.return %b : tensor<1x64x4x4xf32> -} -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> - -// CHECK-LABEL: func @broadcast_constant_fold_not_splat -func.func @broadcast_constant_fold_not_splat() -> tensor<1x64x2xf32> { - // CHECK: mhlo.constant - %cst = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> - // CHECK: mhlo.broadcast - %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<2xf32>) -> tensor<1x64x2xf32> - func.return %b : tensor<1x64x2xf32> -} - -// CHECK-LABEL: func @broadcast_constant_fold_complex -func.func @broadcast_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { - %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> - %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> - func.return %b : tensor<1x64x224x224xcomplex> -} -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xcomplex> - -// CHECK-LABEL: func @broadcast_constant_fold_quantized_skipped -func.func @broadcast_constant_fold_quantized_skipped() -> tensor<1x64x224x224x!quant.uniform> { - %cst = mhlo.constant() {value = dense<2> : tensor} : () -> tensor> - %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> - func.return %b : tensor<1x64x224x224x!quant.uniform> -} -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant() <{value = dense<2> : tensor}> : () -> tensor> -// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast"(%[[CST:.*]]) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> -// CHECK-NEXT: return %[[RES:.*]] : tensor<1x64x224x224x!quant.uniform> - -// CHECK-LABEL: func @broadcast_in_dim_identity -func.func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { - // CHECK: return %arg0 - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> - func.return %0 : tensor<2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_equivalent_reshape -func.func @broadcast_in_dim_equivalent_reshape(%arg0: tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> { - // CHECK: mhlo.reshape - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<1x2x3x4xf32> - func.return %0 : tensor<1x2x3x4xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts -func.func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { - // CHECK: mhlo.broadcast_in_dim - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_equivalent_transpose -func.func @broadcast_in_dim_equivalent_transpose(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { - // CHECK: mhlo.transpose - // CHECK-SAME: permutation = dense<[1, 0]> - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>}> : (tensor<2x2xf32>) -> tensor<2x2xf32> - func.return %0 : tensor<2x2xf32> -} - -// CHECK-LABEL: func @broadcast_in_dim_constant_fold_quantized_skipped -func.func @broadcast_in_dim_constant_fold_quantized_skipped(%arg0: tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { - %b = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> - func.return %b : tensor<2x2x!quant.uniform> -} -// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> -// CHECK-NEXT: return %[[RES:.*]] : tensor<2x2x!quant.uniform> - -// CHECK-LABEL: func @broadcast_consecutive -func.func @broadcast_consecutive(%arg0: tensor<2x3xf32>) -> tensor<2x3x4x5xf32> { - // CHECK: mhlo.broadcast_in_dim - // CHECK-SAME: broadcast_dimensions = dense<[0, 1]> - // CHECK-NEXT: return - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<2x3xf32>) -> tensor<2x3x4xf32> - %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4x5xf32> - func.return %1 : tensor<2x3x4x5xf32> -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic -func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<5x4xf32> - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32> - // CHECK: return %[[RESULT]] : tensor<5x4xf32> - func.return %0 : tensor<5x4xf32> -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape -func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_shape(%arg0: tensor) -> tensor<4x32xi32> { - %0 = mhlo.constant dense<[4, 32]> : tensor<2xi32> - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xi32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xi32>) -> tensor - %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xi32>) -> tensor<4x32xi32> - // CHECK: return %[[RESULT]] : tensor<4x32xi32> - func.return %2 : tensor<4x32xi32> -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape -func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_index_shape(%arg0: tensor) -> tensor<4x32xf32> { - %0 = shape.const_shape [4, 32] : tensor<2xindex> - // CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor - %2 = "mhlo.dynamic_reshape"(%1, %0) : (tensor, tensor<2xindex>) -> tensor<4x32xf32> - // CHECK: return %[[RESULT]] : tensor<4x32xf32> - func.return %2 : tensor<4x32xf32> -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast -func.func @dynamic_broadcast_in_dim_op_not_actually_dynamic_constant_requires_cast(%arg0: tensor) -> tensor { - %0 = shape.const_shape [4, 32] : tensor<2xindex> - // CHECK: %[[BCAST:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x32xf32> - %1 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor, tensor<2xindex>) -> tensor - // CHECK: %[[RESULT:.*]] = tensor.cast %[[BCAST]] : tensor<4x32xf32> to tensor - // CHECK: return %[[RESULT]] : tensor - func.return %1 : tensor -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic -func.func @dynamic_broadcast_in_dim_op_almost_not_actually_dynamic(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<5x4xf32> { - // CHECK: %[[RESULT:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> - %0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) <{ broadcast_dimensions = dense<1> : tensor<1xi64> }> : (tensor, tensor<2xi64>) -> tensor<5x4xf32> - // CHECK: return %[[RESULT]] : tensor<5x4xf32> - func.return %0 : tensor<5x4xf32> -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_1 -func.func @dynamic_broadcast_in_dim_to_same_shape_1(%arg0: tensor) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor - %0 = shape.shape_of %arg0 : tensor -> tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %0) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[ARG]] : tensor - func.return %2 : tensor -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_2 -func.func @dynamic_broadcast_in_dim_to_same_shape_2(%arg0: tensor) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor - %0 = shape.shape_of %arg0 : tensor -> !shape.shape - %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[ARG]] : tensor - func.return %2 : tensor -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_3 -func.func @dynamic_broadcast_in_dim_to_same_shape_3(%arg0: tensor) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor - %0 = shape.shape_of %arg0 : tensor -> tensor - %1 = tensor.cast %0 : tensor to tensor<1xindex> - %2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[ARG]] : tensor - func.return %2 : tensor -} - -// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4 -func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor) -> tensor { - // CHECK-SAME: %[[ARG:.*]]: tensor - %0 = shape.shape_of %arg0 : tensor -> !shape.shape - %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor - %2 = tensor.cast %1 : tensor to tensor<1xindex> - %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor - // CHECK: return %[[ARG]] : tensor - func.return %3 : tensor +// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape_4 +func.func @dynamic_broadcast_in_dim_to_same_shape_4(%arg0: tensor) -> tensor { + // CHECK-SAME: %[[ARG:.*]]: tensor + %0 = shape.shape_of %arg0 : tensor -> !shape.shape + %1 = shape.to_extent_tensor %0 : !shape.shape -> tensor + %2 = tensor.cast %1 : tensor to tensor<1xindex> + %3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2) <{ broadcast_dimensions = dense<0> : tensor<1xi64> }> : (tensor, tensor<1xindex>) -> tensor + // CHECK: return %[[ARG]] : tensor + func.return %3 : tensor } // CHECK-LABEL: func @dynamic_broadcast_in_dim_all_dims_non_expanding @@ -842,52 +323,50 @@ func.func @dynamic_broadcast_in_dim_all_dims_non_expanding(%arg0: tensor, func.return %1 : tensor } -// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d -func.func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { - %cst = mhlo.constant dense<0.000000e+00> : tensor - %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor<1x64x224x224xf32> - func.return %b : tensor<1x64x224x224xf32> -} -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> - -// CHECK-LABEL: func @broadcast_in_dim_constant_fold -func.func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { - %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> - %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> - func.return %b : tensor<1x64x4x4xf32> +// CHECK-LABEL: @broadcast_of_reshape +func.func @broadcast_of_reshape(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + // CHECK: [[RESHAPE:%.*]] = mhlo.dynamic_reshape + // CHECK: return [[RESHAPE]] + %0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> } : (tensor, tensor<2xindex>) -> tensor + func.return %1 : tensor } -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> -// CHECK-LABEL: func @broadcast_in_dim_constant_fold_complex -func.func @broadcast_in_dim_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { - %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> - %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> - func.return %b : tensor<1x64x224x224xcomplex> +// CHECK-LABEL: @permutation_broadcast_of_reshape +func.func @permutation_broadcast_of_reshape(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + // CHECK: mhlo.dynamic_reshape + // CHECK: mhlo.dynamic_broadcast_in_dim + %0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> } : (tensor, tensor<2xindex>) -> tensor + func.return %1 : tensor } -// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> -// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xcomplex> +//////// +// DynamicGatherOp -// CHECK-LABEL: @complex_expand_fold -func.func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex>) - %1 = mhlo.real %0 : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "mhlo.imag"(%0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - // CHECK: return %arg0, %arg1 - func.return %1, %2 : tensor<4xf32>, tensor<4xf32> +// CHECK-LABEL: @simplify_dynamic_gather_i64 +func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { + %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi64>} : () -> tensor<2xi64> + %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi64>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: return %[[RET]] + return %1 : tensor<16x64x256xf16> } -// CHECK-LABEL: @complex_collapse_fold -func.func @complex_collapse_fold(%arg0: tensor<4xcomplex>) -> tensor<4xcomplex> { - %0 = mhlo.real %arg0 : (tensor<4xcomplex>) -> (tensor<4xf32>) - %1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex>) -> (tensor<4xf32>) - %2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - // CHECK: return %arg0 - func.return %2 : tensor<4xcomplex> +// CHECK-LABEL: @simplify_dynamic_gather_i32 +func.func @simplify_dynamic_gather_i32(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { + %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi32>) -> tensor<16x64x256xf16> + // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> + // CHECK: return %[[RET]] + return %1 : tensor<16x64x256xf16> } +//////// +// DynamicIotaOp + // CHECK-LABEL: @dynamic_iota_is_static func.func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> { // CHECK: [[RESULT:%.*]] = "mhlo.iota" @@ -929,57 +408,8 @@ func.func @dynamic_iota_constant(%arg0 : tensor<2xindex>) -> tensor<1x?xi32> { func.return %0 : tensor<1x?xi32> } -// CHECK-LABEL: @iota_constant -func.func @iota_constant() -> tensor<1xi32> { - // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> - %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1xi32> - - // CHECK: return [[CONST]] : tensor<1xi32> - func.return %0 : tensor<1xi32> -} - -// CHECK-LABEL: @iota_constant_multi -func.func @iota_constant_multi() -> tensor<1x4xi32> { - // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> - %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1x4xi32> - - // CHECK: return [[CONST]] : tensor<1x4xi32> - func.return %0 : tensor<1x4xi32> -} - -// CHECK-LABEL: @iota_not_lowered_to_constant -func.func @iota_not_lowered_to_constant() -> tensor<4xi32> { - // CHECK: [[RESULT:%.*]] = "mhlo.iota" - // CHECK: return [[RESULT]] - %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> - func.return %0 : tensor<4xi32> -} - -// CHECK-LABEL: @iota_broadcast -func.func @iota_broadcast() -> tensor<5x4xi32> { - // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> - // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<5xi32>) -> tensor<5x4xi32> - %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5x4xi32> - - func.return %0 : tensor<5x4xi32> -} - -// CHECK-LABEL: @iota_broadcast -func.func @iota_broadcast_second() -> tensor<5x4xi32> { - // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> - // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<5x4xi32> - %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<5x4xi32> - - func.return %0 : tensor<5x4xi32> -} - -// CHECK-LABEL: func @fold_copy -// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] -func.func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { - // CHECK: return [[ARG]] - %0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32> - func.return %0 : tensor<1x4xf32> -} +//////// +// DynamicReshapeOp // CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic func.func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { @@ -1010,495 +440,25 @@ func.func @dynamic_reshape_rank_1_to_rank_1(%arg0: tensor>, %3 = tensor.from_elements %2 : tensor<1xindex> %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor, tensor<1xindex>) -> tensor - func.return %4 : tensor -} - -// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape -// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] -// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] -func.func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor, %shape: tensor) -> tensor { - // CHECK: return [[ARG0]] - %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor - %1 = shape.shape_of %0 : tensor -> tensor - %2 = shape.num_elements %1 : tensor -> index - %3 = tensor.from_elements %2 : tensor<1xindex> - %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor, tensor<1xindex>) -> tensor - func.return %4 : tensor -} - -// CHECK-LABEL: do_not_dce_while_with_outfeed -func.func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { - // CHECK: mhlo.while - %0 = "mhlo.while"(%arg0) ({ - ^bb0(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - %1 = "mhlo.create_token"() : () -> !mhlo.token - // Side-effecting op outfeed present inside while. - %2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !mhlo.token) -> !mhlo.token - "mhlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - - func.return %arg0 : tensor -} - -// CHECK-LABEL: dce_while_without_side_effect -func.func @dce_while_without_side_effect(%arg0: tensor) -> tensor { - // CHECK-NOT: mhlo.while - %0 = "mhlo.while"(%arg0) ({ - ^bb0(%arg1: tensor): - %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - %1 = "mhlo.create_token"() : () -> !mhlo.token - "mhlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - - func.return %arg0 : tensor -} - -// CHECK-LABEL: fold_sign_posi -func.func @fold_sign_posi() -> tensor { - // CHECK: %0 = mhlo.constant dense<1> : tensor - %0 = mhlo.constant dense<2> : tensor - %1 = "mhlo.sign"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: fold_sign_negi -func.func @fold_sign_negi() -> tensor { - // CHECK: %0 = mhlo.constant dense<-1> : tensor - %0 = mhlo.constant dense<-2> : tensor - %1 = "mhlo.sign"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: fold_sign_posf -func.func @fold_sign_posf() -> tensor { - // CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = "mhlo.sign"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: fold_sign_negf -func.func @fold_sign_negf() -> tensor { - // CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor - %0 = mhlo.constant dense<-2.000000e+00> : tensor - %1 = "mhlo.sign"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: fold_sign_negzf -func.func @fold_sign_negzf() -> tensor { - // CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor - %0 = mhlo.constant dense<-0.000000e+00> : tensor - %1 = "mhlo.sign"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -// CHECK-LABEL: fold_compare_same_eq -func.func @fold_compare_same_eq(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: fold_compare_same_le -func.func @fold_compare_same_le(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: fold_compare_same_ge -func.func @fold_compare_same_ge(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK-LABEL: fold_compare_same_ne -func.func @fold_compare_same_ne(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: fold_compare_same_lt -func.func @fold_compare_same_lt(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: fold_compare_same_gt -func.func @fold_compare_same_gt(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.constant dense : tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// Address NaN != NaN. -// CHECK-LABEL: dont_fold_compare_same_eq_float -func.func @dont_fold_compare_same_eq_float(%arg0: tensor) -> tensor { - // CHECK: %0 = mhlo.compare EQ, %arg0, %arg0 : (tensor, tensor) -> tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %0 : tensor -} - -// Address NaN != NaN for complex types. -// CHECK-LABEL: dont_fold_compare_same_eq_complex -func.func @dont_fold_compare_same_eq_complex(%arg0: tensor>) -> tensor { - // CHECK: %0 = mhlo.compare EQ, %arg0, %arg0 : (tensor>, tensor>) -> tensor - %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor>, tensor>) -> tensor - func.return %0 : tensor -} - - -// CHECK-LABEL: fold_compare_false_eq -func.func @fold_compare_false_eq() -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} -// CHECK-LABEL: fold_compare_true_eq -func.func @fold_compare_true_eq() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_bools_true_eq -func.func @fold_compare_bools_true_eq(%arg : tensor) -> tensor { - %1 = mhlo.constant dense : tensor - // CHECK: return %arg - %2 = "mhlo.compare"(%arg, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: compare_i1_as_unsigned -func.func @compare_i1_as_unsigned(%arg : tensor) -> tensor { - %true = mhlo.constant dense : tensor - %false = mhlo.constant dense : tensor - // CHECK: %[[FALSE:.*]] = mhlo.constant dense - // CHECK: return %[[FALSE]] - %2 = "mhlo.compare"(%true, %false) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_eq_float -func.func @fold_compare_false_eq_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_eq_float -func.func @fold_compare_true_eq_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_ne -func.func @fold_compare_false_ne() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_ne -func.func @fold_compare_true_ne() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_bools_false_ne -func.func @fold_compare_bools_false_ne(%arg : tensor) -> tensor { - %1 = mhlo.constant dense : tensor - // CHECK: return %arg - %2 = "mhlo.compare"(%arg, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_ne_float -func.func @fold_compare_false_ne_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_ne_float -func.func @fold_compare_true_ne_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_lt -func.func @fold_compare_false_lt() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_lt -func.func @fold_compare_true_lt() -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_lt_float -func.func @fold_compare_false_lt_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_lt_float -func.func @fold_compare_true_lt_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_le -func.func @fold_compare_false_le() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_le -func.func @fold_compare_true_le() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_le_float -func.func @fold_compare_false_le_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<0.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_le_float -func.func @fold_compare_true_le_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_gt -func.func @fold_compare_false_gt() -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<0> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_gt -func.func @fold_compare_true_gt() -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = mhlo.constant dense<0> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_gt_float -func.func @fold_compare_false_gt_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<0.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_gt_float -func.func @fold_compare_true_gt_float() -> tensor { - %0 = mhlo.constant dense<1.> : tensor - %1 = mhlo.constant dense<0.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_ge -func.func @fold_compare_false_ge() -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_ge -func.func @fold_compare_true_ge() -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<0> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_false_ge_float -func.func @fold_compare_false_ge_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<1.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: fold_compare_true_ge_float -func.func @fold_compare_true_ge_float() -> tensor { - %0 = mhlo.constant dense<0.> : tensor - %1 = mhlo.constant dense<0.> : tensor - // CHECK: %0 = mhlo.constant dense : tensor - %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - func.return %2 : tensor -} - -// CHECK-LABEL: unpack_repack_same_tuple -// CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) -func.func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { - %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor - %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token - %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor - %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> - - // CHECK: return [[ARG0]] - func.return %3 : tuple, !mhlo.token, tensor> -} - -// CHECK-LABEL: unpack_repack_same_tuple_single_element -// CHECK-SAME: ([[ARG0:%.*]]: tuple>) -func.func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { - %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor - %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> - - // CHECK: return [[ARG0]] - func.return %3 : tuple> -} - -// CHECK-LABEL: func @fold_get_dimension_size -func.func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { - %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x512xf32>) -> tensor - func.return %size : tensor - // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor - // CHECK-NEXT: return %[[C]] -} - -// CHECK-LABEL: func @fold_get_dimension_size_fail -func.func @fold_get_dimension_size_fail(%I: tensor<1x128x?xf32>) -> tensor { - // CHECK: "mhlo.get_dimension_size" - %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x?xf32>) -> tensor - func.return %size : tensor -} - -// CHECK-LABEL: func @fold_set_dimension_size -// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>) -func.func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { - %dim = mhlo.constant dense<512> : tensor - %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor) -> tensor<1x128x512xf32> - func.return %result : tensor<1x128x512xf32> - - // CHECK-NEXT: return %[[I]] -} - -// CHECK-LABEL: func @fold_select_same -func.func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { - %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor - // CHECK: return %arg0 - func.return %1 : tensor -} - -// CHECK-LABEL: func @fold_select_first -func.func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { - %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor - // CHECK: return %arg0 - func.return %1 : tensor -} - -// CHECK-LABEL: func @fold_select_second -func.func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { - %0 = mhlo.constant dense<0> : tensor - %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor - // CHECK: return %arg1 - func.return %1 : tensor -} - -// CHECK-LABEL: func @fold_select_vector -func.func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { - %0 = mhlo.constant dense<1> : tensor<4xi1> - %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // CHECK: return %arg0 - func.return %1 : tensor<4xf32> -} - -// CHECK-LABEL: func @simplify_not_as_select_pred( -func.func @simplify_not_as_select_pred(%arg0 : tensor<4xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { - %0 = "mhlo.not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> - %1 = "mhlo.select"(%0, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %1 : tensor<4xf32> - - // CHECK: %[[R:.*]] = mhlo.select %arg0, %arg2, %arg1 - // CHECK: return %[[R]] -} - -// CHECK-LABEL: func @simplify_broadcasted_not_as_select_pred( -func.func @simplify_broadcasted_not_as_select_pred(%arg0 : tensor<1xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { - %0 = "mhlo.not"(%arg0) : (tensor<1xi1>) -> tensor<1xi1> - %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64> }> : (tensor<1xi1>) -> tensor<4xi1> - %2 = "mhlo.select"(%1, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - func.return %2 : tensor<4xf32> + func.return %4 : tensor +} - // CHECK: %[[B:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xi1>) -> tensor<4xi1> - // CHECK: %[[R:.*]] = mhlo.select %[[B]], %arg2, %arg1 - // CHECK: return %[[R]] +// CHECK-LABEL: func @dynamic_reshape_of_dynamic_reshape +// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]] +// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]] +func.func @dynamic_reshape_of_dynamic_reshape(%arg0: tensor, %shape: tensor) -> tensor { + // CHECK: return [[ARG0]] + %0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor, tensor) -> tensor + %1 = shape.shape_of %0 : tensor -> tensor + %2 = shape.num_elements %1 : tensor -> index + %3 = tensor.from_elements %2 : tensor<1xindex> + %4 = "mhlo.dynamic_reshape"(%0, %3) : (tensor, tensor<1xindex>) -> tensor + func.return %4 : tensor } +//////// +// GatherOp + // CHECK-LABEL: gather_to_slice func.func @gather_to_slice(%arg0: tensor<5x6x7xf32>) -> tensor<3x6x5xf32> { %0 = arith.constant dense<[1, 2]> : tensor<2xi32> @@ -1583,791 +543,385 @@ func.func @gather_to_slice_indices_clamp_lowerbound(%arg0 : tensor<4x2xui32>) -> // CHECK: return %[[V1]] : tensor<2xui32> } -// CHECK-LABEL: func @fold_and_same -func.func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 - func.return %0 : tensor<4xi32> -} +//////// +// IotaOp -// CHECK-LABEL: func @fold_and_ones -func.func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<-1> : tensor<4xi32> - %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 - func.return %1 : tensor<4xi32> -} +// CHECK-LABEL: @iota_constant +func.func @iota_constant() -> tensor<1xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1xi32> -// CHECK-LABEL: func @fold_and_zeros -func.func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<0> : tensor<4xi32> - %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %0 - func.return %1 : tensor<4xi32> + // CHECK: return [[CONST]] : tensor<1xi32> + func.return %0 : tensor<1xi32> } -// CHECK-LABEL: func @fold_and_constant -func.func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<7> : tensor<4xi32> - // CHECK: mhlo.and - %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %1 : tensor<4xi32> -} +// CHECK-LABEL: @iota_constant_multi +func.func @iota_constant_multi() -> tensor<1x4xi32> { + // CHECK: [[CONST:%.+]] = mhlo.constant dense<0> : tensor<1x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<1x4xi32> -// CHECK-LABEL: func @fold_and_constants -func.func @fold_and_constants() -> tensor<4xi32> { - %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> - %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> - %2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32> - // CHECK: return %0 - func.return %2 : tensor<4xi32> + // CHECK: return [[CONST]] : tensor<1x4xi32> + func.return %0 : tensor<1x4xi32> } -// CHECK-LABEL: func @fold_or_same -func.func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 +// CHECK-LABEL: @iota_not_lowered_to_constant +func.func @iota_not_lowered_to_constant() -> tensor<4xi32> { + // CHECK: [[RESULT:%.*]] = "mhlo.iota" + // CHECK: return [[RESULT]] + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> func.return %0 : tensor<4xi32> } -// CHECK-LABEL: func @fold_or_ones -func.func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<-1> : tensor<4xi32> - %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %0 - func.return %1 : tensor<4xi32> +// CHECK-LABEL: @iota_broadcast +func.func @iota_broadcast() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<5xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5x4xi32> + + func.return %0 : tensor<5x4xi32> } -// CHECK-LABEL: func @fold_or_zeros -func.func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<0> : tensor<4xi32> - %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 - func.return %1 : tensor<4xi32> +// CHECK-LABEL: @iota_broadcast +func.func @iota_broadcast_second() -> tensor<5x4xi32> { + // CHECK: [[IOTA:%.+]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<4xi32> + // CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<4xi32>) -> tensor<5x4xi32> + %0 = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<5x4xi32> + + func.return %0 : tensor<5x4xi32> } -// CHECK-LABEL: func @fold_or_constant -func.func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<7> : tensor<4xi32> - // CHECK: mhlo.or - %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %1 : tensor<4xi32> +//////// +// PadOp + +// CHECK-LABEL: @pad_zero_length +func.func @pad_zero_length(%arg0: tensor<5x0xf32>, %arg1: tensor) -> tensor<7x2xf32> { + // CHECK: %[[RES:.+]] = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<7x2xf32> + %0 = "mhlo.pad"(%arg0, %arg1) { + edge_padding_low = dense<1> : tensor<2xi64>, + edge_padding_high = dense<1> : tensor<2xi64>, + interior_padding = dense<0> : tensor<2xi64> + } : (tensor<5x0xf32>, tensor) -> tensor<7x2xf32> + // CHECK: return %[[RES]] + func.return %0 : tensor<7x2xf32> } -// CHECK-LABEL: func @fold_or_zeros_right -func.func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<0> : tensor<4xi32> - %1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 - func.return %1 : tensor<4xi32> +//////// +// RealDynamicSliceOp + +// CHECK-LABEL: @simplify_real_dynamic_slice_to_slice +func.func @simplify_real_dynamic_slice_to_slice(%arg0: tensor) -> tensor<1x4xf32> { + %0 = mhlo.constant dense<[0, 0]> : tensor<2xi32> + %1 = mhlo.constant dense<[1, 4]> : tensor<2xi32> + %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = mhlo.real_dynamic_slice %arg0, %0, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-DAG-SAME: start_indices = dense<[0, 0]> : tensor<2xi64> + // CHECK-DAG-SAME: limit_indices = dense<[1, 4]> : tensor<2xi64> + // CHECK-DAG-SAME: strides = dense<[1, 1]> : tensor<2xi64>} + // CHECK: return %[[RESULT]] : tensor<1x4xf32> + return %3 : tensor<1x4xf32> } -// CHECK-LABEL: func @fold_or_zeros_constants -func.func @fold_or_zeros_constants() -> tensor<4xi32> { - %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> - %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> - %2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32> - // CHECK: return %0 - func.return %2 : tensor<4xi32> +// CHECK-LABEL: @simplify_real_dynamic_slice_to_dynamic_slice +func.func @simplify_real_dynamic_slice_to_dynamic_slice(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x4xf32> { + %0 = mhlo.constant dense<[1, 4]> : tensor<2xi32> + %1 = mhlo.add %arg1, %0 : tensor<2xi32> + %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> + %3 = mhlo.real_dynamic_slice %arg0, %arg1, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> + return %3 : tensor<1x4xf32> + // CHECK: [[START_INDEX_0_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_0_0D:%.*]] = mhlo.reshape [[START_INDEX_0_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> + // CHECK-NEXT: [[START_INDEX_1_0D:%.*]] = mhlo.reshape [[START_INDEX_1_1D]] : (tensor<1xi32>) -> tensor + // CHECK-NEXT: [[RESULT:%.*]] = "mhlo.dynamic_slice"(%arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]]) <{ + // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: }> : (tensor, tensor, tensor) -> tensor<1x4xf32> + // CHECK-NEXT: return [[RESULT]] : tensor<1x4xf32> } -// CHECK-LABEL: func @fold_xor_same -func.func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32> - // CHECK: return %0 - func.return %0 : tensor<4xi32> +//////// +// ReshapeOp + +// CHECK-LABEL: @reshape_of_same_shape_op_result +func.func @reshape_of_same_shape_op_result(%arg: tensor, + %shape: tensor<2xindex>) -> tensor { + // CHECK: mhlo.dynamic_reshape + // CHECK-NEXT: mhlo.abs + // CHECK-NOT: mhlo.dynamic_reshape + %0 = "mhlo.dynamic_reshape"(%arg, %shape) : (tensor, tensor<2xindex>) -> tensor + %1 = "mhlo.abs"(%0) : (tensor) -> tensor + %2 = "mhlo.dynamic_reshape"(%1, %shape) : (tensor, tensor<2xindex>) -> tensor + func.return %2 : tensor } -// CHECK-LABEL: func @fold_xor_same_dynamic -func.func @fold_xor_same_dynamic(%arg0 : tensor) -> tensor { - %0 = "mhlo.xor"(%arg0, %arg0) : (tensor, tensor) -> tensor - // CHECK: mhlo.xor - func.return %0 : tensor +// CHECK-LABEL: @eliminate_redundant_reshape +func.func @eliminate_redundant_reshape(%arg : tensor<1x32xi16>) -> tensor<1x32xi16> { + %0 = "mhlo.reshape"(%arg) : (tensor<1x32xi16>) -> tensor<2x16xi16> + %1 = "mhlo.reshape"(%0) : (tensor<2x16xi16>) -> tensor<1x32xi16> + // CHECK: return %arg0 : tensor<1x32xi16> + func.return %1 : tensor<1x32xi16> } -// CHECK-LABEL: func @fold_xor_ones_left -func.func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<-1> : tensor<4xi32> - // CHECK: mhlo.xor - %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %1 : tensor<4xi32> +// CHECK-LABEL: @eliminate_identity_reshape +func.func @eliminate_identity_reshape(%arg : tensor<1x32xi16>) -> tensor<1x32xi16> { + // CHECK-NOT: mhlo.reshape + %0 = "mhlo.reshape"(%arg) : (tensor<1x32xi16>) -> tensor<1x32xi16> + // CHECK: return %arg0 : tensor<1x32xi16> + func.return %0 : tensor<1x32xi16> } -// CHECK-LABEL: func @fold_xor_ones_right -func.func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<-1> : tensor<4xi32> - // CHECK: mhlo.xor - %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - func.return %1 : tensor<4xi32> +//////// +// SelectOp + +// CHECK-LABEL: func @simplify_not_as_select_pred( +func.func @simplify_not_as_select_pred(%arg0 : tensor<4xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.not"(%arg0) : (tensor<4xi1>) -> tensor<4xi1> + %1 = "mhlo.select"(%0, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> + + // CHECK: %[[R:.*]] = mhlo.select %arg0, %arg2, %arg1 + // CHECK: return %[[R]] } -// CHECK-LABEL: func @fold_xor_zeros_left -func.func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<0> : tensor<4xi32> - %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: return %arg0 - func.return %1 : tensor<4xi32> +// CHECK-LABEL: func @simplify_broadcasted_not_as_select_pred( +func.func @simplify_broadcasted_not_as_select_pred(%arg0 : tensor<1xi1>, %arg1 : tensor<4xf32>, %arg2 : tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.not"(%arg0) : (tensor<1xi1>) -> tensor<1xi1> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0]> : tensor<1xi64> }> : (tensor<1xi1>) -> tensor<4xi1> + %2 = "mhlo.select"(%1, %arg1, %arg2) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + func.return %2 : tensor<4xf32> + + // CHECK: %[[B:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<1xi1>) -> tensor<4xi1> + // CHECK: %[[R:.*]] = mhlo.select %[[B]], %arg2, %arg1 + // CHECK: return %[[R]] } -// CHECK-LABEL: func @fold_xor_zeros_right -func.func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { - %0 = mhlo.constant dense<0> : tensor<4xi32> - %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> +//////// +// SliceOp + +// CHECK-LABEL: dynamic_update_slice_fold_length_0 +func.func @dynamic_update_slice_fold_length_0(%arg0: tensor<3x4xi64>, %arg1: tensor<3x0xi64>) -> tensor<3x4xi64> { // CHECK: return %arg0 - func.return %1 : tensor<4xi32> + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x0xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %1 : tensor<3x4xi64> } -// CHECK-LABEL: func @fold_xor_zeros_constants -func.func @fold_xor_zeros_constants() -> tensor<4xi32> { - %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> - %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> - %2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> - // CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32> - // CHECK: return %0 - func.return %2 : tensor<4xi32> +// CHECK-LABEL: dynamic_update_slice_identity_update +func.func @dynamic_update_slice_identity_update(%arg0: tensor<3x4xi64>, %arg1: tensor<3x4xi64>) -> tensor<3x4xi64> { + // CHECK: return %arg1 + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor<3x4xi64>, tensor<3x4xi64>, tensor, tensor) -> tensor<3x4xi64> + func.return %1 : tensor<3x4xi64> } -// CHECK-LABEL: func @fold_negate_int -func.func @fold_negate_int() -> tensor<4xi32> { - %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> - // CHECK: mhlo.constant dense<[0, -1, -6, 3]> - %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> - func.return %1 : tensor<4xi32> +// CHECK-LABEL: dynamic_update_slice_fold_fail_dynamic_shapes +func.func @dynamic_update_slice_fold_fail_dynamic_shapes(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %0, %0) : (tensor, tensor, tensor, tensor) -> tensor + func.return %1 : tensor + // CHECK: %[[CST:.*]] = mhlo.constant dense<0> : tensor + // CHECK: %[[VAL:.*]] = mhlo.dynamic_update_slice %arg0, %arg1, %[[CST]], %[[CST]] : (tensor, tensor, tensor, tensor) -> tensor + // CHECK: return %[[VAL]] : tensor } -// CHECK-LABEL: func @fold_negate_float -func.func @fold_negate_float() -> tensor<4xf32> { - %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> - // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> - %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> - func.return %1 : tensor<4xf32> +// CHECK-LABEL: dynamic_slice_variable_start +func.func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor, %arg2: tensor) -> tensor<1x4xi32> { + // CHECK: "mhlo.dynamic_slice" + %1 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<3x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %1 : tensor<1x4xi32> } -// CHECK-LABEL func @fold_not() -func.func @fold_not() -> tensor<2x2xi1> { - %0 = mhlo.constant dense<[[true, false], [true, false]]> : tensor<2x2xi1> - // CHECK{LITERAL}: mhlo.constant dense<[[false, true], [false, true]]> : tensor<2x2xi1> - %1 = "mhlo.not"(%0) : (tensor<2x2xi1>) -> tensor<2x2xi1> - func.return %1 : tensor<2x2xi1> +// CHECK-LABEL: dynamic_slice_constant_start +func.func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> { + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64> + // CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64> + // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} + // CHECK: return %[[RESULT]] : tensor<2xi32> + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.dynamic_slice"(%arg0, %0) <{slice_sizes = dense<2> : tensor<1xi64>}> : (tensor<4xi32>, tensor) -> tensor<2xi32> + func.return %1 : tensor<2xi32> } -// CHECK-LABEL func @fold_not_i32() -func.func @fold_not_i32() -> tensor<2x2xi32> { - %0 = mhlo.constant dense<[[42, -12], [1, 0]]> : tensor<2x2xi32> - // CHECK{LITERAL}: mhlo.constant dense<[[-43, 11], [-2, -1]]> : tensor<2x2xi32> - %1 = "mhlo.not"(%0) : (tensor<2x2xi32>) -> tensor<2x2xi32> - func.return %1 : tensor<2x2xi32> +// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape +func.func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: mhlo.dynamic_slice + // CHECK-NOT: mhlo.slice + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor, tensor, tensor) -> tensor<1x4xi32> + func.return %2 : tensor<1x4xi32> } -// CHECK-LABEL: func @fold_sqrt_f16_constants -func.func @fold_sqrt_f16_constants() -> tensor<4xf16> { - %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf16> - %1 = "mhlo.sqrt"(%0) : (tensor<4xf16>) -> tensor<4xf16> - // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf16> - // CHECK-NOT: mhlo.sqrt - func.return %1 : tensor<4xf16> +// CHECK-LABEL: dynamic_slice_constant_start_upper_bound +func.func @dynamic_slice_constant_start_upper_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-SAME: limit_indices = dense<[8, 4]> : tensor<2xi64> + // CHECK-SAME: start_indices = dense<[7, 0]> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %0 = mhlo.constant dense<10> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %2 : tensor<1x4xi32> } -// CHECK-LABEL: func @fold_sqrt_bf16_constants -func.func @fold_sqrt_bf16_constants() -> tensor<4xbf16> { - %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xbf16> - %1 = "mhlo.sqrt"(%0) : (tensor<4xbf16>) -> tensor<4xbf16> - // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> - // CHECK-NOT: mhlo.sqrt - func.return %1 : tensor<4xbf16> +// CHECK-LABEL: dynamic_slice_constant_start_lower_bound +func.func @dynamic_slice_constant_start_lower_bound(%arg0: tensor<8x4xi32>, %arg1: tensor<2xi64>) -> tensor<1x4xi32> { + // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) + // CHECK-SAME: limit_indices = dense<[1, 4]> : tensor<2xi64> + // CHECK-SAME: start_indices = dense<0> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK: return %[[RESULT]] : tensor<1x4xi32> + %0 = mhlo.constant dense<-1> : tensor + %1 = mhlo.constant dense<0> : tensor + %2 = "mhlo.dynamic_slice"(%arg0, %0, %1) <{slice_sizes = dense<[1, 4]> : tensor<2xi64>}> : (tensor<8x4xi32>, tensor, tensor) -> tensor<1x4xi32> + func.return %2 : tensor<1x4xi32> } -// CHECK-LABEL: func @fold_sqrt_f32_constants -func.func @fold_sqrt_f32_constants() -> tensor<4xf32> { - %0 = mhlo.constant dense<1.0> : tensor<4xf32> - %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> - // CHECK-NOT: mhlo.sqrt - func.return %1 : tensor<4xf32> -} +// CHECK-LABEL: slice_2D_noop +// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64> +func.func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = "mhlo.slice"(%arg0) <{ limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x2xi64>) -> (tensor<2x2xi64>) -// CHECK-LABEL: func @fold_sqrt_f64_constants -func.func @fold_sqrt_f64_constants() -> tensor<4xf64> { - %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64> - %1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> - // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> - // CHECK-NOT: mhlo.sqrt - func.return %1 : tensor<4xf64> + // CHECK-NEXT: return [[ARG]] + func.return %0 : tensor<2x2xi64> } -// CHECK-LABEL: func @fold_rsqrt_f16_constants -func.func @fold_rsqrt_f16_constants() -> tensor<4xf16> { - %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf16> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xf16>) -> tensor<4xf16> - // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf16> - // CHECK-NOT: mhlo.rsqrt - func.return %1 : tensor<4xf16> -} +// CHECK-LABEL: slice_concat_empty +func.func @slice_concat_empty(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<0x5xf32>) + %2 = "mhlo.concatenate"(%1, %arg2) <{ dimension = 0 : i64 }> : (tensor<0x5xf32>, tensor<1x5xf32>) -> tensor<1x5xf32> -// CHECK-LABEL: func @fold_rsqrt_bf16_constants -func.func @fold_rsqrt_bf16_constants() -> tensor<4xbf16> { - %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xbf16> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xbf16>) -> tensor<4xbf16> - // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xbf16> - // CHECK-NOT: mhlo.rsqrt - func.return %1 : tensor<4xbf16> + // CHECK: return %arg2 + func.return %2 : tensor<1x5xf32> } -// CHECK-LABEL: func @fold_rsqrt_f32_constants -func.func @fold_rsqrt_f32_constants() -> tensor<4xf32> { - %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf32> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf32> - // CHECK-NOT: mhlo.rsqrt - func.return %1 : tensor<4xf32> -} +//////// +// SortOp -// CHECK-LABEL: func @fold_rsqrt_f64_constants -func.func @fold_rsqrt_f64_constants() -> tensor<4xf64> { - %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf64> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> - // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf64> - // CHECK-NOT: mhlo.rsqrt - func.return %1 : tensor<4xf64> +// CHECK-LABEL: @sort_drop_second_arg +func.func @sort_drop_second_arg(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] + // CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) <{dimension = 0 : i64, is_stable = false}> ({ + // CHECK: ^bb0(%[[ARG2:.+]]: tensor, %[[ARG3:.+]]: tensor) + // CHECK: %[[CMP:.+]] = mhlo.compare GT, %[[ARG2]], %[[ARG3]] : (tensor, tensor) -> tensor + // CHECK: mhlo.return %[[CMP]] : tensor + // CHECK: }) : (tensor<3xi32>) -> tensor<3xi32> + // CHECK: return %[[RES]] : tensor<3xi32> + %0:2 = "mhlo.sort"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): + %1 = "mhlo.compare"(%arg2, %arg3) { + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimension = 0 : i64, + is_stable = false + } : (tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + func.return %0#0 : tensor<3xi32> } -// CHECK-LABEL: func @not_fold_sqrt_neg_constants -func.func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { - %0 = mhlo.constant dense<-1.0> : tensor<4xf32> - %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> - // CHECK: mhlo.sqrt - func.return %1 : tensor<4xf32> -} -// CHECK-LABEL: func @not_fold_rsqrt_neg_constants -func.func @not_fold_rsqrt_neg_constants() -> tensor<4xf32> { - %0 = mhlo.constant dense<-1.0> : tensor<4xf32> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> - // CHECK: mhlo.rsqrt - func.return %1 : tensor<4xf32> +// CHECK-LABEL: @sort_no_dim_provided +func.func @sort_no_dim_provided(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] + // CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) + // CHECK: dimension = 1 : i64 + // CHECK: return %[[RES]] : tensor<3x5xi32> + %0 = "mhlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "mhlo.compare"(%arg1, %arg2) { + comparison_direction = #mhlo + } : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimension = -1 : i64, + is_stable = false + } : (tensor<3x5xi32>) -> tensor<3x5xi32> + func.return %0 : tensor<3x5xi32> } -// CHECK-LABEL: func @fold_sqrt_const_zero -func.func @fold_sqrt_const_zero() -> tensor<4xf32> { - %0 = mhlo.constant dense<0.0> : tensor<4xf32> - %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.000000e+00> : tensor<4xf32> - // CHECK-NOT: mhlo.sqrt - func.return %1 : tensor<4xf32> +//////// +// TupleOp + +// CHECK-LABEL: unpack_repack_same_tuple +// CHECK-SAME: ([[ARG0:%.*]]: tuple, !mhlo.token, tensor>) +func.func @unpack_repack_same_tuple(%arg0: tuple, !mhlo.token, tensor>) -> tuple, !mhlo.token, tensor> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %1 = "mhlo.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, !mhlo.token, tensor>) -> !mhlo.token + %2 = "mhlo.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, !mhlo.token, tensor>) -> tensor + %3 = "mhlo.tuple"(%0, %1, %2) : (tensor, !mhlo.token, tensor) -> tuple, !mhlo.token, tensor> + + // CHECK: return [[ARG0]] + func.return %3 : tuple, !mhlo.token, tensor> } -// CHECK-LABEL: func @not_fold_rsqrt_const_zero -func.func @not_fold_rsqrt_const_zero() -> tensor<4xf32> { - %0 = mhlo.constant dense<0.0> : tensor<4xf32> - %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.000000e+00> : tensor<4xf32> - // CHECK: mhlo.rsqrt - func.return %1 : tensor<4xf32> -} +// CHECK-LABEL: unpack_repack_same_tuple_single_element +// CHECK-SAME: ([[ARG0:%.*]]: tuple>) +func.func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tuple> { + %0 = "mhlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple>) -> tensor + %3 = "mhlo.tuple"(%0) : (tensor) -> tuple> -// CHECK-LABEL: func @fold_abs -func.func @fold_abs() -> tensor<4xf32> { - %0 = mhlo.constant dense<-1.0> : tensor<4xf32> - %1 = "mhlo.abs"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> - // CHECK-NOT: mhlo.abs - func.return %1 : tensor<4xf32> + // CHECK: return [[ARG0]] + func.return %3 : tuple> } -// CHECK-LABEL: func @fold_sine -func.func @fold_sine() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.sine"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.909297406> : tensor<4xf32> - // CHECK-NOT: mhlo.sine - func.return %1 : tensor<4xf32> -} +//////// +// WhileOp DCE -// CHECK-LABEL: func @fold_cosine -func.func @fold_cosine() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.cosine"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<-0.416146845> : tensor<4xf32> - // CHECK-NOT: mhlo.cosine - func.return %1 : tensor<4xf32> -} +// CHECK-LABEL: do_not_dce_while_with_outfeed +func.func @do_not_dce_while_with_outfeed(%arg0: tensor) -> tensor { + // CHECK: mhlo.while + %0 = "mhlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "mhlo.create_token"() : () -> !mhlo.token + // Side-effecting op outfeed present inside while. + %2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor, !mhlo.token) -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor -// CHECK-LABEL: func @fold_tanh -func.func @fold_tanh() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.tanh"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.964027583> : tensor<4xf32> - // CHECK-NOT: mhlo.tanh - func.return %1 : tensor<4xf32> + func.return %arg0 : tensor } -// CHECK-LABEL: func @fold_exponential -func.func @fold_exponential() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.exponential"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<7.3890562> : tensor<4xf32> - // CHECK-NOT: mhlo.exponential - func.return %1 : tensor<4xf32> -} +// CHECK-LABEL: dce_while_without_side_effect +func.func @dce_while_without_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: mhlo.while + %0 = "mhlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + %1 = "mhlo.create_token"() : () -> !mhlo.token + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor -// CHECK-LABEL: func @fold_logistic -func.func @fold_logistic() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.logistic"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.880797088> : tensor<4xf32> - // CHECK-NOT: mhlo.logistic - func.return %1 : tensor<4xf32> + func.return %arg0 : tensor } -// CHECK-LABEL: func @fold_log -func.func @fold_log() -> tensor<4xf32> { - %0 = mhlo.constant dense<2.0> : tensor<4xf32> - %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<0.693147182> : tensor<4xf32> - // CHECK-NOT: mhlo.log - func.return %1 : tensor<4xf32> -} +//////// +// Tensor/Shape canonicalize -// CHECK-LABEL: func @not_fold_log_neg_constants -func.func @not_fold_log_neg_constants() -> tensor<4xf32> { - %0 = mhlo.constant dense<-1.0> : tensor<4xf32> - %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> - // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> - // CHECK: mhlo.log - func.return %1 : tensor<4xf32> -} +// CHECK-LABEL: concatenate_noop_typecast +func.func @concatenate_noop_typecast(%arg0: tensor) -> tensor<4xi32> { + // CHECK-SAME: [[ARG:%.+]]: tensor + // CHECK-NEXT: [[RES:%.+]] = tensor.cast [[ARG]] : tensor to tensor<4xi32> + %0 = "mhlo.concatenate"(%arg0) <{ dimension = 0 : i64 }> : (tensor) -> tensor<4xi32> -// CHECK-LABEL: func @fold_if_true( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @fold_if_true(%arg0 : tensor, %arg1 : tensor) -> tensor { - // CHECK-NOT: mhlo.if - // CHECK: return %[[ARG0]] - %true = mhlo.constant dense : tensor - %0 = "mhlo.if"(%true) ({ - "mhlo.return"(%arg0) : (tensor) -> () - }, { - "mhlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @fold_if_false( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @fold_if_false(%arg0 : tensor, %arg1 : tensor) -> tensor { - // CHECK-NOT: mhlo.if - // CHECK: return %[[ARG1]] - %false = mhlo.constant dense : tensor - %0 = "mhlo.if"(%false) ({ - "mhlo.return"(%arg0) : (tensor) -> () - }, { - "mhlo.return"(%arg1) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @fold_case( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @fold_case(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK-NOT: mhlo.case - // CHECK: return %[[ARG1]] - %c1 = mhlo.constant dense<1> : tensor - %0 = "mhlo.case"(%c1) ({ - "mhlo.return"(%arg0) : (tensor) -> () - }, { - "mhlo.return"(%arg1) : (tensor) -> () - }, { - "mhlo.return"(%arg2) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @fold_case_negative_index( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @fold_case_negative_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK-NOT: mhlo.case - // CHECK: return %[[ARG2]] - %m1000 = mhlo.constant dense<-1000> : tensor - %0 = "mhlo.case"(%m1000) ({ - "mhlo.return"(%arg0) : (tensor) -> () - }, { - "mhlo.return"(%arg1) : (tensor) -> () - }, { - "mhlo.return"(%arg2) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: func @fold_case_oob_index( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] -// CHECK-SAME: ) -func.func @fold_case_oob_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - // CHECK-NOT: mhlo.case - // CHECK: return %[[ARG2]] - %c1000 = mhlo.constant dense<1000> : tensor - %0 = "mhlo.case"(%c1000) ({ - "mhlo.return"(%arg0) : (tensor) -> () - }, { - "mhlo.return"(%arg1) : (tensor) -> () - }, { - "mhlo.return"(%arg2) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// CHECK-LABEL: @tensor_flow_scatter_v1_update -func.func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[0, 2]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_v2_update -func.func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[0, 2]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [0], - inserted_window_dims = [1], - scatter_dims_to_operand_dims = [1], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_add -func.func @tensor_flow_scatter_add() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[0, 2]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) - "mhlo.return"(%4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_repeated -func.func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[1, 1]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) - "mhlo.return"(%4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_multiple_batch -func.func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> - %2 = arith.constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) - "mhlo.return"(%4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [1], - scatter_dims_to_operand_dims = [1], - index_vector_dim = 2, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_nd -func.func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { - %0 = arith.constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> - %1 = arith.constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> - %2 = arith.constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0, 1], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> - func.return %3 : tensor<3x3x2xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] - // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] - // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] - // CHECK-SAME: ]> : tensor<3x3x2xi32> -} - -// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector -func.func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { - %0 = arith.constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> - %1 = arith.constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> - %2 = arith.constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0, 1], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 0, - >, - unique_indices = false - } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> - func.return %3 : tensor<3x3x2xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] - // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] - // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] - // CHECK-SAME: ]> : tensor<3x3x2xi32> -} - -// CHECK-LABEL: @scatter_batch_dus -func.func @scatter_batch_dus() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> - %2 = arith.constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1, 2], - scatter_dims_to_operand_dims = [0, 1], - index_vector_dim = 0, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: mhlo.constant dense<[ - // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] - // CHECK-SAME: ]> : tensor<3x3xi32> -} - -// CHECK-LABEL: @scatter_no_update_window_dim -func.func @scatter_no_update_window_dim() -> tensor<3xi32> { - %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32> - %1 = arith.constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> - %2 = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) - "mhlo.return"(%4) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 2, - >, - unique_indices = false - } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> - func.return %3 : tensor<3xi32> - // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> -} - -// CHECK-LABEL: @scatter_negative_index -func.func @scatter_negative_index() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[0, -1]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: constant dense<{{\[}}[1, 2, 3], [4, 5, 6], [7, 8, 9]{{\]}}> : tensor<3x3xi32> - // CHECK: "mhlo.scatter" -} - -// CHECK-LABEL: @scatter_out_of_bound -func.func @scatter_out_of_bound() -> tensor<3x3xi32> { - %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> - %1 = arith.constant dense<[1, 5]> : tensor<2xi32> - %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> - %3 = "mhlo.scatter"(%0, %1, %2) ({ - ^bb0(%arg0: tensor, %arg1: tensor): - "mhlo.return"(%arg1) : (tensor) -> () - }) {indices_are_sorted = false, - scatter_dimension_numbers = #mhlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = false - } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> - func.return %3 : tensor<3x3xi32> - // CHECK: constant dense<{{\[}}[1, 2, 3], [4, 5, 6], [7, 8, 9]{{\]}}> : tensor<3x3xi32> - // CHECK: "mhlo.scatter" -} - -// CHECK-LABEL: @scatter_complex -func.func public @scatter_complex() -> tensor<1xcomplex> { - %0 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> - %1 = mhlo.constant dense<0> : tensor<1xi32> - %2 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<1xcomplex> - %3 = "mhlo.scatter"(%2, %1, %0) ({ - ^bb0(%arg0: tensor>, %arg1: tensor>): - "mhlo.return"(%arg1) : (tensor>) -> () - }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<1xcomplex>, tensor<1xi32>, tensor>) -> tensor<1xcomplex> - func.return %3 : tensor<1xcomplex> -} -// CHECK: "mhlo.scatter" - - -// CHECK-LABEL: @pad_identity_fold -func.func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> { - %0 = arith.constant dense<0.0> : tensor - %1 = "mhlo.pad"(%arg0, %0) { - edge_padding_low = dense<0> : tensor<2xi64>, - edge_padding_high = dense<0> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> - } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: return %arg0 : tensor<5x7xf32> -} - -// CHECK-LABEL: @pad_fold -func.func @pad_fold() -> tensor<4x5xi32> { - %0 = arith.constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> - %1 = arith.constant dense<1> : tensor - %3 = "mhlo.pad"(%0, %1) { - edge_padding_low = dense<[1, 0]> : tensor<2xi64>, - edge_padding_high = dense<[1, 2]> : tensor<2xi64>, - interior_padding = dense<[0, 1]> : tensor<2xi64> - } : (tensor<2x2xi32>, tensor) -> tensor<4x5xi32> - func.return %3 : tensor<4x5xi32> - // CHECK: constant dense<[ - // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] - // CHECK-SAME: ]> : tensor<4x5xi32> -} - -// CHECK-LABEL: @pad_negative_fold -func.func @pad_negative_fold() -> tensor<4x4xi32> { - %0 = arith.constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> - %1 = arith.constant dense<1> : tensor - %3 = "mhlo.pad"(%0, %1) { - edge_padding_low = dense<[1, -1]> : tensor<2xi64>, - edge_padding_high = dense<[1, 2]> : tensor<2xi64>, - interior_padding = dense<[0, 1]> : tensor<2xi64> - } : (tensor<2x2xi32>, tensor) -> tensor<4x4xi32> - func.return %3 : tensor<4x4xi32> - // CHECK: "mhlo.pad" -} - -// CHECK-LABEL: @pad_fold_zero_elements -func.func @pad_fold_zero_elements() -> tensor<3xi32> { - %0 = mhlo.constant dense<> : tensor<0xi32> - %1 = mhlo.constant dense<7> : tensor - %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<0xi32>, tensor) -> tensor<3xi32> - func.return %2 : tensor<3xi32> - // CHECK: mhlo.constant dense<7> : tensor<3xi32> -} - -// CHECK-LABEL: @pad_float_fold -func.func @pad_float_fold() -> tensor<2xf32> { - %0 = mhlo.constant dense<2.000000e+00> : tensor<1xf32> - %1 = mhlo.constant dense<1.000000e+00> : tensor - %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xf32>, tensor) -> tensor<2xf32> - return %2 : tensor<2xf32> - // CHECK: mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: return [[RES]] + func.return %0 : tensor<4xi32> } -// CHECK-LABEL: @pad_zero_length -func.func @pad_zero_length(%arg0: tensor<5x0xf32>, %arg1: tensor) -> tensor<7x2xf32> { - // CHECK: %[[RES:.+]] = "mhlo.broadcast_in_dim"(%arg1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<7x2xf32> - %0 = "mhlo.pad"(%arg0, %arg1) { - edge_padding_low = dense<1> : tensor<2xi64>, - edge_padding_high = dense<1> : tensor<2xi64>, - interior_padding = dense<0> : tensor<2xi64> - } : (tensor<5x0xf32>, tensor) -> tensor<7x2xf32> - // CHECK: return %[[RES]] - func.return %0 : tensor<7x2xf32> +// CHECK-LABEL: slice_unknown_shape +func.func @slice_unknown_shape(%arg0: tensor) -> tensor { + // CHECK: "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor + %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<0> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor) -> tensor + func.return %0 : tensor } // CHECK-LABEL: @pad_zero_length_dyn @@ -2392,24 +946,6 @@ func.func @pad_zero_length_dyn(%arg0: tensor, %arg1: tensor) -> te func.return %0 : tensor } -// CHECK-LABEL: @dynamic_pad_identity_fold -func.func @dynamic_pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<11x15xf32> { - %0 = arith.constant dense<0.0> : tensor - %1 = arith.constant dense<1> : tensor<2xi32> - %2 = arith.constant dense<1> : tensor<2xi32> - %3 = arith.constant dense<1> : tensor<2xi32> - // CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : tensor - // CHECK: %[[PAD:.+]] = "mhlo.pad"(%arg0, %[[CST]]) - // CHECK-SAME: edge_padding_high = dense<1> : tensor<2xi64> - // CHECK-SAME: edge_padding_low = dense<1> : tensor<2xi64> - // CHECK-SAME: interior_padding = dense<1> : tensor<2xi64>} - // CHECK-SAME: (tensor<5x7xf32>, tensor) -> tensor<11x15xf32> - %4 = "mhlo.dynamic_pad"(%arg0, %0, %1, %2, %3) { - } : (tensor<5x7xf32>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<11x15xf32> - // return %[[PAD]] - func.return %4 : tensor<11x15xf32> -} - // CHECK-LABEL: @dynamic_pad_length_dyn func.func @dynamic_pad_length_dyn( %arg0: tensor, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>, @@ -2442,295 +978,3 @@ func.func @dynamic_pad_length_dyn( // CHECK: return %[[BROAD]] func.return %1 : tensor } - -// CHECK-LABEL: @pad_complex_fold -func.func @pad_complex_fold() -> tensor<2xcomplex> { - %0 = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<1xcomplex> - %1 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> - %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xcomplex>, tensor>) -> tensor<2xcomplex> - return %2 : tensor<2xcomplex> - // CHECK: mhlo.constant dense<[(2.000000e+00,0.000000e+00), (1.000000e+00,0.000000e+00)]> : tensor<2xcomplex> -} - -// CHECK-LABEL: @identity_broadcast_reshape -func.func @identity_broadcast_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { - %0 = "mhlo.broadcast"(%arg0) { - broadcast_sizes = dense<[1]> : tensor<1xi64>} : (tensor<128xf32>) -> tensor<1x128xf32> - %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> - func.return %1 : tensor<128xf32> - // CHECK: return %arg0 : tensor<128xf32> -} - -// CHECK-LABEL: @identity_broadcast_in_dim_reshape -func.func @identity_broadcast_in_dim_reshape(%arg0: tensor<128xf32>) -> tensor<128xf32> { - %0 = "mhlo.broadcast_in_dim"(%arg0) { - broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (tensor<128xf32>) -> tensor<1x128xf32> - %1 = "mhlo.reshape"(%0) : (tensor<1x128xf32>) -> tensor<128xf32> - func.return %1 : tensor<128xf32> - // CHECK: return %arg0 : tensor<128xf32> -} - -// CHECK-LABEL: @eliminate_identity_convert -func.func @eliminate_identity_convert(%arg : tensor) -> tensor { - // CHECK-NOT: mhlo.convert - %0 = "mhlo.convert"(%arg) : (tensor) -> tensor - // CHECK: return %arg0 : tensor - func.return %0 : tensor -} - -func.func @fold_fptosi() -> tensor { - %0 = mhlo.constant dense<65535.000000e+00> : tensor - // CHECK: mhlo.constant dense<32767> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_fptosi_rounding() -> tensor { - %0 = mhlo.constant dense<-1.5> : tensor - // CHECK: mhlo.constant dense<-1> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_fptoui() -> tensor { - %0 = mhlo.constant dense<-1.000000e+00> : tensor - // CHECK: mhlo.constant dense<0> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_sitofp() -> tensor { - %0 = mhlo.constant dense<-1> : tensor - // CHECK: mhlo.constant dense<-1.000000e+00> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_uitofp() -> tensor { - %0 = mhlo.constant dense<65535> : tensor - // CHECK: mhlo.constant dense<6.553500e+04> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_uitoui() -> tensor { - %0 = mhlo.constant dense<65535> : tensor - // CHECK: mhlo.constant dense<65535> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_uitosi() -> tensor { - %0 = mhlo.constant dense<65535> : tensor - // CHECK: mhlo.constant dense<65535> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_sitoui() -> tensor { - %0 = mhlo.constant dense<-1> : tensor - // CHECK: mhlo.constant dense<4294967295> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_sitosi() -> tensor { - %0 = mhlo.constant dense<-1> : tensor - // CHECK: mhlo.constant dense<-1> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @fold_predtosi() -> tensor { - %0 = mhlo.constant dense : tensor - // CHECK: mhlo.constant dense<0> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor - func.return %1 : tensor -} - -func.func @not_fold_itouq() -> tensor> { - // CHECK: mhlo.constant dense<1> : tensor - %0 = mhlo.constant dense<1> : tensor - %1 = "mhlo.convert"(%0) : (tensor) -> tensor> - func.return %1 : tensor> -} - -// CHECK-LABEL: @eliminate_redundant_reshape -func.func @eliminate_redundant_reshape(%arg : tensor<1x32xi16>) -> tensor<1x32xi16> { - %0 = "mhlo.reshape"(%arg) : (tensor<1x32xi16>) -> tensor<2x16xi16> - %1 = "mhlo.reshape"(%0) : (tensor<2x16xi16>) -> tensor<1x32xi16> - // CHECK: return %arg0 : tensor<1x32xi16> - func.return %1 : tensor<1x32xi16> -} - -// CHECK-LABEL: @eliminate_identity_reshape -func.func @eliminate_identity_reshape(%arg : tensor<1x32xi16>) -> tensor<1x32xi16> { - // CHECK-NOT: mhlo.reshape - %0 = "mhlo.reshape"(%arg) : (tensor<1x32xi16>) -> tensor<1x32xi16> - // CHECK: return %arg0 : tensor<1x32xi16> - func.return %0 : tensor<1x32xi16> -} - -// CHECK-LABEL: @broadcast_of_reshape -func.func @broadcast_of_reshape(%arg: tensor, - %shape: tensor<2xindex>) -> tensor { - %0 = "mhlo.dynamic_reshape"(%arg, %shape) - : (tensor, tensor<2xindex>) -> tensor - %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { - broadcast_dimensions = dense<[0, 1]> : tensor<2xi64> - } : (tensor, tensor<2xindex>) -> tensor - func.return %1 : tensor -} -// CHECK: [[RESHAPE:%.*]] = mhlo.dynamic_reshape -// CHECK: return [[RESHAPE]] - -// CHECK-LABEL: @permutation_broadcast_of_reshape -func.func @permutation_broadcast_of_reshape(%arg: tensor, - %shape: tensor<2xindex>) -> tensor { - %0 = "mhlo.dynamic_reshape"(%arg, %shape) - : (tensor, tensor<2xindex>) -> tensor - %1 = "mhlo.dynamic_broadcast_in_dim"(%0, %shape) { - broadcast_dimensions = dense<[1, 0]> : tensor<2xi64> - } : (tensor, tensor<2xindex>) -> tensor - func.return %1 : tensor -} -// CHECK: mhlo.dynamic_reshape -// CHECK: mhlo.dynamic_broadcast_in_dim - -// CHECK-LABEL: @reshape_of_same_shape_op_result -func.func @reshape_of_same_shape_op_result(%arg: tensor, - %shape: tensor<2xindex>) -> tensor { - %0 = "mhlo.dynamic_reshape"(%arg, %shape) - : (tensor, tensor<2xindex>) -> tensor - %1 = "mhlo.abs"(%0) : (tensor) -> tensor - %2 = "mhlo.dynamic_reshape"(%1, %shape) - : (tensor, tensor<2xindex>) -> tensor - func.return %2 : tensor -} -// CHECK: mhlo.dynamic_reshape -// CHECK-NEXT: mhlo.abs -// CHECK-NOT: mhlo.dynamic_reshape - -// CHECK-LABEL: @map_op_fold -func.func @map_op_fold(%arg: tensor, %arg1: tensor) -> tensor { - %0 = "mhlo.map"(%arg, %arg1) ({ - ^bb0(%a: tensor, %b: tensor): - "mhlo.return"(%b) : (tensor) -> () - }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor, tensor) -> tensor - func.return %0 : tensor -} -// CHECK: return %arg1 : tensor - -func.func @sort_drop_second_arg(%arg0: tensor<3xi32>, %arg1: tensor<3xi32>) -> tensor<3xi32> { - %0:2 = "mhlo.sort"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor): - %1 = "mhlo.compare"(%arg2, %arg3) { - comparison_direction = #mhlo - } : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = 0 : i64, - is_stable = false - } : (tensor<3xi32>, tensor<3xi32>) -> (tensor<3xi32>, tensor<3xi32>) - func.return %0#0 : tensor<3xi32> -} -// CHECK-LABEL: @sort_drop_second_arg -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] -// CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) <{dimension = 0 : i64, is_stable = false}> ({ -// CHECK: ^bb0(%[[ARG2:.+]]: tensor, %[[ARG3:.+]]: tensor) -// CHECK: %[[CMP:.+]] = mhlo.compare GT, %[[ARG2]], %[[ARG3]] : (tensor, tensor) -> tensor -// CHECK: mhlo.return %[[CMP]] : tensor -// CHECK: }) : (tensor<3xi32>) -> tensor<3xi32> -// CHECK: return %[[RES]] : tensor<3xi32> - -func.func @sort_no_dim_provided(%arg0: tensor<3x5xi32>) -> tensor<3x5xi32> { - %0 = "mhlo.sort"(%arg0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %1 = "mhlo.compare"(%arg1, %arg2) { - comparison_direction = #mhlo - } : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }) {dimension = -1 : i64, - is_stable = false - } : (tensor<3x5xi32>) -> tensor<3x5xi32> - func.return %0 : tensor<3x5xi32> -} -// CHECK-LABEL: @sort_no_dim_provided -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] -// CHECK: %[[RES:.+]] = "mhlo.sort"(%[[ARG0]]) -// CHECK: dimension = 1 : i64 -// CHECK: return %[[RES]] : tensor<3x5xi32> - -// CHECK-LABEL: @reshape_splat_of_bools -func.func public @reshape_splat_of_bools() -> tensor<2x1xi1> { - // CHECK: mhlo.constant dense : tensor<2x1xi1> - %0 = mhlo.constant dense : tensor<2xi1> - %1 = "mhlo.reshape"(%0) : (tensor<2xi1>) -> tensor<2x1xi1> - return %1 : tensor<2x1xi1> -} - -// CHECK-LABEL: @simplify_dynamic_gather_i64 -func.func @simplify_dynamic_gather_i64(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { - %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi64>} : () -> tensor<2xi64> - %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi64>) -> tensor<16x64x256xf16> - // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> - // CHECK: return %[[RET]] - return %1 : tensor<16x64x256xf16> -} - -// CHECK-LABEL: @simplify_dynamic_gather_i32 -func.func @simplify_dynamic_gather_i32(%arg0: tensor<375682x256xf16>, %arg1: tensor<16x64xi64>) -> tensor<16x64x256xf16> { - %0 = "arith.constant"() {value = dense<[1, 256]> : tensor<2xi32>} : () -> tensor<2xi32> - %1 = "mhlo.dynamic_gather"(%arg0, %arg1, %0) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor<375682x256xf16>, tensor<16x64xi64>, tensor<2xi32>) -> tensor<16x64x256xf16> - // CHECK: %[[RET:.+]] = "mhlo.gather"(%arg0, %arg1) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[1, 256]> : tensor<2xi64>}> : (tensor<375682x256xf16>, tensor<16x64xi64>) -> tensor<16x64x256xf16> - // CHECK: return %[[RET]] - return %1 : tensor<16x64x256xf16> -} - -// CHECK-LABEL: @fold_reduce_window -func.func @fold_reduce_window(%arg0: tensor<1x1x20xf32>) -> tensor<1x1x20xf32> { - %cst_0 = mhlo.constant dense<0.000000e+00> : tensor - %r = "mhlo.reduce_window"(%arg0, %cst_0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %s = mhlo.add %arg1, %arg2 : tensor - mhlo.return %s : tensor - }) { - padding = dense<0> : tensor<3x2xi64>, - window_dimensions = dense<1> : tensor<3xi64>, - window_strides = dense<1> : tensor<3xi64> - } : (tensor<1x1x20xf32>, tensor) -> tensor<1x1x20xf32> - func.return %r : tensor<1x1x20xf32> - - // CHECK: return %arg0 : tensor<1x1x20xf32> -} - -// CHECK-LABEL: @simplify_real_dynamic_slice_to_slice -func.func @simplify_real_dynamic_slice_to_slice(%arg0: tensor) -> tensor<1x4xf32> { - %0 = mhlo.constant dense<[0, 0]> : tensor<2xi32> - %1 = mhlo.constant dense<[1, 4]> : tensor<2xi32> - %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> - %3 = mhlo.real_dynamic_slice %arg0, %0, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> - // CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0) - // CHECK-DAG-SAME: start_indices = dense<[0, 0]> : tensor<2xi64> - // CHECK-DAG-SAME: limit_indices = dense<[1, 4]> : tensor<2xi64> - // CHECK-DAG-SAME: strides = dense<[1, 1]> : tensor<2xi64>} - // CHECK: return %[[RESULT]] : tensor<1x4xf32> - return %3 : tensor<1x4xf32> -} - -// CHECK-LABEL: @simplify_real_dynamic_slice_to_dynamic_slice -func.func @simplify_real_dynamic_slice_to_dynamic_slice(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<1x4xf32> { - %0 = mhlo.constant dense<[1, 4]> : tensor<2xi32> - %1 = mhlo.add %arg1, %0 : tensor<2xi32> - %2 = mhlo.constant dense<[1, 1]> : tensor<2xi32> - %3 = mhlo.real_dynamic_slice %arg0, %arg1, %1, %2 : (tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x4xf32> - return %3 : tensor<1x4xf32> - // CHECK: [[START_INDEX_0_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: [[START_INDEX_0_0D:%.*]] = mhlo.reshape [[START_INDEX_0_1D]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: [[START_INDEX_1_1D:%.*]] = "mhlo.slice"(%arg1) <{limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<2xi32>) -> tensor<1xi32> - // CHECK-NEXT: [[START_INDEX_1_0D:%.*]] = mhlo.reshape [[START_INDEX_1_1D]] : (tensor<1xi32>) -> tensor - // CHECK-NEXT: [[RESULT:%.*]] = "mhlo.dynamic_slice"(%arg0, [[START_INDEX_0_0D]], [[START_INDEX_1_0D]]) <{ - // CHECK-SAME: slice_sizes = dense<[1, 4]> : tensor<2xi64> - // CHECK-SAME: }> : (tensor, tensor, tensor) -> tensor<1x4xf32> - // CHECK-NEXT: return [[RESULT]] : tensor<1x4xf32> -} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/chlo_canonicalize.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/chlo_canonicalize.mlir new file mode 100644 index 00000000000000..2db8477e062c94 --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/chlo_canonicalize.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s + +// CHECK-LABEL: constant_like_constant +func.func @constant_like_constant(%arg0: tensor<3x4xi32>) -> tensor<3x4xf32> { + // CHECK: chlo.constant dense<3.200000e+00> + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor<3x4xi32>) -> tensor<3x4xf32> + func.return %0 : tensor<3x4xf32> +} + +// CHECK-LABEL: constant_like_constant_dynamic +func.func @constant_like_constant_dynamic(%arg0: tensor) -> tensor { + // CHECK: chlo.constant_like + %0 = "chlo.constant_like"(%arg0) { value = 3.2 : f32 } : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/fold.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/fold.mlir new file mode 100644 index 00000000000000..4f9cf74a9c67f9 --- /dev/null +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/canonicalize/fold.mlir @@ -0,0 +1,1906 @@ +// RUN: mlir-hlo-opt %s -pass-pipeline='builtin.module(func.func(canonicalize))' | FileCheck %s + +//////// +// AbsOp + +// CHECK-LABEL: func @fold_abs +func.func @fold_abs() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.abs"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.abs + func.return %1 : tensor<4xf32> +} + +//////// +// AddOp + +// CHECK-LABEL: add_fold +func.func @add_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[6, 8, 10, 12]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_scalar_fold +func.func @add_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<1> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<6> + %2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: add_fold_float +func.func @add_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]> + %2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + func.return %2 : tensor<4xf64> +} + +// CHECK-LABEL: add_zero_int_fold +func.func @add_zero_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = mhlo.constant dense<0> : tensor<2x2xi64> + %1 = "mhlo.add"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: return %arg0 : tensor<2x2xi64> + func.return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: add_zero_float_flod +func.func @add_zero_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = mhlo.constant dense<0.0> : tensor<2x2xf32> + %1 = "mhlo.add"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: return %arg0 : tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> +} + +//////// +// AndOp + +// CHECK-LABEL: func @fold_and_same +func.func @fold_and_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.and"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_ones +func.func @fold_and_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_zeros +func.func @fold_and_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constant +func.func @fold_and_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.and + %1 = "mhlo.and"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_and_constants +func.func @fold_and_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.and"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[0, 1, 6, 2]> : tensor<4xi32> + // CHECK: return %0 + func.return %2 : tensor<4xi32> +} + +//////// +// BroadcastOp + +// CHECK-LABEL: func @broadcast_constant_fold_0d +func.func @broadcast_constant_fold_0d() -> tensor<1x64x224x224xf32> { + %cst = mhlo.constant dense<0.000000e+00> : tensor + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor) -> tensor<1x64x224x224xf32> + func.return %b : tensor<1x64x224x224xf32> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> +// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> + +// CHECK-LABEL: func @broadcast_constant_fold +func.func @broadcast_constant_fold() -> tensor<1x64x4x4xf32> { + %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> + func.return %b : tensor<1x64x4x4xf32> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> +// CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> + +// CHECK-LABEL: func @broadcast_constant_fold_not_splat +func.func @broadcast_constant_fold_not_splat() -> tensor<1x64x2xf32> { + // CHECK: mhlo.constant + %cst = mhlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> + // CHECK: mhlo.broadcast + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64]> : tensor<2xi64>}> : (tensor<2xf32>) -> tensor<1x64x2xf32> + func.return %b : tensor<1x64x2xf32> +} + +// CHECK-LABEL: func @broadcast_constant_fold_complex +func.func @broadcast_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { + %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> + func.return %b : tensor<1x64x224x224xcomplex> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> +// CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xcomplex> + +// CHECK-LABEL: func @broadcast_constant_fold_quantized_skipped +func.func @broadcast_constant_fold_quantized_skipped() -> tensor<1x64x224x224x!quant.uniform> { + %cst = mhlo.constant() {value = dense<2> : tensor} : () -> tensor> + %b = "mhlo.broadcast"(%cst) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> + func.return %b : tensor<1x64x224x224x!quant.uniform> +} +// CHECK-NEXT: %[[CST:.*]] = mhlo.constant() <{value = dense<2> : tensor}> : () -> tensor> +// CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast"(%[[CST:.*]]) <{broadcast_sizes = dense<[1, 64, 224, 224]> : tensor<4xi64>}> : (tensor>) -> tensor<1x64x224x224x!quant.uniform> +// CHECK-NEXT: return %[[RES:.*]] : tensor<1x64x224x224x!quant.uniform> + +// CHECK-LABEL: func @broadcast_in_dim_identity +func.func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> { + // CHECK: return %arg0 + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}> : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + func.return %0 : tensor<2x3x4xf32> +} + +//////// +// BroadcastInDimOp + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d +func.func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> { + // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x224x224xf32> + // CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xf32> + %cst = mhlo.constant dense<0.000000e+00> : tensor + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor) -> tensor<1x64x224x224xf32> + func.return %b : tensor<1x64x224x224xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold +func.func @broadcast_in_dim_constant_fold() -> tensor<1x64x4x4xf32> { + // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1x64x4x4xf32> + // CHECK-NEXT: return %[[CST]] : tensor<1x64x4x4xf32> + %cst = mhlo.constant dense<0.000000e+00> : tensor<4x4xf32> + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>}> : (tensor<4x4xf32>) -> tensor<1x64x4x4xf32> + func.return %b : tensor<1x64x4x4xf32> +} + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold_complex +func.func @broadcast_in_dim_constant_fold_complex() -> tensor<1x64x224x224xcomplex> { + // CHECK-NEXT: %[[CST:.*]] = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<1x64x224x224xcomplex> + // CHECK-NEXT: return %[[CST]] : tensor<1x64x224x224xcomplex> + %cst = mhlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor> + %b = "mhlo.broadcast_in_dim"(%cst) <{broadcast_dimensions = dense<[]> : tensor<0xi64>}> : (tensor>) -> tensor<1x64x224x224xcomplex> + func.return %b : tensor<1x64x224x224xcomplex> +} + +// CHECK-LABEL: func @broadcast_in_dim_constant_fold_quantized_skipped +func.func @broadcast_in_dim_constant_fold_quantized_skipped(%arg0: tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> { + // CHECK-NEXT: %[[RES:.*]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK-NEXT: return %[[RES:.*]] : tensor<2x2x!quant.uniform> + %b = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}> : (tensor<1x2x!quant.uniform>) -> tensor<2x2x!quant.uniform> + func.return %b : tensor<2x2x!quant.uniform> +} + +//////// +// CaseOp + +// CHECK-LABEL: func @fold_case( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func.func @fold_case(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG1]] + %c1 = mhlo.constant dense<1> : tensor + %0 = "mhlo.case"(%c1) ({ + "mhlo.return"(%arg0) : (tensor) -> () + }, { + "mhlo.return"(%arg1) : (tensor) -> () + }, { + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @fold_case_negative_index( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func.func @fold_case_negative_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG2]] + %m1000 = mhlo.constant dense<-1000> : tensor + %0 = "mhlo.case"(%m1000) ({ + "mhlo.return"(%arg0) : (tensor) -> () + }, { + "mhlo.return"(%arg1) : (tensor) -> () + }, { + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @fold_case_oob_index( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func.func @fold_case_oob_index(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { + // CHECK-NOT: mhlo.case + // CHECK: return %[[ARG2]] + %c1000 = mhlo.constant dense<1000> : tensor + %0 = "mhlo.case"(%c1000) ({ + "mhlo.return"(%arg0) : (tensor) -> () + }, { + "mhlo.return"(%arg1) : (tensor) -> () + }, { + "mhlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +//////// +// ClampOp + +// CHECK-LABEL: clamp_scalar_fold +func.func @clamp_scalar_fold() -> tensor<5xi64> { + %0 = mhlo.constant dense<149> : tensor + %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> + %2 = mhlo.constant dense<0> : tensor + // CHECK{LITERAL}: mhlo.constant dense<[0, 100, 149, 0, 149]> + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor, tensor<5xi64>, tensor) -> tensor<5xi64> + return %3 : tensor<5xi64> +} + +// CHECK-LABEL: clamp_fold +func.func @clamp_fold() -> tensor<5xi64> { + %0 = mhlo.constant dense<[149, 101, -1, 30, 50]> : tensor<5xi64> + %1 = mhlo.constant dense<[-1, 100, 200, 0, 149]> : tensor<5xi64> + %2 = mhlo.constant dense<[0, 10, -10, 10, -100]> : tensor<5xi64> + // CHECK{LITERAL}: mhlo.constant dense<[0, 100, -1, 10, 50]> + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor<5xi64>, tensor<5xi64>, tensor<5xi64>) -> tensor<5xi64> + return %3 : tensor<5xi64> +} + +// CHECK-LABEL: clamp_fold_float +func.func @clamp_fold_float() -> tensor<6xf32> { + %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 6.0]> : tensor<6xf32> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> + %2 = mhlo.constant dense<[5.0, 1.0, 1.0, 0xFFFFFFFF, 0xFFFFFFFF, 5.0]> : tensor<6xf32> + // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 5.000000e+00] + // CHECK-NOT: mhlo.clamp + %3 = mhlo.clamp %2, %1, %0 : (tensor<6xf32>, tensor<6xf32>, tensor<6xf32>) -> tensor<6xf32> + return %3 : tensor<6xf32> +} + +//////// +// CompareOp + +// CHECK-LABEL: fold_sign_posi +func.func @fold_sign_posi() -> tensor { + // CHECK: %0 = mhlo.constant dense<1> : tensor + %0 = mhlo.constant dense<2> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negi +func.func @fold_sign_negi() -> tensor { + // CHECK: %0 = mhlo.constant dense<-1> : tensor + %0 = mhlo.constant dense<-2> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: fold_sign_posf +func.func @fold_sign_posf() -> tensor { + // CHECK: %0 = mhlo.constant dense<1.000000e+00> : tensor + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negf +func.func @fold_sign_negf() -> tensor { + // CHECK: %0 = mhlo.constant dense<-1.000000e+00> : tensor + %0 = mhlo.constant dense<-2.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: fold_sign_negzf +func.func @fold_sign_negzf() -> tensor { + // CHECK: %0 = mhlo.constant dense<-0.000000e+00> : tensor + %0 = mhlo.constant dense<-0.000000e+00> : tensor + %1 = "mhlo.sign"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +// CHECK-LABEL: fold_compare_same_eq +func.func @fold_compare_same_eq(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_le +func.func @fold_compare_same_le(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_ge +func.func @fold_compare_same_ge(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} +// CHECK-LABEL: fold_compare_same_ne +func.func @fold_compare_same_ne(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_lt +func.func @fold_compare_same_lt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: fold_compare_same_gt +func.func @fold_compare_same_gt(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.constant dense : tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// Address NaN != NaN. +// CHECK-LABEL: dont_fold_compare_same_eq_float +func.func @dont_fold_compare_same_eq_float(%arg0: tensor) -> tensor { + // CHECK: %0 = mhlo.compare EQ, %arg0, %arg0 : (tensor, tensor) -> tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// Address NaN != NaN for complex types. +// CHECK-LABEL: dont_fold_compare_same_eq_complex +func.func @dont_fold_compare_same_eq_complex(%arg0: tensor>) -> tensor { + // CHECK: %0 = mhlo.compare EQ, %arg0, %arg0 : (tensor>, tensor>) -> tensor + %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = #mhlo} : (tensor>, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq +func.func @fold_compare_false_eq() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} +// CHECK-LABEL: fold_compare_true_eq +func.func @fold_compare_true_eq() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_bools_true_eq +func.func @fold_compare_bools_true_eq(%arg : tensor) -> tensor { + %1 = mhlo.constant dense : tensor + // CHECK: return %arg + %2 = "mhlo.compare"(%arg, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: compare_i1_as_unsigned +func.func @compare_i1_as_unsigned(%arg : tensor) -> tensor { + %true = mhlo.constant dense : tensor + %false = mhlo.constant dense : tensor + // CHECK: %[[FALSE:.*]] = mhlo.constant dense + // CHECK: return %[[FALSE]] + %2 = "mhlo.compare"(%true, %false) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_eq_float +func.func @fold_compare_false_eq_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_eq_float +func.func @fold_compare_true_eq_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne +func.func @fold_compare_false_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne +func.func @fold_compare_true_ne() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_bools_false_ne +func.func @fold_compare_bools_false_ne(%arg : tensor) -> tensor { + %1 = mhlo.constant dense : tensor + // CHECK: return %arg + %2 = "mhlo.compare"(%arg, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ne_float +func.func @fold_compare_false_ne_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ne_float +func.func @fold_compare_true_ne_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt +func.func @fold_compare_false_lt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt +func.func @fold_compare_true_lt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_lt_float +func.func @fold_compare_false_lt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_lt_float +func.func @fold_compare_true_lt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le +func.func @fold_compare_false_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le +func.func @fold_compare_true_le() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_le_float +func.func @fold_compare_false_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_le_float +func.func @fold_compare_true_le_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt +func.func @fold_compare_false_gt() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt +func.func @fold_compare_true_gt() -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_gt_float +func.func @fold_compare_false_gt_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_gt_float +func.func @fold_compare_true_gt_float() -> tensor { + %0 = mhlo.constant dense<1.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge +func.func @fold_compare_false_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge +func.func @fold_compare_true_ge() -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<0> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_false_ge_float +func.func @fold_compare_false_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<1.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +// CHECK-LABEL: fold_compare_true_ge_float +func.func @fold_compare_true_ge_float() -> tensor { + %0 = mhlo.constant dense<0.> : tensor + %1 = mhlo.constant dense<0.> : tensor + // CHECK: %0 = mhlo.constant dense : tensor + %2 = "mhlo.compare"(%0, %1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor + func.return %2 : tensor +} + +//////// +// ConvertOp + +func.func @fold_fptosi() -> tensor { + %0 = mhlo.constant dense<65535.000000e+00> : tensor + // CHECK: mhlo.constant dense<32767> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_fptosi_rounding() -> tensor { + %0 = mhlo.constant dense<-1.5> : tensor + // CHECK: mhlo.constant dense<-1> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_fptoui() -> tensor { + %0 = mhlo.constant dense<-1.000000e+00> : tensor + // CHECK: mhlo.constant dense<0> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_sitofp() -> tensor { + %0 = mhlo.constant dense<-1> : tensor + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_uitofp() -> tensor { + %0 = mhlo.constant dense<65535> : tensor + // CHECK: mhlo.constant dense<6.553500e+04> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_uitoui() -> tensor { + %0 = mhlo.constant dense<65535> : tensor + // CHECK: mhlo.constant dense<65535> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_uitosi() -> tensor { + %0 = mhlo.constant dense<65535> : tensor + // CHECK: mhlo.constant dense<65535> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_sitoui() -> tensor { + %0 = mhlo.constant dense<-1> : tensor + // CHECK: mhlo.constant dense<4294967295> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_sitosi() -> tensor { + %0 = mhlo.constant dense<-1> : tensor + // CHECK: mhlo.constant dense<-1> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @fold_predtosi() -> tensor { + %0 = mhlo.constant dense : tensor + // CHECK: mhlo.constant dense<0> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor + func.return %1 : tensor +} + +func.func @not_fold_itouq() -> tensor> { + // CHECK: mhlo.constant dense<1> : tensor + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.convert"(%0) : (tensor) -> tensor> + func.return %1 : tensor> +} + +//////// +// CosineOp + +// CHECK-LABEL: func @fold_cosine +func.func @fold_cosine() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.cosine"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-0.416146845> : tensor<4xf32> + // CHECK-NOT: mhlo.cosine + func.return %1 : tensor<4xf32> +} + +//////// +// DivideOp + +// CHECK-LABEL: divide_scalar_fold +func.func @divide_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<5> : tensor<4xi64> + // CHECK: mhlo.constant dense<1> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_scalar_fold_by_zero +func.func @divide_scalar_fold_by_zero() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<0> : tensor<4xi64> + // CHECK: mhlo.divide + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: divide_fold_int +func.func @divide_fold_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[1, -2, 3, 4]> : tensor<4xi32> + %1 = mhlo.constant dense<[-1, -2, -3, 2]> : tensor<4xi32> + // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[-1, 1, -1, 2]> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + // CHECK: return %[[RESULT]] + func.return %2 : tensor<4xi32> +} + +// CHECK-LABEL: divide_fold_unsigned +func.func @divide_fold_unsigned() -> tensor<4xui32> { + %0 = mhlo.constant dense<[1, -2, 3, 4]> : tensor<4xi32> + %1 = "mhlo.convert"(%0) : (tensor<4xi32>) -> tensor<4xui32> + %2 = mhlo.constant dense<[-1, -2, -3, 2]> : tensor<4xi32> + %3 = "mhlo.convert"(%2) : (tensor<4xi32>) -> tensor<4xui32> + // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[0, 1, 0, 2]> + %4 = "mhlo.divide"(%1, %3) : (tensor<4xui32>, tensor<4xui32>) -> (tensor<4xui32>) + // CHECK: return %[[RESULT]] + func.return %4 : tensor<4xui32> +} + +// CHECK-LABEL: divide_fold_float +func.func @divide_fold_float() -> tensor<4xf64> { + %0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]> + %2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>) + func.return %2 : tensor<4xf64> +} + +// CHECK-LABEL: divide_fold_by_zero +func.func @divide_fold_by_zero() -> tensor<4xi64> { + %0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64> + %1 = mhlo.constant dense<[1, 2, 3, 0]> : tensor<4xi64> + // CHECK: mhlo.divide + %2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +//////// +// DynamicPadOp + +// CHECK-LABEL: @dynamic_pad_identity_fold +func.func @dynamic_pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<11x15xf32> { + %0 = arith.constant dense<0.0> : tensor + %1 = arith.constant dense<1> : tensor<2xi32> + %2 = arith.constant dense<1> : tensor<2xi32> + %3 = arith.constant dense<1> : tensor<2xi32> + // CHECK: %[[CST:.+]] = arith.constant dense<0.000000e+00> : tensor + // CHECK: %[[PAD:.+]] = "mhlo.pad"(%arg0, %[[CST]]) + // CHECK-SAME: edge_padding_high = dense<1> : tensor<2xi64> + // CHECK-SAME: edge_padding_low = dense<1> : tensor<2xi64> + // CHECK-SAME: interior_padding = dense<1> : tensor<2xi64>} + // CHECK-SAME: (tensor<5x7xf32>, tensor) -> tensor<11x15xf32> + %4 = "mhlo.dynamic_pad"(%arg0, %0, %1, %2, %3) { + } : (tensor<5x7xf32>, tensor, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<11x15xf32> + // return %[[PAD]] + func.return %4 : tensor<11x15xf32> +} + +//////// +// ExponentialOp + +// CHECK-LABEL: func @fold_exponential +func.func @fold_exponential() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.exponential"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<7.3890562> : tensor<4xf32> + // CHECK-NOT: mhlo.exponential + func.return %1 : tensor<4xf32> +} + +//////// +// GetDimensionSizeOp / SetDimensionSizeOp + +// CHECK-LABEL: func @fold_get_dimension_size +func.func @fold_get_dimension_size(%I: tensor<1x128x512xf32>) -> tensor { + %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x512xf32>) -> tensor + func.return %size : tensor + // CHECK-NEXT: %[[C:.*]] = mhlo.constant dense<512> : tensor + // CHECK-NEXT: return %[[C]] +} + +// CHECK-LABEL: func @fold_get_dimension_size_fail +func.func @fold_get_dimension_size_fail(%I: tensor<1x128x?xf32>) -> tensor { + // CHECK: "mhlo.get_dimension_size" + %size = "mhlo.get_dimension_size"(%I) <{dimension = 2 : i64}> : (tensor<1x128x?xf32>) -> tensor + func.return %size : tensor +} + +// CHECK-LABEL: func @fold_set_dimension_size +// CHECK-SAME: (%[[I:.*]]: tensor<1x128x512xf32>) +func.func @fold_set_dimension_size(%I: tensor<1x128x512xf32>) -> tensor<1x128x512xf32> { + %dim = mhlo.constant dense<512> : tensor + %result = "mhlo.set_dimension_size"(%I, %dim) {dimension = 2 : i64} : (tensor<1x128x512xf32>, tensor) -> tensor<1x128x512xf32> + func.return %result : tensor<1x128x512xf32> + + // CHECK-NEXT: return %[[I]] +} + +//////// +// IfOp + +// CHECK-LABEL: func @fold_if_true( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func.func @fold_if_true(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NOT: mhlo.if + // CHECK: return %[[ARG0]] + %true = mhlo.constant dense : tensor + %0 = "mhlo.if"(%true) ({ + "mhlo.return"(%arg0) : (tensor) -> () + }, { + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @fold_if_false( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]] +// CHECK-SAME: ) +func.func @fold_if_false(%arg0 : tensor, %arg1 : tensor) -> tensor { + // CHECK-NOT: mhlo.if + // CHECK: return %[[ARG1]] + %false = mhlo.constant dense : tensor + %0 = "mhlo.if"(%false) ({ + "mhlo.return"(%arg0) : (tensor) -> () + }, { + "mhlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +//////// +// LogOp + +// CHECK-LABEL: func @fold_log +func.func @fold_log() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.693147182> : tensor<4xf32> + // CHECK-NOT: mhlo.log + func.return %1 : tensor<4xf32> +} + +//////// +// LogisticOp + +// CHECK-LABEL: func @fold_logistic +func.func @fold_logistic() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.logistic"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.880797088> : tensor<4xf32> + // CHECK-NOT: mhlo.logistic + func.return %1 : tensor<4xf32> +} + +//////// +// MaxOp + +// CHECK-LABEL: max_scalar_fold +func.func @max_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<-5> : tensor<4xi64> + // CHECK: %[[RESULT:.+]] = mhlo.constant dense<7> + %2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + // CHECK: return %[[RESULT]] + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: max_scalar_fold_unsigned +func.func @max_scalar_fold_unsigned() -> tensor<4xui32> { + %0 = mhlo.constant dense<7> : tensor<4xui32> + %1 = mhlo.constant dense<-5> : tensor<4xi32> + %2 = "mhlo.convert"(%1) : (tensor<4xi32>) -> tensor<4xui32> + // CHECK: %[[RESULT:.+]] = mhlo.constant dense<4294967291> + %3 = "mhlo.maximum"(%0, %2) : (tensor<4xui32>, tensor<4xui32>) -> (tensor<4xui32>) + // CHECK: return %[[RESULT]] + func.return %3 : tensor<4xui32> +} + +// CHECK-LABEL: max_fold_float +func.func @max_fold_float() -> tensor<6xf32> { + %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 1.0]> : tensor<6xf32> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> + // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 4.000000e+00] + %2 = "mhlo.maximum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> (tensor<6xf32>) + func.return %2 : tensor<6xf32> +} + +//////// +// MinOp + +// CHECK-LABEL: min_scalar_fold +func.func @min_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<-5> : tensor<4xi64> + // CHECK: mhlo.constant dense<-5> + %2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: min_fold_float +func.func @min_fold_float() -> tensor<6xf32> { + %0 = mhlo.constant dense<[5.0, 66.0, 0xFFFFFFFF, -2.0, 0xFFFFFFFF, 1.0]> : tensor<6xf32> + %1 = mhlo.constant dense<[5.0, 3.0, 2.0, 0xFFFFFFFF, 0xFFFFFFFF, 4.0]> : tensor<6xf32> + // CHECK{LITERAL}: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 1.000000e+00] + %2 = "mhlo.minimum"(%0, %1) : (tensor<6xf32>, tensor<6xf32>) -> (tensor<6xf32>) + func.return %2 : tensor<6xf32> +} + +//////// +// MapOp + +// CHECK-LABEL: @map_op_fold +func.func @map_op_fold(%arg: tensor, %arg1: tensor) -> tensor { + %0 = "mhlo.map"(%arg, %arg1) ({ + ^bb0(%a: tensor, %b: tensor): + "mhlo.return"(%b) : (tensor) -> () + }) {dimensions = dense<[0]> : tensor<1xi64>} : (tensor, tensor) -> tensor + // CHECK: return %arg1 : tensor + func.return %0 : tensor +} + +//////// +// MultiplyOp + +// CHECK-LABEL: multiply_scalar_fold +func.func @multiply_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<3> : tensor<4xi64> + // CHECK: mhlo.constant dense<15> + %2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: mul_one_int_fold +func.func @mul_one_int_fold(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> { + %0 = mhlo.constant dense<1> : tensor<2x2xi64> + %1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi64>, tensor<2x2xi64>) -> tensor<2x2xi64> + // CHECK: return %arg0 : tensor<2x2xi64> + func.return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: mul_one_int8_fold +func.func @mul_one_int8_fold(%arg0: tensor<2x2xi8>) -> tensor<2x2xi8> { + %0 = mhlo.constant dense<1> : tensor<2x2xi8> + %1 = "mhlo.multiply"(%arg0, %0) : (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi8> + // CHECK: return %arg0 : tensor<2x2xi8> + func.return %1 : tensor<2x2xi8> +} + +// CHECK-LABEL: mul_one_float_flod +func.func @mul_one_float_flod(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = mhlo.constant dense<1.0> : tensor<2x2xf32> + %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: return %arg0 : tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> +} + +// CHECK-LABEL: mul_one_fp16_flod +func.func @mul_one_fp16_flod(%arg0: tensor<2x2xf16>) -> tensor<2x2xf16> { + %0 = mhlo.constant dense<1.0> : tensor<2x2xf16> + %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xf16>, tensor<2x2xf16>) -> tensor<2x2xf16> + // CHECK: return %arg0 : tensor<2x2xf16> + func.return %1 : tensor<2x2xf16> +} + +// CHECK-LABEL: mul_one_bf16_flod +func.func @mul_one_bf16_flod(%arg0: tensor<2x2xbf16>) -> tensor<2x2xbf16> { + %0 = mhlo.constant dense<1.0> : tensor<2x2xbf16> + %1 = "mhlo.multiply"(%0, %arg0) : (tensor<2x2xbf16>, tensor<2x2xbf16>) -> tensor<2x2xbf16> + // CHECK: return %arg0 : tensor<2x2xbf16> + func.return %1 : tensor<2x2xbf16> +} + +//////// +// NegateOp + +// CHECK-LABEL: func @fold_negate_int +func.func @fold_negate_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, -3]> : tensor<4xi32> + // CHECK: mhlo.constant dense<[0, -1, -6, 3]> + %1 = "mhlo.negate"(%0) : (tensor<4xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_negate_float +func.func @fold_negate_float() -> tensor<4xf32> { + %0 = mhlo.constant dense<[0., 1., 6., -3.]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[-0.000000e+00, -1.000000e+00, -6.000000e+00, 3.000000e+00]> + %1 = "mhlo.negate"(%0) : (tensor<4xf32>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> +} + +//////// +// NotOp + +// CHECK-LABEL func @fold_not() +func.func @fold_not() -> tensor<2x2xi1> { + %0 = mhlo.constant dense<[[true, false], [true, false]]> : tensor<2x2xi1> + // CHECK{LITERAL}: mhlo.constant dense<[[false, true], [false, true]]> : tensor<2x2xi1> + %1 = "mhlo.not"(%0) : (tensor<2x2xi1>) -> tensor<2x2xi1> + func.return %1 : tensor<2x2xi1> +} + +// CHECK-LABEL func @fold_not_i32() +func.func @fold_not_i32() -> tensor<2x2xi32> { + %0 = mhlo.constant dense<[[42, -12], [1, 0]]> : tensor<2x2xi32> + // CHECK{LITERAL}: mhlo.constant dense<[[-43, 11], [-2, -1]]> : tensor<2x2xi32> + %1 = "mhlo.not"(%0) : (tensor<2x2xi32>) -> tensor<2x2xi32> + func.return %1 : tensor<2x2xi32> +} + +// CHECK-LABEL: func @not_fold_log_neg_constants +func.func @not_fold_log_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.log"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.log + func.return %1 : tensor<4xf32> +} + +//////// +// OrOp + +// CHECK-LABEL: func @fold_or_same +func.func @fold_or_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.or"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_ones +func.func @fold_or_ones(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros +func.func @fold_or_zeros(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_constant +func.func @fold_or_constant(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<7> : tensor<4xi32> + // CHECK: mhlo.or + %1 = "mhlo.or"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_right +func.func @fold_or_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.or"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_or_zeros_constants +func.func @fold_or_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.or"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 3, 7, 3]> : tensor<4xi32> + // CHECK: return %0 + func.return %2 : tensor<4xi32> +} + +//////// +// PadOp + +// CHECK-LABEL: @pad_complex_fold +func.func @pad_complex_fold() -> tensor<2xcomplex> { + %0 = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<1xcomplex> + %1 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xcomplex>, tensor>) -> tensor<2xcomplex> + return %2 : tensor<2xcomplex> + // CHECK: mhlo.constant dense<[(2.000000e+00,0.000000e+00), (1.000000e+00,0.000000e+00)]> : tensor<2xcomplex> +} + +// CHECK-LABEL: @pad_identity_fold +func.func @pad_identity_fold(%arg0: tensor<5x7xf32>) -> tensor<5x7xf32> { + %0 = arith.constant dense<0.0> : tensor + %1 = "mhlo.pad"(%arg0, %0) { + edge_padding_low = dense<0> : tensor<2xi64>, + edge_padding_high = dense<0> : tensor<2xi64>, + interior_padding = dense<0> : tensor<2xi64> + } : (tensor<5x7xf32>, tensor) -> tensor<5x7xf32> + func.return %1 : tensor<5x7xf32> + // CHECK: return %arg0 : tensor<5x7xf32> +} + +// CHECK-LABEL: @pad_fold +func.func @pad_fold() -> tensor<4x5xi32> { + %0 = arith.constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> + %1 = arith.constant dense<1> : tensor + %3 = "mhlo.pad"(%0, %1) { + edge_padding_low = dense<[1, 0]> : tensor<2xi64>, + edge_padding_high = dense<[1, 2]> : tensor<2xi64>, + interior_padding = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x2xi32>, tensor) -> tensor<4x5xi32> + func.return %3 : tensor<4x5xi32> + // CHECK: constant dense<[ + // CHECK-SAME: [1, 1, 1, 1, 1], [2, 1, 3, 1, 1], [4, 1, 5, 1, 1], [1, 1, 1, 1, 1] + // CHECK-SAME: ]> : tensor<4x5xi32> +} + +// CHECK-LABEL: @pad_negative_fold +func.func @pad_negative_fold() -> tensor<4x4xi32> { + %0 = arith.constant dense<[[2, 3], [4, 5]]> : tensor<2x2xi32> + %1 = arith.constant dense<1> : tensor + %3 = "mhlo.pad"(%0, %1) { + edge_padding_low = dense<[1, -1]> : tensor<2xi64>, + edge_padding_high = dense<[1, 2]> : tensor<2xi64>, + interior_padding = dense<[0, 1]> : tensor<2xi64> + } : (tensor<2x2xi32>, tensor) -> tensor<4x4xi32> + func.return %3 : tensor<4x4xi32> + // CHECK: "mhlo.pad" +} + +// CHECK-LABEL: @pad_fold_zero_elements +func.func @pad_fold_zero_elements() -> tensor<3xi32> { + %0 = mhlo.constant dense<> : tensor<0xi32> + %1 = mhlo.constant dense<7> : tensor + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<3> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<0xi32>, tensor) -> tensor<3xi32> + func.return %2 : tensor<3xi32> + // CHECK: mhlo.constant dense<7> : tensor<3xi32> +} + +// CHECK-LABEL: @pad_float_fold +func.func @pad_float_fold() -> tensor<2xf32> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<1xf32> + %1 = mhlo.constant dense<1.000000e+00> : tensor + %2 = "mhlo.pad"(%0, %1) <{edge_padding_high = dense<1> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>}> : (tensor<1xf32>, tensor) -> tensor<2xf32> + return %2 : tensor<2xf32> + // CHECK: mhlo.constant dense<[2.000000e+00, 1.000000e+00]> : tensor<2xf32> +} + +//////// +// ReduceWindowOp + +// CHECK-LABEL: @fold_reduce_window +func.func @fold_reduce_window(%arg0: tensor<1x1x20xf32>) -> tensor<1x1x20xf32> { + %cst_0 = mhlo.constant dense<0.000000e+00> : tensor + %r = "mhlo.reduce_window"(%arg0, %cst_0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %s = mhlo.add %arg1, %arg2 : tensor + mhlo.return %s : tensor + }) { + padding = dense<0> : tensor<3x2xi64>, + window_dimensions = dense<1> : tensor<3xi64>, + window_strides = dense<1> : tensor<3xi64> + } : (tensor<1x1x20xf32>, tensor) -> tensor<1x1x20xf32> + func.return %r : tensor<1x1x20xf32> + + // CHECK: return %arg0 : tensor<1x1x20xf32> +} + +//////// +// RemainderOp + +// CHECK-LABEL: remainder_scalar_fold_by_zero +func.func @remainder_scalar_fold_by_zero() -> tensor<4xi64> { + %0 = mhlo.constant dense<7> : tensor<4xi64> + %1 = mhlo.constant dense<0> : tensor<4xi64> + // CHECK: mhlo.remainder + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +// CHECK-LABEL: remainder_fold_int +func.func @remainder_fold_int() -> tensor<4xi32> { + %0 = mhlo.constant dense<[5, 66, 5, -1]> : tensor<4xi32> + %1 = mhlo.constant dense<[3, 5, 1, -2]> : tensor<4xi32> + // CHECK: %[[RESULT:.+]] = mhlo.constant dense<[2, 1, 0, -1]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> (tensor<4xi32>) + // CHECK: return %[[RESULT]] + func.return %2 : tensor<4xi32> +} + +// CHECK-LABEL: remainder_fold_float +func.func @remainder_fold_float() -> tensor<8xf32> { + %0 = mhlo.constant dense<[-2.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0]> : tensor<8xf32> + %1 = mhlo.constant dense<[10.0, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0]> : tensor<8xf32> + // CHECK{LITERAL}: mhlo.constant dense<[-2.500000e+00, 2.500000e-01, -0.000000e+00, 0.000000e+00, 1.000000e+00, 1.000000e+00, -1.000000e+00, -0.000000e+00]> + %2 = "mhlo.remainder"(%0, %1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>) + func.return %2 : tensor<8xf32> +} + +//////// +// ReshapeOp + +// CHECK-LABEL: @reshape_splat_of_bools +func.func public @reshape_splat_of_bools() -> tensor<2x1xi1> { + // CHECK: mhlo.constant dense : tensor<2x1xi1> + %0 = mhlo.constant dense : tensor<2xi1> + %1 = "mhlo.reshape"(%0) : (tensor<2xi1>) -> tensor<2x1xi1> + return %1 : tensor<2x1xi1> +} + +//////// +// RoundNearestOps + +// CHECK-LABEL: round_fold +func.func @round_fold() -> tensor<4xf32> { + %0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32> + %1 = "mhlo.round_nearest_afz"(%0) : (tensor<4xf32>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> + // CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 3.000000e+00]> +} + +// CHECK-LABEL: round_nearest_even_fold +func.func @round_nearest_even_fold() -> tensor<4xf32> { + %0 = mhlo.constant dense<[-1.5, -0.1, 1.1, 2.5]> : tensor<4xf32> + %1 = "mhlo.round_nearest_even"(%0) : (tensor<4xf32>) -> tensor<4xf32> + func.return %1 : tensor<4xf32> + // CHECK: mhlo.constant dense<[-2.000000e+00, -0.000000e+00, 1.000000e+00, 2.000000e+00]> +} + +//////// +// RsqrtOp + +// CHECK-LABEL: func @fold_rsqrt_f16_constants +func.func @fold_rsqrt_f16_constants() -> tensor<4xf16> { + %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf16> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xf16>) -> tensor<4xf16> + // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf16> + // CHECK-NOT: mhlo.rsqrt + func.return %1 : tensor<4xf16> +} + +// CHECK-LABEL: func @fold_rsqrt_bf16_constants +func.func @fold_rsqrt_bf16_constants() -> tensor<4xbf16> { + %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xbf16> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xbf16>) -> tensor<4xbf16> + // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xbf16> + // CHECK-NOT: mhlo.rsqrt + func.return %1 : tensor<4xbf16> +} + +// CHECK-LABEL: func @fold_rsqrt_f32_constants +func.func @fold_rsqrt_f32_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf32> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf32> + // CHECK-NOT: mhlo.rsqrt + func.return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_rsqrt_f64_constants +func.func @fold_rsqrt_f64_constants() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 4.0, 16.0, 64.0]> : tensor<4xf64> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 5.000000e-01, 2.500000e-01, 1.250000e-01]> : tensor<4xf64> + // CHECK-NOT: mhlo.rsqrt + func.return %1 : tensor<4xf64> +} + +// CHECK-LABEL: func @not_fold_rsqrt_neg_constants +func.func @not_fold_rsqrt_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.rsqrt + func.return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @not_fold_rsqrt_const_zero +func.func @not_fold_rsqrt_const_zero() -> tensor<4xf32> { + %0 = mhlo.constant dense<0.0> : tensor<4xf32> + %1 = "mhlo.rsqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.000000e+00> : tensor<4xf32> + // CHECK: mhlo.rsqrt + func.return %1 : tensor<4xf32> +} + +//////// +// ScatterOp + +// CHECK-LABEL: @tensor_flow_scatter_v1_update +func.func @tensor_flow_scatter_v1_update() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[0, 2]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 20, 30], [4, 5, 6], [70, 80, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_v2_update +func.func @tensor_flow_scatter_v2_update() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[0, 2]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 30], [40, 60], [70, 90]]> : tensor<3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [0], + inserted_window_dims = [1], + scatter_dims_to_operand_dims = [1], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<3x2xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [10, 2, 30], [40, 5, 60], [70, 8, 90] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_add +func.func @tensor_flow_scatter_add() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[0, 2]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 22, 33], [4, 5, 6], [77, 88, 99] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_repeated +func.func @tensor_flow_scatter_repeated() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[1, 1]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [84, 105, 126], [7, 8, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_multiple_batch +func.func @tensor_flow_scatter_multiple_batch() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[[0, 2], [2, 1]]> : tensor<2x2xi32> + %2 = arith.constant dense<[[[10, 30], [40, 60], [70, 90]], [[5, 5], [5, 5], [5, 5]]]> : tensor<2x3x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [1], + scatter_dims_to_operand_dims = [1], + index_vector_dim = 2, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x3x2xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [11, 7, 38], [44, 10, 71], [77, 13, 104] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd +func.func @tensor_flow_scatter_nd() -> tensor<3x3x2xi32> { + %0 = arith.constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = arith.constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = arith.constant dense<[[-10, 10], [-40, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + func.return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-10, 10], [-2, 2], [-3, 3] + // CHECK-SAME: [-40, 40], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @tensor_flow_scatter_nd_index_vector +func.func @tensor_flow_scatter_nd_index_vector() -> tensor<3x3x2xi32> { + %0 = arith.constant dense<[[[-1, 1], [-2, 2], [-3, 3]], [[-4, 4], [-5, 5], [-6, 6]], [[-7, 7], [-8, 8], [-9, 9]]]> : tensor<3x3x2xi32> + %1 = arith.constant dense<[[0, 0], [1, 0]]> : tensor<2x2xi32> + %2 = arith.constant dense<[[-10, 10], [-20, 20]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 0, + >, + unique_indices = false + } : (tensor<3x3x2xi32>, tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<3x3x2xi32> + func.return %3 : tensor<3x3x2xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [-20, 20], [-10, 10], [-3, 3] + // CHECK-SAME: [-4, 4], [-5, 5], [-6, 6] + // CHECK-SAME: [-7, 7], [-8, 8], [-9, 9] + // CHECK-SAME: ]> : tensor<3x3x2xi32> +} + +// CHECK-LABEL: @scatter_batch_dus +func.func @scatter_batch_dus() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[[2, 1], [1, 1]]> : tensor<2x2xi32> + %2 = arith.constant dense<[[[10]], [[20]]]> : tensor<2x1x1xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1, 2], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 0, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2x2xi32>, tensor<2x1x1xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: mhlo.constant dense<[ + // CHECK-SAME: [1, 2, 3], [4, 20, 6], [7, 10, 9] + // CHECK-SAME: ]> : tensor<3x3xi32> +} + +// CHECK-LABEL: @scatter_no_update_window_dim +func.func @scatter_no_update_window_dim() -> tensor<3xi32> { + %0 = arith.constant dense<[0, 1, 2]> : tensor<3xi32> + %1 = arith.constant dense<[[[0], [1]], [[2], [1]]]> : tensor<2x2x1xi32> + %2 = arith.constant dense<[[10, 20], [30, 40]]> : tensor<2x2xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %4 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> (tensor) + "mhlo.return"(%4) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 2, + >, + unique_indices = false + } : (tensor<3xi32>, tensor<2x2x1xi32>, tensor<2x2xi32>) -> tensor<3xi32> + func.return %3 : tensor<3xi32> + // CHECK: mhlo.constant dense<[10, 61, 32]> : tensor<3xi32> +} + +// CHECK-LABEL: @scatter_negative_index +func.func @scatter_negative_index() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[0, -1]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> + // CHECK: constant dense<{{\[}}[1, 2, 3], [4, 5, 6], [7, 8, 9]{{\]}}> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" +} + +// CHECK-LABEL: @scatter_out_of_bound +func.func @scatter_out_of_bound() -> tensor<3x3xi32> { + %0 = arith.constant dense<[[1, 2, 3], [4, 5, 6], [7, 8, 9]]> : tensor<3x3xi32> + %1 = arith.constant dense<[1, 5]> : tensor<2xi32> + %2 = arith.constant dense<[[10, 20, 30], [70, 80, 90]]> : tensor<2x3xi32> + // CHECK: constant dense<{{\[}}[1, 2, 3], [4, 5, 6], [7, 8, 9]{{\]}}> : tensor<3x3xi32> + // CHECK: "mhlo.scatter" + %3 = "mhlo.scatter"(%0, %1, %2) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + "mhlo.return"(%arg1) : (tensor) -> () + }) {indices_are_sorted = false, + scatter_dimension_numbers = #mhlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0], + scatter_dims_to_operand_dims = [0], + index_vector_dim = 1, + >, + unique_indices = false + } : (tensor<3x3xi32>, tensor<2xi32>, tensor<2x3xi32>) -> tensor<3x3xi32> + func.return %3 : tensor<3x3xi32> +} + +// CHECK-LABEL: @scatter_complex +func.func public @scatter_complex() -> tensor<1xcomplex> { + %0 = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor> + %1 = mhlo.constant dense<0> : tensor<1xi32> + %2 = mhlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<1xcomplex> + // CHECK: "mhlo.scatter" + %3 = "mhlo.scatter"(%2, %1, %0) ({ + ^bb0(%arg0: tensor>, %arg1: tensor>): + "mhlo.return"(%arg1) : (tensor>) -> () + }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter, unique_indices = true} : (tensor<1xcomplex>, tensor<1xi32>, tensor>) -> tensor<1xcomplex> + func.return %3 : tensor<1xcomplex> +} + +//////// +// SelectOp + +// CHECK-LABEL: func @fold_select_same +func.func @fold_select_same(%arg0 : tensor, %arg1 : tensor) -> tensor { + %1 = "mhlo.select"(%arg1, %arg0, %arg0) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + func.return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_first +func.func @fold_select_first(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<1> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg0 + func.return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_second +func.func @fold_select_second(%arg0 : tensor, %arg1 : tensor) -> tensor { + %0 = mhlo.constant dense<0> : tensor + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor, tensor, tensor) -> tensor + // CHECK: return %arg1 + func.return %1 : tensor +} + +// CHECK-LABEL: func @fold_select_vector +func.func @fold_select_vector(%arg0 : tensor<4xf32>, %arg1 : tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.constant dense<1> : tensor<4xi1> + %1 = "mhlo.select"(%0, %arg0, %arg1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK: return %arg0 + func.return %1 : tensor<4xf32> +} + +//////// +// SineOp + +// CHECK-LABEL: func @fold_sine +func.func @fold_sine() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.sine"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.909297406> : tensor<4xf32> + // CHECK-NOT: mhlo.sine + func.return %1 : tensor<4xf32> +} + +//////// +// SliceOp + +// CHECK-LABEL: slice_1D_fold +func.func @slice_1D_fold() -> tensor<2xi64> { + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 9]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) + func.return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_1D_fp +func.func @slice_1D_fp() -> tensor<2xf32> { + %0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32> + // CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<4xf32>) -> (tensor<2xf32>) + func.return %1 : tensor<2xf32> +} + +// CHECK-LABEL: slice_1D_strided_fold +func.func @slice_1D_strided_fold() -> tensor<2xi64> { + %0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64> + // CHECK: mhlo.constant dense<[7, 10]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>}> : (tensor<4xi64>) -> (tensor<2xi64>) + func.return %1 : tensor<2xi64> +} + +// CHECK-LABEL: slice_2D_fold +func.func @slice_2D_fold() -> tensor<2x2xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [6, 7], + // CHECK-SAME: [10, 11] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<2x2xi64>) + func.return %1 : tensor<2x2xi64> +} + +// CHECK-LABEL: slice_2D_fold_horizontal +func.func @slice_2D_fold_horizontal() -> tensor<1x4xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [0, 1, 2, 3] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<1x4xi64>) + func.return %1 : tensor<1x4xi64> +} + +// CHECK-LABEL: slice_2D_fold_vertical +func.func @slice_2D_fold_vertical() -> tensor<4x1xi64> { + %0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64> + // CHECK-NEXT: mhlo.constant dense<[ + // CHECK-SAME: [2], [6], [10], [14] + // CHECK-SAME: ]> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x4xi64>) -> (tensor<4x1xi64>) + func.return %1 : tensor<4x1xi64> +} + +// CHECK-LABEL: slice_zero_elements +func.func @slice_zero_elements() -> tensor<0xi64> { + %0 = mhlo.constant dense<> : tensor<0xi64> + // CHECK: %[[CONST:.*]] = mhlo.constant dense<> : tensor<0xi64> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[0]> : tensor<1xi64>, start_indices = dense<[0]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}> : (tensor<0xi64>) -> (tensor<0xi64>) + // CHECK: return %[[CONST]] : tensor<0xi64> + func.return %1 : tensor<0xi64> +} + +// CHECK-LABEL: slice_concat_fold_first +func.func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg0 + func.return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second +func.func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x5xf32>) + // CHECK: return %arg1 + func.return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_second_with_slice +func.func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x5xf32>) -> tensor<1x4xf32> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<2x5xf32>) -> (tensor<1x4xf32>) + + // CHECK: return [[SLICE]] + func.return %1 : tensor<1x4xf32> +} + +// CHECK-LABEL: slice_concat_fold_middle +func.func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> { + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + // CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) <{limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<1x5xf32>) + + // CHECK: return [[SLICE]] + func.return %1 : tensor<1x5xf32> +} + +// CHECK-LABEL: slice_concat_fold_two +func.func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> { + // CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) <{dimension = 0 : i64}> + %0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) <{ dimension = 0 : i64 }> : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32> + + // CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) <{limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> + %1 = "mhlo.slice"(%0) <{ limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<4x5xf32>) -> (tensor<2x5xf32>) + + // CHECK: return [[SLICE]] + func.return %1 : tensor<2x5xf32> +} + +//////// +// Subtract + +// CHECK-LABEL: sub_scalar_fold +func.func @sub_scalar_fold() -> tensor<4xi64> { + %0 = mhlo.constant dense<5> : tensor<4xi64> + %1 = mhlo.constant dense<1> : tensor<4xi64> + // CHECK: mhlo.constant dense<4> + %2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>) + func.return %2 : tensor<4xi64> +} + +//////// +// SqrtOp + +// CHECK-LABEL: func @fold_sqrt_f16_constants +func.func @fold_sqrt_f16_constants() -> tensor<4xf16> { + %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf16> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf16>) -> tensor<4xf16> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf16> + // CHECK-NOT: mhlo.sqrt + func.return %1 : tensor<4xf16> +} + +// CHECK-LABEL: func @fold_sqrt_bf16_constants +func.func @fold_sqrt_bf16_constants() -> tensor<4xbf16> { + %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xbf16> + %1 = "mhlo.sqrt"(%0) : (tensor<4xbf16>) -> tensor<4xbf16> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16> + // CHECK-NOT: mhlo.sqrt + func.return %1 : tensor<4xbf16> +} + +// CHECK-LABEL: func @fold_sqrt_f32_constants +func.func @fold_sqrt_f32_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<1.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.sqrt + func.return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @fold_sqrt_f64_constants +func.func @fold_sqrt_f64_constants() -> tensor<4xf64> { + %0 = mhlo.constant dense<[1.0, 4.0, 9.0, 16.0]> : tensor<4xf64> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf64>) -> tensor<4xf64> + // CHECK: mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xf64> + // CHECK-NOT: mhlo.sqrt + func.return %1 : tensor<4xf64> +} + +// CHECK-LABEL: func @fold_sqrt_const_zero +func.func @fold_sqrt_const_zero() -> tensor<4xf32> { + %0 = mhlo.constant dense<0.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.000000e+00> : tensor<4xf32> + // CHECK-NOT: mhlo.sqrt + func.return %1 : tensor<4xf32> +} + +// CHECK-LABEL: func @not_fold_sqrt_neg_constants +func.func @not_fold_sqrt_neg_constants() -> tensor<4xf32> { + %0 = mhlo.constant dense<-1.0> : tensor<4xf32> + %1 = "mhlo.sqrt"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<-1.000000e+00> : tensor<4xf32> + // CHECK: mhlo.sqrt + func.return %1 : tensor<4xf32> +} + +//////// +// TanhOp + +// CHECK-LABEL: func @fold_tanh +func.func @fold_tanh() -> tensor<4xf32> { + %0 = mhlo.constant dense<2.0> : tensor<4xf32> + %1 = "mhlo.tanh"(%0) : (tensor<4xf32>) -> tensor<4xf32> + // CHECK: mhlo.constant dense<0.964027583> : tensor<4xf32> + // CHECK-NOT: mhlo.tanh + func.return %1 : tensor<4xf32> +} + +//////// +// XorOp + +// CHECK-LABEL: func @fold_xor_same +func.func @fold_xor_same(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = "mhlo.xor"(%arg0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<0> : tensor<4xi32> + // CHECK: return %0 + func.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_same_dynamic +func.func @fold_xor_same_dynamic(%arg0 : tensor) -> tensor { + %0 = "mhlo.xor"(%arg0, %arg0) : (tensor, tensor) -> tensor + // CHECK: mhlo.xor + func.return %0 : tensor +} + +// CHECK-LABEL: func @fold_xor_ones_left +func.func @fold_xor_ones_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_ones_right +func.func @fold_xor_ones_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<-1> : tensor<4xi32> + // CHECK: mhlo.xor + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_left +func.func @fold_xor_zeros_left(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%0, %arg0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_right +func.func @fold_xor_zeros_right(%arg0 : tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.constant dense<0> : tensor<4xi32> + %1 = "mhlo.xor"(%arg0, %0) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: return %arg0 + func.return %1 : tensor<4xi32> +} + +// CHECK-LABEL: func @fold_xor_zeros_constants +func.func @fold_xor_zeros_constants() -> tensor<4xi32> { + %0 = mhlo.constant dense<[0, 1, 6, 3]> : tensor<4xi32> + %1 = mhlo.constant dense<[7, 3, 7, 2]> : tensor<4xi32> + %2 = "mhlo.xor"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + // CHECK: %0 = mhlo.constant dense<[7, 2, 1, 1]> : tensor<4xi32> + // CHECK: return %0 + func.return %2 : tensor<4xi32> +} \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir index 3475bcb73cb54c..6e91799784d0bd 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-linalg.mlir @@ -5749,6 +5749,16 @@ func.func @reduce_precision(%arg0: tensor<1x2x3x4xf32>) // ----- +// CHECK-LABEL: func @reduce_precision_noop( +// CHECK-SAME: %[[ARG:.*]]: tensor<1x2xf32> +// CHECK: return %[[ARG]] +func.func @reduce_precision_noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %0 = "mhlo.reduce_precision"(%arg0) {exponent_bits=8:i32, mantissa_bits=23:i32} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %0 : tensor<1x2xf32> +} + +// ----- + // The following pattern only tests the general structure of the code and the // affine maps as it is better tested by tests executing the result, and as it // includes many ops which could lead to a high load of refactoring. diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 8baa3e0d3298df..59618001c2d7cc 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -1805,6 +1805,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "mhlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "stablehlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index ab36c5123f80d5..24d1220df4bd47 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -6832,6 +6832,20 @@ func.func @invalid_dimension_attr(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +func.func @f8e4m3(%arg0: tensor) -> tensor { + %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + func.func @f8e4m3fn(%arg0: tensor) -> tensor { %0 = "mhlo.convert"(%arg0) : (tensor) -> tensor func.return %0 : tensor diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 0f2e1b108a710f..66c388b9ed373e 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1787,6 +1787,20 @@ func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { func.return %0 : tensor } +// CHECK-LABEL: "type_f8E3M4" +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + // CHECK-LABEL: "type_f8E4M3FN" func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.add"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) : (tensor, tensor) -> tensor diff --git a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc index 1e810cff21a555..5ce72beb332b9a 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/bufferize_pass.cc @@ -66,6 +66,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Visitors.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -235,8 +236,9 @@ struct OneShotBufferizePass opts.allowReturnAllocsFromLoops = true; opts.bufferizeFunctionBoundaries = true; opts.functionArgTypeConverterFn = - [=](TensorType tensorType, Attribute memorySpace, func::FuncOp funcOp, - const bufferization::BufferizationOptions& options) { + [=](TensorType tensorType, Attribute memorySpace, + FunctionOpInterface funcOp, + const bufferization::BufferizationOptions& /*options*/) { // Functions created by fusion outlining should have fully dynamic // layout. All other functions (for now only "main") gets static // layout. diff --git a/third_party/xla/xla/packed_literal_reader.cc b/third_party/xla/xla/packed_literal_reader.cc index 03cf165176e51b..e707778a556694 100644 --- a/third_party/xla/xla/packed_literal_reader.cc +++ b/third_party/xla/xla/packed_literal_reader.cc @@ -21,15 +21,18 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/file_system.h" #include "tsl/platform/logging.h" -#include "tsl/platform/protobuf.h" namespace xla { diff --git a/third_party/xla/xla/packed_literal_reader.h b/third_party/xla/xla/packed_literal_reader.h index 1b9b14a0c93c8d..4a5d5fdb9bdc99 100644 --- a/third_party/xla/xla/packed_literal_reader.h +++ b/third_party/xla/xla/packed_literal_reader.h @@ -19,10 +19,13 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/layout.h" #include "xla/literal.h" +#include "xla/shape.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" namespace xla { diff --git a/third_party/xla/xla/parse_flags_from_env.cc b/third_party/xla/xla/parse_flags_from_env.cc index 4da2c9a0de2613..7e5e7e686b04b1 100644 --- a/third_party/xla/xla/parse_flags_from_env.cc +++ b/third_party/xla/xla/parse_flags_from_env.cc @@ -27,10 +27,12 @@ limitations under the License. #include #include +#include "absl/base/const_init.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/ascii.h" -#include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/parse_flags_from_env_test.cc b/third_party/xla/xla/parse_flags_from_env_test.cc index 0f68b41b7510be..c5a5e8006f1cde 100644 --- a/third_party/xla/xla/parse_flags_from_env_test.cc +++ b/third_party/xla/xla/parse_flags_from_env_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" +#include "tsl/platform/macros.h" #include "tsl/platform/subprocess.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/permutation_util.cc b/third_party/xla/xla/permutation_util.cc index 040f210c0cc921..66f2fb134bbd26 100644 --- a/third_party/xla/xla/permutation_util.cc +++ b/third_party/xla/xla/permutation_util.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/types/span.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index a242a71adc4747..41f5318eab749e 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -48,8 +48,9 @@ cc_library( hdrs = ["event_pool.h"], deps = [ "//xla:types", - "//xla/stream_executor", "//xla/stream_executor:event", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -150,15 +151,18 @@ cc_library( ":worker_thread", "//xla:util", "//xla/client:local_client", - "//xla/stream_executor", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/util:env_var", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -187,9 +191,9 @@ xla_cc_test( "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_wrapper_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -210,8 +214,9 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:computation_placer_hdr", "//xla/service:hlo_cost_analysis", "//xla/tsl/framework:allocator", @@ -240,12 +245,14 @@ cc_library( hdrs = ["pjrt_client_test.h"], deps = [ ":pjrt_client", + ":pjrt_compiler", + "//xla:cpu_function_runtime", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/tests:literal_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", @@ -328,7 +335,7 @@ cc_library( ":metrics", ":pjrt_device_description", ":pjrt_executable", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -348,7 +355,7 @@ xla_cc_test( ":pjrt_client", ":pjrt_compiler", ":pjrt_device_description", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/tsl/lib/monitoring:cell_reader", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -364,7 +371,7 @@ cc_library( hdrs = ["pjrt_common.h"], visibility = internal_visibility([":friends"]), deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "//xla/tsl/lib/gtl:int_type", ], ) @@ -380,7 +387,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service:computation_placer", "//xla/service:hlo_proto_cc", @@ -406,7 +413,7 @@ cc_library( visibility = ["//xla:friends"], deps = [ "//xla:shape_util", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/hash", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -423,7 +430,7 @@ cc_library( visibility = ["//xla:friends"], deps = [ "//xla:shape_util", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -501,7 +508,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt/distributed:protocol_proto_cc", "//xla/service:compiler", @@ -514,7 +521,9 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", "//xla/tsl/framework:allocator", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -560,7 +569,7 @@ xla_cc_test( "//xla:test", "//xla:xla_data_proto_cc", "//xla/client:client_library", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:cpu_plugin", "//xla/service:platform_util", "//xla/tsl/concurrency:async_value", @@ -595,13 +604,14 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:util", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir/utils:error_util", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -618,11 +628,13 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:statusor", + "@shardy//shardy/dialect/sdy/ir:dialect", "@shardy//shardy/dialect/sdy/ir:register", "@stablehlo//:chlo_ops", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", "@stablehlo//:version", ], @@ -639,6 +651,7 @@ xla_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", + "@stablehlo//:stablehlo_portable_api", ], ) @@ -732,6 +745,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -755,6 +769,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/numeric:int128", "@eigen_archive//:eigen3", @@ -762,7 +777,6 @@ xla_cc_test( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -786,8 +800,9 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_helpers", @@ -799,7 +814,6 @@ cc_library( "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/tsl/framework:allocator", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -810,6 +824,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -831,21 +846,27 @@ xla_cc_test( "nomsan", ], deps = [ + ":mlir_to_hlo", ":pjrt_api", ":pjrt_c_api_client", ":pjrt_client", ":pjrt_compiler", ":pjrt_executable", + "//xla:cpu_function_runtime", "//xla:literal_util", "//xla:shape_util", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/pjrt/c:pjrt_c_api_cpu_internal", + "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@stablehlo//:version", ], ) @@ -863,7 +884,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla:util", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service:computation_placer_hdr", "//xla/service:hlo_cost_analysis", @@ -888,8 +909,8 @@ xla_cc_test( deps = [ ":tf_pjrt_client", "//xla:literal_util", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt/cpu:cpu_client", - "//xla/service:hlo_parser", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", @@ -942,7 +963,6 @@ tf_proto_library( # copybara:uncomment_begin(google-only) # py_proto_library( # name = "compile_options_py_pb2", -# api_version = 2, # visibility = ["//visibility:public"], # deps = [":compile_options_proto"], # ) @@ -986,7 +1006,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = [":friends"], + visibility = internal_visibility([":friends"]), deps = [ ":exceptions", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index 0556ee9f61887a..189307891aab2f 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -17,9 +17,13 @@ load( "//xla/tsl:tsl.bzl", "if_google", "if_macos", + "internal_visibility", ) -# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility(["//xla:internal"]), +) cc_library( name = "pjrt_c_api_hdrs", @@ -129,7 +133,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:mlir_to_hlo", @@ -151,6 +155,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", @@ -181,6 +186,7 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt/distributed:key_value_store_interface", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -195,6 +201,7 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:connected_traceme", "@local_tsl//tsl/profiler/lib:context_types_hdrs", + "@stablehlo//:version", ], ) @@ -345,6 +352,7 @@ cc_library( ":pjrt_c_api_helpers", "//xla:shape_util", "//xla/client:executable_build_options", + "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", @@ -419,6 +427,7 @@ xla_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@stablehlo//:version", ], ) @@ -450,11 +459,11 @@ cc_library( "//xla:xla_proto_cc", "//xla/client:executable_build_options", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_future", "//xla/service:computation_placer_hdr", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/tests:literal_test_util", "@com_google_absl//absl/base:core_headers", diff --git a/third_party/xla/xla/pjrt/c/CHANGELOG.md b/third_party/xla/xla/pjrt/c/CHANGELOG.md index 477045dfb93adb..d373b22b73e581 100644 --- a/third_party/xla/xla/pjrt/c/CHANGELOG.md +++ b/third_party/xla/xla/pjrt/c/CHANGELOG.md @@ -1,5 +1,14 @@ # PJRT C API changelog +## 0.56 + +* Added `overridden_serialized_compile_options` and + `overridden_serialized_compile_options_size` fields to + `PJRT_Executable_DeserializeAndLoad_Args`. + +## 0.55 +* Added types F8E4M3 and F8E3M4. + ## 0.54 * Deprecated PJRT_Buffer_GetMemoryLayout. diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api.h b/third_party/xla/xla/pjrt/c/pjrt_c_api.h index 1d5b44c60201c5..9afa95362ce800 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api.h @@ -79,7 +79,7 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_Extension_Base, next); // Changes include: // * Adding a new field to the PJRT_Api or argument structs // * Renaming a method or argument (doesn't affect ABI) -#define PJRT_API_MINOR 54 +#define PJRT_API_MINOR 56 // The plugin should set the major_version and minor_version of // PJRT_Api.pjrt_api_version to be the `PJRT_API_MAJOR` and `PJRT_API_MINOR` in @@ -644,6 +644,10 @@ typedef enum { // 2-bit integer types PJRT_Buffer_Type_S2, PJRT_Buffer_Type_U2, + + // More truncated 8 bit floating-point formats. + PJRT_Buffer_Type_F8E4M3, + PJRT_Buffer_Type_F8E3M4, } PJRT_Buffer_Type; typedef enum { @@ -1573,6 +1577,11 @@ struct PJRT_Executable_DeserializeAndLoad_Args { const char* serialized_executable; size_t serialized_executable_size; PJRT_LoadedExecutable* loaded_executable; // out + // Serialized CompileOptionsProto or null (to use the options + // from the serialized executable). + // (https://github.com/openxla/xla/blob/main/xla/pjrt/compile_options.proto) + const char* overridden_serialized_compile_options; + size_t overridden_serialized_compile_options_size; }; PJRT_DEFINE_STRUCT_TRAITS(PJRT_Executable_DeserializeAndLoad_Args, loaded_executable); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc index 2fa0891685a95e..43b2d283ddf564 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_internal.cc @@ -39,7 +39,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { xla::CpuClientOptions options; options.cpu_device_count = 4; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, - xla::GetTfrtCpuClient(options)); + xla::GetTfrtCpuClient(std::move(options))); args->client = pjrt::CreateWrapperClient(std::move(client)); return nullptr; } @@ -64,7 +64,8 @@ const PJRT_Api* GetCpuPjrtApi() { pjrt::cpu_plugin::PJRT_ExecuteContext_Create, pjrt::cpu_plugin::PJRT_CpuDeviceTopology_Create, pjrt::PJRT_Plugin_Initialize_NoOp, - reinterpret_cast(&layouts_extension)); + reinterpret_cast(&layouts_extension), + pjrt::PJRT_Plugin_Attributes_Xla); return &pjrt_api; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_test.cc index ff32e0cb0e2120..1d10256f143913 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_cpu_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api_cpu.h" #include "xla/pjrt/c/pjrt_c_api_test.h" -#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" namespace pjrt { namespace { diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h index 3ecdaeafb32749..28b17e5434f2ea 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_extension.h @@ -24,17 +24,19 @@ limitations under the License. extern "C" { #endif -#define PJRT_API_GPU_EXTENSION_VERSION 1 +#define PJRT_API_GPU_EXTENSION_VERSION 2 struct PJRT_Gpu_Register_Custom_Call_Args { size_t struct_size; const char* function_name; size_t function_name_size; int api_version; // 0 for an untyped call, 1 -- for typed - void* custom_call_function; + void* handler_instantiate; + void* handler_prepare; + void* handler_initialize; + void* handler_execute; }; -PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Register_Custom_Call_Args, - custom_call_function); +PJRT_DEFINE_STRUCT_TRAITS(PJRT_Gpu_Register_Custom_Call_Args, handler_execute); // Registers a custom call. typedef PJRT_Error* PJRT_Gpu_Register_Custom_Call( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc index 2d593290087719..722471140c2818 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_internal.cc @@ -83,6 +83,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { {"node_id", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, {"num_nodes", PJRT_NamedValue_Type::PJRT_NamedValue_kInt64}, {"enable_mock_nccl", PJRT_NamedValue_Type::PJRT_NamedValue_kBool}, + {"mock_gpu_topology", PJRT_NamedValue_Type::PJRT_NamedValue_kString}, }); PJRT_RETURN_IF_ERROR( ValidateCreateOptions(create_options, kExpectedOptionNameAndTypes)); @@ -141,6 +142,11 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { it != create_options.end()) { enable_mock_nccl = std::get(it->second); } + std::optional mock_gpu_topology; + if (auto it = create_options.find("mock_gpu_topology"); + it != create_options.end()) { + mock_gpu_topology = std::get(it->second); + } xla::GpuClientOptions options; options.allocator_config = allocator_config; @@ -152,6 +158,7 @@ PJRT_Error* PJRT_Client_Create(PJRT_Client_Create_Args* args) { pjrt::ToCppKeyValueStore(args->kv_get_callback, args->kv_get_user_arg, args->kv_put_callback, args->kv_put_user_arg); options.enable_mock_nccl = enable_mock_nccl; + options.mock_gpu_topology = mock_gpu_topology; PJRT_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetStreamExecutorGpuClient(options)); args->client = pjrt::CreateWrapperClient(std::move(client)); @@ -167,19 +174,37 @@ PJRT_Error* PJRT_ExecuteContext_Create(PJRT_ExecuteContext_Create_Args* args) { return nullptr; } -PJRT_Error* PJRT_GpuDeviceTopology_Create( - PJRT_TopologyDescription_Create_Args* args) { - PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( - "PJRT_TopologyDescription_Create_Args", - PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); +namespace { + +struct TargetConfigAndDevices { + stream_executor::GpuTargetConfigProto target_config_proto; + std::vector device_ids; +}; - PJRT_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, - xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, - /*allowed_devices=*/std::nullopt)); +// Parses the 'target_config' entry in 'options'. The option is +// parsed as GpuTargetConfigProto. If there is no 'target_config' in +// 'options', the function falls back to creating a local client, +// returning the local client's target config. +absl::StatusOr GetTargetConfigFromOptions( + const absl::flat_hash_map& options) { + if (auto target_config_it = options.find("target_config"); + target_config_it != options.end()) { + std::string target_config_proto_string = + std::get(target_config_it->second); + stream_executor::GpuTargetConfigProto target_config_proto; + if (!tsl::protobuf::TextFormat::ParseFromString(target_config_proto_string, + &target_config_proto)) { + return absl::FailedPreconditionError( + "Failed to parse GpuTargetConfigProto " + "from the 'target_config' parameter."); + } + return {{target_config_proto, {}}}; + } + TF_ASSIGN_OR_RETURN(xla::LocalClient * xla_client, + xla::GetGpuXlaClient(/*platform_name=*/std::nullopt, + /*allowed_devices=*/std::nullopt)); stream_executor::StreamExecutor* executor = xla_client->backend().default_stream_executor(); - const stream_executor::DeviceDescription& description = - executor->GetDeviceDescription(); std::vector device_ids; device_ids.reserve(xla_client->backend().stream_executors().size()); for (stream_executor::StreamExecutor* executor : @@ -187,13 +212,16 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( device_ids.push_back(executor->device_ordinal()); } auto gpu_target_config = xla::Compiler::TargetConfig(executor); - // TODO(b/341334898): Create a single-host GPU topology. Will be updated for - // multi-host support in the future. - auto gpu_topology = std::make_shared( - device_ids, description.name(), - /*num_slices=*/1, - /*num_hosts_per_slice=*/1, - /*num_devices_per_host=*/device_ids.size()); + return {{gpu_target_config.ToProto(), device_ids}}; +} + +} // namespace + +PJRT_Error* PJRT_GpuDeviceTopology_Create( + PJRT_TopologyDescription_Create_Args* args) { + PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( + "PJRT_TopologyDescription_Create_Args", + PJRT_TopologyDescription_Create_Args_STRUCT_SIZE, args->struct_size)); // Determine the platform ID and name based on the platform. xla::PjRtPlatformId platform_id = @@ -203,12 +231,48 @@ PJRT_Error* PJRT_GpuDeviceTopology_Create( (std::string(PJRT_GPU_PLUGIN_PLATFORM_NAME) == "ROCM") ? xla::RocmName() : xla::CudaName(); + absl::flat_hash_map create_options = + pjrt::ConvertFromPjRtNamedValueList(args->create_options, + args->num_options); + + PJRT_ASSIGN_OR_RETURN(TargetConfigAndDevices target_config_and_devices, + GetTargetConfigFromOptions(create_options)); + + std::vector& device_ids = target_config_and_devices.device_ids; + stream_executor::GpuTargetConfigProto& target_config_proto = + target_config_and_devices.target_config_proto; + xla::TopologySizes sizes{1, 1, static_cast(device_ids.size())}; + + if (auto topology_it = create_options.find("topology"); + topology_it != create_options.end()) { + std::string topology_string = std::get(topology_it->second); + PJRT_ASSIGN_OR_RETURN(sizes, + xla::TopologySizes::FromString(topology_string)); + } + + if (sizes.GetDeviceCount() == 0) { + // If the user did not specify the topology and we did not + // get any devices from the client, then error out because + // we do not know how many devices the topology should have. + return new PJRT_Error{absl::FailedPreconditionError( + "Cannot create topology without an explicit topology shape or without " + "a client")}; + } + + if (sizes.GetDeviceCount() != device_ids.size()) { + device_ids.resize(sizes.GetDeviceCount()); + absl::c_iota(device_ids, sizes.GetDeviceCount()); + } + + auto gpu_topology = std::make_shared( + device_ids, target_config_proto.device_description_str(), + sizes.num_slices, sizes.num_hosts_per_slice, sizes.num_devices_per_host); + auto pjrt_topology = std::make_unique( platform_id, platform_name, std::move(gpu_topology), absl::flat_hash_map{ - {"target_config", - gpu_target_config.ToProto().SerializeAsString()}}); + {"target_config", target_config_proto.SerializeAsString()}}); args->topology = CreateWrapperDeviceTopology(std::move(pjrt_topology)); return nullptr; } @@ -293,14 +357,17 @@ PJRT_Error* PJRT_Gpu_Register_Custom_Call( switch (args->api_version) { case 0: xla::CustomCallTargetRegistry::Global()->Register( - function_name, args->custom_call_function, - PJRT_GPU_PLUGIN_PLATFORM_NAME); + function_name, args->handler_execute, PJRT_GPU_PLUGIN_PLATFORM_NAME); return nullptr; case 1: xla::ffi::Ffi::RegisterStaticHandler( xla::ffi::GetXlaFfiApi(), function_name, PJRT_GPU_PLUGIN_PLATFORM_NAME, - reinterpret_cast(args->custom_call_function)); + XLA_FFI_Handler_Bundle{ + reinterpret_cast(args->handler_instantiate), + reinterpret_cast(args->handler_prepare), + reinterpret_cast(args->handler_initialize), + reinterpret_cast(args->handler_execute)}); return nullptr; default: return new PJRT_Error{absl::UnimplementedError( diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index c288315db8b96c..613241ec92f60a 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -373,11 +373,9 @@ TEST(PjrtCApiPlatformNameTest, AvailablePlatformName) { PJRT_Error* platform_name_error = api->PJRT_Client_PlatformName(&platform_name_args); EXPECT_EQ(platform_name_error, nullptr); -#if TENSORFLOW_USE_ROCM - EXPECT_EQ(platform_name_args.platform_name, expected_platform_name_for_rocm); -#else - EXPECT_EQ(platform_name_args.platform_name, expected_platform_name_for_cuda); -#endif + EXPECT_THAT(platform_name_args.platform_name, + testing::AnyOf(expected_platform_name_for_cuda, + expected_platform_name_for_rocm)); PJRT_Client_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_Client_Destroy_Args_STRUCT_SIZE; @@ -427,6 +425,8 @@ TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) { args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; args.extension_start = nullptr; args.topology = nullptr; + args.num_options = 0; + args.create_options = nullptr; PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); EXPECT_EQ(error, nullptr) << error->status.message(); @@ -435,13 +435,116 @@ TEST(PJRTGpuDeviceTopologyTest, CreateGpuTopology) { reinterpret_cast(args.topology); ASSERT_NE(pjrt_topology, nullptr); -#ifdef TENSORFLOW_USE_ROCM - EXPECT_EQ(pjrt_topology->topology->platform_id(), xla::RocmId()); - EXPECT_EQ(pjrt_topology->topology->platform_name(), xla::RocmName()); -#else - EXPECT_EQ(pjrt_topology->topology->platform_id(), xla::CudaId()); - EXPECT_EQ(pjrt_topology->topology->platform_name(), xla::CudaName()); -#endif + EXPECT_TRUE((pjrt_topology->topology->platform_id() == xla::CudaId() && + pjrt_topology->topology->platform_name() == xla::CudaName()) || + (pjrt_topology->topology->platform_id() == xla::RocmId() && + pjrt_topology->topology->platform_name() == xla::RocmName())); + + PJRT_TopologyDescription_Destroy_Args destroy_args; + destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.topology = const_cast(pjrt_topology); + PJRT_Error* destroy_error = + pjrt_api->PJRT_TopologyDescription_Destroy(&destroy_args); + EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message(); +} + +constexpr char const* kTargetConfigString = R"(gpu_device_info { + threads_per_block_limit: 1024 + threads_per_warp: 32 + shared_memory_per_block: 49152 + shared_memory_per_core: 98304 + threads_per_core_limit: 2048 + core_count: 80 + fpus_per_core: 64 + block_dim_limit_x: 2147483647 + block_dim_limit_y: 65535 + block_dim_limit_z: 65535 + memory_bandwidth: 898048000000 + l2_cache_size: 6291456 + clock_rate_ghz: 1.53 + device_memory_size: 34072559616 + shared_memory_per_block_optin: 98304 + cuda_compute_capability { + major: 7 + } + registers_per_core_limit: 65536 + registers_per_block_limit: 65536 +} +platform_name: "CUDA" +dnn_version_info { + major: 9 + minor: 3 +} +device_description_str: "Tesla V100-SXM2-32GB" +)"; + +TEST(PJRTGpuDeviceTopologyTest, CreateExplicitGpuTopologyAndTargetConfig) { + auto pjrt_api = gpu_plugin::GetGpuPjrtApi(); + + absl::flat_hash_map options = { + {"topology", static_cast("16 x 2 x 4")}, + {"target_config", static_cast(kTargetConfigString)}}; + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + + PJRT_TopologyDescription_Create_Args args; + args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.topology = nullptr; + args.num_options = c_options.size(); + args.create_options = c_options.data(); + + PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); + EXPECT_EQ(error, nullptr) << error->status.message(); + + auto pjrt_topology = + reinterpret_cast(args.topology); + ASSERT_NE(pjrt_topology, nullptr); + + EXPECT_TRUE((pjrt_topology->topology->platform_id() == xla::CudaId() && + pjrt_topology->topology->platform_name() == xla::CudaName()) || + (pjrt_topology->topology->platform_id() == xla::RocmId() && + pjrt_topology->topology->platform_name() == xla::RocmName())); + + EXPECT_EQ(pjrt_topology->topology->ProcessCount().value(), 16 * 2); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions().size(), 16 * 2 * 4); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions()[0]->device_kind(), + "Tesla V100-SXM2-32GB"); + + PJRT_TopologyDescription_Destroy_Args destroy_args; + destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; + destroy_args.extension_start = nullptr; + destroy_args.topology = const_cast(pjrt_topology); + PJRT_Error* destroy_error = + pjrt_api->PJRT_TopologyDescription_Destroy(&destroy_args); + EXPECT_EQ(destroy_error, nullptr) << destroy_error->status.message(); +} + +TEST(PJRTGpuDeviceTopologyTest, CreateExplicitGpuTopology) { + auto pjrt_api = gpu_plugin::GetGpuPjrtApi(); + + absl::flat_hash_map options = { + {"topology", static_cast("16 x 2 x 4")}}; + TF_ASSERT_OK_AND_ASSIGN(std::vector c_options, + ::pjrt::ConvertToPjRtNamedValueList(options)); + + PJRT_TopologyDescription_Create_Args args; + args.struct_size = PJRT_TopologyDescription_Create_Args_STRUCT_SIZE; + args.extension_start = nullptr; + args.topology = nullptr; + args.num_options = c_options.size(); + args.create_options = c_options.data(); + + PJRT_Error* error = pjrt_api->PJRT_TopologyDescription_Create(&args); + EXPECT_EQ(error, nullptr) << error->status.message(); + + auto pjrt_topology = + reinterpret_cast(args.topology); + ASSERT_NE(pjrt_topology, nullptr); + + EXPECT_EQ(pjrt_topology->topology->ProcessCount().value(), 16 * 2); + EXPECT_EQ(pjrt_topology->topology->DeviceDescriptions().size(), 16 * 2 * 4); PJRT_TopologyDescription_Destroy_Args destroy_args; destroy_args.struct_size = PJRT_TopologyDescription_Destroy_Args_STRUCT_SIZE; @@ -461,7 +564,10 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallUntyped) { args.function_name = function_name.c_str(); args.function_name_size = function_name.size(); args.api_version = 0; - args.custom_call_function = reinterpret_cast(&TestCustomCallV2); + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = reinterpret_cast(&TestCustomCallV2); auto api = GetPjrtApi(); const PJRT_Extension_Base* next = reinterpret_cast(api->extension_start); @@ -491,7 +597,10 @@ TEST(PjrtCApiGpuExtensionTest, CustomCallTyped) { args.function_name = function_name.c_str(); args.function_name_size = function_name.size(); args.api_version = 1; - args.custom_call_function = reinterpret_cast(kNoop); + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = reinterpret_cast(kNoop); auto api = GetPjrtApi(); const PJRT_Extension_Base* next = reinterpret_cast(api->extension_start); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index b9508cf24950b4..a0596784085cb9 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "stablehlo/dialect/Version.h" #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_layouts_extension.h" @@ -45,6 +46,7 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -295,6 +297,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F64; case xla::PrimitiveType::F8E5M2: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2; + case xla::PrimitiveType::F8E4M3: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3; case xla::PrimitiveType::F8E4M3FN: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN; case xla::PrimitiveType::F8E4M3B11FNUZ: @@ -303,6 +307,8 @@ PJRT_Buffer_Type ConvertToPjRtBufferType(xla::PrimitiveType type) { return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2FNUZ; case xla::PrimitiveType::F8E4M3FNUZ: return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ; + case xla::PrimitiveType::F8E3M4: + return PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4; case xla::PrimitiveType::C64: return PJRT_Buffer_Type::PJRT_Buffer_Type_C64; case xla::PrimitiveType::C128: @@ -358,6 +364,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::C128; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E5M2: return xla::PrimitiveType::F8E5M2; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3: + return xla::PrimitiveType::F8E4M3; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FN: return xla::PrimitiveType::F8E4M3FN; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3B11FNUZ: @@ -366,6 +374,8 @@ xla::PrimitiveType ConvertFromPjRtBufferType(PJRT_Buffer_Type type) { return xla::PrimitiveType::F8E5M2FNUZ; case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E4M3FNUZ: return xla::PrimitiveType::F8E4M3FNUZ; + case PJRT_Buffer_Type::PJRT_Buffer_Type_F8E3M4: + return xla::PrimitiveType::F8E3M4; case PJRT_Buffer_Type::PJRT_Buffer_Type_INVALID: CHECK(false) << "Buffer type is not supported in C API layer."; } @@ -594,19 +604,46 @@ absl::Status ValidateCreateOptions( return absl::OkStatus(); } -const std::vector& GetXlaPluginCAttributes() { - constexpr absl::string_view kXlaVersion = "xla_version"; +static PJRT_NamedValue XlaVersion(absl::string_view name) { PJRT_NamedValue c_value; c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; c_value.extension_start = nullptr; - c_value.name = kXlaVersion.data(); - c_value.name_size = kXlaVersion.size(); + c_value.name = name.data(); + c_value.name_size = name.size(); c_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64; // TODO(b/327203806): figure out where to keep the xla_version. c_value.int64_value = 2; c_value.value_size = 1; - static const std::vector* c_values = - new std::vector({c_value}); + return c_value; +} + +template +static PJRT_NamedValue StableHloVersion(absl::string_view name, + mlir::vhlo::Version version) { + PJRT_NamedValue c_value; + c_value.struct_size = PJRT_NamedValue_STRUCT_SIZE; + c_value.extension_start = nullptr; + c_value.name = name.data(); + c_value.name_size = name.size(); + c_value.type = PJRT_NamedValue_Type::PJRT_NamedValue_kInt64List; + static int64_t triple[3] = {version.getMajor(), version.getMinor(), + version.getPatch()}; + c_value.int64_array_value = triple; + c_value.value_size = 3; + return c_value; +} + +const std::vector& GetXlaPluginCAttributes() { + static const std::vector* c_values = new std::vector< + PJRT_NamedValue>({ + XlaVersion("xla_version"), + // TODO: (b/375454646) Uncomment once frameworks have bugfix: + // https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243 + // StableHloVersion<0>("stablehlo_current_version", + // mlir::vhlo::Version::getCurrentVersion()), + StableHloVersion<1>("stablehlo_minimum_version", + mlir::vhlo::Version::getMinimumVersion()), + }); return *c_values; } diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h index 5f7746db24a497..759569123456ee 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/attributes.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/layout.h" diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index 8d0a51a48bc840..801cd2a3e65e80 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -25,17 +26,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" #include "absl/time/time.h" +#include "stablehlo/dialect/Version.h" #include "xla/layout.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" -#include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/platform/status.h" -#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace pjrt { @@ -207,5 +205,30 @@ TEST(PjRtCApiHelperTest, ConvertFromCLayoutToLayoutNoTile) { EXPECT_EQ(layout.ToString(), "{1,0}"); } +TEST(PjRtCApiHelperTest, GetXlaPluginCAttributes) { + auto result = GetXlaPluginCAttributes(); + std::unordered_map map; + for (PJRT_NamedValue &nv : result) { + auto [_, did_not_exist_yet] = map.insert({nv.name, &nv}); + EXPECT_TRUE(did_not_exist_yet); + } + EXPECT_TRUE(map.find("xla_version") != map.end()); + // TODO: (b/375454646) Uncomment once frameworks have bugfix: + // https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243 + // + // PJRT_NamedValue *current = map["stablehlo_current_version"]; + // mlir::vhlo::Version current_version = + // mlir::vhlo::Version::getCurrentVersion(); + // EXPECT_TRUE(current->int64_array_value[0] == current_version.getMajor()); + // EXPECT_TRUE(current->int64_array_value[1] == current_version.getMinor()); + // EXPECT_TRUE(current->int64_array_value[2] == current_version.getPatch()); + PJRT_NamedValue *minimum = map["stablehlo_minimum_version"]; + mlir::vhlo::Version minimum_version = + mlir::vhlo::Version::getMinimumVersion(); + EXPECT_TRUE(minimum->int64_array_value[0] == minimum_version.getMajor()); + EXPECT_TRUE(minimum->int64_array_value[1] == minimum_version.getMinor()); + EXPECT_TRUE(minimum->int64_array_value[2] == minimum_version.getPatch()); +} + } // namespace } // namespace pjrt diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc index c341df9e82f324..1781c4100744dd 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/executable_build_options.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -46,7 +48,6 @@ limitations under the License. #include "xla/pjrt/pjrt_future.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" @@ -312,8 +313,6 @@ TEST_F(PjrtCApiTest, LookupDeviceOutOfRangeId) { ASSERT_EQ(status, expected); } -static constexpr std::string_view kExecutableName = "operation"; - void destroy_executable(PJRT_LoadedExecutable* executable, const PJRT_Api* api) { PJRT_LoadedExecutable_Destroy_Args args{ @@ -484,6 +483,24 @@ TEST_F(PjrtCApiTest, CompileInvalidProgramFormat) { ::pjrt::MakeErrorDeleter(api_)(error); } +TEST_F(PjrtCApiTest, PluginAttributes) { + PJRT_Plugin_Attributes_Args args; + args.struct_size = PJRT_Plugin_Attributes_Args_STRUCT_SIZE; + args.extension_start = nullptr; + PJRT_Error* error = api_->PJRT_Plugin_Attributes(&args); + ASSERT_EQ(error, nullptr); + std::set names; + for (int i = 0; i < args.num_attributes; i++) { + auto [_, did_not_exist_yet] = names.insert(args.attributes[i].name); + EXPECT_TRUE(did_not_exist_yet); + } + EXPECT_TRUE(names.find("xla_version") != names.end()); + // TODO: (b/375454646) Uncomment once frameworks have bugfix: + // https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243 + // EXPECT_TRUE(names.find("stablehlo_current_version") != names.end()); + EXPECT_TRUE(names.find("stablehlo_minimum_version") != names.end()); +} + // --------------------------------- Devices ----------------------------------- TEST_F(PjrtCApiTest, DeviceId) { diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc index b2a42edf5a31e0..9602813c573c52 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_test_base.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/client/executable_build_options.h" #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" +#include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index 54b8dbb6514350..f6f3b5f7e99d11 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -41,7 +41,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" @@ -1570,9 +1570,20 @@ PJRT_Error* PJRT_Executable_DeserializeAndLoad( absl::string_view serialized(args->serialized_executable, args->serialized_executable_size); + std::optional overriden_options; + + if (args->overridden_serialized_compile_options && + args->overridden_serialized_compile_options_size > 0) { + PJRT_ASSIGN_OR_RETURN( + overriden_options, + ParseCompileOptions(absl::string_view( + args->overridden_serialized_compile_options, + args->overridden_serialized_compile_options_size))); + } + PJRT_ASSIGN_OR_RETURN(std::unique_ptr executable, args->client->client->DeserializeExecutable( - serialized, /*options=*/std::nullopt)); + serialized, overriden_options)); args->loaded_executable = new PJRT_LoadedExecutable(std::move(executable), args->client); diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h index f3d375cd7f9d9a..cf2f2cc026b27f 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.h @@ -26,6 +26,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/pjrt/c/pjrt_c_api.h" diff --git a/third_party/xla/xla/pjrt/compile_options.proto b/third_party/xla/xla/pjrt/compile_options.proto index bd23ca73f6244c..56b0f58e10757c 100644 --- a/third_party/xla/xla/pjrt/compile_options.proto +++ b/third_party/xla/xla/pjrt/compile_options.proto @@ -7,7 +7,7 @@ import "xla/xla.proto"; import "xla/xla_data.proto"; // A serialization of xla::ExecutableBuildOptions. -// Next id: 20. +// Next id: 22. message ExecutableBuildOptionsProto { // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are @@ -45,6 +45,24 @@ message ExecutableBuildOptionsProto { // Whether to automatically generate XLA shardings for SPMD partitioner. bool use_auto_spmd_partitioning = 7; + // The amount of effort to spend on optimizing for minimizing program + // execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which + // strongly prioritizes execution time at the cost of longer compile times, + // suitable for production workloads. A value of -0.5 would be appropriate for + // research use cases that prefer faster compilations to iterate more quickly. + // Positive values, on the other hand, might enable costly optimizations that + // are off by default. + float exec_time_optimization_effort = 20; + + // The amount of effort to spend on making the program fit in memory (where + // "fit in memory" here has a backend-dependent meaning), as a value in + // [-1.0,+1.0]. The baseline is 0.0, which expends significant effort on + // attempting to make the program fit. A value of -1.0 would be appropriate + // for use cases that wish to spend minimal effort here and fail as quickly as + // possible instead. Positive values, on the other hand, might enable costly + // algorithms to reduce memory usage that are off by default. + float memory_fitting_effort = 21; + // Whether HLOs should be deduplicated. bool deduplicate_hlo = 8; diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 81b92d8963bd68..23c88d504bd1c3 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -140,6 +140,7 @@ cc_library( ":cpu_topology", ":tracked_tfrt_cpu_device_buffer", "//xla:array", + "//xla:cpu_function_runtime", "//xla:debug_options_flags", "//xla:executable_run_options", "//xla:literal", @@ -149,10 +150,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thread_pool_task_runner", "//xla/backends/cpu/runtime:thunk", "//xla/backends/cpu/runtime:thunk_executor", "//xla/client:executable_build_options", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:host_memory_spaces", @@ -187,11 +189,12 @@ cc_library( "//xla/service/cpu:cpu_runtime", "//xla/service/cpu:cpu_xfeed", "//xla/service/cpu:simple_orc_jit", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/container:flat_hash_map", @@ -232,16 +235,15 @@ xla_cc_test( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_computation", "//xla/ffi", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/tests:literal_test_util", - "//xla/tests:test_utils", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index def5ea3aaab7b6..aa15c09e59be1e 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/casts.h" #include "absl/base/dynamic_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" @@ -48,12 +49,14 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "xla/array.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/cpu_function_runtime.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -380,7 +383,7 @@ static int CpuDeviceCount() { } absl::StatusOr> GetTfrtCpuClient( - const CpuClientOptions& options) { + CpuClientOptions options) { // Need at least CpuDeviceCount threads to launch one collective. int cpu_device_count = options.cpu_device_count.value_or(CpuDeviceCount()); size_t num_threads = std::max(DefaultThreadPoolSize(), cpu_device_count); @@ -395,7 +398,8 @@ absl::StatusOr> GetTfrtCpuClient( return std::unique_ptr(std::make_unique( options.process_id, std::move(devices), std::move(options.collectives), - num_threads, options.asynchronous)); + num_threads, options.asynchronous, + std::move(options.customize_hlo_module_config))); } // An upper bound on the number of threads to use for intra-op parallelism. It @@ -416,7 +420,8 @@ static tsl::ThreadOptions GetThreadOptions() { TfrtCpuClient::TfrtCpuClient( int process_index, std::vector> devices, std::shared_ptr collectives, size_t num_threads, - bool asynchronous) + bool asynchronous, + std::function customize_hlo_module_config) : process_index_(process_index), owned_devices_(std::move(devices)), computation_placer_(std::make_unique()), @@ -438,7 +443,8 @@ TfrtCpuClient::TfrtCpuClient( topology_(TfrtCpuTopologyDescription::Create( platform_id(), platform_name(), platform_version(), owned_devices_, cpu::DetectMachineAttributes())), - asynchronous_(asynchronous) { + asynchronous_(asynchronous), + customize_hlo_module_config_(std::move(customize_hlo_module_config)) { for (const std::unique_ptr& device : owned_devices_) { devices_.push_back(device.get()); CHECK( @@ -705,7 +711,8 @@ static absl::StatusOr> JitCompile( const absl::Span argument_layouts, const ExecutableBuildOptions& build_options, const ExecutionOptions& execution_options, - const xla::Compiler::CompileOptions& compile_options, int num_threads) { + const xla::Compiler::CompileOptions& compile_options, int num_threads, + std::function customize_hlo_module_config) { TF_ASSIGN_OR_RETURN(ProgramShape program_shape, computation.GetProgramShape()); // Unoptimized HloModuleConfig. @@ -715,6 +722,11 @@ static absl::StatusOr> JitCompile( execution_options.num_replicas(), num_threads, /*aot_options=*/nullptr)); + // Apply the user-provided callback to customize the HloModuleConfig. + if (customize_hlo_module_config) { + customize_hlo_module_config(*hlo_module_config); + } + // Unoptimized HloModule. const xla::HloModuleProto& hlo_module_proto = computation.proto(); TF_ASSIGN_OR_RETURN( @@ -823,7 +835,8 @@ absl::StatusOr> TfrtCpuClient::Compile( std::unique_ptr cpu_executable, JitCompile(computation, argument_layout_pointers, build_options, execution_options, compile_options, - eigen_intraop_device()->getPool()->NumThreads())); + eigen_intraop_device()->getPool()->NumThreads(), + customize_hlo_module_config_)); auto cpu_executable_ptr = tensorflow::down_cast(cpu_executable.get()); @@ -867,6 +880,11 @@ absl::StatusOr> TfrtCpuClient::Compile( return Compile(xla_computation, options); } +static bool IsAlignedData(void* ptr) { + return (absl::bit_cast(ptr) & + (cpu_function_runtime::MinAlign() - 1)) == 0; +} + absl::StatusOr> TfrtCpuClient::CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, @@ -877,6 +895,13 @@ TfrtCpuClient::CreateViewOfDeviceBuffer( "TfrtCpuClient::CreateViewOfDeviceBuffer does not support `stream` " "argument."); } + if (!IsAlignedData(device_ptr)) { + return InvalidArgument( + "Can't create a view of buffer with unaligned data, ptr: %#x is not " + "aligned to %d bytes. ", + reinterpret_cast(device_ptr), + cpu_function_runtime::MinAlign()); + } absl::InlinedVector, 4> buffers; size_t byte_size = ShapeUtil::ByteSizeOf(shape); auto non_owning_buffer = @@ -1601,11 +1626,8 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( cpu::Thunk::CustomCallExecuteParams custom_call_execute_params, cpu::Thunk::CustomCallExecuteParams::Create(&run_options)); - cpu::Thunk::TaskRunner task_runner = - [&run_options](cpu::Thunk::Task task) { - run_options.intra_op_thread_pool()->getPool()->Schedule( - std::move(task)); - }; + cpu::ThreadPoolTaskRunner task_runner( + run_options.intra_op_thread_pool()->getPool()); cpu::Thunk::ExecuteParams execute_params = { &cpu_executable->function_registry(), @@ -1740,11 +1762,8 @@ absl::StatusOr TfrtCpuExecutable::ExecuteHelper( custom_call_params = cpu::Thunk::CustomCallExecuteParams::Create(&run_options); - cpu::Thunk::TaskRunner task_runner = - [&run_options](cpu::Thunk::Task task) { - run_options.intra_op_thread_pool()->getPool()->Schedule( - std::move(task)); - }; + cpu::ThreadPoolTaskRunner task_runner( + run_options.intra_op_thread_pool()->getPool()); if (collective_params.ok()) { cpu::Thunk::ExecuteParams execute_params = { diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index dfc193a273b0bd..bcb15ac8c3377c 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -37,8 +37,8 @@ limitations under the License. #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" #include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" @@ -60,6 +60,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" @@ -255,10 +256,11 @@ class TfrtCpuDevice final : public PjRtDevice { class TfrtCpuClient final : public PjRtClient { public: - TfrtCpuClient(int process_index, - std::vector> devices, - std::shared_ptr collectives, - size_t num_threads, bool asynchronous); + TfrtCpuClient( + int process_index, std::vector> devices, + std::shared_ptr collectives, + size_t num_threads, bool asynchronous, + std::function customize_hlo_module_config); ~TfrtCpuClient() override; int process_index() const override { return process_index_; } @@ -289,9 +291,7 @@ class TfrtCpuClient final : public PjRtClient { absl::string_view platform_name() const override { return CpuName(); } - absl::string_view platform_version() const override { return ""; } - - PjRtRuntimeType runtime_type() const override { return kTfrt; } + absl::string_view platform_version() const override { return CpuName(); } absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -479,6 +479,9 @@ class TfrtCpuClient final : public PjRtClient { // this client. Only applies to non-parallel computations. bool asynchronous_; + // A callback to customize the HloModuleConfig for each compiled module. + std::function customize_hlo_module_config_; + // Used to prevent too much parallelism: we will not enqueue next non-parallel // computation until last one is done within each user thread. // TODO(yueshengys): Consider moving the enqueuing/ordering logic to JAX via @@ -709,16 +712,21 @@ struct CpuClientOptions { // Distributed collectives implementation. Optional. If not provided, an // in-process collectives implementation will be used. std::shared_ptr collectives; + + // If defined this function will be called on the HloModuleConfig before + // compilation, and allows users to set custom flags. + std::function customize_hlo_module_config; }; + absl::StatusOr> GetTfrtCpuClient( - const CpuClientOptions& options); + CpuClientOptions options); // Deprecated. Use the overload that takes 'options' instead. inline absl::StatusOr> GetTfrtCpuClient( bool asynchronous) { CpuClientOptions options; options.asynchronous = asynchronous; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); } // Deprecated. Use the overload that takes 'options' instead. @@ -730,7 +738,7 @@ inline absl::StatusOr> GetTfrtCpuClient( options.cpu_device_count = cpu_device_count; options.max_inflight_computations_per_device = max_inflight_computations_per_device; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); } } // namespace xla diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc index d0b634f17b5527..21f33067a714bf 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc @@ -15,10 +15,6 @@ limitations under the License. #include "xla/pjrt/cpu/cpu_client.h" -#include "xla/service/hlo.pb.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - #ifndef _WIN32 #include #endif @@ -29,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -36,21 +33,23 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" -#include "xla/client/xla_computation.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" -#include "xla/service/hlo_parser.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/types.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -155,7 +154,8 @@ TEST(TfrtCpuClientTest, HloSnapshot) { CpuClientOptions cpu_options; cpu_options.cpu_device_count = 1; - TF_ASSERT_OK_AND_ASSIGN(auto client, GetTfrtCpuClient(cpu_options)); + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetTfrtCpuClient(std::move(cpu_options))); TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseAndReturnUnverifiedModule(kProgram, {})); diff --git a/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc b/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc index f0aea4ec2f3264..3b59265772dd5c 100644 --- a/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc +++ b/third_party/xla/xla/pjrt/cpu/pjrt_client_test_cpu.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/pjrt/pjrt_client_test.h" #include "xla/pjrt/cpu/cpu_client.h" +#include "xla/pjrt/pjrt_client_test.h" namespace xla { namespace { @@ -23,7 +23,7 @@ namespace { const bool kUnused = (RegisterTestClientFactory([]() { CpuClientOptions options; options.cpu_device_count = 4; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); }), true); diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index 481ddfa0cabc29..dc4933c984ddb4 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -28,6 +28,7 @@ cc_library( "//xla/tsl/distributed_runtime/coordination:coordination_service_impl", "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/protobuf:coordination_config_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", @@ -37,7 +38,6 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:random", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -73,6 +73,8 @@ cc_library( "//xla/tsl/distributed_runtime/coordination:coordination_client", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -81,8 +83,6 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index 0f4f7fff9d809d..a0dfcbe8bf6f10 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace xla { @@ -95,10 +95,7 @@ DistributedRuntimeCoordinationServiceClient:: config.set_poll_for_error_from_service_at_startup( options.poll_for_error_from_service_at_startup); auto error_fn = [timeout_fn = options.missed_heartbeat_callback]( - const absl::Status& status) { - LOG(ERROR) << "Coordination service agent in error status: " << status; - timeout_fn(status, /*coordinator_reported_failure=*/true); - }; + const absl::Status& status) { timeout_fn(status); }; std::unique_ptr leader_client; leader_client.reset(tsl::NewGrpcCoordinationClient(channel)); diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 2387fe6dd452f5..0654522bb78818 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -75,38 +75,25 @@ class DistributedRuntimeClient { // is reported by the coordinator, or we have not heard from the coordinator // recently. `coordinator_reported_failure` is true in the former case. // Exposed so tests can override this behavior to something non-fatal. - std::function - missed_heartbeat_callback = - [](absl::Status status, bool coordinator_reported_failure) { - if (coordinator_reported_failure) { - LOG(QFATAL) - << "Terminating process because the coordinator detected " - "missing heartbeats. This most likely indicates that " - "another task died; see the other task logs for more " - "details. Disable Python buffering, i.e. `python -u`, " - "to be sure to see all the previous output. " - "absl::Status: " - << status; - } else { - LOG(QFATAL) - << "Terminating process because of missing heartbeat " - "response from the coordinator. This most likely " - "indicates that the coordinator task died; see the " - "coordinator's task logs for more details. " - "Disable Python buffering, i.e. `python -u`, to be " - "sure to see all the previous output. absl::Status: " - << status; - } - }; + std::function missed_heartbeat_callback = + [](const absl::Status& status) { + LOG(QFATAL) << "Terminating process because the JAX distributed " + "service detected fatal errors. This most likely " + "indicates that another task died; see the other task " + "logs for more details. Disable Python buffering, " + "i.e. `python -u`, to be sure to see all the " + "previous output. " + "absl::Status: " + << status; + }; // For testing. Should the client explicitly Shutdown() on destruction? bool shutdown_on_destruction = true; // Whether the client should send a request to wait for error from the // coordination service at the startup. - // TODO(b/355706798): Enable this by default once we confirm this works for - // all cases and eventually remove this option. - bool poll_for_error_from_service_at_startup = false; + // TODO(b/355706798): eventually remove this option. + bool poll_for_error_from_service_at_startup = true; }; virtual ~DistributedRuntimeClient() = default; diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index 1b55bab10e1caa..4d76c608bfadfd 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -32,13 +32,13 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "grpcpp/channel.h" -#include "grpcpp/create_channel.h" #include "grpcpp/security/credentials.h" #include "grpcpp/security/server_credentials.h" #include "grpcpp/server.h" #include "grpcpp/server_builder.h" #include "grpcpp/support/channel_arguments.h" #include "xla/pjrt/distributed/client.h" +#include "xla/pjrt/distributed/distributed.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/pjrt/distributed/service.h" #include "xla/pjrt/distributed/topology_util.h" @@ -59,7 +59,7 @@ using ::testing::Pair; using ::testing::UnorderedElementsAre; constexpr absl::Duration kHeartbeatInterval = absl::Milliseconds(500); -constexpr int kMaxMissingHeartbeats = 3; +constexpr int kMaxMissingHeartbeats = 5; constexpr absl::Duration kBarrierTimeout = absl::Milliseconds(200); class ClientServerTest : public testing::Test { @@ -72,51 +72,36 @@ class ClientServerTest : public testing::Test { client_options.heartbeat_interval = kHeartbeatInterval; client_options.max_missing_heartbeats = kMaxMissingHeartbeats; if (channel == nullptr) { - channel = server_->InProcessChannel(::grpc::ChannelArguments()); + channel = coord_service_->server()->InProcessChannel( + ::grpc::ChannelArguments()); } return GetDistributedRuntimeClient(channel, client_options); } void StartService(int num_nodes, - CoordinationServiceImpl::Options service_options = {}, - absl::string_view service_address = "") { - ::grpc::ServerBuilder builder; + CoordinationServiceImpl::Options service_options = {}) { + int port = tsl::testing::PickUnusedPortOrDie(); + service_address_ = absl::StrCat("[::]:", port); + service_options.num_nodes = num_nodes; // Set a small heartbeat interval for quicker tests. service_options.heartbeat_interval = kHeartbeatInterval; service_options.max_missing_heartbeats = kMaxMissingHeartbeats; - // Add a listening port if address is specified. - if (!service_address.empty()) { - auto credentials = ::grpc::InsecureServerCredentials(); - builder.AddListeningPort(std::string(service_address), credentials); - } - // Set up and register service on the gRPC server. - coord_service_ = - std::make_unique(service_options, &builder); - server_ = builder.BuildAndStart(); - coord_service_->StartRpcThread(); - } - - // Shut down the server. - void Stop() { - // Avoid shutting down the server twice if the test has already called - // Stop() earlier. - if (stop_is_already_called_) { - return; - } - server_->Shutdown(); - stop_is_already_called_ = true; + coord_service_ = DistributedRuntimeService::Get( + service_address_, ::grpc::InsecureServerCredentials(), + service_options) + .value(); } - void TearDown() override { Stop(); } + std::string service_address() { return service_address_; } - std::unique_ptr<::grpc::Server> server_; + void StopService() { coord_service_ = nullptr; } private: - std::unique_ptr coord_service_; - bool stop_is_already_called_ = false; + std::unique_ptr coord_service_; + std::string service_address_ = ""; }; TEST_F(ClientServerTest, ConnectAndShutdownAreBarriers) { @@ -379,7 +364,8 @@ TEST_F(ClientServerTest, ZeroInitTimeoutShouldStillWaitForOtherTasks) { } } -TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { +TEST_F(ClientServerTest, + ClientsTerminateShutdownIfAnyClientGoesAway_WithoutErrorPolling) { int num_nodes = 3; StartService(num_nodes); @@ -387,8 +373,7 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = node_id != 0; client_options.poll_for_error_from_service_at_startup = false; - client_options.missed_heartbeat_callback = - [&](absl::Status status, bool coordinator_initiated) {}; + client_options.missed_heartbeat_callback = [&](absl::Status status) {}; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -425,17 +410,14 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { } } -TEST_F(ClientServerTest, - ClientsTerminateShutdownIfAnyClientGoesAway_WithErrorPolling) { +TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = node_id != 0; - client_options.missed_heartbeat_callback = - [&](absl::Status status, bool coordinator_initiated) {}; - client_options.poll_for_error_from_service_at_startup = true; + client_options.missed_heartbeat_callback = [&](absl::Status status) {}; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -466,16 +448,14 @@ TEST_F(ClientServerTest, } } -TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { +TEST_F(ClientServerTest, ClientsShutdownSuccessfully) { int num_nodes = 3; StartService(num_nodes); auto thread_fn = [&](int node_id) -> absl::Status { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = true; - client_options.missed_heartbeat_callback = - [&](absl::Status status, bool coordinator_initiated) {}; - client_options.poll_for_error_from_service_at_startup = true; + client_options.missed_heartbeat_callback = [&](absl::Status status) {}; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -497,8 +477,7 @@ TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { } } -TEST_F(ClientServerTest, - MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway_WithErrorPolling) { +TEST_F(ClientServerTest, MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); @@ -506,11 +485,9 @@ TEST_F(ClientServerTest, DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = (node_id != 0); absl::Notification shutdown; - client_options.missed_heartbeat_callback = [&](absl::Status status, - bool coordinator_initiated) { + client_options.missed_heartbeat_callback = [&](absl::Status status) { shutdown.Notify(); }; - client_options.poll_for_error_from_service_at_startup = true; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -535,7 +512,8 @@ TEST_F(ClientServerTest, } } -TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { +TEST_F(ClientServerTest, + ClientsReceiveMissedHeartbeatIfAnyClientGoesAway_WithoutErrorPolling) { int num_nodes = 3; StartService(num_nodes); @@ -543,10 +521,10 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { DistributedRuntimeClient::Options client_options; client_options.shutdown_on_destruction = (node_id != 0); absl::Notification shutdown; - client_options.missed_heartbeat_callback = [&](absl::Status status, - bool coordinator_initiated) { + client_options.missed_heartbeat_callback = [&](absl::Status status) { shutdown.Notify(); }; + client_options.poll_for_error_from_service_at_startup = false; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -572,13 +550,13 @@ TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { } TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { +#if defined(ADDRESS_SANITIZER) + GTEST_SKIP() + << "This test is known to produce memory leaks due to ungraceful " + "termination of the RPC server despite having pending connections."; +#endif int num_nodes = 3; - // We use a socket connection for this test case because the in-process API - // does not react well to the server being told to shutdown while there are - // active clients. - int port = tsl::testing::PickUnusedPortOrDie(); - StartService(num_nodes, - /*service_options=*/{}, absl::StrCat("[::]:", port)); + StartService(num_nodes); absl::Barrier barrier(num_nodes + 1); @@ -587,14 +565,11 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { client_options.rpc_timeout = absl::Seconds(1); client_options.shutdown_timeout = absl::Seconds(10); absl::Notification shutdown; - client_options.missed_heartbeat_callback = [&](absl::Status status, - bool coordinator_initiated) { + client_options.missed_heartbeat_callback = [&](absl::Status status) { shutdown.Notify(); }; - std::shared_ptr<::grpc::ChannelCredentials> creds = - ::grpc::InsecureChannelCredentials(); - std::shared_ptr<::grpc::Channel> channel = - ::grpc::CreateChannel(absl::StrCat("dns:///localhost:", port), creds); + auto channel = GetDistributedRuntimeClientChannel( + service_address(), ::grpc::InsecureChannelCredentials()); auto client = GetClient(node_id, client_options, channel); TF_RETURN_IF_ERROR(client->Connect()); @@ -614,7 +589,7 @@ TEST_F(ClientServerTest, ClientsTerminateIfServiceGoesAway) { thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); } barrier.Block(); - Stop(); + StopService(); } for (int i = 0; i < num_nodes; ++i) { EXPECT_EQ(statuses[i].code(), tsl::error::FAILED_PRECONDITION); @@ -668,10 +643,9 @@ TEST_F(ClientServerTest, ConnectEventuallyTimesOutIfAClientDoesNotShowUp) { client_options.init_timeout = timeout; client_options.rpc_timeout = timeout; // Overwrite the default error callback which invokes LOG(QFATAL). - client_options.missed_heartbeat_callback = - [](absl::Status status, bool coordinator_reported_failure) { - LOG(ERROR) << "Distributed client has missing heartbeats: " << status; - }; + client_options.missed_heartbeat_callback = [](absl::Status status) { + LOG(ERROR) << "Distributed client has missing heartbeats: " << status; + }; auto client = GetClient(node_id, client_options); TF_RETURN_IF_ERROR(client->Connect()); @@ -998,5 +972,19 @@ TEST_F(ClientServerTest, KeyValueDelete_Directory) { EXPECT_THAT(kvs.value(), IsEmpty()); } +TEST_F(ClientServerTest, UseCompression) { + StartService(/*num_nodes=*/1); + + // Sanity check that the client can connect with compression enabled. + auto channel = GetDistributedRuntimeClientChannel( + service_address(), ::grpc::InsecureChannelCredentials(), + /*use_compression=*/true); + auto client = GetClient(/*node_id=*/0, {}, channel); + + TF_ASSERT_OK(client->Connect()); + TF_ASSERT_OK(client->KeyValueSet("foo/bar/1", "1")); + TF_ASSERT_OK(client->Shutdown()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/distributed/distributed.cc b/third_party/xla/xla/pjrt/distributed/distributed.cc index 69f9f2e249b402..cdbca7b64ed7ff 100644 --- a/third_party/xla/xla/pjrt/distributed/distributed.cc +++ b/third_party/xla/xla/pjrt/distributed/distributed.cc @@ -38,10 +38,22 @@ GetDistributedRuntimeService(std::string address, } std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options) { - std::shared_ptr channel = grpc::CreateChannel( - address, tsl::GetClientCredentials(kVerifySecureCredentials)); + std::string address, const DistributedRuntimeClient::Options& options, + bool use_compression) { + auto channel = GetDistributedRuntimeClientChannel( + address, tsl::GetClientCredentials(kVerifySecureCredentials), + use_compression); return GetDistributedRuntimeClient(channel, options); } +std::shared_ptr<::grpc::Channel> GetDistributedRuntimeClientChannel( + std::string address, std::shared_ptr<::grpc::ChannelCredentials> creds, + bool use_compression) { + grpc::ChannelArguments args; + if (use_compression) { + args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP); + } + return ::grpc::CreateCustomChannel(address, creds, args); +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/distributed/distributed.h b/third_party/xla/xla/pjrt/distributed/distributed.h index 8145ddaa5c699e..e0e6a319873c90 100644 --- a/third_party/xla/xla/pjrt/distributed/distributed.h +++ b/third_party/xla/xla/pjrt/distributed/distributed.h @@ -40,7 +40,13 @@ GetDistributedRuntimeService(std::string address, // Builds a distributed runtime client, connecting to a service at `address`, // where address is a gRPC-style address such as `dns:///localhost:1234`. std::shared_ptr GetDistributedRuntimeClient( - std::string address, const DistributedRuntimeClient::Options& options); + std::string address, const DistributedRuntimeClient::Options& options, + bool use_compression = false); + +// Builds the gRPC channel used by the runtime client. Exposed for testing. +std::shared_ptr<::grpc::Channel> GetDistributedRuntimeClientChannel( + std::string address, std::shared_ptr<::grpc::ChannelCredentials> creds, + bool use_compression = false); } // namespace xla diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 6a8a77a5fca534..51729532b63709 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace { diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 61a0cdf608e88b..5b7a2967deeb09 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -24,13 +24,19 @@ cc_library( "//xla/client:client_library", "//xla/client:local_client", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/tsl/framework:allocator", "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id_impl", "//xla/tsl/util:env_var", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", ], ) @@ -52,7 +58,7 @@ cc_library( "//xla:xla_proto_cc", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/pjrt:compile_options_proto_cc", "//xla/pjrt:event_pool", "//xla/pjrt:host_memory_spaces", @@ -80,11 +86,12 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service:transfer_manager", "//xla/service/gpu:gpu_executable_run_options", - "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/integrations:device_mem_allocator", "//xla/stream_executor/integrations:tf_allocator_adapter", "//xla/tsl/framework:allocator", @@ -152,22 +159,25 @@ xla_cc_test( "//xla:shape_util", "//xla:status_macros", "//xla:test", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_computation", "//xla/ffi", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", + "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_stream_executor_client", "//xla/pjrt/distributed:in_memory_key_value_store", "//xla/service:gpu_plugin", - "//xla/service:hlo_parser", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/tests:literal_test_util", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", @@ -177,8 +187,11 @@ xla_cc_test( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -249,7 +262,7 @@ cc_library( ":se_gpu_pjrt_client", "//xla:status_macros", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/pjrt:mlir_to_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", @@ -265,7 +278,7 @@ cc_library( "//xla/service:local_service", "//xla/service:local_service_utils", "//xla/service/gpu:executable_proto_cc", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:casts", @@ -311,11 +324,11 @@ xla_test( ":se_gpu_pjrt_compiler", "//xla:test", "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/mlir_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", - "//xla/service:hlo_parser", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:literal_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", @@ -340,12 +353,13 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/mlir_hlo", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_compiler", "//xla/pjrt:pjrt_executable", "//xla/service:compiler", - "//xla/service:hlo_parser", "//xla/tests:literal_test_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc index 7501f99192a2b9..c89f333d209e28 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.cc @@ -15,22 +15,37 @@ limitations under the License. #include "xla/pjrt/gpu/gpu_helpers.h" +#include +#include #include #include #include #include #include +#include #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/types/span.h" #include "xla/client/client_library.h" +#include "xla/client/local_client.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/integrations/device_host_allocator.h" #include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/framework/allocator.h" +#include "xla/tsl/framework/bfc_allocator.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/util/env_var.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -215,4 +230,25 @@ std::unique_ptr GetGpuHostAllocator( /*name=*/"xla_gpu_host_bfc", opts); } +int TopologySizes::GetDeviceCount() { + return num_slices * num_hosts_per_slice * num_devices_per_host; +} + +// static +absl::StatusOr TopologySizes::FromString( + std::string_view topology_string) { + TopologySizes sizes; + std::vector topology_components = + absl::StrSplit(topology_string, 'x'); + if (topology_components.size() != 3 || + !absl::SimpleAtoi(topology_components[0], &sizes.num_slices) || + !absl::SimpleAtoi(topology_components[1], &sizes.num_hosts_per_slice) || + !absl::SimpleAtoi(topology_components[2], &sizes.num_devices_per_host)) { + return absl::InternalError( + "topology must be of shape " + "\"xx\""); + } + return sizes; +} + } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h index c64f2decbe4aaf..a13e611688c643 100644 --- a/third_party/xla/xla/pjrt/gpu/gpu_helpers.h +++ b/third_party/xla/xla/pjrt/gpu/gpu_helpers.h @@ -89,6 +89,21 @@ absl::StatusOr> CreateCollectiveBFCAllocator( se::StreamExecutor* executor, double memory_fraction, size_t collective_memory_size); +// Represents topology of devices. +struct TopologySizes { + int num_slices = 0; + int num_hosts_per_slice = 0; + int num_devices_per_host = 0; + + // Returns number of devices in the topology. + int GetDeviceCount(); + // Parses the topology description of the form + // " x x " + // and returns the parsed components on success. + static absl::StatusOr FromString( + std::string_view topology_string); +}; + } // namespace xla #endif // XLA_PJRT_GPU_GPU_HELPERS_H_ diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index adb63a3132027b..cd59543000bbfc 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -46,7 +46,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -123,10 +123,10 @@ class AsyncHostToDeviceTransferManager public: static absl::StatusOr> Create(absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtStreamExecutorDevice* device, PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space) { - if (device_layouts != std::nullopt && + if (device_layouts.has_value() && device_layouts->size() != shape_specs.size()) { return InvalidArgument( "Number of layouts %d does not match the number of shapes %d", @@ -153,14 +153,14 @@ class AsyncHostToDeviceTransferManager std::make_shared(client->thread_pool())); Shape& device_shape = device_shapes.emplace_back( ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); - if (device_layouts == std::nullopt) { + if (device_layouts.has_value() && (*device_layouts)[i].has_value()) { + *device_shape.mutable_layout() = *(*device_layouts)[i]; + } else { TF_ASSIGN_OR_RETURN(device_shape, client->client() ->backend() .transfer_manager() ->ChooseCompactLayoutForShape(device_shape)); - } else { - *device_shape.mutable_layout() = (*device_layouts)[i]; } LocalDeviceState* local_device = device->local_device_state(); se::Stream* h2d_stream = local_device->host_to_device_stream(); @@ -509,6 +509,7 @@ StreamExecutorGpuClient::StreamExecutorGpuClient( tsl::Fingerprint64(platform_name), platform_name, std::move(gpu_topology))), kv_store_(std::move(kv_store)) { + const int basePinnedId = device_count(); for (auto* device : addressable_devices()) { // Use the device id to construct a globally unique memory space id. We do // not promise that memory space ids and device ids are the same. @@ -518,8 +519,8 @@ StreamExecutorGpuClient::StreamExecutorGpuClient( tensorflow::down_cast(device)->AttachMemorySpace( memory_space.get()); owned_memory_spaces_.push_back(std::move(memory_space)); - const size_t basePinnedId = devices.size(); - auto pinned = std::make_unique(basePinnedId, device); + auto pinned = + std::make_unique(basePinnedId + id, device); tensorflow::down_cast(device)->AttachMemorySpace( pinned.get()); owned_memory_spaces_.push_back(std::move(pinned)); @@ -554,7 +555,7 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) { auto* stream_executor_device = tensorflow::down_cast(device); @@ -580,7 +581,7 @@ StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) { CHECK_EQ(memory_space->devices().size(), 1); PjRtDevice* device = memory_space->devices()[0]; @@ -1038,6 +1039,7 @@ absl::StatusOr BuildDistributedDevices( int node_id, int num_nodes, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, + std::optional mock_gpu_topology, absl::Duration get_local_topology_timeout, absl::Duration get_global_topology_timeout) { std::vector> devices; @@ -1069,12 +1071,37 @@ absl::StatusOr BuildDistributedDevices( GlobalTopologyProto global_topology; if (enable_mock_nccl) { + TopologySizes sizes; + if (mock_gpu_topology.has_value()) { + TF_ASSIGN_OR_RETURN(sizes, TopologySizes::FromString(*mock_gpu_topology)); + } else { + // If there is no topology spec, we assume that each node is a slice, + // there is one process (host) on each slice and each host + // has all the local devices. + sizes.num_slices = num_nodes; + sizes.num_hosts_per_slice = 1; + sizes.num_devices_per_host = local_topology.devices().size(); + } + + if (sizes.num_devices_per_host != local_topology.devices().size()) { + return absl::InternalError( + "The number of devices per host in 'mock_gpu_topology' " + "must be the same as the number of devices in the local topology"); + } + + if (sizes.num_slices * sizes.num_hosts_per_slice != num_nodes) { + return absl::InternalError( + "The number of hosts in 'mock_gpu_topology' " + "must be the same as 'num_nodes'"); + } + std::vector local_topologies(num_nodes, local_topology); - for (int i = 0; i < num_nodes; ++i) { - local_topologies[i].set_node_id(i); - // Set a distinct boot_id for each local topology to change slice_index - // for each node. - local_topologies[i].set_boot_id(absl::StrCat(i)); + for (int i = 0; i < sizes.num_slices; ++i) { + for (int j = 0; j < sizes.num_hosts_per_slice; j++) { + int node_id = i * sizes.num_hosts_per_slice + j; + local_topologies[node_id].set_node_id(node_id); + local_topologies[node_id].set_boot_id(absl::StrCat(i)); + } } global_topology = BuildGlobalTopology(absl::MakeSpan(local_topologies), /*assign_global_device_ids=*/true); @@ -1254,10 +1281,10 @@ absl::StatusOr> GetStreamExecutorGpuClient( TF_RET_CHECK(options.num_nodes == 1 || kv_store != nullptr); TF_ASSIGN_OR_RETURN( DeviceTopologyPair device_topology_pair, - BuildDistributedDevices(pjrt_platform_name, - std::move(local_device_states), options.node_id, - options.num_nodes, gpu_run_options.get(), - kv_store, options.enable_mock_nccl)); + BuildDistributedDevices( + pjrt_platform_name, std::move(local_device_states), options.node_id, + options.num_nodes, gpu_run_options.get(), kv_store, + options.enable_mock_nccl, options.mock_gpu_topology)); auto gpu_topology = std::shared_ptr( GpuTopology::FromProto(device_topology_pair.second)); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index 2f65f45462b2e1..ccc6e5dde17027 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -33,7 +33,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/gpu/gpu_helpers.h" @@ -201,6 +201,11 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { std::shared_ptr kv_store, std::shared_ptr gpu_topology); + std::optional> key_value_store() + const override { + return kv_store_; + } + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; @@ -208,7 +213,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) override; absl::StatusOr> @@ -218,7 +223,7 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) override; absl::StatusOr> @@ -273,6 +278,7 @@ absl::StatusOr BuildDistributedDevices( int node_id, int num_nodes, gpu::GpuExecutableRunOptions* gpu_executable_run_options, std::shared_ptr kv_store, bool enable_mock_nccl, + std::optional mock_gpu_topology = std::nullopt, absl::Duration get_local_topology_timeout = absl::Minutes(2), absl::Duration get_global_topology_timeout = absl::Minutes(5)); @@ -293,6 +299,8 @@ struct GpuClientOptions { std::shared_ptr kv_store = nullptr; bool enable_mock_nccl = false; + + std::optional mock_gpu_topology; }; absl::StatusOr> GetStreamExecutorGpuClient( diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index b306194ed4eefd..f65c32bee9f7aa 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -36,9 +36,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" +#include "mlir/IR/MLIRContext.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -47,10 +49,10 @@ limitations under the License. #include "xla/pjrt/host_memory_spaces.h" #include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" #include "xla/pjrt/pjrt_stream_executor_client.h" -#include "xla/service/hlo_parser.h" #include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -60,10 +62,13 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -156,6 +161,22 @@ TEST(StreamExecutorGpuClientTest, MemorySpace) { } } +TEST(StreamExecutorGpuClientTest, MemorySpacesUniqueIds) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + ASSERT_GE(client->devices().size(), 1); + + absl::flat_hash_map memories; + for (auto* device : client->devices()) { + for (auto* memory_space : device->memory_spaces()) { + std::string debug_string(memory_space->DebugString()); + auto [it, inserted] = memories.insert({memory_space->id(), debug_string}); + EXPECT_TRUE(inserted) << "Duplicate ids for memory spaces '" << it->second + << "' and '" << debug_string << "'"; + } + } +} + TEST(StreamExecutorGpuClientTest, PropagateError) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -191,6 +212,51 @@ ENTRY %Add.6 (a.1: f32[], b.2: f32[]) -> (f32[], f32[]) { EXPECT_EQ(result[0][0]->GetReadyFuture().Await(), input_error); } +// TODO(b/372735047): Fix and reenable. +TEST(StreamExecutorGpuClientTest, DISABLED_DonateWithControlDependency) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + auto shape = xla::ShapeUtil::MakeScalarShape(xla::F32); + absl::Status input_error = absl::InvalidArgumentError("input error"); + TF_ASSERT_OK_AND_ASSIGN( + auto buffer, + client->CreateErrorBuffer( + input_error, shape, + *client->addressable_devices()[0]->default_memory_space())); + + static constexpr char const* kAddProgram = + R"( +HloModule Add.6, entry_computation_layout={(f32[], f32[])->(f32[], f32[])} + +ENTRY %Add.6 (a.1: f32[], b.2: f32[]) -> (f32[], f32[]) { + %a.1 = f32[] parameter(0) + %b.2 = f32[] parameter(1) + %add.3 = f32[] add(f32[] %a.1, f32[] %b.2) + %add.4 = f32[] add(f32[] %add.3, f32[] %add.3) + ROOT %tuple.5 = (f32[], f32[]) tuple(f32[] %add.3, f32[] %add.4) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CompileExecutable(kAddProgram, *client)); + + TF_ASSERT_OK_AND_ASSIGN( + auto result, + executable->Execute({{buffer.get(), buffer.get()}}, /*options=*/{})); + + ASSERT_EQ(result.size(), 1); + ASSERT_EQ(result[0].size(), 1); + + TF_ASSERT_OK_AND_ASSIGN( + auto another_buffer, + client->CreateErrorBuffer( + input_error, shape, + *client->addressable_devices()[0]->default_memory_space())); + TF_ASSERT_OK_AND_ASSIGN(another_buffer, + another_buffer->DonateWithControlDependency( + result[0][0]->GetReadyFuture())); + EXPECT_EQ(another_buffer->GetReadyFuture().Await(), input_error); +} + TEST(StreamExecutorGpuClientTest, SendRecvChunked) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -422,12 +488,12 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsyncWithNonCompactLayout) { spec.element_type = src_literal.shape().element_type(); spec.dims = DimensionVector(src_literal.shape().dimensions().begin(), src_literal.shape().dimensions().end()); + std::vector> device_layouts = { + std::make_optional(transposed_shape.layout())}; TF_ASSERT_OK_AND_ASSIGN( auto transfer_manager, client->CreateBuffersForAsyncHostToDevice( - {spec}, - std::make_optional>( - {transposed_shape.layout()}), + {spec}, device_layouts, client->addressable_devices()[0]->memory_spaces()[0])); auto buffer = transfer_manager->RetrieveBuffer(0); @@ -674,6 +740,52 @@ TEST(StreamExecutorGpuClientTest, FromHostAsyncPinnedHostChunked) { EXPECT_THAT(lit->data(), ElementsAreArray(data)); } +TEST(StreamExecutorGpuClientTest, DeleteBufferThenFulfillBufferNoDeadLock) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetStreamExecutorGpuClient(GpuClientOptions())); + ASSERT_THAT(client->addressable_devices(), SizeIs(Gt(0))); + TF_ASSERT_OK_AND_ASSIGN( + PjRtMemorySpace * memspace, + client->addressable_devices()[0]->memory_space_by_kind( + PinnedHostMemorySpace::kKind)); + std::vector data{1, 3, 5, 7, 11, 13, 17, 19}; + Shape shape = ShapeUtil::MakeShape(F32, {static_cast(data.size())}); + std::vector> + txms; + for (int i = 0; i < 10000; ++i) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr txm, + client->CreateBuffersForAsyncHostToDevice({shape}, memspace)); + std::unique_ptr buf = txm->RetrieveBuffer(0); + ASSERT_THAT(buf->GetReadyFuture().IsReady(), Eq(false)); + txms.push_back(std::move(txm)); + // Delete the buffer + } + + // At this point, we have 10000 buffers pending deallocation. + + absl::string_view raw_view(reinterpret_cast(data.data()), + data.size() * sizeof(data[0])); + for (auto& txm : txms) { + int offset = 0; + while (true) { + int end = offset + 3; // unaligned chunk size + if (end > raw_view.size()) { + end = raw_view.size(); + } + int sz = end - offset; + bool reaches_end = end == raw_view.size(); + TF_ASSERT_OK(txm->TransferRawDataToSubBuffer( + /*buffer_index=*/0, raw_view.data() + offset, offset, sz, reaches_end, + /*on_done=*/[]() {})); + if (reaches_end) { + break; + } + offset = end; + } + } +} + TEST(StreamExecutorGpuClientTest, CopyRawToHostFullBuffer) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -946,6 +1058,94 @@ TEST(StreamExecutorGpuClientTest, MockNcclClientTest) { } } +TEST(StreamExecutorGpuClientTest, MockNcclClientWithGpuTopologyTest) { + GpuClientOptions options; + options.enable_mock_nccl = true; + options.num_nodes = 8; + options.mock_gpu_topology = "2x4x2"; + TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(options)); + + auto devices_per_host = client->addressable_device_count(); + EXPECT_EQ(devices_per_host, 2) << "This test requires 2 local GPUs."; + + TF_ASSERT_OK_AND_ASSIGN(const xla::PjRtTopologyDescription* topology, + client->GetTopologyDescription()); + const StreamExecutorGpuTopologyDescription& gpu_topology = + tensorflow::down_cast( + *topology); + + EXPECT_EQ(gpu_topology.gpu_topology().num_slices(), 2); + EXPECT_EQ(gpu_topology.gpu_topology().num_hosts_per_slice(), 4); + EXPECT_EQ(gpu_topology.gpu_topology().num_devices_per_host(), 2); +} + +constexpr char kMlirDistributedSum[] = R"( +module @jit_f attributes {mhlo.num_partitions = 8 : i32, + mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<8xi32> { + mhlo.layout_mode = "default", + mhlo.sharding = "{devices=[8]0,1,2,3,4,5,6,7}"}) -> (tensor { + jax.result_info = "", + mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<0> : tensor + %0 = stablehlo.reduce(%arg0 init: %c) + applies stablehlo.add across dimensions = [0] + : (tensor<8xi32>, tensor) -> tensor + return %0 : tensor + } +})"; + +TEST(StreamExecutorGpuClientTest, MockNcclClientWithGpuTopologyExecuteTest) { + GpuClientOptions client_options; + client_options.enable_mock_nccl = true; + client_options.num_nodes = 4; + client_options.mock_gpu_topology = "2x2x2"; + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(client_options)); + + auto devices_per_host = client->addressable_device_count(); + EXPECT_EQ(devices_per_host, 2) << "This test requires 2 local GPUs."; + + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN( + auto module, xla::ParseMlirModuleString(kMlirDistributedSum, context)); + + xla::CompileOptions options; + options.executable_build_options.set_num_partitions(8) + .set_use_spmd_partitioning(true) + .set_allow_spmd_sharding_propagation_to_output({true}); + TF_ASSERT_OK_AND_ASSIGN(auto executable, client->Compile(*module, options)); + + Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1}, {0}); + std::vector> inputs; + std::vector> input_ptrs; + for (int i = 0; i < devices_per_host; i++) { + auto device = client->addressable_devices()[i]; + std::vector data{i}; + TF_ASSERT_OK_AND_ASSIGN( + auto input, + client->BufferFromHostBuffer( + data.data(), shape.element_type(), shape.dimensions(), + /*byte_strides=*/std::nullopt, + PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/nullptr, device)); + input_ptrs.push_back({input.get()}); + inputs.push_back(std::move(input)); + } + + // Test that running the program does not crash/hang. + TF_ASSERT_OK( + executable->Execute(absl::MakeSpan(input_ptrs), ExecuteOptions())); +} + +TEST(StreamExecutorGpuClientTest, MockNcclClientWithGpuTopologyMismatchTest) { + GpuClientOptions options; + options.enable_mock_nccl = true; + options.num_nodes = 16; + options.mock_gpu_topology = "2x4"; + EXPECT_FALSE(GetStreamExecutorGpuClient(options).ok()); +} + TEST(StreamExecutorGpuClientTest, BufferFromHostBufferPinnedMemory) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); @@ -1543,5 +1743,15 @@ TEST(StreamExecutorGpuClientTest, nullptr); } +TEST(StreamExecutorGpuClientTest, GetDefaultLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + auto shape = ShapeUtil::MakeShape(S4, {2, 2}); + TF_ASSERT_OK_AND_ASSIGN( + auto layout, + client->GetDefaultLayout(shape.element_type(), shape.dimensions())); + EXPECT_EQ(layout.element_size_in_bits(), 4); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index ea9541ce8a03b1..7b80ec95663ece 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc index 84570559295040..b4522692d42d68 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_aot_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -44,7 +44,7 @@ limitations under the License. #elif TENSORFLOW_USE_ROCM #include "xla/service/gpu/amdgpu_compiler.h" #endif -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/tests/literal_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc index 1b630342bf1336..de668949e039d3 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler_test.cc @@ -22,13 +22,13 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Parser/Parser.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/pjrt/gpu/gpu_topology.h" #include "xla/pjrt/gpu/se_gpu_pjrt_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" -#include "xla/service/hlo_parser.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "tsl/platform/status_matchers.h" diff --git a/third_party/xla/xla/pjrt/layout_mode.cc b/third_party/xla/xla/pjrt/layout_mode.cc index 84758c6ef7e293..08877724052ac6 100644 --- a/third_party/xla/xla/pjrt/layout_mode.cc +++ b/third_party/xla/xla/pjrt/layout_mode.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" -#include "xla/service/hlo_parser.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/local_device_state.cc b/third_party/xla/xla/pjrt/local_device_state.cc index 62eba2b6238098..51b4257bfff965 100644 --- a/third_party/xla/xla/pjrt/local_device_state.cc +++ b/third_party/xla/xla/pjrt/local_device_state.cc @@ -27,11 +27,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/protobuf/error_codes.pb.h" namespace xla { @@ -51,6 +53,20 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, prng_seed_generator_(prng_seed_device_()), prng_seed_distribution_(std::numeric_limits::min(), std::numeric_limits::max()) { + // Setting XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL to false will: + // 1. disallow the host to schedule `create buffer -> use -> delete -> + // fulfill`, which is a use case unit tested in + // StreamExecutorGpuClientTest.DeleteBufferThenFulfillBufferNoDeadLock. + // 2. potentially reduce spikes in HBM usage because the host will wait for + // buffer fulfillment to be scheduled before destructing it. + absl::Status status = + tsl::ReadBoolFromEnvVar("XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL", true, + &allow_delete_before_fulfill_); + if (!status.ok()) { + LOG(ERROR) << "Failed to read XLA_PJRT_GPU_ALLOW_DELETE_BEFORE_FULFILL: " + << status; + } + local_hardware_id_ = executor_->device_ordinal(); local_device_id_ = device_ordinal != -1 ? device_ordinal : executor_->device_ordinal(); @@ -69,7 +85,7 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, stream = executor->CreateStream().value(); } if (stream) { - stream->set_name(name); + stream->SetName(name); } return stream; }; @@ -103,6 +119,8 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor, std::make_unique(tsl::Env::Default(), "py_xla_execute"); callback_thread_ = std::make_unique(tsl::Env::Default(), "py_xla_callback"); + cleanup_thread_ = + std::make_unique(tsl::Env::Default(), "py_xla_cleanup"); } LocalDeviceState::~LocalDeviceState() { @@ -154,7 +172,8 @@ absl::Status LocalDeviceState::ThenExecuteCallback( auto callback_stream = callback_stream_map_->find(stream); if (callback_stream == callback_stream_map_->end()) { TF_ASSIGN_OR_RETURN(auto new_stream, executor_->CreateStream()); - new_stream->set_name(absl::StrFormat("Callback for %s", stream->name())); + new_stream->SetName( + absl::StrFormat("Callback for %s", stream->GetName())); callback_stream = callback_stream_map_->insert({stream, std::move(new_stream)}).first; } @@ -243,7 +262,7 @@ std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { // The stream pool is empty, create a new stream. auto stream = compute_stream_->parent()->CreateStream().value(); - stream->set_name("Pool stream"); + stream->SetName("Pool stream"); return stream; } diff --git a/third_party/xla/xla/pjrt/local_device_state.h b/third_party/xla/xla/pjrt/local_device_state.h index 1ce1f1ea7d5401..a7ed7addd84499 100644 --- a/third_party/xla/xla/pjrt/local_device_state.h +++ b/third_party/xla/xla/pjrt/local_device_state.h @@ -170,6 +170,8 @@ class LocalDeviceState { WorkerThread* execute_thread() const { return execute_thread_.get(); } + WorkerThread* cleanup_thread() const { return cleanup_thread_.get(); } + // Enqueues a host callback on 'stream'. `stream` may, but need not, wait for // `callback` to complete. It is safe to call runtime methods from the // callback. @@ -199,6 +201,12 @@ class LocalDeviceState { // Returns a fresh, PRNG-generated random seed for an XLA computation. int GetNewPrngSeed(); + // Whether to allow deleting a buffer before the operation fulfilling the + // buffer is scheduled by the host. + bool allow_delete_before_fulfill() const { + return allow_delete_before_fulfill_; + } + private: absl::Status SynchronizeAllActivity(); @@ -255,6 +263,12 @@ class LocalDeviceState { // semaphore during calls to Execute but release it from a callback and if // they are the same thread we might deadlock. std::unique_ptr callback_thread_; + + // One thread dedicated to cleaning up buffers. Scheduled work on this thread + // may wait for other threads to schedule writes to buffers. + std::unique_ptr cleanup_thread_; + + bool allow_delete_before_fulfill_ = true; }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index 90e904a533c98c..830e10f4502093 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -42,6 +42,8 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Visitors.h" @@ -50,7 +52,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/register.h" +#include "stablehlo/api/PortableApi.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Serialization.h" @@ -58,12 +62,13 @@ limitations under the License. #include "stablehlo/dialect/Version.h" #include "stablehlo/transforms/Passes.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/utils.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -136,27 +141,18 @@ absl::StatusOr> ParseMlirModuleString( mlir::OwningOpRef module = mlir::parseSourceString( llvm::StringRef(mlir_module_str.data(), mlir_module_str.size()), - // IR may be invalid because some fields may be using DenseElements - // instead of DenseArray. We rectify that below and verify after. - mlir::ParserConfig{&context, /*verifyAfterParse=*/false}); + mlir::ParserConfig{&context}); if (!module) { + mlir::emitError(mlir::UnknownLoc::get(&context)) + << "Failed to parse using StableHLO v" + << mlir::vhlo::Version::getCurrentVersion() << ", " + << "this could indicate forward incompatibility, >12w old " + "unsupported plugin, or a portable artifact that needs to be " + "further downgraded."; return diagnostic_handler.ConsumeStatus(); } - // In - // https://github.com/google/jax/commit/184e3a88004680dbf34328b05c5fc0d869cc4a93, - // fields on some ops were changed to use Dense{Bool,I64}ArrayAttr instead of - // I64DenseElementsAttr (DenseIntElementsAttr). Some clients still expect - // dense elements, not dense arrays, so when serializing we always convert the - // arrays to elements. The elements need to be converted back to arrays when - // deserializing. - // TODO: b/320507168 - Remove the conversion code, and verifyAfterParse. TF_RETURN_IF_ERROR(UpgradeVersionedStablehlo(*module)); - if (failed(module->verifyInvariants())) { - VLOG(1) << "MLIR verification failed."; - module->dump(); - return diagnostic_handler.ConsumeStatus(); - } return std::move(module); } @@ -170,6 +166,21 @@ absl::Status ParseMlirModuleStringAndConvertToXlaComputation( return_tuple, /*use_shardy=*/false); } +absl::Status ExportShardyForHloRoundTrip(mlir::ModuleOp module) { + mlir::MLIRContext* context = module.getContext(); + mlir::PassManager pm(context); + xla::sdy::addSdyRoundTripExportPipeline(pm); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); + if (!mlir::succeeded(pm.run(module))) { + const absl::Status status = diagnostic_handler.ConsumeStatus(); + return absl::InvalidArgumentError( + absl::StrCat("Shardy export failed;\n\nDetailed " + "error from MLIR: ", + status.message())); + } + return absl::OkStatus(); +} + absl::StatusOr SerializeUsingNativeBytecode( mlir::ModuleOp module) { std::string bytecode; @@ -194,20 +205,32 @@ absl::StatusOr SerializeUsingNativeBytecode( } absl::StatusOr SerializeUsingVersionedStablehlo( - mlir::ModuleOp mlir_module, absl::string_view target, bool inplace) { + mlir::ModuleOp mlir_module, absl::string_view requested_target, + bool inplace) { mlir::MLIRContext* context = mlir_module->getContext(); mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); + // Usually the plugin is older than the framework, but occasionally a plugin's + // nightly build will use the latest public release of a framework. Serialize + // using the framework's version in these cases. + auto target = mlir::stablehlo::getSmallerVersion( + requested_target, mlir::stablehlo::getCurrentVersion()); + if (mlir::failed(target)) { + return absl::InvalidArgumentError( + "Invalid StableHLO target version requested."); + } + // Legalize CHLO -> [StableHLO+Shape] -> StableHLO // Preserve higher-level ops with XLA support. To be replaced by composites. mlir::PassManager pm(context); + xla::sdy::addSdyRoundTripExportPipeline(pm); pm.addNestedPass( mlir::mhlo::createChloLegalizeToHighLevelMhloPass()); pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( - mlir::stablehlo::createStablehloCreateCompatibilityExpanderPass( - {std::string(target)})); + mlir::stablehlo::createStablehloCompatibilityExpanderPass( + {target.value()})); pm.addNestedPass( mlir::stablehlo::createChloLegalizeToStablehloPass()); pm.addNestedPass( @@ -232,8 +255,8 @@ absl::StatusOr SerializeUsingVersionedStablehlo( // Serialize portable artifact std::string buffer; llvm::raw_string_ostream os(buffer); - if (failed(mlir::stablehlo::serializePortableArtifact(mlir_module, target, - os))) { + if (mlir::failed(mlir::stablehlo::serializePortableArtifact( + mlir_module, target.value(), os))) { const absl::Status status = diagnostic_handler.ConsumeStatus(); return absl::InvalidArgumentError(absl::StrCat( "Failed to serialize StableHLO;\n\nDetailed error from MLIR: ", @@ -251,7 +274,14 @@ absl::Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) { return absl::OkStatus(); } -std::string GetDefaultStablehloVersion() { +std::string GetDefaultStablehloVersion(std::optional plugin_version) { + // TODO: (b/370803410) Use WEEK_12 in PJRT, some plugins were not up to date, + // so temporarily using 1.0.0 to allow them time for a new release. + // PJRT v54 released Jun 10, so most plugins should use WEEK_12 by default. + if (plugin_version.has_value() && plugin_version.value() < 54) { + return "0.19.0"; + } + // This version must be >=12w old. return mlir::vhlo::Version::fromCompatibilityRequirement( mlir::vhlo::Version::CompatibilityRequirement::WEEK_12) @@ -259,22 +289,22 @@ std::string GetDefaultStablehloVersion() { } absl::StatusOr Serialize(mlir::ModuleOp module, - std::optional /*plugin_version*/, absl::string_view target, bool inplace) { // Current PJRT users expect 12 weeks forward compat, VHLO provides this // compat. // TODO (b/344930098): Allow VHLO interop and remove the all_stablehlo check - bool all_stablehlo = true; + bool all_stablehlo_or_shardy = true; module->walk([&](mlir::Operation* op) { if (!llvm::isa(op) && !llvm::isa(op->getDialect())) { - all_stablehlo = false; + mlir::chlo::ChloDialect, mlir::sdy::SdyDialect>( + op->getDialect())) { + all_stablehlo_or_shardy = false; return mlir::WalkResult::interrupt(); } return mlir::WalkResult::advance(); }); - if (!all_stablehlo) { + if (!all_stablehlo_or_shardy) { return SerializeUsingNativeBytecode(module); } return SerializeUsingVersionedStablehlo(module, target, inplace); diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.h b/third_party/xla/xla/pjrt/mlir_to_hlo.h index 15c2818ef9ef2a..2413851c386fd3 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.h +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" namespace xla { @@ -40,9 +40,15 @@ absl::Status ParseMlirModuleStringAndConvertToXlaComputation( absl::string_view mlir_module_str, XlaComputation& xla_computation, bool use_tuple_args, bool return_tuple); +// Export an MHLO + Shardy module into a pure MHLO module, to prepare for a +// round trip to HLO, such that the Shardy ops and attributes are preserved when +// going back to MLIR for Shardy propagation. +absl::Status ExportShardyForHloRoundTrip(mlir::ModuleOp module); + // Returns a version of StableHLO ~12w old, for forward compatibility with PJRT // plugins on a quarterly update cycle. -std::string GetDefaultStablehloVersion(); +std::string GetDefaultStablehloVersion( + std::optional plugin_version = std::nullopt); // Serialize using MLIR Bytecode Format which does not guarantee forward or // backward compatiblity of the dialects used. If passing StableHLO with forward @@ -52,22 +58,23 @@ std::string GetDefaultStablehloVersion(); // For plugin_version < 41, returns `SerializeUsingNativeBytecode`. // For plugin_version >= 41, returns `SerializeUsingVersionedStablehlo`. absl::StatusOr Serialize(mlir::ModuleOp mlir_module, - std::optional plugin_version, absl::string_view target, bool inplace = false); // Serializes an MLIR module to a portable artifact with forward and backward // compatibility. Supports modules using StableHLO/MHLO/CHLO/Func dialects. -// Target parameter is a StableHLO version string ("0.9.0") which can be used -// for forward compatibility to specify the target downgrade version. -// Most commonly should use: +// The `requested_target` parameter is a StableHLO version string ("0.9.0") +// which can be used for forward compatibility to specify the target downgrade +// version. Most commonly should use: // `mlir::stablehlo::getCurrentVersion()` for backward compat but not forward. // `mlir::stablehlo::getMinimumVersion()` for maximum forward compatibility. -// Ideally should be the `mlir::stablehlo::getCurrentVersion()` of the plugin. -// If program contains dialects that aren't supposed in StableHLO portable -// artifacts, use SerializeUsingNativeBytecode. +// In PJRT, the `requested_target` should be the current version of the PJRT +// plugin. Serialize will use `min(framework_version, plugin_version)` to +// serialize. If program contains dialects that aren't supported in StableHLO +// portable artifacts, use SerializeUsingNativeBytecode. absl::StatusOr SerializeUsingVersionedStablehlo( - mlir::ModuleOp mlir_module, absl::string_view target, bool inplace = false); + mlir::ModuleOp mlir_module, absl::string_view requested_target, + bool inplace = false); // Given a module that might be a portable artifact, deserialize and upgrade it // back to StableHLO. diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc b/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc index 25568b51029d80..4e7b2610f4bcbe 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "stablehlo/api/PortableApi.h" #include "xla/test.h" #include "tsl/platform/statusor.h" @@ -50,12 +51,31 @@ TEST(MlirToHloTest, StablehloTest) { mlir::MLIRContext context; TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, ParseMlirModuleString(kProgram, context)); - TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0")); + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0")); // StableHLO uses VHLO for PJRT serialization. EXPECT_THAT(blob, IsVhloArtifact("1.0.0")); } +TEST(MlirToHloTest, StablehloPluginNewerThanFramework) { + constexpr char kProgram[] = + R"( + func.func @add(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { + %cst = stablehlo.constant dense<1.0> : tensor<1x2xf32> + %0 = stablehlo.add %arg0, %cst : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } + )"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + ParseMlirModuleString(kProgram, context)); + + // Request version v100.99.88, newer than the framework version. + // Serialize uses frameworks version when plugin requests a newer version. + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "100.99.98")); + EXPECT_THAT(blob, IsVhloArtifact(mlir::stablehlo::getCurrentVersion())); +} + TEST(MlirToHloTest, ChloTest) { constexpr char kProgram[] = R"( @@ -68,7 +88,7 @@ TEST(MlirToHloTest, ChloTest) { mlir::MLIRContext context; TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, ParseMlirModuleString(kProgram, context)); - TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0")); + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0")); // CHLO decomposes to StableHLO, so uses VHLO serialization. EXPECT_THAT(blob, IsVhloArtifact("1.0.0")); @@ -85,7 +105,7 @@ TEST(MlirToHloTest, ChloTanOpTest) { mlir::MLIRContext context; TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, ParseMlirModuleString(kProgram, context)); - TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0")); + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0")); // CHLO decomposes to StableHLO, so uses VHLO serialization. EXPECT_THAT(blob, IsVhloArtifact("1.0.0")); @@ -103,11 +123,56 @@ TEST(MlirToHloTest, MhloTest) { mlir::MLIRContext context; TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, ParseMlirModuleString(kProgram, context)); - TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0")); + TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0")); // MHLO and other dialects use native MLIR bytecode, not VHLO. EXPECT_THAT(blob, Not(IsVhloArtifact("1.0.0"))); } +TEST(MlirToHloTest, InvalidBytecodeTest) { + // MLIR bytecode format has full compatibility. + // Program using StableHLO v2.0.0 with op vhlo.constant_v99. + // TODO: Once this file is exposed via the StableHLO repo, replace this + // bytecode string with a read of the StableHLO file. + unsigned char invalid_future_vhlo_mlirbc[] = { + 0x4d, 0x4c, 0xef, 0x52, 0x0d, 0x53, 0x74, 0x61, 0x62, 0x6c, 0x65, 0x48, + 0x4c, 0x4f, 0x5f, 0x76, 0x32, 0x2e, 0x30, 0x2e, 0x30, 0x00, 0x01, 0x19, + 0x05, 0x01, 0x05, 0x09, 0x01, 0x03, 0x0b, 0x03, 0x07, 0x0f, 0x13, 0x17, + 0x03, 0x2b, 0x15, 0x07, 0x01, 0x0b, 0x0b, 0x13, 0x13, 0x13, 0x13, 0x03, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x1f, 0x03, 0x07, 0x0f, 0x13, 0x07, 0x02, + 0x53, 0x05, 0x0d, 0x17, 0x01, 0x03, 0x03, 0x17, 0x01, 0x05, 0x07, 0x17, + 0x01, 0x07, 0x15, 0x17, 0x01, 0x09, 0x0b, 0x03, 0x01, 0x23, 0x03, 0x1d, + 0x0f, 0x1d, 0x11, 0x1f, 0x01, 0x09, 0x00, 0x00, 0x80, 0x3f, 0x29, 0x01, + 0x05, 0x11, 0x01, 0x03, 0x01, 0x09, 0x04, 0x41, 0x05, 0x01, 0x50, 0x03, + 0x01, 0x07, 0x04, 0x31, 0x03, 0x01, 0x05, 0x03, 0x50, 0x05, 0x03, 0x07, + 0x04, 0x1d, 0x03, 0x03, 0x09, 0x05, 0x42, 0x07, 0x05, 0x03, 0x01, 0x07, + 0x04, 0x09, 0x03, 0x01, 0x06, 0x03, 0x01, 0x05, 0x01, 0x00, 0xad, 0x13, + 0x0f, 0x0b, 0x1b, 0x15, 0x1b, 0x11, 0x0f, 0x0b, 0x11, 0x62, 0x75, 0x69, + 0x6c, 0x74, 0x69, 0x6e, 0x00, 0x76, 0x68, 0x6c, 0x6f, 0x00, 0x6d, 0x6f, + 0x64, 0x75, 0x6c, 0x65, 0x00, 0x66, 0x75, 0x6e, 0x63, 0x5f, 0x76, 0x31, + 0x00, 0x63, 0x6f, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x74, 0x5f, 0x76, 0x39, + 0x39, 0x00, 0x72, 0x65, 0x74, 0x75, 0x72, 0x6e, 0x5f, 0x76, 0x31, 0x00, + 0x2f, 0x74, 0x6d, 0x70, 0x2f, 0x74, 0x32, 0x2e, 0x6d, 0x6c, 0x69, 0x72, + 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, + 0x00, 0x08, 0x19, 0x07, 0x05, 0x01, 0x01, 0x0b, 0x0b, 0x0d, 0x0b, 0x0f, + 0x11, 0x03, 0x13}; + unsigned int invalid_future_vhlo_mlirbc_len = 243; + + std::string buffer(reinterpret_cast(invalid_future_vhlo_mlirbc), + invalid_future_vhlo_mlirbc_len); + + mlir::MLIRContext context; + auto status = ParseMlirModuleString(buffer, context); + ASSERT_FALSE(status.ok()); + // Check that the error message contains: + // - The name of the op that is not supported (vhlo.constant_v99) + // - The version that the StableHLO portable artifact was emit for (v2.0.0) + // - The current version of StableHLO (v1.X.Y) + EXPECT_THAT(status.status().message(), HasSubstr("vhlo.constant_v99")); + EXPECT_THAT(status.status().message(), HasSubstr("StableHLO_v2.0.0")); + EXPECT_THAT(status.status().message(), + HasSubstr(mlir::stablehlo::getCurrentVersion())); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_api_test.cc b/third_party/xla/xla/pjrt/pjrt_api_test.cc index 8ee9e49451a99e..ed7579b143a959 100644 --- a/third_party/xla/xla/pjrt/pjrt_api_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_api_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/status_matchers.h" -#include "tsl/protobuf/error_codes.pb.h" namespace { using ::testing::HasSubstr; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc index 7319ea3942145c..3f08eba17e8567 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.cc @@ -34,6 +34,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -41,9 +42,11 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/pjrt/c/pjrt_c_api.h" @@ -64,10 +67,8 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/tsl/framework/allocator.h" #include "xla/util.h" #include "xla/xla.pb.h" @@ -392,10 +393,20 @@ absl::StatusOr> PjRtCApiClient::Compile( absl::StatusOr> PjRtCApiClient::Compile( mlir::ModuleOp module, CompileOptions options) { if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null"); - TF_ASSIGN_OR_RETURN( - std::string serialized, - xla::Serialize(module, plugin_attributes()->pjrt_c_api_minor_version, - xla::GetDefaultStablehloVersion())); + + auto attributes = plugin_attributes()->attributes; + std::string version_string; + auto version = attributes.find("stablehlo_current_version"); + if (version != attributes.end()) { + std::vector v = std::get>(version->second); + version_string = absl::StrFormat("%d.%d.%d", v[0], v[1], v[2]); + } else { + version_string = xla::GetDefaultStablehloVersion( + plugin_attributes()->pjrt_c_api_minor_version); + } + + TF_ASSIGN_OR_RETURN(std::string serialized, + xla::Serialize(module, version_string)); std::string format(pjrt::kMlirFormat); return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options, serialized, format); @@ -411,6 +422,17 @@ PjRtCApiClient::DeserializeExecutable(absl::string_view serialized, des_args.client = c_client_.get(); des_args.serialized_executable = serialized.data(); des_args.serialized_executable_size = serialized.length(); + des_args.overridden_serialized_compile_options = nullptr; + des_args.overridden_serialized_compile_options_size = 0; + + std::string options_str; + if (options) { + TF_ASSIGN_OR_RETURN(const CompileOptionsProto options_proto, + options->ToProto()); + options_str = options_proto.SerializeAsString(); + des_args.overridden_serialized_compile_options = options_str.c_str(); + des_args.overridden_serialized_compile_options_size = options_str.size(); + } const PJRT_Api* api = pjrt_c_api(); @@ -642,8 +664,7 @@ absl::StatusOr PjRtCApiClient::GetDefaultLayout( pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts); if (extension == nullptr) { - return absl::UnimplementedError( - "Layouts extension not implemented in this PJRT plugin."); + return LayoutUtil::MakeDescendingLayout(dims.size()); } PJRT_Layouts_PJRT_Client_GetDefaultLayout_Args args; args.struct_size = PJRT_Layouts_PJRT_Client_GetDefaultLayout_Args_STRUCT_SIZE; @@ -1796,39 +1817,36 @@ std::unique_ptr PjRtCApiBuffer::layout() const { pjrt::FindExtension( c_api, PJRT_Extension_Type::PJRT_Extension_Type_Layouts); if (extension == nullptr) { - // TODO(b/343274728): implement some generic layouts behavior for - // plugins that don't support it. - LOG(WARNING) << "PJRT_Layouts_Extension is not found when " - "PjRtCApiBuffer::layout is called."; - return nullptr; + layout_.emplace(LayoutUtil::MakeDescendingLayout(dimensions().size())); + } else { + std::unique_ptr + layout = pjrt::GetMemoryLayout(c_api, buffer_.get()); + + // TODO(b/343274093): returns a PjRtLayout that wraps a C API layout + // directly instead of de/serializing into an xla::Layout. + PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args; + serialize_args.struct_size = + PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE; + serialize_args.extension_start = nullptr; + serialize_args.layout = layout.get(); + pjrt::LogFatalIfPjrtError( + extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), + c_api); + + // Clean up `PJRT_Layouts_SerializedLayout`. + absl::Cleanup cleanup = [&serialize_args] { + serialize_args.serialized_layout_deleter( + serialize_args.serialized_layout); + }; + + std::string serialized_layout(serialize_args.serialized_bytes, + serialize_args.serialized_bytes_size); + absl::StatusOr pjrt_xla_layout = + PjRtXlaLayout::Deserialize(serialized_layout); + TF_CHECK_OK(pjrt_xla_layout.status()); + layout_.emplace(*pjrt_xla_layout); } - std::unique_ptr - layout = pjrt::GetMemoryLayout(c_api, buffer_.get()); - - // TODO(b/343274093): returns a PjRtLayout that wraps a C API layout - // directly instead of de/serializing into an xla::Layout. - PJRT_Layouts_MemoryLayout_Serialize_Args serialize_args; - serialize_args.struct_size = - PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE; - serialize_args.extension_start = nullptr; - serialize_args.layout = layout.get(); - pjrt::LogFatalIfPjrtError( - extension->PJRT_Layouts_MemoryLayout_Serialize(&serialize_args), - c_api); - - // Clean up `PJRT_Layouts_SerializedLayout`. - absl::Cleanup cleanup = [&serialize_args] { - serialize_args.serialized_layout_deleter( - serialize_args.serialized_layout); - }; - - std::string serialized_layout(serialize_args.serialized_bytes, - serialize_args.serialized_bytes_size); - absl::StatusOr pjrt_xla_layout = - PjRtXlaLayout::Deserialize(serialized_layout); - TF_CHECK_OK(pjrt_xla_layout.status()); - layout_.emplace(*pjrt_xla_layout); } } return std::make_unique(*layout_); @@ -2314,9 +2332,9 @@ absl::StatusOr> PjRtCApiCompiler::Compile( if (client) { plugin_version = client->plugin_attributes()->pjrt_c_api_minor_version; } - TF_ASSIGN_OR_RETURN(std::string serialized, - xla::Serialize(module, plugin_version, - xla::GetDefaultStablehloVersion())); + TF_ASSIGN_OR_RETURN( + std::string serialized, + xla::Serialize(module, xla::GetDefaultStablehloVersion(plugin_version))); std::string format(pjrt::kMlirFormat); return InitializeArgsAndCompileAot(c_api_, client, options, topology, serialized, format); diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client.h b/third_party/xla/xla/pjrt/pjrt_c_api_client.h index 5973462c06cbe5..49da2f8af791d4 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client.h +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client.h @@ -37,7 +37,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" @@ -282,11 +282,6 @@ class PjRtCApiClient : public PjRtClient { std::optional plugin_attributes() const override; - // TODO(b/244756954): Rethink this function altogether - PjRtRuntimeType runtime_type() const override { - return PjRtRuntimeType::kTfrt; - } - absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc index 4dbdd5d03af4cb..67074e10623e47 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/pjrt/pjrt_c_api_client.h" +#include #include #include #include @@ -24,9 +25,17 @@ limitations under the License. #include #include #include "absl/status/status.h" -#include "xla/client/xla_builder.h" +#include "absl/strings/str_format.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "stablehlo/dialect/Version.h" +#include "xla/cpu_function_runtime.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" +#include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_cpu_internal.h" +#include "xla/pjrt/mlir_to_hlo.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_compiler.h" @@ -136,7 +145,8 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, GetCApiClient("cpu")); ASSERT_GT(client->addressable_devices().size(), 1); - std::vector data(4, 0); + alignas(cpu_function_runtime::MinAlign()) std::array data; + data.fill(0); auto* data_ptr = data.data(); Shape shape = ShapeUtil::MakeShape(S32, {4}); @@ -144,7 +154,9 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { auto buffer, client->CreateViewOfDeviceBuffer( data_ptr, shape, client->addressable_devices()[0], - /*on_delete_callback=*/[data = std::move(data)]() mutable {})); + /*on_delete_callback=*/[data = std::move(data)]() mutable { + (void)data; + })); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr result, @@ -158,5 +170,70 @@ TEST(PjRtClientTest, CreateViewAndCopyToDeviceAsyncExternalCpuOnly) { *literal)); } +// TODO: (b/375454646) Eanble once frameworks have bugfix: +// https://github.com/openxla/xla/commit/2f99455cdf99e844ddad17de9f4714997023d243 +TEST(PjRtClientTest, DISABLED_CompileUsesStableHloVersion) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(const PJRT_Api* c_api, pjrt::PjrtApi("cpu")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + static auto PJRT_Client_Compile_Orig = c_api->PJRT_Client_Compile; + constexpr char kProgram[] = "func.func @main() {return}"; + mlir::MLIRContext context; + TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef module, + ParseMlirModuleString(kProgram, context)); + const_cast(c_api)->PJRT_Client_Compile = + [](PJRT_Client_Compile_Args* args) -> PJRT_Error* { + mlir::vhlo::Version version = mlir::vhlo::Version::getCurrentVersion(); + std::string version_string = absl::StrFormat( + "%d.%d.%d", version.getMajor(), version.getMinor(), version.getPatch()); + // MLIR doesn't have any functionality for retrieving the producer of + // bytecode files, so just scan the raw string. + EXPECT_TRUE(llvm::StringRef(args->program->code, args->program->code_size) + .contains(version_string)); + return PJRT_Client_Compile_Orig(args); + }; + std::unique_ptr executable = + client->Compile(*module, CompileOptions()).value(); + const_cast(c_api)->PJRT_Client_Compile = PJRT_Client_Compile_Orig; +} + +TEST(PjRtClientTest, DeserializeExecutableWithDifferentDeviceAssignment) { + SetUpCpuPjRtApi(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr client, + GetCApiClient("cpu")); + ASSERT_GT(client->addressable_devices().size(), 1); + + XlaBuilder builder("Identity"); + Shape shape = ShapeUtil::MakeShape(S32, {2, 3}); + auto input = Parameter(&builder, 0, shape, "input"); + auto computation = builder.Build(input).value(); + + auto compile_options_for_device = [](int id) -> xla::CompileOptions { + xla::DeviceAssignment device_assignment(1, 1); + device_assignment(0, 0) = id; + xla::CompileOptions options; + options.executable_build_options.set_device_assignment(device_assignment); + return options; + }; + + // Compile the executable for device 0 and serialize it. + std::unique_ptr executable = + client->Compile(computation, compile_options_for_device(0)).value(); + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_executable, + executable->SerializeExecutable()); + + // Deserialize the executable for device 1. + TF_ASSERT_OK_AND_ASSIGN( + auto deserialized_executable, + client->DeserializeExecutable(serialized_executable, + compile_options_for_device(1))); + + // Check that the executable's compile options were overridden + // with device id 1. + EXPECT_EQ( + deserialized_executable->addressable_devices()[0]->global_device_id(), 1); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index c278332dd74f8f..ddcfe417e0c6db 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -41,9 +41,10 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/literal.h" +#include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/pjrt/pjrt_common.h" #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_device_description.h" @@ -65,16 +66,6 @@ limitations under the License. namespace xla { -enum PjRtRuntimeType { kStreamExecutor, kTfrt }; -inline constexpr absl::string_view PjRtRuntimeTypeString(PjRtRuntimeType type) { - switch (type) { - case kStreamExecutor: - return "stream_executor"; - case kTfrt: - return "tfrt"; - } -} - class PjRtClient; class PjRtDevice; @@ -529,12 +520,16 @@ class PjRtClient { // Lookup any PjRtDevice for a given PjRtDevice::id(). virtual absl::StatusOr LookupDevice( - PjRtGlobalDeviceId global_device_id) const = 0; + PjRtGlobalDeviceId global_device_id) const { + return Unimplemented("LookupDevice is not supported."); + } // Return an addressable PjRtDevice for a given // PjRtDevice::local_device_id(). virtual absl::StatusOr LookupAddressableDevice( - PjRtLocalDeviceId local_device_id) const = 0; + PjRtLocalDeviceId local_device_id) const { + return Unimplemented("LookupAddressableDevice is not supported."); + } // Return all memory spaces owned by the client. // The memory spaces are in no particular order. @@ -550,21 +545,24 @@ class PjRtClient { // (e.g. the CUDA version on GPU or libtpu version on Cloud TPU). virtual absl::string_view platform_version() const = 0; + // Returns the key value store used by the client. + virtual std::optional> + key_value_store() const { + return std::nullopt; + } + // Returns information about the underlying PJRT C API plugin if such a plugin // is being used, otherwise returns nullopt. virtual std::optional plugin_attributes() const { return std::nullopt; } - // TODO(b/244756954): Rethink this function altogether - // Returns an enum that identifies the type of runtime being used under this - // client. - virtual PjRtRuntimeType runtime_type() const = 0; - // Return a device-specific default device assignment, e.g., GPU and TPU may // be different. virtual absl::StatusOr GetDefaultDeviceAssignment( - int num_replicas, int num_partitions) const = 0; + int num_replicas, int num_partitions) const { + return Unimplemented("GetDefaultDeviceAssignment is not supported."); + } // Returns a device-specific default device assignment for multi-slice system. // If num_replicas_per_slice is not defined (nullopt) then we assume that @@ -585,19 +583,27 @@ class PjRtClient { // user-specified or compiler-chosen layouts are requested via the // "mhlo.layout_mode" attribute. virtual absl::StatusOr GetDefaultLayout( - PrimitiveType element_type, absl::Span dims) = 0; + PrimitiveType element_type, absl::Span dims) { + return Unimplemented("GetDefaultLayout is not supported."); + } // Returns a backend-specific HLO cost analysis visitor. virtual absl::StatusOr> GetHloCostAnalysis() - const = 0; + const { + return Unimplemented("GetHloCostAnalysis is not supported."); + } // Compile `computation` with given `options`. virtual absl::StatusOr> Compile( - const XlaComputation& computation, CompileOptions options) = 0; + const XlaComputation& computation, CompileOptions options) { + return Unimplemented("Compile with options is not supported."); + } // Variant of `Compile` that accepts an MLIR module. virtual absl::StatusOr> Compile( - mlir::ModuleOp module, CompileOptions options) = 0; + mlir::ModuleOp module, CompileOptions options) { + return Unimplemented("Compile with MLIR Module is not supported."); + } // Deserializes a serialized executable as produced by // PjRtExecutable::SerializeExecutable(). `serialized` must have been @@ -606,9 +612,14 @@ class PjRtClient { // Pending completion of b/237720161, `options` is a mandatory argument in // most implementations of this interface. They _are_ optional for // implementations related to the PJRT C API. + // + // If `options` are provided, then they override the compile options + // from the serialized executable (`serialized`). virtual absl::StatusOr> DeserializeExecutable(absl::string_view serialized, - std::optional options) = 0; + std::optional options) { + return Unimplemented("Deserialize is not supported."); + } // LoadSerializedExecutable takes the serialized output of PjRtExecutable. The // returned executable is loaded by this client. The same checks are made as @@ -634,7 +645,9 @@ class PjRtClient { // Creates a buffer on the device without initializing or copying any data. virtual absl::StatusOr> CreateUninitializedBuffer( - const Shape& shape, PjRtDevice* device) = 0; + const Shape& shape, PjRtDevice* device) { + return Unimplemented("CreateUnitializedBuffer is not supported."); + } // Creates buffer in the given memory space that carries an error future // without allocating memory. @@ -759,12 +772,16 @@ class PjRtClient { }; // Returns a manager for async transfers into a set of buffers with on-host - // shapes defined by 'shape_specs' and optional `device_layouts`. The - // `device_layout` is used when non-compact layouts are preferred. + // shapes defined by 'shape_specs' and optional `device_layouts`. + // + // If the desired layout of one or more buffers is not specified in + // `device_layouts`, then those buffers will use the default device layout. If + // `device_layouts` itself is not specified, then all buffers will use the + // default device layout. virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtDevice* device) { return absl::UnimplementedError(absl::StrCat( "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " @@ -776,7 +793,7 @@ class PjRtClient { virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice( absl::Span shape_specs, - std::optional> device_layouts, + std::optional>> device_layouts, PjRtMemorySpace* memory_space) { return absl::UnimplementedError(absl::StrCat( "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " @@ -788,12 +805,19 @@ class PjRtClient { // shapes 'shapes'. virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtDevice* device) = 0; + PjRtDevice* device) { + return Unimplemented( + "CreateBuffersForAsyncHostToDevice with on host is not implemented."); + } // Variant of CreateBuffersForAsyncHostToDevice with PjRtMemorySpace. virtual absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, - PjRtMemorySpace* memory_space) = 0; + PjRtMemorySpace* memory_space) { + return Unimplemented( + "CreateBuffersForAsyncHostToDevice with PjRtMemorySpace is not " + "implemented."); + } // Creates a shapeless buffer on the device that can be partitioned into // multiple PjRtBuffer. This class is an Arena version of @@ -886,7 +910,9 @@ class PjRtClient { std::optional> byte_strides, HostBufferSemantics host_buffer_semantics, absl::AnyInvocable on_done_with_host_buffer, - PjRtDevice* device) = 0; + PjRtDevice* device) { + return Unimplemented("BufferFromHostBuffer is not implemented."); + } // Variant of BufferFromHostBuffer that takes an optional device layout. It is // used when non-compact layout is preferred. @@ -922,7 +948,9 @@ class PjRtClient { // the caller should, for example, wait for GetReadyFuture().Await() // completes on the return value before letting literal go out of scope. virtual absl::StatusOr> BufferFromHostLiteral( - const LiteralSlice& literal, PjRtDevice* device) = 0; + const LiteralSlice& literal, PjRtDevice* device) { + return Unimplemented("BufferFromHostLiteral is not implemented."); + } virtual absl::StatusOr> BufferFromHostLiteral( const LiteralSlice& literal, PjRtDevice* device, @@ -974,7 +1002,9 @@ class PjRtClient { virtual absl::StatusOr> CreateViewOfDeviceBuffer( void* device_ptr, const Shape& shape, PjRtDevice* device, std::function on_delete_callback, - std::optional stream = std::nullopt) = 0; + std::optional stream = std::nullopt) { + return Unimplemented("CreateViewOfDeviceBuffer is not implemented."); + } // Returns platform-dependent address for the given buffer that is often but // not guaranteed to be the physical/device address. @@ -1002,7 +1032,9 @@ class PjRtClient { virtual absl::StatusOr>> MakeCrossHostReceiveBuffers(absl::Span shapes, PjRtDevice* device, - PjRtCrossHostRecvNotifier notifier) = 0; + PjRtCrossHostRecvNotifier notifier) { + return Unimplemented("MakeCrossHostReceiveBuffers is not implemented."); + } // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers, as in MakeCrossHostReceiveBuffers above, however @@ -1036,15 +1068,24 @@ class PjRtClient { virtual absl::StatusOr>> MakeCrossHostReceiveBuffersForGather( absl::Span shapes, std::vector gather_details, - PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) = 0; + PjRtDevice* device, PjRtCrossHostRecvNotifier notifier) { + return Unimplemented( + "MakeCrossHostReceiveBuffersForGather is not implemented."); + } // Create ChannelHandles for XLA send/recv. - virtual absl::StatusOr CreateChannelHandle() = 0; - virtual absl::StatusOr CreateDeviceToHostChannelHandle() = 0; + virtual absl::StatusOr CreateChannelHandle() { + return Unimplemented("CreateChannelHandle is not implemented."); + } + virtual absl::StatusOr CreateDeviceToHostChannelHandle() { + return Unimplemented("CreateDeviceToHostChannelHandle is not implemented."); + } // TODO(zhangqiaorjc): Experimental API to be removed. // Defragment device memory. - virtual absl::Status Defragment() = 0; + virtual absl::Status Defragment() { + return Unimplemented("Defragment is not implemented."); + } // If false, this client does not support send/recv host callbacks, and // callers should not set the `send_callbacks` and `recv_callbacks` arguments diff --git a/third_party/xla/xla/pjrt/pjrt_client_test.cc b/third_party/xla/xla/pjrt/pjrt_client_test.cc index cdaadf57295ca5..64e3552ded666a 100644 --- a/third_party/xla/xla/pjrt/pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_client_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/pjrt/pjrt_client_test.h" +#include #include #include #include @@ -25,10 +26,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/cpu_function_runtime.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/pjrt/pjrt_client.h" -#include "xla/service/hlo_parser.h" +#include "xla/pjrt/pjrt_compiler.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -447,9 +450,12 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) { ASSERT_GT(client->addressable_devices().size(), 1); // Skip non-CPU platforms. - if (client->platform_id() != CpuId()) return; + if (client->platform_id() != CpuId()) { + GTEST_SKIP() << "This test is for CPU only."; + } - std::vector data(4, 0); + alignas(cpu_function_runtime::MinAlign()) std::array data; + data.fill(0); auto* data_ptr = data.data(); Shape shape = ShapeUtil::MakeShape(S32, {4}); TF_ASSERT_OK_AND_ASSIGN( @@ -457,8 +463,7 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) { client->CreateViewOfDeviceBuffer( data_ptr, shape, client->addressable_devices()[0], /*on_delete_callback=*/[data = std::move(data)]() mutable { - data.clear(); - data.shrink_to_fit(); + (void)data; })); auto* device_1 = client->addressable_devices()[1]; @@ -486,6 +491,36 @@ TEST(PjRtClientTest, CopyToDeviceAsyncExternalCpuOnly) { } } +TEST(PjRtClientTest, CreateViewOfUnalignedBufferReturnsErrorCpuOnly) { + TF_ASSERT_OK_AND_ASSIGN(auto client, GetClient()); + ASSERT_GT(client->addressable_devices().size(), 1); + + // Skip non-CPU platforms. + if (client->platform_id() != CpuId()) { + GTEST_SKIP() << "This test is for CPU only."; + } + + alignas(cpu_function_runtime::MinAlign()) std::array data; + auto* data_ptr = data.data(); + + // Pointer to the second element is always unaligned, because it's shifted by + // 4 bytes (size of int32_t) from the original pointer. + auto* unaligned_ptr = data_ptr + 1; + + // Shape with a size smaller than the original data vector, because the + // 'unaligned_ptr' points to the second element. + Shape shape = ShapeUtil::MakeShape(S32, {4}); + + // Attempt to create a view of the unaligned buffer. Expect an error. + auto result = client->CreateViewOfDeviceBuffer( + unaligned_ptr, shape, client->addressable_devices()[0], + /*on_delete_callback=*/std::function()); + + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status().message(), + ::testing::HasSubstr("unaligned data")); +} + absl::StatusOr> MakeFloatBuffer( PjRtClient* client, const std::vector& data, absl::Span dimensions) { diff --git a/third_party/xla/xla/pjrt/pjrt_common.h b/third_party/xla/xla/pjrt/pjrt_common.h index 042d28acd12a05..8d11cdae79b3c9 100644 --- a/third_party/xla/xla/pjrt/pjrt_common.h +++ b/third_party/xla/xla/pjrt/pjrt_common.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/pjrt_compiler.h b/third_party/xla/xla/pjrt/pjrt_compiler.h index b60bc95b378b8b..3e5a158391383b 100644 --- a/third_party/xla/xla/pjrt/pjrt_compiler.h +++ b/third_party/xla/xla/pjrt/pjrt_compiler.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/pjrt/pjrt_device_description.h" #include "xla/pjrt/pjrt_executable.h" #include "tsl/platform/fingerprint.h" diff --git a/third_party/xla/xla/pjrt/pjrt_compiler_test.cc b/third_party/xla/xla/pjrt/pjrt_compiler_test.cc index 454e44f07b8fa0..7fdf22efbc1426 100644 --- a/third_party/xla/xla/pjrt/pjrt_compiler_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_compiler_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/pjrt/metrics.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_device_description.h" diff --git a/third_party/xla/xla/pjrt/pjrt_layout.h b/third_party/xla/xla/pjrt/pjrt_layout.h index e9c01b36c1fed6..eea9b861690860 100644 --- a/third_party/xla/xla/pjrt/pjrt_layout.h +++ b/third_party/xla/xla/pjrt/pjrt_layout.h @@ -24,8 +24,8 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" -#include "xla/service/hlo_parser.h" #include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index c9e1c61cd56a3b..672d76adb125bb 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -97,8 +97,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/pjrt/distributed/protocol.pb.h" @@ -425,6 +425,27 @@ absl::Status AddDestinationBufferSynchronization( return absl::OkStatus(); } +// We wait for events that the compute stream didn't already wait for. Based on +// our heuristics, for usage events, this rare case should only occur when a +// buffer was copied to a device and then never used there. In that case we get +// a new stream and use it to hold onto a reference to the buffer until the +// events are complete. +void MaybeWaitForEventOnStream(BufferSequencingEvent* event, + LocalDeviceState* local_device_state, + se::Stream*& stream) { + if (!event->IsPredeterminedErrorOrDefinedOn( + local_device_state->compute_stream()) && + !event->IsComplete()) { + if (stream == nullptr) { + stream = local_device_state->GetFixedSizePoolUsageStream(); + } + VLOG(2) << "Waiting for event: " << event + << "; is_predetermined_error: " << event->IsPredeterminedError() + << "; on stream: " << stream; + event->WaitForEventOnStream(stream); + } +} + } // namespace absl::StatusOr> @@ -1492,6 +1513,24 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { if (local_device_state->allocation_model() == LocalDeviceState::kComputeSynchronized) { se::Stream* block_stream = nullptr; + // If an event is not defined yet, we wait for it to be defined in a new + // thread in the thread pool. + // This allows the host to schedule: + // create buffer -> use -> delete -> fulfill + absl::InlinedVector, 5> + events_to_wait_for_in_a_different_thread; + auto maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait = + [&events_to_wait_for_in_a_different_thread, local_device_state, + &block_stream](const std::shared_ptr& event) { + if (local_device_state->allow_delete_before_fulfill() && + !event->IsDefined()) { + // Wait for the event to be defined in a different thread. + events_to_wait_for_in_a_different_thread.push_back(event); + } else { + MaybeWaitForEventOnStream(event.get(), local_device_state, + block_stream); + } + }; for (const auto& stream_and_event : events) { VLOG(2) << "Checking whether need to wait for stream_and_event: stream: " @@ -1501,25 +1540,11 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { << "; is_predetermined_error: " << stream_and_event.event->IsPredeterminedError(); // We only need to do something for events that didn't already acquire a - // reference to the buffer, and also which the compute stream didn't - // already wait for. Based on our heuristics this rare case should only - // occur when a buffer was copied to a device and then never used there. - // In that case we get a new stream and use it to hold onto a reference - // to the buffer until the events are complete. - if (!stream_and_event.reference_held && - !stream_and_event.event->IsPredeterminedErrorOrDefinedOn( - local_device_state->compute_stream()) && - !stream_and_event.event->IsComplete()) { - if (block_stream == nullptr) { - block_stream = local_device_state->GetFixedSizePoolUsageStream(); - } - VLOG(2) << "Waiting for stream_and_event: stream: " - << stream_and_event.stream - << "; event: " << stream_and_event.event.get() - << "; reference_held: " << stream_and_event.reference_held - << "; is_predetermined_error: " - << stream_and_event.event->IsPredeterminedError(); - stream_and_event.event->WaitForEventOnStream(block_stream); + // reference to the buffer and for other situations described in the + // comment of MaybeWaitForEventOnStream() + if (!stream_and_event.reference_held) { + maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait( + stream_and_event.event); } } for (const auto& definition_event : device_buffer->definition_events()) { @@ -1527,31 +1552,34 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) { << definition_event.get() << "; is_predetermined_error: " << definition_event->IsPredeterminedError(); // Here we wait for the definition events to complete on block_stream as - // well, if they are not on the compute stream and not also recorded as - // usage events. - // - // Since it's possible that definition_event.SetSequencingEvent() - // is called on a different host thread than this host thread, when in - // future more conditions are added to this check, we should be careful - // about whether we put them before the IsPredeterminedErrorOrDefinedOn - // check or after it. For example, we shouldn't add an IsDefined() check - // before the IsPredeterminedErrorOrDefinedOn() check here because that - // could potentially cause a shortcut where we don't wait for - // definition_event.SetSequencingEvent() on the other thread and - // eventually cause memory corruption. - if (!definition_event->IsPredeterminedErrorOrDefinedOn( - local_device_state->compute_stream()) && - !definition_event->IsComplete()) { - if (block_stream == nullptr) { - block_stream = local_device_state->GetFixedSizePoolUsageStream(); - } - VLOG(2) << "Waiting for definition_event: " << definition_event.get() - << "; is_predetermined_error: " - << definition_event->IsPredeterminedError(); - definition_event->WaitForEventOnStream(block_stream); - } + // well, in case they are not also usage events. + maybe_wait_for_event_on_block_stream_or_add_to_events_to_wait( + definition_event); } - if (block_stream != nullptr) { + if (!events_to_wait_for_in_a_different_thread.empty()) { + VLOG(1) << "Going to wait for " + << events_to_wait_for_in_a_different_thread.size() + << " events in a different thread."; + // We always use the cleanup_thread instead of using the + // client->thread_pool() here to avoid exhausting the client thread + // pool. + local_device_state->cleanup_thread()->Schedule( + [events_to_wait_for_in_a_different_thread = + std::move(events_to_wait_for_in_a_different_thread), + local_device_state, device_buffer, block_stream]() mutable { + for (const auto& event : + events_to_wait_for_in_a_different_thread) { + MaybeWaitForEventOnStream(event.get(), local_device_state, + block_stream); + } + if (block_stream != nullptr) { + TF_CHECK_OK(local_device_state->ThenExecuteCallback( + block_stream, [device_buffer]() { + // Drops device_buffer shared pointer. + })); + } + }); + } else if (block_stream != nullptr) { TF_RETURN_IF_ERROR(local_device_state->ThenExecuteCallback( block_stream, [device_buffer]() { // Drops device_buffer shared pointer. @@ -2202,7 +2230,7 @@ absl::Status CheckCompatibleShapes(bool strict_shape_checking, } // Makes a tuple from the arguments to an execution. -absl::StatusOr MakeTupleHelper( +absl::StatusOr> MakeTupleHelper( PjRtStreamExecutorClient* client, LocalDeviceState* local_device, bool strict_shape_checking, const Shape& tupled_parameter_shape, absl::Span py_buffers, @@ -2268,7 +2296,8 @@ absl::StatusOr MakeTupleHelper( auto transfer_event = std::make_shared(client->thread_pool()); transfer_event->SetSequencingEvent(std::move(event_or).value(), stream); - return TupleHandle({std::move(execution_input), std::move(transfer_event)}); + return std::make_unique( + TupleHandle({std::move(execution_input), std::move(transfer_event)})); } // Converts a ScopedShapedBuffer returned from an execution into a @@ -2437,7 +2466,7 @@ PjRtStreamExecutorLoadedExecutable::MakeExecutionInputsAndWaitForEvents( client_->client()->backend().transfer_manager(); // Lift tuple_handle outside the conditional so that the event it returns is // not destroyed until after the loop below that waits on events. - std::optional tuple_handle; + std::unique_ptr tuple_handle; if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { TF_ASSIGN_OR_RETURN( tuple_handle, diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index eb800367fda308..ae20faf4a73036 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -43,8 +43,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" @@ -290,7 +290,6 @@ class PjRtStreamExecutorClient : public PjRtClient { PjRtPlatformId platform_id() const override { return platform_id_; } absl::string_view platform_name() const override { return platform_name_; } absl::string_view platform_version() const override { return ""; } - PjRtRuntimeType runtime_type() const override { return kStreamExecutor; } // Most platforms expect device-to-device transfers to be enqueued on the // source d2d stream, but some platforms use the destination d2d stream. This diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index 2fa381df57290a..a25125ceb9a2c6 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/synchronization/mutex.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_comparison.h" #include "xla/literal_util.h" diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index ac457ea392bc02..8933a2482c8683 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -34,7 +34,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/literal.h" @@ -247,9 +247,6 @@ class TfPjRtClient : public PjRtClient { absl::string_view platform_version() const override { return wrapped_->platform_version(); } - PjRtRuntimeType runtime_type() const override { - return wrapped_->runtime_type(); - } absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { return wrapped_->GetDefaultDeviceAssignment(num_replicas, num_partitions); diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc b/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc index 2e946459fbebde..16c9f3aa183f58 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/tf_pjrt_client_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" #include "xla/pjrt/cpu/cpu_client.h" -#include "xla/service/hlo_parser.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/tracked_device_buffer.cc b/third_party/xla/xla/pjrt/tracked_device_buffer.cc index 2fe88dcacdca59..ca551c4aa81b83 100644 --- a/third_party/xla/xla/pjrt/tracked_device_buffer.cc +++ b/third_party/xla/xla/pjrt/tracked_device_buffer.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/synchronization/mutex.h" @@ -139,7 +140,6 @@ bool BufferSequencingEvent::IsComplete() { void BufferSequencingEvent::ExecuteOrAddToFutureTasks( const std::string& task_name, std::function task) { - absl::MutexLock lock(&mu_); tsl::profiler::TraceMeProducer producer( "BufferSequencingEvent::ExecuteOrAddToFutureTasks", tsl::profiler::ContextType::kPjRt); @@ -150,19 +150,38 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks( context_id); task(); }; - if (defined_status_.IsConcrete()) { - thread_pool_->Schedule(std::move(wrapped_task)); - return; + { + absl::MutexLock lock(&mu_); + if (!defined_status_.IsConcrete()) { + on_ready_tasks_callback_[task_name] = std::move(wrapped_task); + return; + } + // Release the lock to avoid deadlock, in the case where the + // thread_pool_->Schedule() executes wrapped_task inline. + // This is rare but could happen. The callbacks could potentially try to + // acquire the mutex of this BufferSequencingEvent. } - on_ready_tasks_callback_[task_name] = std::move(wrapped_task); + thread_pool_->Schedule(std::move(wrapped_task)); } void BufferSequencingEvent::ExecuteFutureTasks() { - absl::MutexLock lock(&mu_); - for (auto& [task_name, task_callback] : on_ready_tasks_callback_) { - thread_pool_->Schedule(std::move(task_callback)); + absl::flat_hash_map> + on_ready_tasks_callback; + { + absl::MutexLock lock(&mu_); + on_ready_tasks_callback = std::move(on_ready_tasks_callback_); + // Release the lock to avoid deadlock, in the case where the + // thread_pool_->Schedule() executes call_all_task_callbacks inline. + // This is rare but could happen. The callbacks could potentially try to + // acquire the mutex of this BufferSequencingEvent. } - on_ready_tasks_callback_.clear(); + auto call_all_task_callbacks = [on_ready_tasks_callback = + std::move(on_ready_tasks_callback)]() { + for (auto& [task_name, task_callback] : on_ready_tasks_callback) { + task_callback(); + } + }; + thread_pool_->Schedule(std::move(call_all_task_callbacks)); } /* static */ std::shared_ptr diff --git a/third_party/xla/xla/pjrt/transpose.cc b/third_party/xla/xla/pjrt/transpose.cc index 0921357d869e22..e915a77bc5177f 100644 --- a/third_party/xla/xla/pjrt/transpose.cc +++ b/third_party/xla/xla/pjrt/transpose.cc @@ -76,6 +76,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -87,6 +88,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -464,7 +466,8 @@ static_assert(sizeof(uint128) == 16, "uint128 should be 16 bytes in size"); void TransposePlan::Execute( const void* a, void* b, - const std::function)>& schedule_work) const { + std::optional)>> + schedule_work) const { if (num_elems_ == 0) { return; } @@ -508,7 +511,7 @@ void TransposePlan::Execute( absl::BlockingCounter counter(nodes_.size() - 1); for (size_t i = 1; i < nodes_.size(); ++i) { absl::Span nodes = nodes_[i]; - schedule_work([&, nodes]() { + (*schedule_work)([&, nodes]() { execute_by_type(nodes); counter.DecrementCount(); }); diff --git a/third_party/xla/xla/pjrt/transpose.h b/third_party/xla/xla/pjrt/transpose.h index 469e4419c53431..714db857b7da2d 100644 --- a/third_party/xla/xla/pjrt/transpose.h +++ b/third_party/xla/xla/pjrt/transpose.h @@ -30,11 +30,13 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/container/inlined_vector.h" +#include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "absl/types/variant.h" @@ -110,8 +112,8 @@ class TransposePlan { // Currently there are no alignment requirements on either `a` or `b`. However // performance may be better if either or both are aligned. void Execute(const void* a, void* b, - const std::function)>& - schedule_work = {}) const; + std::optional)>> + schedule_work = std::nullopt) const; // Returns a human-readable description of the plan. std::string ToString() const; diff --git a/third_party/xla/xla/pjrt/transpose_kernels.h b/third_party/xla/xla/pjrt/transpose_kernels.h index 18b79bdae2e3f4..cba611d67cc30b 100644 --- a/third_party/xla/xla/pjrt/transpose_kernels.h +++ b/third_party/xla/xla/pjrt/transpose_kernels.h @@ -24,8 +24,6 @@ limitations under the License. #include "xla/compiler_macros.h" -namespace xla { - #ifdef XLA_HAS_SSE2 #include // IWYU pragma: keep #endif @@ -38,6 +36,8 @@ namespace xla { #define XLA_HAS_VEC128 #endif // defined(XLA_HAS_SSE2) || defined(XLA_HAS_ARM_NEON) +namespace xla { + // The transpose microkernels use a general approach of zipping elements from // different rows together. We start zipping together elements of size 1, size 2 // and so-on until we have achieved our transpose. As we increase the number of diff --git a/third_party/xla/xla/pjrt/transpose_test.cc b/third_party/xla/xla/pjrt/transpose_test.cc index 50472b1fa611ed..7d7ed774c0ce9f 100644 --- a/third_party/xla/xla/pjrt/transpose_test.cc +++ b/third_party/xla/xla/pjrt/transpose_test.cc @@ -33,11 +33,11 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/error_codes.pb.h" namespace xla { diff --git a/third_party/xla/xla/pjrt/utils.cc b/third_party/xla/xla/pjrt/utils.cc index 2f5bc314919c37..b50fa31b648db4 100644 --- a/third_party/xla/xla/pjrt/utils.cc +++ b/third_party/xla/xla/pjrt/utils.cc @@ -41,7 +41,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -384,58 +384,6 @@ absl::StatusOr> GetOutputMemoryKinds( // Make sure to choose delimiter that will never show up in Layout strings. static const char* kDelimiter = ";"; -static std::string GetFrontendAttr(absl::Span layout_modes) { - return absl::StrJoin(layout_modes, kDelimiter, - [](std::string* out, const LayoutMode& mode) { - absl::StrAppend(out, mode.ToString()); - }); -} - -absl::Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, - XlaComputation& xla_computation) { - TF_ASSIGN_OR_RETURN(std::vector arg_layout_modes, - GetArgLayoutModes(module)); - TF_ASSIGN_OR_RETURN(std::vector out_layout_modes, - GetOutputLayoutModes(module)); - - // Type is string->string proto map. Using auto here to deal with different - // build environments. - auto& frontend_attrs = *xla_computation.mutable_proto() - ->mutable_frontend_attributes() - ->mutable_map(); - frontend_attrs["arg_layout_modes"] = GetFrontendAttr(arg_layout_modes); - frontend_attrs["out_layout_modes"] = GetFrontendAttr(out_layout_modes); - return absl::OkStatus(); -} - -static std::string GetFrontendAttrForMemorySpace( - const std::vector& memory_spaces) { - return absl::StrJoin( - memory_spaces, kDelimiter, - [](std::string* out, const MemorySpaceColor memory_kind) { - absl::StrAppend(out, memory_kind); - }); -} - -absl::Status AddMemoryKindsToFrontendAttrs(mlir::ModuleOp module, - XlaComputation& xla_computation) { - TF_ASSIGN_OR_RETURN(std::vector arg_memory_spaces, - GetArgMemoryKinds(module)); - TF_ASSIGN_OR_RETURN(std::vector out_memory_spaces, - GetOutputMemoryKinds(module)); - - // Type is string->string proto map. Using auto here to deal with different - // build environments. - auto& frontend_attrs = *xla_computation.mutable_proto() - ->mutable_frontend_attributes() - ->mutable_map(); - frontend_attrs["arg_memory_spaces"] = - GetFrontendAttrForMemorySpace(arg_memory_spaces); - frontend_attrs["out_memory_spaces"] = - GetFrontendAttrForMemorySpace(out_memory_spaces); - return absl::OkStatus(); -} - static absl::StatusOr> GetLayoutModesFromFrontendAttr( absl::string_view attr) { // SkipEmpty() needed to avoid returning the empty string when attr is empty. diff --git a/third_party/xla/xla/pjrt/utils.h b/third_party/xla/xla/pjrt/utils.h index 95d2f8aec3e6c3..3470bd164d72a7 100644 --- a/third_party/xla/xla/pjrt/utils.h +++ b/third_party/xla/xla/pjrt/utils.h @@ -30,7 +30,7 @@ limitations under the License. #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/layout_mode.h" #include "xla/service/computation_placer.h" @@ -70,18 +70,6 @@ absl::StatusOr> GetArgMemoryKinds( absl::StatusOr> GetOutputMemoryKinds( mlir::ModuleOp module); -// Populates the frontend attributes "arg_layout_mode" and "out_layout_mode" in -// xla_computation based on `module`. This function must be called before the -// LayoutMode getters below work correctly on `computation`. -absl::Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, - XlaComputation& xla_computation); - -// Populates the frontend attributes "arg_memory_kinds" and "out_memory_kinds" -// in xla_computation based on `module`. This function must be called before the -// LayoutMode getters below work correctly on `computation`. -absl::Status AddMemoryKindsToFrontendAttrs(mlir::ModuleOp module, - XlaComputation& xla_computation); - // Returns the LayoutMode for each argument of the computations. Checks for the // "arg_layout_mode" frontend attribute, and if not present, assumes // LayoutMode::Mode::kDefault. diff --git a/third_party/xla/xla/primitive_util.h b/third_party/xla/xla/primitive_util.h index 8fbeedbff94dad..de5ee4fde11d7b 100644 --- a/third_party/xla/xla/primitive_util.h +++ b/third_party/xla/xla/primitive_util.h @@ -180,6 +180,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E5M2; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E4M3; +} + template <> constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FN; @@ -200,6 +205,11 @@ constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3FNUZ; } +template <> +constexpr PrimitiveType NativeToPrimitiveType() { + return F8E3M4; +} + // Complex template <> constexpr PrimitiveType NativeToPrimitiveType() { @@ -309,6 +319,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e5m2; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e4m3; +}; + template <> struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fn; @@ -329,6 +344,11 @@ struct PrimitiveTypeToNative { using type = tsl::float8_e4m3fnuz; }; +template <> +struct PrimitiveTypeToNative { + using type = tsl::float8_e3m4; +}; + // Complex template <> struct PrimitiveTypeToNative { @@ -362,8 +382,9 @@ inline constexpr bool IsArrayType(PrimitiveType primitive_type) { } constexpr bool IsF8Type(PrimitiveType type) { - return type == F8E5M2 || type == F8E4M3FN || type == F8E4M3B11FNUZ || - type == F8E5M2FNUZ || type == F8E4M3FNUZ; + return type == F8E5M2 || type == F8E4M3 || type == F8E4M3FN || + type == F8E4M3B11FNUZ || type == F8E5M2FNUZ || type == F8E4M3FNUZ || + type == F8E3M4; } constexpr bool IsFloatingPointType(PrimitiveType type) { @@ -428,6 +449,12 @@ template constexpr R FloatingPointTypeSwitch(F&& f, PrimitiveType type) { if (ABSL_PREDICT_TRUE(IsFloatingPointType(type))) { switch (type) { + case F8E3M4: + return std::forward(f)( + PrimitiveTypeConstant()); + case F8E4M3: + return std::forward(f)( + PrimitiveTypeConstant()); case F8E4M3FN: return std::forward(f)( PrimitiveTypeConstant()); diff --git a/third_party/xla/xla/primitive_util_test.cc b/third_party/xla/xla/primitive_util_test.cc index e8c9dc77087062..850203f17379a4 100644 --- a/third_party/xla/xla/primitive_util_test.cc +++ b/third_party/xla/xla/primitive_util_test.cc @@ -76,10 +76,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[PRED][BF16] = true; expecteds[PRED][C128] = true; expecteds[PRED][F8E5M2] = true; + expecteds[PRED][F8E4M3] = true; expecteds[PRED][F8E4M3FN] = true; expecteds[PRED][F8E4M3B11FNUZ] = true; expecteds[PRED][F8E5M2FNUZ] = true; expecteds[PRED][F8E4M3FNUZ] = true; + expecteds[PRED][F8E3M4] = true; expecteds[S2][PRED] = false; expecteds[S2][S2] = true; expecteds[S2][S4] = true; @@ -100,10 +102,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S2][BF16] = true; expecteds[S2][C128] = true; expecteds[S2][F8E5M2] = true; + expecteds[S2][F8E4M3] = true; expecteds[S2][F8E4M3FN] = true; expecteds[S2][F8E4M3B11FNUZ] = true; expecteds[S2][F8E5M2FNUZ] = true; expecteds[S2][F8E4M3FNUZ] = true; + expecteds[S2][F8E3M4] = true; expecteds[S4][PRED] = false; expecteds[S4][S2] = false; expecteds[S4][S4] = true; @@ -124,10 +128,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S4][BF16] = true; expecteds[S4][C128] = true; expecteds[S4][F8E5M2] = true; + expecteds[S4][F8E4M3] = true; expecteds[S4][F8E4M3FN] = true; expecteds[S4][F8E4M3B11FNUZ] = true; expecteds[S4][F8E5M2FNUZ] = true; expecteds[S4][F8E4M3FNUZ] = true; + expecteds[S4][F8E3M4] = true; expecteds[S8][PRED] = false; expecteds[S8][S2] = false; expecteds[S8][S4] = false; @@ -148,10 +154,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S8][BF16] = true; expecteds[S8][C128] = true; expecteds[S8][F8E5M2] = false; + expecteds[S8][F8E4M3] = false; expecteds[S8][F8E4M3FN] = false; expecteds[S8][F8E4M3B11FNUZ] = false; expecteds[S8][F8E5M2FNUZ] = false; expecteds[S8][F8E4M3FNUZ] = false; + expecteds[S8][F8E3M4] = false; expecteds[S16][PRED] = false; expecteds[S16][S2] = false; expecteds[S16][S4] = false; @@ -172,10 +180,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S16][BF16] = false; expecteds[S16][C128] = true; expecteds[S16][F8E5M2] = false; + expecteds[S16][F8E4M3] = false; expecteds[S16][F8E4M3FN] = false; expecteds[S16][F8E4M3B11FNUZ] = false; expecteds[S16][F8E5M2FNUZ] = false; expecteds[S16][F8E4M3FNUZ] = false; + expecteds[S16][F8E3M4] = false; expecteds[S32][PRED] = false; expecteds[S32][S2] = false; expecteds[S32][S4] = false; @@ -196,10 +206,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S32][BF16] = false; expecteds[S32][C128] = true; expecteds[S32][F8E5M2] = false; + expecteds[S32][F8E4M3] = false; expecteds[S32][F8E4M3FN] = false; expecteds[S32][F8E4M3B11FNUZ] = false; expecteds[S32][F8E5M2FNUZ] = false; expecteds[S32][F8E4M3FNUZ] = false; + expecteds[S32][F8E3M4] = false; expecteds[S64][PRED] = false; expecteds[S64][S2] = false; expecteds[S64][S4] = false; @@ -220,10 +232,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[S64][BF16] = false; expecteds[S64][C128] = false; expecteds[S64][F8E5M2] = false; + expecteds[S64][F8E4M3] = false; expecteds[S64][F8E4M3FN] = false; expecteds[S64][F8E4M3B11FNUZ] = false; expecteds[S64][F8E5M2FNUZ] = false; expecteds[S64][F8E4M3FNUZ] = false; + expecteds[S64][F8E3M4] = false; expecteds[U2][PRED] = false; expecteds[U2][S2] = false; expecteds[U2][S4] = true; @@ -246,10 +260,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U2][BF16] = true; expecteds[U2][C128] = true; expecteds[U2][F8E5M2] = true; + expecteds[U2][F8E4M3] = true; expecteds[U2][F8E4M3FN] = true; expecteds[U2][F8E4M3B11FNUZ] = true; expecteds[U2][F8E5M2FNUZ] = true; expecteds[U2][F8E4M3FNUZ] = true; + expecteds[U2][F8E3M4] = true; expecteds[U4][PRED] = false; expecteds[U4][S2] = false; expecteds[U4][S4] = false; @@ -272,10 +288,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U4][BF16] = true; expecteds[U4][C128] = true; expecteds[U4][F8E5M2] = false; + expecteds[U4][F8E4M3] = true; expecteds[U4][F8E4M3FN] = true; expecteds[U4][F8E4M3B11FNUZ] = true; expecteds[U4][F8E5M2FNUZ] = false; expecteds[U4][F8E4M3FNUZ] = true; + expecteds[U4][F8E3M4] = true; expecteds[U8][PRED] = false; expecteds[U8][S2] = false; expecteds[U8][S4] = false; @@ -298,10 +316,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U8][BF16] = true; expecteds[U8][C128] = true; expecteds[U8][F8E5M2] = false; + expecteds[U8][F8E4M3] = false; expecteds[U8][F8E4M3FN] = false; expecteds[U8][F8E4M3B11FNUZ] = false; expecteds[U8][F8E5M2FNUZ] = false; expecteds[U8][F8E4M3FNUZ] = false; + expecteds[U8][F8E3M4] = false; expecteds[U16][PRED] = false; expecteds[U16][S2] = false; expecteds[U16][S4] = false; @@ -322,10 +342,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U16][BF16] = false; expecteds[U16][C128] = true; expecteds[U16][F8E5M2] = false; + expecteds[U16][F8E4M3] = false; expecteds[U16][F8E4M3FN] = false; expecteds[U16][F8E4M3B11FNUZ] = false; expecteds[U16][F8E5M2FNUZ] = false; expecteds[U16][F8E4M3FNUZ] = false; + expecteds[U16][F8E3M4] = false; expecteds[U32][PRED] = false; expecteds[U32][S2] = false; expecteds[U32][S4] = false; @@ -346,10 +368,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U32][BF16] = false; expecteds[U32][C128] = true; expecteds[U32][F8E5M2] = false; + expecteds[U32][F8E4M3] = false; expecteds[U32][F8E4M3FN] = false; expecteds[U32][F8E4M3B11FNUZ] = false; expecteds[U32][F8E5M2FNUZ] = false; expecteds[U32][F8E4M3FNUZ] = false; + expecteds[U32][F8E3M4] = false; expecteds[U64][PRED] = false; expecteds[U64][S2] = false; expecteds[U64][S4] = false; @@ -370,10 +394,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[U64][BF16] = false; expecteds[U64][C128] = false; expecteds[U64][F8E5M2] = false; + expecteds[U64][F8E4M3] = false; expecteds[U64][F8E4M3FN] = false; expecteds[U64][F8E4M3B11FNUZ] = false; expecteds[U64][F8E5M2FNUZ] = false; expecteds[U64][F8E4M3FNUZ] = false; + expecteds[U64][F8E3M4] = false; expecteds[F16][PRED] = false; expecteds[F16][S2] = false; expecteds[F16][S4] = false; @@ -394,10 +420,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F16][BF16] = false; expecteds[F16][C128] = true; expecteds[F16][F8E5M2] = false; + expecteds[F16][F8E4M3] = false; expecteds[F16][F8E4M3FN] = false; expecteds[F16][F8E4M3B11FNUZ] = false; expecteds[F16][F8E5M2FNUZ] = false; expecteds[F16][F8E4M3FNUZ] = false; + expecteds[F16][F8E3M4] = false; expecteds[F32][PRED] = false; expecteds[F32][S2] = false; expecteds[F32][S4] = false; @@ -418,10 +446,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F32][BF16] = false; expecteds[F32][C128] = true; expecteds[F32][F8E5M2] = false; + expecteds[F32][F8E4M3] = false; expecteds[F32][F8E4M3FN] = false; expecteds[F32][F8E4M3B11FNUZ] = false; expecteds[F32][F8E5M2FNUZ] = false; expecteds[F32][F8E4M3FNUZ] = false; + expecteds[F32][F8E3M4] = false; expecteds[F64][PRED] = false; expecteds[F64][S2] = false; expecteds[F64][S4] = false; @@ -442,10 +472,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F64][BF16] = false; expecteds[F64][C128] = true; expecteds[F64][F8E5M2] = false; + expecteds[F64][F8E4M3] = false; expecteds[F64][F8E4M3FN] = false; expecteds[F64][F8E4M3B11FNUZ] = false; expecteds[F64][F8E5M2FNUZ] = false; expecteds[F64][F8E4M3FNUZ] = false; + expecteds[F64][F8E3M4] = false; expecteds[C64][PRED] = false; expecteds[C64][S2] = false; expecteds[C64][S4] = false; @@ -466,10 +498,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C64][BF16] = false; expecteds[C64][C128] = true; expecteds[C64][F8E5M2] = false; + expecteds[C64][F8E4M3] = false; expecteds[C64][F8E4M3FN] = false; expecteds[C64][F8E4M3B11FNUZ] = false; expecteds[C64][F8E5M2FNUZ] = false; expecteds[C64][F8E4M3FNUZ] = false; + expecteds[C64][F8E3M4] = false; expecteds[BF16][PRED] = false; expecteds[BF16][S2] = false; expecteds[BF16][S4] = false; @@ -490,10 +524,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[BF16][BF16] = true; expecteds[BF16][C128] = true; expecteds[BF16][F8E5M2] = false; + expecteds[BF16][F8E4M3] = false; expecteds[BF16][F8E4M3FN] = false; expecteds[BF16][F8E4M3B11FNUZ] = false; expecteds[BF16][F8E5M2FNUZ] = false; expecteds[BF16][F8E4M3FNUZ] = false; + expecteds[BF16][F8E3M4] = false; expecteds[C128][PRED] = false; expecteds[C128][S2] = false; expecteds[C128][S4] = false; @@ -514,10 +550,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[C128][BF16] = false; expecteds[C128][C128] = true; expecteds[C128][F8E5M2] = false; + expecteds[C128][F8E4M3] = false; expecteds[C128][F8E4M3FN] = false; expecteds[C128][F8E4M3B11FNUZ] = false; expecteds[C128][F8E5M2FNUZ] = false; expecteds[C128][F8E4M3FNUZ] = false; + expecteds[C128][F8E3M4] = false; expecteds[F8E5M2][PRED] = false; expecteds[F8E5M2][S2] = false; expecteds[F8E5M2][S4] = false; @@ -538,10 +576,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2][BF16] = true; expecteds[F8E5M2][C128] = true; expecteds[F8E5M2][F8E5M2] = true; + expecteds[F8E5M2][F8E4M3] = false; expecteds[F8E5M2][F8E4M3FN] = false; expecteds[F8E5M2][F8E4M3B11FNUZ] = false; expecteds[F8E5M2][F8E5M2FNUZ] = false; expecteds[F8E5M2][F8E4M3FNUZ] = false; + expecteds[F8E5M2][F8E3M4] = false; + expecteds[F8E4M3][PRED] = false; + expecteds[F8E4M3][S2] = false; + expecteds[F8E4M3][S4] = false; + expecteds[F8E4M3][S8] = false; + expecteds[F8E4M3][S16] = false; + expecteds[F8E4M3][S32] = false; + expecteds[F8E4M3][S64] = false; + expecteds[F8E4M3][U2] = false; + expecteds[F8E4M3][U4] = false; + expecteds[F8E4M3][U8] = false; + expecteds[F8E4M3][U16] = false; + expecteds[F8E4M3][U32] = false; + expecteds[F8E4M3][U64] = false; + expecteds[F8E4M3][F16] = true; + expecteds[F8E4M3][F32] = true; + expecteds[F8E4M3][F64] = true; + expecteds[F8E4M3][C64] = true; + expecteds[F8E4M3][BF16] = true; + expecteds[F8E4M3][C128] = true; + expecteds[F8E4M3][F8E5M2] = false; + expecteds[F8E4M3][F8E5M2FNUZ] = false; + expecteds[F8E4M3][F8E4M3] = true; + expecteds[F8E4M3][F8E4M3FN] = false; + expecteds[F8E4M3][F8E4M3FNUZ] = false; + expecteds[F8E4M3][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3][F8E3M4] = false; expecteds[F8E4M3FN][PRED] = false; expecteds[F8E4M3FN][S2] = false; expecteds[F8E4M3FN][S4] = false; @@ -562,8 +628,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FN][BF16] = true; expecteds[F8E4M3FN][C128] = true; expecteds[F8E4M3FN][F8E5M2] = false; + expecteds[F8E4M3FN][F8E5M2FNUZ] = false; + expecteds[F8E4M3FN][F8E4M3] = false; expecteds[F8E4M3FN][F8E4M3FN] = true; + expecteds[F8E4M3FN][F8E4M3FNUZ] = false; expecteds[F8E4M3FN][F8E4M3B11FNUZ] = false; + expecteds[F8E4M3FN][F8E3M4] = false; expecteds[F8E4M3B11FNUZ][PRED] = false; expecteds[F8E4M3B11FNUZ][S2] = false; expecteds[F8E4M3B11FNUZ][S4] = false; @@ -584,12 +654,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3B11FNUZ][BF16] = true; expecteds[F8E4M3B11FNUZ][C128] = true; expecteds[F8E4M3B11FNUZ][F8E5M2] = false; + expecteds[F8E4M3B11FNUZ][F8E4M3] = false; expecteds[F8E4M3B11FNUZ][F8E4M3FN] = false; expecteds[F8E4M3B11FNUZ][F8E4M3B11FNUZ] = true; expecteds[F8E4M3B11FNUZ][F8E4M3FNUZ] = false; expecteds[F8E4M3B11FNUZ][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E5M2FNUZ] = false; - expecteds[F8E4M3FN][F8E4M3FNUZ] = false; + expecteds[F8E4M3B11FNUZ][F8E3M4] = false; expecteds[F8E5M2FNUZ][PRED] = false; expecteds[F8E5M2FNUZ][S2] = false; expecteds[F8E5M2FNUZ][S4] = false; @@ -610,10 +680,12 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E5M2FNUZ][BF16] = true; expecteds[F8E5M2FNUZ][C128] = true; expecteds[F8E5M2FNUZ][F8E5M2] = false; + expecteds[F8E5M2FNUZ][F8E4M3] = false; expecteds[F8E5M2FNUZ][F8E4M3FN] = false; expecteds[F8E5M2FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E5M2FNUZ][F8E5M2FNUZ] = true; expecteds[F8E5M2FNUZ][F8E4M3FNUZ] = false; + expecteds[F8E5M2FNUZ][F8E3M4] = false; expecteds[F8E4M3FNUZ][PRED] = false; expecteds[F8E4M3FNUZ][S2] = false; expecteds[F8E4M3FNUZ][S4] = false; @@ -634,10 +706,38 @@ TEST(PrimitiveUtilTest, CastPreservesValues) { expecteds[F8E4M3FNUZ][BF16] = true; expecteds[F8E4M3FNUZ][C128] = true; expecteds[F8E4M3FNUZ][F8E5M2] = false; + expecteds[F8E4M3FNUZ][F8E4M3] = false; expecteds[F8E4M3FNUZ][F8E4M3FN] = false; expecteds[F8E4M3FNUZ][F8E4M3B11FNUZ] = false; expecteds[F8E4M3FNUZ][F8E5M2FNUZ] = false; expecteds[F8E4M3FNUZ][F8E4M3FNUZ] = true; + expecteds[F8E4M3FNUZ][F8E3M4] = false; + expecteds[F8E3M4][PRED] = false; + expecteds[F8E3M4][S2] = false; + expecteds[F8E3M4][S4] = false; + expecteds[F8E3M4][S8] = false; + expecteds[F8E3M4][S16] = false; + expecteds[F8E3M4][S32] = false; + expecteds[F8E3M4][S64] = false; + expecteds[F8E3M4][U2] = false; + expecteds[F8E3M4][U4] = false; + expecteds[F8E3M4][U8] = false; + expecteds[F8E3M4][U16] = false; + expecteds[F8E3M4][U32] = false; + expecteds[F8E3M4][U64] = false; + expecteds[F8E3M4][F16] = true; + expecteds[F8E3M4][F32] = true; + expecteds[F8E3M4][F64] = true; + expecteds[F8E3M4][C64] = true; + expecteds[F8E3M4][BF16] = true; + expecteds[F8E3M4][C128] = true; + expecteds[F8E3M4][F8E5M2] = false; + expecteds[F8E3M4][F8E5M2FNUZ] = false; + expecteds[F8E3M4][F8E4M3] = false; + expecteds[F8E3M4][F8E4M3FN] = false; + expecteds[F8E3M4][F8E4M3FNUZ] = false; + expecteds[F8E3M4][F8E4M3B11FNUZ] = false; + expecteds[F8E3M4][F8E3M4] = true; for (int from_type_int = PrimitiveType_MIN; from_type_int < PrimitiveType_ARRAYSIZE; ++from_type_int) { diff --git a/third_party/xla/xla/protobuf_util.h b/third_party/xla/xla/protobuf_util.h index 81f795287c17d8..b763d7ddaeff1c 100644 --- a/third_party/xla/xla/protobuf_util.h +++ b/third_party/xla/xla/protobuf_util.h @@ -56,11 +56,6 @@ class ProtobufHashWrapper { } }; -// Registers a function that may either expand a dirpath or forward the original -// dirpath along as-is. -void RegisterDirectoryExpander( - const std::function& expander); - } // namespace protobuf_util } // namespace xla diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index 011900dbe17b98..09eb19959dd4b6 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -92,6 +92,7 @@ py_strict_library( "@absl_py//absl/logging", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", + "@ml_dtypes", ] + if_google(["//third_party/py/numpy"]), ) @@ -114,6 +115,7 @@ py_strict_test( "@absl_py//absl/logging", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", + "@ml_dtypes", ] + if_google(["//third_party/py/numpy"]) + xla_py_test_deps(), ) @@ -150,6 +152,7 @@ py_strict_test( "@absl_py//absl/logging", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", + "@ml_dtypes", ] + if_google( [ ":xla_gpu_extension", @@ -207,7 +210,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/pjrt:exceptions", "//xla/python/ifrt", - "//xla/python/pjrt_ifrt", + "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -332,6 +335,7 @@ cc_library( deps = [ ":aggregate_profile", ":callback", + ":guard_lib", ":nb_absl_span", ":nb_class_ptr", ":nb_helpers", @@ -340,7 +344,6 @@ cc_library( ":py_host_callback_proto_cc", ":python_ref_manager", ":traceback", - ":transfer_guard_lib", ":types", ":util", ":xplane_to_profile_instructions", @@ -375,9 +378,9 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/ir:hlo", "//xla/pjrt:exceptions", "//xla/pjrt:host_callback", @@ -390,7 +393,6 @@ cc_library( "//xla/pjrt:pjrt_executable", "//xla/pjrt:pjrt_future", "//xla/pjrt:pjrt_layout", - "//xla/pjrt:pjrt_stream_executor_client", "//xla/pjrt:status_casters", "//xla/pjrt:transpose", "//xla/pjrt/distributed", @@ -402,6 +404,7 @@ cc_library( "//xla/python/ifrt/hlo:hlo_program", "//xla/python/pjrt_ifrt", "//xla/python/pjrt_ifrt:pjrt_attribute_map_util", + "//xla/python/pjrt_ifrt:pjrt_dtype", "//xla/python/pjrt_ifrt:xla_host_callback_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", @@ -538,6 +541,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/pjrt_ifrt", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -561,7 +565,7 @@ cc_library( "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = [":friends"], # For the functions to access C++ flags/thread-local variables + visibility = ["//visibility:private"], # For the functions to access C++ flags/thread-local variables deps = [ ":nb_absl_inlined_vector", ":nb_absl_span", @@ -609,7 +613,7 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:util", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/pjrt:mlir_to_hlo", @@ -676,54 +680,21 @@ cc_library( "@com_google_absl//absl/types:span", "@nanobind", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:approx_topk", - "//xla/client/lib:approx_topk_shape", - "//xla/client/lib:comparators", - "//xla/client/lib:lu_decomposition", - "//xla/client/lib:math", - "//xla/client/lib:qr", - "//xla/client/lib:self_adjoint_eig", - "//xla/client/lib:sorting", - "//xla/client/lib:svd", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:approx_topk", + "//xla/hlo/builder/lib:approx_topk_shape", + "//xla/hlo/builder/lib:comparators", + "//xla/hlo/builder/lib:lu_decomposition", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:qr", + "//xla/hlo/builder/lib:self_adjoint_eig", + "//xla/hlo/builder/lib:sorting", + "//xla/hlo/builder/lib:svd", "//xla/pjrt:status_casters", ], ) -cc_library( - name = "outfeed_receiver", - srcs = ["outfeed_receiver.cc"], - hdrs = ["outfeed_receiver.h"], - deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:util", - "//xla/client:executable_build_options", - "//xla/client:sharding_builder", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/pjrt:pjrt_client", - "//xla/pjrt:pjrt_executable", - "//xla/python/pjrt_ifrt", - "//xla/service:computation_placer_hdr", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - cc_library( name = "pjit", srcs = ["pjit.cc"], @@ -736,6 +707,7 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ + ":guard_lib", ":jax_jit", ":nb_helpers", ":nb_numpy", @@ -743,7 +715,6 @@ cc_library( ":python_ref_manager", ":pytree", ":traceback", - ":transfer_guard_lib", # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -820,59 +791,6 @@ cc_library( ], ) -xla_cc_test( - name = "outfeed_receiver_test_cpu", - size = "small", - srcs = ["outfeed_receiver_test.cc"], - deps = [ - ":outfeed_receiver", - "//xla:test", - "//xla/client:client_library", - "//xla/client:executable_build_options", - "//xla/client:xla_builder", - "//xla/pjrt:pjrt_client", - "//xla/pjrt:pjrt_stream_executor_client", - "//xla/pjrt/cpu:cpu_client", - "//xla/service:platform_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "outfeed_receiver_py", - srcs = ["outfeed_receiver_py.cc"], - hdrs = ["outfeed_receiver_py.h"], - compatible_with = [], - copts = [ - "-fexceptions", - "-fno-strict-aliasing", - ], - features = ["-use_header_modules"], - deps = [ - ":nb_class_ptr", - ":outfeed_receiver", - ":py_client", - ":types", - # placeholder for index annotation deps - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Support", - "@nanobind", - "//xla:literal", - "//xla/client:executable_build_options", - "//xla/client:xla_builder", - "//xla/pjrt:status_casters", - "//xla/python/ifrt", - "//xla/python/pjrt_ifrt", - "@local_tsl//tsl/platform:logging", - ], -) - py_strict_test( name = "pytree_test", srcs = ["pytree_test.py"], @@ -954,7 +872,8 @@ cc_library( "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", "@stablehlo//:stablehlo_serialization", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/mlir/utils:error_util", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", @@ -962,7 +881,6 @@ cc_library( "//xla/pjrt:status_casters", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -1027,11 +945,8 @@ cc_library( deps = [ ":aggregate_profile", ":profiler_utils", - ":py_client", - ":types", ":xplane_to_profile_instructions", # placeholder for index annotation deps - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@nanobind", @@ -1043,6 +958,7 @@ cc_library( "//xla/pjrt:status_casters", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", + "//xla/python/profiler:profile_data_lib", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc/client:capture_profile", "//xla/tsl/profiler/rpc/client:profiler_client_impl", @@ -1056,19 +972,21 @@ cc_library( ) cc_library( - name = "transfer_guard_lib", - srcs = ["transfer_guard_lib.cc"], - hdrs = ["transfer_guard_lib.h"], + name = "guard_lib", + srcs = ["guard_lib.cc"], + hdrs = ["guard_lib.h"], compatible_with = [], copts = [ "-fexceptions", "-fno-strict-aliasing", ], features = ["-use_header_modules"], - visibility = [":friends"], + visibility = ["//visibility:private"], deps = [ # placeholder for index annotation deps + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@nanobind", "//xla:util", @@ -1106,11 +1024,10 @@ cc_library( features = ["-use_header_modules"], visibility = ["//visibility:private"], deps = [ - ":nb_helpers", # placeholder for index annotation deps "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@nanobind", @@ -1153,28 +1070,28 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_builder", - "//xla/client:xla_computation", "//xla/ffi", "//xla/ffi:ffi_api", "//xla/ffi/api:c_api", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "//xla/pjrt:exceptions", "//xla/pjrt:pjrt_executable", "//xla/pjrt:status_casters", "//xla/service:call_inliner", "//xla/service:computation_placer", "//xla/service:custom_call_target_registry", - "//xla/service:flatten_call_graph", - "//xla/service:hlo_dce", "//xla/service:hlo_graph_dumper", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service:name_uniquer", - "//xla/service:tuple_simplifier", "//xla/tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -1286,6 +1203,7 @@ tsl_pybind_extension( deps = [ ":custom_call_sharding", ":dlpack", + ":guard_lib", ":jax_jit", ":logging", ":mlir", @@ -1293,7 +1211,6 @@ tsl_pybind_extension( ":nb_absl_span", ":nb_class_ptr", ":ops", - ":outfeed_receiver_py", ":pjit", ":pmap_lib", ":pprof_profile_builder", @@ -1303,7 +1220,6 @@ tsl_pybind_extension( ":pytree", ":refine_polymorphic_shapes", ":traceback", - ":transfer_guard_lib", ":types", ":util", ":weakref_lru_cache", @@ -1393,6 +1309,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:hlo_proto_cc", "//xla/tsl/profiler/convert:xla_op_utils", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -1401,11 +1322,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) @@ -1418,13 +1334,13 @@ xla_cc_test( "//xla/tests:verified_hlo_module", "//xla/tsl/profiler/convert:xla_op_utils", "//xla/tsl/profiler/rpc/client:save_profile", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/python/custom_partition_callback.cc b/third_party/xla/xla/python/custom_partition_callback.cc index 21485c96dbe90b..df49dfc1e37bc4 100644 --- a/third_party/xla/xla/python/custom_partition_callback.cc +++ b/third_party/xla/xla/python/custom_partition_callback.cc @@ -31,8 +31,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" diff --git a/third_party/xla/xla/python/dlpack.cc b/third_party/xla/xla/python/dlpack.cc index 105293aad9470b..a4bf30dbfb73bc 100644 --- a/third_party/xla/xla/python/dlpack.cc +++ b/third_party/xla/xla/python/dlpack.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" @@ -295,6 +296,72 @@ absl::StatusOr DeviceForDLDevice(const PjRtClient* cpu_client, } } +absl::Status VerifyDType(const DLTensor& dl_tensor) { + if (dl_tensor.dtype.bits % 8 != 0) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: bits should be a multiple of 8, got " + "%d", + dl_tensor.dtype.bits); + } + + if (dl_tensor.dtype.lanes != 1) { + return InvalidArgument( + "Unsupported DLPack tensor dtype: lanes should be equal to 1, got %d", + dl_tensor.dtype.lanes); + } + + return absl::OkStatus(); +} + +absl::StatusOr> GetByteStrides(const DLTensor& dl_tensor) { + TF_RETURN_IF_ERROR(VerifyDType(dl_tensor)); + + // Convert element strides from the number of elements to the number of bytes. + std::vector strides; + strides.reserve(dl_tensor.ndim); + for (int i = 0; i < dl_tensor.ndim; ++i) { + strides.push_back(dl_tensor.strides[i] * dl_tensor.dtype.bits / 8); + } + return strides; +} + +absl::StatusOr> MakePjrtBuffer( + PjRtDevice& device, ::DLManagedTensor* dlmt, const Shape& shape, + PrimitiveType element_type, absl::Span dimensions, + std::optional stream = std::nullopt) { + std::function on_delete_callback; + if (dlmt->deleter) { + on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; + } + + // First try to create a view. + void* data = + static_cast(dlmt->dl_tensor.data) + dlmt->dl_tensor.byte_offset; + auto result = device.client()->CreateViewOfDeviceBuffer( + data, shape, &device, on_delete_callback, stream); + + // If that fails with invalid argument, it's possibly because of the incorrect + // alignment. If we're on CPU, we can create a copy of buffer. + if (result.status().code() == absl::StatusCode::kInvalidArgument && + dlmt->dl_tensor.device.device_type == kDLCPU) { + LOG(WARNING) << "DLPack buffer is not aligned (data at: " << data + << "). Creating a copy."; + + // Convert tensor strides (expressed in number of elements) to byte strides. + std::optional> byte_strides; + if (dlmt->dl_tensor.strides) { + TF_ASSIGN_OR_RETURN(byte_strides, GetByteStrides(dlmt->dl_tensor)); + } + + // Create a copy. + result = device.client()->BufferFromHostBuffer( + data, element_type, dimensions, byte_strides, + PjRtClient::HostBufferSemantics::kMutableZeroCopy, on_delete_callback, + &device); + } + return result; +} + } // namespace absl::StatusOr BufferToDLPackManagedTensor( @@ -444,11 +511,11 @@ absl::StatusOr DLPackManagedTensorToBuffer( if (dlmt->deleter) { on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; } - TF_ASSIGN_OR_RETURN(auto pjrt_buffer, - device->client()->CreateViewOfDeviceBuffer( - static_cast(dlmt->dl_tensor.data) + - dlmt->dl_tensor.byte_offset, - shape, device, on_delete_callback)); + + TF_ASSIGN_OR_RETURN( + auto pjrt_buffer, + MakePjrtBuffer(*device, dlmt, shape, element_type, dimensions)); + // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. PyCapsule_SetName(tensor.ptr(), "used_dltensor"); @@ -515,16 +582,10 @@ absl::StatusOr DLPackManagedTensorToBuffer( Shape shape = ShapeUtil::MakeShapeWithDenseLayout(element_type, dimensions, minor_to_major); - std::function on_delete_callback; - if (dlmt->deleter) { - on_delete_callback = [dlmt]() { dlmt->deleter(dlmt); }; - } - TF_ASSIGN_OR_RETURN( - auto pjrt_buffer, - device->pjrt_device()->client()->CreateViewOfDeviceBuffer( - static_cast(dlmt->dl_tensor.data) + - dlmt->dl_tensor.byte_offset, - shape, device->pjrt_device(), on_delete_callback, stream)); + TF_ASSIGN_OR_RETURN(auto pjrt_buffer, + MakePjrtBuffer(*device->pjrt_device(), dlmt, shape, + element_type, dimensions, stream)); + // We have taken ownership of the array inside the capsule; make sure the // capsule it cannot be used again. PyCapsule_SetName(tensor.ptr(), "used_dltensor"); diff --git a/third_party/xla/xla/python/gpu_support.cc b/third_party/xla/xla/python/gpu_support.cc index 541627e0014122..060329a033498a 100644 --- a/third_party/xla/xla/python/gpu_support.cc +++ b/third_party/xla/xla/python/gpu_support.cc @@ -64,7 +64,9 @@ void RegisterGpuClientAndDefineGpuAllocatorConfig(nanobind::module_& m_nb) { int node_id, int num_nodes, std::optional> allowed_devices, std::optional platform_name, - std::optional mock = false) -> nb_class_ptr { + std::optional mock = false, + std::optional mock_gpu_topology = + "") -> nb_class_ptr { std::unique_ptr ifrt_client; { nb::gil_scoped_release gil_release; @@ -81,6 +83,7 @@ void RegisterGpuClientAndDefineGpuAllocatorConfig(nanobind::module_& m_nb) { options.platform_name = platform_name; options.kv_store = kv_store; options.enable_mock_nccl = mock.value_or(false); + options.mock_gpu_topology = mock_gpu_topology; std::unique_ptr pjrt_client = xla::ValueOrThrow(GetStreamExecutorGpuClient(options)); ifrt_client = ifrt::PjRtClient::Create(std::move(pjrt_client)); @@ -93,7 +96,8 @@ void RegisterGpuClientAndDefineGpuAllocatorConfig(nanobind::module_& m_nb) { nb::arg("num_nodes") = 1, nb::arg("allowed_devices").none() = std::nullopt, nb::arg("platform_name").none() = std::nullopt, - nb::arg("mock").none() = std::nullopt); + nb::arg("mock").none() = std::nullopt, + nb::arg("mock_gpu_topology").none() = std::nullopt); } } // namespace xla diff --git a/third_party/xla/xla/python/transfer_guard_lib.cc b/third_party/xla/xla/python/guard_lib.cc similarity index 74% rename from third_party/xla/xla/python/transfer_guard_lib.cc rename to third_party/xla/xla/python/guard_lib.cc index 9cf93c88ed5ca2..6dec6e4e2490a3 100644 --- a/third_party/xla/xla/python/transfer_guard_lib.cc +++ b/third_party/xla/xla/python/guard_lib.cc @@ -1,4 +1,4 @@ -/* Copyright 2022 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,15 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This files implements the configuration management for transfer guards. -// C++ backends responsible for enforcing transfer guard levels. +// This files implements the configuration management for different types of +// guards. +// C++ backends are responsible for enforcing transfer guard levels. -#include "xla/python/transfer_guard_lib.h" +#include "xla/python/guard_lib.h" #include #include +#include "absl/base/attributes.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep @@ -34,13 +37,17 @@ namespace nb = ::nanobind; namespace { // Protected by the GIL. -TransferGuardState& global_state = *new TransferGuardState(); +GuardState& global_state = *new GuardState(); -ABSL_CONST_INIT thread_local TransferGuardState thread_local_state; +ABSL_CONST_INIT thread_local GuardState thread_local_state; // The default transfer guard level. constexpr TransferGuardLevel kDefaultGuardLevel = TransferGuardLevel::kAllow; +// The default garbage collection guard level. +constexpr GarbageCollectionGuardLevel kDefaultGarbageCollectionGuardLevel = + GarbageCollectionGuardLevel::kAllow; + // Returns the transfer guard action for a transfer. TransferGuardAction GetTransferGuardAction(TransferGuardLevel guard_level, bool explicit_transfer) { @@ -144,33 +151,45 @@ absl::Status ApplyTransferGuardToDeviceToHost( return absl::OkStatus(); } -void BuildTransferGuardSubmodule(nb::module_& m) { - nb::module_ tglib = m.def_submodule("transfer_guard_lib", - "Jax transfer guard support library"); +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard() { + return thread_local_state.garbage_collect_array.value_or( + global_state.garbage_collect_array.value_or( + kDefaultGarbageCollectionGuardLevel)); +} + +void BuildGuardSubmodule(nb::module_& m) { + nb::module_ glib = + m.def_submodule("guard_lib", "Jax support library for guards"); - nb::enum_ tglevel(tglib, "TransferGuardLevel"); + nb::enum_ tglevel(glib, "TransferGuardLevel"); tglevel.value("ALLOW", TransferGuardLevel::kAllow); tglevel.value("LOG", TransferGuardLevel::kLog); tglevel.value("DISALLOW", TransferGuardLevel::kDisallow); tglevel.value("LOG_EXPLICIT", TransferGuardLevel::kLogExplicit); tglevel.value("DISALLOW_EXPLICIT", TransferGuardLevel::kDisallowExplicit); - nb::class_ tgstate(tglib, "TransferGuardState"); - tgstate.def_rw("host_to_device", &TransferGuardState::host_to_device, + nb::enum_ gcglevel( + glib, "GarbageCollectionGuardLevel"); + gcglevel.value("ALLOW", GarbageCollectionGuardLevel::kAllow); + gcglevel.value("LOG", GarbageCollectionGuardLevel::kLog); + gcglevel.value("FATAL", GarbageCollectionGuardLevel::kFatal); + + nb::class_ tgstate(glib, "GuardState"); + tgstate.def_rw("host_to_device", &GuardState::host_to_device, + nb::arg().none()); + tgstate.def_rw("device_to_device", &GuardState::device_to_device, nb::arg().none()); - tgstate.def_rw("device_to_device", &TransferGuardState::device_to_device, + tgstate.def_rw("device_to_host", &GuardState::device_to_host, nb::arg().none()); - tgstate.def_rw("device_to_host", &TransferGuardState::device_to_host, + tgstate.def_rw("explicit_device_put", &GuardState::explicit_device_put); + tgstate.def_rw("explicit_device_get", &GuardState::explicit_device_get); + tgstate.def_rw("garbage_collect_array", &GuardState::garbage_collect_array, nb::arg().none()); - tgstate.def_rw("explicit_device_put", - &TransferGuardState::explicit_device_put); - tgstate.def_rw("explicit_device_get", - &TransferGuardState::explicit_device_get); - tglib.def( + glib.def( "global_state", [&]() { return &global_state; }, nb::rv_policy::reference); - tglib.def( + glib.def( "thread_local_state", [&]() { return &thread_local_state; }, nb::rv_policy::reference); } diff --git a/third_party/xla/xla/python/transfer_guard_lib.h b/third_party/xla/xla/python/guard_lib.h similarity index 79% rename from third_party/xla/xla/python/transfer_guard_lib.h rename to third_party/xla/xla/python/guard_lib.h index 1e5f6ffca3753d..1ff668ff30ee91 100644 --- a/third_party/xla/xla/python/transfer_guard_lib.h +++ b/third_party/xla/xla/python/guard_lib.h @@ -1,4 +1,4 @@ -/* Copyright 2022 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_TRANSFER_GUARD_LIB_H_ -#define XLA_PYTHON_TRANSFER_GUARD_LIB_H_ +#ifndef XLA_PYTHON_GUARD_LIB_H_ +#define XLA_PYTHON_GUARD_LIB_H_ #include #include @@ -45,7 +45,17 @@ enum class TransferGuardLevel { kDisallowExplicit, }; -// Flags for transfer guard levels are controlled by: +// Garbage collection guard level chose by the user code. +enum class GarbageCollectionGuardLevel { + // Silently allow the object to be garbage collected. + kAllow, + // Log and allow the object to be garbage collected. + kLog, + // Fatal crash on object garbage collection. + kFatal, +}; + +// Flags for guard levels are controlled by: // - a global flag value, // e.g., associated to --jax_transfer_guard_device_to_host // which defaults to TransferGuardLevel::kAllow. @@ -54,12 +64,14 @@ enum class TransferGuardLevel { // implement context managers that locally override the global state. // // Explicit device_put/device_get contexts are tracked by context managers. -struct TransferGuardState { +struct GuardState { std::optional host_to_device; std::optional device_to_device; std::optional device_to_host; bool explicit_device_put = false; bool explicit_device_get = false; + + std::optional garbage_collect_array; }; // Resulting action for a transfer given the transfer guard level and the @@ -91,9 +103,13 @@ absl::Status ApplyTransferGuardToDeviceToDevice( absl::Status ApplyTransferGuardToDeviceToHost( absl::FunctionRef formatter); +// Returns the garbage collection guard level for "jax.Array" objects. +// REQUIRES: Python GIL. +GarbageCollectionGuardLevel GetGarbageCollectArrayGuard(); + // The function to call in `xla.cc` to add the bindings for this module. -void BuildTransferGuardSubmodule(nanobind::module_& m); +void BuildGuardSubmodule(nanobind::module_& m); } // namespace jax -#endif // XLA_PYTHON_TRANSFER_GUARD_LIB_H_ +#endif // XLA_PYTHON_GUARD_LIB_H_ diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index 6531018329ad3c..20bffc62af72e2 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -87,6 +87,7 @@ cc_library( ":attribute_map", ":device_proto_cc", ":dtype_proto_cc", + ":execute_options_proto_cc", ":remap_plan_proto_cc", ":serdes", ":shape_proto_cc", @@ -105,8 +106,11 @@ cc_library( "//xla/python/ifrt/ir:sharding_param", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -120,7 +124,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -134,6 +137,7 @@ xla_cc_test( deps = [ ":ifrt", ":mock", + "//xla/tsl/concurrency:ref_count", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", ], @@ -164,6 +168,11 @@ xla_cc_test( ], ) +tf_proto_library( + name = "execute_options_proto", + srcs = ["execute_options.proto"], +) + xla_cc_test( name = "future_test", size = "small", @@ -336,6 +345,7 @@ cc_library( deps = [ ":ifrt", ":test_util", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], alwayslink = True, @@ -517,15 +527,15 @@ tf_proto_library( ) xla_cc_test( - name = "device_test", + name = "device_list_test", size = "small", - srcs = ["device_test.cc"], + srcs = ["device_list_test.cc"], deps = [ ":device_proto_cc", ":device_test_util", ":ifrt", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", @@ -650,10 +660,10 @@ xla_cc_test( ":serdes", ":serdes_proto_cc", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) @@ -713,6 +723,7 @@ xla_cc_test( ":ifrt", ":program_serdes", ":serdes", + ":serdes_proto_cc", "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", diff --git a/third_party/xla/xla/python/ifrt/array.cc b/third_party/xla/xla/python/ifrt/array.cc index aa44b469272109..fd61b4aecdf55c 100644 --- a/third_party/xla/xla/python/ifrt/array.cc +++ b/third_party/xla/xla/python/ifrt/array.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/types/span.h" +#include "xla/tsl/concurrency/ref_count.h" + namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/array.h b/third_party/xla/xla/python/ifrt/array.h index e080cceeec31c0..a83ff3d0e6b693 100644 --- a/third_party/xla/xla/python/ifrt/array.h +++ b/third_party/xla/xla/python/ifrt/array.h @@ -19,12 +19,11 @@ limitations under the License. #include #include #include -#include -#include #include #include "absl/base/attributes.h" -#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/dtype.h" @@ -39,8 +38,6 @@ namespace ifrt { class Client; -using Layout = ::xla::PjRtLayout; - // Semantics for operations that may copy or move sharded buffers in an array. enum class ArrayCopySemantics : int { // Always creates new buffers to construct an output array. Mutation of the @@ -81,8 +78,15 @@ class Array : public llvm::RTTIExtends { // Breaks an array up into per-device arrays. This is the elimination // counterpart of `Client::AssembleArrayFromSingleDeviceArrays()`. + // TODO(hyeontaek): Replace this API with the version that takes + // `SingleDeviceShardSemantics`. virtual absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) = 0; + virtual absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) = 0; + // Returns a shard of an Array which is fully replicated. This is an // optimization so that instead of disassembling into all the shards when // the Array is fully replicated, we can just get 1 shard out and create an diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index 4e079592672dfe..b8ef7caed58dec 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -330,7 +330,8 @@ TEST(ArrayImplTest, MakeArrayFromHostBufferReplicated) { TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays, array->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_EQ(single_device_arrays.size(), devices.size()); for (int i = 0; i < single_device_arrays.size(); ++i) { EXPECT_THAT(single_device_arrays[i]->sharding().devices()->devices(), @@ -383,7 +384,8 @@ TEST(ArrayImplTest, AssembleArray) { auto assembled_array, client->AssembleArrayFromSingleDeviceArrays( assembled_shape, assembled_sharding, absl::MakeSpan(arrays), - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); EXPECT_EQ(assembled_array->dtype(), dtype); EXPECT_EQ(assembled_array->shape(), assembled_shape); @@ -438,11 +440,14 @@ TEST(ArrayImplTest, AssembleAndDisassembleArray) { auto assembled_array, client->AssembleArrayFromSingleDeviceArrays( assembled_shape, assembled_sharding, absl::MakeSpan(arrays), - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); - TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays, - assembled_array->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK_AND_ASSIGN( + auto single_device_arrays, + assembled_array->DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_THAT(single_device_arrays, SizeIs(2)); EXPECT_EQ(single_device_arrays[0]->dtype(), array0->dtype()); @@ -479,7 +484,8 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) { TF_ASSERT_OK_AND_ASSIGN(auto assembled_array, client->AssembleArrayFromSingleDeviceArrays( shape, sharding, absl::MakeSpan(arrays), - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_EQ(assembled_array->dtype(), array->dtype()); ASSERT_EQ(assembled_array->shape(), array->shape()); @@ -488,7 +494,8 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) { TF_ASSERT_OK_AND_ASSIGN(auto single_device_arrays, assembled_array->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); ASSERT_THAT(single_device_arrays, SizeIs(1)); ASSERT_EQ(single_device_arrays[0]->dtype(), array->dtype()); @@ -557,18 +564,22 @@ TEST(ArrayImplTest, CopyToDifferentDevice) { std::vector shapes(shards.size(), shape); std::shared_ptr sharding = ConcreteSharding::Create(devices, MemoryKind(), shape, shapes); - TF_ASSERT_OK_AND_ASSIGN(arrays.emplace_back(), - client->AssembleArrayFromSingleDeviceArrays( - shape, sharding, absl::MakeSpan(shards), - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + client->AssembleArrayFromSingleDeviceArrays( + shape, sharding, absl::MakeSpan(shards), + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); } { std::shared_ptr sharding = ConcreteEvenSharding::Create(devices, MemoryKind(), shape, shape); - TF_ASSERT_OK_AND_ASSIGN(arrays.emplace_back(), - client->AssembleArrayFromSingleDeviceArrays( - shape, sharding, absl::MakeSpan(shards), - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK_AND_ASSIGN( + arrays.emplace_back(), + client->AssembleArrayFromSingleDeviceArrays( + shape, sharding, absl::MakeSpan(shards), + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); } BasicDeviceList::Devices new_devices; @@ -589,9 +600,10 @@ TEST(ArrayImplTest, CopyToDifferentDevice) { BasicDeviceList::Create(new_devices), MemoryKind())); EXPECT_EQ(new_arrays[i]->sharding(), *expected_sharding); - TF_ASSERT_OK_AND_ASSIGN(auto shards, - arrays[i]->DisassembleIntoSingleDeviceArrays( - ArrayCopySemantics::kAlwaysCopy)); + TF_ASSERT_OK_AND_ASSIGN( + auto shards, arrays[i]->DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics::kAlwaysCopy, + SingleDeviceShardSemantics::kAddressableShards)); for (const auto& shard : shards) { std::vector out_data(6); auto future = shard->CopyToHostBuffer(out_data.data(), diff --git a/third_party/xla/xla/python/ifrt/array_spec.cc b/third_party/xla/xla/python/ifrt/array_spec.cc index 828bca230ab5de..b8b8d5b1f872dd 100644 --- a/third_party/xla/xla/python/ifrt/array_spec.cc +++ b/third_party/xla/xla/python/ifrt/array_spec.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "xla/python/ifrt/array_spec.pb.h" -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" diff --git a/third_party/xla/xla/python/ifrt/array_spec.h b/third_party/xla/xla/python/ifrt/array_spec.h index b5e2ba3395f87f..d6497ac15b1e31 100644 --- a/third_party/xla/xla/python/ifrt/array_spec.h +++ b/third_party/xla/xla/python/ifrt/array_spec.h @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/python/ifrt/array_spec.pb.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" @@ -47,7 +48,13 @@ struct ArraySpec { // Returns a `ArraySpecProto` representation. absl::StatusOr ToProto() const; + // TODO(hyeontaek): Remove this method in favor of AbslStringify. std::string DebugString() const; + + template + friend void AbslStringify(Sink& sink, const ArraySpec& array_spec) { + sink.Append(array_spec.DebugString()); + } }; } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/array_test.cc b/third_party/xla/xla/python/ifrt/array_test.cc index ec94659eb7fe56..5c639b4469f29e 100644 --- a/third_party/xla/xla/python/ifrt/array_test.cc +++ b/third_party/xla/xla/python/ifrt/array_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include -#include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/mock.h" +#include "xla/tsl/concurrency/ref_count.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/attribute_map.h b/third_party/xla/xla/python/ifrt/attribute_map.h index 630933d9fce6bb..69a567635eb19d 100644 --- a/third_party/xla/xla/python/ifrt/attribute_map.h +++ b/third_party/xla/xla/python/ifrt/attribute_map.h @@ -93,6 +93,11 @@ class AttributeMap { std::string DebugString(size_t max_string_length = 64, size_t max_int64_list_size = 16) const; + template + friend void AbslStringify(Sink& sink, const AttributeMap& attribute_map) { + sink.Append(attribute_map.DebugString()); + } + private: Map map_; }; diff --git a/third_party/xla/xla/python/ifrt/client.h b/third_party/xla/xla/python/ifrt/client.h index 5b48bedd713a6a..b0292dc6f6328e 100644 --- a/third_party/xla/xla/python/ifrt/client.h +++ b/third_party/xla/xla/python/ifrt/client.h @@ -113,11 +113,19 @@ class Client : public llvm::RTTIExtends { std::function on_done_with_host_buffer) = 0; // Builds a larger array out of individual per-device shards. + // TODO(hyeontaek): Replace this API with the version that takes + // `SingleDeviceShardSemantics`. virtual absl::StatusOr> AssembleArrayFromSingleDeviceArrays( Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) = 0; + virtual absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) = 0; // Copies the arrays to a new set of devices. // @@ -177,13 +185,14 @@ class Client : public llvm::RTTIExtends { virtual absl::StatusOr> MakeTuple( absl::Span> values) = 0; + // Identifies the IFRT implementation. Most C++ users should use LLVM RTTI to + // determine the runtime type. This is a string exposed to users mostly for + // informational reasons. + virtual absl::string_view runtime_type() const = 0; + // The following APIs are taken from `xla::PjRtClient` for fast prototyping. // Most of the APIs will be factored out as a `Platform`/`Topology` in the // future to facilitate topology discovery and ahead-of-time compilation. - - // TODO(hyeontaek): Remove runtime_type() in favor of LLVM RTTI. - virtual absl::string_view runtime_type() const = 0; - // TODO(hyeontaek): Factor them out to a `Platform`/`Topology` class. virtual absl::string_view platform_name() const = 0; virtual absl::string_view platform_version() const = 0; @@ -205,6 +214,11 @@ class Client : public llvm::RTTIExtends { virtual absl::Span addressable_devices() const = 0; virtual int process_index() const = 0; + // Returns all devices. The result includes primary devices that are included + // in `devices()` as well as any other devices that are associated with + // the primary devices. + virtual absl::Span GetAllDevices() const = 0; + // TODO(hyeontaek): Consider removing this API. This API is potentially not // being used by JAX or will be replaced with explicit device assignment. virtual absl::StatusOr GetDefaultDeviceAssignment( @@ -225,8 +239,9 @@ class Client : public llvm::RTTIExtends { // single-shard dimensions `dims`. // TODO(hyeontaek): Change the API to take `Shape` and `Sharding` instead of // single-shard dimensions and device. - virtual absl::StatusOr> GetDefaultLayoutForDevice( - DType dtype, absl::Span dims, Device* device) const = 0; + virtual absl::StatusOr> + GetDefaultLayoutForDevice(DType dtype, absl::Span dims, + Device* device) const = 0; static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc index 6a0f7e2cd7e27c..38edc163d0204d 100644 --- a/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/client_impl_test_lib.cc @@ -13,8 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/test_util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -54,6 +55,18 @@ TEST(ClientImplTest, Devices) { EXPECT_GE(client->process_index(), 0); } +TEST(ClientImplTest, GetAllDevices) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + + EXPECT_GE(client->GetAllDevices().size(), client->device_count()); + + for (Device* device : client->GetAllDevices()) { + TF_ASSERT_OK_AND_ASSIGN(auto* looked_up_device, + client->LookupDevice(device->Id())); + EXPECT_EQ(device, looked_up_device); + } +} + TEST(ClientImplTest, DefaultCompiler) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); EXPECT_THAT(client->GetDefaultCompiler(), NotNull()); diff --git a/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc index 4ea2a4f2ae8935..d8e6ba91b39c64 100644 --- a/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/program_serdes.h" #include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/serdes.pb.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/tsl/concurrency/ref_count.h" diff --git a/third_party/xla/xla/python/ifrt/device.h b/third_party/xla/xla/python/ifrt/device.h index 3da8fe57c5a772..c20b40008d1941 100644 --- a/third_party/xla/xla/python/ifrt/device.h +++ b/third_party/xla/xla/python/ifrt/device.h @@ -24,7 +24,7 @@ limitations under the License. #include "llvm/Support/ExtensibleRTTI.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.pb.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace ifrt { @@ -68,6 +68,8 @@ class Device : public llvm::RTTIExtends { // Debug string suitable for logging when errors occur. Should be verbose // enough to describe the current device unambiguously. + // + // TODO(hyeontaek): Remove this method in favor of AbslStringify. virtual absl::string_view DebugString() const = 0; // Returns the default memory space attached to this device. @@ -87,6 +89,20 @@ class Device : public llvm::RTTIExtends { // process_index as the client. virtual int ProcessIndex() const = 0; + template + friend void AbslStringify(Sink& sink, const Device& device) { + sink.Append(device.DebugString()); + } + + template + friend void AbslStringify(Sink& sink, const Device* device) { + if (device == nullptr) { + sink.Append(""); + } else { + sink.Append(device->DebugString()); + } + } + static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/device_list.cc b/third_party/xla/xla/python/ifrt/device_list.cc index fdd58588cf6516..35b37b5ec1a1dd 100644 --- a/third_party/xla/xla/python/ifrt/device_list.cc +++ b/third_party/xla/xla/python/ifrt/device_list.cc @@ -17,16 +17,17 @@ limitations under the License. #include #include -#include #include #include #include +#include "absl/base/call_once.h" #include "absl/base/optimization.h" #include "absl/hash/hash.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" #include "xla/tsl/concurrency/ref_count.h" @@ -64,12 +65,32 @@ tsl::RCReference BasicDeviceList::Create(Devices devices) { return tsl::MakeRef(std::move(devices)); } -BasicDeviceList::BasicDeviceList(Devices devices) : hash_(kUnsetHash) { - if (devices.size() <= kInlineDeviceSize) { - state_ = State{std::move(devices)}; - } else { - state_ = std::make_shared(State{std::move(devices)}); - } +BasicDeviceList::BasicDeviceList(Devices devices) + : devices_(std::move(devices)), hash_(kUnsetHash) {} + +DeviceList* BasicDeviceList::AddressableDeviceList() const { + absl::call_once(addressable_device_list_cache_.once_flag, [this] { + Devices addressable_devices; + for (Device* device : devices_) { + if (device->IsAddressable()) { + addressable_devices.push_back(device); + } + } + const bool already_fully_addressable = + addressable_devices.size() == devices_.size(); + if (already_fully_addressable) { + // `device_list_holder` is intentionally unset. We skip storing a + // reference-counted copy in the holder to avoid creating a self cycle. + addressable_device_list_cache_.device_list = + const_cast(this); + } else { + addressable_device_list_cache_.device_list_holder = + BasicDeviceList::Create(std::move(addressable_devices)); + addressable_device_list_cache_.device_list = + addressable_device_list_cache_.device_list_holder.get(); + } + }); + return addressable_device_list_cache_.device_list; } uint64_t BasicDeviceList::hash() const { @@ -86,7 +107,7 @@ uint64_t BasicDeviceList::hash() const { std::string BasicDeviceList::ToString() const { return absl::StrCat("BasicDeviceList([", - absl::StrJoin(state().devices, ",", + absl::StrJoin(devices_, ",", [](std::string* out, Device* device) { absl::StrAppend(out, device->DebugString()); diff --git a/third_party/xla/xla/python/ifrt/device_list.h b/third_party/xla/xla/python/ifrt/device_list.h index b34522f9b75686..b10dad716e76eb 100644 --- a/third_party/xla/xla/python/ifrt/device_list.h +++ b/third_party/xla/xla/python/ifrt/device_list.h @@ -18,12 +18,10 @@ limitations under the License. #include #include -#include #include -#include -#include #include +#include "absl/base/call_once.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" @@ -72,6 +70,16 @@ class DeviceList : public tsl::ReferenceCounted, // Returns a list of `Devices*` represented by this `DeviceList`. virtual absl::Span devices() const = 0; + // Returns a `DeviceList*` containing only addressable devices from this + // `DeviceList`. It returns itself if all devices are addressable. It points + // to a heap-allocated object; the pointer is valid at least until this + // `DeviceList` is destroyed, and it can be persisted beyond this + // `DeviceList`'s lifetime by using `tsl::FormRef()`. + virtual DeviceList* AddressableDeviceList() const = 0; + + // Returns true if all devices are addressable. + bool IsFullyAddressable() const { return AddressableDeviceList() == this; } + virtual bool operator==(const DeviceList& other) const = 0; bool operator!=(const DeviceList& other) const { return !(*this == other); } @@ -80,6 +88,16 @@ class DeviceList : public tsl::ReferenceCounted, sink.Append(device_list.ToString()); } + template + friend void AbslStringify(Sink& sink, + const tsl::RCReference& device_list) { + if (device_list == nullptr) { + sink.Append(""); + } else { + sink.Append(device_list->ToString()); + } + } + // Returns the hash of devices. This hash is stable only within the process. virtual uint64_t hash() const = 0; @@ -124,22 +142,20 @@ class BasicDeviceList : public llvm::RTTIExtends { // Returns a `DeviceListProto` representation. DeviceListProto ToProto() const; - absl::Span devices() const override { return state().devices; } + absl::Span devices() const override { return devices_; } + + DeviceList* AddressableDeviceList() const override; bool operator==(const DeviceList& other) const override { + if (this == &other) { + return true; + } const auto* other_basic_device_list = llvm::dyn_cast(&other); if (other_basic_device_list == nullptr) { return false; } - const std::shared_ptr* lhs = - std::get_if>(&state_); - const std::shared_ptr* rhs = - std::get_if>(&other_basic_device_list->state_); - if (lhs != nullptr && rhs != nullptr && lhs->get() == rhs->get()) { - return true; - } - return devices() == other.devices(); + return devices_ == other_basic_device_list->devices_; } uint64_t hash() const override; @@ -152,40 +168,17 @@ class BasicDeviceList : public llvm::RTTIExtends { template friend tsl::RCReference tsl::MakeRef(Args&&... args); - // Internal state that may be shared across `DeviceList` instances. - struct State { - Devices devices; - }; - - State& state() { - return std::visit( - [](auto& state) -> State& { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return state; - } else if constexpr (std::is_same_v>) { - return *state; - } - }, - state_); - } - - const State& state() const { - return std::visit( - [](auto& state) -> const State& { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return state; - } else if constexpr (std::is_same_v>) { - return *state; - } - }, - state_); - } - std::string ToString() const override; - std::variant> state_; + Devices devices_; + + // Addressable device list is dynamically computed and cached. + struct AddressableDeviceListCache { + absl::once_flag once_flag; + DeviceList* device_list = nullptr; + tsl::RCReference device_list_holder; + }; + mutable AddressableDeviceListCache addressable_device_list_cache_; // Cached hash. 0 indicates the hash needs to be computed and cached. // May be written multiple times with the same non-zero value. diff --git a/third_party/xla/xla/python/ifrt/device_test.cc b/third_party/xla/xla/python/ifrt/device_list_test.cc similarity index 53% rename from third_party/xla/xla/python/ifrt/device_test.cc rename to third_party/xla/xla/python/ifrt/device_list_test.cc index 713b9ca3ce5ec1..381e00dd416d2d 100644 --- a/third_party/xla/xla/python/ifrt/device_test.cc +++ b/third_party/xla/xla/python/ifrt/device_list_test.cc @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/device_list.h" #include #include +#include #include #include +#include #include #include "absl/status/statusor.h" -#include "absl/synchronization/blocking_counter.h" +#include "absl/types/span.h" +#include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device.pb.h" -#include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/device_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" @@ -35,6 +37,8 @@ namespace xla { namespace ifrt { namespace { +using ::testing::ElementsAreArray; + class DeviceListTest : public test_util::DeviceTest {}; TEST_P(DeviceListTest, ToFromProto) { @@ -48,23 +52,70 @@ TEST_P(DeviceListTest, ToFromProto) { EXPECT_EQ(*device_list_copy, *device_list); } +TEST_P(DeviceListTest, AddressableDevices) { + auto device_list = GetDevices({0, 1}); + std::vector addressable_devices; + for (Device* device : device_list->devices()) { + if (device->IsAddressable()) { + addressable_devices.push_back(device); + } + } + EXPECT_THAT(device_list->AddressableDeviceList()->devices(), + ElementsAreArray(addressable_devices)); +} + +TEST_P(DeviceListTest, AddressableDevicesFromConcurrentCalls) { + auto device_list = GetDevices({0, 1}); + + const int num_threads = 16; + auto thread_pool = std::make_unique( + tsl::Env::Default(), tsl::ThreadOptions(), "test_pool", + std::min(num_threads, tsl::port::MaxParallelism())); + std::vector addressable_device_lists(num_threads); + for (int i = 0; i < num_threads; ++i) { + thread_pool->Schedule([&, i]() { + addressable_device_lists[i] = device_list->AddressableDeviceList(); + // Touch a device in the list so that tsan can verify access to the + // content of the addressable device list. + addressable_device_lists[i]->devices().front()->Id(); + }); + } + + thread_pool.reset(); + for (int i = 0; i < num_threads; ++i) { + EXPECT_EQ(*addressable_device_lists[i], + *device_list->AddressableDeviceList()); + } +} + +TEST_P(DeviceListTest, IsFullyAddressable) { + auto device_list = GetDevices({0, 1}); + int num_addressable_devices = 0; + for (Device* device : device_list->devices()) { + if (device->IsAddressable()) { + ++num_addressable_devices; + } + } + if (num_addressable_devices == device_list->size()) { + EXPECT_TRUE(device_list->IsFullyAddressable()); + } else { + EXPECT_FALSE(device_list->IsFullyAddressable()); + } +} + TEST_P(DeviceListTest, IdenticalHashFromConcurrentCalls) { auto device_list = GetDevices({0, 1}); const int num_threads = 16; - absl::BlockingCounter counter(num_threads); - tsl::thread::ThreadPool thread_pool( + auto thread_pool = std::make_unique( tsl::Env::Default(), tsl::ThreadOptions(), "test_pool", std::min(num_threads, tsl::port::MaxParallelism())); std::vector hashes(num_threads); for (int i = 0; i < num_threads; ++i) { - thread_pool.Schedule([&, i]() { - hashes[i] = device_list->hash(); - counter.DecrementCount(); - }); + thread_pool->Schedule([&, i]() { hashes[i] = device_list->hash(); }); } - counter.Wait(); + thread_pool.reset(); for (int i = 0; i < num_threads; ++i) { EXPECT_EQ(hashes[i], device_list->hash()); } @@ -89,10 +140,12 @@ TEST_P(DeviceListTest, EqualityTest) { EXPECT_NE(*device_list1, *device_list6); } -INSTANTIATE_TEST_SUITE_P(NumDevices, DeviceListTest, - testing::Values(test_util::DeviceTestParam{ - /*num_devices=*/2, - /*num_addressable_devices=*/2})); +INSTANTIATE_TEST_SUITE_P( + NumDevices, DeviceListTest, + testing::Values(test_util::DeviceTestParam{/*num_devices=*/2, + /*num_addressable_devices=*/1}, + test_util::DeviceTestParam{/*num_devices=*/2, + /*num_addressable_devices=*/2})); } // namespace } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/device_test_util.cc b/third_party/xla/xla/python/ifrt/device_test_util.cc index c77a3400839562..6e4d6f49c65cbb 100644 --- a/third_party/xla/xla/python/ifrt/device_test_util.cc +++ b/third_party/xla/xla/python/ifrt/device_test_util.cc @@ -100,8 +100,6 @@ std::shared_ptr MakeDeviceTestClient(int num_devices, ON_CALL(*device, client).WillByDefault(ReturnPointee(&state->client)); ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i + 10))); ON_CALL(*device, IsAddressable).WillByDefault(Return(addressable)); - ON_CALL(*device, DebugString) - .WillByDefault(Return(absl::StrCat("device(", i + 10, ")"))); ON_CALL(*device, DefaultMemory).WillByDefault(Return(state->memories[i])); // device_memories will be filled in at the end of the loop. ON_CALL(*device, Memories) diff --git a/third_party/xla/xla/python/ifrt/dtype.cc b/third_party/xla/xla/python/ifrt/dtype.cc index 1de5702b6cc8df..17e2cfa281d251 100644 --- a/third_party/xla/xla/python/ifrt/dtype.cc +++ b/third_party/xla/xla/python/ifrt/dtype.cc @@ -37,6 +37,8 @@ std::optional DType::byte_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -78,6 +80,8 @@ std::optional DType::bit_size() const { case kPred: case kS8: case kU8: + case kF8E3M4: + case kF8E4M3: // The following types are https://arxiv.org/abs/2209.05433 case kF8E4M3FN: case kF8E4M3B11FNUZ: @@ -133,6 +137,9 @@ absl::StatusOr DType::FromProto(const DTypeProto& dtype_proto) { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); @@ -175,6 +182,9 @@ DTypeProto DType::ToProto() const { CASE(BF16); CASE(C64); CASE(C128); + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // CASE(F8E3M4); + // CASE(F8E4M3); CASE(F8E4M3FN); CASE(F8E4M3B11FNUZ); CASE(F8E4M3FNUZ); diff --git a/third_party/xla/xla/python/ifrt/dtype.h b/third_party/xla/xla/python/ifrt/dtype.h index 06a92b67f863c8..28097a9cefa930 100644 --- a/third_party/xla/xla/python/ifrt/dtype.h +++ b/third_party/xla/xla/python/ifrt/dtype.h @@ -78,13 +78,15 @@ class DType { // dtype will have empty dimensions. kToken = 17, + kF8E3M4 = 29, + kF8E4M3 = 28, kF8E4M3FN = 20, kF8E4M3B11FNUZ = 23, kF8E4M3FNUZ = 25, kF8E5M2 = 19, kF8E5M2FNUZ = 24, - // Next = 26 + // Next = 30 // Variable-length string represented as raw bytes, as in `bytes` in Python, // i.e., no encoding enforcement. String is not support in XLA. DType.Kind @@ -125,8 +127,14 @@ class DType { // Returns a `DTypeProto` representation. DTypeProto ToProto() const; + // TODO(hyeontaek): Remove this method in favor of AbslStringify. std::string DebugString() const; + template + friend void AbslStringify(Sink& sink, const DType& dtype) { + sink.Append(dtype.DebugString()); + } + private: Kind kind_; }; diff --git a/third_party/xla/xla/python/ifrt/dtype.proto b/third_party/xla/xla/python/ifrt/dtype.proto index eadfd42a3550cd..37976833e7e8c7 100644 --- a/third_party/xla/xla/python/ifrt/dtype.proto +++ b/third_party/xla/xla/python/ifrt/dtype.proto @@ -60,6 +60,8 @@ message DTypeProto { // dtype will have empty dimensions. KIND_TOKEN = 17; + KIND_F8E3M4 = 29; + KIND_F8E4M3 = 28; KIND_F8E4M3FN = 20; KIND_F8E4M3B11FNUZ = 23; KIND_F8E4M3FNUZ = 25; diff --git a/third_party/xla/xla/python/ifrt/dtype_test.cc b/third_party/xla/xla/python/ifrt/dtype_test.cc index 5ac531dabcb9ce..57fec6702d277d 100644 --- a/third_party/xla/xla/python/ifrt/dtype_test.cc +++ b/third_party/xla/xla/python/ifrt/dtype_test.cc @@ -49,6 +49,8 @@ TEST(DTypeTest, ByteSize) { {DType::kPred, 1}, {DType::kS8, 1}, {DType::kU8, 1}, + {DType::kF8E3M4, 1}, + {DType::kF8E4M3, 1}, {DType::kF8E4M3FN, 1}, {DType::kF8E4M3B11FNUZ, 1}, {DType::kF8E4M3FNUZ, 1}, @@ -85,6 +87,8 @@ TEST(DTypeTest, BitSize) { {DType::kPred, 8}, {DType::kS8, 8}, {DType::kU8, 8}, + {DType::kF8E3M4, 8}, + {DType::kF8E4M3, 8}, {DType::kF8E4M3FN, 8}, {DType::kF8E4M3B11FNUZ, 8}, {DType::kF8E4M3FNUZ, 8}, diff --git a/third_party/xla/xla/python/ifrt/executable.cc b/third_party/xla/xla/python/ifrt/executable.cc index 77cabe7f6a9389..07145d017f1395 100644 --- a/third_party/xla/xla/python/ifrt/executable.cc +++ b/third_party/xla/xla/python/ifrt/executable.cc @@ -15,11 +15,38 @@ limitations under the License. #include "xla/python/ifrt/executable.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/execute_options.pb.h" + namespace xla { namespace ifrt { char Executable::ID = 0; char LoadedExecutable::ID = 0; +absl::StatusOr ExecuteOptions::ToProto() const { + ExecuteOptionsProto proto; + + proto.set_launch_id(launch_id); + proto.mutable_non_donatable_input_indices()->Add( + non_donatable_input_indices.begin(), non_donatable_input_indices.end()); + proto.set_fill_status(fill_status); + + return proto; +} + +absl::StatusOr ExecuteOptions::FromProto( + const xla::ifrt::ExecuteOptionsProto& proto) { + ExecuteOptions options; + + options.launch_id = proto.launch_id(); + options.non_donatable_input_indices.insert( + proto.non_donatable_input_indices().begin(), + proto.non_donatable_input_indices().end()); + options.fill_status = proto.fill_status(); + + return options; +} + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/executable.h b/third_party/xla/xla/python/ifrt/executable.h index 08fa0de003ddae..6a292cc9a3044d 100644 --- a/third_party/xla/xla/python/ifrt/executable.h +++ b/third_party/xla/xla/python/ifrt/executable.h @@ -22,18 +22,22 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/pjrt_executable.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/attribute_map.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" +#include "xla/python/ifrt/execute_options.pb.h" #include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/xla_data.pb.h" namespace xla { namespace ifrt { @@ -74,10 +78,10 @@ class Executable : public llvm::RTTIExtends { // Returns a list of output `OpSharding`. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Returns an `HloModule` (optimized) per partition. virtual absl::StatusOr>> @@ -103,6 +107,32 @@ class Executable : public llvm::RTTIExtends { static char ID; // NOLINT }; +struct ExecuteOptions { + // If non-zero, identifies this execution as part of a potentially + // multi-device launch. This can be used to detect scheduling errors, e.g. if + // multi-host programs are launched in different orders on different hosts, + // the launch IDs may be used by the runtime to detect the mismatch. + int32_t launch_id = 0; + + // A set of indices denoting the input arrays that should not be donated. An + // input array may be non-donable, for example, if it is referenced more than + // once. Since such runtime information is not available at compile time, the + // compiler might mark the input as `may-alias`, which could lead IFRT to + // donate the input array when it should not. By defining this set of indices, + // a higher-level IFRT caller can instruct IFRT client not to donate specific + // input arrays. + absl::flat_hash_set non_donatable_input_indices; + + // If true, populate `ExecuteResult::status`. Otherwise, the status is left as + // an invalid future. + bool fill_status = false; + + absl::StatusOr ToProto() const; + + static absl::StatusOr FromProto( + const ExecuteOptionsProto& proto); +}; + // Wraps a computation that has been fully compiled and loaded for execution. class LoadedExecutable : public llvm::RTTIExtends { @@ -153,10 +183,10 @@ class LoadedExecutable // Returns a list of output OpSharding. virtual std::optional> GetOutputShardings() const = 0; // Returns a list of parameter layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetParameterLayouts() const = 0; // Returns a list of output/result layouts. - virtual absl::StatusOr>> + virtual absl::StatusOr>> GetOutputLayouts() const = 0; // Return an HloModule (optimized) per partition. virtual absl::StatusOr>> @@ -175,12 +205,12 @@ class LoadedExecutable // `LoadedExecutable` methods. - // Short-term alias. - using ExecuteOptions = ::xla::ExecuteOptions; + using ExecuteOptions = xla::ifrt::ExecuteOptions; // Result from an execution. struct ExecuteResult { - // Resulting status of the execution. + // Resulting status of the execution. Filled only if + // `ExecuteOptions::fill_status` is true. Future<> status; // Output arrays. std::vector> outputs; diff --git a/tensorflow/cc/experimental/libtf/impl/none.cc b/third_party/xla/xla/python/ifrt/execute_options.proto similarity index 65% rename from tensorflow/cc/experimental/libtf/impl/none.cc rename to third_party/xla/xla/python/ifrt/execute_options.proto index 8f16b1ed4ab760..5503c0b6bb4573 100644 --- a/tensorflow/cc/experimental/libtf/impl/none.cc +++ b/third_party/xla/xla/python/ifrt/execute_options.proto @@ -1,4 +1,4 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/none.h" -namespace tf { -namespace libtf { -namespace impl { +syntax = "proto3"; -None& None::GetInstance() { - static None* none_inst = new None(); - return *none_inst; -} +package xla.ifrt; + +message ExecuteOptionsProto { + bool untuple_result = 2; + int32 launch_id = 3; + repeated int32 non_donatable_input_indices = 7; + bool fill_status = 9; -} // namespace impl -} // namespace libtf -} // namespace tf + reserved 1, 4 to 6, 8; +} diff --git a/third_party/xla/xla/python/ifrt/future_test.cc b/third_party/xla/xla/python/ifrt/future_test.cc index 808d9a4981494a..10fc62f3e2eef2 100644 --- a/third_party/xla/xla/python/ifrt/future_test.cc +++ b/third_party/xla/xla/python/ifrt/future_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include "absl/status/status.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/python/ifrt/hlo/hlo_program.h b/third_party/xla/xla/python/ifrt/hlo/hlo_program.h index b084c987b9931f..37802019e4bb7b 100644 --- a/third_party/xla/xla/python/ifrt/hlo/hlo_program.h +++ b/third_party/xla/xla/python/ifrt/hlo/hlo_program.h @@ -17,9 +17,7 @@ limitations under the License. #define XLA_PYTHON_IFRT_HLO_HLO_PROGRAM_H_ #include -#include #include -#include #include "llvm/Support/ExtensibleRTTI.h" #include "mlir/IR/BuiltinOps.h" diff --git a/third_party/xla/xla/python/ifrt/index.h b/third_party/xla/xla/python/ifrt/index.h index 3c5bb4cc9c2965..d64321a8fef4f7 100644 --- a/third_party/xla/xla/python/ifrt/index.h +++ b/third_party/xla/xla/python/ifrt/index.h @@ -89,8 +89,14 @@ class Index { return *this = *this * multiplier; } + // TODO(hyeontaek): Remove this method in favor of AbslStringify. std::string DebugString() const; + template + friend void AbslStringify(Sink& sink, const Index& index) { + sink.Append(index.DebugString()); + } + private: Elements elements_; }; diff --git a/third_party/xla/xla/python/ifrt/index_domain.h b/third_party/xla/xla/python/ifrt/index_domain.h index 4359bd0a8012ae..506a077fb1c700 100644 --- a/third_party/xla/xla/python/ifrt/index_domain.h +++ b/third_party/xla/xla/python/ifrt/index_domain.h @@ -72,8 +72,15 @@ class IndexDomain { origin_ -= offset; return *this; } + + // TODO(hyeontaek): Remove this method in favor of AbslStringify. std::string DebugString() const; + template + friend void AbslStringify(Sink& sink, const IndexDomain& index_domain) { + sink.Append(index_domain.DebugString()); + } + private: Index origin_; Shape shape_; diff --git a/third_party/xla/xla/python/ifrt/ir/BUILD b/third_party/xla/xla/python/ifrt/ir/BUILD index 01233d18276583..15ce939709aff1 100644 --- a/third_party/xla/xla/python/ifrt/ir/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( @@ -170,16 +171,73 @@ cc_library( ) cc_library( - name = "compiler", - srcs = ["compiler.cc"], - hdrs = ["compiler.h"], + name = "ifrt_ir_program", + srcs = ["ifrt_ir_program.cc"], + hdrs = ["ifrt_ir_program.h"], compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla/python/ifrt", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) + +cc_library( + name = "ifrt_ir_program_serdes", + srcs = ["ifrt_ir_program_serdes.cc"], + compatible_with = get_compatible_with_portable(), + visibility = ["//xla/python/ifrt:friends"], + deps = [ + ":ifrt_ir_program", + "//xla/mlir/utils:error_util", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt/support:module_parsing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeWriter", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:statusor", + ], + alwayslink = True, +) + +xla_cc_test( + name = "ifrt_ir_program_serdes_test", + srcs = ["ifrt_ir_program_serdes_test.cc"], + deps = [ + ":ifrt_ir_program", + ":ifrt_ir_program_serdes", + "//xla/python/ifrt:serdes", + "//xla/python/ifrt/support:module_parsing", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "atom_program_compiler", + hdrs = ["atom_program_compiler.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//xla/python/ifrt:friends"], + deps = [ + ":ir", + "//xla/pjrt:pjrt_executable", + "//xla/python/ifrt", + "//xla/python/ifrt/hlo:hlo_program", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/third_party/xla/xla/python/ifrt/ir/atom_program_compiler.h b/third_party/xla/xla/python/ifrt/ir/atom_program_compiler.h new file mode 100644 index 00000000000000..ffec3a3dc753d4 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/atom_program_compiler.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_ +#define XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/shape.h" + +namespace xla { +namespace ifrt { + +// Loaded executable and unique name for a compiled atom program. +struct AtomProgramCompileResult { + std::string name; + std::shared_ptr executable; +}; + +using AtomExecutableMap = + absl::flat_hash_map>; + +class AtomProgramCompiler { + public: + virtual ~AtomProgramCompiler() = default; + + // Delegates the compilation of an atom XLA program. + // `options` uses logical device id in the main mlir module. + virtual absl::StatusOr CompileXla( + std::unique_ptr computation, xla::CompileOptions options) = 0; + + // Delegates the compilation of an MPMD reshard program. + virtual absl::StatusOr CompileMpmdReshard( + std::vector dtypes, std::vector shapes, + std::vector in_array_types, + std::vector out_array_types) = 0; +}; + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_IR_ATOM_PROGRAM_COMPILER_H_ diff --git a/third_party/xla/xla/python/ifrt/ir/constants.h b/third_party/xla/xla/python/ifrt/ir/constants.h index 26f8a7e999dd52..2a9b2ba4683c18 100644 --- a/third_party/xla/xla/python/ifrt/ir/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/constants.h @@ -53,6 +53,16 @@ inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; +// Name of StringAttr used to store the HloSharding. +inline constexpr llvm::StringLiteral kHloShardingAttrName = "mhlo.sharding"; + +inline constexpr llvm::StringLiteral kIfrtModuleTypeAttrName = + "ifrt.module_type"; + +inline constexpr llvm::StringLiteral kIfrtModuleTypeXla = "xla"; +inline constexpr llvm::StringLiteral kIfrtModuleTypeMpmdReshard = + "mpmd_reshard"; + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/compiler.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc similarity index 83% rename from third_party/xla/xla/python/ifrt/ir/compiler.cc rename to third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc index 8922d23ec30f22..12f2b07fc9ac67 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiler.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,10 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/python/ifrt/ir/compiler.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "xla/python/ifrt/compiler.h" + namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/ir/compiler.h b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h similarity index 74% rename from third_party/xla/xla/python/ifrt/ir/compiler.h rename to third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h index 15db4a974121ec..c8f5e6cde1ca1d 100644 --- a/third_party/xla/xla/python/ifrt/ir/compiler.h +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_IR_COMPILER_H_ -#define XLA_PYTHON_IFRT_IR_COMPILER_H_ +#ifndef XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ +#define XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ #include #include @@ -25,6 +25,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/Support/ExtensibleRTTI.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/executable.h" @@ -37,10 +39,19 @@ struct IfrtIRProgram : llvm::RTTIExtends { IfrtIRProgram() = default; explicit IfrtIRProgram(mlir::ModuleOp mlir_module) : mlir_module(std::move(mlir_module)) {} + IfrtIRProgram(std::unique_ptr context, + mlir::OwningOpRef module) + : mlir_module(*module), + mlir_context(std::move(context)), + owning_mlir_module(std::move(module)) {} mlir::ModuleOp mlir_module; static char ID; // NOLINT + + private: + std::unique_ptr mlir_context; + mlir::OwningOpRef owning_mlir_module; }; // CompileOptions for an IFRT IR program. @@ -58,13 +69,13 @@ struct IfrtIRCompileOptions loaded_exec_binding(std::move(loaded_exec_binding)), compile_options_overrides(std::move(compile_options_overrides)) {} - // Map from logical device ids in MLIR module to runtime device ids obtained - // from IFRT client. + // Mapping from logical device ids in IFRT IR MLIR module to runtime device + // ids obtained from IFRT client. std::vector device_assignments; - // Map from `getSymName()` of declared LoadedExecutableOp in the `mlir_module` - // to pre-compiled LoadedExecutable instance. The LoadedExecutables must - // outlive the LoadedExecutable to be compiled. + // Map from symbol names of LoadedExecutableOp in the IFRT IR MLIR module + // to pre-compiled `LoadedExecutable` instance. The `LoadedExecutable`s must + // outlive the `LoadedExecutable` of the IFRT IR program. absl::flat_hash_map> loaded_exec_binding; @@ -85,4 +96,4 @@ absl::StatusOr> GetIfrtIRCompileOptions( } // namespace ifrt } // namespace xla -#endif // XLA_PYTHON_IFRT_IR_COMPILER_H_ +#endif // XLA_PYTHON_IFRT_IR_IFRT_IR_PROGRAM_H_ diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc new file mode 100644 index 00000000000000..e666dbf6275353 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes.cc @@ -0,0 +1,92 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ExtensibleRTTI.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Support/LLVM.h" +#include "xla/mlir/utils/error_util.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/support/module_parsing.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +namespace { + +class IfrtIRProgramSerDes + : public llvm::RTTIExtends { + public: + absl::string_view type_name() const override { + return "xla::ifrt::IfrtIRProgram"; + } + + absl::StatusOr Serialize(Serializable& serializable) override { + const auto& program = llvm::cast(serializable); + if (program.mlir_module == nullptr) { + return absl::InvalidArgumentError("Unable to serialize null MLIR module"); + } + std::string serialized; + llvm::raw_string_ostream out(serialized); + mlir::BytecodeWriterConfig config; + mlir::BaseScopedDiagnosticHandler diagnostic_handler( + program.mlir_module->getContext()); + if (mlir::failed( + mlir::writeBytecodeToFile(program.mlir_module, out, config))) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to serialize IFRT IR module string: %s", + diagnostic_handler.ConsumeStatus().message())); + } + return serialized; + } + + absl::StatusOr> Deserialize( + const std::string& serialized, + std::unique_ptr) override { + auto context = std::make_unique(); + TF_ASSIGN_OR_RETURN(auto module, + support::ParseMlirModuleString(serialized, *context)); + return std::make_unique(std::move(context), + std::move(module)); + } + + static char ID; // NOLINT +}; + +char IfrtIRProgramSerDes::ID = 0; // NOLINT + +// clang-format off +bool register_ifrt_ir_program_serdes = ([]() { + RegisterSerDes(std::make_unique()); +}(), true); +// clang-format on + +} // namespace + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc new file mode 100644 index 00000000000000..019f3599d73ed2 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ir_program_serdes_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" +#include "xla/python/ifrt/serdes.h" +#include "xla/python/ifrt/support/module_parsing.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { +namespace { + +using ::testing::HasSubstr; +using ::testing::Not; +using ::tsl::testing::StatusIs; + +std::string PrintModule(mlir::ModuleOp module) { + std::string module_str; + llvm::raw_string_ostream os(module_str); + module->print(os, mlir::OpPrintingFlags().enableDebugInfo()); + return module_str; +} + +TEST(IfrtIRProgramSerDesTest, RoundTrip) { + static constexpr absl::string_view kMlirModuleStr = R"( +!array = !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [0]> +module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0] + : (!array) -> !array + return %0 : !array + } + + module @add_one { + func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.constant dense<1> : tensor<2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2xi32> + return %1 : tensor<2xi32> + } + } +} + )"; + + Serialized serialized; + auto context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + mlir::OwningOpRef module, + support::ParseMlirModuleString(kMlirModuleStr, *context)); + auto initial_program = + std::make_unique(std::move(context), std::move(module)); + TF_ASSERT_OK_AND_ASSIGN(serialized, Serialize(*initial_program)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr deserialized_program, + Deserialize(serialized, /*options=*/nullptr)); + + EXPECT_EQ(PrintModule(initial_program->mlir_module), + PrintModule(deserialized_program->mlir_module)); +} + +TEST(IfrtIRProgramSerDesTest, DeserializationError) { + static constexpr absl::string_view kMlirModuleStr = R"( +!array = !ifrt.array, #ifrt.sharding_param<1 to [0] on 1>, [0]> +module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0] + : (!array) -> !array + return %0 : !array + } + + module @add_one { + func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.constant dense<1> : tensor<2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2xi32> + return %1 : tensor<2xi32> + } + } +} + )"; + Serialized serialized; + { + auto context = std::make_unique(); + TF_ASSERT_OK_AND_ASSIGN( + mlir::OwningOpRef module, + support::ParseMlirModuleString(kMlirModuleStr, *context)); + auto program = + std::make_unique(std::move(context), std::move(module)); + TF_ASSERT_OK_AND_ASSIGN(serialized, Serialize(*program)); + } + + serialized.set_data("invalid data"); + + EXPECT_THAT(Deserialize(serialized, /*options=*/nullptr), + StatusIs(Not(absl::StatusCode::kOk), + HasSubstr("Failed to parse IFRT IR module string"))); +} + +} // namespace +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc index 080f0faf76e725..cc37e5010e0840 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc @@ -248,12 +248,24 @@ mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias, return mlir::success(); } -mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, - mlir::ArrayAttr io_aliases, - llvm::ArrayRef inputs, - llvm::ArrayRef outputs) { - llvm::SmallSet aliased_inputs; +mlir::LogicalResult VerifyIoAliasesAndDonations( + mlir::Operation* op, mlir::ArrayAttr io_aliases, + llvm::ArrayRef donated_input_indices, + llvm::ArrayRef inputs, + llvm::ArrayRef outputs) { + llvm::SmallSet aliased_or_donated_inputs; llvm::SmallSet aliased_outputs; + for (const int32_t donated_input_index : donated_input_indices) { + if (donated_input_index < 0 || donated_input_index >= inputs.size()) { + return op->emitOpError() + << "can't donate input #" << donated_input_index + << " as only having " << inputs.size() << " inputs"; + } + if (!aliased_or_donated_inputs.insert(donated_input_index).second) { + return op->emitOpError() << "can't donate input #" << donated_input_index + << " more than once"; + } + } for (const auto& raw_io_alias : io_aliases.getAsRange()) { llvm::ArrayRef io_alias_as_array = raw_io_alias.asArrayRef(); @@ -263,9 +275,9 @@ mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, inputs, outputs))) { return mlir::failure(); } - if (!aliased_inputs.insert(aliased_input).second) { - return op->emitOpError() - << "can't alias input #" << aliased_input << " more than once"; + if (!aliased_or_donated_inputs.insert(aliased_input).second) { + return op->emitOpError() << "can't alias or donate input #" + << aliased_input << " more than once"; } if (!aliased_outputs.insert(aliased_output).second) { return op->emitOpError() @@ -618,8 +630,9 @@ mlir::LogicalResult CallOp::verify() { if (mlir::failed(VerifyDevicePlacement(*this, getDevices(), input_arrays, output_arrays)) || - mlir::failed(VerifyIoAliases(*this, getIoAliases(), input_arrays, - output_arrays))) { + mlir::failed(VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), + input_arrays, output_arrays))) { return mlir::failure(); } return mlir::success(); @@ -680,7 +693,9 @@ mlir::LogicalResult CallLoadedExecutableOp::verify() { output_arrays.push_back(mlir::cast(output.getType())); } - return VerifyIoAliases(*this, getIoAliases(), input_arrays, output_arrays); + return VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), input_arrays, + output_arrays); } mlir::LogicalResult LoadedExecutableOp::verify() { diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td index 937cdf96ca6e30..1f4bde2c710152 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.td @@ -182,9 +182,11 @@ def Ifrt_CallOp : Ifrt_Op<"Call", a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins @@ -192,7 +194,8 @@ def Ifrt_CallOp : Ifrt_Op<"Call", Variadic:$control_inputs, SymbolRefAttr:$callee, Ifrt_DevicesAttr:$devices, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); @@ -220,16 +223,19 @@ def Ifrt_CallLoadedExecutableOp : Ifrt_Op<"CallLoadedExecutable", be placed on a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins Variadic:$inputs, Variadic:$control_inputs, SymbolRefAttr:$callee, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index 01ca1bff5c92e8..0b7f676f76ca69 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -10,9 +10,16 @@ lit_test_suite( name = "all_tests", srcs = enforce_glob( [ + "ifrt_compile_atom_program.mlir", + "ifrt_compile_and_propagate_shardings.mlir", "ifrt_duplicated_callee_elimination.mlir", + "ifrt_lower_mpmd_reshard_to_call.mlir", + "ifrt_lower_sharding_to_xla.mlir", "ifrt_merge_reshards.mlir", "ifrt_outline_atom_program_to_module.mlir", + "ifrt_populate_atom_program_metadata.mlir", + "ifrt_remove_ifrt_attrs.mlir", + "ifrt_reshard_to_copy_arrays.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", "spmd_expansion.mlir", @@ -41,12 +48,25 @@ lit_test_suite( xla_cc_binary( name = "ifrt-opt", + testonly = True, srcs = ["ifrt-opt.cc"], deps = [ "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/pjrt:pjrt_executable", + "//xla/python/ifrt", + "//xla/python/ifrt:mock", + "//xla/python/ifrt/hlo:hlo_program", "//xla/python/ifrt/ir", - "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", + "//xla/python/ifrt/ir:atom_program_compiler", "//xla/python/ifrt/ir/transforms:passes", + "//xla/python/ifrt/support:module_parsing", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", @@ -62,19 +82,16 @@ cc_library( deps = [ "//xla:status_macros", "//xla/mlir/utils:error_util", - "//xla/mlir_hlo:hlo_dialect_registration", "//xla/python/ifrt", "//xla/python/ifrt:test_util", - "//xla/python/ifrt/ir", "//xla/python/ifrt/ir:sharding_param", - "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", + "//xla/python/ifrt/support:module_parsing", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:statusor", @@ -93,7 +110,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/python/ifrt/hlo:hlo_program", - "//xla/python/ifrt/ir:compiler", + "//xla/python/ifrt/ir:ifrt_ir_program", "//xla/python/ifrt/ir:sharding_param", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc index 804f143bd59efe..954680c87d8fe4 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_base.cc @@ -25,23 +25,19 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/InitAllDialects.h" #include "mlir/Parser/Parser.h" #include "xla/mlir/utils/error_util.h" -#include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" #include "xla/python/ifrt/dtype.h" -#include "xla/python/ifrt/ir/ifrt_dialect.h" #include "xla/python/ifrt/ir/sharding_param.h" -#include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" #include "xla/python/ifrt/memory.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" +#include "xla/python/ifrt/support/module_parsing.h" #include "xla/python/ifrt/test_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" @@ -53,13 +49,7 @@ namespace test_util { IfrtIrExecutableImplTestBase::IfrtIrExecutableImplTestBase() { mlir::registerMLIRContextCLOptions(); - - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - mlir::mhlo::registerAllMhloDialects(registry); - registry.insert(); - xla::ifrt::AttachBuiltInSpmdExpansions(registry); - mlir_context_.appendDialectRegistry(registry); + xla::ifrt::support::RegisterMlirDialects(mlir_context_); } void IfrtIrExecutableImplTestBase::SetUp() { diff --git a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc index cef4bdf654c52d..4557f969f76b39 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/python/ifrt/dtype.h" #include "xla/python/ifrt/executable.h" #include "xla/python/ifrt/hlo/hlo_program.h" -#include "xla/python/ifrt/ir/compiler.h" +#include "xla/python/ifrt/ir/ifrt_ir_program.h" #include "xla/python/ifrt/ir/sharding_param.h" #include "xla/python/ifrt/ir/tests/executable_impl_test_base.h" #include "xla/python/ifrt/shape.h" @@ -87,9 +87,11 @@ module { xla::ifrt::DType(xla::ifrt::DType::kS32), xla::ifrt::ShardingParam({2, 1}, {{0}, {2}}), devices)); + ExecuteOptions options; + options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); @@ -127,9 +129,11 @@ module { xla::ifrt::ShardingParam({1}, {{0}, {1}}), BasicDeviceList::Create({devices->devices()[0]}))); + ExecuteOptions options; + options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); @@ -168,9 +172,11 @@ module { xla::ifrt::ShardingParam({1}, {{0}, {1}}), BasicDeviceList::Create({devices->devices()[0]}))); + ExecuteOptions options; + options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); @@ -205,9 +211,11 @@ module { std::make_unique(*mlir_module), std::make_unique(GetDeviceIds(devices)))); - TF_ASSERT_OK_AND_ASSIGN(LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(/*args=*/{}, /*options=*/{}, - /*devices=*/std::nullopt)); + ExecuteOptions options; + options.fill_status = true; + TF_ASSERT_OK_AND_ASSIGN( + LoadedExecutable::ExecuteResult result, + loaded_exec->Execute(/*args=*/{}, options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); ASSERT_EQ(result.outputs.size(), 1); @@ -250,9 +258,11 @@ module { xla::ifrt::DType(xla::ifrt::DType::kS32), xla::ifrt::ShardingParam({2, 1}, {{0}, {2}}), devices)); + ExecuteOptions options; + options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); @@ -295,9 +305,11 @@ module { xla::ifrt::DType(xla::ifrt::DType::kS32), xla::ifrt::ShardingParam({2, 1}, {{0}, {2}}), devices)); + ExecuteOptions options; + options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); ASSERT_EQ(result.outputs.size(), 1); @@ -376,9 +388,11 @@ module { xla::ifrt::DType(xla::ifrt::DType::kS32), xla::ifrt::ShardingParam({2, 1}, {{0}, {2}}), devices)); + ExecuteOptions execute_options; + execute_options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, - loaded_exec->Execute(absl::MakeSpan(&input, 1), /*options=*/{}, + loaded_exec->Execute(absl::MakeSpan(&input, 1), execute_options, /*devices=*/std::nullopt)); TF_ASSERT_OK(result.status.Await()); diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc index 78c88c3c779195..593e9737c5812d 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt-opt.cc @@ -13,23 +13,125 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include + +#include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "mlir/IR/DialectRegistry.h" -#include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" #include "xla/python/ifrt/ir/ifrt_dialect.h" -#include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" #include "xla/python/ifrt/ir/transforms/passes.h" +#include "xla/python/ifrt/mock.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/ifrt/support/module_parsing.h" + +namespace xla { +namespace ifrt { +namespace { + +static constexpr int kMaxTestMethods = 1000; + +class TestChildExecutableCompiler : public AtomProgramCompiler { + public: + TestChildExecutableCompiler() { methods_.reserve(kMaxTestMethods); } + + absl::StatusOr CompileXla( + std::unique_ptr hlo_program, + xla::CompileOptions options) override ABSL_LOCKS_EXCLUDED(mu_) { + absl::MutexLock lock(&mu_); + methods_.push_back(absl::StrCat("fake_method_", methods_.size())); + CHECK_LT(methods_.size(), kMaxTestMethods) + << "push_back() might have caused reallocation, which might have " + "invalidated some method string_views."; + auto mock_executable = + std::make_unique>(); + int num_parameters_to_propagate = + options.executable_build_options + .allow_spmd_sharding_propagation_to_parameters() + .size(); + if (num_parameters_to_propagate > 0) { + xla::OpSharding op_sharding; + op_sharding.set_type(xla::OpSharding::REPLICATED); + std::vector parameter_shardings( + num_parameters_to_propagate, op_sharding); + ON_CALL(*mock_executable, GetParameterShardings()) + .WillByDefault(testing::Return(std::move(parameter_shardings))); + } + int num_outputs_to_propagate = + options.executable_build_options + .allow_spmd_sharding_propagation_to_output() + .size(); + if (num_outputs_to_propagate > 0) { + // Always infer output shardings to be replicated for the lit tests. + xla::OpSharding op_sharding; + op_sharding.set_type(xla::OpSharding::REPLICATED); + std::vector output_shardings(num_outputs_to_propagate, + op_sharding); + ON_CALL(*mock_executable, GetOutputShardings()) + .WillByDefault(testing::Return(std::move(output_shardings))); + } + return AtomProgramCompileResult{ + /*name=*/absl::StrCat("fake_component__", methods_.back()), + /*executable=*/std::move(mock_executable)}; + } + + absl::StatusOr CompileMpmdReshard( + std::vector dtypes, std::vector shapes, + std::vector in_array_types, + std::vector out_array_types) override + ABSL_LOCKS_EXCLUDED(mu_) { + absl::MutexLock lock(&mu_); + methods_.push_back(absl::StrCat("fake_method_", methods_.size())); + CHECK_LT(methods_.size(), kMaxTestMethods) + << "push_back() might have caused reallocation, which might have " + "invalidated some method string_views."; + auto mock_executable = + std::make_unique>(); + return AtomProgramCompileResult{ + /*name=*/absl::StrCat("fake_mpmd_reshard_component__", methods_.back()), + /*executable=*/std::make_unique()}; + } + + private: + absl::Mutex mu_; + std::vector methods_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace +} // namespace ifrt +} // namespace xla int main(int argc, char** argv) { - mlir::DialectRegistry registry; - mlir::registerAllDialects(registry); - mlir::mhlo::registerAllMhloDialects(registry); - registry.insert(); - xla::ifrt::registerIfrtIrPasses(); + std::shared_ptr compiler = + std::make_shared(); + auto compile_options = std::make_shared>>(); + std::shared_ptr atom_executable_map = + std::make_shared(); + std::shared_ptr bound_executable_map = + std::make_shared(); - xla::ifrt::AttachBuiltInSpmdExpansions(registry); + mlir::registerAllPasses(); + xla::ifrt::RegisterIfrtPassesAndPipelines( + compiler, compile_options, atom_executable_map, bound_executable_map); + mlir::DialectRegistry registry; + xla::ifrt::support::InitializeMlirDialectRegistry(registry); return mlir::asMainReturnCode( - mlir::MlirOptMain(argc, argv, "IFRT dialect driver\n", registry)); + mlir::MlirOptMain(argc, argv, "IFRT IR dialect driver\n", registry)); } diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir new file mode 100644 index 00000000000000..4021496168cb8c --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_and_propagate_shardings.mlir @@ -0,0 +1,392 @@ +// RUN: ifrt-opt %s -ifrt-compile-and-propagate-shardings -split-input-file | FileCheck %s + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_to_next_call_op +module @propagate_to_next_call_op { + func.func @main( + %arg0: !array) -> !array_unspecified attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_0:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array) -> !array_unspecified + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_1:.+]](%[[OUT_0]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_1::@main(%0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + return %1 : !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_0]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_1]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_1 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @verify_only_one_module_is_compiled +module @verify_only_one_module_is_compiled { + func.func @main(%arg0: !array) -> (!array_unspecified, !array_unspecified) + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_0:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array) -> !array_unspecified + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_0:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + : (!array) -> !array_unspecified + return %0, %1 : !array_unspecified, !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_0]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_to_reshard +module @propagate_to_reshard { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array) -> !array_unspecified + // CHECK: %[[OUT_RESHARD:.+]], %{{.+}} = ifrt.Reshard(%[[OUT]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Reshard(%0) : (!array_unspecified) -> !array + return %1 : !array + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_to_two_call_op +module @propagate_to_two_call_op { + func.func @main(%arg0: !array) -> (!array_unspecified, !array) + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_0:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array) -> !array_unspecified + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_1:.+]](%[[OUT_0]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_1::@main(%0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + // CHECK: %[[OUT_2:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_2:.+]](%[[OUT_0]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + %2, %ctrl_2 = ifrt.Call @add_one_2::@main(%0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array + return %1, %2 : !array_unspecified, !array + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_0]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_1]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_1 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_2]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + module @add_one_2 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_from_two_call_op +module @propagate_from_two_call_op { + func.func @main(%arg0: !array) -> (!array_unspecified, !array) + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_0:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array) -> !array_unspecified + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_1:.+]](%[[OUT_0]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_1::@main(%0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + // CHECK: %[[OUT_2:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE_2:.+]](%[[OUT_0]], %[[OUT_1]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + %2, %ctrl_2 = ifrt.Call @add_args::@main(%0, %1) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified, !array_unspecified) -> !array + return %1, %2 : !array_unspecified, !array + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_0]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> { + mhlo.sharding = "{devices=[2,1]<=[2]}"}) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_1]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_1 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE_2]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + module @add_args attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = mhlo.add %arg0, %arg1 : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + } +} + +// ----- + +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_to_inputs +module @propagate_to_inputs { + func.func @main(%arg0: !array_unspecified) + -> !array_unspecified attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + return %0 : !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } + +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_from_reshard +module @propagate_from_reshard { + func.func @main(%arg0: !array_unspecified) + -> (!array, !array_unspecified) attributes {ifrt.function} { + // CHECK: %[[OUT_RESHARD:.+]], %{{.+}} = ifrt.Reshard(%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array_unspecified) -> !array + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + return %0, %1 : !array, !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_from_copy_arrays +module @propagate_from_copy_arrays { + func.func @main(%arg0: !array_unspecified) + -> (!array, !array_unspecified) attributes {ifrt.function} { + // CHECK: %[[OUT_COPY:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.CopyArrays(%arg0) : (!array_unspecified) -> !array + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %1, %ctrl_1 = ifrt.Call @add_one_0::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array_unspecified) -> !array_unspecified + return %0, %1 : !array, !array_unspecified + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one_0 attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [2, 3]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0, 1]> +// CHECK-LABEL: @propagate_to_copy_arrays +module @propagate_to_copy_arrays { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.CallLoadedExecutable @[[CALLEE:.+]](%arg0) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0, 1] + {ifrt.module_type = "xla"} : (!array0) -> !array_unspecified + // CHECK: %[[OUT_COPY:.+]], %{{.+}} = ifrt.CopyArrays(%[[OUT]]) + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [2, 3]> + %1, %ctrl_1 = ifrt.CopyArrays(%0) : (!array_unspecified) -> !array1 + return %1 : !array1 + } + + // CHECK: ifrt.LoadedExecutable @[[CALLEE]] + // CHECK-SAME: on devices [0, 1] + // CHECK-SAME: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]> + module @add_one attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32> {mhlo.sharding = "{replicated}"}) + -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir new file mode 100644 index 00000000000000..b99e0f9a43b79e --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_compile_atom_program.mlir @@ -0,0 +1,27 @@ +// RUN: ifrt-opt %s -ifrt-compile-atom-program -split-input-file | FileCheck %s + +// CHECK-LABEL: @call_hlo +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +module @call_hlo { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.CallLoadedExecutable @fake_component__fake_method + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + {ifrt.module_type = "xla"} : (!array) -> !array + return %0 : !array + } + + // CHECK: ifrt.LoadedExecutable @fake_component__fake_method + // CHECK-SAME: on devices [0, 1] + // CHECK: (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>) + // CHECK-SAME: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + module @add_one attributes {sym_visibility = "private"} { + func.func private @main( + %arg0: tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) + -> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir new file mode 100644 index 00000000000000..fe3331a9fa306e --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_mpmd_reshard_to_call.mlir @@ -0,0 +1,227 @@ +// RUN: ifrt-opt %s -ifrt-lower-mpmd-reshard-to-call -split-input-file -verify-diagnostics | FileCheck %s + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @reshard_without_donation +module @reshard_without_donation { + func.func public @main(%arg0: !array0) -> (!array1) + attributes {ifrt.function} { + // CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0 : !array1 + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @reshard_with_donation +module @reshard_with_donation { + func.func public @main(%arg0: !array0) -> (!array1) + attributes {ifrt.function} { + // CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] + // CHECK-SAME: { + // CHECK-DAG: ifrt.module_type = "mpmd_reshard" + // CHECK-DAG: donated_input_indices = array + // CHECK-SAME: } + %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 + return %0 : !array1 + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [2, 3]> +// ifrt.Reshard does not need to be converted to a MPMD reshard ifrt.Call +// because the reshard is a 1:1 buffer copy between devices. +module @reshard_is_not_converted_to_call { + func.func public @main(%arg0: !array0) -> (!array1) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Reshard' op does not reshard any arrays. Use CopyArraysOp instead}} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0 : !array1 + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @reshard_after_call_to_module +module @reshard_after_call_to_module { + func.func public @main(%arg0: !array0) -> (!array1) + attributes {ifrt.function} { + // CHECK: %[[OUT_1:.*]], %[[CTRL_OUT:.*]] = ifrt.Call @add_one + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1] + : (!array0) -> !array0 + // CHECK: %[[OUT_2:.*]], %{{.+}} = ifrt.Call @reshard_4784300543980450571::@main(%[[OUT_1]]) after %[[CTRL_OUT]] + // CHECK: {ifrt.module_type = "mpmd_reshard"} + %1, %ctrl_1 = ifrt.Reshard(%0) after %ctrl_0 : (!array0) -> !array1 + return %1 : !array1 + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @reshard_before_call_to_module +module @reshard_before_call_to_module { + func.func public @main(%arg0: !array0) -> (!array1) + attributes {ifrt.function} { + // CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: %[[OUT:.*]], %[[CTRL_OUT:.*]] = ifrt.Call @add_one + %1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [2] + : (!array1) -> !array1 + return %1 : !array1 + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @two_identical_reshards_single_module +module @two_identical_reshards_single_module { + func.func public @main(%arg0: !array0, %arg1: !array0) -> (!array1, !array1) + attributes {ifrt.function} { + // CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: ifrt.Call @reshard_4784300543980450571::@main(%arg1) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %1, %ctrl_1 = ifrt.Reshard(%arg1) : (!array0) -> !array1 + return %0, %1 : !array1, !array1 + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @two_reshards_two_modules +module @two_reshards_two_modules { + func.func public @main(%arg0: !array0) -> (!array0) + attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.Call @reshard_4784300543980450571::@main(%arg0) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: ifrt.Call @reshard_17322361279023763284::@main(%[[OUT]]) on devices [0, 1, 2] {ifrt.module_type = "mpmd_reshard"} + %1, %ctrl_1 = ifrt.Reshard(%0) : (!array1) -> !array0 + return %1 : !array0 + } + + // CHECK: module @reshard_4784300543980450571 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> + + // CHECK: module @reshard_17322361279023763284 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 3 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [2]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [0]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +// Tests if the module for the MPMD reshard has unique devices. +// CHECK-LABEL: @check_reshard_module_has_unique_devices +module @check_reshard_module_has_unique_devices { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: ifrt.Call @reshard_6746659470058475136::@main(%arg0) on devices [0, 1] {ifrt.module_type = "mpmd_reshard"} + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0 : !array1 + } + + // CHECK: module @reshard_6746659470058475136 + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func @main( + // CHECK: %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 1>, [0]> + // CHECK: -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_sharding_to_xla.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_sharding_to_xla.mlir new file mode 100644 index 00000000000000..8ae74218938344 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_lower_sharding_to_xla.mlir @@ -0,0 +1,131 @@ +// RUN: ifrt-opt %s -ifrt-lower-sharding-to-xla -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @arg_sharding +module @arg_sharding attributes {ifrt.num_devices = 2} { + // CHECK: %arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: mhlo.sharding = "{devices=[2,1]<=[2]}" + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: } + // CHECK: %arg1: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: mhlo.sharding = "{replicated}" + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<1x1 to [0] on 2> + // CHECK-SAME: } + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>}, + %arg1: tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<1x1 to [0] on 2>}) { + return + } +} + +// ----- + +// CHECK-LABEL: @arg_unspecified_sharding +module @arg_unspecified_sharding attributes {ifrt.num_devices = 2} { + // CHECK: %arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: mhlo.sharding = "{devices=[2,1]<=[2]}" + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: } + // CHECK: %arg1: tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding_unspecified}) + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>}, + %arg1: tensor<2x2xi32> {ifrt.sharding=#ifrt.sharding_unspecified}) { + return + } +} + +// ----- + +// CHECK-LABEL: @result_sharding +module @result_sharding attributes {ifrt.num_devices = 2} { + // CHECK: -> (tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: mhlo.sharding = "{devices=[2,1]<=[2]}" + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: } + func.func @main() + -> (tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } +} + +// ----- + +// CHECK-LABEL: @result_unspecified_sharding +module @result_unspecified_sharding attributes {ifrt.num_devices = 2} { + // CHECK: -> (tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: mhlo.sharding = "{devices=[2,1]<=[2]}" + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: } + // CHECK: tensor<2x2xi32> {ifrt.sharding = #ifrt.sharding_unspecified}) + func.func @main() + -> (tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>}, + tensor<2x2xi32> {ifrt.sharding=#ifrt.sharding_unspecified}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0, %0 : tensor<2x2xi32>, tensor<2x2xi32> + } +} + + +// ----- + +module @arg_missing_sharding attributes {ifrt.num_devices = 2} { + // expected-error @+1 {{'func.func' op can't find `ifrt.sharding` attribute of input #0 to set `mhlo.sharding` attribute}} + func.func @main(%arg0: tensor<2x2xi32>) { + return + } +} + +// ----- + +module @result_missing_sharding attributes {ifrt.num_devices = 2} { + // expected-error @+1 {{'func.func' op can't find `ifrt.sharding` attribute of output #0 to set `mhlo.sharding` attribute}} + func.func @main() -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } +} + +// ----- + +// expected-error @+1 {{'builtin.module' op module `module_missing_devices` must have `ifrt.num_devices` attribute}} +module @module_missing_devices { + func.func @main() -> (tensor<2x2xi32> + {ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices=#ifrt}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } +} + +// ----- + +// expected-error @+2 {{'func.func' op can't lower sharding of input #0. Sharding: #ifrt.sharding_param<1x1 to [0] on 1> uses 1 devices while computation uses 2 devices}} +module @arg_w_different_num_devices attributes {ifrt.num_devices = 2} { + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<1x1 to [0] on 1>}) { + return + } +} + +// ----- + +// expected-error @+2 {{'func.func' op can't lower sharding of output #0. Sharding: #ifrt.sharding_param<2x1 to [0] on 2> uses 2 devices while computation uses 4 devices}} +module @res_w_different_num_devices attributes {ifrt.num_devices = 4} { + func.func @main() + -> (tensor<2x2xi32> { + ifrt.sharding=#ifrt.sharding_param<2x1 to [0] on 2>}) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir new file mode 100644 index 00000000000000..18256e356a1acd --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir @@ -0,0 +1,281 @@ +// RUN: ifrt-opt %s -ifrt-populate-atom-program-metadata -ifrt-duplicated-callee-elimination -symbol-dce -split-input-file | FileCheck %s + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @populate_arg_sharding +module @populate_arg_sharding { + func.func @main(%arg0: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE:.+]]::@main(%arg0) + %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1] : (!array) -> () + return + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-NOT: ifrt + module @callee attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) { + return + } + } +} + +// ----- + +// CHECK-LABEL: @populate_result_sharding +module @populate_result_sharding { + func.func @main() attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE:.+]]::@main() + %0, %ctrl_0 = ifrt.Call @callee::@main() on devices [0,1] + : () -> (!ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [1,0]>) + return + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-NOT: ifrt + module @callee attributes {sym_visibility = "private"} { + func.func private @main() -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + } +} + +// ----- + +// Verifies that a single module is populated with metadata if the input and +// output types are the same. +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_outlined_to_single_module +module @calls_outlined_to_single_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[CALLEE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: %[[OUT_1:.+]], %[[CTRL_1:.+]] = ifrt.Call @[[CALLEE]]::@main(%[[OUT_0]]) + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE]]::@main(%[[OUT_1]]) after %[[CTRL_1]] + %2, %ctrl_2 = ifrt.Call @add_one::@main(%1) after %ctrl_1 on devices [0,1] + : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-NOT: ifrt + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +// Verifies that a single module is populated with metadata even if the +// devices are different. +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [2,3]> +// CHECK-LABEL: @calls_on_different_devices_outlined_to_single_module +module @calls_on_different_devices_outlined_to_single_module { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[CALLEE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!array0) -> !array0 + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.CopyArrays(%[[OUT_0]]) + %1, %ctrl_1 = ifrt.CopyArrays(%0) : (!array0) -> (!array1) + // CHECK: %[[OUT_2:.+]], %[[CTRL_2:.+]] = ifrt.Call @[[CALLEE]]::@main(%[[OUT_1]]) + %2, %ctrl_2 = ifrt.Call @add_one::@main(%1) on devices [2,3] + : (!array1) -> !array1 + // CHECK: ifrt.Call @[[CALLEE]]::@main(%[[OUT_2]]) after %[[CTRL_2]] + %3, %ctrl_3 = ifrt.Call @add_one::@main(%2) after %ctrl_2 on devices [2,3] + : (!array1) -> !array1 + return %3 : !array1 + } + + // CHECK: module @[[CALLEE]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-NOT: ifrt + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +// CHECK-LABEL: @call_twice_with_different_sharding +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0,1]> +module @call_twice_with_different_sharding { + func.func @main(%arg0: !array) -> !array_unspecified + attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %{{.+}} = ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%[[OUTPUT]]) + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + : (!array) -> !array_unspecified + return %1 : !array_unspecified + } + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: -> (tensor<2x2xi32> + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_unspecified + // CHECK-NOT: ifrt + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @populate_io_alias_and_donation +module @populate_io_alias_and_donation { + func.func @main(%arg0: !array, %arg1: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0, %arg1) + %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + {io_aliases=[array], donated_input_indices=array} + : (!array, !array) -> !array + // Verify that the module is cloned if io_aliases differ. + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0, %arg1) + %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + : (!array, !array) -> !array + return + } + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: tf.aliasing_output = 0 : i32 + // CHECK-SAME: } + // CHECK: %arg1: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: jax.buffer_donor = true + // CHECK-SAME: } + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main(%arg0: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-SAME: } + module @callee attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0: tensor<2x2xi32> + } + } +} + +// ----- + +!shared_array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @output_of_call_donated +module @output_of_call_donated { + func.func @main(%arg0: !shared_array) -> !shared_array + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) on devices [0, 1] : + %0, %ctrl_0 = ifrt.Call @add_one::@main(%arg0) on devices [0,1] + : (!shared_array) -> !shared_array + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[CALLEE_1:.+]]::@main(%[[OUT_0]]) on devices [0, 1] {io_aliases = [array]} : + %1, %ctrl_1 = ifrt.Call @add_one::@main(%0) on devices [0,1] + {io_aliases=[array]} : (!shared_array) -> !shared_array + return %1 : !shared_array + } + + // CHECK: module @[[CALLEE_0]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + + // CHECK: module @[[CALLEE_1]] + // CHECK-SAME: attributes { + // CHECK-DAG: ifrt.num_devices = 2 + // CHECK-DAG: sym_visibility = "private" + // CHECK-SAME: } + // CHECK: func.func private @main + // CHECK-SAME: %arg0: tensor<2x2xi32> + // CHECK-SAME: tf.aliasing_output = 0 : i32 + module @add_one attributes {sym_visibility = "private"} { + func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir new file mode 100644 index 00000000000000..5ea0c863c21b7d --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_remove_ifrt_attrs.mlir @@ -0,0 +1,14 @@ +// RUN: ifrt-opt %s -ifrt-remove-ifrt-attrs | FileCheck %s + +// CHECK-LABEL: @ifrt_attributes_are_removed +// CHECK-NOT: ifrt +module @ifrt_attributes_are_removed attributes {ifrt.num_devices = 2} { + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<1x1 to [0] on 1>}) + -> (tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<1x1 to [0] on 1>}) + attributes {ifrt.devices = #ifrt} { + return %arg0 : tensor<2x2xi32> + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir new file mode 100644 index 00000000000000..cf8231b67050a4 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_reshard_to_copy_arrays.mlir @@ -0,0 +1,106 @@ +// RUN: ifrt-opt %s -ifrt-reshard-to-copy-arrays -verify-diagnostics -split-input-file | FileCheck %s + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @reshard_to_copy_arrays +module @reshard_to_copy_arrays { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: return %[[COPIED]] + return %0 : !array1 + } +} + +// ----- + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @reshard_not_converted +module @reshard_not_converted { + func.func @main(%arg0: !array0) -> !array1 attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg0) + %0, %ctrl_0 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + // CHECK: return %[[RESHARDED]] + return %0 : !array1 + } +} + +// ----- + +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @extract_copy_from_reshard +module @extract_copy_from_reshard { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) {donated = true} + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) {donated = true} + %0, %1, %ctrl_0 = ifrt.Reshard(%arg0, %arg1) {donated = true} + : (!array0, !array1) -> (!array1, !array2) + // CHECK: return %[[COPIED]], %[[RESHARDED]] + return %0, %1: !array1, !array2 + } +} + +// ----- + +// Verifies that an ifrt.CopyArrays is introduced for each set of devices. +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [0,1]> +// CHECK-LABEL: @extract_copy_per_device_set +module @extract_copy_per_device_set { + func.func @main(%arg0: !array0, %arg1: !array1, %arg2: !array1) + -> (!array1, !array2, !array0) attributes {ifrt.function} { + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) + // CHECK-DAG: %[[COPIED_1:.+]], %{{.+}} = ifrt.CopyArrays(%arg0) + // CHECK-DAG: %[[COPIED_2:.+]], %{{.+}} = ifrt.CopyArrays(%arg2) + %0, %1, %2, %ctrl_0 = ifrt.Reshard(%arg0, %arg1, %arg2) + : (!array0, !array1, !array1) -> (!array1, !array2, !array0) + // CHECK: return %[[COPIED_1]], %[[RESHARDED]], %[[COPIED_2]] + return %0, %1, %2: !array1, !array2, !array0 + } +} + +// ----- + +// Verifies that the control inputs are passed to the CopyArrays. +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +// CHECK-LABEL: @control_inputs_added_to_copy_arrays +module @control_inputs_added_to_copy_arrays { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %[[CTRL:.+]] = ifrt.Call @add_one(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array0) -> !array0 + // CHECK: %[[RESHARDED:.+]], %{{.+}} = ifrt.Reshard(%arg1) after %[[CTRL:.+]] + // CHECK: %[[COPIED:.+]], %{{.+}} = ifrt.CopyArrays(%[[OUT:.+]]) after %[[CTRL:.+]] + %1, %2, %ctrl_1 = ifrt.Reshard(%0, %arg1) after %ctrl_0 + : (!array0, !array1) -> (!array1, !array2) + // CHECK: return %[[COPIED]], %[[RESHARDED]] + return %1, %2: !array1, !array2 + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 8c70318c03598c..92bed2748c2188 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -41,7 +41,7 @@ module @donate_to_reshard_duplicated_arg { // ----- !array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> -module @donate_to_two_calls_error { +module @alias_to_two_calls_error { func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] @@ -59,13 +59,49 @@ module @donate_to_two_calls_error { // ----- +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @donate_to_two_calls_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} + %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0, %1 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @arg_donated_to_call_not_donated_to_program { + func.func @main(%arg0: !array) -> (!array) + attributes {ifrt.function} { + // expected-error @+1 {{'ifrt.Call' op input #0 has not been donated to the program.}} + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0 : !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + !array0 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> !array1 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> module @program_arg_not_donated_error { func.func @main(%arg0: !array0) -> (!array1) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.Reshard' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 has not been donated to the program.}} %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0 : !array1 } @@ -167,7 +203,7 @@ module @donate_to_two_copy_arrays_error { module @program_arg_not_donated_to_remap_error { func.func @main(%arg0: !array {ifrt.donated}, %arg1: !array) -> (!array) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.RemapArrays' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.RemapArrays' op input #1 has not been donated to the program.}} %0 = ifrt.RemapArrays(%arg0, %arg1) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir index e512b260600e73..202724e44496a0 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir @@ -293,7 +293,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -429,4 +429,57 @@ func.func @call_local_view_should_have_valid_shape( func.func @callee(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> { return %arg0 : tensor<4x4xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_an_arg_and_alias_another(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> } \ No newline at end of file diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index 14485f4c86a4e0..e41add06877c60 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -145,7 +145,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -230,3 +230,47 @@ ifrt.LoadedExecutable @callee on devices [0,1] [0,1]>) -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> + + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_one_arg_and_alias_another_arg(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array} : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index 620362de4c1b50..c3695a4e2531f3 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -29,29 +29,69 @@ gentbl_cc_library( cc_library( name = "passes", srcs = [ + "ifrt_compile_and_propagate_shardings_pass.cc", + "ifrt_compile_atom_program_pass.cc", "ifrt_duplicated_callee_elimination_pass.cc", + "ifrt_lower_mpmd_reshard_to_call_pass.cc", + "ifrt_lower_sharding_to_xla_pass.cc", "ifrt_merge_reshards_pass.cc", "ifrt_outline_atom_program_to_module_pass.cc", + "ifrt_populate_atom_program_metadata_pass.cc", + "ifrt_remove_ifrt_attrs_pass.cc", + "ifrt_reshard_to_copy_arrays_pass.cc", + "ifrt_verify_bound_external_loaded_executable_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", + "multi_threaded_atom_program_compiler.cc", + "passes.cc", "spmd_expandable_interface_verification_pass.cc", "spmd_expansion_pass.cc", ], - hdrs = ["passes.h"], + hdrs = [ + "multi_threaded_atom_program_compiler.h", + "passes.h", + ], compatible_with = get_compatible_with_portable(), deps = [ ":passes_inc_gen", ":utils", + "//xla:status_macros", + "//xla:xla_data_proto_cc", + "//xla/client:executable_build_options", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/pjrt:pjrt_executable", + "//xla/python/ifrt", + "//xla/python/ifrt/hlo:hlo_program", "//xla/python/ifrt/ir", + "//xla/python/ifrt/ir:atom_program_compiler", + "//xla/python/ifrt/ir:sharding_param", + "//xla/python/ifrt/support:sharding_conversions", + "//xla/python/pjrt_ifrt:xla_ifrt", + "//xla/service:compilation_environments", + "//xla/service:computation_placer_hdr", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:statusor", + "@stablehlo//:stablehlo_ops", ], ) @@ -73,7 +113,13 @@ cc_library( hdrs = ["utils.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla/mlir/utils:type_util", + "//xla/python/ifrt", + "//xla/python/ifrt/ir", + "//xla/python/pjrt_ifrt:pjrt_dtype", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc new file mode 100644 index 00000000000000..23a98bec10b7cc --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_and_propagate_shardings_pass.cc @@ -0,0 +1,647 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/sharding_param.h" +#include "xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/python/ifrt/support/sharding_conversions.h" +#include "xla/service/hlo.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +class IfrtCompileAndPropagateShardingsPass + : public mlir::PassWrapper> { + public: + explicit IfrtCompileAndPropagateShardingsPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) + : atom_program_compiler_(std::move(compiler), + std::move(compile_options_overrides), true), + atom_executable_map_(std::move(atom_executable_map)) {} + + llvm::StringRef getArgument() const override { + return "ifrt-compile-and-propagate-shardings"; + } + + llvm::StringRef getDescription() const override { + return "Compiles atom programs, propagates shardings to infer unspecified " + " shardings, and lowers CallOp to CallLoadedExecutableOp"; + } + + void getDependentDialects(::mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + IfrtCompileAndPropagateShardingsPass); + + private: + mlir::LogicalResult WalkCallOp( + CallOp call_op, mlir::SymbolTableCollection& symbol_table, + mlir::OpBuilder& builder, + llvm::DenseMap>& + module_to_compiled); + + // Propagates sharding to/from a ReshardOp. + // + // Cases: + // 1) All shardings are specified => no=op. + // 2) Output sharding is not specified => error. + // 3) Input sharding not specified, but output sharding is specified => the + // output sharding is propagated to the input. This case is only possible + // if the input is an argument of the main func, and thus the input sharding + // is propagated to the other ops using this argument. For example, in the + // following case: + // ``` + // func.func @main(%arg0: !array_no_sharding) -> () { + // %0, %ctrl_0 = ifrt.Reshard(%arg0) + // : (!array_no_sharding) -> !array_with_sharding + // %1, %ctrl_1 = ifrt.Call @program::@main(%arg0) on devices [0, 1] + // : (!array_no_sharding) -> !array_no_sharding + // } + // ``` + // The Reshard's output sharding will be propagated to the `@program` + // input. Thus, the order in which the ops appear in the IR, has significant + // effect on the shardings inferred. + mlir::LogicalResult WalkReshardOp(ReshardOp reshard_op); + + // See the documentation of `WalkReshardOp`. + mlir::LogicalResult WalkCopyArraysOp(CopyArraysOp copy_arrays_op); + + mlir::LogicalResult WalkCopyArraysOrReshardOp(mlir::Operation* op, + mlir::OperandRange inputs, + mlir::ResultRange outputs); + + // Replaces all the unspecified input shardings attributes in the atom + // program's main func. + mlir::LogicalResult LowerInputShardingToXla(CallOp call_op, + mlir::func::FuncOp callee, + mlir::OpBuilder& builder); + + // Returns a vector of `ShardingParam` for inputs. + // + // If an input sharding is unspecified in the IFRT IR op, then the sharding + // is fetched from the compiled executable. Otherwise, the sharding present in + // the IFRT IR op is returned. + mlir::FailureOr> GetInputShardingParams( + CallOp call_op, const AtomProgramCompileResult& compile_result); + + // Returns a vector of `ShardingParam` for outputs. + // + // If an output sharding is unspecified in the IFRT IR op, then the sharding + // is fetched from the compiled executable. Otherwise, the sharding present in + // the IFRT IR op is returned. + mlir::FailureOr> GetOutputShardingParams( + CallOp call_op, const AtomProgramCompileResult& compile_result); + + // The method does the following: + // 1) Populates the unspecified sharding attributes in the atom program's main + // func. + // 2) Replaces the sharding in op's input types with unspecified sharding. + // 3) Replaces op's outputs with unspecified shardings. + // 4) Replaces all the usage of replaced outputs. + mlir::FailureOr PropagateShardings( + CallOp call_op, mlir::func::FuncOp callee, + llvm::ArrayRef input_shardings, + llvm::ArrayRef output_shardings, mlir::OpBuilder& builder); + + // Generates a LoadedExecutableOp. + // Returns the symbol of the generated LoadedExecutableOp. + mlir::FailureOr GenerateLoadedExecutableOp( + mlir::ModuleOp module_op, absl::string_view symbol_name, CallOp call_op, + mlir::OpBuilder& builder); + + // Replaces the CallOp with a CallLoadedExecutableOp. + void ReplaceCallOpWithCallLoadedOp(CallOp call_op, + mlir::SymbolRefAttr loaded_exec_op_callee, + mlir::OpBuilder& builder); + + MultiThreadedAtomProgramCompiler atom_program_compiler_; + + // Map from symbol name of LoadedExecutableOp to LoadedExecutable. + std::shared_ptr atom_executable_map_; +}; + +mlir::LogicalResult IfrtCompileAndPropagateShardingsPass::WalkCallOp( + CallOp call_op, mlir::SymbolTableCollection& symbol_table, + mlir::OpBuilder& builder, + llvm::DenseMap>& + module_to_compiled) { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + auto callee_module = llvm::dyn_cast(callee->getParentOp()); + if (callee.getSymName() != kCalleeMainFuncName || callee_module == nullptr) { + return call_op.emitOpError() + << "requires callee outlined as `" << kCalleeMainFuncName + << "` function in a ModuleOp. Actual callee name: " + << callee.getSymName() + << ". Actual callee parent: " << callee->getParentOp()->getName(); + } + + if (auto compiled_it = module_to_compiled.find(callee_module); + compiled_it == module_to_compiled.end()) { + // Only dispatch for compilation if the module has not been dispatched yet. + if (mlir::failed(LowerInputShardingToXla(call_op, callee, builder))) + return mlir::failure(); + + absl::StatusOr compile_future = + atom_program_compiler_.CompileModule(call_op, callee_module); + if (!compile_future.ok()) { + return call_op.emitOpError() + << "failed to dispatch compilation for atom executable: " + << compile_future.status().message(); + } + auto compile_result = compile_future->Await(); + if (!compile_result.ok()) { + return call_op.emitOpError() << "failed to compile to atom executable: " + << compile_result.status().message(); + } + + // Get the input and output shardings from the compiled executable. Only + // the unspecified shardings are fetched from the executable, all the other + // shardings remain the same. + auto input_shardings = GetInputShardingParams(call_op, *compile_result); + if (mlir::failed(input_shardings)) return mlir::failure(); + auto output_shardings = GetOutputShardingParams(call_op, *compile_result); + if (mlir::failed(output_shardings)) return mlir::failure(); + // Change the CallOp signature and propagate shardings. + auto new_call_op = PropagateShardings(call_op, callee, *input_shardings, + *output_shardings, builder); + if (mlir::failed(new_call_op)) return mlir::failure(); + + // Create a LoadedExecutableOp for the atom program. + auto symbol_ref = GenerateLoadedExecutableOp( + callee_module, compile_result->name, *new_call_op, builder); + if (mlir::failed(symbol_ref)) return mlir::failure(); + + // Save the atom program executable to extend its lifetime. + CHECK(atom_executable_map_ + ->try_emplace(compile_result->name, compile_result->executable) + .second); + CHECK(module_to_compiled + .try_emplace(callee_module, + std::make_pair(*compile_result, *symbol_ref)) + .second); + ReplaceCallOpWithCallLoadedOp(*new_call_op, *symbol_ref, builder); + } else { + // The module has been compiled already. Get the unspecified input and + // output shardings from the compiled executable to set the shardings in the + // CallOp and to propagate them. + auto input_shardings = + GetInputShardingParams(call_op, compiled_it->second.first); + if (mlir::failed(input_shardings)) return mlir::failure(); + auto output_shardings = + GetOutputShardingParams(call_op, compiled_it->second.first); + if (mlir::failed(output_shardings)) return mlir::failure(); + auto new_call_op = PropagateShardings(call_op, callee, *input_shardings, + *output_shardings, builder); + if (mlir::failed(new_call_op)) return mlir::failure(); + ReplaceCallOpWithCallLoadedOp(*new_call_op, compiled_it->second.second, + builder); + } + return mlir::success(); +} + +mlir::LogicalResult IfrtCompileAndPropagateShardingsPass::WalkReshardOp( + ReshardOp reshard_op) { + return WalkCopyArraysOrReshardOp(reshard_op, reshard_op.getInputs(), + reshard_op.getOutputs()); +} + +mlir::LogicalResult IfrtCompileAndPropagateShardingsPass::WalkCopyArraysOp( + CopyArraysOp copy_arrays_op) { + return WalkCopyArraysOrReshardOp(copy_arrays_op, copy_arrays_op.getInputs(), + copy_arrays_op.getOutputs()); +} + +mlir::LogicalResult +IfrtCompileAndPropagateShardingsPass::WalkCopyArraysOrReshardOp( + mlir::Operation* op, mlir::OperandRange inputs, mlir::ResultRange outputs) { + for (auto [idx, pair] : llvm::enumerate(llvm::zip(inputs, outputs))) { + auto in_array_type = mlir::cast(std::get<0>(pair).getType()); + if (in_array_type == nullptr) { + op->emitOpError() << "requires all inputs to be `IfrtArrayType`. Input #" + << idx << ": " << std::get<0>(pair).getType(); + return mlir::failure(); + } + auto out_array_type = + mlir::cast(std::get<1>(pair).getType()); + if (out_array_type == nullptr) { + op->emitOpError() + << "requires all outputs to be `IfrtArrayType`. Output #" << idx + << ": " << std::get<1>(pair).getType(); + return mlir::failure(); + } + if (mlir::isa( + in_array_type.getShardingAttr())) { + if (llvm::isa( + out_array_type.getShardingAttr())) { + return op->emitOpError() + << "requires output #" << idx << " to have sharding specified."; + } else { + std::get<0>(pair).setType(out_array_type); + } + } + } + return mlir::success(); +} + +void IfrtCompileAndPropagateShardingsPass::runOnOperation() { + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&getContext()); + mlir::ModuleOp module_op = getOperation(); + llvm::DenseMap> + module_to_compiled; + auto compile_result = + module_op.walk([&](mlir::Operation* op) -> mlir::WalkResult { + if (mlir::isa(op)) { + if (mlir::failed(WalkCallOp(mlir::cast(op), symbol_table, + builder, module_to_compiled))) + return mlir::WalkResult::interrupt(); + } else if (mlir::isa(op)) { + if (mlir::failed(WalkReshardOp(mlir::cast(op)))) { + return mlir::WalkResult::interrupt(); + } + } else if (mlir::isa(op)) { + if (mlir::failed(WalkCopyArraysOp(mlir::cast(op)))) { + return mlir::WalkResult::interrupt(); + } + } + return mlir::WalkResult::advance(); + }); + + if (compile_result.wasInterrupted()) { + signalPassFailure(); + return; + } + + // Update the main function's result types. + mlir::func::FuncOp main_func = GetMainFunction(module_op); + UpdateFunctionType(main_func); +} + +mlir::LogicalResult +IfrtCompileAndPropagateShardingsPass::LowerInputShardingToXla( + CallOp call_op, mlir::func::FuncOp callee, mlir::OpBuilder& builder) { + CHECK_EQ(call_op.getInputs().size(), callee.getNumArguments()); + for (int idx = 0; idx < callee.getNumArguments(); ++idx) { + const auto hlo_sharding_attr = + callee.getArgAttrOfType(idx, kHloShardingAttrName); + if (hlo_sharding_attr == nullptr) { + // The input sharding is not set, see if it has been inferred and can + // be lowered from IFRT Sharding to HloSharding. + auto array_type = mlir::dyn_cast_or_null( + call_op.getInputs()[idx].getType()); + auto sharding_param_attr = mlir::dyn_cast_or_null( + array_type.getShardingAttr()); + if (sharding_param_attr != nullptr) { + // Set the newly inferred sharding on the callee's argument. + auto hlo_sharding = xla::ifrt::support::ToHloSharding( + sharding_param_attr.getSharding()); + callee.setArgAttr( + idx, kHloShardingAttrName, + builder.getStringAttr(hlo_sharding.value().ToString())); + } + } + } + return mlir::success(); +} + +mlir::FailureOr> +IfrtCompileAndPropagateShardingsPass::GetInputShardingParams( + CallOp call_op, const AtomProgramCompileResult& compile_result) { + std::optional> in_shardings = std::nullopt; + llvm::SmallVector in_sharding_params; + in_sharding_params.reserve(call_op.getInputs().size()); + for (const auto& [idx, input] : llvm::enumerate(call_op.getInputs())) { + const auto in_array_type = + mlir::dyn_cast_or_null(input.getType()); + // Get sharding from the executable if it is unspecified in the op. + if (llvm::isa( + in_array_type.getShardingAttr())) { + if (!in_shardings.has_value()) { + in_shardings = compile_result.executable->GetParameterShardings(); + if (!in_shardings.has_value()) { + return call_op.emitError() + << "executable does not have input shardings"; + } + if (in_shardings->size() != call_op.getOutputs().size()) { + return call_op.emitError() + << "mismatch between number of input shardings of executable " + "and op's return: " + << in_shardings->size() << " vs. " + << call_op.getOutputs().size(); + } + } + auto hlo_sharding = xla::HloSharding::FromProto(in_shardings->at(idx)); + if (!hlo_sharding.ok()) { + return call_op.emitError() + << "failed to convert sharding `OpSharding` to `HloSharding`: " + << hlo_sharding.status().message(); + } + auto sharding_param = xla::ifrt::support::ToShardingParam( + hlo_sharding.value(), in_array_type.getShape().getRank(), + in_array_type.getDevices().size()); + if (!sharding_param.ok()) { + return call_op.emitError() << "failed to convert sharding " + "`HloSharding` to `ShardingParam`: " + << sharding_param.status().message(); + } + in_sharding_params.push_back(std::move(sharding_param.value())); + } else { + auto sharding_param_attr = mlir::dyn_cast_or_null( + in_array_type.getShardingAttr()); + if (sharding_param_attr == nullptr) { + return call_op.emitError() << "input #" << idx + << " has sharding attribute that is not of " + "type `IfrtShardingParamAttr`."; + } + in_sharding_params.push_back(sharding_param_attr.getSharding()); + } + } + return in_sharding_params; +} + +mlir::FailureOr> +IfrtCompileAndPropagateShardingsPass::GetOutputShardingParams( + CallOp call_op, const AtomProgramCompileResult& compile_result) { + std::optional> out_shardings = std::nullopt; + llvm::SmallVector out_sharding_params; + out_sharding_params.reserve(call_op.getOutputs().size()); + for (const auto& [idx, output] : llvm::enumerate(call_op.getOutputs())) { + const auto out_array_type = + mlir::dyn_cast_or_null(output.getType()); + // Get sharding from the executable if it is unspecified in the op. + if (llvm::isa( + out_array_type.getShardingAttr())) { + if (!out_shardings.has_value()) { + out_shardings = compile_result.executable->GetOutputShardings(); + if (!out_shardings.has_value()) { + return call_op.emitError() + << "executable does not have output shardings"; + } + if (out_shardings->size() != call_op.getOutputs().size()) { + return call_op.emitError() + << "mismatch between number of output shardings of executable " + "and op's return: " + << out_shardings->size() << " vs. " + << call_op.getOutputs().size(); + } + } + auto hlo_sharding = xla::HloSharding::FromProto(out_shardings->at(idx)); + if (!hlo_sharding.ok()) { + return call_op.emitError() + << "failed to convert sharding `OpSharding` to `HloSharding`: " + << hlo_sharding.status().message(); + } + auto sharding_param = xla::ifrt::support::ToShardingParam( + hlo_sharding.value(), out_array_type.getShape().getRank(), + out_array_type.getDevices().size()); + if (!sharding_param.ok()) { + return call_op.emitError() << "failed to convert sharding " + "`HloSharding` to `ShardingParam`: " + << sharding_param.status().message(); + } + out_sharding_params.push_back(std::move(sharding_param.value())); + } else { + auto sharding_param_attr = mlir::dyn_cast_or_null( + out_array_type.getShardingAttr()); + if (sharding_param_attr == nullptr) { + return call_op.emitError() << "output #" << idx + << " has sharding attribute that is not of " + "type `IfrtShardingParamAttr`."; + } + out_sharding_params.push_back(sharding_param_attr.getSharding()); + } + } + return out_sharding_params; +} + +mlir::FailureOr +IfrtCompileAndPropagateShardingsPass::PropagateShardings( + CallOp call_op, mlir::func::FuncOp callee, + llvm::ArrayRef input_shardings, + llvm::ArrayRef output_shardings, mlir::OpBuilder& builder) { + CHECK_EQ(call_op.getOutputs().size(), callee.getNumResults()); + CHECK_EQ(input_shardings.size(), callee.getNumArguments()); + CHECK_EQ(output_shardings.size(), callee.getNumResults()); + + // Compute arg types. An unspecified sharding is replaced with its + // corresponding input sharding. + for (const auto& [idx, input] : llvm::enumerate(call_op.getInputs())) { + const auto array_type = + mlir::dyn_cast_or_null(input.getType()); + const auto unspecified_sharding_attr = + callee.getArgAttrOfType(idx, kHloShardingAttrName); + if (unspecified_sharding_attr == nullptr) { + auto hlo_sharding = + xla::ifrt::support::ToHloSharding(input_shardings[idx]); + if (!hlo_sharding.ok()) { + return call_op.emitOpError() + << "can't lower sharding of input #" << idx + << ". Sharding: " << input_shardings[idx].DebugString() << ". " + << hlo_sharding.status().message(); + } + callee.setArgAttr(idx, kHloShardingAttrName, + builder.getStringAttr(hlo_sharding.value().ToString())); + auto new_array_type = builder.getType( + array_type.getShape(), input_shardings[idx], array_type.getDevices()); + input.setType(new_array_type); + } + } + + // Compute result types. An unspecified sharding is replaced with its + // corresponding output sharding. + llvm::SmallVector new_call_op_result_types; + new_call_op_result_types.reserve(call_op.getOutputs().size()); + // The op must be replaced if at least an output sharding needs to be changed. + bool replace_call_op = false; + for (const auto& [idx, output] : llvm::enumerate(call_op.getOutputs())) { + if (callee.getResultAttrOfType( + idx, kHloShardingAttrName) == nullptr) { + auto hlo_sharding = + xla::ifrt::support::ToHloSharding(output_shardings[idx]); + if (!hlo_sharding.ok()) { + return call_op.emitOpError() + << "can't lower sharding of output #" << idx + << ". Sharding: " << output_shardings[idx].DebugString() << ". " + << hlo_sharding.status().message(); + } + callee.setResultAttr( + idx, kHloShardingAttrName, + builder.getStringAttr(hlo_sharding.value().ToString())); + } + + const auto array_type = + mlir::dyn_cast_or_null(output.getType()); + // If the CallOp has an output with an unspecified sharding, then the + // CallOp must be replaced with a new CallOp that has the propagated + // shardings. + if (mlir::isa(array_type.getShardingAttr())) { + replace_call_op = true; + auto new_array_type = builder.getType( + array_type.getShape(), output_shardings[idx], + array_type.getDevices()); + new_call_op_result_types.push_back(new_array_type); + } else { + new_call_op_result_types.push_back(output.getType()); + } + } + + if (replace_call_op) { + builder.setInsertionPointAfter(call_op); + auto new_call_op = builder.create( + call_op.getLoc(), /*outputs=*/new_call_op_result_types, + /*control_output=*/builder.getType(), + /*inputs=*/call_op.getInputs(), + /*control_inputs=*/call_op.getControlInputs(), + /*callee=*/call_op.getCallee(), + /*devices=*/call_op.getDevices(), + /*io_aliases=*/call_op.getIoAliases(), + /*donated_input_indices=*/call_op.getDonatedInputIndices()); + new_call_op->setDiscardableAttrs(call_op->getDiscardableAttrDictionary()); + for (auto [i, result] : llvm::enumerate(call_op.getOutputs())) { + result.replaceAllUsesWith(new_call_op.getOutputs()[i]); + } + call_op.getControlOutput().replaceAllUsesWith( + new_call_op.getControlOutput()); + call_op.erase(); + return new_call_op; + } else { + return call_op; + } +} + +mlir::FailureOr +IfrtCompileAndPropagateShardingsPass::GenerateLoadedExecutableOp( + mlir::ModuleOp module_op, absl::string_view symbol_name, CallOp call_op, + mlir::OpBuilder& builder) { + llvm::SmallVector input_types; + for (const mlir::Value input : call_op.getInputs()) { + input_types.push_back(input.getType()); + } + llvm::SmallVector output_types; + for (const mlir::Value output : call_op.getOutputs()) { + output_types.push_back(output.getType()); + } + builder.setInsertionPointAfter(module_op); + builder.create( + module_op.getLoc(), symbol_name, + builder.getFunctionType(input_types, output_types), + call_op.getDevicesAttr()); + return mlir::SymbolRefAttr::get(&getContext(), symbol_name); +} + +void IfrtCompileAndPropagateShardingsPass::ReplaceCallOpWithCallLoadedOp( + CallOp call_op, mlir::SymbolRefAttr loaded_exec_op_callee, + mlir::OpBuilder& builder) { + builder.setInsertionPointAfter(call_op); + auto call_loaded_op = builder.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getInputs(), + call_op.getControlInputs(), loaded_exec_op_callee, call_op.getIoAliases(), + call_op.getDonatedInputIndices()); + call_op.replaceAllUsesWith(call_loaded_op.getResults()); + call_op.erase(); +} + +} // namespace + +std::unique_ptr> +CreateIfrtCompileAndPropagateShardingsPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) { + CHECK(compiler != nullptr); + return std::make_unique( + std::move(compiler), std::move(compile_options_overrides), + std::move(atom_executable_map)); +} + +void RegisterIfrtCompileAndPropagateShardingsPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) { + mlir::registerPass( + [compiler = std::move(compiler), + compile_options_overrides = std::move(compile_options_overrides), + atom_executable_map = + std::move(atom_executable_map)]() -> std::unique_ptr { + return CreateIfrtCompileAndPropagateShardingsPass( + std::move(compiler), std::move(compile_options_overrides), + std::move(atom_executable_map)); + }); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc new file mode 100644 index 00000000000000..843058ba5b9a9f --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_compile_atom_program_pass.cc @@ -0,0 +1,278 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/service/hlo.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +class IfrtCompileAtomProgramPass + : public mlir::PassWrapper> { + public: + explicit IfrtCompileAtomProgramPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) + : atom_program_compiler_(std::move(compiler), + std::move(compile_options_overrides), false), + atom_executable_map_(std::move(atom_executable_map)) {} + + llvm::StringRef getArgument() const override { + return "ifrt-compile-atom-program"; + } + + llvm::StringRef getDescription() const override { + return "Compiles atom programs and lower CallOp to CallLoadedExecutableOp"; + } + + void getDependentDialects(::mlir::DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IfrtCompileAtomProgramPass); + + private: + // Generates a LoadedExecutableOp. + // Returns the symbol of the generated LoadedExecutableOp. + absl::StatusOr GenerateLoadedExecutableOp( + mlir::ModuleOp module_op, absl::string_view symbol_name, CallOp call_op, + mlir::OpBuilder& builder); + + MultiThreadedAtomProgramCompiler atom_program_compiler_; + + // Map from symbol name of LoadedExecutableOp to LoadedExecutable. + std::shared_ptr atom_executable_map_; +}; + +void IfrtCompileAtomProgramPass::runOnOperation() { + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&getContext()); + // Map from the hash of the CallOp to the compile future. + llvm::DenseMap call_to_compile_futures; + mlir::ModuleOp module_op = getOperation(); + // Walk and dispatch the compilations in parallel. + auto compile_result = + module_op.walk([&](CallOp call_op) -> mlir::WalkResult { + // Do not dispatch the atom program for compilation it has already been + // dispatched. + if (!call_to_compile_futures.contains(call_op)) { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + auto callee_module = + llvm::dyn_cast(callee->getParentOp()); + if (callee.getSymName() != kCalleeMainFuncName || + callee_module == nullptr) { + return call_op.emitOpError() + << "requires callee outlined as `" << kCalleeMainFuncName + << "` function in a ModuleOp. Actual callee name: " + << callee.getSymName() << ". Actual callee parent: " + << callee->getParentOp()->getName(); + } + absl::StatusOr compile_future = + atom_program_compiler_.CompileModule(call_op, callee_module); + if (!compile_future.ok()) { + return call_op.emitOpError() + << "failed to dispatch compilation for atom executable: " + << compile_future.status().ToString(); + } + // Clone the CallOp because it will be modified later, but we want + // to keep the original to be able to access the future. + call_to_compile_futures[call_op.clone()] = *std::move(compile_future); + } + return mlir::WalkResult::advance(); + }); + + bool pass_failed = false; + if (compile_result.wasInterrupted()) { + pass_failed = true; + } else { + // Map from the hash of the CallOp to the symbol ref of the + // LoadedExecutableOp. + llvm::DenseMap + call_op_to_loaded_exec_op_ref; + // Walk, wait on compilations, and generate LoadedExecutableOps. + auto result = + module_op.walk([&](CallOp call_op) -> mlir::WalkResult { + mlir::SymbolRefAttr loaded_exec_op_ref; + if (auto loaded_exec_op_ref_it = + call_op_to_loaded_exec_op_ref.find(call_op); + loaded_exec_op_ref_it != call_op_to_loaded_exec_op_ref.end()) { + // Reuse the symbol ref to the LoadedExecutableOp if we've already + // created an op for the CallOp. + loaded_exec_op_ref = loaded_exec_op_ref_it->second; + } else { + auto compile_result = call_to_compile_futures[call_op].Await(); + if (!compile_result.ok()) { + return call_op.emitOpError() + << "failed to compile to atom executable: " + << compile_result.status().ToString(); + } + auto callee_module = llvm::dyn_cast( + call_op.getCalleeOp(symbol_table)->getParentOp()); + absl::StatusOr symbol_ref = + GenerateLoadedExecutableOp(callee_module, compile_result->name, + call_op, builder); + if (!symbol_ref.ok()) { + return call_op.emitOpError() + << "failed to generate loaded executable op: " + << symbol_ref.status().ToString(); + } + loaded_exec_op_ref = *symbol_ref; + // Clone the CallOp because it will be modified next, but we want to + // keep the original to get the symbol ref for equal CallOps. + call_op_to_loaded_exec_op_ref[call_op.clone()] = loaded_exec_op_ref; + // Save the atom program executable to extend its lifetime. + CHECK(atom_executable_map_ + ->try_emplace(compile_result->name, + std::move(compile_result->executable)) + .second); + } + + // Generate CallLoadedExecutableOp. + builder.setInsertionPointAfter(call_op); + auto new_call = builder.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getInputs(), + call_op.getControlInputs(), loaded_exec_op_ref, + call_op.getIoAliases(), call_op.getDonatedInputIndices()); + new_call->setDiscardableAttrs( + call_op->getDiscardableAttrDictionary()); + call_op.replaceAllUsesWith(new_call.getResults()); + call_op.erase(); + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + pass_failed = true; + } + // Erase the CallOp clones that we're used as keys of the map. + for (auto& [call_op, loaded_exec_op_ref] : call_op_to_loaded_exec_op_ref) { + call_op.erase(); + } + } + + if (pass_failed) { + // Wait on all compile futures to ensure that they do not access + // this->compiler_ after the pass has been destructed. We don't care if + // the compilations succeed at this point because the pass has failed + // anyways. + for (auto& [call_op, future] : call_to_compile_futures) { + (void)future.Await(); + } + signalPassFailure(); + } + // Erase the CallOp clones that we're used as keys of the map. + for (auto& [call_op, future] : call_to_compile_futures) { + call_op.erase(); + } +} + +absl::StatusOr +IfrtCompileAtomProgramPass::GenerateLoadedExecutableOp( + mlir::ModuleOp module_op, absl::string_view symbol_name, CallOp call_op, + mlir::OpBuilder& builder) { + // Generate LoadedExecutableOp. + llvm::SmallVector input_types; + for (const mlir::Value input : call_op.getInputs()) { + input_types.push_back(input.getType()); + } + llvm::SmallVector output_types; + for (const mlir::Value output : call_op.getOutputs()) { + output_types.push_back(output.getType()); + } + builder.setInsertionPointAfter(module_op); + builder.create( + module_op.getLoc(), symbol_name, + builder.getFunctionType(input_types, output_types), + call_op.getDevicesAttr()); + return mlir::SymbolRefAttr::get(&getContext(), symbol_name); +} + +} // namespace + +std::unique_ptr> +CreateIfrtCompileAtomProgramPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) { + CHECK(compiler != nullptr); + return std::make_unique( + std::move(compiler), std::move(compile_options_overrides), + std::move(atom_executable_map)); +} + +void RegisterIfrtCompileAtomProgramPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map) { + mlir::registerPass( + [compiler = std::move(compiler), + compile_options_overrides = std::move(compile_options_overrides), + atom_executable_map = + std::move(atom_executable_map)]() -> std::unique_ptr { + return CreateIfrtCompileAtomProgramPass( + std::move(compiler), std::move(compile_options_overrides), + std::move(atom_executable_map)); + }); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_mpmd_reshard_to_call_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_mpmd_reshard_to_call_pass.cc new file mode 100644 index 00000000000000..5fcc8a59e65ad6 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_mpmd_reshard_to_call_pass.cc @@ -0,0 +1,199 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/btree_set.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "tsl/platform/fingerprint.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTLOWERMPMDRESHARDTOCALLPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +// Returns a fingerprint of the input and output types of a ReshardOp. +uint64_t ReshardFingerprint(ReshardOp reshard_op) { + std::string s; + llvm::raw_string_ostream os(s); + for (const auto& input : reshard_op.getInputs()) { + os << input.getType(); + } + for (const auto& output : reshard_op.getOutputs()) { + os << output.getType(); + } + // Whether the input is donated does not need to be included in the + // fingerprint because that does not affect the computations generated. + return tsl::Fingerprint64(os.str()); +} + +class IfrtLowerMpmdReshardToCallPass + : public impl::IfrtLowerMpmdReshardToCallPassBase< + IfrtLowerMpmdReshardToCallPass> { + public: + void runOnOperation() override { + mlir::ModuleOp module_op = getOperation(); + mlir::SymbolTable symbol_table(module_op); + mlir::OpBuilder builder(&getContext()); + mlir::func::FuncOp main_func = GetMainFunction(module_op); + auto result = main_func->walk([&](ReshardOp reshard_op) { + // Uniquify the devices the MPMD reshard executes on. It is unclear what + // the devices order should be, but it is fine to return them sorted + // because they are used only used for debugging purposes. + absl::btree_set device_set; + bool does_reshard = false; + for (const auto& [idx, pair] : llvm::enumerate( + llvm::zip(reshard_op.getInputs(), reshard_op.getOutputs()))) { + auto in_array_type = + mlir::cast(std::get<0>(pair).getType()); + if (in_array_type == nullptr) { + reshard_op.emitOpError() + << "requires all inputs to be `IfrtArrayType`. Input #" << idx + << ": " << std::get<0>(pair).getType(); + return mlir::WalkResult::interrupt(); + } + auto out_array_type = + mlir::cast(std::get<1>(pair).getType()); + if (out_array_type == nullptr) { + reshard_op.emitOpError() + << "requires all outputs to be `IfrtArrayType`. Output #" << idx + << ": " << std::get<1>(pair).getType(); + return mlir::WalkResult::interrupt(); + } + if (IsReshard(in_array_type, out_array_type)) { + does_reshard = true; + } + device_set.insert(in_array_type.getDevices().begin(), + in_array_type.getDevices().end()); + device_set.insert(out_array_type.getDevices().begin(), + out_array_type.getDevices().end()); + } + + if (!does_reshard) { + reshard_op.emitOpError() + << "does not reshard any arrays. Use CopyArraysOp instead"; + return mlir::WalkResult::interrupt(); + } + + std::vector devices(device_set.begin(), device_set.end()); + std::string module_sym_name = + absl::StrCat("reshard_", ReshardFingerprint(reshard_op)); + + auto reshard_module_op = mlir::dyn_cast_or_null( + module_op.lookupSymbol(module_sym_name)); + mlir::func::FuncOp reshard_func = nullptr; + if (reshard_module_op == nullptr) { + // Create a module corresponding to the reshard op. + builder.setInsertionPointToEnd(module_op.getBody()); + reshard_module_op = builder.create( + mlir::UnknownLoc::get(builder.getContext()), module_sym_name); + reshard_module_op.setVisibility(mlir::SymbolTable::Visibility::Private); + reshard_module_op->setAttr(kIfrtNumDevicesAttrName, + builder.getI32IntegerAttr(devices.size())); + + // Create the main func in the reshard module, and add the ReshardOp + // to it. + mlir::OpBuilder reshard_builder(reshard_module_op.getBodyRegion()); + reshard_func = reshard_builder.create( + reshard_module_op->getLoc(), kCalleeMainFuncName, + mlir::FunctionType::get(reshard_builder.getContext(), + reshard_op.getInputs().getTypes(), + reshard_op.getOutputs().getTypes())); + reshard_func.setVisibility(mlir::SymbolTable::Visibility::Public); + reshard_func->setAttr(kIfrtReshardFunctionAttrName, + builder.getUnitAttr()); + mlir::Block* entryBlock = reshard_func.addEntryBlock(); + reshard_builder.setInsertionPointToEnd(entryBlock); + auto inner_reshard_op = reshard_builder.create( + reshard_op.getLoc(), /*outputs=*/reshard_op.getOutputs().getTypes(), + /*control_output=*/reshard_op.getControlOutput().getType(), + /*inputs=*/reshard_func.getArguments(), + /*donated=*/reshard_op.getDonated(), + /*control_inputs=*/mlir::ValueRange()); + reshard_builder.create( + reshard_func.getLoc(), inner_reshard_op.getOutputs()); + } + + // Replace the ReshardOp with a CallOp. + builder.setInsertionPoint(reshard_op); + mlir::SymbolRefAttr reshard_func_symbol = mlir::SymbolRefAttr::get( + reshard_module_op.getSymNameAttr(), + mlir::SymbolRefAttr::get(GetMainFunction(reshard_module_op))); + llvm::SmallVector donated_input_indices; + if (reshard_op.getDonated()) { + donated_input_indices.resize(reshard_op.getInputs().size()); + std::iota(donated_input_indices.begin(), donated_input_indices.end(), + 0); + } + auto call_op = builder.create( + reshard_op.getLoc(), /*outputs=*/reshard_op.getOutputs().getTypes(), + /*control_output=*/reshard_op.getControlOutput().getType(), + /*inputs=*/reshard_op.getInputs(), + /*control_inputs=*/reshard_op.getControlInputs(), + /*callee=*/reshard_func_symbol, + /*devices=*/devices, + /*io_aliases=*/builder.getArrayAttr({}), + /*donated_input_indices=*/ + builder.getDenseI32ArrayAttr(donated_input_indices)); + call_op->setAttr(kIfrtModuleTypeAttrName, + builder.getStringAttr(kIfrtModuleTypeMpmdReshard)); + reshard_op.replaceAllUsesWith(call_op.getResults()); + reshard_op.erase(); + return mlir::WalkResult::advance(); + }); + if (result.wasInterrupted()) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateIfrtLowerMpmdReshardToCallPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_sharding_to_xla_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_sharding_to_xla_pass.cc new file mode 100644 index 00000000000000..f4afb854aab142 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_lower_sharding_to_xla_pass.cc @@ -0,0 +1,193 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_interfaces.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/python/ifrt/support/sharding_conversions.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTLOWERSHARDINGTOXLAPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +// Pass that does the following: +// 1) transforms kIfrtShardingAttrName attribute on the main FuncOp inputs and +// outputs to HloSharding. +// 2) sets FuncOps input/outputs kHloShardingAttrName attribute to the +// corresponding computed HloSharding. +class IfrtLowerShardingToXlaPass + : public impl::IfrtLowerShardingToXlaPassBase { + public: + void runOnOperation() override; +}; + +void IfrtLowerShardingToXlaPass::runOnOperation() { + mlir::OpBuilder builder(&getContext()); + mlir::ModuleOp module_op = getOperation(); + const auto num_devices_attr = + module_op->getAttrOfType(kIfrtNumDevicesAttrName); + if (num_devices_attr == nullptr) { + module_op.emitOpError() + << "module `" << module_op.getSymName().value_or("unknown").str() + << "` must have `" << kIfrtNumDevicesAttrName.str() << "` attribute"; + signalPassFailure(); + return; + } + int num_devices = num_devices_attr.getInt(); + mlir::func::FuncOp func_op = GetMainFunction(module_op); + auto local_view_attr = + module_op->getAttrOfType(kIfrtLocalViewAttrName); + // Lower input shardings. + for (int i = 0; i < func_op.getNumArguments(); ++i) { + const auto sharding_attr = + func_op.getArgAttrOfType( + i, kIfrtShardingAttrName); + if (sharding_attr == nullptr) { + // The op has already been visited, and the IFRT attributes have been + // removed. Verify that kHloShardingAttrName has been set. + if (func_op.getArgAttr(i, kHloShardingAttrName) == nullptr) { + func_op.emitOpError() << "can't find `" << kIfrtShardingAttrName + << "` attribute of input #" << i << " to set `" + << kHloShardingAttrName << "` attribute"; + signalPassFailure(); + return; + } + continue; + } else if (llvm::isa(sharding_attr)) { + // Sharding is not specified so we cannot lower to kHloShardingAttrName. + continue; + } + + // Verify that the input is sharded over the same number of devices as + // the computation. + auto attr_num_devices = sharding_attr.NumDevices(); + if (attr_num_devices != num_devices) { + func_op.emitOpError() + << "can't lower sharding of input #" << i + << ". Sharding: " << sharding_attr << " uses " << attr_num_devices + << " devices while computation uses " << num_devices << " devices"; + signalPassFailure(); + return; + } + + if (local_view_attr != nullptr) { + // The arguments to the function are already sharded, so we do not + // need to shard them again. + func_op.setArgAttr( + i, kHloShardingAttrName, + builder.getStringAttr(xla::HloSharding::Replicate().ToString())); + } else { + const auto sharding_param_attr = + func_op.getArgAttrOfType( + i, kIfrtShardingAttrName); + auto hlo_sharding = + xla::ifrt::support::ToHloSharding(sharding_param_attr.getSharding()); + if (!hlo_sharding.ok()) { + func_op.emitOpError() << "can't lower sharding of input #" << i + << ". Sharding: " << sharding_param_attr << ". " + << hlo_sharding.status().message(); + signalPassFailure(); + return; + } + func_op.setArgAttr( + i, kHloShardingAttrName, + builder.getStringAttr(hlo_sharding.value().ToString())); + } + } + + // Lower output shardings. + for (int i = 0; i < func_op.getNumResults(); ++i) { + const auto sharding_attr = + func_op.getResultAttrOfType( + i, kIfrtShardingAttrName); + if (sharding_attr == nullptr) { + // The op has already been visited, and the IFRT attributes have been + // removed. Verify that kHloShardingAttrName has been set. + if (func_op.getResultAttr(i, kHloShardingAttrName) == nullptr) { + func_op.emitOpError() << "can't find `" << kIfrtShardingAttrName + << "` attribute of output #" << i << " to set `" + << kHloShardingAttrName << "` attribute"; + signalPassFailure(); + return; + } + continue; + } else if (llvm::isa(sharding_attr)) { + // Sharding is not specified so we cannot lower to kHloShardingAttrName. + continue; + } + + // Verify that the output is sharded over the same number of devices as + // the computation. + auto attr_num_devices = sharding_attr.NumDevices(); + if (attr_num_devices != num_devices) { + func_op.emitOpError() + << "can't lower sharding of output #" << i + << ". Sharding: " << sharding_attr << " uses " << attr_num_devices + << " devices while computation uses " << num_devices << " devices"; + signalPassFailure(); + return; + } + + if (local_view_attr != nullptr) { + // The results of the function are already sharded, so we do not need + // to shard them again. + func_op.setResultAttr( + i, kHloShardingAttrName, + builder.getStringAttr(xla::HloSharding::Replicate().ToString())); + } else { + const auto sharding_param_attr = + func_op.getResultAttrOfType( + i, kIfrtShardingAttrName); + auto hlo_sharding = + xla::ifrt::support::ToHloSharding(sharding_param_attr.getSharding()); + if (!hlo_sharding.ok()) { + func_op.emitOpError() << "can't lower sharding of output #" << i + << ". Sharding: " << sharding_param_attr << ". " + << hlo_sharding.status().message(); + signalPassFailure(); + return; + } + func_op.setResultAttr( + i, kHloShardingAttrName, + builder.getStringAttr(hlo_sharding.value().ToString())); + } + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtLowerShardingToXlaPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc new file mode 100644 index 00000000000000..7915f3cf240f99 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc @@ -0,0 +1,212 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTPOPULATEATOMPROGRAMMETADATAPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +// Populates the metadata on the atom program ModuleOp and `main` FuncOp. +mlir::LogicalResult PopulateMetadata(CallOp call_op, mlir::ModuleOp module_op, + mlir::func::FuncOp callee_op, + mlir::OpBuilder& builder) { + module_op->setAttr(kIfrtNumDevicesAttrName, + builder.getI32IntegerAttr(call_op.getDevices().size())); + // Copy `ifrt.local_view` attribute if it exists. + if (call_op->hasAttrOfType(kIfrtLocalViewAttrName)) { + module_op->setAttr(kIfrtLocalViewAttrName, + call_op->getAttr(kIfrtLocalViewAttrName)); + } + + // Attach sharding to inputs. + for (const auto& [i, input] : llvm::enumerate(call_op.getInputs())) { + const auto array = mlir::dyn_cast_or_null(input.getType()); + if (array == nullptr) { + return call_op->emitOpError() + << "requires all inputs to be IfrtArrayType. Input #" << i << ": " + << input.getType(); + } + // It is faster to get all the attributes and add the new ones than + // setting the new attributes one-by-one. This is because the logic that + // sets an attribute converts the attr dict to a NamedAttrList, and then + // linearly searches for the attr. + llvm::SmallVector arg_attrs; + auto arg_attr_dict = callee_op.getArgAttrDict(i); + if (arg_attr_dict != nullptr) { + arg_attrs.append(arg_attr_dict.begin(), arg_attr_dict.end()); + } + arg_attrs.push_back( + builder.getNamedAttr(kIfrtShardingAttrName, array.getShardingAttr())); + callee_op.setArgAttrs(i, arg_attrs); + } + + // Attach sharding to outputs. + for (const auto& [i, output] : llvm::enumerate(call_op.getOutputs())) { + const auto array = mlir::dyn_cast_or_null(output.getType()); + if (array == nullptr) { + return call_op->emitOpError() + << "requires all outputs to be IfrtArrayType. Input #" << i << ": " + << output.getType(); + } + llvm::SmallVector res_attrs; + auto res_attr_dict = callee_op.getResultAttrDict(i); + if (res_attr_dict != nullptr) { + res_attrs.append(res_attr_dict.begin(), res_attr_dict.end()); + } + res_attrs.push_back( + builder.getNamedAttr(kIfrtShardingAttrName, array.getShardingAttr())); + callee_op.setResultAttrs(i, res_attrs); + } + + // Alias inputs. + for (const auto& raw_io_alias : + call_op.getIoAliases().getAsRange()) { + llvm::ArrayRef io_alias_as_array = raw_io_alias.asArrayRef(); + callee_op.setArgAttr(io_alias_as_array[0], "tf.aliasing_output", + builder.getI32IntegerAttr(io_alias_as_array[1])); + } + for (const auto idx : call_op.getDonatedInputIndices()) { + callee_op.setArgAttr(idx, "jax.buffer_donor", builder.getBoolAttr(true)); + } + return mlir::success(); +} + +class IfrtPopulateAtomProgramMetadataPass + : public impl::IfrtPopulateAtomProgramMetadataPassBase< + IfrtPopulateAtomProgramMetadataPass> { + public: + void runOnOperation() override; +}; + +void IfrtPopulateAtomProgramMetadataPass::runOnOperation() { + mlir::MLIRContext& context = getContext(); + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&context); + mlir::func::FuncOp main_func = GetMainFunction(getOperation()); + + // Construct a map from callee `SymbolRefAttr` to the unique `CallOps` + // using it. This map is used to decide if an atom program module must be + // cloned before populating its metadata (i.e., used more than once). + llvm::DenseMap> + callee_call_count; + for (CallOp call_op : main_func.getOps()) { + callee_call_count[call_op.getCallee()].insert(call_op); + } + + llvm::DenseMap visited_call_ops; + // Walk the CallOps in reverse order to ensure that the first CallOp using a + // callee uses the original callee. Otherwise, the walk would modify the name + // of the default callee. + auto result = main_func.walk([&](CallOp call_op) + -> mlir::WalkResult { + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + if (callee == nullptr) { + return call_op->emitOpError() + << "can't find callee `" << call_op.getCalleeAttr() << "`"; + } + auto callee_module = llvm::dyn_cast(callee->getParentOp()); + if (callee.getSymName() != kCalleeMainFuncName || + callee_module == nullptr) { + return call_op.emitOpError() + << "requires callee outlined as `" << kCalleeMainFuncName + << "` function in a ModuleOp. Actual callee name: " + << callee.getSymName() + << ". Actual callee parent: " << callee->getParentOp()->getName(); + } + + if (auto call_op_it = visited_call_ops.find(call_op); + call_op_it != visited_call_ops.end()) { + call_op.setCalleeAttr(call_op_it->second); + } else { + callee_call_count[call_op.getCallee()].erase(call_op); + if (!callee_call_count[call_op.getCallee()].empty()) { + // Only clone the callee if it is used more than once. + mlir::ModuleOp cloned_module = callee_module.clone(); + mlir::func::FuncOp cloned_callee = GetMainFunction(cloned_module); + // Insert new cloned atom program module in the SymbolTable. + symbol_table + .getSymbolTable( + callee_module->getParentWithTrait()) + .insert(cloned_module); + mlir::SymbolRefAttr callee_attr = mlir::SymbolRefAttr::get( + cloned_module.getSymNameAttr(), + mlir::SymbolRefAttr::get(cloned_callee.getSymNameAttr())); + auto populate_result = + PopulateMetadata(call_op, cloned_module, cloned_callee, builder); + if (mlir::failed(populate_result)) { + return populate_result; + } + // Clone the CallOp because it will be modified next. + visited_call_ops[call_op.clone()] = callee_attr; + call_op.setCalleeAttr(callee_attr); + } else { + auto populate_result = PopulateMetadata( + call_op, callee_module, GetMainFunction(callee_module), builder); + if (mlir::failed(populate_result)) { + return populate_result; + } + visited_call_ops[call_op.clone()] = call_op.getCalleeAttr(); + } + } + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + } + + // Erase the cloned CallOp because they were used only as keys of the map. + for (auto& [call_op, unused] : visited_call_ops) { + call_op.erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtPopulateAtomProgramMetadataPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_remove_ifrt_attrs_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_remove_ifrt_attrs_pass.cc new file mode 100644 index 00000000000000..ce3a6c41b27fb0 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_remove_ifrt_attrs_pass.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/python/ifrt/ir/constants.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTREMOVEIFRTATTRSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class IfrtRemoveIfrtAttrsPass + : public impl::IfrtRemoveIfrtAttrsPassBase { + public: + void runOnOperation() override; +}; + +void IfrtRemoveIfrtAttrsPass::runOnOperation() { + mlir::ModuleOp module_op = getOperation(); + module_op->removeAttr(kIfrtNumDevicesAttrName); + module_op->removeAttr(kIfrtLocalViewAttrName); + module_op.walk([&](mlir::func::FuncOp func_op) { + // Remove from function attributes. + for (auto attribute_name : {kIfrtDevicesAttrName, kIfrtShardingAttrName}) { + func_op->removeAttr(attribute_name); + } + + // Remove from argument attributes. + for (int i = 0; i < func_op.getNumArguments(); ++i) { + mlir::NamedAttrList arg_attrs = func_op.getArgAttrDict(i); + for (auto attribute_name : + {kIfrtDevicesAttrName, kIfrtShardingAttrName}) { + arg_attrs.erase(attribute_name); + } + func_op.setArgAttrs(i, arg_attrs); + } + // Remove from result attributes. + for (int i = 0; i < func_op.getNumResults(); ++i) { + mlir::NamedAttrList res_attrs = func_op.getResultAttrDict(i); + for (auto attribute_name : + {kIfrtDevicesAttrName, kIfrtShardingAttrName}) { + res_attrs.erase(attribute_name); + } + func_op.setResultAttrs(i, res_attrs); + } + }); +} +} // namespace + +std::unique_ptr> +CreateIfrtRemoveIfrtAttrsPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc new file mode 100644 index 00000000000000..2ffaf6f8a63f8e --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_reshard_to_copy_arrays_pass.cc @@ -0,0 +1,187 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTRESHARDTOCOPYARRAYSPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class ReshardToCopyArraysOpPattern + : public mlir::OpRewritePattern { + public: + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + xla::ifrt::ReshardOp op, mlir::PatternRewriter& rewriter) const override { + // Map from devices attribute to indices of the input arrays that are just + // copied to those devices. + llvm::DenseMap> + copy_indices; + // Indices of the input arrays that are resharded. + llvm::SmallVector reshard_indices; + for (const auto& [idx, pair] : + llvm::enumerate(llvm::zip(op.getInputs(), op.getOutputs()))) { + auto in_array_type = + mlir::cast(std::get<0>(pair).getType()); + if (in_array_type == nullptr) { + op.emitOpError() << "requires all inputs to be `IfrtArrayType`. Input #" + << idx << ": " << std::get<0>(pair).getType(); + return mlir::failure(); + } + auto out_array_type = + mlir::cast(std::get<1>(pair).getType()); + if (out_array_type == nullptr) { + op.emitOpError() + << "requires all outputs to be `IfrtArrayType`. Output #" << idx + << ": " << std::get<1>(pair).getType(); + return mlir::failure(); + } + if (!IsReshard(in_array_type, out_array_type)) { + copy_indices[out_array_type.getDevicesAttr()].push_back(idx); + } else { + reshard_indices.push_back(idx); + } + } + + if (reshard_indices.size() == op.getInputs().size()) { + // All arrays are resharded. No need to modify the ifrt.Reshard op. + return mlir::failure(); + } + + if (!op.getControlOutput().getUses().empty()) { + // If the control output dependency of the ifrt.Reshard op is used then it + // is unclear what to do with the newly added ifrt.CopyArrays ops. The + // conservative approach would be to add these as control dependencies to + // all the ops that have a control dependency on the ifrt.Reshard op. + // However, we could also add them just to the ops that have a control + // dependency on the ifrt.Reshard op and use the same devices. For now, + // we will just throw an error as the ifrt.Reshard control dependencies + // are not used at the moment. + op.emitOpError() << " cannot extract `ifrt.CopyArrays` from " + "`ifrt.Reshard` with control dependency output"; + return mlir::failure(); + } + + llvm::SmallVector outputs; + outputs.resize(op.getOutputs().size()); + // If an ifrt.Reshard is still left, then we replace the usage of the + // current ifrt.Reshard op's control output with its control output. + // Otherwise, we replace it with the control output of the last + // ifrt.CopyArrays op. + mlir::Value control_output; + + // Replace the ifrt.Reshard with a pruned version that only takes the arrays + // that are resharded. + llvm::SmallVector reshard_input_values; + llvm::SmallVector reshard_output_types; + for (int idx : reshard_indices) { + outputs[idx] = op.getOutputs()[idx]; + reshard_input_values.push_back(op.getInputs()[idx]); + reshard_output_types.push_back(op.getOutputs()[idx].getType()); + } + if (!reshard_input_values.empty()) { + auto reshard_op = rewriter.create( + op.getLoc(), + /*outputs=*/reshard_output_types, + /*control_output=*/op.getControlOutput().getType(), + /*inputs=*/reshard_input_values, + /*donated=*/op.getDonated(), + /*control_inputs=*/op.getControlInputs()); + for (const auto& [idx, output] : + llvm::zip(reshard_indices, reshard_op.getOutputs())) { + outputs[idx] = output; + } + control_output = reshard_op.getControlOutput(); + } + + // Add an ifrt.CopyArrays op for each set of arrays that are copied to a + // set of devices. The new ifrt.CopyArrays ops will inherit *all* the input + // control dependencies of the ifrt.Reshard op. They could receive a subset + // of the control dependencies (e.g., dependencies generated by ops running + // use the same devices as the ones the arrays are coppied to), but that is + // not supported yet. + for (const auto& [devices_attr, indices] : copy_indices) { + llvm::SmallVector copy_input_values; + llvm::SmallVector copy_output_types; + for (int idx : indices) { + copy_input_values.push_back(op.getInputs()[idx]); + copy_output_types.push_back(op.getOutputs()[idx].getType()); + } + auto copy_arrays_op = rewriter.create( + op.getLoc(), + /*outputs=*/copy_output_types, + /*control_output=*/op.getControlOutput().getType(), + /*inputs=*/copy_input_values, + /*donated=*/op.getDonated(), + /*control_inputs=*/op.getControlInputs()); + for (const auto& [idx, output] : + llvm::zip(indices, copy_arrays_op.getOutputs())) { + outputs[idx] = output; + } + if (reshard_input_values.empty()) { + control_output = copy_arrays_op.getControlOutput(); + } + } + outputs.push_back(control_output); + rewriter.replaceOp(op, outputs); + return mlir::success(); + } +}; + +class IfrtReshardToCopyArraysPass + : public impl::IfrtReshardToCopyArraysPassBase< + IfrtReshardToCopyArraysPass> { + public: + void runOnOperation() override { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + mlir::ModuleOp module_op = getOperation(); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(module_op, + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateIfrtReshardToCopyArraysPass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc new file mode 100644 index 00000000000000..d91b40aace0c67 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_bound_external_loaded_executable_pass.cc @@ -0,0 +1,207 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/support/sharding_conversions.h" +#include "xla/service/hlo.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace ifrt { + +namespace { + +class IfrtVerifyBoundExternalLoadedExecutablePass + : public mlir::PassWrapper> { + public: + explicit IfrtVerifyBoundExternalLoadedExecutablePass( + std::shared_ptr bound_executable_map) + : bound_executable_map_(std::move(bound_executable_map)) {} + + llvm::StringRef getArgument() const override { + return "ifrt-verify-bound-external-loaded-executable"; + } + + llvm::StringRef getDescription() const override { + return "Verifies that the bound external LoadedExecutables have number of " + "inputs/outputs, shape, and sharding as the corresponding externally" + " bound LoadedExecutable"; + } + + void runOnOperation() override; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + IfrtVerifyBoundExternalLoadedExecutablePass); + + private: + absl::Status VerifyShardingsEqual( + llvm::ArrayRef types, + const std::vector& shardings, + std::string_view sharding_type); + + // Map from symbol name of LoadedExecutableOp to externally bound + // LoadedExecutable. + std::shared_ptr bound_executable_map_; +}; + +absl::Status IfrtVerifyBoundExternalLoadedExecutablePass::VerifyShardingsEqual( + llvm::ArrayRef types, + const std::vector& shardings, + std::string_view sharding_type) { + for (const auto& it : llvm::enumerate(llvm::zip(types, shardings))) { + const auto& [param_type, sharding] = it.value(); + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(sharding)); + auto array_type = llvm::dyn_cast(param_type); + CHECK(array_type); + auto array_sharding = + llvm::dyn_cast(array_type.getShardingAttr()); + CHECK(array_sharding); + TF_ASSIGN_OR_RETURN( + const xla::HloSharding hlo_type_sharding, + xla::ifrt::support::ToHloSharding(array_sharding.getSharding())); + if (hlo_sharding != hlo_type_sharding) { + return absl::InvalidArgumentError(absl::StrCat( + "expects an executable with ", sharding_type, " #", it.index(), + " sharding ", hlo_sharding.ToString(/*include_metadata=*/false), + ", but was bound to an executable with sharding ", + hlo_type_sharding.ToString(/*include_metadata=*/false))); + } + } + return absl::OkStatus(); +} + +void IfrtVerifyBoundExternalLoadedExecutablePass::runOnOperation() { + mlir::ModuleOp module_op = getOperation(); + // Walk and dispatch the compilations in parallel. + auto result = module_op.walk([&](LoadedExecutableOp loaded_exec_op) + -> mlir::WalkResult { + const auto exec_it = + bound_executable_map_->find(loaded_exec_op.getSymName()); + if (exec_it != bound_executable_map_->end()) { + if (loaded_exec_op.getDevices().size() != + exec_it->second->num_devices()) { + return loaded_exec_op.emitOpError() + << "expects an executable with " + << loaded_exec_op.getDevices().size() + << " devices, but was bound to an executable with " + << exec_it->second->num_devices() << " devices"; + } + + auto func_type = loaded_exec_op.getFunctionType(); + if (!exec_it->second->GetParameterShardings().has_value()) { + return loaded_exec_op.emitOpError() + << "cannot be bound to an executable without parameter " + "shardings"; + } + if (!exec_it->second->GetOutputShardings().has_value()) { + return loaded_exec_op.emitOpError() + << "cannot be bound to an executable without output shardings"; + } + if (func_type.getNumInputs() != + exec_it->second->GetParameterShardings()->size()) { + return loaded_exec_op.emitOpError() + << "expects an executable with " << func_type.getNumInputs() + << " inputs, but was bound to an executable with " + << exec_it->second->GetParameterShardings()->size() << " inputs"; + } + if (func_type.getNumResults() != + exec_it->second->GetOutputShardings()->size()) { + return loaded_exec_op.emitOpError() + << "expects an executable with " << func_type.getNumResults() + << " results, but was bound to an executable with " + << exec_it->second->GetOutputShardings()->size() << " results"; + } + // Verify that the input and output shardings of the LoadedExecutableOp + // are the same as the shardings of the bound executable. + if (!exec_it->second->GetParameterShardings().has_value()) { + return loaded_exec_op.emitOpError() + << "cannot be bound to an executable without parameter " + "shardings"; + } + if (!exec_it->second->GetOutputShardings().has_value()) { + return loaded_exec_op.emitOpError() + << "cannot be bound to an executable without output " + "shardings"; + } + auto sharding_equal_status = VerifyShardingsEqual( + func_type.getInputs(), *exec_it->second->GetParameterShardings(), + "input"); + if (!sharding_equal_status.ok()) { + return loaded_exec_op.emitOpError() << sharding_equal_status.message(); + } + sharding_equal_status = VerifyShardingsEqual( + func_type.getResults(), *exec_it->second->GetOutputShardings(), + "output"); + if (!sharding_equal_status.ok()) { + return loaded_exec_op.emitOpError() << sharding_equal_status.message(); + } + } + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtVerifyBoundExternalLoadedExecutablePass( + std::shared_ptr bound_executable_map) { + return std::make_unique( + std::move(bound_executable_map)); +} + +void RegisterIfrtVerifyBoundExternalLoadedExecutablePass( + std::shared_ptr bound_executable_map) { + mlir::registerPass( + [bound_executable_map = + std::move(bound_executable_map)]() -> std::unique_ptr { + return CreateIfrtVerifyBoundExternalLoadedExecutablePass( + std::move(bound_executable_map)); + }); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index 7e3492147e1665..7a4fcfdf16be96 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -38,17 +38,100 @@ namespace { #include "xla/python/ifrt/ir/transforms/passes.h.inc" // Verifies that if the value is an input to the IR, then it has been donated. -mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, +mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, int idx, mlir::Value arg) { auto block_arg = mlir::dyn_cast(arg); mlir::func::FuncOp func_op = block_arg ? mlir::dyn_cast( block_arg.getOwner()->getParentOp()) : nullptr; - if (func_op && - func_op.getArgAttr(block_arg.getArgNumber(), - xla::ifrt::kIfrtDonatedArgAttrName) == nullptr) { - return op->emitOpError() << "input has not been donated to the program."; + if (func_op && func_op.getArgAttr(block_arg.getArgNumber(), + kIfrtDonatedArgAttrName) == nullptr) { + return op->emitOpError() + << "input #" << idx << " has not been donated to the program."; + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCallOpAliasesAndDonations( + T op, llvm::DenseMap& donated_value_to_op) { + llvm::DenseSet donated_input_idxs; + // Verify if a donated input is an argument of the main func, then it has + // also been donated by the user. + for (const auto idx : op.getDonatedInputIndices()) { + donated_input_idxs.insert(idx); + auto donated_value = op.getInputs()[idx]; + auto donated_it = donated_value_to_op.try_emplace(donated_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed(VerifyIfInputAndDonated(op, idx, donated_value))) { + return mlir::failure(); + } + } + + for (const auto& io_alias : + op.getIoAliases().template getAsRange()) { + mlir::ArrayRef io_alias_as_array = io_alias.asArrayRef(); + donated_input_idxs.insert(io_alias_as_array[0]); + auto aliased_value = op.getInputs()[io_alias_as_array[0]]; + auto donated_it = donated_value_to_op.try_emplace(aliased_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << io_alias_as_array[0] << " of " + << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed( + VerifyIfInputAndDonated(op, io_alias_as_array[0], aliased_value))) { + return mlir::failure(); + } + } + + // Verify non-donated inputs after donated inputs have been + // added to also catch instances such as + // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + if (!donated_input_idxs.contains(idx)) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCopyRemapAndReshardOpsDonation( + T op, llvm::DenseMap& donated_value_to_op) { + // Verify that no inputs have already been donated. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of op at " << op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + if (op.getDonated()) { + // Add the donated inputs to the map and verify that all the + // donated inputs are also donated to the main func. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + donated_value_to_op.try_emplace(input, op); + if (mlir::failed(VerifyIfInputAndDonated(op, idx, input))) { + return mlir::failure(); + } + } } return mlir::success(); } @@ -74,72 +157,12 @@ void IfrtVerifyDonationPass::runOnOperation() { -> mlir::WalkResult { auto result = llvm::TypeSwitch(op) - .Case( - [&](auto& op) { - llvm::DenseSet donated_input_idxs; - for (const auto& io_alias : - op.getIoAliases() - .template getAsRange()) { - mlir::ArrayRef io_alias_as_array = - io_alias.asArrayRef(); - donated_input_idxs.insert(io_alias_as_array[0]); - auto donated_value = op.getInputs()[io_alias_as_array[0]]; - auto donated_it = - donated_value_to_op.try_emplace(donated_value, op); - if (!donated_it.second) { - op.emitOpError() << "input #" << io_alias_as_array[0] - << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it.first->second->getLoc(); - return mlir::failure(); - } - if (mlir::failed( - VerifyIfInputAndDonated(op, donated_value))) { - return mlir::failure(); - } - } - // Verify non-donated inputs after donated inputs have been - // added to also catch instances such as - // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (!donated_input_idxs.contains(idx)) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - } - return mlir::success(); - }) - .Case([&](auto& op) { - // Verify that no inputs have already been donated. - for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of op at " << op.getLoc() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - if (op.getDonated()) { - // Add the donated inputs to the map and verify that all the - // donated inputs are also donated to the main func. - for (const auto input : op.getInputs()) { - donated_value_to_op.try_emplace(input, op); - if (mlir::failed(VerifyIfInputAndDonated(op, input))) { - return mlir::failure(); - } - } - } - return mlir::success(); + .Case([&](auto& op) { + return verifyCallOpAliasesAndDonations(op, donated_value_to_op); + }) + .Case([&](auto& op) { + return verifyCopyRemapAndReshardOpsDonation(op, + donated_value_to_op); }) .Case([&](mlir::func::ReturnOp return_op) { for (const auto& [idx, result] : diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.cc b/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.cc new file mode 100644 index 00000000000000..de3d43f30e918d --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.cc @@ -0,0 +1,271 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h" + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" +#include "xla/client/executable_build_options.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/hlo/hlo_program.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" +#include "xla/python/ifrt/shape.h" +#include "xla/python/pjrt_ifrt/xla_compiler.h" +#include "xla/service/compilation_environments.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo.pb.h" +#include "xla/status_macros.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { + +namespace { + +// Lazily initialized shared thread pool. +tsl::thread::ThreadPool* thread_pool() { + static tsl::thread::ThreadPool* thread_pool = []() { + constexpr int kMaxParallelism = 32; + return new tsl::thread::ThreadPool(tsl::Env::Default(), + tsl::ThreadOptions(), + "CompileAtomPrograms", kMaxParallelism); + }(); + return thread_pool; +} + +void ScheduleWork(tsl::thread::ThreadPool* pool, + absl::AnyInvocable callee) { + // ThreadPool expects std::function that must be copyable, but we can avoid + // this by using AnyInvocable. + pool->Schedule([ptr = new absl::AnyInvocable(std::move(callee))]() { + (*ptr)(); + delete ptr; + }); +} + +// Construct a bool vector with a True entry for each input sharding that must +// be inferred. +llvm::SmallVector GetInputShardingPropagation( + mlir::func::FuncOp func_op) { + llvm::SmallVector sharding_propagation_to_input; + sharding_propagation_to_input.reserve(func_op.getNumArguments()); + for (int idx = 0; idx < func_op.getNumArguments(); ++idx) { + const auto hlo_sharding_attr = + func_op.getArgAttrOfType(idx, kHloShardingAttrName); + if (hlo_sharding_attr == nullptr) { + sharding_propagation_to_input.push_back(true); + } else { + sharding_propagation_to_input.push_back(false); + } + } + return sharding_propagation_to_input; +} + +// Construct a bool vector with a True entry for each output sharding that must +// be inferred. +llvm::SmallVector GetOutputShardingPropagation( + mlir::func::FuncOp func_op) { + llvm::SmallVector sharding_propagation_to_output; + sharding_propagation_to_output.reserve(func_op.getNumResults()); + for (int idx = 0; idx < func_op.getNumResults(); ++idx) { + const auto hlo_sharding_attr = + func_op.getResultAttrOfType(idx, + kHloShardingAttrName); + if (hlo_sharding_attr == nullptr) { + sharding_propagation_to_output.push_back(true); + } else { + sharding_propagation_to_output.push_back(false); + } + } + return sharding_propagation_to_output; +} + +} // namespace + +absl::StatusOr MultiThreadedAtomProgramCompiler::CompileModule( + CallOp call_op, mlir::ModuleOp module_op) { + auto module_type = + call_op->getAttrOfType(kIfrtModuleTypeAttrName); + if (module_type == kIfrtModuleTypeXla) { + return CompileXla(call_op, module_op, thread_pool()); + } else if (module_type == kIfrtModuleTypeMpmdReshard) { + return CompileMpmdReshard(module_op); + } else if (module_type == nullptr) { + return absl::InvalidArgumentError(absl::StrCat( + "CallOp requires `", kIfrtModuleTypeAttrName.str(), "` to be set")); + } else { + return absl::InvalidArgumentError( + absl::StrCat("No compiler for module type: ", module_type.str())); + } +} + +absl::StatusOr +MultiThreadedAtomProgramCompiler::GetXlaCompileOptions( + CallOp call_op, mlir::ModuleOp module_op) { + xla::CompileOptions compile_options; + + // If the CallOp has a compile options key, then try to use the provided + // compile options. + auto compile_options_key = + call_op->getAttrOfType(kIfrtCompileOptionsKey); + bool has_compile_options = false; + if (compile_options_overrides_ != nullptr && compile_options_key != nullptr) { + if (auto compile_options_override = + compile_options_overrides_->find(compile_options_key.str()); + compile_options_override != compile_options_overrides_->end()) { + if (auto xla_options = llvm::dyn_cast( + compile_options_override->second.get())) { + compile_options = xla_options->compile_options; + has_compile_options = true; + } else { + return absl::InvalidArgumentError(absl::StrCat( + "The `", kIfrtCompileOptionsKey.str(), "` compile options key `", + compile_options_key.str(), + "` has an entry that is not of type `XlaCompileOptions`, but the " + "atom program is an XLA program.")); + } + } + } + + if (!has_compile_options) { + auto& exec_build_options = compile_options.executable_build_options; + // Executable build options are constructed using logical ids, which are + // later converted into real Device ids by using the logical ids as + // indices into the device list given at compilation invocation time. + llvm::ArrayRef logical_device_ids = call_op.getDevices(); + if (call_op->hasAttrOfType(kIfrtLocalViewAttrName)) { + exec_build_options.set_num_replicas(logical_device_ids.size()); + exec_build_options.set_num_partitions(1); + xla::DeviceAssignment device_assignment(logical_device_ids.size(), 1); + for (const auto [i, device_id] : llvm::enumerate(logical_device_ids)) { + device_assignment(i, 0) = device_id; + } + exec_build_options.set_device_assignment(device_assignment); + } else { + exec_build_options.set_num_replicas(1); + exec_build_options.set_num_partitions(logical_device_ids.size()); + xla::DeviceAssignment device_assignment(1, logical_device_ids.size()); + for (const auto [i, device_id] : llvm::enumerate(logical_device_ids)) { + device_assignment(0, i) = device_id; + } + exec_build_options.set_device_assignment(device_assignment); + exec_build_options.set_use_spmd_partitioning(true); + if (enable_sharding_propagation_) { + mlir::func::FuncOp main_func = GetMainFunction(module_op); + exec_build_options.set_allow_spmd_sharding_propagation_to_parameters( + GetInputShardingPropagation(main_func)); + exec_build_options.set_allow_spmd_sharding_propagation_to_output( + GetOutputShardingPropagation(main_func)); + } + } + } + + return compile_options; +} + +absl::StatusOr MultiThreadedAtomProgramCompiler::CompileXla( + CallOp call_op, mlir::ModuleOp module_op, + tsl::thread::ThreadPool* thread_pool) { + TF_ASSIGN_OR_RETURN(xla::CompileOptions compile_options, + GetXlaCompileOptions(call_op, module_op)); + + // We must clone the module in order ensure the module string representation + // is maintained. This is because MLIR printing takes different paths + // depending on if a ModuleOp has a parent or not. + + auto hlo_program = std::make_unique( + /*context=*/nullptr, // Shares the same long-living context. + mlir::OwningOpRef(module_op.clone())); + Promise promise = CompileFuture::CreatePromise(); + CompileFuture future(promise); + ScheduleWork( + thread_pool, [this, hlo_program = std::move(hlo_program), + compile_options = std::move(compile_options), + promise = std::move(promise)]() mutable { + promise.Set(compiler_->CompileXla(std::move(hlo_program), + std::move(compile_options))); + }); + return future; +} + +absl::StatusOr +MultiThreadedAtomProgramCompiler::CompileMpmdReshard(mlir::ModuleOp module_op) { + auto main_func = + module_op.lookupSymbol(kCalleeMainFuncName); + TF_RET_CHECK(main_func) << "requires module to have" + << kCalleeMainFuncName.str() << " function"; + std::vector dtypes; + std::vector shapes; + std::vector in_arrays_types; + std::vector out_arrays_types; + dtypes.reserve(main_func.getArgumentTypes().size()); + shapes.reserve(main_func.getArgumentTypes().size()); + in_arrays_types.reserve(main_func.getArgumentTypes().size()); + out_arrays_types.reserve(main_func.getResultTypes().size()); + for (const mlir::Type arg_type : main_func.getArgumentTypes()) { + auto array_type = mlir::dyn_cast(arg_type); + TF_RET_CHECK(array_type != nullptr) + << "Unsupported argument type `" << mlir::debugString(arg_type) << "`"; + TF_ASSIGN_OR_RETURN(DType dtype, + ToIfrtDType(array_type.getShape().getElementType())); + dtypes.push_back(std::move(dtype)); + shapes.push_back(Shape(array_type.getShape().getShape())); + in_arrays_types.push_back(array_type); + } + for (const mlir::Type result_type : main_func.getResultTypes()) { + auto array_type = mlir::dyn_cast(result_type); + TF_RET_CHECK(array_type != nullptr) + << "Unsupported return type `" << mlir::debugString(result_type) << "`"; + out_arrays_types.push_back(array_type); + } + auto promise = CompileFuture::CreatePromise(); + CompileFuture future(promise); + // No need to dispatch from a different thread because MpmdReshard uses its + // own thread pool already. + auto compile_result = compiler_->CompileMpmdReshard( + std::move(dtypes), std::move(shapes), in_arrays_types, out_arrays_types); + promise.Set(std::move(compile_result)); + return future; +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h b/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h new file mode 100644 index 00000000000000..c004b522dcb020 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/multi_threaded_atom_program_compiler.h @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_MULTI_THREADED_ATOM_PROGRAM_COMPILER_H_ +#define XLA_PYTHON_IFRT_IR_TRANSFORMS_MULTI_THREADED_ATOM_PROGRAM_COMPILER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" +#include "xla/pjrt/pjrt_executable.h" +#include "xla/python/ifrt/compiler.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace ifrt { + +using CompileFuture = Future; + +// Wrapper around `AtomProgramCompiler` that offers multi-threaded dispatch +// of atom program compilations. +class MultiThreadedAtomProgramCompiler { + public: + explicit MultiThreadedAtomProgramCompiler( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + bool enable_sharding_propagation) + : compiler_(std::move(compiler)), + compile_options_overrides_(std::move(compile_options_overrides)), + enable_sharding_propagation_{enable_sharding_propagation} {} + + // Dispatches compilation of an atom program module. + // Depending on the type of module, a MLIR pipeline might be executed before + // the compilation is dispatched. + absl::StatusOr CompileModule(CallOp, mlir::ModuleOp module_op); + + private: + // Compiles an atom XLA program. + // Returns a future of a AtomProgramCompileResult for the compiled module. + // + // Note that the method runs `ifrt-compile-xla-preprocessing-pipeline` + // before dispatching compilation. + absl::StatusOr CompileXla( + CallOp call_op, mlir::ModuleOp module_op, + tsl::thread::ThreadPool* thread_pool); + + // Returns a future of a AtomProgramCompileResult for the MPMD reshard module. + absl::StatusOr CompileMpmdReshard(mlir::ModuleOp module_op); + + // Gets the XLA compile options for the given atom program module. + absl::StatusOr GetXlaCompileOptions( + CallOp call_op, mlir::ModuleOp module_op); + + std::shared_ptr compiler_; + + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides_; + + // Whether to allow sharding propagation from inputs to outputs that do not + // have sharding specified (i.e., their mhlo.sharding attribute is not set). + bool enable_sharding_propagation_; +}; + +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_MULTI_THREADED_ATOM_PROGRAM_COMPILER_H_ diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc new file mode 100644 index 00000000000000..4ab08ee45ec61c --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.cc @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/ir/transforms/passes.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/Passes.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" + +namespace xla { +namespace ifrt { + +void CreateIfrtToOutlinedAtomProgramsPipeline( + mlir::OpPassManager& pm, + const IfrtToOutlinedAtomProgramsPipelineOptions& options) { + // Passes that verify the correctness of the module. + pm.addPass(CreateSpmdExpandableInterfaceVerificationPass( + {{mlir::mhlo::MhloDialect::getDialectNamespace().str(), + mlir::stablehlo::StablehloDialect::getDialectNamespace().str()}})); + pm.addNestedPass(CreateIfrtVerifyDonationPass()); + + // Passes that outline atom programs to modules and set their metadata. + pm.addPass(CreateIfrtOutlineAtomProgramToModulePass()); + pm.addPass(CreateIfrtPopulateAtomProgramMetadataPass()); + pm.addPass(CreateIfrtDuplicatedCalleeEliminationPass()); + pm.addPass(mlir::createSymbolDCEPass()); + + if (!options.propagate_shardings) { + pm.addPass(CreateIfrtVerifyShardingSpecifiedPass()); + // We can split ifrt.Reshard to ifrt.CopyArrays because all the shardings + // are specified. + pm.addPass(CreateIfrtReshardToCopyArraysPass()); + } +} + +void CreateIfrtCompileXlaPreprocessingPipeline(mlir::OpPassManager& pm) { + pm.addPass(CreateIfrtLowerShardingToXlaPass()); + pm.addPass(CreateIfrtRemoveIfrtAttrsPass()); +} + +void RegisterIfrtPassesAndPipelines( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map, + std::shared_ptr bound_executable_map) { + registerIfrtIrPasses(); + RegisterIfrtCompileAtomProgramPass(compiler, compile_options_overrides, + atom_executable_map); + RegisterIfrtCompileAndPropagateShardingsPass( + compiler, compile_options_overrides, atom_executable_map); + RegisterIfrtVerifyBoundExternalLoadedExecutablePass(bound_executable_map); + mlir::PassPipelineRegistration( + "ifrt-to-outlined-atom-programs-pipeline", + "Runs passes that do not require compilation-time information", + CreateIfrtToOutlinedAtomProgramsPipeline); + mlir::PassPipelineRegistration<>( + "ifrt-compile-xla-preprocessing-pipeline", + "Run passes to lower an IFRT XLA program for XLA compilation", + CreateIfrtCompileXlaPreprocessingPipeline); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index da7ec1ab599795..d3cac2a0835839 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -17,10 +17,16 @@ limitations under the License. #define XLA_PYTHON_IFRT_IR_TRANSFORMS_PASSES_H_ #include +#include +#include "absl/container/flat_hash_map.h" +#include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassOptions.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/ir/atom_program_compiler.h" namespace xla { namespace ifrt { @@ -49,10 +55,114 @@ CreateIfrtVerifyDonationPass(); std::unique_ptr> CreateIfrtVerifyShardingSpecifiedPass(); +std::unique_ptr> +CreateIfrtPopulateAtomProgramMetadataPass(); + +std::unique_ptr> +CreateIfrtReshardToCopyArraysPass(); + +std::unique_ptr> +CreateIfrtLowerShardingToXlaPass(); + +std::unique_ptr> +CreateIfrtRemoveIfrtAttrsPass(); + +std::unique_ptr> +CreateIfrtLowerMpmdReshardToCallPass(); + +std::unique_ptr> +CreateIfrtVerifyBoundExternalLoadedExecutablePass( + std::shared_ptr bound_executable_map); + +// Compiles every atom program ModuleOp into LoadedExecutableOp, and +// lowers every CallOp to CallLoadedExecutableOp. +// +// This pass is not declared in td file as it doesn't have a default +// constructor. It uses an outside AtomProgramCompiler to delegate the +// compilation of atom programs. +// +// For example, the following code +// ``` +// %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0, 1] +// +// module @callee attributes { +// func.func @main() {} +// } +// ``` +// +// will be replaced by +// ``` +// %0, %ctrl_0 = ifrt.CallLoadedExecutable @component__method(%arg0) +// +// ifrt.LoadedExecutable @component__method on devices [0, 1] +// ``` +// } +std::unique_ptr> +CreateIfrtCompileAtomProgramPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options, + std::shared_ptr atom_executable_map); + +std::unique_ptr> +CreateIfrtCompileAndPropagateShardingsPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options, + std::shared_ptr atom_executable_map); + // Generated definitions. This should be placed after all Pass creations. #define GEN_PASS_REGISTRATION #include "xla/python/ifrt/ir/transforms/passes.h.inc" // IWYU pragma: export +// Registers IfrtCompileAtomProgramPass to ifrt-opt. +void RegisterIfrtCompileAtomProgramPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map); + +// Registers IfrtCompileAndPropagateShardingsPass to ifrt-opt. +void RegisterIfrtCompileAndPropagateShardingsPass( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map); + +// Registers IfrtVerifyBoundExternalLoadedExecutablePass to ifrt-opt. +void RegisterIfrtVerifyBoundExternalLoadedExecutablePass( + std::shared_ptr bound_executable_map); + +struct IfrtToOutlinedAtomProgramsPipelineOptions + : mlir::PassPipelineOptions { + Option propagate_shardings{ + *this, "propagate_shardings", + llvm::cl::desc("Whether to propagate shardings from executables for " + "unspecified shardings.")}; +}; + +// Creates pipeline of all the IFRT IR passes that do not require +// compilation-time information (e.g., device assignments). +void CreateIfrtToOutlinedAtomProgramsPipeline( + mlir::OpPassManager& pm, + const IfrtToOutlinedAtomProgramsPipelineOptions& options); + +// Creates pipeline to lower an IFRT XLA program to be ready for compilation. +void CreateIfrtCompileXlaPreprocessingPipeline(mlir::OpPassManager& pm); + +// Registers passes and pipelines to ifrt-opt. +void RegisterIfrtPassesAndPipelines( + std::shared_ptr compiler, + std::shared_ptr< + absl::flat_hash_map>> + compile_options_overrides, + std::shared_ptr atom_executable_map, + std::shared_ptr bound_executable_map); + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index 10215b72653e0c..054e983a6499ab 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -203,5 +203,222 @@ Verify that each `!ifrt.array` has sharding attribute that is not of type let constructor = "CreateIfrtVerifyShardingSpecifiedPass()"; } +def IfrtPopulateAtomProgramMetadataPass : + Pass<"ifrt-populate-atom-program-metadata", "mlir::ModuleOp"> { + let summary = "Populate metadata from call site to atom functions"; + let description = [{ +For every CallOp, this pass + 1. clones the callee's parent ModuleOp + 2. adds `ifrt.num_devices` attribute to the callee's parent ModuleOp + 2. attaches shardings and devices to the inputs and outputs of the callee's + main FuncOp + 3. attaches `tf.aliasing_output` attr to the callee main FuncOp's inputs + according to `io_aliases` + 4. attaches `jax.buffer_donor` attr to the callee main FuncOp's inputs + according to `donated_input_indices` + +For CallOps with the same callee, a different clone will be created for each +CallOp, even if the populated metadata are the same. User may want to run +`ifrt-duplicated-callee-elimination` pass to dedup the clones. + +For example, the following code + +```mlir +%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0, 1] + {io_aliases=[array]} + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>, + !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @callee { + func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<4x4xi32>) + -> tensor<4x4xi32> {} +} +``` + +will be replaced by + +```mlir +%0, %ctrl_0 = ifrt.Call @new_callee::@main(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @new_callee attributes {ifrt.num_devices = 2} { + func.func private @new_callee( + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>, + ifrt.devices = #ifrt}, + %arg1: tensor<4x4xi32> { + ifrt.sharding = #ifrt.sharding_param<1x2 to [0] on 2>, + ifrt.devices = #ifrt + tf.aliasing_output = 0 : i32}) + -> (tensor<4x4xi32> { + ifrt.sharding = #ifrt.sharding_param<1x2 to [0] on 2>, + ifrt.devices = #ifrt}) + {} +} +``` + }]; + + let constructor = "CreateIfrtPopulateAtomProgramMetadataPass()"; +} + +def IfrtReshardToCopyArraysPass : + Pass<"ifrt-reshard-to-copy-arrays", "mlir::ModuleOp"> { + let summary = "Replaces `ifrt.Reshard` with `ifrt.CopyArrays`"; + let description = [{ +Replaces each `ifrt.Reshard` op with an `ifrt.Reshard` op with inputs only +the arrays that are being resharded, and several `ifrt.CopyArrays` ops to copy +the arrays that are not being resharded. An `ifrt.CopyArrays` op is added for +unique output `ifrt.DevicesAttr`. + +For example, the following code +```mlir +!array0 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]> +!array1 = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [2,3]> +!array2 = !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, + [2,3]> +func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + %0, %1, %ctrl_0 = ifrt.Reshard(%arg0, %arg1) + : (!array0, !array1) -> (!array1, !array2) + return %0, %1: !array1, !array2 +} +``` + +will be replaced by: + +```mlir +func.func @main(%arg0: !array0, %arg1: !array1) -> (!array1, !array2) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Reshard(%arg1) : (!array1) -> !array2 + %1, %ctrl_1 = ifrt.CopyArrays(%arg0) : (!array0) -> !array1 + return %0, %1: !array1, !array2 +} +``` + }]; + + let constructor = "CreateIfrtReshardToCopyArraysPass()"; +} + +def IfrtLowerShardingToXlaPass : + Pass<"ifrt-lower-sharding-to-xla", "mlir::ModuleOp"> { + let summary = "Converts IFRT sharding to HLO sharding for xla modules"; + let description = [{ +Replaces `ifrt.sharding` attr by `mhlo.sharding` attr. + +For example, the following code + +```mlir +module attributes { + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>}) + -> (tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2>}) {} +} +``` + +will be replaced by + +```mlir +module attributes { + func.func @main( + %arg0: tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) + -> (tensor<2x2xi32> {mhlo.sharding = "{devices=[2,1]<=[2]}"}) {} +} +``` + }]; + + let constructor = "CreateIfrtLowerShardingToXlaPass()"; +} + +def IfrtRemoveIfrtAttrsPass : + Pass<"ifrt-remove-ifrt-attrs", "mlir::ModuleOp"> { + let summary = "Remove IFRT attrs from modules and functions of atom programs"; + let description = [{ +Remove IFRT-related attributes from module and functions' arguments and +results. + +This is necessary for the compilation of TfSpmd program. + +For example, the following code + +```mlir +module { + func.func @main( + %arg0: tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<1x1 to [0] on 1>}) + -> (tensor<2x2xi32> { + ifrt.sharding = #ifrt.sharding_param<1x1 to [0] on 1>}) {...} +} +``` + +will be replaced by + +```mlir +module { + func.func @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> {...} +} +``` + }]; + + let constructor = "CreateIfrtRemoveIfrtAttrsPass()"; +} + +def IfrtLowerMpmdReshardToCallPass : + Pass<"ifrt-lower-mpmd-reshard-to-call", "mlir::ModuleOp"> { + let summary = "Lowers MPMD `ifrt.Reshard` to `ifrt.Call`"; + let description = [{ +Replaces each MPMD `ifrt.Reshard` with an `ifrt.Call` to a newly created +module. The module has a main function with a `ifrt.reshard_function` attribute, +and that comprises of the `ifrt.Reshard` ops part of the MPMD reshard. + +For example, the following code +```mlir +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +func.func @main(%arg0: !array0) -> (!array1) { + %0, %ctrl = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0: !array1 +} +``` + +will be replaced by + +```mlir +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +func.func @main(%arg0: !array0) -> (!array1) { + %0, %ctrl = ifrt.Call(%arg) {ifrt.module_type = "mpmd_reshard"} + : (!array0) -> !array1 + return %0: !array1 +} + +module @reshard attributes { + ifrt.num_devices = 3, + sym_visibility = "private"} { + func.func @main(%arg0: !array0) -> !array1 + attributes {ifrt.reshard_function} { + %0, %ctrl = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0 : !array1 + } +} +``` + }]; + + let constructor = "CreateIfrtLowerMpmdReshardToCallPass()"; +} #endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_PASSES_TD_ diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc index b1cb219e5e49fe..2a27edce6f20bb 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc @@ -16,13 +16,66 @@ limitations under the License. #include "xla/python/ifrt/ir/transforms/utils.h" #include "absl/log/check.h" +#include "llvm/ADT/Hashing.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Support/LLVM.h" +#include "xla/mlir/utils/type_util.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" namespace xla { namespace ifrt { +unsigned IfrtCallOpInfo::getHashValue(CallOp call_op) { + llvm::hash_code hash = {}; + // Use `getInputs()/getOutputs()` instead of `getOperands()/getResults()` to + // ensure that the control dependencies are not included in the hash. + for (auto input_type : call_op.getInputs().getTypes()) { + hash = llvm::hash_combine(hash, input_type); + } + for (auto output_type : call_op.getOutputs().getTypes()) { + hash = llvm::hash_combine(hash, output_type); + } + for (mlir::NamedAttribute attr : call_op->getAttrs()) { + // Exclude `operandSegmentSizes` because its value changes depending on + // how many control dependencies a CallOp has. + if (attr.getName() == "operandSegmentSizes") { + continue; + } + hash = llvm::hash_combine(hash, attr); + } + return hash; +} + +bool IfrtCallOpInfo::isEqual(CallOp lhs, CallOp rhs) { + if (lhs == rhs) { + return true; + } + if (lhs == getEmptyKey() || lhs == getTombstoneKey() || + rhs == getEmptyKey() || rhs == getTombstoneKey()) { + return false; + } + // Verify that the input and output types are the same. + if (lhs.getInputs().getTypes() != rhs.getInputs().getTypes()) { + return false; + } + if (lhs.getOutputs().getTypes() != rhs.getOutputs().getTypes()) { + return false; + } + mlir::NamedAttrList lattrs = lhs->getAttrDictionary(); + mlir::NamedAttrList rattrs = rhs->getAttrDictionary(); + lattrs.erase("operandSegmentSizes"); + rattrs.erase("operandSegmentSizes"); + // Verify that the attributes are the same. + return lattrs == rattrs; +} + mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module) { mlir::func::FuncOp func = mlir::dyn_cast_or_null(module.lookupSymbol("main")); @@ -30,5 +83,25 @@ mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module) { return func; } +bool IsReshard(IfrtArrayType from, IfrtArrayType to) { + if (from.getShape() == to.getShape() && + from.getShardingAttr() == to.getShardingAttr() && + from.getDevices().size() == to.getDevices().size()) { + return false; + } + return true; +} + +void UpdateFunctionType(mlir::func::FuncOp func_op) { + func_op.setType(mlir::FunctionType::get( + func_op.getContext(), func_op.getBody().getArgumentTypes(), + func_op.getBody().front().getTerminator()->getOperandTypes())); +} + +absl::StatusOr ToIfrtDType(mlir::Type type) { + xla::PrimitiveType primitive_type = xla::ConvertMlirTypeToPrimitiveType(type); + return ToDType(primitive_type); +} + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/utils.h b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h index 81528e97f418ae..a10feb281f645c 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/utils.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h @@ -16,16 +16,38 @@ limitations under the License. #ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ #define XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ +#include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Types.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" namespace xla { namespace ifrt { +// Used for comparing CallOps without including control dependencies. +struct IfrtCallOpInfo : llvm::DenseMapInfo { + static unsigned getHashValue(xla::ifrt::CallOp call_op); + static bool isEqual(xla::ifrt::CallOp lhs, xla::ifrt::CallOp rhs); +}; + // Retrieves the function named "main" from the given module, if it exists, and // fails otherwise. mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module); +// Returns true if transferring between from and to array requires a reshard. +bool IsReshard(xla::ifrt::IfrtArrayType from, xla::ifrt::IfrtArrayType to); + +// Updates the FunctionType of the given `func_op` to match the block arguments +// types and return operands types in its region. +void UpdateFunctionType(mlir::func::FuncOp func_op); + +// Converts a mlir::Type to a ifrt DType. +absl::StatusOr ToIfrtDType(mlir::Type type); + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/memory.h b/third_party/xla/xla/python/ifrt/memory.h index 309d49705381e3..599bcc277f25a5 100644 --- a/third_party/xla/xla/python/ifrt/memory.h +++ b/third_party/xla/xla/python/ifrt/memory.h @@ -110,11 +110,27 @@ class Memory : public llvm::RTTIExtends { // Debug string suitable for logging when errors occur. Should be verbose // enough to describe the current device unambiguously. + // + // TODO(hyeontaek): Remove this method in favor of AbslStringify. virtual absl::string_view DebugString() const = 0; // The devices to which this memory space is attached. virtual absl::Span Devices() const = 0; + template + friend void AbslStringify(Sink& sink, const Memory& memory) { + sink.Append(memory.DebugString()); + } + + template + friend void AbslStringify(Sink& sink, const Memory* memory) { + if (memory == nullptr) { + sink.Append(""); + } else { + sink.Append(memory->DebugString()); + } + } + static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/mock.cc b/third_party/xla/xla/python/ifrt/mock.cc index 9bd006d48753de..948ca6adab7261 100644 --- a/third_party/xla/xla/python/ifrt/mock.cc +++ b/third_party/xla/xla/python/ifrt/mock.cc @@ -50,6 +50,10 @@ char MockHostCallback::ID = 0; char MockLoadedHostCallback::ID = 0; char MockSharding::ID = 0; +namespace { +using ::testing::_; +} + // LINT.IfChange(MockArrayDelegation) MockArray::MockArray(tsl::RCReference delegated) : delegated_(std::move(delegated)) { @@ -62,9 +66,6 @@ MockArray::MockArray(tsl::RCReference delegated) ON_CALL(*this, IsDeleted).WillByDefault([this]() { return delegated_->IsDeleted(); }); - ON_CALL(*this, DebugString).WillByDefault([this]() { - return delegated_->DebugString(); - }); ON_CALL(*this, dtype).WillByDefault([this]() { return delegated_->dtype(); }); ON_CALL(*this, shape).WillByDefault([this]() -> const Shape& { return delegated_->shape(); @@ -79,10 +80,17 @@ MockArray::MockArray(tsl::RCReference delegated) .WillByDefault([this]() -> absl::StatusOr> { return delegated_->layout(); }); - ON_CALL(*this, DisassembleIntoSingleDeviceArrays) + ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_)) .WillByDefault([this](ArrayCopySemantics semantics) { return delegated_->DisassembleIntoSingleDeviceArrays(semantics); }); + ON_CALL(*this, DisassembleIntoSingleDeviceArrays(_, _)) + .WillByDefault( + [this](ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + return delegated_->DisassembleIntoSingleDeviceArrays( + array_copy_semantics, single_device_shard_semantics); + }); ON_CALL(*this, FullyReplicatedShard) .WillByDefault([this](ArrayCopySemantics semantics) { return delegated_->FullyReplicatedShard(semantics); @@ -111,7 +119,7 @@ MockClient::MockClient(std::unique_ptr delegated) data, dtype, std::move(shape), byte_strides, std::move(sharding), semantics, std::move(on_done_with_host_buffer)); }); - ON_CALL(*this, AssembleArrayFromSingleDeviceArrays) + ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _)) .WillByDefault([this](Shape shape, std::shared_ptr sharding, absl::Span> arrays, @@ -119,6 +127,16 @@ MockClient::MockClient(std::unique_ptr delegated) return delegated_->AssembleArrayFromSingleDeviceArrays( std::move(shape), std::move(sharding), arrays, semantics); }); + ON_CALL(*this, AssembleArrayFromSingleDeviceArrays(_, _, _, _, _)) + .WillByDefault( + [this](Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + return delegated_->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(sharding), arrays, + array_copy_semantics, single_device_shard_semantics); + }); ON_CALL(*this, CopyArrays) .WillByDefault([this](absl::Span> arrays, std::optional> devices, @@ -172,6 +190,9 @@ MockClient::MockClient(std::unique_ptr delegated) ON_CALL(*this, process_index).WillByDefault([this]() { return delegated_->process_index(); }); + ON_CALL(*this, GetAllDevices).WillByDefault([this]() { + return delegated_->GetAllDevices(); + }); ON_CALL(*this, GetDefaultDeviceAssignment) .WillByDefault([this](int num_replicas, int num_partitions) { return delegated_->GetDefaultDeviceAssignment(num_replicas, @@ -214,12 +235,6 @@ MockDevice::MockDevice(Device* delegated) : delegated_(delegated) { return delegated_->ProcessIndex(); }); ON_CALL(*this, Kind).WillByDefault([this]() { return delegated_->Kind(); }); - ON_CALL(*this, DebugString).WillByDefault([this]() { - return delegated_->DebugString(); - }); - ON_CALL(*this, ToString).WillByDefault([this]() { - return delegated_->ToString(); - }); ON_CALL(*this, Attributes).WillByDefault([this]() -> const AttributeMap& { return delegated_->Attributes(); }); diff --git a/third_party/xla/xla/python/ifrt/mock.h b/third_party/xla/xla/python/ifrt/mock.h index 125c95e58a3fb5..08dc438dae9802 100644 --- a/third_party/xla/xla/python/ifrt/mock.h +++ b/third_party/xla/xla/python/ifrt/mock.h @@ -69,7 +69,6 @@ class MockArray : public llvm::RTTIExtends { MOCK_METHOD(Future<>, GetReadyFuture, (), (const, final)); MOCK_METHOD(Future<>, Delete, (), (final)); MOCK_METHOD(bool, IsDeleted, (), (const, final)); - MOCK_METHOD(std::string, DebugString, (), (const, final)); MOCK_METHOD(DType, dtype, (), (const, final)); MOCK_METHOD(const Shape&, shape, (), (const, final)); @@ -81,6 +80,11 @@ class MockArray : public llvm::RTTIExtends { MOCK_METHOD(absl::StatusOr>>, DisassembleIntoSingleDeviceArrays, (ArrayCopySemantics semantics), (final)); + MOCK_METHOD(absl::StatusOr>>, + DisassembleIntoSingleDeviceArrays, + (ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics), + (final)); MOCK_METHOD(absl::StatusOr>, FullyReplicatedShard, (ArrayCopySemantics semantics), (final)); MOCK_METHOD(Future<>, CopyToHostBuffer, @@ -92,6 +96,8 @@ class MockArray : public llvm::RTTIExtends { tsl::RCReference delegated() const { return delegated_; } + std::string DebugString() const final { return "MockArray"; } + static char ID; // NOLINT private: @@ -119,6 +125,13 @@ class MockClient : public llvm::RTTIExtends { absl::Span> arrays, ArrayCopySemantics semantics), (final)); + MOCK_METHOD(absl::StatusOr>, + AssembleArrayFromSingleDeviceArrays, + (Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics), + (final)); MOCK_METHOD(absl::StatusOr>>, CopyArrays, (absl::Span> arrays, std::optional> devices, @@ -145,6 +158,8 @@ class MockClient : public llvm::RTTIExtends { MOCK_METHOD(absl::Span, addressable_devices, (), (const, final)); MOCK_METHOD(int, process_index, (), (const, final)); + MOCK_METHOD(absl::Span, GetAllDevices, (), + (const, final)); MOCK_METHOD(absl::StatusOr, GetDefaultDeviceAssignment, (int num_replicas, int num_partitions), (const, final)); MOCK_METHOD(absl::StatusOr, LookupDevice, (DeviceId device_id), @@ -204,8 +219,6 @@ class MockDevice : public Device { MOCK_METHOD(int, ProcessIndex, (), (const, final)); MOCK_METHOD(DeviceId, Id, (), (const, final)); MOCK_METHOD(absl::string_view, Kind, (), (const, final)); - MOCK_METHOD(absl::string_view, DebugString, (), (const, final)); - MOCK_METHOD(absl::string_view, ToString, (), (const, final)); MOCK_METHOD((const AttributeMap&), Attributes, (), (const, final)); MOCK_METHOD(absl::StatusOr, DefaultMemory, (), (const, final)); MOCK_METHOD(absl::Span, Memories, (), (const, final)); @@ -213,6 +226,9 @@ class MockDevice : public Device { Device* delegated() const { return delegated_; } + absl::string_view DebugString() const final { return "MockDevice"; } + absl::string_view ToString() const final { return "MockDevice"; } + private: Device* const delegated_ = nullptr; }; @@ -224,8 +240,9 @@ class MockMemory : public Memory { MOCK_METHOD(MemoryId, Id, (), (const, final)); MOCK_METHOD(absl::Span, Devices, (), (const, final)); MOCK_METHOD(const MemoryKind&, Kind, (), (const, final)); - MOCK_METHOD(absl::string_view, DebugString, (), (const, final)); MOCK_METHOD(absl::string_view, ToString, (), (const, final)); + + absl::string_view DebugString() const final { return "MockMemory"; } }; // executable.h @@ -244,9 +261,9 @@ class MockExecutable : public llvm::RTTIExtends { (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetHloModules, (), (const, final)); @@ -273,9 +290,9 @@ class MockLoadedExecutable (const, final)); MOCK_METHOD(std::optional>, GetOutputShardings, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetParameterLayouts, (), (const, final)); - MOCK_METHOD(absl::StatusOr>>, + MOCK_METHOD(absl::StatusOr>>, GetOutputLayouts, (), (const, final)); MOCK_METHOD(absl::StatusOr>>, GetOutputMemoryKinds, (), (const, final)); @@ -333,12 +350,28 @@ class MockSharding : public llvm::RTTIExtends { (absl::StatusOr< std::vector>>>), Disassemble, (const Shape& shape), (const, final)); + MOCK_METHOD( + (absl::StatusOr< + std::vector>>>), + Disassemble, + (const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics), + (const, final)); MOCK_METHOD((absl::StatusOr>>>), Disassemble, (const DynamicShape& dynamic_shape), (const final)); + MOCK_METHOD((absl::StatusOr>>>), + Disassemble, + (const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics), + (const final)); MOCK_METHOD(absl::StatusOr>, IndexDomains, (const Shape& shape), (const, final)); - MOCK_METHOD(std::string, DebugString, (), (const, final)); + MOCK_METHOD(absl::StatusOr>, IndexDomains, + (const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics), + (const, final)); MOCK_METHOD(absl::StatusOr, GetShardShape, (const Shape& shape), (const, final)); MOCK_METHOD(bool, HasSamePartitioning, (const Sharding& other), @@ -348,6 +381,8 @@ class MockSharding : public llvm::RTTIExtends { std::optional memory_kind), (const final)); + std::string DebugString() const final { return "MockSharding"; } + static char ID; // NOLINT }; diff --git a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc index 4edfae40571cae..9e8ee895723310 100644 --- a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc @@ -19,9 +19,9 @@ #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/serdes.pb.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt/shape.h b/third_party/xla/xla/python/ifrt/shape.h index 8726f4dd14a30c..e617aaeef52243 100644 --- a/third_party/xla/xla/python/ifrt/shape.h +++ b/third_party/xla/xla/python/ifrt/shape.h @@ -67,8 +67,14 @@ class Shape { // Total number of elements in this shape. int64_t num_elements() const; + // TODO(hyeontaek): Remove this method in favor of AbslStringify. std::string DebugString() const; + template + friend void AbslStringify(Sink& sink, const Shape& shape) { + sink.Append(shape.DebugString()); + } + private: Dimensions dims_; }; diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc index dab5dc32ef3ca5..383d3a317b8075 100644 --- a/third_party/xla/xla/python/ifrt/sharding.cc +++ b/third_party/xla/xla/python/ifrt/sharding.cc @@ -237,26 +237,63 @@ SingleDeviceSharding::WithDeviceAssignment( absl::StatusOr>>> SingleDeviceSharding::Disassemble(const Shape& shape) const { DCHECK(this); - return std::vector>>{ - {shape, SingleDeviceSharding::Create(devices_->devices().front(), - memory_kind_)}}; + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +SingleDeviceSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); + std::vector>> result; + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards || + devices_->devices().front()->IsAddressable()) { + result.reserve(1); + result.push_back({shape, SingleDeviceSharding::Create( + devices_->devices().front(), memory_kind_)}); + } + return result; } absl::StatusOr< std::vector>>> SingleDeviceSharding::Disassemble(const DynamicShape& dynamic_shape) const { DCHECK(this); - return std::vector>>{ - {dynamic_shape, SingleDeviceSharding::Create(devices_->devices().front(), - memory_kind_)}}; + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} +absl::StatusOr< + std::vector>>> +SingleDeviceSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); + std::vector>> result; + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards || + devices_->devices().front()->IsAddressable()) { + result.reserve(1); + result.push_back( + {dynamic_shape, SingleDeviceSharding::Create( + devices_->devices().front(), memory_kind_)}); + } + return result; } absl::StatusOr> SingleDeviceSharding::IndexDomains( const Shape& shape) const { DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> SingleDeviceSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); std::vector result; - result.reserve(1); - result.push_back(IndexDomain(shape)); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards || + devices_->devices().front()->IsAddressable()) { + result.reserve(1); + result.push_back(IndexDomain(shape)); + } return result; } @@ -308,6 +345,14 @@ absl::StatusOr> OpaqueSharding::WithDeviceAssignment( absl::StatusOr>>> OpaqueSharding::Disassemble(const Shape& shape) const { DCHECK(this); + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +OpaqueSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "OpaqueSharding does not have shard shape information"); } @@ -316,6 +361,15 @@ absl::StatusOr< std::vector>>> OpaqueSharding::Disassemble(const DynamicShape& dynamic_shape) const { DCHECK(this); + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr< + std::vector>>> +OpaqueSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "OpaqueSharding does not have shard shape information"); } @@ -323,6 +377,13 @@ OpaqueSharding::Disassemble(const DynamicShape& dynamic_shape) const { absl::StatusOr> OpaqueSharding::IndexDomains( const Shape& shape) const { DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> OpaqueSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "OpaqueSharding does not have index domain information"); } @@ -413,6 +474,14 @@ ConcreteSharding::WithDeviceAssignment( absl::StatusOr>>> ConcreteSharding::Disassemble(const Shape& shape) const { DCHECK(this); + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +ConcreteSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); if (!has_static_shape()) { return InvalidArgument( "ConcreteSharding holds dynamic shape, but was asked " @@ -428,11 +497,19 @@ ConcreteSharding::Disassemble(const Shape& shape) const { std::vector>> result; const std::vector& shard_shapes = std::get>(shard_shapes_); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } const absl::Span devices = devices_->devices(); - result.reserve(devices.size()); for (int i = 0; i < devices.size(); ++i) { - result.push_back({shard_shapes[i], - SingleDeviceSharding::Create(devices[i], memory_kind_)}); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + result.push_back({shard_shapes[i], SingleDeviceSharding::Create( + devices[i], memory_kind_)}); + } } return result; } @@ -441,6 +518,15 @@ absl::StatusOr< std::vector>>> ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const { DCHECK(this); + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr< + std::vector>>> +ConcreteSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); if (!has_dynamic_shape()) { return InvalidArgument( "ConcreteSharding holds static shape, but was asked " @@ -458,10 +544,19 @@ ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const { const std::vector& shard_dynamic_shapes = std::get>(shard_shapes_); const absl::Span devices = devices_->devices(); - result.reserve(devices.size()); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } for (int i = 0; i < devices.size(); ++i) { - result.push_back({shard_dynamic_shapes[i], - SingleDeviceSharding::Create(devices[i], memory_kind_)}); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + result.push_back( + {shard_dynamic_shapes[i], + SingleDeviceSharding::Create(devices[i], memory_kind_)}); + } } return result; } @@ -469,6 +564,13 @@ ConcreteSharding::Disassemble(const DynamicShape& dynamic_shape) const { absl::StatusOr> ConcreteSharding::IndexDomains( const Shape& shape) const { DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> ConcreteSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "ConcreteSharding does not have index domain information"); } @@ -552,6 +654,14 @@ ConcreteEvenSharding::WithDeviceAssignment( absl::StatusOr>>> ConcreteEvenSharding::Disassemble(const Shape& shape) const { DCHECK(this); + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +ConcreteEvenSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); if (shape != shape_) { return InvalidArgument( "ConcreteEvenSharding can only disassemble shape %s, but was asked " @@ -560,10 +670,18 @@ ConcreteEvenSharding::Disassemble(const Shape& shape) const { } std::vector>> result; const absl::Span devices = devices_->devices(); - result.reserve(devices.size()); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } for (int i = 0; i < devices.size(); ++i) { - result.push_back( - {shard_shape_, SingleDeviceSharding::Create(devices[i], memory_kind_)}); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + result.push_back({shard_shape_, SingleDeviceSharding::Create( + devices[i], memory_kind_)}); + } } return result; } @@ -571,6 +689,16 @@ ConcreteEvenSharding::Disassemble(const Shape& shape) const { absl::StatusOr< std::vector>>> ConcreteEvenSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr< + std::vector>>> +ConcreteEvenSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "ConcreteEvenSharding can only disassemble static shape, but was asked " "to disassemble dynamic shape %s", @@ -580,6 +708,12 @@ ConcreteEvenSharding::Disassemble(const DynamicShape& dynamic_shape) const { absl::StatusOr> ConcreteEvenSharding::IndexDomains( const Shape& shape) const { DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} +absl::StatusOr> ConcreteEvenSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "ConcreteEvenSharding does not have index domain information"); } @@ -622,12 +756,29 @@ ShardingParamSharding::ShardingParamSharding( absl::StatusOr>>> ShardingParamSharding::Disassemble(const Shape& shape) const { DCHECK(this); + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +ShardingParamSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); TF_ASSIGN_OR_RETURN(Shape local_shape, GetShardShape(shape)); std::vector>> result; + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } for (Device* device : devices_->devices()) { - result.push_back( - {local_shape, SingleDeviceSharding::Create(device, memory_kind_)}); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + device->IsAddressable()) { + result.push_back( + {local_shape, SingleDeviceSharding::Create(device, memory_kind_)}); + } } return result; @@ -684,6 +835,16 @@ ShardingParamSharding::WithDeviceAssignment( absl::StatusOr< std::vector>>> ShardingParamSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr< + std::vector>>> +ShardingParamSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "ShardingParamSharding can only disassemble static shape, but was asked " "to disassemble dynamic shape %s", @@ -693,6 +854,13 @@ ShardingParamSharding::Disassemble(const DynamicShape& dynamic_shape) const { absl::StatusOr> ShardingParamSharding::IndexDomains( const Shape& shape) const { DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> ShardingParamSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); // Calculate the origins of tiles, ignoring device assignments. TF_ASSIGN_OR_RETURN(Shape local_shape, GetShardShape(shape)); @@ -718,12 +886,22 @@ absl::StatusOr> ShardingParamSharding::IndexDomains( DCHECK_EQ(device_to_index.size() % origins.size(), 0); int replication = device_to_index.size() / origins.size(); + DCHECK_EQ(device_to_index.size(), devices_->size()); std::vector result; - result.reserve(device_to_index.size()); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } + const absl::Span devices = devices_->devices(); for (int i = 0; i < device_to_index.size(); ++i) { - int index = device_to_index[i]; - DCHECK_NE(index, kInvalidIndex); - result.push_back(IndexDomain(origins[index / replication], local_shape)); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + int index = device_to_index[i]; + DCHECK_NE(index, kInvalidIndex); + result.push_back(IndexDomain(origins[index / replication], local_shape)); + } } return result; } diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h index 2e77dbbfa190b2..58e1b0bf052b3e 100644 --- a/third_party/xla/xla/python/ifrt/sharding.h +++ b/third_party/xla/xla/python/ifrt/sharding.h @@ -43,6 +43,31 @@ namespace ifrt { struct DeserializeShardingOptions; +// Semantics for operations that take or return single-device shards of arrays +// or shardings. +enum class SingleDeviceShardSemantics : int { + // Processes only the single-device shards on addresable devices. + // + // * Assembly takes single-device arrays/shards for every addressable shard of + // an assembled array/sharding. + // + // * Disassembly returns single-device arrays/shards for every addressable + // shard of an assembled array/sharding. + kAddressableShards = 0, + + // Processes single-device shards on all devices. + // + // * Assembly takes single-device arrays/shards for every + // addressable/non-addressable shard of an assembled array/sharding. + // + // * Disassembly returns single-device arrays/shards for every + // addressable/non-addressable shard of an assembled array/sharding. + // + // Runtimes that cannot express single-device arrays on a non-addressable + // device does not support this semantics no array operations. + kAllShards, +}; + // Abstract sharding type. // // TODO(hyeontaek): There is an indication that we may prefer to split logical @@ -93,22 +118,42 @@ class Sharding : public llvm::RTTIExtends { // Breaks a shape up into per-device shapes and shardings. See // Array::DisassembleIntoSingleDeviceArrays(). It may return an error if // disassembly is unsupported. + // TODO(hyeontaek): Replace this API with the version that takes + // `SingleDeviceShardSemantics`. virtual absl::StatusOr< std::vector>>> Disassemble(const Shape& shape) const = 0; + virtual absl::StatusOr< + std::vector>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const = 0; // Variant of `Disassemble` that takes a dynamic shape. + // TODO(hyeontaek): Replace this API with the version that takes + // `SingleDeviceShardSemantics`. virtual absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const = 0; + virtual absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const = 0; // Maps each shard to an `IndexDomain` over `shape`. The result is a list of // `index_domain_i` such that `array[index_domain_i] = disassembled_array_i`. // Note that multiple shards may map onto equal `IndexDomain`. For instance, a // fully replicated sharding would return a vector of `[IndexDomain(shape)] * - // devices().size()`. + // devices().size()` if `single_device_shard_semantics == + // SingleDeviceShardSemantics::kAllShards`. + // TODO(hyeontaek): Replace this API with the version that takes + // `SingleDeviceShardSemantics`. virtual absl::StatusOr> IndexDomains( const Shape& shape) const = 0; + virtual absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const = 0; // Deserializes `ShardingProto` into `Sharding`. // Note that `Sharding` serialization uses `SerDes` to handle an open set of @@ -122,8 +167,24 @@ class Sharding : public llvm::RTTIExtends { // `Sharding` subclasses. See `serdes.h`. absl::StatusOr ToProto() const; + // TODO(hyeontaek): Remove this method in favor of AbslStringify. virtual std::string DebugString() const = 0; + template + friend void AbslStringify(Sink& sink, const Sharding& sharding) { + sink.Append(sharding.DebugString()); + } + + template + friend void AbslStringify(Sink& sink, + const std::shared_ptr& sharding) { + if (sharding == nullptr) { + sink.Append(""); + } else { + sink.Append(sharding->DebugString()); + } + } + static char ID; // NOLINT protected: @@ -166,13 +227,25 @@ class SingleDeviceSharding final absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; @@ -208,13 +281,25 @@ class OpaqueSharding : public llvm::RTTIExtends { absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; @@ -295,12 +380,25 @@ class ConcreteSharding : public llvm::RTTIExtends { absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; + absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; @@ -355,12 +453,25 @@ class ConcreteEvenSharding absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; + absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; @@ -396,12 +507,25 @@ class ShardingParamSharding absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; + absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; diff --git a/third_party/xla/xla/python/ifrt/sharding_test.cc b/third_party/xla/xla/python/ifrt/sharding_test.cc index 1fc353c8793ff4..23c4e015672b1e 100644 --- a/third_party/xla/xla/python/ifrt/sharding_test.cc +++ b/third_party/xla/xla/python/ifrt/sharding_test.cc @@ -115,8 +115,23 @@ TEST_P(SingleDeviceShardingTest, IndexDomains) { device_list->devices().front(), MemoryKind()); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + EXPECT_THAT(index_domains, ElementsAre(IndexDomain(shape))); + } } TEST_P(SingleDeviceShardingTest, Disassemble) { @@ -126,25 +141,66 @@ TEST_P(SingleDeviceShardingTest, Disassemble) { { // Disassemble static shape. Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(1)); - const auto& [result_shape, result_sharding] = disassembled[0]; - EXPECT_EQ(shape, result_shape); - EXPECT_EQ(*result_sharding, *sharding); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble( + shape, SingleDeviceShardSemantics::kAddressableShards)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } } { // Disassemble dynamic shape. TF_ASSERT_OK_AND_ASSIGN( DynamicShape dynamic_shape, DynamicShape::Create(Shape({10, 20}), BoundedDynamicShapeTag({true, true}))); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, - sharding->Disassemble(dynamic_shape)); - - ASSERT_THAT(disassembled, SizeIs(1)); - const auto& [result_shape, result_sharding] = disassembled[0]; - EXPECT_EQ(dynamic_shape, result_shape); - EXPECT_EQ(*result_sharding, *sharding); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(dynamic_shape)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(dynamic_shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(dynamic_shape, + SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(dynamic_shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble( + dynamic_shape, SingleDeviceShardSemantics::kAddressableShards)); + ASSERT_THAT(disassembled, SizeIs(1)); + const auto& [result_shape, result_sharding] = disassembled[0]; + EXPECT_EQ(dynamic_shape, result_shape); + EXPECT_EQ(*result_sharding, *sharding); + } } } @@ -375,51 +431,128 @@ TEST_P(ConcreteShardingTest, WithDeviceAssignment) { } TEST_P(ConcreteShardingTest, Disassemble) { - auto device_list = GetDevices({0, 1}); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); std::vector shard_shapes; shard_shapes.reserve(2); - shard_shapes.push_back(Shape({10})); - shard_shapes.push_back(Shape({20})); + shard_shapes.push_back(Shape({3})); + shard_shapes.push_back(Shape({7})); + shard_shapes.push_back(Shape({3})); + shard_shapes.push_back(Shape({7})); + shard_shapes.push_back(Shape({3})); + shard_shapes.push_back(Shape({7})); std::shared_ptr sharding = ConcreteSharding::Create( device_list, MemoryKind(), Shape({30}), shard_shapes); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, - sharding->Disassemble(Shape({30}))); - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < 2; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, shard_shapes[i]); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(Shape({30}))); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, shard_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(Shape({30}), + SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, shard_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(Shape({30}), + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, shard_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } TEST_P(ConcreteShardingTest, DisassembleDynamicShape) { - auto device_list = GetDevices({0, 1}); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); TF_ASSERT_OK_AND_ASSIGN( DynamicShape dynamic_shape, - DynamicShape::Create(Shape({10}), BoundedDynamicShapeTag({true}))); + DynamicShape::Create(Shape({30}), BoundedDynamicShapeTag({true}))); TF_ASSERT_OK_AND_ASSIGN( - DynamicShape shard_dynamic_shape1, + DynamicShape shard_dynamic_shape0, DynamicShape::Create(Shape({3}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape1, + DynamicShape::Create(Shape({7}), BoundedDynamicShapeTag({true}))); TF_ASSERT_OK_AND_ASSIGN( DynamicShape shard_dynamic_shape2, + DynamicShape::Create(Shape({3}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape3, + DynamicShape::Create(Shape({7}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape4, + DynamicShape::Create(Shape({3}), BoundedDynamicShapeTag({true}))); + TF_ASSERT_OK_AND_ASSIGN( + DynamicShape shard_dynamic_shape5, DynamicShape::Create(Shape({7}), BoundedDynamicShapeTag({true}))); std::vector shard_dynamic_shapes{ - std::move(shard_dynamic_shape1), std::move(shard_dynamic_shape2)}; + std::move(shard_dynamic_shape0), std::move(shard_dynamic_shape1), + std::move(shard_dynamic_shape2), std::move(shard_dynamic_shape3), + std::move(shard_dynamic_shape4), std::move(shard_dynamic_shape5), + }; auto sharding = ConcreteSharding::Create(device_list, MemoryKind(), dynamic_shape, shard_dynamic_shapes); - EXPECT_THAT(sharding->Disassemble(Shape({10})), + EXPECT_THAT(sharding->Disassemble(Shape({30})), StatusIs(tsl::error::INVALID_ARGUMENT, HasSubstr("ConcreteSharding holds dynamic shape"))); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, - sharding->Disassemble(DynamicShape(dynamic_shape))); - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < disassembled.size(); ++i) { - const auto& [dynamic_shape, sharding] = disassembled[i]; - EXPECT_EQ(dynamic_shape, shard_dynamic_shapes[i]); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(DynamicShape(dynamic_shape))); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [dynamic_shape, sharding] = disassembled[i]; + EXPECT_EQ(dynamic_shape, shard_dynamic_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(DynamicShape(dynamic_shape), + SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [dynamic_shape, sharding] = disassembled[i]; + EXPECT_EQ(dynamic_shape, shard_dynamic_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(DynamicShape(dynamic_shape), + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [dynamic_shape, sharding] = disassembled[i]; + EXPECT_EQ(dynamic_shape, shard_dynamic_shapes[i]); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -574,19 +707,48 @@ TEST_P(ConcreteEvenShardingTest, WithDeviceAssignment) { } TEST_P(ConcreteEvenShardingTest, Disassemble) { - auto device_list = GetDevices({0, 1}); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}), - Shape({15}), /*is_fully_replicated=*/false); + Shape({5}), /*is_fully_replicated=*/false); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, - sharding->Disassemble(Shape({30}))); - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < 2; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({15})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + sharding->Disassemble(Shape({30}))); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(Shape({30}), + SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(Shape({30}), + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -606,7 +768,7 @@ TEST_P(ConcreteEvenShardingTest, IndexDomainsFails) { std::vector shard_shapes; std::shared_ptr sharding = ConcreteEvenSharding::Create(device_list, MemoryKind(), Shape({30}), - Shape({15}), /*is_fully_replicated=*/false); + Shape({5}), /*is_fully_replicated=*/false); EXPECT_THAT( sharding->IndexDomains(Shape({30})), @@ -768,14 +930,43 @@ TEST_P(ShardingParamShardingTest, Disassemble) { std::shared_ptr param_sharding, ShardingParamSharding::Create(param, device_list, MemoryKind())); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, - param_sharding->Disassemble(Shape({6, 6}))); - ASSERT_THAT(disassembled, SizeIs(6)); - for (int i = 0; i < 6; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({3, 2})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, + param_sharding->Disassemble(Shape({6, 6}))); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({3, 2})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + param_sharding->Disassemble(Shape({6, 6}), + SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({3, 2})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + param_sharding->Disassemble( + Shape({6, 6}), SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({3, 2})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -816,15 +1007,42 @@ TEST_P(ShardingParamShardingTest, IndexDomain) { std::shared_ptr param_sharding, ShardingParamSharding::Create(param, device_list, MemoryKind())); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, - param_sharding->IndexDomains(Shape({6, 6}))); - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), - IndexDomain(Index({0, 2}), Shape({3, 2})), - IndexDomain(Index({0, 4}), Shape({3, 2})), - IndexDomain(Index({3, 0}), Shape({3, 2})), - IndexDomain(Index({3, 2}), Shape({3, 2})), - IndexDomain(Index({3, 4}), Shape({3, 2})))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}))); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 0}), Shape({3, 2})), + IndexDomain(Index({3, 2}), Shape({3, 2})), + IndexDomain(Index({3, 4}), Shape({3, 2})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}), + SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 0}), Shape({3, 2})), + IndexDomain(Index({3, 2}), Shape({3, 2})), + IndexDomain(Index({3, 4}), Shape({3, 2})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains( + Shape({6, 6}), SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 0}), Shape({3, 2})))); + } } TEST_P(ShardingParamShardingTest, IndexDomainWithPermutation) { @@ -835,15 +1053,42 @@ TEST_P(ShardingParamShardingTest, IndexDomainWithPermutation) { std::shared_ptr param_sharding, ShardingParamSharding::Create(param, device_list, MemoryKind())); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, - param_sharding->IndexDomains(Shape({6, 6}))); - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), - IndexDomain(Index({0, 4}), Shape({3, 2})), - IndexDomain(Index({3, 2}), Shape({3, 2})), - IndexDomain(Index({0, 2}), Shape({3, 2})), - IndexDomain(Index({3, 0}), Shape({3, 2})), - IndexDomain(Index({3, 4}), Shape({3, 2})))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}))); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 2}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})), + IndexDomain(Index({3, 0}), Shape({3, 2})), + IndexDomain(Index({3, 4}), Shape({3, 2})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}), + SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 2}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})), + IndexDomain(Index({3, 0}), Shape({3, 2})), + IndexDomain(Index({3, 4}), Shape({3, 2})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains( + Shape({6, 6}), SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 2})), + IndexDomain(Index({0, 4}), Shape({3, 2})), + IndexDomain(Index({3, 2}), Shape({3, 2})), + IndexDomain(Index({0, 2}), Shape({3, 2})))); + } } TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { @@ -854,33 +1099,60 @@ TEST_P(ShardingParamShardingTest, IndexDomainWithReplication) { std::shared_ptr param_sharding, ShardingParamSharding::Create(param, device_list, MemoryKind())); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, - param_sharding->IndexDomains(Shape({6, 6}))); - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 6})), - IndexDomain(Index({0, 0}), Shape({3, 6})), - IndexDomain(Index({0, 0}), Shape({3, 6})), - IndexDomain(Index({3, 0}), Shape({3, 6})), - IndexDomain(Index({3, 0}), Shape({3, 6})), - IndexDomain(Index({3, 0}), Shape({3, 6})))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}))); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains(Shape({6, 6}), + SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + param_sharding->IndexDomains( + Shape({6, 6}), SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({0, 0}), Shape({3, 6})), + IndexDomain(Index({3, 0}), Shape({3, 6})))); + } } INSTANTIATE_TEST_SUITE_P(NumDevices, SingleDeviceShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, - /*num_addressable_devices=*/6})); + /*num_addressable_devices=*/4})); INSTANTIATE_TEST_SUITE_P(NumDevices, OpaqueShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, - /*num_addressable_devices=*/6})); + /*num_addressable_devices=*/4})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, - /*num_addressable_devices=*/6})); + /*num_addressable_devices=*/4})); INSTANTIATE_TEST_SUITE_P(NumDevices, ConcreteEvenShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, - /*num_addressable_devices=*/6})); + /*num_addressable_devices=*/4})); INSTANTIATE_TEST_SUITE_P(NumDevices, ShardingParamShardingTest, testing::Values(test_util::DeviceTestParam{ /*num_devices=*/6, diff --git a/third_party/xla/xla/python/ifrt/support/BUILD b/third_party/xla/xla/python/ifrt/support/BUILD index 1c287ac13ad2ce..42610fcef66195 100644 --- a/third_party/xla/xla/python/ifrt/support/BUILD +++ b/third_party/xla/xla/python/ifrt/support/BUILD @@ -1,14 +1,40 @@ load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], licenses = ["notice"], ) +cc_library( + name = "module_parsing", + srcs = ["module_parsing.cc"], + hdrs = ["module_parsing.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//xla/python/ifrt:friends"], + deps = [ + "//xla/mlir/utils:error_util", + "//xla/mlir_hlo:hlo_dialect_registration", + "//xla/python/ifrt/ir", + "//xla/python/ifrt/ir/transforms:built_in_spmd_expansions", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@shardy//shardy/dialect/sdy/ir:register", + "@stablehlo//:register", + ], +) + cc_library( name = "sharding_conversions", srcs = ["sharding_conversions.cc"], hdrs = ["sharding_conversions.h"], + compatible_with = get_compatible_with_portable(), visibility = ["//xla/python/ifrt:friends"], deps = [ "//xla:xla_data_proto_cc", diff --git a/third_party/xla/xla/python/ifrt/support/module_parsing.cc b/third_party/xla/xla/python/ifrt/support/module_parsing.cc new file mode 100644 index 00000000000000..7a0a8fa7006f3b --- /dev/null +++ b/third_party/xla/xla/python/ifrt/support/module_parsing.cc @@ -0,0 +1,71 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt/support/module_parsing.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Parser/Parser.h" +#include "shardy/dialect/sdy/ir/register.h" +#include "stablehlo/dialect/Register.h" +#include "xla/mlir/utils/error_util.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/transforms/built_in_spmd_expansions.h" + +namespace xla { +namespace ifrt { +namespace support { + +void InitializeMlirDialectRegistry(mlir::DialectRegistry& registry) { + registry.insert(); + mlir::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::sdy::registerAllDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + xla::ifrt::AttachBuiltInSpmdExpansions(registry); +} + +void RegisterMlirDialects(mlir::MLIRContext& context) { + mlir::DialectRegistry registry; + InitializeMlirDialectRegistry(registry); + context.appendDialectRegistry(registry); +} + +absl::StatusOr> ParseMlirModuleString( + absl::string_view mlir_module_str, mlir::MLIRContext& context) { + RegisterMlirDialects(context); + mlir::BaseScopedDiagnosticHandler diagnostic_handler(&context); + mlir::OwningOpRef module = + mlir::parseSourceString(mlir_module_str, &context); + if (!module) { + return absl::InvalidArgumentError( + absl::StrFormat("Failed to parse IFRT IR module string: %s", + diagnostic_handler.ConsumeStatus().message())); + } + return module; +} + +} // namespace support +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/support/module_parsing.h b/third_party/xla/xla/python/ifrt/support/module_parsing.h new file mode 100644 index 00000000000000..f9394cca98bc17 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/support/module_parsing.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ +#define XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" + +namespace xla { +namespace ifrt { +namespace support { + +// Initializes the given MLIR dialect registry with dialects that are required +// by IFRT IR passes. +void InitializeMlirDialectRegistry(mlir::DialectRegistry& registry); + +// Registers all dialects required by IFRT IR modules. +void RegisterMlirDialects(mlir::MLIRContext& context); + +// Converts an IFRT IR module string to an mlir::Module. +absl::StatusOr> ParseMlirModuleString( + absl::string_view mlir_module_str, mlir::MLIRContext& context); + +} // namespace support +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_SUPPORT_MODULE_PARSING_H_ diff --git a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc index 0942c5658d2e1d..d52a8b44f9dfcd 100644 --- a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc +++ b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc @@ -77,6 +77,7 @@ std::shared_ptr MakeTestClient(int num_devices) { for (int i = 0; i < num_devices; ++i) { auto device = std::make_unique(); ON_CALL(*device, Id).WillByDefault(Return(DeviceId(i))); + ON_CALL(*device, IsAddressable).WillByDefault(Return(true)); state->devices.push_back(device.get()); state->device_map.insert({DeviceId(i), std::move(device)}); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 8e947a4e68beac..04c008c035e3b7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -64,6 +64,7 @@ ifrt_proxy_cc_test( "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", "@com_github_grpc_grpc//:gpr", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/base:core_headers", @@ -96,21 +97,50 @@ cc_library( ":host_buffer", "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", + "//xla/python/ifrt_proxy/common:types", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/profiler/utils:xplane_schema", ] + if_google(["@com_google_absl//absl/types:source_location"]), ) +ifrt_proxy_cc_test( + name = "rpc_helper_test", + srcs = ["rpc_helper_test.cc"], + deps = [ + ":client_session", + ":mock_client_session", + ":rpc_helper", + ":version", + "//xla/python/ifrt", + "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", + "//xla/python/ifrt_proxy/common:test_utils", + "//xla/python/ifrt_proxy/common:types", + "//xla/python/ifrt_proxy/common:types_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "client", srcs = ["client.cc"], @@ -239,6 +269,8 @@ ifrt_proxy_cc_test( "//xla/python/ifrt_proxy/common:types_proto_cc", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/status", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", @@ -401,6 +433,7 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_cc_grpc_proto", "//xla/python/ifrt_proxy/common:grpc_ifrt_service_proto_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_github_grpc_grpc//:grpc++", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -409,7 +442,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:unbounded_work_queue", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.cc b/third_party/xla/xla/python/ifrt_proxy/client/array.cc index b34d45299f1be4..c1f9b10dfd98fb 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.cc @@ -100,8 +100,13 @@ Array::MakeArrayFromHostBuffer( } void Array::Destruct(RpcHelper* rpc_helper, ArrayHandle handle) { + if (rpc_helper->version().protocol_version() >= 5) { + rpc_helper->Batch(RpcHelper::kDestructArray, handle); + return; + } + auto req = std::make_unique(); - req->set_array_handle(handle.handle); + req->set_array_handle_deprecated(handle.handle); rpc_helper->DestructArray(std::move(req)) .OnReady( [](absl::StatusOr> response) { @@ -126,8 +131,13 @@ Future<> Array::GetReadyFuture() const { } Future<> Array::Delete() { + if (rpc_helper_->version().protocol_version() >= 5) { + rpc_helper_->Batch(RpcHelper::kDeleteArray, handle_); + return Future<>(absl::OkStatus()); + } + auto req = std::make_unique(); - req->set_array_handle(handle_.handle); + req->set_array_handle_deprecated(handle_.handle); absl::StatusOr> response = rpc_helper_->DeleteArray(std::move(req)).Await(); @@ -165,12 +175,22 @@ Array::AssembleArrayFromSingleDeviceArrays( xla::ifrt::Client* client, std::shared_ptr rpc_helper, Shape shape, std::shared_ptr sharding, absl::Span> arrays, - ArrayCopySemantics semantics) { + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAddressableShards && + rpc_helper->version().protocol_version() < 8) { + return absl::UnimplementedError( + "SingleDeviceShardSemantics::kAdressableShards is not supported in " + "ifrt-proxy version < 8"); + } auto req = std::make_unique(); TF_RET_CHECK(!arrays.empty()); *req->mutable_shape() = shape.ToProto(); TF_ASSIGN_OR_RETURN(*req->mutable_sharding(), sharding->ToProto()); - req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + req->set_copy_semantics(ToArrayCopySemanticsProto(array_copy_semantics)); + req->set_single_device_shard_semantics( + ToSingleDeviceShardSemanticsProto(single_device_shard_semantics)); for (const tsl::RCReference& rcref : arrays) { Array* array = llvm::dyn_cast(rcref.get()); if (array == nullptr) { @@ -234,9 +254,26 @@ Array::RemapArrays(xla::ifrt::Client* client, absl::StatusOr>> Array::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { + return DisassembleIntoSingleDeviceArrays( + semantics, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>> +Array::DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAddressableShards && + rpc_helper_->version().protocol_version() < 8) { + return absl::UnimplementedError( + "SingleDeviceShardSemantics::kAdressableShards is not supported in " + "version < 8"); + } auto req = std::make_unique(); req->set_array_handle(handle_.handle); - req->set_copy_semantics(ToArrayCopySemanticsProto(semantics)); + req->set_copy_semantics(ToArrayCopySemanticsProto(array_copy_semantics)); + req->set_single_device_shard_semantics( + ToSingleDeviceShardSemanticsProto(single_device_shard_semantics)); TF_ASSIGN_OR_RETURN( std::shared_ptr response, diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array.h b/third_party/xla/xla/python/ifrt_proxy/client/array.h index 123bb32695a2a9..6a083dcfb017a0 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/array.h @@ -71,7 +71,8 @@ class Array final : public llvm::RTTIExtends { xla::ifrt::Client* client, std::shared_ptr rpc_helper, Shape shape, std::shared_ptr sharding, absl::Span> arrays, - ArrayCopySemantics semantics); + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics); // `Array::RemapArrays()` implements `Client::RemapArrays()`. // TODO(b/261226026): Implement logic directly in client.cc. @@ -118,6 +119,10 @@ class Array final : public llvm::RTTIExtends { absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) override; absl::StatusOr> FullyReplicatedShard( xla::ifrt::ArrayCopySemantics semantics) override; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc index f069dd959f662e..140c74b0311138 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/array_test.cc @@ -20,6 +20,8 @@ #include #include #include "absl/status/status.h" +#include "absl/synchronization/notification.h" +#include "absl/time/time.h" #include "absl/types/span.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/dtype.h" @@ -60,7 +62,7 @@ namespace { IfrtProxyVersion Version() { IfrtProxyVersion version; - version.set_protocol_version(kClientMinVersion); + version.set_protocol_version(kClientMaxVersion); return version; } @@ -88,17 +90,26 @@ class ArrayTest : public ::testing::Test { // TODO(b/315809436): Test needs rewrite because protobuf matchers are not OSS #if defined(PLATFORM_GOOGLE) TEST_F(ArrayTest, Destruction) { - IfrtResponse response; + // Destruction may not happen immediately because of batching at the + // client-side. This test waits until destruction happens. + absl::Notification destructed; EXPECT_CALL( *session_, Enqueue(Pointee(Partially(EquivToProto(R"pb(destruct_array_request { array_handle: 1234 })pb"))))) - .WillOnce(MockClientSessionReturnResponse(response)); + .WillOnce([&](std::unique_ptr request) + -> Future { + destructed.Notify(); + auto result = std::make_shared(); + return Future(result); + }); MockClient client; tsl::MakeRef(&client, rpc_helper_, DType(DType::Kind::kBF16), Shape({}), /*sharding=*/nullptr, ArrayHandle{1234}); + + ASSERT_TRUE(destructed.WaitForNotificationWithTimeout(absl::Seconds(10))); } #endif diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index db3a6026ac885d..0abd6176dcd8ad 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -67,6 +67,20 @@ absl::StatusOr> Client::Create( absl::flat_hash_set addressable_device_ids( init_response.addressable_device_ids().begin(), init_response.addressable_device_ids().end()); + absl::flat_hash_set primary_device_ids; + if (rpc_helper->version().protocol_version() < 7) { + // Legacy implementation for servers do not support Client::GetAllDevices() + // and thus do not provide device_ids(). Assume that it contains all device + // ids from devices(). + primary_device_ids.reserve(init_response.all_devices().size()); + for (const auto& d : init_response.all_devices()) { + primary_device_ids.insert(d.id()); + } + } else { + primary_device_ids.reserve(init_response.primary_device_ids().size()); + primary_device_ids.insert(init_response.primary_device_ids().begin(), + init_response.primary_device_ids().end()); + } absl::flat_hash_map> memories; for (const auto& m : init_response.memories()) { @@ -77,10 +91,11 @@ absl::StatusOr> Client::Create( } absl::flat_hash_map> devices; - std::vector device_ptrs; + std::vector primary_device_ptrs; std::vector addressable_device_ptrs; + std::vector all_device_ptrs; - for (const auto& d : init_response.devices()) { + for (const auto& d : init_response.all_devices()) { absl::flat_hash_map pjrt_device_attributes; if (rpc_helper->version().protocol_version() <= 3) { @@ -99,13 +114,17 @@ absl::StatusOr> Client::Create( d.device_kind(), d.debug_string(), d.to_string(), std::move(pjrt_device_attributes)); bool is_addressable = addressable_device_ids.contains(d.id()); + bool is_primary = primary_device_ids.contains(d.id()); auto device = std::make_unique(std::move(desc), d.local_device_id(), d.local_hardware_id(), is_addressable); - device_ptrs.push_back(device.get()); - if (is_addressable) { - addressable_device_ptrs.push_back(device.get()); + all_device_ptrs.push_back(device.get()); + if (is_primary) { + primary_device_ptrs.push_back(device.get()); + if (is_addressable) { + addressable_device_ptrs.push_back(device.get()); + } } if (d.has_default_memory_id()) { @@ -150,9 +169,10 @@ absl::StatusOr> Client::Create( std::move(rpc_helper), init_response.session_id(), init_response.platform_name(), init_response.platform_version(), init_response.platform_id(), init_response.process_index(), runtime_type, - std::move(devices), device_ptrs, std::move(addressable_device_ptrs), + std::move(devices), std::move(primary_device_ptrs), + std::move(addressable_device_ptrs), all_device_ptrs, std::move(memories))); - for (ifrt::Device* device : device_ptrs) { + for (ifrt::Device* device : all_device_ptrs) { tensorflow::down_cast(device)->client_ = client.get(); } return client; @@ -163,8 +183,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories) : rpc_helper_(rpc_helper), platform_name_(std::move(platform_name)), @@ -175,8 +196,9 @@ Client::Client(std::shared_ptr rpc_helper, uint64_t session_id, // TODO(b/309059940): Forward the backend attributes to the client. attributes_(AttributeMap::Map()), devices_(std::move(devices)), - device_ptrs_(device_ptrs), + primary_device_ptrs_(primary_device_ptrs), addressable_device_ptrs_(std::move(addressable_device_ptrs)), + all_device_ptrs_(all_device_ptrs), memories_(std::move(memories)), default_compiler_(this, rpc_helper) {} @@ -210,7 +232,19 @@ Client::AssembleArrayFromSingleDeviceArrays( absl::Span> arrays, ArrayCopySemantics semantics) { return Array::AssembleArrayFromSingleDeviceArrays( - this, rpc_helper_, std::move(shape), sharding, arrays, semantics); + this, rpc_helper_, std::move(shape), sharding, arrays, semantics, + SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> +Client::AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + return Array::AssembleArrayFromSingleDeviceArrays( + this, rpc_helper_, std::move(shape), sharding, arrays, + array_copy_semantics, single_device_shard_semantics); } absl::StatusOr>> @@ -302,6 +336,10 @@ xla::ifrt::Future<> Client::GetReadyFuture( return JoinFutures(futures); } +absl::Span Client::GetAllDevices() const { + return all_device_ptrs_; +} + absl::StatusOr Client::GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const { auto req = std::make_unique(); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.h b/third_party/xla/xla/python/ifrt_proxy/client/client.h index fd700441eb63ec..3732b5ddd832d7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.h @@ -77,6 +77,12 @@ class Client final : public llvm::RTTIExtends { Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) override; + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) override; absl::StatusOr>> CopyArrays( absl::Span> arrays, @@ -110,12 +116,13 @@ class Client final : public llvm::RTTIExtends { return addressable_devices().size(); } absl::Span devices() const override { - return device_ptrs_; + return primary_device_ptrs_; } absl::Span addressable_devices() const override { return addressable_device_ptrs_; } int process_index() const override { return process_index_; } + absl::Span GetAllDevices() const override; absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override; absl::StatusOr LookupDevice( @@ -148,8 +155,9 @@ class Client final : public llvm::RTTIExtends { std::string platform_name, std::string platform_version, uint64_t platform_id, uint64_t process_index, std::string runtime_type, absl::flat_hash_map> devices, - std::vector device_ptrs, + std::vector primary_device_ptrs, std::vector addressable_device_ptrs, + std::vector all_device_ptrs, absl::flat_hash_map> memories); // rpc_helper_ will be referenced by various IFRT objects whose lifetime is @@ -166,8 +174,9 @@ class Client final : public llvm::RTTIExtends { const AttributeMap attributes_; const absl::flat_hash_map> devices_; - const std::vector device_ptrs_; + const std::vector primary_device_ptrs_; const std::vector addressable_device_ptrs_; + const std::vector all_device_ptrs_; const absl::flat_hash_map> memories_; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc index 3f1dbb45c7dea6..03dd43f3c93cfe 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client_test.cc @@ -83,7 +83,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -94,7 +94,7 @@ class ClientTest : public ::testing::TestWithParam { value { string_value: "device0" } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -120,6 +120,55 @@ class ClientTest : public ::testing::TestWithParam { } )pb", &response)); + } else if (Version().protocol_version() < 7) { + ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( + R"pb( + platform_name: "ifrt-service" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + local_hardware_id: 1234 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + addressable_device_ids: 1 + memories { + id: 0 + memory_space_kind: "mock" + kind_id: 0 + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + kind_id: 1 + device_ids: [ 1 ] + } + )pb", + &response)); } else { ASSERT_TRUE(tsl::protobuf::TextFormat::ParseFromString( R"pb( @@ -128,7 +177,7 @@ class ClientTest : public ::testing::TestWithParam { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 local_hardware_id: 1234 device_kind: "mock" @@ -141,7 +190,7 @@ class ClientTest : public ::testing::TestWithParam { } } } - devices { + all_devices { id: 1 local_hardware_id: 1234 device_kind: "mock" @@ -154,6 +203,7 @@ class ClientTest : public ::testing::TestWithParam { } } } + primary_device_ids: [ 0, 1 ] addressable_device_ids: 1 memories { id: 0 diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc index 37c7d4795e0509..f284a5bbd75af1 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.cc @@ -355,12 +355,12 @@ std::optional> LoadedExecutable::GetOutputShardings() return (*info)->output_shardings; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetParameterLayouts() const { TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->parameter_layouts.status()); - std::vector> result; + std::vector> result; result.reserve(info->parameter_layouts->size()); for (const xla::Layout& layout : *info->parameter_layouts) { result.push_back(std::make_unique(layout)); @@ -368,12 +368,12 @@ LoadedExecutable::GetParameterLayouts() const { return result; } -absl::StatusOr>> +absl::StatusOr>> LoadedExecutable::GetOutputLayouts() const { TF_ASSIGN_OR_RETURN(auto info, metadata_future_.Await()); TF_RETURN_IF_ERROR(info->output_layouts.status()); - std::vector> result; + std::vector> result; result.reserve(info->output_layouts->size()); for (const xla::Layout& layout : *info->output_layouts) { result.push_back(std::make_unique(layout)); @@ -432,7 +432,12 @@ LoadedExecutable::Execute( // Populate the execution status future. `CheckFuture` deletes the server-side // futures after its completion. - result.status = rpc_helper_->CheckFuture(response->status_handle()); + // + // Starting version 6, the server populates the status future only if it was + // explicitly requested via `options.fill_status`. + if (rpc_helper_->version().protocol_version() < 6 || options.fill_status) { + result.status = rpc_helper_->CheckFuture(response->status_handle()); + } // Create output arrays. The cleanup logic ensures that all handles are // properly cleaned up on early return. diff --git a/third_party/xla/xla/python/ifrt_proxy/client/executable.h b/third_party/xla/xla/python/ifrt_proxy/client/executable.h index 2df7d17a8ffaae..9afc875cf9bf11 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/executable.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/executable.h @@ -75,10 +75,10 @@ class LoadedExecutable final std::optional> GetParameterShardings() const override; std::optional> GetOutputShardings() const override; - absl::StatusOr>> GetParameterLayouts() - const override; - absl::StatusOr>> GetOutputLayouts() - const override; + absl::StatusOr>> + GetParameterLayouts() const override; + absl::StatusOr>> + GetOutputLayouts() const override; absl::StatusOr>> GetOutputMemoryKinds() const override; absl::StatusOr>> GetHloModules() diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc index d673a5f7561829..b555bc62d29398 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.cc @@ -40,7 +40,6 @@ #include "grpcpp/support/channel_arguments.h" #include "xla/pjrt/distributed/util.h" #include "xla/python/ifrt/future.h" -#include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/grpc_credentials.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" @@ -54,8 +53,6 @@ namespace xla { namespace ifrt { namespace proxy { -using OpId = int64_t; - // Logically equivalent to a map, but thread-safe and // with various convenience functions. class GrpcClientSession::ResponseCallbackTable { @@ -146,9 +143,9 @@ Future> GrpcClientSession::Enqueue( absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, ResponseCallback callback) { - const OpId op_id = req->request_metadata().op_id(); - absl::MutexLock l(&writer_mu_); + const OpId op_id = writer_next_op_id_++; + if (writes_stopped_) { return absl::FailedPreconditionError( "GrpcClientSession: writes no longer allowed."); @@ -156,6 +153,9 @@ absl::Status GrpcClientSession::Enqueue(std::unique_ptr req, TF_RETURN_IF_ERROR(response_callbacks_->Add(op_id, std::move(callback))); + CHECK_EQ(req->mutable_request_metadata()->op_id(), 0); + req->mutable_request_metadata()->set_op_id(op_id); + if (!stream_->Write(*req)) { CHECK(response_callbacks_->Pop(op_id).has_value()); return absl::UnknownError("GrpcClientSession: writing to stream failed."); diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h index 9e80b9ad850858..3187098bb6dd0a 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session.h @@ -17,6 +17,7 @@ #ifndef XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ #define XLA_PYTHON_IFRT_PROXY_CLIENT_GRPC_CLIENT_SESSION_H_ +#include #include #include @@ -116,6 +117,9 @@ class GrpcClientSession : public ClientSession { // only one thread is allowed to write to the gRPC stream at a time. absl::Mutex writer_mu_; + using OpId = uint64_t; + OpId writer_next_op_id_ ABSL_GUARDED_BY(writer_mu_) = 1; + // Ensures logic inside `Finish()` is internally called only once. absl::once_flag finish_once_; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc index 882f7b271841d0..61039d9e93b116 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_client_session_test.cc @@ -15,7 +15,6 @@ #include "xla/python/ifrt_proxy/client/grpc_client_session.h" #include -#include #include #include #include @@ -49,6 +48,7 @@ #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status_matchers.h" @@ -64,9 +64,6 @@ namespace { using ::testing::Not; using ::tsl::testing::IsOk; -constexpr int kOp1 = 1; -constexpr int kOp2 = 2; - // Sufficient time for all processing (that are not explicitly waiting for // further input) to have finished. constexpr absl::Duration kSufficientTime = absl::Seconds(5); @@ -79,49 +76,8 @@ GrpcIfrtSessionMetadata Metadata() { absl::Status TestError() { return absl::UnknownError("test error"); } -// A thread-safe queue of `absl::Status` values. -class Queue { - public: - void Push(absl::Status t) { - absl::MutexLock l(&mu_); - queue_.push_back(std::move(t)); - } - - std::optional PopOrTimeout( - absl::Duration timeout = kSufficientTime) { - absl::MutexLock l(&mu_); - auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { - return !queue_.empty(); - }; - mu_.AwaitWithTimeout(absl::Condition(&cond), timeout); - if (queue_.empty()) { - return std::nullopt; - } - absl::Status result = std::move(queue_.front()); - queue_.pop_front(); - return result; - } - - absl::Status Pop(absl::Duration timeout = kSufficientTime) { - auto result = PopOrTimeout(timeout); - CHECK(result.has_value()) << "Timeout!"; - return *result; - } - - void PopAllDuringDestruction() { - absl::MutexLock l(&mu_); - allow_non_empty_destruction_ = true; - } - - ~Queue() { - absl::MutexLock l(&mu_); - if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; - } - - private: - absl::Mutex mu_; - std::deque queue_ ABSL_GUARDED_BY(mu_); - bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +struct Queue : public TestQueue { + Queue() : TestQueue(kSufficientTime) {} }; // Checks that the input is a list of zero-or-more OK statuses followed by @@ -252,7 +208,7 @@ class ClientAndServer { client_finished_notification_.Notify(); }); - client_finished_q_.PopAllDuringDestruction(); + client_finished_q_.AllowNonEmptyDestruction(/*allow=*/true); } void StopServer() { @@ -273,12 +229,11 @@ class ClientAndServer { Queue* client_finished_q() { return &client_finished_q_; } - absl::StatusOr SendSimpleRequest(int op_id) { + absl::StatusOr SendSimpleRequest() { owned_queues_.push_back(std::make_unique()); Queue* q = owned_queues_.back().get(); auto req = std::make_unique(); - req->mutable_request_metadata()->set_op_id(op_id); TF_RETURN_IF_ERROR(client_session_->Enqueue( std::move(req), [q](absl::StatusOr resp) { q->Push(resp.status()); @@ -300,7 +255,7 @@ class ClientAndServer { TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest()); EXPECT_THAT(response_q->Pop(), IsOk()); @@ -313,8 +268,8 @@ TEST(GrpcClientSessionTest, HappyCaseOneRequestWithServerTermination) { TEST(GrpcClientSessionTest, HappyCaseTwoRequestsWithClientFinish) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest(kOp2)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_2, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), IsOk()); EXPECT_THAT(response_q_2->Pop(), IsOk()); @@ -329,10 +284,10 @@ TEST(GrpcClientSessionTest, ServerFinishesDuringFirstRead) { ClientAndServer cs( /*on_req_received=*/[](auto, auto) { return kStopSession; }); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); EXPECT_THAT(response_q_2.status(), Not(IsOk())); EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); @@ -342,8 +297,8 @@ TEST(GrpcClientSessionTest, ServerFinishesDuringConstruction) { ClientAndServer cs(/*on_req_received=*/nullptr, /*on_session_start=*/[]() { return kStopSession; }); - absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_1 = cs.SendSimpleRequest(); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); ExpectHeadAndTail({response_q_1, response_q_2}); if (response_q_1.ok()) EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); @@ -361,10 +316,10 @@ TEST(GrpcClientSessionTest, ClientFinishesAfterServerConsumesFirstRequest) { }); session_ptr.store(cs.client_session()); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); EXPECT_THAT(response_q_1->Pop(), Not(IsOk())); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); EXPECT_THAT(response_q_2.status(), Not(IsOk())); EXPECT_THAT(cs.client_finished_q()->Pop(), Not(IsOk())); @@ -384,8 +339,8 @@ TEST(GrpcClientSessionTest, ClientFinishesAfterServerWritesFirstResponse) { }); session_ptr.store(cs.client_session()); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); // The client may or may not terminate before the first response arrives. response_q_1->Pop().IgnoreError(); @@ -413,8 +368,8 @@ TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { session_ptr.store(cs.client_session()); init_done.Notify(); - absl::StatusOr response_q_1 = cs.SendSimpleRequest(kOp1); - absl::StatusOr response_q_2 = cs.SendSimpleRequest(kOp2); + absl::StatusOr response_q_1 = cs.SendSimpleRequest(); + absl::StatusOr response_q_2 = cs.SendSimpleRequest(); if (response_q_1.ok()) { EXPECT_THAT(response_q_1.value()->Pop(), Not(IsOk())); @@ -431,19 +386,19 @@ TEST(GrpcClientSessionTest, ClientFinishesDuringServerConstruction) { TEST(GrpcClientSessionTest, MethodsAfterFinishReturnError) { ClientAndServer cs; - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q_1, cs.SendSimpleRequest()); cs.client_session()->Finish(TestError()); - EXPECT_THAT(cs.SendSimpleRequest(kOp2), Not(IsOk())); + EXPECT_THAT(cs.SendSimpleRequest(), Not(IsOk())); - response_q_1->PopAllDuringDestruction(); + response_q_1->AllowNonEmptyDestruction(/*allow=*/true); } TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { ClientAndServer cs( /*on_req_received=*/[](const IfrtRequest& r, ServerStream* s) mutable { IfrtResponse resp; - resp.mutable_response_metadata()->set_op_id(kOp2); + resp.mutable_response_metadata()->set_op_id(2000); s->Write(resp); resp.mutable_response_metadata()->set_op_id( r.request_metadata().op_id()); @@ -451,7 +406,7 @@ TEST(GrpcClientSessionTest, ReceivingBadIfrtResponseDoesNotCrash) { return kContinueSession; }); - TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest(kOp1)); + TF_ASSERT_OK_AND_ASSIGN(Queue * response_q, cs.SendSimpleRequest()); EXPECT_THAT(response_q->Pop(), IsOk()); } diff --git a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc index 442381ef0abd50..b80f84b0593779 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/grpc_host_buffer.cc @@ -33,9 +33,9 @@ #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.grpc.pb.h" #include "xla/python/ifrt_proxy/common/grpc_ifrt_service.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/unbounded_work_queue.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index ff116a759e31a6..19998ffd34619f 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -14,10 +14,16 @@ #include "xla/python/ifrt_proxy/client/rpc_helper.h" +#include #include #include +#include +#include #include +#include +#include "absl/base/thread_annotations.h" +#include "absl/functional/bind_front.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -25,51 +31,281 @@ #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" +#include "tsl/platform/threadpool.h" #include "tsl/profiler/lib/traceme.h" #include "tsl/profiler/lib/traceme_encode.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace ifrt { namespace proxy { +namespace { + using ::tsl::profiler::XFlow; +constexpr absl::Duration kPeriodicFlushInterval = absl::Microseconds(50); + +// XFlowHelper makes it easier to create trace spans with a flow between them. +// Typical usage: +// +// XFlowHelper flow("my_request"); +// ... +// +// auto response_handler = [flow](ResponseMsg msg) { +// flow.InstantActivity(); +// LOG(INFO) << "Received response: " << msg; +// } +// +// { +// auto request_span = flow.Span(); +// auto request_protobuf = CreateRequestProtobuf(); +// transport.Send(request_protobuf, response_handler); +// } +// +// +class XFlowHelper { + public: + explicit XFlowHelper(absl::string_view name) + : xflow_id_(tsl::random::New64() >> 8 /*XFlow IDs are 56 bits*/), + name_(name) {} + + typedef enum { kSend, kRecv, kRecvSend } Direction; + + template + tsl::profiler::TraceMe Span() const { + return tsl::profiler::TraceMe([xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + template + void InstantActivity() const { + return tsl::profiler::TraceMe::InstantActivity( + [xflow_id = xflow_id_, name = name_] { + return Encode(xflow_id, name); + }); + } + + private: + template + static std::string Encode(uint64_t xflow_id, absl::string_view name) { + static constexpr absl::string_view flow_dir_str = + D == kSend ? "send" : (D == kRecv ? "recv" : "recv_send"); + const XFlow flow(xflow_id, D == kRecvSend ? XFlow::kFlowInOut + : (D == kRecv ? XFlow::kFlowIn + : XFlow::kFlowOut)); + return tsl::profiler::TraceMeEncode( + name, {{"dir", flow_dir_str}, {"flow", flow.ToStatValue()}}); + }; + + const uint64_t xflow_id_; + const absl::string_view name_; +}; + +// Thread-safe data structure for holding batched operations. +class BatchedOps { + public: + using BatchOperation = RpcHelper::BatchOperation; + + void Add(BatchOperation op, ArrayHandle handle) { + absl::MutexLock l(&mu_); + batched_[op].push_back(handle); + } + + struct IfrtRequests { + std::unique_ptr delete_req; + std::unique_ptr destruct_req; + }; + + IfrtRequests Consume() { + IfrtRequests result; + absl::MutexLock l(&mu_); + if (!batched_[BatchOperation::kDeleteArray].empty()) { + result.delete_req = std::make_unique(); + for (const auto& arr_handle : batched_[BatchOperation::kDeleteArray]) { + result.delete_req->mutable_delete_array_request()->add_array_handle( + arr_handle.handle); + } + batched_[BatchOperation::kDeleteArray].clear(); + } + if (!batched_[BatchOperation::kDestructArray].empty()) { + result.destruct_req = std::make_unique(); + for (const auto& arr_handle : batched_[BatchOperation::kDestructArray]) { + result.destruct_req->mutable_destruct_array_request()->add_array_handle( + arr_handle.handle); + } + batched_[BatchOperation::kDestructArray].clear(); + } + return result; + } + + private: + absl::Mutex mu_; + std::array, BatchOperation::kSentinelDoNotUse> + batched_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace + +// Batches any requested operations and flushes them periodically in the +// background, and allows sending other requested operations immediately. +// Immediate operations are guaranteed to be sent after all previously enqueued +// batched operations. +class RpcHelper::Batcher { + public: + explicit Batcher(std::shared_ptr session) + : session_(std::move(session)) { + thread_pool_.emplace(tsl::Env::Default(), "IfrtProxyRpcHelperBatcher", + /*num_threads=*/1); + thread_pool_->Schedule(absl::bind_front(&Batcher::PeriodicFlusher, this)); + } + + // Sends the given request immediately after sending any batched operations + // that have been previously enqueued. + Future Immediate( + std::unique_ptr request) { + absl::MutexLock l(&mu_); + if (finished_) { + LOG(WARNING) << "After RpcHelper::Finish(): " << request->DebugString(); + return Future( + absl::FailedPreconditionError("RpcHelper::Finish() already called.")); + } + Flush(); + return session_->Enqueue(std::move(request)); + } + + // Enqueues an operation to be sent later. Guaranteed to not be blocked by the + // underlying transport. + void Batch(BatchOperation op, ArrayHandle handle) { + batched_.Add(op, handle); + } + + // Asks the underlying transport to terminate. + void Finish(absl::Status s) { + { + absl::MutexLock l(&mu_); + finished_ = true; + auto remaining = batched_.Consume(); + if (remaining.delete_req != nullptr) { + LOG(WARNING) << "RpcHelper::Batch: Finish() called while there are " + "still batched delete operations"; + } + if (remaining.destruct_req != nullptr) { + LOG(WARNING) << "RpcHelper::Batch: Finish() called while there are " + "still batched destruct operations"; + } + } + thread_pool_.reset(); + session_->Finish(s); + } + + private: + void PeriodicFlusher() { + while (true) { + absl::SleepFor(kPeriodicFlushInterval); + absl::MutexLock l(&mu_); + if (finished_) { + return; + } + { + bool periodic_flush_paused = false; + TestHookCall(TestHookName::kRpcBatcherPausePeriodicFlush, + &periodic_flush_paused); + if (periodic_flush_paused) { + continue; + } + } + tsl::profiler::TraceMe traceme("proxy_periodic_flush"); + Flush(); + } + } + + // Sends all enqueued batched operations. + void Flush() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + auto reqs = batched_.Consume(); + if (reqs.delete_req != nullptr) { + XFlowHelper x_flow_helper("batch_delete"); + auto traceme = x_flow_helper.Span(); + session_->Enqueue(std::move(reqs.delete_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session_, x_flow_helper)); + } + if (reqs.destruct_req != nullptr) { + XFlowHelper x_flow_helper("batch_destruct"); + auto traceme = x_flow_helper.Span(); + session_->Enqueue(std::move(reqs.destruct_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session_, x_flow_helper)); + } + } + + // Handles a response from the server of a previous batched operation; + // bad responses are logged but otherwise ignored. The method is static since + // it can be called in the background after RpcHelper::Batcher is destroyed. + static void HandleBatchResponse( + std::shared_ptr session, XFlowHelper x_flow_helper, + absl::StatusOr> r) { + if (!r.ok()) { + x_flow_helper.InstantActivity(); + LOG(WARNING) << "Batched response from ifrt proxy server: " << r.status(); + return; + } + if (r.value()->has_delete_array_response()) { + auto traceme = x_flow_helper.Span(); + auto ifrt_req = std::make_unique(); + ifrt_req->mutable_check_future_request()->set_future_handle( + r.value()->delete_array_response().deletion_future_handle()); + session->Enqueue(std::move(ifrt_req)) + .OnReady( + absl::bind_front(HandleBatchResponse, session, x_flow_helper)); + } else if (r.value()->has_destruct_array_response() || + r.value()->has_check_future_response()) { + x_flow_helper.InstantActivity(); + } else { + LOG(ERROR) << "Unrecognized response from server for batched request: " + << (*r)->DebugString(); + } + } + + const std::shared_ptr session_; + + BatchedOps batched_; + + absl::Mutex mu_; + bool finished_ ABSL_GUARDED_BY(mu_) = false; + std::optional thread_pool_; +}; + // DoRpc is a templated function that implements the logic of all RPC-wrapping // functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. template -Future> DoRpc(ClientSession* session, - RequestMetadata metadata, +Future> DoRpc(RpcHelper::Batcher* batcher, void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), bool (IfrtResponse::*has_resp)() const, std::unique_ptr req, - absl::string_view profiling_send_name, - absl::string_view profiling_recv_name) { + absl::string_view profiling_name) { auto ifrt_req = std::make_unique(); - *ifrt_req->mutable_request_metadata() = metadata; (ifrt_req.get()->*set_req)(req.release()); - const uint64_t xflow_id = tsl::random::New64() >> 8; // XFlow IDs are 56 bits - tsl::profiler::TraceMe traceme([xflow_id, profiling_send_name]() { - const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowOut); - return tsl::profiler::TraceMeEncode(profiling_send_name, - {{"flow", flow.ToStatValue()}}); - }); + XFlowHelper x_flow_helper(profiling_name); + auto traceme = x_flow_helper.Span(); auto promise = Future>::CreatePromise(); - auto on_ready = [promise, has_resp, get_resp, xflow_id, profiling_recv_name]( + auto on_ready = [promise, has_resp, get_resp, x_flow_helper]( absl::StatusOr> r) mutable { - tsl::profiler::TraceMe traceme([xflow_id, profiling_recv_name]() { - const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowIn); - return tsl::profiler::TraceMeEncode(profiling_recv_name, - {{"flow", flow.ToStatValue()}}); - }); + auto traceme = x_flow_helper.Span(); if (!r.ok()) { LOG_EVERY_N_SEC(ERROR, 10) << "Connection to IFRT proxy server was terminated: " << r.status(); @@ -118,41 +354,18 @@ Future> DoRpc(ClientSession* session, std::make_shared(*std::move((response.get()->*get_resp)()))); } }; - session->Enqueue(std::move(ifrt_req)).OnReady(on_ready); + batcher->Immediate(std::move(ifrt_req)).OnReady(on_ready); return Future>(promise); } -RequestMetadata RpcHelper::ManufactureRequestMetadata() { - RequestMetadata result; - { - absl::MutexLock l(&mu_); - result.set_op_id(next_op_id_++); - } - int prev_op_id = result.op_id() - 1; - if (prev_op_id != 0) { - // TODO(b/266635130): Depend only on necessary prior operations. - result.add_dependencies(prev_op_id); - } - // TODO(b/282757875): Add a ClearOps RPC for old dependencies. - return result; -} - -void RpcHelper::Disconnect() { - session_->Finish(absl::CancelledError("Disconnected by client")); -} - -// TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros -// go against the style guide, but are convenient as we are introducing more -// RPCs and are making changes to the exact signature of the DoRpc function. -#define RPC(METHOD, PROPERTY) \ - RpcHelper::ResponseFuture RpcHelper::METHOD( \ - std::unique_ptr req) { \ - return DoRpc(session_.get(), ManufactureRequestMetadata(), \ - &IfrtRequest::set_allocated_##PROPERTY##_request, \ - &IfrtResponse::mutable_##PROPERTY##_response, \ - &IfrtResponse::has_##PROPERTY##_response, std::move(req), \ - "" #PROPERTY "_send", "" #PROPERTY "_recv"); \ +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc( \ + batcher_.get(), &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req), #PROPERTY); \ } RPC(Init, init); @@ -193,6 +406,21 @@ Future<> RpcHelper::CheckFuture(uint64_t handle) { return Future<>(std::move(promise)); } +RpcHelper::RpcHelper(IfrtProxyVersion version, + std::shared_ptr session) + : batcher_(std::make_unique(std::move(session))), + version_(std::move(version)) {} + +RpcHelper::~RpcHelper() { Disconnect(); } + +void RpcHelper::Batch(BatchOperation op, ArrayHandle handle) { + return batcher_->Batch(op, handle); +} + +void RpcHelper::Disconnect() { + batcher_->Finish(absl::CancelledError("Disconnected by client")); +} + } // namespace proxy } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h index 3ed2a3eeb58d2b..fc88c22756502d 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.h @@ -23,13 +23,12 @@ #include "absl/base/thread_annotations.h" #include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/client/host_buffer.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/types.h" namespace xla { namespace ifrt { @@ -43,14 +42,13 @@ namespace proxy { // specify the necessary dependency. class RpcHelper { public: - RpcHelper(IfrtProxyVersion version, std::shared_ptr session) - : version_(std::move(version)), session_(std::move(session)) {} + RpcHelper(IfrtProxyVersion version, std::shared_ptr session); void Disconnect(); RpcHelper(const RpcHelper&) = delete; RpcHelper& operator=(const RpcHelper&) = delete; - ~RpcHelper() { Disconnect(); } + ~RpcHelper(); // IFRT Proxy version negotiated between the client and the server. const IfrtProxyVersion& version() const { return version_; } @@ -71,6 +69,15 @@ class RpcHelper { template using ResponseFuture = Future>; + class Batcher; + enum BatchOperation { kDeleteArray, kDestructArray, kSentinelDoNotUse }; + + // Adds the given operation to an impending batch of operations and returns + // immediately. The batch of operation is sent later (as a single logical + // RPC). The RPC is guaranteed to be sent before any unbatched RPCs resulting + // from the wrapper functions below. + void Batch(BatchOperation op, ArrayHandle handle); + // Wrapper function for various logical RPCs defined in ifrt_service.proto. // Whenever the RPC finishes, `on_done` will be called with the result or the // return status. `on_done` can be called with various locks held and should @@ -137,10 +144,9 @@ class RpcHelper { Future<> CheckFuture(uint64_t handle); private: - RequestMetadata ManufactureRequestMetadata() ABSL_LOCKS_EXCLUDED(mu_); + const std::unique_ptr batcher_; const IfrtProxyVersion version_; - const std::shared_ptr session_; std::shared_ptr host_buffer_store_; absl::Mutex mu_; diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc new file mode 100644 index 00000000000000..36adbabd3fafb4 --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper_test.cc @@ -0,0 +1,155 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xla/python/ifrt_proxy/client/rpc_helper.h" + +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" +#include "xla/python/ifrt/future.h" +#include "xla/python/ifrt_proxy/client/client_session.h" +#include "xla/python/ifrt_proxy/client/mock_client_session.h" +#include "xla/python/ifrt_proxy/client/version.h" +#include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "xla/python/ifrt_proxy/common/test_utils.h" +#include "xla/python/ifrt_proxy/common/types.h" +#include "xla/python/ifrt_proxy/common/types.pb.h" +#include "tsl/platform/test.h" + +using ::testing::_; +using ::testing::UnorderedElementsAre; + +namespace xla { +namespace ifrt { +namespace proxy { +namespace { + +constexpr absl::Duration kMaxFlushTimeout = absl::Seconds(10); + +void PausePeriodicFlushes() { + // We want to (a) return 'paused=true' whenever the flusher thread tries to + // find out whether flushing has been paused, and (b) wait for any ongoing + // background flushes to complete. To achieve (b), we wait until the flusher + // thread asks for the value of `paused` at least once. + struct AtomicBool { + absl::Mutex mu; + bool b = false; + }; + + auto called_at_least_once = std::make_shared(); + auto periodic_flusher_pause_hook = [called_at_least_once](bool* paused) { + *paused = true; + absl::MutexLock l(&called_at_least_once->mu); + called_at_least_once->b = true; + }; + TestHookSet(TestHookName::kRpcBatcherPausePeriodicFlush, + std::move(periodic_flusher_pause_hook)); + + absl::MutexLock l(&called_at_least_once->mu); + CHECK(called_at_least_once->mu.AwaitWithTimeout( + absl::Condition(&called_at_least_once->b), kMaxFlushTimeout)); +} + +void ResumePeriodicFlushes() { + TestHookClear(TestHookName::kRpcBatcherPausePeriodicFlush); +} + +class RpcHelperTest : public ::testing::Test { + public: + RpcHelperTest() : requests_(kMaxFlushTimeout) { + session_ = std::make_shared(); + IfrtProxyVersion version; + version.set_protocol_version(kClientMaxVersion); + rpc_helper_ = std::make_shared(version, session_); + EXPECT_CALL(*session_, Finish(_)).Times(1); + ON_CALL(*session_, Enqueue) + .WillByDefault([this](std::unique_ptr req) { + requests_.Push(std::move(req)); + return Future( + absl::InternalError("Fake error response")); + }); + } + + std::shared_ptr session_; + std::shared_ptr rpc_helper_; + TestQueue> requests_; +}; + +TEST_F(RpcHelperTest, BatchedPeriodicFlush) { + PausePeriodicFlushes(); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{1}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{2}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{3}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{4}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{9}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{8}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{7}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{6}); + ResumePeriodicFlushes(); + + auto delete_req = requests_.Pop(); + auto destruct_req = requests_.Pop(); + + if (destruct_req->has_delete_array_request()) { + destruct_req.swap(delete_req); + } + + EXPECT_THAT(destruct_req->destruct_array_request().array_handle(), + UnorderedElementsAre(1, 3, 9, 7)); + EXPECT_THAT(delete_req->delete_array_request().array_handle(), + UnorderedElementsAre(2, 4, 8, 6)); +} + +TEST_F(RpcHelperTest, BatchedNoPeriodicFlush) { + PausePeriodicFlushes(); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{1}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{2}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{3}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{4}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{9}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{8}); + rpc_helper_->Batch(RpcHelper::kDestructArray, ArrayHandle{7}); + rpc_helper_->Batch(RpcHelper::kDeleteArray, ArrayHandle{6}); + + // Send some non-batched request, which should flush all the batched requests. + { + auto dummy_request = std::make_unique(); + dummy_request->set_future_handle(1); + rpc_helper_->CheckFuture(std::move(dummy_request)); + requests_.AllowNonEmptyDestruction(/*allow=*/true); + } + + auto delete_req = requests_.Pop(); + auto destruct_req = requests_.Pop(); + + if (destruct_req->has_delete_array_request()) { + destruct_req.swap(delete_req); + } + + EXPECT_THAT(destruct_req->destruct_array_request().array_handle(), + UnorderedElementsAre(1, 3, 9, 7)); + EXPECT_THAT(delete_req->delete_array_request().array_handle(), + UnorderedElementsAre(2, 4, 8, 6)); +} + +} // namespace +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/client/version.h b/third_party/xla/xla/python/ifrt_proxy/client/version.h index 13c753ee9c5d61..bd83ceee35bab7 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/version.h +++ b/third_party/xla/xla/python/ifrt_proxy/client/version.h @@ -24,7 +24,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kClientMinVersion = 3; -inline constexpr int kClientMaxVersion = 4; +inline constexpr int kClientMaxVersion = 8; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) } // namespace proxy diff --git a/third_party/xla/xla/python/ifrt_proxy/common/BUILD b/third_party/xla/xla/python/ifrt_proxy/common/BUILD index 134a8505419f8d..d8c5feefe6d439 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/common/BUILD @@ -72,11 +72,12 @@ tf_proto_library( "//xla/pjrt:execute_options_proto", "//xla/python/ifrt:attribute_map_proto", "//xla/python/ifrt:dtype_proto", + "//xla/python/ifrt:execute_options_proto", "//xla/python/ifrt:remap_plan_proto", "//xla/python/ifrt:serdes_proto", "//xla/python/ifrt:shape_proto", "//xla/python/ifrt:sharding_proto", - "@local_tsl//tsl/protobuf:status_proto", + "//xla/tsl/protobuf:status_proto", ], ) @@ -180,6 +181,20 @@ cc_library( ], ) +cc_library( + name = "test_utils", + srcs = ["test_utils.cc"], + hdrs = ["test_utils.h"], + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/debugging:leak_check", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], +) + # copybara:uncomment_begin # bzl_library( # name = "ifrt_proxy_bzl", @@ -187,4 +202,15 @@ cc_library( # parse_tests = False, # visibility = ["//visibility:private"], # ) +# +# bzl_library( +# name = "ifrt_proxy_google_bzl", +# srcs = ["ifrt_proxy.google.bzl"], +# parse_tests = False, +# visibility = ["//visibility:private"], +# deps = [ +# "//devtools/build_cleaner/skylark:build_defs_lib", +# "//xla:xla_bzl", +# ], +# ) # copybara:uncomment_end diff --git a/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md b/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md index 4166a27daf9ca7..2cb29563cf2bb5 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md +++ b/third_party/xla/xla/python/ifrt_proxy/common/VERSION.md @@ -24,3 +24,26 @@ * Changes: * Changed the serialization of client and device attributes to use `xla.ifrt.AttributeMapProto` instead of `map`. +## Version 5 + +* Added date: 2024-09-20. +* Changes: + * Batch array deletions and destruction on client before sending to server. + +## Version 6 + +* Added date: 2024-09-30. +* Changes: + * Added `ExecuteOptions::fill_status`. + +## Version 7 + +* Added date: 2024-10-01. +* Changes: + * Added support for `Client::GetAllDevices()`. + +## Version 8 + +* Added date: 2024-10-11. +* Changes: + * Added support for `SingleDeviceShardSemantics` in Array assembly and disassembly operations. diff --git a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto index 278f9156f4eb89..a0812a521f9e37 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -17,16 +17,16 @@ syntax = "proto3"; package xla.ifrt.proxy; import "google/protobuf/any.proto"; -import "xla/pjrt/execute_options.proto"; import "xla/python/ifrt/attribute_map.proto"; import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/execute_options.proto"; import "xla/python/ifrt/remap_plan.proto"; import "xla/python/ifrt/serdes.proto"; import "xla/python/ifrt/shape.proto"; import "xla/python/ifrt/sharding.proto"; import "xla/python/ifrt_proxy/common/types.proto"; +import "xla/tsl/protobuf/status.proto"; import "xla/xla_data.proto"; -import "tsl/protobuf/status.proto"; option cc_enable_arenas = true; @@ -141,27 +141,10 @@ message RequestMetadata { // resync after transient connectivity failures. fixed64 op_id = 1; - // List of one or more prior ops this current op is "dependent" - // upon. Currently this allows the client to define the order in which the - // server starts the execution of requests. Future versions may add other - // types of dependencies. For instance, a separate list of dependencies that - // must *complete* executing before the current one can start to execute. - // - // An op_id that has not yet been seen by the server is treated as an error - // that fails the op. - repeated fixed64 dependencies = 2; - - // UserContext is a basic provenance mechanism that allows the server-side - // actions and artifacts (say, allocating a buffer) to be associated with the - // corresponding client-side context that triggered those actions. - // - // The optional UserContextId is generated by the client and are used as an - // opaque label by the server and the run-time systems behind it. - // TODO(b/282757875): Add a pointer to Usercontext bugs/design doc. - fixed64 user_context_id = 3; - - // Additional implementation-specific payloads. + // Implementation-specific payloads. repeated google.protobuf.Any payloads = 4; + + reserved 2, 3; } // Metadata of an IFRT Response. @@ -224,9 +207,12 @@ message InitResponse { AttributeMapProto attributes = 10; // New in Version 4. } - repeated Device devices = 6; // == ifrt::Client::devices() + repeated Device all_devices = 6; // == ifrt::Client::GetAllDevices() + repeated int32 primary_device_ids = + 10; // == [device.id for device in ifrt::Client::devices()] repeated int32 addressable_device_ids = - 7; // == ifrt::Client::addressable_devices() + 7; // == [device.id for device in ifrt::Client::GetAllDevices() if + // device.IsAddressable()] message Memory { int32 id = 1; @@ -283,6 +269,7 @@ message AssembleArrayFromSingleDeviceArraysRequest { ShardingProto sharding = 2; repeated fixed64 single_device_array_handles = 3; proto.ArrayCopySemantics copy_semantics = 4; + optional proto.SingleDeviceShardSemantics single_device_shard_semantics = 5; } message AssembleArrayFromSingleDeviceArraysResponse { fixed64 array_handle = 1; @@ -313,6 +300,7 @@ message CopyToHostBufferResponse {} message DisassembleIntoSingleDeviceArraysRequest { fixed64 array_handle = 1; proto.ArrayCopySemantics copy_semantics = 2; + optional proto.SingleDeviceShardSemantics single_device_shard_semantics = 3; } message DisassembleIntoSingleDeviceArraysResponse { repeated fixed64 single_device_array_handles = 1; @@ -348,7 +336,10 @@ message FullyReplicatedShardResponse { // Deletes the given Array. Response contains the handle for a Future that // becomes ready when the deletion completes. message DeleteArrayRequest { - fixed64 array_handle = 1; + // TODO(b/296144873): Remove after compatibility window. + optional fixed64 array_handle_deprecated = 1 [deprecated = true]; + + repeated fixed64 array_handle = 2; } message DeleteArrayResponse { fixed64 deletion_future_handle = 1; @@ -362,7 +353,10 @@ message IsArrayDeletedResponse { } message DestructArrayRequest { - fixed64 array_handle = 1; + // TODO(b/296144873): Remove after compatibility window. + optional fixed64 array_handle_deprecated = 1 [deprecated = true]; + + repeated fixed64 array_handle = 2; } message DestructArrayResponse {} @@ -439,7 +433,7 @@ message LoadedExecutableMetadataResponse { message LoadedExecutableExecuteRequest { fixed64 loaded_executable_handle = 1; repeated fixed64 args_handles = 2; - xla.ExecuteOptionsProto execute_options = 3; + xla.ifrt.ExecuteOptionsProto execute_options = 3; repeated int32 device_ids = 4; } message LoadedExecutableExecuteResponse { diff --git a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc new file mode 100644 index 00000000000000..eed9fcea24e76e --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/ifrt_proxy/common/test_utils.h" + +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/debugging/leak_check.h" +#include "absl/synchronization/mutex.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +namespace { + +class Overrides { + public: + void Set(TestHookName h, std::function fn) { + absl::MutexLock l(&mu_); + overrides_[h] = std::move(fn); + } + + void Clear(TestHookName h) { + absl::MutexLock l(&mu_); + overrides_.erase(h); + } + + void Call(TestHookName h, bool* param1) { + absl::MutexLock l(&mu_); + const auto it = overrides_.find(h); + if (it != overrides_.end()) { + it->second(param1); + } + } + + private: + absl::Mutex mu_; + absl::flat_hash_map> overrides_ + ABSL_GUARDED_BY(mu_); +}; + +Overrides* overrides() { + // Declaring a global absl::NoDestructor is easier, but as of Sep + // 2024, NoDestructor<> was not yet available in the version of absl linked + // into TSL. + static Overrides* result = []() { + auto* result = new Overrides; + absl::IgnoreLeak(result); + return result; + }(); + return result; +} + +}; // namespace + +void TestHookSet(TestHookName h, std::function fn) { + overrides()->Set(h, std::move(fn)); +} +void TestHookClear(TestHookName h) { overrides()->Clear(h); } + +void TestHookCall(TestHookName h, bool* param1) { + overrides()->Call(h, param1); +} + +} // namespace proxy +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h new file mode 100644 index 00000000000000..002394fc6074fb --- /dev/null +++ b/third_party/xla/xla/python/ifrt_proxy/common/test_utils.h @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ +#define XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ + +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/log/check.h" +#include "absl/synchronization/mutex.h" +#include "absl/time/time.h" + +namespace xla { +namespace ifrt { +namespace proxy { + +// TestQueue implements a thread-safe queue that manages values of type T. +template +class TestQueue { + public: + explicit TestQueue(absl::Duration pop_timeout) + : pop_timeout_(std::move(pop_timeout)) {} + + // Pushes `t` into the queue. + void Push(T t) { + absl::MutexLock l(&mu_); + queue_.push_back(std::move(t)); + } + + // Pops the first element in the queue if a element is already available or + // appears within `pop_timeout` (because `Push` is called). Otherwise returns + // std::nullopt. + std::optional PopOrTimeout() { + absl::MutexLock l(&mu_); + auto cond = [this]() ABSL_SHARED_LOCKS_REQUIRED(mu_) -> bool { + return !queue_.empty(); + }; + mu_.AwaitWithTimeout(absl::Condition(&cond), pop_timeout_); + if (queue_.empty()) { + return std::nullopt; + } + T result = std::move(queue_.front()); + queue_.pop_front(); + return result; + } + + // Pops the first element in the queue if a element is already available or + // appears within `pop_timeout`, and fails otherwise. + T Pop() { + std::optional result = PopOrTimeout(); + CHECK(result.has_value()) << "Timeout!"; + return std::move(*result); + } + + // Sets whether the queue is allowed to be destructed while it contains + // unpopped elements. + void AllowNonEmptyDestruction(bool allow) { + absl::MutexLock l(&mu_); + allow_non_empty_destruction_ = allow; + } + + // Checks that the queue is either empty, or `AllowNonEmptyDestruction(true)` + // has been called. + ~TestQueue() { + absl::MutexLock l(&mu_); + if (!allow_non_empty_destruction_) CHECK(queue_.empty()) << " " << this; + } + + private: + const absl::Duration pop_timeout_; + + absl::Mutex mu_; + std::deque queue_ ABSL_GUARDED_BY(mu_); + bool allow_non_empty_destruction_ ABSL_GUARDED_BY(mu_) = false; +}; + +// TestHook provides a lightweight mechanism to modify the behavior of +// production code from tests. +// TODO(b/266635130): Extend for more hook types (as of Sep 2023, only allows +// `void(bool*)`) and make more lightweight. +enum class TestHookName { + kRpcBatcherPausePeriodicFlush, +}; + +// Allows test code to override the default noop behavior for hook `h`. +void TestHookSet(TestHookName h, std::function fn); + +// Resets hook `h` to the default noop behavior. +void TestHookClear(TestHookName h); + +// Calls hook `h` if it has been overridden by test setup; noop otherwise. +void TestHookCall(TestHookName h, bool* param1); + +} // namespace proxy +} // namespace ifrt +} // namespace xla + +#endif // XLA_PYTHON_IFRT_PROXY_COMMON_TEST_UTILS_H_ diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.cc b/third_party/xla/xla/python/ifrt_proxy/common/types.cc index db981531c24c27..9c2b4dc6a57927 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.cc +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.cc @@ -26,6 +26,7 @@ #include "absl/types/span.h" #include "xla/pjrt/pjrt_common.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/types.pb.h" @@ -97,6 +98,30 @@ absl::StatusOr FromArrayCopySemanticsProto( } } +proto::SingleDeviceShardSemantics ToSingleDeviceShardSemanticsProto( + SingleDeviceShardSemantics s) { + switch (s) { + case SingleDeviceShardSemantics::kAddressableShards: + return proto::SINGLE_DEVICE_SHARD_SEMANTICS_ADDRESSABLE_SHARDS; + case SingleDeviceShardSemantics::kAllShards: + return proto::SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS; + } +} + +absl::StatusOr FromSingleDeviceShardSemanticsProto( + proto::SingleDeviceShardSemantics s) { + switch (s) { + case proto::SINGLE_DEVICE_SHARD_SEMANTICS_ADDRESSABLE_SHARDS: + return SingleDeviceShardSemantics::kAddressableShards; + case proto::SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS: + return SingleDeviceShardSemantics::kAllShards; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unhandled proto-enum value ", s, ":", + proto::SingleDeviceShardSemantics_Name(s))); + } +} + std::vector FromByteStridesProto(const proto::ByteStrides& strides) { std::vector result; result.reserve(strides.strides_size()); diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.h b/third_party/xla/xla/python/ifrt_proxy/common/types.h index 0c517e2da054a5..06f6771b54f9d1 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.h +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.h @@ -24,6 +24,7 @@ #include "absl/types/span.h" #include "xla/pjrt/pjrt_common.h" #include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" #include "xla/python/ifrt_proxy/common/types.pb.h" @@ -44,6 +45,11 @@ absl::StatusOr FromArrayCopySemanticsProto( proto::ArrayCopySemantics s); proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s); +absl::StatusOr FromSingleDeviceShardSemanticsProto( + proto::SingleDeviceShardSemantics s); +proto::SingleDeviceShardSemantics ToSingleDeviceShardSemanticsProto( + SingleDeviceShardSemantics s); + absl::StatusOr FromVariantProto( const proto::Variant& variant_proto); absl::StatusOr ToVariantProto(const xla::PjRtValueType& value); diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.proto b/third_party/xla/xla/python/ifrt_proxy/common/types.proto index 49c3c7e1304570..0585109aff4737 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.proto +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.proto @@ -41,6 +41,12 @@ enum ArrayCopySemantics { ARRAY_COPY_SEMANTICS_DONATE_INPUT = 3; } +enum SingleDeviceShardSemantics { + SINGLE_DEVICE_SHARD_SEMANTICS_UNSPECIFIED = 0; + SINGLE_DEVICE_SHARD_SEMANTICS_ADDRESSABLE_SHARDS = 1; + SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS = 2; +} + message ByteStrides { repeated int64 strides = 1; } diff --git a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc index ae8c86855662b8..9fda694e648727 100644 --- a/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/integration_tests/mock_array_test.cc @@ -124,7 +124,8 @@ class MockArrayTest : public testing::Test { CpuClientOptions options; options.asynchronous = true; options.cpu_device_count = 2; - TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, xla::GetTfrtCpuClient(options)); + TF_ASSIGN_OR_RETURN(auto tfrt_cpu_client, + xla::GetTfrtCpuClient(std::move(options))); auto mock_backend = std::make_unique( /*delegate=*/xla::ifrt::PjRtClient::Create(std::move(tfrt_cpu_client))); @@ -196,49 +197,6 @@ TEST_F(MockArrayTest, ReadyFuturePropagatesError) { StatusIs(kInternal)); } -TEST_F(MockArrayTest, DeletionFutureWaitsUntilDeleted) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); - - tsl::thread::ThreadPool threads(tsl::Env::Default(), "t", /*num_threads=*/1); - absl::Notification wait_ready; - - EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { - // TODO(b/266635130): Write a version of this testcase where the Delete() - // call of the MockArray blocks on `wait_ready`, instead of the Future it - // returns being blocked on `wait_ready`. That version of the testcase does - // not currently work since both the client and the server synchronously - // block until the MockArray's Delete() returns. - auto promise = Future<>::CreatePromise(); - threads.Schedule([&, promise]() mutable { - wait_ready.WaitForNotification(); - promise.Set(arr.backend_array->delegated()->Delete().Await()); - }); - return Future<>(promise); - }); - - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - auto deleted_future = arr.proxy_client_array->Delete(); - - absl::SleepFor(kSomeTime); - EXPECT_FALSE(deleted_future.IsReady()); - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - - wait_ready.Notify(); - EXPECT_THAT(deleted_future.Await(), IsOk()); - EXPECT_TRUE(arr.proxy_client_array->IsDeleted()); -} - -TEST_F(MockArrayTest, DeletionPropagatesError) { - TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); - - EXPECT_CALL(*arr.backend_array, Delete).WillOnce([&] { - return Future<>(absl::InternalError("testing")); - }); - - EXPECT_FALSE(arr.proxy_client_array->IsDeleted()); - EXPECT_THAT(arr.proxy_client_array->Delete().Await(), StatusIs(kInternal)); -} - TEST_F(MockArrayTest, CopyToHostFutureWaitsUntilCopied) { TF_ASSERT_OK_AND_ASSIGN(ArrayPair arr, NewArray()); diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index 8484fee04b7e10..80e1976128d6fb 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -178,12 +178,13 @@ ifrt_proxy_cc_test( "//xla/python/ifrt:serdes", "//xla/python/ifrt:sharding_serdes", "//xla/python/ifrt_proxy/common:ifrt_service_proto_cc", - "//xla/python/ifrt_proxy/common:types", "//xla/python/ifrt_proxy/common:types_proto_cc", "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -203,8 +204,6 @@ ifrt_proxy_cc_test( "@local_tsl//tsl/platform:status_to_from_proto", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc index 040c0444cbd20f..ff38fad9337bfd 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend.cc @@ -35,6 +35,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -206,6 +207,8 @@ Future IfrtBackend::Process( return Future( HandleGetDefaultDeviceAssignmentRequest(std::move(request))); default: + LOG(ERROR) << "Got unimplemented request type: " + << request->DebugString(); return Future(absl::UnimplementedError(absl::StrCat( "Got unimplemented request type: ", request->request_case()))); } @@ -263,8 +266,14 @@ absl::StatusOr IfrtBackend::HandleInit( init_resp->set_runtime_type(AsProtoStringData(client_->runtime_type())); init_resp->set_process_index(client_->process_index()); - for (auto* device : client_->devices()) { - InitResponse::Device* d = init_resp->add_devices(); + absl::Span all_devices; + if (version_.protocol_version() < 7) { + all_devices = client_->devices(); + } else { + all_devices = client_->GetAllDevices(); + } + for (auto* device : all_devices) { + InitResponse::Device* d = init_resp->add_all_devices(); d->set_id(device->Id().value()); d->set_device_kind(AsProtoStringData(device->Kind())); if (auto default_memory = device->DefaultMemory(); default_memory.ok()) { @@ -286,13 +295,17 @@ absl::StatusOr IfrtBackend::HandleInit( } else { *d->mutable_attributes() = device->Attributes().ToProto(); } + + if (device->IsAddressable()) { + init_resp->add_addressable_device_ids(device->Id().value()); + } } - for (auto* addressable_device : client_->addressable_devices()) { - init_resp->add_addressable_device_ids(addressable_device->Id().value()); + for (auto* device : client_->devices()) { + init_resp->add_primary_device_ids(device->Id().value()); } absl::flat_hash_map memories; - for (auto* device : client_->devices()) { + for (auto* device : all_devices) { for (xla::ifrt::Memory* memory : device->Memories()) { const auto [it, inserted] = memories.insert({memory->Id().value(), memory}); @@ -469,12 +482,23 @@ IfrtBackend::HandleAssembleArrayFromSingleDeviceArraysRequest( auto sharding, Sharding::FromProto( absl::bind_front(&Client::LookupDevice, client_.get()), assemble_request.sharding())); - TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto( - assemble_request.copy_semantics())); + TF_ASSIGN_OR_RETURN( + auto array_copy_semantics, + FromArrayCopySemanticsProto(assemble_request.copy_semantics())); + SingleDeviceShardSemantics single_device_shard_semantics; + if (version_.protocol_version() < 8) { + single_device_shard_semantics = SingleDeviceShardSemantics::kAllShards; + } else { + TF_ASSIGN_OR_RETURN(single_device_shard_semantics, + FromSingleDeviceShardSemanticsProto( + assemble_request.single_device_shard_semantics())); + } - TF_ASSIGN_OR_RETURN(auto array, client_->AssembleArrayFromSingleDeviceArrays( - std::move(shape), std::move(sharding), - absl::MakeSpan(arrays), semantics)); + TF_ASSIGN_OR_RETURN( + auto array, + client_->AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(sharding), absl::MakeSpan(arrays), + array_copy_semantics, single_device_shard_semantics)); auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); @@ -605,15 +629,24 @@ Future IfrtBackend::HandleCopyToHostBufferRequest( absl::StatusOr IfrtBackend::HandleDisassembleIntoSingleDeviceArraysRequest( std::unique_ptr request) { - TF_ASSIGN_OR_RETURN( - auto array, - GetArray(request->disassemble_into_single_device_arrays_request() - .array_handle())); + const auto& disassemble_request = + request->disassemble_into_single_device_arrays_request(); + TF_ASSIGN_OR_RETURN(auto array, GetArray(disassemble_request.array_handle())); + SingleDeviceShardSemantics single_device_shard_semantics; + if (version_.protocol_version() < 8) { + single_device_shard_semantics = SingleDeviceShardSemantics::kAllShards; + } else { + TF_ASSIGN_OR_RETURN( + single_device_shard_semantics, + FromSingleDeviceShardSemanticsProto( + disassemble_request.single_device_shard_semantics())); + } // TODO(b/282757875): Consider other ArrayCopySemantics. TF_ASSIGN_OR_RETURN(auto single_device_arrays, array->DisassembleIntoSingleDeviceArrays( - xla::ifrt::ArrayCopySemantics::kAlwaysCopy)); + xla::ifrt::ArrayCopySemantics::kAlwaysCopy, + single_device_shard_semantics)); // Set up an IfrtResponse with pre-allocated space for the right number of // single device array handles. @@ -764,14 +797,32 @@ IfrtBackend::HandleFullyReplicatedShardRequest( absl::StatusOr IfrtBackend::HandleDeleteArrayRequest(std::unique_ptr request) { - TF_ASSIGN_OR_RETURN(auto array, - GetArray(request->delete_array_request().array_handle())); + std::vector bad_handles; + std::vector> deletion_futures; + + auto delete_handle = [&](uint64_t handle) { + auto array = GetArray(handle); + if (array.ok()) { + deletion_futures.push_back(array.value()->Delete()); + } else { + deletion_futures.push_back(Future<>(array.status())); + } + }; + + if (request->delete_array_request().has_array_handle_deprecated()) { + // TODO(b/296144873): After removing array_handle_deprecated(), move + // delete_handle's definition to the single place it is used. + delete_handle(request->delete_array_request().array_handle_deprecated()); + } + + for (auto array_handle : request->delete_array_request().array_handle()) { + delete_handle(array_handle); + } - auto deletion_future = array->Delete(); uint64_t future_handle = handle_generator_.New(); { absl::MutexLock lock(&futures_mutex_); - futures_.insert({future_handle, std::move(deletion_future)}); + futures_.insert({future_handle, JoinFutures(deletion_futures)}); } auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); @@ -793,16 +844,30 @@ IfrtBackend::HandleIsArrayDeletedRequest(std::unique_ptr request) { absl::StatusOr IfrtBackend::HandleDestructArrayRequest(std::unique_ptr request) { + std::vector bad_handles; { absl::MutexLock lock(&arrays_mutex_); - bool deleted = - arrays_.erase(request->destruct_array_request().array_handle()); - if (!deleted) { - return absl::NotFoundError( - absl::StrCat("Unknown array handle: ", - request->destruct_array_request().array_handle())); + for (const uint64_t array_handle : + request->destruct_array_request().array_handle()) { + if (!arrays_.erase(array_handle)) { + bad_handles.push_back(array_handle); + } + } + + if (request->destruct_array_request().has_array_handle_deprecated()) { + const uint64_t array_handle = + request->destruct_array_request().array_handle_deprecated(); + if (!arrays_.erase(array_handle)) { + bad_handles.push_back(array_handle); + } } } + + if (!bad_handles.empty()) { + return absl::NotFoundError(absl::StrCat("Unknown array handle(s): ", + absl::StrJoin(bad_handles, ","))); + } + auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id()); // Currently DestructArrayResponse is an empty message, but proxy clients may @@ -1023,6 +1088,12 @@ IfrtBackend::HandleLoadedExecutableExecuteRequest( TF_ASSIGN_OR_RETURN(auto execute_options, xla::ifrt::LoadedExecutable::ExecuteOptions::FromProto( execute.execute_options())); + // Force the old behavior where `fill_status` was implicitly true before + // protocol version 6. Can be cleaned up once version 6 is outside the + // compatibility window. + if (version_.protocol_version() < 6) { + execute_options.fill_status = true; + } std::optional> devices; if (!execute.device_ids().empty()) { @@ -1047,7 +1118,10 @@ IfrtBackend::HandleLoadedExecutableExecuteRequest( // `CheckFuture` exactly once to check for its status and erase it. In future, // we may introduce separate mechanisms to remove futures from `futures_` // without checking its status for situations where futures are not used. - { + // + // Starting protocol version 6, the client tells the server whether the status + // future needs to be populated or not. + if (version_.protocol_version() < 6 || execute_options.fill_status) { absl::MutexLock lock(&futures_mutex_); execute_response->set_status_handle(handle_generator_.New()); futures_.insert( diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 987708935f3da2..6758f938ca87ef 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -70,6 +70,8 @@ #include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -78,8 +80,6 @@ #include "tsl/platform/status_to_from_proto.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace ifrt { @@ -245,11 +245,14 @@ class IfrtBackendHandlerTest : public IfrtBackendTest { for (int i = 0; i < 2; ++i) { auto mock_device = std::make_unique(); ON_CALL(*mock_device, Id()).WillByDefault(Return(DeviceId(i))); + ON_CALL(*mock_device, IsAddressable()).WillByDefault(Return(true)); raw_device_ptrs.push_back(mock_device.get()); mock_devices_.push_back(std::move(mock_device)); } ON_CALL(*mock_client, devices()).WillByDefault(Return(raw_device_ptrs)); + ON_CALL(*mock_client, GetAllDevices()) + .WillByDefault(Return(raw_device_ptrs)); ON_CALL(*mock_client, LookupDevice(_)) .WillByDefault( Invoke([this](DeviceId id) -> absl::StatusOr { @@ -345,6 +348,9 @@ class IfrtBackendHandlerTest : public IfrtBackendTest { } absl::Status CheckFuture(uint64_t handle) { + if (handle == 0) { + return absl::InternalError("Test error, future handle is 0"); + } auto request = NewIfrtRequest(NewOpId()); request->mutable_check_future_request()->set_future_handle(handle); TF_ASSIGN_OR_RETURN(std::shared_ptr response, @@ -431,7 +437,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -441,7 +447,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { value { string_value: "device0" } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -463,6 +469,53 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } )pb")))))); + } else if (Version().protocol_version() < 7) { + EXPECT_THAT(CallBackend(std::move(request)), + IsOkAndHolds(Pointee( + Partially(IgnoringRepeatedFieldOrdering(EquivToProto(R"pb( + init_response { + session_id: 12345 + platform_name: "ifrt_backend" + platform_version: "n/a" + platform_id: 42 + process_index: 1 + runtime_type: "ifrt-service" + all_devices { + id: 0 + device_kind: "mock" + default_memory_id: 0 + memory_ids: [ 0 ] + attributes { + attributes { + key: "name" + value { string_value: "device0" } + } + } + } + all_devices { + id: 1 + device_kind: "mock" + default_memory_id: 1 + memory_ids: [ 1 ] + attributes { + attributes { + key: "name" + value { string_value: "device1" } + } + } + } + memories { + id: 0 + memory_space_kind: "mock" + device_ids: [ 0 ] + } + memories { + id: 1 + memory_space_kind: "mock" + device_ids: [ 1 ] + } + } + )pb")))))); } else { EXPECT_THAT(CallBackend(std::move(request)), IsOkAndHolds(Pointee( @@ -474,7 +527,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { platform_id: 42 process_index: 1 runtime_type: "ifrt-service" - devices { + all_devices { id: 0 device_kind: "mock" default_memory_id: 0 @@ -486,7 +539,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } - devices { + all_devices { id: 1 device_kind: "mock" default_memory_id: 1 @@ -498,6 +551,7 @@ TEST_P(IfrtBackendHandlerTest, Init) { } } } + primary_device_ids: [ 0, 1 ] memories { id: 0 memory_space_kind: "mock" @@ -527,7 +581,7 @@ TEST_P(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { single_device_arrays.push_back(tsl::MakeRef()); tsl::RCReference source_mock_array = tsl::MakeRef(); - EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_, _)) .WillOnce(Return(std::move(single_device_arrays))); // Inject the mock_array. @@ -536,8 +590,15 @@ TEST_P(IfrtBackendHandlerTest, DisassembleIntoSingleDeviceArraysSucceeds) { // Disassemble. auto disassemble_request = NewIfrtRequest(NewOpId()); - disassemble_request->mutable_disassemble_into_single_device_arrays_request() - ->set_array_handle(array_handle); + auto* disassemble_into_single_device_arrays = + disassemble_request + ->mutable_disassemble_into_single_device_arrays_request(); + disassemble_into_single_device_arrays->set_array_handle(array_handle); + if (Version().protocol_version() >= 8) { + disassemble_into_single_device_arrays->set_single_device_shard_semantics( + proto::SingleDeviceShardSemantics:: + SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS); + } TF_ASSERT_OK_AND_ASSIGN(auto disassemble_response, CallBackend(std::move(disassemble_request))); @@ -596,13 +657,25 @@ TEST_P(IfrtBackendHandlerTest, MakeArrayFromHostBufferSuccess) { TEST_P(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { auto ifrt_request = NewIfrtRequest(NewOpId()); { - ASSERT_TRUE(TextFormat::ParseFromString( - R"pb( - shape { dims: [ 2, 2 ] } - copy_semantics: ARRAY_COPY_SEMANTICS_ALWAYS_COPY - )pb", - ifrt_request - ->mutable_assemble_array_from_single_device_arrays_request())); + if (Version().protocol_version() < 8) { + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + shape { dims: [ 2, 2 ] } + copy_semantics: ARRAY_COPY_SEMANTICS_ALWAYS_COPY + )pb", + ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request())); + } else { + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb( + shape { dims: [ 2, 2 ] } + copy_semantics: ARRAY_COPY_SEMANTICS_ALWAYS_COPY + single_device_shard_semantics: + SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS + )pb", + ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request())); + } TF_ASSERT_OK_AND_ASSIGN(auto* device, mock_client_->LookupDevice(DeviceId(1))); TF_ASSERT_OK_AND_ASSIGN( @@ -618,17 +691,26 @@ TEST_P(IfrtBackendHandlerTest, AssembleArrayFromSingleDeviceArrays) { single_device_arrays.push_back(array); TF_ASSERT_OK_AND_ASSIGN(uint64_t array_handle, MakeTestArray(array)); - ifrt_request->mutable_assemble_array_from_single_device_arrays_request() - ->add_single_device_array_handles(array_handle); + auto* assemble_array_from_single_device_arrays = + ifrt_request + ->mutable_assemble_array_from_single_device_arrays_request(); + assemble_array_from_single_device_arrays->add_single_device_array_handles( + array_handle); + if (Version().protocol_version() >= 8) { + assemble_array_from_single_device_arrays + ->set_single_device_shard_semantics( + proto::SingleDeviceShardSemantics:: + SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS); + } } tsl::RCReference result = tsl::MakeRef(); const Shape expected_shape({2, 2}); - EXPECT_CALL(*mock_client_, - AssembleArrayFromSingleDeviceArrays( - expected_shape, _, ElementsAreArray(single_device_arrays), _)) + EXPECT_CALL(*mock_client_, AssembleArrayFromSingleDeviceArrays( + expected_shape, _, + ElementsAreArray(single_device_arrays), _, _)) .WillOnce(Return(std::move(result))); TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request))); @@ -691,7 +773,7 @@ TEST_P(IfrtBackendHandlerTest, "messages - 1234"; tsl::RCReference source_mock_array = tsl::MakeRef(); - EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_)) + EXPECT_CALL(*source_mock_array, DisassembleIntoSingleDeviceArrays(_, _)) .WillOnce(Return(absl::UnknownError(kDisassembleErrorMessage))); // Set up the mock client to return the source_mock_array when the test tries @@ -701,8 +783,15 @@ TEST_P(IfrtBackendHandlerTest, // Disassembly must fail with the error we injected. auto disassemble_request = NewIfrtRequest(NewOpId()); - disassemble_request->mutable_disassemble_into_single_device_arrays_request() - ->set_array_handle(array_handle); + auto* disassemble_into_single_device_arrays = + disassemble_request + ->mutable_disassemble_into_single_device_arrays_request(); + disassemble_into_single_device_arrays->set_array_handle(array_handle); + if (Version().protocol_version() >= 8) { + disassemble_into_single_device_arrays->set_single_device_shard_semantics( + proto::SingleDeviceShardSemantics:: + SINGLE_DEVICE_SHARD_SEMANTICS_ALL_SHARDS); + } ASSERT_THAT( CallBackend(std::move(disassemble_request)), StatusIs(absl::StatusCode::kUnknown, StrEq(kDisassembleErrorMessage))); @@ -913,26 +1002,46 @@ TEST_P(IfrtBackendHandlerTest, } TEST_P(IfrtBackendHandlerTest, DeleteArraySuccess) { - tsl::RCReference mock_array = - tsl::MakeRef(); - EXPECT_CALL(*mock_array, Delete()) + auto mock_array1 = tsl::MakeRef(); + EXPECT_CALL(*mock_array1, Delete()) .WillOnce(Return(Future<>(absl::OkStatus()))); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); + auto mock_array2 = tsl::MakeRef(); + EXPECT_CALL(*mock_array2, Delete()) + .WillOnce(Return(Future<>(absl::OkStatus()))); + + TF_ASSERT_OK_AND_ASSIGN(auto array_handle1, + MakeTestArray(std::move(mock_array1))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle2, + MakeTestArray(std::move(mock_array2))); uint64_t op_id = NewOpId(); auto ifrt_request = NewIfrtRequest(op_id); - ifrt_request->mutable_delete_array_request()->set_array_handle(array_handle); + ifrt_request->mutable_delete_array_request()->add_array_handle(array_handle1); + ifrt_request->mutable_delete_array_request()->add_array_handle(array_handle2); TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); EXPECT_THAT(tsl::StatusFromProto(resp->response_metadata().status()), IsOk()); - EXPECT_NE(resp->delete_array_response().deletion_future_handle(), 0); + TF_EXPECT_OK( + CheckFuture(resp->delete_array_response().deletion_future_handle())); } -TEST_P(IfrtBackendHandlerTest, DeleteArrayFailsWithNonExistentArrayHandle) { +TEST_P(IfrtBackendHandlerTest, + DeleteArrayReturnsFutureWithNonExistentArrayHandle) { + // Create one existing array. + auto mock_array1 = tsl::MakeRef(); + EXPECT_CALL(*mock_array1, Delete()) + .WillOnce(Return(Future<>(absl::OkStatus()))); + TF_ASSERT_OK_AND_ASSIGN(auto real_handle, + MakeTestArray(std::move(mock_array1))); + + constexpr int kBadHandle = 400; auto ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_delete_array_request()->set_array_handle(0); - EXPECT_THAT(CallBackend(std::move(ifrt_request)), - StatusIs(absl::StatusCode::kNotFound)); + ifrt_request->mutable_delete_array_request()->add_array_handle(real_handle); + ifrt_request->mutable_delete_array_request()->add_array_handle(kBadHandle); + TF_ASSERT_OK_AND_ASSIGN(auto resp, CallBackend(std::move(ifrt_request))); + + EXPECT_THAT( + CheckFuture(resp->delete_array_response().deletion_future_handle()), + StatusIs(absl::StatusCode::kNotFound)); } TEST_P(IfrtBackendHandlerTest, @@ -968,14 +1077,20 @@ TEST_P(IfrtBackendHandlerTest, IsDeleteFailsForNonExistentArrays) { } TEST_P(IfrtBackendHandlerTest, DestructArrayTest) { - tsl::RCReference mock_array = + tsl::RCReference mock_array1 = tsl::MakeRef(); - TF_ASSERT_OK_AND_ASSIGN(auto array_handle, - MakeTestArray(std::move(mock_array))); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle1, + MakeTestArray(std::move(mock_array1))); + tsl::RCReference mock_array2 = + tsl::MakeRef(); + TF_ASSERT_OK_AND_ASSIGN(auto array_handle2, + MakeTestArray(std::move(mock_array2))); auto ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_destruct_array_request()->set_array_handle( - array_handle); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle1); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle2); TF_ASSERT_OK_AND_ASSIGN(auto ifrt_resp, CallBackend(std::move(ifrt_request))); EXPECT_TRUE(ifrt_resp->has_destruct_array_response()); @@ -983,8 +1098,8 @@ TEST_P(IfrtBackendHandlerTest, DestructArrayTest) { // handle no longer exists on the server, (2) DestructArray fails for // non-existent arrays and (3) DestructArray is not idempotent. ifrt_request = NewIfrtRequest(NewOpId()); - ifrt_request->mutable_destruct_array_request()->set_array_handle( - array_handle); + ifrt_request->mutable_destruct_array_request()->add_array_handle( + array_handle1); EXPECT_THAT(CallBackend(std::move(ifrt_request)), StatusIs(absl::StatusCode::kNotFound)); } @@ -1063,7 +1178,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetOutputShardings()) .WillOnce(Return(std::vector{op_sharding1})); - std::vector> parameter_layouts; + std::vector> parameter_layouts; parameter_layouts.push_back(std::make_unique( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/1))); parameter_layouts.push_back(std::make_unique( @@ -1071,7 +1186,7 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableMetadata) { EXPECT_CALL(*executable, GetParameterLayouts()) .WillOnce(Return(std::move(parameter_layouts))); - std::vector> output_layouts; + std::vector> output_layouts; output_layouts.push_back(std::make_unique( xla::LayoutUtil::MakeDescendingLayout(/*rank=*/2))); EXPECT_CALL(*executable, GetOutputLayouts()) @@ -1196,9 +1311,10 @@ TEST_P(IfrtBackendHandlerTest, LoadedExecutableExecute) { execute_request->add_args_handles(arg_handle); } execute_request->set_loaded_executable_handle(handle); - TF_ASSERT_OK_AND_ASSIGN( - *execute_request->mutable_execute_options(), - xla::ifrt::LoadedExecutable::ExecuteOptions().ToProto()); + xla::ifrt::LoadedExecutable::ExecuteOptions execute_options; + execute_options.fill_status = true; + TF_ASSERT_OK_AND_ASSIGN(*execute_request->mutable_execute_options(), + execute_options.ToProto()); TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr response, CallBackend(std::move(request))); diff --git a/third_party/xla/xla/python/ifrt_proxy/server/version.h b/third_party/xla/xla/python/ifrt_proxy/server/version.h index 686fe78993bfd2..c1235ed525fbd3 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/version.h +++ b/third_party/xla/xla/python/ifrt_proxy/server/version.h @@ -26,7 +26,7 @@ namespace proxy { // LINT.IfChange // TODO(b/296144873): Document the version upgrade policy. inline constexpr int kServerMinVersion = 1; -inline constexpr int kServerMaxVersion = 4; +inline constexpr int kServerMaxVersion = 8; // LINT.ThenChange(//tensorflow/compiler/xla/python/ifrt_proxy/common/VERSION.md) // Returns a version that both the client and the server support, or an error if diff --git a/third_party/xla/xla/python/jax_jit.h b/third_party/xla/xla/python/jax_jit.h index fa0fc2b78e89a0..79552702765061 100644 --- a/third_party/xla/xla/python/jax_jit.h +++ b/third_party/xla/xla/python/jax_jit.h @@ -134,7 +134,7 @@ H AbslHashValue(H h, const ArgumentSignature& s) { const auto& static_arg = s.static_args[i]; Py_hash_t hash; try { - hash = xla::nb_hash(static_arg); + hash = nanobind::hash(static_arg); } catch (const nanobind::python_error& e) { if (!e.matches(PyExc_TypeError)) throw; throw std::invalid_argument(absl::StrCat( diff --git a/third_party/xla/xla/python/mlir.cc b/third_party/xla/xla/python/mlir.cc index e4aab79317e9db..36e19d2e7f94a8 100644 --- a/third_party/xla/xla/python/mlir.cc +++ b/third_party/xla/xla/python/mlir.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" @@ -36,11 +35,10 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep -#include "shardy/dialect/sdy/ir/dialect.h" -#include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/Serialization.h" #include "stablehlo/dialect/StablehloOps.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/mlir/utils/error_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -48,9 +46,6 @@ limitations under the License. #include "xla/pjrt/status_casters.h" #include "xla/python/refine_polymorphic_shapes.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -60,34 +55,6 @@ namespace nb = nanobind; namespace xla { namespace { -absl::StatusOr> ParseModule( - mlir::MLIRContext* context, std::string_view str) { - mlir::OwningOpRef module; - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - context->loadDialect(); - - mlir::DialectRegistry registry; - mlir::func::registerAllExtensions(registry); - context->appendDialectRegistry(registry); - - mlir::BaseScopedDiagnosticHandler diagnostic_handler(context); - module = mlir::parseSourceString( - llvm::StringRef(str.data(), str.size()), context); - if (!module) { - return diagnostic_handler.ConsumeStatus(); - } - if (failed(module->verifyInvariants())) { - VLOG(1) << "MLIR verification failed."; - module->dump(); - return diagnostic_handler.ConsumeStatus(); - } - return module; -} - std::string PrintModule(mlir::ModuleOp module) { std::string s; llvm::raw_string_ostream os(s); @@ -146,14 +113,10 @@ absl::StatusOr PyMlirModuleToXlaComputation( std::string_view mlir_module, bool use_tuple_args, bool return_tuple) { mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseModule(&context, mlir_module)); + ParseMlirModuleString(mlir_module, context)); XlaComputation computation; - mlir::PassManager pm(&context); - // SDY dialect may be part of the module which XLA doesn't know about. Export - // it. - xla::sdy::addSdyRoundTripExportPipeline(pm); - TF_RETURN_IF_ERROR(tsl::StatusScopedDiagnosticHandler(&context).consumeStatus( - pm.run(*module))); + // SDY dialect may be part of the module which XLA doesn't know about. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); TF_RETURN_IF_ERROR(MlirToXlaComputation(*module, computation, use_tuple_args, return_tuple, /*use_shardy=*/false)); @@ -165,13 +128,13 @@ absl::StatusOr PyMhloToStablehlo(std::string_view mlir_module) { if (VLOG_IS_ON(3)) context.disableMultithreading(); // JAX can be customized in a way that involves operations from custom // dialects showing up in JAX IR. - // `ParseModule` won't know about these dialects, but that's fine since we - // just want to convert MHLO ops to StableHLO ops here and leave everything - // else unchanged. + // `ParseMlirModuleString` won't know about these dialects, but that's fine + // since we just want to convert MHLO ops to StableHLO ops here and leave + // everything else unchanged. // In order to achieve that, we're allowing unregistered dialects here. context.allowUnregisteredDialects(true); TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseModule(&context, mlir_module)); + ParseMlirModuleString(mlir_module, context)); mlir::PassManager pm(&context); if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); @@ -192,8 +155,8 @@ absl::StatusOr PyStablehloToMhlo(const nb::bytes& mlir_module) { context.allowUnregisteredDialects(true); TF_ASSIGN_OR_RETURN( mlir::OwningOpRef module, - ParseModule(&context, - std::string_view(mlir_module.c_str(), mlir_module.size()))); + ParseMlirModuleString( + std::string_view(mlir_module.c_str(), mlir_module.size()), context)); mlir::PassManager pm(&context); if (VLOG_IS_ON(3)) EnablePrintBeforeAndAfter(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); @@ -212,7 +175,7 @@ absl::StatusOr PySerializePortableArtifact( mlir::MLIRContext context; if (VLOG_IS_ON(3)) context.disableMultithreading(); TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, - ParseModule(&context, mlir_module)); + ParseMlirModuleString(mlir_module, context)); // Serialize portable artifact TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/python/nb_helpers.cc b/third_party/xla/xla/python/nb_helpers.cc index 80e3a6ee6d11f6..6a241ca79cf6ff 100644 --- a/third_party/xla/xla/python/nb_helpers.cc +++ b/third_party/xla/xla/python/nb_helpers.cc @@ -23,14 +23,6 @@ namespace nb = nanobind; namespace xla { -Py_hash_t nb_hash(nb::handle o) { - Py_hash_t h = PyObject_Hash(o.ptr()); - if (h == -1) { - throw nb::python_error(); - } - return h; -} - bool nb_isinstance(nanobind::handle inst, nanobind::handle cls) { int ret = PyObject_IsInstance(inst.ptr(), cls.ptr()); if (ret == -1) { diff --git a/third_party/xla/xla/python/nb_helpers.h b/third_party/xla/xla/python/nb_helpers.h index 845adb8692cf5e..c8d69acaa7bdf5 100644 --- a/third_party/xla/xla/python/nb_helpers.h +++ b/third_party/xla/xla/python/nb_helpers.h @@ -23,10 +23,6 @@ limitations under the License. namespace xla { -// Calls Python hash() on an object. -// TODO(phawkins): consider upstreaming this to nanobind. -Py_hash_t nb_hash(nanobind::handle o); - // Calls Python isinstance(inst, cls). // TODO(phawkins): consider upstreaming this to nanobind. bool nb_isinstance(nanobind::handle inst, nanobind::handle cls); diff --git a/third_party/xla/xla/python/ops.cc b/third_party/xla/xla/python/ops.cc index 87f5333645a702..d199df47d8285c 100644 --- a/third_party/xla/xla/python/ops.cc +++ b/third_party/xla/xla/python/ops.cc @@ -31,17 +31,17 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/tuple.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/client/lib/approx_topk.h" -#include "xla/client/lib/approx_topk_shape.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/lib/lu_decomposition.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/qr.h" -#include "xla/client/lib/self_adjoint_eig.h" -#include "xla/client/lib/sorting.h" -#include "xla/client/lib/svd.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/approx_topk.h" +#include "xla/hlo/builder/lib/approx_topk_shape.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/lib/lu_decomposition.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/qr.h" +#include "xla/hlo/builder/lib/self_adjoint_eig.h" +#include "xla/hlo/builder/lib/sorting.h" +#include "xla/hlo/builder/lib/svd.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/pjrt/status_casters.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_helpers.h" diff --git a/third_party/xla/xla/python/outfeed_receiver.cc b/third_party/xla/xla/python/outfeed_receiver.cc deleted file mode 100644 index 539d1f2df6c308..00000000000000 --- a/third_party/xla/xla/python/outfeed_receiver.cc +++ /dev/null @@ -1,535 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/outfeed_receiver.h" - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/client/executable_build_options.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_executable.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/service/computation_placer.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "tsl/platform/casts.h" -#include "tsl/platform/env.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" -#include "tsl/profiler/lib/traceme.h" - -// Implementation notes: -// -// Startup: -// ------- -// -// The startup is initiated by a call from Python to StartOutfeedReceiver. For -// each local device there is one thread for listening for outfeeds from the -// device, one queue of received outfeeds, and one thread for invoking the -// Python callbacks. -// -// Framing protocol -// ---------------- -// -// The outfeed mechanism has a single channel and the receiver must know -// exactly the shape and number of outfeed operations issued by the compiled -// code. This makes it hard to use outfeed in conditionals and loops and -// especially when outfeeding different-shaped data. -// -// To address this, when we compile the code we capture the shape of the -// data being outfed, and we generate a consumer ID (uint32_t) that is unique -// across the lifetime of the program to: the Python callable to callback to, -// the shape of the arguments, the keyword arguments to pass to the callable. -// Each outfeed payload is preceeded by a header (of shape u32[2]) with a -// special first value and the consumer ID. We maintain a registry of shapes -// by consumer ID. When receiving we lookup the shape by consumer ID, and then -// we read the payload. -// -// Back pressure: -// -------------- -// -// We maintain a sum of the bytes from all the data waiting in the callback -// queues. The listening threads will wait for the sum to drop below a -// configurable threshold, default 256Mb. While the listening thread is waiting, -// on CPU and GPU the next outfeed operation from the device will block. On -// TPU there is a buffer, but eventually the TPU will also block. -// -// Shutdown: -// --------- -// -// The shutdown is initiated automatically when the last reference to the -// outfeed receiver object is dropped, and the Python garbage collector invokes -// the destructor. -// -// The shutdown sequence is implemented as follows: -// * we enqueue on all devices a computation that outfeeds a special header -// with customer ID kOutfeedCidShutdown. -// * when each listening threads gets the shutdown header, it decrements -// a counter of listening threads, and it -// enqueues a special shutdown callback. -// * when each callback thread gets the shutdown callback marker, it terminates. -// * the shutdown code waits until all threads terminate. -// -// Since we currently keep the shape registry in the OutfeedReceiver, it is -// not safe to replace the OutfeedReceiver instance during the lifetime of -// the JAX program, or else previously cached jitted computations may refer -// to previously cached shapes. This can be solved, but for now we disallow -// replacing the OutfeedReceiver, and do not provide a Shutdown API to the -// Python program. - -namespace xla { - -// The header contains: -// 0. kOutfeedHeaderStart -// 1. consumer id -int constexpr kOutfeedHeaderWords = 2; -uint32_t constexpr kOutfeedHeaderStart = 271828; -// Special consumer IDs, without outfeed payload. -uint32_t constexpr kOutfeedCidShutdown = 0; - -// Encapsulates data received from a device outfeed. -class OutfeedData { - public: - OutfeedData(ifrt::PjRtDevice* device, uint32_t consumer_id, Shape shape) - : device_(device), - consumer_id_(consumer_id), - shape_(shape), - literal_(nullptr), - literal_size_bytes_(0) {} - - ifrt::PjRtDevice* device() { return device_; } - uint32_t consumer_id() const { return consumer_id_; } - Shape shape() const { return shape_; } - std::unique_ptr literal() { - CHECK(literal_); - return std::move(literal_); - } - - void SetLiteral(std::unique_ptr literal); - - ssize_t literal_size_bytes() const { return literal_size_bytes_; } - - std::string DebugString() const; - - private: - ifrt::PjRtDevice* device_; - uint32_t consumer_id_; - Shape shape_; - std::unique_ptr literal_; - ssize_t literal_size_bytes_; -}; - -void OutfeedData::SetLiteral(std::unique_ptr literal) { - literal_ = std::move(literal); - shape_ = literal_->shape(); - int total_size_bytes = 0; - ShapeUtil::ForEachSubshape( - shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) { - if (!literal_subshape.IsTuple()) { - total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8); - } - }); - literal_size_bytes_ = total_size_bytes; -} - -std::string OutfeedData::DebugString() const { - return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(), - consumer_id_, shape_.ToString()); -} - -class OutfeedReceiverImpl { - public: - OutfeedReceiverImpl( - OutfeedReceiver::Callback callback, - absl::Span clients, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options); - - OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete; - OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete; - - // Blocks until all data has been received from devices and all data - // in the queue has been passed to Python. - ~OutfeedReceiverImpl(); - - void Start(); - - absl::StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx); - - private: - bool CallbackQueueHasSpace() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return callback_queue_size_bytes_ < max_callback_queue_size_bytes_; - } - - bool ShutdownDone() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0); - } - - void CallbackThreadLoop(int device_idx); - void DeviceListenerThreadLoop(int device_idx); - - // Enqueues to a device an outfeed operation with a shutdown consumer ID. - absl::Status SendShutdownOutfeedHeader(int device_idx); - - // Receives a raw Literal from a device outfeed. - absl::StatusOr> ReceiveRawFromOutfeed( - ifrt::PjRtDevice* device, const Shape& shape); - - // Enqueues received data in the callbaback queue. - void EnqueueReceivedData(uint32_t device_idx, - std::unique_ptr received) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Shuts down the threads. See implementation notes at top of file. - // It is not safe to restart an OutfeedReceiver after shutting down one. - void Shutdown(); - - OutfeedReceiver::Callback callback_; - // The devices on which we are listening. - std::vector devices_; - // Maximum bytes capacity of the ensemble of callback queues. - uint64_t max_callback_queue_size_bytes_; - std::optional executable_build_options_; - - absl::Mutex mu_; - // Registered shapes by consumer id. - // The shape registry must be alive as long as the program exists. - // Right now we tell the user to never restart after Shutdown. - absl::flat_hash_map shape_registry_ ABSL_GUARDED_BY(mu_); - // How many bytes of Literal are in the ensemble of callback queues. - uint64_t callback_queue_size_bytes_ ABSL_GUARDED_BY(mu_); - // Threads listening. - int num_listening_threads_ ABSL_GUARDED_BY(mu_); - bool shutdown_started_ ABSL_GUARDED_BY(mu_); - - // How many callback threads are still working. Used for shutdown. - int num_working_callback_threads_ ABSL_GUARDED_BY(mu_); - - std::vector>> callback_queues_ - ABSL_GUARDED_BY(mu_); - // The threadpool must come last to ensure the queue exists - // when the pool destructor is called. - std::unique_ptr threads_; -}; - -OutfeedReceiverImpl::OutfeedReceiverImpl( - OutfeedReceiver::Callback callback, - absl::Span clients, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options) - : executable_build_options_(executable_build_options) { - callback_ = callback; - max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; - for (const auto& client : clients) { - for (auto device : client->addressable_devices()) { - devices_.push_back(tensorflow::down_cast(device)); - } - } - CHECK_GT(devices_.size(), 0); - callback_queues_ = - std::vector>>(devices_.size()); - - callback_queue_size_bytes_ = 0; - num_listening_threads_ = 0; - num_working_callback_threads_ = 0; - shutdown_started_ = false; -} - -void OutfeedReceiverImpl::Start() { - { - absl::MutexLock lock(&mu_); - CHECK(!shutdown_started_); - } - - int num_threads = 2 * devices_.size(); - threads_ = std::make_unique( - tsl::Env::Default(), "outfeed_receiver", num_threads); - for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { - threads_->Schedule( - [this, device_idx]() { DeviceListenerThreadLoop(device_idx); }); - threads_->Schedule( - [this, device_idx]() { CallbackThreadLoop(device_idx); }); - } -} - -void OutfeedReceiverImpl::Shutdown() { - VLOG(2) << "Shutdown start"; - { - absl::MutexLock lock(&mu_); - CHECK(!shutdown_started_); - shutdown_started_ = true; - } - for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) { - TF_CHECK_OK(SendShutdownOutfeedHeader(device_idx)); - } - VLOG(2) << "Shutdown waiting for listening and callback threads to stop"; - absl::MutexLock lock(&mu_); - mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone)); - VLOG(2) << "Shutdown done"; -} - -OutfeedReceiverImpl::~OutfeedReceiverImpl() { - VLOG(2) << "~OutfeedReceiverImpl"; - Shutdown(); -} - -void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) { - { - absl::MutexLock lock(&mu_); - ++num_listening_threads_; - } - ifrt::PjRtDevice* device = devices_[device_idx]; - while (true) { - Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}); - std::unique_ptr header = - ReceiveRawFromOutfeed(device, header_shape).value(); - absl::Span header_data = header->data(); - CHECK_EQ(header_data.size(), kOutfeedHeaderWords); - CHECK_EQ(header_data[0], kOutfeedHeaderStart); - uint32_t consumer_id = header_data[1]; - Shape shape; - { - absl::MutexLock lock(&mu_); - auto registered_shape = shape_registry_.find(consumer_id); - if (registered_shape == shape_registry_.end()) { - LOG(FATAL) - << "[" << device->DebugString() - << "] Cannot find registered shape for consumer ID " << consumer_id - << ". Perhaps the code was compiled with a different instance " - << "of OutfeedReceiver."; - } - shape = registered_shape->second; - } - auto received = std::make_unique(device, consumer_id, shape); - VLOG(2) << "Listener received header " << received->DebugString(); - if (consumer_id == kOutfeedCidShutdown) { - VLOG(2) << "[" << device->DebugString() - << "] Listener received shutdown header"; - absl::MutexLock lock(&mu_); - --num_listening_threads_; - VLOG(2) << "[" << device->DebugString() << "] Enqueue shutdown callback"; - EnqueueReceivedData(device_idx, std::move(received)); - return; - } - std::unique_ptr data = - ReceiveRawFromOutfeed(device, shape).value(); - received->SetLiteral(std::move(data)); - absl::MutexLock lock(&mu_); - EnqueueReceivedData(device_idx, std::move(received)); - } -} - -void OutfeedReceiverImpl::EnqueueReceivedData( - uint32_t device_idx, std::unique_ptr received) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) { - mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace)); - ssize_t literal_size_bytes = received->literal_size_bytes(); - callback_queue_size_bytes_ += literal_size_bytes; - VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size " - << literal_size_bytes << " bytes; " - << (1 + callback_queues_[device_idx].size()) - << " callbacks in queue of total size " << callback_queue_size_bytes_ - << " bytes.\n"; - callback_queues_[device_idx].push(std::move(received)); -} - -absl::StatusOr> -OutfeedReceiverImpl::ReceiveRawFromOutfeed(ifrt::PjRtDevice* device, - const Shape& shape) { - auto literal = std::make_unique(shape); - TF_RETURN_IF_ERROR( - device->client()->TransferFromOutfeed(device, literal.get())); - return literal; -} - -void OutfeedReceiverImpl::CallbackThreadLoop(int device_idx) { - const ifrt::PjRtDevice* device = devices_[device_idx]; - { - absl::MutexLock lock(&mu_); - num_working_callback_threads_++; - } - while (true) { - std::unique_ptr received; - { - absl::MutexLock lock(&mu_); - mu_.Await(absl::Condition( - +[](std::queue>* queue) { - return !queue->empty(); - }, - &callback_queues_[device_idx])); - received = std::move(callback_queues_[device_idx].front()); - callback_queues_[device_idx].pop(); - callback_queue_size_bytes_ -= received->literal_size_bytes(); - VLOG(2) << "[" << device->DebugString() << "] Dequeued callback for " - << received->DebugString() << "; " - << callback_queues_[device_idx].size() - << " callbacks in queue of total size " - << callback_queue_size_bytes_ << " bytes.\n"; - } - if (received->consumer_id() == kOutfeedCidShutdown) { - VLOG(2) << "[" << device->DebugString() - << "] Callback loop received shutdown signal"; - { - absl::MutexLock lock(&mu_); - CHECK(callback_queues_[device_idx].empty()); - --num_working_callback_threads_; - } - VLOG(2) << "[" << device->DebugString() << "] Callback loop done"; - return; - } - { - tsl::profiler::TraceMe traceme("OutfeedReceiver::Callback"); - callback_(received->device(), received->consumer_id(), - received->literal()); - } - } -} - -absl::Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { - const ifrt::PjRtDevice* device = devices_[device_idx]; - constexpr int consumer_id = kOutfeedCidShutdown; - VLOG(2) << "[" << device->DebugString() - << "] SendSpecialHeader cons=" << consumer_id; - XlaBuilder builder( - absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx)); - - // XLA Next doesn't support returning tokens from computations, so we use - // add-dependency to return a constant while ensuring the side-effect is still - // executed. - XlaOp cst_operand = xla::ConstantR0(&builder, 0); - XlaOp outfeed = - AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {}, 0) - .value(); - XlaOp add_dep = xla::internal::XlaBuilderFriend::BuildAddDependency( - &builder, cst_operand, outfeed, ShapeUtil::MakeScalarShape(S32)); - XlaComputation computation = builder.Build(add_dep).value(); - - CompileOptions compile_options; - if (executable_build_options_) { - compile_options.executable_build_options = *executable_build_options_; - } - compile_options.executable_build_options.set_num_replicas(1); - compile_options.executable_build_options.set_num_partitions(1); - DeviceAssignment device_assignment(1, 1); - device_assignment(0, 0) = device->Id().value(); - compile_options.executable_build_options.set_device_assignment( - device_assignment); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - devices_[device_idx]->client()->pjrt_client()->Compile( - computation, std::move(compile_options))); - ExecuteOptions execute_options; - TF_ASSIGN_OR_RETURN( - std::vector>> output_buffers, - executable->Execute({{}}, execute_options)); - return absl::OkStatus(); -} - -absl::StatusOr OutfeedReceiverImpl::AddOutfeedToBuilder( - XlaBuilder* builder, XlaOp token, uint32_t consumer_id, - std::vector arrays, uint32_t device_idx) { - XlaOp data = Tuple(builder, std::move(arrays)); - Shape shape_with_layout = builder->GetShape(data).value(); - ShapeUtil::ForEachMutableSubshape( - &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { - if (!subshape->has_layout()) { - LayoutUtil::SetToDefaultLayout(subshape); - } - }); - VLOG(2) << "RegisterShape cons=" << consumer_id - << "; shape=" << shape_with_layout.ToString(); - { - absl::MutexLock lock(&mu_); - auto found = shape_registry_.find(consumer_id); - if (found != shape_registry_.end()) { - if (!ShapeUtil::Equal(shape_with_layout, found->second)) { - return InvalidArgument( - "Shape %s does not match previous shape %s used " - "for consumer id %d", - shape_with_layout.DebugString(), found->second.DebugString(), - consumer_id); - } - } else { - shape_registry_.insert({consumer_id, shape_with_layout}); - } - } - - std::vector header{kOutfeedHeaderStart, consumer_id}; - XlaOp header_op = ConstantR1(builder, header); - // We assign the outfeed to the device specified by device_idx (first device - // by default). This must match the sharding for the paired infeed. - builder->SetSharding(sharding_builder::AssignDevice(device_idx)); - token = OutfeedWithToken( - header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), ""); - if (consumer_id != kOutfeedCidShutdown) { - token = OutfeedWithToken(data, token, shape_with_layout, ""); - } - builder->ClearSharding(); - return token; -} - -OutfeedReceiver::OutfeedReceiver( - Callback callback, absl::Span clients, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options) { - p_impl_ = std::make_unique(callback, clients, - max_callback_queue_size_bytes, - executable_build_options); -} - -OutfeedReceiver::~OutfeedReceiver() = default; - -void OutfeedReceiver::Start() { p_impl_->Start(); } - -absl::StatusOr OutfeedReceiver::AddOutfeedToBuilder( - XlaBuilder* builder, XlaOp token, uint32_t consumer_id, - std::vector arrays, uint32_t device_idx) { - if (consumer_id == kOutfeedCidShutdown) { - return InvalidArgument("Consumer ID cannot be a reserved value: %d", - consumer_id); - } - return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays, - device_idx); -} - -} // namespace xla diff --git a/third_party/xla/xla/python/outfeed_receiver.h b/third_party/xla/xla/python/outfeed_receiver.h deleted file mode 100644 index a9f47280c56ee6..00000000000000 --- a/third_party/xla/xla/python/outfeed_receiver.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_PYTHON_OUTFEED_RECEIVER_H_ -#define XLA_PYTHON_OUTFEED_RECEIVER_H_ - -#include -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/client/executable_build_options.h" -#include "xla/client/xla_builder.h" -#include "xla/literal.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/pjrt_ifrt/pjrt_device.h" - -namespace xla { - -class OutfeedReceiverImpl; - -// Implements a multithreaded receiver of outfeeds from devices. -class OutfeedReceiver { - public: - // A callback takes: device, consumer id, received. - using Callback = std::function)>; - - // Constructs the receiver for the given clients and callback function. - // - // Args: - // callback: a function to be called when an outfeed is ready for - // processing. - // clients: the clients for whose devices to listen. - // max_callback_queue_size_bytes: the maximum number of bytes for all - // received outfeeds queued to be processed. When this limit is reached - // we pause receiving outfeeds from devices. - OutfeedReceiver( - Callback callback, absl::Span clients, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options); - - OutfeedReceiver(const OutfeedReceiver&) = delete; - OutfeedReceiver& operator=(const OutfeedReceiver&) = delete; - - // Blocks until all data has been received from devices and all data - // in the queue has been passed to Python. - ~OutfeedReceiver(); - - // Starts the listener threads and the callback thread. - void Start(); - - // Adds to the computation builder the outfeed of the arrays. - // Has the side-effect of registering the sent shape for the consumer_id. - // Returns error status if the outfeed shape is different than the - // previously used shape for the same consumer_id or the consumer id is - // invalid. - absl::StatusOr AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx); - - private: - std::unique_ptr p_impl_; -}; - -} // namespace xla - -#endif // XLA_PYTHON_OUTFEED_RECEIVER_H_ diff --git a/third_party/xla/xla/python/outfeed_receiver_py.cc b/third_party/xla/xla/python/outfeed_receiver_py.cc deleted file mode 100644 index e5ce3f9e6bf9df..00000000000000 --- a/third_party/xla/xla/python/outfeed_receiver_py.cc +++ /dev/null @@ -1,215 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/outfeed_receiver_py.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/base/thread_annotations.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "llvm/Support/Casting.h" -#include "nanobind/nanobind.h" -#include "nanobind/stl/function.h" // IWYU pragma: keep -#include "nanobind/stl/optional.h" // IWYU pragma: keep -#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep -#include "nanobind/stl/vector.h" // IWYU pragma: keep -#include "xla/client/executable_build_options.h" -#include "xla/client/xla_builder.h" -#include "xla/literal.h" -#include "xla/pjrt/status_casters.h" -#include "xla/python/ifrt/device.h" -#include "xla/python/nb_class_ptr.h" -#include "xla/python/outfeed_receiver.h" -#include "xla/python/pjrt_ifrt/pjrt_client.h" -#include "xla/python/py_client.h" -#include "xla/python/types.h" -#include "tsl/platform/logging.h" - -namespace xla { - -namespace nb = nanobind; - -namespace { - -// A wrapper for OutfeedReceiver for use from Python, useful for ensuring -// that the GIL is released before destroying the OutfeedReceiver. -class OutfeedReceiverForPython { - public: - // A callback to Python takes: consumer id, received literal. - using CallbackToPython = - std::function, uint32_t, nb::object)>; - - static absl::StatusOr> Create( - CallbackToPython callback_python, - std::vector> clients, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options) { - std::vector client_ptrs; - client_ptrs.reserve(clients.size()); - for (const auto& client : clients) { - ifrt::PjRtClient* client_ptr = - llvm::dyn_cast(client->ifrt_client()); - if (!client_ptr) { - return absl::InvalidArgumentError( - "Outfeed receiver only implemented for PJRT clients."); - } - client_ptrs.push_back(client_ptr); - } - - return std::make_unique( - std::move(callback_python), std::move(clients), client_ptrs, - max_callback_queue_size_bytes, executable_build_options); - } - - OutfeedReceiverForPython( - CallbackToPython callback_python, - std::vector> clients, - std::vector client_ptrs, - ssize_t max_callback_queue_size_bytes, - const std::optional& executable_build_options) - : callback_python_(std::move(callback_python)), - clients_(std::move(clients)) { - OutfeedReceiver::Callback callback = - [this](ifrt::Device* device, uint32_t consumer_id, - std::shared_ptr literal) { - this->Callback(device, consumer_id, std::move(literal)); - }; - outfeed_receiver_ = std::make_unique( - callback, client_ptrs, max_callback_queue_size_bytes, - executable_build_options); - } - OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete; - OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete; - - ~OutfeedReceiverForPython() { - // This destructor is called from the Python GC. Release it for the duration - // of the destruction, including the destruction of the OutfeedReceiver, - // when we may actually have to wait for threads to end. During this time - // we do not callback to Python (sometimes we get an exception - // "std::runtime_error: scoped_acquire::dec_ref(): thread state must - // be current!""). - { - absl::MutexLock lock(&mu_); - outfeed_receiver_shutting_down_ = true; - } - nb::gil_scoped_release gil_release; - outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver. - } - - void Start() { outfeed_receiver_->Start(); } - - absl::StatusOr AddOutfeed(XlaBuilder* builder, XlaOp token, - uint32_t consumer_id, - std::vector arrays, - uint32_t device_idx) { - return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id, - arrays, device_idx); - } - - void Callback(ifrt::Device* device, uint32_t consumer_id, - std::shared_ptr literal) { - { - absl::MutexLock lock(&mu_); - if (outfeed_receiver_shutting_down_) { - VLOG(2) << "Ignoring unsafe callback to Python during shutdown"; - return; - } - } - // We expect the number of clients to be small, so an O(n) search is fine. - auto it = absl::c_find_if( - clients_, [device](const nb_class_ptr& client) { - return client->ifrt_client() == device->client(); - }); - CHECK(it != clients_.end()); - PyClient* client = it->get(); - nb::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython - nb::object literal_python = LiteralToPython(std::move(literal)).value(); - // The callback_ should handle all exceptions in user-code. If we get - // an exception here, it is a bug in the callback and we should stop. - callback_python_(client->GetPyDevice(device), consumer_id, - std::move(literal_python)); - } - - private: - CallbackToPython callback_python_; - absl::Mutex mu_; - bool outfeed_receiver_shutting_down_ ABSL_GUARDED_BY(mu_) = false; - std::vector> clients_; - std::unique_ptr outfeed_receiver_; -}; - -} // namespace - -void BuildOutfeedReceiverSubmodule(nb::module_& m) { - nb::module_ outfeed_receiver = - m.def_submodule("outfeed_receiver", "Outfeed receiver"); - outfeed_receiver.def( - "start", - [](OutfeedReceiverForPython::CallbackToPython callback_to_python, - nb::sequence clients, ssize_t max_callback_queue_size_bytes, - std::optional executable_build_options) - -> std::unique_ptr { - auto server = xla::ValueOrThrow(OutfeedReceiverForPython::Create( - std::move(callback_to_python), - SequenceToVector>(clients), - max_callback_queue_size_bytes, executable_build_options)); - nb::gil_scoped_release gil_release; - server->Start(); - return server; - }, - nb::arg("callback_to_python"), nb::arg("backends"), - nb::arg("max_queue_size_bytes") = 256 * 1024 * 1024, - nb::arg("executable_build_options").none() = nb::none(), - R"(Starts a multithreaded outfeed receiver. - - There is one thread for each of the specified devices. When Python - drops the last reference to the returned object, the receiver is shut - down. The destructor will block until all data is received from - devices. - - Args: - * callback_to_python: a Python callback to call, with - and the data received. - * backends: the list of backends to listen on. - * max_queue_size_bytes: an optional integer to bound the maximum size - of arrays in the callback queue. When this limit is reached the - device listener pauses. - )"); - - nb::class_ outfeed_receiver_class( - outfeed_receiver, "OutfeedReceiverForPython"); - - outfeed_receiver_class.def( - "add_outfeed", - xla::ValueOrThrowWrapper(&OutfeedReceiverForPython::AddOutfeed), - nb::arg("builder"), nb::arg("token"), nb::arg("consumer_id"), - nb::arg("arrays"), nb::arg("device_idx"), - R"(Adds an outfeed into the given computation builder. - - Has the side-effect of registering the sent shape along with the consumer - ID. Returns error if the outfeed shape is not compatible with previously - used shape for the same consumer ID.)", - nb::call_guard()); -} - -} // namespace xla diff --git a/third_party/xla/xla/python/outfeed_receiver_test.cc b/third_party/xla/xla/python/outfeed_receiver_test.cc deleted file mode 100644 index ffdeb25659b191..00000000000000 --- a/third_party/xla/xla/python/outfeed_receiver_test.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright 2020 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/python/outfeed_receiver.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "xla/client/client_library.h" -#include "xla/client/executable_build_options.h" -#include "xla/client/xla_builder.h" -#include "xla/pjrt/cpu/cpu_client.h" -#include "xla/pjrt/pjrt_client.h" -#include "xla/pjrt/pjrt_stream_executor_client.h" -#include "xla/service/platform_util.h" -#include "xla/test.h" - -namespace xla { - -namespace { - -absl::Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id, - PjRtClient* client) { - XlaComputation computation = builder->Build(root).value(); - - CompileOptions compile_options; - compile_options.executable_build_options.set_num_replicas(1); - compile_options.executable_build_options.set_num_partitions(1); - DeviceAssignment device_assignment(1, 1); - device_assignment(0, 0) = device_id; - compile_options.executable_build_options.set_device_assignment( - device_assignment); - - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - client->Compile(computation, std::move(compile_options))); - ExecuteOptions execute_options; - TF_ASSIGN_OR_RETURN( - std::vector>> output_buffers, - executable->Execute({{}}, execute_options)); - return absl::OkStatus(); -} - -// Accumulates the received data. -class Accumulator { - public: - struct Data { - uint32_t consumer_id; - std::shared_ptr data; - }; - - void Receive(uint32_t consumer_id, std::shared_ptr data) { - absl::MutexLock lock(&mutex_); - received_.push_back(Data{consumer_id, data}); - } - - std::vector received() { - absl::MutexLock lock(&mutex_); - return received_; - } - - private: - absl::Mutex mutex_; - std::vector received_ ABSL_GUARDED_BY(mutex_); -}; - -// TODO(necula): update this test for the TFRT CPU client, which current does -// not support non-local devices. -// absl::StatusOr> GetCpuClientWithNonLocalDevice() -// { -// TF_ASSIGN_OR_RETURN(se::Platform * platform, -// PlatformUtil::GetPlatform("Host")); -// if (platform->VisibleDeviceCount() <= 0) { -// return FailedPrecondition("CPU platform has no visible devices."); -// } -// LocalClientOptions options; -// options.set_platform(platform); -// TF_ASSIGN_OR_RETURN(LocalClient * client, -// ClientLibrary::GetOrCreateLocalClient(options)); - -// se::StreamExecutorConfig config(0); -// TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, -// platform->GetExecutor(config)); -// auto device_state = std::make_unique( -// executor, client, LocalDeviceState::kSynchronous, -// /*max_inflight_computations=*/32, -// /*allow_event_reuse=*/false, /*use_callback_stream=*/false); - -// std::vector> devices; -// devices.push_back(std::make_unique(0, std::move(device_state))); -// devices.push_back(std::make_unique(1, nullptr)); - -// return -// std::unique_ptr(std::make_unique( -// CpuName(), client, std::move(devices), /*process_index=*/0, -// /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, -// /*should_stage_host_to_device_transfers=*/false, -// /*gpu_run_options=*/nullptr)); -// } - -TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(CpuClientOptions())); - auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); - std::vector clients{ifrt_cpu_client.get()}; - - auto receiver = std::make_unique(); - OutfeedReceiver::Callback callback = - [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; - auto outfeed_receiver = - std::make_shared(callback, clients, 128, std::nullopt); - outfeed_receiver->Start(); - - XlaBuilder builder("execute_test_outfeed"); - constexpr int consumer_id0 = 5; - const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); - XlaOp data = Iota(&builder, shape0, 0); - XlaOp send = outfeed_receiver - ->AddOutfeedToBuilder(&builder, CreateToken(&builder), - consumer_id0, {data}, 0) - .value(); - EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok()); - - // Shutdown the receiver, to force it to wait to deliver the callbacks. - outfeed_receiver = nullptr; - std::vector received = receiver->received(); - EXPECT_EQ(1, received.size()); - EXPECT_EQ(consumer_id0, received[0].consumer_id); - EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); -} - -TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(CpuClientOptions())); - auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); - std::vector clients{ifrt_cpu_client.get()}; - - auto receiver = std::make_unique(); - OutfeedReceiver::Callback callback = - [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; - auto outfeed_receiver = - std::make_shared(callback, clients, 128, std::nullopt); - outfeed_receiver->Start(); - - XlaBuilder builder0("execute_test_outfeed_0"); - constexpr int consumer_id0 = 5; - const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); - XlaOp data0 = Iota(&builder0, shape0, 0); - XlaOp send0 = outfeed_receiver - ->AddOutfeedToBuilder(&builder0, CreateToken(&builder0), - consumer_id0, {data0}, 0) - .value(); - EXPECT_TRUE(CompileAndExecute(&builder0, send0, 0, cpu_client.get()).ok()); - - XlaBuilder builder1("execute_test_outfeed_1"); - constexpr int consumer_id1 = 6; - const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); - XlaOp data1 = Iota(&builder1, shape1, 0); - XlaOp send1 = outfeed_receiver - ->AddOutfeedToBuilder(&builder1, CreateToken(&builder1), - consumer_id1, {data1}, 0) - .value(); - EXPECT_TRUE(CompileAndExecute(&builder1, send1, 0, cpu_client.get()).ok()); - - // Shutdown the receiver, to force it to wait to deliver the callbacks. - outfeed_receiver = nullptr; - std::vector received = receiver->received(); - EXPECT_EQ(2, received.size()); - EXPECT_EQ(consumer_id0, received[0].consumer_id); - EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); - EXPECT_EQ(consumer_id1, received[1].consumer_id); - EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); -} - -TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(CpuClientOptions())); - auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); - std::vector clients{ifrt_cpu_client.get()}; - - auto receiver = std::make_unique(); - OutfeedReceiver::Callback callback = - [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; - auto outfeed_receiver = - std::make_shared(callback, clients, 128, std::nullopt); - outfeed_receiver->Start(); - - XlaBuilder builder("execute_test_outfeed"); - constexpr int consumer_id0 = 5; - const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); - XlaOp data0 = Iota(&builder, shape0, 0); - XlaOp send0 = outfeed_receiver - ->AddOutfeedToBuilder(&builder, CreateToken(&builder), - consumer_id0, {data0}, 0) - .value(); - - constexpr int consumer_id1 = 6; - const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); - XlaOp data1 = Iota(&builder, shape1, 0); - XlaOp send1 = - outfeed_receiver - ->AddOutfeedToBuilder(&builder, send0, consumer_id1, {data1}, 0) - .value(); - EXPECT_TRUE(CompileAndExecute(&builder, send1, 0, cpu_client.get()).ok()); - - // Shutdown the receiver, to force it to wait to deliver the callbacks. - outfeed_receiver = nullptr; - std::vector received = receiver->received(); - EXPECT_EQ(2, received.size()); - EXPECT_EQ(consumer_id0, received[0].consumer_id); - EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); - EXPECT_EQ(consumer_id1, received[1].consumer_id); - EXPECT_EQ(ShapeUtil::MakeTupleShape({shape1}), received[1].data->shape()); -} - -TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(CpuClientOptions())); - auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); - std::vector clients{ifrt_cpu_client.get()}; - - auto receiver = std::make_unique(); - OutfeedReceiver::Callback callback = - [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; - auto outfeed_receiver = - std::make_shared(callback, clients, 128, std::nullopt); - outfeed_receiver->Start(); - - XlaBuilder builder("execute_test_outfeed"); - constexpr int consumer_id0 = 5; - const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); - XlaOp data0 = Iota(&builder, shape0, 0); - XlaOp send0 = outfeed_receiver - ->AddOutfeedToBuilder(&builder, CreateToken(&builder), - consumer_id0, {data0}, 0) - .value(); - - const Shape shape1 = ShapeUtil::MakeShape(U32, {128}); - XlaOp data1 = Iota(&builder, shape1, 0); - // A different shape for the same consumer ID. - absl::StatusOr send1 = outfeed_receiver->AddOutfeedToBuilder( - &builder, send0, consumer_id0, {data1}, 0); - EXPECT_FALSE(send1.ok()); - EXPECT_THAT( - send1.status().ToString(), - testing::ContainsRegex( -#if defined(PLATFORM_WINDOWS) - "does not match previous shape \\w*/*\\w* *\\n?element_type")); -#else - "does not match previous shape (go/\\w+[ " - "]+\\n)?element_type")); -#endif -} - -TEST(OutfeedReceiverTest, InvalidConsumerIdError) { - TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, - GetTfrtCpuClient(CpuClientOptions())); - auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); - std::vector clients{ifrt_cpu_client.get()}; - - auto receiver = std::make_unique(); - OutfeedReceiver::Callback callback = - [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; - auto outfeed_receiver = - std::make_shared(callback, clients, 128, std::nullopt); - outfeed_receiver->Start(); - - XlaBuilder builder("execute_test_outfeed"); - const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); - XlaOp data0 = Iota(&builder, shape0, 0); - absl::StatusOr send0 = outfeed_receiver->AddOutfeedToBuilder( - &builder, CreateToken(&builder), 0, {data0}, 0); - - EXPECT_FALSE(send0.ok()); - EXPECT_THAT(send0.status().ToString(), - testing::HasSubstr("Consumer ID cannot be a reserved value")); -} - -// TEST(OutfeedReceiverTest, NonLocalDevicesIgnored) { -// TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, -// GetCpuClientWithNonLocalDevice()); -// auto ifrt_cpu_client = ifrt::PjRtClient::Create(cpu_client); -// std::vector clients{ifrt_cpu_client.get()}; - -// auto receiver = std::make_unique(); -// OutfeedReceiver::Callback callback = -// [&receiver](xla::ifrt::PjRtDevice* device, uint32_t consumer_id, -// std::shared_ptr data) { -// receiver->Receive(consumer_id, data); -// }; -// auto outfeed_receiver = -// std::make_shared(callback, clients, 128); -// outfeed_receiver->Start(); - -// XlaBuilder builder("execute_test_outfeed"); -// constexpr int consumer_id0 = 5; -// const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); -// XlaOp data = Iota(&builder, shape0, 0); -// XlaOp send = outfeed_receiver -// ->AddOutfeedToBuilder(&builder, CreateToken(&builder), -// consumer_id0, {data}) -// .value(); -// EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok()); - -// // Shutdown the receiver, to force it to wait to deliver the callbacks. -// outfeed_receiver = nullptr; -// std::vector received = receiver->received(); -// EXPECT_EQ(1, received.size()); -// EXPECT_EQ(consumer_id0, received[0].consumer_id); -// EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); -// } - -} // namespace - -} // namespace xla diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index c14556d1b58651..528b544504668e 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -55,6 +55,7 @@ limitations under the License. #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" #include "xla/pjrt/pjrt_layout.h" +#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" @@ -70,7 +71,6 @@ limitations under the License. #include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" -#include "xla/python/transfer_guard_lib.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -135,7 +135,10 @@ class PjitFunctionCache { int Size() const { return lru_list_.Size(); } int Capacity() const { return lru_list_.Capacity(); } - void Clear() { lru_list_.Clear(); } + void Clear() { + lru_list_.Clear(); + functions_.clear(); + } private: struct Key { @@ -164,7 +167,7 @@ class PjitFunctionCache { h = H::combine(std::move(h), key.function.ptr()); Py_hash_t hash; try { - hash = xla::nb_hash(key.global_cache_key); + hash = nb::hash(key.global_cache_key); } catch (const nanobind::python_error& e) { if (!e.matches(PyExc_TypeError)) throw; throw std::invalid_argument(absl::StrCat( @@ -347,6 +350,7 @@ class PjitFunctionStore { for (auto* function : compiled_functions_) { function->ClearCache(); } + compiled_functions_.clear(); } private: diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index a87d033d727d04..676fa32b1c581a 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -215,6 +215,7 @@ cc_library( deps = [ ":basic_string_array", ":pjrt_attribute_map_util", + ":pjrt_dtype", ":xla_ifrt", "//xla:literal", "//xla:shape_util", @@ -222,6 +223,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "//xla/pjrt:host_callback", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", @@ -238,7 +240,6 @@ cc_library( "//xla/python/ifrt:attribute_map", "//xla/python/ifrt/hlo:hlo_program", "//xla/service:hlo_proto_cc", - "//xla/translate/mhlo_to_hlo:type_to_shape", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -300,12 +301,26 @@ xla_cc_test( ], ) +cc_library( + name = "pjrt_dtype", + srcs = ["pjrt_dtype.cc"], + hdrs = ["pjrt_dtype.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/python/ifrt", + "@com_google_absl//absl/status:statusor", + ], +) + cc_library( name = "basic_string_array", srcs = ["basic_string_array.cc"], hdrs = ["basic_string_array.h"], compatible_with = get_compatible_with_portable(), deps = [ + "//xla:util", "//xla:xla_data_proto_cc", "//xla/pjrt:pjrt_layout", "//xla/python/ifrt", @@ -317,8 +332,8 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -342,6 +357,7 @@ xla_cc_test( "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc index 8c492ea29015ed..d3b9fd1be984f5 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.cc @@ -28,7 +28,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/pjrt/pjrt_layout.h" @@ -40,6 +39,7 @@ limitations under the License. #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -47,7 +47,7 @@ limitations under the License. // DisassembleIntoSingleDeviceArrays, Reshard, FullyReplicatedShard, // CopyToHostBuffer and AssembleFromSingleDeviceArrays share a common pattern // that waits for the source array(s) buffers to become ready and then copies -// the data into a new array's buffer backing store. Factor out the common +// the data into a new array's buffer. Factor out the common // pattern into a helper function. namespace xla { @@ -103,7 +103,7 @@ absl::StatusOr> BasicStringArray::Create( auto ready_future = Future<>(ready_promise); // Buffers when the become ready must be consistent with the sharding. For - // instance, Buffers.size() (the number of per-shard spans of string_views) + // instance, Buffers.size() (the number of per-shard spans of absl::Cords) // and the devices in the sharding that was used to create an array must // match. If they do not, the array's ready future and buffers future should // become ready with an appropriate error status. @@ -188,6 +188,23 @@ absl::StatusOr>> BasicStringArray::DisassembleIntoSingleDeviceArrays( ArrayCopySemantics semantics) { DCHECK(this); + return DisassembleIntoSingleDeviceArrays( + semantics, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>> +BasicStringArray::DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + DCHECK(this); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards && + !sharding_->devices()->IsFullyAddressable()) { + return InvalidArgument( + "All shards are requested but the sharding has non-addressable " + "devices: %v", + *sharding_->devices()); + } + absl::MutexLock lock(&mu_); if (is_deleted_) { return absl::FailedPreconditionError("Array has already been deleted"); @@ -198,13 +215,13 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays( // For each single device array we are going to pre-make: // (1) a Promise-Future pair for passing the buffers, // - // (2) a Per-shard buffer backing store and the corresponding - // on-done-with-buffer callback. + // (2) a Per-shard data store and the corresponding on-done-with-buffer + // callback. // // (3) shape and sharding by disassembing the source array's sharding. // // The Futures, the on-done-with-host-buffer callbacks, shapes and shardings - // are used to make the arrays. The promises and the buffer backing stores + // are used to make the arrays. The promises and the per-shard stores // are passed onto the OnReady callback that populates them when the buffers // of the source array become ready. std::vector> buffer_promises; @@ -212,21 +229,18 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays( std::vector> buffer_futures; buffer_futures.reserve(num_shards); - struct PerShardBufferBackingStore { // Data (strings) for a single shard. - void CopyFrom(absl::Span input_buffer) { + struct PerShardStringStore { // Data (strings) for a single shard. + void CopyFrom(absl::Span input_buffer) { strings.reserve(input_buffer.size()); - string_views.reserve(input_buffer.size()); - for (absl::string_view buf : input_buffer) { - strings.push_back(std::string(buf.data(), buf.size())); - string_views.push_back(strings.back()); + for (const auto& input_string : input_buffer) { + strings.push_back(input_string); } } - std::vector strings; - std::vector string_views; + std::vector strings; }; - std::vector> - per_shard_buffer_backing_stores; - per_shard_buffer_backing_stores.reserve(num_shards); + + std::vector> per_shard_strings; + per_shard_strings.reserve(num_shards); std::vector on_done_with_buffer_callbacks; on_done_with_buffer_callbacks.reserve(num_shards); @@ -234,30 +248,29 @@ BasicStringArray::DisassembleIntoSingleDeviceArrays( buffer_promises.push_back(Future::CreatePromise()); buffer_futures.push_back(Future(buffer_promises.back())); - auto backing_store = std::make_shared(); - per_shard_buffer_backing_stores.push_back(backing_store); + auto current_shard_strings = std::make_shared(); + per_shard_strings.push_back(current_shard_strings); on_done_with_buffer_callbacks.push_back( - [backing_store = std::move(backing_store)]() {}); + [data = std::move(current_shard_strings)]() {}); } - // Copy each of the per-shard data into the its per-shard buffer backing - // store, make a Buffers object and set the corresponding promise. + // When the buffers become ready, copy each of the per-shard data into the + // buffer of the corresponding single-device array. buffers_.OnReady([buffer_promises = std::move(buffer_promises), - per_shard_buffer_backing_stores = - std::move(per_shard_buffer_backing_stores)]( + per_shard_data = std::move(per_shard_strings)]( absl::StatusOr buffers) mutable { if (!buffers.ok()) { for (auto& promise : buffer_promises) { promise.Set(buffers.status()); } - per_shard_buffer_backing_stores.clear(); + per_shard_data.clear(); return; } auto num_shards = buffers->size(); for (int i = 0; i < num_shards; ++i) { - per_shard_buffer_backing_stores[i]->CopyFrom((*buffers)[i]); + per_shard_data[i]->CopyFrom((*buffers)[i]); Buffers buffers; - buffers.push_back(per_shard_buffer_backing_stores[i]->string_views); + buffers.push_back(absl::MakeConstSpan(per_shard_data[i]->strings)); buffer_promises[i].Set(std::move(buffers)); } }); @@ -285,7 +298,37 @@ Future<> BasicStringArray::CopyToHostBuffer( void* data, std::optional> byte_strides, ArrayCopySemantics semantics) { DCHECK(this); - return Future<>(absl::UnimplementedError("Not implemented")); + absl::MutexLock lock(&mu_); + if (is_deleted_) { + return Future<>( + absl::FailedPreconditionError("Array has already been deleted")); + } + + if (sharding_->devices()->size() != 1) { + return Future<>(absl::InvalidArgumentError(absl::StrFormat( + "CopyToHostBuffer only supports single device string arrays. This " + "array has been sharded over %d devices.", + sharding_->devices()->size()))); + } + + auto copy_completion_promise = Future<>::CreatePromise(); + auto copy_completion_future = Future<>(copy_completion_promise); + + buffers_.OnReady( + [copy_completion_promise = std::move(copy_completion_promise), + host_buffer = static_cast(data)]( + absl::StatusOr input_buffers) mutable { + if (!input_buffers.ok()) { + copy_completion_promise.Set(input_buffers.status()); + return; + } + const absl::Span& input_buffer = (*input_buffers)[0]; + for (int i = 0; i < input_buffer.size(); ++i) { + host_buffer[i] = input_buffer[i]; + } + copy_completion_promise.Set(absl::OkStatus()); + }); + return copy_completion_future; } absl::StatusOr> BasicStringArray::Copy( @@ -307,29 +350,24 @@ absl::StatusOr> BasicStringArray::Copy( sharding_->devices()->size())); } - struct BufferBackingStore { - void AddShardData(absl::Span input_buffer) { + struct StringStore { + void AddShardData(absl::Span input_buffer) { auto& shard_strings = strings.emplace_back(); shard_strings.reserve(input_buffer.size()); - auto& shard_string_views = string_views.emplace_back(); - shard_string_views.reserve(input_buffer.size()); - - for (absl::string_view buf : input_buffer) { - shard_strings.push_back(std::string(buf.data(), buf.size())); - shard_string_views.push_back(shard_strings.back()); + for (const auto& input_string : input_buffer) { + shard_strings.push_back(input_string); } } - std::vector> strings; - std::vector> string_views; + std::vector> strings; }; - auto backing_store = std::make_shared(); - auto on_done_with_buffer = [backing_store]() {}; + auto string_store = std::make_shared(); + auto on_done_with_buffer = [string_store]() {}; auto buffers_promise = Future::CreatePromise(); auto buffers_future = Future(buffers_promise); - auto copier = [backing_store = std::move(backing_store), + auto copier = [string_store = std::move(string_store), buffers_promise = std::move(buffers_promise)]( absl::StatusOr input_buffers) mutable { if (!input_buffers.ok()) { @@ -339,8 +377,8 @@ absl::StatusOr> BasicStringArray::Copy( Buffers buffers; buffers.reserve(input_buffers->size()); for (auto& input_buffer : *input_buffers) { - backing_store->AddShardData(input_buffer); - buffers.push_back(backing_store->string_views.back()); + string_store->AddShardData(input_buffer); + buffers.push_back(string_store->strings.back()); } buffers_promise.Set(std::move(buffers)); }; @@ -366,25 +404,22 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( if (!sharding_->IsFullyReplicated()) { return absl::FailedPreconditionError("This array is not fully replicated"); } - struct BufferBackingStore { // Data (strings) for a single shard. - void CopyFrom(absl::Span input_buffer) { + struct StringStore { // Data (strings) for a single shard. + void CopyFrom(absl::Span input_buffer) { strings.reserve(input_buffer.size()); - string_views.reserve(input_buffer.size()); - for (absl::string_view buf : input_buffer) { - strings.push_back(std::string(buf.data(), buf.size())); - string_views.push_back(strings.back()); + for (const auto& input_strings : input_buffer) { + strings.push_back(input_strings); } } - std::vector strings; - std::vector string_views; + std::vector strings; }; - auto backing_store = std::make_shared(); - auto on_done_with_buffer = [backing_store]() {}; + auto string_store = std::make_shared(); + auto on_done_with_buffer = [string_store]() {}; auto buffers_promise = Future::CreatePromise(); auto buffers_future = Future(buffers_promise); - auto copier = [backing_store = std::move(backing_store), + auto copier = [string_store = std::move(string_store), buffers_promise = std::move(buffers_promise)]( absl::StatusOr input_buffers) mutable { if (!input_buffers.ok()) { @@ -396,10 +431,10 @@ absl::StatusOr> BasicStringArray::FullyReplicatedShard( // were run when the source array's buffers became ready would have // ensured that the input_buffers have at least one shard's worth of data. auto& input_buffer = (*input_buffers)[0]; - backing_store->CopyFrom(input_buffer); + string_store->CopyFrom(input_buffer); Buffers buffers; - buffers.push_back(backing_store->string_views); + buffers.push_back(string_store->strings); buffers_promise.Set(std::move(buffers)); }; buffers_.OnReady(std::move(copier)); diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h index d10950276b292c..a430cfa73fdd26 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array.h @@ -28,7 +28,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/hash/hash.h" #include "absl/log/check.h" -#include "absl/strings/string_view.h" +#include "absl/strings/cord.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/Support/ExtensibleRTTI.h" @@ -71,7 +71,7 @@ class BasicStringArray final : public llvm::RTTIExtends { public: // Must be in dense major to minor order. - using Buffer = absl::Span; + using Buffer = absl::Span; // One Buffer per shard. static constexpr int kBuffersInlineSize = 1; @@ -82,7 +82,7 @@ class BasicStringArray final using OnDoneWithBuffer = std::function; // General array construction. The `buffers` and their elements - // (absl::string_views) must live until the `on_done_with_buffer` is called. + // (absl::Cords) must live until the `on_done_with_buffer` is called. // The number and order of buffers must match the number and order of devices // in `sharding`. static absl::StatusOr> Create( @@ -125,6 +125,10 @@ class BasicStringArray final absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) override; ABSL_MUST_USE_RESULT Future<> CopyToHostBuffer( diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index 37c37ce9ace8d3..c402f0a38ecdb2 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" @@ -54,6 +55,8 @@ namespace xla { namespace ifrt { namespace { +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; using ::testing::HasSubstr; using ::tsl::testing::StatusIs; @@ -84,21 +87,15 @@ std::pair MakeBuffersAndOnDoneWithBuffer( absl::Span input_strings) { BasicStringArray::Buffers buffers; - auto string_holder = std::make_shared>(); - string_holder->reserve(input_strings.size()); - auto string_view_holder = std::make_shared>(); - string_view_holder->reserve(input_strings.size()); - for (const auto str : input_strings) { - string_holder->push_back(std::string(str)); + auto strings = std::make_shared>(); + strings->reserve(input_strings.size()); + for (const auto input_str : input_strings) { + strings->push_back(absl::Cord(input_str)); } - for (const auto& str : *string_holder) { - string_view_holder->push_back(absl::string_view(str)); - } - buffers.push_back(*string_view_holder); + buffers.push_back(*strings); BasicStringArray::OnDoneWithBuffer on_done_with_buffer = - [string_holder = std::move(string_holder), - string_view_holder = std::move(string_view_holder)]() {}; + [strings = std::move(strings)]() {}; return std::make_pair(std::move(buffers), std::move(on_done_with_buffer)); } @@ -175,7 +172,7 @@ TEST(BasicStringArrayLayoutTest, Equality) { TEST(BasicStringArrayTest, CreateSuccess) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); BasicStringArray::Buffers buffers; - buffers.push_back({"abc", "def"}); + buffers.push_back({absl::Cord("abc"), absl::Cord("def")}); // This test implicitly tests that the on_done_with_buffer can be a nullptr, // and that the destruction of the BasicStringArray object completes @@ -197,7 +194,7 @@ TEST(BasicStringArrayTest, Destruction) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); BasicStringArray::Buffers buffers; - buffers.push_back({"abc", "def"}); + buffers.push_back({absl::Cord("abc"), absl::Cord("def")}); absl::Notification on_done_with_buffer_called; BasicStringArray::OnDoneWithBuffer on_done_with_buffer = @@ -228,10 +225,10 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) { ASSERT_GE(devices.size(), 1); // Make a BasicStringArray::Buffer with two shards. - auto shard0_data = std::make_shared>(); - shard0_data->push_back("abc"); - auto shard1_data = std::make_shared>(); - shard1_data->push_back("def"); + auto shard0_data = std::make_shared>(); + shard0_data->push_back(absl::Cord("abc")); + auto shard1_data = std::make_shared>(); + shard1_data->push_back(absl::Cord("def")); BasicStringArray::Buffers buffers; buffers.push_back(*shard0_data); buffers.push_back(*shard1_data); @@ -260,7 +257,7 @@ TEST(BasicStringArrayTest, InvalidBuffersAreHandledCorrectly) { TEST(BasicStringArrayTest, Delete) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); BasicStringArray::Buffers buffers; - buffers.push_back({"abc", "def"}); + buffers.push_back({absl::Cord("abc"), absl::Cord("def")}); absl::Notification on_done_with_buffer_called; BasicStringArray::OnDoneWithBuffer on_done_with_buffer = [&on_done_with_buffer_called]() { on_done_with_buffer_called.Notify(); }; @@ -294,7 +291,7 @@ TEST(GetReadyFutureTest, SuccessCase) { // Make the buffers future ready asynchronously. BasicStringArray::Buffers buffers; - buffers.push_back({"abc", "def"}); + buffers.push_back({absl::Cord("abc"), absl::Cord("def")}); tsl::Env::Default()->SchedClosure([&]() { promise.Set(buffers); }); TF_EXPECT_OK(ready_future.Await()); } @@ -326,11 +323,11 @@ TEST(MakeArrayFromHostBufferTest, SuccessCase) { std::shared_ptr sharding = SingleDeviceSharding::Create(device, MemoryKind()); - auto string_views = std::make_shared>(); - string_views->push_back("abc"); - string_views->push_back("def"); - const void* data = string_views->data(); - auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {}; + auto strings = std::make_shared>(); + strings->push_back(absl::Cord("abc")); + strings->push_back(absl::Cord("def")); + const void* data = strings->data(); + auto on_done_with_host_buffer = [strings = std::move(strings)]() {}; TF_ASSERT_OK(client->MakeArrayFromHostBuffer( data, DType(DType::kString), shape, @@ -345,11 +342,11 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) { Device* device = client->addressable_devices().at(0); std::shared_ptr single_device_sharding = SingleDeviceSharding::Create(device, MemoryKind()); - auto string_views = std::make_shared>(); - string_views->push_back("abc"); - string_views->push_back("def"); - const void* data = string_views->data(); - auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {}; + auto strings = std::make_shared>(); + strings->push_back(absl::Cord("abc")); + strings->push_back(absl::Cord("def")); + const void* data = strings->data(); + auto on_done_with_host_buffer = [strings = std::move(strings)]() {}; // MakeArrayFromHostBuffer should check and fail if `byte_strides` in not // nullopt. @@ -394,16 +391,16 @@ TEST(MakeArrayFromHostBufferTest, FailureCases) { absl::StatusOr> MakeSingleDeviceStringTestArray( absl::Span contents, Client* client, Device* const device) { - Shape shape({1}); + Shape shape(absl::MakeConstSpan({static_cast(contents.size())})); std::shared_ptr sharding = SingleDeviceSharding::Create(device, MemoryKind()); - auto string_views = std::make_shared>(); + auto strings = std::make_shared>(); for (const auto& content : contents) { - string_views->push_back(content); + strings->push_back(absl::Cord(content)); } - const void* data = string_views->data(); - auto on_done_with_host_buffer = [string_views = std::move(string_views)]() {}; + const void* data = strings->data(); + auto on_done_with_host_buffer = [strings = std::move(strings)]() {}; return client->MakeArrayFromHostBuffer( data, DType(DType::kString), shape, @@ -478,7 +475,7 @@ TEST(AssembleArrayFromSingleDeviceArraysTest, for (int i = 0; i < buffers.size(); ++i) { SCOPED_TRACE(absl::StrCat("buffer #", i)); auto buffer = buffers[i]; - EXPECT_THAT(buffer, testing::ElementsAre(per_shard_contents[i])); + EXPECT_THAT(buffer, ElementsAre(per_shard_contents[i])); } } @@ -570,9 +567,9 @@ TEST(AssembleArrayFromSingleDeviceArraysTest, auto buffers_future = basic_string_array->buffers(); TF_ASSERT_OK_AND_ASSIGN(auto buffers, buffers_future.Await()); - EXPECT_EQ(buffers.size(), 2); - EXPECT_THAT(buffers[0], testing::ElementsAre("abc")); - EXPECT_THAT(buffers[1], testing::ElementsAre("def")); + ASSERT_EQ(buffers.size(), 2); + EXPECT_THAT(buffers[0], ElementsAre("abc")); + EXPECT_THAT(buffers[1], ElementsAre("def")); } TEST(AssembleArrayFromSingleDeviceArraysTest, @@ -649,8 +646,8 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, basic_string_array->buffers().Await()); - ASSERT_EQ(buffers.size(), 1); - EXPECT_THAT(buffers[0], testing::ElementsAre("abc")); + ASSERT_EQ(new_buffers.size(), 1); + EXPECT_THAT(new_buffers[0], ElementsAre("abc")); } TEST(DisassembleArrayIntoSingleDeviceArrays, ShardedArrayDisassembleSuccess) { @@ -673,7 +670,7 @@ TEST(DisassembleArrayIntoSingleDeviceArrays, ShardedArrayDisassembleSuccess) { llvm::dyn_cast(disassembled_arrays[i].get()); TF_ASSERT_OK_AND_ASSIGN(auto buffer, basic_string_array->buffers().Await()); ASSERT_EQ(buffer.size(), 1); - EXPECT_THAT(buffer[0], testing::ElementsAre(per_shard_contents[i])); + EXPECT_THAT(buffer[0], ElementsAre(per_shard_contents[i])); } } @@ -719,7 +716,7 @@ TEST(CopyTest, SuccessSingleDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, new_basic_string_array->buffers().Await()); ASSERT_EQ(new_buffers.size(), 1); - EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc")); + EXPECT_THAT(new_buffers[0], ElementsAre("abc")); } TEST(CopyTest, SuccessMultiDeviceShardedArray) { @@ -745,8 +742,8 @@ TEST(CopyTest, SuccessMultiDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, new_basic_string_array->buffers().Await()); ASSERT_EQ(new_buffers.size(), 2); - EXPECT_THAT(new_buffers[0], testing::ElementsAre("shard 0")); - EXPECT_THAT(new_buffers[1], testing::ElementsAre("shard 1")); + EXPECT_THAT(new_buffers[0], ElementsAre("shard 0")); + EXPECT_THAT(new_buffers[1], ElementsAre("shard 1")); } TEST(CopyTest, FailsAfterDeletion) { @@ -819,7 +816,7 @@ TEST(CopyTest, NonReadySourceArraySuccessfullyBecomesReadyAfterCopy) { TF_ASSERT_OK_AND_ASSIGN(auto new_buffers, basic_string_array->buffers().Await()); ASSERT_EQ(new_buffers.size(), 1); - EXPECT_THAT(new_buffers[0], testing::ElementsAre("abc")); + EXPECT_THAT(new_buffers[0], ElementsAre("abc")); // Make sure to wait for the Closure to complete its work and set both // promises before returning from the test. The consequent destruction of the @@ -833,7 +830,6 @@ TEST(CopyTest, NonReadySourceArrayFailsToBecomeReadyAfterCopy) { ASSERT_GE(devices.size(), 2); auto buf_and_on_done_with_buffer = MakeBuffersAndOnDoneWithBuffer({"abc"}); - auto buffers = buf_and_on_done_with_buffer.first; auto on_done_with_buffer = buf_and_on_done_with_buffer.second; TF_ASSERT_OK_AND_ASSIGN( auto ret, CreateNonReadyTestArray(client.get(), devices[0], @@ -886,7 +882,7 @@ TEST(FullyReplicatedShardTest, SuccessSingleDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto replicated_buffers, replicated_basic_string_array->buffers().Await()); ASSERT_EQ(replicated_buffers.size(), 1); - EXPECT_THAT(replicated_buffers[0], testing::ElementsAre(kContents)); + EXPECT_THAT(replicated_buffers[0], ElementsAre(kContents)); } TEST(FullyReplicatedShardTest, SuccessMultiDeviceShardedArray) { @@ -908,7 +904,7 @@ TEST(FullyReplicatedShardTest, SuccessMultiDeviceShardedArray) { TF_ASSERT_OK_AND_ASSIGN(auto replicated_buffers, replicated_basic_string_array->buffers().Await()); ASSERT_EQ(replicated_buffers.size(), 1); - EXPECT_THAT(replicated_buffers[0], testing::ElementsAre(kReplicatedContents)); + EXPECT_THAT(replicated_buffers[0], ElementsAre(kReplicatedContents)); } TEST(FullyReplicatedShardTest, FailsWithNonFullyReplicatedArrays) { @@ -977,6 +973,127 @@ TEST(LayoutTest, FailsAfterDeletion) { EXPECT_THAT(array->layout(), StatusIs(absl::StatusCode::kFailedPrecondition)); } +///////////////////////////////////////////////////////////////////////////// +// +// Tests related to CopyToHostBuffer +// + +TEST(CopyToHostBufferTest, Success) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + auto devices = client->addressable_devices(); + ASSERT_GE(devices.size(), 1); + std::vector input_data = {"abc", "def"}; + TF_ASSERT_OK_AND_ASSIGN( + auto array, + MakeSingleDeviceStringTestArray(input_data, client.get(), devices[0])); + + auto data_read = std::make_unique>(input_data.size()); + TF_ASSERT_OK(array + ->CopyToHostBuffer(data_read->data(), + /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy) + .Await()); + EXPECT_THAT(*data_read, ElementsAreArray(input_data)); +} + +TEST(CopyToHostBufferTest, FailsAfterDeletion) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + auto devices = client->addressable_devices(); + ASSERT_GE(devices.size(), 1); + std::vector input_data = {"abc", "def"}; + TF_ASSERT_OK_AND_ASSIGN( + auto array, + MakeSingleDeviceStringTestArray(input_data, client.get(), devices[0])); + + TF_ASSERT_OK(array->Delete().Await()); + + auto data_read = std::make_unique>(input_data.size()); + EXPECT_THAT(array + ->CopyToHostBuffer(data_read->data(), + /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy) + .Await(), + StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +TEST(CopyToHostBufferTest, FailsWithMultiDeviceShardedArray) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + auto devices = client->addressable_devices(); + ASSERT_GE(devices.size(), 2); + std::vector per_shard_data = {"shard-0", "shard-1"}; + TF_ASSERT_OK_AND_ASSIGN( + auto array, MakeShardedStringTestArray(client.get(), per_shard_data, + /*is_fully_replicated=*/false)); + + auto data_read = + std::make_unique>(per_shard_data.size()); + EXPECT_THAT(array + ->CopyToHostBuffer(data_read->data(), + /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy) + .Await(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(CopytoHostBufferTest, + WorksWithNonReadySourceArrayThatSuccessfullyBecomesReadyAfterCreation) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + auto devices = client->addressable_devices(); + ASSERT_GE(devices.size(), 1); + auto buf_and_on_done_with_buffer = MakeBuffersAndOnDoneWithBuffer({"abc"}); + auto buffers = buf_and_on_done_with_buffer.first; + auto on_done_with_buffer = buf_and_on_done_with_buffer.second; + TF_ASSERT_OK_AND_ASSIGN( + auto ret, CreateNonReadyTestArray(client.get(), devices[0], + std::move(on_done_with_buffer))); + auto array = ret.first; + auto promise = std::move(ret.second); + + auto data_read = std::make_unique>(1); + auto copy_completion_future = + array->CopyToHostBuffer(data_read->data(), /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy); + + absl::Notification done_readying_single_device_arrays; + tsl::Env::Default()->SchedClosure(([&]() mutable { + promise.Set(std::move(buffers)); + done_readying_single_device_arrays.Notify(); + })); + + done_readying_single_device_arrays.WaitForNotification(); + + TF_ASSERT_OK(copy_completion_future.Await()); + EXPECT_THAT(*data_read, ElementsAre("abc")); +} + +TEST(CopytoHostBufferTest, + WorksWithNonReadySourceArrayThatFailsToBecomeReadyAfterCreation) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + auto devices = client->addressable_devices(); + ASSERT_GE(devices.size(), 1); + TF_ASSERT_OK_AND_ASSIGN( + auto ret, CreateNonReadyTestArray(client.get(), devices[0], + /*on_done_with_buffer=*/[]() {})); + auto array = ret.first; + auto promise = std::move(ret.second); + + auto data_read = std::make_unique>(1); + auto copy_completion_future = + array->CopyToHostBuffer(data_read->data(), /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy); + + absl::Notification done_readying_single_device_arrays; + tsl::Env::Default()->SchedClosure(([&]() mutable { + promise.Set(absl::InternalError("injected from the test")); + done_readying_single_device_arrays.Notify(); + })); + + done_readying_single_device_arrays.WaitForNotification(); + + EXPECT_THAT(copy_completion_future.Await(), + StatusIs(absl::StatusCode::kInternal)); +} + } // namespace } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 4b0949165e75fe..d8fb6650449628 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -45,6 +45,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -130,82 +131,6 @@ absl::StatusOr GetMemoryKindFromPjRtBuffers( char PjRtCompatibleArray::ID = 0; char PjRtArray::ID = 0; -absl::StatusOr ToPrimitiveType(DType dtype) { - switch (dtype.kind()) { -#define CASE(DT, PT) \ - case DT: \ - static_assert(PT == \ - static_cast(static_cast(DT))); \ - return PT - CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID); - CASE(DType::kPred, xla::PrimitiveType::PRED); - CASE(DType::kS2, xla::PrimitiveType::S2); - CASE(DType::kS4, xla::PrimitiveType::S4); - CASE(DType::kS8, xla::PrimitiveType::S8); - CASE(DType::kS16, xla::PrimitiveType::S16); - CASE(DType::kS32, xla::PrimitiveType::S32); - CASE(DType::kS64, xla::PrimitiveType::S64); - CASE(DType::kU2, xla::PrimitiveType::U2); - CASE(DType::kU4, xla::PrimitiveType::U4); - CASE(DType::kU8, xla::PrimitiveType::U8); - CASE(DType::kU16, xla::PrimitiveType::U16); - CASE(DType::kU32, xla::PrimitiveType::U32); - CASE(DType::kU64, xla::PrimitiveType::U64); - CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); - CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); - CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); - CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); - CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); - CASE(DType::kF16, xla::PrimitiveType::F16); - CASE(DType::kF32, xla::PrimitiveType::F32); - CASE(DType::kBF16, xla::PrimitiveType::BF16); - CASE(DType::kF64, xla::PrimitiveType::F64); - CASE(DType::kC64, xla::PrimitiveType::C64); - CASE(DType::kC128, xla::PrimitiveType::C128); - CASE(DType::kToken, xla::PrimitiveType::TOKEN); -#undef CASE - case DType::kString: - return InvalidArgument("Not supported as XLA PrimitiveType: %d", - static_cast(dtype.kind())); - } - return InvalidArgument("Invalid DType: %d", static_cast(dtype.kind())); -} - -absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { - switch (primitive_type) { - case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: - case xla::PrimitiveType::PRED: - case xla::PrimitiveType::S2: - case xla::PrimitiveType::S4: - case xla::PrimitiveType::S8: - case xla::PrimitiveType::S16: - case xla::PrimitiveType::S32: - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U2: - case xla::PrimitiveType::U4: - case xla::PrimitiveType::U8: - case xla::PrimitiveType::U16: - case xla::PrimitiveType::U32: - case xla::PrimitiveType::U64: - case xla::PrimitiveType::F8E4M3FN: - case xla::PrimitiveType::F8E4M3B11FNUZ: - case xla::PrimitiveType::F8E4M3FNUZ: - case xla::PrimitiveType::F8E5M2: - case xla::PrimitiveType::F8E5M2FNUZ: - case xla::PrimitiveType::F16: - case xla::PrimitiveType::F32: - case xla::PrimitiveType::BF16: - case xla::PrimitiveType::F64: - case xla::PrimitiveType::C64: - case xla::PrimitiveType::C128: - case xla::PrimitiveType::TOKEN: - return DType(static_cast(static_cast(primitive_type))); - default: - return InvalidArgument("Invalid XLA PrimitiveType: %d", - static_cast(primitive_type)); - } -} - MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer) { if (pjrt_buffer->memory_space() == nullptr) { return MemoryKind(); @@ -344,13 +269,31 @@ PjRtArray::PjRtArray(PjRtCompatibleClient* client, DType dtype, absl::StatusOr>> PjRtArray::DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) { DCHECK(this); + return DisassembleIntoSingleDeviceArrays( + semantics, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>> +PjRtArray::DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + DCHECK(this); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards && + !sharding_->devices()->IsFullyAddressable()) { + return InvalidArgument( + "All shards are requested but the sharding has non-addressable " + "devices: %v", + *sharding_->devices()); + } std::vector> result; - result.reserve(sharding_->devices()->size()); + result.reserve(sharding_->devices()->AddressableDeviceList()->size()); TF_RETURN_IF_ERROR(std::visit( [&](const auto& this_shape) { - TF_ASSIGN_OR_RETURN(auto shape_and_shardings, - sharding_->Disassemble(this_shape)); - for (int i = 0; i < sharding_->devices()->size(); ++i) { + TF_ASSIGN_OR_RETURN( + auto shape_and_shardings, + sharding_->Disassemble( + this_shape, SingleDeviceShardSemantics::kAddressableShards)); + for (int i = 0; i < shape_and_shardings.size(); ++i) { PjRtBuffers buffers; buffers.reserve(1); buffers.push_back(GetPjRtBuffer(semantics, i)); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h index f91de6531c234e..d14747fea550ea 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.h @@ -44,12 +44,6 @@ limitations under the License. namespace xla { namespace ifrt { -// Converts IFRT `DType` into `xla::PrimitiveType`. -absl::StatusOr ToPrimitiveType(DType dtype); - -// Converts `xla::PrimitiveType` into IFRT `DType`. -absl::StatusOr ToDType(xla::PrimitiveType primitive_type); - // Creates IFRT `MemoryKind` from an XLA `PjRtBuffer`. MemoryKind MakeMemoryKindFromPjRtBuffer(PjRtBuffer* pjrt_buffer); @@ -161,6 +155,10 @@ class PjRtArray final absl::StatusOr>> DisassembleIntoSingleDeviceArrays(ArrayCopySemantics semantics) override; + absl::StatusOr>> + DisassembleIntoSingleDeviceArrays( + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) override; ABSL_MUST_USE_RESULT Future<> CopyToHostBuffer( diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc index b6aed371cca9dc..a9603fe5098ddd 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.cc @@ -69,6 +69,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/pjrt_remap.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" @@ -184,21 +185,17 @@ absl::StatusOr> MakeStringArrayFromHostBuffer( TF_RETURN_IF_ERROR(param_validation); auto num_elements = shape.num_elements(); - auto strings = std::make_shared>(); + auto strings = std::make_shared>(); strings->reserve(num_elements); - auto string_views = std::make_shared>(); - string_views->reserve(num_elements); - auto element = static_cast(data); + auto element = static_cast(data); for (int i = 0; i < num_elements; ++i, ++element) { - strings->push_back(std::string(*element)); - string_views->push_back(absl::string_view(strings->back())); + strings->push_back(*element); } std::move(on_done_with_host_buffer)(); BasicStringArray::Buffers buffers; - buffers.push_back(*string_views); - auto buffer_releaser = [strings = std::move(strings), - string_views = std::move(string_views)]() {}; + buffers.push_back(*strings); + auto buffer_releaser = [strings = std::move(strings)]() {}; return BasicStringArray::Create( client, std::move(shape), std::move(sharding), @@ -209,33 +206,35 @@ absl::StatusOr> MakeStringArrayFromHostBuffer( absl::StatusOr> AssembleStringArrayFromSingleDeviceStringArrays( Shape shape, std::shared_ptr sharding, - absl::Span> arrays, ArrayCopySemantics semantics) { + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards && + !sharding->devices()->IsFullyAddressable()) { + return InvalidArgument( + "All shards are requested but the sharding has non-addressable " + "devices: %v", + *sharding->devices()); + } // BufferBackingState contains the per-shard vectors of the strings and // string_views underlying a BasicString::Buffer. Not thread safe. struct BufferBackingStore { explicit BufferBackingStore(int num_shards) - : per_shard_strings(num_shards), per_shard_string_views(num_shards) {} + : per_shard_strings(num_shards) {} void clear() { per_shard_strings.clear(); - per_shard_string_views.clear(); } - void CopyBuffer(absl::Span strbuf, int shard_index, + + void CopyBuffer(absl::Span strbuf, int shard_index, BasicStringArray::Buffers* buffers) { auto& strings = per_shard_strings[shard_index]; strings.reserve(strbuf.size()); - auto& views = per_shard_string_views[shard_index]; - views.reserve(strbuf.size()); - for (int i = 0; i < strbuf.size(); ++i) { - strings.push_back(std::string(strbuf[i].data(), strbuf[i].size())); + strings.push_back(strbuf[i]); } - for (const auto& str : strings) { - views.push_back(str); - } - (*buffers)[shard_index] = absl::MakeConstSpan(views); + (*buffers)[shard_index] = absl::MakeConstSpan(strings); } - std::vector> per_shard_strings; - std::vector> per_shard_string_views; + std::vector> per_shard_strings; }; auto buffer_backing_store = std::make_shared(sharding->devices()->size()); @@ -631,6 +630,18 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays( Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) { DCHECK(this); + return AssembleArrayFromSingleDeviceArrays( + std::move(shape), std::move(sharding), arrays, semantics, + SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> +PjRtClient::AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) { + DCHECK(this); if (llvm::isa(sharding.get())) { // Assemble with SingleDeviceSharding is No-op. if (arrays.size() != 1) { @@ -649,15 +660,23 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays( "supported: sharding=%s", sharding->DebugString()); } - if (sharding->devices()->size() != arrays.size()) { + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards && + !sharding->devices()->IsFullyAddressable()) { + return InvalidArgument( + "All shards are requested but the sharding has non-addressable " + "devices: %v", + *sharding->devices()); + } + if (sharding->devices()->AddressableDeviceList()->size() != arrays.size()) { return InvalidArgument( - "Number of output shards must match the number of single-shard " - "arrays: %d vs. %d", - sharding->devices()->size(), arrays.size()); + "Number of addressable output shards must match the number of " + "single-shard arrays: %d vs. %d", + sharding->devices()->AddressableDeviceList()->size(), arrays.size()); } if (arrays[0]->dtype().kind() == DType::kString) { - return AssembleStringArrayFromSingleDeviceStringArrays(shape, sharding, - arrays, semantics); + return AssembleStringArrayFromSingleDeviceStringArrays( + shape, sharding, arrays, array_copy_semantics, + single_device_shard_semantics); } PjRtArray::PjRtBuffers buffers; buffers.reserve(arrays.size()); @@ -681,7 +700,7 @@ PjRtClient::AssembleArrayFromSingleDeviceArrays( "sharding=%s", i, array->sharding().DebugString()); } - switch (semantics) { + switch (array_copy_semantics) { case ArrayCopySemantics::kAlwaysCopy: // TODO(hyeontaek): kAlwaysCopy should clone the buffer, but the PjRt // API does not have efficient buffer cloning on the same device. diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h index 23900c049f344c..91cbae4155b438 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_client.h @@ -135,11 +135,12 @@ class PjRtClient final ~PjRtClient() override; // For making Arrays with `dtype` as kString: - // (1) the `data` argument should point to an array of `absl::string_view` + // (1) the `data` argument should point to an array of `absl::Cord` // in major-to-minor order, // (2) `byte_strides` are not supported, and non-`nullopt` values cause this // function to fail. // (3) only the `kImmutableDuringCall` semantics is supported currently. + // Fails for other values of `HostBufferSemantics`. absl::StatusOr> MakeArrayFromHostBuffer( const void* data, DType dtype, Shape shape, std::optional> byte_strides, @@ -151,6 +152,11 @@ class PjRtClient final Shape shape, std::shared_ptr sharding, absl::Span> arrays, ArrayCopySemantics semantics) override; + absl::StatusOr> AssembleArrayFromSingleDeviceArrays( + Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ArrayCopySemantics array_copy_semantics, + SingleDeviceShardSemantics single_device_shard_semantics) override; absl::StatusOr>> CopyArrays( absl::Span> arrays, @@ -169,10 +175,7 @@ class PjRtClient final absl::StatusOr> MakeTuple( absl::Span> values) override; - absl::string_view runtime_type() const override { - DCHECK(this); - return PjRtRuntimeTypeString(pjrt_client_->runtime_type()); - } + absl::string_view runtime_type() const override { return "pjrt_ifrt"; } absl::string_view platform_name() const override { DCHECK(this); @@ -206,6 +209,12 @@ class PjRtClient final return addressable_devices_; } int process_index() const override { return pjrt_client_->process_index(); } + + absl::Span GetAllDevices() const override { + DCHECK(this); + return devices_; + } + absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { DCHECK(this); diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc new file mode 100644 index 00000000000000..10a293778bd467 --- /dev/null +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.cc @@ -0,0 +1,107 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" + +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace ifrt { + +absl::StatusOr ToPrimitiveType(DType dtype) { + switch (dtype.kind()) { +#define CASE(DT, PT) \ + case DT: \ + static_assert(PT == \ + static_cast(static_cast(DT))); \ + return PT + CASE(DType::kInvalid, xla::PrimitiveType::PRIMITIVE_TYPE_INVALID); + CASE(DType::kPred, xla::PrimitiveType::PRED); + CASE(DType::kS2, xla::PrimitiveType::S2); + CASE(DType::kS4, xla::PrimitiveType::S4); + CASE(DType::kS8, xla::PrimitiveType::S8); + CASE(DType::kS16, xla::PrimitiveType::S16); + CASE(DType::kS32, xla::PrimitiveType::S32); + CASE(DType::kS64, xla::PrimitiveType::S64); + CASE(DType::kU2, xla::PrimitiveType::U2); + CASE(DType::kU4, xla::PrimitiveType::U4); + CASE(DType::kU8, xla::PrimitiveType::U8); + CASE(DType::kU16, xla::PrimitiveType::U16); + CASE(DType::kU32, xla::PrimitiveType::U32); + CASE(DType::kU64, xla::PrimitiveType::U64); + CASE(DType::kF8E3M4, xla::PrimitiveType::F8E3M4); + CASE(DType::kF8E4M3, xla::PrimitiveType::F8E4M3); + CASE(DType::kF8E4M3FN, xla::PrimitiveType::F8E4M3FN); + CASE(DType::kF8E4M3B11FNUZ, xla::PrimitiveType::F8E4M3B11FNUZ); + CASE(DType::kF8E4M3FNUZ, xla::PrimitiveType::F8E4M3FNUZ); + CASE(DType::kF8E5M2, xla::PrimitiveType::F8E5M2); + CASE(DType::kF8E5M2FNUZ, xla::PrimitiveType::F8E5M2FNUZ); + CASE(DType::kF16, xla::PrimitiveType::F16); + CASE(DType::kF32, xla::PrimitiveType::F32); + CASE(DType::kBF16, xla::PrimitiveType::BF16); + CASE(DType::kF64, xla::PrimitiveType::F64); + CASE(DType::kC64, xla::PrimitiveType::C64); + CASE(DType::kC128, xla::PrimitiveType::C128); + CASE(DType::kToken, xla::PrimitiveType::TOKEN); +#undef CASE + case DType::kString: + return InvalidArgument("Not supported as XLA PrimitiveType: %d", + static_cast(dtype.kind())); + } + return InvalidArgument("Invalid DType: %d", static_cast(dtype.kind())); +} + +absl::StatusOr ToDType(xla::PrimitiveType primitive_type) { + switch (primitive_type) { + case xla::PrimitiveType::PRIMITIVE_TYPE_INVALID: + case xla::PrimitiveType::PRED: + case xla::PrimitiveType::S2: + case xla::PrimitiveType::S4: + case xla::PrimitiveType::S8: + case xla::PrimitiveType::S16: + case xla::PrimitiveType::S32: + case xla::PrimitiveType::S64: + case xla::PrimitiveType::U2: + case xla::PrimitiveType::U4: + case xla::PrimitiveType::U8: + case xla::PrimitiveType::U16: + case xla::PrimitiveType::U32: + case xla::PrimitiveType::U64: + case xla::PrimitiveType::F8E3M4: + case xla::PrimitiveType::F8E4M3: + case xla::PrimitiveType::F8E4M3FN: + case xla::PrimitiveType::F8E4M3B11FNUZ: + case xla::PrimitiveType::F8E4M3FNUZ: + case xla::PrimitiveType::F8E5M2: + case xla::PrimitiveType::F8E5M2FNUZ: + case xla::PrimitiveType::F16: + case xla::PrimitiveType::F32: + case xla::PrimitiveType::BF16: + case xla::PrimitiveType::F64: + case xla::PrimitiveType::C64: + case xla::PrimitiveType::C128: + case xla::PrimitiveType::TOKEN: + return DType(static_cast(static_cast(primitive_type))); + default: + return InvalidArgument("Invalid XLA PrimitiveType: %d", + static_cast(primitive_type)); + } +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h similarity index 57% rename from third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc rename to third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h index 9f863e23a6715d..f0ace0292e82dd 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_dtype.h @@ -13,19 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/spmd/shardy/shardy_call_inliner.h" +#ifndef XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ +#define XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ -#include "absl/strings/match.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" +#include "absl/status/statusor.h" +#include "xla/python/ifrt/dtype.h" +#include "xla/xla_data.pb.h" namespace xla { +namespace ifrt { -bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { - return CallInliner::IsInlineableCallOp(instruction) && - !instruction->has_backend_config() && - !(instruction->GetModule()->config().use_shardy_partitioner() && - absl::StrContains(instruction->to_apply()->name(), "shmap_body")); -} +// Converts IFRT `DType` into `xla::PrimitiveType`. +absl::StatusOr ToPrimitiveType(DType dtype); +// Converts `xla::PrimitiveType` into IFRT `DType`. +absl::StatusOr ToDType(xla::PrimitiveType primitive_type); + +} // namespace ifrt } // namespace xla + +#endif // XLA_PYTHON_PJRT_IFRT_PJRT_DTYPE_H_ diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index c19fdece16f028..60f0e6bba78b0c 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/pjrt/host_callback.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -50,6 +51,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_host_callback.h" #include "xla/python/pjrt_ifrt/pjrt_memory.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" @@ -57,7 +59,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -541,7 +542,11 @@ PjRtLoadedExecutable::Execute( const bool returned_future_supported = pjrt_loaded_executable_->IsReturnedFutureSupported(); - auto opts = options; + xla::ExecuteOptions opts; + opts.untuple_result = true; + opts.launch_id = options.launch_id; + opts.use_major_to_minor_data_layout_for_callbacks = true; + opts.non_donatable_input_indices = options.non_donatable_input_indices; if (!all_loaded_host_callbacks_->empty() && !returned_future_supported) { return Internal( @@ -564,9 +569,7 @@ PjRtLoadedExecutable::Execute( contexts.push_back(CreateHostCallbackStateAndAppendSendRecvCallbacks( host_send_recv_callback->host_callback(), /*host_memory_for_device_manager=*/nullptr, send_callbacks, - recv_callbacks, - /*use_major_to_minor_data_layout_for_callbacks=*/ - options.use_major_to_minor_data_layout_for_callbacks)); + recv_callbacks, opts.use_major_to_minor_data_layout_for_callbacks)); } } opts.send_callbacks = host_callback_states->send_callbacks; @@ -575,7 +578,7 @@ PjRtLoadedExecutable::Execute( // Execute the computation. std::vector>> pjrt_outputs; - ExecuteResult result; + xla::ifrt::Future<> status; if (portable_execution) { std::optional> returned_pjrt_future; TF_RET_CHECK(portable_execution_device->IsAddressable()); @@ -588,9 +591,9 @@ PjRtLoadedExecutable::Execute( pjrt_outputs.push_back(std::move(single_device_pjrt_results)); if (returned_future_supported) { - result.status = *std::move(returned_pjrt_future); + status = *std::move(returned_pjrt_future); } else { - result.status = Future<>(absl::OkStatus()); + status = Future<>(absl::OkStatus()); } } else { std::optional>> returned_pjrt_futures; @@ -603,9 +606,9 @@ PjRtLoadedExecutable::Execute( returned_pjrt_futures)); if (returned_future_supported) { - result.status = JoinFutures(absl::MakeSpan(*returned_pjrt_futures)); + status = JoinFutures(absl::MakeSpan(*returned_pjrt_futures)); } else { - result.status = Future<>(absl::OkStatus()); + status = Future<>(absl::OkStatus()); } } @@ -613,10 +616,11 @@ PjRtLoadedExecutable::Execute( // For host callbacks to work, returned futures must be supported so that we // can use the futures to extend the lifetime of the host callbacks until // the execution finishes. - result.status.OnReady( - [all_loaded_host_callbacks = all_loaded_host_callbacks_, - host_callback_states = std::move(host_callback_states)]( - absl::Status) mutable { all_loaded_host_callbacks.reset(); }); + status.OnReady([all_loaded_host_callbacks = all_loaded_host_callbacks_, + host_callback_states = + std::move(host_callback_states)](absl::Status) mutable { + all_loaded_host_callbacks.reset(); + }); } // Convert 2-level PjRtBuffer vectors into an Array vector. @@ -682,6 +686,11 @@ PjRtLoadedExecutable::Execute( output_shapes_[i], std::move(sharding), std::move(buffers))); } + + ExecuteResult result; + if (options.fill_status) { + result.status = status; + } result.outputs = std::move(outputs); return result; } diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h index 6ebd1f9e903481..ce83ee0da24de1 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.h @@ -116,14 +116,14 @@ class PjRtExecutable final return pjrt_executable_->GetOutputShardings(); } - absl::StatusOr>> GetParameterLayouts() - const override { + absl::StatusOr>> + GetParameterLayouts() const override { DCHECK(this); return pjrt_executable_->GetParameterLayouts(); } - absl::StatusOr>> GetOutputLayouts() - const override { + absl::StatusOr>> + GetOutputLayouts() const override { DCHECK(this); return pjrt_executable_->GetOutputLayouts(); } @@ -242,14 +242,14 @@ class PjRtLoadedExecutable final return pjrt_loaded_executable_->GetOutputShardings(); } - absl::StatusOr>> GetParameterLayouts() - const override { + absl::StatusOr>> + GetParameterLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetParameterLayouts(); } - absl::StatusOr>> GetOutputLayouts() - const override { + absl::StatusOr>> + GetOutputLayouts() const override { DCHECK(this); return pjrt_loaded_executable_->GetOutputLayouts(); } diff --git a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc index 356c2f9e5d2f3a..4fb1ca36e33a50 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/tfrt_cpu_client_test_lib.cc @@ -24,16 +24,17 @@ namespace xla { namespace ifrt { namespace { -const bool kUnused = (test_util::RegisterClientFactory( - []() -> absl::StatusOr> { - CpuClientOptions options; - options.cpu_device_count = 4; - TF_ASSIGN_OR_RETURN(auto pjrt_client, - xla::GetTfrtCpuClient(options)); - return std::shared_ptr( - PjRtClient::Create(std::move(pjrt_client))); - }), - true); +const bool kUnused = + (test_util::RegisterClientFactory( + []() -> absl::StatusOr> { + CpuClientOptions options; + options.cpu_device_count = 4; + TF_ASSIGN_OR_RETURN(auto pjrt_client, + xla::GetTfrtCpuClient(std::move(options))); + return std::shared_ptr( + PjRtClient::Create(std::move(pjrt_client))); + }), + true); } // namespace } // namespace ifrt diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 576a46168dea42..0921f766063ae9 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -133,6 +133,7 @@ TEST(LoadedExecutableImplTest, CompileAndExecute) { /*on_done_with_host_buffer=*/{})); ExecuteOptions execute_options; + execute_options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN( LoadedExecutable::ExecuteResult result, loaded_executable->Execute(absl::MakeSpan(&array, 1), execute_options, @@ -177,6 +178,7 @@ TEST(LoadedExecutableImplTest, CompileAndExecutePortable) { /*on_done_with_host_buffer=*/{})); ExecuteOptions execute_options; + execute_options.fill_status = true; TF_ASSERT_OK_AND_ASSIGN(LoadedExecutable::ExecuteResult result, loaded_executable->Execute( absl::MakeSpan(&array, 1), execute_options, @@ -195,6 +197,51 @@ TEST(LoadedExecutableImplTest, CompileAndExecutePortable) { EXPECT_THAT(out_data, ElementsAreArray(expected_out_data)); } +TEST(LoadedExecutableImplTest, DoNotFillStatus) { + TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); + Compiler* compiler = client->GetDefaultCompiler(); + + std::vector devices = {client->addressable_devices().at(0)}; + TF_ASSERT_OK_AND_ASSIGN( + auto loaded_executable, + CompileOnDevices(client.get(), compiler, module_add_one, devices, + /*replicated=*/false)); + + DType dtype(DType::kF32); + Shape shape({2, 3}); + std::vector data(6); + std::iota(data.begin(), data.end(), 0); + Device* device = client->addressable_devices().at(0); + std::shared_ptr sharding = + SingleDeviceSharding::Create(device, MemoryKind()); + + TF_ASSERT_OK_AND_ASSIGN( + auto array, client->MakeArrayFromHostBuffer( + data.data(), dtype, shape, + /*byte_strides=*/std::nullopt, sharding, + Client::HostBufferSemantics::kImmutableOnlyDuringCall, + /*on_done_with_host_buffer=*/{})); + + ExecuteOptions execute_options; + execute_options.fill_status = false; + TF_ASSERT_OK_AND_ASSIGN( + LoadedExecutable::ExecuteResult result, + loaded_executable->Execute(absl::MakeSpan(&array, 1), execute_options, + /*devices=*/std::nullopt)); + EXPECT_FALSE(result.status.IsValid()); + EXPECT_THAT(result.outputs, SizeIs(1)); + + std::vector out_data(6); + auto future = result.outputs[0]->CopyToHostBuffer( + out_data.data(), /*byte_strides=*/std::nullopt, + ArrayCopySemantics::kAlwaysCopy); + TF_ASSERT_OK(future.Await()); + + std::vector expected_out_data(6); + std::iota(expected_out_data.begin(), expected_out_data.end(), 1); + EXPECT_THAT(out_data, ElementsAreArray(expected_out_data)); +} + TEST(LoadedExecutableImplTest, Delete) { TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient()); Compiler* compiler = client->GetDefaultCompiler(); diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc index 15c6da57ebddb1..38d071a3ec6161 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc @@ -70,7 +70,8 @@ bool NextIndex(Index::Elements* index, absl::Span limit) { // Note that this is O(N^2) where N is the number of devices (shards). std::vector IndexDomainsSlowPath( const xla::HloSharding& hlo_sharding, - const tsl::RCReference& devices, const Shape& shape) { + const tsl::RCReference& devices, const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) { // Only shape dimensions are used. auto xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( xla::PrimitiveType::S32, shape.dims()); @@ -85,14 +86,20 @@ std::vector IndexDomainsSlowPath( Index::Elements origin(shape.dims().size()); Shape::Dimensions shard_shape(shape.dims().size()); - for (int device_idx = 0; device_idx < devices->size(); ++device_idx) { - auto tile_offset = hlo_sharding.TileOffsetForDevice(xla_shape, device_idx); - auto tile_limit = hlo_sharding.TileLimitForDevice(xla_shape, device_idx); - for (int i = 0; i < shape.dims().size(); ++i) { - origin[i] = tile_offset[i]; - shard_shape[i] = tile_limit[i] - tile_offset[i]; + const absl::Span device_ptrs = devices->devices(); + for (int device_idx = 0; device_idx < device_ptrs.size(); ++device_idx) { + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + device_ptrs[device_idx]->IsAddressable()) { + auto tile_offset = + hlo_sharding.TileOffsetForDevice(xla_shape, device_idx); + auto tile_limit = hlo_sharding.TileLimitForDevice(xla_shape, device_idx); + for (int i = 0; i < shape.dims().size(); ++i) { + origin[i] = tile_offset[i]; + shard_shape[i] = tile_limit[i] - tile_offset[i]; + } + result.push_back(IndexDomain(Index(origin), Shape(shard_shape))); } - result.push_back(IndexDomain(Index(origin), Shape(shard_shape))); } return result; } @@ -179,9 +186,17 @@ absl::StatusOr> HloSharding::WithDeviceAssignment( return Create(devices.value_or(devices_), memory_kind.value_or(memory_kind_), xla_hlo_sharding_); } - absl::StatusOr>>> HloSharding::Disassemble(const Shape& shape) const { + DCHECK(this); + return Disassemble(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr>>> +HloSharding::Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); bool is_even_sharding = false; if (xla_hlo_sharding_.IsReplicated() || xla_hlo_sharding_.IsTileMaximal()) { is_even_sharding = true; @@ -212,12 +227,21 @@ HloSharding::Disassemble(const Shape& shape) const { // Fast path for even sharding. TF_ASSIGN_OR_RETURN(xla::ifrt::Shape shard_shape, GetShardShape(shape)); std::vector>> result; - result.reserve(devices_->size()); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } for (int i = 0; i < devices_->size(); ++i) { - result.push_back({ - shard_shape, - SingleDeviceSharding::Create(devices[i], memory_kind_), - }); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + result.push_back({ + shard_shape, + SingleDeviceSharding::Create(devices[i], memory_kind_), + }); + } } return result; } else { @@ -226,12 +250,21 @@ HloSharding::Disassemble(const Shape& shape) const { IndexDomains(shape)); CHECK_EQ(index_domains.size(), devices_->size()); std::vector>> result; - result.reserve(index_domains.size()); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards) { + result.reserve(devices_->size()); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } for (int i = 0; i < index_domains.size(); ++i) { - result.push_back({ - index_domains[i].shape(), - SingleDeviceSharding::Create(devices[i], memory_kind_), - }); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[i]->IsAddressable()) { + result.push_back({ + index_domains[i].shape(), + SingleDeviceSharding::Create(devices[i], memory_kind_), + }); + } } return result; } @@ -240,6 +273,16 @@ HloSharding::Disassemble(const Shape& shape) const { absl::StatusOr< std::vector>>> HloSharding::Disassemble(const DynamicShape& dynamic_shape) const { + DCHECK(this); + return Disassemble(dynamic_shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr< + std::vector>>> +HloSharding::Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { + DCHECK(this); return InvalidArgument( "HloSharding can only disassemble static shape, but was asked " "to disassemble dynamic shape %s", @@ -248,6 +291,13 @@ HloSharding::Disassemble(const DynamicShape& dynamic_shape) const { absl::StatusOr> HloSharding::IndexDomains( const Shape& shape) const { + DCHECK(this); + return IndexDomains(shape, SingleDeviceShardSemantics::kAllShards); +} + +absl::StatusOr> HloSharding::IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const { std::vector result; const int num_devices = devices_->size(); @@ -258,16 +308,24 @@ absl::StatusOr> HloSharding::IndexDomains( if (xla_hlo_sharding_.IsReplicated() || xla_hlo_sharding_.IsTileMaximal()) { // Fast path for a fully replicated or maximal sharding. IndexDomain element(shape); - result.resize(/*count=*/num_devices, /*value=*/element); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards) { + result.resize(/*count=*/num_devices, /*value=*/element); + } else { + result.resize(/*count=*/devices_->AddressableDeviceList()->size(), + /*value=*/element); + } return result; } if (!xla_hlo_sharding_.IsTiled()) { - return IndexDomainsSlowPath(xla_hlo_sharding_, devices_, shape); + return IndexDomainsSlowPath(xla_hlo_sharding_, devices_, shape, + single_device_shard_semantics); } for (const xla::OpSharding::Type subgroup_type : xla_hlo_sharding_.subgroup_types()) { if (subgroup_type != xla::OpSharding::REPLICATED) { - return IndexDomainsSlowPath(xla_hlo_sharding_, devices_, shape); + return IndexDomainsSlowPath(xla_hlo_sharding_, devices_, shape, + single_device_shard_semantics); } } if (xla_hlo_sharding_.tile_assignment().num_elements() != num_devices) { @@ -338,16 +396,25 @@ absl::StatusOr> HloSharding::IndexDomains( } } while (NextIndex(&unique_tile_index, tile_assignment_dims)); - result.reserve(num_devices); + if (single_device_shard_semantics == SingleDeviceShardSemantics::kAllShards) { + result.reserve(num_devices); + } else { + result.reserve(devices_->AddressableDeviceList()->size()); + } + const absl::Span devices = devices_->devices(); for (int device_idx = 0; device_idx < num_devices; ++device_idx) { - Shape::Dimensions actual_tile_shape; - actual_tile_shape.reserve(tile_shape_dims.size()); - for (int i = 0; i < tile_shape_dims.size(); ++i) { - actual_tile_shape.push_back(std::min( - tile_shape_dims[i], shape.dims()[i] - origins[device_idx][i])); + if (single_device_shard_semantics == + SingleDeviceShardSemantics::kAllShards || + devices[device_idx]->IsAddressable()) { + Shape::Dimensions actual_tile_shape; + actual_tile_shape.reserve(tile_shape_dims.size()); + for (int i = 0; i < tile_shape_dims.size(); ++i) { + actual_tile_shape.push_back(std::min( + tile_shape_dims[i], shape.dims()[i] - origins[device_idx][i])); + } + result.push_back(IndexDomain(Index(origins[device_idx]), + Shape(std::move(actual_tile_shape)))); } - result.push_back(IndexDomain(Index(origins[device_idx]), - Shape(std::move(actual_tile_shape)))); } return result; } @@ -358,9 +425,11 @@ std::string HloSharding::DebugString() const { } std::vector TEST_HloShardingIndexDomainsSlowPath( - const HloSharding& hlo_sharding, const Shape& shape) { + const HloSharding& hlo_sharding, const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) { return IndexDomainsSlowPath(hlo_sharding.xla_hlo_sharding(), - hlo_sharding.devices(), shape); + hlo_sharding.devices(), shape, + single_device_shard_semantics); } } // namespace ifrt diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h index 1f4e3a3b869a84..44645fdf6f56d6 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.h @@ -74,12 +74,25 @@ class HloSharding final absl::StatusOr>>> Disassemble(const Shape& shape) const override; + absl::StatusOr>>> + Disassemble( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; + absl::StatusOr< std::vector>>> Disassemble(const DynamicShape& dynamic_shape) const override; + absl::StatusOr< + std::vector>>> + Disassemble( + const DynamicShape& dynamic_shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; absl::StatusOr> IndexDomains( const Shape& shape) const override; + absl::StatusOr> IndexDomains( + const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics) const override; std::string DebugString() const override; @@ -95,7 +108,8 @@ class HloSharding final // Test only: returns `HloSharding::IndexDomains()`, using `xla::HloSharding` // APIs internally. std::vector TEST_HloShardingIndexDomainsSlowPath( - const HloSharding& sharding, const Shape& shape); + const HloSharding& sharding, const Shape& shape, + SingleDeviceShardSemantics single_device_shard_semantics); } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc index c74c7e6168a199..0f64d5c2cf3702 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding_test.cc @@ -170,116 +170,302 @@ TEST_P(HloShardingTest, WithDeviceAssignment) { } TEST_P(HloShardingTest, IndexDomainsWithReplication) { - auto device_list = GetDevices({0, 1}); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(shape), IndexDomain(shape))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(shape), IndexDomain(shape), + IndexDomain(shape), IndexDomain(shape), + IndexDomain(shape), IndexDomain(shape))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(shape), IndexDomain(shape), + IndexDomain(shape), IndexDomain(shape), + IndexDomain(shape), IndexDomain(shape))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(shape), IndexDomain(shape), + IndexDomain(shape), IndexDomain(shape))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithReplication) { - auto device_list = GetDevices({0, 1}); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); // Fully replicated. auto xla_hlo_sharding = xla::HloSharding::Replicate(); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < 2; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({10, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } TEST_P(HloShardingTest, IndexDomainsWithTile) { - auto device_list = GetDevices({0, 1}); - // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + // 6-way sharded along axis 0, 1-way sharded along axis 1. + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({6, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); - Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + Shape shape({12, 20}); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})), + IndexDomain(Index({8, 0}), Shape({2, 20})), + IndexDomain(Index({10, 0}), Shape({2, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})), + IndexDomain(Index({8, 0}), Shape({2, 20})), + IndexDomain(Index({10, 0}), Shape({2, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithTile) { - auto device_list = GetDevices({0, 1}); - // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + // 6-way sharded along axis 0, 1-way sharded along axis 1. + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({6, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); - Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < 2; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({5, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + Shape shape({12, 20}); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({2, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({2, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({2, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } TEST_P(HloShardingTest, IndexDomainsWithUnevenTile) { - auto device_list = GetDevices({0, 1}); - // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + // 6-way sharded along axis 0, 1-way sharded along axis 1. + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({6, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({11, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({6, 20})), - IndexDomain(Index({6, 0}), Shape({5, 20})))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})), + IndexDomain(Index({8, 0}), Shape({2, 20})), + IndexDomain(Index({10, 0}), Shape({1, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})), + IndexDomain(Index({8, 0}), Shape({2, 20})), + IndexDomain(Index({10, 0}), Shape({1, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({2, 20})), + IndexDomain(Index({2, 0}), Shape({2, 20})), + IndexDomain(Index({4, 0}), Shape({2, 20})), + IndexDomain(Index({6, 0}), Shape({2, 20})))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithUnevenTile) { - auto device_list = GetDevices({0, 1}); - // 2-way sharded along axis 0, 1-way sharded along axis 1. - auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({2, 1})); + auto device_list = GetDevices({0, 1, 2, 3, 4, 5}); + // 6-way sharded along axis 0, 1-way sharded along axis 1. + auto xla_hlo_sharding = xla::HloSharding::Tile(xla::TileAssignment({6, 1})); std::shared_ptr sharding = HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({11, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(2)); - for (int i = 0; i < 2; ++i) { - const auto& [shape, sharding] = disassembled[i]; - if (i == 0) { - EXPECT_EQ(shape, Shape({6, 20})); - } else { - EXPECT_EQ(shape, Shape({5, 20})); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + if (i < 5) { + EXPECT_EQ(shape, Shape({2, 20})); + } else { + EXPECT_EQ(shape, Shape({1, 20})); + } + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + if (i < 5) { + EXPECT_EQ(shape, Shape({2, 20})); + } else { + EXPECT_EQ(shape, Shape({1, 20})); + } + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({2, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); } - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); } } @@ -293,18 +479,50 @@ TEST_P(HloShardingTest, IndexDomainsWithPartialTile) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithPartialTile) { @@ -317,14 +535,41 @@ TEST_P(HloShardingTest, DisassembleWithPartialTile) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(6)); - for (int i = 0; i < 6; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({5, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -338,18 +583,50 @@ TEST_P(HloShardingTest, IndexDomainsWithSubgroupReplicated) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithSubgroupReplicated) { @@ -362,14 +639,41 @@ TEST_P(HloShardingTest, DisassembleWithSubgroupReplicated) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(6)); - for (int i = 0; i < 6; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({5, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -383,18 +687,50 @@ TEST_P(HloShardingTest, IndexDomainsWithSubgroupMaximalSlowPath) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); - - EXPECT_THAT(index_domains, - ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({0, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})), - IndexDomain(Index({5, 0}), Shape({5, 20})))); - EXPECT_THAT( - index_domains, - ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath(*sharding, shape))); + { + TF_ASSERT_OK_AND_ASSIGN(auto index_domains, sharding->IndexDomains(shape)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, SingleDeviceShardSemantics::kAllShards)); + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT(index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAllShards))); + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto index_domains, + sharding->IndexDomains(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + EXPECT_THAT(index_domains, + ElementsAre(IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({0, 0}), Shape({5, 20})), + IndexDomain(Index({5, 0}), Shape({5, 20})))); + EXPECT_THAT( + index_domains, + ElementsAreArray(TEST_HloShardingIndexDomainsSlowPath( + *sharding, shape, SingleDeviceShardSemantics::kAddressableShards))); + } } TEST_P(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { @@ -407,14 +743,40 @@ TEST_P(HloShardingTest, DisassembleWithSubgroupMaximalSlowPath) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(6)); - for (int i = 0; i < 6; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({5, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({5, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -438,14 +800,41 @@ TEST_P(HloShardingTest, DisassembleWithManual) { HloSharding::Create(device_list, MemoryKind(), xla_hlo_sharding); Shape shape({10, 20}); - TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); - - ASSERT_THAT(disassembled, SizeIs(6)); - for (int i = 0; i < 6; ++i) { - const auto& [shape, sharding] = disassembled[i]; - EXPECT_EQ(shape, Shape({10, 20})); - EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( - device_list->devices()[i], MemoryKind())); + { + TF_ASSERT_OK_AND_ASSIGN(auto disassembled, sharding->Disassemble(shape)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, SingleDeviceShardSemantics::kAllShards)); + ASSERT_THAT(disassembled, SizeIs(6)); + for (int i = 0; i < 6; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } + } + { + TF_ASSERT_OK_AND_ASSIGN( + auto disassembled, + sharding->Disassemble(shape, + SingleDeviceShardSemantics::kAddressableShards)); + // The first 4 devices are addressable. + ASSERT_THAT(disassembled, SizeIs(4)); + for (int i = 0; i < 4; ++i) { + const auto& [shape, sharding] = disassembled[i]; + EXPECT_EQ(shape, Shape({10, 20})); + EXPECT_EQ(*sharding, *SingleDeviceSharding::Create( + device_list->devices()[i], MemoryKind())); + } } } @@ -495,7 +884,8 @@ TEST_P(HloShardingTest, DisassembleFailsWithDynamicShape) { INSTANTIATE_TEST_SUITE_P(NumDevices, HloShardingTest, testing::Values(test_util::DeviceTestParam{ - .num_devices = 6, .num_addressable_devices = 4})); + /*num_devices=*/6, + /*num_addressable_devices=*/4})); } // namespace } // namespace ifrt diff --git a/third_party/xla/xla/python/profiler.cc b/third_party/xla/xla/python/profiler.cc index e822d5153cc75e..f26ef1b0464387 100644 --- a/third_party/xla/xla/python/profiler.cc +++ b/third_party/xla/xla/python/profiler.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/pjrt/exceptions.h" #include "xla/pjrt/status_casters.h" #include "xla/python/aggregate_profile.h" +#include "xla/python/profiler/profile_data.h" #include "xla/python/profiler_utils.h" #include "xla/python/xplane_to_profile_instructions.h" #include "xla/tsl/profiler/rpc/client/capture_profile.h" @@ -198,6 +199,14 @@ void BuildProfilerSubmodule(nb::module_& m) { std::string xspace_str = xspace.SerializeAsString(); return nb::bytes(xspace_str.data(), xspace_str.size()); }) + .def("stop_and_get_profile_data", + [](ProfilerSessionWrapper* sess) + -> tensorflow::profiler::python::ProfileData { + std::shared_ptr xspace; + // Disables the ProfilerSession + xla::ThrowIfError(sess->session->CollectData(xspace.get())); + return tensorflow::profiler::python::ProfileData(xspace); + }) .def("export", [](ProfilerSessionWrapper* sess, nb::bytes xspace, const std::string& tensorboard_dir) -> void { diff --git a/third_party/xla/xla/python/profiler/BUILD b/third_party/xla/xla/python/profiler/BUILD new file mode 100644 index 00000000000000..85366a7e86cb47 --- /dev/null +++ b/third_party/xla/xla/python/profiler/BUILD @@ -0,0 +1,43 @@ +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tsl:tsl.default.bzl", "tsl_pybind_extension") + +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) + +cc_library( + name = "profile_data_lib", + srcs = ["profile_data.cc"], + hdrs = ["profile_data.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = [ + "//perftools/accelerators/xprof/api/python:__pkg__", + "//xla/python:__pkg__", + ], + deps = [ + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + "@nanobind", + ], + alwayslink = 1, +) + +tsl_pybind_extension( + name = "profile_data", + srcs = ["py_profile_data.cc"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + pytype_srcs = ["profile_data.pyi"], + visibility = ["//visibility:public"], + deps = [ + ":profile_data_lib", + "@nanobind", + ], +) diff --git a/third_party/xla/xla/python/profiler/internal/BUILD b/third_party/xla/xla/python/profiler/internal/BUILD index aeebb16d8ca5c0..6d926ea97802de 100644 --- a/third_party/xla/xla/python/profiler/internal/BUILD +++ b/third_party/xla/xla/python/profiler/internal/BUILD @@ -21,6 +21,10 @@ cc_library( "//tensorflow/python/profiler/internal:__subpackages__", ]), deps = [ + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -30,10 +34,6 @@ cc_library( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", "@pybind11", ], alwayslink = True, diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.cc b/third_party/xla/xla/python/profiler/internal/python_hooks.cc index ce66d686a982e9..0da1fe5e0124b5 100644 --- a/third_party/xla/xla/python/profiler/internal/python_hooks.cc +++ b/third_party/xla/xla/python/profiler/internal/python_hooks.cc @@ -21,13 +21,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace xla { namespace profiler { diff --git a/third_party/xla/xla/python/profiler/profile_data.cc b/third_party/xla/xla/python/profiler/profile_data.cc new file mode 100644 index 00000000000000..91c9e6c84cf39b --- /dev/null +++ b/third_party/xla/xla/python/profiler/profile_data.cc @@ -0,0 +1,243 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/python/profiler/profile_data.h" + +#include +#include // IWYU pragma: keep. For automatic conversion of std::string to Python string. + +#include +#include +#include +#include + +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/status.h" + +namespace tensorflow::profiler::python { + +namespace nb = nanobind; +using tensorflow::profiler::XEvent; +using tensorflow::profiler::XLine; +using tensorflow::profiler::XPlane; +using tensorflow::profiler::XSpace; +using tensorflow::profiler::XStat; + +// Converts a XStat object to a Python tuple. For compatibility reasons, we +// always return of the same sizes. +nb::tuple stats_to_tuple(const XStat& stat, const XPlane* plane) { + if (plane->stat_metadata().contains(stat.metadata_id())) { + const std::string& name = + plane->stat_metadata().at(stat.metadata_id()).name(); + switch (stat.value_case()) { + case XStat::kDoubleValue: + return nb::make_tuple(name, nb::cast(stat.double_value())); + break; + case XStat::kUint64Value: + return nb::make_tuple(name, nb::cast(stat.uint64_value())); + break; + case XStat::kInt64Value: + return nb::make_tuple(name, nb::cast(stat.int64_value())); + break; + case XStat::kStrValue: + return nb::make_tuple(name, stat.str_value()); + break; + case XStat::kBytesValue: + return nb::make_tuple(name, stat.bytes_value()); + break; + case XStat::kRefValue: + if (plane->stat_metadata().contains(stat.ref_value())) { + return nb::make_tuple( + name, plane->stat_metadata().at(stat.ref_value()).name()); + } else { + return nb::make_tuple(name, ""); + } + break; + default: + LOG(ERROR) << "Unsupported stat value type: " << stat.value_case(); + break; + } + } + return nb::make_tuple(nb::none(), nb::none()); +} + +ProfileEvent::ProfileEvent(const XEvent* event, int64_t line_timestamp_ns, + const XPlane* plane, + std::shared_ptr xspace) + : event_(event), + plane_(plane), + line_timestamp_ns_(line_timestamp_ns), + xspace_(xspace) { + CHECK_NOTNULL(event_); + CHECK_NOTNULL(plane_); + CHECK_NOTNULL(xspace_); +} + +double ProfileEvent::start_ns() const { + return event_->offset_ps() / 1000 + line_timestamp_ns_; +} + +double ProfileEvent::duration_ns() const { + return event_->duration_ps() / 1000; +} + +double ProfileEvent::end_ns() const { return start_ns() + duration_ns(); } + +std::string ProfileEvent::name() const { + if (plane_->event_metadata().contains(event_->metadata_id())) { + return plane_->event_metadata().at(event_->metadata_id()).name(); + } + return ""; +} + +VisitorIterator ProfileEvent::stats_begin() { + return VisitorIterator( + &event_->stats(), + [this](const XStat& stat) { return stats_to_tuple(stat, plane_); }); +} +VisitorIterator ProfileEvent::stats_end() { + return VisitorIterator( + &event_->stats(), + [this](const XStat& stat) { return stats_to_tuple(stat, plane_); }, + event_->stats().size()); +} + +ProfileLine::ProfileLine(const XLine* line, const XPlane* plane, + std::shared_ptr xspace) + : line_(line), plane_(plane), xspace_(xspace) { + CHECK_NOTNULL(line_); + CHECK_NOTNULL(plane_); + CHECK_NOTNULL(xspace_); +} + +const std::string& ProfileLine::name() const { return line_->name(); } + +VisitorIterator ProfileLine::events_begin() { + return VisitorIterator( + &line_->events(), [this](const XEvent& event) { + return ProfileEvent(&event, line_->timestamp_ns(), plane_, xspace_); + }); +} + +VisitorIterator ProfileLine::events_end() { + return VisitorIterator( + &line_->events(), + [this](const XEvent& event) { + return ProfileEvent(&event, line_->timestamp_ns(), plane_, xspace_); + }, + line_->events().size()); +} + +ProfilePlane::ProfilePlane(const XPlane* plane, + std::shared_ptr xspace) + : plane_(plane), xspace_(xspace) { + CHECK_NOTNULL(plane_); + CHECK_NOTNULL(xspace_); +} + +const std::string& ProfilePlane::name() const { return plane_->name(); } + +VisitorIterator ProfilePlane::lines_begin() { + return VisitorIterator( + &plane_->lines(), [this](const XLine& line) { + return ProfileLine(&line, plane_, xspace_); + }); +} +VisitorIterator ProfilePlane::lines_end() { + return VisitorIterator( + &plane_->lines(), + [this](const XLine& line) { return ProfileLine(&line, plane_, xspace_); }, + plane_->lines().size()); +} + +VisitorIterator ProfilePlane::stats_begin() { + return VisitorIterator( + &plane_->stats(), + [this](const XStat& stat) { return stats_to_tuple(stat, plane_); }); +} + +VisitorIterator ProfilePlane::stats_end() { + return VisitorIterator( + &plane_->stats(), + [this](const XStat& stat) { return stats_to_tuple(stat, plane_); }, + plane_->stats().size()); +} + +/*static*/ ProfileData ProfileData::from_serialized_xspace( + const nb::bytes& serialized_xspace) { + return ProfileData(serialized_xspace); +} + +/*static*/ ProfileData ProfileData::from_file( + const std::string& proto_file_path) { + std::string serialized_xspace; + TF_CHECK_OK(tsl::ReadFileToString(tsl::Env::Default(), proto_file_path, + &serialized_xspace)); + return ProfileData(serialized_xspace.c_str(), serialized_xspace.size()); +} + +/*static*/ ProfileData ProfileData::from_raw_cpp_ptr(nb::capsule capsule) { + auto raw_ptr = static_cast(capsule.data()); + auto proto_ptr = std::shared_ptr(raw_ptr); + + return ProfileData(proto_ptr); +} + +ProfileData::ProfileData(const char* serialized_xspace_ptr, + size_t serialized_xspace_size) { + CHECK_NOTNULL(serialized_xspace_ptr); + + if (!xspace_) { + xspace_ = std::make_shared(); + } + CHECK(xspace_->ParseFromArray(serialized_xspace_ptr, serialized_xspace_size)); +} + +/*explicit*/ ProfileData::ProfileData(std::shared_ptr xspace_ptr) { + xspace_ = xspace_ptr; +} + +/*explicit*/ ProfileData::ProfileData(const nb::bytes& serialized_xspace) { + if (!xspace_) { + xspace_ = std::make_shared(); + } + CHECK(xspace_->ParseFromArray(serialized_xspace.data(), + serialized_xspace.size())); +} + +VisitorIterator ProfileData::planes_begin() { + return VisitorIterator( + &xspace_->planes(), + [this](const XPlane& plane) { return ProfilePlane(&plane, xspace_); }); +} + +VisitorIterator ProfileData::planes_end() { + return VisitorIterator( + &xspace_->planes(), + [this](const XPlane& plane) { return ProfilePlane(&plane, xspace_); }, + xspace_->planes().size()); +} + +ProfilePlane* ProfileData::find_plane_with_name(const std::string& name) const { + for (const auto& plane : xspace_->planes()) { + if (plane.name() == name) { + return new ProfilePlane(&plane, xspace_); + } + } + return nullptr; +} + +} // namespace tensorflow::profiler::python diff --git a/third_party/xla/xla/python/profiler/profile_data.h b/third_party/xla/xla/python/profiler/profile_data.h new file mode 100644 index 00000000000000..980c76375ae802 --- /dev/null +++ b/third_party/xla/xla/python/profiler/profile_data.h @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_PYTHON_PROFILER_PROFILE_DATA_H_ +#define XLA_PYTHON_PROFILER_PROFILE_DATA_H_ + +#include + +#include +#include +#include +#include + +#include "tsl/platform/logging.h" +#include "tsl/platform/protobuf.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tensorflow::profiler::python { + +namespace nb = nanobind; + +// A simple iterator that converts a proto repeated field to a Python iterable +// with a customized conversion function. +template +class VisitorIterator + : public std::iterator { + public: + VisitorIterator( + const tsl::protobuf::RepeatedPtrField* values, + const std::function& make_visitor, + int pos = 0) + : values_(values), make_visitor_(make_visitor), pos_(pos) { + CHECK_NOTNULL(values_); + CHECK_GE(pos_, 0); + CHECK_LE(pos_, values_->size()); + } + + // Prefix increment operator. + VisitorIterator& operator++() { + ++pos_; + return *this; + } + + // Postfix increment operator. + VisitorIterator operator++(int) { + VisitorIterator tmp(*this); + operator++(); + return tmp; + } + + bool operator==(const VisitorIterator& rhs) const { + return pos_ == rhs.pos_ && values_ == rhs.values_; + } + + bool operator!=(const VisitorIterator& rhs) const { + return pos_ != rhs.pos_ || values_ != rhs.values_; + } + + OutputType operator*() { return make_visitor_((*values_)[pos_]); } + + private: + const tsl::protobuf::RepeatedPtrField* values_; + const std::function make_visitor_; + int pos_ = 0; +}; + +class ProfileEvent { + public: + ProfileEvent() = delete; + + ProfileEvent(const tensorflow::profiler::XEvent* event, + int64_t line_timestamp_ns, + const tensorflow::profiler::XPlane* plane, + std::shared_ptr xspace); + + double start_ns() const; + + double duration_ns() const; + + double end_ns() const; + + std::string name() const; + + VisitorIterator stats_begin(); + VisitorIterator stats_end(); + + private: + const XEvent* event_; + const XPlane* plane_; + const int64_t line_timestamp_ns_; + // The actual XSpace protobuf we are wrapping around. A shared ptr is used so + // the different levels of visitors (ProfileData, ProfilePlane, + // ProfileLine, etc.) don't depend on the lifetime of others. + const std::shared_ptr xspace_; +}; + +class ProfileLine { + public: + ProfileLine() = delete; + + ProfileLine(const tensorflow::profiler::XLine* line, + const tensorflow::profiler::XPlane* plane, + std::shared_ptr xspace); + + const std::string& name() const; + + VisitorIterator events_begin(); + VisitorIterator events_end(); + + private: + const XLine* line_; + const XPlane* plane_; + // The actual XSpace protobuf we are wrapping around. A shared ptr is used so + // the different levels of visitors (ProfileData, ProfilePlane, + // ProfileLine, etc.) don't depend on the lifetime of others. + const std::shared_ptr xspace_; +}; + +class ProfilePlane { + public: + ProfilePlane() = delete; + + ProfilePlane(const tensorflow::profiler::XPlane* plane, + std::shared_ptr xspace); + + const std::string& name() const; + + VisitorIterator lines_begin(); + VisitorIterator lines_end(); + + VisitorIterator stats_begin(); + + VisitorIterator stats_end(); + + private: + const XPlane* plane_; + // The actual XSpace protobuf we are wrapping around. A shared ptr is used so + // the different levels of visitors (ProfileData, ProfilePlane, + // ProfileLine, etc.) don't depend on the lifetime of others. + const std::shared_ptr xspace_; +}; + +class ProfileData { + public: + static ProfileData from_serialized_xspace(const nb::bytes& serialized_xspace); + + static ProfileData from_file(const std::string& proto_file_path); + + static ProfileData from_raw_cpp_ptr(nb::capsule capsule); + + ProfileData() = delete; + + ProfileData(const char* serialized_xspace_ptr, size_t serialized_xspace_size); + + explicit ProfileData(std::shared_ptr xspace_ptr); + + explicit ProfileData(const nb::bytes& serialized_xspace); + + VisitorIterator planes_begin(); + + VisitorIterator planes_end(); + + ProfilePlane* find_plane_with_name(const std::string& name) const; + + private: + // The actual XSpace protobuf we are wrapping around. A shared ptr is used so + // the different levels of visitors (ProfileData, ProfilePlane, + // ProfileLine, etc.) don't depend on the lifetime of others. + std::shared_ptr xspace_; +}; + +ProfileData from_serialized_xspace(const std::string& serialized_xspace); + +ProfileData from_file(const std::string& proto_file_path); + +} // namespace tensorflow::profiler::python + +#endif // XLA_PYTHON_PROFILER_PROFILE_DATA_H_ diff --git a/third_party/xla/xla/python/profiler/profile_data.pyi b/third_party/xla/xla/python/profiler/profile_data.pyi new file mode 100644 index 00000000000000..9a84529639a477 --- /dev/null +++ b/third_party/xla/xla/python/profiler/profile_data.pyi @@ -0,0 +1,121 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for visiting program execution data.""" + +from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple + + +class ProfileData: + """Program execution data captured by jax.profiler functions.""" + + def __init__(self, serialized_xspace: bytes): + ... + + @classmethod + def from_file(cls, path: str) -> 'ProfileData': + """Creates a ProfileData from a serialized XSpace proto file.""" + ... + + @classmethod + def from_serialized_xspace(cls, serialized_xspace: bytes) -> 'ProfileData': + """Creates a ProfileData from a serialized XSpace proto.""" + ... + + @classmethod + def from_raw_cpp_ptr(cls, raw_proto_ptr: object) -> 'ProfileData': + """Creates a ProfileData from a raw C++ pointer enclosed in a capsule to a XSpace proto.""" + ... + + @property + def planes(self) -> Iterator['ProfilePlane']: + ... + + def find_plane_with_name(self, name: str) -> Optional['ProfilePlane']: + """Finds the plane with the given name.""" + ... + + +class ProfilePlane: + """Wraps XPlane protobuf and provides accessors to its contents.""" + + @property + def name(self) -> str: + """Name of the plane.""" + ... + + @property + def lines(self) -> Iterator['ProfileLine']: + """Lines in the plane.""" + ... + + @property + def stats(self) -> Iterator[Tuple[str, Any]]: + """Stats in the plane. + + Returns + An iterator of (name, value) tuples, note that for metadata ids that + are not found, the returned tuple will be (None, None). The caller should + check the tuple value before using it. + """ + ... + + +class ProfileLine: + """Wraps XLine protobuf and provides accessors to its contents.""" + + @property + def name(self) -> str: + """Name of the line.""" + ... + + @property + def events(self) -> Iterator['ProfileEvent']: + """Events in the line.""" + ... + + +class ProfileEvent: + """Wraps XEvent protobuf and provides accessors to its contents.""" + + @property + def start_ns(self) -> float: + """Start time of the event in nanoseconds.""" + ... + + @property + def duration_ns(self) -> float: + """Duration of the event in nanoseconds.""" + ... + + @property + def end_ns(self) -> float: + """End time of the event in nanoseconds.""" + ... + + @property + def name(self) -> str: + """Name of the event.""" + ... + + @property + def stats(self) -> Iterator[Tuple[str, Any]]: + """Stats of the event. + + Returns + An iterator of (name, value) tuples, note that for metadata ids that + are not found, the returned tuple will be (None, None). The caller should + check the tuple value before using it. + """ + ... diff --git a/third_party/xla/xla/python/profiler/py_profile_data.cc b/third_party/xla/xla/python/profiler/py_profile_data.cc new file mode 100644 index 00000000000000..b8bacfe9500ee6 --- /dev/null +++ b/third_party/xla/xla/python/profiler/py_profile_data.cc @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include // For automatic conversion of std::iterator to Python iterable. +#include // For automatic conversion of std::string to Python string. + +#include "xla/python/profiler/profile_data.h" + +namespace { + +namespace nb = nanobind; +// NOLINTBEGIN(build/namespaces) +using namespace nb::literals; +using namespace tensorflow::profiler::python; +// NOLINTEND(build/namespaces) + +NB_MODULE(profile_data, m) { + nb::class_(m, "ProfileEvent") + .def_prop_ro("start_ns", &ProfileEvent::start_ns) + .def_prop_ro("duration_ns", &ProfileEvent::duration_ns) + .def_prop_ro("end_ns", &ProfileEvent::end_ns) + .def_prop_ro("name", &ProfileEvent::name) + .def_prop_ro( + "stats", + [](ProfileEvent&& e) { + return nb::make_iterator(nb::type(), "event_stats", + e.stats_begin(), e.stats_end()); + }, + nb::keep_alive<0, 1>()); + nb::class_(m, "ProfileLine") + .def_prop_ro("name", &ProfileLine::name) + .def_prop_ro( + "events", + [](ProfileLine&& l) { + return nb::make_iterator(nb::type(), "events", + l.events_begin(), l.events_end()); + }, + nb::keep_alive<0, 1>()); + nb::class_(m, "ProfilePlane") + .def_prop_ro("name", &ProfilePlane::name) + .def_prop_ro( + "lines", + [](ProfilePlane&& p) { + return nb::make_iterator(nb::type(), "lines", + p.lines_begin(), p.lines_end()); + }, + nb::keep_alive<0, 1>()) + .def_prop_ro( + "stats", + [](ProfilePlane&& p) { + return nb::make_iterator(nb::type(), "plane_stats", + p.stats_begin(), p.stats_end()); + }, + nb::keep_alive<0, 1>()); + nb::class_(m, "ProfileData") + .def_static("from_raw_cpp_ptr", &ProfileData::from_raw_cpp_ptr, + "capsule"_a) + .def_static("from_file", &ProfileData::from_file, "proto_file_path"_a, + "Creates a ProfileData from a serialized XSpace proto file.") + .def_static("from_serialized_xspace", + &ProfileData::from_serialized_xspace, "serialized_xspace"_a) + .def(nb::init()) + .def("find_plane_with_name", &ProfileData::find_plane_with_name, "name"_a, + nb::keep_alive<0, 1>()) + .def_prop_ro( + "planes", + [](ProfileData&& s) { + return nb::make_iterator(nb::type(), "planes", + s.planes_begin(), s.planes_end()); + }, + nb::keep_alive<0, 1>()); +} + +} // namespace diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 8a622f031d1adf..387c14f65d9da8 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -60,6 +60,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/pjrt/status_casters.h" #include "xla/primitive_util.h" +#include "xla/python/guard_lib.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/device_list.h" @@ -75,13 +76,13 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_client.h" #include "xla/python/py_device.h" #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" -#include "xla/python/transfer_guard_lib.h" #include "xla/python/types.h" #include "xla/python/util.h" #include "xla/shape.h" @@ -309,6 +310,35 @@ extern "C" int PyArray_tp_traverse(PyObject* self, visitproc visit, void* arg) { // dynamic_attr: Allow the GC to clear the dictionary. extern "C" int PyArray_tp_clear(PyObject* self) { + switch (auto guard_level = jax::GetGarbageCollectArrayGuard(); guard_level) { + case jax::GarbageCollectionGuardLevel::kAllow: + break; + case jax::GarbageCollectionGuardLevel::kLog: + case jax::GarbageCollectionGuardLevel::kFatal: { + auto* obj = reinterpret_cast(self); + std::string traceback_str; + if (obj->initialized) { + auto traceback = GetPyArrayStorageFromObject(obj)->traceback; + if (traceback.has_value()) { + traceback_str = traceback.value()->ToString(); + } + } + auto error_msg = absl::StrCat( + "`jax.Array` was deleted by the Python garbage collector " + "instead of reference counting. Break the reference cycle " + "that delays the deletion of this `jax.Array` to avoid hogging " + "memory. Traceback: \n", + traceback_str.empty() ? "not available" : traceback_str); + if (guard_level == jax::GarbageCollectionGuardLevel::kFatal) { + Py_FatalError(error_msg.c_str()); + } else { + PyErr_SetString(PyExc_RuntimeError, error_msg.c_str()); + PyErr_Print(); + PyErr_Clear(); + } + break; + } + } #if PY_VERSION_HEX < 0x030C0000 PyObject*& dict = *_PyObject_GetDictPtr(self); Py_CLEAR(dict); diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index d108e9d9c1e47d..b7dace3445d070 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -57,9 +57,9 @@ limitations under the License. #include "xla/pjrt/pjrt_compiler.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_layout.h" -#include "xla/pjrt/pjrt_stream_executor_client.h" #include "xla/pjrt/status_casters.h" #include "xla/python/callback.h" +#include "xla/python/guard_lib.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/compiler.h" #include "xla/python/ifrt/device.h" @@ -84,17 +84,12 @@ limitations under the License. #include "xla/python/py_values.h" #include "xla/python/python_ref_manager.h" #include "xla/python/traceback.h" -#include "xla/python/transfer_guard_lib.h" #include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep -#include "xla/service/spmd/shardy/constants.h" -#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" -#include "xla/service/spmd/shardy/utils.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" -#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "xla/util.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" @@ -179,6 +174,15 @@ std::vector> PyClient::LocalDevices() { return devices; } +std::vector> PyClient::GetAllDevices() { + std::vector> devices; + devices.reserve(ifrt_client_->GetAllDevices().size()); + for (ifrt::Device* device : ifrt_client_->GetAllDevices()) { + devices.push_back(GetPyDevice(device)); + } + return devices; +} + absl::StatusOr> PyClient::DeviceFromLocalHardwareId( int local_hardware_id) { TF_ASSIGN_OR_RETURN(ifrt::Device * device, @@ -199,86 +203,93 @@ nb::list PyClient::LiveExecutables() { absl::Status PyClient::Defragment() { CHECK(PyGILState_Check()); - auto runtime_type = ifrt_client_->runtime_type(); - if (runtime_type == PjRtRuntimeTypeString(PjRtRuntimeType::kTfrt)) { + if (!llvm::isa(ifrt_client_.get())) { + return absl::UnimplementedError( + "Defragmentation is not supported on this runtime."); + } + ifrt::PlatformId platform_id = ifrt_client_->platform_id(); + bool is_gpu_client = platform_id == CudaId() || platform_id == RocmId() || + platform_id == SyclId(); + + if (!is_gpu_client) { return pjrt_client()->Defragment(); - } else if (runtime_type == - PjRtRuntimeTypeString(PjRtRuntimeType::kStreamExecutor)) { - struct TmpBuffer { - // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays - // can reference the same PjRtBuffer. - std::vector*> pjrt_buffer_ptrs; - // TODO(skyewm): maybe use py_buffer's HostValue - std::shared_ptr host_copy; - }; - - // Synchronously copy all buffers to host - absl::flat_hash_map pjrt_buf_to_tmp_buffer; - - for (PyArray_Storage* array = arrays_; array; array = array->next) { - // TODO(hyeontaek): Support non-PjRt Arrays. - // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that - // std::shared_ptr does not need to be updated in-place. - if (array->ifrt_array == nullptr) { + } + + struct TmpBuffer { + // Non-empty for buffers found in a PyArray_Storage. Multiple Arrays + // can reference the same PjRtBuffer. + std::vector*> pjrt_buffer_ptrs; + // TODO(skyewm): maybe use py_buffer's HostValue + std::shared_ptr host_copy; + }; + + // Synchronously copy all buffers to host + absl::flat_hash_map pjrt_buf_to_tmp_buffer; + + for (PyArray_Storage* array = arrays_; array; array = array->next) { + // TODO(hyeontaek): Support non-PjRt Arrays. + // TODO(hyeontaek): Re-construct ifrt::Array with new PjRtBuffer so that + // std::shared_ptr does not need to be updated in-place. + if (array->ifrt_array == nullptr) { + continue; + } + auto* arr = llvm::dyn_cast_or_null( + array->ifrt_array.get()); + if (arr == nullptr) { + throw XlaRuntimeError( + "This operation is implemented for a PjRt-compatible backend " + "only."); + } + TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, + arr->mutable_pjrt_buffers()); + for (int i = 0; i < pjrt_buffers.size(); ++i) { + std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; + if (pjrt_buf_ptr->IsDeleted()) { continue; } - auto* arr = llvm::dyn_cast_or_null( - array->ifrt_array.get()); - if (arr == nullptr) { - throw XlaRuntimeError( - "This operation is implemented for a PjRt-compatible backend " - "only."); - } - TF_ASSIGN_OR_RETURN(absl::Span> pjrt_buffers, - arr->mutable_pjrt_buffers()); - for (int i = 0; i < pjrt_buffers.size(); ++i) { - std::shared_ptr& pjrt_buf_ptr = pjrt_buffers[i]; - if (pjrt_buf_ptr->IsDeleted()) { - continue; - } - auto [iter, inserted] = - pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); - if (inserted) { - TF_ASSIGN_OR_RETURN(iter->second.host_copy, - pjrt_buf_ptr->ToLiteralSync()); - } - iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); + auto [iter, inserted] = + pjrt_buf_to_tmp_buffer.insert({pjrt_buf_ptr.get(), TmpBuffer()}); + if (inserted) { + TF_ASSIGN_OR_RETURN(iter->second.host_copy, + pjrt_buf_ptr->ToLiteralSync()); } + iter->second.pjrt_buffer_ptrs.push_back(&pjrt_buf_ptr); } + } - // All buffers successfully copied to host, delete on-device copies. - // - // Use blocking delete operation to ensure all memory is actually cleared - // before we start rewriting buffers. - // - // Die instead of returning a bad status because program presumably can't - // continue if we fail to reconstitute device buffers. - for (const auto& it : pjrt_buf_to_tmp_buffer) { - PjRtBuffer* pjrt_buf = it.first; - TF_CHECK_OK(tensorflow::down_cast(pjrt_buf) - ->Release(/*wait_for_operations_to_complete=*/true) - .status()); - } + // All buffers successfully copied to host, delete on-device copies. + // + // Use blocking delete operation to ensure all memory is actually cleared + // before we start rewriting buffers. + // + // Die instead of returning a bad status because program presumably can't + // continue if we fail to reconstitute device buffers. + for (const auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TF_CHECK_OK(pjrt_buf + ->ReleaseDeviceMemoryOwnership( + /*wait_for_operations_to_complete=*/true) + .status()); + } - // Copy host copies back to device and update PyArrays in-place. - for (auto& it : pjrt_buf_to_tmp_buffer) { - PjRtBuffer* pjrt_buf = it.first; - TmpBuffer& tmp_buffer = it.second; - std::unique_ptr new_copy = - pjrt_client() - ->BufferFromHostLiteral(*tmp_buffer.host_copy, pjrt_buf->device()) - .value(); - TF_CHECK_OK(new_copy->BlockHostUntilReady()); - - std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); - for (std::shared_ptr* pjrt_buffer_ptr : - tmp_buffer.pjrt_buffer_ptrs) { - *pjrt_buffer_ptr = new_pjrt_buf_ptr; - } + // Copy host copies back to device and update PyArrays in-place. + for (auto& it : pjrt_buf_to_tmp_buffer) { + PjRtBuffer* pjrt_buf = it.first; + TmpBuffer& tmp_buffer = it.second; + std::unique_ptr new_copy = + pjrt_client() + ->BufferFromHostLiteral(*tmp_buffer.host_copy, pjrt_buf->device()) + .value(); + TF_CHECK_OK(new_copy->GetReadyFuture().Await()); + + std::shared_ptr new_pjrt_buf_ptr(new_copy.release()); + for (std::shared_ptr* pjrt_buffer_ptr : + tmp_buffer.pjrt_buffer_ptrs) { + *pjrt_buffer_ptr = new_pjrt_buf_ptr; } - - // TODO(skyewm): delete executables? } + + // TODO(skyewm): delete executables? return absl::OkStatus(); } @@ -419,6 +430,11 @@ PyClient::CompileIfrtProgram( *stats->bytes_limit); } } + + if (pjrt_compatible_client->pjrt_client()->key_value_store().has_value()) { + options.executable_build_options.set_key_value_store( + *pjrt_compatible_client->pjrt_client()->key_value_store()); + } } std::unique_ptr ifrt_loaded_executable; @@ -444,20 +460,9 @@ PyClient::CompileIfrtProgram( TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); if (options.executable_build_options.use_shardy_partitioner()) { - mlir::PassManager pm(&context); - // Since Shardy is inside the middle of the XLA pipeline, after converting - // down to HLO, we need to run the Shardy export pipeline to preserve the - // SDY ops and sharding attributes for when we come back from HLO to MLIR - // when Shardy propagation is run. - xla::sdy::addSdyRoundTripExportPipeline(pm); - // TODO(bartchr): remove setting `kPythonIntegrationComplete` in follow-up - // now that both JAX and PartIR are integrated with Shardy. - xla::sdy::addFrontendAttribute(*module, - xla::sdy::kPythonIntegrationComplete, - mlir::StringAttr::get(&context, "t")); - TF_RETURN_IF_ERROR( - tsl::StatusScopedDiagnosticHandler(&context).consumeStatus( - pm.run(*module))); + // Since Shardy is located in the middle of the XLA pipeline, we need to + // export it before going to HLO while preserving Shardy ops and attrs. + TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module)); } return CompileIfrtProgram( client, std::make_unique(module.get()), @@ -693,6 +698,9 @@ PyType_Slot PyClient::slots_[] = { .def("local_device_count", &PyClient::addressable_device_count) .def("devices", &PyClient::Devices) .def("local_devices", &PyClient::LocalDevices) + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + .def("_get_all_devices", &PyClient::GetAllDevices) .def("device_from_local_hardware_id", xla::ValueOrThrowWrapper(&PyClient::DeviceFromLocalHardwareId)) .def("live_executables", &PyClient::LiveExecutables) diff --git a/third_party/xla/xla/python/py_client.h b/third_party/xla/xla/python/py_client.h index 374b7f6d2e530c..826300165b67cf 100644 --- a/third_party/xla/xla/python/py_client.h +++ b/third_party/xla/xla/python/py_client.h @@ -32,7 +32,7 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/Support/Casting.h" #include "nanobind/nanobind.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" @@ -126,6 +126,11 @@ class PyClient { std::vector> Devices(); std::vector> LocalDevices(); + // Returns all devices in the client. Private API; only use this method for + // implementing backend._get_all_devices(). + // TODO(hyeontaek): Remove this method once we have a unified API for + // enumerating devices with different criteria. + std::vector> GetAllDevices(); absl::StatusOr> DeviceFromLocalHardwareId( int local_hardware_id); diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index cccb29b1ecd8d3..44990c4c953d8b 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -65,6 +65,7 @@ limitations under the License. #include "xla/python/nb_class_ptr.h" #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_attribute_map_util.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" #include "xla/python/pjrt_ifrt/pjrt_topology.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" @@ -245,6 +246,16 @@ class CompileOnlyIfRtClient final "AssembleArrayFromSingleDeviceArrays not available with compile-only " "client."); } + absl::StatusOr> + AssembleArrayFromSingleDeviceArrays( + ifrt::Shape shape, std::shared_ptr sharding, + absl::Span> arrays, + ifrt::ArrayCopySemantics array_copy_semantics, + ifrt::SingleDeviceShardSemantics single_device_shard_semantics) override { + return Unimplemented( + "AssembleArrayFromSingleDeviceArrays not available with compile-only " + "client."); + } absl::StatusOr>> CopyArrays( absl::Span> arrays, @@ -294,6 +305,9 @@ class CompileOnlyIfRtClient final return {}; } int process_index() const override { return 0; } + absl::Span GetAllDevices() const override { + return devices_; + } absl::StatusOr GetDefaultDeviceAssignment( int num_replicas, int num_partitions) const override { return Unimplemented( diff --git a/third_party/xla/xla/python/py_device_list.cc b/third_party/xla/xla/python/py_device_list.cc index 22d701a5fb361f..a0ea40ce1efb81 100644 --- a/third_party/xla/xla/python/py_device_list.cc +++ b/third_party/xla/xla/python/py_device_list.cc @@ -104,7 +104,7 @@ int64_t PyDeviceList::Hash() { hash_ = absl::HashOf(std::get<0>(device_list_)); break; case 1: - hash_ = xla::nb_hash(std::get<1>(device_list_)); + hash_ = nb::hash(std::get<1>(device_list_)); break; default: throw nb::value_error("Unrecognized DeviceList type"); diff --git a/third_party/xla/xla/python/py_executable.cc b/third_party/xla/xla/python/py_executable.cc index b7395ad7793050..0bdff1204ac2f8 100644 --- a/third_party/xla/xla/python/py_executable.cc +++ b/third_party/xla/xla/python/py_executable.cc @@ -93,13 +93,11 @@ PyLoadedExecutable::PyLoadedExecutable( if (next_) { next_->prev_ = this; } - options_.untuple_result = true; if (fingerprint_) { options_.launch_id = tsl::Fingerprint32(*fingerprint_); VLOG(1) << "Fingerprint for executable " << ifrt_loaded_executable_->name() << ": " << *fingerprint_; } - options_.use_major_to_minor_data_layout_for_callbacks = true; } PyLoadedExecutable::~PyLoadedExecutable() { @@ -203,10 +201,9 @@ void PopulateExecuteShardedResults( template > absl::StatusOr ExecuteShardedOnLocalDevicesInternal( - const ExecuteOptions& options, const nb_class_ptr& client, + const ifrt::ExecuteOptions& options, const nb_class_ptr& client, ifrt::LoadedExecutable* ifrt_loaded_executable, absl::Span args, - std::optional>>& returned_futures, - bool attach_status_to_results) { + std::optional>>& returned_futures) { std::vector> output_arrays; std::unique_ptr> returned_future; int num_computations = ifrt_loaded_executable->addressable_devices().size(); @@ -232,13 +229,13 @@ absl::StatusOr ExecuteShardedOnLocalDevicesInternal( absl::MakeSpan(arg_arrays), options, /*devices=*/std::nullopt)); output_arrays = std::move(result.outputs); - // attach_status_to_results is only supposed to be true when the computation - // has tokens. - if (attach_status_to_results) { + // options.fill_status is only supposed to be true when the computation has + // tokens. + if (options.fill_status) { result_status = result.status; - } - if (returned_futures.has_value()) { - returned_futures->resize(num_computations, std::move(result.status)); + if (returned_futures.has_value()) { + returned_futures->resize(num_computations, std::move(result.status)); + } } } @@ -366,39 +363,43 @@ std::vector PyExecuteResults::ConsumeWithHandlers( absl::StatusOr>> PyLoadedExecutable::ExecuteShardedOnLocalDevices( absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.fill_status = false; std::optional>> returned_futures; - TF_ASSIGN_OR_RETURN( - auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options_, client_, ifrt_loaded_executable_.get(), args, - returned_futures, /*attach_status_to_results=*/false)); + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); return outputs_and_tokens.DisassembleIntoSingleDeviceArrays(); } absl::StatusOr>, PyShardedToken>> PyLoadedExecutable::ExecuteShardedOnLocalDevicesWithTokens( absl::Span args) { + xla::ifrt::ExecuteOptions options = options_; + options.fill_status = true; std::optional>> returned_futures; returned_futures.emplace(); - TF_ASSIGN_OR_RETURN( - auto outputs_and_tokens, - ExecuteShardedOnLocalDevicesInternal( - options_, client_, ifrt_loaded_executable_.get(), args, - returned_futures, /*attach_status_to_results=*/true)); + TF_ASSIGN_OR_RETURN(auto outputs_and_tokens, + ExecuteShardedOnLocalDevicesInternal( + options, client_, ifrt_loaded_executable_.get(), args, + returned_futures)); return std::make_pair(outputs_and_tokens.DisassembleIntoSingleDeviceArrays(), outputs_and_tokens.ConsumeToken()); } absl::StatusOr PyLoadedExecutable::ExecuteSharded( std::vector args, bool with_tokens) { + xla::ifrt::ExecuteOptions options = options_; + options.fill_status = with_tokens; std::optional>> returned_futures; if (with_tokens) { returned_futures.emplace(); } absl::Span span_args = args; - return ExecuteShardedOnLocalDevicesInternal( - options_, client_, ifrt_loaded_executable_.get(), span_args, - returned_futures, /*attach_status_to_results=*/with_tokens); + return ExecuteShardedOnLocalDevicesInternal(options, client_, + ifrt_loaded_executable_.get(), + span_args, returned_futures); } absl::StatusOr>> diff --git a/third_party/xla/xla/python/py_executable.h b/third_party/xla/xla/python/py_executable.h index ed34ce99ef1a89..e032ee7b4acdda 100644 --- a/third_party/xla/xla/python/py_executable.h +++ b/third_party/xla/xla/python/py_executable.h @@ -227,7 +227,7 @@ class PyLoadedExecutable { return exec->shared_ptr_pjrt_loaded_executable(); } - const ExecuteOptions& options() const { return options_; } + const ifrt::ExecuteOptions& options() const { return options_; } const std::optional& fingerprint() const { return fingerprint_; } // Keep `obj` alive as long as PyLoadedExecutable. @@ -246,7 +246,7 @@ class PyLoadedExecutable { std::optional fingerprint_; // The options to pass to `executable_.Execute`. - ExecuteOptions options_; + ifrt::ExecuteOptions options_; // Python objects to keep alive as requested by user. std::vector keepalives_; diff --git a/third_party/xla/xla/python/py_values.cc b/third_party/xla/xla/python/py_values.cc index e5d37d0ebdc838..9a9c63a922e90d 100644 --- a/third_party/xla/xla/python/py_values.cc +++ b/third_party/xla/xla/python/py_values.cc @@ -46,7 +46,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/python/py_array.h" #include "xla/python/python_ref_manager.h" #include "xla/python/sharding.h" @@ -185,6 +185,12 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = BF16; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E3M4; + } else if (std::is_same()) { + PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); + type = F8E4M3; } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3FN; @@ -394,6 +400,14 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint32.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_uint64.ptr()] = HandleNumpyScalar; + if (dtypes.np_float8_e3m4.has_value()) { + (*p)[dtypes.np_float8_e3m4->ptr()] = + HandleNumpyScalar; + } + if (dtypes.np_float8_e4m3.has_value()) { + (*p)[dtypes.np_float8_e4m3->ptr()] = + HandleNumpyScalar; + } (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = @@ -583,6 +597,9 @@ absl::StatusOr PyArgSignatureOfValue(nb::handle arg, (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler; (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler; + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // (*p)[dtypes.np_float8_e3m4.ptr()] = numpy_array_handler; + // (*p)[dtypes.np_float8_e4m3.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3fn.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = numpy_array_handler; (*p)[dtypes.np_float8_e5m2.ptr()] = numpy_array_handler; diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 65bfb3fe5305e4..e5662f9f5da674 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -595,15 +595,14 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const { case PyTreeKind::kNone: if (!object.is_none()) { - PythonDeprecationWarning( - /*stacklevel=*/3, - "In a future release of JAX, flatten-up-to will no longer " - "consider None to be a tree-prefix of non-None values, got: " - "%s.\n\n" - "To preserve the current behavior, you can usually write:\n" + throw std::invalid_argument(absl::StrFormat( + "Expected None, got %s.\n\n" + "In previous releases of JAX, flatten-up-to used to " + "consider None to be a tree-prefix of non-None values. To obtain " + "the previous behavior, you can usually write:\n" " jax.tree.map(lambda x, y: None if x is None else f(x, y), a, " "b, is_leaf=lambda x: x is None)", - nb::cast(nb::repr(object))); + nb::cast(nb::repr(object)))); } break; diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index cdad90eb430794..a5678221c42088 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -128,7 +128,7 @@ size_t ShardingHash(nb::handle sharding) { return absl::Hash()(single_device_sharding->device().ptr()); } - return xla::nb_hash(sharding); + return nb::hash(sharding); } bool ShardingEqual(nb::handle a, nb::handle b) { @@ -148,7 +148,9 @@ bool ShardingEqual(nb::handle a, nb::handle b) { a_named_sharding->memory_kind().equal( b_named_sharding->memory_kind()) && a_named_sharding->manual_axes().equal( - b_named_sharding->manual_axes()); + b_named_sharding->manual_axes()) && + a_named_sharding->logical_device_ids().equal( + b_named_sharding->logical_device_ids()); } if (a_type.is(GSPMDSharding::type())) { @@ -175,7 +177,8 @@ bool ShardingEqual(nb::handle a, nb::handle b) { NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object parsed_pspec, - nb::object manual_axes) + nb::object manual_axes, + nb::object logical_device_ids) : Sharding(/*num_devices=*/[&mesh]() { return nb::cast(mesh.attr("size")); }()), @@ -183,7 +186,8 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, spec_(std::move(spec)), memory_kind_(std::move(memory_kind)), parsed_pspec_(std::move(parsed_pspec)), - manual_axes_(std::move(manual_axes)) { + manual_axes_(std::move(manual_axes)), + logical_device_ids_(std::move(logical_device_ids)) { nb::object idl = nb::object(mesh_.attr("_internal_device_list")); if (idl.is_none()) { internal_device_list_ = std::nullopt; @@ -261,16 +265,18 @@ void RegisterSharding(nb::module_& m) { nb::class_(m, "Sharding").def(nb::init<>()); nb::class_(m, "NamedSharding", nb::dynamic_attr()) - .def(nb::init(), nb::arg("mesh"), nb::arg("spec").none(), nb::arg("memory_kind").none() = nb::none(), nb::arg("_parsed_pspec").none() = nb::none(), - nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr))) + nb::arg("_manual_axes") = nb::steal(PyFrozenSet_New(nullptr)), + nb::arg("_logical_device_ids").none() = nb::none()) .def_prop_ro("mesh", &NamedSharding::mesh) .def_prop_ro("spec", &NamedSharding::spec) .def_prop_ro("_memory_kind", &NamedSharding::memory_kind) .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) + .def_prop_ro("_logical_device_ids", &NamedSharding::logical_device_ids) .def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec, &NamedSharding::set_parsed_pspec) .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { diff --git a/third_party/xla/xla/python/sharding.h b/third_party/xla/xla/python/sharding.h index 847938478b2e85..5b41ae04110689 100644 --- a/third_party/xla/xla/python/sharding.h +++ b/third_party/xla/xla/python/sharding.h @@ -71,13 +71,17 @@ class NamedSharding : public Sharding { public: NamedSharding(nanobind::object mesh, nanobind::object spec, nanobind::object memory_kind, nanobind::object parsed_pspec, - nanobind::object manual_axes); + nanobind::object manual_axes, + nanobind::object logical_device_ids); const nanobind::object& mesh() const { return mesh_; } const nanobind::object& spec() const { return spec_; } const nanobind::object& memory_kind() const { return memory_kind_; } const nanobind::object& parsed_pspec() const { return parsed_pspec_; } const nanobind::object& manual_axes() const { return manual_axes_; } + const nanobind::object& logical_device_ids() const { + return logical_device_ids_; + } void set_parsed_pspec(nanobind::object parsed_pspec) { parsed_pspec_ = std::move(parsed_pspec); } @@ -102,6 +106,7 @@ class NamedSharding : public Sharding { nanobind::object memory_kind_; nanobind::object parsed_pspec_; nanobind::object manual_axes_; + nanobind::object logical_device_ids_; std::optional> internal_device_list_; }; diff --git a/third_party/xla/xla/python/types.cc b/third_party/xla/xla/python/types.cc index f3b8db6fdae018..125f96a75fdf25 100644 --- a/third_party/xla/xla/python/types.cc +++ b/third_party/xla/xla/python/types.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/python/ifrt/dtype.h" #include "xla/python/nb_helpers.h" #include "xla/python/nb_numpy.h" -#include "xla/python/pjrt_ifrt/pjrt_array.h" +#include "xla/python/pjrt_ifrt/pjrt_dtype.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -59,6 +59,8 @@ namespace { struct CustomDtypes { nb_dtype bfloat16; + std::optional float8_e3m4; + std::optional float8_e4m3; nb_dtype float8_e4m3fn; nb_dtype float8_e4m3b11fnuz; nb_dtype float8_e4m3fnuz; @@ -75,6 +77,12 @@ const CustomDtypes& GetCustomDtypes() { nb::module_ ml_dtypes = nb::module_::import_("ml_dtypes"); auto* dtypes = new CustomDtypes; dtypes->bfloat16 = nb_dtype::from_args(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->float8_e3m4 = nb_dtype::from_args(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->float8_e4m3 = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3")); + } dtypes->float8_e4m3fn = nb_dtype::from_args(ml_dtypes.attr("float8_e4m3fn")); dtypes->float8_e5m2 = nb_dtype::from_args(ml_dtypes.attr("float8_e5m2")); @@ -133,13 +141,19 @@ absl::StatusOr DtypeToPrimitiveType(const nb_dtype& np_type) { } }; struct DtypeHash { - ssize_t operator()(const nb_dtype& key) const { return nb_hash(key); } + ssize_t operator()(const nb_dtype& key) const { return nb::hash(key); } }; static auto* custom_dtype_map = []() { const CustomDtypes& custom_dtypes = GetCustomDtypes(); auto* map = new absl::flat_hash_map(); map->emplace(custom_dtypes.bfloat16, BF16); + if (custom_dtypes.float8_e3m4.has_value()) { + map->emplace(*custom_dtypes.float8_e3m4, F8E3M4); + } + if (custom_dtypes.float8_e4m3.has_value()) { + map->emplace(*custom_dtypes.float8_e4m3, F8E4M3); + } map->emplace(custom_dtypes.float8_e4m3fn, F8E4M3FN); map->emplace(custom_dtypes.float8_e4m3b11fnuz, F8E4M3B11FNUZ); map->emplace(custom_dtypes.float8_e4m3fnuz, F8E4M3FNUZ); @@ -204,6 +218,16 @@ absl::StatusOr PrimitiveTypeToNbDtype(PrimitiveType type) { return to_nb_dtype(NPY_UINT32); case U64: return to_nb_dtype(NPY_UINT64); + case F8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case F8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case F8E4M3FN: return custom_dtypes.float8_e4m3fn; case F8E4M3B11FNUZ: @@ -284,6 +308,16 @@ absl::StatusOr IfrtDtypeToNbDtype(ifrt::DType dtype) { return to_nb_dtype(NPY_COMPLEX64); case ifrt::DType::kC128: return to_nb_dtype(NPY_COMPLEX128); + case ifrt::DType::kF8E3M4: + if (custom_dtypes.float8_e3m4.has_value()) { + return *custom_dtypes.float8_e3m4; + } + break; + case ifrt::DType::kF8E4M3: + if (custom_dtypes.float8_e4m3.has_value()) { + return *custom_dtypes.float8_e4m3; + } + break; case ifrt::DType::kF8E4M3FN: return custom_dtypes.float8_e4m3fn; case ifrt::DType::kF8E4M3B11FNUZ: @@ -347,6 +381,12 @@ const NumpyScalarTypes& GetNumpyScalarTypes() { dtypes->np_uint32 = nb::object(numpy.attr("uint32")); dtypes->np_uint64 = nb::object(numpy.attr("uint64")); dtypes->np_bfloat16 = nb::object(ml_dtypes.attr("bfloat16")); + if (nb::hasattr(ml_dtypes, "float8_e3m4")) { + dtypes->np_float8_e3m4 = nb::object(ml_dtypes.attr("float8_e3m4")); + } + if (nb::hasattr(ml_dtypes, "float8_e4m3")) { + dtypes->np_float8_e4m3 = nb::object(ml_dtypes.attr("float8_e4m3")); + } dtypes->np_float8_e4m3fn = nb::object(ml_dtypes.attr("float8_e4m3fn")); dtypes->np_float8_e4m3b11fnuz = nb::object(ml_dtypes.attr("float8_e4m3b11fnuz")); diff --git a/third_party/xla/xla/python/types.h b/third_party/xla/xla/python/types.h index ed7ca847b1a7f7..fece926edd3017 100644 --- a/third_party/xla/xla/python/types.h +++ b/third_party/xla/xla/python/types.h @@ -79,6 +79,9 @@ struct NumpyScalarTypes { nanobind::object np_uint32; nanobind::object np_uint64; nanobind::object np_bfloat16; + // Remove std::optional once the minimum ml_dtypes in JAX is >= 0.5.0. + std::optional np_float8_e3m4; + std::optional np_float8_e4m3; nanobind::object np_float8_e4m3fn; nanobind::object np_float8_e4m3b11fnuz; nanobind::object np_float8_e4m3fnuz; @@ -128,7 +131,6 @@ nanobind::tuple SpanToNbTuple(absl::Span xs) { // references to the objects. nanobind::tuple MutableSpanToNbTuple(absl::Span xs); - template std::vector IterableToVector(const nanobind::iterable& iterable) { std::vector output; diff --git a/third_party/xla/xla/python/weakref_lru_cache.cc b/third_party/xla/xla/python/weakref_lru_cache.cc index 1767e1dabb9cb1..ade03a916864f1 100644 --- a/third_party/xla/xla/python/weakref_lru_cache.cc +++ b/third_party/xla/xla/python/weakref_lru_cache.cc @@ -19,14 +19,16 @@ limitations under the License. #include #include #include +#include #include #include // NOLINT +#include #include #include #include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" -#include "absl/container/node_hash_map.h" +#include "absl/hash/hash.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -35,7 +37,6 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/pjrt/lru_cache.h" -#include "xla/python/nb_helpers.h" namespace nb = nanobind; @@ -44,63 +45,92 @@ namespace { // Minimal wrapper to expose a nb::dict_iterator's value as something // hashable with Abseil. -class HashablePyDictValue { - protected: - using Iter = nb::detail::dict_iterator; +class HashablePyDictEntry { + public: + explicit HashablePyDictEntry(std::pair entry) + : entry_(entry) {} template - friend H AbslHashValue(H h, const HashablePyDictValue& value) { - auto kv = *value.iter_; - return H::combine(std::move(h), xla::nb_hash(kv.first), - xla::nb_hash(kv.second)); + friend H AbslHashValue(H h, const HashablePyDictEntry& v) { + return H::combine(std::move(h), nb::hash(v.entry_.first), + nb::hash(v.entry_.second)); } - explicit HashablePyDictValue(const Iter& iter) : iter_(iter) {} - - Iter iter_; + std::pair entry_; }; // Similarly, a minimalist adaptor around the nb::detail::dict_iterator // itself. Note that the iterator "is" also a Value. Does not meet the full // standard iterator requirements, only enough to support H::combine_unordered. -class HashablePyDictIter : protected HashablePyDictValue { +class HashablePyDictIter { public: using iterator_category = std::input_iterator_tag; - explicit HashablePyDictIter(const Iter& iter) : HashablePyDictValue(iter) {} + explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {} // Minimal set of iterator operations. - const HashablePyDictValue& operator*() const { return *this; } + HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); } bool operator!=(const HashablePyDictIter& rhs) const { return iter_ != rhs.iter_; } void operator++() { ++iter_; } + + private: + nb::detail::dict_iterator& iter_; +}; + +struct HashableKey { + nb::object context; + nb::args args; + nb::kwargs kwargs; + + template + friend H AbslHashValue(H h, const HashableKey& key) { + // Note: Despite the fact this is an ABSL hash function, it's safe to call + // functions that may throw exceptions such as nb::hash(), because it is + // used by an LRUCache, which uses a std::unordered_map, which is + // exception-safe. + h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args)); + nb::detail::dict_iterator begin = key.kwargs.begin(); + nb::detail::dict_iterator end = key.kwargs.end(); + h = H::combine_unordered(std::move(h), HashablePyDictIter(begin), + HashablePyDictIter(end)); + h = H::combine(std::move(h), key.kwargs.size()); + return h; + } }; } // namespace class WeakrefLRUCache : public std::enable_shared_from_this { public: - struct Key { - nb::object context; - nb::args args; - nb::kwargs kwargs; + class Key { + public: + Key(nb::object context, nb::args args, nb::kwargs kwargs) + : context_(std::move(context)), + args_(std::move(args)), + kwargs_(std::move(kwargs)), + cached_hash_(absl::HashOf(HashableKey{context_, args_, kwargs_})) {} bool operator==(const Key& other) const { - return context.equal(other.context) && args.equal(other.args) && - kwargs.equal(other.kwargs); + return context_.equal(other.context_) && args_.equal(other.args_) && + kwargs_.equal(other.kwargs_); } template friend H AbslHashValue(H h, const Key& key) { - h = H::combine(std::move(h), xla::nb_hash(key.context), - xla::nb_hash(key.args)); - h = H::combine_unordered(std::move(h), - HashablePyDictIter(key.kwargs.begin()), - HashablePyDictIter(key.kwargs.end())); - h = H::combine(std::move(h), key.kwargs.size()); - return h; + return H::combine(std::move(h), key.cached_hash_); } + + nb::object context() const { return context_; } + nb::args args() const { return args_; } + nb::kwargs kwargs() const { return kwargs_; } + + private: + nb::object context_; + nb::args args_; + nb::kwargs kwargs_; + size_t cached_hash_; }; struct CacheEntry { @@ -117,82 +147,76 @@ class WeakrefLRUCache : public std::enable_shared_from_this { int64_t currsize; }; - struct UnboundWeakrefCacheEntry { - nb::handle object; - WeakrefLRUCache* cache; + struct WeakrefCacheKey { + nb::weakref ref; size_t cached_hash; }; - struct WeakrefCacheEntry { - nb::weakref weakref; - size_t cached_hash; + using Cache = xla::LRUCache>; + + struct WeakrefCacheValue { + std::shared_ptr cache; }; struct WeakrefKeyHash { - using is_transparent = void; - - size_t operator()(const UnboundWeakrefCacheEntry& v) const { - return v.cached_hash; - } - size_t operator()(const WeakrefCacheEntry& v) const { - return v.cached_hash; - } + size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; } }; struct WeakrefKeyEq { - using is_transparent = void; - bool operator()(const WeakrefCacheEntry& lhs, - const WeakrefCacheEntry& rhs) const { - return lhs.weakref.equal(rhs.weakref); - } - bool operator()(const WeakrefCacheEntry& lhs, - const UnboundWeakrefCacheEntry& rhs) const { - PyObject* obj = PyWeakref_GET_OBJECT(lhs.weakref.ptr()); - if (obj == Py_None) { - return false; - } - return nb::borrow(obj).equal(rhs.object); + bool operator()(const WeakrefCacheKey& lhs, + const WeakrefCacheKey& rhs) const { + return lhs.ref.equal(rhs.ref); } }; - using Cache = xla::LRUCache>; WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn, int64_t maxsize) : cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {} - std::shared_ptr GetCache(const UnboundWeakrefCacheEntry& key) { - auto it = entries_.find(key); - if (it != entries_.end()) { - return (it->second); + std::shared_ptr GetCache(WeakrefCacheKey key) { + WeakrefCacheValue& value = entries_[key]; + if (!value.cache) { + value.cache = std::make_shared(&lru_list_); } - nb::weakref weakref( - key.object, - nb::cpp_function([this_weak = weak_from_this(), - cached_hash = key.cached_hash](nb::handle weakref) { + return value.cache; + } + + nb::object Call(nb::object weakref_key, nb::args args, + nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { + nb::object context = cache_context_fn_(); + + // We precompute all of the hash values needed by the various maps rather + // than computing them during the std::unordered_map insertions. At the very + // least, MSVC's std::unordered_map has undefined behavior if the hash + // function throws an exception + // (https://learn.microsoft.com/en-us/cpp/standard-library/unordered-map-class?view=msvc-170#emplace). + Key key(context, args, kwargs); + size_t wrcache_hash = static_cast(nb::hash(weakref_key)); + + // No hash computations after this point. + + auto weakref_gc_callback = nb::cpp_function( + [this_weak = weak_from_this(), wrcache_hash](nb::handle weakref) { auto cache = this_weak.lock(); if (cache == nullptr) { return; } + // The object the reference referred to is now in the process of being + // destroyed, so we cannot refer to its contents. Python weakref + // objects compare based on identity if the object they refer to is + // gone, so the hash lookup will work fine. auto it = cache->entries_.find( - WeakrefCacheEntry{nb::borrow(weakref), cached_hash}); + WeakrefCacheKey{nb::borrow(weakref), wrcache_hash}); if (it == cache->entries_.end()) { return; } // Create temp-var to avoid re-entrant erase. auto tmp = std::move(it->second); cache->entries_.erase(it); - })); - return (entries_ - .emplace(WeakrefCacheEntry{std::move(weakref), key.cached_hash}, - std::make_shared(&lru_list_)) - .first->second); - } - - nb::object Call(nb::object weakref_key, nb::args args, - nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS { - nb::object context = cache_context_fn_(); - std::shared_ptr cache_ptr = GetCache(UnboundWeakrefCacheEntry{ - weakref_key, this, static_cast(xla::nb_hash(weakref_key))}); + }); + nb::weakref weakref = nb::weakref(weakref_key, weakref_gc_callback); + WeakrefCacheKey wrcache_key{weakref, wrcache_hash}; + std::shared_ptr cache_ptr = GetCache(wrcache_key); Cache& cache = *cache_ptr; ++total_queries_; @@ -212,7 +236,6 @@ class WeakrefLRUCache : public std::enable_shared_from_this { // released if that happens. absl::Cleanup unlock = [this]() ABSL_UNLOCK_FUNCTION(mu_) { mu_.Unlock(); }; - Key key{context, args, kwargs}; entry = cache.GetOrCreateIfAbsent(key, [&inserted](const Key& key) { inserted = true; return std::make_shared(); @@ -248,11 +271,11 @@ class WeakrefLRUCache : public std::enable_shared_from_this { std::vector GetKeys() { std::vector results; mu_.Lock(); - for (const auto& wr_key : entries_) { - for (const auto& rest : *wr_key.second) { + for (const auto& wr_entry : entries_) { + for (const auto& rest : *wr_entry.second.cache) { nb::tuple result = - nb::make_tuple(wr_key.first.weakref, rest.first.context, - rest.first.args, rest.first.kwargs); + nb::make_tuple(*wr_entry.first.ref, rest.first.context(), + rest.first.args(), rest.first.kwargs()); results.push_back(std::move(result)); } } @@ -270,8 +293,9 @@ class WeakrefLRUCache : public std::enable_shared_from_this { void Clear() { total_queries_ = misses_ = 0; std::vector> deferred_deletes; + deferred_deletes.reserve(entries_.size()); for (auto& entry : entries_) { - deferred_deletes.push_back(std::move(entry.second)); + deferred_deletes.push_back(std::move(entry.second.cache)); } entries_.clear(); deferred_deletes.clear(); @@ -280,8 +304,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this { nb::callable cache_context_fn_; nb::callable fn_; Cache::LRUList lru_list_; - absl::node_hash_map, WeakrefKeyHash, - WeakrefKeyEq> + std::unordered_map entries_; int64_t misses_ = 0; int64_t total_queries_ = 0; diff --git a/third_party/xla/xla/python/weakref_lru_cache_test.py b/third_party/xla/xla/python/weakref_lru_cache_test.py index ad5f07bee0bf72..92aa783d6b52a6 100644 --- a/third_party/xla/xla/python/weakref_lru_cache_test.py +++ b/third_party/xla/xla/python/weakref_lru_cache_test.py @@ -15,6 +15,7 @@ import threading import time +import weakref from absl.testing import absltest @@ -111,6 +112,19 @@ class WRKey: cache(wrkey, "arg2") self.assertLen(cache.cache_keys(), 2) + def testNonWeakreferenceableKey(self): + class NonWRKey: + __slots__ = () + + non_wr_key = NonWRKey() + with self.assertRaises(TypeError): + weakref.ref(non_wr_key) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x: 2048) + for _ in range(100): + with self.assertRaises(TypeError): + cache(non_wr_key) + def testCrashingKey(self): class WRKey: pass @@ -146,6 +160,29 @@ class WRKey: "WeakrefLRUCache(hits=5, misses=10, maxsize=2048, currsize=10)", ) + def testGCKeys(self): + class WRKey: + + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + cache = xla_client.weakref_lru_cache(lambda: None, lambda x, y: y, 2048) + keys = [WRKey(i) for i in range(10)] + for i in range(10): + cache(keys[i], i) + + # Delete some keys, to exercise the weakref callback behavior. + del keys[::2] + + for key in keys: + cache(key, 7) + if __name__ == "__main__": absltest.main() diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 987b06b9f8dd4a..3605259af1f148 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -90,6 +90,7 @@ limitations under the License. #include "xla/pjrt/pjrt_layout.h" #include "xla/python/custom_call_sharding.h" #include "xla/python/dlpack.h" +#include "xla/python/guard_lib.h" #include "xla/python/jax_jit.h" #include "xla/python/logging.h" // IWYU pragma: keep #include "xla/python/mlir.h" @@ -97,7 +98,6 @@ limitations under the License. #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/ops.h" -#include "xla/python/outfeed_receiver_py.h" #include "xla/python/pjit.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_executable.h" @@ -115,7 +115,6 @@ limitations under the License. #include "xla/python/pytree.h" #include "xla/python/sharding.h" #include "xla/python/traceback.h" -#include "xla/python/transfer_guard_lib.h" #include "xla/python/weakref_lru_cache.h" #include "xla/python/xla_compiler.h" #include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h" @@ -182,6 +181,10 @@ NB_MODULE(xla_extension, m_nb) { // Exceptions nb::exception xla_runtime_error(m_nb, "XlaRuntimeError", PyExc_RuntimeError); + xla_runtime_error.attr("__doc__") = nb::str( + "Runtime errors thrown by the JAX runtime. While the JAX runtime may " + "raise other exceptions as well, most exceptions thrown by the runtime " + "are instances of this class."); // Types nb::enum_(m_nb, "PrimitiveType", nb::is_arithmetic()) @@ -198,6 +201,9 @@ NB_MODULE(xla_extension, m_nb) { .value("U32", U32) .value("U64", U64) .value("F16", F16) + // TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + // .value("F8E3M4", F8E3M4) + // .value("F8E4M3", F8E4M3) .value("F8E4M3FN", F8E4M3FN) .value("F8E4M3B11FNUZ", F8E4M3B11FNUZ) .value("F8E4M3FNUZ", F8E4M3FNUZ) @@ -334,7 +340,7 @@ NB_MODULE(xla_extension, m_nb) { options.collectives = std::move(collectives); options.process_id = node_id; std::unique_ptr client = - xla::ValueOrThrow(GetTfrtCpuClient(options)); + xla::ValueOrThrow(GetTfrtCpuClient(std::move(options))); ifrt::PjRtClient::CreateOptions ifrt_options; ifrt_options.pjrt_client = std::shared_ptr(std::move(client)); @@ -584,12 +590,11 @@ NB_MODULE(xla_extension, m_nb) { BuildIfrtProgramsSubmodule(m_nb); BuildProfilerSubmodule(m_nb); BuildOpsSubmodule(m_nb); - BuildOutfeedReceiverSubmodule(m_nb); BuildPytreeSubmodule(m_nb); + jax::BuildGuardSubmodule(m_nb); jax::BuildJaxjitSubmodule(m_nb); jax::BuildPmapSubmodule(m_nb); jax::BuildPjitSubmodule(m_nb); - jax::BuildTransferGuardSubmodule(m_nb); BuildTracebackSubmodule(m_nb); BuildMlirSubmodule(m_nb); BuildCustomCallShardingPybindAPI(m_nb); @@ -772,11 +777,12 @@ NB_MODULE(xla_extension, m_nb) { std::optional init_timeout, std::optional shutdown_timeout, std::optional heartbeat_interval, std::optional max_missing_heartbeats, - std::optional> + std::optional> missed_heartbeat_callback, - std::optional shutdown_on_destruction) + std::optional shutdown_on_destruction, + std::optional use_compression) -> std::shared_ptr { + bool compression = use_compression.value_or(false); DistributedRuntimeClient::Options options; options.node_id = node_id; if (rpc_timeout.has_value()) { @@ -801,7 +807,7 @@ NB_MODULE(xla_extension, m_nb) { if (shutdown_on_destruction.has_value()) { options.shutdown_on_destruction = *shutdown_on_destruction; } - return GetDistributedRuntimeClient(address, options); + return GetDistributedRuntimeClient(address, options, compression); }, nb::arg("address"), nb::arg("node_id"), nb::arg("rpc_timeout").none() = std::nullopt, @@ -810,7 +816,8 @@ NB_MODULE(xla_extension, m_nb) { nb::arg("heartbeat_interval").none() = std::nullopt, nb::arg("max_missing_heartbeats").none() = std::nullopt, nb::arg("missed_heartbeat_callback").none() = std::nullopt, - nb::arg("shutdown_on_destruction").none() = std::nullopt); + nb::arg("shutdown_on_destruction").none() = std::nullopt, + nb::arg("use_compression").none() = std::nullopt); m_nb.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); }); diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 89332de94b0b82..58da33cc93b928 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 286 +_version = 293 # Version number for MLIR:Python components. mlir_api_version = 57 @@ -89,6 +89,7 @@ def make_gpu_client( platform_name=None, allowed_devices=None, mock=False, + mock_gpu_topology=None, ): """Returns a GPU client. BFC allocator is used by default.""" options = generate_pjrt_gpu_plugin_options() @@ -120,6 +121,7 @@ def make_gpu_client( platform_name=platform_name, allowed_devices=allowed_devices, mock=mock, + mock_gpu_topology=mock_gpu_topology, ) @@ -274,6 +276,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType = _xla.PrimitiveType bfloat16 = ml_dtypes.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = ml_dtypes.float8_e3m4 +# float8_e4m3 = ml_dtypes.float8_e4m3 float8_e4m3fn = ml_dtypes.float8_e4m3fn float8_e4m3b11fnuz = ml_dtypes.float8_e4m3b11fnuz float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz @@ -292,6 +297,9 @@ def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): PrimitiveType.U16: np.dtype('uint16'), PrimitiveType.U32: np.dtype('uint32'), PrimitiveType.U64: np.dtype('uint64'), + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # PrimitiveType.F8E3M4: np.dtype(float8_e3m4), + # PrimitiveType.F8E4M3: np.dtype(float8_e4m3), PrimitiveType.F8E4M3FN: np.dtype(float8_e4m3fn), PrimitiveType.F8E4M3B11FNUZ: np.dtype(float8_e4m3b11fnuz), PrimitiveType.F8E5M2: np.dtype(float8_e5m2), @@ -469,40 +477,6 @@ def computation_count(): # There are different implementations of Executable for different backends. -def execute_with_python_values(executable, arguments, backend): - """Execute on one replica with Python values as arguments and output.""" - - def put(arg): - return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) - - arguments = [put(arg) for arg in arguments] - outputs = executable.execute(arguments) - return [np.asarray(x) for x in outputs] - - -def execute_with_python_values_replicated(executable, arguments, backend): - """Execute on many replicas with Python values as arguments and output. - - Args: - executable: the program to run. - arguments: a list of lists of Python values indexed by `[replica][arg_num]` - to pass as inputs. - backend: the backend we are targeting. - - Returns: - A list of python values, one per replica. - """ - devices = executable.local_devices() - - # pylint: disable=g-complex-comprehension - def copy_to_devices(pyvals): - return [backend.buffer_from_pyval(v, d) for v, d in zip(pyvals, devices)] - - inputs = [copy_to_devices(pyvals) for pyvals in zip(*arguments)] - outputs = executable.execute_sharded_on_local_devices(inputs) - return [[np.asarray(x) for x in xs] for xs in zip(*outputs)] - - class PaddingType(enum.Enum): VALID = 1 SAME = 2 diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index 8731080c99b52a..f12b8fb3807712 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -59,6 +59,9 @@ _version: int mlir_api_version: int bfloat16: type[numpy.generic] +# TODO: Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4: type[numpy.generic] +# float8_e4m3: type[numpy.generic] float8_e4m3fn: type[numpy.generic] float8_e4m3b11fnuz: type[numpy.generic] float8_e4m3fnuz: type[numpy.generic] @@ -71,13 +74,6 @@ _NameValueMapping = Mapping[str, Union[str, int, list[int], float, bool]] def dtype_to_etype(dtype: numpy.dtype) -> PrimitiveType: ... -def execute_with_python_values(executable: LoadedExecutable, arguments: Sequence[Any], - backend: Client) -> Sequence[numpy.ndarray]: ... - -def execute_with_python_values_replicated( - executable: LoadedExecutable, arguments: Sequence[Sequence[Any]], - backend: Client) -> Sequence[Sequence[numpy.ndarray]]: ... - def shape_from_pyval(pyval: Any, layout: Sequence[int] | None = None) -> Any: ... def heap_profile(client: Client) -> bytes: @@ -101,6 +97,7 @@ def make_gpu_client( platform_name: str | None = ..., allowed_devices: set[int] | None = ..., mock: bool | None = ..., + mock_gpu_topology: str | None = ..., ) -> Client: ... diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index d4406155eadd42..295f37332fd257 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -27,6 +27,7 @@ from absl import logging from absl.testing import absltest from absl.testing import parameterized +import ml_dtypes import numpy as np from xla.python import xla_client @@ -54,6 +55,9 @@ xla_client._xla.jax_jit.global_state().enable_memories = False bfloat16 = xla_client.bfloat16 +# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. +# float8_e3m4 = xla_client.float8_e3m4 +# float8_e4m3 = xla_client.float8_e4m3 float8_e4m3fn = xla_client.float8_e4m3fn float8_e4m3fnuz = xla_client.float8_e4m3fnuz float8_e4m3b11fnuz = xla_client.float8_e4m3b11fnuz @@ -64,6 +68,17 @@ xla_client._xla.mlir.xla_computation_to_mlir_module) +def execute_with_python_values(executable, arguments, backend): # pylint: disable=invalid-name + """Execute on one replica with Python values as arguments and output.""" + + def put(arg): # pylint: disable=invalid-name + return backend.buffer_from_pyval(arg, device=executable.local_devices()[0]) + + arguments = [put(arg) for arg in arguments] + outputs = executable.execute(arguments) + return [np.asarray(x) for x in outputs] + + # pylint: disable=invalid-name def jax_array_convert_to_array(self, dtype=None, copy=None): del copy @@ -100,10 +115,19 @@ def jax_array_copy_to_host_async(self): _CUSTOM_CALLS_REGISTERED = False +# XLA' alignment is 16 bytes at the moment, but it should match what Eigen +# supports, and that can go up to 128 bytes on hardware with HVX. +_XLA_CPU_MAX_ALIGNMENT = 128 + + +# Minimum possible alignment for XLA. +_XLA_CPU_MIN_ALIGNMENT = 16 + + # Return a copy of `x` with the given alignment. Does nothing if `x` is already # aligned. We do this manually, because numpy doesn't support custom alignment # value. -def _Aligned(x, alignment=128): +def _Aligned(x, alignment=_XLA_CPU_MAX_ALIGNMENT): if (x.ctypes.data % alignment) == 0: return x @@ -122,6 +146,31 @@ def _Aligned(x, alignment=128): return result +# Return an unaligned copy of `x`. The result buffer's memory address is +# guaranteed to not be aligned to `alignment`. This function is useful for +# testing failiures. +def _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT): + if (x.ctypes.data % alignment) != 0: + return x + + # Create temporary buffer with extra space. + assert (x.itemsize % alignment) != 0 + offset = 1 + buf = np.empty(x.size + offset, dtype=x.dtype) + + if (buf.ctypes.data % alignment) != 0: + # If the temporary buffer is already unaligned, return it. + result = buf + else: + # Otherwise, create a view of the temporary buffer with an offset. + result = buf[offset : offset + x.size].reshape(x.shape) + assert (result.ctypes.data % alignment) != 0 + + # Copy the data to the result buffer and return it. + np.copyto(result, x) + return result + + def TestFactory(xla_backend, cloud_tpu=False, tfrt_tpu=False, @@ -138,7 +187,10 @@ def TestFactory(xla_backend, # TODO(zhangqiaorjc): test fp8 types when XLA support is complete. # standard_dtypes is only used for BufferProtocolTest so we only test fp8 # round trip tests. - standard_dtypes += [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + fp8_dtypes = [float8_e4m3b11fnuz, float8_e4m3fn, float8_e5m2] + standard_dtypes += fp8_dtypes + # TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0. + # standard_dtypes += [float8_e3m4, float8_e4m3] dlpack_dtypes = int_dtypes + float_dtypes + [np.bool_] + complex_dtypes class ComputationTest(parameterized.TestCase): @@ -164,7 +216,7 @@ def _NewComputation(self, name=None): def _Execute(self, c, arguments): compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build())) - return xla_client.execute_with_python_values( + return execute_with_python_values( compiled_c, arguments, backend=self.backend) def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): @@ -596,7 +648,7 @@ def testExecuteFromProto(self): # Load and execute the proto c = xla_client.XlaComputation(serialized_proto) m = xla_computation_to_mlir_module(c) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( self.backend.compile(m), (), backend=self.backend) np.testing.assert_equal(ans, np.int32(3)) @@ -1245,7 +1297,7 @@ def testConvertElementType(self, src_dtype, dst_dtype): ops.ConvertElementType( ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 1) @@ -1275,7 +1327,7 @@ def testBitcastConvertType(self, src_dtype, dst_dtype): ops.BitcastConvertType( ops.Constant(c, x), xla_client.dtype_to_etype(dst_dtype)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 1) @@ -1859,7 +1911,7 @@ def testTuple(self): ops.Constant(c, NumpyArrayF32([1.0, 2.0])), ops.Constant(c, NumpyArrayBool([True, False, False, True])) ]) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 3) @@ -1899,7 +1951,7 @@ def testRngNormal(self): ops.Constant(c, NumpyArrayF32(1.)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape and uniqueness @@ -1916,7 +1968,7 @@ def testRngUniformF32(self): ops.Constant(c, NumpyArrayF32(hi)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.F32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape, uniqueness, and range @@ -1935,7 +1987,7 @@ def testRngUniformS32(self): ops.Constant(c, NumpyArrayS32(hi)), shape=xla_client.Shape.array_shape(xla_client.PrimitiveType.S32, shape)) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) # since the result is random, we just check shape, integrality, and range @@ -1965,7 +2017,7 @@ def testSortKeyVal(self): values = np.array([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=np.int32) c = self._NewComputation() ops.Sort(c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=0) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 2) @@ -1988,7 +2040,7 @@ def testSortCustomComparator(self): c, (ops.Constant(c, keys), ops.Constant(c, values)), dimension=1, comparator=comparator) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), (), backend=self.backend) self.assertLen(result, 2) @@ -2175,15 +2227,36 @@ def testFft(self): c, expected=[np.fft.irfftn(a, axes=(1, 2, 3))], rtol=2e-4 ) - def testNextAfter(self): - c = self._NewComputation() - ops.NextAfter( - ops.Constant(c, np.array([1, 2], dtype=np.float32)), - ops.Constant(c, np.array([2, 1], dtype=np.float32))) + @parameterized.named_parameters({ + "testcase_name": "_{}".format(dtype.__name__), + "dtype": dtype, + } for dtype in float_dtypes + fp8_dtypes) + def testNextAfter(self, dtype): + if dtype == np.float64 and self.backend.platform == "tpu": + self.skipTest("TPU doesn't support float64") + if dtype == bfloat16 and self.backend.platform == "tpu": + self.skipTest("b/371119032: Test fails on TPUs with bfloat16") + finfo = ml_dtypes.finfo(dtype) + eps = finfo.eps + c = self._NewComputation() + # Each row is (value, direction, expected), where + # 'nextafter(value, direction)' should be 'expected'. + data = np.array( + [ + [1, 2, 1 + finfo.eps], + [2, 1, 2 - eps], + [-0., 1, finfo.smallest_subnormal], + [0., -1, -finfo.smallest_subnormal], + [-finfo.smallest_subnormal, 1, -0.], + [finfo.smallest_subnormal, 1, 2 * finfo.smallest_subnormal], + [finfo.smallest_subnormal, -1, 0], + ], + dtype=dtype, + ) + + ops.NextAfter(ops.Constant(c, data[:, 0]), ops.Constant(c, data[:, 1])) out, = self._Execute(c, ()) - eps = np.finfo(np.float32).eps - np.testing.assert_equal( - np.array([eps + 1, 2 - eps], dtype=np.float32), out) + np.testing.assert_equal(out, data[:, 2]) @parameterized.named_parameters({ "testcase_name": "_{}".format(dtype.__name__), @@ -2578,7 +2651,7 @@ def testInfeedS32Values(self): device.transfer_to_infeed(item) for item in to_infeed: - result, = xla_client.execute_with_python_values( + result, = execute_with_python_values( compiled_c, (), backend=self.backend) self.assertEqual(result, item) @@ -2597,7 +2670,7 @@ def testInfeedTuple(self): device = self.backend.local_devices()[0] device.transfer_to_infeed(to_infeed) - result = xla_client.execute_with_python_values( + result = execute_with_python_values( compiled_c, (), backend=self.backend) self.assertLen(result, 2) np.testing.assert_equal(result[0], to_infeed[0]) @@ -2653,6 +2726,17 @@ def testScatter(self): class DeviceTest(ComputationTest): + def testDevices(self): + self.assertNotEmpty(self.backend.devices()) + + def testLocalDevices(self): + self.assertNotEmpty(self.backend.local_devices()) + + def testGetAllDevices(self): + # TODO(hyeontaek): Remove this method once we have a unified API for + # enumerating devices with different criteria. + self.assertNotEmpty(self.backend._get_all_devices()) # pylint: disable=protected-access + def testPlatform(self): for device in self.backend.local_devices(): self.assertEqual(device.platform, self.backend.platform) @@ -2741,7 +2825,7 @@ def testInvokeWithWrongElementType(self): c.clear_op_metadata() def TestFun(): - return xla_client.execute_with_python_values( + return execute_with_python_values( self.backend.compile(xla_computation_to_mlir_module(c.build())), [self.f32_scalar_2], self.backend) @@ -2763,7 +2847,7 @@ def testComputationRootDifferentFromLastOp(self): arg = NumpyArrayF32(1.0) compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build(result))) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -2787,7 +2871,7 @@ def testSetSharding(self): arg = NumpyArrayF32(1.0) compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build(result))) - ans, = xla_client.execute_with_python_values( + ans, = execute_with_python_values( compiled_c, [arg], backend=self.backend) np.testing.assert_allclose(ans, 4.14) @@ -2831,18 +2915,53 @@ def tearDown(self): del self.cpu_backend del self.gpu_backend + @classmethod + def _GetStreamFromDevice(cls, device): + try: + return device.get_stream_for_external_ready_events() + except xla_client.XlaRuntimeError as err: # type: ignore + if "UNIMPLEMENTED" in str(err): + return None + else: + raise + + def _DLPackManagedTensorToBuffer( + self, tensor, use_legacy_api, backend=None + ): + if use_legacy_api: + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, self.cpu_backend, self.gpu_backend + ) + else: + if not backend: + backend = self.backend + device = backend.local_devices()[0] + stream = DLPackTest._GetStreamFromDevice(device) + return xla_client._xla.dlpack_managed_tensor_to_buffer( + tensor, device, stream + ) + # pylint: disable=g-complex-comprehension # pyformat: disable - @parameterized.named_parameters({ - "testcase_name": "{}_gpu={}".format( - FormatShapeAndDtype(shape, dtype), gpu), - "dtype": dtype, - "shape": shape, - "gpu": gpu - } for dtype in dlpack_dtypes for shape in testcase_shapes - for gpu in [False, True]) + @parameterized.named_parameters( + { + "testcase_name": "{}_gpu={}{}".format( + FormatShapeAndDtype(shape, dtype), + gpu, + "_legacy" if use_legacy_api else "", + ), + "dtype": dtype, + "shape": shape, + "gpu": gpu, + "use_legacy_api": use_legacy_api, + } + for dtype in dlpack_dtypes + for shape in testcase_shapes + for gpu in [False, True] + for use_legacy_api in [False, True] + ) # pyformat: enable - def testRoundTrip(self, dtype, shape, gpu): + def testRoundTrip(self, dtype, shape, gpu, use_legacy_api): if gpu and self.gpu_backend is None: raise unittest.SkipTest("Test not running with GPU support") backend = self.gpu_backend if gpu else self.cpu_backend @@ -2854,41 +2973,130 @@ def testRoundTrip(self, dtype, shape, gpu): dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) del buffer # Free "buffer" to make sure dlt retains ownership. self.assertEqual(type(dlt).__name__, "PyCapsule") - y = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlt, self.cpu_backend, self.gpu_backend) + y = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api, backend) np.testing.assert_array_equal( x.astype(np.uint8) if dtype == np.bool_ else x, np.asarray(y)) - def testTensorsCanBeConsumedOnceOnly(self): + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testTensorsCanBeConsumedOnceOnly(self, use_legacy_api): x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) buffer = self.backend.buffer_from_pyval(x) dlt = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) def ConsumeDLPackTensor(): - _ = xla_client._xla.dlpack_managed_tensor_to_buffer( - dlt, self.cpu_backend, self.gpu_backend - ) + _ = self._DLPackManagedTensorToBuffer(dlt, use_legacy_api) ConsumeDLPackTensor() self.assertRaisesRegex( RuntimeError, ".*a DLPack tensor may be consumed at most once.*", ConsumeDLPackTensor) - def testNonOwnedDlpackCanBeViewedTwice(self): + @parameterized.named_parameters( + { + "testcase_name": "{}".format("_legacy" if use_legacy_api else ""), + "use_legacy_api": use_legacy_api, + } + for use_legacy_api in [False, True] + ) + def testNonOwnedDlpackCanBeViewedTwice(self, use_legacy_api): x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) buffer = self.backend.buffer_from_pyval(x) d1 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) d2 = xla_client._xla.buffer_to_dlpack_managed_tensor(buffer) - y = xla_client._xla.dlpack_managed_tensor_to_buffer( - d1, self.cpu_backend, self.gpu_backend) - z = xla_client._xla.dlpack_managed_tensor_to_buffer( - d2, self.cpu_backend, self.gpu_backend) + y = self._DLPackManagedTensorToBuffer(d1, use_legacy_api) + z = self._DLPackManagedTensorToBuffer(d2, use_legacy_api) del d1, d2 np.testing.assert_array_equal(x, np.asarray(buffer)) np.testing.assert_array_equal(x, np.asarray(y)) np.testing.assert_array_equal(x, np.asarray(z)) + @parameterized.parameters(False, True) + def testZeroCopyOnAlignedDlpackTensor(self, use_legacy_api): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + # Create a numpy array that is aligned to XLA requirements. + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Aligned(x) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was sufficiently aligned, so input and output should alias. + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertEqual( + x_ptr, + y_ptr, + msg=f"Buffers are not aliased ({hex(x_ptr)} != {hex(y_ptr)}).", + ) + + @parameterized.named_parameters( + { + "testcase_name": "{}{}".format( + "_legacy" if use_legacy_api else "", + "_transpose" if transpose else "", + ), + "use_legacy_api": use_legacy_api, + "transpose": transpose, + } + for use_legacy_api in [False, True] + for transpose in [False, True] + ) + def testReturnCopyOnUnalignedDlpackTensor(self, use_legacy_api, transpose): + # Using CPU only, since this test is about CPU memory alignment. + if self.backend.platform != "cpu": + self.skipTest("Test requires CPU") + + if transpose and use_legacy_api: + self.skipTest("Non-default layout is not supported in legacy API") + + # Create a numpy array that is not aligned to XLA requirements. XLA's + # alignment requirements differ for different hardware, so we use the + # smallest possible value. If we make sure the buffer is not aligned to + # this value (16 bytes), then it is also not aligned to its multiples (32, + # 64 etc.) + x = np.array(np.random.rand(3, 4, 5, 6), dtype=np.float32) + x = _Unaligned(x, alignment=_XLA_CPU_MIN_ALIGNMENT) + + # Transpose the array to test non-default layout with trivial striding. + if transpose: + x = x.transpose((0, 2, 1, 3)) + + # Convert it to a DLPack tensor, and then to an XLA buffer. + dlpack_tensor = x.__dlpack__() + buffer = self._DLPackManagedTensorToBuffer(dlpack_tensor, use_legacy_api) + y = np.array(buffer, copy=False) + + # The input was not sufficiently aligned, so input and output should not + # alias (output should be a copy of input, and it should be aligned). + x_ptr = x.__array_interface__["data"][0] + y_ptr = y.__array_interface__["data"][0] + self.assertNotEqual( + x_ptr, + y_ptr, + msg=( + f"Buffers aliased, but should not be ({hex(x_ptr)} ==" + f" {hex(y_ptr)})" + ), + ) + self.assertEqual( + y_ptr % _XLA_CPU_MIN_ALIGNMENT, + 0, + msg="Output buffer not aligned: {hex(y_ptr)}", + ) + np.testing.assert_array_equal(y, x) + tests.append(DLPackTest) class BufferProtocolTest(parameterized.TestCase): @@ -2909,10 +3117,7 @@ def setUp(self): def testRoundTrip(self, dtype, shape): x = np.array(np.random.rand(*shape) * 100, dtype=dtype) - # XLA' alignment is 16 bytes at the moment, but it should match what Eigen - # supports, and that can go up to 128 bytes on hardware with HVX. Align - # the input buffer to 128 bytes to be safe. - x = _Aligned(x, alignment=128) + x = _Aligned(x) x_ptr = x.__array_interface__["data"][0] buffer = self.backend.buffer_from_pyval( x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY) @@ -3079,7 +3284,7 @@ def testPlatformVersion(self): version = self.backend.platform_version logging.info("platform_version:\n%s", version) if self.backend.platform == "cpu": - self.assertEqual(version, "") + self.assertEqual(version, "cpu") elif self.backend.platform in ("gpu", "cuda", "rocm"): # Following is false if not built with --config=cuda if version != "": @@ -3128,7 +3333,7 @@ def testHloProgramViaIfrtProgram(self): ) compiled_c = self.backend.compile_ifrt_program(program, options) - results = xla_client.execute_with_python_values( + results = execute_with_python_values( compiled_c, arguments=(), backend=self.backend ) @@ -3154,10 +3359,8 @@ def testExecutableSerialization(self): serialized = self.backend.serialize_executable(executable) deserialized = self.backend.deserialize_executable(serialized, options) - expected, = xla_client.execute_with_python_values(executable, (), - self.backend) - actual, = xla_client.execute_with_python_values(deserialized, (), - self.backend) + expected, = execute_with_python_values(executable, (), self.backend) + actual, = execute_with_python_values(deserialized, (), self.backend) self.assertTrue(np.all(actual == expected)) def testCompileOptionsSerialization(self): @@ -3169,17 +3372,12 @@ def testCompileOptionsSerialization(self): options.compile_portable_executable = True executable_build_options.num_replicas = 3 executable_build_options.num_partitions = 2 - executable_build_options.debug_options.xla_cpu_enable_fast_math = True - executable_build_options.debug_options.xla_test_all_input_layouts = True - executable_build_options.debug_options.xla_gpu_kernel_cache_file = ( - "/foo/bar" - ) - executable_build_options.debug_options.xla_gpu_enable_llvm_module_compilation_parallelism = ( - True - ) - executable_build_options.debug_options.xla_gpu_per_fusion_autotune_cache_dir = ( - "/bar/foo/" - ) + deb_opt = executable_build_options.debug_options + deb_opt.xla_cpu_enable_fast_math = True + deb_opt.xla_test_all_input_layouts = True + deb_opt.xla_gpu_kernel_cache_file = "/foo/bar" + deb_opt.xla_gpu_enable_llvm_module_compilation_parallelism = True + deb_opt.xla_gpu_per_fusion_autotune_cache_dir = "/bar/foo/" b = options.SerializeAsString() restored = xla_client.CompileOptions.ParseFromString(b) @@ -3352,8 +3550,9 @@ def testExecuteShardedOnLocalDevicesWithTokens(self): options.num_replicas = num_replicas compiled_c = self.backend.compile( xla_computation_to_mlir_module(c.build()), compile_options=options) - results, sharded_token = compiled_c.execute_sharded_on_local_devices_with_tokens( - []) + results, sharded_token = ( + compiled_c.execute_sharded_on_local_devices_with_tokens([]) + ) sharded_token.block_until_ready() self.assertLen(results, 1) self.assertLen(results[0], 1) diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index e61585b61cccdc..28c8f956f69ccf 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -45,17 +45,21 @@ limitations under the License. #include "nanobind/stl/vector.h" // IWYU pragma: keep #include "xla/array.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -70,14 +74,10 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/computation_placer.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/name_uniquer.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/lib/strings/proto_serialization.h" @@ -417,6 +417,32 @@ void DefRepeatedEnumProperty(nb::class_& cls, const char* name, }); } +template +Array NDArrayToArray(nb::ndarray ndarray) { + std::vector shapes; + shapes.reserve(ndarray.ndim()); + for (int i = 0; i < ndarray.ndim(); ++i) { + shapes.push_back(ndarray.shape(i)); + } + xla::Array array(shapes); + array.Each([&](absl::Span indices, int64_t* val) { + int64_t offset = indices.back(); + int64_t multiplier = 1; + for (int i = ndarray.ndim() - 1; i > 0; --i) { + multiplier *= ndarray.shape(i); + offset += indices[i - 1] * multiplier; + } + *val = *(ndarray.data() + offset); + }); + return array; +} + +absl::StatusOr SubgroupWithTileAssignmentHelper( + nb::ndarray tile_assignment, + absl::Span subgroup_types) { + return HloSharding::Subgroup(NDArrayToArray(tile_assignment), subgroup_types); +} + } // namespace void BuildXlaCompilerSubmodule(nb::module_& m) { @@ -1385,6 +1411,11 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { .def_static("manual", [] { return HloSharding::Manual(); }) .def_static("replicate", [] { return HloSharding::Replicate(); }) .def_static("unknown", [] { return HloSharding::Unknown(); }) + .def_static( + "subgroup_with_device_ordering", + xla::ValueOrThrowWrapper(SubgroupWithTileAssignmentHelper), + nb::arg("tile_assignment"), + nb::arg("subgroup_types") = absl::Span()) .def("__eq__", [](const xla::HloSharding& a, const xla::HloSharding& b) { return a == b; }) .def("__hash__", diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index b5ae4c6431ca66..a59ba6a5f6fba7 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -37,12 +37,12 @@ from typing import ( import numpy as np +from . import guard_lib from . import ifrt_programs from . import ifrt_proxy from . import jax_jit from . import mlir from . import ops -from . import outfeed_receiver from . import pmap_lib from . import profiler from . import pytree @@ -73,6 +73,8 @@ class PrimitiveType(enum.IntEnum): U16: PrimitiveType U32: PrimitiveType U64: PrimitiveType + F8E3M4: PrimitiveType + F8E4M3: PrimitiveType F8E4M3FN: PrimitiveType F8E4M3B11FNUZ: PrimitiveType F8E4M3FNUZ: PrimitiveType @@ -408,6 +410,10 @@ class HloSharding: def manual() -> HloSharding: ... @staticmethod def unknown() -> HloSharding: ... + @staticmethod + def subgroup_with_device_ordering( + tile_assignment: np.ndarray, + subgroup_types: Sequence[OpSharding.Type]) -> HloSharding: ... def __eq__(self, other: HloSharding) -> bool: ... def __hash__(self) -> int: ... def __repr__(self) -> str: ... @@ -497,6 +503,7 @@ class Client: def local_device_count(self) -> int: ... def devices(self) -> List[Device]: ... def local_devices(self) -> List[Device]: ... + def _get_all_devices(self) -> List[Device]: ... def device_from_local_hardware_id(self, int) -> Device: ... def live_buffers(self) -> List[Any]: ... def live_arrays(self) -> List[ArrayImpl]: ... @@ -821,6 +828,7 @@ def get_distributed_runtime_client( max_missing_heartbeats: Optional[int] = ..., missed_heartbeat_callback: Optional[Any] = ..., shutdown_on_destruction: Optional[bool] = ..., + use_compression: Optional[bool] = ..., ) -> DistributedRuntimeClient: ... class PreemptionSyncManager: @@ -877,6 +885,7 @@ class NamedSharding(Sharding): memory_kind: Optional[str] = None, _parsed_pspec: Any = None, _manual_axes: frozenset[Any] = frozenset(), + _logical_device_ids: tuple[int, ...] | None = None, ): ... mesh: Any spec: Any @@ -884,6 +893,7 @@ class NamedSharding(Sharding): _parsed_pspec: Any _internal_device_list: DeviceList _manual_axes: frozenset[Any] + _logical_device_ids: tuple[int, ...] | None class SingleDeviceSharding(Sharding): def __init__(self, device: Device, *, memory_kind: Optional[str] = None): ... diff --git a/third_party/xla/xla/python/xla_extension/guard_lib.pyi b/third_party/xla/xla/python/xla_extension/guard_lib.pyi new file mode 100644 index 00000000000000..b4d2817d457115 --- /dev/null +++ b/third_party/xla/xla/python/xla_extension/guard_lib.pyi @@ -0,0 +1,46 @@ +# Copyright 2024 The OpenXLA Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any, List, Optional + +class TransferGuardLevel: + ALLOW: Any + LOG: Any + DISALLOW: Any + LOG_EXPLICIT: Any + DISALLOW_EXPLICIT: Any + +class GarbageCollectionGuardLevel: + ALLOW: Any + LOG: Any + FATAL: Any + +class GuardState: + host_to_device: Optional[TransferGuardLevel] + device_to_device: Optional[TransferGuardLevel] + device_to_host: Optional[TransferGuardLevel] + + explicit_device_put: bool + explicit_device_get: bool + + garbage_collect_array: Optional[GarbageCollectionGuardLevel] + +def global_state() -> GuardState: ... +def thread_local_state() -> GuardState: ... + +class _TestingScopedLogSink: + def __enter__(self) -> _TestingScopedLogSink: ... + def __exit__(self, *args, **kwargs) -> None: ... + def logs(self) -> List[str]: ... diff --git a/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi b/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi deleted file mode 100644 index b0850355de65a3..00000000000000 --- a/third_party/xla/xla/python/xla_extension/outfeed_receiver.pyi +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2021 The OpenXLA Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import enum -from typing import Any, Optional, Sequence - -from xla.python import xla_extension - -Client = xla_extension.Client -XlaBuilder = xla_extension.XlaBuilder -XlaOp = xla_extension.XlaOp - -_CallbackToPython = Any - - -def start( - callback_to_python: _CallbackToPython, - backends: Sequence[Client], - max_queue_size_bytes: int = ..., - compile_options: Optional[xla_extension.ExecutableBuildOptions] = ..., -) -> OutfeedReceiverForPython: - ... - - -class OutfeedReceiverForPython: - - def add_outfeed( - builder: XlaBuilder, - token: XlaOp, - consumer_id: int, - arrays: Sequence[XlaOp], - device_idx: int, - ) -> XlaOp: - ... diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions.cc b/third_party/xla/xla/python/xplane_to_profile_instructions.cc index a446cc810b0196..b0db73556e367b 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions.cc @@ -30,15 +30,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc index c0a93291a3ef13..ee77891fb6b61c 100644 --- a/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc +++ b/third_party/xla/xla/python/xplane_to_profile_instructions_test.cc @@ -23,12 +23,12 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/profiler/convert/xla_op_utils.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace { diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index d7461cadcfe4a4..307c45a4fe7ffa 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal.h" @@ -38,9 +38,9 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/reference_util.h b/third_party/xla/xla/reference_util.h index a086fdbf8cef76..079d23c2a9d07c 100644 --- a/third_party/xla/xla/reference_util.h +++ b/third_party/xla/xla/reference_util.h @@ -29,7 +29,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc index c27e5414525553..7b70fe9e0cf7ce 100644 --- a/third_party/xla/xla/reference_util_test.cc +++ b/third_party/xla/xla/reference_util_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/padding.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/padding.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 095fd1a1f21fc3..1bfbe9bca51b36 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -87,7 +87,6 @@ tf_proto_library( xla_py_proto_library( name = "metrics_pb2", - api_version = 2, deps = [":metrics_proto"], ) @@ -113,34 +112,9 @@ cc_library( cc_library( name = "async_collective_creator", - srcs = ["async_collective_creator.cc"], hdrs = ["async_collective_creator.h"], - deps = [ - ":shape_inference", - "//xla:frontend_attributes", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "async_collective_creator_test", - srcs = ["async_collective_creator_test.cc"], - deps = [ - ":async_collective_creator", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:async_collective_creator instead.", + deps = ["//xla/hlo/transforms:async_collective_creator"], ) cc_library( @@ -149,6 +123,7 @@ cc_library( hdrs = ["all_reduce_key.h"], deps = [ ":hlo_domain_map", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/log", ], @@ -158,7 +133,18 @@ cc_library( name = "all_reduce_promotion", srcs = ["all_reduce_promotion.cc"], hdrs = ["all_reduce_promotion.h"], - deps = [":change_op_data_type"], + deps = [ + ":change_op_data_type", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], ) xla_cc_test( @@ -168,7 +154,12 @@ xla_cc_test( ":all_reduce_promotion", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) @@ -183,6 +174,7 @@ cc_library( ":pattern_matcher", "//xla:literal", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -204,6 +196,7 @@ xla_cc_test( ":pattern_matcher", ":pattern_matcher_gmock", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", @@ -218,39 +211,9 @@ xla_cc_test( cc_library( name = "all_reduce_folder", - srcs = ["all_reduce_folder.cc"], hdrs = ["all_reduce_folder.h"], - deps = [ - ":all_reduce_key", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "all_reduce_folder_test", - srcs = ["all_reduce_folder_test.cc"], - deps = [ - ":all_reduce_folder", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_reduce_folder instead.", + deps = ["//xla/hlo/transforms:all_reduce_folder"], ) cc_library( @@ -265,163 +228,30 @@ cc_library( cc_library( name = "broadcast_canonicalizer", - srcs = ["broadcast_canonicalizer.cc"], hdrs = ["broadcast_canonicalizer.h"], - deps = [ - ":hlo_creation_utils", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "broadcast_canonicalizer_test", - srcs = ["broadcast_canonicalizer_test.cc"], - deps = [ - ":broadcast_canonicalizer", - "//xla:test", - "//xla:test_helpers", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:broadcast_canonicalizer instead.", + deps = ["//xla/hlo/transforms:broadcast_canonicalizer"], ) cc_library( name = "bfloat16_conversion_folding", - srcs = ["bfloat16_conversion_folding.cc"], hdrs = ["bfloat16_conversion_folding.h"], - deps = [ - ":float_support", - ":hlo_dataflow_analysis", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "bfloat16_conversion_folding_test", - srcs = ["bfloat16_conversion_folding_test.cc"], - deps = [ - ":bfloat16_conversion_folding", - ":float_support", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:bfloat16_conversion_folding instead.", + deps = ["//xla/hlo/transforms:bfloat16_conversion_folding"], ) cc_library( name = "float_normalization", - srcs = ["float_normalization.cc"], hdrs = ["float_normalization.h"], - deps = [ - ":call_graph", - ":float_support", - ":hlo_dce", - ":tuple_simplifier", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "float_normalization_test", - srcs = ["float_normalization_test.cc"], - deps = [ - ":float_normalization", - ":float_support", - ":hlo_creation_utils", - ":hlo_verifier", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:float_normalization instead.", + deps = ["//xla/hlo/transforms:float_normalization"], ) cc_library( name = "bfloat16_propagation", - srcs = ["bfloat16_propagation.cc"], hdrs = ["bfloat16_propagation.h"], - deps = [ - ":float_support", - ":hlo_dataflow_analysis", - ":hlo_dce", - ":tuple_simplifier", - "//xla:literal", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "bfloat16_propagation_test", - srcs = ["bfloat16_propagation_test.cc"], - deps = [ - ":bfloat16_propagation", - ":float_support", - ":hlo_verifier", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:bfloat16_propagation instead.", + deps = ["//xla/hlo/transforms:bfloat16_propagation"], ) cc_library( @@ -448,8 +278,8 @@ xla_cc_test( deps = [ ":collective_ops_utils", ":collective_permute_decomposer", - ":hlo_parser", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_query", "//xla/service/gpu:backend_configs_cc", @@ -482,25 +312,9 @@ xla_cc_test( cc_library( name = "convert_async_collectives_to_sync", - srcs = ["convert_async_collectives_to_sync.cc"], hdrs = ["convert_async_collectives_to_sync.h"], - deps = [ - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_async_collectives_to_sync instead.", + deps = ["//xla/hlo/transforms:convert_async_collectives_to_sync"], ) cc_library( @@ -519,32 +333,14 @@ xla_cc_test( srcs = ["value_range_test.cc"], deps = [ ":hlo_module_config", - ":hlo_parser", ":value_range", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest_main", ], ) -xla_cc_test( - name = "convert_async_collectives_to_sync_test", - srcs = ["convert_async_collectives_to_sync_test.cc"], - deps = [ - ":convert_async_collectives_to_sync", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "collective_pipeliner", srcs = ["collective_pipeliner.cc"], @@ -553,10 +349,6 @@ cc_library( ":call_graph", ":collective_ops_utils", ":constant_value", - ":hlo_dce", - ":hlo_parser", - ":hlo_pass", - ":tuple_points_to_analysis", ":value_range", "//xla:comparison_util", "//xla:literal", @@ -564,10 +356,13 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_instruction_utils", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -591,10 +386,13 @@ xla_cc_test( srcs = ["collective_pipeliner_test.cc"], deps = [ ":collective_pipeliner", - ":hlo_parser", + ":hlo_module_config", + ":hlo_verifier", ":host_memory_offload_annotations_hdr", + "//xla:test_helpers", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/hlo/utils:hlo_matchers", "//xla/tests:filecheck", @@ -611,36 +409,9 @@ xla_cc_test( cc_library( name = "collective_quantizer", - srcs = ["collective_quantizer.cc"], hdrs = ["collective_quantizer.h"], - deps = [ - ":hlo_replication_analysis", - ":pattern_matcher", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "collective_quantizer_test", - srcs = ["collective_quantizer_test.cc"], - deps = [ - ":collective_quantizer", - ":hlo_verifier", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:collective_quantizer instead.", + deps = ["//xla/hlo/transforms:collective_quantizer"], ) cc_library( @@ -654,6 +425,8 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -669,8 +442,6 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -686,8 +457,8 @@ xla_cc_test( deps = [ ":dump", ":hlo_module_config", - ":hlo_parser", "//xla:xla_proto_cc", + "//xla/hlo/parser:hlo_parser", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", @@ -730,14 +501,14 @@ xla_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ - ":hlo_parser", ":shape_inference", "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla/client:padding", + "//xla/hlo/builder:padding", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -745,6 +516,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) @@ -808,14 +580,15 @@ xla_cc_test( "sharding_propagation_test.cc", ], deps = [ - ":hlo_dce", - ":hlo_parser", ":sharding_propagation", "//xla:protobuf_util", + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/transforms:hlo_constant_splitter", + "//xla/hlo/transforms:hlo_dce", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -840,6 +613,8 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/spmd/shardy:constants", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", @@ -853,14 +628,15 @@ xla_cc_test( "sharding_remover_test.cc", ], deps = [ - ":hlo_parser", ":sharding_remover", "//xla:status_macros", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) @@ -875,6 +651,7 @@ cc_library( deps = [ ":shape_inference", "//xla:status_macros", + "//xla:util", "//xla/hlo/ir:hlo", ], ) @@ -901,10 +678,10 @@ xla_test( "gpu", ], deps = [ - ":hlo_parser", "//xla:execution_options_util", "//xla:status_macros", "//xla:test", + "//xla/hlo/parser:hlo_parser", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", @@ -931,7 +708,6 @@ cc_library( name = "pattern_matcher", hdrs = ["pattern_matcher.h"], deps = [ - ":hlo_parser", "//xla:comparison_util", "//xla:literal", "//xla:shape_util", @@ -939,6 +715,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:ptrvec", + "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", @@ -953,7 +730,6 @@ xla_cc_test( name = "pattern_matcher_test", srcs = ["pattern_matcher_test.cc"], deps = [ - ":hlo_parser", ":pattern_matcher", "//xla:comparison_util", "//xla:literal_util", @@ -961,6 +737,7 @@ xla_cc_test( "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -1020,39 +797,9 @@ xla_cc_test( ) xla_cc_test( - name = "hlo_dfs_reachability_test", - srcs = ["hlo_dfs_reachability_test.cc"], - deps = [ - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_dfs_reachability", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/random", - "@local_tsl//tsl/platform:test_benchmark", - ], -) - -xla_cc_test( - name = "hlo_reachability_test", - srcs = ["hlo_reachability_test.cc"], - deps = [ - ":computation_placer", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/random", - "@local_tsl//tsl/platform:test_benchmark", - ], -) - -xla_cc_test( - name = "hlo_instruction_test", - srcs = ["hlo_instruction_test.cc"], - tags = ["not_run:arm"], + name = "hlo_instruction_test", + srcs = ["hlo_instruction_test.cc"], + tags = ["not_run:arm"], deps = [ ":pattern_matcher", ":pattern_matcher_gmock", @@ -1081,12 +828,12 @@ xla_cc_test( name = "hlo_sharding_test", srcs = ["hlo_sharding_test.cc"], deps = [ - ":hlo_parser", "//xla:protobuf_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla:xla_data_proto_cc", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/hash", @@ -1141,21 +888,9 @@ xla_cc_test( cc_library( name = "flatten_call_graph", - srcs = ["flatten_call_graph.cc"], hdrs = ["flatten_call_graph.h"], - deps = [ - ":call_graph", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:flatten_call_graph instead.", + deps = ["//xla/hlo/transforms:flatten_call_graph"], ) cc_library( @@ -1164,18 +899,21 @@ cc_library( hdrs = ["call_inliner.h"], deps = [ ":call_graph", - ":hlo_dce", ":hlo_domain_isolator", "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/service/spmd/shardy:constants", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -1188,13 +926,12 @@ xla_cc_test( srcs = ["call_inliner_test.cc"], deps = [ ":call_inliner", - ":hlo_parser", - "//xla:literal", "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1207,64 +944,9 @@ xla_cc_test( cc_library( name = "hlo_computation_deduplicator", - srcs = ["hlo_computation_deduplicator.cc"], hdrs = ["hlo_computation_deduplicator.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "hlo_computation_deduplicator_test", - size = "small", - srcs = ["hlo_computation_deduplicator_test.cc"], - deps = [ - ":hlo_computation_deduplicator", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "flatten_call_graph_test", - srcs = ["flatten_call_graph_test.cc"], - deps = [ - ":call_graph", - ":flatten_call_graph", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_computation_deduplicator instead.", + deps = ["//xla/hlo/transforms:hlo_computation_deduplicator"], ) cc_library( @@ -1274,11 +956,12 @@ cc_library( deps = [ ":compiler", "//xla:debug_options_flags", - "//xla:status_macros", "//xla:types", "//xla:util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", @@ -1302,18 +985,23 @@ cc_library( ":stream_pool", ":transfer_manager", "//xla:util", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform_id", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1335,41 +1023,41 @@ cc_library( ":dynamic_padder", ":executable", ":execution_tracker", - ":hlo_cost_analysis", ":hlo_execution_profile", ":hlo_module_config", ":hlo_module_util", + ":hlo_proto_cc", ":hlo_proto_util", - ":platform_util", - ":source_map_util", + ":shaped_buffer", ":stream_pool", ":transfer_manager", "//xla:debug_options_flags", "//xla:executable_run_options", - "//xla:execution_options_util", "//xla:literal", - "//xla:shape_layout", "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:scoped_annotation", ], alwayslink = 1, @@ -1393,9 +1081,10 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_computation", - "//xla/stream_executor", + "//xla/hlo/builder:xla_computation", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -1420,7 +1109,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -1436,7 +1125,6 @@ cc_library( hdrs = ["latency_hiding_scheduler.h"], deps = [ ":dump", - ":hlo_alias_analysis", ":hlo_buffer", ":hlo_cost_analysis", ":hlo_value", @@ -1444,8 +1132,9 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1467,12 +1156,12 @@ xla_cc_test( name = "latency_hiding_scheduler_test", srcs = ["latency_hiding_scheduler_test.cc"], deps = [ - ":async_collective_creator", ":hlo_cost_analysis", ":latency_hiding_scheduler", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:async_collective_creator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", @@ -1494,8 +1183,8 @@ cc_library( deps = [ ":collective_ops_utils", "//xla:util", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/container:flat_hash_map", @@ -1513,10 +1202,10 @@ xla_cc_test( name = "p2p_schedule_preparation_test", srcs = ["p2p_schedule_preparation_test.cc"], deps = [ - ":hlo_parser", ":p2p_schedule_preparation", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -1581,7 +1270,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:logging", @@ -1594,7 +1283,6 @@ cc_library( ":service", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_transfer_manager", - "//xla/stream_executor", "//xla/stream_executor/host:host_platform", ], ) @@ -1603,7 +1291,6 @@ cc_library( name = "gpu_plugin_impl", compatible_with = get_compatible_with_portable(), deps = [ - "//xla/stream_executor", ] + if_gpu_is_configured([ ":service", "//xla/service/gpu:gpu_compiler", @@ -1638,7 +1325,6 @@ cc_library( "//xla/backends/interpreter:compiler", "//xla/backends/interpreter:interpreter_transfer_manager", "//xla/backends/interpreter:platform", - "//xla/stream_executor", ], ) @@ -1650,18 +1336,19 @@ cc_library( deps = [ "//xla:shape_tree", "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1672,12 +1359,19 @@ xla_cc_test( ":cpu_plugin", ":platform_util", ":shaped_buffer", + "//xla:shape_tree", "//xla:shape_util", "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_benchmark", ], ) @@ -1708,15 +1402,17 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/strings:proto_serialization", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", "@local_tsl//tsl/platform:env", @@ -1727,10 +1423,28 @@ cc_library( ] + internal_hlo_deps(), ) +xla_cc_test( + name = "executable_test", + srcs = ["executable_test.cc"], + deps = [ + ":executable", + ":hlo_execution_profile", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "compiler", srcs = ["compiler.cc"], hdrs = ["compiler.h"], + visibility = internal_visibility(["//xla/internal:hwi_internal"]), deps = [ ":buffer_assignment", ":buffer_value", @@ -1742,12 +1456,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/pjrt/distributed:key_value_store_interface", - "//xla/stream_executor", "//xla/stream_executor:dnn", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", @@ -1764,8 +1477,8 @@ xla_test( deps = [ ":compiler", "//xla:autotune_results_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_init", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1811,8 +1524,10 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -1841,6 +1556,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1861,7 +1577,7 @@ cc_library( "//xla:executable_run_options", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", ], @@ -1913,19 +1629,17 @@ cc_library( deps = [ ":buffer_assignment_proto_cc", ":buffer_value", - ":buffer_value_containers", ":call_graph", - ":hlo_alias_analysis", ":hlo_buffer", - ":hlo_dataflow_analysis", - ":hlo_ordering", ":hlo_proto_cc", ":hlo_value", ":logical_buffer", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service/heap_simulator", @@ -1934,6 +1648,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -1944,6 +1659,7 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1956,74 +1672,47 @@ xla_cc_test( ":call_graph", ":copy_insertion", ":cpu_plugin", - ":flatten_call_graph", - ":hlo_alias_analysis", - ":hlo_dce", - ":hlo_memory_scheduler", - ":hlo_ordering", - ":hlo_parser", + ":hlo_buffer", ":hlo_proto_cc", ":hlo_proto_util", + ":hlo_value", + ":logical_buffer", + "//xla:comparison_util", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/service/memory_space_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) cc_library( name = "hlo_ordering", - srcs = ["hlo_ordering.cc"], hdrs = ["hlo_ordering.h"], - deps = [ - ":call_graph", - ":hlo_dataflow_analysis", - ":hlo_proto_cc", - ":hlo_value", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "hlo_ordering_test", - size = "small", - srcs = ["hlo_ordering_test.cc"], - deps = [ - ":hlo_dataflow_analysis", - ":hlo_ordering", - ":hlo_value", - "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_ordering instead.", + deps = ["//xla/hlo/analysis:hlo_ordering"], ) xla_cc_test( @@ -2048,11 +1737,11 @@ cc_library( srcs = ["hlo_module_group_metadata.cc"], hdrs = ["hlo_module_group_metadata.h"], deps = [ - ":hlo_alias_analysis", - ":tuple_points_to_analysis", "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2087,8 +1776,8 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", @@ -2106,14 +1795,14 @@ xla_cc_test( name = "hlo_schedule_test", srcs = ["hlo_schedule_test.cc"], deps = [ - ":hlo_dce", - ":hlo_memory_scheduler", - ":hlo_ordering", "//xla:shape_util", "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:hlo_memory_scheduler", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -2128,12 +1817,10 @@ xla_cc_test( name = "hlo_input_output_alias_config_test", srcs = ["hlo_input_output_alias_config_test.cc"], deps = [ - ":hlo_dce", - ":hlo_memory_scheduler", - ":hlo_ordering", "//xla:shape_util", "//xla:test_helpers", "//xla:types", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2146,65 +1833,10 @@ xla_cc_test( cc_library( name = "hlo_memory_scheduler", - srcs = ["hlo_memory_scheduler.cc"], hdrs = ["hlo_memory_scheduler.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_memory_scheduler instead.", local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":buffer_value", - ":hlo_alias_analysis", - ":logical_buffer", - ":tuple_points_to_analysis", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/service/heap_simulator", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/lib/gtl:map_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - ], -) - -xla_cc_test( - name = "hlo_memory_scheduler_test", - srcs = ["hlo_memory_scheduler_test.cc"], - deps = [ - ":buffer_value", - ":hlo_alias_analysis", - ":hlo_dce", - ":hlo_memory_scheduler", - ":hlo_ordering", - ":hlo_value", - ":logical_buffer", - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/heap_simulator", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], + deps = ["//xla/hlo/transforms:hlo_memory_scheduler"], ) cc_library( @@ -2222,15 +1854,15 @@ cc_library( hdrs = ["instruction_fusion.h"], deps = [ ":fusion_queue", - ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_module_config", ":pattern_matcher", "//xla:debug_options_flags", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -2252,11 +1884,11 @@ xla_cc_test( name = "instruction_fusion_test", srcs = ["instruction_fusion_test.cc"], deps = [ - ":hlo_parser", ":instruction_fusion", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2269,14 +1901,14 @@ cc_library( srcs = ["multi_output_fusion.cc"], hdrs = ["multi_output_fusion.h"], deps = [ - ":hlo_dataflow_analysis", - ":hlo_dce", "//xla:debug_options_flags", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -2299,9 +1931,9 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:comparators", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -2333,9 +1965,9 @@ xla_cc_test( srcs = ["fusion_node_indexing_evaluation_test.cc"], deps = [ ":fusion_node_indexing_evaluation", - ":hlo_parser", ":instruction_fusion", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", @@ -2387,62 +2019,60 @@ cc_library( cc_library( name = "op_expander_pass", - srcs = ["op_expander_pass.cc"], hdrs = ["op_expander_pass.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:op_expander_pass instead.", + deps = ["//xla/hlo/transforms:op_expander_pass"], +) + +cc_library( + name = "gather_expander", + srcs = ["gather_expander.cc"], + hdrs = ["gather_expander.h"], deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gather_expander", - srcs = ["gather_expander.cc"], - hdrs = ["gather_expander.h"], - deps = [ + ":gather_scatter_utils", ":hlo_creation_utils", - ":op_expander_pass", ":while_util", "//xla:literal_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) cc_library( name = "optimization_barrier_expander", - srcs = ["optimization_barrier_expander.cc"], hdrs = ["optimization_barrier_expander.h"], - deps = [ - ":op_expander_pass", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:optimization_barrier_expander instead.", + deps = ["//xla/hlo/transforms:optimization_barrier_expander"], ) cc_library( name = "comparison_expander", - srcs = ["comparison_expander.cc"], hdrs = ["comparison_expander.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:comparison_expander instead.", + deps = ["//xla/hlo/transforms:comparison_expander"], +) + +cc_library( + name = "scatter_utils", + srcs = ["scatter_utils.cc"], + hdrs = ["scatter_utils.h"], deps = [ - ":op_expander_pass", - "//xla:comparison_util", - "//xla:literal_util", + ":call_inliner", + ":hlo_creation_utils", "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2451,14 +2081,40 @@ cc_library( srcs = ["scatter_expander.cc"], hdrs = ["scatter_expander.h"], deps = [ - ":call_inliner", + ":gather_scatter_utils", ":hlo_creation_utils", - ":op_expander_pass", + ":scatter_utils", ":while_util", "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "scatter_determinism_expander", + srcs = ["scatter_determinism_expander.cc"], + hdrs = ["scatter_determinism_expander.h"], + deps = [ + ":hlo_creation_utils", + ":scatter_utils", + "//xla:array", + "//xla:array2d", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2472,10 +2128,29 @@ xla_cc_test( "//xla:test", "//xla:types", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "scatter_determinism_expander_test", + srcs = ["scatter_determinism_expander_test.cc"], + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":scatter_determinism_expander", + "//xla:literal", + "//xla:test", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2484,17 +2159,18 @@ cc_library( srcs = ["triangular_solve_expander.cc"], hdrs = ["triangular_solve_expander.h"], deps = [ + ":hlo_creation_utils", ":hlo_module_config", - ":op_expander_pass", "//xla:shape_util", "//xla:util", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:constants", - "//xla/client/lib:math", - "//xla/client/lib:matrix", - "//xla/client/lib:slicing", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/builder/lib:slicing", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -2523,167 +2199,43 @@ xla_cc_test( cc_library( name = "cholesky_expander", - srcs = ["cholesky_expander.cc"], hdrs = ["cholesky_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:constants", - "//xla/client/lib:loops", - "//xla/client/lib:math", - "//xla/client/lib:matrix", - "//xla/client/lib:slicing", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], + deps = ["//xla/hlo/transforms:cholesky_expander"], ) cc_library( name = "qr_expander", - srcs = ["qr_expander.cc"], hdrs = ["qr_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:constants", - "//xla/client/lib:loops", - "//xla/client/lib:math", - "//xla/client/lib:matrix", - "//xla/client/lib:qr", - "//xla/client/lib:slicing", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:qr_expander instead.", + deps = ["//xla/hlo/transforms:qr_expander"], ) cc_library( name = "real_imag_expander", - srcs = ["real_imag_expander.cc"], hdrs = ["real_imag_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - ], -) - -xla_cc_test( - name = "real_imag_expander_test", - size = "small", - srcs = ["real_imag_expander_test.cc"], - deps = [ - ":hlo_creation_utils", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":real_imag_expander", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:real_imag_expander instead.", + deps = ["//xla/hlo/transforms:real_imag_expander"], ) cc_library( name = "eigh_expander", - srcs = ["eigh_expander.cc"], hdrs = ["eigh_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:comparators", - "//xla/client/lib:constants", - "//xla/client/lib:loops", - "//xla/client/lib:math", - "//xla/client/lib:matrix", - "//xla/client/lib:slicing", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:eigh_expander instead.", + deps = ["//xla/hlo/transforms:eigh_expander"], ) cc_library( name = "convolution_4d_expander", - srcs = ["convolution_4d_expander.cc"], hdrs = ["convolution_4d_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "convolution_4d_expander_test", - srcs = ["convolution_4d_expander_test.cc"], - deps = [ - "convolution_4d_expander", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convolution_4d_expander instead.", + deps = ["//xla/hlo/transforms:convolution_4d_expander"], ) cc_library( name = "convolution_pred_expander", - srcs = ["convolution_pred_expander.cc"], hdrs = ["convolution_pred_expander.h"], - deps = [ - ":hlo_creation_utils", - ":op_expander_pass", - ":pattern_matcher", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "convolution_pred_expander_test", - srcs = ["convolution_pred_expander_test.cc"], - deps = [ - ":convolution_pred_expander", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convolution_pred_expander instead.", + deps = ["//xla/hlo/transforms:convolution_pred_expander"], ) xla_test( @@ -2696,336 +2248,145 @@ xla_test( ], deps = [ ":batchnorm_expander", - ":hlo_parser", + "//xla:error_spec", "//xla:literal", "//xla:shape_util", "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( name = "algebraic_simplifier", - srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], copts = tsl_copts(), + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:algebraic_simplifier instead.", + deps = ["//xla/hlo/transforms:algebraic_simplifier"], +) + +cc_library( + name = "tree_reduction_rewriter", + hdrs = ["tree_reduction_rewriter.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:tree_reduction_rewriter instead.", + deps = ["//xla/hlo/transforms:tree_reduction_rewriter"], +) + +xla_test( + name = "algebraic_simplifier_overflow_test", + srcs = ["algebraic_simplifier_overflow_test.cc"], + deps = [ + "//xla:error_spec", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "simplify_fp_conversions", + hdrs = ["simplify_fp_conversions.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:simplify_fp_conversions instead.", + deps = ["//xla/hlo/transforms:simplify_fp_conversions"], +) + +cc_library( + name = "logistic_expander", + hdrs = ["logistic_expander.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:logistic_expander instead.", + deps = ["//xla/hlo/transforms:logistic_expander"], +) + +cc_library( + name = "collectives_schedule_linearizer", + hdrs = ["collectives_schedule_linearizer.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:collectives_schedule_linearizer instead.", + deps = ["//xla/hlo/transforms:collectives_schedule_linearizer"], +) + +cc_library( + name = "collective_combiner_utils", + hdrs = ["collective_combiner_utils.h"], deps = [ - ":hlo_cost_analysis", - ":hlo_creation_utils", - ":hlo_module_config", - ":host_memory_offload_annotations_hdr", - ":pattern_matcher", - ":shape_inference", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_instruction_utils", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_sharding_util", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "tree_reduction_rewriter", - srcs = ["tree_reduction_rewriter.cc"], - hdrs = ["tree_reduction_rewriter.h"], + name = "collective_decomposer_utils", + srcs = ["collective_decomposer_utils.cc"], + hdrs = ["collective_decomposer_utils.h"], deps = [ - ":shape_inference", + ":collective_ops_utils", + ":hlo_module_config", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:padding", + "//xla:status_macros", "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "algebraic_simplifier_test", - srcs = ["algebraic_simplifier_test.cc"], - deps = [ - ":algebraic_simplifier", - ":hlo_creation_utils", - ":hlo_parser", - ":host_memory_offload_annotations_hdr", - ":layout_assignment", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":shape_inference", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -xla_test( - name = "algebraic_simplifier_overflow_test", - srcs = ["algebraic_simplifier_overflow_test.cc"], - deps = [ - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -cc_library( - name = "simplify_fp_conversions", - srcs = ["simplify_fp_conversions.cc"], - hdrs = ["simplify_fp_conversions.h"], - deps = [ - "//xla:comparison_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "simplify_fp_conversions_test", - srcs = ["simplify_fp_conversions_test.cc"], - deps = [ - ":simplify_fp_conversions", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:status_matchers", - ], -) - -cc_library( - name = "logistic_expander", - srcs = ["logistic_expander.cc"], - hdrs = ["logistic_expander.h"], - deps = [ - ":hlo_creation_utils", - ":op_expander_pass", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "logistic_expander_test", - srcs = ["logistic_expander_test.cc"], - deps = [ - ":dynamic_padder", - ":hlo_parser", - ":logistic_expander", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "collectives_schedule_linearizer", - srcs = ["collectives_schedule_linearizer.cc"], - hdrs = ["collectives_schedule_linearizer.h"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "collectives_schedule_linearizer_test", - srcs = ["collectives_schedule_linearizer_test.cc"], - deps = [ - ":collectives_schedule_linearizer", - ":pattern_matcher", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], + name = "all_gather_broadcast_reorder", + hdrs = ["all_gather_broadcast_reorder.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_gather_broadcast_reorder instead.", + deps = ["//xla/hlo/transforms:all_gather_broadcast_reorder"], ) cc_library( - name = "collective_combiner_utils", - hdrs = ["collective_combiner_utils.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], + name = "bitcast_dtypes_expander", + hdrs = ["bitcast_dtypes_expander.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:bitcast_dtypes_expander instead.", + deps = ["//xla/hlo/transforms:bitcast_dtypes_expander"], ) cc_library( - name = "collective_decomposer_utils", - srcs = ["collective_decomposer_utils.cc"], - hdrs = ["collective_decomposer_utils.h"], - deps = [ - ":collective_ops_utils", - ":hlo_module_config", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - ], + name = "all_gather_combiner", + hdrs = ["all_gather_combiner.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_gather_combiner instead.", + deps = ["//xla/hlo/transforms:all_gather_combiner"], ) cc_library( - name = "all_gather_broadcast_reorder", - srcs = ["all_gather_broadcast_reorder.cc"], - hdrs = ["all_gather_broadcast_reorder.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], + name = "all_reduce_combiner", + hdrs = ["all_reduce_combiner.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_reduce_combiner instead.", + deps = ["//xla/hlo/transforms:all_reduce_combiner"], ) cc_library( - name = "bitcast_dtypes_expander", - srcs = ["bitcast_dtypes_expander.cc"], - hdrs = ["bitcast_dtypes_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:broadcast", - "//xla/client/lib:constants", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "bitcast_dtypes_expander_test", - srcs = ["bitcast_dtypes_expander_test.cc"], - deps = [ - ":bitcast_dtypes_expander", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "all_gather_broadcast_reorder_test", - srcs = ["all_gather_broadcast_reorder_test.cc"], - deps = [ - ":all_gather_broadcast_reorder", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - ], + name = "all_reduce_contiguous", + hdrs = ["all_reduce_contiguous.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:all_reduce_contiguous instead.", + deps = ["//xla/hlo/transforms:all_reduce_contiguous"], ) cc_library( - name = "all_gather_combiner", - srcs = ["all_gather_combiner.cc"], - hdrs = ["all_gather_combiner.h"], + name = "reduce_scatter_combiner", + srcs = ["reduce_scatter_combiner.cc"], + hdrs = ["reduce_scatter_combiner.h"], deps = [ + ":all_reduce_key", ":collective_combiner_utils", + ":collective_ops_utils", ":hlo_domain_map", "//xla:shape_util", "//xla:status_macros", @@ -3033,934 +2394,575 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", - "//xla/hlo/utils:hlo_sharding_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "all_gather_combiner_test", - srcs = ["all_gather_combiner_test.cc"], + name = "reduce_scatter_combiner_test", + srcs = ["reduce_scatter_combiner_test.cc"], deps = [ - ":all_gather_combiner", - "//xla:xla_data_proto_cc", + ":reduce_scatter_combiner", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "all_reduce_combiner", - srcs = ["all_reduce_combiner.cc"], - hdrs = ["all_reduce_combiner.h"], + name = "all_reduce_simplifier", + srcs = ["all_reduce_simplifier.cc"], + hdrs = ["all_reduce_simplifier.h"], deps = [ - ":all_reduce_key", - ":collective_combiner_utils", - ":hlo_domain_map", - "//xla:array2d", + ":collective_ops_utils", + ":hlo_module_config", + "//xla:literal_util", "//xla:shape_util", - "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_replication_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "//xla/hlo/utils:hlo_sharding_util", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "all_reduce_combiner_test", - srcs = ["all_reduce_combiner_test.cc"], + name = "all_reduce_simplifier_test", + srcs = ["all_reduce_simplifier_test.cc"], deps = [ - ":all_reduce_combiner", - "//xla:literal", - "//xla:literal_util", + ":all_reduce_simplifier", + ":hlo_module_config", + ":pattern_matcher", + ":pattern_matcher_gmock", "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "all_reduce_contiguous", - srcs = ["all_reduce_contiguous.cc"], - hdrs = ["all_reduce_contiguous.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - -xla_cc_test( - name = "all_reduce_contiguous_test", - srcs = ["all_reduce_contiguous_test.cc"], - deps = [ - ":all_reduce_contiguous", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "reduce_scatter_combiner", - srcs = ["reduce_scatter_combiner.cc"], - hdrs = ["reduce_scatter_combiner.h"], - deps = [ - ":all_reduce_key", - ":collective_combiner_utils", - ":collective_ops_utils", - ":hlo_domain_map", - ":shape_inference", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "reduce_scatter_combiner_test", - srcs = ["reduce_scatter_combiner_test.cc"], - deps = [ - ":reduce_scatter_combiner", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "all_reduce_simplifier", - srcs = ["all_reduce_simplifier.cc"], - hdrs = ["all_reduce_simplifier.h"], - deps = [ - ":collective_ops_utils", - ":hlo_module_config", - ":hlo_replication_analysis", - "//xla:literal_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "all_reduce_simplifier_test", - srcs = ["all_reduce_simplifier_test.cc"], - deps = [ - ":all_reduce_simplifier", - ":hlo_module_config", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:window_util", + "//xla:test", + "//xla:types", + "//xla:window_util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", - ], -) - -cc_library( - name = "reduce_scatter_decomposer", - srcs = ["reduce_scatter_decomposer.cc"], - hdrs = ["reduce_scatter_decomposer.h"], - deps = [ - ":collective_decomposer_utils", - ":collective_ops_utils", - ":hlo_module_config", - "//xla:literal_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/status:statusor", - ], -) - -xla_cc_test( - name = "reduce_scatter_decomposer_test", - srcs = ["reduce_scatter_decomposer_test.cc"], - deps = [ - ":collective_ops_utils", - ":reduce_scatter_decomposer", - "//xla:literal_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "reduce_scatter_reassociate", - srcs = ["reduce_scatter_reassociate.cc"], - hdrs = ["reduce_scatter_reassociate.h"], - deps = [ - ":all_reduce_key", - ":collective_ops_utils", - ":hlo_domain_map", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "reduce_scatter_reassociate_test", - srcs = ["reduce_scatter_reassociate_test.cc"], - deps = [ - ":reduce_scatter_reassociate", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -cc_library( - name = "batch_dot_simplification", - srcs = ["batch_dot_simplification.cc"], - hdrs = ["batch_dot_simplification.h"], - deps = [ - ":hlo_creation_utils", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - ], -) - -xla_cc_test( - name = "batch_dot_simplification_test", - srcs = ["batch_dot_simplification_test.cc"], - deps = [ - ":batch_dot_simplification", - "//xla:test", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - ], -) - -xla_cc_test( - name = "gather_expander_test", - srcs = ["gather_expander_test.cc"], - deps = [ - ":gather_expander", - "//xla:test", - "//xla/hlo/utils:hlo_query", - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - ], -) - -cc_library( - name = "conditional_simplifier", - srcs = ["conditional_simplifier.cc"], - hdrs = ["conditional_simplifier.h"], - deps = [ - ":call_graph", - ":call_inliner", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "conditional_simplifier_test", - srcs = ["conditional_simplifier_test.cc"], - deps = [ - ":conditional_simplifier", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:status", - ], -) - -cc_library( - name = "conditional_code_motion", - srcs = ["conditional_code_motion.cc"], - hdrs = ["conditional_code_motion.h"], - deps = [ - ":hlo_cse", - ":hlo_dce", - ":hlo_verifier", - ":tuple_simplifier", - "//xla:debug_options_flags", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "conditional_code_motion_test", - srcs = ["conditional_code_motion_test.cc"], - deps = [ - ":conditional_code_motion", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status", - ], -) - -cc_library( - name = "convolution_group_converter", - srcs = ["convolution_group_converter.cc"], - hdrs = ["convolution_group_converter.h"], - deps = [ - ":hlo_creation_utils", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "convolution_group_converter_test", - size = "small", - srcs = ["convolution_group_converter_test.cc"], - deps = [ - ":convolution_group_converter", - "//xla:test", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - ], -) - -cc_library( - name = "space_to_batch_converter", - srcs = ["space_to_batch_converter.cc"], - hdrs = ["space_to_batch_converter.h"], - deps = [ - ":hlo_creation_utils", - ":pattern_matcher", - ":shape_inference", - "//xla:debug_options_flags", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/tsl/lib/core:bitmap", - "@com_google_absl//absl/algorithm", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "space_to_batch_converter_test", - size = "small", - srcs = ["space_to_batch_converter_test.cc"], - deps = [ - ":space_to_batch_converter", - "//xla:test", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - ], -) - -cc_library( - name = "scan_loop_accumulator_input_unification", - srcs = ["scan_loop_accumulator_input_unification.cc"], - hdrs = ["scan_loop_accumulator_input_unification.h"], - deps = [ - ":call_graph", - ":hlo_alias_analysis", - ":hlo_dataflow_analysis", - ":pattern_matcher", - ":tuple_simplifier", - ":while_loop_simplifier", - ":while_loop_unroller", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scan_loop_accumulator_input_unification_test", - srcs = ["scan_loop_accumulator_input_unification_test.cc"], - deps = [ - ":copy_insertion", - ":scan_loop_accumulator_input_unification", - "//xla:literal", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "hlo_unstacker", - srcs = ["hlo_unstacker.cc"], - hdrs = ["hlo_unstacker.h"], + name = "reduce_scatter_decomposer", + srcs = ["reduce_scatter_decomposer.cc"], + hdrs = ["reduce_scatter_decomposer.h"], deps = [ - ":algebraic_simplifier", - ":hlo_creation_utils", - ":pattern_matcher", - ":tuple_util", - ":while_loop_unroller", + ":collective_decomposer_utils", + ":collective_ops_utils", + ":hlo_module_config", + "//xla:literal_util", "//xla:shape_util", - "//xla:util", - "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "hlo_unstacker_test", - srcs = ["hlo_unstacker_test.cc"], - tags = if_google(["requires-net:external"]), + name = "reduce_scatter_decomposer_test", + srcs = ["reduce_scatter_decomposer_test.cc"], deps = [ - ":hlo_unstacker", + ":collective_ops_utils", + ":reduce_scatter_decomposer", + "//xla:literal_util", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "@com_google_googletest//:gtest_main", + "//xla/tests:xla_internal_test_main", # fixdeps: keep "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "while_loop_unroller", - srcs = ["while_loop_unroller.cc"], - hdrs = ["while_loop_unroller.h"], + name = "reduce_scatter_reassociate", + srcs = ["reduce_scatter_reassociate.cc"], + hdrs = ["reduce_scatter_reassociate.h"], deps = [ - ":call_inliner", + ":all_reduce_key", ":collective_ops_utils", - ":flatten_call_graph", - ":hlo_alias_analysis", - ":hlo_buffer", - ":hlo_creation_utils", - ":hlo_cse", - ":hlo_value", - ":pattern_matcher", - ":tuple_simplifier", - ":while_loop_analysis", - ":while_loop_constant_sinking", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", + ":hlo_domain_map", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/algorithm", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "while_loop_unroller_test", - srcs = ["while_loop_unroller_test.cc"], + name = "reduce_scatter_reassociate_test", + srcs = ["reduce_scatter_reassociate_test.cc"], deps = [ - ":while_loop_unroller", - "//xla:literal", + ":reduce_scatter_reassociate", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "while_loop_analysis", - srcs = ["while_loop_analysis.cc"], - hdrs = ["while_loop_analysis.h"], - deps = [ - ":pattern_matcher", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - ], + name = "batch_dot_simplification", + hdrs = ["batch_dot_simplification.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:batch_dot_simplification instead.", + deps = ["//xla/hlo/transforms:batch_dot_simplification"], ) xla_cc_test( - name = "while_loop_analysis_test", - srcs = ["while_loop_analysis_test.cc"], + name = "gather_expander_test", + srcs = ["gather_expander_test.cc"], deps = [ - ":while_loop_analysis", - "//xla:comparison_util", + ":gather_expander", "//xla:test", - "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings:string_view", ], ) cc_library( - name = "while_loop_simplifier", - srcs = ["while_loop_simplifier.cc"], - hdrs = ["while_loop_simplifier.h"], + name = "conditional_simplifier", + srcs = ["conditional_simplifier.cc"], + hdrs = ["conditional_simplifier.h"], deps = [ + ":call_graph", ":call_inliner", - ":hlo_creation_utils", - ":hlo_dce", - ":pattern_matcher", - ":while_loop_analysis", - "//xla:comparison_util", - "//xla:literal_util", + "//xla:literal", "//xla:shape_util", "//xla:status_macros", - "//xla:union_find", + "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "while_loop_simplifier_test", - srcs = ["while_loop_simplifier_test.cc"], + name = "conditional_simplifier_test", + srcs = ["conditional_simplifier_test.cc"], deps = [ - ":hlo_dce", - ":hlo_parser", - ":tuple_simplifier", - ":while_loop_simplifier", + ":conditional_simplifier", "//xla:literal_util", "//xla:shape_util", "//xla:test", + "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status", ], ) cc_library( - name = "while_loop_trip_count_annotator", - srcs = ["while_loop_trip_count_annotator.cc"], - hdrs = ["while_loop_trip_count_annotator.h"], + name = "conditional_code_motion", + srcs = ["conditional_code_motion.cc"], + hdrs = ["conditional_code_motion.h"], deps = [ - ":while_loop_analysis", - "//xla:xla_data_proto_cc", + ":hlo_cse", + ":hlo_verifier", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", ], ) xla_cc_test( - name = "while_loop_trip_count_annotator_test", - srcs = ["while_loop_trip_count_annotator_test.cc"], + name = "conditional_code_motion_test", + srcs = ["conditional_code_motion_test.cc"], deps = [ - ":while_loop_trip_count_annotator", + ":conditional_code_motion", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", + "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/platform:statusor", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status", ], ) cc_library( - name = "defuser", - srcs = ["defuser.cc"], - hdrs = ["defuser.h"], + name = "convolution_group_converter", + hdrs = ["convolution_group_converter.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convolution_group_converter instead.", + deps = ["//xla/hlo/transforms:convolution_group_converter"], +) + +cc_library( + name = "space_to_batch_converter", + srcs = ["space_to_batch_converter.cc"], + hdrs = ["space_to_batch_converter.h"], deps = [ - ":call_graph", + ":hlo_creation_utils", + ":pattern_matcher", + ":shape_inference", + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", "//xla:status_macros", "//xla:types", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/tsl/lib/core:bitmap", + "@com_google_absl//absl/algorithm", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "defuser_test", - srcs = ["defuser_test.cc"], - deps = [ - ":defuser", - "//xla:literal", - "//xla:shape_util", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "despecializer_test", - srcs = ["despecializer_test.cc"], + name = "space_to_batch_converter_test", + size = "small", + srcs = ["space_to_batch_converter_test.cc"], deps = [ - ":despecializer", - "//xla:literal", - "//xla:shape_util", + ":space_to_batch_converter", + "//xla:test", + "//xla:types", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "dot_decomposer", - srcs = ["dot_decomposer.cc"], - hdrs = ["dot_decomposer.h"], + name = "scan_loop_accumulator_input_unification", + srcs = ["scan_loop_accumulator_input_unification.cc"], + hdrs = ["scan_loop_accumulator_input_unification.h"], deps = [ - ":shape_inference", + ":call_graph", + ":pattern_matcher", + ":while_loop_simplifier", + ":while_loop_unroller", + "//xla:literal_util", "//xla:shape_util", + "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", + "//xla/hlo/transforms:tuple_simplifier", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "dot_decomposer_test", - srcs = ["dot_decomposer_test.cc"], + name = "scan_loop_accumulator_input_unification_test", + srcs = ["scan_loop_accumulator_input_unification_test.cc"], deps = [ - ":dot_decomposer", - ":pattern_matcher", - ":pattern_matcher_gmock", + ":copy_insertion", + ":scan_loop_accumulator_input_unification", + "//xla:literal", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/strings:string_view", + "//xla/tests:literal_test_util", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", ], ) cc_library( - name = "dot_dimension_merger", - srcs = ["dot_dimension_merger.cc"], - hdrs = ["dot_dimension_merger.h"], + name = "hlo_unstacker", + srcs = ["hlo_unstacker.cc"], + hdrs = ["hlo_unstacker.h"], deps = [ ":hlo_creation_utils", + ":pattern_matcher", + ":tuple_util", + ":while_loop_unroller", "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:algebraic_simplifier", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "dot_dimension_merger_test", - srcs = ["dot_dimension_merger_test.cc"], + name = "hlo_unstacker_test", + srcs = ["hlo_unstacker_test.cc"], + tags = if_google(["requires-net:external"]), deps = [ - ":dot_dimension_merger", - ":hlo_parser", + ":hlo_unstacker", + "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "dot_merger", - srcs = ["dot_merger.cc"], - hdrs = ["dot_merger.h"], + name = "while_loop_unroller", + srcs = ["while_loop_unroller.cc"], + hdrs = ["while_loop_unroller.h"], deps = [ - ":shape_inference", - "//xla:protobuf_util", + ":call_inliner", + ":collective_ops_utils", + ":hlo_buffer", + ":hlo_creation_utils", + ":hlo_cse", + ":hlo_value", + ":pattern_matcher", + ":while_loop_constant_sinking", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:while_loop_analysis", + "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service/graphcycles", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/utils:hlo_query", + "@com_google_absl//absl/algorithm", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "dot_merger_test", - srcs = ["dot_merger_test.cc"], + name = "while_loop_unroller_test", + srcs = ["while_loop_unroller_test.cc"], deps = [ - ":algebraic_simplifier", - ":dot_merger", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:shape_util", + ":while_loop_unroller", + "//xla:literal", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings:string_view", + "//xla/tests:literal_test_util", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "convert_mover", - srcs = ["convert_mover.cc"], - hdrs = ["convert_mover.h"], + name = "while_loop_analysis", + hdrs = ["while_loop_analysis.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:while_loop_analysis instead.", + deps = ["//xla/hlo/analysis:while_loop_analysis"], +) + +cc_library( + name = "while_loop_simplifier", + srcs = ["while_loop_simplifier.cc"], + hdrs = ["while_loop_simplifier.h"], deps = [ + ":call_inliner", ":hlo_creation_utils", - "//xla:literal", + ":pattern_matcher", + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", + "//xla:status_macros", + "//xla:union_find", + "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "convert_mover_test", - srcs = ["convert_mover_test.cc"], + name = "while_loop_simplifier_test", + srcs = ["while_loop_simplifier_test.cc"], deps = [ - ":convert_mover", - ":pattern_matcher", - ":pattern_matcher_gmock", + ":while_loop_simplifier", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], ) +cc_library( + name = "while_loop_trip_count_annotator", + hdrs = ["while_loop_trip_count_annotator.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:while_loop_trip_count_annotator instead.", + deps = ["//xla/hlo/transforms:while_loop_trip_count_annotator"], +) + +cc_library( + name = "defuser", + hdrs = ["defuser.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:defuser instead.", + deps = ["//xla/hlo/transforms:defuser"], +) + +cc_library( + name = "dot_decomposer", + hdrs = ["dot_decomposer.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dot_decomposer instead.", + deps = ["//xla/hlo/transforms:dot_decomposer"], +) + +cc_library( + name = "dot_dimension_merger", + hdrs = ["dot_dimension_merger.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dot_dimension_merger instead.", + deps = ["//xla/hlo/transforms:dot_dimension_merger"], +) + +cc_library( + name = "dot_merger", + hdrs = ["dot_merger.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dot_merger instead.", + deps = ["//xla/hlo/transforms:dot_merger"], +) + +cc_library( + name = "convert_mover", + hdrs = ["convert_mover.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_mover instead.", + deps = ["//xla/hlo/transforms:convert_mover"], +) + cc_library( name = "all_to_all_decomposer", srcs = ["all_to_all_decomposer.cc"], hdrs = ["all_to_all_decomposer.h"], deps = [ - ":op_expander_pass", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -3974,6 +2976,7 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/container:flat_hash_set", @@ -3991,8 +2994,8 @@ xla_cc_test( srcs = ["all_gather_decomposer_test.cc"], deps = [ ":all_gather_decomposer", - ":hlo_parser", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -4003,103 +3006,30 @@ xla_cc_test( cc_library( name = "tuple_simplifier", - srcs = ["tuple_simplifier.cc"], hdrs = ["tuple_simplifier.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "tuple_simplifier_test", - srcs = ["tuple_simplifier_test.cc"], - deps = [ - ":tuple_simplifier", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:tuple_simplifier instead.", + deps = ["//xla/hlo/transforms:tuple_simplifier"], ) cc_library( name = "reshape_mover", - srcs = ["reshape_mover.cc"], hdrs = ["reshape_mover.h"], - deps = [ - ":hlo_creation_utils", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@local_tsl//tsl/platform:errors", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reshape_mover instead.", + deps = ["//xla/hlo/transforms:reshape_mover"], ) cc_library( name = "reshape_decomposer", - srcs = ["reshape_decomposer.cc"], hdrs = ["reshape_decomposer.h"], - deps = [ - ":hlo_creation_utils", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/status", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reshape_decomposer instead.", + deps = ["//xla/hlo/transforms:reshape_decomposer"], ) cc_library( name = "reduce_decomposer", - srcs = ["reduce_decomposer.cc"], hdrs = ["reduce_decomposer.h"], - deps = [ - ":hlo_creation_utils", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/status", - ], -) - -xla_cc_test( - name = "reduce_decomposer_test", - srcs = ["reduce_decomposer_test.cc"], - deps = [ - ":hlo_parser", - ":reduce_decomposer", - "//xla:test", - "//xla:test_helpers", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "reshape_decomposer_test", - srcs = ["reshape_decomposer_test.cc"], - deps = [ - ":hlo_parser", - ":reshape_decomposer", - "//xla:test", - "//xla:test_helpers", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reduce_decomposer instead.", + deps = ["//xla/hlo/transforms:reduce_decomposer"], ) cc_library( @@ -4125,7 +3055,6 @@ cc_library( ":call_inliner", ":dynamic_window_utils", ":hlo_creation_utils", - ":hlo_dataflow_analysis", ":hlo_value", ":tuple_util", ":while_util", @@ -4138,6 +3067,7 @@ cc_library( "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -4157,40 +3087,11 @@ cc_library( ], ) -cc_library( - name = "dynamic_dimension_simplifier", - srcs = ["dynamic_dimension_simplifier.cc"], - hdrs = ["dynamic_dimension_simplifier.h"], - deps = [ - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - ], -) - -xla_cc_test( - name = "dynamic_dimension_simplifier_test", - srcs = ["dynamic_dimension_simplifier_test.cc"], - deps = [ - ":dynamic_dimension_simplifier", - ":hlo_creation_utils", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":shape_inference", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - ], +cc_library( + name = "dynamic_dimension_simplifier", + hdrs = ["dynamic_dimension_simplifier.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dynamic_dimension_simplifier instead.", + deps = ["//xla/hlo/transforms:dynamic_dimension_simplifier"], ) cc_library( @@ -4202,7 +3103,6 @@ cc_library( ":dynamic_dimension_inference", ":dynamic_window_utils", ":hlo_creation_utils", - ":hlo_dce", ":pattern_matcher", ":shape_inference", ":tuple_util", @@ -4213,9 +3113,10 @@ cc_library( "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", "//xla/tsl/lib/monitoring:gauge", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -4237,15 +3138,10 @@ xla_test( srcs = ["dynamic_padder_test.cc"], tags = ["test_xla_cpu_thunks"], deps = [ - ":algebraic_simplifier", ":dynamic_dimension_inference", - ":dynamic_dimension_simplifier", ":dynamic_padder", - ":hlo_dce", - ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", - ":tuple_simplifier", "//xla:error_spec", "//xla:literal", "//xla:literal_util", @@ -4255,8 +3151,13 @@ xla_test( "//xla:test_helpers", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:dynamic_dimension_simplifier", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", @@ -4265,6 +3166,7 @@ xla_test( "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -4275,7 +3177,6 @@ xla_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_benchmark", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -4290,7 +3191,7 @@ xla_cc_test( "//xla:test", "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:filecheck", @@ -4302,23 +3203,6 @@ xla_cc_test( ], ) -xla_cc_test( - name = "reshape_mover_test", - srcs = ["reshape_mover_test.cc"], - deps = [ - ":algebraic_simplifier", - ":hlo_verifier", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":reshape_mover", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], -) - cc_library( name = "computation_placer", srcs = ["computation_placer.cc"], @@ -4332,11 +3216,11 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -4356,7 +3240,7 @@ cc_library( ":global_device_id", "//xla:array2d", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:platform", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -4411,9 +3295,11 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", @@ -4438,8 +3324,9 @@ xla_cc_test( "//xla:shape_tree", "//xla:shape_util", "//xla:types", - "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", @@ -4464,13 +3351,13 @@ cc_library( "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", ], ) @@ -4481,7 +3368,6 @@ xla_cc_test( deps = [ ":cpu_plugin", ":hlo_cost_analysis", - ":hlo_parser", ":local_service", ":service", "//xla:shape_util", @@ -4490,10 +3376,11 @@ xla_cc_test( "//xla/client", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:padding", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status:statusor", @@ -4536,7 +3423,6 @@ xla_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], deps = [ - ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", "//xla:comparison_util", @@ -4546,9 +3432,11 @@ xla_cc_test( "//xla:test", "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -4563,7 +3451,6 @@ xla_cc_test( srcs = ["hlo_module_test.cc"], deps = [ ":computation_placer_hdr", - ":hlo_memory_scheduler", ":test_compilation_environment_proto_cc", "//xla:literal", "//xla:shape_util", @@ -4571,6 +3458,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_memory_scheduler", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -4615,8 +3503,8 @@ cc_library( deps = [ ":buffer_value", ":logical_buffer", + "//xla/tsl/lib/gtl:compactptrset", "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/lib/gtl:compactptrset", ], ) @@ -4631,9 +3519,9 @@ cc_library( "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", ], ) @@ -4664,62 +3552,8 @@ cc_library( cc_library( name = "hlo_dataflow_analysis", - srcs = ["hlo_dataflow_analysis.cc"], hdrs = ["hlo_dataflow_analysis.h"], - deps = [ - ":call_graph", - ":hlo_phi_graph", - ":hlo_value", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "hlo_dataflow_analysis_test", - srcs = ["hlo_dataflow_analysis_test.cc"], - deps = [ - ":flatten_call_graph", - ":hlo_creation_utils", - ":hlo_dataflow_analysis", - ":hlo_dce", - ":hlo_graph_dumper", - ":hlo_ordering", - ":hlo_value", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], + deps = ["//xla/hlo/analysis:hlo_dataflow_analysis"], ) cc_library( @@ -4748,117 +3582,23 @@ xla_cc_test( cc_library( name = "hlo_value_semantics_analysis", - srcs = ["hlo_value_semantics_analysis.cc"], hdrs = ["hlo_value_semantics_analysis.h"], - deps = [ - ":hlo_value", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:side_effect_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "hlo_value_semantics_analysis_test", - srcs = ["hlo_value_semantics_analysis_test.cc"], - deps = [ - ":hlo_value_semantics_analysis", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_value_semantics_analysis instead.", + deps = ["//xla/hlo/analysis:hlo_value_semantics_analysis"], ) cc_library( name = "hlo_replication_analysis", - srcs = ["hlo_replication_analysis.cc"], hdrs = ["hlo_replication_analysis.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "hlo_replication_analysis_test", - srcs = ["hlo_replication_analysis_test.cc"], - deps = [ - ":hlo_replication_analysis", - "//xla:shape_util", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_replication_analysis instead.", + deps = ["//xla/hlo/analysis:hlo_replication_analysis"], ) cc_library( name = "hlo_liveness_analysis", - srcs = ["hlo_liveness_analysis.cc"], hdrs = ["hlo_liveness_analysis.h"], - deps = [ - ":call_graph", - ":hlo_value", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "hlo_liveness_analysis_test", - srcs = ["hlo_liveness_analysis_test.cc"], - deps = [ - ":hlo_liveness_analysis", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_liveness_analysis instead.", + deps = ["//xla/hlo/analysis:hlo_liveness_analysis"], ) cc_library( @@ -4882,121 +3622,23 @@ cc_library( cc_library( name = "hlo_alias_analysis", - srcs = ["hlo_alias_analysis.cc"], hdrs = ["hlo_alias_analysis.h"], - deps = [ - ":hlo_buffer", - ":hlo_dataflow_analysis", - ":hlo_ordering", - ":hlo_value", - "//xla:comparison_util", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "hlo_alias_analysis_test", - srcs = ["hlo_alias_analysis_test.cc"], - deps = [ - ":flatten_call_graph", - ":hlo_alias_analysis", - ":hlo_graph_dumper", - ":hlo_ordering", - ":instruction_fusion", - "//xla:literal", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "logical_buffer_analysis", - srcs = ["logical_buffer_analysis.cc"], - hdrs = ["logical_buffer_analysis.h"], - deps = [ - ":logical_buffer", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "tuple_points_to_analysis", - srcs = ["tuple_points_to_analysis.cc"], - hdrs = ["tuple_points_to_analysis.h"], - deps = [ - ":hlo_dataflow_analysis", - ":logical_buffer", - ":logical_buffer_analysis", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:compactptrset", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:hlo_alias_analysis instead.", + deps = ["//xla/hlo/analysis:hlo_alias_analysis"], ) - -xla_cc_test( - name = "tuple_points_to_analysis_test", - srcs = ["tuple_points_to_analysis_test.cc"], - deps = [ - ":logical_buffer", - ":tuple_points_to_analysis", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], + +cc_library( + name = "logical_buffer_analysis", + hdrs = ["logical_buffer_analysis.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:logical_buffer_analysis instead.", + deps = ["//xla/hlo/analysis:logical_buffer_analysis"], +) + +cc_library( + name = "tuple_points_to_analysis", + hdrs = ["tuple_points_to_analysis.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:tuple_points_to_analysis instead.", + deps = ["//xla/hlo/analysis:tuple_points_to_analysis"], ) cc_library( @@ -5026,11 +3668,8 @@ cc_library( deps = [ ":call_graph", ":computation_layout", - ":hlo_dce", ":hlo_graph_dumper", ":logical_buffer", - ":tuple_points_to_analysis", - ":tuple_simplifier", "//xla:permutation_util", "//xla:shape_layout", "//xla:shape_util", @@ -5038,8 +3677,11 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:tuple_points_to_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5069,22 +3711,23 @@ cc_library( deps = [ ":call_graph", ":dump", - ":hlo_alias_analysis", ":hlo_buffer", - ":hlo_dataflow_analysis", - ":hlo_dce", ":hlo_graph_dumper", - ":hlo_ordering", ":hlo_value", - ":tuple_simplifier", "//xla:frontend_attributes", "//xla:shape_tree", "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_ordering", + "//xla/hlo/analysis:hlo_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -5107,12 +3750,12 @@ cc_library( srcs = ["loop_schedule_linearizer.cc"], hdrs = ["loop_schedule_linearizer.h"], deps = [ - ":hlo_alias_analysis", - ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_value", "//xla:shape_tree", "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -5136,7 +3779,6 @@ xla_cc_test( ":copy_insertion", ":hlo_graph_dumper", ":hlo_module_config", - ":hlo_parser", "//xla:comparison_util", "//xla:debug_options_flags", "//xla:literal_util", @@ -5145,7 +3787,9 @@ xla_cc_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", + "//xla/hlo/utils:hlo_query", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/log", @@ -5176,48 +3820,16 @@ xla_cc_test( cc_library( name = "memory_space_propagation", - srcs = ["memory_space_propagation.cc"], hdrs = ["memory_space_propagation.h"], - deps = [ - ":hlo_dataflow_analysis", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - ], -) - -xla_cc_test( - name = "memory_space_propagation_test", - srcs = ["memory_space_propagation_test.cc"], - deps = [ - ":hlo_parser", - ":memory_space_propagation", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:memory_space_propagation instead.", + deps = ["//xla/hlo/transforms:memory_space_propagation"], ) cc_library( name = "hlo_dce", - srcs = ["hlo_dce.cc"], hdrs = ["hlo_dce.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_dce instead.", + deps = ["//xla/hlo/transforms:hlo_dce"], ) cc_library( @@ -5225,15 +3837,15 @@ cc_library( srcs = ["hlo_module_dce.cc"], hdrs = ["hlo_module_dce.h"], deps = [ - ":hlo_dce", - ":hlo_liveness_analysis", - ":tuple_simplifier", ":while_loop_simplifier", "//xla:status_macros", "//xla:types", "//xla:util", + "//xla/hlo/analysis:hlo_liveness_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", @@ -5278,7 +3890,6 @@ xla_cc_test( srcs = ["hlo_verifier_test.cc"], deps = [ ":hlo_module_config", - ":hlo_parser", ":hlo_verifier", ":layout_assignment", "//xla:literal_util", @@ -5286,6 +3897,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -5323,8 +3936,8 @@ xla_cc_test( srcs = ["cpu_gpu_shape_verifier_test.cc"], deps = [ ":cpu_gpu_shape_verifier", - ":hlo_parser", ":hlo_verifier", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -5335,101 +3948,17 @@ xla_cc_test( cc_library( name = "hlo_rematerialization", - srcs = ["hlo_rematerialization.cc"], hdrs = ["hlo_rematerialization.h"], - deps = [ - ":call_graph", - ":hlo_cost_analysis", - ":hlo_dataflow_analysis", - ":hlo_dce", - ":logical_buffer", - ":tuple_points_to_analysis", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_rematerialization instead.", + deps = ["//xla/hlo/transforms:hlo_rematerialization"], ) cc_library( name = "hlo_rematerialization_test_utils", testonly = 1, hdrs = ["hlo_rematerialization_test_utils.h"], - deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "hlo_rematerialization_test_utils_test", - srcs = ["hlo_rematerialization_test_utils_test.cc"], - deps = [ - ":hlo_rematerialization_test_utils", - "//xla/hlo/ir:hlo", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "hlo_rematerialization_test", - srcs = ["hlo_rematerialization_test.cc"], - deps = [ - ":hlo_cost_analysis", - ":hlo_memory_scheduler", - ":hlo_rematerialization", - ":hlo_rematerialization_test_utils", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -xla_cc_test( - name = "hlo_dce_test", - srcs = ["hlo_dce_test.cc"], - deps = [ - ":hlo_dce", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:test_utils", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_rematerialization_test_utils instead.", + deps = ["//xla/hlo/transforms:hlo_rematerialization_test_utils"], ) xla_cc_test( @@ -5452,9 +3981,7 @@ xla_cc_test( name = "layout_assignment_test", srcs = ["layout_assignment_test.cc"], deps = [ - ":algebraic_simplifier", ":computation_layout", - ":hlo_parser", ":layout_assignment", ":logical_buffer", ":pattern_matcher", @@ -5468,6 +3995,8 @@ xla_cc_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms:algebraic_simplifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -5481,14 +4010,13 @@ xla_cc_test( ], ) -# Deprecated, use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass # -# instead. cc_library( name = "hlo_pass", hdrs = [ "hlo_pass_fix.h", "hlo_pass_interface.h", ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass instead.", deps = [ "//xla:status_macros", "//xla:types", @@ -5501,11 +4029,10 @@ cc_library( ], ) -# Deprecated, use -# //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass_pipeline instead. cc_library( name = "hlo_pass_pipeline", hdrs = ["hlo_pass_pipeline.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/pass:hlo_pass_pipeline instead.", deps = [ ":compilation_stats", ":hlo_pass", @@ -5539,7 +4066,6 @@ xla_cc_test( srcs = ["hlo_cse_test.cc"], deps = [ ":hlo_cse", - ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", "//xla:literal", @@ -5547,6 +4073,7 @@ xla_cc_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -5559,49 +4086,8 @@ xla_cc_test( cc_library( name = "hlo_constant_folding", - srcs = ["hlo_constant_folding.cc"], hdrs = ["hlo_constant_folding.h"], - deps = [ - ":slow_operation_alarm", - "//xla:literal", - "//xla:shape_util", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "hlo_constant_folding_test", - srcs = ["hlo_constant_folding_test.cc"], - deps = [ - ":hlo_constant_folding", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:literal", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:statusor", - ], + deps = ["//xla/hlo/transforms:hlo_constant_folding"], ) cc_library( @@ -5609,13 +4095,18 @@ cc_library( srcs = ["hlo_domain_map.cc"], hdrs = ["hlo_domain_map.h"], deps = [ - "//xla:types", + "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -5668,11 +4159,11 @@ xla_cc_test( ":hlo_domain_isolator", ":hlo_domain_remover", ":hlo_domain_verifier", - ":hlo_parser", ":sharding_propagation", "//xla:debug_options_flags", "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -5681,59 +4172,16 @@ xla_cc_test( cc_library( name = "hlo_element_type_converter", - srcs = ["hlo_element_type_converter.cc"], hdrs = ["hlo_element_type_converter.h"], - deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:types", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "hlo_element_type_converter_test", - srcs = ["hlo_element_type_converter_test.cc"], - deps = [ - ":hlo_element_type_converter", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:hlo_element_type_converter instead.", + deps = ["//xla/hlo/transforms:hlo_element_type_converter"], ) cc_library( name = "conditional_canonicalizer", - srcs = ["conditional_canonicalizer.cc"], hdrs = ["conditional_canonicalizer.h"], - deps = [ - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - ], -) - -xla_cc_test( - name = "conditional_canonicalizer_test", - srcs = ["conditional_canonicalizer_test.cc"], - deps = [ - ":conditional_canonicalizer", - ":hlo_parser", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/tests:test_utils", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:conditional_canonicalizer instead.", + deps = ["//xla/hlo/transforms:conditional_canonicalizer"], ) cc_library( @@ -5766,6 +4214,7 @@ cc_library( "//xla:xla_data_proto_cc", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:statusor", ], ) @@ -5820,22 +4269,21 @@ xla_test( deps = [ ":elemental_ir_emitter", ":hlo_module_config", - ":hlo_parser", "//xla:error_spec", - "//xla:execution_options_util", "//xla:literal", "//xla:literal_util", - "//xla:status_macros", "//xla:test", + "//xla:types", "//xla/hlo/ir:hlo", "//xla/service/llvm_ir:ir_array", - "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", ], ) @@ -5905,6 +4353,9 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:dnn", + "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -5917,9 +4368,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/gtl:map_util", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:base64", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -5981,7 +4429,7 @@ xla_cc_test( "//xla:test", "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/service/gpu:ir_emission_utils", @@ -5996,36 +4444,9 @@ xla_cc_test( cc_library( name = "zero_sized_hlo_elimination", - srcs = ["zero_sized_hlo_elimination.cc"], hdrs = ["zero_sized_hlo_elimination.h"], - deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "zero_sized_hlo_elimination_test", - srcs = ["zero_sized_hlo_elimination_test.cc"], - deps = [ - ":zero_sized_hlo_elimination", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:zero_sized_hlo_elimination instead.", + deps = ["//xla/hlo/transforms:zero_sized_hlo_elimination"], ) cc_library( @@ -6033,7 +4454,7 @@ cc_library( srcs = ["stream_pool.cc"], hdrs = ["stream_pool.h"], deps = [ - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/strings:str_format", ], ) @@ -6044,8 +4465,8 @@ xla_cc_test( deps = [ ":stream_pool", "//xla:test_helpers", - "//xla/stream_executor", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform", "//xla/tests:xla_internal_test_main", ], @@ -6088,12 +4509,12 @@ cc_library( deps = [ ":computation_placer", ":executable", - ":hlo_parser", "//xla:status_macros", "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], @@ -6109,7 +4530,6 @@ cc_library( ":computation_placer", ":executable", ":hlo_module_util", - ":hlo_parser", ":hlo_runner_interface", ":transfer_manager", "//xla:shape_util", @@ -6119,7 +4539,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", - "//xla/stream_executor", + "//xla/hlo/parser:hlo_parser", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -6143,7 +4565,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:pjrt_client", @@ -6167,67 +4589,22 @@ cc_library( ":human_readable_profile_builder", "//xla:types", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "sort_simplifier", - srcs = ["sort_simplifier.cc"], - hdrs = ["sort_simplifier.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - ], -) - -xla_cc_test( - name = "sort_simplifier_test", - srcs = ["sort_simplifier_test.cc"], - deps = [ - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":sort_simplifier", - "//xla:test", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", ], ) cc_library( - name = "stable_sort_expander", - srcs = ["stable_sort_expander.cc"], - hdrs = ["stable_sort_expander.h"], - deps = [ - ":op_expander_pass", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - ], + name = "sort_simplifier", + hdrs = ["sort_simplifier.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sort_simplifier instead.", + deps = ["//xla/hlo/transforms:sort_simplifier"], ) -xla_cc_test( - name = "stable_sort_expander_test", - srcs = ["stable_sort_expander_test.cc"], - deps = [ - ":algebraic_simplifier", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":stable_sort_expander", - "//xla:test", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], +cc_library( + name = "stable_sort_expander", + hdrs = ["stable_sort_expander.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:stable_sort_expander instead.", + deps = ["//xla/hlo/transforms:stable_sort_expander"], ) cc_library( @@ -6255,37 +4632,25 @@ xla_cc_test( srcs = ["tuple_util_test.cc"], deps = [ ":hlo_module_config", - ":hlo_parser", ":tuple_util", "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( name = "root_instruction_sinker", - srcs = ["root_instruction_sinker.cc"], hdrs = ["root_instruction_sinker.h"], - deps = [ - ":tuple_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - ], -) - -xla_cc_test( - name = "root_instruction_sinker_test", - srcs = ["root_instruction_sinker_test.cc"], - deps = [ - ":root_instruction_sinker", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:root_instruction_sinker instead.", + deps = ["//xla/hlo/transforms:root_instruction_sinker"], ) cc_library( @@ -6298,127 +4663,23 @@ cc_library( cc_library( name = "convert_memory_placement_to_internal_annotations", - srcs = ["convert_memory_placement_to_internal_annotations.cc"], hdrs = ["convert_memory_placement_to_internal_annotations.h"], - deps = [ - ":host_memory_offload_annotations_hdr", - "//xla:side_effect_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "convert_memory_placement_to_internal_annotations_test", - srcs = ["convert_memory_placement_to_internal_annotations_test.cc"], - deps = [ - ":convert_memory_placement_to_internal_annotations", - ":host_memory_offload_annotations_hdr", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_memory_placement_to_internal_annotations instead.", + deps = ["//xla/hlo/transforms:convert_memory_placement_to_internal_annotations"], ) cc_library( name = "host_memory_transfer_asyncifier", - srcs = ["host_memory_transfer_asyncifier.cc"], hdrs = ["host_memory_transfer_asyncifier.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_memory_transfer_asyncifier_test", - srcs = ["host_memory_transfer_asyncifier_test.cc"], - deps = [ - ":host_memory_transfer_asyncifier", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_memory_transfer_asyncifier instead.", + deps = ["//xla/hlo/transforms:host_memory_transfer_asyncifier"], ) cc_library( name = "host_offload_legalize", - srcs = ["host_offload_legalize.cc"], hdrs = ["host_offload_legalize.h"], - deps = [ - ":call_graph", - ":hlo_alias_analysis", - ":hlo_value", - ":host_memory_offload_annotations_hdr", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_offload_legalize_test", - srcs = ["host_offload_legalize_test.cc"], - shard_count = 12, - deps = [ - ":host_memory_offload_annotations_hdr", - ":host_offload_legalize", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offload_legalize instead.", + deps = ["//xla/hlo/transforms:host_offload_legalize"], ) cc_library( @@ -6432,6 +4693,7 @@ cc_library( ":pattern_matcher", "//xla:literal_util", "//xla:shape_util", + "//xla:side_effect_util", "//xla:status_macros", "//xla:util", "//xla/hlo/ir:hlo", @@ -6475,102 +4737,16 @@ xla_cc_test( cc_library( name = "host_offloader", - srcs = ["host_offloader.cc"], hdrs = ["host_offloader.h"], - deps = [ - ":call_graph", - ":hlo_alias_analysis", - ":hlo_buffer", - ":hlo_cse", - ":hlo_value", - ":host_memory_offload_annotations_hdr", - ":host_offload_utils", - ":pattern_matcher", - "//xla:literal_util", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_offloader_test", - srcs = ["host_offloader_test.cc"], - shard_count = 12, - deps = [ - ":hlo_verifier", - ":host_memory_offload_annotations_hdr", - ":host_offload_legalize", - ":host_offloader", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offloader instead.", + deps = ["//xla/hlo/transforms:host_offloader"], ) cc_library( name = "host_offloading_prepare", - srcs = ["host_offloading_prepare.cc"], hdrs = ["host_offloading_prepare.h"], - deps = [ - ":call_graph", - ":host_memory_offload_annotations_hdr", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "host_offloading_prepare_test", - srcs = ["host_offloading_prepare_test.cc"], - deps = [ - ":host_memory_offload_annotations_hdr", - ":host_offloading_prepare", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:host_offloading_prepare instead.", + deps = ["//xla/hlo/transforms:host_offloading_prepare"], ) cc_library( @@ -6580,6 +4756,7 @@ cc_library( deps = [ ":call_inliner", ":hlo_creation_utils", + ":pattern_matcher", ":tuple_util", "//xla:comparison_util", "//xla:literal_util", @@ -6590,7 +4767,9 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -6607,10 +4786,11 @@ xla_cc_test( "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -6624,10 +4804,10 @@ cc_library( deps = [ ":call_graph", ":collective_ops_utils", - ":hlo_replication_analysis", "//xla:literal_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_replication_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", @@ -6661,8 +4841,6 @@ cc_library( srcs = ["while_loop_concat_code_motion.cc"], hdrs = ["while_loop_concat_code_motion.h"], deps = [ - ":hlo_dce", - ":tuple_simplifier", ":while_loop_simplifier", "//xla:shape_util", "//xla:status_macros", @@ -6671,6 +4849,8 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -6707,13 +4887,13 @@ cc_library( "while_loop_invariant_code_motion.h", ], deps = [ - ":hlo_dce", - ":while_loop_analysis", ":while_util", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -6731,13 +4911,13 @@ xla_cc_test( name = "while_loop_invariant_code_motion_test", srcs = ["while_loop_invariant_code_motion_test.cc"], deps = [ - ":hlo_parser", ":while_loop_invariant_code_motion", "//xla:literal_util", "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -6752,71 +4932,44 @@ cc_library( srcs = ["while_loop_expensive_invariant_code_motion.cc"], hdrs = ["while_loop_expensive_invariant_code_motion.h"], deps = [ - ":while_loop_analysis", ":while_util", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "while_loop_expensive_invariant_code_motion_test", - srcs = ["while_loop_expensive_invariant_code_motion_test.cc"], - deps = [ - ":hlo_parser", - ":while_loop_expensive_invariant_code_motion", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "fusion_constant_sinking", - srcs = ["fusion_constant_sinking.cc"], - hdrs = ["fusion_constant_sinking.h"], - deps = [ - ":hlo_dce", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "fusion_constant_sinking_test", - srcs = ["fusion_constant_sinking_test.cc"], + name = "while_loop_expensive_invariant_code_motion_test", + srcs = ["while_loop_expensive_invariant_code_motion_test.cc"], deps = [ - ":fusion_constant_sinking", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:test", + ":while_loop_expensive_invariant_code_motion", + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) +cc_library( + name = "fusion_constant_sinking", + hdrs = ["fusion_constant_sinking.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:fusion_constant_sinking instead.", + deps = ["//xla/hlo/transforms:fusion_constant_sinking"], +) + cc_library( name = "while_loop_constant_sinking", srcs = ["while_loop_constant_sinking.cc"], @@ -6828,8 +4981,15 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], ) @@ -6838,11 +4998,13 @@ xla_cc_test( srcs = ["while_loop_constant_sinking_test.cc"], deps = [ ":while_loop_constant_sinking", + "//xla:literal_util", "//xla:test", + "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", ], ) @@ -6851,16 +5013,23 @@ cc_library( srcs = ["while_loop_fusible_sinking.cc"], hdrs = ["while_loop_fusible_sinking.h"], deps = [ + ":pattern_matcher", ":while_util", + "//xla:comparison_util", + "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/utils:hlo_query", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", @@ -6873,29 +5042,21 @@ xla_cc_test( srcs = ["while_loop_fusible_sinking_test.cc"], deps = [ ":while_loop_fusible_sinking", - "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:flatten_call_graph", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", ], ) cc_library( name = "despecializer", - srcs = ["despecializer.cc"], hdrs = ["despecializer.h"], - deps = [ - ":defuser", - ":float_normalization", - ":hlo_memory_scheduler", - ":sub_byte_normalization", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/pass:hlo_pass_pipeline", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:despecializer instead.", + deps = ["//xla/hlo/transforms:despecializer"], ) cc_library( @@ -6911,135 +5072,28 @@ cc_library( cc_library( name = "indexed_array_analysis", - srcs = ["indexed_array_analysis.cc"], hdrs = ["indexed_array_analysis.h"], - deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "indexed_array_analysis_test", - srcs = ["indexed_array_analysis_test.cc"], - deps = [ - ":indexed_array_analysis", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/analysis:indexed_array_analysis instead.", + deps = ["//xla/hlo/analysis:indexed_array_analysis"], ) cc_library( name = "hlo_parser", - srcs = ["hlo_parser.cc"], hdrs = ["hlo_parser.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_parser instead.", deps = [ - ":computation_layout", - ":hlo_lexer", - ":hlo_module_config", - ":hlo_proto_cc", - ":name_uniquer", - ":shape_inference", - "//xla:array", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:tile_assignment", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/gtl:map_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "hlo_parser_test", - size = "small", - srcs = ["hlo_parser_test.cc"], - deps = [ - ":hlo_lexer", - ":hlo_module_config", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:array", - "//xla:shape_util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "//xla/hlo/parser:hlo_parser", ], ) cc_library( name = "hlo_lexer", - srcs = ["hlo_lexer.cc"], hdrs = [ "hlo_lexer.h", ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/parser:hlo_lexer instead.", deps = [ - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "@com_google_absl//absl/base", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:regexp", + "//xla/hlo/parser:hlo_lexer", ], ) @@ -7064,68 +5118,16 @@ cc_library( cc_library( name = "optimize_input_output_buffer_alias", - srcs = ["optimize_input_output_buffer_alias.cc"], hdrs = ["optimize_input_output_buffer_alias.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "optimize_input_output_buffer_alias_test", - srcs = ["optimize_input_output_buffer_alias_test.cc"], - deps = [ - ":optimize_input_output_buffer_alias", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:test", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:optimize_input_output_buffer_alias instead.", + deps = ["//xla/hlo/transforms:optimize_input_output_buffer_alias"], ) cc_library( name = "ar_crs_combiner", - srcs = ["ar_crs_combiner.cc"], hdrs = ["ar_crs_combiner.h"], - deps = [ - ":call_graph", - ":hlo_replication_analysis", - ":pattern_matcher", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:ar_crs_combiner instead.", + deps = ["//xla/hlo/transforms:ar_crs_combiner"], ) cc_library( @@ -7143,48 +5145,9 @@ cc_library( cc_library( name = "dynamic_index_splitter", - srcs = ["dynamic_index_splitter.cc"], hdrs = ["dynamic_index_splitter.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "dynamic_index_splitter_test", - srcs = ["dynamic_index_splitter_test.cc"], - deps = [ - ":dynamic_index_splitter", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "ar_crs_combiner_test", - srcs = ["ar_crs_combiner_test.cc"], - deps = [ - ":ar_crs_combiner", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:dynamic_index_splitter instead.", + deps = ["//xla/hlo/transforms:dynamic_index_splitter"], ) xla_cc_test( @@ -7249,14 +5212,9 @@ xla_cc_test( cc_library( name = "slice_sinker", - srcs = ["slice_sinker.cc"], hdrs = ["slice_sinker.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/types:span", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:slice_sinker instead.", + deps = ["//xla/hlo/transforms:slice_sinker"], ) cc_library( @@ -7355,54 +5313,18 @@ cc_library( deps = [":custom_call_status"], ) -xla_cc_test( - name = "slice_sinker_test", - srcs = ["slice_sinker_test.cc"], - deps = [ - ":hlo_dce", - ":hlo_parser", - ":pattern_matcher", - ":pattern_matcher_gmock", - ":slice_sinker", - "//xla:literal_util", - "//xla:shape_util", - "//xla:types", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - ], -) - cc_library( name = "rng_expander", - srcs = ["rng_expander.cc"], - hdrs = ["rng_expander.h"], - deps = [ - ":hlo_creation_utils", - ":op_expander_pass", - "//xla:literal_util", - "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/client/lib:prng", - ], + hdrs = ["rng_expander.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:rng_expander instead.", + deps = ["//xla/hlo/transforms:rng_expander"], ) cc_library( name = "rng_bit_generator_expander", - srcs = ["rng_bit_generator_expander.cc"], hdrs = ["rng_bit_generator_expander.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client/lib:prng", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:rng_bit_generator_expander instead.", + deps = ["//xla/hlo/transforms:rng_bit_generator_expander"], ) cc_library( @@ -7448,37 +5370,9 @@ cc_library( cc_library( name = "collective_transformation_reorderer", - srcs = ["collective_transformation_reorderer.cc"], hdrs = ["collective_transformation_reorderer.h"], - deps = [ - ":hlo_dce", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "collective_transformation_reorderer_test", - srcs = ["collective_transformation_reorderer_test.cc"], - deps = [ - ":collective_transformation_reorderer", - ":hlo_verifier", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:collective_transformation_reorderer instead.", + deps = ["//xla/hlo/transforms:collective_transformation_reorderer"], ) xla_cc_test( @@ -7488,10 +5382,10 @@ xla_cc_test( ":collective_ops_utils", ":computation_placer", ":global_device_id", - ":hlo_parser", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", @@ -7507,11 +5401,12 @@ cc_library( srcs = ["topk_rewriter.cc"], hdrs = ["topk_rewriter.h"], deps = [ + ":hlo_creation_utils", ":pattern_matcher", "//xla:shape_util", "//xla:util", - "//xla/client:xla_builder", - "//xla/client/lib:comparators", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "@com_google_absl//absl/algorithm:container", @@ -7525,12 +5420,12 @@ xla_cc_test( name = "topk_rewriter_test", srcs = ["topk_rewriter_test.cc"], deps = [ - ":hlo_dce", ":pattern_matcher", ":pattern_matcher_gmock", ":topk_rewriter", - ":tuple_simplifier", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -7544,67 +5439,15 @@ xla_cc_test( cc_library( name = "operand_upcaster", - srcs = ["operand_upcaster.cc"], hdrs = ["operand_upcaster.h"], - deps = [ - ":hlo_creation_utils", - ":op_expander_pass", - ":shape_inference", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "operand_upcaster_test", - srcs = ["operand_upcaster_test.cc"], - deps = [ - ":operand_upcaster", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ], + deps = ["//xla/hlo/transforms:operand_upcaster"], ) cc_library( name = "result_caster", - srcs = ["result_caster.cc"], hdrs = ["result_caster.h"], - deps = [ - ":op_expander_pass", - ":shape_inference", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "result_caster_test", - srcs = ["result_caster_test.cc"], - deps = [ - ":result_caster", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:result_caster instead.", + deps = ["//xla/hlo/transforms:result_caster"], ) cc_library( @@ -7613,40 +5456,17 @@ cc_library( hdrs = ["global_device_id.h"], deps = [ "//xla:types", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", ], ) cc_library( name = "convert_operand_folding", - srcs = ["convert_operand_folding.cc"], hdrs = ["convert_operand_folding.h"], - deps = [ - ":op_expander_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -xla_cc_test( - name = "convert_operand_folding_test", - srcs = ["convert_operand_folding_test.cc"], - deps = [ - ":convert_operand_folding", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:convert_operand_folding instead.", + deps = ["//xla/hlo/transforms:convert_operand_folding"], ) cc_library( @@ -7687,7 +5507,6 @@ xla_cc_test( xla_py_proto_library( name = "hlo_pb2", - api_version = 2, visibility = ["//visibility:public"], deps = [":hlo_proto"], ) @@ -7887,18 +5706,9 @@ cc_library( cc_library( name = "instruction_hoister", - srcs = ["instruction_hoister.cc"], hdrs = ["instruction_hoister.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:status", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:instruction_hoister instead.", + deps = ["//xla/hlo/transforms:instruction_hoister"], ) cc_library( @@ -7909,12 +5719,12 @@ cc_library( ":call_inliner", ":gather_scatter_utils", ":hlo_creation_utils", - ":op_expander_pass", "//xla:permutation_util", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", @@ -7926,9 +5736,9 @@ xla_cc_test( name = "scatter_simplifier_test", srcs = ["scatter_simplifier_test.cc"], deps = [ - ":hlo_parser", ":scatter_simplifier", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/tests:hlo_test_base", @@ -7942,9 +5752,9 @@ cc_library( hdrs = ["select_and_scatter_expander.h"], deps = [ ":call_inliner", - ":op_expander_pass", "//xla:literal_util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", ], ) @@ -7994,28 +5804,21 @@ cc_library( hdrs = ["gather_scatter_utils.h"], deps = [ ":hlo_creation_utils", + "//xla:literal_util", "//xla:permutation_util", "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/types:span", ], ) cc_library( name = "gather_simplifier", - srcs = ["gather_simplifier.cc"], hdrs = ["gather_simplifier.h"], - deps = [ - ":gather_scatter_utils", - ":hlo_creation_utils", - ":op_expander_pass", - "//xla:literal_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/algorithm:container", - "@local_tsl//tsl/platform:statusor", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:gather_simplifier instead.", + deps = ["//xla/hlo/transforms:gather_simplifier"], ) cc_library( @@ -8023,10 +5826,11 @@ cc_library( srcs = ["batched_gather_scatter_normalizer.cc"], hdrs = ["batched_gather_scatter_normalizer.h"], deps = [ - ":op_expander_pass", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -8039,73 +5843,16 @@ cc_library( cc_library( name = "reduce_window_rewriter", - srcs = ["reduce_window_rewriter.cc"], hdrs = ["reduce_window_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "reduce_window_rewriter_test", - srcs = ["reduce_window_rewriter_test.cc"], - deps = [ - ":reduce_window_rewriter", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:reduce_window_rewriter instead.", + deps = ["//xla/hlo/transforms:reduce_window_rewriter"], ) cc_library( name = "stochastic_convert_decomposer", - srcs = ["stochastic_convert_decomposer.cc"], hdrs = ["stochastic_convert_decomposer.h"], - deps = [ - ":hlo_creation_utils", - ":shape_inference", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "stochastic_convert_decomposer_test", - srcs = ["stochastic_convert_decomposer_test.cc"], - deps = [ - ":hlo_parser", - ":stochastic_convert_decomposer", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:stochastic_convert_decomposer instead.", + deps = ["//xla/hlo/transforms:stochastic_convert_decomposer"], ) cc_library( @@ -8121,46 +5868,17 @@ cc_library( cc_library( name = "sub_byte_normalization", - srcs = ["sub_byte_normalization.cc"], hdrs = ["sub_byte_normalization.h"], - deps = [ - "//xla:shape_layout", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sub_byte_normalization instead.", + deps = ["//xla/hlo/transforms:sub_byte_normalization"], ) cc_library( name = "sharding_format_picker", testonly = True, - srcs = ["sharding_format_picker.cc"], hdrs = ["sharding_format_picker.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:tile_assignment", - "//xla/hlo/pass:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - ], -) - -xla_cc_test( - name = "gather_simplifier_test", - srcs = ["gather_simplifier_test.cc"], - deps = [ - ":gather_simplifier", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:sharding_format_picker instead.", + deps = ["//xla/hlo/transforms:sharding_format_picker"], ) xla_cc_test( @@ -8342,9 +6060,9 @@ xla_cc_test( ":xla_aot_compile_test_gpu_executable_convolution_runtime_autotuning", ]), tags = [ + "cuda-only", "gpu", "no_oss", - "no_rocm", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. "requires-gpu-sm60-only", ], @@ -8453,7 +6171,7 @@ tf_proto_library( make_default_target_header_only = True, protodeps = [ ":hlo_proto", - "@local_tsl//tsl/protobuf:status_proto", + "//xla/tsl/protobuf:status_proto", ] + if_google(["@com_google_protobuf//:duration"]), visibility = ["//visibility:public"], ) @@ -8464,8 +6182,8 @@ cc_library( hdrs = ["algorithm_util.h"], deps = [ "//xla:xla_data_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_description", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -8474,52 +6192,71 @@ cc_library( cc_library( name = "add_original_value", - srcs = ["add_original_value.cc"], hdrs = ["add_original_value.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:add_original_value instead.", + deps = ["//xla/hlo/transforms:add_original_value"], +) + +xla_cc_test( + name = "propagate_original_value_test", + srcs = ["propagate_original_value_test.cc"], deps = [ - "//xla:shape_util", + ":instruction_fusion", + "//xla:xla_data_proto_cc", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "infeed_token_propagation", + hdrs = ["infeed_token_propagation.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/transforms:infeed_token_propagation instead.", + deps = ["//xla/hlo/transforms:infeed_token_propagation"], +) + +cc_library( + name = "while_loop_pipeline_unroller", + srcs = ["while_loop_pipeline_unroller.cc"], + hdrs = ["while_loop_pipeline_unroller.h"], + deps = [ + ":call_inliner", + ":while_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:hlo_dce", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "add_original_value_test", - srcs = ["add_original_value_test.cc"], + name = "while_loop_pipeline_unroller_test", + srcs = ["while_loop_pipeline_unroller_test.cc"], deps = [ - ":add_original_value", - ":pattern_matcher", - ":pattern_matcher_gmock", - "//xla:shape_util", - "//xla:window_util", - "//xla:xla_data_proto_cc", + ":copy_insertion", + ":while_loop_pipeline_unroller", + "//xla:test_helpers", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/container:inlined_vector", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) -xla_cc_test( - name = "propagate_original_value_test", - srcs = ["propagate_original_value_test.cc"], - deps = [ - ":instruction_fusion", - "//xla:xla_data_proto_cc", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - ], +cc_library( + name = "collective_utils", + hdrs = ["collective_utils.h"], ) exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h index 8dd1655e0eb0e8..2a68cca88b0e94 100644 --- a/third_party/xla/xla/service/add_original_value.h +++ b/third_party/xla/xla/service/add_original_value.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ #define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This pass adds to each op in the HLO graph the original_value attribute, -// which is used for HLO value tracking. See go/hlo-value-tracking for more -// details. -class AddOriginalValue : public HloModulePass { - public: - absl::string_view name() const override { return "add-original-value"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/add_original_value.h" #endif // XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index 9ff96b248d398a..82fc943903041e 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -16,697 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ #define XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/literal.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" - -namespace xla { - -class AlgebraicSimplifierOptions { - public: - // Platform dependent callback to determine if a reshape `from_shape` to - // `to_shape` is a bitcast. - using ReshapeIsBitcastCallback = - std::function; - // Platform dependent callback to determine if a set of reverse dimensions is - // lowerable - using ConvIsLowerableCallback = std::function; - - explicit AlgebraicSimplifierOptions( - ReshapeIsBitcastCallback reshape_is_bitcast_callback = {}, - ConvIsLowerableCallback conv_is_lowerable_callback = {}) - : reshape_is_bitcast_callback_(std::move(reshape_is_bitcast_callback)), - conv_is_lowerable_callback_(std::move(conv_is_lowerable_callback)) {} - - // Use the platform specific callback if set. It is not sensible to return - // true here if the options are not layout sensitive. - bool ReshapeIsBitcast(const Shape& from_shape, const Shape& to_shape) const { - if (!is_layout_sensitive_) { - return false; - } - if (!reshape_is_bitcast_callback_) { - return ShapeUtil::ReshapeIsBitcast(from_shape, to_shape); - } - return reshape_is_bitcast_callback_(from_shape, to_shape); - } - - // Use the platform specific callback if set. Otherwise, return true. - bool ConvIsLowerable(HloInstruction* reverse_dims) const { - if (!conv_is_lowerable_callback_) { - return true; - } - return conv_is_lowerable_callback_(reverse_dims); - } - - void set_conv_is_lowerable_callback( - ConvIsLowerableCallback conv_is_lowerable_callback) { - conv_is_lowerable_callback_ = std::move(conv_is_lowerable_callback); - } - - // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. - void set_is_layout_sensitive(bool is_layout_sensitive) { - is_layout_sensitive_ = is_layout_sensitive; - } - - bool is_layout_sensitive() const { return is_layout_sensitive_; } - - void set_use_associative_reordering(bool use_associative_reordering) { - use_associative_reordering_ = use_associative_reordering; - } - - bool use_associative_reordering() const { - return use_associative_reordering_; - } - - void set_associative_reordering_threshold( - double associative_reordering_threshold) { - associative_reordering_threshold_ = associative_reordering_threshold; - } - - double associative_reordering_threshold() const { - return associative_reordering_threshold_; - } - - void set_use_convert_constant_folding(bool use_convert_constant_folding) { - use_convert_constant_folding_ = use_convert_constant_folding; - } - - bool use_convert_constant_folding() const { - return use_convert_constant_folding_; - } - - void set_raise_slice_and_reduce_through_dot( - bool raise_slice_and_reduce_through_dot) { - raise_slice_and_reduce_through_dot_ = raise_slice_and_reduce_through_dot; - } - - bool raise_slice_and_reduce_through_dot() const { - return raise_slice_and_reduce_through_dot_; - } - - void set_raise_slice_and_reduce_through_dot_threshold( - double raise_slice_and_reduce_through_dot_threshold) { - raise_slice_and_reduce_through_dot_threshold_ = - raise_slice_and_reduce_through_dot_threshold; - } - - double raise_slice_and_reduce_through_dot_threshold() const { - return raise_slice_and_reduce_through_dot_threshold_; - } - - // Enable dot simplification on platforms where it is profitable. - void set_enable_dot_strength_reduction(bool enable_dot_strength_reduction) { - enable_dot_strength_reduction_ = enable_dot_strength_reduction; - } - - bool enable_dot_strength_reduction() const { - return enable_dot_strength_reduction_; - } - - // Enable dot->multiple rewrite for dot as an outer-product - void set_enable_dot_to_multiply_rewrite(bool enable_dot_to_multiply_rewrite) { - enable_dot_to_multiply_rewrite_ = enable_dot_to_multiply_rewrite; - } - - bool enable_dot_to_multiply_rewrite() const { - return enable_dot_to_multiply_rewrite_; - } - - void set_enable_move_dot_param_to_rhs(bool enable_move_dot_param_to_rhs) { - enable_move_dot_param_to_rhs_ = enable_move_dot_param_to_rhs; - } - - bool enable_move_dot_param_to_rhs() const { - return enable_move_dot_param_to_rhs_; - } - - // This platform will not run the DotDecomposer to canonicalize dots. - void set_supports_non_canonical_dots(bool supports_non_canonical_dots) { - supports_non_canonical_dots_ = supports_non_canonical_dots; - } - bool supports_non_canonical_dots() const { - return supports_non_canonical_dots_; - } - - // Enable convolution simplification on platforms where it is profitable. - void set_enable_conv_simplification(bool enable_conv_simplification) { - enable_conv_simplification_ = enable_conv_simplification; - } - bool enable_conv_simplification() const { - return enable_conv_simplification_; - } - - // Enable convolution operand swapping on platforms where it is supported. - void set_enable_conv_operand_swap(bool enable_conv_operand_swap) { - enable_conv_operand_swap_ = enable_conv_operand_swap; - } - bool enable_conv_operand_swap() const { return enable_conv_operand_swap_; } - - // Move constant scalar multiply to one operand or output of convolutions with - // the smallest tensor size, to reduce the number of scalar multiply. - void set_enable_scalar_multiply_reduction( - bool enable_scalar_multiply_reduction) { - enable_scalar_multiply_reduction_ = enable_scalar_multiply_reduction; - } - - bool enable_scalar_multiply_reduction() const { - return enable_scalar_multiply_reduction_; - } - - // Also the algebraic simplifer to treat floating point values like real - // numbers. - void set_enable_floats_are_real(bool enable_floats_are_real) { - enable_floats_are_real_ = enable_floats_are_real; - } - - bool enable_floats_are_real() const { return enable_floats_are_real_; } - - // If enable_window_reduce_replacement is true, the kReduceWindow instruction - // can be optimized by replacement with simpler operations. - void set_enable_window_reduce_to_reduce_replacement( - bool enable_window_reduce_to_reduce_replacement) { - enable_window_reduce_to_reduce_replacement_ = - enable_window_reduce_to_reduce_replacement; - } - - bool enable_window_reduce_to_reduce_replacement() const { - return enable_window_reduce_to_reduce_replacement_; - } - - // Sets the size of a gather operand that can be unrolled into many selects. - void set_very_small_gather_size(int64_t size) { - very_small_gather_size_ = size; - } - - int64_t very_small_gather_size() const { return very_small_gather_size_; } - - void set_cudnn_batchnorm_forward_training_metadata(const std::string& c) { - metadata_.cudnn_batchnorm_forward_training_metadata = c; - } - - const std::string& get_cudnn_batchnorm_forward_training_metadata() const { - return metadata_.cudnn_batchnorm_forward_training_metadata; - } - - void set_enable_reduce_of_reshape(bool enable_reduce_of_reshape) { - enable_reduce_of_reshape_ = enable_reduce_of_reshape; - } - - bool enable_reduce_of_reshape() const { return enable_reduce_of_reshape_; } - - void set_enable_negative_padding_replacement( - bool enable_negative_padding_replacement) { - enable_negative_padding_replacement_ = enable_negative_padding_replacement; - } - - bool enable_negative_padding_replacement() const { - return enable_negative_padding_replacement_; - } - - void set_enable_sink_broadcast(bool enable_sink_broadcast) { - enable_sink_broadcast_ = enable_sink_broadcast; - } - - bool enable_sink_broadcast() const { return enable_sink_broadcast_; } - - // If true, always simplify reduce(transpose(x)) and reduce(reshape(x)), even - // if the transpose/reshape has multiple users. This can be beneficial - // on platforms where the extra transpose/reshape isn't as expensive as - // the optimization benefits brought about by simplifying the graph. - bool unconditionally_simplify_reduce_of_transpose_or_reshape() const { - return unconditionally_simplify_reduce_of_transpose_or_reshape_; - } - void set_unconditionally_simplify_reduce_of_transpose_or_reshape(bool val) { - unconditionally_simplify_reduce_of_transpose_or_reshape_ = val; - } - - // If true, min(x, NaN) = NaN. If false, min(x, NaN) = x. - // - // TODO(b/209827141): Remove this and make minmax_propagate_nan - // unconditionally true. - bool minmax_propagate_nan() const { return minmax_propagate_nan_; } - void set_minmax_propagate_nan(bool val) { minmax_propagate_nan_ = val; } - - // When true, always replaces Reduce(concat({a,b,...})) with - // map(reduce(a),map(reduce(b),...,)). If false, only does the replacement if - // the shapes of a,b,... have the same dimensions. - bool enable_unconditional_reduce_of_concat_replacement() const { - return enable_unconditional_reduce_of_concat_replacement_; - } - void set_enable_unconditional_reduce_of_concat_replacement( - bool enable_unconditional_reduce_of_concat_replacement) { - enable_unconditional_reduce_of_concat_replacement_ = - enable_unconditional_reduce_of_concat_replacement; - } - - // Indicates whether running on CPU - bool executing_on_cpu() const { return executing_on_cpu_; } - void set_executing_on_cpu(bool executing_on_cpu) { - executing_on_cpu_ = executing_on_cpu; - } - - // Option to disable conversion of dynamic-slice to slice. - void set_disable_dynamic_slice_to_slice_conversion(bool disable) { - disable_dynamic_slice_to_slice_conversion_ = disable; - } - bool disable_dynamic_slice_to_slice_conversion() const { - return disable_dynamic_slice_to_slice_conversion_; - } - - private: - // Metadata struct can be used to store any metadata information encapsulated - // with the AlgebraicSimplifierOptions that can be later used in an - // AlgebraicSimplifier pass. For example, - // cudnn_batchnorm_forward_training_metadata can be used to store the name of - // a custom call. If the custom call is - // __cudnn$batchNormalizationForwardTraining, the output with index 2 is - // guaranteed to be positive. This property has been used to recursively - // determine if the operand of an instruction is always positive. - struct Metadata { - std::string cudnn_batchnorm_forward_training_metadata{""}; - Metadata() {} - }; - ReshapeIsBitcastCallback reshape_is_bitcast_callback_; - ConvIsLowerableCallback conv_is_lowerable_callback_; - bool is_layout_sensitive_{false}; - bool enable_dot_strength_reduction_{true}; - bool supports_non_canonical_dots_{true}; - bool enable_dot_to_multiply_rewrite_{true}; - bool enable_move_dot_param_to_rhs_{false}; - bool enable_conv_simplification_{true}; - bool enable_conv_operand_swap_{true}; - bool enable_scalar_multiply_reduction_{false}; - bool enable_floats_are_real_{false}; - bool enable_window_reduce_to_reduce_replacement_{true}; - bool enable_reduce_of_reshape_{true}; - bool enable_negative_padding_replacement_{true}; - bool enable_sink_broadcast_{true}; - bool unconditionally_simplify_reduce_of_transpose_or_reshape_{false}; - int64_t very_small_gather_size_{4}; - bool minmax_propagate_nan_{true}; - bool enable_unconditional_reduce_of_concat_replacement_{true}; - bool executing_on_cpu_{false}; - bool use_associative_reordering_{false}; - double associative_reordering_threshold_{2.0}; - bool raise_slice_and_reduce_through_dot_{false}; - double raise_slice_and_reduce_through_dot_threshold_{2.0}; - bool use_convert_constant_folding_{false}; - bool disable_dynamic_slice_to_slice_conversion_{false}; - Metadata metadata_; -}; - -// A pass which performs algebraic simplifications. -class AlgebraicSimplifier : public HloModulePass { - public: - // If is_layout_sensitive is true, then the simplifier preserves layout during - // transformation. Otherwise, layout is ignored. - explicit AlgebraicSimplifier(const AlgebraicSimplifierOptions& options) - : options_(options) {} - ~AlgebraicSimplifier() override = default; - absl::string_view name() const override { return "algsimp"; } - - // Run algebraic simplification on the given computation. Returns whether the - // computation was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // Create constant from literal with tiles and element size updated in the - // constant's layout. - std::unique_ptr CreateConstantWithLayoutUpdated( - Literal literal) { - auto constant = HloInstruction::CreateConstant(std::move(literal)); - UpdateLayout(constant->mutable_shape()); - return constant; - } - - protected: - AlgebraicSimplifierOptions options_; -}; - -// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain -// algebraic expressions to simplified forms. Note: This only supports -// simplifications that simply look at the operands of an instruction. For the -// more general case a worklist based approach would be needed. -class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { - public: - explicit AlgebraicSimplifierVisitor(const AlgebraicSimplifierOptions& options, - AlgebraicSimplifier* simplifier) - : options_(options), simplifier_(simplifier) {} - - absl::Status HandleAbs(HloInstruction* abs) override; - - absl::Status HandleAdd(HloInstruction* add) override; - - absl::Status HandleAllToAll(HloInstruction* all_to_all) override; - - absl::Status HandleAnd(HloInstruction* logical_and) override; - - absl::Status HandleBitcast(HloInstruction* bitcast) override; - - absl::Status HandleBitcastConvert(HloInstruction* bitcast) override; - - absl::Status HandleBroadcast(HloInstruction* broadcast) override; - - absl::Status HandleCompare(HloInstruction* compare) override; - - absl::Status HandleConcatenate(HloInstruction* concatenate) override; - - absl::Status HandleConstant(HloInstruction* constant) override; - - absl::Status HandleCopy(HloInstruction* copy) override; - - absl::Status HandleConvert(HloInstruction* convert) override; - - absl::Status HandleComplex(HloInstruction* complex) override; - - absl::Status HandleCustomCall(HloInstruction* custom_call) override; - - absl::Status HandleReal(HloInstruction* real) override; - - absl::Status HandleImag(HloInstruction* imag) override; - - absl::Status HandleIota(HloInstruction* instruction) override; - - absl::Status HandleConvolution(HloInstruction* convolution) override; - - absl::Status HandleDivide(HloInstruction* divide) override; - - absl::Status HandleDot(HloInstruction* dot) override; - - absl::Status HandleGather(HloInstruction* gather) override; - - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - - absl::Status HandleLog(HloInstruction* log) override; - - absl::Status HandleMaximum(HloInstruction* maximum) override; - - absl::Status HandleMinimum(HloInstruction* minimum) override; - - absl::Status HandleClamp(HloInstruction* clamp) override; - - absl::Status HandleMultiply(HloInstruction* multiply) override; - - absl::Status HandleNegate(HloInstruction* negate) override; - - absl::Status HandleNot(HloInstruction* logical_not) override; - - absl::Status HandleOptimizationBarrier(HloInstruction* barrier) override; - - absl::Status HandleOr(HloInstruction* logical_or) override; - - absl::Status HandlePad(HloInstruction* pad) override; - - absl::Status HandlePower(HloInstruction* power) override; - - absl::Status HandleRemainder(HloInstruction* remainder) override; - - absl::Status HandleReshape(HloInstruction* reshape) override; - - absl::Status HandleReduce(HloInstruction* hlo) override; - - absl::Status HandleReduceWindow(HloInstruction* hlo) override; - - absl::Status HandleReverse(HloInstruction* reverse) override; - - absl::Status HandleRsqrt(HloInstruction* rsqrt) override; - - absl::Status HandleSlice(HloInstruction* slice) override; - - absl::Status HandleSqrt(HloInstruction* sqrt) override; - - absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; - - absl::Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override; - absl::Status HandleScatter(HloInstruction* hlo) override; - - absl::Status HandleSelect(HloInstruction* select) override; - - absl::Status HandleSort(HloInstruction* sort) override; - - absl::Status HandleTranspose(HloInstruction* transpose) override; - - absl::Status HandleSubtract(HloInstruction* sub) override; - - absl::Status HandleMap(HloInstruction* map) override; - - // Runs the visitor on a computation. - bool Run(HloComputation* computation, - const AlgebraicSimplifierOptions& options, - AlgebraicSimplifier* simplifier); - - // Compute a function that maps from bitcasted dimensions to the resulting - // ones. Returns the function as a vector if successful; std::optional - // otherwise. - static std::optional>> ComputeBitcastDimMap( - const Shape& bitcast_shape, const Shape& operand_shape); - // Invert the directions of the given bitcast dimension map. - static std::vector> InvertBitcastDimMap( - const Shape& original_shape, const Shape& bitcast_shape, - const std::vector>& original_map); - - // Checks if the output of a given instruction is guaranteed to be - // non-negative. e.g. abs - static bool IsNonNegative(const HloInstruction* hlo, - const AlgebraicSimplifierOptions& options); - - // Check if the opcode of a given instruction is a non-decreasing function - // asymptotically satisfying |f(x)| <= |x| - static bool IsNondecreasingSublinear(const HloInstruction* hlo); - - // Modify the layout dimensions of result_shape, so that it becomes the - // re-shaped result of applying bitcast to the original_shape, by using - // dim_map to re-shape layout dimensions of original_shape. Returns the - // result_shape with modified layout if the conversion succeeds; Returns - // std::nullopt if fails. - static std::optional ReshapeLayoutDimensions( - const Shape& original_shape, const Shape& result_shape, - const std::vector>& original_map, - const std::vector>& result_map); - - // Allow backend constraints on tiling etc. to invalidate optimizations. - virtual bool IsValidLayout(const Shape& shape) { return true; } - // Allow backend targets to determine whether a layout is inefficient. - virtual bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) { - return true; - } - - protected: - // The backend-specific options selected for the algebraic simplifier. - const AlgebraicSimplifierOptions& options_; - - private: - // Removes degenerate dimension from dot. - absl::StatusOr RemoveDegenerateDimensionFromDot(HloDotInstruction* dot); - - // Moves the transpose to the broadcast if possible. Can also be called with a - // bitcast transpose. - absl::Status SimplifyTransposeOfBroadcast( - HloInstruction* transpose, absl::Span dimensions); - - // Converts to primitive type if the input hlo is not that type, otherwise - // returns the original hlo. - HloInstruction* AsType(HloInstruction* hlo, - const PrimitiveType element_type) { - if (hlo->shape().element_type() == element_type) { - return hlo; - } - Shape changed_shape = - ShapeUtil::ChangeElementType(hlo->shape(), element_type); - simplifier_->UpdateLayout(&changed_shape); - return computation_->AddInstruction( - HloInstruction::CreateConvert(changed_shape, hlo)); - } - - // Transposes a dot operand such that the batch dimensions are the most major, - // and the contracting dimensions are most minor. - absl::StatusOr - NormalizeDotOperandToBatchMajorAndContractingMinor( - HloInstruction* dot_operand, absl::Span batch_dimensions, - absl::Span contracting_dimensions); - - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)) (or - // transpose(dot(a,b)) if only the batch dims are transposed). - // - // Requires the dot has been canonicalized by DotDecomposer into - // - // LHS [batch dims..., non-contracting dim, contracting dim] - // RHS [batch dims..., contracting dim, non-contracting dim]. - absl::StatusOr RemoveTransposesFromDotOperands(HloDotInstruction* dot); - - // Swap the operands of dots, if one operand is "parameter-like" (i.e. a - // parameter, or a pointwise transformation of a parameter), so the - // "parameter-like" operand (e.g. a weight tensor) is placed on the RHS. - absl::StatusOr MoveDotParamToRhs(HloDotInstruction* dot); - - // Helper method to perform and add reduction on a list of dimensions. - HloInstruction* AddReduce(HloInstruction* hlo, absl::Span dims, - PrimitiveType type); - - // Move scalar multiply to the smallest side of convolution to - // reduce multiply computations. - absl::Status ScalarMultiplyReduction(HloInstruction* dot); - - // Convenience method for replacing an instruction with a bitcast. If operand - // is not null, then the bitcast will use the specified operand instead of the - // operand of the instruction. - void ReplaceWithBitcast(HloInstruction* instruction, - HloInstruction* operand = nullptr); - - // Change copy(bitcast...(copy)) into copy(bitcast) or bitcast(copy) so that - // the replicated copies are combined when allowed by layout/tiling assignment - // constraints. - bool SwapCopyBitcastCopy(HloInstruction* root_copy); - - // Replace old instruction with new instruction if old and new instructions - // are compatible (have the same shape and replacement preserves sharding). - // Updates uses and root instruction. Returns whether a replacement was made. - bool ReplaceInstructionIfCompatible(HloInstruction* old_instruction, - HloInstruction* new_instruction); - // Similar to above but tuplizes `new_instructions` if there are more than 1 - // instructions. - bool ReplaceInstructionIfCompatible( - HloInstruction* old_instruction, - absl::Span new_instructions); - - // Returns whether the shape of the output of the given instructions are the - // same for the purposes of simplification. If options_.is_layout_sensitive() - // is true, then this tests shape equality including layout - // (ShapeUtil::Equal). If options_.is_layout_sensitive() is false, then the - // tests shape compatibility (ShapeUtil::Compatible). - bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const; - - // Same as above but takes shape arguments directly. - bool SameShape(const Shape& lhs, const Shape& rhs) const; - - // A Broadcast that feeds an element-wise operation with a unique non-scalar - // operand can sink to after the operation. - absl::StatusOr TryToSinkBroadcastAfterOpWithUniqueNonScalarOperand( - HloInstruction* broadcast); - - absl::StatusOr OptimizeDotOfConcat(HloInstruction* dot); - absl::StatusOr OptimizeDotOfConcatHelper( - HloInstruction* dot, HloInstruction* lhs, int64_t lhs_contracting_dim, - HloInstruction* rhs, int64_t rhs_contracting_dim, bool swapped); - - absl::StatusOr OptimizeDotOfGather(HloInstruction* dot); - - absl::StatusOr OptimizeDotOfReorderContractingDims( - HloInstruction* dot); - - absl::StatusOr AssociativeReorderDotOperator( - HloDotInstruction* dot); - - HloComputation* GetOrCreateScalarAddComputation(PrimitiveType type) { - HloComputation*& scalar_add_computation = scalar_add_computations_[type]; - if (scalar_add_computation) { - return scalar_add_computation; - } - - HloComputation::Builder b("scalar_add_computation"); - Shape shape = ShapeUtil::MakeShape(type, {}); - simplifier_->UpdateLayout(&shape); - auto scalar_lhs = b.AddInstruction( - HloInstruction::CreateParameter(0, shape, "scalar_lhs")); - auto scalar_rhs = b.AddInstruction( - HloInstruction::CreateParameter(1, shape, "scalar_rhs")); - auto scalar_op = b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kAdd, scalar_lhs, scalar_rhs)); - scalar_add_computation = - computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); - return scalar_add_computation; - } - - // Tries to fold a kPad in the input or filter into the convolution - // instruction's window. - virtual absl::StatusOr FoldConvInputPad(HloInstruction* convolution); - absl::StatusOr FoldConvFilterPad(HloInstruction* convolution); - - // Tries to swap convolution operands if they would result in a more efficient - // convolution. - absl::StatusOr SwapConvOperands(HloInstruction* convolution); - - // Checks if the given convolution is in BF16 and is oneDNN rewritable, if not - // then it promotes the data type of the convolution to F32 - absl::StatusOr IsOneDnnRewritableBF16Conv(HloInstruction** convolution); - - // Tries to use a kDot in place of the given convolution. - absl::StatusOr SimplifyConvToDot(HloInstruction* convolution); - // Tries to use a multiplication in place of the given convolution. - absl::StatusOr SimplifyConvToMultiply(HloInstruction* convolution); - - // Tries to simplify a slice where the result of the slice is a scalar. - absl::StatusOr TrySimplifyScalarSlice(HloInstruction* slice); - - // Tries to convert slice(reshape(X)) into reshape(slice(X)) - absl::StatusOr TryToReorderSliceAndReshape(HloInstruction* slice); - - // Tries to convert slice(reverse(X)) into reverse(slice(X)) - absl::StatusOr TryToReorderSliceAndReverse(HloInstruction* slice); - - // Tries to simplify `(and (< a N) (< a K))` in cases where `N <= K` into - // `(< a N)`. This is crucial for being able to figure out the loop trip - // count. - // - // Assumes that the input is conjunction. - absl::StatusOr TrySimplifyTautologicalCompare( - HloInstruction* conjunction); - - // Tries to simlplify (bitcast-convert (concat (bitcast-convert A) ...)) where - // the types of inner and outer bitcast-convert cancel out. - absl::StatusOr TrySimplifyTautologicalBitcastConvert( - HloInstruction* bitcast); - - // Tries to remove surrounding converts around a binary op where the op has a - // more precise type than its inputs and output. - // - // convert(bin_op(convert(data1), - // convert(data2))) - // where TS is a smaller point type than TL (ex, TS=fp16, TL=fp32) - // -> - // bin_op(data1, data2) - absl::Status TryRemoveUpcastAndDowncastSurroundingBinaryOp( - HloInstruction* convert_instruction); - - // Useful when we want to use the same visitor over multiple computations. - void ResetState(HloComputation* computation); - - // Current HloComputation instance the AlgebraicSimplifierVisitor is - // traversing. - HloComputation* computation_; - - // Cached computation for adding two scalars of a given type. - absl::flat_hash_map scalar_add_computations_; - - AlgebraicSimplifier* simplifier_ = nullptr; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #endif // XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc index 8e011d6d24edf7..071f9994b54a08 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_overflow_test.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/algorithm_util.cc b/third_party/xla/xla/service/algorithm_util.cc index 4e037056730909..3b8724329d60ca 100644 --- a/third_party/xla/xla/service/algorithm_util.cc +++ b/third_party/xla/xla/service/algorithm_util.cc @@ -41,11 +41,15 @@ absl::StatusOr GetBlasComputationType( switch (algorithm) { case PrecisionConfig::ALG_DOT_F16_F16_F16: return se::blas::ComputationType::kF16; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + return se::blas::ComputationType::kBF16AsF32; case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: case PrecisionConfig::ALG_DOT_F16_F16_F32: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: case PrecisionConfig::ALG_DOT_F32_F32_F32: + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: return se::blas::ComputationType::kF32; case PrecisionConfig::ALG_DOT_TF32_TF32_F32: return se::blas::ComputationType::kTF32AsF32; @@ -95,15 +99,26 @@ bool HasFastAccum(PrecisionConfig::Algorithm algorithm) { return algorithm == PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM; } +bool IsAmpere(stream_executor::GpuComputeCapability gpu_compute_capability) { + return std::holds_alternative( + gpu_compute_capability) && + std::get(gpu_compute_capability).major == + stream_executor::CudaComputeCapability::AMPERE; +} + // It's clear that those libraries could support more, but we only list the ones // which we explicitly test for now. -bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm) { +bool IsSupportedByCublasOrCublasLt( + PrecisionConfig::Algorithm algorithm, + stream_executor::GpuComputeCapability gpu_compute_capability) { switch (algorithm) { + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: case PrecisionConfig::ALG_UNSET: case PrecisionConfig::ALG_DOT_F16_F16_F32: case PrecisionConfig::ALG_DOT_F32_F32_F32: case PrecisionConfig::ALG_DOT_F64_F64_F64: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: case PrecisionConfig::ALG_DOT_TF32_TF32_F32: case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32: case PrecisionConfig::ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM: @@ -174,13 +189,20 @@ bool IsSupportedDotAlgorithmOnGpu( return input_storage_type == F16 && (output_storage_type == F16 || output_storage_type == F32); case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && - input_storage_type == BF16 && - (output_storage_type == BF16 || output_storage_type == F32); + if (!is_cuda_ge_ampere && !is_rocm_mi100_and_above) return false; + switch (input_storage_type) { + case BF16: + return output_storage_type == BF16 || output_storage_type == F32; + case F32: + return output_storage_type == F32; + default: + return false; + } case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && input_storage_type == F32 && output_storage_type == F32; + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: case PrecisionConfig::ALG_DOT_TF32_TF32_F32: return (is_cuda_ge_ampere || is_rocm_mi100_and_above) && input_storage_type == F32 && output_storage_type == F32; diff --git a/third_party/xla/xla/service/algorithm_util.h b/third_party/xla/xla/service/algorithm_util.h index 9ce28f552dd5a0..e990ae9757b259 100644 --- a/third_party/xla/xla/service/algorithm_util.h +++ b/third_party/xla/xla/service/algorithm_util.h @@ -52,7 +52,9 @@ bool HasFastAccum(PrecisionConfig::Algorithm algorithm); // // We may want to also check storage types, but for now those are checked in // IsSupportedDotAlgorithmOnGpu. -bool IsSupportedByCublasOrCublasLt(PrecisionConfig::Algorithm algorithm); +bool IsSupportedByCublasOrCublasLt( + PrecisionConfig::Algorithm algorithm, + stream_executor::GpuComputeCapability gpu_compute_capability); // Checks if we support the given algorithm using cuDNN. bool IsSupportedByCudnn(PrecisionConfig::Algorithm algorithm); diff --git a/third_party/xla/xla/service/all_gather_broadcast_reorder.h b/third_party/xla/xla/service/all_gather_broadcast_reorder.h index 0759f8ebfbbc79..ce722207a37a62 100644 --- a/third_party/xla/xla/service/all_gather_broadcast_reorder.h +++ b/third_party/xla/xla/service/all_gather_broadcast_reorder.h @@ -16,26 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ #define XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass that reorders all-gather(broadcast(x)) -> broadcast(all-gather(x)). -// The intent is to reduce the size of all-gather when possible by doing an -// all-gather on the (smaller) pre-broadcasted data and then applying the -// broadcast. -class AllGatherBroadcastReorder : public HloModulePass { - public: - absl::string_view name() const override { return "all-gather-bcast-reorder"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h" #endif // XLA_SERVICE_ALL_GATHER_BROADCAST_REORDER_H_ diff --git a/third_party/xla/xla/service/all_gather_combiner.h b/third_party/xla/xla/service/all_gather_combiner.h index 79bf388322081c..9c7029207c6c3d 100644 --- a/third_party/xla/xla/service/all_gather_combiner.h +++ b/third_party/xla/xla/service/all_gather_combiner.h @@ -16,76 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_GATHER_COMBINER_H_ #define XLA_SERVICE_ALL_GATHER_COMBINER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_domain_map.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Combines small non-dependent AllGather ops into larger combined -// AllGather ops. A typical AllGather implementation has a minimum -// latency-induced time for a AllGather op so a single combined op can be -// more efficient than many small ones. -class AllGatherCombiner : public HloModulePass { - public: - AllGatherCombiner(int64_t combine_threshold_in_bytes, - int64_t combine_threshold_count, bool combine_by_dim, - bool combine_different_dtypes = true); - - absl::string_view name() const override { return "all-gather-combiner"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // The group key encapsulates all of the properties which must match for it to - // be possible to combine the instructions. - // The field of the key corresponds to the following: - // 1. all_gather_dimension - // 2. domain_metadata_id - // 3. channel_id - // 4. use_global_device_ids - // 5. data_type - // 6. replica_groups - // 7. extra arguments in string format. - using GroupKey = - std::tuple, int64_t, bool, bool, PrimitiveType, - std::vector>, std::string>; - - static std::string& GetGroupKeyExtraArgs(GroupKey& key); - - // Returns a key that will be equal for instructions that might be combined, - // or different if not. - static std::optional CombineKey( - const HloInstruction* instruction, const HloDomainMap& domain_map, - bool combine_by_dim, bool combine_different_dtypes = true); - - protected: - absl::StatusOr RunWithKeyCombiner( - HloModule* module, - const absl::flat_hash_set& execution_threads, - absl::FunctionRef( - const HloInstruction*, const HloDomainMap&, bool, bool)> - combine_key); - - private: - // Combine all gather ops up to this threshold. - int64_t combine_threshold_in_bytes_; - - // Combine all gather ops up to this threshold (number of operands). - int64_t combine_threshold_count_; - - // Combine only all-gather ops with the same gather dimension. - bool combine_by_dim_; - - // Combine all-gather ops with different dtypes. - bool combine_different_dtypes_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" #endif // XLA_SERVICE_ALL_GATHER_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_gather_decomposer.cc b/third_party/xla/xla/service/all_gather_decomposer.cc index 98443b9113f976..ce3ed5f5f44026 100644 --- a/third_party/xla/xla/service/all_gather_decomposer.cc +++ b/third_party/xla/xla/service/all_gather_decomposer.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_gather_decomposer_test.cc b/third_party/xla/xla/service/all_gather_decomposer_test.cc index 214134648a7ea4..b857ce032959fc 100644 --- a/third_party/xla/xla/service/all_gather_decomposer_test.cc +++ b/third_party/xla/xla/service/all_gather_decomposer_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_reduce_combiner.h b/third_party/xla/xla/service/all_reduce_combiner.h index bd1aa811f97160..f0f3a200f22f1f 100644 --- a/third_party/xla/xla/service/all_reduce_combiner.h +++ b/third_party/xla/xla/service/all_reduce_combiner.h @@ -16,39 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_COMBINER_H_ #define XLA_SERVICE_ALL_REDUCE_COMBINER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/array2d.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Combines small non-dependent AllReduce ops into larger combined -// AllReduce ops. A typical AllReduce implementation has a minimum -// latency-induced time for a AllReduce op so a single combined op can be -// more efficient than many small ones. -class AllReduceCombiner : public HloModulePass { - public: - AllReduceCombiner(int64_t combine_threshold_in_bytes, - int64_t combine_threshold_count); - - absl::string_view name() const override { return "all-reduce-combiner"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - // Combine all reduce ops up to this threshold. - int64_t combine_threshold_in_bytes_; - - // Combine all reduce ops up to this threshold (number of operands). - int64_t combine_threshold_count_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/all_reduce_combiner.h" #endif // XLA_SERVICE_ALL_REDUCE_COMBINER_H_ diff --git a/third_party/xla/xla/service/all_reduce_contiguous.h b/third_party/xla/xla/service/all_reduce_contiguous.h index 102245cd2ee36a..7dc1a6501259d4 100644 --- a/third_party/xla/xla/service/all_reduce_contiguous.h +++ b/third_party/xla/xla/service/all_reduce_contiguous.h @@ -16,24 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ #define XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Concatenates all-reduce operands together, so the all-reduce is performed -// over a single, contiguous buffer. -class AllReduceContiguous : public HloModulePass { - public: - absl::string_view name() const override { return "all-reduce-contiguous"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/all_reduce_contiguous.h" #endif // XLA_SERVICE_ALL_REDUCE_CONTIGUOUS_H_ diff --git a/third_party/xla/xla/service/all_reduce_folder.h b/third_party/xla/xla/service/all_reduce_folder.h index 77706bbff34d26..6054de621c1d03 100644 --- a/third_party/xla/xla/service/all_reduce_folder.h +++ b/third_party/xla/xla/service/all_reduce_folder.h @@ -16,33 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_FOLDER_H_ #define XLA_SERVICE_ALL_REDUCE_FOLDER_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass that folds an all-reduce feeding into another all-reduce by expanding -// the replica groups. As an example: -// -// ar0 = all-reduce(x) replica_groups={{0,1},{2,3},{4,5},{6,7}} -// ar1 = all-reduce(all-reduce0) replica_groups={{0,2},{1,3},{4,6},{5,7}} -// -// Can be combined into a single all-reduce: -// -// ar1 = all-reduce(x) replica_groups={{0,1,2,3},{4,5,6,7}} -// - -class AllReduceFolder : public HloModulePass { - public: - absl::string_view name() const override { return "all-reduce-folder"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" #endif // XLA_SERVICE_ALL_REDUCE_FOLDER_H_ diff --git a/third_party/xla/xla/service/all_reduce_folder_test.cc b/third_party/xla/xla/service/all_reduce_folder_test.cc deleted file mode 100644 index e984d089adb196..00000000000000 --- a/third_party/xla/xla/service/all_reduce_folder_test.cc +++ /dev/null @@ -1,249 +0,0 @@ -/* Copyright 2021 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/all_reduce_folder.h" - -#include -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/test.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace { - -namespace m = xla::testing::opcode_matchers; -using ::testing::HasSubstr; - -class AllReduceFolderTest : public HloTestBase { - public: - absl::StatusOr> RunPass( - absl::string_view hlo_module, bool expect_change) { - TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module)); - auto changed = AllReduceFolder().Run(module.get()); - if (!changed.ok()) { - return changed.status(); - } - EXPECT_EQ(changed.value(), expect_change); - return absl::StatusOr>(std::move(module)); - } - - size_t AllReduceCount(std::unique_ptr &module) { - return absl::c_count_if(module->entry_computation()->instructions(), - HloPredicateIsOp); - } -}; - -TEST_F(AllReduceFolderTest, Simple) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); -} - -// Same as Simple, but groups for the 2 all-reduce's are swapped. -TEST_F(AllReduceFolderTest, SimpleSwap) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,2},{1,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), HasSubstr("replica_groups={{0,1,2,3}}")); -} - -TEST_F(AllReduceFolderTest, EmptyReplicaGroups) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); -} - -TEST_F(AllReduceFolderTest, MismatchOtherProperties0) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, channel_id=1, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); -} - -TEST_F(AllReduceFolderTest, MismatchOtherProperties1) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -mul { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT mul = f32[] multiply(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3}}, to_apply=mul -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); -} - -TEST_F(AllReduceFolderTest, NotFoldable) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,1},{2,3}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/false)); -} - -TEST_F(AllReduceFolderTest, Foldable0) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,4},{1,5},{2,3},{6,7}}, to_apply=sum - ROOT ar1 = f32[8] all-reduce(ar0), replica_groups={{0,5},{4,1},{2,7},{3,6}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,4,5},{2,3,6,7}}")); -} - -// Verify that a chain of foldable all-reduce's folds in a single pass -// invocation. -TEST_F(AllReduceFolderTest, FoldableChain) { - absl::string_view hlo_string = R"( -HloModule m - -sum { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT add.2 = f32[] add(a, b) -} - -ENTRY main { - p0 = f32[8] parameter(0) - ar0 = f32[8] all-reduce(p0), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=sum - ar1 = f32[8] all-reduce(ar0), replica_groups={{0,2},{1,3},{4,6},{5,7}}, to_apply=sum - ROOT ar2 = f32[8] all-reduce(ar1), replica_groups={{0,4},{1,5},{2,6},{3,7}}, to_apply=sum -} -)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunPass(hlo_string, /*expect_change=*/true)); - std::cerr << module->ToString(); - EXPECT_EQ(AllReduceCount(module), 1); - HloInstruction *root = module->entry_computation()->root_instruction(); - EXPECT_THAT(root, m::AllReduce(m::Parameter(0))); - EXPECT_THAT(root->ToString(), - HasSubstr("replica_groups={{0,1,2,3,4,5,6,7}}")); -} - -} // namespace -} // namespace xla diff --git a/third_party/xla/xla/service/all_reduce_key.cc b/third_party/xla/xla/service/all_reduce_key.cc index bd2fd49dc6be51..82319b09c3e0f8 100644 --- a/third_party/xla/xla/service/all_reduce_key.cc +++ b/third_party/xla/xla/service/all_reduce_key.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_domain_map.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_key.h b/third_party/xla/xla/service/all_reduce_key.h index 53a444d8a95c5b..fd72f7e4230bae 100644 --- a/third_party/xla/xla/service/all_reduce_key.h +++ b/third_party/xla/xla/service/all_reduce_key.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_domain_map.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_promotion.cc b/third_party/xla/xla/service/all_reduce_promotion.cc index b0328759c7d310..0e60d59b6a24be 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.cc +++ b/third_party/xla/xla/service/all_reduce_promotion.cc @@ -19,6 +19,19 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/xla_data.pb.h" + namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_promotion.h b/third_party/xla/xla/service/all_reduce_promotion.h index a1ad33033187f1..e6459f82e00dc2 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.h +++ b/third_party/xla/xla/service/all_reduce_promotion.h @@ -17,7 +17,15 @@ limitations under the License. #define XLA_SERVICE_ALL_REDUCE_PROMOTION_H_ #include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/change_op_data_type.h" +#include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_promotion_test.cc b/third_party/xla/xla/service/all_reduce_promotion_test.cc index 380c1c3cf8e246..86d5fde6eb71c5 100644 --- a/third_party/xla/xla/service/all_reduce_promotion_test.cc +++ b/third_party/xla/xla/service/all_reduce_promotion_test.cc @@ -15,9 +15,15 @@ limitations under the License. #include "xla/service/all_reduce_promotion.h" +#include +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/all_reduce_reassociate.cc b/third_party/xla/xla/service/all_reduce_reassociate.cc index c7becb2c436c0b..6063eef7b6e6b0 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_reassociate.h b/third_party/xla/xla/service/all_reduce_reassociate.h index f2ff998b4b6f04..9fbeb32e6bf81f 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate.h +++ b/third_party/xla/xla/service/all_reduce_reassociate.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_REASSOCIATE_H_ #define XLA_SERVICE_ALL_REDUCE_REASSOCIATE_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/all_reduce_reassociate_test.cc b/third_party/xla/xla/service/all_reduce_reassociate_test.cc index b7130508e878ea..c0a91a93be215c 100644 --- a/third_party/xla/xla/service/all_reduce_reassociate_test.cc +++ b/third_party/xla/xla/service/all_reduce_reassociate_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/all_reduce_simplifier.cc b/third_party/xla/xla/service/all_reduce_simplifier.cc index 0760433bda4489..c51492f0550cc9 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier.cc @@ -24,15 +24,17 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_replication_analysis.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/all_reduce_simplifier.h b/third_party/xla/xla/service/all_reduce_simplifier.h index 1c44b945bdf697..1a8463075198cb 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier.h +++ b/third_party/xla/xla/service/all_reduce_simplifier.h @@ -16,7 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ #define XLA_SERVICE_ALL_REDUCE_SIMPLIFIER_H_ +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -28,9 +30,6 @@ namespace xla { // replaced by a multiply with the replica count. class AllReduceSimplifier : public HloModulePass { public: - explicit AllReduceSimplifier(int64_t replica_count) - : replica_count_(replica_count) {} - ~AllReduceSimplifier() override = default; absl::string_view name() const override { return "all-reduce-simp"; } // Run all-reduce simplification on the given computation. Returns whether the @@ -39,9 +38,6 @@ class AllReduceSimplifier : public HloModulePass { absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; - - private: - int64_t replica_count_; }; } // namespace xla diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index 35f5955076ad7e..7048bf20a61639 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -20,17 +20,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" -#include "xla/window_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -80,7 +76,7 @@ test { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kModuleStr, /*replica_count=*/8)); - AllReduceSimplifier simplifier(/*replica_count=*/8); + AllReduceSimplifier simplifier; ASSERT_TRUE(simplifier.Run(module.get()).value()); EXPECT_THAT( module->entry_computation()->root_instruction(), @@ -116,7 +112,7 @@ test { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kModuleStr, /*replica_count=*/8)); - AllReduceSimplifier simplifier(/*replica_count=*/8); + AllReduceSimplifier simplifier; ASSERT_TRUE(simplifier.Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::MultiplyAnyOrder( @@ -157,7 +153,7 @@ test { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kModuleStr, /*replica_count=*/8)); - AllReduceSimplifier simplifier(/*replica_count=*/8); + AllReduceSimplifier simplifier; ASSERT_TRUE(simplifier.Run(module.get()).value()); EXPECT_THAT( module->entry_computation()->root_instruction(), @@ -187,7 +183,7 @@ test { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule( kModuleStr, /*replica_count=*/8)); - AllReduceSimplifier simplifier(/*replica_count=*/8); + AllReduceSimplifier simplifier; EXPECT_TRUE(simplifier.Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Parameter(0))); @@ -217,7 +213,7 @@ test { auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, /*num_partitions=*/8)); module->mutable_config().set_use_spmd_partitioning(true); - AllReduceSimplifier simplifier(/*replica_count=*/1); + AllReduceSimplifier simplifier; EXPECT_TRUE(simplifier.Run(module.get()).value()); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Parameter(0))); @@ -252,7 +248,7 @@ test { auto module, ParseAndReturnVerifiedModule(kModuleStr, /*replica_count=*/1, /*num_partitions=*/8)); module->mutable_config().set_use_spmd_partitioning(true); - AllReduceSimplifier simplifier(/*replica_count=*/1); + AllReduceSimplifier simplifier; EXPECT_FALSE(simplifier.Run(module.get()).value()); } @@ -279,7 +275,7 @@ test { /*num_partitions=*/1)); // Mark as MPMD. module->mutable_config().set_use_spmd_partitioning(false); - AllReduceSimplifier simplifier(/*replica_count=*/2); + AllReduceSimplifier simplifier; EXPECT_FALSE(simplifier.Run(module.get()).value()); } } // namespace diff --git a/third_party/xla/xla/service/all_to_all_decomposer.cc b/third_party/xla/xla/service/all_to_all_decomposer.cc index ecb08af7660382..dabea315b81c40 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.cc +++ b/third_party/xla/xla/service/all_to_all_decomposer.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/layout_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/all_to_all_decomposer.h b/third_party/xla/xla/service/all_to_all_decomposer.h index 3ef1891a412665..f05e586692b810 100644 --- a/third_party/xla/xla/service/all_to_all_decomposer.h +++ b/third_party/xla/xla/service/all_to_all_decomposer.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ #define XLA_SERVICE_ALL_TO_ALL_DECOMPOSER_H_ +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/allocation_tracker.cc b/third_party/xla/xla/service/allocation_tracker.cc index 95168eba9c6c61..507107723093ab 100644 --- a/third_party/xla/xla/service/allocation_tracker.cc +++ b/third_party/xla/xla/service/allocation_tracker.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/allocation_tracker.h b/third_party/xla/xla/service/allocation_tracker.h index f7748d7162ace2..cea193eaea8568 100644 --- a/third_party/xla/xla/service/allocation_tracker.h +++ b/third_party/xla/xla/service/allocation_tracker.h @@ -22,9 +22,15 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/service/backend.h" +#include "xla/service/shaped_buffer.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/ar_crs_combiner.h b/third_party/xla/xla/service/ar_crs_combiner.h index ea3acd95a24dfd..57b36ee2b1599d 100644 --- a/third_party/xla/xla/service/ar_crs_combiner.h +++ b/third_party/xla/xla/service/ar_crs_combiner.h @@ -16,185 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_AR_CRS_COMBINER_H_ #define XLA_SERVICE_AR_CRS_COMBINER_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/call_graph.h" - -namespace xla { - -// When the HLO graph contains a cross-module AllReduce (N separate AllReduce -// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op -// for SPMD partitioning), followed by some simple linear operations, followed -// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we -// can combine the CMAR and the CRAR, to use an efficient AllReduce -// implementation that fully utilizes the interconnect bandwidth. -// -// Such sequences appear in spatially partitioned models (either MPMD or SPMD). -// This pass must run right after spatial partitioning, when the code is still -// in a single HLO module. -// -// The steps are: -// 1) Find CMARs followed by simple ops followed by CRARs. -// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD -// partitioning, there will only be a single CMAR for each channel_id. -// 3) Prove that the CMAR patterns in each core produce the same result. -// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the -// other operand by the number of spatial partitions. -// 5) Turn the CRAR into an all-core AllReduce. -// -// The pass also handles the case where multiple CMARs lead to the same CRAR, -// and eliminates all CMARs. This graph: -// -// Y -// | -// X CMAR_2 Z -// | \ / -// CMAR_1 + -// \ / -// + -// | -// CRAR -// -// gets rewritten to: -// -// Z num_partitions -// \ / -// Y div -// \ / -// X + -// \ / -// + -// | -// all-core AR -// -class ArCrsCombiner : public HloModulePass { - public: - ArCrsCombiner(int num_spatial_partitions, bool spmd_partition) - : num_spatial_partitions_(num_spatial_partitions), - spmd_partition_(spmd_partition) {} - absl::string_view name() const override { return "ar-crs-combiner"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // Helper method to allow testing of InstructionsComputeSameValue. - static bool TestInstructionsComputeSameValue(HloInstruction* i1, - HloInstruction* i2); - - private: - // We used this struct because multiple ARs could be paired with the same CRS. - // In this case, we want to select the AR that is furthest from the CRS, - // because it makes it easier to eliminate all ARs during RewriteGraph. - struct ArCrsPair { - HloInstruction* ar; - HloInstruction* crs; - // The length of the path from AR to CRS in the HLO graph. - int64_t distance; - - ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, - int64_t dist) - : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} - - std::string ToString() { - std::string result; - absl::StrAppend(&result, "("); - HloInstruction* instruction = ar; - while (instruction != crs) { - absl::StrAppend(&result, instruction->name(), ","); - instruction = instruction->users()[0]; - } - absl::StrAppend(&result, instruction->name(), - ")[id:", *(ar->channel_id()), ",dist:", distance, "]"); - return result; - } - }; - - std::optional MatchesArCrsPattern( - HloInstruction* instruction); - - // If the passed instruction is a while parameter, and the while body is only - // called by a single while instruction, return the while instruction. - std::optional WhileFromBodyParameter( - HloInstruction* instruction); - - // If the passed instruction is a parameter in one of the branch computations, - // and the branch body is only called by a single instruction, return the - // conditional instruction. - std::optional ConditionalFromBodyParameter( - HloInstruction* instruction); - - // Returns a vector of tuple instructions. - // If all instructions that flow to "instruction" are tuples, return them. - // Otherwise, return std::nullopt. Returns an empty vector if the instruction - // is already in the visited set. - std::optional> GetAllTuples( - HloInstruction* instruction, - absl::flat_hash_set* visited); - - // Checks whether two different elements in the same tuple compute the same - // value. - bool TupleElementsComputeSameValue( - HloInstruction* tuple_shaped_instruction, int64_t i1, int64_t i2, - absl::flat_hash_map* visited_pairs); - - // Returns whether the instructions i1 and i2 can be shown to evaluate to the - // same value. Handling WHILE requires recursion, which may cause us to visit - // the same instruction again. To avoid infinite loops, we pass a cache of - // visited instruction pairs. - bool InstructionsComputeSameValue( - HloInstruction* i1, HloInstruction* i2, - absl::flat_hash_map* visited_pairs); - - // Populates all_reduce_map_. - void GroupAllReducesById(HloModule* module); - - // Looks at each AllReduce group in all_reduce_map_, and keeps only the - // groups for which it's safe to move the AllReduce later in the HLO graph. - absl::Status KeepProvablyEqualInstructionGroupsMPMD(); - - // Same as above, but runs on SPMD partitioned module instead of MPMD. - absl::Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module); - - // Performs the graph rewrite that eliminates the early AllReduce and turns - // the later CRS into an AllReduce. - absl::StatusOr RewriteGraph(); - - int num_spatial_partitions_; - - // Run this combiner pass assuming the input module is an SPMD partitioned - // module (as opposed to MPMD partitioned). - // - // The main difference between the two w.r.t. this pass is that there would be - // N all-reduce ops for each channel in MPMD mode, whereas there is only 1 - // for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO - // equivalence check in SPMD mode. - bool spmd_partition_; - - // Map from all-reduce ids to the AR/CRS pairs. - absl::flat_hash_map> all_reduce_map_; - - // Map from a CRS instruction to the all-reduce ID of the AR paired with the - // CRS. Sometimes, several ARs in the code could be paired with the same CRS. - // We use this map to pick a single AR/CRS path to rewrite. - absl::flat_hash_map crs_reserved_map_; - - std::unique_ptr call_graph_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/ar_crs_combiner.h" #endif // XLA_SERVICE_AR_CRS_COMBINER_H_ diff --git a/third_party/xla/xla/service/async_collective_creator.h b/third_party/xla/xla/service/async_collective_creator.h index 5a542cf1e48c59..f3141f50ece42a 100644 --- a/third_party/xla/xla/service/async_collective_creator.h +++ b/third_party/xla/xla/service/async_collective_creator.h @@ -16,53 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ #define XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ -#include -#include -#include - -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Transforms each all-reduce instruction to a pair of all-reduce-start and -// all-reduce-done. -class AsyncCollectiveCreator : public HloModulePass { - public: - // Function to query the shape of the "context" for collectives that use - // HLO async-start/async-done. - using ContextShapeQuery = - std::function(const HloInstruction *)>; - struct CollectiveCreatorConfig { - HloPredicate convert_all_reduce = HloPredicateFalse; - HloPredicate convert_all_gather = HloPredicateFalse; - HloPredicate convert_collective_broadcast = HloPredicateFalse; - HloPredicate convert_collective_permute = HloPredicateFalse; - HloPredicate convert_all_to_all = HloPredicateFalse; - HloPredicate convert_reduce_scatter = HloPredicateFalse; - ContextShapeQuery get_context_shapes = [](const HloInstruction *) { - return std::vector{}; - }; - int64_t all_reduce_min_threshold_in_bytes = 0; - }; - explicit AsyncCollectiveCreator(CollectiveCreatorConfig creator_config) - : config_(std::move(creator_config)) {} - absl::string_view name() const override { return "async-collective-creator"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule *module, - const absl::flat_hash_set &execution_threads) override; - - std::vector MatchCollectives(HloComputation *computation); - absl::StatusOr ReplaceCollectives( - HloComputation *computation, - std::vector &supported_collectives); - const CollectiveCreatorConfig *config() const { return &config_; } - - private: - CollectiveCreatorConfig config_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/async_collective_creator.h" #endif // XLA_SERVICE_ASYNC_COLLECTIVE_CREATOR_H_ diff --git a/third_party/xla/xla/service/backend.cc b/third_party/xla/xla/service/backend.cc index eea05d78293e93..5ed8d66ddca365 100644 --- a/third_party/xla/xla/service/backend.cc +++ b/third_party/xla/xla/service/backend.cc @@ -13,6 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/service/computation_placer.h" +#include "xla/service/stream_pool.h" +#include "xla/service/transfer_manager.h" +#include "xla/stream_executor/platform.h" +#include "tsl/platform/statusor.h" #define EIGEN_USE_THREADS #include "xla/service/backend.h" diff --git a/third_party/xla/xla/service/backend.h b/third_party/xla/xla/service/backend.h index ba54e008333989..cbbec594bc9020 100644 --- a/third_party/xla/xla/service/backend.h +++ b/third_party/xla/xla/service/backend.h @@ -23,17 +23,23 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/service/compiler.h" #include "xla/service/computation_placer.h" #include "xla/service/stream_pool.h" #include "xla/service/transfer_manager.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "tsl/platform/threadpool.h" namespace Eigen { struct ThreadPoolDevice; diff --git a/third_party/xla/xla/service/batch_dot_simplification.h b/third_party/xla/xla/service/batch_dot_simplification.h index 6ba3cf13e69f27..381b67955adf09 100644 --- a/third_party/xla/xla/service/batch_dot_simplification.h +++ b/third_party/xla/xla/service/batch_dot_simplification.h @@ -16,27 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ #define XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { -// Simplifies batch dot operations. -// -// Normally these would live in the algebraic simplifier, but we want to run -// this to fixpoint (this pass reaches fixed point in one execution) before we -// run the DotDecomposer. -class BatchDotSimplification : public HloModulePass { - public: - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - absl::string_view name() const override { return "batch-dot-simplification"; } - - private: - absl::StatusOr ElideDegenerateBatchDimensionFromBatchDot( - HloInstruction* batch_dot); -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" #endif // XLA_SERVICE_BATCH_DOT_SIMPLIFICATION_H_ diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc index 441c3b69f3da28..c29a4cb65eb0cc 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc @@ -29,9 +29,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { @@ -46,42 +48,75 @@ bool IsBatchScatter(const HloScatterInstruction* scatter) { return !dims.input_batching_dims().empty(); } +// If `type` is an integral type in which `size` doesn't fit, promote it to S32 +// or S64 (depending on `size`). +PrimitiveType PromoteTypeForSize(PrimitiveType type, int64_t size) { + // Gather/Scatter should have an integral type, but we check just in case. + if (!primitive_util::IsIntegralType(type) || + primitive_util::FitsInIntegralType(size, type)) { + return type; + } + if (primitive_util::FitsInIntegralType(size, PrimitiveType::S32)) { + return PrimitiveType::S32; + } + return PrimitiveType::S64; +} + +// If `indices_batching_dims` and `updated_index_map` are both sorted, then the +// `indices_are_sorted` property is preserved. +// +// This is because each concatenated iota is monotonically increasing, sorted +// indices batching dims mean their order corresponds to the order of batching +// dims in the operand, and a sorted updated start index map means the order of +// the index vector dim corresponds to the order of operand dims. +bool GetUpdatedIndicesAreSorted(bool indices_are_sorted, + absl::Span indices_batching_dims, + absl::Span updated_index_map) { + return indices_are_sorted && absl::c_is_sorted(indices_batching_dims) && + absl::c_is_sorted(updated_index_map); +} + // Update gather/scater indices by adding fake batching iota dimensions. HloInstruction* CreateConcatIndices( HloInstruction* inst, HloInstruction* indices, int64_t index_vector_dim, absl::Span indices_batching_dims, BatchedGatherScatterNormalizer* normalizer) { - const bool index_vector_dim_on_last_dim = - index_vector_dim == indices->shape().rank(); + // The batching dim sizes might not fit in the existing element type, + // in which case we need to promote it. + PrimitiveType element_type = indices->shape().element_type(); + for (int64_t indices_batching_dim : indices_batching_dims) { + element_type = PromoteTypeForSize( + element_type, indices->shape().dimensions(indices_batching_dim)); + } + if (element_type != indices->shape().element_type()) { + Shape indices_shape = indices->shape(); + indices_shape.set_element_type(element_type); + indices = inst->parent()->AddInstruction( + HloInstruction::CreateConvert(indices_shape, indices)); + } Shape iota_shape = indices->shape(); + const bool index_vector_dim_on_last_dim = + index_vector_dim == iota_shape.rank(); if (index_vector_dim_on_last_dim) { std::vector dimensions(iota_shape.dimensions().begin(), iota_shape.dimensions().end()); dimensions.push_back(1); - iota_shape = ShapeUtil::MakeShape(iota_shape.element_type(), dimensions); + iota_shape = ShapeUtil::MakeShape(element_type, dimensions); + indices = inst->AddInstruction( + HloInstruction::CreateReshape(iota_shape, indices)); } iota_shape.set_dimensions(index_vector_dim, 1); normalizer->UpdateLayout(&iota_shape); std::vector indices_to_concat; + indices_to_concat.reserve(indices_batching_dims.size() + 1); for (int64_t indices_batching_dim : indices_batching_dims) { indices_to_concat.push_back(inst->parent()->AddInstruction( HloInstruction::CreateIota(iota_shape, indices_batching_dim))); } - if (index_vector_dim_on_last_dim) { - std::vector dimensions(indices->shape().dimensions().begin(), - indices->shape().dimensions().end()); - dimensions.push_back(1); - Shape reshape_shape = - ShapeUtil::MakeShape(indices->shape().element_type(), dimensions); - normalizer->UpdateLayout(&reshape_shape); - HloInstruction* reshaped_indices = inst->AddInstruction( - HloInstruction::CreateReshape(reshape_shape, indices)); - indices_to_concat.push_back(reshaped_indices); - } else { - indices_to_concat.push_back(indices); - } + indices_to_concat.push_back(indices); + Shape concat_shape = iota_shape; concat_shape.set_dimensions( index_vector_dim, @@ -121,7 +156,10 @@ absl::StatusOr NormalizeBatchGather( dims.index_vector_dim()); return gather->AddInstruction(HloInstruction::CreateGather( gather->shape(), gather_operand, gather_indices, updated_dims, - gather->gather_slice_sizes(), gather->indices_are_sorted())); + gather->gather_slice_sizes(), + GetUpdatedIndicesAreSorted(gather->indices_are_sorted(), + dims.start_indices_batching_dims(), + start_index_map))); } absl::StatusOr NormalizeBatchScatter( @@ -154,7 +192,10 @@ absl::StatusOr NormalizeBatchScatter( scatter_dims_to_operand_dims, dims.index_vector_dim()); return scatter->AddInstruction(HloInstruction::CreateScatter( scatter->shape(), scatter_operands, scatter_indices, scatter_updates, - scatter->to_apply(), updated_dims, scatter->indices_are_sorted(), + scatter->to_apply(), updated_dims, + GetUpdatedIndicesAreSorted(scatter->indices_are_sorted(), + dims.scatter_indices_batching_dims(), + scatter_dims_to_operand_dims), scatter->unique_indices())); } diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h index 4b5560d38dceec..50c1d43def0293 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc index 22bbdea6fb9be0..ea6995651389d3 100644 --- a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc @@ -79,6 +79,126 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512,1024,100], start_indices: s6 )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesBecomeUnsorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[3,4,1]{2,1,0})->f32[3,4,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[3,4,1]) -> f32[3,4,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[3,4,1]{2,1,0} parameter(1) + ROOT %gather = f32[3,4,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[3,4,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={1}, start_index_map={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[3,4,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[3,4,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[3,4,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[3,4,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,2,1}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesBecomeUnsorted2) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[3,2,1]{2,1,0})->f32[3,2,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[3,2,1]) -> f32[3,2,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[3,2,1]{2,1,0} parameter(1) + ROOT %gather = f32[3,2,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[3,2,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={1,0}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[3,2,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[IOTA2:.*]] = s64[3,2,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[INDICES_CONCAT:.*]] = s64[3,2,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[3,2,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesRemainSorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0})->f32[2,3,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[2,3,1]) -> f32[2,3,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[2,3,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,3,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=true +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,3,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-SAME: indices_are_sorted=true + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchGatherIndicesRemainUnsorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0})->f32[2,3,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,3,4,512], start_indices: s64[2,3,1]) -> f32[2,3,5] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %start_indices = s64[2,3,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,3,5]{2,1,0} + gather(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={2}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,1,5}, + indices_are_sorted=false +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,3,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,1,5} + CHECK-NOT: indices_are_sorted + )"); +} + TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGatherDimSizeZero) { constexpr absl::string_view kModuleStr = R"( HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,0]{5,4,3,2,1,0}, s64[10,9,8,7,5,0]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,0]{9,8,7,6,5,4,3,2,1,0}} @@ -180,6 +300,42 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512,1024,100], scatter_indices: )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + NormalizeBatchScatterIndicesRemainSorted) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyScatter, entry_computation_layout={(f32[2,3,4,512]{3,2,1,0}, s64[2,3,1]{2,1,0}, f32[2,3,5]{2,1,0})->f32[2,3,4,512]{3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[2,3,4,512], scatter_indices: s64[2,3,1], updates: f32[2,3,5]) -> f32[2,3,4,512] { + %input_tensor = f32[2,3,4,512]{3,2,1,0} parameter(0) + %scatter_indices = s64[2,3,1]{2,1,0} parameter(1) + %updates = f32[2,3,5]{2,1,0} parameter(2) + ROOT %scatter = f32[2,3,4,512]{3,2,1,0} + scatter(f32[2,3,4,512]{3,2,1,0} %input_tensor, s64[2,3,1]{2,1,0} %scatter_indices, f32[2,3,5]{2,1,0} %updates), + update_window_dims={2}, inserted_window_dims={2}, scatter_dims_to_operand_dims={2}, input_batching_dims={0,1}, + scatter_indices_batching_dims={0,1}, index_vector_dim=2, indices_are_sorted=true, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2,3,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s64[2,3,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[2,3,4,512]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={2}, + CHECK-SAME: inserted_window_dims={0,1,2}, + CHECK-SAME: scatter_dims_to_operand_dims={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: indices_are_sorted=true + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatterDimSizeZero) { constexpr absl::string_view kModuleStr = R"( @@ -245,5 +401,121 @@ ENTRY %Gather (input_tensor: f32[50,512,1024], start_indices: s64[10,9,8,7,6,512 )"); } +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeDoesNotOverflowIndicesType) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,127,512]{2,1,0}, s8[2,127,1]{2,1,0})->f32[2,127,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,127,512], start_indices: s8[2,127,1]) -> f32[2,127,5] { + %input_tensor = f32[2,127,512]{2,1,0} parameter(0) + %start_indices = s8[2,127,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,127,5]{2,1,0} + gather(f32[2,127,512]{2,1,0} %input_tensor, s8[2,127,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s8[2,127,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s8[2,127,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[INDICES_CONCAT:.*]] = s8[2,127,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[2,127,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsIndicesType) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,128,512]{2,1,0}, s8[2,128,1]{2,1,0})->f32[2,128,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,128,512], start_indices: s8[2,128,1]) -> f32[2,128,5] { + %input_tensor = f32[2,128,512]{2,1,0} parameter(0) + %start_indices = s8[2,128,1]{2,1,0} parameter(1) + ROOT %gather = f32[2,128,5]{2,1,0} + gather(f32[2,128,512]{2,1,0} %input_tensor, s8[2,128,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s32[2,128,1]{{.*}} convert(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s32[2,128,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[CONVERT]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2,128,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsIndicesTypeAndS32) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2147483648,2,512]{2,1,0}, s8[2147483648,2,1]{2,1,0})->f32[2147483648,2,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2147483648,2,512], start_indices: s8[2147483648,2,1]) -> f32[2147483648,2,5] { + %input_tensor = f32[2147483648,2,512]{2,1,0} parameter(0) + %start_indices = s8[2147483648,2,1]{2,1,0} parameter(1) + ROOT %gather = f32[2147483648,2,5]{2,1,0} + gather(f32[2147483648,2,512]{2,1,0} %input_tensor, s8[2147483648,2,1]{2,1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[2147483648,2,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s64[2147483648,2,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s64[2147483648,2,1]{{.*}} convert(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s64[2147483648,2,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[CONVERT]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2147483648,2,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, + BatchingDimSizeOverflowsAndIndexVectorDimOnLastDim) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[2,128,512]{2,1,0}, s8[2,128]{1,0})->f32[2,128,5]{2,1,0}} + +ENTRY %Gather (input_tensor: f32[2,128,512], start_indices: s8[2,128]) -> f32[2,128,5] { + %input_tensor = f32[2,128,512]{2,1,0} parameter(0) + %start_indices = s8[2,128]{1,0} parameter(1) + ROOT %gather = f32[2,128,5]{2,1,0} + gather(f32[2,128,512]{2,1,0} %input_tensor, s8[2,128]{1,0} %start_indices), + offset_dims={2}, collapsed_slice_dims={}, start_index_map={2}, operand_batching_dims={0,1}, + start_indices_batching_dims={0,1}, index_vector_dim=2, slice_sizes={1,1,5} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=0 + CHECK: %[[IOTA2:.*]] = s32[2,128,1]{{.*}} iota(), iota_dimension=1 + CHECK: %[[CONVERT:.*]] = s32[2,128]{{.*}} convert(%start_indices) + CHECK: %[[RESHAPE:.*]] = s32[2,128,1]{{.*}} reshape(%[[CONVERT]]) + CHECK: %[[INDICES_CONCAT:.*]] = s32[2,128,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[RESHAPE]]) + CHECK: ROOT %[[GATHER:.*]] = f32[2,128,5]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={2}, + CHECK-SAME: collapsed_slice_dims={0,1}, + CHECK-SAME: start_index_map={0,1,2}, + CHECK-SAME: index_vector_dim=2, + CHECK-SAME: slice_sizes={1,1,5} + )"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/batchnorm_expander.h b/third_party/xla/xla/service/batchnorm_expander.h index 0ae50afe13eb3c..15738efdc44158 100644 --- a/third_party/xla/xla/service/batchnorm_expander.h +++ b/third_party/xla/xla/service/batchnorm_expander.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/batchnorm_expander_test.cc b/third_party/xla/xla/service/batchnorm_expander_test.cc index e4bb01e9f486da..25cfa87004be01 100644 --- a/third_party/xla/xla/service/batchnorm_expander_test.cc +++ b/third_party/xla/xla/service/batchnorm_expander_test.cc @@ -18,18 +18,17 @@ limitations under the License. #include #include +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_matchers.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/bfloat16_conversion_folding.h b/third_party/xla/xla/service/bfloat16_conversion_folding.h index c8bc39a98c4f74..deb5675fc85cbe 100644 --- a/third_party/xla/xla/service/bfloat16_conversion_folding.h +++ b/third_party/xla/xla/service/bfloat16_conversion_folding.h @@ -16,42 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ #define XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/float_support.h" - -namespace xla { - -// A pass which folds F32 <-> BF16 conversions to their operands or users, when -// it is supported by the backend. -// -// This pass follows the passed-in backend-specific BF16 support rules, but can -// introduce mixed precision in individual HLOs which breaks the assumption of -// some other HLO passes. So it should be used at the end of the HLO -// optimization pipeline followed by a DCE pass. If other passes are needed -// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the -// changed made by this pass. -class BFloat16ConversionFolding : public HloModulePass { - public: - explicit BFloat16ConversionFolding(const FloatSupport* bfloat16_support) - : bfloat16_support_(bfloat16_support) { - DCHECK(bfloat16_support->LowPrecisionType() == BF16); - } - - ~BFloat16ConversionFolding() override = default; - absl::string_view name() const override { return "bfloat16-fold"; } - - // Run BF16 conversion folding on the given computation. Returns whether the - // computation was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const FloatSupport* bfloat16_support_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/bfloat16_conversion_folding.h" #endif // XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/third_party/xla/xla/service/bfloat16_propagation.h b/third_party/xla/xla/service/bfloat16_propagation.h index 3f292823f6edee..e3a0e0fab40b4c 100644 --- a/third_party/xla/xla/service/bfloat16_propagation.h +++ b/third_party/xla/xla/service/bfloat16_propagation.h @@ -16,216 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_BFLOAT16_PROPAGATION_H_ #define XLA_SERVICE_BFLOAT16_PROPAGATION_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/float_support.h" -#include "xla/service/hlo_dataflow_analysis.h" - -namespace xla { - -// HLO pass which reduces the precision of some HLO instructions to BF16 -// according to the backend-specific FloatSupport rule provided by the -// caller. -// -// This pass can be used to reduce instruction precision without affecting the -// numerical accuracy of the module, i.e., the final output of the module would -// be bitwise identical to that without this pass; this is possible if the -// backend already reduces precision to BF16 on some HLO instructions. -// -// This pass will not modify the signature of a computation, unless it is a -// fusion computation or its only caller is a while. -// -// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, -// which has two issues: -// -// 1) It does not guarantee to respect the passed-in FloatSupport -// specification in terms of mixed precision, so the backend may not support an -// HLO that has mixed precision produced by this pass. To address this issue, -// run FloatNormalization with the same FloatSupport after this pass. -// -// 2) In general, mixed precision may break the assumptions of some other HLO -// passes even if the specific backend supports the individual HLOs. Such -// assumptions include that there are no HLOs using mixed precision, or that the -// precision of an HLO's output is determined by its inputs. It should be used -// at the end of the HLO optimization pipeline but before -// BFloat16ConversionFolding. If other passes are needed after this pass, run -// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this -// pass. -class BFloat16Propagation : public HloModulePass { - public: - explicit BFloat16Propagation(const FloatSupport* bfloat16_support); - - ~BFloat16Propagation() override = default; - - absl::string_view name() const override { return "bfloat16-propagation"; } - - // Runs the pass on the given module. Returns whether the module was changed - // (precision reductions were added). - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // Returns whether we should avoid changing the precision of inst regardless - // of the producers and users. - virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst); - - // Determines whether we should consider changing the precision of the given - // instruction in the forward pass. - virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); - - protected: - const FloatSupport* bfloat16_support_; - - private: - // *************************** - // Function called and state produced by the forward analysis pass (from - // parameters to root) that determines the candidate HLOs to use BF16 outputs. - - // The set of instructions to consider using bfloat16, computed in the forward - // pass. - absl::flat_hash_set consider_using_bfloat16_; - - // *************************** - // Functions called and state produced by the backward pass (from root to - // parameters) that finds opportunities to use BF16. - - // Determines the precision for the given instruction in the - // opportunity-finding pass. - void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); - - // Special handling in the opportunity-finding pass for fusion computations. - // - // Precondition: hlo->opcode() == kFusion - void DetermineFusionComputationPrecision(HloInstruction* fusion); - - // Reverts changes to BF16 that will not propagate outside a fusion - // computation. This avoids BF16 casts overhead inside a fusion which won't - // save memory bandwidth. - // - // Precondition: hlo->opcode() == kFusion - void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); - - // Special handling in the opportunity-finding pass for while computations. - // - // Precondition: hlo->opcode() == kWhile - void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); - - // Special handling in the opportunity-finding pass for conditional branches. - // - // Precondition: hlo->opcode() == kConditional - void DetermineConditionalComputationsPrecision(HloInstruction* cond); - - // The set of HloInstructions that have been visited in the - // opportunity-finding pass. - absl::flat_hash_set - instructions_visited_in_backward_pass_; - - // The set of HloComputations that have been visited in the - // opportunity-finding pass. - absl::flat_hash_set - computations_visited_in_backward_pass_; - - // *************************** - // Functions called by the final inconsistency resolving pass. - - // Adjusts the output shapes of HloInstructions such that if two - // HloInstructions have aliasing buffers in their outputs, they must have the - // same precision. - void ResolveInconsistencyOfAliasingBuffers( - HloModule* module, - const absl::flat_hash_set& execution_threads); - - // Resolves inconsistency of aliasing buffers for the given computation, and - // recursively runs on a while instruction's condition and body until a fixed - // point is reached. - bool ResolveInconsistencyOfAliasingBuffersHelper( - HloComputation* computation, - absl::flat_hash_set* visited_computations); - - // Makes the parameters of called computations match how they are called by - // the given HLO. - void AdjustCalledComputationParameters(HloInstruction* hlo); - - // Makes the root instructions of called computations match how they are used - // by the given HLO. - void AdjustCalledComputationRoot(HloInstruction* hlo); - - // *************************** - // Functions called after changes in changes_to_bf16_ are applied. - - // Resolves inconsistencies introduced by this pass for fusions with - // tuple-type output. - absl::Status ResolveInconsistentFusions( - HloModule* module, - const absl::flat_hash_set& execution_threads); - - // Converts the literals in kConstant HLOs which have their types changed to - // BF16 by this pass. - absl::Status ResolveConvertedConstants( - HloModule* module, - const absl::flat_hash_set& execution_threads); - - // Skips no-op conversions (same source and target shapes) that can be - // produced this pass, i.e., replaces them in their uses with their operands. - absl::Status SkipNoopConversions( - HloModule* module, - const absl::flat_hash_set& execution_threads); - - // *************************** - // Functions called and state used by two or more passes. - - // Returns whether all uses of the given HloInstruction can consume BF16 - // input. - bool AllUsersConsumeBF16(const HloInstruction& hlo, - const ShapeIndex& index) const; - - // The output element type of the HLO at the given shape index after changes - // in changes_to_bf16_ are applied. - PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, - const ShapeIndex& index) const; - - // The element type of the HLO value after changes in changes_to_bf16_ are - // applied. - PrimitiveType ValueTypeAfterChange(const HloValue* value) const; - - // If target_type == BF16, adds the HLO at the given index to - // changes_to_bf16_; otherwise, target_type must be F32 and this function - // removes the HLO at the given index from changes_to_bf16_ if it was earlier - // added. - void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, - const ShapeIndex& index, - PrimitiveType target_type); - - // The set of F32 HLO values that must be kept in F32. - absl::flat_hash_set values_that_must_be_kept_as_f32_; - - // Mapping from each HloComputation to the number of callers to it in the - // module. Populated at the beginning of this pass. - absl::flat_hash_map caller_counts_; - - // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which - // are subject to further adjustment, then finally applied to the HLOs. This - // avoids setting changed_ to true but all changes are reverted during - // adjustment. - // - // For each HloInstruction, changes_to_bf16_ stores the affected buffers in - // the output as a map from in-place pointers to subshapes to shape indices. - absl::flat_hash_map> - changes_to_bf16_; - - // Whether the last processed HLO module has been changed by this pass. - bool changed_ = false; - - std::unique_ptr dataflow_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/bfloat16_propagation.h" #endif // XLA_SERVICE_BFLOAT16_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander.h b/third_party/xla/xla/service/bitcast_dtypes_expander.h index f103c37878a603..7824af39cf5829 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander.h +++ b/third_party/xla/xla/service/bitcast_dtypes_expander.h @@ -13,36 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" - #ifndef XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ #define XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ -namespace xla { - -// A pass which expands bitcast-convert between differently sized dtypes to a -// reduction. -class BitcastDtypesExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "bitcast_dtypes_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; - - private: - absl::flat_hash_map computation_cache_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" #endif // XLA_SERVICE_BITCAST_DTYPES_EXPANDER_H_ diff --git a/third_party/xla/xla/service/broadcast_canonicalizer.h b/third_party/xla/xla/service/broadcast_canonicalizer.h index 0cf3bc71c750e2..efedf3ed3481ab 100644 --- a/third_party/xla/xla/service/broadcast_canonicalizer.h +++ b/third_party/xla/xla/service/broadcast_canonicalizer.h @@ -16,25 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_BROADCAST_CANONICALIZER_H_ #define XLA_SERVICE_BROADCAST_CANONICALIZER_H_ -#include - -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This transform ensures that dimensions in all broadcast operations are -// sorted. -class BroadcastCanonicalizer : public HloModulePass { - public: - explicit BroadcastCanonicalizer(); - - absl::string_view name() const override { return "broadcast_canonicalizer"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" #endif // XLA_SERVICE_BROADCAST_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/buffer_assignment.cc b/third_party/xla/xla/service/buffer_assignment.cc index 77aa9283aaade0..4ab09252d2fc8e 100644 --- a/third_party/xla/xla/service/buffer_assignment.cc +++ b/third_party/xla/xla/service/buffer_assignment.cc @@ -18,42 +18,54 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include +#include #include #include #include #include #include #include +#include +#include #include #include #include "absl/algorithm/container.h" -#include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" #include "xla/service/buffer_value.h" -#include "xla/service/buffer_value_containers.h" +#include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" +#include "xla/service/logical_buffer.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -91,8 +103,8 @@ BuildIdToLogicalBufferMap( << "Expected logical buffer to have location information in the proto."; TF_RET_CHECK(id_to_hlo_instruction.contains( logical_buffer_proto.defined_at().instruction_id())) - << "Expected hlo instruction " - << "with the id '" << logical_buffer_proto.defined_at().instruction_id() + << "Expected hlo instruction " << "with the id '" + << logical_buffer_proto.defined_at().instruction_id() << "' in the proto to also exist in the " "HLO module."; // Assumption: An hlo module loaded from an hlo proto @@ -128,7 +140,7 @@ absl::Status GatherComputationsByAllocationType( // be thread-local. std::deque> worklist; worklist.push_back(std::make_pair(module->entry_computation(), - /*is_thread_local*/ false)); + /*is_thread_local=*/false)); // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. @@ -179,18 +191,9 @@ absl::Status GatherComputationsByAllocationType( case HloOpcode::kAsyncStart: case HloOpcode::kAsyncUpdate: case HloOpcode::kAsyncDone: - // Call, conditional, while, and async operations must be called - // from a computation with global allocations as they may return - // references to buffers inside the called computation which cannot - // be thread-local. - if (is_thread_local) { - return InvalidArgument( - "computation %s cannot contain call/while op because it " - "requires thread-local buffer allocations", - computation->name()); - } - worklist.push_back(std::make_pair(subcomputation, - false)); // Not thread local. + // Call, conditional, while, and async operations inherit their + // thread-locality from their parent computation. + worklist.push_back(std::make_pair(subcomputation, is_thread_local)); break; case HloOpcode::kCustomCall: case HloOpcode::kAllReduce: @@ -204,12 +207,11 @@ absl::Status GatherComputationsByAllocationType( case HloOpcode::kSort: case HloOpcode::kFusion: // Map/reduce etc computations are always thread-local. - worklist.push_back(std::make_pair(subcomputation, - true)); // Thread local. + worklist.push_back(std::make_pair(subcomputation, true)); break; default: return Internal("Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode())); + HloOpcodeString(instruction->opcode())); } } } @@ -340,9 +342,14 @@ static const HloInstruction* GetOutputInstruction( return nullptr; } -std::string BufferAllocation::ToShortString() const { +std::string BufferAllocation::ToShortString(bool human_readable_size) const { std::string output; - StrAppendFormat(&output, "allocation %d: size %d", index_, size()); + if (human_readable_size) { + StrAppendFormat(&output, "allocation %d: size %s", index_, + HumanReadableNumBytes(size())); + } else { + StrAppendFormat(&output, "allocation %d: size %d", index_, size()); + } if (color() != 0) { StrAppend(&output, ", color ", color()); } @@ -838,6 +845,151 @@ std::string BufferAssignment::ToString() const { return output; } +std::string BufferAssignment::MemoryUsageReport(float percentile, + int64_t more_than_k) const { + std::string output; + int64_t total_size = 0; + for (auto& allocation : allocations_) { + total_size += allocation.size(); + } + absl::StrAppend(&output, "Total bytes used: ", total_size, " (", + HumanReadableNumBytes(total_size), ")\n"); + + absl::StrAppend(&output, "\nAllocations sorted by size:\n\n"); + auto allocations = allocations_; + std::sort(allocations.begin(), allocations.end(), + [](const BufferAllocation& a, const BufferAllocation& b) { + if (a.size() > b.size()) return true; + if (a.size() < b.size()) return false; + return a.index() < b.index(); + }); + + int64_t cumulative_size = 0; + absl::StrAppend( + &output, "cumulative_size; total_size - cumulative_size; allocation\n"); + absl::StrAppend(&output, + "------------------------------------------------------------" + "------------------\n"); + int64_t index = 0; + for (auto& allocation : allocations) { + cumulative_size += allocation.size(); + absl::StrAppend( + &output, + absl::StrFormat("%10s(%3.0f%%); %10s; %s", + HumanReadableNumBytes(cumulative_size), + 100. * cumulative_size / total_size, + HumanReadableNumBytes(total_size - cumulative_size), + allocation.ToShortString(true))); + + // Skip the rest of the allocations if they are less than percentile of the + // total size and not more than k. + if (++index > more_than_k && + total_size - cumulative_size < total_size * percentile) { + absl::StrAppend( + &output, + absl::StrFormat( + "The rest %d allocations are less than %d%% of the total " + "size and not shown.\n", + allocations.size() - index, static_cast(percentile * 100))); + break; + } + } + + absl::StrAppend(&output, + "\n\nAllocations sorted by size with their values:\n"); + for (auto& allocation : allocations) { + if (allocation.assigned_buffers().size() == 1) { + absl::StrAppend(&output, allocation.ToShortString(true)); + } else { + StrAppendFormat( + &output, "%s\n%s\n", allocation.ToShortString(true), + allocation.MemoryUsageReport("\t", percentile, more_than_k)); + } + } + return output; +} + +std::string BufferAllocation::MemoryUsageReport(const std::string& prefix, + float percentile, + int64_t more_than_k) const { + std::string output; + + struct OffsetInfo { + std::vector values; + OffsetSize offset_size; + }; + + // Group the values by their offset in the allocation. + absl::flat_hash_map offset_to_buffers; + for (const auto& element : assigned_buffers_) { + const HloValue* value = element.first; + OffsetInfo& offset_info = offset_to_buffers[element.second.offset]; + offset_info.values.push_back(value); + offset_info.offset_size.offset = element.second.offset; + offset_info.offset_size.size = + std::max(offset_info.offset_size.size, element.second.size); + } + + // Sort the offset infos by the max size of the values in the group. + std::vector sorted_offset_infos; + int64_t total_size = 0; + for (auto& element : offset_to_buffers) { + total_size += element.second.offset_size.size; + sorted_offset_infos.push_back(std::move(element.second)); + } + absl::c_sort(sorted_offset_infos, + [](const OffsetInfo& a, const OffsetInfo& b) { + return a.offset_size.size > b.offset_size.size; + }); + + StrAppend(&output, prefix, + "cumulative_size; size; offset; used_by_n_values; " + "shapes_list\n"); + StrAppend(&output, prefix, + "------------------------------------------------------------\n"); + int64_t cumulative_size = 0; + int64_t index = 0; + for (const auto& offset_info : sorted_offset_infos) { + cumulative_size += offset_info.offset_size.size; + StrAppendFormat(&output, "%s%9s(%3.0f%%); %10s; %12d; %16d; ", prefix, + xla::HumanReadableNumBytes(cumulative_size), + 100. * cumulative_size / total_size, + xla::HumanReadableNumBytes(offset_info.offset_size.size), + offset_info.offset_size.offset, offset_info.values.size()); + + // Count the number of values with the same shape and append them at the end + // of the line. + absl::flat_hash_map shapes; + for (auto& value : offset_info.values) shapes[value->shape().ToString()]++; + + StrAppend( + &output, + absl::StrJoin(shapes, ", ", [](std::string* out, const auto& pair) { + if (pair.second == 1) { + return absl::StrAppend(out, pair.first); + } + return absl::StrAppend(out, pair.second, "×", pair.first); + })); + + StrAppend(&output, "\n"); + + // Skip the rest of the values if they are less than percentile of the + // total size and not more than k. + if (++index > more_than_k && + total_size - cumulative_size < total_size * percentile) { + StrAppendFormat( + &output, + "%sThe rest %d values are less than %d%% of the total size and not " + "shown.\n", + prefix, sorted_offset_infos.size() - index, + static_cast(percentile * 100)); + break; + } + } + + return output; +} + // Returns the largest k buffers present at the point of peak memory usage // across allocations as a vector of pairs with their corresponding sizes. std::vector> TopKPeakBuffers( diff --git a/third_party/xla/xla/service/buffer_assignment.h b/third_party/xla/xla/service/buffer_assignment.h index 337d9faa9f64ac..ebc6176d3abbac 100644 --- a/third_party/xla/xla/service/buffer_assignment.h +++ b/third_party/xla/xla/service/buffer_assignment.h @@ -33,6 +33,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -42,10 +45,7 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" #include "xla/service/memory_space_assignment/memory_space_assignment.h" @@ -228,7 +228,18 @@ class BufferAllocation { Slice GetSlice(const HloValue& buffer) const; std::string ToString() const; - std::string ToShortString() const; + std::string ToShortString(bool human_readable_size = false) const; + std::string ValuesToString() const; + + // The function returns memory usage report for the values belonging to the + // buffer allocation. The values are grouped by their offset in the + // allocation. The groups are sorted by the max size(Z-A) of the values in the + // group. Percentile and more_than_k are used to control the number of groups + // being reported. + std::string MemoryUsageReport(const std::string& prefix, + float percentile = 0.05, + int64_t more_than_k = 50) const; + BufferAllocationProto ToProto() const; // Whether the buffer is a parameter to or live out of the entry computation. @@ -486,10 +497,18 @@ class BufferAssignment { // Returns the HloLiveRange object used to construct this assignment. const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } + // Is in use by many compilers to dump the buffer-assignment info. std::string ToString() const; + + // Returns a memory usage report with the list of buffer allocations ordered + // by the size(Z-A) and the values assigned to each buffer allocation. + std::string MemoryUsageReport(float percentile = 0.05, + int64_t more_than_k = 50) const; // Verbose string tailored to debugging OOMs, includes the Hlo op metadata for // every buffer associated with each allocation. std::string ToVerboseString(size_t max_buffers_to_show) const; + + // Is in use by tpu compiler to dump the buffer info. std::string BufferInfoString() const; // Convert BufferAssignment to or from a proto. diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index 04238c4fd39f5a..7435d0d65c4fd3 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -15,40 +15,49 @@ limitations under the License. #include "xla/service/buffer_assignment.h" +#include #include #include -#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/copy_insertion.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_ordering.h" -#include "xla/service/hlo_parser.h" -#include "xla/service/hlo_proto_util.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_value.h" +#include "xla/service/logical_buffer.h" +#include "xla/service/memory_space_assignment/memory_space_assignment.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -1622,6 +1631,67 @@ TEST_F(BufferAssignmentTest, CustomCallEmbeddedComputationBuffers) { EXPECT_TRUE(map_root_alloc.is_thread_local()); } +TEST_F(BufferAssignmentTest, CustomCallSubcomputationBuffers) { + // Verify that buffers for subcomputations in a custom call are properly + // marked as thread-local. + auto module = CreateNewVerifiedModule(); + auto scalar_shape = ShapeUtil::MakeShape(F32, {}); + + auto subcomputation_builder = + HloComputation::Builder(TestName() + "_subcomputation"); + auto subcomputation_param = subcomputation_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "subcomputation_param")); + auto subcomputation_root = + subcomputation_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape, HloOpcode::kNegate, subcomputation_param)); + auto subcomputation = + module->AddEmbeddedComputation(subcomputation_builder.Build()); + + // Create a scalar computation to use in a map. + auto map_builder = HloComputation::Builder(TestName() + "_map"); + auto map_param = map_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "map_param")); + auto map_root = map_builder.AddInstruction( + HloInstruction::CreateCall(scalar_shape, {map_param}, subcomputation)); + auto map_computation = module->AddEmbeddedComputation(map_builder.Build()); + + // Create entry computation with a custom call on map_computation. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + builder.AddInstruction(HloInstruction::CreateCustomCall( + scalar_shape, {param}, map_computation, "call_name")); + module->AddEntryComputation(builder.Build()); + + auto assignment = RunBufferAssignment(module.get()); + + // Allocations for the map computation should be thread-local and not + // live-out. + auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param); + EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_param_alloc.maybe_live_out()); + EXPECT_TRUE(map_param_alloc.is_thread_local()); + + auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root); + EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(map_root_alloc.maybe_live_out()); + EXPECT_TRUE(map_root_alloc.is_thread_local()); + + // Allocations for the subcomputation should be thread-local and not + // live-out. + auto& subcomputation_param_alloc = + GetTopLevelAllocation(*assignment, subcomputation_param); + EXPECT_FALSE(subcomputation_param_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(subcomputation_param_alloc.maybe_live_out()); + EXPECT_TRUE(subcomputation_param_alloc.is_thread_local()); + + auto& subcomputation_root_alloc = + GetTopLevelAllocation(*assignment, subcomputation_root); + EXPECT_FALSE(subcomputation_root_alloc.is_entry_computation_parameter()); + EXPECT_FALSE(subcomputation_root_alloc.maybe_live_out()); + EXPECT_TRUE(subcomputation_root_alloc.is_thread_local()); +} + TEST_F(BufferAssignmentTest, TupleParameterAsOutput) { // Test a computation that returns a tuple parameter. auto builder = HloComputation::Builder(TestName()); diff --git a/third_party/xla/xla/service/buffer_value_containers.h b/third_party/xla/xla/service/buffer_value_containers.h index 2e02dd8df7dec3..9b2cfaffee730b 100644 --- a/third_party/xla/xla/service/buffer_value_containers.h +++ b/third_party/xla/xla/service/buffer_value_containers.h @@ -19,7 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "xla/service/buffer_value.h" #include "xla/service/logical_buffer.h" -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" namespace xla { diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc index 0605fbd6457ff7..1fb8652110a77c 100644 --- a/third_party/xla/xla/service/call_inliner.cc +++ b/third_party/xla/xla/service/call_inliner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/call_inliner.h" #include +#include #include #include @@ -24,16 +25,19 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_domain_isolator.h" +#include "xla/service/spmd/shardy/constants.h" #include "xla/status_macros.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -133,6 +137,29 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; +// Specific inlining rules when needing to round-trip from MLIR->HLO->MLIR when +// using Shardy (github.com/openxla/shardy). +// +// - shmap_body: We don't want to inline the bodies of JAX shard maps in order +// to import them into an `sdy.ManualComputationOp`. This is for the MHLO +// round-trip pipeline +// - kManualComputationBodyFuncName: Same as shmap_body except for the SDY +// round-trip pipeline. +bool InlineUnderShardy(HloInstruction* instruction) { + return !(instruction->GetModule()->config().use_shardy_partitioner() && + (absl::StrContains(instruction->to_apply()->name(), "shmap_body") || + absl::StartsWith(instruction->to_apply()->name(), + sdy::kManualComputationBodyFuncName.str()))); +} + +bool InlineComposites( + HloInstruction* instruction, + const absl::flat_hash_set& composites_to_preserve) { + return !instruction->is_composite() || + !composites_to_preserve.contains( + instruction->frontend_attributes().map().at("composite.name")); +} + } // namespace /* static */ absl::StatusOr @@ -152,6 +179,29 @@ CallInliner::Inline(HloInstruction* call) { const auto& callees = call->called_computations(); TF_RET_CHECK(callees.size() == 1); HloComputation* callee = callees[0]; + + // Propagate the frontend attributes related to fusion from the call to the + // inlined instructions. + if (call->has_frontend_attributes()) { + const FrontendAttributes& call_attributes = call->frontend_attributes(); + std::string has_fuse = + call_attributes.map().contains("MUST_FUSE") ? "MUST_FUSE" + : call_attributes.map().contains("MAXIMAL_FUSE") ? "MAXIMAL_FUSE" + : ""; + if (!has_fuse.empty()) { + for (auto instruction : callee->instructions()) { + // Do so for only fusible instructions. + if (instruction->IsFusible()) { + FrontendAttributes frontend_attributes = + instruction->frontend_attributes(); + frontend_attributes.mutable_map()->insert( + {has_fuse, call_attributes.map().at(has_fuse)}); + instruction->set_frontend_attributes(frontend_attributes); + } + } + } + } + // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); TF_RETURN_IF_ERROR(callee->Accept(&visitor)); @@ -160,7 +210,10 @@ CallInliner::Inline(HloInstruction* call) { bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return instruction->opcode() == HloOpcode::kCall && - !instruction->parent()->IsAsyncComputation(); + !instruction->has_backend_config() && + !instruction->parent()->IsAsyncComputation() && + InlineUnderShardy(instruction) && + InlineComposites(instruction, composites_to_preserve_); } absl::StatusOr CallInliner::Run( diff --git a/third_party/xla/xla/service/call_inliner.h b/third_party/xla/xla/service/call_inliner.h index 7fd584ad5eeba8..3eb2b7f1175702 100644 --- a/third_party/xla/xla/service/call_inliner.h +++ b/third_party/xla/xla/service/call_inliner.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_CALL_INLINER_H_ #define XLA_SERVICE_CALL_INLINER_H_ +#include +#include + #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -41,9 +44,12 @@ class CallInliner : public HloModulePass { // inlined. // If update_domain is true, the exit domains could be updated for calls which // are being inlined if necessary. - explicit CallInliner(bool single_call_site = false, - bool update_domain = false) - : single_call_site_(single_call_site), update_domain_(update_domain) {} + explicit CallInliner( + bool single_call_site = false, bool update_domain = false, + absl::flat_hash_set composites_to_preserve = {}) + : single_call_site_(single_call_site), + update_domain_(update_domain), + composites_to_preserve_(std::move(composites_to_preserve)) {} ~CallInliner() override = default; absl::string_view name() const override { return "call-inliner"; } @@ -59,6 +65,7 @@ class CallInliner : public HloModulePass { private: bool single_call_site_; bool update_domain_; + absl::flat_hash_set composites_to_preserve_; }; } // namespace xla diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc index ad6ee73eb14e8a..b41606d1a93e75 100644 --- a/third_party/xla/xla/service/call_inliner_test.cc +++ b/third_party/xla/xla/service/call_inliner_test.cc @@ -16,17 +16,16 @@ limitations under the License. #include "xla/service/call_inliner.h" #include -#include #include #include "absl/log/log.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -377,5 +376,118 @@ TEST_F(CallInlinerTest, InlineCompositeCall) { EXPECT_TRUE((*inst)->frontend_attributes().map().empty()); } +TEST_F(CallInlinerTest, PreserveCompositeCall) { + const absl::string_view hlo_string = R"( + HloModule composite + + %add (lhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] constant(2) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) + } + + ENTRY %main () -> f32[] { + %lhs = f32[] constant(42) + ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + })"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + CallInliner call_inliner( + /*single_call_site=*/true, /*update_domain=*/false, + /*composites_to_preserve=*/{"foo.bar"}); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_FALSE(mutated); + + auto inst = module->entry_computation()->instructions().begin(); + EXPECT_THAT(*inst, op::Constant()); + ++inst; + EXPECT_THAT(*inst, op::Call()); + EXPECT_FALSE((*inst)->frontend_attributes().map().empty()); +} + +TEST_F(CallInlinerTest, UseShardyMhloToHloShmapBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 + %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); +} + +// Don't inline when the name starts with "xla.sdy.manual_computation_body". +TEST_F(CallInlinerTest, UseShardManualComputationBodyNotInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // The single call in the module is not inlined. + EXPECT_FALSE(changed); + + HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); + EXPECT_NE(call, nullptr); + EXPECT_TRUE(call->has_to_apply()); + EXPECT_EQ(call->to_apply()->name(), "xla.sdy.manual_computation_body.4"); +} + +// Inliner only checks if the name of the function has +// "xla.sdy.manual_computation_body" a prefix, not if it contains it. +TEST_F(CallInlinerTest, UseShardManualComputationBodyInlined) { + const char* const hloString = R"( + HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} + + %prefix_xla.sdy.manual_computation_body.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { + %Arg_0.5 = f32[1,8]{1,0} parameter(0) + ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} + } + + ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { + %Arg_0.1 = f32[8,8]{1,0} parameter(0) + %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} + %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_xla.sdy.manual_computation_body.4 + ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, CallInliner().Run(module.get())); + VLOG(1) << module->ToString(); + // Will be inlined. + EXPECT_TRUE(changed); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc index 3c7875a2836ceb..365308ae24cd50 100644 --- a/third_party/xla/xla/service/change_op_data_type.cc +++ b/third_party/xla/xla/service/change_op_data_type.cc @@ -63,12 +63,7 @@ absl::StatusOr ChangeOpDataType::Run( continue; } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - if (instr->opcode() == HloOpcode::kDot && - cpu::OneDnnContractionRewriter::ShouldRewriteDot(instr, true)) { - continue; - } - if (instr->opcode() == HloOpcode::kConvolution && - cpu::OneDnnContractionRewriter::ShouldRewriteConv(instr)) { + if (cpu::OneDnnContractionRewriter::ShouldRewriteInstr(instr, true)) { continue; } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cholesky_expander.h b/third_party/xla/xla/service/cholesky_expander.h index 3178d36e949b19..7e9e7332e917f0 100644 --- a/third_party/xla/xla/service/cholesky_expander.h +++ b/third_party/xla/xla/service/cholesky_expander.h @@ -16,33 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CHOLESKY_EXPANDER_H_ #define XLA_SERVICE_CHOLESKY_EXPANDER_H_ -#include "absl/container/flat_hash_map.h" -#include "xla/client/xla_builder.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -class CholeskyExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "cholesky_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; - - virtual absl::StatusOr> CholeskyUnblocked( - XlaOp a, PrecisionConfig::Precision precision); - - private: - XlaOp BuildCholesky(XlaOp a, int64_t block_size, - PrecisionConfig::Precision precision); - - // Mapping from op signatures to existing computations. - absl::flat_hash_map computation_cache_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/cholesky_expander.h" #endif // XLA_SERVICE_CHOLESKY_EXPANDER_H_ diff --git a/third_party/xla/xla/service/collective_combiner_utils.h b/third_party/xla/xla/service/collective_combiner_utils.h index 29ef053dc6d60b..5fb45edf907d2e 100644 --- a/third_party/xla/xla/service/collective_combiner_utils.h +++ b/third_party/xla/xla/service/collective_combiner_utils.h @@ -28,8 +28,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 3f2cbec7778a2b..caa3a951c1af8b 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -596,6 +596,7 @@ bool IsNonFusionCollective(const HloInstruction* instruction) { case HloOpcode::kCollectivePermute: case HloOpcode::kCollectivePermuteStart: case HloOpcode::kCollectivePermuteDone: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kReduceScatter: return true; case HloOpcode::kAsyncStart: diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index eaa2941b47ea4f..6ec2cf7bb742cc 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/computation_placer.h" #include "xla/service/global_device_id.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/collective_opt_utils.cc b/third_party/xla/xla/service/collective_opt_utils.cc index 8da8a37155b87f..f2c0c411c88dc9 100644 --- a/third_party/xla/xla/service/collective_opt_utils.cc +++ b/third_party/xla/xla/service/collective_opt_utils.cc @@ -318,22 +318,18 @@ std::optional MatchReduceScatter( return spec; } -bool AllGatherDynamicSliceCancellation( +std::optional AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims, bool allow_intervening_reshape, int64_t min_rank, HloPredicate match_partition_id, HloPredicate match_replica_id, bool allow_intervening_bitcast, bool allow_multiple_users) { - auto spec = MatchWithDynamicSlice( + return MatchWithDynamicSlice( ag, num_partitions, num_replicas, allow_multiple_split_dims, allow_intervening_reshape, min_rank, match_partition_id, match_replica_id, ag->constrain_layout(), ag->use_global_device_ids(), ag->channel_id() && ag->opcode() == HloOpcode::kAllGather, allow_intervening_bitcast, allow_multiple_users); - if (spec.has_value()) { - return true; - } - return false; } std::optional MatchWithDynamicSlice( diff --git a/third_party/xla/xla/service/collective_opt_utils.h b/third_party/xla/xla/service/collective_opt_utils.h index 6131028d5f684c..d49a47ab747bb6 100644 --- a/third_party/xla/xla/service/collective_opt_utils.h +++ b/third_party/xla/xla/service/collective_opt_utils.h @@ -44,7 +44,7 @@ std::optional MatchReduceScatter( bool allow_intervening_bitcast = false); // Check whether AG(ICI) and its user DS(ICI) can be canceled out. -bool AllGatherDynamicSliceCancellation( +std::optional AllGatherDynamicSliceCancellation( const HloAllGatherInstruction* ag, int64_t num_partitions, int64_t num_replicas, bool allow_multiple_split_dims = false, bool allow_intervening_reshape = false, int64_t min_rank = 1, diff --git a/third_party/xla/xla/service/collective_permute_decomposer_test.cc b/third_party/xla/xla/service/collective_permute_decomposer_test.cc index eac5ab0707418a..c31a1f1a1d5725 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer_test.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer_test.cc @@ -22,11 +22,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -345,7 +345,8 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { select = f32[2,2] select(broadcast, cp_back, cp_forward) - matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0} + matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, + rhs_contracting_dims={0} ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } @@ -361,8 +362,10 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { start_iter = u32[] constant(0) input_data = f32[2,2] parameter(0) input_weights = f32[2,2] parameter(1) - input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, input_weights) - while_result = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, + input_weights) + while_result = (u32[], f32[2,2], f32[2,2]) while(input), + condition=while_cond, body=while_body ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1 } )"; @@ -378,7 +381,9 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { // an XLA invariant that shouldn't be broken (see // https://openxla.org/xla/operation_semantics#send for details of the // semantics). - HloInstruction* recv_bwd = FindInstruction(transformed_module, "recv"); + HloComputation* while_body = + FindComputation(transformed_module, "while_body"); + HloInstruction* recv_bwd = hlo_query::FindInstruction(while_body, "recv"); EXPECT_EQ(recv_bwd->channel_id().value(), 1); auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map(); EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3); @@ -388,12 +393,12 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{3,0}}"); - HloInstruction* send_bwd = FindInstruction(transformed_module, "send"); + HloInstruction* send_bwd = hlo_query::FindInstruction(while_body, "send"); auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map(); EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{3,0}}"); - HloInstruction* recv_fwd = FindInstruction(transformed_module, "recv.1"); + HloInstruction* recv_fwd = hlo_query::FindInstruction(while_body, "recv.1"); EXPECT_EQ(recv_fwd->channel_id().value(), 2); auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map(); EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3); @@ -401,31 +406,18 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{0,1},{1,2},{2,3}}"); - HloInstruction* send_fwd = FindInstruction(transformed_module, "send.1"); + HloInstruction* send_fwd = hlo_query::FindInstruction(while_body, "send.1"); auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map(); EXPECT_EQ(send_fwd_frontend_attributes.size(), 3); EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), "{{0,1},{1,2},{2,3}}"); - HloComputation* while_body = - FindComputation(transformed_module, "while_body"); EXPECT_NE(while_body, nullptr); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv", "send")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "recv", "recv-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send", "recv-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send", "send-done")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "send-done", "send-done.1")); - EXPECT_TRUE( - hlo_query::IsBeforeInComputation(while_body, "recv-done", "send-done.1")); - EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv-done.1", - "send-done.1")); - auto recv_done_fwd = FindInstruction(transformed_module, "recv-done"); - auto recv_done_bwd = FindInstruction(transformed_module, "recv-done.1"); + HloInstruction* recv_done_fwd = + hlo_query::FindInstruction(while_body, "recv-done"); + HloInstruction* recv_done_bwd = + hlo_query::FindInstruction(while_body, "recv-done.1"); // TODO: b/356201477 - Investigate potential NCCL deadlock in // collective_permute_decomposer diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index e32749ceb7f2f8..eb5c6360da70db 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -40,6 +40,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -48,6 +49,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -56,9 +59,6 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/constant_value.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/service/value_range.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -275,8 +275,8 @@ bool CollectSimpleDependencies(HloInstruction* i, for (HloInstruction* op : i->mutable_operands()) { absl::InlinedVector to_add; if (op->opcode() == HloOpcode::kBroadcast) { - to_add.push_back(op); if (deps_set.insert(op).second) { + to_add.push_back(op); op = op->mutable_operand(0); if (op->opcode() == HloOpcode::kConstant) { if (deps_set.insert(op).second) { @@ -318,6 +318,7 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, absl::flat_hash_set added_instructions; HloInstruction* folded_instr = instr; std::vector formatting_ops; + absl::flat_hash_set formatting_set; // Returns if this is an acceptable user of a pipelined instruction. // Generic elementwise ops can have multiple operands that require the inputs // of being saved across the loop. So protect them through @@ -411,11 +412,12 @@ CheckStoreIntoSliceIsCompatible(HloInstruction* instr, auto& data = stack.back(); HloInstruction* instr = data.first; if (data.second == 0 && instr != folded_instr) { - if (!CollectSimpleDependencies(instr, formatting_ops, - added_instructions)) { + if (!CollectSimpleDependencies(instr, formatting_ops, formatting_set)) { return empty_pair; } - formatting_ops.push_back(instr); + if (formatting_set.insert(instr).second) { + formatting_ops.push_back(instr); + } } if (data.second == instr->user_count()) { stack.pop_back(); @@ -1656,7 +1658,8 @@ absl::Status TransformLoopForward( int64_t level_to_operate_on, bool pipeline_use_tree, bool process_different_sized_ops, HloPredicate should_process, HloPredicate acceptable_formatting, HloPredicate reuse_output_buffer, - int64_t& next_channel_id) { + int64_t& next_channel_id, + CollectivePipeliner::HloPostprocessor post_processing_fn) { // Defining some maps/sets to keep track of instructions duplicated. InstructionMap while_body_to_peeled; absl::flat_hash_set to_skip_set; @@ -1879,7 +1882,8 @@ absl::Status TransformLoopForward( base->shape(), base, to_insert, indices)); }; auto process_slice = - [&next_channel_id, insert_non_alias_custom_call, level_to_operate_on]( + [&next_channel_id, &post_processing_fn, insert_non_alias_custom_call, + level_to_operate_on]( HloInstruction* stacked_data, const InstructionMap& pipelined_values_map, const WhileMoveInfo& move_info) -> absl::StatusOr { @@ -1897,6 +1901,10 @@ absl::Status TransformLoopForward( CollectivePipeliner::kInsertedByPreviousStep)); } + if (post_processing_fn.has_value()) { + TF_RETURN_IF_ERROR((*post_processing_fn)(processed)); + } + InstructionMap cloned_map = pipelined_values_map; cloned_map[move_info.collectives_to_move.front()] = processed; for (auto* formatting_op : move_info.formatting_ops) { @@ -2330,9 +2338,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, // Create the new tuple with the original while tuple size. std::vector new_output_tuple; new_output_tuple.resize(operands_indices_count, nullptr); + InstructionMap pipelined_map; // Reproduce computation to the output after the loop on the full shape. for (auto& to_move : loop_analysis.GetMoveInfos()) { - InstructionMap pipelined_map; for (int64_t i = 0; i < to_move.collectives_to_move.size(); ++i) { HloInstruction* collective = to_move.collectives_to_move[i]; int64_t gte_index = collective_to_new_tuple_index[collective]; @@ -2419,6 +2427,9 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, // an effect on the instruction itself (like say broadcast, slices ... // etc). for (HloInstruction* formatting_op : to_move.formatting_ops) { + if (pipelined_map.contains(formatting_op)) { + continue; + } if (!to_add_batch_set.contains(formatting_op) && formatting_op->opcode() != HloOpcode::kBroadcast) { HloInstruction* cloned_not_to_batch = loop_computation->AddInstruction( @@ -2614,12 +2625,14 @@ static absl::Status TransformLoopBackward( HloPredicate should_process, HloPredicate acceptable_formatting, CollectivePipeliner::HloPostprocessor postprocess_peeled, CollectivePipeliner::HloPostprocessor postprocess_rotated, - int64_t& next_channel_id) { + int64_t& next_channel_id, + CollectivePipeliner::HloPostprocessor post_processing_fn) { // Defining some maps/sets to keep track of instructions duplicated. absl::flat_hash_map while_body_to_peeled; absl::flat_hash_map collective_to_move_map; absl::flat_hash_set is_pipelined_instruction; - absl::flat_hash_map is_output_instruction; + absl::flat_hash_map> + is_output_instruction; absl::flat_hash_set sideeffect_unused_instructions; int64_t count = 0; // Add instructions to duplicate into a set. @@ -2659,8 +2672,8 @@ static absl::Status TransformLoopBackward( CHECK_EQ(while_body->root_instruction()->opcode(), HloOpcode::kTuple); // Record instructions that are part of the output of the loop. for (int i = 0; i < while_body->root_instruction()->operand_count(); ++i) { - is_output_instruction[while_body->root_instruction()->mutable_operand(i)] = - i; + is_output_instruction[while_body->root_instruction()->mutable_operand(i)] + .push_back(i); } // Collect the new parameter shapes with the additional state for the indices @@ -2713,6 +2726,9 @@ static absl::Status TransformLoopBackward( *loop_analysis.GetLoopIterationIdx(), next_channel_id)); + if (post_processing_fn.has_value()) { + TF_RETURN_IF_ERROR((*post_processing_fn)(new_init_operands[idx])); + } if (postprocess_peeled.has_value()) { TF_RETURN_IF_ERROR(postprocess_peeled.value()(new_init_operands[idx])); } @@ -2764,6 +2780,9 @@ static absl::Status TransformLoopBackward( *loop_analysis.GetLoopIterationIdx(), next_channel_id, &loop_variant_parameter_info)); + if (post_processing_fn.has_value()) { + TF_RETURN_IF_ERROR((*post_processing_fn)(cloned_instr)); + } if (postprocess_rotated.has_value()) { TF_RETURN_IF_ERROR(postprocess_rotated.value()(cloned_instr)); } @@ -2887,8 +2906,9 @@ static absl::Status TransformLoopBackward( HloInstruction::CreateGetTupleElement(new_while_loop, tuple_idx)); while_body_replacement_map[instr] = pipelined_value; if (instruction_is_output_it != is_output_instruction.end()) { - output_tuple_instructions[instruction_is_output_it->second] = - pipelined_value; + for (int64_t index : instruction_is_output_it->second) { + output_tuple_instructions[index] = pipelined_value; + } } continue; } @@ -2904,8 +2924,9 @@ static absl::Status TransformLoopBackward( loop_analysis)); while_body_replacement_map[instr] = cloned_instr; if (instruction_is_output_it != is_output_instruction.end()) { - output_tuple_instructions[instruction_is_output_it->second] = - cloned_instr; + for (int64_t index : instruction_is_output_it->second) { + output_tuple_instructions[index] = cloned_instr; + } } } // Substitute old loop with the result of the last peeled iteration. @@ -2953,7 +2974,7 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( tuple_points_to_analysis.get(), call_graph.get()); loop_analysis->ComputeLoopStatistics(); if (loop_analysis->GetLoopIterationCount() && - loop_analysis->GetLoopIterationCount()->GetUnsignedValue() > 0) { + loop_analysis->GetLoopIterationCount()->GetUnsignedValue() > 1) { loop_analyses.push_back( std::make_pair(instruction, std::move(loop_analysis))); } @@ -2990,7 +3011,8 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, - config_.reuse_pipelined_op_buffer, next_channel_id)); + config_.reuse_pipelined_op_buffer, next_channel_id, + config_.postprocess_pipelined_ops)); } else if (config_.pipelining_direction == PipeliningDirection::kForwardSink) { TF_RETURN_IF_ERROR(TransformLoopForwardSink( @@ -3003,7 +3025,8 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, config_.postprocess_backward_peeled_op, - config_.postprocess_backward_rotated_op, next_channel_id)); + config_.postprocess_backward_rotated_op, next_channel_id, + config_.postprocess_pipelined_ops)); } ++transformed_loops; changed = true; diff --git a/third_party/xla/xla/service/collective_pipeliner.h b/third_party/xla/xla/service/collective_pipeliner.h index f01cce9cb77fb2..079c6ef9dad5e1 100644 --- a/third_party/xla/xla/service/collective_pipeliner.h +++ b/third_party/xla/xla/service/collective_pipeliner.h @@ -103,6 +103,8 @@ class CollectivePipeliner : public HloModulePass { // Determines whether a loop invariant instruction can be considered // in the pipelining chain. bool should_add_loop_invariant_op_in_chain = false; + // Postprocessing hook which runs for every successfully pipelined op. + HloPostprocessor postprocess_pipelined_ops = std::nullopt; }; static const char* const kInsertedByPreviousStep; static const char* const kSunkByPreviousStep; diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 73efde7ce53096..be230a95a77b1a 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -36,10 +36,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_verifier.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/test_helpers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -1501,6 +1504,159 @@ ENTRY entry { EXPECT_NE(FindInstruction(module.get(), "ag.2"), nullptr); } +TEST_F(CollectivePipelinerTest, LoopVariantAppearingInRootTupleMultipleTimes) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128], s32[], s32[]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128], s32[], s32[]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2 + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + add.233 = s32[] add(add.232, constant.2557) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.k = bf16[1,1,2,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561, constant.2561, constant.2561), dynamic_slice_sizes={1,1,2,128} + r = bf16[1,2,128] reshape(dynamic-slice.k) + // To be peeled. + custom-call = bf16[1,2,128] custom-call(r), custom_call_target="MoveToDevice" + a = bf16[1,2,128] add(custom-call, custom-call), control-predecessors={constant.2559} + // To be peeled. + ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={} + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, ag) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], s32[], s32[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, add.233, add.233), control-predecessors={a} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + p2 = s32[] parameter(2) + tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], s32[], s32[]) tuple(c0, p0, p1, p2, p2) + while = (s32[], bf16[3,8,128], bf16[3,1,2,128], s32[], s32[]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + auto is_all_gather_or_offloading = [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kAllGather || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToDeviceCustomCallTarget); + }; + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, 0, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/false, + CollectivePipeliner::PipeliningDirection::kBackward, + is_all_gather_or_offloading) + .value()); +} + +TEST_F(CollectivePipelinerTest, TwoIterations) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(2) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.k = bf16[3,1,2,128] get-tuple-element(param), index=2 + param3 = bf16[3,8,128] get-tuple-element(param), index=3 + constant.2561 = s32[] constant(0) + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + dynamic-slice.k = bf16[1,1,2,128] dynamic-slice(get-tuple-element.k, get-tuple-element.394, constant.2561, constant.2561, constant.2561), dynamic_slice_sizes={1,1,2,128} + r = bf16[1,2,128] reshape(dynamic-slice.k) + // To be peeled. + custom-call = bf16[1,2,128] custom-call(r), custom_call_target="MoveToDevice" + a = bf16[1,2,128] add(custom-call, custom-call), control-predecessors={constant.2559} + // To be peeled. + ag = bf16[1,8,128] all-gather(a), dimensions={1}, replica_groups={} + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.395, get-tuple-element.394, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, ag) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, get-tuple-element.394, constant.2561, constant.2561) + ar.2 = bf16[1,8,128] custom-call(ar.1), custom_call_target="MoveToHost" + hmm = bf16[3,8,128] dynamic-update-slice(param3, ar.2, get-tuple-element.394, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, hmm), control-predecessors={a} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + p2 = bf16[3,8,128] parameter(2) + tuple = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) tuple(c0, p0, p1, p2) + while = (s32[], bf16[3,8,128], bf16[3,1,2,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + auto is_all_gather_or_offloading = [](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kAllGather || + instruction->IsCustomCall(host_memory_offload_annotations:: + kMoveToDeviceCustomCallTarget) || + instruction->IsCustomCall( + host_memory_offload_annotations::kMoveToHostCustomCallTarget); + }; + bool changed = + RunOptimizer(module.get(), /*last_run=*/true, /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::PipeliningDirection::kBackward, + is_all_gather_or_offloading) + .value(); + ASSERT_TRUE(changed); + XLA_VLOG_LINES(1, module->ToString()); + changed = + RunOptimizer(module.get(), /*last_run=*/true, /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::PipeliningDirection::kForward, + is_all_gather_or_offloading) + .value(); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_FALSE(changed); +} + TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneBackwardsCollectivePermute) { constexpr absl::string_view hlo_string = R"( @@ -2123,6 +2279,76 @@ ENTRY entry { EXPECT_EQ(select_instr_loop->opcode(), HloOpcode::kSelect); } +TEST_F(CollectivePipelinerTest, ForwardSinkLinearShape4097) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,4097], bf16[3,4097]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,4097], bf16[3,4097]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,4097] get-tuple-element(param), index=1 + get-tuple-element.35 = bf16[3,4097] get-tuple-element(param), index=2 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,4097] dynamic-slice(get-tuple-element.35, select.1348, constant.2561), dynamic_slice_sizes={1,4097} + mul = bf16[1,4097] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,4097] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + c = bf16[] custom-call(), custom_call_target="Boh" + b = bf16[1,4097] broadcast(c), dimensions={} + a = bf16[1,4097] add(ar.1, b) + dynamic-update-slice.35 = bf16[3,4097] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561) + ROOT tuple = (s32[], bf16[3,4097], bf16[3,4097]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35), control-predecessors={select.1348} +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,4097] parameter(0) + tuple = (s32[], bf16[3,4097], bf16[3,4097]) tuple(c0, p0, p0) + while = (s32[], bf16[3,4097], bf16[3,4097]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,4097] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/false, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + const HloComputation* comp = while_instr->while_body(); + const HloInstruction* root_loop = comp->root_instruction(); + EXPECT_TRUE(root_loop->HasControlDependencies()); + EXPECT_EQ(root_loop->control_predecessors().size(), 1); + const HloInstruction* select_instr_loop = + root_loop->control_predecessors()[0]; + EXPECT_EQ(select_instr_loop->opcode(), HloOpcode::kSelect); +} + TEST_F(CollectivePipelinerTest, TransformIncrementIndexByOneNotFirstIdxSinkCustomCall) { constexpr absl::string_view hlo_string = R"( @@ -3766,5 +3992,88 @@ ENTRY entry { while_instr->while_body()->root_instruction()->operand(8))); } +TEST_F(CollectivePipelinerTest, NoRedundantBroadcastsInFormattingOps) { + constexpr absl::string_view hlo_string = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +add.1 { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) parameter(0) + get-tuple-element.394 = s32[] get-tuple-element(param), index=0 + get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1 + get-tuple-element.396 = bf16[3,8,128] get-tuple-element(param), index=2 + get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=3 + constant.2557 = s32[] constant(1) + add.230 = s32[] add(get-tuple-element.394, constant.2557) + constant.2559 = s32[] constant(3) + subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) + constant.2560 = s32[] constant(-1) + add.231 = s32[] add(subtract.139, constant.2560) + constant.2561 = s32[] constant(0) + compare.747 = pred[] compare(add.231, constant.2561), direction=LT + constant.2562 = s32[] constant(2) + add.232 = s32[] add(subtract.139, constant.2562) + select.1348 = s32[] select(compare.747, add.232, add.231) + dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128} + mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99) + ar.1 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add, channel_id=1 + convert = bf16[] convert(add.232) + broadcast = bf16[1,8,128] broadcast(convert) + add.1 = bf16[1,8,128] add(ar.1, broadcast) + dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, add.1, select.1348, constant.2561, constant.2561) + ar.2 = bf16[1,8,128] all-reduce(mul), replica_groups={}, to_apply=add.1, channel_id=2 + add.2 = bf16[1,8,128] add(ar.2, broadcast) + dynamic-update-slice.36 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.396, add.2, select.1348, constant.2561, constant.2561) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, dynamic-update-slice.36, get-tuple-element.35) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 +} +)"; + auto module = ParseAndReturnUnverifiedModule(hlo_string, config_).value(); + EXPECT_TRUE(RunOptimizer(module.get(), /*last_run=*/true, + /*level_to_operate_on=*/0, + /*pipeline_use_tree=*/true, + /*process_different_sized_ops=*/true, + CollectivePipeliner::kForwardSink) + .value()); + XLA_VLOG_LINES(1, module->ToString()); + // There should be only one broadcast instruction using a get-tuple-element + // from the while instruction. + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + [](const HloInstruction* instr) { + return instr->opcode() == + HloOpcode::kBroadcast && + instr->operand(0)->opcode() == + HloOpcode::kGetTupleElement && + instr->operand(0)->operand(0)->opcode() == + HloOpcode::kWhile; + }), + 1); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/collective_quantizer.h b/third_party/xla/xla/service/collective_quantizer.h index 2803523ca7c3ff..b63a3138b91e0b 100644 --- a/third_party/xla/xla/service/collective_quantizer.h +++ b/third_party/xla/xla/service/collective_quantizer.h @@ -16,44 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ #define XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Reduces the amount of data transferred in all-gather, all-to-all, -// collective-broadcast and collective-permute ops by exchanging the collectives -// with subsequent quantizations or type conversions to a narrower type as well -// as preceding dequantizations or type conversions to a wider type. When -// present, unary ops such as bitcasts, copies, reshapes and slices between -// collective and quantization/dequantiation/type conversion are shifted, i.e. -// transforms -// -// collective --> unary --> quantization/type conversion -// -// into -// -// quantization/type conversion --> collective --> unary -// -// and -// -// dequantization/type conversion --> unary --> collective -// -// into -// -// unary --> collective --> dequantization/type conversion. -class CollectiveQuantizer : public HloModulePass { - public: - absl::string_view name() const override { return "collective-quantizer"; } - - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/collective_quantizer.h" #endif // XLA_SERVICE_COLLECTIVE_QUANTIZER_H_ diff --git a/third_party/xla/xla/service/collective_transformation_reorderer.h b/third_party/xla/xla/service/collective_transformation_reorderer.h index 1cb07b13c6a98f..2bbae612c5e4c9 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer.h +++ b/third_party/xla/xla/service/collective_transformation_reorderer.h @@ -16,61 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ #define XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Transforms -// -- all-gather + reshape into reshape + all-gather and -// -- reshape + all-reduce into all-reduce + reshape. -// Both transformations require that there are no other users affected, i.e., -// reshape user count should be 1. -// all-gather transformation requires the reshape to only change the shape of -// the all-gather shards, i.e., not reshaping across the all-gather dimension. -// all-reduce transformation requires all-reduce to be not layout constrained. - -// all-gather + reshape example: - -// input = [C_0, C_1, ..., C_i, ..., C_{n-1}, C_n] ... -// all-gather = [C_0, C_1, ..., P*C_i, ... C_{n-1}, C_n] all-gather(input) -// reshape = [D_0, D_1, ..., P*D_j, ..., D_{m-1}, D_m] reshape(all-gather) - -// can be transformed to: - -// input = [C_0, C_1, ..., C_i, ..., C_{n-1}, C_n] ... -// reshape = [D_0, D_1, ..., D_j, ..., D_{m-1}, D_m] reshape(input) -// all-gather = [D_0, D_1, ..., P*D_j, ... D_{m-1}, D_m] all-gather(input) - -// if and only if C_0 * C_1 * ... * C_{i-1} = D_0 * D_1 * ... * D_{j-1} -// and C_{i+1} * ... * C_{n-1} * C_n = D_{j+1} * ... * D_{m-1} * D_{m}. - -class CollectiveTransformationReorder : public HloModulePass { - public: - CollectiveTransformationReorder() = default; - ~CollectiveTransformationReorder() override = default; - absl::string_view name() const override { - return "collective-transformation-reorderer"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - absl::StatusOr ReorderAllGatherTransformations( - HloModule* module, - const absl::flat_hash_set& execution_threads); - absl::StatusOr ReorderAllReduceTransformations( - HloModule* module, - const absl::flat_hash_set& execution_threads); -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/collective_transformation_reorderer.h" #endif // XLA_SERVICE_COLLECTIVE_TRANSFORMATION_REORDERER_H_ diff --git a/third_party/xla/xla/service/collective_utils.h b/third_party/xla/xla/service/collective_utils.h new file mode 100644 index 00000000000000..916e007dc9b2eb --- /dev/null +++ b/third_party/xla/xla/service/collective_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_COLLECTIVE_UTILS_H_ +#define XLA_SERVICE_COLLECTIVE_UTILS_H_ + +#include + +namespace xla { + +// Defines the default threshold for `AllReduceCombiner` up to which the pass +// will combine collectives. +constexpr int64_t kDefaultAllReduceCombineThreshold = 30 * 1024 * 1024 + 7; + +// Defines the default threshold for `AllGatherCombiner` up to which the pass +// will combine collectives. +constexpr int64_t kDefaultAllGatherCombineThreshold = 30 * 1024 * 1024 + 7; + +// Defines the default threshold for `ReduceScatterCombiner` up to which the +// pass will combine collectives. +constexpr int64_t kDefaultReduceScatterCombineThreshold = 30 * 1024 * 1024 + 7; + +} // namespace xla + +#endif // XLA_SERVICE_COLLECTIVE_UTILS_H_ diff --git a/third_party/xla/xla/service/collectives_schedule_linearizer.h b/third_party/xla/xla/service/collectives_schedule_linearizer.h index 1f3ca55b9ec584..27f0de0032e2fa 100644 --- a/third_party/xla/xla/service/collectives_schedule_linearizer.h +++ b/third_party/xla/xla/service/collectives_schedule_linearizer.h @@ -16,37 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ #define XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/util.h" - -namespace xla { - -// Enforces a total order on all collectives present in the module, based on the -// order given to the instructions. -// -// Does not insert inter-computation dependencies, only linearizes the order -// within each computation. -class CollectivesScheduleLinearizer : public HloModulePass { - public: - explicit CollectivesScheduleLinearizer(HloModulePredicate is_enabled = {}) - : is_enabled_(is_enabled) {} - - absl::string_view name() const override { - return "collectives-schedule-linearizer"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - HloModulePredicate is_enabled_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/collectives_schedule_linearizer.h" #endif // XLA_SERVICE_COLLECTIVES_SCHEDULE_LINEARIZER_H_ diff --git a/third_party/xla/xla/service/comparison_expander.h b/third_party/xla/xla/service/comparison_expander.h index 205a45ae92bf56..333375478e59b7 100644 --- a/third_party/xla/xla/service/comparison_expander.h +++ b/third_party/xla/xla/service/comparison_expander.h @@ -16,43 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_COMPARISON_EXPANDER_H_ #define XLA_SERVICE_COMPARISON_EXPANDER_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/primitive_util.h" -#include "xla/service/op_expander_pass.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// A pass which performs expansion of the comparison operator to support total -// order comparison of floating point numbers. -class ComparisonExpander : public OpExpanderPass { - public: - explicit ComparisonExpander( - absl::Span> - expand_via_upcast = {}) - : expand_via_upcast_(expand_via_upcast.begin(), expand_via_upcast.end()) { - } - ~ComparisonExpander() override = default; - absl::string_view name() const override { return "comparison-expander"; } - - private: - // Returns `true` if `instruction` should be expanded by this pass. - bool InstructionMatchesPattern(HloInstruction* instruction) override; - // Returns a replacement for `instruction`, or nullptr if no replacement is - // needed (e.g. only the to_apply subcomputation of the instruction was - // modified). - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; - - std::vector> expand_via_upcast_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/comparison_expander.h" #endif // XLA_SERVICE_COMPARISON_EXPANDER_H_ diff --git a/third_party/xla/xla/service/compiler.h b/third_party/xla/xla/service/compiler.h index a64578841ed43b..45dc7298c4e8d4 100644 --- a/third_party/xla/xla/service/compiler.h +++ b/third_party/xla/xla/service/compiler.h @@ -157,10 +157,6 @@ class Compiler { // on which compilation is performed. std::optional target_config; - // Registry of MLIR dialects and plugins to be loaded during optimization. - // If non-null, it will be used to construct relevant MLIR contexts. - mlir::DialectRegistry* registry = nullptr; - MultiProcessKeyValueStore key_value_store; }; diff --git a/third_party/xla/xla/service/computation_placer.cc b/third_party/xla/xla/service/computation_placer.cc index ee0cf2932a1e86..43f351a5489592 100644 --- a/third_party/xla/xla/service/computation_placer.cc +++ b/third_party/xla/xla/service/computation_placer.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -222,6 +223,8 @@ static bool InitModule() { stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer); xla::ComputationPlacer::RegisterComputationPlacer( stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer( + stream_executor::sycl::kSyclPlatformId, &CreateComputationPlacer); return true; } static bool module_initialized = InitModule(); diff --git a/third_party/xla/xla/service/conditional_canonicalizer.h b/third_party/xla/xla/service/conditional_canonicalizer.h index a4ae0c05b2e194..6a857fc4cf208c 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer.h +++ b/third_party/xla/xla/service/conditional_canonicalizer.h @@ -15,27 +15,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ #define XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ -#include - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Canonicalize output of conditionals, make non-tuple outputs into tuple with -// single element output. After this pass, all conditional instructions have -// tuple outputs. -class ConditionalCanonicalizer : public HloModulePass { - public: - absl::string_view name() const override { - return "conditional-canonicalizer"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" #endif // XLA_SERVICE_CONDITIONAL_CANONICALIZER_H_ diff --git a/third_party/xla/xla/service/conditional_code_motion.cc b/third_party/xla/xla/service/conditional_code_motion.cc index 00f22ef9cf8703..ee079970f2c46e 100644 --- a/third_party/xla/xla/service/conditional_code_motion.cc +++ b/third_party/xla/xla/service/conditional_code_motion.cc @@ -39,12 +39,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/literal.h" #include "xla/map_util.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/convert_async_collectives_to_sync.h index d574f06ce3b693..3e3884b98a0fbb 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync.h +++ b/third_party/xla/xla/service/convert_async_collectives_to_sync.h @@ -16,57 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ #define XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/util.h" - -namespace xla { - -// Convert asynchronous collectives to synchronous (after HLO scheduling) if -// there are no compute operations overlapping with them. - -class ConvertAsyncCollectivesToSync : public HloModulePass { - public: - explicit ConvertAsyncCollectivesToSync(HloPredicate is_nop = {}) - : is_nop_(is_nop) {} - absl::string_view name() const override { - return "convert-async-collectives-to-sync"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - virtual absl::Status ConvertAsyncInstructionsToSync( - HloComputation* computation, - absl::Span> async_pairs) - const { - return ReplaceAsyncInstructionsWithSync(computation, async_pairs); - } - - // Helper utility to replace a list of pairs of async-start/done ops in a - // computation with their synchronous variants and update the schedule. - static absl::Status ReplaceAsyncInstructionsWithSync( - HloComputation* computation, - absl::Span> - async_pairs); - - static constexpr char kAsyncCollectiveNameAttributeName[] = - "async_collective_name"; - - private: - absl::StatusOr RunOnComputation(HloComputation* computation); - HloPredicate is_nop_; -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h" #endif // XLA_SERVICE_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ diff --git a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h index dd8186d8aa765b..17f629fd058847 100644 --- a/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h +++ b/third_party/xla/xla/service/convert_memory_placement_to_internal_annotations.h @@ -16,31 +16,7 @@ #ifndef XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ #define XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -class ConvertMemoryPlacementToInternalAnnotations : public HloModulePass { - public: - ConvertMemoryPlacementToInternalAnnotations() = default; - - absl::string_view name() const override { - return "convert-memory-placement-to-internal-annotations"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h" #endif // XLA_SERVICE_CONVERT_MEMORY_PLACEMENT_TO_INTERNAL_ANNOTATIONS_H_ diff --git a/third_party/xla/xla/service/convert_mover.h b/third_party/xla/xla/service/convert_mover.h index b531730ab95e01..a335a4583caecd 100644 --- a/third_party/xla/xla/service/convert_mover.h +++ b/third_party/xla/xla/service/convert_mover.h @@ -16,39 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVERT_MOVER_H_ #define XLA_SERVICE_CONVERT_MOVER_H_ -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Moves narrowing conversions up the graph and widening conversions down the -// graph, when we can do so with no effect on numerics. Motivations: -// -// - It's preferable to spend more of our time in lower precision and less of -// our time in higher precision. -// -// - Moving these converts exposes optimization opportunities. For example, in -// reshape(convert-big-to-small(reshape(convert-small-to-big(x)))), we can -// commute one of the converts with one of the reshapes. This leaves us with -// convert(convert(reshape(reshape))), which can probably be simplified -// further by algsimp. -class ConvertMover : public HloModulePass { - public: - ConvertMover() = default; - - absl::string_view name() const override { return "convert-mover"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/convert_mover.h" #endif // XLA_SERVICE_CONVERT_MOVER_H_ diff --git a/third_party/xla/xla/service/convert_operand_folding.h b/third_party/xla/xla/service/convert_operand_folding.h index e3a7ce2811d533..863cd7da8d4914 100644 --- a/third_party/xla/xla/service/convert_operand_folding.h +++ b/third_party/xla/xla/service/convert_operand_folding.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ #define XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// Folds Convert operands to wider types into instructions that supports wider -// result accumulation than the shape inference type. -// -// e.g. s32 hlo(s32 convert(s8), s32 convert(s8)) -> s32 hlo(s8, s8) -class ConvertOperandFolding : public OpExpanderPass { - public: - absl::string_view name() const override { return "convert_operand_folding"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/convert_operand_folder.h" #endif // XLA_SERVICE_CONVERT_OPERAND_FOLDING_H_ diff --git a/third_party/xla/xla/service/convolution_4d_expander.h b/third_party/xla/xla/service/convolution_4d_expander.h index c27c603598b59f..2a290290ebddef 100644 --- a/third_party/xla/xla/service/convolution_4d_expander.h +++ b/third_party/xla/xla/service/convolution_4d_expander.h @@ -16,24 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ #define XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -class Convolution4DExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "convolution_4d_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/convolution_4d_expander.h" #endif // XLA_SERVICE_CONVOLUTION_4D_EXPANDER_H_ diff --git a/third_party/xla/xla/service/convolution_group_converter.h b/third_party/xla/xla/service/convolution_group_converter.h index c9ec5f40d3ec7b..21d68d2751a0fc 100644 --- a/third_party/xla/xla/service/convolution_group_converter.h +++ b/third_party/xla/xla/service/convolution_group_converter.h @@ -16,54 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ #define XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ -#include - -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/status_macros.h" - -namespace xla { - -// A pass which rewrites convolutions with feature_group_count > 1 into -// convolutions with feature_group_count = 1. -class ConvolutionGroupConverter : public HloModulePass { - public: - ConvolutionGroupConverter(std::function should_expand, - std::function is_cost_viable, - bool convert_batch_groups_only, - bool filter_expansion = true) - : should_expand_(should_expand), - is_cost_viable_(is_cost_viable), - convert_batch_groups_only_(convert_batch_groups_only), - filter_expansion_(filter_expansion) {} - - absl::string_view name() const override { - return "convolution-group-converter"; - } - - // Run convolution rewriting on the given computation. Returns whether the - // computation was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // Predicate that determines whether this pass should rewrite a given - // convolution. - std::function should_expand_; - - // Lambda containing cost model that decides whether to expand - // batch_group_count. - std::function is_cost_viable_; - - // Decides whether to convert batch groups or feature groups. - bool convert_batch_groups_only_; - - // Tells whether filter expansion is required. - bool filter_expansion_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" #endif // XLA_SERVICE_CONVOLUTION_GROUP_CONVERTER_H_ diff --git a/third_party/xla/xla/service/convolution_pred_expander.h b/third_party/xla/xla/service/convolution_pred_expander.h index d4dd0919be4bee..84c57681afb00a 100644 --- a/third_party/xla/xla/service/convolution_pred_expander.h +++ b/third_party/xla/xla/service/convolution_pred_expander.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ #define XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// A pass that rewrites boolean convolutions to floating point and converts the -// result back to boolean. This is necessary, as the convolutions on GPUs are -// implemented using custom call to cuDNN, which only supports FP and S8 inputs. -class ConvolutionPredExpander : public OpExpanderPass { - public: - absl::string_view name() const override { - return "convolution-pred-expander"; - } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/convolution_pred_expander.h" #endif // XLA_SERVICE_CONVOLUTION_PRED_EXPANDER_H_ diff --git a/third_party/xla/xla/service/copy_insertion.cc b/third_party/xla/xla/service/copy_insertion.cc index 6e2fc858d09589..6a87d2bd618fc3 100644 --- a/third_party/xla/xla/service/copy_insertion.cc +++ b/third_party/xla/xla/service/copy_insertion.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" @@ -35,23 +36,25 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/frontend_attributes.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/analysis/hlo_reachability.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/compile_time_cap.h" #include "xla/service/dump.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" @@ -186,6 +189,22 @@ DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, return std::make_pair(from_deep_copy, to_deep_copy); } +bool IsSendRecv(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv; +} + +bool IsSendRecvDone(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kSendDone || + instruction->opcode() == HloOpcode::kRecvDone; +} + +bool IsSendRecvInInit(const HloInstruction* init, const ShapeIndex& index) { + if (index.empty()) return false; + int64_t i = index.front(); + return i < init->operand_count() && IsSendRecv(init->operand(i)); +} + // Compute the indices of the loop state which need copies in order to avoid // live range interference. Generally, an element in the loop state does not // need to be copied if the element is passed through transparently through the @@ -202,9 +221,14 @@ bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, for (auto& pair : *indices_to_copy) { const ShapeIndex& index = pair.first; bool& should_copy = pair.second; - // If there is any ambiguity, then loop state must be copied. - if (dataflow.GetValueSet(init, index).values().size() > 1 || - dataflow.GetValueSet(xla_while, index).values().size() > 1) { + if (IsSendRecvInInit(init, index)) { + // Do not copy partially pipelined send/recv ops. The required copies will + // be inserted specifically for the send/recv ops. + should_copy = false; + continue; + } else if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + // If there is any ambiguity, then loop state must be copied. should_copy = true; } else { // If the output of the while instruction is not the same as the init @@ -1307,42 +1331,6 @@ class CopyRemover { if (buffer.values().at(0)->defining_instruction()->IsFused()) { continue; } - if (check_live_range_ordering) { - // Skip checking if execution thread is not included. - auto should_skip_value = [&execution_threads](const HloValue* value) { - return value->defining_instruction()->parent() != nullptr && - !HloInstruction::IsThreadIncluded(value->defining_instruction() - ->parent() - ->execution_thread(), - execution_threads); - }; - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - if (value_a->shape().IsToken()) { - // Token values have no representation and cannot interfere. - continue; - } - if (should_skip_value(value_a)) { - continue; - } - for (const HloValue* value_b : buffer.values()) { - if (!should_skip_value(value_b) && value_a != value_b) { - DCHECK(ordering_->LiveRangeStrictlyBefore( - *value_a, *value_b, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true) || - ordering_->LiveRangeStrictlyBefore( - *value_b, *value_a, dataflow_, - /*use_is_always_before_def_in_same_instr=*/true)) - << value_a->ToString() << " and " << value_b->ToString() - << " are not ordered"; - } - } - } - } std::vector values = buffer.values(); absl::c_sort(values, [this, &instruction_ids](const HloValue* a, @@ -1526,6 +1514,18 @@ class CopyRemover { int64_t* region_analysis_limit) { VLOG(2) << "Trying to remove " << copy->name(); CHECK_NE(region_analysis_limit, nullptr); + if (copy->shape().has_layout() && copy->operand(0)->shape().has_layout()) { + if (copy->shape().layout().memory_space() == Layout::kHostMemorySpace && + copy->operand(0)->shape().layout().memory_space() != + Layout::kHostMemorySpace) { + return false; + } + if (copy->shape().layout().memory_space() != Layout::kHostMemorySpace && + copy->operand(0)->shape().layout().memory_space() == + Layout::kHostMemorySpace) { + return false; + } + } if (!ContainsKey(copy_map_, copy)) { VLOG(2) << copy->name() << " is not removable"; @@ -2014,6 +2014,123 @@ absl::Status CopyInsertion::AddCopiesForConditional( return absl::OkStatus(); } +HloInstruction* FindAsyncSendRecvDoneInWhileBody( + const HloComputation* while_body, const HloInstruction* start_op) { + // Partially pipelined send/recv must have a single user. + if (start_op->user_count() != 1) return nullptr; + HloInstruction* unique_user = start_op->users().front(); + // Send/recv must be consumed by send/recv-done op or be passed through the + // loop. + if (IsSendRecvDone(unique_user)) return unique_user; + if (unique_user->opcode() != HloOpcode::kTuple || !unique_user->IsRoot()) + return nullptr; + int64_t index = unique_user->operand_index(start_op); + for (const HloInstruction* it : + while_body->parameter_instruction(0)->users()) { + const auto* gte = DynCast(it); + if (gte->tuple_index() == index) { + CHECK_EQ(gte->user_count(), 1) << "send/recv in next loop iteration must " + "be consumed by unique send/recv-done."; + HloInstruction* next_unique_user = gte->users().front(); + if (IsSendRecvDone(next_unique_user)) return next_unique_user; + } + } + return nullptr; +} + +// Add copies for partially pipelined async send/recv. Copies are added before +// starting to send and after finishing to recv. This is to prevent overlapping +// live times of the buffers. The control flow edges from the added copy to the +// recv or send-done operation guarantee disjoint live times of the buffers. +// Note that we have anchor these control flow edges to the copies as the send +// and recv-done ops are aliasing. +// +// +// Before: +// +// kParameter kParameter +// | | +// kSendDone kRecvDone +// | +// ... consumer +// +// producer ... +// | +// kSend kRecv +// | | +// (body root) (body root) +// +// +// After: +// +// kParameter kParameter +// | | +// kSendDone ----+ kRecvDone +// | | +// ctrl kCopy ----+ +// producer edge | | +// | | consumer ctrl +// kCopy <-----+ edge +// | | +// kSend kRecv <---+ +// | | +// (body root) (body root) +// +absl::Status CopyInsertion::AddCopiesForAsyncSendRecv( + const HloAliasAnalysis& alias_analysis, HloInstruction* start_op) { + // If start op has multiple users, this must be the synchronous use of + // send/recv. + // TODO(b/369589022): Disambiguate sync and async use of send/recv. + if (start_op->users().size() != 1) return absl::OkStatus(); + + // If start feeds directly into done, the live time is contained and we don't + // need to add any copies. + HloInstruction* unique_user = start_op->users().front(); + const HloOpcode done_opcode = start_op->opcode() == HloOpcode::kSend + ? HloOpcode::kSendDone + : HloOpcode::kRecvDone; + if (unique_user->opcode() == done_opcode) { + return absl::OkStatus(); + } + + // For send/recv outside of the while loop, live times are disjoint. No copies + // needed. + HloComputation* while_body = start_op->parent(); + if (!while_body->IsWhileBodyComputation()) return absl::OkStatus(); + + // Handle send case. + HloInstruction* done_op = + FindAsyncSendRecvDoneInWhileBody(while_body, start_op); + // TODO(b/369589022): Disambiguate sync and async use of send/recv. + if (done_op == nullptr) return absl::OkStatus(); + if (start_op->opcode() == HloOpcode::kSend) { + HloInstruction* operand = start_op->mutable_operand(0); + HloInstruction* copied_operand = + while_body->AddInstruction(HloInstruction::CreateUnary( + operand->shape(), HloOpcode::kCopy, operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(start_op, copied_operand)); + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(copied_operand)); + return absl::OkStatus(); + } + + // Handle recv case. + CHECK_EQ(start_op->opcode(), HloOpcode::kRecv); + PtrVec done_op_users = done_op->users(); + ShapeTree copies_added(done_op->shape()); + TF_ASSIGN_OR_RETURN(HloInstruction * done_op_copy, + while_body->DeepCopyInstruction( + done_op, /*indices_to_copy=*/nullptr, &copies_added)); + for (auto [shape_index, instr] : copies_added) { + if (instr != nullptr) + TF_RETURN_IF_ERROR(instr->AddControlDependencyTo(start_op)); + } + TF_RETURN_IF_ERROR(done_op->AddControlDependencyTo(start_op)); + for (HloInstruction* it : done_op_users) { + TF_RETURN_IF_ERROR(done_op->ReplaceUseWith(it, done_op_copy)); + } + return absl::OkStatus(); +} + // Add kCopy instructions to the given module to guarantee there is no // live-range interference. Generally interference can only occur around kWhile // instructions which have update-in-place semantics. @@ -2034,6 +2151,10 @@ absl::Status CopyInsertion::AddCopiesToResolveInterference( } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else if (IsSendRecv(instruction)) { + // TODO(b/371225893): Generalize this to all async collectives. + TF_RETURN_IF_ERROR( + AddCopiesForAsyncSendRecv(*alias_analysis, instruction)); } else { // When an operand is a tuple, we avoid copying the operand multiple // times by recording and checking the operand number of operands that @@ -2293,8 +2414,6 @@ absl::Status CopyInsertion::RemoveUnnecessaryCopies( } } - std::unique_ptr call_graph = CallGraph::Build(module); - int64_t num_existing_copies = GetNumExistingCopies(module, execution_threads); bool changed = true; int64_t num_iterations = -1; diff --git a/third_party/xla/xla/service/copy_insertion.h b/third_party/xla/xla/service/copy_insertion.h index b76d47cd1a871b..0b2ba86e3ef3eb 100644 --- a/third_party/xla/xla/service/copy_insertion.h +++ b/third_party/xla/xla/service/copy_insertion.h @@ -20,13 +20,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" namespace xla { @@ -107,6 +107,10 @@ class CopyInsertion : public HloModulePass { virtual absl::Status AddCopiesForConditional( const HloAliasAnalysis& alias_analysis, HloInstruction* conditional); + // Add copies for async send/recv instructions. + absl::Status AddCopiesForAsyncSendRecv(const HloAliasAnalysis& alias_analysis, + HloInstruction* async); + // Backend specific function that decides whether an instruction can share // buffer with its operand. HloDataflowAnalysis::CanShareBuffer can_share_buffer_; diff --git a/third_party/xla/xla/service/copy_insertion_test.cc b/third_party/xla/xla/service/copy_insertion_test.cc index 5250e9842895df..38c1ba53fe170a 100644 --- a/third_party/xla/xla/service/copy_insertion_test.cc +++ b/third_party/xla/xla/service/copy_insertion_test.cc @@ -30,12 +30,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -3869,5 +3870,298 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 0); } +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecv) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={()->f32[16]{0}}, num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16]{0}, u32[], token[])) tuple(recv) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[])) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[])) tuple(recv) + while = ((f32[16]{0}, u32[], token[])) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies are removable. + EXPECT_EQ(CountCopies(*module), 0); + + // Expect control dependency from recv-done to recv. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + HloInstruction* recv_done = + hlo_query::FindInstruction(while_body, HloOpcode::kRecvDone); + HloInstruction* recv = + hlo_query::FindInstruction(while_body, HloOpcode::kRecv); + EXPECT_THAT(recv->control_predecessors(), UnorderedElementsAre(recv_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncRecvMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=1 + recv_data = f32[16]{0} get-tuple-element(recv_done), index=0 + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `recv_data` is again here, which extends it's live range. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, + recv_data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(recv, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the recv result, are removable. Additionally, there will be one copy + // leading into the loop. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindInstruction(while_body, HloOpcode::kRecvDone); + HloInstruction* recv = + hlo_query::FindInstruction(while_body, HloOpcode::kRecv); + HloInstruction* recv_done_copy = + hlo_query::FindInstruction(while_body, HloOpcode::kCopy); + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendMultipleUses) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + data = f32[16]{0} get-tuple-element(param), index=1 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // `data` is used again here, which extends it's live range beyond `send`. + ROOT tuple = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), f32[16]{0}) tuple(send, data) + while = ((f32[16]{0}, u32[], token[]), f32[16]{0}) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=1 + ROOT data_ = f32[16]{0} get-tuple-element(while), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies, except for an extra + // use of the send operand, are removable. Additionally, there will be 2 + // copies leading into the loop and returning copying the result. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 3); + EXPECT_EQ(CountCopies(*while_body), 1); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindInstruction(while_body, HloOpcode::kSendDone); + HloInstruction* send = + hlo_query::FindInstruction(while_body, HloOpcode::kSend); + HloInstruction* send_operand_copy = + hlo_query::FindInstruction(while_body, HloOpcode::kCopy); + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); +} + +TEST_F(CopyInsertionTest, PartiallyPipelinedAsyncSendRecvPipelineParallelism) { + constexpr absl::string_view kModuleString = R"( + HloModule test, entry_computation_layout={(f32[16]{0})->f32[16]{0}}, + num_partitions=4 + + while_body { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + + prev_fwd = f32[16]{0} get-tuple-element(param), index=3 + + prev_send = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=0 + send_done = (f32[16]{0}, token[]) send-done(prev_send), channel_id=1 + prev_recv = (f32[16]{0}, u32[], token[]) get-tuple-element(param), index=1 + recv_done = (f32[16]{0}, token[]) recv-done(prev_recv), channel_id=2 + + fwd = f32[16]{0} get-tuple-element(recv_done), index=0 + + after_all = token[] after-all() + send = (f32[16]{0}, u32[], token[]) send(prev_fwd, after_all), + channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + + // Both, the data that was sent and the data that was received are live + // until the end of the while loop. + ROOT tuple = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, prev_fwd, fwd) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16]{0} parameter(0) + after_all = token[] after-all() + recv = (f32[16]{0}, u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + send = (f32[16]{0}, u32[], token[]) send(data, after_all), channel_id=2, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) tuple(send, recv, data, data) + while = ((f32[16]{0}, u32[], token[]), (f32[16]{0}, u32[], token[]), + f32[16]{0}, f32[16]{0}) while(init), condition=while_condition, + body=while_body + recv_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16]{0}, token[]) recv-done(recv_ctx), channel_id=1 + send_ctx = (f32[16]{0}, u32[], token[]) get-tuple-element(while), index=0 + send_done = (f32[16]{0}, token[]) send-done(send_ctx), channel_id=2 + ROOT data_ = f32[16]{0} get-tuple-element(recv_done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + VLOG(2) << module->ToString(); + + // All async start/end will be ordered so that all copies but two are + // removable: + // - The copy for the extra use of the send operand. + // - The copy for the extra use of the recv result. + // The copy removal heuristic fails on removing one data copy, so the total + // number of copies in the while loop body is 3. + HloComputation* while_body = + hlo_query::FindComputation(module.get(), "while_body"); + EXPECT_EQ(CountCopies(*module), 6); + EXPECT_EQ(CountCopies(*while_body), 3); + + // Expect control dependency from send-done to send. + HloInstruction* send_done = + hlo_query::FindInstruction(while_body, HloOpcode::kSendDone); + HloInstruction* send = + hlo_query::FindInstruction(while_body, HloOpcode::kSend); + HloInstruction* send_operand_copy = send->mutable_operand(0); + EXPECT_THAT(send_operand_copy, op::Copy()); + EXPECT_THAT(send, op::Send(send_operand_copy, op::AfterAll())); + EXPECT_THAT(send_operand_copy->control_predecessors(), + UnorderedElementsAre(send_done)); + + // Expect control dependency from recv-done to recv. + HloInstruction* recv_done = + hlo_query::FindInstruction(while_body, HloOpcode::kRecvDone); + HloInstruction* recv = + hlo_query::FindInstruction(while_body, HloOpcode::kRecv); + HloInstruction* recv_done_copy = *absl::c_find_if( + recv->control_predecessors(), HloPredicateIsOp); + EXPECT_THAT(recv_done_copy, op::Copy(op::GetTupleElement(recv_done))); + EXPECT_THAT(recv->control_predecessors(), + UnorderedElementsAre(recv_done, recv_done_copy)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index d56ecda8559fd5..aadbf246c46377 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -81,6 +81,7 @@ filegroup( "runtime_single_threaded_matmul_c128.cc", "runtime_single_threaded_matmul_c64.cc", "runtime_single_threaded_matmul_common.h", + "runtime_single_threaded_matmul_f8.cc", "runtime_single_threaded_matmul_f16.cc", "runtime_single_threaded_matmul_f32.cc", "runtime_single_threaded_matmul_f64.cc", @@ -179,8 +180,9 @@ cc_library( "//xla/service:compiler", "//xla/service:generic_transfer_manager", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", @@ -240,105 +242,107 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/backends/cpu/runtime:thunk", + "//xla/hlo/analysis:hlo_ordering", + "//xla/hlo/analysis:indexed_array_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:batch_dot_simplification", + "//xla/hlo/transforms:bitcast_dtypes_expander", + "//xla/hlo/transforms:broadcast_canonicalizer", + "//xla/hlo/transforms:cholesky_expander", + "//xla/hlo/transforms:comparison_expander", + "//xla/hlo/transforms:conditional_canonicalizer", + "//xla/hlo/transforms:convolution_group_converter", + "//xla/hlo/transforms:dot_decomposer", + "//xla/hlo/transforms:dynamic_dimension_simplifier", + "//xla/hlo/transforms:dynamic_index_splitter", + "//xla/hlo/transforms:eigh_expander", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms:logistic_expander", + "//xla/hlo/transforms:operand_upcaster", + "//xla/hlo/transforms:optimization_barrier_expander", + "//xla/hlo/transforms:optimize_input_output_buffer_alias", + "//xla/hlo/transforms:qr_expander", + "//xla/hlo/transforms:reduce_decomposer", + "//xla/hlo/transforms:reduce_window_rewriter", + "//xla/hlo/transforms:reshape_decomposer", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:result_caster", + "//xla/hlo/transforms:rng_bit_generator_expander", + "//xla/hlo/transforms:rng_expander", + "//xla/hlo/transforms:simplify_fp_conversions", + "//xla/hlo/transforms:slice_sinker", + "//xla/hlo/transforms:sort_simplifier", + "//xla/hlo/transforms:stochastic_convert_decomposer", + "//xla/hlo/transforms:sub_byte_normalization", + "//xla/hlo/transforms:tree_reduction_rewriter", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms:while_loop_trip_count_annotator", + "//xla/hlo/transforms:zero_sized_hlo_elimination", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir_hlo", "//xla/mlir_hlo:all_passes", "//xla/mlir_hlo:mhlo_passes", "//xla/mlir_hlo:transforms_passes", - "//xla/service:algebraic_simplifier", "//xla/service:all_reduce_promotion", "//xla/service:all_to_all_decomposer", - "//xla/service:batch_dot_simplification", "//xla/service:batched_gather_scatter_normalizer", "//xla/service:batchnorm_expander", - "//xla/service:bitcast_dtypes_expander", - "//xla/service:broadcast_canonicalizer", "//xla/service:buffer_assignment", "//xla/service:call_graph", "//xla/service:call_inliner", "//xla/service:change_op_data_type", - "//xla/service:cholesky_expander", - "//xla/service:comparison_expander", "//xla/service:compiler", - "//xla/service:conditional_canonicalizer", "//xla/service:conditional_simplifier", "//xla/service:conditional_to_select", - "//xla/service:convolution_group_converter", "//xla/service:copy_insertion", "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:dot_decomposer", "//xla/service:dump", "//xla/service:dynamic_dimension_inference", - "//xla/service:dynamic_dimension_simplifier", - "//xla/service:dynamic_index_splitter", "//xla/service:dynamic_padder", - "//xla/service:eigh_expander", "//xla/service:executable", - "//xla/service:flatten_call_graph", - "//xla/service:float_normalization", "//xla/service:float_support", "//xla/service:gather_expander", - "//xla/service:hlo_constant_folding", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_cse", - "//xla/service:hlo_dce", "//xla/service:hlo_execution_profile", - "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", "//xla/service:hlo_verifier", - "//xla/service:indexed_array_analysis", "//xla/service:layout_assignment", "//xla/service:llvm_compiler", "//xla/service:logical_buffer", - "//xla/service:logistic_expander", "//xla/service:map_inliner", - "//xla/service:operand_upcaster", - "//xla/service:optimization_barrier_expander", - "//xla/service:optimize_input_output_buffer_alias", - "//xla/service:qr_expander", - "//xla/service:reduce_decomposer", - "//xla/service:reduce_window_rewriter", - "//xla/service:reshape_decomposer", - "//xla/service:reshape_mover", - "//xla/service:result_caster", - "//xla/service:rng_bit_generator_expander", - "//xla/service:rng_expander", "//xla/service:scatter_expander", "//xla/service:select_and_scatter_expander", "//xla/service:sharding_propagation", "//xla/service:sharding_remover", - "//xla/service:simplify_fp_conversions", - "//xla/service:slice_sinker", "//xla/service:slow_operation_alarm", - "//xla/service:sort_simplifier", - "//xla/service:stochastic_convert_decomposer", - "//xla/service:sub_byte_normalization", "//xla/service:topk_rewriter", "//xla/service:transpose_folding", - "//xla/service:tree_reduction_rewriter", "//xla/service:triangular_solve_expander", - "//xla/service:tuple_simplifier", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_invariant_code_motion", "//xla/service:while_loop_simplifier", - "//xla/service:while_loop_trip_count_annotator", - "//xla/service:zero_sized_hlo_elimination", "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:stateful_rng_spmd_partitioner", "//xla/service/spmd/shardy:shardy_xla_pass", - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform_id", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/hlo_to_mhlo:hlo_utils", "//xla/tsl/concurrency:async_value", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", @@ -393,7 +397,6 @@ cc_library( "@local_tsl//tsl/platform:threadpool_async_executor", "@local_tsl//tsl/profiler/lib:traceme", "@local_tsl//tsl/profiler/lib:traceme_encode", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ] + if_llvm_aarch64_available([ "@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep ]) + if_llvm_powerpc_available([ @@ -428,7 +431,7 @@ cc_library( "//xla/service:hlo_profile_printer_data_cc", "//xla/service:hlo_proto_cc", "//xla/service:llvm_compiler", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -497,6 +500,7 @@ cc_library( "//xla:util", "//xla/service:custom_call_target_registry", "//xla/service:llvm_compiler", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/memory", @@ -512,6 +516,7 @@ cc_library( "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:mlir_c_runner_utils", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", ] + xla_internal(["service/cpu:named_orc_jit_memory_mapper"]), ) @@ -556,43 +561,35 @@ cc_library( srcs = ["cpu_executable.cc"], hdrs = ["cpu_executable.h"], deps = [ - ":buffer_desc", ":cpu_runtime", ":simple_orc_jit", - ":xla_framework", "//xla:executable_run_options", "//xla:literal", "//xla:shape_tree", "//xla:shape_util", "//xla:status_macros", - "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thread_pool_task_runner", "//xla/backends/cpu/runtime:thunk", "//xla/backends/cpu/runtime:thunk_executor", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:computation_layout", "//xla/service:custom_call_status", "//xla/service:custom_call_status_internal", "//xla/service:executable", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_execution_profile", "//xla/service:hlo_value", - "//xla/service:logical_buffer", "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:xla_debug_info_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor/host:host_kernel_c_api", "//xla/stream_executor/host:host_stream", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", - "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -600,12 +597,8 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@llvm-project//llvm:Core", - "@llvm-project//llvm:ExecutionEngine", - "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcShared", "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:Parser", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -681,11 +674,11 @@ xla_cc_test( ":ir_emitter", ":ir_function", ":target_machine_features_fake", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:buffer_assignment", "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", - "//xla/service:hlo_parser", "//xla/service:logical_buffer", "//xla/tests:hlo_test_base", "@llvm-project//llvm:Core", @@ -706,11 +699,11 @@ xla_cc_test( "//xla:cpu_function_runtime", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:buffer_assignment", "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", - "//xla/service:hlo_parser", "//xla/service:logical_buffer", "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", @@ -768,6 +761,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/service/llvm_ir:tuple_ops", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -782,7 +776,6 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:TargetParser", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -980,8 +973,8 @@ xla_cc_binary( "//xla/client:client_library", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", @@ -1038,11 +1031,12 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/parser:hlo_parser", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", - "//xla/service:hlo_parser", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1281,6 +1275,7 @@ cc_library( "runtime_single_threaded_matmul_f16.cc", "runtime_single_threaded_matmul_f32.cc", "runtime_single_threaded_matmul_f64.cc", + "runtime_single_threaded_matmul_f8.cc", "runtime_single_threaded_matmul_s32.cc", "runtime_single_threaded_matmul_u8.cc", ], @@ -1293,6 +1288,7 @@ cc_library( "//xla/tsl/framework/contraction:eigen_contraction_kernel_no_mkl", "@com_google_absl//absl/base:core_headers", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", ], ) @@ -1305,6 +1301,7 @@ cc_library( deps = [ ":runtime_single_threaded_matmul_impl", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", ], ) @@ -1525,7 +1522,6 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", - "//xla/service:algebraic_simplifier", "//xla/service:computation_layout", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", @@ -1623,11 +1619,11 @@ xla_cc_test( ":backend_config_proto_cc", ":cpu_executable", ":parallel_task_assignment", + ":target_machine_features", ":target_machine_features_fake", "//xla:test", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", - "//xla/service/cpu:target_machine_features", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -1723,8 +1719,8 @@ xla_cc_test( deps = [ ":ir_emitter", ":target_machine_features_fake", + "//xla/hlo/analysis:hlo_ordering", "//xla/service:buffer_assignment", - "//xla/service:hlo_ordering", "//xla/service:logical_buffer", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest_main", @@ -1757,6 +1753,8 @@ cc_library( copts = runtime_copts() + tsl_copts(), visibility = ["//visibility:public"], deps = [ + ":backend_config_proto_cc", + ":onednn_config_proto_cc", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "@eigen_archive//:eigen3", diff --git a/third_party/xla/xla/service/cpu/benchmarks/BUILD b/third_party/xla/xla/service/cpu/benchmarks/BUILD index aa828150798a57..1a92411eb88552 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/BUILD +++ b/third_party/xla/xla/service/cpu/benchmarks/BUILD @@ -21,13 +21,13 @@ cc_library( hdrs = ["hlo_benchmark_runner.h"], deps = [ "//xla:literal", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_executable", "//xla/pjrt/cpu:cpu_client", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc index ed6aac57504c41..94fe5fd3ab2b85 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/hlo_benchmark_runner.cc @@ -22,14 +22,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal.h" #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_executable.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index ada4b7ba8dfae8..21feed2849ce32 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -78,6 +78,8 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/cpu_function_runtime.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/analysis/indexed_array_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -89,29 +91,58 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" +#include "xla/hlo/transforms/expanders/cholesky_expander.h" +#include "xla/hlo/transforms/expanders/comparison_expander.h" +#include "xla/hlo/transforms/expanders/dot_decomposer.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/eigh_expander.h" +#include "xla/hlo/transforms/expanders/logistic_expander.h" +#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" +#include "xla/hlo/transforms/expanders/qr_expander.h" +#include "xla/hlo/transforms/expanders/reduce_decomposer.h" +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" +#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/operand_upcaster.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/batch_dot_simplification.h" +#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/convolution_group_converter.h" +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" +#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/result_caster.h" +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" +#include "xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/literal.h" #include "xla/map_util.h" #include "xla/mlir_hlo/transforms/passes.h" #include "xla/primitive_util.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_to_all_decomposer.h" -#include "xla/service/batch_dot_simplification.h" #include "xla/service/batched_gather_scatter_normalizer.h" #include "xla/service/batchnorm_expander.h" -#include "xla/service/bitcast_dtypes_expander.h" -#include "xla/service/broadcast_canonicalizer.h" #include "xla/service/buffer_assignment.h" #include "xla/service/call_graph.h" #include "xla/service/call_inliner.h" #include "xla/service/change_op_data_type.h" -#include "xla/service/cholesky_expander.h" -#include "xla/service/comparison_expander.h" #include "xla/service/compiler.h" -#include "xla/service/conditional_canonicalizer.h" #include "xla/service/conditional_simplifier.h" #include "xla/service/conditional_to_select.h" -#include "xla/service/convolution_group_converter.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu/buffer_info_util.h" #include "xla/service/cpu/compiler_functor.h" @@ -129,75 +160,44 @@ limitations under the License. #include "xla/service/cpu/target_machine_features.h" #include "xla/service/cpu/thunk_emitter.h" #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/dot_decomposer.h" #include "xla/service/dump.h" #include "xla/service/dynamic_dimension_inference.h" -#include "xla/service/dynamic_dimension_simplifier.h" -#include "xla/service/dynamic_index_splitter.h" #include "xla/service/dynamic_padder.h" -#include "xla/service/eigh_expander.h" #include "xla/service/executable.h" -#include "xla/service/flatten_call_graph.h" -#include "xla/service/float_normalization.h" #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_execution_profile.h" -#include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_profile_printer_data.pb.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/indexed_array_analysis.h" #include "xla/service/layout_assignment.h" #include "xla/service/llvm_compiler.h" #include "xla/service/llvm_ir/llvm_command_line_options.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/logical_buffer.h" -#include "xla/service/logistic_expander.h" #include "xla/service/map_inliner.h" -#include "xla/service/operand_upcaster.h" -#include "xla/service/optimization_barrier_expander.h" -#include "xla/service/optimize_input_output_buffer_alias.h" -#include "xla/service/qr_expander.h" -#include "xla/service/reduce_decomposer.h" -#include "xla/service/reduce_window_rewriter.h" -#include "xla/service/reshape_decomposer.h" -#include "xla/service/reshape_mover.h" -#include "xla/service/result_caster.h" -#include "xla/service/rng_bit_generator_expander.h" -#include "xla/service/rng_expander.h" #include "xla/service/scatter_expander.h" #include "xla/service/select_and_scatter_expander.h" #include "xla/service/sharding_propagation.h" #include "xla/service/sharding_remover.h" #include "xla/service/slow_operation_alarm.h" -#include "xla/service/sort_simplifier.h" #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" -#include "xla/service/stochastic_convert_decomposer.h" -#include "xla/service/sub_byte_normalization.h" #include "xla/service/topk_rewriter.h" #include "xla/service/transpose_folding.h" -#include "xla/service/tree_reduction_rewriter.h" #include "xla/service/triangular_solve_expander.h" -#include "xla/service/tuple_simplifier.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_invariant_code_motion.h" #include "xla/service/while_loop_simplifier.h" -#include "xla/service/while_loop_trip_count_annotator.h" -#include "xla/service/zero_sized_hlo_elimination.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/host/host_platform_id.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/chain.h" @@ -221,10 +221,10 @@ limitations under the License. #endif #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" #include "xla/service/cpu/cpu_float_support.h" #include "xla/service/cpu/onednn_contraction_rewriter.h" #include "xla/service/cpu/onednn_ops_rewriter.h" -#include "xla/service/simplify_fp_conversions.h" #endif namespace xla { @@ -600,6 +600,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( #endif FloatSupport f8e5m2_support(F8E5M2, F16); pipeline.AddPass(&f8e5m2_support); + FloatSupport f8e4m3_support(F8E4M3, F16); + pipeline.AddPass(&f8e4m3_support); FloatSupport f8e4m3fn_support(F8E4M3FN, F16); pipeline.AddPass(&f8e4m3fn_support); FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); @@ -608,6 +610,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(&f8e5m2fnuz_support); FloatSupport f8e4m3fnuz_support(F8E4M3FNUZ, F16); pipeline.AddPass(&f8e4m3fnuz_support); + FloatSupport f8e3m4_support(F8E3M4, F16); + pipeline.AddPass(&f8e3m4_support); // After canonicalization, there may be more batch dots that can be // simplified. pipeline.AddPass(); @@ -862,7 +866,18 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( // interfering with the rewrites. pipeline.AddPass(); pipeline.AddPass(true); - pipeline.AddPass(); + + // If enabled we'll use more precise region based analysis for copy removal. + if (module->config() + .debug_options() + .xla_cpu_copy_insertion_use_region_analysis()) { + pipeline.AddPass( + /*can_share_buffer=*/nullptr, + /*use_region_based_live_range_analysis=*/-1); + } else { + pipeline.AddPass(); + } + pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -999,10 +1014,11 @@ absl::Status CreateHloProfilingArtifacts( absl::StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, se::StreamExecutor* /*stream_exec*/, const CompileOptions& options) { + auto& config = module->config(); std::unique_ptr jit_target_machine = SimpleOrcJIT::InferTargetMachineForJIT( - CompilerTargetOptions(module->config()), - CodeGenOptLevel(module->config())); + CompilerTargetOptions(config), CodeGenOptLevel(config), + config.debug_options().xla_cpu_max_isa()); TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false, jit_target_machine.get(), @@ -1302,6 +1318,18 @@ static bool HasLargeConstants(llvm::Module& module) { return false; } +inline void VlogMaxIsa(absl::string_view max_cpu_isa) { + if (VLOG_IS_ON(1) && !max_cpu_isa.empty()) { + if (tsl::port::IsX86CPU()) { + VLOG(1) << "`xla_cpu_max_isa` is set. Will not use features newer than: " + << max_cpu_isa; + } else { + VLOG(1) << "`xla_cpu_max_isa` is set to `" << max_cpu_isa + << "`. This flag is not supported on non-x86 CPUs yet."; + } + } +} + absl::StatusOr> CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { TraceMe trace([&] { @@ -1331,6 +1359,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // parallel compilation at run time. size_t parallel_codegen_split_count = debug_options.xla_cpu_parallel_codegen_split_count(); + VlogMaxIsa(debug_options.xla_cpu_max_isa()); auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), @@ -1341,7 +1370,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook, post_optimization_ir_hook, CreateOrcJITPostCompilationHook(module.get(), &obj_files), - parallel_codegen_split_count); + parallel_codegen_split_count, debug_options.xla_cpu_max_isa()); if (!jit) { return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } @@ -2033,15 +2062,18 @@ CpuExecutableAotCompilationResult::LoadExecutable( compiler->BufferSizeBytesFunction(), /*can_share_buffer=*/nullptr)); + const DebugOptions& debug_options = module->config().debug_options(); + VlogMaxIsa(debug_options.xla_cpu_max_isa()); auto jit = SimpleOrcJIT::Create( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), - module->config().debug_options().xla_llvm_disable_expensive_passes(), + debug_options.xla_llvm_disable_expensive_passes(), options::SlpVectorizerDisabled(module->config()), llvm_ir::GetCpuFastMathFlags(module->config()), /*pre_optimization_hook=*/nullptr, /*post_optimization_hook=*/nullptr, - /*post_codegen_hook=*/nullptr); + /*post_codegen_hook=*/nullptr, /*num_jit_dylibs=*/1, + debug_options.xla_cpu_max_isa()); if (!jit) { return Internal("Creating JIT failed: %s", llvm::toString(jit.takeError())); } diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index e1f4b213170651..de8b79351cd18b 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -42,6 +42,7 @@ limitations under the License. #include "llvm/IR/Mangler.h" #include "llvm/Support/Error.h" #include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thread_pool_task_runner.h" #include "xla/backends/cpu/runtime/thunk.h" #include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" @@ -389,15 +390,15 @@ absl::Status CpuExecutable::ExecuteThunks( Thunk::CustomCallExecuteParams::Create(run_options)); // Use the intra-op thread pool to offload thunk executor tasks. - Thunk::TaskRunner task_runner = [run_options](Thunk::Task task) { - run_options->intra_op_thread_pool()->getPool()->Schedule(std::move(task)); - }; + auto* intra_op_thread_pool = run_options->intra_op_thread_pool(); + ThreadPoolTaskRunner task_runner( + intra_op_thread_pool ? intra_op_thread_pool->getPool() : nullptr); Thunk::ExecuteParams execute_params = { &*function_registry_, &allocations, runtime::GetXfeedManager(runtime::GetDeviceOrdinal(run_options)), - run_options->intra_op_thread_pool(), + intra_op_thread_pool, &task_runner, &collective_execute_params, &custom_call_execute_params}; diff --git a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc index 1743cad5c2c5ec..957bcfb22d851c 100644 --- a/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc +++ b/third_party/xla/xla/service/cpu/cpu_instruction_fusion.cc @@ -79,22 +79,23 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, if (CanBeOutputFused(producer, consumer)) { VLOG(2) << "Fusion OK: Can create output fusion."; - return {}; + return FusionDecision::Allow(); } if (CanBeOutputFusedIntoSomeOperand(producer)) { - return "Bailing because producer can be output-fused into some operand."; + return FusionDecision::Forbid( + "Bailing because producer can be output-fused into some operand."); } if (!CanBeLoopFused(*producer)) { - return "Producer is not loop-fusible."; + return FusionDecision::Forbid("Producer is not loop-fusible."); } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && ReusesOperandElements(consumer, operand_index)) { - return "Fusion is not profitable."; + return FusionDecision::Forbid("Fusion is not profitable."); } RETURN_IF_NOT_FUSIBLE(InstructionFusion::ShouldFuse(consumer, operand_index)); @@ -103,12 +104,14 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // just a constant and another node. if (producer->opcode() == HloOpcode::kConstant && consumer->opcode() != HloOpcode::kFusion) { - return "Not fusing: insufficient non-constant nodes."; + return FusionDecision::Forbid( + "Not fusing: insufficient non-constant nodes."); } // Output fusion is not currently supported on CPUs. if (producer->opcode() == HloOpcode::kFusion) { - return "Not fusing: producer is itself a fusion node."; + return FusionDecision::Forbid( + "Not fusing: producer is itself a fusion node."); } // Don't fuse if fusing would cause too much code duplication because of @@ -126,7 +129,7 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh( producer)) { - return "Code duplication too high"; + return FusionDecision::Forbid("Code duplication too high"); } } @@ -149,13 +152,13 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, ShapeUtil::ByteSizeOfElements(consumer->operand(0)->shape()) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; - return {}; + return FusionDecision::Allow(); } else if (consumer->operand(1)->shape().rank() == 1 && operand_index == 0 && ShapeUtil::ByteSizeOfElements(consumer->operand(1)->shape()) < kFusionThresholdBytes) { VLOG(2) << "Fusing small matrix-vector product."; - return {}; + return FusionDecision::Allow(); } } } @@ -166,26 +169,28 @@ FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, !absl::c_linear_search( consumer->dimensions(), LayoutUtil::Minor(consumer->operand(0)->shape().layout(), 0))) { - return "Not fusing reductions over major dimensions"; + return FusionDecision::Forbid( + "Not fusing reductions over major dimensions"); } if (producer->opcode() == HloOpcode::kReduce && !absl::c_linear_search( producer->dimensions(), LayoutUtil::Minor(producer->operand(0)->shape().layout(), 0))) { - return "Not fusing reductions over major dimensions"; + return FusionDecision::Forbid( + "Not fusing reductions over major dimensions"); } if (consumer->IsLoopFusion()) { VLOG(2) << "Fusing: consumer is a fusion node."; - return {}; + return FusionDecision::Allow(); } if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusible."; - return {}; + return FusionDecision::Allow(); } - return "Not fusing: not found a fusible case"; + return FusionDecision::Forbid("Not fusing: not found a fusible case"); } HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc index 3824dfdf143843..405b2ac1decf79 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/computation_layout.h" #include "xla/service/cpu/target_machine_features_fake.h" #include "xla/shape_layout.h" diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index 4e209e61f283c6..db8660c811eec4 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/time/time.h" #include "absl/types/span.h" #include "xla/executable_run_options.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" @@ -47,7 +48,6 @@ limitations under the License. #include "xla/service/cpu/in_process_collectives.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream_executor.h" @@ -117,6 +117,10 @@ extern const char* const kEigenConv3DF32SymbolName = extern const char* const kDuccFftSymbolName = "__xla_cpu_runtime_DuccFft"; extern const char* const kDuccSingleThreadedFftSymbolName = "__xla_cpu_runtime_DuccSingleThreadedFft"; +extern const char* const kEigenSingleThreadedMatMulF8E4M3FNSymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN"; +extern const char* const kEigenSingleThreadedMatMulF8E5M2SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2"; extern const char* const kEigenSingleThreadedMatMulF16SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF16"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index 92beff43a3c0ea..5ac8e39101c844 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -62,6 +62,8 @@ extern const char* const kDuccSingleThreadedFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF16SymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedMatMulF8E4M3FNSymbolName; +extern const char* const kEigenSingleThreadedMatMulF8E5M2SymbolName; extern const char* const kEigenSingleThreadedMatMulC64SymbolName; extern const char* const kEigenSingleThreadedMatMulC128SymbolName; extern const char* const kEigenSingleThreadedMatMulS32SymbolName; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index e043b5c2e13bec..ed63804c7d1ea6 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -87,10 +87,10 @@ limitations under the License. #include "xla/service/llvm_ir/tuple_ops.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -109,13 +109,21 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { +bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string) { + return (absl::StrContains(feature_string, "+avxneconvert") || + absl::StrContains(feature_string, "+amx-bf16")); +} + class IrEmitter::CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) : ElementalIrEmitter( module, ir_emitter->b(), - Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), + Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/ + !IsNativeConvertSupportedOnTargetCPU( + ir_emitter->target_machine_features_ + .get_target_feature_string())}), hlo_module_config_(module_config), ir_emitter_(ir_emitter) {} @@ -3836,12 +3844,13 @@ llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(llvm::IRBuilder<>* b) { llvm::Module* module = b->GetInsertBlock()->getModule(); if (!use_rdtscp_) { llvm::Function* func_llvm_readcyclecounter = - llvm::Intrinsic::getDeclaration(module, - llvm::Intrinsic::readcyclecounter); + llvm::Intrinsic::getOrInsertDeclaration( + module, llvm::Intrinsic::readcyclecounter); return b->CreateCall(func_llvm_readcyclecounter); } llvm::Function* func_llvm_x86_rdtscp = - llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp); + llvm::Intrinsic::getOrInsertDeclaration(module, + llvm::Intrinsic::x86_rdtscp); llvm::Value* rdtscp_call = b->CreateCall(func_llvm_x86_rdtscp); return b->CreateExtractValue(rdtscp_call, {0}); } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index d2c94a913c0c42..1112fcf0556ff8 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -69,6 +69,8 @@ namespace cpu { // Forward declare emitter for XLA:CPU thunks. class IrEmitter2; +bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string); + // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc index 13e9ee64880273..698af33ae2580c 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc @@ -29,13 +29,13 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" #include "xla/cpu_function_runtime.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/target_machine_features_fake.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_ordering.h" -#include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/logical_buffer.h" diff --git a/third_party/xla/xla/service/cpu/ir_emitter_test.cc b/third_party/xla/xla/service/cpu/ir_emitter_test.cc index 7102d20421df4b..ce5c8c403bce7b 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter_test.cc @@ -28,13 +28,13 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_function.h" #include "xla/service/cpu/target_machine_features_fake.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_ordering.h" -#include "xla/service/hlo_parser.h" #include "xla/service/logical_buffer.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -120,5 +120,64 @@ TEST_F(IrEmitterTest, ComputeFuncStack) { ir_emitter.PopComputeFunction(); } +TEST_F(IrEmitterTest, CheckNativeConvertSupportOnTargetCPU) { + std::string spr_feature_string = + "+prfchw,+cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,+" + "avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,+avx512ifma,+xsave,+sse4.2,+" + "tsxldtrk,-sm3,+ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,+" + "avx512vpopcntdq,+cmov,-avx512vp2intersect,+avx512cd,+movbe,-avxvnniint8," + "-ccmp,+amx-int8,-kl,-avx10.1-256,+evex512,+avxvnni,-rtm,+adx,+avx2,-" + "hreset,+movdiri,+serialize,-sha512,+vpclmulqdq,+avx512vl,+uintr,-cf,+" + "clflushopt,-raoint,-cmpccxadd,+bmi,+amx-tile,+sse,-avx10.2-256,+gfni,-" + "avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,+avx512f,+amx-bf16,+" + "avx512bf16,+avx512vnni,-push2pop2,+cx8,+avx512bw,+sse3,+pku,-nf,+" + "fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,+sha,+movdir64b,-ppx,+wbnoinvd,+" + "enqcmd,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+" + "cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,+avx512bitalg,-rdpru,+clwb,+mmx,+" + "sse2,+rdseed,+avx512vbmi2,-prefetchi,+rdpid,-fma4,+avx512vbmi,+shstk,+" + "vaes,+waitpkg,-sgx,+fxsr,+avx512dq,-sse4a"; + + std::string skx_feature_string = + "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-" + "avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-" + "tsxldtrk,-sm3,-ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,-" + "avx512vpopcntdq,+cmov,-avx512vp2intersect,+avx512cd,+movbe,-avxvnniint8," + "-ccmp,-amx-int8,-kl,-avx10.1-256,+evex512,-avxvnni,+rtm,+adx,+avx2,-" + "hreset,-movdiri,-serialize,-sha512,-vpclmulqdq,+avx512vl,-uintr,-cf,+" + "clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,-gfni,-" + "avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,+avx512f,-amx-bf16,-" + "avx512bf16,-avx512vnni,-push2pop2,+cx8,+avx512bw,+sse3,+pku,-nf,+" + "fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-" + "enqcmd,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+" + "cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,+clwb,+mmx,+" + "sse2,+rdseed,-avx512vbmi2,-prefetchi,-rdpid,-fma4,-avx512vbmi,-shstk,-" + "vaes,-waitpkg,-sgx,+fxsr,+avx512dq,-sse4a"; + + std::string srf_feature_string = + "+prfchw,+cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-" + "avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-" + "tsxldtrk,-sm3,+ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,-" + "avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,+avxvnniint8," + "-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,+avxvnni,-rtm,+adx,+avx2,-" + "hreset,+movdiri,+serialize,+vpclmulqdq,-avx512vl,+uintr,-cf,+clflushopt," + "-raoint,+cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,+gfni,-avxvnniint16," + "-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-" + "avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,+pku,-nf,+fsgsbase,-clzero,-" + "mwaitx,-lwp,+lzcnt,+sha,+movdir64b,-ppx,+wbnoinvd,+enqcmd,-avx10.2-512,+" + "avxneconvert,-tbm,+pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt," + "+avxifma,+f16c,-avx512bitalg,-rdpru,+clwb,+mmx,+sse2,+rdseed,-" + "avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,+waitpkg,+" + "sgx,+fxsr,-avx512dq,-sse4a"; + + // Testing sapphire-rapids target + ASSERT_TRUE(IsNativeConvertSupportedOnTargetCPU(spr_feature_string)); + + // Testing skylake target + ASSERT_FALSE(IsNativeConvertSupportedOnTargetCPU(skx_feature_string)); + + // Testing sierra-forest target + ASSERT_TRUE(IsNativeConvertSupportedOnTargetCPU(srf_feature_string)); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/onednn_config.proto b/third_party/xla/xla/service/cpu/onednn_config.proto index 9f38673eaacebd..44829a6857f1f9 100644 --- a/third_party/xla/xla/service/cpu/onednn_config.proto +++ b/third_party/xla/xla/service/cpu/onednn_config.proto @@ -113,4 +113,6 @@ message OneDnnConvolutionConfig { OneDnnFusionConfig fusions = 6; uint64 feature_groups = 7; + + OneDnnOptimizationConfig optimization_config = 8; } diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index 19122b393ce23b..01ffb340e07c1b 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -111,6 +111,16 @@ inline auto OneDnnMatmulInstr(HloInstruction** instr) { return m::CustomCall(instr, {"__onednn$matmul"}); } +inline auto OneDnnConvolutionInstr(HloInstruction** instr) { + return m::CustomCall(instr, {"__onednn$convolution"}); +} + +inline auto OneDnnFusibleInstr(HloInstruction** instr) { + return m::AnyOf( + m::CustomCall(instr, {"__onednn$matmul"}), + m::CustomCall(instr, {"__onednn$convolution"})); +} + inline auto ConvertBF16ToF32(HloInstruction** instr) { return m::Convert(m::Op(instr).WithElementType(PrimitiveType::BF16)) .WithElementType(PrimitiveType::F32); @@ -275,11 +285,12 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) { return OneDnnFusionConfig::UNDEFINED; } -// OneDNN matmul can fuse add operation with automatic broadcasting along the -// addend's dimensions that are 1s. When compatible, Broadcast can be replaced -// by Bitcast, which is much cheaper. Compute new shape for the Bitcast. +// OneDNN matmul / convolution can fuse add operation with automatic +// broadcasting along the addend's dimensions that are 1s. When compatible, +// Broadcast can be replaced by Bitcast, which is much cheaper. Compute new +// shape for the Bitcast. absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, - const Shape& dot_shape) { + const Shape& instr_shape) { if (broadcast_instr->opcode() != HloOpcode::kBroadcast) { return absl::InvalidArgumentError( "Hlo instruction is not a Broadcast insruction."); @@ -303,9 +314,9 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, } } - // If rank(new_shape) > rank(dot), extra dimensions with value = 1 can be + // If rank(new_shape) > rank(instr), extra dimensions with value = 1 can be // deleted from the new_shape. - int64_t rank_difference = new_shape.rank() - dot_shape.rank(); + int64_t rank_difference = new_shape.rank() - instr_shape.rank(); auto new_dims = new_shape.dimensions(); std::vector dims_to_delete; for (int i = 0; i < rank_difference; ++i) { @@ -316,8 +327,8 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, new_shape = ShapeUtil::DeleteDimensions(dims_to_delete, new_shape); // New shape for bias should satisfy the condition: - // rank(new_shape) <= rank(dot). - if (new_shape.rank() > dot_shape.rank()) { + // rank(new_shape) <= rank(instr). + if (new_shape.rank() > instr_shape.rank()) { return absl::CancelledError( "Bias shape could not be adjusted for a fusion."); } @@ -325,20 +336,20 @@ absl::StatusOr AdjustBiasShape(const HloInstruction* broadcast_instr, return new_shape; }; -inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* dot) { - // Check if the operand's shape is compatible with matmul for fusion. +inline bool IsOperandFusible(HloInstruction* operand, HloInstruction* instr) { + // Check if the operand's shape is compatible for fusion. // An operand is fusable if - // 1. rank(operand) <= rank(dot) and + // 1. rank(operand) <= rank(instr) and // 2. Starting from the last dim in backward direction, the dimension // size of operand is either 1 or same to dot. auto operand_dims = operand->shape().dimensions(); - auto dot_dims = dot->shape().dimensions(); - if (operand_dims.size() > dot_dims.size()) return false; + auto instr_dims = instr->shape().dimensions(); + if (operand_dims.size() > instr_dims.size()) return false; int operand_idx = operand_dims.size() - 1; - int dot_idx = dot_dims.size() - 1; - for (; operand_idx >= 0; --operand_idx, --dot_idx) { + int instr_idx = instr_dims.size() - 1; + for (; operand_idx >= 0; --operand_idx, --instr_idx) { if (operand_dims[operand_idx] != 1 && - operand_dims[operand_idx] != dot_dims[dot_idx]) + operand_dims[operand_idx] != instr_dims[instr_idx]) return false; } return true; @@ -367,6 +378,7 @@ inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert, bool OneDnnContractionRewriter::ShouldRewriteDot( const HloInstruction* dot_instr, bool before_layout_assignment) { + if (dot_instr->opcode() != HloOpcode::kDot) return false; // Currently, blocking control dependencies if (dot_instr->HasControlDependencies()) return false; if (!IsSupportedType(dot_instr->shape().element_type())) return false; @@ -429,6 +441,7 @@ bool OneDnnContractionRewriter::ShouldRewriteDot( bool OneDnnContractionRewriter::ShouldRewriteConv( const HloInstruction* conv_instr) { + if (conv_instr->opcode() != HloOpcode::kConvolution) return false; if (conv_instr->HasControlDependencies()) return false; if (!IsSupportedType(conv_instr->shape().element_type())) return false; if (conv_instr->batch_group_count() != 1) return false; @@ -566,14 +579,14 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { } absl::Status HandleAdd(HloInstruction* instr) override { - // Try to do a fusion for Dot(onednn-matmul) + Add. However, + // Try to fuse Add to the instr. However, // HLO Add instruction might receive the addends after additional // processing like Broadcast, Bitcast, Convert, etc. is applied to the raw // addends. Here, the following possible pattern is matched. // // clang-format off // - // Dot addend + // Dot / Conv addend // | | // v v // optional instructions optional instructions @@ -586,148 +599,154 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { // // clang-format on - HloInstruction *addend_intermediate, *dot; - HloInstruction* optional_dot_bitcast = nullptr; - HloInstruction* optional_dot_convert = nullptr; + HloInstruction *addend_intermediate, *contraction; + HloInstruction* optional_contraction_bitcast = nullptr; + HloInstruction* optional_contraction_convert = nullptr; auto pattern = m::AddAnyOrder( &instr, - OptionalConvertAndBitcast(&optional_dot_convert, &optional_dot_bitcast, - OneDnnMatmulInstr(&dot)) + OptionalConvertAndBitcast(&optional_contraction_convert, + &optional_contraction_bitcast, + OneDnnFusibleInstr(&contraction)) .WithOneUser(), m::Op(&addend_intermediate)); if (Match(instr, pattern)) { - if (!IsSupportedType(dot->shape().element_type())) - return absl::OkStatus(); - // TODO(intel-tf): Remove the condition below when the fusion Dot + - // Add(bias) + Add(e.g., residual) is enabled. - if (!dot->backend_config() - ->mutable_onednn_matmul_config() - ->mutable_fusions() - ->ops() - .empty() && - dot->backend_config() - ->mutable_onednn_matmul_config() - ->mutable_fusions() - ->ops(0) == OneDnnFusionConfig::BIAS) { - return absl::OkStatus(); - } - std::vector new_operands; - for (auto operand : dot->operands()) { - new_operands.push_back(operand); - } + HANDLE_OP_INTERNAL(HandleAddInternal, contraction, instr, + addend_intermediate, optional_contraction_convert, + optional_contraction_bitcast); + } - // At this point, the addend could have one of the following - // possiblities that the current fusion can handle: - // - // - addend -> Convert -> Broadcast -> Add - // - addend -> Broadcast -> Convert -> Add - // - addend -> Convert - // - addend -> Broadcast - // - addend - // - // Hunt for addend through possible sequences above and check the addend - // is compatible to onednn-matmul fusion. - HloInstruction* addend = nullptr; - HloInstruction* optional_addend_broadcast = nullptr; - auto addend_pattern = m::AnyOf( - m::Broadcast(&optional_addend_broadcast, - m::Convert(&addend, m::Op())), - m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), - m::Convert(&addend, m::Op()), - m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), - m::Op(&addend)); - if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); - - if (optional_addend_broadcast && addend->shape().rank() != 1) { - auto new_shape = - AdjustBiasShape(optional_addend_broadcast, dot->shape()); - if (new_shape.ok()) { - addend = addend->AddInstruction( - HloInstruction::CreateBitcast(new_shape.value(), addend)); - } else { - VLOG(2) << new_shape.status(); - return absl::OkStatus(); - } - } + return absl::OkStatus(); + } - // Validate addend for fusion. - if (IsSupportedType(addend->shape().element_type()) && - IsOperandFusible(addend, dot)) { - new_operands.push_back(addend); + template + absl::Status HandleAddInternal(HloInstruction* contraction, + HloInstruction* instr, + HloInstruction* addend_intermediate, + HloInstruction* optional_contraction_convert, + HloInstruction* optional_contraction_bitcast) { + if (!IsSupportedType(contraction->shape().element_type())) + return absl::OkStatus(); + // TODO(intel-tf): Remove the condition below when the fusion Contraction + + // Add(bias) + Add(e.g., residual) is enabled. + auto contraction_config = contraction->backend_config(); + if (!GetKernelConfig(&contraction_config) + ->mutable_fusions() + ->ops() + .empty() && + GetKernelConfig(&contraction_config) + ->mutable_fusions() + ->ops(0) == OneDnnFusionConfig::BIAS) { + return absl::OkStatus(); + } + std::vector new_operands; + for (auto operand : contraction->operands()) { + new_operands.push_back(operand); + } + + // At this point, the addend could have one of the following + // possiblities that the current fusion can handle: + // + // - addend -> Convert -> Broadcast -> Add + // - addend -> Broadcast -> Convert -> Add + // - addend -> Convert + // - addend -> Broadcast + // - addend + // + // Hunt for addend through possible sequences above and check the addend + // is compatible for onednn fusion. + HloInstruction* addend = nullptr; + HloInstruction* optional_addend_broadcast = nullptr; + auto addend_pattern = m::AnyOf( + m::Broadcast(&optional_addend_broadcast, m::Convert(&addend, m::Op())), + m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), + m::Convert(&addend, m::Op()), + m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), + m::Op(&addend)); + if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); + + if (optional_addend_broadcast && addend->shape().rank() != 1) { + auto new_shape = + AdjustBiasShape(optional_addend_broadcast, contraction->shape()); + if (new_shape.ok()) { + addend = addend->AddInstruction( + HloInstruction::CreateBitcast(new_shape.value(), addend)); } else { + VLOG(2) << new_shape.status(); return absl::OkStatus(); } + } - // TODO(intel-tf): Remove this restriction once oneDNN has an optimized - // implementation for broadcasted add across all dimensions. - OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; - kind = (addend->shape().rank() == 1) - ? (dot->backend_config() - ->mutable_onednn_matmul_config() - ->fusions() - .ops() - .empty() - ? OneDnnFusionConfig::BIAS - : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; - if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); + // Validate addend for fusion. + if (IsSupportedType(addend->shape().element_type()) && + IsOperandFusible(addend, contraction)) { + new_operands.push_back(addend); + } else { + return absl::OkStatus(); + } - auto matmul_call = Cast(instr->AddInstruction( - dot->CloneWithNewOperands(dot->shape(), new_operands))); + auto custom_call = Cast(instr->AddInstruction( + contraction->CloneWithNewOperands(contraction->shape(), new_operands))); - auto backend_config = matmul_call->backend_config(); - backend_config->mutable_onednn_matmul_config() - ->mutable_fusions() - ->add_ops(kind); + auto backend_config = custom_call->backend_config(); - if (optional_addend_broadcast) { - backend_config->mutable_onednn_matmul_config() - ->mutable_optimization_config() - ->set_bias_broadcast(true); - } - TF_RETURN_IF_ERROR(matmul_call->set_backend_config(*backend_config)); + // TODO(intel-tf): Remove this restriction once oneDNN has an optimized + // implementation for broadcasted add across all dimensions. + OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; + kind = + (addend->shape().rank() == 1) + ? (GetKernelConfig(&backend_config)->fusions().ops().empty() + ? OneDnnFusionConfig::BIAS + : OneDnnFusionConfig::UNDEFINED) + : OneDnnFusionConfig::BINARY_ADD; + if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); - HloInstruction* new_instr; - // If matched pattern has custom-call -> bitcast -> add, then we need to - // insert bitcast after the new fusion to maintain the correct shape - // (new-custom-call -> bitcast). Also, this will optionally be followed - // by -> convert for bf16 case to avoid datatype mismatch. - if (optional_dot_bitcast != nullptr && - optional_dot_bitcast->opcode() == HloOpcode::kBitcast) { - if (optional_dot_convert != nullptr && - optional_dot_convert->opcode() == HloOpcode::kConvert) { - auto bitcast_call = - matmul_call->AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::ChangeElementType( - instr->shape(), matmul_call->shape().element_type()), - matmul_call)); - new_instr = - bitcast_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - bitcast_call->shape(), - optional_dot_convert->shape().element_type()), - bitcast_call)); - } else { - new_instr = matmul_call->AddInstruction( - HloInstruction::CreateBitcast(instr->shape(), matmul_call)); - } + GetKernelConfig(&backend_config)->mutable_fusions()->add_ops(kind); + + if (optional_addend_broadcast) { + GetKernelConfig(&backend_config) + ->mutable_optimization_config() + ->set_bias_broadcast(true); + } + TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config)); + + HloInstruction* new_instr; + // If matched pattern has custom-call -> bitcast -> add, then we need to + // insert bitcast after the new fusion to maintain the correct shape + // (new-custom-call -> bitcast). Also, this will optionally be followed + // by -> convert for bf16 case to avoid datatype mismatch. + if (optional_contraction_bitcast != nullptr && + optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) { + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + auto bitcast_call = + custom_call->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType( + instr->shape(), custom_call->shape().element_type()), + custom_call)); + new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + bitcast_call->shape(), + optional_contraction_convert->shape().element_type()), + bitcast_call)); } else { - if (optional_dot_convert != nullptr && - optional_dot_convert->opcode() == HloOpcode::kConvert) { - new_instr = matmul_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - matmul_call->shape(), - optional_dot_convert->shape().element_type()), - matmul_call)); - } else { - new_instr = matmul_call; - } + new_instr = custom_call->AddInstruction( + HloInstruction::CreateBitcast(instr->shape(), custom_call)); + } + } else { + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + custom_call->shape(), + optional_contraction_convert->shape().element_type()), + custom_call)); + } else { + new_instr = custom_call; } - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); } - + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h index 503d8a8ee25630..2706d05d1ef920 100644 --- a/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h @@ -50,12 +50,30 @@ class OneDnnContractionRewriter : public HloModulePass { static bool ShouldRewriteDot(const HloInstruction* dot_instr, bool before_layout_assignment = false); static bool ShouldRewriteConv(const HloInstruction* conv_instr); + static bool ShouldRewriteInstr(const HloInstruction* instr, + bool before_layout_assignment = false) { + return ShouldRewriteDot(instr, before_layout_assignment) || + ShouldRewriteConv(instr); + } private: int intra_op_parallelism_; const tsl::thread::ThreadPool* compile_threadpool_; }; +#define HANDLE_OP_INTERNAL(internal_callee, contraction, ...) \ + switch (contraction->backend_config() \ + ->backend_config_oneof_case()) { \ + case BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig: \ + return internal_callee< \ + BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig>( \ + contraction, __VA_ARGS__); \ + default: \ + return internal_callee< \ + BackendConfig::BackendConfigOneofCase::kOnednnConvConfig>( \ + contraction, __VA_ARGS__); \ + } + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.cc b/third_party/xla/xla/service/cpu/onednn_convolution.cc index 7ed7987137ad60..30e91fb4aae3e7 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.cc +++ b/third_party/xla/xla/service/cpu/onednn_convolution.cc @@ -63,6 +63,13 @@ dnnl::memory::format_tag GetFormatTag(const int dims) { : dnnl::memory::format_tag::any; } +template <> +typename PrimitiveTrait::pointer_type +GetKernelConfig( + absl::StatusOr* backend_config) { + return (*backend_config)->mutable_onednn_conv_config(); +} + ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( void* result, void** args) { // args[0]: ptr to nargs @@ -154,7 +161,6 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( MemrefInfo ker_minfo(args[arg_indx++]); MemrefInfo res_minfo(result); - // Permute memory descriptors auto inp_md = inp_minfo.GetOneDnnMemDesc(); auto ker_md = ker_minfo.GetOneDnnMemDesc(); auto res_md = res_minfo.GetOneDnnMemDesc(); @@ -174,6 +180,50 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( new_ker_md = new_ker_md.reshape(corr_dims); } + const int64_t num_fused_operands = num_args - arg_indx; + std::vector fused_mds; + std::vector fused_bufs; + for (int64_t i = 0; i < num_fused_operands; ++i) { + MemrefInfo operand_minfo(args[arg_indx++]); + fused_mds.push_back(operand_minfo.GetOneDnnMemDesc()); + fused_bufs.push_back(operand_minfo.Data()); + } + + std::vector> postop_args; + + auto bias_md = memory::desc(); + + dnnl::post_ops post_ops; + int fused_operand_idx = 0; + for (auto& fused_op : conv_config.fusions().ops()) { + switch (fused_op) { + case OneDnnFusionConfig::BIAS: { + bias_md = fused_mds.at(fused_operand_idx); + postop_args.emplace_back( + DNNL_ARG_BIAS, + dnnl::memory(bias_md, cpu_engine, fused_bufs[fused_operand_idx])); + fused_operand_idx++; + } break; + case OneDnnFusionConfig::BINARY_ADD: { + auto binary_md = fused_mds.at(fused_operand_idx); + binary_md = binary_md.permute_axes(out_axes); + auto arg_idx = + DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; + postop_args.emplace_back( + arg_idx, + dnnl::memory(binary_md, cpu_engine, fused_bufs[fused_operand_idx])); + post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); + fused_operand_idx++; + } break; + default: + LOG(FATAL) + << __FILE__ << ":" << __LINE__ + << " Attempt to call OneDNN Convolution runtime library with " + "unsupported post op." + << std::endl; + } + } + auto any_ker_md = memory::desc(new_ker_md.get_dims(), new_ker_md.get_data_type(), dnnl::memory::format_tag::any); @@ -187,37 +237,41 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_OneDnnConvolution( XLA_LIGHTWEIGHT_CHECK(num_args == arg_indx); dnnl::primitive_attr attrs; + if (post_ops.len() > 0) { + attrs.set_post_ops(post_ops); + } + + auto conv_pd = std::make_unique( + cpu_engine, prop_kind::forward_inference, algorithm::convolution_direct, + any_inp_md, any_ker_md, bias_md, any_res_md, strides, rhs_dilations, + pad_left, pad_right, attrs); auto inp_mem = memory(new_inp_md, cpu_engine, inp_minfo.Data()); auto ker_mem = memory(new_ker_md, cpu_engine, ker_minfo.Data()); auto res_mem = memory(new_res_md, cpu_engine, res_minfo.Data()); - auto conv_pd = convolution_forward::primitive_desc( - cpu_engine, prop_kind::forward_inference, algorithm::convolution_direct, - any_inp_md, any_ker_md, any_res_md, strides, rhs_dilations, pad_left, - pad_right, attrs); - - auto new_inp_mem = (conv_pd.src_desc() == inp_mem.get_desc()) + auto new_inp_mem = (conv_pd->src_desc() == inp_mem.get_desc()) ? inp_mem - : ReorderMemory(cpu_engine, conv_pd.src_desc(), + : ReorderMemory(cpu_engine, conv_pd->src_desc(), inp_mem, onednn_stream); - auto new_ker_mem = (conv_pd.weights_desc() == ker_mem.get_desc()) + auto new_ker_mem = (conv_pd->weights_desc() == ker_mem.get_desc()) ? ker_mem - : ReorderMemory(cpu_engine, conv_pd.weights_desc(), + : ReorderMemory(cpu_engine, conv_pd->weights_desc(), ker_mem, onednn_stream); - auto new_res_mem = (conv_pd.dst_desc() == res_mem.get_desc()) + auto new_res_mem = (conv_pd->dst_desc() == res_mem.get_desc()) ? res_mem - : memory(conv_pd.dst_desc(), cpu_engine); + : memory(conv_pd->dst_desc(), cpu_engine); - auto conv_prim = convolution_forward(conv_pd); + auto conv_prim = convolution_forward(*conv_pd); std::unordered_map conv_args{{DNNL_ARG_SRC, new_inp_mem}, {DNNL_ARG_WEIGHTS, new_ker_mem}, {DNNL_ARG_DST, new_res_mem}}; + conv_args.insert(postop_args.begin(), postop_args.end()); conv_prim.execute(onednn_stream, conv_args); - if (conv_pd.dst_desc() == res_mem.get_desc()) { + if (conv_pd->dst_desc() == res_mem.get_desc()) { res_mem = new_res_mem; } else { dnnl::reorder(new_res_mem, res_mem) diff --git a/third_party/xla/xla/service/cpu/onednn_convolution.h b/third_party/xla/xla/service/cpu/onednn_convolution.h index 19cbbe2e2a371a..657cddffb21afd 100644 --- a/third_party/xla/xla/service/cpu/onednn_convolution.h +++ b/third_party/xla/xla/service/cpu/onednn_convolution.h @@ -17,13 +17,22 @@ limitations under the License. #define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_util.h" + namespace xla { namespace cpu { +constexpr auto kOnednnConvConfig = BackendConfigOneofCase::kOnednnConvConfig; + extern "C" { extern void __xla_cpu_runtime_OneDnnConvolution(void* result, void** args); } // extern "C" +template <> +struct PrimitiveTrait { + using pointer_type = xla::cpu::OneDnnConvolutionConfig*; +}; + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.cc b/third_party/xla/xla/service/cpu/onednn_matmul.cc index 1b2dbee81c661b..77f25f1b17ec5a 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.cc +++ b/third_party/xla/xla/service/cpu/onednn_matmul.cc @@ -104,11 +104,6 @@ Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, return MemDescToXlaShapeFlattened(optimized_weights_md); } -struct FusedOperandsRef { - const std::vector& bufs; - std::vector>& postop_args; -}; - std::unique_ptr CreateMatMulPrimDesc( const engine& cpu_engine, const memory::desc& input_md, const memory::desc& plain_weights_md, const memory::desc& output_md, @@ -123,84 +118,9 @@ std::unique_ptr CreateMatMulPrimDesc( memory::format_tag::any); } - dnnl::post_ops post_ops; - int fused_operand_idx = 0; - for (auto& fused_op : matmul_config.fusions().ops()) { - switch (fused_op) { - case OneDnnFusionConfig::RELU: - post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f); - break; - case OneDnnFusionConfig::TANH: - post_ops.append_eltwise(dnnl::algorithm::eltwise_tanh, 0.f, 0.f); - break; - case OneDnnFusionConfig::GELU_TANH: - post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); - break; - case OneDnnFusionConfig::GELU_ERF: - post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); - break; - case OneDnnFusionConfig::RELU6: - post_ops.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0.f, 6.0f); - break; - case OneDnnFusionConfig::SIGMOID: - post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); - break; - case OneDnnFusionConfig::BIAS: { - bias_md = fused_mds.at(fused_operand_idx); - // Extend bias rank to match result rank. - auto missed_rank = output_md.get_ndims() - bias_md.get_ndims(); - XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); - if (missed_rank > 0) { - auto bias_dims = bias_md.get_dims(); - bias_dims.insert(bias_dims.begin(), missed_rank, 1); - bias_md = bias_md.reshape(bias_dims); - } - if (fused_operands_ref) { - fused_operands_ref->postop_args.emplace_back( - DNNL_ARG_BIAS, - dnnl::memory(bias_md, cpu_engine, - fused_operands_ref->bufs[fused_operand_idx])); - } - fused_operand_idx++; - } break; - case OneDnnFusionConfig::ELU: - post_ops.append_eltwise(dnnl::algorithm::eltwise_elu, 1.0f, 0.0f); - break; - case OneDnnFusionConfig::BINARY_ADD: { - auto binary_md = fused_mds.at(fused_operand_idx); - // Extend addend rank to match result rank. - auto missed_rank = output_md.get_ndims() - binary_md.get_ndims(); - XLA_LIGHTWEIGHT_CHECK(missed_rank >= 0); - if (missed_rank > 0) { - auto binary_dims = binary_md.get_dims(); - binary_dims.insert(binary_dims.begin(), missed_rank, 1); - binary_md = binary_md.reshape(binary_dims); - } - if (fused_operands_ref) { - auto arg_idx = - DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; - fused_operands_ref->postop_args.emplace_back( - arg_idx, - dnnl::memory(binary_md, cpu_engine, - fused_operands_ref->bufs[fused_operand_idx])); - } - post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); - fused_operand_idx++; - } break; - case OneDnnFusionConfig::LINEAR: { - float const_float; - *(reinterpret_cast(&const_float)) = - matmul_config.fusions().alpha_typecast(); - post_ops.append_eltwise(dnnl::algorithm::eltwise_linear, const_float, - 0.f); - } break; - default: - LOG(FATAL) << __FILE__ << ":" << __LINE__ - << " Attempt to call OneDNN MatMul runtime library with " - "unsupported post op." - << std::endl; - } - } + dnnl::post_ops post_ops = PopulateOneDnnPostOps( + cpu_engine, fused_mds, &matmul_config.fusions(), output_md.get_ndims(), + fused_operands_ref, &bias_md); dnnl::primitive_attr attrs; if (matmul_config.optimization_config().user_scratchpad()) { @@ -230,6 +150,13 @@ std::unique_ptr CreateMatMulPrimDesc( weights_md, output_md, fused_mds, matmul_config); } +template <> +typename PrimitiveTrait::pointer_type +GetKernelConfig( + absl::StatusOr* backend_config) { + return (*backend_config)->mutable_onednn_matmul_config(); +} + template <> std::unique_ptr CreateOneDnnPrimDesc(HloInstruction* instr) { diff --git a/third_party/xla/xla/service/cpu/onednn_matmul.h b/third_party/xla/xla/service/cpu/onednn_matmul.h index 09a2d6752ec29b..bf452e9d9f0518 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul.h +++ b/third_party/xla/xla/service/cpu/onednn_matmul.h @@ -19,11 +19,15 @@ limitations under the License. #include "dnnl.hpp" #include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_util.h" #include "xla/shape.h" namespace xla { namespace cpu { +constexpr auto kOnednnMatmulConfig = + BackendConfigOneofCase::kOnednnMatmulConfig; + Shape OneDnnMatMulOptWeightsShape(const Shape& input_shape, const Shape& weights_shape, const Shape& bias_shape, @@ -36,6 +40,11 @@ extern void __xla_cpu_runtime_OneDnnMatMul(void* result, void* scratch, extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args); } // extern "C" +template <> +struct PrimitiveTrait { + using pointer_type = xla::cpu::OneDnnMatMulConfig*; +}; + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_memory_util.h b/third_party/xla/xla/service/cpu/onednn_memory_util.h index c0c956a32dc0b1..2fef54861722f1 100644 --- a/third_party/xla/xla/service/cpu/onednn_memory_util.h +++ b/third_party/xla/xla/service/cpu/onednn_memory_util.h @@ -71,7 +71,7 @@ inline dnnl::memory::data_type ToOneDnnDataType(PrimitiveType ptype) { // TODO(intel-tf): properly handle not supported types: // S16, S64, U16, U32, U64, C64, C128, F8E5M2, F8E4M3FN, S4, U4, - // F8E4M3B11FNUZ + // F8E4M3B11FNUZ, F8E4M3, F8E3M4 default: return dt::undef; } diff --git a/third_party/xla/xla/service/cpu/onednn_util.cc b/third_party/xla/xla/service/cpu/onednn_util.cc index 16d1ec07b9016a..17d09230ef63a9 100644 --- a/third_party/xla/xla/service/cpu/onednn_util.cc +++ b/third_party/xla/xla/service/cpu/onednn_util.cc @@ -40,6 +40,92 @@ dnnl::stream MakeOneDnnStream( : dnnl::stream(cpu_engine); } +dnnl::post_ops PopulateOneDnnPostOps( + const dnnl::engine& cpu_engine, + const std::vector& fused_mds, + const OneDnnFusionConfig* fusion_config, const int output_ndims, + FusedOperandsRef* fused_operands_ref, dnnl::memory::desc* bias_md) { + dnnl::post_ops post_ops; + int fused_operand_idx = 0; + for (auto& fused_op : fusion_config->ops()) { + switch (fused_op) { + case OneDnnFusionConfig::RELU: + post_ops.append_eltwise(dnnl::algorithm::eltwise_relu, 0.f, 0.f); + break; + case OneDnnFusionConfig::TANH: + post_ops.append_eltwise(dnnl::algorithm::eltwise_tanh, 0.f, 0.f); + break; + case OneDnnFusionConfig::GELU_TANH: + post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_tanh, 0.f, 0.f); + break; + case OneDnnFusionConfig::GELU_ERF: + post_ops.append_eltwise(dnnl::algorithm::eltwise_gelu_erf, 0.f, 0.f); + break; + case OneDnnFusionConfig::RELU6: + post_ops.append_eltwise(dnnl::algorithm::eltwise_clip_v2, 0.f, 6.0f); + break; + case OneDnnFusionConfig::SIGMOID: + post_ops.append_eltwise(dnnl::algorithm::eltwise_logistic, 0.f, 0.f); + break; + case OneDnnFusionConfig::BIAS: { + *bias_md = fused_mds.at(fused_operand_idx); + // TODO(intel-tf): Move this check to the rewriter file + // Extend bias rank to match result rank. + auto missed_rank = output_ndims - bias_md->get_ndims(); + if (missed_rank > 0) { + auto bias_dims = bias_md->get_dims(); + bias_dims.insert(bias_dims.begin(), missed_rank, 1); + *bias_md = bias_md->reshape(bias_dims); + } + if (fused_operands_ref) { + fused_operands_ref->postop_args.emplace_back( + DNNL_ARG_BIAS, + dnnl::memory(*bias_md, cpu_engine, + fused_operands_ref->bufs[fused_operand_idx])); + } + fused_operand_idx++; + } break; + case OneDnnFusionConfig::ELU: + post_ops.append_eltwise(dnnl::algorithm::eltwise_elu, 1.0f, 0.0f); + break; + case OneDnnFusionConfig::BINARY_ADD: { + auto binary_md = fused_mds.at(fused_operand_idx); + // TODO(intel-tf): Move this check to the rewriter file + // Extend addend rank to match result rank. + auto missed_rank = output_ndims - binary_md.get_ndims(); + if (missed_rank > 0) { + auto binary_dims = binary_md.get_dims(); + binary_dims.insert(binary_dims.begin(), missed_rank, 1); + binary_md = binary_md.reshape(binary_dims); + } + if (fused_operands_ref) { + auto arg_idx = + DNNL_ARG_ATTR_MULTIPLE_POST_OP(post_ops.len()) | DNNL_ARG_SRC_1; + fused_operands_ref->postop_args.emplace_back( + arg_idx, + dnnl::memory(binary_md, cpu_engine, + fused_operands_ref->bufs[fused_operand_idx])); + } + post_ops.append_binary(dnnl::algorithm::binary_add, binary_md); + fused_operand_idx++; + } break; + case OneDnnFusionConfig::LINEAR: { + float const_float; + *(reinterpret_cast(&const_float)) = + fusion_config->alpha_typecast(); + post_ops.append_eltwise(dnnl::algorithm::eltwise_linear, const_float, + 0.f); + } break; + default: + LOG(FATAL) << __FILE__ << ":" << __LINE__ + << " Attempt to call OneDNN runtime library with " + "unsupported post op." + << std::endl; + } + } + return post_ops; +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/onednn_util.h b/third_party/xla/xla/service/cpu/onednn_util.h index aaba304fc083fa..02537883f60b22 100644 --- a/third_party/xla/xla/service/cpu/onednn_util.h +++ b/third_party/xla/xla/service/cpu/onednn_util.h @@ -23,6 +23,8 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "dnnl.hpp" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/onednn_config.pb.h" #include "xla/tsl/util/onednn_threadpool.h" #include "xla/xla_data.pb.h" #include "tsl/platform/cpu_info.h" @@ -51,6 +53,11 @@ inline bool IsSupportedType(xla::PrimitiveType dtype) { return false; } +struct FusedOperandsRef { + const std::vector& bufs; + std::vector>& postop_args; +}; + std::unique_ptr CreateOneDnnThreadPool( const Eigen::ThreadPoolDevice* threadpool_device); @@ -58,11 +65,27 @@ dnnl::stream MakeOneDnnStream( const dnnl::engine& cpu_engine, dnnl::threadpool_interop::threadpool_iface* thread_pool); -// This template function must have explicit specialization at the definition +typedef BackendConfig::BackendConfigOneofCase BackendConfigOneofCase; + +// These template functions must have explicit specialization at the definition // site. template std::unique_ptr CreateOneDnnPrimDesc(HloInstruction*); +template +struct PrimitiveTrait; + +template +typename PrimitiveTrait::pointer_type GetKernelConfig( + absl::StatusOr*); + +dnnl::post_ops PopulateOneDnnPostOps( + const dnnl::engine& cpu_engine, + const std::vector& fused_mds, + const OneDnnFusionConfig* fusion_config, const int output_ndims, + FusedOperandsRef* fused_operands_ref = nullptr, + dnnl::memory::desc* bias_md = nullptr); + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index 6b07b41aad00cc..874d9b3fe1b508 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -114,7 +114,7 @@ static absl::Status BuildAndCallFfi( } // For FFI handlers backend config must be a compatible MLIR dictionary. - ffi::CallFrameBuilder::FlatAttributesMap attributes; + ffi::CallFrameBuilder::AttributesMap attributes; if (!backend_config.empty() && backend_config != "{}") { // Backend config not empty, so proceed to parse it into an MLIR attribute // and build an MLIR compatible map of attributes out of it. diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h index 7f99e89ed54523..f23291b5510671 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "Eigen/Core" +#include "tsl/platform/ml_dtypes.h" extern "C" { @@ -65,6 +66,18 @@ extern void __xla_cpu_runtime_EigenSingleThreadedMatMulU8( uint8_t* lhs, uint8_t* rhs, int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, int32_t transpose_rhs); +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + tsl::float8_e5m2* out, tsl::float8_e5m2* lhs, tsl::float8_e5m2* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs); + +extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + tsl::float8_e4m3fn* out, tsl::float8_e4m3fn* lhs, tsl::float8_e4m3fn* rhs, + int64_t m, int64_t n, int64_t k, int32_t transpose_lhs, + int32_t transpose_rhs); + } // extern "C" #endif // XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc new file mode 100644 index 00000000000000..d29015456a5f3e --- /dev/null +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_matmul_f8.cc @@ -0,0 +1,39 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/base/attributes.h" +#include "xla/service/cpu/runtime_single_threaded_matmul.h" +#include "xla/service/cpu/runtime_single_threaded_matmul_common.h" +#include "tsl/platform/ml_dtypes.h" + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF8E5M2( + const void* run_options_ptr, tsl::float8_e5m2* out, tsl::float8_e5m2* lhs, + tsl::float8_e5m2* rhs, int64_t m, int64_t n, int64_t k, + int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} + +ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedMatMulF8E4M3FN( + const void* run_options_ptr, tsl::float8_e4m3fn* out, + tsl::float8_e4m3fn* lhs, tsl::float8_e4m3fn* rhs, int64_t m, int64_t n, + int64_t k, int32_t transpose_lhs, int32_t transpose_rhs) { + xla::SingleThreadedMatMulDispatch( + run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs, transpose_rhs); +} diff --git a/third_party/xla/xla/service/cpu/sample_harness.cc b/third_party/xla/xla/service/cpu/sample_harness.cc index f43ff05bcb237a..118d72578314ee 100644 --- a/third_party/xla/xla/service/cpu/sample_harness.cc +++ b/third_party/xla/xla/service/cpu/sample_harness.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/types.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/cpu/scoped_ir_builder_test.cc b/third_party/xla/xla/service/cpu/scoped_ir_builder_test.cc index c907193626b0a8..1d8094b4899945 100644 --- a/third_party/xla/xla/service/cpu/scoped_ir_builder_test.cc +++ b/third_party/xla/xla/service/cpu/scoped_ir_builder_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/target_machine_features_fake.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.cc b/third_party/xla/xla/service/cpu/simple_orc_jit.cc index 44b9f109bf0f6d..e789538e128309 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.cc +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.cc @@ -24,11 +24,13 @@ limitations under the License. #include #include #include +#include #include #include // NOLINT #include #include +#include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -81,6 +83,7 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" #include "xla/service/llvm_compiler.h" #include "xla/util.h" +#include "tsl/platform/cpu_info.h" #include "tsl/platform/logging.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) @@ -97,15 +100,6 @@ extern "C" uint16_t __truncsfbf2(float); extern "C" uint16_t __truncdfbf2(double); namespace xla::cpu { - -std::vector DetectMachineAttributes() { - std::vector result; - for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) { - result.push_back((enabled ? '+' : '-') + std::string(feature)); - } - return result; -} - namespace { class DefaultMemoryMapper final @@ -302,30 +296,166 @@ bool ContiguousSectionMemoryManager::finalizeMemory(std::string* err_msg) { return false; } +using tsl::port::CPUFeature; + +// Returns the earliest CPU generation that supports the instruction set. +llvm::StringRef CPUTargetFromMaxFeature(CPUFeature max_feature) { + switch (max_feature) { + case CPUFeature::SSE4_2: + return "nehalem"; + case CPUFeature::AVX: + return "sandybridge"; + case CPUFeature::AVX2: + return "haswell"; + case CPUFeature::AVX512F: + return "skylake-avx512"; + case CPUFeature::AVX512_VNNI: + return "cascadelake"; + case CPUFeature::AVX512_BF16: + return "cooperlake"; + case CPUFeature::AMX_BF16: + case CPUFeature::AMX_INT8: + return "sapphirerapids"; + case CPUFeature::AMX_FP16: + return "graniterapids"; + default: + LOG(FATAL) << "Unsupported max feature: " << max_feature; + } +} + } // namespace +std::optional ISAStringToFeature( + const absl::string_view feature_string) { + if (feature_string.empty()) return std::nullopt; + + // Non-exhaustive list of CPU features. (Only the ones we care about.) + // TODO(penporn): Handle ARM + static auto* x86 = [] { + return new absl::flat_hash_map( + {{"SSE4_2", CPUFeature::SSE4_2}, + {"AVX", CPUFeature::AVX}, + {"AVX2", CPUFeature::AVX2}, + {"AVX512", CPUFeature::AVX512F}, + {"AVX512_VNNI", CPUFeature::AVX512_VNNI}, + {"AVX512_BF16", CPUFeature::AVX512_BF16}, + {"AMX", CPUFeature::AMX_BF16}, // Includes AMX_INT8. + {"AMX_FP16", CPUFeature::AMX_FP16}}); + }(); + + // Assume that `feature_string` always contains all uppercase letters. + if (auto it = x86->find(feature_string); it != x86->end()) return it->second; + LOG(WARNING) << "Unknown CPU ISA: " << feature_string; + return std::nullopt; +} + +// Disable any feature that is newer than `max_feature`. +bool ShouldEnableCPUFeature(const llvm::StringRef feature, + const CPUFeature& max_feature) { + // x86 CPUs have backward compatibility so newer CPUs have all features of + // older CPUs. We go through switch cases from oldest features to newest. + // - Each case looks for features that are introduced in the next + // generation, i.e., features that should be disabled if `max_feature` is + // older or equal to the case's ISA. + // - We combine all features that needs to be disabled from all ISAs newer + // than `max_feature` by falling through cases. + // + // For example, if `max_feature` is AVX2, we start by disabling + // AVX512-generation features in the AVX2 case, then fall through to the + // AVX512 case to disable next-gen features (AVX512_VNNI), etc, all the way + // down to the newest one. + // + // TODO(https://github.com/openxla/xla/issues/17758): Figure out if we need to + // support AVX10 and where to put it. + switch (max_feature) { + case CPUFeature::SSE4_2: + if (feature.starts_with("avx") || feature == "f16c" || + feature == "vpclmulqdq" || feature == "vaes") { + return false; + } + [[fallthrough]]; + case CPUFeature::AVX: + if (feature.starts_with("avx2") || feature.starts_with("fma")) { + return false; + } + [[fallthrough]]; + case CPUFeature::AVX2: + if (feature.starts_with("avx512") || feature == "evex512") return false; + [[fallthrough]]; + case CPUFeature::AVX512F: + if (feature == "avx512vnni") return false; + [[fallthrough]]; + case CPUFeature::AVX512_VNNI: + if (feature == "avx512bf16") return false; + [[fallthrough]]; + case CPUFeature::AVX512_BF16: + if (feature.starts_with("amx")) return false; + [[fallthrough]]; + case CPUFeature::AMX_INT8: + case CPUFeature::AMX_BF16: + if (feature == "amx-fp16") return false; + [[fallthrough]]; + default: + // Leave all other features enabled. + return true; + } +} + +DetectedMachineAttributes DetectMachineAttributes( + std::optional max_feature) { + DetectedMachineAttributes result; + result.features_filtered = false; + // We only have x86 constraints. Skip the check if we are on non-x86 CPUs. + bool no_feature_constraint = + !max_feature.has_value() || !tsl::port::IsX86CPU(); + for (const auto& [feature, enabled] : llvm::sys::getHostCPUFeatures()) { + bool should_enable = + enabled && (no_feature_constraint || + ShouldEnableCPUFeature(feature, *max_feature)); + result.features.push_back( + absl::StrCat(should_enable ? "+" : "-", std::string(feature))); + result.features_filtered |= (should_enable != enabled); + } + std::sort(result.features.begin(), result.features.end()); + return result; +} + +std::vector DetectMachineAttributes() { + return DetectMachineAttributes(std::nullopt).features; +} + /*static*/ std::unique_ptr SimpleOrcJIT::InferTargetMachineForJIT( - const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level) { - std::vector attrs = DetectMachineAttributes(); - llvm::SmallVector llvm_attrs(attrs.begin(), attrs.end()); + const llvm::TargetOptions& target_options, llvm::CodeGenOptLevel opt_level, + absl::string_view max_cpu_isa) { + std::optional max_feature = ISAStringToFeature(max_cpu_isa); + auto result = DetectMachineAttributes(max_feature); + llvm::SmallVector llvm_attrs(result.features.begin(), + result.features.end()); + // If `max_feature` is newer than the host CPU, we should keep the host CPU + // name, e.g., we don't want to set the target CPU to Skylake when we are on + // a Broadwell host. + llvm::StringRef target_cpu = result.features_filtered + ? CPUTargetFromMaxFeature(*max_feature) + : llvm::sys::getHostCPUName(); std::unique_ptr target_machine( llvm::EngineBuilder() .setTargetOptions(target_options) .setOptLevel(opt_level) .selectTarget( /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", - /*MCPU=*/llvm::sys::getHostCPUName(), + /*MCPU=*/target_cpu, /*MAttrs=*/llvm_attrs)); CHECK(target_machine != nullptr); return target_machine; } static CompilerFunctor::TargetMachineBuilder CreateTargetMachineBuilder( - llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level) { - return [target_options, opt_level]() { - return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level); + llvm::TargetOptions target_options, llvm::CodeGenOptLevel opt_level, + absl::string_view max_cpu_isa) { + return [target_options, opt_level, max_cpu_isa]() { + return SimpleOrcJIT::InferTargetMachineForJIT(target_options, opt_level, + max_cpu_isa); }; } @@ -338,9 +468,9 @@ SimpleOrcJIT::SimpleOrcJIT( LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs) + size_t num_jit_dylibs, absl::string_view max_cpu_isa) : target_machine_builder_( - CreateTargetMachineBuilder(target_options, opt_level)), + CreateTargetMachineBuilder(target_options, opt_level, max_cpu_isa)), target_machine_(target_machine_builder_()), target_triple_(target_machine_->getTargetTriple()), data_layout_(target_machine_->createDataLayout()), @@ -426,7 +556,7 @@ llvm::Expected> SimpleOrcJIT::Create( LLVMCompiler::ModuleHook pre_optimization_hook, LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs) { + size_t num_jit_dylibs, absl::string_view max_cpu_isa) { auto SSP = std::make_shared(); auto target_process_control = llvm::orc::SelfExecutorProcessControl::Create(std::move(SSP)); @@ -441,7 +571,7 @@ llvm::Expected> SimpleOrcJIT::Create( target_options, opt_level, optimize_for_size, disable_expensive_passes, disable_slp_vectorizer, fast_math_flags, std::move(pre_optimization_hook), std::move(post_optimization_hook), std::move(post_codegen_hook), - num_jit_dylibs); + num_jit_dylibs, std::move(max_cpu_isa)); } llvm::orc::ExecutorSymbolDef SimpleOrcJIT::ResolveRuntimeSymbol( @@ -561,6 +691,8 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConv3DF32); REGISTER_CPU_RUNTIME_SYMBOL(DuccSingleThreadedFft); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E4M3FN); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF8E5M2); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); diff --git a/third_party/xla/xla/service/cpu/simple_orc_jit.h b/third_party/xla/xla/service/cpu/simple_orc_jit.h index 9adec42216cd57..728db2ab255f67 100644 --- a/third_party/xla/xla/service/cpu/simple_orc_jit.h +++ b/third_party/xla/xla/service/cpu/simple_orc_jit.h @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/JITEventListener.h" @@ -47,6 +49,7 @@ limitations under the License. #include "llvm/TargetParser/Triple.h" #include "xla/service/cpu/compiler_functor.h" #include "xla/service/llvm_compiler.h" +#include "tsl/platform/cpu_info.h" namespace xla::cpu { @@ -79,7 +82,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs = 1); + size_t num_jit_dylibs, absl::string_view max_cpu_isa); static llvm::Expected> Create( const llvm::TargetOptions& target_options, @@ -90,7 +93,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { LLVMCompiler::ModuleHook post_optimization_hook, absl::AnyInvocable post_codegen_hook, - size_t num_jit_dylibs = 1); + size_t num_jit_dylibs, absl::string_view max_cpu_isa); ~SimpleOrcJIT() override; @@ -117,7 +120,7 @@ class SimpleOrcJIT : public llvm::JITEventListener { // the current machine. static std::unique_ptr InferTargetMachineForJIT( const llvm::TargetOptions& target_options, - llvm::CodeGenOptLevel opt_level); + llvm::CodeGenOptLevel opt_level, absl::string_view max_cpu_isa); int64_t SizeOfGeneratedCodeInBytes() const { return size_of_generated_code_in_bytes_; @@ -167,6 +170,22 @@ class SimpleOrcJIT : public llvm::JITEventListener { llvm::JITEventListener* perf_jit_event_listener_; }; +std::optional ISAStringToFeature( + absl::string_view feature_string); + +bool ShouldEnableCPUFeature(llvm::StringRef feature, + const tsl::port::CPUFeature& max_feature); + +struct DetectedMachineAttributes { + std::vector features; + bool features_filtered; +}; + +DetectedMachineAttributes DetectMachineAttributes( + std::optional max_feature); + +// TODO(penporn): PJRT's CPU client also calls this function. We should +// make it get the same filtered attributes according to the `max_isa` setting. std::vector DetectMachineAttributes(); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/target_machine_features.cc b/third_party/xla/xla/service/cpu/target_machine_features.cc index 05b2d0128bc099..d675f5380eb97a 100644 --- a/third_party/xla/xla/service/cpu/target_machine_features.cc +++ b/third_party/xla/xla/service/cpu/target_machine_features.cc @@ -54,5 +54,9 @@ int64_t LLVMTargetMachineFeatures::minimum_alignment_for_allocation( cpu_function_runtime::MinAlign()); } +std::string LLVMTargetMachineFeatures::get_target_feature_string() const { + return target_machine_->getTargetFeatureString().str(); +} + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/target_machine_features.h b/third_party/xla/xla/service/cpu/target_machine_features.h index c8bcc1da74e55a..8ec78485ae6b94 100644 --- a/third_party/xla/xla/service/cpu/target_machine_features.h +++ b/third_party/xla/xla/service/cpu/target_machine_features.h @@ -59,6 +59,8 @@ class TargetMachineFeatures { // this functionality). virtual int vector_register_count(const llvm::Function& function) const = 0; + virtual std::string get_target_feature_string() const = 0; + // Returns the minimum alignment for a buffer of size size_bytes. virtual int64_t minimum_alignment_for_allocation( int64_t size_bytes) const = 0; @@ -102,6 +104,8 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { int64_t minimum_alignment_for_allocation(int64_t size_bytes) const override; + std::string get_target_feature_string() const override; + private: llvm::TargetTransformInfo* GetTargetTransformInfoFor( const llvm::Function& function) const; diff --git a/third_party/xla/xla/service/cpu/target_machine_features_fake.h b/third_party/xla/xla/service/cpu/target_machine_features_fake.h index 2823770177f000..c232cb7a5927f0 100644 --- a/third_party/xla/xla/service/cpu/target_machine_features_fake.h +++ b/third_party/xla/xla/service/cpu/target_machine_features_fake.h @@ -52,6 +52,10 @@ class TargetMachineFeaturesWithFakeAlignmentLogic return fake_alignment_logic_(size_bytes); } + std::string get_target_feature_string() const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + private: std::function fake_alignment_logic_; }; diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index be51e842c6b5e0..fccda722a35e6e 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -50,9 +50,9 @@ xla_cc_test( "//xla/service:executable", "//xla/service:platform_util", "//xla/service/cpu:cpu_compiler", - "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", @@ -134,9 +134,9 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:hlo_ordering", "//xla/service:logical_buffer", "//xla/service/llvm_ir:alias_analysis", "//xla/service/llvm_ir:ir_array", @@ -230,9 +230,9 @@ xla_cc_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "//xla/service", "//xla/service:cpu_plugin", "//xla/tests:client_library_test_base", @@ -307,9 +307,9 @@ xla_cc_test( ":cpu_codegen_test", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:sorting", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:sorting", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_compiler", @@ -330,6 +330,7 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/cpu:cpu_compiler", + "//xla/service/cpu:simple_orc_jit", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", @@ -337,6 +338,7 @@ xla_cc_test( "@llvm-project//llvm:ARMCodeGen", # fixdeps: keep "@llvm-project//llvm:Target", "@llvm-project//llvm:X86CodeGen", # fixdeps: keep + "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -400,6 +402,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:platform_port", ], ) diff --git a/third_party/xla/xla/service/cpu/tests/cpu_fusion_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_fusion_test.cc index fcea12916cbd88..4ef2c8ec0adfe1 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_fusion_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_fusion_test.cc @@ -43,7 +43,7 @@ class CpuFusionTest : public HloTestBase { ErrorSpec error_spec_{0.0001, 1e-5}; private: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.add_xla_disable_hlo_passes("layout-assignment"); return debug_options; diff --git a/third_party/xla/xla/service/cpu/tests/cpu_infeed_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_infeed_test.cc index 4a90b10f5c5bf4..aa25e0f2f99f2f 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_infeed_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_infeed_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" diff --git a/third_party/xla/xla/service/cpu/tests/cpu_intrinsic_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_intrinsic_test.cc index 0e9d32beb5ae13..ea52375327f14f 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_intrinsic_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_intrinsic_test.cc @@ -81,7 +81,7 @@ class CpuUnaryIntrinsicTest } private: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); HloTestBase::SetAotFastMathDebugOptions(&debug_options); return debug_options; diff --git a/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc index 5e486f216e932a..42a561b12cb35f 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_noalias_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -33,7 +34,6 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/llvm_ir/alias_analysis.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/llvm_util.h" diff --git a/third_party/xla/xla/service/cpu/tests/cpu_topk_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_topk_test.cc index 618fd0f02a904c..3f230aef53c276 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_topk_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_topk_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include -#include "xla/client/lib/sorting.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/sorting.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/cpu/test_target_triple_helper.h" diff --git a/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc index ec29e43b3aff93..1b4f77d8ddb79e 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_vectorization_test.cc @@ -15,21 +15,25 @@ limitations under the License. #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" #include "llvm-c/Target.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/cpu/cpu_compiler.h" +#include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/cpu_info.h" #include "tsl/platform/test.h" namespace xla { @@ -79,7 +83,7 @@ class CpuVectorizationTest } private: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); HloTestBase::SetAotFastMathDebugOptions(&debug_options); return debug_options; @@ -140,6 +144,139 @@ INSTANTIATE_TEST_SUITE_P(CpuVectorizationTestInstantiation, ::testing::ValuesIn(CpuVectorizationTestCases), CpuVectorizationTest::Name); +struct MaxIsaTestSpec { + std::string max_isa; + std::string feature; + bool should_enable; +}; + +class MaxIsaTest : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static std::string Name( + const ::testing::TestParamInfo& info) { + // Test names cannot contain '-'. Replace it with '_'. + std::string feature = info.param.feature; + absl::c_replace_if( + feature, [](char c) { return c != '_' && !absl::ascii_isalnum(c); }, + '_'); + return absl::StrCat(info.param.max_isa, "_feature_", feature); + } +}; + +TEST_P(MaxIsaTest, ShouldEnableFeature) { + HloComputation::Builder builder(TestName()); + MaxIsaTestSpec spec = GetParam(); + + auto max_feature = ISAStringToFeature(spec.max_isa); + bool should_enable = ShouldEnableCPUFeature(spec.feature, *max_feature); + EXPECT_EQ(should_enable, spec.should_enable); +} + +std::vector GetMaxIsaTestCases() { + return std::vector({ + MaxIsaTestSpec{"AVX2", "avx", true}, + MaxIsaTestSpec{"AVX2", "avx2", true}, + MaxIsaTestSpec{"AVX2", "avx512f", false}, + MaxIsaTestSpec{"AVX2", "avx512vnni", false}, + MaxIsaTestSpec{"AVX2", "evex512", false}, + MaxIsaTestSpec{"AVX512", "avx512f", true}, + MaxIsaTestSpec{"AVX512", "avx512vnni", false}, + MaxIsaTestSpec{"AVX512", "amx-bf16", false}, + }); +} + +INSTANTIATE_TEST_SUITE_P(MaxIsaTestInstantiation, MaxIsaTest, + ::testing::ValuesIn(GetMaxIsaTestCases()), + MaxIsaTest::Name); + +struct JitVectorizationTestSpec { + HloOpcode opcode; + std::string max_isa; + std::string check_template; + int num_vector_elements; +}; + +class JitVectorizationTest + : public CpuCodegenTest, + public ::testing::WithParamInterface { + public: + static std::string Name( + const ::testing::TestParamInfo& info) { + std::string op_name(HloOpcodeString(info.param.opcode)); + op_name[0] = toupper(op_name[0]); + return absl::StrCat(op_name, "_max_", info.param.max_isa); + } + + private: + DebugOptions GetDebugOptionsForTest() const override { + JitVectorizationTestSpec spec = GetParam(); + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_max_isa(spec.max_isa); + // For AVX512, we have to override the default `prefer_vector_width=256` + // setting. Otherwise, LLVM won't generate AVX512. + // TODO(penporn): Change the setting for actual AVX512 codegen too. + if (spec.max_isa == "AVX512") { + debug_options.set_xla_cpu_prefer_vector_width(512); + } + return debug_options; + } +}; + +TEST_P(JitVectorizationTest, JitUpToIsa) { + if (!tsl::port::IsX86CPU()) { + GTEST_SKIP() << "This feature only works for x86 CPUs."; + } + HloComputation::Builder builder(TestName()); + JitVectorizationTestSpec spec = GetParam(); + + // If the CPU doesn't have the `max_isa` feature, e.g., `max_isa=AVX512` but + // we are running on an AVX2 machine, update the `check_lines` accordingly. + using tsl::port::CPUFeature; + auto feature = ISAStringToFeature(spec.max_isa); + if (!tsl::port::TestCPUFeature(*feature)) { + if (tsl::port::TestCPUFeature(CPUFeature::AVX)) { + spec.num_vector_elements = 8; + } else { + spec.num_vector_elements = 4; + } + } + std::string check_lines = absl::StrReplaceAll( + spec.check_template, {{"%d", absl::StrCat(spec.num_vector_elements)}}); + + // Build HLO module. + auto shape = ShapeUtil::MakeShape(F32, {1024}); + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, spec.opcode, a, b)); + std::unique_ptr computation = builder.Build(); + + auto hlo_module = CreateNewVerifiedModule(); + hlo_module->AddEntryComputation(std::move(computation)); + + CompileAndVerifyIr(std::move(hlo_module), check_lines, + /*match_optimized_ir=*/true); +} + +std::vector GetJitVectorizationTestCases() { + return std::vector({ + JitVectorizationTestSpec{HloOpcode::kMultiply, "SSE4_2", + R"(CHECK: fmul <%d x float>)", 4}, + JitVectorizationTestSpec{HloOpcode::kMultiply, "AVX2", + R"(CHECK: fmul <%d x float>)", 8}, + JitVectorizationTestSpec{HloOpcode::kMultiply, "AVX512", + R"(CHECK: fmul <%d x float>)", 16}, + }); +} + +INSTANTIATE_TEST_SUITE_P(JitVectorizationTestInstantiation, + JitVectorizationTest, + ::testing::ValuesIn(GetJitVectorizationTestCases()), + JitVectorizationTest::Name); + } // namespace } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc index 6bceebc7343c8e..a710898ca8350f 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/strings/str_replace.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/cpu/onednn_contraction_rewriter.h" @@ -32,78 +33,224 @@ limitations under the License. namespace xla { namespace cpu { -class ConvolutionTest : public HloTestBase { +class ConvolutionTest : public HloTestBase, + public ::testing::WithParamInterface { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_cpu_use_thunk_runtime(false); return debug_options; } - const char* conv_rewrite_str_ = R"( + PrimitiveType dtype_; + std::string dtypeString_; + bool user_scratchpad_; + bool weights_prepacked_; + float atol_; + float rtol_; + + constexpr static const char* kConvRewriteStr = R"( ; CHECK: custom_call_target="__onednn$convolution", ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], - ; CHECK-DAG: "onednn_conv_config":{ - ; CHECK-DAG: } - ; CHECK: } + ; CHECK-DAG: "onednn_conv_config":{$fusions_str,$opt_config + ; CHECK-DAG: } + ; CHECK: } )"; + + constexpr static const char* kConvRewriteFusionsStr = R"( + ; CHECK-DAG: "fusions":{ + ; CHECK-DAG: "ops":[$fused_ops] + ; CHECK-DAG: },)"; + + constexpr static const char* kConvRewriteOptimizationsStr = R"( + ; CHECK-DAG: "optimization_config":{ + ; CHECK-DAG: "weights_prepacked":$weights_prepacked, + ; CHECK-DAG: "user_scratchpad":$user_scratchpad, + ; CHECK-DAG: })"; + + ConvolutionTest() { + dtype_ = GetParam(); + atol_ = rtol_ = (dtype_ == F32) ? 1e-4 : 1e-2; + // TODO(intel-tf): Set default value of user_scratchpad to true after + // enabling feature + user_scratchpad_ = false; + weights_prepacked_ = false; + dtypeString_ = primitive_util::LowercasePrimitiveTypeName(dtype_); + } + + void SetUp() override { + if (!IsSupportedType(dtype_)) { + GTEST_SKIP() << "CPU does not support " << dtypeString_; + } + } + + void SetWeightsPrepacked(bool value) { weights_prepacked_ = value; } + + void SetUserScratchpad(bool value) { user_scratchpad_ = value; } + + std::string GetOptimizationsString() { + return (user_scratchpad_ || weights_prepacked_) + ? absl::StrReplaceAll(kConvRewriteOptimizationsStr, + {{"$weights_prepacked", + weights_prepacked_ ? "true" : "false"}, + {"$user_scratchpad", + user_scratchpad_ ? "true" : "false"}}) + : ""; + } + + std::string ConvStringWithOptimizations( + const std::vector fused_ops) { + std::ostringstream stream; + std::for_each( + fused_ops.begin(), fused_ops.end(), + [&](const absl::string_view& arg) { stream << "\"" << arg << "\","; }); + std::string fusions = stream.str(); + if (fused_ops.size() > 0) { + fusions.pop_back(); + return absl::StrReplaceAll( + kConvRewriteStr, + {{"$fusions_str,", absl::StrReplaceAll(kConvRewriteFusionsStr, + {{"$fused_ops", fusions}})}, + {"$opt_config", GetOptimizationsString()}}); + } + return absl::StrReplaceAll( + kConvRewriteStr, + {{"$fusions_str,", ""}, {"$opt_config", GetOptimizationsString()}}); + } + + // TODO(intel-tf): Remove this and simplify patterns when Elemental BF16 is + // fully supported. + PrimitiveType PromotedDtype() { + // BF16 is promoted to F32 because not all HLO Instructions currently + // support BF16 computations. Meanwhile, FP32 and FP16 elementwise + // instructions are not promoted and remain unchanged. + return (dtype_ == BF16) ? F32 : dtype_; + } + + void AdjustToleranceForDtype(PrimitiveType for_type, float atol, float rtol) { + if (dtype_ == for_type) { + atol_ = atol; + rtol_ = rtol; + } + } + + std::string PromotedDtypeToString() { + return primitive_util::LowercasePrimitiveTypeName(PromotedDtype()); + } + + void RunCompareAndMatchOptimizedHlo( + const absl::string_view outline, + const std::vector fused_ops) { + const std::string convolution_module_str = absl::StrReplaceAll( + outline, + {{"$dtype", dtypeString_}, {"$pdtype", PromotedDtypeToString()}}); + EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{atol_, rtol_})); + MatchOptimizedHlo(convolution_module_str, + ConvStringWithOptimizations(fused_ops)); + } }; -TEST_F(ConvolutionTest, Simple2DTestF32) { - const char* convolution_module_str = R"( - HloModule convolution.test.f32 - - ENTRY convolution.test.f32 { - arg.0 = f32[1,22,22,1] parameter(0), parameter_replication={false} - reshape.0 = f32[1,22,22,1] reshape(arg.0) - arg.1 = f32[8,8,1,1] parameter(1), parameter_replication={false} - reshape.1 = f32[8,8,1,1] reshape(arg.1) - convolution.0 = f32[1,11,11,1] convolution(reshape.0, reshape.1), window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f - reshape.2 = f32[1,11,11,1] reshape(convolution.0) - tuple.0 = (f32[1,11,11,1]) tuple(reshape.2) - ROOT get-tuple-element.0 = f32[1,11,11,1] get-tuple-element(tuple.0), index=0 +TEST_P(ConvolutionTest, Simple2DTest1) { + const absl::string_view outline = R"( + HloModule convolution.test + + ENTRY convolution.test { + arg.0 = $dtype[1,22,22,1] parameter(0) + reshape.0 = $dtype[1,22,22,1] reshape(arg.0) + arg.1 = $dtype[8,8,1,1] parameter(1) + reshape.1 = $dtype[8,8,1,1] reshape(arg.1) + convolution.0 = $dtype[1,11,11,1] convolution(reshape.0, reshape.1), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + reshape.2 = $dtype[1,11,11,1] reshape(convolution.0) + tuple.0 = ($dtype[1,11,11,1]) tuple(reshape.2) + ROOT gte.0 = $dtype[1,11,11,1] get-tuple-element(tuple.0), index=0 })"; - EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); + RunCompareAndMatchOptimizedHlo(outline, {}); } -TEST_F(ConvolutionTest, Simple3DTestBF16) { - if (!IsSupportedType(PrimitiveType::BF16)) { - GTEST_SKIP() << "CPU does not support BF16."; - } +TEST_P(ConvolutionTest, Simple3DTest1) { + const absl::string_view outline = R"( + HloModule convolution.test - const char* convolution_module_str = R"( - HloModule convolution.test.bf16 - - ENTRY convolution.test.bf16 { - p0 = bf16[8,4,5,5,1] parameter(0) - p1 = bf16[3,3,3,1,32] parameter(1) - ROOT conv = bf16[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + ENTRY convolution.test { + p0 = $dtype[8,4,5,5,1] parameter(0) + p1 = $dtype[3,3,3,1,32] parameter(1) + ROOT conv = $dtype[8,4,5,5,32] convolution(p0, p1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f })"; - EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); + RunCompareAndMatchOptimizedHlo(outline, {}); } -TEST_F(ConvolutionTest, Simple2DTestF16) { - if (!IsSupportedType(PrimitiveType::F16)) { - GTEST_SKIP() << "CPU does not support F16."; - } +TEST_P(ConvolutionTest, Conv3DWithBiasTest) { + const absl::string_view outline = R"( + HloModule convolution.test.with.bias - const char* convolution_module_str = R"( - HloModule convolution.test.f16 - ENTRY convolution.test.bf16 { - p0 = f16[8,4,5,5,1] parameter(0) - p1 = f16[3,3,3,1,32] parameter(1) - ROOT conv = f16[8,4,5,5,32] convolution(p0, p1), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + ENTRY convolution.test.with.bias { + arg.0 = $dtype[15,4,5,5,28] parameter(0) + arg.1 = $dtype[3,3,3,28,64] parameter(1) + conv = $dtype[15,4,5,5,64] convolution(arg.0, arg.1), + window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=b012f_012io->b012f + bias = $dtype[64] parameter(2) + broadcasted_bias = $dtype[15,4,5,5,64] broadcast(bias), dimensions={4} + ROOT add = $dtype[15,4,5,5,64] add(conv, broadcasted_bias) })"; - EXPECT_TRUE(RunAndCompare(convolution_module_str, ErrorSpec{1e-4, 1e-4})); - MatchOptimizedHlo(convolution_module_str, conv_rewrite_str_); + RunCompareAndMatchOptimizedHlo(outline, {"BIAS"}); +} + +TEST_P(ConvolutionTest, Conv2DWithBinaryAddTest) { + const absl::string_view outline = R"( + HloModule convolution.test.with.binaryadd + + ENTRY convolution.test.with.binaryadd { + arg0.1 = $dtype[1,22,22,1] parameter(0) + constant.3 = $dtype[] constant(1) + broadcast.4 = $dtype[8,8,1,1] broadcast(constant.3), dimensions={} + convolution.0 = $dtype[1,11,11,1] convolution(arg0.1, broadcast.4), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + constant.5 = $dtype[] constant(15) + broadcast.6 = $dtype[1] broadcast(constant.5), dimensions={} + broadcast.9 = $dtype[1,11,11,1] broadcast(broadcast.6), dimensions={3} + ROOT add.10 = $dtype[1,11,11,1] add(convolution.0, broadcast.9) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BINARY_ADD"}); } +// This test should match BIAS + RESIDUAL ADD when the residual add fusion is +// re-enabled. +TEST_P(ConvolutionTest, Conv2DWithBiasAndBinaryAddTest) { + const absl::string_view outline = R"( + HloModule convolution.add.test + + ENTRY convolution.add.test { + arg0.1 = $dtype[1,22,22,1] parameter(0) + arg0.2 = $dtype[8,8,1,10] parameter(1) + convolution.0 = $dtype[1,11,11,10] convolution(arg0.1, arg0.2), + window={size=8x8 stride=2x2 pad=3_3x3_3}, dim_labels=b01f_01io->b01f + const.0 = $dtype[10] constant(15) + bcast.1 = $dtype[1,11,11,10] broadcast(const.0), dimensions={3} + add.0 = $dtype[1,11,11,10] add(convolution.0, bcast.1) + const.1 = $dtype[1,11,11,10] constant({...}) + ROOT add.1 = $dtype[1,11,11,10] add(add.0, const.1) + })"; + + RunCompareAndMatchOptimizedHlo(outline, {"BIAS"}); +} + +INSTANTIATE_TEST_SUITE_P( + OneDnnConvolutionTestSuite, ConvolutionTest, + ::testing::Values(F32, BF16, F16), + [](const ::testing::TestParamInfo& info) { + auto test_name = primitive_util::LowercasePrimitiveTypeName(info.param); + std::transform(test_name.begin(), test_name.end(), test_name.begin(), + [](auto c) { return std::toupper(c); }); + return test_name; + }); + } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc index 92ca5061724faf..5704b02ed64581 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc @@ -24,7 +24,7 @@ namespace { class LayerNormTest : public HloTestBase { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_cpu_use_thunk_runtime(false); return debug_options; diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc index 57f7c09aba11e8..9234a56c9dd66a 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc @@ -36,7 +36,7 @@ namespace cpu { class MatmulTest : public HloTestBase { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_cpu_use_thunk_runtime(false); return debug_options; diff --git a/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc index 1fff5d88a736e5..692e38716075a3 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc @@ -47,7 +47,7 @@ class OneDnnSoftmaxTest : public HloTestBase, public ::testing::WithParamInterface> { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_cpu_use_thunk_runtime(false); return debug_options; diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc index 56096e9703fb7e..ae237dac76eb23 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -38,7 +38,7 @@ class CpuGpuShapeVerifierTest : public HloTestBase { HloVerifierOpts opts; std::unique_ptr metadata = std::make_unique(std::move(opts)); - hlo_verifier_ = std::make_unique(std::move(metadata)); + set_hlo_verifier(std::make_unique(std::move(metadata))); } }; diff --git a/third_party/xla/xla/service/defuser.h b/third_party/xla/xla/service/defuser.h index d70552e1383a7b..46ad02630dfce0 100644 --- a/third_party/xla/xla/service/defuser.h +++ b/third_party/xla/xla/service/defuser.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DEFUSER_H_ #define XLA_SERVICE_DEFUSER_H_ -#include - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass which replaces all fusion instructions with the equivalent un-fused -// instructions. -class Defuser : public HloModulePass { - public: - Defuser() {} - ~Defuser() override {} - absl::string_view name() const override { return "defuser"; } - - // Run defusion on the given module. Returns whether the module was - // changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/defuser.h" #endif // XLA_SERVICE_DEFUSER_H_ diff --git a/third_party/xla/xla/service/despecializer.h b/third_party/xla/xla/service/despecializer.h index 9bfe95d3f8715d..c230c27805b012 100644 --- a/third_party/xla/xla/service/despecializer.h +++ b/third_party/xla/xla/service/despecializer.h @@ -16,87 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DESPECIALIZER_H_ #define XLA_SERVICE_DESPECIALIZER_H_ -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/hlo/pass/hlo_pass_pipeline.h" - -namespace xla { - -// Creates an HloPassPipeline containing multiple HloPasses that can -// despecialize an optimized HloModule. This is useful to run an HloModule -// optimized for one specific platform on a different platform (undoing platform -// specific passes) with matching numerics for comparison. -// -// Current despecialization passes are HloDescheduler, ControlDepRemover, -// Defuser and BFloat16MixedPrecisionRemoval. -class Despecializer : public HloModulePass { - public: - Despecializer(); - void AddReduceWindowToReduceBroadcastDeconstruct(); - void AddAssumeGatherIndicesInBoundRewriteToCopy(); - absl::string_view name() const override { return "despecializer"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - HloPassPipeline pipeline_; -}; - -class AssumeGatherIndicesInBoundRewriteToCopy : public HloModulePass { - public: - AssumeGatherIndicesInBoundRewriteToCopy() = default; - absl::string_view name() const override { - return "AssumeGatherIndicesInBoundRewriteToCopy"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -class DeconstructReduceWindowToReduceBroadcast : public HloModulePass { - public: - DeconstructReduceWindowToReduceBroadcast() = default; - absl::string_view name() const override { - return "ReduceWindowToReduceAndBroadcast"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -// Pass which strips control dependencies from all instructions in the module. -class ControlDepRemover : public HloModulePass { - public: - ControlDepRemover() = default; - absl::string_view name() const override { return "control-dep-remover"; } - - using HloPassInterface::Run; - absl::StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { - bool changed = false; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - changed |= !instruction->control_predecessors().empty(); - TF_RETURN_IF_ERROR(instruction->DropAllControlDeps()); - } - } - return changed; - } -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/despecializer.h" #endif // XLA_SERVICE_DESPECIALIZER_H_ diff --git a/third_party/xla/xla/service/dot_as_convolution_util.cc b/third_party/xla/xla/service/dot_as_convolution_util.cc index 25d6b6a48c9d48..8de34a379922f0 100644 --- a/third_party/xla/xla/service/dot_as_convolution_util.cc +++ b/third_party/xla/xla/service/dot_as_convolution_util.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/shape_inference.h" #include "xla/status_macros.h" +#include "xla/util.h" namespace xla { namespace dot_as_convolution_util { @@ -202,30 +203,30 @@ DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) { dnums.contracting_dims.back().output = -1; dnums.contracting_dims.back().spatial_dim = -1; } - for (int64_t i = 0; i < dot->operand(0)->shape().rank(); ++i) { - if (!absl::c_linear_search(dot_dim_numbs.lhs_batch_dimensions(), i) && - !absl::c_linear_search(dot_dim_numbs.lhs_contracting_dimensions(), i)) { - dnums.lhs_non_contracting_dims.emplace_back(); - dnums.lhs_non_contracting_dims.back().lhs = i; - dnums.lhs_non_contracting_dims.back().rhs = -1; - dnums.lhs_non_contracting_dims.back().output = - dot_dim_numbs.lhs_batch_dimensions_size() + - dnums.lhs_non_contracting_dims.size() - 1; - dnums.lhs_non_contracting_dims.back().spatial_dim = -1; - } + for (auto i : + GetNonContractingDims(dot->operand(0)->shape().rank(), + dot_dim_numbs.lhs_contracting_dimensions(), + dot_dim_numbs.lhs_batch_dimensions())) { + dnums.lhs_non_contracting_dims.emplace_back(); + dnums.lhs_non_contracting_dims.back().lhs = i; + dnums.lhs_non_contracting_dims.back().rhs = -1; + dnums.lhs_non_contracting_dims.back().output = + dot_dim_numbs.lhs_batch_dimensions_size() + + dnums.lhs_non_contracting_dims.size() - 1; + dnums.lhs_non_contracting_dims.back().spatial_dim = -1; } - for (int64_t i = 0; i < dot->operand(1)->shape().rank(); ++i) { - if (!absl::c_linear_search(dot_dim_numbs.rhs_batch_dimensions(), i) && - !absl::c_linear_search(dot_dim_numbs.rhs_contracting_dimensions(), i)) { - dnums.rhs_non_contracting_dims.emplace_back(); - dnums.rhs_non_contracting_dims.back().lhs = -1; - dnums.rhs_non_contracting_dims.back().rhs = i; - dnums.rhs_non_contracting_dims.back().output = - dot_dim_numbs.lhs_batch_dimensions_size() + - dnums.lhs_non_contracting_dims.size() + - dnums.rhs_non_contracting_dims.size() - 1; - dnums.rhs_non_contracting_dims.back().spatial_dim = -1; - } + for (auto i : + GetNonContractingDims(dot->operand(1)->shape().rank(), + dot_dim_numbs.rhs_contracting_dimensions(), + dot_dim_numbs.rhs_batch_dimensions())) { + dnums.rhs_non_contracting_dims.emplace_back(); + dnums.rhs_non_contracting_dims.back().lhs = -1; + dnums.rhs_non_contracting_dims.back().rhs = i; + dnums.rhs_non_contracting_dims.back().output = + dot_dim_numbs.lhs_batch_dimensions_size() + + dnums.lhs_non_contracting_dims.size() + + dnums.rhs_non_contracting_dims.size() - 1; + dnums.rhs_non_contracting_dims.back().spatial_dim = -1; } dnums.lhs_shape_rank = dot->operand(0)->shape().rank(); diff --git a/third_party/xla/xla/service/dot_decomposer.h b/third_party/xla/xla/service/dot_decomposer.h index 362f4958f30e78..1e6f4015f169a3 100644 --- a/third_party/xla/xla/service/dot_decomposer.h +++ b/third_party/xla/xla/service/dot_decomposer.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_DECOMPOSER_H_ #define XLA_SERVICE_DOT_DECOMPOSER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// DotDecomposer is a pass which converts dots into a canonical form where -// non-contracting and contracting dimensions are reshaped together and batch -// dimensions are the most major dimensions. -class DotDecomposer : public HloModulePass { - public: - absl::string_view name() const override { return "dot_decomposer"; } - - // Run DotDecomposer pass on computations in 'module'. - // Returns whether the 'module' was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/dot_decomposer.h" #endif // XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/dot_dimension_merger.h b/third_party/xla/xla/service/dot_dimension_merger.h index a9e511e0c10f16..dcc23bc149217a 100644 --- a/third_party/xla/xla/service/dot_dimension_merger.h +++ b/third_party/xla/xla/service/dot_dimension_merger.h @@ -16,27 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_DIMENSION_MERGER_H_ #define XLA_SERVICE_DOT_DIMENSION_MERGER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Merge consecutive batch dimensions of a dot() by inserting reshapes. -class DotDimensionMerger : public HloModulePass { - public: - absl::string_view name() const override { return "dot_dimension_merger"; } - - // Run the pass on computations in 'module'. - // Return whether the 'module' was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" #endif // XLA_SERVICE_DOT_DIMENSION_MERGER_H_ diff --git a/third_party/xla/xla/service/dot_merger.h b/third_party/xla/xla/service/dot_merger.h index 37081edd76c423..5f8c1160686c27 100644 --- a/third_party/xla/xla/service/dot_merger.h +++ b/third_party/xla/xla/service/dot_merger.h @@ -16,59 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_MERGER_H_ #define XLA_SERVICE_DOT_MERGER_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Merges dots that share an operand. Transforms -// -// x = dot(a, b) -// y = dot(a, c) -// -// into -// -// z = dot(a, concat(b, c)) -// x = slice(z) -// y = slice(z). -// -// This requires that x and y are independent -- that is, x does not -// transitively depend on y, and y does not transitively depend on x. -// -// This is a good transformation if the merged dot runs faster than the original -// dots. On the other hand, merging the dots results in a single result buffer -// z whose live range is the union of x and y's live ranges, so can lead to -// increased memory pressure. You probably only want to do this optimization on -// "small" dots which cannot saturate your device when run alone. -// -// We thus allow backends to set a max size above which an op will not be -// merged. The input+output bytes of at least one dot must be below the -// threshold otherwise we won't merge. (We don't require that both dots be -// below the threshold because backends likely want to allow merging a "small" -// dot into a "large" dot while preventing two large dots from being merged.) -// -// Will skip gemms with more than one non-contracting dimension in the dot -// operands to be concatenated. -class DotMerger : public HloModulePass { - public: - explicit DotMerger(int64_t max_size_to_merge) - : max_size_to_merge_(max_size_to_merge) {} - - absl::string_view name() const override { return "dot-merger"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - int64_t max_size_to_merge_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/dot_merger.h" #endif // XLA_SERVICE_DOT_MERGER_H_ diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index 3aa3a8862011a3..7e819d8c039bd0 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -50,10 +50,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -118,6 +118,7 @@ struct CanonicalDebugOptions { dump_module_metadata(opts.xla_dump_module_metadata()), dump_compress_protos(opts.xla_dump_compress_protos()), dump_hlo_metadata(!opts.xla_dump_disable_metadata()), + dump_fdo_profiles(opts.xla_gpu_experimental_dump_fdo_profiles()), dump_as_long_text(opts.xla_dump_hlo_as_long_text()), dump_mlir_pretty_form(opts.xla_dump_enable_mlir_pretty_form()), dump_large_constants(opts.xla_dump_large_constants()), @@ -236,6 +237,7 @@ struct CanonicalDebugOptions { bool dump_module_metadata; bool dump_compress_protos; bool dump_hlo_metadata; + bool dump_fdo_profiles; bool dump_as_long_text; bool dump_mlir_pretty_form; bool dump_large_constants; @@ -460,13 +462,17 @@ static std::vector DumpHloModuleImpl( file_paths.push_back(DumpToFileInDirOrStdoutImpl( StrCat(filename, ".txt"), module.ToString(print_options), opts)); if (buffer_assn) { - DataProducer data_producer; - data_producer.Append([&] { return buffer_assn->ToString(); }); - data_producer.Append([&] { return "\n\n"; }); - data_producer.Append( + DataProducer buffer_assignment; + buffer_assignment.Append([&] { return buffer_assn->ToString(); }); + buffer_assignment.Append([&] { return "\n\n"; }); + buffer_assignment.Append( [&] { return buffer_assn->hlo_live_range().ToString(); }); file_paths.push_back(DumpToFileInDirOrStdoutImpl( - StrCat(filename, "-buffer-assignment.txt"), data_producer, opts)); + StrCat(filename, "-buffer-assignment.txt"), buffer_assignment, opts)); + DataProducer summary_report; + summary_report.Append([&] { return buffer_assn->MemoryUsageReport(); }); + file_paths.push_back(DumpToFileInDirOrStdoutImpl( + StrCat(filename, "-memory-usage-report.txt"), summary_report, opts)); } } @@ -523,6 +529,12 @@ static std::vector DumpHloModuleImpl( } } + if (opts.dump_fdo_profiles) { + file_paths.push_back( + DumpToFileInDirImpl(StrFormat("%s.fdo_profile", filename), + module.config().fdo_profile(), opts)); + } + // Special case for rendering graphs as URLs. We'll dump them to a file // because why not, but we always log them to stdout as well. if (opts.dump_as_url) { diff --git a/third_party/xla/xla/service/dump_test.cc b/third_party/xla/xla/service/dump_test.cc index 6df547d96fdfce..7d9e4c794a3141 100644 --- a/third_party/xla/xla/service/dump_test.cc +++ b/third_party/xla/xla/service/dump_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" @@ -151,5 +151,35 @@ TEST(DumpTest, DumpProtobufToFileWhenDisabled) { EXPECT_THAT(matches, IsEmpty()); } +TEST(DumpTest, DumpFdoProfileToFileWhenEnabled) { + std::string fdo_profile = "fdo_profile"; + HloModuleConfig config; + *config.mutable_fdo_profile() = fdo_profile; + DebugOptions options = config.debug_options(); + auto env = tsl::Env::Default(); + std::string dump_dir; + ASSERT_TRUE(env->LocalTempFilename(&dump_dir)); + options.set_xla_dump_to(dump_dir); + options.set_xla_gpu_experimental_dump_fdo_profiles(true); + config.set_debug_options(options); + const char* kModuleStr = R"( + HloModule m + test { + p0 = s32[11] parameter(0) + c = s32[11] constant({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) + ROOT x = s32[11] multiply(p0, c) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, + ParseAndReturnUnverifiedModule(kModuleStr, config)); + std::string dump_name = "dump"; + auto paths = DumpHloModuleIfEnabled(*m, dump_name); + EXPECT_EQ(paths.size(), 2); + + std::string data; + EXPECT_TRUE(ReadFileToString(env, paths[1], &data).ok()); + EXPECT_TRUE(absl::StrContains(data, fdo_profile)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/dynamic_dimension_inference.cc b/third_party/xla/xla/service/dynamic_dimension_inference.cc index 683ab191992a66..97436ca78d229e 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/dynamic_parameter_binding.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/dynamic_window_utils.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/tuple_util.h" #include "xla/service/while_util.h" @@ -2461,7 +2461,9 @@ absl::StatusOr DynamicDimensionInferenceVisitor::RequiresPadToStatic( return true; } if (use.instruction->opcode() != HloOpcode::kCustomCall || - use.instruction->custom_call_target() != "PadToStatic") { + !use.instruction->IsCustomCall({"PadToStatic", "Sharding", + "SPMDShardToFullShape", + "SPMDFullToShardShape"})) { if (parent_->op_supports_dynamism_handler_ == nullptr) { return true; } diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc index 9dc9de161aa4bd..ad4d0648528ca1 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/service/dynamic_dimension_inference.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -35,8 +35,6 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { diff --git a/third_party/xla/xla/service/dynamic_dimension_simplifier.h b/third_party/xla/xla/service/dynamic_dimension_simplifier.h index 5fd7eedbf53643..0824118cb48bb5 100644 --- a/third_party/xla/xla/service/dynamic_dimension_simplifier.h +++ b/third_party/xla/xla/service/dynamic_dimension_simplifier.h @@ -15,26 +15,7 @@ limitations under the License. #ifndef XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ #define XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ -#include - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This pass simplifies operations on dynamic dimension sizes so that it can be -// easily analyzed by later passes. -class DynamicDimensionSimplifier : public HloModulePass { - public: - absl::string_view name() const override { - return "dynamic-dimension-simplifier"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" #endif // XLA_SERVICE_DYNAMIC_DIMENSION_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/dynamic_index_splitter.h b/third_party/xla/xla/service/dynamic_index_splitter.h index 87bf0a95bc3fd6..670d297da852ce 100644 --- a/third_party/xla/xla/service/dynamic_index_splitter.h +++ b/third_party/xla/xla/service/dynamic_index_splitter.h @@ -16,25 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ #define XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Convert R1 index operands to DynamicSlice and DynamicUpdateSlice ops into -// separate scalars. -class DynamicIndexSplitter : public HloModulePass { - public: - DynamicIndexSplitter() = default; - absl::string_view name() const override { return "dynamic-index-splitter"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" #endif // XLA_SERVICE_DYNAMIC_INDEX_SPLITTER_H_ diff --git a/third_party/xla/xla/service/dynamic_padder.cc b/third_party/xla/xla/service/dynamic_padder.cc index 8d5b6005fb3849..47ad65f3634ae0 100644 --- a/third_party/xla/xla/service/dynamic_padder.cc +++ b/third_party/xla/xla/service/dynamic_padder.cc @@ -32,8 +32,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/dynamic_parameter_binding.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -42,12 +42,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/dynamic_window_utils.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_dce.h" #include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/service/tuple_util.h" @@ -2021,6 +2021,11 @@ absl::Status DynamicShapeRemovingVisitor::HandleCustomCall( // nature they support dynamic lowering. return absl::OkStatus(); } + if (hlo->IsCustomCall( + {"Sharding", "SPMDShardToFullShape", "SPMDFullToShardShape"})) { + // Sharding ops are purely symbolic. + return absl::OkStatus(); + } return DefaultAction(hlo); } @@ -2235,7 +2240,6 @@ absl::StatusOr DynamicPadder::Run( // the output tensor to be in dynamic form. bool require_dynamic_output = options_.slice_dynamic_output && computation == module->entry_computation(); - changed |= require_dynamic_output; TF_ASSIGN_OR_RETURN(bool c, DynamicShapeRemovingVisitor::Run( computation, options_.op_supports_dynamism_handler, diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index 972bc38ae8c40b..83acce7980d4f2 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/log/check.h" @@ -26,41 +27,37 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_replace.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/dynamic_dimension_inference.h" -#include "xla/service/dynamic_dimension_simplifier.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/test.h" #include "xla/test_helpers.h" -#include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/test_benchmark.h" -#include "tsl/protobuf/error_codes.pb.h" namespace xla { namespace { @@ -2394,5 +2391,22 @@ ENTRY gds { EXPECT_TRUE(status.ok()); } +TEST_F(DynamicPadderTest, ShardingDynamicShapes) { + constexpr std::string_view hlo = R"( +HloModule main + +ENTRY main { + token.0 = after-all() + infeed_tuple.0 = (s32[<=32], token[]) infeed(token.0), sharding={{manual}, {manual}} + infeed.0 = get-tuple-element(infeed_tuple.0), index=0 + ROOT sharding.0 = s32[<=32] custom-call(infeed.0), custom_call_target="Sharding", sharding={manual} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + RunPadder(/*slice_dynamic_output=*/true)); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/dynamic_update_slice_test.cc b/third_party/xla/xla/service/dynamic_update_slice_test.cc index eb8932a9478094..96298fb6437ffd 100644 --- a/third_party/xla/xla/service/dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/dynamic_update_slice_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/execution_options_util.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/service/eigh_expander.h b/third_party/xla/xla/service/eigh_expander.h index 074e09305354d9..5ef10cffe0bbcc 100644 --- a/third_party/xla/xla/service/eigh_expander.h +++ b/third_party/xla/xla/service/eigh_expander.h @@ -16,32 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_EIGH_EXPANDER_H_ #define XLA_SERVICE_EIGH_EXPANDER_H_ -#include "absl/container/flat_hash_map.h" -#include "xla/client/xla_builder.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -class EighExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "eigh_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; - - virtual XlaOp BuildEigh(XlaOp a, bool lower, int64_t max_iter, float tol, - bool sort_eigenvalues); - - absl::Status SortByEigenvalues(XlaOp& v, XlaOp& w); - - private: - // Mapping from op signatures to existing computations. - absl::flat_hash_map computation_cache_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/eigh_expander.h" #endif // XLA_SERVICE_EIGH_EXPANDER_H_ diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index 3d89cd967f6342..ae70a60e44cbac 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -64,11 +64,6 @@ limitations under the License. namespace xla { using absl::StrCat; -using llvm::PatternMatch::m_BitCast; -using llvm::PatternMatch::m_Intrinsic; -using llvm::PatternMatch::m_Select; -using llvm::PatternMatch::m_Value; -using llvm::PatternMatch::match; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; @@ -220,6 +215,90 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +template +llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, + llvm::Value* f8_bits, + llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + // F16 values that are halfway between denormal F8 values. This is used to + // determine how to round to denormal F8 values. + const int halfway_points_e4[8] = { + 0x1400, // 0x1.0p-10 ; halfway between [0/8 * 2^-6, 1/8 * 2^-6] + 0x1A00, // 0x1.8p-9 ; halfway between [1/8 * 2^-6, 2/8 * 2^-6] + 0x1D00, // 0x1.4p-8 ; halfway between [2/8 * 2^-6, 3/8 * 2^-6] + 0x1F00, // 0x1.Cp-8 ; halfway between [3/8 * 2^-6, 4/8 * 2^-6] + 0x2080, // 0x1.2p-7 ; halfway between [4/8 * 2^-6, 5/8 * 2^-6] + 0x2180, // 0x1.6p-7 ; halfway between [5/8 * 2^-6, 6/8 * 2^-6] + 0x2280, // 0x1.Ap-7 ; halfway between [6/8 * 2^-6, 7/8 * 2^-6] + 0x2380, // 0x1.Ep-7 ; halfway between [7/8 * 2^-6, 8/8 * 2^-6] + }; + + const int halfway_points_e3[16] = { + 0x2000, // 0x1.0p-7; halfway between [0/16 * 2^-2, 1/16 * 2^-2] + 0x2600, // 0x1.8p-6; halfway between [1/16 * 2^-2, 2/16 * 2^-2] + 0x2900, // 0x1.4p-5; halfway between [2/16 * 2^-2, 3/16 * 2^-2] + 0x2B00, // 0x1.Cp-5; halfway between [3/16 * 2^-2, 4/16 * 2^-2] + 0x2C80, // 0x1.2p-4; halfway between [4/16 * 2^-2, 5/16 * 2^-2] + 0x2D80, // 0x1.6p-4; halfway between [5/16 * 2^-2, 6/16 * 2^-2] + 0x2E80, // 0x1.Ap-4; halfway between [6/16 * 2^-2, 7/16 * 2^-2] + 0x2F80, // 0x1.Ep-4; halfway between [7/16 * 2^-2, 8/16 * 2^-2] + 0x3040, // 0x1.1p-3; halfway between [8/16 * 2^-2, 9/16 * 2^-2] + 0x30C0, // 0x1.3p-3; halfway between [9/16 * 2^-2, 10/16 * 2^-2] + 0x3140, // 0x1.5p-3; halfway between [10/16 * 2^-2, 11/16 * 2^-2] + 0x31C0, // 0x1.7p-3; halfway between [11/16 * 2^-2, 12/16 * 2^-2] + 0x3240, // 0x1.9p-3; halfway between [12/16 * 2^-2, 13/16 * 2^-2] + 0x32C0, // 0x1.Bp-3; halfway between [13/16 * 2^-2, 14/16 * 2^-2] + 0x3340, // 0x1.Dp-3; halfway between [14/16 * 2^-2, 15/16 * 2^-2] + 0x33C0, // 0x1.Fp-3; halfway between [15/16 * 2^-2, 16/16 * 2^-2] + }; + + const int* halfway_points; + int arr_sz; + if constexpr (f8_exponent_bits == 4) { + halfway_points = halfway_points_e4; + arr_sz = 8; + } else if constexpr (f8_exponent_bits == 3) { + halfway_points = halfway_points_e3; + arr_sz = 16; + } + + // Handle case where output is denormal. If we're rounding to a denormal + // value, ignore the current value of f8_bits and set it to the correct + // denormal value. We emit the equivalent of the following: + // + // if (f16_abs_bits <= halfway_points[0]) { + // f8_bits = 0; + // } else if (f16_abs_bits < halfway_points[1]) { + // f8_bits = 1; + // } else if (f16_abs_bits <= halfway_points[2]) { + // ... // More if-else statements. The comparisons alternate between <= + // ... // and < to handle round-to-even properly. + // } else if (f16_abs_bits < halfway_points[7]) { + // f8_bits = 7; + // } + for (int i = arr_sz - 1; i >= 0; i--) { + Value* comparison; + if (i % 2 == 0) { + comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); + } else { + comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); + } + f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); + } + return f8_bits; +} + absl::StatusOr EmitF16ToF8e5m2(llvm::Value* f16_value, llvm::IRBuilder<>* b) { TF_ASSIGN_OR_RETURN( @@ -242,6 +321,223 @@ llvm::Value* EmitF8e5m2ToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { return b->CreateBitCast(shifted, b->getHalfTy()); } +template +absl::StatusOr EmitF16ToF8e(llvm::Value* f16_value, + llvm::IRBuilder<>* b) { + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + using llvm::APInt; + using llvm::Value; + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f16_as_int = bitcast(f16_value, int) + // f16_abs_bits = f16_as_int & 0x7FFF + Value* f16_as_int = b->CreateBitCast(f16_value, i16_type); + llvm::Value* f16_abs_bits = b->CreateAnd(f16_as_int, i16_const(0x7FFF)); + + // Get the sign. + // f8_sign = (f16_as_int & 0x8000) >> 8 + Value* f16_sign = b->CreateAnd(f16_as_int, i16_const(0x8000)); + f16_sign = b->CreateLShr(f16_sign, i16_const(8)); + Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); + + // Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits + // Denormal values are not handled properly here and are + // dealt with later in this function. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, f16_value, + /*dest_exponent_bits=*/f8_exponent_bits, + /*dest_mantissa_bits=*/f8_mantissa_bits, + /*quiet_nans=*/true, b); + CHECK_OK(f16_reduced_statusor.status()); // Crash OK + Value* f16_reduced = f16_reduced_statusor.value(); + f16_reduced = b->CreateBitCast(f16_reduced, i16_type); + + // Remove the sign bit. + // f16_reduced = f16_reduced & 0x7FFF + f16_reduced = b->CreateAnd(f16_reduced, i16_const(0x7FFF)); + + // F16 inf in binary: 0 11111 0000000000 + constexpr int f16_inf_value = 0x7C00; + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int min_normal_value = (exponent_bias_difference + 1) + << f16_mantissa_bits; + + // Round values smaller than the smallest F8 normal value up to the smallest + // F8 normal value. The case where we round to a denormal value is handled + // later. + // f16_reduced = max(f16_reduced, min_normal_value) + f16_reduced = b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(min_normal_value)), + i16_const(min_normal_value), f16_reduced); + + // Adjust the exponent by subtracting the difference in exponent bias: + // f16_reduced -= (exponent_bias_difference << f16_mantissa_bits) + // For infinity/NaN values, subtract twice the difference in exponent bias + // to ensure the leading exponent bit(s) of f16_reduced are set to zero. + f16_reduced = b->CreateSub( + f16_reduced, + b->CreateSelect( + b->CreateICmpULT(f16_reduced, i16_const(f16_inf_value)), + i16_const(exponent_bias_difference << f16_mantissa_bits), + i16_const(exponent_bias_difference << (f16_mantissa_bits + 1)))); + + // Shift to convert to F8. + // f16_reduced = f16_reduced >> mantissa_bits_difference; + f16_reduced = b->CreateLShr(f16_reduced, i16_const(mantissa_bits_difference)); + + Value* f8_bits = b->CreateTrunc(f16_reduced, i8_type); + + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); + + // Set the sign bit. + // f8_bits |= f8_sign + f8_bits = b->CreateOr(f8_bits, f8_sign); + return f8_bits; +} + +template +llvm::Value* EmitToF16F8e(llvm::Value* f8_value, llvm::IRBuilder<>* b) { + using llvm::APInt; + using llvm::Value; + static_assert(3 <= f8_exponent_bits && f8_exponent_bits <= 4); + + llvm::IntegerType* i8_type = b->getInt8Ty(); + llvm::IntegerType* i16_type = b->getInt16Ty(); + auto i8_const = [i8_type](int val) { + return llvm::ConstantInt::get(i8_type, val); + }; + auto i16_const = [i16_type](int val) { + return llvm::ConstantInt::get(i16_type, val); + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e4[8] = { + 0x0000, // 0 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 + }; + + // Map from F8 denormal value to F16 value. + const int f8_denormal_to_f16_e3[16] = { + 0x0000, // 0 + 0x2400, // 1/16 * 2^-2 + 0x2800, // 2/16 * 2^-2 + 0x2A00, // 3/16 * 2^-2 + 0x2C00, // 4/16 * 2^-2 + 0x2D00, // 5/16 * 2^-2 + 0x2E00, // 6/16 * 2^-2 + 0x2F00, // 7/16 * 2^-2 + 0x3000, // 8/16 * 2^-2 + 0x3080, // 9/16 * 2^-2 + 0x3100, // 10/16 * 2^-2 + 0x3180, // 11/16 * 2^-2 + 0x3200, // 12/16 * 2^-2 + 0x3280, // 13/16 * 2^-2 + 0x3300, // 14/16 * 2^-2 + 0x3380, // 15/16 * 2^-2 + }; + + // Cast the input value to an integer for bitwise manipulation. Get the + // absolute value of the input value. + // f8_as_int = bitcast(f16_value, int) + // f8_abs_bits = f8_as_int & 0x7F + Value* f8_as_int = b->CreateBitCast(f8_value, i8_type); + Value* f8_abs_bits = b->CreateAnd(f8_as_int, i8_const(0x7F)); + + // We assume below that the value is neither NaN nor denormal. If it NaN or + // denormal, the output is set to NaN or zero at the end using Select + // instructions. + + // Get the sign: + // f16_sign = (f8_as_int & 0x80) << 8 + Value* f8_sign = b->CreateAnd(f8_as_int, i8_const(0x80)); + Value* f16_sign = b->CreateZExt(f8_sign, i16_type); + f16_sign = b->CreateShl(f16_sign, i16_const(8)); + + int exponent_mask; + const int* f8_denormal_to_f16; + int f8_denormal_size; + if constexpr (f8_exponent_bits == 4) { + exponent_mask = 0x78; + f8_denormal_to_f16 = f8_denormal_to_f16_e4; + f8_denormal_size = 8; + } else if constexpr (f8_exponent_bits == 3) { + exponent_mask = 0x70; + f8_denormal_to_f16 = f8_denormal_to_f16_e3; + f8_denormal_size = 16; + } + constexpr int f8_bias = (1 << (f8_exponent_bits - 1)) - 1; + constexpr int exponent_bias_difference = 15 - f8_bias; + constexpr int f16_mantissa_bits = 10; // e5m10 + constexpr int f8_mantissa_bits = 7 - f8_exponent_bits; + constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; + constexpr int f8_mantissa_mask = (1 << f8_mantissa_bits) - 1; + + // Get the exponent: + // f8_exponent = (f8_as_int & exponent_mask) >> f8_mantissa_bits + Value* f8_exponent_bits_v = b->CreateAnd(f8_as_int, i8_const(exponent_mask)); + Value* f8_exponent = + b->CreateLShr(f8_exponent_bits_v, i8_const(f8_mantissa_bits)); + + // Adjust the exponent by adding the difference in exponent bias: + // f16_exponent = (f8_exponent + exponent_bias_difference) + // << f16_mantissa_bits + Value* f16_exponent = + b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); + f16_exponent = b->CreateZExt(f16_exponent, i16_type); + f16_exponent = b->CreateShl(f16_exponent, i16_const(f16_mantissa_bits)); + + // Set output exponent to 11111 if input exponent is 111 (Inf or NaN) + // 0.11111.0000000000 is 0x7C00 + Value* is_exp_1111 = + b->CreateICmpEQ(f8_exponent_bits_v, i8_const(exponent_mask)); + f16_exponent = b->CreateSelect(is_exp_1111, i16_const(0x7C00), f16_exponent); + + // Get the mantissa: + // f16_mantissa = (f8_mantissa & f8_mantissa_mask) + // << mantissa_bits_difference + Value* f8_mantissa = b->CreateAnd(f8_as_int, i8_const(f8_mantissa_mask)); + Value* f16_mantissa = b->CreateZExt(f8_mantissa, i16_type); + f16_mantissa = + b->CreateShl(f16_mantissa, i16_const(mantissa_bits_difference)); + + // Combine the exponent and mantissa: + // f16_as_int = f16_exponent | f16_mantissa + Value* f16_as_int = b->CreateOr(f16_exponent, f16_mantissa); + + // If the F8 value is denormal, use the map above to determine the correct F16 + // value. + // if (f8_abs_bits < 8) { f16_as_int = f8_denormal_to_f16[f8_abs_bits]; } + for (int i = 0; i < f8_denormal_size; i++) { + Value* is_denormal_value = b->CreateICmpEQ(f8_abs_bits, i8_const(i)); + f16_as_int = b->CreateSelect(is_denormal_value, + i16_const(f8_denormal_to_f16[i]), f16_as_int); + } + + // Set the sign bit. + // f16_as_int |= f16_sign + f16_as_int = b->CreateOr(f16_as_int, f16_sign); + return b->CreateBitCast(f16_as_int, b->getHalfTy()); +} + llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { using llvm::APInt; using llvm::Value; @@ -277,7 +573,7 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { /*dest_exponent_bits=*/5, /*dest_mantissa_bits=*/3, /*quiet_nans=*/false, b); - CHECK(f16_reduced_statusor.ok()); // Crash OK + CHECK_OK(f16_reduced_statusor.status()); // Crash OK Value* f16_reduced = f16_reduced_statusor.value(); f16_reduced = b->CreateBitCast(f16_reduced, i16_type); @@ -297,6 +593,7 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { i16_const(min_normal_value), f16_reduced); constexpr int exponent_bias_difference = 15 - 7; + constexpr int f8_exponent_bits = 4; constexpr int f16_mantissa_bits = 10; constexpr int f8_mantissa_bits = 3; constexpr int mantissa_bits_difference = f16_mantissa_bits - f8_mantissa_bits; @@ -322,42 +619,9 @@ llvm::Value* EmitF16ToF8e4m3fn(llvm::Value* f16_value, llvm::IRBuilder<>* b) { b->CreateICmpUGT(f16_abs_bits, i16_const(max_finite_value)), i8_const(0x7F), f8_bits); - // F16 values that are halfway between denormal F8 values. This is used to - // determine how to round to denormal F8 values. - const int halfway_points[8] = { - 0x1400, // 2**-10; halfway between [0, 2**-9] - 0x1A00, // 1.5 * 2**-9; halfway between [2**-9, 2**-8] - 0x1D00, // 1.25 * 2**-8; halfway between [2**-8, 1.5 * 2**-8] - 0x1F00, // 1.75 * 2**-8; halfway between [1.5 * 2**-8, 2**-7] - 0x2080, // 1.125 * 2**-7; halfway between [2**-7, 1.25 * 2**-7] - 0x2180, // 1.375 * 2**-7; halfway between [1.25 * 2**-7, 1.5 * 2**-7] - 0x2280, // 1.625 * 2**-7; halfway between [1.5 * 2**-7, 1.75 * 2**-7] - 0x2380, // 1.875 * 2**-7; halfway between [1.75 * 2**-7, 2**-6] - }; - - // Handle case where output is denormal. If we're rounding to a denormal - // value, ignore the current value of f8_bits and set it to the correct - // denormal value. We emit the equivalent of the following: - // - // if (f16_abs_bits <= halfway_points[0]) { - // f8_bits = 0; - // } else if (f16_abs_bits < halfway_points[1]) { - // f8_bits = 1; - // } else if (f16_abs_bits <= halfway_points[2]) { - // ... // More if-else statements. The comparisons alternate between <= - // ... // and < to handle round-to-even properly. - // } else if (f16_abs_bits < halfway_points[7]) { - // f8_bits = 7; - // } - for (int i = ABSL_ARRAYSIZE(halfway_points) - 1; i >= 0; i--) { - Value* comparison; - if (i % 2 == 0) { - comparison = b->CreateICmpULE(f16_abs_bits, i16_const(halfway_points[i])); - } else { - comparison = b->CreateICmpULT(f16_abs_bits, i16_const(halfway_points[i])); - } - f8_bits = b->CreateSelect(comparison, i8_const(i), f8_bits); - } + // Handle F16 values that are halfway between denormal F8 values. + f8_bits = + handle_halfway_points_F16ToF8(f16_abs_bits, f8_bits, b); // Set the sign bit. // f8_bits |= f8_sign @@ -408,7 +672,7 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { b->CreateLShr(f8_exponent_bits, i8_const(f8_mantissa_bits)); // Adjust the exponent by adding the difference in exponent bias: - // f16_exponent = (f8_exopnent + exponent_bias_difference) + // f16_exponent = (f8_exponent + exponent_bias_difference) // << f16_mantissa_bits Value* f16_exponent = b->CreateAdd(f8_exponent, i8_const(exponent_bias_difference)); @@ -435,13 +699,13 @@ llvm::Value* EmitF8e4m3fnToF16(llvm::Value* f8_value, llvm::IRBuilder<>* b) { // Map from F8 denormal value to F16 value. int f8_denormal_to_f16[8] = { 0x0000, // 0 - 0x1800, // 2**-9 - 0x1C00, // 2**-8 - 0x1E00, // 1.5 * 2**-8 - 0x2000, // 2**-7 - 0x2100, // 1.25 * 2**-7 - 0x2200, // 1.5 * 2**-7 - 0x2300, // 1.75 * 2**-7 + 0x1800, // 1/8 * 2^-6 + 0x1C00, // 2/8 * 2^-6 + 0x1E00, // 3/8 * 2^-6 + 0x2000, // 4/8 * 2^-6 + 0x2100, // 5/8 * 2^-6 + 0x2200, // 6/8 * 2^-6 + 0x2300, // 7/8 * 2^-6 }; // If the F8 value is denormal, use the map above to determine the correct F16 @@ -471,8 +735,8 @@ llvm::Value* EmitF16ToF8e4m3b11fnuz(llvm::Value* f16_value, auto type = f16_value->getType(); auto f16_abs_value = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {f16_value}, {type}, b); - auto f16_zero = llvm::ConstantFP::getZero(type); - auto is_zero = b->CreateFCmpOEQ(f16_abs_value, f16_zero); + auto f16_zero_or_underflow = llvm::ConstantFP::get(type, 0x1.004p-14); + auto is_zero = b->CreateFCmpOLT(f16_abs_value, f16_zero_or_underflow); auto f8_overflow_threshold = llvm::ConstantFP::get(type, 0x1.fp+4); auto no_overflow = b->CreateFCmpOLT(f16_abs_value, f8_overflow_threshold); @@ -604,6 +868,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), b_); } + if (to_type == F8E4M3) { + return EmitF16ToF8e<4>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } if (to_type == F8E4M3FN) { return EmitF16ToF8e4m3fn( EmitIntegralToFloating(operand_value, from_type, F16, module_, @@ -623,6 +893,12 @@ absl::StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( b_), to_type, b_); } + if (to_type == F8E3M4) { + return EmitF16ToF8e<3>( + EmitIntegralToFloating(operand_value, from_type, F16, module_, + b_), + b_); + } return EmitIntegralToFloating(operand_value, from_type, to_type, module_, b_); } @@ -789,6 +1065,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E4M3) { + TF_RET_CHECK(to_type != F8E4M3); + operand_value = EmitToF16F8e<4>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (from_type == F8E4M3FN) { TF_RET_CHECK(to_type != F8E4M3FN); operand_value = EmitF8e4m3fnToF16(operand_value, b_); @@ -817,6 +1101,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return operand_value; } } + if (from_type == F8E3M4) { + TF_RET_CHECK(to_type != F8E3M4); + operand_value = EmitToF16F8e<3>(operand_value, b_); + from_type = F16; + if (from_type == to_type) { + return operand_value; + } + } if (primitive_util::IsComplexType(to_type)) { PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); @@ -844,6 +1136,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } return EmitF16ToF8e5m2(operand_value, b_); } + if (to_type == F8E4M3) { + // Cast to F16 first. Casts to F8E4M3 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<4>(operand_value, b_); + } if (to_type == F8E4M3FN) { // Cast to F16 first. Casts to F8E4M3FN must be from F16. if (from_type != F16) { @@ -863,6 +1163,14 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (to_type == F8E5M2FNUZ || to_type == F8E4M3FNUZ) { return EmitFloatingToF8fnuz(from_type, operand_value, to_type, b_); } + if (to_type == F8E3M4) { + // Cast to F16 first. Casts to F8E3M4 must be from F16. + if (from_type != F16) { + operand_value = b_->CreateFPCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(F16, module_)); + } + return EmitF16ToF8e<3>(operand_value, b_); + } if (to_type == PRED) { return b_->CreateZExt( FCmpUNE(operand_value, @@ -1398,6 +1706,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( if (operand_type == F8E5M2) { lhs_value = EmitF8e5m2ToF16(lhs_value, b_); rhs_value = EmitF8e5m2ToF16(rhs_value, b_); + } else if (operand_type == F8E4M3) { + lhs_value = EmitToF16F8e<4>(lhs_value, b_); + rhs_value = EmitToF16F8e<4>(rhs_value, b_); } else if (operand_type == F8E4M3FN) { lhs_value = EmitF8e4m3fnToF16(lhs_value, b_); rhs_value = EmitF8e4m3fnToF16(rhs_value, b_); @@ -1408,6 +1719,9 @@ absl::StatusOr ElementalIrEmitter::EmitFloatBinaryOp( TF_ASSIGN_OR_RETURN( rhs_value, EmitF8fnuzToFloating(operand_type, rhs_value, F16, b_, module_)); + } else if (operand_type == F8E3M4) { + lhs_value = EmitToF16F8e<3>(lhs_value, b_); + rhs_value = EmitToF16F8e<3>(rhs_value, b_); } switch (op->comparison_direction()) { case ComparisonDirection::kEq: @@ -1844,7 +2158,6 @@ absl::StatusOr ElementalIrEmitter::EmitComplexRsqrt( llvm::Value* neg_one = llvm::ConstantFP::get(type, -1); llvm::Value* inf = llvm::ConstantFP::getInfinity(type); llvm::Value* nan = llvm::ConstantFP::getNaN(type); - // llvm::Value* neg_inf = llvm::ConstantFP::getInfinity(type, true); llvm::Value* a_signed_zero = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::copysign, {zero, a}, {a->getType()}, b_); llvm::Value* b_signed_zero = llvm_ir::EmitCallToIntrinsic( diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 9ee2680065a26f..c0a0ddea66f920 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -19,8 +19,11 @@ limitations under the License. #include #include #include +#include #include +#include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/IR/IRBuilder.h" @@ -36,6 +39,8 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" namespace xla { @@ -68,7 +73,7 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { class ElementalIrEmitterExecutionTestWithoutFastMinMax : public ElementalIrEmitterExecutionTest { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = ElementalIrEmitterExecutionTest::GetDebugOptionsForTest(); debug_options.set_xla_cpu_enable_fast_min_max(false); @@ -77,6 +82,23 @@ class ElementalIrEmitterExecutionTestWithoutFastMinMax } }; +template +class ElementalIrEmitterExecutionTypedTest + : public ElementalIrEmitterExecutionTest { + protected: + const std::string& TypeName() { + return primitive_util::LowercasePrimitiveTypeName( + primitive_util::NativeToPrimitiveType()); + } +}; + +using FloatTypes = + ::testing::Types; + +TYPED_TEST_SUITE(ElementalIrEmitterExecutionTypedTest, FloatTypes); + XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { const std::string hlo_text = R"( HloModule FusedDot @@ -229,473 +251,212 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{(0.)})); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 - (f16_ f16[], f32_ f32[], f64_ f64[]) -> (bf16[], bf16[], bf16[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - converted_f16 = bf16[] convert(f16[] f16_) - converted_f32 = bf16[] convert(f32[] f32_) - converted_f64 = bf16[] convert(f64[] f64_) - ROOT tuple = (bf16[], bf16[], bf16[]) tuple(converted_f16, converted_f32, - converted_f64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = bf16[] convert(s8[] s8_) - converted_s16 = bf16[] convert(s16[] s16_) - converted_s32 = bf16[] convert(s32[] s32_) - converted_s64 = bf16[] convert(s64[] s64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToBF16) { - RunTypeConversionTest(R"( - HloModule convertToBF16 - ENTRY ConvertToBF16 (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (bf16[], bf16[], bf16[], bf16[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = bf16[] convert(u8[] u8_) - converted_u16 = bf16[] convert(u16[] u16_) - converted_u32 = bf16[] convert(u32[] u32_) - converted_u64 = bf16[] convert(u64[] u64_) - ROOT tuple = (bf16[], bf16[], bf16[], bf16[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_f16 bf16[], to_f32 bf16[], to_f64 bf16[]) -> (f16[], f32[], f64[]) { - to_f16 = bf16[] parameter(0) - to_f32 = bf16[] parameter(1) - to_f64 = bf16[] parameter(2) - f16_ = f16[] convert(bf16[] to_f16) - f32_ = f32[] convert(bf16[] to_f32) - f64_ = f64[] convert(bf16[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_s8 bf16[], to_s16 bf16[], to_s32 bf16[], - to_s64 bf16[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = bf16[] parameter(0) - to_s16 = bf16[] parameter(1) - to_s32 = bf16[] parameter(2) - to_s64 = bf16[] parameter(3) - s8_ = s8[] convert(bf16[] to_s8) - s16_ = s16[] convert(bf16[] to_s16) - s32_ = s32[] convert(bf16[] to_s32) - s64_ = s64[] convert(bf16[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16(to_u8 bf16[], to_u16 bf16[], to_u32 bf16[], - to_u64 bf16[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = bf16[] parameter(0) - to_u16 = bf16[] parameter(1) - to_u32 = bf16[] parameter(2) - to_u64 = bf16[] parameter(3) - u8_ = u8[] convert(bf16[] to_u8) - u16_ = u16[] convert(bf16[] to_u16) - u32_ = u32[] convert(bf16[] to_u32) - u64_ = u64[] convert(bf16[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertBF16ToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromBF16 - ENTRY ConvertFromBF16 - (to_c64 bf16[], to_c128 bf16[]) -> (c64[], c128[]) { - to_c64 = bf16[] parameter(0) - to_c128 = bf16[] parameter(1) - c64_ = c64[] convert(bf16[] to_c64) - c128_ = c128[] convert(bf16[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareBF16) { - constexpr char hlo_text[] = R"( - HloModule compareBF16 - ENTRY main { - p0 = bf16[4] parameter(0) - p1 = bf16[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToBF16(lhs); - rhs = LiteralUtil::ConvertF32ToBF16(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaBF16) { - constexpr char hlo_text[] = R"( - HloModule IotaBF16 - ENTRY main { - ROOT iota_ = bf16[4] iota(), iota_dimension=0 +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatsToFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, BatchDotBF16) { - const char* const hlo_text = R"( - HloModule matmul - + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - x = bf16[8,16] parameter(0) - y = bf16[8,16,32] parameter(1) - ROOT dot = bf16[8,32] dot(x, y), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} + f16_ = f16[] parameter(0) + f32_ = f32[] parameter(1) + f64_ = f64[] parameter(2) + bf16_ = bf16[] parameter(3) + converted_f16 = ${tname}[] convert(f16_) + converted_f32 = ${tname}[] convert(f32_) + converted_f64 = ${tname}[] convert(f64_) + converted_bf16 = ${tname}[] convert(bf16_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( + converted_f16, converted_f32, converted_f64, converted_bf16) } - )"; - HloModuleConfig config; - DebugOptions debug_options = GetDebugOptionsForTest(); - config.set_debug_options(debug_options); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_text, config)); - EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e4m3fnuz[] convert(f16[] f16_) - converted_f32 = f8e4m3fnuz[] convert(f32[] f32_) - converted_f64 = f8e4m3fnuz[] convert(f64[] f64_) - converted_bf16 = f8e4m3fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertSignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { s8_ = s8[] parameter(0) s16_ = s16[] parameter(1) s32_ = s32[] parameter(2) s64_ = s64[] parameter(3) - converted_s8 = f8e4m3fnuz[] convert(s8[] s8_) - converted_s16 = f8e4m3fnuz[] convert(s16[] s16_) - converted_s32 = f8e4m3fnuz[] convert(s32[] s32_) - converted_s64 = f8e4m3fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_s8 = ${tname}[] convert(s8_) + converted_s16 = ${tname}[] convert(s16_) + converted_s32 = ${tname}[] convert(s32_) + converted_s64 = ${tname}[] convert(s64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_s8, converted_s16, converted_s32, converted_s64) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E4FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E4FNUZ - ENTRY ConvertToF8E4FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) { +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertUnsignedToFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { u8_ = u8[] parameter(0) u16_ = u16[] parameter(1) u32_ = u32[] parameter(2) u64_ = u64[] parameter(3) - converted_u8 = f8e4m3fnuz[] convert(u8[] u8_) - converted_u16 = f8e4m3fnuz[] convert(u16[] u16_) - converted_u32 = f8e4m3fnuz[] convert(u32[] u32_) - converted_u64 = f8e4m3fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[], f8e4m3fnuz[]) tuple( + converted_u8 = ${tname}[] convert(u8_) + converted_u16 = ${tname}[] convert(u16_) + converted_u32 = ${tname}[] convert(u32_) + converted_u64 = ${tname}[] convert(u64_) + ROOT tuple = (${tname}[], ${tname}[], ${tname}[], ${tname}[]) tuple( converted_u8, converted_u16, converted_u32, converted_u64) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_f16 f8e4m3fnuz[], to_f32 f8e4m3fnuz[], to_f64 f8e4m3fnuz[], to_bf16 f8e4m3fnuz[]) -> (f16[], f32[], f64[], bf16[]) { - to_f16 = f8e4m3fnuz[] parameter(0) - to_f32 = f8e4m3fnuz[] parameter(1) - to_f64 = f8e4m3fnuz[] parameter(2) - to_bf16 = f8e4m3fnuz[] parameter(3) - f16_ = f16[] convert(f8e4m3fnuz[] to_f16) - f32_ = f32[] convert(f8e4m3fnuz[] to_f32) - f64_ = f64[] convert(f8e4m3fnuz[] to_f64) - bf16_ = bf16[] convert(f8e4m3fnuz[] to_f64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToFloats) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_f16 = ${tname}[] parameter(0) + to_f32 = ${tname}[] parameter(1) + to_f64 = ${tname}[] parameter(2) + to_bf16 = ${tname}[] parameter(3) + f16_ = f16[] convert(to_f16) + f32_ = f32[] convert(to_f32) + f64_ = f64[] convert(to_f64) + bf16_ = bf16[] convert(to_f64) ROOT tuple = (f16[], f32[], f64[], bf16[]) tuple(f16_, f32_, f64_, bf16_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_s8 f8e4m3fnuz[], to_s16 f8e4m3fnuz[], to_s32 f8e4m3fnuz[], - to_s64 f8e4m3fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e4m3fnuz[] parameter(0) - to_s16 = f8e4m3fnuz[] parameter(1) - to_s32 = f8e4m3fnuz[] parameter(2) - to_s64 = f8e4m3fnuz[] parameter(3) - s8_ = s8[] convert(f8e4m3fnuz[] to_s8) - s16_ = s16[] convert(f8e4m3fnuz[] to_s16) - s32_ = s32[] convert(f8e4m3fnuz[] to_s32) - s64_ = s64[] convert(f8e4m3fnuz[] to_s64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToSigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_s8 = ${tname}[] parameter(0) + to_s16 = ${tname}[] parameter(1) + to_s32 = ${tname}[] parameter(2) + to_s64 = ${tname}[] parameter(3) + s8_ = s8[] convert(to_s8) + s16_ = s16[] convert(to_s16) + s32_ = s32[] convert(to_s32) + s64_ = s64[] convert(to_s64) ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ(to_u8 f8e4m3fnuz[], to_u16 f8e4m3fnuz[], to_u32 f8e4m3fnuz[], - to_u64 f8e4m3fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e4m3fnuz[] parameter(0) - to_u16 = f8e4m3fnuz[] parameter(1) - to_u32 = f8e4m3fnuz[] parameter(2) - to_u64 = f8e4m3fnuz[] parameter(3) - u8_ = u8[] convert(f8e4m3fnuz[] to_u8) - u16_ = u16[] convert(f8e4m3fnuz[] to_u16) - u32_ = u32[] convert(f8e4m3fnuz[] to_u32) - u64_ = u64[] convert(f8e4m3fnuz[] to_u64) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToUnsigned) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_u8 = ${tname}[] parameter(0) + to_u16 = ${tname}[] parameter(1) + to_u32 = ${tname}[] parameter(2) + to_u64 = ${tname}[] parameter(3) + u8_ = u8[] convert(to_u8) + u16_ = u16[] convert(to_u16) + u32_ = u32[] convert(to_u32) + u64_ = u64[] convert(to_u64) ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E4FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E4FNUZ - ENTRY ConvertFromF8E4FNUZ - (to_c64 f8e4m3fnuz[], to_c128 f8e4m3fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e4m3fnuz[] parameter(0) - to_c128 = f8e4m3fnuz[] parameter(1) - c64_ = c64[] convert(f8e4m3fnuz[] to_c64) - c128_ = c128[] convert(f8e4m3fnuz[] to_c128) + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, ConvertFloatToComplex) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m + ENTRY main { + to_c64 = ${tname}[] parameter(0) + to_c128 = ${tname}[] parameter(1) + c64_ = c64[] convert(to_c64) + c128_ = c128[] convert(to_c128) ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTypeConversionTest(hlo_text); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E4FNUZ +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, CompareFloat) { + auto tname = this->TypeName(); + if (std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - p0 = f8e4m3fnuz[4] parameter(0) - p1 = f8e4m3fnuz[4] parameter(1) + p0 = ${tname}[4] parameter(0) + p1 = ${tname}[4] parameter(1) ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E4M3FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E4FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E4FNUZ +})", + {{"${tname}", tname}}); + Literal lhs = LiteralUtil::CreateR1( + {TypeParam(1.), TypeParam(2.), TypeParam(3.), TypeParam(4.)}); + Literal rhs = LiteralUtil::CreateR1( + {TypeParam(4.), TypeParam(4.), TypeParam(2.), TypeParam(1.)}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {&lhs, &rhs}); +} + +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, IotaFloat) { + auto tname = this->TypeName(); + if (std::is_same() || + std::is_same() || + std::is_same() || + std::is_same() || + std::is_same()) { + GTEST_SKIP() << "Skipping test for type " << tname; + } + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule m ENTRY main { - ROOT iota_ = f8e4m3fnuz[4] iota(), iota_dimension=0 + ROOT iota_ = ${tname}[4] iota(), iota_dimension=0 } - )"; - - RunTest(hlo_text, {}); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertFloatsToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ - (f16_ f16[], f32_ f32[], f64_ f64[], bf16_ bf16[]) -> (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - f16_ = f16[] parameter(0) - f32_ = f32[] parameter(1) - f64_ = f64[] parameter(2) - bf16_ = bf16[] parameter(3) - converted_f16 = f8e5m2fnuz[] convert(f16[] f16_) - converted_f32 = f8e5m2fnuz[] convert(f32[] f32_) - converted_f64 = f8e5m2fnuz[] convert(f64[] f64_) - converted_bf16 = f8e5m2fnuz[] convert(bf16[] bf16_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_f16, converted_f32, converted_f64, converted_bf16) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertSignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (s8_ s8[], s16_ s16[], s32_ s32[], s64_ s64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - s8_ = s8[] parameter(0) - s16_ = s16[] parameter(1) - s32_ = s32[] parameter(2) - s64_ = s64[] parameter(3) - converted_s8 = f8e5m2fnuz[] convert(s8[] s8_) - converted_s16 = f8e5m2fnuz[] convert(s16[] s16_) - converted_s32 = f8e5m2fnuz[] convert(s32[] s32_) - converted_s64 = f8e5m2fnuz[] convert(s64[] s64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_s8, converted_s16, converted_s32, converted_s64) - } - )"); + )", + {{"${tname}", tname}}); + ElementalIrEmitterExecutionTest::RunTest(hlo_text, {}); } -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertUnsignedToF8E5FNUZ) { - RunTypeConversionTest(R"( - HloModule convertToF8E5FNUZ - ENTRY ConvertToF8E5FNUZ (u8_ u8[], u16_ u16[], u32_ u32[], u64_ u64[]) -> - (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) { - u8_ = u8[] parameter(0) - u16_ = u16[] parameter(1) - u32_ = u32[] parameter(2) - u64_ = u64[] parameter(3) - converted_u8 = f8e5m2fnuz[] convert(u8[] u8_) - converted_u16 = f8e5m2fnuz[] convert(u16[] u16_) - converted_u32 = f8e5m2fnuz[] convert(u32[] u32_) - converted_u64 = f8e5m2fnuz[] convert(u64[] u64_) - ROOT tuple = (f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[], f8e5m2fnuz[]) tuple( - converted_u8, converted_u16, converted_u32, converted_u64) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToFloat) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_f16 f8e5m2fnuz[], to_f32 f8e5m2fnuz[], to_f64 f8e5m2fnuz[]) -> (f16[], f32[], f64[]) { - to_f16 = f8e5m2fnuz[] parameter(0) - to_f32 = f8e5m2fnuz[] parameter(1) - to_f64 = f8e5m2fnuz[] parameter(2) - f16_ = f16[] convert(f8e5m2fnuz[] to_f16) - f32_ = f32[] convert(f8e5m2fnuz[] to_f32) - f64_ = f64[] convert(f8e5m2fnuz[] to_f64) - ROOT tuple = (f16[], f32[], f64[]) tuple(f16_, f32_, f64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToSigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_s8 f8e5m2fnuz[], to_s16 f8e5m2fnuz[], to_s32 f8e5m2fnuz[], - to_s64 f8e5m2fnuz[]) -> (s8[], s16[], s32[], s64[]) { - to_s8 = f8e5m2fnuz[] parameter(0) - to_s16 = f8e5m2fnuz[] parameter(1) - to_s32 = f8e5m2fnuz[] parameter(2) - to_s64 = f8e5m2fnuz[] parameter(3) - s8_ = s8[] convert(f8e5m2fnuz[] to_s8) - s16_ = s16[] convert(f8e5m2fnuz[] to_s16) - s32_ = s32[] convert(f8e5m2fnuz[] to_s32) - s64_ = s64[] convert(f8e5m2fnuz[] to_s64) - ROOT tuple = (s8[], s16[], s32[], s64[]) tuple(s8_, s16_, s32_, s64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToUnsigned) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ(to_u8 f8e5m2fnuz[], to_u16 f8e5m2fnuz[], to_u32 f8e5m2fnuz[], - to_u64 f8e5m2fnuz[]) -> (u8[], u16[], u32[], u64[]) { - to_u8 = f8e5m2fnuz[] parameter(0) - to_u16 = f8e5m2fnuz[] parameter(1) - to_u32 = f8e5m2fnuz[] parameter(2) - to_u64 = f8e5m2fnuz[] parameter(3) - u8_ = u8[] convert(f8e5m2fnuz[] to_u8) - u16_ = u16[] convert(f8e5m2fnuz[] to_u16) - u32_ = u32[] convert(f8e5m2fnuz[] to_u32) - u64_ = u64[] convert(f8e5m2fnuz[] to_u64) - ROOT tuple = (u8[], u16[], u32[], u64[]) tuple(u8_, u16_, u32_, u64_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, ConvertF8E5FNUZToComplex) { - RunTypeConversionTest(R"( - HloModule convertFromF8E5FNUZ - ENTRY ConvertFromF8E5FNUZ - (to_c64 f8e5m2fnuz[], to_c128 f8e5m2fnuz[]) -> (c64[], c128[]) { - to_c64 = f8e5m2fnuz[] parameter(0) - to_c128 = f8e5m2fnuz[] parameter(1) - c64_ = c64[] convert(f8e5m2fnuz[] to_c64) - c128_ = c128[] convert(f8e5m2fnuz[] to_c128) - ROOT tuple = (c64[], c128[]) tuple(c64_, c128_) - } - )"); -} - -XLA_TEST_F(ElementalIrEmitterExecutionTest, CompareF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule compareF8E5FNUZ - ENTRY main { - p0 = f8e5m2fnuz[4] parameter(0) - p1 = f8e5m2fnuz[4] parameter(1) - ROOT cmp = pred[4] compare(p0, p1), direction=LT -})"; - - Literal lhs = LiteralUtil::CreateR1({1, 2, 3, 4}); - Literal rhs = LiteralUtil::CreateR1({4, 3, 2, 1}); - lhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(lhs); - rhs = LiteralUtil::ConvertF32ToF8E5M2FNUZ(rhs); - RunTest(hlo_text, {&lhs, &rhs}); -} +TYPED_TEST(ElementalIrEmitterExecutionTypedTest, BatchDotFloat) { + auto tname = this->TypeName(); + const auto hlo_text = absl::StrReplaceAll(R"( + HloModule matmul -XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E5FNUZ) { - constexpr char hlo_text[] = R"( - HloModule IotaF8E5FNUZ ENTRY main { - ROOT iota_ = f8e5m2fnuz[4] iota(), iota_dimension=0 + x = ${tname}[8,16] parameter(0) + y = ${tname}[8,16,32] parameter(1) + ROOT dot = ${tname}[8,32] dot(x, y), lhs_batch_dims={0}, + rhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={1} } - )"; + )", + {{"${tname}", tname}}); + HloModuleConfig config; + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + config.set_debug_options(debug_options); - RunTest(hlo_text, {}); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + HloTestBase::ParseAndReturnVerifiedModule(hlo_text, config)); + EXPECT_TRUE( + HloTestBase::RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); } XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc index aa81fce3e80e1c..0c522d4f37b83e 100644 --- a/third_party/xla/xla/service/executable.cc +++ b/third_party/xla/xla/service/executable.cc @@ -57,25 +57,6 @@ void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, unowned_indices_.insert(index); } -absl::StatusOr ExecutionInput::ToShapedBuffer( - se::DeviceMemoryAllocator* allocator, int device_ordinal) const { - const Shape& input_shape = shape(); - ShapedBuffer shaped_buffer(input_shape, device_ordinal); - for (const auto& index_buffer : Buffers()) { - const tensorflow::se::OwningDeviceMemory* mem = - index_buffer.second.AsOwningDeviceMemory(); - if (mem != nullptr && (mem->allocator() != allocator || - mem->device_ordinal() != device_ordinal)) { - return tsl::errors::InvalidArgument("Device buffer at index ", - index_buffer.first.ToString(), - " has mismatching allocator/device"); - } - shaped_buffer.set_buffer(index_buffer.second.AsDeviceMemoryBase(), - index_buffer.first); - } - return std::move(shaped_buffer); -} - absl::StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, diff --git a/third_party/xla/xla/service/executable.h b/third_party/xla/xla/service/executable.h index f1a08cd570b88b..e9c4abe32e1b8b 100644 --- a/third_party/xla/xla/service/executable.h +++ b/third_party/xla/xla/service/executable.h @@ -18,11 +18,14 @@ limitations under the License. #include #include +#include #include #include +#include "absl/base/thread_annotations.h" #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "absl/types/variant.h" #include "xla/debug_options_flags.h" @@ -98,9 +101,6 @@ class ExecutionInput { absl::Status SetDynamicShape(Shape dynamic_shape); - absl::StatusOr ToShapedBuffer( - se::DeviceMemoryAllocator* allocator, int device_ordinal) const; - void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { *buffers_.mutable_element(index) = std::move(buffer); } @@ -376,6 +376,12 @@ class Executable { // Dumping helpers. void set_hlo_proto(std::unique_ptr hlo_proto) { + // Despite the mutex lock, this function is NOT thread-safe. + // The mutex is needed for the lazy HLO module loading in `hlo_proto()`. + // Since both `hlo_proto()` and `buffer_assignment_proto()` return a + // pointer to hlo_proto_, having the mutex is not enough to make this + // function thread-safe. + absl::MutexLock lock(&hlo_proto_mutex_); hlo_proto_ = std::move(hlo_proto); } bool dumping_snapshot() const { @@ -385,6 +391,7 @@ class Executable { } HloProto const* hlo_proto() const { + absl::MutexLock lock(&hlo_proto_mutex_); if (hlo_proto_ != nullptr && !hlo_proto_->has_hlo_module()) { *hlo_proto_->mutable_hlo_module() = module().ToProto(); } @@ -392,6 +399,7 @@ class Executable { } const BufferAssignmentProto* buffer_assignment_proto() const { + absl::MutexLock lock(&hlo_proto_mutex_); return hlo_proto_ != nullptr && hlo_proto_->has_buffer_assignment() ? &hlo_proto_->buffer_assignment() : nullptr; @@ -441,7 +449,8 @@ class Executable { // hlo_proto_->buffer_assignment is set and hlo_proto_->hlo_module isn't, the // hlo_module proto will be computed on the fly when requested with // hlo_proto(). This avoids wasting CPU and memory if the proto isn't needed. - std::unique_ptr hlo_proto_; + std::unique_ptr hlo_proto_ ABSL_GUARDED_BY(hlo_proto_mutex_); + mutable absl::Mutex hlo_proto_mutex_; }; } // namespace xla diff --git a/third_party/xla/xla/service/executable_test.cc b/third_party/xla/xla/service/executable_test.cc new file mode 100644 index 00000000000000..388b7be1bd44a7 --- /dev/null +++ b/third_party/xla/xla/service/executable_test.cc @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/executable.h" + +#include +#include +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_execution_profile.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/env.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/platform/threadpool.h" + +namespace xla { +namespace { + +class TestExecutable : public Executable { + public: + explicit TestExecutable(std::shared_ptr module) + : Executable{std::move(module)} {} + + absl::StatusOr ExecuteAsyncOnStream( + const ServiceExecutableRunOptions* run_options, + std::vector arguments, + HloExecutionProfile* hlo_execution_profile) override { + return absl::UnimplementedError("Not needed for this test."); + } +}; + +class ExecutableTest : public HloTestBase {}; + +TEST_F(ExecutableTest, HloProtoGetterIsThreadCompatible) { + // Executable::hlo_proto() is doing some lazy initialization of a + // part of `hlo_proto_`. This test ensures that this is done in a + // thread-compatible way. + // Note that this test needs to run with --config=tsan to reliably + // detect any potential data races. + constexpr std::string_view kHloModule = R"( + HloModule module + + ENTRY main { + ROOT c = s32[] constant(1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + TestExecutable executable(module); + + auto proto = std::make_unique(); + executable.set_hlo_proto(std::move(proto)); + + { + tsl::thread::ThreadPool pool(tsl::Env::Default(), "test", + /*num_threads=*/2); + for (int i = 0; i < 2; ++i) { + pool.Schedule([&] { executable.hlo_proto()->SerializeAsString(); }); + } + } +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/flatten_call_graph.h b/third_party/xla/xla/service/flatten_call_graph.h index a3fe5aada4f1c5..ff5af7039ee3b7 100644 --- a/third_party/xla/xla/service/flatten_call_graph.h +++ b/third_party/xla/xla/service/flatten_call_graph.h @@ -18,26 +18,7 @@ limitations under the License. #ifndef XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ #define XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Flattening associates each call site with a unique computation (for -// sequential calling contexts) This simplifies buffer assignment and -// points-to analysis (see b/36865746 for details). -class FlattenCallGraph : public HloModulePass { - public: - absl::string_view name() const override { return "flatten-call-graph"; } - - // Duplicates computations called from multiple call- or while-nodes to - // flatten the call graph. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #endif // XLA_SERVICE_FLATTEN_CALL_GRAPH_H_ diff --git a/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc b/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc index fe3a1041933cb5..22916aa084fc47 100644 --- a/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc +++ b/third_party/xla/xla/service/float8_fnuz_ir_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/APFloat.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Intrinsics.h" #include "xla/primitive_util.h" @@ -39,6 +40,10 @@ namespace { absl::StatusOr PrimitiveTypeToAPFloatSemantics( PrimitiveType type) { switch (type) { + case F8E3M4: + return &llvm::APFloat::Float8E3M4(); + case F8E4M3: + return &llvm::APFloat::Float8E4M3(); case F8E4M3B11FNUZ: return &llvm::APFloat::Float8E4M3B11FNUZ(); case F8E4M3FN: @@ -67,6 +72,8 @@ absl::StatusOr PrimitiveTypeToAPFloatSemantics( absl::StatusOr PrimitiveTypeToLLVMType(llvm::IRBuilder<>* b, PrimitiveType type) { switch (type) { + case F8E3M4: + case F8E4M3: case F8E4M3B11FNUZ: case F8E4M3FN: case F8E4M3FNUZ: diff --git a/third_party/xla/xla/service/float_normalization.h b/third_party/xla/xla/service/float_normalization.h index 1862edcb14d8a6..db54be02642d8d 100644 --- a/third_party/xla/xla/service/float_normalization.h +++ b/third_party/xla/xla/service/float_normalization.h @@ -16,90 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_FLOAT_NORMALIZATION_H_ #define XLA_SERVICE_FLOAT_NORMALIZATION_H_ -#include - -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/float_support.h" - -namespace xla { - -// A pass which adds type conversions (e.g. F32 <-> BF16) for HLO instructions -// that do not support low-precision input/output or mixed precision, according -// to the passed-in backend-specific FloatSupport instance. -class FloatNormalization : public HloModulePass { - public: - explicit FloatNormalization(const FloatSupport* float_support) - : float_support_(float_support), - name_("float-normalization-" + - primitive_util::LowercasePrimitiveTypeName( - float_support_->LowPrecisionType())) {} - - ~FloatNormalization() override = default; - absl::string_view name() const override { return name_; } - - // Run float normalization on the given computation. Returns whether the - // computation was changed. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const FloatSupport* float_support_; - std::string name_; -}; - -// A pass that unconditionally removes the mixed F32/BF16 uses in HLO -// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike -// FloatNormalization, this pass does not use a backend-specific -// FloatSupport, and does not change HLOs that have BF16 data if they do not -// use mixed precision; it removes mixed precision even if the backend supports -// it. This pass is used to make the HLO module valid for other HLO passes which -// do not support mixed precision. Currently, this pass is only used by the -// Despecializer, not by our normal compilation flow on TPU. -class BFloat16MixedPrecisionRemoval : public HloModulePass { - public: - BFloat16MixedPrecisionRemoval() = default; - - ~BFloat16MixedPrecisionRemoval() override = default; - - absl::string_view name() const override { - return "bf16-mixed-precision-removal"; - } - - // Run mixed precision removal on the given computation. Returns whether the - // computation was changed. - using HloPassInterface::Run; - absl::StatusOr Run(HloModule* module, - const absl::flat_hash_set& - execution_threads) override { - FloatNormalization normalization(&no_mixed_precision_support_); - return normalization.Run(module, execution_threads); - } - - private: - class BFloat16SupportForMixedPrecisionRemoval : public FloatSupport { - public: - BFloat16SupportForMixedPrecisionRemoval() : FloatSupport(BF16) {} - - ~BFloat16SupportForMixedPrecisionRemoval() override = default; - - bool SupportsLowPrecisionOperand(const HloInstruction& hlo, - int64_t operand_index) const override { - return true; - } - - bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { - return true; - } - - bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { - return false; - } - } no_mixed_precision_support_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #endif // XLA_SERVICE_FLOAT_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/float_support.cc b/third_party/xla/xla/service/float_support.cc index 3bcbfdd7dcb144..e2e6bf28daf6af 100644 --- a/third_party/xla/xla/service/float_support.cc +++ b/third_party/xla/xla/service/float_support.cc @@ -97,6 +97,7 @@ bool FloatSupport::EffectiveOperandPrecisionIsOutputPrecision( case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPad: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: diff --git a/third_party/xla/xla/service/fusion_constant_sinking.h b/third_party/xla/xla/service/fusion_constant_sinking.h index 96e3fb95e7adc8..15f15e41af05c9 100644 --- a/third_party/xla/xla/service/fusion_constant_sinking.h +++ b/third_party/xla/xla/service/fusion_constant_sinking.h @@ -16,24 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ #define XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass which sinks constants into fusion computations. -class FusionConstantSinking : public HloModulePass { - public: - absl::string_view name() const override { return "fusion_constant_sinking"; } - - // Run fusion constant sinking operations on the given module. Returns whether - // the module was changed (constant expressions folded). - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/fusion_constant_sinking.h" #endif // XLA_SERVICE_FUSION_CONSTANT_SINKING_H_ diff --git a/third_party/xla/xla/service/fusion_node_indexing_evaluation_test.cc b/third_party/xla/xla/service/fusion_node_indexing_evaluation_test.cc index 5c3790e60c41ea..8becc665ecde1e 100644 --- a/third_party/xla/xla/service/fusion_node_indexing_evaluation_test.cc +++ b/third_party/xla/xla/service/fusion_node_indexing_evaluation_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gather_expander.cc b/third_party/xla/xla/service/gather_expander.cc index 8277a7b902ad02..095cb318b28338 100644 --- a/third_party/xla/xla/service/gather_expander.cc +++ b/third_party/xla/xla/service/gather_expander.cc @@ -15,12 +15,18 @@ limitations under the License. #include "xla/service/gather_expander.h" +#include +#include #include #include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" +#include "xla/service/gather_scatter_utils.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/while_util.h" #include "xla/util.h" @@ -107,45 +113,18 @@ absl::StatusOr AdjustBatchDimsInAccumulator( return ExpandFirstDimIntoNDims(accumulator, batch_dim_bounds); } -// Expand an index vector from the start_indices tensor into a vector that can -// be used to dynamic-slice out of the gather operand. -absl::StatusOr ExpandIndexVectorIntoOperandSpace( - HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers, - int64_t operand_rank) { - HloComputation* computation = index_vector->parent(); - const Shape& index_shape = index_vector->shape(); - - if (operand_rank == 0) { - // This is Gather from a scalar. So, the index vector in operand space must - // be a zero-sized vector. - return computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); - } - - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); - - // We extract out individual components from the smaller index and concatenate - // them (interspersing zeros as needed) into the larger index. - std::vector expanded_index_components; - - for (int i = 0; i < operand_rank; i++) { - int64_t index_vector_dim_index = - FindIndex(dim_numbers.start_index_map(), i); - if (index_vector_dim_index != dim_numbers.start_index_map_size()) { - TF_ASSIGN_OR_RETURN( - HloInstruction * component_to_concat, - MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, - /*limit_indices=*/{index_vector_dim_index + 1}, - /*strides=*/{1})); - expanded_index_components.push_back(component_to_concat); - } else { - expanded_index_components.push_back(zero); - } - } - - return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +// Returns the dimensions in a slice that are either collapsed or corresponding +// to an operand batching dimension. +std::vector GetDegeneratedSliceDims( + const GatherDimensionNumbers& dim_numbers) { + absl::Span collapsed_slice_dims = + dim_numbers.collapsed_slice_dims(); + absl::Span batching_dims = dim_numbers.operand_batching_dims(); + std::vector removed_dims; + removed_dims.reserve(collapsed_slice_dims.size() + batching_dims.size()); + absl::c_copy(collapsed_slice_dims, std::back_inserter(removed_dims)); + absl::c_copy(batching_dims, std::back_inserter(removed_dims)); + return removed_dims; } // This generates the body of the while that implements the main data movement @@ -158,11 +137,11 @@ absl::StatusOr> GatherLoopBody( HloInstruction* const operand = incoming_loop_state[0]; HloInstruction* const start_indices = incoming_loop_state[1]; HloInstruction* const output_accumulator = incoming_loop_state[2]; + const Shape& orig_start_indices_shape = gather.operand(1)->shape(); bool has_scalar_indices = start_indices->shape().dimensions_size() == 1; - CHECK_EQ(has_scalar_indices, - dim_numbers.index_vector_dim() == - gather.operand(1)->shape().dimensions_size()); + CHECK_EQ(has_scalar_indices, dim_numbers.index_vector_dim() == + orig_start_indices_shape.dimensions_size()); HloInstruction* induction_var_as_vector = MakeBroadcastHlo(induction_var, /*broadcast_dimensions=*/{}, @@ -197,8 +176,11 @@ absl::StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * gathered_slice_start, - ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers, - operand->shape().dimensions_size())); + ExpandIndexVectorIntoOperandSpace( + orig_start_indices_shape, operand->shape().dimensions_size(), + dim_numbers.index_vector_dim(), dim_numbers.start_index_map(), + dim_numbers.start_indices_batching_dims(), + dim_numbers.operand_batching_dims(), index_vector, induction_var)); TF_ASSIGN_OR_RETURN(HloInstruction * gathered_slice, MakeDynamicSliceHlo(operand, gathered_slice_start, @@ -206,7 +188,8 @@ absl::StatusOr> GatherLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction* const gathered_slice_with_dims_collapsed, - ElideDegenerateDims(gathered_slice, dim_numbers.collapsed_slice_dims())); + ElideDegenerateDims(gathered_slice, + GetDegeneratedSliceDims(dim_numbers))); TF_ASSIGN_OR_RETURN( HloInstruction* const gathered_slice_for_update, @@ -239,7 +222,8 @@ HloInstruction* CreateGatherLoopAccumulatorInitValue( accumulator_state_shape_dims.reserve(1 + slice_sizes.size()); accumulator_state_shape_dims.push_back(gather_loop_trip_count); for (int64_t i = 0; i < slice_sizes.size(); i++) { - if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) { + if (!absl::c_linear_search(dim_numbers.collapsed_slice_dims(), i) && + !absl::c_linear_search(dim_numbers.operand_batching_dims(), i)) { accumulator_state_shape_dims.push_back(slice_sizes[i]); } } @@ -287,7 +271,7 @@ int64_t GatherLoopTripCount(HloInstruction* gather_instr) { return trip_count; } -int64_t GatherIsBroadcast(HloInstruction* gather_instr) { +bool GatherIsBroadcast(HloInstruction* gather_instr) { return absl::c_equal(gather_instr->gather_slice_sizes(), gather_instr->operand(0)->shape().dimensions()); } @@ -336,7 +320,7 @@ absl::StatusOr GatherExpander::ExpandInstruction( return MakeScalarLike(gather_instr, 0); } Shape broadcast_operand_shape = ShapeUtil::DeleteDimensions( - gather_instr->gather_dimension_numbers().collapsed_slice_dims(), + GetDegeneratedSliceDims(gather_instr->gather_dimension_numbers()), gather_instr->operand(0)->shape()); TF_ASSIGN_OR_RETURN(HloInstruction * broadcast_operand, MakeReshapeHlo(broadcast_operand_shape, @@ -412,8 +396,7 @@ bool GatherExpander::InstructionMatchesPattern(HloInstruction* inst) { // which can be represented without a loop -- i.e. we only simplify // gathers which have a trip count of 1. (mode_ == kEliminateAllGathers || GatherLoopTripCount(inst) == 1 || - absl::c_equal(inst->gather_slice_sizes(), - inst->operand(0)->shape().dimensions())); + GatherIsBroadcast(inst)); } } // namespace xla diff --git a/third_party/xla/xla/service/gather_expander.h b/third_party/xla/xla/service/gather_expander.h index 8f43141c3119e8..334024cb12e0d7 100644 --- a/third_party/xla/xla/service/gather_expander.h +++ b/third_party/xla/xla/service/gather_expander.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GATHER_EXPANDER_H_ #define XLA_SERVICE_GATHER_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/gather_expander_test.cc b/third_party/xla/xla/service/gather_expander_test.cc index 271fabdbda2c62..a7f39c326336c1 100644 --- a/third_party/xla/xla/service/gather_expander_test.cc +++ b/third_party/xla/xla/service/gather_expander_test.cc @@ -15,6 +15,12 @@ limitations under the License. #include "xla/service/gather_expander.h" +#include + +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -23,7 +29,19 @@ limitations under the License. namespace xla { namespace { -using GatherExpanderTest = HloTestBase; +class GatherExpanderTest : public HloTestBase { + protected: + void CheckWhileBody(HloModule* module, absl::string_view expected) { + std::vector while_instructions = + FindInstructions(module, HloOpcode::kWhile); + EXPECT_EQ(while_instructions.size(), 1); + HloComputation* while_body = while_instructions[0]->while_body(); + EXPECT_TRUE(*RunFileCheck( + while_body->ToString( + HloPrintOptions{}.set_include_layout_in_shapes(false)), + expected)); + } +}; TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) { const std::string hlo_text = R"( @@ -230,5 +248,169 @@ ENTRY main { module->VerifyOrAddFailure("after-gather-expander."); } +TEST_F(GatherExpanderTest, GatherIsBroadcastBatchDim) { + const std::string hlo_text = R"( +HloModule test + +ENTRY main { + operand = s32[1,3,1] parameter(0) + indices = s32[1,5] parameter(1) + ROOT gather = s32[1,3,5] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={2}, + start_index_map={0}, + index_vector_dim=2, + slice_sizes={1,3,1}, + operand_batching_dims={0}, + start_indices_batching_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + GatherExpander pass(GatherExpander::kEliminateSimpleGathers); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get())); + ASSERT_TRUE(changed); + ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(), + {HloOpcode::kGather})); + ASSERT_TRUE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(), + {HloOpcode::kBroadcast})); + module->VerifyOrAddFailure("after-gather-expander."); +} + +TEST_F(GatherExpanderTest, GatherToLoopWithBatchDims) { + const std::string hlo_text = R"( +HloModule GatherWithBatchDims + +ENTRY main { + operand = s32[5,2] parameter(0) + indices = s32[5,1] parameter(1) + ROOT gather = s32[5,1] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={}, + start_index_map={1}, + index_vector_dim=1, + operand_batching_dims={0}, + start_indices_batching_dims={0}, + slice_sizes={1,1} +} +)"; + const std::string expected = R"( + //CHECK: (s32[], s32[5,2], s32[5,1], s32[5,1])) -> (s32[], s32[5,2], s32[5,1], s32[5,1]) { + //CHECK: %[[PARAM:.*]] = (s32[], s32[5,2], s32[5,1], s32[5,1]) parameter(0) + //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index= + //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[5,2] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=1 + //CHECK: %[[START_INDICES:.*]] = s32[5,1] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=2 + //CHECK: %[[RESULT:.*]] = s32[5,1] get-tuple-element((s32[], s32[5,2], s32[5,1], s32[5,1]) %[[PARAM]]), index=3 + + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) + //CHECK: %[[I_1D_2:.*]] = s32[1] broadcast(s32[] %[[I]]) + + //CHECK: %[[START_INDICES_INDEX_D1_PAD:.*]] = s32[] constant(0) + //CHECK: %[[START_INDICES_INDEX_VECTOR:.*]] = s32[2] pad(s32[1] %[[I_1D_2]], s32[] %[[START_INDICES_INDEX_D1_PAD]]), padding=0_1 + //CHECK: %[[START_INDICES_INDEX_D0_SLICE:.*]] = s32[1] slice(s32[2] %[[START_INDICES_INDEX_VECTOR]]), slice={[0:1]} + //CHECK: %[[START_INDICES_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_D0_SLICE]]) + //CHECK: %[[START_INDICES_INDEX_D1_SLICE:.*]] = s32[1] slice(s32[2] %[[START_INDICES_INDEX_VECTOR]]), slice={[1:2]} + //CHECK: %[[START_INDICES_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_D1_SLICE]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1,1] dynamic-slice(s32[5,1] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX_D0]], s32[] %[[START_INDICES_INDEX_D1]]) + + //CHECK: %[[OFFSET_RAW:.*]] = s32[1] reshape(s32[1,1] %[[INDEX_VECTOR]]) + //CHECK: %[[OFFSET:.*]] = s32[1] slice(s32[1] %[[OFFSET_RAW]]) + //CHECK: %[[OPERAND_INDEX:.*]] = s32[2] concatenate(s32[1] %[[I_1D_1]], s32[1] %[[OFFSET]]) + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[2] %[[OPERAND_INDEX]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[2] %[[OPERAND_INDEX]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[RESULT_SLICE_RAW0:.*]] = s32[1,1] dynamic-slice(s32[5,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]]) + + //CHECK: %[[RESULT_SLICE_RAW1:.*]] = s32[1] reshape(s32[1,1] %[[RESULT_SLICE_RAW0]]) + //CHECK: %[[RESULT_SLICE:.*]] = s32[1,1] reshape(s32[1] %[[RESULT_SLICE_RAW1]]) + //CHECK: %[[RESULT_INDEX_D1_PAD:.*]] = s32[] constant(0) + //CHECK: %[[RESULT_INDEX_VECTOR:.*]] = s32[2] pad(s32[1] %[[I_1D_2]], s32[] %[[RESULT_INDEX_D1_PAD]]), padding=0_1 + //CHECK: %[[RESULT_INDEX_D0_SLICE:.*]] = s32[1] slice(s32[2] %[[RESULT_INDEX_VECTOR]]), slice={[0:1]} + //CHECK: %[[RESULT_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[RESULT_INDEX_D0_SLICE]]) + //CHECK: %[[RESULT_INDEX_D1_SLICE:.*]] = s32[1] slice(s32[2] %[[RESULT_INDEX_VECTOR]]), slice={[1:2]} + //CHECK: %[[RESULT_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[RESULT_INDEX_D1_SLICE]]) + //CHECK: %[[UPDATED_RESULT:.*]] = s32[5,1] dynamic-update-slice(s32[5,1] %[[RESULT]], s32[1,1] %[[RESULT_SLICE]], s32[] %[[RESULT_INDEX_D0]], s32[] %[[RESULT_INDEX_D1]]) + + //CHECK: ROOT %{{.*}} = (s32[], s32[5,2], s32[5,1], s32[5,1]) tuple(s32[] %[[I_PLUS_1]], s32[5,2] %[[OPERAND]], s32[5,1] %[[START_INDICES]], s32[5,1] %[[UPDATED_RESULT]]) +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); + ASSERT_TRUE(changed); + CheckWhileBody(module.get(), expected); +} + +TEST_F(GatherExpanderTest, GatherToLoopWithBatchDimsAndCollapsedDims) { + const std::string hlo_text = R"( +HloModule GatherWithBatchAndCollapsedDims + +ENTRY main { + operand = s32[7,3,4,5] parameter(0) + indices = s32[5,2,7] parameter(1) + ROOT gather = s32[5,3,2,7] gather(operand, indices), + offset_dims={1}, + collapsed_slice_dims={2}, + start_index_map={2}, + index_vector_dim=3, + operand_batching_dims={0,3}, + start_indices_batching_dims={2,0}, + slice_sizes={1,3,1,1} +} +)"; + // Compared with the previous test, this test adds complexity in calculating + // the indices for the operand. As such, we mostly check the operand indices + // here. + const std::string expected = R"( + //CHECK: (s32[], s32[7,3,4,5], s32[70], s32[70,3])) -> (s32[], s32[7,3,4,5], s32[70], s32[70,3]) { + //CHECK: %[[PARAM:.*]] = (s32[], s32[7,3,4,5], s32[70], s32[70,3]) parameter(0) + //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=0 + + //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[7,3,4,5] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=1 + //CHECK: %[[START_INDICES:.*]] = s32[70] get-tuple-element((s32[], s32[7,3,4,5], s32[70], s32[70,3]) %[[PARAM]]), index=2 + + //CHECK: %[[CONSTANT7:.*]] = s32[] constant(7) + //CHECK: %[[BD0_RAW:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT7]]) + //CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]]) + //CHECK: %[[CONSTANT0:.*]] = s32[1] constant({0}) + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) + //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]]) + //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[70] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]]) + + //CHECK: %[[OFFSET:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]]) + //CHECK: %[[BD1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT7]]) + //CHECK: %[[CONSTANT2:.*]] = s32[] constant(2) + //CHECK: %[[BD2_RAW:.*]] = s32[] divide(s32[] %[[BD1]], s32[] %[[CONSTANT2]]) + //CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[BD2_RAW]]) + //CHECK: %[[OPERAND_INDEX:.*]] = s32[4] concatenate(s32[1] %[[BD0]], s32[1] %[[CONSTANT0]], s32[1] %[[OFFSET]], s32[1] %[[BD2]]) + + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[2:3]} + //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]]) + //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDEX]]), slice={[3:4]} + //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]]) + //CHECK: %{{.*}} = s32[1,3,1,1] dynamic-slice(s32[7,3,4,5] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]]) +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get())); + ASSERT_TRUE(changed); + CheckWhileBody(module.get(), expected); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/gather_scatter_utils.cc b/third_party/xla/xla/service/gather_scatter_utils.cc index 897f1e7429b50e..1b759524754662 100644 --- a/third_party/xla/xla/service/gather_scatter_utils.cc +++ b/third_party/xla/xla/service/gather_scatter_utils.cc @@ -15,16 +15,79 @@ limitations under the License. #include "xla/service/gather_scatter_utils.h" +#include #include #include #include +#include "absl/algorithm/container.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/permutation_util.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/util.h" namespace xla { +namespace { + +// Generates the HLO to calculate the implicit and explicit batch dimension +// indices and returns the explicit batch dimension to the HLO indices in the +// order of major to minor. +std::vector GenerateExplicitBatchDimIndices( + const Shape& start_indices_shape, int64_t index_vector_dim, + absl::Span start_indices_batching_dims, + HloInstruction* induction_var) { + if (start_indices_batching_dims.empty()) { + return {}; + } + + int64_t rank = start_indices_shape.dimensions_size(); + int64_t num_batch_dims = (rank == index_vector_dim) ? rank : rank - 1; + HloComputation* computation = induction_var->parent(); + HloInstruction* divident = induction_var; + const Shape& shape = induction_var->shape(); + + std::vector explicit_batch_dim_indices( + start_indices_batching_dims.size()); + + for (int64_t i = start_indices_shape.dimensions_size() - 1; i >= 0; i--) { + if (i == index_vector_dim) { + continue; + } + auto it = absl::c_find(start_indices_batching_dims, i); + num_batch_dims--; // Reuse the variable to count remaining batch dims. + if (num_batch_dims == 0) { + if (it != start_indices_batching_dims.end()) { + // Avoid generating a remainder that just returns the divident itself. + explicit_batch_dim_indices[it - start_indices_batching_dims.begin()] = + divident; + } + break; + } + + HloInstruction* divisor = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(start_indices_shape.dimensions(i)))); + if (it != start_indices_batching_dims.end()) { + explicit_batch_dim_indices[it - start_indices_batching_dims.begin()] = + computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kRemainder, divident, divisor)); + } + + divident = computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDivide, divident, divisor)); + } + + return explicit_batch_dim_indices; +} + +} // namespace + absl::StatusOr TransformStartIndices( HloInstruction* indices, int64_t index_vector_dim) { int64_t rank = indices->shape().rank(); @@ -99,4 +162,56 @@ absl::StatusOr MoveDimensionToEnd(HloInstruction* operand, return MaybeTranspose(operand, permutation); } +absl::StatusOr ExpandIndexVectorIntoOperandSpace( + const Shape& start_indices_shape, int64_t operand_rank, + int64_t index_vector_dim, absl::Span start_index_map, + absl::Span start_indices_batching_dims, + absl::Span operand_batching_dims, + HloInstruction* index_vector, HloInstruction* induction_var) { + HloComputation* computation = index_vector->parent(); + const Shape& index_shape = index_vector->shape(); + + if (operand_rank == 0) { + // This is a Gather/Scatter from/on a scalar. Return a zero-sized vector of + // indices. + return computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); + } + + HloInstruction* zero = + computation->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); + + // We extract out individual components from the smaller index and concatenate + // them (interspersing zeros as needed) into the larger index. + std::vector expanded_index_components; + std::vector explicit_batch_dim_indices = + GenerateExplicitBatchDimIndices(start_indices_shape, index_vector_dim, + start_indices_batching_dims, + induction_var); + int64_t seen_explicit_batch_dims = 0; + for (int i = 0; i < operand_rank; i++) { + int64_t index_vector_dim_index = FindIndex(start_index_map, i); + if (index_vector_dim_index != start_index_map.size()) { + TF_ASSIGN_OR_RETURN( + HloInstruction * component_to_concat, + MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, + /*limit_indices=*/{index_vector_dim_index + 1}, + /*strides=*/{1})); + expanded_index_components.push_back(component_to_concat); + } else { + if (absl::c_linear_search(operand_batching_dims, i)) { + expanded_index_components.push_back(MakeBroadcastHlo( + explicit_batch_dim_indices[seen_explicit_batch_dims++], + /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/{1})); + } else { + expanded_index_components.push_back(zero); + } + } + } + + return MakeConcatHlo(expanded_index_components, /*dimension=*/0); +} + } // namespace xla diff --git a/third_party/xla/xla/service/gather_scatter_utils.h b/third_party/xla/xla/service/gather_scatter_utils.h index 3ce7eb43701b09..cf1368373583b2 100644 --- a/third_party/xla/xla/service/gather_scatter_utils.h +++ b/third_party/xla/xla/service/gather_scatter_utils.h @@ -53,6 +53,15 @@ absl::StatusOr MoveDimensionToEnd(HloInstruction* operand, size_t dimension, size_t rank); +// Expands an index vector from the start_indices tensor into a vector that can +// be used to dynamic-slice out of the gather/scatter operand. +absl::StatusOr ExpandIndexVectorIntoOperandSpace( + const Shape& start_indices_shape, int64_t operand_rank, + int64_t index_vector_dim, absl::Span start_index_map, + absl::Span start_indices_batching_dims, + absl::Span operand_batching_dims, + HloInstruction* index_vector, HloInstruction* induction_var); + } // namespace xla #endif // XLA_SERVICE_GATHER_SCATTER_UTILS_H_ diff --git a/third_party/xla/xla/service/gather_simplifier.h b/third_party/xla/xla/service/gather_simplifier.h index 6d2b37502cd4c0..0cbcadc09b0d26 100644 --- a/third_party/xla/xla/service/gather_simplifier.h +++ b/third_party/xla/xla/service/gather_simplifier.h @@ -16,36 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_GATHER_SIMPLIFIER_H_ #define XLA_SERVICE_GATHER_SIMPLIFIER_H_ -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// This pass rewrites gather operations into a combination of transposes, -// reshapes and a simpler gather. -// -// The output gather's attributes will have the following characteristics: -// - start_indices is a two-dimensional tensor -// - index_vector_dim is 1 -// - start_index_map is [0, 1, ...] -// - collapsed_slice_dims is [] -// - offset_dims is [1, 2, ...] -// -// The purpose of this pass is to check whether this transformation has any -// performance implications. -class GatherSimplifier : public OpExpanderPass { - public: - absl::string_view name() const override { return "gather_simplifier"; } - - static bool IsSimplifiedGather(const HloGatherInstruction* gather); - - protected: - bool InstructionMatchesPattern(HloInstruction* inst) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* inst) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" #endif // XLA_SERVICE_GATHER_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/generic_transfer_manager.cc b/third_party/xla/xla/service/generic_transfer_manager.cc index 436bf144e6a7cf..85168eb9a969c9 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.cc +++ b/third_party/xla/xla/service/generic_transfer_manager.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/service/shaped_buffer.h" @@ -300,4 +301,15 @@ Shape GenericTransferManager::HostShapeToDeviceShape( return device_shape; } +absl::StatusOr GenericTransferManager::ChooseCompactLayoutForShape( + const Shape& host_shape) const { + Shape compact_shape = LayoutUtil::GetWithDefaultLayout(host_shape); + if (PackSubbyteTypes() && + primitive_util::IsSubByteNonPredType(compact_shape.element_type())) { + compact_shape.mutable_layout()->set_element_size_in_bits( + primitive_util::BitWidth(compact_shape.element_type())); + } + return compact_shape; +} + } // namespace xla diff --git a/third_party/xla/xla/service/generic_transfer_manager.h b/third_party/xla/xla/service/generic_transfer_manager.h index 3503cff66b7dc0..22a2178792b16e 100644 --- a/third_party/xla/xla/service/generic_transfer_manager.h +++ b/third_party/xla/xla/service/generic_transfer_manager.h @@ -83,6 +83,9 @@ class GenericTransferManager : public TransferManager { Shape HostShapeToDeviceShape(const Shape& host_shape) const override; + absl::StatusOr ChooseCompactLayoutForShape( + const Shape& host_shape) const override; + private: // Transfer a memory block of the given size from the device source into the // 'destination' buffer. diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc index eb8cb7afa85004..8153090dbcbf55 100644 --- a/third_party/xla/xla/service/generic_transfer_manager_test.cc +++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc @@ -181,5 +181,13 @@ TEST_F(GenericTransferManagerTest, TransferLiteralFromDeviceInt4) { } } +TEST_F(GenericTransferManagerTest, ChooseCompactLayoutForShape) { + auto shape = ShapeUtil::MakeShape(S4, {2, 2}); + TF_ASSERT_OK_AND_ASSIGN(auto compact_shape, + transfer_manager_.ChooseCompactLayoutForShape(shape)); + EXPECT_TRUE(Shape::Equal().IgnoreLayout()(compact_shape, shape)); + EXPECT_EQ(compact_shape.layout().element_size_in_bits(), 4); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/global_device_id.h b/third_party/xla/xla/service/global_device_id.h index 78f4c0a3dc914a..92f30b9f1c11cb 100644 --- a/third_party/xla/xla/service/global_device_id.h +++ b/third_party/xla/xla/service/global_device_id.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index cbe63c0f75718e..ff08fad968f9f5 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -73,7 +73,7 @@ tf_proto_library( protodeps = [ "//xla:xla_data_proto", "//xla:autotuning_proto", - "@local_tsl//tsl/protobuf:dnn_proto", + "//xla/tsl/protobuf:dnn_proto", ], ) @@ -114,10 +114,10 @@ cc_library( name = "gpu_memory_space_assignment", hdrs = ["gpu_memory_space_assignment.h"], deps = [ + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_value", "@com_google_absl//absl/status", ], @@ -155,17 +155,19 @@ xla_test( "//xla:status_macros", "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client/lib:constants", "//xla/ffi", "//xla/ffi:execution_context", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:executable", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep @@ -367,6 +369,8 @@ cc_library( "//xla/service/gpu/runtime:nccl_all_reduce_thunk", "//xla/service/gpu/runtime:nccl_all_to_all_thunk", "//xla/service/gpu/runtime:nccl_api", + "//xla/service/gpu/runtime:nccl_clique", + "//xla/service/gpu/runtime:nccl_clique_key", "//xla/service/gpu/runtime:nccl_collective_broadcast_thunk", "//xla/service/gpu/runtime:nccl_collective_permute_thunk", "//xla/service/gpu/runtime:nccl_collective_thunk", @@ -392,6 +396,7 @@ cc_library( "//xla/stream_executor:launch_dim", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -419,7 +424,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:human_readable_json", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", "@triton//:TritonDialects", ] + if_gpu_is_configured([ "//xla/service/gpu/runtime:cholesky_thunk", @@ -530,7 +534,7 @@ cc_library( hdrs = ["buffer_allocations.h"], deps = [ "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -594,9 +598,8 @@ cc_library( "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:scoped_activate_context", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -723,7 +726,7 @@ xla_cc_test( deps = [ ":reduction_utils", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/strings", @@ -816,17 +819,15 @@ cc_library( hdrs = ["triton_fusion_analysis.h"], deps = [ ":cudnn_support_utils", - ":matmul_utils", + ":matmul_indexing_utils", ":triton_tiling_propagation", "//xla:autotuning_proto_cc", - "//xla:shape_util", "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:instruction_fusion", - "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -861,6 +862,7 @@ cc_library( hdrs = ["split_k_gemm_rewriter.h"], deps = [ ":ir_emission_utils", + ":matmul_indexing_utils", ":matmul_utils", ":triton_fusion_analysis", ":triton_tiling_propagation", @@ -915,6 +917,43 @@ xla_cc_test( ], ) +cc_library( + name = "matmul_indexing_utils", + srcs = ["matmul_indexing_utils.cc"], + hdrs = ["matmul_indexing_utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "matmul_indexing_utils_test", + srcs = ["matmul_indexing_utils_test.cc"], + deps = [ + ":matmul_indexing_utils", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "matmul_utils", srcs = ["matmul_utils.cc"], @@ -926,6 +965,7 @@ cc_library( deps = [ ":backend_configs_cc", ":ir_emission_utils", + ":matmul_indexing_utils", "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", @@ -935,9 +975,11 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:algorithm_util", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:numeric_options", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -959,7 +1001,6 @@ cc_library( "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", ]) + if_rocm_is_configured([ #keep sorted - "//xla/stream_executor/platform:dso_loader", "//xla/stream_executor/rocm:amdhipblaslt_plugin", "//xla/stream_executor/rocm:hipblas_lt_header", "@local_config_rocm//rocm:rocm_headers", @@ -976,7 +1017,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", @@ -997,9 +1038,10 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -1026,9 +1068,10 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1050,8 +1093,10 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1111,7 +1156,7 @@ xla_cc_test( ":gpu_device_info_for_tests", "//xla:test", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", @@ -1148,7 +1193,7 @@ xla_cc_test( "//xla:test", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", @@ -1206,11 +1251,14 @@ cc_library( "//xla/service:generic_transfer_manager", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/stream_executor/sycl:sycl_platform_id", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -1239,6 +1287,7 @@ cc_library( "//xla/service:float_support", "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", ], ) @@ -1266,18 +1315,18 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:buffer_value", "//xla/service:dump", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:logical_buffer", "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/stream_executor/rocm:rocm_platform_id", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -1299,6 +1348,25 @@ cc_library( ], ) +cc_library( + name = "fusion_dispatch_pipeline", + srcs = ["fusion_dispatch_pipeline.cc"], + hdrs = ["fusion_dispatch_pipeline.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service/gpu/transforms:fusion_block_level_rewriter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + ], +) + cc_library( name = "fusion_pipeline", srcs = ["fusion_pipeline.cc"], @@ -1307,13 +1375,14 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", "//xla/service:cpu_gpu_shape_verifier", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_cse", - "//xla/service:hlo_dce", "//xla/service:hlo_verifier", "//xla/service:layout_assignment", "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/transforms:fusion_block_level_rewriter", "//xla/service/gpu/transforms:fusion_merger", "//xla/service/gpu/transforms:horizontal_input_fusion", "//xla/service/gpu/transforms:horizontal_loop_fusion", @@ -1332,12 +1401,12 @@ cc_library( hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"], deps = [ "//xla:xla_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", "//xla/service:copy_insertion", "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", "//xla/service:hlo_verifier", "//xla/service:layout_assignment", "//xla/service:loop_schedule_linearizer", @@ -1364,6 +1433,7 @@ cc_library( ":cublas_cudnn", ":executable_proto_cc", ":execution_stream_assignment", + ":fusion_dispatch_pipeline", ":fusion_pipeline", ":gpu_constants", ":gpu_executable", @@ -1405,10 +1475,69 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "//xla/hlo/analysis:hlo_dataflow_analysis", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:all_gather_broadcast_reorder", + "//xla/hlo/transforms:all_gather_combiner", + "//xla/hlo/transforms:all_reduce_combiner", + "//xla/hlo/transforms:all_reduce_contiguous", + "//xla/hlo/transforms:all_reduce_folder", + "//xla/hlo/transforms:async_collective_creator", + "//xla/hlo/transforms:bitcast_dtypes_expander", + "//xla/hlo/transforms:broadcast_canonicalizer", + "//xla/hlo/transforms:collective_quantizer", + "//xla/hlo/transforms:collectives_schedule_linearizer", + "//xla/hlo/transforms:comparison_expander", + "//xla/hlo/transforms:conditional_canonicalizer", + "//xla/hlo/transforms:convert_async_collectives_to_sync", + "//xla/hlo/transforms:convert_memory_placement_to_internal_annotations", + "//xla/hlo/transforms:convert_mover", + "//xla/hlo/transforms:convolution_4d_expander", + "//xla/hlo/transforms:convolution_pred_expander", + "//xla/hlo/transforms:dot_decomposer", + "//xla/hlo/transforms:dot_merger", + "//xla/hlo/transforms:dynamic_dimension_simplifier", + "//xla/hlo/transforms:dynamic_index_splitter", + "//xla/hlo/transforms:eigh_expander", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms:gather_simplifier", + "//xla/hlo/transforms:hlo_computation_deduplicator", + "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:hlo_rematerialization", + "//xla/hlo/transforms:host_memory_transfer_asyncifier", + "//xla/hlo/transforms:host_offload_legalize", + "//xla/hlo/transforms:host_offloader", + "//xla/hlo/transforms:logistic_expander", + "//xla/hlo/transforms:operand_upcaster", + "//xla/hlo/transforms:optimization_barrier_expander", + "//xla/hlo/transforms:optimize_input_output_buffer_alias", + "//xla/hlo/transforms:qr_expander", + "//xla/hlo/transforms:real_imag_expander", + "//xla/hlo/transforms:reduce_decomposer", + "//xla/hlo/transforms:reduce_window_rewriter", + "//xla/hlo/transforms:reshape_decomposer", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:result_caster", + "//xla/hlo/transforms:rng_bit_generator_expander", + "//xla/hlo/transforms:rng_expander", + "//xla/hlo/transforms:simplify_fp_conversions", + "//xla/hlo/transforms:slice_sinker", + "//xla/hlo/transforms:sort_simplifier", + "//xla/hlo/transforms:stable_sort_expander", + "//xla/hlo/transforms:stochastic_convert_decomposer", + "//xla/hlo/transforms:sub_byte_normalization", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/transforms:while_loop_trip_count_annotator", + "//xla/hlo/transforms:zero_sized_hlo_elimination", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner", @@ -1431,10 +1560,12 @@ cc_library( "//xla/service/gpu/transforms:convert_async_collectives_to_sync", "//xla/service/gpu/transforms:cudnn_custom_call_converter", "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", + "//xla/service/gpu/transforms:dot_algorithm_rewriter", "//xla/service/gpu/transforms:dot_dimension_sorter", "//xla/service/gpu/transforms:dot_operand_converter", "//xla/service/gpu/transforms:double_buffer_loop_unrolling", "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", + "//xla/service/gpu/transforms:fusion_block_level_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter", "//xla/service/gpu/transforms:gemm_fusion", @@ -1463,105 +1594,49 @@ cc_library( "//xla/service/gpu/transforms:windowed_einsum_handler", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:collective_permute_motion", - "//xla/service:algebraic_simplifier", - "//xla/service:all_gather_broadcast_reorder", - "//xla/service:all_gather_combiner", - "//xla/service:all_reduce_combiner", - "//xla/service:all_reduce_contiguous", - "//xla/service:all_reduce_folder", "//xla/service:all_reduce_promotion", "//xla/service:all_reduce_reassociate", - "//xla/service:async_collective_creator", + "//xla/service:all_reduce_simplifier", "//xla/service:batched_gather_scatter_normalizer", "//xla/service:batchnorm_expander", - "//xla/service:bitcast_dtypes_expander", - "//xla/service:broadcast_canonicalizer", "//xla/service:buffer_assignment", "//xla/service:buffer_value", "//xla/service:call_inliner", "//xla/service:collective_permute_decomposer", "//xla/service:collective_pipeliner", - "//xla/service:collective_quantizer", - "//xla/service:collectives_schedule_linearizer", - "//xla/service:comparison_expander", "//xla/service:compiler", - "//xla/service:conditional_canonicalizer", "//xla/service:conditional_simplifier", - "//xla/service:convert_async_collectives_to_sync", - "//xla/service:convert_memory_placement_to_internal_annotations", - "//xla/service:convert_mover", - "//xla/service:convolution_4d_expander", - "//xla/service:convolution_pred_expander", "//xla/service:copy_insertion", "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:dot_decomposer", - "//xla/service:dot_merger", "//xla/service:dump", "//xla/service:dynamic_dimension_inference", - "//xla/service:dynamic_dimension_simplifier", - "//xla/service:dynamic_index_splitter", "//xla/service:dynamic_padder", - "//xla/service:eigh_expander", "//xla/service:executable", "//xla/service:export_hlo", - "//xla/service:flatten_call_graph", - "//xla/service:float_normalization", "//xla/service:float_support", "//xla/service:gather_expander", - "//xla/service:gather_simplifier", - "//xla/service:hlo_computation_deduplicator", - "//xla/service:hlo_constant_folding", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_cse", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", - "//xla/service:hlo_rematerialization", "//xla/service:hlo_verifier", - "//xla/service:host_memory_transfer_asyncifier", - "//xla/service:host_offload_legalize", - "//xla/service:host_offloader", "//xla/service:layout_assignment", "//xla/service:layout_normalization", "//xla/service:llvm_compiler", "//xla/service:logical_buffer", - "//xla/service:logistic_expander", "//xla/service:loop_schedule_linearizer", - "//xla/service:operand_upcaster", - "//xla/service:optimization_barrier_expander", - "//xla/service:optimize_input_output_buffer_alias", - "//xla/service:qr_expander", - "//xla/service:real_imag_expander", - "//xla/service:reduce_decomposer", "//xla/service:reduce_scatter_combiner", "//xla/service:reduce_scatter_reassociate", - "//xla/service:reduce_window_rewriter", - "//xla/service:reshape_decomposer", - "//xla/service:reshape_mover", - "//xla/service:result_caster", - "//xla/service:rng_bit_generator_expander", - "//xla/service:rng_expander", + "//xla/service:scatter_determinism_expander", "//xla/service:scatter_expander", "//xla/service:scatter_simplifier", "//xla/service:sharding_remover", - "//xla/service:simplify_fp_conversions", - "//xla/service:slice_sinker", "//xla/service:slow_operation_alarm", - "//xla/service:sort_simplifier", - "//xla/service:stable_sort_expander", - "//xla/service:stochastic_convert_decomposer", - "//xla/service:sub_byte_normalization", "//xla/service:topk_rewriter", "//xla/service:transpose_folding", - "//xla/service:tuple_simplifier", "//xla/service:while_loop_all_reduce_code_motion", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_simplifier", - "//xla/service:while_loop_trip_count_annotator", - "//xla/service:zero_sized_hlo_elimination", - "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/integrations:device_mem_allocator", "//xla/stream_executor:device_description", @@ -1569,8 +1644,6 @@ cc_library( "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:semantic_version", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/tsl/lib/monitoring:counter", "//xla:autotune_results_proto_cc", "//xla:debug_options_flags", @@ -1595,7 +1668,11 @@ cc_library( # go/keep-sorted end ]) + xla_internal(["service:export_hlo"]) + if_google([ "//xla/hlo/experimental/auto_sharding", - ]), + ]) + [ + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", + ], ) xla_test( @@ -1608,6 +1685,7 @@ xla_test( backends = ["gpu"], data = ["gpu_compiler_test_autotune_db.textproto"], deps = [ + ":backend_configs_cc", ":gpu_compiler", ":gpu_hlo_schedule", ":metrics", @@ -1632,8 +1710,12 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/monitoring:collected_metrics", + "//xla/tsl/lib/monitoring:collection_registry", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", @@ -1661,11 +1743,11 @@ xla_test( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_memory_scheduler", + "//xla/hlo/transforms:hlo_rematerialization", "//xla/hlo/utils:hlo_matchers", "//xla/service:buffer_value", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_memory_scheduler", - "//xla/service:hlo_rematerialization", "//xla/service/gpu/transforms:stream_attribute_annotator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -1701,9 +1783,9 @@ cc_library( "nvptx_compiler_registration.cc", ], tags = [ + "cuda-only", "gpu", "manual", - "no_rocm", ], deps = [ ":nvptx_compiler_impl", @@ -1723,9 +1805,9 @@ cc_library( "nvptx_compiler.h", ], tags = [ + "cuda-only", "gpu", "manual", - "no_rocm", ], deps = [ ":buffer_sharing", @@ -1739,30 +1821,31 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:convert_mover", + "//xla/hlo/transforms:dot_dimension_merger", + "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:tuple_simplifier", "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:algebraic_simplifier", "//xla/service:call_inliner", - "//xla/service:convert_mover", - "//xla/service:dot_dimension_merger", "//xla/service:dump", - "//xla/service:float_normalization", "//xla/service:float_support", - "//xla/service:hlo_constant_folding", "//xla/service:hlo_cse", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/autotuning:conv_algorithm_picker", "//xla/service/gpu/autotuning:gemm_algorithm_picker", "//xla/service/gpu/autotuning:gemm_fusion_autotuner", "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/llvm_gpu_backend:nvptx_utils", "//xla/service/gpu/transforms:algebraic_simplifier", "//xla/service/gpu/transforms:conv_padding_legalization", "//xla/service/gpu/transforms:conv_rewriter", @@ -1781,12 +1864,13 @@ cc_library( "//xla/service/gpu/transforms:sort_rewriter", "//xla/service/gpu/transforms:triangular_solve_rewriter", "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/stream_executor/cuda:cuda_diagnostics", - "//xla/stream_executor/cuda:cuda_driver", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/cuda:nvjitlink", "//xla/stream_executor/cuda:nvjitlink_support", @@ -1835,7 +1919,7 @@ xla_test( "gpu_a100", ], tags = [ - "no_rocm", + "cuda-only", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], deps = [ @@ -1845,12 +1929,12 @@ xla_test( ":nvptx_compiler_impl", "//xla:util", "//xla:xla_proto_cc", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:buffer_assignment", "//xla/service:buffer_value", - "//xla/service:hlo_ordering", "//xla/service:logical_buffer", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -1873,7 +1957,7 @@ xla_test( "gpu", ], tags = [ - "no_rocm", + "cuda-only", "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], deps = [ @@ -1937,9 +2021,9 @@ xla_cc_test( "//xla/service:gpu_plugin", "//xla/service:platform_util", "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # build_cleaner: keep "@com_google_absl//absl/strings", @@ -1988,17 +2072,17 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:convert_mover", + "//xla/hlo/transforms:dot_dimension_merger", + "//xla/hlo/transforms:float_normalization", + "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:tuple_simplifier", "//xla/service:call_inliner", - "//xla/service:convert_mover", - "//xla/service:dot_dimension_merger", - "//xla/service:float_normalization", "//xla/service:float_support", - "//xla/service:hlo_constant_folding", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/autotuning:conv_algorithm_picker", "//xla/service/gpu/autotuning:gemm_algorithm_picker", @@ -2084,10 +2168,10 @@ cc_library( "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_memory_scheduler", "//xla/hlo/utils:hlo_query", "//xla/service:buffer_value", "//xla/service:collective_ops_utils", - "//xla/service:hlo_memory_scheduler", "//xla/service:latency_hiding_scheduler", "//xla/service:p2p_schedule_preparation", "//xla/service:profile_guided_latency_estimator", @@ -2121,13 +2205,14 @@ xla_test( ], backends = ["gpu"], deps = [ + ":gpu_compiler", ":gpu_hlo_schedule", "//xla:shape_util", + "//xla/hlo/analysis:hlo_ordering", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -2154,10 +2239,10 @@ cc_library( deps = [ "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/service:collective_ops_utils", "//xla/service:collective_pipeliner", - "//xla/service:hlo_parser", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -2173,9 +2258,9 @@ xla_cc_test( ":gpu_p2p_pipeliner", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_verifier", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -2195,18 +2280,18 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:hlo_constant_folding", "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:sort_simplifier", + "//xla/hlo/transforms:tuple_simplifier", "//xla/service:conditional_simplifier", "//xla/service:gather_expander", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_dce", "//xla/service:hlo_module_config", - "//xla/service:reshape_mover", "//xla/service:scatter_expander", "//xla/service:sharding_propagation", - "//xla/service:sort_simplifier", - "//xla/service:tuple_simplifier", "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_simplifier", "//xla/service/gpu/transforms:algebraic_simplifier", @@ -2230,10 +2315,11 @@ xla_cc_test( "//xla:util", "//xla/client:executable_build_options", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:algebraic_simplifier", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla/service/spmd/shardy:constants", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2257,18 +2343,27 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", - "//xla/service:while_loop_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], ) -cuda_library( +cc_library( + name = "stream_executor_util_kernel_stub", + srcs = ["stream_executor_util_kernel_stub.cc"], +) + +gpu_kernel_library( name = "stream_executor_util_kernel", srcs = ["stream_executor_util_kernel.cu.cc"], - tags = ["no_rocm"], - deps = ["@local_config_cuda//cuda:cuda_headers"], + tags = ["gpu"], + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) cc_library( @@ -2276,7 +2371,6 @@ cc_library( srcs = ["stream_executor_util.cc"], hdrs = ["stream_executor_util.h"], copts = tsl_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":cublas_cudnn", ":launch_dimensions", @@ -2286,12 +2380,17 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", - "//xla/stream_executor", "//xla/stream_executor:data_type", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/algorithm:container", @@ -2308,10 +2407,10 @@ cc_library( "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", - ] + if_cuda_is_configured([ - ":stream_executor_util_kernel", - ]), + ] + if_gpu_is_configured( + if_false = [":stream_executor_util_kernel_stub"], + if_true = [":stream_executor_util_kernel"], + ), ) xla_cc_test( @@ -2400,7 +2499,6 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla/service:hlo_module_config", - "//xla/stream_executor", "//xla/stream_executor:device_memory_handle", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/gpu:asm_compiler", @@ -2415,7 +2513,13 @@ cc_library( ]) + if_rocm_is_configured([ # keep sorted "@local_config_rocm//rocm:rocm_headers", - ]), + ]) + [ + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + ], ) gpu_kernel_library( @@ -2444,10 +2548,11 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla/service:hlo_module_config", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", @@ -2493,8 +2598,8 @@ cc_library( "//xla:permutation_util", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dataflow_analysis", "//xla/service:instruction_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", @@ -2517,7 +2622,7 @@ xla_cc_test( deps = [ ":gpu_fusible", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -2610,13 +2715,16 @@ xla_cc_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:float_normalization", + "//xla/hlo/transforms:float_normalization", "//xla/service:hlo_verifier", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2628,6 +2736,21 @@ cc_library( "//xla/tsl/lib/monitoring:counter", "//xla/tsl/lib/monitoring:gauge", "//xla/tsl/lib/monitoring:sampler", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:stacktrace", + ], +) + +xla_cc_test( + name = "metrics_test", + srcs = ["metrics_test.cc"], + deps = [ + ":metrics", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/monitoring:collected_metrics", + "//xla/tsl/lib/monitoring:collection_registry", + "@local_tsl//tsl/platform:test", ], ) @@ -2638,8 +2761,11 @@ cc_library( deps = [ "//xla:types", "//xla:util", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/gpu:gpu_stream_header", "@com_google_absl//absl/status", @@ -2673,7 +2799,10 @@ tsl_gpu_library( "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_finder", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2720,7 +2849,7 @@ xla_cc_test( ], deps = [ ":hlo_fusion_stats", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -2877,13 +3006,21 @@ xla_test( "//xla:literal", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:backend", + "//xla/service:platform_util", "//xla/service/gpu/autotuning:autotuner_util", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/gpu:gpu_timer", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:mock_gpu_executor", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -2997,3 +3134,121 @@ xla_cc_test( "@local_tsl//tsl/platform:test", ], ) + +cc_library( + name = "gpu_collective_combiner_utils", + srcs = ["gpu_collective_combiner_utils.cc"], + hdrs = ["gpu_collective_combiner_utils.h"], + deps = [ + ":backend_configs_cc", + "//xla/hlo/ir:hlo", + "//xla/service:collective_utils", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "gpu_collective_combiner_utils_test", + srcs = ["gpu_collective_combiner_utils_test.cc"], + deps = [ + ":backend_configs_cc", + ":gpu_collective_combiner_utils", + ":gpu_hlo_schedule", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_pipeliner", + "//xla/service:collective_utils", + "//xla/service:hlo_module_config", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "gpu_all_gather_combiner", + srcs = ["all_gather_combiner.cc"], + hdrs = ["all_gather_combiner.h"], + deps = [ + ":backend_configs_cc", + ":gpu_collective_combiner_utils", + ":gpu_hlo_schedule", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:all_gather_combiner", + "//xla/service:hlo_domain_map", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = + "gpu_all_gather_combiner_test", + srcs = ["all_gather_combiner_test.cc"], + deps = [ + ":gpu_all_gather_combiner", + "//xla/hlo/ir:hlo", + "//xla/service:collective_utils", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "reduce_scatter_combiner", + srcs = ["reduce_scatter_combiner.cc"], + hdrs = ["reduce_scatter_combiner.h"], + deps = [ + ":backend_configs_cc", + ":gpu_collective_combiner_utils", + ":gpu_hlo_schedule", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_domain_map", + "//xla/service:reduce_scatter_combiner", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduce_scatter_combiner_test", + srcs = ["reduce_scatter_combiner_test.cc"], + deps = [ + ":reduce_scatter_combiner", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/service:collective_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/service/gpu/all_gather_combiner.cc b/third_party/xla/xla/service/gpu/all_gather_combiner.cc new file mode 100644 index 00000000000000..17acef2a2623c7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/all_gather_combiner.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/all_gather_combiner.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_collective_combiner_utils.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/hlo_domain_map.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +namespace { + +std::optional PipelinedCombinerKey( + const HloInstruction* instruction, const HloDomainMap& domain_map, + bool combine_by_dim, bool combine_different_dtypes) { + auto combined_key = AllGatherCombiner::CombineKey( + instruction, domain_map, combine_by_dim, combine_different_dtypes); + if (!combined_key.has_value()) { + return std::nullopt; + } + auto backend_config = instruction->backend_config(); + if (!backend_config.ok()) { + return std::nullopt; + } + bool is_pipelined = + backend_config->collective_backend_config().is_pipelined(); + if (!is_pipelined) { + return std::nullopt; + } + AllGatherCombiner::GetGroupKeyExtraArgs(*combined_key) + .append(" " + std::to_string(static_cast(is_pipelined))); + return combined_key.value(); +} + +} // namespace + +absl::StatusOr GpuAllGatherCombiner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Combiner threshold is specified. Running parent pass code. + if (combine_threshold_in_bytes_ != default_combine_threshold_in_bytes_) { + return AllGatherCombiner::Run(module, execution_threads); + } + + // Pass configuration heuristics are not enabled. Running parent pass code. + if (!module->config() + .debug_options() + .xla_gpu_enable_heuristic_pass_configuration()) { + return AllGatherCombiner::Run(module, execution_threads); + } + + // Combine as much as possible for pipelined collectives. + int previous_combiner_threshold = combine_threshold_in_bytes_; + combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( + *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, + HloOpcode::kAllGather, pointer_size_); + TF_ASSIGN_OR_RETURN( + bool combined_pipelined_instructions, + RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); + + // Use previous combiner thresholds after we combine pipelined collectives. + // The rest is combined by the parent pass code. + combine_threshold_in_bytes_ = previous_combiner_threshold; + TF_ASSIGN_OR_RETURN(bool combined_rest, + AllGatherCombiner::Run(module, execution_threads)); + return combined_pipelined_instructions || combined_rest; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/all_gather_combiner.h b/third_party/xla/xla/service/gpu/all_gather_combiner.h new file mode 100644 index 00000000000000..5ed8af754b5a79 --- /dev/null +++ b/third_party/xla/xla/service/gpu/all_gather_combiner.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_ALL_GATHER_COMBINER_H_ +#define XLA_SERVICE_GPU_ALL_GATHER_COMBINER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Similarly to `AllGatherCombiner` pass, combines `AllGather` ops into a single +// larger `AllGather` op to maximize network bandwidth usage. Additionally, if +// no flags are set for combiner thresholds, the pass will try to figure out the +// optimal combiner threshold by itself. +class GpuAllGatherCombiner : public AllGatherCombiner { + public: + GpuAllGatherCombiner(const se::DeviceDescription& device_info, + const int default_combine_threshold_in_bytes, + int64_t combine_threshold_in_bytes, + int64_t combine_threshold_count, bool combine_by_dim, + bool combine_different_dtypes, int64_t pointer_size) + : AllGatherCombiner(combine_threshold_in_bytes, combine_threshold_count, + combine_by_dim, combine_different_dtypes), + device_info_(device_info), + default_combine_threshold_in_bytes_(default_combine_threshold_in_bytes), + pointer_size_(pointer_size) {} + + absl::string_view name() const override { return "gpu-all-gather-combiner"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_info_; + const int default_combine_threshold_in_bytes_; + int64_t pointer_size_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_ALL_GATHER_COMBINER_H_ diff --git a/third_party/xla/xla/service/gpu/all_gather_combiner_test.cc b/third_party/xla/xla/service/gpu/all_gather_combiner_test.cc new file mode 100644 index 00000000000000..b819bdfbb3ac85 --- /dev/null +++ b/third_party/xla/xla/service/gpu/all_gather_combiner_test.cc @@ -0,0 +1,356 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/all_gather_combiner.h" + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/collective_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +using GpuAllGatherCombinerTest = HloTestBase; + +using ::stream_executor::DeviceDescription; + +TEST_F(GpuAllGatherCombinerTest, + CombinesPipelinedCollectivesUpToSuggestedThreshold) { + // The IR is the minimal valid example of a while loop with AG inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] all-gather(param.pipelined.0), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] all-gather(param.pipelined.1), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] all-gather(param.pipelined.2), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), + dimensions={0} + ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), + dimensions={0} + ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), + dimensions={0} + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(true); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 90604; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuAllGatherCombiner(device_info, /*default_combine_threshold_in_bytes=*/ + threshold_bytes, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, + /*combine_different_dtypes=*/true, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: all-gather(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]], %[[PIPELINED_PARAM_2]]) + // CHECK-DAG: all-gather(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]]) + // CHECK-DAG: all-gather(%[[NONPIPELINED_PARAM_2]]) + )"; + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +TEST_F(GpuAllGatherCombinerTest, CombinesCollectivesUpToSpecifiedThreshold) { + // The IR is the minimal valid example of a while loop with AG inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] all-gather(param.pipelined.0), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] all-gather(param.pipelined.1), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] all-gather(param.pipelined.2), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), + dimensions={0} + ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), + dimensions={0} + ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), + dimensions={0} + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(true); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 90604; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuAllGatherCombiner(device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultAllGatherCombineThreshold, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, + /*combine_different_dtypes=*/true, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: all-gather(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]]) + // CHECK-DAG: all-gather(%[[PIPELINED_PARAM_2]], %[[NONPIPELINED_PARAM_0]]) + // CHECK-DAG: all-gather(%[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]) + )"; + + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +TEST_F(GpuAllGatherCombinerTest, + CombinesCollectivesUpToDefaultThresholdIfFlagDisabled) { + // The IR is the minimal valid example of a while loop with AG inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] all-gather(param.pipelined.0), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] all-gather(param.pipelined.1), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] all-gather(param.pipelined.2), dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] all-gather(param.nonpipelined.0), + dimensions={0} + ag.nonpipelined.1 = bf16[6,8,128] all-gather(param.nonpipelined.1), + dimensions={0} + ag.nonpipelined.2 = bf16[6,8,128] all-gather(param.nonpipelined.2), + dimensions={0} + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(false); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 90604; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + GpuAllGatherCombiner(device_info, /*default_combine_threshold_in_bytes=*/ + threshold_bytes, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, + /*combine_different_dtypes=*/true, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: all-gather(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]]) + // CHECK-DAG: all-gather(%[[PIPELINED_PARAM_2]], %[[NONPIPELINED_PARAM_0]]) + // CHECK-DAG: all-gather(%[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]) + )"; + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +} // namespace + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index b26bc359c3dcbb..0406274f03b925 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -28,11 +28,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/service/call_inliner.h" -#include "xla/service/convert_mover.h" -#include "xla/service/dot_dimension_merger.h" -#include "xla/service/float_normalization.h" #include "xla/service/float_support.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/autotuning/conv_algorithm_picker.h" @@ -49,13 +52,9 @@ limitations under the License. #include "xla/service/gpu/transforms/gpusolver_rewriter.h" #include "xla/service/gpu/transforms/sort_rewriter.h" #include "xla/service/gpu/transforms/triangular_solve_rewriter.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/reshape_mover.h" -#include "xla/service/tuple_simplifier.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/semantic_version.h" @@ -100,7 +99,6 @@ class ConvBfloat16Support : public FloatSupport { absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator, const se::SemanticVersion& toolkit_version) { // Convert convolutions into CustomCalls to MIOpen, then canonicalize them // (PadInsertion). diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.h b/third_party/xla/xla/service/gpu/amdgpu_compiler.h index cf1248fd6f86d8..89af84cbe1d334 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.h +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.h @@ -25,7 +25,6 @@ limitations under the License. #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" @@ -43,7 +42,6 @@ class AMDGPUCompiler : public GpuCompiler { absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator, const se::SemanticVersion& toolkit_version) override; absl::Status OptimizeHloPostLayoutAssignment( diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD index be63f3888442af..2e21ea0b3cd02c 100644 --- a/third_party/xla/xla/service/gpu/autotuning/BUILD +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -30,8 +30,8 @@ cc_library( srcs = ["gemm_fusion_autotuner.cc"], hdrs = ["gemm_fusion_autotuner.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":autotuner_compile_util", @@ -45,12 +45,14 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:float_normalization", "//xla/hlo/utils:hlo_query", "//xla/pjrt/distributed:key_value_store_interface", "//xla/service:algorithm_util", + "//xla/service:call_inliner", "//xla/service:dump", "//xla/service:executable", - "//xla/service:float_normalization", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_graph_dumper", "//xla/service:hlo_module_config", @@ -60,18 +62,24 @@ cc_library( "//xla/service/gpu:gpu_float_support", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_indexing_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:split_k_gemm_rewriter", "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", + "//xla/service/gpu/transforms:dot_algorithm_rewriter", "//xla/service/gpu/transforms:fusion_wrapper", "//xla/service/gpu/transforms:gemm_rewriter", "//xla/service/gpu/transforms:priority_fusion", - "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:semantic_version", - "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:redzone_allocator", "//xla/tools:hlo_decomposer_lib", "//xla/tsl/lib/core:bits", @@ -109,8 +117,8 @@ xla_test( "gpu_h100", ], tags = [ + "cuda-only", "no_mac", - "no_rocm", ], deps = [ ":autotuner_compile_util", @@ -137,6 +145,7 @@ xla_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", @@ -146,6 +155,7 @@ xla_test( "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", @@ -180,10 +190,11 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu:variant_visitor", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:redzone_allocator", "//xla/tsl/util/proto:proto_utils", "@com_google_absl//absl/container:flat_hash_set", @@ -265,7 +276,9 @@ cc_library( "//xla/service:shaped_buffer", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:ir_emission_utils", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:redzone_allocator", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", @@ -320,12 +333,12 @@ xla_test( "//xla/stream_executor:semantic_version", "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -358,12 +371,16 @@ cc_library( "//xla/service/gpu:gpu_conv_runner", "//xla/service/gpu:hlo_algorithm_denylist", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor:numeric_options", + "//xla/stream_executor:platform", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/tsl/util:env_var", @@ -377,7 +394,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:status", @@ -398,7 +414,7 @@ xla_test( "gpu_amd_any", ], tags = [ - "no_rocm", + "cuda-only", "noasan", "nomsan", ], @@ -410,10 +426,10 @@ xla_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:tuple_simplifier", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service:platform_util", - "//xla/service:tuple_simplifier", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/transforms:conv_rewriter", @@ -450,8 +466,9 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/kernels:custom_kernel", "//xla/service/gpu/kernels:custom_kernel_fusion", - "//xla/stream_executor", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/algorithm:container", @@ -473,7 +490,7 @@ xla_test( backends = [ "gpu", ], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":autotuner_util", ":custom_kernel_fusion_autotuner", @@ -509,7 +526,6 @@ xla_cc_test( ], tags = [ "gpu", - "no_rocm", ], deps = [ ":autotuner_util", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index a652c3d8a103e6..237d06a83b6245 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -463,10 +463,13 @@ absl::StatusOr> TryFindInCache( // Cache miss. if (config.should_require_complete_aot_autotune_results()) { - return NotFound( + absl::Status s = NotFound( "Complete XLA AOT autotuning results are required, but no AOT result " "was found for key: %s", key.ToString()); + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + return s; } TF_ASSIGN_OR_RETURN(AutotuneResult autotune_result, autotune_fn()); @@ -593,5 +596,8 @@ AutotunerUtil::CreateRedzoneAllocator(const AutotuneConfig& config, autotune_cache_stats = CacheStats(); } +constexpr absl::string_view kAutotuneCacheRequiredErrorPayloadKey = + "https://openxla.org/gpu/autotune_cache_hit_required/"; + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index e70b252abb30a0..48bb3e3b291442 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -57,6 +57,12 @@ struct DevicelessConfig { se::DeviceDescription device_description; }; +// Status payload key to put errors at when autotune cache hits are required. +// See absl::Status docs for full details, but methods like +// {Get,Set,Clear}Payload allow manipulating it. The value of the payload is not +// specified and individual sources of this error may provide different values. +extern const absl::string_view kAutotuneCacheRequiredErrorPayloadKey; + class AutotuneCacheKey { public: AutotuneCacheKey(const se::DeviceDescription& device_description, diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc index 9c1d7e016a8f09..c34807e254fd75 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/gpu/autotuning/autotuner_util.h" #include +#include #include #include @@ -58,6 +59,7 @@ namespace { using ::testing::ElementsAre; using ::testing::HasSubstr; using ::testing::IsEmpty; +using ::testing::Ne; using ::testing::Not; using ::testing::TempDir; using ::testing::UnorderedElementsAre; @@ -221,13 +223,16 @@ TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) { auto options = DebugOptions(); options.set_xla_gpu_require_complete_aot_autotune_results(true); AutotuneConfig config(DeviceConfig{executor}, options); + absl::Status s = AutotunerUtil::Autotune(instruction, config, [&] { + return AutotuneResult(); + }).status(); EXPECT_THAT( - AutotunerUtil::Autotune(instruction, config, - [&] { return AutotuneResult(); }), - StatusIs( - absl::StatusCode::kNotFound, - HasSubstr("Complete XLA AOT autotuning results are required, but " - "no AOT result was found for key: operator()(const HloInstruction* gemm, @@ -388,20 +390,18 @@ class GemmAutotuner { << best.status(); return AutotuneResult{}; } // GetBestAlgorithm -}; // GemmAutotuner +}; // class GemmAutotuner // Do Gemm Autotune without stream executor. Use results from autotune cache // only. absl::StatusOr RunOnInstruction(HloInstruction* gemm, - const AutotuneConfig& config, - size_t* num_algorithms_left) { + GemmAutotuner& autotuner) { VLOG(3) << "Loading the autotune result of GemmThunk " << gemm->ToString(); GpuBackendConfig gpu_config = gemm->backend_config().value(); GemmBackendConfig& backend_config = *gpu_config.mutable_gemm_backend_config(); - *num_algorithms_left = 0; // Degenerate gemms replaced with memzero operation, no need to auto tune it. if (backend_config.alpha_real() == 0.0 && backend_config.alpha_imag() == 0.0 && backend_config.beta() == 0.0) { @@ -409,13 +409,12 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, return false; } + const AutotuneConfig& config = autotuner.config(); AutotuneCacheKey key(config.GetModelStr(), *gemm); - GemmAutotuner autotuner(config); TF_ASSIGN_OR_RETURN(AutotuneResult algorithm, AutotunerUtil::Autotune( gemm, config, [&] { return autotuner(gemm, key); })); - *num_algorithms_left = autotuner.num_algorithms_left(); auto old_algorithm = backend_config.selected_algorithm(); bool update_algorithm = IsCublasLtMatmulF8(*gemm) || @@ -442,9 +441,8 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, if (new_algorithm == old_algorithm && backend_config.has_selected_algorithm()) { - // We don't need to update the backend config if - // the algorithm hasn't changed unless previously - // the algorithm wasn't set explicitly. + // We don't need to update the backend config if the algorithm was not + // changed unless previously the algorithm wasn't set explicitly. return false; } @@ -457,17 +455,16 @@ absl::StatusOr RunOnInstruction(HloInstruction* gemm, } absl::StatusOr RunOnComputation(HloComputation* computation, - AutotuneConfig config, + GemmAutotuner& autotuner, size_t* num_algorithms_left) { bool changed = false; for (HloInstruction* instr : computation->instructions()) { if (IsCublasGemm(*instr)) { - size_t num_left; - TF_ASSIGN_OR_RETURN(bool result, - RunOnInstruction(instr, config, &num_left)); + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr, autotuner)); // Gathering statistics on the algorithms left after tuning (for testing) - *num_algorithms_left = std::max(*num_algorithms_left, num_left); + *num_algorithms_left = + std::max(*num_algorithms_left, autotuner.num_algorithms_left()); changed |= result; } } @@ -487,11 +484,11 @@ absl::StatusOr GemmAlgorithmPicker::Run( VLOG(2) << "GEMM auto-tuning disabled, GemmAlgorithmPicker returning early"; return false; } - + GemmAutotuner autotuner(config_); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { - TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, config_, + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, autotuner, &num_algorithms_left_)); changed |= result; } diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc index b88c2b0916ce08..6526e3338fb6c5 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc @@ -35,10 +35,10 @@ limitations under the License. #include "xla/stream_executor/semantic_version.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/xla.pb.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla::gpu { namespace { @@ -50,25 +50,21 @@ class GemmAlgorithmPickerTest : public HloTestBase, public: GemmAlgorithmPickerTest() { AutotunerUtil::ClearAutotuneResults(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cublaslt(GetParam()); debug_options.set_xla_gpu_enable_triton_gemm(false); return debug_options; } - const se::DeviceDescription& device_desc() { - return backend().default_stream_executor()->GetDeviceDescription(); - } - se::StreamExecutor* stream_exec() { return backend().default_stream_executor(); } - const se::DeviceDescription& gpu_device_desc() { + const se::DeviceDescription& device_desc() { return stream_exec()->GetDeviceDescription(); } const se::GpuComputeCapability& gpu_comp() { - return gpu_device_desc().gpu_compute_capability(); + return device_desc().gpu_compute_capability(); } void SetUp() override { @@ -103,7 +99,7 @@ class GemmAlgorithmPickerTest : public HloTestBase, }; TEST_P(GemmAlgorithmPickerTest, BlasGetVersion) { - auto* blas = backend().default_stream_executor()->AsBlas(); + auto* blas = stream_exec()->AsBlas(); ASSERT_TRUE(blas != nullptr); std::string version; ASSERT_TRUE(blas->GetVersion(&version).ok()); @@ -148,6 +144,15 @@ ENTRY main { if (num_left1 < 2) { GTEST_SKIP() << "Too few algorithms left after the first step"; } + + // Test that the function to get current stream value works fine: + auto* blas = stream_exec()->AsBlas(); + ASSERT_TRUE(blas != nullptr); + TF_ASSERT_OK_AND_ASSIGN(bool is_main_stream, blas->IsMainStreamSet()); + // ROCM only: CUDA blas API does not reset stream after each blas call. + if (std::holds_alternative(gpu_comp())) { + ASSERT_TRUE(is_main_stream); + } } // Clear cache before the second run! @@ -291,7 +296,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo, module_cfg)); changed = false; - DevicelessConfig deviceless_config{gpu_device_desc()}; + DevicelessConfig deviceless_config{device_desc()}; AutotuneConfig deviceless_cfg{deviceless_config, opts}; TF_ASSERT_OK_AND_ASSIGN( changed, diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 79524924584c97..4068a75c6c67af 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -51,13 +51,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/call_inliner.h" #include "xla/service/dump.h" #include "xla/service/executable.h" -#include "xla/service/float_normalization.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -65,10 +67,16 @@ limitations under the License. #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion.h" +#include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h" #include "xla/service/gpu/transforms/fusion_wrapper.h" #include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/transforms/priority_fusion.h" @@ -85,7 +93,6 @@ limitations under the License. #include "xla/stream_executor/gpu/redzone_allocator.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tools/hlo_decomposer.h" #include "xla/tsl/lib/core/bits.h" #include "xla/tsl/util/proto/proto_utils.h" @@ -140,76 +147,6 @@ constexpr std::array kNumCtas = {1, 2, 4, 8, 16}; using AutoTuneCacheKeyCount = absl::flat_hash_map; -class GemmFusionAutotunerVisitor : public DfsHloRewriteVisitor { - public: - explicit GemmFusionAutotunerVisitor(const AutotuneConfig& config) - : config_(config) {} - - absl::Status HandleFusion(HloInstruction* hlo) override { - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - FusionBackendConfig& backend_config = - *gpu_config.mutable_fusion_backend_config(); - if (backend_config.kind() != kTritonGemmFusionKind && - backend_config.kind() != kCuDnnFusionKind) { - return absl::OkStatus(); - } - - VLOG(4) << "Processing " << hlo->ToString(); - if (!backend_config.has_triton_gemm_config() && - !backend_config.has_cudnn_fusion_config()) { - TF_ASSIGN_OR_RETURN( - AutotuneResult autotune_result, - AutotunerUtil::Autotune( - hlo, config_, [&]() -> absl::StatusOr { - if (config_.IsDeviceless()) { - return absl::InternalError(absl::StrCat( - "Expect autotune result cache hit for deviceless " - "compilation (HLO: ", - hlo->ToString(), ")")); - } - return absl::InternalError("Expect autotune result cache hit."); - })); - VLOG(4) << "Result: " << autotune_result.ShortDebugString(); - - if (autotune_result.has_triton()) { - *backend_config.mutable_triton_gemm_config() = autotune_result.triton(); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } else if (autotune_result.has_gemm()) { - // Falling back to cuBLAS: Converting the fusion to a Call, so that it - // can be inlined back again. - HloComputation* const computation = hlo->parent(); - HloInstruction* const call = computation->AddInstruction( - HloInstruction::CreateCall(hlo->shape(), hlo->operands(), - hlo->fused_instructions_computation())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, call)); - hlo = call; - } else { - CHECK(autotune_result.has_algorithm()); - backend_config.set_kind(std::string(kCuDnnFusionKind)); - backend_config.mutable_cudnn_fusion_config()->set_plan_id( - autotune_result.algorithm().algo_id()); - TF_RETURN_IF_ERROR(hlo->set_backend_config(gpu_config)); - } - } - - if (backend_config.has_triton_gemm_config()) { - TF_ASSIGN_OR_RETURN( - const TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - if (config.split_k > 1) { - TF_RETURN_IF_ERROR(MakeDotSplitKBatch(hlo, config)); - } - } - - MarkAsChanged(); - return absl::OkStatus(); - } - - private: - AutotuneConfig config_; -}; - class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { public: explicit GemmConfigSetCollector(GemmFusionAutotunerImpl* impl) @@ -259,7 +196,9 @@ class GemmConfigSetCollector : public ConstDfsHloVisitorWithDefault { bool missing_config = (backend_config.kind() == kTritonGemmFusionKind && !backend_config.has_triton_gemm_config()) || (backend_config.kind() == kCuDnnFusionKind && - !backend_config.has_cudnn_fusion_config()); + !backend_config.has_cudnn_fusion_config()) || + (backend_config.kind() == kCustomFusionKind && + !backend_config.has_custom_fusion_config()); if (missing_config) { if (error_out_on_cache_miss_) { return absl::NotFoundError(absl::StrCat( @@ -324,20 +263,12 @@ absl::StatusOr GetLimits(const HloDotInstruction& dot) { int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; } -int64_t PriorityFusionShapeSize(const Shape& shape) { +HloCostAnalysis::Options PriorityFusionOptions() { // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the // pointer size is used only to determine the size of tuple types. We // shouldn't have any tuples in the autotuned module, so it's safe to use - // a constant here, instead of piping the real value. - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - -HloCostAnalysis::Options PriorityFusionOptions() { - return {/*shape_size=*/PriorityFusionShapeSize, - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; + // the default value here, instead of piping the real value. + return {.count_multiple_input_accesses = true}; } absl::StatusOr> TritonGemmAutotuneExtractor( @@ -402,9 +333,7 @@ absl::StatusOr> CublasGemmAutotuneExtractor( // don't use cuBlas in the end. This assumes that the substituting algorithm // has result which are close enough for the check in this file. if (dot->precision_config().algorithm() == - PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || - dot->precision_config().algorithm() == - PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6) { + PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { dot->mutable_precision_config()->set_algorithm( PrecisionConfig::ALG_DOT_F32_F32_F32); } @@ -412,11 +341,13 @@ absl::StatusOr> CublasGemmAutotuneExtractor( for (GemmRewriterOptions::DType dtype : {GemmRewriterOptions::DType::kFp8Only, GemmRewriterOptions::DType::kNonFp8Only}) { - GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version, - GemmRewriterOptions{dtype}); + GemmRewriter gemm_rewriter(config.GetGpuComputeCapability(), + toolkit_version, GemmRewriterOptions{dtype}); + DotAlgorithmRewriter dot_algorithm_rewriter; PriorityFusion fusion_pass( /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions()); - TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(dot_algorithm_rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(gemm_rewriter.Run(new_module.get()).status()); TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); } // TODO(tdanyluk): Consider running GemmAlgorithmPicker here for better cuBLAS @@ -427,6 +358,46 @@ absl::StatusOr> CublasGemmAutotuneExtractor( return new_module; } +absl::Status UpdateFusionInstructionKernelIndex( + HloInstruction* fusion_instruction, int kernel_index) { + GpuBackendConfig gpu_config = + fusion_instruction->backend_config().value(); + gpu_config.mutable_fusion_backend_config() + ->mutable_custom_fusion_config() + ->set_kernel_index(kernel_index); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +absl::StatusOr> CustomFusionKernelAutotuneExtractor( + const GemmFusionAutotunerImpl::CustomKernelFusionConfig& cutlass_config, + const AutotuneConfig& config, const se::SemanticVersion& toolkit_version, + const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const HloComputation* fusion_computation = fusion->called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + new_module->mutable_config().set_debug_options(debug_opts); + + CustomKernelFusionRewriter rewriter( + &config.GetExecutor()->GetDeviceDescription()); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, config.GetExecutor()->GetDeviceDescription(), + PriorityFusionOptions()); + TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + + // Select custom kernel fusion kernel. + HloInstruction* custom_kernel_fusion = + hlo_query::GetFirstInstructionWithOpcode(*new_module->entry_computation(), + HloOpcode::kFusion); + int64_t kernel_index = cutlass_config.kernel_index; + TF_RETURN_IF_ERROR( + UpdateFusionInstructionKernelIndex(custom_kernel_fusion, kernel_index)); + + return new_module; +} + absl::StatusOr> FusionExtractor( const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); @@ -475,6 +446,11 @@ AutotuneResult FromConfig(const BackendConfig& config) { AutotuneResult res; if (std::holds_alternative(config)) { res.mutable_gemm()->set_algorithm(CUBLAS_GEMM_DEFAULT); + } else if (std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config)) { + res.mutable_custom_kernel_fusion()->set_kernel_index( + std::get(config) + .kernel_index); } else if (std::holds_alternative( config)) { res.mutable_algorithm()->set_algo_id( @@ -574,6 +550,129 @@ std::string Serialize(const BackendConfig& config) { } // anonymous namespace +absl::Status RewriteGemmFusionToCall(HloInstruction* fusion_instr) { + // Falling back to cuBLAS: Converting the fusion to a Call, so that it + // can be inlined back again. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + return computation->ReplaceInstruction(fusion_instr, call); +} + +absl::Status RewriteGemmFusionToCustomKernelFusion( + HloInstruction* fusion_instr, se::DeviceDescription device_description, + int64_t kernel_index) { + // Rewrites gemm fusion to custom kernel fusion. + // First convert the fusion to a call. Then inlines the call. Then + // rewrites to custom kernel fusion. + HloComputation* const computation = fusion_instr->parent(); + HloInstruction* const call = + computation->AddInstruction(HloInstruction::CreateCall( + fusion_instr->shape(), fusion_instr->operands(), + fusion_instr->fused_instructions_computation())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(fusion_instr, call)); + HloPassPipeline pipeline("autotuner_custom_kernel_fusion_rewriter"); + pipeline.AddPass(); + pipeline.AddPass(&device_description, + kernel_index); + HloModule* hlo_module = call->GetModule(); + return pipeline.Run(hlo_module).status(); +} + +absl::Status HandleTritonGemm(HloInstruction* fusion_instr, + FusionBackendConfig& fusion_backend_config) { + TF_ASSIGN_OR_RETURN( + const TritonGemmConfig config, + TritonGemmConfig::FromProto(fusion_backend_config.triton_gemm_config())); + if (config.split_k > 1) { + TF_RETURN_IF_ERROR(MakeDotSplitKBatch(fusion_instr, config)); + } + return absl::OkStatus(); +} + +absl::Status GemmFusionAutotunerRewriterVisitor::HandleFusion( + HloInstruction* fusion_instr) { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion_instr->backend_config()); + FusionBackendConfig& fusion_backend_config = + *gpu_config.mutable_fusion_backend_config(); + + // Only autotune Triton, cuDNN, and custom kernel fusions. + if (fusion_backend_config.kind() != kTritonGemmFusionKind && + fusion_backend_config.kind() != kCuDnnFusionKind && + fusion_backend_config.kind() != kCustomFusionKind) { + return absl::OkStatus(); + } + + // Do not autotune if the backend config has already assigned tiling config. + if (fusion_backend_config.has_triton_gemm_config()) { + TF_RETURN_IF_ERROR(HandleTritonGemm(fusion_instr, fusion_backend_config)); + MarkAsChanged(); + return absl::OkStatus(); + } + + // Do not autotune if the backend config has valid config. + if (fusion_backend_config.has_cudnn_fusion_config() || + fusion_backend_config.has_custom_fusion_config()) { + return absl::OkStatus(); + } + + VLOG(4) << "Autotuning fusion instruction: " << fusion_instr->ToString(); + TF_ASSIGN_OR_RETURN( + AutotuneResult autotune_result, + AutotunerUtil::Autotune( + fusion_instr, config_, [&]() -> absl::StatusOr { + absl::Status s; + if (config_.IsDeviceless()) { + s = absl::InternalError(absl::StrCat( + "Expect autotune result cache hit for deviceless " + "compilation (HLO: ", + fusion_instr->ToString(), ")")); + } else { + s = absl::InternalError("Expect autotune result cache hit."); + } + tsl::errors::InsertPayloads( + s, {{std::string(kAutotuneCacheRequiredErrorPayloadKey), ""}}); + + return s; + })); + VLOG(4) << "Autotuning result: " << autotune_result.ShortDebugString(); + + if (autotune_result.has_triton()) { + *fusion_backend_config.mutable_triton_gemm_config() = + autotune_result.triton(); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + TF_RETURN_IF_ERROR(HandleTritonGemm(fusion_instr, fusion_backend_config)); + MarkAsChanged(); + return absl::OkStatus(); + } + + if (autotune_result.has_gemm()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCall(fusion_instr)); + MarkAsChanged(); + return absl::OkStatus(); + } + + if (autotune_result.has_custom_kernel_fusion()) { + TF_RETURN_IF_ERROR(RewriteGemmFusionToCustomKernelFusion( + fusion_instr, config_.GetExecutor()->GetDeviceDescription(), + autotune_result.custom_kernel_fusion().kernel_index())); + MarkAsChanged(); + return absl::OkStatus(); + } + + // Autotune result has a cuDNN fusion. + CHECK(autotune_result.has_algorithm()); + fusion_backend_config.set_kind(std::string(kCuDnnFusionKind)); + fusion_backend_config.mutable_cudnn_fusion_config()->set_plan_id( + autotune_result.algorithm().algo_id()); + TF_RETURN_IF_ERROR(fusion_instr->set_backend_config(gpu_config)); + MarkAsChanged(); + return absl::OkStatus(); +} + // Methods required for sorting the configs. bool GemmFusionAutotunerImpl::CuBlasConfig::operator<( const CuBlasConfig& other) const { @@ -583,6 +682,10 @@ bool GemmFusionAutotunerImpl::CuDnnConfig::operator<( const CuDnnConfig& other) const { return plan_id < other.plan_id; } +bool GemmFusionAutotunerImpl::CustomKernelFusionConfig::operator<( + const CustomKernelFusionConfig& other) const { + return false; +} bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { return debug_options_.xla_gpu_autotune_level() > 0 && @@ -603,6 +706,72 @@ bool GemmFusionAutotunerImpl::IsAutotuningEnabled() const { } } +std::vector GenerateCustomKernelFusionConfigs( + const HloFusionInstruction& fusion, + se::DeviceDescription device_description) { + std::vector configs; + const CustomKernelFusionPatternRegistry* patterns = + CustomKernelFusionPatternRegistry::Default(); + HloComputation* computation = fusion.called_computation(); + // Get the first dot instruction in the fusion body. + HloInstruction* dot_instruction = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + std::vector match = + patterns->Match(device_description, dot_instruction); + + // For Cutlass we expect only one match for a GEMM fusion. + if (match.size() == 1) { + CustomKernelFusionRegistry* registry = + CustomKernelFusionRegistry::Default(); + auto* custom_kernel_fusion = registry->Lookup(match[0].config().name()); + + // If custom fusion is not found it means that some of the build targets + // might not be statically linked into the binary. + if (custom_kernel_fusion != nullptr) { + // There can be multiple kernels for a single fusion pattern, which are + // selected by the kernel_index. + // To get the number of kernels we can rewrite the fusion to custom kernel + // fusion and count the number of loaded kernels. + const HloComputation* fusion_computation = fusion.called_computation(); + std::unique_ptr new_module = + ExtractComputationIntoNewModule(*fusion_computation); + CustomKernelFusionRewriter rewriter(&device_description); + absl::StatusOr changed = rewriter.Run(new_module.get()); + if (!changed.ok() || !changed.value()) { + VLOG(2) << "Skip custom kernel config. Failed to rewrite custom kernel " + "fusion: " + << changed.status(); + return configs; + } + + HloInstruction* custom_kernel_fusion_instr = + hlo_query::GetFirstInstructionWithOpcode( + *new_module->entry_computation(), HloOpcode::kFusion); + if (custom_kernel_fusion_instr == nullptr) { + VLOG(2) << "Skip custom kernel config. Failed to find custom kernel " + "fusion instruction in the rewritten module."; + return configs; + } + absl::StatusOr> kernels = + custom_kernel_fusion->LoadKernels( + device_description, + custom_kernel_fusion_instr->fused_instructions_computation()); + if (!kernels.ok()) { + VLOG(2) << "Skip custom kernel config. Failed to load custom kernels: " + << kernels.status(); + } else { + for (int i = 0; i < kernels.value().size(); ++i) { + GemmFusionAutotunerImpl::CustomKernelFusionConfig config{ + /*kernel_index=*/i}; + configs.push_back(config); + } + } + } + } + + return configs; +} + absl::StatusOr> GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { const HloDotInstruction* dot = @@ -613,7 +782,7 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { if (!debug_options_.xla_gpu_experimental_disable_binary_libraries()) { // Add cuBLAS reference config, if available. if (algorithm_util::IsSupportedByCublasOrCublasLt( - dot->precision_config().algorithm()) && + dot->precision_config().algorithm(), GetComputeCapability()) && !dot->sparse_operands() && IsAutotuningEnabled()) { configs.push_back(CuBlasConfig{}); } @@ -642,6 +811,19 @@ GemmFusionAutotunerImpl::GenerateConfigs(const HloFusionInstruction& fusion) { } } + // Add CustomKernelFusion (Cutlass) configs, if available. + // Go through all the instructions in the fusion body try to match them to + // a custom kernel fusion pattern. + if ((IsFusionKind(fusion, kCustomFusionKind) || + IsFusionKind(fusion, kTritonGemmFusionKind)) && + IsAutotuningEnabled() && !config_.IsDeviceless()) { + std::vector custom_kernel_fusion_configs = + GenerateCustomKernelFusionConfigs( + fusion, config_.GetExecutor()->GetDeviceDescription()); + configs.insert(configs.end(), custom_kernel_fusion_configs.begin(), + custom_kernel_fusion_configs.end()); + } + // Add triton configs. TF_ASSIGN_OR_RETURN(std::vector triton_configs, GenerateTritonConfigs(*dot)); @@ -805,6 +987,14 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, config_, config_.GetExecutor()->GetDeviceDescription(), toolkit_version_, fusion, opts); })); + } else if (std::holds_alternative(config)) { + TF_ASSIGN_OR_RETURN(executable, + compile_util.Compile([&](const DebugOptions& opts) { + return CustomFusionKernelAutotuneExtractor( + std::get(config), + config_, toolkit_version_, fusion, opts); + })); + } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -902,13 +1092,6 @@ absl::StatusOr> GemmFusionAutotunerImpl::Profile( return absl::StrFormat("XlaAutotunerMeasurement:#hlo_op=%s#", fusion.name()); }); - se::DeviceMemoryAllocator* allocator = config_.GetAllocator(); - std::unique_ptr owned_allocator; - if (allocator == nullptr) { - owned_allocator = - std::make_unique(stream_exec); - allocator = owned_allocator.get(); - } TF_ASSIGN_OR_RETURN(se::Stream* const stream, config_.GetStream()); const HloInstruction& root = *fusion_computation->root_instruction(); @@ -1006,10 +1189,6 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper(); for (int num_stages : kNumStages) { - // Volta doesn't support num_stages > 2. - if (!cc.IsAtLeastAmpere() && num_stages > 2) { - break; - } for (int tile_m : kBlockSizes) { for (int tile_n : kBlockSizes) { for (int tile_k : kBlockSizes) { @@ -1052,28 +1231,22 @@ std::vector GemmFusionAutotunerImpl::GetDefaultTritonConfigs() const { using Config = TritonGemmConfig; std::vector configs = { - Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), - Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), - Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), - Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), - Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), - Config(64, 32, 64, 1, 2, 8)}; - if (GetComputeCapability().IsAtLeastAmpere()) { - absl::c_copy( - std::vector{ - Config(128, 256, 32, 1, 3, 8), Config(256, 128, 32, 1, 3, 8), - Config(256, 64, 32, 1, 4, 4), Config(64, 256, 32, 1, 4, 4), - Config(128, 64, 32, 1, 4, 4), Config(64, 128, 32, 1, 4, 4), - Config(256, 128, 128, 1, 3, 8), Config(256, 64, 128, 1, 4, 4), - Config(64, 256, 128, 1, 4, 4), Config(128, 128, 128, 1, 4, 4), - Config(128, 64, 64, 1, 4, 4), Config(64, 128, 64, 1, 4, 4), - Config(128, 32, 64, 1, 4, 4), Config(64, 32, 64, 1, 4, 4), - Config(32, 128, 32, 1, 4, 4), Config(128, 128, 32, 1, 4, 4), - Config(16, 16, 256, 1, 3, 4), Config(128, 128, 64, 2, 1, 8), - Config(64, 64, 64, 1, 2, 4), Config(16, 64, 256, 8, 1, 4), - Config(256, 256, 128, 1, 3, 8)}, - std::back_inserter(configs)); - } + Config(32, 32, 256, 1, 1, 4), Config(64, 32, 32, 16, 1, 4), + Config(32, 64, 64, 4, 1, 4), Config(128, 128, 64, 4, 1, 4), + Config(16, 16, 256, 1, 1, 4), Config(16, 128, 32, 16, 1, 4), + Config(16, 64, 128, 1, 1, 4), Config(16, 128, 32, 8, 1, 4), + Config(16, 16, 512, 1, 1, 4), Config(32, 16, 512, 1, 1, 4), + Config(64, 32, 64, 1, 2, 8), Config(128, 256, 32, 1, 3, 8), + Config(256, 128, 32, 1, 3, 8), Config(256, 64, 32, 1, 4, 4), + Config(64, 256, 32, 1, 4, 4), Config(128, 64, 32, 1, 4, 4), + Config(64, 128, 32, 1, 4, 4), Config(256, 128, 128, 1, 3, 8), + Config(256, 64, 128, 1, 4, 4), Config(64, 256, 128, 1, 4, 4), + Config(128, 128, 128, 1, 4, 4), Config(128, 64, 64, 1, 4, 4), + Config(64, 128, 64, 1, 4, 4), Config(128, 32, 64, 1, 4, 4), + Config(64, 32, 64, 1, 4, 4), Config(32, 128, 32, 1, 4, 4), + Config(128, 128, 32, 1, 4, 4), Config(16, 16, 256, 1, 3, 4), + Config(128, 128, 64, 2, 1, 8), Config(64, 64, 64, 1, 2, 4), + Config(16, 64, 256, 8, 1, 4), Config(256, 256, 128, 1, 3, 8)}; if (GetComputeCapability().IsAtLeastHopper()) { absl::c_copy( std::vector{ @@ -1305,8 +1478,8 @@ absl::StatusOr GemmFusionAutotuner::Run( } } - return GemmFusionAutotunerVisitor(config_).RunOnModule(module, - execution_threads); + return GemmFusionAutotunerRewriterVisitor(config_).RunOnModule( + module, execution_threads); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 7c262ffc8c613b..17272607532c20 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -29,7 +29,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/autotuning.pb.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" @@ -46,6 +48,18 @@ limitations under the License. namespace xla { namespace gpu { +// Uses profile results to rewrite a gemm fusion to use the best backend. +class GemmFusionAutotunerRewriterVisitor : public DfsHloRewriteVisitor { + public: + explicit GemmFusionAutotunerRewriterVisitor(const AutotuneConfig& config) + : config_(config) {} + + absl::Status HandleFusion(HloInstruction* fusion_instr) override; + + private: + AutotuneConfig config_; +}; + // Takes a gemm fusion and chooses between cuBLAS, cuDNN, and Triton backends. // In the case of Triton, it also chooses the best tiling configuration. // @@ -99,8 +113,13 @@ class GemmFusionAutotunerImpl { int64_t plan_id; bool operator<(const CuDnnConfig& other) const; }; + struct CustomKernelFusionConfig { + int64_t kernel_index; + bool operator<(const CustomKernelFusionConfig& other) const; + }; using BackendConfig = - std::variant; + std::variant; using BackendConfigs = std::vector< std::pair>>; diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index f47003ecea4256..a7b79d36e549e2 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -25,7 +25,9 @@ limitations under the License. #include #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "xla/autotuning.pb.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -51,6 +53,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" @@ -103,7 +106,7 @@ ENTRY entry { // Destroy the original module to be sure that the extracted one has no // dependency on it. - module.release(); + module = nullptr; EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); @@ -143,7 +146,7 @@ ENTRY entry { // Destroy the original module to be sure that the extracted one has no // dependency on it. - module.release(); + module = nullptr; EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), GmockMatch(m::Dot(m::Convert(m::Parameter()), m::Parameter()))); @@ -175,11 +178,126 @@ class StatelessAutotunerTest : public HloTestBase { AutotunerUtil::ClearAutotuneResults(); HloTestBase::TearDown(); } + + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs( + const HloModule& module, + const se::CudaComputeCapability& compute_capability, + const se::SemanticVersion& toolkit_version, + const DebugOptions& debug_options) { + const HloFusionInstruction& fusion = *Cast( + module.entry_computation()->root_instruction()); + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, debug_options}; + GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, + debug_options, nullptr); + return autotuner.GenerateConfigs(fusion); + } + + se::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + // Returns the config for the current device. + absl::StatusOr> + GetPossibleMatmulAutotuneConfigs(const HloModule& module) { + DeviceConfig device_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{device_config, GetDebugOptionsForTest()}; + GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(), + GetDebugOptionsForTest(), nullptr); + const HloFusionInstruction& fusion = *Cast( + module.entry_computation()->root_instruction()); + return autotuner.GenerateConfigs(fusion); + } + + bool hasCublasConfig( + const std::vector& configs) { + return std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative( + config); + }); + } }; +constexpr absl::string_view kHloDotFusionWithAlgorithm = R"( + HloModule module + + computation { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT r = f32[1024,1024] dot(p0, p1), + algorithm=$0, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + ROOT computation = f32[1024,1024] fusion(f32[1024,1024] p0,f32[1024,1024] p1), + kind=kCustom, + calls=computation + } +)"; + +TEST_F(StatelessAutotunerTest, NoCublasFallbackForTf32Tf32F32X3Algorithm) { + // There is no cublas implementation for dot_tf32_tf32_f32_x3 at the moment. + // At the same time cublas f32 is faster than triton for this algorithm. + // But we don't want to fallback to cuBLAS in this case because we lose the + // precision guarantees. + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(absl::Substitute( + kHloDotFusionWithAlgorithm, "dot_tf32_tf32_f32_x3"))); + + TF_ASSERT_OK_AND_ASSIGN(auto configs, + GetPossibleMatmulAutotuneConfigs(*module)); + EXPECT_FALSE(hasCublasConfig(configs)) + << "There is no cublas implementation for dot_tf32_tf32_f32_x3. That is " + "why we don't want to fallback to cublas."; +} + +TEST_F(StatelessAutotunerTest, + NoCublasFallbackForBf16Bf16F32AlgorithmOnHopper) { + // There is no cublas implementation for dot_bf16_bf16_f32 at the moment. + // At the same time cublas f32 is faster than triton for this algorithm. + // But we don't want to fallback to cuBLAS in this case because we lose the + // precision guarantees. + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(absl::Substitute( + kHloDotFusionWithAlgorithm, "dot_bf16_bf16_f32"))); + + TF_ASSERT_OK_AND_ASSIGN(auto configs, + GetPossibleMatmulAutotuneConfigs(*module)); + switch (GetCudaComputeCapability().major) { + case se::CudaComputeCapability::AMPERE: + EXPECT_TRUE(hasCublasConfig(configs)) + << "There is a cublas implementation for dot_bf16_bf16_f32 on Ampere"; + break; + case se::CudaComputeCapability::HOPPER: + EXPECT_TRUE(hasCublasConfig(configs)) + << "There is a cublas implementation for dot_bf16_bf16_f32 on Hopper"; + break; + default: + // We don't know what to expect for other compute capabilities. + EXPECT_FALSE(hasCublasConfig(configs)); + } +} + class GemmFusionAutotunerTest : public StatelessAutotunerTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = StatelessAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(true); @@ -188,13 +306,6 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { return debug_options; } - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } - void CheckTritonAutotuning(absl::string_view hlo, absl::string_view expected) { HloPassPipeline pipeline("gemm_rewrite"); @@ -238,7 +349,7 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest { class GemmFusionAutotunerTestWithMorePreciseReduction : public GemmFusionAutotunerTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmFusionAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( @@ -247,7 +358,8 @@ class GemmFusionAutotunerTestWithMorePreciseReduction } }; -absl::StatusOr> GetPossibleMatmulAutotuneConfigs( +absl::StatusOr> +GetPossibleMatmulAutotuneTritonConfigs( const HloDotInstruction& dot, const se::CudaComputeCapability& compute_capability, const se::SemanticVersion& toolkit_version, @@ -276,7 +388,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -298,7 +410,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -320,7 +432,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -451,7 +563,6 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } -// Modify block_k back to 16 once b/337839570 is fixed. // TODO(b/344770374): Make this test not fragile. TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) { const std::string kHloText = R"( @@ -470,7 +581,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -494,7 +605,6 @@ ENTRY %e { ::testing::HasSubstr("Insufficient registers")))); } -// Modify block_k back to 16 once b/337839570 is fixed. // TODO(b/344770374): Make this test not fragile. TEST_F(GemmFusionAutotunerTest, DoNotFilterOutAutotuningKernelSpillingRegisters) { @@ -517,7 +627,7 @@ ENTRY %e { %get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0) %convert = s8[4,12288]{1,0} parameter(1) ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -540,7 +650,6 @@ ENTRY %e { EXPECT_NE(executable, nullptr); } -// Modify block_k back to 16 once b/337839570 is fixed. TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) { const std::string kHloText = R"( HloModule m @@ -556,7 +665,7 @@ ENTRY %e { %p0 = s8[12288,1536]{1,0} parameter(0) %p1 = f16[4,12288]{1,0} parameter(1) ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot, - backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} + backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}} })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); @@ -741,7 +850,7 @@ ENTRY e { class GemmFusionAutotunerLevelTest : public StatelessAutotunerTest, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = StatelessAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_autotune_level(GetParam()); @@ -828,7 +937,7 @@ INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerLevelSweep, class GemmFusionAutotunerExhaustiveTest : public GemmFusionAutotunerTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmFusionAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_exhaustive_tiling_search(true); @@ -875,7 +984,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -886,7 +995,7 @@ ENTRY e { class GemmFusionAutotunerDisableSplitK : public GemmFusionAutotunerTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmFusionAutotunerTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_split_k_autotuning(false); @@ -907,7 +1016,7 @@ ENTRY e { se::CudaComputeCapability::AMPERE, /*minor=*/0}; TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), GetDebugOptionsForTest())); @@ -938,7 +1047,7 @@ ENTRY wais { TF_ASSERT_OK_AND_ASSIGN( const std::vector configs, - GetPossibleMatmulAutotuneConfigs( + GetPossibleMatmulAutotuneTritonConfigs( *Cast( module->entry_computation()->root_instruction()), compute_capability, GetToolkitVersion(), debug_options)); @@ -1002,6 +1111,187 @@ ENTRY entry { CHECK_OK(autotuner.CompileAll(*compile_util, configs)); } +TEST_F(GemmFusionAutotunerTest, CreatesCustomKernelFusionConfigs) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + })"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, GeneratesConfigForUpcastGemmWithPrologue) { + const std::string kHlo = R"( + HloModule module + + %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> f32[256,4096] { + %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) + %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) + %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) + %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) + %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) + ROOT r = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = f32[1,256,4,4096] parameter(0) + %p1 = bf16[1,4,4096,4096] parameter(1) + ROOT %gemm_fusion_r = f32[256,4096] fusion(%p0, %p1), kind=kCustom, + calls=gemm_fusion_r_computation, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, + GeneratesConfigForUpcastGemmWithPrologueAndEpilogue) { + const std::string kHlo = R"( + HloModule module + + %gemm_fusion_r_computation (parameter_0.1: f32[1,256,4,4096], parameter_1.1: bf16[1,4,4096,4096]) -> bf16[1048576] { + %parameter_0.1 = f32[1,256,4,4096]{3,2,1,0} parameter(0) + %bitcast.60 = f32[256,16384]{1,0} bitcast(f32[1,256,4,4096]{3,2,1,0} %parameter_0.1) + %parameter_1.1 = bf16[1,4,4096,4096]{3,2,1,0} parameter(1) + %bitcast.61 = bf16[16384,4096]{1,0} bitcast(bf16[1,4,4096,4096]{3,2,1,0} %parameter_1.1) + %convert.22 = f32[16384,4096]{1,0} convert(bf16[16384,4096]{1,0} %bitcast.61) + %dot.5 = f32[256,4096]{1,0} dot(f32[256,16384]{1,0} %bitcast.60, f32[16384,4096]{1,0} %convert.22), lhs_contracting_dims={1}, rhs_contracting_dims={0} + %convert.23 = bf16[256,4096]{1,0} convert(f32[256,4096]{1,0} %dot.5) + %bitcast.62 = bf16[1,256,4096]{2,1,0} bitcast(bf16[256,4096]{1,0} %convert.23) + %transpose.18 = bf16[1,4096,256]{2,1,0} transpose(bf16[1,256,4096]{2,1,0} %bitcast.62), dimensions={0,2,1} + ROOT %bitcast.63 = bf16[1048576]{0} bitcast(bf16[1,4096,256]{2,1,0} %transpose.18) + } + + ENTRY main { + %p0 = f32[1,256,4,4096] parameter(0) + %p1 = bf16[1,4,4096,4096] parameter(1) + ROOT %gemm_fusion_r = bf16[1048576] fusion(%p0, %p1), kind=kCustom, + calls=gemm_fusion_r_computation, + backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + const se::CudaComputeCapability compute_capability{ + se::CudaComputeCapability::AMPERE, /*minor=*/0}; + + TF_ASSERT_OK_AND_ASSIGN( + const std::vector configs, + GetPossibleMatmulAutotuneConfigs(*module, compute_capability, + GetToolkitVersion(), + GetDebugOptionsForTest())); + EXPECT_TRUE(std::any_of( + configs.begin(), configs.end(), + [](const GemmFusionAutotunerImpl::BackendConfig& config) { + return std::holds_alternative< + GemmFusionAutotunerImpl::CustomKernelFusionConfig>(config); + })); +} + +TEST_F(GemmFusionAutotunerTest, RewritesGemmFusionToCustomKernelFusion) { + const std::string kHlo = R"( + HloModule module, entry_computation_layout={(bf16[1024,1024]{1,0}, bf16[1024,1024]{1,0})->f32[1024,1024]{1,0}} + + %gemm_fusion_r_computation { + %parameter_0 = bf16[1024,1024]{1,0} parameter(0) + %convert.2 = f32[1024,1024]{1,0} convert(%parameter_0) + %parameter_1 = bf16[1024,1024]{1,0} parameter(1) + %convert.3 = f32[1024,1024]{1,0} convert(%parameter_1) + ROOT %r.1 = f32[1024,1024]{1,0} dot(%convert.2, %convert.3), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + %p0 = bf16[1024,1024]{1,0} parameter(0) + %p1 = bf16[1024,1024]{1,0} parameter(1) + ROOT %gemm_fusion_r = f32[1024,1024]{1,0} fusion(%p0, %p1), kind=kCustom, calls=gemm_fusion_r_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} + } +)"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(kHlo).value(); + + DebugOptions opts; + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), + backend().memory_allocator()}, + opts}; + AutotuneCacheKey cache_key(autotune_config.GetModelStr(), + *module->entry_computation()->root_instruction()); + TF_ASSERT_OK_AND_ASSIGN(AutotuneResults autotune_results_override, + ParseTextProto(R"pb( + version: 3 + results { + device: "..." + hlo: "..." + result { + custom_kernel_fusion { kernel_index: 1 } + run_time { nanos: 14 } + } + })pb")); + autotune_results_override.mutable_results(0)->set_device( + std::string(cache_key.GetModelStr())); + autotune_results_override.mutable_results(0)->set_hlo( + std::string(cache_key.GetHlo())); + + GemmFusionAutotunerRewriterVisitor visitor(autotune_config); + + CHECK_OK(AutotunerUtil::LoadAutotuneResults(autotune_results_override)); + visitor.RunOnModule(module.get(), {}).value(); + std::string pattern = R"( + CHECK: ROOT %cutlass_gemm_with_upcast + CHECK-SAME: fusion + CHECK-SAME: kind=kCustom + CHECK-SAME: "kernel_index":1 + )"; + TF_ASSERT_OK_AND_ASSIGN(bool file_check_matches, + RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(file_check_matches); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/backend_configs.proto b/third_party/xla/xla/service/gpu/backend_configs.proto index b23fe4f95629a7..0522e8f2c11080 100644 --- a/third_party/xla/xla/service/gpu/backend_configs.proto +++ b/third_party/xla/xla/service/gpu/backend_configs.proto @@ -3,8 +3,8 @@ syntax = "proto3"; package xla.gpu; import "xla/autotuning.proto"; +import "xla/tsl/protobuf/dnn.proto"; import "xla/xla_data.proto"; -import "tsl/protobuf/dnn.proto"; // Backend configs for XLA:GPU. // @@ -124,6 +124,9 @@ message BitcastBackendConfig { message CollectiveBackendConfig { bool is_sync = 1; bool no_parallel_custom_call = 2; + // Determines whether the collective op of interested has been pipelined + // within a loop. + bool is_pipelined = 3; } // Backend config for cost model estimates. diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc index 0ffb8e3fe63de9..92dd1f174a7f38 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.cc +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -28,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -66,10 +67,10 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, } // The buffers needed for 'user_subshape' and 'operand_shape' need to have // the same size, otherwise they cannot be shared. We already checked that - // the number of elements are the same, so now we check the number of bytes + // the number of elements are the same, so now we check the number of bits // needed for the element types. - if (ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()) != - ShapeUtil::ByteSizeOfPrimitiveType(user_subshape.element_type())) { + if (primitive_util::BitWidth(operand_shape.element_type()) != + primitive_util::BitWidth(user_subshape.element_type())) { return false; } } diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 39f00c826fdc7c..ba51d1ee15fbf5 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -42,6 +42,8 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" @@ -54,8 +56,6 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index a451af5a149fad..3e75fa156ac8c6 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/buffer_value.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.cc b/third_party/xla/xla/service/gpu/cublas_cudnn.cc index 18e131eee8f108..a9d94e8ed8ae33 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.cc +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.cc @@ -128,11 +128,13 @@ bool IsCustomCallToDnnNorm(const HloInstruction& hlo) { } bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo) { - if (hlo.opcode() != HloOpcode::kCustomCall) { - return false; - } - const auto& target = hlo.custom_call_target(); - return target == kCudnnfMHASoftmaxF8CallTarget; + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kCudnnfMHASoftmaxF8CallTarget; +} + +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kCustomCall && + hlo.custom_call_target() == kCudnnfMHASoftmaxBackwardF8CallTarget; } bool IsFwdCustomCallTofMHA(const HloInstruction& hlo) { @@ -169,7 +171,7 @@ bool IsCustomCallTofMHA(const HloInstruction& hlo) { } bool IsCustomCallTofMHAF8(const HloInstruction& hlo) { - return IsFwdCustomCallTofMHAF8(hlo); + return IsFwdCustomCallTofMHAF8(hlo) || IsBwdCustomCallTofMHAF8(hlo); } bool IsCubDeviceRadixSort(const HloInstruction& hlo) { diff --git a/third_party/xla/xla/service/gpu/cublas_cudnn.h b/third_party/xla/xla/service/gpu/cublas_cudnn.h index 9befcbb60901b1..d3f0a1ce22fdb3 100644 --- a/third_party/xla/xla/service/gpu/cublas_cudnn.h +++ b/third_party/xla/xla/service/gpu/cublas_cudnn.h @@ -188,6 +188,7 @@ extern const absl::string_view kCudnnfMHASoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxDropoutCallTarget; extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxCallTarget; // Backward calls +extern const absl::string_view kCudnnfMHASoftmaxBackwardF8CallTarget; extern const absl::string_view kCudnnfMHASoftmaxBackwardCallTarget; extern const absl::string_view kCudnnfMHASoftmaxDropoutBackwardCallTarget; extern const absl::string_view @@ -195,6 +196,7 @@ extern const absl::string_view extern const absl::string_view kCudnnfMHAScaleBiasSoftmaxBackwardCallTarget; bool IsFwdCustomCallTofMHAF8(const HloInstruction& hlo); +bool IsBwdCustomCallTofMHAF8(const HloInstruction& hlo); bool IsCustomCallTofMHAF8(const HloInstruction& hlo); bool IsFwdCustomCallTofMHA(const HloInstruction& hlo); bool IsBwdCustomCallTofMHA(const HloInstruction& hlo); diff --git a/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc b/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc index 0cc170fb1a32f9..48ca7ef87579d3 100644 --- a/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc +++ b/third_party/xla/xla/service/gpu/cudnn_support_utils_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index 2de44347f8ce35..21ebba8347e9dd 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -39,11 +39,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -52,7 +52,6 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" @@ -78,6 +77,20 @@ limitations under the License. #define gpuMemcpyHostToDevice hipMemcpyHostToDevice #endif +namespace xla { + +struct Range { + int64_t lo; + int64_t hi; +}; + +} // namespace xla + +// Register struct types with XLA:FFI to enable automatic decoding from +// dictionary attributes to structs. +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(::xla::Range, StructMember("lo"), + StructMember("hi")); + namespace xla { namespace { @@ -390,9 +403,10 @@ static absl::Status Memcpy(se::Stream* stream, ffi::AnyBuffer src, XLA_FFI_DEFINE_HANDLER(kMemcpy, Memcpy, ffi::Ffi::Bind() .Ctx() - .Arg() // src - .Ret() // dst -); + .Arg() // src + .Ret(), // dst + {ffi::Traits::kCmdBufferCompatible}); + XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$memcpy", PLATFORM, kMemcpy); @@ -614,20 +628,26 @@ TEST_F(CustomCallTest, ExportedFfiWithStatusSucceeded) { //===----------------------------------------------------------------------===// static absl::Status FfiAttributes(ffi::Result, - absl::Span i32_arr) { + absl::Span i32_arr, + Range range) { if (i32_arr.size() != 4) return absl::InternalError("i32_arr size does not match"); if (i32_arr[0] != 1 || i32_arr[1] != 2 || i32_arr[2] != 3 || i32_arr[3] != 4) return absl::InternalError("i32_arr values do not match"); + if (range.lo != 0 || range.hi != 42) { + return absl::InternalError("range values do not match"); + } + return absl::OkStatus(); } XLA_FFI_DEFINE_HANDLER(kFfiAttributes, FfiAttributes, ffi::Ffi::Bind() .Ret() - .Attr>("i32_arr")); + .Attr>("i32_arr") + .Attr("range")); XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_attributes", PLATFORM, kFfiAttributes); @@ -636,7 +656,9 @@ TEST_F(CustomCallTest, FfiAttributes) { XlaBuilder b(TestName()); CustomCall(&b, "xla.gpu.ffi_attributes", /*operands=*/{}, ShapeUtil::MakeShape(F32, {}), - /*opaque=*/"{ i32_arr = array }", + /*opaque=*/ + "{ i32_arr = array," + " range = { lo = 0 : i64, hi = 42 : i64 } }", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, /*schedule=*/CustomCallSchedule::SCHEDULE_NONE, @@ -650,7 +672,6 @@ TEST_F(CustomCallTest, FfiAttributes) { static absl::Status MemcpyWithCalledComputation( se::Stream* stream, int32_t device_ordinal, - se::DeviceMemoryAllocator* allocator, se::OwningScratchAllocator<> scratch_allocator, ffi::AnyBuffer src, ffi::Result dst, const HloComputation* called_computation) { if (called_computation == nullptr) @@ -674,7 +695,6 @@ XLA_FFI_DEFINE_HANDLER(kMemcpyWithCalledComputation, ffi::Ffi::Bind() .Ctx() .Ctx() // device_ordinal - .Ctx() // allocator .Ctx() // scratch .Arg() // src .Ret() // dst @@ -715,6 +735,7 @@ TEST_F(CustomCallTest, WithCalledComputation) { struct SomeExtraContext { explicit SomeExtraContext(int32_t value) : value(value) {} int32_t value; + bool prepared = false; bool initialized = false; bool executed = false; }; @@ -723,15 +744,25 @@ template static absl::Status ExecutionContext(ffi::Result, SomeExtraContext* ctx) { if (ctx->value != 42) return absl::InternalError("Unexpected value"); - if constexpr (stage == ffi::ExecutionStage::kInitialize) { + if constexpr (stage == ffi::ExecutionStage::kPrepare) { + ctx->prepared = true; + } else if constexpr (stage == ffi::ExecutionStage::kInitialize) { ctx->initialized = true; - } else { + } else if constexpr (stage == ffi::ExecutionStage::kExecute) { ctx->executed = true; + } else { + return absl::InternalError("Unexpected stage"); } return absl::OkStatus(); } +XLA_FFI_DEFINE_HANDLER(kExecutionContextPrepare, + ExecutionContext, + ffi::Ffi::Bind() + .Ret() + .Ctx>()); + XLA_FFI_DEFINE_HANDLER(kExecutionContextInitialize, ExecutionContext, ffi::Ffi::Bind() @@ -748,7 +779,7 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla.gpu.ffi_execution_context", PLATFORM, { /*instantiate=*/nullptr, - /*prepare=*/nullptr, + /*prepare=*/kExecutionContextPrepare, /*initialize=*/kExecutionContextInitialize, /*execute=*/kExecutionContextExecute, }); @@ -774,6 +805,7 @@ TEST_F(CustomCallTest, FfiExecutionContext) { // Check that FFI handler was called during initialization and execution. TF_ASSERT_OK_AND_ASSIGN(auto* user_context, execution_context.Lookup()); + EXPECT_TRUE(user_context->prepared); EXPECT_TRUE(user_context->initialized); EXPECT_TRUE(user_context->executed); } diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index cc6747af620056..19a3d1390fff1d 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -19,17 +19,27 @@ limitations under the License. #include #include +#include #include +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" +#include "xla/service/backend.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/gpu_timer.h" +#include "xla/stream_executor/gpu/mock_gpu_executor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "tsl/platform/statusor.h" @@ -40,9 +50,6 @@ class DeterminismTest : public GpuCodegenTest { public: DeterminismTest() : debug_options_(HloTestBase::GetDebugOptionsForTest()) { debug_options_.set_xla_gpu_exclude_nondeterministic_ops(true); - // Randomize timer durations to better test autotuning does not introduce - // nondeterminism. - se::gpu::GpuTimer::ReturnRandomDurationsForTesting(); } // Runs the HLO several times with the same random inputs, and asserts the @@ -77,10 +84,57 @@ class DeterminismTest : public GpuCodegenTest { } } - DebugOptions GetDebugOptionsForTest() override { return debug_options_; } + DebugOptions GetDebugOptionsForTest() const override { + return debug_options_; + } DebugOptions debug_options_; + enum class TimerCreation { kAllowed, kForbidden }; + + // Runs the HLO passes with the given HLO string and matches the + // resulting HLO against the given expect_hlo_regex using FileCheck. + // + // Calls to GpuExecutor::CreateEventBasedTimer can be forbidden by setting + // timer_creation to kForbidden. (The test fails when the function is called + // in this case.) + void MatchOptimizedHlo(absl::string_view hlo_string, + absl::string_view expected_hlo_regex, + TimerCreation timer_creation) { + if (timer_creation == TimerCreation::kAllowed) { + HloTestBase::MatchOptimizedHlo(hlo_string, expected_hlo_regex); + return; + } + + // If timer creation is forbidden we inject a mock GPU executor that + // prevents timer creation. + TF_ASSERT_OK_AND_ASSIGN(stream_executor::Platform * default_platform, + PlatformUtil::GetDefaultPlatform()); + stream_executor::gpu::MockGpuExecutor executor(default_platform, 0); + EXPECT_CALL(executor, CreateEventBasedTimer).Times(0); + EXPECT_CALL(executor, GetDeviceDescription) + .WillRepeatedly([this]() -> const se::DeviceDescription& { + return backend().default_stream_executor()->GetDeviceDescription(); + }); + EXPECT_CALL(executor, GetPlatform).WillRepeatedly([&]() { + return default_platform; + }); + EXPECT_CALL(executor, AsDnn).WillRepeatedly([&]() { + return backend().default_stream_executor()->AsDnn(); + }); + EXPECT_CALL(executor, device_ordinal).WillRepeatedly([]() { return 0; }); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(auto optimized_module, + backend().compiler()->RunHloPasses( + std::move(module), &executor, GetAllocator())); + absl::StatusOr filecheck_result = + RunFileCheck(optimized_module->ToString(), expected_hlo_regex); + TF_ASSERT_OK(filecheck_result.status()); + EXPECT_TRUE(filecheck_result.value()); + } + bool IsVoltaOrLater() const { return backend() .default_stream_executor() @@ -122,12 +176,14 @@ ENTRY e { } debug_options_.set_xla_gpu_triton_fusion_level(0); - MatchOptimizedHlo(kHloText, R"(; CHECK: custom_call_target="__cublas$gemm")"); + MatchOptimizedHlo(kHloText, R"(; CHECK: custom_call_target="__cublas$gemm")", + TimerCreation::kForbidden); AssertDeterminism(kHloText); debug_options_.set_xla_gpu_enable_cublaslt(true); MatchOptimizedHlo(kHloText, - R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); + R"(; CHECK: custom_call_target="__cublas$lt$matmul")", + TimerCreation::kForbidden); AssertDeterminism(kHloText); } @@ -152,7 +208,8 @@ ENTRY e { MatchOptimizedHlo(kHloText, R"( CHECK: __triton_gemm CHECK: {"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"} - )"); + )", + TimerCreation::kForbidden); AssertDeterminism(kHloText, /*num_runs=*/3); } @@ -177,7 +234,8 @@ ENTRY e { R"( CHECK: __triton_gemm CHECK-NOT: {"block_m":"32","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"} - )"); + )", + TimerCreation::kAllowed); } TEST_F(DeterminismTest, Conv) { diff --git a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc index bb17f1a6ed58f1..ce4c90d000ff7d 100644 --- a/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc +++ b/third_party/xla/xla/service/gpu/dot_algorithm_support_test.cc @@ -115,7 +115,7 @@ class DotAlgorithmSupportTest return GetDeviceDescription().gpu_compute_capability(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); // Setting this explicitly to make sure that we also test the case when the // dot's dimensions are under the rewrite size threshold: @@ -173,10 +173,10 @@ TEST_P(DotAlgorithmSupportTest, AlgorithmIsSupportedFromCudaCapability) { if (params.backend_restriction == BackendRestriction::kTritonOnly) { MatchOptimizedHlo(hlo_text, R"( - ;CHECK: ENTRY - ;CHECK: ROOT - ;CHECK-SAME: kCustom - ;CHECK-SAME: "triton_gemm_config" + ;CHECK: ENTRY + ;CHECK: ROOT + ;CHECK-SAME: kCustom + ;CHECK-SAME: "triton_gemm_config" )"); } } else { @@ -215,22 +215,27 @@ INSTANTIATE_TEST_SUITE_P(DotF16F16F32Tests, DotAlgorithmSupportTest, Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32Tests, DotAlgorithmSupportTest, - Combine(Values(PC::ALG_DOT_BF16_BF16_F32), - Values(BF16), Values(BF16, F32), - Values(CC(8, 0)), +INSTANTIATE_TEST_SUITE_P(DotF32ForBf16Bf16F32Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32), Values(F32), + Values(F32), Values(CC(8, 0)), Values(SemanticVersion{6, 0, 0}), Values(BackendRestriction::kNoRestriction), Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); -INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32XnTests, DotAlgorithmSupportTest, +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32X3Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3), + Values(F32), Values(F32), Values(CC(8, 0)), + Values(SemanticVersion{6, 0, 0}), + Values(BackendRestriction::kNoRestriction), + Values(Sizes{32, 32}, Sizes{16, 2})), + TestParamsToString); - Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X3, - PC::ALG_DOT_BF16_BF16_F32_X6), +INSTANTIATE_TEST_SUITE_P(DotBf16Bf16F32X6Tests, DotAlgorithmSupportTest, + Combine(Values(PC::ALG_DOT_BF16_BF16_F32_X6), Values(F32), Values(F32), Values(CC(8, 0)), Values(SemanticVersion{6, 0, 0}), - Values(BackendRestriction::kTritonOnly), + Values(BackendRestriction::kNoRestriction), Values(Sizes{32, 32}, Sizes{16, 2})), TestParamsToString); diff --git a/third_party/xla/xla/service/gpu/float_support_test.cc b/third_party/xla/xla/service/gpu/float_support_test.cc index 5822d10a8deeda..c7c294bfbd2920 100644 --- a/third_party/xla/xla/service/gpu/float_support_test.cc +++ b/third_party/xla/xla/service/gpu/float_support_test.cc @@ -39,7 +39,7 @@ class FloatSupportTest : public HloTestBase { class FloatSupportTestWithCublas : public FloatSupportTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = FloatSupportTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(false); return debug_options; @@ -48,7 +48,7 @@ class FloatSupportTestWithCublas : public FloatSupportTest { class FloatSupportTestWithTriton : public FloatSupportTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = FloatSupportTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(true); debug_options.set_xla_gpu_triton_gemm_any(true); diff --git a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc new file mode 100644 index 00000000000000..435ad70c473282 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.cc @@ -0,0 +1,138 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusion_dispatch_pipeline.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/Support/MathExtras.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/pattern_matcher.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { + +namespace { + +namespace m = ::xla::match; + +bool IsSlowLoopTransposeFusion(const HloFusionInstruction* fusion) { + const HloInstruction* root = + fusion->fused_instructions_computation()->root_instruction(); + + bool is_loop_transpose_fusion = + fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && + root->opcode() == HloOpcode::kTranspose; + + if (!is_loop_transpose_fusion) { + return false; + } + + // The slow transposes are those when the minormost dimension in the input + // is neither the minormost nor the second minormost dimension in the output, + // and the output minormost dimension is swapped with the new minormost + // dimension. + int64_t rank = root->shape().rank(); + + // The transpose dimension grouper has run, so it should be enough to check + // that the minormost dimension's index within the result is smaller than + // rank - 2, and that the new minormost dimension is swapped with it. + // This only triggers for transposes with major-to-minor layout. + bool has_major_to_minor_layout = + LayoutUtil::IsMonotonicWithDim0Major(root->shape().layout()); + absl::Span transpose_dimensions = root->dimensions(); + int64_t result_minormost_dim_in_operand = transpose_dimensions.back(); + + return has_major_to_minor_layout && + transpose_dimensions[result_minormost_dim_in_operand] == rank - 1 && + transpose_dimensions[rank - 1] < rank - 2; +} + +// Pattern-matches slow loop fusions that can likely be handled better by +// Triton than by other emitters. +// TODO(b/370690811,b/372187266): generalize this to other slow transposes. +bool FusionWillBeHandledBetterByTriton( + const HloFusionInstruction* fusion, + const se::DeviceDescription& device_description) { + if (!IsSlowLoopTransposeFusion(fusion)) { + return false; + } + + const HloInstruction* root = + fusion->fused_instructions_computation()->root_instruction(); + + // Because of Triton's power-of-two restriction, we're only guaranteed to + // handle the bitcast case when the bitcast's minor dimension is a power of + // two. This ensures that we can tile it reasonably even if the bitcast's + // input has that dimension collapsed. (See comments in `symbolic_tile.cc` + // around destructuring summations to understand why this is important.) + auto can_bitcast_input_be_tiled_efficiently = + [](const HloInstruction* bitcast) { + return llvm::isPowerOf2_64(bitcast->shape().dimensions_minor(0)); + }; + + bool is_pure_transpose = ::xla::Match(root, m::Transpose(m::Parameter())); + bool is_bitcasted_transpose_with_power_of_two_minor_dim = ::xla::Match( + root, + m::Transpose(m::Bitcast(m::Parameter()) + .WithPredicate(can_bitcast_input_be_tiled_efficiently))); + return is_pure_transpose || + is_bitcasted_transpose_with_power_of_two_minor_dim; +} + +} // anonymous namespace + +HloPassPipeline FusionDispatchPipeline( + const se::DeviceDescription& device_description, + HloCostAnalysis::ShapeSizeFunction shape_size_fn) { + std::function(const HloFusionInstruction*)> + try_rewrite_fusion_if = + [&device_description]( + const HloFusionInstruction* fusion) -> absl::StatusOr { + bool should_always_rewrite_to_block_level = + fusion->GetModule() + ->config() + .debug_options() + .xla_gpu_experimental_enable_fusion_block_level_rewriter(); + + // TODO(b/370690811): this rewrite may no longer be necessary once MLIR + // emitters transposes are faster. + return should_always_rewrite_to_block_level || + FusionWillBeHandledBetterByTriton(fusion, device_description); + }; + + // Even though this is a single pass, we need to create a pipeline in order + // to make sure the pass's run is recorded in the `HloModuleMetadata`. + HloPassPipeline pipeline("fusion-dispatch-pipeline"); + pipeline.AddPass(device_description, shape_size_fn, + std::move(try_rewrite_fusion_if)); + return pipeline; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.h b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.h new file mode 100644 index 00000000000000..7256f9d2d9567d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusion_dispatch_pipeline.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ +#define XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ + +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" + +namespace xla { +namespace gpu { + +// Returns a pipeline that attempts to redirect fusions to the most efficient +// emitter possible. +HloPassPipeline FusionDispatchPipeline( + const se::DeviceDescription& device_description, + HloCostAnalysis::ShapeSizeFunction shape_size_fn); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSION_DISPATCH_PIPELINE_H_ diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index 7bc4f170980ea1..4b56e92a952c2c 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -20,6 +20,7 @@ limitations under the License. #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/cpu_gpu_shape_verifier.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/transforms/fusion_merger.h" @@ -31,7 +32,6 @@ limitations under the License. #include "xla/service/gpu/transforms/variadic_op_splitter.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_verifier.h" #include "xla/service/layout_assignment.h" #include "xla/stream_executor/device_description.h" @@ -50,29 +50,23 @@ HloPassPipeline FusionPipeline( // We try to split variadic ops with many parameters into several such ops // to avoid exceeding the parameter space. fusion.AddPass(); + HloVerifierOpts opts = + HloVerifierOpts().MakeLayoutSensitive().WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); fusion.AddInvariantCheckerDebug( - std::make_unique( - HloVerifierOpts() - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout)), + std::make_unique(std::move(opts)), "hlo verifier (debug)"); - if (debug_options.xla_gpu_enable_priority_fusion()) { - GpuHloCostAnalysis::Options cost_analysis_options{ - shape_size_bytes_function, - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; - fusion.AddPass(thread_pool, gpu_device_info, - std::move(cost_analysis_options)); - } else { - fusion.AddPass(/*may_duplicate=*/false, - gpu_device_info); - fusion.AddPass(/*may_duplicate=*/true, - gpu_device_info); - fusion.AddPass(gpu_device_info, shape_size_bytes_function); - } + GpuHloCostAnalysis::Options cost_analysis_options{ + shape_size_bytes_function, + /*per_second_rates=*/{}, + /*min_latencies_seconds=*/{}, + /*count_multiple_input_accesses=*/true}; + fusion.AddPass(thread_pool, gpu_device_info, + std::move(cost_analysis_options)); + // Running CSE affects how many users an op has. This plays a role in what // we detect as a tiled transpose fusion. fusion.AddPass(/*is_layout_sensitive=*/true, @@ -82,6 +76,7 @@ HloPassPipeline FusionPipeline( fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); fusion.AddPass(); + return std::move(fusion); } diff --git a/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc b/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc index 37eb3bee29c836..fcaa298b4489db 100644 --- a/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc +++ b/third_party/xla/xla/service/gpu/fusion_process_dump_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/fusion_process_dump.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 290c451dfffb8b..5976590a4eae5f 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -16,7 +16,9 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions/mlir:computation_partitioner", @@ -61,6 +63,7 @@ cc_library( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":fusion_emitter", + "//xla:literal", "//xla:shape_util", "//xla:status_macros", "//xla:util", @@ -126,11 +129,11 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:constants", "//xla/ffi", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", "//xla/service:custom_call_target_registry", "//xla/service:executable", @@ -142,8 +145,9 @@ xla_test( "//xla/service/gpu/runtime:sequential_thunk", "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", - "//xla/stream_executor", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -343,6 +347,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", @@ -353,6 +358,7 @@ cc_library( "//xla/service/gpu:matmul_utils", "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu/fusions/triton:triton_fusion_emitter", + "//xla/service/gpu/fusions/triton:triton_fusion_emitter_legacy_matmul", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:ir_array", @@ -448,6 +454,7 @@ xla_test( "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -518,6 +525,7 @@ cc_library( ":reduction_base", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", @@ -530,10 +538,13 @@ cc_library( "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", "//xla/service/gpu/model:indexing_analysis", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -565,6 +576,7 @@ cc_library( srcs = ["concatenate_mlir.cc"], hdrs = ["concatenate_mlir.h"], deps = [ + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", @@ -573,6 +585,7 @@ cc_library( "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/model:indexing_analysis", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index d4bbe647f152d2..e977408fde2a8f 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -42,6 +43,7 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h index b98db45690389c..c6223d8fc38a4f 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 9b513ba16871fa..fb206ef9dd5506 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/comparison_util.h" #include "xla/debug_options_flags.h" @@ -46,6 +47,7 @@ limitations under the License. #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -56,7 +58,7 @@ namespace { class CuDnnFusionTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // Let this group of tests just use first available plan skipping // autotuning. @@ -95,7 +97,7 @@ class CuDnnFusionFileCheckTest : public CuDnnFusionTest { } } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions options = CuDnnFusionTest::GetDebugOptionsForTest(); options.set_xla_dump_to(output_directory_); return options; @@ -554,7 +556,7 @@ ENTRY e { class CuDnnFusionCommandBufferTest : public CuDnnFusionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = CuDnnFusionTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_graph_min_graph_size(1); return debug_options; @@ -609,7 +611,7 @@ ENTRY e { class CuDnnFusionLevel2Test : public CuDnnFusionExecutionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = CuDnnFusionExecutionTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cudnn_gemm_fusion_level(2); @@ -834,7 +836,7 @@ ENTRY e { class CuDnnFusionLevel3Test : public CuDnnFusionExecutionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = CuDnnFusionExecutionTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cudnn_gemm_fusion_level(3); @@ -947,7 +949,7 @@ INSTANTIATE_TEST_SUITE_P( HloOpcode::kNegate, HloOpcode::kRsqrt, HloOpcode::kSin, HloOpcode::kSqrt, HloOpcode::kTan, HloOpcode::kTanh}), - ::testing::Values(5e-4)), + ::testing::Values(1e-3)), ElementwiseTestParamsToString); using BinaryElementwiseTest = ElementwiseTest; @@ -1091,7 +1093,7 @@ INSTANTIATE_TEST_SUITE_P(SelectTestSuite, SelectTest, class CuDnnFusionRewriteTest : public CuDnnFusionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = CuDnnFusionTest::GetDebugOptionsForTest(); // Reset autotuning level to default. debug_options.set_xla_gpu_autotune_level( diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index c89cdaba4f89db..949cf62d1ffa52 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,6 +28,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "mlir/AsmParser/AsmParser.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -61,6 +64,7 @@ limitations under the License. #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" @@ -176,51 +180,6 @@ absl::StatusOr GetOperandSlice( return absl::InternalError("WTF"); } -// Returns true if `offset` is a loop iteration number. This pattern matching -// detects HLOs that generated by `jax.lax.scan` and will miss slightly -// different patterns that still compute slice offset as loop iteration number. -static bool IsLoopIterationOffset(const HloInstruction* offset) { - const HloComputation* parent = offset->parent(); - if (!parent->IsWhileBodyComputation()) return false; - - // Scan loops trip count must be known at compile time as it iterates over the - // leading dimension of the statically shaped input. - const HloInstruction* while_instr = parent->WhileCallInstruction(); - auto config = while_instr->backend_config(); - if (!config.ok() || !config->has_known_trip_count()) return false; - int32_t trip_count = config->known_trip_count().n(); - - // Check that offset is defined by a loop fusion that computes offset - // from the loop iteration number. - if (!offset->IsLoopFusion() || - !Match( - offset->fused_expression_root(), - m::Select(m::Compare(m::Parameter(0), m::ConstantScalar(0)), - m::Add(m::Parameter(0), m::ConstantScalar(trip_count)), - m::Parameter(0)))) { - return false; - } - - // Check that we get loop iteration directly from loop parameters bundle. - HloInstruction* get_loop_iteration; - if (!Match(const_cast(offset->operand(0)), - m::GetTupleElement(&get_loop_iteration, m::Parameter(0)))) { - return false; - } - int32_t loop_iter_idx = get_loop_iteration->tuple_index(); - - // Check that loop iteration counter updated with a +1 fusion. - const HloInstruction* loop_inc = - parent->root_instruction()->operand(loop_iter_idx); - if (!loop_inc->IsLoopFusion() || - !Match(loop_inc->fused_expression_root(), - m::Add(m::Parameter(0), m::ConstantScalar(1)))) { - return false; - } - - return true; -} - // Returns the constant literal, if the offset is from an offset array. Returns // `std::nullopt` otherwise. std::optional GetOffsetArray(const HloInstruction* inst) { @@ -276,10 +235,6 @@ absl::Status CollectSliceInfo( "Unsupported constant offset shape: ", cst->shape().ToString())); } - } else if (IsLoopIterationOffset(offset_value)) { - // Loop offset defined by a loop iteration number. - arg_offsets.emplace_back() = DynamicSliceThunk::LoopIter(); - } else { // Loop offset computed on device and has to be transferred to host. TF_ASSIGN_OR_RETURN(arg_offsets.emplace_back(), @@ -298,26 +253,86 @@ absl::Status CollectSliceInfo( return absl::OkStatus(); } +// This function assumes that the computation graph for `fusion_instr` looks +// like: +// +// ... +// root_tuple_operand = (... ty[shape], ...) ... +// ROOT root_tuple = (... (... ty[shape], ...), ...) +// tuple(... root_tuple_operand, ...) +// +// Given such a pattern and a (complete) index into `root_tuple_operand`, we +// recover the slice of `root_tuple` that corresponds to that index. +absl::StatusOr GetResultSliceForPartiallyUnnestedTuple( + const BufferAssignment& buffer_assignment, + const HloFusionInstruction& fusion_instr, + const HloInstruction& root_tuple_operand, + const ShapeIndex& root_tuple_operand_shape_idx, + const HloInstruction& root_tuple) { + int64_t operand_index = root_tuple.operand_index(&root_tuple_operand); + ShapeIndex slice_shape_index; + slice_shape_index.push_back(operand_index); + absl::c_copy(root_tuple_operand_shape_idx, + std::back_inserter(slice_shape_index)); + return GetAllocationSlice(buffer_assignment, &fusion_instr, + slice_shape_index); +} + absl::StatusOr GetResultSlice( const BufferAssignment& buffer_assignment, const HloFusionAdaptor& adaptor, - const HloInstruction& fusion_instr, const HloInstruction& start_instr, + const HloFusionInstruction& fusion_instr, const HloInstruction& start_instr, std::vector& slice_instrs, const ShapeIndex& shape_idx, unsigned arg_idx) { auto* start = const_cast(&start_instr); + if (start->IsRoot()) { + return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); + } + // Walk through ShapeIndex to find the real "user" (i.e. not get-tuple-element // user). Otherwise one sliced element will mark all buffers of all other // elements "sliced" too. if (start->shape().IsTuple()) { - for (auto idx : shape_idx) { - std::vector gte_users( - start->shape().tuple_shapes_size(), nullptr); - for (auto* user : start->users()) - if (auto* gte = DynCast(user)) - gte_users[gte->tuple_index()] = gte; - - start = static_cast(gte_users[idx]); - if (start == nullptr) - return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); + for (auto [index_nesting_level, index_in_shape] : + llvm::enumerate(shape_idx)) { + HloInstruction* gte_user = nullptr; + for (auto* user : start->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == index_in_shape) { + gte_user = user; + break; + } + } + + if (gte_user == nullptr) { + // At this point, two things are known: + // 1. `start` was not the root instruction of the fusion at the + // beginning of this function call; + // 2. `start` still has a tuple shape because we haven't managed to + // unwrap the entire shape index. + // We also know, by definition of the surrounding pass, that all the + // results of the custom call must be materialized at the output of + // the fusion, which indicates that `start` is currently *not* the + // root. Since we can't slice/bitcast/reshape a tuple, then the + // only possible consumer should be a `tuple` instruction, which + // logically should be the root of the fusion. + HloInstruction* start_user = start->users().front(); + if (start->user_count() != 1 || + start_user->opcode() != HloOpcode::kTuple || + !start_user->IsRoot()) { + return absl::InternalError( + "Expected the user of a nested tuple shape to be a root tuple " + "instruction." + "Expected a single user of the tuple-shaped instruction"); + } + + ShapeIndex remaining_shape_index( + shape_idx.begin() + index_nesting_level, shape_idx.end()); + return GetResultSliceForPartiallyUnnestedTuple( + buffer_assignment, fusion_instr, *start, remaining_shape_index, + *start_user); + } + + start = gte_user; } } @@ -343,7 +358,56 @@ absl::StatusOr GetResultSlice( } } - return GetAllocationSlice(buffer_assignment, &fusion_instr, shape_idx); + constexpr absl::string_view kNonContiguousDynamicUpdateSliceError = + "DynamicSliceFusion only handles contiguous slices currently"; + + // At this point, we've fully unfolded a tuple that was not the root of the + // computation. There are two options; either, the root is a tuple, or it is + // not. + // + // If the root is not a tuple, we can simply get the buffer slice assigned to + // the fusion itself---there is nothing else to choose from. + if (fusion_instr.shape().IsArray()) { + HloInstruction* root = fusion_instr.fused_expression_root(); + if (root->opcode() == HloOpcode::kDynamicUpdateSlice && + !IsContiguousSlice(*root)) { + return absl::InternalError(kNonContiguousDynamicUpdateSliceError); + } + return GetAllocationSlice(buffer_assignment, &fusion_instr, {}); + } + + // If the root is a tuple however, it may be a nested tuple. Go all the way + // to the root to figure out the index that our array occupies within that + // tuple. + HloInstruction* current_hlo = start; + std::vector reversed_shape_index; + do { + TF_RET_CHECK(current_hlo->user_count() == 1); + HloInstruction* user = current_hlo->users().front(); + // We may encounter three ops here: dynamic-update-slice, tuple, or bitcast. + switch (user->opcode()) { + case HloOpcode::kBitcast: + break; + case HloOpcode::kDynamicUpdateSlice: + if (!IsContiguousSlice(*user)) { + return absl::InternalError(kNonContiguousDynamicUpdateSliceError); + } + break; + case HloOpcode::kTuple: + reversed_shape_index.push_back(user->operand_index(current_hlo)); + break; + default: + return absl::InternalError( + absl::StrCat("Unexpected opcode while processing the epilogue of a " + "DynamicSliceFusion: ", + HloOpcodeString(user->opcode()))); + }; + current_hlo = user; + } while (!current_hlo->IsRoot()); + + return GetAllocationSlice( + buffer_assignment, &fusion_instr, + ShapeIndex(reversed_shape_index.rbegin(), reversed_shape_index.rend())); } absl::StatusOr EmitGemm( @@ -681,15 +745,15 @@ absl::StatusOr EmitCustomCall( auto ffi_thunk = [&](Slices ops, Slices res) { auto& called_computations = custom_call.called_computations(); return CustomCallThunk::Create( - thunk_info, registration->bundle, std::move(ops), std::move(res), - std::move(attributes), + thunk_info, call_target_name, registration->bundle, std::move(ops), + std::move(res), std::move(attributes), called_computations.empty() ? nullptr : called_computations[0]); }; auto legacy_thunk = [&](Slices ops, Slices res) { - return CustomCallThunk::Create(thunk_info, std::move(custom_call_target), - std::move(ops), std::move(res), - std::move(opaque)); + return CustomCallThunk::Create( + thunk_info, call_target_name, std::move(custom_call_target), + std::move(ops), std::move(res), std::move(opaque)); }; std::vector> fake_allocations(num_args); diff --git a/third_party/xla/xla/service/gpu/fusions/custom.h b/third_party/xla/xla/service/gpu/fusions/custom.h index c5e758da715a8c..bd0ccc71688426 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.h +++ b/third_party/xla/xla/service/gpu/fusions/custom.h @@ -50,6 +50,16 @@ class CustomFusion : public FusionInterface { // compile-time instead of allocating a new buffer for it at runtime by // translating the static slice into offset + size of the original buffer passed // into the custom call `%gemm`. +// +// It is possible to inscribe the results of the custom call within a larger +// array. In that case, the affected results are each fed into a +// `dynamic-update-slice` operation, whose result is one of the fusion's +// outputs. +// +// The pass makes the assumption that, for each one of the custom-call's outputs +// there is exactly one path to the fusion root. The resulting shape for the +// dynamic slice fusion may be an unwrapped array, a flat tuple, or even a +// nested tuple. class DynamicSliceFusion : public FusionInterface { public: explicit DynamicSliceFusion(const HloFusionAnalysis& analysis) diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 212e6b51e5445d..3d49c08c05b3ea 100644 --- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/custom_call_target_registry.h" @@ -282,6 +282,94 @@ TEST_F(DynamicSliceFusionTest, CublasGemmWithWorkspace) { /*run_hlo_passes=*/false)); } +TEST_F(DynamicSliceFusionTest, NestedTupleOutputForCublasGemmWithWorkspace) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_ref = R"( + HloModule nested_tuple + + ENTRY main { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + slice_1 = f16[1,8,8]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:8]} + bitcast_1 = f16[8,8]{1,0} bitcast(slice_1) + slice_2 = f16[1,8,8]{2,1,0} slice(p1), slice={[1:2], [0:8], [0:8]} + bitcast_2 = f16[8,8]{1,0} bitcast(slice_2) + + custom-call = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast_1, bitcast_2), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + result = f16[8,8]{1,0} get-tuple-element(custom-call), index=0 + workspace = s8[256]{0} get-tuple-element(custom-call), index=1 + nested_tuple = (s8[256]{0}) tuple(workspace) + ROOT tuple = (f16[8,8]{1,0}, (s8[256]{0})) tuple(result, nested_tuple) + })"; + + const char* hlo_opt = R"( + HloModule jit_slice + + fused_computation { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + slice_1 = f16[1,8,8]{2,1,0} slice(p0), slice={[1:2], [0:8], [0:8]} + bitcast_1 = f16[8,8]{1,0} bitcast(slice_1) + slice_2 = f16[1,8,8]{2,1,0} slice(p1), slice={[1:2], [0:8], [0:8]} + bitcast_2 = f16[8,8]{1,0} bitcast(slice_2) + + custom-call = (f16[8,8]{1,0}, s8[256]{0}) custom-call(bitcast_1, bitcast_2), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + result = f16[8,8]{1,0} get-tuple-element(custom-call), index=0 + workspace = s8[256]{0} get-tuple-element(custom-call), index=1 + nested_tuple = (s8[256]{0}) tuple(workspace) + ROOT tuple = (f16[8,8]{1,0}, (s8[256]{0})) tuple(result, nested_tuple) + } + + ENTRY main.9 { + p0 = f16[2,8,8]{2,1,0} parameter(0) + p1 = f16[2,8,8]{2,1,0} parameter(1) + ROOT fusion = (f16[8,8]{1,0}, (s8[256]{0})) fusion(p0, p1), kind=kCustom, calls=fused_computation, + backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); +} + TEST_F(DynamicSliceFusionTest, ContiguousSlice) { ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; @@ -2960,6 +3048,7 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) { HloModuleConfig ref_config; debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false); + debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false); ref_config.set_debug_options(debugoptions); TF_ASSERT_OK_AND_ASSIGN(auto ref_module, ParseAndReturnVerifiedModule(hlo_ref, ref_config)); @@ -3409,93 +3498,6 @@ TEST_F(DynamicSliceFusionTest, OffsetArrayTestU64) { ".*"); } -TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateCollective) { - const char* hlo_ref = R"( - HloModule test, replica_count=2 - - add.clone { - x.1 = f16[] parameter(0) - y.1 = f16[] parameter(1) - ROOT add.462 = f16[] add(x.1, y.1) - } - - ENTRY %main.9 { - param_0 = f16[128,128]{1,0} parameter(0) - param_1 = f16[128,128]{1,0} parameter(1) - constant_20 = u32[] constant(20) - constant_0 = u32[] constant(0) - dynamic-slice = f16[64,128]{1,0} dynamic-slice(param_0, constant_0, constant_0), dynamic_slice_sizes={64,128} - reduce-scatter = f16[64,128]{1,0} reduce-scatter(dynamic-slice), channel_id=64, replica_groups={{0},{1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone - ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0) - })"; - - const char* hlo_opt = R"( - HloModule test, replica_count=2 - - %add { - %param_0 = f16[] parameter(0) - %param_1 = f16[] parameter(1) - ROOT %add.1 = f16[] add(%param_0, %param_1) - } - - %address-computation { - %p0 = f16[128,128]{1,0} parameter(0) - %p1 = f16[128,128]{1,0} parameter(1) - %p2 = u32[] parameter(2) - %p3 = u32[] parameter(3) - %dynamic-slice = f16[64,128]{1,0} dynamic-slice(%p0, %p3, %p3), dynamic_slice_sizes={64,128} - %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%dynamic-slice), channel_id=64, replica_groups={{0},{1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add - ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3) - } - - ENTRY %main.9 { - %param_0.1 = f16[128,128]{1,0} parameter(0) - %param_1.1 = f16[128,128]{1,0} parameter(1) - %constant_20 = u32[] constant(20) - %constant_0 = u32[] constant(0) - ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0.1, %param_1.1, %constant_20, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false} - })"; - - // Dynamic slice fusion is turned on by default. So, we need to turn that off - // while parsing. - HloModuleConfig ref_config; - DebugOptions debug_options; - debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); - ref_config.set_debug_options(debug_options); - - TF_ASSERT_OK_AND_ASSIGN(auto ref_module, - ParseAndReturnVerifiedModule(hlo_ref, ref_config)); - TF_ASSERT_OK_AND_ASSIGN(auto ref_module_opt, - GetOptimizedModule(std::move(ref_module))); - - TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion, - ParseAndReturnVerifiedModule(hlo_opt, ref_config)); - TF_ASSERT_OK_AND_ASSIGN(auto module_with_fusion_opt, - GetOptimizedModule(std::move(module_with_fusion))); - - // Check that the thunk is a d2d copy thunk because the collective is - // degenerate. - auto module_with_fusion_opt_clone = module_with_fusion_opt->Clone(); - TF_ASSERT_OK_AND_ASSIGN( - auto exec, - CreateExecutable(std::move(module_with_fusion_opt_clone), false)); - GpuExecutable* gpu_exec = dynamic_cast(exec.get()); - auto& child_thunk = gpu_exec->GetThunk().thunks()[1]; - ASSERT_EQ(child_thunk->kind(), Thunk::kDynamicSlice); - auto* ds_thunk = dynamic_cast(child_thunk.get()); - ASSERT_EQ(ds_thunk->embedded_thunk()->kind(), Thunk::kSequential); - auto* embedded_thunk = - dynamic_cast(ds_thunk->embedded_thunk()); - ASSERT_EQ(embedded_thunk->thunks().size(), 1ul); - ASSERT_EQ(embedded_thunk->thunks()[0]->kind(), Thunk::kCopy); - ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; - - // Comparing the outputs. - EXPECT_TRUE(RunAndCompareTwoModulesReplicated( - std::move(ref_module_opt), std::move(module_with_fusion_opt), - /*run_hlo_passes=*/false, /*use_threads=*/true, error)); -} - TEST_F(DynamicSliceFusionTest, ReduceScatterSlice) { const char* hlo_ref = R"( HloModule jit_slice, replica_count=2 @@ -3599,57 +3601,6 @@ TEST_F(DynamicSliceFusionTest, ReduceScatterDynamicSlice) { false, true, error)); } -TEST_F(DynamicSliceFusionTest, ReduceScatterDegenerateSlice) { - const char* hlo_ref = R"( - HloModule test_module, replica_count=2 - - add { - a = s32[] parameter(0) - b = s32[] parameter(1) - ROOT add = s32[] add(a, b) - } - - ENTRY main { - p0 = s32[2,4,8] parameter(0) - slice = s32[1,4,8] slice(p0), slice={[1:2], [0:4], [0:8]} - bc = s32[4,8] reshape(slice) - ROOT rs = s32[4,8] reduce-scatter(bc), channel_id=64, replica_groups={{0},{1}}, use_global_device_ids=true, dimensions={0}, to_apply=add - } - )"; - HloModuleConfig config; - DebugOptions options; - options.set_xla_gpu_enable_dynamic_slice_fusion(false); - options.clear_xla_gpu_enable_command_buffer(); - config.set_debug_options(options); - TF_ASSERT_OK_AND_ASSIGN(auto module_ref, - ParseAndReturnVerifiedModule(hlo_ref, config)); - - options.set_xla_gpu_enable_dynamic_slice_fusion(true); - options.clear_xla_gpu_enable_command_buffer(); - config.set_debug_options(options); - TF_ASSERT_OK_AND_ASSIGN(auto module_new, - ParseAndReturnVerifiedModule(hlo_ref, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto module_ref_opt, - GetOptimizedModule(std::move(module_ref))); - TF_ASSERT_OK_AND_ASSIGN(auto module_new_opt, - GetOptimizedModule(std::move(module_new))); - - ASSERT_TRUE(GetDynamicSliceFusions(*module_ref_opt).empty()); - ASSERT_FALSE(GetDynamicSliceFusions(*module_new_opt).empty()); - - auto module_new_opt_clone = module_new_opt->Clone(); - TF_ASSERT_OK_AND_ASSIGN( - auto exec, CreateExecutable(std::move(module_new_opt_clone), false)); - GpuExecutable* gpu_exec = dynamic_cast(exec.get()); - ASSERT_EQ(gpu_exec->GetThunk().thunks()[0]->kind(), Thunk::kCopy); - - ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; - EXPECT_TRUE(RunAndCompareTwoModulesReplicated(std::move(module_ref_opt), - std::move(module_new_opt), - false, true, error)); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 432d600701d1ab..17388caa36d3db 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -168,21 +168,20 @@ IndexingMap KernelFusionInterface::GetDefaultThreadIdIndexingMap( divisor *= shape.dimensions(dimension); } - std::vector dim_vars = { - {{0, static_cast(launch_dims.thread_counts_per_block().x) - 1}}, - {{0, static_cast(launch_dims.thread_counts_per_block().y) - 1}}, - {{0, static_cast(launch_dims.thread_counts_per_block().z) - 1}}, - {{0, static_cast(launch_dims.block_counts().x) - 1}}, - {{0, static_cast(launch_dims.block_counts().y) - 1}}, - {{0, static_cast(launch_dims.block_counts().z) - 1}}, - }; - std::vector range_vars; + std::vector dim_vars = DimVarsFromGPUGrid( + {static_cast(launch_dims.thread_counts_per_block().x), + static_cast(launch_dims.thread_counts_per_block().y), + static_cast(launch_dims.thread_counts_per_block().z), + static_cast(launch_dims.block_counts().x), + static_cast(launch_dims.block_counts().y), + static_cast(launch_dims.block_counts().z)}); + std::vector range_vars; int64_t num_elements = ShapeUtil::ElementsIn(shape); - range_vars.push_back( - {{0, CeilOfRatio(num_elements, - static_cast(launch_dims.launch_bound()) * - unroll_factor) - - 1}}); + range_vars.push_back(IndexingMap::Variable{ + {0, CeilOfRatio(num_elements, + static_cast(launch_dims.launch_bound()) * + unroll_factor) - + 1}}); range_vars.push_back({0, unroll_factor - 1}); IndexingMap indexing_map( mlir::AffineMap::get(/*dimCount=*/6, diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 200f06f8461db5..ff15b29724c79c 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -20,8 +20,6 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/strings/match.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -42,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/fusions/legacy/scatter.h" #include "xla/service/gpu/fusions/legacy/transpose.h" #include "xla/service/gpu/fusions/loop_mlir.h" -#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/reduction_mlir.h" #include "xla/service/gpu/fusions/scatter_mlir.h" #include "xla/service/gpu/fusions/transpose_mlir.h" diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h index f7406b463b9117..6e9f16aa1deca8 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.h +++ b/third_party/xla/xla/service/gpu/fusions/fusions.h @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/ir_emission_utils.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index 823b3c8f765ad6..db4605be601a29 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -69,7 +69,8 @@ LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions() const { const auto& update_shape = dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); - return CalculateLaunchDimensions(update_shape, analysis_.device_info()); + return CalculateLaunchDimensions(update_shape, analysis_.device_info(), + config_); } std::optional @@ -84,7 +85,7 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); - return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, + return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, update_shape, indexing_context); } diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index 2ed84a06522b16..ab1a71f1c6c8a4 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -26,7 +26,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" @@ -47,8 +49,9 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { explicit MlirInPlaceDynamicUpdateSliceFusion( const HloFusionAnalysis& analysis) : analysis_(analysis), - dus_ops_( - GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {} + dus_ops_(GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())), + config_(ComputeLoopFusionConfig( + analysis, dus_ops_[0].instruction().operand(1)->shape())) {} LaunchDimensions launch_dimensions() const override; @@ -77,6 +80,7 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { private: const HloFusionAnalysis& analysis_; std::vector dus_ops_; + LaunchDimensionsConfig config_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc index e595571c519830..ae6730de098ff6 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc @@ -78,8 +78,7 @@ MlirInputSlicesFusion::GetEpilogues(const HloFusionInstruction& fusion, // We don't actually use epilogues here, but this is how we tell the base // class not to emit code for the slices. - return {mlir_converter::EpilogueSpecification::FromOutputIndexing( - analysis_, roots, roots, *this, mlir_context)}; + return {GetEpilogueForOutputIndexing(analysis_, roots, roots, mlir_context)}; } LaunchDimensions MlirInputSlicesFusion::launch_dimensions() const { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD index 16fb100182636e..d7e85cc8c89cc0 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -68,6 +68,14 @@ gentbl_cc_library( name = "xla_gpu_attrs_inc_gen", strip_include_prefix = ".", tbl_outs = [ + ( + ["-gen-enum-decls"], + "xla_gpu_enums.h.inc", + ), + ( + ["-gen-enum-defs"], + "xla_gpu_enums.cc.inc", + ), ( [ "-gen-attrdef-decls", @@ -127,7 +135,6 @@ cc_library( ":xla_gpu_ops_inc_gen", ":xla_gpu_types_inc_gen", "//xla/service/gpu/model:indexing_analysis", - "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", @@ -138,7 +145,6 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", - "@stablehlo//:stablehlo_type_inference", ], ) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir similarity index 58% rename from third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir index 76a74dd7908eca..6a199f5f024241 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/attrs.mlir @@ -8,18 +8,16 @@ // CHECK-SAME: d2 in [10, 12], // CHECK-SAME: s0 in [0, 32], // CHECK-SAME: d0 + s0 in [1, 10], -// CHECK-SAME: d0 mod 2 in [0, 1], -// CHECK-SAME: is_simplified: true +// CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), - domain: - d0 in [1, 2], - d1 in [5, 8], - d2 in [10, 12], - s0 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - is_simplified: true +#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "d2 in [10, 12]," + "s0 in [0, 32]," + "d0 mod 2 in [0, 1]," + "d0 + s0 in [1, 10]" > func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>) @@ -39,20 +37,19 @@ func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map> // CHECK-SAME: d0 + s0 in [1, 10] // CHECK-SAME: d0 mod 2 in [0, 1] // CHECK-SAME: d1 + s1 + s2 in [1, 32] -// CHECK-SAME: is_simplified: false // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 10], - s1 in [0, 5], - s2 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - d1 + s1 + s2 in [1, 32], - is_simplified: false - > +#map = #xla_gpu.indexing_map< + "(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "s0 in [0, 10]," + "s1 in [0, 5]," + "s2 in [0, 32]," + "d0 mod 2 in [0, 1]," + "d0 + s0 in [1, 10]," + "d1 + s1 + s2 in [1, 32]" + > func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-LABEL: @more_range_vars // CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]> @@ -64,13 +61,11 @@ func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 100] // CHECK-SAME: s0 in [-3, -1] -// CHECK-SAME: is_simplified: false // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0), - domain: - d0 in [0, 100], - s0 in [-3, -1], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0)," + "domain:" + "d0 in [0, 100]," + "s0 in [-3, -1]" > func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @indexing_map_small @@ -85,15 +80,13 @@ func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: d1 in [5, 8] // CHECK-SAME: d2 in [10, 12] // CHECK-SAME: s0 in [0, 32] -// CHECK-SAME: is_simplified: false // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), - domain: - d0 in [1, 2], - d1 in [5, 8], - d2 in [10, 12], - s0 in [0, 32], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0, d1, d2)[s0] -> (d0)," + "domain:" + "d0 in [1, 2]," + "d1 in [5, 8]," + "d2 in [10, 12]," + "s0 in [0, 32]" > func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-LABEL: @no_constraints @@ -106,13 +99,11 @@ func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: s0 in [3, 5] // CHECK-SAME: s0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false // CHECK-SAME: > -#map = #xla_gpu.indexing_map<()[s0] -> (s0), - domain: - s0 in [3, 5], - s0 mod 2 in [0, 1], - is_simplified: false +#map = #xla_gpu.indexing_map<"()[s0] -> (s0)," + "domain:" + "s0 in [3, 5]," + "s0 mod 2 in [0, 1]" > func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_dimensions @@ -125,13 +116,11 @@ func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-SAME: domain: // CHECK-SAME: d0 in [3, 5] // CHECK-SAME: d0 mod 2 in [0, 1] -// CHECK-SAME: is_simplified: false // CHECK-SAME: > -#map = #xla_gpu.indexing_map<(d0) -> (d0), - domain: - d0 in [3, 5], - d0 mod 2 in [0, 1], - is_simplified: false +#map = #xla_gpu.indexing_map<"(d0) -> (d0)," + "domain:" + "d0 in [3, 5]," + "d0 mod 2 in [0, 1]," > func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @no_symbols @@ -142,7 +131,15 @@ func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< // CHECK-SAME: () -> () // CHECK-SAME: > -#map = #xla_gpu.indexing_map<() -> ()> +#map = #xla_gpu.indexing_map<"() -> ()"> func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) // CHECK-LABEL: @empty // CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> + +// ----- + +func.func private @tensor_layout( + %in0: tensor<42xf32, #xla_gpu.layout<"shmem", + "(d0) -> ()," "domain: d0 in [0, 42]">>) +// CHECK: #layout = #xla_gpu.layout<"shmem", "(d0) -> (), domain: +// CHECK: tensor<42xf32, #layout> diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir index 495456a5ab36d4..08086e34f60b05 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -1,15 +1,12 @@ // RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s -#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), - domain: s0 in [-10, 10], s1 in [0, 2], - is_simplified: false> +#map0 = #xla_gpu.indexing_map<"()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), domain: s0 in [-10, 10], s1 in [0, 2]"> func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1, d0 mod 2), -// CHECK-SAME: domain: d0 in [-10, 10] -// CHECK-SAME: is_simplified: true> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1, d0 mod 2), +// CHECK-SAME: domain: d0 in [-10, 10]"> // CHECK-LABEL: func.func @simplify_apply_indexing // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) @@ -17,14 +14,13 @@ func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), - domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]"> func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, %d2: index, %s0: index, %s1: index) -> (index, index, index) { %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] func.return %0#0, %0#1, %0#2 : index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d2 + 1, d2 mod 2, d0 + d1), // CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], d2 in [-11, 11] // CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims @@ -38,23 +34,13 @@ func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 mod 10), domain: d0 in [0, 9], is_simplified: true> -func.func @do_not_simplify_if_is_simplified_is_true(%d0: index) -> (index) { - %0 = xla_gpu.apply_indexing #map0(%d0) - func.return %0 : index -} -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 mod 10) - -// ----- - -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), - domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) -> (index, index, index, index, index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), // CHECK-LABEL: func.func @fold_indexing_map_results // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) @@ -67,13 +53,13 @@ func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0), - domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)," + "domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]"> func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#2 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 2), // CHECK-SAME: domain: d0 in [0, 2] // CHECK-LABEL: func.func @remove_unused_results @@ -84,8 +70,8 @@ func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) // ----- -#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3), - domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)," + "domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]"> func.func @fold_operands(%d0: index) -> index { %d1 = arith.constant 1 : index %s0 = arith.constant 2 : index @@ -93,7 +79,7 @@ func.func @fold_operands(%d0: index) -> index { %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 3), // CHECK-SAME: domain: d0 in [0, 10] // CHECK-LABEL: func.func @fold_operands @@ -104,8 +90,8 @@ func.func @fold_operands(%d0: index) -> index { func.func @fold_operands_and_results(%arg0: index, %arg1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1), - domain: d0 in [0, 4], d1 in [0, 5], is_simplified: false>(%arg0, %arg1) + %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (0, d1)," + "domain: d0 in [0, 4], d1 in [0, 5]">(%arg0, %arg1) return %0#0, %0#1 : index, index } @@ -117,14 +103,14 @@ func.func @fold_operands_and_results(%arg0: index, %arg1: index) // ----- func.func @fold_sequence(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42), - domain: d0 in [0, 10000], is_simplified: false>(%0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + "(d0, d1) -> (d0 + d1), domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 100 + 42)," + "domain: d0 in [0, 10000]">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -134,14 +120,14 @@ func.func @fold_sequence(%arg0: index, %arg1: index) -> index { // ----- func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42), - domain: s0 in [0, 10000], is_simplified: false>(%0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), " + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map< + "()[s0] -> (s0 mod 100 + 42), domain: s0 in [0, 10000]">(%0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 + 42), // CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4] // CHECK-LABEL: func.func @fold_sequence_sym // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -150,12 +136,11 @@ func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { // ----- -#indexing_map1 = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0 + 8512), - domain: d0 in [0, 1], d1 in [0, 607], is_simplified: false> -#indexing_map2 = #xla_gpu.indexing_map< - (d0, d1, d2) -> (((d1 floordiv 32 + 1) mod 3) * 64 - + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0 + 8512)," + "domain: d0 in [0, 1], d1 in [0, 607]"> +#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (" + "((d1 floordiv 32 + 1) mod 3) * 64 + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @fold_sequence_no_simplification_needed(%i: index) -> index { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -168,12 +153,12 @@ func.func @fold_sequence_no_simplification_needed(%i: index) -> index { // ----- -#indexing_map1 = #xla_gpu.indexing_map<(d0) -> (3 * d0), - domain: d0 in [0, 9407], is_simplified: false> -#indexing_map2 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 1), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> -#indexing_map3 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 2), - domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1], is_simplified: false> +#indexing_map1 = #xla_gpu.indexing_map< + "(d0) -> (3 * d0), domain: d0 in [0, 9407]"> +#indexing_map2 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 1)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> +#indexing_map3 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 floordiv 32 + 2)," + "domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]"> func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} @@ -187,14 +172,14 @@ func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { // ----- func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 5], d1 in [0, 4], is_simplified: false>(%arg0, %arg1) - %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - domain: d0 in [0, 4], d1 in [0, 10000], is_simplified: false>(%arg1, %0) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + "domain: d0 in [0, 5], d1 in [0, 4]">(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1)," + "domain: d0 in [0, 4], d1 in [0, 10000]">(%arg1, %0) func.return %1 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), // CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5] // CHECK-LABEL: func.func @fold_sequence_shared_operands // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) @@ -235,15 +220,15 @@ func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) // ----- -#map0 = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 * s0), - domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," + "domain: d0 in [0, 3], s0 in [0, 2]"> func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) -> index { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] func.return %0 : index } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d0 * d1) * 2), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> ((d0 * d1) * 2), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2] // CHECK-LABEL: func.func @apply_indexing_move_syms_to_dims // CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] @@ -251,8 +236,9 @@ func.func @apply_indexing_move_syms_to_dims(%dim0: index, %sym0: index) // // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (4 * d0), domain: d0 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0) -> (4 * d0), domain: d0 in [0, 3]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1)," + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %idx = xla_gpu.apply_indexing #map0(%dim) %sum = xla_gpu.loop (%idx)[%i, %j] -> (%r0, %r1) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -263,7 +249,7 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 * 4 + s0, s1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 * 4 + s0, s1), // CHECK-SAME: domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing // CHECK-SAME: %[[ARG0:.*]]: tensor<1024x32xf32>, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: index) @@ -272,8 +258,10 @@ func.func @loop_of_apply_indexing(%input: tensor<1024x32xf32>, %init: f32, %dim: // ----- -#map0 = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 * s0), domain: d0 in [0, 3], s0 in [0, 2], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0 + s1), domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map0 = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 * s0)," + "domain: d0 in [0, 3], s0 in [0, 2]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0 + s1)," + "domain: d0 in [0, 12], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: tensor<1024x32xf32>, %init: f32) -> (f32) { %0 = xla_gpu.apply_indexing #map0(%dim0)[%sym0] %sum = xla_gpu.loop (%0)[%i, %j] -> (%r0) in #map1 iter_args(%sum_ = %init) -> (f32) { @@ -284,7 +272,7 @@ func.func @loop_of_apply_indexing_with_syms(%dim0: index, %sym0: index, %input: func.return %sum : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> ((d0 * d1) * 2 + s0 + s1), // CHECK-SAME: domain: d0 in [0, 3], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] // CHECK-LABEL: func.func @loop_of_apply_indexing_with_syms // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir index 922b3f3bbfff0e..35064858b23150 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -1,13 +1,6 @@ // RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics -#map0 = #xla_gpu.indexing_map< - (d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} %0:2 = xla_gpu.apply_indexing #map0 (%d0) @@ -16,16 +9,7 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map< - (d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - d0 mod 2 in [0, 1], - d0 + s0 in [1, 10], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0), domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32], d0 mod 2 in [0, 1], d0 + s0 in [1, 10]"> func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { // expected-error @+1 {{apply indexing op cannot have any constraints}} %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] @@ -34,7 +18,7 @@ func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of loop-carried values and results}} @@ -52,7 +36,7 @@ func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<()[s0] -> (s0, s0), domain: s0 in [0, 1024], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0] -> (s0, s0), domain: s0 in [0, 1024]"> func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (f32) { // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} @@ -70,8 +54,7 @@ func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} @@ -89,8 +72,7 @@ func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} @@ -105,9 +87,7 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { @@ -119,10 +99,8 @@ func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -134,10 +112,8 @@ func.func @no_thread_id_in(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), - domain: s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]"> func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -149,10 +125,8 @@ func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -161,11 +135,8 @@ func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: inde // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -178,10 +149,8 @@ func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0, s0), - domain: d0 in [0, 32], s0 in [0, 1024], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{number of symbols in both indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -190,10 +159,8 @@ func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, % // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]"> func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { // expected-error @+1 {{domain of symbols of indexing_maps must match}} %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> @@ -202,12 +169,8 @@ func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -219,12 +182,8 @@ func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -236,12 +195,8 @@ func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]"> func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) @@ -254,12 +209,8 @@ func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]"> func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { @@ -271,13 +222,8 @@ func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -290,13 +236,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -309,14 +250,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 2 in [0, 0], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), - domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], - d1 mod 4 in [0, 0], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]"> func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) @@ -329,12 +264,8 @@ func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 mod 16 + s0, d1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0 mod 16 + s0, d1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -346,12 +277,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 mod 16, d1, d2), - domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 mod 16, d1, d2), domain: d0 in [0, 32], d1 in [0, 2], d2 in [0, 5]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir index 572202bf148ce2..f6fd03d8f1ed24 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -56,19 +56,13 @@ func.func @caller(%a: f32, %b: f32) -> f32 { // ----- -#map0 = #xla_gpu.indexing_map< -(d0, d1)[s0] -> (d0, d1 + s0), - domain: - d0 in [1, 2], - d1 in [5, 8], - s0 in [0, 32], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d0, d1 + s0)," + "domain: d0 in [1, 2], d1 in [5, 8], s0 in [0, 32]"> func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) // CHECK-SAME: domain: // CHECK-SAME: d0 in [1, 2] @@ -83,18 +77,13 @@ func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) // ----- -#map0 = #xla_gpu.indexing_map< -(d0, d1) -> (d0, d1), - domain: - d0 in [0, 2], - d1 in [1, 3], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," + "domain: d0 in [0, 2], d1 in [1, 3]"> func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: (d0, d1) -> (d0, d1) // CHECK-SAME: domain: // CHECK-SAME: d0 in [0, 2] @@ -108,17 +97,13 @@ func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { // ----- -#map0 = #xla_gpu.indexing_map< - ()[s0] -> (s0, s0), - domain: - s0 in [2, 4], - is_simplified: false -> +#map0 = #xla_gpu.indexing_map<"()[s0] -> (s0, s0)," + "domain: s0 in [2, 4]"> func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing #map0 [%s0] func.return %0#0, %0#1 : index, index } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<" // CHECK-SAME: ()[s0] -> (s0, s0) // CHECK-SAME: domain: // CHECK-SAME: s0 in [2, 4] @@ -130,8 +115,8 @@ func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), " + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -155,15 +140,12 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32], - is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1)," + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1)," + "domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1)," + "domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -174,12 +156,12 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.return %1 : tensor<32x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1) +// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (s0, s1) // CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32] -// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1) -// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], +// CHECK: #[[$MAP2:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1) +// CHECK-SAME: d0 in [0, 32], d1 in [0, 2]"> // CHECK-LABEL: @materialize_and_insert // CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at // CHECK-SAME: #[[$MAP]](%{{.*}}, %{{.*}}) @@ -233,13 +215,14 @@ func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32) // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { %0 = xla_gpu.reindex %in0 at #map : tensor<1024xf32> -> tensor<16x64xf32> func.return %0 : tensor<16x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1) // CHECK-LABEL: func.func @reindex( // CHECK-SAME: %[[IN1:.*]]: tensor<1024xf32> // CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] : @@ -247,7 +230,8 @@ func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1)," + "domain: d0 in [0, 15], d1 in [0, 63]"> func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reindex %in0 at #map default %c0 @@ -255,7 +239,7 @@ func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { func.return %0 : tensor<16x64xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 64 + d1) // CHECK-LABEL: func.func @reindex_pad( // CHECK-SAME: %[[IN1:.*]]: tensor<1022xf32> // CHECK: %[[C0:.*]] = arith.constant 0.00 @@ -278,4 +262,4 @@ func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { // CHECK: xla_gpu.shuffle_reduce(%[[IN1]], %[[IN2]]) to 4 // CHECK-SAME: combiner=@do_nothing {xla.range = [0 : index, 42 : index]} -// CHECK-SAME: : f32, i32 \ No newline at end of file +// CHECK-SAME: : f32, i32 diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index 8fec5e91c9c3a1..622320885e5ac9 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include -#include "absl/strings/str_format.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -28,11 +29,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" - -#define GET_ATTRDEF_LIST -#define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -46,137 +45,35 @@ using mlir::AsmPrinter; using mlir::failure; using mlir::success; -constexpr llvm::StringRef kIsSimplifiedKeyword = "is_simplified"; - -ParseResult ParseInterval(AsmParser& parser, Interval& interval) { - // ParseResult converts to `true` if parsing failed. - return failure(parser.parseLSquare() || parser.parseInteger(interval.lower) || - parser.parseComma() || parser.parseInteger(interval.upper) || - parser.parseRSquare()); -} - -ParseResult parseBool(AsmParser& parser, bool* result) { - if (succeeded(parser.parseOptionalKeyword("true"))) { - *result = true; - return success(); - } - if (succeeded(parser.parseOptionalKeyword("false"))) { - *result = false; - return success(); - } - return failure(); -} - -void PrintDimVars(AsmPrinter& p, ArrayRef dim_vars) { - for (const auto [index, dim_var] : llvm::enumerate(dim_vars)) { - p << "d" << index << " in " << dim_var.bounds << ", "; - } -} - -ParseResult ParseDimVars(AsmParser& parser, ArrayRef dim_names, - SmallVector& dim_vars) { - dim_vars.reserve(dim_names.size()); - for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) { - if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") || - ParseInterval(parser, dim_vars.emplace_back().bounds) || - parser.parseComma()) { - return failure(); - } - } - return success(); -} - -void PrintRangeVars(AsmPrinter& p, ArrayRef range_vars) { - for (const auto [index, range_var] : llvm::enumerate(range_vars)) { - p << "s" << index << " in " << range_var.range << ", "; - } -} - -ParseResult ParseRangeVars(AsmParser& parser, - ArrayRef range_symbol_names, - SmallVector& range_vars) { - range_vars.reserve(range_symbol_names.size()); - for (const auto& [index, range_symbol_name] : - llvm::enumerate(range_symbol_names)) { - if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") || - ParseInterval(parser, range_vars.emplace_back().range) || - parser.parseComma()) { - return failure(); - } - } - return success(); -} - -void PrintConstraints(AsmPrinter& p, - ArrayRef> constraints) { - for (const auto& [expr, interval] : constraints) { - p << expr << " in " << interval << ", "; - } +// Parses a chain of string attributes into an indexing map. +// Example: +// "()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)," +// " domain: s0 in [-10, 10], s1 in [0, 2]" +// will be parsed as 3 StringAttrs, concatenated into a single string, and then +// parsed into an IndexingMap. +std::optional parseChainOfStringsAsIndexingMap( + mlir::AsmParser& parser) { + mlir::StringAttr indexing_map_attr; + std::string indexing_map_str; + while (parser.parseOptionalAttribute(indexing_map_attr).has_value()) { + indexing_map_str.append(indexing_map_attr.getValue()); + } + return ParseIndexingMap(indexing_map_str, parser.getContext()); } mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { - mlir::AffineMap map; - if (parser.parseLess() || parser.parseAffineMap(map)) { + if (parser.parseLess()) { return {}; } - - // Store real strings to back up StringRef throughout ParseConstraints. - SmallVector dim_strings(map.getNumDims()); - SmallVector symbol_strings(map.getNumSymbols()); - SmallVector> symbolSet; - symbolSet.reserve(map.getNumDims() + map.getNumSymbols()); - for (int i = 0; i < map.getNumDims(); ++i) { - dim_strings[i] = absl::StrFormat("d%d", i); - symbolSet.push_back( - {dim_strings[i], mlir::getAffineDimExpr(i, parser.getContext())}); - } - for (int i = 0; i < map.getNumSymbols(); ++i) { - symbol_strings[i] = absl::StrFormat("s%d", i); - symbolSet.push_back( - {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())}); - } - if (map.getNumDims() + map.getNumSymbols() == 0) { - if (parser.parseGreater()) return {}; - return IndexingMapAttr::get(parser.getContext(), map, /*dim_vars=*/{}, - /*range_vars=*/{}, - /*constraints=*/{}, /*is_simplified=*/true); - } - if (parser.parseComma() || parser.parseKeyword("domain") || - parser.parseColon()) { - return {}; - } - - SmallVector dim_vars; - if (ParseDimVars(parser, dim_strings, dim_vars)) { - return {}; - } - SmallVector range_vars; - if (ParseRangeVars(parser, symbol_strings, range_vars)) { - return {}; - } - - SmallVector> constraints; - while (failed(parser.parseOptionalKeyword(kIsSimplifiedKeyword))) { - auto& constraint = constraints.emplace_back(); - if (parser.parseAffineExpr(symbolSet, constraint.first) || - parser.parseKeyword("in") || ParseInterval(parser, constraint.second) || - parser.parseComma()) { - return {}; - } - constraints.push_back(constraint); - } - - bool is_simplified = false; - if (parser.parseColon() || parseBool(parser, &is_simplified) || - parser.parseGreater()) { + auto indexing_map = parseChainOfStringsAsIndexingMap(parser); + if (!indexing_map.has_value() || parser.parseGreater()) { return {}; } - return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars, - constraints, is_simplified); + return IndexingMapAttr::get(parser.getContext(), *indexing_map); } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<" << getIndexingMap().ToString() << ">"; + printer << "<\"" << ToString(getIndexingMap()) << "\">"; } IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, @@ -186,34 +83,57 @@ IndexingMapAttr IndexingMapAttr::get(mlir::MLIRContext* context, constraints.push_back({constraint.first, constraint.second}); } return get(context, indexing_map.GetAffineMap(), indexing_map.GetDimVars(), - indexing_map.GetRangeVars(), constraints, - indexing_map.IsSimplified()); + indexing_map.GetRangeVars(), constraints); } mlir::LogicalResult IndexingMapAttr::verify( mlir::function_ref emitError, - mlir::AffineMap map, ArrayRef dim_vars, - ArrayRef range_vars, - ArrayRef> constraints, bool is_simplified) { - if (map.getNumDims() != dim_vars.size()) { - return emitError() << "dim size must match the number of dimensions in " - "the affine map"; + mlir::AffineMap map, ArrayRef dim_vars, + ArrayRef range_vars, + ArrayRef> constraints) { + auto indexing_map = + IndexingMap(map, dim_vars, range_vars, /*rt_vars=*/{}, constraints); + std::stringstream ss; + if (!indexing_map.Verify(ss)) { + return emitError() << ss.str(); } - if (map.getNumSymbols() != range_vars.size()) { - return emitError() - << "range size must match the number of symbols in the affine map"; - } - return mlir::success(); + return success(); } IndexingMap IndexingMapAttr::getIndexingMap() const { return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, - getConstraints(), getIsSimplified()); + getConstraints()); } int64_t IndexingMapAttr::getNumResults() const { return getMap().getNumResults(); } +mlir::Attribute LayoutAttr::parse(mlir::AsmParser& parser, mlir::Type) { + mlir::StringAttr memory_space_str; + if (parser.parseLess() || parser.parseAttribute(memory_space_str) || + parser.parseComma()) { + return {}; + } + std::optional memspace = + symbolizeMemorySpace(memory_space_str.getValue()); + if (!memspace.has_value()) { + return {}; + } + std::optional indexing_map = + parseChainOfStringsAsIndexingMap(parser); + if (!indexing_map.has_value() || parser.parseGreater()) { + return {}; + } + auto* context = parser.getContext(); + return LayoutAttr::get(context, MemorySpaceAttr::get(context, *memspace), + IndexingMapAttr::get(context, *indexing_map)); +} + +void LayoutAttr::print(mlir::AsmPrinter& printer) const { + printer << "<\"" << stringifyMemorySpace(getMemorySpace().getValue()) + << "\", \"" << ToString(getThreadMap().getIndexingMap()) << "\">"; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 3bcdc79e4ff119..adcef4e52c95d5 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -17,6 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" class XLAGPU_Attr traits = []> : @@ -27,16 +28,11 @@ def XLAGPU_AffineMapParameter : AttrOrTypeParameter<"::mlir::AffineMap", ""> { } -def XLAGPU_DimVarsParameter : ArrayRefParameter<"::xla::gpu::DimVar", - "DimVarArray"> { +def XLAGPU_IndexingMapVariableParameter + : ArrayRefParameter<"::xla::gpu::IndexingMap::Variable", + "IndexingMapVariableArray"> { } -def XLAGPU_RangeVarsParameter : ArrayRefParameter<"::xla::gpu::RangeVar", - "RangeVarArray"> { -} - -def XLAGPU_BoolParameter : AttrOrTypeParameter<"bool", ""> {} - def XLAGPU_ConstraintsParameter : ArrayRefParameter<"::std::pair<::mlir::AffineExpr, ::xla::gpu::Interval>", "ContraintsArray"> { @@ -49,10 +45,9 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { https://openxla.org/xla/indexing for more details. }]; let parameters = (ins XLAGPU_AffineMapParameter:$map, - XLAGPU_DimVarsParameter:$dim_vars, - XLAGPU_RangeVarsParameter:$range_vars, - XLAGPU_ConstraintsParameter:$constraints, - XLAGPU_BoolParameter:$is_simplified); + XLAGPU_IndexingMapVariableParameter:$dim_vars, + XLAGPU_IndexingMapVariableParameter:$range_vars, + XLAGPU_ConstraintsParameter:$constraints); let hasCustomAssemblyFormat = 1; let builders = [ AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>, @@ -81,4 +76,41 @@ def XLAGPU_LaunchGridAttr : XLAGPU_Attr<"LaunchGrid"> { }]; } +//===----------------------------------------------------------------------===// +// Tensor layout attribute +//===----------------------------------------------------------------------===// + +def XLAGPU_MemorySpace : I32EnumAttr<"MemorySpace", + "element-wise op type", [ + I32EnumAttrCase<"kRegisters", 0, "registers">, + I32EnumAttrCase<"kSharedMemory", 1, "shmem"> + ]> { + let cppNamespace = "::xla::gpu"; + let genSpecializedAttr = 0; +} + +def XLAGPU_MemorySpaceAttr : EnumAttr< + XlaGpuDialect, XLAGPU_MemorySpace, "memory_space"> { + let assemblyFormat = "`<` $value `>`"; +} + +def XLAGPU_LayoutAttr : XLAGPU_Attr<"Layout"> { + let mnemonic = "layout"; + let summary = "Layout consists of a thread ID indexing map + memory space."; + let description = [{ + This attribute is used as an encoding for RankedTensorType. It indicates in + which memory space the tensor is stored and the access pattern from the + warps/threads. + ```mlir + tensor<42xf32, #xla_gpu.layout<"shmem", (d0) -> (), domain: d0 in [0, 42]>> + ``` + }]; + let parameters = (ins + AttrParameter<"MemorySpaceAttr", "memory_space">:$memory_space, + AttrParameter<"IndexingMapAttr", "thread_map">:$thread_map + ); + let hasCustomAssemblyFormat = 1; +} + + #endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc index 5ac9c59ce773df..c46561a98d0d45 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc @@ -18,6 +18,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Transforms/InliningUtils.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" + +// The order of these includes is important. +#include "xla/service/gpu/fusions/ir/xla_gpu_enums.cc.inc" #define GET_ATTRDEF_CLASSES #include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" #define GET_TYPEDEF_CLASSES @@ -112,6 +115,10 @@ struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { os << "indexing_map"; return AliasResult::FinalAlias; } + if (llvm::isa(attr)) { + os << "layout"; + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } }; diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index d31c5bef66ac34..322636e5dbd107 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -23,7 +23,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "llvm/Support/Casting.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -46,9 +45,9 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "stablehlo/dialect/TypeInference.h" #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -142,8 +141,8 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, AffineMap affine_map, - ArrayRef dim_vars, - ArrayRef range_vars) { + ArrayRef dim_vars, + ArrayRef range_vars) { IndexingMap indexing_map(affine_map, dim_vars, range_vars, {}); build(builder, result, operands, indexing_map); } @@ -184,13 +183,13 @@ ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, parser.parseOptionalAttrDict(result.attributes)) { return failure(); } - auto map = indexing_map_attr.getMap(); + auto map = indexing_map_attr.getIndexingMap().GetAffineMap(); result.addTypes(SmallVector(map.getNumResults(), index_type)); return success(); } void ApplyIndexingOp::print(OpAsmPrinter& p) { - AffineMap affine_map = getIndexingMapAttr().getMap(); + AffineMap affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); p << " " << getIndexingMapAttr(); auto operands = getOperands(); @@ -215,14 +214,14 @@ void ApplyIndexingOp::print(OpAsmPrinter& p) { } LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getIndexingMapAttr().getMap(); + auto affine_map = getIndexingMapAttr().getIndexingMap().GetAffineMap(); unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); if (getOperands().size() != num_variables) { return emitOpError( "operand count must match the number of dimensions and symbols in the " "affine map"); } - if (!getIndexingMapAttr().getConstraints().empty()) { + if (!getIndexingMap().GetConstraints().empty()) { return emitOpError("apply indexing op cannot have any constraints"); } return success(); @@ -311,11 +310,10 @@ struct SimplifyIndexingMap : public mlir::OpRewritePattern { LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { IndexingMap indexing_map = indexing_op.getIndexingMap(); - if (indexing_map.IsSimplified()) { + if (!indexing_map.Simplify()) { return rewriter.notifyMatchFailure(indexing_op, "IndexingMap is already simplified"); } - indexing_map.Simplify(); rewriter.replaceOpWithNewOp( indexing_op, indexing_op.getOperands(), indexing_map); return success(); @@ -467,9 +465,9 @@ struct FoldApplyIndexingOperands unsigned new_num_operands = indexing_op->getNumOperands() - num_constants; SmallVector new_operands; new_operands.reserve(new_num_operands); - SmallVector new_dim_vars; + SmallVector new_dim_vars; new_dim_vars.reserve(num_dims); - SmallVector new_range_vars; + SmallVector new_range_vars; new_range_vars.reserve(num_symbols); unsigned new_num_dims = 0; @@ -836,13 +834,13 @@ LogicalResult LoopOp::verify() { return emitOpError() << "mismatch in number of induction variables " << getNumInductionVars() << " and RangeVars in the indexing map " - << indexing_map.ToString(); + << ToString(indexing_map); } if (indexing_map.GetDimVarsCount() != getDims().size()) { return emitOpError() << "mismatch in number of dims operands " << getDims().size() << " and DimVars in the indexing map " - << indexing_map.ToString(); + << ToString(indexing_map); } for (auto [bb_arg, result_type, init] : llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) { @@ -959,9 +957,9 @@ VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { for (const auto& constraint : map.GetConstraints()) { constraint.first.walk([&](mlir::AffineExpr leaf) { if (auto dim = mlir::dyn_cast(leaf)) { - result.constraints_for_dims[dim.getPosition()].push_back(constraint); + result.constraints_for_dims[dim.getPosition()].insert(constraint); } else if (auto sym = mlir::dyn_cast(leaf)) { - result.constraints_for_symbols[sym.getPosition()].push_back(constraint); + result.constraints_for_symbols[sym.getPosition()].insert(constraint); } }); } @@ -982,7 +980,7 @@ LogicalResult MaterializeOp::verify() { return emitOpError() << "must have thread_id dimension in both indexing maps"; } - if (map_in.GetDimVars(0) != map_out.GetDimVars(0)) { + if (map_in.GetDimVars(0).bounds != map_out.GetDimVars(0).bounds) { return emitOpError() << "thread_id dimension must have the same bounds in " "both indexing maps"; } @@ -1002,7 +1000,7 @@ LogicalResult MaterializeOp::verify() { } for (auto const& [range_in, range_out] : llvm::zip(map_in.GetRangeVars(), map_out.GetRangeVars())) { - if (range_in.range != range_out.range) { + if (range_in.bounds != range_out.bounds) { return emitOpError() << "domain of symbols of indexing_maps must match"; } } @@ -1047,12 +1045,12 @@ LogicalResult MaterializeOp::verify() { //===----------------------------------------------------------------------===// LogicalResult InsertOp::verify() { - if (!getMap().getRangeVars().empty()) { + if (!getMap().getIndexingMap().GetRangeVars().empty()) { return emitOpError() << "insert_op map must not have any symbols"; } int64_t vector_map_num_results = getSource().getType().getIndexingMapAttr().getNumResults(); - if (vector_map_num_results != getMap().getDimVars().size()) { + if (vector_map_num_results != getMap().getIndexingMap().GetDimVars().size()) { return emitOpError() << "source map result count must equal insert_op's " "map's dimension count"; } diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index e025bd90b37e64..28e46396a34091 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep @@ -30,6 +31,7 @@ limitations under the License. #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep #include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_enums.h.inc" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep #define GET_ATTRDEF_CLASSES #include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" @@ -41,9 +43,9 @@ limitations under the License. namespace xla::gpu { struct VariableConstraints { - llvm::SmallVector>> + llvm::SmallVector> constraints_for_dims; - llvm::SmallVector>> + llvm::SmallVector> constraints_for_symbols; }; VariableConstraints GetConstraintsForVariables(const IndexingMap& map); diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td index bd47ade3a4d087..ce951932692581 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td @@ -274,8 +274,8 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { OpBuilder<(ins "mlir::ValueRange":$operands, "const IndexingMap&":$indexing_map)>, OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, - "llvm::ArrayRef":$dim_vars, - "llvm::ArrayRef":$range_vars)>, + "llvm::ArrayRef":$dim_vars, + "llvm::ArrayRef":$range_vars)>, ]; let extraClassDeclaration = [{ // Returns the indexing map constructed from IndexingMapAttr. diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc index 2d9076d7803280..b301ad26edd93e 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/test.h" @@ -34,19 +35,19 @@ class XLAGPUOpsTest : public HloTestBase { }; TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { - auto map = IndexingMap( - ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), - /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, - /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); - map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 1}); - map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), - Interval{0, 2}); - map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}); - map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}); - map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_), - Interval{0, 6}); - + auto map = *ParseIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), + domain: d0 in [0, 5], + d1 in [0, 2], + s0 in [0, 32], + s1 in [0, 1024], + d1 + s1 in [0, 4], + d1 mod 32 in [0, 6], + s0 + s1 in [0, 3], + s0 mod 4 in [0, 1], + s1 mod 4 in [0, 2] + )", + &mlir_context_); auto constraints_for_variables = GetConstraintsForVariables(map); EXPECT_THAT(constraints_for_variables.constraints_for_dims[0], UnorderedElementsAre()); @@ -69,10 +70,14 @@ TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { } TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) { - auto map = IndexingMap( - ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), - /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, - /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + auto map = *ParseIndexingMap(R"( + (d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), + domain: d0 in [0, 5], + d1 in [0, 2], + s0 in [0, 32], + s1 in [0, 1024], + )", + &mlir_context_); auto constraints_for_variables = GetConstraintsForVariables(map); EXPECT_THAT(constraints_for_variables.constraints_for_dims, ElementsAre(IsEmpty(), IsEmpty())); diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc index dbcc20b36f9951..86f2dffa74f4f2 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc @@ -21,14 +21,9 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep -#define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD index 98d8ade7c5e5c3..02347e0ef49e2a 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -13,6 +13,7 @@ cc_library( deps = [ "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:ir_emitter", "//xla/service/gpu:ir_emitter_context", @@ -35,10 +36,11 @@ xla_cc_test( srcs = ["in_place_dynamic_update_slice_test.cc"], deps = [ ":in_place_dynamic_update_slice", + "//xla:xla_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -89,7 +91,7 @@ xla_cc_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -97,6 +99,7 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) @@ -137,16 +140,18 @@ xla_cc_test( srcs = ["scatter_test.cc"], deps = [ ":scatter", + "//xla:xla_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) @@ -205,6 +210,7 @@ cc_library( "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/fusions:reduction_base", "//xla/service/gpu/fusions:thunk_util", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/gpu/runtime:thunk", "//xla/service/llvm_ir:fused_ir_emitter", @@ -214,6 +220,7 @@ cc_library( "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", @@ -224,6 +231,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -237,17 +245,13 @@ xla_cc_test( srcs = ["reduction_test.cc"], deps = [ ":reduction", - "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", @@ -287,16 +291,18 @@ xla_cc_test( srcs = ["concatenate_test.cc"], deps = [ ":concatenate", + "//xla:xla_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -322,6 +328,7 @@ cc_library( "//xla/service/llvm_ir:ir_array", "//xla/service/llvm_ir:llvm_util", "//xla/service/llvm_ir:loop_emitter", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", @@ -341,9 +348,11 @@ xla_cc_test( deps = [ ":transpose", "//xla:status_macros", + "//xla:xla_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", @@ -392,10 +401,11 @@ xla_cc_test( srcs = ["input_slices_test.cc"], deps = [ ":input_slices", + "//xla:xla_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc index 9a9bdc2dd488b2..adc7b37b3f1de2 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -14,39 +14,30 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/fusions/legacy/concatenate.h" -#include - #include #include #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" namespace xla { namespace gpu { namespace { class ConcatenateTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -93,26 +84,25 @@ TEST_F(ConcatenateTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 399], - is_simplified: true + bl_x * 128 + th_x in [0, 399] )"; + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h index db12c3cbbf4643..dac691d63688a7 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index 53be6363567cdd..c28fd0a8b3686b 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -22,10 +22,11 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" namespace xla { @@ -33,21 +34,12 @@ namespace gpu { namespace { class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); @@ -83,7 +75,9 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); - EXPECT_THAT(thread_id_update_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_update_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}, {}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( th_x floordiv 6, th_x mod 6), @@ -95,8 +89,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc index 0c604502bd51d1..ce233ff6bc4e5b 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -22,31 +22,23 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" namespace xla { namespace gpu { namespace { class InputSlicesTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -80,7 +72,9 @@ TEST_F(InputSlicesTest, ThreadIndexing) { auto thread_id_to_output_indexing = fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), + EXPECT_THAT(ToString(*thread_id_to_output_indexing, + {"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, + {"chunk_id", "unroll_id"}, {}), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, ((bl_x * 128 + th_x) floordiv 3) mod 2, @@ -95,8 +89,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 29], - is_simplified: true + bl_x * 128 + th_x in [0, 29] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc index e6ce5f113c713b..afec7da45465d5 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc @@ -23,15 +23,12 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/numeric/bits.h" #include "absl/status/status.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" #include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -45,9 +42,7 @@ limitations under the License. #include "xla/service/llvm_ir/fused_ir_emitter.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/util.h" -#include "tsl/platform/macros.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index 82a9de34c7cc49..17ad16b157d748 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -20,11 +20,12 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" @@ -36,19 +37,9 @@ namespace gpu { namespace { class LoopTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - protected: stream_executor::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -85,8 +76,12 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 15000, ((bl_x * 128 + th_x) floordiv 75) mod 200, @@ -101,8 +96,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 3], - bl_x * 128 + th_x in [0, 1499999], - is_simplified: true + bl_x * 128 + th_x in [0, 1499999] )")); } @@ -128,8 +122,12 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: th_x in [0, 19], @@ -139,14 +137,14 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), - MatchIndexingString(R"( + EXPECT_THAT( + ToString(*thread_id_to_input_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), domain: th_x in [0, 19], @@ -156,8 +154,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { bl_y in [0, 0], bl_z in [0, 0], chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true + unroll_id in [0, 0] )")); } @@ -183,8 +180,12 @@ TEST_F(LoopTest, Broadcast) { auto thread_id_to_output_indexing = loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; + EXPECT_THAT( + ToString(*thread_id_to_output_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( (bl_x * 128 + th_x) floordiv 600, ((bl_x * 128 + th_x) floordiv 30) mod 20, @@ -198,14 +199,14 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); auto thread_id_to_input_indexing = loop_fusion->ComputeThreadIdToInputIndexing( /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), - MatchIndexingString(R"( + EXPECT_THAT( + ToString(*thread_id_to_input_indexing, dim_names, range_names, {}), + MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (((bl_x * 128 + th_x) floordiv 30) mod 20), domain: @@ -217,8 +218,7 @@ TEST_F(LoopTest, Broadcast) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true + bl_x * 128 + th_x in [0, 5999] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc index e009ea18e0b48c..0448b5d7c1c7d0 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc @@ -43,6 +43,9 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -63,6 +66,8 @@ limitations under the License. #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -78,6 +83,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -1223,14 +1229,9 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( auto physical_shape = ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape()); - std::vector dimension_ranges{ - {{0, tiling_.GetNumThreadsPerBlock() - 1}}, - {}, - {}, - {{0, tiling_.GetNumBlocks() - 1}}, - {{0, static_cast(groups_.grouped_roots.size() - 1)}}, - {}, - }; + std::vector dimension_ranges = DimVarsFromGPUGrid( + {tiling_.GetNumThreadsPerBlock(), 1, 1, tiling_.GetNumBlocks(), + static_cast(groups_.grouped_roots.size()), 1}); constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; constexpr int kRowMinorReduced = @@ -1264,7 +1265,7 @@ std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( mlir::SmallVector projected_dims{ block_offsets.getResult(kColMajorKept), block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]}; - std::vector range_vars; + std::vector range_vars; if (thread_ids.size() == 4) { int vector_size = tiling_.GetThreadTileSize().back(); range_vars.push_back({0, vector_size - 1}); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h index 131b4ec38c7693..a94ac72b5bc0b4 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "llvm/IR/IRBuilder.h" +#include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc index 144159ce442424..8109a7b1c4068d 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -19,15 +19,11 @@ limitations under the License. #include #include -#include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -73,42 +69,40 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { ReductionFusion fusion(analysis); EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + ToString(*fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32, - (d0 mod 32) * 2 + s2 * 64 + s3 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2, s3] -> ( + bl_x floordiv 8, + (bl_x mod 8) * 8 + th_x floordiv 32, + (th_x mod 32) * 2 + s2 * 64 + s3 ), domain: - d0 in [0, 255], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 255], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 799], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 0], s2 in [0, 7], - s3 in [0, 1], - is_simplified: true + s3 in [0, 1] )")); EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + ToString(*fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z) -> ( + bl_x floordiv 8, + (bl_x mod 8) * 8 + th_x floordiv 32 ), domain: - d0 in [0, 224], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: true + th_x in [0, 224], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 799], + bl_y in [0, 0], + bl_z in [0, 0], + th_x mod 32 in [0, 0] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc index 07987886a73120..c06c3e143eb913 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc @@ -36,14 +36,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/legacy/loop.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_nested.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/llvm_ir/fused_ir_emitter.h" diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc index 8c6674d4a2b546..4055af6a7ca873 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -15,17 +15,20 @@ limitations under the License. #include "xla/service/gpu/fusions/legacy/scatter.h" #include +#include #include #include #include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" namespace xla { @@ -33,21 +36,13 @@ namespace gpu { namespace { class ScatterFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id", "index_id"}); - } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; } protected: - AffineMapPrinter printer_; mlir::MLIRContext mlir_context_; }; @@ -163,34 +158,33 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { bl_z in [0, 0], chunk_id in [0, 0], unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; + mlir::SmallVector dim_names = {"th_x", "th_y", "th_z", + "bl_x", "bl_y", "bl_z"}; + mlir::SmallVector range_names = {"chunk_id", "unroll_id"}; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kUpdatesIndexing)); + range_names.push_back("index_id"); constexpr auto kIndicesIndexing = R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> ((bl_x * 128 + th_x) floordiv 200, 0), @@ -204,20 +198,17 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { chunk_id in [0, 0], unroll_id in [0, 0], index_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true + bl_x * 128 + th_x in [0, 8399] )"; EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kIndicesIndexing)); EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), + ToString(*fusion->ComputeThreadIdToInputIndexing( + /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_), + dim_names, range_names, {}), MatchIndexingString(kIndicesIndexing)); } diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc index a1a7acb58388a7..b01dc613aa98e6 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc @@ -24,7 +24,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" @@ -36,6 +38,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" @@ -333,9 +336,8 @@ IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { offsets.push_back(block + thread); } - std::vector dimension_ranges{ - {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, - }; + std::vector dimension_ranges = + DimVarsFromGPUGrid({threads_per_block, 1, 1, num_blocks, 1, 1}); auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), block_offsets.getNumSymbols(), offsets, mlir_context); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc index f91a0a4b6b120f..5940893415d132 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc @@ -35,6 +35,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" #include "mlir/IR/AffineMap.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" @@ -52,6 +53,7 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h index 3366130c05546b..9596a6036698df 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index bba3d721368e5b..a19dc94eef0dc1 100644 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -24,10 +24,12 @@ limitations under the License. #include "xla/service/gpu/fusions/fusions.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla.pb.h" #include "tsl/platform/statusor.h" namespace xla { @@ -36,7 +38,7 @@ namespace { class TransposeTest : public HloTestBase { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; @@ -77,46 +79,44 @@ TEST_F(TransposeTest, ThreadIndexing021) { mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4, + (bl_x mod 2) * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + (bl_x mod 2) * 32 + s1 * 4 + th_x floordiv 32, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -141,46 +141,44 @@ TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( 0, - d3 * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 + bl_x * 32 + s1 * 4 + th_x floordiv 32, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( 0, - d0 floordiv 32 + s1 * 4, - d3 * 32 + d0 mod 32 + th_x floordiv 32 + s1 * 4, + bl_x * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } @@ -207,46 +205,44 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + th_x floordiv 32 + s0 * 4, + bl_x, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 1], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + th_x mod 32 in [0, 23] )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + th_x floordiv 32 + s0 * 4, + bl_x, + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 1], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 5], s1 in [0, 0], s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true + th_x mod 32 in [0, 23] )")); } @@ -274,8 +270,8 @@ TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) { mlir::MLIRContext mlir_context; EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString()); + ToString(*fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)), + ToString(*fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context))); } TEST_F(TransposeTest, ThreadIndexingSideOutput) { @@ -305,45 +301,43 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { // Check if side output `%broadcast` get the correct input indexing, which // should corresponds to `%input1` with shape [100,32]. EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), + ToString(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 + (th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1, s2] -> ( + bl_x floordiv 2, + th_x floordiv 32 + s1 * 4, + (bl_x mod 2) * 32 + th_x mod 32 ), domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], + th_x in [0, 127], + th_y in [0, 0], + th_z in [0, 0], + bl_x in [0, 199], + bl_y in [0, 0], + bl_z in [0, 0], s0 in [0, 0], s1 in [0, 7], - s2 in [0, 0], - is_simplified: true + s2 in [0, 0] )")); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index f6890b24806648..5f2d295d7eade6 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -20,13 +20,10 @@ cc_library( deps = [ ":type_util", "//xla:shape_util", - "//xla:union_find", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions:fusion_emitter", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -39,6 +36,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", ], ) @@ -68,7 +66,9 @@ cc_library( "//xla:comparison_util", "//xla:shape_util", "//xla:status_macros", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir/utils:type_util", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", @@ -79,7 +79,6 @@ cc_library( "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -115,8 +114,8 @@ xla_cc_test( ":elemental_hlo_to_mlir", "//xla:status_macros", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/mlir_hlo", - "//xla/service:hlo_parser", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", @@ -130,6 +129,7 @@ xla_cc_test( "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:DLTIDialect", "@llvm-project//mlir:FuncDialect", @@ -164,6 +164,7 @@ cc_library( "//xla/mlir_hlo:mhlo_passes", "//xla/service:buffer_assignment", "//xla/service:dump", + "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emitter_context", "//xla/service/gpu:kernel_arguments", "//xla/service/gpu:kernel_reuse_cache", @@ -240,6 +241,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", @@ -273,8 +275,8 @@ cc_library( deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "//xla/mlir/utils:type_util", - "//xla/translate/hlo_to_mhlo:hlo_utils", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/README.md b/third_party/xla/xla/service/gpu/fusions/mlir/README.md index d692bd279bce98..e04c771208ab16 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/README.md +++ b/third_party/xla/xla/service/gpu/fusions/mlir/README.md @@ -98,7 +98,7 @@ do not. ### Gather We only support canonical gathers as produced by [`gather_simplifier`]( -https://github.com/openxla/xla/blob/main/xla/service/gather_simplifier.h). +https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/gather_simplifier.h). ## Emission of functions diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc index 53d8678e953074..d34c05cc28e37c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -32,9 +31,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -43,15 +43,16 @@ limitations under the License. #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/fusions/mlir/type_util.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" +#include "xla/shape_util.h" namespace xla { namespace gpu { @@ -81,49 +82,6 @@ EpilogueSpecification EpilogueSpecification::FromIdentityIndexing( return result; } -EpilogueSpecification EpilogueSpecification::FromOutputIndexing( - const HloFusionAnalysis& analysis, - const std::vector& heroes, - const std::vector& roots, - const KernelFusionInterface& fusion, mlir::MLIRContext* mlir_context) { - EpilogueSpecification result; - - absl::flat_hash_map - root_to_hero; - for (auto [root, hero] : - llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes())) { - root_to_hero[&root.instruction()] = &hero.instruction(); - } - absl::flat_hash_map root_to_index; - for (auto [index, root] : llvm::enumerate(analysis.fusion_roots())) { - root_to_index[&root.instruction()] = root_to_index.size(); - } - - result.root_indexing.reserve(roots.size()); - for (auto* root : roots) { - auto indexing = fusion.ComputeThreadIdToOutputIndexing(root_to_index[root], - mlir_context); - if (result.index_ranges.empty()) { - result.index_ranges.reserve(indexing->GetDimensionCount() + - indexing->GetSymbolCount()); - for (const auto& dim : indexing->GetDimensionBounds()) { - result.index_ranges.push_back(dim.upper + 1); - } - for (const auto& sym : indexing->GetSymbolBounds()) { - result.index_ranges.push_back(sym.upper + 1); - } - } - auto* hero = root_to_hero[root]; - auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( - {*hero, &analysis.fusion()}, {*root, &analysis.fusion()}, mlir_context); - result.root_indexing.push_back( - ComposeIndexingMaps(*indexing, epilogue_indexing)); - } - result.heroes = heroes; - result.roots = roots; - return result; -} - std::string PartitionedComputation::Subgraph::ToString(int indentation) const { std::string indent(indentation, ' '); std::ostringstream ss; diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h index 5d5c78c4cd64aa..e3775029d20f26 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.h @@ -20,17 +20,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/types/span.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -41,13 +40,6 @@ struct EpilogueSpecification { static EpilogueSpecification FromIdentityIndexing( const HloInstruction* hero, const HloInstruction* root, mlir::MLIRContext* mlir_context); - // Creates an epilogue with the raw thread/block/symbol indices, as defined - // by the fusion's thread->output mapping. - static EpilogueSpecification FromOutputIndexing( - const HloFusionAnalysis& analysis, - const std::vector& heroes, - const std::vector& roots, - const KernelFusionInterface& fusion, mlir::MLIRContext* mlir_context); std::vector heroes; std::vector roots; diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index d94c6d5a038461..fd0dc7c570c3e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -66,6 +66,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/primitive_util.h" @@ -77,7 +78,7 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" namespace xla { @@ -611,7 +612,7 @@ SmallVector MapHloOp(mlir::Type result_type, Value result = mhlo::MhloOpToStdScalarOp::mapOpOfType( b.getLoc(), result_type, arg_types, typename MhloOp::Adaptor(args, std::forward(extra_args)...), - &b); + /*attributes=*/std::nullopt, &b); if (result.getType().isInteger(1)) { result = b.create(b.getI8Type(), result); } @@ -853,7 +854,7 @@ absl::StatusOr> EmitConvert( } auto out = mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( builder.getLoc(), result_type_with_sign, result_element_type, arg_types, - operands, &builder); + operands, /*attributes=*/std::nullopt, &builder); if (auto int_ty = mlir::dyn_cast(out.getType())) { auto in = operands[0]; if (auto float_ty = mlir::dyn_cast(in.getType())) { @@ -918,7 +919,7 @@ absl::StatusOr> EmitIota(const HloInstruction* instr, index = builder.create(index_type, index); return {{mhlo::MhloOpToStdScalarOp::mapConvertOpToStdScalarOp( builder.getLoc(), result_type_with_sign, result_element_type, - {index_type}, {index}, &builder)}}; + {index_type}, {index}, /*attributes=*/std::nullopt, &builder)}}; } absl::StatusOr> EmitCompare( @@ -933,7 +934,8 @@ absl::StatusOr> EmitCompare( auto result_types = llvm::to_vector(mlir::TypeRange{builder.getI1Type()}); auto i1 = mhlo::MhloOpToStdScalarOp::mapOpOfType( builder.getLoc(), result_types, arg_types, - mhlo::CompareOp::Adaptor(operands, nullptr, properties), &builder); + mhlo::CompareOp::Adaptor(operands, nullptr, properties), + /*attributes=*/std::nullopt, &builder); return {{builder.create(builder.getI8Type(), i1) .getResult()}}; } @@ -1626,9 +1628,9 @@ ValueRange EmitLoopNest(ImplicitLocOpBuilder& b, ValueRange dim_values, remainder.GetMutableSymbolBound(sym_index).lower = bound.upper; remainder.Simplify(); - VLOG(5) << "Peeled indexing map " << indexing_map.ToString() << "\n into " - << peeled_map.ToString() << "\nand remainder\n" - << remainder.ToString(); + VLOG(5) << "Peeled indexing map " << indexing_map << "\n into " + << peeled_map << "\nand remainder\n" + << remainder; return EmitLoopNestImpl(b, dim_values, first_results, remainder, create_body, vectorize); } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 8f109aa3f452fe..c92f2445709e89 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -22,8 +22,10 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 33046ea54085aa..fff26b55878065 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -31,20 +32,18 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/hlo_parser.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -235,10 +234,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2], is_simplified: true>(%[[Y]]) + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 2]">(%[[Y]]) // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 3), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 6], is_simplified: true>(%[[Z]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 3), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 6]">(%[[Z]], %[[I]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -285,8 +284,8 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // If symbol rescaling wasn't working we would have a // `d1 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 3], is_simplified: true>(%[[X]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 18], d1 in [0, 3]">(%[[X]], %[[I]]) // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -506,7 +505,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -518,9 +517,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true>(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -548,7 +547,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -560,9 +559,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7], is_simplified: true>(%[[X]]) + // CHECK-SAME: <"(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]">(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7], is_simplified: true>(%[[Y]]) + // CHECK-SAME: <"(d0) -> (d0 - 4), domain: d0 in [4, 7]">(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -879,11 +878,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -925,11 +924,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 2], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 2], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -972,21 +971,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), domain: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 - 2), - // CHECK-SAME: d0 in [0, 11], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1 - 2), + // CHECK-SAME: d0 in [0, 11], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1026,17 +1025,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) mod 2), domain: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 12], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK-SAME: d0 in [0, 12], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 + d1) floordiv 2), - // CHECK-SAME: d0 in [0, 18], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 + d1) floordiv 2), + // CHECK-SAME: d0 in [0, 18], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1078,11 +1077,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), - // CHECK-SAME: d0 in [0, 3], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), + // CHECK-SAME: d0 in [0, 3], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1124,14 +1123,14 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> ((d0 floordiv 8) * 2 + d1), - // CHECK-SAME: d0 in [0, 15], d1 in [0, 1], is_simplified: true>(%[[O]], %[[I]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> ((d0 floordiv 8) * 2 + d1), + // CHECK-SAME: d0 in [0, 15], d1 in [0, 1]">(%[[O]], %[[I]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1175,11 +1174,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 5], d1 in [0, 2], is_simplified: true>(%[[W]], %[[X]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 5], d1 in [0, 2]">(%[[W]], %[[X]]) // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), - // CHECK-SAME: d0 in [0, 7], d1 in [0, 4], is_simplified: true>(%[[H]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 + d1), + // CHECK-SAME: d0 in [0, 7], d1 in [0, 4]">(%[[H]], %[[Y]]) // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1645,8 +1644,8 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true>(%[[X]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1669,8 +1668,8 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = - // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), - // CHECK-SAME: d0 in [0, 9], d1 in [0, 9], is_simplified: true>(%[[X]], %[[Y]]) + // CHECK-SAME: #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]">(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 45d66ebca0108f..e5f806f04c0384 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -92,11 +92,11 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/runtime/kernel_thunk.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -416,6 +416,50 @@ MlirFusionEmitterBase::CreateMLIRModule( return module; } +mlir_converter::EpilogueSpecification +MlirFusionEmitterBase::GetEpilogueForOutputIndexing( + const HloFusionAnalysis& analysis, + const std::vector& heroes, + const std::vector& roots, + mlir::MLIRContext* mlir_context) const { + mlir_converter::EpilogueSpecification result; + + absl::flat_hash_map + root_to_hero; + for (auto [root, hero] : + llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes())) { + root_to_hero[&root.instruction()] = &hero.instruction(); + } + absl::flat_hash_map root_to_index; + for (auto [index, root] : llvm::enumerate(analysis.fusion_roots())) { + root_to_index[&root.instruction()] = root_to_index.size(); + } + + result.root_indexing.reserve(roots.size()); + for (auto* root : roots) { + auto indexing = + ComputeThreadIdToOutputIndexing(root_to_index[root], mlir_context); + if (result.index_ranges.empty()) { + result.index_ranges.reserve(indexing->GetDimensionCount() + + indexing->GetSymbolCount()); + for (const auto& dim : indexing->GetDimensionBounds()) { + result.index_ranges.push_back(dim.upper + 1); + } + for (const auto& sym : indexing->GetSymbolBounds()) { + result.index_ranges.push_back(sym.upper + 1); + } + } + auto* hero = root_to_hero[root]; + auto epilogue_indexing = ComputeEpilogueInputToOutputIndexing( + {*hero, &analysis.fusion()}, {*root, &analysis.fusion()}, mlir_context); + result.root_indexing.push_back( + ComposeIndexingMaps(*indexing, epilogue_indexing)); + } + result.heroes = heroes; + result.roots = roots; + return result; +} + absl::Status MlirFusionEmitterBase::EmitMlir( mlir::ModuleOp module, FuncOp entry_function, const HloFusionInstruction& fusion) const { @@ -542,6 +586,7 @@ void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) { void AddLoopTransformationPasses(mlir::OpPassManager& pm) { pm.addNestedPass(CreateLowerXlaGpuToScfPass()); + pm.addNestedPass(CreateFuseLoopsPass()); pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { // CSE after inlining because inlining can introduce duplicates. pm.addPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index 68ce87f4374aab..9f410694434a91 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/stream_executor/device_description.h" @@ -79,6 +80,14 @@ class MlirFusionEmitterBase : public KernelFusionInterface { return {}; } + // Creates an epilogue with the raw thread/block/symbol indices, as defined + // by the fusion's thread->output mapping. + mlir_converter::EpilogueSpecification GetEpilogueForOutputIndexing( + const HloFusionAnalysis& analysis, + const std::vector& heroes, + const std::vector& roots, + mlir::MLIRContext* mlir_context) const; + virtual absl::Status EmitEntryFunction( const mlir_converter::PartitionedComputations& computations, const mlir_converter::CallTargetProvider& call_targets, diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index f896d5e6b37475..35e14a98200c66 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc index 81e568f956c3b2..76d4b284ebc331 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/type_util.cc @@ -20,11 +20,11 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/layout_util.h" #include "xla/mlir/utils/type_util.h" #include "xla/primitive_util.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index b7f62e3c7d1d54..dfac23affc34e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -31,7 +31,6 @@ limitations under the License. #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" @@ -40,14 +39,10 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/union_find.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.h b/third_party/xla/xla/service/gpu/fusions/reduction_base.h index ad99e6f40140c9..0239838c2b8bb3 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/reduction_utils.h" +#include "xla/util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index b2db8fa5cd1730..b59cbb3f8ca02b 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include #include -#include +#include #include #include #include @@ -27,6 +27,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" @@ -61,7 +63,9 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { @@ -75,7 +79,6 @@ using mlir::Value; using mlir::ValueRange; using mlir_converter::PartitionedComputations; -constexpr int kRowMajorReduced = ReductionDimensions::kRowMajorReducedDimension; constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension; constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension; @@ -359,7 +362,6 @@ MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) : analysis_(analysis) { auto* hero_reduction = analysis.FindHeroReduction(); CHECK_NE(hero_reduction, nullptr); - Shape input_shape = hero_reduction->operand(0)->shape(); reduction_dimensions_ = GetReductionKindAndContiguousComponents(*hero_reduction); VLOG(10) << reduction_dimensions_; @@ -400,12 +402,11 @@ IndexingMap MlirReductionFusion::GetIndexingMap( absl::Span symbol_sizes) const { auto* ctx = results.front().getContext(); auto num_groups = static_cast(reduction_heroes_.size()); - return IndexingMap{ - AffineMap::get(6, symbol_sizes.size(), results, ctx), - DimVarsFromTensorSizes( - {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}), - RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}}; + return IndexingMap{AffineMap::get(6, symbol_sizes.size(), results, ctx), + DimVarsFromGPUGrid({Product(num_threads_), 1, 1, + Product(num_blocks_), num_groups, 1}), + RangeVarsFromTensorSizes(symbol_sizes), + /*rt_vars=*/{}}; } IndexingMap MlirReductionFusion::GetThreadIndexingMap( @@ -414,10 +415,13 @@ IndexingMap MlirReductionFusion::GetThreadIndexingMap( absl::Span symbol_sizes) const { auto affine_map = AffineMap::get(1, symbol_sizes.size(), results, results.front().getContext()); - return IndexingMap{affine_map, - DimVarsFromTensorSizes({Product(num_threads_)}), - RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}, constraints}; + return IndexingMap{ + affine_map, + {IndexingMap::Variable{0, Product(num_threads_) - 1, + ToVariableName(VariableKind::kThreadX)}}, + RangeVarsFromTensorSizes(symbol_sizes), + /*rt_vars=*/{}, + constraints}; } LaunchDimensions MlirReductionFusion::launch_dimensions() const { @@ -436,8 +440,7 @@ MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, for (const auto& [heroes, roots] : llvm::zip(reduction_heroes_, reduction_roots_)) { epilogues.push_back( - mlir_converter::EpilogueSpecification::FromOutputIndexing( - analysis_, heroes, roots, *this, mlir_context)); + GetEpilogueForOutputIndexing(analysis_, heroes, roots, mlir_context)); } // Add empty epilogues for the side outputs. This ensures their roots don't // get "fused" into the tuple function. @@ -770,32 +773,11 @@ llvm::SmallVector MlirSmallColumnReductionFusion::EmitReduction( shared_rows_ / 2); } -std::unique_ptr CreateMlirReductionFusion( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp( - reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); - } - - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); -} - MlirRowReductionFusion::MlirRowReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; @@ -931,33 +913,28 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( } MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis) + const HloFusionAnalysis& analysis, int vector_size) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); input_shape_ = {shape[0], shape[1], shape[2]}; - CHECK_GT(rows_per_warp, 1); - - auto compute_block_size = [&](int vector_size) { - int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - num_threads_ = {num_threads_kept, num_threads_reduced}; - tile_sizes_per_thread_ = {shape[0], vector_size}; - num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; - }; + num_threads_ = GetNumThreads(reduction_dimensions_, vector_size); + num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)}; + tile_sizes_per_thread_ = {shape[0], vector_size}; +} - // Compute the launch grid without vectorization. We use the results to - // compute the vectorized launch grid. - compute_block_size(1); +std::unique_ptr MlirMultiRowReductionFusion::TryCreate( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + auto reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; + // This emitter only supports reductions where the reduced dimension is a + // power of 2. + if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) { + return nullptr; + } // Normally, we only consider input types for vectorization. However, in // multi-row reductions, the input:output ratio is much higher, so we consider @@ -965,24 +942,80 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( int smallest_input_or_output_bits = std::min(analysis.input_output_info().smallest_input_dtype_bits, analysis.input_output_info().smallest_output_dtype_bits); + int largest_input_or_output_bits = + std::max(analysis.input_output_info().smallest_input_dtype_bits, + analysis.input_output_info().smallest_output_dtype_bits); + // Handle the case when there are no inputs. + if (largest_input_or_output_bits == std::numeric_limits::max()) { + largest_input_or_output_bits = + analysis.input_output_info().smallest_output_dtype_bits; + } - // This vector size is always valid: we know that the reduced dimension is a - // power of 2, since otherwise RowReductionGetRowsPerWarp would have - // returned 1. // Our codegen can't currently deal with vectorization across rows, so we // limit the vector size to the size of the row. Note that this emitter // essentially reverts to the loop emitter in this case, except for side // outputs. - int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), - 32 / smallest_input_or_output_bits); - - // We target 8 warps per block, which means there could be up to 8 blocks per - // SM, but we have no good way of knowing. In practice, enabling vectorization - // for decently sized reductions at least does not hurt. - if (num_blocks_.front() > analysis.device_info().core_count() && - vector_size > 1) { - compute_block_size(vector_size); + int vector_size = std::min(static_cast(shape[kRowMinorReduced]), + 64 / smallest_input_or_output_bits); + + // Very large vector sizes for f32 can be detrimental, so we limit the vector + // size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still + // a bit on the high side, but remember that we also have very small inputs + // or outputs. + if (largest_input_or_output_bits >= 32) { + vector_size = std::min(128 / largest_input_or_output_bits, vector_size); + } + + // The reduced dimension must fit into a single warp. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; + } + + // At the very least, we want to have work for every SM. + // TODO(jreiffers): This limit is probably too low: if we have as many blocks + // as SMs, we'll only run about 8 warps per SM, so occupancy will be very low. + // Further measurements are needed to refine this heuristic. + int64_t min_desired_blocks = analysis.device_info().core_count(); + while (vector_size > 1 && + GetNumBlocks(reduction_dimensions, + GetNumThreads(reduction_dimensions, vector_size)) < + min_desired_blocks) { + vector_size /= 2; + } + + // Check again that the reduced dimension fits after potentially reducing the + // vector size. + if (shape[kRowMinorReduced] > WarpSize() * vector_size) { + return nullptr; } + + return std::make_unique(analysis, vector_size); +} + +absl::InlinedVector MlirMultiRowReductionFusion::GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size) { + int64_t num_threads_reduced = + reduction_dimensions.dimensions[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + return {num_threads_kept, num_threads_reduced}; +} + +int64_t MlirMultiRowReductionFusion::GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads) { + CHECK_EQ(num_threads.size(), 2) + << "Expected num_threads to contain the number of threads in the {kept, " + "reduced} dimensions."; + return CeilOfRatio(reduction_dimensions.dimensions[kRowKept], + num_threads.front()); } IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( @@ -1013,8 +1046,7 @@ IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( : mlir::getAffineDimExpr(3, ctx); IndexingMap projected_index = GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), - {0, 0}); + projected_index.AddConstraint(thread_id[1] % num_threads_[1], {0, 0}); // We don't need a constraint on the loop dimensions, because they are removed // by GetIndexingMap (since they don't show up in the output index // computation). @@ -1034,10 +1066,30 @@ llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( auto per_thread = state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, - WarpSize() / 2 / GetRowsPerWarp()); + num_threads_[1] / 2); return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, group_id, /*symbol_values=*/{}); } +std::unique_ptr CreateMlirReductionFusion( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + if (reduction_dimensions.is_row_reduction) { + auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis); + if (multi_row_emitter != nullptr) { + return multi_row_emitter; + } + return std::make_unique(analysis); + } + + if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 838729254070ac..b4deb0ee862b9e 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -16,10 +16,12 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ #include +#include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/types/span.h" @@ -168,9 +170,23 @@ class MlirRowReductionFusion : public MlirReductionFusion { class MlirMultiRowReductionFusion : public MlirReductionFusion { public: - explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis); + MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis, + int vector_size); + + // Attempts to create a multi-row reduction emitter for the given analysis. + // Returns nullptr if the fusion is not supported. + static std::unique_ptr TryCreate( + const HloFusionAnalysis& analysis); protected: + // Returns the number of {kept, reduced} threads for the given reduction and + // vector size. + static absl::InlinedVector GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size); + static int64_t GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads); + int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo index f99ff371ef38d1..875bb871d287e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/concatenate/concat_1d.hlo @@ -8,10 +8,10 @@ fusion { param2 = f32[300] parameter(2) ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} } -// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0) -// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 200) -// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3, d4, d5)[s0, s1] -> (d3 * 128 + d0 + 600) +// CHECK-DAG: #[[MAP:.*]] = #xla_gpu.indexing_map<"(th_x, bl_x) -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_1:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x) +// CHECK-DAG: #[[LOOPMAP_2:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 200) +// CHECK-DAG: #[[LOOPMAP_3:.*]] = #xla_gpu.indexing_map<"(th_x, th_y, th_z, bl_x, bl_y, bl_z)[s0, s1] -> (bl_x * 128 + th_x + 600) // CHECK: func.func @main // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo new file mode 100644 index 00000000000000..92914a55dc3c42 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/int4.hlo @@ -0,0 +1,14 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=dus:1 + +f { + input = s4[100,9] parameter(0) + slice = s4[100,6] parameter(1) + c0 = s32[] constant(0) + ROOT dus = s4[100,9] dynamic-update-slice(input, slice, c0, c0) +} + +// CHECK: vector.transfer_read +// CHECK: tensor.insert +// CHECK: tensor.insert diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo new file mode 100644 index 00000000000000..b9dbddaa26bb4d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x1_too_small.hlo @@ -0,0 +1,14 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=dus:1 + +dus { + %input = f32[40,40,300] parameter(0) + %update = f32[1,1,40] parameter(1) + %idx = s32[] parameter(2) + %zero = s32[] constant(0) + ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero) +} + +// CHECK-NOT: vector.transfer_read {{.*}} vector<4xf32> +// CHECK-NOT: vector.transfer_write {{.*}} vector<4xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo new file mode 100644 index 00000000000000..775dd248d0bb7b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/dynamic_update_slice/vectorize_x4.hlo @@ -0,0 +1,14 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=dus:1 + +dus { + %input = f32[40,40,300] parameter(0) + %update = f32[20,40,300] parameter(1) + %idx = s32[] parameter(2) + %zero = s32[] constant(0) + ROOT dus = f32[40,40,300] dynamic-update-slice(%input, %update, %idx, %zero, %zero) +} + +// CHECK: vector.transfer_read {{.*}} vector<4xf32> +// CHECK: vector.transfer_write {{.*}} vector<4xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo new file mode 100644 index 00000000000000..5719809d3a327b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_s4.hlo @@ -0,0 +1,13 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: --inline="default-pipeline='cse'" | FileCheck %s +// RUN: test_correctness %s --bijection_outputs=broadcast + +bcast { + x = s4[] constant(-2) + ROOT broadcast = s4[3]{0} broadcast(x), dimensions={} +} +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<3xi4> +// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in +// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) +// CHECK: %[[CST:.*]] = arith.constant -2 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo new file mode 100644 index 00000000000000..782b2e17753f0d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_non_scalar_constant_s4.hlo @@ -0,0 +1,19 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: --inline="default-pipeline='cse'" | FileCheck %s +// RUN: test_correctness %s --bijection_outputs=broadcast + +bcast { + x = s4[3]{0} constant({-2, -3, -4}) + ROOT broadcast = s4[3,31]{1,0} broadcast(x), dimensions={0} +} + +ENTRY main { + ROOT res = s4[3,31]{1,0} fusion(), kind=kLoop, calls=bcast +} + +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<3x31xi4> +// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]]) in +// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) +// CHECK: %[[CST:.*]] = arith.constant dense<[-2, -3, -4]> +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[CST]][%[[RA]]] +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[EXTRACTED]] into %[[ITER]][%[[RA]], %[[RB]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo index 3b5e454584137a..4f93eacbfab93d 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo @@ -12,8 +12,8 @@ fusion { ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) } -// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 4), -// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 4), +// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 4), +// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 4), // CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]]) in // CHECK-DAG: %[[MAJOR_IDX:.*]] = xla_gpu.apply_indexing #[[MAJOR]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo new file mode 100644 index 00000000000000..d2e3928bdfd564 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +// The reference implementation reduces in f64, so we need a larger tolerance. +// RUN: test_correctness %s --bijection_inputs=reduce:0 \ +// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005 + +add { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(lhs, rhs) +} + +fusion { + param_0 = f16[2048,64] parameter(0) + c = f16[] constant(0) + ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add +} + +// If unvectorized, this would be a regular row reduction. However, since we can +// vectorize to size four, we can emit this as a multi-row reduction. +// CHECK: vector.transfer_read {{.*}} vector<4xf16> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f32_x8_no_inputs.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f32_x8_no_inputs.hlo new file mode 100644 index 00000000000000..6e8220723cd512 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f32_x8_no_inputs.hlo @@ -0,0 +1,21 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fusion { + one = f32[] constant(1) + bc = f32[1024,4] broadcast(one), dimensions={} + c = f32[] constant(0) + ROOT reduce = f32[1024] reduce(bc, c), dimensions={1}, to_apply=add +} + +// Multi-row reductions do not use shared memory. +// CHECK-NOT: allocate_shared +// There should be 8 elements per warp. +// CHECK: shuffle_reduce(%{{.*}}) to 2 +// CHECK-NOT: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo index a0663dd88308fb..8abb0d548d1c06 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/scatter/unique_indices.hlo @@ -24,7 +24,7 @@ scatter { unique_indices=true, to_apply=add } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2) +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(th_x) -> (th_x floordiv 2) // CHECK-LABEL: func.func @main( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> @@ -60,4 +60,4 @@ scatter { // CHECK: %[[COMBINED:.*]] = arith.addf %[[CURRENT]], %[[UPD_ELEM]] // CHECK: %[[UPDATED:.*]] = tensor.insert %[[COMBINED]] // CHECK-SAME: into %{{[a-z0-9]+}}[%{{.*}}, %[[RC]]] : tensor<10x5xf32> -// CHECK: xla_gpu.yield %[[UPDATED]] : tensor<10x5xf32> \ No newline at end of file +// CHECK: xla_gpu.yield %[[UPDATED]] : tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo index 1ca70362596696..25695c8212f7d0 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/epilogue.hlo @@ -11,10 +11,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo index 60e5cd404e1504..7c2d63c78a47ef 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_021.hlo @@ -9,13 +9,14 @@ fusion { ROOT %abs = f32[20,170,160] abs(%transpose) } // CHECK-LABEL: func.func @main( +// CHECK-SAME: %[[INPUT:.*]]: tensor<20x160x170xf32> { // CHECK-SAME: }, %[[OUT:.*]]: tensor<20x170x160xf32> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<1x32x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize +// CHECK-SAME: @fusion_exp(%[[INPUT]]) at #indexing_map +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo index 55c2976d32b341..2fc3855efe4c11 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_102.hlo @@ -10,10 +10,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8> // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[P0:.*]] = xla_gpu.pure_call @fusion_p0 -// CHECK: tensor.insert %[[P0]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_p0 +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo index 0dd4a27547514f..97e23f171b713a 100644 --- a/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/fused_transpose_210.hlo @@ -12,10 +12,9 @@ fusion { // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x20xf32> // // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x1x33xf32> -// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.loop -// CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) -// CHECK: %[[EXP:.*]] = xla_gpu.pure_call @fusion_exp -// CHECK: tensor.insert %[[EXP]] into %[[SHMEM_]] +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @fusion_exp +// CHECK: %[[SHMEM_WITH_VALS:.*]] = xla_gpu.insert %[[MATERIALIZED]] +// CHECK-SAME: into %[[SHMEM]] at #indexing_map // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo new file mode 100644 index 00000000000000..b17954109d1702 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/transpose/side_outputs_inplace.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize |\ +// RUN: FileCheck %s +// RUN: test_correctness %s + +HloModule module, input_output_alias={ {0}: (0, {}) } + +transpose_fusion { + p0 = f32[1024,2048]{1,0} parameter(0) + p1 = f32[1024,2048]{1,0} parameter(1) + add = f32[1024,2048]{1,0} add(p0, p1) + bitcast = f32[2097152]{0} bitcast(p0) + transpose = f32[2048,1024]{1,0} transpose(p1), dimensions={1,0} + ROOT res = (f32[1024,2048]{1,0}, f32[2048,1024]{1,0}, f32[2097152]{0}) tuple(add, transpose, bitcast) +} + +ENTRY module { + param0 = f32[1024,2048]{1,0} parameter(0) + param1 = f32[1024,2048]{1,0} parameter(1) + ROOT f = (f32[1024,2048]{1,0}, f32[2048,1024]{1,0}, f32[2097152]{0}) fusion(param0, param1), kind=kInput, calls=transpose_fusion +} + +// CHECK: xla_gpu.allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc index 72529cd6545c4d..c812a24fab915c 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/service/gpu/fusions/tools/test_lib.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" @@ -71,7 +72,7 @@ absl::Status TestBijection(const IndexingMap& map, auto status = VerifyBijection(map, intervals); if (status.ok()) return status; return absl::FailedPreconditionError( - absl::StrCat(status.message(), " in map ", map.ToString())); + absl::StrCat(status.message(), " in map ", ToString(map))); } TEST_F(CorrectnessTest, RunAndCompare) { diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc index 11b82ddd517072..30fa08451a4109 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc @@ -77,6 +77,8 @@ absl::StatusOr> LoadTestModule( auto* new_entry = module->AddComputationAndUnifyNamesAndIds( builder.Build(), /*is_entry=*/false); module->ReplaceEntryComputation(new_entry); + *module->mutable_entry_computation_layout() = + module->compute_computation_layout(); } return module; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD index e06494acb6262e..8f2adb7c38a88b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -38,6 +38,7 @@ cc_library( "erase_dead_functions.cc", "expand_float_ops.cc", "flatten_tensors.cc", + "fuse_loops.cc", "lower_tensors.cc", "lower_to_llvm.cc", "lower_xla_gpu_to_scf.cc", @@ -59,6 +60,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/mlir_hlo", "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", @@ -72,6 +74,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc b/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc index 8b07bb810727a4..8f899228f0fb94 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/convert_float_nvidia.cc @@ -85,7 +85,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { b.create(0, 8)); auto cvtIntr = to_ty.isFloat8E4M3FN() ? "llvm.nvvm.f16x2.to.e4m3x2.rn" : "llvm.nvvm.f16x2.to.e5m2x2.rn"; - cvtOp = b.create(b.getIntegerType(16), cvtIntr, + cvtOp = b.create(b.getIntegerType(16), + b.getStringAttr(cvtIntr), mlir::ValueRange{vec}); } else { // Other FP types get converted to F32 first. @@ -97,7 +98,8 @@ struct RewriteTruncFPattern : public mlir::OpRewritePattern { } auto cvtIntr = to_ty.isFloat8E4M3FN() ? "llvm.nvvm.ff.to.e4m3x2.rn" : "llvm.nvvm.ff.to.e5m2x2.rn"; - cvtOp = b.create(b.getIntegerType(16), cvtIntr, + cvtOp = b.create(b.getIntegerType(16), + b.getStringAttr(cvtIntr), mlir::ValueRange{value, value}); } Value res = b.create(b.getIntegerType(8), cvtOp.getResults()); @@ -211,7 +213,8 @@ struct RewriteExtFPattern : public mlir::OpRewritePattern { : "llvm.nvvm.e5m2x2.to.f16x2.rn"; mlir::FloatType f16_ty = b.getF16Type(); auto cvtOp = b.create( - ml::getFixedVectorType(f16_ty, 2), cvtIntr, mlir::ValueRange{input}); + ml::getFixedVectorType(f16_ty, 2), b.getStringAttr(cvtIntr), + mlir::ValueRange{input}); Value res = b.create( cvtOp.getResults(), b.create(0, 8)); if (to_ty.getWidth() > f16_ty.getWidth()) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc index aebb7f44608559..6fea3a97527f9b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/log/check.h" @@ -175,12 +176,19 @@ Value IsInf(Value value, mlir::ImplicitLocOpBuilder& b) { } assert(ty.getIntOrFloatBitWidth() == 8); - if (!ty.isFloat8E5M2()) { - // F8E5M2 is the only 8 bit float with infinities. + // F8E5M2, F8E4M3, F8E3M4 are the only 8 bit float with infinities. + if (ty.isFloat8E5M2()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x7C; + } else if (ty.isFloat8E4M3()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x78; + } else if (ty.isFloat8E3M4()) { + Val bits{b.create(b.getI8Type(), value), &b}; + return (bits & 0x7F) == 0x70; + } else { return b.create(false, b.getI1Type()); } - Val bits{b.create(b.getI8Type(), value), &b}; - return (bits & 0x7F) == 0x7C; } Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { @@ -193,8 +201,12 @@ Value IsNaN(Value value, mlir::ImplicitLocOpBuilder& b) { Val bits{b.create(b.getI8Type(), value), &b}; if (ty.isFloat8E5M2()) { return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1100); + } else if (ty.isFloat8E4M3()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'1000); } else if (ty.isFloat8E4M3FN()) { return (bits & 0b0111'1111) == 0b0111'1111; + } else if (ty.isFloat8E3M4()) { + return (bits & 0b0111'1111).cmp(ma::CmpIPredicate::ugt, 0b0111'0000); } return bits == 0x80; } @@ -207,7 +219,8 @@ Value EmitReducePrecision(Value value, int exponent_bits, int mantissa_bits, return mlir::mhlo::MhloOpToStdScalarOp::mapOpOfType< mlir::mhlo::ReducePrecisionOp>( b.getLoc(), value.getType(), {value.getType()}, - mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), &b); + mlir::mhlo::ReducePrecisionOp::Adaptor(value, nullptr, properties), + /*attributes=*/std::nullopt, &b); } Value EmitF16ToF8e5m2(Value in, mlir::ImplicitLocOpBuilder& b) { @@ -544,7 +557,8 @@ struct RewriteF8Cst : public mlir::OpRewritePattern { int64_t constant = rhs_cst.bitcastToAPInt().getZExtValue(); // If we're comparing to +-0, compare the absolute values. if (rhs_cst.isZero() && - (lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { + (lhs.getType().isFloat8E3M4() || lhs.getType().isFloat8E4M3() || + lhs.getType().isFloat8E4M3FN() || lhs.getType().isFloat8E5M2())) { int_value = int_value & 0x7f; constant &= 0x7f; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc new file mode 100644 index 00000000000000..161ad93aac11ea --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/fuse_loops.cc @@ -0,0 +1,325 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +using mlir::MLIRContext; +using mlir::Operation; +using mlir::SmallVector; +using mlir::Value; +using mlir::ValueRange; +namespace mv = ::mlir::vector; + +#define GEN_PASS_DEF_FUSELOOPSPASS +#include "xla/service/gpu/fusions/transforms/passes.h.inc" + +bool LoopsUseSameDimOps(LoopOp& loop1, LoopOp& loop2) { + for (auto [dim1, dim2] : llvm::zip(loop1.getDims(), loop2.getDims())) { + if (dim1.getDefiningOp() != dim2.getDefiningOp()) { + return false; + } + } + return true; +} + +bool LoopsHaveTheSameDomain(LoopOp& loop1, LoopOp& loop2) { + auto map1 = loop1.getIndexingMap(); + auto map2 = loop2.getIndexingMap(); + if (map1.GetDimVarsCount() != map2.GetDimVarsCount() || + map1.GetRangeVarsCount() != map2.GetRangeVarsCount() || + map1.GetConstraintsCount() != map2.GetConstraintsCount()) { + return false; + } + for (auto [d1, d2] : llvm::zip(map1.GetDimVars(), map2.GetDimVars())) { + if (d1 != d2) return false; + } + for (auto [r1, r2] : llvm::zip(map1.GetRangeVars(), map2.GetRangeVars())) { + if (r1 != r2) return false; + } + if (map1.GetConstraints() != map2.GetConstraints()) return false; + + // Check dimensions come from the same op. This is technically not a + // requirement and could be modified to handle different dim args. + return LoopsUseSameDimOps(loop1, loop2); +} + +// Check that the loops: +// 1. insert and extract from the same location within each iteration, +// 2. use all their IVs (so we don't overwrite the values in another iteration), +// 3. all indices are IVs (so they are confirmed injective). +bool IndicesAreEqualAndInjective(int64_t iv_count, mv::InsertOp insert, + mv::ExtractOp extract) { + auto insert_indices = insert.getDynamicPosition(); + auto extract_indices = extract.getDynamicPosition(); + if (insert_indices.size() != extract_indices.size()) { + return false; + } + if (insert_indices.size() != iv_count) { + return false; + } + + SmallVector matched_indices(iv_count, false); + for (auto [in, ex] : llvm::zip(insert_indices, extract_indices)) { + auto in_arg = mlir::dyn_cast(in); + auto ex_arg = mlir::dyn_cast(ex); + if (!in_arg || !ex_arg || in_arg.getArgNumber() != ex_arg.getArgNumber()) { + return false; + } + // Check #3 - all indices are IVs. + if (in_arg.getArgNumber() >= iv_count) { + return false; + } + matched_indices[in_arg.getArgNumber()] = true; + } + // If there is a loop IV that we didn't use in the insert op, then don't + // match. It's possible that we overwrite the value on a subsequent iteration + // so the loops cannot be fused. + return llvm::all_of(matched_indices, [](bool matched) { return matched; }); +} + +bool LoopDominatesLoop(LoopOp dominator /*lastloop*/, LoopOp dominatee) { + mlir::DominanceInfo dom; + return llvm::all_of(dominatee.getResults(), [&](Value result) { + return llvm::all_of(result.getUsers(), [&](Operation* user) { + return dom.properlyDominates(dominator, user, + /*enclosingOpOk*/ false); + }); + }); +} + +// Fuse insert_loop and extract_loop into a single loop, and remove the +// vector.insert and vector.extract ops. +void FuseExtractInsertLoopPair(MLIRContext* mlir_context, LoopOp insert_loop, + LoopOp extract_loop, mv::InsertOp insert, + mv::ExtractOp extract) { + mlir::IRRewriter rewriter(mlir_context); + rewriter.setInsertionPointAfter(extract_loop); + // Create a new map that has the results of both loops. + // map = (d0...dn)[s0...sn] -> + // (insert_loop_results..., extract_loop_results...) + auto insert_loop_map = insert_loop.getIndexingMap(); + auto extract_loop_map = extract_loop.getIndexingMap(); + auto map = insert_loop_map.GetAffineMap(); + for (auto res : extract_loop_map.GetAffineMap().getResults()) { + map = map.insertResult(res, map.getNumResults()); + } + IndexingMap new_map(map, insert_loop_map.GetDimVars(), + insert_loop_map.GetRangeVars(), + /*rt_vars=*/{}, insert_loop_map.GetConstraints()); + + auto new_loop = + rewriter.create(insert_loop.getLoc(), new_map, + insert_loop.getDims(), extract_loop.getInits()); + + // Make the loops independent of the vector.insert/extract & erase. + auto vector_cst = insert_loop.getInits().back(); + insert_loop->replaceAllUsesWith(ValueRange(vector_cst)); + extract_loop->replaceAllUsesWith(new_loop.getResults()); + extract.replaceAllUsesWith(insert.getSource()); + auto insert_loop_yield = + mlir::dyn_cast(insert_loop.getRegion().front().back()); + rewriter.eraseOp(insert_loop_yield); + rewriter.eraseOp(extract); + rewriter.eraseOp(insert); + + // Map old loop arguments to new loop arguments. + // new_args = [s0...sn, insert_loop_results..., extract_loop_results..., + // extract_inits...] + auto new_args = new_loop.getRegion().front().getArguments(); + auto range_vars = new_args.take_front(new_map.GetRangeVarsCount()); + new_args = new_args.drop_front(range_vars.size()); + auto in_loop_results = new_args.take_front(insert_loop_map.GetNumResults()); + new_args = new_args.drop_front(in_loop_results.size()); + auto ex_loop_results = new_args.take_front(extract_loop_map.GetNumResults()); + auto extract_inits = new_args.take_back(extract_loop.getInits().size()); + + // old_insert_args = [s0...sn, insert_loop_results..., vector_cst] + SmallVector old_insert_args; + old_insert_args.append(range_vars.begin(), range_vars.end()); + old_insert_args.append(in_loop_results.begin(), in_loop_results.end()); + old_insert_args.push_back(vector_cst); + + // old_insert_args = [s0...sn, extract_loop_results..., extract_inits...] + SmallVector old_extract_args; + old_extract_args.append(range_vars.begin(), range_vars.end()); + old_extract_args.append(ex_loop_results.begin(), ex_loop_results.end()); + old_extract_args.append(extract_inits.begin(), extract_inits.end()); + + // Merge the loops: first insert, then extract. + rewriter.mergeBlocks(&insert_loop.getRegion().front(), + &new_loop.getRegion().front(), old_insert_args); + rewriter.mergeBlocks(&extract_loop.getRegion().front(), + &new_loop.getRegion().front(), old_extract_args); + rewriter.eraseOp(insert_loop); + rewriter.eraseOp(extract_loop); +} + +// Fuse loops that have the same map, same dim variables, & can be rewritten as +// a single loop, each stacked on top of the next. +void FuseIndependentLoops(MLIRContext* mlir_context, + SmallVector& loops) { + auto last_loop = loops.back(); + auto map = last_loop.getIndexingMap(); + mlir::IRRewriter rewriter(mlir_context); + rewriter.setInsertionPointAfter(last_loop); + + SmallVector inits; + SmallVector results; + for (auto loop : loops) { + inits.append(loop.getInits().begin(), loop.getInits().end()); + auto yield_op = loop.getBody()->getTerminator(); + auto yields = yield_op->getOperands(); + results.append(yields.begin(), yields.end()); + yield_op->erase(); + } + auto new_loop = rewriter.create(last_loop.getLoc(), map, + last_loop.getDims(), inits); + + auto new_args = new_loop.getRegion().front().getArguments(); + int common_args_count = map.GetRangeVarsCount() + map.GetNumResults(); + auto common_args = new_args.take_front(common_args_count); + auto init_args = new_args.drop_front(common_args_count); + auto new_results = new_loop.getResults(); + + for (auto loop : loops) { + int num_results = loop.getNumResults(); + loop->replaceAllUsesWith(new_results.take_front(num_results)); + new_results = new_results.drop_front(num_results); + SmallVector old_args(common_args); + auto old_inits = init_args.take_front(num_results); + old_args.append(old_inits.begin(), old_inits.end()); + init_args = init_args.drop_front(num_results); + + rewriter.mergeBlocks(&loop.getRegion().front(), + &new_loop.getRegion().front(), old_args); + rewriter.eraseOp(loop); + } + rewriter.setInsertionPointToEnd(new_loop.getBody()); + rewriter.create(new_loop.getLoc(), results); +} + +void FuseSameMapLoopsIfPossible(MLIRContext* mlir_context, + SmallVector& loops) { + if (loops.size() < 2) return; + auto last_loop = loops.back(); + loops.pop_back(); + SmallVector eligible_loops; + for (auto loop : loops) { + if (LoopDominatesLoop(/*dominator=*/last_loop, /*dominatee=*/loop) && + LoopsUseSameDimOps(last_loop, loop)) { + eligible_loops.push_back(loop); + } + } + eligible_loops.push_back(last_loop); + + if (eligible_loops.size() < 2) return; + FuseIndependentLoops(mlir_context, eligible_loops); +} + +void FuseExtractIfPossible(MLIRContext* mlir_context, mv::ExtractOp extract) { + // Check that it has the following pattern: + // %insert_loop = { %insert = vector.insert ... } + // %extract_loop = { %extract = vector.extract %insert_loop } + auto extract_loop = extract->getParentOfType(); + if (!extract_loop) return; + if (!extract.getVector().getDefiningOp()) return; + auto insert_loop = + mlir::dyn_cast(extract.getVector().getDefiningOp()); + if (!insert_loop) return; + SmallVector inserts; + // If necessary, the insert_loop result size constraint may be relaxed. + if (insert_loop.getResults().size() != 1) return; + for (auto user : insert_loop.getRegionIterArgs().back().getUsers()) { + if (auto insert = mlir::dyn_cast(user)) { + inserts.push_back(insert); + } + } + if (inserts.size() != 1) return; + auto insert = inserts.front(); + + // Check that the vector isn't being used anywhere else so it can be + // removed entirely; we already know from above it's being used by + // extract so it should have exactly one use. + if (!insert_loop.getResult(0).hasOneUse()) return; + + if (!LoopsHaveTheSameDomain(insert_loop, extract_loop)) return; + // Only fuse loops if we are extracting from the same position that we are + // inserting into on each iteration. + if (!IndicesAreEqualAndInjective(insert_loop.getNumInductionVars(), insert, + extract)) { + return; + } + + // All requirements have been met: fuse loops. + FuseExtractInsertLoopPair(mlir_context, insert_loop, extract_loop, insert, + extract); +} + +struct FuseLoopsPass : public impl::FuseLoopsPassBase { + void runOnOperation() override { + auto mlir_context = &getContext(); + + SmallVector extracts; + getOperation()->walk([&](Operation* op) -> void { + if (auto extract = mlir::dyn_cast(op)) { + extracts.push_back(extract); + } + }); + for (auto extract : extracts) { + FuseExtractIfPossible(mlir_context, extract); + } + + // Fuse loops with the same map & that do not affect each other. + mlir::DenseMap> loops_by_map; + getOperation()->walk([&](Operation* op) -> void { + if (auto loop = mlir::dyn_cast(op)) { + loops_by_map[loop.getIndexingMapAttr()].push_back(loop); + } + }); + for (auto [_, loops] : loops_by_map) { + FuseSameMapLoopsIfPossible(mlir_context, loops); + } + } +}; + +} // namespace + +std::unique_ptr CreateFuseLoopsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc index 63e5c75f56c03a..7fa43e51b7881b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" @@ -58,6 +60,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -189,10 +192,10 @@ std::tuple GetI4IndexAndNibble(Value linear_index, } mlir::LLVM::GEPOp CreateGep(TypedValue tensor, - Value linear_index, mlir::ImplicitLocOpBuilder& b, - Type element_type = nullptr) { - if (!element_type) { - element_type = tensor.getType().getElementType(); + Value linear_index, mlir::ImplicitLocOpBuilder& b) { + Type element_type = tensor.getType().getElementType(); + if (element_type == b.getI4Type()) { + element_type = b.getI8Type(); } auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); auto tensor_ptr = @@ -221,12 +224,11 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; if (element_type == rewriter.getI4Type()) { - element_type = rewriter.getI8Type(); std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); } - auto gep = CreateGep(op.getTensor(), linear_index, b, element_type); + auto gep = CreateGep(op.getTensor(), linear_index, b); auto load = rewriter .create(gep.getLoc(), gep.getElemType(), gep) @@ -281,14 +283,12 @@ struct RewriteTransferRead if (vector_type.getElementType().isInteger(1)) { vector_type = vector_type.cloneWith(std::nullopt, b.getI8Type()); } - mlir::Type gep_element_type = vector_type.getElementType(); if (op.getVectorType().getElementType().isInteger(4)) { linear_index = b.create( linear_index, b.create(1, linear_index.getType())); - gep_element_type = b.getI8Type(); } - auto gep = CreateGep(source, linear_index, b, gep_element_type); + auto gep = CreateGep(source, linear_index, b); mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_vector_type = converter.convertType(vector_type); @@ -325,45 +325,67 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); auto linear_index = GetLinearIndex(op.getIndices(), b); - auto element_type = tensor_dest.getType().getElementType(); - Value is_low_nibble = nullptr; + auto scalar_value = op.getScalar(); - if (element_type == rewriter.getI4Type()) { - element_type = rewriter.getI8Type(); + // For i4 we store 2 values into one byte. This needs special handling here. + if (tensor_dest.getType().getElementType() == rewriter.getI4Type()) { + // We need to use directly op.getDest() as input, otherwise the following + // rewrite might remove the only user of it. + tensor_dest = op.getDest(); + Value is_low_nibble; std::tie(linear_index, is_low_nibble) = GetI4IndexAndNibble(linear_index, b); - } - auto gep = CreateGep(tensor_dest, linear_index, b, element_type); - auto scalar_value = op.getScalar(); - - if (is_low_nibble) { - Value current_value = - b.create(gep.getElemType(), gep); - auto ty = current_value.getType(); + // Technically we should half the number of elements when going to i8 + // element type, but it doesn't really matter because we only actually use + // the element type. Indexing is done by linear index, and GEP ops don't + // care about the number of elements. The tensor types will disappear + // completely after the LowerTensors pass. + Type ty = b.getI8Type(); + Type tensor_ty = tensor_dest.getType().clone(ty); + auto tensor_dest_i8 = + b.create(tensor_ty, tensor_dest) + .getResult(0); scalar_value = b.create(ty, scalar_value); - Value low_updated = b.create( - b.create( - current_value, b.create(0xf0, ty)), - scalar_value); - Value high_updated = b.create( - b.create( - current_value, b.create(0x0f, ty)), - b.create( - scalar_value, b.create(4, ty))); - scalar_value = b.create(is_low_nibble, low_updated, - high_updated); + + // We need AtomicRMWOp because it can happen that different threads try to + // access the same memory location. + auto atomic_rmw = b.create(tensor_dest_i8, linear_index); + mlir::ImplicitLocOpBuilder body_builder(atomic_rmw.getLoc(), + atomic_rmw.getBodyBuilder()); + Value current_value = atomic_rmw.getCurrentValue(); + Value low_updated = body_builder.create( + body_builder.create( + current_value, + body_builder.create(0xf0, ty)), + body_builder.create( + scalar_value, + body_builder.create(0x0f, ty))); + Value high_updated = body_builder.create( + body_builder.create( + current_value, + body_builder.create(0x0f, ty)), + body_builder.create( + scalar_value, + body_builder.create(4, ty))); + Value new_value = body_builder.create( + is_low_nibble, low_updated, high_updated); + body_builder.create(new_value); + Value casted_result = b.create( + tensor_dest.getType(), atomic_rmw.getResult()) + .getResult(0); + op.replaceAllUsesWith(casted_result); + } else { + auto gep = CreateGep(tensor_dest, linear_index, b); + mlir::LLVMTypeConverter converter(getContext()); + auto llvm_type = converter.convertType(scalar_value.getType()); + scalar_value = + b.create(llvm_type, scalar_value) + .getResult(0); + b.create(scalar_value, gep); + op.replaceAllUsesWith(op.getDest()); } - mlir::LLVMTypeConverter converter(getContext()); - auto llvm_type = converter.convertType(scalar_value.getType()); - scalar_value = rewriter - .create( - gep.getLoc(), llvm_type, scalar_value) - .getResult(0); - rewriter.create(gep.getLoc(), scalar_value, gep); - - op.replaceAllUsesWith(op.getDest()); op.erase(); return success(); } @@ -382,7 +404,6 @@ struct RewriteTransferWrite mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); auto linear_index = GetLinearIndex(op.getIndices(), b); - auto element_type = tensor_dest.getType().getElementType(); mlir::Value vector_value = op.getVector(); if (op.getVectorType().getElementType().isInteger(1)) { @@ -394,12 +415,11 @@ struct RewriteTransferWrite linear_index = b.create( linear_index, b.create(1, linear_index.getType())); - element_type = rewriter.getI8Type(); // LLVM and XLA pack i4s in opposite order, so we have to reshuffle the // elements. vector_value = PermutePairsInVector(vector_value, b); } - auto gep = CreateGep(tensor_dest, linear_index, b, element_type); + auto gep = CreateGep(tensor_dest, linear_index, b); mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(vector_value.getType()); @@ -453,11 +473,28 @@ mlir::LLVM::GlobalOp CreateGlobalOp(mlir::Attribute value, } Type element_type = shaped_ty.getElementType(); + int64_t num_elements = shaped_ty.getNumElements(); // Needed to support complex element type. mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - auto array_ty = mlir::LLVM::LLVMArrayType::get(llvm_element_type, - shaped_ty.getNumElements()); + if (mlir::isa(element_type)) { + int bit_width = mlir::cast(element_type).getWidth(); + if (bit_width == 4) { + num_elements = CeilOfRatio(num_elements, 2); + llvm_element_type = b.getI8Type(); + auto unpacked_data = + mlir::cast(value).getRawData(); + std::vector packed_data(num_elements); + absl::Span packed_data_span = + absl::MakeSpan(packed_data.data(), packed_data.size()); + PackIntN(4, unpacked_data, packed_data_span); + value = mlir::DenseElementsAttr::getFromRawBuffer( + mlir::RankedTensorType::get({num_elements}, llvm_element_type), + packed_data); + } + } + auto array_ty = + mlir::LLVM::LLVMArrayType::get(llvm_element_type, num_elements); std::string name; int index = 0; do { @@ -1040,6 +1077,10 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase { while (auto gep = addr.getDefiningOp()) { addr = gep.getBase(); } + while (auto cast = + addr.getDefiningOp()) { + addr = cast.getOperand(0); + } if (addr.getDefiningOp() || addr.getDefiningOp() || addr.getDefiningOp()) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc index e483bfebedb979..dd2177c0e97e71 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include +#include #include #include #include @@ -42,6 +43,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -201,7 +203,8 @@ struct PipelineLoad : mlir::OpRewritePattern { auto plus_one_map = mlir::AffineMap::get( 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1); b.setInsertionPoint(next_value); - IndexingMap indexing_map(plus_one_map, {DimVar{0, ub.getSExtValue() - 1}}, + IndexingMap indexing_map(plus_one_map, + {IndexingMap::Variable{0, ub.getSExtValue() - 1}}, /*range_vars=*/{}, /*rt_vars=*/{}); auto induction_plus_one = b.create(new_for.getInductionVar(), indexing_map) @@ -240,9 +243,10 @@ int GetUnrollingFactor(mlir::scf::ForOp op) { // Get a rough estimate of the size of the loop body. int64_t size = 0; + bool can_unroll = true; op.getBodyRegion().walk([&](mlir::Operation* op) { if (mlir::isa(op)) { - size += kMaxSize; + can_unroll = false; return; } @@ -272,6 +276,16 @@ int GetUnrollingFactor(mlir::scf::ForOp op) { size += this_size; }); + if (!can_unroll) { + return 1; + } + + // Always unroll if the trip count is smaller than the max unroll factor, + // because it's very likely that the loop was meant to be unrolled. + if (trip_count <= MaxUnrollFactor()) { + return trip_count; + } + int factor = std::min(trip_count, kMaxSize / size); while (factor > 1 && trip_count % factor) { --factor; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h index 470a333f70ccca..99304ed9a1f8da 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h @@ -51,6 +51,7 @@ std::unique_ptr CreateLowerXlaGpuToScfPass(); std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); +std::unique_ptr CreateFuseLoopsPass(); std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateRewriteReductionsPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 52a0dacbc3db8f..f19e984fd9e6ec 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -273,6 +273,38 @@ def VectorizeLoadsAndStoresPass : let constructor = "CreateVectorizeLoadsAndStoresPass()"; } +def FuseLoopsPass : Pass<"xla-gpu-fuse-loops", "mlir::func::FuncOp"> { + let summary = "Fuse xla_gpu.loop."; + let description = [{ + This pass fuses similar xla_gpu.loops into one if the second one is + extracting the same value from a vector in which the first one inserts to. + + Before fuse-loops: + %loop0 = xla_gpu.loop (%tid, %bid) -> (%ra, %rb, %rc)[%i, %j] + in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] + %1 = vector.insert %extracted, %iter [%i, %j] + xla_gpu.yield %1 + } + %loop1 = xla_gpu.loop (%tid, %bid) -> (%ra, %rb)[%i, %j] + in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<32x33xf32>) { + %2 = vector.extract %loop0 [%i, %j] + %inserted = tensor.insert %iter[%ra, %rb] + xla_gpu.yield %extracted + } + + After fuse-loops: + %loop = xla_gpu.loop (%tid, %bid) -> (%ra, %rb, %rc, %rd, %re)[%i, %j] + in #indexing_map iter_args(%iter = %shmem) -> (tensor<32x33xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] + %inserted = tensor.insert %extracted into %iter[%rd, %re] + xla_gpu.yield %inserted + } + }]; + let dependentDialects = ["xla::gpu::XlaGpuDialect"]; + let constructor = "CreateFuseLoopsPass()"; +} + def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> { let summary = "Peels xla_gpu.loop."; let description = [{ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc index 9e95e3c3264239..63d9bb3d6923dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -88,9 +89,9 @@ struct PeelLoop : public OpRewritePattern { tail_map.Simplify(); VLOG(5) << "Peeled indexing map\n" - << indexing_map.ToString() << "into\n" - << peeled_map.ToString() << "and\n" - << tail_map.ToString() << "\n"; + << ToString(indexing_map) << "into\n" + << ToString(peeled_map) << "and\n" + << ToString(tail_map) << "\n"; indexing_maps.pop_back(); indexing_maps.push_back(tail_map); indexing_maps.push_back(peeled_map); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc index 620fa8e51d30a8..50969b8bd6bbd8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc @@ -12,14 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" @@ -30,6 +34,7 @@ limitations under the License. #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" namespace xla { @@ -62,23 +67,55 @@ int GetNumThreads(mlir::Operation* op) { return Product(grid.getThreadCounts()); } -std::pair GetNumAndSizeOfMinorReducedDimensions(ReduceOp op) { +struct DimensionGroup { + int64_t size; + int64_t stride; + int first_dimension; + int num_dimensions; +}; + +DimensionGroup GetMinorMostReduction(ReduceOp op) { llvm::ArrayRef dims = op.getDimensions(); + auto input_ty = GetInputType(op); - int64_t cumulative_size = 1; - for (int i = 0; i < dims.size(); ++i) { - // The expected next reduction dimension if it is contiguous with the - // previously reduced dimensions. - int expected_dim = input_ty.getRank() - 1 - i; - // If the next reduced dimension is not the expected one, it is not - // contiguous (i.e., it's not part of the minor reduced dimensions, there is - // a kept dimension in between). - if (dims[dims.size() - 1 - i] != expected_dim) { - return {i, cumulative_size}; + DimensionGroup result{1, 1, static_cast(input_ty.getRank()), 0}; + llvm::SmallBitVector reduced_dims(input_ty.getRank()); + for (int64_t dim : dims) { + reduced_dims.set(dim); + } + + // Look for the first group of consecutive reduced dimensions and compute the + // stride and size of the group. + bool in_reduction = false; + for (int dim = input_ty.getRank() - 1; + dim >= 0 && (!in_reduction || reduced_dims[dim]); --dim) { + assert(input_ty.getDimSize(dim) > 1 && + "degenerate dimensions are not allowed"); + --result.first_dimension; + if (reduced_dims[dim]) { + in_reduction = true; + result.size *= input_ty.getDimSize(dim); + ++result.num_dimensions; + } else { + result.stride *= input_ty.getDimSize(dim); } - cumulative_size *= input_ty.getDimSize(input_ty.getRank() - 1 - i); } - return {dims.size(), cumulative_size}; + + return result; +} + +llvm::SmallVector ReindexTensors( + mlir::OpBuilder& b, mlir::ValueRange tensors, mlir::ValueRange defaults, + llvm::ArrayRef new_shape, const IndexingMap& map) { + llvm::SmallVector reindexed; + reindexed.reserve(tensors.size()); + for (auto [tensor, def] : llvm::zip(tensors, defaults)) { + auto new_ty = + mlir::cast(tensor.getType()).clone(new_shape); + reindexed.push_back( + b.create(tensor.getLoc(), new_ty, tensor, def, map)); + } + return reindexed; } // Rewrites large row reductions to three reductions: @@ -94,13 +131,12 @@ struct RewriteRowReduction : mlir::OpRewritePattern { ReduceOp op, mlir::PatternRewriter& rewriter) const override { auto* ctx = op.getContext(); - auto [num_minor_dims, reduced_size] = - GetNumAndSizeOfMinorReducedDimensions(op); - if (num_minor_dims == 0) { + auto minor_reduction = GetMinorMostReduction(op); + if (minor_reduction.stride > 1) { return rewriter.notifyMatchFailure(op, "not a row reduction"); } - if (reduced_size <= WarpSize()) { + if (minor_reduction.size <= WarpSize()) { return rewriter.notifyMatchFailure(op, "small minor dimension"); } @@ -108,9 +144,9 @@ struct RewriteRowReduction : mlir::OpRewritePattern { assert(num_threads % WarpSize() == 0); llvm::ArrayRef input_shape = GetInputType(op).getShape(); - llvm::SmallVector projected_input_shape{ - input_shape.begin(), input_shape.end() - num_minor_dims}; - projected_input_shape.push_back(reduced_size); + auto projected_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + projected_input_shape.push_back(minor_reduction.size); // Collapse the minor dimensions into one. // [..., 123, 456] -> [..., 123 * 456] @@ -120,14 +156,14 @@ struct RewriteRowReduction : mlir::OpRewritePattern { // Pad the new minor dimension to a multiple of the number of threads. For // example, for 128 threads, 123 * 456 = 56088 is padded to 56192. auto padded_projected_input_shape = projected_input_shape; - int64_t padded_size = RoundUpTo(reduced_size, num_threads); + int64_t padded_size = RoundUpTo(minor_reduction.size, num_threads); padded_projected_input_shape.back() = padded_size; // Reshape the padded minor dimension so that we can reduce it per thread // and then per warp. // [..., 56192] -> [..., 439, 4, 32] - llvm::SmallVector per_thread_reduction_input_shape( - input_shape.begin(), input_shape.end() - num_minor_dims); + auto per_thread_reduction_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); per_thread_reduction_input_shape.push_back(padded_size / num_threads); per_thread_reduction_input_shape.push_back(num_threads / WarpSize()); per_thread_reduction_input_shape.push_back(WarpSize()); @@ -141,24 +177,18 @@ struct RewriteRowReduction : mlir::OpRewritePattern { mlir::getAffineDimExpr(per_thread_input_rank - 1, ctx) + mlir::getAffineDimExpr(per_thread_input_rank - 2, ctx) * num_threads, - {0, reduced_size - 1}); - - // Reshape the inputs. - llvm::SmallVector new_operands; - new_operands.reserve(op.getOperands().size()); - for (auto [operand, init] : llvm::zip(op.getInputs(), op.getInits())) { - auto new_input_ty = mlir::cast(operand.getType()) - .clone(per_thread_reduction_input_shape); - new_operands.push_back(rewriter.create( - operand.getLoc(), new_input_ty, operand, init, reindex_map)); - } + {0, minor_reduction.size - 1}); + + auto new_inputs = + ReindexTensors(rewriter, op.getInputs(), op.getInits(), + per_thread_reduction_input_shape, reindex_map); // Reduce the non-minor dimensions and the third to last dimension. - auto dims_for_first_reduction = - llvm::to_vector(op.getDimensions().drop_back(num_minor_dims)); + auto dims_for_first_reduction = llvm::to_vector( + op.getDimensions().drop_back(minor_reduction.num_dimensions)); dims_for_first_reduction.push_back(per_thread_input_rank - 3); auto first_reduction = - rewriter.create(op.getLoc(), new_operands, op.getInits(), + rewriter.create(op.getLoc(), new_inputs, op.getInits(), dims_for_first_reduction, op.getCombiner()); // Reduce the last and the second-to-last dimensions. First to produce one @@ -175,9 +205,130 @@ struct RewriteRowReduction : mlir::OpRewritePattern { } }; +// Rewrites column reductions to a reduce-transpose-reduce. +struct RewriteColumnReduction : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + ReduceOp op, mlir::PatternRewriter& rewriter) const override { + auto* ctx = op.getContext(); + + auto minor_reduction = GetMinorMostReduction(op); + + if (minor_reduction.stride == 1) { + return rewriter.notifyMatchFailure(op, "not a column reduction"); + } + + int64_t num_threads = GetNumThreads(op); + + // If the stride is larger than the number of threads, we can efficiently + // emit this reduction as a simple loop, assuming there's no excessive + // padding. + // TODO(jreiffers): Is there anything we can do if the number of threads + // doesn't divide the stride? + if (minor_reduction.stride >= num_threads) { + return rewriter.notifyMatchFailure(op, "efficient loop reduction"); + } + + // A column reduction reduces [a, b] to [b]. We do this in four steps: + // 1. reshape [a, b] to [a ceildiv c, c, b] + // 2. reduce [a ceildiv c, c, b] to [c, b] via a loop + // 3. transpose [c, b] to [b, c] + // 4. emit a row reduction on [b, c]. + // + // We are constrained in our choice for `c`: + // + // - we need one element of shared memory (or a register) for each element + // of the intermediate results, so a larger c needs more shared memory. + // - we can have at most WarpSize intermediate results per final result, + // so c can be at most 32. + // - c must be a power of two so we can use a warp shuffle. + // - c * b should be less than the number of threads (but as close to it + // as possible, so we don't have excessive padding). + // + // All of this assumes no vectorization. + // TODO(jreiffers): Handle vectorization here. + + // Emitters always choose `c = 32` if `b` is not a small power of two. + // Also, reductions are tiled so `b = 32`. The number of threads is always + // 1024. This satisfies all the constraints above. + // Reduce the size of the reduction dimension. The maximum size we can + // handle is the warp size. + + assert(num_threads > minor_reduction.stride); + int64_t c = std::min(WarpSize(), num_threads / minor_reduction.stride); + + llvm::ArrayRef input_shape = GetInputType(op).getShape(); + auto projected_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + projected_input_shape.push_back(minor_reduction.size); + projected_input_shape.push_back(minor_reduction.stride); + auto projection_map = + GetBitcastMap(projected_input_shape, input_shape, ctx); + int64_t projected_rank = projected_input_shape.size(); + + // Pad the new minor dimension to a multiple of c. + auto padded_projected_input_shape = projected_input_shape; + int64_t padded_size = RoundUpTo(minor_reduction.size, c); + padded_projected_input_shape[projected_rank - 2] = padded_size; + + // Reshape the input to [..., a ceildiv c, c, b] + auto reshaped_input_shape = llvm::to_vector( + input_shape.take_front(minor_reduction.first_dimension)); + reshaped_input_shape.push_back(padded_size / c); + reshaped_input_shape.push_back(c); + reshaped_input_shape.push_back(minor_reduction.stride); + int64_t reshaped_rank = reshaped_input_shape.size(); + + auto reindex_map = + GetBitcastMap(reshaped_input_shape, padded_projected_input_shape, ctx) * + projection_map; + reindex_map.AddConstraint( + mlir::getAffineDimExpr(reshaped_rank - 2, ctx) + + mlir::getAffineDimExpr(reshaped_rank - 3, ctx) * c, + {0, minor_reduction.size - 1}); + + auto new_inputs = ReindexTensors(rewriter, op.getInputs(), op.getInits(), + reshaped_input_shape, reindex_map); + + // Reduce the non-minor dimensions and the third to last dimension. + // [..., a ceildiv c, c, b] -> [..., c, b] + auto dims_for_first_reduction = llvm::to_vector( + op.getDimensions().drop_back(minor_reduction.num_dimensions)); + dims_for_first_reduction.push_back(reshaped_rank - 3); + auto first_reduction = + rewriter.create(op.getLoc(), new_inputs, op.getInits(), + dims_for_first_reduction, op.getCombiner()); + + // Transpose [..., c, b] to [..., b, c] + auto shape = GetOutputType(first_reduction).getShape(); + int64_t first_reduction_rank = shape.size(); + llvm::SmallVector permutation(first_reduction_rank); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[first_reduction_rank - 1], + permutation[first_reduction_rank - 2]); + + auto transposed_shape = llvm::to_vector(shape); + std::swap(transposed_shape[first_reduction_rank - 1], + transposed_shape[first_reduction_rank - 2]); + IndexingMap transpose_map( + mlir::AffineMap::getPermutationMap(permutation, ctx), + DimVarsFromTensorSizes(transposed_shape), {}, {}); + + auto transposed = + ReindexTensors(rewriter, first_reduction.getResults(), op.getInits(), + transposed_shape, transpose_map); + + rewriter.replaceOpWithNewOp( + op, transposed, op.getInits(), + llvm::ArrayRef{first_reduction_rank - 1}, op.getCombiner()); + return mlir::success(); + } +}; + void RewriteReductionsPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext()); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc index acbd9d3735ea46..dd28ebd6a13db8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc @@ -223,15 +223,16 @@ struct RewriteAffineApply : OpRewritePattern { LogicalResult matchAndRewrite(mlir::affine::AffineApplyOp op, PatternRewriter& rewriter) const override { AffineMap affine_map = op.getAffineMap(); - std::vector dim_ranges(affine_map.getNumDims()); - std::vector symbol_ranges(affine_map.getNumSymbols()); + std::vector dim_ranges(affine_map.getNumDims()); + std::vector symbol_ranges( + affine_map.getNumSymbols()); for (int i = 0; i < affine_map.getNumInputs(); ++i) { if (auto range = GetRange(op->getOperand(i))) { if (i >= dim_ranges.size()) { - symbol_ranges[i - dim_ranges.size()] = RangeVar{*range}; + symbol_ranges[i - dim_ranges.size()] = IndexingMap::Variable{*range}; } else { - dim_ranges[i] = DimVar{*range}; + dim_ranges[i] = IndexingMap::Variable{*range}; } } else { return rewriter.notifyMatchFailure(op, "failed to deduce range"); @@ -357,6 +358,17 @@ std::optional GetIVRange(mlir::Value iv) { return {{lb.getSExtValue(), ub.getSExtValue() - 1}}; } } + if (auto loop_op = mlir::dyn_cast(parent)) { + const auto& indexing_map = loop_op.getIndexingMap(); + if (bbarg.getArgNumber() >= loop_op.getNumInductionVars() && + bbarg.getArgNumber() < + loop_op.getNumInductionVars() + indexing_map.GetNumResults()) { + RangeEvaluator range_evaluator = indexing_map.GetRangeEvaluator(); + return range_evaluator.ComputeExpressionRange( + indexing_map.GetAffineMap().getResult(bbarg.getArgNumber() - + loop_op.getNumInductionVars())); + } + } return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc index f3d67e24ee3248..72b1c0a22628e4 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" @@ -228,6 +229,35 @@ struct RewriteTruncExtShuffle : public OpRewritePattern { } }; +static std::optional GetSelectRange(mlir::Operation* sel) { + // Match |x| implemented as (x >= 0) ? x : (0 - x). + mlir::Value x = sel->getOperand(1); + auto m_x = mlir::matchers::m_Val(x); + if (!x.getType().isSignlessIntOrIndex() || + !mlir::matchPattern( + sel, mlir::m_Op( + mlir::m_Op(m_x, mlir::m_Zero()), m_x, + mlir::m_Op(mlir::m_Zero(), m_x)))) { + return std::nullopt; + } + if (sel->getOperand(0).getDefiningOp().getPredicate() != + CmpIPredicate::sge) { + return std::nullopt; + } + // Annotate |x| as >= 0. + Interval result{0, + static_cast( + (1ull << (x.getType().getIntOrFloatBitWidth() - 1)) - 1)}; + std::optional x_range = GetRange(x); + if (x_range.has_value()) { + Interval positive_range = x_range->max({0, 0}); + Interval negative_range = -x_range->min({0, 0}); + Interval abs_range = positive_range.Union(negative_range); + return result.Intersect(abs_range); + } + return result; +} + void AnnotateRanges(mlir::func::FuncOp func) { func->walk([](mlir::Operation* op) { if (op->getNumResults() != 1) { @@ -262,6 +292,10 @@ void AnnotateRanges(mlir::func::FuncOp func) { } else { out_range = lhs_range * rhs_range; } + } else if (mlir::isa(op)) { + out_range = GetRange(op->getOperand(0)); + } else if (mlir::isa(op)) { + out_range = GetSelectRange(op); } if (out_range) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index 1691d3fd748c23..d35dc71ddad023 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,7 +8,7 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2], is_simplified: true> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]"> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, @@ -67,7 +67,7 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3], is_simplified: true> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]"> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { @@ -93,8 +93,8 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) } {some_attr} return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -114,12 +114,9 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // ----- -#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), - domain: d0 in [0, 127], d1 in [0, 393749], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]"> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -225,7 +222,7 @@ func.func @vector_extract(%arg0: vector<2x3xf32>, %arg1: index) -> f32 { %v = vector.extract %arg0[%arg1, 2] : f32 from vector<2x3xf32> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 3 + 2), +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 3 + 2), // CHECK-SAME: domain: d0 in [0, 1] // CHECK-LABEL: func.func @vector_extract( @@ -241,7 +238,7 @@ func.func @vector_insert(%arg0: vector<10x24xf32>, %i: index) %out = vector.insert %scalar, %arg0 [1, %i] : f32 into vector<10x24xf32> func.return %out : vector<10x24xf32> } -// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 24), +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 24), // CHECK-SAME: domain: d0 in [0, 23] // CHECK-LABEL: func.func @vector_insert( // CHECK-SAME: %[[VECTOR:.*]]: vector<240xf32>, %[[I:.*]]: index) -> @@ -290,8 +287,8 @@ func.func @for_loop_vector(%t0: vector<32x1024xf32>, %t1: vector<64x8x4xf32>) return %for#0, %for#1, %c0_f32 : vector<32x1024xf32>, vector<64x8x4xf32>, f32 } -// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) -// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop_vector( // CHECK-SAME: %[[V0:.*]]: vector<32768xf32>, // CHECK-SAME: %[[V1:.*]]: vector<2048xf32>) -> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir new file mode 100644 index 00000000000000..f2926856623ca7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/fuse_loops.mlir @@ -0,0 +1,385 @@ +// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-fuse-loops \ +// RUN: | FileCheck %s + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(th_x, d1)[s0, s1] ->" +" (0," +" th_x mod 32," +" th_x floordiv 32 + s0 * 4)," +" domain:" +" th_x in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + th_x mod 32 in [0, 169]"> +func.func @fuse_loops(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + + +// CHECK: #[[$FUSED_MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> +// CHECK-SAME: (d1 floordiv 30, ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32, +// CHECK-SAME: (d1 mod 6) * 32 + d0 mod 32, 0, d0 mod 32, d0 floordiv 32 + s0 * 4), +// CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 599], +// CHECK-SAME: s0 in [0, 7], s1 in [0, 0], (d1 mod 6) * 32 + d0 mod 32 in [0, 169] + +// CHECK: %[[FUSED_LOOP:.*]] = xla_gpu.loop {{.*}} in #[[$FUSED_MAP]] +// CHECK-NOT: vector.insert +// CHECK-NOT: vector.extract +// CHECK: %[[EXTRACTED:.*]] = tensor.extract +// CHECK: %[[EXP:.*]] = math.exp %[[EXTRACTED]] +// CHECK: tensor.insert %[[EXP]] + +// CHECK: xla_gpu.sync_threads %[[FUSED_LOOP]] + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +func.func @do_not_fuse_index_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%j, %i] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + return %xla_loop_0 : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_index_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +func.func @do_not_fuse_multiple_uses(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + %0 = vector.extract %xla_loop [2, 0] : f32 from vector<8x1xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_multiple_uses +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 5], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +func.func @do_not_fuse_map_domain_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_map_domain_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0]," +" (d1 mod 5) * 32 + d0 mod 32 in [0, 169]"> +func.func @do_not_fuse_map_constraint_mismatch(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_map_constraint_mismatch +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +" (d1 floordiv 30," +" ((d1 floordiv 6) mod 5) * 32 + s0 * 4 + d0 floordiv 32," +" (d1 mod 6) * 32 + d0 mod 32)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1, s2] ->" +" (0," +" d0 mod 32," +" d0 floordiv 32 + s0 * 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 599]," +" s0 in [0, 7], s1 in [0, 0], s2 in [0, 1]," +" (d1 mod 6) * 32 + d0 mod 32 in [0, 169]"> +func.func @do_not_fuse_unused_loop_iv(%arg0: tensor<20x160x170xf32>) -> tensor<1x32x33xf32> { + %cst = arith.constant dense<0.000000e+00> : vector<8x1xf32> + %c0 = arith.constant 0 : index + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 599 : index]} + %shmem = xla_gpu.allocate_shared : tensor<1x32x33xf32> + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + -> (%ra, %rb, %rc) in #indexing_map iter_args(%iter = %cst) -> (vector<8x1xf32>) { + %extracted = tensor.extract %arg0[%ra, %rb, %rc] : tensor<20x160x170xf32> + %0 = math.exp %extracted : f32 + %1 = vector.insert %0, %iter [%i, %j] : f32 into vector<8x1xf32> + xla_gpu.yield %1 : vector<8x1xf32> + } + %xla_loop_0 = xla_gpu.loop (%tid, %bid)[%i, %j, %k] + -> (%ra, %rb, %rc) in #indexing_map1 iter_args(%iter = %shmem) -> (tensor<1x32x33xf32>) { + %0 = vector.extract %xla_loop[%i, %j] : f32 from vector<8x1xf32> + %inserted = tensor.insert %0 into %iter[%ra, %rb, %rc] : tensor<1x32x33xf32> + xla_gpu.yield %inserted : tensor<1x32x33xf32> + } + %synced_tensor = xla_gpu.sync_threads %xla_loop_0 : tensor<1x32x33xf32> + return %synced_tensor : tensor<1x32x33xf32> +} + +// CHECK-LABEL: @do_not_fuse_unused_loop_iv +// CHECK: xla_gpu.loop +// CHECK: vector.insert +// CHECK: xla_gpu.loop +// CHECK: vector.extract + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8," +" d0 mod 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 1023]," +" s0 in [0, 2], s1 in [0, 0]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> " +" ((d0 floordiv 4) mod 8192," +" d0 mod 4)," +" domain:" +" d0 in [0, 98303]"> +func.func @fuse_identical_independent_loops(%arg0: tensor<8192x4xf64>, + %arg1: tensor<98304x4xf64>, %arg2: tensor<98304x4xf64>) -> + tensor<98304x4xf64> { + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]} + %cst_2 = arith.constant 0.50000000000000089 : f64 + %cst = arith.constant 0 : index + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) { + %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> + %3 = arith.mulf %extracted, %cst_2 : f64 + %inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64> + xla_gpu.yield %inserted : tensor<98304x4xf64> + } + %xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + iter_args(%iter = %arg2) -> (tensor<98304x4xf64>) { + %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> + %inserted = tensor.insert %extracted into %iter[%ra, %rb] : + tensor<98304x4xf64> + xla_gpu.yield %inserted : tensor<98304x4xf64> + } + return %xla_loop_1 : tensor<98304x4xf64> +} + +// CHECK-LABEL: @fuse_identical_independent_loops +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8192x4xf64>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<98304x4xf64>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<98304x4xf64>) +// CHECK: %[[LOOP0:.*]], %[[LOOP1:.*]] = xla_gpu.loop +// CHECK-SAME: -> (%[[RA:.*]], %[[RB:.*]]) in +// CHECK-SAME: iter_args(%[[ITER0:.*]] = %[[ARG1]], %[[ITER1:.*]] = %[[ARG2]]) +// CHECK: tensor.insert {{.*}} into %[[ITER0]][%[[RA]], %[[RB]]] +// CHECK: tensor.insert {{.*}} into %[[ITER1]][%[[RA]], %[[RB]]] +// CHECK: xla_gpu.yield {{.*}} : tensor<98304x4xf64>, tensor<98304x4xf64> + +// ----- + +#indexing_map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] ->" +" ((d0 floordiv 32) * 8192 + d1 * 8 + s0 * 32768 + (d0 floordiv 4) mod 8," +" d0 mod 4)," +" domain:" +" d0 in [0, 127], d1 in [0, 1023]," +" s0 in [0, 2], s1 in [0, 0]"> +#indexing_map1 = #xla_gpu.indexing_map<"(d0) -> " +" ((d0 floordiv 4) mod 8192," +" d0 mod 4)," +" domain:" +" d0 in [0, 98303]"> +func.func @do_not_fuse_dependent_loops(%arg0: tensor<8192x4xf64>, + %arg1: tensor<98304x4xf64>) -> tensor<98304x4xf64> { + %tid = gpu.thread_id x {xla.range = [0 : index, 127 : index]} + %bid = gpu.block_id x {xla.range = [0 : index, 1023 : index]} + %cst_2 = arith.constant 0.50000000000000089 : f64 + %cst = arith.constant 0 : index + %xla_loop = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + iter_args(%iter = %arg1) -> (tensor<98304x4xf64>) { + %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> + %3 = arith.mulf %extracted, %cst_2 : f64 + %inserted = tensor.insert %3 into %iter[%ra, %rb] : tensor<98304x4xf64> + xla_gpu.yield %inserted : tensor<98304x4xf64> + } + %dependency = tensor.insert %cst_2 into %xla_loop[%cst, %cst] : + tensor<98304x4xf64> + %xla_loop_1 = xla_gpu.loop (%tid, %bid)[%i, %j] -> (%ra, %rb) in #indexing_map + iter_args(%iter = %dependency) -> (tensor<98304x4xf64>) { + %0:2 = xla_gpu.apply_indexing #indexing_map1(%ra) + %extracted = tensor.extract %arg0[%0#0, %0#1] : tensor<8192x4xf64> + %inserted = tensor.insert %extracted into %iter[%ra, %rb] : + tensor<98304x4xf64> + xla_gpu.yield %inserted : tensor<98304x4xf64> + } + return %xla_loop_1 : tensor<98304x4xf64> +} + +// CHECK-LABEL: @do_not_fuse_dependent_loops +// CHECK-COUNT-2: xla_gpu.loop \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index 822c3a85c9a2a0..a894c13dce1293 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -102,7 +102,8 @@ func.func @store_control_flow( %arg0: tensor<2xf32>, %arg1: index) } func.return %result : tensor<2xf32> } -// CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { +// CHECK-LABEL: @store_control_flow( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -150,7 +151,7 @@ func.func @vector_constant() -> vector<2xindex> { func.return %c1 : vector<2xindex> } // vector constants should not be rewritten. -// CHECK: @vector_constant +// CHECK-LABEL: @vector_constant // CHECK-NEXT: arith.constant // ----- @@ -164,7 +165,8 @@ func.func @complex_tensor_insert( %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> func.return %out : tensor<10xcomplex> } -// CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr +// CHECK-LABEL: @complex_tensor_insert( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr // CHECK: %[[C:.*]] = complex.create // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[C]] : complex to !llvm.struct<(f32, f32)> @@ -178,7 +180,8 @@ func.func @complex_tensor_extract( %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> func.return %v2 : complex } -// CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr +// CHECK-LABEL: @complex_tensor_extract( +// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> // CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)> // CHECK: builtin.unrealized_conversion_cast %[[LOAD]] : !llvm.struct<(f32, f32)> to complex @@ -237,7 +240,7 @@ func.func @atomic_rmw_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { } return %ret : tensor<8xf32> } -// CHECK: @atomic_rmw_f32 +// CHECK-LABEL: @atomic_rmw_f32 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]] // CHECK-NEXT: scf.while (%[[VAR:.*]] = %[[INIT]]) @@ -256,7 +259,7 @@ func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) } return %ret : tensor<8xf16> } -// CHECK: @atomic_rmw_f16 +// CHECK-LABEL: @atomic_rmw_f16 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] // CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}} @@ -285,7 +288,7 @@ func.func @atomic_rmw_overwrite(%in: tensor<8xf16>, %i: index) } return %ret : tensor<8xf16> } -// CHECK: @atomic_rmw_overwrite +// CHECK-LABEL: @atomic_rmw_overwrite // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] // CHECK-NEXT: %[[OFFSET:.*]] = llvm.and %[[ADDR_INT]], %{{.*}} @@ -307,7 +310,7 @@ func.func @shared_complex() -> tensor<10xcomplex> { return %shared : tensor<10xcomplex> } // CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>> -// CHECK: @shared_complex +// CHECK-LABEL: @shared_complex // ----- @@ -317,14 +320,33 @@ func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) %r = tensor.insert %v into %arg[%j] : tensor<10xi4> return %r : tensor<10xi4> } -// CHECK: @i4_load_store +// CHECK-LABEL: @i4_load_store +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i8 +// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : i8 +// CHECK-DAG: %[[C_NEG16:.*]] = arith.constant -16 : i8 // CHECK: llvm.getelementptr // CHECK-SAME: -> !llvm.ptr, i8 -// CHECK: llvm.load +// CHECK: %[[VALUE_I8:.*]] = arith.extui {{.*}} : i4 to i8 // CHECK: llvm.getelementptr // CHECK-SAME: -> !llvm.ptr, i8 -// CHECK: llvm.load -// CHECK: llvm.store +// CHECK: %[[CURRENT_I32:.*]] = llvm.load +// CHECK-SAME: !llvm.ptr -> i32 +// CHECK: scf.while (%[[INIT:.*]] = %[[CURRENT_I32]]) +// CHECK: %[[SHIFTED:.*]] = llvm.lshr %[[INIT]] +// CHECK: %[[CURRENT:.*]] = llvm.trunc %[[SHIFTED]] +// CHECK: %[[MASKED_CURRENT_LO:.*]] = arith.andi %[[CURRENT]], %[[C_NEG16]] : i8 +// CHECK: %[[MASKED_VALUE_I8:.*]] = arith.andi %[[VALUE_I8]], %[[C15]] : i8 +// CHECK: %[[NEW_LO:.*]] = arith.ori %[[MASKED_CURRENT_LO]], %[[MASKED_VALUE_I8]] : i8 +// CHECK: %[[MASKED_CURRENT_HI:.*]] = arith.andi %[[CURRENT]], %[[C15]] : i8 +// CHECK: %[[VALUE_HI:.*]] = arith.shli %[[VALUE_I8]], %[[C4]] : i8 +// CHECK: %[[NEW_HI:.*]] = arith.ori %[[MASKED_CURRENT_HI]], %[[VALUE_HI]] : i8 +// CHECK: %[[NEW_VALUE:.*]] = arith.select %{{.*}}, %[[NEW_LO]], %[[NEW_HI]] : i8 +// CHECK: %[[NEW_VALUE_I32:.*]] = llvm.zext %[[NEW_VALUE]] +// CHECK: %[[MASKED_INIT:.*]] = llvm.and %[[INIT]] +// CHECK: %[[NEW_VALUE_SHIFTED:.*]] = llvm.shl %[[NEW_VALUE_I32]] +// CHECK: %[[NEW_INIT:.*]] = llvm.or %[[MASKED_INIT]], %[[NEW_VALUE_SHIFTED]] +// CHECK: llvm.cmpxchg %{{.*}}, %[[INIT]], %[[NEW_INIT]] seq_cst seq_cst +// CHECK: scf.condition // ----- @@ -337,7 +359,7 @@ func.func @direct_atomic_rmw_overwrite(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_overwrite +// CHECK-LABEL: @direct_atomic_rmw_overwrite // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.store %[[C2]], %[[ADDR]] atomic unordered {alignment = 4 : i64} @@ -354,7 +376,7 @@ func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_addi +// CHECK-LABEL: @direct_atomic_rmw_addi // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.atomicrmw add %[[ADDR]], %[[C2]] seq_cst @@ -371,7 +393,7 @@ func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_maxsi +// CHECK-LABEL: @direct_atomic_rmw_maxsi // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.atomicrmw max %[[ADDR]], %[[C2]] seq_cst @@ -388,7 +410,7 @@ func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_maxui +// CHECK-LABEL: @direct_atomic_rmw_maxui // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.atomicrmw umax %[[ADDR]], %[[C2]] seq_cst @@ -405,7 +427,7 @@ func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_minsi +// CHECK-LABEL: @direct_atomic_rmw_minsi // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.atomicrmw min %[[ADDR]], %[[C2]] seq_cst @@ -422,7 +444,7 @@ func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, } return %ret : tensor<8xi32> } -// CHECK: @direct_atomic_rmw_minui +// CHECK-LABEL: @direct_atomic_rmw_minui // CHECK: %[[C2:.*]] = arith.constant 2 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK: llvm.atomicrmw umin %[[ADDR]], %[[C2]] seq_cst @@ -698,3 +720,15 @@ func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vecto // CHECK: %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]] // CHECK: return %[[CAST]] : vector<2xi1> +// ----- + +func.func @int4_constant(%arg0: tensor<3xi4>, %arg1: index) -> i4 { + %cst = arith.constant dense<[1, 2, 3]> : tensor<3xi4> + %extracted = tensor.extract %arg0[%arg1] : tensor<3xi4> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<3xi4> + %0 = arith.addi %extracted, %extracted_0 : i4 + return %0 : i4 +} +// CHECK: llvm.mlir.global private constant +// CHECK-SAME: dense<[18, 48]> +// CHECK-LABEL: @int4_constant diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir index f02f7012b80cf4..f981cef83029d8 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -1,9 +1,8 @@ // RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf \ // RUN: --split-input-file | FileCheck %s -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1, s1 - 1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%ra, %rb) @@ -15,9 +14,9 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 func.return %sum : f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + s1), -// CHECK-DAG: #[[$MAPA:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1), -// CHECK-DAG: #[[$MAPB:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s1 - 1), +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + s1), +// CHECK-DAG: #[[$MAPA:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1), +// CHECK-DAG: #[[$MAPB:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s1 - 1), // CHECK-LABEL: func.func @loop_op( // CHECK-SAME: %[[IN:.*]]: tensor<1024x32xf32>, @@ -60,9 +59,8 @@ func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32 // ----- -#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + 1, s1 - 1), - domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0 + 1, s1 - 1)," + "domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]"> func.func @loop_yields_value_from_above(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index dd15bdaafc533f..f53ccc1e8ae54f 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -124,12 +124,8 @@ func.func @predicated_extract( func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> { @@ -137,8 +133,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x2x2xf32, #map1> func.return %0 : !xla_gpu.indexed_vector<32x2x2xf32, #map1> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0 * 2 + s0, s1) // CHECK: @materialize(%[[INPUT:.*]]: tensor<32x64xf32>, %[[INDEX1:.*]]: index, %[[INDEX2:.*]]: index) @@ -153,12 +149,8 @@ func.func @materialize(%input: tensor<32x64xf32>, %i: index, %j: index) // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -166,8 +158,8 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, : !xla_gpu.indexed_vector<32x64xf32, #map> -> tensor<32x64xf32> func.return %0 : tensor<32x64xf32> } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 mod 16, d1) +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1 * 32 + d0 * 2 + s0, s1) +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 mod 16, d1) // CHECK: @insert(%[[INPUT:.*]]: !xla_gpu.indexed_vector<32x64xf32, #[[$MAP]]>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index, @@ -179,7 +171,7 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, // CHECK: %[[SCALAR:.*]] = vector.extract %{{.*}}[%[[S0]], %[[S1]]] // CHECK-SAME: : f32 from vector<2x2xf32> -// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing +// CHECK: %[[MAP1_RESULT:.*]]:2 = xla_gpu.apply_indexing // CHECK-SAME: #[[$MAP1]](%[[MAP_RESULT1]], %[[MAP_RESULT2]]) // CHECK: %[[NEW_TENSOR:.*]] = tensor.insert %[[SCALAR]] // CHECK-SAME: into %[[TENSOR]][%[[MAP1_RESULT]]#0, %[[MAP1_RESULT]]#1] @@ -189,15 +181,9 @@ func.func @insert(%input: !xla_gpu.indexed_vector<32x64xf32, #map>, func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1], - is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], - is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 1], s1 in [0, 1]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1], s1 in [0, 1]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { @@ -213,12 +199,8 @@ func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, func.func private @exp(%p0: tensor<32x64xcomplex>, %i: index, %j: index) -> complex -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 8], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d1*32+d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 8], s0 in [0, 2], s1 in [0, 3]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> func.func @materialize_complex( %input: tensor<32x64xcomplex>, %output: tensor<32x64xcomplex>, @@ -245,11 +227,8 @@ func.func @materialize_complex( // ----- -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0*2+s0, s1), - domain: d0 in [0, 32], d1 in [0, 2], - s0 in [0, 2], s1 in [0, 3], is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1) -> (d0, d1), - domain: d0 in [0, 32], d1 in [0, 2], is_simplified: false> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0, s1] -> (d0*2+s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 2], s1 in [0, 3]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1) -> (d0, d1), domain: d0 in [0, 32], d1 in [0, 2]"> func.func @insert_complex( %input: !xla_gpu.indexed_vector<32x3x4xcomplex, #map1>, %output: tensor<32x64xcomplex>, @@ -274,4 +253,4 @@ func.func @insert_complex( // CHECK: %[[IMAG:.*]] = vector.extract %[[VECTOR]][%[[C1]], %[[I]], %[[J]]] // CHECK: %[[COMPLEX:.*]] = complex.create %[[REAL]], %[[IMAG]] // CHECK: %[[INSERTED:.*]] = tensor.insert %[[COMPLEX]] into %[[ITER]] -// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> \ No newline at end of file +// CHECK: xla_gpu.yield %[[INSERTED]] : tensor<32x64xcomplex> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index dd7d639e3273e6..1094b51a2a6841 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,11 +1,8 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8), - domain: d0 in [0, 31], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8), - domain: d0 in [0, 31], is_simplified: false> -#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), - domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7], is_simplified: false> +#map = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 8), domain: d0 in [0, 31]"> +#map1 = #xla_gpu.indexing_map<"(d0) -> (d0 mod 8), domain: d0 in [0, 31]"> +#map2 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]"> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -127,7 +124,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_extract // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index @@ -154,7 +151,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15], is_simplified: false>(%i) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 15]">(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> @@ -164,8 +161,8 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), +// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_transfer // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir index 8959fbb826bdda..9ffd7bdc0fbfd1 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -1,16 +1,9 @@ // RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \ // RUN: | FileCheck %s -#map = #xla_gpu.indexing_map< - (d0)[s0, s1] -> (s0, s1), - domain: - d0 in [0, 3], - s0 in [0, 7], - s1 in [0, 10], - d0 + s0 in [0, 9], - d0 + s1 in [0, 12], - is_simplified: false -> +#map = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain:" + "d0 in [0, 3], s0 in [0, 7], s1 in [0, 10], d0 + s0 in [0, 9]," + "d0 + s1 in [0, 12]"> func.func @peel_both_loops(%input: tensor<16x32xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i, %j] -> (%r0, %r1) @@ -21,9 +14,9 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, } func.return %sum : f32 } -// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9], is_simplified: true> -// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9], is_simplified: true> -// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10], is_simplified: true> +// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]"> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<"(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]"> // CHECK-LABEL: func.func @peel_both_loops( // CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, @@ -48,13 +41,8 @@ func.func @peel_both_loops(%input: tensor<16x32xf32>, // ----- -#map = #xla_gpu.indexing_map< - (d0)[s0] -> (s0), - domain: - d0 in [0, 3], - s0 in [0, 7], - is_simplified: false -> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (s0)," + "domain: d0 in [0, 3], s0 in [0, 7]"> func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) @@ -72,13 +60,11 @@ func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, // ----- #map = #xla_gpu.indexing_map< - (d0)[s0] -> (s0), - domain: - d0 in [0, 3], - s0 in [0, 7], - s0 mod 5 in [0, 1], - is_simplified: false -> +" (d0)[s0] -> (s0)," +" domain:" +" d0 in [0, 3]," +" s0 in [0, 7]," +" s0 mod 5 in [0, 1]"> func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, %dim: index) -> (f32) { %sum = xla_gpu.loop (%dim)[%i] -> (%r0) @@ -91,4 +77,4 @@ func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, } // CHECK-LABEL: func.func @constraint_exists_after_peeling // CHECK: xla_gpu.loop -// CHECK-NOT: xla_gpu.loop \ No newline at end of file +// CHECK-NOT: xla_gpu.loop diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir index fe04e7f95463e0..5f8b9ba5413d84 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir @@ -19,7 +19,7 @@ func.func @row_reduction(%arg0: tensor<128x1027xf32>) return %0 : tensor<128xf32> } -// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), +// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), // CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 8], d2 in [0, 3], d3 in [0, 31], d1 * 128 + d2 * 32 + d3 in [0, 1026] // CHECK-LABEL: @row_reduction // CHECK-SAME: %[[IN:.*]]: tensor<128x1027xf32> @@ -37,8 +37,8 @@ func.func @add(%a: f32, %b: f32) -> f32 { return %0 : f32 } -func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<1x42x128x32x8xf32>) - -> tensor<1x128xf32> attributes { +func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<2x42x128x32x8xf32>) + -> tensor<2x128xf32> attributes { xla_gpu.launch_grid = #xla_gpu.launch_grid< block_counts = [42, 1, 1], thread_counts = [128, 1, 1] @@ -46,12 +46,48 @@ func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<1x42x128x32x8xf32> } { %c0 = arith.constant 0.0 : f32 %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1, 3, 4] combiner=@add - : tensor<1x42x128x32x8xf32> to tensor<1x128xf32> - return %0 : tensor<1x128xf32> + : tensor<2x42x128x32x8xf32> to tensor<2x128xf32> + return %0 : tensor<2x128xf32> } -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex -// CHECK-SAME: : tensor<1x42x128x32x8xf32> -> tensor<1x42x128x2x4x32xf32> -// CHECK: xla_gpu.reduce(%[[REINDEXED]]) -// CHECK-SAME: dimensions=[1, 3] -// CHECK-SAME: : tensor<1x42x128x2x4x32xf32> +// CHECK-LABEL: @row_reduction_with_major_reduced_dim +// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex +// CHECK-SAME: : tensor<2x42x128x32x8xf32> -> tensor<2x42x128x2x4x32xf32> +// CHECK: xla_gpu.reduce(%[[REINDEXED]]) +// CHECK-SAME: dimensions=[1, 3] +// CHECK-SAME: : tensor<2x42x128x2x4x32xf32> + +// ----- + +func.func @add(%a: f32, %b: f32) -> f32 { + %0 = arith.addf %a, %b : f32 + return %0 : f32 +} + +func.func @column(%arg0: tensor<2x32x32xf32>) + -> tensor<2x32xf32> attributes { + xla_gpu.launch_grid = #xla_gpu.launch_grid< + block_counts = [42, 1, 1], + thread_counts = [128, 1, 1] + > + } { + %c0 = arith.constant 0.0 : f32 + %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add + : tensor<2x32x32xf32> to tensor<2x32xf32> + return %0 : tensor<2x32xf32> +} + +// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) +// CHECK-SAME: d1 * 4 + d2 in [0, 31] +// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0, d2, d1) +// CHECK-LABEL: @column +// CHECK-SAME: %[[IN:.*]]: tensor<2x32x32xf32> +// CHECK: %[[C0:.*]] = arith.constant 0.00 +// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$RESHAPE]] default %[[C0]] +// CHECK-SAME: -> tensor<2x8x4x32xf32> +// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] +// CHECK-SAME: to tensor<2x4x32xf32> +// CHECK: %[[TRANSPOSED:.*]] = xla_gpu.reindex %[[R1]] at #[[$TRANSPOSE]] +// CHECK-SAME: -> tensor<2x32x4xf32> +// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[TRANSPOSED]]) inits(%[[C0]]) dimensions=[2] +// CHECK: return %[[R2]] : tensor<2x32xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index db78b88abd51e0..e62a530de0e7db 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -63,8 +63,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { %2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)), - domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3], is_simplified: false>[%1, %0, %i] + #xla_gpu.indexing_map< + "()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))," + "domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]">[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -92,8 +93,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100), - domain: s0 in [0, 42], s1 in [0, 1000], is_simplified: false>[%arg0, %arg1] + #xla_gpu.indexing_map< + "()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)," + "domain: s0 in [0, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0 : index } @@ -106,8 +108,8 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1), - domain: s0 in [-10, 42], s1 in [0, 1000], is_simplified: false>[%arg0, %arg1] + #xla_gpu.indexing_map<"()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)," + "domain: s0 in [-10, 42], s1 in [0, 1000]">[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -124,8 +126,9 @@ func.func @order_summands(%arg1: index) { scf.for %arg2 = %c0 to %c4 step %c1 { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing - #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10), - domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3], is_simplified: false>[%arg2, %arg1, %arg3] + #xla_gpu.indexing_map< + "()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)," + "domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]">[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index 5f776d0f338862..e6fea946e6e827 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -248,7 +248,8 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %c42_f32 = arith.constant 42.0 : f32 %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { - %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9], is_simplified: false>(%i) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<"(d0) -> (d0 mod 4)," + "domain: d0 in [0, 9]">(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -262,10 +263,11 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000), - domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3], is_simplified: false> -#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9), - domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3], is_simplified: false> +#map = #xla_gpu.indexing_map< + "(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)," + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]"> +#map1 = #xla_gpu.indexing_map<"(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)," + "domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]"> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -289,5 +291,116 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } return %0 : tensor<2400000x9xf32> } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2, d3) -> (d2 * 32768 + (d0 * 4 + d1 * 512 + d3) floordiv 9), // CHECK-LABEL: func.func @refine_constraints_for_symbol + +// ----- + +#map = #xla_gpu.indexing_map< + "(d0, d1, d2, d3, d4, d5)[s0] -> ((d0 * 4 + s0) floordiv 6, (d0 * 4 + s0) mod 6)," + "domain:" + "d0 in [0, 29]," + "d1 in [0, 0]," + "d2 in [0, 0]," + "d3 in [0, 0]," + "d4 in [0, 0]," + "d5 in [0, 0]," + "s0 in [0, 3]," + "d0 * 4 + s0 in [0, 29]"> +func.func @dus(%arg0: tensor<20x30xf32>, %arg1: tensor<5x6xf32>, %arg2: i32, %arg3: i32, %arg4: tensor<20x30xf32>) -> tensor<20x30xf32> { + %c24 = arith.constant 24 : index + %c15 = arith.constant 15 : index + %c0 = arith.constant 0 : index + %thread_id_x = gpu.thread_id x + %thread_id_y = gpu.thread_id y + %thread_id_z = gpu.thread_id z + %block_id_x = gpu.block_id x + %block_id_y = gpu.block_id y + %block_id_z = gpu.block_id z + %0 = arith.index_cast %arg2 : i32 to index + %1 = arith.minsi %0, %c15 : index + %2 = arith.maxsi %1, %c0 : index + %3 = arith.index_cast %arg3 : i32 to index + %4 = arith.minsi %3, %c24 : index + %5 = arith.maxsi %4, %c0 : index + %xla_loop = xla_gpu.loop (%thread_id_x, %thread_id_y, %thread_id_z, %block_id_x, %block_id_y, %block_id_z)[%i] -> (%ra, %rb) in #map iter_args(%iter = %arg4) -> (tensor<20x30xf32>) { + %6 = arith.addi %2, %ra : index + %7 = arith.addi %5, %rb : index + %extracted = tensor.extract %arg1[%ra, %rb] : tensor<5x6xf32> + %inserted = tensor.insert %extracted into %iter[%6, %7] : tensor<20x30xf32> + xla_gpu.yield %inserted : tensor<20x30xf32> + } + return %xla_loop : tensor<20x30xf32> +} + +// CHECK-LABEL: func.func @dus +// CHECK: arith.minsi +// CHECK-SAME: xla.range = [-9223372036854775808 : index, 15 : index] +// CHECK: arith.maxsi +// CHECK-SAME: xla.range = [0 : index, 15 : index] +// CHECK: arith.minsi +// CHECK-SAME: xla.range = [-9223372036854775808 : index, 24 : index] +// CHECK: arith.maxsi +// CHECK-SAME: xla.range = [0 : index, 24 : index] +// CHECK: xla_gpu.loop +// CHECK: arith.addi +// CHECK-SAME: xla.range = [0 : index, 19 : index] +// CHECK: arith.addi +// CHECK-SAME: xla.range = [0 : index, 29 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 2147483647 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 2147483647 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32 {xla.range = [-31 : i32, 17 : i32]}) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 31 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 31 : index] + +// ----- + +module { + func.func @annotate_range_abs_index(%v: i32 {xla.range = [-5 : i32, 3 : i32]}) -> index { + %c0_i32 = arith.constant 0 : i32 + %0 = arith.cmpi sge, %v, %c0_i32 : i32 + %1 = arith.subi %c0_i32, %v : i32 + %2 = arith.select %0, %v, %1 : i32 + %3 = arith.index_cast %2 : i32 to index + return %3: index + } +} + +// CHECK-LABEL: @annotate_range_abs +// CHECK: arith.select +// CHECK-SAME: xla.range = [0 : index, 5 : index] +// CHECK-NEXT: arith.index_cast +// CHECK-SAME: xla.range = [0 : index, 5 : index] diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir index c77d035e6271b3..b9c2b1b4086278 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -1,8 +1,8 @@ // RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file \ // RUN: -xla-gpu-vectorize-loads-stores -cse -canonicalize | FileCheck %s -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -20,7 +20,7 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63], is_simplified: true> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2), domain: d0 in [0, 63]"> // CHECK-LABEL: @simple_read // CHECK-SAME: (%[[ARG0:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -36,8 +36,78 @@ func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 4 + s0)," + "domain: d0 in [0, 63], s0 in [0, 3]"> +func.func @simple_read(%arg0: tensor<256xf16>) -> (f16) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f16 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f16 { + %inner = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter1 = %iter) -> f16 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<256xf16> + %added = arith.addf %iter1, %extracted : f16 + scf.yield %added : f16 + } + scf.yield %inner : f16 + } + return %outer : f16 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 4), domain: d0 in [0, 63]"> +// CHECK-LABEL: @simple_read +// CHECK-SAME: (%[[ARG0:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] +// CHECK-NEXT: vector.extract %[[V]][%[[J]]] +// CHECK-NEXT: addf + +// ----- + +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 8 + s0)," + "domain: d0 in [0, 63], s0 in [0, 7]"> +func.func @simple_read(%arg0: tensor<512xi8>) -> (i8) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0 : i8 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> i8 { + %inner = scf.for %j = %c0 to %c8 step %c1 iter_args(%iter1 = %iter) -> i8 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<512xi8> + %added = arith.addi %iter1, %extracted : i8 + scf.yield %added : i8 + } + scf.yield %inner : i8 + } + return %outer : i8 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 8), domain: d0 in [0, 63]"> +// CHECK-LABEL: @simple_read +// CHECK-SAME: (%[[ARG0:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] +// CHECK-NEXT: vector.extract %[[V]][%[[J]]] +// CHECK-NEXT: addi + +// ----- + +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0 + 1)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -60,8 +130,8 @@ func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 3 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -84,8 +154,8 @@ func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (3 * d0 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (3 * d0 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -108,8 +178,8 @@ func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 + s0 * 2)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -134,8 +204,8 @@ func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { // We could vectorize this as a float vector load of double the size, but we // don't currently. -#map = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 + s0), - domain: d0 in [0, 127], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> (2 * d0 + s0)," + "domain: d0 in [0, 127], s0 in [0, 1]"> func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -250,10 +320,11 @@ func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { // ----- -#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), - domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512), - domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)," + "domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]"> +#map1 = #xla_gpu.indexing_map< + "(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512)," + "domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]"> func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, %arg4: index) -> (tensor<131072xf32>, f32) { @@ -280,8 +351,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, } return %0#0, %0#1 : tensor<131072xf32>, f32 } -// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7], is_simplified: true> -// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7], is_simplified: true> +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0, d1) -> (d0 * 2 + d1 * 512), domain: d0 in [0, 255], d1 in [0, 7]"> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<"(d0, d1, d2) -> (d0 * 32 + d1 * 2 + d2 * 512), domain: d0 in [0, 7], d1 in [0, 255], d2 in [0, 7]"> // CHECK-LABEL: @multiple // CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) // CHECK: %[[C0:.*]] = arith.constant 0 : index @@ -304,8 +375,8 @@ func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 64 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -323,7 +394,7 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { } return %outer : f32 } -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> ((d0 mod 16) * 4), // CHECK-LABEL: @remainder_with_modulo // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: scf.for %[[I:.*]] = %[[C0]] @@ -332,8 +403,8 @@ func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map = #xla_gpu.indexing_map<"(d0)[s0] -> ((d0 * 4) mod 65 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index @@ -356,10 +427,10 @@ func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), - domain: d0 in [0, 63], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," + "domain: d0 in [0, 63]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -381,8 +452,8 @@ module { } } -// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2 + 10), -// CHECK-SAME: domain: d0 in [0, 63], is_simplified: true> +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 2 + 10), +// CHECK-SAME: domain: d0 in [0, 63]"> // CHECK-LABEL: @apply_indexing_sequence // CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] // CHECK: vector.transfer_read {{.*}}[%[[BASE]]] @@ -390,10 +461,10 @@ module { // ----- -#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), - domain: d0 in [0, 63], is_simplified: true> -#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), - domain: d0 in [0, 63], s0 in [0, 1], is_simplified: true> +#map0 = #xla_gpu.indexing_map<"(d0) -> (d0 + 5)," + "domain: d0 in [0, 63]"> +#map1 = #xla_gpu.indexing_map<"(d0)[s0] -> (d0 * 2 + s0)," + "domain: d0 in [0, 63], s0 in [0, 1]"> module { func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { %c0 = arith.constant 0 : index @@ -418,4 +489,4 @@ module { } // CHECK-LABEL: @apply_indexing_sequence_same_block -// CHECK-NOT: vector.transfer_read \ No newline at end of file +// CHECK-NOT: vector.transfer_read diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc index 9795a96e387f53..34e90b1ebb3368 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc @@ -139,7 +139,7 @@ mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type, } std::optional vector_size = mlir::getConstantIntValue(loop.getUpperBound()); - if (vector_size != 2 && vector_size != 4) { + if (vector_size != 2 && vector_size != 4 && vector_size != 8) { return nullptr; // Unsupported vector size. } if (tensor_type.getShape().back() % *vector_size) { diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index fd18cef310a8fb..9be7989dc1c87b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" @@ -211,8 +212,10 @@ IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( std::vector dim_var_sizes(6, 1); dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = kNumThreadsPerBlock; + dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = + Product(block_counts_); return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), - DimVarsFromTensorSizes(dim_var_sizes), + DimVarsFromGPUGrid(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), {}}; } @@ -233,44 +236,72 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( } else { ++shmem_tensor_size.back(); } + int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + SmallVector callee_operands( + entry_function.getArguments().take_front(num_inputs)); + auto tids_and_bids = EmitThreadAndBlockIds(builder); + auto identity_map = + IndexingMapAttr::get(ctx, CreateIdentityMap(shmem_tensor_size, ctx)); + + // We can assume that all transpose operands have the same shape. + Shape operand_shape = shmem_transposes_.front()->operand(0)->shape(); - // Allocate shared memory. - SmallVector inits; + // Indexing for MaterializeOp to read from input. + auto indexing = GetIndexing(/*input=*/true, operand_shape, ctx); + + // Indexing for InsertOp to write into shared memory. + IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); + // As we are writing the same elements that we are reading, any read + // constraints can also be constraints for the write. + for (auto constraint : indexing.GetConstraints()) { + write_indexing.AddConstraint(constraint.first, constraint.second); + } + for (auto [index, bound] : llvm::enumerate(indexing.GetSymbolBounds())) { + write_indexing.GetMutableSymbolBound(index) = bound; + } + write_indexing.Simplify(); + auto dimensions = SmallVector(operand_shape.dimensions().begin(), + operand_shape.dimensions().end()); + SmallVector shmem_tensors; for (auto* transpose : shmem_transposes_) { auto elem_type = mlir_converter::PrimitiveTypeToMlirType( transpose->shape().element_type(), builder); - inits.push_back(builder.create( - RankedTensorType::get(shmem_tensor_size, elem_type))); + auto shmem = builder.create( + RankedTensorType::get(shmem_tensor_size, elem_type)); + auto indexed_vector = + IndexedVectorType::get(ctx, shmem_tensor_size, elem_type, + IndexingMapAttr::get(ctx, write_indexing)); + auto callee = + mlir::SymbolRefAttr::get(call_target_provider(transpose->operand(0))); + + auto materialized = builder.create( + /* result_type=*/indexed_vector, + /*input=*/callee_operands, + /*indices(dimensions)=*/tids_and_bids, + /*callee=*/callee, + /*map=*/IndexingMapAttr::get(ctx, indexing)); + + auto insert = builder.create( + /*result_type=*/shmem.getType(), + /*source=*/materialized.getResult(), + /*indices(dimensions)=*/tids_and_bids, + /*dest=*/shmem, + /*map=*/identity_map); + shmem_tensors.push_back(insert.getResult()); } - // Add output arguments for side outputs. - int num_inputs = fusion.fused_instructions_computation()->num_parameters(); + // Produce all side outputs and then write them. + SmallVector side_output_inits; for (int index : side_output_root_indices_) { - inits.push_back(entry_function.getArgument(num_inputs + index)); + side_output_inits.push_back(entry_function.getArgument(num_inputs + index)); } - - IndexingMap write_indexing = GetSharedMemoryIndexing(/*read=*/false, ctx); auto body_builder = [&](ValueRange symbol_values, ValueRange map_results, ValueRange output_tensors) -> SmallVector { auto input_indices = [&](const HloInstruction* instr) { return ApplyIndexing(GetIndexing(/*input=*/true, instr->shape(), ctx), thread_and_block_ids, symbol_values, builder); }; - SmallVector result_tensors; - auto shmem_indices = ApplyIndexing(write_indexing, thread_and_block_ids, - symbol_values, builder); - for (auto [transpose, output] : - llvm::zip(shmem_transposes_, output_tensors)) { - // Emit loop that writes subgraphs of transpose operands to shmem. - auto result_scalar = mlir_converter::ProvideParameter( - root_computation, transpose, - /*operand_index=*/0, input_indices(transpose->operand(0)), - call_target_provider, entry_function, builder)[0]; - result_tensors.push_back(builder.create( - result_scalar, output, shmem_indices)); - } - // Produce all side outputs and then write them. SmallVector side_outputs; SmallVector> side_output_indices; auto* root_tuple = fusion.fused_expression_root(); @@ -283,22 +314,21 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( side_outputs.append(param_values.begin(), param_values.end()); } + SmallVector result_tensors; for (const auto& [value, indices, output] : - llvm::zip(side_outputs, side_output_indices, - output_tensors.take_back(side_output_roots_.size()))) { + llvm::zip(side_outputs, side_output_indices, output_tensors)) { result_tensors.push_back( builder.create(value, output, indices)); } return result_tensors; }; - - auto indexing = GetIndexing( - /*input=*/true, shmem_transposes_.front()->operand(0)->shape(), ctx); - auto written_vector = mlir_converter::EmitXlaLoopOp( - builder, thread_and_block_ids, inits, indexing, body_builder); - ValueRange written = written_vector; - auto shmem_tensors = written.take_front(shmem_transposes_.size()); + mlir::ValueRange side_output_vector; + if (!side_output_inits.empty()) { + side_output_vector = mlir_converter::EmitXlaLoopOp( + builder, thread_and_block_ids, side_output_inits, indexing, + body_builder); + } WriteResult result; result.shmem_tensors = @@ -307,8 +337,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( .getResults(); result.updated_outputs = output_args; for (auto [index, side_output_result] : - llvm::zip(side_output_root_indices_, - written.take_back(side_output_roots_.size()))) { + llvm::zip(side_output_root_indices_, side_output_vector)) { result.updated_outputs[index] = side_output_result; } return result; @@ -362,9 +391,8 @@ std::vector MlirTransposeFusion::GetEpilogues(const HloFusionInstruction& fusion, MLIRContext* mlir_context) const { std::vector epilogues{ - mlir_converter::EpilogueSpecification::FromOutputIndexing( - analysis_, shmem_transposes_, shmem_transpose_roots_, *this, - mlir_context)}; + GetEpilogueForOutputIndexing(analysis_, shmem_transposes_, + shmem_transpose_roots_, mlir_context)}; // Add empty epilogues for the side outputs. This ensures their roots don't // get "fused" into the tuple function. for (const auto* root : side_output_roots_) { diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 9602242fe4745a..a3f040008ae479 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -23,12 +23,14 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" @@ -36,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/shape.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index 7d235c132989c4..c7f35de5398045 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -37,8 +37,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 7b3eeff1028b6c..011ea8b3cd3884 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -1,4 +1,5 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla:xla.bzl", "xla_cc_test") @@ -22,6 +23,39 @@ package_group( ], ) +cc_library( + name = "emitter_helpers", + srcs = ["emitter_helpers.cc"], + hdrs = ["emitter_helpers.h"], + deps = [ + "//xla:literal", + "//xla:shape_util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/mlir_hlo:transformation_helpers", + "//xla/service/gpu:target_util", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@triton//:TritonDialects", + ], +) + cc_library( name = "triton_fusion_emitter", srcs = if_gpu_is_configured( @@ -34,11 +68,11 @@ cc_library( ]), hdrs = ["triton_fusion_emitter.h"], deps = [ + ":emitter_helpers", ":passes", + ":triton_fusion_emitter_legacy_matmul", + ":triton_support", "//xla:autotuning_proto_cc", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:literal", "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", @@ -46,10 +80,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", + "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "//xla/mlir_hlo", - "//xla/mlir_hlo:map_mhlo_to_scalar_op", - "//xla/service:algorithm_util", "//xla/service:dump", "//xla/service:hlo_module_config", "//xla/service:instruction_fusion", @@ -58,38 +90,29 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu:target_util", "//xla/service/gpu:triton_fusion_analysis", - "//xla/service/gpu:triton_tiling_propagation", "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/transforms:passes", - "//xla/service/gpu/llvm_gpu_backend", - "//xla/service/gpu/model:affine_map_printer", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", "//xla/tools:hlo_decomposer_lib", - "//xla/translate/hlo_to_mhlo:hlo_function_importer", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Linker", "@llvm-project//llvm:Support", - "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -98,30 +121,23 @@ cc_library( "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:ControlFlowToLLVM", "@llvm-project//mlir:ExecutionEngineUtils", - "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:IndexToLLVM", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LLVMIRTransforms", "@llvm-project//mlir:LLVMToLLVMIRTranslation", - "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:NVVMToLLVMIRTranslation", "@llvm-project//mlir:Pass", "@llvm-project//mlir:ROCDLToLLVMIRTranslation", - "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:SCFToControlFlow", "@llvm-project//mlir:Support", "@llvm-project//mlir:ToLLVMIRTranslation", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:tensor_float_32_utils", "@triton//:TritonDialects", "@triton//:TritonTransforms", ] + if_gpu_is_configured([ @@ -132,14 +148,124 @@ cc_library( "@triton//:TritonLLVMIR", ]) + if_cuda_is_configured([ "@triton//third_party/nvidia:NVGPUToLLVM", + "//xla/service/gpu/llvm_gpu_backend:nvptx_libdevice_path", "@triton//third_party/nvidia:TritonNVIDIAGPUToLLVM", ]) + if_rocm_is_configured([ "@local_tsl//tsl/platform:rocm_rocdl_path", + "//xla/service/gpu/llvm_gpu_backend:llvm_gpu_backend", "@triton//third_party/amd:TritonAMDGPUToLLVM", "@triton//third_party/amd:TritonAMDGPUTransforms", ]), ) +cc_library( + name = "triton_fusion_emitter_legacy_matmul", + srcs = if_gpu_is_configured( + ["triton_fusion_emitter_legacy_matmul.cc"], + ["triton_fusion_emitter_legacy_matmul_stub.cc"], + ), + hdrs = ["triton_fusion_emitter_legacy_matmul.h"], + deps = [ + "//xla:comparison_util", + "//xla:literal", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/mlir_hlo:transformation_helpers", + "//xla/service:algorithm_util", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_indexing_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:target_util", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@triton//:TritonDialects", + ], +) + +cc_library( + name = "triton_fusion_emitter_stub_for_testing", + srcs = [ + "triton_fusion_emitter_legacy_matmul_stub.cc", + "triton_fusion_emitter_stub.cc", + ], + hdrs = [ + "triton_fusion_emitter.h", + "triton_fusion_emitter_legacy_matmul.h", + ], + deps = [ + "//xla:autotuning_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "//xla/stream_executor:device_description", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@triton//:TritonDialects", + ], +) + +xla_cc_test( + name = "triton_fusion_emitter_stub_test", + srcs = ["triton_fusion_emitter_stub_test.cc"], + deps = [ + ":triton_fusion_emitter_stub_for_testing", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_tsl//tsl/platform:test", + ], +) + gentbl_cc_library( name = "passes_inc_gen", tbl_outs = [ @@ -194,9 +320,97 @@ cc_library( ], ) +td_library( + name = "xla_triton_td_files", + srcs = glob(["*.td"]), + includes = ["."], + deps = [ + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "xla_triton_dialect_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + ], + "xla_triton_dialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + ], + "xla_triton_dialect.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_triton_dialect.td", + deps = [":xla_triton_td_files"], +) + +gentbl_cc_library( + name = "xla_triton_ops_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "xla_triton_ops.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "xla_triton_ops.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_triton_ops.td", + deps = [ + ":xla_triton_td_files", + "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:SideEffectInterfacesTdFiles", + "@triton//:td_files", + ], +) + +cc_library( + name = "xla_triton", + srcs = ["xla_triton_ops.cc"], + hdrs = ["xla_triton_ops.h"], + deps = [ + ":xla_triton_dialect_inc_gen", + ":xla_triton_ops_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SideEffectInterfaces", + "@triton//:TritonDialects", + "@triton//:triton_gpu_attr_inc_gen", + "@triton//:triton_gpu_types_inc_gen", + "@triton//:triton_ops_inc_gen", + ], +) + xla_test( name = "triton_fusion_emitter_device_legacy_test", srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]), + # TODO(b/372714955): Fix the memory leak! + backend_args = if_google( + { + "gpu_h100": ["--heap_check="], + "gpu_a100": ["--heap_check="], + }, + {}, + ), backends = [ "gpu_a100", "gpu_h100", @@ -211,10 +425,10 @@ xla_test( ":triton_test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", - "//xla:literal", - "//xla:literal_util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:backend_configs_cc", @@ -222,15 +436,9 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", - "//xla/tests:filecheck", - "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", @@ -238,13 +446,55 @@ xla_test( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) +xla_test( + name = "dot_algorithms_test", + srcs = if_gpu_is_configured(["dot_algorithms_test.cc"]), + backend_args = if_google( + { + "gpu_h100": ["--heap_check="], + "gpu_a100": ["--heap_check="], + }, + {}, + ), + backends = [ + "gpu_a100", + "gpu_h100", + "gpu_amd_any", + ], + tags = [ + "no_mac", + ], + deps = [ + ":kernel_name_tracer", + ":triton_test_utils", + "//xla:autotuning_proto_cc", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + xla_test( name = "triton_fusion_emitter_device_test", srcs = if_gpu_is_configured(["triton_fusion_emitter_device_test.cc"]), @@ -261,6 +511,8 @@ xla_test( ":triton_test_utils", "//xla:autotuning_proto_cc", "//xla:error_spec", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", @@ -268,7 +520,6 @@ xla_test( "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", @@ -285,6 +536,38 @@ xla_test( ], ) +cc_library( + name = "kernel_name_tracer_cuda", + testonly = True, + srcs = if_cuda(["kernel_name_tracer_cuda.cc"]), + hdrs = ["kernel_name_tracer.h"], + tags = ["manual"], # Need to exclude this from wildcard builds + deps = [ + "//xla/backends/profiler/gpu:cupti_collector", + "//xla/backends/profiler/gpu:cupti_tracer", + "//xla/tsl/profiler/utils:time_utils", + "@com_google_absl//absl/algorithm:container", + ], +) + +cc_library( + name = "kernel_name_tracer_noop", + testonly = True, + srcs = ["kernel_name_tracer_noop.cc"], + hdrs = ["kernel_name_tracer.h"], + tags = ["manual"], # Need to exclude this from wildcard builds +) + +cc_library( + name = "kernel_name_tracer", + testonly = True, + hdrs = ["kernel_name_tracer.h"], + deps = if_cuda( + [":kernel_name_tracer_cuda"], + [":kernel_name_tracer_noop"], + ), +) + cc_library( name = "triton_test_utils", testonly = True, @@ -296,8 +579,8 @@ cc_library( "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:float_normalization", "//xla/hlo/utils:hlo_query", - "//xla/service:float_normalization", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_float_support", @@ -401,7 +684,6 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu/tests:gpu_codegen_test", "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", @@ -434,6 +716,7 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:tensor_float_32_utils", ], ) @@ -480,6 +763,7 @@ xla_test( tags = ["no_mac", "no_rocm"], # TODO(rocm) 240729 deps = [ + ":kernel_name_tracer", ":triton_fusion_emitter", ":triton_support", ":triton_test_utils", diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index 46a569d265bcdd..7b5a6bc26faea6 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -25,7 +25,7 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "xla/service/gpu/fusions/triton/passes.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" @@ -58,6 +58,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createLoopInvariantCodeMotionPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/nvidia/backend/compiler.py diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 2a95ea833f4bcc..c9e6d553cd5d98 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -70,6 +70,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createLoopInvariantCodeMotionPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createLoopUnrollPass()); // Based on make_ttgir() in // @triton//:third_party/amd/backend/compiler.py @@ -90,12 +91,15 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createTritonAMDGPUStreamPipelinePass()); pm.addPass(mlir::createCanonicalizerPass()); } + pm.addPass(mt::createInsertInstructionSchedHintsPass()); pm.addPass(mt::gpu::createTritonGPUOptimizeDotOperands({true})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); if (block_level_parameters.num_stages != kAmdDoubleBuffering) { pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); } + pm.addPass(mlir::createTritonAMDGPUCanonicalizePointersPass()); + pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); @@ -119,6 +123,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mt::createLowerInstructionSchedHintsPass("default")); pm.addPass(mt::createConvertBuiltinFuncToLLVMPass()); // There is no clusters in ROCm for now. out_cluster_info.clusterDimX = 1; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc new file mode 100644 index 00000000000000..7cf3e449a5b0c7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/dot_algorithms_test.cc @@ -0,0 +1,1233 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" +#include "xla/service/gpu/fusions/triton/triton_test_utils.h" +#include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla.pb.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +class AlgorithmTest : public GpuCodegenTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); + debug_options.set_xla_dump_to("sponge"); + debug_options.set_xla_dump_hlo_pass_re(".*"); + debug_options.set_xla_gpu_dump_autotuned_gemm_fusions(true); + + // Enable triton fusion for all supported GEMMs. + debug_options.set_xla_gpu_triton_gemm_any(true); + + return debug_options; + } + + stream_executor::CudaComputeCapability GetCudaComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + } + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return device_desc().gpu_compute_capability(); + } + stream_executor::GpuComputeCapability CudaAmpereOrRocm() { + if (std::holds_alternative( + GpuComputeComp())) { + return stream_executor::GpuComputeCapability{ + device_desc().rocm_compute_capability()}; + } else { + return stream_executor::GpuComputeCapability{ + stream_executor::CudaComputeCapability{ + stream_executor::CudaComputeCapability::AMPERE, 0}}; + } + } + + protected: + const stream_executor::DeviceDescription& device_desc() { + return backend().default_stream_executor()->GetDeviceDescription(); + } +}; + +// In these tests, we depend on "algorithm" annotations for selecting the 6XBF16 +// algorithm. +class Triton6xBF16GemmTest : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // These 2 flags are not strictly necessary now, but we're adding them to be + // on the safe side against future flakiness. + // + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +// In these tests, we depend on debug option flags for selecting the 6XBF16 +// algorithm. +// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_6way_gemm +// flag after we will support the algorithm values through the entire stack. +class Triton6xBF16GemmTestWithFlag : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Enable bf16_6way gemm to compute F32 matmul. + debug_options.set_xla_gpu_enable_bf16_6way_gemm(true); + return debug_options; + } +}; + +class BlasAlgorithmTest : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + debug_options.set_xla_gpu_enable_triton_gemm(false); + return debug_options; + } +}; + +class TritonAlgorithmTest : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Enable gemm for any hlo including pure matmuls. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +TEST_F(AlgorithmTest, Algorithm3xBF16) { + constexpr std::string_view kHloText = R"( + HloModule Algorithm3xBF16 + + ENTRY e { + p0 = f32[128,128] parameter(0) + p1 = f32[128,128] parameter(1) + ROOT dot = f32[128,128] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 + } + )"; + EXPECT_TRUE( + RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0.001, /*arel=*/0.001})); +} + +TEST_F(AlgorithmTest, Algorithm6xBF16) { + constexpr std::string_view kHloText = R"( + HloModule Algorithm6xBF16 + + ENTRY e { + p0 = f32[128,128] parameter(0) + p1 = f32[128,128] parameter(1) + ROOT dot = f32[128,128] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 + } + )"; + EXPECT_TRUE( + RunAndCompare(kHloText, ErrorSpec{/*aabs=*/0.001, /*arel=*/0.001})); +} + +TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32) { + // We check that the algorithm is propagated to the BLAS call. + // We also check that the kernel name matches the algorithm for Ampere. + // The algorithm for Hopper is not the one we expect because it uses TF32. + + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr std::string_view kHloText = R"( + HloModule Algorithm_BF16_BF16_F32 + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = R"( + CHECK: %convert{{.*}} = bf16[ + CHECK: %convert{{.*}} = bf16[ + CHECK: "algorithm":"ALG_UNSET" + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); + + auto tracer = KernelNameTracer::Create(); + if (tracer == nullptr) { + GTEST_SKIP() << "KernelNameTracer is not implemented."; + } + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_names = tracer->stop(); + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_names[0]; + break; + case CudaComputeCapabilities::AMPERE: + EXPECT_THAT(kernel_names, ::testing::UnorderedElementsAre( + ::testing::Eq("wrapped_convert"), + ::testing::Eq("wrapped_convert_1"), + ::testing::HasSubstr("gemm_bf16_"))); + break; + case CudaComputeCapabilities::HOPPER: + // Convert to bf16+cublas works faster than dot with algorithm. + EXPECT_THAT(kernel_names, + ::testing::UnorderedElementsAre( + ::testing::Eq("wrapped_convert"), + ::testing::Eq("wrapped_convert_1"), + ::testing::HasSubstr("gemm_bf16f32_bf16f32"))); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_names[0]; + } +} + +TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X3) { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr std::string_view kHloText = R"( + HloModule Algorithm_BF16_BF16_F32_X3 + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + // Single dot was replaced with 3 dots. + const std::string pattern = R"( + CHECK-COUNT-3: custom_call_target="__cublas$gemm" + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); + + auto tracer = KernelNameTracer::Create(); + if (tracer == nullptr) { + GTEST_SKIP() << "KernelNameTracer is not implemented."; + } + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_names = tracer->stop(); + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_names[0]; + break; + case CudaComputeCapabilities::AMPERE: + ASSERT_EQ(kernel_names.size(), 1); + EXPECT_THAT(kernel_names[0], ::testing::Eq("loop_convert_fusion_1")); + break; + case CudaComputeCapabilities::HOPPER: + EXPECT_THAT(kernel_names, + ::testing::UnorderedElementsAre( + ::testing::Eq("loop_convert_fusion_1"), + ::testing::HasSubstr("gemm_bf16f32_bf16f32_f32_"))); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_names[0]; + } +} + +TEST_F(BlasAlgorithmTest, Algorithm_BF16_BF16_F32_X6) { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + constexpr std::string_view kHloText = R"( + HloModule Algorithm_BF16_BF16_F32_X6 + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32_x6, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + // Single dot was replaced with 3 dots. + const std::string pattern = R"( + CHECK-COUNT-6: custom_call_target="__cublas$gemm" + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); + + auto tracer = KernelNameTracer::Create(); + if (tracer == nullptr) { + GTEST_SKIP() << "KernelNameTracer is not implemented."; + } + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_names = tracer->stop(); + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_names[0]; + break; + case CudaComputeCapabilities::AMPERE: + ASSERT_EQ(kernel_names.size(), 1); + EXPECT_THAT(kernel_names[0], ::testing::Eq("loop_convert_fusion_1")); + break; + case CudaComputeCapabilities::HOPPER: + EXPECT_THAT(kernel_names, + ::testing::UnorderedElementsAre( + ::testing::HasSubstr("loop_convert_fusion"), + ::testing::HasSubstr("gemm_bf16f32_bf16f32_f32_"))); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_names[0]; + } +} + +TEST_F(BlasAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { + // We check that the algorithm is propagated to the BLAS call. + // We also check that the kernel name matches the algorithm for Ampere. + + constexpr std::string_view kHloText = R"( + HloModule Algorithm_TF32_TF32_F32_X3 + + ENTRY main { + lhs = f32[8512,256]{1,0} parameter(0) + rhs = f32[256,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_tf32_tf32_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "algorithm":"ALG_DOT_TF32_TF32_F32_X3")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + ASSERT_TRUE(ok); + + auto tracer = KernelNameTracer::Create(); + if (tracer == nullptr) { + GTEST_SKIP() << "KernelNameTracer is not implemented."; + } + tracer->start(); + EXPECT_TRUE(Run(std::move(module), /*run_hlo_passes=*/false)); + auto kernel_names = tracer->stop(); + + auto cc = GetCudaComputeCapability(); + using CudaComputeCapabilities = + stream_executor::CudaComputeCapability::CudaComputeCapabilities; + switch (cc.major) { + case CudaComputeCapabilities::BLACKWELL: + GTEST_SKIP() << "CudaComputeCapabilities::BLACKWELL has the kernel name: " + << kernel_names[0]; + break; + case CudaComputeCapabilities::AMPERE: + // There is no support for TF32_TF32_F32_X3 on Ampere. We use F32_F32_F32. + EXPECT_THAT( + kernel_names, + ::testing::Contains(::testing::HasSubstr("ampere_sgemm_128x64_nn"))); + break; + case CudaComputeCapabilities::HOPPER: + // There is no support for TF32_TF32_F32_X3 on Hopper. We use F32_F32_F32. + EXPECT_THAT( + kernel_names, + ::testing::Contains(::testing::HasSubstr("gemm_f32f32_f32f32_f32"))); + break; + default: + GTEST_SKIP() << "Unsupported compute capability: " << cc.major + << " has the kernel name: " << kernel_names[0]; + } +} + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { + constexpr std::string_view kHloText = R"( +HloModule Emit6xBF16GemmWhenBothInputsAreF32 + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { + constexpr std::string_view kHloText = R"( +HloModule Emit6xBF16GemmWhenBothInputsAreF32 + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { + constexpr std::string_view kHloText = R"( +HloModule Triton6xBF16GemmWorksForLongContractingDimension + +triton_dot { + p0 = f32[5,2048] parameter(0) + p1 = f32[2048,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[5,2048]{1,0} parameter(0) + p1 = f32[2048,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { + constexpr std::string_view kHloText = R"( +HloModule Triton6xBF16GemmCanHandleInfinity + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { + constexpr std::string_view kHloText = R"( +HloModule Triton6xBF16GemmCanHandleNaN + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +// Test case shows that why we truncate the middle term instead of rounding. +// If we round the middle term, the splitted terms may disagree in sign. This +// could result in wrong results for extreme values. +// For example, consider: +// x = -3.40282347e+38 +// If we round the middle term, its decomposition would be: +// x_hi: -3.38953139e+38 +// x_mid: -1.3240357e+36 +// x_lo: 5.17201445e+33 +// The result of x*x would be NaN instead of positive infinity. +TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForInputsWithLargeExponent) { + constexpr std::string_view kHloText = R"( +HloModule Triton6xBF16GemmWorksForInputsWithLargeExponent + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + constexpr float kLargeExponentFloat = 0x1.0103p72f; + arguments[0] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); +} + +TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; + } + constexpr std::string_view kHloText = R"( +HloModule Emit6xBF16GemmEndToEnd + +ENTRY e { + p0 = f32[5,32] parameter(0) + p1 = f32[32,7] parameter(1) + ROOT dot = f32[5,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x6 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 +CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 +)"); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, + /*arel=*/1e-6})); +} + +// In these tests, we depend on "algorithm" annotations for selecting the 3XBF16 +// algorithm. +class Triton3xBF16GemmTest : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // These 2 flags are not strictly necessary now, but we're adding them the + // to be on the safe side against future flakiness. + // + // Enable triton fusion for all supported GEMMs. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + return debug_options; + } +}; + +// In these tests, we depend on debug option flags for selecting the 3XBF16 +// algorithm. +// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_3way_gemm +// flag after we will support the algorithm values through the entire stack. +class Triton3xBF16GemmTestWithFlag : public AlgorithmTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = AlgorithmTest::GetDebugOptionsForTest(); + // Enable triton fusion for all supported GEMMs. + debug_options.set_xla_gpu_triton_gemm_any(true); + // Do not fall back to cuBLAS, we are testing Triton. + debug_options.set_xla_gpu_cublas_fallback(false); + // Do not autotune split-k by default, since this prevents deterministically + // matching the optimized HLO. + debug_options.set_xla_gpu_enable_split_k_autotuning(false); + // Enable bf16_3way gemm to compute F32 matmul. + debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); + return debug_options; + } + + protected: + void SetUp() override { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + } +}; + +TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { + constexpr std::string_view kHloText = R"( +HloModule Emit3xBF16GemmWhenBothInputsAreF32 + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { + constexpr std::string_view kHloText = R"( +HloModule Emit3xBF16GemmWhenBothInputsAreF32 + +triton_dot { + p0 = f32[5,7] parameter(0) + p1 = f32[7,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f32[5,7]{1,0} parameter(0) + p1 = f32[7,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> +CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> +CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> +CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> +CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> +CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> +CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> +CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ABS:.*]] = math.absf +CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> +CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> +CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> +CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { + constexpr std::string_view kHloText = R"( +HloModule NoEmit3xBF16GemmWhenBothInputsAreNotF32 + +triton_dot { + p0 = f16[5,7] parameter(0) + p1 = f16[7,33] parameter(1) + ROOT dot = f16[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY e { + p0 = f16[5,7]{1,0} parameter(0) + p1 = f16[7,33]{1,0} parameter(1) + ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK: tt.dot +CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32> +CHECK-NOT: tt.dot + )")); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { + constexpr std::string_view kHloText = R"( +HloModule Triton3xBF16GemmWorksForLongContractingDimension + +triton_dot { + p0 = f32[5,2048] parameter(0) + p1 = f32[2048,33] parameter(1) + ROOT dot = f32[5,33] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[5,2048]{1,0} parameter(0) + p1 = f32[2048,33]{1,0} parameter(1) + ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> + )")); + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-4, + /*arel=*/1e-4})); +} + +TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleInfinity) { + constexpr std::string_view kHloText = R"( +HloModule Triton3xBF16GemmCanHandleInfinity + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(BlasAlgorithmTest, Blas3xBF16GemmCanHandleInfinity) { + constexpr std::string_view kHloText = R"( +HloModule Blas3xBF16GemmCanHandleInfinity + +ENTRY e { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}, + {+std::numeric_limits::infinity(), + +std::numeric_limits::infinity()}}); + arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(BlasAlgorithmTest, Blas3xBF16GemmCanHandleNaN) { + constexpr std::string_view kHloText = R"( +HloModule Blas3xBF16GemmCanHandleNaN + +ENTRY e { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(BlasAlgorithmTest, Blas3xBF16GemmWorksForInputsWithLargeExponent) { + constexpr std::string_view kHloText = R"( +HloModule Blas3xBF16GemmWorksForInputsWithLargeExponent + +ENTRY e { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + constexpr float kLargeExponentFloat = 0x1.0103p72f; + arguments[0] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +TEST_F(TritonAlgorithmTest, Triton3xBF16GemmCanHandleNaN) { + constexpr std::string_view kHloText = R"( +HloModule Triton3xBF16GemmCanHandleNaN + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + arguments[0] = + LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}}); + arguments[1] = LiteralUtil::CreateR2( + {{1.0f, +std::numeric_limits::infinity()}, + {1.0f, +std::numeric_limits::infinity()}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonAlgorithmTest, Triton3xBF16GemmWorksForInputsWithLargeExponent) { + constexpr std::string_view kHloText = R"( +HloModule Triton3xBF16GemmWorksForInputsWithLargeExponent + +triton_dot { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT dot = f32[2,2] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} + +ENTRY e { + p0 = f32[2,2]{1, 0} parameter(0) + p1 = f32[2,2]{1, 0} parameter(1) + ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, + backend_config={"fusion_backend_config": {kind: "__triton_gemm", + triton_gemm_config: + {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} +} +)"; + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( +CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> + )")); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetOptimizedModule(kHloText)); + std::vector arguments(2); + constexpr float kLargeExponentFloat = 0x1.0103p72f; + arguments[0] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + arguments[1] = LiteralUtil::CreateR2( + {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); + std::vector argument_ptrs; + absl::c_transform( + arguments, std::back_inserter(argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(std::move(module), argument_ptrs, + ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); +} + +TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; + } + constexpr std::string_view kHloText = R"( +HloModule Emit3xBF16GemmEndToEnd + +ENTRY e { + p0 = f32[5,32] parameter(0) + p1 = f32[32,7] parameter(1) + ROOT dot = f32[5,7] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0}, + algorithm=dot_bf16_bf16_f32_x3 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(kHloText)); + CompileAndOptionallyVerifyPtx(std::move(verified_module), + R"( +CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 +CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 +)"); + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, + /*arel=*/1e-5})); +} + +TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X3) { + const std::string kHloText = R"( + HloModule Algorithm_BF16_BF16_F32_X3 + + ENTRY main { + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + +TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32_X6) { + const std::string kHloText = R"( + HloModule Algorithm_BF16_BF16_F32_X6 + + ENTRY main { + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32_x6, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + +TEST_F(TritonAlgorithmTest, Algorithm_TF32_TF32_F32_X3) { + const std::string kHloText = R"( + HloModule Algorithm_TF32_TF32_F32_X3 + + ENTRY main { + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_tf32_tf32_f32_x3, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + +TEST_F(TritonAlgorithmTest, Algorithm_BF16_BF16_F32) { + if (!SupportsBF16(GpuComputeComp())) { + GTEST_SKIP() << "BF16 not supported."; + } + const std::string kHloText = R"( + HloModule Algorithm_BF16_BF16_F32 + + ENTRY main { + lhs = f32[8512,64]{1,0} parameter(0) + rhs = f32[64,8512]{1,0} parameter(1) + ROOT dot = f32[8512,8512]{1,0} dot(lhs, rhs), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + const std::string pattern = + R"(CHECK: "kind":"__triton_gemm","triton_gemm_config")"; + TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); + TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); + EXPECT_TRUE(ok); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc new file mode 100644 index 00000000000000..0f2dacf571baa8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.cc @@ -0,0 +1,441 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/emitter_helpers.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/TargetParser/Triple.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" +#include "xla/primitive_util.h" +#include "xla/service/gpu/target_util.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "tsl/platform/statusor.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace xla::gpu::triton { + +using ::llvm::SmallVector; +using ::mlir::ArrayRef; +using ::mlir::ImplicitLocOpBuilder; +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; +using ::mlir::ValueRange; + +namespace ma = ::mlir::arith; +namespace mm = ::mlir::math; +namespace mt = ::mlir::triton; + +ScalarOrTensor::ScalarOrTensor(mlir::Value value) { + if (auto tt = mlir::dyn_cast(value.getType())) { + CHECK_GT(tt.getRank(), 0); + value_ = TensorValue{value}; + } else { + value_ = ScalarValue{value}; + } +} + +SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { + SmallVector result; + result.reserve(tile_sizes.size()); + for (int64_t value : tile_sizes) { + result.push_back(llvm::PowerOf2Ceil(value)); + } + return result; +} + +absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { + switch (t) { + case F64: + return b.getF64Type(); + case F32: + return b.getF32Type(); + case F16: + return b.getF16Type(); + case BF16: + return b.getBF16Type(); + case S64: + return b.getI64Type(); + case S32: + return b.getI32Type(); + case S16: + return b.getI16Type(); + case PRED: + return b.getI1Type(); + case S8: + return b.getI8Type(); + case F8E5M2: + return b.getFloat8E5M2Type(); + case F8E4M3FN: + return b.getFloat8E4M3FNType(); + default: + return absl::UnimplementedError( + absl::StrCat("This type is not supported yet: ", + primitive_util::LowercasePrimitiveTypeName(t))); + } +} + +Type StorageType(mlir::OpBuilder b, Type t) { + if (t.isInteger(1)) { + return b.getI8Type(); + } + return t; +} + +bool IsFp8Type(Type t) { + return t.isFloat8E5M2() || t.isFloat8E4M3FN() || t.isFloat8E5M2FNUZ() || + t.isFloat8E4M3FNUZ() || t.isFloat8E4M3B11FNUZ(); +} + +Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { + Type src_ty = value.getType(); + Type src_element_ty = src_ty; + Type fp32_ty = b.getF32Type(); + Type dst_ty = dst_element_ty; + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + src_element_ty = src_shaped_ty.getElementType(); + dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty); + fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type()); + } + if (src_ty == dst_ty) { + return value; + } + + // All operations on bf16 are done through f32. + if (src_element_ty.isBF16()) { + return Cast(b, b.create(fp32_ty, value), dst_element_ty); + } + if (dst_element_ty.isBF16()) { + // S8 -> BF16 is directly supported and doesn't need to go through f32. + if (!src_element_ty.isInteger(8)) { + return b.create(dst_ty, Cast(b, value, b.getF32Type())); + } + } + + // float => float + auto src_fp_element_ty = mlir::dyn_cast(src_element_ty); + auto dst_fp_element_ty = mlir::dyn_cast(dst_element_ty); + if (src_fp_element_ty && dst_fp_element_ty) { + // F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp + // because LLVM doesn't support casts from/to FP8. + // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as + // we can't test the code below without patching the feature. + if (IsFp8Type(src_element_ty)) { + return b.create(dst_ty, value); + } + if (IsFp8Type(dst_element_ty)) { + return b.create( + dst_ty, value, + mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE)); + } + + if (src_fp_element_ty.getFPMantissaWidth() > + dst_fp_element_ty.getFPMantissaWidth()) { + return b.create(dst_ty, value); + } else { + return b.create(dst_ty, value); + } + } + // int => int + if (mlir::isa(src_element_ty) && + mlir::isa(dst_element_ty)) { + if (src_element_ty.getIntOrFloatBitWidth() < + dst_element_ty.getIntOrFloatBitWidth()) { + if (src_element_ty.isInteger(1)) { + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + // int => float + if (mlir::isa(src_element_ty) && dst_fp_element_ty) { + // TODO(b/266862493): Support unsigned integer types. + if (src_element_ty.isInteger(1)) { + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + // float => int + if (src_fp_element_ty && mlir::isa(dst_element_ty)) { + if (dst_element_ty.isInteger(1)) { + return b.create(ma::CmpFPredicate::UNE, value, + ZerosLike(b, value)); + } + // TODO(b/266862493): Support unsigned integer types. + // The current logic handles signed integer types only. Additional handling + // is needed for unsigned integer types. + auto cst_int = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()) + .UnwrapUnsafe(); + } else { + return CreateConst(b, dst_element_ty, x).UnwrapUnsafe(); + } + }; + auto cst_float = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()) + .UnwrapUnsafe(); + } else { + return CreateConst(b, src_fp_element_ty, x).UnwrapUnsafe(); + } + }; + auto fptosi = b.create(dst_ty, value); + int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth()); + int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth()); + + // value <= static_cast(INT_MIN) ? INT_MIN : ... + auto clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OLE, value, + cst_float(min)), + cst_int(min), fptosi); + // value >= static_cast(INT_MAX) ? INT_MAX : ... + clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OGE, value, + cst_float(max)), + cst_int(max), clamped); + // isnan(value) ? 0 : ... + return b.create( + b.create(mlir::arith::CmpFPredicate::UNO, value, + value), + cst_int(0), clamped); + } + + LOG(FATAL) << "Type conversion not supported: " + << llvm_ir::DumpToString(src_element_ty) << " -> " + << llvm_ir::DumpToString(dst_element_ty); +} + +Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values[0], values[1]); + } else { + return b.create(values[0], values[1]); + } +} + +Value Compare(ImplicitLocOpBuilder& b, ValueRange values, + mlir::mhlo::ComparisonDirection direction) { + const Type type = mlir::getElementTypeOrSelf(values[0]); + if (mlir::isa(type)) { + return b.create( + mlir::mhlo::impl::getCmpPredicate( + direction, + /*isSigned=*/!type.isInteger(1)) + .value(), + values[0], values[1]); + } + return b.create( + mlir::mhlo::impl::getCmpPredicate(direction, + /*isSigned=*/true) + .value(), + values[0], values[1]); +} + +Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values); + } + // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This also works, but we wanted to make it similar to minimum. + // logic: isNaN(lhs) || lhs >= rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_ge = Compare(b, values, mlir::mhlo::ComparisonDirection::GE); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_ge)), + values[0], values[1]); +} + +Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values); + } + // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This should also work, but the tests show that it doesn't work for + // minimum(x, NaN): + // logic: isNaN(lhs) || lhs <= rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_le = Compare(b, values, mlir::mhlo::ComparisonDirection::LE); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_le)), + values[0], values[1]); +} + +ScalarOrTensor Splat(ImplicitLocOpBuilder& b, ScalarOrTensor value, + ArrayRef shape) { + CHECK(!shape.empty()); + auto type = mlir::RankedTensorType::get(shape, value.Type()); + return ScalarOrTensor(b.create(type, value.UnwrapUnsafe())); +} + +absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloInstruction& hlo, + ValueRange inputs) { + if (mlir::getElementTypeOrSelf(inputs[0]).isF32() || + mlir::getElementTypeOrSelf(inputs[0]).isF64()) { + auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); + if (dev_fn_id.ok()) { + llvm::Triple triple("nvptx64-unknown-unknown"); + if (std::holds_alternative( + device_info.gpu_compute_capability())) { + triple.setTriple("amdgcn-unknown-unknown"); + } + return b.create( + inputs[0].getType(), inputs, "libdevice", libdevice_path, + ObtainDeviceFunctionName(dev_fn_id.value(), + hlo.shape().element_type(), triple), + /*pure=*/true); + } + } + const bool is_integer = + mlir::isa(mlir::getElementTypeOrSelf(inputs[0])); + + switch (hlo.opcode()) { + case HloOpcode::kCopy: + // Dimension transformations are taken care of separately. + return inputs[0]; + case HloOpcode::kAbs: + if (is_integer) { + return b.create(inputs[0]); + } + return b.create(inputs[0]); + case HloOpcode::kCeil: + return b.create(inputs[0]); + case HloOpcode::kFloor: + return b.create(inputs[0]); + case HloOpcode::kNot: + return b.create(inputs[0], OnesLike(b, inputs[0])); + case HloOpcode::kNegate: + // NegFOp is not supported by Triton. + return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); + case HloOpcode::kConvert: { + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, hlo.shape().element_type())); + return Cast(b, inputs[0], dst_ty); + } + case HloOpcode::kAdd: + if (is_integer) { + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kSubtract: + return Subtract(b, inputs); + case HloOpcode::kMultiply: + if (is_integer) { + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kMaximum: + return Maximum(b, device_info, inputs); + case HloOpcode::kMinimum: + return Minimum(b, device_info, inputs); + case HloOpcode::kClamp: + return Maximum( + b, device_info, + {Minimum(b, device_info, {inputs[1], inputs[2]}), inputs[0]}); + case HloOpcode::kAnd: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kOr: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kXor: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kDivide: + if (is_integer) { + // Unsigned not supported yet. + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kCompare: + return Compare( + b, inputs, + mlir::mhlo::symbolizeComparisonDirection( + ComparisonDirectionToString(hlo.comparison_direction())) + .value()); + case HloOpcode::kSelect: + return b.create( + Compare(b, {inputs[0], ZerosLike(b, inputs[0])}, + mlir::mhlo::ComparisonDirection::NE), + inputs[1], inputs[2]); + case HloOpcode::kReducePrecision: + return mlir::mhlo::reducePrecision( + b.getLoc(), inputs[0], hlo.exponent_bits(), hlo.mantissa_bits(), &b); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); + } +} + +absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, + const HloInstruction& constant) { + TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); + llvm::SmallVector shape{constant.shape().dimensions().begin(), + constant.shape().dimensions().end()}; + + if (constant.shape().IsInteger()) { + if (constant.shape().element_type() == U64) { + return CreateConst(b, ty, ScalarConstantValue(constant, U64), + shape); + } else { + return CreateConst(b, ty, ScalarConstantValue(constant, S64), + shape); + } + } + return CreateConst(b, ty, ScalarConstantValue(constant, F64), shape); +} + +} // namespace xla::gpu::triton diff --git a/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h new file mode 100644 index 00000000000000..1fc7372bd2f751 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/emitter_helpers.h @@ -0,0 +1,198 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_EMITTER_HELPERS_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_EMITTER_HELPERS_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" +#include "tsl/platform/status.h" + +namespace xla::gpu::triton { + +// This is a wrapper around mlir::Value that can hold either a scalar or a +// non-0D tensor. An attempt to use this class with 0D tensors will CHECK-fail +// because 0D tensors are not supported by Triton. +class ScalarOrTensor { + public: + ScalarOrTensor() = default; + + // Wraps the given value in a ScalarOrTensor. CHECK-fails if the + // value is a 0D tensor, because Triton does not support 0D tensors. + explicit ScalarOrTensor(mlir::Value value); + + bool IsScalar() const { return std::holds_alternative(value_); } + bool IsTensor() const { return std::holds_alternative(value_); } + + mlir::Value UnwrapScalar() { + CHECK(IsScalar()); + return std::get(value_).scalar_value; + } + + mlir::Value UnwrapTensor() { + CHECK(IsTensor()); + return std::get(value_).tensor_value; + } + + // Returns the underlying value regardless of whether it is a scalar or a + // tensor. Only call this method in contexts where the consumer of the result + // both needs to use an `mlir::Value` and functions identically for scalars + // and tensors. In other cases, prefer to use the `UnwrapScalar` or + // `UnwrapTensor` methods. + mlir::Value UnwrapUnsafe() { + if (auto* scalar = std::get_if(&value_)) { + return scalar->scalar_value; + } + return std::get(value_).tensor_value; + } + + mlir::Type Type() { return UnwrapUnsafe().getType(); } + + private: + struct ScalarValue { + mlir::Value scalar_value; + }; + + struct TensorValue { + mlir::Value tensor_value; + }; + + std::variant value_; +}; + +// Triton requires that all block dimensions are a power of 2. +// TODO(b/353484968): Delete this function once we have constraints to only +// propagate tile sizes that are a power of 2. +llvm::SmallVector GetPaddedTileSizes( + llvm::ArrayRef tile_sizes); + +// XLA -> Triton type conversions. +absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t); + +mlir::Type StorageType(mlir::OpBuilder b, mlir::Type t); + +// Get the value of the scalar constant's literal in a C++ type. +template +T ScalarConstantValue(const HloInstruction& instr, PrimitiveType dst_type) { + CHECK_EQ(instr.opcode(), HloOpcode::kConstant); + CHECK(ShapeUtil::IsEffectiveScalar(instr.shape())); + absl::StatusOr converted = instr.literal().Convert(dst_type); + TF_CHECK_OK(converted.status()); + return converted.value().GetFirstElement(); +} + +// Create a scalar constant. +template +ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder b, mlir::Type type, + T value) { + if (mlir::isa(type)) { + auto result = + b.create(b.getIntegerAttr(type, value)); + return ScalarOrTensor(result); + } + if (mlir::isa(type)) { + auto result = b.create( + b.getFloatAttr(type, static_cast(value))); + return ScalarOrTensor(result); + } + LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); +} + +// Create a tensor constant. +template +ScalarOrTensor CreateConst(mlir::ImplicitLocOpBuilder& b, mlir::Type type, + T value, llvm::ArrayRef shape) { + if (shape.empty()) { + return CreateConst(b, type, value); + } + auto tensor_type = mlir::RankedTensorType::get(shape, type); + if (auto int_type = mlir::dyn_cast(type)) { + auto result = + b.create(mlir::DenseElementsAttr::get( + tensor_type, mlir::APInt(int_type.getIntOrFloatBitWidth(), value))); + return ScalarOrTensor(result); + } + if (auto float_type = mlir::dyn_cast(type)) { + auto result = + b.create(mlir::DenseElementsAttr::get( + tensor_type, b.getFloatAttr(type, static_cast(value)))); + return ScalarOrTensor(result); + } + LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); +} + +// Create a constant of the same shape as `like` but with a new type and value. +template +mlir::Value ConstLike(mlir::ImplicitLocOpBuilder& b, mlir::Value like, + T new_value) { + if (auto src_shaped_ty = mlir::dyn_cast(like.getType())) { + mlir::Type src_ty = src_shaped_ty.getElementType(); + return CreateConst(b, src_ty, new_value, src_shaped_ty.getShape()) + .UnwrapUnsafe(); + } + return CreateConst(b, like.getType(), new_value).UnwrapUnsafe(); +} + +inline mlir::Value ZerosLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { + return ConstLike(b, x, 0); +} + +inline mlir::Value OnesLike(mlir::ImplicitLocOpBuilder& b, mlir::Value x) { + return ConstLike(b, x, 1); +} + +bool IsFp8Type(mlir::Type t); + +ScalarOrTensor Splat(mlir::ImplicitLocOpBuilder& b, ScalarOrTensor value, + llvm::ArrayRef shape); + +// Triton type conversions. +mlir::Value Cast(mlir::ImplicitLocOpBuilder& b, mlir::Value value, + mlir::Type dst_element_ty); + +// Emits a scalar constant. +absl::StatusOr EmitConstant(mlir::ImplicitLocOpBuilder& b, + const HloInstruction& constant); + +absl::StatusOr EmitElementwise( + mlir::ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, const HloInstruction& hlo, + mlir::ValueRange inputs); + +} // namespace xla::gpu::triton + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_EMITTER_HELPERS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h new file mode 100644 index 00000000000000..487228eee63406 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer.h @@ -0,0 +1,43 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ + +#include +#include +#include + +namespace xla::gpu { + +// In some cases we need to know what exact kernel was used. It happens when we +// have no direct way to get this information from the HLO. For example, when we +// have a fusion with a custom call to cuBLAS or another third party library. +// This class allows to get the names of the kernels that were used. +class KernelNameTracer { + public: + static std::unique_ptr Create(); + + virtual void start() = 0; + + // It should return the names of the kernels that were executed on GPU:0. + virtual std::vector stop() = 0; + + virtual ~KernelNameTracer() = default; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_KERNEL_NAME_TRACER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc new file mode 100644 index 00000000000000..b975d18e2a87d5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_cuda.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "xla/backends/profiler/gpu/cupti_collector.h" +#include "xla/backends/profiler/gpu/cupti_tracer.h" +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" +#include "xla/tsl/profiler/utils/time_utils.h" + +namespace xla::gpu { + +// This class allows to get the name of the kernel that was used. +// It works only on CUDA. It uses CuptiTracer to get the kernel name. +class KernelNameTracerCuda : public KernelNameTracer { + public: + KernelNameTracerCuda() + : cupti_tracer_(profiler::CuptiTracer::GetCuptiTracerSingleton()) {} + + void start() override; + + std::vector stop() override; + + private: + profiler::CuptiTracer* cupti_tracer_; // Not owned. + std::unique_ptr cupti_collector_; +}; + +std::unique_ptr KernelNameTracer::Create() { + return std::make_unique(); +} + +void KernelNameTracerCuda::start() { + profiler::CuptiTracerCollectorOptions collector_options; + collector_options.num_gpus = profiler::CuptiTracer::NumGpus(); + auto start_gputime_ns = profiler::CuptiTracer::GetTimestamp(); + auto start_walltime_ns = tsl::profiler::GetCurrentTimeNanos(); + cupti_collector_ = profiler::CreateCuptiCollector( + collector_options, start_walltime_ns, start_gputime_ns); + profiler::CuptiTracerOptions options; + options.activities_selected = {CUPTI_ACTIVITY_KIND_KERNEL}; + cupti_tracer_->Enable(options, cupti_collector_.get()); +} + +std::vector KernelNameTracerCuda::stop() { + cupti_tracer_->Disable(); + uint64_t end_gpu_ns = cupti_collector_->GetTracingEndTimeNs(); + auto space = std::make_unique(); + cupti_collector_->Export(space.get(), end_gpu_ns); + for (const auto& plane : space->planes()) { + std::vector names; + if (plane.name() == "/device:GPU:0") { + for (const auto& metadata : plane.event_metadata()) { + names.push_back(metadata.second.name()); + } + return names; + } + } + return {}; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/python/outfeed_receiver_py.h b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc similarity index 64% rename from third_party/xla/xla/python/outfeed_receiver_py.h rename to third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc index 0148ff0ac312b7..4d6b361e84a97e 100644 --- a/third_party/xla/xla/python/outfeed_receiver_py.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/kernel_name_tracer_noop.cc @@ -1,4 +1,4 @@ -/* Copyright 2020 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ -#define XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ +#include -// placeholder for index annotation headers -#include "nanobind/nanobind.h" +#include "xla/service/gpu/fusions/triton/kernel_name_tracer.h" -namespace xla { +namespace xla::gpu { -void BuildOutfeedReceiverSubmodule(nanobind::module_& m); +std::unique_ptr KernelNameTracer::Create() { return nullptr; } -} // namespace xla - -#endif // XLA_PYTHON_OUTFEED_RECEIVER_PY_H_ +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index a6e6ac3e2181cc..7eb6cedd67bb04 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -15,29 +15,21 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" -#include -#include -#include #include #include -#include -#include #include #include -#include #include #include // NOLINT(build/c++11): required to interface with LLVM #include #include #include -#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" -#include "absl/strings/cord.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -48,9 +40,7 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/TargetParser/Triple.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" @@ -63,11 +53,8 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -94,40 +81,31 @@ limitations under the License. #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/Passes.h" #include "xla/autotuning.pb.h" -#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/utils/hlo_query.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/layout_util.h" -#include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" #include "xla/permutation_util.h" -#include "xla/primitive_util.h" -#include "xla/service/algorithm_util.h" #include "xla/service/dump.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/fusions/triton/emitter_helpers.h" #include "xla/service/gpu/fusions/triton/passes.h" -#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" -#include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/gpu/model/triton_emitter_constraints.h" -#include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" -#include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/hlo_module_config.h" #include "xla/service/instruction_fusion.h" #include "xla/service/llvm_ir/llvm_util.h" @@ -136,15 +114,12 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/launch_dim.h" #include "xla/tools/hlo_decomposer.h" -#include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/tensor_float_32_utils.h" #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -156,456 +131,53 @@ namespace xla { namespace gpu { namespace ma = ::mlir::arith; -namespace mm = ::mlir::math; namespace mn = ::mlir::NVVM; namespace mt = ::mlir::triton; using ::llvm::SmallVector; -using mlir::ArrayRef; -using mlir::ImplicitLocOpBuilder; +using ::mlir::ArrayRef; +using ::mlir::ImplicitLocOpBuilder; using ::mlir::ShapedType; using ::mlir::Type; using ::mlir::Value; -using mlir::ValueRange; +using ::mlir::ValueRange; -namespace { - -// Triton requires that all block dimensions are a power of 2. -// TODO(b/353484968): Delete this function once we have constraints to only -// propagate tile sizes that are a power of 2. -llvm::SmallVector GetPaddedTileSizes(ArrayRef tile_sizes) { - llvm::SmallVector result; - result.reserve(tile_sizes.size()); - for (int64_t value : tile_sizes) { - result.push_back(llvm::PowerOf2Ceil(value)); - } - return result; -} - -// XLA -> Triton type conversions. -absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { - switch (t) { - case F64: - return b.getF64Type(); - case F32: - return b.getF32Type(); - case F16: - return b.getF16Type(); - case BF16: - return b.getBF16Type(); - case S64: - return b.getI64Type(); - case S32: - return b.getI32Type(); - case S16: - return b.getI16Type(); - case PRED: - return b.getI1Type(); - case S8: - return b.getI8Type(); - case S4: // The unpacking to i8 is supported by the emitter. - // We pass the s4 tensor as i8 tensor with the minor dimension having 2x - // less elements and unpack in the inner loop of the triton kernel. - return b.getI8Type(); - case F8E5M2: - return b.getFloat8E5M2Type(); - case F8E4M3FN: - return b.getFloat8E4M3FNType(); - default: - return absl::UnimplementedError( - absl::StrCat("This type is not supported yet: ", - primitive_util::LowercasePrimitiveTypeName(t))); - } -} - -Type StorageType(mlir::OpBuilder b, Type t) { - if (t.isInteger(1)) { - return b.getI8Type(); - } - return t; -} - -// Get the value of the scalar constant's literal in a C++ type. -template -T ScalarConstantValue(const HloInstruction& instr, PrimitiveType dst_type) { - CHECK(hlo_query::IsScalarConstant(&instr)); - absl::StatusOr converted = instr.literal().Convert(dst_type); - TF_CHECK_OK(converted.status()); - return converted.value().GetFirstElement(); -} - -// Create a scalar constant. -template -ma::ConstantOp CreateConst(ImplicitLocOpBuilder b, Type type, T value) { - if (mlir::isa(type)) { - return b.create(b.getIntegerAttr(type, value)); - } - if (mlir::isa(type)) { - return b.create( - b.getFloatAttr(type, static_cast(value))); - } - LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); -} - -// Create a tensor constant. -template -ma::ConstantOp CreateConst(ImplicitLocOpBuilder& b, Type type, T value, - ArrayRef shape) { - auto tensor_type = mlir::RankedTensorType::get(shape, type); - if (auto int_type = mlir::dyn_cast(type)) { - return b.create(mlir::DenseElementsAttr::get( - tensor_type, mlir::APInt(int_type.getIntOrFloatBitWidth(), value))); - } - if (auto float_type = mlir::dyn_cast(type)) { - return b.create(mlir::DenseElementsAttr::get( - tensor_type, b.getFloatAttr(type, static_cast(value)))); - } - LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); -} - -Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { - Type src_ty = src_shaped_ty.getElementType(); - return CreateConst(b, src_ty, 0, src_shaped_ty.getShape()); - } - return CreateConst(b, x.getType(), 0); -} - -Value OnesLike(ImplicitLocOpBuilder& b, Value x) { - if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { - Type src_ty = src_shaped_ty.getElementType(); - return CreateConst(b, src_ty, 1, src_shaped_ty.getShape()); - } - return CreateConst(b, x.getType(), 1); -} - -bool IsFp8Type(Type t) { - return t.isFloat8E5M2() || t.isFloat8E4M3FN() || t.isFloat8E5M2FNUZ() || - t.isFloat8E4M3FNUZ() || t.isFloat8E4M3B11FNUZ(); -} - -// Triton type conversions. -Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { - Type src_ty = value.getType(); - Type src_element_ty = src_ty; - Type fp32_ty = b.getF32Type(); - Type dst_ty = dst_element_ty; - if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { - src_element_ty = src_shaped_ty.getElementType(); - dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty); - fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type()); - } - if (src_ty == dst_ty) { - return value; - } - - // All operations on bf16 are done through f32. - if (src_element_ty.isBF16()) { - return Cast(b, b.create(fp32_ty, value), dst_element_ty); - } - if (dst_element_ty.isBF16()) { - // S8 -> BF16 is directly supported and doesn't need to go through f32. - if (!src_element_ty.isInteger(8)) { - return b.create(dst_ty, Cast(b, value, b.getF32Type())); - } - } - - // float => float - auto src_fp_element_ty = mlir::dyn_cast(src_element_ty); - auto dst_fp_element_ty = mlir::dyn_cast(dst_element_ty); - if (src_fp_element_ty && dst_fp_element_ty) { - // F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp - // because LLVM doesn't support casts from/to FP8. - // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as - // we can't test the code below without patching the feature. - if (IsFp8Type(src_element_ty)) { - return b.create(dst_ty, value); - } - if (IsFp8Type(dst_element_ty)) { - return b.create( - dst_ty, value, - mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE)); - } - - if (src_fp_element_ty.getFPMantissaWidth() > - dst_fp_element_ty.getFPMantissaWidth()) { - return b.create(dst_ty, value); - } else { - return b.create(dst_ty, value); - } - } - // int => int - if (mlir::isa(src_element_ty) && - mlir::isa(dst_element_ty)) { - if (src_element_ty.getIntOrFloatBitWidth() < - dst_element_ty.getIntOrFloatBitWidth()) { - if (src_element_ty.isInteger(1)) { - return b.create(dst_ty, value); - } - return b.create(dst_ty, value); - } - return b.create(dst_ty, value); - } - // int => float - if (mlir::isa(src_element_ty) && dst_fp_element_ty) { - // TODO(b/266862493): Support unsigned integer types. - if (src_element_ty.isInteger(1)) { - return b.create(dst_ty, value); - } - return b.create(dst_ty, value); - } - // float => int - if (src_fp_element_ty && mlir::isa(dst_element_ty)) { - if (dst_element_ty.isInteger(1)) { - return b.create(ma::CmpFPredicate::UNE, value, - ZerosLike(b, value)); - } - // TODO(b/266862493): Support unsigned integer types. - // The current logic handles signed integer types only. Additional handling - // is needed for unsigned integer types. - auto cst_int = [&](int64_t x) { - if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { - return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()); - } else { - return CreateConst(b, dst_element_ty, x); - } - }; - auto cst_float = [&](int64_t x) { - if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { - return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()); - } else { - return CreateConst(b, src_fp_element_ty, x); - } - }; - auto fptosi = b.create(dst_ty, value); - int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth()); - int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth()); - - // value <= static_cast(INT_MIN) ? INT_MIN : ... - auto clamped = b.create( - b.create(mlir::arith::CmpFPredicate::OLE, value, - cst_float(min)), - cst_int(min), fptosi); - // value >= static_cast(INT_MAX) ? INT_MAX : ... - clamped = b.create( - b.create(mlir::arith::CmpFPredicate::OGE, value, - cst_float(max)), - cst_int(max), clamped); - // isnan(value) ? 0 : ... - return b.create( - b.create(mlir::arith::CmpFPredicate::UNO, value, - value), - cst_int(0), clamped); - } - - LOG(FATAL) << "Type conversion not supported: " - << llvm_ir::DumpToString(src_element_ty) << " -> " - << llvm_ir::DumpToString(dst_element_ty); -} - -Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { - if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { - return b.create(values[0], values[1]); - } else { - return b.create(values[0], values[1]); - } -} - -Value Compare(ImplicitLocOpBuilder& b, ValueRange values, - mlir::mhlo::ComparisonDirection direction) { - const Type type = mlir::getElementTypeOrSelf(values[0]); - if (mlir::isa(type)) { - return b.create( - mlir::mhlo::impl::getCmpPredicate( - direction, - /*isSigned=*/!type.isInteger(1)) - .value(), - values[0], values[1]); - } - return b.create( - mlir::mhlo::impl::getCmpPredicate(direction, - /*isSigned=*/true) - .value(), - values[0], values[1]); -} - -Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, - ValueRange values) { - if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { - return b.create(values); - } - // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs - // See also: IEEE Std 754-2008 5.11. - // - // This also works, but we wanted to make it similar to minimum. - // logic: isNaN(lhs) || lhs >= rhs ? lhs : rhs - Value lhs_is_nan = - Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); - Value rhs_is_not_nan = - Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); - Value lhs_is_ge = Compare(b, values, mlir::mhlo::ComparisonDirection::GE); - return b.create( - b.create(lhs_is_nan, - b.create(rhs_is_not_nan, lhs_is_ge)), - values[0], values[1]); -} - -Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, - ValueRange values) { - if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { - return b.create(values); - } - // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs - // See also: IEEE Std 754-2008 5.11. - // - // This should also work, but the tests show that it doesn't work for - // minimum(x, NaN): - // logic: isNaN(lhs) || lhs <= rhs ? lhs : rhs - Value lhs_is_nan = - Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); - Value rhs_is_not_nan = - Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); - Value lhs_is_le = Compare(b, values, mlir::mhlo::ComparisonDirection::LE); - return b.create( - b.create(lhs_is_nan, - b.create(rhs_is_not_nan, lhs_is_le)), - values[0], values[1]); -} +using ::xla::gpu::triton::Cast; +using ::xla::gpu::triton::CreateConst; +using ::xla::gpu::triton::EmitConstant; +using ::xla::gpu::triton::EmitElementwise; +using ::xla::gpu::triton::GetPaddedTileSizes; +using ::xla::gpu::triton::ScalarOrTensor; +using ::xla::gpu::triton::StorageType; +using ::xla::gpu::triton::TritonType; -// TODO(b/269489810): Contribute nicer builders to Triton, so we don't need to -// define these utilities. -Value Splat(ImplicitLocOpBuilder& b, Value value, ArrayRef shape) { - auto type = mlir::RankedTensorType::get(shape, value.getType()); - return b.create(type, value); -} +namespace { using TensorValue = mlir::TypedValue; -Value Broadcast(ImplicitLocOpBuilder& b, TensorValue value, - ArrayRef shape) { - return b.create(value.getType().clone(shape), value); +ScalarOrTensor Broadcast(ImplicitLocOpBuilder& b, TensorValue value, + ArrayRef shape) { + return ScalarOrTensor( + b.create(value.getType().clone(shape), value)); } -Value Range(ImplicitLocOpBuilder& b, int32_t limit) { +ScalarOrTensor Range(ImplicitLocOpBuilder& b, int32_t limit) { auto type = mlir::RankedTensorType::get(limit, b.getI32Type()); - return b.create(type, 0, limit); + return ScalarOrTensor(b.create(type, 0, limit)); } Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { return b.create(ptr.getType(), ptr, offset); } -absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, - absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloInstruction& hlo, - ValueRange inputs) { - if (mlir::getElementTypeOrSelf(inputs[0]).isF32() || - mlir::getElementTypeOrSelf(inputs[0]).isF64()) { - auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); - if (dev_fn_id.ok()) { - llvm::Triple triple("nvptx64-unknown-unknown"); - if (std::holds_alternative( - device_info.gpu_compute_capability())) { - triple.setTriple("amdgcn-unknown-unknown"); - } - return b.create( - inputs[0].getType(), inputs, "libdevice", libdevice_path, - ObtainDeviceFunctionName(dev_fn_id.value(), - hlo.shape().element_type(), triple), - /*pure=*/true); - } - } - const bool is_integer = - mlir::isa(mlir::getElementTypeOrSelf(inputs[0])); - - switch (hlo.opcode()) { - case HloOpcode::kCopy: - // Dimension transformations are taken care of separately. - return inputs[0]; - case HloOpcode::kAbs: - if (is_integer) { - return b.create(inputs[0]); - } - return b.create(inputs[0]); - case HloOpcode::kCeil: - return b.create(inputs[0]); - case HloOpcode::kFloor: - return b.create(inputs[0]); - case HloOpcode::kNot: - return b.create(inputs[0], OnesLike(b, inputs[0])); - case HloOpcode::kNegate: - // NegFOp is not supported by Triton. - return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); - case HloOpcode::kConvert: { - TF_ASSIGN_OR_RETURN(Type dst_ty, - TritonType(b, hlo.shape().element_type())); - return Cast(b, inputs[0], dst_ty); - } - case HloOpcode::kAdd: - if (is_integer) { - return b.create(inputs[0], inputs[1]); - } - return b.create(inputs[0], inputs[1]); - case HloOpcode::kSubtract: - return Subtract(b, inputs); - case HloOpcode::kMultiply: - if (is_integer) { - return b.create(inputs[0], inputs[1]); - } - return b.create(inputs[0], inputs[1]); - case HloOpcode::kMaximum: - return Maximum(b, device_info, inputs); - case HloOpcode::kMinimum: - return Minimum(b, device_info, inputs); - case HloOpcode::kClamp: - return Maximum( - b, device_info, - {Minimum(b, device_info, {inputs[1], inputs[2]}), inputs[0]}); - case HloOpcode::kAnd: - return b.create(inputs[0], inputs[1]); - case HloOpcode::kOr: - return b.create(inputs[0], inputs[1]); - case HloOpcode::kXor: - return b.create(inputs[0], inputs[1]); - case HloOpcode::kDivide: - if (is_integer) { - // Unsigned not supported yet. - return b.create(inputs[0], inputs[1]); - } - return b.create(inputs[0], inputs[1]); - case HloOpcode::kCompare: - return Compare( - b, inputs, - mlir::mhlo::symbolizeComparisonDirection( - ComparisonDirectionToString(hlo.comparison_direction())) - .value()); - case HloOpcode::kSelect: - return b.create( - Compare(b, {inputs[0], ZerosLike(b, inputs[0])}, - mlir::mhlo::ComparisonDirection::NE), - inputs[1], inputs[2]); - default: - return absl::InvalidArgumentError( - absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); - } -} - -Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, - ArrayRef boundary_checks) { - // 0-D MakeTensorPtrOp - // - // Triton tries to access the -1 element of a vector and segfaults when - // lowering the code to load a 0-D tensor to LLVM. The workaround is to load a - // regular pointer + a splat. +ScalarOrTensor EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, + ArrayRef boundary_checks) { if (auto make_tensor_ptr = pointer.getDefiningOp()) { if (make_tensor_ptr.getOffsets().empty()) { - return Splat(b, - b.create(make_tensor_ptr.getBase(), - mt::CacheModifier::NONE, - mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false), - {}); + return ScalarOrTensor(b.create(make_tensor_ptr.getBase(), + mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false)); } } @@ -615,123 +187,28 @@ Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, if (!boundary_checks.empty()) { padding = mt::PaddingOption::PAD_ZERO; } - return b.create(pointer, boundary_checks, padding, - mt::CacheModifier::NONE, - mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); + return ScalarOrTensor(b.create(pointer, boundary_checks, + padding, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false)); } // Non-tensor pointer. - // - // TODO(b/343013366): Remove this after we delete the legacy SoftMax code. - // It's the only place where this code-path is used. - return Splat(b, - b.create(pointer, mt::CacheModifier::NONE, - mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false), - {}); + return ScalarOrTensor(b.create(pointer, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false)); } -absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, - const HloInstruction& constant) { - TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); - if (constant.shape().IsInteger()) { - if (constant.shape().element_type() == U64) { - return CreateConst(b, ty, ScalarConstantValue(constant, U64)); - } else { - return CreateConst(b, ty, ScalarConstantValue(constant, S64)); - } - } - return CreateConst(b, ty, ScalarConstantValue(constant, F64)); -} - -// Grouped properties of tiled dimensions used to generate block pointers. -struct DimProperties { - DimProperties(int64_t index, Value pid, int block_size, int split_value) - : index(index), - pid(pid), - block_size(block_size), - split_value(split_value) {} - - // Logical index of the dimension at the tiling-defining operation. - int64_t index; - // Block program ID corresponding to this dimension. - Value pid; - // Elements of the dimension to process per block program. - int block_size; - // Size of the major part of the dimension if it's split into two parts. - int split_value; -}; - -struct Side { - explicit Side(TritonFusionAnalysis::Scope scope, - std::vector tiled_dims = {}, - std::optional batch_dim_idx = std::nullopt) - : scope(scope), tiled_dims(tiled_dims), batch_dim_idx(batch_dim_idx) {} - TritonFusionAnalysis::Scope scope; - std::vector tiled_dims; - std::optional batch_dim_idx; - int64_t unpack_dim_idx = 0; -}; - -absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, - const TritonFusionAnalysis* analysis, - const Side& side, - const HloInstruction& broadcast, - Value input) { - TF_RET_CHECK(analysis != nullptr); - std::vector out_shape; - for (const DimProperties& dim : side.tiled_dims) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis->IterSpec(side.scope, &broadcast, dim.index); - if (spec != nullptr && spec->at(0).stride > 0) { - out_shape.push_back(dim.block_size); - } - } - auto tensor_input = mlir::dyn_cast(input); - if (!tensor_input) { - // Input is scalar. - return Splat(b, input, out_shape); - } - if (tensor_input.getType().getRank() == out_shape.size()) { - // No dimensions to broadcast. - return input; - } - // Add broadcasted dimensions one by one. - Value expanded_input = tensor_input; - int dim_idx = 0; - for (const DimProperties& dim : side.tiled_dims) { - if (auto* spec = analysis->IterSpec(side.scope, &broadcast, dim.index); - spec != nullptr && spec->at(0).stride > 0) { - if (analysis->IterSpec(side.scope, broadcast.operand(0), dim.index) == - nullptr) { - // Broadcasted dimension. - expanded_input = b.create(expanded_input, dim_idx); - } - ++dim_idx; - } - } - return Broadcast(b, mlir::cast(expanded_input), out_shape); -} - -absl::StatusOr EmitScope( +absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, const Side& side, + const TritonFusionAnalysis* analysis, absl::Span instructions, - absl::flat_hash_map& values); + absl::flat_hash_map& values); -// Adds `n` leading `1` dimensions to the input tensor. -Value LeftExpandDimNTimes(ImplicitLocOpBuilder& b, Value input, int64_t n) { - for (int i = 0; i < n; ++i) { - input = b.create(input, /*axis=*/0); - } - return input; -} - -absl::StatusOr EmitReduce( +absl::StatusOr EmitReduce( ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_hlo_reduce, - absl::flat_hash_map& values, + absl::flat_hash_map& values, absl::string_view libdevice_path, const se::DeviceDescription& device_info) { // At the moment, we should only emit a full reduction over a single @@ -741,9 +218,9 @@ absl::StatusOr EmitReduce( TF_RET_CHECK(hlo_reduce.operand_count() == 2); TF_RET_CHECK(hlo_reduce.dimensions().size() == 1); - Value input = values[tiled_hlo_reduce.operand(0)]; + ScalarOrTensor input = values[tiled_hlo_reduce.operand(0)]; llvm::ArrayRef input_shape = - mlir::cast(input.getType()).getShape(); + mlir::cast(input.Type()).getShape(); absl::Span source_tensor_shape = hlo_reduce.operand(0)->shape().dimensions(); @@ -755,14 +232,14 @@ absl::StatusOr EmitReduce( // structure (element_type, hlo_reduce.to_apply(), hlo_reduce.operand(1))--- // up to floating-point inaccuracies. Masking the input using // hlo_reduce.operand(1) is thus always the right choice to ensure that the - // reduction is computed correctly, since it is the neutral value with regards - // to the reducer. + // reduction is computed correctly, since it is the neutral value with + // regards to the reducer. int64_t source_tensor_reduction_dimension_size = source_tensor_shape[reduction_dimension]; int64_t input_reduction_dimension_size = input_shape[reduction_dimension]; if (input_reduction_dimension_size != source_tensor_reduction_dimension_size) { - Value range = Range(b, input_reduction_dimension_size); + Value range = Range(b, input_reduction_dimension_size).UnwrapUnsafe(); // Triton's broadcast requires that the rank of the source and broadcasted // result are equal. for (int i = 0; i < input_shape.size() - 1; i++) { @@ -772,23 +249,32 @@ absl::StatusOr EmitReduce( range = b.create(range, /*axis=*/i + 1); } } - Value mask = Broadcast(b, mlir::cast(range), input_shape); - Value constant = - CreateConst(b, b.getI32Type(), source_tensor_reduction_dimension_size); - Value constant_tensor = Splat(b, constant, input_shape); - mask = b.create(ma::CmpIPredicate::slt, mask, constant_tensor); - - Value neutral = values[tiled_hlo_reduce.operand(1)]; + Value mask = Broadcast(b, mlir::cast(range), input_shape) + .UnwrapUnsafe(); + ScalarOrTensor constant = CreateConst( + b, b.getI32Type(), source_tensor_reduction_dimension_size, input_shape); + mask = b.create(ma::CmpIPredicate::slt, mask, + constant.UnwrapUnsafe()); + + ScalarOrTensor neutral = values[tiled_hlo_reduce.operand(1)]; // Triton's broadcast requires that the rank of the source and broadcasted // result are equal. - for (int i = 0; i < input_shape.size(); i++) { - neutral = b.create(neutral, /*axis=*/0); + if (neutral.IsScalar()) { + neutral = Splat(b, neutral, input_shape); + } else { + for (int i = 0; i < input_shape.size(); i++) { + neutral = ScalarOrTensor( + b.create(neutral.UnwrapUnsafe(), /*axis=*/0)); + } + neutral = Broadcast(b, mlir::cast(neutral.UnwrapUnsafe()), + input_shape); } - neutral = Broadcast(b, mlir::cast(neutral), input_shape); - input = b.create(mask, input, neutral); + input = ScalarOrTensor(b.create(mask, input.UnwrapUnsafe(), + neutral.UnwrapUnsafe())); } - mt::ReduceOp reduction = b.create(input, reduction_dimension); + mt::ReduceOp reduction = + b.create(input.UnwrapUnsafe(), reduction_dimension); { TF_ASSIGN_OR_RETURN(Type result_ty, TritonType(b, hlo_reduce.shape().element_type())); @@ -799,16 +285,16 @@ absl::StatusOr EmitReduce( HloComputation* reduction_computation = hlo_reduce.to_apply(); std::vector to_emit; - absl::flat_hash_map region_values; + absl::flat_hash_map region_values; for (const HloInstruction* instr : reduction_computation->MakeInstructionPostOrder()) { if (instr->opcode() == HloOpcode::kParameter) { int parameter_number = instr->parameter_number(); TF_RET_CHECK(parameter_number < 2); - TF_RET_CHECK( - region_values - .insert({instr, reducer->getArgument(parameter_number)}) - .second); + TF_RET_CHECK(region_values + .insert({instr, ScalarOrTensor(reducer->getArgument( + parameter_number))}) + .second); } else { to_emit.push_back(instr); } @@ -818,23 +304,14 @@ absl::StatusOr EmitReduce( b.setInsertionPointToStart(reducer); TF_ASSIGN_OR_RETURN( - Value result, - EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, + ScalarOrTensor result, + EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, to_emit, region_values)); - b.create(SmallVector({result})); + b.create(SmallVector({result.UnwrapUnsafe()})); b.setInsertionPointAfter(reduction); } - Value result = reduction.getResult().front(); - - // We want to return a tensor, but the ReturnReduceOp produces a raw scalar - // when reducing a single dim. To convert to a tensor we splat the result. - if (!mlir::dyn_cast(reduction.getResult().front())) { - result = Splat(b, result, {}); - } - - return result; + return ScalarOrTensor(reduction.getResult().front()); } // Emit code corresponding to a fusion instruction somehow nested within the @@ -843,17 +320,17 @@ absl::StatusOr EmitReduce( // fusion, we simply flatten the fusion inside the computation. // // TODO(b/331413981): get rid of this special handling once this is solved. -absl::StatusOr EmitNestedFusion( +absl::StatusOr EmitNestedFusion( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction& fusion_instruction, - absl::flat_hash_map& values) { - // TODO(b/331402498): revisit the order of scope once we completely deprecate - // Triton fusion analysis. + absl::flat_hash_map& values) { + // TODO(b/331402498): revisit the order of scope once we completely + // deprecate Triton fusion analysis. const HloComputation* fusion_computation = fusion_instruction.fused_instructions_computation(); - absl::flat_hash_map region_values; + absl::flat_hash_map region_values; std::vector to_emit; for (const HloInstruction* instr : @@ -871,35 +348,38 @@ absl::StatusOr EmitNestedFusion( TF_RET_CHECK(to_emit.back() == fusion_computation->root_instruction()); return EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, - region_values); + to_emit, region_values); } -Value EmitTiledBroadcast( +ScalarOrTensor EmitTiledBroadcast( ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_broadcast, - absl::flat_hash_map& values) { + absl::flat_hash_map& values) { const llvm::SmallVector& input_tile_shape = tiled_broadcast.operand(0)->tile_sizes(); const llvm::SmallVector& output_tile_shape = tiled_broadcast.tile_sizes(); + + if (input_tile_shape.empty() && output_tile_shape.empty()) { + return values[tiled_broadcast.operand(0)]; + } + CHECK(!output_tile_shape.empty()); + SmallVector padded_output_tile_shape = GetPaddedTileSizes(output_tile_shape); - Value expanded_input = values[tiled_broadcast.operand(0)]; + ScalarOrTensor input = values[tiled_broadcast.operand(0)]; + // Handle the 0d special case. + if (input.IsScalar()) { + return Splat(b, input, padded_output_tile_shape); + } + + Value expanded_input = input.UnwrapTensor(); // Returns true if `dim_id` is broadcasted. auto is_broadcasted_dim = [&](int64_t dim_id) { return !llvm::is_contained(tiled_broadcast.hlo()->dimensions(), dim_id); }; - // Handle the 0d special case. - if (input_tile_shape.empty()) { - expanded_input = - LeftExpandDimNTimes(b, expanded_input, output_tile_shape.size()); - return Broadcast(b, mlir::cast(expanded_input), - padded_output_tile_shape); - } - // The loop below iterates over output dimensions and tracks matching dims in // input_tile_shape and expended_input value. // `input_dim_id != expanded_input_dim_id`, because size-1 dims are present in @@ -928,20 +408,128 @@ Value EmitTiledBroadcast( padded_output_tile_shape); } -Value EmitTiledReshape(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, - Value input) { +absl::StatusOr EmitTiledIota( + ImplicitLocOpBuilder& b, ValueRange tile_multi_index, + const TiledHloInstruction& tiled_iota) { + const HloIotaInstruction* hlo_iota = + ::xla::Cast(tiled_iota.hlo()); + int64_t iota_dim = hlo_iota->iota_dimension(); + + SmallVector padded_tile_sizes = + GetPaddedTileSizes(tiled_iota.tile_sizes()); + + // We can treat iota more or less as a parameter load, except that we need to + // generate the right values in the right place as opposed to loading them. + TF_ASSIGN_OR_RETURN(IndexingMap tile_offsets_indexing, + tiled_iota.tile_offsets_indexing()); + + auto iota_dim_offset = b.create( + b.getI32Type(), mlir_converter::ApplyIndexing( + tile_offsets_indexing, /*dims=*/tile_multi_index, + /*symbols=*/{}, b)[iota_dim]); + + // First, stride as needed between the iota components. + Value range = b.create( + Range(b, padded_tile_sizes[iota_dim]).UnwrapTensor(), + Splat(b, + CreateConst(b, b.getI32Type(), tiled_iota.tile_strides()[iota_dim]), + padded_tile_sizes[iota_dim]) + .UnwrapTensor()); + + // Then, add the base offset to the iota components. + range = b.create(range, Splat(b, ScalarOrTensor(iota_dim_offset), + padded_tile_sizes[iota_dim]) + .UnwrapTensor()); + + // Cast the result to the targeted type. + TF_ASSIGN_OR_RETURN(Type iota_element_type, + TritonType(b, hlo_iota->shape().element_type())); + + range = Cast(b, range, iota_element_type); + + // And finally, produce a broadcast along the non-iota dimensions in order to + // produce the whole iota tile. + for (int i = 0; i < padded_tile_sizes.size() - 1; i++) { + if (i < iota_dim) { + range = b.create(range, /*axis=*/0); + } else { + range = b.create(range, /*axis=*/i + 1); + } + } + + return Broadcast(b, mlir::cast(range), padded_tile_sizes); +} + +// Reshapes a non-0D tensor of shape [1, 1, 1, ...] to a scalar. +ScalarOrTensor ReshapeTensorToScalar(ImplicitLocOpBuilder& b, Value input) { + auto element_type = mlir::cast(input.getType()).getElementType(); + + // First, reshape to a 1D tensor if not already the case. This is needed + // because triton::ReduceOp can only reduce 1 dimension at a time. + auto single_dim_tensor = input; + if (mlir::cast(input.getType()).getRank() > 1) { + Type output_tensor_type = mlir::RankedTensorType::get({1}, element_type); + single_dim_tensor = b.create(output_tensor_type, input, + /*allow_reorder=*/true); + } + + // Second, reduce to a scalar. + mt::ReduceOp reduction = + b.create(single_dim_tensor, /*axis=*/0); + + mlir::Location loc = b.getLoc(); + mlir::Block* reducer = b.createBlock( + &reduction->getRegion(0), /*insertPt=*/{}, + /*argTypes=*/{element_type, element_type}, /*locs=*/{loc, loc}); + + b.setInsertionPointToStart(reducer); + Value result = mlir::isa(element_type) + ? b.create(reducer->getArgument(0), + reducer->getArgument(1)) + .getResult() + : b.create(reducer->getArgument(0), + reducer->getArgument(1)) + .getResult(); + b.create(SmallVector({result})); + b.setInsertionPointAfter(reduction); + + return ScalarOrTensor(reduction.getResult().front()); +} + +absl::StatusOr EmitTiledReshape(ImplicitLocOpBuilder& b, + ArrayRef tile_sizes, + ScalarOrTensor input) { SmallVector padded_tile_sizes = GetPaddedTileSizes(tile_sizes); - Type input_element_type = - mlir::cast(input.getType()).getElementType(); - Type output_tensor_type = - mlir::RankedTensorType::get(padded_tile_sizes, input_element_type); + if (input.IsScalar()) { + if (tile_sizes.empty()) { + // Nothing to do. + return input; + } + // Convert the scalar to a tensor. + return Splat(b, input, padded_tile_sizes); + } + + // At this point we know that the input is a non-0D tensor. + + auto input_shaped_type = mlir::cast(input.Type()); + + // Handle the case of reshaping [1,1,1...] to a scalar. + if (tile_sizes.empty()) { + return ReshapeTensorToScalar(b, input.UnwrapTensor()); + } + + // At this point we know that neither the input nor the output are 0D tensors. + + Type output_tensor_type = mlir::RankedTensorType::get( + padded_tile_sizes, input_shaped_type.getElementType()); // Conservatively prevent Triton from reordering elements within the tile. // TODO(b/353637689): see if this restriction can be lifted. bool allow_reorder = false; - return b.create(output_tensor_type, input, allow_reorder) - .getResult(); + auto reshape = b.create(output_tensor_type, + input.UnwrapUnsafe(), allow_reorder); + return ScalarOrTensor(reshape.getResult()); } Value EmitTiledTranspose(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, @@ -958,8 +546,9 @@ Value EmitTiledTranspose(ImplicitLocOpBuilder& b, ArrayRef tile_sizes, return b.create(output_tensor_type, input, order); } -Value EmitTiledBitcast(ImplicitLocOpBuilder& b, - const TiledHloInstruction& tiled_bitcast, Value input) { +absl::StatusOr EmitTiledBitcast( + ImplicitLocOpBuilder& b, const TiledHloInstruction& tiled_bitcast, + Value input) { // Any Bitcast is decomposable to a transpose+reshape+transpose. auto trt = ShapeUtil::DecomposeBitcastToTrt( tiled_bitcast.hlo()->operand(0)->shape(), tiled_bitcast.hlo()->shape()); @@ -989,29 +578,35 @@ Value EmitTiledBitcast(ImplicitLocOpBuilder& b, // are a permutation (according to transpose2_dims) of the tile sizes of // the reshape. Since we know the tile sizes of the final transpose and need // the tile sizes of the reshape, we compute the tile sizes backwards, taking - // the inreverse permutation. + // the inverse permutation. std::vector reshape_tile_sizes = PermuteInverse(tiled_bitcast.tile_sizes(), trt.transpose2_dims); - Value normalized_reshape = - ShapeUtil::Equal(trt.transpose1_shape, trt.reshape_shape) - ? normalized_input - : EmitTiledReshape(b, reshape_tile_sizes, normalized_input); + Value normalized_reshape; + if (ShapeUtil::Equal(trt.transpose1_shape, trt.reshape_shape)) { + normalized_reshape = normalized_input; + } else { + TF_ASSIGN_OR_RETURN(auto reshape, + EmitTiledReshape(b, reshape_tile_sizes, + ScalarOrTensor(normalized_input))); + normalized_reshape = reshape.UnwrapUnsafe(); + } // The final transpose simply uses the tile sizes computed for the original // bitcast by the tiling analysis. - return trt.IsTranspose2Identity() - ? normalized_reshape - : EmitTiledTranspose(b, tiled_bitcast.tile_sizes(), - llvm::to_vector(trt.transpose2_dims), - normalized_reshape); + return ScalarOrTensor{ + trt.IsTranspose2Identity() + ? normalized_reshape + : EmitTiledTranspose(b, tiled_bitcast.tile_sizes(), + llvm::to_vector(trt.transpose2_dims), + normalized_reshape)}; } -absl::StatusOr EmitTiledHloInstruction( +absl::StatusOr EmitTiledHloInstruction( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloInstruction& tiled_hlo, mlir::triton::FuncOp fn, ValueRange tile_multi_index, - absl::flat_hash_map& values) { + absl::flat_hash_map& values) { const HloInstruction* hlo = tiled_hlo.hlo(); if (fusion->IsUserOf(hlo)) { @@ -1020,15 +615,14 @@ absl::StatusOr EmitTiledHloInstruction( b, tile_multi_index, tiled_hlo, fn.getArgument(fusion->operand_index(hlo)))); - Value parameter = + ScalarOrTensor parameter = EmitParameterLoad(b, make_tensor.op, make_tensor.boundary_checks); // Some types are stored using different types, e.g. i1 is stored in memory - // as i8. It's important to type checking that we perform a conversion - // after loading if the type of the loaded parameter does not match what - // is expected. - Type loaded_element_type = - mlir::cast(parameter.getType()).getElementType(); + // as i8. It's important to type checking that we perform a conversion after + // loading if the type of the loaded parameter does not match what is + // expected. + Type loaded_element_type = getElementTypeOrSelf(parameter.Type()); TF_ASSIGN_OR_RETURN(Type expected_element_type, TritonType(b, hlo->shape().element_type())); @@ -1041,22 +635,25 @@ absl::StatusOr EmitTiledHloInstruction( "while lowering ", fusion->called_computation()->ToString())); } - parameter = Cast(b, parameter, expected_element_type); + parameter = ScalarOrTensor( + Cast(b, parameter.UnwrapUnsafe(), expected_element_type)); } return parameter; } if (hlo->opcode() == HloOpcode::kConstant) { - if (ShapeUtil::IsScalar(hlo->shape())) { - TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); - // Splat makes it a tensor to avoid type mismatches. - return Splat(b, constant, {}); + if (ShapeUtil::IsEffectiveScalar(hlo->shape())) { + return EmitConstant(b, *hlo); } return absl::UnimplementedError( absl::StrCat("Unsupported non-scalar constant ", hlo->ToString())); } + if (hlo->opcode() == HloOpcode::kIota) { + return EmitTiledIota(b, tile_multi_index, tiled_hlo); + } + if (hlo->opcode() == HloOpcode::kBroadcast) { return EmitTiledBroadcast(b, tiled_hlo, values); } @@ -1070,9 +667,12 @@ absl::StatusOr EmitTiledHloInstruction( operands.reserve(hlo->operands().size()); for (const TiledHloInstruction* operand : tiled_hlo.operands()) { - operands.push_back(values[operand]); + operands.push_back(values[operand].UnwrapUnsafe()); } - return EmitElementwise(b, libdevice_path, device_info, *hlo, operands); + TF_ASSIGN_OR_RETURN( + Value result, + EmitElementwise(b, libdevice_path, device_info, *hlo, operands)); + return ScalarOrTensor(result); } if (hlo->opcode() == HloOpcode::kReshape) { @@ -1081,15 +681,16 @@ absl::StatusOr EmitTiledHloInstruction( } if (hlo->opcode() == HloOpcode::kBitcast) { - return EmitTiledBitcast(b, tiled_hlo, values[tiled_hlo.operand(0)]); + return EmitTiledBitcast(b, tiled_hlo, + values[tiled_hlo.operand(0)].UnwrapUnsafe()); } if (hlo->opcode() == HloOpcode::kTranspose) { auto transpose = ::xla::Cast(tiled_hlo.hlo()); - return EmitTiledTranspose(b, tiled_hlo.tile_sizes(), - llvm::to_vector(transpose->dimensions()), - values[tiled_hlo.operand(0)]); + return ScalarOrTensor(EmitTiledTranspose( + b, tiled_hlo.tile_sizes(), llvm::to_vector(transpose->dimensions()), + values[tiled_hlo.operand(0)].UnwrapUnsafe())); } // Slice is currently supported only as an operation on indices @@ -1104,17 +705,17 @@ absl::StatusOr EmitTiledHloInstruction( // Emit sequence of instructions using compatible tiling ordered producers // before consumers. -absl::StatusOr EmitTiledScope( +absl::StatusOr EmitTiledScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, const HloFusionInstruction* fusion, const TiledHloComputation& tiled_computation, mlir::triton::FuncOp fn, ValueRange tile_multi_index) { - absl::flat_hash_map values; + absl::flat_hash_map values; for (const TiledHloInstruction* tiled_hlo : tiled_computation.instructions()) { TF_ASSIGN_OR_RETURN( - Value result, + ScalarOrTensor result, EmitTiledHloInstruction(b, libdevice_path, device_info, fusion, *tiled_hlo, fn, tile_multi_index, values)); TF_RET_CHECK(values.insert({tiled_hlo, result}).second) @@ -1125,58 +726,18 @@ absl::StatusOr EmitTiledScope( return values[tiled_computation.GetRoot()]; } -// Emit sequence of operations for unpacking 2xi4 -> i8. -absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, - const HloInstruction* hlo, - const Side& side, Value& value) { - VLOG(6) << "EmitUnpackInt4: " << hlo->ToString(); - auto input_type = mlir::cast(value.getType()); - if (input_type.getShape().size() != 2) { - return absl::InvalidArgumentError( - absl::StrCat("UnpackInt4 works only for 2d inputs: ", hlo->ToString())); - } - // We use shifts instead the mask because we need to keep the sign bit. - Value shift4 = - Splat(b, CreateConst(b, b.getI8Type(), 4), input_type.getShape()); - Value lo = b.create(b.create(value, shift4), shift4); - Value hi = b.create(value, shift4); - Value result = b.create(hi, lo); - if (side.unpack_dim_idx == 0) { - result = b.create(result, b.getDenseI32ArrayAttr({0, 2, 1})); - } - SmallVector result_shape(input_type.getShape()); - result_shape[side.unpack_dim_idx] *= 2; - auto type = mlir::RankedTensorType::get(result_shape, b.getI8Type()); - return b.create(type, result, /*allow_reorder=*/false); -} - // Emit sequence of instructions using compatible tiling ordered producers // before consumers. -absl::StatusOr EmitScope( +absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, const Side& side, + const TritonFusionAnalysis* analysis, absl::Span instructions, - absl::flat_hash_map& values) { + absl::flat_hash_map& values) { for (const HloInstruction* hlo : instructions) { - Value result; - if (hlo->opcode() == HloOpcode::kConvert && - hlo->operand(0)->shape().element_type() == S4) { - if (!hlo->GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - return absl::UnimplementedError( - "Int4 support is not enabled in the debug options."); - } - - TF_ASSIGN_OR_RETURN( - auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)])); - std::vector operands({unpacked}); - TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, - device_info, *hlo, operands)); - } else if (hlo->opcode() == HloOpcode::kConcatenate || - hlo->opcode() == HloOpcode::kDynamicSlice) { + ScalarOrTensor result; + if (hlo->opcode() == HloOpcode::kConcatenate || + hlo->opcode() == HloOpcode::kDynamicSlice) { // Parameter loads and their concatenations are handled outside EmitScope. TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; @@ -1188,20 +749,20 @@ absl::StatusOr EmitScope( TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; } else if (hlo->opcode() == HloOpcode::kConstant) { - TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); - // Splat makes it a tensor to avoid type mismatches. - result = Splat(b, constant, {}); + return EmitConstant(b, *hlo); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - TF_ASSIGN_OR_RETURN(result, EmitBroadcast(b, analysis, side, *hlo, - values[hlo->operand(0)])); + return absl::InvalidArgumentError( + "Broadcast is not yet supported in EmitScope()."); } else if (HloInstruction::IsOpElementwise(hlo->opcode())) { std::vector operands; operands.reserve(hlo->operands().size()); for (const HloInstruction* operand : hlo->operands()) { - operands.push_back(values[operand]); + operands.push_back(values[operand].UnwrapUnsafe()); } - TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, - device_info, *hlo, operands)); + TF_ASSIGN_OR_RETURN( + Value elementwise_result, + EmitElementwise(b, libdevice_path, device_info, *hlo, operands)); + result = ScalarOrTensor(elementwise_result); } else if (hlo->opcode() == HloOpcode::kTuple) { TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); } else if (hlo->opcode() == HloOpcode::kBitcast || @@ -1228,1524 +789,6 @@ absl::StatusOr EmitScope( return values[instructions.back()]; } -const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( - const TritonFusionAnalysis& analysis, int64_t lhs_noncontracting_dim_idx) { - const TensorIterationSpec::DimIterationSpec* result = nullptr; - for (const HloInstruction* lhs_param : - analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS)) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, lhs_param, - lhs_noncontracting_dim_idx); - if (spec != nullptr && spec->size() > 1) { - CHECK_EQ(spec->size(), 2); - if (result != nullptr) { - CHECK_EQ(result->at(0).count, spec->at(0).count); - CHECK_EQ(result->at(1).count, spec->at(1).count); - } - result = spec; - } - } - return result; -} - -// Structure for parameters relating to the MatMul shape and dimension indices. -// -// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -// -// The logical output dimensions are always ordered as: -// split-K, batch, non-contracting LHS, non-contracting RHS, -// where split-K and batch are optional. -struct MatMulDims { - static absl::StatusOr Create( - const TritonGemmConfig& config, const HloDotInstruction& dot, - const TritonFusionAnalysis& analysis); - - std::optional out_split_k_dim_idx = std::nullopt; - - std::optional lhs_batch_dim_idx = std::nullopt; - std::optional rhs_batch_dim_idx = std::nullopt; - std::optional out_batch_dim_idx = std::nullopt; - - // The LHS non-contracting can be split into two. - std::optional lhs_noncontracting_split = std::nullopt; - - int lhs_contracting_dim_idx; - int lhs_noncontracting_dim_idx; - int rhs_contracting_dim_idx; - int rhs_noncontracting_dim_idx; - // The index of the LHS noncontracting dim in the output. - int out_lhs_noncontracting_dim_idx; - // The index of the RHS noncontracting dim in the output. - int out_rhs_noncontracting_dim_idx; - - int64_t m; - int64_t n; - int64_t k; - - private: - MatMulDims() = default; -}; - -// Structure for parameters relating to the MatMul launch grid. -struct MatMulLaunchConfig { - explicit MatMulLaunchConfig(const TritonGemmConfig& config, - const HloDotInstruction& dot, - const MatMulDims& dims); - - int64_t grid_m; - int64_t grid_n; - LaunchDimensions launch_dims; - mt::ProgramIDDim batch_program_id_dim; - mt::ProgramIDDim noncontracting_program_id_dim; -}; - -/*static*/ absl::StatusOr MatMulDims::Create( - const TritonGemmConfig& config, const HloDotInstruction& dot, - const TritonFusionAnalysis& analysis) { - MatMulDims matmul_dims; - if (config.split_k > 1) { - // split-k is always the first logical dimension. - matmul_dims.out_split_k_dim_idx = 0; - } - - int64_t num_split_k_dims = config.split_k > 1 ? 1 : 0; - const auto& dims = dot.dot_dimension_numbers(); - matmul_dims.lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0); - matmul_dims.lhs_noncontracting_dim_idx = - GetNonContractingDims(dot.operand(0)->shape(), - dims.lhs_batch_dimensions(), - dims.lhs_contracting_dimensions()) - .value()[0]; - matmul_dims.rhs_contracting_dim_idx = dims.rhs_contracting_dimensions(0); - matmul_dims.rhs_noncontracting_dim_idx = - GetNonContractingDims(dot.operand(1)->shape(), - dims.rhs_batch_dimensions(), - dims.rhs_contracting_dimensions()) - .value()[0]; - - if (dims.lhs_batch_dimensions_size() > num_split_k_dims) { - matmul_dims.lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); - matmul_dims.rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); - // The batch dimension (if present) comes after the split-k dimension (if - // present, otherwise it's the first dimension). - matmul_dims.out_batch_dim_idx = num_split_k_dims; - } - - // Logical output dimensions are always ordered as: - // split-K, batch, non-contracting LHS, non-contracting RHS, - // where split-K and batch are optional. - matmul_dims.out_rhs_noncontracting_dim_idx = dot.shape().rank() - 1; - matmul_dims.out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; - - auto* root = dot.parent()->root_instruction(); - auto iter_spec = - analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, - matmul_dims.out_rhs_noncontracting_dim_idx); - TF_RET_CHECK(iter_spec != nullptr); - matmul_dims.n = iter_spec->at(0).count; - // Contracting dimension length. - if (config.split_k > 1 && - dot.operand(1)->operand(0)->opcode() == HloOpcode::kPad) { - // Unpadded LHS shape: [..., k, ...] - // Padded LHS shape: [..., padded_k, ...] - // Bitcasted LHS shape: [..., split_k, padded_k / split_k, ...] - TF_RET_CHECK(dot.operand(1)->opcode() == HloOpcode::kBitcast); - const Shape& unpadded_rhs_shape = - dot.operand(1)->operand(0)->operand(0)->shape(); - matmul_dims.k = - unpadded_rhs_shape.dimensions(dims.rhs_contracting_dimensions(0) - 1); - } else { - matmul_dims.k = - dot.operand(1)->shape().dimensions(dims.rhs_contracting_dimensions(0)) * - config.split_k; - } - - auto* lhs_noncontracting_split_spec = GetLhsNoncontractingSplitSpec( - analysis, matmul_dims.lhs_noncontracting_dim_idx); - if (lhs_noncontracting_split_spec != nullptr) { - // Just the fastest-varying part of it if the dimension is split. - matmul_dims.m = lhs_noncontracting_split_spec->at(0).count; - matmul_dims.lhs_noncontracting_split = - lhs_noncontracting_split_spec->at(1).count; - } else { - matmul_dims.m = analysis - .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, - matmul_dims.out_lhs_noncontracting_dim_idx) - ->at(0) - .count; - } - - // For now split non-contracting and batch are not supported - // simultaneously because they are implemented via same mechanism. - TF_RET_CHECK(!(matmul_dims.out_batch_dim_idx.has_value() && - matmul_dims.lhs_noncontracting_split.has_value())); - - TF_RET_CHECK(matmul_dims.m >= 1); - TF_RET_CHECK(matmul_dims.n >= 1); - return std::move(matmul_dims); -} - -MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, - const HloDotInstruction& dot, - const MatMulDims& dims) - : grid_m((dims.m + config.block_m - 1) / config.block_m), - grid_n((dims.n + config.block_n - 1) / config.block_n) { - int64_t batch_size = dims.lhs_noncontracting_split.value_or( - dims.out_batch_dim_idx.has_value() - ? dot.shape().dimensions(*dims.out_batch_dim_idx) - : 1); - // X block size is 32-bit, Y and Z are 16-bit. Use X for large dimensions. - constexpr int64_t kBlockCountYZLimit = 65536; - - // In the imaginary situation where both batch size and grid_m * grid_n - // are over 65535 we have to give up. Given the minimal m, n block sizes of 16 - // this requires at least 256 GB of output. - CHECK_LT(batch_size * grid_m * grid_n, - kBlockCountYZLimit * kBlockCountYZLimit); - - const bool large_batch = batch_size >= kBlockCountYZLimit; - if (large_batch) { - batch_program_id_dim = mt::ProgramIDDim::X; - noncontracting_program_id_dim = mt::ProgramIDDim::Y; - launch_dims = LaunchDimensions( - se::BlockDim(batch_size, grid_m * grid_n, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); - } else { - batch_program_id_dim = mt::ProgramIDDim::Y; - noncontracting_program_id_dim = mt::ProgramIDDim::X; - launch_dims = LaunchDimensions( - se::BlockDim(grid_m * grid_n, batch_size, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); - } -} - -absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, - const HloDotInstruction& dot) { - TF_RET_CHECK(config.split_k >= 1); - TF_RET_CHECK(config.block_m >= 16); - TF_RET_CHECK(config.block_k >= 16); - TF_RET_CHECK(config.block_n >= 16); - - const auto& dims = dot.dot_dimension_numbers(); - int num_batch_dims = - dims.lhs_batch_dimensions_size() - (config.split_k > 1 ? 1 : 0); - TF_RET_CHECK(num_batch_dims <= 1); - if (config.split_k > 1) { - // Split-K dimension has to be the first batch one and have an index - // just before the contracting one. - const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1; - const int rhs_split_k_dim_idx = dims.rhs_contracting_dimensions(0) - 1; - // Size of this dimension has to match the split_k value. - TF_RET_CHECK(dims.lhs_batch_dimensions(0) == lhs_split_k_dim_idx); - TF_RET_CHECK(dims.rhs_batch_dimensions(0) == rhs_split_k_dim_idx); - TF_RET_CHECK(config.split_k == - dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx)); - TF_RET_CHECK(config.split_k == - dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx)); - } - - // Rely on dot decomposer: there is just one contracting and one - // non-contracting dimension on each side + batch ones optionally. - TF_RET_CHECK(dims.lhs_contracting_dimensions_size() == 1); - TF_RET_CHECK(dims.rhs_contracting_dimensions_size() == 1); - - TF_RET_CHECK(dot.operand(0)->shape().rank() == - 2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims); - return absl::OkStatus(); -} - -// if (index < limits[0]) { -// return choices[0]; -// } else if (index < limits[1]) { -// return choices[1]; -// } else if (...) { -// ... -// } else { -// return choices.back(); -// } -absl::StatusOr EmitMultiSelect(ImplicitLocOpBuilder b, Value index, - ValueRange limits, ValueRange choices) { - TF_RET_CHECK(choices.size() - 1 == limits.size()); - Value result = choices[0]; - for (int i = 0; i < choices.size() - 1; ++i) { - result = b.create( - b.create(ma::CmpIPredicate::slt, index, limits[i]), result, - choices[i + 1]); - } - return result; -} - -absl::Status UncompilableMatmul(absl::string_view explanation) { - absl::Status s = absl::CancelledError(explanation); - s.SetPayload(kUncompilableFusion, absl::Cord(explanation)); - return s; -} - -bool IsFp8Matmul(const HloDotInstruction* dot_instr) { - return absl::c_all_of(std::array{0, 1}, [&](int idx) { - return primitive_util::IsF8Type( - dot_instr->operand(idx)->shape().element_type()); - }); -} - -class MatMulEmitterHelper { - public: - MatMulEmitterHelper(absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloDotInstruction* dot_instr, - ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims, - const MatMulLaunchConfig& launch_config, - const TritonFusionAnalysis& analysis) - : b_(b), - libdevice_path_(libdevice_path), - device_info_(device_info), - dot_instr_(dot_instr), - index_ty_(index_ty), - analysis_(analysis), - dims_(dims), - launch_config_(launch_config) {} - - // TODO(b/266862493): Accumulator can be integer too. - // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - absl::StatusOr GetDotAccumulatorType() { - const PrecisionConfig::Algorithm algorithm = - dot_instr_->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - TF_ASSIGN_OR_RETURN(Type dot_output_ty, - TritonType(b_, dot_instr_->shape().element_type())); - // The code below assumes that lhs and rhs have the same type. However - // it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported - // at the hardware level. NVidia GPU currently only supports f32 - // accumulator for such matmuls. - if (IsFp8Matmul(dot_instr_)) { - return b_.getF32Type(); - } - - // Data type of dot() immediate inputs. - TF_ASSIGN_OR_RETURN( - const Type lhs_ty, - TritonType(b_, dot_instr_->operand(0)->shape().element_type())); - TF_ASSIGN_OR_RETURN( - const Type rhs_ty, - TritonType(b_, dot_instr_->operand(1)->shape().element_type())); - TF_RET_CHECK(lhs_ty == rhs_ty); - Type dot_input_ty = lhs_ty; - // TODO(b/266862493): Accumulator can be integer too. - // Otherwise only f64 x f64 -> f64 uses f64 accumulator. - return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() - : b_.getF32Type(); - } - - absl::StatusOr accum_type = - algorithm_util::GetDotAccumulatorType(algorithm); - CHECK(accum_type.ok()) << "Unexpected algorithm: " - << PrecisionConfig::Algorithm_Name(algorithm); - TF_ASSIGN_OR_RETURN(Type mlir_accum_type, - TritonType(b_, accum_type.value())); - if (auto float_accum_type = - mlir::dyn_cast(mlir_accum_type)) { - return float_accum_type; - } - LOG(FATAL) << "Only floating point accumulator types are supported for " - "now, but we got: " - << llvm_ir::DumpToString(mlir_accum_type); - } - - std::vector EpiloguePostOrderTransitiveOperands( - const HloInstruction* root) { - // Collect all instructions of the dot's output scope. - absl::flat_hash_set to_order; - { - std::queue to_add; - if (root != dot_instr_) { - to_add.push(root); - } - while (!to_add.empty()) { - const HloInstruction* current = to_add.front(); - for (const HloInstruction* operand : current->operands()) { - if (!to_order.contains(operand)) { - if (operand != dot_instr_) { - to_add.push(operand); - } - } - } - to_order.insert(current); - to_add.pop(); - } - } - // Order them producers before consumers. - std::vector to_emit; - for (const HloInstruction* hlo : - dot_instr_->parent()->MakeInstructionPostOrder()) { - if (to_order.contains(hlo)) { - to_emit.push_back(hlo); - } - } - return to_emit; - } - - Value MakeInput(const Side& side, int64_t operand_index, - absl::flat_hash_map& values) { - return *EmitScope( - b_, libdevice_path_, device_info_, &analysis_, side, - dot_instr_->parent()->MakeInstructionPostOrderFrom( - const_cast(*dot_instr_->operand(operand_index))), - values); - } - - int64_t GetNonContractingDimIdxForOperandScope( - TritonFusionAnalysis::Scope scope) { - if (scope == TritonFusionAnalysis::Scope::LHS) { - return dims_.lhs_noncontracting_dim_idx; - } else if (scope == TritonFusionAnalysis::Scope::RHS) { - return dims_.rhs_noncontracting_dim_idx; - } else { - CHECK(false) << "This shouldn't be called for the output scope."; - } - } - - // Return the batch stride of the HLO passed as a parameter. If the - // parameter HLO has no batch dimension, a zero stride is returned. - // Also sets offset_batch and updates has_batch_offset as a side effect. - absl::StatusOr GetBatchStride(const Side& side, - const HloInstruction* hlo_param, - int64_t& offset_batch, - bool& has_batch_offset) { - int64_t stride_batch = 0; - if (side.scope != TritonFusionAnalysis::Scope::RHS && - dims_.lhs_noncontracting_split) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis_.IterSpec(side.scope, hlo_param, side.tiled_dims[0].index); - if (spec != nullptr) { - if (spec->size() > 1) { - // Support one specific kind of output transpose that splits the - // dimension originating from the split LHS non-contracting one. - stride_batch = spec->at(1).stride; - } else { - // Because the major part of the split is implemented using the - // batch logic stride_batch is populated here as the stride of - // the minor part times its size. - stride_batch = spec->at(0).stride * - (spec->at(0).count / *dims_.lhs_noncontracting_split); - } - TF_RET_CHECK(stride_batch != 0); - } - } else if (side.batch_dim_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis_.IterSpec(side.scope, hlo_param, *side.batch_dim_idx); - if (spec != nullptr) { - stride_batch = spec->at(0).stride; - offset_batch = spec->at(0).slice_start; - TF_RET_CHECK(stride_batch != 0); - } - } - - has_batch_offset |= stride_batch != 0; - return Cst(stride_batch); - } - - // bases: The base pointers of each argument. - absl::StatusOr EmitTensorPointer( - const HloInstruction* hlo, const Side& side, ValueRange bases, - Value pid_k, std::vector& boundary_checks) { - // Parameters of MakeTensorPtrOp to be generated by this function. - Value base; - std::vector bounds; - std::vector strides; - std::vector strides_sizes; // We use it to detect the minor dim. - // Offsets from tensor origin, same for all thread blocks. - std::vector tensor_offsets; - std::vector block_dims; - std::vector dim_order; - - // Offsets for a given thread block, typically pid * block size. - // Used in a one-off AdvanceOp applied to the generated MakeTensorPtrOp. - std::vector block_offsets; - - // Concatenations of parameters are handled during generation of block - // pointers because of a limitation of implementation of block pointers - // in the Triton compiler: block pointers are not supported inside - // conditionals. - // Therefore instead of directly using a conditional to emit a concatenation - // and emitting its inputs inside the cases a single block pointer is - // emitted for all inputs, but all its properties (base, strides etc) get - // generated conditionally on the position of the current thread block - // within the concatenated dimension. - - // Index of concatenated dimension if present, -1 otherwise. - int concat_dim_idx; - // Offsets along the concatenated dimension at which operands change. - std::vector concat_boundaries; - // Block index along the concatenated dimension * block size. - Value concat_dim_pid_offset; - - if (hlo->opcode() == HloOpcode::kConcatenate) { - // For now only non-contracting dimension can be concatenated. - concat_dim_idx = (side.scope == TritonFusionAnalysis::Scope::LHS) - ? dims_.lhs_noncontracting_dim_idx - : dims_.rhs_noncontracting_dim_idx; - const DimProperties& properties = [&] { - for (const DimProperties& dim : side.tiled_dims) { - if (dim.index == concat_dim_idx) { - return dim; - } - } - LOG(FATAL) << "Missing dimension."; - }(); - TF_RET_CHECK(bases.size() == hlo->operand_count()); - - concat_boundaries.reserve(hlo->operand_count() - 1); - for (int i = 0; i < hlo->operand_count() - 1; ++i) { - const TensorIterationSpec::IterationSpecFragment& fragment = - analysis_.IterSpec(side.scope, hlo->operand(i), concat_dim_idx) - ->at(0); - if (fragment.sliced_count % properties.block_size != 0) { - return UncompilableMatmul( - "Operand is not divisible by the block size."); - } - concat_boundaries.push_back( - Cst32(-fragment.slice_start + fragment.sliced_count)); - } - - concat_dim_pid_offset = - b_.create(properties.pid, Cst32(properties.block_size)); - TF_ASSIGN_OR_RETURN(base, EmitMultiSelect(b_, concat_dim_pid_offset, - concat_boundaries, bases)); - } else { - concat_dim_idx = -1; - base = bases[0]; - } - - auto add_dim = [&](const DimProperties& properties) -> absl::Status { - if (analysis_.IterSpec(side.scope, hlo, properties.index) == nullptr) { - return absl::OkStatus(); - } - Value pid_offset = - (properties.pid == nullptr) - ? Cst32(0) - : b_.create(properties.pid, - Cst32(properties.block_size)); - std::vector inputs; - if (hlo->opcode() == HloOpcode::kConcatenate) { - inputs.insert(inputs.end(), hlo->operands().cbegin(), - hlo->operands().cend()); - } else { - inputs = {hlo}; - } - std::vector specs; - std::vector input_strides; - std::vector input_offsets; - std::vector input_bounds; - specs.reserve(inputs.size()); - input_strides.reserve(inputs.size()); - input_offsets.reserve(inputs.size()); - input_bounds.reserve(inputs.size()); - for (const HloInstruction* input : inputs) { - specs.push_back( - analysis_.IterSpec(side.scope, input, properties.index)); - const auto stride = specs.back()->at(0).stride; - strides_sizes.push_back(stride); - input_strides.push_back(Cst64(stride)); - input_offsets.push_back(b_.create( - pid_offset, Cst32(specs.back()->at(0).slice_start))); - input_bounds.push_back(Cst64(specs.back()->at(0).count)); - } - TF_ASSIGN_OR_RETURN(Value select_value, - EmitMultiSelect(b_, concat_dim_pid_offset, - concat_boundaries, input_strides)); - strides.push_back(select_value); - if (properties.index == concat_dim_idx) { - TF_ASSIGN_OR_RETURN( - select_value, - EmitMultiSelect(b_, pid_offset, concat_boundaries, input_offsets)); - block_offsets.push_back(select_value); - TF_ASSIGN_OR_RETURN( - select_value, - EmitMultiSelect(b_, pid_offset, concat_boundaries, input_bounds)); - bounds.push_back(select_value); - tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); - } else if (hlo->opcode() == HloOpcode::kDynamicSlice && - (side.scope == TritonFusionAnalysis::Scope::LHS || - side.scope == TritonFusionAnalysis::Scope::RHS) && - properties.index == - GetNonContractingDimIdxForOperandScope(side.scope)) { - // Here we compute the offset of where we should read the slice from. - // TODO(b/323255699): Add support for slices of the contracting dim. - // Dynamic slices are guaranteed to only be offset along the majormost - // dimension. - - // The only fragment of the non-contracting dim of the dot's input in - // the current scope: - TF_RET_CHECK(specs.back()->size() == 1); - const TensorIterationSpec::IterationSpecFragment - only_fragment_of_nc_dim = specs.back()->at(0); - // The majormost dim index in the dynamic slice's output. - const int majormost_dim = hlo->shape().layout().minor_to_major().back(); - - // dynamic slice operands are (input, start_index0, start_index1, ...) - // so the start index corresponding to the ith dimension is bases[i+1]. - Value majormost_dim_start_index_ptr_val = bases[majormost_dim + 1]; - Value majormost_dim_start_index_val = b_.create( - majormost_dim_start_index_ptr_val, mt::CacheModifier::NONE, - mt::EvictionPolicy::NORMAL, - /*isVolatile=*/false); - int64_t majormost_dim_start_index_upper_limit = - hlo->operand(0)->shape().dimensions(majormost_dim) - - hlo->dynamic_slice_sizes().at(majormost_dim); - // We don't want to cast S64 indices to S32, because that could result - // in an incorrect value. - if (majormost_dim_start_index_val.getType().isInteger() && - majormost_dim_start_index_val.getType().getIntOrFloatBitWidth() == - 64) { - return UncompilableMatmul( - "64 bit dynamic-slice indices are not supported yet."); - } - majormost_dim_start_index_val = - Cast(b_, majormost_dim_start_index_val, b_.getI32Type()); - majormost_dim_start_index_val = - b_.create(majormost_dim_start_index_val, Cst32(0)); - majormost_dim_start_index_val = b_.create( - majormost_dim_start_index_val, - Cst32(majormost_dim_start_index_upper_limit)); - - // How many "rows" (non-contracting dim values) are there in a slice of - // size 1? - int64_t rows_per_majormost_dim = 1; - for (int i = 0; i < hlo->shape().dimensions().size() - 1; ++i) { - rows_per_majormost_dim *= hlo->shape().dimensions_minor(i); - } - rows_per_majormost_dim = - rows_per_majormost_dim / only_fragment_of_nc_dim.stride; - Value rows_per_majormost_dim_val = Cst32(rows_per_majormost_dim); - - Value tensor_offset_val_i32 = b_.create( - majormost_dim_start_index_val, rows_per_majormost_dim_val); - tensor_offsets.push_back(tensor_offset_val_i32); - - // tt.make_tensor_ptr expects an i64 for shape and size, but expects - // i32 for offsets. We extend the offset to calculate the upper bound. - Value tensor_offset_val_i64 = - b_.create(i64_ty_, tensor_offset_val_i32); - Value sliced_count_val = Cst64(only_fragment_of_nc_dim.sliced_count); - Value upper_bound_val = - b_.create(tensor_offset_val_i64, sliced_count_val); - bounds.push_back(upper_bound_val); - - block_offsets.push_back(pid_offset); - } else { - tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); - block_offsets.push_back(pid_offset); - int64_t dim_bound = specs.front()->at(0).count; - if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && - properties.index == dims_.out_lhs_noncontracting_dim_idx && - specs.front()->size() == 1 && - dims_.lhs_noncontracting_split.has_value()) { - // Dimension of the output produced by the non-contracting LHS one - // is logically split, major part is addressed using pid_batch. - dim_bound /= *dims_.lhs_noncontracting_split; - } - bounds.push_back(Cst64(dim_bound)); - if (dim_bound % (properties.block_size * properties.split_value) != 0) { - boundary_checks.push_back(bounds.size() - 1); - } - if (hlo->shape().element_type() == PrimitiveType::S4) { - // For s4 type we need to divide the minor dim bound by 2 because it - // is the packing dimension. But if the minor dim has length == 1 then - // the major dim stride is also 1 and it is the packing dimension. - if (strides_sizes.back() == 1) { - // For the odd bounds we need to add 1 in advance. - // Otherwise we will loose the last element. - bounds[bounds.size() - 1] = Cst64((dim_bound + 1) / 2); - } else { - int last_stride_index = strides.size() - 1; - strides[last_stride_index] = - b_.create(strides[last_stride_index], Cst64(2)); - } - } - } - block_dims.push_back(properties.block_size); - dim_order.emplace(dim_order.begin(), dim_order.size()); - return absl::OkStatus(); - }; - - for (const DimProperties& dim : side.tiled_dims) { - TF_RETURN_IF_ERROR(add_dim(dim)); - } - - int64_t offset_batch = 0; - bool has_batch_offset = false; - Value batch_stride; - - if (hlo->opcode() == HloOpcode::kConcatenate) { - std::vector batch_strides; - batch_strides.reserve(hlo->operands().size()); - for (const HloInstruction* operand : hlo->operands()) { - TF_ASSIGN_OR_RETURN( - Value op_stride, - GetBatchStride(side, operand, offset_batch, has_batch_offset)); - batch_strides.push_back(op_stride); - } - TF_ASSIGN_OR_RETURN(batch_stride, - EmitMultiSelect(b_, concat_dim_pid_offset, - concat_boundaries, batch_strides)); - } else { - TF_ASSIGN_OR_RETURN(batch_stride, GetBatchStride(side, hlo, offset_batch, - has_batch_offset)); - } - - // Avoid generating logic to compute batch offset if unnecessary. - if (has_batch_offset) { - Value pid_batch = - b_.create(launch_config_.batch_program_id_dim); - - Value pid_offset_batch = b_.create( - b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), - batch_stride); - - if (hlo->shape().element_type() == PrimitiveType::S4) { - pid_offset_batch = b_.create(pid_offset_batch, Cst(2)); - } - base = AddPtr(b_, base, pid_offset_batch); - } - - if (dims_.out_split_k_dim_idx.has_value()) { - const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec( - TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx); - if (spec != nullptr) { - TF_RET_CHECK(pid_k != nullptr); - base = AddPtr(b_, base, - b_.create(ConvertScalar(pid_k), - Cst(spec->at(0).stride))); - } - } - - if (block_dims.empty()) { - // Load of a scalar. - return base; - } - auto tensor_ptr = mlir::cast( - b_.create(base, bounds, strides, tensor_offsets, - block_dims, dim_order) - .getResult()); - tensor_ptr = b_.create(tensor_ptr.getType(), tensor_ptr, - block_offsets); - return tensor_ptr; - } - - private: - // Extend int32 indexes to int64, if necessary. - Value ConvertScalar(Value value) { - if (index_ty_.getIntOrFloatBitWidth() == 64) { - return b_.create(index_ty_, value); - } - return value; - } - - Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); } - Value Cst32(int32_t v) { return CreateConst(b_, i32_ty_, v); } - Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } - - ImplicitLocOpBuilder& b_; - absl::string_view libdevice_path_; - const se::DeviceDescription& device_info_; - const HloDotInstruction* dot_instr_; - Type index_ty_; - TritonFusionAnalysis analysis_; - MatMulDims dims_; - MatMulLaunchConfig launch_config_; - Type i32_ty_ = b_.getI32Type(); - Type i64_ty_ = b_.getI64Type(); -}; - -} // namespace - -absl::StatusOr GetMatMulLaunchDimensions( - const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { - auto dot = HloBfsFindIf(fusion.GetRoots(), fusion, [](auto node) { - return node.opcode() == HloOpcode::kDot; - }); - TF_RET_CHECK(dot != std::nullopt); - const auto& dot_instr = - *static_cast(&dot->instruction()); - TF_ASSIGN_OR_RETURN(MatMulDims dims, - MatMulDims::Create(config, dot_instr, analysis)); - MatMulLaunchConfig launch_config(config, dot_instr, dims); - return launch_config.launch_dims; -} - -absl::StatusOr> GetArguments(mlir::triton::FuncOp fn, - const HloInstruction& input) { - if (input.opcode() == HloOpcode::kParameter) { - return {{fn.getArgument(input.parameter_number())}}; - } else if (input.opcode() == HloOpcode::kConcatenate || - input.opcode() == HloOpcode::kDynamicSlice) { - // As defined in GemmFusion, all inputs of concatenate and dynamic slice are - // parameters. - SmallVector result; - for (const HloInstruction* operand : input.operands()) { - TF_RET_CHECK(operand->opcode() == HloOpcode::kParameter); - result.push_back(fn.getArgument(operand->parameter_number())); - } - return result; - } - LOG(FATAL) << "Unexpected opcode: " << input.opcode(); -} - -// Concatenations can currently only be applied directly to parameters; -// all concatenated parameters share the same block pointer. This function -// returns all inputs of a kernel: concatenations of parameters and standalone -// parameters. -ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, - const TritonFusionAnalysis::Scope scope) { - ConstHloInstructionSet result; - for (const HloInstruction* parameter : analysis.ScopeParameters(scope)) { - if (absl::c_any_of(parameter->users(), [](const HloInstruction* user) { - return user->opcode() == HloOpcode::kConcatenate || - user->opcode() == HloOpcode::kDynamicSlice; - })) { - // Concatenation is always the only user of its parameters by - // construction. - CHECK_EQ(parameter->users().size(), 1); - for (const HloInstruction* operand : parameter->users()[0]->operands()) { - // All operands of a concatenation have to be computation parameters. - CHECK_EQ(operand->opcode(), HloOpcode::kParameter); - } - result.insert(parameter->users()[0]); - } else { - result.insert(parameter); - } - } - return result; -} - -// Truncates |input| of F32 type to the number representable in Bf16 toward -// zero. -// It is used for Emit6xBfloat16MatMul. -Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { - ShapedType input_type = mlir::dyn_cast(input.getType()); - Type input_type_as_i32 = input_type.clone(b.getI32Type()); - Value input_as_i32 = b.create(input_type_as_i32, input); - Value mask = CreateConst(b, b.getI32Type(), 0xFFFF0000u, - input_type.getShape()); - Value high_bits = b.create(input_type_as_i32, input_as_i32, mask); - - return b.create(input_type, high_bits); -} - -// Finds the middle 8 bits of |input|'s mantissa. -// It is used for Emit6xBfloat16MatMul. -Value SoftMiddleEight(ImplicitLocOpBuilder& b, Value input) { - Value high = TruncateToBF16TowardsZero(b, input); - return b.create(input, high); -} - -// Finds the low 8 bits of |input|'s mantissa. -// It is used for Emit6xBfloat16MatMul. -Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { - // Find the middle bits of the middle bits, and these are the low eight - // bits. - return SoftMiddleEight(b, SoftMiddleEight(b, input)); -} - -// Rounds |input| to BF16 type. -// It is used for Emit6xBfloat16MatMul. -Value RoundToBF16(ImplicitLocOpBuilder& b, Value input) { - return Cast(b, input, b.getBF16Type()); -} - -// Checks |input| is finite f32 (not Nan and not infinite). -// It is used for Emit6xBfloat16MatMul and Emit3xBfloat16MatMul. -Value CheckFiniteF32(ImplicitLocOpBuilder& b, Value input) { - Value positive_inf = CreateConst( - b, b.getF32Type(), std::numeric_limits::infinity(), - mlir::cast(input.getType()).getShape()); - Value abs_input = b.create(input); - return b.create(ma::CmpFPredicate::OGT, positive_inf, abs_input); -} - -// Leverages BF16 datatype for F32 matmul computation. It follows the guidance -// from https://arxiv.org/pdf/1904.06376.pdf. -absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, - Value rhs, Value acc) { - Type f32 = b.getF32Type(); - TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(rhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(acc.getType()).getElementType() == f32); - - Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); - Value lhs_middle = - RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, lhs))); - Value lhs_low = - RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, lhs))); - - Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); - Value rhs_middle = - RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, rhs))); - Value rhs_low = - RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, rhs))); - - auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, - Value accumulator) -> Value { - return b.create(lhs_bf16, rhs_bf16, accumulator, - /*inputPrecision=*/mt::InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); - }; - - Value local_acc = ZerosLike(b, acc); - Value result = bf16_dot(lhs_middle, rhs_middle, local_acc); - result = bf16_dot(lhs_low, rhs_high, result); - result = bf16_dot(lhs_high, rhs_low, result); - result = bf16_dot(lhs_middle, rhs_high, result); - result = bf16_dot(lhs_high, rhs_middle, result); - // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. - // If rhs is +infinity, we will have: - // +infinity * 1.0 = +infinity - // +infinity * 0.0 = NaN - // We would get the wrong result if we sum these partial products. Instead, we - // must override any accumulated result if the last partial product is - // non-finite. See b/115844437. - Value is_finite = CheckFiniteF32(b, result); - result = b.create(is_finite, result, ZerosLike(b, result)); - result = bf16_dot(lhs_high, rhs_high, result); - result = b.create(acc, result); - return result; -} - -// Compute F32 matmul with 3 BF16 dots. It is less accurate than -// Emit6xBfloat16MatMul. -absl::StatusOr Emit3xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, - Value rhs, Value acc) { - Type f32 = b.getF32Type(); - TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(rhs.getType()).getElementType() == f32); - TF_RET_CHECK(mlir::cast(acc.getType()).getElementType() == f32); - - Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); - Value lhs_low = RoundToBF16(b, SoftMiddleEight(b, lhs)); - - Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); - Value rhs_low = RoundToBF16(b, SoftMiddleEight(b, rhs)); - - auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, - Value accumulator) -> Value { - return b.create(lhs_bf16, rhs_bf16, accumulator, - /*inputPrecision=*/mt::InputPrecision::IEEE, - /*maxNumImpreciseAcc=*/0); - }; - - Value local_acc = ZerosLike(b, acc); - Value result = bf16_dot(lhs_low, rhs_high, local_acc); - result = bf16_dot(lhs_high, rhs_low, result); - Value is_finite = CheckFiniteF32(b, result); - result = b.create(is_finite, result, ZerosLike(b, result)); - result = bf16_dot(lhs_high, rhs_high, result); - result = b.create(acc, result); - return result; -} - -namespace { - -bool IsTf32Allowed(const HloDotInstruction* dot_instr) { - const PrecisionConfig::Algorithm algorithm = - dot_instr->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - return tsl::tensor_float_32_execution_enabled() && - absl::c_none_of(dot_instr->precision_config().operand_precision(), - [](const int precision) { - return precision != PrecisionConfig::DEFAULT; - }); - } - - return algorithm_util::HasTf32InputType(algorithm); -} - -bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, - mlir::OpBuilder& builder, Value dot_input_lhs, - Value dot_input_rhs, - const se::DeviceDescription& device_info) { - const PrecisionConfig::Algorithm algorithm = - dot_instr->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = builder.getF32Type(); - return hlo_module->config() - .debug_options() - .xla_gpu_enable_bf16_6way_gemm() && - mlir::cast(dot_input_lhs.getType()).getElementType() == - f32 && - mlir::cast(dot_input_rhs.getType()).getElementType() == - f32; - } - - return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6; -} - -bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, - mlir::OpBuilder& builder, Value dot_input_lhs, - Value dot_input_rhs, - const se::DeviceDescription& device_info) { - const PrecisionConfig::Algorithm algorithm = - dot_instr->precision_config().algorithm(); - - if (algorithm == PrecisionConfig::ALG_UNSET) { - const HloModule* hlo_module = dot_instr->GetModule(); - Type f32 = builder.getF32Type(); - return hlo_module->config() - .debug_options() - .xla_gpu_enable_bf16_3way_gemm() && - mlir::cast(dot_input_lhs.getType()).getElementType() == - f32 && - mlir::cast(dot_input_rhs.getType()).getElementType() == - f32; - } - - return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3; -} - -// This is a heuristic that serves as a proxy for register usage and code size. -// -// We have noticed that tilings with very long LLVM IR code are both slow to -// compile and slow to run. This can be for example due to register spills. So -// we should skip these tilings to save time. But it's better to skip them -// before the LLVM IR is generated. To do that, we came up with a formula that -// strongly correlates with the LLVM IR size. The formula is the size of the two -// input and the output thread block tiles divided by the number of warps. We -// read https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ as a -// reference, and found the formula by trial and error. -// -// To regenerate the limit, we have to run an exhaustive search on all tilings -// for a few different HLOs, printing the runtimes and the heuristic values. -// -// From that, we can find a limit, such that all tilings within alpha * -// optimal_runtime have a heuristic value less than or equal to the limit. -// -// In our measurements, all tilings which were within 1.13 * optimal_runtime had -// a complexity_heuristic_value <= kComplexityHeuristicLimit. -// -// See go/tiling-heuristic for more details. -absl::Status CheckGemmTilingComplexityHeuristic( - const TritonGemmConfig& config) { - constexpr int64_t kComplexityHeuristicLimit = 9000; - int64_t complexity_heuristic_value = - (config.block_m * config.block_n + - (config.block_m + config.block_n) * config.block_k) / - config.num_warps; - VLOG(2) << "Complexity heuristic: " << complexity_heuristic_value; - if (complexity_heuristic_value > kComplexityHeuristicLimit) { - return ResourceExhausted("Tiling complexity heuristic exceeded: %d > %d", - complexity_heuristic_value, - kComplexityHeuristicLimit); - } - return absl::OkStatus(); -} - -class Scopes { - public: - Scopes(ImplicitLocOpBuilder& b, const HloInstruction* dot_instr, - const TritonFusionAnalysis& analysis, const MatMulDims& dims, - const TritonGemmConfig& config, const MatMulLaunchConfig launch_config, - bool is_sparse) - : lhs_(TritonFusionAnalysis::Scope::LHS), - rhs_(TritonFusionAnalysis::Scope::RHS), - out_(TritonFusionAnalysis::Scope::OUTPUT) { - constexpr int group_m = 8; - const int64_t width = group_m * launch_config.grid_n; - - auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; - - auto pid_nc = b.create( - launch_config.noncontracting_program_id_dim); - pid_k_ = (config.split_k > 1) - ? b.create(mt::ProgramIDDim::Z) - : Value{}; - - auto group_id = b.create(pid_nc, c32(width)); - ma::ConstantOp group_m_op = c32(group_m); - auto first_pid_m = b.create(group_id, group_m_op); - auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); - auto group_size = b.create( - b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, - group_m_op); - - pid_m_ = b.create(first_pid_m, - b.create(pid_nc, group_size)); - - pid_n_ = b.create(b.create(pid_nc, c32(width)), - group_size); - - int lhs_non_contracting_block_size = config.block_m; - int lhs_contracting_block_size = config.block_k; - int lhs_unpack_bound_idx = 0; - if (is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { - auto minor_dim = std::max(dims.lhs_contracting_dim_idx, - dims.lhs_noncontracting_dim_idx); - auto minor_bound = analysis - .IterSpec(TritonFusionAnalysis::Scope::LHS, - dot_instr->operand(0), minor_dim) - ->at(0) - .count; - if (minor_bound == - 1) { // Assuming that the contracting dimension is major. - lhs_contracting_block_size /= 2; - lhs_unpack_bound_idx = 1; - } else if (dims.lhs_contracting_dim_idx > - dims.lhs_noncontracting_dim_idx) { - // lhs is int4 and the contracting dimension is minor. - lhs_contracting_block_size /= 2; - lhs_unpack_bound_idx = 1; - } else { - // lhs is int4 and the contracting dimension is major. - lhs_non_contracting_block_size /= 2; - lhs_unpack_bound_idx = 0; - } - } - if (is_sparse) { - lhs_contracting_block_size /= 2; - } - lhs_.tiled_dims = { - DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, - lhs_non_contracting_block_size, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k_, - lhs_contracting_block_size, config.split_k)}; - lhs_.batch_dim_idx = dims.lhs_batch_dim_idx; - lhs_.unpack_dim_idx = lhs_unpack_bound_idx; - - int rhs_contracting_block_size = config.block_k; - int rhs_non_contracting_block_size = config.block_n; - int rhs_unpack_bound_idx = 0; - if (is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { - auto minor_dim = std::max(dims.rhs_contracting_dim_idx, - dims.rhs_noncontracting_dim_idx); - auto minor_bound = analysis - .IterSpec(TritonFusionAnalysis::Scope::RHS, - dot_instr->operand(1), minor_dim) - ->at(0) - .count; - - if (minor_bound == 1) { // rhs is int4 and the _minor_ bound is 1. - rhs_contracting_block_size /= 2; - } else if (dims.rhs_contracting_dim_idx > - dims.rhs_noncontracting_dim_idx) { - // rhs is int4 and the contracting dimension is minor. - rhs_contracting_block_size /= 2; - } else { - // rhs is int4 and the contracting dimension is major. - rhs_non_contracting_block_size /= 2; - rhs_unpack_bound_idx = 1; - } - } - rhs_.tiled_dims = { - DimProperties(dims.rhs_contracting_dim_idx, pid_k_, - rhs_contracting_block_size, config.split_k), - DimProperties(dims.rhs_noncontracting_dim_idx, pid_n_, - rhs_non_contracting_block_size, - /*split_value=*/1)}; - rhs_.batch_dim_idx = dims.rhs_batch_dim_idx; - rhs_.unpack_dim_idx = rhs_unpack_bound_idx; - - out_.tiled_dims = {DimProperties(dims.out_lhs_noncontracting_dim_idx, - pid_m_, config.block_m, - /*split_value=*/1), - DimProperties(dims.out_rhs_noncontracting_dim_idx, - pid_n_, config.block_n, - /*split_value=*/1)}; - out_.batch_dim_idx = dims.out_batch_dim_idx; - - if (is_sparse) { - meta_ = Side{TritonFusionAnalysis::Scope::META, - /*tiled_dims=*/ - {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, - config.block_m, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k_, - config.block_k / 16, config.split_k)}, - dims.lhs_batch_dim_idx}; - } - } - - std::vector input_scopes() const { - if (meta_.has_value()) { - return {&lhs_, &rhs_, &meta_.value()}; - } - return {&lhs_, &rhs_}; - } - const Side& lhs() const { return lhs_; } - const Side& rhs() const { return rhs_; } - const Side& out() const { return out_; } - const std::optional& meta() const { return meta_; } - const Value& pid_m() const { return pid_m_; } - const Value& pid_k() const { return pid_k_; } - const Value& pid_n() const { return pid_n_; } - - static bool is_int4_param(const TritonFusionAnalysis& analysis, - TritonFusionAnalysis::Scope scope) { - const ConstHloInstructionSet& params = analysis.ScopeParameters(scope); - return params.size() == 1 && - (*params.cbegin())->shape().element_type() == S4; - } - - private: - Side lhs_; - Side rhs_; - Side out_; - std::optional meta_; - - Value pid_m_; - Value pid_k_; - Value pid_n_; -}; - -enum MaskExpandDimension { kMajor = 0, kMinor = 1 }; - -Value EmitMaskOnInput(ImplicitLocOpBuilder& b, - MaskExpandDimension expand_dimension, Value input, - int denom, Value k, int64_t dims_k, int64_t block_k, - Value pid_k) { - auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; - int size = block_k / denom; - auto elements_in_tile = b.create(c32(dims_k / denom), k); - auto cond = - b.create(ma::CmpIPredicate::slt, elements_in_tile, c32(size)); - auto if_op = b.create( - cond, /*thenBranch=*/ - [&](mlir::OpBuilder& builder, mlir::Location loc) { - ImplicitLocOpBuilder b(loc, builder); - auto range_k = Range(b, size); - if (pid_k != nullptr) { - range_k = b.create( - range_k, Splat(b, b.create(pid_k, c32(size)), size)); - } - auto ty = mlir::cast(input.getType()); - TensorValue range_expanded = mlir::cast( - b.create(range_k, expand_dimension).getResult()); - Value mask = b.create( - ty.clone(b.getI1Type()), - b.create(ma::CmpIPredicate::slt, range_expanded, - Splat(b, elements_in_tile, - range_expanded.getType().getShape()))); - auto result = b.create(mask, input, ZerosLike(b, input)); - b.create(mlir::ValueRange(result)); - }, - /*elseBranch=*/ - [&](mlir::OpBuilder& b, mlir::Location loc) { - b.create(loc, mlir::ValueRange(input)); - }); - return if_op.getResult(0); -} - -} // namespace - -// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. -absl::Status EmitMatMul(mlir::OpBuilder builder, - absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, const BlockLevelParameters&) { - auto backend_config = - fusion->backend_config()->fusion_backend_config(); - - if (!backend_config.has_triton_gemm_config()) { - // TODO(bchetioui): consolidate default parameters. At the moment, these - // may be constructed in two distinct places. - LOG(WARNING) << "Using fallback triton GEMM config for op " - << fusion->name(); - auto& triton_config = *backend_config.mutable_triton_gemm_config(); - triton_config.set_block_m(64); - triton_config.set_block_k(64); - triton_config.set_block_n(64); - triton_config.set_split_k(1); - triton_config.set_num_stages(1); - triton_config.set_num_warps(2); - triton_config.set_num_ctas(1); - } - - TF_ASSIGN_OR_RETURN( - TritonGemmConfig config, - TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); - TF_ASSIGN_OR_RETURN(auto analysis, - TritonFusionAnalysis::Execute( - *fusion->called_computation(), config.split_k)); - - TF_RETURN_IF_ERROR(CheckGemmTilingComplexityHeuristic(config)); - - const HloComputation* computation = fusion->fused_instructions_computation(); - const HloInstruction* instr = - hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); - const HloDotInstruction* dot_instr = DynCast(instr); - bool is_sparse = dot_instr->sparse_operands() > 0; - - // Use 32-bit indexing if addressing any of the inputs or the output (which - // could grow if split_k is set) does not cross the INT_MAX boundary. - // Otherwise, fall back to 64-bit indexing, which is slower. - bool use_64bit_indexing = - ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX || - ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX || - ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k > INT_MAX; - Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); - - const HloInstruction* root = dot_instr->parent()->root_instruction(); - TF_RET_CHECK(!root->shape().IsTuple()); - - // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. - bool is_unsupported_bitwidth = - HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { - if (node->opcode() != HloOpcode::kConvert) { - return false; - } - int in_width = - primitive_util::BitWidth(node->operand(0)->shape().element_type()); - return in_width <= 8 && node->shape().element_type() == F32; - }); - - // We'll be creating a lot of instructions from a single dot, use an - // implicit loc builder so we don't have to pass around the location all the - // time. - auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); - ImplicitLocOpBuilder b(loc, builder); - - TF_RETURN_IF_ERROR(ValidateMatMulConfig(config, *dot_instr)); - const int split_k = config.split_k; - const int block_m = config.block_m; - const int block_k = config.block_k; - const int block_n = config.block_n; - - TF_ASSIGN_OR_RETURN(const MatMulDims dims, - MatMulDims::Create(config, *dot_instr, analysis)); - const MatMulLaunchConfig launch_config(config, *dot_instr, dims); - VLOG(6) << analysis.ToString(); - - MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, - index_ty, dims, launch_config, analysis); - - TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); - - ma::ConstantOp accumulator_init = - CreateConst(b, acc_ty, 0, {block_m, block_n}); - - // Parameters are passed to the loop in non-trivial order, these maps help - // finding them and their attributes. - absl::flat_hash_map iter_args_to_inputs; - absl::flat_hash_map> iter_args_to_boundary_checks; - - // Calculate the sizes of the lhs, rhs, meta, and output sides. - Scopes scopes(b, dot_instr, analysis, dims, config, launch_config, is_sparse); - - auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; - - constexpr size_t kLhsMetaOperandIdx = HloDotInstruction::kOperands; - size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); - size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); - - absl::flat_hash_map triton_type_for_input; - for (const Side& side : {scopes.lhs(), scopes.rhs()}) { - for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { - TF_ASSIGN_OR_RETURN(Type input_ty, - TritonType(b, input->shape().element_type())); - triton_type_for_input.insert({input, input_ty}); - } - } - - auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, - ValueRange iter_args) -> void { - SmallVector iter_args_next; - iter_args_next.reserve(iter_args.size()); - std::array, 3> values; - - // Load tiles of all parameters of LHS and RHS scopes and advance pointers. - for (int i = 0; i < iter_args.size() - 1; ++i) { - const int index = i < lsize ? 0 : i < lsize + rsize ? 1 : 2; - const Side& side = *(scopes.input_scopes()[index]); - - const HloInstruction* param_hlo = iter_args_to_inputs[i]; - Type param_ty = index == kLhsMetaOperandIdx - ? b.getI16Type() - : triton_type_for_input.at(param_hlo); - Type param_storage_ty = StorageType(b, param_ty); - Value param_value = - EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]); - if (param_ty != param_storage_ty) { - // For example cast i8 to i1. - param_value = Cast(b, param_value, param_ty); - } - - CHECK(values[index].insert({param_hlo, param_value}).second); - SmallVector increments; - for (const DimProperties& dim : side.tiled_dims) { - const TensorIterationSpec::DimIterationSpec* spec = - analysis.IterSpec(side.scope, iter_args_to_inputs[i], dim.index); - if (spec == nullptr || spec->at(0).stride == 0) { - continue; - } - // Only the contracting dimensions are advanced. - if (dim.index == (index == 0 || index == kLhsMetaOperandIdx - ? dims.lhs_contracting_dim_idx - : dims.rhs_contracting_dim_idx)) { - increments.push_back(c32(dim.block_size * split_k)); - } else { - increments.push_back(c32(0)); - } - } - if (increments.empty()) { - iter_args_next.push_back(iter_args[i]); - } else { - iter_args_next.push_back(b.create( - iter_args[i].getType(), iter_args[i], increments)); - } - } - - // Emit all operations of LHS and RHS scopes. - Value dot_input_lhs = emitter.MakeInput(scopes.lhs(), 0, values[0]); - Value dot_input_rhs = emitter.MakeInput(scopes.rhs(), 1, values[1]); - Value dot_input_meta = - is_sparse ? emitter.MakeInput(*scopes.meta(), 2, values[2]) : Value{}; - - // Operation in the fusion before the dot can alter the elements of the - // tiles that were zero masked during loads. These have to be zeroed here - // again just before the dot so that they do not affect the output. - // Only the K dimension needs masking here because unnecessary elements in - // the other two get discarded by the masked store at the end. - const bool need_masking = dims.k % (block_k * split_k) > 0; - if (need_masking) { - dot_input_lhs = EmitMaskOnInput(b, MaskExpandDimension::kMajor, - dot_input_lhs, is_sparse ? 2 : 1, ki, - dims.k, block_k, scopes.pid_k()); - dot_input_rhs = - EmitMaskOnInput(b, MaskExpandDimension::kMinor, dot_input_rhs, 1, ki, - dims.k, block_k, scopes.pid_k()); - // Masking the metadata is not necessary, as the inputs are masked - // (i.e. zeroed out), so the padded metadata can hold any values. - } - - if (is_sparse) { - iter_args_next.push_back(b.create( - dot_input_lhs, dot_input_rhs, iter_args.back(), dot_input_meta)); - b.create(iter_args_next); - return; - } - - const HloModule* hlo_module = dot_instr->GetModule(); - if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() && - hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) { - LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled." - << " Fallback to BF16 6way gemm."; - } - - Value accumulator_next; - if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, - device_info)) { - absl::StatusOr accumulator_next_or = Emit6xBfloat16MatMul( - b, dot_input_lhs, dot_input_rhs, iter_args.back()); - TF_CHECK_OK(accumulator_next_or.status()); - accumulator_next = accumulator_next_or.value(); - } else if (Is3xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, - device_info)) { - absl::StatusOr accumulator_next_or = Emit3xBfloat16MatMul( - b, dot_input_lhs, dot_input_rhs, iter_args.back()); - TF_CHECK_OK(accumulator_next_or.status()); - accumulator_next = accumulator_next_or.value(); - } else { - // Execute matrix multiplication of input tiles and pass the accumulator. - // TODO(manany): Should be looked into once we enable Hopper workloads. - // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a - // lower precision than the output type. The change was introduced here: - // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a - auto input_precision = - IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth - ? mt::InputPrecision::TF32 - : mt::InputPrecision::IEEE; - // For fp8 matmuls, disable accumulator promotion, as it's what cublas - // does. It may make sense to enable frequent accumulator promotion at - // higher matmul precisions set in the config. - int max_num_imprecise_acc = - IsFp8Matmul(dot_instr) ? std::numeric_limits::max() : 0; - accumulator_next = - b.create(dot_input_lhs, dot_input_rhs, iter_args.back(), - /*inputPrecision=*/input_precision, - /*maxNumImpreciseAcc=*/max_num_imprecise_acc); - } - iter_args_next.push_back(accumulator_next); - - b.create(iter_args_next); - return; - }; - - // Pointers to inputs of LHS scope, then RHS, then the accumulator - // that change with every loop iteration and are passed between them. - SmallVector iter_args; - iter_args.reserve(lsize + rsize + 1 + is_sparse); - - for (const Side* side : scopes.input_scopes()) { - for (const HloInstruction* input : ScopeInputs(analysis, side->scope)) { - TF_RET_CHECK( - iter_args_to_inputs.insert({iter_args.size(), input}).second); - TF_ASSIGN_OR_RETURN(SmallVector arguments, - GetArguments(fn, *input)); - TF_ASSIGN_OR_RETURN(Value tensor_ptr, - emitter.EmitTensorPointer( - input, *side, arguments, scopes.pid_k(), - iter_args_to_boundary_checks[iter_args.size()])); - iter_args.push_back(tensor_ptr); - } - } - - iter_args.push_back(accumulator_init); - Value acc_final = b.create( - /*lowerBound=*/c32(0), - /*upperBound=*/c32(dims.k), - /*step=*/c32(block_k * split_k), - /*iterArgs=*/iter_args, body_builder) - .getResult(iter_args.size() - 1); - absl::flat_hash_map values_out; - TF_ASSIGN_OR_RETURN(Type acc_final_ty, - TritonType(b, dot_instr->shape().element_type())); - values_out[dot_instr] = Cast(b, acc_final, acc_final_ty); - - // Emit the output scope. - if (std::vector to_emit = - emitter.EpiloguePostOrderTransitiveOperands(root); - !to_emit.empty()) { - for (const HloInstruction* input : - ScopeInputs(analysis, TritonFusionAnalysis::Scope::OUTPUT)) { - std::vector boundary_checks; - TF_ASSIGN_OR_RETURN(SmallVector arguments, - GetArguments(fn, *input)); - TF_ASSIGN_OR_RETURN( - Value tensor_pointer, - emitter.EmitTensorPointer(input, scopes.out(), arguments, - scopes.pid_k(), boundary_checks)); - TF_RET_CHECK(values_out - .insert({input, EmitParameterLoad(b, tensor_pointer, - boundary_checks)}) - .second); - } - TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, device_info, &analysis, - scopes.out(), to_emit, values_out) - .status()); - } - - // Emit tensor store operations for all outputs. - for (int i = 0; - i < fn.getNumArguments() - dot_instr->parent()->num_parameters(); ++i) { - const HloInstruction* producer = - root->shape().IsTuple() ? root->operand(i) : root; - std::vector boundary_checks; - TF_ASSIGN_OR_RETURN( - Value tensor_pointer, - emitter.EmitTensorPointer( - producer, scopes.out(), - {fn.getArgument(i + dot_instr->parent()->num_parameters())}, - scopes.pid_k(), boundary_checks)); - b.create(tensor_pointer, values_out[producer], boundary_checks, - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); - } - return absl::OkStatus(); -} - // Computes the base pointer offset for the given tile multi-index and hlo shape // taking into account the physical layout of the hlo buffer. absl::StatusOr ComputeBasePtrOffset( @@ -2772,6 +815,8 @@ absl::StatusOr ComputeBasePtrOffset( /*symbols=*/{}, b)[0]); } +} // namespace + namespace ir_emitter_triton_internal { SmallVector ComputeDelinearizedTileIndex( @@ -2811,7 +856,8 @@ absl::StatusOr CreateMakeTensorPtrOp( int64_t current_stride = 1; for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) { strides[cur_dim] = - CreateConst(b, b.getI64Type(), tile_strides[cur_dim] * current_stride); + CreateConst(b, b.getI64Type(), tile_strides[cur_dim] * current_stride) + .UnwrapScalar(); current_stride *= shape.dimensions(cur_dim); } @@ -2850,11 +896,12 @@ absl::StatusOr CreateMakeTensorPtrOp( // compute a "residual shape" which is the original parent shape minus // the offsets. Value parent_size = - CreateConst(b, b.getI64Type(), shape.dimensions(dim_idx)); + CreateConst(b, b.getI64Type(), shape.dimensions(dim_idx)) + .UnwrapScalar(); Value offset = b.create(b.getI64Type(), tile_offsets_as_indices[dim_idx]); residual_shape.push_back(b.create(parent_size, offset)); - offsets.push_back(CreateConst(b, b.getI32Type(), 0)); + offsets.push_back(CreateConst(b, b.getI32Type(), 0).UnwrapScalar()); // TODO(b/342989850): Clarify and comment what `order` exactly is. It's not // entirely clear from the Triton docs. @@ -2882,6 +929,8 @@ absl::StatusOr CreateMakeTensorPtrOp( } // namespace ir_emitter_triton_internal +namespace { +// Generate Triton IR inside 'fn', using the given block_level_parameters. absl::Status EmitGeneric(mlir::OpBuilder builder, absl::string_view libdevice_path, const se::DeviceDescription& device_info, @@ -2916,7 +965,7 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, b, tiled_hlo_computation.num_output_tiles_per_dim()); TF_ASSIGN_OR_RETURN( - Value result, + ScalarOrTensor result, EmitTiledScope(b, libdevice_path, device_info, fusion, tiled_hlo_computation, fn, tile_multi_index)); @@ -2924,25 +973,38 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, // as i8. It's important to type checking that we perform a conversion before // storing if the type of the result does not match the type of the output // pointer. - Type result_element_type = - mlir::cast(result.getType()).getElementType(); + Type result_element_type = getElementTypeOrSelf(result.Type()); Type result_storage_type = StorageType(b, result_element_type); if (result_element_type != result_storage_type) { - result = Cast(b, result, result_storage_type); + result = + ScalarOrTensor(Cast(b, result.UnwrapUnsafe(), result_storage_type)); } const auto& tiled_hlo = *tiled_hlo_computation.GetRoot(); + + Value parent_base_ptr = fn.getArgument(computation->num_parameters()); + + if (result.IsScalar()) { + b.create(parent_base_ptr, result.UnwrapScalar(), + mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + return absl::OkStatus(); + } + + CHECK(tiled_hlo.hlo()->shape().IsArray() && + tiled_hlo.hlo()->shape().rank() > 0); TF_ASSIGN_OR_RETURN(auto make_tensor, ir_emitter_triton_internal::CreateMakeTensorPtrOp( - b, tile_multi_index, tiled_hlo, - fn.getArgument(computation->num_parameters()))); - b.create(make_tensor.op, result, make_tensor.boundary_checks, - mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + b, tile_multi_index, tiled_hlo, parent_base_ptr)); + b.create(make_tensor.op, result.UnwrapTensor(), + make_tensor.boundary_checks, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL); return absl::OkStatus(); } +} // namespace + void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context) { mlir_context .loadDialect> TranslateLLVMToLLVMIR( - llvm::LLVMContext* llvmContext, mlir::ModuleOp module, - absl::string_view libdevice_path) { + llvm::LLVMContext* llvmContext, mlir::ModuleOp module) { mlir::DialectRegistry registry; mlir::registerBuiltinDialectTranslation(registry); mlir::registerLLVMDialectTranslation(registry); @@ -2992,15 +1053,6 @@ absl::Status CreateInternalError(std::string_view message, return absl::InternalError(err); } -absl::Status DoSupportType(const DebugOptions& debug_options, - PrimitiveType type) { - if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) { - return absl::FailedPreconditionError( - "Int4 support is not enabled in the debug options."); - } - return absl::OkStatus(); -} - absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -3022,10 +1074,11 @@ absl::StatusOr> CreateTritonModule( SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); - TF_RETURN_IF_ERROR(DoSupportType(debug_options, type)); Type ir_type; if (type == U16) { ir_type = b.getI16Type(); + } else if (type == S4) { + ir_type = b.getI8Type(); } else { TF_ASSIGN_OR_RETURN(ir_type, TritonType(b, type)); } @@ -3232,8 +1285,7 @@ absl::StatusOr CompileTritonToLLVM( if (emit_kernel) { TF_ASSIGN_OR_RETURN( std::unique_ptr ll_triton_module, - TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module, - GetLibdevicePath(hlo_config, device_info))); + TranslateLLVMToLLVMIR(&llvm_module->getContext(), triton_module)); VLogModule(5, *ll_triton_module); if (should_verify) { VerifyModule(*ll_triton_module); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index 9c7cd49cd3d862..c987359dfe674b 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -65,28 +65,6 @@ struct TritonWrapperResult { std::optional cluster_dim; }; -// Generate Triton IR inside 'fn'. This uses the given output_tile_sizes -// and the SymbolicTileAnalysis from the computation. The provided -// TritonFusionAnalysis and TritonGemmConfig are ignored. -absl::Status EmitGeneric(mlir::OpBuilder b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters); - -// Compute the launch dimensions for the given Triton MatMul. -absl::StatusOr GetMatMulLaunchDimensions( - const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config); - -// Use tiling and execution parameters from 'config'. output_tile_sizes is -// ignored. -absl::Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters); - // Load the MLIR dialects required for Triton IR generation. void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index d07abdb4811224..ee78e17bf7e044 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -14,8 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include -#include #include #include #include @@ -24,12 +22,9 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" -#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" -#include "absl/types/span.h" #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -39,8 +34,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/literal.h" -#include "xla/literal_util.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" @@ -50,14 +45,11 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" -#include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -67,6 +59,7 @@ namespace gpu { namespace { namespace m = ::xla::match; +using tsl::testing::StatusIs; class TritonTest : public GpuCodegenTest { public: @@ -100,7 +93,21 @@ class TritonTest : public GpuCodegenTest { class TritonGemmTest : public TritonTest { public: - DebugOptions GetDebugOptionsForTest() override { + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); // Do not fall back to cuBLAS, we are testing Triton. debug_options.set_xla_gpu_cublas_fallback(false); @@ -109,7 +116,6 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); - debug_options.set_xla_gpu_enable_triton_gemm_int4(true); return debug_options; } @@ -122,7 +128,7 @@ class TritonGemmTest : public TritonTest { class TritonGemmTestWithSplitK : public TritonGemmTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_split_k_autotuning(true); return debug_options; @@ -131,15 +137,53 @@ class TritonGemmTestWithSplitK : public TritonGemmTest { class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(false); return debug_options; } }; +TEST_F(TritonGemmTest, RejectDotInt4HLO) { + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[16,32,64]{2,1,0} parameter(0) + rhs = s4[16,64,16]{2,1,0} parameter(1) + ROOT dot = s4[16,32,16]{2,1,0} dot(lhs, rhs), + lhs_contracting_dims={2}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + )"; + EXPECT_THAT(GetOptimizedModule(kHloText).status(), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + +TEST_F(TritonGemmTest, RejectInt4NegatePlusConvertHLO) { + constexpr std::string_view kHloText = R"( + HloModule t + + ENTRY main { + lhs = s4[16,32,64]{2,1,0} parameter(0) + lhs_negated = s4[16,32,64]{2,1,0} negate(lhs) + lhs_converted = bf16[16,32,64]{2,1,0} convert(lhs_negated) + rhs = bf16[16,64,16]{2,1,0} parameter(1) + ROOT dot = bf16[16,32,16]{2,1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={2}, + rhs_contracting_dims={1}, + lhs_batch_dims={0}, + rhs_batch_dims={0} + } + )"; + EXPECT_THAT(GetOptimizedModule(kHloText).status(), + StatusIs(tsl::error::INVALID_ARGUMENT)); +} + TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY main { @@ -153,8 +197,9 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { rhs_batch_dims={0} } )"; + const std::string pattern = - R"(CHECK-NOT: ""kind":"__triton_gemm","triton_gemm_config"")"; + R"(CHECK-NOT: "kind":"__triton_gemm","triton_gemm_config")"; TF_ASSERT_OK_AND_ASSIGN(auto module, GetOptimizedModule(kHloText)); TF_ASSERT_OK_AND_ASSIGN(auto ok, RunFileCheck(module->ToString(), pattern)); EXPECT_TRUE(ok); @@ -163,7 +208,7 @@ TEST_F(TritonGemmTest, RejectTritonFusionForInt4WithMinorBatchDim) { TEST_F(TritonGemmTest, LHSInt4WithMinorDimEqualTo1) { // We prove that triton can handle int4 dot with non contracting dim size // equal to 1. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -192,7 +237,7 @@ TEST_F(TritonGemmTest, LHSInt4WithMinorDimEqualTo1) { TEST_F(TritonGemmTest, RHSInt4WithMinorDimEqualTo1) { // We prove that triton can handle int4 dot with non contracting dim size // equal to 1. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -222,7 +267,7 @@ TEST_F(TritonGemmTest, RHSInt4WithMinorDimEqualTo1) { TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { // We prove that triton can handle int4 dot with non minor // lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -250,7 +295,7 @@ TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { // We prove that triton can handle int4 dot with non minor // lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -278,7 +323,7 @@ TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -302,7 +347,7 @@ TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { } TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -328,7 +373,7 @@ TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { // We prove that triton can handle int4 dot with minor lhs_contracting_dim. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -355,7 +400,7 @@ TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { } TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -380,7 +425,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -405,7 +450,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -432,7 +477,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { } TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_computation { @@ -459,7 +504,7 @@ TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { } TEST_F(TritonTest, TestGemm) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -551,7 +596,7 @@ CHECK: } } TEST_F(TritonTest, TestGemmWithTrivialNonContractingDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_dot { @@ -641,7 +686,7 @@ CHECK: } } TEST_F(TritonTest, PredParametersAreTruncatedToI1) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm_computation { @@ -682,7 +727,7 @@ CHECK: %{{.*}} = arith.andi %[[TRUNCI]], %{{.*}} : tensor<16x16xi1> } TEST_F(TritonTest, CodegenBatchedDotWithConcatenationWithCorrectBatchStride) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm { @@ -725,7 +770,7 @@ CHECK: %[[BLOCK_BASE_PTR:.*]] = tt.addptr %[[ARG_PTR]], %[[OFFSET]] TEST_F(TritonTest, CodegenDynamicSliceWithCorrectOffsets) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_gemm { @@ -775,7 +820,7 @@ CHECK-DAG: tt.make_tensor_ptr %[[DYNAMIC_SLICE_INPUT]], [%[[C2_i64]], %[[ROW_L } TEST_F(TritonTest, SparseDot) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -806,7 +851,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonTest, SparseDotWithMasking) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -843,7 +888,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS_MASKED]], %[[RHS_MASKED]], %{{[^:]+}}, %[[ME } TEST_F(TritonTest, SparseDotBroadcastMetadata) { - const char* kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t triton_dot { @@ -880,7 +925,7 @@ CHECK: triton_gpu.sparse_dot %[[LHS]], %[[RHS]], %{{[^:]+}}, %[[META]] : } TEST_F(TritonGemmTest, DoNotUseTensorCoresWithNonDefaultPrecision) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_gemm_r { parameter_0 = s8[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -910,7 +955,7 @@ CHECK-NOT: mma } TEST_F(TritonGemmTest, DebugOptionsArePropagated) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) @@ -962,7 +1007,7 @@ ENTRY main { } TEST_F(TritonGemmTest, UseTensorCoresForF32OnAmpere) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_gemm_r { parameter_0 = f16[80,15]{1,0} parameter(0) convert.3 = f32[80,15]{1,0} convert(parameter_0) @@ -994,7 +1039,7 @@ TEST_F(TritonGemmTest, FailIfTooMuchShmem) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "GEMM padding requirements for ROCM not included yet."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1048,9 +1093,8 @@ ENTRY entry { EXPECT_THAT( TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, block_level_parameters, &llvm_module, mlir_context), - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - ::testing::HasSubstr("Shared memory size limit exceeded"))); + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + ::testing::HasSubstr("Shared memory size limit exceeded"))); config.set_block_m(64); config.set_block_n(128); @@ -1071,7 +1115,7 @@ TEST_F(TritonGemmTestWithSplitK, // The condition mentioned in the test name is fulfilled by // GemmKey(16, 64, 256, 8, 1, 4), which was part of the default configs for // Ampere at the time of the addition of this test case. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule extracted ENTRY e { @@ -1225,7 +1269,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SplitAndTransposeLhsExecutesCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1255,7 +1299,7 @@ TEST_F(TritonGemmTest, NondefaultOperandLayoutIsSupported) { #ifndef NDEBUG GTEST_SKIP() << "This test times out when -UNDEBUG is set."; #endif - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY r { p1 = f16[9,140,128]{2,1,0} parameter(1) cp = f16[9,140,128]{2,0,1} copy(p1) @@ -1428,7 +1472,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultipleBatchRequireSeparateTranspose) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1451,7 +1495,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenNonBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { parameter_0 = f32[3,10]{1,0} parameter(0) parameter_1 = f32[10,128]{1,0} parameter(1) @@ -1475,7 +1519,7 @@ ENTRY e { } TEST_F(TritonGemmTest, CanCodegenBatchedDotWithConcatenationCorrectly) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { parameter_0 = f32[2,3,10]{2,1,0} parameter(0) parameter_1 = f32[2,10,128]{2,1,0} parameter(1) @@ -1520,7 +1564,7 @@ ENTRY e { } TEST_F(TritonTest, FloatToSignedIntConversion) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t, is_scheduled=true triton_gemm_r { @@ -1581,7 +1625,7 @@ ENTRY e { // This tests the complexity heuristics in TritonWrapper. TEST_F(TritonGemmTest, FailForTooComplexTiling) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule module, is_scheduled=true triton_gemm_dot { @@ -1634,9 +1678,8 @@ ENTRY entry { EXPECT_THAT( TritonWrapper("test_fn", triton_dot_fusion, CudaAmpereOrRocm(), dev_info, block_level_parameters, &llvm_module, mlir_context), - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - "Tiling complexity heuristic exceeded: 147456 > 9000")); + StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Tiling complexity heuristic exceeded: 147456 > 9000")); // Succeeds if the tiling is not too complex. config.set_block_m(32); @@ -1743,7 +1786,7 @@ ENTRY e { class TritonGemmTestAny : public TritonGemmTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(true); return debug_options; @@ -1822,7 +1865,7 @@ TEST_F(TritonGemmTest, DynamicSliceIsSupportedInLhsEndToEnd) { // is not strictly needed, because we also support clamping the indices. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -1853,7 +1896,7 @@ ENTRY e { TEST_F(TritonGemmTest, DynamicSliceIsSupportedInRhs) { // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -1886,7 +1929,7 @@ ENTRY e { } TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p0 = bf16[8192,512]{1,0} parameter(0) p1 = bf16[512,512]{1,0} parameter(1) @@ -1969,7 +2012,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostContractingDimIsSupported) { // dimension is contracted. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2006,7 +2049,7 @@ TEST_F(TritonGemmTest, DynamicSliceOfMajormostBatchDimIsSupported) { // dimension is a batch. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2045,7 +2088,7 @@ TEST_F(TritonGemmTest, DynamicSliceSingleDimensionIntoReshapeIsSupported) { // layer weights and extracting them with dynamic slice. // The start index(es) for the non-majormost dimension(s) are constant zero(s) // because we don't support dynamic slice on those dimensions. - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_gemm { @@ -2112,7 +2155,7 @@ ENTRY e { } TEST_F(TritonGemmTest, BroadcastOfScalarWorksCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( fusion { p0 = f16[2,18] parameter(0) p1 = f16[256,2] parameter(1) @@ -2165,7 +2208,7 @@ ENTRY e { class TritonGemmLevel2Test : public TritonGemmTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_fusion_level(2); return debug_options; @@ -2174,7 +2217,7 @@ class TritonGemmLevel2Test : public TritonGemmTest { class TritonGemmLevel2TestAny : public TritonGemmLevel2Test { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmLevel2Test::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(true); return debug_options; @@ -2182,7 +2225,7 @@ class TritonGemmLevel2TestAny : public TritonGemmLevel2Test { }; TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2208,7 +2251,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationWithLargeInputsIsNotFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2239,7 +2282,7 @@ ENTRY e { TEST_F(TritonGemmLevel2Test, ParametersWithDifferentLayoutsAreSupportedInOneScope) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = s8[5,3] parameter(0) p0c = f16[5,3] convert(p0) @@ -2262,7 +2305,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BinaryOperationOnLargeParametersIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2287,7 +2330,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, LinkingLibdeviceTwiceWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = s8[7,3] parameter(0) c0 = f32[7,3] convert(p0) @@ -2318,7 +2361,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarParameterIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[64,256] parameter(0) p0c = f32[64,256] convert(p0) @@ -2339,7 +2382,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfScalarConstantIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2365,7 +2408,7 @@ TEST_F(TritonGemmLevel2Test, DoubleBroadcastOfScalarConstantIsHandled) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { c = s32[] constant(1) bc1 = s32[21]{0} broadcast(c), dimensions={} @@ -2389,7 +2432,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorConstantIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2413,7 +2456,7 @@ TEST_F(TritonGemmLevel2Test, AlwaysFuseScalarConstantAtBroadcastInput) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = bf16[2,3,3]{2,1,0} parameter(0) p1 = bf16[3,2,3]{2,1,0} parameter(1) @@ -2440,7 +2483,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, BroadcastOfVectorParameterIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_dot { p0 = f16[75] parameter(0) bc0 = f16[75,67] broadcast(p0), dimensions={0} @@ -2469,7 +2512,7 @@ TEST_F(TritonGemmLevel2Test, FuseConcatenation) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( e { p0 = s8[153,1536] parameter(0) p1 = s8[153,128] parameter(1) @@ -2495,7 +2538,7 @@ e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2518,7 +2561,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2541,7 +2584,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheLeft) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2564,7 +2607,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheRight) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2587,7 +2630,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsLHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2612,7 +2655,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MinimumReturnsRHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2637,7 +2680,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsLHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2662,7 +2705,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2TestAny, MaximumReturnsRHS) { - constexpr absl::string_view kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t ENTRY e { @@ -2687,7 +2730,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SineOutputIsNotFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2710,7 +2753,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[97,121] parameter(0) s0 = f16[7,101] slice(p0), slice={[3:10], [10:111]} @@ -2731,7 +2774,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SliceInputWithReshapeIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f32[363,1536] parameter(0) p1 = f32[4,1536,611] parameter(1) @@ -2753,7 +2796,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, NestedSlicingWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p1 = f32[6,24] parameter(1) slice1 = f32[5,20] slice(p1), slice={[1:6], [3:23]} @@ -2775,7 +2818,7 @@ ENTRY e { } TEST_F(TritonGemmTest, SlicedBatchDimensionIsSupported) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[3,3,256] parameter(0) s0 = f16[3,3,128] slice(p0), slice={[0:3], [0:3], [123:251]} @@ -2800,7 +2843,7 @@ ENTRY e { TEST_F(TritonGemmTestWithSplitK, SplitKDoesNotBreakSlicedFragmentedContractingDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[16,8,128]{2,1,0} parameter(0) s0 = f16[16,4,128]{2,1,0} slice(p0), @@ -2824,7 +2867,7 @@ ENTRY e { } TEST_F(TritonGemmTestWithSplitK, SplitKWithTrivialDimension) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY entry_computation { p0 = f16[1001,1]{1,0} parameter(0) convert = f32[1001,1]{1,0} convert(p0) @@ -2837,7 +2880,7 @@ ENTRY entry_computation { } TEST_F(TritonGemmLevel2Test, NarrowingConvertOutputIsFused) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2863,7 +2906,7 @@ TEST_F(TritonGemmLevel2Test, ParameterAfterDotIsFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2895,7 +2938,7 @@ TEST_F(TritonGemmLevel2Test, OutputFusionExecutesCorrectly) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2931,7 +2974,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSOutputTransposeAloneIsNotFused) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -2964,7 +3007,7 @@ TEST_F(TritonGemmLevel2Test, SplitLHSInputOutputIsFused) { GTEST_SKIP() << "Skipped until corresponding issue on ROCm is fixed."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0t = (s8[5,18,20,150]) parameter(0) p0 = s8[5,18,20,150] get-tuple-element(p0t), index=0 @@ -2989,7 +3032,7 @@ ENTRY e { } TEST_F(TritonGemmLevel2Test, SupportPredParametersUsedInExpressions) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p = pred[2,2]{1,0} parameter(0) a = f32[2,2]{1,0} parameter(1) @@ -4290,12 +4333,15 @@ triton_dot { cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1) p0 = f16[9,32]{0,1} parameter(0) b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0) - cp0 = f16[3,3,2,16]{1,3,2,0} copy(b0) - cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0) + cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0) + cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3} + cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0) + cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1) m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0) cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m) - cp1 = f16[3,3,2,16]{3,2,1,0} copy(cvt2) - b1 = f16[9,32]{1,0} bitcast(cp1) + cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2) + cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2} + b1 = f16[9,32]{1,0} bitcast(cp1t0) p2 = f16[32,32]{1,0} parameter(2) ROOT r = f16[9,32]{1,0} dot(b1, p2), lhs_contracting_dims={1}, rhs_contracting_dims={0} @@ -4319,12 +4365,15 @@ ENTRY e { cvt1 = f32[3,3,2,16]{1,3,2,0} convert(p1) p0 = f16[9,32]{0,1} parameter(0) b0 = f16[3,3,2,16]{1,0,3,2} bitcast(p0) - cp0 = f16[3,3,2,16]{1,3,2,0} copy(b0) - cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0) + cp0b0 = f16[2,16,3,3]{3,2,1,0} bitcast(b0) + cp0t0 = f16[3,2,16,3]{3,2,1,0} transpose(cp0b0), dimensions={2,0,1,3} + cp0b1 = f16[3,3,2,16]{1,3,2,0} bitcast(cp0t0) + cvt0 = f32[3,3,2,16]{1,3,2,0} convert(cp0b1) m = f32[3,3,2,16]{1,3,2,0} multiply(cvt1, cvt0) cvt2 = f16[3,3,2,16]{1,3,2,0} convert(m) - cp1 = f16[3,3,2,16]{3,2,1,0} copy(cvt2) - b1 = f16[9,32]{1,0} bitcast(cp1) + cp1b0 = f16[3,2,16,3]{3,2,1,0} bitcast(cvt2) + cp1t0 = f16[3,3,2,16]{3,2,1,0} transpose(cp1b0), dimensions={0,3,1,2} + b1 = f16[9,32]{1,0} bitcast(cp1t0) p2 = f16[32,32]{1,0} parameter(2) ROOT r = f16[9,32]{1,0} dot(b1, p2), lhs_contracting_dims={1}, rhs_contracting_dims={0} @@ -4337,7 +4386,7 @@ ENTRY e { class TritonGemmContractionDims : public TritonGemmTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = TritonGemmTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_ensure_minor_dot_contraction_dims(true); debug_options.set_xla_gpu_triton_gemm_any(true); @@ -4350,7 +4399,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_0) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4375,7 +4424,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_1_2) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4400,7 +4449,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_2_0_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4426,7 +4475,7 @@ TEST_F(TritonGemmContractionDims, TritonDotForceContractionDims_1_1) { if (!SupportsBF16(GpuComputeComp())) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -4446,680 +4495,6 @@ ENTRY e { .WithShape(BF16, {16, 40}, {1, 0}))); } -// In these tests, we depend on "algorithm" annotations for selecting the 6XBF16 -// algorithm. -class Triton6xBF16GemmTest : public TritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); - // These 2 flags are not strictly necessary now, but we're adding them to be - // on the safe side against future flakiness. - // - // Enable triton fusion for all supported GEMMs. - debug_options.set_xla_gpu_triton_gemm_any(true); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - return debug_options; - } - - protected: - void SetUp() override { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - } -}; - -// In these tests, we depend on debug option flags for selecting the 6XBF16 -// algorithm. -// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_6way_gemm -// flag after we will support the algorithm values through the entire stack. -class Triton6xBF16GemmTestWithFlag : public TritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); - // Enable triton fusion for all supported GEMMs. - debug_options.set_xla_gpu_triton_gemm_any(true); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - // Enable bf16_6way gemm to compute F32 matmul. - debug_options.set_xla_gpu_enable_bf16_6way_gemm(true); - return debug_options; - } -}; - -TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, - /*arel=*/1e-6})); -} - -TEST_F(Triton6xBF16GemmTestWithFlag, Emit6xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-5: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-6, - /*arel=*/1e-6})); -} - -TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForLongContractingDimension) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,2048] parameter(0) - p1 = f32[2048,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} - -ENTRY e { - p0 = f32[5,2048]{1,0} parameter(0) - p1 = f32[2048,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); -} - -TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleInfinity) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - arguments[0] = - LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), - +std::numeric_limits::infinity()}, - {+std::numeric_limits::infinity(), - +std::numeric_limits::infinity()}}); - arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/0, /*arel=*/0})); -} - -TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmCanHandleNaN) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - arguments[0] = - LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}, - {std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}}); - arguments[1] = LiteralUtil::CreateR2( - {{1.0f, +std::numeric_limits::infinity()}, - {1.0f, +std::numeric_limits::infinity()}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/0, /*arel=*/0})); -} - -// Test case shows that why we truncate the middle term instead of rounding. -// If we round the middle term, the splitted terms may disagree in sign. This -// could result in wrong results for extreme values. -// For example, consider: -// x = -3.40282347e+38 -// If we round the middle term, its decomposition would be: -// x_hi: -3.38953139e+38 -// x_mid: -1.3240357e+36 -// x_lo: 5.17201445e+33 -// The result of x*x would be NaN instead of positive infinity. -TEST_F(Triton6xBF16GemmTest, Triton6xBF16GemmWorksForInputsWithLargeExponent) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-6: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - constexpr float kLargeExponentFloat = 0x1.0103p72f; - arguments[0] = LiteralUtil::CreateR2( - {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); - arguments[1] = LiteralUtil::CreateR2( - {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE( - RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/1e-6, /*arel=*/1e-6})); -} - -TEST_F(Triton6xBF16GemmTest, Emit6xBF16GemmEndToEnd) { - if (std::holds_alternative(GpuComputeComp())) { - GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X6 not supported on ROCM."; - } - const char* kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,32] parameter(0) - p1 = f32[32,7] parameter(1) - ROOT dot = f32[5,7] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x6 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(kHloText)); - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( -CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 -CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 -)"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6, - /*arel=*/1e-6})); -} - -// In these tests, we depend on "algorithm" annotations for selecting the 3XBF16 -// algorithm. -class Triton3xBF16GemmTest : public TritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); - // These 2 flags are not strictly necessary now, but we're adding them the - // to be on the safe side against future flakiness. - // - // Enable triton fusion for all supported GEMMs. - debug_options.set_xla_gpu_triton_gemm_any(true); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - return debug_options; - } -}; - -// In these tests, we depend on debug option flags for selecting the 3XBF16 -// algorithm. -// TODO(b/316147294): Remove this class and the --xla_gpu_enable_bf16_3way_gemm -// flag after we will support the algorithm values through the entire stack. -class Triton3xBF16GemmTestWithFlag : public TritonTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = TritonTest::GetDebugOptionsForTest(); - // Enable triton fusion for all supported GEMMs. - debug_options.set_xla_gpu_triton_gemm_any(true); - // Do not fall back to cuBLAS, we are testing Triton. - debug_options.set_xla_gpu_cublas_fallback(false); - // Do not autotune split-k by default, since this prevents deterministically - // matching the optimized HLO. - debug_options.set_xla_gpu_enable_split_k_autotuning(false); - // Enable bf16_3way gemm to compute F32 matmul. - debug_options.set_xla_gpu_enable_bf16_3way_gemm(true); - return debug_options; - } - - protected: - void SetUp() override { - if (!SupportsBF16(GpuComputeComp())) { - GTEST_SKIP() << "BF16 not supported."; - } - } -}; - -TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); -} - -TEST_F(Triton3xBF16GemmTestWithFlag, Emit3xBF16GemmWhenBothInputsAreF32) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,7] parameter(0) - p1 = f32[7,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f32[5,7]{1,0} parameter(0) - p1 = f32[7,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: %[[INFINITY:.*]] = arith.constant dense<0x7F800000> : tensor<32x32xf32> -CHECK: %[[C_MASK:.*]] = arith.constant dense<-65536> : tensor<32x32xi32> -CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32x32xf32> -CHECK: %[[CAST_I32:.*]] = tt.bitcast %{{.*}} : tensor<32x32xf32> -> tensor<32x32xi32> -CHECK: %[[EXTRACT_HI:.*]] = arith.andi %[[CAST_I32]], %[[C_MASK]] : tensor<32x32xi32> -CHECK: %[[CAST_HI:.*]] = tt.bitcast %[[EXTRACT_HI]] : tensor<32x32xi32> -> tensor<32x32xf32> -CHECK: %[[TRUNC_TO_BF16:.*]] = arith.truncf %[[CAST_HI]] : tensor<32x32xf32> to tensor<32x32xbf16> -CHECK-COUNT-2: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ABS:.*]] = math.absf -CHECK: %[[CMP:.*]] = arith.cmpf ogt, %[[INFINITY]], %[[ABS]] : tensor<32x32xf32> -CHECK: %[[SELECT:.*]] = arith.select %[[CMP]], %{{.*}}, %[[C0]] : tensor<32x32xi1>, tensor<32x32xf32> -CHECK: %[[DOT_LAST:.*]] = tt.dot %{{.*}}, %{{.*}}, %[[SELECT]] : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> -CHECK: %[[ACC:.*]] = arith.addf %[[DOT_LAST]], %[[C0]] : tensor<32x32xf32> - )")); - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); -} - -TEST_F(Triton3xBF16GemmTestWithFlag, NoEmit3xBF16GemmWhenBothInputsAreNotF32) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f16[5,7] parameter(0) - p1 = f16[7,33] parameter(1) - ROOT dot = f16[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -} - -ENTRY e { - p0 = f16[5,7]{1,0} parameter(0) - p1 = f16[7,33]{1,0} parameter(1) - ROOT _ = f16[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1,"num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK: tt.dot -CHECK-SAME: tensor<32x32xf16> * tensor<32x32xf16> -> tensor<32x32xf32> -CHECK-NOT: tt.dot - )")); -} - -TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForLongContractingDimension) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[5,2048] parameter(0) - p1 = f32[2048,33] parameter(1) - ROOT dot = f32[5,33] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} - -ENTRY e { - p0 = f32[5,2048]{1,0} parameter(0) - p1 = f32[2048,33]{1,0} parameter(1) - ROOT _ = f32[5,33] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":64,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":4, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x32xbf16> * tensor<32x32xbf16> -> tensor<64x32xf32> - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/1e-4, - /*arel=*/1e-4})); -} - -TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleInfinity) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - arguments[0] = - LiteralUtil::CreateR2({{+std::numeric_limits::infinity(), - +std::numeric_limits::infinity()}, - {+std::numeric_limits::infinity(), - +std::numeric_limits::infinity()}}); - arguments[1] = LiteralUtil::CreateR2({{1.0f, 1.0f}, {1.0f, 1.0f}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/0, /*arel=*/0})); -} - -TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmCanHandleNaN) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - arguments[0] = - LiteralUtil::CreateR2({{std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}, - {std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}}); - arguments[1] = LiteralUtil::CreateR2( - {{1.0f, +std::numeric_limits::infinity()}, - {1.0f, +std::numeric_limits::infinity()}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/0, /*arel=*/0})); -} - -TEST_F(Triton3xBF16GemmTest, Triton3xBF16GemmWorksForInputsWithLargeExponent) { - const char* kHloText = R"( -HloModule t - -triton_dot { - p0 = f32[2,2] parameter(0) - p1 = f32[2,2] parameter(1) - ROOT dot = f32[2,2] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} - -ENTRY e { - p0 = f32[2,2]{1, 0} parameter(0) - p1 = f32[2,2]{1, 0} parameter(1) - ROOT _ = f32[2,2] fusion(p0, p1), kind=kCustom, calls=triton_dot, - backend_config={"fusion_backend_config": {kind: "__triton_gemm", - triton_gemm_config: - {"block_m":32,"block_n":32,"block_k":32,"split_k":1,"num_stages":1,"num_warps":1, "num_ctas":1}}} -} -)"; - TF_ASSERT_OK( - CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_dot", R"( -CHECK-COUNT-3: %{{.*}} = tt.dot %{{.*}}, %{{.*}}, %{{.*}} : tensor<32x32xbf16> * tensor<32x32xbf16> -> tensor<32x32xf32> - )")); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - GetOptimizedModule(kHloText)); - std::vector arguments(2); - constexpr float kLargeExponentFloat = 0x1.0103p72f; - arguments[0] = LiteralUtil::CreateR2( - {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); - arguments[1] = LiteralUtil::CreateR2( - {{kLargeExponentFloat, 1.0f}, {-kLargeExponentFloat, 1.0f}}); - std::vector argument_ptrs; - absl::c_transform( - arguments, std::back_inserter(argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - EXPECT_TRUE( - RunAndCompareNoHloPasses(std::move(module), argument_ptrs, - ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-4})); -} - -TEST_F(Triton3xBF16GemmTest, Emit3xBF16GemmEndToEnd) { - if (std::holds_alternative(GpuComputeComp())) { - GTEST_SKIP() << "ALG_DOT_BF16_BF16_F32_X3 not supported on ROCM."; - } - const char* kHloText = R"( -HloModule t - -ENTRY e { - p0 = f32[5,32] parameter(0) - p1 = f32[32,7] parameter(1) - ROOT dot = f32[5,7] dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0}, - algorithm=dot_bf16_bf16_f32_x3 -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, - ParseAndReturnVerifiedModule(kHloText)); - CompileAndOptionallyVerifyPtx(std::move(verified_module), - R"( -CHECK: mma.sync.aligned.{{.*}}.row.col.f32.bf16.bf16.f32 -CHECK-NOT: mma.sync.aligned.{{.*}}.row.col.f32.tf32.tf32.f32 -)"); - EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-5, - /*arel=*/1e-5})); -} - // This test could be modified to allow TF32 once this bug is fixed. // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. TEST_F(TritonTest, NoTF32For8BitOrLessWithF32) { @@ -5299,7 +4674,7 @@ TEST_F(TritonGemmTest, TestNoAutotuner) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Autotuner is always in pipeline on Cuda."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( ENTRY e { p0 = f16[30,30] parameter(0) p1 = s8[30,30] parameter(1) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 3cd5da3117c673..32590c9c07d4ba 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "llvm/IR/LLVMContext.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/PassManager.h" @@ -30,6 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" @@ -40,6 +42,7 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -59,7 +62,7 @@ class TritonEmitterTest : public GpuCodegenTest { }; TEST_F(TritonEmitterTest, ReductionOnMinormostAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -87,7 +90,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 1 : i32}> } TEST_F(TritonEmitterTest, ReductionOnMajormostAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -115,7 +118,7 @@ CHECK: "tt.reduce"(%[[LOAD:.*]]) <{axis = 0 : i32}> } TEST_F(TritonEmitterTest, ReductionOnIntermediateAxisIsEmittedCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -145,7 +148,7 @@ CHECK: "tt.reduce"(%[[SELECT:.*]]) <{axis = 2 : i32}> } TEST_F(TritonEmitterTest, TestReductionWithTileSizeLargerThanSourceTensor) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t maximum { Arg_0 = f32[] parameter(0) @@ -186,7 +189,7 @@ CHECK: }) // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithSoftMaxSingleParameter) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -213,7 +216,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[C125:.*]] = arith.constant 125 : i64 @@ -247,7 +250,7 @@ CHECK: } // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleParameters) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { @@ -278,7 +281,7 @@ ENTRY main { "num_warps":"1"}}}})"; TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #indexing_map = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 124], is_simplified: true> +CHECK: #indexing_map = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 124]"> CHECK: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -312,7 +315,7 @@ CHECK-DAG: tt.store {{.*}} : !tt.ptr> } TEST_F(TritonEmitterTest, TestGenericEmitterWithMultipleTiledDimensions) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t max { @@ -349,9 +352,9 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249], is_simplified: true> -CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 125), domain: d0 in [0, 1249], is_simplified: true> -CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127), domain: d0 in [0, 1249], is_simplified: true> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP1:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 mod 125), domain: d0 in [0, 1249]"> +CHECK: #[[MAP2:.*]] = #xla_gpu.indexing_map<"(d0) -> (d0 * 127), domain: d0 in [0, 1249]"> CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 : i32 CHECK-DAG: %[[ZERO_64:.*]] = arith.constant 0 : i64 @@ -395,7 +398,7 @@ CHECK-NEXT: tt.store {{.*}} : !tt.ptr> TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongReductionDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -433,7 +436,7 @@ TEST_F(TritonEmitterTest, NestedReducerFusionGetsCodegenedCorrectly) { GTEST_SKIP() << "BF16 not supported."; } - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule softmax fused_convert { @@ -472,7 +475,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalDiamondParameterBroadcastedAlongBatchDimProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -505,7 +508,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalSplatDiamondScalarParameterProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -542,8 +545,8 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, "triton_softmax_computation", R"( -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047], is_simplified: true> -// CHECK: #xla_gpu.indexing_map<(d0) -> (d0 mod 32), domain: d0 in [0, 2047], is_simplified: true> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 floordiv 32), domain: d0 in [0, 2047]"> +// CHECK: #xla_gpu.indexing_map<"(d0) -> (d0 mod 32), domain: d0 in [0, 2047]"> // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -560,7 +563,7 @@ ENTRY main { TEST_F( TritonEmitterTest, DiamondWithAdditionalBroadcastOf1DParameterAlongNonReductionDimensionsProducesAccurateResults) { // NOLINT(whitespace/line_length) - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule h1 max_computation { @@ -594,7 +597,7 @@ ENTRY main { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. TEST_F(TritonEmitterTest, EmitterFailsIfComputeCapabilityIsBelowAmpere) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p0 = f32[10,10] parameter(0) p1 = f32[10,10] parameter(1) @@ -694,7 +697,7 @@ ENTRY entry_computation { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should b // moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t add { Arg_0 = f32[] parameter(0) @@ -736,7 +739,7 @@ CHECK: tt.store {{.*}} : !tt.ptr> TEST_F(TritonEmitterTest, TestGenericEmitterWithReductonAndMultidimensionalTile) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule t max { Arg_0 = f32[] parameter(0) @@ -764,7 +767,7 @@ ENTRY main { } TEST_F(TritonEmitterTest, TestSoftMaxWithTileElementsNotAllContiguous) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m region { @@ -793,7 +796,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileThatNeedsMasking) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -812,7 +815,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguous) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -831,7 +834,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, TestSliceWithTileElementsNotAllContiguousUnaligned) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m fused_computation { @@ -854,7 +857,7 @@ ENTRY entry_computation { } TEST_F(TritonEmitterTest, ReshapeIntoBroadcastIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) reshape = f32[64,2,256]{2,1,0} reshape(param_0) @@ -880,7 +883,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastIntoBroadcastIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { param_0 = f32[128,256]{1,0} parameter(0) bitcast = f32[64,2,256]{2,1,0} bitcast(param_0) @@ -906,7 +909,7 @@ CHECK: tt.reshape } TEST_F(TritonEmitterTest, BitcastNormalizedLayoutsIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -934,7 +937,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedInputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7] bitcast(p) @@ -962,7 +965,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastNonNormalizedOutputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[5,42] parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -991,7 +994,7 @@ CHECK: tt.store TEST_F(TritonEmitterTest, BitcastNonNormalizedInputOutputLayoutIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,6,7]{1,2,0} bitcast(p) @@ -1019,7 +1022,7 @@ CHECK: tt.store } TEST_F(TritonEmitterTest, BitcastTransposeOnlyIsLoweredCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { p = s8[42,5]{0,1} parameter(0) ROOT bitcast = s8[5,42] bitcast(p) @@ -1048,17 +1051,14 @@ CHECK: tt.store // TODO(b/353484968): move this test to a deviceless file. TEST_F(TritonEmitterTest, GenericEmitterLowersBroadcastFrom0dOperandCorrectly) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( triton_computation { - // TODO(b/348565795): make this a 0D scalar directly once this is known to be - // supported. - param_0 = f32[1] parameter(0) - reshape = f32[] reshape(param_0) - ROOT broadcast = f32[127,125]{1,0} broadcast(reshape), dimensions={} + param_0 = f32[] parameter(0) + ROOT broadcast = f32[127,125]{1,0} broadcast(param_0), dimensions={} } ENTRY main { - param_0 = f32[1] parameter(0) + param_0 = f32[] parameter(0) ROOT triton_fusion = f32[127,125]{1,0} fusion(param_0), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": @@ -1068,15 +1068,14 @@ ENTRY main { })"; TF_EXPECT_OK( CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( -CHECK: tt.broadcast -CHECK-SAME: tensor<1x1xf32> -> tensor<8x4xf32> +CHECK: tt.splat {{.*}} f32 -> tensor<8x4xf32> )")); } TEST_F(TritonEmitterTest, PredOutputIsStoredCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1109,7 +1108,7 @@ CHECK: tt.store {{.*}} %[[CASTED_OUT]] TEST_F(TritonEmitterTest, PredInputIsLoadedCorrectly) { // The 'pred' element type in XLA is unpacked and uses i8 for storage. This // is the only sub-byte type to have this behavior. - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1145,7 +1144,7 @@ CHECK: arith.trunci %[[I8_PARAM]] : tensor<4xi8> to tensor<4xi1> } TEST_F(TritonEmitterTest, Transpose3D) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1175,7 +1174,7 @@ CHECK: tt.trans %[[TILE]] {order = array} : tensor<8x4x1xf32> // TODO(b/353484968): Delete this test once we have constraints to only // propagate tile sizes that are a power of 2. TEST_F(TritonEmitterTest, Transpose3D_TileFullDimThatIsNotPowerOf2) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m triton_computation { @@ -1196,6 +1195,214 @@ ENTRY main { RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); } +TEST_F(TritonEmitterTest, StridedIota4DIsCodegeneratedCorrectly) { + constexpr std::string_view kHloText = R"( +triton_computation { + iota = f32[3,4,1000,5] iota(), iota_dimension=2 + ROOT slice = f32[3,4,182,5] slice(iota), slice={[0:3], [0:4], [91:1000:5], [0:5]} +} + +ENTRY main { + ROOT triton_fusion = f32[3,4,182,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})"; + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.muli{{.*}} %[[RANGE]] +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +class IotaEmitterParametrizedTest + : public TritonEmitterTest, + public ::testing::WithParamInterface {}; + +TEST_P(IotaEmitterParametrizedTest, Iota4DIsCodegeneratedCorrectly) { + auto data_type = GetParam(); + const std::string kHloText = + absl::Substitute(R"( +triton_computation { + ROOT iota = $0[3,4,1000,5] iota(), iota_dimension=2 +} + +ENTRY main { + ROOT triton_fusion = $0[3,4,1000,5] fusion(), + kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","2","64","8"], + "num_warps":"1"}}} +})", + primitive_util::LowercasePrimitiveTypeName(data_type)); + + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: %[[RANGE:.*]] = tt.make_range {{.*}} : tensor<64xi32> +CHECK: arith.addi{{.*}} %[[RANGE]] + // Omit the data type below, since it depends on a test parameter + // and is not abbreviated the same as in HLO. +CHECK: tt.broadcast {{.*}} -> tensor<1x2x64x8x +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +INSTANTIATE_TEST_SUITE_P(IotaEmitterParametrizedTestSuite, + IotaEmitterParametrizedTest, + ::testing::ValuesIn({S8, S16, S32, S64, BF16, F16, F32, + F64})); + +TEST_F(TritonEmitterTest, ReducePrecisionIsLoweredCorrectly) { + const std::string kHloText = R"( +triton_computation { + p = f32[5,7] parameter(0) + ROOT rp = f32[5,7] reduce-precision(p), exponent_bits=2, mantissa_bits=2 +} + +ENTRY entry_computation { + p = f32[5,7] parameter(0) + ROOT fusion = f32[5,7] fusion(p), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":["4","4"], "num_warps":"1"}} + } +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: tt.load +)")); + + EXPECT_TRUE( + RunAndCompareNoHloPasses(kHloText, ErrorSpec{/*aabs=*/0, /*arel=*/0})); +} + +TEST_F(TritonEmitterTest, Chaining0DElementwiseScalarsIsSupported) { + const std::string kHloText = R"( +triton_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + exp0 = f32[] exponential(p0) + exp1 = f32[] exponential(p1) + neg0 = f32[] negate(exp0) + neg1 = f32[] negate(exp1) + add = f32[] add(neg0, neg1) + mul = f32[] multiply(add, add) + div = f32[] divide(mul, p0) + conv = bf16[] convert(div) + const = bf16[] constant(0.5) + ROOT sub = bf16[] subtract(conv, const) +} + +ENTRY entry_computation { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT fusion = bf16[] fusion(p0, p1), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":[], "num_warps":"1"}} + } +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: tt.load {{.*}} !tt.ptr +CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 +CHECK: arith.subf {{.*}} f32 +CHECK: tt.load {{.*}} !tt.ptr +CHECK: tt.extern_elementwise {{.*}} (f32) -> f32 +CHECK: arith.subf {{.*}} f32 +CHECK: arith.addf {{.*}} f32 +CHECK: arith.mulf {{.*}} f32 +CHECK: arith.divf {{.*}} f32 +CHECK: arith.truncf {{.*}} f32 to bf16 +CHECK: arith.subf {{.*}} bf16 +CHECK: tt.store {{.*}} !tt.ptr +)")); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/6e-1, /*arel=*/6e-1})); +} + +TEST_F(TritonEmitterTest, Multiple0DBroadcastsAreSupported) { + const std::string kHloText = R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +triton_computation { + p = f32[] parameter(0) + exp = f32[] exponential(p) + b1 = f32[10] broadcast(exp), dimensions={} + b2 = f32[10,10] broadcast(exp), dimensions={} + b3 = f32[10,10] broadcast(b1), dimensions={0} + add = f32[10,10] add(b2,b3) + c = f32[] constant(0) + reduce1 = f32[10] reduce(add, c), dimensions={0}, to_apply=add + ROOT reduce2 = f32[] reduce(reduce1, c), dimensions={0}, to_apply=add +} + +ENTRY entry_computation { + p = f32[] parameter(0) + ROOT fusion = f32[] fusion(p), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":[], "num_warps":"1"}} + } +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: tt.load +CHECK: tt.splat +CHECK: arith.addf +CHECK: tt.reduce +CHECK: tt.store {{.*}} !tt.ptr +)")); + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/6e-1, /*arel=*/6e-1})); +} + +TEST_F(TritonEmitterTest, ReshapeTo0DIsSupported) { + const std::string kHloText = R"( +triton_computation { + p0 = f32[1,1,1,1] parameter(0) + p1 = f32[1] parameter(1) + reshape1 = f32[] reshape(p0) + reshape2 = f32[] reshape(p1) + ROOT add = f32[] add(reshape1, reshape2) +} + +ENTRY entry_computation { + p0 = f32[1,1,1,1] parameter(0) + p1 = f32[1] parameter(1) + ROOT fusion = f32[] fusion(p0, p1), kind=kCustom, calls=triton_computation, + backend_config={ + "fusion_backend_config":{ "kind":"__triton", "block_level_fusion_config":{ + "output_tile_sizes":[], "num_warps":"1"}} + } +})"; + TF_EXPECT_OK( + CreateTritonIrAndFileCheck(this, kHloText, "triton_computation", R"( +CHECK: tt.reshape +CHECK: tt.reduce{{.*}}axis = 0 +CHECK-NOT: tt.reshape +CHECK: tt.reduce{{.*}}axis = 0 +CHECK: tt.store {{.*}} !tt.ptr +)")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloText, ErrorSpec{0, 0})); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc index a50776c6c54e9f..3ffa43bca72bc9 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_large_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include "absl/log/check.h" @@ -28,7 +29,21 @@ namespace { class TritonGemmTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; + } + } + + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; @@ -73,7 +88,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeNonContractingProductWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -97,7 +112,7 @@ ENTRY e { } TEST_F(TritonGemmTest, LargeBatchWorks) { - const std::string kHloText = R"( + constexpr std::string_view kHloText = R"( HloModule m ENTRY e { @@ -121,7 +136,7 @@ ENTRY e { class TritonSoftmaxTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options .set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); @@ -130,7 +145,7 @@ class TritonSoftmaxTest : public GpuCodegenTest { }; TEST_F(TritonSoftmaxTest, - CanFuseAndEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) { + CanEmitDiamondWithInputNumberOfElementsLargerThanInt32Max) { const std::string hlo_text = R"( HloModule softmax @@ -140,26 +155,26 @@ max_computation { ROOT maximum = f16[] maximum(arg_0, arg_1) } -ENTRY main { +triton_fusion_computation { param_0 = f16[65538,32768]{1,0} parameter(0) constant_neg_inf = f16[] constant(-inf) reduce = f16[65538]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation broadcast = f16[65538,32768]{1,0} broadcast(reduce), dimensions={0} ROOT subtract = f16[65538,32768]{1,0} subtract(param_0, broadcast) } -)"; - MatchOptimizedHlo(hlo_text, R"( -; CHECK: ENTRY -; CHECK: %[[P0:.*]] = f16[65538,32768]{1,0} parameter(0) -; CHECK: ROOT -; CHECK-SAME: fusion(%[[P0]]) -; CHECK-SAME: kind=kCustom -; CHECK-SAME: __triton -)"); +ENTRY main { + param_0 = f16[65538,32768]{1,0} parameter(0) + ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kCustom, + calls=triton_fusion_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", + "block_level_fusion_config":{"output_tile_sizes":["1","32768"], + "num_warps":"1"}}} +})"; // Checking that this does not crash should be enough. - EXPECT_TRUE(Run(hlo_text)); + EXPECT_TRUE(Run(hlo_text, /*run_hlo_passes=*/false)); } } // namespace diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc new file mode 100644 index 00000000000000..113579a0ee16c6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.cc @@ -0,0 +1,2268 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LLVM.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" +#include "xla/mlir_hlo/mhlo/transforms/transformation_helpers.h" +#include "xla/primitive_util.h" +#include "xla/service/algorithm_util.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_indexing_utils.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/target_util.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/service/gpu/triton_tiling_propagation.h" +#include "xla/service/llvm_ir/llvm_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tensor_float_32_utils.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace xla::gpu { + +namespace ma = ::mlir::arith; +namespace mm = ::mlir::math; +namespace mt = ::mlir::triton; + +using ::llvm::SmallVector; +using ::mlir::ArrayRef; +using ::mlir::ImplicitLocOpBuilder; +using ::mlir::ShapedType; +using ::mlir::Type; +using ::mlir::Value; +using ::mlir::ValueRange; + +namespace { + +absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { + switch (t) { + case F64: + return b.getF64Type(); + case F32: + return b.getF32Type(); + case F16: + return b.getF16Type(); + case BF16: + return b.getBF16Type(); + case S64: + return b.getI64Type(); + case S32: + return b.getI32Type(); + case S16: + return b.getI16Type(); + case PRED: + return b.getI1Type(); + case S8: + return b.getI8Type(); + case S4: // The unpacking to i8 is supported by the emitter. + // We pass the s4 tensor as i8 tensor with the minor dimension having 2x + // less elements and unpack in the inner loop of the triton kernel. + return b.getI8Type(); + case F8E5M2: + return b.getFloat8E5M2Type(); + case F8E4M3FN: + return b.getFloat8E4M3FNType(); + default: + return absl::UnimplementedError( + absl::StrCat("This type is not supported yet: ", + primitive_util::LowercasePrimitiveTypeName(t))); + } +} + +Type StorageType(mlir::OpBuilder b, Type t) { + if (t.isInteger(1)) { + return b.getI8Type(); + } + return t; +} + +// Create a scalar constant. +template +mlir::arith::ConstantOp CreateConst(mlir::ImplicitLocOpBuilder b, + mlir::Type type, T value) { + if (mlir::isa(type)) { + return b.create(b.getIntegerAttr(type, value)); + } + if (mlir::isa(type)) { + return b.create( + b.getFloatAttr(type, static_cast(value))); + } + LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); +} + +// Create a tensor constant. +template +mlir::arith::ConstantOp CreateConst(mlir::ImplicitLocOpBuilder& b, + mlir::Type type, T value, + llvm::ArrayRef shape) { + auto tensor_type = mlir::RankedTensorType::get(shape, type); + if (auto int_type = mlir::dyn_cast(type)) { + return b.create(mlir::DenseElementsAttr::get( + tensor_type, mlir::APInt(int_type.getIntOrFloatBitWidth(), value))); + } + if (auto float_type = mlir::dyn_cast(type)) { + return b.create(mlir::DenseElementsAttr::get( + tensor_type, b.getFloatAttr(type, static_cast(value)))); + } + LOG(FATAL) << "Constant type not supported: " << llvm_ir::DumpToString(type); +} + +Value ZerosLike(ImplicitLocOpBuilder& b, Value x) { + if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { + Type src_ty = src_shaped_ty.getElementType(); + return CreateConst(b, src_ty, 0, src_shaped_ty.getShape()); + } + return CreateConst(b, x.getType(), 0); +} + +Value OnesLike(ImplicitLocOpBuilder& b, Value x) { + if (auto src_shaped_ty = mlir::dyn_cast(x.getType())) { + Type src_ty = src_shaped_ty.getElementType(); + return CreateConst(b, src_ty, 1, src_shaped_ty.getShape()); + } + return CreateConst(b, x.getType(), 1); +} + +bool IsFp8Type(Type t) { + return t.isFloat8E5M2() || t.isFloat8E4M3FN() || t.isFloat8E5M2FNUZ() || + t.isFloat8E4M3FNUZ() || t.isFloat8E4M3B11FNUZ(); +} + +Value Cast(ImplicitLocOpBuilder& b, Value value, Type dst_element_ty) { + Type src_ty = value.getType(); + Type src_element_ty = src_ty; + Type fp32_ty = b.getF32Type(); + Type dst_ty = dst_element_ty; + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + src_element_ty = src_shaped_ty.getElementType(); + dst_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), dst_element_ty); + fp32_ty = src_shaped_ty.clone(src_shaped_ty.getShape(), b.getF32Type()); + } + if (src_ty == dst_ty) { + return value; + } + + // All operations on bf16 are done through f32. + if (src_element_ty.isBF16()) { + return Cast(b, b.create(fp32_ty, value), dst_element_ty); + } + if (dst_element_ty.isBF16()) { + // S8 -> BF16 is directly supported and doesn't need to go through f32. + if (!src_element_ty.isInteger(8)) { + return b.create(dst_ty, Cast(b, value, b.getF32Type())); + } + } + + // float => float + auto src_fp_element_ty = mlir::dyn_cast(src_element_ty); + auto dst_fp_element_ty = mlir::dyn_cast(dst_element_ty); + if (src_fp_element_ty && dst_fp_element_ty) { + // F8 <-> FP16, BF16, FP32, FP64 need to be handled via Triton's tt.fp_to_fp + // because LLVM doesn't support casts from/to FP8. + // TODO(b/266862493): Add end-to-end test once FP8 support lands in XLA as + // we can't test the code below without patching the feature. + if (IsFp8Type(src_element_ty)) { + return b.create(dst_ty, value); + } + if (IsFp8Type(dst_element_ty)) { + return b.create( + dst_ty, value, + mt::RoundingModeAttr::get(b.getContext(), mt::RoundingMode::RTNE)); + } + + if (src_fp_element_ty.getFPMantissaWidth() > + dst_fp_element_ty.getFPMantissaWidth()) { + return b.create(dst_ty, value); + } else { + return b.create(dst_ty, value); + } + } + // int => int + if (mlir::isa(src_element_ty) && + mlir::isa(dst_element_ty)) { + if (src_element_ty.getIntOrFloatBitWidth() < + dst_element_ty.getIntOrFloatBitWidth()) { + if (src_element_ty.isInteger(1)) { + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + // int => float + if (mlir::isa(src_element_ty) && dst_fp_element_ty) { + // TODO(b/266862493): Support unsigned integer types. + if (src_element_ty.isInteger(1)) { + return b.create(dst_ty, value); + } + return b.create(dst_ty, value); + } + // float => int + if (src_fp_element_ty && mlir::isa(dst_element_ty)) { + if (dst_element_ty.isInteger(1)) { + return b.create(ma::CmpFPredicate::UNE, value, + ZerosLike(b, value)); + } + // TODO(b/266862493): Support unsigned integer types. + // The current logic handles signed integer types only. Additional handling + // is needed for unsigned integer types. + auto cst_int = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, dst_element_ty, x, src_shaped_ty.getShape()); + } else { + return CreateConst(b, dst_element_ty, x); + } + }; + auto cst_float = [&](int64_t x) { + if (auto src_shaped_ty = mlir::dyn_cast(src_ty)) { + return CreateConst(b, src_fp_element_ty, x, src_shaped_ty.getShape()); + } else { + return CreateConst(b, src_fp_element_ty, x); + } + }; + auto fptosi = b.create(dst_ty, value); + int64_t min = llvm::minIntN(dst_element_ty.getIntOrFloatBitWidth()); + int64_t max = llvm::maxIntN(dst_element_ty.getIntOrFloatBitWidth()); + + // value <= static_cast(INT_MIN) ? INT_MIN : ... + auto clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OLE, value, + cst_float(min)), + cst_int(min), fptosi); + // value >= static_cast(INT_MAX) ? INT_MAX : ... + clamped = b.create( + b.create(mlir::arith::CmpFPredicate::OGE, value, + cst_float(max)), + cst_int(max), clamped); + // isnan(value) ? 0 : ... + return b.create( + b.create(mlir::arith::CmpFPredicate::UNO, value, + value), + cst_int(0), clamped); + } + + LOG(FATAL) << "Type conversion not supported: " + << llvm_ir::DumpToString(src_element_ty) << " -> " + << llvm_ir::DumpToString(dst_element_ty); +} + +Value Subtract(ImplicitLocOpBuilder& b, ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values[0], values[1]); + } else { + return b.create(values[0], values[1]); + } +} + +Value Compare(ImplicitLocOpBuilder& b, ValueRange values, + mlir::mhlo::ComparisonDirection direction) { + const Type type = mlir::getElementTypeOrSelf(values[0]); + if (mlir::isa(type)) { + return b.create( + mlir::mhlo::impl::getCmpPredicate( + direction, + /*isSigned=*/!type.isInteger(1)) + .value(), + values[0], values[1]); + } + return b.create( + mlir::mhlo::impl::getCmpPredicate(direction, + /*isSigned=*/true) + .value(), + values[0], values[1]); +} + +Value Maximum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values); + } + // logic: isNaN(lhs) || (!isNan(rhs) && lhs >= rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This also works, but we wanted to make it similar to minimum. + // logic: isNaN(lhs) || lhs >= rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_ge = Compare(b, values, mlir::mhlo::ComparisonDirection::GE); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_ge)), + values[0], values[1]); +} + +Value Minimum(ImplicitLocOpBuilder& b, const se::DeviceDescription& device_info, + ValueRange values) { + if (mlir::isa(mlir::getElementTypeOrSelf(values[0]))) { + return b.create(values); + } + // logic: isNaN(lhs) || (!isNan(rhs) && lhs <= rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This should also work, but the tests show that it doesn't work for + // minimum(x, NaN): + // logic: isNaN(lhs) || lhs <= rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_le = Compare(b, values, mlir::mhlo::ComparisonDirection::LE); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_le)), + values[0], values[1]); +} + +Value Splat(ImplicitLocOpBuilder& b, Value value, ArrayRef shape) { + auto type = mlir::RankedTensorType::get(shape, value.getType()); + return b.create(type, value); +} + +absl::StatusOr EmitElementwise(ImplicitLocOpBuilder& b, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloInstruction& hlo, + ValueRange inputs) { + if (mlir::getElementTypeOrSelf(inputs[0]).isF32() || + mlir::getElementTypeOrSelf(inputs[0]).isF64()) { + auto dev_fn_id = GetTargetDeviceFunctionID(hlo.opcode()); + if (dev_fn_id.ok()) { + llvm::Triple triple("nvptx64-unknown-unknown"); + if (std::holds_alternative( + device_info.gpu_compute_capability())) { + triple.setTriple("amdgcn-unknown-unknown"); + } + return b.create( + inputs[0].getType(), inputs, "libdevice", libdevice_path, + ObtainDeviceFunctionName(dev_fn_id.value(), + hlo.shape().element_type(), triple), + /*pure=*/true); + } + } + const bool is_integer = + mlir::isa(mlir::getElementTypeOrSelf(inputs[0])); + + switch (hlo.opcode()) { + case HloOpcode::kCopy: + // Dimension transformations are taken care of separately. + return inputs[0]; + case HloOpcode::kAbs: + if (is_integer) { + return b.create(inputs[0]); + } + return b.create(inputs[0]); + case HloOpcode::kCeil: + return b.create(inputs[0]); + case HloOpcode::kFloor: + return b.create(inputs[0]); + case HloOpcode::kNot: + return b.create(inputs[0], OnesLike(b, inputs[0])); + case HloOpcode::kNegate: + // NegFOp is not supported by Triton. + return Subtract(b, {ZerosLike(b, inputs[0]), inputs[0]}); + case HloOpcode::kConvert: { + TF_ASSIGN_OR_RETURN(Type dst_ty, + TritonType(b, hlo.shape().element_type())); + return Cast(b, inputs[0], dst_ty); + } + case HloOpcode::kAdd: + if (is_integer) { + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kSubtract: + return Subtract(b, inputs); + case HloOpcode::kMultiply: + if (is_integer) { + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kMaximum: + return Maximum(b, device_info, inputs); + case HloOpcode::kMinimum: + return Minimum(b, device_info, inputs); + case HloOpcode::kClamp: + return Maximum( + b, device_info, + {Minimum(b, device_info, {inputs[1], inputs[2]}), inputs[0]}); + case HloOpcode::kAnd: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kOr: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kXor: + return b.create(inputs[0], inputs[1]); + case HloOpcode::kDivide: + if (is_integer) { + // Unsigned not supported yet. + return b.create(inputs[0], inputs[1]); + } + return b.create(inputs[0], inputs[1]); + case HloOpcode::kCompare: + return Compare( + b, inputs, + mlir::mhlo::symbolizeComparisonDirection( + ComparisonDirectionToString(hlo.comparison_direction())) + .value()); + case HloOpcode::kSelect: + return b.create( + Compare(b, {inputs[0], ZerosLike(b, inputs[0])}, + mlir::mhlo::ComparisonDirection::NE), + inputs[1], inputs[2]); + case HloOpcode::kReducePrecision: + return mlir::mhlo::reducePrecision( + b.getLoc(), inputs[0], hlo.exponent_bits(), hlo.mantissa_bits(), &b); + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported elementwise operation ", hlo.ToString())); + } +} + +absl::StatusOr EmitConstant(ImplicitLocOpBuilder& b, + const HloInstruction& constant) { + CHECK_EQ(constant.opcode(), HloOpcode::kConstant); + CHECK(ShapeUtil::IsEffectiveScalar(constant.shape())); + + TF_ASSIGN_OR_RETURN(Type ty, TritonType(b, constant.shape().element_type())); + + if (constant.shape().element_type() == U64) { + TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(U64)); + return CreateConst(b, ty, converted.GetFirstElement()); + } + + if (constant.shape().IsInteger()) { + TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(S64)); + return CreateConst(b, ty, converted.GetFirstElement()); + } + + TF_ASSIGN_OR_RETURN(Literal converted, constant.literal().Convert(F64)); + return CreateConst(b, ty, converted.GetFirstElement()); +} + +// Emit sequence of operations for unpacking 2xi4 -> i8. +absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, + const HloInstruction* hlo, + int64_t unpack_dim_idx, Value& value) { + VLOG(6) << "EmitUnpackInt4: " << hlo->ToString(); + auto input_type = mlir::cast(value.getType()); + if (input_type.getShape().size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("UnpackInt4 works only for 2d inputs: ", hlo->ToString())); + } + // We use shifts instead the mask because we need to keep the sign bit. + Value shift4 = + Splat(b, CreateConst(b, b.getI8Type(), 4), input_type.getShape()); + Value lo = b.create(b.create(value, shift4), shift4); + Value hi = b.create(value, shift4); + Value result = b.create(hi, lo); + if (unpack_dim_idx == 0) { + result = b.create(result, b.getDenseI32ArrayAttr({0, 2, 1})); + } + SmallVector result_shape(input_type.getShape()); + result_shape[unpack_dim_idx] *= 2; + auto type = mlir::RankedTensorType::get(result_shape, b.getI8Type()); + return b.create(type, result, /*allow_reorder=*/false); +} + +using TensorValue = mlir::TypedValue; + +Value Broadcast(ImplicitLocOpBuilder& b, TensorValue value, + ArrayRef shape) { + return b.create(value.getType().clone(shape), value); +} + +Value Range(ImplicitLocOpBuilder& b, int32_t limit) { + auto type = mlir::RankedTensorType::get(limit, b.getI32Type()); + return b.create(type, 0, limit); +} + +Value AddPtr(ImplicitLocOpBuilder& b, Value ptr, Value offset) { + return b.create(ptr.getType(), ptr, offset); +} + +Value EmitParameterLoad(ImplicitLocOpBuilder& b, Value pointer, + ArrayRef boundary_checks) { + // 0-D MakeTensorPtrOp + // + // Triton tries to access the -1 element of a vector and segfaults when + // lowering the code to load a 0-D tensor to LLVM. The workaround is to load a + // regular pointer + a splat. + if (auto make_tensor_ptr = pointer.getDefiningOp()) { + if (make_tensor_ptr.getOffsets().empty()) { + return Splat(b, + b.create(make_tensor_ptr.getBase(), + mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false), + {}); + } + } + + // Any other tensor pointer. + if (mt::isTensorPointerType(pointer.getType())) { + std::optional padding; + if (!boundary_checks.empty()) { + padding = mt::PaddingOption::PAD_ZERO; + } + return b.create(pointer, boundary_checks, padding, + mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false); + } + + // Non-tensor pointer. + // + // TODO(b/343013366): Remove this after we delete the legacy SoftMax code. + // It's the only place where this code-path is used. + return Splat(b, + b.create(pointer, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false), + {}); +} + +// Grouped properties of tiled dimensions used to generate block pointers. +struct DimProperties { + DimProperties(int64_t index, Value pid, int block_size, int split_value) + : index(index), + pid(pid), + block_size(block_size), + split_value(split_value) {} + + // Logical index of the dimension at the tiling-defining operation. + int64_t index; + // Block program ID corresponding to this dimension. + Value pid; + // Elements of the dimension to process per block program. + int block_size; + // Size of the major part of the dimension if it's split into two parts. + int split_value; +}; + +struct Side { + explicit Side(TritonFusionAnalysis::Scope scope, + std::vector tiled_dims = {}, + std::optional batch_dim_idx = std::nullopt) + : scope(scope), tiled_dims(tiled_dims), batch_dim_idx(batch_dim_idx) {} + TritonFusionAnalysis::Scope scope; + std::vector tiled_dims; + std::optional batch_dim_idx; + int64_t unpack_dim_idx = 0; +}; + +absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, + const TritonFusionAnalysis* analysis, + const Side& side, + const HloInstruction& broadcast, + Value input) { + TF_RET_CHECK(analysis != nullptr); + std::vector out_shape; + for (const DimProperties& dim : side.tiled_dims) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis->IterSpec(side.scope, &broadcast, dim.index); + if (spec != nullptr && spec->at(0).stride > 0) { + out_shape.push_back(dim.block_size); + } + } + auto tensor_input = mlir::dyn_cast(input); + if (!tensor_input) { + // Input is scalar. + return Splat(b, input, out_shape); + } + if (tensor_input.getType().getRank() == out_shape.size()) { + // No dimensions to broadcast. + return input; + } + // Add broadcasted dimensions one by one. + Value expanded_input = tensor_input; + int dim_idx = 0; + for (const DimProperties& dim : side.tiled_dims) { + if (auto* spec = analysis->IterSpec(side.scope, &broadcast, dim.index); + spec != nullptr && spec->at(0).stride > 0) { + if (analysis->IterSpec(side.scope, broadcast.operand(0), dim.index) == + nullptr) { + // Broadcasted dimension. + expanded_input = b.create(expanded_input, dim_idx); + } + ++dim_idx; + } + } + return Broadcast(b, mlir::cast(expanded_input), out_shape); +} + +// Emit sequence of instructions using compatible tiling ordered producers +// before consumers. +absl::StatusOr EmitScope( + ImplicitLocOpBuilder& b, absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const TritonFusionAnalysis* analysis, const Side& side, + absl::Span instructions, + absl::flat_hash_map& values) { + for (const HloInstruction* hlo : instructions) { + Value result; + if (hlo->opcode() == HloOpcode::kConvert && + hlo->operand(0)->shape().element_type() == S4) { + TF_ASSIGN_OR_RETURN( + auto unpacked, + EmitUnpackInt4(b, hlo, side.unpack_dim_idx, values[hlo->operand(0)])); + std::vector operands({unpacked}); + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); + } else if (hlo->opcode() == HloOpcode::kConcatenate || + hlo->opcode() == HloOpcode::kDynamicSlice) { + // Parameter loads and their concatenations are handled outside EmitScope. + TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); + continue; + } else if (hlo->opcode() == HloOpcode::kParameter) { + if (hlo->users()[0]->opcode() == HloOpcode::kConcatenate || + hlo->users()[0]->opcode() == HloOpcode::kDynamicSlice) { + continue; + } + TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); + continue; + } else if (hlo->opcode() == HloOpcode::kConstant) { + TF_ASSIGN_OR_RETURN(Value constant, EmitConstant(b, *hlo)); + // Splat makes it a tensor to avoid type mismatches. + result = Splat(b, constant, {}); + } else if (hlo->opcode() == HloOpcode::kBroadcast) { + TF_ASSIGN_OR_RETURN(result, EmitBroadcast(b, analysis, side, *hlo, + values[hlo->operand(0)])); + } else if (HloInstruction::IsOpElementwise(hlo->opcode())) { + std::vector operands; + operands.reserve(hlo->operands().size()); + for (const HloInstruction* operand : hlo->operands()) { + operands.push_back(values[operand]); + } + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); + } else if (hlo->opcode() == HloOpcode::kTuple) { + TF_RET_CHECK(hlo->IsRoot()) << hlo->ToString(); + } else if (hlo->opcode() == HloOpcode::kBitcast || + hlo->opcode() == HloOpcode::kTranspose || + hlo->opcode() == HloOpcode::kSlice || + hlo->opcode() == HloOpcode::kReshape || + hlo->opcode() == HloOpcode::kPad) { + // All these are currently supported only as operations on indices + // which are pushed to loads and stores. No operations on tiles are + // performed here. + result = values[hlo->operand(0)]; + } else { + return absl::InvalidArgumentError( + absl::StrCat("Unsupported operation ", hlo->ToString())); + } + TF_RET_CHECK(values.insert({hlo, result}).second) << hlo->ToString(); + VLOG(8) << "Emitted " << hlo->ToString(HloPrintOptions::ShortParsable()); + } + return values[instructions.back()]; +} + +const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( + const TritonFusionAnalysis& analysis, int64_t lhs_noncontracting_dim_idx) { + const TensorIterationSpec::DimIterationSpec* result = nullptr; + for (const HloInstruction* lhs_param : + analysis.ScopeParameters(TritonFusionAnalysis::Scope::LHS)) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::LHS, lhs_param, + lhs_noncontracting_dim_idx); + if (spec != nullptr && spec->size() > 1) { + CHECK_EQ(spec->size(), 2); + if (result != nullptr) { + CHECK_EQ(result->at(0).count, spec->at(0).count); + CHECK_EQ(result->at(1).count, spec->at(1).count); + } + result = spec; + } + } + return result; +} + +// Structure for parameters relating to the MatMul shape and dimension indices. +// +// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. +// +// The logical output dimensions are always ordered as: +// split-K, batch, non-contracting LHS, non-contracting RHS, +// where split-K and batch are optional. +struct MatMulDims { + static absl::StatusOr Create( + const TritonGemmConfig& config, const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis); + + std::optional out_split_k_dim_idx = std::nullopt; + + std::optional lhs_batch_dim_idx = std::nullopt; + std::optional rhs_batch_dim_idx = std::nullopt; + std::optional out_batch_dim_idx = std::nullopt; + + // The LHS non-contracting can be split into two. + std::optional lhs_noncontracting_split = std::nullopt; + + int lhs_contracting_dim_idx; + int lhs_noncontracting_dim_idx; + int rhs_contracting_dim_idx; + int rhs_noncontracting_dim_idx; + // The index of the LHS noncontracting dim in the output. + int out_lhs_noncontracting_dim_idx; + // The index of the RHS noncontracting dim in the output. + int out_rhs_noncontracting_dim_idx; + + int64_t m; + int64_t n; + int64_t k; + + private: + MatMulDims() = default; +}; + +// Structure for parameters relating to the MatMul launch grid. +struct MatMulLaunchConfig { + explicit MatMulLaunchConfig(const TritonGemmConfig& config, + const HloDotInstruction& dot, + const MatMulDims& dims); + + int64_t grid_m; + int64_t grid_n; + LaunchDimensions launch_dims; + mt::ProgramIDDim batch_program_id_dim; + mt::ProgramIDDim noncontracting_program_id_dim; +}; + +/*static*/ absl::StatusOr MatMulDims::Create( + const TritonGemmConfig& config, const HloDotInstruction& dot, + const TritonFusionAnalysis& analysis) { + MatMulDims matmul_dims; + if (config.split_k > 1) { + // split-k is always the first logical dimension. + matmul_dims.out_split_k_dim_idx = 0; + } + + int64_t num_split_k_dims = config.split_k > 1 ? 1 : 0; + const auto& dims = dot.dot_dimension_numbers(); + matmul_dims.lhs_contracting_dim_idx = dims.lhs_contracting_dimensions(0); + matmul_dims.lhs_noncontracting_dim_idx = + GetNonContractingDims(dot.operand(0)->shape(), + dims.lhs_batch_dimensions(), + dims.lhs_contracting_dimensions()) + .value()[0]; + matmul_dims.rhs_contracting_dim_idx = dims.rhs_contracting_dimensions(0); + matmul_dims.rhs_noncontracting_dim_idx = + GetNonContractingDims(dot.operand(1)->shape(), + dims.rhs_batch_dimensions(), + dims.rhs_contracting_dimensions()) + .value()[0]; + + if (dims.lhs_batch_dimensions_size() > num_split_k_dims) { + matmul_dims.lhs_batch_dim_idx = *dims.lhs_batch_dimensions().rbegin(); + matmul_dims.rhs_batch_dim_idx = *dims.rhs_batch_dimensions().rbegin(); + // The batch dimension (if present) comes after the split-k dimension (if + // present, otherwise it's the first dimension). + matmul_dims.out_batch_dim_idx = num_split_k_dims; + } + + // Logical output dimensions are always ordered as: + // split-K, batch, non-contracting LHS, non-contracting RHS, + // where split-K and batch are optional. + matmul_dims.out_rhs_noncontracting_dim_idx = dot.shape().rank() - 1; + matmul_dims.out_lhs_noncontracting_dim_idx = dot.shape().rank() - 2; + + auto* root = dot.parent()->root_instruction(); + auto iter_spec = + analysis.IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + matmul_dims.out_rhs_noncontracting_dim_idx); + TF_RET_CHECK(iter_spec != nullptr); + matmul_dims.n = iter_spec->at(0).count; + // Contracting dimension length. + if (config.split_k > 1 && + dot.operand(1)->operand(0)->opcode() == HloOpcode::kPad) { + // Unpadded LHS shape: [..., k, ...] + // Padded LHS shape: [..., padded_k, ...] + // Bitcasted LHS shape: [..., split_k, padded_k / split_k, ...] + TF_RET_CHECK(dot.operand(1)->opcode() == HloOpcode::kBitcast); + const Shape& unpadded_rhs_shape = + dot.operand(1)->operand(0)->operand(0)->shape(); + matmul_dims.k = + unpadded_rhs_shape.dimensions(dims.rhs_contracting_dimensions(0) - 1); + } else { + matmul_dims.k = + dot.operand(1)->shape().dimensions(dims.rhs_contracting_dimensions(0)) * + config.split_k; + } + + auto* lhs_noncontracting_split_spec = GetLhsNoncontractingSplitSpec( + analysis, matmul_dims.lhs_noncontracting_dim_idx); + if (lhs_noncontracting_split_spec != nullptr) { + // Just the fastest-varying part of it if the dimension is split. + matmul_dims.m = lhs_noncontracting_split_spec->at(0).count; + matmul_dims.lhs_noncontracting_split = + lhs_noncontracting_split_spec->at(1).count; + } else { + matmul_dims.m = analysis + .IterSpec(TritonFusionAnalysis::Scope::OUTPUT, root, + matmul_dims.out_lhs_noncontracting_dim_idx) + ->at(0) + .count; + } + + // For now split non-contracting and batch are not supported + // simultaneously because they are implemented via same mechanism. + TF_RET_CHECK(!(matmul_dims.out_batch_dim_idx.has_value() && + matmul_dims.lhs_noncontracting_split.has_value())); + + TF_RET_CHECK(matmul_dims.m >= 1); + TF_RET_CHECK(matmul_dims.n >= 1); + return std::move(matmul_dims); +} + +MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, + const HloDotInstruction& dot, + const MatMulDims& dims) + : grid_m((dims.m + config.block_m - 1) / config.block_m), + grid_n((dims.n + config.block_n - 1) / config.block_n) { + int64_t batch_size = dims.lhs_noncontracting_split.value_or( + dims.out_batch_dim_idx.has_value() + ? dot.shape().dimensions(*dims.out_batch_dim_idx) + : 1); + // X block size is 32-bit, Y and Z are 16-bit. Use X for large dimensions. + constexpr int64_t kBlockCountYZLimit = 65536; + + // In the imaginary situation where both batch size and grid_m * grid_n + // are over 65535 we have to give up. Given the minimal m, n block sizes of 16 + // this requires at least 256 GB of output. + CHECK_LT(batch_size * grid_m * grid_n, + kBlockCountYZLimit * kBlockCountYZLimit); + + const bool large_batch = batch_size >= kBlockCountYZLimit; + if (large_batch) { + batch_program_id_dim = mt::ProgramIDDim::X; + noncontracting_program_id_dim = mt::ProgramIDDim::Y; + launch_dims = LaunchDimensions( + se::BlockDim(batch_size, grid_m * grid_n, config.split_k), + se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + } else { + batch_program_id_dim = mt::ProgramIDDim::Y; + noncontracting_program_id_dim = mt::ProgramIDDim::X; + launch_dims = LaunchDimensions( + se::BlockDim(grid_m * grid_n, batch_size, config.split_k), + se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + } +} + +absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, + const HloDotInstruction& dot) { + TF_RET_CHECK(config.split_k >= 1); + TF_RET_CHECK(config.block_m >= 16); + TF_RET_CHECK(config.block_k >= 16); + TF_RET_CHECK(config.block_n >= 16); + + const auto& dims = dot.dot_dimension_numbers(); + int num_batch_dims = + dims.lhs_batch_dimensions_size() - (config.split_k > 1 ? 1 : 0); + TF_RET_CHECK(num_batch_dims <= 1); + if (config.split_k > 1) { + // Split-K dimension has to be the first batch one and have an index + // just before the contracting one. + const int lhs_split_k_dim_idx = dims.lhs_contracting_dimensions(0) - 1; + const int rhs_split_k_dim_idx = dims.rhs_contracting_dimensions(0) - 1; + // Size of this dimension has to match the split_k value. + TF_RET_CHECK(dims.lhs_batch_dimensions(0) == lhs_split_k_dim_idx); + TF_RET_CHECK(dims.rhs_batch_dimensions(0) == rhs_split_k_dim_idx); + TF_RET_CHECK(config.split_k == + dot.operand(0)->shape().dimensions(lhs_split_k_dim_idx)); + TF_RET_CHECK(config.split_k == + dot.operand(1)->shape().dimensions(rhs_split_k_dim_idx)); + } + + // Rely on dot decomposer: there is just one contracting and one + // non-contracting dimension on each side + batch ones optionally. + TF_RET_CHECK(dims.lhs_contracting_dimensions_size() == 1); + TF_RET_CHECK(dims.rhs_contracting_dimensions_size() == 1); + + TF_RET_CHECK(dot.operand(0)->shape().rank() == + 2 + (config.split_k > 1 ? 1 : 0) + num_batch_dims); + return absl::OkStatus(); +} + +// if (index < limits[0]) { +// return choices[0]; +// } else if (index < limits[1]) { +// return choices[1]; +// } else if (...) { +// ... +// } else { +// return choices.back(); +// } +absl::StatusOr EmitMultiSelect(ImplicitLocOpBuilder b, Value index, + ValueRange limits, ValueRange choices) { + TF_RET_CHECK(choices.size() - 1 == limits.size()); + Value result = choices[0]; + for (int i = 0; i < choices.size() - 1; ++i) { + result = b.create( + b.create(ma::CmpIPredicate::slt, index, limits[i]), result, + choices[i + 1]); + } + return result; +} + +absl::Status UncompilableMatmul(absl::string_view explanation) { + absl::Status s = absl::CancelledError(explanation); + s.SetPayload(kUncompilableFusion, absl::Cord(explanation)); + return s; +} + +bool IsFp8Matmul(const HloDotInstruction* dot_instr) { + return absl::c_all_of(std::array{0, 1}, [&](int idx) { + return primitive_util::IsF8Type( + dot_instr->operand(idx)->shape().element_type()); + }); +} + +class MatMulEmitterHelper { + public: + MatMulEmitterHelper(absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloDotInstruction* dot_instr, + ImplicitLocOpBuilder& b, Type index_ty, MatMulDims dims, + const MatMulLaunchConfig& launch_config, + const TritonFusionAnalysis& analysis) + : b_(b), + libdevice_path_(libdevice_path), + device_info_(device_info), + dot_instr_(dot_instr), + index_ty_(index_ty), + analysis_(analysis), + dims_(dims), + launch_config_(launch_config) {} + + // TODO(b/266862493): Accumulator can be integer too. + // Otherwise only f64 x f64 -> f64 uses f64 accumulator. + absl::StatusOr GetDotAccumulatorType() { + const PrecisionConfig::Algorithm algorithm = + dot_instr_->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + TF_ASSIGN_OR_RETURN(Type dot_output_ty, + TritonType(b_, dot_instr_->shape().element_type())); + // The code below assumes that lhs and rhs have the same type. However + // it's not always the case with fp8 matmuls, e.g. e4m3×e5m2 is supported + // at the hardware level. NVidia GPU currently only supports f32 + // accumulator for such matmuls. + if (IsFp8Matmul(dot_instr_)) { + return b_.getF32Type(); + } + + // Data type of dot() immediate inputs. + TF_ASSIGN_OR_RETURN( + const Type lhs_ty, + TritonType(b_, dot_instr_->operand(0)->shape().element_type())); + TF_ASSIGN_OR_RETURN( + const Type rhs_ty, + TritonType(b_, dot_instr_->operand(1)->shape().element_type())); + TF_RET_CHECK(lhs_ty == rhs_ty); + Type dot_input_ty = lhs_ty; + // TODO(b/266862493): Accumulator can be integer too. + // Otherwise only f64 x f64 -> f64 uses f64 accumulator. + return (dot_output_ty.isF64() && dot_input_ty.isF64()) ? b_.getF64Type() + : b_.getF32Type(); + } + + absl::StatusOr accum_type = + algorithm_util::GetDotAccumulatorType(algorithm); + CHECK(accum_type.ok()) << "Unexpected algorithm: " + << PrecisionConfig::Algorithm_Name(algorithm); + TF_ASSIGN_OR_RETURN(Type mlir_accum_type, + TritonType(b_, accum_type.value())); + if (auto float_accum_type = + mlir::dyn_cast(mlir_accum_type)) { + return float_accum_type; + } + LOG(FATAL) << "Only floating point accumulator types are supported for " + "now, but we got: " + << llvm_ir::DumpToString(mlir_accum_type); + } + + std::vector EpiloguePostOrderTransitiveOperands( + const HloInstruction* root) { + // Collect all instructions of the dot's output scope. + absl::flat_hash_set to_order; + { + std::queue to_add; + if (root != dot_instr_) { + to_add.push(root); + } + while (!to_add.empty()) { + const HloInstruction* current = to_add.front(); + for (const HloInstruction* operand : current->operands()) { + if (!to_order.contains(operand)) { + if (operand != dot_instr_) { + to_add.push(operand); + } + } + } + to_order.insert(current); + to_add.pop(); + } + } + // Order them producers before consumers. + std::vector to_emit; + for (const HloInstruction* hlo : + dot_instr_->parent()->MakeInstructionPostOrder()) { + if (to_order.contains(hlo)) { + to_emit.push_back(hlo); + } + } + return to_emit; + } + + Value MakeInput(const Side& side, int64_t operand_index, + absl::flat_hash_map& values) { + return *EmitScope( + b_, libdevice_path_, device_info_, &analysis_, side, + dot_instr_->parent()->MakeInstructionPostOrderFrom( + const_cast(*dot_instr_->operand(operand_index))), + values); + } + + int64_t GetNonContractingDimIdxForOperandScope( + TritonFusionAnalysis::Scope scope) { + if (scope == TritonFusionAnalysis::Scope::LHS) { + return dims_.lhs_noncontracting_dim_idx; + } else if (scope == TritonFusionAnalysis::Scope::RHS) { + return dims_.rhs_noncontracting_dim_idx; + } else { + CHECK(false) << "This shouldn't be called for the output scope."; + } + } + + // Return the batch stride of the HLO passed as a parameter. If the + // parameter HLO has no batch dimension, a zero stride is returned. + // Also sets offset_batch and updates has_batch_offset as a side effect. + absl::StatusOr GetBatchStride(const Side& side, + const HloInstruction* hlo_param, + int64_t& offset_batch, + bool& has_batch_offset) { + int64_t stride_batch = 0; + if (side.scope != TritonFusionAnalysis::Scope::RHS && + dims_.lhs_noncontracting_split) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo_param, side.tiled_dims[0].index); + if (spec != nullptr) { + if (spec->size() > 1) { + // Support one specific kind of output transpose that splits the + // dimension originating from the split LHS non-contracting one. + stride_batch = spec->at(1).stride; + } else { + // Because the major part of the split is implemented using the + // batch logic stride_batch is populated here as the stride of + // the minor part times its size. + stride_batch = spec->at(0).stride * + (spec->at(0).count / *dims_.lhs_noncontracting_split); + } + TF_RET_CHECK(stride_batch != 0); + } + } else if (side.batch_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis_.IterSpec(side.scope, hlo_param, *side.batch_dim_idx); + if (spec != nullptr) { + stride_batch = spec->at(0).stride; + offset_batch = spec->at(0).slice_start; + TF_RET_CHECK(stride_batch != 0); + } + } + + has_batch_offset |= stride_batch != 0; + return Cst(stride_batch); + } + + // bases: The base pointers of each argument. + absl::StatusOr EmitTensorPointer( + const HloInstruction* hlo, const Side& side, ValueRange bases, + Value pid_k, std::vector& boundary_checks) { + // Parameters of MakeTensorPtrOp to be generated by this function. + Value base; + std::vector bounds; + std::vector strides; + std::vector strides_sizes; // We use it to detect the minor dim. + // Offsets from tensor origin, same for all thread blocks. + std::vector tensor_offsets; + std::vector block_dims; + std::vector dim_order; + + // Offsets for a given thread block, typically pid * block size. + // Used in a one-off AdvanceOp applied to the generated MakeTensorPtrOp. + std::vector block_offsets; + + // Concatenations of parameters are handled during generation of block + // pointers because of a limitation of implementation of block pointers + // in the Triton compiler: block pointers are not supported inside + // conditionals. + // Therefore instead of directly using a conditional to emit a concatenation + // and emitting its inputs inside the cases a single block pointer is + // emitted for all inputs, but all its properties (base, strides etc) get + // generated conditionally on the position of the current thread block + // within the concatenated dimension. + + // Index of concatenated dimension if present, -1 otherwise. + int concat_dim_idx; + // Offsets along the concatenated dimension at which operands change. + std::vector concat_boundaries; + // Block index along the concatenated dimension * block size. + Value concat_dim_pid_offset; + + if (hlo->opcode() == HloOpcode::kConcatenate) { + // For now only non-contracting dimension can be concatenated. + concat_dim_idx = (side.scope == TritonFusionAnalysis::Scope::LHS) + ? dims_.lhs_noncontracting_dim_idx + : dims_.rhs_noncontracting_dim_idx; + const DimProperties& properties = [&] { + for (const DimProperties& dim : side.tiled_dims) { + if (dim.index == concat_dim_idx) { + return dim; + } + } + LOG(FATAL) << "Missing dimension."; + }(); + TF_RET_CHECK(bases.size() == hlo->operand_count()); + + concat_boundaries.reserve(hlo->operand_count() - 1); + for (int i = 0; i < hlo->operand_count() - 1; ++i) { + const TensorIterationSpec::IterationSpecFragment& fragment = + analysis_.IterSpec(side.scope, hlo->operand(i), concat_dim_idx) + ->at(0); + if (fragment.sliced_count % properties.block_size != 0) { + return UncompilableMatmul( + "Operand is not divisible by the block size."); + } + concat_boundaries.push_back( + Cst32(-fragment.slice_start + fragment.sliced_count)); + } + + concat_dim_pid_offset = + b_.create(properties.pid, Cst32(properties.block_size)); + TF_ASSIGN_OR_RETURN(base, EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, bases)); + } else { + concat_dim_idx = -1; + base = bases[0]; + } + + auto add_dim = [&](const DimProperties& properties) -> absl::Status { + if (analysis_.IterSpec(side.scope, hlo, properties.index) == nullptr) { + return absl::OkStatus(); + } + Value pid_offset = + (properties.pid == nullptr) + ? Cst32(0) + : b_.create(properties.pid, + Cst32(properties.block_size)); + std::vector inputs; + if (hlo->opcode() == HloOpcode::kConcatenate) { + inputs.insert(inputs.end(), hlo->operands().cbegin(), + hlo->operands().cend()); + } else { + inputs = {hlo}; + } + std::vector specs; + std::vector input_strides; + std::vector input_offsets; + std::vector input_bounds; + specs.reserve(inputs.size()); + input_strides.reserve(inputs.size()); + input_offsets.reserve(inputs.size()); + input_bounds.reserve(inputs.size()); + for (const HloInstruction* input : inputs) { + specs.push_back( + analysis_.IterSpec(side.scope, input, properties.index)); + const auto stride = specs.back()->at(0).stride; + strides_sizes.push_back(stride); + input_strides.push_back(Cst64(stride)); + input_offsets.push_back(b_.create( + pid_offset, Cst32(specs.back()->at(0).slice_start))); + input_bounds.push_back(Cst64(specs.back()->at(0).count)); + } + TF_ASSIGN_OR_RETURN(Value select_value, + EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, input_strides)); + strides.push_back(select_value); + if (properties.index == concat_dim_idx) { + TF_ASSIGN_OR_RETURN( + select_value, + EmitMultiSelect(b_, pid_offset, concat_boundaries, input_offsets)); + block_offsets.push_back(select_value); + TF_ASSIGN_OR_RETURN( + select_value, + EmitMultiSelect(b_, pid_offset, concat_boundaries, input_bounds)); + bounds.push_back(select_value); + tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); + } else if (hlo->opcode() == HloOpcode::kDynamicSlice && + (side.scope == TritonFusionAnalysis::Scope::LHS || + side.scope == TritonFusionAnalysis::Scope::RHS) && + properties.index == + GetNonContractingDimIdxForOperandScope(side.scope)) { + // Here we compute the offset of where we should read the slice from. + // TODO(b/323255699): Add support for slices of the contracting dim. + // Dynamic slices are guaranteed to only be offset along the majormost + // dimension. + + // The only fragment of the non-contracting dim of the dot's input in + // the current scope: + TF_RET_CHECK(specs.back()->size() == 1); + const TensorIterationSpec::IterationSpecFragment + only_fragment_of_nc_dim = specs.back()->at(0); + // The majormost dim index in the dynamic slice's output. + const int majormost_dim = hlo->shape().layout().minor_to_major().back(); + + // dynamic slice operands are (input, start_index0, start_index1, ...) + // so the start index corresponding to the ith dimension is bases[i+1]. + Value majormost_dim_start_index_ptr_val = bases[majormost_dim + 1]; + Value majormost_dim_start_index_val = b_.create( + majormost_dim_start_index_ptr_val, mt::CacheModifier::NONE, + mt::EvictionPolicy::NORMAL, + /*isVolatile=*/false); + int64_t majormost_dim_start_index_upper_limit = + hlo->operand(0)->shape().dimensions(majormost_dim) - + hlo->dynamic_slice_sizes().at(majormost_dim); + // We don't want to cast S64 indices to S32, because that could result + // in an incorrect value. + if (majormost_dim_start_index_val.getType().isInteger() && + majormost_dim_start_index_val.getType().getIntOrFloatBitWidth() == + 64) { + return UncompilableMatmul( + "64 bit dynamic-slice indices are not supported yet."); + } + majormost_dim_start_index_val = + Cast(b_, majormost_dim_start_index_val, b_.getI32Type()); + majormost_dim_start_index_val = + b_.create(majormost_dim_start_index_val, Cst32(0)); + majormost_dim_start_index_val = b_.create( + majormost_dim_start_index_val, + Cst32(majormost_dim_start_index_upper_limit)); + + // How many "rows" (non-contracting dim values) are there in a slice of + // size 1? + int64_t rows_per_majormost_dim = 1; + for (int i = 0; i < hlo->shape().dimensions().size() - 1; ++i) { + rows_per_majormost_dim *= hlo->shape().dimensions_minor(i); + } + rows_per_majormost_dim = + rows_per_majormost_dim / only_fragment_of_nc_dim.stride; + Value rows_per_majormost_dim_val = Cst32(rows_per_majormost_dim); + + Value tensor_offset_val_i32 = b_.create( + majormost_dim_start_index_val, rows_per_majormost_dim_val); + tensor_offsets.push_back(tensor_offset_val_i32); + + // tt.make_tensor_ptr expects an i64 for shape and size, but expects + // i32 for offsets. We extend the offset to calculate the upper bound. + Value tensor_offset_val_i64 = + b_.create(i64_ty_, tensor_offset_val_i32); + Value sliced_count_val = Cst64(only_fragment_of_nc_dim.sliced_count); + Value upper_bound_val = + b_.create(tensor_offset_val_i64, sliced_count_val); + bounds.push_back(upper_bound_val); + + block_offsets.push_back(pid_offset); + } else { + tensor_offsets.push_back(Cst32(specs.front()->at(0).slice_start)); + block_offsets.push_back(pid_offset); + int64_t dim_bound = specs.front()->at(0).count; + if (side.scope == TritonFusionAnalysis::Scope::OUTPUT && + properties.index == dims_.out_lhs_noncontracting_dim_idx && + specs.front()->size() == 1 && + dims_.lhs_noncontracting_split.has_value()) { + // Dimension of the output produced by the non-contracting LHS one + // is logically split, major part is addressed using pid_batch. + dim_bound /= *dims_.lhs_noncontracting_split; + } + bounds.push_back(Cst64(dim_bound)); + if (dim_bound % (properties.block_size * properties.split_value) != 0) { + boundary_checks.push_back(bounds.size() - 1); + } + if (hlo->shape().element_type() == PrimitiveType::S4) { + // For s4 type we need to divide the minor dim bound by 2 because it + // is the packing dimension. But if the minor dim has length == 1 then + // the major dim stride is also 1 and it is the packing dimension. + if (strides_sizes.back() == 1) { + // For the odd bounds we need to add 1 in advance. + // Otherwise we will loose the last element. + bounds[bounds.size() - 1] = Cst64((dim_bound + 1) / 2); + } else { + int last_stride_index = strides.size() - 1; + strides[last_stride_index] = + b_.create(strides[last_stride_index], Cst64(2)); + } + } + } + block_dims.push_back(properties.block_size); + dim_order.emplace(dim_order.begin(), dim_order.size()); + return absl::OkStatus(); + }; + + for (const DimProperties& dim : side.tiled_dims) { + TF_RETURN_IF_ERROR(add_dim(dim)); + } + + int64_t offset_batch = 0; + bool has_batch_offset = false; + Value batch_stride; + + if (hlo->opcode() == HloOpcode::kConcatenate) { + std::vector batch_strides; + batch_strides.reserve(hlo->operands().size()); + for (const HloInstruction* operand : hlo->operands()) { + TF_ASSIGN_OR_RETURN( + Value op_stride, + GetBatchStride(side, operand, offset_batch, has_batch_offset)); + batch_strides.push_back(op_stride); + } + TF_ASSIGN_OR_RETURN(batch_stride, + EmitMultiSelect(b_, concat_dim_pid_offset, + concat_boundaries, batch_strides)); + } else { + TF_ASSIGN_OR_RETURN(batch_stride, GetBatchStride(side, hlo, offset_batch, + has_batch_offset)); + } + + // Avoid generating logic to compute batch offset if unnecessary. + if (has_batch_offset) { + Value pid_batch = + b_.create(launch_config_.batch_program_id_dim); + + Value pid_offset_batch = b_.create( + b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), + batch_stride); + + if (hlo->shape().element_type() == PrimitiveType::S4) { + pid_offset_batch = b_.create(pid_offset_batch, Cst(2)); + } + base = AddPtr(b_, base, pid_offset_batch); + } + + if (dims_.out_split_k_dim_idx.has_value()) { + const TensorIterationSpec::DimIterationSpec* spec = analysis_.IterSpec( + TritonFusionAnalysis::Scope::OUTPUT, hlo, *dims_.out_split_k_dim_idx); + if (spec != nullptr) { + TF_RET_CHECK(pid_k != nullptr); + base = AddPtr(b_, base, + b_.create(ConvertScalar(pid_k), + Cst(spec->at(0).stride))); + } + } + + if (block_dims.empty()) { + // Load of a scalar. + return base; + } + auto tensor_ptr = mlir::cast( + b_.create(base, bounds, strides, tensor_offsets, + block_dims, dim_order) + .getResult()); + tensor_ptr = b_.create(tensor_ptr.getType(), tensor_ptr, + block_offsets); + return tensor_ptr; + } + + private: + // Extend int32 indexes to int64, if necessary. + Value ConvertScalar(Value value) { + if (index_ty_.getIntOrFloatBitWidth() == 64) { + return b_.create(index_ty_, value); + } + return value; + } + + Value Cst(int64_t v) { return CreateConst(b_, index_ty_, v); } + Value Cst32(int32_t v) { return CreateConst(b_, i32_ty_, v); } + Value Cst64(int64_t v) { return CreateConst(b_, i64_ty_, v); } + + ImplicitLocOpBuilder& b_; + absl::string_view libdevice_path_; + const se::DeviceDescription& device_info_; + const HloDotInstruction* dot_instr_; + Type index_ty_; + TritonFusionAnalysis analysis_; + MatMulDims dims_; + MatMulLaunchConfig launch_config_; + Type i32_ty_ = b_.getI32Type(); + Type i64_ty_ = b_.getI64Type(); +}; + +absl::StatusOr> GetArguments(mlir::triton::FuncOp fn, + const HloInstruction& input) { + if (input.opcode() == HloOpcode::kParameter) { + return {{fn.getArgument(input.parameter_number())}}; + } else if (input.opcode() == HloOpcode::kConcatenate || + input.opcode() == HloOpcode::kDynamicSlice) { + // As defined in GemmFusion, all inputs of concatenate and dynamic slice are + // parameters. + SmallVector result; + for (const HloInstruction* operand : input.operands()) { + TF_RET_CHECK(operand->opcode() == HloOpcode::kParameter); + result.push_back(fn.getArgument(operand->parameter_number())); + } + return result; + } + LOG(FATAL) << "Unexpected opcode: " << input.opcode(); +} + +// Concatenations can currently only be applied directly to parameters; +// all concatenated parameters share the same block pointer. This function +// returns all inputs of a kernel: concatenations of parameters and standalone +// parameters. +ConstHloInstructionSet ScopeInputs(const TritonFusionAnalysis& analysis, + const TritonFusionAnalysis::Scope scope) { + ConstHloInstructionSet result; + for (const HloInstruction* parameter : analysis.ScopeParameters(scope)) { + if (absl::c_any_of(parameter->users(), [](const HloInstruction* user) { + return user->opcode() == HloOpcode::kConcatenate || + user->opcode() == HloOpcode::kDynamicSlice; + })) { + // Concatenation is always the only user of its parameters by + // construction. + CHECK_EQ(parameter->users().size(), 1); + for (const HloInstruction* operand : parameter->users()[0]->operands()) { + // All operands of a concatenation have to be computation parameters. + CHECK_EQ(operand->opcode(), HloOpcode::kParameter); + } + result.insert(parameter->users()[0]); + } else { + result.insert(parameter); + } + } + return result; +} + +// Truncates |input| of F32 type to the number representable in Bf16 toward +// zero. +// It is used for Emit6xBfloat16MatMul. +Value TruncateToBF16TowardsZero(ImplicitLocOpBuilder& b, Value input) { + ShapedType input_type = mlir::dyn_cast(input.getType()); + Type input_type_as_i32 = input_type.clone(b.getI32Type()); + Value input_as_i32 = b.create(input_type_as_i32, input); + Value mask = CreateConst(b, b.getI32Type(), 0xFFFF0000u, + input_type.getShape()); + Value high_bits = b.create(input_type_as_i32, input_as_i32, mask); + + return b.create(input_type, high_bits); +} + +// Finds the middle 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftMiddleEight(ImplicitLocOpBuilder& b, Value input) { + Value high = TruncateToBF16TowardsZero(b, input); + return b.create(input, high); +} + +// Finds the low 8 bits of |input|'s mantissa. +// It is used for Emit6xBfloat16MatMul. +Value SoftLowEight(ImplicitLocOpBuilder& b, Value input) { + // Find the middle bits of the middle bits, and these are the low eight + // bits. + return SoftMiddleEight(b, SoftMiddleEight(b, input)); +} + +// Rounds |input| to BF16 type. +// It is used for Emit6xBfloat16MatMul. +Value RoundToBF16(ImplicitLocOpBuilder& b, Value input) { + return Cast(b, input, b.getBF16Type()); +} + +// Checks |input| is finite f32 (not Nan and not infinite). +// It is used for Emit6xBfloat16MatMul and Emit3xBfloat16MatMul. +Value CheckFiniteF32(ImplicitLocOpBuilder& b, Value input) { + Value positive_inf = CreateConst( + b, b.getF32Type(), std::numeric_limits::infinity(), + mlir::cast(input.getType()).getShape()); + Value abs_input = b.create(input); + return b.create(ma::CmpFPredicate::OGT, positive_inf, abs_input); +} + +// Leverages BF16 datatype for F32 matmul computation. It follows the guidance +// from https://arxiv.org/pdf/1904.06376.pdf. +absl::StatusOr Emit6xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, + Value rhs, Value acc) { + Type f32 = b.getF32Type(); + TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); + TF_RET_CHECK(mlir::cast(rhs.getType()).getElementType() == f32); + TF_RET_CHECK(mlir::cast(acc.getType()).getElementType() == f32); + + Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); + Value lhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, lhs))); + Value lhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, lhs))); + + Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); + Value rhs_middle = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftMiddleEight(b, rhs))); + Value rhs_low = + RoundToBF16(b, TruncateToBF16TowardsZero(b, SoftLowEight(b, rhs))); + + auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, + Value accumulator) -> Value { + return b.create(lhs_bf16, rhs_bf16, accumulator, + /*inputPrecision=*/mt::InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); + }; + + Value local_acc = ZerosLike(b, acc); + Value result = bf16_dot(lhs_middle, rhs_middle, local_acc); + result = bf16_dot(lhs_low, rhs_high, result); + result = bf16_dot(lhs_high, rhs_low, result); + result = bf16_dot(lhs_middle, rhs_high, result); + result = bf16_dot(lhs_high, rhs_middle, result); + // If lhs is 1.0, we will have lhs_high = 1.0 and lhs_low = 0.0. + // If rhs is +infinity, we will have: + // +infinity * 1.0 = +infinity + // +infinity * 0.0 = NaN + // We would get the wrong result if we sum these partial products. Instead, we + // must override any accumulated result if the last partial product is + // non-finite. See b/115844437. + Value is_finite = CheckFiniteF32(b, result); + result = b.create(is_finite, result, ZerosLike(b, result)); + result = bf16_dot(lhs_high, rhs_high, result); + result = b.create(acc, result); + return result; +} + +// Compute F32 matmul with 3 BF16 dots. It is less accurate than +// Emit6xBfloat16MatMul. +absl::StatusOr Emit3xBfloat16MatMul(ImplicitLocOpBuilder& b, Value lhs, + Value rhs, Value acc) { + Type f32 = b.getF32Type(); + TF_RET_CHECK(mlir::cast(lhs.getType()).getElementType() == f32); + TF_RET_CHECK(mlir::cast(rhs.getType()).getElementType() == f32); + TF_RET_CHECK(mlir::cast(acc.getType()).getElementType() == f32); + + Value lhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, lhs)); + Value lhs_low = RoundToBF16(b, SoftMiddleEight(b, lhs)); + + Value rhs_high = RoundToBF16(b, TruncateToBF16TowardsZero(b, rhs)); + Value rhs_low = RoundToBF16(b, SoftMiddleEight(b, rhs)); + + auto bf16_dot = [&](Value lhs_bf16, Value rhs_bf16, + Value accumulator) -> Value { + return b.create(lhs_bf16, rhs_bf16, accumulator, + /*inputPrecision=*/mt::InputPrecision::IEEE, + /*maxNumImpreciseAcc=*/0); + }; + + Value local_acc = ZerosLike(b, acc); + Value result = bf16_dot(lhs_low, rhs_high, local_acc); + result = bf16_dot(lhs_high, rhs_low, result); + Value is_finite = CheckFiniteF32(b, result); + result = b.create(is_finite, result, ZerosLike(b, result)); + result = bf16_dot(lhs_high, rhs_high, result); + result = b.create(acc, result); + return result; +} + +bool IsTf32Allowed(const HloDotInstruction* dot_instr) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + return tsl::tensor_float_32_execution_enabled() && + absl::c_none_of(dot_instr->precision_config().operand_precision(), + [](const int precision) { + return precision != PrecisionConfig::DEFAULT; + }); + } + + return algorithm_util::HasTf32InputType(algorithm); +} + +mt::InputPrecision InferDotPrecision(const HloDotInstruction* dot_instr) { + auto algorithm = dot_instr->precision_config().algorithm(); + if (algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3) { + return mt::InputPrecision::TF32x3; + } + // TODO(b/320659359) Allow TF32 for 8-bit or less types with F32. + bool is_unsupported_bitwidth = + HloBfsAnyOf({dot_instr}, [&](const HloInstruction* node) { + if (node->opcode() != HloOpcode::kConvert) { + return false; + } + int in_width = + primitive_util::BitWidth(node->operand(0)->shape().element_type()); + return in_width <= 8 && node->shape().element_type() == F32; + }); + + return IsTf32Allowed(dot_instr) && !is_unsupported_bitwidth + ? mt::InputPrecision::TF32 + : mt::InputPrecision::IEEE; +} + +bool Is6xBfloat16MatMul(const HloDotInstruction* dot_instr, + mlir::OpBuilder& builder, Value dot_input_lhs, + Value dot_input_rhs, + const se::DeviceDescription& device_info) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + const HloModule* hlo_module = dot_instr->GetModule(); + Type f32 = builder.getF32Type(); + return hlo_module->config() + .debug_options() + .xla_gpu_enable_bf16_6way_gemm() && + mlir::cast(dot_input_lhs.getType()).getElementType() == + f32 && + mlir::cast(dot_input_rhs.getType()).getElementType() == + f32; + } + + return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6; +} + +bool Is3xBfloat16MatMul(const HloDotInstruction* dot_instr, + mlir::OpBuilder& builder, Value dot_input_lhs, + Value dot_input_rhs, + const se::DeviceDescription& device_info) { + const PrecisionConfig::Algorithm algorithm = + dot_instr->precision_config().algorithm(); + + if (algorithm == PrecisionConfig::ALG_UNSET) { + const HloModule* hlo_module = dot_instr->GetModule(); + Type f32 = builder.getF32Type(); + return hlo_module->config() + .debug_options() + .xla_gpu_enable_bf16_3way_gemm() && + mlir::cast(dot_input_lhs.getType()).getElementType() == + f32 && + mlir::cast(dot_input_rhs.getType()).getElementType() == + f32; + } + + return algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3; +} + +// This is a heuristic that serves as a proxy for register usage and code size. +// +// We have noticed that tilings with very long LLVM IR code are both slow to +// compile and slow to run. This can be for example due to register spills. So +// we should skip these tilings to save time. But it's better to skip them +// before the LLVM IR is generated. To do that, we came up with a formula that +// strongly correlates with the LLVM IR size. The formula is the size of the two +// input and the output thread block tiles divided by the number of warps. We +// read https://developer.nvidia.com/blog/cutlass-linear-algebra-cuda/ as a +// reference, and found the formula by trial and error. +// +// To regenerate the limit, we have to run an exhaustive search on all tilings +// for a few different HLOs, printing the runtimes and the heuristic values. +// +// From that, we can find a limit, such that all tilings within alpha * +// optimal_runtime have a heuristic value less than or equal to the limit. +// +// In our measurements, all tilings which were within 1.13 * optimal_runtime had +// a complexity_heuristic_value <= kComplexityHeuristicLimit. +// +// See go/tiling-heuristic for more details. +absl::Status CheckGemmTilingComplexityHeuristic( + const TritonGemmConfig& config) { + constexpr int64_t kComplexityHeuristicLimit = 9000; + int64_t complexity_heuristic_value = + (config.block_m * config.block_n + + (config.block_m + config.block_n) * config.block_k) / + config.num_warps; + VLOG(2) << "Complexity heuristic: " << complexity_heuristic_value; + if (complexity_heuristic_value > kComplexityHeuristicLimit) { + return ResourceExhausted("Tiling complexity heuristic exceeded: %d > %d", + complexity_heuristic_value, + kComplexityHeuristicLimit); + } + return absl::OkStatus(); +} + +class Scopes { + public: + Scopes(ImplicitLocOpBuilder& b, const HloInstruction* dot_instr, + const TritonFusionAnalysis& analysis, const MatMulDims& dims, + const TritonGemmConfig& config, const MatMulLaunchConfig launch_config, + bool is_sparse) + : lhs_(TritonFusionAnalysis::Scope::LHS), + rhs_(TritonFusionAnalysis::Scope::RHS), + out_(TritonFusionAnalysis::Scope::OUTPUT) { + constexpr int group_m = 8; + const int64_t width = group_m * launch_config.grid_n; + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + + auto pid_nc = b.create( + launch_config.noncontracting_program_id_dim); + pid_k_ = (config.split_k > 1) + ? b.create(mt::ProgramIDDim::Z) + : Value{}; + + auto group_id = b.create(pid_nc, c32(width)); + ma::ConstantOp group_m_op = c32(group_m); + auto first_pid_m = b.create(group_id, group_m_op); + auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); + auto group_size = b.create( + b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, + group_m_op); + + pid_m_ = b.create(first_pid_m, + b.create(pid_nc, group_size)); + + pid_n_ = b.create(b.create(pid_nc, c32(width)), + group_size); + + int lhs_non_contracting_block_size = config.block_m; + int lhs_contracting_block_size = config.block_k; + int lhs_unpack_bound_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { + auto minor_dim = std::max(dims.lhs_contracting_dim_idx, + dims.lhs_noncontracting_dim_idx); + auto minor_bound = analysis + .IterSpec(TritonFusionAnalysis::Scope::LHS, + dot_instr->operand(0), minor_dim) + ->at(0) + .count; + if (minor_bound == + 1) { // Assuming that the contracting dimension is major. + lhs_contracting_block_size /= 2; + lhs_unpack_bound_idx = 1; + } else if (dims.lhs_contracting_dim_idx > + dims.lhs_noncontracting_dim_idx) { + // lhs is int4 and the contracting dimension is minor. + lhs_contracting_block_size /= 2; + lhs_unpack_bound_idx = 1; + } else { + // lhs is int4 and the contracting dimension is major. + lhs_non_contracting_block_size /= 2; + lhs_unpack_bound_idx = 0; + } + } + if (is_sparse) { + lhs_contracting_block_size /= 2; + } + lhs_.tiled_dims = { + DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + lhs_non_contracting_block_size, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + lhs_contracting_block_size, config.split_k)}; + lhs_.batch_dim_idx = dims.lhs_batch_dim_idx; + lhs_.unpack_dim_idx = lhs_unpack_bound_idx; + + int rhs_contracting_block_size = config.block_k; + int rhs_non_contracting_block_size = config.block_n; + int rhs_unpack_bound_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { + auto minor_dim = std::max(dims.rhs_contracting_dim_idx, + dims.rhs_noncontracting_dim_idx); + auto minor_bound = analysis + .IterSpec(TritonFusionAnalysis::Scope::RHS, + dot_instr->operand(1), minor_dim) + ->at(0) + .count; + + if (minor_bound == 1) { // rhs is int4 and the _minor_ bound is 1. + rhs_contracting_block_size /= 2; + } else if (dims.rhs_contracting_dim_idx > + dims.rhs_noncontracting_dim_idx) { + // rhs is int4 and the contracting dimension is minor. + rhs_contracting_block_size /= 2; + } else { + // rhs is int4 and the contracting dimension is major. + rhs_non_contracting_block_size /= 2; + rhs_unpack_bound_idx = 1; + } + } + rhs_.tiled_dims = { + DimProperties(dims.rhs_contracting_dim_idx, pid_k_, + rhs_contracting_block_size, config.split_k), + DimProperties(dims.rhs_noncontracting_dim_idx, pid_n_, + rhs_non_contracting_block_size, + /*split_value=*/1)}; + rhs_.batch_dim_idx = dims.rhs_batch_dim_idx; + rhs_.unpack_dim_idx = rhs_unpack_bound_idx; + + out_.tiled_dims = {DimProperties(dims.out_lhs_noncontracting_dim_idx, + pid_m_, config.block_m, + /*split_value=*/1), + DimProperties(dims.out_rhs_noncontracting_dim_idx, + pid_n_, config.block_n, + /*split_value=*/1)}; + out_.batch_dim_idx = dims.out_batch_dim_idx; + + if (is_sparse) { + meta_ = Side{TritonFusionAnalysis::Scope::META, + /*tiled_dims=*/ + {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + config.block_m, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + config.block_k / 16, config.split_k)}, + dims.lhs_batch_dim_idx}; + } + } + + std::vector input_scopes() const { + if (meta_.has_value()) { + return {&lhs_, &rhs_, &meta_.value()}; + } + return {&lhs_, &rhs_}; + } + const Side& lhs() const { return lhs_; } + const Side& rhs() const { return rhs_; } + const Side& out() const { return out_; } + const std::optional& meta() const { return meta_; } + const Value& pid_m() const { return pid_m_; } + const Value& pid_k() const { return pid_k_; } + const Value& pid_n() const { return pid_n_; } + + static bool is_int4_param(const TritonFusionAnalysis& analysis, + TritonFusionAnalysis::Scope scope) { + const ConstHloInstructionSet& params = analysis.ScopeParameters(scope); + return params.size() == 1 && + (*params.cbegin())->shape().element_type() == S4; + } + + private: + Side lhs_; + Side rhs_; + Side out_; + std::optional meta_; + + Value pid_m_; + Value pid_k_; + Value pid_n_; +}; + +enum MaskExpandDimension { kMajor = 0, kMinor = 1 }; + +Value EmitMaskOnInput(ImplicitLocOpBuilder& b, + MaskExpandDimension expand_dimension, Value input, + int denom, Value k, int64_t dims_k, int64_t block_k, + Value pid_k) { + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + int size = block_k / denom; + auto elements_in_tile = b.create(c32(dims_k / denom), k); + auto cond = + b.create(ma::CmpIPredicate::slt, elements_in_tile, c32(size)); + auto if_op = b.create( + cond, /*thenBranch=*/ + [&](mlir::OpBuilder& builder, mlir::Location loc) { + ImplicitLocOpBuilder b(loc, builder); + auto range_k = Range(b, size); + if (pid_k != nullptr) { + range_k = b.create( + range_k, Splat(b, b.create(pid_k, c32(size)), size)); + } + auto ty = mlir::cast(input.getType()); + TensorValue range_expanded = mlir::cast( + b.create(range_k, expand_dimension).getResult()); + Value mask = b.create( + ty.clone(b.getI1Type()), + b.create(ma::CmpIPredicate::slt, range_expanded, + Splat(b, elements_in_tile, + range_expanded.getType().getShape()))); + auto result = b.create(mask, input, ZerosLike(b, input)); + b.create(mlir::ValueRange(result)); + }, + /*elseBranch=*/ + [&](mlir::OpBuilder& b, mlir::Location loc) { + b.create(loc, mlir::ValueRange(input)); + }); + return if_op.getResult(0); +} + +} // namespace + +// Use tiling and execution parameters from 'config'. BlockLevelParameters are +// ignored. +// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. +absl::Status EmitMatMul(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::triton::FuncOp fn, const BlockLevelParameters&) { + auto backend_config = + fusion->backend_config()->fusion_backend_config(); + + if (!backend_config.has_triton_gemm_config()) { + // TODO(bchetioui): consolidate default parameters. At the moment, these + // may be constructed in two distinct places. + LOG(WARNING) << "Using fallback triton GEMM config for op " + << fusion->name(); + auto& triton_config = *backend_config.mutable_triton_gemm_config(); + triton_config.set_block_m(64); + triton_config.set_block_k(64); + triton_config.set_block_n(64); + triton_config.set_split_k(1); + triton_config.set_num_stages(1); + triton_config.set_num_warps(2); + triton_config.set_num_ctas(1); + } + + TF_ASSIGN_OR_RETURN( + TritonGemmConfig config, + TritonGemmConfig::FromProto(backend_config.triton_gemm_config())); + TF_ASSIGN_OR_RETURN(auto analysis, + TritonFusionAnalysis::Execute( + *fusion->called_computation(), config.split_k)); + + TF_RETURN_IF_ERROR(CheckGemmTilingComplexityHeuristic(config)); + + const HloComputation* computation = fusion->fused_instructions_computation(); + const HloInstruction* instr = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + const HloDotInstruction* dot_instr = DynCast(instr); + bool is_sparse = dot_instr->sparse_operands() > 0; + + // Use 32-bit indexing if addressing any of the inputs or the output (which + // could grow if split_k is set) does not cross the INT_MAX boundary. + // Otherwise, fall back to 64-bit indexing, which is slower. + bool use_64bit_indexing = + ShapeUtil::ElementsIn(dot_instr->operand(0)->shape()) > INT_MAX || + ShapeUtil::ElementsIn(dot_instr->operand(1)->shape()) > INT_MAX || + ShapeUtil::ElementsIn(dot_instr->shape()) * config.split_k > INT_MAX; + Type index_ty = builder.getIntegerType(use_64bit_indexing ? 64 : 32); + + const HloInstruction* root = dot_instr->parent()->root_instruction(); + TF_RET_CHECK(!root->shape().IsTuple()); + + // We'll be creating a lot of instructions from a single dot, use an + // implicit loc builder so we don't have to pass around the location all the + // time. + auto loc = mlir::NameLoc::get(builder.getStringAttr(dot_instr->name())); + ImplicitLocOpBuilder b(loc, builder); + + TF_RETURN_IF_ERROR(ValidateMatMulConfig(config, *dot_instr)); + const int split_k = config.split_k; + const int block_m = config.block_m; + const int block_k = config.block_k; + const int block_n = config.block_n; + + TF_ASSIGN_OR_RETURN(const MatMulDims dims, + MatMulDims::Create(config, *dot_instr, analysis)); + const MatMulLaunchConfig launch_config(config, *dot_instr, dims); + VLOG(6) << analysis.ToString(); + + MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, + index_ty, dims, launch_config, analysis); + + TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); + + ma::ConstantOp accumulator_init = + CreateConst(b, acc_ty, 0, {block_m, block_n}); + + // Parameters are passed to the loop in non-trivial order, these maps help + // finding them and their attributes. + absl::flat_hash_map iter_args_to_inputs; + absl::flat_hash_map> iter_args_to_boundary_checks; + + // Calculate the sizes of the lhs, rhs, meta, and output sides. + Scopes scopes(b, dot_instr, analysis, dims, config, launch_config, is_sparse); + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + + constexpr size_t kLhsMetaOperandIdx = HloDotInstruction::kOperands; + size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); + size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); + + absl::flat_hash_map triton_type_for_input; + for (const Side& side : {scopes.lhs(), scopes.rhs()}) { + for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + TF_ASSIGN_OR_RETURN(Type input_ty, + TritonType(b, input->shape().element_type())); + triton_type_for_input.insert({input, input_ty}); + } + } + + auto body_builder = [&](mlir::OpBuilder&, mlir::Location, Value ki, + ValueRange iter_args) -> void { + SmallVector iter_args_next; + iter_args_next.reserve(iter_args.size()); + std::array, 3> values; + + // Load tiles of all parameters of LHS and RHS scopes and advance pointers. + for (int i = 0; i < iter_args.size() - 1; ++i) { + const int index = i < lsize ? 0 : i < lsize + rsize ? 1 : 2; + const Side& side = *(scopes.input_scopes()[index]); + + const HloInstruction* param_hlo = iter_args_to_inputs[i]; + Type param_ty = index == kLhsMetaOperandIdx + ? b.getI16Type() + : triton_type_for_input.at(param_hlo); + Type param_storage_ty = StorageType(b, param_ty); + Value param_value = + EmitParameterLoad(b, iter_args[i], iter_args_to_boundary_checks[i]); + if (param_ty != param_storage_ty) { + // For example cast i8 to i1. + param_value = Cast(b, param_value, param_ty); + } + + CHECK(values[index].insert({param_hlo, param_value}).second); + SmallVector increments; + for (const DimProperties& dim : side.tiled_dims) { + const TensorIterationSpec::DimIterationSpec* spec = + analysis.IterSpec(side.scope, iter_args_to_inputs[i], dim.index); + if (spec == nullptr || spec->at(0).stride == 0) { + continue; + } + // Only the contracting dimensions are advanced. + if (dim.index == (index == 0 || index == kLhsMetaOperandIdx + ? dims.lhs_contracting_dim_idx + : dims.rhs_contracting_dim_idx)) { + increments.push_back(c32(dim.block_size * split_k)); + } else { + increments.push_back(c32(0)); + } + } + if (increments.empty()) { + iter_args_next.push_back(iter_args[i]); + } else { + iter_args_next.push_back(b.create( + iter_args[i].getType(), iter_args[i], increments)); + } + } + + // Emit all operations of LHS and RHS scopes. + Value dot_input_lhs = emitter.MakeInput(scopes.lhs(), 0, values[0]); + Value dot_input_rhs = emitter.MakeInput(scopes.rhs(), 1, values[1]); + Value dot_input_meta = + is_sparse ? emitter.MakeInput(*scopes.meta(), 2, values[2]) : Value{}; + + // Operation in the fusion before the dot can alter the elements of the + // tiles that were zero masked during loads. These have to be zeroed here + // again just before the dot so that they do not affect the output. + // Only the K dimension needs masking here because unnecessary elements in + // the other two get discarded by the masked store at the end. + const bool need_masking = dims.k % (block_k * split_k) > 0; + if (need_masking) { + dot_input_lhs = EmitMaskOnInput(b, MaskExpandDimension::kMajor, + dot_input_lhs, is_sparse ? 2 : 1, ki, + dims.k, block_k, scopes.pid_k()); + dot_input_rhs = + EmitMaskOnInput(b, MaskExpandDimension::kMinor, dot_input_rhs, 1, ki, + dims.k, block_k, scopes.pid_k()); + // Masking the metadata is not necessary, as the inputs are masked + // (i.e. zeroed out), so the padded metadata can hold any values. + } + + if (is_sparse) { + iter_args_next.push_back(b.create( + dot_input_lhs, dot_input_rhs, iter_args.back(), dot_input_meta)); + b.create(iter_args_next); + return; + } + + const HloModule* hlo_module = dot_instr->GetModule(); + if (hlo_module->config().debug_options().xla_gpu_enable_bf16_3way_gemm() && + hlo_module->config().debug_options().xla_gpu_enable_bf16_6way_gemm()) { + LOG(WARNING) << "Both BF16 6way gemm and 3way gemm are enabled." + << " Fallback to BF16 6way gemm."; + } + + Value accumulator_next; + if (Is6xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, + device_info)) { + absl::StatusOr accumulator_next_or = Emit6xBfloat16MatMul( + b, dot_input_lhs, dot_input_rhs, iter_args.back()); + TF_CHECK_OK(accumulator_next_or.status()); + accumulator_next = accumulator_next_or.value(); + } else if (Is3xBfloat16MatMul(dot_instr, b, dot_input_lhs, dot_input_rhs, + device_info)) { + absl::StatusOr accumulator_next_or = Emit3xBfloat16MatMul( + b, dot_input_lhs, dot_input_rhs, iter_args.back()); + TF_CHECK_OK(accumulator_next_or.status()); + accumulator_next = accumulator_next_or.value(); + } else { + // Execute matrix multiplication of input tiles and pass the accumulator. + // TODO(manany): Should be looked into once we enable Hopper workloads. + // maxNumImpreciseAcc flag was introduced for Hopper to accumulate in a + // lower precision than the output type. The change was introduced here: + // https://github.com/openai/triton/commit/31b0c521427109a8eda609b58d756c380b21599a + auto dot_precision = InferDotPrecision(dot_instr); + + // Cast F32 inputs to BF16 if the algorithm is BF16_BF16_F32. + if (dot_instr->precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + if (dot_instr->operand(0)->shape().element_type() == F32) { + dot_input_lhs = Cast(b, dot_input_lhs, b.getBF16Type()); + } + if (dot_instr->operand(1)->shape().element_type() == F32) { + dot_input_rhs = Cast(b, dot_input_rhs, b.getBF16Type()); + } + } + + // For fp8 matmuls, disable accumulator promotion, as it's what cublas + // does. It may make sense to enable frequent accumulator promotion at + // higher matmul precisions set in the config. + int max_num_imprecise_acc = + IsFp8Matmul(dot_instr) ? std::numeric_limits::max() : 0; + accumulator_next = + b.create(dot_input_lhs, dot_input_rhs, iter_args.back(), + /*inputPrecision=*/dot_precision, + /*maxNumImpreciseAcc=*/max_num_imprecise_acc); + } + iter_args_next.push_back(accumulator_next); + + b.create(iter_args_next); + return; + }; + + // Pointers to inputs of LHS scope, then RHS, then the accumulator + // that change with every loop iteration and are passed between them. + SmallVector iter_args; + iter_args.reserve(lsize + rsize + 1 + is_sparse); + + for (const Side* side : scopes.input_scopes()) { + for (const HloInstruction* input : ScopeInputs(analysis, side->scope)) { + TF_RET_CHECK( + iter_args_to_inputs.insert({iter_args.size(), input}).second); + TF_ASSIGN_OR_RETURN(SmallVector arguments, + GetArguments(fn, *input)); + TF_ASSIGN_OR_RETURN(Value tensor_ptr, + emitter.EmitTensorPointer( + input, *side, arguments, scopes.pid_k(), + iter_args_to_boundary_checks[iter_args.size()])); + iter_args.push_back(tensor_ptr); + } + } + + iter_args.push_back(accumulator_init); + Value acc_final = b.create( + /*lowerBound=*/c32(0), + /*upperBound=*/c32(dims.k), + /*step=*/c32(block_k * split_k), + /*iterArgs=*/iter_args, body_builder) + .getResult(iter_args.size() - 1); + absl::flat_hash_map values_out; + TF_ASSIGN_OR_RETURN(Type acc_final_ty, + TritonType(b, dot_instr->shape().element_type())); + values_out[dot_instr] = Cast(b, acc_final, acc_final_ty); + + // Emit the output scope. + if (std::vector to_emit = + emitter.EpiloguePostOrderTransitiveOperands(root); + !to_emit.empty()) { + for (const HloInstruction* input : + ScopeInputs(analysis, TritonFusionAnalysis::Scope::OUTPUT)) { + std::vector boundary_checks; + TF_ASSIGN_OR_RETURN(SmallVector arguments, + GetArguments(fn, *input)); + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer(input, scopes.out(), arguments, + scopes.pid_k(), boundary_checks)); + TF_RET_CHECK(values_out + .insert({input, EmitParameterLoad(b, tensor_pointer, + boundary_checks)}) + .second); + } + TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, device_info, &analysis, + scopes.out(), to_emit, values_out) + .status()); + } + + // Emit tensor store operations for all outputs. + for (int i = 0; + i < fn.getNumArguments() - dot_instr->parent()->num_parameters(); ++i) { + const HloInstruction* producer = + root->shape().IsTuple() ? root->operand(i) : root; + std::vector boundary_checks; + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer( + producer, scopes.out(), + {fn.getArgument(i + dot_instr->parent()->num_parameters())}, + scopes.pid_k(), boundary_checks)); + b.create(tensor_pointer, values_out[producer], boundary_checks, + mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); + } + return absl::OkStatus(); +} + +absl::StatusOr GetMatMulLaunchDimensions( + const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, + const TritonGemmConfig& config) { + auto dot = HloBfsFindIf(fusion.GetRoots(), fusion, [](auto node) { + return node.opcode() == HloOpcode::kDot; + }); + TF_RET_CHECK(dot != std::nullopt); + const auto& dot_instr = + *static_cast(&dot->instruction()); + TF_ASSIGN_OR_RETURN(MatMulDims dims, + MatMulDims::Create(config, dot_instr, analysis)); + MatMulLaunchConfig launch_config(config, dot_instr, dims); + return launch_config.launch_dims; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h new file mode 100644 index 00000000000000..ae6fd6116cb461 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_LEGACY_MATMUL_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_LEGACY_MATMUL_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/Builders.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/triton_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "triton/Dialect/Triton/IR/Dialect.h" + +namespace xla::gpu { + +// Compute the launch dimensions for the given Triton MatMul. +absl::StatusOr GetMatMulLaunchDimensions( + const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, + const TritonGemmConfig& config); + +// Use tiling and execution parameters from 'config'. BlockLevelParameters are +// ignored. +// Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. +absl::Status EmitMatMul(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::triton::FuncOp fn, const BlockLevelParameters&); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_LEGACY_MATMUL_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc new file mode 100644 index 00000000000000..682472cf5bdc4d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul_stub.cc @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" + +namespace xla::gpu { + +// Compute the launch dimensions for the given Triton MatMul. +absl::StatusOr GetMatMulLaunchDimensions( + const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, + const TritonGemmConfig& config) { + return absl::UnimplementedError("not supported for this build configuration"); +} + +absl::Status EmitMatMul(mlir::OpBuilder builder, + absl::string_view libdevice_path, + const se::DeviceDescription& device_info, + const HloFusionInstruction* fusion, + mlir::triton::FuncOp fn, const BlockLevelParameters&) { + return absl::UnimplementedError("not supported for this build configuration"); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index a3cafcdf40628f..76d53d8390e1cc 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -53,7 +53,22 @@ struct MixTypeParams { class MixedTypeTest : public GpuCodegenTest, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + se::GpuComputeCapability GetGpuComputeCapability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + void SetUp() override { + if (std::holds_alternative( + GetGpuComputeCapability())) { + GTEST_SKIP() + << "Related fusions are not performed on ROCm without Triton."; + } + } + + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // We are testing Triton, remove cuBLAS fallback for these tests. debug_options.set_xla_gpu_cublas_fallback(false); @@ -125,9 +140,7 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, // TritonRewriteTest2Params{F32, F16}, // TritonRewriteTest2Params{F32, BF16}, MixTypeParams{S8, BF16, 24, 40, 8}, - // Modify the case below to use k = 32 instead of - // 16 once b/337839570 is fixed. - MixTypeParams{S8, F16, 80, 32, 32, 1e-3, 1e-6}, + MixTypeParams{S8, F16, 80, 16, 32, 1e-3, 1e-6}, MixTypeParams{F16, F32, 127, 3, 300, 1e-2, 1e-2}, MixTypeParams{F16, BF16, 544, 96, 16, 1e-3, 1e-3}, MixTypeParams{BF16, F32, 77, 500, 333, 3e-3, 3e-3}, @@ -136,7 +149,7 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest, class TritonTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(true); debug_options.set_xla_gpu_cublas_fallback(false); @@ -787,7 +800,7 @@ INSTANTIATE_TEST_SUITE_P( class TritonSoftmaxTest : public GpuCodegenTest, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options .set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); @@ -1563,6 +1576,7 @@ ENTRY main { } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec(/*aabs=*/0, /*arel=*/0), /*reference_preprocessor=*/nullptr, + /*test_preprocessor=*/nullptr, max_bits_of_precision)); } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc index 33c1e0666dd90d..fe984689ac1131 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "llvm/ADT/SmallVector.h" #include "llvm/IR/Module.h" #include "mlir/IR/Builders.h" @@ -56,28 +57,6 @@ absl::Status EmitGeneric(mlir::OpBuilder b, absl::string_view libdevice_path, return absl::UnimplementedError("not supported for this build configuration"); } -absl::StatusOr GetMatMulLaunchDimensions( - const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { - return absl::UnimplementedError("not supported for this build configuration"); -} - -absl::Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters) { - return absl::UnimplementedError("not supported for this build configuration"); -} - -absl::Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters) { - return absl::UnimplementedError("not supported for this build configuration"); -} - void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context) {} absl::StatusOr TritonWrapper( @@ -103,7 +82,7 @@ absl::StatusOr CompileTritonToLLVM( const se::DeviceDescription& device_info, const BlockLevelParameters& block_level_parameters, mlir::ModuleOp triton_module, llvm::Module* llvm_module, - mlir::MLIRContext& mlir_context) { + mlir::MLIRContext& mlir_context, bool emit_kernel) { return absl::UnimplementedError("not supported for this build configuration"); } @@ -121,10 +100,16 @@ std::string GetLibdevicePath(const HloModuleConfig& hlo_config, namespace ir_emitter_triton_internal { +llvm::SmallVector ComputeDelinearizedTileIndex( + mlir::ImplicitLocOpBuilder& b, + absl::Span num_output_tiles_per_dim) { + return {}; +} + absl::StatusOr CreateMakeTensorPtrOp( mlir::ImplicitLocOpBuilder& b, mlir::ValueRange tile_multi_index, const TiledHloInstruction& tiled_hlo, mlir::Value argument_block) { - return MakeTensorPtrOpAndBoundaryChecks(); + return absl::UnimplementedError("not supported for this build configuration"); } } // namespace ir_emitter_triton_internal diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc new file mode 100644 index 00000000000000..0e777e59364f21 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_stub_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/PassManager.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" +#include "xla/service/gpu/fusions/triton/triton_fusion_emitter_legacy_matmul.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "tsl/platform/test.h" + +namespace mlir::triton::nvidia_gpu { +// We define ClusterInfo here in order to avoid having to import a GPU-only +// header. +struct ClusterInfo {}; + +} // namespace mlir::triton::nvidia_gpu + +namespace xla::gpu { +namespace { + +TEST(TritonStub, CallStubApi) { + mlir::MLIRContext context; + + LoadMlirDialectsForTriton(context); + EXPECT_FALSE(TritonWrapper({}, nullptr, {}, {}, {}, nullptr, context).ok()); + EXPECT_FALSE(CreateTritonModule({}, nullptr, {}, {}, context).ok()); + EXPECT_FALSE( + CompileTritonToLLVM({}, {}, {}, {}, {}, {}, nullptr, context, {}).ok()); + + mlir::OpPassManager pm; + mt::nvidia_gpu::ClusterInfo cluster_info; + + EXPECT_FALSE(CreateTritonPipeline(pm, {}, {}, cluster_info).ok()); + EXPECT_EQ(GetLibdevicePath({}, {}), ""); + + mlir::ImplicitLocOpBuilder builder(mlir::UnknownLoc::get(&context), &context); + + EXPECT_TRUE( + ir_emitter_triton_internal::ComputeDelinearizedTileIndex(builder, {}) + .empty()); + + HloConstantInstruction constant(LiteralUtil::CreateR1({1, 1})); + auto tiled_hlo = TiledHloInstruction::Create(&constant, {}, {1}, {1}, {}); + EXPECT_TRUE(tiled_hlo.ok()); + + EXPECT_FALSE(ir_emitter_triton_internal::CreateMakeTensorPtrOp( + builder, {}, *tiled_hlo.value(), {}) + .ok()); +} + +TEST(TritonStub, CallLegacyMatMulApis) { + HloConstantInstruction constant(Literal{}); + auto adaptor = HloFusionAdaptor::ForInstruction(&constant); + EXPECT_FALSE(GetMatMulLaunchDimensions({}, *adaptor.get(), {}).ok()); + + mlir::MLIRContext context; + mlir::OpBuilder builder(&context); + EXPECT_FALSE(EmitMatMul(builder, {}, {}, nullptr, {}, {}).ok()); +} + +} // namespace + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index 69eeb7461bba6f..776f8920a1d24c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -106,6 +106,11 @@ absl::flat_hash_set TritonSupportedUnaryElementwiseOps( HloOpcode::kCeil}; ret.insert(additional_opcodes.begin(), additional_opcodes.end()); } + + if (primitive_util::IsFloatingPointType(element_type)) { + ret.insert(HloOpcode::kReducePrecision); + } + return ret; } @@ -117,7 +122,7 @@ CodegenDecision IsTritonSupportedConversion( }; auto error_message = [&]() { - return CodegenDecision( + return CodegenDecision::Forbid( absl::StrCat("Unsupported conversion in Triton: ", primitive_util::LowercasePrimitiveTypeName(input), " to ", primitive_util::LowercasePrimitiveTypeName(output))); @@ -137,9 +142,8 @@ CodegenDecision IsTritonSupportedConversion( } if (IsTritonSupportedDataType(input, gpu_version) && - (IsTritonSupportedDataType(output, gpu_version) || - output == PrimitiveType::S4)) { - return CodegenDecision{}; + IsTritonSupportedDataType(output, gpu_version)) { + return CodegenDecision::Allow(); } return error_message(); @@ -224,7 +228,8 @@ CodegenDecision CanTritonHandleReduce( const se::GpuComputeCapability& gpu_version) { if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || reduce.shape().element_type() == PrimitiveType::F8E5M2) { - return "F8E4M3FN and F8E5M2 are not supported for reductions."; + return CodegenDecision::Forbid( + "F8E4M3FN and F8E5M2 are not supported for reductions."); } bool is_triton_supported_reduction_computation = absl::c_all_of( @@ -232,19 +237,21 @@ CodegenDecision CanTritonHandleReduce( return IsTritonSupportedInstructionImpl(*instr, gpu_version).CanFuse(); }); if (!is_triton_supported_reduction_computation) { - return "Unsupported reduction computation by Triton."; + return CodegenDecision::Forbid( + "Unsupported reduction computation by Triton."); } if (reduce.dimensions().size() == 1 && reduce.operand_count() == 2) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } - return "Reduction is not a row-reduction of a single operand."; + return CodegenDecision::Forbid( + "Reduction is not a row-reduction of a single operand."); } CodegenDecision IsTritonSupportedInstructionImpl( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { if (internal::IsTritonUnsupportedOpcode(instr.opcode())) { - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } // Special handling for the kConvert instruction, which has a non-standard @@ -259,7 +266,7 @@ CodegenDecision IsTritonSupportedInstructionImpl( bool output_type_is_supported = IsTritonSupportedDataType(type, gpu_version); if (!output_type_is_supported) { - return "Unsupported output data type."; + return CodegenDecision::Forbid("Unsupported output data type."); } bool input_types_are_supported = @@ -269,16 +276,25 @@ CodegenDecision IsTritonSupportedInstructionImpl( }); if (!input_types_are_supported) { - return "Unsupported input data type."; + return CodegenDecision::Forbid("Unsupported input data type."); } // Const is technically an elementwise op, so this check must be before the // elementwise check. if (instr.opcode() == HloOpcode::kConstant) { - return ShapeUtil::IsScalar(instr.shape()) - ? CodegenDecision{} - : CodegenDecision{ - "Only scalar constants are supported in Triton."}; + return ShapeUtil::IsEffectiveScalar(instr.shape()) + ? CodegenDecision::Allow() + : CodegenDecision::Forbid( + "Only scalar constants are supported in Triton."); + } + + if (instr.opcode() == HloOpcode::kIota) { + PrimitiveType element_type = instr.shape().element_type(); + return element_type != PrimitiveType::F8E4M3FN && + element_type != PrimitiveType::F8E5M2 + ? CodegenDecision::Allow() + : CodegenDecision::Forbid( + "F8E4M3FN and F8E5M2 are not supported for iota."); } if (instr.IsElementwise()) { @@ -289,9 +305,9 @@ CodegenDecision IsTritonSupportedInstructionImpl( // operand. instr.operand(instr.operand_count() - 1)->shape().element_type(), gpu_version)) { - return "Unsupported elementwise operation."; + return CodegenDecision::Forbid("Unsupported elementwise operation."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } // TODO(bchetioui): support kDot, kPad, and kDynamicSlice. @@ -306,12 +322,12 @@ CodegenDecision IsTritonSupportedInstructionImpl( case HloOpcode::kBroadcast: case HloOpcode::kBitcast: case HloOpcode::kReshape: - return CodegenDecision{}; + return CodegenDecision::Allow(); default: VLOG(2) << "Unsupported instruction: " << instr.ToString(); break; } - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } } // namespace @@ -354,6 +370,7 @@ bool IsTritonUnsupportedOpcode(HloOpcode opcode) { case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kPartitionId: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduceWindow: @@ -410,6 +427,20 @@ CodegenDecision IsTritonSupportedInstruction( return decision; } +CodegenDecision IsTritonSupportedComputation( + const HloComputation& computation, + const se::GpuComputeCapability& gpu_compute_capability) { + for (const auto* instruction : computation.instructions()) { + if (CodegenDecision can_codegen = + IsTritonSupportedInstruction(*instruction, gpu_compute_capability); + !can_codegen) { + return can_codegen; + } + } + + return CodegenDecision::Allow(); +} + bool IsTritonFusedComputation(const HloComputation& computation) { HloFusionInstruction* fusion = static_cast(computation.FusionInstruction()); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h index 391bb3daa46b11..879a8ea375bba4 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_H_ #define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_H_ @@ -22,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/instruction_fusion.h" +#include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -48,6 +50,16 @@ absl::Status EnsureTritonSupportsComputeCapability( CodegenDecision IsTritonSupportedInstruction( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); +// Returns `CodegenDecision`'s equivalent of `true` if all the instructions in +// the parameter computation are supported by the Triton emitters for the given +// compute capability. +// +// This function has the same caveats as `IsTritonSupportedInstruction` as +// defined in the present namespace. +CodegenDecision IsTritonSupportedComputation( + const HloComputation& computation, + const se::GpuComputeCapability& gpu_compute_capability); + // Returns `true` if the parameter computation is a Triton fused computation, // i.e. the calling fusion instruction has `FusionKind::kCustom` and // `backend_config()` with `kind` set to diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc index b07630b7cb7734..dc8740807e2fc5 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/triton/triton_support.h" - #include #include #include @@ -22,12 +20,14 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/strings/str_format.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -115,29 +115,22 @@ bool IsTritonSupportedDataType(PrimitiveType type, CodegenDecision IsInstructionSupportsDataTypes( const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; + return CodegenDecision::Forbid("Unsupported output data type."); } for (const HloInstruction* operand : instr.operands()) { const auto operand_type = operand->shape().element_type(); switch (instr.opcode()) { case HloOpcode::kConvert: - // TODO(b/358580281): remove DebugOptions from this function after - // enabling int4 in Triton GEMM. - if (operand_type == S4 && instr.GetModule() - ->config() - .debug_options() - .xla_gpu_enable_triton_gemm_int4()) { - continue; - } + if (operand_type == S4) continue; [[fallthrough]]; default: if (!IsTritonSupportedDataType(operand_type, gpu_version)) { - return "Unsupported input data type."; + return CodegenDecision::Forbid("Unsupported input data type."); } } } - return CodegenDecision{}; + return CodegenDecision::Allow(); } std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( @@ -211,12 +204,12 @@ CodegenDecision CanTritonHandleElementwise( return decision; } if (instr.opcode() == HloOpcode::kConstant) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( instr.opcode(), instr.operand(0)->shape().element_type())) { - return "Unsupported elementwise operation."; + return CodegenDecision::Forbid("Unsupported elementwise operation."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } bool IsDotAlgorithmSupportedByTriton( @@ -227,6 +220,7 @@ bool IsDotAlgorithmSupportedByTriton( auto rocm_compute_capability = std::get_if(&gpu_version); switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3: case PrecisionConfig::ALG_DOT_TF32_TF32_F32: if (cuda_compute_capability) { return true; @@ -268,37 +262,40 @@ CodegenDecision CanTritonHandleGEMM( if (!tsl::tensor_float_32_execution_enabled() || absl::c_any_of(dot.precision_config().operand_precision(), [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Having non-default operand precisions or TensorFloat-32 disabled " - "for Dot op with unset algorithm."; + return CodegenDecision::Forbid( + "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."); } } else { if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), gpu_version)) { - return "Unsupported algorithm on the current device(s)."; + return CodegenDecision::Forbid(absl::StrFormat( + "Unsupported algorithm on the current device(s): %s", + PrecisionConfig::Algorithm_Name(dot.precision_config().algorithm()))); } } // TODO(b/266862493): Support more output types. if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), gpu_version)) { - return "Unsupported output data type for Dot op."; + return CodegenDecision::Forbid("Unsupported output data type for Dot op."); } if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), gpu_version) || !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), gpu_version)) { - return "Unsupported input data type for Dot op."; + return CodegenDecision::Forbid("Unsupported input data type for Dot op."); } const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // TODO(b/269580541): support multiple batch dimensions. if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; + return CodegenDecision::Forbid("Multiple batch dimensions."); } - return CodegenDecision{}; + return CodegenDecision::Allow(); } bool NoNonContractingDimension(const HloDotInstruction& dot) { @@ -323,7 +320,7 @@ CodegenDecision IsTritonSupportedDynamicSlice( case S32: break; // supported default: - return CodegenDecision( + return CodegenDecision::Forbid( "Dynamic slice is only supported with S8, S16, or S32 indices."); } } @@ -341,14 +338,14 @@ CodegenDecision IsTritonSupportedDynamicSlice( if (i == majormost_dim_id) { continue; } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { - return CodegenDecision( + return CodegenDecision::Forbid( "Unsupported dynamic slice on non-major-most dimension."); } } // TODO(b/343143854): Check the subtleties of which dynamic slices are // supported, for example that a fragmented dimension cannot be sliced. - return CodegenDecision{}; + return CodegenDecision::Allow(); } CodegenDecision IsTritonSupportedInstruction( @@ -362,15 +359,15 @@ CodegenDecision IsTritonSupportedInstruction( auto* dot = Cast(&instr); // Cases where lhs or rhs have no non-contracting dims are not handled. if (NoNonContractingDimension(*dot)) { - return "No non-contracting dimensions."; + return CodegenDecision::Forbid("No non-contracting dimensions."); } return CanTritonHandleGEMM(*dot, gpu_version); } case HloOpcode::kTuple: { if (instr.IsRoot()) { - return CodegenDecision{}; + return CodegenDecision::Allow(); } - return "Only supports root tuples."; + return CodegenDecision::Forbid("Only supports root tuples."); } case HloOpcode::kDynamicSlice: { return IsTritonSupportedDynamicSlice( @@ -384,11 +381,11 @@ CodegenDecision IsTritonSupportedInstruction( case HloOpcode::kConcatenate: case HloOpcode::kParameter: case HloOpcode::kBroadcast: - return CodegenDecision{}; + return CodegenDecision::Allow(); default: break; } - return "Unsupported opcode."; + return CodegenDecision::Forbid("Unsupported opcode."); } } // namespace legacy_triton diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc index 6b21a8c7c82acc..3f926c70ef1825 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -86,6 +86,7 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { case HloOpcode::kXor: case HloOpcode::kNot: return type == PRED || pu::IsIntegralType(type); + case HloOpcode::kAtan2: case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kExpm1: @@ -94,13 +95,13 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { case HloOpcode::kRsqrt: case HloOpcode::kSin: case HloOpcode::kSqrt: - case HloOpcode::kCbrt: case HloOpcode::kTan: case HloOpcode::kTanh: case HloOpcode::kReal: case HloOpcode::kImag: case HloOpcode::kLogistic: return pu::IsFloatingPointType(type) || pu::IsComplexType(type); + case HloOpcode::kCbrt: case HloOpcode::kErf: case HloOpcode::kFloor: case HloOpcode::kCeil: @@ -120,7 +121,6 @@ bool DoesOpSupportType(HloOpcode opcode, PrimitiveType type) { return pu::IsSignedIntegralType(type) || pu::IsFloatingPointType(type) || pu::IsComplexType(type); case HloOpcode::kPower: - case HloOpcode::kAtan2: case HloOpcode::kDivide: case HloOpcode::kRemainder: case HloOpcode::kSubtract: @@ -246,6 +246,19 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}, cc); } +TEST_P(BitcastOrReshapeTest, IsTritonSupported0DBitcastOrReshape) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + parameter_0 = $0[1,1,1] parameter(0) + ROOT bitcast_or_reshape = $0[] $1(parameter_0) +})"; + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc); +} + constexpr std::array kTestedOpsBitcastReshape = {HloOpcode::kBitcast, HloOpcode::kReshape}; @@ -445,6 +458,39 @@ ENTRY triton_computation { skip_failure_branch_to_avoid_crash); } +TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise0D) { + auto [data_type, opcode, cc] = GetParam(); + const std::string kHloTestTemplate = R"( +ENTRY triton_computation { + parameter_0 = $0[] parameter(0) + parameter_1 = $0[] parameter(1) + ROOT binary = $0[] $1(parameter_0, parameter_1) +})"; + + const std::string kHloCompareTestTemplate = R"( +ENTRY triton_computation { + parameter_0 = $0[] parameter(0) + parameter_1 = $0[] parameter(1) + ROOT compare = pred[] $1(parameter_0, parameter_1), direction=GE +})"; + + TF_ASSERT_OK_AND_ASSIGN( + TestedInstruction ti, + ParseTemplateAndGetInstruction(opcode == HloOpcode::kCompare + ? kHloCompareTestTemplate + : kHloTestTemplate, + data_type, opcode)); + + bool skip_failure_branch_to_avoid_crash = + opcode == HloOpcode::kDivide && + (data_type == PrimitiveType::BF16 || data_type == PrimitiveType::F16 || + data_type == PrimitiveType::F8E5M2 || + data_type == PrimitiveType::F8E4M3FN); + + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{}, cc, + skip_failure_branch_to_avoid_crash); +} + constexpr std::array kTestedOpsBinaryElementwise = { HloOpcode::kAnd, HloOpcode::kOr, @@ -1062,6 +1108,24 @@ INSTANTIATE_TEST_SUITE_P( using ConstantTest = TritonSupportTestWithTypeAndDeviceParam; +TEST_P(ConstantTest, ConstantEffectiveScalar) { + // The IsTritonSupportedReduction effectively tests the scalar constant + // support. + auto [data_type, cc] = GetParam(); + bool dtype_is_complex = data_type == C64 || data_type == C128; + const std::string kHloTestTemplate = + absl::Substitute(R"( +ENTRY triton_computation { + ROOT const = $0[1,1] constant({{$1}}) +})", + "$0", dtype_is_complex ? "(0, 0)" : "0"); + + TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti, ParseTemplateAndGetInstruction( + kHloTestTemplate, data_type, + HloOpcode::kConstant)); + RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 1}, cc); +} + TEST_P(ConstantTest, Constant2D) { // The IsTritonSupportedReduction effectively tests the scalar constant // support. @@ -1165,6 +1229,7 @@ constexpr std::array kUnsupportedOps = {HloOpcode::kAddDependency, HloOpcode::kOutfeed, HloOpcode::kPad, HloOpcode::kPartitionId, + HloOpcode::kRaggedAllToAll, HloOpcode::kRecv, HloOpcode::kRecvDone, HloOpcode::kReduceWindow, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc index 6fa9635663999e..881fdbc89e634f 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.cc @@ -39,9 +39,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" -#include "xla/service/float_normalization.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" @@ -70,9 +70,10 @@ bool SupportsBF16(const stream_executor::GpuComputeCapability& cc) { CHECK(false); } -absl::Status CreateTritonIrAndFileCheck( - HloTestBase* test, absl::string_view hlo_text, - absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) { +absl::Status CreateTritonIrAndFileCheck(HloTestBase* test, + absl::string_view hlo_text, + absl::string_view triton_fusion_name, + absl::string_view filecheck_pattern) { TF_ASSIGN_OR_RETURN(std::unique_ptr verified_module, test->ParseAndReturnVerifiedModule(hlo_text)); auto* comp = verified_module->GetComputationWithName(triton_fusion_name); @@ -235,7 +236,7 @@ absl::Status ConvertEntryToTritonFusion(HloModule* module) { } // namespace -DebugOptions TritonSupportTestBase::GetDebugOptionsForTest() { +DebugOptions TritonSupportTestBase::GetDebugOptionsForTest() const { auto options = HloTestBase::GetDebugOptionsForTest(); // It's necessary to set this manually, because it's disabled in optimized // builds and there are some ASAN builds that run on TAP with -c opt. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.h index 0a7ec78bc00432..b8e178fc6905ee 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_test_utils.h @@ -72,7 +72,7 @@ absl::StatusOr ApplyFloatNormalization( class TritonSupportTestBase : public HloTestBase { protected: - DebugOptions GetDebugOptionsForTest() override; + DebugOptions GetDebugOptionsForTest() const override; // An HLO module together with a reference to the instruction of interest // that's being tested. See ParseTemplateAndGetInstruction for more details. diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_dialect.td b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_dialect.td new file mode 100644 index 00000000000000..4d8948a0e3c0b5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_dialect.td @@ -0,0 +1,32 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_DIALECT_TD_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_DIALECT_TD_ + +include "mlir/IR/DialectBase.td" + +def XlaTritonDialect : Dialect { + let name = "triton_xla"; + + let description = [{ + This dialect contains ops included in the xla extension point for Triton. + }]; + + let cppNamespace = "::mlir::triton::xla"; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_DIALECT_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.cc b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.cc new file mode 100644 index 00000000000000..56260cd43d781b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.cc @@ -0,0 +1,127 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/xla_triton_ops.h" + +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep +#include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/IR/Builders.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" // IWYU pragma: keep +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep +#include "mlir/IR/ValueRange.h" +#include "xla/service/gpu/fusions/triton/xla_triton_dialect.cc.inc" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" + +using mlir::Dialect; +using mlir::DictionaryAttr; +using mlir::Location; +using mlir::LogicalResult; +using mlir::MLIRContext; +using mlir::OpaqueProperties; +using mlir::RankedTensorType; +using mlir::RegionRange; +using mlir::SmallVectorImpl; +using mlir::TensorOrMemDesc; +using mlir::Type; +using mlir::ValueRange; + +namespace mlir::triton::xla { + +void XlaTritonDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "xla/service/gpu/fusions/triton/xla_triton_ops.cc.inc" + >(); +} + +LogicalResult SparseDotOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + return DotOp::inferReturnTypes(context, location, operands, attributes, + properties, regions, inferredReturnTypes); +} + +LogicalResult SparseDotOp::verify() { + // Implied properties of 2:4 sparse dots. + constexpr int kContractingFactor = 2; + constexpr int kMetadataElementsPerPackedValue = 8; + // Verify operand A. + auto aTensorTy = llvm::cast(getOperand(0).getType()); + auto aElemTy = aTensorTy.getElementType(); + if (!aElemTy.isF16() && !aElemTy.isBF16()) + return emitError("element type of operand A is not supported"); + auto aShape = aTensorTy.getShape(); + if (aShape.size() != 2) return emitError("shape of operand A is incorrect"); + + // Verify operand B. + auto bTensorTy = llvm::cast(getOperand(1).getType()); + auto bElemTy = bTensorTy.getElementType(); + if (!bElemTy.isF16() && !bElemTy.isBF16()) + return emitError("element type of operand B is not supported"); + auto bShape = bTensorTy.getShape(); + if (bShape.size() != 2) return emitError("shape of operand B is incorrect"); + + // Verify operand C. + auto cTensorTy = llvm::cast(getOperand(2).getType()); + auto cElemTy = cTensorTy.getElementType(); + if (!cElemTy.isF32()) + return emitError("element type of operand C is not supported"); + auto cShape = cTensorTy.getShape(); + if (cShape.size() != 2) return emitError("shape of operand C is incorrect"); + + // Check operand dependencies. + if (aShape[0] != cShape[0] || bShape[1] != cShape[1] || + bShape[0] != aShape[1] * kContractingFactor) + return emitError("operand shape dimensions are incorrect"); + if (aElemTy != bElemTy) + return emitError("operand element types do not match"); + + // Verify sparse metadata. + auto metaTy = llvm::cast(getOperand(3).getType()); + auto metaShape = metaTy.getShape(); + if (!metaTy.getElementType().isInteger(16) || metaShape.size() != 2) + return emitError("sparse metadata tensor is invalid"); + if (metaShape[0] != aShape[0] || + metaShape[1] * kMetadataElementsPerPackedValue != aShape[1]) + return emitError("sparse metadata shape dimensions are incorrect"); + + // Verify tensor encoding. + auto aEncoding = aTensorTy.getEncoding(); + auto bEncoding = bTensorTy.getEncoding(); + if (!aEncoding && !bEncoding) return mlir::success(); + if (!aEncoding || !bEncoding) + return emitError("mismatching encoding between A and B operands"); + + Dialect &dialect = aEncoding.getDialect(); + auto interface = llvm::cast(&dialect); + return interface->verifyDotOpEncodingCompatibility(getOperation(), aEncoding, + bEncoding); +} + +} // namespace mlir::triton::xla + +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/triton/xla_triton_ops.cc.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.h b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.h new file mode 100644 index 00000000000000..ffa6dc460be1e0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.h @@ -0,0 +1,31 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_H_ + +#include "mlir/IR/Attributes.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep +#include "mlir/IR/Dialect.h" // IWYU pragma: keep +#include "mlir/IR/MLIRContext.h" // IWYU pragma: keep +#include "mlir/IR/OpDefinition.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep +#include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep +#include "xla/service/gpu/fusions/triton/xla_triton_dialect.h.inc" // IWYU pragma: keep +#include "triton/Dialect/Triton/IR/Dialect.h" // IWYU pragma: keep +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/triton/xla_triton_ops.h.inc" + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.td b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.td new file mode 100644 index 00000000000000..284a2757b92c87 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/xla_triton_ops.td @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_TD_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_TD_ + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" // Pure +include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType +include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td" +include "xla/service/gpu/fusions/triton/xla_triton_dialect.td" +include "triton/Dialect/Triton/IR/TritonTypes.td" + +class TTXLA_Op traits = []> : + Op { +} + +def TTXLA_SparseDotOp : TTXLA_Op<"sparse_dot", [ + Pure, DeclareOpInterfaceMethods, + TypesMatchWith<"result's type matches accumulator's type", "d", "c", "$_self">]> { + let summary = "sparse dot"; + + let arguments = (ins + TT_TensorOrMemDesc:$a, + TT_TensorOrMemDesc:$b, + TT_FpIntTensor:$c, + TT_IntTensor: $aMeta); + let results = (outs TT_FpIntTensor:$d); + let assemblyFormat = [{ + $a`,` $b`,` $c`,` $aMeta attr-dict + `:` type($a) `meta` type($aMeta) `*` type($b) `->` type($d) + }]; + let hasVerifier = 1; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_XLA_TRITON_OPS_TD_ diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc new file mode 100644 index 00000000000000..d72666468cd5a1 --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.cc @@ -0,0 +1,79 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/service/collective_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +using MemoryAwareScheduler = std::function( + const HloModule*, int64_t, int64_t*)>; + +namespace { + +int64_t GetDefaultValue(HloOpcode opcode) { + if (opcode == HloOpcode::kAllGather) { + return kDefaultAllGatherCombineThreshold; + } else if (opcode == HloOpcode::kAllReduce) { + return kDefaultAllReduceCombineThreshold; + } else if (opcode == HloOpcode::kReduceScatter) { + return kDefaultReduceScatterCombineThreshold; + } else { + LOG(FATAL) << "Expected collective op. Got: " << opcode; + } + return -1; +} + +} // namespace + +int64_t ComputeSuggestedCombinerThreshold( + const HloModule& module, const se::DeviceDescription& device_info, + MemoryAwareScheduler scheduler, HloOpcode collective_opcode, + int64_t pointer_size) { + int64_t base_limit = module.config().device_memory_size() != 0 + ? module.config().device_memory_size() + : device_info.device_memory_size(); + int64_t peak_memory_bytes = -1; + auto mem_schedule = scheduler(&module, pointer_size, &peak_memory_bytes); + + if (!mem_schedule.ok() || peak_memory_bytes == -1) { + VLOG(1) << "Cannot schedule module: " << mem_schedule.status().message(); + return GetDefaultValue(collective_opcode); + } + + int32_t slop_factor = + module.config().debug_options().xla_gpu_memory_limit_slop_factor(); + return base_limit * slop_factor / 100 - peak_memory_bytes; +} + +absl::Status AppendPipelinedInstruction(HloInstruction* instr) { + auto config = instr->backend_config(); + config->mutable_collective_backend_config()->set_is_pipelined(true); + TF_RETURN_IF_ERROR(instr->set_backend_config(*config)); + return absl::OkStatus(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h new file mode 100644 index 00000000000000..171f132dae541a --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ +#define XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/stream_executor/device_description.h" + +namespace xla::gpu { + +// Suggests a combiner threshold to the caller (combiner). At the moment it only +// suggests a lower value than a default combiner threshold if it exceeds +// available memory on a device. If the scheduling of a `module` failed for any +// reason the method return a default value of a combiner threshold for +// `collective_opcode`. +int64_t ComputeSuggestedCombinerThreshold( + const HloModule& module, const se::DeviceDescription& device_info, + std::function(const HloModule*, int64_t, + int64_t*)> + scheduler, + HloOpcode collective_opcode, int64_t pointer_size); + +// Adds information that `instr` has been pipelined to the +// `CollectiveBackendInfo`. It is up to the caller to decide when to invoke +// this. +absl::Status AppendPipelinedInstruction(HloInstruction* instr); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_GPU_COLLECTIVE_COMBINER_UTILS_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc new file mode 100644 index 00000000000000..753945761b4f9b --- /dev/null +++ b/third_party/xla/xla/service/gpu/gpu_collective_combiner_utils_test.cc @@ -0,0 +1,322 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/gpu_collective_combiner_utils.h" + +#include +#include + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/collective_pipeliner.h" +#include "xla/service/collective_utils.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +using CollectiveCombinerUtilsTest = HloTestBase; + +TEST_F(CollectiveCombinerUtilsTest, + ComputeSuggestedCombinerThresholdReturnsMemoryThresholdForDeviceInfo) { + absl::string_view kHloText = R"( + HloModule m + + ENTRY ar { + p0 = f32[32,32] parameter(0) + p1 = f32[32,32] parameter(1) + + ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), + custom_call_target="__cublas$gemm" + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + int pointer_size = 4; + stream_executor::DeviceDescription device_info; + device_info.set_device_memory_size(20000); + + int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( + *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, + HloOpcode::kAllReduce, pointer_size); + + // device size = 20000 bytes + // slop factor = 0.95 + // peak memory = parameters + output = (2*32*32 + 32*32) * 4 bytes = 12288 + // suggested thresholds = device size * slop factor - peak memory + EXPECT_EQ(suggested_threshold, 6712); +} + +TEST_F(CollectiveCombinerUtilsTest, + ComputeSuggestedCombinerThresholdReturnsMemoryThresholdForModuleConfig) { + absl::string_view kHloText = R"( + HloModule m + + ENTRY ar { + p0 = f32[32,32] parameter(0) + p1 = f32[32,32] parameter(1) + + ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), + custom_call_target="__cublas$gemm" + })"; + + HloModuleConfig config = GetModuleConfigForTest(); + config.set_device_memory_size(20000); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloText, config)); + int pointer_size = 4; + stream_executor::DeviceDescription device_info; + + int64_t suggested_threshold = ComputeSuggestedCombinerThreshold( + *module, device_info, gpu::ScheduleGpuModuleWithMemoryScheduler, + HloOpcode::kAllReduce, pointer_size); + + // device size = 20000 bytes + // slop factor = 0.95 + // peak memory = parameters + output = (2*32*32 + 32*32) * 4 bytes = 12288 + // suggested thresholds = device size * slop factor - peak memory + EXPECT_EQ(suggested_threshold, 6712); +} + +TEST_F( + CollectiveCombinerUtilsTest, + ComputeSuggestedCombinerThresholdReturnsDefaultValueUponSchedulingFailure) { + absl::string_view kHloText = R"( + HloModule m + + ENTRY ar { + p0 = f32[32,32] parameter(0) + p1 = f32[32,32] parameter(1) + + ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), + custom_call_target="__cublas$gemm" + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + int pointer_size = 4; + stream_executor::DeviceDescription device_info; + device_info.set_device_memory_size(20000); + + auto sched_fun = [](const HloModule* m, int64_t p_sz, + int64_t* p) -> absl::StatusOr { + return absl::UnimplementedError("Fail."); + }; + + int64_t suggested_threshold_all_reduce = ComputeSuggestedCombinerThreshold( + *module, device_info, sched_fun, HloOpcode::kAllReduce, pointer_size); + int64_t suggested_threshold_all_gather = ComputeSuggestedCombinerThreshold( + *module, device_info, sched_fun, HloOpcode::kAllGather, pointer_size); + int64_t suggested_threshold_reduce_scatter = + ComputeSuggestedCombinerThreshold(*module, device_info, sched_fun, + HloOpcode::kReduceScatter, + pointer_size); + + EXPECT_EQ(suggested_threshold_all_reduce, kDefaultAllReduceCombineThreshold); + EXPECT_EQ(suggested_threshold_all_gather, kDefaultAllGatherCombineThreshold); + EXPECT_EQ(suggested_threshold_reduce_scatter, + kDefaultReduceScatterCombineThreshold); +} + +TEST_F(CollectiveCombinerUtilsTest, + AppendPipelinedInstructionAppendsPipelinedInstructionInfoForward) { + // This is just a canonical IR which makes it easy to pipeline a collective + // forward – in this example AllReduce. + absl::string_view kHloText = R"( + HloModule module + add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) + } + + while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT + } + + while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + current-loop-index = s32[] get-tuple-element(param), index=0 + output-buffer = bf16[3,8,128] get-tuple-element(param), index=1 + input-buffer = bf16[3,8,128] get-tuple-element(param), index=2 + constant.1 = s32[] constant(1) + next-loop-index = s32[] add(current-loop-index, constant.1) + constant.0 = s32[] constant(0) + sliced-input-buffer = bf16[1,8,128] dynamic-slice(input-buffer, + current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128} + all-reduce = bf16[1,8,128] all-reduce(sliced-input-buffer), + replica_groups={}, to_apply=add, channel_id=1 + bitcast.0 = bf16[3,8,128] bitcast(all-reduce) + dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer, + bitcast.0, current-loop-index, constant.0, constant.0) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index, + dynamic-update-slice, input-buffer) + } + + ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), + condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + // This config is taken from the gpu_compiler.cc configuration of the forward + // pipeliner. + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kForward, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/HloPredicateTrue, + /*reuse_pipelined_op_buffer=*/HloPredicateFalse, + }; + config.postprocess_pipelined_ops = AppendPipelinedInstruction; + + HloPassPipeline pipeline("collective-pipeliner"); + pipeline.AddPass(config); + pipeline.AddPass(/*remove_cross_partition_collective_ops=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + + hlo_query::ForEachInstructionWithOpcode( + *module, HloOpcode::kAllReduce, [](HloInstruction* instr) { + EXPECT_TRUE(instr->backend_config() + ->collective_backend_config() + .is_pipelined()); + }); + + hlo_query::ForEachInstructionWithPred( + *module, HloPredicateIsNotOp, + [](HloInstruction* instr) { + EXPECT_FALSE(instr->backend_config() + ->collective_backend_config() + .is_pipelined()); + }); +} + +TEST_F(CollectiveCombinerUtilsTest, + AppendPipelinedInstructionAppendsPipelinedInstructionInfoBackward) { + // This is just the simple IR which makes it easy for the pipeliner to + // pipeline a collective. The pipelined collective is AllGather so the main + // complexity comes from a fact that we have to slice it at the end of the + // loop (so that we can gather it again in the next iteration). + absl::string_view kHloText = R"( + HloModule module + + while_cond { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(3) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT + } + + while_body { + param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0) + current-loop-index = s32[] get-tuple-element(param), index=0 + output-buffer = bf16[3,8,128] get-tuple-element(param), index=1 + input-buffer = bf16[3,8,128] get-tuple-element(param), index=2 + constant.1 = s32[] constant(1) + next-loop-index = s32[] add(current-loop-index, constant.1) + constant.0 = s32[] constant(0) + sliced-input-buffer = bf16[1,8,128] dynamic-slice(input-buffer, + current-loop-index, constant.0, constant.0), dynamic_slice_sizes={1,8,128} + all-gather = bf16[3,8,128] all-gather(sliced-input-buffer), dimensions={0} + dynamic-update-slice = bf16[3,8,128] dynamic-update-slice(output-buffer, + all-gather, current-loop-index, constant.0, constant.0) + ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(next-loop-index, + dynamic-update-slice, input-buffer) + } + + ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[3,8,128] parameter(0) + tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0) + while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), + condition=while_cond, body=while_body + ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1 + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + // This config is taken from the gpu_compiler.cc configuration of the backward + // pipeliner. + CollectivePipeliner::Config config{ + /*level_to_operate_on=*/0, + /*max_pipelining_per_loop=*/INT64_MAX, + /*last_run=*/true, + /*pipeline_use_tree=*/false, + /*process_different_sized_ops=*/true, + /*pipelining_direction=*/ + CollectivePipeliner::PipeliningDirection::kBackward, + /*should_process=*/HloPredicateIsOp, + /*acceptable_formatting=*/HloPredicateTrue, + /*reuse_pipelined_op_buffer=*/HloPredicateFalse, + /*should_allow_loop_variant_parameter_in_chain=*/HloPredicateFalse, + /*should_allow_control_dependencies=*/false, + /*postprocess_backward_peeled_op=*/std::nullopt, + /*postprocess_backward_rotated_op=*/std::nullopt, + /*should_add_loop_invariant_op_in_chain=*/true, + }; + config.postprocess_pipelined_ops = AppendPipelinedInstruction; + + HloPassPipeline pipeline("collective-pipeliner"); + pipeline.AddPass(config); + pipeline.AddPass(/*remove_cross_partition_collective_ops=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get())); + EXPECT_TRUE(changed); + + hlo_query::ForEachInstructionWithOpcode( + *module, HloOpcode::kAllGather, [](HloInstruction* instr) { + EXPECT_TRUE(instr->backend_config() + ->collective_backend_config() + .is_pipelined()); + }); + + hlo_query::ForEachInstructionWithPred( + *module, HloPredicateIsNotOp, + [](HloInstruction* instr) { + EXPECT_FALSE(instr->backend_config() + ->collective_backend_config() + .is_pipelined()); + }); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc old mode 100644 new mode 100755 index 86f07cfb649404..8adc8747bd169b --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -61,7 +60,9 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" @@ -70,57 +71,88 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/collectives/all_gather_broadcast_reorder.h" +#include "xla/hlo/transforms/collectives/all_gather_combiner.h" +#include "xla/hlo/transforms/collectives/all_reduce_combiner.h" +#include "xla/hlo/transforms/collectives/all_reduce_contiguous.h" +#include "xla/hlo/transforms/collectives/async_collective_creator.h" +#include "xla/hlo/transforms/collectives/collective_quantizer.h" +#include "xla/hlo/transforms/collectives/collectives_schedule_linearizer.h" +#include "xla/hlo/transforms/convert_memory_placement_to_internal_annotations.h" +#include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" +#include "xla/hlo/transforms/expanders/comparison_expander.h" +#include "xla/hlo/transforms/expanders/convolution_4d_expander.h" +#include "xla/hlo/transforms/expanders/convolution_pred_expander.h" +#include "xla/hlo/transforms/expanders/dot_decomposer.h" +#include "xla/hlo/transforms/expanders/dynamic_index_splitter.h" +#include "xla/hlo/transforms/expanders/eigh_expander.h" +#include "xla/hlo/transforms/expanders/logistic_expander.h" +#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" +#include "xla/hlo/transforms/expanders/qr_expander.h" +#include "xla/hlo/transforms/expanders/real_imag_expander.h" +#include "xla/hlo/transforms/expanders/reduce_decomposer.h" +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" +#include "xla/hlo/transforms/expanders/stable_sort_expander.h" +#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" +#include "xla/hlo/transforms/host_offload_legalize.h" +#include "xla/hlo/transforms/host_offloader.h" +#include "xla/hlo/transforms/operand_upcaster.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/all_reduce_folder.h" +#include "xla/hlo/transforms/simplifiers/broadcast_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/conditional_canonicalizer.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/dot_merger.h" +#include "xla/hlo/transforms/simplifiers/dynamic_dimension_simplifier.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" +#include "xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h" +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" +#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/result_caster.h" +#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" +#include "xla/hlo/transforms/simplifiers/slice_sinker.h" +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #include "xla/maybe_owning.h" -#include "xla/service/algebraic_simplifier.h" -#include "xla/service/all_gather_broadcast_reorder.h" -#include "xla/service/all_gather_combiner.h" -#include "xla/service/all_reduce_combiner.h" -#include "xla/service/all_reduce_contiguous.h" -#include "xla/service/all_reduce_folder.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_reduce_reassociate.h" -#include "xla/service/async_collective_creator.h" +#include "xla/service/all_reduce_simplifier.h" #include "xla/service/batched_gather_scatter_normalizer.h" #include "xla/service/batchnorm_expander.h" -#include "xla/service/bitcast_dtypes_expander.h" -#include "xla/service/broadcast_canonicalizer.h" #include "xla/service/buffer_assignment.h" #include "xla/service/call_inliner.h" #include "xla/service/collective_permute_decomposer.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/collective_quantizer.h" -#include "xla/service/collectives_schedule_linearizer.h" -#include "xla/service/comparison_expander.h" #include "xla/service/compiler.h" -#include "xla/service/conditional_canonicalizer.h" #include "xla/service/conditional_simplifier.h" -#include "xla/service/convert_memory_placement_to_internal_annotations.h" -#include "xla/service/convert_mover.h" -#include "xla/service/convolution_4d_expander.h" -#include "xla/service/convolution_pred_expander.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/dot_decomposer.h" -#include "xla/service/dot_merger.h" #include "xla/service/dump.h" #include "xla/service/dynamic_dimension_inference.h" -#include "xla/service/dynamic_dimension_simplifier.h" -#include "xla/service/dynamic_index_splitter.h" #include "xla/service/dynamic_padder.h" -#include "xla/service/eigh_expander.h" #include "xla/service/executable.h" #include "xla/service/export_hlo.h" -#include "xla/service/flatten_call_graph.h" -#include "xla/service/float_normalization.h" #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" -#include "xla/service/gather_simplifier.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/execution_stream_assignment.h" +#include "xla/service/gpu/fusion_dispatch_pipeline.h" #include "xla/service/gpu/fusion_pipeline.h" #include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/gpu_executable.h" @@ -157,6 +189,7 @@ limitations under the License. #include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include "xla/service/gpu/transforms/cudnn_custom_call_converter.h" #include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h" #include "xla/service/gpu/transforms/dot_dimension_sorter.h" #include "xla/service/gpu/transforms/dot_operand_converter.h" #include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" @@ -188,54 +221,25 @@ limitations under the License. #include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" #include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_computation_deduplicator.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_rematerialization.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/host_memory_transfer_asyncifier.h" -#include "xla/service/host_offload_legalize.h" -#include "xla/service/host_offloader.h" #include "xla/service/layout_assignment.h" #include "xla/service/layout_normalization.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/logistic_expander.h" -#include "xla/service/operand_upcaster.h" -#include "xla/service/optimization_barrier_expander.h" -#include "xla/service/optimize_input_output_buffer_alias.h" -#include "xla/service/qr_expander.h" -#include "xla/service/real_imag_expander.h" -#include "xla/service/reduce_decomposer.h" #include "xla/service/reduce_scatter_combiner.h" #include "xla/service/reduce_scatter_reassociate.h" -#include "xla/service/reduce_window_rewriter.h" -#include "xla/service/reshape_decomposer.h" -#include "xla/service/reshape_mover.h" -#include "xla/service/result_caster.h" -#include "xla/service/rng_bit_generator_expander.h" -#include "xla/service/rng_expander.h" +#include "xla/service/scatter_determinism_expander.h" #include "xla/service/scatter_expander.h" #include "xla/service/scatter_simplifier.h" #include "xla/service/sharding_remover.h" -#include "xla/service/simplify_fp_conversions.h" -#include "xla/service/slice_sinker.h" #include "xla/service/slow_operation_alarm.h" -#include "xla/service/sort_simplifier.h" -#include "xla/service/stable_sort_expander.h" -#include "xla/service/stochastic_convert_decomposer.h" -#include "xla/service/sub_byte_normalization.h" #include "xla/service/topk_rewriter.h" #include "xla/service/transpose_folding.h" -#include "xla/service/tuple_simplifier.h" #include "xla/service/while_loop_all_reduce_code_motion.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_simplifier.h" -#include "xla/service/while_loop_trip_count_annotator.h" -#include "xla/service/zero_sized_hlo_elimination.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -474,8 +478,10 @@ GpuCompiler::GpuCompiler(se::Platform::Id platform_id, namespace { // Adds the HloVerifier for GPU to the given pipeline. -void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {}, - bool debug_only = false) { +void AddHloVerifier(HloPassPipeline* pipeline, + bool verify_unique_channel_ids = false, + HloVerifierOpts&& opts = {}, bool debug_only = false) { + opts.verify_unique_channel_ids = verify_unique_channel_ids; std::unique_ptr verifier_metadata = std::make_unique(std::move(opts)); if (debug_only) { @@ -522,6 +528,10 @@ AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions( // GPU only supports canonical convolutions. layout_insensitive_algsimp_opts.set_supports_non_canonical_dots(false); + // On GPU it helps to reorder them so that the fused cuDNN kernel can be + // used. + layout_insensitive_algsimp_opts.set_enable_conv_add_multiply_reorder(true); + // "slow" minmax means we propagate nan. layout_insensitive_algsimp_opts.set_minmax_propagate_nan( !hlo_module_config.debug_options().xla_gpu_enable_fast_min_max()); @@ -648,7 +658,8 @@ absl::Status RunOptimizationPasses( const DebugOptions& debug_options = hlo_module->config().debug_options(); HloPassPipeline pipeline("optimization"); - AddHloVerifier(&pipeline); + AddHloVerifier(&pipeline, + !debug_options.xla_experimental_ignore_channel_id()); if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) { pipeline.AddPass(); } @@ -692,6 +703,7 @@ absl::Status RunOptimizationPasses( if (RequireDeterminism(hlo_module->config())) { // Scatter can be indeterministic if indices are not unique or a non // associative combiner function is used. Eliminate these Scatter ops. + pipeline.AddPass(); pipeline.AddPass( ScatterExpander::kEliminateIndeterministicScatters); } @@ -768,7 +780,9 @@ absl::Status RunOptimizationPasses( // point. [&, &pipeline = pipeline.AddPass>("simplification")] { - AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true); + AddHloVerifier(&pipeline, + !debug_options.xla_experimental_ignore_channel_id(), + HloVerifierOpts{}, /*debug_only=*/true); // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. @@ -786,9 +800,21 @@ absl::Status RunOptimizationPasses( // AlgebraicSimplifier may add contracting dimensions to a dot. pipeline.AddPass(); pipeline.AddPass(); - // Only merge "smallish" dots. This threshold was not set carefully, but - // so far we know that 1mb is too small. - pipeline.AddPass(/*max_size_to_merge=*/int64_t{32} << 20); + // Only merge "smallish" dots. This threshold defaults to 32MB today, with + // a flag to override. + // Do not merge dots when they are assigned different stream ids. + std::function + can_merge = [&](const HloInstruction* dot_a, + const HloInstruction* dot_b) -> bool { + return dot_a->backend_config()->operation_queue_id() == + dot_b->backend_config()->operation_queue_id(); + }; + pipeline.AddPass( + /*max_size_to_merge=*/int64_t{debug_options + .xla_gpu_dot_merger_threshold_mb()} + << 20, + can_merge); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); @@ -822,8 +848,34 @@ absl::Status RunOptimizationPasses( return pipeline.Run(hlo_module).status(); } -absl::Status AddCollectivePipelinerPasses( - const DebugOptions& debug_options, HloPassPipeline& collectives_pipeline) { +absl::Status RunCollectiveOptimizationPasses( + HloModule* hlo_module, + const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, + se::GpuComputeCapability gpu_version) { + // Optimize collectives generated by SPMD partitioning. Enable these passes + // otherwise as well so that all collectives can get these optimizations. + const DebugOptions& debug_options = hlo_module->config().debug_options(); + + HloPassPipeline collectives_pipeline("collective-optimizations"); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass(); + collectives_pipeline.AddPass( + debug_options.xla_gpu_enable_reassociation_for_converted_ar()); + collectives_pipeline.AddPass(); + + collectives_pipeline.AddPass( + /*enable_reduce_scatter=*/debug_options + .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); + + // Moves collectives' subsequent quantization before the collective to + // minimize data transfers. + collectives_pipeline.AddPass(); + // Remove dead computations after collective quantization. + collectives_pipeline.AddPass(); + if (debug_options.xla_gpu_enable_pipelined_collectives() || debug_options.xla_gpu_enable_pipelined_all_reduce()) { CollectivePipeliner::Config config{ @@ -875,49 +927,7 @@ absl::Status AddCollectivePipelinerPasses( /*reuse_pipelined_op_buffer=*/HloPredicateFalse}; collectives_pipeline.AddPass(config); } - return absl::OkStatus(); -} - -absl::Status RunPostLayoutCollectivePipelinerPasses(HloModule* hlo_module) { - const DebugOptions& debug_options = hlo_module->config().debug_options(); - HloPassPipeline collectives_pipeline("collective-pipeliner-optimizations"); - if (debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { - TF_RETURN_IF_ERROR( - AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); - // We call WhileLoopTripCountAnnotator at the end of the collective - // pipeline, which might have changed the loop trip count. - collectives_pipeline.AddPass(); - // Flatten call graph after loop peeling. - collectives_pipeline.AddPass(); - } - return collectives_pipeline.Run(hlo_module).status(); -} - -absl::Status RunCollectiveOptimizationPasses( - HloModule* hlo_module, - const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, - se::GpuComputeCapability gpu_version) { - // Optimize collectives generated by SPMD partitioning. Enable these passes - // otherwise as well so that all collectives can get these optimizations. - const DebugOptions& debug_options = hlo_module->config().debug_options(); - HloPassPipeline collectives_pipeline("collective-optimizations"); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass(); - collectives_pipeline.AddPass( - debug_options.xla_gpu_enable_reassociation_for_converted_ar()); - collectives_pipeline.AddPass(); - - collectives_pipeline.AddPass( - /*enable_reduce_scatter=*/debug_options - .xla_gpu_enable_while_loop_reduce_scatter_code_motion()); - - if (!debug_options.xla_gpu_run_post_layout_collective_pipeliner()) { - TF_RETURN_IF_ERROR( - AddCollectivePipelinerPasses(debug_options, collectives_pipeline)); - } collectives_pipeline.AddPass(); collectives_pipeline.AddPass( @@ -952,12 +962,6 @@ absl::Status RunCollectiveOptimizationPasses( // Remove dead computations left over after ar/rs promotion. collectives_pipeline.AddPass(); - // Moves collectives' subsequent quantization before the collective to - // minimize data transfers. - collectives_pipeline.AddPass(); - // Remove dead computations after collective quantization. - collectives_pipeline.AddPass(); - // Run WhileLoopTripCountAnnotator after collective pipelining and before // layout assignment and fusion.This pass does some pattern-matching on // while bodies/conditions, and this is where the HLO is "nicest". @@ -993,6 +997,12 @@ absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, pipeline.AddPass( SubByteNormalization::SET_ELEMENT_SIZE); pipeline.AddPass(true); + // Run HostOffloadLegalize before LayoutNormalization to prevent + // the creation of invalid transpose/bitcast operations within + // host memory offloading segments. + pipeline.AddPass( + static_cast(stream_executor::MemoryType::kHost), + /* after_layout= */ true); return pipeline.Run(hlo_module).status(); } @@ -1033,6 +1043,43 @@ absl::Status RunFusionPasses(HloModule* hlo_module, return absl::OkStatus(); } +// Adds unrolling while loop optimization. Mostly to get rid of extra D2D +// copies, but also there are some performance benefits (better comm-compute +// overlap) when collectives are present within a while loop. +void AddDoubleBufferingPasses(const DebugOptions& opts, + HloPassPipeline& pipeline) { + std::optional unroll_strategy = + std::nullopt; + // Support old flag. + if (opts.xla_gpu_enable_while_loop_double_buffering()) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; + } + // Support new flag setting style, override the old one. + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; + } + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) { + LOG_IF(WARNING, unroll_strategy != std::nullopt) + << "Overriding double buffering set via " + "`xla_gpu_enable_while_loop_double_buffering` flag."; + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll; + } + if (opts.xla_gpu_enable_while_loop_unrolling() == + DebugOptions::WHILE_LOOP_UNROLLING_AUTO_UNROLL && + opts.xla_gpu_enable_heuristic_pass_configuration() && + !opts.xla_gpu_enable_while_loop_double_buffering()) { + unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kAuto; + } + if (unroll_strategy != std::nullopt) { + pipeline.AddPass(); + pipeline.AddPass(*unroll_strategy); + pipeline.AddPass(); + pipeline.AddPass(); + } +} + absl::Status RunPostFusionPasses( HloModule* hlo_module, std::function @@ -1065,29 +1112,7 @@ absl::Status RunPostFusionPasses( pipeline.AddPass(blueconnect_num_devices_per_host); } - std::optional unroll_strategy = - std::nullopt; - // Support old flag. - if (opts.xla_gpu_enable_while_loop_double_buffering()) { - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; - } - // Support new flag setting style, override the old one. - if (opts.xla_gpu_enable_while_loop_unrolling() == - DebugOptions::WHILE_LOOP_UNROLLING_DOUBLE_BUFFER) { - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kDoubleBuffer; - } - if (opts.xla_gpu_enable_while_loop_unrolling() == - DebugOptions::WHILE_LOOP_UNROLLING_FULL_UNROLL) { - LOG_IF(WARNING, unroll_strategy != std::nullopt) - << "Overriding double buffering set via " - "`xla_gpu_enable_while_loop_double_buffering` flag."; - unroll_strategy = DoubleBufferLoopUnrolling::UnrollStrategy::kFullUnroll; - } - if (unroll_strategy != std::nullopt) { - pipeline.AddPass(*unroll_strategy); - pipeline.AddPass(); - pipeline.AddPass(); - } + AddDoubleBufferingPasses(opts, pipeline); return pipeline.Run(hlo_module).status(); } @@ -1199,6 +1224,7 @@ absl::Status RunLayoutNormalizationPasses( opts.set_supports_non_canonical_dots(false); opts.set_is_layout_sensitive(true); opts.set_enable_conv_operand_swap(false); + opts.set_enable_conv_add_multiply_reorder(true); // "slow" minmax means we propagate nan. opts.set_minmax_propagate_nan(!debug_options.xla_gpu_enable_fast_min_max()); opts.set_enable_unconditional_reduce_of_concat_replacement(false); @@ -1307,7 +1333,7 @@ absl::Status GpuCompiler::OptimizeHloModule( } TF_RETURN_IF_ERROR(OptimizeHloConvolutionCanonicalization( - hlo_module, gpu_version, dnn_version, options.device_allocator, + hlo_module, gpu_version, dnn_version, gpu_target_config.device_description.runtime_version())); TF_RETURN_IF_ERROR( @@ -1320,8 +1346,6 @@ absl::Status GpuCompiler::OptimizeHloModule( hlo_module, stream_exec, options, gpu_target_config, thread_pool.get_mutable())); - TF_RETURN_IF_ERROR(RunPostLayoutCollectivePipelinerPasses(hlo_module)); - // This is a "low effort, high impact" fusion that should be run first. TF_RETURN_IF_ERROR(RunDynamicSliceFusionPasses(hlo_module, PlatformId())); @@ -1379,6 +1403,12 @@ void AddGemmRewriterPasses(HloPassPipeline& pipeline, bias_mode = GemmRewriterOptions::BiasMode::kNoBias; } + // Rewrite dots with the algorithms that cannot be handled by cublas directly. + // I.e. transform single dot into a chain of dots with the default algorithm + // that cublas can handle. These dots were inlined by the CallInliner pass + // above. + pipeline.AddPass(); + pipeline.AddPass( gpu_version, toolkit_version, GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only, bias_mode}); @@ -1402,6 +1432,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( opts.set_supports_non_canonical_dots(false); opts.set_is_layout_sensitive(true); opts.set_enable_conv_operand_swap(false); + opts.set_enable_conv_add_multiply_reorder(true); // "slow" minmax means we propagate nan. opts.set_minmax_propagate_nan(!debug_options.xla_gpu_enable_fast_min_max()); opts.set_enable_unconditional_reduce_of_concat_replacement(false); @@ -1413,19 +1444,23 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // Lambdas and related constants: const GpuFloatSupport bf16_support(gpu_version, BF16); const GpuFloatSupport f8e5m2_support(gpu_version, F8E5M2, F16); + const GpuFloatSupport f8e4m3_support(gpu_version, F8E4M3, F16); const GpuFloatSupport f8e4m3fn_support(gpu_version, F8E4M3FN, F16); const FloatSupport f8e4m3b11fnuz_support(F8E4M3B11FNUZ, F16); const GpuFloatSupport f8e5m2fnuz_support(gpu_version, F8E5M2FNUZ, F16); const GpuFloatSupport f8e4m3fnuz_support(gpu_version, F8E4M3FNUZ, F16); + const GpuFloatSupport f8e3m4_support(gpu_version, F8E3M4, F16); auto add_float_normalization = [&](HloPassPipeline& pipeline) { auto& sub_pipeline = pipeline.AddPass("float_normalization"); sub_pipeline.AddPass(&bf16_support); sub_pipeline.AddPass(&f8e5m2_support); + sub_pipeline.AddPass(&f8e4m3_support); sub_pipeline.AddPass(&f8e4m3fn_support); sub_pipeline.AddPass(&f8e4m3b11fnuz_support); sub_pipeline.AddPass(&f8e5m2fnuz_support); sub_pipeline.AddPass(&f8e4m3fnuz_support); + sub_pipeline.AddPass(&f8e3m4_support); // Remove `f32 -> bf16 -> f32` casts inserted by bf16 normalization. if (debug_options.xla_allow_excess_precision()) { sub_pipeline.AddPass(); @@ -1472,9 +1507,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( const auto* rocm_cc = std::get_if(&gpu_version); if (debug_options.xla_gpu_enable_triton_gemm() && - ((cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || - rocm_cc != nullptr)) { + (cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE))) { pipeline.AddPass(); pipeline.AddPass(gpu_version); } else if (cuda_cc != nullptr && @@ -1531,17 +1565,18 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( } pipeline.AddPass(); - // Do not split small reduction dimensions unless priority fusion is - // enabled, which handles such cases well. - bool ignore_small_reduce_dims = - !debug_options.xla_gpu_enable_priority_fusion(); - pipeline.AddPass>(ignore_small_reduce_dims); + pipeline.AddPass>( + /*ignore_small_reduce_dims=*/false); pipeline.AddPass>(gpu_version); + // Normalization passes might have introduced s4 tensors without bit width + // annotations, this pass will add the annotations. + pipeline.AddPass( + SubByteNormalization::SET_ELEMENT_SIZE); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } HloPassPipeline pipeline("post-layout_assignment"); - AddHloVerifier(&pipeline, + AddHloVerifier(&pipeline, !debug_options.xla_experimental_ignore_channel_id(), HloVerifierOpts{} .MakeLayoutSensitive() .WithInstructionCanChangeLayout( @@ -1562,15 +1597,19 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. + AddGemmRewriterPasses(pipeline, debug_options, gpu_version, gpu_target_config.device_description.runtime_version()); // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); - pipeline.AddPass( - static_cast(stream_executor::MemoryType::kHost), - /* after_layout= */ true); + pipeline.AddPass(&NormalizeLayoutForGpuCustomCalls); + + // Layout normalization will create scatters that are not simplified and + // also have unsorted update_window_dims. + pipeline.AddPass(); + pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost)); @@ -1606,14 +1645,16 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( #ifdef NDEBUG // Verify the module in non-debug builds. For debug builds, the verifier // already runs after every pass. + HloVerifierOpts opts = HloVerifierOpts{} + .MakeLayoutSensitive() + .WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout) + .VerifyBroadcastDimensionsOrder() + .VerifyReshapeIsBitcast(); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); pipeline.AddPass( - std::make_unique( - HloVerifierOpts{} - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout) - .VerifyBroadcastDimensionsOrder() - .VerifyReshapeIsBitcast()), + std::make_unique(std::move(opts)), "end-of-post-layout_assignment"); #endif // NDEBUG @@ -1693,6 +1734,19 @@ absl::StatusOr> GpuCompiler::RunHloPasses( TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + const auto* cuda_cc = std::get_if( + &gpu_target_config.device_description.gpu_compute_capability()); + if (cuda_cc != nullptr && cuda_cc->IsAtLeastAmpere()) { + // This needs to run after every pass affecting fusions, which includes + // `CopyFusion`, which itself must run in the + // `PrepareHloModuleForIrEmitting` pipeline. + TF_RETURN_IF_ERROR( + FusionDispatchPipeline(gpu_target_config.device_description, + ShapeSizeBytesFunction()) + .Run(module.get()) + .status()); + } + uint64_t end_usecs = tsl::Env::Default()->NowMicros(); // This won't record values for calls that error out (because if they error @@ -2242,6 +2296,9 @@ absl::StatusOr> GpuCompiler::RunBackend( return absl::StrFormat("XlaCompileBackend:#module=%s,program_id=%d#", module->name(), module->unique_id()); }}; + + RecordGpuCompilerStacktrace(); + BinaryMap dnn_compiled_graphs; if (stream_exec) { TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec, @@ -2576,9 +2633,12 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass(); } - if (module->config().debug_options().xla_gpu_enable_pgle_accuracy_checker()) { - AddHloVerifier(&main_pipeline, - HloVerifierOpts{}.VerifyInstructionNameUnchanged()); + if (module->config().debug_options().xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { + AddHloVerifier( + &main_pipeline, + module->config().debug_options().xla_experimental_ignore_channel_id(), + HloVerifierOpts{}.VerifyInstructionNameUnchanged()); } return main_pipeline.Run(module).status(); } diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index b18b48abfcc4d9..b8fe422dbe1fac 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -26,11 +26,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/IR/Module.h" #include "xla/autotune_results.pb.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_util.h" @@ -40,12 +41,10 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_compiler.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/semantic_version.h" @@ -227,7 +226,6 @@ class GpuCompiler : public LLVMCompiler { virtual absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator, const se::SemanticVersion& toolkit_version) = 0; // TODO(timshen): Replace `debug_module` with some portable debug information diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc old mode 100644 new mode 100755 index f163e5b786bba9..2f5662aaab75d3 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/service/gpu/gpu_compiler.h" +#include #include #include +#include #include #include #include @@ -25,9 +27,11 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" +#include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/autotune_results.pb.h" @@ -43,6 +47,7 @@ limitations under the License. #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/metrics.h" #include "xla/service/hlo_module_config.h" @@ -56,7 +61,10 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/literal_test_util.h" +#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/monitoring/collected_metrics.h" +#include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/xla_data.pb.h" #include "tsl/platform/casts.h" #include "tsl/platform/env.h" @@ -86,11 +94,20 @@ class GpuCompilerTest : public HloTestBase { return tensorflow::down_cast(compiler) ->RunPostSchedulingPipelines(module, 4 * 1024 * 1024, gpu_device_info); } +<<<<<<< HEAD const auto& device_desc() { return backend().default_stream_executor()->GetDeviceDescription(); } const se::GpuComputeCapability& GpuComputeComp() { return device_desc().gpu_compute_capability(); +======= + + const stream_executor::GpuComputeCapability& GpuComputeComp() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); +>>>>>>> master } }; @@ -117,6 +134,44 @@ ENTRY main { EXPECT_EQ(GetCompiledProgramsCount(), 1); } +TEST_F(GpuCompilerTest, RecordsStreamzStackTrace) { + const char* hlo_text = R"( +HloModule test + +ENTRY main { + p = f32[10]{0} parameter(0) + ROOT neg = f32[10]{0} negate(p) +} +)"; + + auto module = ParseAndReturnVerifiedModule(hlo_text).value(); + + std::unique_ptr executable = + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/false}) + .value(); + + const std::string kGpuCompilerStacktraceMetricName = + "/xla/service/gpu/compiler_stacktrace_count"; + tsl::monitoring::CollectionRegistry::CollectMetricsOptions options; + std::unique_ptr metrics = + tsl::monitoring::CollectionRegistry::Default()->CollectMetrics(options); + + EXPECT_TRUE(metrics->point_set_map.find(kGpuCompilerStacktraceMetricName) != + metrics->point_set_map.end()); + + // Since Streamz is recorded every call, we expect at least one point. + // All other callers may increment the counter as well. + EXPECT_GT( + metrics->point_set_map[kGpuCompilerStacktraceMetricName]->points.size(), + 0); +} + TEST_F(GpuCompilerTest, GenerateDebugInfoForNonAutotuningCompilations) { const char* hlo_text = R"( HloModule test @@ -283,7 +338,7 @@ ENTRY e { return str; } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions options = HloTestBase::GetDebugOptionsForTest(); options.set_xla_gpu_dump_autotune_results_to( xla_gpu_dump_autotune_results_to_); @@ -405,7 +460,19 @@ ENTRY main { HloOpcode::kAllGatherDone); } -TEST_F(GpuCompilerTest, +class GpuCompilerTestWithAutotuneDb : public GpuCompilerTest { + public: + static void SetUpTestSuite() { + std::string path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"); + TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path)); + } + + static void TearDownTestSuite() { AutotunerUtil::ClearAutotuneResults(); } +}; + +TEST_F(GpuCompilerTestWithAutotuneDb, GemmFusionIsNoOpWhenGemmFusionAutotunerFallsBackToCublas) { if (std::holds_alternative(GpuComputeComp())) { GTEST_SKIP() << "Folder structure differences prevents finding of gpu_compiler_test_autotune_db.textproto."; @@ -455,17 +522,10 @@ ENTRY main { config.set_replica_count(1); config.set_num_partitions(1); - // Load autotuning DB. We shouldn't depend on actual execution times in a unit - // test. - std::string path = - tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", - "gpu_compiler_test_autotune_db.textproto"); - TF_EXPECT_OK(AutotunerUtil::LoadAutotuneResultsFromFile(path)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_string, config)); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, GetOptimizedModule(std::move(module))); - AutotunerUtil::ClearAutotuneResults(); DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); @@ -485,6 +545,54 @@ ENTRY main { triton_disabled_module->computation_count()); } +TEST_F(GpuCompilerTestWithAutotuneDb, + CublasF8NumericallySameWithTritonFallbackAndWithoutTriton) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastHopper()) { + GTEST_SKIP() + << "Autotuning results have only been generated for Hopper GPUs"; + } + const absl::string_view hlo_string = R"( +HloModule test + +ENTRY main { + p0 = f8e4m3fn[12288,4096]{0,1} parameter(0) + p1 = f8e4m3fn[4096,16384]{0,1} parameter(1) + dot = bf16[12288,16384]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + bitcast = bf16[] constant(0.956) + broadcast = bf16[12288,16384]{1,0} broadcast(bitcast), dimensions={} + ROOT multiply = bf16[12288,16384]{1,0} multiply(dot, broadcast) + })"; + + HloModuleConfig config; + DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); + triton_enabled_debug_options + .set_xla_gpu_require_complete_aot_autotune_results(true); + config.set_debug_options(triton_enabled_debug_options); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_enabled_module, + GetOptimizedModule(std::move(module))); + + DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); + triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); + triton_disabled_debug_options.set_xla_gpu_cublas_fallback(true); + config.set_debug_options(triton_disabled_debug_options); + + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(hlo_string, config)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr triton_disabled_module, + GetOptimizedModule(std::move(module))); + + EXPECT_TRUE(RunAndCompareTwoModules(std::move(triton_enabled_module), + std::move(triton_disabled_module), + ErrorSpec{1e-6, 1e-6}, false)); +} + class FloatNormalizationTest : public GpuCompilerTest, public ::testing::WithParamInterface< std::pair> {}; @@ -739,7 +847,7 @@ class KernelCacheTest : public HloTestBase { } } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_kernel_cache_file(cache_file_name_); debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(true); @@ -899,7 +1007,7 @@ ENTRY e { class KernelCacheTestSingleThreaded : public KernelCacheTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = KernelCacheTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_force_compilation_parallelism(1); return debug_options; @@ -916,7 +1024,7 @@ TEST_F(KernelCacheTestSingleThreaded, CacheIsGenerated) { class NoKernelCacheTest : public KernelCacheTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = KernelCacheTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); return debug_options; @@ -978,10 +1086,112 @@ TEST_F(GpuCompilerTest, TestFlag_xla_gpu_unsafe_pipelined_loop_annotator) { EXPECT_TRUE(filecheck_matched); } +bool HasBlockLevelFusionConfig(const HloInstruction* fusion) { + return fusion->opcode() == HloOpcode::kFusion && + fusion->has_backend_config() && + fusion->backend_config().ok() && + fusion->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config(); +} + +TEST_F(GpuCompilerTest, + LoopFusionRootedInTransposeIsRewrittenToBlockLevelByDefaultPostAmpere) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + + constexpr absl::string_view transpose_fusion_module = R"( +transpose { + p0 = f32[1024,1024,1024] parameter(0) + ROOT transpose = f32[1024,1024,1024] transpose(p0), dimensions={2,1,0} +} + +ENTRY main { + p0 = f32[1024,1024,1024] parameter(0) + ROOT fusion = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=transpose +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(transpose_fusion_module)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(std::move(module))); + + if (cc.IsAtLeastAmpere()) { + EXPECT_TRUE(HasBlockLevelFusionConfig( + optimized_module->entry_computation()->root_instruction())); + } else { + EXPECT_FALSE(HasBlockLevelFusionConfig( + optimized_module->entry_computation()->root_instruction())); + } +} + +TEST_F( + GpuCompilerTest, + FusionBlockLevelRewriterRewritesKLoopTransposeWithBitcastIfTheSmallMinorDimIsAPowerOfTwo) { // NOLINT(whitespace/line_length) + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "FusionBlockLevelRewriter requires Ampere+ to run."; + } + + // If this test starts failing, then it's likely that this no longer generates + // a kLoop transpose. That's great---it probably means the rewrite in question + // is no longer necessary! + // + // The small minor dimension here is a power of two, so the rewrite should + // succeed. + constexpr absl::string_view rewritable_transpose_string = R"( +ENTRY main { + p0 = f32[1024,4096]{1,0} parameter(0) + reshape = f32[1024,1024,4]{2,1,0} reshape(p0) + ROOT transpose = f32[4,1024,1024]{2,1,0} transpose(reshape), dimensions={2,1,0} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rewritable_transpose_module, + ParseAndReturnVerifiedModule(rewritable_transpose_string)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rewritable_transpose_optimized_module, + GetOptimizedModule(std::move(rewritable_transpose_module))); + EXPECT_TRUE(HasBlockLevelFusionConfig( + rewritable_transpose_optimized_module->entry_computation() + ->root_instruction())); + + // The small minor dimension here is not a power of two, so the rewrite should + // fail. + constexpr absl::string_view unrewritable_transpose_string = R"( +ENTRY main { + p0 = f32[1024,6144]{1,0} parameter(0) + reshape = f32[1024,1024,6]{2,1,0} reshape(p0) + ROOT transpose = f32[6,1024,1024]{2,1,0} transpose(reshape), dimensions={2,1,0} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr unrewritable_transpose_module, + ParseAndReturnVerifiedModule(unrewritable_transpose_string)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr unrewritable_transpose_optimized_module, + GetOptimizedModule(std::move(unrewritable_transpose_module))); + EXPECT_FALSE(HasBlockLevelFusionConfig( + unrewritable_transpose_optimized_module->entry_computation() + ->root_instruction())); +} + using GpuCompilerPassTest = GpuCompilerTest; TEST_F(GpuCompilerPassTest, GpuCompilerRunsTritonGemmRewriterByDefaultFromAmpere) { + if (std::holds_alternative(GpuComputeComp())) { + GTEST_SKIP() << "TritonGemmRewriter disabled for ROCm until autotuner " + << "is included."; + } auto cc = backend() .default_stream_executor() ->GetDeviceDescription() @@ -1052,6 +1262,111 @@ ENTRY main { expect_custom_kernel_fusion_rewriter_has_run); } +class PassOrderTest : public GpuCompilerTest { + public: + void SetDebugOptions(const DebugOptions& options) { + HloModuleConfig config = GetModuleConfigForTest(); + config.set_debug_options(options); + CompileModule(config); + } + + // Fails if any of the passes with names matching the regular expression + // first_pass_regex run after any of the passes matching last_pass_regex or if + // none of the executed passes matches first_pass_regex or last_pass_regex. + void VerifyPassOrder(absl::string_view first_pass_regex, + absl::string_view last_pass_regex) { + if (!optimized_module_) { + CompileModule(GetModuleConfigForTest()); + } + int first_pass_latest_run = -1; + int last_pass_earliest_run = std::numeric_limits::max(); + int run_index = 0; + for (const HloPassMetadata& pass_metadata : + optimized_module_->metadata()->proto().pass_metadata()) { + if (RE2::FullMatch(pass_metadata.pass_name(), first_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches first_pass_regex." << std::endl; + first_pass_latest_run = std::max(first_pass_latest_run, run_index); + } + if (RE2::FullMatch(pass_metadata.pass_name(), last_pass_regex)) { + VLOG(2) << "Pass " << pass_metadata.pass_name() + << " matches last_pass_regex." << std::endl; + last_pass_earliest_run = std::min(last_pass_earliest_run, run_index); + } + ++run_index; + } + + EXPECT_GT(first_pass_latest_run, -1) + << "Did not run a pass matching " << first_pass_regex; + EXPECT_LT(last_pass_earliest_run, std::numeric_limits::max()) + << "Did not run a pass matching " << last_pass_regex; + EXPECT_LE(first_pass_latest_run, last_pass_earliest_run) + << "One or more passes matching " << first_pass_regex + << " ran after passes matching " << last_pass_regex; + } + + private: + void CompileModule(const HloModuleConfig& config) { + constexpr absl::string_view constant_module = R"( +ENTRY main { + ROOT constant = f32[] constant(0) +})"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(constant_module, config)); + TF_ASSERT_OK_AND_ASSIGN(optimized_module_, + GetOptimizedModule(std::move(module))); + } + + std::unique_ptr optimized_module_; +}; + +TEST_F(PassOrderTest, PassesAreRunInCorrectOrder) { + VerifyPassOrder(/*first_pass_regex=*/"layout-assignment", + /*last_pass_regex=*/"priority-fusion"); + VerifyPassOrder(/*first_pass_regex=*/"layout-assignment", + /*last_pass_regex=*/"layout_normalization"); + VerifyPassOrder(/*first_pass_regex=*/"host-offload-legalize", + /*last_pass_regex=*/"layout_normalization"); +} + +TEST_F(PassOrderTest, FusionBlockLevelRewriterRunsAfterAllFusionPasses) { + auto cc = backend() + .default_stream_executor() + ->GetDeviceDescription() + .cuda_compute_capability(); + if (!cc.IsAtLeastAmpere()) { + GTEST_SKIP() << "FusionBlockLevelRewriter requires Ampere+ to run."; + } + + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_experimental_enable_fusion_block_level_rewriter( + true); + SetDebugOptions(debug_options); + + VerifyPassOrder(/*first_pass_regex=*/".*fusion.*", + /*last_pass_regex=*/"fusion-block-level-rewriter"); +} + +TEST_F(PassOrderTest, CollectivePipelinerRunsAfterCollectiveQuantizer) { + DebugOptions options = GetDebugOptionsForTest(); + options.set_xla_gpu_enable_pipelined_collectives(true); + SetDebugOptions(options); + + VerifyPassOrder(/*first_pass_regex=*/"collective-quantizer", + /*last_pass_regex=*/"collective-pipeliner.*"); +} + +TEST_F(PassOrderTest, + AllGatherDynamicSliceSimplifierRunsAfterAllGatherOptimizer) { + DebugOptions options = GetDebugOptionsForTest(); + SetDebugOptions(options); + + VerifyPassOrder( + /*first_pass_regex=*/".*all-gather-optimizer.*", + /*last_pass_regex=*/".*all-gather-dynamic-slice-simplifier.*"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 51caadb7bd2d06..4a2493a64bb415 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -71,3 +71,38 @@ results { } } } +results { + device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" + hlo: "(bf16[12288,16384]{1,0}, s8[33554432]{0}) custom-call(f8e4m3fn[4096,12288]{0,1}, f8e4m3fn[4096,16384]{0,1}, f32[], f32[], f32[], f32[]), custom_call_target=\"__cublas$lt$matmul$f8\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":0.95703125,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"50331648\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"67108864\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" + result { + gemm { + } + run_time { + nanos: 1 + } + } +} +results { + device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" + hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_2 = bf16[12288,16384]{1,0} dot(f8e4m3fn[12288,4096]{0,1} tmp_0, f8e4m3fn[4096,16384]{0,1} tmp_1), lhs_contracting_dims={1}, rhs_contracting_dims={0}\n tmp_3 = bf16[] constant({...})\n tmp_4 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_3), dimensions={}\n ROOT tmp_5 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_2, bf16[12288,16384]{1,0} tmp_4)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 1 + } + } +} +results { + device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" + hlo: "{\n tmp_0 = f8e4m3fn[12288,4096]{0,1} parameter(0)\n tmp_1 = f8e4m3fn[4096,12288]{1,0} bitcast(f8e4m3fn[12288,4096]{0,1} tmp_0)\n tmp_2 = f8e4m3fn[4096,16384]{0,1} parameter(1)\n tmp_3 = bf16[12288,16384]{1,0} dot(f8e4m3fn[4096,12288]{1,0} tmp_1, f8e4m3fn[4096,16384]{0,1} tmp_2), lhs_contracting_dims={0}, rhs_contracting_dims={0}\n tmp_4 = bf16[] constant({...})\n tmp_5 = bf16[12288,16384]{1,0} broadcast(bf16[] tmp_4), dimensions={}\n ROOT tmp_6 = bf16[12288,16384]{1,0} multiply(bf16[12288,16384]{1,0} tmp_3, bf16[12288,16384]{1,0} tmp_5)\n}" + result { + gemm { + algorithm: -1 + } + run_time { + nanos: 1 + } + } +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index e3d939e873a228..1de6f9106a01ad 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -71,16 +71,13 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/stream_executor/gpu/gpu_executor.h" -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scoped_module_handle.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -177,6 +174,8 @@ absl::Status GpuExecutable::CheckCompatibilityWithServiceExecutableRunOptions( << std::get(gpu_version_).ToString() << "}, but was {" << std::get(cc).ToString() << "}"; + } else if (platform_id == stream_executor::sycl::kSyclPlatformId) { + // TODO: Add check. } else { return Internal("Unknown platform"); } @@ -622,7 +621,7 @@ GpuExecutable::ResolveConstantGlobals(se::Stream* stream) { if (!(executor->GetPlatform()->id() == stream_executor::cuda::kCudaPlatformId && binary().empty() && text().empty())) { - TF_RETURN_IF_ERROR(executor->LoadModule(module_spec, &module_handle)); + TF_ASSIGN_OR_RETURN(module_handle, executor->LoadModule(module_spec)); } // A flag signalling if constant initialization submitted memcpy operations @@ -805,12 +804,9 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( se::DeviceMemoryAllocator* const memory_allocator = run_options->allocator(); se::StreamExecutor* executor = run_options->stream()->parent(); -#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM // GpuExecutable always bound to a single GpuContext during its execution, so // we activate it once to skip expensive context activations later. - se::gpu::GpuExecutor* gpu_executor = se::gpu::ExtractGpuExecutor(executor); - se::gpu::ScopedActivateContext activation(gpu_executor); -#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM + auto activation = executor->Activate(); // Force synchronous execution if the allocator requires it. const bool block_host_until_done = diff --git a/third_party/xla/xla/service/gpu/gpu_float_support.cc b/third_party/xla/xla/service/gpu/gpu_float_support.cc index 1403ad021a217d..c02e1583499620 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -93,6 +94,7 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kTranspose: // Other special ops. case HloOpcode::kBitcast: + case HloOpcode::kReducePrecision: return true; // Elementwise ops. case HloOpcode::kAdd: @@ -106,6 +108,13 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { } return false; } + // Reduction. + case HloOpcode::kReduce: + return absl::c_all_of(hlo.called_computations().front()->instructions(), + [this](const HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kParameter || + this->IsSupported(*hlo); + }); default: return false; } diff --git a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc index 1d2f6c167bb090..1c1c7f4e8149d4 100644 --- a/third_party/xla/xla/service/gpu/gpu_float_support_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_float_support_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_verifier.h" @@ -35,6 +37,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { namespace { @@ -195,7 +198,7 @@ TEST_F(FloatSupportTest, ShouldAlwaysConvertFp8Dot) { /*should_convert_rhs=*/false, F8E5M2); } -TEST_F(FloatSupportTest, ShouldConverTritonUnsupportedFp8Dot) { +TEST_F(FloatSupportTest, ShouldConvertTritonUnsupportedFp8Dot) { TestTritonFusedDot(F8E4M3FN, F8E4M3FN, F16, se::CudaComputeCapability::Hopper(), /*should_convert_lhs=*/true, @@ -253,5 +256,51 @@ TEST_F(FloatSupportTest, ShouldKeepBf16OnHopper) { /*should_convert_rhs=*/false, BF16); } +TEST_F(FloatSupportTest, Bf16ReducePrecisionIsNotNormalized) { + auto cc = se::CudaComputeCapability::Ampere(); + constexpr absl::string_view kHloModule = R"( +HloModule m + +ENTRY main { + p0 = bf16[] parameter(0) + ROOT r = bf16[] reduce-precision(p0), exponent_bits=8, mantissa_bits=7 +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule)); + EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32)); +} + +TEST_F(FloatSupportTest, + BF16ReductionOnHopperIsOnlyNormalizedIfReducerIsUnsupported) { + auto cc = se::CudaComputeCapability::Hopper(); + constexpr absl::string_view kHloModuleTemplate = R"( +HloModule m + +reducer { + p0 = bf16[] parameter(0) + p1 = bf16[] parameter(1) + ROOT reducer = bf16[] $0(p0, p1) +} + +ENTRY main { + p0 = bf16[1024] parameter(0) + init = bf16[] constant(1337) + ROOT r = bf16[] reduce(p0, init), dimensions={0}, to_apply=reducer +})"; + + // add.bf16 was added in Hopper. + TF_ASSERT_OK_AND_ASSIGN(auto module_with_supported_reducer, + ParseAndReturnVerifiedModule( + absl::Substitute(kHloModuleTemplate, "add"))); + EXPECT_FALSE(Normalize(module_with_supported_reducer.get(), cc, BF16, F32)); + + // There is no bf16 instruction for divide, however. + TF_ASSERT_OK_AND_ASSIGN(auto module_with_unsupported_reducer, + ParseAndReturnVerifiedModule( + absl::Substitute(kHloModuleTemplate, "divide"))); + EXPECT_TRUE(Normalize(module_with_unsupported_reducer.get(), cc, BF16, F32)); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index bae2e880a9f199..e997de13e7cc45 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/synchronization/mutex.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -69,8 +69,7 @@ const Shape& GetElementShape(const HloFusionAnalysis& analysis) { // Computes the maximum valid unroll factor for a given instruction. int ComputeMaxUnrollFactor(int64_t num_elements) { - constexpr int kMaxUnrollFactor = 4; - for (int i = kMaxUnrollFactor; i > 1; i /= 2) { + for (int i = MaxUnrollFactor(); i > 1; i /= 2) { if (num_elements % i == 0) { return i; } @@ -266,15 +265,15 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, if (hero1_is_unnested_reduce && hero2_is_unnested_reduce && !AreReductionsMultiOutputFusionCompatible(hero2, hero1)) { - return "tiled reductions with different shapes"; + return FusionDecision::Forbid("tiled reductions with different shapes"); } else if (hero1_is_unnested_transpose && hero2_is_unnested_transpose && // After normalization to rank 3, the transposes should have the // same shape and permute the same dimensions. !tiled_transpose_hero1->IsEquivalent(*tiled_transpose_hero2)) { - return "tiled transposes with different shapes"; + return FusionDecision::Forbid("tiled transposes with different shapes"); } else if ((hero1_is_unnested_transpose && hero2_is_unnested_reduce) || (hero1_is_unnested_reduce && hero2_is_unnested_transpose)) { - return "MOF-fusion of a transpose and a reduction"; + return FusionDecision::Forbid("MOF-fusion of a transpose and a reduction"); } // If we are dealing with unnested transpose, make sure that we can still // treat them as unnested transpose after the sibling fusion. @@ -303,18 +302,18 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, int64_t operand_idx = fusion2->operand_index(fusion1); auto hlo = fusion2->fused_parameter(operand_idx); if (!check_path_of_intermediate_ops(hlo)) { - return "tiled transpose would become untiled"; + return FusionDecision::Forbid("tiled transpose would become untiled"); } } else if (hero2_is_unnested_transpose && fusion1->IsUserOf(fusion2)) { int64_t operand_idx = fusion1->operand_index(fusion2); auto hlo = fusion1->fused_parameter(operand_idx); if (!check_path_of_intermediate_ops(hlo)) { - return "tiled transpose would become untiled"; + return FusionDecision::Forbid("tiled transpose would become untiled"); } } } } - return {}; + return FusionDecision::Allow(); } FusionDecision ShapesCompatibleForMultiOutputFusion( @@ -356,9 +355,9 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( (!accept_unequal_shape || !ShapeUtil::IsReshapeOrTransposeBitcast(l1, l2, /*ignore_element_type=*/true))) { - return "different loop shapes"; + return FusionDecision::Forbid("different loop shapes"); } - return {}; + return FusionDecision::Allow(); } bool IsInputFusibleScatter(const HloInstruction& instr) { @@ -469,10 +468,10 @@ static bool AllSatisfy(const HloInstruction& instr, FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, const HloInstruction& consumer) { if (IsInputFusibleScatter(producer)) { - return "do not fuse into the output of scatter"; + return FusionDecision::Forbid("do not fuse into the output of scatter"); } if (!IsInputFusibleScatter(consumer)) { - return {}; + return FusionDecision::Allow(); } const HloInstruction* inplace_operand; @@ -485,19 +484,21 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, inplace_operand = consumer.operand(0); } if (inplace_operand == &producer) { - return "do not fuse into the in-place operand of scatter"; + return FusionDecision::Forbid( + "do not fuse into the in-place operand of scatter"); } if (absl::c_linear_search(producer.operands(), inplace_operand)) { - return "Producer uses the in-place operand of a scatter"; + return FusionDecision::Forbid( + "Producer uses the in-place operand of a scatter"); } - return {}; + return FusionDecision::Allow(); } FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, const HloInstruction& consumer) { if (!IsLoopFusibleAsProducer(producer) && !IsInputFusibleTranspose(producer)) { - return "the producer is not loop-fusible"; + return FusionDecision::Forbid("the producer is not loop-fusible"); } if (IsInputFusibleReduction(producer)) { @@ -505,7 +506,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, ->config() .debug_options() .xla_gpu_enable_reduction_epilogue_fusion()) { - return "Reduction epilogue fusion is not enabled."; + return FusionDecision::Forbid( + "Reduction epilogue fusion is not enabled."); } const HloInstruction& reduce_hero = producer.opcode() == HloOpcode::kFusion @@ -514,16 +516,19 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, if (!ReductionIsRaceFree( reduce_hero.GetModule()->config(), GetReductionKindAndContiguousComponents(reduce_hero))) { - return "Reduction output fusion only works for race free reductions"; + return FusionDecision::Forbid( + "Reduction output fusion only works for race free reductions"); } if (!AllSatisfy(consumer, [](const HloInstruction* hlo) { return IsIntermediate(hlo, /*allowed_operand_count=*/1); })) { - return "Reductions from/to continuous dims epilogue not fusible"; + return FusionDecision::Forbid( + "Reductions from/to continuous dims epilogue not fusible"); } if (producer.user_count() > 1) { - return "reduction output fusion only works for single user"; + return FusionDecision::Forbid( + "reduction output fusion only works for single user"); } } @@ -532,12 +537,14 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, } if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { - return "the consumer is not input-fusible and not loop-fusible"; + return FusionDecision::Forbid( + "the consumer is not input-fusible and not loop-fusible"); } // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { - return "the producer is not fusible as it is a multi-output fusion"; + return FusionDecision::Forbid( + "the producer is not fusible as it is a multi-output fusion"); } // Fuse scalar constants into loop fusion nodes. This reduces the number of @@ -551,7 +558,7 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, if (producer.opcode() == HloOpcode::kConstant && (!ShapeUtil::IsEffectiveScalar(producer.shape()) || consumer.opcode() != HloOpcode::kFusion)) { - return "not fusing constant"; + return FusionDecision::Forbid("not fusing constant"); } // Make sure the new fusion obeys the in-place semantics. @@ -561,7 +568,7 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { - return "Producer is a multi-output fusion"; + return FusionDecision::Forbid("Producer is a multi-output fusion"); } // Allowing multi-output fusions that contain in-place operations makes code @@ -589,18 +596,18 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // contract that describes what multi-output fusion scenarios are supported by // codegen and then changing this check to allow exactly those fusions). if (!HloDataflowAnalysis::GetInPlaceInputOutputPairs(&producer).empty()) { - return "In-place operations are present"; + return FusionDecision::Forbid("In-place operations are present"); } if (!IsLoopFusibleAsProducer(producer)) { - return "producer is not loop-fusible"; + return FusionDecision::Forbid("producer is not loop-fusible"); } if (IsPhysicallyTransposing(producer)) { - return "producer is physically transposing"; + return FusionDecision::Forbid("producer is physically transposing"); } - return {}; + return FusionDecision::Allow(); } // Returns an estimate of the shared memory usage for a given instruction in @@ -751,16 +758,17 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, FusionInfoCache* cache /*=nullptr*/) { if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) > device_info.shared_memory_per_block()) { - return FusionDecision{} - << "shared memory usage would be over the budget of " + return FusionDecision::Forbid( + "shared memory usage would be over the budget of ") << device_info.shared_memory_per_block() << "B"; } if (NumUnnestedReductions(instr1, cache) + NumUnnestedReductions(instr2, cache) > kMaxUnnestedReductionOutputsPerFusion) { - return FusionDecision{} << "over " << kMaxUnnestedReductionOutputsPerFusion - << " unnested reductions in fusion"; + return FusionDecision::Forbid("over ") + << kMaxUnnestedReductionOutputsPerFusion + << " unnested reductions in fusion"; } // Compute the number of outputs of the (possibly multi-output) fusion node @@ -791,7 +799,7 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, if (instr1.operand_count() + instr2.operand_count() - 1 + num_output_buffers <= MaxOperandsAndOutputsPerFusion()) { - return {}; + return FusionDecision::Allow(); } else { VLOG(5) << "Operand count of " << "(" << instr1.ToString() << " ) = " << instr1.operand_count() << " and ( " @@ -816,15 +824,16 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, // consumer numbers of output. So no need to check it. if (is_consumer_producer_fusion && operands.size() <= instr1.operands().size()) { - return {}; + return FusionDecision::Allow(); } // Does the new fusion have more operands and outputs than the max? if (operands.size() + num_output_buffers > MaxOperandsAndOutputsPerFusion()) { - return "Number of operands and output buffers is larger than allowed " - "budget per fusion"; + return FusionDecision::Forbid( + "Number of operands and output buffers is larger than allowed budget " + "per fusion"); } - return {}; + return FusionDecision::Allow(); } bool CreatesHeavyComputation(const HloInstruction& producer, @@ -896,8 +905,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { // with any other instruction. // Note that scatter cannot be the root of a multi-output fusion because // its emitter doesn't support it. + // + // Custom fusions cannot be fused with anything. - return instr.IsFusible() && + return instr.IsFusible() && !instr.IsCustomFusion() && (IsInputFusibleReduction(instr) || IsInputFusibleTranspose(instr) || instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here. instr.IsElementwise()); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 0dadbfa36f5476..e9c4309ad3b82a 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -228,6 +228,9 @@ bool IsGenericTritonFusion(const HloInstruction& instr); // instructions it contains. bool MayPreventVectorization(const HloFusionAdaptor& fusion); +// Returns the max loop unroll factor. +inline constexpr int64_t MaxUnrollFactor() { return 4; } + LaunchDimensionsConfig ComputeLoopFusionConfig( const HloFusionAnalysis& analysis); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 735709cbd346f8..51af8693af2f30 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -488,23 +488,14 @@ TEST_F(GpuFusibleTest, TEST_F(GpuFusibleTest, CustomFusionIsNotFusibleAsConsumer) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( -HloModule m - triton_fusion { - p0 = f16[20,3]{1,0} parameter(0) - p1 = f16[3,40]{1,0} parameter(1) - dot = f16[20,40]{1,0} dot(p0, p1), - lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT c = f16[20,40]{0,1} copy(dot) + p = s32[20,3] parameter(0) + ROOT neg = s32[20,3] negate(p) } ENTRY e { - p0 = f16[20,3]{1,0} parameter(0) - n = f16[20,3]{1,0} negate(p0) - p1 = f16[3,40]{1,0} parameter(1) - ROOT r = f16[20,40]{0,1} fusion(n, p1), - kind=kCustom, - calls=triton_fusion + p = s32[20,3] parameter(0) + ROOT r = s32[20,3] fusion(p), kind=kCustom, calls=triton_fusion })")); const HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_FALSE(IsFusibleAsMultiOutputFusionRoot(*root)); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index ee2e9874e358cf..d89635fe871e97 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/buffer_value.h" #include "xla/service/collective_ops_utils.h" @@ -51,7 +52,6 @@ limitations under the License. #include "xla/service/gpu/transforms/pgle_accuracy_checker.h" #include "xla/service/gpu/transforms/schedule_postprocessing.h" #include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" -#include "xla/service/hlo_memory_scheduler.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/service/p2p_schedule_preparation.h" #include "xla/service/profile_guided_latency_estimator.h" @@ -247,20 +247,10 @@ HloInstructionSequence PostprocessorToScheduleSyncCollectives( return result; } -absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( - const HloModule* module, int64_t pointer_size) { - return ScheduleModule( - module, - [pointer_size](const BufferValue& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); - }, - ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, - PostProcessSchedule)); -} - // Latency hiding scheduler support. -SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { +SchedulerConfig GetSchedulerConfig(int64_t memory_limit, + int64_t collective_resource) { SchedulerConfig config; config.all_reduce_overlap_limit = 1; config.collective_broadcast_overlap_limit = 1; @@ -269,6 +259,7 @@ SchedulerConfig GetSchedulerConfig(int64_t memory_limit) { config.aggressive_scheduling_policies = true; config.schedule_send_recvs = true; config.memory_limit = memory_limit; + config.parallel_collective_overlap_limit = collective_resource; return config; } @@ -458,16 +449,47 @@ absl::StatusOr ScheduleGpuModule( VLOG(1) << "Fingerprint before LHS for module " << module->name() << "(" << module->unique_id() << ") = " << fingerprint; + const DebugOptions& options = module->config().debug_options(); const bool enable_latency_hiding_scheduler = - module->config() - .debug_options() - .xla_gpu_enable_latency_hiding_scheduler(); + options.xla_gpu_enable_latency_hiding_scheduler(); if (!enable_latency_hiding_scheduler) { return ScheduleMetadata{memory_limit}; } - SchedulerConfig config = GetSchedulerConfig(memory_limit); + if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() && + module->config().fdo_profile().empty() && + options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { + return absl::InvalidArgumentError( + "xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile " + "path specified in xla_gpu_pgle_profile_file_or_directory_path"); + } + + if (options.xla_gpu_pgle_profile_file_or_directory_path().empty() && + module->config().fdo_profile().empty() && + options.xla_gpu_pgle_accuracy_checker(), + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN) { + LOG(WARNING) + << "xla_gpu_pgle_accuracy_checker is set to WARN, but no profile path " + "specified in xla_gpu_pgle_profile_file_or_directory_path"; + } + + SchedulerConfig config = GetSchedulerConfig( + memory_limit, + module->config() + .debug_options() + .xla_gpu_experimental_parallel_collective_overlap_limit()); + CHECK((config.collective_broadcast_overlap_limit <= + config.parallel_collective_overlap_limit) && + (config.all_to_all_overlap_limit <= + config.parallel_collective_overlap_limit) && + (config.all_gather_overlap_limit <= + config.parallel_collective_overlap_limit) && + (config.all_reduce_overlap_limit <= + config.parallel_collective_overlap_limit) && + (config.reduce_scatter_overlap_limit <= + config.parallel_collective_overlap_limit)); auto gpu_latency_estimator = std::make_unique(pointer_size); @@ -476,9 +498,7 @@ absl::StatusOr ScheduleGpuModule( ReadPGLEProfile(module, fingerprint); const bool enable_analytical_latency_estimator = - module->config() - .debug_options() - .xla_gpu_enable_analytical_latency_estimator(); + options.xla_gpu_enable_analytical_latency_estimator(); HloPassPipeline pipeline("latency-hiding-scheduler"); if (profile.has_value()) { auto aggregator = std::make_unique(); @@ -487,9 +507,10 @@ absl::StatusOr ScheduleGpuModule( std::move(aggregator)); LOG(INFO) << "Found profile, using profile guided latency estimator"; VLOG(1) << "Profile:\n" << profile->DebugString(); - if (module->config() - .debug_options() - .xla_gpu_enable_pgle_accuracy_checker()) { + if (options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_WARN || + options.xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { pipeline.AddPass(*pg_latency_estimator); } latency_estimator = std::move(pg_latency_estimator); @@ -506,9 +527,7 @@ absl::StatusOr ScheduleGpuModule( } auto async_tracker = [&]() -> std::unique_ptr { - return module->config() - .debug_options() - .xla_gpu_lhs_enable_gpu_async_tracker() + return options.xla_gpu_lhs_enable_gpu_async_tracker() ? std::make_unique(config) : std::make_unique(config); }(); @@ -533,6 +552,18 @@ absl::StatusOr ScheduleGpuModule( return ScheduleMetadata{memory_limit}; } +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( + const HloModule* module, int64_t pointer_size, int64_t* peak_memory_bytes) { + return ScheduleModule( + module, + [pointer_size](const BufferValue& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), pointer_size); + }, + ComputationSchedulerToModuleScheduler(DefaultMemoryScheduler, + PostProcessSchedule), + /*execution_threads=*/{}, /*peak_memory=*/peak_memory_bytes); +} + HloInstructionSequence PostProcessSchedule( const HloInstructionSequence& input) { HloInstructionSequence result = PostprocessorToScheduleSyncCollectives(input); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h index b71226c20710a9..af0e7dd3fa9afc 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.h @@ -37,6 +37,13 @@ absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info); +// Schedules a GPU module with `DefaultMemoryScheduler` and +// `PostProcessSchedule` postprocessing. If `peak_memory_bytes` is not nullptr, +// then the it will be set to peak memory usage in bytes. +absl::StatusOr ScheduleGpuModuleWithMemoryScheduler( + const HloModule* module, int64_t pointer_size, + int64_t* peak_memory_bytes = nullptr); + HloInstructionSequence PostProcessSchedule(const HloInstructionSequence& input); constexpr absl::string_view kFingerprintBeforeLHS = "fingerprint_before_lhs"; diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 0f9c8412bcdfc1..2f4756c3ed71ee 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -38,8 +39,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/backend.h" +#include "xla/service/gpu/gpu_compiler.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_ordering.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -347,6 +348,32 @@ TEST_F(GpuHloScheduleTest, LHSCostModel) { EXPECT_TRUE(HasValidFingerprint(module.get())); } +TEST_F(GpuHloScheduleTest, + ScheduleGpuModuleWithMemorySchedulerReturnsPeakMemoryBytes) { + absl::string_view kHloText = R"( + HloModule m + + ENTRY ar { + p0 = f32[32,32] parameter(0) + p1 = f32[32,32] parameter(1) + + ROOT _ = f32[32,32]{1,0} custom-call(p0, p1), + custom_call_target="__cublas$gemm" + })"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloText, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true))); + int64_t pointer_size = + dynamic_cast(backend().compiler())->GetPointerSize(); + int64_t peak_memory_bytes = -1; + TF_ASSERT_OK_AND_ASSIGN(auto schedule, + ScheduleGpuModuleWithMemoryScheduler( + module.get(), pointer_size, &peak_memory_bytes)); + EXPECT_GT(peak_memory_bytes, 0); +} + TEST_F(GpuHloScheduleTest, LHSCostModelCostlyAR) { const char* hlo_text = R"( HloModule AsyncAR @@ -539,7 +566,8 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) { HloModuleConfig config(module->config()); DebugOptions dboptions(config.debug_options()); - dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true); + dboptions.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); config.set_debug_options(dboptions); module->set_config(config); @@ -1637,5 +1665,62 @@ TEST_F(GpuHloScheduleTest, AsyncOps) { HloOpcode::kAsyncDone, HloOpcode::kAdd)); } +// This test verifies that the latency hiding scheduler overlaps host memory +// offloading (copy-start/copy-done) with computation. +TEST_F(GpuHloScheduleTest, CopyStartDoneScheduled) { + constexpr absl::string_view kHloCopyStartDone = R"( + HloModule offloading + ENTRY main { + param.0 = f32[512,1024]{1,0} parameter(0) + tanh.14 = f32[512,1024]{1,0} tanh(param.0) + copy-start.1 = (f32[512,1024]{1,0:S(5)}, f32[512,1024]{1,0}, u32[]) copy-start(param.0) + copy-done.1 = f32[512,1024]{1,0:S(5)} copy-done(copy-start.1) + copy-start.3 = (f32[512,1024]{1,0}, f32[512,1024]{1,0:S(5)}, u32[]) copy-start(copy-done.1) + copy-done.3 = f32[512,1024]{1,0} copy-done(copy-start.3) + ROOT add.0 = f32[512,1024]{1,0} add(copy-done.3, tanh.14) + })"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloCopyStartDone, + GetModuleConfig(/*enable_latency_hiding_scheduler=*/true))); + TF_CHECK_OK(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status()); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( +// CHECK: ENTRY +// CHECK: copy-start.3 = (f32[512,1024]{1,0}, f32[512,1024]{1,0:S(5)}, u32[]) copy-start +// CHECK: tanh.14 = f32[512,1024]{1,0} tanh +// CHECK: copy-done.3 = f32[512,1024]{1,0} copy-done +)")); +} + +TEST_F(GpuHloScheduleTest, InvalidPGLEOptions) { + const char* hlo = R"( + HloModule test + ENTRY add { + a = s32[] parameter(0) + b = s32[] parameter(1) + ROOT add = add(a,b) + } + )"; + + HloModuleConfig config; + DebugOptions options; + options.set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); + options.set_xla_gpu_enable_latency_hiding_scheduler(true); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo, config)); + + GTEST_FLAG_SET(death_test_style, "threadsafe"); + EXPECT_DEATH(BuildHloOrdering(module.get()), + "xla_gpu_pgle_accuracy_checker is set to ERROR, but no profile " + "path specified in xla_gpu_pgle_profile_file_or_directory_path"); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc index 4daaba52dff7bd..418d29dbbd3362 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.cc @@ -91,18 +91,23 @@ std::pair GetP2PResourceAndUsage( return {resource, usage}; } +// Marks async start operations to be scheduled as early as possible. +// It allows maximum overlap of operations while respecting dependencies. +// Besides async collectives, copy-start is async memcpy D2H/H2D, the beginning +// of a host offloading segment. bool IsGpuAsyncStart(const HloInstruction& hlo) { return (hlo_query::IsAsyncCollectiveStartOp(&hlo, /*include_send_recv=*/true) && !IsSyncCollective(&hlo)) || - IsAsyncComputeOp(hlo); + IsAsyncComputeOp(hlo) || hlo.opcode() == HloOpcode::kCopyStart; } +// Marks async done operations to be scheduled as late as possible. bool IsGpuAsyncDone(const HloInstruction& hlo) { return (hlo_query::IsAsyncCollectiveDoneOp(&hlo, /*include_send_recv=*/true) && !IsSyncCollective(hlo.operand(0))) || - IsAsyncComputeOp(hlo); + IsAsyncComputeOp(hlo) || hlo.opcode() == HloOpcode::kCopyDone; } bool IsAsyncPair(const HloInstruction& from, const HloInstruction& target) { @@ -183,7 +188,7 @@ void GpuAsyncTrackerBase::PostProcessScheduleGraph( GpuAsyncTracker::GpuAsyncTracker(const SchedulerConfig& config) : GpuAsyncTrackerBase(config) {} -ResourcesVector GpuAsyncTracker::GetResourcesFromInstruction( +ResourcesVector GpuAsyncTracker::GetResourcesFromInstructionImpl( const HloInstruction& instr) const { CanonicalAsyncOp op = GetCanonicalAsyncOp(instr); if (op.outer == HloOpcode::kAsyncStart || op.outer == HloOpcode::kAsyncDone) { @@ -203,7 +208,7 @@ ResourcesVector GpuAsyncTracker::GetResourcesFromInstruction( GetFirstTargetDefinedResource() + static_cast(resource), usage)}; } - return GpuAsyncTrackerBase::GetResourcesFromInstruction(instr); + return GpuAsyncTrackerBase::GetResourcesFromInstructionImpl(instr); } int64_t GpuAsyncTracker::GetNumTargetDefinedResources() const { @@ -239,6 +244,11 @@ int64_t GpuAsyncTracker::GetNumAvailableResources(int64_t resource_type) const { return 2; } + if ((resource_type - first_target_resource) == + static_cast(GpuResourceType::kGpuAsyncStreamCollectives)) { + return config_.parallel_collective_overlap_limit; + } + return 1; } @@ -393,7 +403,8 @@ ApproximateLatencyEstimator::TimeCost GpuLatencyEstimator::GetLatencyBetween( void GPUProfileStatisticsAggregator::HandleMissingInstructionCost( const HloInstruction& instruction) { if (!IsNopInstruction(instruction) && - instruction.opcode() != HloOpcode::kWhile) { + HloPredicateIsNotOp(&instruction) && + HloPredicateIsNotOp(&instruction)) { missing_instructions_.insert(&instruction); } } diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h index b0db29c812cb37..ad6f67d774924b 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler.h @@ -76,7 +76,7 @@ class GpuAsyncTracker : public GpuAsyncTrackerBase { explicit GpuAsyncTracker(const SchedulerConfig& config); // Returns resources used (occupied or released) by `instr`. - ResourcesVector GetResourcesFromInstruction( + ResourcesVector GetResourcesFromInstructionImpl( const HloInstruction& instr) const override; // Returns the number of target defined resources diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index e2f5b7318f54e6..42e05cf9db71cd 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -39,21 +39,33 @@ using ::testing::Property; using ::testing::UnorderedElementsAre; using ::tsl::testing::StatusIs; +int GetIndexByName(absl::Span instruction_sequence, + absl::string_view hlo_name) { + return absl::c_find_if(instruction_sequence, + [hlo_name](HloInstruction* instruction) { + return instruction->name() == hlo_name; + }) - + instruction_sequence.begin(); +} + // TODO(b/346918304): Separate relevant tests from gpu_hlo_schedule_test.cc // into broader GPU scheduling related tests vs. tests related to components of // GPU LHS. class GpuLatencyHidingSchedulerBaseTest : public HloTestBase { protected: - absl::StatusOr ScheduleModule(HloModule* module) { + absl::StatusOr ScheduleModule( + HloModule* module, int64_t num_parallel_resources = 1, + DebugOptions::PGLEStrictnessLevel strictness = + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { auto& test_backend = backend(); const auto& gpu_device_info = test_backend.default_stream_executor()->GetDeviceDescription(); - HloModuleConfig config(module->config()); - DebugOptions dboptions(config.debug_options()); - dboptions.set_xla_gpu_enable_pgle_accuracy_checker(true); - config.set_debug_options(dboptions); - module->set_config(config); + DebugOptions& options = module->mutable_config().mutable_debug_options(); + options.set_xla_gpu_experimental_parallel_collective_overlap_limit( + num_parallel_resources); + options.set_xla_gpu_pgle_accuracy_checker(strictness); + TF_RETURN_IF_ERROR( ScheduleGpuModule(module, /*pointer_size=*/8, gpu_device_info) .status()); @@ -110,6 +122,44 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, } } +// Copies are not fusion wrapped. We ran a fusion wrapper prior to scheduling +// which wrapped copies and some copies were prevented from copy elision by copy +// insertion pass which runs after scheduling. Potentially we might end up with +// unrecognized instructions at scheduling time. +// +// See b/373800086 for more context. +TEST_F(GpuLatencyHidingSchedulerBaseTest, + GPUProfileStatisticsAggregatorDoesNotCountCopies) { + GPUProfileStatisticsAggregator aggregator; + ProfileStatisticsAggregator::Statistics before_stats = aggregator.GetStats(); + + ASSERT_EQ(before_stats.missing_instructions.size(), 0); + ASSERT_EQ(before_stats.found_instructions_count, 0); + + absl::string_view kFdoProfile = ""; + absl::string_view kHloModule = R"( + HloModule m + + ENTRY main { + parameter.0 = f32[] parameter(0) + ROOT copy.0 = copy(parameter.0) + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + for (const HloInstruction* instr : + module->entry_computation()->instructions()) { + aggregator.HandleMissingInstructionCost(*instr); + + ProfileStatisticsAggregator::Statistics after_stats = aggregator.GetStats(); + EXPECT_EQ(after_stats.missing_instructions.size(), 0); + EXPECT_EQ(after_stats.found_instructions_count, 0); + } +} + TEST_F(GpuLatencyHidingSchedulerBaseTest, GPUProfileStatisticsAggregatorCountsMissingInstruction) { GPUProfileStatisticsAggregator aggregator; @@ -282,5 +332,59 @@ TEST_F(GpuLatencyHidingSchedulerBaseTest, TF_EXPECT_OK(ScheduleModule(module.get())); } +TEST_F(GpuLatencyHidingSchedulerBaseTest, + MultipleParallelResourceShouldOverlapCollectives) { + absl::string_view kFdoProfile = R"pb( + costs { name: "add_0" cost_us: 100000.0 } + costs { name: "ar_0" cost_us: 10.0 } + costs { name: "rs_0" cost_us: 10.0 } + )pb"; + ; + absl::string_view kHloModule = R"( + HloModule m + + reduce { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT _ = f32[] add(x, y) + } + + ENTRY main { + p0 = f32[] parameter(0) + p1 = f32[2] parameter(1) + p2 = f32[2] parameter(2) + ar_0 = f32[] all-reduce-start(p0), to_apply=reduce + ar_1 = f32[] all-reduce-done(ar_0) + rs_0 = ((f32[2]), f32[1]) reduce-scatter-start(p1), to_apply=reduce, dimensions={0} + rs_1 = f32[1] reduce-scatter-done(rs_0) + add_0 = f32[2] add(p1, p2) + ROOT _ = (f32[], f32[1], f32[2]) tuple(ar_1, rs_1, add_0) + } + )"; + + auto config = GetModuleConfig(kFdoProfile); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule, config)); + + TF_EXPECT_OK(ScheduleModule(module.get(), /*num_parallel_resources=*/2)); + auto schedule = module->schedule(); + std::vector instruction_sequence = + schedule.sequence(module->entry_computation()).instructions(); + // Since we allow 2 collectives in-flight, we should expect this pattern: + // ar(rs)-start -> rs(ar)-start -> add -> ar(rs)-done -> ar(rs)-done + EXPECT_TRUE(GetIndexByName(instruction_sequence, "ar_0") < + GetIndexByName(instruction_sequence, "rs_1") && + GetIndexByName(instruction_sequence, "rs_0") < + GetIndexByName(instruction_sequence, "ar_1")); + EXPECT_TRUE(GetIndexByName(instruction_sequence, "add_0") > + GetIndexByName(instruction_sequence, "ar_0") && + GetIndexByName(instruction_sequence, "add_0") > + GetIndexByName(instruction_sequence, "rs_0") && + GetIndexByName(instruction_sequence, "add_0") < + GetIndexByName(instruction_sequence, "ar_1") && + GetIndexByName(instruction_sequence, "add_0") < + GetIndexByName(instruction_sequence, "rs_1")); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h index fff43614afd98e..9dc80eb45a7c64 100644 --- a/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h +++ b/third_party/xla/xla/service/gpu/gpu_memory_space_assignment.h @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index 3aa9d79977bd81..a3ac9f2f95b12f 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -27,14 +27,14 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout.h" #include "xla/service/buffer_value.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_rematerialization.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc index d097d1b6382517..9920fc93dacf33 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner.cc @@ -27,10 +27,10 @@ limitations under the License. #include "absl/strings/str_join.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/hlo_parser.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc index df46055ce0d7e5..c3e603d055f4e4 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -29,9 +29,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc index b6d99a85194d5f..e2299be1074b29 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc @@ -25,22 +25,22 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/hlo/transforms/hlo_constant_splitter.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/service/conditional_simplifier.h" #include "xla/service/gather_expander.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" -#include "xla/service/hlo_constant_folding.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/reshape_mover.h" #include "xla/service/scatter_expander.h" #include "xla/service/sharding_propagation.h" -#include "xla/service/sort_simplifier.h" #include "xla/service/spmd/collective_permute_motion.h" #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" -#include "xla/service/tuple_simplifier.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_simplifier.h" #include "xla/stream_executor/device_description.h" @@ -102,6 +102,15 @@ void AddSPMDPasses( /*is_spmd=*/true, /*propagate_metadata=*/false, config.allow_spmd_sharding_propagation_to_output()); } + std::optional oper_size_threshold = std::nullopt; + if (hlo_module->config() + .debug_options() + .xla_gpu_operand_bytes_threshold_for_windowed_einsum() >= 0) { + oper_size_threshold = + hlo_module->config() + .debug_options() + .xla_gpu_operand_bytes_threshold_for_windowed_einsum(); + } spmd_pipeline.AddPass( num_partitions, hlo_module->config().replica_count(), hlo_module->config() @@ -111,7 +120,7 @@ void AddSPMDPasses( .debug_options() .xla_gpu_multi_streamed_windowed_einsum(), /*skip_checking_windowed_einsum_users=*/true, - /*disable_ag_rewrite_for_multiple_consumers=*/true); + /*disable_ag_rewrite_for_multiple_consumers=*/true, oper_size_threshold); spmd_pipeline.AddPass(); } diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.h b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.h index e969965f5842b2..9b567d55a100bb 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.h +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc index c0a7d68ebc3af5..2348dee1afd63b 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc @@ -25,10 +25,11 @@ limitations under the License. #include "absl/log/log.h" #include "xla/client/executable_build_options.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" +#include "xla/service/spmd/shardy/constants.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -51,6 +52,11 @@ class GpuSpmdPartitioningTest : public HloTestBase, config.set_use_shardy_partitioner(UseShardy()); TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module, config)); + if (UseShardy()) { + FrontendAttributes attrs; + attrs.mutable_map()->try_emplace(xla::sdy::kImportMhloShardings, "t"); + module->add_frontend_attributes(attrs); + } HloPassPipeline spmd_pipeline("spmd-partitioner"); se::CudaComputeCapability ampere(8, 0); @@ -66,7 +72,7 @@ class GpuSpmdPartitioningTest : public HloTestBase, protected: bool UseShardy() const { return GetParam(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); return debug_options; } diff --git a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc index dc770514bdda2b..17876412ab49fb 100644 --- a/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc +++ b/third_party/xla/xla/service/gpu/gpu_transfer_manager.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/sycl/sycl_platform_id.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -238,7 +239,8 @@ static absl::Status ForEachChunk( for (int64_t chunk_index = 0; chunk_index < num_chunks; ++chunk_index) { TF_RETURN_IF_ERROR(callback( /*chunk_offset=*/chunk_index * chunk_size, - /*chunk_size=*/std::min(chunk_size, size - chunk_index * chunk_size))); + /*chunk_size=*/std::min( + chunk_size, static_cast(size - chunk_index * chunk_size)))); } return absl::OkStatus(); } @@ -369,11 +371,20 @@ static std::unique_ptr CreateAMDGPUTransferManager() { .getPointerSize(0 /* default address space */)); } +static std::unique_ptr CreateSYCLTransferManager() { + return std::make_unique( + /*id=*/stream_executor::sycl::kSyclPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::spir::DataLayout()) + .getPointerSize(0 /* default address space */)); +} + static bool InitModule() { xla::TransferManager::RegisterTransferManager( stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); xla::TransferManager::RegisterTransferManager( stream_executor::rocm::kROCmPlatformId, &CreateAMDGPUTransferManager); + xla::TransferManager::RegisterTransferManager( + stream_executor::sycl::kSyclPlatformId, &CreateSYCLTransferManager); return true; } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc index c2d33da7fcf408..abc10887f3f8e7 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/strings/match.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/gpu/infeed_manager.cc b/third_party/xla/xla/service/gpu/infeed_manager.cc index e0fa03b9137790..4524d4c09d58b5 100644 --- a/third_party/xla/xla/service/gpu/infeed_manager.cc +++ b/third_party/xla/xla/service/gpu/infeed_manager.cc @@ -42,7 +42,7 @@ constexpr int kMaxInfeedsInFlight = 8; InfeedManager::InfeedManager(se::StreamExecutor* executor) : BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight), stream_(executor->CreateStream().value()) { - stream_->set_name("Infeed manager"); + stream_->SetName("Infeed manager"); } static absl::StatusOr CopyBufferToDevice( diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 406fcd9534a9dc..b0e52de8c3a0f6 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -101,7 +101,8 @@ bool IsMatrixMultiplication(const HloInstruction& dot) { PrimitiveType output_primitive_type = dot.shape().element_type(); bool type_is_allowed = - (output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || + (output_primitive_type == F8E3M4 || output_primitive_type == F8E4M3 || + output_primitive_type == F8E4M3FN || output_primitive_type == F8E5M2 || output_primitive_type == F8E4M3FNUZ || output_primitive_type == F8E5M2FNUZ || output_primitive_type == F16 || output_primitive_type == BF16 || output_primitive_type == F32 || @@ -271,7 +272,7 @@ llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32; } llvm::Function* intrinsic = - llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {}); + llvm::Intrinsic::getOrInsertDeclaration(module, llvm_intrinsic_id, {}); return b->CreateCall( intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)}); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index bfc49082958ba4..28850ef8830b4b 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -43,7 +44,6 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -58,10 +58,8 @@ limitations under the License. #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" @@ -128,6 +126,8 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" #include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_permute_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" @@ -157,14 +157,13 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/human_readable_json.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" #include "triton/Dialect/Triton/IR/Dialect.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -740,8 +739,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( const HloCustomCallInstruction* instr) { - TF_RET_CHECK(instr->operand_count() == 6 || instr->operand_count() == 7 || - instr->operand_count() == 8); + TF_RET_CHECK(instr->operand_count() > 3 && instr->operand_count() < 8); TF_ASSIGN_OR_RETURN(const auto gpu_config, instr->backend_config()); const xla::gpu::GemmBackendConfig& config = gpu_config.gemm_backend_config(); @@ -777,22 +775,22 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN( BufferAllocation::Slice b_scale, GetAllocationSliceForHlo(instr->operand(a_scale_index + 1))); + + // cublasLT requires c_scale/d_scale to be null when C/D is not FP8. + // Currently, C cannot be FP8. + BufferAllocation::Slice c_scale, d_scale; #if GOOGLE_CUDA - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice c_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice d_scale, - GetAllocationSliceForHlo(instr->operand(a_scale_index + 3))); -#else // TENSORFLOW_USE_ROCM - BufferAllocation::Slice c_scale; - BufferAllocation::Slice d_scale; + if (instr->shape().tuple_shapes(0).element_type() == F8E4M3FN || + instr->shape().tuple_shapes(0).element_type() == F8E5M2) { + TF_ASSIGN_OR_RETURN(d_scale, + GetAllocationSliceForHlo(instr->operands().back())); + } #endif BufferAllocation::Slice bias; if (has_vector_bias) { TF_ASSIGN_OR_RETURN( - bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 4))); + bias, GetAllocationSliceForHlo(instr->operand(a_scale_index + 2))); } BufferAllocation::Slice d_amax; @@ -1221,14 +1219,15 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk( auto ffi_thunk = [&] { auto& called_computations = instr->called_computations(); return CustomCallThunk::Create( - Thunk::ThunkInfo::WithProfileAnnotation(instr), registration->bundle, - std::move(operands), std::move(results), std::move(attributes), + Thunk::ThunkInfo::WithProfileAnnotation(instr), call_target_name, + registration->bundle, std::move(operands), std::move(results), + std::move(attributes), called_computations.empty() ? nullptr : called_computations[0]); }; auto legacy_thunk = [&] { return CustomCallThunk::Create( - Thunk::ThunkInfo::WithProfileAnnotation(instr), + Thunk::ThunkInfo::WithProfileAnnotation(instr), call_target_name, std::move(custom_call_target), std::move(operands), std::move(results), std::move(opaque)); }; @@ -1516,6 +1515,22 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) { return absl::OkStatus(); } +absl::Status IrEmitterUnnested::EmitCopy(const HloInstruction* instr) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + instr->operand(0)->shape(), instr->shape(), + Layout::Equal().MinorToMajorOnly())); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice src_buffer, + GetAllocationSliceForHlo(instr->operand(0))); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dst_buffer, + GetAllocationSliceForHlo(instr)); + AddThunkToThunkSequence(std::make_unique( + Thunk::ThunkInfo::WithProfileAnnotation(instr), + /*source_buffer=*/src_buffer, + /*destination_buffer=*/dst_buffer, + /*mem_size=*/src_buffer.size())); + return absl::OkStatus(); +} + absl::Status IrEmitterUnnested::EmitAsyncCustomCallStart( const HloInstruction* instr) { const HloInstruction* wrapped = instr->async_wrapped_instruction(); @@ -2140,7 +2155,9 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( thunk_info.profile_annotation = async_start->name(); } auto thunk = std::make_unique( - thunk_info, NcclApi::Default(), inst, /*buffers=*/std::move(buffers)); + thunk_info, NcclApi::Default(), inst, + /*buffers=*/std::move(buffers), + ir_emitter_context_->debug_options().xla_gpu_use_memcpy_local_p2p()); GetCollectivesAsyncEvents().insert({async_start, thunk->async_events()}); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); @@ -2175,40 +2192,128 @@ absl::Status IrEmitterUnnested::EmitNcclThunk( return absl::OkStatus(); } +// Find the canonical send/recv start op for one of send, recv, send-done, or +// recv-done. For trivial cases send/recv and send-done/recv-done come in pairs +// and the canonical start op is the send/recv op of the pair. If send/recv is +// partially pipelined, we will use the send/recv leading into the while loop as +// the canonical start op, which will serve as a key for the async events. +// +// Example: +// ``` +// send_ctx = send(src, ...) <-- canonical start op +// send_ctx_final = while(send_ctx) { +// send_ctx_in = parameter(0) +// send-done(send_ctx_in) +// ... +// ROOT send_ctx_out = send(next_src, ...) +// } +// send-done(send_ctx_final) +// ``` +static const HloInstruction* FindCanonicalSendRecvStartOp( + const HloInstruction* inst) { + CHECK(inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kRecv || + inst->opcode() == HloOpcode::kSendDone || + inst->opcode() == HloOpcode::kRecvDone); + + // Find container while loop and index for the send/recv case or return + // canonical start op directly. + const HloInstruction* while_op = nullptr; + int64_t i = -1; + if (inst->opcode() == HloOpcode::kSend || + inst->opcode() == HloOpcode::kRecv) { + CHECK_EQ(inst->users().size(), 1); + const HloInstruction* unique_user = inst->users().front(); + + // Return send/recv inst directly if this is a simple send/recv pair. + if (unique_user->opcode() == HloOpcode::kSendDone || + unique_user->opcode() == HloOpcode::kRecvDone) { + return inst; + } + + // Find while loop and index, otherwise. + CHECK(unique_user->opcode() == HloOpcode::kTuple || + unique_user->opcode() == HloOpcode::kWhile); + if (unique_user->IsRoot()) { + // send/recv op in the loop body. + CHECK(unique_user->parent()->IsWhileBodyComputation()); + while_op = unique_user->parent()->WhileCallInstruction(); + i = unique_user->operand_index(inst); + } else { + // send/recv leading into the loop. + CHECK_EQ(unique_user->users().size(), 1); + CHECK(unique_user->users().front()->opcode() == HloOpcode::kWhile); + while_op = unique_user->users().front(); + i = unique_user->operand_index(inst); + } + } + + // Find container while loop and index for the send-done/recv-done case or + // return canonical start op directly. + if (inst->opcode() == HloOpcode::kSendDone || + inst->opcode() == HloOpcode::kRecvDone) { + const HloInstruction* operand = inst->operand(0); + + // Return send/recv inst directly if this is a simple send/recv pair. + if (operand->opcode() == HloOpcode::kSend || + operand->opcode() == HloOpcode::kRecv) { + return operand; + } + + // Find while loop and index, otherwise. + CHECK(operand->opcode() == HloOpcode::kGetTupleElement); + const auto* gte = Cast(operand); + const HloInstruction* iter_tuple = operand->operand(0); + if (iter_tuple->opcode() == HloOpcode::kParameter) { + // send-done/recv-done in the loop body. + CHECK(Cast(iter_tuple)->parameter_number() == 0); + CHECK(operand->parent()->IsWhileBodyComputation()); + while_op = iter_tuple->parent()->WhileCallInstruction(); + i = gte->tuple_index(); + } else { + // send-done/recv-done proceeding the loop. + CHECK(iter_tuple->opcode() == HloOpcode::kWhile); + while_op = iter_tuple; + i = gte->tuple_index(); + } + } + + // Extract canonical start op from while loop's init. + CHECK(while_op != nullptr); + CHECK(0 <= i && i < while_op->shape().tuple_shapes_size()); + const HloInstruction* init = while_op->operand(0); + const HloInstruction* canonical_start_op = init->operand(i); + CHECK(canonical_start_op->opcode() == HloOpcode::kSend || + canonical_start_op->opcode() == HloOpcode::kRecv); + return canonical_start_op; +} + absl::Status IrEmitterUnnested::EmitNcclAsyncDone(Thunk::Kind kind, const HloInstruction* inst) { + // Partial pipelining is only implemented for send/recv. + bool is_send_recv = + kind == Thunk::Kind::kNcclRecvDone || kind == Thunk::Kind::kNcclSendDone; + const HloInstruction* start = + is_send_recv ? FindCanonicalSendRecvStartOp(inst) : inst->operand(0); + + // Find canonical async event. CollectivesAsyncEvents& collectives_async_events = GetCollectivesAsyncEvents(); - if (kind == Thunk::Kind::kNcclRecvDone || - kind == Thunk::Kind::kNcclSendDone) { - const HloChannelInstruction* done = DynCast(inst); - int64_t channel_id = done->channel_id().value(); - // We only pipeline Send/Recv when channel_id > 0, and allows multiple - // and potentially interleaving Send/Recv chains using channel_id = 0. - if (MayPipelineSendRecvChannel(channel_id)) { - auto it = collectives_async_events.find( - GetSendRecvAsyncEventsKey(kind, channel_id)); - TF_RET_CHECK(it != collectives_async_events.end()) - << "couldn't find async events for channel_id " << channel_id; - AddThunkToThunkSequence(std::make_unique( - kind, Thunk::ThunkInfo::WithProfileAnnotation(inst), it->second, - GetStreamKindForSendRecv(DynCast(inst)))); - return absl::OkStatus(); - } - } - - const HloInstruction* start = inst->operand(0); - auto async_events = collectives_async_events.extract(start); - TF_RET_CHECK(async_events) + auto async_events_it = collectives_async_events.find(start); + TF_RET_CHECK(async_events_it != collectives_async_events.end()) << "couldn't find async events for start operation"; // Can be null if no start thunk was created (e.g. if the start op is // degenerate), in which case there's nothing to do here. - if (async_events.mapped()) { - AddThunkToThunkSequence(std::make_unique( - kind, Thunk::ThunkInfo::WithProfileAnnotation(inst), - std::move(async_events.mapped()), AsyncStreamKind::kCollective)); + if (!async_events_it->second) return absl::OkStatus(); + + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective; + if (is_send_recv) { + stream_kind = GetStreamKindForSendRecv(Cast(start)); } + AddThunkToThunkSequence(std::make_unique( + kind, Thunk::ThunkInfo::WithProfileAnnotation(inst), + async_events_it->second, stream_kind)); return absl::OkStatus(); } @@ -2463,8 +2568,10 @@ absl::Status IrEmitterUnnested::EmitCopyDoneThunk(const HloInstruction* instr) { } absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { - if (!instr->channel_id().has_value()) + // TODO(b/372306903): Do not require channel id for send. + if (!instr->channel_id().has_value()) { return absl::InternalError("Unknown send instruction channel id"); + } const HloInstruction* src = instr->operand(0); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, @@ -2488,20 +2595,12 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { instr, replica_count, partition_count, nccl_buffer); CollectivesAsyncEvents& collectives_async_events = GetCollectivesAsyncEvents(); - int64_t channel_id = instr->channel_id().value(); - if (MayPipelineSendRecvChannel(channel_id)) { - std::pair async_events_key = - GetSendRecvAsyncEventsKey(Thunk::Kind::kNcclSendDone, channel_id); - auto it = collectives_async_events.find(async_events_key); - if (it != collectives_async_events.end()) { - VLOG(0) << "Found async events " << it->second.get(); - thunk->set_async_events(it->second); - } else { - VLOG(0) << "Used Async events create for thunk " - << thunk->async_events().get(); - collectives_async_events.emplace(async_events_key, - thunk->async_events()); - } + + // Wire up async events. + const HloInstruction* canonical_send_instr = + FindCanonicalSendRecvStartOp(instr); + if (collectives_async_events.contains(canonical_send_instr)) { + thunk->set_async_events(collectives_async_events[canonical_send_instr]); } else { collectives_async_events.try_emplace(instr, thunk->async_events()); } @@ -2521,8 +2620,10 @@ absl::Status IrEmitterUnnested::EmitSendThunk(const HloSendInstruction* instr) { absl::Status IrEmitterUnnested::EmitSendDoneThunk( const HloSendDoneInstruction* instr) { - if (!instr->channel_id().has_value()) + // TODO(b/372306903): Do not require channel id for send-done. + if (!instr->channel_id().has_value()) { return absl::InternalError("Unknown send done instruction channel id"); + } if (!instr->is_host_transfer()) { return EmitNcclAsyncDone(Thunk::kNcclSendDone, instr); @@ -2536,8 +2637,11 @@ absl::Status IrEmitterUnnested::EmitSendDoneThunk( } absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) { - if (!instr->channel_id().has_value()) + // TODO(b/372306903): Do not require channel id for recv. + if (!instr->channel_id().has_value()) { return absl::InternalError("Unknown recv instruction channel id"); + } + TF_RET_CHECK(instr->shape().IsTuple()); TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, GetAllocationSliceForHlo(instr, {0})); @@ -2563,18 +2667,12 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) { instr, replica_count, partition_count, nccl_buffer); CollectivesAsyncEvents& collectives_async_events = GetCollectivesAsyncEvents(); - int64_t channel_id = instr->channel_id().value(); - if (MayPipelineSendRecvChannel(channel_id)) { - std::pair async_events_key = - GetSendRecvAsyncEventsKey(Thunk::Kind::kNcclRecvDone, channel_id); - auto it = collectives_async_events.find(async_events_key); - - if (it != GetCollectivesAsyncEvents().end()) { - thunk->set_async_events(it->second); - } else { - collectives_async_events.emplace(async_events_key, - thunk->async_events()); - } + + // Wire up async events. + const HloInstruction* canonical_recv_instr = + FindCanonicalSendRecvStartOp(instr); + if (collectives_async_events.contains(canonical_recv_instr)) { + thunk->set_async_events(collectives_async_events[canonical_recv_instr]); } else { collectives_async_events.try_emplace(instr, thunk->async_events()); } @@ -2595,8 +2693,10 @@ absl::Status IrEmitterUnnested::EmitRecvThunk(const HloRecvInstruction* instr) { absl::Status IrEmitterUnnested::EmitRecvDoneThunk( const HloRecvDoneInstruction* instr) { - if (!instr->channel_id().has_value()) + // TODO(b/372306903): Do not require channel id for send-done. + if (!instr->channel_id().has_value()) { return absl::InternalError("Unknown recv done instruction channel id"); + } if (!instr->is_host_transfer()) { return EmitNcclAsyncDone(Thunk::kNcclRecvDone, instr); @@ -2772,9 +2872,10 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( } return EmitCustomCallThunk(custom_call); } - case HloOpcode::kFusion: { + case HloOpcode::kFusion: return EmitFusion(Cast(instr)); - } + case HloOpcode::kCopy: + return EmitCopy(instr); case HloOpcode::kInfeed: return EmitInfeed(Cast(instr)); case HloOpcode::kOutfeed: diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index d19dd5d9c4172c..5b5f4b11a19260 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -156,6 +156,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitCustomCallThunk(const HloCustomCallInstruction* instr); absl::Status EmitFftThunk(const HloFftInstruction* instr); absl::Status EmitFusion(const HloFusionInstruction* instr); + absl::Status EmitCopy(const HloInstruction* instr); absl::Status EmitAsyncCustomCallStart(const HloInstruction* instr); absl::Status EmitSelectAndScatter( const HloSelectAndScatterInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index db3a909357eda6..2e237acd8f6671 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -62,8 +62,8 @@ cc_library( hdrs = ["custom_kernel.h"], visibility = [":friends"], deps = [ - "//xla/stream_executor", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/strings:str_format", ], ) @@ -82,8 +82,8 @@ cc_library( srcs = ["cutlass_gemm_fusion.cc"], hdrs = ["cutlass_gemm_fusion.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":custom_kernel", @@ -115,7 +115,7 @@ xla_test( backends = ["gpu"], # TODO(b/332820384): Enable when it passes on H100. disabled_backends = DEFAULT_DISABLED_BACKENDS + ["gpu_h100"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":custom_kernel_fusion_pattern", ":cutlass_gemm_custom_kernel", @@ -149,13 +149,12 @@ cc_library( ":topk_kernel_gpu", "//xla:shape_util", "//xla:types", - "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", # build_cleaner: keep - "//xla/stream_executor:platform", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -177,9 +176,12 @@ gpu_kernel_library( compatible_with = [], deps = [ "//xla:types", - "//xla/stream_executor/gpu:gpu_types_header", - "@local_tsl//tsl/lib/math:math_util", - ], + "//xla/tsl/lib/math:math_util", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) xla_test( @@ -190,9 +192,10 @@ xla_test( ":topk_kernel", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/stream_executor", # build_cleaner: keep "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_init", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/host:host_platform", @@ -219,8 +222,10 @@ cc_library( ":custom_kernel", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -240,10 +245,11 @@ xla_test( "//xla:types", "//xla:xla_data_proto_cc", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:kernel", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/stream_executor/cuda:cuda_platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", @@ -264,16 +270,18 @@ cc_library( srcs = ["cutlass_gemm_custom_kernel.cc"], hdrs = ["cutlass_gemm_custom_kernel.h"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":custom_kernel", ":cutlass_gemm", ":cutlass_gemm_kernels", # build_cleaner: keep "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -288,13 +296,16 @@ xla_test( srcs = ["cutlass_gemm_custom_kernel_test.cc"], backends = ["gpu"], data = [":cutlass_gemm_kernel_f32xf32_to_f32.so"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", @@ -309,16 +320,20 @@ xla_cc_binary( testonly = 1, srcs = ["cutlass_gemm_custom_kernel_benchmarks.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", "//xla/service:gpu_plugin", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -346,7 +361,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), # __grid_constant__ is not supported by clang - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm", "@cutlass_archive//:cutlass", @@ -355,7 +370,7 @@ cuda_library( cuda_library( name = "cutlass_gemm_epilogue", - tags = ["no_rocm"], + tags = ["cuda-only"], # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. textual_hdrs = ["cutlass_gemm_epilogue.cu.h"], deps = ["@cutlass_archive//:cutlass"], @@ -371,8 +386,8 @@ cuda_library( cc_library( name = "cutlass_gemm_kernels", tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_kernel_bf16xbf16_to_bf16", @@ -399,7 +414,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -414,7 +429,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -432,7 +447,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -450,7 +465,7 @@ cuda_library( [], ["-Wno-unknown-attributes"], ), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", @@ -463,8 +478,8 @@ cuda_library( srcs = ["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"], copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":cutlass_gemm_adaptor", @@ -483,8 +498,51 @@ cc_binary( linkshared = True, linkstatic = False, tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [":cutlass_gemm"], ) + +#===--------------------------------------------------------------------------------------------===# +# PTX Custom Kernels +#===--------------------------------------------------------------------------------------------===# + +cc_library( + name = "ptx_custom_kernel", + srcs = ["ptx_custom_kernel.cc"], + hdrs = ["ptx_custom_kernel.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":custom_kernel", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_test( + name = "ptx_custom_kernel_test", + srcs = ["ptx_custom_kernel_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":custom_kernel", + ":ptx_custom_kernel", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/cuda:cuda_platform", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc new file mode 100644 index 00000000000000..228804d0d83b0f --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.cc @@ -0,0 +1,56 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/ptx_custom_kernel.h" + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::gpu::kernel { + +namespace se = ::stream_executor; + +absl::StatusOr> +KernelArgsPacking(const se::Kernel &kernel, const se::KernelArgs &args) { + auto *mem_args = se::Cast(&args); + + return se::PackKernelArgs(mem_args->device_memory_args(), + mem_args->number_of_shared_bytes()); +} + +// Note: Make sure that the kernel_name matches the kernel name in the ptx, +// otherwise you will get a "CUDA_ERROR_NOT_FOUND: named symbol not found.". +// E.g. `.visible .entry AddI32(...)` would have a kernel name of "AddI32". +absl::StatusOr GetPtxCustomKernel(std::string kernel_name, + std::string_view ptx, + int num_args, + se::BlockDim block_dim, + se::ThreadDim thread_dim, + size_t shared_memory_bytes) { + se::MultiKernelLoaderSpec kernel_spec(/*arity=*/num_args, KernelArgsPacking); + kernel_spec.AddCudaPtxInMemory(ptx, kernel_name); + return CustomKernel(kernel_name, kernel_spec, block_dim, thread_dim, + /*shared_memory_bytes=*/shared_memory_bytes); +}; + +} // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h new file mode 100644 index 00000000000000..7ebe304df9c466 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_KERNELS_PTX_CUSTOM_KERNEL_H_ +#define XLA_SERVICE_GPU_KERNELS_PTX_CUSTOM_KERNEL_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/launch_dim.h" + +namespace xla::gpu::kernel { + +absl::StatusOr GetPtxCustomKernel(std::string kernel_name, + std::string_view ptx, + int num_args, + se::BlockDim block_dim, + se::ThreadDim thread_dim, + size_t shared_memory_bytes = 0); +} + +#endif // XLA_SERVICE_GPU_KERNELS_PTX_CUSTOM_KERNEL_H_ diff --git a/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc new file mode 100644 index 00000000000000..bf6f650876a6ea --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/ptx_custom_kernel_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/kernels/ptx_custom_kernel.h" + +#include +#include +#include +#include + +#include "xla/service/gpu/kernels/custom_kernel.h" +#include "xla/stream_executor/cuda/cuda_platform.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu::kernel { + +namespace se = ::stream_executor; + +constexpr std::string_view kAddI32KernelPtx = R"( +.version 4.0 +.target sm_50 +.address_size 64 + +.visible .entry AddI32( + .param .u64 AddI32_param_0, + .param .u64 AddI32_param_1, + .param .u64 AddI32_param_2 +) +{ + .reg .b32 %r<8>; + .reg .b64 %rd<11>; + .loc 1 1 0 + + ld.param.u64 %rd1, [AddI32_param_0]; + ld.param.u64 %rd2, [AddI32_param_1]; + ld.param.u64 %rd3, [AddI32_param_2]; + .loc 1 3 3 + cvta.to.global.u64 %rd4, %rd3; + cvta.to.global.u64 %rd5, %rd2; + cvta.to.global.u64 %rd6, %rd1; + mov.u32 %r1, %tid.x; + mov.u32 %r2, %ctaid.x; + mov.u32 %r3, %ntid.x; + mad.lo.s32 %r4, %r2, %r3, %r1; + .loc 1 4 3 + mul.wide.s32 %rd7, %r4, 4; + add.s64 %rd8, %rd6, %rd7; + ld.global.u32 %r5, [%rd8]; + add.s64 %rd9, %rd5, %rd7; + ld.global.u32 %r6, [%rd9]; + add.s32 %r7, %r6, %r5; + add.s64 %rd10, %rd4, %rd7; + st.global.u32 [%rd10], %r7; + .loc 1 5 1 + ret; + +})"; + +TEST(PtxCustomKernelTest, GetPtxCustomKernel) { + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + se::gpu::CudaPlatform platform; + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, + platform.ExecutorForDevice(0)); + TF_ASSERT_OK_AND_ASSIGN( + CustomKernel custom_kernel, + GetPtxCustomKernel("AddI32", kAddI32KernelPtx, 3, se::BlockDim(4), + se::ThreadDim(1), byte_length)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr kernel, + executor->LoadKernel(custom_kernel.kernel_spec())); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + executor->CreateStream()); + se::DeviceMemory a = executor->AllocateArray(length, 0); + se::DeviceMemory b = executor->AllocateArray(length, 0); + se::DeviceMemory c = executor->AllocateArray(length, 0); + TF_CHECK_OK(stream->Memset32(&a, 1, byte_length)); + TF_CHECK_OK(stream->Memset32(&b, 2, byte_length)); + TF_CHECK_OK(stream->MemZero(&c, byte_length)); + + se::KernelArgsDeviceMemoryArray args( + std::vector({a, b, c}), + custom_kernel.shared_memory_bytes()); + TF_CHECK_OK(stream->Launch(custom_kernel.thread_dims(), + custom_kernel.block_dims(), *kernel, args)); + + TF_CHECK_OK(stream->BlockHostUntilDone()); + + std::vector dst(4, 42); + TF_CHECK_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); +} + +} // namespace xla::gpu::kernel diff --git a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h index ee3d71f4a36423..5a68b56efd7ef7 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/topk_kernel.cu.h @@ -25,8 +25,7 @@ limitations under the License. #include #include "xla/service/gpu/kernels/topk_kernel_common.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #if GOOGLE_CUDA diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index b1f957b0a4a812..613ec0f42fa23f 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -11,7 +11,11 @@ load( "if_cuda_is_configured", ) load("//xla:xla.bzl", "xla_cc_test") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -30,16 +34,17 @@ cc_library( name = "llvm_gpu_backend", srcs = [ "gpu_backend_lib.cc", - "utils.cc", ], hdrs = [ "gpu_backend_lib.h", - "utils.h", ], local_defines = if_cuda_is_configured([ "GOOGLE_CUDA=1", ]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ + ":load_ir_module", + ":nvptx_libdevice_path", + ":utils", "//xla:status_macros", "//xla:types", "//xla:util", @@ -76,7 +81,7 @@ cc_library( "@llvm-project//llvm:Target", "@llvm-project//mlir:NVVMDialect", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -98,32 +103,98 @@ cc_library( ]), ) +cc_library( + name = "load_ir_module", + hdrs = ["load_ir_module.h"], + deps = [ + "@com_google_absl//absl/strings:string_view", + ] + if_google( + ["//xla/service/gpu/llvm_gpu_backend/google:load_ir_module"], + ["//xla/service/gpu/llvm_gpu_backend/default:load_ir_module"], + ), +) + +cc_library( + name = "nvptx_libdevice_path", + hdrs = ["nvptx_libdevice_path.h"], + deps = [ + "@com_google_absl//absl/strings:string_view", + ] + if_google( + ["//xla/service/gpu/llvm_gpu_backend/google:nvptx_libdevice_path"], + ["//xla/service/gpu/llvm_gpu_backend/default:nvptx_libdevice_path"], + ), +) + +cc_library( + name = "nvptx_utils", + srcs = ["nvptx_utils.cc"], + hdrs = ["nvptx_utils.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:cuda_root_path", + ], +) + +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + ], +) + xla_cc_test( - name = "utils_test", + name = "gpu_backend_lib_test", size = "small", - srcs = ["utils_test.cc"], + srcs = ["gpu_backend_lib_test.cc"], + deps = [ + ":llvm_gpu_backend", + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + ], +) + +xla_cc_test( + name = "load_ir_module_test", + size = "small", + srcs = ["load_ir_module_test.cc"], data = [ "tests_data/saxpy.ll", ], deps = [ - ":llvm_gpu_backend", + ":load_ir_module", "//xla/tests:xla_internal_test_main", - "@llvm-project//llvm:Core", + "@llvm-project//llvm:ir_headers", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:test", ], ) xla_cc_test( - name = "gpu_backend_lib_test", + name = "nvptx_utils_test", + srcs = ["nvptx_utils_test.cc"], + deps = [ + ":nvptx_utils", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test", + ], +) + +xla_cc_test( + name = "utils_test", size = "small", - srcs = ["gpu_backend_lib_test.cc"], + srcs = ["utils_test.cc"], deps = [ - ":llvm_gpu_backend", - "//xla/stream_executor:device_description", - "//xla/stream_executor:semantic_version", + ":utils", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/BUILD new file mode 100644 index 00000000000000..95f2fbcc68b72e --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/BUILD @@ -0,0 +1,37 @@ +# Description: +# Default implementations for llvm_gpu_backend + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/service/gpu/llvm_gpu_backend:__pkg__"], + licenses = ["notice"], +) + +cc_library( + name = "nvptx_libdevice_path", + srcs = ["nvptx_libdevice_path.cc"], + deps = [ + "//xla/service/gpu/llvm_gpu_backend:nvptx_utils", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:cuda_root_path", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + ], + alwayslink = True, +) + +cc_library( + name = "load_ir_module", + srcs = ["load_ir_module.cc"], + deps = [ + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:logging", + ], + alwayslink = True, +) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/load_ir_module.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/load_ir_module.cc new file mode 100644 index 00000000000000..ea64e722cd927b --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/load_ir_module.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/SourceMgr.h" +#include "tsl/platform/logging.h" + +namespace { + +static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) { + LOG(FATAL) << diagnostic->getFilename().str() << ":" + << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() + << ": " << diagnostic->getMessage().str(); +} + +} // namespace + +namespace xla::gpu { + +std::unique_ptr LoadIRModule(const std::string& filename, + llvm::LLVMContext* llvm_context) { + llvm::SMDiagnostic diagnostic_err; + std::unique_ptr module = + llvm::getLazyIRFileModule(filename, diagnostic_err, *llvm_context, + /*ShouldLazyLoadMetadata=*/true); + + if (module == nullptr) { + DieWithSMDiagnosticError(&diagnostic_err); + } + + return module; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/nvptx_libdevice_path.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/nvptx_libdevice_path.cc new file mode 100644 index 00000000000000..b74a172bdb9058 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/default/nvptx_libdevice_path.cc @@ -0,0 +1,71 @@ + +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/base/const_init.h" +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" +#include "tsl/platform/cuda_root_path.h" +#include "tsl/platform/env.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/path.h" + +namespace xla::gpu::nvptx { +namespace { + +std::string GetLibdeviceDir(absl::string_view xla_gpu_cuda_data_dir) { + for (const std::string& cuda_root : + tsl::CandidateCudaRoots(std::string{xla_gpu_cuda_data_dir})) { + std::string libdevice_dir = + tsl::io::JoinPath(cuda_root, "nvvm", "libdevice"); + VLOG(2) << "Looking for libdevice at " << libdevice_dir; + if (tsl::Env::Default()->IsDirectory(libdevice_dir).ok()) { + VLOG(2) << "Found libdevice dir " << libdevice_dir; + return libdevice_dir; + } + } + LOG(WARNING) << CantFindCudaMessage( + "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " + "result in compilation or runtime failures, if the program we try to run " + "uses routines from libdevice.", + xla_gpu_cuda_data_dir); + // GetCudaRootCandidates always includes ".", but if everything fails, we + // return it anyway. Better than returning the empty string. + return "."; +} + +} // namespace + +std::string LibDevicePath(absl::string_view xla_gpu_cuda_data_dir) { + static absl::Mutex libdevice_cache_mu(absl::kConstInit); + static auto& libdevice_dir_path_cache ABSL_GUARDED_BY(libdevice_cache_mu) = + *new absl::flat_hash_map(); + std::string libdevice_dir_path = [&] { + absl::MutexLock l(&libdevice_cache_mu); + auto it = libdevice_dir_path_cache.find(xla_gpu_cuda_data_dir); + if (it != libdevice_dir_path_cache.end()) { + return it->second; + } + auto [it2, inserted] = libdevice_dir_path_cache.emplace( + xla_gpu_cuda_data_dir, GetLibdeviceDir(xla_gpu_cuda_data_dir)); + return it2->second; + }(); + // CUDA 9+ uses a single libdevice file for all devices, and we don't support + // older CUDAs. + return tsl::io::JoinPath(libdevice_dir_path, "libdevice.10.bc"); +} + +} // namespace xla::gpu::nvptx diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index f58accf3e2583b..7a113fe4e8ea0b 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -32,9 +32,6 @@ limitations under the License. #include #include "absl/base/call_once.h" -#include "absl/base/const_init.h" -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" @@ -42,7 +39,6 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "llvm/ADT/Any.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" @@ -77,6 +73,8 @@ limitations under the License. #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/Transforms/IPO/Internalize.h" #include "llvm/Transforms/Scalar.h" +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h" #include "xla/service/gpu/llvm_gpu_backend/utils.h" #include "xla/service/gpu/metrics.h" #include "xla/service/llvm_ir/llvm_command_line_options.h" @@ -86,7 +84,6 @@ limitations under the License. #include "xla/tsl/util/env_var.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/platform/cuda_libdevice_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -272,14 +269,31 @@ absl::Status LinkWithBitcodeVector( return absl::OkStatus(); } +// Links libdevice into the given module if the module needs libdevice. +absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, + const std::string& libdevice_path) { + if (!CouldNeedDeviceBitcode(*module)) { + return absl::OkStatus(); + } + + if (!tsl::Env::Default()->FileExists(libdevice_path).ok()) { + LOG(WARNING) + << "libdevice is required by this HLO module but was not found at " + << libdevice_path; + return xla::Internal("libdevice not found at %s", libdevice_path); + } + + VLOG(1) << "Linking with libdevice from: " << libdevice_path; + return LinkWithBitcodeVector(module, {libdevice_path}); +} + absl::Status NVPTXTargetModuleLinker(llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, const std::string& device_bitcode_path) { // Link the input module with libdevice, to pull in implementations of some // builtins. - TF_RETURN_IF_ERROR( - nvptx::LinkLibdeviceIfNecessary(module, device_bitcode_path)); + TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(module, device_bitcode_path)); // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. @@ -554,76 +568,6 @@ std::string GetSmName(se::CudaComputeCapability compute_capability) { return absl::StrCat("sm_", sm_version, extension); } -std::string CantFindCudaMessage(absl::string_view msg, - absl::string_view xla_gpu_cuda_data_dir) { - return absl::StrCat( - msg, "\nSearched for CUDA in the following directories:\n ", - absl::StrJoin(tsl::CandidateCudaRoots(std::string{xla_gpu_cuda_data_dir}), - "\n "), - "\nYou can choose the search directory by setting xla_gpu_cuda_data_dir " - "in HloModule's DebugOptions. For most apps, setting the environment " - "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."); -} - -static std::string GetLibdeviceDir(absl::string_view xla_gpu_cuda_data_dir) { - for (const std::string& cuda_root : - tsl::CandidateCudaRoots(std::string{xla_gpu_cuda_data_dir})) { - std::string libdevice_dir = - tsl::io::JoinPath(cuda_root, "nvvm", "libdevice"); - VLOG(2) << "Looking for libdevice at " << libdevice_dir; - if (tsl::Env::Default()->IsDirectory(libdevice_dir).ok()) { - VLOG(2) << "Found libdevice dir " << libdevice_dir; - return libdevice_dir; - } - } - LOG(WARNING) << CantFindCudaMessage( - "Can't find libdevice directory ${CUDA_DIR}/nvvm/libdevice. This may " - "result in compilation or runtime failures, if the program we try to run " - "uses routines from libdevice.", - xla_gpu_cuda_data_dir); - - // GetCudaRootCandidates always includes ".", but if everything fails, we - // return it anyway. Better than returning the empty string. - return "."; -} - -std::string LibDevicePath(absl::string_view xla_gpu_cuda_data_dir) { - static absl::Mutex libdevice_cache_mu(absl::kConstInit); - static auto& libdevice_dir_path_cache ABSL_GUARDED_BY(libdevice_cache_mu) = - *new absl::flat_hash_map(); - std::string libdevice_dir_path = [&] { - absl::MutexLock l(&libdevice_cache_mu); - auto it = libdevice_dir_path_cache.find(xla_gpu_cuda_data_dir); - if (it != libdevice_dir_path_cache.end()) { - return it->second; - } - auto [it2, inserted] = libdevice_dir_path_cache.emplace( - xla_gpu_cuda_data_dir, GetLibdeviceDir(xla_gpu_cuda_data_dir)); - return it2->second; - }(); - // CUDA 9+ uses a single libdevice file for all devices, and we don't support - // older CUDAs. - return tsl::io::JoinPath(libdevice_dir_path, "libdevice.10.bc"); -} - -// Links libdevice into the given module if the module needs libdevice. -absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, - const std::string& libdevice_path) { - if (!CouldNeedDeviceBitcode(*module)) { - return absl::OkStatus(); - } - - if (!tsl::Env::Default()->FileExists(libdevice_path).ok()) { - LOG(WARNING) - << "libdevice is required by this HLO module but was not found at " - << libdevice_path; - return xla::Internal("libdevice not found at %s", libdevice_path); - } - - VLOG(1) << "Linking with libdevice from: " << libdevice_path; - return LinkWithBitcodeVector(module, {libdevice_path}); -} - absl::StatusOr CompileToPtx( llvm::Module* module, se::GpuComputeCapability gpu_version, const DebugOptions& debug_options, diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 42f84b503e5c84..a93a1d3e1590de 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -43,16 +43,6 @@ namespace nvptx { std::string GetSmName( stream_executor::CudaComputeCapability compute_capability); -std::string CantFindCudaMessage(absl::string_view msg, - absl::string_view xla_gpu_cuda_data_dir); - -// Get path to NVVM libdevice file. -std::string LibDevicePath(absl::string_view xla_gpu_cuda_data_dir); - -// Link libdevice if functions using it are detected in the module. -absl::Status LinkLibdeviceIfNecessary(llvm::Module* module, - const std::string& libdevice_path); - // Compiles the argument module and returns it. libdevice_dir_path is the parent // directory of the libdevice bitcode libraries. The contents of the module may // be changed. diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module.h new file mode 100644 index 00000000000000..c5c610c7cb74a5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module.h @@ -0,0 +1,48 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_LOAD_IR_MODULE_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_LOAD_IR_MODULE_H_ + +#include +#include + +#include "absl/strings/string_view.h" + +namespace llvm { +class LLVMContext; +class Module; +} // namespace llvm + +namespace xla::gpu { + +// Convenience function for loading a LLVM module from an IR file. The module +// is created in the given LLVM context. +// +// If loading fails for some reason, dies printing a diagnostic error. +std::unique_ptr LoadIRModule(const std::string& filename, + llvm::LLVMContext* llvm_context); + +// Convenience function for replacing the extension of the given filename. +// If the filename has no extension, the new extension is appended to its name. +// +// For example: +// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc" +std::string ReplaceFilenameExtension(absl::string_view filename, + absl::string_view new_extension); + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_LOAD_IR_MODULE_H_ diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module_test.cc new file mode 100644 index 00000000000000..93f306de781d64 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/load_ir_module_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/load_ir_module.h" + +#include +#include + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "tsl/platform/path.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +std::string SaxpyIRFile() { + return tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "llvm_gpu_backend", "tests_data", "saxpy.ll"); +} + +TEST(LoadIrModuleTest, TestLoadIRModule) { + llvm::LLVMContext llvm_context; + std::string test_srcdir = tsl::testing::TensorFlowSrcRoot(); + std::unique_ptr module = + LoadIRModule(SaxpyIRFile(), &llvm_context); + // Sanity check that the module was loaded properly. + ASSERT_NE(nullptr, module); + ASSERT_NE(std::string::npos, module->getModuleIdentifier().find("saxpy.ll")); + ASSERT_NE(nullptr, module->getFunction("cuda_saxpy")); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h new file mode 100644 index 00000000000000..b9be9cf61e9280 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_libdevice_path.h @@ -0,0 +1,25 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_LIBDEVICE_PATH_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_LIBDEVICE_PATH_H_ +#include + +#include "absl/strings/string_view.h" + +namespace xla::gpu::nvptx { + +// Returns path to libdevice file. +std::string LibDevicePath(absl::string_view xla_gpu_data_dir); + +} // namespace xla::gpu::nvptx + +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_LIBDEVICE_PATH_H_ diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.cc new file mode 100644 index 00000000000000..f37268ac02d670 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.cc @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" + +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/cuda_root_path.h" + +namespace xla::gpu::nvptx { + +std::string CantFindCudaMessage(absl::string_view msg, + absl::string_view xla_gpu_cuda_data_dir) { + return absl::StrCat( + msg, "\nSearched for CUDA in the following directories:\n ", + absl::StrJoin(tsl::CandidateCudaRoots(std::string{xla_gpu_cuda_data_dir}), + "\n "), + "\nYou can choose the search directory by setting xla_gpu_cuda_data_dir " + "in HloModule's DebugOptions. For most apps, setting the environment " + "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work."); +} + +} // namespace xla::gpu::nvptx diff --git a/tensorflow/cc/experimental/libtf/impl/string_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.h similarity index 59% rename from tensorflow/cc/experimental/libtf/impl/string_test.cc rename to third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.h index 4cc07d07dfa095..b4545f96855220 100644 --- a/tensorflow/cc/experimental/libtf/impl/string_test.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils.h @@ -1,31 +1,26 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/experimental/libtf/impl/string.h" +#include + +#include "absl/strings/string_view.h" + +#ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_UTILS_H_ +#define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_UTILS_H_ -#include "tensorflow/core/platform/test.h" +namespace xla::gpu::nvptx { -namespace tf { -namespace libtf { -namespace impl { +std::string CantFindCudaMessage(absl::string_view msg, + absl::string_view xla_gpu_cuda_data_dir); -TEST(StringTest, TestBasicInterning) { - String s1("foo"); - String s2("foo"); - EXPECT_EQ(&s1.str(), &s2.str()); } -} // namespace impl -} // namespace libtf -} // namespace tf +#endif // XLA_SERVICE_GPU_LLVM_GPU_BACKEND_NVPTX_UTILS_H_ diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils_test.cc new file mode 100644 index 00000000000000..9b587473ec275e --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/nvptx_utils_test.cc @@ -0,0 +1,26 @@ +/* Copyright 2024 The OpenXLA Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" + +#include +#include "tsl/platform/test.h" + +namespace xla::gpu::nvptx { +namespace { + +TEST(NvptxUtilsTest, CantFindCudaMessageTest) { + auto msg = CantFindCudaMessage("foo", "/bar"); + EXPECT_NE(msg.length(), 0); +} + +} // namespace +} // namespace xla::gpu::nvptx diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.cc index 84254bff68873e..e9922c272ae991 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.cc @@ -15,44 +15,14 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/utils.h" -#include #include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IRReader/IRReader.h" -#include "llvm/Support/SourceMgr.h" -#include "tsl/platform/logging.h" - -namespace { - -static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) { - LOG(FATAL) << diagnostic->getFilename().str() << ":" - << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo() - << ": " << diagnostic->getMessage().str(); -} - -} // namespace namespace xla { namespace gpu { -std::unique_ptr LoadIRModule(const std::string& filename, - llvm::LLVMContext* llvm_context) { - llvm::SMDiagnostic diagnostic_err; - std::unique_ptr module = - llvm::getLazyIRFileModule(filename, diagnostic_err, *llvm_context, - /*ShouldLazyLoadMetadata=*/true); - - if (module == nullptr) { - DieWithSMDiagnosticError(&diagnostic_err); - } - - return module; -} - std::string ReplaceFilenameExtension(absl::string_view filename, absl::string_view new_extension) { auto pos = filename.rfind('.'); diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.h index b355852ea4fe08..e99e03325d09dc 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils.h @@ -16,26 +16,13 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_ #define XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_ -#include #include #include "absl/strings/string_view.h" -namespace llvm { -class LLVMContext; -class Module; -} // namespace llvm - namespace xla { namespace gpu { -// Convenience function for loading a LLVM module from an IR file. The module -// is created in the given LLVM context. -// -// If loading fails for some reason, dies printing a diagnostic error. -std::unique_ptr LoadIRModule(const std::string& filename, - llvm::LLVMContext* llvm_context); - // Convenience function for replacing the extension of the given filename. // If the filename has no extension, the new extension is appended to its name. // diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils_test.cc index c7b707c839ce20..413aab81713d19 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils_test.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/utils_test.cc @@ -15,34 +15,12 @@ limitations under the License. #include "xla/service/gpu/llvm_gpu_backend/utils.h" -#include -#include - -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "tsl/platform/path.h" #include "tsl/platform/test.h" namespace xla { namespace gpu { namespace { -std::string SaxpyIRFile() { - return tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", - "llvm_gpu_backend", "tests_data", "saxpy.ll"); -} - -TEST(UtilsTest, TestLoadIRModule) { - llvm::LLVMContext llvm_context; - std::string test_srcdir = tsl::testing::TensorFlowSrcRoot(); - std::unique_ptr module = - LoadIRModule(SaxpyIRFile(), &llvm_context); - // Sanity check that the module was loaded properly. - ASSERT_NE(nullptr, module); - ASSERT_NE(std::string::npos, module->getModuleIdentifier().find("saxpy.ll")); - ASSERT_NE(nullptr, module->getFunction("cuda_saxpy")); -} - TEST(UtilsTest, TestReplaceFilenameExtension) { ASSERT_EQ(ReplaceFilenameExtension("baz.tx", "cc"), "baz.cc"); ASSERT_EQ(ReplaceFilenameExtension("/foo/baz.txt", "cc"), "/foo/baz.cc"); diff --git a/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc b/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc new file mode 100644 index 00000000000000..7359fede2cc51d --- /dev/null +++ b/third_party/xla/xla/service/gpu/matmul_indexing_utils.cc @@ -0,0 +1,78 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/types/span.h" +#include "xla/autotuning.pb.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/shape.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::StatusOr> GetNonContractingDims( + const Shape& shape, absl::Span batch_dims, + absl::Span contracting_dims) { + auto nc = + ::xla::GetNonContractingDims(shape.rank(), contracting_dims, batch_dims); + + TF_RET_CHECK(batch_dims.size() + contracting_dims.size() + nc.size() == + shape.rank()); + return std::vector(nc.begin(), nc.end()); +} + +const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( + const HloInstruction& dot, const int operand_number) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + if (operand_number == 0) { + return dimension_numbers.lhs_batch_dimensions(); + } + return dimension_numbers.rhs_batch_dimensions(); +} + +absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + if (operand_number == 0) { + TF_RET_CHECK(dimension_numbers.lhs_contracting_dimensions().size() == 1); + return dimension_numbers.lhs_contracting_dimensions(0); + } + TF_RET_CHECK(dimension_numbers.rhs_contracting_dimensions().size() == 1); + return dimension_numbers.rhs_contracting_dimensions(0); +} + +absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + const int operand_number) { + TF_ASSIGN_OR_RETURN(int64_t contracting_dim, + ContractingDimensionIndex(dot, operand_number)); + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dims, + GetNonContractingDims(dot.operand(operand_number)->shape(), + BatchDimensionsForOperand(dot, operand_number), + {contracting_dim})); + TF_RET_CHECK(non_contracting_dims.size() == 1); + return non_contracting_dims.front(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/matmul_indexing_utils.h b/third_party/xla/xla/service/gpu/matmul_indexing_utils.h new file mode 100644 index 00000000000000..9ceb4b78b01215 --- /dev/null +++ b/third_party/xla/xla/service/gpu/matmul_indexing_utils.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MATMUL_INDEXING_UTILS_H_ +#define XLA_SERVICE_GPU_MATMUL_INDEXING_UTILS_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/shape.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +// Ordered non-contracting dimensions for a dot instruction operand. +absl::StatusOr> GetNonContractingDims( + const Shape& shape, absl::Span batch_dims, + absl::Span contracting_dims); + +// Batch dimensions of an operand of a dot instruction. +// Just an unified accessor to lhs_batch_dimensions and rhs_batch_dimensions. +const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( + const HloInstruction& dot, int operand_number); + +// Index of the only contracting dimension of dot instruction operand. +absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, + int operand_number); + +// Index of the only non-contracting dimension of dot instruction operand. +absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, + int operand_number); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MATMUL_INDEXING_UTILS_H_ diff --git a/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc b/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc new file mode 100644 index 00000000000000..099b64c0471e16 --- /dev/null +++ b/third_party/xla/xla/service/gpu/matmul_indexing_utils_test.cc @@ -0,0 +1,40 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/matmul_indexing_utils.h" + +#include "absl/strings/string_view.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/shape.h" +#include "xla/test.h" +#include "tsl/platform/status_matchers.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::ElementsAre; +using ::tsl::testing::IsOkAndHolds; + +TEST(GetNonContractingDimsTest, Valid) { + Shape shape = ParseShape("f32[1,2,3,4,5,6]").value(); + EXPECT_THAT(GetNonContractingDims(shape, /*batch_dims=*/{4}, + /*contracting_dims=*/{1, 5}), + IsOkAndHolds(ElementsAre(0, 2, 3))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index 49270de65ecd3f..333e7a1c2cd456 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -45,68 +46,17 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { -absl::StatusOr> GetNonContractingDims( - const Shape& shape, absl::Span batch_dims, - absl::Span contracting_dims) { - std::vector non_contracting_dims; - // This is O(rank**2), but we expect rank to be small. - for (int64_t dim = 0; dim < shape.rank(); ++dim) { - bool is_batch = absl::c_count(batch_dims, dim) != 0; - bool is_contracting = absl::c_count(contracting_dims, dim) != 0; - TF_RET_CHECK(!(is_batch && is_contracting)); - if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim); - } - - TF_RET_CHECK(batch_dims.size() + contracting_dims.size() + - non_contracting_dims.size() == - shape.rank()); - return non_contracting_dims; -} - -const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( - const HloInstruction& dot, const int operand_number) { - const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); - if (operand_number == 0) { - return dimension_numbers.lhs_batch_dimensions(); - } - return dimension_numbers.rhs_batch_dimensions(); -} - -absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { - const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); - if (operand_number == 0) { - TF_RET_CHECK(dimension_numbers.lhs_contracting_dimensions().size() == 1); - return dimension_numbers.lhs_contracting_dimensions(0); - } - TF_RET_CHECK(dimension_numbers.rhs_contracting_dimensions().size() == 1); - return dimension_numbers.rhs_contracting_dimensions(0); -} - -absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, - const int operand_number) { - TF_ASSIGN_OR_RETURN(int64_t contracting_dim, - ContractingDimensionIndex(dot, operand_number)); - TF_ASSIGN_OR_RETURN( - std::vector non_contracting_dims, - GetNonContractingDims(dot.operand(operand_number)->shape(), - BatchDimensionsForOperand(dot, operand_number), - {contracting_dim})); - TF_RET_CHECK(non_contracting_dims.size() == 1); - return non_contracting_dims.front(); -} - absl::StatusOr GetBatchRowColumnShape( const Shape& shape, absl::Span batch_dims, absl::Span row_dims, absl::Span col_dims) { @@ -484,8 +434,8 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, if (has_vector_bias) { int vector_bias_index = has_matrix_bias ? 3 : 2; if (primitive_util::IsF8Type(lhs_shape.element_type())) { - // FP8 gemms have 4 scales as inputs which come before the vector bias. - vector_bias_index += 4; + // FP8 gemms have 2 scales as inputs which come before the vector bias. + vector_bias_index += 2; } vector_bias_shape = gemm->operand(vector_bias_index)->shape(); } diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 5f128e418af58c..1f2692eea5bdb6 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -43,24 +42,6 @@ limitations under the License. namespace xla { namespace gpu { -// Ordered non-contracting dimensions for a dot instruction operand. -absl::StatusOr> GetNonContractingDims( - const Shape& shape, absl::Span batch_dims, - absl::Span contracting_dims); - -// Batch dimensions of an operand of a dot instruction. -// Just an unified accessor to lhs_batch_dimensions and rhs_batch_dimensions. -const tsl::protobuf::RepeatedField& BatchDimensionsForOperand( - const HloInstruction& dot, int operand_number); - -// Index of the only contracting dimension of dot instruction operand. -absl::StatusOr ContractingDimensionIndex(const HloInstruction& dot, - int operand_number); - -// Index of the only non-contracting dimension of dot instruction operand. -absl::StatusOr NonContractingDimensionIndex(const HloInstruction& dot, - int operand_number); - // Normalize shape to (batch, rows, columns) logical dimensions. absl::StatusOr GetBatchRowColumnShape( const Shape& shape, absl::Span batch_dims, diff --git a/third_party/xla/xla/service/gpu/matmul_utils_test.cc b/third_party/xla/xla/service/gpu/matmul_utils_test.cc index c3ccdb517438b7..d758d04169b7e0 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils_test.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/shape.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -32,16 +32,8 @@ namespace xla { namespace gpu { namespace { -using ::testing::ElementsAre; using ::tsl::testing::IsOkAndHolds; -TEST(GetNonContractingDimsTest, Valid) { - Shape shape = ParseShape("f32[1,2,3,4,5,6]").value(); - EXPECT_THAT(GetNonContractingDims(shape, /*batch_dims=*/{4}, - /*contracting_dims=*/{1, 5}), - IsOkAndHolds(ElementsAre(0, 2, 3))); -} - using CanFoldTransposeOperandIntoDotTest = HloTestBase; TEST_F(CanFoldTransposeOperandIntoDotTest, ArgTransposeFoldGemm) { diff --git a/third_party/xla/xla/service/gpu/metrics.cc b/third_party/xla/xla/service/gpu/metrics.cc index 87f7452639a14d..b3bd4861912852 100644 --- a/third_party/xla/xla/service/gpu/metrics.cc +++ b/third_party/xla/xla/service/gpu/metrics.cc @@ -16,10 +16,17 @@ limitations under the License. #include "xla/service/gpu/metrics.h" #include +#include +#include +#include "absl/strings/ascii.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "xla/tsl/lib/monitoring/counter.h" #include "xla/tsl/lib/monitoring/gauge.h" #include "xla/tsl/lib/monitoring/sampler.h" +#include "tsl/platform/stacktrace.h" namespace xla { namespace { @@ -40,6 +47,10 @@ auto* xla_device_binary_size = tsl::monitoring::Gauge::New( "/xla/service/gpu/xla_device_binary_size", "The size of the XLA binary loaded onto the GPU device."); +auto* gpu_compiler_stacktrace_count = tsl::monitoring::Counter<1>::New( + "/xla/service/gpu/compiler_stacktrace_count", + "The number of times a compiler stacktrace was called.", "stacktrace"); + } // namespace void RecordHloPassesDuration(const uint64_t time_usecs) { @@ -93,4 +104,31 @@ void RecordXlaDeviceBinarySize(const int64_t size) { xla_device_binary_size->GetCell()->Set(size); } +void RecordGpuCompilerStacktrace() { + std::string tsl_stacktrace = tsl::CurrentStackTrace(); + + // tsl::CurrentStackTrace() adds a prefix and postfix lines, so remove them. + std::deque stack = absl::StrSplit(tsl_stacktrace, '\n'); + stack.pop_front(); + stack.pop_back(); + + // Stack traces with addresses would make too many unique streamz cells. + // We only care about the actual call stack. + // Format chars added by tsl::CurrentStackTrace(). + constexpr unsigned kFormatChars = 8; + constexpr unsigned kAddressFormat = kFormatChars + 2 * sizeof(void*); + for (int i = 0; i < stack.size(); ++i) { + stack[i] = std::string(absl::StripAsciiWhitespace( + absl::ClippedSubstr(stack[i], kAddressFormat))); + } + + std::string stacktrace = absl::StrJoin(stack, ";\n"); + gpu_compiler_stacktrace_count->GetCell(stacktrace)->IncrementBy(1); +} + +int GetGpuCompilerStacktraceCount(absl::string_view stacktrace) { + return gpu_compiler_stacktrace_count->GetCell(std::string(stacktrace)) + ->value(); +} + } // namespace xla diff --git a/third_party/xla/xla/service/gpu/metrics.h b/third_party/xla/xla/service/gpu/metrics.h index 61173560e2f74b..8244ff77d83b71 100644 --- a/third_party/xla/xla/service/gpu/metrics.h +++ b/third_party/xla/xla/service/gpu/metrics.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include "absl/strings/string_view.h" + namespace xla { // HLO passes (HLO -> HLO). @@ -52,6 +54,13 @@ int64_t GetCompiledProgramsCount(); // Records the size of the XLA device binary in bytes. void RecordXlaDeviceBinarySize(int64_t size); +// Records the stacktrace of the GPU compiler. +void RecordGpuCompilerStacktrace(); + +// Returns the number of times the GPU compiler was called with the given +// stacktrace. +int GetGpuCompilerStacktraceCount(absl::string_view stacktrace); + } // namespace xla #endif // XLA_SERVICE_GPU_METRICS_H_ diff --git a/third_party/xla/xla/service/gpu/metrics_test.cc b/third_party/xla/xla/service/gpu/metrics_test.cc new file mode 100644 index 00000000000000..836c32d0563cb5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/metrics_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/metrics.h" + +#include +#include +#include + +#include "xla/tsl/lib/monitoring/collected_metrics.h" +#include "xla/tsl/lib/monitoring/collection_registry.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +TEST(MetricsTest, RecordsGpuCompilerStacktrace) { + const std::string kGpuCompilerStacktraceMetricName = + "/xla/service/gpu/compiler_stacktrace_count"; + + RecordGpuCompilerStacktrace(); + + tsl::monitoring::CollectionRegistry::CollectMetricsOptions options; + std::unique_ptr metrics = + tsl::monitoring::CollectionRegistry::Default()->CollectMetrics(options); + + EXPECT_TRUE(metrics->point_set_map.find(kGpuCompilerStacktraceMetricName) != + metrics->point_set_map.end()); + EXPECT_EQ( + metrics->point_set_map[kGpuCompilerStacktraceMetricName]->points.size(), + 1); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index af58af0a37647b..9617ba64403899 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -1,3 +1,4 @@ +load("@bazel_skylib//rules:build_test.bzl", "build_test") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") @@ -5,7 +6,7 @@ load("//xla:xla.bzl", "xla_cc_test") # Libraries for performance modeling of HLO. load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( @@ -86,7 +87,7 @@ xla_cc_test( srcs = ["fusion_analysis_cache_test.cc"], deps = [ ":fusion_analysis_cache", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", "//xla/stream_executor:device_description", @@ -246,6 +247,7 @@ cc_library( ":gpu_hlo_cost_analysis", ":gpu_performance_model_base", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:hlo_fusion_analysis", @@ -273,6 +275,7 @@ xla_cc_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", @@ -303,9 +306,9 @@ cc_library( ":indexing_analysis", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_dataflow_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", @@ -394,6 +397,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test_helpers", "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_traversal", @@ -413,32 +417,6 @@ xla_cc_test( ], ) -cc_library( - name = "affine_map_printer", - srcs = ["affine_map_printer.cc"], - hdrs = ["affine_map_printer.h"], - deps = [ - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], -) - -xla_cc_test( - name = "affine_map_printer_test", - srcs = ["affine_map_printer_test.cc"], - deps = [ - ":affine_map_printer", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:test", - ], -) - cc_library( name = "affine_map_evaluator", srcs = ["affine_map_evaluator.cc"], @@ -469,21 +447,22 @@ cc_library( srcs = [ "indexing_analysis.cc", "indexing_map.cc", + "indexing_map_serialization.cc", ], hdrs = [ "indexing_analysis.h", "indexing_map.h", + "indexing_map_serialization.h", ], deps = [ - ":affine_map_printer", "//xla:permutation_util", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:gather_simplifier", + "//xla/hlo/transforms:gather_simplifier", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:matmul_indexing_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -491,9 +470,11 @@ cc_library( "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:logging", @@ -504,15 +485,14 @@ xla_cc_test( name = "indexing_map_test", srcs = ["indexing_map_test.cc"], deps = [ - ":affine_map_printer", ":indexing_analysis", ":indexing_test_utils", - "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/hash:hash_testing", - "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", @@ -521,6 +501,21 @@ xla_cc_test( ], ) +xla_cc_test( + name = "indexing_map_serialization_test", + srcs = ["indexing_map_serialization_test.cc"], + deps = [ + ":indexing_analysis", + ":indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "indexing_test_utils", testonly = True, @@ -530,6 +525,7 @@ cc_library( ":indexing_analysis", "//xla:status_macros", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "@com_google_absl//absl/container:flat_hash_map", @@ -571,7 +567,6 @@ cc_library( hdrs = ["symbolic_tile.h"], deps = [ ":affine_map_evaluator", - ":affine_map_printer", ":indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -650,6 +645,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/service:name_uniquer", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/lib/gtl:iterator_range", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", @@ -660,7 +656,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/platform:errors", ], ) @@ -697,7 +692,6 @@ cc_library( hdrs = ["symbolic_tile_analysis.h"], deps = [ ":affine_map_evaluator", - ":affine_map_printer", ":indexing_analysis", ":symbolic_tile", ":symbolic_tiled_hlo_instruction", @@ -737,6 +731,7 @@ xla_cc_test( ":tiled_hlo_instruction_or_computation", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:instruction_fusion", "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", @@ -749,6 +744,7 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], ) @@ -783,6 +779,7 @@ xla_cc_test( ":symbolic_tile_analysis", ":triton_emitter_constraints", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:instruction_fusion", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_traversal", @@ -805,16 +802,17 @@ cc_library( deps = [ ":affine_map_evaluator", ":indexing_analysis", + ":tiled_hlo_instruction_or_computation", "//xla:shape_util", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu/fusions:fusion_emitter", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -827,6 +825,9 @@ xla_cc_test( srcs = ["coalescing_analysis_test.cc"], deps = [ ":coalescing_analysis", + ":symbolic_tile", + ":symbolic_tile_analysis", + ":tiled_hlo_instruction_or_computation", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", @@ -838,8 +839,10 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -909,6 +912,7 @@ cc_library( "//xla/service:gpu_plugin", "//xla/service:hlo_module_config", "//xla/service:hlo_runner", + "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", "//xla/stream_executor:device_description", "//xla/tests:test_utils", @@ -954,6 +958,7 @@ cc_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", ], ) @@ -965,9 +970,24 @@ cc_library( ] ] +build_test( + name = "hlo_op_profiler_build_test", + targets = [ + ":hlo_op_profiler_run_sm80", + ], +) + xla_test( name = "hlo_op_profiler_test", srcs = ["hlo_op_profiler_test.cc"], + # TODO(b/372714955): Fix the memory leak! + backend_args = if_google( + { + "gpu_h100": ["--heap_check="], + "gpu_a100": ["--heap_check="], + }, + {}, + ), backends = ["gpu"], local_defines = if_cuda(["GOOGLE_CUDA"]), deps = [ @@ -976,6 +996,8 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc b/third_party/xla/xla/service/gpu/model/affine_map_printer.cc deleted file mode 100644 index 83b68eca0473d8..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/affine_map_printer.h" - -#include -#include -#include -#include - -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/Support/LLVM.h" - -namespace xla { -namespace gpu { -namespace { - -using mlir::AffineBinaryOpExpr; -using mlir::AffineConstantExpr; -using mlir::AffineDimExpr; -using mlir::AffineExpr; -using mlir::AffineExprKind; -using mlir::AffineMap; -using mlir::AffineSymbolExpr; - -} // namespace - -AffineMapPrinter::AffineMapPrinter( - absl::Span dim_names, - absl::Span symbol_names) { - dim_id_to_name_.reserve(dim_names.size()); - for (const auto& [index, name] : llvm::enumerate(dim_names)) { - dim_id_to_name_[index] = name; - } - symbol_id_to_name_.reserve(symbol_names.size()); - for (const auto& [index, name] : llvm::enumerate(symbol_names)) { - symbol_id_to_name_[index] = name; - } -} - -void AffineMapPrinter::Print(std::ostream& out, AffineMap affine_map) const { - out << ToString(affine_map); -} - -std::string AffineMapPrinter::ToString(AffineMap affine_map) const { - std::string s; - llvm::raw_string_ostream ss(s); - - if (dim_id_to_name_.empty() && symbol_id_to_name_.empty()) { - affine_map.print(ss); - return s; - } - // Dimension identifiers. - int dim_count = affine_map.getNumDims(); - ss << '('; - for (int i = 0; i < dim_count - 1; ++i) { - ss << GetDimensionName(i) << ", "; - } - if (dim_count >= 1) { - ss << GetDimensionName(dim_count - 1); - } - ss << ')'; - // Symbolic identifiers. - int symbol_count = affine_map.getNumSymbols(); - if (symbol_count != 0) { - ss << '['; - for (unsigned i = 0; i < symbol_count - 1; ++i) { - ss << GetSymbolName(i) << ", "; - } - if (affine_map.getNumSymbols() >= 1) { - ss << GetSymbolName(symbol_count - 1); - } - ss << ']'; - } - // Result affine expressions. - ss << " -> ("; - llvm::interleaveComma(affine_map.getResults(), ss, [&](AffineExpr expr) { - PrintExprImpl(expr, /*add_parentheses=*/false, ss); - }); - ss << ')'; - return s; -} - -void AffineMapPrinter::Print(std::ostream& out, - mlir::AffineExpr affine_expr) const { - out << ToString(affine_expr); -} - -std::string AffineMapPrinter::ToString(mlir::AffineExpr affine_expr) const { - std::string s; - llvm::raw_string_ostream ss(s); - PrintExprImpl(affine_expr, /*add_parentheses=*/false, ss); - return s; -} - -void AffineMapPrinter::PrintExprImpl(const mlir::AffineExpr affine_expr, - bool add_parentheses, - llvm::raw_ostream& os) const { - const char* binopSpelling = nullptr; - switch (affine_expr.getKind()) { - case AffineExprKind::SymbolId: { - unsigned symbol_id = - mlir::cast(affine_expr).getPosition(); - os << GetSymbolName(symbol_id); - return; - } - case AffineExprKind::DimId: { - unsigned dim_id = mlir::cast(affine_expr).getPosition(); - os << GetDimensionName(dim_id); - return; - } - case AffineExprKind::Constant: - os << mlir::cast(affine_expr).getValue(); - return; - case AffineExprKind::Add: - binopSpelling = " + "; - break; - case AffineExprKind::Mul: - binopSpelling = " * "; - break; - case AffineExprKind::FloorDiv: - binopSpelling = " floordiv "; - break; - case AffineExprKind::CeilDiv: - binopSpelling = " ceildiv "; - break; - case AffineExprKind::Mod: - binopSpelling = " mod "; - break; - } - - auto binOp = mlir::cast(affine_expr); - AffineExpr lhsExpr = binOp.getLHS(); - AffineExpr rhsExpr = binOp.getRHS(); - - // Handle tightly binding binary operators. - if (binOp.getKind() != AffineExprKind::Add) { - if (add_parentheses) { - os << '('; - } - - // Pretty print multiplication with -1. - auto rhsConst = mlir::dyn_cast(rhsExpr); - if (rhsConst && binOp.getKind() == AffineExprKind::Mul && - rhsConst.getValue() == -1) { - os << "-"; - PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); - if (add_parentheses) { - os << ')'; - } - return; - } - - PrintExprImpl(lhsExpr, /*add_parentheses=*/true, os); - - os << binopSpelling; - PrintExprImpl(rhsExpr, /*add_parentheses=*/true, os); - - if (add_parentheses) { - os << ')'; - } - return; - } - - // Print out special "pretty" forms for add. - if (add_parentheses) { - os << '('; - } - - // Pretty print addition to a product that has a negative operand as a - // subtraction. - if (auto rhs = mlir::dyn_cast(rhsExpr)) { - if (rhs.getKind() == AffineExprKind::Mul) { - AffineExpr rrhsExpr = rhs.getRHS(); - if (auto rrhs = mlir::dyn_cast(rrhsExpr)) { - if (rrhs.getValue() == -1) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - "; - if (rhs.getLHS().getKind() == AffineExprKind::Add) { - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); - } else { - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/false, os); - } - - if (add_parentheses) { - os << ')'; - } - return; - } - - if (rrhs.getValue() < -1) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - "; - PrintExprImpl(rhs.getLHS(), /*add_parentheses=*/true, os); - os << " * " << -rrhs.getValue(); - if (add_parentheses) { - os << ')'; - } - return; - } - } - } - } - - // Pretty print addition to a negative number as a subtraction. - if (auto rhsConst = mlir::dyn_cast(rhsExpr)) { - if (rhsConst.getValue() < 0) { - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - os << " - " << -rhsConst.getValue(); - if (add_parentheses) { - os << ')'; - } - return; - } - } - - PrintExprImpl(lhsExpr, /*add_parentheses=*/false, os); - - os << " + "; - PrintExprImpl(rhsExpr, /*add_parentheses=*/false, os); - - if (add_parentheses) { - os << ')'; - } -} - -void AffineMapPrinter::SetSymbolName(int64_t symbol_id, llvm::StringRef name) { - symbol_id_to_name_[symbol_id] = name; -} - -void AffineMapPrinter::SetDimensionName(int64_t dim_id, llvm::StringRef name) { - dim_id_to_name_[dim_id] = name; -} - -std::string AffineMapPrinter::GetSymbolName(int64_t symbol_id) const { - auto it = symbol_id_to_name_.find(symbol_id); - if (it == symbol_id_to_name_.end()) { - return absl::StrCat("s", symbol_id); - } - return it->second; -} - -std::string AffineMapPrinter::GetDimensionName(int64_t dim_id) const { - auto it = dim_id_to_name_.find(dim_id); - if (it == dim_id_to_name_.end()) { - return absl::StrCat("d", dim_id); - } - return it->second; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer.h b/third_party/xla/xla/service/gpu/model/affine_map_printer.h deleted file mode 100644 index bb1f6fbeb6902b..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ -#define XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ - -#include -#include -#include -#include - -#include "absl/types/span.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" - -namespace xla { -namespace gpu { - -// AffineMapPrinter allows to "pretty print" mlir::AffineMap by setting custom -// symbol and dimension names. -class AffineMapPrinter { - public: - AffineMapPrinter() = default; - AffineMapPrinter(AffineMapPrinter&& other) = default; - AffineMapPrinter& operator=(AffineMapPrinter&& other) = default; - AffineMapPrinter(absl::Span dim_names, - absl::Span symbol_names); - - void SetSymbolName(int64_t symbol_id, llvm::StringRef name); - void SetDimensionName(int64_t dim_id, llvm::StringRef name); - - std::string GetSymbolName(int64_t symbol_id) const; - std::string GetDimensionName(int64_t dim_id) const; - - void Print(std::ostream& out, mlir::AffineMap affine_map) const; - std::string ToString(mlir::AffineMap affine_map) const; - - void Print(std::ostream& out, mlir::AffineExpr affine_expr) const; - std::string ToString(mlir::AffineExpr affine_expr) const; - - private: - void PrintExprImpl(mlir::AffineExpr affine_expr, bool add_parentheses, - llvm::raw_ostream& os) const; - - llvm::DenseMap dim_id_to_name_; - llvm::DenseMap symbol_id_to_name_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_MODEL_AFFINE_MAP_PRINTER_H_ diff --git a/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc b/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc deleted file mode 100644 index 01c6092b4d02c3..00000000000000 --- a/third_party/xla/xla/service/gpu/model/affine_map_printer_test.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/model/affine_map_printer.h" - -#include -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace gpu { -namespace { - -using ::mlir::AffineExpr; -using ::mlir::AffineMap; -using ::mlir::bindDims; -using ::mlir::bindSymbols; -using ::testing::HasSubstr; - -class IndexingMapTest : public HloTestBase { - public: - mlir::MLIRContext mlir_context_; - AffineMapPrinter printer_; -}; - -TEST_F(IndexingMapTest, AffineMapPrinterTest) { - AffineExpr d0, d1, s0, s1; - bindDims(&mlir_context_, d0, d1); - bindSymbols(&mlir_context_, s0, s1); - - // (d0, d1)[s0, s1] -> (d0 + d1 floordiv 8, s0 + s1 mod 16). - auto map = - AffineMap::get(2, 2, {d0 + d1.floorDiv(8), s0 + s1 % 16}, &mlir_context_); - - printer_.SetDimensionName(0, "offset"); - printer_.SetSymbolName(1, "linear_index"); - EXPECT_THAT(printer_.ToString(map), - HasSubstr("(offset, d1)[s0, linear_index] -> " - "(offset + d1 floordiv 8, s0 + linear_index mod 16)")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc index 1fa86f0caa024c..95940241d8b0de 100644 --- a/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc +++ b/third_party/xla/xla/service/gpu/model/analytical_latency_estimator_test.cc @@ -96,12 +96,6 @@ class AnalyticalLatencyHidingSchedulerTest : public GpuCodegenTest { ->GetDeviceDescription() .cuda_compute_capability(); } - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } }; TEST_F(AnalyticalLatencyHidingSchedulerTest, TestAnalyticalLatencyEstimator) { @@ -149,7 +143,8 @@ ENTRY entry { auto scheduler_config = GetDefaultSchedulerConfig(); auto latency_estimator = std::make_unique( scheduler_config, std::make_unique(), - dev_info, ShapeSizeBytesFunction(), hlo_module->entry_computation()); + dev_info, HloCostAnalysis::DefaultShapeSize, + hlo_module->entry_computation()); EXPECT_TRUE(RunScheduler(hlo_module.get(), scheduler_config, std::move(latency_estimator)) .ok()); diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 9e7a685d590a29..31bed0893218fb 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -25,9 +25,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -43,8 +42,11 @@ limitations under the License. #include "xla/service/gpu/model/affine_map_evaluator.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/util.h" namespace xla { namespace gpu { @@ -97,15 +99,94 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, return true; } +double BandwidthUtilizationRateHeuristicForTiledMemoryAccess( + const TiledHloInstruction& hbm_access_instr, + const se::DeviceDescription& device_info) { + const HloInstruction* hlo = hbm_access_instr.hlo(); + const Shape& shape = hlo->shape(); + + // Compute the number of elements in the contiguous part of the tile. + int64_t contiguous_elements = 1; + for (const auto dim_idx : shape.layout().minor_to_major()) { + // This dimension is strided, so it's not contiguous. + if (hbm_access_instr.tile_stride(dim_idx) != 1) { + break; + } + + int64_t tile_size = hbm_access_instr.tile_size(dim_idx); + int64_t dim_size = shape.dimensions(dim_idx); + + // Make sure to ignore the mask if there is one. + contiguous_elements *= std::min(tile_size, dim_size); + + // This dimension is only partially captured, so more major dimensions are + // necessarily not captured contiguously. + if (tile_size < dim_size) { + break; + } + } + + // Compute the size of the contiguous part of the tile in bytes. + int64_t contiguous_bytes_accessed = + contiguous_elements * + ShapeUtil::ByteSizeOfPrimitiveType(hlo->shape().element_type()); + + // Memory accesses are fully coalesced if the memory access uses exactly a + // multiple of the DRAM->L2 cache line size contiguously. + int64_t transaction_size_bytes = + device_info.dram_to_l2_transaction_size_bytes(); + int64_t effective_bytes_accessed = + transaction_size_bytes * + CeilOfRatio(contiguous_bytes_accessed, transaction_size_bytes); + return 1.0 * contiguous_bytes_accessed / effective_bytes_accessed; +} + +bool IsTiledReadCoalescedHeuristic(const TiledHloInstruction& operand, + const se::DeviceDescription& device_info) { + const Shape& shape = operand.hlo()->shape(); + + // Compute the number of elements in the contiguous part of the tile. + int64_t contiguous_read_elements = 1; + for (const auto dim_idx : shape.layout().minor_to_major()) { + // This dimension is strided, so it's not contiguous. + if (operand.tile_stride(dim_idx) != 1) { + break; + } + + int64_t tile_size = operand.tile_size(dim_idx); + int64_t dim_size = shape.dimensions(dim_idx); + + // Make sure to ignore the mask if there is one. + contiguous_read_elements *= std::min(tile_size, dim_size); + + // This dimension is only partially captured, so more major dimensions are + // necessarily not captured contiguously. + if (tile_size < dim_size) { + break; + } + } + + // Compute the size of the contiguous part of the tile in bytes. + int64_t contiguous_bytes_accessed = + contiguous_read_elements * + ShapeUtil::ByteSizeOfPrimitiveType(operand.hlo()->shape().element_type()); + + // We consider a read coalesced if the contiguous part of the read covers the + // whole DRAM->L2 cache line. + // + // TODO(b/332714755): note that we don't check that we fully exploit all the + // cache lines we read from if we happen to read through several of them. + return contiguous_bytes_accessed >= + device_info.dram_to_l2_transaction_size_bytes(); +} + namespace { using ::mlir::AffineBinaryOpExpr; using ::mlir::AffineConstantExpr; -using ::mlir::AffineDimExpr; using ::mlir::AffineExpr; using ::mlir::AffineExprKind; using ::mlir::AffineMap; -using ::mlir::AffineSymbolExpr; using ::mlir::getAffineConstantExpr; using ::mlir::MLIRContext; @@ -233,11 +314,10 @@ void AssignValuesToRTVars(IndexingMap* indexing_map) { symbol_replacements.push_back( mlir::getAffineSymbolExpr(symbol_id, mlir_context)); } - for (const RTVar& rt_var : indexing_map->GetRTVars()) { + for (const IndexingMap::Variable& rt_var : indexing_map->GetRTVars()) { // Take midpoint of the feasible interval for the RT variable. symbol_replacements.push_back(getAffineConstantExpr( - (rt_var.feasible_values.lower + rt_var.feasible_values.upper) / 2, - mlir_context)); + (rt_var.bounds.lower + rt_var.bounds.upper) / 2, mlir_context)); } AffineMap thread_x_to_input_no_dim_symbols = indexing_map->GetAffineMap().replaceDimsAndSymbols( @@ -263,7 +343,7 @@ void AssignValuesToOuterLoopIVs(IndexingMap* indexing_map) { for (int64_t symbol_id = 0; symbol_id < indexing_map->GetRangeVarsCount() - 1; ++symbol_id) { symbol_replacements.push_back(getAffineConstantExpr( - indexing_map->GetRangeVar(symbol_id).range.lower, mlir_context)); + indexing_map->GetRangeVar(symbol_id).bounds.lower, mlir_context)); } symbol_replacements.push_back(mlir::getAffineSymbolExpr(0, mlir_context)); @@ -498,7 +578,7 @@ bool IsIndexingCoalesced(IndexingMap& thread_x_to_linearized_input, AffineExpr c0 = getAffineConstantExpr(0, mlir_context); IndexingMap thread_x_first_32_elements{ AffineMap::get(1, 0, {thread_x_dim, c0, c0, c0, c0, c0}, mlir_context), - {DimVar{{0, 31}}}, + {IndexingMap::Variable{{0, 31}}}, /*range_vars=*/{}, /*rt_vars=*/{}}; IndexingMap thread_x_to_input_sample = diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index da2c6872b191e9..ca6c5465e87a20 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -23,7 +23,8 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -72,6 +73,27 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, const HloInstruction* producer, const HloInstruction* consumer = nullptr); +// Returns the bandwidth utilization rate of the memory access for the given +// tiled HLO instruction. Naturally, values are between 0 and 1, where a +// perfectly coalesced read has a utilization rate of 1. +// +// Note: the assumption is that the tile sizes do not include padding beyond +// the end of the shape. +double BandwidthUtilizationRateHeuristicForTiledMemoryAccess( + const TiledHloInstruction& hbm_access_instr, + const se::DeviceDescription& device_info); + +// Returns true if read of this tiled hlo operand is coalesced. +// +// We consider a read coalesced if the operand tile consist of contiguous chunk +// of memory that saturate DRAM->L2 cache line. For post-V100 NVIDIA GPUs, that +// is 64 bytes by default. +// +// TODO(b/332714755): check whether we should bump up the granularity of +// memory transactions. +bool IsTiledReadCoalescedHeuristic(const TiledHloInstruction& operand, + const se::DeviceDescription& device_info); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index aefe84294472a2..cd545247e0b0aa 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -15,12 +15,14 @@ limitations under the License. #include "xla/service/gpu/model/coalescing_analysis.h" +#include #include #include #include #include #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -29,11 +31,15 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -148,7 +154,7 @@ TEST_F(CoalescingTest, OutputAndLhsTransposedLayout) { fusion { p0 = f32[100, 200]{1, 0} parameter(0) p1 = f32[100, 200]{0, 1} parameter(1) - ROOT exp = f32[100, 200]{1, 0} add(p0, p1) + ROOT add = f32[100, 200]{1, 0} add(p0, p1) } ENTRY e { p0 = f32[100, 200]{1, 0} parameter(0) @@ -510,6 +516,200 @@ TEST_F(CoalescingTest, Param) { EXPECT_THAT(IsReadCoalescedPerOperand(ir), ElementsAre(true, true, true)); } +class CoalescingForTiledHloTest : public CoalescingTest { + public: + std::vector IsTiledReadCoalescedPerOperand( + const HloInstruction* root, absl::Span tile_sizes) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + + SymbolicTileAnalysis symbolic_tile_analysis = + std::get(SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_)); + + TiledHloComputation tiled_hlo_computation = + *symbolic_tile_analysis.ComputeTiledHloInstructions( + tile_sizes, /*constraints_are_known_satisfied=*/true, + /*compute_all_tile_offset_indexing_maps=*/true); + + const TiledHloInstruction* tiled_hlo_root = tiled_hlo_computation.GetRoot(); + std::vector result; + for (const TiledHloInstruction* operand : tiled_hlo_root->operands()) { + result.push_back(IsTiledReadCoalescedHeuristic(*operand, device_info_)); + } + return result; + } + + std::vector EffectiveBandwidthUtilizationRatePerOperand( + const HloInstruction* root, absl::Span tile_sizes) { + auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); + + SymbolicTileAnalysis symbolic_tile_analysis = + std::get(SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_)); + + TiledHloComputation tiled_hlo_computation = + *symbolic_tile_analysis.ComputeTiledHloInstructions( + tile_sizes, /*constraints_are_known_satisfied=*/true, + /*compute_all_tile_offset_indexing_maps=*/true); + + const TiledHloInstruction* tiled_hlo_root = tiled_hlo_computation.GetRoot(); + std::vector result; + for (const TiledHloInstruction* operand : tiled_hlo_root->operands()) { + result.push_back(BandwidthUtilizationRateHeuristicForTiledMemoryAccess( + *operand, device_info_)); + } + return result; + } +}; + +TEST_F(CoalescingForTiledHloTest, TiledReadCoalescedHeuristic_Transpose) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[2048, 48] parameter(0) + ROOT transpose = f32[48, 2048] transpose(p0), dimensions={1, 0} +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + // The operand is not coalesced because the tile has stride 48. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1, 2048}), + ElementsAre(false)); + + // The operand is coalesced because we read 48 contiguous elements. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {48, 32}), + ElementsAre(true)); +} + +TEST_F(CoalescingForTiledHloTest, + TiledReadCoalescedHeuristic_MaskingIsHandledCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[2048, 12] parameter(0) + ROOT transpose = f32[12, 2048] transpose(p0), dimensions={1, 0} +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kNumBytesPerParamRow = 12 * 4; + + // The transaction size can be configured in different ways, and the minimum + // possible value on A100 is 32 bytes---which would make this test fail. + // Ensure that the transaction size is configured to be large enough. + ASSERT_GT(device_info_.dram_to_l2_transaction_size_bytes(), + kNumBytesPerParamRow); + + // The operand is coalesced because we read 4 * 12 = 48 contiguous elements + // (though the tile contains 64 elements due to the mask). + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 4}), ElementsAre(true)); + + // The mask should be ignored when checking whether reads are coalesced. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1024, 1}), + ElementsAre(false)); +} + +TEST_F(CoalescingForTiledHloTest, RhsTransposedLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[256, 512]{1,0} parameter(0) + p1 = f32[256, 512]{0,1} parameter(1) + ROOT add = f32[256, 512]{1,0} add(p0, p1) +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kExpectedDramToL2TransactionSize = 64; + ASSERT_EQ(device_info_.dram_to_l2_transaction_size_bytes(), + kExpectedDramToL2TransactionSize); + + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {1, 16}), + ElementsAre(true, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 1}), + ElementsAre(false, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 16}), + ElementsAre(true, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {8, 8}), + ElementsAre(false, false)); +} + +TEST_F(CoalescingForTiledHloTest, SmallDataTypes) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = s8[256, 512] parameter(0) + p1 = s8[256, 512] parameter(1) + ROOT add = s8[256, 512] add(p0, p1) +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + constexpr int kExpectedDramToL2TransactionSize = 64; + ASSERT_EQ(device_info_.dram_to_l2_transaction_size_bytes(), + kExpectedDramToL2TransactionSize); + + // To be coalesced, a contiguous chunk of memory load should be at least + // kExpectedDramToL2TransactionSize bytes long. + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 16}), + ElementsAre(false, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 32}), + ElementsAre(false, false)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 64}), + ElementsAre(true, true)); + EXPECT_THAT(IsTiledReadCoalescedPerOperand(root, {16, 128}), + ElementsAre(true, true)); +} + +TEST_F( + CoalescingForTiledHloTest, + EffectiveBandwidthUtilizationRateIsComputedCorrectlyForTiledMemoryAccess) { // NOLINT(whitespace/line_length) + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = s8[256, 16] parameter(0) + ROOT convert = s8[256, 16] convert(p0) +})")); + + const HloInstruction* root = module->entry_computation()->root_instruction(); + + // Note: the tests below rely strongly on this value for the transaction size. + // If the transaction size is changed, the tests will need to be updated. + constexpr int kExpectedDramToL2TransactionSize = 64; + ASSERT_EQ(device_info_.dram_to_l2_transaction_size_bytes(), + kExpectedDramToL2TransactionSize); + + // By reading only one byte at a time, we expect to exploit exactly + // 1 / kExpectedDramToL2TransactionSize of the bandwidth. + EXPECT_THAT(EffectiveBandwidthUtilizationRatePerOperand(root, {1, 1}), + ElementsAre(1.0 / kExpectedDramToL2TransactionSize)); + + // Reading one full row won't cut it; by reading 16 bytes at a time, we expect + // to exploit exactly 16 / kExpectedDramToL2TransactionSize of the bandwidth. + EXPECT_THAT(EffectiveBandwidthUtilizationRatePerOperand(root, {1, 16}), + ElementsAre(16.0 / kExpectedDramToL2TransactionSize)); + + // Reading 4 rows at a time will allow us to exploit 100% of the bandwidth. + EXPECT_THAT(EffectiveBandwidthUtilizationRatePerOperand(root, {4, 16}), + ElementsAre(1.0)); + + // Reading 8 rows at a time will allow us to exploit 100% of the bandwidth. + EXPECT_THAT(EffectiveBandwidthUtilizationRatePerOperand(root, {8, 16}), + ElementsAre(1.0)); + + // Reading 6 rows at a time will however only allow us to exploit 75% of the + // bandwidth; the first four rows are read fully coalesced, but the last two + // rows use only half of the transaction size---i.e. 3/4 of the transactions + // are coalesced. + EXPECT_THAT(EffectiveBandwidthUtilizationRatePerOperand(root, {6, 16}), + ElementsAre(0.75)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc index 1750c38c3c41ae..820d9925ab3193 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include #include "absl/strings/string_view.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/hlo_parser.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc index aad3343260c945..46dcff11343140 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/numbers.h" #include "absl/time/time.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc index 2f8fe7b18b3233..8b1d92eb55df7b 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_cost_model_stats_collection_test.cc @@ -25,30 +25,17 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { class GpuCostModelStatsCollectionTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: GpuCostModelStatsCollection cost_model_stats_{ TestGpuDeviceInfo::RTXA6000DeviceInfo(), - GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}}; + GpuHloCostAnalysis::Options{.count_multiple_input_accesses = true}}; }; TEST_F(GpuCostModelStatsCollectionTest, FusinInEntryComputation) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index 6115b812c912fc..9f591ac8c25e6a 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -24,8 +24,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/model/hlo_op_profiles.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" @@ -35,18 +33,8 @@ namespace xla { namespace gpu { class GpuHloCostAnalysisTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: - HloCostAnalysis::Options options_{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; + HloCostAnalysis::Options options_{.count_multiple_input_accesses = true}; GpuHloCostAnalysis analysis_{options_}; GpuHloCostAnalysisTest() : HloTestBase() {} }; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 4ec9f347dff90d..462b44bdf543ec 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -58,6 +58,16 @@ namespace xla { namespace gpu { namespace { +// Information about an operand read. +struct OperandReadInfo { + // Total number of bytes read from the operand. + int64_t total_bytes_read = 0; + + // Factor, between 0 and 1, determining how much of the chip's HBM bandwidth + // is actually attained when reading this operand. + double read_bandwidth_utilization_rate = 1.0; +}; + // Returns the number of elements in the tile after each dimension is padded to // the next power of 2. // TODO(b/353484968): Delete this function once we have constraints to only @@ -77,6 +87,24 @@ int64_t GetPaddedTileSize(absl::Span tile_sizes) { // heuristic tries to be safe and increase recall at the cost of precision. bool DoesTileFitsInRegisters(int64_t tile_size, const se::DeviceDescription& device_info) { + // This is a conservative estimate to make sure that we don't get a tile that + // is too big and results in register spills. + // + // We had the following reasoning for the value of this constant: + // * Whenever a block needs to use a tile more than once, it needs to + // either (1) load the tile from HBM several times, or (2) store the tile + // in registers at the same time as some of the results. That is the case + // for normalization diamonds for instance, where the input tile is used + // twice. + // * We expect kernels without reuse to benefit from smaller tile sizes + // anyway. + // * We use around 20% of the registers as working memory for indexing + // computations and expensive instructions like exponential or cosine. + // + // This value was empirically determined in September 2024 and may change in + // the future. + constexpr double kFractionOfRegistersAvailableToStoreTile = 0.4; + // Register allocation happens at PTX->SASS level, so we can't know the exact // number of registers used by a kernel. We make a few assumptions about the // kernel we will generate (this may not hold in the future): @@ -95,18 +123,19 @@ bool DoesTileFitsInRegisters(int64_t tile_size, // data type. `registers_per_block_limit()` returns the number of 32-bit // registers. Check if 64-bit types need twice as many registers. Check if // smaller types can fit into one register. - return tile_size <= device_info.registers_per_block_limit(); + return tile_size <= kFractionOfRegistersAvailableToStoreTile * + device_info.registers_per_block_limit(); } // Returns the number of warps to use based on the tile size. The numbers were // originally selected from Triton SoftMax reduction row length. // TODO(b/332714755): Make it smarter. int64_t GetNumWarps(int64_t tile_size) { - if (tile_size <= 512) return 1; - if (tile_size <= 1024) return 2; - if (tile_size <= 16384) return 4; - if (tile_size <= 32768) return 8; - if (tile_size <= 65536) return 16; + if (tile_size <= 256) return 1; + if (tile_size <= 512) return 2; + if (tile_size <= 1024) return 4; + if (tile_size <= 2048) return 8; + if (tile_size <= 4096) return 16; return 32; } @@ -246,7 +275,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( auto element_type = instr->shape().element_type(); int64_t n_bytes_total = 0; for (const auto& indexing_map : indexing_maps) { - VLOG(10) << indexing_map.ToString(); + VLOG(10) << indexing_map; int64_t num_iters = GetIterationSpaceSize(indexing_map, instr); @@ -266,9 +295,10 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForFusion( VLogOperandRead(instr, n_bytes_total, n_bytes_net, is_coalesced); - read_time += - ReadTimeWithDRAMHeuristic(*device_info_, num_blocks, n_bytes_net, - n_bytes_total, element_type, is_coalesced); + read_time += ReadTimeWithDRAMHeuristic( + *device_info_, num_blocks, n_bytes_net, n_bytes_total, element_type, + GetCoalescingUtilizationRate(element_type, *device_info_, + is_coalesced)); } } @@ -347,7 +377,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( const HloFusionAdaptor& fusion_adaptor, const TiledHloComputation& tiled_hlo_computation, const LaunchDimensions& launch_dimensions) { - absl::flat_hash_map n_bytes_total_map; + absl::flat_hash_map n_bytes_total_map; int64_t flops = 0; int64_t bytes_read = 0; @@ -405,19 +435,39 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( int64_t tile_bytes_read = element_type_size * num_elements; bytes_read += tile_bytes_read; - n_bytes_total_map[hlo] += tile_bytes_read; + + double effective_bandwidth_utilization_rate = + BandwidthUtilizationRateHeuristicForTiledMemoryAccess(*tiled_hlo, + *device_info_); + + OperandReadInfo& operand_read_info = n_bytes_total_map[hlo]; + operand_read_info.total_bytes_read += tile_bytes_read; + // TODO(b/332714755): using std::min is more pessimistic than it needs to + // be since it'll end up assuming that if one read is done with lower + // bandwidth, all other reads of the same operand will also be done with + // lower bandwidth. But it's a start. We should refactor this function to + // properly track each read independently later. + operand_read_info.read_bandwidth_utilization_rate = + std::min(operand_read_info.read_bandwidth_utilization_rate, + effective_bandwidth_utilization_rate); } } absl::Duration read_time = absl::ZeroDuration(); - for (const auto& [hlo, n_bytes_total] : n_bytes_total_map) { + for (const auto& [hlo, operand_read_info] : n_bytes_total_map) { int64_t operand_size = shape_size_(hlo->shape()); - int64_t n_bytes_net = std::min(operand_size, n_bytes_total); + int64_t n_bytes_net = + std::min(operand_size, operand_read_info.total_bytes_read); + // TODO(b/332714755): use + // `BandwidthUtilizationRateHeuristicForTiledMemoryAccess` to compute read + // time as well. read_time += ReadTimeWithDRAMHeuristic( - *device_info_, num_blocks, n_bytes_net, n_bytes_total, + *device_info_, num_blocks, n_bytes_net, + operand_read_info.total_bytes_read, /*element_type=*/hlo->shape().element_type(), - /*coalesced=*/true); + /*hbm_bandwidth_utilization_rate=*/ + operand_read_info.read_bandwidth_utilization_rate); } int64_t bytes_written = @@ -426,7 +476,13 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledHloComputation( absl::Duration compute_time = ComputeTime(*device_info_, flops, launch_dimensions.num_blocks(), launch_dimensions.num_threads_per_block()); - absl::Duration write_time = WriteTime(*device_info_, bytes_written); + + int64_t effective_bandwidth = + BandwidthUtilizationRateHeuristicForTiledMemoryAccess( + *tiled_hlo_computation.GetRoot(), *device_info_) * + device_info_->memory_bandwidth(); + absl::Duration write_time = + absl::Seconds(1.0 * bytes_written / effective_bandwidth); absl::Duration memory_access_time = read_time + write_time; absl::Duration exec_time = CombineComputeAndMemoryAccessTime( compute_time, memory_access_time, @@ -488,9 +544,16 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( const TiledHloComputation& tiled_hlo_computation) { - const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); - int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); + + // Decide on the number of warps to use based on the largest live tile size + // at any given point within the computation. + int64_t largest_live_tile_size = 1; + for (const auto& tiled_hlo : tiled_hlo_computation.instructions()) { + largest_live_tile_size = std::max( + largest_live_tile_size, GetPaddedTileSize(tiled_hlo->tile_sizes())); + } + int64_t num_warps = GetNumWarps(largest_live_tile_size); return {static_cast(num_blocks), static_cast(num_warps * WarpSize())}; @@ -543,7 +606,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( } if (!best_tiled_run_time_data.has_value()) { - return FusionDecision("No valid tilings found."); + return FusionDecision::Forbid("No valid tilings found."); } return *best_tiled_run_time_data; } diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index 18ce88e553277f..f9ed7a1b355ddb 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -56,20 +57,13 @@ using ::tsl::testing::StatusIs; class GpuIndexingPerformanceModelTest : public HloTestBase { public: - GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - mlir::MLIRContext mlir_context_; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &device_info_, &fusion_analysis_cache_, HloCostAnalysis::DefaultShapeSize, &mlir_context_}; GpuIndexingPerformanceModelTest() : HloTestBase() {} @@ -336,7 +330,7 @@ ENTRY main { EXPECT_THAT(tiled_runtime_data.block_level_parameters.output_tile_sizes, ElementsAre(4, 911)); - EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 4); + EXPECT_EQ(tiled_runtime_data.block_level_parameters.num_warps, 16); EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_read, kExpectedBytesRead); EXPECT_EQ(tiled_runtime_data.runtime_data.bytes_written, kOutputSizeBytes); @@ -360,30 +354,29 @@ max_computation { } softmax { - param_0 = f16[65538,32768]{1,0} parameter(0) + param_0 = f16[131076,16384]{1,0} parameter(0) constant_neg_inf = f16[] constant(-inf) - reduce = f16[65538]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation - broadcast = f16[65538,32768]{1,0} broadcast(reduce), dimensions={0} - ROOT subtract = f16[65538,32768]{1,0} subtract(param_0, broadcast) + reduce = f16[131076]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + broadcast = f16[131076,16384]{1,0} broadcast(reduce), dimensions={0} + ROOT subtract = f16[131076,16384]{1,0} subtract(param_0, broadcast) } ENTRY main { - param_0 = f16[65538,32768]{1,0} parameter(0) - ROOT fusion = f16[65538,32768]{1,0} fusion(param_0), kind=kCustom, calls=softmax -} -)")); + param_0 = f16[131076,16384]{1,0} parameter(0) + ROOT fusion = f16[131076,16384]{1,0} fusion(param_0), kind=kCustom, calls=softmax +})")); auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - LaunchDimensions launch_dimensions{65538LL * 32768LL, 32}; + LaunchDimensions launch_dimensions{131076LL * 16384LL, 32}; TF_ASSERT_OK_AND_ASSIGN( auto runtime_data, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, launch_dimensions, /*output_tile_sizes=*/{1, 1})); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 183, 1); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 39, 1); - EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 185, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.read_time), 2932, 2); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.compute_time), 19, 1); + EXPECT_NEAR(absl::ToDoubleSeconds(runtime_data.exec_time), 2932, 2); } // TODO(b/351342921): Remove this test once there is no special filter for @@ -463,35 +456,31 @@ add { } triton_softmax_computation { - param_0 = f32[16,40000] parameter(0) + param_0 = f32[16,16000] parameter(0) constant_0 = f32[] constant(0) reduce_0 = f32[16] reduce(param_0, constant_0), dimensions={1}, to_apply=add - broadcast = f32[16,40000] broadcast(reduce_0), dimensions={0} - ROOT multiply = f32[16,40000] multiply(param_0, broadcast) + broadcast = f32[16,16000] broadcast(reduce_0), dimensions={0} + ROOT multiply = f32[16,16000] multiply(param_0, broadcast) } ENTRY main { - param_0 = f32[16,40000] parameter(0) - ROOT triton_softmax = f32[16,40000] fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} + param_0 = f32[16,16000] parameter(0) + ROOT triton_softmax = f32[16,16000] fusion(param_0), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} } )")); auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( - auto tiling_result, - indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); - TF_ASSERT_OK_AND_ASSIGN(auto res1, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, /*launch_dimensions=*/{16, 32}, - /*output_tile_sizes=*/{1, 40000})); - EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 7, 1); + /*output_tile_sizes=*/{1, 16000})); + EXPECT_NEAR(absl::ToDoubleMicroseconds(res1.exec_time), 3, 1); TF_ASSERT_OK_AND_ASSIGN(auto res2, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, /*launch_dimensions=*/{8, 32}, - /*output_tile_sizes=*/{2, 40000})); + /*output_tile_sizes=*/{2, 16000})); EXPECT_TRUE(res2.IsInfinite()); } @@ -515,10 +504,6 @@ ENTRY main { auto fusion_adaptor = HloFusionAdaptor::ForInstruction( module->entry_computation()->root_instruction()); - TF_ASSERT_OK_AND_ASSIGN( - auto tiling_result, - indexing_cost_model_.TryFindBestTilingForFusion(*fusion_adaptor)); - TF_ASSERT_OK_AND_ASSIGN( auto res, indexing_cost_model_.EstimateRunTimeForTiledFusion( *fusion_adaptor, /*launch_dimensions=*/{1, 2 * WarpSize()}, @@ -536,6 +521,97 @@ ENTRY main { EXPECT_EQ(res.flops, kPaddedOutputTileSize * kAddFlops); } +TEST_F( + GpuIndexingPerformanceModelTest, + EstimateRunTimeForTiledFusion_UncoalescedReadsAreScaledBasedOnWasteTransactionPercentage) { // NOLINT(whitespace/line_length) + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +triton_softmax_computation { + param_0 = f32[2048,512] parameter(0) + param_1 = f32[2048,512] parameter(1) + ROOT add = f32[2048,512] add(param_0, param_1) +} + +ENTRY main { + param_0 = f32[2048,512] parameter(0) + param_1 = f32[2048,512] parameter(1) + ROOT triton_softmax = f32[2048,512] fusion(param_0, param_1), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_coalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{4096, 2 * WarpSize()}, + /*output_tile_sizes=*/{2, 128})); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_uncoalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{4096, 2 * WarpSize()}, + /*output_tile_sizes=*/{128, 2})); + + // The number of bytes read is the same for coalesced and uncoalesced reads. + constexpr int64_t kParamSizeBytes = 2048 * 512 * 4; + EXPECT_EQ(res_coalesced.bytes_read, 2 * kParamSizeBytes); + EXPECT_EQ(res_uncoalesced.bytes_read, 2 * kParamSizeBytes); + + // But we expect to waste 7/8th of read transaction time in the + // uncoalesced case, making the read time 8 times slower. + EXPECT_NEAR( + absl::FDivDuration(res_uncoalesced.read_time, res_coalesced.read_time), 8, + 0.001); +} + +TEST_F( + GpuIndexingPerformanceModelTest, + EstimateRunTimeForTiledFusion_UncoalescedWritesAreScaledBasedOnWasteTransactionPercentage) { // NOLINT(whitespace/line_length) + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = s8[2048,512] parameter(0) + param_1 = s8[2048,512] parameter(1) + ROOT add = s8[2048,512] add(param_0, param_1) +} + +ENTRY main { + param_0 = s8[2048,512] parameter(0) + param_1 = s8[2048,512] parameter(1) + ROOT fusion = s8[2048,512] fusion(param_0, param_1), + kind=kCustom, calls=add, + backend_config={"fusion_backend_config": {"kind":"__triton"}} +})")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_coalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{512, WarpSize()}, + /*output_tile_sizes=*/{16, 128})); + + TF_ASSERT_OK_AND_ASSIGN( + auto res_uncoalesced, + indexing_cost_model_.EstimateRunTimeForTiledFusion( + *fusion_adaptor, /*launch_dimensions=*/{512, WarpSize()}, + /*output_tile_sizes=*/{128, 16})); + + // The number of bytes read is the same for coalesced and uncoalesced reads. + constexpr int64_t kParamSizeBytes = 2048 * 512; + EXPECT_EQ(res_coalesced.bytes_read, 2 * kParamSizeBytes); + EXPECT_EQ(res_uncoalesced.bytes_read, 2 * kParamSizeBytes); + + // But we expect to waste 3/4th of write transaction time in the + // uncoalesced case, making the write time 4 times slower. + EXPECT_NEAR( + absl::FDivDuration(res_uncoalesced.write_time, res_coalesced.write_time), + 4, 0.001); +} + TEST_F(GpuIndexingPerformanceModelTest, GetLaunchDimensionsForTiledFusion_IsSupported) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( @@ -573,8 +649,55 @@ ENTRY main { // Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate // the number of warps for padded tile that has size of 16 * 16 * 16 = 4096 - // and corresponds to 4 warps. - EXPECT_EQ(launch_dimensions.num_threads_per_block(), 4 * WarpSize()); + // and corresponds to 16 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize()); +} + +TEST_F(GpuIndexingPerformanceModelTest, + NumberOfWarpsDependsOnLargestLiveTileSize) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add = f32[] add(param_0, param_1) +} + +fusion_computation { + param_0 = f32[1,4096] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[1] reduce(param_0, c0), dimensions={1}, to_apply=add +} + +ENTRY main { + param_0 = f32[1,4096] parameter(0) + ROOT fusion = f32[1] fusion(param_0), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config": {"kind":"__triton"}} +} +)")); + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + module->entry_computation()->root_instruction()); + + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + /*emitter_specific_constraints_builder=*/nullptr); + ASSERT_TRUE(std::holds_alternative(analysis_or_error)); + + TF_ASSERT_OK_AND_ASSIGN( + TiledHloComputation tiled_hlo_computation, + std::get(analysis_or_error) + .ComputeTiledHloInstructions(/*tile_parameters=*/{1})); + + LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + EXPECT_EQ(launch_dimensions.num_blocks(), 1); + + // The largest tile size is 1 * 4096, for which our implementation recommends + // using 16 warps. + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 16 * WarpSize()); } class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { @@ -584,10 +707,7 @@ class FlopsPerElementTest : public GpuIndexingPerformanceModelTest { ParseAndReturnVerifiedModule(hlo_module_string)); GpuHloCostAnalysis cost_analysis( - GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}, + GpuHloCostAnalysis::Options{.count_multiple_input_accesses = true}, device_info_); ASSERT_IS_OK(module->entry_computation()->Accept(&cost_analysis)); diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index 48cce140367be1..63a9bea3999288 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status.h" namespace xla { @@ -82,12 +83,14 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( bytes_read += n_bytes_total; bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + PrimitiveType element_type = operand->shape().element_type(); VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); read_time += ReadTimeWithDRAMHeuristic( device_info, num_blocks, n_bytes_net, n_bytes_total, - operand->shape().element_type(), coalesced); + operand->shape().element_type(), + GetCoalescingUtilizationRate(element_type, device_info, coalesced)); } absl::Duration write_time = WriteTime(device_info, bytes_written); @@ -229,12 +232,14 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( bytes_read += n_bytes_total; bool coalesced = coalescing_analysis.IsReadCoalesced(operand); + PrimitiveType element_type = operand->shape().element_type(); VLogOperandRead(operand, n_bytes_total, n_bytes_net, coalesced); read_time += ReadTimeWithDRAMHeuristic( device_info, launch_dimensions.num_blocks(), n_bytes_net, n_bytes_total, - operand->shape().element_type(), coalesced); + operand->shape().element_type(), + GetCoalescingUtilizationRate(element_type, device_info, coalesced)); } auto exec_time = CombineComputeAndMemoryAccessTime( diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index 46c98903ab293b..e906992df789ea 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -59,20 +59,6 @@ bool FusionUsesParameterElementwiseFromRoot( fusion->fused_expression_root()) == 1.f; } -int GetCoalescingWasteFactor(PrimitiveType element_type, - const se::DeviceDescription& gpu_device_info) { - int64_t element_size_bytes = - element_type == PrimitiveType::TUPLE || - element_type == PrimitiveType::TOKEN - ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ - : ShapeUtil::ByteSizeOfPrimitiveType(element_type); - // Assume we use one element from the cache line and waste the remaining - // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache - // line. - return gpu_device_info.dram_to_l2_transaction_size_bytes() / - element_size_bytes; -} - // Limit the bandwidth for low occupancy cases. Each SM can issue at most // one 32B memory transaction per clock. H100 needs at least 56.8 active SMs // (1830 MHz) to saturate the memory bandwidth (3.35 TB/s). @@ -321,13 +307,11 @@ absl::Duration GpuPerformanceModelBase::ReadTime( absl::Duration GpuPerformanceModelBase::ReadTimeWithDRAMHeuristic( const se::DeviceDescription& gpu_device_info, int64_t num_blocks, int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, - bool coalesced) { - int waste_factor = - coalesced ? 1 : GetCoalescingWasteFactor(element_type, gpu_device_info); - + double hbm_bandwidth_utilization_rate) { // The first read of the input buffer always happens from DRAM. If reads are // no coaleced, bandwidth is reduced by the waste factor. - float dram_bandwidth = gpu_device_info.memory_bandwidth() / waste_factor; + float dram_bandwidth = + gpu_device_info.memory_bandwidth() * hbm_bandwidth_utilization_rate; // Two things can happed on re-reading the buffer: // - If the buffer fits into cache, the L1/L2 cache speedup is applied. @@ -341,7 +325,7 @@ absl::Duration GpuPerformanceModelBase::ReadTimeWithDRAMHeuristic( rest_bandwidth *= kL1CacheSpeedup; } } else { - rest_bandwidth /= waste_factor; + rest_bandwidth *= hbm_bandwidth_utilization_rate; } dram_bandwidth = AdjustBandwidth(gpu_device_info, dram_bandwidth, num_blocks); @@ -441,5 +425,21 @@ void GpuPerformanceModelBase::VLogOperandRead(const HloInstruction* operand, << ", n_bytes_net: " << n_bytes_net << ", coalesced: " << coalesced; } +double GetCoalescingUtilizationRate( + PrimitiveType element_type, const se::DeviceDescription& gpu_device_info, + bool coalesced) { + int64_t element_size_bytes = + element_type == PrimitiveType::TUPLE || + element_type == PrimitiveType::TOKEN + ? 4 /* Dummy value. TODO(jreiffers): Model this case. */ + : ShapeUtil::ByteSizeOfPrimitiveType(element_type); + // Assume we use one element from the cache line and waste the remaining + // bandwidth. For example, if we're reading f32s, we use 1/16nd of the cache + // line. + return coalesced ? 1.0 + : 1.0 * element_size_bytes / + gpu_device_info.dram_to_l2_transaction_size_bytes(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h index 3a034a4985bb62..3f5cf9aeaf33da 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.h @@ -157,9 +157,7 @@ struct GpuPerformanceModelOptions { } static GpuPerformanceModelOptions ForModule(const HloModule* module) { - return module->config().debug_options().xla_gpu_enable_priority_fusion() - ? PriorityFusion() // Only cache within priority fusion. - : Default(); + return PriorityFusion(); } }; @@ -226,11 +224,12 @@ class GpuPerformanceModelBase { // given GPU. // // Assumes that the first n_bytes_net are always read from DRAM, but next - // reads can be cached. Applies waste factor if read from DRAM is uncoalesced. + // reads can be cached. Restricts the effective HBM bandwidth using the + // utilization rate passed as a parameter to model not-fully-coalesced reads. static absl::Duration ReadTimeWithDRAMHeuristic( const se::DeviceDescription& gpu_device_info, int64_t num_blocks, int64_t n_bytes_net, int64_t n_bytes_total, PrimitiveType element_type, - bool coalesced); + double hbm_bandwidth_utilization_rate); // Tells input access time of the producer alone if fused_consumer // is not specified. Otherwise estimates the access time to producer's @@ -259,6 +258,17 @@ class GpuPerformanceModelBase { bool coalesced); }; +// Given an element type and whether the read is coalesced, returns the +// utilization rate of the HBM bandwidth. +// +// TODO(b/332714755): to avoid interfering with the cost model as it exists +// right now, this duplicates pre-existing logic and doesn't take into account +// how much of the memory access is actually useful and just assumes the worst +// possible utilization if the read is uncoalesced. +double GetCoalescingUtilizationRate( + PrimitiveType element_type, const se::DeviceDescription& gpu_device_info, + bool coalesced); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index f96e9ee3767744..77c357d3cbdc69 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -25,8 +25,6 @@ limitations under the License. #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" @@ -38,17 +36,7 @@ namespace { class GpuPerformanceModelBaseTest : public HloTestBase { public: - GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - - GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis::Options options_{.count_multiple_input_accesses = true}; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index dc80d18a4a3a37..7baaa33a438d5c 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -50,13 +51,6 @@ namespace gpu { namespace { class GpuPerformanceModelTest : public HloTestBase { - GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: GpuPerformanceModel::RunTimes EstimateRunTimesDefault( const HloInstruction* producer, @@ -85,10 +79,7 @@ class GpuPerformanceModelTest : public HloTestBase { } mlir::MLIRContext mlir_context_; - GpuHloCostAnalysis::Options options_{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}; + GpuHloCostAnalysis::Options options_{.count_multiple_input_accesses = true}; // The reference times in the test cases below are measured // on A6000 by profiling the execution of the HLOs. se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; @@ -97,7 +88,7 @@ class GpuPerformanceModelTest : public HloTestBase { GpuPerformanceModelCache gpu_performance_model_cache_; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ - &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), + &device_info_, &fusion_analysis_cache_, HloCostAnalysis::DefaultShapeSize, &mlir_context_}; GpuPerformanceModelTest() : HloTestBase() {} diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler.cc index 149bd262cc4889..5c733adac314d1 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/service/gpu/model/hlo_op_profile.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" +#include "xla/service/hlo_verifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -158,6 +159,9 @@ absl::StatusOr HloOpProfiler::MeasureOpChainDuration( std::unique_ptr module = MakeModuleForMeasurements(op, data_type, chain_length); + HloVerifier verifier(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/false); + TF_RETURN_IF_ERROR(verifier.Run(&*module).status()); std::minstd_rand0 engine; // Some operations have dynamic duration that depends on the input values. diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc index 8205ec3b478af4..b12d35961f2bc4 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_run.cc @@ -49,11 +49,11 @@ void WriteOutput(const DeviceHloInstructionProfiles& literal, std::string file_name; std::string output_directory; if (tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { - std::string filename = tsl::io::JoinPath( + file_name = tsl::io::JoinPath( output_directory, absl::StrFormat("profiles-%d-%s", tsl::Env::Default()->NowMicros(), name)); - file_name = absl::StrCat(filename, ".textproto"); + absl::StrAppend(&file_name, ".textproto"); } else { file_name = tsl::io::GetTempFilename(absl::StrCat(name, ".textproto")); } @@ -123,11 +123,11 @@ int RunProfiler(int argc, char** argv) { } } - VLOG(1) << "\n" << instr_profiles; + VLOG(1) << "\n" << instr_profiles.DebugString(); - auto profile_name = HloOpProfiles::GetProfileName(&dev_info); DeviceHloInstructionProfiles device_profiles; - device_profiles.mutable_entries()->insert({profile_name, instr_profiles}); + device_profiles.mutable_entries()->insert( + {HloOpProfiles::GetProfileName(dev_info), instr_profiles}); if (!output_file.empty()) { WriteOutput(device_profiles, output_file); } diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index 95dfc7c19f5ad4..f8e102a05e58da 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -19,18 +19,23 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status_matchers.h" namespace xla { namespace gpu { namespace { -using HloOpProfilerTest = HloTestBase; - -TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { +class HloOpProfilerTest : public HloTestBase { + void SetUp() override { #ifndef GOOGLE_CUDA - GTEST_SKIP() << "Not built with --config=cuda"; + GTEST_SKIP() << "Not built with --config=cuda"; #endif - HloOpProfiler profiler(test_runner_); + } +}; + +TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { + HloOpProfiler profiler(test_runner_as_hlo_runner()); // f32 is fast but measurable. EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kAdd, F32) .value() @@ -48,6 +53,12 @@ TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { 1000); } +TEST_F(HloOpProfilerTest, UnsupportedCombinationsDoNotCrash) { + HloOpProfiler profiler(test_runner_as_hlo_runner()); + EXPECT_THAT(profiler.MeasureClockCyclesPerOp(HloOpcode::kCbrt, S8), + tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index 18ed0d526862a7..e0a433211e7e90 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -44,12 +44,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/gather_simplifier.h" #include "xla/layout.h" #include "xla/permutation_util.h" -#include "xla/service/gather_simplifier.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/matmul_utils.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -75,6 +74,206 @@ HloInstructionIndexing CreateUnknownIndexing(int64_t count = 1) { return indexing; } +struct HLORTVar { + Interval feasible_values; + const HloInstruction* hlo; + // This is a map from the iteration space of the corresponding indexing map to + // the iteration space of `hlo`. It shows what element of `hlo` we need to + // extract to get the runtime value for the RTVar. + mlir::AffineMap map; +}; + +bool operator==(const HLORTVar& lhs, const HLORTVar& rhs) { + return lhs.feasible_values == rhs.feasible_values && lhs.hlo == rhs.hlo && + lhs.map == rhs.map; +} + +inline bool operator!=(const HLORTVar& lhs, const HLORTVar& rhs) { + return !(lhs == rhs); +} + +// The return type of `OptimizeRTVar` below +struct RTVarOptimizationResult { + // An affine expr which maps the old RTVar to the new, optimized RTVar: + // `()[sk] -> s'k` (with k being `symbol_index` in the `OptimizeRTVar` call). + // If `expr` doesn't depend on `sk` it means the RTVar could be optimized + // away completely and the value of `rt_var` can be ignored. + AffineExpr remapped_symbol; + + // The new, optimized RTVar + HLORTVar rt_var; +}; + +// Tries to optimize the given RTVar by removing some parts (or entirety) of +// the dependent HLO graph: +// +// 1. If no optimization is possible it returns `{sk, rt_var}` - the +// identity expr and the unchanged rt_var. +// +// 2. If full optimization is possible, it returns +// `{const, rt_var}` - an affine expr that does not anymore depend +// on `sk` and an arbitrary rt_var. +// +// 3. if partial optimization is possible, it returns +// `{()[sk] -> f(sk), rt_var_new }` - an affine expression that maps from the +// old RTVar to the new RTVar, and the new RTVar itself. The new RTVar now +// references some HLO subgraph of the old RTVar's HLO. +RTVarOptimizationResult OptimizeRTVar(HLORTVar rt_var, int64_t symbol_index, + MLIRContext* mlir_context) { + const auto symbol = getAffineSymbolExpr(symbol_index, mlir_context); + auto result_expr = symbol; + + while (true) { + if (auto constant_expr = DynCast(rt_var.hlo)) { + if (rt_var.map.isConstant()) { + const auto idx = rt_var.map.getConstantResults(); + result_expr = result_expr.replace( + symbol, getAffineConstantExpr( + constant_expr->literal().GetIntegralAsS64(idx).value(), + mlir_context)); + } + return {result_expr, rt_var}; + } + + if (auto iota_expr = DynCast(rt_var.hlo)) { + auto iota_dimension = iota_expr->iota_dimension(); + CHECK(iota_dimension < rt_var.map.getNumResults()); + return { + result_expr.replace(symbol, rt_var.map.getResults()[iota_dimension]), + rt_var}; + } + + auto is_indexing_transformation = [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kBitcast || + instr->opcode() == HloOpcode::kBroadcast || + instr->opcode() == HloOpcode::kReshape || + instr->opcode() == HloOpcode::kReverse || + instr->opcode() == HloOpcode::kSlice || + instr->opcode() == HloOpcode::kTranspose; + }; + + if (is_indexing_transformation(rt_var.hlo)) { + auto instr_indexing_map = + *ComputeOutputToInputIndexing(rt_var.hlo, 0, mlir_context) + .indexing_maps[0] + .begin(); + + rt_var.hlo = rt_var.hlo->operand(0); + rt_var.map = instr_indexing_map.GetAffineMap().compose(rt_var.map); + continue; + } + + if (rt_var.hlo->opcode() == HloOpcode::kNegate) { + rt_var.hlo = rt_var.hlo->operand(0); + result_expr = result_expr.replace(symbol, -symbol); + continue; + } + + if (rt_var.hlo->opcode() == HloOpcode::kAdd || + rt_var.hlo->opcode() == HloOpcode::kSubtract || + rt_var.hlo->opcode() == HloOpcode::kMultiply || + rt_var.hlo->opcode() == HloOpcode::kDivide) { + const auto apply_op = [&](const AffineExpr& lhs, + const AffineExpr& rhs) -> AffineExpr { + switch (rt_var.hlo->opcode()) { + case HloOpcode::kAdd: + return lhs + rhs; + case HloOpcode::kSubtract: + return lhs - rhs; + case HloOpcode::kMultiply: + return lhs * rhs; + case HloOpcode::kDivide: + return lhs.floorDiv(rhs); + default: + ABSL_UNREACHABLE(); + } + }; + + auto lhs = OptimizeRTVar( + HLORTVar{rt_var.feasible_values, rt_var.hlo->operand(0), rt_var.map}, + symbol_index, mlir_context); + + if (!lhs.remapped_symbol.isFunctionOfSymbol(symbol_index)) { + // This means that lhs is constant-like and we can eliminate the + // operand. + result_expr = + result_expr.replace(symbol, apply_op(lhs.remapped_symbol, symbol)); + + // We continue optimizing the `rhs` operand + rt_var.hlo = rt_var.hlo->operand(1); + continue; + } + + auto rhs = OptimizeRTVar( + HLORTVar{rt_var.feasible_values, rt_var.hlo->operand(1), rt_var.map}, + symbol_index, mlir_context); + + if (!rhs.remapped_symbol.isFunctionOfSymbol(symbol_index)) { + // This means that rhs is constant-like and we can eliminate the + // operand. + result_expr = + result_expr.replace(symbol, apply_op(symbol, rhs.remapped_symbol)); + + // We can also take advantage of the optimization already done for lhs: + result_expr = result_expr.replace(symbol, lhs.remapped_symbol); + rt_var = lhs.rt_var; + continue; + } + } + + return {result_expr, rt_var}; + } +} + +std::vector ConvertHLORTVarsToRTVars( + const std::vector& hlo_rt_vars) { + std::vector rt_vars; + rt_vars.reserve(hlo_rt_vars.size()); + for (const HLORTVar& hlo_rt_var : hlo_rt_vars) { + rt_vars.push_back(IndexingMap::Variable{hlo_rt_var.feasible_values}); + } + return rt_vars; +} + +IndexingMap FoldRTVarsAndConstructIndexingMap( + AffineMap affine_map, std::vector dim_vars, + std::vector hlo_rt_vars) { + if (hlo_rt_vars.empty()) { + return IndexingMap(affine_map, std::move(dim_vars), /*range_vars=*/{}, + ConvertHLORTVarsToRTVars(hlo_rt_vars)); + } + + auto* ctx = affine_map.getContext(); + + for (auto symbol_index = 0; symbol_index < hlo_rt_vars.size(); + ++symbol_index) { + auto& rt_var = hlo_rt_vars[symbol_index]; + + // range_vars and rt_vars share the symbol space, with the rt_vars coming + // after the range_vars. + auto rt_var_symbol = getAffineSymbolExpr(symbol_index, ctx); + + RTVarOptimizationResult result = OptimizeRTVar(rt_var, symbol_index, ctx); + + if (result.remapped_symbol != rt_var_symbol) { + affine_map = affine_map.replace({{rt_var_symbol, result.remapped_symbol}}, + affine_map.getNumDims(), + affine_map.getNumSymbols()); + + llvm::DenseMap replacements; + } + + if (result.remapped_symbol.isFunctionOfSymbol(symbol_index)) { + // If we still depend on the rt_var, then we update it. + if (rt_var != result.rt_var) { + rt_var = std::move(result.rt_var); + } + } + } + return IndexingMap(affine_map, std::move(dim_vars), /*range_vars=*/{}, + ConvertHLORTVarsToRTVars(hlo_rt_vars)); +} + HloInstructionIndexing ComputeOutputToInputCwiseOpIndexing( const HloInstruction* instr, MLIRContext* mlir_context) { IndexingMap identity_map = CreateIdentityMap(instr->shape(), mlir_context); @@ -160,7 +359,8 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( // be adjusted for a particular operand_id. mlir::MutableAffineMap affine_map = AffineMap::getMultiDimIdentityMap(operand_0_dims.size(), mlir_context); - std::vector dim_vars = DimVarsFromTensorSizes(operand_0_dims); + std::vector dim_vars = + DimVarsFromTensorSizes(operand_0_dims); HloInstructionIndexing concat_indexing; concat_indexing.indexing_maps.resize(concat->operand_count()); @@ -170,7 +370,8 @@ HloInstructionIndexing ComputeOutputToInputConcatenateOpIndexing( for (const auto [operand_id, operand] : llvm::enumerate(concat->operands())) { affine_map.setResult(concat_dim, concat_dim_expr - offset); int64_t operand_concat_dim = operand->shape().dimensions()[concat_dim]; - dim_vars[concat_dim] = DimVar{{offset, offset + operand_concat_dim - 1}}; + dim_vars[concat_dim] = + IndexingMap::Variable{{offset, offset + operand_concat_dim - 1}}; concat_indexing.indexing_maps[operand_id].insert( IndexingMap(affine_map.getAffineMap(), dim_vars, /*range_vars=*/{}, /*rt_vars=*/{})); @@ -313,7 +514,7 @@ HloInstructionIndexing ComputeOutputToInputDynamicSliceOpIndexing( IndexingMap start_indices_map = IndexingMap::FromTensorSizes( empty_results_affine_map, output_shape.dimensions(), {}); - std::vector offsets_rt_vars; + std::vector offsets_rt_vars; offsets_rt_vars.reserve(rank); std::vector exprs; exprs.reserve(rank); @@ -322,17 +523,16 @@ HloInstructionIndexing ComputeOutputToInputDynamicSliceOpIndexing( exprs.push_back(getAffineDimExpr(dim, mlir_context) + getAffineSymbolExpr(dim, mlir_context)); offsets_rt_vars.push_back( - RTVar{Interval{0, input_shape.dimensions(dim) - slice_size}, - dynamic_slice->operand(dim + first_index_num), - empty_results_affine_map}); + HLORTVar{Interval{0, input_shape.dimensions(dim) - slice_size}, + dynamic_slice->operand(dim + first_index_num), + empty_results_affine_map}); } std::vector indexing_maps(dynamic_slice->operand_count(), start_indices_map); - indexing_maps.front() = - IndexingMap{AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, exprs, - mlir_context), - start_indices_map.GetDimVars(), /*range_vars=*/{}, - std::move(offsets_rt_vars)}; + indexing_maps.front() = FoldRTVarsAndConstructIndexingMap( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, exprs, + mlir_context), + start_indices_map.GetDimVars(), std::move(offsets_rt_vars)); return HloInstructionIndexing::FromIndexingMaps(indexing_maps); } @@ -362,19 +562,19 @@ HloInstructionIndexing ComputeOutputToInputDynamicUpdateSliceOpIndexing( // update: (d_0 - s_0, ..., d_{N-1} - s_{N-1}) std::vector exprs; exprs.reserve(rank); - std::vector rt_vars; + std::vector rt_vars; rt_vars.reserve(rank); for (auto [dim, slice_size] : llvm::enumerate(update_shape.dimensions())) { exprs.push_back(getAffineDimExpr(dim, mlir_context) - getAffineSymbolExpr(dim, mlir_context)); Interval feasible_values{0, output_shape.dimensions(dim) - slice_size}; - rt_vars.push_back(RTVar{feasible_values, dus->operand(2 + dim), - empty_results_affine_map}); + rt_vars.push_back(HLORTVar{feasible_values, dus->operand(2 + dim), + empty_results_affine_map}); } - IndexingMap update_map{AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, - /*results=*/exprs, mlir_context), - operand_map.GetDimVars(), - /*range_vars=*/{}, rt_vars}; + IndexingMap update_map = FoldRTVarsAndConstructIndexingMap( + AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/rank, + /*results=*/exprs, mlir_context), + operand_map.GetDimVars(), std::move(rt_vars)); std::vector indexing_maps(dus->operand_count(), start_indices_map); @@ -402,20 +602,20 @@ HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( // (d_0, ... d_{rank - 1}) -> (d_0, s_0), // where 0 <= s_0 <= indices_shape[1] - 1. AffineExpr indices_id_dim = getAffineDimExpr(0, mlir_context); - std::vector dim_vars = + std::vector dim_vars = DimVarsFromTensorSizes(output_shape.dimensions()); IndexingMap indices_map{ AffineMap::get(output_rank, 1, {indices_id_dim, getAffineSymbolExpr(0, mlir_context)}, mlir_context), dim_vars, - {RangeVar{{0, index_vector_length - 1}}}, + {IndexingMap::Variable{{0, index_vector_length - 1}}}, /*rt_vars=*/{}}; // A map for the `operand` operand of gather, from which we extract slices. // (d_0, ... d_{rank - 1}) -> (d_1 + s0, d_2 + s_1, ...), // where s_i are RTVars that extract indices from the `indices` operand. - std::vector rt_vars; + std::vector rt_vars; std::vector exprs; exprs.reserve(operand_shape.rank()); for (auto [operand_dim_id, slice_size] : @@ -425,7 +625,7 @@ HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( if (operand_dim_id >= index_vector_length) continue; - rt_vars.push_back(RTVar{ + rt_vars.push_back(HLORTVar{ Interval{0, operand_shape.dimensions(operand_dim_id) - slice_size}, gather->operand(1), AffineMap::get(output_rank, /*symbolCount=*/0, @@ -435,10 +635,10 @@ HloInstructionIndexing ComputeOutputToInputGatherOpIndexing( exprs.back() = exprs.back() + getAffineSymbolExpr(operand_dim_id, mlir_context); } - IndexingMap operand_map = { + IndexingMap operand_map = FoldRTVarsAndConstructIndexingMap( AffineMap::get(/*dimCount=*/output_rank, /*symbolCount=*/index_vector_length, exprs, mlir_context), - std::move(dim_vars), /*range_vars=*/{}, std::move(rt_vars)}; + std::move(dim_vars), std::move(rt_vars)); return HloInstructionIndexing::FromIndexingMaps({operand_map, indices_map}); } @@ -451,16 +651,16 @@ IndexingMap ComputeOutputToInputPadOpIndexingImpl( std::vector exprs; std::vector> constraints; - std::vector dim_vars; + std::vector dim_vars; exprs.reserve(output_rank); constraints.reserve(output_rank); int64_t output_dim_id = 0; for (const auto [output_dim, pad_low, pad_high, pad_interior] : llvm::zip(output_dims, padding_low, padding_high, padding_interior)) { AffineExpr dim_expr = getAffineDimExpr(output_dim_id, mlir_context); - dim_vars.push_back( - {Interval{std::max(int64_t{0}, pad_low), - std::min(output_dim - 1, output_dim - 1 - pad_high)}}); + dim_vars.push_back({IndexingMap::Variable{ + std::max(int64_t{0}, pad_low), + std::min(output_dim - 1, output_dim - 1 - pad_high)}}); if (pad_interior == 0) { exprs.push_back(dim_expr - pad_low); } else { @@ -529,7 +729,7 @@ HloInstructionIndexing ComputeOutputToInputReduceOpIndexing( output_shape.dimensions(), parallel_dims_sizes); IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), {}, /*is_simplified=*/true); + output_shape.dimensions(), {}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce->operand_count()); @@ -604,8 +804,8 @@ IndexingMap ComposeIndexingMapsForWindow( padding_interior.reserve(rank); padded_input_dimensions.reserve(rank); SmallVector exprs; - std::vector dim_vars; - std::vector range_vars; + std::vector dim_vars; + std::vector range_vars; exprs.reserve(rank); dim_vars.reserve(rank); range_vars.reserve(rank); @@ -625,8 +825,9 @@ IndexingMap ComposeIndexingMapsForWindow( exprs.push_back(symbol_expr * window_config.window_dilation() + window_config.stride() * dim_expr); - dim_vars.push_back({Interval{0, output_dimensions[dim_id] - 1}}); - range_vars.push_back({Interval{0, window_config.size() - 1}}); + dim_vars.push_back( + {IndexingMap::Variable{0, output_dimensions[dim_id] - 1}}); + range_vars.push_back({IndexingMap::Variable{0, window_config.size() - 1}}); } // Indexing map for pad op that pads the input. IndexingMap padded_input_indexing = ComputeOutputToInputPadOpIndexingImpl( @@ -662,8 +863,7 @@ HloInstructionIndexing ComputeOutputToInputReduceWindowOpIndexing( // Indexing map for the init value. IndexingMap inits_indexing_map = IndexingMap::FromTensorSizes( AffineMap::get(output_shape.rank(), /*symbolCount=*/0, {}, mlir_context), - output_shape.dimensions(), /*symbol_upper_bounds=*/{}, - /*is_simplified=*/true); + output_shape.dimensions(), /*symbol_upper_bounds=*/{}); HloInstructionIndexing instr_indexing; instr_indexing.indexing_maps.resize(reduce_window->operand_count()); @@ -735,8 +935,9 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( kernel_exprs[dnums.kernel_output_feature_dimension()] = dim_expr; // Build initial symbol ranges. - std::vector input_symbols = input_spatial_indexing.GetRangeVars(); - std::vector kernel_symbols = + std::vector input_symbols = + input_spatial_indexing.GetRangeVars(); + std::vector kernel_symbols = RangeVarsFromTensorSizes(kernel_spatial_sizes); // Add symbol for input feature dimension. @@ -748,8 +949,8 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( int64_t input_group_size = kernel_shape.dimensions(dnums.kernel_input_feature_dimension()); Interval input_feature_range{0, input_group_size - 1}; - input_symbols.push_back({input_feature_range}); - kernel_symbols.push_back({input_feature_range}); + input_symbols.push_back(IndexingMap::Variable{input_feature_range}); + kernel_symbols.push_back(IndexingMap::Variable{input_feature_range}); // With multiple feature groups, the input feature dimension is equally split. if (convolution->feature_group_count() > 1) { @@ -770,7 +971,8 @@ HloInstructionIndexing ComputeOutputToInputConvolutionOpIndexing( output_shape.dimensions(dnums.output_batch_dimension()); AffineExpr batch_group_expr = getAffineSymbolExpr(input_symbols.size(), mlir_context); - input_symbols.push_back({{0, convolution->batch_group_count() - 1}}); + input_symbols.push_back( + IndexingMap::Variable{{0, convolution->batch_group_count() - 1}}); input_exprs[dnums.input_batch_dimension()] = batch_group_expr * batch_group_size + batch_dim_expr; } else { @@ -1151,18 +1353,20 @@ std::vector ToTransposeDimensions(const Layout& l) { } // namespace +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context) { + return IndexingMap::FromTensorSizes( + AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), + /*dim_upper_bounds=*/dimensions, /*symbol_upper_bounds=*/{}); +} + IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { if (shape.IsTuple()) { // Should happen only for variadic reduce. In that case all tuple shapes are // equal. return CreateIdentityMap(shape.tuple_shapes(0), mlir_context); } - - auto dimensions = shape.dimensions(); - IndexingMap identity_map = IndexingMap::FromTensorSizes( - AffineMap::getMultiDimIdentityMap(dimensions.size(), mlir_context), - dimensions, {}, /*is_simplified=*/dimensions.empty()); - return identity_map; + return CreateIdentityMap(shape.dimensions(), mlir_context); } llvm::SmallVector DelinearizeInBoundsIndex( @@ -1254,33 +1458,25 @@ HloInstructionIndexing HloInstructionIndexing::FromIndexingMaps( return instr_indexing; } -std::string HloInstructionIndexing::ToString( - const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - Print(ss, printer); +std::string HloInstructionIndexing::ToString() const { + std::stringstream ss; + ss << *this; return ss.str(); } -void HloInstructionIndexing::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +std::ostream& operator<<(std::ostream& out, + const HloInstructionIndexing& instr_indexing) { for (const auto& [operand_id, indexing_maps] : - llvm::enumerate(indexing_maps)) { + llvm::enumerate(instr_indexing.indexing_maps)) { out << "operand id = " << operand_id << ' '; for (const auto& indexing_map : indexing_maps) { if (indexing_map.IsUndefined()) { out << "unknown indexing"; continue; } - indexing_map.Print(out, printer); + out << indexing_map; } } -} - -std::ostream& operator<<(std::ostream& out, - const HloInstructionIndexing& instr_indexing) { - AffineMapPrinter printer; - instr_indexing.Print(out, printer); return out; } @@ -1507,13 +1703,5 @@ IndexingMap ComputeEpilogueInputToOutputIndexing( return root_indexing; } -IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, - int64_t operand_idx, - mlir::MLIRContext* mlir_context) { - HloInstructionIndexing indexing = - ComputeOutputToInputIndexing(instr, operand_idx, mlir_context); - return *indexing.indexing_maps[0].begin(); -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index 965b060da30be7..e05a598cc2a322 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/shape.h" @@ -43,9 +42,7 @@ using IndexingMapSet = absl::flat_hash_set; // Contains indexing maps for all N-dimensional tensor input operands that // correspond to a particular output. struct HloInstructionIndexing { - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + std::string ToString() const; // Returns true if the indexing was simplified. bool Simplify(); @@ -163,6 +160,8 @@ std::vector DelinearizeIndex(absl::Span dims, // Creates an identity indexing map corresponding to the parameter shape. IndexingMap CreateIdentityMap(const Shape& shape, mlir::MLIRContext* mlir_context); +IndexingMap CreateIdentityMap(absl::Span dimensions, + mlir::MLIRContext* mlir_context); llvm::SmallVector DelinearizeInBoundsIndex( mlir::AffineExpr linear, absl::Span sizes); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 440c06f44bb19b..cb4aefc17e6423 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "tsl/platform/test.h" @@ -63,15 +64,13 @@ TEST_F(IndexingAnalysisTest, FuseProducerConsumerOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -97,29 +96,25 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))), Pair(transpose, ElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))), Pair(parameter, UnorderedElementsAre(MatchIndexingMap(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"))))); } @@ -158,34 +153,29 @@ TEST_F(IndexingAnalysisTest, Pair(root, ElementsAre(MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 31], - is_simplified: false + d0 in [0, 31] )"))), Pair(root->operand(0), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(1), ElementsAre(MatchIndexingMap(R"( (d0)[s0] -> (d0, s0), domain: d0 in [0, 31], - s0 in [0, 39], - is_simplified: true + s0 in [0, 39] )"))), Pair(root->operand(2), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))), Pair(root->operand(3), ElementsAre(MatchIndexingMap(R"( (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )"))))); } @@ -215,8 +205,7 @@ TEST_F(IndexingAnalysisTest, ComputeGroupedOutputToInputIndexing_SingleOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )"))))); } @@ -260,8 +249,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: false + d3 in [0, 63] )"))), Pair(¶meter_0.instruction(), ElementsAre(MatchIndexingMap(R"( (d0, d1, d2, d3) -> (d0, d2), @@ -269,8 +257,7 @@ TEST_F(IndexingAnalysisTest, d0 in [0, 14], d1 in [0, 31], d2 in [0, 19], - d3 in [0, 63], - is_simplified: true + d3 in [0, 63] )"))))); } @@ -290,8 +277,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -302,8 +288,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestOutputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); } @@ -366,8 +351,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -378,8 +362,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -399,8 +382,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); auto output_indexing = GetInputToOutputIndexing(root, /*input_id=*/0, @@ -411,8 +393,7 @@ TEST_F(IndexingAnalysisTest, PhysicalLayoutTestInputAndOutputPermutation) { domain: d0 in [0, 29], d1 in [0, 9], - d2 in [0, 19], - is_simplified: false + d2 in [0, 19] )")); } @@ -431,14 +412,12 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -447,8 +426,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -457,8 +435,7 @@ TEST_F(IndexingAnalysisTest, ElementwiseOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -482,14 +459,12 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -498,8 +473,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -508,8 +482,7 @@ TEST_F(IndexingAnalysisTest, Map) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 19], - is_simplified: false + d1 in [0, 19] )")); } @@ -527,8 +500,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsReshape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -547,8 +519,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTranspose) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -566,8 +537,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { (d0, d1) -> (d1, d0 floordiv 3, d0 mod 3), domain: d0 in [0, 50], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -576,8 +546,7 @@ TEST_F(IndexingAnalysisTest, BitcastIsTransposeReshapeTranspose) { domain: d0 in [0, 15], d1 in [0, 16], - d2 in [0, 2], - is_simplified: true + d2 in [0, 2] )")); } @@ -596,8 +565,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 9], d1 in [0, 19], - d2 in [0, 29], - is_simplified: false + d2 in [0, 29] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -606,8 +574,7 @@ TEST_F(IndexingAnalysisTest, BroadcastOp) { domain: d0 in [0, 19], s0 in [0, 9], - s1 in [0, 29], - is_simplified: false + s1 in [0, 29] )")); } @@ -640,22 +607,19 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 - 5, d2), domain: d0 in [0, 1], d1 in [5, 15], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 - 16, d2), domain: d0 in [0, 1], d1 in [16, 32], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); @@ -665,8 +629,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 4], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, /*input_id=*/1); @@ -676,8 +639,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 10], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); auto output_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); @@ -687,8 +649,7 @@ TEST_F(IndexingAnalysisTest, ConcatenateOp) { domain: d0 in [0, 1], d1 in [0, 16], - d2 in [0, 6], - is_simplified: false + d2 in [0, 6] )")); } @@ -707,42 +668,32 @@ TEST_F(IndexingAnalysisTest, DynamicSliceOp) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2), + (d0, d1, d2){rt0, rt1, rt2} -> (d0 + rt0, d1 + rt1, d2 + rt2), domain: d0 in [0, 0], d1 in [0, 1], d2 in [0, 31], - s0 in [0, 1], - hlo: %of1 = s32[] parameter(1), - (d0, d1, d2) -> (), - s1 in [0, 0], - hlo: %of2 = s32[] parameter(2), - (d0, d1, d2) -> (), - s2 in [0, 226], - hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: false + rt0 in [0, 1], + rt1 in [0, 0], + rt2 in [0, 226] operand id = 1 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 2 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] operand id = 3 (d0, d1, d2) -> (), domain: d0 in [0, 0], d1 in [0, 1], - d2 in [0, 31], - is_simplified: false + d2 in [0, 31] )")); } @@ -763,32 +714,24 @@ TEST_F(IndexingAnalysisTest, DynamicUpdateSliceOp) { (d0, d1) -> (d0, d1), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 1 - (d0, d1)[s0, s1] -> (d0 - s0, d1 - s1), + (d0, d1){rt0, rt1} -> (d0 - rt0, d1 - rt1), domain: d0 in [0, 19], d1 in [0, 29], - s0 in [0, 15], - hlo: %of1 = s32[] parameter(2), - (d0, d1) -> (), - s1 in [0, 20], - hlo: %of2 = s32[] parameter(3), - (d0, d1) -> (), - is_simplified: false + rt0 in [0, 15], + rt1 in [0, 20] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 19], - d1 in [0, 29], - is_simplified: false + d1 in [0, 29] )")); } @@ -810,13 +753,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSingleBinaryOp) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] operand id = 1 (d0) -> (d0), domain: - d0 in [0, 99], - is_simplified: true + d0 in [0, 99] )")); } @@ -890,8 +831,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0] -> (d0 * 768 + s0), domain: @@ -901,8 +841,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 2 (d0, d1, d2, d3, d4, d5) -> (d1), domain: @@ -911,8 +850,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] operand id = 3 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -922,8 +860,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 4 (d0, d1, d2, d3, d4, d5)[s0] -> (d1, d0 * 768 + s0), domain: @@ -933,8 +870,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d3 in [0, 0], d4 in [0, 5], d5 in [0, 127], - s0 in [0, 767], - is_simplified: true + s0 in [0, 767] operand id = 5 (d0, d1, d2, d3, d4, d5) -> (d2, d4, d5), domain: @@ -943,8 +879,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDot) { d2 in [0, 2], d3 in [0, 0], d4 in [0, 5], - d5 in [0, 127], - is_simplified: true + d5 in [0, 127] )")); } @@ -1001,16 +936,14 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSoftmax) { d0 in [0, 1], d1 in [0, 64], d2 in [0, 124], - s0 in [0, 124], - is_simplified: true + s0 in [0, 124] )"), MatchIndexingMap(R"( (d0, d1, d2) -> (d0, d1, d2), domain: d0 in [0, 1], d1 in [0, 64], - d2 in [0, 124], - is_simplified: true + d2 in [0, 124] )")))); } @@ -1032,15 +965,13 @@ TEST_F(IndexingAnalysisTest, FusionOpTensorPlusTransposedTensor) { (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )"), MatchIndexingMap(R"( (d0, d1) -> (d1, d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")))); } @@ -1070,38 +1001,32 @@ TEST_F(IndexingAnalysisTest, FusionExponentialDuplication) { ElementsAre(UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")), UnorderedElementsAre(MatchIndexingMap(R"( (d0) -> (d0 + 2), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0 + 1), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )"), MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 1], - is_simplified: true + d0 in [0, 1] )")))); } @@ -1118,19 +1043,14 @@ TEST_F(IndexingAnalysisTest, GatherOp) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1, d2, d3)[s0, s1] -> (d1 + s0, d2 + s1, d3), + (d0, d1, d2, d3){rt0, rt1} -> (d1 + rt0, d2 + rt1, d3), domain: d0 in [0, 1805], d1 in [0, 6], d2 in [0, 7], d3 in [0, 3], - s0 in [0, 26], - hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 0), - s1 in [0, 68], - hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 1), - is_simplified: false + rt0 in [0, 26], + rt1 in [0, 68] operand id = 1 (d0, d1, d2, d3)[s0] -> (d0, s0), domain: @@ -1138,8 +1058,7 @@ TEST_F(IndexingAnalysisTest, GatherOp) { d1 in [0, 6], d2 in [0, 7], d3 in [0, 3], - s0 in [0, 1], - is_simplified: false + s0 in [0, 1] )")); } @@ -1172,13 +1091,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfReduce) { d0 in [0, 9], s0 in [0, 149], s1 in [0, 49], - s2 in [0, 19], - is_simplified: true + s2 in [0, 19] operand id = 1 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); } @@ -1210,14 +1127,12 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReduceOfBroadcast) { domain: d0 in [0, 14], d1 in [0, 63], - s0 in [0, 19], - is_simplified: true + s0 in [0, 19] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 14], - d1 in [0, 63], - is_simplified: true + d1 in [0, 63] )")); } @@ -1252,8 +1167,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithTransposeOfTranspose) { domain: d0 in [0, 9], d1 in [0, 49], - d2 in [0, 19], - is_simplified: true + d2 in [0, 19] )")); } @@ -1285,13 +1199,11 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReducedSlice) { domain: d0 in [0, 31], s0 in [0, 15], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] operand id = 1 (d0) -> (), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1312,8 +1224,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_CollapseOfExpand) { operand id = 0 (d0) -> (d0), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1335,8 +1246,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ExpandOfCollapse) { (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1359,8 +1269,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithReshape_ChainedGenericReshapes) { domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } @@ -1385,8 +1294,7 @@ TEST_F(IndexingAnalysisTest, FusionOpWithSliceOfSlice) { domain: d0 in [0, 6], d1 in [0, 8], - d2 in [0, 23], - is_simplified: true + d2 in [0, 23] )")); } @@ -1418,47 +1326,34 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDynSliceOfDynSlice) { )")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1)[s0, s1, s2, s3] -> (d0 + s0 + s2, d1 + s1 + s3), + (d0, d1){rt0, rt1, rt2, rt3} -> (d0 + rt0 + rt2, d1 + rt1 + rt3), domain: d0 in [0, 24], d1 in [0, 15], - s0 in [0, 100], - hlo: %of11 = s32[] parameter(1), - (d0, d1) -> (), - s1 in [0, 32], - hlo: %of12 = s32[] parameter(2), - (d0, d1) -> (), - s2 in [0, 25], - hlo: %of21 = s32[] parameter(3), - (d0, d1) -> (), - s3 in [0, 16], - hlo: %of22 = s32[] parameter(4), - (d0, d1) -> (), - is_simplified: true + rt0 in [0, 100], + rt1 in [0, 32], + rt2 in [0, 25], + rt3 in [0, 16] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] operand id = 4 (d0, d1) -> (), domain: d0 in [0, 24], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } @@ -1487,22 +1382,19 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfAllConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 (d0, d1, d2) -> (d0, d1 * 3 - 5, d2), domain: d0 in [0, 1], d1 in [2, 5], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 2 (d0, d1, d2) -> (d0, d1 * 3 - 16, d2), domain: d0 in [0, 1], d1 in [6, 10], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] )")); } @@ -1531,8 +1423,7 @@ TEST_F(IndexingAnalysisTest, FusionOpSliceOfOneOfConcatenateOpInputs) { domain: d0 in [0, 1], d1 in [0, 2], - d2 in [0, 6], - is_simplified: true + d2 in [0, 6] operand id = 1 KNOWN EMPTY operand id = 2 @@ -1561,15 +1452,13 @@ TEST_F(IndexingAnalysisTest, FusionOpReshapeOfConcat) { domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [0, 1], - is_simplified: true + d0 * 8 + d1 in [0, 1] operand id = 1 (d0, d1) -> (d0 * 8 + d1 - 2), domain: d0 in [0, 3], d1 in [0, 7], - d0 * 8 + d1 in [2, 31], - is_simplified: true + d0 * 8 + d1 in [2, 31] )")); } @@ -1596,8 +1485,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpCollapseShape) { operand id = 0 (d0) -> (d0 floordiv 8, d0 mod 8), domain: - d0 in [0, 31], - is_simplified: true + d0 in [0, 31] )")); } @@ -1614,8 +1502,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandShape) { (d0, d1) -> (d0 * 8 + d1), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1634,8 +1521,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 31], d1 in [0, 2], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -1645,8 +1531,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandAndCollapseShape) { domain: d0 in [0, 3], d1 in [0, 7], - d2 in [0, 11], - is_simplified: true + d2 in [0, 11] )")); } @@ -1664,8 +1549,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpExpandSubshapeOnly) { domain: d0 in [0, 3], d1 in [0, 3], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -1683,8 +1567,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape2DTo3D) { domain: d0 in [0, 1], d1 in [0, 3], - d2 in [0, 3], - is_simplified: true + d2 in [0, 3] )")); } @@ -1703,8 +1586,7 @@ TEST_F(IndexingAnalysisTest, ReshapeOpGenericReshape3DTo2D) { d1 mod 4), domain: d0 in [0, 3], - d1 in [0, 7], - is_simplified: true + d1 in [0, 7] )")); } @@ -1723,14 +1605,12 @@ TEST_F(IndexingAnalysisTest, PadOp) { domain: d0 in [1, 7], d1 in [4, 7], - (d0 - 1) mod 2 in [0, 0], - is_simplified: false + (d0 - 1) mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 11], - d1 in [0, 15], - is_simplified: false + d1 in [0, 15] )")); } @@ -1748,14 +1628,12 @@ TEST_F(IndexingAnalysisTest, PadOpNoInterior) { (d0, d1) -> (d0 - 1, d1), domain: d0 in [1, 2], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 9], - d1 in [0, 7], - is_simplified: false + d1 in [0, 7] )")); } @@ -1778,13 +1656,11 @@ TEST_F(IndexingAnalysisTest, PadOpNegativePadding) { (d0) -> ((d0 + 3) floordiv 2), domain: d0 in [0, 4], - (d0 + 3) mod 2 in [0, 0], - is_simplified: false + (d0 + 3) mod 2 in [0, 0] operand id = 1 (d0) -> (), domain: - d0 in [0, 4], - is_simplified: false + d0 in [0, 4] )")); } @@ -1811,14 +1687,12 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 9], s0 in [0, 19], - s1 in [0, 49], - is_simplified: false + s1 in [0, 49] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 149], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto output_indexing_0 = GetInputToOutputIndexing(root, 0); @@ -1829,8 +1703,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { d0 in [0, 149], d1 in [0, 19], d2 in [0, 9], - d3 in [0, 49], - is_simplified: false + d3 in [0, 49] )")); auto output_indexing_1 = GetInputToOutputIndexing(root, 1); EXPECT_THAT(output_indexing_1.ToString(), MatchIndexingString(R"( @@ -1838,8 +1711,7 @@ TEST_F(IndexingAnalysisTest, ReduceOp) { ()[s0, s1] -> (s0, s1), domain: s0 in [0, 149], - s1 in [0, 9], - is_simplified: false + s1 in [0, 9] )")); } @@ -1872,24 +1744,20 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); auto output_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); @@ -1898,32 +1766,27 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 1 (d0)[s0] -> (s0, d0), domain: d0 in [0, 9], - s0 in [0, 255], - is_simplified: false + s0 in [0, 255] operand id = 2 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] operand id = 3 (d0) -> (), domain: - d0 in [0, 9], - is_simplified: true + d0 in [0, 9] )")); constexpr std::string_view kInputToOutputIndexing = R"( (d0, d1) -> (d1), domain: d0 in [0, 255], - d1 in [0, 9], - is_simplified: false + d1 in [0, 9] )"; auto input_indexing_0 = GetInputToOutputIndexing(root, /*input_id=*/0); EXPECT_THAT( @@ -1940,8 +1803,7 @@ TEST_F(IndexingAnalysisTest, VariadicReduceOp) { constexpr std::string_view kInitToOutputIndexing = R"( ()[s0] -> (s0), domain: - s0 in [0, 9], - is_simplified: false + s0 in [0, 9] )"; auto input_indexing_2 = GetInputToOutputIndexing(root, /*input_id=*/2); EXPECT_THAT( @@ -1977,14 +1839,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_NoPadding) { domain: d0 in [0, 1023], d1 in [0, 2], - s0 in [0, 511], - is_simplified: true + s0 in [0, 511] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 1023], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2013,14 +1873,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_PaddingAndWindowStride) { s0 in [0, 2], s1 in [0, 1], d0 * 2 + s0 in [1, 13], - d1 + s1 in [0, 16], - is_simplified: true + d1 + s1 in [0, 16] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 6], - d1 in [0, 16], - is_simplified: true + d1 in [0, 16] )")); } @@ -2047,14 +1905,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_BaseDilation) { d0 in [0, 2], d1 in [0, 4], d0 mod 2 in [0, 0], - d1 mod 2 in [0, 0], - is_simplified: true + d1 mod 2 in [0, 0] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 2], - d1 in [0, 4], - is_simplified: true + d1 in [0, 4] )")); } @@ -2080,14 +1936,12 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_WindowDilation) { domain: d0 in [0, 3], d1 in [0, 2], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] operand id = 1 (d0, d1) -> (), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] )")); } @@ -2121,28 +1975,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); auto input_indexing_1 = GetOutputToInputIndexing(root, /*output_id=*/1); EXPECT_THAT(input_indexing_1.ToString(), MatchIndexingString(R"( @@ -2152,28 +2002,24 @@ TEST_F(IndexingAnalysisTest, ReduceWindowOp_Variadic) { d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 1 (d0, d1)[s0, s1] -> (s0, d1 + s1), domain: d0 in [0, 0], d1 in [0, 1], s0 in [0, 1], - s1 in [0, 1], - is_simplified: true + s1 in [0, 1] operand id = 2 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] operand id = 3 (d0, d1) -> (), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); } @@ -2198,8 +2044,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2209,8 +2054,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_NoPadding) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2237,8 +2081,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { s1 in [0, 4], s2 in [0, 3], d1 * 2 + s0 in [1, 12], - d2 * 2 + s1 in [2, 11], - is_simplified: false + d2 * 2 + s1 in [2, 11] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2248,8 +2091,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_PaddingAndWindowStride) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2276,8 +2118,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { s1 in [0, 4], s2 in [0, 3], (d1 + s0) mod 2 in [0, 0], - (d2 + s1) mod 2 in [0, 0], - is_simplified: false + (d2 + s1) mod 2 in [0, 0] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2287,8 +2128,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_LhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2313,8 +2153,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2324,8 +2163,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_RhsDilation) { d3 in [0, 7], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2350,8 +2188,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2361,8 +2198,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_FeatureGroups) { d3 in [0, 47], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2388,8 +2224,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { s0 in [0, 2], s1 in [0, 4], s2 in [0, 3], - s3 in [0, 6], - is_simplified: false + s3 in [0, 6] operand id = 1 (d0, d1, d2, d3)[s0, s1, s2] -> (s2, s0, s1, d3), domain: @@ -2399,8 +2234,7 @@ TEST_F(IndexingAnalysisTest, ConvolutionOp_BatchGroups) { d3 in [0, 20], s0 in [0, 2], s1 in [0, 4], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } @@ -2420,8 +2254,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); auto output_indexing = GetInputToOutputIndexing(root); @@ -2432,8 +2265,7 @@ TEST_F(IndexingAnalysisTest, ReverseOp) { d0 in [0, 0], d1 in [0, 16], d2 in [0, 8], - d3 in [0, 8], - is_simplified: false + d3 in [0, 8] )")); } @@ -2458,8 +2290,7 @@ TEST_F(IndexingAnalysisTest, ReverseReshape) { (d0, d1) -> (d0, d1), domain: d0 in [0, 9], - d1 in [0, 10], - is_simplified: true + d1 in [0, 10] )")); } @@ -2479,8 +2310,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { domain: d0 in [0, 4], d1 in [0, 2], - d2 in [0, 24], - is_simplified: false + d2 in [0, 24] )")); auto output_indexing = GetInputToOutputIndexing(root); EXPECT_THAT(output_indexing.ToString(), MatchIndexingString(R"( @@ -2495,8 +2325,7 @@ TEST_F(IndexingAnalysisTest, SliceOp) { d1 in [3, 17], d2 in [0, 48], (d1 - 3) mod 7 in [0, 0], - d2 mod 2 in [0, 0], - is_simplified: false + d2 mod 2 in [0, 0] )")); } @@ -2516,8 +2345,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: false + d3 in [0, 12287] )")); EXPECT_THAT(GetInputToOutputIndexing(root).ToString(), MatchIndexingString(R"( operand id = 0 @@ -2526,8 +2354,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp) { d0 in [0, 2], d1 in [0, 12287], d2 in [0, 5], - d3 in [0, 127], - is_simplified: false + d3 in [0, 127] )")); } @@ -2546,8 +2373,7 @@ TEST_F(IndexingAnalysisTest, TransposeOp4D) { d0 in [0, 2], d1 in [0, 5], d2 in [0, 127], - d3 in [0, 12287], - is_simplified: true + d3 in [0, 12287] )")); } @@ -2573,8 +2399,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] operand id = 1 (d0, d1, d2, d3, d4, d5)[s0, s1] -> (s1, d0, d4, s0, d5, d1), domain: @@ -2585,8 +2410,7 @@ TEST_F(IndexingAnalysisTest, DotOp) { d4 in [0, 15], d5 in [0, 21], s0 in [0, 17], - s1 in [0, 16], - is_simplified: false + s1 in [0, 16] )")); } @@ -2647,8 +2471,7 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { (d0, d1) -> (d0 * 6, d1 * 2), domain: d0 in [0, 3], - d1 in [0, 2], - is_simplified: true + d1 in [0, 2] operand id = 1 unknown indexing operand id = 2 @@ -2679,15 +2502,13 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing) { HloInstructionAdaptor log(*computation->GetInstructionWithName("log"), fusion.get()); - EXPECT_THAT( - ComputeEpilogueInputToOutputIndexing(transpose, log, &mlir_context_) - .ToString(), - MatchIndexingString(R"( + EXPECT_THAT(ToString(ComputeEpilogueInputToOutputIndexing(transpose, log, + &mlir_context_)), + MatchIndexingString(R"( (d0, d1) -> (d1 * 1000 + d0), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: true + d1 in [0, 999] )")); } @@ -2710,15 +2531,13 @@ TEST_F(IndexingAnalysisTest, EpilogueIndexing_NoEpilogue) { HloInstructionAdaptor transpose(*computation->GetInstructionWithName("t"), fusion.get()); - EXPECT_THAT( - ComputeEpilogueInputToOutputIndexing(transpose, transpose, &mlir_context_) - .ToString(), - MatchIndexingString(R"( + EXPECT_THAT(ToString(ComputeEpilogueInputToOutputIndexing( + transpose, transpose, &mlir_context_)), + MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } @@ -2736,21 +2555,245 @@ TEST_F(IndexingAnalysisTest, BroadcastingElementwise) { operand id = 0 (d0, d1) -> (), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 1 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] operand id = 2 (d0, d1) -> (d0, d1), domain: d0 in [0, 999], - d1 in [0, 999], - is_simplified: false + d1 in [0, 999] )")); } +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ScalarConstant) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + offset = s64[] constant(42) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, offset), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + ROOT fusion = s32[10] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0) -> (d0 + 42), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Iota) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=0 + ROOT gather = f32[42,1,1] gather(p0, iota), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (d0, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_IotaAsConstant) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=1 + ROOT gather = f32[42,1,1] gather(p0, iota), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (0, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Broadcast) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + c42 = s64[] constant(42) + bcast = s64[42, 1] broadcast(s64[] c42), dimensions={} + ROOT gather = f32[42,1,1] gather(p0, bcast), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (42, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Reverse) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = f32[33,76] parameter(0) + iota = s64[42,1] iota(), iota_dimension=0 + reverse = s64[42,1] reverse(iota), dimensions={0} + ROOT gather = f32[42,1,1] gather(p0, reverse), + offset_dims={1,2}, + collapsed_slice_dims={}, + start_index_map={0}, + index_vector_dim=1, + slice_sizes={1,1} + } + ENTRY main { + p0 = f32[33,76] parameter(0) + ROOT fusion = f32[42,1,1] fusion(p0), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0, d1, d2) -> (-d0 + 41, 0), + domain: + d0 in [0, 41], + d1 in [0, 0], + d2 in [0, 0] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Add) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + add = s64[] add(c42, p1) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, add), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 (d0){rt0} -> (d0 + rt0 + 42), + domain: + d0 in [0, 9], + rt0 in [0, 4086] + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_Multiply) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + add = s64[] multiply(c42, p1) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, add), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + // TODO: Figure out why the bounds are not updated. + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 (d0){rt0} -> (d0 + rt0 * 42), + domain: + d0 in [0, 9], + rt0 in [0, 4086] + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + +TEST_F(IndexingAnalysisTest, FusionWithRTVarsSimplification_ChainedOps) { + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( + HloModule m + fused_computation { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + c42 = s64[] constant(42) + c2 = s64[] constant(2) + add = s64[] add(c42, p1) + multiply = s64[] multiply(c2, add) + ROOT dynamic-slice = s32[10] + dynamic-slice(p0, multiply), dynamic_slice_sizes={10} + } + ENTRY main { + p0 = s32[4096] parameter(0) + p1 = s64[] parameter(1) + ROOT fusion = s32[10] fusion(p0, p1), kind=kInput, calls=fused_computation + } + )hlo")); + EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( + operand id = 0 + (d0){rt0} -> (d0 + rt0 * 2 + 84), + domain: d0 in [0, 9], + rt0 in [0, 4086] + operand id = 1 + (d0) -> (), + domain: + d0 in [0, 9] + )")); +} + TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"hlo( HloModule m @@ -2772,21 +2815,17 @@ TEST_F(IndexingAnalysisTest, FusionOpWithDUS) { )hlo")); EXPECT_THAT(input_indexing.ToString(), MatchIndexingString(R"( operand id = 0 - (d0, d1)[s0] -> (0, d1 + s0 - 4096), + (d0, d1){rt0} -> (0, d1 + rt0 - 4096), domain: d0 in [0, 0], d1 in [0, 4095], - s0 in [0, 4096], - hlo: %slice = s32[1]{0} parameter(1), - (d0, d1) -> (0), - d1 + s0 in [4096, 8191], - is_simplified: true + rt0 in [0, 4096], + d1 + rt0 in [4096, 8191] operand id = 1 (d0, d1) -> (0), domain: d0 in [0, 0], - d1 in [0, 4095], - is_simplified: true + d1 in [0, 4095] )")); } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index 7add913b3e5942..cea65b2df29f6d 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -30,7 +30,10 @@ limitations under the License. #include #include "absl/base/optimization.h" +#include "absl/log/check.h" #include "absl/numeric/int128.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" @@ -48,7 +51,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "tsl/platform/logging.h" // IWYU pragma: keep namespace xla { @@ -783,15 +785,17 @@ SmallVector MapSymbolsToComposedSymbolsList( } // namespace static constexpr std::string_view kVarKindDefault = "default"; -static constexpr std::string_view kVarKindThreadX = "thread_x"; -static constexpr std::string_view kVarKindThreadY = "thread_y"; -static constexpr std::string_view kVarKindThreadZ = "thread_z"; -static constexpr std::string_view kVarKindBlockX = "block_x"; -static constexpr std::string_view kVarKindBlockY = "block_y"; -static constexpr std::string_view kVarKindBlockZ = "block_z"; - -std::string_view ToString(VariableKind type) { - switch (type) { +static constexpr std::string_view kVarKindThreadX = "th_x"; +static constexpr std::string_view kVarKindThreadY = "th_y"; +static constexpr std::string_view kVarKindThreadZ = "th_z"; +static constexpr std::string_view kVarKindBlockX = "bl_x"; +static constexpr std::string_view kVarKindBlockY = "bl_y"; +static constexpr std::string_view kVarKindBlockZ = "bl_z"; +static constexpr std::string_view kVarKindWarp = "warp"; +static constexpr std::string_view kVarKindWarpThread = "th_w"; + +std::string_view ToVariableName(VariableKind var_kind) { + switch (var_kind) { case VariableKind::kDefault: return kVarKindDefault; case VariableKind::kThreadX: @@ -806,39 +810,46 @@ std::string_view ToString(VariableKind type) { return kVarKindBlockY; case VariableKind::kBlockZ: return kVarKindBlockZ; + case VariableKind::kWarp: + return kVarKindWarp; + case VariableKind::kWarpThread: + return kVarKindWarpThread; } llvm_unreachable("Unknown VariableType"); } -VariableKind ToVariableType(std::string_view type_name) { - if (type_name == kVarKindDefault) return VariableKind::kDefault; - if (type_name == kVarKindThreadX) return VariableKind::kThreadX; - if (type_name == kVarKindThreadY) return VariableKind::kThreadY; - if (type_name == kVarKindThreadZ) return VariableKind::kThreadZ; - if (type_name == kVarKindBlockX) return VariableKind::kBlockX; - if (type_name == kVarKindBlockY) return VariableKind::kBlockY; - if (type_name == kVarKindBlockZ) return VariableKind::kBlockZ; - llvm_unreachable("Unknown VariableType name"); +VariableKind ToVariableType(std::string_view var_name) { + if (var_name == kVarKindThreadX) return VariableKind::kThreadX; + if (var_name == kVarKindThreadY) return VariableKind::kThreadY; + if (var_name == kVarKindThreadZ) return VariableKind::kThreadZ; + if (var_name == kVarKindBlockX) return VariableKind::kBlockX; + if (var_name == kVarKindBlockY) return VariableKind::kBlockY; + if (var_name == kVarKindBlockZ) return VariableKind::kBlockZ; + if (var_name == kVarKindWarp) return VariableKind::kWarp; + if (var_name == kVarKindWarpThread) return VariableKind::kWarpThread; + return VariableKind::kDefault; } std::ostream& operator<<(std::ostream& out, VariableKind var_type) { - out << ToString(var_type); + out << ToVariableName(var_type); return out; } -// Returns the output-to-input indexing map of the first output of `instr` -IndexingMap GetIndexingMapForInstruction(const HloInstruction* instr, - int64_t operand_idx, - mlir::MLIRContext* mlir_context); +std::ostream& operator<<(std::ostream& out, const Interval& interval) { + out << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); + return out; +} std::string Interval::ToString() const { std::stringstream ss; - Print(ss); + ss << *this; return ss.str(); } -void Interval::Print(std::ostream& out) const { - out << '[' << lower << ", " << upper << "]"; +inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, + const Interval& interval) { + os << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); + return os; } int64_t Interval::GetLoopTripCount() const { @@ -958,49 +969,49 @@ Interval Interval::FloorDiv(int64_t rhs) const { return {std::min(a, b), std::max(a, b)}; } -std::ostream& operator<<(std::ostream& out, const Interval& range) { - range.Print(out); - return out; -} - -bool operator==(const DimVar& lhs, const DimVar& rhs) { +bool operator==(const IndexingMap::Variable& lhs, + const IndexingMap::Variable& rhs) { return lhs.bounds == rhs.bounds; } -bool operator==(const RangeVar& lhs, const RangeVar& rhs) { - return lhs.range == rhs.range; -} - -bool operator==(const RTVar& lhs, const RTVar& rhs) { - return lhs.feasible_values == rhs.feasible_values && lhs.hlo == rhs.hlo && - lhs.map == rhs.map; -} - -std::vector DimVarsFromTensorSizes( +std::vector DimVarsFromTensorSizes( absl::Span tensor_sizes) { - std::vector ranges; + std::vector ranges; ranges.reserve(tensor_sizes.size()); for (int64_t size : tensor_sizes) { - ranges.push_back({Interval{0, size - 1}}); + ranges.push_back(IndexingMap::Variable{0, size - 1}); } return ranges; } +std::vector DimVarsFromGPUGrid( + absl::Span grid_sizes) { + CHECK_EQ(grid_sizes.size(), 6) + << "Grid must be 6-dimensional (th_x, th_y, th_z, bl_x, bl_y, bl_z)"; + return { + IndexingMap::Variable{0, grid_sizes[0] - 1, kVarKindThreadX}, + IndexingMap::Variable{0, grid_sizes[1] - 1, kVarKindThreadY}, + IndexingMap::Variable{0, grid_sizes[2] - 1, kVarKindThreadZ}, + IndexingMap::Variable{0, grid_sizes[3] - 1, kVarKindBlockX}, + IndexingMap::Variable{0, grid_sizes[4] - 1, kVarKindBlockY}, + IndexingMap::Variable{0, grid_sizes[5] - 1, kVarKindBlockZ}, + }; +} -std::vector RangeVarsFromTensorSizes( +std::vector RangeVarsFromTensorSizes( absl::Span tensor_sizes) { - std::vector ranges; + std::vector ranges; ranges.reserve(tensor_sizes.size()); for (int64_t size : tensor_sizes) { - ranges.push_back({Interval{0, size - 1}}); + ranges.push_back({IndexingMap::Variable{0, size - 1}}); } return ranges; } IndexingMap::IndexingMap( - AffineMap affine_map, std::vector dimensions, - std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints, - bool is_simplified) + AffineMap affine_map, std::vector dimensions, + std::vector range_vars, + std::vector rt_vars, + absl::Span const> constraints) : affine_map_(affine_map), dim_vars_(std::move(dimensions)), range_vars_(std::move(range_vars)), @@ -1012,12 +1023,12 @@ IndexingMap::IndexingMap( for (const auto& [expr, range] : constraints) { AddConstraint(expr, range); } - is_simplified_ = is_simplified; } IndexingMap::IndexingMap( - AffineMap affine_map, std::vector dimensions, - std::vector range_vars, std::vector rt_vars, + AffineMap affine_map, std::vector dimensions, + std::vector range_vars, + std::vector rt_vars, const llvm::DenseMap& constraints) : affine_map_(affine_map), dim_vars_(std::move(dimensions)), @@ -1032,13 +1043,10 @@ IndexingMap::IndexingMap( IndexingMap IndexingMap::FromTensorSizes( AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, bool is_simplified) { - return IndexingMap{affine_map, - DimVarsFromTensorSizes(dim_upper_bounds), + absl::Span symbol_upper_bounds) { + return IndexingMap{affine_map, DimVarsFromTensorSizes(dim_upper_bounds), RangeVarsFromTensorSizes(symbol_upper_bounds), - /*rt_vars=*/{}, - /*constraints=*/{}, - is_simplified}; + /*rt_vars=*/{}}; } RangeEvaluator IndexingMap::GetRangeEvaluator() const { @@ -1050,7 +1058,6 @@ const Interval& IndexingMap::GetDimensionBound(int64_t dim_id) const { } Interval& IndexingMap::GetMutableDimensionBound(int64_t dim_id) { - is_simplified_ = false; return dim_vars_[dim_id].bounds; } @@ -1068,28 +1075,27 @@ const Interval& IndexingMap::GetSymbolBound(int64_t symbol_id) const { // we have to pick the correct bounds. int64_t range_var_count = GetRangeVarsCount(); return symbol_id < range_var_count - ? range_vars_[symbol_id].range - : rt_vars_[symbol_id - range_var_count].feasible_values; + ? range_vars_[symbol_id].bounds + : rt_vars_[symbol_id - range_var_count].bounds; } Interval& IndexingMap::GetMutableSymbolBound(int64_t symbol_id) { - is_simplified_ = false; // Because affine map symbols are packed like [range_vars, rt_vars], // we have to pick the correct bounds. int64_t range_var_count = GetRangeVarsCount(); return symbol_id < range_var_count - ? range_vars_[symbol_id].range - : rt_vars_[symbol_id - range_var_count].feasible_values; + ? range_vars_[symbol_id].bounds + : rt_vars_[symbol_id - range_var_count].bounds; } std::vector IndexingMap::GetSymbolBounds() const { std::vector bounds; bounds.reserve(affine_map_.getNumSymbols()); for (const auto& range_var : range_vars_) { - bounds.push_back(range_var.range); + bounds.push_back(range_var.bounds); } for (const auto& rt_var : rt_vars_) { - bounds.push_back(rt_var.feasible_values); + bounds.push_back(rt_var.bounds); } return bounds; } @@ -1129,7 +1135,6 @@ void IndexingMap::AddConstraint(mlir::AffineExpr expr, Interval range) { ResetToKnownEmpty(); } } - is_simplified_ = false; } void IndexingMap::EraseConstraint(mlir::AffineExpr expr) { @@ -1256,77 +1261,10 @@ Interval RangeEvaluator::ComputeExpressionRange(AffineExpr expr) { return result; } -std::string IndexingMap::ToString(const AffineMapPrinter& printer) const { - std::stringstream ss; - Print(ss, printer); - return ss.str(); -} - -void PrintRTVars(const std::vector& rt_vars, - int first_rt_var_symbol_index, std::ostream& out, - const AffineMapPrinter& printer) { - for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { - out << printer.GetSymbolName( - static_cast(first_rt_var_symbol_index + index)) - << " in "; - rt_var.feasible_values.Print(out); - out << ", hlo: " - << (rt_var.hlo == nullptr ? "NULL" : rt_var.hlo->ToString()) << ", "; - printer.Print(out, rt_var.map); - out << ", "; - } -} - -void IndexingMap::Print(std::ostream& out, - const AffineMapPrinter& printer) const { - if (IsKnownEmpty()) { - out << "KNOWN EMPTY\n"; - return; - } - printer.Print(out, affine_map_); - if (dim_vars_.empty() && range_vars_.empty() && rt_vars_.empty()) { - return; - } - out << ", domain: "; - for (const auto& [index, dim_var] : llvm::enumerate(dim_vars_)) { - out << printer.GetDimensionName(static_cast(index)) << " in "; - dim_var.bounds.Print(out); - out << ", "; - } - for (const auto& [index, range_var] : llvm::enumerate(range_vars_)) { - out << printer.GetSymbolName(static_cast(index)) << " in "; - range_var.range.Print(out); - out << ", "; - } - int64_t range_vars_count = GetRangeVarsCount(); - PrintRTVars(rt_vars_, /*first_rt_var_symbol_index=*/range_vars_count, out, - printer); - std::vector expr_range_strings; - expr_range_strings.reserve(constraints_.size()); - for (const auto& [expr, range] : constraints_) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - range.Print(ss); - expr_range_strings.push_back(ss.str()); - } - std::sort(expr_range_strings.begin(), expr_range_strings.end()); - for (const auto& expr_range_string : expr_range_strings) { - out << expr_range_string << ", "; - } - out << "is_simplified: " << (is_simplified_ ? "true" : "false"); -} - MLIRContext* IndexingMap::GetMLIRContext() const { return IsUndefined() ? nullptr : affine_map_.getContext(); } -std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { - AffineMapPrinter printer; - indexing_map.Print(out, printer); - return out; -} - bool operator==(const IndexingMap& lhs, const IndexingMap& rhs) { return lhs.GetAffineMap() == rhs.GetAffineMap() && lhs.GetDimVars() == rhs.GetDimVars() && @@ -1339,6 +1277,23 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { return ComposeIndexingMaps(lhs, rhs); } +bool IndexingMap::Verify(std::ostream& out) const { + if (IsUndefined()) { + return true; + } + if (affine_map_.getNumDims() != dim_vars_.size()) { + out << "dim size must match the number of dimensions in " + "the affine map"; + return false; + } + if (affine_map_.getNumSymbols() != range_vars_.size() + rt_vars_.size()) { + out << "range vars size + rt var size must match the number of " + "symbols in the affine map"; + return false; + } + return true; +} + // Simplification of IndexingMap has two main parts. // At first we optimized constraints to make the domain as small and simple as // possible. And only then we simplify the affine_map, because its @@ -1353,9 +1308,7 @@ IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs) { // simplification, because the ranges of constraints were already optimized once // when IndexingMap was constructed. bool IndexingMap::Simplify() { - if (IsSimplified() || IsUndefined() || IsKnownEmpty()) return false; - - bool rtvars_were_eliminated = ReplaceConstantRTVars(); + if (IsUndefined() || IsKnownEmpty()) return false; // Simplify constraints to shrink the lower/upper bounds of dims and symbols. bool constraints_were_simplified = false; @@ -1384,9 +1337,7 @@ bool IndexingMap::Simplify() { if (affine_map_was_simplified) { affine_map_ = simplified_affine_map; } - is_simplified_ = true; - return affine_map_was_simplified || constraints_were_simplified || - rtvars_were_eliminated; + return affine_map_was_simplified || constraints_were_simplified; } bool AffineExprSimplifier::SimplifyConstraintExprs(IndexingMap& map) { @@ -1600,7 +1551,7 @@ bool IndexingMap::CompressVars(const llvm::SmallBitVector& unused_dims, SmallVector dim_replacements; if (num_dims_changed) { affine_map_ = mlir::compressDims(affine_map_, unused_dims); - std::vector compressed_dim_vars; + std::vector compressed_dim_vars; dim_replacements = SmallVector( num_dims_before, getAffineConstantExpr(0, mlir_context)); int64_t used_dims_count = 0; @@ -1619,8 +1570,8 @@ bool IndexingMap::CompressVars(const llvm::SmallBitVector& unused_dims, affine_map_ = mlir::compressSymbols(affine_map_, unused_symbols); symbol_replacements = SmallVector( num_symbols_before, getAffineConstantExpr(0, mlir_context)); - std::vector compressed_range_vars; - std::vector compressed_rt_vars; + std::vector compressed_range_vars; + std::vector compressed_rt_vars; MLIRContext* mlir_context = GetMLIRContext(); int64_t used_symbols_count = 0; auto range_vars_count = range_vars_.size(); @@ -1683,24 +1634,24 @@ void IndexingMap::ResetToKnownEmpty() { dim_var.bounds = Interval{0, -1}; } for (auto& range_var : range_vars_) { - range_var.range = Interval{0, -1}; + range_var.bounds = Interval{0, -1}; } constraints_.clear(); is_known_empty_ = true; - is_simplified_ = true; } bool IndexingMap::VerifyVariableIntervals() { + // TODO: Check if the variable names are unique. return llvm::all_of(dim_vars_, - [](const DimVar& dim_var) { + [](const IndexingMap::Variable& dim_var) { return dim_var.bounds.IsFeasible(); }) && llvm::all_of(range_vars_, - [](const RangeVar& range_var) { - return range_var.range.IsFeasible(); + [](const IndexingMap::Variable& range_var) { + return range_var.bounds.IsFeasible(); }) && - llvm::all_of(rt_vars_, [](const RTVar& rt_var) { - return rt_var.feasible_values.IsFeasible(); + llvm::all_of(rt_vars_, [](const IndexingMap::Variable& rt_var) { + return rt_var.bounds.IsFeasible(); }); } @@ -1813,17 +1764,19 @@ IndexingMap ComposeIndexingMaps(const IndexingMap& first, // The symbols in the composed map, i.e. combined // producer_map.compose(consumer_map) are packed as // [range_vars(second)|rt_vars(second)|range_vars(first)|rt_vars(first)]. - std::vector combined_range_vars; + std::vector combined_range_vars; combined_range_vars.reserve(second.GetRangeVarsCount() + first.GetRangeVarsCount()); - for (const RangeVar& range_var : llvm::concat( - second.GetRangeVars(), first.GetRangeVars())) { + for (const IndexingMap::Variable& range_var : + llvm::concat(second.GetRangeVars(), + first.GetRangeVars())) { combined_range_vars.push_back(range_var); } - std::vector combined_rt_vars; + std::vector combined_rt_vars; combined_rt_vars.reserve(second.GetRTVarsCount() + first.GetRTVarsCount()); - for (const RTVar& rt_var : - llvm::concat(second.GetRTVars(), first.GetRTVars())) { + for (const IndexingMap::Variable& rt_var : + llvm::concat(second.GetRTVars(), + first.GetRTVars())) { combined_rt_vars.push_back(rt_var); } // The symbols in the composed map have to be permuted to keep the invariant @@ -1921,7 +1874,7 @@ bool IndexingMap::RescaleSymbols() { symbol_expr, constant_expr * symbol_expr + shift_value, affine_map_.getNumDims(), affine_map_.getNumSymbols()); - auto& symbol_range = range_vars_[symbol_expr.getPosition()].range; + auto& symbol_range = range_vars_[symbol_expr.getPosition()].bounds; symbol_range.lower = (symbol_range.lower - shift_value) / scaling_factor; symbol_range.upper = (symbol_range.upper - shift_value) / scaling_factor; } @@ -1937,191 +1890,6 @@ bool IndexingMap::RescaleSymbols() { return !to_delete.empty(); } -// The return type of `OptimizeRTVar` below -struct RTVarOptimizationResult { - // An affine expr which maps the old RTVar to the new, optimized RTVar: - // `()[sk] -> s'k` (with k being `symbol_index` in the `OptimizeRTVar` call). - // If `expr` doesn't depend on `sk` it means the RTVar could be optimized - // away completely and the value of `rt_var` can be ignored. - AffineExpr remapped_symbol; - - // The new, optimized RTVar - RTVar rt_var; -}; - -namespace { -// Tries to optimize the given RTVar by removing some parts (or entirety) of -// the dependent HLO graph: -// -// 1. If no optimization is possible it returns `{sk, rt_var}` - the -// identity expr and the unchanged rt_var. -// -// 2. If full optimization is possible, it returns -// `{const, rt_var}` - an affine expr that does not anymore depend -// on `sk` and an arbitrary rt_var. -// -// 3. if partial optimization is possible, it returns -// `{()[sk] -> f(sk), rt_var_new }` - an affine expression that maps from the -// old RTVar to the new RTVar, and the new RTVar itself. The new RTVar now -// references some HLO subgraph of the old RTVar's HLO. -RTVarOptimizationResult OptimizeRTVar(RTVar rt_var, int64_t symbol_index, - MLIRContext* mlir_context) { - const auto symbol = getAffineSymbolExpr(symbol_index, mlir_context); - auto result_expr = symbol; - - while (true) { - if (auto constant_expr = DynCast(rt_var.hlo)) { - if (rt_var.map.isConstant()) { - const auto idx = rt_var.map.getConstantResults(); - result_expr = result_expr.replace( - symbol, getAffineConstantExpr( - constant_expr->literal().GetIntegralAsS64(idx).value(), - mlir_context)); - } - return {result_expr, rt_var}; - } - - if (auto iota_expr = DynCast(rt_var.hlo)) { - auto iota_dimension = iota_expr->iota_dimension(); - CHECK(iota_dimension < rt_var.map.getNumResults()); - return { - result_expr.replace(symbol, rt_var.map.getResults()[iota_dimension]), - rt_var}; - } - - auto is_indexing_transformation = [](const HloInstruction* instr) { - return instr->opcode() == HloOpcode::kBitcast || - instr->opcode() == HloOpcode::kBroadcast || - instr->opcode() == HloOpcode::kReshape || - instr->opcode() == HloOpcode::kReverse || - instr->opcode() == HloOpcode::kSlice || - instr->opcode() == HloOpcode::kTranspose; - }; - - if (is_indexing_transformation(rt_var.hlo)) { - auto instr_indexing_map = - GetIndexingMapForInstruction(rt_var.hlo, 0, mlir_context); - - rt_var.hlo = rt_var.hlo->operand(0); - rt_var.map = instr_indexing_map.GetAffineMap().compose(rt_var.map); - continue; - } - - if (rt_var.hlo->opcode() == HloOpcode::kNegate) { - rt_var.hlo = rt_var.hlo->operand(0); - result_expr = result_expr.replace(symbol, -symbol); - continue; - } - - if (rt_var.hlo->opcode() == HloOpcode::kAdd || - rt_var.hlo->opcode() == HloOpcode::kSubtract || - rt_var.hlo->opcode() == HloOpcode::kMultiply || - rt_var.hlo->opcode() == HloOpcode::kDivide) { - const auto apply_op = [&](const AffineExpr& lhs, - const AffineExpr& rhs) -> AffineExpr { - switch (rt_var.hlo->opcode()) { - case HloOpcode::kAdd: - return lhs + rhs; - case HloOpcode::kSubtract: - return lhs - rhs; - case HloOpcode::kMultiply: - return lhs * rhs; - case HloOpcode::kDivide: - return lhs.floorDiv(rhs); - default: - ABSL_UNREACHABLE(); - } - }; - - auto lhs = OptimizeRTVar( - RTVar{rt_var.feasible_values, rt_var.hlo->operand(0), rt_var.map}, - symbol_index, mlir_context); - - if (!lhs.remapped_symbol.isFunctionOfSymbol(symbol_index)) { - // This means that lhs is constant-like and we can eliminate the - // operand. - result_expr = - result_expr.replace(symbol, apply_op(lhs.remapped_symbol, symbol)); - - // We continue optimizing the `rhs` operand - rt_var.hlo = rt_var.hlo->operand(1); - continue; - } - - auto rhs = OptimizeRTVar( - RTVar{rt_var.feasible_values, rt_var.hlo->operand(1), rt_var.map}, - symbol_index, mlir_context); - - if (!rhs.remapped_symbol.isFunctionOfSymbol(symbol_index)) { - // This means that rhs is constant-like and we can eliminate the - // operand. - result_expr = - result_expr.replace(symbol, apply_op(symbol, rhs.remapped_symbol)); - - // We can also take advantage of the optimization already done for lhs: - result_expr = result_expr.replace(symbol, lhs.remapped_symbol); - rt_var = lhs.rt_var; - continue; - } - } - - return {result_expr, rt_var}; - } -} -} // namespace - -bool IndexingMap::ReplaceConstantRTVars() { - if (rt_vars_.empty()) return false; - - bool did_simplify = false; - - for (auto index = 0; index < rt_vars_.size(); ++index) { - auto& rt_var = rt_vars_[index]; - - // range_vars and rt_vars share the symbol space, with the rt_vars coming - // after the range_vars. - auto symbol_index = range_vars_.size() + index; - auto rt_var_symbol = getAffineSymbolExpr(symbol_index, GetMLIRContext()); - - RTVarOptimizationResult result = - OptimizeRTVar(rt_var, symbol_index, GetMLIRContext()); - - if (result.remapped_symbol != rt_var_symbol) { - did_simplify = true; - affine_map_ = affine_map_.replace( - {{rt_var_symbol, result.remapped_symbol}}, affine_map_.getNumDims(), - affine_map_.getNumSymbols()); - - llvm::DenseMap replacements; - - for (const auto& [constraint, interval] : constraints_) { - auto modified_constraint = - constraint.replace(rt_var_symbol, result.remapped_symbol); - - if (constraint == modified_constraint) continue; - replacements[constraint] = modified_constraint; - } - - for (const auto& [old_expr, new_expr] : replacements) { - auto interval = constraints_.at(old_expr); - constraints_.erase(old_expr); - constraints_[new_expr] = interval; - } - } - - if (result.remapped_symbol.isFunctionOfSymbol(symbol_index)) { - // If we still depend on the rt_var, then we update it. - if (rt_var != result.rt_var) { - rt_var = std::move(result.rt_var); - did_simplify = true; - } - } else { - did_simplify = true; - } - } - return did_simplify; -} - bool IndexingMap::IsRangeVarSymbol(mlir::AffineSymbolExpr symbol) const { unsigned int position = symbol.getPosition(); CHECK_LE(position, GetSymbolCount()); @@ -2144,7 +1912,7 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { MLIRContext* mlir_context = GetMLIRContext(); int64_t num_vars = num_dims + num_symbols; - std::vector new_dim_vars; + std::vector new_dim_vars; new_dim_vars.reserve(num_vars); // // Populate the existing dims. @@ -2153,13 +1921,10 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { // Capture the existing symbols as dims. SmallVector syms_replacements; int64_t symbol_id = num_dims; - for (const auto& range_var : range_vars_) { - syms_replacements.push_back(getAffineDimExpr(symbol_id++, mlir_context)); - new_dim_vars.push_back(DimVar{range_var.range}); - } - for (const auto& rt_var : rt_vars_) { + for (const IndexingMap::Variable& var : + llvm::concat(range_vars_, rt_vars_)) { syms_replacements.push_back(getAffineDimExpr(symbol_id++, mlir_context)); - new_dim_vars.push_back(DimVar{rt_var.feasible_values}); + new_dim_vars.push_back(IndexingMap::Variable{var.bounds}); } // Update constraints. @@ -2172,8 +1937,7 @@ IndexingMap IndexingMap::ConvertSymbolsToDimensions() const { AffineMap canonical_map = affine_map_.replaceDimsAndSymbols({}, syms_replacements, num_vars, 0); IndexingMap new_indexing_map(canonical_map, new_dim_vars, /*range_vars=*/{}, - /*rt_vars=*/{}, new_constraints, - /*is_simplified=*/false); + /*rt_vars=*/{}, new_constraints); return new_indexing_map; } diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 25d40abd47c3f1..980792c251a95e 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -26,7 +26,6 @@ limitations under the License. #include #include -#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Hashing.h" @@ -37,7 +36,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/affine_map_printer.h" namespace xla { namespace gpu { @@ -52,17 +50,19 @@ enum class VariableKind : char { kThreadX, kThreadY, kThreadZ, + // GPU warp ID. + kWarp, + // GPU thread ID in the warp. + kWarpThread }; -std::string_view ToString(VariableKind type); -VariableKind ToVariableType(std::string_view type_name); +std::string_view ToVariableName(VariableKind var_kind); +VariableKind ToVariableType(std::string_view var_name); std::ostream& operator<<(std::ostream& out, VariableKind var_type); // Interval represents a closed interval [lower_bound, upper_bound]. struct Interval { std::string ToString() const; - void Print(std::ostream& out) const; - bool IsPoint() const { return lower == upper; } bool IsFeasible() const { return lower <= upper; } @@ -161,12 +161,9 @@ struct Interval { int64_t upper = 0; }; -std::ostream& operator<<(std::ostream& out, const Interval& range); +std::ostream& operator<<(std::ostream& out, const Interval& interval); inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, - const Interval& interval) { - os << absl::StrFormat("[%d, %d]", interval.lower, interval.upper); - return os; -} + const Interval& interval); template H AbslHashValue(H h, const Interval& range) { @@ -205,82 +202,15 @@ class RangeEvaluator { bool use_constraints_; }; -// Dimension variable represents a dimension of a tensor or a GPU grid. -// Dimensions correspond to the dimension parameter of `affine_map_`. -struct DimVar { - Interval bounds; -}; -bool operator==(const DimVar& lhs, const DimVar& rhs); -inline bool operator!=(const DimVar& lhs, const DimVar& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const DimVar& dimension) { - return H::combine(std::move(h), dimension.bounds); -} - -inline size_t hash_value(const DimVar& dim_var) { - return llvm::hash_combine(dim_var.bounds); -} - -// RangeSymbol variable represents a range of values, e.g. to compute a single -// element of the reduction's result we need a range of values from the input -// tensor. RangeSymbol variables correspond to the front portion of the -// symbols in `affine_map_`. -struct RangeVar { - Interval range; -}; -bool operator==(const RangeVar& lhs, const RangeVar& rhs); -inline bool operator!=(const RangeVar& lhs, const RangeVar& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const RangeVar& range_var) { - return H::combine(std::move(h), range_var.range); -} - -inline size_t hash_value(const RangeVar& range_var) { - return llvm::hash_combine(range_var.range); -} - -// RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in -// HLO dynamic-update-slice op. RTSymbol variables correspond to the back -// portion of the symbols in `affine_map_`. -struct RTVar { - Interval feasible_values; - const HloInstruction* hlo; - // This is a map from the iteration space of the corresponding indexing map to - // the iteration space of `hlo`. It shows what element of `hlo` we need to - // extract to get the runtime value for the RTVar. - mlir::AffineMap map; -}; -bool operator==(const RTVar& lhs, const RTVar& rhs); -inline bool operator!=(const RTVar& lhs, const RTVar& rhs) { - return !(lhs == rhs); -} - -template -H AbslHashValue(H h, const RTVar& rt_var) { - llvm::hash_code map_hash = llvm::hash_combine(rt_var.map); - return H::combine(std::move(h), rt_var.feasible_values, rt_var.hlo, - static_cast(map_hash)); -} - -std::vector DimVarsFromTensorSizes( - absl::Span tensor_sizes); - -std::vector RangeVarsFromTensorSizes( - absl::Span tensor_sizes); - -// Contains an affine map with N dimension expressions and M symbols: -// (d0, ..., d_{N - 1})[s_0, ..., s_{M - 1}] -> f(d_i, s_j) -// Dimensions d_i correspond to the iteration space of the output tensor. Some -// or all of the dimensions of the input operands can be expressed as a function -// of dimensions of output. For example, for broadcasts and cwise ops all -// dimensions of the inputs are covered by the output dimensions. -// Domain specifies for what ranges of values the indexing map is specified. +// Contains an affine map with N dimension expressions and M + K symbols: +// (d0, ..., d_{N - 1})[s_0, ..., s_{M - 1}]{r_0, ..., r_{K - 1}} -> f(d_i, s_j) +// Dimensions d_i correspond to the iteration space of the output tensor. +// Symbols s_j correspond to ranges of the input dimensions. +// Runtime variables r_k correspond to the runtime variables. +// Some or all of the dimensions of the input operands can be expressed as a +// function of dimensions of output. For example, for broadcasts and cwise ops +// all dimensions of the inputs are covered by the output dimensions. Domain +// specifies for what ranges of values the indexing map is specified. // // Example: // @@ -298,17 +228,28 @@ std::vector RangeVarsFromTensorSizes( // reverse = f32[1, 17, 9, 9] reverse(%p0), dimensions={1, 2} // ``` // can be written as `(d0, d1, d2, d3) -> (d0, -d1 + 16, -d2 + 8, d3)` with -// d0 in [0, 1), d1 in [0, 16], d2 in [0, 8] and d3 in [0, 8]. +// d0 in [0, 0], d1 in [0, 16], d2 in [0, 8] and d3 in [0, 8]. class IndexingMap { public: + // Variable represents dimension, range or runtime variable. + struct Variable { + Variable() = default; + explicit Variable(Interval bounds, llvm::StringRef name = "") + : bounds(bounds), name(name) {} + Variable(int64_t lb, int64_t ub, llvm::StringRef name = "") + : Variable(Interval{lb, ub}, name) {} + + Interval bounds; + std::string name = ""; + }; + IndexingMap( - mlir::AffineMap affine_map, std::vector dimensions, - std::vector range_vars, std::vector rt_vars, - absl::Span const> constraints = {}, - bool is_simplified = false); + mlir::AffineMap affine_map, std::vector dimensions, + std::vector range_vars, std::vector rt_vars, + absl::Span const> constraints = {}); - IndexingMap(mlir::AffineMap affine_map, std::vector dimensions, - std::vector range_vars, std::vector rt_vars, + IndexingMap(mlir::AffineMap affine_map, std::vector dimensions, + std::vector range_vars, std::vector rt_vars, const llvm::DenseMap& constraints); IndexingMap(const IndexingMap&) = default; @@ -321,13 +262,10 @@ class IndexingMap { static IndexingMap FromTensorSizes( mlir::AffineMap affine_map, absl::Span dim_upper_bounds, - absl::Span symbol_upper_bounds, - bool is_simplified = false); + absl::Span symbol_upper_bounds); - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + // Returns true if the indexing map is valid. + bool Verify(std::ostream& out) const; // Returns true if the map was simplified. bool Simplify(); @@ -346,18 +284,18 @@ class IndexingMap { RangeEvaluator GetRangeEvaluator() const; // Getters for dimension vars. - const DimVar& GetDimVars(int64_t id) const { return dim_vars_[id]; } - const std::vector& GetDimVars() const { return dim_vars_; } + const Variable& GetDimVars(int64_t id) const { return dim_vars_[id]; } + const std::vector& GetDimVars() const { return dim_vars_; } int64_t GetDimVarsCount() const { return dim_vars_.size(); } // Getters for range vars. - const RangeVar& GetRangeVar(int64_t id) const { return range_vars_[id]; } - const std::vector& GetRangeVars() const { return range_vars_; } + const Variable& GetRangeVar(int64_t id) const { return range_vars_[id]; } + const std::vector& GetRangeVars() const { return range_vars_; } int64_t GetRangeVarsCount() const { return range_vars_.size(); } // Getters for runtime vars. - const RTVar& GetRTVar(int64_t id) const { return rt_vars_[id]; } - const std::vector& GetRTVars() const { return rt_vars_; } + const Variable& GetRTVar(int64_t id) const { return rt_vars_[id]; } + const std::vector& GetRTVars() const { return rt_vars_; } int64_t GetRTVarsCount() const { return rt_vars_.size(); } // Gets bounds of `affine_map_` dimensions. @@ -406,10 +344,6 @@ class IndexingMap { // satisfies both constraints. bool IsKnownEmpty() const { return is_known_empty_; } - // Returns true if the indexing map is simplified. - void SetIsSimplified(bool is_simplified) { is_simplified_ = is_simplified; } - bool IsSimplified() const { return is_simplified_; } - bool IsUndefined() const { return affine_map_ == mlir::AffineMap(); } // Removes unused symbols from the `affine_map_` and constraints. @@ -440,9 +374,9 @@ class IndexingMap { // Returns a new indexing map with all RangeVars and RTVars converted to // DimVars. // For example, - // (d0, d1, d2)[s0, s1] -> (d0, d1, d2, s0, s1) + // (d0, d1, d2)[s0, s1]{r0} -> (d0, d1, d2, s0, s1, r0) // will be converted to - // (d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4) + // (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5) IndexingMap ConvertSymbolsToDimensions() const; private: @@ -452,10 +386,6 @@ class IndexingMap { // Returns true if simplification was performed. bool MergeModConstraints(); - // Replace RTVars that yield constants by indexing expressions. - // Returns true if simplification was performed. - bool ReplaceConstantRTVars(); - // Removes DimVars, RangeVars, RTVars that correspond to the unused dimensions // and symbols. If unused_dims is empty, then dims won't be removed. The same // applies to unused_symbols. Returns true, if anything was removed. @@ -463,8 +393,8 @@ class IndexingMap { const llvm::SmallBitVector& unused_symbols); // Resets the indexing map to the canonical "known" empty indexing map, i.e. - // (d0...)[s0...] -> (0...) affine map. Does not change the number of symbols, - // dimensions or results. + // (d0...)[s0...]{r0...} -> (0...) affine map. + // Does not change the number of symbols, dimensions or results. void ResetToKnownEmpty(); // Verify if all intervals for DimVars, RangeVars and RTVars are feasible. @@ -474,17 +404,28 @@ class IndexingMap { bool VerifyConstraintIntervals(); mlir::AffineMap affine_map_; - std::vector dim_vars_; - std::vector range_vars_; - std::vector rt_vars_; + + // Dimension variable represents a dimension of a tensor or a GPU grid. + // Dimensions correspond to the dimension parameter of `affine_map_`. + std::vector dim_vars_; + + // RangeSymbol variable represents a range of values, e.g. to compute a single + // element of the reduction's result we need a range of values from the input + // tensor. RangeSymbol variables correspond to the front portion of the + // symbols in `affine_map_`. + std::vector range_vars_; + + // RTSymbol variable represents a runtime symbol, e.g. a dynamic offset in + // HLO dynamic-update-slice op. RTSymbol variables correspond to the back + // portion of the symbols in `affine_map_`. + std::vector rt_vars_; + // Inequality constraints for affine expressions. They restrict the feasible // set for the domain of the indexing map. It contains affine expressions // other than AffineDimExpr and AffineSymbolExpr. llvm::DenseMap constraints_; // Flag to indicate that the domain is empty. bool is_known_empty_ = false; - // Flag to indicate that the indexing map is simplified. - bool is_simplified_ = false; }; std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); bool operator==(const IndexingMap& lhs, const IndexingMap& rhs); @@ -493,25 +434,26 @@ inline bool operator!=(const IndexingMap& lhs, const IndexingMap& rhs) { } IndexingMap operator*(const IndexingMap& lhs, const IndexingMap& rhs); +bool operator==(const IndexingMap::Variable& lhs, + const IndexingMap::Variable& rhs); +inline bool operator!=(const IndexingMap::Variable& lhs, + const IndexingMap::Variable& rhs) { + return !(lhs == rhs); +} + +template +H AbslHashValue(H h, const IndexingMap::Variable& dimension) { + return H::combine(std::move(h), dimension.bounds); +} + +inline size_t hash_value(const IndexingMap::Variable& dim_var) { + return llvm::hash_combine(dim_var.bounds); +} + // Composes affine maps, i.e. second ∘ first. IndexingMap ComposeIndexingMaps(const IndexingMap& first, const IndexingMap& second); -// Prints the RTVars. -// -// This is exposed to allow SymbolicTile to reuse it. -// -// `first_rt_var_symbol_index`: The index of the symbol associated with the -// first RTVar. The RTVars will be printed with consequent symbol indices -// starting with `first_rt_var_symbol_index`. For example, if `rt_vars.size() -// == 3` and `first_rt_var_symbol_index == 4`, then the symbol names "s4", -// "s5" and "s6" will be used. -// -// TODO(b/334043862): Unexpose this function if possible. -void PrintRTVars(const std::vector& rt_vars, - int first_rt_var_symbol_index, std::ostream& out, - const AffineMapPrinter& printer); - template H AbslHashValue(H h, const IndexingMap& indexing_map) { llvm::hash_code affine_map_hash = @@ -529,6 +471,15 @@ H AbslHashValue(H h, const IndexingMap& indexing_map) { return h; } +std::vector DimVarsFromTensorSizes( + absl::Span tensor_sizes); + +std::vector DimVarsFromGPUGrid( + absl::Span grid_sizes); + +std::vector RangeVarsFromTensorSizes( + absl::Span tensor_sizes); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc new file mode 100644 index 00000000000000..5c929c243cc7f5 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.cc @@ -0,0 +1,922 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map_serialization.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +using llvm::SmallVector; +using llvm::SmallVectorImpl; +using llvm::StringRef; +using mlir::AffineBinaryOpExpr; +using mlir::AffineConstantExpr; +using mlir::AffineDimExpr; +using mlir::AffineExpr; +using mlir::AffineExprKind; +using mlir::AffineMap; +using mlir::AffineMapAttr; +using mlir::AffineSymbolExpr; +using mlir::ArrayRef; +using mlir::MLIRContext; + +enum class Delimeter { kParen, kBracket, kBrace }; + +struct Token { + enum class Kind { + // Variable name, e.g. "d0", "s1". + kVarName, + // Integer literal. + kIntLiteral, + kBoolLiteral, + // Keywords + kKeywordDomain, + kKeywordIn, + kKeywordIsSimplified, + // Arithmetic operation, e.g. "+", "-", "*", "floorDiv", "mod". + kPlus, + kMinus, + kTimes, + kFloorDiv, + kMod, + // Punctuation. + kArrow, + kLParen, // ( + kRParen, // ) + kLBracket, // [ + kRBracket, // ] + kLBrace, // { + kRBrace, // } + kComma, + kColon, + // Status. + kError, + kEOF + }; + StringRef spelling; + Token::Kind kind; +}; + +Token::Kind GetSingleCharTokenType(char c) { + switch (c) { + case '(': + return Token::Kind::kLParen; + case ')': + return Token::Kind::kRParen; + case '[': + return Token::Kind::kLBracket; + case ']': + return Token::Kind::kRBracket; + case '{': + return Token::Kind::kLBrace; + case '}': + return Token::Kind::kRBrace; + case ',': + return Token::Kind::kComma; + case ':': + return Token::Kind::kColon; + case '+': + return Token::Kind::kPlus; + case '-': + return Token::Kind::kMinus; + case '*': + return Token::Kind::kTimes; + default: + return Token::Kind::kError; + } +} + +bool IsPartOfAffineExpr(Token token) { + return token.kind == Token::Kind::kVarName || + token.kind == Token::Kind::kIntLiteral || + token.kind == Token::Kind::kPlus || + token.kind == Token::Kind::kMinus || + token.kind == Token::Kind::kTimes || + token.kind == Token::Kind::kFloorDiv || + token.kind == Token::Kind::kMod; +} + +class Parser { + public: + explicit Parser(llvm::StringRef input) : input_(input), it_(input.begin()) { + // Set the parser to the first token. + current_token_ = GetNextTokenImpl(); + } + + const Token& GetCurrentToken() const { return current_token_; }; + void Advance() { + if (current_token_.kind == Token::Kind::kError || + current_token_.kind == Token::Kind::kEOF) { + return; + } + current_token_ = GetNextTokenImpl(); + } + Token GetNextToken() { + Advance(); + return current_token_; + } + + bool ConsumeToken(Token::Kind kind); + bool ParseVarName(std::string* var_name); + bool ParseInt(int64_t* value); + bool ParseBool(bool* boolean); + bool ParseInterval(Interval* interval); + bool ParseAffineExprString(std::string* affine_expr_str); + std::pair GetDelimiterPair(Delimeter delimeter); + bool ParseCommaSeparatedVarList( + Delimeter delimeter, + llvm::function_ref parse_element_fn); + + private: + void ConsumeWhitespace() { + while (it_ != input_.end() && std::isspace(*it_)) ++it_; + } + + // Parses the next token from the input and sets the iterator to the position + // right after it. + Token GetNextTokenImpl(); + + llvm::StringRef input_; + llvm::StringRef::iterator it_; + Token current_token_; +}; + +bool Parser::ParseVarName(std::string* var_name) { + if (current_token_.kind != Token::Kind::kVarName) { + llvm::errs() << "Expected var name, got: " << current_token_.spelling + << "\n"; + return false; + } + *var_name = current_token_.spelling.str(); + Advance(); + return true; +} + +bool Parser::ParseInt(int64_t* value) { + int val; + if (current_token_.kind != Token::Kind::kIntLiteral || + current_token_.spelling.getAsInteger(/*radix=*/0, val)) { + llvm::errs() << "Expected int literal, got: " << current_token_.spelling + << "\n"; + return false; + } + *value = static_cast(val); + Advance(); + return true; +} + +bool Parser::ParseBool(bool* boolean) { + if (current_token_.kind != Token::Kind::kBoolLiteral) { + llvm::errs() << "Expected bool literal, got: " << current_token_.spelling + << "\n"; + return false; + } + *boolean = current_token_.spelling.compare("true") == 0; + Advance(); + return true; +} + +bool Parser::ParseInterval(Interval* interval) { + if (!ConsumeToken(Token::Kind::kLBracket) || !ParseInt(&interval->lower) || + !ConsumeToken(Token::Kind::kComma) || !ParseInt(&interval->upper) || + !ConsumeToken(Token::Kind::kRBracket)) { + return false; + } + return interval; +} + +bool Parser::ParseAffineExprString(std::string* affine_expr_str) { + unsigned num_unmatched_parens = 0; + while (true) { + if (IsPartOfAffineExpr(current_token_)) { + affine_expr_str->append(current_token_.spelling); + affine_expr_str->push_back(' '); + Advance(); + continue; + } + if (ConsumeToken(Token::Kind::kLParen)) { + affine_expr_str->push_back('('); + ++num_unmatched_parens; + continue; + } + if (current_token_.kind == Token::Kind::kRParen && + num_unmatched_parens > 0) { + affine_expr_str->push_back(')'); + --num_unmatched_parens; + Advance(); + continue; + } + break; + } + return current_token_.kind != Token::Kind::kError; +} + +std::pair Parser::GetDelimiterPair( + Delimeter delimeter) { + switch (delimeter) { + case Delimeter::kParen: + return {Token::Kind::kLParen, Token::Kind::kRParen}; + case Delimeter::kBracket: + return {Token::Kind::kLBracket, Token::Kind::kRBracket}; + case Delimeter::kBrace: + return {Token::Kind::kLBrace, Token::Kind::kRBrace}; + default: + llvm::errs() << "Unsupported delimiter: " << static_cast(delimeter) + << "\n"; + return {Token::Kind::kError, Token::Kind::kError}; + } +} + +bool Parser::ParseCommaSeparatedVarList( + Delimeter delimeter, + llvm::function_ref parse_element_fn) { + auto [left_delimiter, right_delimiter] = GetDelimiterPair(delimeter); + if (!ConsumeToken(left_delimiter)) { + return false; + } + if (ConsumeToken(right_delimiter)) { + return true; + } + std::string element; + while (parse_element_fn(*this)) { + if (ConsumeToken(Token::Kind::kComma)) continue; + return ConsumeToken(right_delimiter); + } + return false; +} + +bool Parser::ConsumeToken(Token::Kind kind) { + Token token = GetCurrentToken(); + if (token.kind != kind) { + return false; + } + GetNextToken(); + return true; +} + +Token Parser::GetNextTokenImpl() { + ConsumeWhitespace(); + if (it_ == input_.end()) { + return Token{"", Token::Kind::kEOF}; + } + auto start = it_; + if (std::isalpha(*it_)) { + // Variable name. + while (it_ != input_.end() && + (std::isalpha(*it_) || std::isdigit(*it_) || *it_ == '_')) { + ++it_; + } + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + if (spelling == "true" || spelling == "false") { + return Token{spelling, Token::Kind::kBoolLiteral}; + } + if (spelling == "domain") { + return Token{spelling, Token::Kind::kKeywordDomain}; + } + if (spelling == "in") { + return Token{spelling, Token::Kind::kKeywordIn}; + } + if (spelling == "mod") { + return Token{spelling, Token::Kind::kMod}; + } + if (spelling == "floorDiv") { + return Token{spelling, Token::Kind::kFloorDiv}; + } + return Token{spelling, Token::Kind::kVarName}; + } + if (std::isdigit(*it_)) { + auto start = it_; + while (it_ != input_.end() && std::isdigit(*it_)) { + ++it_; + } + + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + return Token{spelling, Token::Kind::kIntLiteral}; + } + if (*it_ == '-') { + ++it_; + if (it_ != input_.end()) { + if (*it_ == '>') { + ++it_; + return Token{"->", Token::Kind::kArrow}; + } else if (std::isdigit(*it_)) { + auto start = it_ - 1; + while (it_ != input_.end() && std::isdigit(*it_)) { + ++it_; + } + StringRef spelling = input_.substr(start - input_.data(), it_ - start); + return Token{spelling, Token::Kind::kIntLiteral}; + } else { + return Token{"-", Token::Kind::kMinus}; + } + } + } + StringRef spelling = input_.substr(start - input_.data(), 1); + return Token{spelling, GetSingleCharTokenType(*(it_++))}; +} + +// Parses a comma separated list of variable names. It is used to parse the +// lists of dimension and symbol variables. +bool ParseVarNames(Parser& parser, Delimeter delimeter, + SmallVectorImpl& var_names) { + auto parse_var_name_fn = [&](Parser& parser) { + std::string var_name; + if (!parser.ParseVarName(&var_name)) { + return false; + } + var_names.push_back(var_name); + return true; + }; + return parser.ParseCommaSeparatedVarList(delimeter, parse_var_name_fn); +} + +// Parses a comma separated list of affine expressions. It is used to parse +// the list of affine map results. +bool ParseAffineMapResults(Parser& parser, + SmallVectorImpl& affine_expr_strs) { + auto parse_var_name_fn = [&](Parser& parser) { + std::string affine_expr_str; + if (!parser.ParseAffineExprString(&affine_expr_str)) { + return false; + } + affine_expr_strs.push_back(affine_expr_str); + return true; + }; + return parser.ParseCommaSeparatedVarList(Delimeter::kParen, + parse_var_name_fn); +} + +// Assembles an affine map from the given dimension and symbol names and the +// affine expressions for the results. +bool ParseAffineExprsWithMLIR(ArrayRef dim_var_names, + ArrayRef symbol_var_names, + ArrayRef affine_expr_strings, + MLIRContext* context, + SmallVectorImpl& affine_exprs) { + std::stringstream ss; + ss << "affine_map<(" << absl::StrJoin(dim_var_names, ", ") << ") "; + if (!symbol_var_names.empty()) { + ss << '[' << absl::StrJoin(symbol_var_names, ", ") << "] "; + } + ss << " -> (" << absl::StrJoin(affine_expr_strings, ", ") << ")>"; + auto affine_map_attr = mlir::parseAttribute(ss.str(), context); + if (!affine_map_attr) { + llvm::errs() << "Failed to parse affine map: " << ss.str() << "\n"; + return false; + } + AffineMap affine_map = mlir::cast(affine_map_attr).getValue(); + affine_exprs = llvm::to_vector(affine_map.getResults()); + return true; +} + +std::string GetVarName(int64_t id, std::string_view name, + std::string_view prefix) { + if (!name.empty()) { + return std::string(name); + } + return absl::StrFormat("%s%d", prefix, id); +} + +std::string GetDimVarName(int64_t dim_id, std::string_view dim_name = "") { + return GetVarName(dim_id, dim_name, "d"); +} + +std::string GetRangeVarName(int64_t range_id, + std::string_view range_name = "") { + return GetVarName(range_id, range_name, "s"); +} + +std::string GetRTVarName(int64_t rt_id, std::string_view rt_name = "") { + return GetVarName(rt_id, rt_name, "rt"); +} + +std::string GetAffineSymbolName( + int64_t id, absl::Span symbol_names = {}) { + if (id < symbol_names.size()) { + const auto& name = symbol_names[id]; + if (!name.empty()) { + return name; + } + } + return absl::StrFormat("%s%d", "s", id); +} + +std::string GetAffineDimensionName( + int64_t id, absl::Span dim_names = {}) { + if (id < dim_names.size()) { + const auto& name = dim_names[id]; + if (!name.empty()) { + return name; + } + } + return absl::StrFormat("%s%d", "d", id); +} + +void PrintAffineExprImpl(const AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names, + bool add_parentheses, llvm::raw_ostream& os) { + const char* binopSpelling = nullptr; + switch (affine_expr.getKind()) { + case AffineExprKind::SymbolId: { + unsigned symbol_id = + mlir::cast(affine_expr).getPosition(); + os << GetAffineSymbolName(symbol_id, symbol_names); + return; + } + case AffineExprKind::DimId: { + unsigned dim_id = mlir::cast(affine_expr).getPosition(); + os << GetAffineDimensionName(dim_id, dim_names); + return; + } + case AffineExprKind::Constant: + os << mlir::cast(affine_expr).getValue(); + return; + case AffineExprKind::Add: + binopSpelling = " + "; + break; + case AffineExprKind::Mul: + binopSpelling = " * "; + break; + case AffineExprKind::FloorDiv: + binopSpelling = " floordiv "; + break; + case AffineExprKind::CeilDiv: + binopSpelling = " ceildiv "; + break; + case AffineExprKind::Mod: + binopSpelling = " mod "; + break; + } + + auto binOp = mlir::cast(affine_expr); + AffineExpr lhsExpr = binOp.getLHS(); + AffineExpr rhsExpr = binOp.getRHS(); + + // Handle tightly binding binary operators. + if (binOp.getKind() != AffineExprKind::Add) { + if (add_parentheses) { + os << '('; + } + + // Pretty print multiplication with -1. + auto rhsConst = mlir::dyn_cast(rhsExpr); + if (rhsConst && binOp.getKind() == AffineExprKind::Mul && + rhsConst.getValue() == -1) { + os << "-"; + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + if (add_parentheses) { + os << ')'; + } + return; + } + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + + os << binopSpelling; + PrintAffineExprImpl(rhsExpr, dim_names, symbol_names, + /*add_parentheses=*/true, os); + + if (add_parentheses) { + os << ')'; + } + return; + } + + // Print out special "pretty" forms for add. + if (add_parentheses) { + os << '('; + } + + // Pretty print addition to a product that has a negative operand as a + // subtraction. + if (auto rhs = mlir::dyn_cast(rhsExpr)) { + if (rhs.getKind() == AffineExprKind::Mul) { + AffineExpr rrhsExpr = rhs.getRHS(); + if (auto rrhs = mlir::dyn_cast(rrhsExpr)) { + if (rrhs.getValue() == -1) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - "; + if (rhs.getLHS().getKind() == AffineExprKind::Add) { + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/true, os); + } else { + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/false, os); + } + if (add_parentheses) { + os << ')'; + } + return; + } + + if (rrhs.getValue() < -1) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - "; + PrintAffineExprImpl(rhs.getLHS(), dim_names, symbol_names, + /*add_parentheses=*/true, os); + os << " * " << -rrhs.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + } + } + + // Pretty print addition to a negative number as a subtraction. + if (auto rhsConst = mlir::dyn_cast(rhsExpr)) { + if (rhsConst.getValue() < 0) { + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + os << " - " << -rhsConst.getValue(); + if (add_parentheses) { + os << ')'; + } + return; + } + } + + PrintAffineExprImpl(lhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + + os << " + "; + PrintAffineExprImpl(rhsExpr, dim_names, symbol_names, + /*add_parentheses=*/false, os); + + if (add_parentheses) { + os << ')'; + } +} + +} // namespace + +std::optional ParseIndexingMap(llvm::StringRef input, + MLIRContext* context) { + Parser parser(input); + + // Parse variable names. + SmallVector dim_var_names; + SmallVector range_var_names; + SmallVector rt_var_names; + if (!ParseVarNames(parser, Delimeter::kParen, dim_var_names) || + (parser.GetCurrentToken().kind == Token::Kind::kLBracket && + !ParseVarNames(parser, Delimeter::kBracket, range_var_names)) || + (parser.GetCurrentToken().kind == Token::Kind::kLBrace && + !ParseVarNames(parser, Delimeter::kBrace, rt_var_names))) { + llvm::errs() << "Failed to parse variable names\n"; + return std::nullopt; + } + + // Parse affine map results. + SmallVector affine_expr_strs; + if (!parser.ConsumeToken(Token::Kind::kArrow) || + !ParseAffineMapResults(parser, affine_expr_strs)) { + llvm::errs() << "Failed to parse affine map results\n"; + return std::nullopt; + } + int num_affine_map_results = affine_expr_strs.size(); + + // Special case: no domain is printed for the empty map. + if (dim_var_names.empty() && range_var_names.empty() && + rt_var_names.empty()) { + if (num_affine_map_results != 0 || + parser.GetCurrentToken().kind != Token::Kind::kEOF) { + llvm::errs() << "Expected an empty indexing map\n"; + return std::nullopt; + } + return IndexingMap{AffineMap::get(context), /*dimensions=*/{}, + /*range_vars=*/{}, /*rt_vars=*/{}}; + } + + if (!parser.ConsumeToken(Token::Kind::kComma) || + !parser.ConsumeToken(Token::Kind::kKeywordDomain) || + !parser.ConsumeToken(Token::Kind::kColon)) { + llvm::errs() << "Failed to parse domain keyword\n"; + return std::nullopt; + } + // Parse dimension variables. + std::vector dim_vars; + for (const auto& [dim_id, dim_name] : llvm::enumerate(dim_var_names)) { + std::string var_name; + Interval interval; + if (!parser.ParseVarName(&var_name) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { + llvm::errs() << "Failed to parse DimVar " << dim_name << " interval\n"; + return std::nullopt; + } + if (var_name != dim_name) { + llvm::errs() << "Dimension name mismatch " << dim_name + << " != " << var_name << "\n"; + return std::nullopt; + } + if (var_name == GetDimVarName(dim_id)) { + var_name = ""; + } + dim_vars.push_back(IndexingMap::Variable{interval, var_name}); + } + // Parse range variables. + std::vector range_vars; + for (const auto& [index, name] : llvm::enumerate(range_var_names)) { + std::string var_name; + Interval interval; + if (!parser.ParseVarName(&var_name) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { + llvm::errs() << "Failed to parse RangeVar " << name << " interval\n"; + return std::nullopt; + } + if (var_name != name) { + llvm::errs() << "Range var name mismatch " << name << " != " << var_name + << "\n"; + return std::nullopt; + } + if (var_name == GetRangeVarName(index)) { + var_name = ""; + } + range_vars.push_back(IndexingMap::Variable{interval, var_name}); + } + // Parse runtime variables. + std::vector rt_vars; + for (const auto& [index, name] : llvm::enumerate(rt_var_names)) { + std::string var_name; + Interval interval; + if (!parser.ParseVarName(&var_name) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { + llvm::errs() << "Failed to parse RuntimeVar " << name << " interval\n"; + return std::nullopt; + } + if (var_name != name) { + llvm::errs() << "Runtime var name mismatch " << name << " != " << var_name + << "\n"; + return std::nullopt; + } + if (var_name == GetRTVarName(index)) { + var_name = ""; + } + rt_vars.push_back(IndexingMap::Variable{interval, var_name}); + } + // Parse constraints. + SmallVector constraint_bounds; + while (!parser.ConsumeToken(Token::Kind::kEOF)) { + std::string affine_expr_str; + Interval interval; + if (!parser.ParseAffineExprString(&affine_expr_str) || + !parser.ConsumeToken(Token::Kind::kKeywordIn) || + !parser.ParseInterval(&interval) || + (parser.GetCurrentToken().kind != Token::Kind::kEOF && + !parser.ConsumeToken(Token::Kind::kComma))) { + llvm::errs() << "Failed to parse constraint\n"; + return std::nullopt; + } + affine_expr_strs.push_back(affine_expr_str); + constraint_bounds.push_back(interval); + } + // Parse affine expressions. + SmallVector symbol_var_names; + symbol_var_names.reserve(range_var_names.size() + rt_var_names.size()); + symbol_var_names.append(range_var_names.begin(), range_var_names.end()); + symbol_var_names.append(rt_var_names.begin(), rt_var_names.end()); + SmallVector affine_exprs; + if (!ParseAffineExprsWithMLIR(dim_var_names, symbol_var_names, + affine_expr_strs, context, affine_exprs)) { + llvm::errs() << "Failed to parse affine expressions\n"; + return std::nullopt; + } + ArrayRef affine_map_results = + ArrayRef(affine_exprs).take_front(num_affine_map_results); + ArrayRef constraint_exprs = + ArrayRef(affine_exprs).drop_front(num_affine_map_results); + + // Populate constraints. + SmallVector> constraints; + constraints.reserve(constraint_exprs.size()); + for (const auto& [expr, bounds] : + llvm::zip(constraint_exprs, constraint_bounds)) { + constraints.push_back(std::make_pair(expr, bounds)); + } + auto map = AffineMap::get(dim_vars.size(), range_vars.size() + rt_vars.size(), + affine_map_results, context); + return IndexingMap{map, std::move(dim_vars), std::move(range_vars), + std::move(rt_vars), constraints}; +} + +std::string ToString(AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names) { + std::string s; + llvm::raw_string_ostream ss(s); + PrintAffineExprImpl(affine_expr, dim_names, symbol_names, + /*add_parentheses=*/false, ss); + return s; +} + +std::string ToString(AffineExpr affine_expr) { + return ToString(affine_expr, /*dim_names=*/{}, /*symbol_names=*/{}); +} + +std::ostream& operator<<(std::ostream& out, AffineExpr affine_expr) { + out << ToString(affine_expr); + return out; +} + +std::string ToString(AffineMap affine_map, + absl::Span dim_names, + absl::Span range_names, + absl::Span rt_names) { + CHECK_EQ(dim_names.size(), affine_map.getNumDims()); + CHECK_EQ(range_names.size() + rt_names.size(), affine_map.getNumSymbols()); + + std::string s; + llvm::raw_string_ostream ss(s); + + // Dimension identifiers. + ss << '(' << absl::StrJoin(dim_names, ", ") << ')'; + // Range identifiers. + if (!range_names.empty()) { + ss << '[' << absl::StrJoin(range_names, ", ") << ']'; + } + // Runtime identifiers. + if (!rt_names.empty()) { + ss << '{' << absl::StrJoin(rt_names, ", ") << '}'; + } + // Result affine expressions. + ss << " -> ("; + SmallVector symbol_names; + symbol_names.reserve(range_names.size() + rt_names.size()); + symbol_names.append(range_names.begin(), range_names.end()); + symbol_names.append(rt_names.begin(), rt_names.end()); + llvm::interleaveComma(affine_map.getResults(), ss, [&](AffineExpr expr) { + PrintAffineExprImpl(expr, dim_names, symbol_names, + /*add_parentheses=*/false, ss); + }); + ss << ')'; + return s; +} + +std::string ToString(AffineMap affine_map) { + int dim_count = affine_map.getNumDims(); + SmallVector dim_names; + dim_names.reserve(affine_map.getNumDims()); + for (int64_t dim_id = 0; dim_id < dim_count; ++dim_id) { + dim_names.push_back(GetAffineDimensionName(dim_id)); + } + int symbol_count = affine_map.getNumSymbols(); + SmallVector symbol_names; + symbol_names.reserve(affine_map.getNumSymbols()); + for (int64_t symbol_id = 0; symbol_id < symbol_count; ++symbol_id) { + symbol_names.push_back(GetAffineSymbolName(symbol_id)); + } + // AffineMap concats ranges and runtime variables and printed as + // "[dims](ranges, rt_vars)". + return ToString(affine_map, dim_names, symbol_names, {}); +} + +std::ostream& operator<<(std::ostream& out, AffineMap affine_map) { + out << ToString(affine_map); + return out; +} + +std::string ToString(const IndexingMap& indexing_map, + absl::Span dim_names, + absl::Span range_names, + absl::Span rt_names) { + std::stringstream ss; + if (indexing_map.IsKnownEmpty()) { + ss << "KNOWN EMPTY\n"; + return ss.str(); + } + const auto& dim_vars = indexing_map.GetDimVars(); + CHECK_EQ(dim_names.size(), dim_vars.size()); + const auto& range_vars = indexing_map.GetRangeVars(); + CHECK_EQ(range_names.size(), range_vars.size()); + const auto& rt_vars = indexing_map.GetRTVars(); + CHECK_EQ(rt_names.size(), rt_vars.size()); + SmallVector symbol_names; + symbol_names.reserve(range_names.size() + rt_names.size()); + symbol_names.append(range_names.begin(), range_names.end()); + symbol_names.append(rt_names.begin(), rt_names.end()); + ss << ToString(indexing_map.GetAffineMap(), dim_names, range_names, rt_names); + if (dim_vars.empty() && range_vars.empty() && rt_vars.empty()) { + return ss.str(); + } + ss << ", domain: "; + int64_t remaining_vars_to_print = + dim_vars.size() + range_vars.size() + rt_vars.size(); + for (const auto& [index, dim_var] : llvm::enumerate(dim_vars)) { + ss << dim_names[index] << " in " << dim_var.bounds; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } + } + for (const auto& [index, range_var] : llvm::enumerate(range_vars)) { + ss << symbol_names[index] << " in " << range_var.bounds; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } + } + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { + ss << rt_names[index] << " in " << rt_var.bounds; + if (--remaining_vars_to_print > 0) { + ss << ", "; + } + } + std::vector expr_range_strings; + const auto& constraints = indexing_map.GetConstraints(); + expr_range_strings.reserve(constraints.size()); + for (const auto& [expr, range] : constraints) { + expr_range_strings.push_back(absl::StrCat( + ToString(expr, dim_names, symbol_names), " in ", range.ToString())); + } + std::sort(expr_range_strings.begin(), expr_range_strings.end()); + if (!expr_range_strings.empty()) { + ss << ", " << absl::StrJoin(expr_range_strings, ", "); + } + return ss.str(); +} + +std::string ToString(const IndexingMap& indexing_map) { + // Get variable names for DimVars. + SmallVector dim_names; + dim_names.reserve(indexing_map.GetDimensionCount()); + for (const auto& [index, dim_var] : + llvm::enumerate(indexing_map.GetDimVars())) { + dim_names.push_back(GetDimVarName(index, dim_var.name)); + } + // Get variable names for RangeVars. + SmallVector range_names; + range_names.reserve(indexing_map.GetRangeVarsCount()); + for (const auto& [index, range_var] : + llvm::enumerate(indexing_map.GetRangeVars())) { + range_names.push_back(GetRangeVarName(index, range_var.name)); + } + // Get variable names for RTVars. + SmallVector rt_names; + rt_names.reserve(indexing_map.GetRTVarsCount()); + for (const auto& [index, rt_var] : + llvm::enumerate(indexing_map.GetRTVars())) { + rt_names.push_back(GetRTVarName(index, rt_var.name)); + } + return ToString(indexing_map, dim_names, range_names, rt_names); +} + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map) { + out << ToString(indexing_map); + return out; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h new file mode 100644 index 00000000000000..5e077956eea67e --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization.h @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ +#define XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ + +#include +#include +#include + +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { + +// Parses the given string into an IndexingMap. +std::optional ParseIndexingMap(llvm::StringRef input, + mlir::MLIRContext* context); + +// Prints AffineExpr using the default (d0, d1, ..., s0, s1, ...) variable +// names. +std::string ToString(mlir::AffineExpr affine_expr); + +// Prints AffineExpr using the provided variable names. +std::string ToString(mlir::AffineExpr affine_expr, + absl::Span dim_names, + absl::Span symbol_names); + +std::ostream& operator<<(std::ostream& out, mlir::AffineExpr affine_expr); + +// Prints AffineMap using the default (d0, d1, ..., s0, s1, ...) variable names. +// Mixes range and runtime variables into a single symbol list. +std::string ToString(mlir::AffineMap affine_map); + +// Prints AffineMap using the provided variable names. +std::string ToString(mlir::AffineMap affine_map, + absl::Span dim_names, + absl::Span range_names, + absl::Span rt_names); + +std::ostream& operator<<(std::ostream& out, mlir::AffineMap affine_map); + +// Prints IndexingMap using the default (d0, d1, ..., s0, s1, ..., r0, r1, ...) +// variable names. +std::string ToString(const IndexingMap& indexing_map); + +// Prints IndexingMap using the provided variable names. +std::string ToString(const IndexingMap& indexing_map, + absl::Span dim_names, + absl::Span range_names, + absl::Span rt_names); + +std::ostream& operator<<(std::ostream& out, const IndexingMap& indexing_map); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_INDEXING_MAP_SERIALIZATION_H_ diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc new file mode 100644 index 00000000000000..fa4f89773bd3ad --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/indexing_map_serialization_test.cc @@ -0,0 +1,185 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/indexing_map_serialization.h" + +#include +#include +#include "absl/strings/string_view.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::testing::HasSubstr; + +class IndexingMapSerializationTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; + void ParseAndCheck(absl::string_view indexing_map_str) { + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + ASSERT_TRUE(indexing_map.has_value()); + EXPECT_THAT(ToString(*indexing_map), MatchIndexingString(indexing_map_str)); + } +}; + +TEST_F(IndexingMapSerializationTest, EmptyMap) { ParseAndCheck("() -> ()"); } + +TEST_F(IndexingMapSerializationTest, DimsOnly) { + ParseAndCheck(R"( + (d0, d1) -> (d0 mod 2 + d1), + domain: + d0 in [0, 3], + d1 in [-4, 4] + )"); +} + +TEST_F(IndexingMapSerializationTest, SymbolsOnly) { + ParseAndCheck(R"( + ()[s0, s1] -> (s0 floordiv s1), + domain: + s0 in [0, 3], + s1 in [0, 4] + )"); +} + +TEST_F(IndexingMapSerializationTest, RuntimeOnly) { + ParseAndCheck(R"( + (){r0, r1} -> (r0), + domain: + r0 in [1, 1], + r1 in [0, 1] + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsAndRuntime) { + ParseAndCheck(R"( + (d0){r0, r1} -> (d0 floordiv r0 + r1), + domain: + d0 in [0, 3], + r0 in [1, 1], + r1 in [0, 1] + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsAndSymbols) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3] + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsAndRanges) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2] -> (s2, d0 + d1, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3] + )"); +} + +TEST_F(IndexingMapSerializationTest, DimsRangesAndRuntime) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2]{r0} -> (s2, d0 + d1, s1, s0 * r0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3], + r0 in [0, 4] + )"); +} + +TEST_F(IndexingMapSerializationTest, DimRangesRuntimeAndConstraints) { + ParseAndCheck(R"( + (d0, d1)[s0, s1, s2]{r0} -> (s2, d0 + d1, s1, s0, r0), + domain: + d0 in [0, 3], + d1 in [0, 4], + s0 in [0, 1], + s1 in [0, 1], + s2 in [0, 3], + r0 in [0, 100], + (r0 + 1) mod 5 in [1, 1], + d0 mod 4 in [0, 0], + d1 + s0 in [0, 45] + )"); +} + +TEST_F(IndexingMapSerializationTest, AffineExprsWithParens) { + ParseAndCheck(R"( + (d0, d1)[s0, s1] -> ((d0 + d0 mod 3) floordiv 3 + + s0 + (s0 * 2) mod 3 + (d0 + s0) mod 3), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39] + )"); +} + +// This test will be updated when the printing uses types of variables. +TEST_F(IndexingMapSerializationTest, CustomNames) { + ParseAndCheck(R"( + (th_x, bl_x)[s0, vector_elem, s2]{gpu} + -> (s2, th_x * gpu + bl_x, vector_elem, s0), + domain: + th_x in [0, 3], + bl_x in [0, 4], + s0 in [0, 1], + vector_elem in [0, 1], + s2 in [0, 3], + gpu in [0, 1], + (th_x * gpu) mod 4 in [0, 0], + bl_x + s0 in [0, 45] + )"); +} + +TEST_F(IndexingMapSerializationTest, AffineMapPrinterTest) { + mlir::AffineExpr d0, d1, s0, s1, r0, r1; + mlir::bindDims(&mlir_context_, d0, d1); + mlir::bindSymbols(&mlir_context_, s0, s1, r0, r1); + + // (d0, d1)[s0, s1]{r0, r1} -> + // (d0 + d1 floordiv 8 - r0 * 64, s0 + s1 mod 16 + r1). + auto map = mlir::AffineMap::get( + 2, 4, {d0 + d1.floorDiv(8) - r0 * 64, s0 + s1 % 16 + r1}, &mlir_context_); + EXPECT_THAT(ToString(map, {"offset", "d1"}, {"s0", "linear_index"}, + {"gpu_index", "r1"}), + HasSubstr("(offset, d1)[s0, linear_index]{gpu_index, r1} -> " + "(offset + d1 floordiv 8 - gpu_index * 64, s0 + " + "linear_index mod 16 + r1)")); + EXPECT_THAT(ToString(map), + HasSubstr("(d0, d1)[s0, s1, s2, s3] -> " + "(d0 + d1 floordiv 8 - s2 * 64, s0 + s1 mod 16 + s3)")); +} +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc index c7bd056b072a5b..aa577cc34d5e05 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -26,16 +27,15 @@ limitations under the License. #include #include #include "absl/hash/hash_testing.h" -#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/model/affine_map_printer.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -49,8 +49,13 @@ using ::testing::ElementsAre; class IndexingMapTest : public HloTestBase { public: + IndexingMap Parse(absl::string_view indexing_map_str) { + auto indexing_map = ParseIndexingMap(indexing_map_str, &mlir_context_); + EXPECT_TRUE(indexing_map.has_value()); + return *indexing_map; + } + mlir::MLIRContext mlir_context_; - AffineMapPrinter printer_; }; std::vector ConvertToSTL(const llvm::SmallBitVector& bit_vector) { @@ -64,58 +69,77 @@ std::vector ConvertToSTL(const llvm::SmallBitVector& bit_vector) { TEST_F(IndexingMapTest, VariableKind) { EXPECT_EQ(ToVariableType("default"), VariableKind::kDefault); - EXPECT_EQ(ToVariableType("thread_x"), VariableKind::kThreadX); - EXPECT_EQ(ToVariableType("thread_y"), VariableKind::kThreadY); - EXPECT_EQ(ToVariableType("thread_z"), VariableKind::kThreadZ); - EXPECT_EQ(ToVariableType("block_x"), VariableKind::kBlockX); - EXPECT_EQ(ToVariableType("block_y"), VariableKind::kBlockY); - EXPECT_EQ(ToVariableType("block_z"), VariableKind::kBlockZ); - - EXPECT_EQ(ToString(VariableKind::kDefault), "default"); - EXPECT_EQ(ToString(VariableKind::kThreadX), "thread_x"); - EXPECT_EQ(ToString(VariableKind::kThreadY), "thread_y"); - EXPECT_EQ(ToString(VariableKind::kThreadZ), "thread_z"); - EXPECT_EQ(ToString(VariableKind::kBlockX), "block_x"); - EXPECT_EQ(ToString(VariableKind::kBlockY), "block_y"); - EXPECT_EQ(ToString(VariableKind::kBlockZ), "block_z"); + EXPECT_EQ(ToVariableType("th_x"), VariableKind::kThreadX); + EXPECT_EQ(ToVariableType("th_y"), VariableKind::kThreadY); + EXPECT_EQ(ToVariableType("th_z"), VariableKind::kThreadZ); + EXPECT_EQ(ToVariableType("bl_x"), VariableKind::kBlockX); + EXPECT_EQ(ToVariableType("bl_y"), VariableKind::kBlockY); + EXPECT_EQ(ToVariableType("bl_z"), VariableKind::kBlockZ); + EXPECT_EQ(ToVariableType("warp"), VariableKind::kWarp); + EXPECT_EQ(ToVariableType("th_w"), VariableKind::kWarpThread); + + EXPECT_EQ(ToVariableName(VariableKind::kDefault), "default"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadX), "th_x"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadY), "th_y"); + EXPECT_EQ(ToVariableName(VariableKind::kThreadZ), "th_z"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockX), "bl_x"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockY), "bl_y"); + EXPECT_EQ(ToVariableName(VariableKind::kBlockZ), "bl_z"); + EXPECT_EQ(ToVariableName(VariableKind::kWarp), "warp"); + EXPECT_EQ(ToVariableName(VariableKind::kWarpThread), "th_w"); +} + +TEST_F(IndexingMapTest, VerifyDimensions) { + auto indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), + /*dim_upper_bounds=*/{10, 10}, /*symbol_upper_bounds=*/{}); + + std::stringstream ss; + EXPECT_FALSE(indexing_map.Verify(ss)); + EXPECT_EQ(ss.str(), + "dim size must match the number of dimensions in the affine map"); +} + +TEST_F(IndexingMapTest, VerifySymbols) { + auto indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("(d0) -> (d0)", &mlir_context_), + /*dim_upper_bounds=*/{10}, /*symbol_upper_bounds=*/{10}); + + std::stringstream ss; + EXPECT_FALSE(indexing_map.Verify(ss)); + EXPECT_EQ(ss.str(), + "range vars size + rt var size must match the number of symbols in " + "the affine map"); } TEST_F(IndexingMapTest, RTVar) { - auto zero_dim_map = AffineMap::get(&mlir_context_); - std::vector rt_vars{RTVar{Interval{0, 2}, - /*instr=*/nullptr, zero_dim_map}, - RTVar({Interval{0, 7}, - /*instr=*/nullptr, zero_dim_map})}; - IndexingMap indexing_map( - ParseAffineMap("(d0, d1)[s0, s1, s2] -> (d1, d0, s0 + s1, s1)", + ParseAffineMap("(d0, d1)[range, rt0, rt1] -> (d1, d0, range + rt0, rt1)", &mlir_context_), - {DimVar{{0, 99}}, DimVar{{0, 43}}}, {RangeVar{{-99, 99}}}, - std::move(rt_vars)); - printer_.SetSymbolName(0, "range"); - printer_.SetSymbolName(1, "rt_0"); - printer_.SetSymbolName(2, "rt_1"); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0, d1)[range, rt_0, rt_1] -> (d1, d0, range + rt_0, rt_0), + {IndexingMap::Variable{0, 99, "d0"}, IndexingMap::Variable{0, 43, "d1"}}, + {IndexingMap::Variable{-99, 99, "range"}}, + {IndexingMap::Variable{Interval{0, 2}}, + IndexingMap::Variable({Interval{0, 7}})}); + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( + (d0, d1)[range]{rt0, rt1} -> (d1, d0, range + rt0, rt1), domain: d0 in [0, 99], d1 in [0, 43], range in [-99, 99], - rt_0 in [0, 2], - hlo: NULL, - () -> (), - rt_1 in [0, 7], - hlo: NULL, - () -> (), - is_simplified: false + rt0 in [0, 2], + rt1 in [0, 7] )")); } TEST_F(IndexingMapTest, Evaluation) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - + IndexingMap indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1] + )"); auto results = indexing_map.Evaluate( mlir::getAffineConstantExprs({1, 2}, &mlir_context_), mlir::getAffineConstantExprs({3, 4}, &mlir_context_)); @@ -136,13 +160,20 @@ TEST_F(IndexingMapTest, Evaluation) { } TEST_F(IndexingMapTest, Composition_Permutation) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {4, 4}, {2, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {4}, {4}); - + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 3], + d1 in [0, 3], + s0 in [0, 1], + s1 in [0, 1] + )"); + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 3], + s0 in [0, 3] + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), @@ -150,18 +181,26 @@ TEST_F(IndexingMapTest, Composition_Permutation) { d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 3], - is_simplified: false + s2 in [0, 3] )")); } TEST_F(IndexingMapTest, Composition_RestrictedInterval) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {5, 6}, {7, 2}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 4], + d1 in [0, 5], + s0 in [0, 6], + s1 in [0, 1] + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7] + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -170,26 +209,30 @@ TEST_F(IndexingMapTest, Composition_RestrictedInterval) { d0 in [0, 4], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { - IndexingMap producer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", &mlir_context_), - {50, 60}, {70, 20}); - producer.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - producer.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{1, 1}); - - IndexingMap consumer = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), {10}, {8}); - consumer.AddConstraint(ParseAffineExpr("d0 + s0", &mlir_context_), - Interval{0, 20}); - consumer.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), - Interval{0, 0}); + IndexingMap producer = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 8 in [0, 0], + s0 mod 3 in [1, 1] + )"); + + IndexingMap consumer = Parse(R"( + (d0)[s0] -> (d0, s0), + domain: + d0 in [0, 9], + s0 in [0, 7], + d0 + s0 in [0, 20], + s0 mod 4 in [0, 0] + )"); auto composed = ComposeIndexingMaps(consumer, producer); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -202,8 +245,7 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { d0 + s2 in [0, 20], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: false + s2 mod 4 in [0, 0] )")); EXPECT_TRUE(composed.Simplify()); EXPECT_THAT(composed, MatchIndexingMap(R"( @@ -215,110 +257,86 @@ TEST_F(IndexingMapTest, Composition_ProducerAndConsumerHaveConstraints) { s2 in [0, 4], d0 mod 8 in [0, 0], s0 mod 3 in [1, 1], - s2 mod 4 in [0, 0], - is_simplified: true + s2 mod 4 in [0, 0] )")); } TEST_F(IndexingMapTest, Composition_RTVar) { - auto zero_dim_map = AffineMap::get(&mlir_context_); - std::vector rt_vars{ - RTVar{Interval{0, 0}, - /*instr=*/nullptr, zero_dim_map}, - RTVar({Interval{0, 1}, /*instr=*/nullptr, zero_dim_map}), - RTVar({Interval{0, 226}, /*instr=*/nullptr, zero_dim_map})}; + std::vector rt_vars{ + IndexingMap::Variable{Interval{0, 0}}, + IndexingMap::Variable({Interval{0, 1}}), + IndexingMap::Variable({Interval{0, 226}})}; IndexingMap producer( - ParseAffineMap("(d0, d1, d2)[s0, s1, s2] -> (d0 + s0, d1 + s1, d2 + s2)", - &mlir_context_), - {DimVar{{0, 0}}, DimVar{{0, 1}}, DimVar{{0, 226}}}, {}, - std::move(rt_vars)); + ParseAffineMap( + "(d0, d1, d2)[rt0, rt1, rt2] -> (d0 + rt0, d1 + rt1, d2 + rt2)", + &mlir_context_), + {IndexingMap::Variable{{0, 0}}, IndexingMap::Variable{{0, 1}}, + IndexingMap::Variable{{0, 226}}}, + {}, std::move(rt_vars)); IndexingMap consumer( - ParseAffineMap("(d0, d1)[s0] -> (0, d1, s0)", &mlir_context_), - {DimVar{{0, 0}}, DimVar{{0, 1}}}, {RangeVar{0, 31}}, {}); - printer_.SetSymbolName(0, "s"); - printer_.SetSymbolName(1, "rt_0"); - printer_.SetSymbolName(2, "rt_1"); - printer_.SetSymbolName(3, "rt_2"); + ParseAffineMap("(d0, d1)[s] -> (0, d1, s)", &mlir_context_), + {IndexingMap::Variable{0, 0}, IndexingMap::Variable{0, 1}}, + {IndexingMap::Variable{0, 31, "s"}}, {}); auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT(composed.ToString(printer_), MatchIndexingString(R"( - (d0, d1)[s, rt_0, rt_1, rt_2] -> (rt_0, d1 + rt_1, s + rt_2), + EXPECT_THAT(ToString(composed), MatchIndexingString(R"( + (d0, d1)[s]{rt0, rt1, rt2} -> (rt0, d1 + rt1, s + rt2), domain: d0 in [0, 0], d1 in [0, 1], s in [0, 31], - rt_0 in [0, 0], - hlo: NULL, - () -> (), - rt_1 in [0, 1], - hlo: NULL, - () -> (), - rt_2 in [0, 226], - hlo: NULL, - () -> (), - is_simplified: false + rt0 in [0, 0], + rt1 in [0, 1], + rt2 in [0, 226] )")); } TEST_F(IndexingMapTest, Composition_OnlyRTVars) { - auto zero_dim_map = AffineMap::get(&mlir_context_); - IndexingMap producer( ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + s0, d1 + 4 * s1)", &mlir_context_), - {DimVar{{0, 24}}, DimVar{{0, 15}}}, {}, - {RTVar({Interval{0, 2}, /*instr=*/nullptr, zero_dim_map}), - RTVar({Interval{0, 1}, /*instr=*/nullptr, zero_dim_map})}); + {IndexingMap::Variable{0, 24}, IndexingMap::Variable{0, 15}}, {}, + {IndexingMap::Variable{Interval{0, 2}, "ps_0"}, + IndexingMap::Variable{Interval{0, 1}, "ps_1"}}); - std::vector consumer_rt_vars; + std::vector consumer_rt_vars; IndexingMap consumer( ParseAffineMap("(d0, d1)[s0, s1] -> (d0 + 2 * s0, d1 + 3 * s1)", &mlir_context_), - {DimVar{{0, 24}}, DimVar{{0, 15}}}, {}, - {RTVar({Interval{0, 25}, /*instr=*/nullptr, zero_dim_map}), - RTVar({Interval{0, 16}, /*instr=*/nullptr, zero_dim_map})}); - - printer_.SetSymbolName(0, "ps_0"); - printer_.SetSymbolName(1, "ps_1"); - printer_.SetSymbolName(2, "cs_0"); - printer_.SetSymbolName(3, "cs_1"); + {IndexingMap::Variable{0, 24}, IndexingMap::Variable{0, 15}}, {}, + {IndexingMap::Variable{Interval{0, 25}, "cs_0"}, + IndexingMap::Variable{Interval{0, 16}, "cs_1"}}); auto composed = ComposeIndexingMaps(consumer, producer); - EXPECT_THAT(composed.ToString(printer_), MatchIndexingString(R"( - (d0, d1)[ps_0, ps_1, cs_0, cs_1] -> + EXPECT_THAT(ToString(composed), MatchIndexingString(R"( + (d0, d1){ps_0, ps_1, cs_0, cs_1} -> (d0 + cs_0 * 2 + ps_0, d1 + cs_1 * 3 + ps_1 * 4), domain: d0 in [0, 24], d1 in [0, 15], ps_0 in [0, 2], - hlo: NULL, - () -> (), ps_1 in [0, 1], - hlo: NULL, - () -> (), cs_0 in [0, 25], - hlo: NULL, - () -> (), cs_1 in [0, 16], - hlo: NULL, - () -> (), d0 + cs_0 * 2 in [0, 24], - d1 + cs_1 * 3 in [0, 15], - is_simplified: false + d1 + cs_1 * 3 in [0, 15] )")); } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, s0, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint cannot be removed, because it contains a dimension. - indexing_map.AddConstraint(ParseAffineExpr("s0 + d0", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, s0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 + s0 in [1, 100], + s0 mod 3 in [0, 0] + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, s0, s1), @@ -328,59 +346,69 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesDim) { s0 in [0, 69], s1 in [0, 19], d0 + s0 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintUsesUnusedDim) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (s0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused dim. - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (s0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + d0 mod 3 in [0, 0] + )"); indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0)[s0, s1] -> (s0, d0, s1), domain: d0 in [0, 59], s0 in [0, 69], - s1 in [0, 19], - is_simplified: false + s1 in [0, 19] )")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSym) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d0, d1, s1)", &mlir_context_), - {50, 60}, {70, 20}); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d0, d1, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0] + )"); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d0, d1, s0), domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42)", - &mlir_context_), - {1, 2, 3, 4, 5}, {32, 64, 96}); - indexing_map.AddConstraint( - ParseAffineExpr("s0 * 4 + d1 + d3", &mlir_context_), Interval{24, 459}); - indexing_map.AddConstraint(ParseAffineExpr("s0 + s2", &mlir_context_), - Interval{0, 512}); - auto unused_vars = indexing_map.RemoveUnusedVars(); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3, d4)[s0, s1, s2] -> (s0 * 4 + d1 + d3 - 42), + domain: + d0 in [0, 0], + d1 in [0, 1], + d2 in [0, 2], + d3 in [0, 3], + d4 in [0, 4], + s0 in [0, 31], + s1 in [0, 63], + s2 in [0, 95], + s0 * 4 + d1 + d3 in [24, 459], + s0 + s2 in [0, 512] + )"); // dimensions d0, d2, d4 and symbol s1 will be removed. + auto unused_vars = indexing_map.RemoveUnusedVars(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d0 + s0 * 4 + d1 - 42), domain: @@ -389,8 +417,7 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { s0 in [0, 31], s1 in [0, 95], d0 + s0 * 4 + d1 in [24, 459], - s0 + s1 in [0, 512], - is_simplified: false + s0 + s1 in [0, 512] )")); EXPECT_THAT(ConvertToSTL(unused_vars), ::testing::ElementsAreArray( @@ -398,14 +425,17 @@ TEST_F(IndexingMapTest, RemoveUnusedVars_ConstraintsWithManyDims) { } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 + s1 in [1, 100], + s0 mod 3 in [0, 0] + )"); // This constraint cannot be removed, because it contains a "used symbol". - indexing_map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), - Interval{1, 100}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0, s1] -> (d1, d0, s1), @@ -415,62 +445,68 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesSymbol) { s0 in [0, 69], s1 in [0, 19], s0 + s1 in [1, 100], - s0 mod 3 in [0, 0], - is_simplified: false + s0 mod 3 in [0, 0] )")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintUsesOnlyUnusedSymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s0 mod 3 in [0, 0] + )"); // This constraint can be removed, because it contains only the unused symbol. - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); indexing_map.RemoveUnusedSymbols(); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0, d1)[s0] -> (d1, d0, s0), domain: d0 in [0, 49], d1 in [0, 59], - s0 in [0, 19], - is_simplified: false + s0 in [0, 19] )")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintIsAConstantWithinRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{-10, 5}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [-10, 5] + )"); EXPECT_THAT(indexing_map, MatchIndexingMap(R"( (d0) -> (d0), domain: - d0 in [0, 49], - is_simplified: false + d0 in [0, 49] )")); } TEST_F(IndexingMapTest, KnownEmpty_CreatingIndexingMapWithInfeasibleRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {-1}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, -2] + )"); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRange) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 49], + 0 in [10, 15] + )"); // Addition of this constraint makes the domain empty. - indexing_map.AddConstraint(ParseAffineExpr("0", &mlir_context_), - Interval{10, 15}); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, KnownEmpty_Composition) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {50}, {}); - IndexingMap known_empty = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (0)", &mlir_context_), {0}, {}); + auto indexing_map = Parse("(d0) -> (d0), domain: d0 in [0, 49]"); + auto known_empty = Parse("(d0) -> (d0), domain: d0 in [0, -1]"); EXPECT_THAT(known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(indexing_map * known_empty, MatchIndexingMap("KNOWN EMPTY")); EXPECT_THAT(known_empty * indexing_map, MatchIndexingMap("KNOWN EMPTY")); @@ -480,22 +516,31 @@ TEST_F(IndexingMapTest, KnownEmpty_Composition) { TEST_F(IndexingMapTest, KnownEmpty_AddingConstraintOutOfRangeAfterSimplification) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1)", &mlir_context_), - {50, 60}, {70, 20}); - indexing_map.AddConstraint(ParseAffineExpr("s1 floordiv 20", &mlir_context_), - Interval{2, 2}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 19], + s1 floordiv 20 in [2, 2] + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_THAT(indexing_map, MatchIndexingMap("KNOWN EMPTY")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", - &mlir_context_), - {32}, {1, 2, 3, 4, 5}); - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42), + domain: + d0 in [0, 31], + s0 in [0, 0], + s1 in [0, 1], + s2 in [0, 2], + s3 in [0, 3], + s4 in [0, 4], + d0 * 4 + s1 + s3 in [24, 459] + )"); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. EXPECT_THAT(indexing_map, MatchIndexingMap(R"( @@ -504,47 +549,42 @@ TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithManySymbols) { d0 in [0, 31], s0 in [0, 1], s1 in [0, 3], - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + d0 * 4 + s0 + s1 in [24, 459] )")); } TEST_F(IndexingMapTest, RemoveUnusedSymbols_ConstraintsWithRTVars) { - auto zero_dim_map = AffineMap::get(&mlir_context_); IndexingMap indexing_map( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), - {DimVar{{0, 31}}}, {RangeVar{{0, 0}}, RangeVar{{0, 1}}, RangeVar{{0, 2}}}, - {RTVar{Interval{0, 3}, - /*instr=*/nullptr, zero_dim_map}, - RTVar{Interval{0, 4}, - /*instr=*/nullptr, zero_dim_map}}); + {IndexingMap::Variable{{0, 31}}}, + {IndexingMap::Variable{{0, 0}}, IndexingMap::Variable{{0, 1}}, + IndexingMap::Variable{{0, 2}}}, + {IndexingMap::Variable{Interval{0, 3}}, + IndexingMap::Variable{Interval{0, 4}}}); indexing_map.AddConstraint( ParseAffineExpr("d0 * 4 + s1 + s3", &mlir_context_), Interval{24, 459}); indexing_map.RemoveUnusedSymbols(); // Symbols s0, s2, s4 will be removed and s1 and s3 will become s0 and s1. EXPECT_THAT(indexing_map, MatchIndexingMap(R"( - (d0)[s0, s1] -> (d0 * 4 + s0 + s1 - 42), + (d0)[s0]{rt0} -> (d0 * 4 + s0 + rt0 - 42), domain: d0 in [0, 31], s0 in [0, 1], - s1 in [0, 3], - hlo: NULL, - () -> (), - d0 * 4 + s0 + s1 in [24, 459], - is_simplified: false + rt0 in [0, 3], + d0 * 4 + s0 + rt0 in [24, 459] )")); }; TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { - auto zero_dim_map = AffineMap::get(&mlir_context_); IndexingMap indexing_map( ParseAffineMap( "(d0)[s0, s1, s2, s3] -> (d0 * 4 + s0 + s1 + 2 * s2 + 3 * s3 - 42)", &mlir_context_), - {DimVar{{0, 31}}}, {RangeVar{{0, 0}}, RangeVar{{0, 1}}}, - {RTVar{Interval{0, 3}, /*instr=*/nullptr, zero_dim_map}, - RTVar{Interval{0, 4}, /*instr=*/nullptr, zero_dim_map}}); + {IndexingMap::Variable{{0, 31}}}, + {IndexingMap::Variable{{0, 0}}, IndexingMap::Variable{{0, 1}}}, + {IndexingMap::Variable{Interval{0, 3}}, + IndexingMap::Variable{Interval{0, 4}}}); indexing_map.AddConstraint( ParseAffineExpr("d0 * 4 + s0 + 2 * s2", &mlir_context_), Interval{24, 459}); @@ -556,196 +596,194 @@ TEST_F(IndexingMapTest, ConvertSymbolsToDimensions) { d2 in [0, 1], d3 in [0, 3], d4 in [0, 4], - d0 * 4 + d1 + d3 * 2 in [24, 459], - is_simplified: false + d0 * 4 + d1 + d3 * 2 in [24, 459] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("(d0 mod 8) + 5", &mlir_context_), - Interval{50, 54}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 mod 8 + 5 in [50, 54] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: d0 in [0, 99], - d0 mod 8 in [45, 49], - is_simplified: true + d0 mod 8 in [45, 49] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_IndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 599] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), domain: d0 in [0, 99], s0 in [0, 1], - s1 in [0, 2], - is_simplified: true + s1 in [0, 2] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_NotIndependentOfSymbol) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1)", &mlir_context_), - {2000}, {2, 3}); - - indexing_map.AddConstraint( - ParseAffineExpr("d0 * 6 + s0 * 3 + s1", &mlir_context_), - Interval{0, 598}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0 * 6 + s0 * 3 + s1), + domain: + d0 in [0, 1999], + s0 in [0, 1], + s1 in [0, 2], + d0 * 6 + s0 * 3 + s1 in [0, 598] + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_Sum_GcdGreaterOne) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 * 6 + s0 * 3)", &mlir_context_), {2000}, - {2}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 6 + s0 * 3", &mlir_context_), - Interval{0, 599}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 * 6 + s0 * 3), + domain: + d0 in [0, 1999], + s0 in [0, 1], + d0 * 6 + s0 * 3 in [0, 599] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 * 6 + s0 * 3), domain: d0 in [0, 99], - s0 in [0, 1], - is_simplified: true + s0 in [0, 1] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 floordiv 8", &mlir_context_), - Interval{5, 11}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 floordiv 8 in [5, 11] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [40, 95], - is_simplified: true + d0 in [40, 95] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivPositiveDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv 3 in [-11, -5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-33, -13], - is_simplified: true + s0 in [-33, -13] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_FloorDivNegativeDivisorNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 floordiv -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 floordiv -3 in [-11, -5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [15, 35], - is_simplified: true + s0 in [15, 35] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierPositiveBounds) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0) -> (d0)", &mlir_context_), {100}, {}); - - indexing_map.AddConstraint(ParseAffineExpr("d0 * 8", &mlir_context_), - Interval{14, 33}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [0, 99], + d0 * 8 in [14, 33] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0), domain: - d0 in [2, 4], - is_simplified: true + d0 in [2, 4] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulPositiveMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * 3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * 3 in [-11, -5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [-3, -2], - is_simplified: true + s0 in [-3, -2] )")); } TEST_F(IndexingMapTest, ConstraintIntervalSimplification_MulNegativeMultiplierNegativeBounds) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0)[s0] -> (d0)", &mlir_context_), - {DimVar{{0, 99}}}, {RangeVar{{-99, 99}}}, /*rt_vars=*/{}); - - indexing_map.AddConstraint(ParseAffineExpr("s0 * -3", &mlir_context_), - Interval{-11, -5}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0), + domain: + d0 in [0, 99], + s0 in [-99, 99], + s0 * -3 in [-11, -5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0), domain: d0 in [0, 99], - s0 in [2, 3], - is_simplified: true + s0 in [2, 3] )")); } TEST_F(IndexingMapTest, ConstraintMerge_Mod) { - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0, s1] -> (d0, s1, s0)", &mlir_context_), - {DimVar{{0, 4}}}, {RangeVar{{-21, -1}}, RangeVar{{0, 10}}}, - /*rt_vars=*/{}); - indexing_map.AddConstraint(ParseAffineExpr("d0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s1 mod 5", &mlir_context_), - Interval{1, 1}); + auto indexing_map = Parse(R"( + (d0)[s0, s1] -> (d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [-21, -2], + s1 in [0, 10], + d0 mod 3 in [0, 0], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0], + s1 mod 5 in [1, 1] + )"); EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1] -> (d0, s1, s0), domain: d0 in [0, 3], @@ -753,32 +791,36 @@ TEST_F(IndexingMapTest, ConstraintMerge_Mod) { s1 in [1, 6], d0 mod 3 in [0, 0], s0 mod 6 in [0, 0], - s1 mod 5 in [1, 1], - is_simplified: true + s1 mod 5 in [1, 1] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ConstantDims) { - IndexingMap indexing_map = - IndexingMap(ParseAffineMap("(d0) -> (d0)", &mlir_context_), - {DimVar{{5, 5}}}, /*range_vars=*/{}, /*rt_vars=*/{}); + auto indexing_map = Parse(R"( + (d0) -> (d0), + domain: + d0 in [5, 5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (5), domain: - d0 in [5, 5], - is_simplified: true + d0 in [5, 5] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { // This is a regression test for a bug where we didn't canonicalize the order // of summands correctly, leading to `Simplify` not being idempotent. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + " - "(s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0)))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0, d1)[s0, s1] -> (((((d0 + (d0 mod 3)) floordiv 3) + + (s0 + ((s0 + s0) mod 3))) + (((d0 + s0) mod 3) + 0))), + domain: + d0 in [0, 9], + d1 in [0, 19], + s0 in [0, 29], + s1 in [0, 39] + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } @@ -786,236 +828,250 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression) { TEST_F(IndexingMapTest, AffineMapSimplification_SumOrderRegression2) { // This is a regression test for a bug where we didn't simplify the affine // expression fully after a single iteration. - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> ((((s0 + d0) + d0) floordiv 2))", - &mlir_context_), - {10, 20}, {30, 40}); + auto indexing_map = Parse(R"( + (d0)[s0] -> ((((s0 + d0) + d0) floordiv 2)), + domain: + d0 in [0, 9], + s0 in [0, 19] + )"); EXPECT_TRUE(indexing_map.Simplify()); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_FloorDivRegression) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap( - "(d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6)", - &mlir_context_), - {12, 6}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d0 floordiv 3) * 3 + d1 floordiv 2) floordiv 6), + domain: + d0 in [0, 11], + d1 in [0, 5] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 floordiv 6), domain: d0 in [0, 11], - d1 in [0, 5], - is_simplified: true + d1 in [0, 5] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsSub) { - IndexingMap indexing_map( - ParseAffineMap("(d0) -> (d0 mod 42)", &mlir_context_), {{53, 71}}, {}, - {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 42), + domain: + d0 in [53, 71] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 - 42), domain: - d0 in [53, 71], - is_simplified: true + d0 in [53, 71] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsAdd) { - IndexingMap indexing_map(ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), - {{-5, -1}}, {}, {}); + auto indexing_map = Parse(R"( + (d0) -> (d0 mod 5), + domain: + d0 in [-5, -1] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> (d0 + 5), domain: - d0 in [-5, -1], - is_simplified: true + d0 in [-5, -1] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModIsNotAdd) { - IndexingMap indexing_map1( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-4, 0}}, {}, {}); + auto indexing_map1 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-4, 0]"); EXPECT_FALSE(indexing_map1.Simplify()); - IndexingMap indexing_map2( - ParseAffineMap("(d0) -> (d0 mod 5)", &mlir_context_), {{-6, -1}}, {}, {}); + auto indexing_map2 = Parse("(d0) -> (d0 mod 5), domain: d0 in [-6, -1]"); EXPECT_FALSE(indexing_map2.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsMod) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + s0 mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModMultiplied) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (d0 - (s0 floordiv 3) * 12 + s0 * 7), + domain: + d0 in [0, 1], + s0 in [0, 3] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 mod 3) * 4 + s0 * 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SubIsModSum) { - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0)", - &mlir_context_), - {2}, {4}); + auto indexing_map = Parse(R"( + (d0)[s0] -> (1 + d0 - ((s0 + 1) floordiv 3) * 3 + s0), + domain: + d0 in [0, 1], + s0 in [0, 3] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0] -> (d0 + (s0 + 1) mod 3), domain: d0 in [0, 1], - s0 in [0, 3], - is_simplified: true + s0 in [0, 3] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsIfSmallerThanDivisor) { - auto serialized_map = "(d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 16}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (d0 + d1 floordiv 16, d1 mod 16), + domain: + d0 in [0, 7], + d1 in [0, 15] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 15], - is_simplified: true + d1 in [0, 15] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, " - "((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, " - "d2 mod 10)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {9, 9, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 100 + d1 * 10 + d2) floordiv 100, + ((d0 * 100 + d1 * 10 + d2) mod 100) floordiv 10, + d2 mod 10), + domain: + d0 in [0, 8], + d1 in [0, 8], + d2 in [0, 8] + )"); EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1, d2) -> (d0, d1, d2), domain: d0 in [0, 8], d1 in [0, 8], - d2 in [0, 8], - is_simplified: true + d2 in [0, 8] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithDivisibleMultipliers) { - auto serialized_map = - "(d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, " - " (d0 * 16 + d1 * 4 + d2) mod 8)"; - - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {10, 10, 10}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2) -> ((d0 * 16 + d1 * 4 + d2) floordiv 8, + (d0 * 16 + d1 * 4 + d2) mod 8), + domain: + d0 in [0, 9], + d1 in [0, 9], + d2 in [0, 9] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1, d2) -> (d0 * 2 + (d1 * 4 + d2) floordiv 8, (d1 * 4 + d2) mod 8), domain: d0 in [0, 9], d1 in [0, 9], - d2 in [0, 9], - is_simplified: true + d2 in [0, 9] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivsAndModsWithReverse) { - auto serialized_map = - "(d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, " - "d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {8, 9}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (-((d0 * -11 - d1 + 109) floordiv 11) + 9, + d0 * 11 + d1 + ((d0 * -11 - d1 + 109) floordiv 11) * 11 - 99), + domain: + d0 in [0, 7], + d1 in [0, 8] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0, d1), domain: d0 in [0, 7], - d1 in [0, 8], - is_simplified: true + d1 in [0, 8] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape) { - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 128) floordiv 715) * 715), + domain: + s0 in [0, 127] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0 * 128), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape2) { - auto serialized_map = - "(d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {1024, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 mod 8) * 128 + d1 + (d0 floordiv 8) * 1024), + domain: + d0 in [0, 1023], + d1 in [0, 127] + )"); + ; EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 128 + d1), domain: d0 in [0, 1023], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape3) { - auto serialized_map = - "(d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 " - "+ ((d1 * 128 + d0) floordiv 192) * 768)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128, 3072}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> (((d1 * 2 + d0 floordiv 64) mod 3) * 256 + (d0 mod 64) * 4 + + ((d1 * 128 + d0) floordiv 192) * 768), + domain: + d0 in [0, 127], + d1 in [0, 3071] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 4 + d1 * 512), domain: d0 in [0, 127], - d1 in [0, 3071], - is_simplified: true + d1 in [0, 3071] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ModWithNegativeMultiplerDoesNotGetSimplified) { - auto serialized_map = "(d0) -> ((-d0) mod 2)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {128}, {}); + auto indexing_map = Parse(R"( + (d0) -> ((-d0) mod 2), + domain: + d0 in [0, 127] + )"); EXPECT_FALSE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0) -> ((-d0) mod 2), domain: - d0 in [0, 127], - is_simplified: true + d0 in [0, 127] )")); } @@ -1024,84 +1080,89 @@ TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyBitcastAndBack) { // `((d0 * 2 + d1 floordiv 64) floordiv 3) floordiv 1024`. // This test verifies that we can still simplify the map after the // simplification of the floordiv. - auto serialized_map = - "(d0, d1) -> ((d0 floordiv 1536) * 786432 + (((d0 * 2 + d1 floordiv " - "64) floordiv 3) mod 1024) * 768 + ((d0 * 2 + d1 floordiv 64) mod 3) * " - "256 + (d1 mod 64) * 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {3072, 128}, {}); + auto indexing_map = Parse(R"( + (d0, d1) -> ((d0 floordiv 1536) * 786432 + + (((d0 * 2 + d1 floordiv 64) floordiv 3) mod 1024) * 768 + + ((d0 * 2 + d1 floordiv 64) mod 3) * 256 + (d1 mod 64) * 4), + domain: + d0 in [0, 3071], + d1 in [0, 127] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0, d1) -> (d0 * 512 + d1 * 4), domain: d0 in [0, 3071], - d1 in [0, 127], - is_simplified: true + d1 in [0, 127] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_SimplifyReshape_Regression) { // We have s0 * 128 in the mod, but s0 * 64 in the floordiv *. - auto serialized_map = - "()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {128}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 128) mod 715 + ((s0 * 64) floordiv 715) * 715), + domain: + s0 in [0, 127] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (((s0 * 64) floordiv 715) * 715 + (s0 * 128) mod 715), domain: - s0 in [0, 127], - is_simplified: true + s0 in [0, 127] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivsInSequence) { - auto serialized_map = - "()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * " - "14)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> (s0 - ((s0 floordiv 2) floordiv 7) * 14 + (s0 floordiv 14) * 14), + domain: + s0 in [0, 1233] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> (s0), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivDiv) { - auto serialized_map = "()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 2 + s1 floordiv 64) floordiv 3), + domain: + s0 in [0, 1233], + s1 in [0, 127] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ((s0 * 128 + s1) floordiv 192), domain: s0 in [0, 1233], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumConstant) { - auto serialized_map = "()[s0] -> ((s0 * 6 + 9) floordiv 18)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 * 6 + 9) floordiv 18), + domain: + s0 in [0, 1233] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0] -> ((s0 * 2 + 3) floordiv 6), domain: - s0 in [0, 1233], - is_simplified: true + s0 in [0, 1233] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_DivSumDiv) { - auto serialized_map = - "()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 floordiv 3 + s1 floordiv 3) floordiv 6), + domain: + s0 in [0, 1233], + s1 in [0, 127] + )"); // The rewrite tested in AffineMapSimplification_DivDiv must not trigger here. EXPECT_FALSE(indexing_map.Simplify()); } @@ -1110,20 +1171,25 @@ TEST_F(IndexingMapTest, AffineMapSimplification_NegativeDiv) { // (s0 floordiv 2) floordiv -7 is not s0 floordiv -14: // 15 // 2 // -7 = -1 // 15 // -14 = -2 - auto serialized_map = "()[s0] -> ((s0 floordiv 2) floordiv -7)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {1234}); + auto indexing_map = Parse(R"( + ()[s0] -> ((s0 floordiv 2) floordiv -7), + domain: + s0 in [0, 1233] + )"); EXPECT_FALSE(indexing_map.Simplify()); } TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { - auto serialized_map = - "()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod " - "20000)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {872, 4, 128, 896}); + auto indexing_map = Parse(R"( + ()[s0, s1, s2, s3] -> ((s0 * 458752 + s1 + s2 * 4 + s3 * 512) mod 20000), + domain: + s0 in [0, 871], + s1 in [0, 3], + s2 in [0, 127], + s3 in [0, 895] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1, s2, s3] -> ( ((s0 * 114688 + s3 * 128 + s2) mod 5000) * 4 + s1 ), @@ -1131,123 +1197,131 @@ TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromMod) { s0 in [0, 871], s1 in [0, 3], s2 in [0, 127], - s3 in [0, 895], - is_simplified: true + s3 in [0, 895] )")); } TEST_F(IndexingMapTest, AffineMapSimplification_ExtractFromDiv_NegativeMultiplier) { - auto serialized_map = - "()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) " - "* 2) floordiv 4)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {}, {2, 128}); + auto indexing_map = Parse(R"( + ()[s0, s1] -> ((s0 * 16 - (s1 floordiv 4) floordiv 2 + (s1 floordiv 8) * 2) + floordiv 4), + domain: + s0 in [0, 1], + s1 in [0, 127] + )"); EXPECT_TRUE(indexing_map.Simplify()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( ()[s0, s1] -> ( s0 * 4 + s1 floordiv 32 ), domain: s0 in [0, 1], - s1 in [0, 127], - is_simplified: true + s1 in [0, 127] )")); } TEST_F(IndexingMapTest, RescaleSymbols_Simple) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0] + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), domain: d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } TEST_F(IndexingMapTest, RescaleSymbols_WithShift) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {42, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 41], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3] + )"); // [BEFORE] Allowed values for s0: 3, 9, 15, ..., 39 = (6 * 6 + 3) // [AFTER] Allowed values for s0: 0, 1, 2, ..., 6 EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraints) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 3", &mlir_context_), - Interval{0, 0}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 7], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 2 in [0, 0], + s0 mod 3 in [0, 0] + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0), domain: d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], - s2 in [0, 5], - is_simplified: false + s2 in [0, 5] )")); } TEST_F(IndexingMapTest, RescaleSymbols_RescaledSymbolInOtherNonModConstraint) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {10, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s2", &mlir_context_), - Interval{0, 28}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 9], + s1 in [0, 1], + s2 in [0, 5], + s0 * s2 in [0, 28], + s0 mod 6 in [3, 3] + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( + EXPECT_THAT(ToString(indexing_map), MatchIndexingString(R"( (d0)[s0, s1, s2] -> (s2, d0, s1, s0 * 6 + 3), domain: d0 in [0, 3], s0 in [0, 1], s1 in [0, 1], s2 in [0, 5], - (s0 * 6 + 3) * s2 in [0, 28], - is_simplified: false + (s0 * 6 + 3) * s2 in [0, 28] )")); } TEST_F(IndexingMapTest, RescaleSymbols_TwoModConstraintsForTheSameSymbolWhichCannotBeMerged) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s1, s0)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {100, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{3, 3}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 7", &mlir_context_), - Interval{5, 5}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s1, s0), + domain: + d0 in [0, 3], + s0 in [0, 99], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [3, 3], + s0 mod 7 in [5, 5] + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); const mlir::AffineExpr result3 = indexing_map.GetAffineMap().getResult(3); @@ -1274,14 +1348,16 @@ TEST_F(IndexingMapTest, } TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { - auto serialized_map = "(d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6)"; - IndexingMap indexing_map = IndexingMap::FromTensorSizes( - ParseAffineMap(serialized_map, &mlir_context_), {4}, {7, 2, 6}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 6", &mlir_context_), - Interval{0, 0}); - indexing_map.AddConstraint(ParseAffineExpr("s0 * s1", &mlir_context_), - Interval{0, 100}); - + auto indexing_map = Parse(R"( + (d0)[s0, s1, s2] -> (s2, d0, s0, s0 floordiv 6), + domain: + d0 in [0, 3], + s0 in [0, 6], + s1 in [0, 1], + s2 in [0, 5], + s0 mod 6 in [0, 0], + s0 * s1 in [0, 100] + )"); EXPECT_TRUE(indexing_map.RescaleSymbols()); for (auto& [expr, interval] : indexing_map.GetConstraints()) { @@ -1291,13 +1367,14 @@ TEST_F(IndexingMapTest, RescaleSymbolsKeepsHashmapConsistent) { } TEST_F(IndexingMapTest, RangeEvaluatorTest) { - auto serialized_map = "(d0, d1, d2, d3)[] -> (0)"; - IndexingMap indexing_map(ParseAffineMap(serialized_map, &mlir_context_), - {{Interval{0, 9}}, - {Interval{-10, -1}}, - {Interval{-1, 2}}, - {Interval{0, 0}}}, - {}, {}); + auto indexing_map = Parse(R"( + (d0, d1, d2, d3)[] -> (0), + domain: + d0 in [0, 9], + d1 in [-10, -1], + d2 in [-1, 2], + d3 in [0, 0] + )"); RangeEvaluator range_evaluator(indexing_map, &mlir_context_); mlir::AffineExpr d0, d1, d2, d3; bindDims(&mlir_context_, d0, d1, d2, d3); @@ -1450,410 +1527,6 @@ TEST(IntervalMathTest, MultiplicationSaturating) { EXPECT_THAT(any * neg_one, IntervalIs(any)); } -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ScalarConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[] constant(42) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("()[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{42, 42}, - hlo_module.value()->entry_computation()->root_instruction(), - AffineMap::get(0, 0, {}, &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), - MatchIndexingString("() -> (42)")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_StaticIndexIntoTensorConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[2, 4]{1,0} constant({{1, 2, 3, 4}, {11, 12, 13, 14}}) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("()[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{1, 14}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("() -> (1,2)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), - MatchIndexingString("() -> (13)")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_NonFoldableTensor) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %constant = s64[2, 4]{1,0} constant({{1, 2, 3, 4}, {11, 12, 13, 14}}) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (s0)", &mlir_context_), - /*dimensions=*/{}, - /*range_vars=*/{}, - {RTVar{Interval{1, 14}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (1, d0)", &mlir_context_)}}); - - EXPECT_FALSE(indexing_map.Simplify()); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Iota) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=0 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, d0), - domain: - d0 in [0, 255], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_IotaAsConstant) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=1 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, 7), - domain: - d0 in [0, 255], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ConstraintsGetUpdated) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - ROOT %iota = s64[10, 10]{1,0} iota(), iota_dimension=0 - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 255}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 9}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7)", &mlir_context_)}}); - indexing_map.AddConstraint(ParseAffineExpr("s0 mod 2", &mlir_context_), - Interval{0, 0}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, d0), - domain: - d0 in [0, 254], - d0 mod 2 in [0, 0], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Broadcast) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota = s64[12]{0} iota(), iota_dimension=0 - ROOT %broadcast = s64[32, 12]{1,0} broadcast(s64[12]{0} %iota), dimensions={1} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // (d0, 11): d0 maps into the broadcasted dimension, so it doesn't matter - // and 11 maps to 11 in iota. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 31}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 11)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, 11), - domain: - d0 in [0, 31], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_ChainedNoncomputeOps) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota = s64[12]{0} iota(), iota_dimension=0 - %reverse = s64[12]{0} reverse(s64[12]{0} %iota), dimensions={0} - %reshape = s64[3,4]{1,0} reshape(s64[12]{0} %reverse) - ROOT %broadcast = s64[36,3,4]{2,1,0} broadcast(s64[3,4]{1,0} %reshape), dimensions={1,2} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // - Iota: [0, 1, ,,,, 11] - // - Reverse: [11, 10, ..., 0] - // - Reshape: [[11, 10, 9, 8], [7, 6, 5, 4], [3, 2, 1, 0]] - // - Coordinates: (d0 floordiv 12, 3) - // - y-coordinate=3 means we index into [8, 4, 0] - // - x-coordinate=(d0 floordiv 12) means our constant looks like this: - // [8, ..., 8, 4, ..., 4, 0, ..., 0] - // - Hence our final expression: (d0 floordiv 12) * -4 + 8 - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 35}}, - /*range_vars=*/{}, - {RTVar{ - Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0 floordiv 12, 3)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, (d0 floordiv 12) * -4 + 8), - domain: - d0 in [0, 35], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartialRTVarRemoval) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[12]{0} constant({...}) - ROOT %broadcast = s64[24,12]{1,0} broadcast(s64[12]{0} %constant), dimensions={1} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // (d0, d0 floordiv 2): d0 maps into the broadcasted dimension, so it can't be - // removed, but d0 floordiv 2 doesn't yield an affine expression so we need to - // keep the RTVar, but can optimize it by removing the broadcast. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 23}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 512}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0 floordiv 2)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0)[s0] -> (d0, s0), - domain: - d0 in [0, 23], - s0 in [0, 512], - hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0 floordiv 2), - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Add) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[] constant(42) - %broadcast = s64[12,13,24]{2,1,0} broadcast(s64[] %constant), dimensions={} - %iota = s64[12,13,24]{2,1,0} iota(), iota_dimension=2 - ROOT %add = s64[12,13,24]{2,1,0} add(s64[12,13,24]{2,1,0} %broadcast, s64[12,13,24]{2,1,0} %iota) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // The iota dimension is the last dimension in (d0, 7, 2 * d0), hence this - // composes to 42 + 2 * d0 - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7, 2 * d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, d0 * 2 + 42), - domain: - d0 in [0, 11], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_Multiply) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %iota0 = s64[12,12]{1,0} iota(), iota_dimension=0 - %iota1 = s64[12]{0} iota(), iota_dimension=0 - %broadcast = s64[12,12]{1,0} broadcast(s64[12]{0} %iota1), dimensions={1} - %multiply = s64[12,12]{1,0} multiply(s64[12,12]{1,0} %iota0, s64[12,12]{1,0} %broadcast) - ROOT %reverse = s64[12,12]{1,0} reverse(s64[12,12]{1,0} %multiply), dimensions={0} - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // Iota0: [[0, ..., 0], [1, ..., 1], ..., [11, ..., 11]] - // Iota1: [0, ..., 11] - // Broadcast1: [[0, 1, ..., 11], [0, 1, ..., 11], ..., [0, 1, ..., 11]] - // Mul: [[0, .., 0], [0, 1, ..., 11], [0, 2, ..., 22], ..., [0, 11, ..., 121]] - // Reverse: [[0, 11, ..., 121], [0, 10, ..., 110], ..., [0, ..., 0]] - // Therefore (d0, d0) evaluates to: (11 - d0) * d0. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - indexing_map.RemoveUnusedSymbols(); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0) -> (d0, (-d0 + 11) * d0), - domain: - d0 in [0, 11], - is_simplified: true - )")); -} - -TEST_F(IndexingMapTest, ReplaceConstantRTVars_PartiallyOptimizableAdd) { - absl::StatusOr> hlo_module = - ParseAndReturnVerifiedModule(R"hlo( - HloModule m - - ENTRY e { - %constant = s64[12]{0} constant({...}) - %broadcast = s64[12,13,24]{2,1,0} broadcast(s64[12]{0} %constant), dimensions={0} - %iota = s64[12,13,24]{2,1,0} iota(), iota_dimension=2 - ROOT %add = s64[12,13,24]{2,1,0} add(s64[12,13,24]{2,1,0} %broadcast, s64[12,13,24]{2,1,0} %iota) - } - )hlo"); - - ASSERT_TRUE(hlo_module.ok()); - - // The iota dimension is the last dimension in (d0, 7, 2 * d0), the constant - // only depends on the first dimension. The constant consists of some - // arbitrary values that cannot be represent as an affine expression, hence - // the RTVar remains in-place. - IndexingMap indexing_map( - ParseAffineMap("(d0)[s0] -> (d0, s0)", &mlir_context_), - /*dimensions=*/{{0, 11}}, - /*range_vars=*/{}, - {RTVar{Interval{0, 11}, - hlo_module.value()->entry_computation()->root_instruction(), - ParseAffineMap("(d0) -> (d0, 7, 2 * d0)", &mlir_context_)}}); - - EXPECT_TRUE(indexing_map.Simplify()); - - EXPECT_THAT(indexing_map.ToString(printer_), MatchIndexingString(R"( - (d0)[s0] -> (d0, d0 * 2 + s0), - domain: - d0 in [0, 11], - s0 in [0, 11], - hlo: %constant = s64[12]{0} constant({...}), - (d0) -> (d0), - is_simplified: true - )")); -} - template void ExpectSupportsAbslHashAndEqAndNe(absl::Span values) { EXPECT_TRUE(absl::VerifyTypeImplementsAbslHashCorrectly(values)); @@ -1895,103 +1568,104 @@ TEST_F(IndexingMapTest, IntervalSupportsLlvmStyleHashingAndEqAndNe) { } TEST_F(IndexingMapTest, DimVarSupportsAbslHashAndEqAndNe) { - ExpectSupportsAbslHashAndEqAndNe( - {DimVar{1, 1}, DimVar{0, 1}, DimVar{1, 2}}); + ExpectSupportsAbslHashAndEqAndNe( + {IndexingMap::Variable{1, 1}, IndexingMap::Variable{0, 1}, + IndexingMap::Variable{1, 2}}); } TEST_F(IndexingMapTest, RangeVarSupportsAbslHashAndEqAndNe) { - ExpectSupportsAbslHashAndEqAndNe( - {RangeVar{1, 1}, RangeVar{0, 1}, RangeVar{1, 2}}); + ExpectSupportsAbslHashAndEqAndNe( + {IndexingMap::Variable{1, 1}, IndexingMap::Variable{0, 1}, + IndexingMap::Variable{1, 2}}); } TEST_F(IndexingMapTest, RTVarSupportsAbslHashAndEqAndNe) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(R"( -HloModule m - -ENTRY e { - ROOT %constant = s64[] constant(42) -})")); + HloModule m + ENTRY e { + ROOT %constant = s64[] constant(42) + } + )")); ASSERT_NE(hlo_module, nullptr); - const HloInstruction* constant_instr = - hlo_module->entry_computation()->root_instruction(); - - ExpectSupportsAbslHashAndEqAndNe( - {RTVar{Interval{1, 1}, nullptr, - ParseAffineMap("(d0) -> (d0)", &mlir_context_)}, - RTVar{Interval{1, 2}, nullptr, - ParseAffineMap("(d0) -> (d0)", &mlir_context_)}, - RTVar{ - Interval{1, 2}, - nullptr, - ParseAffineMap("(d0) -> (d0 * 2)", &mlir_context_), - }, - RTVar{ - Interval{1, 2}, - constant_instr, - ParseAffineMap("(d0) -> (d0 * 2)", &mlir_context_), - }}); + + ExpectSupportsAbslHashAndEqAndNe( + {IndexingMap::Variable{Interval{1, 1}}, + IndexingMap::Variable{Interval{1, 2}}, + IndexingMap::Variable{Interval{1, 2}}, + IndexingMap::Variable{Interval{1, 2}}}); } TEST_F(IndexingMapTest, IndexingMapSupportsAbslHashAndEqAndNe) { - auto zero_dim_map = AffineMap::get(&mlir_context_); ExpectSupportsAbslHashAndEqAndNe( - {IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {51, 60}, {70, 80}), - IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {71, 80}), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 16", &mlir_context_), - Interval{0, 0}); - return m; - }(), - [&] { - auto m = IndexingMap::FromTensorSizes( - ParseAffineMap("(d0, d1)[s0, s1] -> (d1, d0, s1, s0)", - &mlir_context_), - {50, 60}, {70, 80}); - m.AddConstraint(ParseAffineExpr("d0 mod 8", &mlir_context_), - Interval{0, 0}); - m.AddConstraint(ParseAffineExpr("d0 mod 32", &mlir_context_), - Interval{0, 0}); - return m; - }(), + {Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79] + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1 * 2, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79] + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 50], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79] + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79] + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 16 in [0, 0] + )"), + Parse(R"( + (d0, d1)[s0, s1] -> (d1, d0, s1, s0), + domain: + d0 in [0, 49], + d1 in [0, 59], + s0 in [0, 69], + s1 in [0, 79], + d0 mod 8 in [0, 0], + d0 mod 32 in [0, 0] + )"), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), - {DimVar{{0, 31}}}, - {RangeVar{{0, 0}}, RangeVar{{0, 1}}, RangeVar{{0, 2}}}, - {RTVar{Interval{0, 3}, - /*instr=*/nullptr, zero_dim_map}, - RTVar{Interval{0, 4}, - /*instr=*/nullptr, zero_dim_map}}), + {IndexingMap::Variable{{0, 31}}}, + {IndexingMap::Variable{{0, 0}}, IndexingMap::Variable{{0, 1}}, + IndexingMap::Variable{{0, 2}}}, + {IndexingMap::Variable{Interval{0, 3}}, + IndexingMap::Variable{Interval{0, 4}}}), IndexingMap( ParseAffineMap("(d0)[s0, s1, s2, s3, s4] -> (d0 * 4 + s1 + s3 - 42)", &mlir_context_), - {DimVar{{0, 31}}}, - {RangeVar{{0, 0}}, RangeVar{{0, 1}}, RangeVar{{0, 2}}}, - {RTVar{Interval{0, 3}, - /*instr=*/nullptr, zero_dim_map}, - RTVar{Interval{0, 5}, - /*instr=*/nullptr, zero_dim_map}})}); + {IndexingMap::Variable{{0, 31}}}, + {IndexingMap::Variable{{0, 0}}, IndexingMap::Variable{{0, 1}}, + IndexingMap::Variable{{0, 2}}}, + {IndexingMap::Variable{Interval{0, 3}}, + IndexingMap::Variable{Interval{0, 5}}})}); } } // namespace diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc index d014b597dba2b6..24a1e88131977a 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.cc @@ -345,16 +345,16 @@ absl::Status VerifyExprsAreIdentical( mlir::AffineExpr reference, mlir::AffineExpr other, absl::Span dimension_ranges, absl::Span symbol_ranges) { - std::vector dims; + std::vector dims; dims.reserve(dimension_ranges.size()); for (const auto& interval : dimension_ranges) { - dims.push_back({interval}); + dims.push_back(IndexingMap::Variable{interval}); } - std::vector symbols; + std::vector symbols; symbols.reserve(symbol_ranges.size()); for (const auto& interval : symbol_ranges) { - symbols.push_back({interval}); + symbols.push_back(IndexingMap::Variable{interval}); } IndexingMap map(mlir::AffineMap::get(dimension_ranges.size(), diff --git a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h index c4bb2910fa36cd..5584c208685e5b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_test_utils.h +++ b/third_party/xla/xla/service/gpu/model/indexing_test_utils.h @@ -31,8 +31,10 @@ limitations under the License. #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" @@ -49,7 +51,7 @@ MATCHER_P(MatchIndexingMap, indexing_string, "") { return false; } return ExplainMatchResult( - true, ApproximateMatch(indexing_string, arg.ToString()), result_listener); + true, ApproximateMatch(indexing_string, ToString(arg)), result_listener); } MATCHER_P(MatchIndexingString, indexing_string, "") { diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index f2e14e0c655bc0..0cb570ebf8e0a0 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "llvm/ADT/DenseMap.h" @@ -41,8 +42,8 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" #include "xla/service/gpu/model/affine_map_evaluator.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" namespace xla { namespace gpu { @@ -231,8 +232,7 @@ ExtractSizesAndStridesFromMultivariateSummation( std::optional maybe_size_and_stride = ExtractSizeAndStride(summand, dimension_intervals, symbol_intervals); if (!maybe_size_and_stride.has_value()) { - VLOG(1) << "Couldn't extract size and stride from " - << AffineMapPrinter().ToString(summand); + VLOG(1) << "Couldn't extract size and stride from " << ToString(summand); return std::nullopt; } sizes_and_strides.push_back(*maybe_size_and_stride); @@ -320,8 +320,8 @@ std::optional TryGetSizeExpressionRangeSize( // working well with concatenations. Nevertheless, we can take a look // later. VLOG(1) << "Attempted to combine strides but got dimension " - << AffineMapPrinter().ToString(size) << " with lower bound " - << interval.lower << " != 0"; + << ToString(size) << " with lower bound " << interval.lower + << " != 0"; return std::nullopt; } // We need to add 1 to the upper bound of the interval to describe the @@ -364,7 +364,7 @@ std::optional CombineStrides( for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { if (size_and_stride.stride.getKind() != AffineExprKind::Constant) { VLOG(1) << "Attempted to combine non-constant stride: " - << AffineMapPrinter().ToString(size_and_stride.stride); + << ToString(size_and_stride.stride); return std::nullopt; } @@ -379,7 +379,7 @@ std::optional CombineStrides( size_and_stride.size.getKind() != AffineExprKind::DimId) { VLOG(1) << "Attempted to combine strides but got non-constant, " "non-dimension size " - << AffineMapPrinter().ToString(size_and_stride.size); + << ToString(size_and_stride.size); return std::nullopt; } } @@ -567,9 +567,8 @@ std::optional CombineSizesAndStrides( if (VLOG_IS_ON(1)) { for (const SizeAndStrideExpression& size_and_stride : sizes_and_strides) { LOG(INFO) << "CombineSizesAndStrides:"; - LOG(INFO) << "size: " << AffineMapPrinter().ToString(size_and_stride.size) - << " stride: " - << AffineMapPrinter().ToString(size_and_stride.stride); + LOG(INFO) << "size: " << ToString(size_and_stride.size) + << " stride: " << ToString(size_and_stride.stride); } } @@ -603,7 +602,6 @@ std::optional ExtractSizeAndStride( AffineExpr strided_indexing, absl::Span dimension_intervals, absl::Span symbol_intervals) { MLIRContext* ctx = strided_indexing.getContext(); - AffineMapPrinter printer; switch (strided_indexing.getKind()) { case AffineExprKind::DimId: @@ -711,9 +709,8 @@ std::optional TryIntersectConjointConstraints( auto& [result_expr, result_interval] = *result_it; result_interval = result_interval.Intersect(interval); if (!result_interval.IsFeasible()) { - AffineMapPrinter printer; VLOG(1) << "Got two incompatible intervals for expression " - << printer.ToString(expr); + << ToString(expr); return std::nullopt; } } else { @@ -866,15 +863,13 @@ bool ConstraintExpression::IsSatisfiedBy( return constraints_are_satisfied; } -std::string ConstraintExpression::ToString( - const AffineMapPrinter& printer) const { +std::string ConstraintExpression::ToString() const { std::stringstream ss; - Print(ss, printer); + Print(ss); return ss.str(); } -void ConstraintExpression::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +void ConstraintExpression::Print(std::ostream& out) const { if (IsAlwaysSatisfied()) { out << "always satisfied"; } else if (is_satisfiable()) { @@ -886,11 +881,8 @@ void ConstraintExpression::Print(std::ostream& out, std::vector constraint_strings; constraint_strings.reserve(disjunction.size()); for (const auto& [expr, interval] : disjunction) { - std::stringstream ss; - printer.Print(ss, expr); - ss << " in "; - interval.Print(ss); - constraint_strings.push_back(ss.str()); + constraint_strings.push_back(absl::StrCat(xla::gpu::ToString(expr), + " in ", interval.ToString())); } std::sort(constraint_strings.begin(), constraint_strings.end()); conjunction_strings.push_back(absl::StrJoin(constraint_strings, " && ")); @@ -1019,7 +1011,7 @@ void ConstraintExpression::Simplify() { /*static*/ std::optional SymbolicTile::FromIndexingMap( IndexingMap indexing_map) { - VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map.ToString(); + VLOG(1) << "SymbolicTile::FromIndexingMap: " << indexing_map; // We do not handle indexing maps with pre-existing constraints for now. // Let's try to simplify the indexing map, because the constraints my be @@ -1030,7 +1022,7 @@ void ConstraintExpression::Simplify() { if (indexing_map.GetConstraintsCount() != 0) { VLOG(1) << "Deriving symbolic tile from indexing map with pre-existing " << "constraints might produce spurious constraints. Bailing out. " - << indexing_map.ToString(); + << indexing_map; return std::nullopt; } @@ -1104,17 +1096,16 @@ void ConstraintExpression::Simplify() { offset = offset + size * stride - stride; stride = -stride; } else if (!constant) { - AffineMapPrinter printer; VLOG(1) << "Unexpected non-constant stride expression: " - << printer.ToString(stride); + << xla::gpu::ToString(stride); } } // DimVars in `indexing_map` represent indices, but in `tile_map` they will // represent the size of the tile. So we need to add 1 to the bounds. // For example: indices: [0, 9] -> sizes: [1, 10]. - std::vector tile_sizes = indexing_map.GetDimVars(); - for (DimVar& tile_size : tile_sizes) { + std::vector tile_sizes = indexing_map.GetDimVars(); + for (IndexingMap::Variable& tile_size : tile_sizes) { tile_size.bounds.lower += 1; tile_size.bounds.upper += 1; } @@ -1139,45 +1130,33 @@ void ConstraintExpression::Simplify() { /*rt_vars=*/indexing_map.GetRTVars()); tile_map.RemoveUnusedSymbols(); CHECK_EQ(tile_map.GetRangeVarsCount(), 0); - VLOG(1) << "tile_map: " << tile_map.ToString(); + VLOG(1) << "tile_map: " << tile_map; constraints.Simplify(); return SymbolicTile(std::move(tile_map), std::move(constraints)); } -std::string SymbolicTile::RtVarsToString( - const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - PrintRTVars(tile_map_.GetRTVars(), /*first_rt_var_symbol_index=*/0, ss, - printer); - return ss.str(); -} - -std::string SymbolicTile::ToString(const AffineMapPrinter& printer) const { - std::string s; - std::stringstream ss(s); - Print(ss, printer); +std::string SymbolicTile::ToString() const { + std::stringstream ss; + Print(ss); return ss.str(); } -void SymbolicTile::Print(std::ostream& out, - const AffineMapPrinter& printer) const { +void SymbolicTile::Print(std::ostream& out) const { out << "Symbolic tile with \n"; - out << "\toffset_map: "; - printer.Print(out, offset_map()); - out << "\n\tsize_map: "; - printer.Print(out, size_map()); - out << "\n\tstride_map: "; - printer.Print(out, stride_map()); - const std::vector& rt_vars = tile_map_.GetRTVars(); + out << "\toffset_map: " << offset_map(); + out << "\n\tsize_map: " << size_map(); + out << "\n\tstride_map: " << stride_map(); + const std::vector& rt_vars = tile_map_.GetRTVars(); if (!rt_vars.empty()) { out << "\n\trt_vars: "; - PrintRTVars(rt_vars, /*first_rt_var_symbol_index=*/0, out, printer); + for (const auto& [index, rt_var] : llvm::enumerate(rt_vars)) { + out << 's' << index << " in " << rt_var.bounds << ", "; + } } if (!constraints_.IsAlwaysSatisfied()) { out << "\n\tconstraints: "; - constraints_.Print(out, printer); + constraints_.Print(out); } } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.h b/third_party/xla/xla/service/gpu/model/symbolic_tile.h index a86cd363daf1e8..c8e19f27112e4e 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.h @@ -29,7 +29,6 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -110,10 +109,9 @@ class ConstraintExpression { return disjoint_conjoint_constraints_; } - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + void Print(std::ostream& out) const; // Simplifies the constraint expression. // @@ -285,12 +283,9 @@ class SymbolicTile { static std::optional FromIndexingMap(IndexingMap indexing_map); // For printing in tests. - std::string RtVarsToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; - void Print(std::ostream& out, const AffineMapPrinter& printer) const; + void Print(std::ostream& out) const; mlir::AffineMap offset_map() const; mlir::AffineMap size_map() const; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 1ab971ff2a4a94..43358a5781ce1d 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -50,9 +50,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -132,7 +132,7 @@ absl::StatusOr ComputeTileOffsetIndexing( })) { return absl::FailedPreconditionError( absl::StrCat("Symbol lower bound is not zero. ", - tiled_hlo.indexing_map().ToString())); + ToString(tiled_hlo.indexing_map()))); } std::vector symbol_lower_bounds( @@ -228,9 +228,8 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( // Bail out on instructions that are known to cause problems down the // line. This is not an inherent limitation of the approach, but simply // issues to be resolved in the current implementation. - if (hlo->opcode() == HloOpcode::kDot || - hlo->opcode() == HloOpcode::kConcatenate) { - return FusionDecision{} << "Bailing out on " << hlo->ToString(); + if (hlo->opcode() == HloOpcode::kConcatenate) { + return FusionDecision::Forbid("Bailing out on ") << hlo->ToString(); } // Due to the issue highlighted in b/365727080, and the related workaround @@ -254,13 +253,13 @@ FusionDecision ShouldProceedWithSymbolicTileDerivation( SymbolicTile::FromIndexingMap(reshape_indexing_map); if (!reshape_symbolic_tile.has_value()) { - return FusionDecision{} << "Bailing out on reshape " << hlo->ToString() - << " with indexing map " - << reshape_indexing_map.ToString(); + return FusionDecision::Forbid("Bailing out on reshape ") + << hlo->ToString() << " with indexing map " + << ToString(reshape_indexing_map); } } - return {}; + return FusionDecision::Allow(); } // Sets a SymbolicTile for each tiled hlo instruction and computes their @@ -292,16 +291,14 @@ SetSymbolicTilesAndComputeConstraints( auto symbolic_tile = SymbolicTile::FromIndexingMap(indexing_map); if (!symbolic_tile.has_value()) { - return FusionDecision{} << "Failed to compute symbolic tile for " - << indexing_map.ToString() << " for HLO " - << hlo->ToString(); + return FusionDecision::Forbid("Failed to compute symbolic tile for ") + << ToString(indexing_map) << " for HLO " << hlo->ToString(); } if (!symbolic_tile->is_satisfiable()) { - return FusionDecision{} << "Symbolic tile " << symbolic_tile->ToString() - << " is not satisfiable for " - << indexing_map.ToString() << " for HLO " - << hlo->ToString(); + return FusionDecision::Forbid("Symbolic tile ") + << symbolic_tile->ToString() << " is not satisfiable for " + << ToString(indexing_map) << " for HLO " << hlo->ToString(); } constraints = ConstraintExpression::And(std::move(constraints), @@ -309,7 +306,7 @@ SetSymbolicTilesAndComputeConstraints( constraints.Simplify(); if (!constraints.is_satisfiable()) { - return FusionDecision{} << "Fusion has unsatisfiable constraints"; + return FusionDecision::Forbid("Fusion has unsatisfiable constraints"); } tiled_hlo_instruction->set_symbolic_tile(*std::move(symbolic_tile)); @@ -347,7 +344,7 @@ void SortTiledHloInstructionsInPostOrder( }); } -} // namespace +} // anonymous namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( const HloComputation& computation, MLIRContext* ctx, @@ -365,8 +362,8 @@ void SortTiledHloInstructionsInPostOrder( auto roots = fusion.GetRoots(); if (roots.size() > 1) { - return FusionDecision{} << "Multi-output fusions are not supported. " - << fusion.ToString(); + return FusionDecision::Forbid("Multi-output fusions are not supported. ") + << fusion.ToString(); } auto& root = roots[0]; @@ -399,8 +396,8 @@ void SortTiledHloInstructionsInPostOrder( ComposeIndexingMaps(tiled_hlo_instruction->indexing_map(), *operand_indexing_map_set.begin()); if (operand_indexing_map.IsUndefined()) { - return FusionDecision{} - << "Couldn't derive indexing map for instruction " + return FusionDecision::Forbid( + "Couldn't derive indexing map for instruction ") << tiled_hlo_instruction->hlo()->ToString() << " and operand " << operand.instruction().ToString(); } @@ -564,7 +561,8 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( std::optional tile_offset_indexing; if (compute_all_tile_offset_indexing_maps || - parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo())) { + parameters_with_offset_indexing.contains(symbolic_tiled_hlo->hlo()) || + symbolic_tiled_hlo->hlo()->opcode() == HloOpcode::kIota) { TF_ASSIGN_OR_RETURN( tile_offset_indexing, ComputeTileOffsetIndexing( @@ -594,8 +592,7 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( output_tiling_info.num_output_tiles_per_dim); } -std::string SymbolicTileAnalysis::ToString( - const AffineMapPrinter& printer) const { +std::string SymbolicTileAnalysis::ToString() const { std::stringstream ss; NameUniquer name_uniquer("_"); absl::flat_hash_map tile_names; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index 58a08afde9ba17..e8e0cea2fef4b8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -31,7 +31,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -91,7 +90,7 @@ class SymbolicTileAnalysis { // Returns a graph of HLO instructions tiled with the given tile parameters. // The provided tile parameters must satisfy the analysis's constraints. - // By default, `ComputetiledHloInstructions` performs a check that the + // By default, `ComputeTiledHloInstructions` performs a check that the // constraints are satisfied by the chosen tiled parameters. Setting // `constraints_are_known_satisfied` to true bypasses this check. // @@ -141,8 +140,7 @@ class SymbolicTileAnalysis { // Returns a string representation of the analysis. Used only for error // messages and debugging. - std::string ToString( - const AffineMapPrinter& printer = AffineMapPrinter()) const; + std::string ToString() const; // Returns a list of tilings for the symbolic tiled HLO computation of the // analysis that are expected to perform well. diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index 99f136d3aecae2..c9ca6877e0bcdc 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/types/span.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/symbolic_tile.h" @@ -40,9 +41,9 @@ limitations under the License. #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -168,8 +169,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); auto p0_from_subtract0 = root->operand(0); @@ -182,8 +182,7 @@ ENTRY main { (d0, d1) -> (d0, d1 * 10), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); EXPECT_THAT(*p0_from_subtract1, MatchTiledHloInstruction( @@ -193,8 +192,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 9], - is_simplified: true + d1 in [0, 9] )")); } @@ -286,8 +284,7 @@ ENTRY main { (d0, d1) -> (d0, 0), domain: d0 in [0, 1], - d1 in [0, 0], - is_simplified: true + d1 in [0, 0] )")); } @@ -321,8 +318,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); EXPECT_THAT(*root->operand(0), @@ -333,8 +329,7 @@ ENTRY main { domain: d0 in [0, 1], d1 in [0, 1], - d2 in [0, 7], - is_simplified: true + d2 in [0, 7] )")); } @@ -371,8 +366,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice0, @@ -382,8 +376,7 @@ ENTRY main { (d0, d1) -> (d0 * 2, d1 * 2 + 2), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); EXPECT_THAT(*p0_from_slice1, @@ -393,28 +386,63 @@ ENTRY main { (d0, d1) -> (d0 * 2 + 3, d1 * 2 + 4), domain: d0 in [0, 1], - d1 in [0, 3], - is_simplified: true + d1 in [0, 3] )")); } -TEST_F(SymbolicTileAnalysisTest, BailOutOnUnsupportedDot) { +TEST_F(SymbolicTileAnalysisTest, DotOffsetIndexingIsCorrect) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion { - p0 = f32[1,2]{1,0} parameter(0) - p1 = f32[2,3]{1,0} parameter(1) - ROOT dot = f32[1,3]{1,0} dot(p0, p1), - lhs_batch_dims={}, rhs_batch_dims={}, + p0 = f32[4,8] parameter(0) + p1 = f32[8,16] parameter(1) + ROOT dot = f32[4,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY main { - p0 = f32[1,2]{1,0} parameter(0) - p1 = f32[2,3]{1,0} parameter(1) - ROOT fusion = f32[1,3]{1,0} fusion(p0, p1), kind=kLoop, calls=fusion + p0 = f32[4,8] parameter(0) + p1 = f32[8,16] parameter(1) + ROOT fusion = f32[4,16] fusion(p0, p1), kind=kLoop, calls=fusion })")); - EXPECT_FALSE(TryAnalyzeModule(module.get()).has_value()); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{2, 2}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/true)); + + const TiledHloInstruction* dot = tiled_hlo_computation.GetRoot(); + EXPECT_THAT(*dot, MatchTiledHloInstruction( + /*tile_sizes=*/{2, 2}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (d0 * 2, d1 * 2), + domain: + d0 in [0, 1], + d1 in [0, 7] + )")); + + const TiledHloInstruction* lhs = dot->operand(0); + EXPECT_THAT(*lhs, MatchTiledHloInstruction( + /*tile_sizes=*/{2, 8}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (d0 * 2, 0), + domain: + d0 in [0, 1], + d1 in [0, 7] + )")); + + const TiledHloInstruction* rhs = dot->operand(1); + EXPECT_THAT(*rhs, MatchTiledHloInstruction( + /*tile_sizes=*/{8, 2}, /*tile_strides=*/{1, 1}, + /*tile_offsets_indexing=*/R"( + (d0, d1) -> (0, d1 * 2), + domain: + d0 in [0, 1], + d1 in [0, 7] + )")); } TEST_F(SymbolicTileAnalysisTest, DoesNotBailOutOnConstrainedReshape) { @@ -871,8 +899,7 @@ ENTRY main { (d0, d1) -> (d0, d1), domain: d0 in [0, 65537], - d1 in [0, 32767], - is_simplified: true + d1 in [0, 32767] )")); } @@ -927,25 +954,19 @@ ENTRY main { (d0, d1) -> (0, d1, 0), domain: d0 in [0, 0], - d1 in [0, 1], - is_simplified: true + d1 in [0, 1] )")); EXPECT_THAT(*param_0_tile, MatchTiledHloInstruction( /*tile_sizes=*/{1, 1, 32}, /*tile_strides=*/{0, 1, 1}, /*tile_offsets_indexing=*/R"( - (d0, d1)[s0, s1] -> (s0, d1, s1), + (d0, d1){rt0, rt1} -> (rt0, d1, rt1), domain: d0 in [0, 0], d1 in [0, 1], - s0 in [0, 1], - hlo: %of1 = s32[] parameter(1), - (d0, d1, d2) -> (), - s1 in [0, 226], - hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), - is_simplified: true + rt0 in [0, 1], + rt1 in [0, 226] )")); } @@ -1020,6 +1041,29 @@ ENTRY main { EXPECT_TRUE(analysis.has_value()); } +TEST_F(SymbolicTileAnalysisTest, IotaAlwaysHasTileOffsetsIndexingSet) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +fusion { + ROOT iota = s32[100] iota(), iota_dimension=0 +} + +ENTRY main { + ROOT fusion = s32[100] fusion(), kind=kLoop, calls=fusion +})")); + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + TF_ASSERT_OK_AND_ASSIGN(TiledHloComputation tiled_hlo_computation, + analysis->ComputeTiledHloInstructions( + /*tile_parameters=*/{4}, + /*constraints_are_known_satisfied=*/false, + /*compute_all_tile_offset_indexing_maps=*/false)); + + const TiledHloInstruction* iota = tiled_hlo_computation.GetRoot(); + EXPECT_THAT(iota->tile_offsets_indexing().status(), ::tsl::testing::IsOk()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 78e7068ee45196..493a331e52555f 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -407,11 +407,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicSlice) { stride_map: (d0, d1, d2) -> (0, 1, 1) rt_vars: s0 in [0, 1], - hlo: %of1 = s32[] parameter(1), - (d0, d1, d2) -> (), s1 in [0, 226], - hlo: %of3 = s32[] parameter(3), - (d0, d1, d2) -> (), )"))); for (int i = 1; i <= 3; i++) { EXPECT_THAT( @@ -459,11 +455,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughDynamicUpdateSlice) { stride_map: (d0, d1) -> (1, 1) rt_vars: s0 in [0, 15], - hlo: %of1 = s32[] parameter(2), - (d0, d1) -> (), s1 in [0, 20], - hlo: %of2 = s32[] parameter(3), - (d0, d1) -> (), )"))); for (int i = 2; i <= 3; i++) { EXPECT_THAT( @@ -502,11 +494,7 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughGather) { stride_map: (d0, d1, d2, d3) -> (1, 1, 1) rt_vars: s0 in [0, 26], - hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 0), s1 in [0, 68], - hlo: %indices = s32[1806,2]{1,0} parameter(1), - (d0, d1, d2, d3) -> (d0, 1), )"))); EXPECT_THAT( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc index 4a6c067638cebd..fb687372b9d0be 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction.cc @@ -49,7 +49,7 @@ std::string SymbolicTiledHloInstruction::ToString() const { std::stringstream ss; ss << "\thlo: " << hlo_->ToString() << "\n"; ss << "\t" << symbolic_tile().ToString() << "\n"; - ss << "\tindexing map: " << indexing_map_.ToString() << "\n"; + ss << "\tindexing map: " << indexing_map_ << "\n"; return ss.str(); } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc index 62f603d3c54dff..21a270cfd7866e 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tiled_hlo_instruction_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h index 5708a4c3401c36..13d7456dfccaec 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_computation.h @@ -27,8 +27,8 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "xla/util.h" -#include "tsl/lib/gtl/iterator_range.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc index e68db3040c816a..997556007fbcb6 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_map_serialization.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -67,7 +68,7 @@ absl::Status VerifyTiledHloInstructionConstructorPreconditions( return absl::InvalidArgumentError(absl::StrFormat( "tile_offsets_indexing must have the same number of results as the " "rank of the hlo shape. tile_offsets_indexing = %s, hlo = %s", - tile_offsets_indexing->ToString(), hlo->ToString())); + ToString(*tile_offsets_indexing), hlo->ToString())); } return absl::OkStatus(); @@ -97,8 +98,9 @@ std::string TiledHloInstruction::ToString() const { ss << "\ttile_sizes: (" << absl::StrJoin(tile_sizes_, ", ") << ")\n"; ss << "\ttile_strides: (" << absl::StrJoin(tile_strides_, ", ") << ")\n"; ss << "\ttile_offsets_indexing: " - << (tile_offsets_indexing_.has_value() ? tile_offsets_indexing_->ToString() - : "nullopt"); + << (tile_offsets_indexing_.has_value() + ? gpu::ToString(*tile_offsets_indexing_) + : "nullopt"); return ss.str(); } diff --git a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h index 146035b0cb1e55..409bbdf6b4e62a 100644 --- a/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h +++ b/third_party/xla/xla/service/gpu/model/tiled_hlo_instruction.h @@ -71,12 +71,14 @@ class TiledHloInstruction { // Returns the tile sizes. The number of tile sizes is equal to the rank of // the output shape. const llvm::SmallVector& tile_sizes() const { return tile_sizes_; } + int64_t tile_size(int64_t dim_idx) const { return tile_sizes_[dim_idx]; } // Returns the tile strides. The number of tile strides is equal to the rank // of the output shape. const llvm::SmallVector& tile_strides() const { return tile_strides_; } + int64_t tile_stride(int64_t dim_idx) const { return tile_strides_[dim_idx]; } // Returns the indexing map from tile multi-index to tile offsets. The map has // a form of `(d0, d1, ...) -> (tile_offset0, tile_offset1, ...)`. The number diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc index 7a171bf1c76a6d..c4dc4a61173466 100644 --- a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -26,13 +26,13 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index e29420c8bfb8b3..eda7b9e05c10f3 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -42,17 +42,22 @@ limitations under the License. #include "llvm/IRReader/IRReader.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/dot_dimension_merger.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" -#include "xla/service/convert_mover.h" -#include "xla/service/dot_dimension_merger.h" #include "xla/service/dump.h" -#include "xla/service/float_normalization.h" #include "xla/service/float_support.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/autotuning/conv_algorithm_picker.h" @@ -64,6 +69,7 @@ limitations under the License. #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" +#include "xla/service/gpu/llvm_gpu_backend/nvptx_utils.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/target_constants.h" #include "xla/service/gpu/transforms/algebraic_simplifier.h" @@ -83,18 +89,12 @@ limitations under the License. #include "xla/service/gpu/transforms/gpusolver_rewriter.h" #include "xla/service/gpu/transforms/sort_rewriter.h" #include "xla/service/gpu/transforms/triangular_solve_rewriter.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" #include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/reshape_mover.h" -#include "xla/service/tuple_simplifier.h" #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" -#include "xla/stream_executor/cuda/cuda_driver.h" // IWYU pragma : keep - Needed for GpuContext #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/cuda/nvjitlink.h" #include "xla/stream_executor/cuda/nvjitlink_support.h" @@ -103,11 +103,9 @@ limitations under the License. #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/stream_executor/cuda/ptx_linking_method.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/util/env_var.h" @@ -187,7 +185,6 @@ class MatmulBfloat16Support : public FloatSupport { absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator, const se::SemanticVersion& toolkit_version) { auto cuda_compute_capability = std::get(gpu_version); @@ -228,6 +225,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( GetAlgebraicSimplifierOptions(hlo_module->config()); algsimp_options.set_supports_non_canonical_dots(false); algsimp_options.set_enable_conv_operand_swap(false); + algsimp_options.set_enable_conv_add_multiply_reorder(true); algsimp_options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(algsimp_options, gpu_version); @@ -296,6 +294,7 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( alg_sim_options.set_supports_non_canonical_dots(false); alg_sim_options.set_is_layout_sensitive(true); alg_sim_options.set_enable_conv_operand_swap(false); + alg_sim_options.set_enable_conv_add_multiply_reorder(true); // "slow" minmax means we propagate nan. alg_sim_options.set_minmax_propagate_nan( !hlo_module->config().debug_options().xla_gpu_enable_fast_min_max()); @@ -1003,8 +1002,7 @@ absl::StatusOr> NVPTXCompiler::LinkModules( return LinkUsingNvlink(cc, debug_options.xla_gpu_cuda_data_dir(), cubin_images); } - return LinkGpuAsm(cc, se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context(), - cubin_images); + return LinkGpuAsm(cc, stream_exec, cubin_images); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index 78591bb2c42a7d..b22af269d17829 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -29,17 +29,16 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "llvm/IR/Module.h" #include "xla/autotune_results.pb.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/cuda/ptx_linking_method.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" @@ -59,7 +58,6 @@ class NVPTXCompiler : public GpuCompiler { absl::Status OptimizeHloConvolutionCanonicalization( HloModule* hlo_module, se::GpuComputeCapability gpu_version, se::dnn::VersionInfo dnn_version, - se::DeviceMemoryAllocator* device_allocator, const se::SemanticVersion& toolkit_version) override; absl::Status OptimizeHloPostLayoutAssignment( diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index f43066672b2bef..43894e0a7ba7a3 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" @@ -30,7 +31,6 @@ limitations under the License. #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" @@ -87,7 +87,7 @@ class NVPTXCompilerTest : public HloTestBase { class NVPTXCompilerTestTriton : public NVPTXCompilerTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_cublas_fallback(false); return debug_options; diff --git a/third_party/xla/xla/service/gpu/outfeed_manager.cc b/third_party/xla/xla/service/gpu/outfeed_manager.cc index 000531bde3e020..2d76f26ffd3348 100644 --- a/third_party/xla/xla/service/gpu/outfeed_manager.cc +++ b/third_party/xla/xla/service/gpu/outfeed_manager.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 8704737b711e25..ef1097a00ad1d4 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -19,16 +19,16 @@ limitations under the License. #include #include +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu_gpu_shape_verifier.h" #include "xla/service/gpu/transforms/alias_passthrough_params.h" #include "xla/service/gpu/transforms/copy_fusion.h" #include "xla/service/gpu/transforms/horizontal_loop_fusion.h" #include "xla/service/gpu/transforms/sanitize_constant_names.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_verifier.h" #include "xla/service/layout_assignment.h" #include "xla/service/loop_schedule_linearizer.h" @@ -48,12 +48,13 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); + HloVerifierOpts opts = + HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout( + LayoutAssignment::InstructionCanChangeLayout); + opts.verify_unique_channel_ids = + !debug_options.xla_experimental_ignore_channel_id(); std::unique_ptr verifier_metadata = - std::make_unique( - HloVerifierOpts{} - .MakeLayoutSensitive() - .WithInstructionCanChangeLayout( - LayoutAssignment::InstructionCanChangeLayout)); + std::make_unique(std::move(opts)); pipeline.AddInvariantCheckerDebug(std::move(verifier_metadata), "hlo verifier (debug)"); diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h index 095907f39794ac..89c925e2308fef 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_PREPARE_HLO_FOR_IR_EMITTING_PIPELINE_H_ #define XLA_SERVICE_GPU_PREPARE_HLO_FOR_IR_EMITTING_PIPELINE_H_ +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/hlo_dataflow_analysis.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc index 967c0e494d67cb..8c4524a4d4cf99 100644 --- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc @@ -217,7 +217,7 @@ class NVPTXCompilationTests debug_options->set_xla_llvm_force_inline_before_split(false); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_autotune_level(0); return debug_options; diff --git a/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc new file mode 100644 index 00000000000000..a289aa12899401 --- /dev/null +++ b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.cc @@ -0,0 +1,93 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/reduce_scatter_combiner.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/gpu_collective_combiner_utils.h" +#include "xla/service/gpu/gpu_hlo_schedule.h" +#include "xla/service/hlo_domain_map.h" +#include "xla/service/reduce_scatter_combiner.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +std::optional PipelinedCombinerKey( + const HloInstruction* instruction, const HloDomainMap& domain_map, + bool combine_by_dim) { + auto combined_key = ReduceScatterCombiner::CombineKey(instruction, domain_map, + combine_by_dim); + if (!combined_key.has_value()) { + return std::nullopt; + } + auto backend_config = instruction->backend_config(); + if (!backend_config.ok()) { + return std::nullopt; + } + bool is_pipelined = + backend_config->collective_backend_config().is_pipelined(); + if (!is_pipelined) { + return std::nullopt; + } + ReduceScatterCombiner::GetGroupKeyExtraArgs(*combined_key) + .append(" " + std::to_string(static_cast(is_pipelined))); + return combined_key.value(); +} + +} // namespace + +absl::StatusOr GpuReduceScatterCombiner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + // Combiner threshold is specified. Running parent pass code. + if (combine_threshold_in_bytes_ != default_combine_threshold_in_bytes_) { + return ReduceScatterCombiner::Run(module, execution_threads); + } + + // Pass configuration heuristics are not enabled. Running parent pass code. + if (!module->config() + .debug_options() + .xla_gpu_enable_heuristic_pass_configuration()) { + return ReduceScatterCombiner::Run(module, execution_threads); + } + + // Combine as much as possible for pipelined collectives. + int previous_combiner_threshold = combine_threshold_in_bytes_; + combine_threshold_in_bytes_ = ComputeSuggestedCombinerThreshold( + *module, device_info_, ScheduleGpuModuleWithMemoryScheduler, + HloOpcode::kReduceScatter, pointer_size_); + TF_ASSIGN_OR_RETURN( + bool combined_pipelined_instructions, + RunWithKeyCombiner(module, execution_threads, PipelinedCombinerKey)); + + // Use previous combiner thresholds after we combine pipelined collectives. + // The rest is combined by the parent pass code. + combine_threshold_in_bytes_ = previous_combiner_threshold; + TF_ASSIGN_OR_RETURN(bool combined_rest, + ReduceScatterCombiner::Run(module, execution_threads)); + return combined_pipelined_instructions || combined_rest; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/reduce_scatter_combiner.h b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.h new file mode 100644 index 00000000000000..b09f39ba3f171c --- /dev/null +++ b/third_party/xla/xla/service/gpu/reduce_scatter_combiner.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_REDUCE_SCATTER_COMBINER_H_ +#define XLA_SERVICE_GPU_REDUCE_SCATTER_COMBINER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/reduce_scatter_combiner.h" +#include "xla/stream_executor/device_description.h" +namespace xla::gpu { + +// Similarly to `ReduceScatterCombiner` pass, combines `ReduceScatter` ops into +// a single larger `ReduceScatter` op to maximize network bandwidth usage. +// Additionally, if no flags are set for combiner thresholds, the pass will try +// to figure out the optimal combiner threshold by itself. +class GpuReduceScatterCombiner : public ReduceScatterCombiner { + public: + GpuReduceScatterCombiner(const se::DeviceDescription& device_info, + const int default_combine_threshold_in_bytes, + int64_t combine_threshold_in_bytes, + int64_t combine_threshold_count, bool combine_by_dim, + int64_t pointer_size) + : ReduceScatterCombiner(combine_threshold_in_bytes, + combine_threshold_count, combine_by_dim), + device_info_(device_info), + default_combine_threshold_in_bytes_(default_combine_threshold_in_bytes), + pointer_size_(pointer_size) {} + + absl::string_view name() const override { + return "gpu-reduce-scatter-combiner"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_info_; + const int default_combine_threshold_in_bytes_; + int64_t pointer_size_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_REDUCE_SCATTER_COMBINER_H_ diff --git a/third_party/xla/xla/service/gpu/reduce_scatter_combiner_test.cc b/third_party/xla/xla/service/gpu/reduce_scatter_combiner_test.cc new file mode 100644 index 00000000000000..539ff0031a85ee --- /dev/null +++ b/third_party/xla/xla/service/gpu/reduce_scatter_combiner_test.cc @@ -0,0 +1,362 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/reduce_scatter_combiner.h" + +#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/service/collective_utils.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using GpuReduceScatterCombinerTest = HloTestBase; + +using ::stream_executor::DeviceDescription; + +TEST_F(GpuReduceScatterCombinerTest, + CombinesPipelinedCollectivesUpToSuggestedThreshold) { + // The IR is the minimal valid example of a while loop with RS inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] reduce-scatter(param.pipelined.0), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] reduce-scatter(param.pipelined.1), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] reduce-scatter(param.pipelined.2), + to_apply=add, dimensions={0}, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] reduce-scatter(param.nonpipelined.0), + dimensions={0}, to_apply=add + ag.nonpipelined.1 = bf16[6,8,128] reduce-scatter(param.nonpipelined.1), + dimensions={0}, to_apply=add + ag.nonpipelined.2 = bf16[6,8,128] reduce-scatter(param.nonpipelined.2), + dimensions={0}, to_apply=add + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(true); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 87625; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + 4 * threshold_bytes); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, GpuReduceScatterCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + threshold_bytes, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: reduce-scatter(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]], %[[PIPELINED_PARAM_2]]) + // CHECK-DAG: reduce-scatter(%[[NONPIPELINED_PARAM_0]], %[[NONPIPELINED_PARAM_1]]) + // CHECK-DAG: reduce-scatter(%[[NONPIPELINED_PARAM_2]]) + )"; + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +TEST_F(GpuReduceScatterCombinerTest, + CombinesCollectivesUpToSpecifiedThreshold) { + // The IR is the minimal valid example of a while loop with RS inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] reduce-scatter(param.pipelined.0), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] reduce-scatter(param.pipelined.1), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] reduce-scatter(param.pipelined.2), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] reduce-scatter(param.nonpipelined.0), + dimensions={0}, to_apply=add + ag.nonpipelined.1 = bf16[6,8,128] reduce-scatter(param.nonpipelined.1), + dimensions={0}, to_apply=add + ag.nonpipelined.2 = bf16[6,8,128] reduce-scatter(param.nonpipelined.2), + dimensions={0}, to_apply=add + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(true); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 87625; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, GpuReduceScatterCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultReduceScatterCombineThreshold, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: reduce-scatter(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]]) + // CHECK-DAG: reduce-scatter(%[[PIPELINED_PARAM_2]], %[[NONPIPELINED_PARAM_0]]) + // CHECK-DAG: reduce-scatter(%[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]) + )"; + + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +TEST_F(GpuReduceScatterCombinerTest, + CombinesCollectivesUpToDefaultThresholdIfFlagDisabled) { + // The IR is the minimal valid example of a while loop with RS inside. Three + // are annotated as pipelined and three are not. Various configurations of the + // combiner are tested to ensure the expected behaviour. + constexpr absl::string_view kHloString = R"( +HloModule module + +add { + lhs = bf16[] parameter(0) + rhs = bf16[] parameter(1) + ROOT add = bf16[] add(lhs, rhs) +} + +while_cond { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + gte = s32[] get-tuple-element(param), index=0 + constant.1 = s32[] constant(8) + ROOT cmp = pred[] compare(gte, constant.1), direction=LT +} + +while_body { + param = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], + bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) parameter(0) + param.0 = s32[] get-tuple-element(param), index=0 + param.pipelined.0 = bf16[6,8,128] get-tuple-element(param), index=1 + param.pipelined.1 = bf16[6,8,128] get-tuple-element(param), index=2 + param.pipelined.2 = bf16[6,8,128] get-tuple-element(param), index=3 + param.nonpipelined.0 = bf16[6,8,128] get-tuple-element(param), index=4 + param.nonpipelined.1 = bf16[6,8,128] get-tuple-element(param), index=5 + param.nonpipelined.2 = bf16[6,8,128] get-tuple-element(param), index=6 + param.7 = bf16[3,1,2,128] get-tuple-element(param), index=7 + zero = bf16[] constant(0) + one = s32[] constant(1) + it = s32[] add(param.0, one) + ag.pipelined.0 = bf16[6,8,128] reduce-scatter(param.pipelined.0), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.1 = bf16[6,8,128] reduce-scatter(param.pipelined.1), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.pipelined.2 = bf16[6,8,128] reduce-scatter(param.pipelined.2), + dimensions={0}, to_apply=add, + backend_config={"collective_backend_config": {"is_pipelined": true}} + ag.nonpipelined.0 = bf16[6,8,128] reduce-scatter(param.nonpipelined.0), + dimensions={0}, to_apply=add + ag.nonpipelined.1 = bf16[6,8,128] reduce-scatter(param.nonpipelined.1), + dimensions={0}, to_apply=add + ag.nonpipelined.2 = bf16[6,8,128] reduce-scatter(param.nonpipelined.2), + dimensions={0}, to_apply=add + ROOT tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(it, ag.pipelined.0, ag.pipelined.1, ag.pipelined.2, ag.nonpipelined.0, ag.nonpipelined.1, ag.nonpipelined.2, param.7) +} + +ENTRY entry { + c0 = s32[] constant(0) + p0 = bf16[6,8,128] parameter(0) + p1 = bf16[3,1,2,128] parameter(1) + tuple = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) tuple(c0, p0, p0, p0, p0, p0, p0, p1) + while = (s32[], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[6,8,128], bf16[3,1,2,128]) while(tuple), condition=while_cond, body=while_body + ROOT _ = bf16[6,8,128] get-tuple-element(while), index=1 +} +)"; + auto config = + GetModuleConfigForTest(/*replica_count=*/1, /*num_partitions=*/2); + config.mutable_debug_options() + .set_xla_gpu_enable_heuristic_pass_configuration(false); + DeviceDescription device_info; + // Combine at most 2 collectives. + int collective_size = 2 * 6 * 8 * 128; + int threshold_bytes = 2 * collective_size; + int current_peak_mem = 87625; + int pointer_size = 4; + device_info.set_device_memory_size(current_peak_mem + threshold_bytes * 4); + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString, config)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, GpuReduceScatterCombiner( + device_info, /*default_combine_threshold_in_bytes=*/ + kDefaultReduceScatterCombineThreshold, + /*combine_threshold_in_bytes=*/threshold_bytes, + /*combine_threshold_count=*/256, + /*combine_by_dim=*/false, pointer_size) + .Run(module.get())); + + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); + // Pipelined all gathers were combined up to the predefined max available + // device mem limit. + const absl::string_view kExpected = R"( + // CHECK-DAG: %[[PIPELINED_PARAM_0:.*]] = {{.*}} index=1 + // CHECK-DAG: %[[PIPELINED_PARAM_1:.*]] = {{.*}} index=2 + // CHECK-DAG: %[[PIPELINED_PARAM_2:.*]] = {{.*}} index=3 + // CHECK-DAG: %[[NONPIPELINED_PARAM_0:.*]] = {{.*}} index=4 + // CHECK-DAG: %[[NONPIPELINED_PARAM_1:.*]] = {{.*}} index=5 + // CHECK-DAG: %[[NONPIPELINED_PARAM_2:.*]] = {{.*}} index=6 + // CHECK-DAG: reduce-scatter(%[[PIPELINED_PARAM_0]], %[[PIPELINED_PARAM_1]]) + // CHECK-DAG: reduce-scatter(%[[PIPELINED_PARAM_2]], %[[NONPIPELINED_PARAM_0]]) + // CHECK-DAG: reduce-scatter(%[[NONPIPELINED_PARAM_1]], %[[NONPIPELINED_PARAM_2]]) + )"; + EXPECT_TRUE( + *RunFileCheck(module->ToString(HloPrintOptions() + .set_print_operand_shape(false) + .set_print_result_shape(false)), + kExpected)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/reduction_utils_test.cc b/third_party/xla/xla/service/gpu/reduction_utils_test.cc index 11868007c4631e..4a7db5677f79ef 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils_test.cc +++ b/third_party/xla/xla/service/gpu/reduction_utils_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 76dd3bff1304dd..b76f904a363943 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -58,15 +58,20 @@ cc_library( deps = [ ":annotation", ":custom_call_thunk", + ":dynamic_slice_thunk", ":nccl_all_gather_thunk", ":nccl_all_reduce_thunk", + ":nccl_all_to_all_thunk", ":nccl_api", ":nccl_clique_key", ":nccl_collective_broadcast_thunk", ":nccl_collective_thunk", ":thunk", + ":while_thunk", "//xla:debug_options_flags", "//xla:executable_run_options", + "//xla:shape_util", + "//xla:status_macros", "//xla:types", "//xla:util", "//xla/ffi:call_frame", @@ -78,18 +83,22 @@ cc_library( "//xla/service:computation_placer", "//xla/service:custom_call_status_internal", "//xla/service:custom_call_status_public_headers", - "//xla/service:executable", "//xla/service:global_device_id", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", - "//xla/stream_executor", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:trace_command_buffer_factory", + "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tsl/concurrency:ref_count", @@ -101,9 +110,11 @@ cc_library( "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -128,14 +139,15 @@ cc_library( ":memset_thunk", ":nccl_all_gather_thunk", ":nccl_all_reduce_thunk", + ":nccl_all_to_all_thunk", ":nccl_collective_thunk", ":replica_id_thunk", ":sequential_thunk", + ":thunk", ":wait_for_streams_thunk", ":while_thunk", "//xla:util", "//xla/service:buffer_assignment", - "//xla/service/gpu/runtime:thunk", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -158,10 +170,11 @@ xla_test( "//xla/service:platform_util", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", - "//xla/stream_executor", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/tsl/lib/core:status_test_util", @@ -203,8 +216,10 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", - "//xla/stream_executor", - "//xla/stream_executor/gpu:scoped_activate_context", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", @@ -241,7 +256,8 @@ cc_library( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -264,13 +280,12 @@ cc_library( "//xla/service:global_device_id", "//xla/service:lockable", "//xla/service:rendezvous", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/hash", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -293,6 +308,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//xla/service:global_device_id", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/status", @@ -301,7 +317,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", + "@local_tsl//tsl/platform:logging", ], ) @@ -337,8 +353,10 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:ir_emission_utils", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -379,16 +397,20 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", @@ -414,7 +436,6 @@ cc_library( "//xla/service:buffer_assignment", "//xla/hlo/ir:hlo", "@local_tsl//tsl/platform:logging", - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory", "//xla/stream_executor/gpu:gpu_asm_opts", @@ -422,7 +443,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]), + ]) + ["//xla/stream_executor:stream_executor_h"], ) cc_library( @@ -436,8 +457,9 @@ cc_library( ":thunk", "//xla/service:buffer_assignment", # build_cleaner: keep "//xla/service/gpu:buffer_allocations", # build_cleaner: keep - "//xla/stream_executor", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -481,16 +503,20 @@ xla_test( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:hlo_test_base", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -515,8 +541,9 @@ cc_library( "//xla:util", "//xla/service:buffer_assignment", "//xla/service/gpu:variant_visitor", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -540,9 +567,10 @@ cc_library( "//xla:util", "//xla/service:buffer_assignment", "//xla/service/gpu:gpu_conv_runner", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -565,8 +593,9 @@ cc_library( ":thunk", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -594,14 +623,16 @@ cc_library( "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_types_header", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "@local_tsl//tsl/platform:errors", - ] + ["//xla/service/gpu:cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]), + ] + ["//xla/service/gpu:cub_sort_kernel_" + suffix for suffix in get_cub_sort_kernel_types()]) + [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", + ], ) cc_library( @@ -626,8 +657,9 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_status_internal", "//xla/service/gpu:buffer_allocations", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/memory", @@ -651,10 +683,12 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:fft", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -689,13 +723,13 @@ cc_library( srcs = ["gpublas_lt_matmul_thunk.cc"], hdrs = ["gpublas_lt_matmul_thunk.h"], deps = [ + ":thunk", "//xla:status_macros", "//xla/service:buffer_assignment", "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -719,8 +753,9 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:gpu_transfer_manager", "//xla/service/gpu:io_feed_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -741,7 +776,10 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -762,7 +800,7 @@ cc_library( deps = [ ":thunk", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", ], ) @@ -774,12 +812,12 @@ cc_library( deps = [ ":nccl_api", ":nccl_collective_thunk", + ":thunk", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", + "//xla/stream_executor:stream", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -796,13 +834,13 @@ cc_library( deps = [ ":nccl_api", ":nccl_collective_thunk", + ":thunk", "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service/gpu:backend_configs_cc", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", + "//xla/stream_executor:stream", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -819,16 +857,23 @@ cc_library( hdrs = ["nccl_all_to_all_thunk.h"], deps = [ ":nccl_api", + ":nccl_clique_key", ":nccl_collective_thunk", + ":thunk", "//xla:shape_util", "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", - "//xla/service/gpu:ir_emission_utils", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@llvm-project//mlir:IR", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -846,7 +891,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", @@ -868,7 +914,7 @@ cc_library( "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service/gpu:backend_configs_cc", - "//xla/stream_executor", + "//xla/stream_executor:stream", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:node_hash_map", @@ -895,6 +941,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", @@ -903,12 +950,13 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:ir_emission_utils", "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_stream", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/translate/mhlo_to_hlo:attribute_exporter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -941,8 +989,8 @@ cc_library( "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:collective_ops_utils", - "//xla/service:hlo_parser", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -964,13 +1012,13 @@ cc_library( ":nccl_clique_key", ":nccl_collective_thunk", ":nccl_p2p_thunk_common", + ":thunk", "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", + "//xla/stream_executor:stream", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -989,13 +1037,13 @@ cc_library( ":nccl_clique_key", ":nccl_collective_thunk", ":nccl_p2p_thunk_common", + ":thunk", "//xla:status_macros", "//xla/hlo/ir:hlo", "//xla/service:collective_ops_utils", "//xla/service:computation_placer", "//xla/service:global_device_id", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", + "//xla/stream_executor:stream", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", @@ -1010,12 +1058,12 @@ cc_library( srcs = ["norm_thunk.cc"], hdrs = ["norm_thunk.h"], deps = [ + ":thunk", "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", "//xla/service/gpu:gpu_norm_runner", - "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -1037,7 +1085,8 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:gpu_transfer_manager", "//xla/service/gpu:io_feed_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", @@ -1049,9 +1098,9 @@ cc_library( srcs = ["replica_id_thunk.cc"], hdrs = ["replica_id_thunk.h"], deps = [ + ":thunk", "//xla/service:buffer_assignment", "//xla/service:global_device_id", - "//xla/service/gpu/runtime:thunk", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:statusor", ], @@ -1066,7 +1115,7 @@ cc_library( ]), deps = [ ":annotation", - "//xla/service/gpu/runtime:thunk", + ":thunk", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", @@ -1084,8 +1133,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/service:buffer_assignment", "//xla/service:global_device_id", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -1111,6 +1161,7 @@ cc_library( "//xla:executable_run_options", "//xla/ffi:execution_context", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", "//xla/service:buffer_assignment", "//xla/service:executable", "//xla/service:global_device_id", @@ -1118,8 +1169,9 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:gpu_executable_run_options", "//xla/service/gpu:ir_emission_utils", - "//xla/stream_executor", - "//xla/translate/mhlo_to_hlo:location_exporter", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", @@ -1129,7 +1181,6 @@ cc_library( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:statusor", ], ) @@ -1183,14 +1234,16 @@ cc_library( "//xla/service/gpu:buffer_allocations", "//xla/service/gpu:make_batch_pointers", "//xla/service/gpu/runtime:thunk", - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory", "//xla/stream_executor/gpu:gpu_asm_opts", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", - ]) + ["//xla:status_macros"], + ]) + [ + "//xla:status_macros", + "//xla/stream_executor:stream_executor_h", + ], ) cc_library( @@ -1201,8 +1254,9 @@ cc_library( ":sequential_thunk", ":thunk", "//xla/service:buffer_assignment", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -1237,7 +1291,7 @@ cc_library( ":thunk", "//xla/service:buffer_assignment", "//xla/service/gpu:kernel_arguments", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "@com_google_absl//absl/base", "@com_google_absl//absl/status", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index abfa6eb465cd99..580775c588c2d6 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/runtime/command_buffer_cmd.h" +#include #include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/algorithm/container.h" @@ -38,6 +40,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" #include "xla/ffi/call_frame.h" @@ -51,21 +54,26 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/runtime/annotation.h" +#include "xla/service/gpu/runtime/dynamic_slice_thunk.h" #include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/service_executable_run_options.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/lazy_op_runner.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/trace_command_buffer_factory.h" @@ -87,17 +95,6 @@ limitations under the License. namespace xla::gpu { -namespace { -std::optional AssignBufferIfNotNull( - const BufferAllocations& buffer_allocations, - BufferAllocation::Slice& slice) { - return slice.allocation() != nullptr - ? std::optional{buffer_allocations - .GetDeviceAddress(slice)} - : std::nullopt; -} -} // namespace - using ExecutionScopeId = se::CommandBuffer::ExecutionScopeId; using MemoryAccess = CommandBufferCmd::MemoryAccess; @@ -1422,7 +1419,8 @@ absl::Status CustomCallCmd::RecordLegacyCustomCall( } ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "CustomCallCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << "CustomCallCmd: target_name=" << target_name_ + << ", execution_scope_id=" << execution_scope_id.value(); for (int i = 0; i < operands_.size(); ++i) { if (operands_[i].has_value()) { VLOG(5) << " Operand " << i << ": " << operands_[i]->slice << " (" @@ -1477,7 +1475,8 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( ffi::CallFrameBuilder builder(operands_.size(), results_.size()); ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "CustomCallCmd: execution_scope_id=" << execution_scope_id.value(); + VLOG(5) << "CustomCallCmd: target_name=" << target_name_ + << ", execution_scope_id=" << execution_scope_id.value(); for (int i = 0; i < operands_.size(); ++i) { const std::optional& slice = operands_[i]; @@ -1511,7 +1510,7 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( execute_params.buffer_allocations->GetDeviceAddress(slice->slice); VLOG(5) << " Result " << i << ": " << slice->slice << " (" << buffer.opaque() << ")"; - builder.AddBufferArg(buffer, slice->shape.element_type(), + builder.AddBufferRet(buffer, slice->shape.element_type(), slice->shape.dimensions()); } @@ -1529,7 +1528,7 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( ffi::CallOptions options = { execute_params.buffer_allocations->device_ordinal(), ffi::CallOptions::GpuOptions{ - execute_params.stream, + stream, execute_params.buffer_allocations->memory_allocator()}, /*called_computation=*/nullptr, // TODO(b/342285364) execute_params.ffi_execution_context}; @@ -1780,6 +1779,76 @@ CommandBufferCmd::BufferUsageVector ReduceScatterCmd::buffers() { return buffer_usage; } +//===----------------------------------------------------------------------===// +// AllToAllCmd +//===----------------------------------------------------------------------===// + +AllToAllCmd::AllToAllCmd(ExecutionStreamId execution_stream_id, + ExecutionStreamId async_from_stream_id, + NcclApi* nccl_api, NcclCollectiveConfig config, + bool has_split_dimension, + absl::Span buffers) + : CollectiveCmd(CommandBufferCmdType::kAllToAll, execution_stream_id, + async_from_stream_id, nccl_api, std::move(config)), + has_split_dimension_(has_split_dimension), + buffers_(buffers.begin(), buffers.end()) {} + +absl::Status AllToAllCmd::Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) { + TF_RETURN_IF_ERROR(BarrierIfAsync( + command_buffer, execute_params.stream->parent(), record_params)); + + TF_ASSIGN_OR_RETURN( + std::vector device_buffers, + ConvertToDeviceBuffers(execute_params.buffer_allocations, buffers_, + config().operand_element_type)); + + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + VLOG(5) << "AllToAllCmd, has_split_dimension=" << has_split_dimension_ + << ", execution_scope_id=" << execution_scope_id.value(); + + for (size_t i = 0; i < device_buffers.size(); ++i) { + VLOG(5) << " Src: " << buffers_[i].source_buffer << " (" + << device_buffers[i].source_buffer.opaque() << ")"; + VLOG(5) << " Dst: " << buffers_[i].destination_buffer << " (" + << device_buffers[i].destination_buffer.opaque() << ")"; + } + + if (!execute_params.collective_params || !execute_params.collective_cliques) { + return absl::InvalidArgumentError( + "ReduceScatterCmd requires collective parameters and cliques"); + } + + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_handle, + GetNcclComm(*execute_params.collective_params, + *execute_params.collective_cliques, config().replica_groups, + config().group_mode, nccl_stream_id(), GetAsyncStreamKind())); + NcclApi::NcclCommHandle comm = comm_handle.comm_handle; + // Use custom allocator for persistent execution plans. + NcclApi::ScopedPersistentPlanAllocator scoped_allocator( + comm, tsl::MakeRef( + execute_params.buffer_allocations->device_ordinal(), + execute_params.buffer_allocations->memory_allocator(), + execute_params.stream)); + + return AddTracedCommandBuffer( + execute_params, record_params, command_buffer, [&](se::Stream* stream) { + return RunAllToAll(nccl_api(), has_split_dimension_, device_buffers, + *stream, comm); + }); +} + +CommandBufferCmd::BufferUsageVector AllToAllCmd::buffers() { + BufferUsageVector buffer_usage; + for (auto& buffer : buffers_) { + buffer_usage.emplace_back(buffer.source_buffer, MemoryAccess::kRead); + buffer_usage.emplace_back(buffer.destination_buffer, MemoryAccess::kWrite); + } + return buffer_usage; +} + //===----------------------------------------------------------------------===// // AllGatherCmd //===----------------------------------------------------------------------===// @@ -1917,4 +1986,262 @@ CommandBufferCmd::BufferUsageVector CollectiveBroadcastCmd::buffers() { return buffer_usage; } +//===----------------------------------------------------------------------===// +// DynamicSliceFusionCmd +//===----------------------------------------------------------------------===// + +DynamicSliceFusionCmd::DynamicSliceFusionCmd( + ExecutionStreamId execution_stream_id, + std::unique_ptr embedded_commands, + std::vector> arguments, + std::vector> fake_allocations, + std::vector>> offsets, + std::vector> orig_shapes, + std::vector> sliced_shapes, + std::vector> offset_byte_sizes) + : CommandBufferCmd(CommandBufferCmdType::kDynamicSliceFusionCmd, + execution_stream_id), + embedded_commands_(std::move(embedded_commands)), + fake_allocations_(std::move(fake_allocations)) { + // Zip all arguments together to create a list of SliceDef. + for (auto [arg, offset, orig_shape, sliced_shape, offset_byte_size] : + llvm::zip_equal(arguments, offsets, orig_shapes, sliced_shapes, + offset_byte_sizes)) { + slices_.push_back(DynamicSliceThunk::SliceDef{ + std::move(arg), + std::move(offset), + std::move(orig_shape), + std::move(sliced_shape), + std::move(offset_byte_size), + }); + } + + for (auto [argument_idx, slice] : llvm::enumerate(slices_)) { + embeded_to_origin_slice_map_[argument_idx] = slice.embedded_thunk_argument; + } + + // Find how many offsets we might have to transfer from device to host and + // pre-compute host allocation requirements. + for (DynamicSliceThunk::SliceDef& slice : slices_) { + offsets_allocs_base_.push_back(offsets_allocs_size_); + if (slice.sliced_shape.has_value()) { + offsets_allocs_size_ += slice.sliced_shape->rank() * sizeof(int64_t); + } + } +} + +// Force update the command when there is any non-constant value slice offset, +// because the memory address might changed if the offset is loop +// iterator or operator outputs even if the parent command's memory pointers do +// not change. +bool DynamicSliceFusionCmd::force_update() { + return !absl::c_all_of(slices_, [](const DynamicSliceThunk::SliceDef& slice) { + if (!slice.offsets.has_value()) return true; + return absl::c_all_of(slice.offsets.value(), + [](DynamicSliceThunk::Offset offset) { + return std::holds_alternative(offset); + }); + }); +} + +absl::Status DynamicSliceFusionCmd::Initialize( + const Thunk::InitializeParams& params, StateManager& state) { + TF_RETURN_IF_ERROR(embedded_commands_->Initialize(params, state)); + absl::MutexLock lock(&mutex_); + if (offsets_allocs_.contains(params.executor)) return absl::OkStatus(); + + VLOG(2) << "Allocate " << offsets_allocs_size_ + << " bytes for transferring offsets on executor: " << params.executor; + TF_ASSIGN_OR_RETURN( + std::unique_ptr allocation, + params.executor->HostMemoryAllocate(offsets_allocs_size_)); + offsets_allocs_.emplace(params.executor, std::move(allocation)); + return absl::OkStatus(); +} + +absl::Status DynamicSliceFusionCmd::Prepare( + const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) { + for (DynamicSliceThunk::SliceDef& slice : slices_) { + if (slice.offsets.has_value()) { + TF_RET_CHECK(slice.embedded_thunk_argument.has_value()); + TF_RET_CHECK(slice.orig_shape.has_value()); + TF_RET_CHECK(slice.sliced_shape.has_value()); + TF_RET_CHECK(slice.offset_byte_size.has_value()); + + TF_RET_CHECK(slice.orig_shape->IsArray()); + TF_RET_CHECK(slice.sliced_shape->IsArray()); + + TF_RET_CHECK(slice.offsets->size() == slice.orig_shape->rank()); + TF_RET_CHECK(slice.sliced_shape->rank() == slice.orig_shape->rank()); + } + } + TF_RETURN_IF_ERROR(embedded_commands_->Prepare(params, resource_requests)); + return absl::OkStatus(); +} + +absl::Status DynamicSliceFusionCmd::Record( + const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, se::CommandBuffer* command_buffer) { + se::Stream& stream = *execute_params.stream; + ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); + + const BufferAllocations& orig_allocations = + *execute_params.buffer_allocations; + absl::InlinedVector slice_buffers( + slices_.size(), se::DeviceMemoryBase()); + + // Get memory allocation for copying offsets from device. + int64_t* offsets_alloc = [&] { + absl::MutexLock lock(&mutex_); + return reinterpret_cast( + offsets_allocs_.at(stream.parent())->opaque()); + }(); + + auto offset_value = [&](int64_t arg_idx, int64_t offset_idx) -> int64_t& { + return offsets_alloc[offsets_allocs_base_.at(arg_idx) + offset_idx]; + }; + + VLOG(2) << "Execute address computation thunk: slices=" << slices_.size(); + for (auto [argument_idx, slice] : llvm::enumerate(slices_)) { + // Skip arguments that do not have buffer slices (tokens). + if (!slice.embedded_thunk_argument.has_value()) { + continue; + } + + // `argument_buffer` will contain the original offset for slice + // `argument_slice` within `orig_allocations` + se::DeviceMemoryBase argument_buffer = + orig_allocations.GetDeviceAddress(*slice.embedded_thunk_argument); + + // If argument is not sliced, just use the original buffer. + if (!slice.offsets.has_value()) { + slice_buffers[argument_idx] = argument_buffer; + continue; + } + + const Shape& src_shape = *slice.orig_shape; + const Shape& dst_shape = *slice.sliced_shape; + + absl::InlinedVector slice_starts; + slice_starts.reserve(dst_shape.rank()); + + // Number of issues d2h transfers to copy offset values from device to + // host. + int64_t num_transfers = 0; + + // Get offset for `argument_idx`-th argument, which has `dst_shape.rank()` + // components. + for (auto [offset_idx, values] : llvm::enumerate(llvm::zip( + *slice.offsets, src_shape.dimensions(), dst_shape.dimensions()))) { + auto [offset, src_dim, dst_dim] = values; + if (uint64_t* const_offset = std::get_if(&offset)) { + // Forward slice offsets that are known constant values + VLOG(2) << " - arg " << argument_idx << "[" << offset_idx + << "]: constant offset = " << *const_offset; + offset_value(argument_idx, offset_idx) = *const_offset; + + } else if (std::holds_alternative(offset)) { + // Get slice offset from the current loop iteration. + TF_ASSIGN_OR_RETURN(int64_t iter, WhileThunk::CurrentLoopIteration()); + VLOG(2) << " - arg " << argument_idx << "[" << offset_idx + << "]: loop iteration offset = " << iter; + offset_value(argument_idx, offset_idx) = iter; + + } else if (DynamicSliceThunk::OffsetArray* offset_array = + std::get_if(&offset)) { + TF_ASSIGN_OR_RETURN(int64_t iter, WhileThunk::CurrentLoopIteration()); + VLOG(2) << " - arg " << argument_idx << "[" << offset_idx + << "]: offset array offset = " << offset_array->values[iter]; + offset_value(argument_idx, offset_idx) = offset_array->values[iter]; + + } else { + // Transfer slice offset value from device to host. + auto alloc_slice = std::get(offset); + VLOG(2) << " - arg " << argument_idx << "[" << offset_idx + << "]: transfer offset from device " << alloc_slice.ToString(); + + se::DeviceMemoryBase offset_src = + orig_allocations.GetDeviceAddress(alloc_slice); + int64_t* offset_dst = &offset_value(argument_idx, offset_idx); + + // Copy the `offset_idx`-th component of the offset for the + // `argument_idx`-th argument from device to host. + TF_RETURN_IF_ERROR( + stream.Memcpy(offset_dst, offset_src, *slice.offset_byte_size)); + ++num_transfers; + } + } + + // Wait for the completion of all transfers. + if (num_transfers > 0) { + VLOG(2) << "Wait for completion of " << num_transfers << " transfer"; + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + } + + // Clamp start indices: + // start_indices[i] = min(max(start_indices[i], 0), + // operand.dimension_size[i] - size_indices[i]) + for (auto [offset_idx, values] : llvm::enumerate( + llvm::zip(src_shape.dimensions(), dst_shape.dimensions()))) { + auto [src_dim, dst_dim] = values; + int64_t start_index = + std::min(std::max(offset_value(argument_idx, offset_idx), int64_t{0}), + src_dim - dst_dim); + VLOG(2) << "arg idx: " << argument_idx << " offset_idx " << offset_idx + << " with offset_value " << offset_value(argument_idx, offset_idx) + << " start_idx: " << start_index << " src_dim: " << src_dim + << " dst_dim:" << dst_dim; + slice_starts.push_back(start_index); + } + + // Compute new slice. No need to copy the content to new buffers as we can + // reuse the original buffers since slices are contiguous. + int64_t new_size = ShapeUtil::ByteSizeOf(dst_shape); + + int64_t new_offset = 0; + for (auto [start, stride] : + llvm::zip(slice_starts, *ShapeUtil::ByteStrides(src_shape))) { + new_offset += start * stride; + } + + VLOG(2) << "Create sliced argument " << argument_idx << " of shape " + << slice.sliced_shape->ToString() + << " by slicing argument of shape " << slice.orig_shape->ToString() + << " at offset " << new_offset << " with " << new_size; + slice_buffers[argument_idx] = + argument_buffer.GetByteSlice(new_offset, new_size); + } + + // Safe to create a local BufferAllocations here since buffers are only slices + // of bigger ones allocated elsewhere. + BufferAllocations slice_allocations(slice_buffers, + orig_allocations.device_ordinal(), + orig_allocations.memory_allocator()); + + Thunk::ExecuteParams new_params = + Thunk::ExecuteParams::CloneWithNewAllocations(execute_params, + slice_allocations); + auto nested_command_buffer = + execute_params.stream->parent() + ->CreateCommandBuffer(se::CommandBuffer::Mode::kNested) + .value(); + TF_RETURN_IF_ERROR(embedded_commands_->Record(new_params, record_params, + nested_command_buffer.get())); + return command_buffer->AddNestedCommandBuffer(execution_scope_id, + *nested_command_buffer); +} + +CommandBufferCmd::BufferUsageVector DynamicSliceFusionCmd::buffers() { + CommandBufferCmd::BufferUsageVector buffers; + auto embed_buffers = embedded_commands_->buffers(); + for (auto buffer_usage : embed_buffers) { + CHECK(embeded_to_origin_slice_map_[buffer_usage.slice.index()].has_value()); + buffers.emplace_back( + embeded_to_origin_slice_map_[buffer_usage.slice.index()].value(), + buffer_usage.access); + } + return buffers; +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index 27e8fea0d86366..d760f457b30712 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -26,12 +26,14 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -44,13 +46,18 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" +#include "xla/service/gpu/runtime/dynamic_slice_thunk.h" #include "xla/service/gpu/runtime/nccl_api.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/shape.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -78,8 +85,10 @@ namespace xla::gpu { V(kCollectiveCmd, "CollectiveCmd") \ V(kAllReduceCmd, "AllReduceCmd") \ V(kReduceScatter, "ReduceScatterCmd") \ + V(kAllToAll, "AllToAllCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ + V(kDynamicSliceFusionCmd, "DynamicSliceFusionCmd") \ V(kUnknownCmd, "UnknownCmd") \ // clang-format on @@ -888,25 +897,28 @@ class CustomCallCmd : public CommandBufferCmd { // has different meaning in different translation units. We need to get rid of // GOOGLE_CUDA defines all over XLA to fix this! As a workaround just keep // constructor in a header file. - CustomCallCmd(ExecutionStreamId execution_stream_id, + CustomCallCmd(ExecutionStreamId execution_stream_id, std::string target_name, CustomCallTarget call_target, std::vector> operands, std::vector> results, absl::string_view opaque) : CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd, execution_stream_id), + target_name_(std::move(target_name)), call_target_(std::move(call_target)), opaque_(opaque), operands_(std::move(operands)), results_(std::move(results)) {} - CustomCallCmd(ExecutionStreamId execution_stream_id, XLA_FFI_Handler* handler, + CustomCallCmd(ExecutionStreamId execution_stream_id, std::string target_name, + XLA_FFI_Handler* handler, std::vector> operands, std::vector> results, AttributesMap attributes, const HloComputation* called_computation) : CommandBufferCmd(CommandBufferCmdType::kCustomCallCmd, execution_stream_id), + target_name_(std::move(target_name)), handler_(handler), attributes_(std::move(attributes)), called_computation_(called_computation), @@ -928,6 +940,8 @@ class CustomCallCmd : public CommandBufferCmd { const RecordParams& record_params, se::CommandBuffer* command_buffer); + std::string target_name_; + // This is a legacy custom call API that is discouraged, and will be // deprecated once XLA:FFI mechanism is ready. CustomCallTarget call_target_; @@ -1073,6 +1087,32 @@ class ReduceScatterCmd : public CollectiveCmd { std::vector buffers_; }; +//===----------------------------------------------------------------------===// +// AllToAllCmd +//===----------------------------------------------------------------------===// + +class AllToAllCmd : public CollectiveCmd { + public: + AllToAllCmd(ExecutionStreamId execution_stream_id, + ExecutionStreamId async_from_stream_id, NcclApi* nccl_api, + NcclCollectiveConfig config, bool has_split_dimension, + absl::Span buffers); + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + AsyncStreamKind GetAsyncStreamKind() override { + return AsyncStreamKind::kCollective; + }; + + private: + bool has_split_dimension_; + std::vector buffers_; +}; + //===----------------------------------------------------------------------===// // AllGatherCmd //===----------------------------------------------------------------------===// @@ -1119,6 +1159,62 @@ class CollectiveBroadcastCmd : public CollectiveCmd { std::vector buffers_; }; +//===----------------------------------------------------------------------===// +// DynamicSliceFusionCmd +//===----------------------------------------------------------------------===// + +class DynamicSliceFusionCmd : public CommandBufferCmd { + public: + DynamicSliceFusionCmd( + ExecutionStreamId execution_stream_id, + std::unique_ptr embedded_commands, + std::vector> arguments, + std::vector> fake_allocations_, + std::vector>> + offsets, + std::vector> orig_shapes, + std::vector> sliced_shapes, + std::vector> offset_byte_sizes); + + absl::Status Initialize(const Thunk::InitializeParams& params, + StateManager& state); + + absl::Status Prepare(const Thunk::PrepareParams& params, + Thunk::ResourceRequests& resource_requests) final; + + absl::Status Record(const Thunk::ExecuteParams& execute_params, + const RecordParams& record_params, + se::CommandBuffer* command_buffer) override; + + BufferUsageVector buffers() override; + + bool force_update() override; + + bool IsNestedCommandBuffer() const final { return true; } + + private: + std::unique_ptr embedded_commands_; + std::vector slices_; + std::vector> fake_allocations_; + + // Pinned host memory for transferring offset values from device to host. + absl::Mutex mutex_; + absl::flat_hash_map> + offsets_allocs_ ABSL_GUARDED_BY(mutex_); + + // Pre-computed size requirement for `offsets_allocs_`. + int64_t offsets_allocs_size_ = 0; + + // A mapping from argument index to the base offset in the `offsets_allocs_`. + std::vector offsets_allocs_base_; + + // mapping from original allocation index to allocation index of embedded + // command sequences. + absl::flat_hash_map> + embeded_to_origin_slice_map_; +}; + } // namespace xla::gpu #endif // XLA_SERVICE_GPU_RUNTIME_COMMAND_BUFFER_CMD_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index cb3801ef13c1da..8daec80710a35f 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -23,7 +23,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/service/buffer_assignment.h" #include "xla/service/gpu/runtime/command_buffer_cmd.h" #include "xla/service/gpu/runtime/conditional_thunk.h" #include "xla/service/gpu/runtime/copy_thunk.h" @@ -35,6 +34,7 @@ limitations under the License. #include "xla/service/gpu/runtime/memset_thunk.h" #include "xla/service/gpu/runtime/nccl_all_gather_thunk.h" #include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_to_all_thunk.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/replica_id_thunk.h" #include "xla/service/gpu/runtime/sequential_thunk.h" @@ -173,6 +173,13 @@ static absl::StatusOr Convert( thunk.buffers()); } +static absl::StatusOr Convert(const NcclAllToAllStartThunk& thunk) { + return std::make_unique( + thunk.nccl_execution_stream_id(), thunk.execution_stream_id(), + thunk.nccl_api(), thunk.config(), thunk.has_split_dimension(), + thunk.buffers()); +} + static absl::StatusOr Convert(const NcclAllGatherStartThunk& thunk) { return std::make_unique( thunk.nccl_execution_stream_id(), thunk.execution_stream_id(), @@ -184,6 +191,26 @@ static absl::StatusOr Convert(const NcclCollectiveDoneThunk& thunk) { thunk.nccl_execution_stream_id()); } +static absl::StatusOr Convert(const DynamicSliceThunk& thunk) { + auto cmd_sequence = std::make_unique(); + auto embed_thunk = thunk.get_embeded_thunk(); + TF_RETURN_IF_ERROR(AppendCommands( + *cmd_sequence, embed_thunk->thunks(), + CommandBufferCmdSequence::SynchronizationMode::kAutomatic)); + + auto& thunk_fake_allocations = thunk.get_fake_allocations(); + std::vector> fake_allocations; + for (auto it = thunk_fake_allocations.begin(); + it != thunk_fake_allocations.end(); ++it) { + fake_allocations.push_back(std::make_unique(**it)); + } + return std::make_unique( + thunk.execution_stream_id(), std::move(cmd_sequence), + thunk.get_arguments(), std::move(fake_allocations), thunk.get_offsets(), + thunk.get_orig_shapes(), thunk.get_sliced_shapes(), + thunk.get_offset_byte_sizes()); +} + static absl::StatusOr Convert(const PartitionIdThunk& thunk) { return std::make_unique(thunk.execution_stream_id(), thunk.dest(), @@ -197,9 +224,16 @@ static absl::StatusOr Convert(const ReplicaIdThunk& thunk) { } static absl::StatusOr Convert(const CustomCallThunk& thunk) { - return std::make_unique(thunk.execution_stream_id(), - thunk.call_target(), thunk.operands(), - thunk.results(), thunk.opaque()); + if (auto bundle = thunk.bundle(); bundle.has_value()) { + return std::make_unique( + thunk.execution_stream_id(), thunk.target_name(), bundle->execute, + thunk.operands(), thunk.results(), thunk.attributes(), + /*called_computation=*/nullptr); // TODO(b/342285364) + } else { + return std::make_unique( + thunk.execution_stream_id(), thunk.target_name(), thunk.call_target(), + thunk.operands(), thunk.results(), thunk.opaque()); + } } static absl::StatusOr Convert(const CuDnnThunk& thunk) { @@ -270,6 +304,8 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kNcclReduceScatterStart: return append(Convert(thunk)); + case Thunk::Kind::kNcclAllToAllStart: + return append(Convert(thunk)); case Thunk::Kind::kPartitionId: return append(Convert(thunk)); case Thunk::Kind::kReplicaId: @@ -289,8 +325,12 @@ static absl::Status AppendCommands( case Thunk::Kind::kNcclAllGatherDone: case Thunk::Kind::kNcclAllReduceDone: case Thunk::Kind::kNcclReduceScatterDone: + case Thunk::Kind::kNcclAllToAllDone: return append(Convert(thunk)); + case Thunk::Kind::kDynamicSlice: + return append(Convert(thunk)); + case Thunk::Kind::kWaitForStreams: return append(Convert(thunk)); diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index bce9d1927d05ea..385e762fabe40f 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -51,6 +51,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" @@ -707,6 +708,161 @@ TEST(CommandBufferThunkTest, GemmCmd) { ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); } +TEST(CommandBufferThunkTest, DynamicSliceFusionCmd) { + if (!IsAtLeastCuda12300()) { + GTEST_SKIP() << "CUDA graph tracing is not supported"; + } + + se::StreamExecutor* executor = GpuExecutor(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t lhs_length = sizeof(float) * 4 * 4; + int64_t fake_lhs_length = sizeof(float) * 2 * 4; + int64_t rhs_length = sizeof(float) * 4 * 3; + int64_t out_length = sizeof(float) * 2 * 3; + + // Prepare arguments: + // lhs = [1.0, 2.0, 3.0, 4.0 + // 5.0, 6.0, 7.0, 8.0] + // rhs = [1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0 + // 1.0, 1.0, 1.0] + se::DeviceMemory lhs = executor->AllocateArray(4 * 4); + std::vector lhs_arr{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + TF_ASSERT_OK(stream->Memcpy(&lhs, lhs_arr.data(), lhs_length)); + + se::DeviceMemory rhs = executor->AllocateArray(4 * 3); + std::vector rhs_arr(12, 1); + TF_ASSERT_OK(stream->Memcpy(&rhs, rhs_arr.data(), rhs_length)); + + se::DeviceMemory out = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&out, out_length)); + + se::DeviceMemory workspace = + executor->AllocateArray(1024 * 1024); + TF_ASSERT_OK(stream->MemZero(&workspace, 1024 * 1024)); + + // Prepare buffer allocations for recording command buffer. + std::vector> fake_allocations(4); + fake_allocations[0] = std::make_unique( + /*index=*/0, fake_lhs_length, /*color=*/0); + fake_allocations[1] = + std::make_unique(/*index=*/1, rhs_length, /*color=*/0); + fake_allocations[2] = + std::make_unique(/*index=*/2, out_length, + /*color=*/0); + + fake_allocations[3] = + std::make_unique(/*index=*/3, 1024 * 1024, + /*color=*/0); + BufferAllocation::Slice fake_slice_lhs(fake_allocations[0].get(), 0, + fake_lhs_length); + BufferAllocation::Slice slice_rhs(fake_allocations[1].get(), 0, rhs_length); + BufferAllocation::Slice slice_out(fake_allocations[2].get(), 0, out_length); + BufferAllocation::Slice slice_workspace(fake_allocations[3].get(), 0, + 1024 * 1024); + auto config = + GemmConfig::For(ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {}, {1}, + ShapeUtil::MakeShape(PrimitiveType::F32, {4, 3}), {}, {0}, + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 3}), 1.0, + 0.0, 0.0, PrecisionConfig::ALG_UNSET, std::nullopt, + se::blas::kDefaultComputePrecision, false, false); + ASSERT_TRUE(config.ok()); + + // Prepare commands sequence for constructing command buffer. + std::unique_ptr embed_commands = + std::make_unique(); + embed_commands->Emplace(s0, config.value(), fake_slice_lhs, + slice_rhs, slice_out, slice_workspace, + /*deterministic=*/true); + + BufferAllocation alloc_lhs(/*index=*/0, lhs_length, /*color=*/0); + BufferAllocation::Slice slice_lhs(&alloc_lhs, 0, lhs_length); + + std::vector lhs_offsets = { + DynamicSliceThunk::Offset(2UL), DynamicSliceThunk::Offset(0UL)}; + + std::vector> arguments = { + std::optional(slice_lhs), + std::optional(slice_rhs), + std::optional(slice_out), + std::optional(slice_workspace)}; + + std::vector>> offsets = { + lhs_offsets, std::nullopt, std::nullopt, std::nullopt}; + + std::vector> orig_shapes = { + ShapeUtil::MakeShape(PrimitiveType::F32, {4, 4}), std::nullopt, + std::nullopt, std::nullopt}; + std::vector> sliced_shapes = { + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), std::nullopt, + std::nullopt, std::nullopt}; + std::vector> offset_byte_sizes = { + sizeof(int64_t), std::nullopt, std::nullopt, std::nullopt}; + + CommandBufferCmdSequence commands; + commands.Emplace( + s0, std::move(embed_commands), arguments, std::move(fake_allocations), + offsets, orig_shapes, sliced_shapes, offset_byte_sizes); + + // Construct a thunk with command sequence. + CommandBufferThunk thunk(std::move(commands), Thunk::ThunkInfo()); + + ServiceExecutableRunOptions run_options; + se::StreamExecutorMemoryAllocator allocator(executor); + BufferAllocations allocations({lhs, rhs, out, workspace}, 0, &allocator); + + Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( + run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); + + Thunk::ExecutableSource source = {/*text=*/"", /*binary=*/{}}; + TF_ASSERT_OK(thunk.Initialize( + {executor, source, &allocations, stream.get(), stream.get()})); + + // Execute command buffer thunk and verify that it executed a GEMM. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `out` data back to host. + std::vector dst(6, 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Prepare buffer allocation for updating command buffer. + se::DeviceMemory updated_out = executor->AllocateArray(2 * 3); + TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); + + // Update buffer allocation to updated `out` buffer. + allocations = + BufferAllocations({lhs, rhs, updated_out, workspace}, 0, &allocator); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); + + // Try to update the command buffer with the same buffers. + TF_ASSERT_OK(stream->MemZero(&updated_out, out_length)); + + // Thunk execution should automatically update underlying command buffer. + TF_ASSERT_OK(thunk.ExecuteOnStream(params)); + TF_ASSERT_OK(stream->BlockHostUntilDone()); + + // Copy `updated_out` data back to host. + std::fill(dst.begin(), dst.end(), 0); + TF_ASSERT_OK(stream->Memcpy(dst.data(), updated_out, out_length)); + + ASSERT_EQ(dst, std::vector({10, 10, 10, 26, 26, 26})); +} + TEST(CommandBufferThunkTest, CublasLtCmd) { if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph tracing is not supported"; @@ -1313,4 +1469,96 @@ TEST(CommandBufferThunkTest, WhileCmd) { // maybe add a CustomLaunchCmd and wrap loop update into custom kernel. } +class CmdBufferTest : public HloTestBase { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); + debug_options.set_xla_gpu_graph_min_graph_size(1); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUBLASLT); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUSTOM_CALL); + debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); + debug_options.add_xla_gpu_enable_command_buffer( + DebugOptions::DYNAMIC_SLICE); + return debug_options; + } +}; + +TEST_F(CmdBufferTest, DynamicSliceFusionCmd) { + // Hlo generated by below jax code + // def scan_body(carry, x): + // sliced_x = lax.slice(x, (0, 0), (128, 128)) + // result = jnp.dot(carry, sliced_x) + // new_carry = result + // return new_carry, result + // @jax.jit + // def run_scan(initial_carry, xs): + // final_carry, outputs = lax.scan(scan_body, initial_carry, xs, length=2) + // return final_carry, outputs + + const char* module_str = R"( +HloModule jit_run_scan + +None.7 { + Arg_0.8 = f32[128,128]{1,0} parameter(0) + Arg_1.9 = f32[128,128]{1,0} parameter(1) + dot.10 = f32[128,128]{1,0} dot(Arg_0.8, Arg_1.9), lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT tuple.11 = (f32[128,128]{1,0}, f32[128,128]{1,0}) tuple(dot.10, dot.10) +} + +region_0.12 { + arg_tuple.13 = (s32[], f32[128,128]{1,0}, f32[2,128,128]{2,1,0}, f32[2,128,128]{2,1,0}) parameter(0) + get-tuple-element.14 = s32[] get-tuple-element(arg_tuple.13), index=0 + constant.18 = s32[] constant(1) + add.34 = s32[] add(get-tuple-element.14, constant.18) + get-tuple-element.15 = f32[128,128]{1,0} get-tuple-element(arg_tuple.13), index=1 + get-tuple-element.17 = f32[2,128,128]{2,1,0} get-tuple-element(arg_tuple.13), index=3 + constant.20 = s32[] constant(0) + compare.21 = pred[] compare(get-tuple-element.14, constant.20), direction=LT + constant.19 = s32[] constant(2) + add.22 = s32[] add(get-tuple-element.14, constant.19) + select.23 = s32[] select(compare.21, add.22, get-tuple-element.14) + dynamic-slice.24 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.17, select.23, constant.20, constant.20), dynamic_slice_sizes={1,128,128} + reshape.25 = f32[128,128]{1,0} reshape(dynamic-slice.24) + call.26 = (f32[128,128]{1,0}, f32[128,128]{1,0}) call(get-tuple-element.15, reshape.25), to_apply=None.7 + get-tuple-element.27 = f32[128,128]{1,0} get-tuple-element(call.26), index=0 + get-tuple-element.16 = f32[2,128,128]{2,1,0} get-tuple-element(arg_tuple.13), index=2 + get-tuple-element.28 = f32[128,128]{1,0} get-tuple-element(call.26), index=1 + reshape.29 = f32[1,128,128]{2,1,0} reshape(get-tuple-element.28) + compare.30 = pred[] compare(get-tuple-element.14, constant.20), direction=LT + add.31 = s32[] add(get-tuple-element.14, constant.19) + select.32 = s32[] select(compare.30, add.31, get-tuple-element.14) + dynamic-update-slice.33 = f32[2,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.16, reshape.29, select.32, constant.20, constant.20) + ROOT tuple.35 = (s32[], f32[128,128]{1,0}, f32[2,128,128]{2,1,0}, f32[2,128,128]{2,1,0}) tuple(add.34, get-tuple-element.27, dynamic-update-slice.33, get-tuple-element.17) +} // region_0.12 + +region_1.36 { + arg_tuple.37 = (s32[], f32[128,128]{1,0}, f32[2,128,128]{2,1,0}, f32[2,128,128]{2,1,0}) parameter(0) + get-tuple-element.39 = f32[128,128]{1,0} get-tuple-element(arg_tuple.37), index=1 + get-tuple-element.40 = f32[2,128,128]{2,1,0} get-tuple-element(arg_tuple.37), index=2 + get-tuple-element.41 = f32[2,128,128]{2,1,0} get-tuple-element(arg_tuple.37), index=3 + get-tuple-element.38 = s32[] get-tuple-element(arg_tuple.37), index=0 + constant.42 = s32[] constant(2) + ROOT compare.43 = pred[] compare(get-tuple-element.38, constant.42), direction=LT +} // region_1.36 + +ENTRY main.49 { + constant.3 = s32[] constant(0) + Arg_0.1 = f32[128,128]{1,0} parameter(0) + constant.4 = f32[] constant(0) + broadcast.5 = f32[2,128,128]{2,1,0} broadcast(constant.4), dimensions={} + Arg_1.2 = f32[2,128,128]{2,1,0} parameter(1) + tuple.6 = (s32[], f32[128,128]{1,0}, f32[2,128,128]{2,1,0}, f32[2,128,128]{2,1,0}) tuple(constant.3, Arg_0.1, broadcast.5, Arg_1.2) + while.44 = (s32[], f32[128,128]{1,0}, f32[2,128,128]{2,1,0}, f32[2,128,128]{2,1,0}) while(tuple.6), condition=region_1.36, body=region_0.12 + get-tuple-element.45 = s32[] get-tuple-element(while.44), index=0 + get-tuple-element.46 = f32[128,128]{1,0} get-tuple-element(while.44), index=1 + get-tuple-element.47 = f32[2,128,128]{2,1,0} get-tuple-element(while.44), index=2 + ROOT tuple.48 = (f32[128,128]{1,0}, f32[2,128,128]{2,1,0}) tuple(get-tuple-element.46, get-tuple-element.47) +} +)"; + EXPECT_TRUE(RunAndCompare(module_str, ErrorSpec{1e-3, 1e-3})); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index 7cf44c109cb267..c12e79d60e613c 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -55,17 +55,17 @@ using xla::ffi::CallFrameBuilder; using xla::ffi::CallOptions; absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, CustomCallTarget call_target, + ThunkInfo thunk_info, std::string target_name, CustomCallTarget call_target, std::vector> operands, std::vector> results, const std::string& opaque) { - return absl::WrapUnique( - new CustomCallThunk(thunk_info, std::move(call_target), - std::move(operands), std::move(results), opaque)); + return absl::WrapUnique(new CustomCallThunk( + thunk_info, std::move(target_name), std::move(call_target), + std::move(operands), std::move(results), opaque)); } absl::StatusOr> CustomCallThunk::Create( - ThunkInfo thunk_info, XLA_FFI_Handler_Bundle bundle, - std::vector> operands, + ThunkInfo thunk_info, std::string target_name, + XLA_FFI_Handler_Bundle bundle, std::vector> operands, std::vector> results, AttributesMap attributes, const HloComputation* called_computation) { auto execution_state = std::make_unique(); @@ -89,28 +89,31 @@ absl::StatusOr> CustomCallThunk::Create( } return absl::WrapUnique(new CustomCallThunk( - thunk_info, bundle, std::move(operands), std::move(results), - std::move(attributes), std::move(execution_state), called_computation)); + thunk_info, std::move(target_name), bundle, std::move(operands), + std::move(results), std::move(attributes), std::move(execution_state), + called_computation)); } -CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, +CustomCallThunk::CustomCallThunk(ThunkInfo thunk_info, std::string target_name, CustomCallTarget call_target, std::vector> operands, std::vector> results, const std::string& opaque) : Thunk(Thunk::kCustomCall, thunk_info), + target_name_(std::move(target_name)), operands_(std::move(operands)), results_(std::move(results)), call_target_(std::move(call_target)), opaque_(opaque) {} CustomCallThunk::CustomCallThunk( - ThunkInfo thunk_info, XLA_FFI_Handler_Bundle bundle, - std::vector> operands, + ThunkInfo thunk_info, std::string target_name, + XLA_FFI_Handler_Bundle bundle, std::vector> operands, std::vector> results, AttributesMap attributes, std::unique_ptr execution_state, const HloComputation* called_computation) : Thunk(Thunk::kCustomCall, thunk_info), + target_name_(std::move(target_name)), operands_(std::move(operands)), results_(std::move(results)), bundle_(bundle), @@ -156,19 +159,27 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) { } absl::Status CustomCallThunk::ExecuteFfiHandler( - XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, - int32_t device_ordinal, se::Stream* stream, - se::DeviceMemoryAllocator* allocator, + XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, se::Stream* stream, const ffi::ExecutionContext* execution_context, const BufferAllocations* buffer_allocations) { if (handler == nullptr) { return absl::InternalError("FFI execute handler is not set"); } + if (stage != XLA_FFI_ExecutionStage_PREPARE && + !(buffer_allocations && stream)) { + return absl::InternalError("buffer allocations and stream are required"); + } // TODO(ezhulenev): This is not the most optimal approach, as we'll be doing // a lot of extra allocation on every call. We have to keep attributes // separate from arguments, as they do not change after thunk is constructed. CallFrameBuilder builder(operands_.size(), results_.size()); + auto device_address = + [buffer_allocations]( + BufferAllocation::Slice slice) -> se::DeviceMemoryBase { + return buffer_allocations ? buffer_allocations->GetDeviceAddress(slice) + : se::DeviceMemoryBase{}; + }; for (auto& operand : operands_) { if (!operand.has_value()) { @@ -179,7 +190,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( if (!operand->slice.allocation()) return Internal("custom call argument missing buffer allocation"); - builder.AddBufferArg(buffer_allocations->GetDeviceAddress(operand->slice), + builder.AddBufferArg(device_address(operand->slice), operand->shape.element_type(), operand->shape.dimensions()); } @@ -193,7 +204,7 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( if (!result->slice.allocation()) return Internal("custom call result missing buffer allocation"); - builder.AddBufferRet(buffer_allocations->GetDeviceAddress(result->slice), + builder.AddBufferRet(device_address(result->slice), result->shape.element_type(), result->shape.dimensions()); } @@ -204,6 +215,13 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( builder.AddAttributes(attrs.Build()); CallFrame call_frame = builder.Build(); + int32_t device_ordinal = -1; + se::DeviceMemoryAllocator* allocator = nullptr; + if (stage != XLA_FFI_ExecutionStage_PREPARE) { + device_ordinal = buffer_allocations->device_ordinal(); + allocator = buffer_allocations->memory_allocator(); + } + CallOptions options = { device_ordinal, CallOptions::GpuOptions{stream, allocator}, called_computation_, execution_context, execution_state_.get()}; @@ -212,10 +230,14 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( absl::Status CustomCallThunk::Prepare(const PrepareParams& params, ResourceRequests& resource_requests) { - if (bundle_ && bundle_->prepare) { - return absl::InternalError("FFI prepare stage is not yet supported"); + if (!bundle_ || !bundle_->prepare) { + return absl::OkStatus(); } - return absl::OkStatus(); + + return ExecuteFfiHandler(bundle_->prepare, XLA_FFI_ExecutionStage_PREPARE, + /*stream=*/nullptr, + /*execution_context=*/nullptr, + /*buffer_allocations=*/nullptr); } absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { @@ -224,19 +246,15 @@ absl::Status CustomCallThunk::Initialize(const InitializeParams& params) { } return ExecuteFfiHandler( - bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, - params.buffer_allocations->device_ordinal(), params.stream, - params.buffer_allocations->memory_allocator(), + bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, params.stream, params.ffi_execution_context, params.buffer_allocations); } absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) { if (bundle_.has_value()) { - return ExecuteFfiHandler( - bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, - params.buffer_allocations->device_ordinal(), params.stream, - params.buffer_allocations->memory_allocator(), - params.ffi_execution_context, params.buffer_allocations); + return ExecuteFfiHandler(bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, + params.stream, params.ffi_execution_context, + params.buffer_allocations); } return ExecuteCustomCall(params); } diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h index c65676381f9c8a..097bdff9400a77 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_SERVICE_GPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #include -#include #include #include #include @@ -77,17 +76,17 @@ class CustomCallThunk : public Thunk { Shape shape; }; - using Attribute = ffi::CallFrameBuilder::FlatAttribute; - using AttributesMap = ffi::CallFrameBuilder::FlatAttributesMap; + using Attribute = ffi::CallFrameBuilder::Attribute; + using AttributesMap = ffi::CallFrameBuilder::AttributesMap; static absl::StatusOr> Create( - ThunkInfo thunk_info, CustomCallTarget call_target, - std::vector> operands, + ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, std::vector> operands, std::vector> results, const std::string& opaque); static absl::StatusOr> Create( - ThunkInfo thunk_info, XLA_FFI_Handler_Bundle bundle, - std::vector> operands, + ThunkInfo thunk_info, std::string target_name, + XLA_FFI_Handler_Bundle bundle, std::vector> operands, std::vector> results, AttributesMap attributes, const HloComputation* called_computation); @@ -96,20 +95,27 @@ class CustomCallThunk : public Thunk { absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; - const CustomCallTarget& call_target() const { return call_target_; } + const std::string& target_name() const { return target_name_; } + CustomCallTarget call_target() const { return call_target_; } + std::optional bundle() const { return bundle_; } + const AttributesMap& attributes() const { return attributes_; } + const std::vector>& operands() const { return operands_; } const std::vector>& results() const { return results_; } + absl::string_view opaque() const { return opaque_; } private: - CustomCallThunk(ThunkInfo thunk_info, CustomCallTarget call_target, + CustomCallThunk(ThunkInfo thunk_info, std::string target_name, + CustomCallTarget call_target, std::vector> operands, std::vector> results, const std::string& opaque); - CustomCallThunk(ThunkInfo thunk_info, XLA_FFI_Handler_Bundle bundle, + CustomCallThunk(ThunkInfo thunk_info, std::string target_name, + XLA_FFI_Handler_Bundle bundle, std::vector> operands, std::vector> results, AttributesMap attributes, @@ -120,11 +126,12 @@ class CustomCallThunk : public Thunk { absl::Status ExecuteFfiHandler(XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage, - int32_t device_ordinal, se::Stream* stream, - se::DeviceMemoryAllocator* allocator, + se::Stream* stream, const ffi::ExecutionContext* execution_context, const BufferAllocations* buffer_allocations); + std::string target_name_; + std::vector> operands_; std::vector> results_; diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc index f8670a5d78e32a..94e76d38addaa1 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.cc @@ -57,7 +57,12 @@ DynamicSliceThunk::DynamicSliceThunk( : Thunk(Kind::kDynamicSlice, thunk_info), embedded_thunk_(std::make_unique( ThunkInfo(), std::move(*embedded_thunk))), - fake_allocations_(std::move(fake_allocations)) { + arguments_(arguments), + fake_allocations_(std::move(fake_allocations)), + offsets_(offsets), + orig_shapes_(orig_shapes), + sliced_shapes_(sliced_shapes), + offset_byte_sizes_(offset_byte_sizes) { // Zip all arguments together to create a list of SliceDef. for (auto [arg, offsets, orig_shape, sliced_shape, offset_byte_size] : llvm::zip_equal(arguments, offsets, orig_shapes, sliced_shapes, @@ -251,6 +256,10 @@ absl::Status DynamicSliceThunk::ExecuteOnStream(const ExecuteParams& params) { int64_t start_index = std::min(std::max(offset_value(argument_idx, offset_idx), int64_t{0}), src_dim - dst_dim); + VLOG(2) << "arg idx: " << argument_idx << " offset_idx " << offset_idx + << " with offset_value " << offset_value(argument_idx, offset_idx) + << " start_idx: " << start_index << " src_dim: " << src_dim + << " dst_dim:" << dst_dim; slice_starts.push_back(start_index); } diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h index aac0bedad9d04b..5882b6b1e93ee9 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk.h @@ -67,7 +67,7 @@ class DynamicSliceThunk : public Thunk { DynamicSliceThunk( ThunkInfo thunk_info, std::unique_ptr embedded_thunk, std::vector> arguments, - std::vector> fake_allocations_, + std::vector> fake_allocations, std::vector>> offsets, std::vector> orig_shapes, std::vector> sliced_shapes, @@ -83,10 +83,6 @@ class DynamicSliceThunk : public Thunk { absl::Status Initialize(const InitializeParams& params) override; absl::Status ExecuteOnStream(const ExecuteParams& params) override; - private: - std::unique_ptr embedded_thunk_; - std::vector> fake_allocations_; - // Definition of a dynamic slice that extract a slice from the original buffer // defined by `embedded_thunk_argument` at given `offsets`. struct SliceDef { @@ -97,6 +93,44 @@ class DynamicSliceThunk : public Thunk { std::optional offset_byte_size; }; + const SequentialThunk* get_embeded_thunk() const { + return embedded_thunk_.get(); + } + + std::vector> get_arguments() const { + return arguments_; + } + + const std::vector>& get_fake_allocations() + const { + return fake_allocations_; + } + + std::vector>> get_offsets() const { + return offsets_; + } + + std::vector> get_orig_shapes() const { + return orig_shapes_; + } + + std::vector> get_sliced_shapes() const { + return sliced_shapes_; + } + + std::vector> get_offset_byte_sizes() const { + return offset_byte_sizes_; + } + + private: + std::unique_ptr embedded_thunk_; + std::vector> arguments_; + std::vector> fake_allocations_; + std::vector>> offsets_; + std::vector> orig_shapes_; + std::vector> sliced_shapes_; + std::vector> offset_byte_sizes_; + std::vector slices_; // Pinned host memory for transferring offset values from device to host. diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc index 75c700b75b4e34..5a657b0574fc33 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "xla/ffi/ffi.h" @@ -461,8 +462,8 @@ TEST(DynamicSliceThunkTest, SlicedMemcpy) { ThunkSequence seq; TF_ASSERT_OK_AND_ASSIGN( seq.emplace_back(), - CustomCallThunk::Create(Thunk::ThunkInfo(), registration->bundle, - operands, results, + CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy", + registration->bundle, operands, results, /*attributes=*/CustomCallThunk::AttributesMap(), /*called_computation=*/nullptr)); @@ -621,8 +622,8 @@ TEST(DynamicSliceThunkTest, SlicedOutputMemcpy) { ThunkSequence seq; TF_ASSERT_OK_AND_ASSIGN( seq.emplace_back(), - CustomCallThunk::Create(Thunk::ThunkInfo(), registration->bundle, - operands, results, + CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy", + registration->bundle, operands, results, /*attributes=*/CustomCallThunk::AttributesMap(), /*called_computation=*/nullptr)); @@ -1262,8 +1263,8 @@ TEST(DynamicSliceThunkTest, SlicedMemcpyOOB) { ThunkSequence seq; TF_ASSERT_OK_AND_ASSIGN( seq.emplace_back(), - CustomCallThunk::Create(Thunk::ThunkInfo(), registration->bundle, - operands, results, + CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy", + registration->bundle, operands, results, /*attributes=*/CustomCallThunk::AttributesMap(), /*called_computation=*/nullptr)); diff --git a/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc index 7d620522146acf..a493a20031005e 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.cc @@ -175,13 +175,13 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, batch_size, &scratch_allocator); TF_RET_CHECK(fft_plan != nullptr) << "Failed to create cuFFT batched plan with scratch allocator"; - fft_plan_ptr->scale_factor = 1.0f / output_distance; + fft_plan_ptr->scale_factor = output_distance; } else { fft->UpdatePlanWithScratchAllocator(stream, fft_plan.get(), &scratch_allocator); } - float scale_factor = fft_plan_ptr->scale_factor; + uint64_t scale_factor = fft_plan_ptr->scale_factor; bool launch_ok; switch (fft_type) { @@ -205,7 +205,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - complex64(scale_factor), &output_data, 1); + complex64(1.0f / scale_factor), &output_data, 1); } break; } @@ -217,7 +217,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - complex128(scale_factor), &output_data, 1); + complex128(1.0 / scale_factor), &output_data, 1); } break; } @@ -241,7 +241,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1); + 1.0f / scale_factor, &output_data, 1); } break; } @@ -253,7 +253,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, TF_ASSIGN_OR_RETURN(auto blas, GetBlas(stream)); launch_ok = blas->DoBlasScal(stream, ShapeUtil::ElementsIn(output_shape), - scale_factor, &output_data, 1); + 1.0 / scale_factor, &output_data, 1); } break; } @@ -264,7 +264,7 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, return absl::OkStatus(); } return Internal("Unable to launch fft with type %s", - FftTypeToString(fft_type)); + FftTypeToString(fft_type)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime/fft_thunk.h b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h index ffd45ed804fda9..eedb75fb80fe6d 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/fft_thunk.h @@ -42,7 +42,7 @@ struct FftPlan { // protect each plan with a mutex. absl::Mutex mu; std::unique_ptr plan ABSL_GUARDED_BY(mu); - float scale_factor ABSL_GUARDED_BY(mu); + uint64_t scale_factor ABSL_GUARDED_BY(mu); }; class FftPlanCache { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc index 49aae84589b97a..2ef818d62d6056 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.cc @@ -66,7 +66,8 @@ absl::Status CheckImplementableInst(const HloAllGatherInstruction* inst) { NcclAllGatherStartThunk::NcclAllGatherStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloAllGatherInstruction* inst, std::vector buffers) + const HloAllGatherInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclAllGatherStart, thunk_info, nccl_api, IsSyncCollective(inst)), config_(impl::GetNcclAllGatherConfig(inst)), @@ -103,7 +104,7 @@ absl::Status RunAllGather(NcclApi* nccl_api, int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); TF_RETURN_IF_ERROR(nccl_api->GroupStart()); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h index 4b5d63e6639e3e..aba61ffb6c6acb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_gather_thunk.h @@ -39,7 +39,8 @@ class NcclAllGatherStartThunk : public NcclCollectiveThunk { public: NcclAllGatherStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllGatherInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "all-gather-start"; } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc index 532d8d03b6bb6a..087ce5006746ff 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.cc @@ -48,7 +48,7 @@ absl::Status RunAllReduce(NcclApi* nccl_api, ReductionKind reduction_kind, int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-reduce from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); TF_RETURN_IF_ERROR(nccl_api->GroupStart()); for (DeviceBufferPair& buffer : buffers) { @@ -156,7 +156,8 @@ NcclAllReduceReduceScatterThunkBase::NcclAllReduceReduceScatterThunkBase( NcclAllReduceStartThunk::NcclAllReduceStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloAllReduceInstruction* inst, std::vector buffers) + const HloAllReduceInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclAllReduceReduceScatterThunkBase( Thunk::kNcclAllReduceStart, thunk_info, nccl_api, impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), @@ -189,7 +190,8 @@ absl::Status NcclAllReduceStartThunk::RunNcclCollective( NcclReduceScatterStartThunk::NcclReduceScatterStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloReduceScatterInstruction* inst, std::vector buffers) + const HloReduceScatterInstruction* inst, std::vector buffers, + bool p2p_memcpy_enabled) : NcclAllReduceReduceScatterThunkBase( Thunk::kNcclReduceScatterStart, thunk_info, nccl_api, impl::GetNcclAllReduceConfigInst(inst), std::move(buffers), @@ -230,7 +232,7 @@ absl::Status RunReduceScatter(NcclApi* nccl_api, ReductionKind reduction_kind, VLOG(3) << "Performing reduce-scatter from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h index 7d70edaf2dab56..f36727c5081a31 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_reduce_thunk.h @@ -63,7 +63,8 @@ class NcclAllReduceStartThunk : public NcclAllReduceReduceScatterThunkBase { public: NcclAllReduceStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllReduceInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "all-reduce-start"; } @@ -87,7 +88,8 @@ class NcclReduceScatterStartThunk : public NcclAllReduceReduceScatterThunkBase { public: NcclReduceScatterStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloReduceScatterInstruction* inst, - std::vector buffers); + std::vector buffers, + bool p2p_memcpy_enabled = false); static const char* GetHloOpName() { return "reduce-scatter-start"; } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc index 90bbd448c03fc2..ba80b7eabc2fb4 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc @@ -21,15 +21,19 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" #include "absl/strings/substitute.h" -#include "mlir/IR/Value.h" +#include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" +#include "xla/service/gpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -41,7 +45,6 @@ limitations under the License. namespace xla { namespace gpu { - namespace { NcclAllToAllConfig GetNcclAllToAllConfig(const HloAllToAllInstruction* instr) { @@ -58,11 +61,12 @@ NcclAllToAllConfig GetNcclAllToAllConfig(const HloAllToAllInstruction* instr) { NcclAllToAllStartThunk::NcclAllToAllStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllToAllInstruction* instr, - std::vector buffers) + std::vector buffers, bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclAllToAllStart, thunk_info, nccl_api, IsSyncCollective(instr)), config_(GetNcclAllToAllConfig(instr)), - buffers_(std::move(buffers)) { + buffers_(std::move(buffers)), + p2p_memcpy_enabled_(p2p_memcpy_enabled) { CHECK_EQ(config_.config.operand_count, buffers_.size()); } @@ -92,6 +96,77 @@ NcclAllToAllStartThunk::NcclAllToAllStartThunk( return GetNcclAllToAllConfig(instr).config.group_mode; } +absl::Status NcclAllToAllStartThunk::Initialize( + const InitializeParams& params) { + TF_RETURN_IF_ERROR(NcclCollectiveThunk::Initialize(params)); + device_count_ = params.local_device_count; + CHECK_GT(device_count_, 0); + VLOG(5) << "Local device count: " << device_count_; + + if (is_local() && p2p_memcpy_enabled_) { + const NcclStreamId stream_id = nccl_stream_id(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_wrapper, + GetNcclComm(*params.collective_params, *params.collective_cliques, + config().replica_groups, config().group_mode, stream_id, + stream_kind)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + int local_id = params.stream->parent()->device_ordinal() % num_participants; + { + absl::MutexLock lock(&pointer_maps_mutex_); + if (!send_pointer_maps_.count(local_id)) { + for (int i = 0; i < num_participants; ++i) { + if (!params.stream->parent()->HostMemoryRegister( + &send_pointer_maps_[local_id][i], sizeof(void*))) { + VLOG(5) << "Registering host send pointer for memcpy failed."; + } + if (!params.stream->parent()->HostMemoryRegister( + &receive_pointer_maps_[local_id][i], sizeof(void*))) { + VLOG(5) << "Registering host recv pointer for memcpy failed."; + } + } + } + } + } + return absl::OkStatus(); +} + +absl::Status NcclAllToAllStartThunk::Cleanup(const CleanupParams& params) { + if (p2p_memcpy_enabled_) { + const NcclStreamId stream_id = nccl_stream_id(); + AsyncStreamKind stream_kind = GetAsyncStreamKind(); + TF_ASSIGN_OR_RETURN( + NcclCommHandleWrapper comm_wrapper, + GetNcclComm(*params.collective_params, *params.collective_cliques, + config().replica_groups, config().group_mode, stream_id, + stream_kind)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + + int local_id = params.executor->device_ordinal() % num_participants; + { + absl::MutexLock lock(&pointer_maps_mutex_); + if (send_pointer_maps_.count(local_id)) { + for (auto& [id, value] : send_pointer_maps_[local_id]) { + if (!params.executor->HostMemoryUnregister((void*)value)) { + VLOG(5) << "Unregistering host send pointer for memcpy failed."; + } + } + } + if (receive_pointer_maps_.count(local_id)) { + for (auto& [id, value] : receive_pointer_maps_[local_id]) { + if (!params.executor->HostMemoryUnregister((void*)value)) { + VLOG(5) << "Unregistering host recv pointer for memcpy failed."; + } + } + } + } + } + return absl::OkStatus(); +} + absl::Status NcclAllToAllStartThunk::RunNcclCollective( const ExecuteParams& params, se::Stream& stream, NcclCommHandleWrapper comm_wrapper) { @@ -99,18 +174,52 @@ absl::Status NcclAllToAllStartThunk::RunNcclCollective( std::vector device_buffers, ConvertToDeviceBuffers(params, buffers_, config_.config.operand_element_type)); + TF_ASSIGN_OR_RETURN(int32_t num_participants, + nccl_api()->CommCount(comm_wrapper.comm_handle)); + + if (is_local() && p2p_memcpy_enabled_) { + int local_id = stream.parent()->device_ordinal() % num_participants; + absl::flat_hash_map* send_pointer_map = nullptr; + absl::flat_hash_map* receive_pointer_map = nullptr; + { + absl::MutexLock lock(&pointer_maps_mutex_); + send_pointer_map = &send_pointer_maps_[local_id]; + receive_pointer_map = &receive_pointer_maps_[local_id]; + } + return xla::gpu::RunMemCpyAllToAll( + nccl_api(), config_.has_split_dimension, device_buffers, stream, + comm_wrapper.comm_handle, *send_pointer_map, *receive_pointer_map); + } return xla::gpu::RunAllToAll(nccl_api(), config_.has_split_dimension, device_buffers, stream, comm_wrapper.comm_handle); } +AsyncStreamKind NcclAllToAllStartThunk::GetAsyncStreamKind() const { + return (is_local() && p2p_memcpy_enabled_) ? AsyncStreamKind::kMemCpyP2P + : AsyncStreamKind::kCollective; +} + +bool NcclAllToAllStartThunk::is_local() const { + for (const auto& replica_group : config_.config.replica_groups) { + const int64_t node_id = replica_group.replica_ids().at(0) / device_count_; + if (!absl::c_all_of(replica_group.replica_ids(), + [this, node_id](const int64_t rank) { + return rank / device_count_ == node_id; + })) { + return false; + } + } + return true; +} + absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, std::vector& buffers, se::Stream& stream, NcclApi::NcclCommHandle comm) { int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing all-to-all from device ordinal: " << device_ordinal; TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(nccl_api, device_ordinal, buffers, comm)); + MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); @@ -163,5 +272,84 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, return nccl_api->GroupEnd(); } +absl::Status RunMemCpyAllToAll( + NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::flat_hash_map& send_pointer_map, + absl::flat_hash_map& receive_pointer_map) { + int device_ordinal = stream.parent()->device_ordinal(); + VLOG(3) << "Performing mem-copy-all-to-all from device ordinal: " + << device_ordinal; + TF_RETURN_IF_ERROR( + MaybeRegisterBuffers(nccl_api, stream.parent(), buffers, comm)); + + TF_ASSIGN_OR_RETURN(int32_t num_participants, nccl_api->CommCount(comm)); + + // AllToAll can operate in two modes. Either it specifies a split dimension, + // in which case inputs are split and outputs concatenated in that dimension + // (here, we only support dimension 0), or it takes a list of inputs + // and produces a tuple of outputs. + if (has_split_dimension) { + for (DeviceBufferPair& buffer : buffers) { + TF_RET_CHECK(buffer.element_count % num_participants == 0) + << "Buffer was not an exact multiple of the number of participants."; + + size_t chunk_elements = buffer.element_count / num_participants; + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (int peer = 0; peer < num_participants; ++peer) { + se::DeviceMemoryBase recv_slice = + NcclApi::Slice(buffer.destination_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + send_pointer_map[peer] = (uint64_t)recv_slice.opaque(); + + TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer], + peer, comm, &stream)); + TF_RETURN_IF_ERROR(nccl_api->RecvPtrFromPeer(&receive_pointer_map[peer], + peer, comm, &stream)); + } + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + + for (int peer = 0; peer < num_participants; ++peer) { + se::DeviceMemoryBase send_slice = + NcclApi::Slice(buffer.source_buffer, buffer.element_type, + peer * chunk_elements, chunk_elements); + se::DeviceMemoryBase dst_addr = + se::DeviceMemoryBase((void*)receive_pointer_map[peer]); + TF_RETURN_IF_ERROR( + stream.MemcpyD2D(&dst_addr, send_slice, send_slice.size())); + } + } + } else { + TF_RET_CHECK(buffers.size() == num_participants) + << "Number of inputs didn't match the number of participants."; + + TF_RETURN_IF_ERROR(nccl_api->GroupStart()); + for (int peer = 0; peer < num_participants; ++peer) { + send_pointer_map[peer] = + (uint64_t)buffers[peer].destination_buffer.opaque(); + + TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer], peer, + comm, &stream)); + TF_RETURN_IF_ERROR(nccl_api->RecvPtrFromPeer(&receive_pointer_map[peer], + peer, comm, &stream)); + } + TF_RETURN_IF_ERROR(nccl_api->GroupEnd()); + TF_RETURN_IF_ERROR(stream.BlockHostUntilDone()); + + for (int peer = 0; peer < num_participants; ++peer) { + // double buffer, exchange data with peer + se::DeviceMemoryBase dst_addr = + se::DeviceMemoryBase((void*)receive_pointer_map[peer]); + TF_RETURN_IF_ERROR(stream.MemcpyD2D(&dst_addr, + buffers[peer].source_buffer, + buffers[peer].source_buffer.size())); + } + } + return absl::OkStatus(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h index 3bc8e2e78cb192..49c616a99a3938 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.h @@ -19,10 +19,16 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "absl/status/status.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/stream_executor/stream.h" @@ -39,7 +45,7 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { public: NcclAllToAllStartThunk(ThunkInfo thunk_info, NcclApi* nccl_api, const HloAllToAllInstruction* instr, - std::vector buffers); + std::vector buffers, bool p2p_memcpy_enabled); // Returns whether the given instruction can be lowered to a nccl all-to-all // call. @@ -47,26 +53,51 @@ class NcclAllToAllStartThunk : public NcclCollectiveThunk { int64_t replica_count, int64_t partition_count); + absl::Status Initialize(const InitializeParams& params) override; + + absl::Status Cleanup(const CleanupParams& params) override; + static const char* GetHloOpName() { return "all-to-all-start"; } static CollectiveOpGroupMode GetGroupMode( const HloAllToAllInstruction* instr); - protected: const NcclCollectiveConfig& config() const override { return config_.config; } + bool has_split_dimension() const { return config_.has_split_dimension; } + absl::Span buffers() const { return buffers_; } + + protected: absl::Status RunNcclCollective(const ExecuteParams& params, se::Stream& stream, NcclCommHandleWrapper comm_wrapper) override; + AsyncStreamKind GetAsyncStreamKind() const override; + + bool is_local() const; + private: const NcclAllToAllConfig config_; const std::vector buffers_; + int64_t device_count_ = 1; + bool p2p_memcpy_enabled_ = false; + absl::Mutex pointer_maps_mutex_; + absl::node_hash_map> + send_pointer_maps_ ABSL_GUARDED_BY(pointer_maps_mutex_); + absl::node_hash_map> + receive_pointer_maps_ ABSL_GUARDED_BY(pointer_maps_mutex_); }; absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension, std::vector& buffers, se::Stream& stream, NcclApi::NcclCommHandle comm); +absl::Status RunMemCpyAllToAll( + NcclApi* nccl_api, bool has_split_dimension, + std::vector& buffers, se::Stream& stream, + NcclApi::NcclCommHandle comm, + absl::flat_hash_map& send_pointer_map, + absl::flat_hash_map& receive_pointer_map); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 77f022da6ec64f..95e7151c50d857 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/ref_count.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -112,6 +112,8 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case S8: case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return ncclInt8; case PRED: case U8: @@ -330,10 +332,14 @@ class DefaultNcclApi final : public NcclApi { absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) final; + absl::Status SendPtrToPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) final; + absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final; absl::StatusOr RegisterBuffer( NcclCommHandle comm, se::DeviceMemoryBase buffer) final; @@ -347,6 +353,8 @@ NcclApi* NcclApi::Default() { return nccl_api; } +bool NcclApi::HasNcclSupport() { return true; } + static_assert(NCCL_UNIQUE_ID_BYTES == NcclCliqueId::kSize, "size of nccl unique id must match the clique id size"); @@ -391,8 +399,7 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const NcclCliqueId& clique_id, VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank << " of " << nranks << "; fingerprint(id)=" << clique_id.fingerprint(); - - se::gpu::ScopedActivateContext activate_context(ranks[i].device); + auto activate_context = ranks[i].device->Activate(); XLA_NCCL_RETURN_IF_ERROR(ncclCommInitRankConfig( &comm_handles[i], nranks, AsNcclUniqueId(clique_id), ranks[i].rank, @@ -609,6 +616,17 @@ absl::Status DefaultNcclApi::Send(se::DeviceMemoryBase send_buffer, peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); } +absl::Status DefaultNcclApi::SendPtrToPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL RecvPtrFromPeer operation on device #%d; " + "peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), peer, comm, stream); + return XLA_NCCL_STATUS(ncclSend(ptr, 1, ncclUint64, peer, Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + absl::Status DefaultNcclApi::Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, @@ -627,6 +645,18 @@ absl::Status DefaultNcclApi::Recv(se::DeviceMemoryBase recv_buffer, peer, Cast(comm), se::gpu::AsGpuStreamValue(stream))); } +absl::Status DefaultNcclApi::RecvPtrFromPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) { + VLOG(3) << absl::StreamFormat( + "Launch NCCL RecvPtrFromPeer operation on device #%d; " + "peer=%d; comm=%p; stream=%p", + stream->parent()->device_ordinal(), peer, comm, stream); + + return XLA_NCCL_STATUS(ncclRecv(ptr, 1, ncclUint64, peer, Cast(comm), + se::gpu::AsGpuStreamValue(stream))); +} + absl::StatusOr DefaultNcclApi::RegisterBuffer(NcclCommHandle comm, se::DeviceMemoryBase buffer) { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.h b/third_party/xla/xla/service/gpu/runtime/nccl_api.h index 76747b64f703c3..d44603f3c95838 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.h @@ -60,6 +60,10 @@ class NcclApi { // NCCL or a stub if XLA compiled without NCCL or CUDA support. static NcclApi* Default(); + // Returns true if XLA is compiled with NCCL support, otherwise returns false. + // If false, Default() will return a stub implementation. + static bool HasNcclSupport(); + // Forward declarations of opaque structs corresponding to underlying platform // types (also defined as opaque structs). struct NcclComm; @@ -247,6 +251,10 @@ class NcclApi { virtual absl::Status Send(se::DeviceMemoryBase send_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) = 0; + // Send a pointer `ptr` to rank `peer`. + virtual absl::Status SendPtrToPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) = 0; // Receive data from rank `peer` into `recv_buff`. // @@ -254,6 +262,10 @@ class NcclApi { virtual absl::Status Recv(se::DeviceMemoryBase recv_buffer, PrimitiveType dtype, size_t count, int32_t peer, NcclCommHandle comm, se::Stream* stream) = 0; + // Receive a pointer from rank `peer` into `ptr`. + virtual absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, + NcclCommHandle comm, + se::Stream* stream) = 0; // Register `buffer` with communicator `comm` for zero-copy communication. // Returned handle can be used for future unregistration. diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc index b0cfad8fc23dfe..9cf030ad9fe5cb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api_stub.cc @@ -149,11 +149,21 @@ class NcclApiStub final : public NcclApi { return UnimplementedError(); } + absl::Status SendPtrToPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final { + return UnimplementedError(); + } + absl::Status Recv(se::DeviceMemoryBase, PrimitiveType, size_t, int32_t, NcclCommHandle, se::Stream*) final { return UnimplementedError(); } + absl::Status RecvPtrFromPeer(void* ptr, int32_t peer, NcclCommHandle comm, + se::Stream* stream) final { + return UnimplementedError(); + } + absl::StatusOr RegisterBuffer( NcclCommHandle, se::DeviceMemoryBase) final { return UnimplementedError(); @@ -170,4 +180,6 @@ NcclApi* NcclApi::Default() { return nccl_api; } +bool NcclApi::HasNcclSupport() { return false; } + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc index 9bbc6f4019eab1..2cdb0fd2be1705 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" +#include "tsl/platform/logging.h" namespace xla::gpu { @@ -44,7 +45,19 @@ NcclCliqueKey::NcclCliqueKey( : devices_(std::move(devices)), stream_id_(stream_id), stream_kind_(stream_kind), - participant_groups_(std::move(participant_groups)) {} + participant_groups_(std::move(participant_groups)) { + for (std::vector& group : participant_groups_) { + absl::c_sort(group); + } + // Compare the groups by their first element. + auto compare_groups = [](const std::vector& lhs, + const std::vector& rhs) { + CHECK(!lhs.empty()); + CHECK(!rhs.empty()); + return lhs[0] < rhs[0]; + }; + absl::c_sort(participant_groups_, compare_groups); +} absl::Span NcclCliqueKey::devices() const { return devices_; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h index 0946ce62ef7275..22cd6af46359bb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h @@ -29,7 +29,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla::gpu { @@ -56,10 +56,11 @@ enum class AsyncStreamKind : int64_t { kCollective = 0, // Stream for asynchronous collective ops. kP2P0 = 1, // One Stream for P2P Send and Recv ops. kP2P1 = 2, // Another Stream for P2P Send and Recv ops. + kMemCpyP2P = 3, // Stream for MemCpyP2P }; constexpr static int64_t kAsyncStreamTotal = - static_cast(AsyncStreamKind::kP2P1) + 1; + static_cast(AsyncStreamKind::kMemCpyP2P) + 1; // Assigns a unique ID to a stream for asynchronous or synchronous execution. // These IDs can be used, for example, to look up the NCCL communicator. diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc index e55e401ace8ee0..50f43b116145fb 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc @@ -89,6 +89,27 @@ TEST(NcclCliqueKeyTest, CompareWithParticipantGroups) { EXPECT_EQ(key0_nogroups, key1_nogroups); } +TEST(NcclCliqueKeyTest, CompareWithPermutedParticipantGroups) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + // The keys are equal because the replica groups are same up to permutation. + NcclCliqueKey key0( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id3, id2}, {id0, id1}}); + NcclCliqueKey key1( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}, {id2, id3}}); + EXPECT_EQ(key0, key1); + + NcclCliqueKey key_other( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id2}, {id1, id3}}); + EXPECT_FALSE(key0 == key_other); +} + TEST(NcclCliqueKeyTest, BtreeIterationOrder) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc index 370dd1189acc5c..dd9da283791be2 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.cc @@ -37,7 +37,8 @@ namespace xla::gpu { NcclCollectiveBroadcastStartThunk::NcclCollectiveBroadcastStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, - const HloCollectiveBroadcastInstruction* instr, std::vector buffers) + const HloCollectiveBroadcastInstruction* instr, std::vector buffers, + bool p2p_memcpy_enabled) : NcclCollectiveThunk(Thunk::kNcclCollectiveBroadcastStart, thunk_info, nccl_api, IsSyncCollective(instr)), config_(GetNcclCollectiveConfig(instr, std::nullopt)), diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h index 4b19b785c025d7..14e32e1b4172cc 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_broadcast_thunk.h @@ -47,7 +47,7 @@ class NcclCollectiveBroadcastStartThunk : public NcclCollectiveThunk { NcclCollectiveBroadcastStartThunk( ThunkInfo thunk_info, NcclApi* nccl_api, const HloCollectiveBroadcastInstruction* instr, - std::vector buffers); + std::vector buffers, bool p2p_memcpy_enabled = false); protected: absl::Status RunNcclCollective(const ExecuteParams& params, diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc index 6803e5b9874252..88a8ec85095fb5 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_permute_thunk.cc @@ -227,7 +227,7 @@ absl::Status RunCollectivePermute( VLOG(3) << "Performing collective permute from device ordinal: " << device_ordinal << " current_id " << current_id; TF_RETURN_IF_ERROR( - MaybeRegisterBuffers(nccl_api, device_ordinal, {buffer}, comm)); + MaybeRegisterBuffers(nccl_api, stream.parent(), {buffer}, comm)); const std::optional source_id = source_target.source; const std::optional target_id = source_target.target; diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index 8e075c8d01c730..29fd2befc75a90 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -52,6 +52,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -92,6 +93,8 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, // they involve actual computation and not just data movement. case F8E5M2: case F8E4M3FN: + case F8E5M2FNUZ: + case F8E4M3FNUZ: return !IsReductionCollective(reduction_op); default: return false; @@ -300,7 +303,7 @@ absl::StatusOr> ConvertToDeviceBuffers( return device_buffers; } -absl::Status RegisterBufferOnce(NcclApi* nccl_api, int device_ordinal, +absl::Status RegisterBufferOnce(NcclApi* nccl_api, se::StreamExecutor* executor, NcclApi::NcclCommHandle comm, se::DeviceMemoryBase buffer) { // Keep track of which communicators we have registered for already. @@ -319,39 +322,34 @@ absl::Status RegisterBufferOnce(NcclApi* nccl_api, int device_ordinal, // Since each XLA buffer is a slice into a larger BFCAllocator chunk, first // get the base address of buffer. We will use the base address to keep track // of which chunks we have registered. - void* base_ptr; - size_t base_size; -#ifdef GOOGLE_CUDA - TF_RETURN_IF_ERROR(se::gpu::GpuDriver::GetPointerAddressRange( - reinterpret_cast(buffer.opaque()), - reinterpret_cast(&base_ptr), &base_size)); -#else // GOOGLE_CUDA - base_ptr = nullptr; - base_size = 0; -#endif // GOOGLE_CUDA + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase base_buffer, + executor->GetMemoryRange(buffer)); absl::MutexLock lock(&all_registered.mu); - if (!all_registered.records.contains({device_ordinal, comm, base_ptr})) { + if (!all_registered.records.contains( + {executor->device_ordinal(), comm, base_buffer.opaque()})) { // ncclCommRegister will internally get and use the base address/size of the // address we provide. TF_ASSIGN_OR_RETURN(NcclApi::NcclRegisteredBufferHandle handle, nccl_api->RegisterBuffer(comm, buffer)); all_registered.handles.push_back(handle); - all_registered.records.insert({device_ordinal, comm, base_ptr}); + all_registered.records.insert( + {executor->device_ordinal(), comm, base_buffer.opaque()}); } return absl::OkStatus(); } -absl::Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, +absl::Status MaybeRegisterBuffers(NcclApi* nccl_api, + se::StreamExecutor* executor, const std::vector& buffers, NcclApi::NcclCommHandle comm) { for (int i = 0; i < buffers.size(); ++i) { if (buffers[i].source_memory_space == kCollectiveMemorySpaceColor) { - TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, executor, comm, buffers[i].source_buffer)); } if (buffers[i].destination_memory_space == kCollectiveMemorySpaceColor) { - TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, device_ordinal, comm, + TF_RETURN_IF_ERROR(RegisterBufferOnce(nccl_api, executor, comm, buffers[i].destination_buffer)); } } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h index 2a549cdd81f520..91092847eb53f3 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/global_device_id.h" @@ -48,7 +49,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/stream.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/xla_data.pb.h" namespace xla { @@ -326,7 +326,8 @@ absl::StatusOr> ConvertToDeviceBuffers( // communicator to enable zero-copy collectives. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html -absl::Status MaybeRegisterBuffers(NcclApi* nccl_api, int device_ordinal, +absl::Status MaybeRegisterBuffers(NcclApi* nccl_api, + se::StreamExecutor* executor, const std::vector& buffers, NcclApi::NcclCommHandle comm); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc b/third_party/xla/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc index 6da04e4370c6f4..494fde4a32850a 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_p2p_thunk_common.cc @@ -28,9 +28,9 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/runtime/nccl_clique_key.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc index 800c684ade4aa1..fbf94f2e301c6e 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_recv_thunk.cc @@ -93,8 +93,10 @@ absl::Status NcclRecvThunk::RunNcclCollective( // source, just memzero() the destination buffer. int device_ordinal = stream.parent()->device_ordinal(); VLOG(3) << "Performing Recv from device ordinal: " << device_ordinal - << "current_id " << current_id; - TF_RETURN_IF_ERROR(MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, + << ", current_id: " << current_id << ", group mode: " + << CollectiveOpGroupModeToString(config_.config.group_mode); + ; + TF_RETURN_IF_ERROR(MaybeRegisterBuffers(nccl_api(), stream.parent(), {buffer}, comm_wrapper.comm_handle)); const std::optional source_id = source_target.source; @@ -135,8 +137,7 @@ absl::Status NcclRecvThunk::RunNcclCollective( } else { // If there is no source peer, i.e. no sender to this instance, zero out // the destination buffer. - VLOG(3) << absl::StreamFormat("%s : collective-Permute: Issuing MemZero", - device_string); + VLOG(3) << absl::StreamFormat("%s : Recv: Issuing MemZero", device_string); TF_RETURN_IF_ERROR(stream.MemZero(&dest_addr, dest_addr.size())); } return absl::OkStatus(); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc index 8e719449d07f44..9f76cafba6db02 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_send_thunk.cc @@ -92,9 +92,10 @@ absl::Status NcclSendThunk::RunNcclCollective( // Determine the target IDs for this instance. The target ID is the ID // to which this instance will copy its data. int device_ordinal = stream.parent()->device_ordinal(); - VLOG(3) << "Performing collective permute from device ordinal: " - << device_ordinal << "current_id " << current_id; - TF_RETURN_IF_ERROR(MaybeRegisterBuffers(nccl_api(), device_ordinal, {buffer}, + VLOG(3) << "Performing Send from device ordinal: " << device_ordinal + << ", current_id: " << current_id << ", group mode: " + << CollectiveOpGroupModeToString(config_.config.group_mode); + TF_RETURN_IF_ERROR(MaybeRegisterBuffers(nccl_api(), stream.parent(), {buffer}, comm_wrapper.comm_handle)); const std::optional target_id = source_target.target; diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index d2ca3ae1184b52..90319ff9ee1651 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #include "xla/service/global_device_id.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_allocations.h" @@ -43,7 +44,6 @@ limitations under the License. #include "xla/service/gpu/runtime/nccl_clique_key.h" #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index db0e49c355f102..1c3e09f8fb2889 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -46,7 +46,7 @@ limitations under the License. #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace xla { namespace gpu { @@ -400,6 +400,23 @@ class Thunk { bool mock_collectives = false); }; + //===--------------------------------------------------------------------===// + // CleanupParams + //===--------------------------------------------------------------------===// + + // Parameters passed to Cleanup. Before returning from executable execution, + // thunks may need to clean up any resource allocated or registered through + // runtime APIs. + struct CleanupParams { + se::StreamExecutor* executor = nullptr; + + // Parameters for executing collective operations. + CollectiveExecuteParams* collective_params = nullptr; + + // Collective cliques acquired based on resource requests. + CollectiveCliques* collective_cliques = nullptr; + }; + //===--------------------------------------------------------------------===// // The hlo_instruction argument is meant to be the instruction this thunk was @@ -444,6 +461,14 @@ class Thunk { // Precondition: Initialize(initialize_params) has been called. virtual absl::Status ExecuteOnStream(const ExecuteParams& params) = 0; + // Cleans up any resources after thunk execution. + // + // This may be called multiple times. Its main purpose is to free up + // any resources occupied after initialization and execution. + virtual absl::Status Cleanup(const CleanupParams& params) { + return absl::OkStatus(); + } + static absl::string_view KindToString(Thunk::Kind kind); ExecutionStreamId execution_stream_id() const { return execution_stream_id_; } diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index ea5607516d6e3d..85c7e26d58b7cc 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -42,7 +42,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/hlo_creation_utils.h" diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index 8c17196090f3ca..cf6f98a6400068 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -690,7 +690,7 @@ class SplitKTestWithMorePreciseReduction : public HloTestBase, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_disable_reduced_precision_reduction( true); diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index 5a515a8a2d5ce8..961b7bcf6a81e6 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -58,13 +58,13 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/typed_kernel_factory.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/tsl/util/proto/proto_utils.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla { namespace gpu { @@ -493,7 +493,6 @@ static void InitializeTypedBuffer(se::Stream* stream, // Nothing more to do return; } -#ifdef GOOGLE_CUDA // Repeat the host_buffer_size elements at the start of `buf` to the end CHECK_EQ(elements_to_fill, buffer.size() / sizeof(T) - host_buffer_size); se::StreamExecutor* executor = stream->parent(); @@ -514,7 +513,6 @@ static void InitializeTypedBuffer(se::Stream* stream, se::BlockDim(blocks_per_grid, 1, 1), *kernel, buffer, host_buffer_bytes, static_cast(buffer.size()))); -#endif } void InitializeBuffer(se::Stream* stream, PrimitiveType buffer_type, diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.h b/third_party/xla/xla/service/gpu/stream_executor_util.h index 877ad8bcc62f48..d0338595f9f17d 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.h +++ b/third_party/xla/xla/service/gpu/stream_executor_util.h @@ -37,8 +37,8 @@ limitations under the License. #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/protobuf/dnn.pb.h" // Helper functions for interacting with StreamExecutor. diff --git a/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc b/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc new file mode 100644 index 00000000000000..392d08eb63705b --- /dev/null +++ b/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc @@ -0,0 +1,21 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace xla::gpu::repeat_buffer_kernel { + +// Stub to make CPU build linker find undefined symbol. +void* kernel() { return nullptr; } + +} // namespace xla::gpu::repeat_buffer_kernel diff --git a/third_party/xla/xla/service/gpu/target_util.cc b/third_party/xla/xla/service/gpu/target_util.cc index 123763f856526d..a18ee0887312bd 100644 --- a/third_party/xla/xla/service/gpu/target_util.cc +++ b/third_party/xla/xla/service/gpu/target_util.cc @@ -428,7 +428,7 @@ llvm::CallInst* EmitCallToTargetIntrinsic( LOG(FATAL) << "Invalid triple " << target_triple.str(); } - llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + llvm::Function* intrinsic = llvm::Intrinsic::getOrInsertDeclaration( module, llvm_intrinsic_id, llvm_ir::AsArrayRef(overloaded_types)); return b->CreateCall(intrinsic, llvm_ir::AsArrayRef(operands)); } diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index def6442e46967b..f0f189e645c44f 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -83,13 +83,14 @@ xla_test( "//xla:shape_util", "//xla/ffi", "//xla/ffi:ffi_api", - "//xla/stream_executor", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:test", ], ) + [ + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/tests:xla_internal_test_main", ], ) @@ -209,7 +210,7 @@ xla_test( "swap_conv_operands_test.cc", ], backends = ["gpu"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":gpu_codegen_test", "//xla:error_spec", @@ -227,7 +228,7 @@ xla_test( deps = [ ":gpu_codegen_test", "//xla:error_spec", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:device_description", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", @@ -385,7 +386,7 @@ xla_test( "//xla:shape_util", "//xla:test_helpers", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_base", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", @@ -583,6 +584,7 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", ], ) @@ -654,6 +656,7 @@ lit_test_suite( "transpose_210.hlo", "transpose_210_extra_output.hlo", "triton_naming.hlo", + "zero_clamp_abs_index.hlo", ], include = [ "*.hlo", @@ -677,6 +680,16 @@ lit_test_suite( ], default_tags = tf_cuda_tests_tags(), hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", +<<<<<<< HEAD +======= + tags_override = { + "element_wise_row_vectorization.hlo": ["cuda-only"], + "scatter_bf16.hlo": ["cuda-only"], + "single_instruction.hlo": ["cuda-only"], + "reduce_unnested.hlo": ["cuda-only"], + "reduction_vectorization_sm_all.hlo": ["cuda-only"], + }, +>>>>>>> master tools = [ "//xla/tools:hlo-opt", "@llvm-project//llvm:FileCheck", @@ -701,6 +714,7 @@ lit_test_suite( # deps = [ # "//xla/service/gpu/fusions/transforms:passes", # "//xla/service/gpu/fusions/triton:passes", +# "//xla/service/gpu/fusions/triton:xla_triton", # "@llvm-project//mlir:AllExtensions", # "@llvm-project//mlir:MlirOptLib", # "@triton//:AllPassesAndDialects", @@ -748,7 +762,11 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla:xla_proto_cc", + "//xla/stream_executor:device_description", + "//xla/stream_executor:kernel", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", @@ -758,7 +776,6 @@ xla_test( ] + if_cuda_is_configured([ "//xla/stream_executor/cuda:cuda_asm_compiler", "//xla/service/gpu:gpu_asm_opts_util", - "//xla/stream_executor", "//xla/service/gpu:stream_executor_util", "//xla/stream_executor:device_memory", ]), @@ -786,6 +803,10 @@ xla_test( "gpu_a100", "gpu_h100", ], +<<<<<<< HEAD +======= + tags = ["cuda-only"], +>>>>>>> master deps = if_cuda_is_configured( [ ":gpu_codegen_test", @@ -837,28 +858,23 @@ xla_test( "//xla:error_spec", "//xla:literal", "//xla:literal_util", - "//xla:reference_util", "//xla:shape_util", "//xla:test_helpers", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service:hlo_module_config", - "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_macros_header", - "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc index 66920a3f182792..9ad5b19e0b6c70 100644 --- a/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc @@ -26,7 +26,7 @@ namespace { class ConcatenateEmitterTest : public gpu::GpuCodegenTest { protected: ConcatenateEmitterTest() = default; - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.set_xla_gpu_mlir_emitter_level(0); return opts; diff --git a/third_party/xla/xla/service/gpu/tests/dot_bf16.hlo b/third_party/xla/xla/service/gpu/tests/dot_bf16.hlo index 28db5c95903ac3..a88d1b17befc91 100644 --- a/third_party/xla/xla/service/gpu/tests/dot_bf16.hlo +++ b/third_party/xla/xla/service/gpu/tests/dot_bf16.hlo @@ -18,7 +18,7 @@ ENTRY %computation1 { // ----- // CHECK-SM70: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(f32[1536,6144]{1,0} {{.*}}, f32[32,1536]{1,0} {{.*}}), custom_call_target="__cublas$gemm" -// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(bf16[1536,6144]{1,0} %convert.1.0, bf16[32,1536]{1,0} %b.1), custom_call_target="__cublas$gemm" +// CHECK-SM80: (f32[6144,32]{1,0}, s8[4194304]{0}) custom-call(bf16[1536,6144]{1,0} %convert.2.0, bf16[32,1536]{1,0} %b.1), custom_call_target="__cublas$gemm" HloModule module2 diff --git a/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc index 34b5c703798c23..16383324dfb016 100644 --- a/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc +++ b/third_party/xla/xla/service/gpu/tests/float_conversions_test.cc @@ -23,131 +23,53 @@ namespace gpu { class FloatConversionTest : public GpuCodegenTest {}; -TEST_F(FloatConversionTest, F8E5M2ToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e5m2[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3FNToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3fn[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3B11FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3b11fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E5M2FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e5m2fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F8E4M3FNUZToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f8e4m3fnuz[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, BF16ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = bf16[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F16ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F64ToF32) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f64[] parameter(0) - ROOT %c = f32[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} +class FloatConversionParamTest + : public GpuCodegenTest, + public ::testing::WithParamInterface {}; -TEST_F(FloatConversionTest, F16ToF8E5M2) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e5m2[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} +INSTANTIATE_TEST_SUITE_P(FloatConversionParamSuite, FloatConversionParamTest, + ::testing::Values("f64", "f32", "f16", "bf16", + "f8e5m2", "f8e5m2fnuz", "f8e4m3", + "f8e4m3fn", "f8e4m3fnuz", + "f8e4m3b11fnuz", "f8e3m4")); -TEST_F(FloatConversionTest, F16ToF8E4M3FN) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3fn[] convert(%p) +TEST_P(FloatConversionParamTest, FloatToF16) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = %s[] parameter(0) + ROOT c1 = f16[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E4M3B11FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3b11fnuz[] convert(%p) +TEST_P(FloatConversionParamTest, F16ToFloat) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = f16[] parameter(0) + ROOT c1 = %s[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E5M2FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e5m2fnuz[] convert(%p) +TEST_P(FloatConversionParamTest, FloatToF32) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = %s[] parameter(0) + ROOT c1 = f32[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } -TEST_F(FloatConversionTest, F16ToF8E4M3FNUZ) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f16[] parameter(0) - ROOT %c = f8e4m3fnuz[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToBF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = bf16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToF16) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = f16[] convert(%p) - })", - ErrorSpec{1e-5, 1e-5})); -} - -TEST_F(FloatConversionTest, F32ToF64) { - EXPECT_TRUE(RunAndCompare(R"(ENTRY m { - %p = f32[] parameter(0) - ROOT %c = f64[] convert(%p) +TEST_P(FloatConversionParamTest, F32ToFloat) { + auto type_name = GetParam(); + EXPECT_TRUE(RunAndCompare(absl::StrFormat(R"(ENTRY m { + p0 = f32[] parameter(0) + ROOT c1 = %s[] convert(p0) })", + type_name), ErrorSpec{1e-5, 1e-5})); } diff --git a/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc index 18b5feeeb6e733..716797f1ba36b4 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_compilation_parallelism_test.cc @@ -27,7 +27,7 @@ namespace gpu { namespace { class CompilationParallelismTest : public GpuCodegenTest { - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // Use multiple threads for compilation debug_options.set_xla_gpu_force_compilation_parallelism(4); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index b0e2d9c86c95a9..5028af0a6be959 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -26,9 +26,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -89,7 +89,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { ErrorSpec mha_error_spec_{2.5E-3, 1e-5}; protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cudnn_fmha(true); debug_options.clear_xla_gpu_enable_command_buffer(); @@ -1263,94 +1263,7 @@ class FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM } }; -class FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8 - : public MultiHeadedAttentionTest { - protected: - void TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8() { - if (skip_reason_) GTEST_SKIP() << *skip_reason_; - if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(9, 1, 0)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; - } - XlaBuilder builder(TestName()); - std::string hlo_string_ref = - R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.4.0 = (bf16[4,4,16,16]{3,1,2,0}, u8[16]{0}) custom-call(convert.19, convert.31, convert.43), custom_call_target="__cudnn$fmhaSoftmax", operand_layout_constraints={bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}}, api_version=API_VERSION_STATUS_RETURNING, backend_config={"operation_queue_id": "0", "wait_on_operation_queues": [], "cudnn_fmha_backend_config": {"algorithm": {"algo_id": "0", "math_type": "TENSOR_OP_MATH", "tuning_knobs": {"17": "1", "24": "0"}, "is_cudnn_frontend": true, "workspace_size": "0"}, "fmha_scale": 1.0, "dropout_rate": 0.0, "intermediate_tensor_shape": {"element_type": "BF16", "dimensions": ["4", "4", "16", "16"], "tuple_shapes": [], "layout": {"dim_level_types": [], "dim_unique": [], "dim_ordered": [], "minor_to_major": ["3", "2", "1", "0"], "tiles": [], "element_size_in_bits": "0", "memory_space": "0", "index_primitive_type": "PRIMITIVE_TYPE_INVALID", "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", "dynamic_shape_metadata_prefix_bytes": "0"}, "is_dynamic_dimension": [false, false, false, false]}, "seed": 42, "is_flash_attention": true, "mask_type": "NO_MASK", "bmm1_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["3"], "lhs_batch_dimensions": ["0", "2"], "rhs_batch_dimensions": ["0", "2"]}, "bmm2_dot_dimension_numbers": {"lhs_contracting_dimensions": ["3"], "rhs_contracting_dimensions": ["1"], "lhs_batch_dimensions": ["0", "1"], "rhs_batch_dimensions": ["0", "2"]}}} - ROOT get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 - } // main.106 - )"; // NOLINT - std::string hlo_string = R"( - HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,4,16,16]{3,1,2,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true}, allow_spmd_sharding_propagation_to_output={true} - clip.33 { - Arg_2.36 = bf16[] parameter(2) - broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} - Arg_1.35 = bf16[] parameter(1) - broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} - Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) - ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) - } // clip.33 - ENTRY main.106 { - constant.99 = f32[] constant(1) - broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} - Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) - constant.6 = bf16[] constant(1) - broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} - divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) - constant.5 = bf16[] constant(-448) - constant.4 = bf16[] constant(448) - call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 - convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) - convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) - Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) - divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) - call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 - convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) - convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) - Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) - divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) - call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 - convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) - convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) - custom-call.21.0 = (f8e4m3fn[4,4,16,16]{3,1,2,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, u8[16]{0}) custom-call(convert.18, convert.30, convert.42, broadcast.99, broadcast.99, /*index=5*/broadcast.99, broadcast.99, broadcast.99, broadcast.99), custom_call_target="__cudnn$fmhaSoftmaxF8", operand_layout_constraints={f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f8e4m3fn[4,16,4,16]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}} - get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 - ROOT out = bf16[4,4,16,16]{3,1,2,0} convert(get-tuple-element.5.0) - } // main.106 - )"; // NOLINT - EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, - ErrorSpec{1e-2, 1e-2})); - } -}; +class FlashAttentionBMMScaleSoftmaxBMMF8 : public MultiHeadedAttentionTest {}; class FlashAttentionBMMScaleSoftmaxDropoutBMM : public MultiHeadedAttentionTest { @@ -1465,10 +1378,442 @@ XLA_TEST_F(FlashAttentionBMMScaleSlidingWindowMaskSoftmaxBMM, bfloat16>(); // NOLINT } +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef() { + static constexpr absl::string_view hlo_text = + R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } + + ENTRY main.106 { + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, bf16[] constant(-448), bf16[] constant(448)), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} + +absl::string_view GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8() { + static constexpr absl::string_view hlo_text = R"( + HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0}, bf16[4,16,4,16]{3,2,1,0})->bf16[4,16,4,16]{3,2,1,0}}, allow_spmd_sharding_propagation_to_parameters={true,true,true} + clip.33 { + Arg_2.36 = bf16[] parameter(2) + broadcast.39 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_2.36), dimensions={} + Arg_1.35 = bf16[] parameter(1) + broadcast.37 = bf16[4,16,4,16]{3,2,1,0} broadcast(Arg_1.35), dimensions={} + Arg_0.34 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + maximum.38 = bf16[4,16,4,16]{3,2,1,0} maximum(broadcast.37, Arg_0.34) + ROOT minimum.40 = bf16[4,16,4,16]{3,2,1,0} minimum(broadcast.39, maximum.38) + } // clip.33 + ENTRY main.106 { + constant.99 = f32[] constant(1) + broadcast.99 = f32[1,1,1,1]{3,2,1,0} broadcast(constant.99), dimensions={} + Arg_0.1 = bf16[4,16,4,16]{3,2,1,0} parameter(0) + constant.6 = bf16[] constant(1) + broadcast.7 = bf16[4,16,4,16]{3,2,1,0} broadcast(constant.6), dimensions={} + divide.8 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_0.1, broadcast.7) + constant.5 = bf16[] constant(-448) + constant.4 = bf16[] constant(448) + call.17 = bf16[4,16,4,16]{3,2,1,0} call(divide.8, constant.5, constant.4), to_apply=clip.33 + convert.18 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.17) + convert.19 = bf16[4,16,4,16]{3,2,1,0} convert(convert.18) + Arg_1.2 = bf16[4,16,4,16]{3,2,1,0} parameter(1) + divide.20 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_1.2, broadcast.7) + call.29 = bf16[4,16,4,16]{3,2,1,0} call(divide.20, constant.5, constant.4), to_apply=clip.33 + convert.30 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.29) + convert.31 = bf16[4,16,4,16]{3,2,1,0} convert(convert.30) + Arg_2.3 = bf16[4,16,4,16]{3,2,1,0} parameter(2) + divide.32 = bf16[4,16,4,16]{3,2,1,0} divide(Arg_2.3, broadcast.7) + call.41 = bf16[4,16,4,16]{3,2,1,0} call(divide.32, constant.5, constant.4), to_apply=clip.33 + convert.42 = f8e4m3fn[4,16,4,16]{3,2,1,0} convert(call.41) + convert.43 = bf16[4,16,4,16]{3,2,1,0} convert(convert.42) + )"; + return hlo_text; +} // BMM1 - Scale - Softmax - BMM2 fp8 -XLA_TEST_F(FlashAttentionBMMScalePaddingMaskSoftmaxBMMF8, - Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8) { - TestImpl_Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_F8(); +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BNTH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + auto cc = GetCudaComputeCapability(); + if (!cc.IsAtLeastHopper()) { + GTEST_SKIP() << "Flash Attention fp8 requires at least Hopper."; + } + XlaBuilder builder(TestName()); + std::string ref_bnth = R"( + custom-call.4.0 = ( + bf16[4,4,16,16]{3,1,2,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = bf16[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.4.0), index=0 + ROOT transpose.7 = bf16[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + } +)"; + + std::string fp8_bnth = R"( + custom-call.21.0 = ( + f8e4m3fn[4,4,16,16]{3,1,2,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "4", "16", "16"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "2"], + "rhs_batch_dimensions": ["0", "2"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["1"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "2"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,4,16,16]{3,1,2,0} get-tuple-element(custom-call.21.0), index=0 + transpose.26 = f8e4m3fn[4,16,4,16]{3,2,1,0} transpose(get-tuple-element.5.0), dimensions={0,2,1,3} + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(transpose.26) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_bnth; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_bnth; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); +} + +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Inference_BMM1_NoMask_Softmax_BMM2_BTNH_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + auto cc = GetCudaComputeCapability(); + if (!cc.IsAtLeastHopper()) { + GTEST_SKIP() << "Flash Attention fp8 requires at least Hopper."; + } + XlaBuilder builder(TestName()); + + std::string ref_btnh = R"( + custom-call.4.0 = ( + bf16[4,16,4,16]{3,2,1,0}, + u8[0]{0} + ) custom-call( + convert.19, + convert.31, + convert.43 + ), + custom_call_target="__cudnn$fmhaSoftmax", + operand_layout_constraints={ + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0}, + bf16[4,16,4,16]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + ROOT get-tuple-element.5.0 = bf16[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.4.0), index=0 + } +)"; + + std::string fp8_btnh = R"( + custom-call.21.0 = ( + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + u8[16]{0} + ) custom-call( + convert.18, + convert.30, + convert.42, + broadcast.99, + broadcast.99, + /*index=5*/broadcast.99, + broadcast.99, + broadcast.99, + broadcast.99 + ), + custom_call_target="__cudnn$fmhaSoftmaxF8", + operand_layout_constraints={ + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f8e4m3fn[4,16,4,16]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": { + "17": "1", + "24": "0" + }, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 0.75, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["4", "16", "4", "4"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + get-tuple-element.5.0 = f8e4m3fn[4,16,4,16]{3,2,1,0} get-tuple-element(custom-call.21.0), index=0 + ROOT out = bf16[4,16,4,16]{3,2,1,0} convert(get-tuple-element.5.0) + } + )"; + + std::string hlo_string = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonF8()) + + fp8_btnh; + std::string hlo_string_ref = + std::string(GetModuleFlashAttentionBMMScaleSoftmaxBMMCommonRef()) + + ref_btnh; + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string, hlo_string_ref, + ErrorSpec{5e-2, 5e-2})); } // BMM1 - Scale - Softmax - BMM2 fp8 @@ -1476,6 +1821,268 @@ XLA_TEST_F(FlashAttentionBMMScaleSoftmaxDropoutBMM, Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2) { TestImpl_Flash_Attention_Training_BMM1_Softmax_Dropout_BMM2(); } + +XLA_TEST_F(FlashAttentionBMMScaleSoftmaxBMMF8, + Flash_Attention_Bwd_BMM1_NoMask_Softmax_BMM2_F8) { + if (skip_reason_) GTEST_SKIP() << *skip_reason_; + if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < + se::dnn::VersionInfo(9, 1, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.1.0."; + } + auto cc = GetCudaComputeCapability(); + if (!cc.IsAtLeastHopper()) { + GTEST_SKIP() << "Flash Attention fp8 requires at least Hopper."; + } + XlaBuilder builder(TestName()); + std::string hlo_string_ref = R"( + HloModule fmha_cudnn_custom_call_bwd + // Process inputs: clip, convert to f8e4m3fn, and convert back to bf16 + cast_to_representable { + // Parameters + input = bf16[1,1,128,128] parameter(0) + min_val = bf16[] parameter(1) + max_val = bf16[] parameter(2) + + // Broadcasting min and max values + min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={} + max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={} + + // Clipping the scaled input + clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input) + clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min) + + // Converting to f8e4m3fn and back to bf16 + converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped) + ROOT converted_bf16 = bf16[1,1,128,128] convert(converted_f8) + } + // Main function + ENTRY main { + // Input parameters + query = bf16[1,1,128,128] parameter(0) + key = bf16[1,1,128,128] parameter(1) + value = bf16[1,1,128,128] parameter(2) + grad_output = bf16[1,1,128,128] parameter(3) + fwd_output = bf16[1,1,128,128] parameter(4) + score = f32[1,1,128] parameter(5) + + // Constants + one_f32 = f32[] constant(1) + one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={} + min_clip_val = bf16[] constant(-448) + max_clip_val = bf16[] constant(448) + + query_processed = bf16[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable + key_processed = bf16[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable + value_processed = bf16[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable + grad_output_processed = bf16[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + fwd_output_processed = bf16[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + + // FMHA Forward Backward custom call + fmha_result = (bf16[1,1,128,128], bf16[1,1,128,128], bf16[1,1,128,128], u8[0]) custom-call( + query_processed, key_processed, value_processed, + score, fwd_output_processed, grad_output_processed + ), + custom_call_target="__cudnn$fmhaSoftmaxBackward", + operand_layout_constraints={ + bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0}, + bf16[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0}, + bf16[1,1,128,128]{3,2,1,0}, bf16[1,1,128,128]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": {"17": "1", "24": "0"}, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 1.0, + "dropout_rate": 0.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["1", "1", "128", "128"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "seed": 42, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "sliding_window_length": 0, + "bmm1_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm1_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + + ROOT output = bf16[1,1,128,128] get-tuple-element(fmha_result), index=0 + })"; + + std::string hlo_string = R"( + HloModule fmha_cudnn_custom_call_bwd_f8 + // Process inputs: clip, convert to f8e4m3fn + cast_to_representable { + // Parameters + input = bf16[1,1,128,128] parameter(0) + min_val = bf16[] parameter(1) + max_val = bf16[] parameter(2) + + // Broadcasting min and max values + min_broadcast = bf16[1,1,128,128] broadcast(min_val), dimensions={} + max_broadcast = bf16[1,1,128,128] broadcast(max_val), dimensions={} + + // Clipping the scaled input + clipped_min = bf16[1,1,128,128] maximum(min_broadcast, input) + clipped = bf16[1,1,128,128] minimum(max_broadcast, clipped_min) + + // Converting to f8e4m3fn and back to bf16 + ROOT converted_f8 = f8e4m3fn[1,1,128,128] convert(clipped) + } + + // Main function + ENTRY main { + // Input parameters + query = bf16[1,1,128,128] parameter(0) + key = bf16[1,1,128,128] parameter(1) + value = bf16[1,1,128,128] parameter(2) + grad_output = bf16[1,1,128,128] parameter(3) + fwd_output = bf16[1,1,128,128] parameter(4) + score = f32[1,1,128] parameter(5) + + // Constants + one_f32 = f32[] constant(1) + one_f32_broadcast = f32[1,1,1,1] broadcast(one_f32), dimensions={} + min_clip_val = bf16[] constant(-448) + max_clip_val = bf16[] constant(448) + + query_processed = f8e4m3fn[1,1,128,128] call(query, min_clip_val, max_clip_val), to_apply=cast_to_representable + key_processed = f8e4m3fn[1,1,128,128] call(key, min_clip_val, max_clip_val), to_apply=cast_to_representable + value_processed = f8e4m3fn[1,1,128,128] call(value, min_clip_val, max_clip_val), to_apply=cast_to_representable + grad_output_processed = f8e4m3fn[1,1,128,128] call(grad_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + fwd_output_processed = f8e4m3fn[1,1,128,128] call(fwd_output, min_clip_val, max_clip_val), to_apply=cast_to_representable + + // FMHA Softmax Backward custom call + fmha_result = (f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], f8e4m3fn[1,1,128,128], + f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], f32[1,1,1,1], u8[0]) custom-call( + query_processed, key_processed, value_processed, + grad_output_processed, fwd_output_processed, score, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, + one_f32_broadcast, one_f32_broadcast, one_f32_broadcast, one_f32_broadcast + ), + custom_call_target="__cudnn$fmhaSoftmaxBackwardF8", + operand_layout_constraints={ + f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0}, + f8e4m3fn[1,1,128,128]{3,2,1,0}, f8e4m3fn[1,1,128,128]{3,2,1,0}, + f8e4m3fn[1,1,128,128]{3,2,1,0}, f32[1,1,128]{2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, + f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0}, f32[1,1,1,1]{3,2,1,0} + }, + api_version=API_VERSION_STATUS_RETURNING, + backend_config={ + "operation_queue_id": "0", + "wait_on_operation_queues": [], + "cudnn_fmha_backend_config": { + "algorithm": { + "algo_id": "0", + "math_type": "TENSOR_OP_MATH", + "tuning_knobs": {"17": "1", "24": "0"}, + "is_cudnn_frontend": true, + "workspace_size": "0" + }, + "fmha_scale": 1.0, + "intermediate_tensor_shape": { + "element_type": "BF16", + "dimensions": ["1", "1", "128", "128"], + "tuple_shapes": [], + "layout": { + "dim_level_types": [], + "dim_unique": [], + "dim_ordered": [], + "minor_to_major": ["3", "2", "1", "0"], + "tiles": [], + "element_size_in_bits": "0", + "memory_space": "0", + "index_primitive_type": "PRIMITIVE_TYPE_INVALID", + "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID", + "dynamic_shape_metadata_prefix_bytes": "0" + }, + "is_dynamic_dimension": [false, false, false, false] + }, + "is_flash_attention": true, + "mask_type": "NO_MASK", + "bmm1_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm1_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm1_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["2"], + "rhs_contracting_dimensions": ["2"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + }, + "bmm2_grad_gemm2_dot_dimension_numbers": { + "lhs_contracting_dimensions": ["3"], + "rhs_contracting_dimensions": ["3"], + "lhs_batch_dimensions": ["0", "1"], + "rhs_batch_dimensions": ["0", "1"] + } + } + } + + fmha_output = f8e4m3fn[1,1,128,128] get-tuple-element(fmha_result), index=0 + ROOT output = bf16[1,1,128,128] convert(fmha_output) + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_string_ref, hlo_string, + ErrorSpec{2e-1, 2e-1})); +} } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc index 29876e9cf1b3b7..00a331bd8a1c28 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc @@ -34,13 +34,6 @@ namespace gpu { namespace { class GpuFusionPipelineTest : public GpuCodegenTest { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: void CheckGpuFusionPipeline(absl::string_view hlo, std::optional expected) { @@ -50,8 +43,10 @@ class GpuFusionPipelineTest : public GpuCodegenTest { pipeline.AddPass(/*may_duplicate=*/false, device_info); pipeline.AddPass(/*may_duplicate=*/true, device_info); - pipeline.AddPass(device_info, ShapeSizeBytesFunction()); - pipeline.AddPass(device_info, ShapeSizeBytesFunction()); + pipeline.AddPass(device_info, + HloCostAnalysis::DefaultShapeSize); + pipeline.AddPass(device_info, + HloCostAnalysis::DefaultShapeSize); RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected); } diff --git a/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc index e0f399dc017e61..aeebc7902785b5 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc @@ -116,7 +116,7 @@ TEST_F(GpuInt4Test, TestOddElements) { ; CHECK-NEXT: br i1 %[[in_bounds]], label %[[in_bounds_true:.*]], label %[[in_bounds_after:.*]] ; CHECK: [[in_bounds_true]]: ; CHECK: %{{.*}} = load i8, ptr %{{.*}}, align 1 - ; CHECK: store i8 %{{.*}}, ptr %{{.*}}, align 1 + ; CHECK: cmpxchg ptr %{{.*}} ; CHECK: br label %[[in_bounds_after]] ; CHECK: [[in_bounds_after]]: ; CHECK-NEXT: ret void)"; diff --git a/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc index 6133a5b38f4bc5..6f6dde7a185df7 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc @@ -43,7 +43,7 @@ class SparseDotTest : public GpuCodegenTest, public ::testing::WithParamInterface> { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(true); debug_options.set_xla_gpu_autotune_level(0); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc index cc4c36507fec94..e8b7208b049dd8 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc @@ -34,7 +34,7 @@ namespace { class GpuSpmdE2ECompileTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_autotune_level(0); return debug_options; diff --git a/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc b/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc index 853aaf18430943..17af5a49919eaa 100644 --- a/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc +++ b/third_party/xla/xla/service/gpu/tests/in_place_op_test.cc @@ -24,7 +24,7 @@ namespace { class InPlaceOpTest : public HloTestBase { // Don't override any flags. - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { return GetDebugOptionsFromFlags(); } }; diff --git a/third_party/xla/xla/service/gpu/tests/infeed_test.cc b/third_party/xla/xla/service/gpu/tests/infeed_test.cc index 933313f39710b5..3d20d884691b9d 100644 --- a/third_party/xla/xla/service/gpu/tests/infeed_test.cc +++ b/third_party/xla/xla/service/gpu/tests/infeed_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" diff --git a/third_party/xla/xla/service/gpu/tests/mock_custom_call_test.cc b/third_party/xla/xla/service/gpu/tests/mock_custom_call_test.cc index 90380c4a5db72a..5bcb680ac162ac 100644 --- a/third_party/xla/xla/service/gpu/tests/mock_custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/tests/mock_custom_call_test.cc @@ -35,7 +35,7 @@ TEST_F(UnknownCustomCallFails, UnknownCustomCallFails) { } class MockedCustomCall : public GpuCodegenTest { - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions opts; opts.set_xla_gpu_mock_custom_calls(true); return opts; diff --git a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc index 2e0a975bbfca9d..f2dfd27cae0b3f 100644 --- a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc +++ b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc @@ -38,7 +38,7 @@ namespace { class ParallelReductionTest : public GpuCodegenTest { protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // The test contains a MOF fusion and the XLA optimizer passes // don't like this. diff --git a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc index 680391c2fa7db6..fd2a1c4fb41e02 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "xla/error_spec.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_parser.h" #include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -33,7 +33,7 @@ namespace { class ReductionVectorizationTest : public GpuCodegenTest {}; class ReductionVectorizationNoOptTest : public GpuCodegenTest { - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // The test MultiOutputStore contain a MOF fusion and XLA optimizer pass // doesn't like this. diff --git a/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc b/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc index b68df87498c8e1..70e7a65e697278 100644 --- a/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc +++ b/third_party/xla/xla/service/gpu/tests/simplify_fp_conversions_test.cc @@ -24,7 +24,7 @@ namespace { class SimplifyFPConversionsTest : public HloTestBase { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_allow_excess_precision( enable_simplify_all_fp_conversions_); diff --git a/third_party/xla/xla/service/gpu/tests/sorting_test.cc b/third_party/xla/xla/service/gpu/tests/sorting_test.cc index abcb2066e95234..ba1468f04845b5 100644 --- a/third_party/xla/xla/service/gpu/tests/sorting_test.cc +++ b/third_party/xla/xla/service/gpu/tests/sorting_test.cc @@ -56,10 +56,12 @@ compare { ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } + ENTRY TestComputation { x = f32[3, 2]{1, 0} parameter(0) - x.copy = f32[3, 2]{0, 1} copy(x) - ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare + tr = f32[2, 3]{1, 0} transpose(x), dimensions={1,0} + b = f32[3, 2]{0, 1} bitcast(tr) + ROOT sort = f32[3, 2]{0, 1} sort(b), dimensions={1}, to_apply=compare } )"; diff --git a/third_party/xla/xla/service/gpu/tests/sparse_xla_triton_op.mlir b/third_party/xla/xla/service/gpu/tests/sparse_xla_triton_op.mlir new file mode 100644 index 00000000000000..f856f07f0b4616 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/sparse_xla_triton_op.mlir @@ -0,0 +1,23 @@ +// RUN: xla-opt %s | FileCheck %s + +#mma = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [2, 2], + CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], + instrShape = [16, 8]}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> +#dot_meta_enc = #triton_gpu.sparse_dot_meta<{parent=#mma}> + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: sparse_xla_triton_op + tt.func @sparse_xla_triton_op(%A_dot: tensor<32x32xf16, #dot_operand_a>, + %B_dot: tensor<64x32xf16, #dot_operand_b>, + %meta_reg: tensor<32x4xi16, #dot_meta_enc>) { + %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> + // CHECK-LABEL: triton_xla.sparse_dot + %D = triton_xla.sparse_dot %A_dot, %B_dot, %acc, %meta_reg : + tensor<32x32xf16, #dot_operand_a> meta tensor<32x4xi16, + #dot_meta_enc> * tensor<64x32xf16, #dot_operand_b> + -> tensor<32x32xf32, #mma> + tt.return + } +} diff --git a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc index a8d338b0663872..472618bebfc771 100644 --- a/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc +++ b/third_party/xla/xla/service/gpu/tests/tensor_float_32_global_var_test.cc @@ -45,7 +45,7 @@ class TensorFloat32GlobalVarTest : public ::testing::WithParamInterface, tsl::enable_tensor_float_32_execution(true); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); const bool enable_triton_gemm = GetParam(); if (enable_triton_gemm) { diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc index f27b6f82366230..88951049042455 100644 --- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc +++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc @@ -17,12 +17,14 @@ limitations under the License. #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/fusions/triton/passes.h" +#include "xla/service/gpu/fusions/triton/xla_triton_ops.h" #include "third_party/triton/bin/RegisterTritonDialects.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. + registry.insert(); xla::gpu::registerTritonFusionTransformsPasses(); xla::gpu::registerGpuFusionTransformsPasses(); diff --git a/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo b/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo new file mode 100644 index 00000000000000..59f448644172d4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/zero_clamp_abs_index.hlo @@ -0,0 +1,13 @@ +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s + +e { + p0 = s32[8,9] parameter(0) + p1 = s32[5] parameter(1) + a = s32[5] abs(p1) + ROOT r = s32[5,2,3] gather(p0, a), + offset_dims={1,2}, collapsed_slice_dims={}, start_index_map={0}, + index_vector_dim=1, slice_sizes={2,3} +} + +// CHECK: llvm.smin.i32 +// CHECK-NOT: llvm.smax.i32 diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 96cdb279211540..3c2d885e0b13b1 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -16,6 +16,23 @@ package( licenses = ["notice"], ) +cc_library( + name = "dot_algorithm_rewriter", + srcs = ["dot_algorithm_rewriter.cc"], + hdrs = ["dot_algorithm_rewriter.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status", + ], +) + cc_library( name = "algebraic_simplifier", srcs = [ @@ -25,18 +42,23 @@ cc_library( "algebraic_simplifier.h", ], deps = [ + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/service:pattern_matcher", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -46,7 +68,9 @@ xla_cc_test( deps = [ ":algebraic_simplifier", "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -223,6 +247,7 @@ xla_cc_test( ":reduce_scatter_creator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:filecheck", "//xla/service:hlo_module_config", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -300,6 +325,7 @@ xla_cc_test( "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:hlo_proto_cc", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -316,9 +342,11 @@ cc_library( hdrs = ["all_gather_dynamic_slice_simplifier.h"], deps = [ "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", "//xla/service:collective_opt_utils", - "//xla/service:hlo_creation_utils", - "//xla/service:op_expander_pass", + "//xla/service:hlo_module_config", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) @@ -330,11 +358,71 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_module_config", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_main", ], ) +cc_library( + name = "fusion_block_level_rewriter", + srcs = ["fusion_block_level_rewriter.cc"], + hdrs = ["fusion_block_level_rewriter.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_cost_analysis", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_indexing_performance_model", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "fusion_block_level_rewriter_test", + srcs = ["fusion_block_level_rewriter_test.cc"], + deps = [ + ":fusion_block_level_rewriter", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "collective_permute_cycle_decomposer", srcs = ["collective_permute_cycle_decomposer.cc"], @@ -346,10 +434,10 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", - "//xla/service:hlo_parser", "//xla/service/gpu:backend_configs_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -361,16 +449,49 @@ cc_library( ], ) +cc_library( + name = "collective_send_recv_combiner", + srcs = ["collective_send_recv_combiner.cc"], + hdrs = ["collective_send_recv_combiner.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "collective_send_recv_combiner_test", + srcs = ["collective_send_recv_combiner_test.cc"], + deps = [ + ":collective_send_recv_combiner", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + xla_cc_test( name = "collective_permute_cycle_decomposer_test", srcs = ["collective_permute_cycle_decomposer_test.cc"], deps = [ ":collective_permute_cycle_decomposer", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", @@ -386,6 +507,7 @@ cc_library( "//xla:comparison_util", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:collective_ops_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -419,11 +541,11 @@ cc_library( hdrs = ["collective_permute_valid_iteration_annotator.h"], deps = [ "//xla:literal_util", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:collective_ops_utils", "//xla/service:pattern_matcher", - "//xla/service:while_loop_analysis", ], ) @@ -434,8 +556,8 @@ xla_cc_test( ":collective_permute_valid_iteration_annotator", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:while_loop_trip_count_annotator", "//xla/service:collective_ops_utils", - "//xla/service:while_loop_trip_count_annotator", "//xla/tests:hlo_test_base", "@local_tsl//tsl/platform:test_main", ], @@ -483,7 +605,7 @@ xla_test( deps = [ ":command_buffer_scheduling", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_executable", "//xla/stream_executor:device_description", @@ -601,7 +723,7 @@ cc_library( hdrs = ["convert_async_collectives_to_sync.h"], deps = [ "//xla/hlo/ir:hlo", - "//xla/service:convert_async_collectives_to_sync", + "//xla/hlo/transforms:convert_async_collectives_to_sync", "//xla/service/gpu:backend_configs_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -755,9 +877,10 @@ cc_library( "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -796,13 +919,13 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/service:algebraic_simplifier", - "//xla/service:convert_mover", - "//xla/service:hlo_constant_folding", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:convert_mover", + "//xla/hlo/transforms:hlo_constant_folding", + "//xla/hlo/transforms:reshape_mover", "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service:reshape_mover", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:stream_executor_util", @@ -842,9 +965,11 @@ cc_library( "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_indexing_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -869,7 +994,7 @@ xla_test( "gpu", ], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cudnn_fused_mha_rewriter", ":cudnn_fused_mha_transpose_fusion", @@ -878,17 +1003,17 @@ xla_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", + "//xla/hlo/parser:hlo_parser", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:reshape_decomposer", "//xla/service:computation_layout", "//xla/service:hlo_cse", - "//xla/service:hlo_dce", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_verifier", "//xla/service:layout_normalization", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service:reshape_decomposer", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/stream_executor:device_description", @@ -920,7 +1045,7 @@ cc_library( "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:matmul_indexing_utils", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -967,7 +1092,7 @@ cc_library( "//xla/stream_executor/cuda:cudnn_plugin", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - ]), + ]) + ["//xla/service/gpu:matmul_indexing_utils"], ) cc_library( @@ -986,7 +1111,8 @@ cc_library( "//xla/service:pattern_matcher", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -997,7 +1123,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", @@ -1039,7 +1164,7 @@ cc_library( "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:cudnn_support_utils", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", @@ -1057,7 +1182,7 @@ xla_cc_test( srcs = ["cudnn_pad_for_convolutions_test.cc"], deps = [ ":cudnn_pad_for_convolutions", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:cublas_cudnn", @@ -1103,12 +1228,12 @@ xla_cc_test( "//xla:util", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/service:algebraic_simplifier", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:reshape_mover", + "//xla/hlo/transforms:tuple_simplifier", "//xla/service:call_inliner", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/tests:hlo_test_base", @@ -1131,16 +1256,17 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:cudnn_support_utils", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -1161,8 +1287,8 @@ xla_cc_test( deps = [ ":cudnn_vectorize_convolutions", "//xla:util", + "//xla/hlo/parser:hlo_parser", "//xla/service:call_inliner", - "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:backend_configs_cc", @@ -1290,6 +1416,37 @@ xla_test( ], ) +cc_library( + name = "dot_normalizer", + srcs = ["dot_normalizer.cc"], + hdrs = ["dot_normalizer.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "dot_normalizer_test", + srcs = ["dot_normalizer_test.cc"], + deps = [ + ":dot_normalizer", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + cc_library( name = "dot_operand_converter", srcs = ["dot_operand_converter.cc"], @@ -1298,7 +1455,7 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:op_expander_pass", + "//xla/hlo/transforms:op_expander_pass", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", @@ -1377,11 +1534,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_instruction_utils", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:flatten_call_graph", "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", - "//xla/service:flatten_call_graph", - "//xla/service:hlo_parser", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -1405,13 +1562,16 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/utils:hlo_query", - "//xla/service:tuple_simplifier", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], @@ -1454,29 +1614,28 @@ xla_cc_test( name = "dynamic_slice_fusion_rewriter_test", srcs = ["dynamic_slice_fusion_rewriter_test.cc"], tags = [ + "cuda-only", "gpu", - "no_rocm", ], deps = [ ":dynamic_slice_fusion_rewriter", "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/client/lib:constants", "//xla/ffi", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:constants", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:hlo_memory_scheduler", "//xla/service:buffer_value", "//xla/service:custom_call_target_registry", "//xla/service:executable", - "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_module_config", "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -1640,6 +1799,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", + "//xla/hlo/testlib:verified_hlo_module", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:cublas_padding_requirements", @@ -1664,6 +1825,7 @@ cc_library( deps = [ "//xla:literal", "//xla:literal_util", + "//xla:permutation_util", "//xla:shape_util", "//xla:status_macros", "//xla:types", @@ -1683,6 +1845,7 @@ cc_library( "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:gpu_blas_lt", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -1694,7 +1857,6 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -1702,6 +1864,7 @@ xla_test( name = "gemm_rewriter_test", srcs = ["gemm_rewriter_test.cc"], backends = ["gpu"], + shard_count = 5, deps = [ ":gemm_rewriter", "//xla:error_spec", @@ -1787,7 +1950,6 @@ cc_library( "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:cusolver_context", "//xla/service/gpu:ir_emission_utils", - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:device_memory_allocator", "@local_tsl//tsl/platform:errors", @@ -1848,8 +2010,8 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:sub_byte_normalization", "//xla/service:hlo_creation_utils", - "//xla/service:sub_byte_normalization", "//xla/service/gpu:gpu_fusible", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1877,10 +2039,10 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", - "//xla/service:hlo_dce", - "//xla/service:hlo_parser", + "//xla/hlo/transforms:hlo_dce", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:gpu_device_info_for_tests", @@ -1960,10 +2122,11 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_indexing_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:reduction_utils", "//xla/service/gpu:stream_executor_util", - "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/stream_executor:dnn", "//xla/tsl/util:env_var", "@com_google_absl//absl/log", @@ -1985,8 +2148,8 @@ xla_cc_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:computation_layout", - "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:stream_executor_util", @@ -2041,8 +2204,8 @@ cc_library( deps = [ "//xla:debug_options_flags", "//xla:shape_util", + "//xla/hlo/analysis:hlo_dfs_reachability", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_dfs_reachability", "//xla/hlo/pass:hlo_pass", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_graph_dumper", @@ -2087,6 +2250,68 @@ xla_cc_test( ], ) +cc_library( + name = "nest_gemm_fusion", + srcs = ["nest_gemm_fusion.cc"], + hdrs = ["nest_gemm_fusion.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", + "//xla/service:instruction_fusion", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_indexing_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", + "//xla/service/gpu/model:tiled_hlo_instruction_or_computation", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "nest_gemm_fusion_test", + srcs = ["nest_gemm_fusion_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":nest_gemm_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "pipelined_p2p_rewriter", srcs = ["pipelined_p2p_rewriter.cc"], @@ -2117,6 +2342,7 @@ xla_cc_test( deps = [ ":pipelined_p2p_rewriter", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings:string_view", @@ -2192,7 +2418,9 @@ xla_cc_test( "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions/triton:triton_support", "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", @@ -2379,7 +2607,7 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", @@ -2489,6 +2717,7 @@ xla_cc_test( deps = [ ":scatter_slice_simplifier", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", @@ -2523,7 +2752,7 @@ xla_cc_test( ":schedule_postprocessing", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -2554,6 +2783,7 @@ xla_cc_test( deps = [ ":scheduling_instruction_annotator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings:string_view", @@ -2615,8 +2845,10 @@ xla_cc_test( deps = [ ":softmax_rewriter_triton", "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", "//xla/service:instruction_fusion", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", @@ -2642,6 +2874,7 @@ cc_library( ["sort_rewriter_stub.cc"], ), hdrs = ["sort_rewriter.h"], + visibility = ["//xla/service/gpu:__subpackages__"] + if_google(["//learning/brain/engprod/xwatch:__subpackages__"]), deps = [ "//xla:comparison_util", "//xla:shape_util", @@ -2649,7 +2882,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service:stable_sort_expander", + "//xla/hlo/transforms:stable_sort_expander", "//xla/service/gpu:cublas_cudnn", "//xla/service/gpu/runtime:cub_sort_thunk", "@com_google_absl//absl/container:flat_hash_set", @@ -2699,6 +2932,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -2712,6 +2946,7 @@ xla_cc_test( deps = [ ":stream_attribute_annotator", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service/gpu:backend_configs_cc", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -2736,6 +2971,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], @@ -2793,6 +3029,7 @@ xla_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:statusor", @@ -2827,7 +3064,7 @@ xla_cc_test( deps = [ ":topk_splitter", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", + "//xla/hlo/transforms:hlo_dce", "//xla/service:pattern_matcher", "//xla/service:topk_rewriter", "//xla/tests:hlo_test_base", @@ -2951,6 +3188,7 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:executable", @@ -2963,6 +3201,7 @@ cc_library( "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:stream", "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", @@ -2985,7 +3224,10 @@ xla_test( ":triton_fusion_numerics_verifier", "//xla:shape_util", "//xla:test_helpers", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service:backend", "//xla/service:platform_util", "//xla/service/gpu/autotuning:autotuner_compile_util", "//xla/service/gpu/autotuning:autotuner_util", @@ -3031,7 +3273,7 @@ xla_cc_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/service:pattern_matcher", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -3046,13 +3288,14 @@ cc_library( deps = [ "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:algebraic_simplifier", + "//xla/hlo/transforms:hlo_constant_folding", "//xla/hlo/utils:hlo_query", - "//xla/service:algebraic_simplifier", - "//xla/service:hlo_constant_folding", "//xla/service:hlo_creation_utils", "//xla/service:pattern_matcher", "//xla/service:shape_inference", @@ -3063,6 +3306,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -3076,6 +3320,7 @@ xla_cc_test( deps = [ ":windowed_einsum_handler", "//xla/hlo/ir:hlo", + "//xla/hlo/testlib:filecheck", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/service/gpu:backend_configs_cc", diff --git a/third_party/xla/xla/service/gpu/transforms/README.md b/third_party/xla/xla/service/gpu/transforms/README.md new file mode 100644 index 00000000000000..c176c23c90f435 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/README.md @@ -0,0 +1 @@ +This folder consolidates GPU specific HLO transformation passes. \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc index d59ae2b6a1d039..48250f8c3576f2 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -16,15 +16,69 @@ limitations under the License. #include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape_util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla::gpu { +namespace m = ::xla::match; + +absl::StatusOr +GpuAlgebraicSimplifierVisitor::TryToSinkBroadcastOperandsOfChainedAdds( + HloInstruction* add) { + if (!options_.enable_sink_broadcast()) { + return false; + } + + HloInstruction *conv, *constant_0, *broadcast_0, *add_0, *constant_1, + *broadcast_1; + if (!Match(add, m::AddAnyOrder( + m::AddAnyOrder( + &add_0, m::Convolution(&conv, m::Op(), m::Op()), + m::Broadcast(&broadcast_0, m::Constant(&constant_0))), + m::Broadcast(&broadcast_1, m::Constant(&constant_1))))) { + return false; + } + + // Skip when the broadcast shapes and dimensions don't match. + if (!ShapeUtil::Equal(constant_0->shape(), constant_1->shape()) || + broadcast_0->dimensions() != broadcast_1->dimensions()) { + return false; + } + + HloInstruction* new_constant_add = + add->AddInstruction(HloInstruction::CreateBinary( + constant_0->shape(), HloOpcode::kAdd, constant_0, constant_1)); + HloInstruction* new_bcast = + add->AddInstruction(HloInstruction::CreateBroadcast( + broadcast_0->shape(), new_constant_add, broadcast_0->dimensions())); + TF_RETURN_IF_ERROR(ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, + new_bcast, conv))); + return true; +} + +absl::Status GpuAlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { + TF_ASSIGN_OR_RETURN(bool replaced, + TryToSinkBroadcastOperandsOfChainedAdds(add)); + if (replaced) { + return absl::OkStatus(); + } + + return AlgebraicSimplifierVisitor::HandleAdd(add); +} + bool GpuAlgebraicSimplifierVisitor::ShouldStrengthReduceDotToReduce( const HloInstruction* hlo) { if (!options_.enable_dot_strength_reduction()) { diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h index c63cf1e51a4c75..9e6b1d785a9994 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h @@ -19,11 +19,12 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" @@ -38,9 +39,20 @@ class GpuAlgebraicSimplifierVisitor : public AlgebraicSimplifierVisitor { : AlgebraicSimplifierVisitor(options, simplifier), compute_capability_(std::move(compute_capability)) {} + absl::Status HandleAdd(HloInstruction* add) override; + bool ShouldStrengthReduceDotToReduce(const HloInstruction* hlo) override; private: + // Try to convert add(broadcast(const_0), add(broadcast(const_1), conv(...))) + // into add(broadcast(add(const_0, const_1)), conv(...)) and return true if + // successful. The particular sink happens only when enable_sink_broadcast is + // true and the broadcast shapes and dimensions match. The sink only happens + // when following a convolution to avoid having a side input when the + // instructions are fused to cudnnConvolutionBiasActivationForward later. + absl::StatusOr TryToSinkBroadcastOperandsOfChainedAdds( + HloInstruction* add); + se::GpuComputeCapability compute_capability_; }; diff --git a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc index c1e52e90a417c0..686b2e15d7a58f 100644 --- a/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include +#include #include #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -27,8 +30,122 @@ limitations under the License. namespace xla::gpu { namespace { +namespace m = ::xla::match; + class GpuAlgebraicSimplifierTest : public HloTestBase {}; +TEST_F(GpuAlgebraicSimplifierTest, SinkBroadcastOperandsOfChainedAdds) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + ASSERT_TRUE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::AddAnyOrder( + m::Broadcast(m::Add(m::Constant(), m::Constant())), + m::Convolution(m::Op(), m::Op())))); +} + +TEST_F(GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWhenDisabled) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={1} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(false); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + +TEST_F(GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWithoutConvolution) { + const std::string& hlo_string = R"( + HloModule m + test { + p = bf16[4, 4] parameter(0) + const0 = bf16[4] constant({0, 0.25, 0.5, 0.75}) + bcast0 = bf16[4,4] broadcast(const0), dimensions={0} + add0 = bf16[4,4] add(p, bcast0) + const1 = bf16[4] constant({1, 1.25, 1.5, 1.75}) + bcast1 = bf16[4,4] broadcast(const1), dimensions={0} + ROOT add1 = bf16[4,4] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + +TEST_F( + GpuAlgebraicSimplifierTest, + DoNotSinkBroadcastOperandsOfChainedAddsWithMismatchedBroadcastDimensions) { + const std::string& hlo_string = R"( + HloModule m + test { + in = bf16[1,3,3,1] parameter(0) + filter = bf16[2,2,1,1] constant({{{{1.1}}, {{2.1}}}, + {{{3.1}}, {{4.1}}}}) + conv = bf16[1,2,2,1] convolution(in, filter), + window={size=2x2}, dim_labels=b01f_01io->b01f + const0 = bf16[2] constant({0, 0.25}) + bcast0 = bf16[1,2,2,1] broadcast(const0), dimensions={1} + add0 = bf16[1,2,2,1] add(conv, bcast0) + const1 = bf16[2] constant({1, 1.25}) + bcast1 = bf16[1,2,2,1] broadcast(const1), dimensions={2} + ROOT add1 = bf16[1,2,2,1] add(add0, bcast1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AlgebraicSimplifierOptions options; + options.set_enable_sink_broadcast(true); + EXPECT_FALSE( + GpuAlgebraicSimplifier(options, se::CudaComputeCapability::Ampere()) + .Run(module.get()) + .value()); +} + TEST_F(GpuAlgebraicSimplifierTest, VectorVectorDotShouldBeStrengthReduced) { const std::string& hlo_string = R"( HloModule m diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc index 4035b80606cdff..86a05fe63a64e1 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.cc @@ -15,8 +15,15 @@ limitations under the License. #include "xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h" +#include + +#include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/collective_opt_utils.h" +#include "xla/service/hlo_module_config.h" namespace xla { bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( @@ -46,19 +53,19 @@ bool AllGatherDynamicSliceSimplifier::InstructionMatchesPattern( is_reshape ? Cast(operand->mutable_operand(0)) : Cast(operand); - bool match = AllGatherDynamicSliceCancellation( + std::optional spec = AllGatherDynamicSliceCancellation( all_gather, config.num_partitions(), config.replica_count(), - /*allow_multiple_split_dims=*/true, - /*allow_intervening_reshape=*/true, /*min_rank=*/1, - HloPredicateIsOp, + config_.allow_multiple_split_dims, config_.allow_intervening_reshape, + config_.min_rank, HloPredicateIsOp, HloPredicateIsOp, - /*allow_intervening_bitcast=*/false, - /*allow_multiple_users=*/true); + config_.allow_intervening_bitcast, config_.allow_multiple_users); - return match; + return spec.has_value() && + spec->split_dim == all_gather->all_gather_dimension(); } -StatusOr AllGatherDynamicSliceSimplifier::ExpandInstruction( +absl::StatusOr +AllGatherDynamicSliceSimplifier::ExpandInstruction( HloInstruction* instruction) { HloDynamicSliceInstruction* dynamic_slice = Cast(instruction); diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h index f0fb673ad1f6fa..dc0b1c06ae736f 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier.h @@ -16,7 +16,12 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ #define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_DYNAMIC_SLICE_SIMPLIFIER_H_ -#include "xla/service/op_expander_pass.h" +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { @@ -32,6 +37,20 @@ namespace xla { class AllGatherDynamicSliceSimplifier : public OpExpanderPass { public: + struct Config { + bool allow_multiple_split_dims = false; + bool allow_intervening_reshape = true; + int min_rank = 1; + bool allow_intervening_bitcast = false; + bool allow_multiple_users = false; + }; + + static Config DefaultConfig() { return {}; } + + explicit AllGatherDynamicSliceSimplifier( + Config config = AllGatherDynamicSliceSimplifier::DefaultConfig()) + : config_(std::move(config)) {} + absl::string_view name() const override { return "all-gather-dynamic-slice-simplifier"; } @@ -39,8 +58,11 @@ class AllGatherDynamicSliceSimplifier : public OpExpanderPass { protected: bool InstructionMatchesPattern(HloInstruction* instruction) override; - StatusOr ExpandInstruction( + absl::StatusOr ExpandInstruction( HloInstruction* instruction) override; + + private: + Config config_; }; } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc index c7f4391bc00923..d59af461e2a0f2 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_dynamic_slice_simplifier_test.cc @@ -17,37 +17,43 @@ limitations under the License. #include #include -#include +#include -#include "xla/hlo/ir/hlo_casting_utils.h" +#include +#include +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace gpu { namespace { -using ::testing::Matcher; namespace op = xla::testing::opcode_matchers; class AllGatherDynamicSliceSimplifierTest : public HloTestBase { public: absl::StatusOr> RunPass( absl::string_view hlo_module, int64_t num_replicas, - int64_t num_partitions, bool expect_change) { + int64_t num_partitions, bool expect_change, + bool allow_multiple_users = false) { HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/num_replicas, /*num_partitions=*/num_partitions); config.set_use_spmd_partitioning(num_partitions > 1); TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module, config)); - auto changed = AllGatherDynamicSliceSimplifier().Run(module.get()); + AllGatherDynamicSliceSimplifier::Config pass_config; + pass_config.allow_multiple_users = allow_multiple_users; + auto changed = + AllGatherDynamicSliceSimplifier(pass_config).Run(module.get()); if (!changed.ok()) { return changed.status(); } @@ -198,6 +204,29 @@ TEST_F(AllGatherDynamicSliceSimplifierTest, IncorrectAllGatherDimension) { op::Constant())); } +TEST_F(AllGatherDynamicSliceSimplifierTest, + AllGatherDimDoesNotMatchDynamicSlice) { + absl::string_view hlo_string = R"( + HloModule m + + ENTRY root { + param = f32[2,16] parameter(0) + ag = f32[16,16] all-gather(%param), dimensions={0} + pid = u32[] partition-id() + pid_s32 = s32[] convert(%pid) + slice_size = s32[] constant(2) + offset = s32[] multiply(%pid_s32, %slice_size) + zero = s32[] constant(0) + ROOT _ = f32[16,2] dynamic-slice(ag, zero, offset), + dynamic_slice_sizes={16,2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, + /*num_replicas=*/1, + /*num_partitions=*/8, + /*expect_change=*/false)); +} + // Test cancellation of all-gather followed by dynamic-slice across all replicas // with reshape and multiple users of the all-gather. TEST_F(AllGatherDynamicSliceSimplifierTest, @@ -223,7 +252,8 @@ TEST_F(AllGatherDynamicSliceSimplifierTest, TF_ASSERT_OK_AND_ASSIGN(auto module, RunPass(hlo_string, /*num_replicas=*/1, /*num_partitions=*/8, - /*expect_change=*/true)); + /*expect_change=*/true, + /*allow_multiple_users=*/true)); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Reshape(op::Parameter(0)), op::AllGather(op::Parameter(0)))); diff --git a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h index d1c8ebd9fb48c4..6250876fb5c322 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h @@ -25,10 +25,13 @@ limitations under the License. namespace xla { namespace gpu { -// Transforms binary_op(all-gather(reduce_scatter(a)), -// all-gather(reduce_scatter(b))) to allgather(binary_op(reduce_scatter(a), -// reduce_scatter(b))) - +// Transforms +// binary_op(all-gather(op1(a)),all-gather(op2(b))) +// to +// allgather(binary_op(op1(a),op2(b))) +// +// Where binary_op is commutative and takes exactly two operands as input. +// class AllGatherOptimizer : public HloModulePass { public: AllGatherOptimizer() = default; diff --git a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc index 0ab8c6f3b894f8..581237dd5f0479 100644 --- a/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include "xla/service/hlo_module_config.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc index b6c20041715e5c..ef8bd95e88b42c 100644 --- a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc @@ -24,11 +24,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc index 4bd671f0e1ac96..a8462e252d0866 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc @@ -32,11 +32,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h index 5aaf46c0028433..7663d878745b2c 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h @@ -56,9 +56,6 @@ class CollectivePermuteCycleDecomposer : public HloModulePass { return "collective-permute-cycle-decomposer"; } - using HloPassInterface::Run; - // Runs CollectivePermuteCycleDecomposer pass on computations in 'module'. - // Returns whether the 'module' was changed. absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc index 98f53fd5be7070..ab4b5466dda500 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc @@ -25,10 +25,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -36,9 +35,17 @@ namespace { using ::testing::HasSubstr; using CollectivePermuteCycleDecomposerTest = HloTestBase; +using Decomposer = CollectivePermuteCycleDecomposer; -TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) { - const absl::string_view kModuleStr = R"( +HloPrintOptions PrintOptions() { + HloPrintOptions options; + options.set_print_operand_shape(false); + options.set_include_layout_in_shapes(false); + return options; +} + +TEST_F(CollectivePermuteCycleDecomposerTest, NoCycle_NotTransformed) { + absl::string_view kHlo = R"( HloModule test ENTRY test_computation { p = u32[8,8] parameter(0) @@ -47,17 +54,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, TrivialNotTransformed) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_FALSE(changed); + TF_ASSERT_OK(RunAndCheckHloRewrite(kHlo, Decomposer(0), false)); } -TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) { +TEST_F(CollectivePermuteCycleDecomposerTest, HonorsThreshold) { // When `size of data` > `threshold`, then it is decomposed, otherwise it // stays as it is. - const absl::string_view kModuleStr = R"( + // u32[4,2] = 4*4*2 = 32 bytes + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[4,2] parameter(0) @@ -66,16 +70,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BelowThresholdNotTransformed) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/33); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - RunHloPass(CollectivePermuteCycleDecomposer(33), module.get())); - EXPECT_FALSE(changed); - TF_ASSERT_OK_AND_ASSIGN( - changed, RunHloPass(CollectivePermuteCycleDecomposer(16), module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(33), false)); + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(32), true)); + TF_ASSERT_OK(RunAndCheckHloRewrite(hlo, Decomposer(16), true)); } TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { @@ -84,7 +81,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { // 2. They should split over the value of partition-id. // 3. The metadata and frontend_attributes are propagated to split // collectives. - const absl::string_view kModuleStr = R"( + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[8,8] parameter(0) @@ -94,30 +91,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} } )"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - - TF_CHECK_OK(VerifyHloModule(module.get(), false, true)); - HloPrintOptions options; - options.set_print_operand_shape(false); - options.set_include_layout_in_shapes(false); - EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo, Decomposer(0), true)); + EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"( // CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] { // CHECK-DAG: %[[partition_id:.+]] = u32[] partition-id() // CHECK-DAG: %[[c0:.+]] = u32[] constant(0) // CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition_id]], %[[c0]]), direction=EQ // CHECK-DAG: %{{.+}} = u32[8,8] parameter(0) - - // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, + + // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, // CHECK-SAME{LITERAL}: source_target_pairs={{3,0}}, frontend_attributes={_xla_send_recv_validation={{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - - // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, + + // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, // CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}}, frontend_attributes={_xla_send_recv_validation={{0,7},{1,8},{2,9}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - + // CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]]) // CHECK-DAG: } )")); @@ -127,7 +115,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) { // For a forward cycle, this checks: // 1. Split collectives should not have channel-id // 2. Split collectives are combined based on replica-id. - const absl::string_view kModuleStr = R"( + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[8,8] parameter(0) @@ -136,17 +124,9 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) { } )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - TF_CHECK_OK(VerifyHloModule(module.get(), false, true)); - - HloPrintOptions options; - options.set_print_operand_shape(false); - options.set_include_layout_in_shapes(false); - EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo, Decomposer(0), true)); + EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"( // CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] { // CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id() // CHECK-DAG: %[[c0:.+]] = u32[] constant(0) @@ -155,17 +135,17 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleNoChannel) { // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= // CHECK-SAME{LITERAL}: {{3,0}} - + // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= // CHECK-SAME{LITERAL}: {{0,1},{1,2},{2,3}} - + // CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]]) // CHECK-DAG: } )")); } TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { - const absl::string_view kModuleStr = R"( + absl::string_view hlo = R"( HloModule test while_cond { @@ -198,11 +178,8 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { while_res = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body ROOT data_out = f32[2,2] get-tuple-element(while_res), index=1 })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo, Decomposer(0), true)); HloCollectivePermuteInstruction* cp1 = DynCast( FindInstruction(module.get(), "cp.backward")); @@ -222,7 +199,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { // 1. Metadata is propagated to split collectives. // 2. Frontend attributes are accurately split. // 3. The split collectives have channel IDs. - const absl::string_view kModuleStr = R"( + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[8,8] parameter(0) @@ -232,29 +209,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - RunHloPass(CollectivePermuteCycleDecomposer(0), module.get())); - EXPECT_TRUE(changed); - TF_CHECK_OK(VerifyHloModule(module.get(), true, false)); - HloPrintOptions options; - options.set_print_operand_shape(false); - options.set_include_layout_in_shapes(false); - EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo, Decomposer(0), true)); + EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"( // CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] { // CHECK-DAG: %[[partition:.+]] = u32[] partition-id() // CHECK-DAG: %[[three:.+]] = u32[] constant(3) // CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[partition]], %[[three]]), direction=EQ // CHECK-DAG: %{{.+}} = u32[8,8] parameter(0) - + // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=1, source_target_pairs= // CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - + // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), channel_id=2, source_target_pairs= // CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}}, metadata={op_name="op1/op2/add" source_file="foo/bar/mysource.py" source_line=35} - + // CHECK-DAG: ROOT %{{.+}} = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]]) // CHECK-DAG: } )")); @@ -264,7 +233,7 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) { // For backward cycle, this checks: // 1. Split collectives do not have a channel-id // 2. Split collectives are combined based on the value of replica-id. - const absl::string_view kModuleStr = R"( + absl::string_view hlo = R"( HloModule test ENTRY test_computation { p = u32[8,8] parameter(0) @@ -273,28 +242,21 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycleNoChannel) { frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"} })"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule((kModuleStr))); - CollectivePermuteCycleDecomposer decomposer(/*threshold_in_bytes=*/0); - TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); - EXPECT_TRUE(changed); - HloPrintOptions options; - options.set_print_operand_shape(false); - options.set_include_layout_in_shapes(false); - TF_CHECK_OK(VerifyHloModule(module.get(), false, true)); - EXPECT_TRUE(*RunFileCheck(module->ToString(options), R"( + TF_ASSERT_OK_AND_ASSIGN(auto module, + RunAndCheckHloRewrite(hlo, Decomposer(0), true)); + EXPECT_TRUE(*RunFileCheck(module->ToString(PrintOptions()), R"( // CHECK: ENTRY %test_computation (p: u32[8,8]) -> u32[8,8] { // CHECK-DAG: %[[replica_id:.+]] = u32[] replica-id() // CHECK-DAG: %[[three:.+]] = u32[] constant(3) // CHECK-DAG: %[[compare:.+]] = pred[] compare(%[[replica_id]], %[[three]]), direction=EQ // CHECK-DAG: %{{.+}} = u32[8,8] parameter(0) - + // CHECK-DAG: %[[cp1:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= // CHECK-SAME{LITERAL}: {{0,3}}, frontend_attributes={_xla_send_recv_validation={{0,7}}} - + // CHECK-DAG: %[[cp2:.+]] = u32[8,8] collective-permute(%{{.+}}), source_target_pairs= // CHECK-SAME{LITERAL}: {{1,0},{2,1},{3,2}}, frontend_attributes={_xla_send_recv_validation={{1,8},{2,9},{3,10}}} - + // CHECK-DAG: ROOT %select = u32[8,8] select(%[[compare]], %[[cp1]], %[[cp2]]) // CHECK-DAG: } )")); diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc index e9df22abeddcfe..cdca7af15832da 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc @@ -12,10 +12,10 @@ limitations under the License. #include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/while_loop_analysis.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc index 41a2957568c5e0..45578e4f7dbed1 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/while_loop_trip_count_annotator.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc index 87f7e083a01058..c383e6c2cc6917 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/collective_ops_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -41,21 +42,20 @@ namespace { using SourceTargetPair = std::pair; using SourceTargetPairs = std::vector; -struct SelectPredInfo { - int64_t constant; - Comparison::Direction direction; - HloOpcode device_id_type; // kReplicaId or kPartitionId +struct FoldableSelect { + Comparison::Direction cmp_direction; + int64_t constant_id; + CollectiveOpGroupMode collective_mode; HloInstruction* true_operand; HloInstruction* false_operand; }; // Returns handy references to %constant, %true_operand, %false_operand of the -// select(broadcast(compare(current_device_id, constant)), true_operand, -// false_operand) +// `select(broadcast(compare(current_id, constant)), true_operand, +// false_operand)` // or -// select(compare(current_device_id, constant), true_operand, -// false_operand) -std::optional GetPredSelectInfo(HloInstruction* select) { +// select(compare(current_id, constant), true_operand, false_operand)` +std::optional MatchFoldableSelect(HloInstruction* select) { if (select->opcode() != HloOpcode::kSelect) { return std::nullopt; } @@ -71,79 +71,96 @@ std::optional GetPredSelectInfo(HloInstruction* select) { const HloCompareInstruction* compare = DynCast(compare_candidate); + if (compare->direction() != Comparison::Direction::kEq && + compare->direction() != Comparison::Direction::kNe) { + return std::nullopt; + } + + const HloInstruction* id_op = compare->operand(0); + CollectiveOpGroupMode mode; + if (id_op->opcode() == HloOpcode::kReplicaId) { + mode = CollectiveOpGroupMode::kCrossReplica; + } else if (id_op->opcode() == HloOpcode::kPartitionId) { + mode = CollectiveOpGroupMode::kCrossPartition; + } else { + return std::nullopt; + } - if ((compare->operand(0)->opcode() != HloOpcode::kReplicaId && - compare->operand(0)->opcode() != HloOpcode::kPartitionId) || - compare->operand(1)->opcode() != HloOpcode::kConstant) { + if (compare->operand(1)->opcode() != HloOpcode::kConstant) { return std::nullopt; } int64_t id_value = compare->operand(1)->literal().GetFirstInteger().value_or(-1); - return SelectPredInfo{id_value, compare->direction(), - compare->operand(0)->opcode(), + return FoldableSelect{compare->direction(), id_value, mode, select->mutable_operand(1), select->mutable_operand(2)}; } -bool IsUniqueSource(int64_t device_id, const SourceTargetPairs& pairs) { - if (pairs.size() == 1 && pairs[0].first == device_id) return true; - return false; -} - -bool IsNotPresentInSource(int64_t device_id, const SourceTargetPairs& pairs) { - return absl::c_none_of( - pairs, [device_id](const auto& pair) { return pair.first == device_id; }); -} - -inline absl::StatusOr update(HloInstruction* cp, HloInstruction* data) { - TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, data)); - return true; -} +std::optional StaticallyEvaluatePredicateForAllSourceIDs( + FoldableSelect select_match, SourceTargetPairs pairs) { + // If there are no pairs, the predicate is undefined. + if (pairs.empty()) return std::nullopt; + + // Evaluate the select predicate for the first source target pair. + CHECK(select_match.cmp_direction == Comparison::Direction::kEq || + select_match.cmp_direction == Comparison::Direction::kNe); + auto select_predicate_eval = [&select_match](const SourceTargetPair& pair) { + int64_t src_id = pair.first; + return select_match.cmp_direction == Comparison::Direction::kEq + ? src_id == select_match.constant_id + : src_id != select_match.constant_id; + }; + bool result_candidate = select_predicate_eval(pairs.front()); + + // Check that the result is the same for all source target pairs. If not, + // we have a contradiction and cannot statically evaluate the predicate. We + // return std::nullopt in this case. + if (!absl::c_all_of(pairs, [&](const SourceTargetPair& it) -> bool { + return result_candidate == select_predicate_eval(it); + })) { + return std::nullopt; + } -// We have to maintain integrity of relationship between partition/replica -// and collective-permute's channel_id. -// That is we can only fold select when -// 1. cp has channel_id and condition is based on partition_id -// 2. cp has no channel_id and condition is based on replica_id -// See enum class CollectiveOpGroupMode for details. -bool IsShardingConsistent(HloCollectivePermuteInstruction* cp, - HloOpcode device_id_type) { - auto id = cp->channel_id(); - return (device_id_type == HloOpcode::kPartitionId && id.has_value()) || - (device_id_type == HloOpcode::kReplicaId && !id.has_value()); + // The predicate statically evaluates to the same value for all source target + // pairs. + return result_candidate; } // Recognizes the pattern and update if applicable. -absl::StatusOr TryFoldSelect(HloInstruction* in) { - if (in->opcode() != HloOpcode::kCollectivePermute) return false; - auto select_info_opt = GetPredSelectInfo(in->mutable_operand(0)); - if (!select_info_opt.has_value()) return false; - auto select_info = select_info_opt.value(); - +absl::StatusOr TryFoldColectivePermuteOfSelect(HloInstruction* inst) { + // Root op must be a collective-permute. HloCollectivePermuteInstruction* cp = - Cast(in); - if (!IsShardingConsistent(cp, select_info.device_id_type)) return false; - - int64_t device_id = select_info.constant; - SourceTargetPairs pairs = cp->source_target_pairs(); - - if (select_info.direction == Comparison::Direction::kEq) { - if (IsUniqueSource(device_id, pairs)) { - return update(cp, select_info.true_operand); - } else if (IsNotPresentInSource(device_id, pairs)) { - return update(cp, select_info.false_operand); - } - } - - if (select_info.direction == Comparison::Direction::kNe) { - if (IsNotPresentInSource(device_id, pairs)) { - return update(cp, select_info.true_operand); - } else if (IsUniqueSource(device_id, pairs)) { - return update(cp, select_info.false_operand); - } - } - return false; + DynCast(inst); + if (cp == nullptr) return false; + + // Operand must be a foldable select, i.e. a select op that this pass' + // analysis supports. + std::optional select_match = + MatchFoldableSelect(inst->mutable_operand(0)); + if (!select_match.has_value()) return false; + + // We have to maintain integrity of relationship between the predicate, which + // is based on partition or replica ID, and the collevtive mode of the + // collective-permute op. + TF_ASSIGN_OR_RETURN( + CollectiveOpGroupMode collective_mode, + GetCollectiveOpGroupMode(cp->channel_id().has_value(), + /*use_global_device_ids=*/std::nullopt)); + if (collective_mode != select_match->collective_mode) return false; + + // We can only actually fold the select if we can evaluate the predicate + // statically to a known value for all relevant source IDs. + std::optional predicate_value = + StaticallyEvaluatePredicateForAllSourceIDs(*select_match, + cp->source_target_pairs()); + if (!predicate_value.has_value()) return false; + + // Fold select and forward the correct operand. + HloInstruction* new_operand = *predicate_value ? select_match->true_operand + : select_match->false_operand; + TF_RETURN_IF_ERROR(cp->ReplaceOperandWith(0, new_operand)); + return true; } } // namespace @@ -152,9 +169,10 @@ absl::StatusOr CollectiveSelectFolder::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - TF_ASSIGN_OR_RETURN(bool local_changed, TryFoldSelect(instruction)); + for (HloComputation* comp : module->computations()) { + for (HloInstruction* inst : comp->instructions()) { + TF_ASSIGN_OR_RETURN(bool local_changed, + TryFoldColectivePermuteOfSelect(inst)); changed |= local_changed; } } diff --git a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.h b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.h index 3e14ecbf054e1b..c53eb2ca508b3f 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_select_folder.h +++ b/third_party/xla/xla/service/gpu/transforms/collective_select_folder.h @@ -24,37 +24,54 @@ limitations under the License. namespace xla { -// When collective-permute operates on a comparison to a device id -// and the senders match the condition's branch -// we can link collective-permute to the original data skipping the comparison. -// For example -// condition = broadcast(compare(replica_id, X), direction=EQ -// data_snd = select(condition, compare_true_data, compare_false_data) -// rcv = collective-permute(data_snd compare_true_data), pairs={{X,0}} -// can be transformed to -// rcv = collective-permute(compare_true_data), pairs={{X,0}} +// If a collective-permute selects its source data based on a partition or +// replica ID and we can prove that the condition is either always true or +// always false, we can fold the redundant select op and use the correct source +// data directly. // -// The pass is *only* handling compare direction={EQ,NE}. -// The pass handles Compare with and without preceding Broadcast. +// Example: +// +// condition = compare(replica-id(), X), direction=EQ +// snd_data = select(condition, true_data, false_data) +// rcv_data = collective-permute(snd_data), source_target_pairs={{X,0}} +// +// The condition is always true for the only relevant replica X and the IR can +// be folded into +// +// rcv_data = collective-permute(true_data), source_target_pairs={{X,0}} +// +// The pass only supports simple partion/replica-based predicates, comparing +// partition/replica-id with a constant. Only comparison directions {EQ,NE} are +// supported. The predicate may be broadcasted. +// +// This pass is motivated by pipeline parallelism, where it removes undesired +// data dependencies. +// +// Example: // -// This pass is particularly useful in the pipeline parallelism generated module -// such as: // fwd_data = ... // bwd_data = // is_first_device = ... // is_last_device = ... -// data_snd = select(is_last_device, bwd_data, fwd_data) -// bwd_data_rcv = collective-permute(data_snd), pairs={{3,0}} -// fwd_data_rcv = collective-permute(data_snd), pairs={{0,1},{1,2},{2,3}} -// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +// snd_data = select(is_last_device, bwd_data, fwd_data) +// rcv_bwd_data = collective-permute(snd_data), +// source_target_pairs={{LAST_ID,0}} +// rcv_fwd_data = collective-permute(snd_data), +// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}} +// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data) // -// After the transformation, the module will become: -// fwd_data_snd = ... -// bwd_data_snd = ... +// The select can be removed on both paths resulting in +// +// fwd_data = ... +// bwd_data = // is_first_device = ... -// bwd_data_rcv = collective-permute(bwd_data_snd), pairs={{3,0}} -// fwd_data_rcv = collective-permute(fwd_data_snd), pairs={{0,1},{1,2},{2,3}} -// ROOT data_rcv = select(is_first_device, bwd_data_rcv, fwd_data_rcv) +// is_last_device = ... +// rcv_bwd_data = collective-permute(bwd_data), +// source_target_pairs={{LAST_ID,0}} +// rcv_fwd_data = collective-permute(fwd_data), +// source_target_pairs={{0,1},{1,2},...,{LAST_ID,0}} +// ROOT rcv_data = select(is_first_device, rcv_bwd_data, rcv_fwd_data) +// class CollectiveSelectFolder : public HloModulePass { public: absl::string_view name() const override { return "collective-select-folder"; } diff --git a/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc index 59b346abfa4c3d..2f11ad958fd89f 100644 --- a/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_select_folder_test.cc @@ -17,14 +17,11 @@ limitations under the License. #include #include -#include -#include #include #include #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -38,37 +35,11 @@ namespace xla { namespace { using ::testing::HasSubstr; - -HloPrintOptions LeastPrintOptions() { - HloPrintOptions options; - options.set_print_operand_shape(false) - .set_include_layout_in_shapes(false) - .set_print_percent(false); - return options; -} - class CollectiveSelectFolderTest : public HloTestBase { public: - using FixedMapping = - std::initializer_list>; - - absl::StatusOr> RunTranform( - bool expect_changed, std::string_view hlo_template, FixedMapping params) { - std::string hlo_string = absl::StrReplaceAll(hlo_template, params); - SCOPED_TRACE("Input HLO: " + hlo_string); - VLOG(7) << "Input HLO: " << hlo_string; - - TF_ASSIGN_OR_RETURN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSIGN_OR_RETURN(bool changed, - RunHloPass(CollectiveSelectFolder(), module.get())); - VLOG(7) << "Output HLO: " << module->ToString(LeastPrintOptions()); - EXPECT_EQ(changed, expect_changed); - return module; - } - absl::Status ExpectNoTranform(std::string_view hlo_template) { - return RunTranform(/*expect_changed=*/false, hlo_template, {}).status(); + return RunAndCheckHloRewrite(hlo_template, CollectiveSelectFolder(), false) + .status(); } }; @@ -121,33 +92,39 @@ const char* kSPMD2cp = R"( TEST_F(CollectiveSelectFolderTest, SimpleForwardCycle) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, - {{"$first_id_constant", "0"}, - {"$last_id_constant", "3"}, - {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, - {"$backward_pairs", "{{3,0}}"}})); + auto module, + RunAndCheckHloRewrite(kSPMD2cp, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "3"}, + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}})); VerifyDirectDataFeedSPMD(module.get(), "fwd_data", "bwd_data"); } TEST_F(CollectiveSelectFolderTest, SimpleBackwardCycle) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, - {{"$first_id_constant", "3"}, - {"$last_id_constant", "0"}, - {"$forward_pairs", "{{3,2},{2,1},{1,0}}"}, - {"$backward_pairs", "{{0,3}}"}})); + auto module, + RunAndCheckHloRewrite(kSPMD2cp, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$first_id_constant", "3"}, + {"$last_id_constant", "0"}, + {"$forward_pairs", "{{3,2},{2,1},{1,0}}"}, + {"$backward_pairs", "{{0,3}}"}})); VerifyDirectDataFeedSPMD(module.get(), "fwd_data", "bwd_data"); } TEST_F(CollectiveSelectFolderTest, CompareNEForwardCycle) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, - {{"$first_id_constant", "0"}, - {"$last_id_constant", "3"}, - {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, - {"$backward_pairs", "{{3,0}}"}, - {"direction=EQ", "direction=NE"}})); + auto module, + RunAndCheckHloRewrite(kSPMD2cp, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "3"}, + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}, + {"direction=EQ", "direction=NE"}})); // Compared with SimpleForwardCycle above, this test flips the condition // and therefore the data being forwarded. VerifyDirectDataFeedSPMD(module.get(), "bwd_data", "fwd_data"); @@ -159,11 +136,13 @@ TEST_F(CollectiveSelectFolderTest, CompareNEForwardCycle) { // to the select. TEST_F(CollectiveSelectFolderTest, LastDeviceIdMismatch) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSPMD2cp, - {{"$first_id_constant", "0"}, - {"$last_id_constant", "2"}, // mismatch - {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, - {"$backward_pairs", "{{3,0}}"}})); + auto module, + RunAndCheckHloRewrite(kSPMD2cp, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$first_id_constant", "0"}, + {"$last_id_constant", "2"}, // mismatch + {"$forward_pairs", "{{0,1},{1,2},{2,3}}"}, + {"$backward_pairs", "{{3,0}}"}})); VerifyDirectDataFeedSPMD(module.get(), "data_snd", "fwd_data"); } @@ -183,41 +162,49 @@ const char* kSelectBasecase = R"( )"; TEST_F(CollectiveSelectFolderTest, EqualTrueBranchTransform) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunTranform(/*expect_changed=*/true, kSelectBasecase, - {{"$device_id_constant", "3"}, - {"$direction", "EQ"}, - {"$pairs", "{{3,0}}"}})); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{3,0}}"}})); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); } TEST_F(CollectiveSelectFolderTest, EqualFalseBranchTransform) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunTranform(/*expect_changed=*/true, kSelectBasecase, - {{"$device_id_constant", "3"}, - {"$direction", "EQ"}, - {"$pairs", "{{0,1},{1,2}}"}})); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{0,1},{1,2}}"}})); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0)->name(), "compare_false_data"); } TEST_F(CollectiveSelectFolderTest, NotEqualFalseBranchTransform) { - TF_ASSERT_OK_AND_ASSIGN(auto module, - RunTranform(/*expect_changed=*/true, kSelectBasecase, - {{"$device_id_constant", "3"}, - {"$direction", "NE"}, - {"$pairs", "{{3,0}}"}})); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$device_id_constant", "3"}, + {"$direction", "NE"}, + {"$pairs", "{{3,0}}"}})); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0)->name(), "compare_false_data"); } TEST_F(CollectiveSelectFolderTest, NotEqualTrueTrueTransform) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSelectBasecase, - {{"$device_id_constant", "3"}, - {"$direction", "NE"}, - {"$pairs", "{{0,1},{1,2},{4,5},{5,6}}"}})); + auto module, + RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$device_id_constant", "3"}, + {"$direction", "NE"}, + {"$pairs", "{{0,1},{1,2},{4,5},{5,6}}"}})); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); } @@ -225,17 +212,19 @@ TEST_F(CollectiveSelectFolderTest, NotEqualTrueTrueTransform) { TEST_F(CollectiveSelectFolderTest, MoreThanOnePair_NotTransformed) { // The cp contains sources 0 and 1, and therefore doesn't match // equal(1) and not equal(1) - TF_ASSERT_OK(RunTranform(/*expect_changed=*/false, kSelectBasecase, - {{"$device_id_constant", "1"}, - {"$direction", "EQ"}, - {"$pairs", "{{0,1},{1,2}}"}})); + TF_ASSERT_OK(RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/false, + {{"$device_id_constant", "1"}, + {"$direction", "EQ"}, + {"$pairs", "{{0,1},{1,2}}"}})); // The cp contains sources 0 and 1, and therefore doesn't match // not_equal(1) and not not_equal(1) - TF_ASSERT_OK(RunTranform(/*expect_changed=*/false, kSelectBasecase, - {{"$device_id_constant", "1"}, - {"$direction", "NE"}, - {"$pairs", "{{0,1},{1,2}}"}})); + TF_ASSERT_OK(RunAndCheckHloRewrite(kSelectBasecase, CollectiveSelectFolder(), + /*expect_change=*/false, + {{"$device_id_constant", "1"}, + {"$direction", "NE"}, + {"$pairs", "{{0,1},{1,2}}"}})); } const char* kSelectNoBroadcast = R"( @@ -254,10 +243,12 @@ const char* kSelectNoBroadcast = R"( TEST_F(CollectiveSelectFolderTest, SelectNoBroadcastTransform) { TF_ASSERT_OK_AND_ASSIGN( - auto module, RunTranform(/*expect_changed=*/true, kSelectNoBroadcast, - {{"$device_id_constant", "3"}, - {"$direction", "EQ"}, - {"$pairs", "{{3,0}}"}})); + auto module, + RunAndCheckHloRewrite(kSelectNoBroadcast, CollectiveSelectFolder(), + /*expect_change=*/true, + {{"$device_id_constant", "3"}, + {"$direction", "EQ"}, + {"$pairs", "{{3,0}}"}})); auto root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->operand(0)->name(), "compare_true_data"); } diff --git a/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.cc b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.cc new file mode 100644 index 00000000000000..198f97b11dd0be --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.cc @@ -0,0 +1,140 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collective_send_recv_combiner.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" +namespace xla { + +namespace { + +// BuildWrappedComputationForAsyncStart is a side-effecting function that +// returns a clone of the given instruction and populates the async_start_inputs +// and async_start_input_shapes vectors with the operands and operand shapes of +// the cloned instruction. +HloInstruction* BuildWrappedComputationForAsyncStart( + HloComputation::Builder& builder, HloInstruction* instruction, + std::vector& async_start_inputs, + std::vector& async_start_input_shapes) { + int operand_counter = 0; + std::vector operands; + for (auto src_operand : instruction->operands()) { + operands.push_back(builder.AddInstruction(HloInstruction::CreateParameter( + operand_counter, src_operand->shape(), + absl::StrCat("param", operand_counter)))); + async_start_inputs.push_back(src_operand); + async_start_input_shapes.push_back(src_operand->shape()); + ++operand_counter; + } + return builder.AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), operands)); +} + +absl::Status UpdateControlDependencies(HloInstruction* old_instruction, + HloInstruction* new_instruction) { + if (!old_instruction->HasControlDependencies()) { + return absl::OkStatus(); + } + for (HloInstruction* predecessor : old_instruction->control_predecessors()) { + TF_RETURN_IF_ERROR(predecessor->RemoveControlDependencyTo(old_instruction)); + TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(new_instruction)); + } + for (HloInstruction* successor : old_instruction->control_successors()) { + TF_RETURN_IF_ERROR(old_instruction->RemoveControlDependencyTo(successor)); + TF_RETURN_IF_ERROR(new_instruction->AddControlDependencyTo(successor)); + } + return absl::OkStatus(); +} + +absl::Status CreateAsyncStartAndAsyncDone( + HloInstruction* root, HloComputation::Builder& builder, + HloInstruction* instruction, HloComputation* computation, HloModule* module, + std::vector& async_start_inputs, + std::vector& async_start_input_shapes, bool& changed) { + for (auto instruction_user : instruction->users()) { + if (instruction_user->opcode() != HloOpcode::kSendDone && + instruction_user->opcode() != HloOpcode::kRecvDone) { + // Ignore instruction users that are not send-done or recv-done. + continue; + } + Shape async_start_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeTupleShape(async_start_input_shapes), root->shape(), + ShapeUtil::MakeScalarShape(S32)}); + auto async_start = + computation->AddInstruction(HloInstruction::CreateAsyncStart( + async_start_shape, async_start_inputs, + module->AddEmbeddedComputation(builder.Build(root)))); + auto async_done = computation->AddInstruction( + HloInstruction::CreateAsyncDone(root->shape(), async_start)); + TF_RETURN_IF_ERROR(UpdateControlDependencies(instruction, async_start)); + TF_RETURN_IF_ERROR(UpdateControlDependencies(instruction_user, async_done)); + TF_RETURN_IF_ERROR( + instruction_user->ReplaceAllUsesWithDifferentShape(async_done)); + TF_RETURN_IF_ERROR( + instruction_user->parent()->RemoveInstruction(instruction_user)); + TF_RETURN_IF_ERROR( + instruction->ReplaceAllUsesWithDifferentShape(async_start)); + TF_RETURN_IF_ERROR(instruction->parent()->RemoveInstruction(instruction)); + changed = true; + } + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr CollectiveSendRecvCombiner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + int wrapped_computation_index = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kSend && + instruction->opcode() != HloOpcode::kRecv) { + continue; + } + + // Create a new computation that wraps the send/recv instruction. + ++wrapped_computation_index; + auto builder = HloComputation::Builder(absl::StrCat( + "wrapped_", instruction->name(), wrapped_computation_index)); + std::vector async_start_inputs; + std::vector async_start_input_shapes; + auto root = BuildWrappedComputationForAsyncStart( + builder, instruction, async_start_inputs, async_start_input_shapes); + + TF_RETURN_IF_ERROR(CreateAsyncStartAndAsyncDone( + root, builder, instruction, computation, module, async_start_inputs, + async_start_input_shapes, changed)); + } + } + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h new file mode 100644 index 00000000000000..de8a88517035d2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SEND_RECV_COMBINER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SEND_RECV_COMBINER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { + +// CollectiveSendRecvCombiner is a pass that scans for all send/recv pairs +// which are part of the same computation, and transforms them into wrapped +// single-op computations that are executed asynchronously. This pass also +// replaces the corresponding send-done and recv-done instructions with +// async-done functions. This pass is primarily used for pipelining send/recv +// and send-done/recv-done instructions across while loop iteration boundaries. +class CollectiveSendRecvCombiner : public HloModulePass { + public: + absl::string_view name() const override { + return "collective-send-recv-combiner"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_SEND_RECV_COMBINER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc new file mode 100644 index 00000000000000..7f8b36ae4831ff --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/collective_send_recv_combiner_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/collective_send_recv_combiner.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +using CollectiveSendRecvCombinerTest = HloTestBase; + +// TODO: b/372132451 - add unit test in collective send/recv combiner to check +// control dependencies + +TEST_F(CollectiveSendRecvCombinerTest, TransformedNoFrontEndAttr) { + const char* kHloStr = R"( + ENTRY main { + data = f32[] constant(5) + recv-start = token[] after-all() + recv = (f32[], u32[], token[]) recv(recv-start), channel_id=1 + send = (f32[], u32[], token[]) send(data, recv-start), channel_id=1 + recv-done = (f32[], token[]) recv-done(recv), channel_id=1 + send-done = token[] send-done(send), channel_id=1 + ROOT out = f32[] get-tuple-element(recv-done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule((kHloStr))); + CollectiveSendRecvCombiner combiner; + TF_ASSERT_OK_AND_ASSIGN(bool changed, combiner.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunFileCheck(module->ToString(), R"( + CHECK: ENTRY %[[MAIN:.*]] () -> f32[] { + CHECK: %[[RECV_START:.*]] = token[] after-all() + CHECK: %[[RECV_ASYNC:.*]] = ((token[]), (f32[], u32[], token[]), s32[]) + recv-start(token[] %[[RECV_START:.*]]), channel_id=1 + CHECK: %[[RECV_DONE:.*]] = (f32[], u32[], token[]) + recv-done(((token[]), (f32[], u32[], token[]), s32[]) %[[RECV_ASYNC:.*]]) + CHECK: ROOT %[[OUT:.*]] = f32[] get-tuple-element((f32[], u32[], token[]) + %[[RECV_DONE:.*]]), index=0 + CHECK: %[[DATA:.*]] = f32[] constant(5) + CHECK: %[[SEND_ASYNC:.*]] = ((f32[], token[]), (f32[], u32[], token[]), s32[]) + send-start(f32[] %[[DATA]], token[] %[[RECV_ASYNC:.*]]) + CHECK: %[[SEND_DONE:.*]] = (f32[], u32[], token[]) + send-done(((f32[], token[]), (f32[], u32[], token[]), s32[]) %[[SEND_ASYNC:.*]]) + )") + .value()); +} + +TEST_F(CollectiveSendRecvCombinerTest, TrivialNoTransform) { + const char* kHloStr = R"( + ENTRY main { + zero = f32[] constant(0) + five = f32[] constant(5) + ROOT out = f32[] add(zero, five) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule((kHloStr))); + CollectiveSendRecvCombiner combiner; + TF_ASSERT_OK_AND_ASSIGN(bool changed, combiner.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectiveSendRecvCombinerTest, PartiallyPipelinedSendRecvNoTransform) { + const char* const kModuleStr = R"( + HloModule test + + while_body { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + prev_send = (f32[16], u32[], token[]) get-tuple-element(param), index=0 + data = f32[16] get-tuple-element(param), index=1 + send_done = (f32[16], token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + while = ((f32[16], u32[], token[]), f32[16]) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16], u32[], token[]) get-tuple-element(while), index=0 + ROOT send_done = (f32[16], token[]) send-done(send_ctx), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule((kModuleStr))); + CollectiveSendRecvCombiner combiner; + TF_ASSERT_OK_AND_ASSIGN(bool changed, combiner.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(CollectiveSendRecvCombinerTest, TransformedWithControlDependency) { + const char* kHloStr = R"( + ENTRY main { + data = f32[] constant(5) + recv-start = token[] after-all() + recv = (f32[], u32[], token[]) recv(recv-start), channel_id=1 + send = (f32[], u32[], token[]) send(data, recv-start), channel_id=1 + recv-done = (f32[], token[]) recv-done(recv), channel_id=1, + control-predecessors={send} + send-done = token[] send-done(send), channel_id=1 + ROOT out = f32[] get-tuple-element(recv-done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule((kHloStr))); + CollectiveSendRecvCombiner combiner; + TF_ASSERT_OK_AND_ASSIGN(bool changed, combiner.Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_TRUE(RunFileCheck(module->ToString(), R"( + CHECK: ENTRY %[[MAIN:.*]] () -> f32[] { + CHECK: %[[DATA:.*]] = f32[] constant(5) + CHECK: %[[RECV_START:.*]] = token[] after-all() + CHECK: %[[SEND_ASYNC:.*]] = ((f32[], token[]), (f32[], u32[], token[]), s32[]) + send-start(f32[] %[[DATA]], token[] %[[RECV_START:.*]]) + CHECK: %[[RECV_ASYNC:.*]] = ((token[]), (f32[], u32[], token[]), s32[]) + recv-start(token[] %[[RECV_START:.*]]), channel_id=1 + CHECK: %[[RECV_DONE:.*]] = (f32[], u32[], token[]) + recv-done(((token[]), (f32[], u32[], token[]), s32[]) %[[RECV_ASYNC:.*]]) + CHECK: ROOT %[[OUT:.*]] = f32[] get-tuple-element((f32[], u32[], token[]) + %[[RECV_DONE:.*]]), index=0 + CHECK: %[[SEND_DONE:.*]] = (f32[], u32[], token[]) + send-done(((f32[], token[]), (f32[], u32[], token[]), s32[]) %[[SEND_ASYNC:.*]]) + )") + .value()); +} +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index c772ef052b5a1a..5391dc8031eff7 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -43,7 +43,6 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/shape.h" @@ -112,12 +111,14 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { return config.enabled_commands.contains(DebugOptions::FUSION); } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || + hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } - if (hlo->opcode() == HloOpcode::kReduceScatter) { + if (hlo->opcode() == HloOpcode::kReduceScatter || + hlo->opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } @@ -138,7 +139,8 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, if (hlo->async_wrapped_opcode() == HloOpcode::kFusion) { return config.enabled_commands.contains(DebugOptions::FUSION); } - if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { + if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter || + hlo->async_wrapped_opcode() == HloOpcode::kAllToAll) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } } @@ -241,7 +243,8 @@ static bool IsCommand(const HloInstruction* hlo, return config.enabled_commands.contains(DebugOptions::CUDNN); } const auto& custom_config = backend_config.custom_fusion_config(); - if (custom_config.name() == "address_computation") { + if ((custom_config.name() == "address_computation") || + (custom_config.name() == "dynamic_address_computation")) { auto fusion_analysis = HloFusionAnalysis::Create(*hlo, config.device_description); const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); @@ -251,10 +254,16 @@ static bool IsCommand(const HloInstruction* hlo, node.opcode() == HloOpcode::kReduceScatter; }); const HloInstruction* hero = &hero_adaptor->instruction(); - return IsCommand(hero, config) || IsAsyncStartCommand(hero, config); - } - if (custom_config.name() == "dynamic_address_computation") { - return false; + + if (custom_config.name() == "address_computation") { + return IsCommand(hero, config) || IsAsyncStartCommand(hero, config); + } else { + // DynamicSliceFusionRewriter currently only rewrites for dynamic slice + // fusion with constant or loop iteration offset values, which are all + // supported by command buffer. + return (config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE) && + (IsCommand(hero, config) || IsAsyncStartCommand(hero, config))); + } } return config.enabled_commands.contains(DebugOptions::FUSION); } @@ -334,6 +343,40 @@ CommandBufferScheduling::CollectCommandBufferSequences( auto& instructions = schedule.instructions(); + // we currently require that when lowering DynamicSliceFusion, the offset + // value should not come from the output of operators that are already + // captured in command buffer. + auto check_dynamic_slice_operand_not_from_seq = + [&](const HloInstructionSequence& seq, const HloInstruction* inst) { + if (!config.enabled_commands.contains(DebugOptions::DYNAMIC_SLICE)) + return true; + const auto* fusion = DynCast(inst); + if (!fusion) return true; + + auto gpu_config = fusion->backend_config(); + const FusionBackendConfig& backend_config = + gpu_config->fusion_backend_config(); + const auto& custom_config = backend_config.custom_fusion_config(); + if (custom_config.name() != "dynamic_address_computation") return true; + + auto* fused_computation = fusion->called_computation(); + return !absl::c_any_of( + fused_computation->instructions(), [&](const HloInstruction* inst) { + const auto* dynamic_inst = + DynCast(inst); + if (!dynamic_inst) return false; + for (auto* operand : dynamic_inst->index_operands()) { + const auto* param = DynCast(operand); + const auto* fusion_operand = + fusion->operand(param->parameter_number()); + if (seq.contains(fusion_operand)) { + return true; + } + } + return false; + }); + }; + // Collect the sequence of instructions that contains the async start and its // corresponding done instruction. If there is another start instruction // between the original start and done, we may potentially extend the sequence @@ -370,7 +413,11 @@ CommandBufferScheduling::CollectCommandBufferSequences( // we do not capture unmatched async done instruction. auto check_async_region = [&](const HloInstructionSequence& seq) { if (!absl::c_all_of(seq.instructions(), [&](HloInstruction* inst) { - return IsNoOp(inst) || IsCommand(inst, config) || + return IsNoOp(inst) || + (IsCommand(inst, config) && + check_dynamic_slice_operand_not_from_seq(seq, inst) && + check_dynamic_slice_operand_not_from_seq(current_seq, + inst)) || IsAsyncStartCommand(inst, config) || IsAsyncDoneCommand(inst, config); })) { @@ -404,7 +451,8 @@ CommandBufferScheduling::CollectCommandBufferSequences( } // Synchronous commands always can be added to instruction sequence. - if (IsCommand(inst, config)) { + if (IsCommand(inst, config) && + check_dynamic_slice_operand_not_from_seq(current_seq, inst)) { num_commands_in_current_seq++; current_seq.push_back(inst); continue; @@ -438,7 +486,9 @@ CommandBufferScheduling::CollectCommandBufferSequences( // the beginning of the computation. This simplifies the construction of command // buffer computations because we don't need to deal with parameters and // constants that have users outside of a command buffer. -absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( +// Returns true if there is a change in the order of instructions, false +// otherwise. +absl::StatusOr CommandBufferScheduling::MoveParametersAndConstantsToFront( HloComputation* computation) { HloInstructionSequence new_sequence; HloSchedule& schedule = computation->parent()->schedule(); @@ -468,7 +518,11 @@ absl::Status CommandBufferScheduling::MoveParametersAndConstantsToFront( } schedule.set_sequence(computation, new_sequence); - return absl::OkStatus(); + for (auto [old_i, new_i] : + llvm::zip(sequence.instructions(), new_sequence.instructions())) { + if (old_i != new_i) return true; + } + return false; } //===----------------------------------------------------------------------===// @@ -782,6 +836,7 @@ absl::StatusOr CommandBufferScheduling::Run( std::reverse(order.begin(), order.end()); absl::flat_hash_set processed_command_buffers; + auto changed = false; for (HloComputation* comp : order) { // Skip special computations that do not have lowering to thunks. if (comp->IsFusionComputation() || comp->IsAsyncComputation() || @@ -791,7 +846,8 @@ absl::StatusOr CommandBufferScheduling::Run( // Skip computations that already part of command buffers. if (processed_command_buffers.contains(comp)) continue; - TF_RETURN_IF_ERROR(MoveParametersAndConstantsToFront(comp)); + TF_ASSIGN_OR_RETURN(bool changed_, MoveParametersAndConstantsToFront(comp)); + changed |= changed_; std::vector sequences = CollectCommandBufferSequences( @@ -804,6 +860,7 @@ absl::StatusOr CommandBufferScheduling::Run( TF_ASSIGN_OR_RETURN( HloComputation * command_buffer_computation, RewriteCommandBuffer(comp, seq, std::move(command_buffer))); + changed = true; // All computations reachable from a command buffer computation are nested // command buffers (i.e. body computations attached to a while operation). @@ -815,7 +872,7 @@ absl::StatusOr CommandBufferScheduling::Run( } TF_RETURN_IF_ERROR(module->schedule().Update()); - return true; + return changed; } } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h index 15f0b2dd4d4da9..71d5b421c1ee56 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -99,7 +99,9 @@ class CommandBufferScheduling : public HloModulePass { // the beginning of the computation. This simplifies the construction of // command buffer computations because we don't need to deal with parameters // and constants that have users outside of a command buffer. - static absl::Status MoveParametersAndConstantsToFront( + // Returns true if there is a change in the order of instructions, false + // otherwise. + static absl::StatusOr MoveParametersAndConstantsToFront( HloComputation* computation); struct CommandBuffer { diff --git a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index ab04e5e6441130..b58b75be7462f1 100644 --- a/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_executable.h" -#include "xla/service/hlo_parser.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" @@ -43,7 +43,7 @@ class CommandBufferSchedulingTest : public HloTestBase { return TestGpuDeviceInfo::CudaOrRocmDeviceInfo(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CONDITIONALS); @@ -56,8 +56,15 @@ class CommandBufferSchedulingTest : public HloTestBase { } const se::GpuComputeCapability& GetGpuComputeCapability() { +<<<<<<< HEAD return backend().default_stream_executor() ->GetDeviceDescription().gpu_compute_capability(); +======= + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); +>>>>>>> master } }; @@ -1085,6 +1092,35 @@ TEST_F(CommandBufferSchedulingTest, AsyncFusion) { }); } +TEST_F(CommandBufferSchedulingTest, AsyncAlltoAll) { + const char* hlo = R"( + HloModule m, is_scheduled=true + + async_computation.1 { + param.1 = f32[4,8,128]{2,1,0} parameter(0) + ROOT all-to-all.1 = f32[4,8,128]{2,1,0} all-to-all(param.1), channel_id=1, dimensions={1} + } + + ENTRY main { + param.0 = f32[4,8,128]{2,1,0} parameter(0) + all-to-all-start = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) async-start(param.0), calls=async_computation.1 + ROOT all-to-all-done = f32[4,8,128]{2,1,0} async-done(all-to-all-start) + })"; + + const char* expected = R"( + CHECK: %command_buffer ([[P:.+]]: f32[4,8,128]) -> f32[4,8,128] { + CHECK: %[[P]] = f32[4,8,128]{2,1,0} parameter(0) + CHECK: %[[S1:.+]] = ((f32[4,8,128]{2,1,0}), f32[4,8,128]{2,1,0}) all-to-all-start(%[[P]]), channel_id=1, replica_groups={}, dimensions={1} + CHECK: ROOT {{.*}} = f32[4,8,128]{2,1,0} all-to-all-done(%[[S1]]) + CHECK: })"; + + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) { if (backend().platform()->Name() == "Host") { GTEST_SKIP() << "GPU support required for this test"; @@ -1108,8 +1144,14 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionDynamicSlicing) { rs = s32[4,32] reduce-scatter(input), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add ROOT dus = s32[8,32] dynamic-update-slice(p1, rs, c0, c0) })"; + TF_ASSERT_OK_AND_ASSIGN(auto original_module, + ParseAndReturnVerifiedModule(hlo)); + DebugOptions& original_options = + original_module->mutable_config().mutable_debug_options(); + original_options.set_xla_gpu_enable_dynamic_slice_fusion(true); - TF_ASSERT_OK_AND_ASSIGN(auto m, GetOptimizedModule(hlo)); + TF_ASSERT_OK_AND_ASSIGN(auto m, + GetOptimizedModule(std::move(original_module))); HloModuleConfig config(m->config()); DebugOptions options(config.debug_options()); @@ -1241,5 +1283,52 @@ TEST_F(CommandBufferSchedulingTest, DynamicSliceFusionStaticSlicing) { false, true, std::nullopt)); } +TEST_F(CommandBufferSchedulingTest, ReturnFalseWhenNoChange) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + ROOT call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), + std::nullopt); +} + +TEST_F(CommandBufferSchedulingTest, ReturnTrueWhenOnlyParamMoved) { + const char* hlo = R"( + HloModule module, is_scheduled=true + ENTRY main { + a = s32[8,8] parameter(0) + b = s32[8,8] parameter(1) + call = s32[8,8] custom-call(a,b), custom_call_target="__cublas$gemm" + c = s32[8,8] parameter(2) + ROOT call2 = s32[8,8] custom-call(call, c), custom_call_target="__cublas$gemm" + } + )"; + + HloModuleConfig config; + DebugOptions options = GetDebugOptionsForTest(); + options.clear_xla_gpu_enable_command_buffer(); + options.add_xla_gpu_enable_command_buffer(DebugOptions::COLLECTIVES); + config.set_debug_options(options); + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo, config)); + RunAndFilecheckHloRewrite(hlo, CommandBufferScheduling(device_desc()), R"( + // CHECK: %{{.+}} = {{.+}} parameter(0) + // CHECK: %{{.+}} = {{.+}} parameter(1) + // CHECK: %{{.+}} = {{.+}} parameter(2) + // CHECK: %{{.+}} = {{.+}} custom-call + // CHECK: %{{.+}} = {{.+}} custom-call + )"); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h index 6507080a5fa49b..85ae15ada81142 100644 --- a/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h @@ -23,7 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/convert_async_collectives_to_sync.h" +#include "xla/hlo/transforms/collectives/convert_async_collectives_to_sync.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc index 77a32c935e412e..45c47fa8128658 100644 --- a/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc @@ -39,7 +39,7 @@ class CublasGemmPadForTensorCoresTest : public HloTestBase { .value(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); // Some pads would not be added if we detect that Triton will handle the // given dot operation. diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc index 1c01b3f47cd878..9f7668c9e226bd 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -199,7 +199,7 @@ absl::StatusOr HloCustomCallToCuDnnGraph( dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, activation, static_cast(config.fmha_scale()), dnn_mask_type)); return std::move(graph); - } else { + } else if (IsBwdCustomCallTofMHA(*custom_call)) { TF_ASSIGN_OR_RETURN( auto gpu_config, custom_call->backend_config()); @@ -314,6 +314,75 @@ absl::StatusOr HloCustomCallToCuDnnGraph( config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, dnn_mask_type, force_deterministic, sliding_window_length)); return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + Shape bmm1_grad_gemm1_rhs_shape = custom_call->operand(0)->shape(); + Shape bmm1_grad_gemm2_rhs_shape = custom_call->operand(1)->shape(); + Shape bmm2_grad_gemm2_rhs_shape = custom_call->operand(2)->shape(); + + Shape fwd_output_shape = custom_call->operand(3)->shape(); + Shape d_output_shape = custom_call->operand(4)->shape(); + + Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + + Shape d_bmm1_lhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {0}); + Shape d_bmm1_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {1}); + Shape d_bmm2_rhs_shape = ShapeUtil::GetSubshape(custom_call->shape(), {2}); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm1_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm1_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm2_rhs_shape, + config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm1_lhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm1_lhs_shape, + config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm2_rhs_shape, + config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor d_output, + MatmulTensorDescriptorFor( + d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), + RHS)); + + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs, + TensorDescriptorFor(d_bmm1_lhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs, + TensorDescriptorFor(d_bmm1_rhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, + TensorDescriptorFor(d_bmm2_rhs_shape)); + // 3 gradients, 4 amaxs and one workspace + TF_RET_CHECK(8 == custom_call->shape().tuple_shapes().size()); + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, + bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, + d_bmm1_rhs, d_bmm2_rhs, config.fmha_scale(), dnn_mask_type)); + return std::move(graph); } } diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index 7e9e3710569406..c4ebb27d62ab71 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -35,18 +35,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/algebraic_simplifier.h" -#include "xla/service/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/convert_mover.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/gpu/transforms/conv_rewriter.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/reshape_mover.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/semantic_version.h" @@ -67,9 +67,9 @@ namespace m = match; using ::testing::HasSubstr; using ::testing::Not; -constexpr std::initializer_list kf16f32f64{"f16", "f32", - "f64"}; -constexpr std::initializer_list kf16f32{"f16", "f32"}; +static const std::initializer_list kf16f32f64{"f16", "f32", + "f64"}; +static const std::initializer_list kf16f32{"f16", "f32"}; class CudnnFusedConvRewriterHloTest : public HloTestBase { public: diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc index b7bce9d4ea2549..89a0408f29e8eb 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc index a64fd0624bea62..f61987cf88e67e 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc @@ -26,20 +26,20 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/algebraic_simplifier.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/computation_layout.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_verifier.h" #include "xla/service/layout_normalization.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/reshape_decomposer.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/test_helpers.h" @@ -105,7 +105,7 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase { }); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cudnn_fmha(true); debug_options.set_xla_gpu_fused_attention_use_cudnn_rng(true); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc index 7299643818f5c7..9315cb85ff6c29 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc index e64cc8409acf21..4b024154669f8a 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc @@ -48,7 +48,7 @@ limitations under the License. #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernel_reuse_cache.h" -#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc index 5d5e089933fd88..28270b163ffab5 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc @@ -43,12 +43,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/types.h" #include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: keep @@ -113,15 +113,19 @@ using NormMetadataMap = absl::flat_hash_map; // HloInstruction: // UniqueHloInstruction x; // bool m = Match( -// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.capture_and_verify)), -// m::Sin(m::Op().WithPredicate(x.capture_and_verify)))); +// instr, m::Divide(m::Cos(m::Op().WithPredicate(x.CaptureOrVerifyFn())), +// m::Sin(m::Op().WithPredicate(x.CaptureOrVerifyFn())))); // m is true and x.Instr() returns an HloInstruction pointer to the operand of // cosine and sine iff HloInstruction *instr points to a division of a cosine by // a sine that operate on the same instruction. class UniqueHloInstruction { public: UniqueHloInstruction() - : is_set_(false), instr_(nullptr), capture_or_verify_() {} + : is_set_(false), + instr_(nullptr), + capture_or_verify_([this](const HloInstruction* instr) -> bool { + return CaptureOrVerify(const_cast(instr)); + }) {} HloInstruction* Instr() const { return instr_; } void SetInstr(HloInstruction* instr) { is_set_ = true; @@ -143,12 +147,7 @@ class UniqueHloInstruction { // Returns a std::function for capturing or verifying an instruction using // WithPredicate. - std::function GetCaptureOrVerifyFn() { - if (!capture_or_verify_) { - capture_or_verify_ = [this](const HloInstruction* instr) -> bool { - return CaptureOrVerify(const_cast(instr)); - }; - } + std::function CaptureOrVerifyFn() const { return capture_or_verify_; } @@ -597,7 +596,7 @@ auto Expectation(UniqueHloInstruction* expectation, Pattern pattern) { .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); }) - .WithPredicate(expectation->GetCaptureOrVerifyFn())); + .WithPredicate(expectation->CaptureOrVerifyFn())); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -612,7 +611,7 @@ auto Expectation(UniqueHloInstruction* expectation, HloInstruction** reduce, .WithPredicate([](const HloInstruction* instr) { return CalculatesExpectation(instr); }) - .WithPredicate(expectation->GetCaptureOrVerifyFn())); + .WithPredicate(expectation->CaptureOrVerifyFn())); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -624,19 +623,19 @@ auto Variance(UniqueHloInstruction* variance, UniqueHloInstruction* expectation, return m::AnyOf( Subtract( Expectation(Square(OptionalSupportedTransform( - m::Op().WithPredicate(x->GetCaptureOrVerifyFn())))), - Square(Expectation(expectation, - OptionalSupportedTransform(m::Op().WithPredicate( - x->GetCaptureOrVerifyFn()))))) - .WithPredicate(variance->GetCaptureOrVerifyFn()), + m::Op().WithPredicate(x->CaptureOrVerifyFn())))), + Square(Expectation( + expectation, OptionalSupportedTransform( + m::Op().WithPredicate(x->CaptureOrVerifyFn()))))) + .WithPredicate(variance->CaptureOrVerifyFn()), Expectation( Square(Subtract( OptionalSupportedTransform( - m::Op().WithPredicate(x->GetCaptureOrVerifyFn())), + m::Op().WithPredicate(x->CaptureOrVerifyFn())), Expectation(expectation, - OptionalSupportedTransform(m::Op().WithPredicate( - x->GetCaptureOrVerifyFn())))))) - .WithPredicate(variance->GetCaptureOrVerifyFn())); + OptionalSupportedTransform( + m::Op().WithPredicate(x->CaptureOrVerifyFn())))))) + .WithPredicate(variance->CaptureOrVerifyFn())); } // Reciprocal of the square root of variance + epsilon with optional broadcast. @@ -647,7 +646,7 @@ auto NormFactor(HloInstruction** norm_factor, UniqueHloInstruction* x, auto shared_subpattern = m::SharedSubpattern(Rsqrt( norm_factor, AddAnyOrder(Variance(variance, expectation, x), m::Broadcast(m::ConstantScalar().WithPredicate( - epsilon->GetCaptureOrVerifyFn()))))); + epsilon->CaptureOrVerifyFn()))))); return m::AnyOf(m::Broadcast(shared_subpattern), shared_subpattern); } @@ -696,10 +695,10 @@ auto SubtractMultiplyAddAnyOrder(P0 p0, P1 p1, P2 p2, P3 p3, P4 p4) { // Expectation fused into a layer norm Custom Call. auto FusedExpectation(UniqueHloInstruction* custom_call) { - auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 1)); + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 1)); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -708,21 +707,20 @@ auto FusedExpectation(UniqueHloInstruction* custom_call) { auto FusedExpectation(UniqueHloInstruction* fused_expectation, UniqueHloInstruction* custom_call) { auto shared_subpattern = m::SharedSubpattern( - m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 1) - .WithPredicate(fused_expectation->GetCaptureOrVerifyFn())); + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 1) + .WithPredicate(fused_expectation->CaptureOrVerifyFn())); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } // Norm factor fused into a layer norm Custom Call. auto FusedNormFactor(UniqueHloInstruction* custom_call) { - auto shared_subpattern = m::SharedSubpattern(m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 2)); + auto shared_subpattern = m::SharedSubpattern( + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 2)); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -731,11 +729,10 @@ auto FusedNormFactor(UniqueHloInstruction* custom_call) { auto FusedNormFactor(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* custom_call) { auto shared_subpattern = m::SharedSubpattern( - m::GetTupleElement( - m::CustomCall({kCudnnNormCallTarget}) - .WithPredicate(custom_call->GetCaptureOrVerifyFn()), - 2) - .WithPredicate(fused_norm_factor->GetCaptureOrVerifyFn())); + m::GetTupleElement(m::CustomCall({kCudnnNormCallTarget}) + .WithPredicate(custom_call->CaptureOrVerifyFn()), + 2) + .WithPredicate(fused_norm_factor->CaptureOrVerifyFn())); return m::AnyOf(shared_subpattern, BitcastOrReshape(shared_subpattern)); } @@ -784,7 +781,7 @@ auto XCenter(UniqueHloInstruction* x_center, UniqueHloInstruction* x, }; return Subtract(m::Op(), m::Broadcast(FusedExpectation(fused_expectation, custom_call))) - .WithPredicate(x_center->GetCaptureOrVerifyFn()) + .WithPredicate(x_center->CaptureOrVerifyFn()) .WithPredicate(capture_or_verify_x); } @@ -806,7 +803,7 @@ auto F0(UniqueHloInstruction* custom_call, UniqueHloInstruction* scale, reduce, MultiplyMultiplyAnyOrder( XCenter(x, custom_call, norm_metadata), m::Broadcast(m::Op().WithPredicate(capture_or_verify_scale)), - m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); + m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); } // Product of XCenter and the scaled and broadcasted product of F0 and @@ -872,7 +869,7 @@ auto F2(UniqueHloInstruction* fused_norm_factor, UniqueHloInstruction* scale, m::Broadcast( BitcastOrReshape(FusedNormFactor(fused_norm_factor, custom_call))), MultiplyAnyOrder(m::Broadcast().WithPredicate(capture_or_verify_scale), - m::Op().WithPredicate(dy->GetCaptureOrVerifyFn()))); + m::Op().WithPredicate(dy->CaptureOrVerifyFn()))); } class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { @@ -902,10 +899,10 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { instr, SubtractMultiplyAddAnyOrder( OptionalSupportedTransform( - m::Op().WithPredicate(x.GetCaptureOrVerifyFn())), + m::Op().WithPredicate(x.CaptureOrVerifyFn())), Expectation(&expectation, &reduce, - OptionalSupportedTransform(m::Op().WithPredicate( - x.GetCaptureOrVerifyFn()))), + OptionalSupportedTransform( + m::Op().WithPredicate(x.CaptureOrVerifyFn()))), NormFactor(&norm_factor, &x, &variance, &expectation, &epsilon), m::Broadcast(&broadcast_scale, m::Op(&scale)), m::Broadcast(&broadcast_bias, m::Op(&bias))))) { @@ -949,7 +946,36 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { } // Verify the element types. The element types of input and output and the - // shapes of scale and bias must match. + // shapes of scale and bias must match. If a conversion to the type of the + // input is the only user of the output, set the output to the conversion. + // Similarly, to ensure the scale and bias have the same type, if the + // scale/bias is a conversion from the type of the bias/scale, set the + // scale/bias to the operand of the conversion. If scale and bias are type + // conversions from the same type, set both to the operands of the + // conversions. + if (instr->user_count() == 1 && + instr->users()[0]->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(instr->users()[0]->shape(), + x.Instr()->shape())) { + instr = instr->users()[0]; + } + if (scale->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(scale->operand(0)->shape(), + bias->shape())) { + scale = scale->mutable_operand(0); + } + if (bias->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(bias->operand(0)->shape(), + scale->shape())) { + bias = bias->mutable_operand(0); + } + if (scale->opcode() == HloOpcode::kConvert && + bias->opcode() == HloOpcode::kConvert && + ShapeUtil::SameElementType(scale->operand(0)->shape(), + bias->operand(0)->shape())) { + scale = scale->mutable_operand(0); + bias = bias->mutable_operand(0); + } if (!CompatibleElementType(instr) || !CompatibleElementType(scale) || !CompatibleElementType(bias) || !ShapeUtil::SameElementType(instr->shape(), x.Instr()->shape()) || @@ -1134,12 +1160,11 @@ class CudnnNormRewriterVisitor : public DfsHloRewriteVisitor { UniqueHloInstruction& epsilon) { HloInstruction* gte = custom_call->users()[0]; if (Match(instr, - m::Divide( - m::Op(), - AddAnyOrder( - m::Op().WithPredicate(variance.GetCaptureOrVerifyFn()), - m::Broadcast(m::ConstantScalar().WithPredicate( - epsilon.GetCaptureOrVerifyFn())))))) { + m::Divide(m::Op(), + AddAnyOrder( + m::Op().WithPredicate(variance.CaptureOrVerifyFn()), + m::Broadcast(m::ConstantScalar().WithPredicate( + epsilon.CaptureOrVerifyFn())))))) { // Verify the uniqueness of the operands. if (!variance.Instr() || !epsilon.Instr()) { VLOG(1) << "Layer norm operands not unique."; diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc index a3dbc71132949a..fcace976245171 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc @@ -40,7 +40,7 @@ class CudnnNormRewriterTest : public GpuCodegenTest { .cuda_compute_capability(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cudnn_layer_norm(true); return debug_options; @@ -535,6 +535,69 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { TestNorm(hlo_text, optimized_hlo); } +TEST_F(CudnnNormRewriterTest, LayerNorm4D3TypeConversion) { + const char* hlo_text = R"( + HloModule test + + apply { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT c = f32[] add(a,b) + } + + ENTRY test { + input = f16[2,4,6,8] parameter(0) + input_f32 = f32[2,4,6,8] convert(input) + input_square = f32[2,4,6,8] multiply(input_f32, input_f32) + c0 = f32[] constant(0) + input_square_sum = f32[2,4,6] reduce(input_square, c0), dimensions={3}, to_apply=apply + r_nelems = f32[] constant(0.125) + r_nelems_bcast = f32[2,4,6] broadcast(r_nelems), dimensions={} + input_square_mean = f32[2,4,6] multiply(input_square_sum, r_nelems_bcast) + input_sum = f32[2,4,6] reduce(input_f32, c0), dimensions={3}, to_apply=apply + input_mean = f32[2,4,6] multiply(input_sum, r_nelems_bcast) + input_mean_square = f32[2,4,6] multiply(input_mean, input_mean) + variance = f32[2,4,6] subtract(input_square_mean, input_mean_square) + epsilon = f32[] constant(0.001) + epsilon_bcast = f32[2,4,6] broadcast(epsilon), dimensions={} + variance_plus_epsilon = f32[2,4,6] add(variance, epsilon_bcast) + norm_factor = f32[2,4,6] rsqrt(variance_plus_epsilon) + norm_factor_bcast = f32[2,4,6,8] broadcast(norm_factor), dimensions={0,1,2} + input_mean_bcast = f32[2,4,6,8] broadcast(input_mean), dimensions={0,1,2} + input_center = f32[2,4,6,8] subtract(input_f32, input_mean_bcast) + norm = f32[2,4,6,8] multiply(norm_factor_bcast, input_center) + scale = f16[8] parameter(1) + scale_f32 = f32[8] convert(scale) + scale_bcast = f32[2,4,6,8] broadcast(scale_f32), dimensions={3} + norm_scale = f32[2,4,6,8] multiply(norm, scale_bcast) + bias = f16[8] parameter(2) + bias_f32 = f32[8] convert(bias) + bias_bcast = f32[2,4,6,8] broadcast(bias_f32), dimensions={3} + norm_scale_bias = f32[2,4,6,8] add(norm_scale, bias_bcast) + ROOT out = f16[2,4,6,8] convert(norm_scale_bias) + })"; + + const char* optimized_hlo = R"( + +; CHECK-LABEL: ENTRY %test ({{.*}}: f16[2,4,6,8], {{.*}}: f16[8], {{.*}}: f16[8]) -> f16[2,4,6,8] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f16[2,4,6,8]{3,2,1,0} parameter(0) +; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = f16[8]{0} parameter(1) +; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P1]]) +; CHECK-NEXT: [[P2:%[^ ]+]] = f16[8]{0} parameter(2) +; CHECK-NEXT: [[P2_BITCAST:%[^ ]+]] = f16[1,8,1,1]{3,2,1,0} bitcast([[P2]]) +; CHECK-NEXT: [[CC:%[^ ]+]] = (f16[48,8,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P2_BITCAST]]), +; CHECK: custom_call_target="__cudnn$norm", +; CHECK: backend_config={ +; CHECK-DAG: "epsilon":0.001 +; CHECK: } +; CHECK-NEXT: [[GTE:%[^ ]+]] = f16[48,8,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 +; CHECK-NEXT: ROOT {{.*}} = f16[2,4,6,8]{3,2,1,0} bitcast([[GTE]]) + )"; + + TestNorm(hlo_text, optimized_hlo); +} + TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) { const char* hlo_text = R"( HloModule test diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc index 7cee2c54f166e7..8c1cd26e228389 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc index 6437c208b1878a..2ee606ce1a2820 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc @@ -26,15 +26,15 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/literal.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" #include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/reshape_mover.h" -#include "xla/service/tuple_simplifier.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 698b8fb73dd579..d939347874bb95 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -28,8 +28,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -40,6 +40,7 @@ limitations under the License. #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cudnn_support_utils.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -96,24 +97,6 @@ static std::vector GetRelevantConvs( return convs; } -// Converts an XlaBuilder into an HloComputation in the same module as -// `sibling_computation`. -// -// Yes, we serialize/deserialize as a proto. :) -static absl::StatusOr BuilderToHloComputation( - XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - // Reshapes `instr` so that it has an extra dimension of size `vect_size` right // after `dim`. static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) { @@ -460,11 +443,11 @@ static absl::StatusOr TryRevectorizeConv( new_conv_result, dnums->output_feature_dimension(), *output_vect_dim, /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim)); + XlaOp root = Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Set the name on the new conv. This is purely cosmetic, but we attempt to // preserve e.g. "cudnn-conv.42" instead of "custom-call.42". @@ -599,11 +582,11 @@ static absl::StatusOr TryVectorizeConv( Collapse(new_conv_result, {dnums->output_feature_dimension(), dnums->output_feature_dimension() + 1}); + XlaOp root = Tuple(&b, {conv_result_collapsed, new_conv_scratch}); + TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root)); TF_ASSIGN_OR_RETURN( HloComputation * new_conv_comp, - BuilderToHloComputation( - b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}), - conv->parent())); + XlaComputationToHloComputation(comp, conv->parent()->parent())); // Create a tuple and replace the old conv with it! VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString(); diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc index 7528870af4c605..b955f6f0f1b325 100644 --- a/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc @@ -23,10 +23,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/call_inliner.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" diff --git a/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc new file mode 100644 index 00000000000000..09b063d3b44662 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.cc @@ -0,0 +1,186 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/dot_algorithm_rewriter.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "tsl/platform/status.h" + +namespace xla::gpu { + +namespace { + +HloInstruction* Truncate(HloInstruction* f32_param) { + // Cast to int32 first, then zero out the high bits. Then cast back to f32. + Shape u32_shape = f32_param->shape(); + u32_shape.set_element_type(PrimitiveType::U32); + HloInstruction* u32_param = f32_param->AddInstruction( + HloInstruction::CreateBitcastConvert(u32_shape, f32_param)); + HloInstruction* mask_constant = + f32_param->parent()->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(0xFFFF0000))); + HloInstruction* u32_mask = u32_param->AddInstruction( + HloInstruction::CreateBroadcast(u32_shape, mask_constant, {})); + HloInstruction* masked_u32 = + u32_param->AddInstruction(HloInstruction::CreateBinary( + u32_shape, HloOpcode::kAnd, u32_param, u32_mask)); + return masked_u32->AddInstruction( + HloInstruction::CreateBitcastConvert(f32_param->shape(), masked_u32)); +} + +HloInstruction* Sub(HloInstruction* instr, HloInstruction* high) { + return instr->AddInstruction(HloInstruction::CreateBinary( + instr->shape(), HloOpcode::kSubtract, instr, high)); +} + +HloInstruction* RoundToBF16(HloInstruction* instr) { + Shape new_shape = instr->shape(); + new_shape.set_element_type(PrimitiveType::BF16); + return instr->AddInstruction(HloInstruction::CreateConvert(new_shape, instr)); +} + +std::pair Split2x(HloInstruction* f32_param) { + HloInstruction* high_f32 = Truncate(f32_param); + HloInstruction* low_f32 = Sub(f32_param, high_f32); + return std::make_pair(RoundToBF16(high_f32), RoundToBF16(low_f32)); +} + +std::tuple Split3x( + HloInstruction* f32_param) { + HloInstruction* high_f32_t = Truncate(f32_param); + HloInstruction* mid_f32 = Sub(f32_param, high_f32_t); + HloInstruction* mid_f32_t = Truncate(mid_f32); + HloInstruction* low_f32_t = Truncate(Sub(mid_f32, mid_f32_t)); + return std::make_tuple(RoundToBF16(high_f32_t), RoundToBF16(mid_f32_t), + RoundToBF16(low_f32_t)); +} + +void RewriteF32ToBF16X3(HloInstruction* instr) { + HloComputation* computation = instr->parent(); + HloDotInstruction* dot = Cast(instr); + PrecisionConfig precision_config = dot->precision_config(); + precision_config.clear_algorithm(); + const Shape& shape = dot->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + + auto [lhs_high_bf16, lhs_low_bf16] = Split2x(dot->mutable_operand(0)); + auto [rhs_high_bf16, rhs_low_bf16] = Split2x(dot->mutable_operand(1)); + + HloInstruction* high_dot = + computation->AddInstruction(HloInstruction::CreateDot( + shape, lhs_high_bf16, rhs_high_bf16, dnums, precision_config)); + HloInstruction* left_low = + computation->AddInstruction(HloInstruction::CreateDot( + shape, lhs_high_bf16, rhs_low_bf16, dnums, precision_config)); + HloInstruction* right_low = + computation->AddInstruction(HloInstruction::CreateDot( + shape, lhs_low_bf16, rhs_high_bf16, dnums, precision_config)); + HloInstruction* low_sum = + computation->AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, left_low, right_low)); + HloInstruction* sum = + computation->AddInstruction(HloInstruction::CreateBinary( + dot->shape(), HloOpcode::kAdd, low_sum, high_dot)); + TF_CHECK_OK(dot->ReplaceAllUsesWith(sum)); + TF_CHECK_OK(dot->parent()->RemoveInstruction(dot)); +} + +void RewriteF32ToBF16X6(HloInstruction* instr) { + HloComputation* computation = instr->parent(); + HloDotInstruction* original_dot = Cast(instr); + PrecisionConfig precision_config = original_dot->precision_config(); + precision_config.clear_algorithm(); + const Shape& shape = original_dot->shape(); + const DotDimensionNumbers& dnums = original_dot->dot_dimension_numbers(); + auto dot = [&](HloInstruction* lhs, HloInstruction* rhs) { + return computation->AddInstruction( + HloInstruction::CreateDot(shape, lhs, rhs, dnums, precision_config)); + }; + auto sum = [&](HloInstruction* lhs, HloInstruction* rhs) { + return computation->AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs)); + }; + + auto [lhs_high_bf16, lhs_mid_bf16, lhs_low_bf16] = + Split3x(original_dot->mutable_operand(0)); + auto [rhs_high_bf16, rhs_mid_bf16, rhs_low_bf16] = + Split3x(original_dot->mutable_operand(1)); + + HloInstruction* middle_middle_dot = dot(lhs_mid_bf16, rhs_mid_bf16); + HloInstruction* high_low_dot = dot(lhs_high_bf16, rhs_low_bf16); + HloInstruction* low_high_dot = dot(lhs_low_bf16, rhs_high_bf16); + HloInstruction* high_middle_dot = dot(lhs_high_bf16, rhs_mid_bf16); + HloInstruction* middle_high_dot = dot(lhs_mid_bf16, rhs_high_bf16); + HloInstruction* high_high_dot = dot(lhs_high_bf16, rhs_high_bf16); + + HloInstruction* result = nullptr; + result = sum(middle_middle_dot, high_low_dot); + result = sum(result, low_high_dot); + result = sum(result, high_middle_dot); + result = sum(result, middle_high_dot); + result = sum(result, high_high_dot); + + TF_CHECK_OK(original_dot->ReplaceAllUsesWith(result)); + TF_CHECK_OK(original_dot->parent()->RemoveInstruction(original_dot)); +} + +} // namespace + +absl::StatusOr DotAlgorithmRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + auto algorithm = instruction->precision_config().algorithm(); + switch (algorithm) { + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + RewriteF32ToBF16X3(instruction); + changed = true; + break; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + RewriteF32ToBF16X6(instruction); + changed = true; + break; + default: + break; + } + } + } + return changed; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.h new file mode 100644 index 00000000000000..fdc55bb14ba3d2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/dot_algorithm_rewriter.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_ALGORITHM_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_ALGORITHM_REWRITER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla::gpu { + +class DotAlgorithmRewriter : public HloModulePass { + public: + DotAlgorithmRewriter() = default; + absl::string_view name() const override { return "dot-algorithm-rewriter"; } + using HloPassInterface::Run; + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_ALGORITHM_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc index 364c1405f09267..8b6efd2e757da8 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc @@ -31,7 +31,7 @@ namespace { class WithoutDotDimensionSorterTest : public GpuCodegenTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // The pass is disabled here to preserve suboptimal dimension order in // 1) UnsortedDimsCreateTransposes to reveal the transposes. diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc new file mode 100644 index 00000000000000..77768f32711fe8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/dot_normalizer.h" + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +bool DotNormalizer::InstructionMatchesPattern(HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kDot) { + return false; + } + return instruction->dot_dimension_numbers() + .lhs_contracting_dimensions() + .empty(); +} + +absl::StatusOr DotNormalizer::ExpandInstruction( + HloInstruction* instruction) { + HloDotInstruction* dot = Cast(instruction); + HloInstruction* lhs = dot->mutable_operand(0); + Shape new_lhs_shape = lhs->shape(); + ShapeUtil::AppendMinorDimension(1, &new_lhs_shape); + HloInstruction* normalized_lhs = + dot->AddInstruction(HloInstruction::CreateBitcast(new_lhs_shape, lhs)); + TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(0, normalized_lhs)); + HloInstruction* rhs = dot->mutable_operand(1); + Shape new_rhs_shape = rhs->shape(); + ShapeUtil::AppendMinorDimension(1, &new_rhs_shape); + HloInstruction* normalized_rhs = + dot->AddInstruction(HloInstruction::CreateBitcast(new_rhs_shape, rhs)); + TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(1, normalized_rhs)); + DotDimensionNumbers* dnums = dot->mutable_dot_dimension_numbers(); + dnums->add_lhs_contracting_dimensions(new_lhs_shape.rank() - 1); + dnums->add_rhs_contracting_dimensions(new_rhs_shape.rank() - 1); + return nullptr; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer.h b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.h new file mode 100644 index 00000000000000..97e85229195f1d --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/util.h" + +namespace xla::gpu { + +// Ensures that a dot has at least 1 contracting dimension. If there are no +// contracting dimensions, a trivial 1-sized contracting dimension is added. +// This pass is expected to be run after layout assignment. +class DotNormalizer : public OpExpanderPass { + public: + explicit DotNormalizer(HloPredicate extra_filter = nullptr) + : OpExpanderPass(std::move(extra_filter)) {} + + absl::string_view name() const override { return "dot_normalizer"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc new file mode 100644 index 00000000000000..9242c1f73bf7a7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/dot_normalizer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/dot_normalizer.h" + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +namespace m = ::xla::match; + +using DotNormalizerTest = HloTestBase; +using ::tsl::testing::IsOkAndHolds; + +TEST_F(DotNormalizerTest, DotWithoutContractingDims) { + constexpr char kHlo[] = R"( + HloModule test + + ENTRY main { + p0 = f16[5,15]{1,0} parameter(0) + p1 = f16[5,16,17]{2,1,0} parameter(1) + ROOT r = f16[5,15,16,17]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_THAT(DotNormalizer().Run(m.get()), IsOkAndHolds(true)); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Dot(m::Bitcast().WithShape(F16, {5, 15, 1}, {2, 1, 0}), + m::Bitcast().WithShape(F16, {5, 16, 17, 1}, {3, 2, 1, 0})) + .WithContractingDims({2}, {3}))); +} + +TEST_F(DotNormalizerTest, DotWithContractingDims) { + constexpr char kHlo[] = R"( + HloModule test + + ENTRY main { + p0 = f16[5,15,3]{2,1,0} parameter(0) + p1 = f16[5,17,3]{2,1,0} parameter(1) + ROOT r = f16[5,15,17]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_THAT(DotNormalizer().Run(m.get()), IsOkAndHolds(false)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h index b269bed8b6a6f3..7b3331efc96707 100644 --- a/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" #include "xla/util.h" namespace xla::gpu { diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc index f1d9248ae12d94..7d217aac5674ee 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc @@ -38,10 +38,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/flatten_call_graph.h" -#include "xla/service/hlo_parser.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -523,6 +523,23 @@ absl::StatusOr DoubleBufferingUnroll(HloInstruction* while_instr, return true; // changed } +// Function performs double buffering unrolling strategy iff there is any +// collective operation within a body computation. +absl::StatusOr AutoUnroll(HloInstruction* while_instr, + HloModule* module) { + CHECK_EQ(while_instr->opcode(), HloOpcode::kWhile); + + bool any_collective_present = absl::c_any_of( + while_instr->while_body()->MakeInstructionPostOrder(), + [](HloInstruction* instr) { + return hlo_query::IsCollectiveCommunicationOp(instr->opcode()); + }); + if (any_collective_present) { + return DoubleBufferingUnroll(while_instr, module); + } + return false; // IR not changed. +} + } // namespace absl::StatusOr DoubleBufferLoopUnrolling::Run( @@ -555,6 +572,8 @@ absl::StatusOr DoubleBufferLoopUnrolling::Run( TF_ASSIGN_OR_RETURN(changed, FullyUnroll(while_instr, module)); } else if (unroll_strategy_ == UnrollStrategy::kDoubleBuffer) { TF_ASSIGN_OR_RETURN(changed, DoubleBufferingUnroll(while_instr, module)); + } else if (unroll_strategy_ == UnrollStrategy::kAuto) { + TF_ASSIGN_OR_RETURN(changed, AutoUnroll(while_instr, module)); } else { LOG(FATAL) << absl::StrCat("Unhandled unrolling strategy: ", unroll_strategy_); diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h index c3d774fe60b176..aa4803457a1815 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h @@ -47,7 +47,7 @@ namespace gpu { // unrolled. class DoubleBufferLoopUnrolling : public HloModulePass { public: - enum class UnrollStrategy { kDoubleBuffer, kFullUnroll }; + enum class UnrollStrategy { kDoubleBuffer, kFullUnroll, kAuto }; explicit DoubleBufferLoopUnrolling( UnrollStrategy unroll_strategy = UnrollStrategy::kDoubleBuffer) diff --git a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc index a12376e931defd..c4c99c819519cb 100644 --- a/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc @@ -21,14 +21,16 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_query.h" -#include "xla/service/tuple_simplifier.h" #include "xla/test.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -55,13 +57,103 @@ int64_t CountInstructions(HloModule& module, HloOpcode opcode) { return count; } -class GpuLoopDoubleBufferTransformerTest : public HloTestBase { - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_while_loop_double_buffering(true); - return debug_options; - } -}; +using GpuLoopDoubleBufferTransformerTest = HloTestBase; + +TEST_F(GpuLoopDoubleBufferTransformerTest, + AutoUnrollLoopWhenCollectivesArePresent) { + absl::string_view kModuleString = R"( +HloModule m +condition { + input_tuple = (f32[], s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=1 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} + +ar_add { + Arg_1 = f32[] parameter(1) + Arg_0 = f32[] parameter(0) + ROOT add_ar = f32[] add(Arg_1, Arg_0) +} + +body { + input_tuple = (f32[], s32[]) parameter(0) + param_0 = f32[] get-tuple-element(input_tuple), index=0 + cond = s32[] get-tuple-element(input_tuple), index=1 + all-reduce-start = f32[] all-reduce-start(param_0), channel_id=8, replica_groups={{0}}, to_apply=ar_add, backend_config="{\"is_sync\":false}" + one = s32[] constant(1) + all-reduce-done = f32[] all-reduce-done(all-reduce-start) + cond_plus_1 = s32[] add(cond, one) + ROOT output_tuple = (f32[], s32[]) tuple(all-reduce-done, cond_plus_1) +} + +ENTRY main { + param_0 = f32[] parameter(0) + param_2 = s32[] constant(0) + tuple = (f32[], s32[]) tuple(param_0, param_2) + ROOT while = (f32[], s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + HloPassPipeline pipeline("double-buffering-pipeline"); + DoubleBufferLoopUnrolling unroller( + DoubleBufferLoopUnrolling::UnrollStrategy::kAuto); + TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get())); + + EXPECT_TRUE(changed); + + HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode( + *module->entry_computation(), HloOpcode::kWhile); + TF_ASSERT_OK_AND_ASSIGN( + WhileLoopBackendConfig config, + while_instruction->backend_config()); + EXPECT_EQ(config.known_trip_count().n(), 5); + EXPECT_EQ(CountInstructions((*while_instruction->while_body()), + HloOpcode::kAllReduceStart), + 2); +} + +TEST_F(GpuLoopDoubleBufferTransformerTest, + DoNotAutoUnrollLoopWhenCollectivesAreNotPresent) { + absl::string_view kModuleString = R"( +HloModule m +condition { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + trip_count = s32[] constant(10) + ROOT done = pred[] compare(cond, trip_count), direction=LT +} + +body { + input_tuple = (s32[]) parameter(0) + cond = s32[] get-tuple-element(input_tuple), index=0 + one = s32[] constant(1) + cond_plus_1 = s32[] add(cond, one) + ROOT output_tuple = (s32[]) tuple(cond_plus_1) +} + +ENTRY main { + param_0 = s32[] constant(0) + tuple = (s32[]) tuple(param_0) + ROOT while = (s32[]) while(tuple), condition=condition, body=body, backend_config={"known_trip_count":{"n":"10"}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + DoubleBufferLoopUnrolling unroller( + DoubleBufferLoopUnrolling::UnrollStrategy::kAuto); + TF_ASSERT_OK_AND_ASSIGN(bool changed, unroller.Run(module.get())); + + EXPECT_FALSE(changed); + + HloInstruction* while_instruction = hlo_query::GetFirstInstructionWithOpcode( + *module->entry_computation(), HloOpcode::kWhile); + TF_ASSERT_OK_AND_ASSIGN( + WhileLoopBackendConfig config, + while_instruction->backend_config()); + EXPECT_EQ(config.known_trip_count().n(), 10); +} TEST_F(GpuLoopDoubleBufferTransformerTest, FullUnrollOddTripCountTest) { const char* const kModuleString = R"( diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 8ea3bc3801062e..f6a8e975239abf 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -159,9 +159,11 @@ bool IsAlignedSlice(const HloInstruction* slice) { // param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) // // the index in `gte` has to be the loop iteration index // gte = s32[] get-tuple-element(param), index=0 -// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT +// c0 = s32[] constant(0) +// compare = pred[] compare(gte, c0), direction=LT // c_trip_count = s32[] constant(16) -// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte) +// add = s32[] add(gte, c_trip_count) +// select = s32[] select(compare, add, gte) // clang-format on bool IsLoopIterationNumber(const HloInstruction& offset) { diff --git a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index 9a71c9930adc78..d477dc436a4752 100644 --- a/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "absl/status/status.h" -#include "xla/client/lib/constants.h" -#include "xla/client/xla_builder.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/custom_call_target_registry.h" diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.cc new file mode 100644 index 00000000000000..809183953d306d --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.cc @@ -0,0 +1,158 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/fusion_analysis_cache.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/instruction_fusion.h" +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +using ::mlir::MLIRContext; + +absl::StatusOr ProcessFusionInstruction( + HloFusionInstruction* fusion_instruction, + const se::DeviceDescription& device_info, + HloCostAnalysis::ShapeSizeFunction shape_size, MLIRContext* ctx) { + const HloComputation* fusion_computation = + fusion_instruction->fused_instructions_computation(); + if (CodegenDecision can_codegen = IsTritonSupportedComputation( + *fusion_computation, device_info.gpu_compute_capability()); + !can_codegen) { + VLOG(2) << "Can't rewrite fusion " << fusion_instruction->ToString() + << " because one or more instructions is not supported by Triton: " + << can_codegen.Explain(); + return false; + } + + TF_ASSIGN_OR_RETURN(auto backend_config, + fusion_instruction->backend_config()); + + if (backend_config.has_fusion_backend_config() && + backend_config.fusion_backend_config().has_block_level_fusion_config()) { + // Fusion is already block-level! Skip. + return false; + } + + HloFusionAnalysisCache fusion_analysis_cache(device_info); + GpuPerformanceModelWithIndexingAnalysis indexing_performance_model( + &device_info, &fusion_analysis_cache, shape_size, ctx); + + auto fusion_adaptor = HloFusionAdaptor::ForInstruction( + Cast(fusion_instruction)); + + TF_ASSIGN_OR_RETURN( + TiledRunTimeDataOrError tiled_runtime_data_or_error, + indexing_performance_model.TryFindBestTilingForFusion(*fusion_adaptor)); + + if (const auto* fusion_decision = + std::get_if(&tiled_runtime_data_or_error)) { + // Can't rewrite this fusion because we can't tile it, skip! + VLOG(2) << "Can't rewrite fusion " << fusion_instruction->ToString() + << " because tiling search failed. (The most likely cause for " + << "is that SymbolicTileAnalysis failed.)"; + return false; + } + + TiledRunTimeData tiled_runtime_data = + std::get(std::move(tiled_runtime_data_or_error)); + VLOG(1) + << "Found parameters " + << absl::StrCat( + "sizes=[", + absl::StrJoin( + tiled_runtime_data.block_level_parameters.output_tile_sizes, + ", "), + "], num_warps=", + tiled_runtime_data.block_level_parameters.num_warps) + << " for fusion computation " << fusion_computation->ToString(); + + *backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + tiled_runtime_data.block_level_parameters.ToBlockLevelFusionConfig(); + backend_config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonFusionKind)); + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config(backend_config)); + fusion_instruction->set_fusion_kind(HloInstruction::FusionKind::kCustom); + return true; +} + +} // anonymous namespace + +absl::StatusOr FusionBlockLevelRewriter::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_RETURN_IF_ERROR(EnsureTritonSupportsComputeCapability( + device_info_.gpu_compute_capability())); + + MLIRContext ctx; + bool has_changed = false; + + for (HloComputation* computation : + module->MakeComputationSorted(execution_threads)) { + if (!computation->IsFusionComputation()) { + continue; + } + + HloFusionInstruction* fusion_instruction = + ::xla::Cast(computation->FusionInstruction()); + + TF_ASSIGN_OR_RETURN(bool should_try_rewrite, + should_try_rewrite_if_(fusion_instruction)); + if (!should_try_rewrite) { + continue; + } + + TF_ASSIGN_OR_RETURN( + bool changed, ProcessFusionInstruction(fusion_instruction, device_info_, + shape_size_, &ctx)); + + has_changed |= changed; + } + + return has_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.h b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.h new file mode 100644 index 00000000000000..6cf8f988242f97 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter.h @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_BLOCK_LEVEL_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_BLOCK_LEVEL_REWRITER_H_ + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { +namespace gpu { + +class FusionBlockLevelRewriter : public HloModulePass { + public: + explicit FusionBlockLevelRewriter( + const se::DeviceDescription& device_info, + HloCostAnalysis::ShapeSizeFunction shape_size, + absl::AnyInvocable(const HloFusionInstruction*)> + should_try_rewrite_if) + : device_info_(device_info), + shape_size_(shape_size), + should_try_rewrite_if_(std::move(should_try_rewrite_if)) {} + + absl::string_view name() const override { + return "fusion-block-level-rewriter"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_info_; + HloCostAnalysis::ShapeSizeFunction shape_size_; + absl::AnyInvocable(const HloFusionInstruction*)> + should_try_rewrite_if_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_FUSION_BLOCK_LEVEL_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc new file mode 100644 index 00000000000000..d574fc106282ad --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/fusion_block_level_rewriter_test.cc @@ -0,0 +1,174 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. + +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/fusion_block_level_rewriter.h" + +#include +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +bool HasTritonBlockLevelFusionConfig(const HloInstruction* fusion) { + return fusion->opcode() == HloOpcode::kFusion && + fusion->has_backend_config() && + fusion->backend_config().ok() && + fusion->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config() && + fusion->backend_config() + ->fusion_backend_config() + .kind() == kTritonFusionKind; +} + +class FusionBlockLevelRewriterTest : public HloTestBase { + protected: + se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo( + se::CudaComputeCapability::Ampere())}; +}; + +bool RewriteEverythingPossible(const HloFusionInstruction* fusion) { + return true; +} + +TEST_F(FusionBlockLevelRewriterTest, + DoesNotRewriteFusionThatIsAlreadyBlockLevel) { + const absl::string_view hlo_text = R"( +fusion_computation { + ROOT param_0 = f32[10,10] parameter(0) +} + +ENTRY entry { + param_0 = f32[10,10] parameter(0) + ROOT fusion = f32[10,10] fusion(param_0), kind=kCustom, + calls=fusion_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton", "block_level_fusion_config":{}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + EXPECT_THAT( + FusionBlockLevelRewriter(device_info_, HloCostAnalysis::DefaultShapeSize, + RewriteEverythingPossible) + .Run(module.get()), + IsOkAndHolds(false)); +} + +TEST_F(FusionBlockLevelRewriterTest, + RewritesFusionThatIsNotBlockLevelAndCanBeTiledAndCodegenedCorrectly) { + const absl::string_view hlo_text = R"( +fusion_computation { + ROOT param_0 = f32[10,10] parameter(0) +} + +ENTRY entry { + param_0 = f32[10,10] parameter(0) + ROOT fusion = f32[10,10] fusion(param_0), kind=kLoop, + calls=fusion_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + + EXPECT_THAT( + FusionBlockLevelRewriter(device_info_, HloCostAnalysis::DefaultShapeSize, + RewriteEverythingPossible) + .Run(module.get()), + IsOkAndHolds(true)); + const HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kFusion); + EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom); + EXPECT_TRUE(HasTritonBlockLevelFusionConfig(root)); +} + +TEST_F(FusionBlockLevelRewriterTest, + DoesNotRewriteFusionThatIsNotBlockLevelAndCannotBeTiledCorrectly) { + const absl::string_view hlo_text = R"( +fusion_computation { + param_0 = f32[10,10] parameter(0) + ROOT bitcast = f32[25,4] bitcast(param_0) +} + +ENTRY entry { + param_0 = f32[10,10] parameter(0) + ROOT fusion = f32[25,4] fusion(param_0), kind=kLoop, + calls=fusion_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + mlir::MLIRContext ctx; + + ASSERT_FALSE(std::holds_alternative( + SymbolicTileAnalysis::AnalyzeComputation( + *module->GetComputationWithName("fusion_computation"), &ctx))); + EXPECT_THAT( + FusionBlockLevelRewriter(device_info_, HloCostAnalysis::DefaultShapeSize, + RewriteEverythingPossible) + .Run(module.get()), + IsOkAndHolds(false)); +} + +TEST_F(FusionBlockLevelRewriterTest, + DoesNotRewriteFusionThatIsNotBlockLevelAndCannotBeCodegenedCorrectly) { + const absl::string_view hlo_text = R"( +fusion_computation { + param_0 = f8e4m3fn[10,10] parameter(0) + ROOT add = f8e4m3fn[10,10] add(param_0, param_0) +} + +ENTRY entry { + param_0 = f8e4m3fn[10,10] parameter(0) + ROOT fusion = f8e4m3fn[10,10] fusion(param_0), kind=kLoop, + calls=fusion_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text)); + ASSERT_FALSE(IsTritonSupportedComputation( + *module->GetComputationWithName("fusion_computation"), + device_info_.gpu_compute_capability())); + EXPECT_THAT( + FusionBlockLevelRewriter(device_info_, HloCostAnalysis::DefaultShapeSize, + RewriteEverythingPossible) + .Run(module.get()), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc index 37986219faae16..d132cf6f3ae682 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc @@ -211,7 +211,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // merge. if (producer->users().empty()) { ++num_fail_no_users_; - return "fusion has no users"; + return FusionDecision::Forbid("fusion has no users"); } // Skip 'producer' instruction if it is not a loop fusion. Library fusion @@ -220,7 +220,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // kReduce), so they shouldn't be further fused either. if (!producer->IsLoopFusion()) { ++num_fail_not_loop_fusion_; - return "not a loop fusion"; + return FusionDecision::Forbid("not a loop fusion"); } auto producer_hero = GetRealHeroForMultiOutputFusion(*producer); @@ -229,11 +229,11 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { for (const HloInstruction* user : producer->users()) { if (user->opcode() == HloOpcode::kBitcast) { ++num_fail_merge_all_users_; - return "not fusing bitcast ops"; + return FusionDecision::Forbid("not fusing bitcast ops"); } if (user->IsCustomFusion()) { ++num_fail_merge_all_users_; - return "not fusing custom fusions"; + return FusionDecision::Forbid("not fusing custom fusions"); } auto consumer_hero = GetRealHeroForMultiOutputFusion(*user); if (auto compatible = @@ -256,7 +256,7 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { // it to a producer which transposes most data. if (has_reduction_user && TransposesMostData(*producer)) { ++num_fail_uncoalesced_read_; - return "would read mostly uncoalesced"; + return FusionDecision::Forbid("would read mostly uncoalesced"); } for (const HloInstruction* user : producer->users()) { @@ -285,8 +285,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { for (const HloInstruction* user : producer->users()) { if (cost_analysis_->ProducerConsumerMergedTooLarge(*producer, *user)) { ++num_fail_inefficient_fusion_emitter_; - return FusionDecision{} << "if merged with " << user->name() - << " will generate huge IR"; + return FusionDecision::Forbid("if merged with ") + << user->name() << " will generate huge IR"; } } @@ -295,10 +295,10 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { GpuPerformanceModelOptions::Default(), producer->users()); if (t.time_fused > t.time_unfused) { ++num_fail_slower_if_fused_; - return "will execute slower if fused"; + return FusionDecision::Forbid("will execute slower if fused"); } - return {}; + return FusionDecision::Allow(); } absl::StatusOr FusionMerger::Run( diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc index 5068f65a49b867..8d52e4f183e407 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc @@ -41,16 +41,9 @@ namespace { namespace m = ::xla::match; class FusionMergerTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: FusionMerger fusion_merger_{TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}; + HloCostAnalysis::DefaultShapeSize}; FusionMergerTest() : HloTestBase() {} }; diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc index 16957f80d370e0..12bac0af0e8758 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc @@ -62,7 +62,6 @@ absl::StatusOr FusionWrapper::Run( case HloOpcode::kConcatenate: case HloOpcode::kConvolution: case HloOpcode::kConvert: - case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kDot: diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index be5f0d7dfd49c6..afc0e21e9ca31a 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -124,6 +124,22 @@ TEST_F(FusionWrapperTest, ControlDependency) { // CHECK-SAME: control-predecessors={%fusion})"); } +TEST_F(FusionWrapperTest, Copy) { + // Avoid rewriting copies, so that the rematerialization pass + // can avoid rematerializing copies inserted by copy-insertion + // (the rematerialization could read overwritten data). + RunAndFilecheckHloRewrite(R"( + HloModule Copy + + ENTRY %main (parameter.1: f32[5]) -> f32[5] { + %parameter.1 = f32[5]{0} parameter(0) + ROOT %copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1) + })", + FusionWrapper(), + // No change + std::nullopt); +} + TEST_F(FusionWrapperTest, While) { RunAndFilecheckHloRewrite(R"( HloModule While @@ -148,8 +164,8 @@ TEST_F(FusionWrapperTest, While) { })", FusionWrapper(), R"( // CHECK: %wrapped_broadcast_computation {{.*}} { -// CHECK: %param_0.1 = f32[] parameter(0) -// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} +// CHECK: %param_0 = f32[] parameter(0) +// CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0), dimensions={} // CHECK: } // CHECK: %body {{.*}} { // CHECK: %parameter.5 = (f32[5]{0}) parameter(0) @@ -161,14 +177,10 @@ TEST_F(FusionWrapperTest, While) { // CHECK: %parameter.12 = (f32[5]{0}) parameter(0) // CHECK: ROOT %constant_1 = pred[] constant(false) // CHECK: } -// CHECK: %wrapped_copy_computation {{.*}} { -// CHECK: %param_0 = f32[5]{0} parameter(0) -// CHECK: ROOT %copy.0 = f32[5]{0} copy(%param_0) -// CHECK: } // CHECK: ENTRY %main {{.*}} { // CHECK: %parameter.1 = f32[5]{0} parameter(0) -// CHECK: %wrapped_copy = f32[5]{0} fusion(%parameter.1), kind=kLoop, calls=%wrapped_copy_computation -// CHECK: %tuple = (f32[5]{0}) tuple(%wrapped_copy) +// CHECK: %copy.3 = f32[5]{0} copy(%parameter.1) +// CHECK: %tuple = (f32[5]{0}) tuple(%copy.3) // CHECK: ROOT %while.19 = (f32[5]{0}) while(%tuple), condition=%cond, body=%body // CHECK: })"); } diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc index 228c32a9babf51..6973521e783734 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc @@ -39,7 +39,7 @@ class GemmBroadcastFoldingRewriteTest : public GpuCodegenTest { .gpu_compute_capability(); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // These tests test the cuBLAS rewriter so we have to make sure that we use // cuBLAS for them. diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index e1d21f1827c6f6..c52c44ab514c32 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -635,14 +635,14 @@ class Decision { // Returns true if it's profitable to fuse. bool WantToFuse() const { return fusing_decision_.CanFuse(); } - static Decision Accept() { return {FusionDecision(), true}; }; + static Decision Allow() { return {FusionDecision::Allow(), true}; }; - static Decision Decline(std::string_view value) { - return {FusionDecision(value), false}; + static Decision Deny(std::string_view value) { + return {FusionDecision::Forbid(value), false}; } static Decision NotProfitable(std::string_view value) { - return {FusionDecision(value), true}; + return {FusionDecision::Forbid(value), true}; } private: @@ -670,7 +670,7 @@ absl::StatusOr CreateDotFusion( legacy_triton::IsTritonSupportedInstruction(dot, gpu_version); !is_supported) { VLOG(3) << is_supported.Explain(); - return Decision::Decline(is_supported.Explain()); + return Decision::Deny(is_supported.Explain()); } // Verify sparse dot constraints. @@ -729,7 +729,7 @@ absl::StatusOr CreateDotFusion( dot, TritonFusionAnalysis::Scope::LHS) || !analysis.IsBatchDimMinorForInt4Parameter( dot, TritonFusionAnalysis::Scope::RHS)) { - return Decision::Decline( + return Decision::Deny( "Fusion is not possible because the parameter with the type S4 has " "minor batch dimension."); } @@ -740,9 +740,11 @@ absl::StatusOr CreateDotFusion( dot.precision_config().algorithm(); if (algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6 || algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3 || + algorithm == PrecisionConfig::ALG_DOT_BF16_BF16_F32 || + algorithm == PrecisionConfig::ALG_DOT_TF32_TF32_F32_X3 || dot.GetModule()->config().debug_options().xla_gpu_triton_gemm_any() || dot.sparse_operands()) { - return Decision::Accept(); + return Decision::Allow(); } bool is_pure_matmul = true; @@ -757,10 +759,10 @@ absl::StatusOr CreateDotFusion( } return absl::OkStatus(); }); - if (is_pure_matmul) { - return Decision::NotProfitable("Pure Matmul"); - } - return Decision::Accept(); + + if (is_pure_matmul) return Decision::NotProfitable("Pure Matmul"); + + return Decision::Allow(); } // Extracts into fused computations parts of HLO graph including dot() diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index 38cf1e66bfc551..ebda001f2a7ba1 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -27,14 +27,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/device_description.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" @@ -55,7 +55,7 @@ class GemmFusionTest : public HloTestBase { : HloTestBase(/*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false) {} - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_gemm_any(false); debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); @@ -440,7 +440,7 @@ ENTRY e { class GemmFusionLevel2Test : public GemmFusionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_triton_fusion_level(2); return debug_options; @@ -1241,7 +1241,7 @@ ENTRY e { // A test fixture class for testing the threshold for small matrices. class SmallDotGemmFusionTest : public GemmFusionTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmFusionTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100); return debug_options; @@ -1352,30 +1352,21 @@ ENTRY main { EXPECT_FALSE(result.ok()); } -constexpr auto kInt4Dot = R"( -ENTRY e { - p0 = s8[16,16] parameter(0) - p1 = s4[16,16] parameter(1) - p1c = bf16[16,16] convert(p1) - ROOT dot = bf16[16,16] dot(p0, p1c), - lhs_contracting_dims={1}, rhs_contracting_dims={0} -})"; - TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { + constexpr auto kInt4Dot = R"( + ENTRY e { + p0 = s8[16,16] parameter(0) + p1 = s4[16,16] parameter(1) + p1c = bf16[16,16] convert(p1) + ROOT dot = bf16[16,16] dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); } -TEST_F(SmallDotGemmFusionTest, Int4DotIsNotRewritten) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kInt4Dot)); - EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); -} - TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { const std::string kInt4Dot = R"( ENTRY main { @@ -1390,9 +1381,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that the lhs is not converted. @@ -1417,9 +1405,6 @@ TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); // Check that the fusion is present and that convert and negation is fused in // it. @@ -1446,9 +1431,6 @@ TEST_F(SmallDotGemmFusionTest, Int4WithMinorBatchDimIsNotRewritten) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kInt4Dot)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_enable_triton_gemm_int4(true); TF_ASSERT_OK_AND_ASSIGN(auto result, GemmFusion(gpu_version_).Run(module.get())); EXPECT_FALSE(result); diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index 32ed147415b962..8581c5c35ca4cb 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -46,8 +46,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" #include "xla/service/gpu/backend_configs.pb.h" @@ -63,13 +65,13 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/semantic_version.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace xla { namespace gpu { @@ -362,27 +364,61 @@ std::optional MatchFp8Param(HloInstruction *instr) { // dimension. Keeps the layout the same. HloInstruction *TransposeMatrix(HloInstruction *instr, int64_t contracting_dim, absl::Span batch_dims) { + auto input_shape = instr->shape(); // Identify the dimensional order which describes a transpose of the // contracting and non-contracting dimensions of the GEMM. - std::vector permutation(instr->shape().dimensions_size(), -1); + std::vector permutation(input_shape.dimensions_size(), -1); // Discard the batch dimensions. for (int64_t batch_dim : batch_dims) { permutation[batch_dim] = batch_dim; } // Identify the non-contracting dimension. int non_contracting_dim; - for (int i = 0; i < instr->shape().dimensions_size(); ++i) { + for (int i = 0; i < input_shape.dimensions_size(); ++i) { if (permutation[i] == -1 && contracting_dim != i) { non_contracting_dim = i; } } - permutation[non_contracting_dim] = contracting_dim; - permutation[contracting_dim] = non_contracting_dim; - Shape new_shape = ShapeUtil::PermuteDimensions(permutation, instr->shape()); - *new_shape.mutable_layout() = instr->shape().layout(); - return instr->AddInstruction( - HloInstruction::CreateTranspose(new_shape, instr, permutation)); + if (Layout::Equal()(input_shape.layout(), + LayoutUtil::GetDefaultLayoutForShape(input_shape))) { + permutation[non_contracting_dim] = contracting_dim; + permutation[contracting_dim] = non_contracting_dim; + + Shape new_shape = ShapeUtil::PermuteDimensions(permutation, input_shape); + *new_shape.mutable_layout() = input_shape.layout(); + + return instr->AddInstruction( + HloInstruction::CreateTranspose(new_shape, instr, permutation)); + } + + Shape normalized_input_shape = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto a0 = MakeBitcastHlo(instr, normalized_input_shape); + + std::vector layout_permuation( + input_shape.layout().minor_to_major().begin(), + input_shape.layout().minor_to_major().end()); + absl::c_reverse(layout_permuation); + auto inv_perm = InversePermutation(layout_permuation); + + int new_contracting_dim = inv_perm[contracting_dim]; + int new_non_contracting_dim = inv_perm[non_contracting_dim]; + absl::c_iota(permutation, 0); + std::swap(permutation[new_contracting_dim], + permutation[new_non_contracting_dim]); + + Shape transpose_shape = + ShapeUtil::PermuteDimensions(permutation, a0->shape()); + *transpose_shape.mutable_layout() = a0->shape().layout(); + + HloInstruction *normalized_transpose = instr->AddInstruction( + HloInstruction::CreateTranspose(transpose_shape, a0, permutation)); + + Shape final_shape = ShapeUtil::PermuteDimensions(inv_perm, transpose_shape); + *final_shape.mutable_layout() = input_shape.layout(); + return MakeBitcastHlo(normalized_transpose, final_shape); } // If the bias is a sequence of ops that depend only on broadcasts of @@ -656,23 +692,53 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { break; } case GemmRewriterOptions::DType::kNonFp8Only: { - // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( - absl::string_view gemm_custom_call_target, - GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); - const Shape &output_shape = instr->shape(); - HloInstruction *gemm_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + if (gemm_backend_config.precision_config().algorithm() == + PrecisionConfig::ALG_DOT_BF16_BF16_F32) { + TF_RETURN_IF_ERROR(TurnDotIntoConvertAndDotForBF16BF16F32( + instr, gemm_backend_config, gpu_backend_config)); + } else { + // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {instr->mutable_operand(0), instr->mutable_operand(1)}, + gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + } } break; }; return absl::OkStatus(); } + absl::Status TurnDotIntoConvertAndDotForBF16BF16F32( + HloInstruction *instr, GemmBackendConfig &gemm_backend_config, + GpuBackendConfig &gpu_backend_config) { + auto lhs_shape = instr->operand(0)->shape(); + lhs_shape.set_element_type(BF16); + auto lhs_convert = instr->mutable_operand(0)->AddInstruction( + HloInstruction::CreateConvert(lhs_shape, instr->mutable_operand(0))); + auto rhs_shape = instr->operand(1)->shape(); + rhs_shape.set_element_type(BF16); + auto rhs_convert = instr->mutable_operand(1)->AddInstruction( + HloInstruction::CreateConvert(rhs_shape, instr->mutable_operand(1))); + gemm_backend_config.mutable_precision_config()->clear_algorithm(); + TF_ASSIGN_OR_RETURN( + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {lhs_convert, rhs_convert}, gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + return absl::OkStatus(); + } + absl::Status HandleMultiply(HloInstruction *instr) override { HloInstruction *alpha, *existing_gemm; if (Match(instr, @@ -1083,12 +1149,18 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 // format. Set the factors to one when no scaling factors were captured. - Literal one_literal = LiteralUtil::One(F32); - HloInstruction *one = instr->AddInstruction( - HloInstruction::CreateConstant(one_literal.Clone())); std::array mult_scale{a.mult_scale, b.mult_scale}; std::array scales{a.scale, b.scale}, inv_scales, scales_f32; + HloInstruction *one_constant = nullptr; + auto one = [&one_constant, instr]() -> HloInstruction * { + if (!one_constant) { + one_constant = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); + } + return one_constant; + }; + for (int i = 0; i < scales.size(); ++i) { if (scales[i]) { if (!ShapeUtil::IsScalar(scales[i]->shape())) { @@ -1099,7 +1171,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (!mult_scale[i]) { inv_scales[i] = instr->AddInstruction(HloInstruction::CreateBinary( - scales[i]->shape(), HloOpcode::kDivide, one, scales[i])); + scales[i]->shape(), HloOpcode::kDivide, one(), scales[i])); } scales_f32[i] = mult_scale[i] ? scales[i] : inv_scales[i]; if (scales_f32[i]->shape().element_type() != F32) { @@ -1107,7 +1179,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { ShapeUtil::MakeScalarShape(F32), scales_f32[i])); } } else { - scales_f32[i] = one; + scales_f32[i] = one(); } } @@ -1249,7 +1321,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { PadShapeToMultipleOf16(instr->shape(), out_batch_dims); std::vector operands_list = { - a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one}; + a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1]}; HloInstruction *new_custom_call = instr->AddInstruction(HloInstruction::CreateCustomCall( @@ -1415,13 +1487,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - // If necessary, invert the scaling factor of D and convert to F32. + // If necessary, invert the scaling factor of D and convert to F32. When no + // scaling factor was captured, set the factor to one. if (d_scale) { TF_ASSIGN_OR_RETURN(d_scale, InvertAndConvertScalar(d_scale, !mult_scale)); - TF_RETURN_IF_ERROR(existing_gemm->ReplaceOperandWith( - gemm_backend_config.beta() == 0.0 ? 5 : 6, d_scale)); + } else { + d_scale = instr->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(F32))); } + existing_gemm->AppendOperand(d_scale); // If present, elide the calculation of the maximum of the absolute values // of the result of the GEMM. @@ -1887,12 +1962,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { if (!absl::c_linear_search(supported_type, output_type)) return false; TF_ASSIGN_OR_RETURN(const se::blas::DataType output_dtype, se::gpu::AsBlasDataType(output_type)); - // TODO(tdanyluk): Investigate why don't we use the actual precision (and - // algorithm) here? Why do we use the default? - TF_ASSIGN_OR_RETURN(const se::blas::ComputationType compute_type, - se::gpu::GetBlasComputationType( - PrecisionConfig::ALG_UNSET, a_dtype, output_type, - stream_executor::blas::kDefaultComputePrecision)); + TF_ASSIGN_OR_RETURN( + const se::blas::ComputationType compute_type, + se::gpu::GetBlasComputationType( + instr.precision_config().algorithm(), a_dtype, output_type, + stream_executor::blas::kDefaultComputePrecision)); se::blas::DataType scale_type = se::gpu::GetScaleType(output_dtype, compute_type); @@ -1983,7 +2057,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { backend_config.precision_config().operand_precision()); const PrecisionConfig::Algorithm algorithm = backend_config.precision_config().algorithm(); - if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm)) return false; + if (!algorithm_util::IsSupportedByCublasOrCublasLt(algorithm, gpu_version_)) + return false; TF_ASSIGN_OR_RETURN( const se::blas::ComputationType compute_type, diff --git a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index 5ae8571d626704..9e691d859a316c 100644 --- a/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -94,7 +94,7 @@ class GemmRewriteTest : public GpuCodegenTest { } } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // These tests test the cuBLAS rewriter so we have to make sure that we use // cuBLAS for them. @@ -288,7 +288,7 @@ class ParameterizedGemmRewriteTest replacements_[kCustomCallTargetPlaceholder] = kUsingCublasLt ? "__cublas$lt$matmul" : "__cublas$gemm"; } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cublaslt(GetParam()); debug_options.set_xla_gpu_enable_triton_gemm(false); @@ -1447,7 +1447,7 @@ INSTANTIATE_TEST_SUITE_P(CublasTestsBothLegacyAndLt, // A test fixture class for tests which are specific to legacy cublas class LegacyCublasGemmRewriteTest : public GemmRewriteTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_triton_gemm(false); debug_options.set_xla_gpu_enable_cublaslt(false); @@ -2204,7 +2204,7 @@ ENTRY test { // A test fixture class for tests which are specific to cublasLt class CublasLtGemmRewriteTest : public GemmRewriteTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cublaslt(true); debug_options.set_xla_gpu_enable_triton_gemm(false); @@ -4951,11 +4951,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks.append( - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } else { checks.append( - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), )"); } checks.append( @@ -5010,7 +5010,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5033,6 +5033,58 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { )"); } +TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) { + const char* hlo_text = R"( +HloModule test + ENTRY test { + x = <>[2,64,32]{1,2,0} parameter(0) + y = <>[2,32,16]{2,1,0} parameter(1) + x_scale = f32[] parameter(2) + y_scale = f32[] parameter(3) + dq_scale = f32[] multiply(x_scale, y_scale) + dq_scale_bcast = f32[2,64,16] broadcast(dq_scale), dimensions={} + out.0 = f32[2,64,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0} + ROOT out = f32[2,64,16] multiply(out.0, dq_scale_bcast) + } +)"; + + CheckFp8IfSupported(hlo_text); + RunAndFilecheckHloRewrite( + hlo_text, + GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), + R"( +; CHECK-LABEL: ENTRY %test ({{.*}}: <>[2,64,32], {{.*}}: <>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,64,16] { +; CHECK-NEXT: [[P0:%[^ ]+]] = <>[2,64,32]{1,2,0} parameter(0) +; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <>[2,32,64]{2,1,0} bitcast([[P0]]) +; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <>[2,64,32]{2,1,0} transpose([[P0_BT]]), dimensions={0,2,1} +; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <>[2,32,64]{1,2,0} bitcast([[P0_TR]]) +; CHECK-NEXT: [[P1:%[^ ]+]] = <>[2,32,16]{2,1,0} parameter(1) +; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} +; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]]) +; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,64,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]]), +; CHECK: custom_call_target="__cublas$lt$matmul$f8", +; CHECK: backend_config={ +; CHECK-DAG: "alpha_real":1 +; CHECK-DAG: "alpha_imag":0 +; CHECK-DAG: "beta":0 +; CHECK-DAG: "dot_dimension_numbers":{ +; CHECK-DAG: "lhs_contracting_dimensions":["1"] +; CHECK-DAG: "rhs_contracting_dimensions":["2"] +; CHECK-DAG: "lhs_batch_dimensions":["0"] +; CHECK-DAG: "rhs_batch_dimensions":["0"] +; CHECK-DAG: } +; CHECK-DAG: "precision_config":{ +; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"] +; CHECK-DAG: } +; CHECK-DAG: "epilogue":"DEFAULT" +; CHECK: } + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { const char* hlo_text = R"( HloModule test @@ -5065,8 +5117,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5122,8 +5173,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { ; CHECK-NEXT: [[P1_TRANSPOSE_PADDED:%[^ ]+]] = <>[32,32]{1,0} pad([[P1_TRANSPOSE]], [[C1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C4:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]], [[C4]], /*index=5*/[[C4]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5206,7 +5256,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C1]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5270,8 +5320,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5329,7 +5378,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[32,16]{1,0} parameter(1) ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]], [[C2]], /*index=5*/[[C2]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_U3]], [[P1_TRANSPOSE]], [[C2]], [[C2]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5390,8 +5439,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { ; CHECK-NEXT: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[DYN_SLICE]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5457,8 +5505,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { ; CHECK-NEXT: [[SELECT:%[^ ]+]] = <>[16,32]{1,0} select([[P4]], [[P1]], [[C0_CONVERT]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[SELECT]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5542,8 +5589,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[10,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[10,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5599,8 +5645,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":3 @@ -5656,8 +5701,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -5732,15 +5776,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += R"(; CHECK-NEXT: [[B:%[^ ]+]] = bf16[16]{0} parameter(4) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]], [[B]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[B]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5832,15 +5875,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[XS:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = bf16[] parameter(3) ; CHECK-NEXT: [[XS1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) )"; if (IsRocm() && GetToolkitVersion() < se::SemanticVersion{6, 2, 0}) { checks += - R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } else { checks += - R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]], [[C1]], /*index=5*/[[C1]]), + R"(; CHECK-NEXT: [[OUT:%[^ ]+]] = (bf16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[XS]], [[XS1]]), )"; } checks += R"(; CHECK: custom_call_target="__cublas$lt$matmul$f8", @@ -5944,8 +5986,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { ; CHECK: [[C0:%[^ ]+]] = f32[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]], /*index=5*/[[C1]], [[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: output_to_operand_aliasing={ ; CHECK-SAME: {0}: (2, {}) @@ -6010,8 +6051,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[16,16]{1,0} pad([[P2]], [[C2]]), padding=0_2x0_2 ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C3:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]], /*index=5*/[[C3]], [[C3]]), +; CHECK-NEXT: [[DOT_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_TRANSPOSE_PADDED]], [[P2_PADDED]], [[P3]], [[P4]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6068,8 +6108,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C2]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6118,7 +6159,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P2_INV:%[^ ]+]] = f32[] divide([[C0]], [[P2]]) ; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]], [[C1]], /*index=5*/[[C1]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_INV]], [[C1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6165,7 +6206,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]]), +; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6216,7 +6257,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[16,16]{1,0} parameter(2) ; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]], /*index=5*/[[C0]], [[C0]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[C0]], [[C0]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6281,12 +6322,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6391,12 +6431,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6473,11 +6512,10 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { ; CHECK: [[C0:%[^ ]+]] = f16[16,16]{1,0} add({{.*}}) ; CHECK-NEXT: [[P2:%[^ ]+]] = f16[] parameter(3) ; CHECK: [[P3:%[^ ]+]] = f16[] parameter(4) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: [[P4:%[^ ]+]] = f16[] parameter(5) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[DUMMY2:%[^ ]+]]), ; CHECK-NOT: output_to_operand_aliasing -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]], /*index=5*/[[C1]], [[DUMMY2:%[^ ]+]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[C0]], [[DUMMY0:%[^ ]+]], [[DUMMY1:%[^ ]+]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6544,14 +6582,13 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) +; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(5) ; CHECK-PTX-NEXT: [[DV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[CV2:%[^ ]+]] = f32[] convert([[DV]]) -; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[CV2]], [[VB]]), -; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK-PTX: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]], /*index=5*/[[CV2]]), +; CHECK-GCN: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6608,10 +6645,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f32[16]{0} parameter(2) ; CHECK-NEXT: [[VBC:%[^ ]+]] = bf16[16]{0} convert([[VB]]) -; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]], [[VBC]]), +; CHECK: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[VBC]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6671,9 +6707,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[VB:%[^ ]+]] = f16[16]{0} parameter(2) -; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[C]], /*index=5*/[[C]], [[VB]]), +; CHECK : ROOT [[OUT:%[^ ]+]] = f16[16,16]{1,0} custom-call([[P0]], [[P1_TRANSPOSE]], [[CV]], [[CV1]], [[VB]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6745,10 +6780,9 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[32]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[32]{0} convert([[B]]) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[B_F16]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[P2_CV]], [[P3_CV]], [[B_F16]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6829,12 +6863,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CV:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(4) ; CHECK-NEXT: [[P3_CV:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) ; CHECK-NEXT: [[B:%[^ ]+]] = f32[31]{0} parameter(2) ; CHECK-NEXT: [[B_F16:%[^ ]+]] = f16[31]{0} convert([[B]]) ; CHECK-NEXT: [[C3:%[^ ]+]] = f16[] constant(0) ; CHECK-NEXT: [[P2_PAD:%[^ ]+]] = f16[32]{0} pad([[B_F16]], [[C3]]), padding=0_1 -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[C]], /*index=5*/[[C]], [[P2_PAD]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f16[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PAD]], [[P1_PAD]], [[P2_CV]], [[P3_CV]], [[P2_PAD]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6907,8 +6940,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { ; CHECK-NEXT: [[B_BITCAST:%[^ ]+]] = f32[64,32]{1,0} bitcast([[B]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[64,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BITCAST]], [[P1_TRANSPOSE]], [[B_BITCAST]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -6989,8 +7021,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_PADDED:%[^ ]+]] = f32[48,32]{1,0} pad([[B_BITCAST]], [[C3]]), padding=0_3x0_1 ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]], /*index=5*/[[C]], [[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_PADDED]], [[P1_PADDED]], [[P2_PADDED]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7055,8 +7086,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,16]{1,0} transpose([[P1]]), dimensions={1,0} ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(3) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(4) -; CHECK-NEXT: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK-NEXT: [[GEMM_TUPLE:%[^ ]+]] = (f32[48,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7118,8 +7148,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { ; CHECK: [[P1_TRANSPOSE:%[^ ]+]] = <>[32,64]{1,0} transpose([[AG1]]), dimensions={1,0} ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM_TUPLE:%[^ ]+]] = (f32[16,32]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AG]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7176,8 +7205,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7234,8 +7262,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK: [[P1:%[^ ]+]] = <>[16,32]{1,0} parameter(1) ; CHECK: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK: [[C:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]], [[C]], /*index=5*/[[C]]), +; CHECK: [[GEMM:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[AA]], [[P1]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7297,8 +7324,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[CV0:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(5) ; CHECK-NEXT: [[CV1:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK: [[C1:%[^ ]+]] = f32[] constant(1) -; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]], /*index=5*/[[C1]], [[C1]]), +; CHECK: [[GEMMOUT_TUPLE:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[MB]], [[CV0]], [[CV1]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7373,12 +7399,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7454,13 +7479,12 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P2_CONVERT:%[^ ]+]] = f32[] convert([[P2]]) ; CHECK-NEXT: [[P3:%[^ ]+]] = f16[] parameter(3) ; CHECK-NEXT: [[P3_CONVERT:%[^ ]+]] = f32[] convert([[P3]]) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f16[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f16[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f16[] divide([[C2]], [[P4]]) ; CHECK-PTX-NEXT: [[P4_INV_CONVERT:%[^ ]+]] = f32[] convert([[P4_INV]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[P4_INV_CONVERT]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]], [[P4_INV_CONVERT]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f16[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2_CONVERT]], [[P3_CONVERT]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7534,12 +7558,11 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[C2:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX-NEXT: [[P4:%[^ ]+]] = f32[] parameter(4) ; CHECK-PTX-NEXT: [[P4_INV:%[^ ]+]] = f32[] divide([[C2]], [[P4]]) -; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[P4_INV]]), -; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[C1]], /*index=5*/[[C1]]), +; CHECK-PTX-NEXT: [[OUT:%[^ ]+]] = (<>[16,16]{1,0}, f32[], s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]], [[P4_INV]]), +; CHECK-GCN-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0]], [[P1_TRANSPOSE]], [[P2]], [[P3]]), ; CHECK: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ ; CHECK-DAG: "alpha_real":1 @@ -7809,7 +7832,6 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { ; CHECK-GCN-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <>[16,32]{1,0} transpose([[P1]]) ; CHECK-GCN-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2) ; CHECK-GCN-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3) -; CHECK-GCN-NEXT: [[C1:%[^ ]+]] = f32[] constant(1) ; CHECK-PTX: custom_call_target="<>", ; CHECK-GCN: custom_call_target="__cublas$lt$matmul$f8", ; CHECK: backend_config={ @@ -8024,7 +8046,7 @@ class GemmRewriteAllocationTest : public GpuCodegenTest { ASSERT_EQ(allocations.size(), expected_number_of_allocations); } - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // Make sure the rewriter does not skip the rewrite for being too small. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); @@ -8057,7 +8079,7 @@ ENTRY AddDotsFunc { class SmallDotGemmRewriteTest : public GemmRewriteTest { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = GemmRewriteTest::GetDebugOptionsForTest(); debug_options.set_xla_gpu_gemm_rewrite_size_threshold(100); return debug_options; @@ -8105,6 +8127,34 @@ ENTRY DotFunc { )"); } +TEST_F(SmallDotGemmRewriteTest, RewriteForALG_BF16_BF16_F32) { + if (!HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) { + GTEST_SKIP() + << "There is no autotuning starting with the Nvidia Ampere generation"; + } + + const char* hlo_text = R"( + HloModule RewriteForALG_BF16_BF16_F32 + + ENTRY DotFunc { + x = f32[1024,1024] parameter(0) + y = f32[1024,1024] parameter(1) + ROOT out = f32[1024,1024] dot(x, y), + algorithm=dot_bf16_bf16_f32, + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + )"; + + MatchOptimizedHlo(hlo_text, + R"( +; CHECK-LABEL: ENTRY %DotFunc ({{.*}}: f32[1024,1024], {{.*}}: f32[1024,1024]) -> f32[1024,1024] { +; CHECK-NEXT: [[P0:%[^ ]+]] = f32[1024,1024]{1,0} parameter(0) +; CHECK-NEXT: [[P1:%[^ ]+]] = f32[1024,1024]{1,0} parameter(1) +; CHECK: [[GEMM:%[^ ]+]] = {{.*}} custom-call({{.*}}), custom_call_target="__cublas$gemm", {{.*}},"algorithm":"ALG_UNSET" +)"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index befe869ac072df..bb09aa02c77de5 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -95,7 +95,8 @@ std::vector FindAndSortFusionCandidates( // Find out the input fusion instructions whose only consumer is `consumer`. // This guarantees that fusing these candidates will never create cycles, as // there is no back edge. - if (IsInputFusibleReduction(*predecessor) && + if (!predecessor->IsCustomFusion() && + IsInputFusibleReduction(*predecessor) && IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { if (fusion_instr_set.insert(predecessor).second) { fusion_instrs.push_back(predecessor); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc index 5fc1a54acd8d53..dc8f5f3bfa1b5f 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc @@ -265,6 +265,38 @@ TEST_F(HorizontalInputFusionTest, NonfusionInstrs) { GmockMatch(m::Tuple(m::Reduce(), m::Reduce()))); } +TEST_F(HorizontalInputFusionTest, DoesNotFuseCustomFusions) { + auto module = ParseAndReturnVerifiedModule(R"( +max { + p0 = f16[] parameter(0) + p1 = f16[] parameter(1) + ROOT max = f16[] maximum(p0, p1) +} + +triton_a { + p = f16[128,256] parameter(0) + c = f16[] constant(0) + ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max +} + +triton_b { + p = f16[128,256] parameter(0) + c = f16[] constant(0) + ROOT n = f16[128] reduce(p, c), dimensions={1}, to_apply=max +} + + ENTRY entry_computation { + p = f16[128,256] parameter(0) + fa = f16[128] fusion(p), kind=kCustom, calls=triton_a + fb = f16[128] fusion(p), kind=kCustom, calls=triton_b + ROOT tuple = (f16[128], f16[128]) tuple(fa, fb) + } +)") + .value(); + + EXPECT_FALSE(horizontal_input_fusion_.Run(module.get()).value()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 0a3d705103c416..c45061805c68f5 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -37,10 +37,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" #include "xla/layout_util.h" +#include "xla/primitive_util.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/sub_byte_normalization.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -307,7 +308,6 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( << " other fusion candidates, instr: " << instr->ToString(); continue; } else { - VLOG(2) << "Find a fusion candidate " << instr->ToString(); // Encapsulate it into a fusion computation for unified representation // for later processing. fusible_instrs_.push_back(instr); @@ -346,6 +346,9 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { return absl::Span(); } + // CUDA has a parameter size limit of ~4k bytes. + constexpr int64_t kMaxCudaParamSize = 4000; + // Fusing too many computations at a time may not be easily profitable and // may increase compile time due to large kernels. Set a limit to it. // From profiling results, we found an issue that large fused horizontal @@ -366,23 +369,21 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { }(); size_t left = pos_; - size_t right = pos_ + 1; - size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]); - PrimitiveType first_output_type = - GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]); - // CUDA has a parameter size limit of ~4k bytes. - constexpr int64_t kMaxCudaParamSize = 4000; - size_t accum_io_size = 0; + size_t right = pos_; size_t accum_num_outputs = 0; + size_t accum_io_size = 0; + for (; right < fusible_instrs_.size(); ++right) { - PrimitiveType cur_output_type = - GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]); - if (first_output_type != cur_output_type) { + if (GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]) != + GetUniqueOutputTypeOfFusible(*fusible_instrs_[right])) { // Cannot fuse computations who have multiple output types. + VLOG(2) << "different multiple output types"; break; } - if (first_output_size != GetOutputSizeOfFusible(*fusible_instrs_[right])) { + if (GetOutputSizeOfFusible(*fusible_instrs_[left]) != + GetOutputSizeOfFusible(*fusible_instrs_[right])) { // Cannot fuse computations who have different numbers of outputs. + VLOG(2) << "different number of outputs"; break; } if (GetInstrCountOfFusible(*fusible_instrs_[left]) != @@ -391,6 +392,7 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { // introduce control divergence. This is a very simple heuristic to avoid // fusing computations with too much discrepancy and we may improve it // when the needs arise. + VLOG(2) << "different instruction count"; break; } if (!sliced_input_fusion_ && @@ -399,19 +401,30 @@ HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() { GetOutputsOfFusible(*fusible_instrs_[right])[0]->shape())) { // This is for fusing into kLoop type kernel, so we requires that each // fusion operand have the same shape + VLOG(2) << "different output shape"; break; } size_t num_outputs = GetOutputSizeOfFusible(*fusible_instrs_[right]); accum_num_outputs += num_outputs; if (accum_num_outputs >= kMaxFusionBatchSize) { // Hit max fusion batch size. + VLOG(2) << "hit max fusion batch size: " << accum_num_outputs; break; } accum_io_size += fusible_instrs_.at(right)->operand_count() + num_outputs; if (accum_io_size * 8 >= kMaxCudaParamSize) { + VLOG(2) << "hit max cuda param size: " << accum_io_size; break; } } + + // If right was not incremented, it means that the `left` instruction already + // exceeds one of the limits. We can't do anything about that fusion here, so + // we return a span of one instruction. + if (left == right) { + ++right; + } + VLOG(2) << "horizontal fuse get instruction span with " << (right - left) << " instructions for sliced_input_fusion=" << sliced_input_fusion_ << " fusion"; @@ -438,6 +451,7 @@ absl::StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( // `Fuse()`. std::vector fusion_instrs; for (HloInstruction* instr : fusibles) { + VLOG(2) << "next candidate: " << instr->ToString(); if (instr->opcode() == HloOpcode::kFusion) { fusion_instrs.push_back(instr); } else { diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index d3fb82e9d4b05f..b40922b72f9d6e 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -26,12 +26,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/transforms/instruction_fusion.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index 5e32f2ec0c2ee1..bfd8c5bbb6b0a9 100644 --- a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -91,31 +91,34 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( // Output fusions are not currently supported on GPUs. if (producer->opcode() == HloOpcode::kFusion) { - return "the producer is a fusion"; + return FusionDecision::Forbid("the producer is a fusion"); } if (consumer->IsCustomFusion()) { - return "the consumer is a custom fusion"; + return FusionDecision::Forbid("the consumer is a custom fusion"); } // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). if (is_expensive(*producer) && ReusesOperandElements(consumer, operand_index)) { - return "the producer is expensive, and the consumer reuses inputs"; + return FusionDecision::Forbid( + "the producer is expensive, and the consumer reuses inputs"); } // Do not fuse into fusions if the resulting kernel would suffer from // uncoalesced reads due to a transposed memory access pattern. if (IsInputFusibleReduction(*consumer) && IsPhysicallyTransposing(*producer)) { - return "fusing the producer would break read coalescing"; + return FusionDecision::Forbid( + "fusing the producer would break read coalescing"); } RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer)); if (CreatesHeavyComputation(*producer, *consumer)) { - return "the fusion would create a heavy computation"; + return FusionDecision::Forbid( + "the fusion would create a heavy computation"); } return InstructionFusion::ShouldFuse(consumer, operand_index); @@ -133,7 +136,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, /*is_consumer_producer_fusion=*/true)); if (consumer->opcode() != HloOpcode::kFusion) { - return {}; + return FusionDecision::Allow(); } // Also check that our emitter can handle the fusion node. We currently can @@ -149,9 +152,10 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, FusionNodeIndexingEvaluation(consumer)); } if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) { - return "the fusion would result in an overly large code duplication"; + return FusionDecision::Forbid( + "the fusion would result in an overly large code duplication"); } - return {}; + return FusionDecision::Allow(); } HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 9af3e7e04d4d47..d4c44f6b78df7a 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/reduction_utils.h" #include "xla/service/gpu/stream_executor_util.h" @@ -363,8 +364,11 @@ absl::Status GpuLayoutAssignment::AddBackendConstraints( output_shape.dimensions_size() == 2 && lhs_shape.dimensions_size() == 2 && rhs_shape.dimensions_size() == 2); + bool is_fp8_to_fp8 = + (lhs_shape.element_type() == PrimitiveType::F8E4M3FN && + rhs_shape.element_type() == PrimitiveType::F8E4M3FN); - if (is_s8_to_s32 || + if (is_s8_to_s32 || is_fp8_to_fp8 || (is_bf16_to_bf16 && debug_options.xla_gpu_ensure_minor_dot_contraction_dims())) { TF_RETURN_IF_ERROR(SetOperandMajorToMinorLayout( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 4dbd453e1d4850..a0324839cd014c 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -25,11 +25,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/computation_layout.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -655,6 +655,35 @@ ENTRY main { LayoutUtil::MakeLayout({1, 3, 2, 0}).minor_to_major()); } +TEST_F(LayoutAssignmentTest, AutoLayoutE4M3ContractingMinorFirst) { + const char* hlo = R"( + + HloModule jit_dot_general_f8e4m3fn + + ENTRY main { + p0 = f8e4m3fn[128,5120] parameter(0) + p1 = f8e4m3fn[5120,10240] parameter(1) + ROOT dot = f32[128,10240] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr m, + ParseAndReturnUnverifiedModule( + hlo, {}, HloParserOptions().set_fill_missing_layouts(false))); + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape(), + /*ignore_layouts=*/false); + GpuLayoutAssignment layout_assignment( + &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Dot(m::Parameter(0).WithShape(F8E4M3FN, {128, 5120}, {1, 0}), + m::Parameter(1).WithShape(F8E4M3FN, {5120, 10240}, {0, 1})) + .WithShape(F32, {128, 10240}, {1, 0}))); +} + TEST_F(LayoutAssignmentTest, VariadicReduceSameOperandLayout) { const char* module_str = R"( HloModule variadic_reduce diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 04456d8131ac76..e83c2b9a1ae7ae 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -30,8 +30,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -86,14 +86,16 @@ const HloSliceInstruction* FindUniqueSlice(const HloInstruction* parent, FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1, const HloInstruction& instr2, const HloInstruction* parent) { - if (parent->shape().IsTuple()) return {}; + if (parent->shape().IsTuple()) return FusionDecision::Allow(); // Allow MOF if the parameter is small, even if there's no overlap. 1024 bytes // were arbitrarily chosen as the threshold. - if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) return {}; + if (ShapeUtil::ByteSizeOfElements(parent->shape()) < 1024) { + return FusionDecision::Allow(); + } const HloSliceInstruction* slice1 = FindUniqueSlice(parent, &instr1); const HloSliceInstruction* slice2 = FindUniqueSlice(parent, &instr2); - if (!slice1 || !slice2) return {}; + if (!slice1 || !slice2) return FusionDecision::Allow(); // TODO(jreiffers): Check strides as well. auto& starts1 = slice1->slice_starts(); @@ -104,10 +106,10 @@ FusionDecision ParameterSlicesAreNonOverlapping(const HloInstruction& instr1, for (int64_t dim = 0; dim < parent->shape().rank(); ++dim) { bool overlap = starts1[dim] < limits2[dim] && starts2[dim] < limits1[dim]; if (!overlap) { - return "slices are non-overlapping"; + return FusionDecision::Forbid("slices are non-overlapping"); } } - return {}; + return FusionDecision::Allow(); } FusionDecision LegalToFuse(const HloInstruction& instr1, @@ -125,7 +127,7 @@ FusionDecision LegalToFuse(const HloInstruction& instr1, (instr2.opcode() == HloOpcode::kFusion && instr2.fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice)) { - return "can't fuse multiple DUSs"; + return FusionDecision::Forbid("can't fuse multiple DUSs"); } // Do this check last, as it may be expensive. @@ -175,11 +177,11 @@ FusionDecision OperandReachableFromProducer( << "Reachability map is incomplete. This should never " "happen."; if (&producer != operand && reachability.IsReachable(&producer, operand)) { - return { - absl::StrCat(producer.name(), " would introduce a cycle when fused")}; + return FusionDecision::Forbid( + absl::StrCat(producer.name(), " would introduce a cycle when fused")); } } - return {}; + return FusionDecision::Allow(); } FusionDecision ProducerCandidateIsFusible( @@ -188,7 +190,8 @@ FusionDecision ProducerCandidateIsFusible( const se::DeviceDescription& device_info, GpuHloCostAnalysis* cost_analysis) { if (!IsFusibleAsMultiOutputFusionRoot(consumer)) { - return "consumer not eligible as multi-output fusion root."; + return FusionDecision::Forbid( + "consumer not eligible as multi-output fusion root."); } RETURN_IF_NOT_FUSIBLE( @@ -202,7 +205,7 @@ FusionDecision ProducerCandidateIsFusible( /*is_consumer_producer_fusion=*/false, fusion_info_cache)); if (cost_analysis->ProducerConsumerMergedTooLarge(producer, consumer)) { - return "will generate too large IR"; + return FusionDecision::Forbid("will generate too large IR"); } GpuPerformanceModel::RunTimes t = GpuPerformanceModel::EstimateRunTimes( @@ -211,10 +214,10 @@ FusionDecision ProducerCandidateIsFusible( /*fused_consumers=*/{&consumer}, /*multi_output=*/true); if (t.time_fused > t.time_unfused) { - return "will execute slower if fused"; + return FusionDecision::Forbid("will execute slower if fused"); } - return {}; + return FusionDecision::Allow(); } std::vector GetProducerConsumerMultiOutputFusionCandidates( @@ -283,8 +286,9 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, FusionInfoCache* fusion_info_cache, const se::DeviceDescription& device_info) { if (reachability.IsConnected(&sibling_consumer_1, &sibling_consumer_2)) { - return {absl::StrCat(sibling_consumer_1.name(), " and ", - sibling_consumer_2.name(), " are connected")}; + return FusionDecision::Forbid( + absl::StrCat(sibling_consumer_1.name(), " and ", + sibling_consumer_2.name(), " are connected")); } RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion( @@ -302,7 +306,7 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, // This check should be last, as it may be expensive. RETURN_IF_NOT_FUSIBLE(LegalToFuse(sibling_consumer_1, sibling_consumer_2, device_info, fusion_info_cache)); - return {}; + return FusionDecision::Allow(); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h index b56a3e38e52da5..a69c21c444d30d 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h @@ -21,8 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_dfs_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc index abd45a15538959..581e953727897a 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc @@ -42,23 +42,16 @@ namespace gpu { namespace m = ::xla::match; class MultiOutputFusionTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: MultiOutputFusion mof_{TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}; + HloCostAnalysis::DefaultShapeSize}; void CheckMultiOutputFusion(absl::string_view hlo, std::optional expected) { RunAndFilecheckHloRewrite( hlo, MultiOutputFusion{TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}, + HloCostAnalysis::DefaultShapeSize}, expected); } }; @@ -1762,7 +1755,7 @@ TEST_F(MultiOutputFusionTest, OverlappingRead) { } class TransposeMultiOutputFusionTest : public MultiOutputFusionTest { - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = MultiOutputFusionTest::GetDebugOptionsForTest(); // Only the MLIR transpose emitter supports unpadded 2D transposes. diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc new file mode 100644 index 00000000000000..2b5b75cce90782 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.cc @@ -0,0 +1,463 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/call_graph.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" +#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/instruction_fusion.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { + +namespace { +// Fuses the given instructions together. The instructions are expected to be +// passed in def-before-use order. The resulting fusion has a single root +// instruction, which is the last instructions in the input span. We only +// replace the uses of the root in 'consumer', and leave other users alone. +absl::Status FuseInstructionsForConsumer( + absl::Span instructions, HloInstruction& consumer) { + HloComputation::Builder builder(instructions.back()->name()); + + absl::flat_hash_map + old_to_new_mapping; + std::vector parameters; + + auto add_parameter = [&](HloInstruction* instruction) -> void { + int param_index = parameters.size(); + old_to_new_mapping[instruction] = + builder.AddInstruction(HloInstruction::CreateParameter( + param_index, instruction->shape(), + absl::StrCat("parameter_", param_index))); + parameters.push_back(instruction); + }; + + for (HloInstruction* instruction : instructions) { + if (old_to_new_mapping.contains(instruction)) { + continue; + } + + if (instruction->opcode() == HloOpcode::kParameter) { + add_parameter(instruction); + continue; + } + std::vector new_operands; + for (HloInstruction* operand : instruction->mutable_operands()) { + if (!old_to_new_mapping.contains(operand)) { + add_parameter(operand); + } + new_operands.push_back(old_to_new_mapping[operand]); + } + old_to_new_mapping[instruction] = builder.AddInstruction( + instruction->CloneWithNewOperands(instruction->shape(), new_operands)); + } + + HloInstruction* old_root = instructions.back(); + old_to_new_mapping[old_root]->MarkAsRoot(); + + HloComputation* computation = + old_root->GetModule()->AddComputationAndUnifyNamesAndIds( + builder.Build(), /*is_entry=*/false); + HloInstruction* fusion = + old_root->parent()->AddInstruction(HloInstruction::CreateFusion( + old_root->shape(), HloInstruction::FusionKind::kCustom, parameters, + computation)); + fusion->GetModule()->SetAndUniquifyInstrName(fusion, "block_fusion"); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kTritonFusionKind)); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + for (int64_t operand_index : consumer.OperandIndices(old_root)) { + TF_RETURN_IF_ERROR(consumer.ReplaceOperandWith(operand_index, fusion)); + } + + return absl::OkStatus(); +} + +// Annotates the given nested fusion with the given tile sizes. +// Implementation for AnnotateDotLhs/RhsNestedFusion(). +absl::Status AnnotateDotOperandNestedFusionImpl( + HloFusionInstruction& nested_fusion, const HloDotInstruction& dot, + const TritonGemmConfig& config, + absl::Span contracting_dimensions, // Must be single element + absl::Span batch_dimensions, int64_t contracting_dim_size, + int64_t non_contracting_dim_size) { + if (contracting_dimensions.size() != 1) { + return absl::InternalError( + absl::StrCat("Expected a single lhs contracting dimension but got ", + contracting_dimensions.size())); + } + + TF_ASSIGN_OR_RETURN( + std::vector non_contracting_dimensions, + GetNonContractingDims(dot.operand(0)->shape(), batch_dimensions, + contracting_dimensions)); + + if (non_contracting_dimensions.size() != 1) { + return absl::InternalError( + absl::StrCat("Expected a single non-contracting dimension but got ", + non_contracting_dimensions.size())); + } + + // We have a single contracting dimension, and a single non-contracting + // dimension. All the other output tile sizes are set to 1. + std::vector output_tile_sizes(dot.operand(0)->shape().rank(), 1); + output_tile_sizes[contracting_dimensions[0]] = contracting_dim_size; + output_tile_sizes[non_contracting_dimensions[0]] = non_contracting_dim_size; + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = std::move(output_tile_sizes); + + TF_ASSIGN_OR_RETURN(auto backend_config, + nested_fusion.backend_config()); + *backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + TF_RETURN_IF_ERROR(nested_fusion.set_backend_config(backend_config)); + + return absl::OkStatus(); +} + +absl::Status AnnotateDotLhsNestedFusion(HloFusionInstruction& nested_fusion, + const HloDotInstruction& dot, + const TritonGemmConfig& config) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + return AnnotateDotOperandNestedFusionImpl( + nested_fusion, dot, config, + dimension_numbers.lhs_contracting_dimensions(), + dimension_numbers.lhs_batch_dimensions(), config.block_k, config.block_m); +} + +absl::Status AnnotateDotRhsNestedFusion(HloFusionInstruction& nested_fusion, + const HloDotInstruction& dot, + const TritonGemmConfig& config) { + const DotDimensionNumbers& dimension_numbers = dot.dot_dimension_numbers(); + return AnnotateDotOperandNestedFusionImpl( + nested_fusion, dot, config, + dimension_numbers.rhs_contracting_dimensions(), + dimension_numbers.rhs_batch_dimensions(), config.block_k, config.block_n); +} + +// Finds tile sizes for the root of the analysis that satisfy the +// requirements of the dot. That is, the tile sizes need to satisfy the +// constraints of the analysis and map to the given config of the dot. +absl::StatusOr> FindOutputTileSizesForEpilogue( + const SymbolicTiledHloInstruction& tiled_dot, + const SymbolicTileAnalysis& analysis, const TritonGemmConfig& config) { + int64_t dot_rank = tiled_dot.symbolic_tile().tile_map().GetDimensionCount(); + llvm::SmallVector expected_dot_tile_sizes(dot_rank, 1); + // We always expect the shape of the dot to be [1, ..., block_m, block_n]. + expected_dot_tile_sizes[dot_rank - 2] = config.block_m; + expected_dot_tile_sizes[dot_rank - 1] = config.block_n; + + if (VLOG_IS_ON(1)) { + std::ostringstream oss; + for (const auto& size : expected_dot_tile_sizes) { + oss << size << " "; + } + LOG(INFO) << "FindOutputTileSizesForEpilogue: " << tiled_dot.ToString() + << "Constraints: " << analysis.GetConstraints().ToString() + << "Expected dot tile sizes: " << oss.str(); + } + // Try all permutations of the dot tile sizes to see if any of them satisfy + // the constraints of the analysis and map to the given config of the dot. + llvm::SmallVector output_tile_sizes = expected_dot_tile_sizes; + std::sort(output_tile_sizes.begin(), output_tile_sizes.end()); + do { + TF_ASSIGN_OR_RETURN( + bool parameters_satisfy_constraints, + analysis.ParametersSatisfyConstraints(output_tile_sizes)); + if (!parameters_satisfy_constraints) { + continue; + } + auto mapped_dot_tile_sizes = tiled_dot.TileSizes(output_tile_sizes); + if (mapped_dot_tile_sizes == expected_dot_tile_sizes) { + return output_tile_sizes; + } + } while (std::next_permutation(output_tile_sizes.begin(), + output_tile_sizes.end())); + + return absl::InternalError(absl::StrCat( + "Couldn't find output tile sizes that satisfy ", tiled_dot.ToString())); +} + +// Extracts the TritonGemmConfig from the given fusion's backend config. +absl::StatusOr GetTritonGemmConfig( + const HloFusionInstruction& fusion) { + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion.backend_config()); + const FusionBackendConfig& backend_config = + gpu_config.fusion_backend_config(); + if (!backend_config.has_triton_gemm_config()) { + return absl::InternalError( + "The fusion's backend config doesn't have a triton_gemm_config."); + } + return TritonGemmConfig::FromProto(backend_config.triton_gemm_config()); +} + +// Transforms a fusion into an equivalent nested fusion if it has a single dot. +// Returns ok if the transformation was successful. +absl::Status MakeNestedFusionFromGemmFusion( + HloFusionInstruction* fusion, const TritonGemmConfig& config, + const SymbolicTileAnalysis& analysis, + const SymbolicTiledHloInstruction& tiled_dot, HloDotInstruction* dot) { + DCHECK(GetTritonGemmConfig(*fusion).value() == config); + DCHECK_EQ(tiled_dot.hlo(), dot); + + HloComputation* computation = fusion->called_computation(); + + // Left-hand side of the dot. + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( + computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(0)), + *dot)); + TF_RETURN_IF_ERROR(AnnotateDotLhsNestedFusion( + *::xla::Cast(dot->mutable_operand(0)), *dot, + config)); + + // Right-hand side of the dot. + TF_RETURN_IF_ERROR(FuseInstructionsForConsumer( + computation->MakeInstructionPostOrderFrom(*dot->mutable_operand(1)), + *dot)); + TF_RETURN_IF_ERROR(AnnotateDotRhsNestedFusion( + *::xla::Cast(dot->mutable_operand(1)), *dot, + config)); + + // Delete newly unused instructions, if any. + TF_ASSIGN_OR_RETURN([[maybe_unused]] bool changed, + HloDCE::RunOnComputation( + computation, + /*remove_cross_partition_collective_ops=*/false)); + + // Annotate the fusion itself. + TF_ASSIGN_OR_RETURN( + llvm::SmallVector output_tile_sizes, + FindOutputTileSizesForEpilogue(tiled_dot, analysis, config)); + + TF_ASSIGN_OR_RETURN(auto gpu_config, + fusion->backend_config()); + FusionBackendConfig& backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(std::string(kTritonFusionKind)); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes.assign(output_tile_sizes.begin(), + output_tile_sizes.end()); + + *backend_config.mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + + return absl::OkStatus(); +} + +size_t GetDotCount(HloComputation* computation) { + return absl::c_count_if(computation->instructions(), [](HloInstruction* hlo) { + return hlo->opcode() == HloOpcode::kDot; + }); +} + +// Returns the transitive producers of 'instruction'. Returns an error if that +// set is not closed, i.e. if there are any consumers outside of the set other +// than 'instruction' itself. +absl::StatusOr GetClosedProducerSet( + HloInstruction* instruction) { + HloInstructionSet producers; + std::deque worklist(instruction->operands().begin(), + instruction->operands().end()); + do { + HloInstruction* front = worklist.front(); + worklist.pop_front(); + if (!producers.insert(front).second) { + continue; // Already in producer set. + } + worklist.insert(worklist.end(), front->operands().begin(), + front->operands().end()); + for (HloInstruction* user : front->users()) { + if (TF_PREDICT_TRUE(user == instruction || producers.count(user) > 0)) { + continue; // User is instruction itself or in producer set. + } + return absl::FailedPreconditionError(absl::StrCat( + "Instruction ", front->ToString(), " has consumer ", user->ToString(), + ", which is not in the producer set of, or ", instruction->ToString(), + " itself.")); + } + } while (!worklist.empty()); + return producers; +} + +// Hoists the given 'bitcast' out of its computation, to the parent of each +// caller. +absl::Status HoistBitcastToCallers(HloInstruction* bitcast, + CallGraph* call_graph) { + TF_ASSIGN_OR_RETURN(HloInstructionSet producers, + GetClosedProducerSet(bitcast)); + + // Check that it's safe to hoist the bitcast. + for (HloInstruction* instruction : producers) { + if (!instruction->IsElementwise() && !instruction->IsConstant() && + instruction->opcode() != HloOpcode::kParameter) { + return absl::InternalError( + absl::StrCat("Cannot hoist bitcast past ", instruction->ToString())); + } + } + + // Adjust the shape of of every instruction in the backward slice. + Shape shape = bitcast->shape(); + for (HloInstruction* instruction : producers) { + *instruction->mutable_shape() = shape; + if (instruction->opcode() != HloOpcode::kParameter) { + continue; + } + int64_t number = instruction->parameter_number(); + for (HloInstruction* caller : + call_graph->GetComputationCallers(instruction->parent())) { + HloInstruction* new_bitcast = + caller->AddInstruction(HloInstruction::CreateBitcast( + instruction->shape(), caller->mutable_operand(number))); + TF_RETURN_IF_ERROR( + caller->ReplaceOperandWithDifferentShape(number, new_bitcast)); + } + } + + TF_RETURN_IF_ERROR(bitcast->ReplaceAllUsesWith(bitcast->mutable_operand(0))); + + return absl::OkStatus(); +} + +// Hoists all bitcasts in the computation to its callers. +absl::Status HoistBitcastsInComputationToCallers(HloComputation* computation, + CallGraph* call_graph) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kBitcast) { + continue; + } + TF_RETURN_IF_ERROR(HoistBitcastToCallers(instruction, call_graph)); + } + return absl::OkStatus(); +} + +class NestGemmFusionVisitor : public DfsHloRewriteVisitor { + public: + explicit NestGemmFusionVisitor(mlir::MLIRContext* ctx, CallGraph* call_graph) + : ctx_(ctx), call_graph_(call_graph) {} + + absl::Status HandleFusion(HloInstruction* instruction) override { + HloFusionInstruction* fusion = Cast(instruction); + + absl::StatusOr config = GetTritonGemmConfig(*fusion); + if (!config.ok()) { + return absl::OkStatus(); // Skip because it's not a Triton gemm fusion. + } + + HloComputation* computation = fusion->called_computation(); + HloInstruction* dot = + hlo_query::GetFirstInstructionWithOpcode(*computation, HloOpcode::kDot); + if (dot == nullptr) { + return absl::OkStatus(); // Skip because fusion has no dot. + } + DCHECK_EQ(GetDotCount(computation), 1) << "Fusion has more than one dot."; + + TF_RETURN_IF_ERROR( + HoistBitcastsInComputationToCallers(computation, call_graph_)); + SymbolicTileAnalysisOrError analysis_or = + SymbolicTileAnalysis::AnalyzeComputation( + *fusion->called_computations()[0], ctx_); + + if (std::holds_alternative(analysis_or)) { + return absl::InternalError( + absl::StrCat("Failed to analyze the computation (", + std::get(analysis_or).Explain(), + "): ", fusion->called_computation()->ToString())); + } + + auto& analysis = std::get(analysis_or); + const auto& tiled_instructions = analysis.GetSymbolicTiledHloComputation(); + auto is_dot = [&](const auto& instr) { return instr->hlo() == dot; }; + auto tiled_dot_it = absl::c_find_if(tiled_instructions, is_dot); + if (tiled_dot_it == tiled_instructions.end()) { + return absl::InternalError(absl::StrCat( + "Couldn't find a symbolic tiled instruction for ", dot->ToString())); + } + + TF_RETURN_IF_ERROR(MakeNestedFusionFromGemmFusion( + fusion, config.value(), analysis, **tiled_dot_it, + Cast(dot))); + this->MarkAsChanged(); + return absl::OkStatus(); + } + + private: + mlir::MLIRContext* ctx_; + CallGraph* call_graph_; +}; + +} // namespace + +absl::StatusOr NestGemmFusion::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + auto call_graph = CallGraph::Build(module, execution_threads); + mlir::MLIRContext ctx; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + NestGemmFusionVisitor visitor(&ctx, call_graph.get()); + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + changed |= visitor.changed(); + } + return changed; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h new file mode 100644 index 00000000000000..aee2ece23afd33 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion.h @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla::gpu { + +// Rewrites Triton GEMM fusions to generic Triton fusions. Any other fusions are +// left unchanged. +// +// The fusion's backend config is set to a BlockLevelFusionConfig, derived from +// a previously set TritonGemmConfig. +// +// The operands of the dot (including their prologues) are fused into two new +// nested fusions, each with their own BlockLevelFusionConfig. +class NestGemmFusion : public HloModulePass { + public: + absl::string_view name() const override { return "nest_gemm_fusion"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_NEST_GEMM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc new file mode 100644 index 00000000000000..718e8a474d1568 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/nest_gemm_fusion_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/nest_gemm_fusion.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::tsl::testing::StatusIs; + +namespace xla { + +// Gtest hook to pretty-print an HloInstruction. +static void PrintTo(const HloInstruction& hlo, std::ostream* os) { + *os << hlo.ToString(); +} + +namespace gpu { +namespace { + +// Wraps a matcher for a fusion instruction's output tile sizes. +// Proto matchers would be nice, but b/229726259 is P2. +MATCHER_P(OutputTileSizesIs, matcher, "") { + auto backend_config = arg.template backend_config(); + if (!backend_config.ok()) { + *result_listener << "failed to get backend config: " + << backend_config.status(); + return false; + } + FusionBackendConfig fusion_backend_config = + backend_config->fusion_backend_config(); + if (!fusion_backend_config.has_block_level_fusion_config()) { + *result_listener << "has no block level fusion config"; + return false; + } + auto output_tile_sizes = + fusion_backend_config.block_level_fusion_config().output_tile_sizes(); + return ExplainMatchResult(matcher, output_tile_sizes, result_listener); +} + +class NestGemmFusionTest : public HloTestBase {}; + +TEST_F(NestGemmFusionTest, BasicTest) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +dot { + lhs = bf16[8192,512] parameter(0) + rhs = bf16[512,512] parameter(1) + ROOT dot = bf16[8192,512] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = bf16[8192,512] parameter(0) + p1 = bf16[512,512] parameter(1) + ROOT fusion = bf16[8192,512] fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config": { + "kind":"__triton_gemm", "triton_gemm_config": { + "block_m":"64", "block_n":"256", "block_k":"32", + "split_k":"1", "num_stages":"1", "num_warps":"1", "num_ctas":"1" + } + } + } +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, NestGemmFusion().Run(module.get())) + EXPECT_TRUE(changed); + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + const HloInstruction* fusion = nullptr; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(match::Fusion(&fusion))); + EXPECT_THAT(*fusion, OutputTileSizesIs(ElementsAre(64, 256))); + + const HloInstruction* lhs = nullptr; + const HloInstruction* rhs = nullptr; + EXPECT_THAT(fusion->fused_expression_root(), + GmockMatch(match::Dot(match::Fusion(&lhs), match::Fusion(&rhs)))); + EXPECT_THAT(*lhs, OutputTileSizesIs(ElementsAre(64, 32))); + EXPECT_THAT(*rhs, OutputTileSizesIs(ElementsAre(32, 256))); +} + +// Tests hoisting of bitcasts which would otherwise trigger unsatisfiable +// constraints during symbolic tile analysis. +TEST_F(NestGemmFusionTest, BitcastsAreHoistedOutOfGemmFusions) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +dot { + lhs = f32[21] parameter(0) + bitcast = f32[3,7]{0,1} bitcast(lhs) + rhs = f32[7,11] parameter(1) + ROOT dot = f32[3,11] dot(bitcast, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = f32[21] parameter(0) + p1 = f32[7,11] parameter(1) + ROOT fusion = f32[3,11] fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config": { + "kind":"__triton_gemm", "triton_gemm_config": { + "block_m":"32", "block_n":"64", "block_k":"16", + "split_k":"1", "num_stages":"1", "num_warps":"1", "num_ctas":"1" + } + } + } +} +)")); + + TF_ASSERT_OK_AND_ASSIGN(bool changed, NestGemmFusion().Run(module.get())) + EXPECT_TRUE(changed); + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + const HloInstruction* fusion = nullptr; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(match::Fusion(&fusion))); + EXPECT_THAT(fusion->operand(0), GmockMatch(match::Bitcast())); + EXPECT_THAT(*fusion, OutputTileSizesIs(ElementsAre(32, 64))); + + const HloInstruction* lhs = nullptr; + const HloInstruction* rhs = nullptr; + EXPECT_THAT(fusion->fused_expression_root(), + GmockMatch(match::Dot(match::Fusion(&lhs), match::Fusion(&rhs)))); + EXPECT_THAT(*lhs, OutputTileSizesIs(ElementsAre(32, 16))); + EXPECT_THAT(*rhs, OutputTileSizesIs(ElementsAre(16, 64))); +} + +TEST_F(NestGemmFusionTest, FailsOnBitcastWithOpenProducerSet) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( +HloModule module + +dot { + p0 = f32[32] parameter(0) + lhs = f32[4,8] bitcast(p0) + rhs = f32[8,4] bitcast(p0) + ROOT dot = f32[4,4] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = f32[32] parameter(0) + ROOT fusion = f32[4,4] fusion(p0), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config": { + "kind":"__triton_gemm", "triton_gemm_config": { + "block_m":"4", "block_n":"4", "block_k":"8", + "split_k":"1", "num_stages":"1", "num_warps":"1", "num_ctas":"1" + } + } + } +} +)")); + + EXPECT_THAT(NestGemmFusion().Run(module.get()).status(), + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("not in the producer set"))); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc index 3f2d1ab6426fd0..d1994a1f6f0687 100644 --- a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc @@ -148,6 +148,10 @@ TEST_F(PGLEAccuracyCheckerTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); *module->mutable_config().mutable_fdo_profile() = kProfileString; + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_pgle_accuracy_checker( + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR); auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); diff --git a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc index 287603c6d0de93..d6fd4b37dd8a74 100644 --- a/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/filecheck.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index f887161d869fd9..e51f3a457244c5 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -90,8 +90,8 @@ bool IsFusible(const HloInstruction& instr) { // Other non-elementwise ops also supported by elemental fusion. switch (instr.opcode()) { case HloOpcode::kFusion: - return instr.fusion_kind() != HloInstruction::FusionKind::kCustom; - + return IsGenericTritonFusion(instr) || + instr.fusion_kind() != HloInstruction::FusionKind::kCustom; case HloOpcode::kCopy: case HloOpcode::kIota: case HloOpcode::kConstant: @@ -149,7 +149,7 @@ class PriorityFusionQueue { mlir::MLIRContext* mlir_context, HloFusionAnalysisCache& fusion_analysis_cache, FusionDeduplicationCache& fusion_deduplication_cache, - bool triton_softmax_priority_fusion_enabled) + bool triton_heroless_fusion_enabled) : computation_(computation), device_info_(device_info), cost_analysis_(cost_analysis_options, *device_info), @@ -161,8 +161,7 @@ class PriorityFusionQueue { mlir_context_(mlir_context), fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), - triton_softmax_priority_fusion_enabled_( - triton_softmax_priority_fusion_enabled) { + triton_heroless_fusion_enabled_(triton_heroless_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); TF_CHECK_OK(computation_->Accept(&cost_analysis_)); @@ -267,8 +266,7 @@ class PriorityFusionQueue { } absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { - bool is_triton_fusion = IsGenericTritonFusion(*producer); - if (!IsFusible(*producer) && !is_triton_fusion) { + if (!IsFusible(*producer)) { return absl::OkStatus(); } @@ -277,7 +275,7 @@ class PriorityFusionQueue { } EstimateRunTimeData runtime_data; - if (is_triton_fusion) { + if (IsGenericTritonFusion(*producer)) { TF_ASSIGN_OR_RETURN( runtime_data, gpu_indexing_performance_model_.EstimateRunTimeForTriton(producer)); @@ -541,6 +539,10 @@ class PriorityFusionQueue { } FusionDecision IsTritonSupported(const HloInstruction& instruction) { + if (IsGenericTritonFusion(instruction)) { + return FusionDecision::Allow(); + } + if (instruction.opcode() != HloOpcode::kFusion) { return IsTritonSupportedInstruction( instruction, device_info_->gpu_compute_capability()); @@ -555,7 +557,7 @@ class PriorityFusionQueue { } } - return {}; + return FusionDecision::Allow(); } TiledRunTimeDataOrError GetTiledRunTimeDataCached( @@ -587,17 +589,17 @@ class PriorityFusionQueue { if (result_or_status.ok()) { return *result_or_status; } else { - return FusionDecision{ + return FusionDecision::Forbid( absl::StrCat("TiledRunTimeDataOrError return status: ", - result_or_status.status().message())}; + result_or_status.status().message())); } }(); if (const auto* fusion_decision = std::get_if(&tiled_run_time_data_or_error)) { - tiled_run_time_data_or_error = FusionDecision{ + tiled_run_time_data_or_error = FusionDecision::Forbid( absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ", - fusion_decision->Explain())}; + fusion_decision->Explain())); } absl::MutexLock lock(&tiled_run_time_data_cache_mutex_); @@ -607,28 +609,17 @@ class PriorityFusionQueue { FusionDecision CanFuseTriton(HloInstruction* producer, HloInstruction* consumer) { - if (!triton_softmax_priority_fusion_enabled_) { - return "triton softmax fusion is not enabled"; + if (!IsGenericTritonFusion(*producer) && + !IsGenericTritonFusion(*consumer) && !triton_heroless_fusion_enabled_) { + return FusionDecision::Forbid("triton heroless fusion is not enabled"); } - if (IsGenericTritonFusion(*producer)) { - if (!IsFusible(*consumer)) { - return "the consumer is not fusible"; - } - - if (auto fusion_decision = IsTritonSupported(*consumer); - !fusion_decision) { - return fusion_decision; - } - } else { - if (!IsFusible(*producer)) { - return "the producer is not fusible"; - } + if (auto fusion_decision = IsTritonSupported(*producer); !fusion_decision) { + return fusion_decision; + } - if (auto fusion_decision = IsTritonSupported(*producer); - !fusion_decision) { - return fusion_decision; - } + if (auto fusion_decision = IsTritonSupported(*consumer); !fusion_decision) { + return fusion_decision; } TiledRunTimeDataOrError tiled_run_time_data_or_error = @@ -651,24 +642,43 @@ class PriorityFusionQueue { tiled_run_time_data.block_level_parameters; } - return {}; + return FusionDecision::Allow(); } FusionDecision CanFuse(HloInstruction* producer, HloInstruction* consumer) { - if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer)) { - return CanFuseTriton(producer, consumer); + // Don't fuse across a root instruction. There are situation when a root + // instruction is not the last in the computation. Instructions after the + // root are not necessary dead. They can be inputs to instructions with side + // effects, like outfeed. + if (producer == producer->parent()->root_instruction()) { + return FusionDecision::Forbid( + "not fusing into the output of the root instruction"); } if (!IsFusible(*producer)) { - return "the producer is not fusible"; + return FusionDecision::Forbid("the producer is not fusible"); } if (!IsFusible(*consumer)) { - return "the consumer is not fusible"; + return FusionDecision::Forbid("the consumer is not fusible"); + } + + // Fusing with Triton is our preferred choice. If the producer-consumer + // fusion is supported by Triton and all necessary flags are enabled, the + // result will be a Triton fusion. If either `producer` or `consumer` is + // already a Triton fusion, we can fuse only if the result will also be a + // Triton fusion. + // + // Otherwise, we'll check if the fusion is supported by the emitter. + FusionDecision can_fuse_triton = CanFuseTriton(producer, consumer); + if (IsGenericTritonFusion(*producer) || IsGenericTritonFusion(*consumer) || + can_fuse_triton) { + return can_fuse_triton; } if (consumer->opcode() == HloOpcode::kBitcast) { - return "not fusing into a single bitcast as consumer"; + return FusionDecision::Forbid( + "not fusing into a single bitcast as consumer"); } // Scatter is special as it has no elemental version but is still input @@ -698,7 +708,8 @@ class PriorityFusionQueue { }; if (contains_significant_reduce(producer) && contains_significant_reduce(consumer)) { - return "both the producer and the consumer contain a reduce"; + return FusionDecision::Forbid( + "both the producer and the consumer contain a reduce"); } // Avoid doing fusions into the output of an "input" fusion when it would @@ -712,8 +723,8 @@ class PriorityFusionQueue { fusion_analysis_cache_.Get(*producer, *consumer); if (analysis_fused.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kLoop) { - return "fusion into output of a reduce fusion would create a loop " - "fusion"; + return FusionDecision::Forbid( + "fusion into output of a reduce fusion would create a loop fusion"); } } @@ -731,15 +742,8 @@ class PriorityFusionQueue { // kernels, in which case we don't want to fuse. // TODO(b/119692968): Remove this once we have fixed our fusion emitter. if (cost_analysis_.ProducerConsumerMergedTooLarge(*producer, *consumer)) { - return "the fusion would result in an overly large code duplication"; - } - - // Don't fuse across a root instruction. There are situation when a root - // instruction is not the last in the computation. Instructions after the - // root are not necessary dead. They can be inputs to instructions with side - // effects, like outfeed. - if (producer == producer->parent()->root_instruction()) { - return "not fusing into the output of the root instruction"; + return FusionDecision::Forbid( + "the fusion would result in an overly large code duplication"); } return InstructionFusion::ShouldFuseInPlaceOp(producer, consumer); @@ -764,7 +768,7 @@ class PriorityFusionQueue { // override any value. { absl::MutexLock lock(&can_fuse_cache_mutex_); - can_fuse_cache_[producer][consumer] = fusion_decision; + can_fuse_cache_[producer].insert_or_assign(consumer, fusion_decision); } return fusion_decision; @@ -772,10 +776,9 @@ class PriorityFusionQueue { FusionDecision CanFuseWithAllNonBitcastUsers(HloInstruction* producer) { if (producer->users().empty()) { - return "No users to fuse"; + return FusionDecision::Forbid("No users to fuse"); } - FusionDecision result; bool has_non_bitcast_user = false; for (const auto& user : producer->users()) { if (user->opcode() == HloOpcode::kBitcast) { @@ -790,9 +793,10 @@ class PriorityFusionQueue { } } if (!has_non_bitcast_user) { - return "not fusing because there are only bitcast users"; + return FusionDecision::Forbid( + "not fusing because there are only bitcast users"); } - return {}; + return FusionDecision::Allow(); } // Store computation for cost analysis. @@ -872,7 +876,8 @@ class PriorityFusionQueue { // like shared memory usage or number of unnested reductions of fusion nodes. FusionInfoCache fusion_info_cache_; - bool triton_softmax_priority_fusion_enabled_; + // If true, redirect all fusion decisions to Triton fusion. + bool triton_heroless_fusion_enabled_; bool dump_fusion_visualization_; }; @@ -900,6 +905,36 @@ bool PriorityFusion::ConsumeFuel(HloInstruction* producer, }); }; +FusionDecision PriorityFusion::CanFuseConstant(const HloInstruction* constant, + const HloInstruction* user) { + // If user is a scatter, verify that we can fuse the constant correctly. + if (auto fusion_decision = CanEmitInputFusedScatter(*constant, *user); + !fusion_decision) { + return fusion_decision; + } + + // If user is a Triton fusion, verify that the constant is supported + // by Triton. + // + // Note: `IsFusible` should not be used for Triton fusions. Generally, + // `IsFusible` returns `false` for Triton fusions, because Triton fusions have + // kCustom fusion kind, but sometimes `IsFusible` will return `true` if the + // fusion contains only elementwise instructions. + // We can always fuse a producer into Triton fusions if the producer is + // supported by Triton, so it's enough to check if the constant is supported. + if (IsGenericTritonFusion(*user)) { + return IsTritonSupportedInstruction(*constant, + device_info_.gpu_compute_capability()); + } + + // Verify that the user is fusible. + if (!IsFusible(*user)) { + return FusionDecision::Forbid("User is not fusible"); + } + + return FusionDecision::Allow(); +} + absl::StatusOr PriorityFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -937,10 +972,10 @@ absl::StatusOr PriorityFusion::Run( module->ToString(HloPrintOptions::ShortParsable())); } - bool triton_softmax_priority_fusion_enabled = + bool triton_heroless_fusion_enabled = module->config() .debug_options() - .xla_gpu_experimental_enable_triton_softmax_priority_fusion(); + .xla_gpu_experimental_enable_triton_heroless_priority_fusion(); FusionDeduplicationCache fusion_deduplication_cache = FusionDeduplicationCache::Create(*module); @@ -953,7 +988,7 @@ absl::StatusOr PriorityFusion::Run( computation, cost_analysis_options_, &device_info_, fusion_process_dump_.get(), thread_pool_, &mlir_context_, fusion_analysis_cache_, fusion_deduplication_cache, - triton_softmax_priority_fusion_enabled); + triton_heroless_fusion_enabled); while (fusion_queue->DequeueNextProducer()) { auto producer = fusion_queue->current_producer(); @@ -986,6 +1021,8 @@ absl::StatusOr PriorityFusion::Run( if (backend_config_it != block_level_parameters_map.end()) { TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( GetTritonGpuBackendConfig(backend_config_it->second))); + fusion_instruction->set_fusion_kind( + HloInstruction::FusionKind::kCustom); } changed = true; @@ -1020,11 +1057,11 @@ absl::StatusOr PriorityFusion::Run( constants.push_back(instruction); } } + for (auto* constant : constants) { auto users = constant->users(); for (auto* user : users) { - if ((IsFusible(*user) || IsGenericTritonFusion(*user)) && - CanEmitInputFusedScatter(*constant, *user)) { + if (CanFuseConstant(constant, user)) { Fuse(constant, user); changed = true; } diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h index f1ea4198d0e910..f1d19532a755f5 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ #define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ - #include #include @@ -32,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" #include "tsl/platform/threadpool.h" @@ -67,6 +67,10 @@ class PriorityFusion : public HloModulePass { // continue with the transformation. bool ConsumeFuel(HloInstruction* producer, HloInstruction* consumer); + // Returns the decision if the constant can be fused into the user. + FusionDecision CanFuseConstant(const HloInstruction* constant, + const HloInstruction* user); + tsl::thread::ThreadPool* thread_pool_; se::DeviceDescription device_info_; diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index 2a0254f55294e4..1e3c3d4a7248b2 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -15,8 +15,6 @@ limitations under the License. #include "xla/service/gpu/transforms/priority_fusion.h" -#include - #include #include #include @@ -26,10 +24,8 @@ limitations under the License. #include #include #include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" @@ -38,10 +34,8 @@ limitations under the License. #include "xla/service/hlo_cost_analysis.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape.h" -#include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -55,13 +49,6 @@ namespace xla { namespace gpu { class PriorityFusionTest : public HloTestBase { - HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - } - public: std::vector RunAndGetFusionKinds( absl::string_view hlo) { @@ -72,30 +59,17 @@ class PriorityFusionTest : public HloTestBase { for (auto computation : module->computations()) { if (!computation->FusionInstruction()) continue; - auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto analysis = HloFusionAnalysis::Create( - *computation->FusionInstruction(), device_info); + *computation->FusionInstruction(), device_info_); kinds.push_back(analysis.GetEmitterFusionKind()); } return kinds; } + se::DeviceDescription device_info_ = TestGpuDeviceInfo::RTXA6000DeviceInfo(); PriorityFusion priority_fusion_{ - /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(), - GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), - /*per_second_rates=*/{}, - /*min_latencies_seconds=*/{}, - /*count_multiple_input_accesses=*/true}}; -}; - -class PriorityFusionWithTritonEnabledTest : public PriorityFusionTest { - public: - DebugOptions GetDebugOptionsForTest() override { - DebugOptions debug_options = PriorityFusionTest::GetDebugOptionsForTest(); - debug_options - .set_xla_gpu_experimental_enable_triton_softmax_priority_fusion(true); - return debug_options; - } + /*thread_pool=*/nullptr, device_info_, + GpuHloCostAnalysis::Options{.count_multiple_input_accesses = true}}; }; TEST_F(PriorityFusionTest, FuseWithSharedArgument) { @@ -901,8 +875,7 @@ TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } -TEST_F(PriorityFusionWithTritonEnabledTest, - CanMergeTritonFusionWithBothProducerAndConsumer) { +TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) { const std::string kHloText = R"( HloModule t add { @@ -958,8 +931,7 @@ ENTRY main { 2); } -TEST_F(PriorityFusionWithTritonEnabledTest, - FuseTritonProducerWithTwoConsumers) { +TEST_F(PriorityFusionTest, FuseTritonProducerWithTwoConsumers) { const std::string kHloText = R"( HloModule t add { @@ -1021,8 +993,7 @@ ENTRY main { 2); } -TEST_F(PriorityFusionWithTritonEnabledTest, - TritonProducerNotSupported_DoNotFuse) { +TEST_F(PriorityFusionTest, TritonProducerNotSupported_DoNotFuse) { const std::string kHloText = R"( HloModule t @@ -1051,8 +1022,7 @@ ENTRY main { EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); } -TEST_F(PriorityFusionWithTritonEnabledTest, - TritonConsumerNotSupported_DoNotFuse) { +TEST_F(PriorityFusionTest, TritonConsumerNotSupported_DoNotFuse) { const std::string kHloText = R"( HloModule t @@ -1106,5 +1076,34 @@ TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) { EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } +class PriorityFusionWithTritonEnabledTest : public PriorityFusionTest { + public: + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions debug_options = PriorityFusionTest::GetDebugOptionsForTest(); + debug_options + .set_xla_gpu_experimental_enable_triton_heroless_priority_fusion(true); + return debug_options; + } +}; + +TEST_F(PriorityFusionWithTritonEnabledTest, + TwoElementwiseOpsAreFusedWithTriton) { + auto module = *ParseAndReturnVerifiedModule(R"( +HloModule m + +ENTRY main { + p0 = f32[2048] parameter(0) + p1 = f32[2048] parameter(1) + add = f32[2048] add(p0, p1) + ROOT mul = f32[2048] multiply(add, p0) +})"); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); + EXPECT_TRUE(IsGenericTritonFusion(*root)); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc index 4b9f6fb130ed0f..be773fbdb225f5 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc index 8f1c93c1ec31d0..e867f2421d7560 100644 --- a/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc index 0c9c6e675e1fa7..01659a11f6e66d 100644 --- a/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc index abe8d50a63c09b..04b050f8945031 100644 --- a/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/tests/filecheck.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 31a1b7e30cbb56..4b2e12c1ce36b8 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -157,6 +157,10 @@ inline bool HasOneUse(const HloInstruction* instr) { // Unsupported case #4: // p = f32[a,b] parameter(0) // b = f32[a,x,b] broadcast(p), dimensions={0,2} +// +// Unsupported case #5: +// p = f32[] parameter(0) +// b = f32[x] broadcast(p), dimensions={} bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) { CHECK_EQ(hlo.opcode(), HloOpcode::kBroadcast) << "Expected broadcast " << hlo.ToShortString(); @@ -169,9 +173,10 @@ bool IsBatchOrReductionDimBroadcast(const HloInstruction& hlo) { const HloParameterInstruction* parameter = Cast(hlo.operand(0)); - // Support only one dim broadcast. - if (parameter->shape().dimensions_size() + 1 != - broadcast->shape().dimensions_size()) { + // Support only one dim broadcast. Scalar parameters are handled elsewhere. + if (broadcast->dimensions().empty() || + parameter->shape().dimensions_size() + 1 != + broadcast->shape().dimensions_size()) { return false; } @@ -514,8 +519,8 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( if (const auto* fusion_decision = std::get_if(&tiled_runtime_data_or)) { - return FusionDecision{absl::StrCat("SymbolicTileAnalysis failed: ", - fusion_decision->Explain())}; + return FusionDecision::Forbid(absl::StrCat("SymbolicTileAnalysis failed: ", + fusion_decision->Explain())); } TiledRunTimeData tiled_runtime_data = @@ -534,8 +539,9 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( if (run_time_without_softmax_rewriter < tiled_runtime_data.runtime_data.exec_time) { - return "Run time estimate for without applying the custom normalization " - "rewrite is faster."; + return FusionDecision::Forbid( + "Run time estimate for without applying the custom normalization " + "rewrite is faster."); } } @@ -547,7 +553,7 @@ DecideIfShouldFuseAndMaybeSetBlockLevelParameters( TF_RETURN_IF_ERROR(softmax_fusion->set_backend_config(backend_config)); VLOG(5) << "Fusing with backend config: " << backend_config.DebugString(); - return FusionDecision{}; + return FusionDecision::Allow(); } absl::StatusOr MaybeFuseDiamondChainImpl( @@ -615,12 +621,12 @@ FusionDecision ShouldFuseReduction(const HloInstruction& reduce, const se::GpuComputeCapability& cc) { if (CodegenDecision is_supported = IsTritonSupportedInstruction(reduce, cc); !is_supported) { - return FusionDecision(is_supported.Explain()); + return FusionDecision::Forbid(is_supported.Explain()); } if (reduce.dimensions().size() != 1 || reduce.dimensions(0) != reduce.operand(0)->shape().rank() - 1) { - return FusionDecision( + return FusionDecision::Forbid( "The reductions in the diamond must reduce 1 dimension and that " "dimension must be the last dimension of the operand."); } @@ -634,21 +640,23 @@ FusionDecision ShouldFuseReduction(const HloInstruction& reduce, identity->operand(0)->opcode() == HloOpcode::kConstant && IsTritonSupportedInstruction(*identity, cc)); if (!should_fuse_identity) { - return "Reduction identity is not a constant or a supported convert of a " - "constant."; + return FusionDecision::Forbid( + "Reduction identity is not a constant or a supported convert of a " + "constant."); } - return {}; + return FusionDecision::Allow(); } DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( HloInstruction* instr, const se::GpuComputeCapability& cc) { if (!instr->IsElementwiseBinary()) { - return "Root is not elementwise binary."; + return FusionDecision::Forbid("Root is not elementwise binary."); } if (!IsTritonSupportedInstruction(*instr, cc)) { - return "Root is not supported for Triton instruction."; + return FusionDecision::Forbid( + "Root is not supported for Triton instruction."); } HloInstruction* producer; @@ -657,18 +665,21 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( if (!TrivialEdge(&broadcast, instr->mutable_operand(1), HloOpcode::kBroadcast, cc)) { - return "Could not find a trivial connection from root to a broadcast."; + return FusionDecision::Forbid( + "Could not find a trivial connection from root to a broadcast."); } if (!TrivialEdge(&reduce, broadcast->mutable_operand(0), HloOpcode::kReduce, cc)) { - return "Could not find a trivial connection from matched broadcast to a " - "reduction."; + return FusionDecision::Forbid( + "Could not find a trivial connection from matched broadcast to a " + "reduction."); } if (!(HasDefaultLayout(broadcast->shape()) && HasDefaultLayout(reduce->shape()))) { - return "Broadcast or reduce have non-default layouts."; + return FusionDecision::Forbid( + "Broadcast or reduce have non-default layouts."); } if (FusionDecision should_fuse_reduction = ShouldFuseReduction(*reduce, cc); @@ -686,19 +697,21 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( identity->operand(0)->opcode() == HloOpcode::kConstant && IsTritonSupportedInstruction(*identity, cc)); if (!should_fuse_identity) { - return "Reduction identity is not a constant or a supported convert of a " - "constant."; + return FusionDecision::Forbid( + "Reduction identity is not a constant or a supported convert of a " + "constant."); } if (!HasOneUse(broadcast) || !HasOneUse(reduce)) { - return "More than one use of broadcast or reduce."; + return FusionDecision::Forbid("More than one use of broadcast or reduce."); } producer = reduce->mutable_operand(0); if (absl::c_linear_search(broadcast->dimensions(), broadcast->shape().rank() - 1)) { - return "Broadcast is not along the reduction dimension."; + return FusionDecision::Forbid( + "Broadcast is not along the reduction dimension."); } while (IsTriviallyFusible(producer, cc)) { @@ -706,16 +719,16 @@ DiamondMatchingDecision MatchesTritonCompatibleClosedReductionDiamondImpl( } if (!HasDefaultLayout(producer->shape())) { - return "Producer has non-default layout."; + return FusionDecision::Forbid("Producer has non-default layout."); } if (!IsTriviallyConnectedProducerOf(producer, instr->mutable_operand(0), cc)) { - return "Producer is not trivially connected."; + return FusionDecision::Forbid("Producer is not trivially connected."); } if (producer != instr->operand(0) && instr->operand(0)->user_count() != 1) { - return "Unsupported root-producer connection."; + return FusionDecision::Forbid("Unsupported root-producer connection."); } VLOG(5) << "Matched Softmax diamond with: "; diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc index 80bed2552becf7..06b481bf0db232 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc @@ -29,14 +29,13 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/instruction_fusion.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" @@ -48,13 +47,6 @@ namespace m = ::xla::match; using ::testing::HasSubstr; -GpuHloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() { - return [&](const Shape& shape) { - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; -} - bool HasBlockLevelFusionConfig(const HloInstruction* fusion) { return fusion->opcode() == HloOpcode::kFusion && fusion->has_backend_config() && @@ -70,7 +62,7 @@ class SoftmaxRewriterTritonTest protected: se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; SoftmaxRewriterTriton fusion_rewriter_{device_info_, - ShapeSizeBytesFunction()}; + HloCostAnalysis::DefaultShapeSize}; }; TEST_F(SoftmaxRewriterTritonTest, CanFuseExactSoftmaxF32) { @@ -836,7 +828,7 @@ ENTRY main { SoftmaxRewriterTriton( TestGpuDeviceInfo::RTXA6000DeviceInfo( se::CudaComputeCapability{se::CudaComputeCapability::VOLTA, 0}), - ShapeSizeBytesFunction()) + HloCostAnalysis::DefaultShapeSize) .Run(module.get()), tsl::testing::StatusIs( tsl::error::FAILED_PRECONDITION, @@ -864,7 +856,7 @@ ENTRY main { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); EXPECT_TRUE(SoftmaxRewriterTriton(TestGpuDeviceInfo::AMDMI210DeviceInfo(), - ShapeSizeBytesFunction()) + HloCostAnalysis::DefaultShapeSize) .Run(module.get()) .ok()); } @@ -1043,7 +1035,8 @@ ENTRY main { } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - SoftmaxRewriterTriton fusion_rewriter(device_info_, ShapeSizeBytesFunction()); + SoftmaxRewriterTriton fusion_rewriter(device_info_, + HloCostAnalysis::DefaultShapeSize); EXPECT_FALSE(fusion_rewriter_.Run(module.get()).value()); } @@ -1285,8 +1278,8 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); - SoftmaxRewriterTriton softmax_rewriter_triton(device_info_, - ShapeSizeBytesFunction()); + SoftmaxRewriterTriton softmax_rewriter_triton( + device_info_, HloCostAnalysis::DefaultShapeSize); int unmatched = 0, matched = 0; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { @@ -1597,7 +1590,7 @@ ENTRY main { // Verify that SoftmaxRewriterTriton without Cost Model will fuse the // normalization diamond. SoftmaxRewriterTriton fusion_rewriter_without_cost_model{ - device_info_, ShapeSizeBytesFunction(), + device_info_, HloCostAnalysis::DefaultShapeSize, /*only_fuse_if_profitable=*/false}; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); @@ -1612,7 +1605,7 @@ ENTRY main { // SoftmaxRewriterTriton with Cost Model will discard the normalization // diamond, because row size is too large. SoftmaxRewriterTriton fusion_rewriter_with_cost_model{ - device_info_, ShapeSizeBytesFunction(), + device_info_, HloCostAnalysis::DefaultShapeSize, /*only_fuse_if_profitable=*/true}; auto module = ParseAndReturnVerifiedModule(hlo_string).value(); @@ -1620,6 +1613,34 @@ ENTRY main { } } +TEST_F(SoftmaxRewriterTritonTest, DoesNotCrashOnScalarBroadcast) { + const std::string hlo_string = R"( +HloModule softmax +max_computation { + arg_0 = f32[] parameter(0) + arg_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(arg_0, arg_1) +} +ENTRY main { + param_0 = f32[127,125]{1,0} parameter(0) + param_1 = f32[] parameter(1) + broadcast_from_scalar = f32[127] broadcast(param_1), dimensions={} + constant_neg_inf = f32[] constant(-inf) + reduce = f32[127]{0} reduce(param_0, constant_neg_inf), dimensions={1}, to_apply=max_computation + add = f32[127]{0} add(broadcast_from_scalar, reduce) + broadcast = f32[127,125]{1,0} broadcast(add), dimensions={0} + subtract = f32[127,125]{1,0} subtract(param_0, broadcast) + ROOT abs = f32[127,125]{1,0} abs(subtract) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + EXPECT_TRUE(fusion_rewriter_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Fusion(m::Parameter(), m::Parameter()) + .WithPredicate(HasBlockLevelFusionConfig))); +} + } // anonymous namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index b299db8d19316a..8c47ecb1b605d7 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -32,9 +32,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/expanders/stable_sort_expander.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/runtime/cub_sort_thunk.h" -#include "xla/service/stable_sort_expander.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -181,7 +181,8 @@ std::optional AnalyzeSortOp( // Create runner for CUB sort operation. absl::StatusOr> CreateRunner( - HloSortInstruction* sort_op, const SortComputationAnalysis& sort_config) { + const HloSortInstruction* sort_op, + const SortComputationAnalysis& sort_config) { int value_index = 1 - sort_config.key_operand; return CubSortRunnerInterface::Create( sort_op->operand(sort_config.key_operand)->shape().element_type(), @@ -190,37 +191,6 @@ absl::StatusOr> CreateRunner( : std::nullopt); } -// Verify that the sort tensor shape is supported by CUB. -bool IsCubCompatibleSort(HloSortInstruction* sort_op) { - VLOG(1) << "Sort instruction: " << sort_op->name(); - if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) { - VLOG(2) << "Unsupported operand count: " << sort_op->operand_count(); - return false; - } - - const Shape& operand_shape = sort_op->operand(0)->shape(); - if (sort_op->sort_dimension() != operand_shape.rank() - 1) { - VLOG(2) << "Sort dimension should be the minor one"; - return false; - } - if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) { - VLOG(2) << "Tensor shape size is too small to see an improvement"; - return false; - } - - auto sort_config = AnalyzeSortOp(*sort_op); - if (!sort_config.has_value()) { - VLOG(2) << "Only simple compare computations are supported"; - return false; - } - if (!CreateRunner(sort_op, *sort_config).ok()) { - VLOG(2) << "Unsupported operand types (no compiled CUB kernels)"; - return false; - } - VLOG(2) << "Sort operation is compatible"; - return true; -} - // Restore the result shape after sorting a pair of tensors. // The trailing argument is the scratch buffer which should be discarded. HloInstruction* UnpackResultPair(HloSortInstruction* sort_op, @@ -338,5 +308,35 @@ absl::StatusOr SortRewriter::Run( return changed; } +bool IsCubCompatibleSort(const HloSortInstruction* sort_op) { + VLOG(1) << "Sort instruction: " << sort_op->name(); + if (sort_op->operand_count() != 1 && sort_op->operand_count() != 2) { + VLOG(2) << "Unsupported operand count: " << sort_op->operand_count(); + return false; + } + + const Shape& operand_shape = sort_op->operand(0)->shape(); + if (sort_op->sort_dimension() != operand_shape.rank() - 1) { + VLOG(2) << "Sort dimension should be the minor one"; + return false; + } + if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) { + VLOG(2) << "Tensor shape size is too small to see an improvement"; + return false; + } + + auto sort_config = AnalyzeSortOp(*sort_op); + if (!sort_config.has_value()) { + VLOG(2) << "Only simple compare computations are supported"; + return false; + } + if (!CreateRunner(sort_op, *sort_config).ok()) { + VLOG(2) << "Unsupported operand types (no compiled CUB kernels)"; + return false; + } + VLOG(2) << "Sort operation is compatible"; + return true; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h index 5763c9b4ed9e86..96835aab6306e0 100644 --- a/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h @@ -57,6 +57,9 @@ class SortRewriter : public HloModulePass { static inline int sort_size_threshold_ = 16385; }; +// Verify that the sort tensor shape is supported by CUB. +bool IsCubCompatibleSort(const HloSortInstruction* sort_op); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index 68805b1ddc3c0c..404e09141caa91 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index c7d2ca59cff0e9..0e178e557b4135 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -25,8 +25,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc index be0eb6fc7ac5e0..bf44757d8e46cb 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/service/gpu/runtime/thunk.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc index 9feb6414a57d81..ac0c53a1586dff 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" diff --git a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc index 8236c26d4056ae..0814b0ef71b726 100644 --- a/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc @@ -28,11 +28,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/pattern_matcher.h" #include "xla/service/topk_rewriter.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index b39a50bde50203..8f86dc7ccf3e5e 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/tools/hlo_decomposer.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -177,6 +178,16 @@ absl::Status VerifyTritonFusion(AutotunerCompileUtil& util, return status; } +TritonFusionNumericsVerifier::FusionCacheKey CacheKeyForFusion( + const HloFusionInstruction& fusion) { + std::unique_ptr module = ExtractInstructionIntoNewModule(fusion); + HloPrintOptions print_options = HloPrintOptions::ModuleFingerprint() + .set_print_only_essential_constants(false) + .set_print_backend_config(true) + .set_sort_backend_config(true); + return module->ToString(print_options); +} + } // namespace absl::StatusOr TritonFusionNumericsVerifier::Run( @@ -200,8 +211,16 @@ absl::StatusOr TritonFusionNumericsVerifier::Run( TF_RETURN_IF_ERROR(triton_fusion_numerics_pass_internal::ForAllTritonFusions( *module, execution_threads, [&](const HloFusionInstruction& fusion) { - return VerifyTritonFusion(*opt_compile_util, fusion, config_, - debug_options); + auto key = CacheKeyForFusion(fusion); + if (auto it = fusion_result_cache_.find(key); + it != fusion_result_cache_.end()) { + ++cache_hits_; + return it->second; + } + auto result = VerifyTritonFusion(*opt_compile_util, fusion, config_, + debug_options); + fusion_result_cache_[key] = result; + return result; })); return false; } diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h index f23a90bff8e4b7..f9e2cc742187eb 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ #define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#include + +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -30,6 +33,7 @@ limitations under the License. #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/stream_executor/stream.h" +#include "xla/xla.pb.h" namespace xla::gpu { @@ -49,8 +53,17 @@ class TritonFusionNumericsVerifier : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; + using FusionCacheKey = std::string; + + int CacheHitsForTestingOnly() const { return cache_hits_; } + private: AutotuneConfig config_; + + // In some models there are many identical fusions. These are cached to avoid + // expensive recomputations. + absl::flat_hash_map fusion_result_cache_; + int cache_hits_ = 0; // used for testing only. }; namespace triton_fusion_numerics_pass_internal { diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index dd562e07d38aa8..2762166278cdbf 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -27,13 +27,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" +#include "xla/service/backend.h" #include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/platform.h" -#include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/status_matchers.h" namespace xla::gpu { @@ -43,7 +45,7 @@ class TritonFusionNumericsVerifierTest : public HloTestBase, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto options = HloTestBase::GetDebugOptionsForTest(); options.set_xla_gpu_experimental_enable_triton_softmax_priority_fusion( true); @@ -245,6 +247,61 @@ ENTRY main { ::testing::HasSubstr("Failed to compile Triton fusion")); } +TEST_F(TritonFusionNumericsVerifierTest, CacheIsUsed) { + absl::string_view hlo_text = R"( +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] maximum(p0, p1) +} + +reduce_0 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add +} + +reduce_1 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=max +} + +// Identical to reduce_0. +reduce_2 { + p = f32[16,16] parameter(0) + c = f32[] constant(0) + ROOT reduce_0 = f32[16]{0} reduce(p, c), dimensions={1}, to_apply=add +} + +ENTRY main { + p0 = f32[16,16] parameter(0) + p1 = f32[16,16] parameter(1) + p2 = f32[16,16] parameter(2) + r0 = f32[16] fusion(p0), kind=kCustom, calls=reduce_0, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + r1 = f32[16] fusion(p1), kind=kCustom, calls=reduce_1, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + r2 = f32[16] fusion(p2), kind=kCustom, calls=reduce_2, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["16"],"num_warps":"1"}}} + add_0_1 = f32[16] add(r0, r1) + ROOT add_0_2 = f32[16] add(add_0_1, r2) +} + )"; + + std::unique_ptr module = + *ParseAndReturnVerifiedModule(hlo_text, GetModuleConfigForTest()); + AutotuneConfig autotune_config{ + DeviceConfig{backend().default_stream_executor(), GetAllocator()}, + module->config().debug_options()}; + TritonFusionNumericsVerifier verifier(autotune_config); + TF_EXPECT_OK(RunHloPass(verifier, module.get())); + EXPECT_EQ(verifier.CacheHitsForTestingOnly(), 1); +} + INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite, TritonFusionNumericsVerifierTest, ::testing::Values(F32, F16, BF16)); diff --git a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc index 1d726136a3a8ee..7d1101091f0d87 100644 --- a/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index b7ac16438ecb5d..2a21a0725dfa7e 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -24,22 +24,27 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/service/while_loop_unroller.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -56,12 +61,15 @@ namespace m = match; // and type conversions of FP8 operands into the bodies of their while loops, // i.e. rewrites // -// inputs --> dequant --> while loop {collective-permute/dot/etc} +// inputs --> dequant --> (unary) --> while loop {collective-permute/dot/etc} // // into // -// inputs --> while loop {dequant --> collective-permute/dot/etc}. -// Returns whether the input computation has been changed. +// inputs --> (unary) --> while loop {dequant --> collective-permute/dot/etc}. +// +// Unary bitcast, broadcast, copy, reshape and transpose ops are allowed between +// dequantization and while loop. Returns whether the input computation has been +// changed. absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { HloInstruction* while_instr = while_body->WhileCallInstruction(); // The input of the while loop will be modified and must have no other users. @@ -73,8 +81,21 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { // while loop. HloInstruction* param_tuple = while_instr->mutable_operand(0); std::array binaries, operands, scales; + std::array, 2> unaries; for (int k = 0; k < 2; ++k) { - if (!Match(param_tuple->mutable_operand(k), + HloInstruction* operand = param_tuple->mutable_operand(k); + // Capture bitcast, broadcast, copy, reshape and transpose ops between + // dequantization and the loop. + while (operand->opcode() == HloOpcode::kBitcast || + operand->opcode() == HloOpcode::kBroadcast || + operand->opcode() == HloOpcode::kCopy || + operand->opcode() == HloOpcode::kReshape || + operand->opcode() == HloOpcode::kTranspose) { + unaries[k].emplace_back(operand); + operand = operand->mutable_operand(0); + } + std::reverse(unaries[k].begin(), unaries[k].end()); + if (!Match(operand, m::AnyOf( m::Divide(&binaries[k], m::Convert(m::Op(&operands[k])), m::Broadcast(m::Op(&scales[k]))), @@ -156,6 +177,22 @@ absl::StatusOr ShiftDequantizationF8(HloComputation* while_body) { return false; } + // Replace any dequantized bitcast, broadcast, copy, reshape and transpose ops + // before the while loop with FP8 unary ops. + for (int k = 0; k < 2; ++k) { + for (HloInstruction* unary : unaries[k]) { + Shape new_shape = ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()); + + operands[k] = unary->AddInstruction(unary->CloneWithNewOperands( + ShapeUtil::MakeShapeWithDenseLayout( + operands[k]->shape().element_type(), unary->shape().dimensions(), + unary->shape().layout().minor_to_major()), + {operands[k]})); + } + } + // Replace the dequantized dot operands in the parameter tuple used by while // with FP8 operands. for (int k = 0; k < 2; ++k) { @@ -467,6 +504,136 @@ bool ShouldAddToChain(const HloInstruction* inst) { return false; } } + +HloComputation* MakeSumComputation(PrimitiveType type, HloModule* module) { + HloComputation::Builder sum_b("add"); + auto x = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(type, {}), "x")); + auto y = sum_b.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(type, {}), "y")); + sum_b.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(type, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = module->AddEmbeddedComputation(sum_b.Build()); + return reduction; +} + +// Transform partial accumulations into a reduction on a contiguous buffer. +// Partial accumulations will impact the overlap between dots because the +// dot+add pattern will be fused into a single gemm later in gemm rewriter +// which adds data dependencies between gemms. Instead we write all +// intermediate results into a larger buffer and perform a one-shot reduction. +// The high-level transformation is: +// +// 'prev_res' is previously partially accumulated result. +// +// shape(x,y) prev_res shape(x,y) dot0 +// \ / +// \ / +// shape(x,y) add0 shape(x,y) dot1 +// \ / +// \ / +// shape(x,y) add1 +// | +// shape(x,y) loop output +// +// transformed into: +// shape(x,y) prev_res shape(x,y) dot0 shape(x,y) dot1 +// \ / / +// \ / / +// shape(n,x,y) concatenate on first axis, n is the number of partitions +// | +// shape(n,x,y) loop output +// | +// shape(x,y) reduction on first axis +// +// The final reduction is pulled outside of the loop to overlap with other +// collectives. +absl::Status MoveAccumulationOutsideLoop( + std::vector& partial_accumulations, + HloComputation* while_body, HloInstruction* loop) { + // The input of the while loop will be modified and must have no other users. + if (!loop || loop->operand(0)->user_count() != 1) { + return absl::OkStatus(); + } + + std::vector partials_to_concat; + + // We reshape it to a N+1 dimensioned tensor with left-most dim being 1. + Shape shape = partial_accumulations[0]->shape(); + shape = ShapeUtil::PrependMajorDimension(1, shape); + + for (auto& inst : partial_accumulations) { + HloInstruction* reshaped_partial = + while_body->AddInstruction(HloInstruction::CreateReshape(shape, inst)); + partials_to_concat.push_back(reshaped_partial); + } + Shape concat_shape = partial_accumulations[0]->shape(); + concat_shape = ShapeUtil::PrependMajorDimension(partial_accumulations.size(), + concat_shape); + + HloInstruction* concat = while_body->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape, partials_to_concat, 0)); + + HloComputation* comp = loop->parent(); + HloInstruction* windowed_lhs = loop->mutable_operand(0)->mutable_operand(0); + // Add a broadcasted zero of the same type as windowed_lhs. This holds all + // the partial accumulations and will be fed to a global reduction after + // this windowed einsum loop. We move the reduction outside of the loop so + // it can be fused or overlap with other instructions in the main + // computation. + Literal zero_literal = + LiteralUtil::Zero(windowed_lhs->shape().element_type()); + HloInstruction* zero = comp->AddInstruction( + HloInstruction::CreateConstant(std::move(zero_literal))); + Shape zero_bcast_shape = ShapeUtil::ChangeElementType( + concat_shape, windowed_lhs->shape().element_type()); + HloInstruction* zero_bcast = MakeBroadcastHlo(zero, {}, zero_bcast_shape); + loop->mutable_operand(0)->AppendOperand(zero_bcast); + ShapeUtil::AppendShapeToTuple(zero_bcast->shape(), + loop->mutable_operand(0)->mutable_shape()); + + // Update the parameter tuples of while's body and condition + // computations. + for (HloComputation* while_comp : {while_body, loop->while_condition()}) { + while_comp->ReplaceParameter( + 0, HloInstruction::CreateParameter( + 0, loop->mutable_operand(0)->shape(), + while_comp->parameter_instruction(0)->name())); + } + HloInstruction* root = while_body->root_instruction(); + std::vector original_operands(root->operands().begin(), + root->operands().end()); + original_operands.push_back(concat); + HloInstruction* new_output_tuple = while_body->AddInstruction( + HloInstruction::CreateTuple(original_operands)); + TF_RETURN_IF_ERROR( + while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); + + // Update the shape of the while loop instruction. + *loop->mutable_shape() = loop->operand(0)->shape(); + + // The final reduction + HloInstruction* concat_result_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, (loop->operand(0)->shape().tuple_shapes_size() - 1))); + HloInstruction* reduced_result = + comp->AddInstruction(HloInstruction::CreateReduce( + partial_accumulations[0]->shape(), concat_result_gte, zero, {0}, + MakeSumComputation(shape.element_type(), loop->GetModule()))); + + // Replace the original output if present. + HloInstruction* original_output_gte; + auto it = absl::c_find_if(loop->users(), [&](HloInstruction* instr) { + // Index of the original output. It's fixed to be the third element in the + // tuple. + return instr->tuple_index() == 2; + }); + if (it != loop->users().end()) { + original_output_gte = *it; + TF_RETURN_IF_ERROR(original_output_gte->ReplaceAllUsesWith(reduced_result)); + } + return absl::OkStatus(); +} absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { HloComputation* while_body = loop->while_body(); // This is to set force delay for the first collective permute so it can @@ -477,6 +644,7 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { WindowedEinsumHandler::kWindowedEinsumRsLoopName) == 0 ? 2 : 0; + std::vector partial_accumulations; for (HloInstruction* inst : while_body->MakeInstructionPostOrder()) { HloInstruction* matched_cp; if (Match(inst, @@ -492,6 +660,20 @@ absl::Status PostProcessUnrolledLoop(HloInstruction* loop, int64_t stream_id) { TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; } + // If dot's result is accumulated, this means we found a loop with + // contracting dim sharded. + HloInstruction* partial_dot; + if (Match(inst, m::AddAnyOrder(m::Op(), + m::Dot(&partial_dot, m::Op(), m::Op())))) { + partial_accumulations.push_back(partial_dot); + } + } + if (partial_accumulations.size() > 0 && + while_body->name().find( + WindowedEinsumHandler::kWindowedEinsumAgLoopName) != + std::string::npos) { + TF_RETURN_IF_ERROR( + MoveAccumulationOutsideLoop(partial_accumulations, while_body, loop)); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index e5f1e57f306306..06333ec97a8450 100644 --- a/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -634,127 +634,142 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 TEST_F(WindowedEinsumHandlerTest, AllGatherF8) { constexpr absl::string_view kHloString = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[1536,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 windowed_dot_general_body_ag { - param.1 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element.lhs = f32[2,512,24576]{2,1,0} get-tuple-element(param.1), index=0 - collective-permute.send_first_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(get-tuple-element.lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - collective-permute.send_second_lhs_shard = f32[2,512,24576]{2,1,0} collective-permute(collective-permute.send_first_lhs_shard), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} - get-tuple-element.rhs = f32[24576,24576]{1,0} get-tuple-element(param.1), index=1 - get-tuple-element.3 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=2 - dot.first_shard_dot = f32[2,512,24576]{2,1,0} dot(get-tuple-element.lhs, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.12 = s32[] constant(0) - constant.13 = s32[4]{0} constant({0, 512, 1024, 1536}) - get-tuple-element.5 = u32[] get-tuple-element(param.1), index=4 - partition-id = u32[] partition-id() - add = u32[] add(get-tuple-element.5, partition-id) - constant.11 = u32[] constant(4) - remainder = u32[] remainder(add, constant.11) - dynamic-slice = s32[1]{0} dynamic-slice(constant.13, remainder), dynamic_slice_sizes={1} - reshape = s32[] reshape(dynamic-slice) - dynamic-update-slice.update_first_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(get-tuple-element.3, dot.first_shard_dot, constant.12, reshape, constant.12) - dot.second_shard_dot = f32[2,512,24576]{2,1,0} dot(collective-permute.send_first_lhs_shard, get-tuple-element.rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} - constant.15 = u32[] constant(1) - add.1 = u32[] add(get-tuple-element.5, constant.15) - add.2 = u32[] add(add.1, partition-id) - remainder.1 = u32[] remainder(add.2, constant.11) - dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.13, remainder.1), dynamic_slice_sizes={1} - reshape.1 = s32[] reshape(dynamic-slice.1) - dynamic-update-slice.update_second_shard_result = f32[2,2048,24576]{2,1,0} dynamic-update-slice(dynamic-update-slice.update_first_shard_result, dot.second_shard_dot, constant.12, reshape.1, constant.12) - get-tuple-element.4 = f32[2,2048,24576]{2,1,0} get-tuple-element(param.1), index=3 - add.3 = u32[] add(add.1, constant.15) - ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(collective-permute.send_second_lhs_shard, get-tuple-element.rhs, dynamic-update-slice.update_second_shard_result, get-tuple-element.4, add.3) + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + lhs = f32[2,512,24576]{2,1,0} get-tuple-element(input), index=0 + permuted_lhs0 = f32[2,512,24576]{2,1,0} collective-permute(lhs), channel_id=4, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + permuted_lhs1 = f32[2,512,24576]{2,1,0} collective-permute(permuted_lhs0), channel_id=5, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + rhs = f32[24576,24576]{1,0} get-tuple-element(input), index=1 + partial_dot_output = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=2 + dot0 = f32[2,512,24576]{2,1,0} dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c0 = s32[] constant(0) + dot_update_slice_offsets = s32[4]{0} constant({0, 512, 1024, 1536}) + loop_counter = u32[] get-tuple-element(input), index=4 + partition_id = u32[] partition-id() + loop_counter_plus_partition_id = u32[] add(loop_counter, partition_id) + c4 = u32[] constant(4) + dot_update_slice_offsets_index0 = u32[] remainder(loop_counter_plus_partition_id, c4) + dot_update_slice_offset0 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index0), dynamic_slice_sizes={1} + dot_update_slice_offset_scalar0 = s32[] reshape(dot_update_slice_offset0) + updated_dot_output0 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(partial_dot_output, dot0, c0, dot_update_slice_offset_scalar0, c0) + dot1 = f32[2,512,24576]{2,1,0} dot(permuted_lhs0, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c1 = u32[] constant(1) + loop_counter_plus_one = u32[] add(loop_counter, c1) + loop_counter_plus_partiion_id_plus_one = u32[] add(loop_counter_plus_one, partition_id) + dot_update_slice_offsets_index1 = u32[] remainder(loop_counter_plus_partiion_id_plus_one, c4) + dot_update_slice_offset1 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index1), dynamic_slice_sizes={1} + dot_update_slice_offset1_scalar = s32[] reshape(dot_update_slice_offset1) + updated_dot_output1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(updated_dot_output0, dot1, c0, dot_update_slice_offset1_scalar, c0) + pass_through = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=3 + next_loop_counter = u32[] add(loop_counter_plus_one, c1) + ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(permuted_lhs1, rhs, updated_dot_output1, pass_through, next_loop_counter) } // windowed_dot_general_body_ag windowed_dot_general_cond_ag { - param = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) - get-tuple-element = u32[] get-tuple-element(param), index=4 - constant.10 = u32[] constant(4) - ROOT compare = pred[] compare(get-tuple-element, constant.10), direction=LT + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + loop_counter = u32[] get-tuple-element(input), index=4 + loop_limit = u32[] constant(4) + ROOT compare = pred[] compare(loop_counter, loop_limit), direction=LT } -ENTRY test_main { - param.4 = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - reshape.8 = f8e4m3fn[2,512,24576]{2,1,0} reshape(param.4) - param.5 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - constant.18 = f32[] constant(0) - broadcast = f32[2,2048,24576]{2,1,0} broadcast(constant.18), dimensions={} - constant.20 = u32[] constant(0) +ENTRY main { + lhs = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + c0_f32 = f32[] constant(0) + c0_f32_bcast = f32[2,2048,24576]{2,1,0} broadcast(c0_f32), dimensions={} + c0_u32 = u32[] constant(0) scale_lhs = f32[] parameter(2) scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={} - lhs_bf32 = f32[2,512,24576]{2,1,0} convert(reshape.8) - lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_bf32, scale_lhs_bcast) + lhs_f32 = f32[2,512,24576]{2,1,0} convert(lhs) + lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_f32, scale_lhs_bcast) scale_rhs = f32[] parameter(3) - scale_rhs_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs), dimensions={} - rhs_bf32 = f32[24576,24576]{1,0} convert(param.5) - rhs_scaled = f32[24576,24576]{1,0} multiply(rhs_bf32, scale_rhs_bcast) - tuple.2 = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_scaled, broadcast, broadcast, constant.20) - while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + scale_rhs_bcast = f32[1536,24576]{1,0} broadcast(scale_rhs), dimensions={} + rhs_f32 = f32[1536,24576]{1,0} convert(rhs) + rhs_scaled = f32[1536,24576]{1,0} multiply(rhs_f32, scale_rhs_bcast) + rhs_bcast = f32[16,1536,24576]{2,1,0} broadcast(rhs_scaled), dimensions={1,2} + rhs_reshaped = f32[24576,24576]{1,0} reshape(rhs_bcast) + while_input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs_reshaped, c0_f32_bcast, c0_f32_bcast, c0_u32) + while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(while_input), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag ROOT get-tuple-element.13 = f32[2,2048,24576]{2,1,0} get-tuple-element(while), index=2 } )"; RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( -; CHECK-LABEL: unrolled_windowed_dot_general_body_ag -; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) -; CHECK-NEXT: [[GTE0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=0 -; CHECK-NEXT: [[CP0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[GTE0]]), channel_id=6 -; CHECK-NEXT: [[CP1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[CP0]]), channel_id=7 -; CHECK-NEXT: [[GTE1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[P0]]), index=1 -; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=2 -; CHECK-NEXT: [[CONVERT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[GTE0]]) -; CHECK-NEXT: [[GTE3:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=5 -; CHECK-NEXT: [[BCAST0:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[GTE3]]), dimensions={} -; CHECK-NEXT: [[MUL0:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT0]], [[BCAST0]]) -; CHECK-NEXT: [[CONVERT1:%[^ ]+]] = f32[24576,24576]{1,0} convert([[GTE1]]) -; CHECK-NEXT: [[GTE4:%[^ ]+]] = f32[] get-tuple-element([[P0]]), index=6 -; CHECK-NEXT: [[BCAST1:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[GTE4]]), dimensions={} -; CHECK-NEXT: [[MUL1:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[CONVERT1]], [[BCAST1]]) -; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL0]], [[MUL1]]), +; CHECK-LABEL: %unrolled_windowed_dot_general_body_ag +; CHECK-NEXT: [[INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) +; CHECK-NEXT: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} get-tuple-element([[INPUT]]), index=0 +; CHECK-NEXT: [[PERMUTED_LHS0:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[LHS]]), channel_id=6 +; CHECK-NEXT: [[PERMUTED_LHS1:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} collective-permute([[PERMUTED_LHS0]]), channel_id=7 +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} get-tuple-element([[INPUT]]), index=1 +; CHECK-NEXT: [[PARTIAL_DOT_OUTPUT:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=2 +; CHECK-NEXT: [[LHS_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[LHS]]) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=5 +; CHECK-NEXT: [[SCALE_LHS_BCAST:%[^ ]+]] = f32[2,512,24576]{2,1,0} broadcast([[SCALE_LHS]]), dimensions={} +; CHECK-NEXT: [[LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[LHS_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[RHS_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS]]) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] get-tuple-element([[INPUT]]), index=6 +; CHECK-NEXT: [[SCALE_RHS_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS]]), dimensions={} +; CHECK-NEXT: [[RHS_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS_F32]], [[SCALE_RHS_BCAST]]) +; CHECK-NEXT: [[DOT0:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0}, ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]", ; CHECK-DAG: "wait_on_operation_queues":[], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[C0:%[^ ]+]] = s32[] constant(0) -; CHECK-NEXT: [[C4:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[C0_S32:%[^ ]+]] = s32[] constant(0) +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) ; CHECK-NEXT: [[C5:%[^ ]+]] = u32[] constant(0) -; CHECK-NEXT: [[PID:%[^ ]+]] = u32[] partition-id() -; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PID]]) -; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(3) -; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C2]]) -; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C4]], [[AND0]], [[C2]]) +; CHECK-NEXT: [[PARTITION_ID:%[^ ]+]] = u32[] partition-id() +; CHECK-NEXT: [[ADD0:%[^ ]+]] = u32[] add([[C5]], [[PARTITION_ID]]) +; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(3) +; CHECK-NEXT: [[AND0:%[^ ]+]] = u32[] and([[ADD0]], [[C3]]) +; CHECK-NEXT: [[CLAMP0:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND0]], [[C3]]) ; CHECK-NEXT: [[CONVERT3:%[^ ]+]] = s32[] convert([[CLAMP0]]) -; CHECK-NEXT: [[C6:%[^ ]+]] = s32[] constant(512) -; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C6]]) +; CHECK-NEXT: [[C512:%[^ ]+]] = s32[] constant(512) +; CHECK-NEXT: [[MUL3:%[^ ]+]] = s32[] multiply([[CONVERT3]], [[C512]]) ; CHECK-NEXT: [[RESHAPE0:%[^ ]+]] = s32[] reshape([[MUL3]]) -; CHECK-NEXT: [[DUPDATESLICE0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[GTE2]], [[DOT0]], [[C0]], [[RESHAPE0]], [[C0]]), +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT0:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[PARTIAL_DOT_OUTPUT]], [[DOT0]], [[C0_S32]], [[RESHAPE0]], [[C0_S32]]), ; CHECK-DAG: backend_config={ ; CHECK-DAG: "operation_queue_id":"0", ; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"], ; CHECK-DAG: "force_earliest_schedule":false} -; CHECK-NEXT: [[CONVERT2:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[CP0]]) -; CHECK-NEXT: [[MUL2:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[CONVERT2]], [[BCAST0]]) -; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[MUL2]], [[MUL1]]), +; CHECK-NEXT: [[PERMUTED_LHS0_F32:%[^ ]+]] = f32[2,512,24576]{2,1,0} convert([[PERMUTED_LHS0]]) +; CHECK-NEXT: [[PERMUTED_LHS_SCALED:%[^ ]+]] = f32[2,512,24576]{2,1,0} multiply([[PERMUTED_LHS0_F32]], [[SCALE_LHS_BCAST]]) +; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[PERMUTED_LHS_SCALED]], [[RHS_SCALED]]), ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0} -; CHECK-NEXT: [[GTE7:%[^ ]+]] = u32[] get-tuple-element([[P0]]), index=4 -; CHECK-NEXT: [[C3:%[^ ]+]] = u32[] constant(1) -; CHECK-NEXT: [[ADD1:%[^ ]+]] = u32[] add([[GTE7]], [[C3]]) -; CHECK-NEXT: [[ADD2:%[^ ]+]] = u32[] add([[ADD1]], [[PID]]) -; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[ADD2]], [[C2]]) -; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C4]], [[AND1]], [[C2]]) +; CHECK-NEXT: [[LOOP_COUNTER:%[^ ]+]] = u32[] get-tuple-element([[INPUT]]), index=4 +; CHECK-NEXT: [[C1:%[^ ]+]] = u32[] constant(1) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C1]]) +; CHECK-NEXT: [[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID:%[^ ]+]] = u32[] add([[LOOP_COUNTER_PLUS_ONE]], [[PARTITION_ID]]) +; CHECK-NEXT: [[AND1:%[^ ]+]] = u32[] and([[LOOP_COUNTER_PLUS_ONE_PLUS_PARTITION_ID]], [[C3]]) +; CHECK-NEXT: [[CLAMP1:%[^ ]+]] = u32[] clamp([[C0_U32]], [[AND1]], [[C3]]) ; CHECK-NEXT: [[CONVERT4:%[^ ]+]] = s32[] convert([[CLAMP1]]) -; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C6]]) +; CHECK-NEXT: [[MUL4:%[^ ]+]] = s32[] multiply([[CONVERT4]], [[C512]]) ; CHECK-NEXT: [[RESHAPE1:%[^ ]+]] = s32[] reshape([[MUL4]]) -; CHECK-NEXT: [[DUPDATESLICE1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[DUPDATESLICE0]], [[DOT1]], [[C0]], [[RESHAPE1]], [[C0]]) -; CHECK-NEXT: [[GTE6:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[P0]]), index=3 -; CHECK-NEXT: [[C7:%[^ ]+]] = u32[] constant(2) -; CHECK-NEXT: [[ADD3:%[^ ]+]] = u32[] add([[GTE7]], [[C7]]) -; CHECK-NEXT: [[TUPLE0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[CP1]], [[GTE1]], [[DUPDATESLICE1]], [[GTE6]], [[ADD3]], /*index=5*/[[GTE3]], [[GTE4]]) +; CHECK-NEXT: [[UPDATED_DOT_OUTPUT1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dynamic-update-slice([[UPDATED_DOT_OUTPUT0]], [[DOT1]], [[C0_S32]], [[RESHAPE1]], [[C0_S32]]) +; CHECK-NEXT: [[PASS_THROUGH:%[^ ]+]] = f32[2,2048,24576]{2,1,0} get-tuple-element([[INPUT]]), index=3 +; CHECK-NEXT: [[C2:%[^ ]+]] = u32[] constant(2) +; CHECK-NEXT: [[NEXT_LOOP_COUNTER:%[^ ]+]] = u32[] add([[LOOP_COUNTER]], [[C2]]) +; CHECK-NEXT: [[TUPLE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[PERMUTED_LHS1]], [[RHS]], [[UPDATED_DOT_OUTPUT1]], [[PASS_THROUGH]], [[NEXT_LOOP_COUNTER]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK-LABEL: ENTRY %main +; CHECK: [[LHS:%[^ ]+]] = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} +; CHECK-NEXT: [[RHS:%[^ ]+]] = f8e4m3fn[1536,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS_BCAST:%[^ ]+]] = f8e4m3fn[16,1536,24576]{2,1,0} broadcast([[RHS]]), dimensions={1,2} +; CHECK-NEXT: [[RHS_RESHAPED:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} reshape([[RHS_BCAST]]) +; CHECK-NEXT: [[C0:%[^ ]+]] = f32[] constant(0) +; CHECK-NEXT: [[C0_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[C0]]), dimensions={} +; CHECK-NEXT: [[C0_U32:%[^ ]+]] = u32[] constant(0) +; CHECK-NEXT: [[SCALE_LHS:%[^ ]+]] = f32[] parameter(2) +; CHECK-NEXT: [[SCALE_RHS:%[^ ]+]] = f32[] parameter(3) +; CHECK-NEXT: [[WHILE_INPUT:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[LHS]], [[RHS_RESHAPED]], [[C0_BCAST]], [[C0_BCAST]], [[C0_U32]], /*index=5*/[[SCALE_LHS]], [[SCALE_RHS]]) +; CHECK: [[WHILE:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) while([[WHILE_INPUT]]), +; CHECK-DAG: condition=%unrolled_windowed_dot_general_cond_ag, +; CHECK-DAG: body=%unrolled_windowed_dot_general_body_ag )"); } @@ -1064,7 +1079,7 @@ ENTRY main { TEST_F(WindowedEinsumHandlerTest, AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { constexpr absl::string_view kHloString = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->(bf16[16,2048,6288]{2,1,0}, bf16[4096,6288]{1,0})}, num_partitions=8 windowed_dot_general_body_ag { param.195 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) parameter(0) @@ -1114,10 +1129,11 @@ ENTRY main.12_spmd { constant.24 = u32[] constant(0) tuple.2 = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) tuple(param.4, param.5, broadcast, broadcast, constant.24) while = (bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0}, bf16[16,2048,6288]{2,1,0}, u32[]) while(tuple.2), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag - get-tuple-element.13 = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 + get-tuple-element.result = bf16[16,2048,6288]{2,1,0} get-tuple-element(while), index=2 all-gather = bf16[16,2048,4096]{2,1,0} all-gather(param.4), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={2}, use_global_device_ids=true param.6 = bf16[16,2048,6288]{2,1,0} parameter(2) - ROOT dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} + dot.7 = bf16[4096,6288]{1,0} dot(all-gather, param.6), lhs_contracting_dims={0,1}, rhs_contracting_dims={0,1} + ROOT tuple.output = (bf16[16,2048,6288]{2,1,0}, bf16[4096,6288]{1,0}) tuple(get-tuple-element.result, dot.7) } )"; @@ -1137,6 +1153,12 @@ ENTRY main.12_spmd { EXPECT_EQ(inst->operand(0)->opcode(), HloOpcode::kGetTupleElement); EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); + + EXPECT_EQ(ag_loop->operand(0)->shape().tuple_shapes_size(), 7); + // The root instruction's first operand should now be a reduction. + EXPECT_EQ( + module->entry_computation()->root_instruction()->operand(0)->opcode(), + HloOpcode::kReduce); } } // namespace diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc index d3b003c7add4bd..1bc8e0cf899554 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/triton_fusion_analysis.h" #include -#include #include #include #include @@ -33,17 +32,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/gpu/cudnn_support_utils.h" -#include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/matmul_indexing_utils.h" #include "xla/service/gpu/triton_tiling_propagation.h" #include "xla/service/instruction_fusion.h" -#include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/tools/hlo_decomposer.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -112,22 +109,6 @@ namespace triton_fusion { return context; } -namespace { - -// Tells how many new parameters does a fusion gain by fusing the operation as -// an input. -int64_t NumAddedParameters(const HloInstruction& hlo) { - // Non-scalar constant is equivalent to a parameter: one input, one output. - if (hlo.opcode() == HloOpcode::kConstant && - !ShapeUtil::IsScalar(hlo.shape())) { - return 0; - } - // All other instructions add all own inputs and remove own single output. - return hlo.operand_count() - 1; -} - -} // namespace - bool FusionContext::CombineDimOrdersAndReqs(const DimOrdersAndReqs& update) { // First check that all updates to insert are compatible to avoid // incomplete merges. @@ -224,41 +205,6 @@ absl::StatusOr TritonFusionAnalysis::Execute( return analysis; } -absl::Status TritonFusionAnalysis::ExecuteForProducerConsumer( - const HloInstruction& producer, const HloInstruction& consumer, - int split_k) { - // TODO(shyshkov): Use HloFusionAdaptor to avoid the need to materialize the - // hlo fusion. - std::unique_ptr new_module = - ExtractProducerConsumerIntoNewModule(producer, consumer); - - auto* new_producer = - new_module->entry_computation()->GetInstructionWithName(producer.name()); - auto* new_consumer = - new_module->entry_computation()->GetInstructionWithName(consumer.name()); - - std::unique_ptr fusion_instruction_holder; - HloInstruction* fusion_instruction; - if (new_consumer->opcode() == HloOpcode::kFusion) { - fusion_instruction = new_consumer; - } else { - fusion_instruction_holder = HloInstruction::CreateFusion( - new_consumer->shape(), new_producer->fusion_kind(), new_consumer); - fusion_instruction = fusion_instruction_holder.get(); - } - - // Try to merge the producer into candidate fusion. - if (new_producer->opcode() == HloOpcode::kFusion) { - fusion_instruction->MergeFusionInstruction(new_producer); - } else { - fusion_instruction->FuseInstruction(new_producer); - } - - auto* fused_computation = - fusion_instruction->fused_instructions_computation(); - return Execute(*fused_computation, split_k).status(); -} - bool TritonFusionAnalysis::IsBatchDimMinorForInt4Parameter( const HloInstruction& dot, Scope scope) const { CHECK(scope == Scope::LHS || scope == Scope::RHS); diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis.h b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h index 405dc79176e202..e3894a793cf572 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis.h @@ -49,14 +49,6 @@ class TritonFusionAnalysis { static absl::StatusOr Execute( const HloDotInstruction& dot, int split_k = 1); - // Execute the analysis of a produce-consumer fusion. Returns absl::OkStatus, - // if the analysis can find a valid tiling for the producer-consumer fusion. - // `split_k` indicates whether this operation was converted to the split-K - // form and tells the analysis how to interpret the batch dimensions. - static absl::Status ExecuteForProducerConsumer(const HloInstruction& producer, - const HloInstruction& consumer, - int split_k = 1); - // A scope is an HLO graph that can be tiled efficiently using same or // compatible tile shapes on all operations. GEMM fusion has 3 or 4 scopes // defined by left operand, right operand, optional meta (third operand) and diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 1943af5343b427..1d5230ee66855e 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -284,7 +284,7 @@ Int64OrError CombineSplitDimMajorPartSizeReqs(int64_t a, int64_t b) { if (a == kNoSplitRequirement) { return b; } - return FusionDecision("Conflicting splits of splittable dimension"); + return FusionDecision::Forbid("Conflicting splits of splittable dimension"); } } // namespace @@ -318,7 +318,7 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( CHECK(!dim_fragments.empty()); for (int i = 0; i < dim_fragments.size() - 1; ++i) { if (tensor_dim_fragments[dim_fragments[i]].is_sliced()) { - return "Sliced non-major-most fragment."; + return FusionDecision::Forbid("Sliced non-major-most fragment."); } } int group_counter = 0; @@ -342,7 +342,7 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( } if (last_seen_group_last_fragment_index > *fragment_it) { - return "Transpose within a dimension."; + return FusionDecision::Forbid("Transpose within a dimension."); } ++group_counter; @@ -356,14 +356,16 @@ DotRequirementsOrError GetRequirementsIfSupportedOrder( if (group_counter == 2) { if (split_dim_major_part != kNoSplitRequirement && split_dim_major_part != grouped_size) { - return "Conflicting splits of splittable dimension"; + return FusionDecision::Forbid( + "Conflicting splits of splittable dimension"); } split_dim_major_part = grouped_size; } else if (group_counter > 2) { - return "2nd split of a splittable dimension."; + return FusionDecision::Forbid( + "2nd split of a splittable dimension."); } } else { - return "Unsupported split of a dimension."; + return FusionDecision::Forbid("Unsupported split of a dimension."); } } @@ -479,7 +481,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( }; if (dst_remaining_size >= src_dim->full_count()) { if (dst_remaining_size % src_dim->full_count()) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } // Source dimension fragment completely fits into the destination one: // just copy it as is. @@ -497,7 +499,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // If there is a remaining fragment of a previous destination dimension // assign it first. if (src_remaining_size % dst_remaining_size || (src_dim->is_sliced())) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } add_new_fragment( Fragment{src_dim->dst_dim_number(), dst_remaining_size}); @@ -515,13 +517,13 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // size assign the remainder of the source and carry over the // remainder of the destination. if (dst_dim_size % src_remaining_size) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } dst_remaining_size = dst_dim_size / src_remaining_size; new_fragment_size = src_remaining_size; } if (src_dim->is_sliced()) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } add_new_fragment( Fragment{src_dim->dst_dim_number(), new_fragment_size}); @@ -537,7 +539,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForBitcast( // give up. while (dst_dim_it != dst_dim_end) { if (dst_shape.dimensions(*dst_dim_it) != 1) { - return "Unsupported bitcast"; + return FusionDecision::Forbid("Unsupported bitcast"); } if (!dst_fragments_order.empty()) { dst_fragments_order.push_back( @@ -582,7 +584,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( Fragments src_fragments_order = src_dim_order.TensorFragmentsOrder(); if (hlo.opcode() == HloOpcode::kSlice && ShapeUtil::IsEffectiveScalar(hlo.shape())) { - return FusionDecision("Slice to scalar is not implemented yet."); + return FusionDecision::Forbid("Slice to scalar is not implemented yet."); } // Every HLO dimension can correspond to a group of subdimensions in // dim_order_. For the easier handling of permutations: group dim_order_ by @@ -595,7 +597,8 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( // It's not supported currently to further propagate dimensions after // reaching a trivial sized tensor. We could probably support it, but now we // just prevent crashing here. - return FusionDecision("Cannot propagate further from trivial sized tensor"); + return FusionDecision::Forbid( + "Cannot propagate further from trivial sized tensor"); } auto src_fragment_it = src_fragments_order.begin(); for (int64_t dim_index : src.shape().layout().minor_to_major()) { @@ -652,17 +655,17 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( dst_logical.resize(src_logical.size() + reduce->dimensions().size()); if (reduce->dimensions().size() != 1) { - return FusionDecision("Unsupported reduction."); + return FusionDecision::Forbid("Unsupported reduction."); } else if (reduce->dimensions().front() != reduce->operand(0)->shape().rank() - 1) { - return FusionDecision("Only row reductions are supported."); + return FusionDecision::Forbid("Only row reductions are supported."); } } else if (hlo.opcode() == HloOpcode::kConcatenate) { dst_logical.resize(src_logical.size()); for (int i = 0; i < src_logical.size(); ++i) { if (i == hlo.concatenate_dimension()) { if (src_logical[i].size() != 1 || src_logical[i][0]->is_sliced()) { - return FusionDecision("Unsupported concatenation."); + return FusionDecision::Forbid("Unsupported concatenation."); } const Fragment& src_fragment = *src_logical[i][0]; Fragment& dst_fragment = new_fragments.emplace_back( @@ -733,7 +736,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (slice->slice_limits(dim) - slice->slice_starts(dim) != dst->shape().dimensions(dim)) { if (dst_logical[dim].size() > 1) { - return FusionDecision("Slicing of fragmented dimension."); + return FusionDecision::Forbid("Slicing of fragmented dimension."); } auto fragment = dst_logical[dim].front(); fragment->set_count(dst->shape().dimensions(dim)); @@ -755,7 +758,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( dst_logical[dim] = src_logical[dim]; if (dynamic_slice->slice_sizes(dim) != dst->shape().dimensions(dim)) { if (dst_logical[dim].size() > 1) { - return FusionDecision("Slicing of fragmented dimension."); + return FusionDecision::Forbid("Slicing of fragmented dimension."); } auto fragment = dst_logical[dim].front(); fragment->set_count(dst->shape().dimensions(dim)); @@ -767,7 +770,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( } } } else { - return FusionDecision("Function called on a wrong instruction."); + return FusionDecision::Forbid("Function called on a wrong instruction."); } // Destination logical -> destination physical and ungroup subdimensions. // Map original fragments to the resulting ones to derive their new @@ -794,7 +797,7 @@ DimOrderMapOrError GetPropagatedDimOrdersForDimAlteringOp( if (hlo.opcode() == HloOpcode::kBroadcast && src_fragments_order[fragment_number].full_count() > 1 && dim_numbers_present_in_dst.contains(dim_index)) { - return FusionDecision("Unsupported broadcast"); + return FusionDecision::Forbid("Unsupported broadcast"); } continue; } @@ -818,7 +821,8 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, return (user->opcode() == HloOpcode::kConcatenate || user->opcode() == HloOpcode::kDynamicSlice); })) { - return "No fusion into concatenations or dynamic slice."; + return FusionDecision::Forbid( + "No fusion into concatenations or dynamic slice."); } if (hlo.opcode() == HloOpcode::kParameter || hlo_query::IsScalarConstant(&hlo)) { @@ -830,7 +834,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kBroadcast) { if (direction != TransformDirection::kOutputToInput) { - return "Unsupported broadcast direction."; + return FusionDecision::Forbid("Unsupported broadcast direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); @@ -838,7 +842,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, // Pad ops are only supported when they are generated as part of the split-k // transform of dot fusions. if (direction != TransformDirection::kOutputToInput) { - return "Unsupported pad direction."; + return FusionDecision::Forbid("Unsupported pad direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); @@ -852,7 +856,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, } else if (hlo.opcode() == HloOpcode::kSlice) { // TODO(b/316637896) Add support for slices in softmax. if (direction != TransformDirection::kOutputToInput) { - return "Unsupported slice direction."; + return FusionDecision::Forbid("Unsupported slice direction."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, @@ -870,7 +874,7 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, properties); } else if (hlo.opcode() == HloOpcode::kReshape) { if (!ShapeUtil::ReshapeIsBitcast(hlo.operand(0)->shape(), hlo.shape())) { - return "Non-bitcast reshape."; + return FusionDecision::Forbid("Non-bitcast reshape."); } return GetPropagatedDimOrdersForBitcast(hlo, direction, src_dim_order, properties); @@ -885,15 +889,16 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, if (noncontracting_dim_fragment_order_it != src_dim_fragments_orders.end()) { if (noncontracting_dim_fragment_order_it->second.size() > 1) { - return "Concatenations on split non-contracting dimensions are " - "unsupported."; + return FusionDecision::Forbid( + "Concatenations on split non-contracting dimensions are " + "unsupported."); } } auto dim = LogicalIndexOfLabeledDimension(hlo.shape(), src_dim_order, noncontracting_dim_label); if (!dim.has_value() || dim.value() != hlo.concatenate_dimension()) { - return "Unsupported concatenation."; + return FusionDecision::Forbid("Unsupported concatenation."); } if (absl::c_any_of(hlo.operands(), [&hlo](const HloInstruction* operand) { // In the current simple implementation of concatenation the size of @@ -907,13 +912,13 @@ DimOrderMapOrError GetPropagatedDimOrders(const HloInstruction& hlo, kMinConcatFragmentSize != 0; })) { - return FusionDecision( + return FusionDecision::Forbid( "At least one operand of concatenation can not be perfectly tiled."); } return GetPropagatedDimOrdersForDimAlteringOp(hlo, direction, src_dim_order, properties); } - return "Unimplemented instruction."; + return FusionDecision::Forbid("Unimplemented instruction."); } // Difference of input and output data volumes of an instruction. @@ -966,9 +971,9 @@ FusionDecision IsConversionWorthFusing(const HloInstruction& input, // output fusion - then it should be fused here anyway! if (ShapeUtil::ByteSizeOf(input.operand(0)->shape()) > ShapeUtil::ByteSizeOf(input.shape())) { - return "Narrowing conversion."; + return FusionDecision::Forbid("Narrowing conversion."); } - return FusionDecision{}; + return FusionDecision::Allow(); } } // namespace @@ -1004,16 +1009,16 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (hlo.opcode() == HloOpcode::kTuple || hlo.opcode() == HloOpcode::kGetTupleElement) { - return "Unsupported instruction."; + return FusionDecision::Forbid("Unsupported instruction."); } if (hlo.opcode() == HloOpcode::kReduce || hlo.opcode() == HloOpcode::kAllReduce || hlo.opcode() == HloOpcode::kAllReduceStart || hlo.opcode() == HloOpcode::kAllReduceDone) { - return "Reductions are not fused yet."; + return FusionDecision::Forbid("Reductions are not fused yet."); } if (hlo.opcode() == HloOpcode::kPad) { - return "Pads are not fused yet."; + return FusionDecision::Forbid("Pads are not fused yet."); } if (auto decision = legacy_triton::IsTritonSupportedInstruction(hlo, gpu_version); @@ -1042,7 +1047,7 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( return decision; } } else if (hlo.IsElementwise() && hlo.opcode() != HloOpcode::kCopy) { - return "Ignored elementwise operation"; + return FusionDecision::Forbid("Ignored elementwise operation"); } } else { // Exception for binary elementwise operations: in most cases these are @@ -1068,12 +1073,14 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( } } if (!accepted && !IsInputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as input."; + return FusionDecision::Forbid( + "Not obviously profitable to fuse as input."); } } } else { if (fusion_level < 2) { - return "Skipping fusing outputs at low fusion levels."; + return FusionDecision::Forbid( + "Skipping fusing outputs at low fusion levels."); } for (int i = 0; i < hlo.operand_count(); ++i) { const HloInstruction* operand = hlo.operand(i); @@ -1088,10 +1095,12 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( operand->opcode() == HloOpcode::kParameter) { continue; } - return "Has multiple inputs - not properly analyzed yet."; + return FusionDecision::Forbid( + "Has multiple inputs - not properly analyzed yet."); } if (!IsOutputWorthFusing(hlo)) { - return "Not obviously profitable to fuse as output."; + return FusionDecision::Forbid( + "Not obviously profitable to fuse as output."); } } return dim_orders_and_requirements; diff --git a/third_party/xla/xla/service/gpu/while_transformer_test.cc b/third_party/xla/xla/service/gpu/while_transformer_test.cc index a5bf72cb4d8b62..e530583508b2db 100644 --- a/third_party/xla/xla/service/gpu/while_transformer_test.cc +++ b/third_party/xla/xla/service/gpu/while_transformer_test.cc @@ -17,12 +17,12 @@ limitations under the License. #include #include "xla/comparison_util.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" -#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/graphcycles/graphcycles.cc b/third_party/xla/xla/service/graphcycles/graphcycles.cc index 019087c1a98276..056cdcd74a01c7 100644 --- a/third_party/xla/xla/service/graphcycles/graphcycles.cc +++ b/third_party/xla/xla/service/graphcycles/graphcycles.cc @@ -38,7 +38,6 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/service/graphcycles/ordered_set.h" diff --git a/third_party/xla/xla/service/heap_simulator/BUILD b/third_party/xla/xla/service/heap_simulator/BUILD index 847541a5c0d0c6..c2b3944f4e4fa6 100644 --- a/third_party/xla/xla/service/heap_simulator/BUILD +++ b/third_party/xla/xla/service/heap_simulator/BUILD @@ -39,12 +39,12 @@ cc_library( ":allocation_block", "//xla:comparison_util", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:logical_buffer", @@ -76,12 +76,12 @@ xla_cc_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc index 8c3236b7f93168..9ceb861e0fce2d 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc @@ -44,6 +44,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" @@ -52,9 +54,7 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" #include "xla/service/time_utils.h" diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h index c7d722bf873b21..7328f87722b600 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h @@ -39,6 +39,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_schedule.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/logical_buffer.h" diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc index 8e11953de05c51..d27dbd14d81cce 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc @@ -29,18 +29,18 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/hlo_alias_analysis.h b/third_party/xla/xla/service/hlo_alias_analysis.h index 28b319cbdb12c0..e2789adda9f4bf 100644 --- a/third_party/xla/xla/service/hlo_alias_analysis.h +++ b/third_party/xla/xla/service/hlo_alias_analysis.h @@ -16,111 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ #define XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Analysis which allocates HloBuffers to HloValues. -class HloAliasAnalysis { - public: - // The callgraph of the given HloModule must be flattened - // (xla::FlattenCallGraph) prior to running the analysis. - static absl::StatusOr> Run( - const HloModule* module, - const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr); - - std::string ToString() const; - - // Return the buffer containing the given value. - const HloBuffer& GetBufferContainingValue(const HloValue& value) const { - return *value_to_buffer_.at(&value); - } - HloBuffer& GetBufferContainingValue(const HloValue& value) { - return *value_to_buffer_.at(&value); - } - - // Return the HloBuffer with the given ID. - const HloBuffer& GetBuffer(HloBuffer::Id buffer_id) const { - return buffers_.at(buffer_id); - } - HloBuffer& GetBuffer(HloBuffer::Id buffer_id) { - return buffers_.at(buffer_id); - } - - // Returns the unique buffer at the given position. CHECK fails if the buffer - // set at that position does not contain exactly one buffer. - const HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - HloBuffer& GetUniqueBufferAt(const HloInstruction* instruction, - const ShapeIndex& index = {}); - - // Compute the set of buffers at the given instruction and index and return as - // a vector. This set is exactly the union of the buffers containing the - // HloValues at this position. - std::vector ComputeBuffersAt( - const HloInstruction* instruction, const ShapeIndex& index = {}) const; - - // Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This - // vector is lazily computed. Mutating operations on HloAliasAnalysis may - // invalidate the underlying vector requiring recomputation. - const std::vector& buffers() const { return buffers_; } - - // Returns the underlying dataflow analysis used by this alias analysis. - HloDataflowAnalysis& dataflow_analysis() const { return *dataflow_analysis_; } - - // Returns true if a buffer lives out of the module. - bool BufferLivesOut(const HloBuffer& buffer) const { - return live_out_buffers_.contains(&buffer); - } - - // Returns true if a hlo value lives out of the module. - bool ValueLivesOut(const HloValue& value) const { - return live_out_buffers_.contains(&GetBufferContainingValue(value)); - } - - std::vector LiveOutBuffers() const { - std::vector results(live_out_buffers_.begin(), - live_out_buffers_.end()); - absl::c_sort(results, HloBuffer::IdLessThan); - return results; - } - - protected: - explicit HloAliasAnalysis(const HloModule* module); - - // Verify various invariants of the alias analysis. - absl::Status Verify() const; - - const HloModule* module_; - - // A set of buffers that live out the module. - absl::flat_hash_set live_out_buffers_; - - // The underlying dataflow analysis used by this alias analysis. - std::unique_ptr dataflow_analysis_; - - // A map indicating which buffer a value is contained in. - absl::flat_hash_map value_to_buffer_; - - // A lazily constructed vector containing all HloBuffers sorted by - // HloBuffer::Id. - std::vector buffers_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_alias_analysis.h" #endif // XLA_SERVICE_HLO_ALIAS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_computation_deduplicator.h b/third_party/xla/xla/service/hlo_computation_deduplicator.h index 64eac6de00452a..bf82bc4ff4204c 100644 --- a/third_party/xla/xla/service/hlo_computation_deduplicator.h +++ b/third_party/xla/xla/service/hlo_computation_deduplicator.h @@ -16,36 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ #define XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Deduplicate computations inside a `HloModule`: If two computations are -// identical then keep the first one (in postorder terms) and remove the rest. -class HloComputationDeduplicator : public HloModulePass { - public: - // Setting mark_fusion_duplications to true will only process fusions in the - // HLO. The comparator in this pass will mark duplicate fusions which is - // needed for groupings in analysis (e.g. Xprof). Currently, the pass - // doesn't change the HLO if the flag is set to true. - explicit HloComputationDeduplicator(bool mark_fusion_duplications = false) - : mark_fusion_duplications_(mark_fusion_duplications) {} - absl::string_view name() const override { return "computation-deduplicator"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - bool ContainsLargeConstants(HloComputation* comp); - bool mark_fusion_duplications_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_computation_deduplicator.h" #endif // XLA_SERVICE_HLO_COMPUTATION_DEDUPLICATOR_H_ diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index a7190b33f2088d..d46996dd3accd5 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" @@ -41,6 +41,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -972,5 +973,40 @@ TEST_F(HloComputationTest, CompositeCall) { EXPECT_EQ(composite_call->frontend_attributes().map().size(), 3); } +TEST_F(HloComputationTest, CloneComputationWithAsyncInstructions) { + constexpr std::string_view hlo = R"( +HloModule main + +comp.0 { + ROOT custom-call.0 = () custom-call(), custom_call_target="foo" +} + +ENTRY main { + in.0 = () parameter(0) + call.0 = () call(), to_apply=comp.0 + ROOT out.0 = () tuple() +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloComputation* comp0 = FindComputation(module.get(), "comp.0"); + HloInstruction* custom_call = FindInstruction(module.get(), "custom-call.0"); + TF_ASSERT_OK(comp0->CreateAsyncInstructions( + custom_call, /*context_shapes=*/{ShapeUtil::MakeScalarShape(U32)}, + /*async_execution_thread=*/HloInstruction::kMainExecutionThread, + /*replace=*/true, + /*override_names=*/true)); + + HloComputation* comp1 = module->AddEmbeddedComputation(comp0->Clone()); + HloComputation* comp2 = module->AddEmbeddedComputation(comp0->Clone()); + EXPECT_NE(comp0->root_instruction()->name(), + comp1->root_instruction()->name()); + EXPECT_NE(comp0->root_instruction()->operand(0)->name(), + comp1->root_instruction()->operand(0)->name()); + EXPECT_NE(comp1->root_instruction()->name(), + comp2->root_instruction()->name()); + EXPECT_NE(comp1->root_instruction()->operand(0)->name(), + comp2->root_instruction()->operand(0)->name()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_constant_folding.h b/third_party/xla/xla/service/hlo_constant_folding.h index 4f56d7fa35562c..5f82f95d863ebb 100644 --- a/third_party/xla/xla/service/hlo_constant_folding.h +++ b/third_party/xla/xla/service/hlo_constant_folding.h @@ -16,36 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ #define XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass which performs constant folding in order to avoid unnecessary -// computation on constants. -class HloConstantFolding : public HloModulePass { - public: - absl::string_view name() const override { return "constant_folding"; } - - // Run constant folding operations on the given module. Returns whether the - // module was changed (constant expressions folded). - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - // Number of slow constant-folds we've encountered. Used for firing - // SlowOperationAlarms. - static std::atomic slow_op_counter_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_constant_folding.h" #endif // XLA_SERVICE_HLO_CONSTANT_FOLDING_H_ diff --git a/third_party/xla/xla/service/hlo_cost_analysis.cc b/third_party/xla/xla/service/hlo_cost_analysis.cc index c49476e8f927e3..19f497c673c8cd 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/tsl/lib/gtl/map_util.h" #include "xla/util.h" #include "xla/window_util.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" namespace xla { @@ -1039,6 +1039,10 @@ absl::Status HloCostAnalysis::HandleAllToAll(const HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status HloCostAnalysis::HandleRaggedAllToAll(const HloInstruction* hlo) { + return absl::OkStatus(); +} + absl::Status HloCostAnalysis::HandleCollectiveBroadcast( const HloInstruction* /*hlo*/) { return absl::OkStatus(); @@ -1528,4 +1532,8 @@ bool HloCostAnalysis::KeyToCopyFromSubcomputation(absl::string_view key) const { !absl::StartsWith(key, kUtilizationKey); } +int64_t HloCostAnalysis::DefaultShapeSize(const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, kDefaultPointerSize); +} + } // namespace xla diff --git a/third_party/xla/xla/service/hlo_cost_analysis.h b/third_party/xla/xla/service/hlo_cost_analysis.h index a1e700491ce9e8..268be3bb9ba5d3 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis.h +++ b/third_party/xla/xla/service/hlo_cost_analysis.h @@ -392,6 +392,9 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // buffer of a shape. using ShapeSizeFunction = std::function; + static constexpr int64_t kDefaultPointerSize = 8; + static int64_t DefaultShapeSize(const Shape& shape); + // A struct to encapsulate hardware-related options. This includes the shape // size function, which is used to encode hardware-specific padding and per // second rates of FLOPs, bytes per second (available bandwidth), and @@ -400,7 +403,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { // Function which computes the size of the top-level of a given shape (not // including nested elements, if any). If null then bytes_accessed methods // return an error. - ShapeSizeFunction shape_size; + ShapeSizeFunction shape_size = DefaultShapeSize; // How much of each property can be processed per second. E.g. if the // property is bytes accessed, this is the number of bytes that can be // processed per second. Is empty if no rates have been set. @@ -454,7 +457,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { }; explicit HloCostAnalysis(const Options& options); - explicit HloCostAnalysis(ShapeSizeFunction shape_size, + explicit HloCostAnalysis(ShapeSizeFunction shape_size = DefaultShapeSize, const Properties& per_second_rates = {}, const Properties& min_latency_seconds = {}); @@ -502,6 +505,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { absl::Status HandleAllReduceStart(const HloInstruction* hlo) override; absl::Status HandleAllReduceDone(const HloInstruction* hlo) override; absl::Status HandleAllToAll(const HloInstruction* hlo) override; + absl::Status HandleRaggedAllToAll(const HloInstruction* hlo) override; absl::Status HandleCollectiveBroadcast(const HloInstruction* hlo) override; absl::Status HandleCollectivePermute(const HloInstruction* hlo) override; absl::Status HandleCollectivePermuteStart(const HloInstruction* hlo) override; diff --git a/third_party/xla/xla/service/hlo_cost_analysis_test.cc b/third_party/xla/xla/service/hlo_cost_analysis_test.cc index 4bac5768a2d54e..05515ee8e0bc2d 100644 --- a/third_party/xla/xla/service/hlo_cost_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_cost_analysis_test.cc @@ -24,12 +24,12 @@ limitations under the License. #include "xla/client/client.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/local_service.h" #include "xla/service/service.h" #include "xla/shape_util.h" @@ -41,12 +41,6 @@ limitations under the License. namespace xla { namespace { -constexpr int64_t kPointerSize = 8; - -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - // This test suite tests the HLO cost analysis by first building a computation // using the client computation builder and running the HloCostAnalysis that // returns the number of floating point and transcendental operations in the @@ -146,7 +140,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -180,7 +174,7 @@ TEST_F(HloCostAnalysisTest, DotGeneral) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -216,7 +210,7 @@ TEST_F(HloCostAnalysisTest, DotGeneral2) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -246,7 +240,7 @@ TEST_F(HloCostAnalysisTest, DotGeneral3) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -273,7 +267,7 @@ TEST_F(HloCostAnalysisTest, Map) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -303,7 +297,7 @@ TEST_F(HloCostAnalysisTest, Convolution) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -345,7 +339,7 @@ TEST_F(HloCostAnalysisTest, ConvolutionSame) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -380,7 +374,7 @@ TEST_F(HloCostAnalysisTest, ConvolutionExtreme) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -403,7 +397,7 @@ TEST_F(HloCostAnalysisTest, ConvolutionExtreme2) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -426,7 +420,7 @@ TEST_F(HloCostAnalysisTest, ConvolutionWithFeatureGroup) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -455,7 +449,7 @@ TEST_F(HloCostAnalysisTest, Reduce) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -480,7 +474,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindow) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -507,7 +501,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindowWithOverlaps) { int n_output_elements = 3 * 4; // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK(root->Accept(&analysis)); // Each of the output elements are generated from reducing [4x5] elements. @@ -539,7 +533,7 @@ ENTRY fusion.50 { } )"; auto hlo_module = ParseAndReturnUnverifiedModule(hlo_text).value(); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); EXPECT_EQ(analysis.flop_count(), (2 * 3 * 1024) + (1024 - 1)); @@ -571,7 +565,7 @@ TEST_F(HloCostAnalysisTest, ReduceWindowVariadic) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -597,7 +591,7 @@ TEST_F(HloCostAnalysisTest, SelectAndScatter) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -619,7 +613,7 @@ TEST_F(HloCostAnalysisTest, Broadcast) { XlaBuilder b("broadcast"); Broadcast(ConstantR0(&b, 42), {10, 7}); auto hlo_module = BuildHloGraph(&b); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); EXPECT_EQ(analysis.flop_count(), 0); @@ -635,8 +629,8 @@ TEST_F(HloCostAnalysisTest, BroadcastCountMultipleInputAccesses) { XlaBuilder b("broadcast"); Broadcast(ConstantR0(&b, 42), {10, 7}); auto hlo_module = BuildHloGraph(&b); - HloCostAnalysis analysis(HloCostAnalysis::Options{ - .shape_size = ShapeSize, .count_multiple_input_accesses = true}); + HloCostAnalysis analysis( + HloCostAnalysis::Options{.count_multiple_input_accesses = true}); ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); EXPECT_EQ(analysis.flop_count(), 0); @@ -661,7 +655,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -672,7 +666,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { } TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { - HloCostAnalysis conv_analysis(ShapeSize); + HloCostAnalysis conv_analysis; { XlaBuilder builder("conv_looking_matmul"); auto lhs = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}), @@ -685,7 +679,7 @@ TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) { &conv_analysis)); } - HloCostAnalysis matmul_analysis(ShapeSize); + HloCostAnalysis matmul_analysis; { XlaBuilder builder("matmul"); auto lhs = @@ -716,7 +710,7 @@ TEST_F(HloCostAnalysisTest, LatencyBoundedOptimalTime) { ParseAndReturnUnverifiedModule(hlo_string)); const HloInstruction* add = module->entry_computation()->root_instruction(); - HloCostAnalysis::Options options{ShapeSize}; + HloCostAnalysis::Options options; const float clock_cycle_seconds = 10.0f; options.set_flops_per_second(1024); options.set_bytes_per_second(1024); @@ -756,7 +750,7 @@ TEST_F(FusionCostAnalysis, LoopFusionDynUpdateSlice) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_fusion_module_str)); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; HloInstruction* fusion = module->entry_computation()->root_instruction(); ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); @@ -773,7 +767,7 @@ TEST_F(FusionCostAnalysis, LoopFusionDynUpdateSlice) { )"; TF_ASSERT_OK_AND_ASSIGN(auto dus_module, ParseAndReturnVerifiedModule(hlo_dus_module_str)); - HloCostAnalysis dus_analysis(ShapeSize); + HloCostAnalysis dus_analysis; auto dus = dus_module->entry_computation()->root_instruction(); ASSERT_IS_OK(dus->Accept(&dus_analysis)); EXPECT_EQ(fusion_analysis.operand_bytes_accessed(*fusion, 0), 0); @@ -832,7 +826,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // The time given these rates at i == 0 is exactly even among the properties // at 1.0 seconds. For other values, one of the rates is slower so that it // becomes the bottleneck. - HloCostAnalysis::Options options{ShapeSize}; + HloCostAnalysis::Options options; options.set_flops_per_second(16 * (i == 1 ? 1 / 2.0 : 1.0)); options.set_transcendentals_per_second(4 * (i == 2 ? 1 / 4.0 : 1.0)); options.set_bytes_per_second(64 * (i == 3 ? 1 / 8.0 : 1.0)); @@ -899,13 +893,13 @@ ENTRY temp { )"; TF_ASSERT_OK_AND_ASSIGN(auto nested_fusion_module, ParseAndReturnVerifiedModule(nested_fusion_text)); - HloCostAnalysis nested_analysis(ShapeSize); + HloCostAnalysis nested_analysis; auto* nested_root = nested_fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(nested_root->Accept(&nested_analysis)); TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, ParseAndReturnVerifiedModule(fusion_text)); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; auto* fusion_root = fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); // The nested fusion should only access the bytes size amount of the parameter @@ -967,13 +961,13 @@ ENTRY temp { )"; TF_ASSERT_OK_AND_ASSIGN(auto nested_fusion_module, ParseAndReturnVerifiedModule(nested_fusion_text)); - HloCostAnalysis nested_analysis(ShapeSize); + HloCostAnalysis nested_analysis; auto* nested_root = nested_fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(nested_root->Accept(&nested_analysis)); TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, ParseAndReturnVerifiedModule(fusion_text)); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; auto* fusion_root = fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); // The nested fusion should only access the bytes size amount of the parameter @@ -1010,7 +1004,7 @@ ENTRY temp { )"; TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, ParseAndReturnVerifiedModule(hlo_text)); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; auto* fusion_root = fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); EXPECT_EQ(1073741824, fusion_analysis.bytes_accessed(*fusion_root)); @@ -1044,7 +1038,7 @@ ENTRY temp { )"; TF_ASSERT_OK_AND_ASSIGN(auto fusion_module, ParseAndReturnVerifiedModule(hlo_text)); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; auto* fusion_root = fusion_module->entry_computation()->root_instruction(); ASSERT_IS_OK(fusion_root->Accept(&fusion_analysis)); EXPECT_EQ(1610612736, fusion_analysis.bytes_accessed(*fusion_root)); @@ -1083,7 +1077,7 @@ TEST_F(FusionCostAnalysis, LoopFusionTupleOutput) { auto* fusion = computation->CreateFusionInstruction( {tuple2, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); EXPECT_EQ(fusion_analysis.flop_count(), 16); @@ -1137,7 +1131,7 @@ TEST_F(FusionCostAnalysis, NoLayout) { auto* fusion = computation->CreateFusionInstruction( {add, broadcast}, HloInstruction::FusionKind::kLoop); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); EXPECT_EQ(fusion_analysis.flop_count(), 120); @@ -1177,7 +1171,7 @@ ENTRY entry { HloInstruction* fusion = module->entry_computation()->root_instruction(); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 3 * 2 * 3); @@ -1212,7 +1206,7 @@ ENTRY entry { HloInstruction* fusion = module->entry_computation()->root_instruction(); - HloCostAnalysis fusion_analysis(ShapeSize); + HloCostAnalysis fusion_analysis; ASSERT_IS_OK(fusion->Accept(&fusion_analysis)); EXPECT_EQ(fusion_analysis.bytes_accessed(*fusion), sizeof(float) * 2 * 2 * 4); @@ -1244,15 +1238,18 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK(root->Accept(&analysis)); EXPECT_EQ(analysis.output_bytes_accessed(*root), 3); // 2-element tuple (pointers) + its 3-element shape #0 - EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 2 * kPointerSize + 3); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), + 2 * HloCostAnalysis::kDefaultPointerSize + 3); // Same as above + non-scalar constant c1 + output. - EXPECT_EQ(analysis.bytes_accessed(*root), 2 * kPointerSize + 3 + 3 + 3); - EXPECT_EQ(analysis.bytes_accessed(), 2 * kPointerSize + 3 + 3 + 3); + EXPECT_EQ(analysis.bytes_accessed(*root), + 2 * HloCostAnalysis::kDefaultPointerSize + 3 + 3 + 3); + EXPECT_EQ(analysis.bytes_accessed(), + 2 * HloCostAnalysis::kDefaultPointerSize + 3 + 3 + 3); } TEST_F(FusionCostAnalysis, InfeedOutfeed) { @@ -1278,7 +1275,7 @@ ENTRY entry { HloInstruction* outfeed = module->entry_computation()->GetInstructionWithName("outfeed"); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK(infeed->Accept(&analysis)); ASSERT_IS_OK(outfeed->Accept(&analysis)); @@ -1313,7 +1310,7 @@ ENTRY entry { HloInstruction* all_reduce = module->entry_computation()->root_instruction(); - HloCostAnalysis all_reduce_analysis(ShapeSize); + HloCostAnalysis all_reduce_analysis; ASSERT_IS_OK(all_reduce->Accept(&all_reduce_analysis)); EXPECT_EQ(all_reduce_analysis.bytes_accessed(*all_reduce), @@ -1327,7 +1324,7 @@ ENTRY entry { } TEST_F(HloCostAnalysisTest, TupleCost) { - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; XlaBuilder builder("tuple"); auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {123}), "x"); @@ -1340,17 +1337,19 @@ TEST_F(HloCostAnalysisTest, TupleCost) { EXPECT_EQ(analysis.flop_count(), 0); EXPECT_EQ(analysis.transcendental_count(), 0); - EXPECT_EQ(analysis.bytes_accessed(), kPointerSize * 2); + EXPECT_EQ(analysis.bytes_accessed(), + HloCostAnalysis::kDefaultPointerSize * 2); HloInstruction* root = hlo_module->entry_computation()->root_instruction(); EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), 0); EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), 0); - EXPECT_EQ(analysis.output_bytes_accessed(*root), kPointerSize * 2); + EXPECT_EQ(analysis.output_bytes_accessed(*root), + HloCostAnalysis::kDefaultPointerSize * 2); } using DomainCostAnalysis = HloTestBase; TEST_F(DomainCostAnalysis, DomainCost) { - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; HloComputation::Builder builder("domain"); auto x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -1392,7 +1391,7 @@ TEST_F(HloCostAnalysisTest, BaseDilatedConvolution) { // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1407,7 +1406,7 @@ TEST_F(HloCostAnalysisTest, Slice) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1427,7 +1426,7 @@ TEST_F(HloCostAnalysisTest, DynamicSlice) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1449,7 +1448,7 @@ TEST_F(HloCostAnalysisTest, DynamicUpdateSlice) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1481,7 +1480,7 @@ TEST_F(HloCostAnalysisTest, Gather) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1493,6 +1492,38 @@ TEST_F(HloCostAnalysisTest, Gather) { EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3); } +TEST_F(HloCostAnalysisTest, GatherBatchingDims) { + // Test the analysis on a gather. + XlaBuilder builder("gather"); + Shape operand_shape = ShapeUtil::MakeShape(S32, {5, 3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {5}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + GatherDimensionNumbers dim_numbers; + dim_numbers.add_offset_dims(1); + dim_numbers.add_collapsed_slice_dims(1); + dim_numbers.add_operand_batching_dims(0); + dim_numbers.add_start_indices_batching_dims(0); + dim_numbers.add_start_index_map(1); + dim_numbers.set_index_vector_dim(1); + Gather(operand, indices, dim_numbers, {1, 1, 3}); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 140); + + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 5 * 3); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 5); + EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 5 * 3); +} + TEST_F(HloCostAnalysisTest, Scatter) { // Test the analysis on a scatter. XlaBuilder builder("scatter"); @@ -1513,7 +1544,7 @@ TEST_F(HloCostAnalysisTest, Scatter) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1526,6 +1557,41 @@ TEST_F(HloCostAnalysisTest, Scatter) { EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 2 * 3); } +TEST_F(HloCostAnalysisTest, ScatterBatchingDims) { + // Test the analysis on a scatter. + XlaBuilder builder("scatter"); + Shape operand_shape = ShapeUtil::MakeShape(F32, {5, 3, 3}); + Shape indices_shape = ShapeUtil::MakeShape(S32, {5}); + Shape values_shape = ShapeUtil::MakeShape(F32, {5, 3}); + + auto operand = Parameter(&builder, 0, operand_shape, "operand"); + auto indices = Parameter(&builder, 1, indices_shape, "indices"); + auto values = Parameter(&builder, 2, values_shape, "values"); + ScatterDimensionNumbers dim_numbers; + dim_numbers.set_index_vector_dim(1); + dim_numbers.add_update_window_dims(1); + dim_numbers.add_inserted_window_dims(1); + dim_numbers.add_input_batching_dims(0); + dim_numbers.add_scatter_indices_batching_dims(0); + dim_numbers.add_scatter_dims_to_operand_dims(1); + Scatter(operand, indices, values, add_, dim_numbers); + + auto hlo_module = BuildHloGraph(&builder); + + // Run HLO cost analysis. + HloCostAnalysis analysis; + ASSERT_IS_OK( + hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); + + EXPECT_EQ(analysis.bytes_accessed(), 4 * (5 + 3 * (5 * 3))); + + HloInstruction* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 0), sizeof(float) * 5 * 3); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 1), sizeof(int32_t) * 5); + EXPECT_EQ(analysis.operand_bytes_accessed(*root, 2), sizeof(float) * 5 * 3); + EXPECT_EQ(analysis.output_bytes_accessed(*root), sizeof(float) * 5 * 3); +} + TEST_F(HloCostAnalysisTest, MultioutputScatter) { // Test the analysis on a scatter. XlaBuilder builder("scatter"); @@ -1561,7 +1627,7 @@ TEST_F(HloCostAnalysisTest, MultioutputScatter) { auto hlo_module = BuildHloGraph(&builder); // Run HLO cost analysis. - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK( hlo_module->entry_computation()->root_instruction()->Accept(&analysis)); @@ -1581,7 +1647,7 @@ TEST_F(HloCostAnalysisTest, GetShapeSizeIgnoreUnsupportedShape) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0}, {DIM_DENSE, DIM_COMPRESSED}); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; EXPECT_TRUE(LayoutUtil::IsSparseArray(shape)); EXPECT_EQ(0, analysis.GetShapeSize(shape)); } @@ -1609,7 +1675,7 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; ASSERT_IS_OK(root->Accept(&analysis)); EXPECT_EQ(analysis.output_bytes_accessed(*root), 10000); @@ -1636,7 +1702,7 @@ TEST_F(FusionCostAnalysis, RevisitModifiedFusion) { HloInstruction* fusion = computation->CreateFusionInstruction( {neg, mul, add}, HloInstruction::FusionKind::kLoop); - HloCostAnalysis::Options options{ShapeSize}; + HloCostAnalysis::Options options; HloCostAnalysis analysis(options); ASSERT_IS_OK(fusion->Accept(&analysis)); @@ -1702,7 +1768,7 @@ ENTRY e { HloInstruction* root = module->entry_computation()->root_instruction(); - HloCostAnalysis modified_analysis(ShapeSize); + HloCostAnalysis modified_analysis; ASSERT_IS_OK(root->Accept(&modified_analysis)); HloInstruction* fusion_root = root->called_computations()[0]->root_instruction(); @@ -1717,7 +1783,7 @@ ENTRY e { module->entry_computation()->ComputeProgramShape()); ASSERT_IS_OK(modified_analysis.RevisitInstruction(root)); - HloCostAnalysis unmodified_analysis(ShapeSize); + HloCostAnalysis unmodified_analysis; ASSERT_IS_OK(root->Accept(&unmodified_analysis)); EXPECT_FLOAT_EQ(modified_analysis.operand_utilization(*fusion_root, 0), 0.2); @@ -1747,7 +1813,7 @@ ENTRY e { ParseAndReturnVerifiedModule(hlo_string)); HloInstruction* root = module->entry_computation()->root_instruction(); - HloCostAnalysis analysis(ShapeSize); + HloCostAnalysis analysis; // add_computation is shared by two reductions - r0 and r1. // Removing/revisiting one of them should not affect the other one. diff --git a/third_party/xla/xla/service/hlo_creation_utils.cc b/third_party/xla/xla/service/hlo_creation_utils.cc index a94e23d21066e5..46f1cf8473e381 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.cc +++ b/third_party/xla/xla/service/hlo_creation_utils.cc @@ -31,10 +31,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -597,12 +597,22 @@ HloInstruction* MaybeMakeTuple(absl::Span operands) { HloInstruction::CreateTuple(operands)); } +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, src_comp.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, + HloModule::CreateFromProto(src_comp.proto(), config)); + HloCloneContext context(dest_module); + return dest_module->DeepCloneComputation(new_module->entry_computation(), + &context); +} + absl::StatusOr MakeSortHlo( const Shape& sort_shape, absl::Span operands, int64_t dimension_to_sort, bool is_stable, HloComputation::Builder* builder, HloModule* module, const OpMetadata* metadata) { CHECK(!operands.empty()) << "Sort Hlo requires at least one operand."; - HloComputation* compare_computation; XlaBuilder b("Sort.Compare"); if (metadata != nullptr) { b.SetOpMetadata(*metadata); @@ -612,13 +622,8 @@ absl::StatusOr MakeSortHlo( operand_types[i] = operands[i]->shape().element_type(); } XlaComputation comparator = CreateScalarLtComputation(operand_types, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module); - compare_computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module)); return builder->AddInstruction(HloInstruction::CreateSort( sort_shape, dimension_to_sort, operands, compare_computation, is_stable)); } diff --git a/third_party/xla/xla/service/hlo_creation_utils.h b/third_party/xla/xla/service/hlo_creation_utils.h index 2db4a7045fc0e2..d9599663ea7fea 100644 --- a/third_party/xla/xla/service/hlo_creation_utils.h +++ b/third_party/xla/xla/service/hlo_creation_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/literal_util.h" @@ -257,6 +258,11 @@ absl::StatusOr MakeSelectHlo( // instruction with all the operands. Crashes if `operands` is empty. HloInstruction* MaybeMakeTuple(absl::Span operands); +// Creates a HloComputation in the destination module from a builder's +// XlaComputation. +absl::StatusOr XlaComputationToHloComputation( + XlaComputation& src_comp, HloModule* dest_module); + // Creates a Sort HLO instruction and adds it to the computation containing the // operands. All operands must be in the same computation. Also creates a // default compare sub-computation which sorts the first operand into ascending diff --git a/third_party/xla/xla/service/hlo_cse_test.cc b/third_party/xla/xla/service/hlo_cse_test.cc index f6378353b8d507..00364bbfd74c50 100644 --- a/third_party/xla/xla/service/hlo_cse_test.cc +++ b/third_party/xla/xla/service/hlo_cse_test.cc @@ -25,10 +25,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis.h b/third_party/xla/xla/service/hlo_dataflow_analysis.h index fecffcae047089..571638e53cf80f 100644 --- a/third_party/xla/xla/service/hlo_dataflow_analysis.h +++ b/third_party/xla/xla/service/hlo_dataflow_analysis.h @@ -20,388 +20,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ #define XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/hash/hash.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/call_graph.h" -#include "xla/service/hlo_phi_graph.h" -#include "xla/service/hlo_value.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Identifies one array input of an HloInstruction. -struct HloOperandIndex { - using MyTuple = std::tuple; - - template - friend H AbslHashValue(H h, const HloOperandIndex& hlo_operand_index) { - return H::combine(std::move(h), hlo_operand_index.ToTuple()); - } - - friend bool operator==(const HloOperandIndex& lhs, - const HloOperandIndex& rhs) { - return lhs.ToTuple() == rhs.ToTuple(); - } - - bool operator!=(const HloOperandIndex& other) const { - return !(*this == other); - } - - MyTuple ToTuple() const { - return std::make_tuple(operand_number, std::cref(operand_index)); - } - - // The operand number in which the array value appears. - int64_t operand_number; - - // The shape index within the operand in which the array value appears. - ShapeIndex operand_index; -}; - -// Analysis which identifies all HLO values and their uses in an HLO module. -class HloDataflowAnalysis { - public: - // Infrastructure for passing may-alias hints: HLO passes can populate the - // may-alias table. If an empty optional is returned, default rules are used. - // - // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be - // overriden using backend-specific overrides. - // - // The first parameter of the function should be the instruction, the - // second parameter should be an operand of the instruction. The third - // parameter should be the output index of the instruction. - using CanShareBuffer = std::function( - const HloInstruction* instr, const HloInstruction* operand, - const ShapeIndex& user_index)>; - - // Infrastructure for overriding whether an instruction defines a new value. - // - // The first parameter is the instruction and the second parameter is the - // output index. If an empty optional is used, default rules are used. If a - // ForwardedOperand object is returned, the value at the corresponding - // operand's index is used for the output, overriding all default logic. - struct ForwardedOperand { - int64_t operand_number; - ShapeIndex operand_index; - }; - using ForwardsValue = std::function( - const HloInstruction* instr, const ShapeIndex& index)>; - - // Runs dataflow analysis on the given module. Parameters: - // - // ssa_form : If true then new values are defined at the merge points of - // kWhile instructions. Abusing nomenclature somewhat, we call these "phi - // values". The merge is formed by the init value and loop backedge. The - // SSA form is minimal in that a new phi value is defined only if the - // merge point is reachable by multiple different values. The SSA form is - // also in loop-closed form in that no values defined inside of a loop - // (while body) is used outside of the loop. Example use of this ssa_form - // mode is to reason about live range interference of buffers. - // - // If ssa_form is false, then merge points do not define new - // values. Rather, the HloValueSet for the merge point contains the union - // of the merged HloValues. - // - // bitcast_defines_value : If true then the Bitcast HLO instruction defines - // a new HLO value in the analysis. If false then Bitcast forwards the - // value of its operand. - static absl::StatusOr> Run( - const HloModule& module, bool ssa_form = false, - bool bitcast_defines_value = false, - const CanShareBuffer& can_share_buffer = nullptr, - const ForwardsValue& forwards_value = nullptr, - absl::flat_hash_set execution_threads = {}); - - // Returns true if 'instruction' defines an HLO value at the given shape index - // of its output. - bool ValueIsDefinedAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - - // Returns the HloValue defined by 'instruction' at the given shape index of - // its output. - // - // Precondition: ValueIsDefinedAt is true for this instruction and index. - const HloValue& GetValueDefinedAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - HloValue& GetValueDefinedAt(const HloInstruction* instruction, - const ShapeIndex& index = {}); - - // Returns the InstructionValueSet for the given instruction. - const InstructionValueSet& GetInstructionValueSet( - const HloInstruction* instruction) const; - InstructionValueSet& GetInstructionValueSet( - const HloInstruction* instruction); - - // Returns all values that are contained in the output of this instruction in - // a flattened set. - HloValueSet GetFlattenedValueSet(const HloInstruction* instruction) const; - - // Returns the HloValueSet for the given instruction at the given index or the - // given position. - const HloValueSet& GetValueSet(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - const HloValueSet& GetValueSet(const HloPosition& position) const; - HloValueSet& GetValueSet(const HloPosition& position); - HloValueSet& GetValueSet(const HloInstruction* instruction, - const ShapeIndex& index = {}); - - // Returns the unique value in the HloValueSet at the given instruction and - // shape index. CHECKs if the value set does not contain a exactly one value. - const HloValue& GetUniqueValueAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) const { - return GetValueSet(instruction, index).GetUniqueValue(); - } - HloValue& GetUniqueValueAt(const HloInstruction* instruction, - const ShapeIndex& index = {}) { - return GetValue(GetValueSet(instruction, index).GetUniqueValue().id()); - } - - // Returns the HloValue with the given Id. - const HloValue& GetValue(HloValue::Id value_id) const; - HloValue& GetValue(HloValue::Id value_id); - - // Returns the total number of HloValues. - int64_t value_count() const { return values_.size(); } - - // Returns a vector of all HloValues stabily sorted by HloValue::Id. - const std::vector& values() const { return values_vector_; } - - // Returns the call graph used for computing the dataflow. - const CallGraph& call_graph() const { return *call_graph_; } - - std::string ToString() const; - - // Returns true if 'user' cannot possibly use the buffer at 'index' in - // 'operand'. Returns false otherwise. - // - // 'operand' does not have to be an operand of 'user'. This can be the - // case with indirect uses. - bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user) const; - - // Returns true if 'user' (at 'user_index') can share a buffer with its - // operand 'operand' (at 'operand_index'). Returns false otherwise. - // - // REQUIRES: 'operand' is an operand of 'user'. - bool CanShareOperandBufferWithUser(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* user, - const ShapeIndex& user_index) const; - - const HloModule& module() const { return module_; } - - // Returns true if the operation is an in-place operation and its operand 0 - // must alias with the output. - static bool IsInPlaceOperation(HloOpcode opcode); - - // Returns true if the operation is the start/done of an asynchronous - // operation, where the buffer used/produced by the op needs to stay alive - // until the asynchronous operation completes. - static bool IsAsynchronousOperationStart(HloOpcode opcode); - static bool IsAsynchronousOperationDone(HloOpcode opcode); - - // Returns the pairs of inputs and outputs that must share the same buffer, - // according to the aliasing rules for that instruction. - // - // This function only considers array values as inputs and outputs, so - // when tuples are present it "sees through" to the array values inside. The - // HloUse describing the input parameter contains not only the operand number - // but also a shape index describing its position inside a nested tuple shape - // (if any). Similarly, the output parameter is described by a shape index - // into the nested tuple shape (if any) of the output value. - // - // For example, for this hypothetical op: - // %foo = (f32[1], (f32[2], f32[3])) - // op((f32[4], f32[5]) %arg0, f32[6] %arg1) - // - // ... the results can include any of the 3 * 3 = 9 possible pairs of - // input and output arrays. - static std::vector> - GetInPlaceInputOutputPairs(const HloInstruction* instruction); - - // Verifies various invariants of the dataflow analysis. - absl::Status Verify() const; - - private: - static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst); - - HloDataflowAnalysis(const HloModule& module, bool ssa_form, - bool bitcast_defines_value, - const CanShareBuffer& can_share_buffer, - const ForwardsValue& forwards_value, - absl::flat_hash_set execution_threads); - - // 1. During value propagation (Propagate function), always create phi - // values once it see multiple inputs merging at the same point. It then - // records those phi values as well as their inputs in a phi graph. - // - // 2. Post value propagation, Dataflow analysis can then do certain - // optimization(OptimizePhiValues) on the phi graph to prune uncessary phi - // nodes. - // - // Note that this applies in SSA form, and Both of the functions are - // guaranteed to exit. - // - void OptimizePhiValues(); - - // Returns a new HloValue defined at the given instruction and shape index. - HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, - bool is_phi); - - // Marks the HloValue with the given ID for deletion. - void MarkValueForDeletion(HloValue::Id value_id); - - // Deletes all HloValues marked for deletion. Should be called after - // propagation is complete. - void DeleteMarkedValues(); - - // Constructs and initializes the InstructionValueSets of all instructions to - // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling Propagate. - absl::Status InitializeInstructionValueSets(); - - // Updates the value set of the given instruction based on the values flowing - // into the instruction (operands and cross-computation dataflow). - bool UpdateInstructionValueSet(HloInstruction* instruction); - - // Updates the value set for a particular instruction type. Returns whether - // the instruction value set changed. - bool UpdateBitcastValueSet(HloInstruction* bitcast); - bool UpdateCallValueSet(HloInstruction* call); - bool UpdateConditionalValueSet(HloInstruction* conditional); - bool UpdateCopyValueSet(HloInstruction* copy); - bool UpdateCustomCallValueSet(HloInstruction* custom_call); - bool UpdateDomainValueSet(HloInstruction* domain); - bool UpdateGetTupleElementValueSet(HloInstruction* gte); - bool UpdateParameterValueSet(HloInstruction* parameter); - // Async op propagation rules: - // - Operand of async-start to parameter of async wrapped computation and at - // index {0, operand_number} of async-start and async-update outputs. - // - Root of async wrapped computation to index {1} of async-start and - // async-update and index {} of async-done. - // - The contexts in indices {2+} of async-start to the same indices of - // async-update. - // - // As a result of this, the operands/outputs of async-start and async-done - // instructions share the same values as the parameters/roots of the async - // wrapped computation. - bool UpdateAsyncStartValueSet(HloInstruction* async_start); - bool UpdateAsyncUpdateValueSet(HloInstruction* async_update); - bool UpdateAsyncDoneValueSet(HloInstruction* async_done); - bool UpdateCopyStartValueSet(HloInstruction* copy_start); - bool UpdateCopyDoneValueSet(HloInstruction* copy_done); - bool UpdateOptimizationBarrierValueSet(HloInstruction* barrier); - bool UpdateRecvDoneValueSet(HloInstruction* recv_done); - bool UpdateSendValueSet(HloInstruction* send); - bool UpdateTupleValueSet(HloInstruction* tuple); - bool UpdateWhileValueSet(HloInstruction* xla_while); - bool UpdateAddDependencyValueSet(HloInstruction* add_dependency); - bool UpdateAllGatherStartValueSet(HloInstruction* all_gather_start); - bool UpdateAllGatherDoneValueSet(HloInstruction* all_gather_done); - bool UpdateAllReduceDoneValueSet(HloInstruction* all_reduce_done); - bool UpdateCollectivePermuteStartValueSet( - HloInstruction* collective_permute_start); - bool UpdateCollectivePermuteDoneValueSet( - HloInstruction* collective_permute_done); - - // Propagates the dataflow through the module. In particular, it propagates - // the HloValueSet from its defining instruction to the users of the - // instructions. - void Propagate(); - - // Returns the result of the SSA Phi function applied to the given inputs at - // the given instruction. - bool Phi(HloInstruction* instruction, - absl::Span inputs); - - // Updates the positions of the HloValues in the output of the given - // instruction. This should be called after the instruction value set of - // 'instruction' has been changed. 'prev_value_set' must point to the previous - // state of the value set prior to the change. 'prev_value_set' may be null if - // this is the first time positions are being computed. The previous state is - // necessary to efficiently remove positions which have been eliminated due to - // changes in the instructions' InstructionValueSet. - void UpdatePositionsOfValuesAt( - HloInstruction* instruction, const InstructionValueSet& new_value_set, - const InstructionValueSet* prev_value_set = nullptr); - - const HloModule& module_; - const absl::flat_hash_set execution_threads_; - const bool ssa_form_; - const bool bitcast_defines_value_; - - std::unique_ptr call_graph_; - - // The map of all HloValues in the module. We pass around pointers to the - // mapped HloValues, so the underlying container must keep them valid despite - // mutations touching other map entries. - absl::flat_hash_map> values_; - - // A map from instruction to InstructionValueSet. - absl::flat_hash_map> - value_sets_; - - // Values marked for deletion during construction. We don't delete them - // immediately because references to them may remain in ValueSets temporarily - // during propagation. After construction, these values are deleted. - std::vector value_ids_to_delete_; - - // A vector containing all HloValues sorted by HloValue::Id. - std::vector values_vector_; - - // The Id to use for the next HloValue. - HloValue::Id next_value_id_ = 0; - - // An explicit graph holding phi values and edges. - PhiGraph phi_graph_; - - // Backend specific function that decides whether an instruction can share - // a buffer with its operand. - CanShareBuffer can_share_buffer_ = nullptr; - - ForwardsValue forwards_value_ = nullptr; -}; - -// Removes layers of tuple indirection introduced via 'tuple' and -// 'get-tuple-element' instructions to more directly identify the source of the -// given HLO value (identified by the given `ShapeIndex` into the output of the -// given `HloInstruction`). -// -// e.g. for the following: -// %x = some-op(...) -// %foo = get-tuple-element(%x), index=0 -// %bar = tuple(%y, %foo) -// -// ... FollowTupleIndirection(%bar, {1}) == {%x, {0}} (output 1 of 'bar' comes -// from output 0 of %x). -// -// Note that all 'tuple' instructions are followed before all -// 'get-tuple-element' instructions are followed. This is because it is assumed -// that tupling a value and then extracting it from the tuple again will not -// occur in properly-optimized IR. -std::pair FollowTupleIndirection( - const HloInstruction* instruction, ShapeIndex operand_index); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #endif // XLA_SERVICE_HLO_DATAFLOW_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_dce.h b/third_party/xla/xla/service/hlo_dce.h index acd344bb025bc2..d0ce0665d0d0df 100644 --- a/third_party/xla/xla/service/hlo_dce.h +++ b/third_party/xla/xla/service/hlo_dce.h @@ -16,63 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_DCE_H_ #define XLA_SERVICE_HLO_DCE_H_ -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// HLO pass which removes dead instructions from each computation in the module -// and removes dead computations from the module. -// -// An instruction is dead if it is not reachable from the root. A computation is -// dead if it is not the entry computation of the module and it is not reachable -// from the entry computation. -// -// This pass does not remove dead parameter instructions, as parameter -// instructions cannot be deleted. -class HloDCE : public HloModulePass { - public: - HloDCE() : remove_cross_partition_collective_ops_(false) {} - explicit HloDCE(bool remove_cross_partition_collective_ops) - : remove_cross_partition_collective_ops_( - remove_cross_partition_collective_ops) {} - ~HloDCE() override {} - absl::string_view name() const override { return "dce"; } - - // Run DCE on a computation. - static absl::StatusOr RunOnComputation( - HloComputation* computation, bool remove_cross_partition_collective_ops); - - // Run the pass on the given module. Returns whether the module was changed - // (instructions were removed). - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - // Finds all computations that are not called by any instruction and removes - // them from the module. Returns whether any dead code was removed. - absl::StatusOr RecursivelyRemoveDeadComputations(HloModule* module); - - // Given a dead computation, decrements the ref count of all its called - // computations and checks if any of the subcomputations become dead after the - // removal. Returns whether all dead computations were successfully removed - // from the module. - absl::Status RecursivelyRemoveDeadComputation( - HloModule* module, HloComputation* computation, - absl::flat_hash_map& live_call_counts); - - bool remove_cross_partition_collective_ops_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #endif // XLA_SERVICE_HLO_DCE_H_ diff --git a/third_party/xla/xla/service/hlo_domain_map.cc b/third_party/xla/xla/service/hlo_domain_map.cc index 1b3083d1b53d75..6543b850a39d93 100644 --- a/third_party/xla/xla/service/hlo_domain_map.cc +++ b/third_party/xla/xla/service/hlo_domain_map.cc @@ -15,16 +15,28 @@ limitations under the License. #include "xla/service/hlo_domain_map.h" -#include +#include #include #include #include +#include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_domain_metadata.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/map_util.h" -#include "xla/types.h" +#include "xla/status_macros.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_domain_map.h b/third_party/xla/xla/service/hlo_domain_map.h index d2285d06e4f855..f54627a57f363a 100644 --- a/third_party/xla/xla/service/hlo_domain_map.h +++ b/third_party/xla/xla/service/hlo_domain_map.h @@ -16,24 +16,27 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_DOMAIN_MAP_H_ #define XLA_SERVICE_HLO_DOMAIN_MAP_H_ +#include #include +#include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "tsl/platform/status.h" namespace xla { // The HloDomainMap splits a set of instructions within a module or computation, // into different domains, separated by kDomain instructions. // A domain is composed by a set of instructions which can reach each other via -// operand/user edges, without crossing a kDomain insutrction of a given kind. +// operand/user edges, without crossing a kDomain instruction of a given kind. // A domain never crosses computation boundaries. class HloDomainMap { public: diff --git a/third_party/xla/xla/service/hlo_domain_test.cc b/third_party/xla/xla/service/hlo_domain_test.cc index 13f80fdf6b441b..e0603cff03bd27 100644 --- a/third_party/xla/xla/service/hlo_domain_test.cc +++ b/third_party/xla/xla/service/hlo_domain_test.cc @@ -20,11 +20,11 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/call_inliner.h" #include "xla/service/hlo_domain_isolator.h" #include "xla/service/hlo_domain_remover.h" #include "xla/service/hlo_domain_verifier.h" -#include "xla/service/hlo_parser.h" #include "xla/service/sharding_propagation.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/hlo_element_type_converter.h b/third_party/xla/xla/service/hlo_element_type_converter.h index e78edec9d27a22..3fed0142430401 100644 --- a/third_party/xla/xla/service/hlo_element_type_converter.h +++ b/third_party/xla/xla/service/hlo_element_type_converter.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ #define XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass that eliminates certain element types as the input or output of ops by -// inserting Convert ops. This allows a backend to support an element type while -// only actually implementing the Convert op for that element type. This is -// generally not the fastest approach, but it works. -class HloElementTypeConverter : public HloModulePass { - public: - // eliminate_type is the type to eliminate as the input or output of ops, - // using Convert ops to replace it with replace_with_type. - HloElementTypeConverter(PrimitiveType eliminate_type, - PrimitiveType replace_with_type); - - absl::string_view name() const override { return "element_type_converter"; } - - // Returns the pass on the module and returns whether the module was modified. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - PrimitiveType eliminate_type_; - PrimitiveType replace_with_type_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_element_type_converter.h" #endif // XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/third_party/xla/xla/service/hlo_graph_dumper.cc b/third_party/xla/xla/service/hlo_graph_dumper.cc index de91298040015d..dcd1c0b98bc001 100644 --- a/third_party/xla/xla/service/hlo_graph_dumper.cc +++ b/third_party/xla/xla/service/hlo_graph_dumper.cc @@ -73,12 +73,12 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/stream_executor/dnn.h" +#include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "xla/types.h" #include "xla/util.h" #include "xla/window_util.h" -#include "tsl/lib/gtl/map_util.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/base64.h" #include "tsl/platform/env.h" #include "tsl/platform/numbers.h" @@ -1266,6 +1266,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kPartitionId: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kSend: diff --git a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc index 8b7a99b385db67..7c2cc1d945c0d5 100644 --- a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc +++ b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc @@ -20,12 +20,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_ordering.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index 7709bda6032e7f..19ab4d15e8131c 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -2898,6 +2898,30 @@ TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithDiffOpcode) { EXPECT_FALSE(add2->has_backend_config()); } +TEST_F(HloInstructionTest, BackendConfigNotCopiedToDerivedWithConfig) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto p1 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p1")); + auto add = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1)); + + gpu::GpuBackendConfig gpu_config0; + gpu::GpuBackendConfig gpu_config1; + gpu_config0.set_operation_queue_id(2); + gpu_config1.set_operation_queue_id(3); + + TF_ASSERT_OK(add->set_backend_config(gpu_config0)); + auto add2 = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0)); + TF_ASSERT_OK(add2->set_backend_config(gpu_config1)); + + add->SetupDerivedInstruction(add2); + auto backend_config = add2->backend_config(); + EXPECT_TRUE(backend_config.ok()); + EXPECT_EQ(backend_config->operation_queue_id(), 3); +} + TEST_F(HloInstructionTest, MergeMultiOutputProducerFusionIntoMultiOutputFusion) { const std::string& hlo_string = R"( @@ -3058,5 +3082,104 @@ TEST_F(HloInstructionTest, m::Add(m::Parameter(0), m::Parameter(1))))); } +TEST_F(HloInstructionTest, UnfuseInstruction) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + add = f32[10]{0} add(param0, param1) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(HloInstructionTest, UnfuseInstruction2) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + add = f32[10]{0} add(param0, param1) + add2 = f32[10]{0} add(add, param1) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add2) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add2 = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + HloInstruction* add = add2->mutable_operand(0); + + // add2 is not unfusable since it has non-const non-parameter operands. + EXPECT_FALSE(fusion->UnfuseInstruction(add2).ok()); + + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, GmockMatch(m::Add(m::Parameter(0), m::Parameter(1)))); +} + +TEST_F(HloInstructionTest, UnfuseInstructionWithConstantOperand) { + const std::string& hlo_string = R"( + HloModule mof + fusion_comp { + param0 = f32[10]{0} parameter(0) + param1 = f32[10]{0} parameter(1) + const = f32[] constant(1.0) + broadcast = f32[10]{0} broadcast(const), dimensions={} + add = f32[10]{0} add(param0, broadcast) + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(param1, add) + } + + ENTRY main { + p0 = f32[10]{0} parameter(0) + p1 = f32[10]{0} parameter(1) + fusion.1 = (f32[10]{0}, f32[10]{0}) fusion(p0, p1), kind=kLoop, calls=fusion_comp + gte0 = f32[10]{0} get-tuple-element(fusion.1), index=0 + gte1 = f32[10]{0} get-tuple-element(fusion.1), index=1 + ROOT res = (f32[10]{0}, f32[10]{0}) tuple(gte0, gte1) + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* fusion = FindInstruction(module.get(), "fusion.1"); + HloInstruction* add = fusion->fused_instructions_computation() + ->root_instruction() + ->mutable_operand(1); + TF_ASSERT_OK_AND_ASSIGN(auto unfused, fusion->UnfuseInstruction(add)); + EXPECT_THAT(unfused, + GmockMatch(m::Add(m::Parameter(0), m::Broadcast(m::Constant())))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_lexer.h b/third_party/xla/xla/service/hlo_lexer.h index 8a7547ff679834..aad399ed291f3a 100644 --- a/third_party/xla/xla/service/hlo_lexer.h +++ b/third_party/xla/xla/service/hlo_lexer.h @@ -16,204 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_LEXER_H_ #define XLA_SERVICE_HLO_LEXER_H_ -#include -#include -#include - -#include "absl/strings/string_view.h" -#include "xla/shape.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/regexp.h" - -namespace xla { - -// Defines different kinds of tokens used by the HLO lexer. -// -// You shouldn't need to use this directly unless you're using HloLexer -// directly, and you probably don't need to do that. Use hlo_parser instead. -enum class TokKind { - // Markers - kEof, - kError, - - // Tokens with no info. - kEqual, // = - kComma, // , - kColon, // : - kAsterisk, // * - kQuestionMark, // ? - kOctothorp, // # - kPlus, // + - kTilde, // ~ - kLsquare, - kRsquare, // [ ] - kLbrace, - kRbrace, // { } - kLparen, - kRparen, // ( ) - kDots, // ... - - kArrow, // -> - kLeq, // <= - - // Keywords - kw_HloModule, - kw_ENTRY, - kw_ROOT, - kw_true, - kw_false, - kw_maximal, - kw_replicated, - kw_manual, - kw_last_tile_dim_replicate, - kw_shard_as, - kw_shard_like, - kw_unknown, - kw_inf, - - kNegInf, // -inf - - // Typed tokens. - kPrimitiveType, // F32, PRED, etc. - kName, // %foo - kAttributeName, // dimensions= - kDimLabels, // [0-9bf?]{2,}_[0-9io?]{2,}->[0-9bf?]{2,} - kDxD, // [0-9]+(x[0-9]+)+ - kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* - kSparsityDesc, // ([LR]\.[0-9]+@[0-9]+:[0-9]+_?)+ - kIdent, // other identifiers - kString, // "abcd\"\n" - kInt, // 42 - kDecimal, // 4.2 -}; - -std::string TokKindToString(TokKind kind); - -// Lexer for the HloModule::ToString() format text. -// -// This class is meant to be used by hlo_parser.cc. You shouldn't need to use -// it directly. -class HloLexer { - public: - explicit HloLexer(absl::string_view buf) : buf_(buf) { - current_ptr_ = buf_.data(); - } - - TokKind Lex() { return token_state_.current_kind = LexToken(); } - - TokKind GetKind() const { return token_state_.current_kind; } - std::string GetStrVal() const { - switch (GetKind()) { - case TokKind::kName: - case TokKind::kAttributeName: - case TokKind::kDimLabels: - case TokKind::kDxD: - case TokKind::kPad: - case TokKind::kSparsityDesc: - case TokKind::kString: - case TokKind::kIdent: - return token_state_.str_val; - default: - LOG(FATAL) << "This token does not have string value"; - } - } - int64_t GetInt64Val() const { - CHECK(GetKind() == TokKind::kInt) << TokKindToString(GetKind()); - return token_state_.int64_val; - } - double GetDecimalVal() const { - CHECK(GetKind() == TokKind::kDecimal); - return token_state_.decimal_val; - } - PrimitiveType GetPrimitiveTypeVal() const { - CHECK(GetKind() == TokKind::kPrimitiveType); - return token_state_.primitive_type_val; - } - - typedef const char* LocTy; - - // Returns the location of the current token. - LocTy GetLoc() const { return token_state_.token_start; } - - // Returns the line and column of a location in the buffer. - std::pair GetLineAndColumn(LocTy location) const; - - // Returns the whole line given the location. - absl::string_view GetLine(LocTy loc) const; - - // Looks ahead one token and returns it. Lexer state is unchanged. - TokKind LookAhead(); - - // Lexes a string delimited by matching curly braces. Curlies contained - // inside double quotes don't count. - // - // Requires that you've already lexed the open curly brace. - // - // The returned string value includes the outer curlies. - // - // Returns TokKind::kString on success. - TokKind LexJsonDict(); - - private: - // Returns the current character. If it's neither the end of input buffer nor - // an invalid character, moves the pointer forward. - int GetNextChar(); - - // Returns the current character. - int PeekCurrentChar() const; - - // Creates string_view with the given begin and end. Exits if the begin > end, - // or it's out of the range of the current buffer. - absl::string_view StringViewFromPointers(const char* begin, - const char* end) const; - - // Returns true if the given ptr is dereferenceable within the range of the - // current buffer. - bool CanDereference(const char* ptr) const; - - TokKind LexToken(); - - TokKind LexIdentifier(); - TokKind LexPercent(); - TokKind LexShape(); - TokKind LexConstant(); - TokKind LexNumberOrPattern(); - TokKind LexString(); - - std::optional LexNanPayload(absl::string_view& consumable); - - absl::string_view buf_; - const char* current_ptr_; - - // Information about the current token. - struct TokenState { - const char* token_start = nullptr; - TokKind current_kind; - std::string str_val; - int64_t int64_val; - double decimal_val; - PrimitiveType primitive_type_val; - }; - TokenState token_state_; - - struct LineNoCacheTy { - const char* last_query; - unsigned line_no_of_query; - }; - // This caches the line number of the previous query. - mutable LineNoCacheTy line_no_cache_{nullptr, 0}; -}; - -// Does this string start with "{", end with "}", and contain valid-ish JSON -// in-between? If so, hlo_parser can parse e.g. backend_config={blah: "blah"} -// instead of the much uglier backend_config="{blah: \"blah\"}". -// -// (Technically we're not checking for fully-valid JSON, just something we can -// find the end of reasonably.) -bool LexesAsJsonDict(absl::string_view str); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/parser/hlo_lexer.h" #endif // XLA_SERVICE_HLO_LEXER_H_ diff --git a/third_party/xla/xla/service/hlo_liveness_analysis.h b/third_party/xla/xla/service/hlo_liveness_analysis.h index 81a358c4d2a738..fd590408d53934 100644 --- a/third_party/xla/xla/service/hlo_liveness_analysis.h +++ b/third_party/xla/xla/service/hlo_liveness_analysis.h @@ -16,52 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ #define XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/call_graph.h" -#include "xla/service/hlo_value.h" -#include "xla/shape_tree.h" -#include "xla/shape_util.h" - -namespace xla { - -// Analysis which identifies all live {HloInstruction, ShapeIndex} pairs in -// an HLO module. -// -// HloLivenessAnalysis marks the shape index of each live output of each -// instruction in the module, by propagating live shape index information -// from an instruction to its called computations and operands. -class HloLivenessAnalysis { - public: - // Maps from an HloInstruction to its live/dead output shape indices. - using HloIndexMap = absl::flat_hash_map>>; - - // Runs liveness analysis on 'module'. Returns HloLivenessAnalysis object - // which exports liveness for each {HloInstruction, ShapeIndex} in 'module'. - static absl::StatusOr> Run( - const HloModule& module); - - // Returns true if output of 'instruction' at 'shape_index' is live. - // Returns false otherwise. - bool IsLive(const HloInstruction* instruction, - const ShapeIndex& shape_index) const; - - private: - HloLivenessAnalysis(const HloModule& module); - - void RunAnalysis(); - - const HloModule& module_; - std::unique_ptr call_graph_; - HloIndexMap live_index_map_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_liveness_analysis.h" #endif // XLA_SERVICE_HLO_LIVENESS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h index fd9ae679110afb..09d8b432f998db 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.h +++ b/third_party/xla/xla/service/hlo_memory_scheduler.h @@ -16,176 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ #define XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/logical_buffer.h" -#include "xla/service/tuple_points_to_analysis.h" - -namespace xla { - -// Postprocessor of the HloInstructionSequence. This is an opt-in postprocessing -// function to MemorySchedulerAlgorithm to enforce certain hlo schedule -// constraints desired for custom-calls. -using MemorySchedulerPostprocessor = - std::function; - -// A memory scheduler computes an execution sequence for the HLO instructions in -// 'computation' that minimizes peak memory (or finds a balance between memory -// and available concurrency), given a points-to analysis result that describes -// buffer aliasing, together with a target-specific size function that maps a -// tensor's logical size to its padded size. peak_memory (may be nullptr) is set -// to the peak memory of the resulting schedule according to the HeapSimulator. -// -// TODO(yunxing): Cleanup usage of TuplePointsToAnalysis. -using MemorySchedulerAlgorithm = - std::function( - HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, - const LogicalBuffer::SizeFunction&, - const MemorySchedulerPostprocessor&, - /*peak_memory*/ int64_t*)>; - -// Scheduler for the entire module. -using ModuleSchedulerAlgorithm = std::function( - const HloModule*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, - const LogicalBuffer::SizeFunction&, - const absl::flat_hash_set& execution_threads, - /*peak_memory*/ int64_t*)>; - -// Lift a computation scheduler into a module scheduler by calling the -// computation scheduler on all computations in a module. -ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( - const MemorySchedulerAlgorithm&, const MemorySchedulerPostprocessor& = {}); - -// List scheduler -absl::StatusOr ListMemoryScheduler( - HloComputation* computation, - const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); - -// DFS-order scheduler -absl::StatusOr DFSMemoryScheduler( - HloComputation* computation, - const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); - -// BFS-order scheduler -// -// BFS-order scheduler is a simple memory scheduler that schedules instructions -// in a breadth-first order, which maximizes the available concurrency at the -// cost of increased memory usage (HLO operations that do not have buffer -// conflicts can be executed in parallel). -// -// This is the most trivial scheduling optimized for maximum concurrency. In -// practice it is only useful for CPU backend where memory is cheap and we have -// a lot of available compute cores, and cheap concurrency primitives. -absl::StatusOr BFSMemoryScheduler( - HloComputation* computation, - const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); - -// Naive Post Order scheduler -absl::StatusOr PostOrderMemoryScheduler( - HloComputation* computation, - const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); - -// The default scheduling algorithm. Runs the list scheduler, the DFS scheduler, -// and the post-order scheduler and chooses whichever returns a lower min- -// memory, not accounting for fragmentation. peak_memory (may be nullptr) is set -// to the peak memory of the resulting schedule according to the HeapSimulator. -absl::StatusOr DefaultMemoryScheduler( - HloComputation* computation, - const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); - -absl::StatusOr DefaultModuleScheduler( - const HloModule* module, const TuplePointsToAnalysis& points_to_analysis, - const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_set& execution_threads, - int64_t* peak_memory); - -// Returns an HloSchedule which seeks to minimize the memory required for the -// module. size_function is the function returning the number of bytes required -// for a LogicalBuffer. peak_memory (if not nullptr) is set to the largest peak -// memory (according to the HeapSimulator) of all computations in the module. -absl::StatusOr ScheduleModule( - const HloModule* module, const LogicalBuffer::SizeFunction& size_function, - const ModuleSchedulerAlgorithm& algorithm = {}, - const absl::flat_hash_set& execution_threads = {}, - int64_t* peak_memory = nullptr); - -// A pass which schedules the HLO instructions in a module. The HloModule's -// schedule field is set to the resulting HloSchedule using -// HloModule::set_schedule. -class HloMemoryScheduler : public HloModulePass { - public: - // size_function is the function returning the number of bytes required for a - // LogicalBuffer. algorithm is the memory scheduling algorithm to use. If not - // specified, then DefaultMemoryScheduler is used. - explicit HloMemoryScheduler(const LogicalBuffer::SizeFunction& size_function, - const ModuleSchedulerAlgorithm& algorithm = {}); - - ~HloMemoryScheduler() override = default; - - absl::string_view name() const override { return "hlo-memory-scheduler"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - LogicalBuffer::SizeFunction size_function_; - - ModuleSchedulerAlgorithm algorithm_; -}; - -// A pass which produces a naive, but correct schedule. The schedule is produced -// using a DFS traversal of the graph with no attempt to minimize memory use. -class HloTrivialScheduler : public HloModulePass { - public: - absl::string_view name() const override { return "hlo-trivial-scheduler"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -// A trivial pass which clears the schedule currently set on the -// HloModule. After this pass runs HloModule::has_schedule will return false. -class HloDescheduler : public HloModulePass { - public: - HloDescheduler() = default; - ~HloDescheduler() override = default; - absl::string_view name() const override { return "hlo-descheduler"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #endif // XLA_SERVICE_HLO_MEMORY_SCHEDULER_H_ diff --git a/third_party/xla/xla/service/hlo_module_config.cc b/third_party/xla/xla/service/hlo_module_config.cc index a5400c866c63d0..4b4a1cae218586 100644 --- a/third_party/xla/xla/service/hlo_module_config.cc +++ b/third_party/xla/xla/service/hlo_module_config.cc @@ -73,6 +73,9 @@ std::string HloModuleConfig::compilation_cache_key() const { static std::atomic counter{0}; StrAppend(&key, "forcing recompile ", counter++); } + StrAppend(&key, "::exec_time_optimization_effort=", + exec_time_optimization_effort()); + StrAppend(&key, "::memory_fitting_effort=", memory_fitting_effort()); if (replica_count() != 1) { StrAppend(&key, "::replica_count=", replica_count()); } @@ -280,6 +283,8 @@ HloModuleConfigProto HloModuleConfig::ToProto() const { for (int64_t partitioning_id : auto_spmd_partitioning_mesh_ids_) { proto.add_auto_spmd_partitioning_mesh_ids(partitioning_id); } + proto.set_exec_time_optimization_effort(exec_time_optimization_effort_); + proto.set_memory_fitting_effort(memory_fitting_effort_); proto.set_deduplicate_hlo(deduplicate_hlo_); proto.set_intra_op_parallelism_threads(intra_op_parallelism_threads_); proto.set_device_type(device_type_); @@ -351,6 +356,9 @@ HloModuleConfig::CreateFromProto(const HloModuleConfigProto& proto) { config->auto_spmd_partitioning_mesh_ids_.assign( proto.auto_spmd_partitioning_mesh_ids().begin(), proto.auto_spmd_partitioning_mesh_ids().end()); + config->exec_time_optimization_effort_ = + proto.exec_time_optimization_effort(); + config->memory_fitting_effort_ = proto.memory_fitting_effort(); config->deduplicate_hlo_ = proto.deduplicate_hlo(); config->intra_op_parallelism_threads_ = proto.intra_op_parallelism_threads(); config->device_type_ = proto.device_type(); diff --git a/third_party/xla/xla/service/hlo_module_config.h b/third_party/xla/xla/service/hlo_module_config.h index 69c3ba2861cce7..8fce19e8baa547 100644 --- a/third_party/xla/xla/service/hlo_module_config.h +++ b/third_party/xla/xla/service/hlo_module_config.h @@ -209,6 +209,18 @@ class HloModuleConfig { return auto_spmd_partitioning_mesh_ids_; } + void set_exec_time_optimization_effort(float exec_time_optimization_effort) { + exec_time_optimization_effort_ = exec_time_optimization_effort; + } + float exec_time_optimization_effort() const { + return exec_time_optimization_effort_; + } + + void set_memory_fitting_effort(float memory_fitting_effort) { + memory_fitting_effort_ = memory_fitting_effort; + } + float memory_fitting_effort() const { return memory_fitting_effort_; } + // If enabled, deduplicate equivalent hlos into function calls to reduce code // size. void set_deduplicate_hlo(bool deduplicate_hlo) { @@ -431,6 +443,24 @@ class HloModuleConfig { std::vector auto_spmd_partitioning_mesh_ids_; + // The amount of effort to spend on optimizing for minimizing program + // execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which + // strongly prioritizes execution time at the cost of longer compile times, + // suitable for production workloads. A value of -0.5 would be appropriate for + // research use cases that prefer faster compilations to iterate more quickly. + // Positive values, on the other hand, might enable costly optimizations that + // are off by default. + float exec_time_optimization_effort_ = 0.0f; + + // The amount of effort to spend on making the program fit in memory (where + // "fit in memory" here has a backend-dependent meaning), as a value in [-1.0, + // +1.0]. The baseline is 0.0, which expends significant effort on attempting + // to make the program fit. A value of -1.0 would be appropriate for use cases + // that wish to spend minimal effort here and fail as quickly as possible + // instead. Positive values, on the other hand, might enable costly algorithms + // to reduce memory usage that are off by default. + float memory_fitting_effort_ = 0.0f; + // If enabled, deduplicate equivalent hlos into function calls to reduce code // size. bool deduplicate_hlo_ = false; diff --git a/third_party/xla/xla/service/hlo_module_dce.cc b/third_party/xla/xla/service/hlo_module_dce.cc index 81f8380bfb8172..fa4da849792503 100644 --- a/third_party/xla/xla/service/hlo_module_dce.cc +++ b/third_party/xla/xla/service/hlo_module_dce.cc @@ -19,13 +19,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/analysis/hlo_liveness_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_liveness_analysis.h" -#include "xla/service/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/service/while_loop_simplifier.h" #include "xla/status_macros.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/hlo_module_group_metadata.cc b/third_party/xla/xla/service/hlo_module_group_metadata.cc index 895cc33c0f7580..22a833d4870e72 100644 --- a/third_party/xla/xla/service/hlo_module_group_metadata.cc +++ b/third_party/xla/xla/service/hlo_module_group_metadata.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/hlo_module_group_metadata.h b/third_party/xla/xla/service/hlo_module_group_metadata.h index 959dcf8f227b2e..a53fac2206acb6 100644 --- a/third_party/xla/xla/service/hlo_module_group_metadata.h +++ b/third_party/xla/xla/service/hlo_module_group_metadata.h @@ -25,10 +25,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_alias_analysis.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_module_group_util.cc b/third_party/xla/xla/service/hlo_module_group_util.cc index 55aec4dc6ab83a..de4b61736e7474 100644 --- a/third_party/xla/xla/service/hlo_module_group_util.cc +++ b/third_party/xla/xla/service/hlo_module_group_util.cc @@ -25,10 +25,10 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/status_macros.h" #include "xla/types.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/hlo_module_group_util.h b/third_party/xla/xla/service/hlo_module_group_util.h index 7f12fb6ca660e7..9f3e28a60686c6 100644 --- a/third_party/xla/xla/service/hlo_module_group_util.h +++ b/third_party/xla/xla/service/hlo_module_group_util.h @@ -25,9 +25,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/service/hlo_module_group_metadata.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index f2375751a90f55..339feeb8fd2d4e 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -28,10 +28,11 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_original_value.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/computation_placer.h" -#include "xla/service/hlo_memory_scheduler.h" #include "xla/service/test_compilation_environment.pb.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -947,6 +948,35 @@ ENTRY main { EXPECT_EQ(stack_frame.column, location->column()); } +TEST_F(HloModuleTest, PrintOriginalValue) { + // Create a module with a single computation. + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder("Constant"); + std::vector values(16, 42.0); + auto instruction = + HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f)); + auto original_value = std::make_shared(instruction->shape()); + for (auto& leaf : original_value->leaves()) { + leaf.second = {std::string(instruction->name()), leaf.first}; + } + instruction->set_original_value(original_value); + builder.AddInstruction(std::move(instruction)); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + "HloModule PrintOriginalValue, " + "entry_computation_layout={()->f32[]}\n\nENTRY %Constant () -> " + "f32[] {\n ROOT %constant = f32[] constant(42), " + "origin={{\"constant\"}}\n}\n\n", + module->ToString(HloPrintOptions().set_print_original_value(true))); + + EXPECT_EQ( + "HloModule PrintOriginalValue, " + "entry_computation_layout={()->f32[]}\n\nENTRY %Constant () -> " + "f32[] {\n ROOT %constant = f32[] constant(42)\n}\n\n", + module->ToString(HloPrintOptions().set_print_original_value(false))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_module_util.cc b/third_party/xla/xla/service/hlo_module_util.cc index ca67634cd1e14d..1bc65eef147ef9 100644 --- a/third_party/xla/xla/service/hlo_module_util.cc +++ b/third_party/xla/xla/service/hlo_module_util.cc @@ -118,6 +118,10 @@ absl::StatusOr> CreateModuleConfig( config->set_auto_spmd_partitioning_mesh_ids(std::vector( execution_options->auto_spmd_partitioning_mesh_ids().begin(), execution_options->auto_spmd_partitioning_mesh_ids().end())); + config->set_exec_time_optimization_effort( + execution_options->exec_time_optimization_effort()); + config->set_memory_fitting_effort( + execution_options->memory_fitting_effort()); config->set_deduplicate_hlo(execution_options->deduplicate_hlo()); config->set_seed(execution_options->seed()); config->set_launch_id(execution_options->launch_id()); diff --git a/third_party/xla/xla/service/hlo_ordering.h b/third_party/xla/xla/service/hlo_ordering.h index 6b070798f9ebb4..d035368156aed2 100644 --- a/third_party/xla/xla/service/hlo_ordering.h +++ b/third_party/xla/xla/service/hlo_ordering.h @@ -16,228 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_ORDERING_H_ #define XLA_SERVICE_HLO_ORDERING_H_ -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_reachability.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/call_graph.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_value.h" -#include "xla/types.h" - -namespace xla { - -// Base class for describing a partial ordering of HLO instructions. Used to -// determine live range overlap of HLO instruction output buffers. -class HloOrdering { - public: - explicit HloOrdering(const HloModule* module) - : module_(module), call_graph_(CallGraph::Build(module)) {} - virtual ~HloOrdering() = default; - - // Specify the ordering constraints between a pair of instructions a and b. - enum class ExecutionConstraint { - // Indicate a and b are the same instruction; - kIsSame, - // Indicate a runs before b starts; - kRunBeforeStart, - // Indicate a runs before b ends but after b starts, e.g., when b is a - // conditional or while loop; - kRunBeforeEnd, - // Only one of a or b runs each time their common ancestor is evaluated, - // and a is in an earlier branch than b. - kRunExclusiveBefore, - // Only one of a or b runs each time, and a is in a later branch than b. - kRunExclusiveAfter, - // Indicate a runs after b ends. - kRunAfter, - // An order cannot be detrermined as a and b do not have a common ancestor. - kUnordered, - }; - // Return the execution constraint between a and b. - HloOrdering::ExecutionConstraint GetExecutionConstraint( - const HloInstruction* a, const HloInstruction* b) const; - - // Returns true if instruction 'a' executes before instruction 'b'. This is - // not reflexive, that is, an instruction does not execute before itself. - bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; - - // Returns whether the value 'a' is defined before the value 'b' under the - // given ordering. - bool IsDefinedBefore(const HloValue& a, const HloValue& b) const; - - // Returns whether the given use is before the given value definition under - // the given ordering. Set use_is_always_before_def_in_same_instr to false if - // you want the analysis to always consider a use at an instruction's operand - // to be strictly before that instructions definition. The configuration needs - // to be false when result will be used to remove unnecessary copy - // instructions, due to additional buffer sharing constraints. - bool UsesBeforeValueDefinition( - absl::Span uses, const HloValue& value, - const HloDataflowAnalysis& dataflow, - bool use_is_always_before_def_in_same_instr = false) const; - // Returns whether the given values interfere. Two values interfere if they - // may both be simultaneously live. - bool MayInterfere(const HloValue& a, const HloValue& b, - const HloDataflowAnalysis& dataflow) const; - - // Returns true if the live range of the given value 'a' is strictly before - // the live range of value 'b' using the given HLO ordering. - bool LiveRangeStrictlyBefore( - const HloValue& a, const HloValue& b, const HloDataflowAnalysis& dataflow, - bool use_is_always_before_def_in_same_instr = false) const; - - // Returns the sequential instruction order for the given computation, or - // nullptr if the computation does not have a sequential ordering. - virtual const HloInstructionSequence* SequentialOrder( - const HloComputation& computation) const = 0; - - // Return the call graph of the module used to compute ordering. - const CallGraph& call_graph() const { return *call_graph_; } - - virtual std::string ToString() const = 0; - - protected: - // Returns true if instruction 'a' executes before instruction 'b'. - // Precondition: 'a' and 'b' are in the same computation. - // - // Derived classes should implement this method for determining order of - // instructions in the same computation. ExecutesBefore() analyzes the - // callgraph and uses this method to determine ordering of instructions in - // different computations. - virtual bool ExecutesBeforeInSameComputation( - const HloInstruction* a, const HloInstruction* b) const = 0; - - const HloModule* module_; - - std::unique_ptr call_graph_; -}; - -// Base class for partial orderings implemented by a map of predecessors for -// each instruction. Subclasses should fill in predecessors_. -class PredecessorHloOrdering : public HloOrdering { - public: - ~PredecessorHloOrdering() override = default; - - // Returns nullptr indicating the computation does not have a sequential - // ordering. - const HloInstructionSequence* SequentialOrder( - const HloComputation& computation) const override { - return nullptr; - } - - HloReachabilityMap& reachability_map(const HloComputation* computation) { - return *predecessors_.at(computation); - } - const HloReachabilityMap& reachability_map( - const HloComputation* computation) const { - return *predecessors_.at(computation); - } - - protected: - explicit PredecessorHloOrdering(const HloModule* module); - std::string ToStringHelper(const std::string& name) const; - - bool ExecutesBeforeInSameComputation(const HloInstruction* a, - const HloInstruction* b) const override; - - // For each computation in the module, this is the set of the instruction's - // predecessors. An instruction is an element of its own predecessor set. - // - // Subclasses should fill this in to define the desired ordering. - absl::flat_hash_map> - predecessors_; -}; - -// An HLO ordering based on data dependencies in the HLO graph. In this partial -// order, instruction A executes before instruction B only if there is a path -// from A to B in the HLO graph. For example, given the following graph: -/* - param - / \ - negate exp - \ / - add -*/ -// DependencyHloOrdering gives the following executes-before relations: -// param executes before negate, exp, and add -// negate executes before add -// exp executes before add -// add executes before nothing -// negate and exp are not ordered because the dependencies allow either to -// execute before the other (or in parallel). DependencyHloOrdering ordering -// allows maximum parallelism and enables any execution order which satisfies -// data dependencies. This requires pessimistic assumptions about buffer live -// ranges and can result in more memory used than more constrained orderings. -class DependencyHloOrdering : public PredecessorHloOrdering { - public: - explicit DependencyHloOrdering(const HloModule* module); - ~DependencyHloOrdering() override = default; - - std::string ToString() const override; -}; - -// An HLO ordering based on a total order of instructions in each computation. -// The computation total order is a sequencing of all of its instructions in -// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded -// execution. For example, given the following HLO graph: -/* - param - / \ - negate exp - \ / - add -*/ -// and the following sequence: -// -// {param, negate, exp, add} -// -// SequentialHloOrdering gives the following executes-before relations: -// param executes before negate, exp, and add -// negate executes before exp and add -// exp executes before add -// add executes before nothing -// This is more constrained than DependencyHloOrdering in this example because -// negate and exp are ordered (negate before exp). This enables param to share -// the same buffer as exp (param buffer is dead after exp). Generally, this -// ordering enables more buffer sharing (reduced memory usage) because buffer -// interference is reduced relative to DependencyHloOrdering. -class SequentialHloOrdering : public HloOrdering { - public: - explicit SequentialHloOrdering(const HloSchedule& schedule); - explicit SequentialHloOrdering(HloSchedule&& schedule); - ~SequentialHloOrdering() override = default; - - // Returns the sequential instruction order for the given computation. - const HloInstructionSequence* SequentialOrder( - const HloComputation& computation) const override; - - std::string ToString() const override; - - protected: - void Initialize(); - - bool ExecutesBeforeInSameComputation(const HloInstruction* a, - const HloInstruction* b) const override; - - const HloSchedule schedule_; - - // The position of every instruction in the HLO module in its respective - // computation sequence (a value of zero indicates the instruction is first in - // the sequence, etc). Instructions from all computations are contained in - // this map so more than one instruction may have the same position - // value. This is not a problem because ExecutesBefore also verifies - // instructions are in the same computation. - absl::flat_hash_map order_position_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_ordering.h" #endif // XLA_SERVICE_HLO_ORDERING_H_ diff --git a/third_party/xla/xla/service/hlo_parser.h b/third_party/xla/xla/service/hlo_parser.h index 2628c15eb00db8..6a9e8d8be6039d 100644 --- a/third_party/xla/xla/service/hlo_parser.h +++ b/third_party/xla/xla/service/hlo_parser.h @@ -16,93 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_PARSER_H_ #define XLA_SERVICE_HLO_PARSER_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_lexer.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with the given config. -// Note: Tests derived from HloTestBase should use -// ParseAndReturnVerifiedModule() instead! -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, const HloModuleConfig& config, - bool set_to_default_entry_computation_layout = true); - -// Given a string in the HloModule::ToString() format, parses the string and -// creates a HloModule with default config. -// Note: Tests derived from HloTestBase should use -// ParseAndReturnVerifiedModule() instead! -absl::StatusOr> ParseAndReturnUnverifiedModule( - absl::string_view str, bool set_to_default_entry_computation_layout = true); - -// Parses sharding from str. str is supposed to contain the body of the -// sharding, i.e. just the rhs of the "sharding={...}" attribute string, e.g., -// "{replicated}". -absl::StatusOr ParseSharding(absl::string_view str); - -// Parses frontend attributes from str. str is supposed to contain the body of -// the frontend attributes , i.e. just the rhs of the -// "frontend_attributes={...}" attribute string, e.g., -// "{attr_a=a,attr_b=b}". -absl::StatusOr ParseFrontendAttributes( - absl::string_view str); - -// Parses statistics viz from str. str is supposed to contain the body of the -// statistics visualization, i.e. just the rhs of the "statistics={...}" -// attribute string, e.g., "{visualizing_index=1,nan_percent=50}". -absl::StatusOr ParseStatisticsViz(absl::string_view str); - -// Parses parameter replication from str. str is supposed to contain the body of -// the parameter replication, i.e. just the rhs of the -// "parameter_replication={...}" attribute string, e.g., "{true, false}". -absl::StatusOr> ParseParameterReplication( - absl::string_view str); - -// Parses the result of window_util::ToString(const Window&). -absl::StatusOr ParseWindow(absl::string_view str); - -// Parses the result of ConvolutionDimensionNumbersToString(), e.g. -// "b0f_0io->b0f". -absl::StatusOr ParseConvolutionDimensionNumbers( - absl::string_view str); - -// Parses the result of PaddingConfigToString(), e.g. "0_0x1_1". -absl::StatusOr ParsePaddingConfig(absl::string_view str); - -// Parses and returns a Shape::ToString-format string. -absl::StatusOr ParseShape(absl::string_view str); - -// Parses and returns a Layout::ToString-format string. -absl::StatusOr ParseLayout(absl::string_view str); - -// Parses and returns a std::vector from str. str is supposed to -// contain a list of the replica groups, i.e. just the rhs of the -// "replica_groups={...}" attribute string, e.g., "{{0,1}, {2,3}}". -absl::StatusOr> ParseReplicaGroupsOnly( - absl::string_view str); - -class HloParser { - public: - // Runs the parser and constructs the resulting HLO in the given (empty) - // HloModule. Returns the error status in case an error occurred. - virtual absl::Status Run(HloModule* module) = 0; - virtual ~HloParser() {} - - private: - static std::unique_ptr CreateHloParserForTests( - absl::string_view str); - friend class VerifiedHloModule; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/parser/hlo_parser.h" #endif // XLA_SERVICE_HLO_PARSER_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization.h b/third_party/xla/xla/service/hlo_rematerialization.h index 4eba8f1a2bdc8d..0dcdcee3636247 100644 --- a/third_party/xla/xla/service/hlo_rematerialization.h +++ b/third_party/xla/xla/service/hlo_rematerialization.h @@ -15,234 +15,7 @@ #ifndef XLA_SERVICE_HLO_REMATERIALIZATION_H_ #define XLA_SERVICE_HLO_REMATERIALIZATION_H_ -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_schedule.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/call_graph.h" -#include "xla/service/hlo_cost_analysis.h" -#include "xla/service/tuple_points_to_analysis.h" -#include "xla/shape.h" - -namespace xla { - -// HLO pass which rematerializes instructions to reduce peak memory use, where -// memory use is defined as the total size of all live HLO instruction -// values. Parameters and constants are included in memory use estimates. -// -// CSE will undo the effects of this optimization and should not be run after -// this pass. In general, this pass should be run very late, immediately before -// code generation. -class HloRematerialization : public HloModulePass { - public: - using ShapeSizeFunction = std::function; - - using CompactShapeFunction = - std::function(const Shape&)>; - - // Helper struct that communicates the before / after sizes for the - // rematerialization process. - struct RematerializationSizes { - int64_t before_bytes = -1; - int64_t after_bytes = -1; - }; - - // Mode in which the rematerialization algorithm should be run. - struct RematerializationModeConfig { - RematerializationModeConfig(bool recompute, bool compress, - bool host_offload) - : recompute(recompute), - compress(compress), - host_offload(host_offload) {} - bool recompute; // Enables the kCompress RematStrategy. - bool compress; // Enables the kRecompute RematStrategy. - bool host_offload; // Enables the kHostOffload RematStrategy. - }; - - // This is a struct containing configuration options that are specific to the - // Host Memory Offload strategy. - struct HostMemoryOffloadConfig { - explicit HostMemoryOffloadConfig(int64_t host_memory_space, - float bandwidth_to_host_bytes_per_second, - float bandwidth_from_host_bytes_per_second) - : host_memory_space(host_memory_space), - bandwidth_to_host_bytes_per_second( - bandwidth_to_host_bytes_per_second), - bandwidth_from_host_bytes_per_second( - bandwidth_from_host_bytes_per_second) {} - - // The host memory space, which is used during the host offload strategy. - int64_t host_memory_space; - - float bandwidth_to_host_bytes_per_second; - - float bandwidth_from_host_bytes_per_second; - }; - - static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; } - - struct Options { - explicit Options(HloCostAnalysis& hlo_cost_analysis, - const RematerializationModeConfig& remat_mode_config, - int64_t memory_limit_bytes, int block_size_limit, - int block_rematerialization_factor, int64_t min_remat_size, - CompactShapeFunction compact_shape_function, - std::optional - host_memory_offload_config = std::nullopt, - absl::flat_hash_map - async_computation_parallelism = {}) - : hlo_cost_analysis(hlo_cost_analysis), - remat_mode_config(remat_mode_config), - memory_limit_bytes(memory_limit_bytes), - block_size_limit(block_size_limit), - block_rematerialization_factor(block_rematerialization_factor), - min_remat_size(min_remat_size), - compact_shape_function(compact_shape_function == nullptr - ? DefaultCompactShapeFunction - : std::move(compact_shape_function)), - host_memory_offload_config(host_memory_offload_config), - async_computation_parallelism(async_computation_parallelism) {} - - // The cost model used for decisions during rematerialization for host - // memory offload. It is also used for getting Shape size. - HloCostAnalysis& hlo_cost_analysis; - - // Holds the rematerialization strategy configuration to be used by the - // pass. - RematerializationModeConfig remat_mode_config; - - // Function which computes the size of the top-level buffer of a shape. - const ShapeSizeFunction size_function; - - // The threshold number of bytes to reduce memory use to via - // rematerialization. Size of aliased outputs should be subtracted - // from this. - int64_t memory_limit_bytes; - - // Maximum number of consecutive instructions to consider for - // rematerialization. - int block_size_limit; - - // Controls the amount of effort spent trying to find large blocks for - // rematerialization. Larger values leads to longer compilation times in - // return for potentially reduced memory consumption. - int block_rematerialization_factor; - - // The minimum size, in bytes, of a tensor to be considered for - // rematerialization. All tensors smaller than this size will be skipped - // over. - int64_t min_remat_size; - - // Converts a shape into compact form, returns the same shape if a shape is - // already considered compact. - CompactShapeFunction compact_shape_function; - - std::optional host_memory_offload_config; - - // Collection of async entry computations and their number of parallel - // invocations. - absl::flat_hash_map async_computation_parallelism; - }; - - explicit HloRematerialization(Options options, RematerializationSizes& sizes) - : options_(std::move(options)), sizes_(sizes) {} - - ~HloRematerialization() override = default; - - absl::string_view name() const override { return "rematerialization"; } - - // Get the next available channel id and increment count. - int64_t NextChannelId() { return next_channel_id_++; } - - // Get the peak memory for the computation. - int64_t ComputationPeakMemory(const HloComputation* computation) const { - return computation_peak_memory_.at(computation); - } - - // Runs rematerialization on the given module. Returns whether the module was - // changed. Requires that the module has a schedule set - // (HloModule::has_schedule() is true) before running. Returns whether any - // instructions were rematerialized. If memory use is already below the limit - // specified in the constructor then no instructions are rematerialized and - // false is returned. - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - protected: - // Rematerializes instructions within the given computation. 'schedule' - // constains the order in which the computation's instructions will be emitted - // in the backend. Rematerialized instructions will be added to the HLO - // computation and inserted into 'schedule'. - virtual absl::StatusOr RematerializeComputation( - HloComputation* computation, HloSchedule* schedule, - int64_t memory_limit_bytes, int64_t min_remat_size, - const absl::flat_hash_set& execution_threads); - - // Computes and returns the peak memory used by the given computation. The - // peak memory is the maximum total size of all live HLO instruction values at - // any program point. 'order' is the order in which the HLO instructions will - // be emitted which is used to determine lifespans of HLO values. - absl::StatusOr ComputePeakMemory( - const HloComputation* computation, const HloInstructionSequence& order, - const absl::flat_hash_set& execution_threads) const; - - // Returns the peak memory usage of the called computations for the given - // instruction. Zero is returned if the instruction calls no computations. - absl::StatusOr CalledComputationsMemoryUsage( - const HloInstruction* instruction, - const absl::flat_hash_set& execution_threads) const; - - const Options options_; - - // Reference to data structure which records the peak memory usage of the HLO - // module before/after rematerialization. - RematerializationSizes& sizes_; - - // Call graph of the hlo_module. - std::unique_ptr call_graph_; - - // The peak memory usage of each computation. The map contains only those - // computations called from sequential context - // (CallContext::kSequential). These values are updated as rematerialization - // occurs. - absl::flat_hash_map computation_peak_memory_; - - std::unique_ptr points_to_analysis_; - - // Set of computations which have had rematerialization - // applied. Rematerialization is only applied once per computation. - absl::flat_hash_set rematerialized_computations_; - - // Count of the total instructions rematerialized. - int64_t instructions_rematerialized_ = 0; - - // Count of the net instructions added to the HLO module by - // rematerialization. This can be different than instructions_rematerialized_ - // because some rematerializations are effectively moves in the HLO - // schedule. In these cases, the rematerialization instruction replaces all - // uses of the original instruction and the original instruction is - // dead. Hence, no net instructions were added. - int64_t net_instructions_added_ = 0; - - // Size of the largest block that has been rematerialized. This is actually an - // upper bound (within a factor of 2) on the block size. - int max_rematerialized_block_size_ = 0; - - // Tracking available channel id numbers to use to apply to rematerialized - // channel instructions - int64_t next_channel_id_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization.h" #endif // XLA_SERVICE_HLO_REMATERIALIZATION_H_ diff --git a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h b/third_party/xla/xla/service/hlo_rematerialization_test_utils.h index 069494536f2637..8837169bec82fd 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test_utils.h +++ b/third_party/xla/xla/service/hlo_rematerialization_test_utils.h @@ -18,130 +18,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ #define XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ -#include -#include - -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/shape_util.h" -#include "xla/tests/hlo_test_base.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class RematerializationTestBase : public HloTestBase { - protected: - // Creates and returns a computation which can benefit from - // rematerialization. The computation looks like: - // - // F32[1] %param = {...} - // F32[] %reshape = reshape(F32[], param) - // F32[1024] %bcast = broadcast(%param) - // F32[1024] %negate = negate(%bcast) - // F32[2048] %concat_1 = concat({%negate, %negate}) - // F32[1] %slice_1 = slice(%concat_1, {0:1}) - // F32[1025] %concat_2 = concat({%bcast, %slice_1}) - // F32[1] %slice_2 = slice(%concat_2, {0:1}); - // - // The instruction %bcast can be rematerialized before its use at %concat_2 - // to reduce peak memory usage. This avoids %bcast and %concat_1 being - // simultaneously live. Peak memory use is about 16KB before rematerialization - // (during execution of %concat_1) and about 12KB after rematerializing %bcast - // for its use in %concat_2. - std::unique_ptr MakeRematerializableComputation( - const std::string& suffix = "") { - auto builder = HloComputation::Builder(TestName() + suffix); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); - auto reshape = builder.AddInstruction( - HloInstruction::CreateReshape(scalar_shape_, param)); - auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); - auto negate = builder.AddInstruction( - HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kNegate, bcast)); - auto concat_1 = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(xla::F32, {2048}), {negate, negate}, - /*dimension=*/0)); - auto slice_1 = builder.AddInstruction(HloInstruction::CreateSlice( - vec1_shape_, concat_1, /*start_indices=*/{0}, - /*limit_indices=*/{1}, - /*strides=*/{1})); - auto concat_2 = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, slice_1}, - /*dimension=*/0)); - // Add a final slice to make the parameter shape match the output shape - // which is necessary to use this computation in a while. - builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat_2, - /*start_indices=*/{0}, - /*limit_indices=*/{1}, - /*strides=*/{1})); - return builder.Build(); - } - - // Creates and returns a computation which includes a while and can benefit - // from rematerialization. The computation looks like: - // - // F32[] %param = {...} - // F32[1024] %bcast = broadcast(%param) - // F32[1] %slice_1 = slice(%bcast, {0:1}) - // F32[1] %while = while(%slice_1, while_body, while_cond) - // F32[1025] %concat = concat({%bcast, %while}) - // F32[1] %slice_2 = slice(%concat, {0:1}); - // - // The instruction %bcast can be rematerialized before its use at %concat to - // reduce peak memory usage. This avoids %bcast being live during execution of - // the while. Peak memory use is maximum of 8K and 4K plus the memory use of - // the while subcomputations. - std::unique_ptr MakeRematerializableWhileComputation( - HloComputation* while_cond, HloComputation* while_body, - const std::string& suffix = "") { - auto builder = HloComputation::Builder(TestName() + suffix); - auto param = builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); - auto reshape = builder.AddInstruction( - HloInstruction::CreateReshape(scalar_shape_, param)); - auto bcast = builder.AddInstruction( - HloInstruction::CreateBroadcast(vec1024_shape_, reshape, {})); - auto slice_1 = builder.AddInstruction( - HloInstruction::CreateSlice(vec1_shape_, bcast, /*start_indices=*/{0}, - /*limit_indices=*/{1}, - /*strides=*/{1})); - auto while_inst = builder.AddInstruction(HloInstruction::CreateWhile( - vec1_shape_, while_cond, while_body, slice_1)); - auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(xla::F32, {1025}), {bcast, while_inst}, - /*dimension=*/0)); - builder.AddInstruction(HloInstruction::CreateSlice(vec1_shape_, concat, - /*start_indices=*/{0}, - /*limit_indices=*/{1}, - /*strides=*/{1})); - return builder.Build(); - } - - // Create and return a trivial computation appropriate for use as a while - // condition. - std::unique_ptr MakeConditionComputation() { - auto builder = HloComputation::Builder(TestName() + ".cond"); - builder.AddInstruction( - HloInstruction::CreateParameter(0, vec1_shape_, "param")); - builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); - return builder.Build(); - } - - // Return the byte size of the top-level buffer of the given shape. - static int64_t ByteSizeOf(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); - } - - protected: - // Various shapes used in the canned computations. - const Shape scalar_shape_ = ShapeUtil::MakeShape(xla::F32, {}); - const Shape vec1_shape_ = ShapeUtil::MakeShape(xla::F32, {1}); - const Shape vec1024_shape_ = ShapeUtil::MakeShape(xla::F32, {1024}); -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/hlo_rematerialization_test_utils.h" #endif // XLA_SERVICE_HLO_REMATERIALIZATION_TEST_UTILS_H_ diff --git a/third_party/xla/xla/service/hlo_replication_analysis.h b/third_party/xla/xla/service/hlo_replication_analysis.h index 417598705bea91..85289cb01adb5e 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis.h +++ b/third_party/xla/xla/service/hlo_replication_analysis.h @@ -16,138 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ #define XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// An HLO pass that determines whether each instruction in the module outputs -// the same value across replicas or across partitions (depending on the value -// `cross_partition_spmd`). It propagates sources of replicated values to -// the rest of the module, where sources include cross-replica-sum, annotated -// entry parameters, and constants. -class HloReplicationAnalysis { - public: - // Runs the analysis on module and returns the result or an error. - static absl::StatusOr> Run( - const HloModule* module, bool cross_partition_spmd); - - // Same as above, but the caller can provide additional annotations: a set of - // while loops that are known to have the same iteration counts across - // replicas or partitions. - static absl::StatusOr> Run( - const HloModule* module, bool cross_partition_spmd, - const absl::flat_hash_set* - loops_known_with_same_iterations); - - // Same as above but supports finding partially replicated HLOs. - static absl::StatusOr> - RunWithPartialReplication(const HloModule* module, bool cross_partition_spmd); - - // Returns if the HLO instruction outputs the same value (i.e., replicated) at - // the given index across all replicas or partitions. - bool HloInstructionIsReplicatedAt(const HloInstruction* inst, - const ShapeIndex& index) const; - - bool HloInstructionIsReplicatedAt( - const HloInstruction* inst, const ShapeIndex& index, - absl::Span replica_groups) const; - - private: - // A data structure that represents how an HLO is replicated among a set of - // devices. Device ID could be either partition ID or replica ID. - // We represent partial replication by grouping devices that have the same - // value into the same set. - class HloReplication { - public: - static HloReplication ReplicatedOnAllDevices(); - static HloReplication UniqueOnAllDevices(); - static HloReplication PartiallyReplicated( - absl::Span> device_sets); - HloReplication(); - HloReplication(const HloReplication& other) = default; - HloReplication(HloReplication&& other) = default; - HloReplication& operator=(HloReplication&& other) = default; - HloReplication Merge(const HloReplication& other) const; - bool Equal(const HloReplication& other) const; - bool IsReplicatedOnAllDevices() const; - bool IsUniqueOnAllDevices() const; - bool IsReplicatedWithinSubgroup(absl::Span device_ids) const; - std::string ToString() const; - - private: - enum class State { - kReplicatedOnAllDevices = 0, - kUniqueOnAllDevices = 1, - kPartiallyReplicated = 2, - }; - explicit HloReplication(State state, - absl::Span device_set_root); - State state_; - // Empty if state_ is kReplicatedOnAllDevices or kUniqueOnAllDevices. - // Otherwise, its size equals to the number of devices (either partitions - // or replications). Maps each device ID to the smallest device ID in the - // set. - std::vector device_set_root_; - }; - - static HloReplication DetermineHloInstructionIsReplicated( - const HloInstruction* hlo, const ShapeIndex& index, - bool cross_partition_spmd, - const absl::flat_hash_map>& hlo_replication, - bool support_partial_replication); - - HloReplicationAnalysis(const HloModule* module, bool cross_partition_spmd, - const absl::flat_hash_set* - loops_known_with_same_iterations, - bool support_partial_replication) - : module_(module), - cross_partition_spmd_(cross_partition_spmd), - loops_known_with_same_iterations_(*loops_known_with_same_iterations), - support_partial_replication_(support_partial_replication) {} - - // Computes hlo_replication_. - absl::Status ComputeHloReplication(); - - // A helper function to recursively compute hlo_replication on a computation. - // Returns whether hlo_replication_ is changed. - bool ComputeHloReplicationOnComputation(const HloComputation* computation, - bool mark_everything_not_replicated); - - const HloModule* module_; - - // If true, run this replication analysis for replicated values across - // partitions (not across replicas) on an SPMD partitioned module. This means - // that HloInstructionIsReplicatedAt() returns true if the value is identical - // across partitions for each replica. The module-level parameter and root - // instructions may have HloSharding attributes that indicate whether values - // are identical across partitions. - // - // If false, HloReplicationAnalysis runs across replicas. - bool cross_partition_spmd_; - - // A set of while loops that are known to have the same iteration counts - // across replicas or partitions. This is provided by the caller as additional - // annotations. - const absl::flat_hash_set& - loops_known_with_same_iterations_; - - const bool support_partial_replication_; - - // A map from each analyzed HLO instruction to a shape tree that represents - // whether the instruction outputs the same value across replicas or - // partitions at each shape index. - absl::flat_hash_map> - hlo_replication_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_replication_analysis.h" #endif // XLA_SERVICE_HLO_REPLICATION_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_runner.cc b/third_party/xla/xla/service/hlo_runner.cc index 35f03cf3535960..16a3e4a0ac601a 100644 --- a/third_party/xla/xla/service/hlo_runner.cc +++ b/third_party/xla/xla/service/hlo_runner.cc @@ -23,10 +23,10 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout_util.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_util.h" -#include "xla/service/hlo_parser.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/hlo_runner.h b/third_party/xla/xla/service/hlo_runner.h index f8b49e6bd8fbc4..f2387c04bca6d4 100644 --- a/third_party/xla/xla/service/hlo_runner.h +++ b/third_party/xla/xla/service/hlo_runner.h @@ -187,10 +187,14 @@ class HloRunner : public HloRunnerInterface { absl::string_view Name() const override; - DeviceShapeRepresentationFn device_shape_representation_fn() { + DeviceShapeRepresentationFn device_shape_representation_fn() const override { return device_shape_representation_fn_; } + DeviceShapeSizeFn device_shape_size_fn() const override { + return backend().compiler()->ShapeSizeBytesFunction(); + } + private: absl::StatusOr ExecuteWithExecutionInputs( Executable* executable, std::vector arguments, diff --git a/third_party/xla/xla/service/hlo_runner_interface.cc b/third_party/xla/xla/service/hlo_runner_interface.cc index bf0076071ad95f..f3f3303851952a 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.cc +++ b/third_party/xla/xla/service/hlo_runner_interface.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/service/hlo_runner_interface.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_runner_interface.h b/third_party/xla/xla/service/hlo_runner_interface.h index d73f0be87eb316..ab6ab7f121b13b 100644 --- a/third_party/xla/xla/service/hlo_runner_interface.h +++ b/third_party/xla/xla/service/hlo_runner_interface.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ #define XLA_SERVICE_HLO_RUNNER_INTERFACE_H_ +#include +#include #include #include #include @@ -84,14 +86,16 @@ class HloRunnerInterface { bool use_threads = false; }; - HloRunnerInterface() = default; + using DeviceShapeRepresentationFn = std::function; + using DeviceShapeSizeFn = std::function; + HloRunnerInterface() = default; virtual ~HloRunnerInterface() = default; // Converts an HloModule from the given hlo textual IR string (in // HloModule::ToString format). static absl::StatusOr> CreateModuleFromString( - const absl::string_view hlo_string, const DebugOptions& debug_options); + absl::string_view hlo_string, const DebugOptions& debug_options); // Reads the proto file in xla.HloProto format, creates and returns the // HloModule. @@ -215,7 +219,13 @@ class HloRunnerInterface { // Returns the name of this runner. virtual absl::string_view Name() const = 0; - typedef std::function DeviceShapeRepresentationFn; + // Return the device shape representation of 'host_shape'. + virtual DeviceShapeRepresentationFn device_shape_representation_fn() + const = 0; + // Return the device shape size of 'host_shape'. + // This function is used e.g. to create a VerifiedHloModule. It returns an + // integer representing the size of the shape in bytes as opposed to a Shape. + virtual DeviceShapeSizeFn device_shape_size_fn() const = 0; }; } // namespace xla diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index 3965bf61870f3a..6406e6269854d6 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/layout.h" #include "xla/pjrt/host_memory_spaces.h" @@ -114,6 +114,7 @@ absl::StatusOr GenerateExecuteOptions(const HloModule& module) { // TODO(b/245550554): Remove the use of PjRtWrappedExecutable. class PjRtWrappedExecutable : public Executable { public: + // Takes ownership of the provided executable. explicit PjRtWrappedExecutable(std::shared_ptr hlo_module, PjRtLoadedExecutable* pjrt_loaded_executable) : Executable(hlo_module), @@ -125,11 +126,11 @@ class PjRtWrappedExecutable : public Executable { HloExecutionProfile* hlo_execution_profile) override; PjRtLoadedExecutable* GetPjRtLoadedExecutable() const { - return pjrt_loaded_executable_; + return pjrt_loaded_executable_.get(); } private: - PjRtLoadedExecutable* pjrt_loaded_executable_; + std::unique_ptr pjrt_loaded_executable_; }; absl::StatusOr PjRtWrappedExecutable::ExecuteAsyncOnStream( @@ -144,9 +145,11 @@ static const int kDeviceIdx = 0; HloRunnerPjRt::HloRunnerPjRt( std::unique_ptr pjrt_client, - DeviceShapeRepresentationFn device_shape_representation_fn) + DeviceShapeRepresentationFn device_shape_representation_fn, + DeviceShapeSizeFn device_shape_size_fn) : pjrt_client_(std::move(pjrt_client)), - device_shape_representation_fn_(device_shape_representation_fn) {} + device_shape_representation_fn_(device_shape_representation_fn), + device_shape_size_fn_(device_shape_size_fn) {} HloRunnerPjRt::~HloRunnerPjRt() = default; @@ -373,9 +376,7 @@ absl::StatusOr> HloRunnerPjRt::CreateExecutable( std::move(pjrt_executable->GetHloModules().value()[0])), pjrt_executable.release()); - std::unique_ptr exec = - static_cast>(executable.release()); - return exec; + return executable; } absl::StatusOr> HloRunnerPjRt::ExecuteReplicated( diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.h b/third_party/xla/xla/service/hlo_runner_pjrt.h index f0d4d2ec9051a1..c6f9c685a2168b 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.h +++ b/third_party/xla/xla/service/hlo_runner_pjrt.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_RUNNER_PJRT_H_ #define XLA_SERVICE_HLO_RUNNER_PJRT_H_ +#include #include #include -#include #include #include "xla/pjrt/pjrt_client.h" @@ -34,7 +34,8 @@ class HloRunnerPjRt : public HloRunnerInterface { public: explicit HloRunnerPjRt( std::unique_ptr pjrt_client, - DeviceShapeRepresentationFn device_shape_representation_fn); + DeviceShapeRepresentationFn device_shape_representation_fn, + DeviceShapeSizeFn device_shape_size_fn); ~HloRunnerPjRt() override; @@ -97,9 +98,18 @@ class HloRunnerPjRt : public HloRunnerInterface { absl::string_view Name() const override; + DeviceShapeRepresentationFn device_shape_representation_fn() const override { + return device_shape_representation_fn_; + } + + DeviceShapeSizeFn device_shape_size_fn() const override { + return device_shape_size_fn_; + } + private: std::unique_ptr pjrt_client_; DeviceShapeRepresentationFn device_shape_representation_fn_; + DeviceShapeSizeFn device_shape_size_fn_; std::vector BufferVecToPointerVec( const std::vector>& buffer); diff --git a/third_party/xla/xla/service/hlo_schedule_test.cc b/third_party/xla/xla/service/hlo_schedule_test.cc index 4f96b30498b1c6..d18c8527893c81 100644 --- a/third_party/xla/xla/service/hlo_schedule_test.cc +++ b/third_party/xla/xla/service/hlo_schedule_test.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/log/log.h" +#include "xla/hlo/analysis/hlo_ordering.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_memory_scheduler.h" -#include "xla/service/hlo_ordering.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_memory_scheduler.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/hlo_sharding_test.cc b/third_party/xla/xla/service/hlo_sharding_test.cc index 9dc0014208ea96..b6db61bc9d31cc 100644 --- a/third_party/xla/xla/service/hlo_sharding_test.cc +++ b/third_party/xla/xla/service/hlo_sharding_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include "absl/hash/hash.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/protobuf_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc index 03c6f5c2eb6356..f6b747673293e7 100644 --- a/third_party/xla/xla/service/hlo_unstacker.cc +++ b/third_party/xla/xla/service/hlo_unstacker.cc @@ -285,6 +285,15 @@ bool UnstackWhileOperandAtIndex( bool PropagateGteShapeChange(HloInstruction* gte, UnstackerTransformer& unstacker) { VLOG(5) << "PropagateGteShapeChange(" << gte->name() << ")"; + + HloInstruction* parent_while = nullptr; + if (unstacker.GetMetadata().bodies.contains(gte->parent())) { + parent_while = unstacker.GetMetadata().bodies.at(gte->parent()); + if (parent_while->while_body() != gte->parent()) { + parent_while = nullptr; + } + } + std::vector handled_instrs; // TODO: b/343457903 - Use HloDataflowAnalysis to track the usage of a value // instead of manually applying bfs @@ -296,6 +305,7 @@ bool PropagateGteShapeChange(HloInstruction* gte, std::deque worklist; worklist.push_back(gte); visited.insert({gte, gte->tuple_index()}); + unstacker.AddOperandChange(gte, gte->tuple_index()); while (!worklist.empty()) { HloInstruction* changed_instr_to_propagate = worklist.front(); // The index of the changed operand that needs to be propagated. @@ -320,11 +330,24 @@ bool PropagateGteShapeChange(HloInstruction* gte, // instruction will get the new shape eventually and the // change_operand_index does not matter. visited.insert({user, changed_operand_index}); + unstacker.AddOperandChange(user, changed_operand_index); worklist.push_back(user); } else if (user->opcode() == HloOpcode::kTuple) { - int64_t use_index = user->operand_index(changed_instr_to_propagate); - visited.insert({user, {use_index}}); - worklist.push_back(user); + for (int64_t i = 0; i < user->operand_count(); ++i) { + if (user->operand(i) == changed_instr_to_propagate) { + visited.insert({user, i}); + unstacker.AddOperandChange(user, i); + worklist.push_back(user); + if (parent_while != nullptr && user->IsRoot() && + i != gte->tuple_index()) { + bool changed_nested_while = + CanUnstackWhileOperand(parent_while, unstacker, i); + if (!changed_nested_while) { + return false; + } + } + } + } } else if (user->opcode() == HloOpcode::kWhile) { // Recursively check the inner while for unstacking and populate // unstacker instance. @@ -334,6 +357,7 @@ bool PropagateGteShapeChange(HloInstruction* gte, return false; } visited.insert({user, changed_operand_index}); + unstacker.AddOperandChange(user, changed_operand_index); worklist.push_back(user); } else { if (absl::c_find(handled_instrs, user) != handled_instrs.end()) { @@ -358,6 +382,8 @@ bool PropagateGteShapeChange(HloInstruction* gte, for (HloInstruction* handled_instr_user : instr->users()) { if (user->shape() == gte->shape()) { visited.insert({handled_instr_user, changed_operand_index}); + unstacker.AddOperandChange(handled_instr_user, + changed_operand_index); worklist.push_back(handled_instr_user); } } @@ -366,9 +392,6 @@ bool PropagateGteShapeChange(HloInstruction* gte, } } } - for (const auto& [instr, index] : visited) { - unstacker.AddOperandChange(instr, index); - } return true; } @@ -402,6 +425,42 @@ bool CanPropagateGteShapeChangesInComputation( return true; } +std::unique_ptr DynamicSliceToSlice( + HloInstruction* dynamic_slice, HloInstruction* input, int64_t i) { + std::vector new_start_indices; + new_start_indices.reserve(dynamic_slice->shape().rank()); + std::vector new_limit_indices; + new_limit_indices.reserve(dynamic_slice->shape().rank()); + std::vector new_strides; + new_strides.reserve(dynamic_slice->shape().rank()); + new_start_indices.push_back(i); + new_limit_indices.push_back(i + 1); + new_strides.push_back(1); + for (int64_t j = 1; j < dynamic_slice->shape().rank(); ++j) { + new_start_indices.push_back(0); + new_limit_indices.push_back( + dynamic_slice->mutable_operand(0)->shape().dimensions(j)); + new_strides.push_back(1); + } + return HloInstruction::CreateSlice(dynamic_slice->shape(), input, + new_start_indices, new_limit_indices, + new_strides); +} + +bool ShouldUnfuseSlices(const UnstackerMetadata& metadata, HloInstruction* ds) { + HloInstruction* input = ds->mutable_operand(0); + for (int64_t i = 0; i < input->shape().dimensions(0); ++i) { + HloInstruction* slice = + ds->AddInstruction(DynamicSliceToSlice(ds, input, i)); + if (!metadata.unfuse_slice(slice)) { + CHECK_OK(slice->parent()->RemoveInstruction(slice)); + return false; + } + CHECK_OK(slice->parent()->RemoveInstruction(slice)); + } + return true; +} + // This function is responsible for: // 1. Hoisting the unstacking computation outside the while_instr. // 2. Replacing the input of the while_instr with the new unstacked version. @@ -451,32 +510,15 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, if (unstacker.GetPatternType() == PatternType::DSFusionPattern || unstacker.GetPatternType() == PatternType::NestedDSFusionPattern || unstacker.GetPatternType() == PatternType::DSFusionNoBitcastPattern) { - HloInstruction* dynamic_slice = nullptr; if (unstacker.GetPatternType() == PatternType::DSFusionPattern || unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { - dynamic_slice = root_instr->mutable_operand(0); + slice = while_instr->AddInstruction(DynamicSliceToSlice( + root_instr->mutable_operand(0), old_while_input, i)); } else if (unstacker.GetPatternType() == PatternType::DSFusionNoBitcastPattern) { - dynamic_slice = root_instr; - } - std::vector new_start_indices; - new_start_indices.reserve(dynamic_slice->shape().rank()); - std::vector new_limit_indices; - new_limit_indices.reserve(dynamic_slice->shape().rank()); - std::vector new_strides; - new_strides.reserve(dynamic_slice->shape().rank()); - new_start_indices.push_back(i); - new_limit_indices.push_back(i + 1); - new_strides.push_back(1); - for (int64_t j = 1; j < dynamic_slice->shape().rank(); ++j) { - new_start_indices.push_back(0); - new_limit_indices.push_back( - dynamic_slice->mutable_operand(0)->shape().dimensions(j)); - new_strides.push_back(1); + slice = while_instr->AddInstruction( + DynamicSliceToSlice(root_instr, old_while_input, i)); } - slice = while_instr->AddInstruction(HloInstruction::CreateSlice( - dynamic_slice->shape(), old_while_input, new_start_indices, - new_limit_indices, new_strides)); } if (slice == nullptr || !unstacker.GetMetadata().unfuse_slice(slice)) { std::vector operands = { @@ -530,19 +572,30 @@ bool CanUnstackWhileOperand(const HloInstruction* while_instr, return false; } - const HloInstruction* root_operand = - while_instr->while_body()->root_instruction()->operand(index); + HloInstruction* root_operand = + while_instr->while_body()->root_instruction()->mutable_operand(index); if (root_operand == nullptr) { return false; } - if (Match(root_operand, match::GetTupleElement(match::While()))) { - VLOG(3) << "Faced a gte originating from loop: " - << root_operand->ToString(); - bool loop_feeding_root_changes_collected = CanUnstackWhileOperand( - root_operand->operand(0), unstacker, root_operand->tuple_index()); - if (!loop_feeding_root_changes_collected) { - VLOG(3) << "Failed: loop " << root_operand->operand(0)->name() - << " output at " << index << " is not unstackable"; + + HloInstruction* gte_operand = nullptr; + // Currently, we only support unstacking of while operands that either: + // 1. Are parameters of the while_body. + // 2. Are get-tuple-elements of another while instruction. + if (Match(root_operand, match::GetTupleElement(match::Op(>e_operand)))) { + if (Match(gte_operand, match::While())) { + VLOG(3) << "Faced a gte originating from loop: " + << root_operand->ToString(); + bool loop_feeding_root_changes_collected = CanUnstackWhileOperand( + root_operand->operand(0), unstacker, root_operand->tuple_index()); + if (!loop_feeding_root_changes_collected) { + VLOG(3) << "Failed: loop " << root_operand->operand(0)->name() + << " output at " << index << " is not unstackable"; + return false; + } + } else if (!Match(gte_operand, match::Parameter().WithParameterNum(0))) { + VLOG(3) << "Failed: root operand of while_body at " << index + << " is not a parameter"; return false; } } @@ -660,13 +713,15 @@ Shape MakeUnstackedShapeFromSlice(const Shape& slice_shape, int64_t layers) { // parameters inside an unrollable loop. If so, it returns the loop config. std::optional IsFusionInsideUnrollableLoopWithNumParameter( const UnstackerMetadata& metadata, const HloInstruction* instr, - int64_t num_fusion_params) { + std::optional num_fusion_params) { if (instr->opcode() != HloOpcode::kFusion) { return std::nullopt; } - if (instr->fused_parameters().size() != num_fusion_params) { - VLOG(3) << "Fusion has different number of parameters"; - return std::nullopt; + if (num_fusion_params.has_value()) { + if (instr->fused_parameters().size() != num_fusion_params) { + VLOG(3) << "Fusion has different number of parameters"; + return std::nullopt; + } } if (!metadata.unrollable_loop_bodies.contains(instr->parent())) { VLOG(5) << "Fusion not inside unrollable while body, " << instr->name() @@ -683,7 +738,7 @@ std::optional IsFusionInsideUnrollableLoopWithNumParameter( // dynamic-slice instruction. HloInstruction* GetMostMajorEffectivelyStaticDynamicSliceInFusion( const UnstackerMetadata& metadata, const HloInstruction* instr, - int64_t num_fusion_params, int64_t stacked_operand_idx) { + std::optional num_fusion_params, int64_t stacked_operand_idx) { std::optional while_instr_config = IsFusionInsideUnrollableLoopWithNumParameter(metadata, instr, num_fusion_params); @@ -756,6 +811,9 @@ std::optional GetDSFusionPattern(const UnstackerMetadata& metadata, if (shape_covering_instr == nullptr) { return std::nullopt; } + if (!ShouldUnfuseSlices(metadata, shape_covering_instr)) { + return std::nullopt; + } HloInstruction* bitcast_operand = nullptr; if (Match(instr->fused_instructions_computation()->root_instruction(), match::Bitcast(match::Op(&bitcast_operand)))) { @@ -763,7 +821,6 @@ std::optional GetDSFusionPattern(const UnstackerMetadata& metadata, PatternInfo pattern_info; pattern_info.type = PatternType::DSFusionPattern; pattern_info.instr = instr; - // const Shape& slice_shape = instr->shape(); const Shape& slice_shape = shape_covering_instr->shape(); const int64_t num_layers = instr->operand(0)->shape().dimensions(0); pattern_info.unstacked_shape = @@ -1299,6 +1356,9 @@ std::optional GetReduceFusionPattern( if (shape_covering_instr == nullptr) { return std::nullopt; } + if (!ShouldUnfuseSlices(metadata, shape_covering_instr)) { + return std::nullopt; + } HloInstruction* reduce_operand = nullptr; HloInstruction* fusion_root = instr->fused_instructions_computation()->root_instruction(); @@ -1381,6 +1441,7 @@ absl::StatusOr HloUnstacker::Run( } } + int64_t num_unstacked = 0; bool unstacked = false; std::vector unstacked_instructions; for (HloInstruction* loop : entry_loops) { @@ -1393,8 +1454,12 @@ absl::StatusOr HloUnstacker::Run( VLOG(3) << "Attempting to unstack " << loop->name() << " at " << i << " = " << loop->while_init()->operand(i)->shape().ToString(true) << loop->while_init()->operand(i)->ToShortString(); - unstacked |= + bool current_unstacked = UnstackWhileOperandAtIndex(metadata, loop, i, unstacked_instructions); + if (current_unstacked) { + num_unstacked++; + unstacked = true; + } VLOG(3) << "###################"; } } @@ -1427,6 +1492,7 @@ absl::StatusOr HloUnstacker::Run( CHECK(unrolled); } VLOG(3) << "after unstacking \n" << module->ToString(); + VLOG(3) << "Num unstacked: " << num_unstacked; return true; } diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 3b00f9236a1ae7..9878e0805f6669 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -99,6 +99,172 @@ TEST_F(UnstackerTest, UnstackDSFusionPattern) { std::nullopt, false)); } +TEST_F(UnstackerTest, NotUnstackDSFusionPattern) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %fused_computation.tuple { + %param_0.51117 = s8[3,128,128] parameter(0) + mult = multiply(param_0.51117, param_0.51117) + ROOT out = tuple(param_0.51117, mult) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + fusion_mult = (s8[3,128,128], s8[3,128,128]) fusion(s8[3,128,128] p1), kind=kLoop, calls=%fused_computation.tuple + mult = s8[3,128,128] get-tuple-element(fusion_mult), index=1 + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, mult) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_FALSE(unstacked); +} + +TEST_F(UnstackerTest, UnstackDSFusionPatternMultipleLoopRootUse) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p2 = s8[3,128,128] get-tuple-element(wide_p), index=3 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p2, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(inc, conv, p2, p2) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + zero = s8[] constant(0) + buffer = s8[3,128,128] broadcast(zero), dimensions={} + while.input = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(init, p1, p0, buffer) + while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 6); + // Check that the bitcast is unfused and there are not fusions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + +TEST_F(UnstackerTest, UnstackDSFusionPatternWithUnusedOperand) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(inc, conv, p1, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + zero = s8[] constant(0) + buffer = s8[3,128,128] broadcast(zero), dimensions={} + while.input = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) tuple(init, p1, p0, buffer) + while.out = (s32[], bf16[8,128], s8[3,128,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 6); + // Check that the bitcast is unfused and there are not fusions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + TEST_F(UnstackerTest, UnstackReduceFusionPattern) { std::string hlo_string = R"( HloModule SimpleLoop @@ -265,6 +431,55 @@ TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { std::nullopt, false)); } +TEST_F(UnstackerTest, UnstackDSFusionPatternKeepFused) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + ROOT out = s8[128,128] bitcast(%dynamic-slice.22040) + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + auto unfuse = [](HloInstruction* instruction) { return false; }; + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, + HloUnstacker(unfuse).Run(module.get())); + EXPECT_FALSE(unstacked); +} + TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) { std::string hlo_string = R"( HloModule SimpleLoop @@ -650,7 +865,8 @@ TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) { std::nullopt, false)); } -TEST_F(UnstackerTest, NotUnstackNestedDSFusionPatternWithSameUnstackingComps) { +TEST_F(UnstackerTest, + NotUnstackNestedDSFusionPatternWithDifferentUnstackingComps) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.h b/third_party/xla/xla/service/hlo_value_semantics_analysis.h index dc0eeb224768e6..4a946206879037 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.h +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.h @@ -16,428 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ #define XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/node_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/dfs_hlo_visitor.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_value.h" -#include "xla/shape.h" -#include "xla/shape_tree.h" -#include "xla/shape_util.h" - -namespace xla { - -struct SendRecvGroup { - HloInstruction* send; - HloInstruction* recv; -}; - -class SendRecvGroupMap { - public: - explicit SendRecvGroupMap(const HloModule& hlo_module); - SendRecvGroupMap(SendRecvGroupMap&& other) = default; - SendRecvGroupMap(const SendRecvGroupMap& other) = default; - virtual ~SendRecvGroupMap() = default; - virtual absl::StatusOr GetMatchingSendOrRecv( - HloInstruction* send_or_recv) const; - - private: - absl::flat_hash_map host_transfer_rendezvous_map_; -}; - -class HloPreOrderDFS { - public: - HloPreOrderDFS() = default; - ~HloPreOrderDFS() = default; - absl::Status Run(const HloComputation& computation, - DfsHloVisitorBase* visitor); - - private: - bool IsReady(const HloInstruction* instruction) const; - std::vector stack_; - absl::flat_hash_set visited_; -}; - -using EinsumDepthMap = - absl::node_hash_map>; - -// The einsum depth is the length of the einsum dependency chain. And we -// distinguish instructions that are used by root and that are not used by -// root. -// The einsum depth of an HLO value A is defined as follows: -// for B = op(A, ...) -// 1) the root instruction has a depth of 0; -// 2) non-root instructions that have zero users have a depth of -1; -// 3) if op is a Dot or Convolution (i.e., einsum), -// depth(A, B) = depth(B) >= 0 ? depth(B) + 1 : depth(B) - 1. -// depth(A, B) means the depth of A because of B; -// 4) otherwise depth(A, B) = depth(B); -// 5) depth(A) is computed by merging all depth(A, u) where u is a user of A. -// See MergeDepth for how user depths are merged. - -class EinsumDepthAnalysis : public DfsHloVisitorWithDefault { - public: - static absl::StatusOr> Run( - const HloComputation& computation, - const SendRecvGroupMap& send_recv_group_map); - ~EinsumDepthAnalysis() override = default; - absl::Status DefaultAction(HloInstruction* instruction) override; - absl::Status HandleTuple(HloInstruction* tuple) override; - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - absl::Status HandleDot(HloInstruction* dot) override; - absl::Status HandleConvolution(HloInstruction* convolution) override; - absl::Status HandleCall(HloInstruction* call) override; - absl::Status HandleFusion(HloInstruction* fusion) override; - absl::Status HandleWhile(HloInstruction* xla_while) override; - absl::Status HandleConditional(HloInstruction* conditional) override; - absl::Status HandleAfterAll(HloInstruction* after_all) override; - absl::Status HandleSend(HloInstruction* send) override; - absl::Status HandleRecv(HloInstruction* recv) override; - absl::Status HandleSendDone(HloInstruction* send_done) override; - absl::Status HandleRecvDone(HloInstruction* recv_done) override; - absl::Status HandleAllReduce(HloInstruction* all_reduce) override; - absl::Status HandleAsyncStart(HloInstruction* async_start) override; - absl::Status HandleAsyncDone(HloInstruction* async_done) override; - const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } - - private: - explicit EinsumDepthAnalysis(const SendRecvGroupMap& send_recv_group_map) - : send_recv_group_map_(&send_recv_group_map) {} - absl::Status RunInternal(const HloComputation& computation, - const std::optional>& root_depth); - ShapeTree& GetOrCreateDepthTree(const HloInstruction* instruction); - ShapeTree& GetDepthTreeOrDie(const HloInstruction* instruction); - absl::Status SetInstructionDepth(const HloInstruction* instruction, - int depth); - absl::Status SetInstructionDepth(const HloInstruction* instruction, - const ShapeTree& depth); - absl::Status SetInstructionDepthFromTupleDepth( - const HloInstruction* instruction, const ShapeTree& tuple_depth_tree, - int tuple_index); - absl::Status HandleDepthIncrementInstruction(HloInstruction* instruction); - absl::Status HandleCalledComputation( - const HloComputation& called_computation, - const ShapeTree& root_depth, - absl::Span operands); - absl::Status HandleTupleLike(HloInstruction* tuple_like); - EinsumDepthMap einsum_depth_map_; - const SendRecvGroupMap* const send_recv_group_map_; -}; - -using EinsumHeightMap = - absl::node_hash_map>; - -// Einsum height is the maximum number of einsums between this instruction and -// any leaf. - -class EinsumHeightAnalysis : public DfsHloVisitorWithDefault { - public: - static absl::StatusOr> Run( - const HloComputation& computation, - const SendRecvGroupMap& send_recv_group_map); - ~EinsumHeightAnalysis() override = default; - absl::Status DefaultAction(HloInstruction* instruction) override; - absl::Status HandleTuple(HloInstruction* tuple) override; - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - absl::Status HandleDot(HloInstruction* dot) override; - absl::Status HandleConvolution(HloInstruction* convolution) override; - absl::Status HandleCall(HloInstruction* call) override; - absl::Status HandleFusion(HloInstruction* fusion) override; - absl::Status HandleWhile(HloInstruction* xla_while) override; - absl::Status HandleConditional(HloInstruction* conditional) override; - absl::Status HandleSend(HloInstruction* send) override; - absl::Status HandleRecv(HloInstruction* recv) override; - absl::Status HandleSendDone(HloInstruction* send_done) override; - absl::Status HandleRecvDone(HloInstruction* recv_done) override; - absl::Status HandleAllReduce(HloInstruction* all_reduce) override; - absl::Status HandleAsyncStart(HloInstruction* async_start) override; - absl::Status HandleAsyncDone(HloInstruction* async_done) override; - const EinsumHeightMap& GetEinsumHeightMap() const { - return einsum_height_map_; - } - - private: - explicit EinsumHeightAnalysis(const SendRecvGroupMap& send_recv_group_map) - : send_recv_group_map_(&send_recv_group_map) {} - absl::Status RunInternal(const HloComputation& computation, - absl::Span operands); - ShapeTree& GetOrCreateHeightTree(const HloInstruction* instruction); - ShapeTree& GetHeightTreeOrDie(const HloInstruction* instruction); - bool HasHeightFor(const HloInstruction* instruction) const; - absl::Status SetInstructionHeight(const HloInstruction* instruction, - int height); - absl::Status SetInstructionHeight(const HloInstruction* instruction, - const ShapeTree& height); - absl::Status HandleHeightIncrementInstruction(HloInstruction* instruction); - absl::Status HandleCalledComputation( - const HloComputation& computation, - absl::Span operands); - absl::Status HandleTupleLike(HloInstruction* tuple_like); - - EinsumHeightMap einsum_height_map_; - const SendRecvGroupMap* const send_recv_group_map_; -}; - -// The comment below explains where the labels could originate from. Once -// originated, those labels are then propagated throughout the HLO module. -enum class HloValueSemanticLabel { - // Values that are known or predictable at compile time, including constants, - // iota, replica-id, and partition-id. - kStatic, - // Values that are not known or can't be predicated at compile time. - kRandom, - // HLO module parameters. - kWeight, - // Output of weight-weight or weight-activation matmuls. - kActivation, - // Output of weight-activation matmuls where the weight is a dependence of - // that activation. Or output of weight-activation-gradient matmuls. - kActivationGradient, - // Output of activation-gradient-activation matmuls. - kWeightGradient, - kTupleOrToken, -}; - -std::string HloValueSemanticLabelToString(HloValueSemanticLabel label); - -class HloValueSemantics { - public: - using Id = int64_t; - HloValueSemantics(HloValueSemanticLabel label, const HloPosition& origin); - HloValueSemantics(Id id, HloValueSemanticLabel label, - const HloPosition& origin); - HloValueSemantics(const HloValueSemantics& other) = default; - HloValueSemantics(HloValueSemantics&& other) = default; - HloValueSemantics& operator=(const HloValueSemantics& other) = default; - - Id id() const { return id_; } - HloValueSemanticLabel label() const { return label_; } - const HloPosition& origin() const { return origin_; } - std::string ToString() const; - - private: - const Id id_; - const HloValueSemanticLabel label_; - const HloPosition origin_; -}; - -std::string HloValueSemanticsTreeToString( - const ShapeTree& tree); - -using HloValueSemanticsMap = - absl::node_hash_map>; -class HloValueSemanticsPropagation; - -class HloValueSemanticsAnalysis { - public: - static absl::StatusOr> Run( - const HloModule& module, - const absl::flat_hash_set& execution_threads = {}); - virtual ~HloValueSemanticsAnalysis() = default; - bool HasSemanticsFor(const HloInstruction* instruction) const; - const HloValueSemantics* GetSemantics(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - - const HloValueSemanticsMap& GetSemanticsMap() const { - return value_semantics_; - } - - const EinsumDepthMap& GetEinsumDepthMap() const { return einsum_depth_map_; } - const EinsumHeightMap& GetEinsumHeightMap() const { - return einsum_height_map_; - } - int GetDepth(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - int GetHeight(const HloInstruction* instruction, - const ShapeIndex& index = {}) const; - - const SendRecvGroupMap& GetSendRecvGroupMap() const { - return *send_recv_group_map_; - } - - absl::StatusOr GetMatchingSendOrRecv( - HloInstruction* send_or_recv) const; - - protected: - friend class HloValueSemanticsPropagation; - explicit HloValueSemanticsAnalysis( - const HloModule& module, - const absl::flat_hash_set& execution_threads); - virtual absl::Status InitializeEinsumDepth(); - virtual absl::Status InitializeEinsumHeight(); - // We match send and recv HLOs to propagate semantics from send to recv. - virtual void InitializeSendRecvGroups(); - void AnnotateWeights(); - - // Infer semantics for all instructions in the computation. Computation - // parameters are assigned the semantics of the corresponding operand. - absl::Status RunOnComputation( - const HloComputation& computation, - absl::Span operands); - // Same as the above RunOnComputation, but computation parameters have - // already been assigned with semantics. - virtual absl::Status RunOnComputation(const HloComputation& computation); - HloValueSemantics::Id NextId(); - const HloValueSemantics* NewHloValueSemantics(HloValueSemanticLabel label, - const HloPosition& origin); - const ShapeTree& GetInstructionSemantics( - const HloInstruction* instruction) const; - void DeepCopyHloValueSemantics( - ShapeTree& copy_to, - const ShapeTree& copy_from, - const ShapeIndex& source_index, const ShapeIndex& destination_index); - void DeepCopyHloValueSemantics( - const HloInstruction* target, - const ShapeTree& copy_from, - const ShapeIndex& source_index = {}); - void SetHloValueSemantics( - const HloInstruction* target, - const ShapeTree& semantics); - void DeleteHloValueSemantics( - const ShapeTree& to_delete); - void DeleteHloValueSemantics(const HloValueSemantics* to_delete); - const HloModule& module_; - const absl::flat_hash_set& execution_threads_; - HloValueSemanticsMap value_semantics_; - absl::flat_hash_map> - value_semantics_map_; - HloValueSemantics::Id next_id_; - EinsumDepthMap einsum_depth_map_; - EinsumHeightMap einsum_height_map_; - std::unique_ptr send_recv_group_map_; -}; - -class HloValueSemanticsPropagation : public DfsHloVisitorWithDefault { - public: - explicit HloValueSemanticsPropagation(HloValueSemanticsAnalysis* analysis); - absl::Status Run(const HloComputation& computation); - // Infer the output semantics from all operands of the instruction. - absl::Status DefaultAction(HloInstruction* instruction) override; - absl::Status HandleParameter(HloInstruction* parameter) override; - absl::Status HandleConstant(HloInstruction* constant) override; - absl::Status HandleIota(HloInstruction* iota) override; - absl::Status HandlePartitionId(HloInstruction* partition_id) override; - absl::Status HandleReplicaId(HloInstruction* replica_id) override; - absl::Status HandleClamp(HloInstruction* clamp) override; - absl::Status HandleTuple(HloInstruction* tuple) override; - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - absl::Status HandleCall(HloInstruction* call) override; - absl::Status HandleFusion(HloInstruction* fusion) override; - absl::Status HandleCustomCall(HloInstruction* custom_call) override; - absl::Status HandleWhile(HloInstruction* xla_while) override; - absl::Status HandleConditional(HloInstruction* conditional) override; - absl::Status HandleSelect(HloInstruction* select) override; - absl::Status HandleConcatenate(HloInstruction* concatenate) override; - absl::Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; - absl::Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override; - absl::Status HandleCopyStart(HloInstruction* copy_start) override; - absl::Status HandleCopyDone(HloInstruction* copy_done) override; - absl::Status HandleAllGatherStart(HloInstruction* all_gather_start) override; - absl::Status HandleAllGatherDone(HloInstruction* all_gather_done) override; - absl::Status HandleCollectivePermuteStart( - HloInstruction* collective_permute_start) override; - absl::Status HandleCollectivePermuteDone( - HloInstruction* collective_permute_done) override; - absl::Status HandleGather(HloInstruction* gather) override; - absl::Status HandleScatter(HloInstruction* scatter) override; - absl::Status HandleAfterAll(HloInstruction* after_all) override; - absl::Status HandleAllReduce(HloInstruction* all_reduce) override; - absl::Status HandleAsyncStart(HloInstruction* async_start) override; - absl::Status HandleAsyncDone(HloInstruction* async_done) override; - absl::Status HandleInfeed(HloInstruction* infeed) override; - absl::Status HandleOutfeed(HloInstruction* outfeed) override; - absl::Status HandleDomain(HloInstruction* domain) override; - absl::Status HandleOptimizationBarrier(HloInstruction* opt_barrier) override; - absl::Status HandleRngBitGenerator( - HloInstruction* rng_bit_generator) override; - absl::Status HandleSend(HloInstruction* send) override; - absl::Status HandleRecv(HloInstruction* recv) override; - absl::Status HandleSendDone(HloInstruction* send_done) override; - absl::Status HandleRecvDone(HloInstruction* recv_done) override; - - protected: - HloValueSemantics CopySemantics(const HloValueSemantics& semantics) const; - HloValueSemantics CopySemanticsWithNewOrigin( - const HloValueSemantics& semantics, HloInstruction* new_origin, - const ShapeIndex& index = {}) const; - const HloValueSemantics* AddSemantics(const HloValueSemantics& semantics); - struct EinsumAndOperandIndex { - HloInstruction* einsum; - int64_t operand_index; - }; - // Checks if the origin of `semantics` is an einsum that takes - // `origin_dependence` as an operand. - // If `recursive` is set to true, recursively checks all ancestors of the - // `semantics`' origin (including itself) for the above condition. - // Returns all such einsums and the operand index corresponding to - // `origin_dependence`. - // We use this function to find whether the output of an einsum who has an - // operand X is used in another einsum who takes X as an operand. This is - // the pattern for gradient. - // For example, consider C = einsum(A, B), dC / dB = einsum(A, C). - std::vector FindEinsumsWhereOriginDependsOnOther( - const HloValueSemantics& semantics, const HloPosition& origin_dependence, - bool recursive = false) const; - bool OriginDependsOn(const HloValueSemantics& semantics, - const HloPosition& origin_dependence, - bool recursive = false) const; - absl::StatusOr MaybeCreateGradientSemantics( - HloInstruction* gradient_candidate, - HloValueSemanticLabel fallback_label) const; - absl::StatusOr ComputeSemanticsFromStaticAndOther( - const HloValueSemantics& static_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr ComputeSemanticsFromRandomAndOther( - const HloValueSemantics& random_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr ComputeSemanticsFromWeightAndOther( - const HloValueSemantics& weight_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr ComputeSemanticsFromActivationAndOther( - const HloValueSemantics& activation_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr - ComputeSemanticsFromActivationGradientAndOther( - const HloValueSemantics& activation_gradient_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr ComputeSemanticsFromWeightGradientAndOther( - const HloValueSemantics& weight_gradient_semantics, - const HloValueSemantics& other_semantics, - HloInstruction* instruction) const; - absl::StatusOr MergeSemanticsForAnInstruction( - HloInstruction* instruction, - std::vector& semantics_vec) const; - absl::StatusOr ComputeSemanticsFromOperands( - HloInstruction* instruction, absl::Span operand_indices, - absl::Span operand_shape_indices = {}) const; - absl::Status HandleTupleLike(HloInstruction* tuple_like); - absl::Status HandleCollectiveOrCopyStart(HloInstruction* op_start); - absl::Status HandleCollectiveOrCopyDone(HloInstruction* op_done); - HloValueSemanticsAnalysis* analysis_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/hlo_value_semantics_analysis.h" #endif // XLA_SERVICE_HLO_VALUE_SEMANTICS_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index 22592935ec498f..c2cb566be5ffaa 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -615,6 +615,24 @@ absl::Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { } } +absl::Status ShapeVerifier::HandleRaggedAllToAll(HloInstruction* hlo) { + auto* all_to_all = Cast(hlo); + TF_ASSIGN_OR_RETURN(CollectiveOpGroupMode group_mode, + GetCollectiveOpGroupMode( + all_to_all->channel_id().has_value(), std::nullopt)); + + TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, group_mode)); + + TF_RET_CHECK(all_to_all != nullptr); + TF_RET_CHECK(hlo->operand_count() == 6); + std::vector operand_shapes; + for (const HloInstruction* operand : hlo->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(hlo, + ShapeInference::InferRaggedAllToAllShape(operand_shapes)); +} + absl::Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) { return CheckShape(hlo, ShapeUtil::MakeShape(U32, {})); } @@ -2368,6 +2386,38 @@ absl::Status VerifyAsynchronousInstructionPairs(const HloModule& module) { instruction, {HloOpcode::kCollectivePermuteStart})); break; } + case HloOpcode::kSend: { + // If the instruction is kSend or kRecv, it can have no users if and + // only if it is wrapped in an async call. + if (instruction->IsRoot() && + instruction->parent()->IsAsyncComputation()) { + break; + } + TF_RETURN_IF_ERROR(VerifySingleUser( + instruction, {HloOpcode::kSendDone, HloOpcode::kTuple})); + break; + } + case HloOpcode::kSendDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, {HloOpcode::kSend, HloOpcode::kGetTupleElement})); + break; + } + case HloOpcode::kRecv: { + // If the instruction is kSend or kRecv, it can have no users if and + // only if it is wrapped in an async call. + if (instruction->IsRoot() && + instruction->parent()->IsAsyncComputation()) { + break; + } + TF_RETURN_IF_ERROR(VerifySingleUser( + instruction, {HloOpcode::kRecvDone, HloOpcode::kTuple})); + break; + } + case HloOpcode::kRecvDone: { + TF_RETURN_IF_ERROR(VerifySingleOperand( + instruction, {HloOpcode::kRecv, HloOpcode::kGetTupleElement})); + break; + } default: break; } @@ -2414,44 +2464,13 @@ absl::Status VerifyLayoutConstrainedAllReduce(const HloModule& module) { // Checks various invariants of channel instructions (send/recv and // collectives). -absl::Status VerifyChannels(const HloModule& module) { +absl::Status VerifyChannels(const HloModule& module, + const HloVerifierOpts& opts) { absl::flat_hash_map> channel_instructions; - // For Async operations, we need to make sure: - // (1) AsyncStart and AsyncDone are used in pairs - // (2) AsynStart and Asyndone are connected, that is, an AsynDone has an - // AsyncStart as its only operand, and an AsynStart has an AsyncDone as - // its only user - // (3) the channel ID used by a pair of Async operations is unique - // - // Send and SendDone, Recv and RecvDone are such pairs of Async operations. - // Different from other Async operations, a channel ID can be used by one - // Send-SendDone pair and one Recv-RecvDone pair. As such, we verify the - // above three invariants for Send/Recv related instructions with adjustment - // to (3): - // (3*) the channel ID used by a pair of Send-SendDone can be shared by at - // most one pair of Recv-RecvDone. - // - // Currently, the GPU compiler can decomposed collective-permute into a group - // of instructions with a pair of Send-SendDone and a pair of Recv-RecvDone - // that use the same channel ID. When a while-body contains such instructions, - // the GPU compiler can also peel off Send and Recv, and statically order - // SendDone/RecvDone inside the while-body before Send/Recv. This breaks - // invariants (2) and (3*) for the pipelined Send/Recv case. We verify the - // following for a group of instructions using the same channel ID but don't - // satisfy invariants (1)(2)(3*): - // (4) All instructions in the group are annotated with frontend attributes. - // We avoid verifying the content of such a frontend attribute to avoid - // making the general HLO instruction verifier depend on the compiler pass - // that performs the transformation. - // (5) the group should contain equal number uses of each Send/Recv related - // instructions. - // - // Comparing the verification of unpipelined Send/Recv with the verification - // of pipelined, what we missing verifying is that the direct connection - // between Send/Recv and SendDone/RecvDone through operands. - // + // Send/recv instruction must have a unique user. If it is the corresponding + // send-done/recv-done operation, channel IDs must match. for (const HloComputation* computation : module.computations()) { for (const HloInstruction* instruction : computation->instructions()) { auto channel_instr = DynCast(instruction); @@ -2462,55 +2481,39 @@ absl::Status VerifyChannels(const HloModule& module) { switch (instruction->opcode()) { case HloOpcode::kSend: { - bool pipelined = true; - if (instruction->users().size() == 1) { - const HloInstruction* send_user = instruction->users().front(); - if (send_user->opcode() == HloOpcode::kSendDone) { - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_user)); - TF_RETURN_IF_ERROR( - CheckSameIsHostTransfer(instruction, send_user)); - pipelined = false; - } + // If the instruction is kSend or kRecv, it can have no users if and + // only if it is wrapped in an async call. + if (instruction->IsRoot() && + instruction->parent()->IsAsyncComputation()) { + break; + } + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + if (send_done->opcode() == HloOpcode::kSendDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); } - // Pipelined Send should be annotated with frontend attributes. - TF_RET_CHECK(pipelined == false || - !instruction->frontend_attributes().map().empty()); break; } case HloOpcode::kRecv: { - bool pipelined = true; - if (instruction->users().size() == 1) { - const HloInstruction* recv_user = instruction->users().front(); - if (recv_user->opcode() == HloOpcode::kRecvDone) { - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_user)); - TF_RETURN_IF_ERROR( - CheckSameIsHostTransfer(instruction, recv_user)); - pipelined = false; - } + // If the instruction is kSend or kRecv, it can have no users if and + // only if it is wrapped in an async call. + if (instruction->IsRoot() && + instruction->parent()->IsAsyncComputation()) { + break; + } + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + if (recv_done->opcode() == HloOpcode::kRecvDone) { + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); } - // Pipelined Recv should be annotated with frontend attributes. - TF_RET_CHECK(pipelined == false || - !instruction->frontend_attributes().map().empty()); - break; - } - case HloOpcode::kSendDone: { - TF_RET_CHECK(instruction->operands().size() == 1); - const HloInstruction* send_done_operand = instruction->operand(0); - // If the operand is not a Send, the Send-done is pipelined and should - // have frontend attributes. - TF_RET_CHECK(send_done_operand->opcode() == HloOpcode::kSend || - !instruction->frontend_attributes().map().empty()); break; } - case HloOpcode::kRecvDone: { + case HloOpcode::kSendDone: + case HloOpcode::kRecvDone: TF_RET_CHECK(instruction->operands().size() == 1); - const HloInstruction* recv_done_operand = instruction->operand(0); - // If the operand is not a Recv, the Recv-done is pipelined and should - // have frontend attributes. - TF_RET_CHECK(recv_done_operand->opcode() == HloOpcode::kRecv || - !instruction->frontend_attributes().map().empty()); break; - } default: break; } @@ -2521,59 +2524,22 @@ absl::Status VerifyChannels(const HloModule& module) { for (auto& pair : channel_instructions) { auto& instructions = pair.second; const HloInstruction* first = instructions[0]; - auto sendrecv = DynCast(first); - if (sendrecv) { - // Check that all instructions are Send/Recv related and count the - // appearance of each opcode in the group. - absl::flat_hash_map opcode_to_count; + if (const auto* sendrecv = DynCast(first)) { + absl::flat_hash_set opcodes; for (const HloInstruction* instr : instructions) { - auto it = opcode_to_count.find(instr->opcode()); - if (it != opcode_to_count.end()) { - it->second++; - } else { - opcode_to_count[instr->opcode()] = 1; - } - TF_RET_CHECK(DynCast(instr) != nullptr) + opcodes.insert(instr->opcode()); + auto cast = DynCast(instr); + TF_RET_CHECK(cast != nullptr) << "channel " << pair.first << " is used for different types of channel instructions"; } - - int count = opcode_to_count.begin()->second; - bool consistent_count = - absl::c_all_of(opcode_to_count, [count](const auto& opcode_count) { - return opcode_count.second == count; - }); - // A pipelined group of Send/Recv should all have frontend attributes. - bool maybe_pipelined = - absl::c_all_of(instructions, [](const HloInstruction* inst) { - return !inst->frontend_attributes().map().empty(); - }); - - if (sendrecv->is_host_transfer()) { - TF_RET_CHECK(consistent_count && count == 1 && instructions.size() == 2) - << "channel " << pair.first - << " is used for multiple host send/recv instructions"; - } else { - if (consistent_count && count == 1) { - TF_RET_CHECK(instructions.size() == opcode_to_count.size()) - << "channel " << pair.first - << " is used for multiple send/recv instructions"; - } else { - TF_RET_CHECK(maybe_pipelined) << "channel " << pair.first - << " is used for multiple send/recv " - "instructions but not pipelined"; - TF_RET_CHECK(consistent_count && opcode_to_count.size() % 2 == 0) - << "channel " << pair.first - << " is pipelined. Not all Send/Recv related instructions are" - " used the same number of times or channel is used for other " - "instructions"; - } - } } else { for (const HloInstruction* instr : instructions) { - TF_RET_CHECK(first->opcode() == instr->opcode()) - << "channel " << pair.first - << " is used for different types of channel instructions"; + if (opts.verify_unique_channel_ids) { + TF_RET_CHECK(first->opcode() == instr->opcode()) + << "channel " << pair.first + << " is used for different types of channel instructions"; + } } } } @@ -2982,8 +2948,12 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { const Layout& operand_layout = operand_shape.layout(); Layout::Equal equal_predicate = Layout::Equal().IgnoreTiles().IgnoreMemorySpace(); - if (instruction->opcode() == HloOpcode::kConvert) { - // Convert instructions can change element_size_in_bits + if (instruction->opcode() == HloOpcode::kConvert || + instruction->opcode() == HloOpcode::kCompare || + (instruction->opcode() == HloOpcode::kSelect && + operand_shape.element_type() == PRED)) { + // Convert and Compare instructions can change element_size_in_bits + // Select instructions ignore element_size_in_bits for predicate equal_predicate.IgnoreElementSize(); } else if (instruction->opcode() == HloOpcode::kDynamicSlice || instruction->opcode() == HloOpcode::kDynamicUpdateSlice || @@ -3082,7 +3052,8 @@ absl::StatusOr HloVerifier::Run( TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module)); - TF_RETURN_IF_ERROR(VerifyChannels(*module)); + TF_RETURN_IF_ERROR( + VerifyChannels(*module, target_metadata_->GetVerifierOpts())); TF_RETURN_IF_ERROR(VerifyInstructionNameUnchanged( *module, target_metadata_->GetVerifierOpts())); @@ -3093,7 +3064,11 @@ absl::StatusOr HloVerifier::Run( for (auto* computation : module->computations(execution_threads)) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); - if (computation->IsAsyncComputation()) { + // Verify that async computations contain a single instruction or a + // collection of send/recv instructions. This is needed to represent NCCL + // groups on GPU. + if (computation->IsAsyncComputation() && + !computation->OnlyContainsSendRecv()) { TF_RETURN_IF_ERROR(VerifyAsyncComputation(computation)); } } diff --git a/third_party/xla/xla/service/hlo_verifier.h b/third_party/xla/xla/service/hlo_verifier.h index 83f5e32a6def05..8aa7a1dcb7ee78 100644 --- a/third_party/xla/xla/service/hlo_verifier.h +++ b/third_party/xla/xla/service/hlo_verifier.h @@ -149,6 +149,9 @@ struct HloVerifierOpts { // cloned (".clone" suffix) or rematted (".remat"); bool verify_instruction_name_unchanged = false; + // Check if channel instructions all have unique channel ids. + bool verify_unique_channel_ids = true; + HloPredicate instruction_can_change_layout; // Returns a target-specific shape size. @@ -191,6 +194,7 @@ class ShapeVerifier : public DfsHloVisitor { absl::Status HandleAllReduceStart(HloInstruction* hlo) override; absl::Status HandleAllReduceDone(HloInstruction* hlo) override; absl::Status HandleAllToAll(HloInstruction* hlo) override; + absl::Status HandleRaggedAllToAll(HloInstruction* hlo) override; absl::Status HandleCollectiveBroadcast(HloInstruction* hlo) override; absl::Status HandleCollectivePermute(HloInstruction* hlo) override; absl::Status HandleCollectivePermuteStart(HloInstruction* hlo) override; diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 07c47980640b71..eff127af51be37 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -35,10 +35,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/layout.h" #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/service/layout_assignment.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -75,12 +76,13 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase { /*allow_mixed_precision_in_hlo_verifier=*/true) {} }; -class HloVerifierTestLayoutSensitive : public HloTestBase { +class HloVerifierTestLayoutSensitive : public HloHardwareIndependentTestBase { public: HloVerifierTestLayoutSensitive() - : HloTestBase(/*verifier_layout_sensitive=*/true, - /*allow_mixed_precision_in_hlo_verifier=*/false, - LayoutAssignment::InstructionCanChangeLayout) {} + : HloHardwareIndependentTestBase( + /*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/false, + LayoutAssignment::InstructionCanChangeLayout) {} }; class HloVerifierTestLayoutSensitiveAndAllowMixedPrecision @@ -997,7 +999,7 @@ TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); - auto status = verifier().Run(module.get()).status(); + absl::Status status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT( status.message(), @@ -1415,6 +1417,60 @@ TEST_F(HloVerifierTest, AsyncOpComputationNotTrivial) { "expected to contain only the root and parameter instructions")); } +TEST_F(HloVerifierTest, AsyncMultiOpComputationSendRecvOnly) { + const char* const hlo_string = R"( + wrapped_send_recv_1 { + param0 = f32[] parameter(0) + param1 = token[] parameter(1) + send = (f32[], u32[], token[]) send(param0, param1), channel_id=1 + param2 = f32[] parameter(2) + param3 = token[] parameter(3) + send.1 = (f32[], u32[], token[]) send(param2, param3), channel_id=2 + param4 = token[] parameter(4) + recv = (f32[], u32[], token[]) recv(param4), channel_id=1 + param5 = token[] parameter(5) + recv.1 = (f32[], u32[], token[]) recv(param5), channel_id=2 + ROOT tuple = ((f32[], u32[], token[]), (f32[], u32[], token[]), + (f32[], u32[], token[]), (f32[], u32[], token[])) + tuple(send, send.1, recv, recv.1) + } + + ENTRY main { + data-1 = f32[] constant(1) + after-all-1 = token[] after-all() + data-2 = f32[] constant(2) + after-all-2 = token[] after-all() + tuple-start = ((f32[], token[], f32[], token[], token[], token[]), + ((f32[], u32[], token[]), (f32[], u32[], token[]), + (f32[], u32[], token[]), (f32[], u32[], token[])), s32[]) + async-start(data-1, after-all-1, data-2, after-all-2, after-all-1, after-all-2), + calls=wrapped_send_recv_1 + tuple-done = ((f32[], u32[], token[]), (f32[], u32[], token[]), + (f32[], u32[], token[]), (f32[], u32[], token[])) async-done(tuple-start) + gte.4 = (f32[], u32[], token[]) get-tuple-element(tuple-done), index=2 + gte.5 = f32[] get-tuple-element(gte.4), index=0 + gte.6 = token[] get-tuple-element(gte.4), index=2 + tuple.1 = (f32[], token[]) tuple(gte.5, gte.6) + data-out-1 = f32[] get-tuple-element(tuple.1), index=0 + gte.7 = (f32[], u32[], token[]) get-tuple-element(tuple-done), index=3 + gte.8 = f32[] get-tuple-element(gte.7), index=0 + gte.9 = token[] get-tuple-element(gte.7), index=2 + tuple.2 = (f32[], token[]) tuple(gte.8, gte.9) + data-out-2 = f32[] get-tuple-element(tuple.2), index=0 + ROOT out = (f32[], f32[]) tuple(data-out-1, data-out-2) + get-tuple-element = (f32[], u32[], token[]) get-tuple-element(tuple-done), index=0 + gte.1 = token[] get-tuple-element(get-tuple-element), index=2 + gte.2 = (f32[], u32[], token[]) get-tuple-element(tuple-done), index=1 + gte.3 = token[] get-tuple-element(gte.2), index=2 + } +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_TRUE(status.ok()); +} + TEST_F(HloVerifierTest, IotaNonArrayResult) { const char* const hlo_string = R"( HloModule IotaTupleResult @@ -2263,93 +2319,181 @@ TEST_F(HloVerifierTest, ChannelVerifier) { HasSubstr("used for different types of channel instructions")); } -TEST_F(HloVerifierTest, ChannelVerifierPipelinedMissingDones) { +TEST_F(HloVerifierTest, ChannelVerifierPartiallyPipelinedAsyncRecv) { const char* const kModuleStr = R"( - HloModule test - cond { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) - count = get-tuple-element(%param), index=0 - ub = u32[] constant(1) - ROOT result = pred[] compare(count, ub), direction=LT - } - - body { - param = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) parameter(0) - count = get-tuple-element(%param), index=0 - - recv.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=1 - recv-done.0 = (u32[2], token[]) recv-done(recv.0), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - recv-data.0 = u32[2] get-tuple-element(recv-done.0), index=0 - - c1 = u32[] constant(1) - new_count = u32[] add(count, c1) - - send.0 = (u32[2], u32[], token[]) get-tuple-element(param), index=2 - send-done.0 = (u32[2], token[]) recv-done(send.0), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - - after-all.0.n = token[] after-all() - recv.0.n = (u32[2], u32[], token[]) recv(after-all.0.n), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - - after-all.1.n = token[] after-all() - send.0.n = (u32[2], u32[], token[]) send(recv-data.0, after-all.1.n), - channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - ROOT result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - tuple(new_count, recv.0.n, send.0.n) - } - - ENTRY test_computation { - c0 = u32[] constant(0) - init = u32[2] broadcast(c0), dimensions={} - after-all.0.p = token[] after-all() - recv.0.p = (u32[2], u32[], token[]) recv(after-all.0.p), channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - after-all.1.p = token[] after-all() - send.0.p = (u32[2], u32[], token[]) send(init, after-all.1.p), - channel_id=1, - frontend_attributes={ - _xla_send_recv_source_target_pairs="{{1,0}}", - _xla_send_recv_pipeline="0" - } - - while_init = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - tuple(c0, recv.0.p, send.0.p) - while_result = (u32[], (u32[2], u32[], token[]), (u32[2], u32[], token[])) - while(while_init), body=body, condition=cond - - recv.0.q = (u32[2], u32[], token[]) get-tuple-element(while_result), index=1 - recv-done.0.q = (u32[2], token[]) recv-done(recv.0.q), channel_id=1, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } - - ROOT recv-data.0.q = u32[2] get-tuple-element(recv-done.0.q), index=0 - })"; + HloModule test + + while_body { + param = ((f32[16], u32[], token[])) parameter(0) + prev_recv = (f32[16], u32[], token[]) get-tuple-element(param), index=0 + recv_done = (f32[16], token[]) recv-done(prev_recv), channel_id=1 + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16], u32[], token[])) tuple(recv) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16], u32[], token[])) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16], u32[], token[])) tuple(recv) + while = ((f32[16], u32[], token[])) while(init), + condition=while_condition, body=while_body + recv_ctx = (f32[16], u32[], token[]) get-tuple-element(while), index=0 + recv_done = (f32[16], token[]) recv-done(recv_ctx), channel_id=1 + ROOT result = f32[16] get-tuple-element(recv_done), index=0 + })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kModuleStr)); - EXPECT_THAT( - verifier().Run(module.get()).status().message(), - HasSubstr("is pipelined. Not all Send/Recv related instructions are used" - " the same number of times")); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierPartiallyPipelinedAsyncSend) { + const char* const kModuleStr = R"( + HloModule test + + while_body { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + prev_send = (f32[16], u32[], token[]) get-tuple-element(param), index=0 + data = f32[16] get-tuple-element(param), index=1 + send_done = (f32[16], token[]) send-done(prev_send), channel_id=1 + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT tuple = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + } + + // Infinite loop to keep IR small. + while_condition { + param = ((f32[16], u32[], token[]), f32[16]) parameter(0) + ROOT infinite_loop = pred[] constant(true) + } + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + init = ((f32[16], u32[], token[]), f32[16]) tuple(send, data) + while = ((f32[16], u32[], token[]), f32[16]) while(init), + condition=while_condition, body=while_body + send_ctx = (f32[16], u32[], token[]) get-tuple-element(while), index=0 + ROOT send_done = (f32[16], token[]) send-done(send_ctx), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierAsyncSend) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(after_all, data), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + ROOT send_done = (f32[16], token[]) send-done(send), channel_id=1 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, SingleUserExceptionForWrappedSendRecv) { + const char* const kModuleStr = R"( + wrapped_send { + data = f32[] parameter(0) + after-all = token[] parameter(1) + ROOT send = (f32[], u32[], token[]) send(data, after-all), channel_id=1, + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}} + } + wrapped_recv { + after-all = token[] parameter(0) + ROOT recv = (f32[], u32[], token[]) recv(after-all), channel_id=1, + frontend_attributes={_xla_send_recv_source_target_pairs={{1,0}}} + } + ENTRY main () -> f32[] { + data = f32[] constant(5) + after-all = token[] after-all() + async-recv-start = ((token[]), (f32[], u32[], token[]), s32[]) async-start(after-all), calls=wrapped_recv + async-send-start = ((f32[], token[]), (f32[], u32[], token[]), s32[]) async-start(data, after-all), calls=wrapped_send + async-recv-done = (f32[], u32[], token[]) async-done(async-recv-start), calls=wrapped_recv + async-send-done = (f32[], u32[], token[]) async-done(async-send-start), calls=wrapped_send + ROOT out = f32[] get-tuple-element((f32[], u32[], token[]) async-recv-done), index=0 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierAsyncRecv) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 + ROOT result = f32[16] get-tuple-element(recv_done), index=0 + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + +TEST_F(HloVerifierTest, ChannelVerifierMultipleSendUsers) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + data = f32[16] parameter(0) + after_all = token[] after-all() + send = (f32[16], u32[], token[]) send(data, after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + send_done = (f32[16], token[]) send-done(send), channel_id=1 + ROOT result = ((f32[16], u32[], token[]), f32[16]) tuple(send, send_done) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().message(), + HasSubstr("send instruction requires one consumer, found 2")); +} + +TEST_F(HloVerifierTest, ChannelVerifierMultipleRecvUsers) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main_spmd { + after_all = token[] after-all() + recv = (f32[16], u32[], token[]) recv(after_all), channel_id=1, + frontend_attributes={ + _xla_send_send_source_target_pairs={{0,1},{1,2},{2,3}}} + recv_done = (f32[16], token[]) recv-done(recv), channel_id=1 + ROOT result = (((f32[16], u32[], token[])), f32[16]) + tuple(recv, recv_done) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + EXPECT_THAT(verifier().Run(module.get()).status().message(), + HasSubstr("recv instruction requires one consumer, found 2")); } TEST_F(HloVerifierTest, CollectiveChannelVerifier) { @@ -3408,5 +3552,40 @@ TEST_F(HloVerifierTestLayoutSensitive, HasSubstr("Instruction has mismatched minor-to-major size and " "dimension size: ")); } + +TEST_F(HloVerifierTest, NoErrorOnDuplicateChannelId) { + const char* const hlo_string = R"( + HloModule m + + ENTRY main { + data_param = f32[2048,2048]{1,0} parameter(0) + cp1 = f32[2048,2048]{1,0} collective-permute(data_param), source_target_pairs={{0,1},{1,2},{2,3}}, channel_id=1 + cp2 = f32[2048,2048]{1,0} collective-permute(data_param), source_target_pairs={{0,1}}, channel_id=1 + + ROOT tuple = (f32[2048,2048]{1,0}, f32[2048,2048]{1,0}) tuple(cp1, cp2) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + HloVerifierOpts opts{}; + opts.verify_unique_channel_ids = false; + HloVerifier verifier(std::move(opts)); + ASSERT_IS_OK(verifier.Run(module.get()).status()); +} + +TEST_F(HloVerifierTestLayoutSensitive, Int4CompareSelect) { + const char* const kModuleStr = R"( + HloModule test + + ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + less = pred[10] compare(a, b), direction=LT + ROOT result = select(less, a, b) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(kModuleStr)); + TF_ASSERT_OK(verifier().Run(module.get())); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier.h b/third_party/xla/xla/service/host_memory_transfer_asyncifier.h index 0f42d8f1cfa019..d2677f2ab2948e 100644 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier.h +++ b/third_party/xla/xla/service/host_memory_transfer_asyncifier.h @@ -15,44 +15,7 @@ limitations under the License. #ifndef XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ #define XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -/* -This pass finds copies between the host memory and device memory and converts -them into the async ops. This includes, but is not limited to: - - device to host DynamicUpdateSlice - - host to device DynamicSlice -* The examples below are not yet supported * - - host to device DynamicUpdateSlice - - device to host DynamicSlice - - host to device Copy - - device to host Copy -*/ -class HostMemoryTransferAsyncifier : public HloModulePass { - public: - explicit HostMemoryTransferAsyncifier(int64_t host_memory_space_color) - : kHostMemorySpaceColor(host_memory_space_color) {} - ~HostMemoryTransferAsyncifier() override = default; - - absl::string_view name() const override { - return "host-memory-transfer-asyncifier"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const int64_t kHostMemorySpaceColor; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/host_memory_transfer_asyncifier.h" #endif // XLA_SERVICE_HOST_MEMORY_TRANSFER_ASYNCIFIER_H_ diff --git a/third_party/xla/xla/service/host_offload_legalize.h b/third_party/xla/xla/service/host_offload_legalize.h index c04a9cd549d26d..181c82e269a183 100644 --- a/third_party/xla/xla/service/host_offload_legalize.h +++ b/third_party/xla/xla/service/host_offload_legalize.h @@ -15,48 +15,7 @@ #ifndef XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ #define XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_alias_analysis.h" - -namespace xla { - -class HloCostAnalysis; - -// This pass legalizes the graph for the "host memory offloading" pass to -// correctly identified buffers that are meant to be move on the host. Any -// legalization that could block that is welcome into this pass. -class HostOffloadLegalize : public HloModulePass { - public: - explicit HostOffloadLegalize(int64_t host_memory_space_color, - bool after_layout) - : kHostMemorySpaceColor(host_memory_space_color), - after_layout_(after_layout) {} - ~HostOffloadLegalize() override = default; - - absl::string_view name() const override { return "host-offload-legalize"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const int64_t kHostMemorySpaceColor; - const bool after_layout_; - - // For any memory offloaded to the host, return the instruction which is the - // start of such and offload. These will either be "MoveToHost" annotations or - // entry computation parameters. - std::vector FindStartingInstructionsOfHostMemoryOffload( - HloModule* module, - const absl::flat_hash_set& execution_threads) const; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/host_offload_legalize.h" #endif // XLA_SERVICE_HOST_OFFLOAD_LEGALIZE_H_ diff --git a/third_party/xla/xla/service/host_offload_utils.cc b/third_party/xla/xla/service/host_offload_utils.cc index 98732043da3a65..d391f3a3446207 100644 --- a/third_party/xla/xla/service/host_offload_utils.cc +++ b/third_party/xla/xla/service/host_offload_utils.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/shape_util.h" +#include "xla/side_effect_util.h" #include "xla/util.h" namespace xla { @@ -256,5 +257,14 @@ bool IsSynchronousCopyFromOrToHost(const HloInstruction* instruction) { Layout::kHostMemorySpace); } +bool ComputeTypeIsHost(const HloInstruction* hlo_instruction) { + const auto& frontend_attributes_map = + hlo_instruction->frontend_attributes().map(); + return (frontend_attributes_map.find(kXlaComputeTypeAttr) != + frontend_attributes_map.end() && + frontend_attributes_map.find(kXlaComputeTypeAttr)->second == + kXlaComputeTypeHost); +} + } // namespace host_offload_utils } // namespace xla diff --git a/third_party/xla/xla/service/host_offload_utils.h b/third_party/xla/xla/service/host_offload_utils.h index c71dcc21e4ea83..615f331b385513 100644 --- a/third_party/xla/xla/service/host_offload_utils.h +++ b/third_party/xla/xla/service/host_offload_utils.h @@ -101,6 +101,8 @@ bool IsHostAsyncStart(const HloInstruction* instruction); // Returns true if the copy is from or to host memory space. bool IsSynchronousCopyFromOrToHost(const HloInstruction* instruction); +bool ComputeTypeIsHost(const HloInstruction* hlo_instruction); + } // namespace host_offload_utils } // namespace xla diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index bcc761ea2fafab..0f68eb631fc033 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -15,151 +15,7 @@ #ifndef XLA_SERVICE_HOST_OFFLOADER_H_ #define XLA_SERVICE_HOST_OFFLOADER_H_ -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/host_offload_utils.h" - -namespace xla { - -class HloCostAnalysis; - -// This pass does "host memory offloading". If a tensor is annotated to be moved -// to or from the host, this pass will remove the annotations and update each -// tensor's layout with host memory spaces and insert copies if necessary. This -// pass checks to make sure that no compute is done on the tensors annotated for -// host memory offload; if there is compute, it is considered a user error and -// an error will be returned. -// The pass will "walk down" the Hlo graph starting from either MoveToHost -// custom calls or from parameters with host memory space in their layout. All -// tensors along each path have their memory space set as host memory space. If -// a MoveToHost custom call is paired with a DynamicUpdateSlice, the -// DynamicUpdateSlice will write into host memory space. Otherwise, a copy from -// device to host will be inserted. -// -// If an output of a host offloaded computation is only used on host, the memory -// space of the usages are updated to reflect it and no copies to and from host -// are performed. Any MoveToHost instructions for outputs used only on host, are -// removed. -// TODO(b/347101407): A better approach could be to remove redundant copies in a -// generalized fashion. Should also be moved out of Host Offloader. -// -// All MoveToHost and MoveToDevice custom calls are removed by the end of this -// pass. -class HostOffloader : public HloModulePass { - public: - explicit HostOffloader(int64_t host_memory_space_color) - : kHostMemorySpaceColor(host_memory_space_color) {} - ~HostOffloader() override = default; - - absl::string_view name() const override { return "host-offloader"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const int64_t kHostMemorySpaceColor; - absl::flat_hash_set - already_visited_move_to_host_custom_calls_; - absl::flat_hash_set dynamic_update_slices_already_allocated_; - absl::flat_hash_set validated_slices_; - absl::flat_hash_map copies_created_after_; - absl::flat_hash_set move_to_device_custom_calls_to_remove_; - absl::flat_hash_set - already_inserted_copy_before_; - - // Sometimes previous transformations turn a DynamicSlice into a Slice. Since - // we're doing a DMA between the host and device, we need to turn the Slice - // back into a DynamicSlice. - absl::StatusOr DynamifySlice(HloInstruction* slice); - - // Returns true if the instruction is allowed to be in the - // middle of a path between a MoveToHost custom-call annotation and a - // DynamicUpdateSlice. Ideally the custom-call should be immediately followed - // by the DynamicUpdateSlice, but this is not always the case. - bool InstructionIsAllowedBetweenMoveToHostAndDus( - const HloInstruction* instruction) const; - - // Returns true if the instruction is allowed to be in the - // middle of a path between a DynamicSlice and a MoveToDevice custom-call - // annotation. Ideally the DynamicSlice should be immediately followed by the - // custom-call, but this is not always the case. - bool InstructionIsAllowedBetweenDsAndMoveToDevice( - const HloInstruction* instruction) const; - - // Walks down the graph and does "host memory offloading" starting from every - // host memory parameter in the entry computation. - absl::StatusOr HandleInputStreaming(HloComputation* entry_computation); - - // Walks down the graph and does "host memory offloading" starting from every - // MoveToHost custom call. - absl::StatusOr HandleMoveToHostCustomCall( - HloInstruction* custom_call_instruction); - - // Since we always walk the graph from the top down, this function only needs - // to remove these lingering custom calls. This function should only be called - // once all host memory offloading is done because multiple paths might lead - // to the same MoveToDevice custom call. Removing it too early will confuse - // subsequent walkings of the graph. - absl::StatusOr HandleMoveToDeviceCustomCall( - HloInstruction* custom_call_instruction); - - // DynamicUpdateSlices which write into host memory must have their - // destination buffer allocated on the host. This function creates the - // allocation and updates all positions to have host memory space. - absl::Status CreateAllocateBufferForDynamicUpdateSlice( - HloInstruction* dynamic_update_slice); - - // Returns an error if something unallowed exists between the - // Slice/DynamicSlice and the MoveToDevice custom call. - absl::Status ValidateSliceLeadsToMoveToDeviceCustomCall( - HloInstruction* slice); - - // Common function for doing the actual walking of the graph. Host memory - // spaces are set and copies are inserted in here. - absl::StatusOr WalkDownHostMemoryOffloadPaths( - const host_offload_utils::InstructionAndShapeIndex& - starting_instruction_and_index, - bool insert_copy_before); - - // Given a custom call, this returns the first instruction and shape index to - // start the host memory offload path from for each use of the custom call. - absl::StatusOr> - GetStartingInstructions(HloInstruction* custom_call_instruction); - - // When a MoveToHost custom call is not paired with a DynamicUpdateSlice, a - // copy from device to host must be inserted. - absl::StatusOr InsertCopyBetween( - const host_offload_utils::InstructionAndShapeIndex& - before_instruction_and_index, - const host_offload_utils::InstructionAndShapeIndex& - after_instruction_and_index); - - // This is a fix for scheduling. Add copies to inputs of dynamic-update-slice - // if the inserted value is directly a parameter of a computation. This is to - // avoid cases in while loop where parameter/output aliasing can stop - // scheduling because control-dependencies are added. - absl::StatusOr ApplySchedulingFix( - HloModule* module, - const absl::flat_hash_set& execution_threads); - - // Starting from the outputs of the host offloaded computation, track all - // their usages. For the outputs that are ONLY used on host, remove redundant - // copies to and from host, as well as update the memory space. - absl::StatusOr HandleRedundantCopiesBackToHost( - const HloModule* module, HloInstruction* instruction); -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/host_offloader.h" #endif // XLA_SERVICE_HOST_OFFLOADER_H_ diff --git a/third_party/xla/xla/service/host_offloading_prepare.h b/third_party/xla/xla/service/host_offloading_prepare.h index cb0bfe04078f11..016bfadb46bad7 100644 --- a/third_party/xla/xla/service/host_offloading_prepare.h +++ b/third_party/xla/xla/service/host_offloading_prepare.h @@ -16,75 +16,7 @@ #ifndef XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ #define XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This is a collection of rewrites that prepares HLO module for host -// offloading, mainly to work around different limitation of the compilation -// pipeline and runtime. These rewrites can be placed in a different parts of -// the overall compilation pipeline to prepare HLO module for host offloading -// for the given backend (different backends have different limitations). -class HostOffloadingPrepare : public HloModulePass { - public: - enum class Rewrite { - // Currently host compute offloading requires that all temporary inputs are - // in device memory. If they are streamed inputs (inputs to the entry - // computation), they can be in either device or host memory. - // - // This rewrite removes `MoveToHost` custom calls that feed directly into - // the computation offloading to the host. - kElideMoveToHost, - - // Currently host compute offloading does not support tiled layouts, and - // because of that layouts on the call instruction arguments might be - // different from the layouts in the called computation body. - // - // Host offloading handles layout mismatches at run time by delinearizing - // arguments and linearizing results on the fly. - // - // To keep HLO module valid we rewrite calls to host offloaded computations - // into custom calls with the only purpose to suppress verification error. - // Host offloading compiler later does its own verification to check that - // arguments are compatible with parameters in the offloaded computation and - // knows how to handle mismatched layouts. - kConvertToCustomCall, - }; - - static std::string RewriteName(Rewrite rewrite) { - switch (rewrite) { - case Rewrite::kElideMoveToHost: - return "elide-move-to-host"; - case Rewrite::kConvertToCustomCall: - return "convert-to-custom-call"; - } - } - - explicit HostOffloadingPrepare(Rewrite rewrite) - : rewrite_(rewrite), - pass_name_(absl::StrCat("host-offloading-prepare", "-", - RewriteName(rewrite_))) {} - - absl::string_view name() const override { return pass_name_; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - Rewrite rewrite_; - std::string pass_name_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/host_offloading_prepare.h" #endif // XLA_SERVICE_HOST_OFFLOADING_PREPARE_H_ diff --git a/third_party/xla/xla/service/indexed_array_analysis.h b/third_party/xla/xla/service/indexed_array_analysis.h index 1f0f451b1cf070..6dbfd2a1eccf74 100644 --- a/third_party/xla/xla/service/indexed_array_analysis.h +++ b/third_party/xla/xla/service/indexed_array_analysis.h @@ -16,380 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ #define XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/literal.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -// IndexedArrayAnalysis decides if an HLO instruction can be rewritten as a -// gather from another array. It does this by mapping HLO instructions to -// instances of IndexedArrayAnalysis::Array, which can be inspected to discover -// whether said HLO is equivalent to a gather. -class IndexedArrayAnalysis { - public: - // IndexedArrayAnalysis maps each HLO instruction to an instance of a Array. - // Array really just a sum type of the classes that inherit from it. The - // meaning of each of the subtypes is documented on the subtype declaration. - // - // Array instances are immutable once created. - class Array { - public: - enum Kind { - kUnknown, - kConstant, - kReshaped, - kScalarIndexedConstant, - kScalarIndexed - }; - - virtual Kind kind() const = 0; - virtual const Shape& shape() const = 0; - - // Does a checked downcast from `Array` to `T` which must be one of its - // subtypes. - template - T* as() { - static_assert((std::is_base_of::value), - "target type not derived from source type"); - // We skip the CHECK and hence the dynamic_cast if RTTI is disabled. -#if !defined(__GNUC__) || defined(__GXX_RTTI) - CHECK_NE(dynamic_cast(this), nullptr); -#endif // !defined(__GNUC__) || defined(__GXX_RTTI) - - return static_cast(this); - } - - virtual ~Array() = default; - - Array& operator=(const Array& other) = delete; - }; - - // Represents an HLO instruction that was not analyzable by this - // IndexedArrayAnalysis. Instances of UnknownArray just wrap an existing - // HloInstruction. - class UnknownArray : public Array { - public: - Kind kind() const override { return kUnknown; } - const Shape& shape() const override { return instruction().shape(); } - const HloInstruction& instruction() const { return instruction_; } - - private: - explicit UnknownArray(const HloInstruction* instr) : instruction_(*instr) {} - - const HloInstruction& instruction_; - - friend class IndexedArrayAnalysis; - }; - - // Represents a constant value. This constant value may be present in the HLO - // module being analyzed, or it could have been created on the fly by the - // analysis. - class ConstantArray : public Array { - public: - Kind kind() const override { return kConstant; } - const Shape& shape() const override { return literal()->shape(); } - const Literal* literal() const { return literal_; } - - private: - explicit ConstantArray(const Literal* literal) : literal_(literal) {} - const Literal* literal_; - - friend class IndexedArrayAnalysis; - }; - - // Represents an Array that is a reshape of another Array. - class ReshapedArray : public Array { - public: - Kind kind() const override { return kReshaped; } - - // The array to reshape. - Array* operand() const { return operand_; } - - // The output shape. - const Shape& shape() const override { return shape_; } - - private: - explicit ReshapedArray(Array* operand, Shape shape) - : operand_(operand), shape_(shape) {} - - Array* operand_; - const Shape shape_; - - friend class IndexedArrayAnalysis; - }; - - // --------------------------------------------------------------------------- - // Indexed Array Overview - // --------------------------------------------------------------------------- - // - // ScalarIndexedArray and ScalarIndexedConstantArray form the core of this - // analysis. ScalarIndexedConstantArray is just a specialization of - // ScalarIndexedArray so we will only discuss ScalarIndexedArray in this - // overview. - // - // A ScalarIndexedArray represents an array that can be computed by indexing - // into a "source" array using an "indices" tensor. A simple example is a - // gather operation gathering 12 rows out of a [100,100] matrix -- such an - // operation will be represented by an instance of a ScalarIndexedArray with - // the [100,100] matrix as the "source" array and the [12]-shaped indices - // array as the "indices" tensor. The ScalarIndexedArray operation itself - // will be of shape [12,100] (assuming we were gathering with axis=0). - // - // Gather operations are not the only operation that maps to - // ScalarIndexedArray instances (if that were true there would be little point - // in having a separate analysis). We can often infer ScalarIndexedArrays for - // other operations too. For instance, consider: - // - // %source = f32[100,100] constant - // %indices = s32[12] ... - // %gather = f32[12,100] ... gather from %source using %indices at axis 0 - // %dot = dot(%gather, other_constant) [canonical contracting dims] - // - // The dot operation itself is also a ScalarIndexedArray with source = - // dot(constant, other_constant) and indices = %indices. A reshape of %gather - // to [12,5,20] too is a ScalarIndexedArray with source = an appropriately - // reshaped constant and indices = %indices. - - // Represents the result of a gather operation. This gather operation may - // explicitly be present in the HLO module being analyzed, or it could have - // been created on the fly by the analysis. - // - // An instance of ScalarIndexedArray represents a array whose I'th element can - // be mapped to the J'th element of the `source` array (where I and J are - // multidimensional indices) in this way: - // - // I' = remove components at positions `output_dims` from I - // G' = remove components not at positions `output_dims` from I - // T = indices[G'] - // J = I' with T inserted at position `source_dim` - // - // For example, if source is of shape [11,13,17,19], indices is of shape - // [23,29], output_dims is [0,2] and source_dim is 2 then the output is of - // shape [23,11,29,13,19] and the output index [A,B,C,D,E] is mapped to the - // input index [B,D,indices[A,C],E]. - class ScalarIndexedArray : public Array { - public: - Kind kind() const override { return kScalarIndexed; } - const Shape& shape() const override { return shape_; } - - Array* source() const { return source_; } - Array* indices() const { return indices_; } - - // `source_dim` is the dimension in the source array that is being indexed - // over using indices from the `indices` array. See the class documentation - // and the overview for more details. - int64_t source_dim() const { return source_dim_; } - - // `output_dims` are the dimensions in the output array that are being used - // to compute an index into the `indices` array. See the class - // documentation and the overview for more details. - absl::Span output_dims() const { return output_dims_; } - - private: - explicit ScalarIndexedArray(Array* source, Array* indices, - int64_t source_dim, - std::vector output_dims, Shape shape) - : source_(source), - indices_(indices), - source_dim_(source_dim), - output_dims_(std::move(output_dims)), - shape_(std::move(shape)) {} - - Array* source_; - Array* indices_; - int64_t source_dim_; - std::vector output_dims_; - Shape shape_; - - friend class IndexedArrayAnalysis; - }; - - // A ScalarIndexedConstantArray is just a ScalarIndexedArray constrained to - // have a ConstantArray instance as the source. This is an ergonomic - // concession -- in theory it is possible to just keep ScalarIndexedArray and - // check source()->kind(). - class ScalarIndexedConstantArray : public ScalarIndexedArray { - public: - Kind kind() const override { return kScalarIndexedConstant; } - - const Literal& literal() const { - return *source()->as()->literal(); - } - - private: - explicit ScalarIndexedConstantArray(Array* source, Array* indices, - int64_t source_dim, - std::vector output_dims, - Shape shape) - : ScalarIndexedArray(source, indices, source_dim, - std::move(output_dims), std::move(shape)) { - CHECK(dynamic_cast(source)); - } - - friend class IndexedArrayAnalysis; - }; - - // Returns an Array instance for `instr`. The IndexedArrayAnalysis instance - // keeps ownership of the returned Array instance. - // - // Caching Behavior: IndexedArrayAnalysis has a cache mapping HLO - // instructions to IndexedArrayAnalysis::Array instances. This entire cache - // becomes stale and may cause the analysis to return incorrect results if any - // transitive operand (stopping at the containing computation) is modified for - // any HLO instruction on which GetArrayFor has been invoked. - // - // NB! By inspecting the implementation, you may be able to infer a stronger - // caching guarantee than what is mentioned above. Nevertheless, what is - // stated above is the contract. - absl::StatusOr GetArrayFor(const HloInstruction* instr); - - // Pretty-prints the expression rooted at `root`. - std::string ToString(Array* root, bool print_constants = false); - - private: - // Helper function that ensures that every HLO instruction that is - // transitively used by `root` has an entry in `cache_`. - absl::Status TraverseAndPopulateCache(const HloInstruction* root); - - // Creates an Array instance for `instr` under the assumption that all - // operations of `instr` are present in `cache_`. - absl::StatusOr ComputeArrayFor(const HloInstruction* instr); - - absl::StatusOr ComputeArrayForConstant(const Literal& literal); - - absl::StatusOr ComputeArrayForGather( - const Shape& shape, const GatherDimensionNumbers& dim_numbers, - absl::Span slice_sizes, Array* source, Array* indices); - - absl::StatusOr ComputeArrayForDotWithIndexedLhs( - const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, ScalarIndexedConstantArray* lhs, - ConstantArray* rhs); - - absl::StatusOr ComputeArrayForDotWithIndexedRhs( - const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, ConstantArray* lhs, - ScalarIndexedConstantArray* rhs); - - absl::StatusOr ComputeArrayForDot( - const Shape& shape, const DotDimensionNumbers& dim_numbers, - const PrecisionConfig& precision_config, Array* lhs, Array* rhs); - - // This tries to fold a ScalarIndexedArray which has another - // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a - // ScalarIndexedArray as indices. If `source` happened to be a - // ScalarIndexedConstantArray this can result in an expression that is more - // canonical. - // - // As an example, consider a gather operation, G0, gathering 7 elements from - // an array "Arr" of shape [100] resulting in an array of shape [7], and a - // second gather operation, G1, which gathers 3 elements out of the result of - // G0 resulting in an array of shape [3]. Let the indices uses by G0 be I0 - // (of shape [7]) and the indices used by G1 be I1 (of shape [3]). We can - // instead rewrite G1 to gather directly from "Arr" with the three indices - // from I0 as per I1. In other words, we can rewrite: - // - // G0 = [Arr[i] for i in I0] - // G1 = [G0[i] for i in I1] - // - // into - // - // I2 = [I0[i] for i in I1] - // G1 = [Arr[i] for i in I2] - absl::StatusOr FoldGatherOfGather( - ScalarIndexedArray* source, Array* indices, int64_t source_dim, - absl::Span output_dims, Shape shape); - - // Reshapes a scalar-indexed node to remove the degenerate dimensions in its - // output. The result is always a scalar-indexed node. - absl::StatusOr ReshapeToRemoveDegenerateDims( - ScalarIndexedArray* operand); - - // Reshapes a scalar-indexed node such that the result has the degenerate - // dimensions `degenerate_dims`. The result is always a scalar-indexed node. - absl::StatusOr ReshapeToAddDegenerateDims( - ScalarIndexedArray* operand, absl::Span degenerate_dims); - - absl::StatusOr FoldReshapeOfGather( - const Shape& shape, ScalarIndexedConstantArray* operand); - absl::StatusOr FoldReshapeOfGatherNoDegenerateDims( - const Shape& shape, ScalarIndexedConstantArray* scalar_indexed); - absl::StatusOr ComputeArrayForReshape(const Shape& shape, - Array* operand); - - absl::StatusOr ComputeArrayForElementwiseBinaryOp(HloOpcode opcode, - Array* lhs, - Array* rhs); - absl::StatusOr ComputeArrayForElementwiseUnaryOp(HloOpcode opcode, - Array* operand); - - template - T* Construct(Args&&... args) { - T* new_tensor = new T(std::forward(args)...); - owned_tensors_.push_back(std::unique_ptr(new_tensor)); - return new_tensor; - } - - ScalarIndexedArray* ConstructScalarIndexedArray( - Array* source, Array* indices, int64_t source_dim, - std::vector output_dims, Shape shape) { - if (source->kind() == Array::kConstant) { - return Construct(source, indices, source_dim, - std::move(output_dims), - std::move(shape)); - } else { - return Construct(source, indices, source_dim, - std::move(output_dims), - std::move(shape)); - } - } - - Literal* TakeOwnership(Literal literal) { - owned_literals_.push_back(std::move(literal)); - return &owned_literals_.back(); - } - - absl::StatusOr TakeOwnership( - absl::StatusOr literal_or_error) { - TF_ASSIGN_OR_RETURN(Literal literal, std::move(literal_or_error)); - owned_literals_.push_back(std::move(literal)); - return &owned_literals_.back(); - } - - std::vector> owned_tensors_; - std::vector owned_literals_; - absl::flat_hash_map cache_; -}; - -// A pass that prints all non-trivial results returned by IndexedArrayAnalysis. -// This pass is a no-op if !VLOG_IS_ON(2) so it should be fine to -// unconditionally add to the regular HLO pass pipeline. -class IndexedArrayAnalysisPrinterPass : public HloModulePass { - public: - absl::string_view name() const override { - return "indexed-array-analysis-printer-pass"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/indexed_array_analysis.h" #endif // XLA_SERVICE_INDEXED_ARRAY_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/infeed_token_propagation.h b/third_party/xla/xla/service/infeed_token_propagation.h new file mode 100644 index 00000000000000..31a0aa19ed8c07 --- /dev/null +++ b/third_party/xla/xla/service/infeed_token_propagation.h @@ -0,0 +1,22 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ +#define XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ + +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/collectives/infeed_token_propagation.h" + +#endif // XLA_SERVICE_INFEED_TOKEN_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/instruction_fusion.cc b/third_party/xla/xla/service/instruction_fusion.cc index e96d811f4ba460..d299e9e7d31d13 100644 --- a/third_party/xla/xla/service/instruction_fusion.cc +++ b/third_party/xla/xla/service/instruction_fusion.cc @@ -38,12 +38,12 @@ limitations under the License. #endif // PLATFORM_GOOGLE #include "absl/types/span.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/map_util.h" #include "xla/service/fusion_queue.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" @@ -210,6 +210,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kRecv: case HloOpcode::kRecvDone: case HloOpcode::kReduce: @@ -614,7 +615,7 @@ absl::StatusOr InstructionFusion::Run( << use_regular_fusion.Explain(); } - FusionDecision use_mof; + FusionDecision use_mof = FusionDecision::Allow(); if (!use_regular_fusion) { use_mof = ShouldFuseIntoMultiOutput(instruction, i); if (use_mof) { @@ -718,8 +719,11 @@ HloInstruction* InstructionFusion::AddFusionInstruction( fusion_instruction->set_fusion_kind(kind); } } else { - fusion_instruction = computation->AddInstruction( - HloInstruction::CreateFusion(consumer->shape(), kind, consumer)); + fusion_instruction = + computation->AddInstruction(HloInstruction::CreateFusion( + consumer->shape(), kind, consumer, + absl::StrCat(HloOpcodeString(producer->opcode()), "_", + HloOpcodeString(consumer->opcode()), "_"))); TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction)); } fusion_instruction->set_called_computations_execution_thread( @@ -911,8 +915,9 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, for (int i = 0; i < consumer->operand_count(); ++i) { if (i != operand_number && consumer->operand(operand_number) == consumer->operand(i)) { - return "The consumer is an in-place operation that has an additional " - "operand that has the same value as the in-place buffer"; + return FusionDecision::Forbid( + "The consumer is an in-place operation that has an additional " + "operand that has the same value as the in-place buffer"); } } if (consumer->operand(operand_number) == producer || @@ -949,12 +954,13 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, return is_nonelementwise_op(inst); }); if (producer_nonelementwise_ops.size() > 1) { - return "Producer fusion has multiple non-elementwise ops, bailing."; + return FusionDecision::Forbid( + "Producer fusion has multiple non-elementwise ops, bailing."); } // If the producer has only elementwise ops or bitcasts, we can fuse. if (producer_nonelementwise_ops.empty()) { if (consumer->opcode() != HloOpcode::kFusion) { - return {}; + return FusionDecision::Allow(); } // If the consumer fusion has both elementwise and non-elementwise ops, // and ops of the two groups access the same buffer of the producer, we @@ -980,9 +986,10 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, instr); }); return inplace_conflict_after_fusion - ? "Non-elementwise ops in consumer lead to inplace conflict " - "after fusion." - : FusionDecision(); + ? FusionDecision::Forbid( + "Non-elementwise ops in consumer lead to inplace " + "conflict after fusion.") + : FusionDecision::Allow(); } auto dus_ops = ExtractInstructions(consumer, HloOpcode::kDynamicUpdateSlice); @@ -991,17 +998,18 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, // TODO(akuegel): Are there other ops than dynamic update slice where we // have a special emitter if it can be done in-place? if (dus_ops.empty()) { - return {}; + return FusionDecision::Allow(); } if (dus_ops.size() > 1) { - return "multiple dus ops, bailing."; + return FusionDecision::Forbid("multiple dus ops, bailing."); } auto dus = dus_ops[0]; auto producer_nonelementwise = producer_nonelementwise_ops[0]; if (producer_nonelementwise->opcode() == HloOpcode::kSlice) { if (producer_nonelementwise->shape() != dus->operand(1)->shape()) { - return "Slice op has a different shape than the update shape of the " - "dus op, bailing."; + return FusionDecision::Forbid( + "Slice op has a different shape than the update shape of the " + "dus op, bailing."); } for (int i = 0; i < dus->shape().rank(); ++i) { const HloInstruction* dus_operand = @@ -1010,21 +1018,23 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, if (!constant_operand || *constant_operand != producer_nonelementwise->slice_starts(i) || producer_nonelementwise->slice_strides(i) != 1) { - return "DUS and slice index mismatch"; + return FusionDecision::Forbid("DUS and slice index mismatch"); } } VLOG(4) << "DUS and slice index match"; if (consumer->opcode() == HloOpcode::kFusion && !IsSafeToFuseSliceIntoDusFusion(producer, consumer, dus)) { - return "Fusing slice into DUS will also fuse another non-elementwise " - "op with shared operand as DUS."; + return FusionDecision::Forbid( + "Fusing slice into DUS will also fuse another non-elementwise " + "op with shared operand as DUS."); } - return {}; + return FusionDecision::Allow(); } if (producer_nonelementwise->opcode() == HloOpcode::kDynamicSlice) { if (producer_nonelementwise->shape() != dus->operand(1)->shape()) { - return "Dynamic slice op has a different shape than the update shape " - "of the dus op, bailing."; + return FusionDecision::Forbid( + "Dynamic slice op has a different shape than the update shape " + "of the dus op, bailing."); } for (int i = 0; i < dus->shape().rank(); ++i) { const HloInstruction* ds_operand = get_real_operand( @@ -1035,21 +1045,23 @@ bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer, auto constant_dus_operand = get_constant_operand(dus_operand); if (constant_ds_operand != constant_dus_operand || (!constant_ds_operand && ds_operand != dus_operand)) { - return "DUS and DS index mismatch"; + return FusionDecision::Forbid("DUS and DS index mismatch"); } } VLOG(4) << "DUS and DS index match"; if (consumer->opcode() == HloOpcode::kFusion && !IsSafeToFuseSliceIntoDusFusion(producer, consumer, dus)) { - return "Fusing DS into DUS will also fuse another non-elementwise op " - "with shared operand as DUS."; + return FusionDecision::Forbid( + "Fusing DS into DUS will also fuse another non-elementwise op " + "with shared operand as DUS."); } - return {}; + return FusionDecision::Allow(); } - return "unrecognized inplace update non-elementwise output pair"; + return FusionDecision::Forbid( + "unrecognized inplace update non-elementwise output pair"); } } - return {}; + return FusionDecision::Allow(); } FusionDecision InstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -1065,15 +1077,17 @@ FusionDecision InstructionFusion::ShouldFuse( // Don't fuse across a root instruction. if (producer == producer->parent()->root_instruction()) { - return "not fusing into the output of the root instruction"; + return FusionDecision::Forbid( + "not fusing into the output of the root instruction"); } // Cost condition: don't duplicate expensive instructions. if (FusionWouldDuplicate(*producer, *consumer) && (!may_duplicate_ || is_expensive_(*producer)) && !IsAlwaysDuplicable(*producer)) { - return may_duplicate_ ? "expensive producer would be duplicated" - : "fusion pass cannot duplicate"; + return FusionDecision::Forbid(may_duplicate_ + ? "expensive producer would be duplicated" + : "fusion pass cannot duplicate"); } return inplace_op_fusion_decider(producer, consumer); } diff --git a/third_party/xla/xla/service/instruction_fusion.h b/third_party/xla/xla/service/instruction_fusion.h index c4952aea15c2ae..0fa29ae93f01c9 100644 --- a/third_party/xla/xla/service/instruction_fusion.h +++ b/third_party/xla/xla/service/instruction_fusion.h @@ -33,10 +33,10 @@ limitations under the License. #if defined(PLATFORM_GOOGLE) #include "absl/types/source_location.h" #endif // PLATFORM_GOOGLE +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/fusion_queue.h" @@ -46,16 +46,11 @@ namespace xla { // explain the reason. class FusionDecision { public: - // Can not be fused: explain why. Implicit conversion due to optional-like - // semantics: waiver granted in cl/419938611. - FusionDecision(absl::string_view explanation) // NOLINT - : explanation_(explanation) {} - - // Same constructor as string_view, to allow implicit string conversion (can't - // implicitly convert both char* to string_view and string_view to - // FusionDecision). - FusionDecision(const char* explanation) // NOLINT - : explanation_(explanation) {} + static FusionDecision Allow() { return FusionDecision(); } + static FusionDecision Forbid(absl::string_view explanation) { + return FusionDecision(explanation); + } + FusionDecision(const FusionDecision& decision) = default; // If condition is `true` means that we CAN fuse. In that case, explanation is // discarded. @@ -74,9 +69,6 @@ class FusionDecision { absl::SourceLocation source_location = absl::SourceLocation::current()); #endif // PLATFORM_GOOGLE - // Can be fused. - FusionDecision() = default; - // Returns whether it can be fused. explicit operator bool() const { return CanFuse(); } @@ -88,9 +80,10 @@ class FusionDecision { // them is false to show why fusion wasn't performed. FusionDecision Or(const FusionDecision& decision) const { if (CanFuse() || decision.CanFuse()) { - return {}; + return Allow(); } - return {absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())}; + return Forbid( + absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())); } // Connects two fusion decision with a conjunction. Unlike disjunction, @@ -109,12 +102,12 @@ class FusionDecision { // Appends to explanation, or turns the decision negative. FusionDecision operator<<(absl::string_view explanation) const { - return {absl::StrCat(explanation_.value_or(""), explanation)}; + return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); } // Appends to explanation, or turns the decision negative. FusionDecision operator<<(int64_t explanation) const { - return {absl::StrCat(explanation_.value_or(""), explanation)}; + return Forbid(absl::StrCat(explanation_.value_or(""), explanation)); } // Explains why the fusion could not be performed. @@ -123,6 +116,14 @@ class FusionDecision { private: // Empty IFF fusion is possible (explanation provided for negative cases). std::optional explanation_; + + FusionDecision() = default; + + explicit FusionDecision(absl::string_view explanation) + : explanation_(explanation) {} + + explicit FusionDecision(const char* explanation) + : explanation_(explanation) {} }; #define RETURN_IF_NOT_FUSIBLE(...) \ @@ -213,7 +214,8 @@ class InstructionFusion : public HloModulePass { // duplicated by multi-output fusion. virtual FusionDecision ShouldFuseIntoMultiOutput(HloInstruction* consumer, int64_t operand_index) { - return "multi-output fusion not supported by this pass"; + return FusionDecision::Forbid( + "multi-output fusion not supported by this pass"); } // Chooses a fusion kind for `producer` and `consumer`. diff --git a/third_party/xla/xla/service/instruction_fusion_test.cc b/third_party/xla/xla/service/instruction_fusion_test.cc index db6c3244c3932f..d75c9a4fe29d86 100644 --- a/third_party/xla/xla/service/instruction_fusion_test.cc +++ b/third_party/xla/xla/service/instruction_fusion_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" @@ -791,20 +791,20 @@ TEST_F(InstructionFusionTest, DontFuseProducerIfInplaceConflict) { class FusionDecisionTest : public HloTestBase {}; TEST_F(FusionDecisionTest, NotFusionPossibleDisjunction) { - FusionDecision a = {}; - FusionDecision b = "not possible"; + FusionDecision a = FusionDecision::Allow(); + FusionDecision b = FusionDecision::Forbid("not possible"); EXPECT_TRUE(!a || !b); - a = "not possible"; - b = {}; + a = FusionDecision::Forbid("not possible"); + b = FusionDecision::Allow(); EXPECT_TRUE(!a || !b); - a = "impossible"; - b = "very impossible"; + a = FusionDecision::Forbid("impossible"); + b = FusionDecision::Forbid("very impossible"); EXPECT_TRUE(!a || !b); - a = {}; - b = {}; + a = FusionDecision::Allow(); + b = FusionDecision::Allow(); EXPECT_FALSE(!a || !b); } diff --git a/third_party/xla/xla/service/instruction_hoister.h b/third_party/xla/xla/service/instruction_hoister.h index a52598e0520376..bd002321eecf92 100644 --- a/third_party/xla/xla/service/instruction_hoister.h +++ b/third_party/xla/xla/service/instruction_hoister.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_INSTRUCTION_HOISTER_H_ #define XLA_SERVICE_INSTRUCTION_HOISTER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// HLO pass that hoists parameters and constants to increase opportunities for -// prefetching. -class InstructionHoister : public HloModulePass { - public: - explicit InstructionHoister(bool hoist_parameters = true, - bool host_constants = true) - : hoist_parameters_(hoist_parameters), host_constants_(host_constants) {} - - ~InstructionHoister() override = default; - - absl::string_view name() const override { return "instruction-hoister"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - bool hoist_parameters_; - bool host_constants_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" #endif // XLA_SERVICE_INSTRUCTION_HOISTER_H_ diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index f0e5af7ac3c9a3..b53fae81283425 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -40,15 +40,15 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/map_util.h" #include "xla/service/dump.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" @@ -73,12 +73,70 @@ bool IsNopInstruction(const HloInstruction& hlo) { (op == HloOpcode::kTuple && hlo.user_count() == 1 && hlo.users().front()->opcode() == HloOpcode::kWhile); } + +bool InstructionDefinesValue(const HloInstruction* instruction, + const HloValue* value) { + if (value->defining_instruction() == instruction) { + return true; + } + if (value->shape().has_layout() && + value->shape().layout().memory_space() != kDefaultMemorySpace) { + return false; + } + // Also check if the instruction is a call to a computation that defines the + // value. This is needed in cases, e.g., where we wrap a value-defining + // instruction in a async call for offloading, and the async start itself will + // effectively define the value in the current scope that the scheduler is + // running in. + if (instruction->opcode() == HloOpcode::kAsyncStart) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == value->defining_instruction(); + } + return instruction->async_wrapped_instruction() == + value->defining_instruction(); + } + return false; +} + +bool InstructionFirstDefinesBuffer( + const HloInstruction* instruction, + const BufferInfoTracker::ValueInfo& buffer_value_info) { + if (buffer_value_info.first_definition == instruction) { + return true; + } + if (buffer_value_info.value->values()[0]->shape().has_layout() && + buffer_value_info.value->values()[0]->shape().layout().memory_space() != + kDefaultMemorySpace) { + return false; + } + // Similar to logic above, also check if the instruction is a call to a + // computation that defines the value. + if (instruction->opcode() == HloOpcode::kAsyncStart) { + if (instruction->async_wrapped_opcode() == HloOpcode::kCall) { + return instruction->async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() == buffer_value_info.first_definition; + } + return instruction->async_wrapped_instruction() == + buffer_value_info.first_definition; + } + return false; +} + } // namespace CanonicalAsyncOp DefaultGetCanonicalAsyncOp(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kAsyncStart: case HloOpcode::kAsyncDone: + if (hlo.async_wrapped_opcode() == HloOpcode::kCall) { + return {hlo.opcode(), hlo.async_wrapped_instruction() + ->called_computations()[0] + ->root_instruction() + ->opcode()}; + } return {hlo.opcode(), hlo.async_wrapped_opcode()}; case HloOpcode::kAllReduceStart: return {HloOpcode::kAsyncStart, HloOpcode::kAllReduce}; @@ -279,12 +337,13 @@ ResourcesVector AsyncTracker::GetResourcesFromInstructionImpl( } } -ResourcesVector AsyncTracker::GetResourcesFromInstruction( +absl::Span AsyncTracker::GetResourcesFromInstruction( const HloInstruction& hlo) const { - if (!resources_cache_.contains(&hlo)) { - resources_cache_.insert({&hlo, GetResourcesFromInstructionImpl(hlo)}); + auto [it, inserted] = resources_cache_.emplace(&hlo, ResourcesVector{}); + if (inserted) { + it->second = GetResourcesFromInstructionImpl(hlo); } - return resources_cache_.at(&hlo); + return it->second; } int64_t AsyncTracker::GetNumResourcesPerInstruction( @@ -293,6 +352,31 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction( instr); } +const absl::flat_hash_map& +AsyncTracker::RecursivelyComputeResourceMap( + const HloComputation* computation) const { + auto& per_opcode_map = async_in_computation_cache_[computation]; + if (per_opcode_map != nullptr) { + return *per_opcode_map; + } + per_opcode_map = std::make_unique>(); + auto* m = per_opcode_map.get(); + for (HloInstruction* instr : computation->instructions()) { + if (IsSupportedAsyncDone(*instr)) { + for (const auto& resource : GetResourcesFromInstruction(*instr)) { + ++(*m)[resource.first]; + } + } + for (const HloComputation* called_comp : instr->called_computations()) { + for (auto& called_per_opcode_pair : + RecursivelyComputeResourceMap(called_comp)) { + (*m)[called_per_opcode_pair.first] += called_per_opcode_pair.second; + } + } + } + return *m; +} + int64_t AsyncTracker::GetNumResourcesPerInstruction( int64_t resource_type, const HloInstruction& instr) const { // For instructions not calling a computation then return 1 if the instruction @@ -309,45 +393,13 @@ int64_t AsyncTracker::GetNumResourcesPerInstruction( ? 1 : 0; } - std::function recursively_compute_resource_map = - [this, - &recursively_compute_resource_map](const HloComputation* computation) { - absl::flat_hash_map per_opcode_map; - for (HloInstruction* instr : computation->instructions()) { - if (IsSupportedAsyncDone(*instr)) { - for (auto& resource : GetResourcesFromInstruction(*instr)) { - ++per_opcode_map[resource.first]; - } - } - for (const HloComputation* called_comp : - instr->called_computations()) { - auto it = async_in_computation_cache_.find(called_comp); - if (it == async_in_computation_cache_.end()) { - recursively_compute_resource_map(called_comp); - it = async_in_computation_cache_.find(called_comp); - CHECK(it != async_in_computation_cache_.end()); - } - for (auto& called_per_opcode_pair : it->second) { - per_opcode_map[called_per_opcode_pair.first] += - called_per_opcode_pair.second; - } - } - } - async_in_computation_cache_[computation] = std::move(per_opcode_map); - }; int64_t num_resources = 0; for (const HloComputation* computation : instr.called_computations()) { - auto it = async_in_computation_cache_.find(computation); - if (it == async_in_computation_cache_.end()) { - recursively_compute_resource_map(computation); - it = async_in_computation_cache_.find(computation); - CHECK(it != async_in_computation_cache_.end()); - } - auto opcode_it = it->second.find(resource_type); - if (opcode_it == it->second.end()) { - continue; + const auto& map = RecursivelyComputeResourceMap(computation); + auto opcode_it = map.find(resource_type); + if (opcode_it != map.end()) { + num_resources += opcode_it->second; } - num_resources += opcode_it->second; } return num_resources; } @@ -596,7 +648,7 @@ void MemoryPressureTracker::Initialize( output_values.push_back(std::make_pair( buffer_tracker_.GetBufferInfo(buffer->id()), index)); if (absl::c_any_of(buffer->values(), [&](const HloValue* value) { - return value->defining_instruction() == instruction; + return InstructionDefinesValue(instruction, value); })) { defined_values.push_back( buffer_tracker_.GetBufferInfo(buffer->id())); @@ -663,7 +715,7 @@ void MemoryPressureTracker::UpdateBuffers(const HloInstruction* instruction) { continue; } if (live_buffers_[b.value->id()] != 0) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { live_memory_usage_ -= b.buffer_size; live_buffers_set_.erase(b.value->id()); } @@ -721,7 +773,7 @@ std::pair MemoryPressureTracker::MemoryPressureDifference( continue; } if (live_buffers_[b.value->id()]) { - if (b.first_definition == instruction) { + if (InstructionFirstDefinesBuffer(instruction, b)) { increase -= b.buffer_size; } } @@ -1736,8 +1788,9 @@ HloScheduleGraph::HloScheduleGraph( new_node_it->second->predecessors_.reserve(instr->operand_count()); new_node_it->second->successors_.reserve(instr->user_count()); new_node_it->second->cost_ = latency_estimator->NodeCost(instr); + auto resources = async_tracker->GetResourcesFromInstruction(*instr); new_node_it->second->resources_ = - async_tracker->GetResourcesFromInstruction(*instr); + ResourcesVector(resources.begin(), resources.end()); new_node_it->second->released_shareable_resources_ = async_tracker->GetReleasedShareableResourcesFromVector( new_node_it->second->GetResources()); diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 57cd09362893f1..09c97eccf81969 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -35,12 +35,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/map_util.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" @@ -139,6 +139,7 @@ struct SchedulerConfig { bool enable_selective_resources = false; int64_t max_hops_to_closest_selective_overlap = 0; int64_t rerun = 0; + int64_t parallel_collective_overlap_limit = 1; }; // Class used estimate latency between instructions and cost of HLOs. @@ -215,7 +216,7 @@ class AsyncTracker { const HloInstruction& hlo) const; // Returns resources used (i.e., occupied or released) by this instruction - virtual ResourcesVector GetResourcesFromInstruction( + absl::Span GetResourcesFromInstruction( const HloInstruction& hlo) const; // Modifies the schedule graph passed as input to add dependencies that are @@ -298,8 +299,12 @@ class AsyncTracker { : get_canonical_async_op_(std::move(func)), config_(config) {} private: - mutable absl::flat_hash_map> + const absl::flat_hash_map& RecursivelyComputeResourceMap( + const HloComputation* computation) const; + + mutable absl::flat_hash_map< + const HloComputation*, + std::unique_ptr>> async_in_computation_cache_; GetCanonicalAsyncOpFunc get_canonical_async_op_; diff --git a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc index 60b00e642eb10a..8247cfa7cf35c9 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler_test.cc @@ -37,7 +37,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" -#include "xla/service/async_collective_creator.h" +#include "xla/hlo/transforms/collectives/async_collective_creator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -3165,9 +3165,10 @@ ENTRY %module { return AsyncTracker::GetResourceHazardType(resource_type); } - ResourcesVector GetResourcesFromInstruction( + ResourcesVector GetResourcesFromInstructionImpl( const HloInstruction& hlo) const override { - ResourcesVector result = AsyncTracker::GetResourcesFromInstruction(hlo); + ResourcesVector result = + AsyncTracker::GetResourcesFromInstructionImpl(hlo); // There is only one target defined resource (which is non-extendable). if (hlo.opcode() == HloOpcode::kAllGatherStart) { result.push_back({AsyncTracker::GetFirstTargetDefinedResource(), diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 02885038251c06..56c7eb33895534 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -42,16 +43,15 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/map_util.h" #include "xla/permutation_util.h" #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" -#include "xla/service/hlo_dce.h" #include "xla/service/logical_buffer.h" -#include "xla/service/tuple_points_to_analysis.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" @@ -275,11 +275,8 @@ absl::Status LayoutAssignment::SetBufferLayout(const Layout& layout, << LayoutUtil::HumanString(layout) << " with priority " << priority << "; mandatory = " << mandatory << "; dfs = " << dfs << "\n"; TF_RETURN_IF_ERROR(points_to_analysis_->VerifyBuffer(buffer)); - if (unconstrained_buffer_ids_.find(buffer.id()) != - unconstrained_buffer_ids_.end()) { + if (unconstrained_buffer_ids_.erase(buffer.id()) > 0) { VLOG(3) << "Erase buffer from unconstrained ids\n"; - TF_RET_CHECK(unconstrained_buffer_ids_.erase(buffer.id()) == 1) - << buffer.ToString(); } if (!buffer.IsArray()) { @@ -291,32 +288,27 @@ absl::Status LayoutAssignment::SetBufferLayout(const Layout& layout, TF_RETURN_IF_ERROR( LayoutUtil::ValidateLayoutForShape(layout, buffer.shape())); - auto iter = buffer_constraints_.find(&buffer); - if (iter != buffer_constraints_.end()) { - BufferLayoutConstraint curr_constraint = iter->second; - if (curr_constraint.UpdateLayout(priority, layout, mandatory, dfs, this, - user)) { + auto& buffer_constraint = buffer_constraints_[&buffer]; + if (buffer_constraint == nullptr) { + buffer_constraint = std::make_unique( + layout, buffer, mandatory, dfs, priority); + } else { + if (buffer_constraint->UpdateLayout(priority, layout, mandatory, dfs, this, + user)) { if (IsAtMostRank1(buffer.shape())) { return absl::OkStatus(); } - iter = - buffer_constraints_.insert_or_assign(&buffer, curr_constraint).first; } else { VLOG(3) << "Unable to update existing Buffer layout for " - << curr_constraint.ToString() << " with new layout" + << buffer_constraint->ToString() << " with new layout" << LayoutUtil::HumanString(layout) << " at priority " << priority << "\n"; return absl::OkStatus(); } - } else { - iter = buffer_constraints_ - .insert(std::make_pair( - &buffer, BufferLayoutConstraint(layout, buffer, mandatory, - dfs, priority))) - .first; - } - VLOG(3) << "SUCC setting buffer constraint: " << iter->second.ToString(); - added_constraints_.push_back(&iter->second); + } + VLOG(3) << "SUCC setting buffer constraint: " + << buffer_constraint->ToString(); + added_constraints_.push_back(buffer_constraint.get()); const HloInstruction* instruction = buffer.instruction(); if (dynamic_cast(instruction) != nullptr) { // Check and propagate via output-operand aliasing @@ -540,7 +532,7 @@ absl::Status LayoutAssignment::SetInstructionLayout( const BufferLayoutConstraint* LayoutAssignment::GetBufferLayoutConstraint( const LogicalBuffer& buffer) const { auto it = buffer_constraints_.find(&buffer); - return it == buffer_constraints_.end() ? nullptr : &it->second; + return it == buffer_constraints_.end() ? nullptr : it->second.get(); } const ShapeLayout* LayoutAssignment::LayoutConstraints::OperandLayout( @@ -2895,6 +2887,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kReduceScatter: case HloOpcode::kAllReduceStart: case HloOpcode::kAllReduceDone: + case HloOpcode::kRaggedAllToAll: return false; case HloOpcode::kAsyncStart: case HloOpcode::kAsyncUpdate: diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index 6d29b7cd6f53d9..549e0de376550e 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -44,7 +45,6 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/computation_layout.h" #include "xla/service/logical_buffer.h" -#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" @@ -754,7 +754,8 @@ class LayoutAssignment : public HloModulePass { buffer_sets_cache_; // The set of BufferLayoutConstraints applied to the computation. - absl::node_hash_map + absl::flat_hash_map> buffer_constraints_; // A vector which holds constraints as they are added. Can be cleared with diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 0b294c46ddef17..fb24c1d83647cc 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -28,13 +28,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/algebraic_simplifier.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/algebraic_simplifier.h" #include "xla/service/computation_layout.h" -#include "xla/service/hlo_parser.h" #include "xla/service/logical_buffer.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 16781509e22c60..74100a62e20111 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -347,7 +347,11 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { auto s = hlo->shape(); auto a = hlo->mutable_operand(0); auto b = hlo->mutable_operand(1); - TF_RET_CHECK(a->shape().layout() == s.layout()); + auto layout_equal = Layout::Equal(); + if (hlo->opcode() == HloOpcode::kCompare) { + layout_equal.IgnoreElementSize(); + } + TF_RET_CHECK(layout_equal(a->shape().layout(), s.layout())); TF_ASSIGN_OR_RETURN(auto a0, GetNormalizedInput(a)); TF_ASSIGN_OR_RETURN(auto b0, GetNormalizedInput(b)); diff --git a/third_party/xla/xla/service/layout_normalization_test.cc b/third_party/xla/xla/service/layout_normalization_test.cc index 88ea4828ec597a..6fcf848ea46be8 100644 --- a/third_party/xla/xla/service/layout_normalization_test.cc +++ b/third_party/xla/xla/service/layout_normalization_test.cc @@ -922,5 +922,21 @@ ENTRY main.17 { }); } +TEST_F(LayoutNormalizationTest, CompareInt4) { + const char* hlo = R"( +HloModule module + +ENTRY main { + a = s4[10]{0:E(4)} parameter(0) + b = s4[10]{0:E(4)} parameter(1) + ROOT out = compare(a, b), direction=EQ +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: pred[10]{0} compare({{.*}}) +)"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/llvm_ir/alias_analysis_test.cc b/third_party/xla/xla/service/llvm_ir/alias_analysis_test.cc index cb91226ecc4bac..1ca4a5f1a9fad4 100644 --- a/third_party/xla/xla/service/llvm_ir/alias_analysis_test.cc +++ b/third_party/xla/xla/service/llvm_ir/alias_analysis_test.cc @@ -25,7 +25,7 @@ namespace xla::cpu { namespace { class AliasAnalysisTest : public CpuCodegenTest { - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); // We do not generate IR for while loops with thunks runtime, so we // explicitly disable it for this test. diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 27630b674d2ce4..fbd4ec99e29157 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -132,7 +132,7 @@ llvm::CallInst* EmitCallToIntrinsic( absl::Span overloaded_types, llvm::IRBuilder<>* b, absl::string_view name) { llvm::Module* module = ModuleFromIRBuilder(b); - llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration( + llvm::Function* intrinsic = llvm::Intrinsic::getOrInsertDeclaration( module, intrinsic_id, AsArrayRef(overloaded_types)); return b->CreateCall(intrinsic, AsArrayRef(operands), name.data()); } @@ -200,9 +200,11 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt16Ty(module->getContext()); case F8E5M2: case F8E5M2FNUZ: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E4M3FNUZ: + case F8E3M4: // We represent F8 as an int since there is no LLVM F8 dtype. return llvm::Type::getInt8Ty(module->getContext()); case BF16: diff --git a/third_party/xla/xla/service/local_service.cc b/third_party/xla/xla/service/local_service.cc index 0d89b2712ca044..557fb216582695 100644 --- a/third_party/xla/xla/service/local_service.cc +++ b/third_party/xla/xla/service/local_service.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/computation_layout.h" @@ -93,7 +93,6 @@ LocalService::CompileExecutables( build_options.layout_canonicalization_callback(), false, {}, - nullptr, {build_options.key_value_store(), build_options.process_index(), build_options.process_count()}}; if (build_options.num_partitions() == 1) { diff --git a/third_party/xla/xla/service/local_service.h b/third_party/xla/xla/service/local_service.h index 4b9112386d4e01..be7eee43e0c6ee 100644 --- a/third_party/xla/xla/service/local_service.h +++ b/third_party/xla/xla/service/local_service.h @@ -22,7 +22,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" diff --git a/third_party/xla/xla/service/local_service_utils.cc b/third_party/xla/xla/service/local_service_utils.cc index 9c0be82ab8860b..d6f6ce4f0280b3 100644 --- a/third_party/xla/xla/service/local_service_utils.cc +++ b/third_party/xla/xla/service/local_service_utils.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/backend.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/service/local_service_utils.h b/third_party/xla/xla/service/local_service_utils.h index ecfd832f2b85a7..d17300c0c7e688 100644 --- a/third_party/xla/xla/service/local_service_utils.h +++ b/third_party/xla/xla/service/local_service_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/backend.h" #include "xla/service/hlo_module_config.h" diff --git a/third_party/xla/xla/service/logical_buffer.h b/third_party/xla/xla/service/logical_buffer.h index f951baea8c26f8..350bbdfcd31c46 100644 --- a/third_party/xla/xla/service/logical_buffer.h +++ b/third_party/xla/xla/service/logical_buffer.h @@ -24,9 +24,9 @@ limitations under the License. #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/gtl/int_type.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/int_type.h" namespace xla { diff --git a/third_party/xla/xla/service/logical_buffer_analysis.h b/third_party/xla/xla/service/logical_buffer_analysis.h index 46b9f1d15b80a8..6571558fb208e4 100644 --- a/third_party/xla/xla/service/logical_buffer_analysis.h +++ b/third_party/xla/xla/service/logical_buffer_analysis.h @@ -16,77 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ #define XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/logical_buffer.h" -#include "xla/shape_util.h" - -namespace xla { -// A class to create all the logical buffers defined by the HLO ops in a module. -class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { - public: - // Runs points-to analysis on 'module'. - static absl::StatusOr> Run( - const HloModule* module); - - // Returns the logical buffer with the given ID. - LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; - - // Returns the logical buffer that represents the output of a given HLO - // at a given index. - LogicalBuffer& GetBuffer(HloInstruction* instruction, - const ShapeIndex& index) const; - - const std::vector>& logical_buffers() const { - return logical_buffers_; - } - size_t num_logical_buffers() const { return logical_buffers_.size(); } - - private: - explicit LogicalBufferAnalysis(const HloModule* module) : module_(module) {} - absl::Status Analyze(); - - // The module this analysis is performed on. - const HloModule* module_; - - // Create a new logical buffer and return a reference to it. The newly created - // buffer is stored in an internal vector of LogicalBuffers and can be - // accessed with GetBuffer. - void NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index); - - absl::Status DefaultAction(HloInstruction* hlo_instruction) override; - absl::Status HandleTuple(HloInstruction* tuple) override; - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - absl::Status HandleBitcast(HloInstruction* bitcast) override; - absl::Status HandleDomain(HloInstruction* domain) override; - absl::Status HandleCopy(HloInstruction* copy) override; - absl::Status HandleCopyStart(HloInstruction* copy_start) override; - absl::Status HandleCopyDone(HloInstruction* copy_done) override; - absl::Status HandleRecvDone(HloInstruction* recv_done) override; - absl::Status HandleSend(HloInstruction* send) override; - absl::Status HandleAddDependency(HloInstruction* add_dependency) override; - absl::Status HandleCustomCall(HloInstruction* custom_call) override; - absl::Status HandleFusion(HloInstruction* fusion) override; - - // A map from the buffer ID to the logical buffer - std::vector> logical_buffers_; - - // A map from an hlo + shape index to the logical buffer representing - // the appropriate output. - absl::flat_hash_map, - LogicalBuffer*> - output_buffers_; - // Whether to alias buffers defined by dataflow relations. This aliasing - // relation should not be recognized if copies can be inserted to break up - // the dataflow relation-induced aliasing. - const bool alias_buffer_across_dataflow_ = false; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/logical_buffer_analysis.h" #endif // XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/logistic_expander.h b/third_party/xla/xla/service/logistic_expander.h index 8e9aeec2e67952..c0c5ec0c37f0da 100644 --- a/third_party/xla/xla/service/logistic_expander.h +++ b/third_party/xla/xla/service/logistic_expander.h @@ -16,34 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_LOGISTIC_EXPANDER_H_ #define XLA_SERVICE_LOGISTIC_EXPANDER_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// A pass which performs expansion of the logistic function. -class LogisticExpander : public OpExpanderPass { - public: - LogisticExpander() = default; - ~LogisticExpander() override = default; - absl::string_view name() const override { return "logistic-expander"; } - - private: - // Returns `true` if `instruction` should be expanded by this pass. - bool InstructionMatchesPattern(HloInstruction* instruction) override; - // Returns a replacement for `instruction`, or nullptr if no replacement is - // needed (e.g. only the to_apply subcomputation of the instruction was - // modified). - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/logistic_expander.h" #endif // XLA_SERVICE_LOGISTIC_EXPANDER_H_ diff --git a/third_party/xla/xla/service/loop_schedule_linearizer.cc b/third_party/xla/xla/service/loop_schedule_linearizer.cc index c01260039f7c74..b02716bb23db59 100644 --- a/third_party/xla/xla/service/loop_schedule_linearizer.cc +++ b/third_party/xla/xla/service/loop_schedule_linearizer.cc @@ -24,14 +24,14 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/graphcycles/graphcycles.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/loop_schedule_linearizer.h b/third_party/xla/xla/service/loop_schedule_linearizer.h index 3286a2e41a0048..e6348b15f8acde 100644 --- a/third_party/xla/xla/service/loop_schedule_linearizer.h +++ b/third_party/xla/xla/service/loop_schedule_linearizer.h @@ -19,12 +19,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" namespace xla { diff --git a/third_party/xla/xla/service/map_inliner.cc b/third_party/xla/xla/service/map_inliner.cc index 0bb7920496928d..3eaf04f69d6e6b 100644 --- a/third_party/xla/xla/service/map_inliner.cc +++ b/third_party/xla/xla/service/map_inliner.cc @@ -96,6 +96,7 @@ absl::Status MapInlinerVisitor::HandleMap(HloInstruction* map) { computation_->ReplaceInstruction(map, placed_instruction)); } else { std::vector params; + params.reserve(root.operands().size()); for (int64_t o = 0; o < root.operands().size(); o++) { params.push_back(map->operands()[root.operand(o)->parameter_number()]); } diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index f3f989d083f8ea..83b877fde498b9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -47,12 +47,12 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service/heap_simulator", @@ -93,16 +93,16 @@ xla_cc_test( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:instruction_hoister", "//xla/hlo/utils:hlo_live_range", "//xla/hlo/utils:hlo_matchers", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_value", - "//xla/service:instruction_hoister", "//xla/service/heap_simulator", "//xla/service/heap_simulator:allocation_block", "//xla/tests:hlo_test_base", @@ -224,6 +224,7 @@ cc_library( "//xla/service/heap_simulator", "//xla/service/heap_simulator:allocation_block", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/log", @@ -237,6 +238,23 @@ cc_library( ], ) +xla_cc_test( + name = "allocation_test", + srcs = ["allocation_test.cc"], + deps = [ + ":allocation", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_value", + "//xla/service/heap_simulator", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + cc_library( name = "tuning_utils", srcs = ["tuning_utils.cc"], @@ -281,14 +299,14 @@ cc_library( deps = [ "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:call_graph", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", - "//xla/service:while_loop_analysis", "//xla/service/heap_simulator", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:function_ref", @@ -328,13 +346,14 @@ cc_library( ":cost_analysis", "//xla:shape_util", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", - "//xla/service:hlo_alias_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -347,9 +366,9 @@ xla_cc_test( ":cost_analysis", ":simulator", "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/service/heap_simulator", @@ -395,10 +414,10 @@ cc_library( deps = [ ":cost_analysis", "//xla:shape_util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:call_graph", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", @@ -419,10 +438,10 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", @@ -437,6 +456,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", ], @@ -458,10 +478,10 @@ xla_cc_test( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_value", "//xla/tests:hlo_test_base", @@ -501,13 +521,13 @@ cc_library( "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_alias_analysis", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", "//xla/service:call_graph", - "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", "//xla/service:time_utils", @@ -518,6 +538,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 67371364f1cd0d..1f4d09be94e995 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -30,7 +30,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -40,6 +39,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -49,6 +49,8 @@ limitations under the License. #include "absl/types/span.h" #include "re2/re2.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -58,9 +60,7 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" @@ -465,6 +465,16 @@ bool MsaAlgorithm::IsIntervalPinnedToAlternateMemory( shape.layout().memory_space() == options_.alternate_memory_space; } +bool MsaAlgorithm::MatchesPrefetchContext( + const PrefetchContext& context, absl::string_view producer_name, + ShapeIndex producer_shape_index, absl::string_view consumer_name) const { + return context.request->use->hlo_use.instruction->name() == consumer_name && + context.request->allocation_value->defining_position() + .instruction->name() == producer_name && + context.request->allocation_value->defining_position().index == + producer_shape_index; +} + MsaAlgorithm::MsaAlgorithm(AllocationSequence* allocations, const Options& options, const HloAliasAnalysis& alias_analysis, @@ -655,9 +665,12 @@ void MsaAlgorithm::FindAliases( auto aliased_values_it = values_by_defining_inst.find(instruction); if (aliased_values_it != values_by_defining_inst.end()) { for (const AllocationValue* aliased_value : aliased_values_it->second) { - VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() - << " to " << aliased_value->ToShortString(); - use->aliases.push_back(aliased_value->defining_position()); + if (absl::c_find(use->aliases, aliased_value->defining_position()) == + use->aliases.end()) { + VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() + << " to " << aliased_value->ToShortString(); + use->aliases.push_back(aliased_value->defining_position()); + } } } }; @@ -954,23 +967,12 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, const int iteration_start_idx = loop_start_idx + loop_size; const int iteration_end_idx = iteration_start_idx + loop_size; - TF_ASSIGN_OR_RETURN( - std::unique_ptr optimizer, - MemoryBoundLoopOptimizer::Create( - iteration_start_idx, iteration_end_idx, options_.max_size_in_bytes, - options_.memory_bound_loop_optimizer_options, hlo_live_range_, - alias_analysis_, *options_.cost_analysis, options_.size_fn, - options_.reserved_scoped_memory_fn)); + TF_ASSIGN_OR_RETURN(std::unique_ptr optimizer, + MemoryBoundLoopOptimizer::Create( + iteration_start_idx, iteration_end_idx, + hlo_live_range_, alias_analysis_, options_)); optimizer->Optimize(); - const int loop_optimized_allocations_original_size = - loop_optimized_allocations_.size(); - for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { - if (!value.allocations.empty() && value.IsAllocationTypeSupported()) { - loop_optimized_allocations_.push_back(std::move(value.allocations)); - } - } - // Check if this unrolled loop is in a while loop. const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); @@ -981,9 +983,12 @@ absl::Status MsaAlgorithm::OptimizeMemoryBoundLoop(int loop_start_idx, // Update the loop_optimized_allocations_map_ with the output of the // optimizer. - for (int i = loop_optimized_allocations_original_size; - i < loop_optimized_allocations_.size(); ++i) { - const AllocationSequence& sequence = loop_optimized_allocations_.at(i); + for (MemoryBoundLoopOptimizer::LoopValue& value : optimizer->loop_values()) { + if (value.allocations.empty() || !value.IsAllocationTypeSupported()) { + continue; + } + loop_optimized_allocations_.push_back(std::move(value.allocations)); + const AllocationSequence& sequence = loop_optimized_allocations_.back(); CHECK(!sequence.empty()); VLOG(3) << " alloc: " << sequence.back()->ToString(); for (const auto& allocation : sequence) { @@ -1278,17 +1283,96 @@ void MsaAlgorithm::IdentifyAndOptimizeMemoryBoundLoops() { } } -bool MsaAlgorithm::IsReplaceableSyncCopyCandidate( +bool MsaAlgorithm::IsAsyncConversionCandidate( const HloInstruction* instruction) const { - if (!options_.enable_sync_copy_replacement) { + bool meets_special_preconditions = + IsAsyncConversionCopyCandidate(instruction) || + IsAsyncConversionSliceCandidate(instruction) == + AsyncConversionResult::kSuccess; + if (!meets_special_preconditions) { + return false; + } + + for (auto& operand : instruction->operands()) { + // TODO(b/374835319): relax the operand constraint to be able to cover + // nested sync data movement cases. + if (IsAsyncConversionCandidate(operand)) { + VLOG(4) << "The instruction is not considered to be replaced, because it " + "potentially has a replaceable operand."; + return false; + } + const HloValue& operand_value = alias_analysis_.dataflow_analysis() + .GetValueSet(operand) + .GetUniqueValue(); + if (!buffer_intervals_.at(&operand_value).need_allocation) { + VLOG(4) + << "The instruction is not considered to be replaced, because its " + "operand value doesn't need an allocation."; + return false; + } + } + + const HloValue& value = alias_analysis_.dataflow_analysis() + .GetValueSet(instruction) + .GetUniqueValue(); + if (!buffer_intervals_.at(&value).need_allocation) { + VLOG(4) << "The instruction is not considered to be replaced, because its " + "output doesn't need an allocation and it might be too late to " + "replace this instruction."; + return false; + } + if (value.IsRootOf(instruction->parent())) { + VLOG(4) << "The instruction is not considered to be replaced, because its " + "output value is in the root of the computation."; return false; } - if (failed_copy_replacements_set_.contains(instruction)) { + if (finalized_values_.contains(&value)) { + VLOG(4) << "The instruction is not considered to be replaced, because its " + "output value is in the finalized values."; + return false; + } + if (buffer_intervals_.at(&value).size > available_heap_size()) { + VLOG(4) << "The instruction is not considered to be replaced, because its " + "output value is too large to fit in the heap."; + return false; + } + // This check is here only because we skip processing the values that are not + // allowed in alternate memory. + if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + buffer_intervals_.at(&value), options_.alternate_memory_space)) { + VLOG(4) << "The instruction is not considered to be replaced, because its " + "output value is not allowed in alternate memory."; + return false; + } + + for (const HloInstruction* user : instruction->users()) { + if (HloDataflowAnalysis::IsAsynchronousOperationStart(user->opcode())) { + VLOG(4) << "The instruction is not considered to be replaced, because " + "its used by an async start operation that might require " + "contiguous allocation."; + return false; + } + } + + return true; +} + +bool MsaAlgorithm::IsAsyncConversionCopyCandidate( + const HloInstruction* instruction) const { + if (!options_.enable_sync_copy_replacement) { return false; } if (instruction->opcode() != HloOpcode::kCopy) { return false; } + if (failed_async_conversions_.contains(instruction)) { + return false; + } + if (instruction->IsRoot()) { + // Root copy is not replaceable with current implementation, because the + // instruction has no uses + return false; + } if (instruction->operand(0)->shape() != instruction->shape()) { VLOG(5) << "Sync copy " << instruction->ToShortString() << " is not replaceable, because the operand and output shapes do " @@ -1308,109 +1392,151 @@ bool MsaAlgorithm::IsReplaceableSyncCopyCandidate( "initial assignment."; return false; } - if (alias_analysis_.dataflow_analysis() - .GetUniqueValueAt(instruction->operand(0)) - .positions() - .size() != 1) { - VLOG(5) << "Sync copy " << instruction->ToShortString() - << " is not replaceable because we currently do not support operand" - << " values that have more than one position."; - return false; - } return true; } -std::vector MsaAlgorithm::GenerateJointProcessedValues( - const HloValue* entrance_value) { - if (options_.enable_sync_copy_replacement) { - auto joint_processed_values = - GetJointProcessedValuesForSyncCopyReplacement(entrance_value); - UpdateSyncCopyCandidatesForJointProcessedValues(joint_processed_values); - return joint_processed_values; +namespace { + +bool IsTrivialInstruction(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kGetTupleElement || + instruction->opcode() == HloOpcode::kTuple || + instruction->opcode() == HloOpcode::kBitcast; +} + +} // namespace + +MsaAlgorithm::AsyncConversionResult +MsaAlgorithm::IsAsyncConversionSliceCandidate( + const HloInstruction* instruction) const { + if (!options_.enable_sync_slice_replacement) { + return AsyncConversionResult::kFeatureNotEnabled; + } + if (failed_async_conversions_.contains(instruction)) { + return failed_async_conversions_.at(instruction); + } + if (instruction->opcode() != HloOpcode::kSlice) { + return AsyncConversionResult::kFailedPrecondition; + } + + if (instruction->IsRoot()) { + // Root slice is not replaceable with current implementation, because the + // instruction has no uses + return AsyncConversionResult::kFailedPrecondition; + } + + if (!options_.is_async_slice_implemented_fn(instruction)) { + VLOG(4) << "The sync slice is not considered to be replaced, because the " + "async version is not implemented for " + << instruction->ToShortString(); + return AsyncConversionResult::kFailedPrecondition; + } + + if (instruction->shape().layout().memory_space() != + static_cast(MemorySpace::kDefault) || + instruction->operand(0)->shape().layout().memory_space() != + static_cast(MemorySpace::kDefault)) { + VLOG(4) << "Sync slice " << instruction->ToShortString() + << " is not replaceable, because the operand or output have an " + "initial assignment."; + return AsyncConversionResult::kFailedPrecondition; + } + if (instruction->shape().element_type() != + instruction->operand(0)->shape().element_type()) { + VLOG(4) << "Sync slice " << instruction->ToShortString() + << " is not replaceable because the operand and output have " + "different element types."; + return AsyncConversionResult::kFailedPrecondition; } - return {entrance_value}; + return AsyncConversionResult::kSuccess; } -std::vector -MsaAlgorithm::GetJointProcessedValuesForSyncCopyReplacement( - const HloValue* entrance_value) const { +std::vector MsaAlgorithm::GenerateJointProcessedValues( + const HloValue* entrance_value) { std::vector worklist = {entrance_value}; + if (options_.enable_sync_copy_replacement || + options_.enable_sync_slice_replacement) { + // Adds the HloValue that is related to a given instruction to the worklist + auto add_to_worklist = [&](const HloInstruction* inst) { + const HloValue& next_value = alias_analysis_.dataflow_analysis() + .GetValueSet(inst) + .GetUniqueValue(); + if (std::find(worklist.begin(), worklist.end(), &next_value) == + worklist.end()) { + worklist.push_back(&next_value); + } + }; - // Adds the HloValue that is related to a given instruction to the worklist - auto add_to_worklist = [&](const HloInstruction* inst) { - const HloValue& next_value = - alias_analysis_.dataflow_analysis().GetValueSet(inst).GetUniqueValue(); - if (std::find(worklist.begin(), worklist.end(), &next_value) == - worklist.end()) { - worklist.push_back(&next_value); - } - }; - - for (size_t idx = 0; idx < worklist.size(); ++idx) { - const HloValue* value = worklist.at(idx); - // Values that are related to the current value through a sync copy use - // are added to the worklist. - for (const auto& use : value->GetUses()) { - if (IsReplaceableSyncCopyCandidate(use.instruction)) { - add_to_worklist(use.instruction); + for (size_t idx = 0; idx < worklist.size(); ++idx) { + const HloValue* value = worklist.at(idx); + // Values that are related to the current value through a sync copy use + // are added to the worklist. + for (const auto& use : value->GetUses()) { + if (IsAsyncConversionCandidate(use.instruction)) { + add_to_worklist(use.instruction); + } + } + // Expand the worklist to include values that connect to the current + // value as sync copy operands, if any. + HloInstruction* defining_instruction = value->instruction(); + if (IsAsyncConversionCandidate(defining_instruction)) { + CHECK_EQ(defining_instruction->operands().size(), 1); + add_to_worklist(defining_instruction->operands().back()); } } - // Expand the worklist to include values that connect to the current - // value as sync copy operands, if any. - HloInstruction* defining_instruction = value->instruction(); - if (IsReplaceableSyncCopyCandidate(defining_instruction)) { - CHECK_EQ(defining_instruction->operands().size(), 1); - add_to_worklist(defining_instruction->operands().back()); - } + // We're sensitive to the order of the worklist. + absl::c_stable_sort(worklist, [&](const HloValue* a, const HloValue* b) { + return std::make_pair(buffer_intervals_.at(a).start, + buffer_intervals_.at(a).end) < + std::make_pair(buffer_intervals_.at(b).start, + buffer_intervals_.at(b).end); + }); + UpdateSyncDataMovementCandidatesForJointProcessedValues(worklist); } - - // We're sensitive to the order of the worklist. - absl::c_stable_sort(worklist, [&](const HloValue* a, const HloValue* b) { - return std::make_pair(buffer_intervals_.at(a).start, - buffer_intervals_.at(a).end) < - std::make_pair(buffer_intervals_.at(b).start, - buffer_intervals_.at(b).end); - }); - return worklist; } -void MsaAlgorithm::UpdateSyncCopyCandidatesForJointProcessedValues( +void MsaAlgorithm::UpdateSyncDataMovementCandidatesForJointProcessedValues( const std::vector& joint_processed_values) { - absl::flat_hash_set pending_replaceable_copies; + absl::flat_hash_set replaceable_sync_instructions; + absl::flat_hash_set do_not_touch_instructions; for (const HloValue* value : joint_processed_values) { for (const auto& use : value->GetUses()) { - if (IsReplaceableSyncCopyCandidate(use.instruction)) { - pending_replaceable_copies.insert(use.instruction); + bool is_use_replaceable_sync_candidate = + IsAsyncConversionCandidate(use.instruction); + if (is_use_replaceable_sync_candidate && + !do_not_touch_instructions.contains(use.instruction)) { + replaceable_sync_instructions.insert(use.instruction); } } HloInstruction* inst = value->instruction(); - if (IsReplaceableSyncCopyCandidate(inst)) { - pending_replaceable_copies.insert(inst); + bool is_inst_replaceable_sync_candidate = IsAsyncConversionCandidate(inst); + if (is_inst_replaceable_sync_candidate && + !do_not_touch_instructions.contains(inst)) { + replaceable_sync_instructions.insert(inst); } } - sorted_sync_copy_replacement_candidates_.clear(); - sorted_sync_copy_replacement_candidates_.insert( - sorted_sync_copy_replacement_candidates_.end(), - pending_replaceable_copies.begin(), pending_replaceable_copies.end()); + sorted_async_conversion_candidates_.clear(); + sorted_async_conversion_candidates_.insert( + sorted_async_conversion_candidates_.end(), + replaceable_sync_instructions.begin(), + replaceable_sync_instructions.end()); const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - absl::c_stable_sort(sorted_sync_copy_replacement_candidates_, + absl::c_stable_sort(sorted_async_conversion_candidates_, [&instruction_schedule](const HloInstruction* a, const HloInstruction* b) { return instruction_schedule.at(a) < instruction_schedule.at(b); }); VLOG(3) << "Sorted pending replaceable copies: "; - if (sorted_sync_copy_replacement_candidates_.empty()) { + if (sorted_async_conversion_candidates_.empty()) { VLOG(3) << " --Empty--"; } - for (size_t idx = 0; idx < sorted_sync_copy_replacement_candidates_.size(); + for (size_t idx = 0; idx < sorted_async_conversion_candidates_.size(); ++idx) { - VLOG(3) - << " " << idx + 1 << "/" - << sorted_sync_copy_replacement_candidates_.size() << ") " - << sorted_sync_copy_replacement_candidates_.at(idx)->ToShortString(); + VLOG(3) << " " << idx + 1 << "/" + << sorted_async_conversion_candidates_.size() << ") " + << sorted_async_conversion_candidates_.at(idx)->ToShortString(); } } @@ -1451,12 +1577,9 @@ void MsaAlgorithm::ColorColocatedIntervalsToAlternate( } } -void MsaAlgorithm::CreateAllocationValuesForJointProcessedIntervals( - const std::vector& joint_processed_values, - std::vector& joint_allocation_values, - std::vector>& - joint_colocated_intervals) { - for (auto& interval_hlo : joint_processed_values) { +void MsaAlgorithm::CreateAllocationValuesForJointProcessedValues( + JointAllocationProposal& proposal) { + for (auto& interval_hlo : proposal.values) { auto& interval = buffer_intervals_.at(interval_hlo); if (finalized_values_.contains(interval_hlo)) { @@ -1495,9 +1618,21 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedIntervals( if (!options_.enable_window_prefetch && interval.size > available_heap_size()) { - VLOG(3) << "Skip " << interval.buffer->ToShortString() - << " because the buffer is larger than the heap size."; - continue; + const HloInstruction* defining_instruction = + interval.buffer->instruction(); + auto may_be_replaced_by_slice_fn = [this](const HloInstruction* user) { + return IsInstructionPendingReplacements(user) && + user->opcode() == HloOpcode::kSlice; + }; + bool may_be_replaced_by_slice = std::any_of( + defining_instruction->users().begin(), + defining_instruction->users().end(), may_be_replaced_by_slice_fn); + + if (!may_be_replaced_by_slice) { + VLOG(3) << "Skip " << interval.buffer->ToShortString() + << " because the buffer is larger than the heap size."; + continue; + } } auto colocated_intervals = GetSortedColocatedIntervals(interval); @@ -1521,10 +1656,58 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedIntervals( } CreateAllocationValuesFromColocatedIntervals(colocated_intervals, - joint_allocation_values); - joint_colocated_intervals.push_back(colocated_intervals); - NicePrintAllocationValues(joint_allocation_values, /*log_level=*/3); + proposal.allocation_values); + proposal.colocated_intervals.push_back(colocated_intervals); + } + // Order allocation values so that when read and write sequences are + // different, we're sure the allocation value corresponding to the read + // sequence is processed before the written sequence's allocation value. + // We move the allocation values that have their defining instruction in the + // sync conversion list to the end to be processed last. + std::stable_partition(proposal.allocation_values.begin(), + proposal.allocation_values.end(), + [this](AllocationValue& allocation_value) { + return !IsInstructionPendingReplacements( + allocation_value.defining_instruction()); + }); + + NicePrintAllocationValues(proposal.allocation_values, /*log_level=*/3); +} + +MsaAlgorithm::JointAllocationProposal MsaAlgorithm::GetJointProposal( + MsaBufferInterval& interval) { + JointAllocationProposal proposal; + proposal.values = GenerateJointProcessedValues(interval.buffer); + if (VLOG_IS_ON(3)) { + VLOG(3) << "Joint-processed values for " << interval.buffer->ToShortString() + << ": "; + for (size_t idx = 0; idx < proposal.values.size(); ++idx) { + const HloValue* hlo_value = proposal.values.at(idx); + VLOG(3) << " " << idx + 1 << "/" << proposal.values.size() << ") " + << hlo_value->ToShortString(); + } } + + CreateAllocationValuesForJointProcessedValues(proposal); + return proposal; +} + +bool MsaAlgorithm::RepackAllocationsIncludeConvertedSyncMemOp() { + for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { + if (allocation_block.allocation->is_copy_allocation()) { + if (dynamic_cast(allocation_block.allocation) + ->sync_mem_op()) { + return true; + } + } + if (allocation_block.allocation->is_sliced_copy_allocation()) { + if (dynamic_cast(allocation_block.allocation) + ->sync_mem_op()) { + return true; + } + } + } + return false; } absl::StatusOr> MsaAlgorithm::Finish() { @@ -1630,24 +1813,9 @@ absl::StatusOr> MsaAlgorithm::Finish() { << " because it is already processed."; continue; } - auto joint_processed_values = GenerateJointProcessedValues(interval.buffer); - if (VLOG_IS_ON(3)) { - VLOG(3) << "Joint-processed values for " - << interval.buffer->ToShortString() << ": "; - for (size_t idx = 0; idx < joint_processed_values.size(); ++idx) { - const HloValue* hlo_value = joint_processed_values.at(idx); - VLOG(3) << " " << idx + 1 << "/" << joint_processed_values.size() - << ") " << hlo_value->ToShortString(); - } - } - std::vector joint_allocation_values; - std::vector> - joint_colocated_intervals; - CreateAllocationValuesForJointProcessedIntervals(joint_processed_values, - joint_allocation_values, - joint_colocated_intervals); - if (joint_allocation_values.empty()) { + JointAllocationProposal proposal = GetJointProposal(interval); + if (proposal.allocation_values.empty()) { VLOG(3) << "No allocation values for these joint-processed values."; continue; } @@ -1655,48 +1823,59 @@ absl::StatusOr> MsaAlgorithm::Finish() { bool repacked = false; for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { - for (auto& colocated_intervals : joint_colocated_intervals) { + for (auto& colocated_intervals : proposal.colocated_intervals) { AddRequiredAssignmentsForColocatedIntervals(colocated_intervals); } options_.prefetch_interval_picker->SetRetryNumber(retry_number); TF_ASSIGN_OR_RETURN( Result result, - AllocateAllocationValues(absl::MakeSpan(joint_allocation_values))); + AllocateAllocationValues(absl::MakeSpan(proposal.allocation_values))); VLOG(2) << "Allocation result = " << ResultToString(result); - if (result_is(result, Result::kFailSyncCopyReplacement)) { - CHECK(options_.enable_sync_copy_replacement) - << "Allocation result is Result::kFailSyncCopyReplacement, but " - "sync copy replacement is not enabled."; - for (const HloInstruction* copy_inst : - sorted_sync_copy_replacement_candidates_) { - VLOG(3) << "Adding " << copy_inst->ToShortString() - << " to the set of failed copy replacements. These copies " - "will not be considered for replacement future efforts."; - failed_copy_replacements_set_.insert(copy_inst); + VLOG(3) << "--Allocations List Begin--"; + for (int allocation_value_idx = 0; + allocation_value_idx < proposal.allocation_values.size(); + ++allocation_value_idx) { + auto& allocation_value = + proposal.allocation_values.at(allocation_value_idx); + VLOG(3) << allocation_value_idx + 1 << "/" + << proposal.allocation_values.size() << ") " + << allocation_value.ToShortString(); + for (auto& allocation : *allocation_value.allocation_sequence()) { + VLOG(3) << " " << allocation->ToString(); } - UncommitPendingChunks(absl::MakeSpan(joint_allocation_values)); - --retry_number; - VLOG(3) << "Updating the joint-processed values after sync copy " - "replacement failure."; - joint_processed_values = {interval.buffer}; - joint_allocation_values.clear(); - joint_colocated_intervals.clear(); - CreateAllocationValuesForJointProcessedIntervals( - joint_processed_values, joint_allocation_values, - joint_colocated_intervals); - if (joint_allocation_values.empty()) { + } + VLOG(3) << "--Allocations List End--"; + if (result_is(result, Result::kFailSyncDataMoveReplacement)) { + CHECK(options_.enable_sync_copy_replacement || + options_.enable_sync_slice_replacement) + << "Allocation result is Result::kFailSyncCopyReplacement, but " + "no sync replacement is enabled."; + UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values)); + proposal = GetJointProposal(interval); + if (proposal.allocation_values.empty()) { VLOG(3) << "No allocation values found in the updated joint-processed " "values. Moving on to the next set of joint-processed values."; break; } + --retry_number; + } else if (result_requires_uncommit(result)) { - UncommitPendingChunks(absl::MakeSpan(joint_allocation_values)); + UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values)); VLOG(2) << "Couldn't allocate. Retry number " << retry_number; + if (retry_number > 0 && !sorted_async_conversion_candidates_.empty()) { + failed_async_conversions_[sorted_async_conversion_candidates_.at(0)] = + AsyncConversionResult::kFailedGaveUp; + VLOG(2) << "Giving the allocation another chance by dropping one " + "async conversion candidate."; + proposal = GetJointProposal(interval); + --retry_number; + } } else if ((result_is(result, Result::kFailOutOfMemory) || options_.repack_after_every_allocation) && - num_repacks_ < options_.max_repacks && !repacked) { - UncommitPendingChunks(absl::MakeSpan(joint_allocation_values)); + num_repacks_ < options_.max_repacks && !repacked && + !RepackAllocationsIncludeConvertedSyncMemOp()) { + UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values)); ++num_repacks_; repacked = true; CHECK_NE(options_.repacker, nullptr); @@ -1720,10 +1899,13 @@ absl::StatusOr> MsaAlgorithm::Finish() { // Check if any of the allocation sites are inefficient. If so, get rid // of the pending allocation, require all of the inefficient sites in // the default memory, and perform allocation again. - std::vector inefficient_sites = - GetInefficientAllocationSites(joint_allocation_values); + std::vector inefficient_sites = {}; + if (sorted_async_conversion_candidates_.empty()) { + inefficient_sites = + GetInefficientAllocationSites(proposal.allocation_values); + } if (!inefficient_sites.empty()) { - UncommitPendingChunks(absl::MakeSpan(joint_allocation_values)); + UncommitPendingChunks(absl::MakeSpan(proposal.allocation_values)); for (const HloPositionOrUse& site : inefficient_sites) { // To avoid a livelock situation, we commit the required assignments // right away. Otherwise, reallocation can find alternate memory @@ -1741,19 +1923,22 @@ absl::StatusOr> MsaAlgorithm::Finish() { continue; } - FinalizeAllocations(absl::MakeSpan(joint_allocation_values)); + FinalizeAllocations(absl::MakeSpan(proposal.allocation_values)); break; } } // Keep track of the processed values to prevent double-processing in future // joint-processed intervals. - for (auto& value : joint_processed_values) { + for (auto& value : proposal.values) { finalized_values_.insert(value); } } if (options_.repack_after_every_allocation) { CHECK_NE(options_.repacker, nullptr); + CHECK(!RepackAllocationsIncludeConvertedSyncMemOp()) + << "Repacking is not supported yet when there are converted sync mem " + "ops."; std::vector repack_allocation_blocks; ExportAllocationsForRepacking(repack_allocation_blocks); VLOG(2) << "Final Repacking."; @@ -1781,11 +1966,12 @@ absl::StatusOr> MsaAlgorithm::Finish() { if (VLOG_IS_ON(3)) { VLOG(3) << "Sync copy replacement summary: "; - for (const HloInstruction* inst : successful_copy_replacements_set_) { + for (const HloInstruction* inst : successful_async_conversion_set_) { VLOG(3) << "Successful copy replacement: " << inst->ToString(); } - for (const HloInstruction* inst : failed_copy_replacements_set_) { - VLOG(3) << "Failed copy replacement: " << inst->ToString(); + for (auto& failure : failed_async_conversions_) { + VLOG(3) << "Failed copy replacement: " << failure.first->ToString() + << ", reason: " << int(failure.second); } } @@ -2046,9 +2232,10 @@ void MsaAlgorithm::AddRequiredAssignmentsForColocatedIntervals( void MsaAlgorithm::CreateAllocationValuesFromColocatedIntervals( absl::Span colocated_intervals, std::vector& allocation_values) { + std::vector new_allocation_values; // Create AllocationValues for all the colocated intervals. for (const auto& colocated_interval : colocated_intervals) { - CreateAllocationValues(*colocated_interval, allocation_values); + CreateAllocationValues(*colocated_interval, new_allocation_values); } // Go through the AllocationValues and delete the ones that have the identical // defining instruction and use instructions. This is useful for async @@ -2066,59 +2253,140 @@ void MsaAlgorithm::CreateAllocationValuesFromColocatedIntervals( } return instruction_vector; }; - for (int i = 0; i < allocation_values.size() - 1; ++i) { - for (int j = i + 1; j < allocation_values.size(); ++j) { - const AllocationValue& allocation_value_1 = allocation_values[i]; - const AllocationValue& allocation_value_2 = allocation_values[j]; + for (int i = 0; i < new_allocation_values.size() - 1; ++i) { + for (int j = i + 1; j < new_allocation_values.size(); ++j) { + const AllocationValue& allocation_value_1 = new_allocation_values[i]; + const AllocationValue& allocation_value_2 = new_allocation_values[j]; if (create_instruction_vector(allocation_value_1) == create_instruction_vector(allocation_value_2)) { VLOG(3) << "Allocation values " << allocation_value_1.ToShortString() << " and " << allocation_value_2.ToShortString() << " are equivalent, deleting the second one."; - allocation_values.erase(allocation_values.begin() + j); + new_allocation_values.erase(new_allocation_values.begin() + j); --j; } } } - FindAliases(&allocation_values); + FindAliases(&new_allocation_values); + absl::c_move(new_allocation_values, std::back_inserter(allocation_values)); +} + +bool MsaAlgorithm::RequiresNoCopyAlternateMemAllocation( + AllocationValue& allocation_value) const { + return allocation_value.value()->shape().has_layout() && + allocation_value.value()->shape().layout().memory_space() == + options_.alternate_memory_space; +} + +void MsaAlgorithm::AssignDefaultMemIfNotAllowedInAlternateMem( + AllocationValue& allocation_value, int64_t definition_time) { + if (!options_.is_position_allowed_in_alternate_mem_fn( + allocation_value.defining_position())) { + if (RequiresNoCopyAlternateMemAllocation(allocation_value)) { + LOG(WARNING) << "The value " << allocation_value.value()->ToShortString() + << " is pre-colored for alternate memory but the position " + << allocation_value.defining_position().ToString() + << " is not allowed in the alternate memory. Respecting the " + "color " + "but this may break things later in compilation."; + } else { + AddRequiredAssignment(allocation_value.value(), + allocation_value.defining_instruction(), + MemorySpace::kDefault, definition_time); + } + } +} + +std::vector +MsaAlgorithm::GenerateAllocationSegmentContexts( + absl::Span& allocation_values, + absl::flat_hash_map>& + value_indices_by_sync_inst, + int allocation_value_idx) const { + AllocationValue& allocation_value = + allocation_values.at(allocation_value_idx); + std::vector uses_work_list; + for (int primary_use_idx = 0; + primary_use_idx < allocation_value.uses().size(); ++primary_use_idx) { + AllocationValue::Use& primary_use = + allocation_value.uses().at(primary_use_idx); + if (!IsInstructionPendingReplacements(primary_use.hlo_use.instruction)) { + uses_work_list.push_back({&allocation_value.uses(), primary_use_idx, + allocation_value_idx, false}); + } else { + uses_work_list.push_back({&allocation_value.uses(), primary_use_idx, + allocation_value_idx, true}); + for (auto sync_destination_idx : + value_indices_by_sync_inst.at(primary_use.hlo_use.instruction)) { + AllocationValue& sync_destination = + allocation_values.at(sync_destination_idx); + if (sync_destination.defining_instruction() == + primary_use.hlo_use.instruction) { + VLOG(3) << "Adding secondary uses related to allocation value " + << sync_destination.ToShortString() + << " to uses worklist, because the allocation value is " + "defined at the copy use instruction output."; + for (int secondary_use_id = 0; + secondary_use_id < sync_destination.uses().size(); + ++secondary_use_id) { + // This is an important line + sync_destination.uses().at(secondary_use_id).sync_mem_op_operand = + primary_use.hlo_use.instruction; + int allocation_value_to_update_idx = sync_destination_idx; + uses_work_list.push_back({&sync_destination.uses(), + secondary_use_id, + allocation_value_to_update_idx, false}); + } + } else { + VLOG(3) << "Skipping secondary uses related to allocation value " + << sync_destination.ToShortString() + << ", because the allocation value is not defined at the " + "copy use instruction " + "output."; + } + } + } + } + // Sort uses according to their use time + std::sort(uses_work_list.begin(), uses_work_list.end(), + [](const auto& a, const auto& b) { + return a.uses->at(a.use_idx).time < b.uses->at(b.use_idx).time; + }); + VLOG(3) << "Uses work list:"; + for (int i = 0; i < uses_work_list.size(); i++) { + auto [uses, use_idx, allocation_value_to_update_idx, + only_extend_existing_allocation] = uses_work_list.at(i); + VLOG(3) << " " << i + 1 << "/" << uses_work_list.size() << ") " + << uses->at(use_idx).hlo_use.ToString(); + } + if (uses_work_list.empty()) { + VLOG(3) << " --Empty--"; + } + return uses_work_list; } absl::StatusOr MsaAlgorithm::AllocateAllocationValues( absl::Span allocation_values) { const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - absl::flat_hash_map> - value_indices_by_copy_inst; + value_indices_by_sync_inst; for (size_t idx = 0; idx < allocation_values.size(); ++idx) { const HloInstruction* inst = allocation_values.at(idx).defining_instruction(); - if (IsReplaceableSyncCopyCandidate(inst)) { - value_indices_by_copy_inst[inst].push_back(idx); + if (IsInstructionPendingReplacements(inst)) { + value_indices_by_sync_inst[inst].push_back(idx); } } - absl::flat_hash_set all_use_times_set; + // Extract all use times + std::vector all_use_times; for (const AllocationValue& allocation_value : allocation_values) { for (const auto& use : allocation_value.uses()) { - if (!IsInstructionInPendingCopyReplacements(use)) { - all_use_times_set.insert(use.time); - } else { - for (size_t copy_destination_idx : - value_indices_by_copy_inst[use.hlo_use.instruction]) { - const AllocationValue& copy_destination = - allocation_values.at(copy_destination_idx); - for (const auto& copy_use : copy_destination.uses()) { - all_use_times_set.insert(copy_use.time); - } - } - } + all_use_times.push_back(use.time); } } - std::vector all_use_times(all_use_times_set.begin(), - all_use_times_set.end()); absl::c_sort(all_use_times); - for (int i = 0; i < all_use_times.size(); ++i) { VLOG(3) << "all_use_times[" << i << "] = " << all_use_times[i]; } @@ -2128,132 +2396,83 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( // as well as inside the while loop. absl::flat_hash_map preferred_offset_for_computation; + absl::flat_hash_map + preferred_offset_for_allocation_value; + absl::flat_hash_map + definition_time_for_allocation_value; Result result = Result::kSuccess; - for (int allocation_value_idx = 0; - allocation_value_idx < allocation_values.size(); - ++allocation_value_idx) { - auto& allocation_value = allocation_values.at(allocation_value_idx); - VLOG(3) << allocation_value_idx + 1 << "/" << allocation_values.size() + for (int alloc_value_idx = 0; alloc_value_idx < allocation_values.size(); + ++alloc_value_idx) { + auto& allocation_value = allocation_values.at(alloc_value_idx); + VLOG(3) << alloc_value_idx + 1 << "/" << allocation_values.size() << ") Allocating allocation value: " << allocation_value.ToShortString(); - if (IsInstructionInPendingCopyReplacements( + if (IsInstructionPendingReplacements( allocation_value.defining_instruction())) { VLOG(3) << "Skip allocating allocation value " << allocation_value.ToShortString(); continue; } - int64_t definition_time = - instruction_schedule.at(allocation_value.defining_instruction()); - - bool require_no_copy_alternate_mem_allocation = - allocation_value.value()->shape().has_layout() && - allocation_value.value()->shape().layout().memory_space() == - options_.alternate_memory_space; VLOG(4) << "require_no_copy_alternate_mem_allocation = " - << require_no_copy_alternate_mem_allocation; - if (require_no_copy_alternate_mem_allocation && + << RequiresNoCopyAlternateMemAllocation(allocation_value); + if (RequiresNoCopyAlternateMemAllocation(allocation_value) && allocation_value.size() > available_heap_size()) { VLOG(3) << "Skip " << allocation_value.value()->ToShortString() << " because the buffer is larger than the heap size."; continue; } - if (!options_.is_position_allowed_in_alternate_mem_fn( - allocation_value.defining_position())) { - if (require_no_copy_alternate_mem_allocation) { - LOG(WARNING) - << "The value " << allocation_value.value()->ToShortString() - << " is pre-colored for alternate memory but the position " - << allocation_value.defining_position().ToString() - << " is not allowed in the alternate memory. Respecting the color " - "but this may break things later in compilation."; - } else { - AddRequiredAssignment(allocation_value.value(), - allocation_value.defining_instruction(), - MemorySpace::kDefault, definition_time); - } - } - - AliasedOffset* preferred_offset = nullptr; - auto preferred_offset_it = - preferred_offset_for_computation.find(allocation_value.computation()); - if (preferred_offset_it != preferred_offset_for_computation.end()) { - preferred_offset = preferred_offset_it->second; - } const AllocationValue::Use* previous_use = nullptr; - std::vector&>> - uses_work_list; - for (int primary_use_idx = 0; - primary_use_idx < allocation_value.uses().size(); ++primary_use_idx) { - AllocationValue::Use& primary_use = - allocation_value.uses().at(primary_use_idx); - if (!IsInstructionInPendingCopyReplacements(primary_use)) { - uses_work_list.push_back({primary_use_idx, allocation_value.uses()}); - } else { - for (auto copy_destination_idx : - value_indices_by_copy_inst[primary_use.hlo_use.instruction]) { - AllocationValue& copy_destination = - allocation_values.at(copy_destination_idx); - if (copy_destination.defining_instruction() == - primary_use.hlo_use.instruction) { - VLOG(3) << "Adding secondary uses related to allocation value " - << copy_destination.ToShortString() - << " to uses worklist, because the allocation value is " - "defined at the copy use instruction output."; - for (int secondary_use_id = 0; - secondary_use_id < copy_destination.uses().size(); - ++secondary_use_id) { - // This is an important line - copy_destination.uses().at(secondary_use_id).copy_source = - primary_use.hlo_use.instruction; - uses_work_list.push_back( - {secondary_use_id, copy_destination.uses()}); - } - } else { - VLOG(3) << "Skipping secondary uses related to allocation value " - << copy_destination.ToShortString() - << ", because the allocation value is not defined at the " - "copy use instruction " - "output."; - } - } - } - } - VLOG(3) << "Uses work list:"; - for (int i = 0; i < uses_work_list.size(); i++) { - auto [use_idx, uses] = uses_work_list.at(i); - VLOG(3) << " " << i + 1 << "/" << uses_work_list.size() << ") " - << uses.at(use_idx).hlo_use.ToString(); - } - if (uses_work_list.empty()) { - VLOG(3) << " --Empty--"; - } + auto uses_work_list = GenerateAllocationSegmentContexts( + allocation_values, value_indices_by_sync_inst, alloc_value_idx); // Iterate over the uses. - for (auto& [use_idx, uses] : uses_work_list) { - const AllocationValue::Use& use = uses.at(use_idx); - VLOG(3) << "Working on use: " << use.hlo_use.ToString(); - // if (!use.copy_destinations.empty() && - if (IsInstructionInPendingCopyReplacements(use) && - use.hlo_use.instruction->IsUserOf( - allocation_value.defining_instruction())) { - // If the use is a copy instruction, we only process the uses of the - // allocation values which are defined at the output of the copy - // instruction. The rest of the uses will be processed later in the - // high-level loop over allocation_values. Nested copy uses are - // distinguished by checking if the use is a direct users of the - // allocation value - VLOG(3) << "Skip allocating a segment for use " - << use.hlo_use.ToString() << " because it's a copy use."; - continue; + for (auto& entry : uses_work_list) { + const AllocationValue::Use& use = entry.uses->at(entry.use_idx); + AllocationValue& allocation_value_to_update = + allocation_values.at(entry.allocation_value_to_update_idx); + std::string extension_only_hint_str = + entry.only_extend_existing_allocation ? " (extension only): " : ": "; + VLOG(3) << "Working on use" << extension_only_hint_str + << use.hlo_use.ToString() + << ", allocation value: " << allocation_value.ToShortString() + << ", updates allocation value: " + << allocation_value_to_update.ToShortString(); + + if (!definition_time_for_allocation_value.contains( + &allocation_value_to_update)) { + definition_time_for_allocation_value[&allocation_value_to_update] = + hlo_live_range_.instruction_schedule().at( + allocation_value_to_update.defining_instruction()); + AssignDefaultMemIfNotAllowedInAlternateMem( + allocation_value_to_update, definition_time_for_allocation_value.at( + &allocation_value_to_update)); + } + + if (!preferred_offset_for_allocation_value.contains( + &allocation_value_to_update)) { + auto preferred_offset_it = preferred_offset_for_computation.find( + allocation_value_to_update.computation()); + if (preferred_offset_it != preferred_offset_for_computation.end()) { + preferred_offset_for_allocation_value[&allocation_value_to_update] = + preferred_offset_it->second; + } else { + preferred_offset_for_allocation_value[&allocation_value_to_update] = + nullptr; + } } - preferred_offset = UpdatePreferredOffsetForUse(use, preferred_offset); + preferred_offset_for_allocation_value[&allocation_value_to_update] = + UpdatePreferredOffsetForUse(use, + preferred_offset_for_allocation_value.at( + &allocation_value_to_update)); AllocationRequest request = CreateAllocationRequest( - allocation_value, use, previous_use, preferred_offset, - definition_time, require_no_copy_alternate_mem_allocation, - all_use_times); + allocation_value, allocation_value_to_update, use, previous_use, + preferred_offset_for_allocation_value.at(&allocation_value_to_update), + definition_time_for_allocation_value.at(&allocation_value_to_update), + RequiresNoCopyAlternateMemAllocation(allocation_value_to_update), + all_use_times, entry.only_extend_existing_allocation); // Bitcasts don't define buffers and don't directly consume buffers. // Skip allocating buffers for bitcast uses (unless they are the root // instruction). The uses that feed from bitcasts will be handled @@ -2264,24 +2483,47 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( result_mark(AllocateSegment(request), result); if (request.require_copy_allocation) { auto allocation_sequence = - request.allocation_value->mutable_allocation_sequence(); + allocation_value_to_update.mutable_allocation_sequence(); auto it = std::find_if( allocation_sequence->begin(), allocation_sequence->end(), [&](const std::unique_ptr< xla::memory_space_assignment::Allocation>& allocation_ptr) { - auto copy_allocation = - dynamic_cast(allocation_ptr.get()); - return copy_allocation && - (copy_allocation->copy_done_schedule_before() <= - request.required_copy_allocation_latest_time); + if (allocation_ptr->is_copy_allocation()) { + auto copy_allocation = + dynamic_cast(allocation_ptr.get()); + return copy_allocation && + (copy_allocation->copy_done_schedule_before() <= + request.required_copy_allocation_latest_time) && + (copy_allocation->sync_mem_op() == + request.required_copy_allocation_for) && + (!request.required_copy_for_slice || + (request.required_copy_for_slice && + !copy_allocation->cross_program_prefetch_index() + .has_value())); + } + if (allocation_ptr->is_sliced_copy_allocation()) { + auto sliced_copy_allocation = + dynamic_cast( + allocation_ptr.get()); + return sliced_copy_allocation && + (sliced_copy_allocation->earliest_available_time() <= + request.required_copy_allocation_latest_time) && + (sliced_copy_allocation->sync_mem_op() == + request.required_copy_allocation_for) && + !request.required_copy_for_slice; + } + return false; }); + if (result_requires_uncommit(result) || it == allocation_sequence->end()) { VLOG(3) << "No async copy allocation found by the end of " "segment allocation. " "Sync copy replacement has failed. Fall back to the " "normal mode."; - result_mark(Result::kFailSyncCopyReplacement, result); + failed_async_conversions_[request.required_copy_allocation_for] = + AsyncConversionResult::kFailedSatisfyingConstraints; + result_mark(Result::kFailSyncDataMoveReplacement, result); result_mark(Result::kFailRequiresUncommit, result); } else { bool has_correct_use = false; @@ -2296,9 +2538,13 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( "segment allocation with the correct use. " "Sync copy replacement has failed. Fall back to the " "normal mode."; - result_mark(Result::kFailSyncCopyReplacement, result); + failed_async_conversions_[request.required_copy_allocation_for] = + AsyncConversionResult::kFailedPrecondition; + result_mark(Result::kFailSyncDataMoveReplacement, result); result_mark(Result::kFailRequiresUncommit, result); } else { + not_finalized_async_conversions_.push_back( + request.required_copy_allocation_for); VLOG(3) << "Replacing " << request.required_copy_allocation_for->ToShortString() << " with " << (*it)->ToString(); @@ -2325,19 +2571,48 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( // If there are multiple uses, they can try using the memory // allocation already at the alternate memory. - definition_time = instruction_schedule.at(use.hlo_use.instruction); + definition_time_for_allocation_value[&allocation_value_to_update] = + instruction_schedule.at(use.hlo_use.instruction); previous_use = &use; } + if (entry.only_extend_existing_allocation) { + continue; + } const auto use_time = request.end_time; - UpdateAllocationRequirementForUseAliases(allocation_value, use, use_time); + UpdateAllocationRequirementForUseAliases(allocation_value_to_update, use, + use_time); MaybeCreateMirroredParentAllocationForWhileUse( - allocation_value, use, use_time, allocation_values, + allocation_value_to_update, use, use_time, allocation_values, preferred_offset_for_computation); } } + + if (!VerifyAllConversionsAreSuccessful()) { + result_mark(Result::kFailSyncDataMoveReplacement, result); + result_mark(Result::kFailRequiresUncommit, result); + } + return result; } +bool MsaAlgorithm::VerifyAllConversionsAreSuccessful() { + for (const HloInstruction* instruction : + sorted_async_conversion_candidates_) { + if (absl::c_find(not_finalized_async_conversions_, instruction) == + not_finalized_async_conversions_.end()) { + if (!failed_async_conversions_.contains(instruction)) { + failed_async_conversions_[instruction] = + AsyncConversionResult::kFailedNotProcessed; + VLOG(3) << "Async conversion failed for " + << instruction->ToShortString() + << " because its operand or user was not processed."; + } + return false; + } + } + return true; +} + MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse( const AllocationValue::Use& use, MsaAlgorithm::AliasedOffset* preferred_offset) const { @@ -2358,20 +2633,50 @@ MsaAlgorithm::AliasedOffset* MsaAlgorithm::UpdatePreferredOffsetForUse( } MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( - AllocationValue& allocation_value, const AllocationValue::Use& use, - const AllocationValue::Use* previous_use, AliasedOffset* preferred_offset, - int64_t definition_time, bool require_no_copy_alternate_mem_allocation, - const std::vector& all_use_times) { + AllocationValue& allocation_value, + AllocationValue& allocation_value_to_update, + const AllocationValue::Use& use, const AllocationValue::Use* previous_use, + AliasedOffset* preferred_offset, int64_t definition_time, + bool require_no_copy_alternate_mem_allocation, + const std::vector& all_use_times, + bool only_extend_existing_allocation) { const HloUse& hlo_use = use.hlo_use; const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); bool require_copy_allocation = false; int64_t required_copy_allocation_latest_time = 0; HloInstruction* required_copy_allocation_for = nullptr; - if (use.copy_source && - IsInstructionInPendingCopyReplacements(use.copy_source)) { - required_copy_allocation_latest_time = GetCorrectedUseTime(use.copy_source); - required_copy_allocation_for = use.copy_source; + bool required_copy_for_slice = false; + if (use.sync_mem_op_operand && + IsInstructionPendingReplacements(use.sync_mem_op_operand)) { + required_copy_allocation_for = use.sync_mem_op_operand; require_copy_allocation = true; + required_copy_for_slice = + (IsAsyncConversionSliceCandidate(use.sync_mem_op_operand) == + AsyncConversionResult::kSuccess); + // The async copy allocation can be delayed until the earliest time at which + // the value is used in a position or the earliest use time of the updated + // allocation value. We find the minimum of these two times. + int64_t min_time = + GetCorrectedUseTime(allocation_value.defining_instruction()); + int64_t earliest_position_time = std::numeric_limits::max(); + for (auto& position : allocation_value.value()->positions()) { + auto position_time = GetCorrectedUseTime(position.instruction); + if (position_time > min_time) { + earliest_position_time = + std::min(earliest_position_time, position_time); + } + } + int64_t earliest_use_time = std::numeric_limits::max(); + for (auto& secondary_use : allocation_value_to_update.uses()) { + if (!IsTrivialInstruction(secondary_use.hlo_use.instruction) || + secondary_use.hlo_use.instruction == + use.hlo_use.instruction->parent()->root_instruction()) { + earliest_use_time = std::min( + earliest_use_time, GetCorrectedUseTime(secondary_use.hlo_use)); + } + } + required_copy_allocation_latest_time = + std::min(earliest_use_time, earliest_position_time); } int64_t use_time = instruction_schedule.at(hlo_use.instruction); bool allow_no_copy_alternate_mem_allocation = true; @@ -2400,16 +2705,18 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( // Add a required assignment in default memory if the use not allowed in // alternate memory. - if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) { + if (!IsUseAllowedInAlternateMemory(allocation_value_to_update, hlo_use)) { if (require_no_copy_alternate_mem_allocation) { - LOG(WARNING) << "The value " << allocation_value.value()->ToShortString() + LOG(WARNING) << "The value " + << allocation_value_to_update.value()->ToShortString() << " is pre-colored for alternate memory but the use " << hlo_use.ToString() << " is not allowed in the alternate memory. Respecting the " "color but this may break things later in compilation."; } else { - AddRequiredAssignment(allocation_value.value(), hlo_use.instruction, - MemorySpace::kDefault, use_time); + AddRequiredAssignment(allocation_value_to_update.value(), + hlo_use.instruction, MemorySpace::kDefault, + use_time); } } else if (previous_use != nullptr) { // We allow buffers in alternate memory that are passed into @@ -2488,7 +2795,9 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( } } - if (options_.use_repeated_instance_for_preferred_prefetch_time) { + // TODO(mehrdadk): Remove this code once we have a better way to find + // repeated instructions. + if (false) { const std::vector* repeated_insts = GetRepeatedInstructionList(hlo_use.instruction); if (repeated_insts) { @@ -2540,7 +2849,7 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( // time is the parameter use, which is less. request.inclusive_start_time = std::min(definition_time, use_time); request.latest_prefetch_time = latest_prefetch_time; - request.size = allocation_value.size(); + request.size = allocation_value_to_update.size(); request.prefer_no_copy_alternate_mem_allocation = prefer_no_copy_alternate_mem_allocation; request.allow_no_copy_alternate_mem_allocation = @@ -2558,9 +2867,12 @@ MsaAlgorithm::AllocationRequest MsaAlgorithm::CreateAllocationRequest( request.required_copy_allocation_latest_time = required_copy_allocation_latest_time; request.required_copy_allocation_for = required_copy_allocation_for; + request.required_copy_for_slice = required_copy_for_slice; + request.allocation_value_to_update = &allocation_value_to_update; } request.end_time = use_time; + request.only_extend_existing_allocation = only_extend_existing_allocation; return request; } @@ -3127,12 +3439,13 @@ void MsaAlgorithm::AllocateCrossProgramPrefetchBuffer( int64_t cross_program_prefetch_end_time = free_buffer ? last_use_time : prefetch_candidate.end; - AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate, - /*exclusive_start_time=*/ - InclusiveToExclusiveStartTime(prefetch_candidate.start), - cross_program_prefetch_end_time, latest_prefetch_time, - &allocations, /*aliased_offset=*/nullptr, - /*resource=*/0.0, cross_program_prefetch_index); + AddAsyncCopyOrOtherMemOp( + *allocations.back(), MemorySpace::kAlternate, chunk_candidate, + /*exclusive_start_time=*/ + InclusiveToExclusiveStartTime(prefetch_candidate.start), + cross_program_prefetch_end_time, latest_prefetch_time, &allocations, + /*aliased_offset=*/nullptr, + /*resource=*/0.0, cross_program_prefetch_index); absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); AliasedOffset* cross_program_prefetch_offset = @@ -3141,14 +3454,14 @@ void MsaAlgorithm::AllocateCrossProgramPrefetchBuffer( if (free_buffer) { VLOG(2) << "Adding an end-of-program prefetch for freed " "cross-program-prefetched buffer."; - AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, chunk_candidate, - /*exclusive_start_time=*/ - InclusiveToExclusiveStartTime( - end_of_program_inclusive_prefetch_start_time), - end_of_program_prefetch_end_time, - end_of_program_prefetch_end_time, &allocations, - cross_program_prefetch_offset, - /*resource=*/0.0); + AddAsyncCopyOrOtherMemOp( + *allocations.front(), MemorySpace::kAlternate, chunk_candidate, + /*exclusive_start_time=*/ + InclusiveToExclusiveStartTime( + end_of_program_inclusive_prefetch_start_time), + end_of_program_prefetch_end_time, end_of_program_prefetch_end_time, + &allocations, cross_program_prefetch_offset, + /*resource=*/0.0); CHECK_EQ(cross_program_prefetch_offset->offset, allocations.back()->chunk().offset); } @@ -3799,12 +4112,6 @@ absl::Status MsaAlgorithm::AreRepackedSlicesValid( void MsaAlgorithm::UncommitPendingChunks( absl::Span allocation_values) { - if (!sorted_sync_copy_replacement_candidates_.empty()) { - VLOG(3) << "Withdrawing copy replacement efforts for this group of " - "joint-processed intervals, because the initial allocation " - "attempt was not successful."; - sorted_sync_copy_replacement_candidates_.clear(); - } // Clear the allocation sequence of the allocation values so that in case we // retry allocation after uncommitting. for (AllocationValue& allocation_value : allocation_values) { @@ -3869,10 +4176,10 @@ void MsaAlgorithm::UncommitPendingChunks( void MsaAlgorithm::FinalizeAllocations( absl::Span allocation_values) { - for (const HloInstruction* copy_inst : - sorted_sync_copy_replacement_candidates_) { - successful_copy_replacements_set_.insert(copy_inst); + for (const HloInstruction* copy_inst : sorted_async_conversion_candidates_) { + successful_async_conversion_set_.insert(copy_inst); } + not_finalized_async_conversions_.clear(); std::vector>> colocation_vector; absl::flat_hash_map offset_to_index; @@ -3938,12 +4245,11 @@ void MsaAlgorithm::ClearPendingChunks() { aliased_offsets_.clear(); } -bool MsaAlgorithm::IsInstructionInPendingCopyReplacements( +bool MsaAlgorithm::IsInstructionPendingReplacements( const HloInstruction* instruction) const { - return std::find(sorted_sync_copy_replacement_candidates_.begin(), - sorted_sync_copy_replacement_candidates_.end(), - instruction) != - sorted_sync_copy_replacement_candidates_.end(); + return std::find(sorted_async_conversion_candidates_.begin(), + sorted_async_conversion_candidates_.end(), + instruction) != sorted_async_conversion_candidates_.end(); } void MsaAlgorithm::AddToPendingChunks(const MsaBufferInterval& buffer_interval, @@ -4000,8 +4306,8 @@ std::string MsaAlgorithm::SingleFailureResultToString(const Result& result) { return "AllSlicesHaveTheSameStartTime"; case Result::kFailConflictingPreferredOffsets: return "FailConflictingPreferredOffsets"; - case Result::kFailSyncCopyReplacement: - return "FailSyncCopyReplacement"; + case Result::kFailSyncDataMoveReplacement: + return "FailSyncDataMoveReplacement"; default: return "UnknownResult"; } @@ -4068,6 +4374,10 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { *use.instruction, use.operand_number, use.operand_index); } + if (request.only_extend_existing_allocation && + !allocation_sequence->empty()) { + allocation_sequence->back()->Extend(request.inclusive_start_time); + } // There could be a requirement to pin this buffer to default memory either // because it is a parameter or an output. If the buffer is a parameter, then // we're allowed to prefetch. If the use expects the output to be in default @@ -4082,7 +4392,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // Find required assignment both for the use and its aliases. If they are both // non-nullopt, then make sure they require the same assignment. auto required_assignment_at_end = RequiredMemoryAssignmentAt( - request.allocation_value->value(), request.end_time); + request.allocation_value_to_update->value(), request.end_time); auto aliased_required_assignment_at_end = AliasedRequiredAssignmentForUse(*request.use); if (required_assignment_at_end != aliased_required_assignment_at_end) { @@ -4108,7 +4418,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { return allocation->memory_space() == required_memory_space_at_start; }); if (prev_allocation_it != allocation_sequence->rend()) { - (*prev_allocation_it)->set_end_time(request.inclusive_start_time); + (*prev_allocation_it)->Extend(request.inclusive_start_time); needs_required_allocation = false; } } @@ -4136,7 +4446,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // First try keeping the allocation entirely in the alternate memory. if (required_memory_space_at_start != MemorySpace::kDefault && required_memory_space_at_end != MemorySpace::kDefault && - request.allow_no_copy_alternate_mem_allocation) { + request.allow_no_copy_alternate_mem_allocation && + !request.require_copy_allocation) { allocation_result = AllocateInAlternateMemoryNoCopy(request); if (allocation_result == Result::kSuccess) { return Result::kSuccess; @@ -4206,7 +4517,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { if (required_memory_space_at_end == MemorySpace::kDefault) { VLOG(3) << "Not trying to prefetch because use requires buffer in default mem."; - (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time); + (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); // If the buffer is placed in default memory, we can also try window @@ -4218,8 +4529,9 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // Finally, try to prefetch the buffer into alternate memory. if (request.allow_prefetch && - !request.allocation_value->requires_contiguous_allocation()) { - if (request.require_copy_allocation) { + !request.allocation_value->requires_contiguous_allocation() && + !request.only_extend_existing_allocation) { + if (request.require_copy_allocation && !request.required_copy_for_slice) { auto it = std::find_if( allocation_sequence->begin(), allocation_sequence->end(), [&](const std::unique_ptr& @@ -4248,9 +4560,12 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { if (request.preferred_prefetch_time) { // Warn if the prefetch time picked doesn't match the preferred prefetch // time. - CHECK(!request.allocation_value->allocation_sequence()->empty()); + CHECK(!request.allocation_value_to_update->allocation_sequence() + ->empty()); const Allocation* allocation = - request.allocation_value->allocation_sequence()->back().get(); + request.allocation_value_to_update->allocation_sequence() + ->back() + .get(); int64_t prefetch_time = 0; if (allocation->is_copy_allocation()) { prefetch_time = static_cast(allocation) @@ -4300,7 +4615,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // If a copy wasn't inserted, then add this use to the latest allocation in // default memory. - (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time); + (*prev_allocation_in_default_mem_it)->Extend(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); // If the buffer is placed in default memory, we can try window prefetching @@ -4324,12 +4639,13 @@ void MsaAlgorithm::AddAsyncCopyForWindowPrefetch( /*cross_program_prefetch_index=*/std::nullopt); } -void MsaAlgorithm::AddAsyncCopy( +void MsaAlgorithm::AddAsyncCopyOrOtherMemOp( Allocation& prev_allocation, MemorySpace memory_space, std::optional chunk, int64_t exclusive_start_time, int64_t end_time, int64_t copy_done_schedule_before_time, AllocationSequence* allocations, AliasedOffset* aliased_offset, float resource, - std::optional cross_program_prefetch_index) { + std::optional cross_program_prefetch_index, + HloInstruction* sync_mem_op) { VLOG(3) << "Copy to " << (memory_space == MemorySpace::kDefault ? "default" : "alternate") << " memory in (" << exclusive_start_time << ", " @@ -4339,7 +4655,8 @@ void MsaAlgorithm::AddAsyncCopy( allocations->push_back(std::make_unique( prev_allocation, memory_space, chunk, exclusive_start_time, - copy_done_schedule_before_time, end_time, cross_program_prefetch_index)); + copy_done_schedule_before_time, end_time, cross_program_prefetch_index, + sync_mem_op)); RegisterAsyncCopy(memory_space, exclusive_start_time, copy_done_schedule_before_time, allocations, aliased_offset, @@ -4412,7 +4729,8 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch( const Allocation& prev_allocation, AllocationSequence* allocations, AliasedOffset* aliased_offset, const std::vector& slice_decisions_sorted_by_start_time, - int64_t prefetch_end_time, int64_t allocation_end_time) { + int64_t prefetch_end_time, int64_t allocation_end_time, + HloInstruction* sync_mem_op) { VLOG(3) << "Sliced copy to alternate memory. " << SliceTimesAndCopyResourcesToString( slice_decisions_sorted_by_start_time, prefetch_end_time, @@ -4426,7 +4744,7 @@ void MsaAlgorithm::AddAsyncSlicesForPrefetch( prev_allocation, MemorySpace::kAlternate, slice_decisions_sorted_by_start_time, prefetch_end_time, allocation_end_time, options_.sliced_prefetch_options, - options_.get_equivalent_s8_shape_fn)); + options_.get_equivalent_s8_shape_fn, sync_mem_op)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. @@ -4584,7 +4902,7 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( if (prev_allocation != nullptr && (prev_allocation->is_copy_like_allocation() || prev_allocation->defining_position() == defining_position)) { - prev_allocation->set_end_time(request.end_time); + prev_allocation->Extend(request.end_time); } else { request.allocation_value->mutable_allocation_sequence()->push_back( std::make_unique( @@ -4595,8 +4913,10 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( *request.allocation_value->allocation_sequence()->back(), preferred_offset); } - request.allocation_value->allocation_sequence()->back()->AddUse( - request.use->hlo_use); + if (!request.only_extend_existing_allocation) { + request.allocation_value->allocation_sequence()->back()->AddUse( + request.use->hlo_use); + } return Result::kSuccess; } if (request.prefer_no_copy_alternate_mem_allocation) { @@ -4627,18 +4947,23 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) { int64_t eviction_end_time = prev_allocation->end_time(); CHECK(eviction_exclusive_start_time <= eviction_end_time); - int64_t preferred_eviction_end_time = - std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime( - request.allocation_value->defining_position().shape(), - eviction_exclusive_start_time, request.end_time), - eviction_end_time); + int64_t preferred_eviction_end_time = std::max( + options_.prefetch_interval_picker->PreferredEvictionEndTime( + request.allocation_value_to_update->defining_position().shape(), + eviction_exclusive_start_time, request.end_time), + eviction_end_time); // Evictions must complete by the time of this use. preferred_eviction_end_time = std::min(preferred_eviction_end_time, request.latest_prefetch_time); MsaBufferInterval eviction_mem_interval; eviction_mem_interval.buffer = request.allocation_value->value(); - eviction_mem_interval.size = request.size; + // When replacing an sync slice, the size of the original allocation_value + // matters instead of the queuing_allocation_value + // TODO(mehrdadk): separate the request size for src and dst + // AllocationSequence + eviction_mem_interval.size = + std::max(request.allocation_value->size(), request.size); // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; @@ -4694,12 +5019,13 @@ MsaAlgorithm::Result MsaAlgorithm::Evict(const AllocationRequest& request) { // See if this interval would violate the asynchronous copy limit. if (!eviction_interval_too_short && !eviction_violates_outstanding_copies && !eviction_violates_resource) { - prev_allocation->set_end_time(eviction_end_time); - AddAsyncCopy(*prev_allocation, MemorySpace::kDefault, - /*chunk=*/std::nullopt, eviction_exclusive_start_time, - prev_allocation->end_time(), eviction_end_time, - request.allocation_value->mutable_allocation_sequence(), - /*aliased_offset=*/nullptr, eviction_resource); + prev_allocation->Extend(eviction_end_time); + AddAsyncCopyOrOtherMemOp( + *prev_allocation, MemorySpace::kDefault, + /*chunk=*/std::nullopt, eviction_exclusive_start_time, + prev_allocation->end_time(), eviction_end_time, + request.allocation_value->mutable_allocation_sequence(), + /*aliased_offset=*/nullptr, eviction_resource); } else { if (eviction_violates_outstanding_copies) { VLOG(3) << "This violates the maximum async copies."; @@ -4852,6 +5178,13 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( ? options_.while_use_extra_outstanding_prefetch_limit : 0; + // If the request is for a sync mem op conversion to async, we may allow for + // more async copies. + if (context.request->require_copy_allocation) { + context.extra_async_copy_limit += + options_.extend_async_copies_limit_for_sync_mem_op_conversion; + } + // Loop over potential prefetch starting times. At the selected start time, we // check if we have enough resources and memory for a sliced version of the // request and a non-sliced version of the request. We return the first sliced @@ -4901,7 +5234,7 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( } // Check if we found any solutions. - if (context.sliced_solution) { + if (context.sliced_solution && !context.request->required_copy_for_slice) { CHECK(!context.sliced_solution->slices_for_pending_chunks.empty()); VLOG(3) << DescribeSlicedBufferMove( context.sliced_solution->slice_decisions_sorted_by_start_time, result_, @@ -4914,12 +5247,15 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( } AddAsyncSlicesForPrefetch( *context.prev_allocation_in_default_mem, - context.request->allocation_value->mutable_allocation_sequence(), + context.request->allocation_value_to_update + ->mutable_allocation_sequence(), context.request->preferred_offset, context.sliced_solution->slice_decisions_sorted_by_start_time, - context.prefetch_end_time, context.request->end_time); - context.request->allocation_value->allocation_sequence()->back()->AddUse( - context.request->use->hlo_use); + context.prefetch_end_time, context.request->end_time, + context.request->required_copy_allocation_for); + context.request->allocation_value_to_update->allocation_sequence() + ->back() + ->AddUse(context.request->use->hlo_use); return Result::kSuccess; } if (context.unsliced_solution) { @@ -4942,22 +5278,28 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( context.unsliced_solution->chunk_candidate, context.unsliced_solution_intervals.full.start - 1, context.prefetch_end_time, - context.request->allocation_value->mutable_allocation_sequence(), + context.request->allocation_value_to_update + ->mutable_allocation_sequence(), context.request->preferred_offset, context.unsliced_solution->prefetch_resource, *context.request->window_prefetch_options); } else { - AddAsyncCopy( + AddAsyncCopyOrOtherMemOp( *context.prev_allocation_in_default_mem, MemorySpace::kAlternate, context.unsliced_solution->chunk_candidate, context.unsliced_solution_intervals.full.start - 1, context.request->end_time, context.prefetch_end_time, - context.request->allocation_value->mutable_allocation_sequence(), + context.request->allocation_value_to_update + ->mutable_allocation_sequence(), context.request->preferred_offset, - context.unsliced_solution->prefetch_resource); + context.unsliced_solution->prefetch_resource, + /*cross_program_prefetch_index=*/std::nullopt, + context.request->required_copy_allocation_for); + context.prev_allocation_in_default_mem->Extend( + context.request->latest_prefetch_time); } - request.allocation_value->allocation_sequence()->back()->AddUse( + request.allocation_value_to_update->allocation_sequence()->back()->AddUse( request.use->hlo_use); return Result::kSuccess; } @@ -4968,6 +5310,12 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( } void MsaAlgorithm::GenerateSliceProposal(PrefetchContext& context) const { + if (context.request->required_copy_for_slice) { + VLOG(5) << "Not slicing " << context.request->use->hlo_use + << " because slicing a slice instruction is not supported yet."; + return; + } + if (options_.sliced_prefetch_options.max_slices() < 2) { return; } @@ -5017,7 +5365,7 @@ void MsaAlgorithm::SetupPrefetchWorkingIntervalsAndSliceProposal( // Setup the full WorkingIntervals for the sliced and unsliced solutions. // Future code will adjust the start and end times. context.sliced_solution_intervals.full = MsaBufferInterval{ - context.request->allocation_value->value(), + context.request->allocation_value_to_update->value(), /*size=*/context.request->size, /*start=*/-1, /*end=*/context.request->end_time, @@ -5534,7 +5882,8 @@ std::vector MsaAlgorithm::FindBestChunkCandidates( // Then find the latest use that can be allocated contiguously without // copies. - const Shape& shape = request.allocation_value->defining_position().shape(); + const Shape& shape = + request.allocation_value_to_update->defining_position().shape(); for (; (use_time_it + 1) != use_times.end() && options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 52d0f0ee563747..756a1a6bf94334 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -25,7 +25,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -36,18 +35,17 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" -#include "xla/service/buffer_value.h" #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" @@ -140,9 +138,10 @@ class AllocationValue { // All the positions where this use aliases with. The aliased positions // must get the same allocation. std::vector aliases; - // The sync copy instruction that produced the allocation value that this - // use is consuming. - HloInstruction* copy_source = nullptr; + // A synchronous memory operation that feeds this use. + // TODO(mehrdadk): extend this to support multiple sync data movement + // operands. + HloInstruction* sync_mem_op_operand = nullptr; bool operator==(const Use& other) const { return hlo_use == other.hlo_use && time == other.time && @@ -223,6 +222,34 @@ struct AsynchronousCopy { } }; +// Represents a context for allocating a segment of an AllocationValue. +// AllocationValue typically provides enough information to allocate the entire +// live range of the AllocationValue, since all segments update only the +// AllocationSequence belonging to the AllocationValue. However, in cases of +// synchronous memory op conversion (e.g., copy, slice, etc.), we also need +// to modify the AllocationSequence of the AllocationValue produced at the +// synchronous memory op's output. This struct provides a context for allocating +// a segment of an AllocationValue, specifying the uses of the AllocationValue +// that we are processing, the index of the use that we are processing in the +// AllocationValue's uses vector, the index of the AllocationValue in +// Span, whose allocation sequence we will update, and whether +// the use is only processed to extend the lifetime of its operand's allocation, +// and the use will not receive a new allocation. +struct AllocationSegmentContext { + // The uses of the AllocationValue that we are processing. + const std::vector* uses; + // The index of the use that we are processing in the AllocationValue's + // AllocationValue::uses vector. + int use_idx; + // Index of the AllocationValue in allocation_values that is being processed + // in AllocateAllocationValues(), whose allocation sequence we will be + // updated. + int allocation_value_to_update_idx; + // If true, the use is only processed to extend the lifetime of its operand's + // allocation, and the use will not receive a new allocation. + bool only_extend_existing_allocation; +}; + // Compare asynchronous copies such that an earlier start time has the same or // earlier end time and an earlier end time has the same or earlier start time. bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b); @@ -395,22 +422,22 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // Given an HloValue, returns a group of HloValues that need to be processed // jointly. Normally, HloValues can be processed individually. However, in - // case we are trying to replace synchronous copies, we need to process all - // copy-connected values together. + // case we are trying to replace synchronous copies, we need to jointly + // process all values that are produced or consumed by a synchronous memory + // call instruction. std::vector GenerateJointProcessedValues( const HloValue* entrance_value); - // Given an HloValue, returns a group of HloValues that are connected to it by - // replaceable sync copy candidates that feed into or follow from that value. - std::vector GetJointProcessedValuesForSyncCopyReplacement( - const HloValue* entrance_value) const; - // Updates sorted_sync_copy_replacement_candidates_ with synchronous copy // instructions that connect the given joint processed values, and meet the // conditions in IsReplaceableSyncCopyCandidate(). - void UpdateSyncCopyCandidatesForJointProcessedValues( + void UpdateSyncDataMovementCandidatesForJointProcessedValues( const std::vector& joint_processed_values); + // Returns true if repack_allocation_blocks_ includes an AllocationBlock + // belonging to a converted synchronous memory operations. + bool RepackAllocationsIncludeConvertedSyncMemOp(); + absl::StatusOr> Finish() override; protected: @@ -514,6 +541,17 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::Span all_use_times; // See the comment for require_copy_allocation HloInstruction* required_copy_allocation_for; + // If the required copy in require_copy_allocation is only for a slice of + // the allocation_value + bool required_copy_for_slice; + // The resulting Allocation will be added to the AllocationSequence of + // allocation_value_to_update. We only expect allocation_value_to_update to + // be different from allocation_value in the case of a synchronous memory + // operation conversion to asynchronous, otherwise, they should be the same. + AllocationValue* allocation_value_to_update; + // No new Allocation is needed to be created and we will only extend an + // existing one. + bool only_extend_existing_allocation; // Data structure that contains the options for making window prefetched // allocations. const WindowPrefetchedAllocation::Options* window_prefetch_options = @@ -713,8 +751,9 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { kAllSlicesHaveTheSameStartTime = 128, // There were conflicting preferred offsets. kFailConflictingPreferredOffsets = 256, - // Could not replace the synchronous copy with an asynchronous one - kFailSyncCopyReplacement = 512 + // Could not replace the synchronous data movement instruction (e.g., kCopy, + // kSlice) with an asynchronous one + kFailSyncDataMoveReplacement = 512 }; // Return true if the result belongs to a failure. @@ -760,11 +799,27 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // for the found loops. void IdentifyAndOptimizeMemoryBoundLoops(); + // Returns true if the instruction meets the preconditions of a replaceable + // synchronous copy or slice instruction. This only checks for necessary + // conditions, and doesn't guarantee a successful replacement. + bool IsAsyncConversionCandidate(const HloInstruction* instruction) const; // Not supported instructions for sync copy replacement: // 1. Layout-changing copies - // 2. Copied value appears in more than one position - // 3. Instruction operand or output has a pre-specified memory space - bool IsReplaceableSyncCopyCandidate(const HloInstruction* instruction) const; + // 2. Instruction operand or output has a pre-specified memory space + bool IsAsyncConversionCopyCandidate(const HloInstruction* instruction) const; + + enum class AsyncConversionResult { + kSuccess = 0, + kFeatureNotEnabled = 1, + kFailedPrecondition = 2, + kFailedValueNotAllowedInAlternateMemory = 4, + kFailedSatisfyingConstraints = 8, + kFailedNotProcessed = 16, + kFailedGaveUp = 32, + }; + + AsyncConversionResult IsAsyncConversionSliceCandidate( + const HloInstruction* instruction) const; // Allocates buffers for instructions that need reserved scoped allocations in // the alternate memory space. @@ -823,12 +878,39 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // can be placed in alternate memory, considering the restrictions for loops // and conditionals. Also calculates the timing for prefetching, taking into // account instruction schedules, operation type (e.g., sequential vs. - // non-sequential calls), and prior usage patterns. + // non-sequential calls), and prior usage patterns. We add the resulting + // Allocation to the AllocationSequence of allocation_value_to_update. When + // only_extend_existing_allocation is true, no new Allocations will be created + // while processing the resulting AllocationRequest, and we only need to + // extend an existing Allocation's end_time. AllocationRequest CreateAllocationRequest( - AllocationValue& allocation_value, const AllocationValue::Use& use, - const AllocationValue::Use* previous_use, AliasedOffset* preferred_offset, - int64_t definition_time, bool require_no_copy_alternate_mem_allocation, - const std::vector& all_use_times); + AllocationValue& allocation_value, + AllocationValue& allocation_value_to_update, + const AllocationValue::Use& use, const AllocationValue::Use* previous_use, + AliasedOffset* preferred_offset, int64_t definition_time, + bool require_no_copy_alternate_mem_allocation, + const std::vector& all_use_times, + bool only_extend_existing_allocation); + + // Returns true, if the allocation value requires a pinned allocation in the + // alternate memory space. + bool RequiresNoCopyAlternateMemAllocation( + AllocationValue& allocation_value) const; + + // Adds a required assignment in default memory, at the given time, if + // allocation_value's defining position is not allowed in alternate memory. + void AssignDefaultMemIfNotAllowedInAlternateMem( + AllocationValue& allocation_value, int64_t time); + + // Returns all AllocationSegmentContexts needed for a given set of + // AllocationValues that we would like to process jointly. + std::vector GenerateAllocationSegmentContexts( + absl::Span& allocation_values, + absl::flat_hash_map>& + value_indices_by_sync_inst, + int allocation_value_idx) const; + + bool VerifyAllConversionsAreSuccessful(); // Finds allocations for allocation values generated from colocated intervals. // All of the allocation values have a must-alias relationship with each @@ -1036,14 +1118,19 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { AliasedOffset* aliased_offset, float resource, std::optional cross_program_prefetch_index); - // Adds an asynchronous copy to allocations. - void AddAsyncCopy( + // Adds an asynchronous copy or other memory operation (e.g., slice) to + // allocations. We pass sync_mem_op to the CopyAllocation constructor. When + // sync_mem_op is set, instead of an async copy, CopyAllocation::Process() + // will replace sync_mem_op with the async version of sync_mem_op's opcode + // (e.g., slice) and shape. + void AddAsyncCopyOrOtherMemOp( Allocation& prev_allocation, MemorySpace memory_space, std::optional chunk, int64_t exclusive_start_time, int64_t end_time, int64_t copy_done_schedule_before_time, AllocationSequence* allocations, AliasedOffset* aliased_offset, float resource, - std::optional cross_program_prefetch_index = std::nullopt); + std::optional cross_program_prefetch_index = std::nullopt, + HloInstruction* sync_mem_op = nullptr); // For prefetching, adds a SlicedCopyAllocation to allocations. Also updates // asynchronous copy data structures, prefetch_interval_tree_, and aliasing @@ -1052,7 +1139,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { const Allocation& prev_allocation, AllocationSequence* allocations, AliasedOffset* aliased_offset, const std::vector& slice_decisions_sorted_by_start_time, - int64_t prefetch_end_time, int64_t allocation_end_time); + int64_t prefetch_end_time, int64_t allocation_end_time, + HloInstruction* sync_mem_op); // For window prefetching, adds a WindowPrefetchedAllocation to allocations. // Also updates asynchronous copy data structures, prefetch_interval_tree_, @@ -1080,28 +1168,44 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // Clears all pending chunks and asynchronous copies. void ClearPendingChunks(); - bool IsInstructionInPendingCopyReplacements( - const AllocationValue::Use& use) const { - return IsInstructionInPendingCopyReplacements(use.hlo_use.instruction); - } - - bool IsInstructionInPendingCopyReplacements( + // Returns true if we are trying to replace instruction with its async + // version, while processing JointAllocationProposal. + bool IsInstructionPendingReplacements( const HloInstruction* instruction) const; // Colors the colocated intervals in the alternate memory. void ColorColocatedIntervalsToAlternate( const std::vector& colocated_intervals); - // Iterates over joint_processed_values and populates - // joint_processed_allocation_values with allocation values created for the - // joint-processed values, and populates joint_processed_colocated_intervals - // with a vector of colocated buffer intervals, one vector per joint-processed - // value. - void CreateAllocationValuesForJointProcessedIntervals( - const std::vector& joint_processed_values, - std::vector& joint_allocation_values, - std::vector>& - joint_colocated_intervals); + // A proposal for a group of values to be allocated jointly. Proposals are not + // guaranteed to be accepted, and when they fail, the algorithm will try to + // come up with a new proposal on a smaller subset of values. + struct JointAllocationProposal { + // The values that are being jointly processed. + std::vector values; + // The allocation values created for the joint-processed values. + std::vector allocation_values; + // The colocated buffer intervals for the joint-processed values. This is a + // vector of vectors, one vector per joint-processed value, and the + // colocation must be only enforced on intervals belonging to the same + // joint-processed value. + std::vector> colocated_intervals; + }; + + // Iterates over proposal's values and populates its allocation_values and + // colocated_intervals with the appropriate allocation values and colocated + // intervals created for the values. + void CreateAllocationValuesForJointProcessedValues( + JointAllocationProposal& proposal); + + // Returns a JointAllocationProposal with values, allocation + // values, and colocated intervals that are proposed to be processed jointly + // for the given interval. Also, if the interval consumes or produces any + // synchronous memory call instructions (e.g., kCopy, kSlice) and the option + // to replace them with their asynchronous versions is enabled, this method + // will add those instructions to the sorted_async_conversion_candidates_ + // vector. + JointAllocationProposal GetJointProposal(MsaBufferInterval& interval); // Append buffer and allocation infos for debugging and dump it into a file, // if enabled. @@ -1151,6 +1255,13 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { bool IsIntervalPinnedToAlternateMemory( const MsaBufferInterval& interval) const; + // A convenience debugging method that returns true if the prefetch context + // matches the described producer and consumer. + bool MatchesPrefetchContext(const PrefetchContext& context, + absl::string_view producer_name, + ShapeIndex producer_shape_index, + absl::string_view consumer_name) const; + AllocationSequence* allocations_; const Options& options_; const HloAliasAnalysis& alias_analysis_; @@ -1173,12 +1284,11 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { std::vector pending_async_copies_; std::vector> pending_required_assignments_; - // A list of candidate sync copy instructions that we are trying to replace - // with asynchronous copies while processing the current interval, sorted by + // A list of candidate sync instructions that we are trying to replace with + // an asynchronous version, while processing the current interval, sorted by // their order in the instruction schedule. Being in this list doesn't - // guarantee that the copy will be replaced. These sync copies are basically - // the instructions that connects the values that are being jointly processed. - std::vector sorted_sync_copy_replacement_candidates_; + // guarantee that the sync instruction will be converted to async. + std::vector sorted_async_conversion_candidates_; // A cache to keep the peak memory usage at each point in the graph. We use // this to see if the proposed allocation in the alternate memory would fit // ignoring fragmentation, and if not, we can skip the more expensive lookup @@ -1215,7 +1325,7 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::flat_hash_map loop_optimized_allocations_map_; // A map to look the operands of each instruction that are assigned in - // alternate memory. + // alternate memory or are window prefetched. absl::flat_hash_map>> operands_in_alternate_memory_map_; @@ -1228,8 +1338,10 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::flat_hash_set finalized_values_; // Set of sync copy instructions that we failed/succeeded in replacing with // asynchronous copies. - absl::flat_hash_set failed_copy_replacements_set_; - absl::flat_hash_set successful_copy_replacements_set_; + absl::flat_hash_map + failed_async_conversions_; + absl::flat_hash_set successful_async_conversion_set_; + std::vector not_finalized_async_conversions_; // Debug strings. std::string buffer_info_str_; std::string allocation_info_str_; diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc index 54edfba58e5276..52cce90dca9a81 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -279,7 +279,8 @@ std::string PinnedAllocation::ToString() const { memory_space() == MemorySpace::kDefault ? "def" : "alt"; std::optional chunk = maybe_chunk(); if (chunk) { - absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, + ", size: ", chunk->size, ")"); } return absl::StrCat((is_scoped_allocation() ? "Scoped " : ""), "PinnedAllocation in ", memory_space_str, " defined at ", @@ -303,7 +304,8 @@ CopyAllocation::CopyAllocation( std::optional chunk, int64_t copy_start_schedule_after_time, int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index) + std::optional cross_program_prefetch_index, + HloInstruction* sync_mem_op) : Allocation( /*defining_position=*/{nullptr, {}}, memory_space, chunk, // Allocation uses an inclusive start time @@ -312,7 +314,8 @@ CopyAllocation::CopyAllocation( /*is_scoped_allocation=*/false, cross_program_prefetch_index), prev_allocation_(prev_allocation), copy_start_schedule_after_(copy_start_schedule_after_time), - copy_done_schedule_before_(copy_done_schedule_before_time) {} + copy_done_schedule_before_(copy_done_schedule_before_time), + sync_mem_op_(sync_mem_op) {} int64_t CopyAllocation::earliest_available_time() const { return copy_done_schedule_before_; @@ -323,11 +326,31 @@ absl::Status CopyAllocation::Process() { Shape shape = defining_position().shape(); HloInstruction* producing_instruction = AddGetTupleElements(); HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( - ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - producing_instruction, cross_program_prefetch_index())); - copy_done_ = computation->AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + if (sync_mem_op_ != nullptr && sync_mem_op_->opcode() != HloOpcode::kCopy) { + TF_ASSIGN_OR_RETURN(copy_done_, + computation->CreateAsyncInstructions( + sync_mem_op_, {ShapeUtil::MakeShape(S32, {})}, + HloInstruction::kMainExecutionThread, false)); + copy_start_ = copy_done_->mutable_operand(0); + // If the shape of the copy start operand is not compatible with the + // shape of the producing instruction, we insert a bitcast to make them + // compatible. + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + producing_instruction->shape(), copy_start_->operand(0)->shape())) { + producing_instruction = + computation->AddInstruction(HloInstruction::CreateBitcast( + copy_start_->operand(0)->shape(), producing_instruction)); + } + TF_RETURN_IF_ERROR( + copy_start_->ReplaceOperandWith(0, producing_instruction)); + } else { + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( + ShapeUtil::MakeTupleShape( + {shape, shape, ShapeUtil::MakeShape(U32, {})}), + producing_instruction, cross_program_prefetch_index())); + copy_done_ = computation->AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); + } VLOG(4) << "Created " << copy_start_->name() << " for copy allocation: " << ToString(); @@ -353,13 +376,15 @@ std::string CopyAllocation::ToString() const { memory_space() == MemorySpace::kDefault ? "def" : "alt"; std::optional chunk = maybe_chunk(); if (chunk) { - absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, ")"); + absl::StrAppend(&memory_space_str, " (off: ", chunk->offset, + ", size: ", chunk->size, ")"); } return absl::StrCat("Copy Allocation in ", memory_space_str, ", start_time:", start_time(), ", end_time:", end_time(), ", copy_start_after_time: ", copy_start_schedule_after(), ", copy_done_before_time: ", copy_done_schedule_before(), - ", uses: ", UsesToString(uses()), ", from ", + ", uses: ", UsesToString(uses()), ", sync_mem_op: ", + sync_mem_op_ ? sync_mem_op_->name() : "none", ", from ", prev_allocation_.ToString()); } @@ -412,7 +437,8 @@ SlicedCopyAllocation::SlicedCopyAllocation( std::vector slice_decisions_sorted_by_exclusive_start_time, int64_t copy_done_schedule_before_time, int64_t end_time, const SlicedPrefetchOptions& sliced_prefetch_options, - absl::FunctionRef get_equivalent_s8_shape_fn) + absl::FunctionRef get_equivalent_s8_shape_fn, + HloInstruction* sync_mem_op) : Allocation( /*defining_position=*/{nullptr, {}}, memory_space, GetSlicedCopyAllocationChunk( @@ -427,7 +453,8 @@ SlicedCopyAllocation::SlicedCopyAllocation( original_shape_to_slice_(prev_allocation.defining_position().shape()), prev_allocation_(prev_allocation), sliced_prefetch_options_(sliced_prefetch_options), - get_equivalent_s8_shape_fn_(get_equivalent_s8_shape_fn) { + get_equivalent_s8_shape_fn_(get_equivalent_s8_shape_fn), + sync_mem_op_(sync_mem_op) { CHECK_GE(slice_decisions_sorted_by_exclusive_start_time.size(), 2); slice_details_sorted_by_exclusive_start_time_.reserve( slice_decisions_sorted_by_exclusive_start_time.size()); @@ -621,7 +648,9 @@ std::string SlicedCopyAllocation::ToString() const { slice_details_sorted_by_start_time().front().copy_start_after_time, ", last_slice_copy_done_before_time: ", slice_details_sorted_by_start_time().back().copy_done_before_time, - ", uses: ", UsesToString(uses()), ", from ", prev_allocation_.ToString()); + ", uses: ", UsesToString(uses()), + ", sync_mem_op: ", sync_mem_op_ ? sync_mem_op_->name() : "none", + ", from ", prev_allocation_.ToString()); } absl::Status SlicedCopyAllocation::CreateBitcastConcat( @@ -1000,4 +1029,73 @@ std::vector GetAllocationSequenceInRawPointers( return allocations_in_raw_pointers; } +namespace { + +struct AllocationSummary { + static void Add(const Allocation& allocation, + std::vector& data) { + if (!allocation.is_sliced_copy_allocation()) { + std::string name = allocation.defining_position().ToString(); + if (allocation.cross_program_prefetch_index().has_value()) { + absl::StrAppend(&name, " (xprogram prefetch)"); + } + data.push_back(AllocationSummary{allocation.chunk(), + allocation.start_time(), + allocation.end_time(), name}); + return; + } + const SlicedCopyAllocation& sliced_copy_allocation = + dynamic_cast(allocation); + for (int i = 0; + i < sliced_copy_allocation.slice_details_sorted_by_start_time().size(); + ++i) { + std::string name = absl::StrCat( + sliced_copy_allocation.defining_position().ToString(), " (slice ", i, + (sliced_copy_allocation.cross_program_prefetch_index().has_value() + ? ", xprogram prefetch" + : ""), + ")"); + const SlicedCopyAllocation::SliceDetail& slice_detail = + sliced_copy_allocation.slice_details_sorted_by_start_time()[i]; + data.push_back(AllocationSummary{ + slice_detail.slice_decision.chunk, + ExclusiveToInclusiveStartTime( + slice_detail.slice_decision.exclusive_start_time), + sliced_copy_allocation.end_time(), name}); + } + } + + HeapSimulator::Chunk chunk; + int64_t start_time_inclusive; + int64_t end_time_inclusive; + std::string name; +}; + +} // namespace + +void AllocationSequenceDebugging::LogAltMemAllocationsAt( + const AllocationSequence& allocations, int64_t time) { + std::vector data_vector; + for (const std::unique_ptr& allocation : allocations) { + if (allocation->start_time() <= time && allocation->end_time() >= time && + allocation->is_in_alternate_mem()) { + AllocationSummary::Add(*allocation, data_vector); + } + } + absl::c_sort(data_vector, + [](const AllocationSummary& a, const AllocationSummary& b) { + if (a.chunk.offset != b.chunk.offset) { + return a.chunk.offset < b.chunk.offset; + } + return a.start_time_inclusive < b.start_time_inclusive; + }); + LOG(INFO) << "Live allocations in alternate mem at instruction time " << time + << " (before MSA alters the graph):"; + for (const AllocationSummary& data : data_vector) { + LOG(INFO) << "Alt mem allocation in chunk " << data.chunk.ToString() + << " during [" << data.start_time_inclusive << "," + << data.end_time_inclusive << "], holding " << data.name; + } +} + } // namespace xla::memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h index b576ade9fcb34e..81ac4199c5b86f 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.h +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ +#include + #include #include #include @@ -25,12 +27,15 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" @@ -240,7 +245,8 @@ class CopyAllocation final : public Allocation { std::optional chunk, int64_t copy_start_schedule_after_time, int64_t copy_done_schedule_before_time, int64_t end_time, - std::optional cross_program_prefetch_index = std::nullopt); + std::optional cross_program_prefetch_index = std::nullopt, + HloInstruction* sync_mem_op = nullptr); // Overridden methods // @@ -263,6 +269,7 @@ class CopyAllocation final : public Allocation { bool operator==(const Allocation& other) const override; // New non-virtual methods + const HloInstruction* sync_mem_op() const { return sync_mem_op_; } bool operator==(const CopyAllocation& other) const; const Allocation& prev_allocation() { return prev_allocation_; } @@ -286,6 +293,8 @@ class CopyAllocation final : public Allocation { int64_t copy_done_schedule_before_; HloInstruction* copy_start_ = nullptr; HloInstruction* copy_done_ = nullptr; + // The sync data movement instruction that this copy is associated with. + HloInstruction* sync_mem_op_ = nullptr; }; // This class represents an allocation resulting from asynchronous sliced @@ -343,7 +352,8 @@ class SlicedCopyAllocation final : public Allocation { std::vector slice_decisions_sorted_by_exclusive_start_time, int64_t copy_done_schedule_before_time, int64_t end_time, const SlicedPrefetchOptions& sliced_prefetch_options, - absl::FunctionRef get_equivalent_s8_shape_fn); + absl::FunctionRef get_equivalent_s8_shape_fn, + HloInstruction* sync_mem_op = nullptr); // Overridden methods // @@ -368,6 +378,7 @@ class SlicedCopyAllocation final : public Allocation { bool operator==(const Allocation& other) const override; // New non-virtual methods + const HloInstruction* sync_mem_op() const { return sync_mem_op_; } bool operator==(const SlicedCopyAllocation& other) const; std::vector SliceOffsetsSortedByStartTime() const; @@ -396,6 +407,8 @@ class SlicedCopyAllocation final : public Allocation { HloInstruction* concat_ = nullptr; const SlicedPrefetchOptions& sliced_prefetch_options_; absl::FunctionRef get_equivalent_s8_shape_fn_; + // The sync data movement instruction that this copy is associated with. + HloInstruction* sync_mem_op_ = nullptr; }; // This class represents an allocation resulting from asynchronously prefetching @@ -533,6 +546,17 @@ class ParentAllocation final : public Allocation { HloInstruction* calling_instruction_; }; +// A class with some utility functions that are useful in debugging. +struct AllocationSequenceDebugging { + // Developers can call this method to log all the allocations in alternate + // memory, at a given instruction time. + // + // REQUIRED: + // - This method is intended to be called before MSA modifies the HloModule. + static void LogAltMemAllocationsAt(const AllocationSequence& allocations, + int64_t time); +}; + } // namespace xla::memory_space_assignment #endif // XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_ALLOCATION_H_ diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc b/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc new file mode 100644 index 00000000000000..bba0f400507f4c --- /dev/null +++ b/third_party/xla/xla/service/memory_space_assignment/allocation_test.cc @@ -0,0 +1,153 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/memory_space_assignment/allocation.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/heap_simulator/heap_simulator.h" +#include "xla/service/hlo_value.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" + +namespace xla::memory_space_assignment { +namespace { + +class AllocationTest : public HloTestBase {}; + +TEST_F(AllocationTest, CopyAllocationProcessSimple) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[2,3]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + p1_negate = f32[2,3]{1,0} negate(p1) + add = f32[2,3]{1,0} add(p0, p1_negate) + ROOT tuple = tuple(add, p0) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // HloComputation* computation = module->entry_computation(); + HloInstruction* add = FindInstruction(module.get(), "add"); + HloInstruction* p1_negate = FindInstruction(module.get(), "p1_negate"); + + HeapSimulator::Chunk p1_negate_chunk = + HeapSimulator::Chunk::FromOffsetSize(0, 24); + + PinnedAllocation p1_negate_pinned( + HloPosition{p1_negate, {}}, MemorySpace::kAlternate, p1_negate_chunk, + /*start_time=*/0, + /*end_time=*/5, /*is_scoped_allocation=*/false); + CopyAllocation copy_allocation(p1_negate_pinned, MemorySpace::kAlternate, + std::nullopt, + /*copy_start_schedule_after_time=*/2, + /*copy_done_schedule_before_time=*/3, + /*end_time=*/5, std::nullopt, + /*sync_instruction=*/nullptr); + + // Use the correct instruction and operand numbers for the add instruction + copy_allocation.AddUse(HloUse{add, 1}); // Use of p1_negate in add + + TF_ASSERT_OK(copy_allocation.Process()); + + // Check copy_start and copy_done instructions. + HloInstruction* copy_start = copy_allocation.copy_start(); + ASSERT_NE(copy_start, nullptr); + EXPECT_EQ(copy_start->opcode(), HloOpcode::kCopyStart); + EXPECT_EQ(copy_start->operand(0), p1_negate); + + HloInstruction* copy_done = copy_allocation.copy_done(); + ASSERT_NE(copy_done, nullptr); + EXPECT_EQ(copy_done->opcode(), HloOpcode::kCopyDone); + EXPECT_EQ(copy_done->operand(0), copy_start); + + // Check that uses are updated. + EXPECT_EQ(add->operand(1), copy_done); + + // Check defining position + EXPECT_EQ(copy_allocation.defining_position().instruction, copy_done); +} + +TEST_F(AllocationTest, CopyAllocationProcessReplaceSyncSlice) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + p0 = f32[1,3]{1,0} parameter(0) + p1 = f32[2,3]{1,0} parameter(1) + p1_negate = f32[2,3]{1,0} negate(p1) + slice = f32[1,3]{1,0} slice(p1_negate), slice={[0:1], [0:3]} + add = f32[1,3]{1,0} add(p0, slice) + ROOT tuple = tuple(add, p0) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + // HloComputation* computation = module->entry_computation(); + HloInstruction* add = FindInstruction(module.get(), "add"); + HloInstruction* p1_negate = FindInstruction(module.get(), "p1_negate"); + HloInstruction* slice = FindInstruction(module.get(), "slice"); + + HeapSimulator::Chunk p1_negate_chunk = + HeapSimulator::Chunk::FromOffsetSize(0, 24); + + PinnedAllocation p1_negate_pinned( + HloPosition{p1_negate, {}}, MemorySpace::kAlternate, p1_negate_chunk, + /*start_time=*/0, + /*end_time=*/5, /*is_scoped_allocation=*/false); + CopyAllocation copy_allocation(p1_negate_pinned, MemorySpace::kAlternate, + std::nullopt, + /*copy_start_schedule_after_time=*/2, + /*copy_done_schedule_before_time=*/3, + /*end_time=*/5, std::nullopt, + /*sync_instruction=*/slice); + + // Use the correct instruction and operand numbers for the add instruction + copy_allocation.AddUse(HloUse{add, 1}); // Use of p1_negate in add + + TF_ASSERT_OK(copy_allocation.Process()); + + // Check copy_start and copy_done instructions. + HloInstruction* slice_start = copy_allocation.copy_start(); + ASSERT_NE(slice_start, nullptr); + EXPECT_EQ(slice_start->opcode(), HloOpcode::kAsyncStart); + EXPECT_EQ(slice_start->operand(0), p1_negate); + + HloInstruction* slice_done = copy_allocation.copy_done(); + ASSERT_NE(slice_done, nullptr); + EXPECT_EQ(slice_done->opcode(), HloOpcode::kAsyncDone); + EXPECT_EQ(slice_done->operand(0), slice_start); + + // Check the shapes. + EXPECT_EQ(slice_done->shape(), slice->shape()); + + // Check that uses are updated. + EXPECT_EQ(add->operand(1), slice_done); + + // Check defining position + EXPECT_EQ(copy_allocation.defining_position().instruction, slice_done); +} + +} // namespace +} // namespace xla::memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc index d82ee863f868aa..ac4fabec0991f3 100644 --- a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.cc @@ -26,17 +26,17 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" -#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h index 72229fcab2d273..42899bd133d503 100644 --- a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h @@ -26,11 +26,11 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc index 06001beef0ef65..be602ea54b2fae 100644 --- a/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -38,12 +38,6 @@ namespace { using memory_space_assignment::CostAnalysis; using memory_space_assignment::CostAnalysisOptions; -constexpr int64_t kPointerSize = 8; - -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { protected: absl::Status Initialize(const HloModule* module, @@ -53,7 +47,6 @@ class MemorySpaceAssignmentCostAnalysisTest : public HloTestBase { options_.async_copy_bandwidth_bytes_per_second = 32; options_.pipeline_overhead_window_size_mib = pipeline_overhead_window_size_mib; - options.shape_size = ShapeSize; options.set_flops_per_second(8); options.set_bytes_per_second(32); options.set_transcendentals_per_second(16); diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc index 9c4b1a2e8bd39b..82290fe5ed0669 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.cc @@ -36,7 +36,9 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -45,7 +47,6 @@ limitations under the License. #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -62,6 +63,21 @@ namespace xla { namespace memory_space_assignment { namespace { +struct LoopOptimizerChunkInterval { + int64_t begin_idx_in_loop; + int64_t end_idx_in_loop; + EvenOddChunkPair chunks; + + std::string ToString() const { + CHECK(chunks.HasValues()); + return absl::StrFormat( + "begin_idx_in_loop: %d, end_idx_in_loop: %d, even chunk: %s, odd " + "chunk: %s", + begin_idx_in_loop, end_idx_in_loop, chunks.even_chunk->ToString(), + chunks.odd_chunk->ToString()); + } +}; + std::optional GetInstructionIndex( const HloInstruction* instruction, const absl::flat_hash_map& @@ -137,7 +153,7 @@ void LoopOptimizerBestFitHeap::RemoveEvenOddChunkPair( EvenOddChunkPair& chunks) { CheckAllocationIntervalValid(begin_idx_in_loop, end_idx_in_loop); ShiftAllocationIntervalIfRequired(begin_idx_in_loop, end_idx_in_loop); - auto [even_chunk, odd_chunk] = chunks; + auto& [even_chunk, odd_chunk] = chunks; RemoveEvenChunks(begin_idx_in_loop, end_idx_in_loop, even_chunk); RemoveOddChunks(begin_idx_in_loop, end_idx_in_loop, odd_chunk); } @@ -325,18 +341,17 @@ int64_t LoopOptimizerBestFitHeap::LastMemoryOffsetOccupied() const { } /*static*/ absl::StatusOr> -MemoryBoundLoopOptimizer::Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, - const CostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) { +MemoryBoundLoopOptimizer::Create(int loop_start, int loop_end, + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis, + const Options& options) { + CHECK(options.cost_analysis != nullptr); std::unique_ptr optimizer = absl::WrapUnique(new MemoryBoundLoopOptimizer( - loop_start, loop_end, alternate_memory_size, options, hlo_live_range, - alias_analysis, cost_analysis, size_function, - reserved_scoped_memory_fn)); + loop_start, loop_end, options.max_size_in_bytes, + options.memory_bound_loop_optimizer_options, hlo_live_range, + alias_analysis, *options.cost_analysis, options.size_fn, + options.reserved_scoped_memory_fn, options.alignment_in_bytes)); TF_RETURN_IF_ERROR(optimizer->Initialize()); return std::move(optimizer); } @@ -347,7 +362,8 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( const HloLiveRange& hlo_live_range, const HloAliasAnalysis& alias_analysis, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn) + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, + int64_t alignment_in_bytes) : loop_start_(loop_start), loop_end_(loop_end), loop_size_(loop_end - loop_start), @@ -357,13 +373,17 @@ MemoryBoundLoopOptimizer::MemoryBoundLoopOptimizer( alias_analysis_(alias_analysis), cost_analysis_(cost_analysis), size_function_(size_function), - reserved_scoped_memory_fn_(reserved_scoped_memory_fn) {} + reserved_scoped_memory_fn_(reserved_scoped_memory_fn), + heap_(LoopOptimizerBestFitHeap(alternate_memory_size, + /*loop_size=*/loop_end - loop_start, + alignment_in_bytes)) {} absl::Status MemoryBoundLoopOptimizer::Initialize() { const auto& instruction_sequence = hlo_live_range_.flattened_instruction_sequence().instructions(); VLOG(3) << "MemoryBoundLoopOptimizer::Initialize, loop start: " << loop_start_ - << ", loop end: " << loop_end_ << ", loop size: " << loop_size_; + << ", loop end: " << loop_end_ << ", loop size: " << loop_size_ + << ", alternate memory size: " << alternate_memory_size_; const HloComputation* loop_computation = nullptr; // Initialize the remaining memory array with the size of the alternate // memory. Also populate instructions_in_loop_ and @@ -387,11 +407,20 @@ absl::Status MemoryBoundLoopOptimizer::Initialize() { } else { TF_RET_CHECK(loop_computation == loop_inst->parent()); } - remaining_memory_.push_back( - alternate_memory_size_ - + int64_t reserved_memory = reserved_scoped_memory_fn_(loop_inst, /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{})); + /*outputs_in_alternate_memory=*/{}); + if (reserved_memory == 0) { + continue; + } + // Chunks for reserved scoped memory should always be found at offset 0. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + i, i, reserved_memory, /*preferred_offsets=*/{0, 0}); + CHECK(chunks.HasValues()); + CHECK(chunks.even_chunk->size == reserved_memory); + VLOG(3) << "Reserved chunk: " << chunks.even_chunk->ToString() + << " loop index: " << i; } // Create a tree set to keep track of all the values that the loop @@ -572,7 +601,7 @@ void MemoryBoundLoopOptimizer::MaybeCreateLoopValue( VLOG(3) << "Savings: " << loop_value.savings; VLOG(3) << "Savings per byte: " << loop_value.savings_per_byte; for (const HloValue* value : buffer.values()) { - VLOG(3) << value->ToString(); + VLOG(6) << value->ToString(); } loop_value.hlo_values = buffer.values(); } else { @@ -809,11 +838,22 @@ std::string MemoryBoundLoopOptimizer::LoopValue::ToString() const { for (const auto& allocation : allocations) { absl::StrAppend(&allocations_str, "\n - ", allocation->ToString()); } + std::string chunk_str; + if (chunks.HasValues()) { + absl::StrAppend(&chunk_str, "\n", + "even chunk: ", chunks.even_chunk->ToString()); + absl::StrAppend(&chunk_str, "\n", + "odd chunk: ", chunks.odd_chunk->ToString()); + absl::StrAppend(&chunk_str, "\n", "alternate memory begin idx in loop: ", + alternate_memory_begin_idx_in_loop.value()); + absl::StrAppend(&chunk_str, "\n", "alternate memory end idx in loop: ", + alternate_memory_end_idx_in_loop.value()); + } return absl::StrCat( "Size: ", size, " savings: ", savings, " savings per byte: ", savings_per_byte, - " allocation type: ", AllocationTypeToString(allocation_type), "\n", - values_str, "\n", allocations_str); + " allocation type: ", AllocationTypeToString(allocation_type), chunk_str, + "\n", values_str, "\n", allocations_str); } bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { @@ -822,6 +862,14 @@ bool MemoryBoundLoopOptimizer::LoopValue::IsAllocationTypeSupported() const { allocation_type == AllocationType::kPrefetch; } +void MemoryBoundLoopOptimizer::LoopValue::SetChunkPairAndInterval( + EvenOddChunkPair chunk_pair, int64_t begin_idx_in_loop, + int64_t end_idx_in_loop) { + chunks = chunk_pair; + alternate_memory_begin_idx_in_loop = begin_idx_in_loop; + alternate_memory_end_idx_in_loop = end_idx_in_loop; +} + void MemoryBoundLoopOptimizer::SortLoopValues() { absl::c_stable_sort(loop_values_, [](const LoopValue& a, const LoopValue& b) { return a.savings_per_byte > b.savings_per_byte; @@ -850,9 +898,13 @@ void MemoryBoundLoopOptimizer::AllocateLoopValues() { VLOG(1) << "Unsupported allocation: " << value.ToString(); } } + VLOG(6) << "Heap after allocating temporaries:\n" + << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating temporaries: " << CalculateExecutionTime(); AllocatePrefetches(absl::MakeSpan(prefetch_values)); + VLOG(6) << "Heap after allocating prefetches:\n" + << heap_.MemoryUsageToAsciiArt(); VLOG(3) << "Execution time after allocating prefetches: " << CalculateExecutionTime(); } @@ -897,26 +949,10 @@ void MemoryBoundLoopOptimizer::PostProcess() { value.allocations.back()->AddUse(use); } } + VLOG(3) << "LoopValue: " << value.ToString(); } } -bool MemoryBoundLoopOptimizer::AllocateBetween(int64_t begin_idx, - int64_t end_idx, int64_t size) { - int64_t end_idx_sentinel = end_idx; - if (end_idx < begin_idx) { - end_idx_sentinel += loop_size_; - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - if (remaining_memory_[i % loop_size_] < size) { - return false; - } - } - for (int64_t i = begin_idx; i <= end_idx_sentinel; ++i) { - remaining_memory_[i % loop_size_] -= size; - } - return true; -} - bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { VLOG(3) << "AllocateTemporary: " << value.ToString(); if (value.hlo_values.size() > 1) { @@ -925,37 +961,59 @@ bool MemoryBoundLoopOptimizer::AllocateTemporary(LoopValue& value) { } int64_t definition_idx = value.loop_positions.front().first; int64_t max_use_idx; + int64_t begin_idx_in_loop = definition_idx; + int64_t end_idx_in_loop; if (!value.next_iteration_uses.empty()) { max_use_idx = value.next_iteration_uses.back().first; // If max_use_idx >= definition_idx, then this is a loop carried dependence // and we should not have called this function. CHECK_LT(max_use_idx, definition_idx); + end_idx_in_loop = max_use_idx + loop_size_; } else { max_use_idx = value.loop_uses.back().first; + end_idx_in_loop = max_use_idx; } - bool success = AllocateBetween(definition_idx, max_use_idx, value.size); - if (success) { - VLOG(3) << "Pos: " << value.loop_positions[0].second; - value.allocations.push_back(std::make_unique( - value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, - definition_idx, max_use_idx, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); + EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value.size); + if (!chunks.HasValues()) { + VLOG(3) << "Could not find Allocation for temporary value: " + << value.ToString(); + return false; } - return success; + value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + VLOG(3) << "Pos: " << value.loop_positions[0].second; + VLOG(3) << "Allocation found for temporary value: " << value.ToString(); + VLOG(6) << "Heap after allocating temporary value: " + << heap_.MemoryUsageToAsciiArt(); + value.allocations.push_back(std::make_unique( + value.loop_positions[0].second, MemorySpace::kAlternate, std::nullopt, + definition_idx, max_use_idx, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/true); + return true; } bool MemoryBoundLoopOptimizer::AllocatePinned(LoopValue& value) { - bool success = AllocateBetween(0, loop_size_ - 1, value.size); - if (success) { - CHECK(value.header_position); - value.allocations.push_back(std::make_unique( - *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, - loop_size_, - /*is_scoped_allocation=*/false)); - AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); + int64_t begin_idx_in_loop = 0; + int64_t end_idx_in_loop = loop_size_ - 1; + EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value.size); + if (!chunks.HasValues()) { + VLOG(3) << "Could not find Allocation for pinned value: " + << value.ToString(); + return false; } - return success; + value.SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + CHECK(value.header_position); + VLOG(3) << "Allocation found for pinned value: " << value.ToString(); + VLOG(6) << "Heap after allocating pinned value: " + << heap_.MemoryUsageToAsciiArt(); + value.allocations.push_back(std::make_unique( + *value.header_position, MemorySpace::kAlternate, std::nullopt, 0, + loop_size_, + /*is_scoped_allocation=*/false)); + AddAllLoopPositionsAndUses(value, /*allocate_next_iteration_uses=*/false); + return true; } bool MemoryBoundLoopOptimizer::AllocatePrefetches( @@ -1005,8 +1063,6 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( << *context.bandwidth_idle_times.rbegin(); } - context.additional_memory_used.resize(loop_size_, 0); - // Allocate prefetches by traversing the loop values in reverse order of // the first uses. for (int value_index : context.value_indices) { @@ -1014,10 +1070,6 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( } for (int i = 0; i < loop_size_; ++i) { - remaining_memory_[i] -= context.additional_memory_used[i]; - VLOG(3) << "Additional memory [" << i - << "]: " << context.additional_memory_used[i]; - VLOG(3) << "Remaining memory [" << i << "]: " << remaining_memory_[i]; VLOG(3) << "Remaining bandwidth [" << i << "] : " << context.bandwidth_idle_times[i]; } @@ -1026,7 +1078,7 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetches( bool MemoryBoundLoopOptimizer::AllocatePrefetch( int value_index, AllocatePrefetchesContext& context) { - LoopValue* value = context.values.at(value_index); + LoopValue* value = context.values[value_index]; VLOG(3) << "Allocating value: " << value->ToString(); int first_use_idx = value->loop_uses.front().first; int last_use_idx = value->loop_uses.back().first; @@ -1036,24 +1088,23 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( last_use_idx_sentinel = last_use_idx + loop_size_; CHECK_LT(last_use_idx, first_use_idx); } - bool out_of_memory = false; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - if (context.additional_memory_used[loop_idx] + value->size > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory allocating for uses."; - out_of_memory = true; - } - } - if (out_of_memory) { - return false; - } float copy_resource = cost_analysis_.GetAsyncCopyElapsed(value->hlo_values.front()->shape()); VLOG(3) << "First use: " << value->loop_uses.begin()->second << " use idx: " << first_use_idx << " copy resource: " << copy_resource; - std::optional copy_start_time; + const auto& [even_chunk, odd_chunk] = heap_.FindEvenAndOddAllocationBetween( + first_use_idx, last_use_idx_sentinel, value->size); + if (!even_chunk.has_value() || !odd_chunk.has_value()) { + // Not enough memory to even fit the value in the alternate memory for the + // duration of its live range. + VLOG(3) << "Could not find Allocation for prefetch value: " + << value->ToString(); + return false; + } + + std::optional copy_start_loop_idx; + int committed_early_forced_prefetches_count = 0; // The general allocation algorithm for prefetches is to first calculate the // default-memory bandwidth idle times at each point (assuming all prefetches // succeeded). We show this pictorially below. We also show the previous @@ -1160,23 +1211,31 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( float accumulated_copy_resource = 0; std::vector early_forced_prefetch_value_indices; int early_forced_prefetch_value_search_index = 0; - float early_forced_prefetch_additional_memory = 0; - for (int i = first_use_idx - 1; i >= last_use_idx_sentinel - loop_size_; - --i) { - int loop_idx = (i + loop_size_) % loop_size_; + VLOG(6) << "Memory usage before allocating prefetch value: " + << value->ToString() << "\n" + << heap_.MemoryUsageToAsciiArt(); + // NOTE: We can, in practice, run the following loop for loop_size + // iterations(one full loop), till first_use_idx - loop_size, as opposed to + // limiting it till last_use_idx_sentinel - loop_size. This will allow a + // prefetch to use all the idle bandwidth available during one full loop + // iteration. + for (int current_idx = first_use_idx - 1; + current_idx >= last_use_idx_sentinel - loop_size_; --current_idx) { + int loop_idx = (current_idx + loop_size_) % loop_size_; // Check if this prefetch rolls over to the previous iteration, check if any // already-scheduled prefetches would violate the FIFO order, and if so, // "early-force" them to be co-scheduled with this prefetch to maintain the // FIFO order. This of course increases the required memory, so also keep // track of additional memory that would be consumed. - if (i < 0) { + if (current_idx < 0) { for (; context.value_indices[early_forced_prefetch_value_search_index] != value_index; ++early_forced_prefetch_value_search_index) { VLOG(3) << "Searching for early forced: " << early_forced_prefetch_value_search_index; - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_search_index]); + LoopValue* early_forced_value = + context.values[context.value_indices + [early_forced_prefetch_value_search_index]]; if (early_forced_value->allocations.empty()) { continue; } @@ -1199,31 +1258,85 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } early_forced_prefetch_value_indices.push_back( early_forced_prefetch_value_search_index); - early_forced_prefetch_additional_memory += early_forced_value->size; - VLOG(3) << "Found early-forced prefetch value: " + VLOG(6) + << "Memory usage before removing prefetch value for early force: " + << early_forced_value->ToString() << "\n" + << heap_.MemoryUsageToAsciiArt(); + // Remove the original chunk from the heap. + heap_.RemoveEvenOddChunkPair( + early_forced_value->alternate_memory_begin_idx_in_loop.value(), + early_forced_value->alternate_memory_end_idx_in_loop.value(), + early_forced_value->chunks); + } + } + + VLOG(3) << "Loop idx:" << loop_idx << " Early force prefetch values: " + << early_forced_prefetch_value_indices.size(); + VLOG(6) << "Memory usage before adding pending chunks: \n" + << heap_.MemoryUsageToAsciiArt(); + std::vector pending_chunk_intervals; + for (int early_forced_prefetch_value_index : + early_forced_prefetch_value_indices) { + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; + int64_t begin_idx_in_loop = loop_idx; + int64_t end_idx_in_loop = + early_forced_value->alternate_memory_end_idx_in_loop.value(); + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); + if (!chunks.HasValues()) { + VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " + << end_idx_in_loop << " for early forced value: " << early_forced_value->ToString(); - VLOG(3) << "Early forced prefetch additional memory: " - << early_forced_prefetch_additional_memory; + VLOG(6) << "Memory usage after failed allocation: \n" + << heap_.MemoryUsageToAsciiArt(); + break; } + pending_chunk_intervals.push_back( + {begin_idx_in_loop, end_idx_in_loop, chunks}); + VLOG(3) << "Added pending chunk: " + << pending_chunk_intervals.back().ToString() + << " for value: " << early_forced_value->ToString(); } - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - if (loop_idx == last_use_idx) { - overlap_memory_overhead = value->size; - VLOG(3) << "Loop idx == last use idx (" << loop_idx - << "), overlap memory overhead = " << overlap_memory_overhead; + if (pending_chunk_intervals.size() == + early_forced_prefetch_value_indices.size()) { + int64_t begin_idx_in_loop = current_idx; + int64_t end_idx_in_loop = last_use_idx_sentinel; + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value->size); + if (chunks.HasValues()) { + pending_chunk_intervals.push_back( + {begin_idx_in_loop, end_idx_in_loop, chunks}); + VLOG(3) << "Added pending chunk: " + << pending_chunk_intervals.back().ToString() + << " for current value: " << value->ToString(); + } else { + VLOG(3) << "Could not allocate between " << begin_idx_in_loop << " and " + << end_idx_in_loop << " for value: " << value->ToString(); + VLOG(6) << "Memory usage after failed allocation: \n" + << heap_.MemoryUsageToAsciiArt(); + } + } + + bool out_of_memory = pending_chunk_intervals.size() < + early_forced_prefetch_value_indices.size() + 1; + + // Remove the pending chunks from the heap. + for (auto& pending_chunk_interval : pending_chunk_intervals) { + VLOG(3) << "Removing pending chunk: " + << pending_chunk_interval.ToString(); + heap_.RemoveEvenOddChunkPair(pending_chunk_interval.begin_idx_in_loop, + pending_chunk_interval.end_idx_in_loop, + pending_chunk_interval.chunks); } - // OOM; give up prefetch. - if (context.additional_memory_used[loop_idx] + value->size + - overlap_memory_overhead + early_forced_prefetch_additional_memory > - remaining_memory_[loop_idx]) { - VLOG(3) << "Ran out of memory. Accumulated copy resource " - << accumulated_copy_resource << " out of " << copy_resource - << " at " << loop_idx; + VLOG(6) << "Memory usage after removing pending chunks: " + << heap_.MemoryUsageToAsciiArt(); + + if (out_of_memory) { + VLOG(3) << "Ran out of memory for value: " << value->ToString(); break; } @@ -1243,16 +1356,20 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( (copy_resource - accumulated_copy_resource)); if (bandwidth_idle_time >= copy_resource - accumulated_copy_resource) { accumulated_copy_resource = copy_resource; - copy_start_time = loop_idx; + copy_start_loop_idx = current_idx; + committed_early_forced_prefetches_count = + early_forced_prefetch_value_indices.size(); VLOG(3) << "Found the complete copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; break; - } else if (!copy_start_time && + } else if (!copy_start_loop_idx.has_value() && accumulated_copy_resource + bandwidth_idle_time >= copy_resource * options_.desired_copy_ratio()) { accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; + copy_start_loop_idx = current_idx; + committed_early_forced_prefetches_count = + early_forced_prefetch_value_indices.size(); VLOG(3) << "Found the desired copy ratio and updated accumulated copy " "resource: " << accumulated_copy_resource; @@ -1261,7 +1378,9 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( // Even if desired resource isn't reached, and if the options allow it, // allow a fully pipelined prefetch. accumulated_copy_resource += bandwidth_idle_time; - copy_start_time = loop_idx; + copy_start_loop_idx = current_idx; + committed_early_forced_prefetches_count = + early_forced_prefetch_value_indices.size(); VLOG(3) << "Could not reach the desired copy ratio but scheduling " "fully pipelined prefetch anyway: " << accumulated_copy_resource; @@ -1273,27 +1392,52 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( } } - // Could not find a suitable copy start time. - if (!copy_start_time) { + // Restore original heap state as is for values that are not being early + // forced. This is to ensure that the memory usage is the same as before early + // forcing. If no copy start time was found, all the prefetches will be + // restored to their original state. If a copy start time was found, the + // prefetches that will not be early forced will be restored to their original + // state. + VLOG(6) << "Memory usage before restoring original state: " + << heap_.MemoryUsageToAsciiArt(); + for (int i = committed_early_forced_prefetches_count; + i < early_forced_prefetch_value_indices.size(); ++i) { + int early_forced_prefetch_value_index = + early_forced_prefetch_value_indices[i]; + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; + // Allocate a chunk in at the same offset as the original prefetch. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + early_forced_value->alternate_memory_begin_idx_in_loop.value(), + early_forced_value->alternate_memory_end_idx_in_loop.value(), + early_forced_value->size, + {early_forced_value->chunks.even_chunk->offset, + early_forced_value->chunks.odd_chunk->offset}); + // The chunk should always be present as we are allocating at the same + // offset. + CHECK(chunks.HasValues()); + CHECK_EQ(chunks.even_chunk->offset, + early_forced_value->chunks.even_chunk->offset); + CHECK_EQ(chunks.odd_chunk->offset, + early_forced_value->chunks.odd_chunk->offset); + } + VLOG(6) << "Memory usage after restoring original state: " + << heap_.MemoryUsageToAsciiArt(); + + if (!copy_start_loop_idx.has_value()) { + VLOG(3) << "Could not find a suitable copy start time for value: " + << value->ToString(); return false; } - VLOG(3) << "Success: copy_start_time: " << *copy_start_time + VLOG(3) << "Success: copy_start_loop_idx: " << copy_start_loop_idx.value() << " leftover copy resource: " << (copy_resource - accumulated_copy_resource); - auto update_additional_memory_used = [&](int loop_idx, int64_t addition) { - VLOG(4) << "Updating additional memory used at " << loop_idx << ". " - << context.additional_memory_used[loop_idx] << " + " << addition - << " => " << (context.additional_memory_used[loop_idx] + addition) - << " (remaining: " << remaining_memory_[loop_idx] << ")"; - context.additional_memory_used[loop_idx] += addition; - CHECK_LE(context.additional_memory_used[loop_idx], - remaining_memory_[loop_idx]); - }; - for (int i = first_use_idx; i <= last_use_idx_sentinel; ++i) { - int loop_idx = i % loop_size_; - update_additional_memory_used(loop_idx, value->size); - } + // We are early forcing the prefetches of the previous iteration. This is the + // corresponding copy start index in the previous iteration. + int early_prefetch_copy_start_loop_idx = + (copy_start_loop_idx.value() + loop_size_) % loop_size_; // We reset accumulated copy resource and then reuse it to accumulate copy // resource time in order to replay the previous for loop. It is important // that we use the same arithmetic operations (as opposed to subtracting from @@ -1303,58 +1447,79 @@ bool MemoryBoundLoopOptimizer::AllocatePrefetch( --i) { int loop_idx = (i + loop_size_) % loop_size_; float& bandwidth_idle_time = context.bandwidth_idle_times[loop_idx]; - // Overlap memory overhead only happens if the copy start overlaps with the - // first use (i.e. fully pipelined), so we'd need to account for 2X the - // buffer at this time. - int64_t overlap_memory_overhead = 0; - update_additional_memory_used(loop_idx, - value->size + overlap_memory_overhead); if (bandwidth_idle_time < copy_resource - accumulated_copy_resource) { accumulated_copy_resource += bandwidth_idle_time; bandwidth_idle_time = 0; - if (loop_idx == *copy_start_time) { + if (loop_idx == early_prefetch_copy_start_loop_idx) { VLOG(3) << "Remaining copy resource: " << (copy_resource - accumulated_copy_resource); break; } } else { bandwidth_idle_time -= copy_resource - accumulated_copy_resource; - CHECK_EQ(loop_idx, *copy_start_time); + CHECK_EQ(loop_idx, early_prefetch_copy_start_loop_idx); break; } } - // Create the Allocation objects that correspond to the scheduled prefetch. - CHECK(value->header_position); - value->allocations.push_back(std::make_unique( - *value->header_position, MemorySpace::kDefault, std::nullopt, 0, - loop_size_, /*is_scoped_allocation=*/false)); - value->allocations.push_back(std::make_unique( - *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, - ((*copy_start_time - 1) + loop_size_) % loop_size_, first_use_idx, - last_use_idx_sentinel)); - AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); - // Account for the additional memory used by early forcing the already // scheduled prefetches. Also modify the start times of these to this // prefetch's copy start time. - for (int early_forced_prefetch_value_index : - early_forced_prefetch_value_indices) { - LoopValue* early_forced_value = context.values.at( - context.value_indices[early_forced_prefetch_value_index]); + // Allocate the force-early prefetches first, and allocate them in the same + // order as we did to check for out-of-memory, so we can reproduce the same + // allocation pattern. + // TODO(subhankarshah): Instead of depending on the order of allocation, store + // the offsets of the early forced prefetches and use that to allocate them. + for (int i = 0; i < committed_early_forced_prefetches_count; ++i) { + int early_forced_prefetch_value_index = + early_forced_prefetch_value_indices[i]; + LoopValue* early_forced_value = + context + .values[context.value_indices[early_forced_prefetch_value_index]]; CHECK(!early_forced_value->allocations.empty()); CopyAllocation* early_forced_prefetch = static_cast( early_forced_value->allocations.back().get()); - for (int index = early_forced_prefetch->copy_start_schedule_after(); - index >= *copy_start_time; --index) { - update_additional_memory_used(index, early_forced_value->size); - VLOG(3) << "Additional memory used: " << index << " " - << context.additional_memory_used[index]; - } + int64_t begin_idx_in_loop = early_prefetch_copy_start_loop_idx; + int64_t end_idx_in_loop = + early_forced_value->alternate_memory_end_idx_in_loop.value(); + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, early_forced_value->size); + // The chunk should always be present as we reproducing the same allocation + // pattern as the out-of-memory check. + CHECK(chunks.HasValues()); + CHECK_LT(begin_idx_in_loop, + early_forced_value->alternate_memory_begin_idx_in_loop.value()); + early_forced_value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, + end_idx_in_loop); early_forced_prefetch->set_copy_start_schedule_after( - ((*copy_start_time - 1) + loop_size_) % loop_size_); - VLOG(3) << "Updated prefetch: " << early_forced_prefetch->ToString(); + ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_); + VLOG(3) << "Early forced prefetch: " << early_forced_value->ToString(); + VLOG(6) << "Memory usage after allocating early forced prefetch: " + << heap_.MemoryUsageToAsciiArt(); } + + // Create the Allocation objects that correspond to the scheduled prefetch. + CHECK(value->header_position); + value->allocations.push_back(std::make_unique( + *value->header_position, MemorySpace::kDefault, std::nullopt, 0, + loop_size_, /*is_scoped_allocation=*/false)); + int64_t begin_idx_in_loop = copy_start_loop_idx.value(); + int64_t end_idx_in_loop = last_use_idx_sentinel; + // The chunk should always be present as we reproducing the same allocation + // pattern as the out-of-memory check. + EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( + begin_idx_in_loop, end_idx_in_loop, value->size); + CHECK(chunks.HasValues()); + value->SetChunkPairAndInterval(chunks, begin_idx_in_loop, end_idx_in_loop); + value->allocations.push_back(std::make_unique( + *value->allocations.back(), MemorySpace::kAlternate, std::nullopt, + ((early_prefetch_copy_start_loop_idx - 1) + loop_size_) % loop_size_, + first_use_idx, last_use_idx_sentinel)); + VLOG(3) << "Allocation found for prefetch: " << value->ToString(); + VLOG(6) << "Memory usage after allocating prefetch: " << value->ToString() + << "\n" + << heap_.MemoryUsageToAsciiArt(); + AddAllLoopPositionsAndUses(*value, /*allocate_next_iteration_uses=*/true); return true; } diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h index 5af196b4323af7..3db5b5d5ed062b 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ #define XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_MEMORY_BOUND_LOOP_OPTIMIZER_H_ -#include #include #include #include @@ -26,16 +25,17 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -49,8 +49,14 @@ namespace xla { namespace memory_space_assignment { // Pair of chunks for even and odd loop iterations. -using EvenOddChunkPair = std::pair, - std::optional>; +struct EvenOddChunkPair { + std::optional even_chunk; + std::optional odd_chunk; + + bool HasValues() const { + return even_chunk.has_value() && odd_chunk.has_value(); + } +}; // LoopOptimizerBestFitHeap extends GlobalDecreasingSizeBestFitHeap to track // allocated buffers and their live intervals for the MemoryBoundLoopOptimizer. @@ -134,10 +140,10 @@ class LoopOptimizerBestFitHeap private: // REQUIRES: - // - begin_idx_in_loop <= end_idx_in_loop - // - begin_idx_in_loop is within [-loop_size loop_size) - // - end_idx_in_loop is within [0, 2 * loop_size) - // - end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation + // * begin_idx_in_loop <= end_idx_in_loop + // * begin_idx_in_loop is within [-loop_size loop_size) + // * end_idx_in_loop is within [0, 2 * loop_size) + // * end_idx_in_loop - begin_idx_in_loop + 1 <= 2 * loop_size (allocation // colocated in even (or odd) iterations cannot span more than 2 loop // iterations) void CheckAllocationIntervalValid(int64_t begin_idx_in_loop, @@ -254,6 +260,7 @@ class MemoryBoundLoopOptimizer { // We represent each tensor used in the current iteration as a LoopValue, // wrapping the relevant information such as its HLO value, indices and // pointers to its use and position sites in different iterations. + // TODO(b/364621066): Make LoopValue a class. struct LoopValue { // An enum that encodes the allocation type that is suitable for this // LoopValue. See the comment above on what each of these mean. @@ -273,6 +280,12 @@ class MemoryBoundLoopOptimizer { // of a loop value. bool IsAllocationTypeSupported() const; + // Sets the data members `chunks`, `alternate_memory_begin_idx_in_loop`, and + // `alternate_memory_end_idx_in_loop`. + void SetChunkPairAndInterval(EvenOddChunkPair chunk_pair, + int64_t begin_idx_in_loop, + int64_t end_idx_in_loop); + // The HloValues that correspond to this LoopValue. std::vector hlo_values; // The position in the header, if any. @@ -299,17 +312,25 @@ class MemoryBoundLoopOptimizer { float savings_per_byte; // The optimized AllocationSequence. AllocationSequence allocations; + // Chunks for even and odd iterations. If a loop value is double buffered + // then it must have different chunks for even and odd iterations. + EvenOddChunkPair chunks; + // Begin index of loop value in alternate memory. + // REQUIRES: + // * (-loop_size) <= alternate_memory_begin_idx_in_loop + // * alternate_memory_begin_idx_in_loop < loop_size + std::optional alternate_memory_begin_idx_in_loop = std::nullopt; + // End index of loop value in alternate memory. + // REQUIRES: + // * 0 <= alternate_memory_end_idx_in_loop + // * alternate_memory_end_idx_in_loop < 2*loop_size + std::optional alternate_memory_end_idx_in_loop = std::nullopt; }; // Factory method to create and initialize a MemoryBoundLoopOptimizer. static absl::StatusOr> Create( - int loop_start, int loop_end, uint64_t alternate_memory_size, - const MemoryBoundLoopOptimizerOptions& options, - const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis_, - const CostAnalysis& cost_analysis, - const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + int loop_start, int loop_end, const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis, const Options& options); // Optimize the loop. Initialize must be called first. void Optimize(); @@ -324,13 +345,16 @@ class MemoryBoundLoopOptimizer { // Return the remaining memory vector for each point in time in the loop using // the allocation decisions so far. - const std::vector& remaining_memory() const { - return remaining_memory_; + std::vector RemainingMemory() const { + return heap_.RemainingMemoryByTime(); } int64_t MaxAlternateMemoryUsed() const { - return alternate_memory_size_ - *std::min_element(remaining_memory_.begin(), - remaining_memory_.end()); + return heap_.LastMemoryOffsetOccupied(); + } + + std::string MemoryUsageToAsciiArt() const { + return heap_.MemoryUsageToAsciiArt(); } // The loop start, end, and size accessors. @@ -344,15 +368,12 @@ class MemoryBoundLoopOptimizer { // The values that are requested to be prefetched. absl::Span values; - // A list of indices into values array, sorted by the start time of the - // first use. + // A list of indices into values array, sorted by the (descending) start + // time of the first use. std::vector value_indices; // Default memory remaining bandwidths assuming all prefetches succeeded. std::vector bandwidth_idle_times; - - // Additional memory used while performing prefetching. - std::vector additional_memory_used; }; MemoryBoundLoopOptimizer( @@ -362,7 +383,8 @@ class MemoryBoundLoopOptimizer { const HloAliasAnalysis& alias_analysis_, const CostAnalysis& cost_analysis, const BufferValue::SizeFunction& size_function, - const ReservedScopedMemoryFunction& reserved_scoped_memory_fn); + const ReservedScopedMemoryFunction& reserved_scoped_memory_fn, + int64_t alignment_in_bytes); // Initializes the data structures used by the optimizer. absl::Status Initialize(); @@ -384,9 +406,6 @@ class MemoryBoundLoopOptimizer { // Allocate LoopValues by dispatching to the correct Allocate method. void AllocateLoopValues(); - // Allocate and reserve memory between the given indices. - bool AllocateBetween(int64_t begin_idx, int64_t end_idx, int64_t size); - // Perform allocation type kTemporary. Return true if successful. bool AllocateTemporary(LoopValue& value); @@ -440,13 +459,22 @@ class MemoryBoundLoopOptimizer { absl::flat_hash_map instructions_in_next_iteration_; std::vector loop_values_; - std::vector remaining_memory_; absl::flat_hash_map>> uses_in_alternate_mem_; absl::flat_hash_map> positions_in_alternate_mem_; const ReservedScopedMemoryFunction& reserved_scoped_memory_fn_; + + // The heap used to allocate loop values. Since some loop values can be double + // buffered, between successive iterations, they must have different chunks + // for even and odd iterations. We model 4 iterations of the loop to allocate + // the loop values to alternate memory so we can model the buffers that cross + // one or two loop boundaries. The allocations in the 2nd and 3rd iterations + // represent the actual memory view. The 0th and 1st iteration serve to + // account for allocations, whose buffers cross one or two loop boundaries, + // into the 2nd and 3rd iterations. + LoopOptimizerBestFitHeap heap_; }; } // namespace memory_space_assignment diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index f241269bb6fa77..e6257df3cbc073 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -34,13 +34,13 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "re2/re2.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -68,14 +68,9 @@ namespace { using ::testing::ContainerEq; using ::testing::HasSubstr; -constexpr int64_t kPointerSize = 8; - -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} int64_t SizeFunction(const BufferValue& value) { - return ShapeSize(value.shape()); + return HloCostAnalysis::DefaultShapeSize(value.shape()); } int64_t ReservedScopedMemoryFn( @@ -97,7 +92,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateSameEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool CanFindSameEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -105,7 +100,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindSameEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool IsAllocateEvenAndOddBetweenSuccessful(int64_t begin_idx_in_loop, @@ -113,7 +108,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } bool CanFindEvenAndOddAllocationBetween(int64_t begin_idx_in_loop, @@ -121,7 +116,7 @@ class LoopOptimizerBestFitHeapTest : public ::testing::Test { int64_t size) { EvenOddChunkPair chunks = heap_.FindEvenAndOddAllocationBetween( begin_idx_in_loop, end_idx_in_loop, size); - return chunks.first.has_value() && chunks.second.has_value(); + return chunks.HasValues(); } std::string GetMemoryUsageAsciiArt() { return heap_.MemoryUsageToAsciiArt(); } @@ -193,10 +188,9 @@ TEST_F(LoopOptimizerBestFitHeapTest, TestAllocateEvenAndOddBetween) { TEST_F(LoopOptimizerBestFitHeapTest, TestRemoveChunk) { EvenOddChunkPair chunks = heap_.AllocateEvenAndOddBetween(3, 11, 16); - EXPECT_TRUE(chunks.first.has_value() && chunks.second.has_value()); + EXPECT_TRUE(chunks.HasValues()); EvenOddChunkPair second_chunks = heap_.AllocateEvenAndOddBetween(-3, 8, 16); - EXPECT_TRUE(second_chunks.first.has_value() && - second_chunks.second.has_value()); + EXPECT_TRUE(second_chunks.HasValues()); EXPECT_THAT(heap_.RemainingMemoryByTime(), ContainerEq(std::vector{16, 16, 16, 0, 0, 0})); EXPECT_EQ(heap_.LastMemoryOffsetOccupied(), 64); @@ -285,7 +279,6 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { cost_analysis_options_.alternate_mem_bandwidth_bytes_per_second = 128; cost_analysis_options_.async_copy_bandwidth_bytes_per_second = 32; cost_analysis_options_.pipeline_overhead_window_size_mib = 1; - options.shape_size = ShapeSize; options.set_flops_per_second(16); options.set_bytes_per_second(32); options.set_transcendentals_per_second(16); @@ -314,12 +307,17 @@ class MemoryBoundLoopOptimizerTest : public HloTestBase { optimizer_options.set_enabled(true); optimizer_options.set_desired_copy_ratio(0.7); optimizer_options.set_allow_unsatisfied_fully_pipelined_prefetch(false); - TF_ASSIGN_OR_RETURN( - optimizer_, - MemoryBoundLoopOptimizer::Create( - loop_start, loop_end, alternate_memory_size, optimizer_options, - *live_range_, *alias_analysis_, *cost_analysis_, SizeFunction, - reserved_scoped_memory_fn)); + Options options; + options.max_size_in_bytes = alternate_memory_size; + options.alignment_in_bytes = 8; + options.alternate_memory_space = kAlternateMemorySpace; + options.cost_analysis = cost_analysis_.get(); + options.size_fn = SizeFunction; + options.reserved_scoped_memory_fn = reserved_scoped_memory_fn; + options.memory_bound_loop_optimizer_options = optimizer_options; + TF_ASSIGN_OR_RETURN(optimizer_, MemoryBoundLoopOptimizer::Create( + loop_start, loop_end, *live_range_, + *alias_analysis_, options)); return optimizer_.get(); } @@ -702,7 +700,10 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { )"; int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 64; + // Although alternate_memory_size=64 is minimum memory needed to fit the copy + // of param0 with desired copy ratio. alternate_memory_size=80 memory will + // ensure complete copy of param0 to alternate memory. + int64_t alternate_memory_size = 80; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -736,6 +737,55 @@ TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch) { EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); } +TEST_F(MemoryBoundLoopOptimizerTest, SimplePrefetch2) { + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op3, f32[1,4] $prev_op4) + $op1 = f32[1,4] add(f32[1,4] $prev_op4, f32[1,4] $op0) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $op1) + $op3 = f32[1,4] add(f32[1,4] $op1, f32[1,4] $op2) + $op4 = f32[1,4] add(f32[1,4] $param0, f32[1,4] $op3) + ROOT $root = tuple($op4, $param0) + )"; + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + // alternate_memory_size=64 is minimum memory needed to fit the copy of param0 + // with desired copy ratio. + int64_t alternate_memory_size = 64; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + absl::flat_hash_set seen_uses; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + LOG(INFO) << loop_value.ToString(); + if (loop_value.hlo_values.front() + ->defining_position() + .instruction->name() == "param0") { + EXPECT_TRUE(loop_value.allocations.back()->is_copy_allocation()); + } + for (const auto& allocation : loop_value.allocations) { + for (const HloUse& use : allocation->uses()) { + EXPECT_FALSE(seen_uses.contains(use)) << use.ToString(); + seen_uses.insert(use); + } + } + } + + // Ensure all of the uses in the loop have an associated use. + for (absl::string_view inst_name : {"op0", "op1", "op2", "op3", "op4"}) { + HloInstruction* inst = + module->entry_computation()->GetInstructionWithName(inst_name); + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 0})) << inst_name; + EXPECT_TRUE(seen_uses.contains(HloUse{inst, 1})) << inst_name; + } + // Check that execution time has increased to 2 since we will wait on copy + // done for param0. + EXPECT_EQ(optimizer->CalculateExecutionTime(), 2); + EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); +} + // Specify a ReservedScopedMemoryFunction to the loop optimizer that causes each // HLO to reserve the entire alternate memory. If the loop optimizer is // correctly accounting for reserved scoped memory, it should not put any @@ -773,10 +823,10 @@ TEST_F(MemoryBoundLoopOptimizerTest, ReservedScopedMemory) { // Check that a spurious GetTupleElement instruction in a later iteration of a // loop does not cause MSA to CHECK fail, when identifying loops. Prior to the -// change instroduced with this test, IdentifyAndOptimizeMemoryBoundLoops() +// change introduced with this test, IdentifyAndOptimizeMemoryBoundLoops() // would recognize 4 iterations to the loop thinking that gte is a repeat of // op2. Doing so triggers the CHECKs introduced by the change that added this -// test to fail. So, the point of this test is to verfiy that we do not check +// test to fail. So, the point of this test is to verify that we do not check // fail. TEST_F(MemoryBoundLoopOptimizerTest, GetTupleElement) { absl::string_view hlo_string = R"( @@ -909,7 +959,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 432; + int64_t alternate_memory_size = 464; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -985,7 +1035,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap) { EXPECT_EQ(optimizer->CalculateExecutionTime(), 12.5); // Check the memory used at each point of the loop. - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // Time 0: 3 temporaries (16 B) + param0 (128 B) + param1 (128 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 128 + 128)); @@ -1049,7 +1099,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 192; + int64_t alternate_memory_size = 208; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1133,7 +1183,7 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { int loop_start_idx; MemoryBoundLoopOptimizer* optimizer; - int64_t alternate_memory_size = 432; + int64_t alternate_memory_size = 464; TF_ASSERT_OK_AND_ASSIGN( auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, loop_start_idx, &optimizer)); @@ -1190,6 +1240,80 @@ TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithOverlap2) { EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size); } +TEST_F(MemoryBoundLoopOptimizerTest, PrefetchFifoOrderWithoutOverlap2) { + // Same as PrefetchFifoOrderWithoutOverlap, except that, we reduce the size of + // alternate memory, such that only one of param0 and param1 can be + // prefetched. Additionally, we add many more small prefetches, such that, + // during the prefetch of param0 or param1, a valid copy start time is found + // with desired copy ratio, but not with complete copy ratio, early forcing + // param2. After finding a copy start time with the desired copy ratio, when + // trying to find a better copy start time with complete copy ratio, more + // prefetches are temporarily early forced, but restored to their original + // state later. + absl::string_view hlo_loop_str = R"( + $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) + $op1 = f32[8,4] add(f32[8,4] $param0, f32[8,4] $param1) + $op2 = f32[1,4] add(f32[1,4] $op0, f32[1,4] $param9) + $op3 = f32[1,4] add(f32[1,4] $param7, f32[1,4] $param8) + $op4 = f32[1,4] add(f32[1,4] $param5, f32[1,4] $param6) + $op5 = f32[1,4] add(f32[1,4] $param4, f32[1,4] $op3) + $op6 = f32[1,4] add(f32[1,4] $op4, f32[1,4] $param3) + $op7 = f32[1,4] add(f32[1,4] $op5, f32[1,4] $op2) + $op8 = f32[1,4] add(f32[1,4] $op6, f32[1,4] $op7) + $op9 = f32[1,4] add(f32[1,4] $op7, f32[1,4] $op8) + $op10 = f32[1,4] add(f32[1,4] $op8, f32[1,4] $op9) + $op11 = f32[1,4] add(f32[1,4] $op9, f32[1,4] $op10) + $op12 = f32[1,4] add(f32[1,4] $op10, f32[1,4] $op11) + $op13 = f32[1,4] add(f32[1,4] $op11, f32[1,4] $op12) + $op14 = f32[1,4] add(f32[1,4] $param2, f32[1,4] $op13) + )"; + + int loop_start_idx; + MemoryBoundLoopOptimizer* optimizer; + int64_t alternate_memory_size = 384; + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndCreateOptimizer(hlo_loop_str, alternate_memory_size, + loop_start_idx, &optimizer)); + + optimizer->Optimize(); + std::vector prefetches; + for (const MemoryBoundLoopOptimizer::LoopValue& loop_value : + optimizer->loop_values()) { + if (!loop_value.allocations.empty() && + loop_value.allocations.back()->is_copy_allocation()) { + prefetches.push_back(static_cast( + loop_value.allocations.back().get())); + } + } + EXPECT_EQ(prefetches.size(), 9); + for (const CopyAllocation* prefetch : prefetches) { + const HloUse& use = *prefetch->uses().begin(); + if (use.instruction->name() == "op14") { + EXPECT_EQ(prefetch->copy_done_schedule_before(), 14); + EXPECT_EQ(prefetch->copy_start_schedule_after(), 6); + } else if (use.instruction->name() == "op1") { + EXPECT_EQ(prefetch->copy_start_schedule_after(), 6); + EXPECT_EQ(prefetch->copy_done_schedule_before(), 1); + } + } + + EXPECT_NEAR(optimizer->CalculateExecutionTime(), 16.7083, 1e-3); + const std::vector& remaining_memory = optimizer->RemainingMemory(); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(0), 176); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(1), 208); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(2), 112); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(3), 112); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(4), 112); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(5), 96); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(6), 80); + EXPECT_EQ(alternate_memory_size - remaining_memory.at(7), 208); + for (int i = 8; i < 14; ++i) { + EXPECT_EQ(alternate_memory_size - remaining_memory.at(i), 192); + } + EXPECT_EQ(alternate_memory_size - remaining_memory.at(14), 176); + EXPECT_EQ(optimizer->MaxAlternateMemoryUsed(), alternate_memory_size - 128); +} + TEST_F(MemoryBoundLoopOptimizerTest, OptimizerEndToEnd) { absl::string_view hlo_loop_str = R"( $op0 = f32[1,4] add(f32[1,4] $prev_op13, f32[1,4] $prev_op14) @@ -1302,13 +1426,13 @@ TEST_F(MemoryBoundLoopOptimizerTest, TempAndPinnedAllocations) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 64; + int64_t alternate_memory_size = 80; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(19, 24, module.get(), alternate_memory_size)); optimizer->Optimize(); - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // Time 0: 3 temporaries (16 B) + 1 pinned (16 B) EXPECT_EQ(remaining_memory.at(0), alternate_memory_size - (3 * 16 + 16)); // Time 1: 3 temporaries (16 B) + 1 pinned (16 B) @@ -1373,12 +1497,12 @@ TEST_F(MemoryBoundLoopOptimizerTest, NegativeSavingNotPinned) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_str)); - int64_t alternate_memory_size = 52; + int64_t alternate_memory_size = 72; TF_ASSERT_OK_AND_ASSIGN( auto optimizer, CreateOptimizer(21, 27, module.get(), alternate_memory_size)); optimizer->Optimize(); - const std::vector& remaining_memory = optimizer->remaining_memory(); + std::vector remaining_memory = optimizer->RemainingMemory(); // We expect that pinned_prev_param0 would not get pinned due to negative // savings: 32(uses) - 28 * 16(size) = -416 Time 0: 3 temporaries (16 B) + 1 // pinned (4 B) diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc index 753a348e14a76f..d6e63ff011473d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.cc @@ -38,6 +38,8 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -46,9 +48,7 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/algorithm.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -369,9 +369,15 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( TF_RETURN_IF_ERROR(SimplifyGraph()); TF_RETURN_IF_ERROR(FixSchedule()); TF_RETURN_IF_ERROR(ExportAndColorBuffers()); + std::vector alt_mem_bytes_occupied; + // alt_mem_bytes_occupied is used for logging in the RuntimeSimulator below. + // We only populate it in VerifyAndExportHeapSimulatorTrace if the + // RuntimeSimulator is present. + TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace( + runtime_simulator.has_value() ? &alt_mem_bytes_occupied : nullptr)); if (runtime_simulator.has_value()) { - float estimated_time = - runtime_simulator->SimulateElapsedTime(module_, allocations_); + float estimated_time = runtime_simulator->SimulateElapsedTime( + module_, allocations_, &alt_mem_bytes_occupied); VLOG(1) << "Estimated elapsed time with async copies (sec): " << estimated_time; } @@ -392,8 +398,6 @@ MemorySpaceAssignment::RunMemorySpaceAssignment( VLOG(1) << "Number of evictions: " << stats.num_evictions << ", in bytes: " << stats.eviction_bytes; - TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace()); - return std::move(preset_assignments_); } @@ -1003,7 +1007,8 @@ absl::Status MemorySpaceAssignment::FixSchedule() { return absl::OkStatus(); } -absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { +absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace( + std::vector* alt_mem_bytes_occupied) { VLOG(1) << "Verifying..."; TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module_)); @@ -1179,6 +1184,12 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { int64_t max_memory_usage = 0; int64_t prev_time = 0; int64_t prev_memory_usage = 0; + if (alt_mem_bytes_occupied) { + // Populate alt_mem_bytes_occupied with -1, for each instruction. + alt_mem_bytes_occupied->resize( + hlo_live_range->flattened_instruction_sequence().instructions().size(), + -1); + } for (const auto& event : events) { int64_t time; bool is_free; @@ -1215,8 +1226,23 @@ absl::Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { } prev_memory_usage = std::max(prev_memory_usage, memory_usage); max_memory_usage = std::max(max_memory_usage, memory_usage); - VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time; + if (alt_mem_bytes_occupied) { + // Update known alt mem usage. + (*alt_mem_bytes_occupied)[time] = memory_usage; + } } + if (alt_mem_bytes_occupied) { + // Replace the -1s in alt_mem_bytes_occupied with the previous alt memory + // usage. + int64_t prev_bytes = 0; + for (int64_t i = 0; i < alt_mem_bytes_occupied->size(); ++i) { + if ((*alt_mem_bytes_occupied)[i] == -1) { + (*alt_mem_bytes_occupied)[i] = prev_bytes; + } + prev_bytes = (*alt_mem_bytes_occupied)[i]; + } + } + VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage; return absl::OkStatus(); diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h index 35e5854ddfdf6e..09cc44fd1dee65 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.h @@ -178,12 +178,12 @@ Useful logging and error messages #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/cost_analysis.h" @@ -298,7 +298,11 @@ class MemorySpaceAssignment { // Verify that the memory space assignment is free of overlapping buffers and // export heap simulator trace to be used by buffer_assignment. - absl::Status VerifyAndExportHeapSimulatorTrace(); + // + // If alt_mem_bytes_occupied is not null, it will be populated with the number + // of bytes occupied in the alternate memory space at each instruction time. + absl::Status VerifyAndExportHeapSimulatorTrace( + std::vector* alt_mem_bytes_occupied = nullptr); protected: // Main driver of the memory space assignment pass. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 7fc30f30102e43..c5ff5d70ef5030 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -45,11 +45,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/transforms/simplifiers/instruction_hoister.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/layout_util.h" @@ -57,12 +60,9 @@ limitations under the License. #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" -#include "xla/service/instruction_hoister.h" #include "xla/service/memory_space_assignment/algorithm.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/service/memory_space_assignment/buffer_interval_comparator.h" @@ -98,16 +98,13 @@ using ::testing::_; using ::testing::Return; using ::testing::UnorderedElementsAre; -constexpr int64_t kPointerSize = 8; constexpr float kAsyncCopyBandwidth = 100; constexpr float kAlternateMemBandwidth = 1000; constexpr float kBytesPerSecond = 100; constexpr float kFlopsPerSecond = 1000; constexpr float kTranscendentalsPerSecond = 10; -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} +const auto& ShapeSize = HloCostAnalysis::DefaultShapeSize; int64_t SizeFunction(const BufferValue& value) { return ShapeSize(value.shape()); @@ -153,7 +150,6 @@ class MemorySpaceAssignmentTestBase : public HloTestBase { HloCostAnalysis::Options DefaultHloCostAnalysisOptions() { HloCostAnalysis::Options options; - options.shape_size = ShapeSize; options.set_flops_per_second(kFlopsPerSecond); options.set_bytes_per_second(kBytesPerSecond); options.set_transcendentals_per_second(kTranscendentalsPerSecond); @@ -692,8 +688,19 @@ ENTRY entry { negate5 = f32[2,3]{1,0} negate(negate4) negate6 = f32[2,3]{1,0} negate(negate5) negate7 = f32[2,3]{1,0} negate(negate6) - p0_copy = f32[2,3]{1,0} copy(p0) - ROOT tuple0 = tuple(negate7, p0, p0_copy) + p0_copy0 = f32[2,3]{1,0} copy(p0) + p0_copy1 = f32[2,3]{1,0} copy(p0) + negate8 = f32[2,3]{1,0} negate(negate7) + negate9 = f32[2,3]{1,0} negate(negate8) + negate10 = f32[2,3]{1,0} negate(negate9) + negate11 = f32[2,3]{1,0} negate(negate10) + negate12 = f32[2,3]{1,0} negate(negate11) + constant.1 = f32[] constant(0) + broadcast = f32[2,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice.0 = f32[2,3] dynamic-update-slice(p0_copy0, broadcast, constant.3, constant.3) + dynamic-update-slice.1 = f32[2,3] dynamic-update-slice(p0_copy1, broadcast, constant.3, constant.3) + ROOT tuple0 = tuple(negate12, dynamic-update-slice.0, dynamic-update-slice.1) } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -701,10 +708,25 @@ ENTRY entry { Options options = DefaultMemorySpaceOptions(); options.enable_sync_copy_replacement = true; AssignMemorySpace(module.get(), options); - HloInstruction* tuple0 = FindInstruction(module.get(), "tuple0"); - ASSERT_NE(tuple0->operand(1), tuple0->operand(2)); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + + HloInstruction* dynamic_update_slice_0 = + FindInstruction(module.get(), "dynamic-update-slice.0"); + HloInstruction* dynamic_update_slice_1 = + FindInstruction(module.get(), "dynamic-update-slice.1"); + const HloInstruction* p0_copy0_replacement = + dynamic_update_slice_0->operand(0); + const HloInstruction* p0_copy1_replacement = + dynamic_update_slice_1->operand(0); + EXPECT_THAT(p0_copy0_replacement, + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + EXPECT_THAT(p0_copy1_replacement, + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + ASSERT_NE(p0_copy0_replacement, p0_copy1_replacement); } +// All uses of the sync copy operand that are scheduled before the replaced sync +// copy share the allocation in alternate memory (if any). TEST_F(MemorySpaceAssignmentTest, SyncCopyReplacementOperandHasMultipleUses) { absl::string_view hlo_string = R"( HloModule module, is_scheduled=true @@ -712,17 +734,19 @@ HloModule module, is_scheduled=true ENTRY entry { p0 = f32[2,3]{1,0} parameter(0) p1 = f32[2,3]{1,0} parameter(1) - negate0 = f32[2,3]{1,0} negate(p1) - negate1 = f32[2,3]{1,0} negate(negate0) - negate2 = f32[2,3]{1,0} negate(negate1) - negate3 = f32[2,3]{1,0} negate(negate2) - negate4 = f32[2,3]{1,0} negate(negate3) - negate5 = f32[2,3]{1,0} negate(negate4) - negate6 = f32[2,3]{1,0} negate(negate5) - negate7 = f32[2,3]{1,0} negate(negate6) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + p0_negate0 = negate(p0) + p0_negate1 = negate(p0) + negate6 = negate(negate5) + negate7 = negate(negate6) p0_copy = f32[2,3]{1,0} copy(p0) - add0 = add(p0_copy, p0) - ROOT tuple = tuple(negate7, add0) + add0 = add(p0_copy, p0_negate0) + ROOT tuple = tuple(negate7, add0, p0_negate1) } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -730,8 +754,421 @@ ENTRY entry { Options options = DefaultMemorySpaceOptions(); options.enable_sync_copy_replacement = true; AssignMemorySpace(module.get(), options); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); HloInstruction* add0 = FindInstruction(module.get(), "add0"); - ASSERT_EQ(add0->operand(0), add0->operand(1)); + HloInstruction* p0_negate0 = FindInstruction(module.get(), "p0_negate0"); + HloInstruction* p0_negate1 = FindInstruction(module.get(), "p0_negate1"); + + EXPECT_THAT(add0->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + EXPECT_THAT(p0_negate0->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + EXPECT_THAT(p0_negate1->operand(0), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0)); + ASSERT_EQ(p0_negate0->operand(0), p0_negate1->operand(0)); + ASSERT_NE(add0->operand(0), p0_negate0->operand(0)); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementAfterPrefetch) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice = f32[1,2,3] slice(p0), slice={[0:1], [0:2], [0:3]} + concat = f32[11,2,3] concatenate(negate7, slice), dimensions={0} + ROOT root = negate(concat) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 512; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + EXPECT_THAT(concat->operand(1), op::AsyncDone(op::AsyncStart(p0))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementIgnoredTrivials) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice = f32[1,2,3] slice(p0), slice={[0:1], [0:2], [0:3]} + bitcast0 = f32[1,3,2] bitcast(slice) + bitcast1 = f32[10,3,2] bitcast(negate7) + concat = f32[11,3,2] concatenate(bitcast1, bitcast0), dimensions={0} + ROOT root = negate(concat) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 512; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + EXPECT_THAT(concat->operand(1), + op::Bitcast(op::AsyncDone(op::AsyncStart(p0)))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementAfterEviction) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[8,4,2]{2,1,0} parameter(0) + p1 = f32[4,4,2]{2,1,0} parameter(1) + negate_p0 = negate(p0) + slice0 = f32[1,4,2] slice(negate_p0), slice={[0:1], [0:4], [0:2]} + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + negate8 = negate(negate7) + negate9 = negate(negate8) + negate10 = negate(negate9) + slice1 = f32[1,4,2] slice(negate10), slice={[0:1], [0:4], [0:2]} + concat = f32[2,4,2] concatenate(slice0, slice1), dimensions={0} + ROOT root = negate(concat) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 400; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + + AssignMemorySpace(module.get(), options); + + HloInstruction* negate_p0 = FindInstruction(module.get(), "negate_p0"); + ASSERT_NE(negate_p0, nullptr); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT(concat->operand(0), + op::AsyncDone(op::AsyncStart(op::AsyncCopy( + kDefaultMemorySpace, kAlternateMemorySpace, negate_p0)))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementTwoSlices) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice.1 = f32[1,2,3] slice(p0), slice={[0:1], [0:2], [0:3]} + slice.2 = f32[1,2,3] slice(p0), slice={[1:2], [0:2], [0:3]} + add = f32[1,2,3] add(slice.1, slice.2) + ROOT concat = f32[11,2,3] concatenate(negate7, add), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 512; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* add = FindInstruction(module.get(), "add"); + ASSERT_NE(add, nullptr); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + ASSERT_NE(add->operand(0), add->operand(1)); + EXPECT_THAT(add->operand(0), op::AsyncDone(op::AsyncStart(p0))); + EXPECT_THAT(add->operand(1), op::AsyncDone(op::AsyncStart(p0))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementNestedSlices) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice0 = f32[9,2,3] slice(p0), slice={[0:9], [0:2], [0:3]} + negate8 = f32[10,2,3] negate(negate7) + negate9 = f32[10,2,3] negate(negate8) + negate10 = f32[10,2,3] negate(negate9) + negate11 = f32[10,2,3] negate(negate10) + negate12 = f32[10,2,3] negate(negate11) + slice1 = f32[1,2,3] slice(slice0), slice={[0:1], [0:2], [0:3]} + concat = f32[11,2,3] concatenate(negate12, slice1), dimensions={0} + ROOT root = negate(concat) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 300; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT(concat->operand(1), op::Slice(op::AsyncDone(op::AsyncStart(p0)))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementOneFails) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[1,2,3]{2,1,0} parameter(1) + negate0 = f32[1,2,3]{2,1,0:S(1)} negate(p1) + negate1 = f32[1,2,3]{2,1,0:S(1)} negate(negate0) + negate2 = f32[1,2,3]{2,1,0:S(1)} negate(negate1) + negate3 = f32[1,2,3]{2,1,0:S(1)} negate(negate2) + negate4 = f32[1,2,3]{2,1,0:S(1)} negate(negate3) + negate5 = f32[1,2,3]{2,1,0:S(1)} negate(negate4) + negate6 = f32[1,2,3]{2,1,0:S(1)} negate(negate5) + negate7 = f32[1,2,3]{2,1,0:S(1)} negate(negate6) + slice.0 = f32[8,2,3] slice(p0), slice={[0:8], [0:2], [0:3]} + slice.1 = f32[1,2,3] slice(p0), slice={[8:9], [0:2], [0:3]} + slice.2 = f32[1,2,3] slice(p0), slice={[9:10], [0:2], [0:3]} + add.0 = f32[1,2,3] add(slice.1, slice.2) + concat.0 = f32[9,2,3] concatenate(slice.0, add.0), dimensions={0} + ROOT concat.1 = f32[10,2,3] concatenate(negate7, concat.0), dimensions={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 72; + options.enable_sync_copy_replacement = false; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + HloInstruction* add0 = FindInstruction(module.get(), "add.0"); + ASSERT_NE(add0, nullptr); + EXPECT_THAT(add0->operand(0), op::AsyncDone(op::AsyncStart(p0))); + EXPECT_THAT(add0->operand(1), op::AsyncDone(op::AsyncStart(p0))); + HloInstruction* slice0 = FindInstruction(module.get(), "slice.0"); + ASSERT_NE(slice0, nullptr); +} + +// The prefetch logic has to correctly distinguish the output shape of an async +// copy vs an async slice. In this test, a prefetch of p0 would not fit into the +// memory, while prefetching a slice of p0 is feasible. +TEST_F(MemorySpaceAssignmentTest, SyncSliceReplacementTheSlicedOneFits) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[1,2,3]{2,1,0} parameter(1) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice = f32[1,2,3] slice(p0), slice={[0:1], [0:2], [0:3]} + concat = f32[2,2,3] concatenate(negate7, slice), dimensions={0} + ROOT root = negate(concat) + } + )"; + + Options options = DefaultMemorySpaceOptions(); + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + options.max_size_in_bytes = 64; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module1, + ParseAndReturnVerifiedModule(hlo_string)); + options.enable_sync_slice_replacement = false; + AssignMemorySpace(module1.get(), options); + HloInstruction* p0 = FindInstruction(module1.get(), "p0"); + ASSERT_NE(p0, nullptr); + HloInstruction* concat = FindInstruction(module1.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT(concat->operand(1), op::Slice(p0)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module2, + ParseAndReturnVerifiedModule(hlo_string)); + options.enable_sync_slice_replacement = true; + AssignMemorySpace(module2.get(), options); + p0 = FindInstruction(module2.get(), "p0"); + ASSERT_NE(p0, nullptr); + concat = FindInstruction(module2.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT(concat->operand(1), op::AsyncDone(op::AsyncStart(p0))); +} + +TEST_F(MemorySpaceAssignmentTest, SyncReplacementMultipleOpTypes) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +ENTRY entry { + p0 = f32[10,2,3]{2,1,0} parameter(0) + p1 = f32[10,2,3]{2,1,0} parameter(1) + p0_copy = copy(p0) + negate0 = negate(p1) + negate1 = negate(negate0) + negate2 = negate(negate1) + negate3 = negate(negate2) + negate4 = negate(negate3) + negate5 = negate(negate4) + negate6 = negate(negate5) + negate7 = negate(negate6) + slice = f32[1,2,3] slice(p0_copy), slice={[0:1], [0:2], [0:3]} + concat = f32[11,2,3] concatenate(negate7, slice), dimensions={0} + ROOT root = negate(concat) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 512; + options.enable_sync_copy_replacement = true; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + HloInstruction* concat = FindInstruction(module.get(), "concat"); + ASSERT_NE(concat, nullptr); + EXPECT_THAT( + concat->operand(1), + op::Slice(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, p0))); +} + +// This test is for the redundant aliasing bug (b/374902759) introduced between +// different operands of the same instruction while converting sync copy to +// async ones. +TEST_F(MemorySpaceAssignmentTest, SyncReplacementAliasingBug) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true, entry_computation_layout={(f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[])->f32[10,2,3]{2,1,0}} + +%while_body (p0.1: (f32[10,2,3], f32[10,2,3], f32[10,2,3], pred[])) -> (f32[10,2,3], f32[10,2,3], f32[10,2,3], pred[]) { + %p0.1 = (f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) parameter(0) + %gte0 = f32[10,2,3]{2,1,0} get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %p0.1), index=0 + %gte1 = f32[10,2,3]{2,1,0} get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %p0.1), index=1 + %gte2 = f32[10,2,3]{2,1,0} get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %p0.1), index=2 + %gte3 = pred[] get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %p0.1), index=3 + %neg0 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %gte2) + %neg1 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg0) + %neg2 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg1) + %neg3 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg2) + %neg4 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg3) + %neg5 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg4) + %neg6 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg5) + %neg7 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %neg6) + ROOT %tuple = (f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) tuple(f32[10,2,3]{2,1,0} %gte0, f32[10,2,3]{2,1,0} %gte1, f32[10,2,3]{2,1,0} %neg7, pred[] %gte3) +} + +%while_cond (p0: (f32[10,2,3], f32[10,2,3], f32[10,2,3], pred[])) -> pred[] { + %p0 = (f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) parameter(0) + ROOT %gte = pred[] get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %p0), index=3 +} + +ENTRY %entry (p0.2: f32[10,2,3], p1: f32[10,2,3], p2: pred[]) -> f32[10,2,3] { + %p0.2 = f32[10,2,3]{2,1,0} parameter(0) + %p1 = f32[10,2,3]{2,1,0} parameter(1) + %p2 = pred[] parameter(2) + p0_copy = f32[10,2,3]{2,1,0} copy(f32[10,2,3]{2,1,0} %p0.2) + %p0_copy_copy = f32[10,2,3]{2,1,0} copy(f32[10,2,3]{2,1,0} p0_copy) + %negate0 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %p1) + %negate1 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate0) + %negate2 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate1) + %negate3 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate2) + %negate4 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate3) + %negate5 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate4) + %negate6 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate5) + %negate7 = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %negate6) + %tuple.3 = (f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) tuple(f32[10,2,3]{2,1,0} %negate7, f32[10,2,3]{2,1,0} %p0_copy, f32[10,2,3]{2,1,0} %p0_copy_copy, pred[] %p2) + while.1 = (f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) while((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) %tuple.3), condition=%while_cond, body=%while_body + %gte.1 = f32[10,2,3]{2,1,0} get-tuple-element((f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, f32[10,2,3]{2,1,0}, pred[]) while.1), index=2 + ROOT %negate = f32[10,2,3]{2,1,0} negate(f32[10,2,3]{2,1,0} %gte.1) +} + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + Options options = DefaultMemorySpaceOptions(); + options.max_size_in_bytes = 1024; + options.enable_sync_copy_replacement = true; + options.enable_sync_slice_replacement = true; + options.is_async_slice_implemented_fn = + [](const HloInstruction* instruction) { return true; }; + AssignMemorySpace(module.get(), options); + HloInstruction* while_instruction = FindInstruction(module.get(), "while.1"); + ASSERT_NE(while_instruction, nullptr); + const HloInstruction* tuple = while_instruction->operand(0); + HloInstruction* p0_copy = FindInstruction(module.get(), "p0_copy"); + ASSERT_NE(p0_copy, nullptr); + EXPECT_THAT(tuple->operand(1), op::AsyncCopy(kDefaultMemorySpace, + kAlternateMemorySpace, p0_copy)); + EXPECT_THAT(tuple->operand(2), + op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace, + tuple->operand(1))); } TEST_F(MemorySpaceAssignmentTest, AlwaysSpillJitPrefetchTest) { @@ -3131,7 +3568,6 @@ TEST_F(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); AssignMemorySpace(module.get()); - // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation. // This will force an eviction and a prefetch for while body root. auto copy0 = @@ -8348,6 +8784,77 @@ entry { VLOG(2) << "module: " << module->ToString(); } +// This test verifies that window prefetched operands are seen by the +// reserved_scoped_memory_fn. Because window prefetched operands allocates space +// in the alternate memory, which will be identified as prefetched_operands. +// Therefore they will be seen by reserved_scoped_memory_fn. +TEST_F(MemorySpaceAssignmentTest, + WindowPrefetchedOperandsAreSeenByReservedScopedMemoryFn) { + absl::string_view hlo_string = R"( + HloModule module, is_scheduled=true + + fused_computation { + param0 = f32[1024] parameter(0) + param1 = f32[1024] parameter(1) + ROOT root = f32[1024] add(param0, param1) + } + + ENTRY Entry { + param0 = f32[1024] parameter(0) + param1 = f32[1024] parameter(1) + ROOT fusion = f32[1024] fusion(param0, param1), kind=kLoop, calls=fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + bool seen_window_prefetched_operand = false; + + Options options = DefaultMemorySpaceOptions(); + options.max_repacks = 10; + options.repack_after_every_allocation = true; + options.reduce_scoped_memory_limit = true; + options.reserved_scoped_memory_fn = + [&](const HloInstruction* instruction, + const absl::flat_hash_set> + operands_in_alternate_memory, + const absl::flat_hash_set outputs_in_alternate_memory) { + if (instruction == fusion && !operands_in_alternate_memory.empty()) { + seen_window_prefetched_operand = true; + } + return 1; + }; + + // Make sure that the alternate memory is larger than the fusion operand's + // full size, but smaller than its span buffer size, so that it will be window + // prefetched. + options.enable_window_prefetch = true; + ASSERT_LT(options.max_size_in_bytes, 1024); + ASSERT_GT(options.max_size_in_bytes, 32); + // This lambda instructs MSA to allocate 32 bytes in the alternate memory as + // span buffer of the fusion instruction. + options.window_prefetch_detail_fn = + [&](const HloInstruction* instruction) -> WindowPrefetchDetail { + WindowPrefetchDetail detail; + if (instruction == fusion) { + WindowPrefetchDetail::WindowDetail* window = detail.add_windows(); + window->set_operand(0); + window->set_size(32); + } + return detail; + }; + + // Run memory space assignment and verify that window prefetched operands are + // seen by the reserved_scoped_memory_fn. + absl::flat_hash_map, int64_t> repack_map; + FakeMemorySpaceAssignmentRepacker repacker = + FakeMemorySpaceAssignmentRepacker(repack_map, nullptr); + options.repacker = &repacker; + AssignMemorySpace(module.get(), options, /*max_prefetch_interval=*/10, + /*min_prefetch_interval=*/0); + EXPECT_TRUE(seen_window_prefetched_operand); +} + using AsynchronousCopyOrderingTest = ::testing::Test; TEST_F(AsynchronousCopyOrderingTest, Simple) { @@ -9993,7 +10500,8 @@ ENTRY main { // Setup cost analysis so it takes 2 instructions to prefetch anything. HloCostAnalysis::Properties properties; properties[HloCostAnalysis::kBytesAccessedKey] = kBytesPerSecond; - HloCostAnalysis hlo_cost_analysis(ShapeSize, properties); + HloCostAnalysis hlo_cost_analysis(HloCostAnalysis::DefaultShapeSize, + properties); CostAnalysisOptions cost_analysis_options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( @@ -11066,6 +11574,42 @@ ENTRY main { TF_EXPECT_OK(CheckSliceChunks(*assignments, root->operand(1))); } +TEST_F(SlicedPrefetchTest, TwoSlicesWithCopyReplacement) { + std::string hlo_text = R"zz( +HloModule Slice, is_scheduled=true + +ENTRY main { + p0 = f32[8,8] parameter(0) + p1 = f32[8,8] parameter(1) + + a = f32[8,8] tanh(p0) + b = f32[8,8] tanh(a) + c = f32[8,8] tanh(b) + + p1_copy1 = f32[8,8] copy(p1) + p1_copy2 = f32[8,8] copy(p1) + + r1 = f32[8,8] add(c, p1_copy1) + r2 = f32[8,8] add(c, p1_copy2) + + ROOT r = f32[8,8] add(r1, r2) +})zz"; + Options options = options_; + options.enable_sync_copy_replacement = true; + SetupProposeSlicesToExpect2SlicesOfF32x8x8(); + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); + VLOG(1) << "Original module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); + + std::unique_ptr assignments = AssignMemorySpace( + module.get(), options, + /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/1); + + VLOG(1) << "Post-MSA module:\n" + << module->ToString(HloPrintOptions::ShortParsable()); +} + TEST_F(SlicedPrefetchTest, ThreeSlices) { std::string hlo_text = R"zz( HloModule Slice, is_scheduled=true diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h index fb9730ced90641..48799ae046cea9 100644 --- a/third_party/xla/xla/service/memory_space_assignment/options.h +++ b/third_party/xla/xla/service/memory_space_assignment/options.h @@ -61,6 +61,8 @@ using WindowPrefetchDetailFunction = std::function; using WindowPrefetchNotifyOperandAppendedFunction = std::function; +using IsAsyncSliceImplementedFunction = + std::function; // The different options to be passed to the Run() API. struct Options { @@ -103,7 +105,9 @@ struct Options { [](const HloPosition&) { return true; }; // This function returns the amount of scoped memory in bytes that should be - // reserved during the execution of this instruction. + // reserved during the execution of this instruction. Note that the + // `operands_in_alternate_memory` also includes the window prefetched + // operands. ReservedScopedMemoryFunction reserved_scoped_memory_fn = [](const HloInstruction*, const absl::flat_hash_set< @@ -124,6 +128,9 @@ struct Options { WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn = [](HloInstruction*, int64_t, int64_t) {}; + IsAsyncSliceImplementedFunction is_async_slice_implemented_fn = + [](const HloInstruction*) { return false; }; + // If true, we will try to reduce scoped allocation buffer size for all // instructions if their operand/output has been allocated in alternate // memory. @@ -210,6 +217,14 @@ struct Options { // ones. If it fails to replace the copy, it keeps the sync version. bool enable_sync_copy_replacement = false; + // If true, tries to replace synchronous slice instructions with asynchronous + // ones. If it fails to replace the slice, it keeps the sync version. + bool enable_sync_slice_replacement = false; + + // If non-zero, this is the number of extra outstanding async copies that we + // allow for each sync mem op that is converted to an async mem op. + int extend_async_copies_limit_for_sync_mem_op_conversion = 0; + // The ratio of use bytes to copy bytes for a given allocation site below // which we consider the site to be inefficient. A value of 0 would treat all // sites as efficient and a value of 1 would require the amount of bytes used diff --git a/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc index 0ea6a0dfb5ba54..16c1e5be46f365 100644 --- a/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/prefetch_interval_picker_test.cc @@ -36,12 +36,6 @@ namespace xla { namespace memory_space_assignment { namespace { -constexpr int64_t kPointerSize = 8; - -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { @@ -77,7 +71,7 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( @@ -177,7 +171,7 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( @@ -261,7 +255,7 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( @@ -330,7 +324,7 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( @@ -376,7 +370,7 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, EarliestLatestWindowTooSmall) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - HloCostAnalysis hlo_cost_analysis(ShapeSize); + HloCostAnalysis hlo_cost_analysis; CostAnalysisOptions options; HloCostAnalysisCosts hlo_cost_analysis_costs(hlo_cost_analysis); TF_ASSERT_OK_AND_ASSIGN( diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator.cc b/third_party/xla/xla/service/memory_space_assignment/simulator.cc index d547c1e65eb998..14ce882f726c88 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator.cc +++ b/third_party/xla/xla/service/memory_space_assignment/simulator.cc @@ -20,20 +20,23 @@ limitations under the License. #include #include #include +#include #include +#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/layout.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/memory_space_assignment/allocation.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -232,7 +235,8 @@ float RuntimeSimulator::SimulateAsyncCopyLikeDone( return elapsed_time; }; -float RuntimeSimulator::SimulateComputeInstruction( +RuntimeSimulator::ElapsedAndIdleTimes +RuntimeSimulator::SimulateComputeInstruction( const HloInstruction* instruction, absl::Span> operands_in_alternate_memory, @@ -245,17 +249,21 @@ float RuntimeSimulator::SimulateComputeInstruction( outputs_in_alternate_memory); // Execute the outstanding async copy likes in the idle time. - ProcessAsyncCopyLikesInIdleTime(default_memory_idle_time); + default_memory_idle_time = + ProcessAsyncCopyLikesInIdleTime(default_memory_idle_time); float inst_elapsed = cost_analysis_->GetInstructionElapsedInAlternateMemory( *instruction, operands_in_alternate_memory, outputs_in_alternate_memory); - return inst_elapsed; + return {inst_elapsed, default_memory_idle_time}; } -void RuntimeSimulator::ProcessAsyncCopyLikesInIdleTime(float time) { +float RuntimeSimulator::ProcessAsyncCopyLikesInIdleTime(float time) { if (time <= 0.0) { - return; + return 0.0; } + + float available_bandwidth = cost_analysis_->base_costs().BytesPerSecond(); + float remaining_simulation_time = time; // This loop simulates the execution of the front memory requests in the // read and/or write queues. The loop terminates when the remaining time is @@ -263,7 +271,6 @@ void RuntimeSimulator::ProcessAsyncCopyLikesInIdleTime(float time) { while ((!outstanding_read_default_queue_.empty() || !outstanding_write_default_queue_.empty()) && remaining_simulation_time > 0.0) { - float available_bandwidth = cost_analysis_->base_costs().BytesPerSecond(); if (!outstanding_read_default_queue_.empty() && !outstanding_write_default_queue_.empty()) { // Need to share the bandwidth @@ -283,15 +290,33 @@ void RuntimeSimulator::ProcessAsyncCopyLikesInIdleTime(float time) { float real_elapsed_time = bytes_to_process / available_bandwidth; remaining_simulation_time -= real_elapsed_time; + if (remaining_simulation_time <= 0.0) { + // This can happen due to floating point errors. + remaining_simulation_time = 0.0; + } + RemoveBytesFromQueueIfNotEmpty(outstanding_read_default_queue_, bytes_to_process); RemoveBytesFromQueueIfNotEmpty(outstanding_write_default_queue_, bytes_to_process); } + + return remaining_simulation_time; +} + +namespace { + +float GetUnusedDefaultMemBandwidthBytes(float bytes_per_second, float seconds) { + CHECK_GE(bytes_per_second, 0.0); + + return bytes_per_second * seconds; } +} // namespace + float RuntimeSimulator::SimulateElapsedTime( - const HloModule* hlo_module, const AllocationSequence& allocations) { + const HloModule* hlo_module, const AllocationSequence& allocations, + const std::vector* alt_mem_bytes_occupied) { InitializeAlternateMemoryMap(allocations); std::unique_ptr alias_analysis = @@ -305,11 +330,21 @@ float RuntimeSimulator::SimulateElapsedTime( CHECK_GT(cost_analysis_->base_costs().BytesPerSecond(), 0.0); float total_elapsed = 0.0; - + // The number of additional bytes that could be transferred between default + // and alternate memory. + float cumulative_available_transfer_bytes = 0.0; + + if (alt_mem_bytes_occupied) { + CHECK_EQ( + alt_mem_bytes_occupied->size(), + hlo_live_range->flattened_instruction_sequence().instructions().size()); + } const auto& instruction_sequence = hlo_live_range->flattened_instruction_sequence().instructions(); - for (const HloInstruction* instruction : instruction_sequence) { + for (int time = 0; time < instruction_sequence.size(); ++time) { + const HloInstruction* instruction = instruction_sequence[time]; float inst_elapsed = 0.0; + float idle_default_memory_bandwidth_time = 0.0; if (instruction->opcode() == HloOpcode::kWhile) { // Since the instructions in the while body are calculated // separately, we can skip the while instruction. @@ -356,10 +391,14 @@ float RuntimeSimulator::SimulateElapsedTime( if (operand_it != operands_in_alternate_memory_map_.end()) operands_in_alternate_memory = absl::MakeSpan(operand_it->second); - inst_elapsed = + ElapsedAndIdleTimes elapsed_and_idle = SimulateComputeInstruction(instruction, operands_in_alternate_memory, outputs_in_alternate_memory); + inst_elapsed = elapsed_and_idle.elapsed_time; + idle_default_memory_bandwidth_time = + elapsed_and_idle.idle_default_memory_bandwidth_time; } + float total_trip_count = 0.0; if (inst_elapsed > 0.0) { // The calculation assumes all instructions are executed independently. // Thus, the execution time is the same for each invocation. This property @@ -368,12 +407,39 @@ float RuntimeSimulator::SimulateElapsedTime( // the loop body. In this case, the first async copy in the first // iteration will be slower than other iterations, since it needs to wait // for the async copies issued before the loop. - float total_trip_count = cost_analysis_->CalculateNestTripCount( + total_trip_count = cost_analysis_->CalculateNestTripCount( instruction, &cost_analysis_cache_); total_elapsed += inst_elapsed * total_trip_count; } + + cumulative_available_transfer_bytes += + (GetUnusedDefaultMemBandwidthBytes( + cost_analysis_->base_costs().BytesPerSecond(), + idle_default_memory_bandwidth_time) * + total_trip_count); + VLOG(2) << [&]() { + std::string instruction_name(instruction->name()); + if (instruction->opcode() == HloOpcode::kCopyStart && + instruction->cross_program_prefetch_index().has_value()) { + absl::StrAppend(&instruction_name, " (xprogram prefetch)"); + } + std::string alt_mem_bytes_occupied_str = ""; + if (alt_mem_bytes_occupied) { + alt_mem_bytes_occupied_str = + absl::StrCat("; alt mem usage: ", alt_mem_bytes_occupied->at(time)); + } + + return absl::StrCat(time, ": instruction: ", instruction_name, + "; elapsed: ", inst_elapsed, + "; cumulative available transfer bytes: ", + cumulative_available_transfer_bytes, + "; trip count: ", total_trip_count, + alt_mem_bytes_occupied_str); + }(); } + return total_elapsed; } + } // namespace memory_space_assignment } // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator.h b/third_party/xla/xla/service/memory_space_assignment/simulator.h index 906322b259a275..425f579bc21ddf 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator.h +++ b/third_party/xla/xla/service/memory_space_assignment/simulator.h @@ -59,6 +59,13 @@ struct OutstandingAsyncCopyLike { // A wrapper class around runtime simulator. class RuntimeSimulator { public: + // A struct that captures an instructions elapsed time and the amount of time + // we estimate default memory bandwidth to be idle, during that instruction. + struct ElapsedAndIdleTimes { + float elapsed_time; + float idle_default_memory_bandwidth_time; + }; + explicit RuntimeSimulator(CostAnalysis* cost_analysis, int64_t alternate_memory_space) : cost_analysis_(cost_analysis), @@ -93,8 +100,12 @@ class RuntimeSimulator { // there is spare bandwidth to simulate async memory accesses to default // memory. If we get to an async copy like done, we must wait until it // finishes (potentially waiting for copies issued before it to finish. - float SimulateElapsedTime(const HloModule* hlo_module, - const AllocationSequence& allocations); + // + // alt_mem_bytes_occupied is a vector of the amount of alt mem bytes allocated + // at any given instruction. It may be null. + float SimulateElapsedTime( + const HloModule* hlo_module, const AllocationSequence& allocations, + const std::vector* alt_mem_bytes_occupied = nullptr); // This is an auxiliary function for simulating the execution // time for executing a copy-done instruction. It returns the @@ -125,7 +136,7 @@ class RuntimeSimulator { // Aside from returning the elapsed time, this function also updates the // outstanding memory request queues, by draining them when the compute // instruction is not occupying bandwidth. - float SimulateComputeInstruction( + ElapsedAndIdleTimes SimulateComputeInstruction( const HloInstruction* compute_instruction, absl::Span> operands_in_alternate_memory, @@ -152,7 +163,9 @@ class RuntimeSimulator { // the memory access queues in a given amount of time (seconds). If both // outstanding_*_default_queues are non-empty, they share bandwidth. If one of // the queues is empty and the other is not, it gets the full bandwdith. - void ProcessAsyncCopyLikesInIdleTime(float time); + // + // Returns the remaining idle time after processing async-copy-likes. + float ProcessAsyncCopyLikesInIdleTime(float time); int64_t alternate_memory_space_; std::list outstanding_read_default_queue_; diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc index 3b61a70f9309f5..2f822ffccc2e16 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc @@ -27,11 +27,11 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/heap_simulator/heap_simulator.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_value.h" #include "xla/service/memory_space_assignment/allocation.h" @@ -53,13 +53,8 @@ using memory_space_assignment::RuntimeSimulator; using ::testing::ElementsAreArray; using ::testing::IsEmpty; -constexpr int64_t kPointerSize = 8; constexpr int64_t kAlternateMemorySpace = 1; -int64_t ShapeSize(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - class MemorySpaceAssignmentSimulatorTest : public HloTestBase { protected: absl::Status Initialize(absl::string_view hlo_string) { @@ -84,7 +79,6 @@ class MemorySpaceAssignmentSimulatorTest : public HloTestBase { } } HloCostAnalysis::Options tpu_device_options; - tpu_device_options.shape_size = ShapeSize; // Assume 1 FLOP per second for testing. tpu_device_options.set_flops_per_second(1); // Assume 1 byte per second for testing. @@ -498,9 +492,12 @@ TEST_F(SimulateAsyncCopyLikeDoneTest, const HloInstruction* copy_start_1_inst = instruction_map_["copy-start.1"]; const HloInstruction* neg_inst = instruction_map_["neg"]; - float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( - neg_inst, /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{}); + float compute_elapsed_time = + runtime_simulator_ + ->SimulateComputeInstruction(neg_inst, + /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{}) + .elapsed_time; // The compute operand requires 32 FLOPs and 32 * 4 * 2 bytes access, which // requires 32 and 256 secs respectively. Thus, it is default memory access @@ -539,9 +536,13 @@ TEST_F(SimulateAsyncCopyLikeDoneTest, // process the async copies in this time. Both queues are not empty, so the // bandwidth is shared. Each of the request at the front of the queue process // 64 sec * 0.5 bytes/sec = 32 bytes. - float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( - instruction_map_["neg"], /*operands_in_alternate_memory=*/{{0, {}}}, - /*outputs_in_alternate_memory=*/{}); + float compute_elapsed_time = + runtime_simulator_ + ->SimulateComputeInstruction( + instruction_map_["neg"], + /*operands_in_alternate_memory=*/{{0, {}}}, + /*outputs_in_alternate_memory=*/{}) + .elapsed_time; // 64 secs for alternate memory access + 128 secs for default memory access EXPECT_EQ(compute_elapsed_time, 192); @@ -577,9 +578,13 @@ TEST_F(SimulateAsyncCopyLikeDoneTest, // 64 secs idle time to process async copies. Since only the read queue is not // empty, we can use the full bandwidth and process 64 sec * 1 bytes/sec = 64 // bytes. - float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( - instruction_map_["neg"], /*operands_in_alternate_memory=*/{{0, {}}}, - /*outputs_in_alternate_memory=*/{}); + float compute_elapsed_time = + runtime_simulator_ + ->SimulateComputeInstruction( + instruction_map_["neg"], + /*operands_in_alternate_memory=*/{{0, {}}}, + /*outputs_in_alternate_memory=*/{}) + .elapsed_time; // 64 secs for alternate memory access + 128 secs for default memory access EXPECT_EQ(compute_elapsed_time, 192); @@ -602,9 +607,12 @@ TEST_F(SimulateAsyncCopyLikeDoneTest, TF_ASSERT_OK(Initialize(hlo_string)); - float compute_elapsed_time = runtime_simulator_->SimulateComputeInstruction( - instruction_map_["neg"], /*operands_in_alternate_memory=*/{}, - /*outputs_in_alternate_memory=*/{}); + float compute_elapsed_time = + runtime_simulator_ + ->SimulateComputeInstruction(instruction_map_["neg"], + /*operands_in_alternate_memory=*/{}, + /*outputs_in_alternate_memory=*/{}) + .elapsed_time; // Execution time: 128 * 4 * 2 / 1 for default access EXPECT_EQ(compute_elapsed_time, 1024); // The queues should remain empty. diff --git a/third_party/xla/xla/service/memory_space_assignment/testing_utils.h b/third_party/xla/xla/service/memory_space_assignment/testing_utils.h index 25267371d654c6..c1bf0f5a8648b3 100644 --- a/third_party/xla/xla/service/memory_space_assignment/testing_utils.h +++ b/third_party/xla/xla/service/memory_space_assignment/testing_utils.h @@ -24,10 +24,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/memory_space_assignment/cost_analysis.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/memory_space_propagation.h b/third_party/xla/xla/service/memory_space_propagation.h index 8f741f6430cf50..11676aa45c3ba9 100644 --- a/third_party/xla/xla/service/memory_space_propagation.h +++ b/third_party/xla/xla/service/memory_space_propagation.h @@ -16,33 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ #define XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/service/hlo_dataflow_analysis.h" - -namespace xla { - -// This is a legalization pass that propagates the memory space in the layout to -// the fusion computations. -class MemorySpacePropagation : public HloModulePass { - public: - ~MemorySpacePropagation() override = default; - absl::string_view name() const override { return "memory-space-propagation"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - // Given the shape index (operand or output) and its corresponding instruction - // in the fused computation (parameter or root), propagates the memory space - // in the callee side. Returns true if the module is modified. - bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction, - int64_t memory_space) const; - - std::unique_ptr dataflow_analysis_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/memory_space_propagation.h" #endif // XLA_SERVICE_MEMORY_SPACE_PROPAGATION_H_ diff --git a/third_party/xla/xla/service/multi_output_fusion.cc b/third_party/xla/xla/service/multi_output_fusion.cc index 779a292ac43348..0967e152717dae 100644 --- a/third_party/xla/xla/service/multi_output_fusion.cc +++ b/third_party/xla/xla/service/multi_output_fusion.cc @@ -19,11 +19,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/multi_output_fusion.h b/third_party/xla/xla/service/multi_output_fusion.h index 14233ae4b8b9e8..dd321c7e93d4d2 100644 --- a/third_party/xla/xla/service/multi_output_fusion.h +++ b/third_party/xla/xla/service/multi_output_fusion.h @@ -23,8 +23,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/pass/hlo_pass_interface.h" namespace xla { @@ -128,6 +128,10 @@ class MultiOutputFusion : public HloModulePass { // reachability, worklist, and fusion candidates. HloInstruction* CreateFusion(HloInstruction* base, HloInstruction* to_fuse); + bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { + return reachability_->IsConnected(instr1, instr2); + } + private: // An internal data structure for each instruction in current computation. // When an instruction is removed, member 'hlo' is set to nullptr. @@ -195,10 +199,6 @@ class MultiOutputFusion : public HloModulePass { candidates_[get_candidate_id(instr)].hlo = nullptr; } - bool is_connected(HloInstruction* instr1, HloInstruction* instr2) { - return reachability_->IsConnected(instr1, instr2); - } - std::vector candidates_; WorkList worklist_; diff --git a/third_party/xla/xla/service/op_expander_pass.h b/third_party/xla/xla/service/op_expander_pass.h index e8644f5abb30a1..df65b012e1da6c 100644 --- a/third_party/xla/xla/service/op_expander_pass.h +++ b/third_party/xla/xla/service/op_expander_pass.h @@ -16,48 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_OP_EXPANDER_PASS_H_ #define XLA_SERVICE_OP_EXPANDER_PASS_H_ -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This pass is an abstract superclass for passes that replace operations that -// match a pattern. It is intended to be subclassed, not used directly. -// -// This pass is useful for legalizing HLO instructions that a particular backend -// does not support into other HLO instructions. -class OpExpanderPass : public HloModulePass { - public: - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - // extra_filter: Optional extra filtering criteria for matching instructions, - // used in conjunction with InstructionMatchesPattern. - // preserve_sharding and relay_control_dependency: If we preserve sharding and - // relay control dependency when replacing the matched instructions. - explicit OpExpanderPass(HloPredicate extra_filter = nullptr, - bool preserve_sharding = false, - bool relay_control_dependency = false) - : extra_filter_(std::move(extra_filter)), - preserve_sharding_(preserve_sharding), - relay_control_dependency_(relay_control_dependency) {} - - protected: - // Returns `true` if `instruction` should be expanded by this pass. - virtual bool InstructionMatchesPattern(HloInstruction* instruction) = 0; - - // Returns a replacement for `instruction`, or nullptr if no replacement is - // needed (e.g. only the to_apply subcomputation of the instruction was - // modified). - virtual absl::StatusOr ExpandInstruction( - HloInstruction* instruction) = 0; - - HloPredicate extra_filter_; - const bool preserve_sharding_; - const bool relay_control_dependency_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/op_expander_pass.h" #endif // XLA_SERVICE_OP_EXPANDER_PASS_H_ diff --git a/third_party/xla/xla/service/operand_upcaster.h b/third_party/xla/xla/service/operand_upcaster.h index d89daf4e415d0f..8b237a47e0cd65 100644 --- a/third_party/xla/xla/service/operand_upcaster.h +++ b/third_party/xla/xla/service/operand_upcaster.h @@ -16,32 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_OPERAND_UPCASTER_H_ #define XLA_SERVICE_OPERAND_UPCASTER_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" -#include "xla/util.h" - -namespace xla { - -// Inserts Convert to operands of instructions that allows result accumulation -// as wider integral types. -class OperandUpcaster : public OpExpanderPass { - public: - explicit OperandUpcaster(HloPredicate extra_filter = nullptr) - : OpExpanderPass(std::move(extra_filter)) {} - - absl::string_view name() const override { return "operand_upcaster"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/operand_upcaster.h" #endif // XLA_SERVICE_OPERAND_UPCASTER_H_ diff --git a/third_party/xla/xla/service/optimization_barrier_expander.h b/third_party/xla/xla/service/optimization_barrier_expander.h index b614b80d8f3e4a..b257010fe9a616 100644 --- a/third_party/xla/xla/service/optimization_barrier_expander.h +++ b/third_party/xla/xla/service/optimization_barrier_expander.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ #define XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// This pass removes the opt-barrier operation which is functionally a no-op. -class OptimizationBarrierExpander : public HloModulePass { - public: - OptimizationBarrierExpander() = default; - - absl::string_view name() const override { return "cse_barrier_expander"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/optimization_barrier_expander.h" #endif // XLA_SERVICE_OPTIMIZATION_BARRIER_EXPANDER_H_ diff --git a/third_party/xla/xla/service/optimize_input_output_buffer_alias.h b/third_party/xla/xla/service/optimize_input_output_buffer_alias.h index d03128c8b9138f..04ad98bc488386 100644 --- a/third_party/xla/xla/service/optimize_input_output_buffer_alias.h +++ b/third_party/xla/xla/service/optimize_input_output_buffer_alias.h @@ -16,74 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ #define XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" -#include "xla/shape.h" -#include "xla/shape_util.h" - -namespace xla { - -// This pass finds input and output buffers that can be aliased, and writes the -// alias config into the HloModule. -// -// The input and the output buffers can be in any shape, and each output buffer -// can alias with an input buffer with the same shape. Each input buffer may -// only alias with a single output buffer. For example, for the following -// parameter and the output buffers, -// -// Parameters : { P1(f32[3]), P2(s32[3]), P3(f32[3,12]), P4(f32[16,12]), ... } -// Outputs : { O1(s32[3]), O2(f32[3]), O3(f32[16,12]), ... } -// -// one potential aliasing would be (O1, P2), (O2, P1), (O3, P4), .. -class OptimizeInputOutputBufferAlias : public HloModulePass { - public: - OptimizeInputOutputBufferAlias() = default; - explicit OptimizeInputOutputBufferAlias( - bool registered_buffer_donor_only, - std::function shape_size_fn = - [](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape); }) - : registered_buffer_donor_only_(registered_buffer_donor_only), - shape_size_fn_(shape_size_fn) {} - ~OptimizeInputOutputBufferAlias() override = default; - - absl::string_view name() const override { - return "optimize_input_output_buffer_alias"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - friend class OptimizeInputOutputBufferAliasTest; - - // If true, we only consider the registered buffer donor in - // HloBufferDonorConfig, ignoring unregistered input parameters. If false, we - // treat all input parameters as buffer donors. - bool registered_buffer_donor_only_ = false; - - // Match buffer donors and donees and save the matched paired in the - // alias_config. The availability of buffer donors is controlled by the flag - // registered_buffer_donor_only_. - absl::StatusOr Build(absl::Span input_shapes, - const Shape& output_shape, - HloInputOutputAliasConfig* alias_config, - HloBufferDonorConfig* buffer_donor_config); - - std::function shape_size_fn_ = [](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape); - }; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/optimize_input_output_buffer_alias.h" #endif // XLA_SERVICE_OPTIMIZE_INPUT_OUTPUT_BUFFER_ALIAS_H_ diff --git a/third_party/xla/xla/service/p2p_schedule_preparation.cc b/third_party/xla/xla/service/p2p_schedule_preparation.cc index 3ed9a6df03dfee..18b18f0a8ed41b 100644 --- a/third_party/xla/xla/service/p2p_schedule_preparation.cc +++ b/third_party/xla/xla/service/p2p_schedule_preparation.cc @@ -28,12 +28,12 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/p2p_schedule_preparation_test.cc b/third_party/xla/xla/service/p2p_schedule_preparation_test.cc index b1127c586fe032..828ca209972823 100644 --- a/third_party/xla/xla/service/p2p_schedule_preparation_test.cc +++ b/third_party/xla/xla/service/p2p_schedule_preparation_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/pattern_matcher.h b/third_party/xla/xla/service/pattern_matcher.h index 76979f097ef1f9..957c29f1ea5e15 100644 --- a/third_party/xla/xla/service/pattern_matcher.h +++ b/third_party/xla/xla/service/pattern_matcher.h @@ -46,10 +46,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/ptrvec.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -2690,6 +2690,7 @@ XLA_UNOP_PATTERN(Cos) XLA_UNOP_PATTERN(AllReduceStart) XLA_UNOP_PATTERN(AllReduceDone) XLA_UNOP_PATTERN(AllToAll) +XLA_UNOP_PATTERN(RaggedAllToAll) XLA_UNOP_PATTERN(AsyncDone) XLA_UNOP_PATTERN(CollectiveBroadcast) XLA_UNOP_PATTERN(CollectivePermute) diff --git a/third_party/xla/xla/service/pattern_matcher_test.cc b/third_party/xla/xla/service/pattern_matcher_test.cc index 73da06ae7c1eea..c8ecfe16029189 100644 --- a/third_party/xla/xla/service/pattern_matcher_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_test.cc @@ -24,10 +24,10 @@ limitations under the License. #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/profile_guided_latency_estimator.cc b/third_party/xla/xla/service/profile_guided_latency_estimator.cc index d8e20f2445c4a4..d66f5334933184 100644 --- a/third_party/xla/xla/service/profile_guided_latency_estimator.cc +++ b/third_party/xla/xla/service/profile_guided_latency_estimator.cc @@ -188,16 +188,19 @@ absl::Status ProfileGuidedLatencyEstimator::CheckAccuracy( ProfileStatisticsAggregator::Statistics stats = aggregator_->GetStats(); size_t missing_instructions_count = stats.missing_instructions.size(); if (missing_instructions_count > 0) { - LOG(ERROR) << "Found " << stats.found_instructions_count - << " instructions from the profile."; - LOG(ERROR) << "Missing " << missing_instructions_count - << " instructions from the profile."; + LOG(WARNING) << "Found " << stats.found_instructions_count + << " instructions from the profile."; + LOG(WARNING) << "Missing " << missing_instructions_count + << " instructions from the profile."; for (const HloInstruction* instr : stats.missing_instructions) { - LOG(ERROR) << " " << instr->name(); + LOG(WARNING) << " " << instr->name(); + } + if (module.config().debug_options().xla_gpu_pgle_accuracy_checker() == + DebugOptions::PGLE_STRICTNESS_LEVEL_ERROR) { + return absl::InvalidArgumentError( + absl::StrCat("Found ", missing_instructions_count, + " missing instructions. Discarding the profile.")); } - return absl::InvalidArgumentError( - absl::StrCat("Found ", missing_instructions_count, - " missing instructions. Discarding the profile.")); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/propagate_original_value_test.cc b/third_party/xla/xla/service/propagate_original_value_test.cc index 37340843c2fb5d..301e31e68972e8 100644 --- a/third_party/xla/xla/service/propagate_original_value_test.cc +++ b/third_party/xla/xla/service/propagate_original_value_test.cc @@ -29,17 +29,17 @@ TEST_F(PropagateOriginalValueTest, InstructionFusion) { HloModule test, entry_computation_layout={(s32[]{:T(256)})->u32[2]{0:T(256)}} ENTRY test { - Arg_0 = s32[]{:T(256)} parameter(0), original_value={{"Arg_0"}}, metadata={op_name="seed"} - constant = s32[]{:T(256)} constant(32), original_value={{"constant"}} - shift-right-logical = s32[]{:T(256)} shift-right-logical(Arg_0, constant), original_value={{"shift-right-logical"}} - convert = u32[]{:T(256)} convert(shift-right-logical), original_value={{"convert"}} - bitcast = u32[1]{0:T(256)} bitcast(convert), original_value={{"reshape"}} + Arg_0 = s32[]{:T(256)} parameter(0), origin={{"Arg_0"}}, metadata={op_name="seed"} + constant = s32[]{:T(256)} constant(32), origin={{"constant"}} + shift-right-logical = s32[]{:T(256)} shift-right-logical(Arg_0, constant), origin={{"shift-right-logical"}} + convert = u32[]{:T(256)} convert(shift-right-logical), origin={{"convert"}} + bitcast = u32[1]{0:T(256)} bitcast(convert), origin={{"reshape"}} constant.1 = u32[]{:T(256)} constant(0) pad = u32[2]{0:T(256)} pad(bitcast, constant.1), padding=0_1 - convert.1 = u32[]{:T(256)} convert(Arg_0), original_value={{"convert.1"}} - bitcast.1 = u32[1]{0:T(256)} bitcast(convert.1), original_value={{"reshape.1"}} + convert.1 = u32[]{:T(256)} convert(Arg_0), origin={{"convert.1"}} + bitcast.1 = u32[1]{0:T(256)} bitcast(convert.1), origin={{"reshape.1"}} pad.1 = u32[2]{0:T(256)} pad(bitcast.1, constant.1), padding=1_0 - ROOT add = u32[2]{0:T(256)} add(pad, pad.1), original_value={{"concatenate"}} + ROOT add = u32[2]{0:T(256)} add(pad, pad.1), origin={{"concatenate"}} } )"; @@ -49,20 +49,20 @@ ENTRY test { R"( CHECK: %fused_computation CHECK: %[[PARAM:.*]] = s32[]{:T(256)} parameter(0) -CHECK: %[[CONSTANT:.*]] = s32[]{:T(256)} constant(32), original_value={{[{]}}{"constant"}} -CHECK: %[[SHIFT:.*]] = s32[]{:T(256)} shift-right-logical(%[[PARAM]], %[[CONSTANT]]), original_value={{[{]}}{"shift-right-logical"} -CHECK: %[[CONVERT:.*]] = u32[]{:T(256)} convert(%[[SHIFT]]), original_value={{[{]}}{"convert"} -CHECK: %[[BITCAST:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT]]), original_value={{[{]}}{"reshape"} +CHECK: %[[CONSTANT:.*]] = s32[]{:T(256)} constant(32), origin={{[{]}}{"constant"}} +CHECK: %[[SHIFT:.*]] = s32[]{:T(256)} shift-right-logical(%[[PARAM]], %[[CONSTANT]]), origin={{[{]}}{"shift-right-logical"} +CHECK: %[[CONVERT:.*]] = u32[]{:T(256)} convert(%[[SHIFT]]), origin={{[{]}}{"convert"} +CHECK: %[[BITCAST:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT]]), origin={{[{]}}{"reshape"} CHECK: %[[CONSTANT1:.*]] = u32[]{:T(256)} constant(0) CHECK: %[[PAD:.*]] = u32[2]{0:T(256)} pad(%[[BITCAST]], %[[CONSTANT1]]), padding=0_1 -CHECK: %[[CONVERT1:.*]] = u32[]{:T(256)} convert(%[[PARAM]]), original_value={{[{]}}{"convert.1"} -CHECK: %[[BITCAST1:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT1]]), original_value={{[{]}}{"reshape.1"} +CHECK: %[[CONVERT1:.*]] = u32[]{:T(256)} convert(%[[PARAM]]), origin={{[{]}}{"convert.1"} +CHECK: %[[BITCAST1:.*]] = u32[1]{0:T(256)} bitcast(%[[CONVERT1]]), origin={{[{]}}{"reshape.1"} CHECK: %[[PAD1:.*]] = u32[2]{0:T(256)} pad(%[[BITCAST1]], %[[CONSTANT1]]), padding=1_0 -CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), original_value={{[{]}}{"concatenate"} +CHECK: ROOT %[[ADD:.*]] = u32[2]{0:T(256)} add(%[[PAD]], %[[PAD1]]), origin={{[{]}}{"concatenate"} CHECK: ENTRY %test -CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), original_value={{[{]}}{"Arg_0"} -CHECK: ROOT %fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation +CHECK: %Arg_0 = s32[]{:T(256)} parameter(0), origin={{[{]}}{"Arg_0"} +CHECK: ROOT %pad_add_fusion = u32[2]{0:T(256)} fusion(%Arg_0), kind=kLoop, calls=%fused_computation, origin={{[{]}}{"concatenate"} )"); } diff --git a/third_party/xla/xla/service/qr_expander.h b/third_party/xla/xla/service/qr_expander.h index d4818f644d137f..067ea64c9166a9 100644 --- a/third_party/xla/xla/service/qr_expander.h +++ b/third_party/xla/xla/service/qr_expander.h @@ -16,42 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_QR_EXPANDER_H_ #define XLA_SERVICE_QR_EXPANDER_H_ -#include "absl/container/flat_hash_map.h" -#include "xla/client/lib/qr.h" -#include "xla/client/xla_builder.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -class QrExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "qr_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; - - virtual absl::StatusOr QrBlock( - XlaOp a, PrecisionConfig::Precision precision); - - virtual absl::StatusOr CompactWYRepresentation( - PrimitiveType type, absl::Span batch_dims, XlaOp vs, - XlaOp taus, int64_t m, int64_t n, PrecisionConfig::Precision precision); - - private: - absl::StatusOr BuildQrDecomposition( - XlaOp a, int64_t block_size, PrecisionConfig::Precision precision); - - absl::StatusOr ProductOfElementaryHouseholderReflectors( - XlaOp a, XlaOp taus, int64_t block_size, - PrecisionConfig::Precision precision); - - // Mapping from op signatures to existing computations. - absl::flat_hash_map computation_cache_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/qr_expander.h" #endif // XLA_SERVICE_QR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/real_imag_expander.h b/third_party/xla/xla/service/real_imag_expander.h index 2c2bd9e08eb08b..fc87a60e747da6 100644 --- a/third_party/xla/xla/service/real_imag_expander.h +++ b/third_party/xla/xla/service/real_imag_expander.h @@ -16,22 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_REAL_IMAG_EXPANDER_H_ #define XLA_SERVICE_REAL_IMAG_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// Expands real/image instructions with non-complex inputs. -class RealImagExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "real_imag_expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* inst) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* inst) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/real_imag_expander.h" #endif // XLA_SERVICE_REAL_IMAG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/reduce_decomposer.h b/third_party/xla/xla/service/reduce_decomposer.h index d112002fb56786..12fac9b0dec6b1 100644 --- a/third_party/xla/xla/service/reduce_decomposer.h +++ b/third_party/xla/xla/service/reduce_decomposer.h @@ -16,65 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_REDUCE_DECOMPOSER_H_ #define XLA_SERVICE_REDUCE_DECOMPOSER_H_ -#include - -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// For each reduction R(I), ensures the postcondition: -// -// !custom_layout_allowed(R) -// => -// layout(R) == layout(I) # modulo removed dimensions -// -// To achieve that, decomposes layout-mutating reductions which do not satisfy -// `custom_layout_allowed` into a reduction and a copy. -// -// For a singular reduction: -// -// -> reduce -> -// -// Gets turned into: -// -// -> reduce -> copy -> -// -// For a variadic recuction, the layout assignment guarantees that the layout -// is the same for all outputs. This pass will transpose the variadic reduction -// inputs which have different physical layout to the first operand. -// -// A{L} \ -// B{L} -> reduce{L'} -> -// C{L} / -// -// Get turned into: -// -// A{L} \ / GTE(1) -> copy{L'} \ -// B{L} -> reduce{E(L)} --- GTE(2) -> copy{L'} - Tuple{L'} -// C{L} / \ GTE(3) -> copy{L'} / -// -// Where E(L) is expected layout of a reduction (original layout with reduce -// dimensions dropped). -// -// PRECONDITION: -// In variadic reduction, all outputs have the same layout -// (enforced by layout assignment). -class ReduceDecomposer : public HloModulePass { - public: - explicit ReduceDecomposer(HloPredicate custom_layout_allowed = nullptr) - : custom_layout_allowed_(custom_layout_allowed) {} - - absl::string_view name() const override { return "reduce-decomposer"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - HloPredicate custom_layout_allowed_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/reduce_decomposer.h" #endif // XLA_SERVICE_REDUCE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reduce_scatter_combiner.cc b/third_party/xla/xla/service/reduce_scatter_combiner.cc index 1e58187720b4cc..6c0d34aa96870e 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner.cc +++ b/third_party/xla/xla/service/reduce_scatter_combiner.cc @@ -22,29 +22,33 @@ limitations under the License. #include #include #include +#include #include #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/service/all_reduce_key.h" #include "xla/service/collective_combiner_utils.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/hlo_domain_map.h" -#include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -71,9 +75,6 @@ int64_t FindMostFrequentScatterDim( return most_frequent_dim < min_rank ? most_frequent_dim : 0; } -using ReduceScatterKey = - std::tuple; - // Combines the elements of to_combine into a single ReduceScatter op. All // entries in to_combine must be ReduceScatter ops with exactly one operand // and the same reduction operation. @@ -172,16 +173,36 @@ absl::Status CombineReduceScatters( } } // namespace -ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes, - int64_t combine_threshold_count, - bool combine_by_dim) - : combine_threshold_in_bytes_(combine_threshold_in_bytes), - combine_threshold_count_(combine_threshold_count), - combine_by_dim_(combine_by_dim) {} +/*static*/ std::string& ReduceScatterCombiner::GetGroupKeyExtraArgs( + ReduceScatterCombiner::GroupKey& key) { + return std::get<2>(key); +} -absl::StatusOr ReduceScatterCombiner::Run( +std::optional +ReduceScatterCombiner::CombineKey(const HloInstruction* instruction, + const HloDomainMap& domain_map, + bool combine_by_dim) { + auto* rs = DynCast(instruction); + std::optional key = GetAllReduceKey(instruction, &domain_map); + + if (!rs || !key) { + return std::nullopt; + } + if (!MatchReductionComputation(rs->to_apply())) { + return std::nullopt; + } + + // Ignore dimension (set to -1) if we are not grouping by dimension. + int64_t rs_dim_key = combine_by_dim ? rs->scatter_dimension() : -1; + return ReduceScatterCombiner::GroupKey{std::move(*key), rs_dim_key, ""}; +} + +absl::StatusOr ReduceScatterCombiner::RunWithKeyCombiner( HloModule* module, - const absl::flat_hash_set& execution_threads) { + const absl::flat_hash_set& execution_threads, + absl::FunctionRef( + const HloInstruction*, const HloDomainMap&, bool)> + combine_key) { VLOG(1) << "Running ReduceScatterCombiner with threshold of " << combine_threshold_in_bytes_ << " bytes"; @@ -202,27 +223,13 @@ absl::StatusOr ReduceScatterCombiner::Run( module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, "")); - auto key_fn = [&domain_map, this](const HloInstruction* instruction) - -> std::optional { - auto* rs = DynCast(instruction); - std::optional key = - GetAllReduceKey(instruction, domain_map.get()); - - if (!rs || !key) { - return std::nullopt; - } - if (!MatchReductionComputation(rs->to_apply())) { - return std::nullopt; - } - - // Ignore dimension (set to -1) if we are not grouping by dimension. - int64_t rs_dim_key = this->combine_by_dim_ ? rs->scatter_dimension() : -1; - return ReduceScatterKey{std::move(*key), rs_dim_key}; + auto key_fn = [&](const HloInstruction* instruction) { + return combine_key(instruction, *domain_map, combine_by_dim_); }; TF_ASSIGN_OR_RETURN( bool computation_changed, - CombineInstructionsByKey( + CombineInstructionsByKey( computation, key_fn, &CombineReduceScatters, combine_threshold_in_bytes_, combine_threshold_count_)); changed |= computation_changed; @@ -231,4 +238,19 @@ absl::StatusOr ReduceScatterCombiner::Run( return changed; } +ReduceScatterCombiner::ReduceScatterCombiner(int64_t combine_threshold_in_bytes, + int64_t combine_threshold_count, + bool combine_by_dim) + : combine_threshold_in_bytes_(combine_threshold_in_bytes), + combine_threshold_count_(combine_threshold_count), + combine_by_dim_(combine_by_dim) {} + +absl::StatusOr ReduceScatterCombiner::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_ASSIGN_OR_RETURN( + bool changed, RunWithKeyCombiner(module, execution_threads, CombineKey)); + return changed; +} + } // namespace xla diff --git a/third_party/xla/xla/service/reduce_scatter_combiner.h b/third_party/xla/xla/service/reduce_scatter_combiner.h index 26e047bd400f92..60eaeb95cef105 100644 --- a/third_party/xla/xla/service/reduce_scatter_combiner.h +++ b/third_party/xla/xla/service/reduce_scatter_combiner.h @@ -16,9 +16,20 @@ limitations under the License. #ifndef XLA_SERVICE_REDUCE_SCATTER_COMBINER_H_ #define XLA_SERVICE_REDUCE_SCATTER_COMBINER_H_ +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/functional/function_ref.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/all_reduce_key.h" +#include "xla/service/hlo_domain_map.h" namespace xla { @@ -38,7 +49,26 @@ class ReduceScatterCombiner : public HloModulePass { HloModule* module, const absl::flat_hash_set& execution_threads) override; - private: + using GroupKey = std::tuple; + + static std::string& GetGroupKeyExtraArgs( + ReduceScatterCombiner::GroupKey& key); + + // Returns a key that will be equal for instructions that might be combined, + // or different if not. + static std::optional CombineKey( + const HloInstruction* instruction, const HloDomainMap& domain_map, + bool combine_by_dim); + + protected: + absl::StatusOr RunWithKeyCombiner( + HloModule* module, + const absl::flat_hash_set& execution_threads, + absl::FunctionRef( + const HloInstruction*, const HloDomainMap&, bool)> + combine_key); + // Combine reduce-scatter ops up to this threshold. int64_t combine_threshold_in_bytes_; diff --git a/third_party/xla/xla/service/reduce_window_rewriter.h b/third_party/xla/xla/service/reduce_window_rewriter.h index 616ec8b701e9a4..01f1ad58267695 100644 --- a/third_party/xla/xla/service/reduce_window_rewriter.h +++ b/third_party/xla/xla/service/reduce_window_rewriter.h @@ -16,57 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ #define XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Rewrite ReduceWindow to be more performant in cases it is written in a -// quadratic way: -// -// 1) Work around unimplemented cases in the implementation of ReduceWindow. -// -// This rewrites all R1 ReduceWindow nodes. We reshape the operand to an -// R2, perform the operation, and reshape back to R1. The reshapes correspond to -// a bitcast if the tensor length is less than or equal to a passed parameter. -// The motivation for this is to avoid use of overly large reductions and the -// complexities and restrictions therein. -// -// 2) Rewrite ReduceWindow ops that represent a CumSum/CumProd into a -// tree-reduction (see details in the implementation). -// Note that this may itself generate R1 ReduceWindow ops, which means this pass -// needs to be run to a fixed point. -class ReduceWindowRewriter : public HloModulePass { - public: - // `base_length` is a size of a reduce-window we are comfortable with - // executing. - explicit ReduceWindowRewriter(int64_t base_length) - : base_length_(base_length) {} - - absl::string_view name() const override { return "reduce-window-rewriter"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - absl::Status ReplaceReduceWindowWithReshape( - HloReduceWindowInstruction* reduce_window); - - absl::StatusOr TryOptimizeCumSumOrProd( - HloReduceWindowInstruction* reduce_window); - - int64_t base_length_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/reduce_window_rewriter.h" #endif // XLA_SERVICE_REDUCE_WINDOW_REWRITER_H_ diff --git a/third_party/xla/xla/service/reshape_decomposer.h b/third_party/xla/xla/service/reshape_decomposer.h index 9c10649653cb3c..f5d5b140b1921f 100644 --- a/third_party/xla/xla/service/reshape_decomposer.h +++ b/third_party/xla/xla/service/reshape_decomposer.h @@ -16,25 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_RESHAPE_DECOMPOSER_H_ #define XLA_SERVICE_RESHAPE_DECOMPOSER_H_ -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Decomposes a reshape which does not satisfy the ReshapeIsBitcast precondition -// into a bitcast and a copy (physical transposition). Tries to create only one -// transposition, but when it's not possible, creates two. -// -// Postcondition: All reshapes are turned into bitcasts. -class ReshapeDecomposer : public HloModulePass { - public: - absl::string_view name() const override { return "reshape-decomposer"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/reshape_decomposer.h" #endif // XLA_SERVICE_RESHAPE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/reshape_mover.h b/third_party/xla/xla/service/reshape_mover.h index 14116a0ba0cf17..63f2003ed3e8c3 100644 --- a/third_party/xla/xla/service/reshape_mover.h +++ b/third_party/xla/xla/service/reshape_mover.h @@ -16,60 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_RESHAPE_MOVER_H_ #define XLA_SERVICE_RESHAPE_MOVER_H_ -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// This pass sinks kReshape and kTranspose operations (known as "rearrange" ops) -// down through elementwise ops: -// -// op(rearrange(x), rearrange(y)) => rearrange(op(x, y)). -// -// We also handle the case where one of the operands is not itself a rearrange -// op but can be trivially rearranged. For example: -// -// op(rearrange(x), broadcast(scalar_y)) => -// rearrange(x, broadcast'(scalar_y)). -// -// This pass should be run to a fixed point. It also expects algsimp to be run -// after each iteration. - -struct ReshapeMoverOptions { - // On some platforms, it's cheap to do `reshape(broadcast(f32[n] x))`. The - // reshape and broadcast can always be fused, and the index calculations are - // not expensive. In such cases it can be beneficial for us to create these - // reshapes eagerly, allowing us to get rid of more expensive ones. - bool reshape_of_1d_broadcast_is_cheap = false; -}; - -class ReshapeMover : public HloModulePass { - public: - explicit ReshapeMover( - const ReshapeMoverOptions& options = ReshapeMoverOptions{}) - : options_(options) {} - - absl::string_view name() const override { return "reshape-mover"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - absl::StatusOr TryReshapeMoveOnCandidates( - HloInstructionSet* candidates); - absl::StatusOr SinkRearrangeOperands(HloInstruction* instruction); - absl::StatusOr ApplyInverseRearrange( - const HloInstruction* rearrange, HloInstruction* operand); - bool IsReshapeMoveCandidate(HloInstruction* instruction); - const HloInstruction* FirstNontrivialRearrange( - absl::Span instrs); - bool CanTriviallyRearrange(const HloInstruction* instr, - const HloInstruction* rearrange); - - ReshapeMoverOptions options_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/reshape_mover.h" #endif // XLA_SERVICE_RESHAPE_MOVER_H_ diff --git a/third_party/xla/xla/service/result_caster.h b/third_party/xla/xla/service/result_caster.h index 1abb59dfe9a277..d8fc21221f5038 100644 --- a/third_party/xla/xla/service/result_caster.h +++ b/third_party/xla/xla/service/result_caster.h @@ -16,36 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_RESULT_CASTER_H_ #define XLA_SERVICE_RESULT_CASTER_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/op_expander_pass.h" -#include "xla/util.h" - -namespace xla { - -// Inserts Convert to result of instructions to the preferred element type -// specified by the instructions when direct accumulation of that type isn't -// supported by the backend. This pass is run in combination with -// OperandUpcaster. If the inferred accumulation type has less precision, -// OperandUpcaster will convert the operands to the higher precision type if -// necessary. -class ResultCaster : public OpExpanderPass { - public: - explicit ResultCaster(HloPredicate extra_filter = nullptr) - : OpExpanderPass(std::move(extra_filter)) {} - - absl::string_view name() const override { return "result_caster"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/result_caster.h" #endif // XLA_SERVICE_RESULT_CASTER_H_ diff --git a/third_party/xla/xla/service/rng_bit_generator_expander.h b/third_party/xla/xla/service/rng_bit_generator_expander.h index 7e45c53c8362fc..40a8b353804746 100644 --- a/third_party/xla/xla/service/rng_bit_generator_expander.h +++ b/third_party/xla/xla/service/rng_bit_generator_expander.h @@ -16,57 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ #define XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ -#include "absl/container/flat_hash_map.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/op_expander_pass.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class RngBitGeneratorExpander : public OpExpanderPass { - public: - explicit RngBitGeneratorExpander(RandomAlgorithm default_algorithm) - : default_algorithm_(default_algorithm) { - CHECK_NE(default_algorithm_, RandomAlgorithm::RNG_DEFAULT); - } - - absl::string_view name() const override { - return "rng-bit-generator-expander"; - } - - protected: - struct RngGeneratorKey { - Shape data_shape; - Shape state_shape; - RandomAlgorithm algorithm; - HloModule* module; - - template - friend H AbslHashValue(H h, const RngGeneratorKey& c) { - return H::combine(std::move(h), c.state_shape, c.data_shape, c.algorithm, - c.module); - } - - bool operator==(const RngGeneratorKey& o) const { - return data_shape == o.data_shape && state_shape == o.state_shape && - algorithm == o.algorithm && module == o.module; - } - }; - - bool InstructionMatchesPattern(HloInstruction* instruction) override; - absl::StatusOr ExpandInstruction( - HloInstruction* hlo) override; - absl::StatusOr GetGeneratorComputation( - const Shape& data_shape, const Shape& state_shape, - RandomAlgorithm algorithm, HloModule* module); - - const RandomAlgorithm default_algorithm_; - absl::flat_hash_map computation_cache_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" #endif // XLA_SERVICE_RNG_BIT_GENERATOR_EXPANDER_H_ diff --git a/third_party/xla/xla/service/rng_expander.h b/third_party/xla/xla/service/rng_expander.h index dd41a2a94838e5..5f1951d7c2c6f4 100644 --- a/third_party/xla/xla/service/rng_expander.h +++ b/third_party/xla/xla/service/rng_expander.h @@ -16,28 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_RNG_EXPANDER_H_ #define XLA_SERVICE_RNG_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" - -namespace xla { - -class RngExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "rng-expander"; } - - protected: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - - absl::StatusOr ExpandInstruction( - HloInstruction* rng) override; - - private: - // Cache RNG computations based on the distribution, output shape and shapes - // of the first and second operand. - absl::flat_hash_map, - HloComputation*> - expanded_rng_instructions_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/rng_expander.h" #endif // XLA_SERVICE_RNG_EXPANDER_H_ diff --git a/third_party/xla/xla/service/root_instruction_sinker.h b/third_party/xla/xla/service/root_instruction_sinker.h index 50975bd0a876e1..38cc3c7756908e 100644 --- a/third_party/xla/xla/service/root_instruction_sinker.h +++ b/third_party/xla/xla/service/root_instruction_sinker.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ #define XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to -// the bottom of the non-fusion computations. To avoid dependency violations of -// moving the ROOT instruction, it creates a new ROOT instruction that looks -// like the following: -// - For tuple ROOT type: -// new_root = tuple(gte(old_root), gte(old_root), ...) -// - For non-tuple ROOT type: -// new_root = bitcast(old_root) -class RootInstructionSinker : public HloModulePass { - public: - ~RootInstructionSinker() override = default; - absl::string_view name() const override { return "root-instruction-sinker"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/root_instruction_sinker.h" #endif // XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_ diff --git a/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc b/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc index 4470ae615eb7e5..fb20157f3fd090 100644 --- a/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc +++ b/third_party/xla/xla/service/scan_loop_accumulator_input_unification.cc @@ -25,16 +25,16 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/hlo_alias_analysis.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" -#include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/tuple_simplifier.h" #include "xla/service/while_loop_simplifier.h" #include "xla/service/while_loop_unroller.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/scatter_determinism_expander.cc b/third_party/xla/xla/service/scatter_determinism_expander.cc new file mode 100644 index 00000000000000..b938121a107af4 --- /dev/null +++ b/third_party/xla/xla/service/scatter_determinism_expander.cc @@ -0,0 +1,462 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_determinism_expander.h" + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "xla/array.h" +#include "xla/array2d.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/service/scatter_utils.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +// Canonicalizes the scatter_updates in order to keep them uniform while +// performing the scatter operation. +static absl::StatusOr> CanonicalizeScatterUpdates( + const std::vector& scatter_updates, + HloInstruction* scatter_indices, const ScatterDimensionNumbers& dim_numbers, + int64_t scatter_loop_trip_count) { + std::vector adjusted_updates; + adjusted_updates.reserve(scatter_updates.size()); + for (HloInstruction* update : scatter_updates) { + TF_ASSIGN_OR_RETURN( + HloInstruction * canonical_update, + PermuteScatterAndWindowDims(update, dim_numbers.update_window_dims())); + TF_ASSIGN_OR_RETURN( + HloInstruction * adjusted_update, + AdjustScatterDims(scatter_indices->shape(), canonical_update, + dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_loop_trip_count, adjusted_update->shape().dimensions(0)); + adjusted_updates.push_back(adjusted_update); + } + return adjusted_updates; +} + +// Create the out-of-bound tensor for the scatter operation. +HloInstruction* CreateOutOfBoundTensor(HloComputation* parent, + HloInstruction* scatter_indices, + const Shape& scatter_shape) { + if (scatter_indices->shape().rank() == 1) { + CHECK_EQ(scatter_shape.dimensions_size(), 1); + Array out_of_bound_array({scatter_indices->shape().dimensions(0)}, + scatter_shape.dimensions(0)); + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateFromArray(out_of_bound_array))); + } + // More than one dimension in scatter_indices + Array2D out_of_bound_array(scatter_indices->shape().dimensions(0), + scatter_indices->shape().dimensions(1)); + for (int i = 0; i < scatter_indices->shape().dimensions(0); ++i) { + for (int j = 0; j < scatter_indices->shape().dimensions(1); ++j) { + out_of_bound_array(i, j) = scatter_shape.dimensions(j); + } + } + return parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR2FromArray2D(out_of_bound_array))); +} + +// Computation for sorting the scalar scatter indices and updates together +HloComputation* ScalarSortingComparison(HloModule* module, + const Shape key_shape, + const Shape update_shape, + int64_t num_updates) { + HloComputation::Builder builder("sorting_computation"); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, key_shape, "lhs_key")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, key_shape, "rhs_key")); + const int kExistingParams = 2; + for (int i = 0; i < num_updates; ++i) { + builder.AddInstruction( + HloInstruction::CreateParameter(kExistingParams + i, update_shape, + absl::StrFormat("lhs_update_%d", i))); + builder.AddInstruction( + HloInstruction::CreateParameter(kExistingParams + 1 + i, update_shape, + absl::StrFormat("rhs_update_%d", i))); + } + builder.AddInstruction( + HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), param0, + param1, ComparisonDirection::kLt)); + return module->AddEmbeddedComputation(builder.Build()); +} + +static std::vector SortIndicesAndUpdates( + HloInstruction* scatter_indices, + const std::vector& scatter_updates, int64_t num_indices, + HloScatterInstruction* scatter, HloComputation* parent) { + const Shape& indices_shape = scatter_indices->shape(); + const Shape& updates_shape = scatter_updates[0]->shape(); + auto updates_dims = updates_shape.dimensions(); + // Since we canonicalized the scatter updates, the first dim will always be + // the number of updates and the rest will be the shape of each update + + HloInstruction* scalar_indices = scatter_indices; + + std::vector single_update_dimensions(updates_dims.begin() + 1, + updates_dims.end()); + + const Shape update_shape = ShapeUtil::MakeShape(updates_shape.element_type(), + single_update_dimensions); + + const Shape& scalar_index_shape = + ShapeUtil::MakeShape(indices_shape.element_type(), {num_indices}); + + auto* comparison = ScalarSortingComparison( + scatter->GetModule(), + ShapeUtil::MakeShape(indices_shape.element_type(), {}), + ShapeUtil::MakeShape(updates_shape.element_type(), {}), + scatter_updates.size()); + + std::vector sort_operands = {scalar_indices}; + std::vector sort_shapes = {scalar_index_shape}; + for (auto update : scatter_updates) { + sort_operands.push_back(update); + sort_shapes.push_back(update->shape()); + } + + auto* sorting = parent->AddInstruction(HloInstruction::CreateSort( + ShapeUtil::MakeTupleShape(sort_shapes), 0, sort_operands, comparison, + /*is_stable=*/false)); + auto* sorted_scalar_indices = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_indices->shape(), sorting, 0)); + + std::vector sorted_updates(scatter_updates.size()); + for (int i = 0; i < scatter_updates.size(); i++) { + sorted_updates[i] = + parent->AddInstruction(HloInstruction::CreateGetTupleElement( + scatter_updates[i]->shape(), sorting, i + 1)); + } + std::vector sorted_tensors = {sorted_scalar_indices}; + sorted_tensors.insert(sorted_tensors.end(), sorted_updates.begin(), + sorted_updates.end()); + return sorted_tensors; +} + +// CreateScanWithIndices performs a prefix scan operation (akin to parallel +// prefix sum) on the updates and indices, to compute the accumulated updates in +// log(n) time. +// +// High-level algorithm: +// +// Iteration through log2(num_updates): +// - For each iteration, the `updates` tensor will be sliced and padded to +// perform shifting by `offset`. +// - Similarly, the `indices` tensor is also sliced and padded. +// - A mask is created that compares each element of shifted `indices` and +// original `indices` are equal (used to avoid combining updates from +// different indices). +// - The `to_apply` function is used to combine the original and shifted +// updates to generate a combined update tensor. +// - Based on the mask, the new update tensor will choose from either the +// combined update or the original update. +// - The result becomes the `new_updates`, which is then used as the +// input for the next iteration. +static absl::StatusOr CreateScanWithIndices( + HloComputation* parent, HloInstruction* updates, HloInstruction* indices, + HloComputation* to_apply) { + const Shape& updates_shape = updates->shape(); + const Shape& indices_shape = indices->shape(); + // Get the length of the input array + int64_t num_updates = updates_shape.dimensions(0); + + // Calculate the number of iterations needed (log_2(n)) + int64_t log_n = Log2Ceiling(static_cast(num_updates)); + + // Start to traverse + HloInstruction* prev_updates = updates; + HloInstruction* prev_indices = indices; + HloInstruction* new_updates = nullptr; + + std::vector start_indices = {0}; + std::vector strides = {1}; + + for (int64_t iteration = 0; iteration < log_n; ++iteration) { + int64_t offset = static_cast(1) << iteration; + std::vector end_indices = {num_updates - offset}; + + auto shifted_updates_shape = ShapeUtil::MakeShape( + updates_shape.element_type(), {num_updates - offset}); + auto padding_updates_shape = + ShapeUtil::MakeShape(updates_shape.element_type(), {offset}); + + auto shifted_indices_shape = ShapeUtil::MakeShape( + indices_shape.element_type(), {num_updates - offset}); + auto padding_indices_shape = + ShapeUtil::MakeShape(indices_shape.element_type(), {offset}); + + auto* shifted_updates = parent->AddInstruction( + HloInstruction::CreateSlice(shifted_updates_shape, prev_updates, + start_indices, end_indices, strides)); + auto* padding_updates = + parent->AddInstruction(HloInstruction::CreateBroadcast( + padding_updates_shape, + parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(updates_shape.element_type(), 0))), + {})); + + auto* shifted_indices = parent->AddInstruction( + HloInstruction::CreateSlice(shifted_indices_shape, prev_indices, + start_indices, end_indices, strides)); + auto* padding_indices = + parent->AddInstruction(HloInstruction::CreateBroadcast( + padding_indices_shape, + parent->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(indices_shape.element_type(), 0))), + {})); + + auto* concatenated_updates = + parent->AddInstruction(HloInstruction::CreateConcatenate( + updates_shape, {padding_updates, shifted_updates}, 0)); + auto* concatenated_indices = + parent->AddInstruction(HloInstruction::CreateConcatenate( + indices_shape, {padding_indices, shifted_indices}, 0)); + + auto* indices_mask = parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {num_updates}), prev_indices, + concatenated_indices, ComparisonDirection::kEq)); + std::vector map_operands = {prev_updates, + concatenated_updates}; + TF_ASSIGN_OR_RETURN(HloInstruction * reduced_updates, + MakeMapHlo(map_operands, to_apply)); + new_updates = parent->AddInstruction(HloInstruction::CreateTernary( + updates_shape, HloOpcode::kSelect, indices_mask, reduced_updates, + prev_updates)); + prev_updates = new_updates; + } + return new_updates; +} + +absl::StatusOr> ComputePrefixScan( + const std::vector& sorted_updates, + HloInstruction* sorted_scalar_indices, HloScatterInstruction* scatter, + HloComputation* parent) { + std::vector prefix_scans(sorted_updates.size()); + for (int i = 0; i < sorted_updates.size(); i++) { + // TODO(chenhao) change to use the extracted computation + TF_ASSIGN_OR_RETURN( + HloComputation * to_apply, + CallComputationAndGetIthOutputWithBinaryParams(scatter->to_apply(), i)); + TF_ASSIGN_OR_RETURN(prefix_scans[i], + CreateScanWithIndices(parent, sorted_updates[i], + sorted_scalar_indices, to_apply)); + } + return prefix_scans; +} + +static HloInstruction* FindLastOccurrenceIndices( + HloInstruction* scatter_indices, HloInstruction* sorted_scalar_indices, + HloInstruction* scatter, HloComputation* parent, int64_t num_indices) { + int64_t indices_len = sorted_scalar_indices->shape().dimensions(0); + HloInstruction* sorted_indices = sorted_scalar_indices; + auto* sorted_indices_preceding_part = + parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {indices_len - 1}), + sorted_scalar_indices, {0}, {indices_len - 1}, {1})); + auto* sorted_indices_following_part = + parent->AddInstruction(HloInstruction::CreateSlice( + ShapeUtil::MakeShape(scatter_indices->shape().element_type(), + {indices_len - 1}), + sorted_scalar_indices, {1}, {indices_len}, {1})); + auto* indices_mask_without_padding = + parent->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {indices_len - 1}), + sorted_indices_preceding_part, sorted_indices_following_part, + ComparisonDirection::kNe)); + // Pad the comparison with a true value at the end + auto* true_constant = parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + auto* padding = parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, {1}), true_constant, {})); + std::vector padding_operands = {indices_mask_without_padding, + padding}; + auto* indices_mask = parent->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(PRED, {indices_len}), padding_operands, 0)); + + // Mask the indices + indices_mask = parent->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(PRED, scatter_indices->shape().dimensions()), + indices_mask, {0})); + + auto* out_of_bound_tensor = + CreateOutOfBoundTensor(parent, scatter_indices, scatter->shape()); + + auto* masked_indices = parent->AddInstruction(HloInstruction::CreateTernary( + sorted_indices->shape(), HloOpcode::kSelect, indices_mask, sorted_indices, + out_of_bound_tensor)); + return masked_indices; +} + +absl::StatusOr ScatterDeterminismExpander::ExpandInstruction( + HloInstruction* inst) { + auto* scatter = Cast(inst); + auto scatter_operands = scatter->scatter_operands(); + HloInstruction* scatter_indices = scatter->scatter_indices(); + std::vector scatter_updates( + scatter->scatter_updates().begin(), scatter->scatter_updates().end()); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + + // If the updates tensors are empty, there is no need to update the operands. + // The operands can be forwarded. + if (ShapeUtil::IsZeroElementArray(scatter_updates[0]->shape())) { + if (scatter_operands.size() == 1) { + return scatter_operands[0]; + } + return scatter->parent()->AddInstruction( + HloInstruction::CreateTuple(scatter_operands)); + } + + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + int64_t scatter_indices_count = ScatterIndicesCount(scatter); + if (!IsInt32(scatter_indices_count)) { + // 2147483647 is the maximum value for a 32-bit signed integer (INT32_MAX). + return Unimplemented( + "Scatter operations with more than 2147483647 scatter indices are not " + "supported. This error occurred for %s.", + scatter->ToString()); + } + + // Canonicalize the scatter_indices, after which the size of its most-major + // dimension must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(scatter_indices, + CanonicalizeScatterIndices( + scatter_indices, dim_numbers.index_vector_dim())); + CHECK_EQ(scatter_indices_count, scatter_indices->shape().dimensions(0)); + + // Canonicalize the updates, after which the size of their most-major + // dimensions must be same as the while loop trip count. + TF_ASSIGN_OR_RETURN(scatter_updates, CanonicalizeScatterUpdates( + scatter_updates, scatter_indices, + dim_numbers, scatter_indices_count)); + + HloComputation* parent = scatter->parent(); + + // Sort the scatter indices and updates together based on the scatter indices. + int64_t num_indices = ShapeUtil::ElementsIn(scatter_updates[0]->shape()); + std::vector sorted_tensors = SortIndicesAndUpdates( + scatter_indices, scatter_updates, num_indices, scatter, parent); + HloInstruction* sorted_scalar_indices = sorted_tensors[0]; + std::vector sorted_updates(sorted_tensors.begin() + 1, + sorted_tensors.end()); + + TF_ASSIGN_OR_RETURN(std::vector prefix_scan_updates, + ComputePrefixScan(sorted_updates, sorted_scalar_indices, + scatter, parent)); + + HloInstruction* last_occurrence_indices = FindLastOccurrenceIndices( + scatter_indices, sorted_scalar_indices, scatter, parent, num_indices); + + // Finally, recreate the scatter instruction with unique indices + return parent->AddInstruction(HloInstruction::CreateScatter( + scatter->shape(), scatter_operands, last_occurrence_indices, + prefix_scan_updates, scatter->to_apply(), dim_numbers, + /*indices_are_sorted=*/true, /*unique_indices=*/true)); +} + +namespace { +void RecursivelyGetInputParamNumbers( + const HloInstruction* instruction, std::vector& param_numbers, + absl::flat_hash_set& visited) { + if (!visited.emplace(instruction).second) { + return; + } + + if (instruction->opcode() == HloOpcode::kParameter) { + param_numbers.push_back(instruction->parameter_number()); + return; + } + for (HloInstruction* operand : instruction->operands()) { + RecursivelyGetInputParamNumbers(operand, param_numbers, visited); + } +} + +// Check if every output of the scatter computation only depends on the +// corresponding operand and updates +bool CheckOutputDependency(HloComputation* to_apply, int operand_size) { + HloInstruction* root = to_apply->root_instruction(); + if (!root->shape().IsTuple()) { + return true; + } + CHECK_EQ(operand_size, root->operand_count()); + + // traverse the tuple output of the computation + for (int i = 0; i < operand_size; ++i) { + const HloInstruction* output = root->operand(i); + std::vector param_numbers; + absl::flat_hash_set visited; + RecursivelyGetInputParamNumbers(output, param_numbers, visited); + // The input dependencies can be at most 2 + if (param_numbers.size() > 2) { + return false; + } + for (int64_t param_number : param_numbers) { + if (param_number != i && param_number != operand_size + i) { + return false; + } + } + } + return true; +} + +} // namespace + +bool ScatterDeterminismExpander::InstructionMatchesPattern( + HloInstruction* inst) { + auto* scatter = DynCast(inst); + // Need to check if updates and indices are scalar, as the current pass does + // not expand scatter with multi-dimensional updates or indices. This is + // temporary and will be removed in a future PR soon. + if (scatter == nullptr) { + return false; + } + + const Shape& indices_shape = scatter->scatter_indices()->shape(); + const Shape& updates_shape = scatter->scatter_updates()[0]->shape(); + + // Check if indices and updates are effectively 1D. + bool indices_are_1d = + (indices_shape.rank() == 1 || + (indices_shape.rank() == 2 && indices_shape.dimensions(1) == 1)); + bool updates_are_1d = + (updates_shape.rank() == 1 || + (updates_shape.rank() == 2 && updates_shape.dimensions(1) == 1)); + + return indices_are_1d && updates_are_1d && !IsScatterDeterministic(scatter) && + CheckOutputDependency(scatter->to_apply(), + scatter->scatter_operands().size()); +} + +} // namespace xla diff --git a/third_party/xla/xla/service/scatter_determinism_expander.h b/third_party/xla/xla/service/scatter_determinism_expander.h new file mode 100644 index 00000000000000..62c80346cad87f --- /dev/null +++ b/third_party/xla/xla/service/scatter_determinism_expander.h @@ -0,0 +1,44 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ +#define XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ + +#include "xla/hlo/transforms/expanders/op_expander_pass.h" + +namespace xla { + +// This pass rewrites scatter operations into a prefix-scan based algorithm that +// ensures the scatter results to be determininstic. Note that the computation +// after the expansion still contains a scatter operation, but it does not have +// duplicated indices and hence the results are guaranteed to be deterministic. +class ScatterDeterminismExpander : public OpExpanderPass { + public: + explicit ScatterDeterminismExpander() = default; + + absl::string_view name() const override { + return "scatter_determinism_expander"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* inst) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_SCATTER_DETERMINISM_EXPANDER_H_ diff --git a/third_party/xla/xla/service/scatter_determinism_expander_test.cc b/third_party/xla/xla/service/scatter_determinism_expander_test.cc new file mode 100644 index 00000000000000..23e7e87d4bcce9 --- /dev/null +++ b/third_party/xla/xla/service/scatter_determinism_expander_test.cc @@ -0,0 +1,330 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_determinism_expander.h" + +#include +#include +#include + +#include "xla/literal.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +class ScatterDeterminismExpanderTest : public HloTestBase {}; + +TEST_F(ScatterDeterminismExpanderTest, + DoNotEliminateScatterWithAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = s32[] parameter(1) + arg0.172 = s32[] parameter(0) + ROOT add.48 = s32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = s32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = s32[4096,1,1] parameter(2) + ROOT scatter.48 = s32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, + EliminateScatterWithNonAssociativeCombiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[4096] parameter(0) + pad.96 = s32[4096,1] parameter(1) + bitcast.2748 = f32[4096] parameter(2) + ROOT scatter.48 = f32[4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_TRUE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, + DoNotEliminateScatterWithAssociativeFp32Combiner) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT max.48 = f32[] maximum(arg0.172, arg1.173) + } + + ENTRY fused_computation { + bitcast.2335 = f32[1,4096] parameter(0) + pad.96 = s32[4096,2] parameter(1) + bitcast.2748 = f32[4096,1,1] parameter(2) + ROOT scatter.48 = f32[1,4096] scatter(bitcast.2335, pad.96, bitcast.2748), + update_window_dims={1,2}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {2}, {3}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + std::vector expected_result = {2.0, 16.0, 14.0, 3.0}; + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddHloVerificationTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[2] constant({0, 0}) + indices = s32[3,1] constant({{0}, {1}, {1}}) + updates = f32[3] constant({2, 1, 5}) + ROOT scatter.48 = f32[2] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + const char* const kExpectedPattern = R"( + CHECK: ENTRY %scatter_add_computation () -> f32[2] { + CHECK-DAG: %[[INDICES:.*]] = s32[3,1]{1,0} constant({ {0}, {1}, {1} }) + CHECK-DAG: %[[RESHAPE:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[OPERAND:.*]] = f32[2]{0} constant({0, 0}) + CHECK-DAG: %[[RESHAPE1:.*]] = s32[3]{0} reshape(%[[INDICES]]) + CHECK-DAG: %[[UPDATES:.*]] = f32[3]{0} constant({2, 1, 5}) + CHECK-DAG: %[[TRANSPOSE:.*]] = f32[3]{0} transpose(%[[UPDATES]]), dimensions={0} + CHECK-DAG: %[[RESHAPE2:.*]] = f32[3]{0} reshape(%[[TRANSPOSE]]) + CHECK-DAG: %[[SORT:.*]] = (s32[3]{0}, f32[3]{0}) sort(%[[RESHAPE1]], %[[RESHAPE2]]), dimensions={0}, to_apply=%sorting_computation + CHECK-DAG: %[[GET_TUPLE_ELEMENT:.*]] = s32[3]{0} get-tuple-element(%[[SORT]]), index=0 + CHECK-DAG: %[[SLICE4:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} + CHECK-DAG: %[[SLICE5:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[1:3]} + CHECK-DAG: %[[COMPARE3:.*]] = pred[2]{0} compare(%[[SLICE4]], %[[SLICE5]]), direction=NE + CHECK-DAG: %[[CONSTANT4:.*]] = pred[] constant(true) + CHECK-DAG: %[[BROADCAST4:.*]] = pred[1]{0} broadcast(%[[CONSTANT4]]), dimensions={} + CHECK-DAG: %[[CONCAT_COMPARE4:.*]] = pred[3]{0} concatenate(%[[COMPARE3]], %[[BROADCAST4]]), dimensions={0} + CHECK-DAG: %[[BROADCAST5:.*]] = pred[3]{0} broadcast(%[[CONCAT_COMPARE4]]), dimensions={0} + CHECK-DAG: %[[CONSTANT5:.*]] = s32[3]{0} constant({2, 2, 2}) + CHECK-DAG: %[[SELECT2:.*]] = s32[3]{0} select(%[[BROADCAST5]], %[[GET_TUPLE_ELEMENT]], %[[CONSTANT5]]) + CHECK-DAG: %[[CONSTANT3:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST3:.*]] = s32[2]{0} broadcast(%[[CONSTANT3]]), dimensions={} + CHECK-DAG: %[[SLICE3:.*]] = s32[1]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:1]} + CHECK-DAG: %[[CONCAT3:.*]] = s32[3]{0} concatenate(%[[BROADCAST3]], %[[SLICE3]]), dimensions={0} + CHECK-DAG: %[[COMPARE2:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT3]]), direction=EQ + CHECK-DAG: %[[CONSTANT1:.*]] = s32[] constant(0) + CHECK-DAG: %[[BROADCAST1:.*]] = s32[1]{0} broadcast(%[[CONSTANT1]]), dimensions={} + CHECK-DAG: %[[SLICE1:.*]] = s32[2]{0} slice(%[[GET_TUPLE_ELEMENT]]), slice={[0:2]} + CHECK-DAG: %[[CONCAT1:.*]] = s32[3]{0} concatenate(%[[BROADCAST1]], %[[SLICE1]]), dimensions={0} + CHECK-DAG: %[[COMPARE1:.*]] = pred[3]{0} compare(%[[GET_TUPLE_ELEMENT]], %[[CONCAT1]]), direction=EQ + CHECK-DAG: %[[GET_TUPLE_ELEMENT1:.*]] = f32[3]{0} get-tuple-element(%[[SORT]]), index=1 + CHECK-DAG: %[[CONSTANT_F32:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST_F32:.*]] = f32[1]{0} broadcast(%[[CONSTANT_F32]]), dimensions={} + CHECK-DAG: %[[SLICE_F32:.*]] = f32[2]{0} slice(%[[GET_TUPLE_ELEMENT1]]), slice={[0:2]} + CHECK-DAG: %[[CONCAT_F32:.*]] = f32[3]{0} concatenate(%[[BROADCAST_F32]], %[[SLICE_F32]]), dimensions={0} + CHECK-DAG: %[[MAP:.*]] = f32[3]{0} map(%[[GET_TUPLE_ELEMENT1]], %[[CONCAT_F32]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT:.*]] = f32[3]{0} select(%[[COMPARE1]], %[[MAP]], %[[GET_TUPLE_ELEMENT1]]) + CHECK-DAG: %[[CONSTANT2:.*]] = f32[] constant(0) + CHECK-DAG: %[[BROADCAST2:.*]] = f32[2]{0} broadcast(%[[CONSTANT2]]), dimensions={} + CHECK-DAG: %[[SLICE2:.*]] = f32[1]{0} slice(%[[SELECT]]), slice={[0:1]} + CHECK-DAG: %[[CONCAT2:.*]] = f32[3]{0} concatenate(%[[BROADCAST2]], %[[SLICE2]]), dimensions={0} + CHECK-DAG: %[[MAP1:.*]] = f32[3]{0} map(%[[SELECT]], %[[CONCAT2]]), dimensions={0}, to_apply=%scatter_computation + CHECK-DAG: %[[SELECT1:.*]] = f32[3]{0} select(%[[COMPARE2]], %[[MAP1]], %[[SELECT]]) + CHECK-DAG: ROOT %[[SCATTER:.*]] = f32[2]{0} scatter(%[[OPERAND]], %[[SELECT2]], %[[SELECT1]]), + CHECK-SAME: update_window_dims={}, + CHECK-SAME: inserted_window_dims={0}, + CHECK-SAME: scatter_dims_to_operand_dims={0}, + CHECK-SAME: index_vector_dim=1, + CHECK-SAME: indices_are_sorted=true, + CHECK-SAME: unique_indices=true, + CHECK-SAME: to_apply=%scatter_computation + )"; + + RunAndFilecheckHloRewrite(kModuleStr, ScatterDeterminismExpander(), + kExpectedPattern); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddOutOfBoundCorrectnessTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[4] constant({0, 0, 0, 0}) + indices = s32[7,1] constant({{0}, {1}, {5}, {4}, {1}, {1}, {2}}) + updates = f32[7] constant({2, 1, 5, 3, 8, 7, 9}) + ROOT scatter.48 = f32[4] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + std::vector expected_result = {2.0, 16.0, 9.0, 0.0}; + + Literal result_literal = ExecuteAndTransfer(std::move(module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, expected_result); +} + +TEST_F(ScatterDeterminismExpanderTest, ScatterAddReproducibilityTest) { + const char* const kModuleStr = R"( + HloModule scatter_determinism_expander + + scatter_computation { + arg1.173 = f32[] parameter(1) + arg0.172 = f32[] parameter(0) + ROOT add.48 = f32[] add(arg0.172, arg1.173) + } + + ENTRY scatter_add_computation { + operand = f32[3] constant({0, 0, 0}) + indices = s32[100,1] constant({{0}, {3}, {0}, {1}, {0}, {3}, {1}, {2}, {1}, {2}, {2}, {2}, {0}, {2}, {1}, {0}, {1}, {1}, {2}, {0}, {2}, {1}, {2}, {1}, {2}, {2}, {3}, {2}, {2}, {0}, {3}, {0}, {3}, {2}, {0}, {3}, {3}, {3}, {3}, {3}, {2}, {3}, {3}, {0}, {0}, {3}, {3}, {3}, {2}, {3}, {2}, {3}, {0}, {0}, {2}, {0}, {1}, {3}, {1}, {3}, {2}, {2}, {2}, {1}, {0}, {3}, {1}, {1}, {1}, {1}, {1}, {2}, {2}, {3}, {0}, {2}, {2}, {0}, {2}, {1}, {0}, {2}, {2}, {2}, {0}, {2}, {0}, {1}, {3}, {0}, {2}, {3}, {3}, {2}, {0}, {3}, {3}, {2}, {3}, {2}}) + updates = f32[100] constant({0.02379167, 0.8527204, 0.8132185, 0.5140263, 0.17172801, 0.8026866, 0.5124631, 0.34838438, 0.50526905, 0.3370521, 0.10868239, 0.10520637, 0.83827364, 0.78986526, 0.34059846, 0.8349273, 0.24575627, 0.21387374, 0.02423227, 0.5617423, 0.28066766, 0.94366455, 0.61214995, 0.7383388, 0.52419806, 0.65466726, 0.41012764, 0.24028647, 0.74443066, 0.03544927, 0.851014, 0.02434528, 0.47239733, 0.72706807, 0.35055435, 0.6274171, 0.61077535, 0.06525731, 0.8091929, 0.21307838, 0.6465323, 0.3245015, 0.5538883, 0.8849807, 0.9591211, 0.83856845, 0.48919427, 0.11810577, 0.16933143, 0.83657074, 0.587505, 0.6867087, 0.95522237, 0.5797727, 0.28024232, 0.34749162, 0.5199702, 0.9811766, 0.5645981, 0.2446456, 0.68722725, 0.9616587, 0.480047, 0.88953114, 0.7083205, 0.948612, 0.67764974, 0.44131804, 0.36789334, 0.95148766, 0.30909216, 0.70908046, 0.8749926, 0.60973287, 0.60751855, 0.22647333, 0.5363518, 0.96195626, 0.08158326, 0.5266887, 0.85922587, 0.648262, 0.4657668, 0.31623375, 0.43507564, 0.48351157, 0.41285944, 0.73501325, 0.15267539, 0.67055714, 0.08459568, 0.04527426, 0.21078384, 0.4654404, 0.7363906, 0.23245859, 0.22119188, 0.99092937, 0.878675, 0.4102913}) + ROOT scatter.48 = f32[3] scatter(operand, indices, updates), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=1, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + ScatterDeterminismExpander scatter_determinism_expander; + TF_ASSERT_OK_AND_ASSIGN( + bool result, RunHloPass(&scatter_determinism_expander, module.get())); + + EXPECT_TRUE(result); + + auto cloned_module = module->Clone(); + Literal first_result_literal = + ExecuteAndTransfer(std::move(cloned_module), {}); + auto first_result_span = first_result_literal.data(); + std::vector first_result(first_result_span.begin(), + first_result_span.end()); + + const int num_trials = 20; + std::vector> results; + + for (int i = 0; i < num_trials; ++i) { + auto cloned_module = module->Clone(); + + Literal result_literal = ExecuteAndTransfer(std::move(cloned_module), {}); + + auto result_data = result_literal.data(); + std::vector actual_result(result_data.begin(), result_data.end()); + + EXPECT_EQ(actual_result, first_result) + << "Results are not reproducible across trials!"; + } +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/scatter_expander.cc b/third_party/xla/xla/service/scatter_expander.cc index cd7f72c64c1777..01ebae5dd533dd 100644 --- a/third_party/xla/xla/service/scatter_expander.cc +++ b/third_party/xla/xla/service/scatter_expander.cc @@ -15,8 +15,13 @@ limitations under the License. #include "xla/service/scatter_expander.h" +#include +#include +#include + #include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -24,154 +29,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal_util.h" -#include "xla/service/call_inliner.h" +#include "xla/service/gather_scatter_utils.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/service/scatter_utils.h" #include "xla/service/while_util.h" +#include "xla/shape.h" namespace xla { -// Transposes the given scatter_indices such that the index_vector_dim becomes -// the most-minor dimension. -static absl::StatusOr TransposeIndexVectorDimToLast( - HloInstruction* scatter_indices, int64_t index_vector_dim) { - const Shape& scatter_indices_shape = scatter_indices->shape(); - - if (scatter_indices_shape.dimensions_size() == index_vector_dim) { - return scatter_indices; - } - - if (index_vector_dim == (scatter_indices_shape.dimensions_size() - 1)) { - return scatter_indices; - } - - std::vector permutation; - permutation.reserve(scatter_indices_shape.dimensions_size()); - for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != index_vector_dim) { - permutation.push_back(i); - } - } - permutation.push_back(index_vector_dim); - return MakeTransposeHlo(scatter_indices, permutation); -} - -// Canonicalizes the scatter_indices tensor in order to keep them uniform while -// performing the scatter operation. -static absl::StatusOr CanonicalizeScatterIndices( - HloInstruction* scatter_indices, int64_t index_vector_dim) { - // Transpose the non-index-vector dimensions to the front. - TF_ASSIGN_OR_RETURN( - HloInstruction * transposed_scatter_indices, - TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); - if (scatter_indices->shape().rank() == index_vector_dim + 1 && - scatter_indices->shape().dimensions(index_vector_dim) == 1) { - auto new_shape = - ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); - TF_ASSIGN_OR_RETURN(scatter_indices, - MakeReshapeHlo(new_shape, scatter_indices)); - } - bool indices_are_scalar = - index_vector_dim == scatter_indices->shape().dimensions_size(); - - // The number of dimensions in scatter_indices that are index dimensions. - const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; - - // If there is only one index (i.e. scatter_indices has rank 1 and this - // scatter is really just a dynamic update slice) add a leading degenerate - // dimension for uniformity. Otherwise create a "collapsed" leading dimension - // that subsumes all of the non-index-vector dimensions. - const Shape& shape = transposed_scatter_indices->shape(); - if (shape.dimensions_size() == index_dims_in_scatter_indices) { - return PrependDegenerateDims(transposed_scatter_indices, 1); - } else { - // Collapse all but the dimensions (0 or 1) in scatter_indices containing - // the index vectors. - return CollapseFirstNDims( - transposed_scatter_indices, - shape.dimensions_size() - index_dims_in_scatter_indices); - } -} - -// Permutes the `updates` tensor such that all the scatter dims appear in the -// major dimensions and all the window dimensions appear in the minor -// dimensions. -static absl::StatusOr PermuteScatterAndWindowDims( - HloInstruction* updates, absl::Span update_window_dims) { - std::vector permutation; - const int64_t updates_rank = updates->shape().rank(); - permutation.reserve(updates_rank); - - for (int64_t i = 0; i < updates_rank; ++i) { - bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); - if (is_scatter_dim) { - permutation.push_back(i); - } - } - for (auto window_dim : update_window_dims) { - permutation.push_back(window_dim); - } - - return MakeTransposeHlo(updates, permutation); -} - -// Expands or contracts the scatter indices in the updates tensor. -static absl::StatusOr AdjustScatterDims( - const Shape& scatter_indices_shape, HloInstruction* updates, - int64_t index_vector_dim) { - int64_t num_scatter_dims = scatter_indices_shape.dimensions_size(); - if (index_vector_dim < scatter_indices_shape.dimensions_size()) { - --num_scatter_dims; - } - if (num_scatter_dims == 0) { - // If there are no scatter dims, this must be a dynamic-update-slice kind of - // scatter. In this case, we prepend a degenerate dimension to work - // uniformly in the while loop. - return PrependDegenerateDims(updates, 1); - } - return CollapseFirstNDims(updates, num_scatter_dims); -} - -// Expands an index vector from the scatter_indices tensor into a vector that -// can be used to dynamic-update-slice to perform the scatter update. -static absl::StatusOr ExpandIndexVectorIntoOperandSpace( - HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers, - int64_t operand_rank) { - HloComputation* computation = index_vector->parent(); - const Shape& index_shape = index_vector->shape(); - - // Scatter of a scalar. Return a zero-sized vector of indices. - if (operand_rank == 0) { - return computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); - } - - HloInstruction* zero = - computation->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); - - // We extract out individual components from the smaller index and concatenate - // them (interspersing zeros as needed) into the larger index. - std::vector expanded_index_components; - - for (int i = 0; i < operand_rank; i++) { - int64_t index_vector_dim_index = - FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i); - if (index_vector_dim_index != - dim_numbers.scatter_dims_to_operand_dims_size()) { - TF_ASSIGN_OR_RETURN( - HloInstruction * component_to_concat, - MakeSliceHlo(index_vector, /*start_indices=*/{index_vector_dim_index}, - /*limit_indices=*/{index_vector_dim_index + 1}, - /*strides=*/{1})); - expanded_index_components.push_back(component_to_concat); - } else { - expanded_index_components.push_back(zero); - } - } - - return MakeConcatHlo(expanded_index_components, /*dimension=*/0); -} - static absl::StatusOr CheckIndexValidity( HloComputation* computation, HloInstruction* index, absl::Span operand_dims, @@ -218,31 +83,21 @@ static absl::StatusOr CheckIndexValidity( return MakeBroadcastHlo(valid_index_reduced, {}, window_sizes); } -static absl::StatusOr CallAndGetOutput( - HloComputation* original, int output_index) { - HloInstruction* original_root = original->root_instruction(); - if (!original_root->shape().IsTuple()) { - return original; - } - HloComputation* new_comp = [&] { - HloComputation::Builder builder( - absl::StrCat(original->name(), ".dup.", output_index)); - for (int i = 0, n = original->num_parameters(); i < n; ++i) { - HloInstruction* original_param = original->parameter_instruction(i); - builder.AddInstruction(HloInstruction::CreateParameter( - i, original_param->shape(), original_param->name())); - } - return original->parent()->AddEmbeddedComputation(builder.Build()); - }(); - HloInstruction* call_original = new_comp->AddInstruction( - HloInstruction::CreateCall(original_root->shape(), - new_comp->parameter_instructions(), original)); - new_comp->set_root_instruction( - new_comp->AddInstruction( - HloInstruction::CreateGetTupleElement(call_original, output_index)), - /*accept_different_shape=*/true); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); - return new_comp; +// Returns the sorted dimensions in a slice that are either collapsed or +// corresponding to an explicit batching dimension. +std::vector GetDegeneratedSliceDims( + const ScatterDimensionNumbers& dim_numbers) { + absl::Span input_batching_dims = + dim_numbers.input_batching_dims(); + absl::Span inserted_window_dims = + dim_numbers.inserted_window_dims(); + std::vector degenerated_dims; + degenerated_dims.reserve(inserted_window_dims.size() + + input_batching_dims.size()); + absl::c_copy(inserted_window_dims, std::back_inserter(degenerated_dims)); + absl::c_copy(input_batching_dims, std::back_inserter(degenerated_dims)); + absl::c_sort(degenerated_dims); + return degenerated_dims; } // Body of the while loop that performs the scatter operation using other HLOs. @@ -286,7 +141,12 @@ static absl::StatusOr> ScatterLoopBody( TF_ASSIGN_OR_RETURN( HloInstruction * scatter_slice_start, ExpandIndexVectorIntoOperandSpace( - index_vector, dim_numbers, operands[0]->shape().dimensions_size())); + scatter->scatter_indices()->shape(), + operands[0]->shape().dimensions_size(), + dim_numbers.index_vector_dim(), + dim_numbers.scatter_dims_to_operand_dims(), + dim_numbers.scatter_indices_batching_dims(), + dim_numbers.input_batching_dims(), index_vector, induction_var)); // Extract the slice to be used to update from `updates` tensor for the // induction_var corresponding to this iteration of the while loop. @@ -307,6 +167,9 @@ static absl::StatusOr> ScatterLoopBody( auto update_slices_with_dims_inserted = absl::MakeSpan(map_operands).last(updates.size()); absl::Span actual_update_slice_dims; + + std::vector degenerated_dims = GetDegeneratedSliceDims(dim_numbers); + for (int i = 0, n = operands.size(); i < n; ++i) { HloInstruction* update = updates[i]; TF_ASSIGN_OR_RETURN( @@ -316,8 +179,7 @@ static absl::StatusOr> ScatterLoopBody( ElideDegenerateDims(update_slice, {0})); TF_ASSIGN_OR_RETURN( HloInstruction * update_slice_with_dims_inserted, - InsertDegenerateDims(update_slice_for_scatter, - dim_numbers.inserted_window_dims())); + InsertDegenerateDims(update_slice_for_scatter, degenerated_dims)); update_slices_with_dims_inserted[i] = update_slice_with_dims_inserted; // Note that the following transformation assumes that both DynamicSlice and // DynamicUpdateSlice follow the same semantics for OOB indices. For @@ -377,22 +239,6 @@ static absl::StatusOr> ScatterLoopBody( return updated_loop_state; } -static int64_t ScatterTripCount(const HloScatterInstruction* scatter) { - // Compute the trip count for the while loop to be used for scatter. This - // should be the number of indices we should scatter into the operand. - const HloInstruction* scatter_indices = scatter->scatter_indices(); - const Shape& scatter_indices_shape = scatter_indices->shape(); - const ScatterDimensionNumbers& dim_numbers = - scatter->scatter_dimension_numbers(); - int64_t scatter_loop_trip_count = 1; - for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); - } - } - return scatter_loop_trip_count; -} - // High Level Algorithm. // // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where @@ -431,7 +277,7 @@ absl::StatusOr ScatterExpander::ExpandInstruction( // Compute the trip count for the while loop to be used for scatter. This // should be the number of indices we should scatter into the operand. - int64_t scatter_loop_trip_count = ScatterTripCount(scatter); + int64_t scatter_loop_trip_count = ScatterIndicesCount(scatter); if (!IsInt32(scatter_loop_trip_count)) { return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " @@ -485,48 +331,13 @@ absl::StatusOr ScatterExpander::ExpandInstruction( return MaybeMakeTuple(results); } -namespace { - -bool IsCombinerAssociative(const HloComputation* combiner) { - // Consider simple binary combiner functions only. - if (combiner->instruction_count() != 3) { - return false; - } - switch (combiner->root_instruction()->opcode()) { - // Minimum and Maximum are common associative combiners. - case HloOpcode::kMinimum: - case HloOpcode::kMaximum: - return true; - // Other common combiners are associative at least for integer arithmetic. - case HloOpcode::kAdd: - case HloOpcode::kMultiply: - case HloOpcode::kOr: - case HloOpcode::kXor: - return combiner->root_instruction()->shape().IsInteger(); - default: - return false; - } -} - -bool IsDeterministic(const HloScatterInstruction* scatter) { - if (scatter->unique_indices()) { - return true; - } - if (IsCombinerAssociative(scatter->to_apply())) { - return true; - } - return false; -} - -} // namespace - bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { auto* scatter = DynCast(inst); return (scatter != nullptr) && (mode_ == kEliminateAllScatters || (mode_ == kEliminateSimpleScatters && - ScatterTripCount(scatter) == 1) || + ScatterIndicesCount(scatter) == 1) || (mode_ == kEliminateIndeterministicScatters && - !IsDeterministic(scatter))); + !IsScatterDeterministic(scatter))); } } // namespace xla diff --git a/third_party/xla/xla/service/scatter_expander.h b/third_party/xla/xla/service/scatter_expander.h index 0658e962017454..fd19be4461b45f 100644 --- a/third_party/xla/xla/service/scatter_expander.h +++ b/third_party/xla/xla/service/scatter_expander.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SCATTER_EXPANDER_H_ #define XLA_SERVICE_SCATTER_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/scatter_expander_test.cc b/third_party/xla/xla/service/scatter_expander_test.cc index 4d135d3bb26dad..664f0112068fb8 100644 --- a/third_party/xla/xla/service/scatter_expander_test.cc +++ b/third_party/xla/xla/service/scatter_expander_test.cc @@ -16,11 +16,14 @@ limitations under the License. #include "xla/service/scatter_expander.h" #include +#include #include +#include #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/filecheck.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/shape_util.h" @@ -65,7 +68,6 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) { ParseAndReturnVerifiedModule(kModuleStr)); ClearInstructionLayout(module.get(), "operand"); - ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&scatter_expander, module.get())); @@ -140,6 +142,85 @@ TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) { EXPECT_FALSE(result); } +TEST_F(ScatterExpanderTest, ScatterToLoopWithBatchDims) { + const char* kModuleStr = R"( +HloModule TensorFlowScatter + func { + x = s32[] parameter(0) + y = s32[] parameter(1) + ROOT s = s32[] add(x,y) + } + + ENTRY main { + indices = s32[2,3,5]{2,1,0} parameter(0) + update = s32[2,3,2,5]{3,2,1,0} parameter(1) + z = s32[] constant(0) + input = s32[5,3,2,2]{3,2,1,0} broadcast(z), dimensions={} + ROOT s = s32[5,3,2,2]{3,2,1,0} scatter(input, indices, update), + update_window_dims={2}, + inserted_window_dims={1}, + scatter_dims_to_operand_dims={1}, + index_vector_dim=3, + input_batching_dims={0,3}, + scatter_indices_batching_dims={2,0}, + to_apply=func + })"; + + // Verify the code that indexes into the operand. + const std::string expected = R"( + //CHECK: (s32[], s32[5,3,2,2], s32[30], s32[30,2])) -> (s32[], s32[5,3,2,2], s32[30], s32[30,2]) { + //CHECK: %[[PARAM:.*]] = (s32[], s32[5,3,2,2], s32[30], s32[30,2]) parameter(0) + //CHECK: %[[I:.*]] = s32[] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=0 + //CHECK: %[[CONSTANT1:.*]] = s32[] constant(1) + //CHECK: %[[I_PLUS_1:.*]] = s32[] add(s32[] %[[I]], s32[] %[[CONSTANT1]]) + //CHECK: %[[OPERAND:.*]] = s32[5,3,2,2] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=1 + + //CHECK: %[[CONSTANT0:.*]] = s32[] constant(0) + //CHECK: %[[OPERAND_INDICES_LOWER_BOUND:.*]] = s32[4] broadcast(s32[] %[[CONSTANT0]]) + //CHECK: %[[CONSTANT5:.*]] = s32[] constant(5) + //CHECK: %[[REMAINDER:.*]] = s32[] remainder(s32[] %[[I]], s32[] %[[CONSTANT5]]) + //CHECK: %[[BD2:.*]] = s32[1] broadcast(s32[] %[[REMAINDER]]) + //CHECK: %[[START_INDICES:.*]] = s32[30] get-tuple-element((s32[], s32[5,3,2,2], s32[30], s32[30,2]) %[[PARAM]]), index=2 + //CHECK: %[[I_1D_1:.*]] = s32[1] broadcast(s32[] %[[I]]) + //CHECK: %[[START_INDICES_INDEX_RAW:.*]] = s32[1] slice(s32[1] %[[I_1D_1]]) + //CHECK: %[[START_INDICES_INDEX:.*]] = s32[] reshape(s32[1] %[[START_INDICES_INDEX_RAW]]) + //CHECK: %[[INDEX_VECTOR:.*]] = s32[1] dynamic-slice(s32[30] %[[START_INDICES]], s32[] %[[START_INDICES_INDEX]]) + + //CHECK: %[[SCATTER_INDEX:.*]] = s32[1] slice(s32[1] %[[INDEX_VECTOR]]) + //CHECK: %[[CONSTANT0_2:.*]] = s32[1] constant({0}) + //CHECK: %[[BD_0_1:.*]] = s32[] divide(s32[] %[[I]], s32[] %[[CONSTANT5]]) + //CHECK: %[[CONSTANT3:.*]] = s32[] constant(3) + //CHECK: %[[BD0_RAW:.*]] = s32[] divide(s32[] %[[BD_0_1]], s32[] %[[CONSTANT3]]) + //CHECK: %[[BD0:.*]] = s32[1] broadcast(s32[] %[[BD0_RAW]]) + //CHECK: %[[OPERAND_INDICES:.*]] = s32[4] concatenate(s32[1] %[[BD2]], s32[1] %[[SCATTER_INDEX]], s32[1] %[[CONSTANT0_2]], s32[1] %[[BD0]]) + //CHECK: %[[OPERAND_INDEX_D0_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[0:1]} + //CHECK: %[[OPERAND_INDEX_D0:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D0_RAW]]) + //CHECK: %[[OPERAND_INDEX_D1_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[1:2]} + //CHECK: %[[OPERAND_INDEX_D1:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D1_RAW]]) + //CHECK: %[[OPERAND_INDEX_D2_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[2:3]} + //CHECK: %[[OPERAND_INDEX_D2:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D2_RAW]]) + //CHECK: %[[OPERAND_INDEX_D3_RAW:.*]] = s32[1] slice(s32[4] %[[OPERAND_INDICES]]), slice={[3:4]} + //CHECK: %[[OPERAND_INDEX_D3:.*]] = s32[] reshape(s32[1] %[[OPERAND_INDEX_D3_RAW]]) + //CHECK: %{{.*}} = s32[1,1,2,1] dynamic-slice(s32[5,3,2,2] %[[OPERAND]], s32[] %[[OPERAND_INDEX_D0]], s32[] %[[OPERAND_INDEX_D1]], s32[] %[[OPERAND_INDEX_D2]], s32[] %[[OPERAND_INDEX_D3]]) +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleStr)); + ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_TRUE(result); + + std::vector while_instructions = + FindInstructions(module.get(), HloOpcode::kWhile); + EXPECT_EQ(while_instructions.size(), 1); + HloComputation* while_body = while_instructions[0]->while_body(); + EXPECT_TRUE( + *RunFileCheck(while_body->ToString( + HloPrintOptions{}.set_include_layout_in_shapes(false)), + expected)); +} + TEST_F(ScatterExpanderTest, EliminateSimpleMultioutpuScattersSkipsNontrivialScatter) { const char* kModuleStr = R"( diff --git a/third_party/xla/xla/service/scatter_simplifier.h b/third_party/xla/xla/service/scatter_simplifier.h index 8b14e16abc9fff..42fbf443bfc88e 100644 --- a/third_party/xla/xla/service/scatter_simplifier.h +++ b/third_party/xla/xla/service/scatter_simplifier.h @@ -17,7 +17,7 @@ limitations under the License. #define XLA_SERVICE_SCATTER_SIMPLIFIER_H_ #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/scatter_simplifier_test.cc b/third_party/xla/xla/service/scatter_simplifier_test.cc index 919a68620fee42..994dc76d2c2d51 100644 --- a/third_party/xla/xla/service/scatter_simplifier_test.cc +++ b/third_party/xla/xla/service/scatter_simplifier_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_fix.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/scatter_utils.cc b/third_party/xla/xla/service/scatter_utils.cc new file mode 100644 index 00000000000000..fe97c0d3e9f25d --- /dev/null +++ b/third_party/xla/xla/service/scatter_utils.cc @@ -0,0 +1,242 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/scatter_utils.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_inliner.h" +#include "xla/service/hlo_creation_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { + +absl::StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64_t index_vector_dim) { + const Shape& scatter_indices_shape = scatter_indices->shape(); + if (index_vector_dim >= (scatter_indices_shape.dimensions_size() - 1)) { + return scatter_indices; + } + + std::vector permutation; + permutation.reserve(scatter_indices_shape.dimensions_size()); + for (int64_t i = 0; i < scatter_indices_shape.dimensions_size(); i++) { + if (i != index_vector_dim) { + permutation.push_back(i); + } + } + permutation.push_back(index_vector_dim); + return MakeTransposeHlo(scatter_indices, permutation); +} + +absl::StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, absl::Span update_window_dims) { + std::vector permutation; + const int64_t updates_rank = updates->shape().rank(); + permutation.reserve(updates_rank); + + for (int64_t i = 0; i < updates_rank; ++i) { + bool is_scatter_dim = !absl::c_binary_search(update_window_dims, i); + if (is_scatter_dim) { + permutation.push_back(i); + } + } + for (int64_t window_dim : update_window_dims) { + permutation.push_back(window_dim); + } + + return MakeTransposeHlo(updates, permutation); +} + +// Expands or contracts the scatter indices in the updates tensor. +absl::StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64_t index_vector_dim) { + int64_t num_scatter_dims = scatter_indices_shape.dimensions_size(); + if (index_vector_dim < scatter_indices_shape.dimensions_size()) { + --num_scatter_dims; + } + if (num_scatter_dims == 0) { + // If there are no scatter dims, this must be a dynamic-update-slice kind of + // scatter. In this case, we prepend a degenerate dimension to work + // uniformly in the while loop. + return PrependDegenerateDims(updates, 1); + } + return CollapseFirstNDims(updates, num_scatter_dims); +} + +absl::StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64_t index_vector_dim) { + // Transpose the non-index-vector dimensions to the front. + TF_ASSIGN_OR_RETURN( + HloInstruction * transposed_scatter_indices, + TransposeIndexVectorDimToLast(scatter_indices, index_vector_dim)); + if (scatter_indices->shape().rank() - 1 == index_vector_dim && + scatter_indices->shape().dimensions(index_vector_dim) == 1) { + auto new_shape = + ShapeUtil::DeleteDimension(index_vector_dim, scatter_indices->shape()); + TF_ASSIGN_OR_RETURN(scatter_indices, + MakeReshapeHlo(new_shape, scatter_indices)); + } + bool indices_are_scalar = + index_vector_dim == scatter_indices->shape().dimensions_size(); + + // The number of dimensions in scatter_indices that are index dimensions. + const int64_t index_dims_in_scatter_indices = indices_are_scalar ? 0 : 1; + + // If there is only one index (i.e. scatter_indices has rank 1 and this + // scatter is really just a dynamic update slice) add a leading degenerate + // dimension for uniformity. Otherwise create a "collapsed" leading dimension + // that subsumes all of the non-index-vector dimensions. + const Shape& shape = transposed_scatter_indices->shape(); + if (shape.dimensions_size() == index_dims_in_scatter_indices) { + return PrependDegenerateDims(transposed_scatter_indices, 1); + } + // Collapse all but the dimensions (0 or 1) in scatter_indices containing + // the index vectors. + return CollapseFirstNDims( + transposed_scatter_indices, + shape.dimensions_size() - index_dims_in_scatter_indices); +} + +absl::StatusOr CallAndGetOutput(HloComputation* original, + int output_index) { + HloInstruction* original_root = original->root_instruction(); + if (!original_root->shape().IsTuple()) { + return original; + } + HloComputation* new_comp = [&] { + HloComputation::Builder builder( + absl::StrCat(original->name(), ".dup.", output_index)); + for (int i = 0, n = original->num_parameters(); i < n; ++i) { + HloInstruction* original_param = original->parameter_instruction(i); + builder.AddInstruction(HloInstruction::CreateParameter( + i, original_param->shape(), original_param->name())); + } + return original->parent()->AddEmbeddedComputation(builder.Build()); + }(); + HloInstruction* call_original = new_comp->AddInstruction( + HloInstruction::CreateCall(original_root->shape(), + new_comp->parameter_instructions(), original)); + new_comp->set_root_instruction( + new_comp->AddInstruction( + HloInstruction::CreateGetTupleElement(call_original, output_index)), + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + return new_comp; +} + +absl::StatusOr CallComputationAndGetIthOutputWithBinaryParams( + HloComputation* original, int output_index) { + HloInstruction* original_root = original->root_instruction(); + if (!original_root->shape().IsTuple()) { + return original; + } + int64_t num_params = original->num_parameters(); + int64_t num_outputs = original_root->shape().tuple_shapes_size(); + + CHECK_EQ(num_params / 2, num_outputs); + HloComputation* new_comp = [&] { + HloComputation::Builder builder( + absl::StrCat(original->name(), ".dup.", output_index)); + HloInstruction* original_param_lhs = + original->parameter_instruction(output_index); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, original_param_lhs->shape(), original_param_lhs->name())); + HloInstruction* original_param_rhs = + original->parameter_instruction(output_index + num_outputs); + builder.AddInstruction(HloInstruction::CreateParameter( + 1, original_param_rhs->shape(), original_param_rhs->name())); + return original->parent()->AddEmbeddedComputation(builder.Build()); + }(); + std::vector operands; + operands.reserve(num_params); + for (int i = 0; i < num_outputs; ++i) { + operands.push_back(new_comp->parameter_instruction(0)); + } + for (int i = 0; i < num_outputs; ++i) { + operands.push_back(new_comp->parameter_instruction(1)); + } + + HloInstruction* call_original = new_comp->AddInstruction( + HloInstruction::CreateCall(original_root->shape(), operands, original)); + new_comp->set_root_instruction( + new_comp->AddInstruction( + HloInstruction::CreateGetTupleElement(call_original, output_index)), + /*accept_different_shape=*/true); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_original).status()); + return new_comp; +} + +int64_t ScatterIndicesCount(const HloScatterInstruction* scatter) { + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + const HloInstruction* scatter_indices = scatter->scatter_indices(); + const Shape& scatter_indices_shape = scatter_indices->shape(); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + int64_t scatter_loop_trip_count = 1; + for (int64_t i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + return scatter_loop_trip_count; +} + +bool IsScatterCombinerAssociative(const HloComputation* combiner) { + // Consider simple binary combiner functions only. + if (combiner->instruction_count() != 3) { + return false; + } + switch (combiner->root_instruction()->opcode()) { + // Minimum and Maximum are common associative combiners. + case HloOpcode::kMinimum: + case HloOpcode::kMaximum: + return true; + // Other common combiners are associative at least for integer arithmetic. + case HloOpcode::kAdd: + case HloOpcode::kMultiply: + case HloOpcode::kOr: + case HloOpcode::kXor: + return combiner->root_instruction()->shape().IsInteger(); + default: + return false; + } +} + +bool IsScatterDeterministic(const HloScatterInstruction* scatter) { + if (scatter->unique_indices()) { + return true; + } + if (IsScatterCombinerAssociative(scatter->to_apply())) { + return true; + } + return false; +} +} // namespace xla diff --git a/third_party/xla/xla/service/scatter_utils.h b/third_party/xla/xla/service/scatter_utils.h new file mode 100644 index 00000000000000..22209e4fef7bb0 --- /dev/null +++ b/third_party/xla/xla/service/scatter_utils.h @@ -0,0 +1,61 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SCATTER_UTILS_H_ +#define XLA_SERVICE_SCATTER_UTILS_H_ + +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" + +namespace xla { + +// Transposes the given scatter_indices such that the index_vector_dim becomes +// the most-minor dimension. +absl::StatusOr TransposeIndexVectorDimToLast( + HloInstruction* scatter_indices, int64_t index_vector_dim); + +// Permutes the `updates` tensor such that all the scatter dims appear in the +// major dimensions and all the window dimensions appear in the minor +// dimensions. +absl::StatusOr PermuteScatterAndWindowDims( + HloInstruction* updates, absl::Span update_window_dims); + +// Expands or contracts the scatter indices in the updates tensor. +absl::StatusOr AdjustScatterDims( + const Shape& scatter_indices_shape, HloInstruction* updates, + int64_t index_vector_dim); + +// Canonicalizes the scatter_indices tensor in order to keep them uniform while +// performing the scatter operation. +absl::StatusOr CanonicalizeScatterIndices( + HloInstruction* scatter_indices, int64_t index_vector_dim); + +absl::StatusOr CallAndGetOutput(HloComputation* original, + int output_index); +absl::StatusOr CallComputationAndGetIthOutputWithBinaryParams( + HloComputation* original, int output_index); + +int64_t ScatterIndicesCount(const HloScatterInstruction* scatter); + +// Checks if the combiner is associative. +bool IsScatterCombinerAssociative(const HloComputation* combiner); + +// Checks if the scatter operation is deterministic. +bool IsScatterDeterministic(const HloScatterInstruction* scatter); + +} // namespace xla + +#endif // XLA_SERVICE_SCATTER_UTILS_H_ diff --git a/third_party/xla/xla/service/select_and_scatter_expander.h b/third_party/xla/xla/service/select_and_scatter_expander.h index 9e544972b3fec9..bcaabdb6f2427a 100644 --- a/third_party/xla/xla/service/select_and_scatter_expander.h +++ b/third_party/xla/xla/service/select_and_scatter_expander.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SELECT_AND_SCATTER_EXPANDER_H_ #define XLA_SERVICE_SELECT_AND_SCATTER_EXPANDER_H_ -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/service.cc b/third_party/xla/xla/service/service.cc index 4b003ce7d6332b..34f3f4e82314d6 100644 --- a/third_party/xla/xla/service/service.cc +++ b/third_party/xla/xla/service/service.cc @@ -24,15 +24,18 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "xla/debug_options_flags.h" -#include "xla/execution_options_util.h" +#include "absl/types/span.h" +#include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/backend.h" @@ -43,27 +46,28 @@ limitations under the License. #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/dynamic_padder.h" #include "xla/service/executable.h" -#include "xla/service/hlo_cost_analysis.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_proto_util.h" -#include "xla/service/platform_util.h" -#include "xla/service/source_map_util.h" +#include "xla/service/service_executable_run_options.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" #include "xla/service/transfer_manager.h" #include "xla/shape.h" -#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" -#include "xla/types.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/scoped_annotation.h" namespace xla { diff --git a/third_party/xla/xla/service/service.h b/third_party/xla/xla/service/service.h index 3fd7f227c362e8..b8c0c65dc00b3f 100644 --- a/third_party/xla/xla/service/service.h +++ b/third_party/xla/xla/service/service.h @@ -24,18 +24,23 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/service/allocation_tracker.h" #include "xla/service/backend.h" #include "xla/service/channel_tracker.h" #include "xla/service/compilation_cache.h" +#include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/execution_tracker.h" #include "xla/service/hlo_execution_profile.h" #include "xla/service/hlo_module_config.h" +#include "xla/service/shaped_buffer.h" +#include "xla/shape.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 4271cc897f41d7..2ab75ddc86e6df 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -348,6 +348,8 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); switch (opcode) { case HloOpcode::kFloor: + case HloOpcode::kCbrt: // Complex cbrt is not implemented in either of the + // backends. case HloOpcode::kCeil: case HloOpcode::kErf: case HloOpcode::kRoundNearestAfz: @@ -368,7 +370,6 @@ absl::StatusOr InferMostSpecificDimAndBound(int64_t dim, case HloOpcode::kLogistic: case HloOpcode::kRsqrt: case HloOpcode::kSqrt: - case HloOpcode::kCbrt: case HloOpcode::kTan: case HloOpcode::kTanh: if (!ShapeUtil::ElementIsFloating(shape) && @@ -1231,6 +1232,14 @@ ShapeInference::InferElementwiseBinaryOpShape( } } + if (operation == HloOpcode::kAtan2 && !ShapeUtil::ElementIsFloating(lhs) && + !ShapeUtil::ElementIsComplex(lhs)) { + return InvalidArgument( + "Expected input element type to be floating or complex for %s " + "operation; got %s.", + HloOpcodeString(operation), PrimitiveType_Name(lhs.element_type())); + } + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs) && !lhs.is_unbounded_dynamic() && !rhs.is_unbounded_dynamic()) { // If the shapes are the same other than layout, the output shape is the @@ -2559,6 +2568,13 @@ ShapeInference::InferScalarBroadcastShape(absl::Span shapes) { return InferVariadicOpShape(HloOpcode::kTuple, operand_shapes); } +/* static */ absl::StatusOr ShapeInference::InferRaggedAllToAllShape( + absl::Span operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectArray(*(operand_shapes[1]), "operand 1 of ragged-all-to-all")); + return *(operand_shapes[1]); +} + /* static */ absl::StatusOr ShapeInference::InferCollectiveBroadcastShape( absl::Span operand_shapes) { @@ -3755,6 +3771,10 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { on_false.is_dynamic_dimension(dimension)); } } + if (result.has_layout()) { + result.mutable_layout()->set_element_size_in_bits( + on_true.layout().element_size_in_bits()); + } return std::move(result); } diff --git a/third_party/xla/xla/service/shape_inference.h b/third_party/xla/xla/service/shape_inference.h index e48915f4a3b968..80e8f9ebeda0a4 100644 --- a/third_party/xla/xla/service/shape_inference.h +++ b/third_party/xla/xla/service/shape_inference.h @@ -178,6 +178,10 @@ class ShapeInference { static absl::StatusOr InferAllToAllTupleShape( absl::Span operand_shapes); + // Infers the shape of an HLO ragged-all-to-all instruction. + static absl::StatusOr InferRaggedAllToAllShape( + absl::Span operand_shapes); + // Infers the shape of a collective broadcast operation. static absl::StatusOr InferCollectiveBroadcastShape( absl::Span operand_shapes); diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 29ae32add358e3..e6b8fdd7bc6214 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -30,16 +30,17 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" -#include "xla/client/padding.h" +#include "xla/hlo/builder/padding.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" namespace xla { @@ -239,6 +240,18 @@ TEST_F(ShapeInferenceTest, SelectBadShapes) { HasSubstr("Expected array argument for select pred")); } +TEST_F(ShapeInferenceTest, SelectPreservesElementSize) { + Shape pred_shape = ShapeUtil::MakeShape(PRED, {10}); + Shape int4_shape = ShapeUtil::MakeShape(S4, {10}); + int4_shape.mutable_layout()->set_element_size_in_bits(4); + + const absl::StatusOr inferred_shape = + ShapeInference::InferTernaryOpShape(HloOpcode::kSelect, pred_shape, + int4_shape, int4_shape); + ASSERT_IS_OK(inferred_shape.status()); + ASSERT_TRUE(ShapeUtil::Equal(*inferred_shape, int4_shape)); +} + TEST_F(ShapeInferenceTest, ClampAllMatrix) { const absl::StatusOr inferred_shape = ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, matrix_64_48_, @@ -338,6 +351,17 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { .ok()); } +TEST_F(ShapeInferenceTest, Atan2FailsWithIntegerInput) { + const Shape input = ShapeUtil::MakeScalarShape(S8); + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kAtan2, input, input, {}); + EXPECT_THAT( + inferred_shape.status(), + tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("Expected input element type to be " + "floating or complex for atan2"))); +} + TEST_F(ShapeInferenceTest, Complex) { const auto complex_shape = [&](const Shape& lhs, const Shape& rhs, absl::Span bcast) { @@ -379,6 +403,17 @@ TEST_F(ShapeInferenceTest, Complex) { ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {}))); } +TEST_F(ShapeInferenceTest, ComplexCbrtIsNotSupported) { + const Shape input = ShapeUtil::MakeScalarShape(C64); + const absl::StatusOr inferred_shape = + ShapeInference::InferUnaryOpShape(HloOpcode::kCbrt, input); + EXPECT_THAT( + inferred_shape.status(), + tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("Expected element type in shape to be " + "floating for cbrt operation"))); +} + TEST_F(ShapeInferenceTest, VariadicOpTuplify) { const absl::StatusOr result = ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_}); diff --git a/third_party/xla/xla/service/shaped_buffer.cc b/third_party/xla/xla/service/shaped_buffer.cc index a5155e4331f624..2dad9ae7f42fd4 100644 --- a/third_party/xla/xla/service/shaped_buffer.cc +++ b/third_party/xla/xla/service/shaped_buffer.cc @@ -22,13 +22,14 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "xla/layout_util.h" #include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" -#include "xla/util.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/device_memory_allocator.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -46,7 +47,7 @@ ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape, int device_ordinal, int physical_device_ordinal) : ShapedBuffer(on_device_shape, device_ordinal, physical_device_ordinal) {} -ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) +ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) noexcept : on_host_shape_(std::move(s.on_host_shape_)), on_device_shape_(std::move(s.on_device_shape_)), device_ordinal_(s.device_ordinal_), @@ -58,7 +59,7 @@ ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) buffers_.replace_shape_ptr(on_device_shape_); } -ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { +ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) noexcept { on_device_shape_ = std::move(s.on_device_shape_); on_host_shape_ = std::move(s.on_host_shape_); device_ordinal_ = s.device_ordinal_; @@ -140,13 +141,14 @@ ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, se::DeviceMemoryAllocator* allocator) : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} -ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) +ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s) noexcept : ShapedBuffer(static_cast(s)), allocator_(s.allocator_) { // Null out s.allocator_ so it doesn't try to free anything in its destructor. s.allocator_ = nullptr; } -ScopedShapedBuffer& ScopedShapedBuffer::operator=(ScopedShapedBuffer&& s) { +ScopedShapedBuffer& ScopedShapedBuffer::operator=( + ScopedShapedBuffer&& s) noexcept { Deallocate(); *static_cast(this) = std::move(static_cast(s)); diff --git a/third_party/xla/xla/service/shaped_buffer.h b/third_party/xla/xla/service/shaped_buffer.h index 5faf97cb64f3d7..92644c18bf3af1 100644 --- a/third_party/xla/xla/service/shaped_buffer.h +++ b/third_party/xla/xla/service/shaped_buffer.h @@ -20,10 +20,13 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" @@ -52,8 +55,8 @@ class ShapedBuffer { int physical_device_ordinal = -1); // Movable, but not copyable. - ShapedBuffer(ShapedBuffer&& s); - ShapedBuffer& operator=(ShapedBuffer&&); + ShapedBuffer(ShapedBuffer&& s) noexcept; + ShapedBuffer& operator=(ShapedBuffer&&) noexcept; ShapedBuffer(const ShapedBuffer&) = delete; ShapedBuffer& operator=(const ShapedBuffer&) = delete; @@ -170,8 +173,8 @@ class ScopedShapedBuffer : public ShapedBuffer { se::DeviceMemoryAllocator* allocator); // Movable, but not copyable. - ScopedShapedBuffer(ScopedShapedBuffer&& s); - ScopedShapedBuffer& operator=(ScopedShapedBuffer&&); + ScopedShapedBuffer(ScopedShapedBuffer&& s) noexcept; + ScopedShapedBuffer& operator=(ScopedShapedBuffer&&) noexcept; ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete; diff --git a/third_party/xla/xla/service/shaped_buffer_test.cc b/third_party/xla/xla/service/shaped_buffer_test.cc index 00e98275bfb6f1..b07e246e33b43d 100644 --- a/third_party/xla/xla/service/shaped_buffer_test.cc +++ b/third_party/xla/xla/service/shaped_buffer_test.cc @@ -19,11 +19,19 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/service/platform_util.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/test.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" namespace xla { diff --git a/third_party/xla/xla/service/sharding_format_picker.h b/third_party/xla/xla/service/sharding_format_picker.h index 2f8f47a0e48bfb..9a369faedf284b 100644 --- a/third_party/xla/xla/service/sharding_format_picker.h +++ b/third_party/xla/xla/service/sharding_format_picker.h @@ -16,30 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ #define XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Test-only pass to transform the HloSharding format of all the instructions in -// a module to the selected format. -class ShardingFormatPicker : public HloModulePass { - public: - enum class ShardingType { - kV1, // Converts all HloSharding to V1 format. - kBestEffortV2, // Best effort to convert all HloSharding to V2 format. - }; - explicit ShardingFormatPicker(ShardingType sharding_type) - : sharding_type_(sharding_type) {} - absl::string_view name() const override { return "sharding-format-picker"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - const ShardingType sharding_type_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/sharding_format_picker.h" #endif // XLA_SERVICE_SHARDING_FORMAT_PICKER_H_ diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 316644bf87ea8b..68adfb73c9a85b 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -16,10 +16,10 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include -#include #include #include #include +#include #include #include #include @@ -35,7 +35,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/status/status.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" @@ -200,6 +199,7 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kPad: case HloOpcode::kPower: case HloOpcode::kOptimizationBarrier: + case HloOpcode::kRaggedAllToAll: case HloOpcode::kReverse: case HloOpcode::kSlice: case HloOpcode::kShiftLeft: @@ -419,6 +419,55 @@ bool SupportSpatialPartitioning( } } +// Helper to lookahead sharding of user of an instruction to be used as guidance +// for ambiguous cases. +std::optional LookaheadUserSharding(HloInstruction* instr, + bool is_spmd, + const CallGraph& call_graph) { + if (instr->user_count() != 1) { + return std::nullopt; + } + HloInstruction* current_user = instr->users()[0]; + std::optional sharding; + std::vector users_chain = {instr, current_user}; + // Collect single user instructions along the way. + while (!current_user->has_sharding()) { + // Only consider single user chains. + if (current_user->users().size() != 1) { + users_chain.clear(); + break; + } + current_user = current_user->users()[0]; + users_chain.push_back(current_user); + } + // Early exit for unsupported cases. + if (users_chain.empty()) { + return std::nullopt; + } + for (int i = users_chain.size() - 1; i >= 1; --i) { + HloInstruction* user = users_chain[i]; + HloInstruction* current = users_chain[i - 1]; + CHECK(user->has_sharding()); + sharding = ShardingPropagation::GetShardingFromUser( + *current, *user, INT64_MAX, is_spmd, call_graph, + /*sharding_helper=*/nullptr); + // We need to set the sharding to the instruction, because + // GetShardingFromUser() interface uses sharding from the instruction + // itself. It will be cleared out later. + if (sharding.has_value() && i != 1) { + current->set_sharding(*sharding); + continue; + } + break; + } + // Clear the sharding of the middle instructions we set the sharding of + // because they were unsharded. + for (int i = 1; i < users_chain.size() - 1; ++i) { + users_chain[i]->clear_sharding(); + } + return sharding; +} + // Infer output sharding on index parallel dimensions for gather from operand // and indices. bool InferGatherParallelShardingFromOperands( @@ -427,28 +476,24 @@ bool InferGatherParallelShardingFromOperands( bool may_combine_partial_sharding) { CHECK(DynCast(instruction)); bool changed = false; - auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims(parallel_dims); auto output_parallel_dims = hlo_sharding_util::GetGatherParallelOutputDims( *instruction, parallel_dims); - // Infer output sharding from scatter operand sharding. + // Infer output sharding from gather operand sharding. if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(0))) { changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - instruction->operand(0)->sharding(), - instruction->operand(0)->shape(), instruction->shape(), - absl::MakeConstSpan(aligned_operand_parallel_dims), + instruction->operand(0)->sharding(), instruction->shape(), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims), absl::MakeConstSpan(output_parallel_dims)), instruction, may_combine_partial_sharding); } - // Infer output sharding from scatter indices sharding. + // Infer output sharding from gather indices sharding. if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(1))) { changed |= MaybeImproveInstructionSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - instruction->operand(1)->sharding(), - instruction->operand(1)->shape(), instruction->shape(), + instruction->operand(1)->sharding(), instruction->shape(), absl::MakeConstSpan(parallel_dims.indices_parallel_dims), absl::MakeConstSpan(output_parallel_dims)), instruction, may_combine_partial_sharding); @@ -469,11 +514,8 @@ bool InferScatterParallelShardingFromOperands( auto scatter_indices = scatter->scatter_indices(); auto scatter_updates = scatter->scatter_updates(); bool changed = false; - auto aligned_operand_parallel_dims = - hlo_sharding_util::IndexAlignedOperandParallelDims(parallel_dims); auto update_parallel_dims = hlo_sharding_util::GetScatterParallelUpdateDims( *instruction, parallel_dims); - auto output_parallel_dims = aligned_operand_parallel_dims; // Infer output sharding from scatter operand sharding. Shape shape = operand_count == 1 ? instruction->shape() @@ -483,9 +525,9 @@ bool InferScatterParallelShardingFromOperands( changed |= MaybeImproveInstructionSubSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_operands[i]->sharding(), scatter_operands[i]->shape(), - shape, absl::MakeConstSpan(aligned_operand_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)), + scatter_operands[i]->sharding(), shape, + absl::MakeConstSpan(parallel_dims.operand_parallel_dims), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); } } @@ -493,9 +535,9 @@ bool InferScatterParallelShardingFromOperands( if (hlo_sharding_util::IsSpatiallyPartitioned(scatter_indices)) { auto parallel_sharding_from_indices = hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_indices->sharding(), scatter_indices->shape(), shape, + scatter_indices->sharding(), shape, absl::MakeConstSpan(parallel_dims.indices_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)); + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)); for (int64_t i = 0; i != operand_count; ++i) { changed |= MaybeImproveInstructionSubSharding( parallel_sharding_from_indices, instruction, {i}, @@ -508,9 +550,9 @@ bool InferScatterParallelShardingFromOperands( changed |= MaybeImproveInstructionSubSharding( hlo_sharding_util:: InferGatherScatterParallelShardingFromOperandSharding( - scatter_updates[i]->sharding(), scatter_updates[i]->shape(), - shape, absl::MakeConstSpan(update_parallel_dims), - absl::MakeConstSpan(output_parallel_dims)), + scatter_updates[i]->sharding(), shape, + absl::MakeConstSpan(update_parallel_dims), + absl::MakeConstSpan(parallel_dims.operand_parallel_dims)), instruction, {i}, may_combine_partial_sharding); } } @@ -519,26 +561,29 @@ bool InferScatterParallelShardingFromOperands( bool CanPropagateThroughAtAggressiveLevel(const HloInstruction& inst, int64_t aggressiveness) { - // At minimum aggressiveness, only allow pass-through ops. - if (aggressiveness < 1 && - !(inst.IsElementwise() || inst.IsCustomCall("Sharding")) && - inst.opcode() != HloOpcode::kTranspose && - inst.opcode() != HloOpcode::kReshape && - inst.opcode() != HloOpcode::kTuple && - inst.opcode() != HloOpcode::kGetTupleElement && - inst.opcode() != HloOpcode::kWhile && - inst.opcode() != HloOpcode::kDynamicSlice && - inst.opcode() != HloOpcode::kDynamicUpdateSlice && - inst.opcode() != HloOpcode::kOptimizationBarrier && - inst.opcode() != HloOpcode::kConcatenate && - inst.opcode() != HloOpcode::kCall && inst.opcode() != HloOpcode::kCopy) { - return false; + // Always allow pass-through ops. + if (inst.IsElementwise() || inst.IsCustomCall("Sharding") || + inst.opcode() == HloOpcode::kCall || + inst.opcode() == HloOpcode::kConcatenate || + inst.opcode() == HloOpcode::kCopy || + inst.opcode() == HloOpcode::kDynamicSlice || + inst.opcode() == HloOpcode::kDynamicUpdateSlice || + inst.opcode() == HloOpcode::kGetTupleElement || + inst.opcode() == HloOpcode::kOptimizationBarrier || + inst.opcode() == HloOpcode::kReshape || + inst.opcode() == HloOpcode::kTuple || + inst.opcode() == HloOpcode::kTranspose || + inst.opcode() == HloOpcode::kWhile) { + return true; } + // Broadcast propagation should have at least aggressiveness 2. - if (aggressiveness < 2 && inst.opcode() == HloOpcode::kBroadcast) { - return false; + if (inst.opcode() == HloOpcode::kBroadcast) { + return aggressiveness >= 2; } - return true; + + // Other ops should have at least aggressiveness 1. + return aggressiveness >= 1; } // Checks if two HloShardings have the same metadata attached. @@ -1025,9 +1070,9 @@ bool IsCSEPreventionSharding(const HloSharding& sharding) { } // namespace bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding) { + bool may_combine_partial_sharding, bool is_spmd) { auto from_operand = [&](int64_t operand_index) { auto operand = instruction->operand(operand_index); const HloSharding& operand_sharding = operand->sharding(); @@ -1082,66 +1127,55 @@ bool InferDotShardingFromOperands( from_operand(1), instruction, may_combine_partial_sharding, /*allow_aggressive_resharding=*/false); } - - // Four cases based on if improved_operand_0 and improved_operand_1 are - // available. - // Case 0. Both operands have no improved sharding. + // If not improved sharding found then do not set any sharding. if (!improved_operand_0.has_value() && !improved_operand_1.has_value()) { return false; } - // Case 1. Sharding found from operand 0 but not operand 1. Set sharding from - // operand 0. + // Sharding found from operand 0 but not operand 1. Set sharding from operand + // 0 if (improved_operand_0.has_value() && !improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_0); return true; } - // Case 2. Sharding found from operand 1 but not operand 0. Set sharding from - // operand 1. + // Sharding found from operand 1 but not operand 0. Set sharding from operand + // 1 if (!improved_operand_0.has_value() && improved_operand_1.has_value()) { instruction->set_sharding(*improved_operand_1); return true; } - // Case 3. Both operands have improved shardings. CHECK(improved_operand_0.has_value() && improved_operand_1.has_value()); - - // If one of the improved shardings is a sub-tiling or equal to the other, use - // the better sharding with more tiles. - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_0, *improved_operand_1)) { - instruction->set_sharding(*improved_operand_0); - return true; - } - if (hlo_sharding_util::IsSubTilingOrEqualSharding( - instruction->shape(), *improved_operand_1, *improved_operand_0)) { - instruction->set_sharding(*improved_operand_1); - return true; - } - - // If the two improved shardings are mergeable, there is no conflict. - if (std::optional improved_sharding = - hlo_sharding_util::ReturnImprovedShardingImpl( - *improved_operand_0, &improved_operand_1.value(), - instruction->shape(), may_combine_partial_sharding, - /*allow_aggressive_resharding=*/false)) { - instruction->set_sharding(*improved_sharding); - return true; - } - - if (aggressiveness < 3) { - // We can improve the dot with different shardings. Pause the propagation - // and wait for the winner between the two operands. - return false; - } - - // The two improved sharding are different and we are at the highest - // aggressiveness. Prioritize the operand with larger size. + std::optional lookahead_sharding = + LookaheadUserSharding(instruction, is_spmd, call_graph); std::array sharding_priority = {*improved_operand_0, *improved_operand_1}; - if (ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < - ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { + bool priority_defined_with_lookahead = false; + // Found sharding from lookahead. + if (lookahead_sharding.has_value()) { + const bool operand_0_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_0); + const bool operand_1_is_lookahead_subtiling = + hlo_sharding_util::IsSubTilingOrEqualSharding( + instruction->shape(), *lookahead_sharding, *improved_operand_1); + // If the sharding from operand 0 is a subtiling of the user, but not the + // one from operand 1 prioritize that sharding. + if (operand_0_is_lookahead_subtiling && !operand_1_is_lookahead_subtiling) { + priority_defined_with_lookahead = true; + } + // If the sharding from operand 1 is a subtiling of the user, but not the + // one from operand 0 prioritize that sharding. + if (!operand_0_is_lookahead_subtiling && operand_1_is_lookahead_subtiling) { + instruction->set_sharding(*improved_operand_1); + std::swap(sharding_priority[0], sharding_priority[1]); + priority_defined_with_lookahead = true; + } + } + // If lookahead didn't define a priority then use size. + if (!priority_defined_with_lookahead && + ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()) < + ShapeUtil::ByteSizeOf(instruction->operand(1)->shape())) { std::swap(sharding_priority[0], sharding_priority[1]); } - // Set primary sharding to the instruction and then try to improve it with // the secondary sharding. instruction->set_sharding(sharding_priority[0]); @@ -1152,8 +1186,9 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - int64_t aggressiveness, - bool may_combine_partial_sharding) { + const CallGraph& call_graph, + bool may_combine_partial_sharding, + bool is_spmd) { auto get_partitions_for_dims = [&](const HloInstruction* inst, absl::Span< @@ -1188,8 +1223,8 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, (lhs_conv_spatial_partitions == 1 && rhs_conv_spatial_partitions == 1 && instruction->batch_group_count() == 1 && instruction->feature_group_count() == 1)) { - return InferDotShardingFromOperands(instruction, dot_dims, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dot_dims, + may_combine_partial_sharding, is_spmd); } const auto& dnums = instruction->convolution_dimension_numbers(); const HloInstruction* lhs = instruction->operand(0); @@ -1348,7 +1383,8 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_like_group, const std::vector* - allow_spmd_sharding_propagation_to_parameters_vector) { + allow_spmd_sharding_propagation_to_parameters_vector, + bool remove_unknown_shardings) { bool changed = false; const bool use_shard_group = instruction_to_shard_group_id && @@ -1440,7 +1476,7 @@ absl::StatusOr ProcessShardingInstruction( bool replaced_with_copy = replace_sharding_with_copy && - (!original_sharding.IsUnknown() || + (!original_sharding.IsUnknown() || remove_unknown_shardings || instruction->operand(0)->opcode() == HloOpcode::kParameter); // Replace the sharding instruction with a copy node so that it does not // need special handling. @@ -2084,7 +2120,7 @@ bool ShardingPropagation::InferShardingFromOperands( return false; } // Do not pass through manual sharding to concat or dynamic slice when - // aggressiveneess is 0. + // aggressiveness is 0. if (aggressiveness == 0 && (instruction->opcode() == HloOpcode::kConcatenate || instruction->opcode() == HloOpcode::kDynamicSlice)) { @@ -2292,8 +2328,8 @@ bool ShardingPropagation::InferShardingFromOperands( 1); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands(instruction, aggressiveness, - may_combine_partial_sharding); + return InferConvolutionShardingFromOperands( + instruction, call_graph, may_combine_partial_sharding, is_spmd_); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!hlo_sharding_util::IsSpatiallyPartitioned(input)) { @@ -2382,8 +2418,9 @@ bool ShardingPropagation::InferShardingFromOperands( case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands(instruction, dnums, aggressiveness, - may_combine_partial_sharding); + return InferDotShardingFromOperands(instruction, call_graph, dnums, + may_combine_partial_sharding, + is_spmd_); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); @@ -2463,6 +2500,21 @@ bool ShardingPropagation::InferShardingFromOperands( } case HloOpcode::kGather: { bool changed = false; + + const GatherDimensionNumbers& dnums = + instruction->gather_dimension_numbers(); + if (!dnums.operand_batching_dims().empty()) { + hlo_sharding_util::GatherScatterParallelDims explict_batch_dims; + explict_batch_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + explict_batch_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + changed |= InferGatherParallelShardingFromOperands( + instruction, explict_batch_dims, may_combine_partial_sharding); + } + if (hlo_sharding_util::IsSpatiallyPartitioned(instruction->operand(1))) { HloSharding new_sharding = hlo_sharding_util:: GatherOutputShardingFromIndexIndexPassthroughDimensions( @@ -2502,11 +2554,26 @@ bool ShardingPropagation::InferShardingFromOperands( } case HloOpcode::kScatter: { auto& scatter = *Cast(instruction); + bool changed = false; + + const ScatterDimensionNumbers& dnums = + instruction->scatter_dimension_numbers(); + if (!dnums.input_batching_dims().empty()) { + hlo_sharding_util::GatherScatterParallelDims explict_batch_dims; + explict_batch_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), + dnums.input_batching_dims().end()); + explict_batch_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + changed |= InferScatterParallelShardingFromOperands( + instruction, explict_batch_dims, may_combine_partial_sharding); + } + const int64_t operand_count = scatter.scatter_operand_count(); auto scatter_operands = scatter.scatter_operands(); auto scatter_indices = scatter.scatter_indices(); auto scatter_updates = scatter.scatter_updates(); - bool changed = false; if (is_spmd_) { for (int64_t i = 0; i != operand_count; ++i) { if (hlo_sharding_util::IsSpatiallyPartitioned(scatter_operands[i])) { @@ -2521,10 +2588,9 @@ bool ShardingPropagation::InferShardingFromOperands( })) { return changed; } - auto scatter_parallel_dims = - hlo_sharding_util::GetScatterParallelBatchDims(*instruction, - call_graph); - if (scatter_parallel_dims) { + if (auto scatter_parallel_dims = + hlo_sharding_util::GetScatterParallelBatchDims(*instruction, + call_graph)) { changed |= InferScatterParallelShardingFromOperands( instruction, *scatter_parallel_dims, may_combine_partial_sharding); @@ -2676,6 +2742,10 @@ bool ShardingPropagation::InferShardingFromUsers( bool improved_sharding = false; const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kRngBitGenerator) { + instruction->set_sharding(HloSharding::Replicate()); + return true; + } std::optional user_sharding = ShardingPropagation::GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd, diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 2654a1fd7d335b..903d5d7730822d 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -16,16 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_SHARDING_PROPAGATION_H_ #define XLA_SERVICE_SHARDING_PROPAGATION_H_ -#include #include #include #include #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_domain_metadata.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/service/call_graph.h" #include "xla/service/custom_call_sharding_helper.h" @@ -36,15 +43,16 @@ namespace xla { // Infers the shardings for a dot HLO op from the shardings on its operands, // which are expected to have sharding annotations. bool InferDotShardingFromOperands( - HloInstruction* instruction, + HloInstruction* instruction, const CallGraph& call_graph, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, - int64_t aggressiveness, bool may_combine_partial_sharding); + bool may_combine_partial_sharding, bool is_spmd); // Infers the shardings for a convolution HLO op from the shardings on its // operands, which are expected to have sharding annotations. bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - int64_t aggressiveness, - bool may_combine_partial_sharding); + const CallGraph& call_graph, + bool may_combine_partial_sharding, + bool is_spmd); // Remove Sharding custom-call instruction by folding the sharding attribute // to its operand. If the operand already has a different sharding, insert a @@ -71,7 +79,8 @@ absl::StatusOr ProcessShardingInstruction( absl::flat_hash_map>* shard_group_id_to_shard_like_group = nullptr, const std::vector* - allow_spmd_sharding_propagation_to_parameters_vector = nullptr); + allow_spmd_sharding_propagation_to_parameters_vector = nullptr, + bool remove_unknown_shardings = false); int64_t ComputeNonRootUsers(const HloInstruction* instr); diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index 565314d9150e33..f0e227a3ced16f 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -33,11 +33,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/hlo/ir/hlo_sharding.h" -#include "xla/hlo/transforms/hlo_constant_splitter.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/hlo_constant_splitter.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/protobuf_util.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -3324,7 +3325,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,2,2]0,1,2,3,4,5,6,7}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -3396,7 +3397,7 @@ ENTRY %conv { EXPECT_THAT(instruction, op::Sharding("{devices=[2,4]0,2,3,1,4,6,7,5}")); if (GetParam().propagate_metadata && !GetParam().clear_metadata) { EXPECT_THAT(instruction->sharding(), - ShardingMetadata({CreateMetadata("a"), CreateMetadata("b")})); + ShardingMetadata({CreateMetadata("b"), CreateMetadata("a")})); } else { EXPECT_THAT(instruction->sharding(), ShardingMetadata({})); } @@ -6487,6 +6488,231 @@ ENTRY %module { EXPECT_THAT(copy_p, op::Sharding("{replicated}")); } +TEST_F(ShardingPropagationTest, GatherExplicitBatchDimsFromOperandToResult) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,2,2,2]<=[16]} + %indices = s32[14,10,6,2] parameter(1) + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,1,2,2]<=[2,2,2,2]T(2,0," + "3,1) last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, GatherExplicitBatchDimsFromIndicesToResult) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,2,2]<=[16]} + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,2,1,2]<=[16] last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, GatherBackwardWithExplicitBatchDims) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4}, + sharding={devices=[2,2,2,2]<=[16]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{devices=[2,1,2,2,2]<=[2,2,2,2]T(1,0,3,2) " + "last_tile_dim_replicate}")); + EXPECT_THAT( + module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,2,2,1,2]<=[16] last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromOperandToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[2,2,2,2]<=[16]} + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,2] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,2,2,2]<=[16]}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromIndicesToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,2,2]<=[16]} + %updates = f32[14,10,6,2] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Sharding( + "{devices=[2,1,2,1,4]<=[2,2,4]T(1,0,2) last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterExplicitBatchDimsFromUpdatesToResult) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,4] parameter(2), sharding={devices=[2,2,2,2]<=[16]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3 +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Sharding("{devices=[2,1,2,2,2]<=[2,2,2,2]T(1,0,3,2) " + "last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, ScatterBackwardWithExplicitBatchDims) { + const char* const hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0) + %indices = s32[14,10,6,2] parameter(1) + %updates = f32[14,10,6,4] parameter(2) + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[2,2,2,2]<=[16]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true, true, true}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + + EXPECT_THAT(module->entry_computation()->parameter_instruction(0), + op::Sharding("{devices=[2,2,2,2]<=[16]}")); + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,2,1,1,4]<=[2,2,2,2]T(2,0,1,3) " + "last_tile_dim_replicate}")); + EXPECT_THAT(module->entry_computation()->parameter_instruction(2), + op::Sharding("{devices=[2,2,1,2,2]<=[2,2,2,2]T(2,0,3,1) " + "last_tile_dim_replicate}")); +} + TEST_P(ParameterizedMetadataTest, ParallelGatherFromOperandForwardPass) { const char* const hlo_string = R"( HloModule module @@ -11863,7 +12089,7 @@ ENTRY main.9 { op::Sharding("{{devices=[4]<=[4]}, {devices=[4]<=[4]}}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands1) { +TEST_F(ShardingPropagationTest, LookaheadUsersOfDot) { const char* const hlo_string = R"( HloModule module @@ -11880,108 +12106,24 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); EXPECT_TRUE(changed); XLA_VLOG_LINES(1, module->ToString()); + // Check dangling sharding custom-call can be removed by DCE after + // propagation. auto* instruction = FindInstruction(module.get(), "dot.1"); + // Check sharding is correctly propagated. EXPECT_THAT(instruction, op::Sharding( "{devices=[4,4,1,4]<=[4,16]T(1,0) last_tile_dim_replicate}")); } -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands2) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,32] parameter(0), sharding={devices=[16,1]<=[16]} - p1 = bf16[32,64] parameter(1), sharding={devices=[1,16]<=[16]} - dot = bf16[16,64] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} - ROOT copy = bf16[16,64] copy(dot), sharding={replicated} -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[1,16]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands3) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,4,2]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[2,8,1]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT( - instruction, - op::Sharding("{devices=[2,4,1,2]<=[16] last_tile_dim_replicate}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands4) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[4,16,32] parameter(0), sharding={devices=[2,1,8]<=[16]} - p1 = bf16[4,32,64] parameter(1), sharding={devices=[4,1,4]<=[16]} - dot = bf16[4,16,64] dot(p0, p1), lhs_batch_dims={0}, rhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_contracting_dims={1} - ROOT copy = bf16[4,16,64] copy(dot) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - auto* instruction = FindInstruction(module.get(), "dot"); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,1,4]<=[16]}")); -} - -TEST_F(ShardingPropagationTest, InferDotShardingFromOperands5) { - const char* const hlo_string = R"( -HloModule module - -ENTRY %entry { - p0 = bf16[16,16] parameter(0), sharding={devices=[4,4]<=[4,4]T(1,0)} - p1 = bf16[16,16] parameter(1), sharding={devices=[4,4]<=[4,4]T(1,0)} - dot.0 = bf16[16,16] dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={1} - p2 = bf16[16,16] parameter(2), sharding={devices=[4,4]<=[16]} - p3 = bf16[16,16] parameter(3), sharding={devices=[4,4]<=[16]} - dot.1 = bf16[16,16] dot(p2, p3), lhs_contracting_dims={1}, rhs_contracting_dims={0} - add = bf16[16,16] add(dot.0, dot.1) - ROOT copy = bf16[16,16] copy(add) -})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(hlo_string)); - TF_ASSERT_OK_AND_ASSIGN( - bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); - EXPECT_TRUE(changed); - - XLA_VLOG_LINES(1, module->ToString()); - for (absl::string_view name : {"dot.0", "dot.1", "add"}) { - auto* instruction = FindInstruction(module.get(), name); - EXPECT_THAT(instruction, op::Sharding("{devices=[4,4]<=[16]}")); - } -} - TEST_F(ShardingPropagationTest, AsyncInstructionManualShardingArray) { const char* const hlo_string = R"( HloModule module @@ -12279,5 +12421,38 @@ ENTRY main { op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, ReplicateRngBitGeneratorSeed) { + const char* const hlo_string = R"( +HloModule module +apply_or { + x = u64[] parameter(0) + y = u64[] parameter(1) + ROOT x_or_y = or(x, y) +} +ENTRY main { + p = s32[2,2]{1,0} parameter(0), sharding={devices=[2,2]<=[4]} + up = u64[2,2] convert(p) + i = u64[] constant(0) + seed = u64[2] reduce(up, i), dimensions={1}, to_apply=apply_or + rbg = u32[2048,4096] rng-bit-generator(seed), algorithm=rng_default + ROOT s = u32[2048,4096]{1,0} custom-call(rbg), custom_call_target="Sharding", sharding={devices=[2,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{true}, + /*allow_spmd_sharding_propagation_to_parameters=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + + XLA_VLOG_LINES(1, module->ToString()); + auto* instruction = FindInstruction(module.get(), "seed"); + // Check sharding is correctly propagated. + EXPECT_TRUE(instruction->sharding().IsReplicated()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/sharding_remover.cc b/third_party/xla/xla/service/sharding_remover.cc index 83c746b12c6882..b9029581aaa8a6 100644 --- a/third_party/xla/xla/service/sharding_remover.cc +++ b/third_party/xla/xla/service/sharding_remover.cc @@ -21,6 +21,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/sharding_remover.h b/third_party/xla/xla/service/sharding_remover.h index 3c88ed1e5c0adc..5ea1b6e1273bce 100644 --- a/third_party/xla/xla/service/sharding_remover.h +++ b/third_party/xla/xla/service/sharding_remover.h @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/sharding_remover_test.cc b/third_party/xla/xla/service/sharding_remover_test.cc index 86b52d32013cad..110b3bc4168677 100644 --- a/third_party/xla/xla/service/sharding_remover_test.cc +++ b/third_party/xla/xla/service/sharding_remover_test.cc @@ -15,13 +15,14 @@ limitations under the License. #include "xla/service/sharding_remover.h" +#include #include #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" -#include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; diff --git a/third_party/xla/xla/service/simplify_fp_conversions.h b/third_party/xla/xla/service/simplify_fp_conversions.h index 2c92319c96476e..b12727941fb086 100644 --- a/third_party/xla/xla/service/simplify_fp_conversions.h +++ b/third_party/xla/xla/service/simplify_fp_conversions.h @@ -16,31 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ #define XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Simplifies chains of floating-point conversions. -// -// The algebraic simplifier will remove convert pairs of the form `X -> Y -> X`, -// only when they are a no-op, e.g. `bf16 -> f32 -> bf16` or -// `f32 -> bf16 -> f32`. Note that the latter optimization might lead to -// increased precision. -class SimplifyFPConversions : public HloModulePass { - public: - explicit SimplifyFPConversions() = default; - - absl::string_view name() const override { return "simplify-fp-conversions"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/simplify_fp_conversions.h" #endif // XLA_SERVICE_SIMPLIFY_FP_CONVERSIONS_H_ diff --git a/third_party/xla/xla/service/slice_sinker.h b/third_party/xla/xla/service/slice_sinker.h index eab42adf5171a1..d1d1aa599b1a0f 100644 --- a/third_party/xla/xla/service/slice_sinker.h +++ b/third_party/xla/xla/service/slice_sinker.h @@ -16,22 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SLICE_SINKER_H_ #define XLA_SERVICE_SLICE_SINKER_H_ -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// An HLO pass that sinks slice operations used by a group of elementwise -// operations and merges the group of elementwise operations. -class SliceSinker : public HloModulePass { - public: - absl::string_view name() const override { return "slice-sinker"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/slice_sinker.h" #endif // XLA_SERVICE_SLICE_SINKER_H_ diff --git a/third_party/xla/xla/service/sort_simplifier.h b/third_party/xla/xla/service/sort_simplifier.h index bd4d38c855927a..d05996705787c0 100644 --- a/third_party/xla/xla/service/sort_simplifier.h +++ b/third_party/xla/xla/service/sort_simplifier.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SORT_SIMPLIFIER_H_ #define XLA_SERVICE_SORT_SIMPLIFIER_H_ -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// HLO pass which removes unused operands from sort, where an unused operand is -// defined as an operand at some index 'x' at which the output is not used. -class SortSimplifier : public HloModulePass { - public: - absl::string_view name() const override { return "simplify-sorts"; } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/sort_simplifier.h" #endif // XLA_SERVICE_SORT_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc index 5a6b8c5e0627b9..4efd1cbb6d8fbf 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.cc +++ b/third_party/xla/xla/service/space_to_batch_converter.cc @@ -25,18 +25,18 @@ limitations under the License. #include #include -#include "absl/algorithm/algorithm.h" #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" #include "xla/debug_options_flags.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" @@ -44,14 +44,12 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/tsl/lib/core/bitmap.h" -#include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/space_to_batch_converter.h b/third_party/xla/xla/service/space_to_batch_converter.h index 47141bb1a6d4c3..37cf1809703d21 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.h +++ b/third_party/xla/xla/service/space_to_batch_converter.h @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc index 6f1d86b618216e..a88d157314c7aa 100644 --- a/third_party/xla/xla/service/space_to_batch_converter_test.cc +++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc @@ -20,11 +20,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index 2d3a5d993d0c80..5481cf0886b8d0 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -45,14 +45,18 @@ cc_library( "//xla:util", "//xla:window_util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:comparators", + "//xla/hlo/analysis:hlo_reachability", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:comparators", "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_reachability", "//xla/hlo/ir:tile_assignment", + "//xla/hlo/parser:hlo_lexer", "//xla/hlo/pass:hlo_pass", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:flatten_call_graph", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", "//xla/hlo/utils:hlo_query", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:call_graph", @@ -60,16 +64,13 @@ cc_library( "//xla/service:computation_layout", "//xla/service:custom_call_sharding_helper", "//xla/service:dot_as_convolution_util", - "//xla/service:flatten_call_graph", + "//xla/service:hlo_creation_utils", "//xla/service:hlo_cse", - "//xla/service:hlo_dce", - "//xla/service:hlo_lexer", "//xla/service:hlo_module_config", "//xla/service:host_memory_offload_annotations_hdr", "//xla/service:pattern_matcher", "//xla/service:shape_inference", "//xla/service:sharding_propagation", - "//xla/service:tuple_simplifier", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -93,6 +94,7 @@ cc_library( xla_cc_test( name = "spmd_partitioner_test", srcs = ["spmd_partitioner_test.cc"], + shard_count = 10, deps = [ ":spmd_partitioner", ":spmd_prepare", @@ -101,11 +103,11 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:sharding_format_picker", "//xla/hlo/utils:hlo_matchers", "//xla/hlo/utils:hlo_sharding_util", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/service:sharding_format_picker", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "//xla/tsl/lib/core:status_test_util", @@ -271,10 +273,10 @@ xla_cc_test( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:rng_expander", "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/service:rng_expander", "//xla/service:sharding_propagation", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", @@ -294,9 +296,9 @@ cc_library( "//xla:comparison_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:while_loop_analysis", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", - "//xla/service:while_loop_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/third_party/xla/xla/service/spmd/collective_permute_motion.cc b/third_party/xla/xla/service/spmd/collective_permute_motion.cc index c398ef815b63c4..cff4cb1a0cec1f 100644 --- a/third_party/xla/xla/service/spmd/collective_permute_motion.cc +++ b/third_party/xla/xla/service/spmd/collective_permute_motion.cc @@ -29,10 +29,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/while_loop_analysis.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/xla/service/spmd/custom_call_handler.cc b/third_party/xla/xla/service/spmd/custom_call_handler.cc index dab26f5985a0c5..4903decc236836 100644 --- a/third_party/xla/xla/service/spmd/custom_call_handler.cc +++ b/third_party/xla/xla/service/spmd/custom_call_handler.cc @@ -29,10 +29,10 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" @@ -40,10 +40,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" #include "xla/service/custom_call_sharding_helper.h" -#include "xla/service/hlo_lexer.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/service/spmd/spmd_partitioner.h" @@ -207,13 +208,8 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallTopK( XlaComputation comparator = CreateScalarComparisonComputation( "compare-value-and-index", {input->shape().element_type(), S32}, {Gt, Lt}, &b); - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comparator.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comparator.proto(), config)); - HloCloneContext context(module_); - auto compare_computation = - module_->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN(HloComputation * compare_computation, + XlaComputationToHloComputation(comparator, module_)); // Each partition needs to do TopK separately, thus the base shape for sort // becomes [ceil(batch_size / batch_dim_partition), k * shard_count]. const Shape sort_shape = ShapeUtil::MakeTupleShape( diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 22f88cf0dad143..00388aab4d1e54 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -32,6 +32,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/hlo_reachability.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -39,7 +40,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/hlo/ir/hlo_reachability.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal_util.h" @@ -454,6 +454,24 @@ bool RequiresTransposeSharding( has_different_lhs_rhs_dim_sharding; } +bool should_enable_windowed_einsum_with_threshold( + const SpmdPartitionerOptions& options, const HloInstruction* lhs, + const HloInstruction* rhs, int64_t operand_or_output_shape_size) { + if (options.total_bytes_windowed_einsum_threshold != std::nullopt) { + if (lhs == nullptr || rhs == nullptr) { + return false; + } + int64_t total_operand_bytes = (ShapeUtil::ByteSizeOf(rhs->shape()) + + ShapeUtil::ByteSizeOf(lhs->shape())); + int64_t operand_bytes_threshold = + options.total_bytes_windowed_einsum_threshold.value(); + return total_operand_bytes >= operand_bytes_threshold; + } else { + return operand_or_output_shape_size >= + options.threshold_for_windowed_einsum_mib * 1024 * 1024; + } +} + std::optional GetWindowedEinsumConfiguration( int64_t num_partitions, int64_t output_lhs_non_contracting_partitions, int64_t output_rhs_non_contracting_partitions, @@ -663,8 +681,8 @@ std::optional GetWindowedEinsumConfiguration( if (output_lhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_lhs == lhs_sharding && - rhs_shape_size >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024 && + should_enable_windowed_einsum_with_threshold(options, lhs, rhs, + rhs_shape_size) && (!rhs || check_users_sharding(rhs)) && !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/true) && options.enable_windowed_einsum_for_all_gather) { @@ -695,8 +713,8 @@ std::optional GetWindowedEinsumConfiguration( } if (output_rhs_non_contracting_partitions == num_partitions && output_sharding_transposed_to_match_rhs == rhs_sharding && - lhs_shape_size >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024 && + should_enable_windowed_einsum_with_threshold(options, lhs, rhs, + lhs_shape_size) && (!lhs || check_users_sharding(lhs)) && !disable_windowed_einsum(/*lhs_needs_ag=*/true, /*rhs_needs_ag=*/false) && options.enable_windowed_einsum_for_all_gather) { @@ -729,8 +747,8 @@ std::optional GetWindowedEinsumConfiguration( lhs_contracting_partitions == num_partitions && (output_lhs_non_contracting_partitions == num_partitions || output_rhs_non_contracting_partitions == num_partitions) && - output_shape_size >= - options.threshold_for_windowed_einsum_mib * 1024 * 1024 && + should_enable_windowed_einsum_with_threshold(options, lhs, rhs, + output_shape_size) && !disable_windowed_einsum(/*lhs_needs_ag=*/false, /*rhs_needs_ag=*/false) && options.enable_windowed_einsum_for_reduce_scatter) { diff --git a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc index fd45e649f18cd2..4b9a700fe927ad 100644 --- a/third_party/xla/xla/service/spmd/gather_scatter_handler.cc +++ b/third_party/xla/xla/service/spmd/gather_scatter_handler.cc @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include +#include +#include #include #include @@ -266,6 +267,10 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, bool allow_recursive) { + if (indices.sharding().IsTileMaximal()) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -274,7 +279,7 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( } }; - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); SpmdBuilder* b = visitor->builder(); absl::InlinedVector index_group_dims = hlo_sharding_util::GetGatherScatterIndexPassthroughIndexDims( @@ -300,10 +305,6 @@ absl::StatusOr PartitionGatherIndexPassthroughDimensions( AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims(indices.sharding(), index_group_dims), output_grouped); - - if (indices.sharding().IsTileMaximal()) { - return nullptr; - } // See if we can group partially replicated dimensions from the operand // otherwise replicate it. const GroupedSharding operand_grouped = AlignGroupsWith( @@ -430,7 +431,7 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( }; SpmdBuilder* b = visitor->builder(); - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); std::vector start_index_map(dnums.start_index_map().begin(), dnums.start_index_map().end()); if (std::optional> trivial_slice_dims = @@ -579,16 +580,20 @@ absl::StatusOr PartitionGatherTrivialSlicedOperandDimensions( return nullptr; } -// Partition a gather over a indices dimensions that are cosidered parallel -// (which means that the indices access the operand in a monotonically -// increasing way across the respective operand dimension referenced by the -// index). -absl::StatusOr PartitionGatherIndexParallelDimensions( +absl::StatusOr PartitionGatherParallelDimensions( const HloGatherInstruction* gather, PartitionedHlo operand, PartitionedHlo indices, const Shape& output_shape, const HloSharding& output_sharding, absl::Span batch_dims, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, - bool allow_recursive) { + bool allow_recursive, + const hlo_sharding_util::GatherScatterParallelDims& parallel_dims, + bool need_offset) { + auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( + *operand.hlo(), *indices.hlo(), parallel_dims); + if (!gather_sharding.has_value()) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -598,125 +603,177 @@ absl::StatusOr PartitionGatherIndexParallelDimensions( }; SpmdBuilder* b = visitor->builder(); - GatherDimensionNumbers dnums = gather->gather_dimension_numbers(); + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); - // Handle the case where operand is tile maximal. In this case we check if - // the index is not TileMaximal and in this case we use the index sharding - // to drive the output sharding. - if (std::optional - parallel_dims = hlo_sharding_util::GetGatherParallelBatchDims( - *gather, visitor->call_graph())) { - if (auto gather_sharding = GatherScatterOperandsShardedAcrossParallelDims( - *operand.hlo(), *indices.hlo(), *parallel_dims)) { - const auto indices_parallel_dims = parallel_dims->indices_parallel_dims; - const auto operand_parallel_dims = parallel_dims->operand_parallel_dims; - const auto output_parallel_dims = - hlo_sharding_util::GetGatherParallelOutputDims(*gather, - *parallel_dims); - operand = operand.Reshard(gather_sharding->operand_sharding); - indices = indices.Reshard(gather_sharding->indices_sharding); - HloSharding gather_output_sharding = hlo_sharding_util:: - GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( - indices.sharding(), output_shape.rank(), indices_parallel_dims, - output_parallel_dims); - - // Refine output sharding from the operand. it should be inferred from - // operand sharding, so that the partitioned gather can be either 1) - // directly created on the partitioned operand, or 2) recursively created - // without aligning the groups. - if (auto maybe_passthrough = hlo_sharding_util:: - GatherOutputShardingFromOperandOperandPassthroughDimensions( - operand.base_shape(), - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operand.sharding(), operand_parallel_dims), - *gather, slice_sizes)) { - hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, - &gather_output_sharding); - } - // Construct the offsets for the operand sharding to be used to adjust - // the indices. Because we know the only dimensions partitioned are the - // parallel ones and because the partitioning is the same across indices - // and operands we can apply the offsets on the operands on the indices. - std::vector operand_offsets = MakePartitionOffsets( - operand.base_shape(), operand.sharding(), - operand.state().partition_id, b, operand_parallel_dims); - absl::InlinedVector index_offsets; - for (int start_idx = 0; start_idx < dnums.start_index_map_size(); - ++start_idx) { - HloInstruction* index_offset = - indices.rank() > index_dim - ? b->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - operand_offsets[dnums.start_index_map(start_idx)])) - : operand_offsets[dnums.start_index_map(start_idx)]; - index_offsets.push_back(index_offset); - } - HloInstruction* adjusted_indices = nullptr; - if (indices.rank() > index_dim) { - // Concatenate the offsets for the parallel dimensions to subtract. - adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, - {indices.base_shape().dimensions(index_dim)}), - index_offsets, 0)); - } else { - CHECK_EQ(index_offsets.size(), 1); - adjusted_indices = index_offsets[0]; - } - if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(adjusted_indices->shape(), - indices.hlo()->shape().element_type()), - adjusted_indices)); - } - if (adjusted_indices->shape().rank() == 0) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {})); - } else { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {index_dim})); - } - // Adjust indices by subtracting the offsets based on the partition id. - adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + + const auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto output_parallel_dims = + hlo_sharding_util::GetGatherParallelOutputDims(*gather, parallel_dims); + operand = operand.Reshard(gather_sharding->operand_sharding); + indices = indices.Reshard(gather_sharding->indices_sharding); + HloSharding gather_output_sharding = hlo_sharding_util:: + GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( + indices.sharding(), output_shape.rank(), indices_parallel_dims, + output_parallel_dims); + if (!need_offset) { + hlo_sharding_util::MergeShardingIfCompatible( + hlo_sharding_util:: + GatherOutputShardingFromIndexIndexPassthroughDimensions( + indices.sharding(), gather), + &gather_output_sharding); + } + + // Refine output sharding from the operand. it should be inferred from + // operand sharding, so that the partitioned gather can be either 1) + // directly created on the partitioned operand, or 2) recursively created + // without aligning the groups. + if (auto maybe_passthrough = hlo_sharding_util:: + GatherOutputShardingFromOperandOperandPassthroughDimensions( + operand.base_shape(), + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operand.sharding(), operand_parallel_dims), + *gather, slice_sizes)) { + hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, + &gather_output_sharding); + } + + // Construct the offsets for the operand sharding to be used to adjust + // the indices. Because we know the only dimensions partitioned are the + // parallel ones and because the partitioning is the same across indices + // and operands we can apply the offsets on the operands on the indices. + PartitionedHlo new_indices = indices; + if (need_offset) { + std::vector operand_offsets = MakePartitionOffsets( + operand.base_shape(), operand.sharding(), operand.state().partition_id, + b, operand_parallel_dims); + absl::InlinedVector index_offsets; + for (int start_idx = 0; start_idx < dnums.start_index_map_size(); + ++start_idx) { + HloInstruction* index_offset = + indices.rank() > index_dim + ? b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + operand_offsets[dnums.start_index_map(start_idx)])) + : operand_offsets[dnums.start_index_map(start_idx)]; + index_offsets.push_back(index_offset); + } + HloInstruction* adjusted_indices = nullptr; + if (indices.rank() > index_dim) { + // Concatenate the offsets for the parallel dimensions to subtract. + adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, + {indices.base_shape().dimensions(index_dim)}), + index_offsets, 0)); + } else { + CHECK_EQ(index_offsets.size(), 1); + adjusted_indices = index_offsets[0]; + } + if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(adjusted_indices->shape(), + indices.hlo()->shape().element_type()), adjusted_indices)); - PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - const GroupedSharding new_indices_grouped = - hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), - indices_parallel_dims); - const GroupedSharding operand_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - operand.sharding(), operand_parallel_dims), - new_indices_grouped); - const GroupedSharding output_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - gather_output_sharding, output_parallel_dims), - new_indices_grouped); - PartitionedHlo per_group_operand = - PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); - PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( - new_indices, new_indices_grouped, b, clean_ups); - const Shape pshape = GetPerGroupBaseShape(output_grouped, output_shape); - TF_ASSIGN_OR_RETURN( - HloInstruction * pgather, - PartitionGather(gather, per_group_operand, per_group_new_indices, - pshape, output_grouped.sharding, batch_dims, - slice_sizes, visitor, allow_recursive)); - if (allow_recursive) { - VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; - } - pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); - return PartitionedHlo(pgather, output_shape, operand.state()) - .Reshard(output_sharding) - .hlo(); } + if (adjusted_indices->shape().rank() == 0) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {})); + } else { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {index_dim})); + } + // Adjust indices by subtracting the offsets based on the partition id. + adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + adjusted_indices)); + new_indices = indices.CloneWithNewHlo(adjusted_indices); } - return nullptr; + + const GroupedSharding new_indices_grouped = + hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), + indices_parallel_dims); + const GroupedSharding operand_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operand.sharding(), operand_parallel_dims), + new_indices_grouped); + const GroupedSharding output_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + gather_output_sharding, output_parallel_dims), + new_indices_grouped); + PartitionedHlo per_group_operand = + PerGroupPartitionedHlo(operand, operand_grouped, b, clean_ups); + PartitionedHlo per_group_new_indices = + PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups); + const Shape pshape = GetPerGroupBaseShape(output_grouped, output_shape); + TF_ASSIGN_OR_RETURN( + HloInstruction * pgather, + PartitionGather(gather, per_group_operand, per_group_new_indices, pshape, + output_grouped.sharding, batch_dims, slice_sizes, visitor, + allow_recursive)); + if (allow_recursive) { + VLOG(5) << "[Gather partitioning]: Partitioned as parallel batch_dim"; + } + pgather->set_sharding(hlo_sharding_util::UngroupSharding(output_grouped)); + return PartitionedHlo(pgather, output_shape, operand.state()) + .Reshard(output_sharding) + .hlo(); +} + +// Partition a gather over indices dimensions that are considered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionGatherIndexParallelDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + std::optional parallel_dims = + hlo_sharding_util::GetGatherParallelBatchDims(*gather, + visitor->call_graph()); + if (!parallel_dims.has_value()) { + return nullptr; + } + return PartitionGatherParallelDimensions( + gather, operand, indices, output_shape, output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, *parallel_dims, + /*need_offset=*/true); +} + +// Partition a gather over explicit batch dimensions defined in +// operand_batching_dims and start_indices_batching_dims. +absl::StatusOr PartitionGatherExplicitBatchDimensions( + const HloGatherInstruction* gather, PartitionedHlo operand, + PartitionedHlo indices, const Shape& output_shape, + const HloSharding& output_sharding, absl::Span batch_dims, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + const GatherDimensionNumbers& dnums = gather->gather_dimension_numbers(); + if (dnums.operand_batching_dims().empty()) { + return nullptr; + } + + hlo_sharding_util::GatherScatterParallelDims parallel_dims; + parallel_dims.operand_parallel_dims.assign( + dnums.operand_batching_dims().begin(), + dnums.operand_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.start_indices_batching_dims().begin(), + dnums.start_indices_batching_dims().end()); + + return PartitionGatherParallelDimensions( + gather, operand, indices, output_shape, output_sharding, batch_dims, + slice_sizes, visitor, allow_recursive, parallel_dims, + /*need_offset=*/false); } // Returns a full list of partitioning methods used for gather. std::vector> GatherPartitionMethods() { - return {{PartitionGatherIndexParallelDimensions, + return {{PartitionGatherExplicitBatchDimensions, + "PartitionGatherExplicitBatchDimensions"}, + {PartitionGatherIndexParallelDimensions, "PartitionGatherIndexParallelDimensions"}, {PartitionGatherOperandPassthroughDimensions, "PartitionGatherOperandPassthroughDimensions"}, @@ -729,6 +786,8 @@ GatherPartitionMethods() { // Helper function to get the gather partitioning method. decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) { switch (method) { + case PartitioningMethod::kExplicitBatch: + return PartitionGatherExplicitBatchDimensions; case PartitioningMethod::kIndexParallel: return PartitionGatherIndexParallelDimensions; case PartitioningMethod::kOperandPassthrough: @@ -738,7 +797,7 @@ decltype(PartitionGather)* GetGatherPartitionMethod(PartitioningMethod method) { case PartitioningMethod::kIndexPassthrough: return PartitionGatherIndexPassthroughDimensions; default: - return PartitionGatherIndexParallelDimensions; + return PartitionGatherExplicitBatchDimensions; } } @@ -1019,16 +1078,20 @@ absl::StatusOr PartitionScatter( absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, bool allow_recursive = true); -// Partition a scatter over a indices dimensions that are cosidered parallel -// (which means that the indices access the operand in a monotonically -// increasing way across the respective operand dimension referenced by the -// index). -absl::StatusOr PartitionScatterIndexParallelDimensions( +absl::StatusOr PartitionScatterParallelDimensions( const HloScatterInstruction* scatter, std::vector operands, PartitionedHlo indices, std::vector updates, const Shape& output_shape, const HloSharding& output_sharding, absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, - bool allow_recursive) { + bool allow_recursive, + const hlo_sharding_util::GatherScatterParallelDims& parallel_dims, + bool need_offset) { + auto scatter_sharding = GatherScatterOperandsShardedAcrossParallelDims( + *operands[0].hlo(), *indices.hlo(), parallel_dims); + if (!scatter_sharding) { + return nullptr; + } + // Perform clean up actions upon exiting function scope. absl::InlinedVector, 3> clean_ups; absl::Cleanup cleaner = [&clean_ups] { @@ -1038,136 +1101,184 @@ absl::StatusOr PartitionScatterIndexParallelDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); const int64_t index_dim = dnums.index_vector_dim(); - // Handle the case where operand is tile maximal. In this case we check if - // the index is not TileMaximal and in this case we use the index sharding - // to drive the output sharding. - if (std::optional - parallel_dims = hlo_sharding_util::GetScatterParallelBatchDims( - *scatter, visitor->call_graph())) { - if (auto scatter_sharding = GatherScatterOperandsShardedAcrossParallelDims( - *operands[0].hlo(), *indices.hlo(), *parallel_dims)) { - const auto operand_parallel_dims = parallel_dims->operand_parallel_dims; - const auto indices_parallel_dims = parallel_dims->indices_parallel_dims; - const auto update_parallel_dims = - hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, - *parallel_dims); - for (auto& operand : operands) { - operand = operand.Reshard(scatter_sharding->operand_sharding); - } - indices = indices.Reshard(scatter_sharding->indices_sharding); - HloSharding update_sharding = hlo_sharding_util:: - GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( - indices.sharding(), updates[0].rank(), indices_parallel_dims, - update_parallel_dims); - // Refine update sharding from the operand. it should be inferred from - // operand sharding, so that the partitioned scatter can be either 1) - // directly created on the partitioned operand, or 2) recursively created - // without aligning the groups. - if (auto maybe_passthrough = hlo_sharding_util:: - ScatterUpdateShardingFromOutputOperandPassthroughDimensions( - operands[0].base_shape(), - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operands[0].sharding(), operand_parallel_dims), - *scatter, slice_sizes)) { - hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, - &update_sharding); - } - for (auto& update : updates) { - update = update.Reshard(update_sharding); - } + const auto operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto update_parallel_dims = + hlo_sharding_util::GetScatterParallelUpdateDims(*scatter, parallel_dims); + for (auto& operand : operands) { + operand = operand.Reshard(scatter_sharding->operand_sharding); + } + indices = indices.Reshard(scatter_sharding->indices_sharding); + HloSharding update_sharding = hlo_sharding_util:: + GatherOutputOrScatterUpdateShardingFromIndicesParallelDimensions( + indices.sharding(), updates[0].rank(), indices_parallel_dims, + update_parallel_dims); + if (!need_offset) { + hlo_sharding_util::MergeShardingIfCompatible( + hlo_sharding_util:: + ScatterUpdateShardingFromIndexIndexPassthroughDimensions( + indices.sharding(), scatter), + &update_sharding); + } - // Construct the offsets for the operand sharding to be used to adjust - // the indices. Because we know the only dimensions partitioned are the - // parallel ones and because the partitioning is the same across indices - // and operands we can apply the offsets on the operands on the indices. - std::vector operand_offsets = MakePartitionOffsets( - operands[0].base_shape(), operands[0].sharding(), - operands[0].state().partition_id, b, operand_parallel_dims); - absl::InlinedVector index_offsets; - for (int start_idx = 0; - start_idx < dnums.scatter_dims_to_operand_dims_size(); ++start_idx) { - HloInstruction* index_offset = - indices.base_shape().dimensions_size() > index_dim - ? b->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(S32, {1}), - operand_offsets[dnums.scatter_dims_to_operand_dims( - start_idx)])) - : operand_offsets[dnums.scatter_dims_to_operand_dims( - start_idx)]; - index_offsets.push_back(index_offset); - } - HloInstruction* adjusted_indices = nullptr; - if (indices.base_shape().dimensions_size() > index_dim) { - // Concatenate the offsets for the parallel dimensions to subtract. - adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(S32, - {indices.base_shape().dimensions(index_dim)}), - index_offsets, 0)); - } else { - CHECK_EQ(index_offsets.size(), 1); - adjusted_indices = index_offsets[0]; - } - if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType(adjusted_indices->shape(), - indices.hlo()->shape().element_type()), - adjusted_indices)); - } - if (adjusted_indices->shape().rank() == 0) { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {})); - } else { - adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( - indices.hlo()->shape(), adjusted_indices, {index_dim})); - } - // Adjust indices by subtracting the offsets based on the partition id. - adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( - indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + // Refine update sharding from the operand. it should be inferred from + // operand sharding, so that the partitioned scatter can be either 1) + // directly created on the partitioned operand, or 2) recursively created + // without aligning the groups. + if (auto maybe_passthrough = hlo_sharding_util:: + ScatterUpdateShardingFromOutputOperandPassthroughDimensions( + operands[0].base_shape(), + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + *scatter, slice_sizes)) { + hlo_sharding_util::MergeShardingIfCompatible(*maybe_passthrough, + &update_sharding); + } + + for (auto& update : updates) { + update = update.Reshard(update_sharding); + } + + // Construct the offsets for the operand sharding to be used to adjust + // the indices. Because we know the only dimensions partitioned are the + // parallel ones and because the partitioning is the same across indices + // and operands we can apply the offsets on the operands on the indices. + PartitionedHlo new_indices = indices; + if (need_offset) { + std::vector operand_offsets = MakePartitionOffsets( + operands[0].base_shape(), operands[0].sharding(), + operands[0].state().partition_id, b, operand_parallel_dims); + absl::InlinedVector index_offsets; + for (int start_idx = 0; + start_idx < dnums.scatter_dims_to_operand_dims_size(); ++start_idx) { + HloInstruction* index_offset = + indices.base_shape().dimensions_size() > index_dim + ? b->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + operand_offsets[dnums.scatter_dims_to_operand_dims( + start_idx)])) + : operand_offsets[dnums.scatter_dims_to_operand_dims(start_idx)]; + index_offsets.push_back(index_offset); + } + HloInstruction* adjusted_indices = nullptr; + if (indices.base_shape().dimensions_size() > index_dim) { + // Concatenate the offsets for the parallel dimensions to subtract. + adjusted_indices = b->AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, + {indices.base_shape().dimensions(index_dim)}), + index_offsets, 0)); + } else { + CHECK_EQ(index_offsets.size(), 1); + adjusted_indices = index_offsets[0]; + } + if (indices.hlo()->shape().element_type() != PrimitiveType::S32) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(adjusted_indices->shape(), + indices.hlo()->shape().element_type()), adjusted_indices)); - PartitionedHlo new_indices = indices.CloneWithNewHlo(adjusted_indices); - const GroupedSharding new_indices_grouped = - hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), - indices_parallel_dims); - const GroupedSharding operand_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - operands[0].sharding(), operand_parallel_dims), - new_indices_grouped); - const GroupedSharding update_grouped = - AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( - updates[0].sharding(), update_parallel_dims), - new_indices_grouped); - const GroupedSharding& output_grouped = operand_grouped; - std::vector per_group_operands = - PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); - std::vector per_group_updates = - PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); - PartitionedHlo per_group_new_indices = PerGroupPartitionedHlo( - new_indices, new_indices_grouped, b, clean_ups); - auto pshape = - MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); - TF_ASSIGN_OR_RETURN( - HloInstruction * pscatter, - PartitionScatter( - scatter, per_group_operands, per_group_new_indices, - per_group_updates, pshape, - HloSharding::Single(scatter->shape(), output_grouped.sharding), - slice_sizes, visitor, allow_recursive)); - pscatter->set_sharding(HloSharding::Single( - pscatter->shape(), - hlo_sharding_util::UngroupSharding(output_grouped))); - if (allow_recursive) { - VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; - } - return PartitionedHlo(pscatter, output_shape, operands[0].state()) - .Reshard(output_sharding) - .hlo(); } + if (adjusted_indices->shape().rank() == 0) { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {})); + } else { + adjusted_indices = b->AddInstruction(HloInstruction::CreateBroadcast( + indices.hlo()->shape(), adjusted_indices, {index_dim})); + } + // Adjust indices by subtracting the offsets based on the partition id. + adjusted_indices = b->AddInstruction(HloInstruction::CreateBinary( + indices.hlo()->shape(), HloOpcode::kSubtract, indices.hlo(), + adjusted_indices)); + new_indices = indices.CloneWithNewHlo(adjusted_indices); } - return nullptr; + + const GroupedSharding new_indices_grouped = + hlo_sharding_util::GroupShardingOnDims(new_indices.sharding(), + indices_parallel_dims); + const GroupedSharding operand_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + operands[0].sharding(), operand_parallel_dims), + new_indices_grouped); + const GroupedSharding update_grouped = + AlignGroupsWith(hlo_sharding_util::GroupShardingOnDims( + updates[0].sharding(), update_parallel_dims), + new_indices_grouped); + const GroupedSharding& output_grouped = operand_grouped; + std::vector per_group_operands = + PerGroupPartitionedHlos(operands, operand_grouped, b, clean_ups); + std::vector per_group_updates = + PerGroupPartitionedHlos(updates, update_grouped, b, clean_ups); + PartitionedHlo per_group_new_indices = + PerGroupPartitionedHlo(new_indices, new_indices_grouped, b, clean_ups); + auto pshape = MaybeGetTuplePerGroupBaseShape(output_grouped, output_shape); + TF_ASSIGN_OR_RETURN( + HloInstruction * pscatter, + PartitionScatter( + scatter, per_group_operands, per_group_new_indices, per_group_updates, + pshape, + HloSharding::Single(scatter->shape(), output_grouped.sharding), + slice_sizes, visitor, allow_recursive)); + pscatter->set_sharding(HloSharding::Single( + pscatter->shape(), hlo_sharding_util::UngroupSharding(output_grouped))); + if (allow_recursive) { + VLOG(5) << "[Scatter partitioning]: Partitioned as index parallel"; + } + return PartitionedHlo(pscatter, output_shape, operands[0].state()) + .Reshard(output_sharding) + .hlo(); } + +// Partition a scatter over a indices dimensions that are cosidered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionScatterIndexParallelDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, + const Shape& output_shape, const HloSharding& output_sharding, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + std::optional parallel_dims = + hlo_sharding_util::GetScatterParallelBatchDims(*scatter, + visitor->call_graph()); + if (!parallel_dims) { + return nullptr; + } + + return PartitionScatterParallelDimensions( + scatter, operands, indices, updates, output_shape, output_sharding, + slice_sizes, visitor, allow_recursive, *parallel_dims, true); +} + +// Partition a scatter over a indices dimensions that are cosidered parallel +// (which means that the indices access the operand in a monotonically +// increasing way across the respective operand dimension referenced by the +// index). +absl::StatusOr PartitionScatterExplicitBatchDimensions( + const HloScatterInstruction* scatter, std::vector operands, + PartitionedHlo indices, std::vector updates, + const Shape& output_shape, const HloSharding& output_sharding, + absl::Span slice_sizes, SpmdPartitioningVisitor* visitor, + bool allow_recursive) { + const ScatterDimensionNumbers& dnums = scatter->scatter_dimension_numbers(); + if (dnums.input_batching_dims().empty()) { + return nullptr; + } + + hlo_sharding_util::GatherScatterParallelDims parallel_dims; + parallel_dims.operand_parallel_dims.assign( + dnums.input_batching_dims().begin(), dnums.input_batching_dims().end()); + parallel_dims.indices_parallel_dims.assign( + dnums.scatter_indices_batching_dims().begin(), + dnums.scatter_indices_batching_dims().end()); + + return PartitionScatterParallelDimensions( + scatter, operands, indices, updates, output_shape, output_sharding, + slice_sizes, visitor, allow_recursive, parallel_dims, false); +} + // Perform partitioning of Scatter when the operand is split in a update window // dimension that is passed through (slice size is the same size of the operand // dimension). @@ -1276,7 +1387,7 @@ absl::StatusOr PartitionScatterIndexPassthroughDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); // Parse non-variadic computation only. Vardiadic case will be replicated. const HloSharding original_indices_sharding = indices.sharding(); absl::InlinedVector index_group_dims = @@ -1410,7 +1521,7 @@ absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( }; SpmdBuilder* b = visitor->builder(); - auto dnums = scatter->scatter_dimension_numbers(); + const auto& dnums = scatter->scatter_dimension_numbers(); if (std::optional> trivial_slice_dims = GatherScatterOperandPartitionedOnTrivialSliceDims( operands[0], dnums.scatter_dims_to_operand_dims(), slice_sizes)) { @@ -1510,7 +1621,9 @@ absl::StatusOr PartitionScatterTrivialSlicedOperandDimensions( // Returns a full list of partitioning methods used for scatter. std::vector> ScatterPartitionMethods() { - return {{PartitionScatterIndexParallelDimensions, + return {{PartitionScatterExplicitBatchDimensions, + "PartitionScatterExplicitBatchDimensions"}, + {PartitionScatterIndexParallelDimensions, "PartitionScatterIndexParallelDimensions"}, {PartitionScatterOperandPassthroughDimensions, "PartitionScatterOperandPassthroughDimensions"}, @@ -1524,6 +1637,8 @@ ScatterPartitionMethods() { decltype(PartitionScatter)* GetScatterPartitionMethod( PartitioningMethod method) { switch (method) { + case PartitioningMethod::kExplicitBatch: + return PartitionScatterExplicitBatchDimensions; case PartitioningMethod::kIndexParallel: return PartitionScatterIndexParallelDimensions; case PartitioningMethod::kOperandPassthrough: @@ -1657,8 +1772,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { if (hlo->sharding().HasUniqueDevice()) { return DefaultAction(hlo); } - auto scatter = Cast(hlo); - auto dnums = scatter->scatter_dimension_numbers(); + const auto scatter = Cast(hlo); // Check all operands have the same shapes and shardings, and all updates have // the same shapes and shardings, and live with this assumption during scatter // partitioning. @@ -1724,7 +1838,8 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) { } scatter_partition_method = options().scatter_partition_method; std::vector slice_sizes = hlo_sharding_util::GetScatterSliceSize( - operands[0].base_shape(), updates[0].base_shape(), dnums); + operands[0].base_shape(), updates[0].base_shape(), + scatter->scatter_dimension_numbers()); TF_ASSIGN_OR_RETURN( HloInstruction * pscatter, diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index b5d040c5562195..32bca413997b09 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -23,30 +23,6 @@ package_group( ], ) -cc_library( - name = "shardy_call_inliner", - srcs = ["shardy_call_inliner.cc"], - hdrs = ["shardy_call_inliner.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:call_inliner", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "shardy_call_inliner_test", - srcs = ["shardy_call_inliner_test.cc"], - deps = [ - ":shardy_call_inliner", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "shardy_xla_pass", srcs = ["shardy_xla_pass.cc"], @@ -60,18 +36,18 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", + "//xla/hlo/transforms:tuple_simplifier", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/hlo/utils:hlo_sharding_util", "//xla/mlir_hlo:mhlo_passes", "//xla/service:computation_layout", - "//xla/service:hlo_dce", "//xla/service:hlo_proto_cc", - "//xla/service:tuple_simplifier", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_export", "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -123,6 +99,7 @@ xla_cc_test( name = "shardy_xla_pass_test", srcs = ["shardy_xla_pass_test.cc"], deps = [ + ":constants", ":shardy_xla_pass", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", @@ -146,13 +123,16 @@ xla_cc_binary( "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import", "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export", "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import", - "//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls", + "//xla/service/spmd/shardy/round_trip_common:export_named_computations", + "//xla/service/spmd/shardy/round_trip_common:import_backend_func_calls", "//xla/service/spmd/shardy/round_trip_common:import_constants", + "//xla/service/spmd/shardy/round_trip_common:import_sdy_custom_calls", "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding", "//xla/service/spmd/shardy/sdy_round_trip:export_ops", - "//xla/service/spmd/shardy/sdy_round_trip:export_shardings", - "//xla/service/spmd/shardy/sdy_round_trip:import_shardings", + "//xla/service/spmd/shardy/sdy_round_trip:export_shardy_attrs", + "//xla/service/spmd/shardy/sdy_round_trip:import_shardy_attrs", "//xla/service/spmd/shardy/sdy_round_trip:pipelines", + "//xla/service/spmd/shardy/sdy_round_trip:remove_size_one_axes", "//xla/service/spmd/shardy/sdy_round_trip:shard_map_export", "//xla/service/spmd/shardy/sdy_round_trip:shard_map_import", "//xla/service/spmd/shardy/sdy_round_trip/test_utils:mhlo_to_hlo_to_mhlo", diff --git a/third_party/xla/xla/service/spmd/shardy/constants.h b/third_party/xla/xla/service/spmd/shardy/constants.h index 6d92bbc1660ed2..220a43e1b48cc9 100644 --- a/third_party/xla/xla/service/spmd/shardy/constants.h +++ b/third_party/xla/xla/service/spmd/shardy/constants.h @@ -38,14 +38,23 @@ inline constexpr llvm::StringRef kSPMDShardToFullShapeCallTargetName = // The attribute name for backend config. inline constexpr llvm::StringRef kXlaBackendConfigAttr = "backend_config"; -// Attribute name for temporarily storing the Shardonnay sharding during HLO -// round-trip. It cannot match the name kShardingAttr ("sdy.sharding"), as -// during round-trip, going from HLO to MHLO, the code removes attributes -// in the `frontend_attributes` field, making them top level. And Shardonnay +// Attribute name for temporarily storing the Shardy sharding during HLO +// sdy-round-trip. It cannot match the name `kShardingAttr` ("sdy.sharding"), as +// during sdy-round-trip, going from HLO to MHLO, the code removes attributes +// in the `frontend_attributes` field, making them top level. And Shardy // verification expects `kShardingAttr` to be of type // TensorShardingAttr/TensorShardingPerValueAttr - not a StringAttr. inline constexpr llvm::StringRef kShardingRoundTripAttr = "xla.sdy.sharding"; +// Attribute name for temporarily storing the Shardy sharding rule during HLO +// sdy-round-trip. It cannot match the name `kShardingRuleAttr` +// ("sdy.sharding_rule"), as during sdy-round-trip, going from HLO to MHLO, the +// code removes attributes in the `frontend_attributes` field, making them top +// level. And Shardy verification expects `kShardingRuleAttr` to be of type +// OpShardingRuleAttr - not a StringAttr. +inline constexpr llvm::StringRef kShardingRuleRoundTripAttr = + "xla.sdy.sharding_rule"; + // Attribute name for temporarily storing the Shardonnay meshes during HLO // round-trip. inline constexpr llvm::StringRef kMeshesRoundTripAttr = "xla.sdy.meshes"; @@ -55,16 +64,23 @@ inline constexpr llvm::StringRef kMeshesRoundTripAttr = "xla.sdy.meshes"; inline constexpr llvm::StringRef kFuncResultShardingTargetName = "xla.sdy.FuncResultSharding"; +// The target name of the ShardingGroup custom call. +inline constexpr llvm::StringRef kShardingGroupCustomCallTargetName = + "xla.sdy.ShardingGroup"; + +// Sharding group id attribute name. The attribute will be of type `int64_t` +// and will be used to identify a group of ops that should be sharded together. +inline constexpr llvm::StringRef kShardingGroupIdAttr = + "xla.sdy.sharding_group_id"; + // Attribute name for storing frontend attributes in XLA. inline constexpr llvm::StringRef kFrontendAttributesAttr = "mhlo.frontend_attributes"; -// Attribute name for determining whether the frontend Python framework has -// lowered to SDY collectives and has exported them using -// `SdyRoundTripExportPipeline`. -// TODO(bartchr): remove this when JAX & PartIR integration is complete. -inline constexpr llvm::StringRef kPythonIntegrationComplete = - "xla.sdy.python_integration_complete"; +// Attribute name for determining whether we need to import MHLO shardings, +// i.e., the input module doesn't contain SDY shardings as frontend attributes. +inline constexpr llvm::StringRef kImportMhloShardings = + "xla.sdy.import_mhlo_shardings"; // Attribute name for determining whether tuple parameters should be used for // the rest of the XLA pipeline. @@ -81,17 +97,21 @@ inline constexpr llvm::StringRef kOutShardings = "xla.sdy.out_shardings"; // Attribute name for the manual axes of a `ManualComputationOp`. inline constexpr llvm::StringRef kManualAxes = "xla.sdy.manual_axes"; -// The target name of the custom call that will store the various attrs of a -// `ManualComputationOp` and a reference to a `FuncOp` that is the body of the -// original `ManualComputationOp`. -inline constexpr llvm::StringRef kManualComputationCustomCallTargetName = - "xla.sdy.ManualComputation"; - // The function name of the of the body of a `ManualComputationOp` during Shardy // round tripping. Used inline constexpr llvm::StringRef kManualComputationBodyFuncName = "xla.sdy.manual_computation_body"; +// The target name of the custom call that changes operands from global to local +// shape during Shardy round tripping. +inline constexpr llvm::StringRef kGlobalToLocalShapeCallTargetName = + "xla.sdy.GlobalToLocalShape"; + +// The target name of the custom call that changes results from local to global +// shape during Shardy round tripping. +inline constexpr llvm::StringRef kLocalToGlobalShapeCallTargetName = + "xla.sdy.LocalToGlobalShape"; + // The name of the global mesh. inline constexpr llvm::StringRef kGlobalMeshName = "mesh"; diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD index d2c8a4daa2318a..eb7756caa77127 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/BUILD @@ -1,6 +1,7 @@ # Import/Export passes for going from `sdy.sharding`s to `mhlo.sharding`s and vice versa. load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla:xla.bzl", "xla_cc_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -25,11 +26,10 @@ cc_library( "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", - "//xla/translate/mhlo_to_hlo:type_to_shape", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -87,6 +87,7 @@ cc_library( ":export_ops", ":export_shardings", ":shard_map_export", + "//xla/service/spmd/shardy/round_trip_common:export_named_computations", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", ], @@ -102,11 +103,11 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:tile_assignment", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", - "//xla/translate/mhlo_to_hlo:attribute_exporter", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -121,6 +122,25 @@ cc_library( ], ) +xla_cc_test( + name = "mhlo_import_test", + srcs = ["mhlo_import_test.cc"], + deps = [ + ":mhlo_import", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/ir:register", + ], +) + cc_library( name = "shard_map_import", srcs = ["shard_map_import.cc"], diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc index 0ffff7134c61a6..fbc7beca1bf085 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/sharding_op_util.h" @@ -99,8 +100,7 @@ class ReshardPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp(op, adaptor.getInput()); TensorShardingAttr sdySharding = adaptor.getShardingAttr(); - copyOp->setAttr(kShardingAttr, TensorShardingPerValueAttr::get( - op.getContext(), sdySharding)); + mlir::sdy::setShardings(copyOp, sdySharding); SmallVector unspecifiedDims; for (auto [dim, dimSharding] : diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc index 3ab00020fc21d8..7f697980300703 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_shardings.cc @@ -58,11 +58,11 @@ limitations under the License. #include "shardy/dialect/sdy/ir/utils.h" #include "xla/array.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" namespace xla { namespace sdy { @@ -152,7 +152,7 @@ LogicalResult exportFunc(FuncOp funcOp, const SymbolTable& symbolTable, }; std::function getMeshAttr = [&](TensorShardingAttr sharding) { - return mlir::sdy::getMeshAttr(symbolTable, sharding.getMeshName()); + return sharding.getMesh(symbolTable); }; for (int64_t argNum = 0; argNum < funcOp.getNumArguments(); ++argNum) { @@ -177,11 +177,11 @@ LogicalResult exportFunc(FuncOp funcOp, const SymbolTable& symbolTable, } funcOp.front().walk([&](Operation* op) { - if (auto shardingPerValue = - op->getAttrOfType(kShardingAttr)) { - op->setAttr(kXlaShardingAttr, - convertToHloShardingAttr(op, shardingPerValue.getShardings(), - getMeshAttr, getStringAttr)); + if (ArrayRef shardings = mlir::sdy::getShardings(op); + !shardings.empty()) { + op->setAttr( + kXlaShardingAttr, + convertToHloShardingAttr(op, shardings, getMeshAttr, getStringAttr)); op->removeAttr(kShardingAttr); } }); diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc index a4f48c59943275..36aee9a64f266b 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/spmd/shardy/mhlo_round_trip/export_ops.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" +#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h" namespace xla { namespace sdy { @@ -33,6 +34,7 @@ void addMhloExportPipeline(mlir::OpPassManager& pm) { // shouldn't be applied before converting to HLO as they apply folding. pm.addPass(createExportOpsPass()); pm.addPass(createMhloRoundTripShardMapExportPass()); + pm.addPass(createExportNamedComputationsPass()); pm.addPass(createExportMhloShardingsPass()); } diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index 02553d14e9a98b..1f0cff4c61a75c 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -56,14 +56,15 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" -#include "xla/translate/mhlo_to_hlo/attribute_exporter.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -95,7 +96,6 @@ using ::mlir::sdy::MeshAxisAttr; using ::mlir::sdy::MeshOp; using ::mlir::sdy::SdyDialect; using ::mlir::sdy::TensorShardingAttr; -using ::mlir::sdy::TensorShardingPerValueAttr; // The information of a sub-dimension in IotaTileAssignment. One tile dimension // in tile assignment may correspond to multiple sub-dimensions. See @@ -441,6 +441,9 @@ TensorShardingAttr convertToSdySharding( // break it when we find common mesh axes. while (product < localAxisSize) { MeshAxisAttr axisAttr = globalMesh.getAxes()[globalAxisIndex++]; + if (axisAttr.getSize() == 1) { + continue; + } globalAxes.push_back(AxisRefAttr::get(ctx, axisAttr.getName())); product *= axisAttr.getSize(); } @@ -550,8 +553,7 @@ LogicalResult importShardings( mlir::cast(resType).getRank(), /*openDims=*/false)); } - op->setAttr(kShardingAttr, TensorShardingPerValueAttr::get( - globalMesh.getContext(), newShardings)); + mlir::sdy::setShardings(op, newShardings); op->removeAttr(kXlaShardingAttr); } }); diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc new file mode 100644 index 00000000000000..f1635a804ab332 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" + +#include + +#include +#include "llvm/ADT/DenseMap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/LLVM.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/register.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "tsl/platform/test.h" + +namespace mlir::sdy { + +namespace { + +TEST(MhloImportTest, SkipFirstAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +// As above, but the middle axis is the one with size 1. +TEST(MhloImportTest, SkipSecondAxisOfSize1) { + MLIRContext context; + loadAllRequiredDialects(&context); + SmallVector axes; + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "y", 4)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "x", 1)); + axes.emplace_back(mlir::sdy::MeshAxisAttr::get(&context, "z", 2)); + auto mesh = sdy::MeshAttr::get(&context, axes); + + TensorShardingAttr sharding = xla::sdy::convertToSdySharding( + /*hloSharding=*/xla::HloSharding::IotaTile({4, 2}), + /*globalMesh=*/mesh, + /*deviceIdToMaximalMeshName=*/ + llvm::SmallDenseMap(), /*rank=*/2, + /*openDims=*/true); + EXPECT_EQ(attributeToString(sharding), + "#sdy.sharding<@mesh, [{\"y\", ?}, {\"z\", ?}]>"); +} + +} // namespace +} // namespace mlir::sdy diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc index a24c797a7ec4ca..a8098832a71d5a 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.cc @@ -140,36 +140,6 @@ void insertAxisNamesFromSharding(mlir::MLIRContext* context, } } -// Let axesInOldSharding be the axes used in `sharding`, which should be a -// subset of `axesSet`. Append the set difference `axesSet - axesInOldSharding` -// in the replicated axes of `sharding`. -TensorShardingAttr appendReplicatedAxes( - mlir::MLIRContext* context, TensorShardingAttr sharding, - const llvm::SmallDenseSet& axesSet, MeshAttr mesh) { - llvm::SmallDenseSet axesInOldSharding; - insertAxisNamesFromSharding(context, sharding, axesInOldSharding); - - // `axesInOldSharding` is a subset of `axesSet` since `axesSet` is the union - // of axes in all in/out shardings. - CHECK(llvm::set_is_subset(axesInOldSharding, axesSet)); - if (axesInOldSharding.size() == axesSet.size()) { - return sharding; - } - - SmallVector newReplicatedAxes(sharding.getReplicatedAxes()); - newReplicatedAxes.reserve(newReplicatedAxes.size() + axesSet.size() - - axesInOldSharding.size()); - for (StringAttr axis : axesSet) { - if (!axesInOldSharding.contains(axis)) { - newReplicatedAxes.push_back(AxisRefAttr::get(context, axis)); - } - } - llvm::sort(newReplicatedAxes, AxisRefAttr::getMeshComparator(mesh)); - - return TensorShardingAttr::get(context, sharding.getMeshName(), - sharding.getDimShardings(), newReplicatedAxes); -} - // Assumptions to confirm this is a shard_map pattern: // 1. All operands are the result of a `SPMDFullToShardShape` custom call, which // is the result of a `Sharding` custom call. @@ -365,29 +335,18 @@ class ShardMapImportPass auto inOutShardings = llvm::concat(inShardings, outShardings); // All in/out shardings must refer to the same mesh. - const mlir::FlatSymbolRefAttr meshName = - inOutShardings.begin()->getMeshSymName(); - if (absl::c_any_of(inOutShardings, - [&meshName](TensorShardingAttr sharding) { - return sharding.getMeshSymName() != meshName; - })) { + MeshAttr mesh = mlir::sdy::getCommonMesh(inShardings, outShardings, op); + if (!mesh) { op.emitError("Multiple meshes in a single manual computation."); success = false; return mlir::WalkResult::interrupt(); } - MeshAttr mesh = sdy::getMeshAttr(op, meshName); // Manual axes are the union of the axes in the in/out shardings. llvm::SmallDenseSet manualAxesSet; for (TensorShardingAttr tensorSharding : inOutShardings) { insertAxisNamesFromSharding(context, tensorSharding, manualAxesSet); } - // Update the in/out shardings with manual axes. Append the unused - // manual axes in the list of explicitly replicated axes. - for (TensorShardingAttr& tensorSharding : inOutShardings) { - tensorSharding = appendReplicatedAxes(context, tensorSharding, - manualAxesSet, mesh); - } manualAxes.assign(manualAxesSet.begin(), manualAxesSet.end()); llvm::sort(manualAxes, mesh.getAxisNameComparator()); diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD index ef04dfd7a00d6a..9ca8930594473c 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD @@ -14,13 +14,14 @@ package_group( ) cc_library( - name = "convert_sharding_custom_calls", - srcs = ["convert_sharding_custom_calls.cc"], - hdrs = ["convert_sharding_custom_calls.h"], + name = "import_sdy_custom_calls", + srcs = ["import_sdy_custom_calls.cc"], + hdrs = ["import_sdy_custom_calls.h"], deps = [ "//xla:sharding_op_util", "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", @@ -31,6 +32,39 @@ cc_library( ], ) +cc_library( + name = "export_named_computations", + srcs = ["export_named_computations.cc"], + hdrs = ["export_named_computations.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + ], +) + +cc_library( + name = "import_backend_func_calls", + srcs = ["import_backend_func_calls.cc"], + hdrs = ["import_backend_func_calls.h"], + deps = [ + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@shardy//shardy/dialect/sdy/ir:dialect", + ], +) + cc_library( name = "import_constants", srcs = ["import_constants.cc"], @@ -68,8 +102,9 @@ cc_library( srcs = ["pipeline_passes.cc"], hdrs = ["pipeline_passes.h"], deps = [ - ":convert_sharding_custom_calls", + ":import_backend_func_calls", ":import_constants", + ":import_sdy_custom_calls", ":open_while_free_vars_sharding", "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.cc deleted file mode 100644 index 0053512b9b89d7..00000000000000 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" - -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Transforms/DialectConversion.h" -#include "shardy/dialect/sdy/ir/constants.h" -#include "shardy/dialect/sdy/ir/dialect.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/spmd/shardy/constants.h" -#include "xla/sharding_op_util.h" - -namespace xla { -namespace sdy { - -namespace { - -using ::mlir::StringRef; - -using ::mlir::mhlo::CustomCallOp; - -using ::mlir::sdy::kShardingAttr; -using ::mlir::sdy::ShardingConstraintOp; -using ::mlir::sdy::TensorShardingAttr; -using ::mlir::sdy::TensorShardingPerValueAttr; - -class ShardingCustomCallPattern - : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult matchAndRewrite( - CustomCallOp op, OpAdaptor adaptor, - mlir::ConversionPatternRewriter& rewriter) const override { - if (op.getCallTargetName() != kShardingCustomCallTargetName) { - return rewriter.notifyMatchFailure( - op, "expected CustomCallOp with target name " + - kShardingCustomCallTargetName.str()); - } - - CHECK_EQ(op.getNumOperands(), 1); - - std::vector unspecDims; - if (std::optional backendConfig = op.getBackendConfig()) { - CHECK_OK(xla::sharding_op_util::ParseAttributes( - mlir::dyn_cast(*backendConfig).getValue(), - &unspecDims)); - } - - auto shardingPerValue = - op->getAttrOfType(kShardingAttr); - if (!shardingPerValue) { - op.emitError() << "expected CustomCallOp with sharding attribute"; - return mlir::failure(); - } - if (shardingPerValue.size() != 1) { - op.emitError() << "expected CustomCallOp with exactly one sharding " - "attribute"; - return mlir::failure(); - } - TensorShardingAttr sharding = shardingPerValue.getShardings().front(); - - if (!unspecDims.empty()) { - sharding = sharding.openShardingDims(unspecDims); - } - - rewriter.replaceOpWithNewOp( - op, adaptor.getInputs().front(), sharding); - - return mlir::success(); - } -}; - -// Converts a CustomCall with target name Sharding into a -// ShardingConstraintOp. -class ConvertShardingCustomCallsPass - : public mlir::PassWrapper> { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertShardingCustomCallsPass) - - void runOnOperation() final { - mlir::MLIRContext& context = getContext(); - mlir::ConversionTarget target(context); - target.addLegalDialect(); - target.addDynamicallyLegalOp([](CustomCallOp op) { - return op.getCallTargetName() != kShardingCustomCallTargetName; - }); - mlir::RewritePatternSet patterns(&context); - patterns.add(&context); - if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, - std::move(patterns)))) { - signalPassFailure(); - } - } - - StringRef getArgument() const override { - return "xla-sdy-convert-sharding-custom-calls"; - } - - StringRef getDescription() const override { - return "Converts a CustomCall with target name Sharding into a " - "ShardingConstraintOp."; - } -}; - -} // namespace - -std::unique_ptr createConvertShardingCustomCallsPass() { - return std::make_unique(); -} - -void registerConvertShardingCustomCallsPass() { - mlir::registerPass(createConvertShardingCustomCallsPass); -} - -} // namespace sdy -} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.cc new file mode 100644 index 00000000000000..4a94c676ae1e76 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.cc @@ -0,0 +1,135 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h" + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" +#include "shardy/dialect/sdy/ir/constants.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ArrayAttr; +using ::mlir::ModuleOp; +using ::mlir::NamedAttribute; +using ::mlir::StringRef; +using ::mlir::SymbolTable; +using ::mlir::func::CallOp; +using ::mlir::func::FuncOp; +using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::NamedComputationOp; +using ::mlir::sdy::TensorShardingPerValueAttr; + +// Converts a `NamedComputationOp` into a `CallOp`. +class ExportNamedComputationsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ExportNamedComputationsPass) + + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + mlir::Block& moduleBlock = moduleOp.getRegion().front(); + getOperation()->walk([&](NamedComputationOp namedComputationOp) { + mlir::IRRewriter rewriter(namedComputationOp); + rewriter.setInsertionPointToEnd(&moduleBlock); + auto funcOp = rewriter.create( + namedComputationOp.getLoc(), namedComputationOp.getName(), + rewriter.getFunctionType( + namedComputationOp.getBody().getArgumentTypes(), + namedComputationOp.getResultTypes()), + rewriter.getStringAttr("private"), + /*argAttrs=*/ArrayAttr(), /*resultAttrs=*/ArrayAttr()); + rewriter.setInsertionPointToStart(funcOp->getBlock()); + mlir::sdy::inlineRegionAndConvertTerminatorOp( + namedComputationOp.getBody(), funcOp.getBody()); + rewriter.setInsertionPoint(namedComputationOp); + + // Copy the input shardings to the func. + if (std::optional inShardings = + namedComputationOp.getInShardings()) { + for (auto [arg, sharding] : llvm::zip_equal( + funcOp.getArguments(), inShardings->getShardings())) { + setSharding(arg, sharding); + } + } + + // Copy the output shardings to the func AND call. + mlir::SmallVector callOpAttrs( + namedComputationOp->getDiscardableAttrs()); + if (std::optional outShardings = + namedComputationOp.getOutShardings()) { + for (auto [i, sharding] : + llvm::enumerate(outShardings->getShardings())) { + funcOp.setResultAttr(i, kShardingAttr, sharding); + } + callOpAttrs.push_back(NamedAttribute( + rewriter.getStringAttr(kShardingAttr), *outShardings)); + } + + mlir::StringAttr funcName = symbolTable.insert(funcOp); + auto callOp = rewriter.replaceOpWithNewOp( + namedComputationOp, namedComputationOp.getResultTypes(), funcName, + namedComputationOp.getOperands()); + callOp->setAttrs(callOpAttrs); + }); + } + + StringRef getArgument() const override { + return "xla-sdy-export-named-computations"; + } + + StringRef getDescription() const override { + return "Creates a pass that converts a `NamedComputationOp` with a " + "`to a `CallOp` with a new private function " + "called the `NamedComputationOp`'s `name`. The new `FuncOp` and " + "`CallOp` have the same shardings as the original " + "`NamedComputationOp`s operands/results."; + } +}; + +} // namespace + +std::unique_ptr createExportNamedComputationsPass() { + return std::make_unique(); +} + +void registerExportNamedComputationsPass() { + mlir::registerPass(createExportNamedComputationsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.h new file mode 100644 index 00000000000000..cdd8fd42c27c20 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/export_named_computations.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_EXPORT_NAMED_COMPUTATIONS_H_ +#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_EXPORT_NAMED_COMPUTATIONS_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates a pass that converts a `NamedComputationOp` to a `CallOp` with a new +// private function called the `NamedComputationOp`'s `name`. The new `FuncOp` +// and `CallOp` have the same shardings as the original `NamedComputationOp`s +// operands/results. +// +// If there is a function with the same name as the `NamedComputationOp` in the +// module, the MLIR symbol table will change it to `{name}_#`. +std::unique_ptr createExportNamedComputationsPass(); + +// Register the xla-sdy-export-named-computations pass. +void registerExportNamedComputationsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_EXPORT_NAMED_COMPUTATIONS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc new file mode 100644 index 00000000000000..57a50d928d3bde --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.cc @@ -0,0 +1,157 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" +#include "shardy/dialect/sdy/ir/constants.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::MLIRContext; +using ::mlir::OpConversionPattern; +using ::mlir::StringRef; +using ::mlir::SymbolTable; +using ::mlir::func::CallOp; +using ::mlir::func::FuncOp; +using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::NamedComputationOp; + +class BackendFuncCallPattern : public OpConversionPattern { + public: + explicit BackendFuncCallPattern(MLIRContext* context, + const SymbolTable& symbolTable) + : OpConversionPattern(context), symbolTable(symbolTable) {} + + mlir::LogicalResult matchAndRewrite( + CallOp callOp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override { + if (!hasFrontendAttr(callOp, kXlaBackendConfigAttr)) { + return mlir::failure(); + } + + FuncOp func = symbolTable.lookup(adaptor.getCallee()); + CHECK(func) << "Failed to lookup function: " + << std::string_view(adaptor.getCallee()); + mlir::SmallVector namedCompAttrs; + llvm::copy_if(callOp->getDiscardableAttrs(), + std::back_inserter(namedCompAttrs), + [](const mlir::NamedAttribute& attr) { + return attr.getName() != kShardingAttr; + }); + + auto namedCompOp = rewriter.replaceOpWithNewOp( + callOp, callOp->getResultTypes(), adaptor.getCallee(), + adaptor.getOperands(), /*inShardings=*/nullptr, + /*outShardings=*/mlir::sdy::getShardingPerValue(callOp)); + namedCompOp->setAttrs(namedCompAttrs); + if (func.getBody().empty()) { + return rewriter.notifyMatchFailure(callOp, [](mlir::Diagnostic& diag) { + diag << "Tried to use an already inlined FuncOp. Expected each CallOp " + "with backend_config to have a unique FuncOp."; + }); + } + + mlir::sdy::inlineRegionAndConvertTerminatorOp( + func.getBody(), namedCompOp.getRegion(), rewriter); + rewriter.eraseOp(func); + + return mlir::success(); + } + + private: + const SymbolTable& symbolTable; +}; + +// Converts a `CallOp` with `backend_config` into a `NamedComputationOp`. +class ImportBackendFuncCallsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ImportBackendFuncCallsPass) + + void runOnOperation() final { + // NOTE: Assume that there is a unique callee for each caller. So no need to + // do a walk and copy the callees if there are multiple callers for the + // callee. + mlir::MLIRContext& context = getContext(); + mlir::ConversionTarget target(context); + target.addLegalOp(); + SymbolTable symbolTable(getOperation()); + target.addDynamicallyLegalOp([&](CallOp op) { + // In case the assumption that each host-callback caller has a unique + // callee is not true, and an optimized build is being run without + // verification, make sure that the callee is a function that exists. + return !hasFrontendAttr(op, kXlaBackendConfigAttr); + }); + mlir::RewritePatternSet patterns(&context); + patterns.add(&context, symbolTable); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } + + StringRef getArgument() const override { + return "xla-sdy-import-backend-func-calls"; + } + + StringRef getDescription() const override { + return "Creates a pass that converts a `CallOp` with a `backend_config` " + "attr to a `NamedComputationOp` with the function body inlined and " + "name of the callee."; + } +}; + +} // namespace + +std::unique_ptr createImportBackendFuncCallsPass() { + return std::make_unique(); +} + +void registerImportBackendFuncCallsPass() { + mlir::registerPass(createImportBackendFuncCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h new file mode 100644 index 00000000000000..50f03781a172a1 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h @@ -0,0 +1,41 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_BACKEND_FUNC_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_BACKEND_FUNC_CALLS_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates a pass that converts a `CallOp` with a `backend_config` attr to a +// `NamedComputationOp` with the function body inlined and name of the callee. +// +// This pass is used to handle host offloading calls which are non inlined +// functions that require the callee to be propagated through. +// +// NOTE: it assumes that there is a unique callee for each caller. +std::unique_ptr createImportBackendFuncCallsPass(); + +// Register the xla-sdy-import-backend-func-calls pass. +void registerImportBackendFuncCallsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_BACKEND_FUNC_CALLS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc new file mode 100644 index 00000000000000..8172a217e30a91 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.cc @@ -0,0 +1,176 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/DialectConversion.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" +#include "xla/sharding_op_util.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::IntegerAttr; +using ::mlir::StringRef; +using ::mlir::mhlo::CustomCallOp; +using ::mlir::sdy::ShardingConstraintOp; +using ::mlir::sdy::ShardingGroupOp; +using ::mlir::sdy::TensorShardingAttr; +using ::mlir::mhlo::CustomCallOpAdaptor; + +mlir::LogicalResult rewriteShardingCustomCall( + CustomCallOp op, CustomCallOpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) { + CHECK_EQ(op.getNumOperands(), 1); + + std::vector unspecDims; + if (std::optional backendConfig = op.getBackendConfig()) { + CHECK_OK(xla::sharding_op_util::ParseAttributes( + mlir::dyn_cast(*backendConfig).getValue(), + &unspecDims)); + } + + if (op->getNumResults() != 1) { + op.emitError() << "expected CustomCallOp with exactly one result"; + return mlir::failure(); + } + TensorShardingAttr sharding = mlir::sdy::getSharding(op->getResult(0)); + if (!sharding) { + op.emitError() << "expected CustomCallOp with a sharding attribute"; + return mlir::failure(); + } + + if (!unspecDims.empty()) { + sharding = sharding.openShardingDims(unspecDims); + } + + rewriter.replaceOpWithNewOp( + op, adaptor.getInputs().front(), sharding); + + return mlir::success(); +} + +mlir::LogicalResult rewriteShardingGroupCustomCall( + CustomCallOp op, CustomCallOpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) { + CHECK_EQ(op.getNumOperands(), 1); + CHECK_LE(op.getNumResults(), 1); + + std::optional shardingGroupId = + tryGetFrontendAttr(op, kShardingGroupIdAttr); + if (!shardingGroupId.has_value()) { + return op.emitError() << "expected CustomCallOp with a sharding group id."; + } + if (!op.use_empty()) { + return op.emitError() + << "xla.sdy.ShardingGroup CustomCallOp should have no uses."; + } + + rewriter.replaceOpWithNewOp(op, adaptor.getInputs().front(), + shardingGroupId->getInt()); + + return mlir::success(); +} + +class SdyCustomCallPattern : public mlir::OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + CustomCallOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter& rewriter) const override { + if (op.getCallTargetName() == kShardingCustomCallTargetName) { + return rewriteShardingCustomCall(op, adaptor, rewriter); + } + + if (op.getCallTargetName() == kShardingGroupCustomCallTargetName) { + return rewriteShardingGroupCustomCall(op, adaptor, rewriter); + } + + return rewriter.notifyMatchFailure( + op, "expected CustomCallOp with xla.sdy target name."); + } +}; + +// Convert custom calls into sdy APIs. +// * xla.sdy.Sharding -> ShardingConstraintOp +// * xla.sdy.ShardingGroup -> ShardingGroupOp +class ImportSdyCustomCallsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ImportSdyCustomCallsPass) + + void runOnOperation() final { + mlir::MLIRContext& context = getContext(); + mlir::ConversionTarget target(context); + target.addLegalDialect(); + target.addDynamicallyLegalOp([](CustomCallOp op) { + return op.getCallTargetName() != kShardingCustomCallTargetName && + op.getCallTargetName() != kShardingGroupCustomCallTargetName; + }); + mlir::RewritePatternSet patterns(&context); + patterns.add(&context); + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) { + signalPassFailure(); + } + } + + StringRef getArgument() const override { + return "xla-sdy-import-sdy-custom-calls"; + } + + StringRef getDescription() const override { + return "Converts a CustomCall with target name Sharding into a " + "ShardingConstraintOp and with target name ShardingGroup into a " + "ShardingGroupOp."; + } +}; + +} // namespace + +std::unique_ptr createImportSdyCustomCallsPass() { + return std::make_unique(); +} + +void registerImportSdyCustomCallsPass() { + mlir::registerPass(createImportSdyCustomCallsPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h similarity index 57% rename from third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h rename to third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h index c7e2f4d0310f64..74ac5c847b7caa 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_CONVERT_SHARDING_CUSTOM_CALLS_H_ -#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_CONVERT_SHARDING_CUSTOM_CALLS_H_ +#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_SDY_CUSTOM_CALLS_H_ +#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_SDY_CUSTOM_CALLS_H_ #include @@ -23,14 +23,15 @@ limitations under the License. namespace xla { namespace sdy { -// Creates a pass that converts a `CustomCall` with target name Sharding into a -// `ShardingConstraintOp`. -std::unique_ptr createConvertShardingCustomCallsPass(); +// Creates a pass that imports sdy tagged `CustomCall` ops. Namely it converts +// * xla.sdy.Sharding -> ShardingConstraintOp +// * xla.sdy.ShardingGroup -> ShardingGroupOp +std::unique_ptr createImportSdyCustomCallsPass(); -// Register the xla-sdy-convert-sharding-custom-calls pass. -void registerConvertShardingCustomCallsPass(); +// Register the xla-sdy-import-sdy-custom-calls pass. +void registerImportSdyCustomCallsPass(); } // namespace sdy } // namespace xla -#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_CONVERT_SHARDING_CUSTOM_CALLS_H_ +#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IMPORT_SDY_CUSTOM_CALLS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index d558679fff7e8d..68592c1918a3e3 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -19,8 +19,9 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" +#include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h" #include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" namespace xla { @@ -51,8 +52,9 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { } void addCommonPostImportPasses(mlir::OpPassManager& pm) { - pm.addPass(createConvertShardingCustomCallsPass()); + pm.addPass(createImportSdyCustomCallsPass()); pm.addNestedPass(createOpenWhileFreeVarsShardingPass()); + pm.addPass(createImportBackendFuncCallsPass()); } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.h index ada06bcdd623cd..75812a4e6e0985 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.h +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.h @@ -22,15 +22,20 @@ namespace xla { namespace sdy { // Adds the common import passes for both the SDY and MHLO import -// pipelines that need to be called before each pass converts an HLO sharding/ -// SDY sharding string into an `sdy.sharding` attribute. +// pipelines that need to be called before each pipeline converts an HLO +// sharding/SDY sharding string into an `sdy.sharding` attribute. void addCommonPreImportPasses(mlir::OpPassManager& pm); // Adds the common import passes for both the SDY and MHLO import -// pipelines that need to be called after each pass converts an HLO sharding/ -// SDY sharding string into an `sdy.sharding` attribute. +// pipelines that need to be called after each pipeline converts an HLO +// sharding/SDY sharding string into an `sdy.sharding` attribute. void addCommonPostImportPasses(mlir::OpPassManager& pm); +// Adds the common export passes for both the SDY and MHLO import +// pipelines that need to be called before each pipeline converts an HLO +// sharding/SDY sharding string into an `sdy.sharding` attribute. +void addCommonPreExportPasses(mlir::OpPassManager& pm); + } // namespace sdy } // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc index 892dd9b66a0859..7f2dff488a7f00 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc @@ -29,13 +29,16 @@ limitations under the License. #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" -#include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" +#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h" +#include "xla/service/spmd/shardy/round_trip_common/import_backend_func_calls.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/import_sdy_custom_calls.h" #include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" -#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" -#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.h" @@ -56,18 +59,21 @@ int main(int argc, char** argv) { xla::sdy::registerMhloImportPipeline(); xla::sdy::registerMhloImportShardingsPass(); xla::sdy::registerMhloRoundTripShardMapImportPass(); - xla::sdy::registerConvertShardingCustomCallsPass(); + xla::sdy::registerImportSdyCustomCallsPass(); xla::sdy::registerOpenWhileFreeVarsShardingPass(); + xla::sdy::registerImportBackendFuncCallsPass(); xla::sdy::registerImportConstantsPass(); xla::sdy::registerMhloExportPipeline(); xla::sdy::registerMhloExportShardingsPass(); xla::sdy::registerMhloRoundTripShardMapExportPass(); + xla::sdy::registerExportNamedComputationsPass(); xla::sdy::registerExportOpsPass(); xla::sdy::registerSdyRoundTripMhloToHloToMhloPass(); - xla::sdy::registerSdyRoundTripExportShardingsPass(); - xla::sdy::registerSdyRoundTripImportShardingsPass(); + xla::sdy::registerSdyRoundTripExportShardyAttrsPass(); + xla::sdy::registerSdyRoundTripImportShardyAttrsPass(); + xla::sdy::registerSdyRoundTripRemoveSizeOneAxesPass(); xla::sdy::registerSdyRoundTripExportOpsPass(); xla::sdy::registerSdyRoundTripExportPipeline(); xla::sdy::registerSdyRoundTripShardMapExportPass(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index ebc36cdc54d2a9..034bee542d0e57 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -16,9 +16,9 @@ package_group( ) cc_library( - name = "export_shardings", - srcs = ["export_shardings.cc"], - hdrs = ["export_shardings.h"], + name = "export_shardy_attrs", + srcs = ["export_shardy_attrs.cc"], + hdrs = ["export_shardy_attrs.h"], deps = [ "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", @@ -40,6 +40,7 @@ cc_library( deps = [ "//xla/mlir_hlo", "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -50,16 +51,14 @@ cc_library( ) cc_library( - name = "import_shardings", - srcs = ["import_shardings.cc"], - hdrs = ["import_shardings.h"], + name = "import_shardy_attrs", + srcs = ["import_shardy_attrs.cc"], + hdrs = ["import_shardy_attrs.h"], deps = [ "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service/spmd/shardy:constants", "//xla/service/spmd/shardy:utils", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:FuncDialect", @@ -107,7 +106,22 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@shardy//shardy/dialect/sdy/ir:dialect", - "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "remove_size_one_axes", + srcs = ["remove_size_one_axes.cc"], + hdrs = ["remove_size_one_axes.h"], + deps = [ + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@shardy//shardy/dialect/sdy/ir:dialect", + "@shardy//shardy/dialect/sdy/transforms/common:sharding_walker", ], ) @@ -117,13 +131,17 @@ cc_library( hdrs = ["pipelines.h"], deps = [ ":export_ops", - ":export_shardings", - ":import_shardings", + ":export_shardy_attrs", + ":import_shardy_attrs", + ":remove_size_one_axes", + ":shard_map_export", + ":shard_map_import", "//xla/service:hlo_proto_cc", "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", - "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_import", + "//xla/service/spmd/shardy/round_trip_common:export_named_computations", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", ], ) diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc index b5bd21fbeaa04f..67c4bc63b86802 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc @@ -42,6 +42,7 @@ limitations under the License. #include "shardy/dialect/sdy/ir/dialect.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/utils.h" namespace mhlo = ::mlir::mhlo; @@ -62,6 +63,7 @@ using ::mlir::success; using ::mlir::sdy::ConstantOp; using ::mlir::sdy::ShardingConstraintOp; +using ::mlir::sdy::ShardingGroupOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; @@ -107,6 +109,25 @@ class ShardingConstraintPattern } }; +class ShardingGroupPattern : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + private: + LogicalResult matchAndRewrite( + ShardingGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto customCallOp = rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), adaptor.getInput()); + + customCallOp.setCallTargetName(kShardingGroupCustomCallTargetName); + addFrontendAttribute(customCallOp, kShardingGroupIdAttr, + op.getGroupIdAttr()); + customCallOp.setHasSideEffectAttr(rewriter.getBoolAttr(true)); + return success(); + } +}; + class SdyRoundTripExportOpsPass : public PassWrapper> { public: @@ -118,7 +139,9 @@ class SdyRoundTripExportOpsPass target.addIllegalOp(); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); - patterns.add(&context); + patterns + .add( + &context); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc similarity index 73% rename from third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc rename to third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc index a7177f334c077d..8474d3efb0e6e2 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" #include #include -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" @@ -70,15 +69,16 @@ using ::mlir::func::FuncOp; using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshOp; +using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; // Saves `shardingPerValueAttr` including any existing `frontendAttributes` on // the `op`. -void saveOpShardingPerValueAttr(Operation* op, - TensorShardingPerValueAttr shardingPerValueAttr, - OpBuilder& builder) { +void saveOpShardingPerValueAttr( + Operation* op, TensorShardingPerValueAttr shardingPerValueAttr) { addFrontendAttribute(op, kShardingRoundTripAttr, shardingPerValueAttr); } @@ -92,11 +92,11 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { } } - for (mlir::OpOperand& returnOperand : - mlir::sdy::getBodyTerminatorOpOperands(funcOp)) { - int64_t resultNum = returnOperand.getOperandNumber(); + Operation* terminatorOp = mlir::sdy::getBodyTerminator(funcOp); + builder.setInsertionPoint(terminatorOp); + for (mlir::OpOperand& returnOperand : terminatorOp->getOpOperands()) { if (auto sharding = funcOp.getResultAttrOfType( - resultNum, kShardingAttr)) { + returnOperand.getOperandNumber(), kShardingAttr)) { // We cannot save the result shardings as frontend attributes. MHLO->HLO // conversion converts `mhlo.sharding`s on the results to a tuple // sharding on the ROOT instruction, but it discards the frontend @@ -106,7 +106,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { // Op's sharding to the FuncOp's result and delete te temporary custom // call. Value returnValue = returnOperand.get(); - builder.setInsertionPoint(returnOperand.getOwner()); auto customCallOp = builder.create( returnValue.getLoc(), returnValue.getType(), returnValue); customCallOp.setCallTargetName(kFuncResultShardingTargetName); @@ -115,27 +114,32 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { customCallOp.setHasSideEffect(true); saveOpShardingPerValueAttr( customCallOp, - TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding), - builder); + TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding)); returnOperand.set(customCallOp.getResult(0)); } } funcOp.front().walk([&](Operation* op) { - if (auto oldShardingPerValue = - op->getAttrOfType(kShardingAttr)) { - saveOpShardingPerValueAttr(op, oldShardingPerValue, builder); + if (TensorShardingPerValueAttr oldShardingPerValue = + mlir::sdy::getShardingPerValue(op)) { + saveOpShardingPerValueAttr(op, oldShardingPerValue); + } + if (auto oldShardingRule = + op->getAttrOfType(kShardingRuleAttr)) { + addFrontendAttribute(op, kShardingRuleRoundTripAttr, oldShardingRule); + op->removeAttr(kShardingRuleAttr); } }); return mlir::success(); } -class SdyRoundTripExportShardingsPass - : public PassWrapper> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SdyRoundTripExportShardingsPass) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripExportShardyAttrsPass) void runOnOperation() final { ModuleOp moduleOp = getOperation(); @@ -154,19 +158,22 @@ class SdyRoundTripExportShardingsPass for (MeshOp meshOp : moduleOp.getOps()) { mhloMeshes.emplace_back(meshOp.getSymNameAttr(), meshOp.getMeshAttr()); } - addFrontendAttribute(moduleOp, kMeshesRoundTripAttr, - DictionaryAttr::get(context, mhloMeshes)); + if (!mhloMeshes.empty()) { + addFrontendAttribute(moduleOp, kMeshesRoundTripAttr, + DictionaryAttr::get(context, mhloMeshes)); + } } StringRef getArgument() const override { - return "xla-sdy-round-trip-export-shardings"; + return "xla-sdy-round-trip-export-shardy-attrs"; } StringRef getDescription() const override { - return "Converts the shardings from kShardingAttr to " - "kShardingRoundTripAttr in the HLO frontend attributes and saves " - "the mesh symbols as kMeshesRoundTripAttr in the module frontend " - "attributes."; + return "Converts the shardy attributes from " + "kShardingAttr/kShardingRuleAttr to " + "kShardingRoundTripAttr/kShardingRuleRoundTripAttr in the HLO " + "frontend attributes and saves the mesh symbols as " + "kMeshesRoundTripAttr in the module frontend attributes."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { @@ -176,12 +183,12 @@ class SdyRoundTripExportShardingsPass } // namespace -void registerSdyRoundTripExportShardingsPass() { - mlir::registerPass(createSdyRoundTripExportShardingsPass); +void registerSdyRoundTripExportShardyAttrsPass() { + mlir::registerPass(createSdyRoundTripExportShardyAttrsPass); } -std::unique_ptr createSdyRoundTripExportShardingsPass() { - return std::make_unique(); +std::unique_ptr createSdyRoundTripExportShardyAttrsPass() { + return std::make_unique(); } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h similarity index 54% rename from third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h rename to third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h index 4b8ce6ab737419..d4d64aefbdbc96 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDINGS_H_ -#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDINGS_H_ +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDY_ATTRS_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDY_ATTRS_H_ #include @@ -23,19 +23,22 @@ limitations under the License. namespace xla { namespace sdy { -// Registers the xla-sdy-round-trip-export-shardings pass. -void registerSdyRoundTripExportShardingsPass(); +// Registers the xla-sdy-round-trip-export-shardy-attrs pass. +void registerSdyRoundTripExportShardyAttrsPass(); -// Creates the pass that converts the shardings from `kShardingAttr` to -// `kShardingRoundTripAttr` in the HLO frontend attributes and saves the -// mesh symbols as `kMeshesRoundTripAttr` in the module frontend attributes. +// Creates the pass to convert SDY attributes to frontend attributes: +// +// - Converts shardings from `kShardingAttr` to `kShardingRoundTripAttr` +// - Converts sharding rules from `kShardingRuleAttr` to +// `kShardingRuleRoundTripAttr` +// - Saves the mesh symbols as `kMeshesRoundTripAttr` // // NOTE: The `kShardingAttr`s are not removed from the ops. They are kept around -// because part of the `SdyRoundTripExportPipeline` it also converts the +// because part of the `SdyRoundTripExportPipeline` also converts the // `kShardingAttr`s to `kXlaShardingAttr`s. -std::unique_ptr createSdyRoundTripExportShardingsPass(); +std::unique_ptr createSdyRoundTripExportShardyAttrsPass(); } // namespace sdy } // namespace xla -#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDINGS_H_ +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_EXPORT_SHARDY_ATTRS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc similarity index 50% rename from third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc rename to third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc index eb11cc53f1c456..26f3539163b15f 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.cc @@ -13,15 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" #include #include #include -#include +#include -#include "absl/log/check.h" -#include "absl/strings/escaping.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/AsmParser/AsmParser.h" @@ -34,6 +32,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" @@ -46,7 +45,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/constants.h" #include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/ir/utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/constants.h" @@ -59,27 +57,29 @@ namespace { using ::mlir::Attribute; using ::mlir::DictionaryAttr; +using ::mlir::IRRewriter; using ::mlir::ModuleOp; using ::mlir::NamedAttribute; using ::mlir::Operation; using ::mlir::StringAttr; using ::mlir::StringRef; using ::mlir::SymbolTable; -using ::mlir::SymbolTableCollection; using ::mlir::func::FuncOp; using ::mlir::mhlo::CustomCallOp; using ::mlir::sdy::kShardingAttr; +using ::mlir::sdy::kShardingRuleAttr; using ::mlir::sdy::MeshAttr; +using ::mlir::sdy::OpShardingRuleAttr; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; -// Builds the shardings coming from Shardy previously. This means +// Builds the shardy attributes coming from Shardy previously. This means // the module was exported from Shardy and we are now round-tripping back. // This should happen after the meshes were created from the `ModuleOp` attrs -// (see `SdyRoundTripImportShardingsPass`). -void convertShardings(FuncOp funcOp) { +// (see `SdyRoundTripImportShardyAttrsPass`). +void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) { // Copy over the argument shardings, but not the result shardings yet. // We need to wait until after we've converted all the Operations before // copying the result shardings. @@ -102,98 +102,104 @@ void convertShardings(FuncOp funcOp) { resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr)); } - // Extract the round-tripped SDY shardings from the operations. + // Extract the round-tripped SDY shardy attributes from the operations. funcOp.front().walk([&](Operation* op) { op->removeAttr(kXlaShardingAttr); - if (DictionaryAttr dictAttr = getFrontendAttrs(op)) { - // NOTE: we are only setting the sharding on known custom-calls. For any - // other op that has a `kShardingRoundTripAttr` we discard it. XLA - // sometimes creates new instructions, copying over the operand's frontend - // attrs, which may mean the shapes are wrong when the new instruction is - // a reshape for example. This does mean we can't fully round-trip b/w HLO - // and MLIR after SDY propagation. - if (auto customCallOp = mlir::dyn_cast(op)) { - StringRef targetName = customCallOp.getCallTargetName(); - if (targetName == kFuncResultShardingTargetName) { - // This is a temporary CustomCallOp that holds the sharding from a - // func result. When importing we want to move that sharding to the - // func result and delete the CustomCallOp. - auto shardingPerValueAttr = - parseStringAttr( - dictAttr, kShardingRoundTripAttr); - for (mlir::OpOperand& use : - llvm::make_early_inc_range(customCallOp->getUses())) { - int64_t resNum = use.getOperandNumber(); - funcOp.setResultAttr(resNum, kShardingAttr, + DictionaryAttr dictAttr = getFrontendAttrs(op); + if (!dictAttr) { + return; + } + // NOTE: we are only setting the sharding on known custom-calls. For any + // other op that has a `kShardingRoundTripAttr` we discard it. XLA sometimes + // creates new instructions, copying over the operand's frontend attrs, + // which may mean the shapes are wrong when the new instruction is a reshape + // for example. This does mean we can't fully round-trip b/w HLO and MLIR + // after SDY propagation. + if (auto customCallOp = mlir::dyn_cast(op)) { + StringRef targetName = customCallOp.getCallTargetName(); + if (targetName == kFuncResultShardingTargetName) { + // This is a temporary CustomCallOp that holds the sharding from a + // func result. When importing we want to move that sharding to the + // func result and delete the CustomCallOp. + auto shardingPerValueAttr = parseStringAttr( + dictAttr, kShardingRoundTripAttr); + for (mlir::OpOperand& use : + llvm::make_early_inc_range(customCallOp->getUses())) { + // We currently ignore users that are not the func return op. + // This might happen due to inlined func ops that originally had + // result shardings. + // TODO(b/370984308): explore if we need to support this properly. + if (mlir::isa(use.getOwner())) { + funcOp.setResultAttr(use.getOperandNumber(), kShardingAttr, shardingPerValueAttr.getSharding(0)); - mlir::sdy::getBodyTerminator(funcOp)->setOperand( - resNum, customCallOp.getOperand(0)); + use.set(customCallOp.getOperand(0)); } - customCallOp.erase(); - return; - } - if (targetName == kShardingCustomCallTargetName || - targetName == kSPMDFullToShardShapeCallTargetName || - targetName == kSPMDShardToFullShapeCallTargetName) { - customCallOp->setAttr(kShardingAttr, - parseStringAttr( - dictAttr, kShardingRoundTripAttr)); } + rewriter.replaceOp(customCallOp, customCallOp.getOperand(0)); + return; + } + if (targetName == kShardingCustomCallTargetName || + targetName == kSPMDFullToShardShapeCallTargetName || + targetName == kSPMDShardToFullShapeCallTargetName) { + customCallOp->setAttr(kShardingAttr, + parseStringAttr( + dictAttr, kShardingRoundTripAttr)); } - removeFrontendAttribute(op, kShardingRoundTripAttr); + } + removeFrontendAttribute(op, kShardingRoundTripAttr); + + // Import sharding rules. + if (auto shardingRuleAttr = parseStringAttr( + dictAttr, kShardingRuleRoundTripAttr)) { + op->setAttr(kShardingRuleAttr, shardingRuleAttr); + removeFrontendAttribute(op, kShardingRuleRoundTripAttr); } }); } -class SdyRoundTripImportShardingsPass - : public mlir::PassWrapper> { public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SdyRoundTripImportShardingsPass) + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripImportShardyAttrsPass) void runOnOperation() final { ModuleOp moduleOp = getOperation(); - SymbolTableCollection symbolTableCollection; - SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); - // If there is a dictionary attribute `kFrontendAttributesAttr` and it - // contains `kMeshesRoundTripAttr`, it means that the function was a - // Shardy function and we are roundtripping back to Shardy. In that - // case, we can use the saved string attributes to restore the original mesh - // and value shardings with the original mesh axis names and priorities on - // the sharding. - DictionaryAttr moduleDictAttr = getFrontendAttrs(moduleOp); - if (!moduleDictAttr) { - moduleOp.emitError( - "Expected an attribute `kFrontendAttributesAttr` on the module that " - "contains the Shardy meshes."); - signalPassFailure(); - return; - } - auto sdyMeshes = - parseStringAttr(moduleDictAttr, kMeshesRoundTripAttr); - mlir::OpBuilder builder(moduleOp); + // We can use the saved string attributes to restore the original mesh and + // value shardings with the original mesh axis names and priorities on the + // sharding. If there is no `kMeshesRoundTripAttr, there were no meshes in + // the original Shardy model. + std::optional meshesAttr = + tryGetFrontendAttr(moduleOp, kMeshesRoundTripAttr); + mlir::ArrayRef sdyMeshes = + meshesAttr.has_value() ? meshesAttr.value().getValue() + : mlir::ArrayRef(); + + IRRewriter rewriter(moduleOp); // Insert the meshes before any functions. - builder.setInsertionPointToStart(moduleOp.getBody()); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + SymbolTable symbolTable(moduleOp); for (NamedAttribute mesh : sdyMeshes) { auto meshAttr = mlir::cast(mesh.getValue()); - symbolTable.insert(builder.create( + symbolTable.insert(rewriter.create( moduleOp.getLoc(), mesh.getName(), meshAttr)); } removeFrontendAttribute(moduleOp, kMeshesRoundTripAttr); for (auto funcOp : moduleOp.getOps()) { - convertShardings(funcOp); + convertShardyAttrs(funcOp, rewriter); } } StringRef getArgument() const override { - return "xla-sdy-round-trip-import-shardings"; + return "xla-sdy-round-trip-import-shardy-attrs"; } StringRef getDescription() const override { - return "Converts the shardings from strings in MHLO frontend attributes to " - "SDY meshes and shardings."; + return "Converts the shardy attributes from strings in MHLO frontend " + "attributes to SDY meshes, shardings and sharding rules."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { @@ -203,12 +209,12 @@ class SdyRoundTripImportShardingsPass } // namespace -std::unique_ptr createSdyRoundTripImportShardingsPass() { - return std::make_unique(); +std::unique_ptr createSdyRoundTripImportShardyAttrsPass() { + return std::make_unique(); } -void registerSdyRoundTripImportShardingsPass() { - mlir::registerPass(createSdyRoundTripImportShardingsPass); +void registerSdyRoundTripImportShardyAttrsPass() { + mlir::registerPass(createSdyRoundTripImportShardyAttrsPass); } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h new file mode 100644 index 00000000000000..0e75e2fbc648a4 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDY_ATTRS_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDY_ATTRS_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates the pass to convert frontend attributes to SDY attributes: +// +// - Converts shardings from `kShardingRoundTripAttr` to `kShardingAttr` +// - Converts sharding rules from `kShardingRuleRoundTripAttr` to +// `kShardingRuleAttr` +// - Converts meshes from `kMeshesRoundTripAttr` to sdy.mesh symbols +std::unique_ptr createSdyRoundTripImportShardyAttrsPass(); + +// Registers the xla-sdy-round-trip-import-shardy-attrs pass. +void registerSdyRoundTripImportShardyAttrsPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDY_ATTRS_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index 14ad8133ab33f0..32e15074c843a1 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -22,11 +22,14 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "xla/service/hlo.pb.h" #include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" -#include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_import.h" +#include "xla/service/spmd/shardy/round_trip_common/export_named_computations.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" -#include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" -#include "xla/service/spmd/shardy/sdy_round_trip/import_shardings.h" +#include "xla/service/spmd/shardy/sdy_round_trip/export_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/import_shardy_attrs.h" +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h" +#include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" namespace xla { namespace sdy { @@ -34,24 +37,21 @@ namespace sdy { using ::mlir::PassPipelineRegistration; void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { - // NOTE: we don't do any exporting for ManualComputationOp, since during - // SDY round-trip we expect the same pattern of custom calls to continue to - // exist. We save `sdy.sharding`s on those custom calls during - // `createSdyRoundTripExportShardingsPass` and make use of - // `createSdyRoundTripImportShardingsPass` to import them. + pm.addPass(createExportNamedComputationsPass()); pm.addPass(createSdyRoundTripExportOpsPass()); + pm.addPass(createSdyRoundTripShardMapExportPass()); // Preserve the SDY shardings for `createExportMhloShardingsPass` so that // we have both `mhlo.sharding`s and hidden `sdy.sharding`s on the module. We // want to have `mhlo.sharding`s for Pathways to read from. - pm.addPass(createSdyRoundTripExportShardingsPass()); + pm.addPass(createSdyRoundTripExportShardyAttrsPass()); pm.addPass(createExportMhloShardingsPass()); } void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { addCommonPreImportPasses(pm); - pm.addPass(createSdyRoundTripImportShardingsPass()); - // TODO(bartchr): replace with an sdy round trip shard map pass. - pm.addPass(createMhloRoundTripShardMapImportPass()); + pm.addPass(createSdyRoundTripImportShardyAttrsPass()); + pm.addPass(createSdyRoundTripShardMapImportPass()); + pm.addPass(createSdyRoundTripRemoveSizeOneAxesPass()); addCommonPostImportPasses(pm); } diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc new file mode 100644 index 00000000000000..06a383f1fefafd --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.cc @@ -0,0 +1,194 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "shardy/dialect/sdy/transforms/common/sharding_walker.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::ModuleOp; +using ::mlir::Operation; +using ::mlir::SmallVector; +using ::mlir::StringAttr; +using ::mlir::StringRef; +using ::mlir::SymbolTable; +using ::mlir::sdy::AxisRefAttr; +using ::mlir::sdy::DimensionShardingAttr; +using ::mlir::sdy::getMeshAttr; +using ::mlir::sdy::ManualAxesAttr; +using ::mlir::sdy::ManualComputationOp; +using ::mlir::sdy::MeshAttr; +using ::mlir::sdy::MeshAxisAttr; +using ::mlir::sdy::MeshOp; +using ::mlir::sdy::TensorShardingAttr; + +bool hasSizeOneAxes(MeshOp meshOp) { + return llvm::any_of(meshOp.getMesh().getAxes(), + [](MeshAxisAttr axis) { return axis.getSize() == 1; }); +} + +MeshAttr removeSizeOneAxes(MeshAttr mesh) { + SmallVector axes; + llvm::copy_if(mesh.getAxes(), std::back_inserter(axes), + [](MeshAxisAttr axis) { return axis.getSize() != 1; }); + return MeshAttr::get(mesh.getContext(), axes, mesh.getDeviceIds()); +} + +TensorShardingAttr removeSizeOneAxes(TensorShardingAttr sharding, + const SymbolTable& symbolTable) { + MeshAttr mesh = sharding.getMesh(symbolTable); + CHECK(mesh) << "unknown mesh: " << std::string_view(sharding.getMeshName()); + + auto isNotSizeOne = [&](AxisRefAttr axis) { return axis.getSize(mesh) != 1; }; + + // Remove from dimension shardings. + SmallVector dimShardings; + dimShardings.reserve(sharding.getRank()); + for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) { + SmallVector newAxes; + newAxes.reserve(dimSharding.getAxes().size()); + llvm::copy_if(dimSharding.getAxes(), std::back_inserter(newAxes), + isNotSizeOne); + // Remove priority if there are no sharding axes and the dimension is + // closed, since this isn't allowed by verification (would have no effect on + // propagation). + std::optional priority = + newAxes.empty() && dimSharding.getIsClosed() + ? std::nullopt + : dimSharding.getPriority(); + dimShardings.push_back( + DimensionShardingAttr::get(dimSharding.getContext(), newAxes, + dimSharding.getIsClosed(), priority)); + } + + // Remove from replicated axes. + SmallVector replicatedAxes; + llvm::copy_if(sharding.getReplicatedAxes(), + std::back_inserter(replicatedAxes), isNotSizeOne); + + // Remove for inlined mesh. + mlir::Attribute meshOrRef = sharding.getMeshOrRef(); + if (auto mesh = mlir::dyn_cast(meshOrRef)) { + meshOrRef = removeSizeOneAxes(mesh); + } + + return TensorShardingAttr::get(sharding.getContext(), meshOrRef, dimShardings, + replicatedAxes); +} + +void removeSizeOneManualAxes(ManualComputationOp manualComputationOp, + const SymbolTable& symbolTable) { + MeshAttr mesh = mlir::sdy::getCommonMesh( + manualComputationOp.getInShardings().getShardings(), + manualComputationOp.getOutShardings().getShardings(), symbolTable); + CHECK(mesh) << "no common mesh found for ManualComputationOp"; + + SmallVector newManualAxes; + llvm::copy_if( + manualComputationOp.getManualAxes(), std::back_inserter(newManualAxes), + [&](StringAttr axisName) { return mesh.getAxisSize(axisName) != 1; }); + manualComputationOp.setManualAxesAttr( + ManualAxesAttr::get(manualComputationOp.getContext(), newManualAxes)); +} + +class SdyRoundTripRemoveSizeOneAxesPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + SdyRoundTripRemoveSizeOneAxesPass) + + void runOnOperation() final { + ModuleOp moduleOp = getOperation(); + SymbolTable symbolTable(moduleOp); + + if (llvm::none_of(moduleOp.getOps(), hasSizeOneAxes)) { + // Nothing to do. + return; + } + + LOG(INFO) << "[Shardy] removing axes of size one."; + + mlir::sdy::transformShardings( + moduleOp, + [&](TensorShardingAttr sharding) { + return removeSizeOneAxes(sharding, symbolTable); + }, + [&](Operation* op) { + if (auto manualComputationOp = + mlir::dyn_cast(op)) { + removeSizeOneManualAxes(manualComputationOp, symbolTable); + } + }); + + for (auto meshOp : moduleOp.getOps()) { + meshOp.setMeshAttr(removeSizeOneAxes(meshOp.getMesh())); + } + } + + StringRef getArgument() const override { + return "xla-sdy-round-trip-remove-size-one-axes"; + } + + StringRef getDescription() const override { + return "Removes axes of size one from all meshes, shardings, and manual " + "computation ops, to avoid conflict during propagation that are due " + "to such axes."; + } + + void getDependentDialects(mlir::DialectRegistry& registry) const final { + registry.insert(); + } +}; + +} // namespace + +std::unique_ptr createSdyRoundTripRemoveSizeOneAxesPass() { + return std::make_unique(); +} + +void registerSdyRoundTripRemoveSizeOneAxesPass() { + mlir::registerPass(createSdyRoundTripRemoveSizeOneAxesPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h similarity index 57% rename from third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h rename to third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h index 2f77466af87626..04d280e5d91178 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/remove_size_one_axes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDINGS_H_ -#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDINGS_H_ +#ifndef XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ +#define XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ #include @@ -23,14 +23,15 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts the shardings from strings in MHLO frontend -// attributes to SDY meshes and shardings. -std::unique_ptr createSdyRoundTripImportShardingsPass(); +// Creates the pass that removes axes of size one from all meshes, shardings, +// and manual computation ops, to avoid conflict during propagation that are due +// to such axes. +std::unique_ptr createSdyRoundTripRemoveSizeOneAxesPass(); -// Registers the xla-sdy-round-trip-import-shardings pass. -void registerSdyRoundTripImportShardingsPass(); +// Registers the xla-sdy-round-trip-remove-size-one-axes pass. +void registerSdyRoundTripRemoveSizeOneAxesPass(); } // namespace sdy } // namespace xla -#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_IMPORT_SHARDINGS_H_ +#endif // XLA_SERVICE_SPMD_SHARDY_SDY_ROUND_TRIP_REMOVE_SIZE_ONE_AXES_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc index 08363df42d500a..16d9397ed16ee7 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.cc @@ -31,7 +31,9 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" @@ -52,6 +54,7 @@ namespace { using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::StringRef; +using ::mlir::func::CallOp; using ::mlir::func::FuncOp; namespace stablehlo = ::mlir::stablehlo; @@ -72,29 +75,46 @@ class SdyRoundTripShardMapExportPass auto rewriter = mlir::IRRewriter(context); moduleOp->walk([&](sdy::ManualComputationOp manualComputation) { rewriter.setInsertionPointToEnd(&moduleOp.getRegion().front()); + mlir::Location loc = manualComputation.getLoc(); + mlir::Region& manualCompBody = manualComputation.getBody(); + mlir::TypeRange manualCompBodyArgTypes = + manualCompBody.getArgumentTypes(); + mlir::TypeRange localResultTypes = + sdy::getBodyTerminatorOpOperandTypes(manualComputation); auto funcOp = rewriter.create( - manualComputation.getLoc(), kManualComputationBodyFuncName, - rewriter.getFunctionType( - manualComputation.getBody().getArgumentTypes(), - sdy::getBodyTerminatorOpOperandTypes(manualComputation))); - sdy::inlineRegionAndConvertTerminatorOp( - manualComputation.getBody(), funcOp.getBody()); + loc, kManualComputationBodyFuncName, + rewriter.getFunctionType(manualCompBodyArgTypes, localResultTypes)); mlir::StringAttr funcName = symbolTable.insert(funcOp); rewriter.setInsertionPoint(manualComputation); - auto customCallOp = rewriter.create( - manualComputation.getLoc(), manualComputation.getResultTypes(), - manualComputation->getOperands()); - customCallOp.setCallTargetName(kManualComputationCustomCallTargetName); - customCallOp.setCalledComputationsAttr( - rewriter.getArrayAttr(mlir::FlatSymbolRefAttr::get(funcName))); - addFrontendAttribute(customCallOp, kInShardings, + stablehlo::CustomCallOp fullToShard; + mlir::ValueRange operands = manualComputation->getOperands(); + if (!operands.empty()) { + fullToShard = rewriter.create( + loc, manualCompBodyArgTypes, operands); + fullToShard.setCallTargetName(kGlobalToLocalShapeCallTargetName); + operands = fullToShard->getResults(); + } + + auto callOp = + rewriter.create(loc, localResultTypes, funcName, operands); + addFrontendAttribute(callOp, kInShardings, manualComputation.getInShardings()); - addFrontendAttribute(customCallOp, kOutShardings, + addFrontendAttribute(callOp, kOutShardings, manualComputation.getOutShardings()); - addFrontendAttribute(customCallOp, kManualAxes, + addFrontendAttribute(callOp, kManualAxes, manualComputation.getManualAxesAttr()); - rewriter.replaceOp(manualComputation, customCallOp->getResults()); + + mlir::ResultRange results = manualComputation->getResults(); + if (!results.empty()) { + auto shardToFull = rewriter.create( + loc, manualComputation.getResultTypes(), callOp->getResults()); + shardToFull.setCallTargetName(kLocalToGlobalShapeCallTargetName); + results = shardToFull->getResults(); + } + sdy::inlineRegionAndConvertTerminatorOp( + manualCompBody, funcOp.getBody()); + rewriter.replaceOp(manualComputation, results); }); } @@ -104,9 +124,9 @@ class SdyRoundTripShardMapExportPass StringRef getDescription() const override { return "Converts the body of a ManualComputationOp to a separate function " - "with a CustomCallOp of the same name referring to it. The " - "CustomCallOp saves the in/out shardings and manual axes as " - "frontend attrs for HLO round tripping."; + "with a CallOp and a pair of CustomCallOps that change the shape of " + "the arguments/results. The CallOp saves the in/out shardings and " + "manual axes as frontend attrs."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h index 34eef5eeb13cf6..c3a7ed9b7ea3a7 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_export.h @@ -23,8 +23,10 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts `ManualComputationOps` to a separate function -// and `CustomCallOp` for round tripping between HLO. +// Creates the pass that converts `ManualComputationOp`s to a separate function +// with a CallOp and a pair of `CustomCallOp`s that change the shape of the +// arguments/results. The CallOp saves the in/out shardings and manual axes as +// frontend attrs. std::unique_ptr createSdyRoundTripShardMapExportPass(); // Registers the xla-sdy-round-trip-shard-map-export pass. diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc index 8975fa2142691c..7a0e1d018e0c2e 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.cc @@ -15,14 +15,15 @@ limitations under the License. #include "xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h" +#include #include #include #include "absl/log/check.h" +#include "absl/strings/match.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" @@ -31,7 +32,9 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -41,7 +44,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" #include "shardy/dialect/sdy/ir/dialect.h" #include "shardy/dialect/sdy/ir/utils.h" -#include "stablehlo/dialect/StablehloOps.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/utils.h" @@ -55,47 +58,71 @@ using ::mlir::ModuleOp; using ::mlir::OpConversionPattern; using ::mlir::StringRef; using ::mlir::SymbolTable; +using ::mlir::func::CallOp; using ::mlir::func::FuncOp; -using ::mlir::stablehlo::CustomCallOp; +using ::mlir::mhlo::CustomCallOp; namespace sdy = ::mlir::sdy; -// Converts `CustomCallOp`s called `@local_xla.sdy.ManualComputation` with in/out -// shardings and manual axes as frontend attrs to `ManualComputationOp`s. -class ManualComputationPattern : public OpConversionPattern { +// Converts a CallOp calling a @local_xla.sdy.manual_computation_body func with in/out +// shardings and manual axes as frontend attrs, wrapped with custom calls that +// change the shape of the arguments/results to a `ManualComputationOp`. See +// `SdyRoundTripShardMapExportPass` for its counterpart. +class ManualComputationPattern : public OpConversionPattern { public: explicit ManualComputationPattern(MLIRContext* context, const SymbolTable& symbolTable) - : OpConversionPattern(context), symbolTable(symbolTable) {} + : OpConversionPattern(context), symbolTable(symbolTable) {} mlir::LogicalResult matchAndRewrite( - CustomCallOp customCallOp, OpAdaptor adaptor, + CallOp callOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter& rewriter) const override { - if (customCallOp.getCallTargetName() != - kManualComputationCustomCallTargetName) { + if (!absl::StartsWith(callOp.getCallee(), kManualComputationBodyFuncName)) { return mlir::failure(); } - CHECK_EQ(customCallOp.getCalledComputations().size(), 1); - auto shmapBodyFunc = - symbolTable.lookup((*customCallOp.getCalledComputations() - .getAsRange() - .begin()) - .getValue()); + // NOTE: if the original `ManualComputationOp` had no operands (results), + // then a @FullToShard (@ShardToFull) custom call won't be present. So + // we have to take the operands/results of the newly created + // `ManualComputationOp` differently depending on whether the original had + // operands/results. + CustomCallOp fullToShard; + mlir::ValueRange operands = callOp->getOperands(); + if (!operands.empty()) { + fullToShard = callOp->getOperand(0).getDefiningOp(); + CHECK(fullToShard); + CHECK(fullToShard.getCallTargetName() == + kGlobalToLocalShapeCallTargetName); + operands = fullToShard->getOperands(); + } + mlir::TypeRange resultTypes = callOp->getResultTypes(); + CustomCallOp shardToFull; + if (!resultTypes.empty()) { + CHECK(callOp->getResult(0).hasOneUse()) + << "all CallOp results should be used by a single ShardToFull"; + shardToFull = + mlir::cast(*callOp->getResult(0).getUsers().begin()); + CHECK(shardToFull.getCallTargetName() == + kLocalToGlobalShapeCallTargetName); + resultTypes = shardToFull->getResultTypes(); + } + + auto shmapBodyFunc = symbolTable.lookup(callOp.getCallee()); if (shmapBodyFunc.empty()) { - return customCallOp->emitOpError( + return callOp->emitOpError( "expected a unique FuncOp per " - "@local_xla.sdy.ManualComputation custom call. Were " + "@local_xla.sdy.manual_computation_body call. Were " "functions maybe somehow shared/de-duped between " "two ManualComputations?"); } - mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(customCallOp); - CHECK(frontendAttrs); + mlir::DictionaryAttr frontendAttrs = getFrontendAttrs(callOp); + CHECK(frontendAttrs) + << "Expected in/out shardings and manual axes as frontend attrs on the " + "CallOp during round tripping."; auto manualComputationOp = rewriter.replaceOpWithNewOp( - customCallOp, customCallOp->getResultTypes(), - customCallOp->getOperands(), + callOp, resultTypes, operands, parseStringAttr(frontendAttrs, kInShardings), parseStringAttr(frontendAttrs, @@ -104,6 +131,12 @@ class ManualComputationPattern : public OpConversionPattern { sdy::inlineRegionAndConvertTerminatorOp( shmapBodyFunc.getBody(), manualComputationOp.getRegion(), rewriter); rewriter.eraseOp(shmapBodyFunc); + if (fullToShard) { + rewriter.eraseOp(fullToShard); + } + if (shardToFull) { + rewriter.replaceOp(shardToFull, manualComputationOp->getResults()); + } return mlir::success(); } @@ -124,10 +157,10 @@ class SdyRoundTripShardMapImportPass SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(module); MLIRContext& context = getContext(); mlir::ConversionTarget target(context); - target.addDynamicallyLegalOp([](CustomCallOp op) { - return op.getCallTargetName() != kManualComputationCustomCallTargetName; + target.addDynamicallyLegalOp([](CallOp op) { + return !absl::StartsWith(op.getCallee(), kManualComputationBodyFuncName); }); - target.addLegalOp(); + target.addLegalOp(); mlir::RewritePatternSet patterns(&context); patterns.add(&context, symbolTable); if (mlir::failed(mlir::applyPartialConversion(module, target, @@ -141,9 +174,10 @@ class SdyRoundTripShardMapImportPass } StringRef getDescription() const override { - return "converts CustomCalls called @local_xla.sdy.manual_computation_body " - "with in/out shardings and manual axes as frontend attrs to a " - "`ManualComputationOp`"; + return "converts a CallOp calling a @local_xla.sdy.manual_computation_body func " + "with in/out shardings and manual axes as frontend attrs, wrapped " + "with a pair of `CustomCallOps` that change the shape of the " + "arguments/results, to a ManualComputationOp"; } void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h index 1520c8baa663f7..e84304a177dce9 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/shard_map_import.h @@ -23,9 +23,10 @@ limitations under the License. namespace xla { namespace sdy { -// Creates the pass that converts a `CustomCallOp` called -// `kManualComputationBodyFuncName` with in/out shardings and manual -// axes as frontend attrs to a `ManualComputationOp`. +// Creates the pass that converts a `CallOp` calling +// `@local_xla.sdy.manual_computation_body` with in/out shardings and manual +// axes as frontend attrs, wrapped with a pair of `CustomCallOp`s that change +// the shape of the arguments/results, to a `ManualComputationOp`. std::unique_ptr createSdyRoundTripShardMapImportPass(); // Registers the xla-sdy-round-trip-shard-map-import pass. diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD index 448496e4e6de84..75974ab9c50c87 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/BUILD @@ -22,12 +22,12 @@ cc_library( deps = [ "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "//xla/mlir_hlo", "//xla/mlir_hlo:mhlo_passes", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/translate/hlo_to_mhlo:hlo_to_mlir_hlo", - "//xla/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc index b9c55aebcdbf6b..da7bda8f60e3b9 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/test_utils/mhlo_to_hlo_to_mhlo.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" @@ -36,13 +36,13 @@ limitations under the License. #include "shardy/dialect/sdy/ir/dialect.h" #include "stablehlo/dialect/StablehloOps.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "tsl/platform/errors.h" namespace xla { @@ -121,7 +121,7 @@ class SdyRoundTripMhloToHloToMhloPass void getDependentDialects(mlir::DialectRegistry& registry) const final { registry.insert(); } }; diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h deleted file mode 100644 index 666e168322b5ab..00000000000000 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ -#define XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/call_inliner.h" - -namespace xla { - -// The same as CallInliner, except as part of -// go/jax-shmap -> `sdy.ManualComputationOp` importing, we require the pattern -// in MHLO: -// ``` -// %shard_arg0_0 = custom_call @Sharding(%0) -// %shard_arg0_1 = custom_call @SPMDFullToShardShape(%shard_arg0_0) -// ... -// %shard_argN_0 = custom_call @Sharding(%N) -// %shard_argN_1 = custom_call @SPMDFullToShardShape(%shard_argN_0) -// -// %shard_result0, ..., %shard_resultN = func.call @shmap_body(%shard_arg0_1, -// ..., -// %shard_argN_1) -// -// %shard_result0_0 = custom_call @Sharding(%shard_result0) -// %shard_result0_1 = custom_call @SPMDShardToFullShape(%shard_result0_0) -// ... -// %shard_resultN_0 = custom_call @Sharding(%shard_resultN) -// %shard_resultN_1 = custom_call @SPMDShardToFullShape(%shard_resultN_0) -// ``` -// We specifically match on the `func.call @shmap_body` since we want to inline -// the body of that function into the `ManualComputationOp` body. So this makes -// sure we inline all functions except for the shmap_body's when using -// Shardy. When Shardy is disabled, then we have the same behavior as -// CallInliner. -class ShardyCallInliner : public CallInliner { - public: - using CallInliner::CallInliner; - absl::string_view name() const override { return "shardy-call-inliner"; } - - bool IsInlineableCallOp(HloInstruction* instruction) const override; -}; - -} // namespace xla - -#endif // XLA_SERVICE_SPMD_SHARDY_SHARDY_CALL_INLINER_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc deleted file mode 100644 index 00d952b3b80461..00000000000000 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/shardy_call_inliner.h" - -#include -#include "absl/log/log.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace sdy { - -using ShardyCallInlinerTest = xla::HloTestBase; - -TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { - const char* const hloString = R"( - HloModule jit_f, entry_computation_layout={(f32[8,8]{1,0})->f32[8,8]{1,0}} - - %prefix_shmap_body_suffix.4 (Arg_0.5: f32[1,8]) -> f32[1,8] { - %Arg_0.5 = f32[1,8]{1,0} parameter(0) - ROOT %add.6 = f32[1,8]{1,0} add(f32[1,8]{1,0} %Arg_0.5, f32[1,8]{1,0} %Arg_0.5), metadata={source_file="-" source_line=11} - } - - ENTRY %main.10 (Arg_0.1: f32[8,8]) -> f32[8,8] { - %Arg_0.1 = f32[8,8]{1,0} parameter(0) - %custom-call.2 = f32[8,8]{1,0} custom-call(f32[8,8]{1,0} %Arg_0.1), custom_call_target="Sharding", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=3} - %custom-call.3 = f32[1,8]{1,0} custom-call(f32[8,8]{1,0} %custom-call.2), custom_call_target="SPMDFullToShardShape", sharding={manual}, metadata={source_file="-" source_line=4} - %call.7 = f32[1,8]{1,0} call(f32[1,8]{1,0} %custom-call.3), to_apply=%prefix_shmap_body_suffix.4 - %custom-call.8 = f32[1,8]{1,0} custom-call(f32[1,8]{1,0} %call.7), custom_call_target="Sharding", sharding={manual}, metadata={source_file="-" source_line=6} - ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} - })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); - module->mutable_config().set_use_shardy_partitioner(true); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); - VLOG(1) << module->ToString(); - // The single call in the module is not inlined. - EXPECT_FALSE(changed); - - HloInstruction* call = FindInstruction(module.get(), xla::HloOpcode::kCall); - EXPECT_NE(call, nullptr); - EXPECT_TRUE(call->has_to_apply()); - EXPECT_EQ(call->to_apply()->name(), "prefix_shmap_body_suffix.4"); -} - -} // namespace sdy -} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index c5c8997a6d51e2..ea7bb1043eeafc 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include +#include #include #include #include @@ -45,25 +46,25 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout.h" #include "xla/map_util.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_dce.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/spmd/shardy/constants.h" #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_export.h" #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" #include "xla/service/spmd/shardy/utils.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_layout.h" #include "xla/shape_util.h" -#include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" -#include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -308,11 +309,22 @@ absl::StatusOr ShardyXLA::Run( /*flatten_computation_args_result=*/true)); std::string shardyDir = hloModule->config().debug_options().xla_dump_to(); + + if (shardyDir == "sponge") { + shardyDir = getenv("TEST_UNDECLARED_OUTPUTS_DIR"); + if (shardyDir.empty()) { + LOG(WARNING) << "\"sponge\" specified as dump directory but " + "TEST_UNDECLARED_OUTPUTS_DIR is not set!"; + } + } + if (!shardyDir.empty()) { shardyDir = tsl::io::JoinPath(shardyDir, "shardy", std::string_view(mlirModule->getName().value_or(""))); + LOG(INFO) << "Using Shardy output directory: " << shardyDir; } + // MLIR pipeline: (1) import, (2) Shardy, and (3) export. bool enableVerifier = false; @@ -326,19 +338,13 @@ absl::StatusOr ShardyXLA::Run( "sdy_module_before_xla_import")); bool useTupleArgs = false; mlir::DictionaryAttr moduleFrontendAttrs = getFrontendAttrs(*mlirModule); - if (moduleFrontendAttrs && moduleFrontendAttrs.get(kUseTupleArgs)) { + if (hasKey(moduleFrontendAttrs, kUseTupleArgs)) { useTupleArgs = true; removeFrontendAttribute(*mlirModule, kUseTupleArgs); } - // TODO(bartchr): Only call addSdyRoundTripImportPipeline when JAX & PartIR - // integration is complete. Need to branch on `kPythonIntegrationComplete` - // since partir.jit lowers to SDY (so call addSdyRoundTripImportPipeline) but - // jax.jit doesn't yet (so call addMhloImportPipeline). - if (moduleFrontendAttrs && - moduleFrontendAttrs.get(kPythonIntegrationComplete)) { - removeFrontendAttribute(*mlirModule, kPythonIntegrationComplete); - addSdyRoundTripImportPipeline(pm); - } else { + + if (hasKey(moduleFrontendAttrs, kImportMhloShardings)) { + removeFrontendAttribute(*mlirModule, kImportMhloShardings); auto spanToArrayRef = [](absl::Span span) { return mlir::ArrayRef(span.data(), span.size()); }; @@ -349,6 +355,9 @@ absl::StatusOr ShardyXLA::Run( .allow_spmd_sharding_propagation_to_parameters()), spanToArrayRef( hloModule->config().allow_spmd_sharding_propagation_to_output())); + } else { + // This is the default path. + addSdyRoundTripImportPipeline(pm); } // Store the entry computation layout, input-output alias config, and buffer @@ -415,9 +424,7 @@ absl::StatusOr ShardyXLA::Run( // We don't fully replace the HLO module, so it will continue to have the // temporary frontend attributes. So clean them up as XLA won't need them. - removeFrontendAttributes( - hloModule, - {kUseTupleArgs, kPythonIntegrationComplete, kMeshesRoundTripAttr}); + removeFrontendAttributes(hloModule, {kUseTupleArgs, kMeshesRoundTripAttr}); return true; } diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 40463d0dc74fce..df52bb57562a57 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/service/spmd/shardy/constants.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" @@ -32,8 +33,21 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace sdy { +namespace { + using ShardyXLATest = xla::HloTestBase; +void runShardy(VerifiedHloModule* module) { + FrontendAttributes attrs; + attrs.mutable_map()->try_emplace(xla::sdy::kImportMhloShardings, "t"); + module->add_frontend_attributes(attrs); + TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module)); + VLOG(1) << module->ToString(); + EXPECT_TRUE(changed); +} + +} // namespace + TEST_F(ShardyXLATest, AllowSpmdShardingPropagationParametersOutputRespected) { const char* const hloString = R"( HloModule module, allow_spmd_sharding_propagation_to_parameters={false,true}, allow_spmd_sharding_propagation_to_output={true} @@ -47,9 +61,7 @@ TEST_F(ShardyXLATest, AllowSpmdShardingPropagationParametersOutputRespected) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{replicated}")); @@ -76,9 +88,7 @@ TEST_F(ShardyXLATest, ElementWise) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); HloInstruction* add = FindInstruction(module.get(), xla::HloOpcode::kAdd); EXPECT_NE(add, nullptr); @@ -113,9 +123,7 @@ TEST_F(ShardyXLATest, CostantSplitter) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); HloInstruction* dot = FindInstruction(module.get(), xla::HloOpcode::kDot); @@ -163,9 +171,7 @@ TEST_F(ShardyXLATest, Dot) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,2,1,2]<=[8] last_tile_dim_replicate}")); @@ -200,9 +206,7 @@ TEST_F(ShardyXLATest, DotTiledBatchDim) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,2,1]<=[4]}")); @@ -228,9 +232,7 @@ TEST_F(ShardyXLATest, DotMergeOperands1) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,2,1,2]<=[8] last_tile_dim_replicate}")); @@ -256,9 +258,7 @@ TEST_F(ShardyXLATest, DotMergeOperands2) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,2,2]<=[8]}")); @@ -281,9 +281,7 @@ TEST_F(ShardyXLATest, DotMergeOperands3) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,4]<=[8]}")); @@ -309,9 +307,7 @@ TEST_F(ShardyXLATest, BackwardDotFromContracting) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,2,2]<=[8]}")); @@ -336,9 +332,7 @@ TEST_F(ShardyXLATest, EntryComputationLayoutSingleResult) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ( module->entry_computation_layout().ToString(), @@ -357,9 +351,7 @@ TEST_F(ShardyXLATest, EntryComputationLayoutNestedTuple) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->entry_computation_layout().ToString(), "(f32[4,2]{0,1:T(2,128)}, f32[4,2]{0,1:T(2,128)}, " @@ -383,9 +375,7 @@ TEST_F(ShardyXLATest, EntryComputationLayoutMissingLayout) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->entry_computation_layout().ToString(), "(f32[3,8,32,4]{2,1,3,0:T(8,128)}, " @@ -404,9 +394,7 @@ TEST_F(ShardyXLATest, InputOutputAliasConfigSingleResult) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->input_output_alias_config().ToShortString(), "{}: (1, {}, may-alias)"); @@ -425,9 +413,7 @@ TEST_F(ShardyXLATest, InputOutputAliasConfigSingleResultNestedParams) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->input_output_alias_config().ToShortString(), "{}: (1, {}, may-alias)"); @@ -444,9 +430,7 @@ TEST_F(ShardyXLATest, InputOutputAliasConfigNestedResultAndParams) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->input_output_alias_config().ToShortString(), "{1}: (1, {}, may-alias), {3}: (3, {}, may-alias)"); @@ -464,9 +448,7 @@ TEST_F(ShardyXLATest, BufferDonorConfigSingleResult) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->buffer_donor_config().ToShortString(), "(1, {})"); } @@ -482,9 +464,7 @@ TEST_F(ShardyXLATest, BufferDonorConfigNestedTuple) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_EQ(module->buffer_donor_config().ToShortString(), "(0, {}), (2, {})"); } @@ -500,9 +480,7 @@ TEST_F(ShardyXLATest, ShardingCustomCall) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->parameter_instruction(0), op::Sharding("{devices=[2,1]<=[2]}")); @@ -525,9 +503,7 @@ TEST_F(ShardyXLATest, RngBitGenerator) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Sharding("{{devices=[16,2]<=[32]}, {devices=[8,4]<=[32]}}")); @@ -574,9 +550,7 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); HloInstruction* whileInst = FindInstruction(module.get(), xla::HloOpcode::kWhile); @@ -622,9 +596,7 @@ TEST_F(ShardyXLATest, ShardMap) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); // The entry computation and the region_add for the all-reduce. shmap_body is // inlined. @@ -653,9 +625,8 @@ TEST_F(ShardyXLATest, EmptyModule) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); + EXPECT_EQ(module->entry_computation_layout().ToString(), "()->()"); EXPECT_EQ(module->input_output_alias_config().ToShortString(), ""); } @@ -679,9 +650,8 @@ TEST_F(ShardyXLATest, TestUseTuplesTrue) { })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyXLA().Run(module.get())); - VLOG(1) << module->ToString(); - EXPECT_TRUE(changed); + runShardy(module.get()); + EXPECT_EQ(module->entry_computation()->parameter_instructions().size(), 1); EXPECT_EQ(module->buffer_donor_config().ToShortString(), "(0, {1})"); EXPECT_EQ(module->input_output_alias_config().ToShortString(), diff --git a/third_party/xla/xla/service/spmd/shardy/test/export_named_computations.mlir b/third_party/xla/xla/service/spmd/shardy/test/export_named_computations.mlir new file mode 100644 index 00000000000000..2df5e216bc0c05 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/export_named_computations.mlir @@ -0,0 +1,70 @@ +// RUN: sdy_opt %s -xla-sdy-export-named-computations 2>&1 | FileCheck %s + +sdy.mesh @mesh = <["x"=2, "y"=2]> + +// Note we don't override the block argument shardings of the function +// @ignore_operand_shardings, but we set the argument shardings on the call +// to @foo. +// CHECK-LABEL: func @ignore_operand_shardings( +// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) +// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { +func.func @ignore_operand_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NEXT: %[[CALL:.*]] = call @foo(%arg0) + // CHECK-SAME: {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, + // CHECK-SAME: random_attr = "random_value", + // CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} + // CHECK-SAME: : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = stablehlo.custom_call @MoveToHost(%[[CALL]]) {backend_config = "", sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y", ?}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> + %0 = sdy.named_computation<"foo">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) { + %2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32> + sdy.return %2 : tensor<8x2xi32> + } {random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = stablehlo.custom_call @MoveToHost(%0) {backend_config = "", sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y", ?}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> +} + +// CHECK-LABEL: func @vanilla_named_computation(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { +func.func @vanilla_named_computation(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %[[CALL]] : tensor<8x2xi32> + %0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>) { + %1 = stablehlo.multiply %arg1, %arg1 : tensor<8x2xi32> + sdy.return %1 : tensor<8x2xi32> + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// CHECK-LABEL: func @multiple_same_named_computations( +func.func @multiple_same_named_computations(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NEXT: %0 = call @baz(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %1 = call @baz_0(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {}]>]>} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %1 : tensor<8x2xi32> + %0 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) { + %2 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32> + sdy.return %2 : tensor<8x2xi32> + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = sdy.named_computation<"baz">(%arg0) in_shardings=[<@mesh, [{}, {"y"}]>] out_shardings=[<@mesh, [{"x"}, {}]>] (%arg1: tensor<8x2xi32>) { + %3 = stablehlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32> + sdy.return %3 : tensor<8x2xi32> + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> +} + +// CHECK-LABEL: func private @foo +// CHECK-SAME: (%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) +// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) { +// CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}, sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x", ?}, {"y", ?}]>]>} : tensor<8x2xi32> +// CHECK-NEXT: return %0 : tensor<8x2xi32> + +// CHECK-LABEL: func private @bar(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { +// CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 : tensor<8x2xi32> +// CHECK-NEXT: return %0 : tensor<8x2xi32> + +// CHECK-LABEL: func private @baz( +// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) +// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) + +// CHECK-LABEL: func private @baz_0( +// CHECK-SAME: %arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) +// CHECK-SAME: -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir new file mode 100644 index 00000000000000..35c4d62e8d099d --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls.mlir @@ -0,0 +1,58 @@ +// RUN: sdy_opt %s -xla-sdy-import-backend-func-calls 2>&1 | FileCheck %s + +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> + +// CHECK-LABEL: func @no_out_shardings +func.func @no_out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) { + // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> + // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, + // CHECK-SAME: random_attr = "random_value"} + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> + %0 = call @foo(%arg0) {random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> +} + +func.func private @foo(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// CHECK-LABEL: func @out_shardings +func.func @out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NEXT: %[[NC:.*]] = sdy.named_computation<"bar">(%arg0) out_shardings=[<@mesh, [{"x"}, {"y"}]>] (%arg1: tensor<8x2xi32>) { + // CHECK-NEXT: %[[MULT:.*]] = mhlo.multiply %arg1, %arg1 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + // CHECK-NEXT: sdy.return %[[MULT]] : tensor<8x2xi32> + // CHECK-NEXT: } {mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}, + // CHECK-SAME: random_attr = "random_value"} + // CHECK-SAME: (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: %[[MOVE_TO_HOST:.*]] = mhlo.custom_call @MoveToHost(%[[NC]]) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %[[MOVE_TO_HOST]] : tensor<8x2xi32> + %0 = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %1 = mhlo.custom_call @MoveToHost(%0) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %1 : tensor<8x2xi32> +} + +// NOTE: we ignore any arg/result shardings on the function. +func.func private @bar(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) { + %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// Don't import if there is no backend_config. +// CHECK-LABEL: func @no_backend_config +func.func @no_backend_config(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + // CHECK-NEXT: %[[CALL:.*]] = call @baz(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> + // CHECK-NEXT: return %[[CALL]] : tensor<8x2xi32> + %0 = call @baz(%arg0) : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +func.func private @baz(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls_failure.mlir new file mode 100644 index 00000000000000..d663bcf3ce5ac7 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/import_backend_func_calls_failure.mlir @@ -0,0 +1,16 @@ +// RUN: sdy_opt %s -xla-sdy-import-backend-func-calls -split-input-file -verify-diagnostics + +sdy.mesh @mesh = #sdy.mesh<["x"=2, "y"=2]> + +func.func @out_shardings(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) { + %0 = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + // expected-error @+1 {{failed to legalize operation 'func.call' that was explicitly marked illegal}} + %1 = call @bar(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"x"}, {"y"}]>]>, random_attr = "random_value", mhlo.frontend_attributes = {backend_config = "{\22flag_configs\22:[],\22scoped_memory_configs\22:[],\22device_type\22:\22DEVICE_TYPE_HOST\22,\22used_scoped_memory_configs\22:[]}"}} : (tensor<8x2xi32>) -> tensor<8x2xi32> + %2 = mhlo.custom_call @MoveToHost(%1) {backend_config = ""} : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %2 : tensor<8x2xi32> +} + +func.func private @bar(%arg0: tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {}]>}) -> (tensor<8x2xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}]>}) { + %0 = mhlo.multiply %arg0, %arg0 {mhlo.frontend_attributes = {_xla_compute_type = "host"}} : tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir index 045f49b020d0de..9cc62dd41959b7 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir @@ -19,10 +19,11 @@ func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{dev // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_2"}, {"axis_0", "axis_1"}]>}, // CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"axis_0", "axis_2"}]>}, // CHECK-SAME: %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"axis_1"}]>}) -// CHECK-SAME: -> tensor<8x16xf32> { +// CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_0", "axis_1"}, {"axis_2"}]>}) { func.func @multiple_shardings(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,8,4]<=[2,4,4]T(0,2,1) last_tile_dim_replicate}"}, - %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"}) -> tensor<8x16xf32> { + %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"}) + -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { // CHECK-NEXT: mhlo.add // CHECK-SAME{LITERAL}: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"axis_1", "axis_0"}, {}]>]>} %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} : tensor<8x8xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index 0884347694aa77..0fe29bef5870a7 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -29,10 +29,11 @@ func.func @non_trivial_common_mesh(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.s // CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, // CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,8,4]<=[2,4,4]T(0,2,1) last_tile_dim_replicate}"}, // CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"}) -// CHECK-SAME: -> tensor<8x16xf32> { +// CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, - %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { + %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) + -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_0", "axis_1"}, {"axis_2"}]>}) { // CHECK-NEXT: mhlo.add // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[8,1,4]<=[2,4,4]T(1,0,2) last_tile_dim_replicate}"} %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> @@ -174,10 +175,37 @@ func.func @mesh_empty_should_be_converted_to_replicated_sharding(%arg0: tensor<8 func.func @multiple_shardings_with_device_list(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_4, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { -// CHECK-NEXT: mhlo.add -// CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}"} + // CHECK-NEXT: mhlo.add + // CHECK-SAME{LITERAL}: {mhlo.sharding = "{devices=[4,1,2]0,2,1,3,4,6,5,7 last_tile_dim_replicate}"} %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_4, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } +// CHECK-LABEL: func @named_sharding_in_manual_computation( +// CHECK-SAME: %arg0: tensor<32x2xi32> {mhlo.sharding = "{devices=[32,1]<=[32]}"}) +// CHECK-SAME: -> (tensor<32x2xi32> {mhlo.sharding = "{devices=[32,1]<=[32]}"}) { +func.func @named_sharding_in_manual_computation( + %arg0: tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) + -> (tensor<32x2xi32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", "y"}, {}]>}) { + // CHECK-NEXT: %[[COPY_0:.*]] = mhlo.copy %arg0 {mhlo.sharding = "{devices=[32,1]<=[32]}"} : tensor<32x2xi32> + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = mhlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{devices=[1,1,8,4]<=[32] last_tile_dims={manual, replicated}}"} : (tensor<32x2xi32>) -> tensor<4x2xi32> + // CHECK-NEXT: %[[FOO:.*]] = call @foo(%[[FULL_TO_SHARD]]) {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} : (tensor<4x2xi32>) -> tensor<4x2xi32> + // CHECK-NEXT: %[[COPY_1:.*]] = mhlo.copy %[[FOO]] {mhlo.sharding = "{devices=[1,1,8,4]<=[32] last_tile_dims={manual, replicated}}"} : tensor<4x2xi32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = mhlo.custom_call @SPMDShardToFullShape(%[[COPY_1]]) {mhlo.sharding = "{devices=[32,1]<=[32]}"} : (tensor<4x2xi32>) -> tensor<32x2xi32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<32x2xi32> + %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_2, [{"x", "y"}, {}]>] out_shardings=[<@mesh_2, [{"x", "y"}, {}]>] manual_axes={"x"} (%arg1: tensor<4x2xi32>) { + %1 = sdy.named_computation<"foo">(%arg1) in_shardings=[<@mesh_2, [{"y"}, {}]>] out_shardings=[<@mesh_2, [{"y"}, {}]>] (%arg2: tensor<4x2xi32>) { + %2 = stablehlo.multiply %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"y"}, {}]>]>} : tensor<4x2xi32> + sdy.return %2 : tensor<4x2xi32> + } : (tensor<4x2xi32>) -> tensor<4x2xi32> + sdy.return %1 : tensor<4x2xi32> + } : (tensor<32x2xi32>) -> tensor<32x2xi32> + return %0 : tensor<32x2xi32> +} + +// CHECK-LABEL: func private @foo +// CHECK-SAME: %arg0: tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"} +// CHECK-SAME: -> (tensor<4x2xi32> {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}) { +// CHECK-NEXT: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 {mhlo.sharding = "{devices=[4,1,8]<=[8,4]T(1,0) last_tile_dims={manual}}"} : tensor<4x2xi32> +// CHECK-NEXT: return %[[MULT]] : tensor<4x2xi32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir index 582e8d2a48935e..55ccddd9645d5e 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir @@ -29,7 +29,7 @@ func.func @sharding_custom_call_with_unspecified_dims(%arg0: tensor<8x8xf32> {mh func.func @manual(%arg0: tensor<8x8xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<4x8xf32> {mhlo.sharding = "{devices=[4,1,2]<=[8] last_tile_dim_replicate}"}) -> (tensor<8x8xf32>) { // CHECK: sdy.manual_computation(%arg0, %arg1) - // CHECK-SAME: in_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>, <@mesh, [{"axis_0"}, {}], replicated={"axis_1"}>] + // CHECK-SAME: in_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>, <@mesh, [{"axis_0"}, {}]>] // CHECK-SAME: out_shardings=[<@mesh, [{"axis_0", "axis_1"}, {}]>] // CHECK-SAME: manual_axes={"axis_0", "axis_1"} (%arg2: tensor<1x8xf32>, %arg3: tensor<1x8xf32>) { // CHECK-LABEL: mhlo.add diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir index 645c059427cb76..12641b0d746476 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_round_trip_shard_map_import.mlir @@ -261,10 +261,10 @@ func.func private @shmap_body_6(%arg0: tensor<16x8xf32>) -> (tensor<16x8xf32>) { return %0 : tensor<16x8xf32> } -// CHECK-LABEL: func.func public @sorted_replicated_axes -func.func public @sorted_replicated_axes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { +// CHECK-LABEL: func.func public @sharding_with_missing_manual_axes +func.func public @sharding_with_missing_manual_axes(%arg0: tensor<16x16xf32>) -> tensor<32x4xf32> { // CHECK: %0 = sdy.manual_computation(%arg0) - // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_2, [{"b"}, {"a"}], replicated={"c"}>] out_shardings=[<@mesh_2, [{"a"}, {}], replicated={"b", "c"}>] manual_axes={"a", "b", "c"} (%arg1: tensor<8x4xf32>) { + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_2, [{"b"}, {"a"}]>] out_shardings=[<@mesh_2, [{"a"}, {}], replicated={"c"}>] manual_axes={"a", "b", "c"} (%arg1: tensor<8x4xf32>) { // CHECK-NEXT: %1 = mhlo.add %arg1, %arg1 : tensor<8x4xf32> // CHECK-NEXT: sdy.return %1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<16x16xf32>) -> tensor<32x4xf32> diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index 037368f8e305bd..ae6e1640c50a04 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -5,6 +5,21 @@ // These would be needed to work for round-tripping in JAX integration. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Basic test with no meshes or shardings + +// CHECK-NOT: sdy.mesh + +// CHECK-LABEL: func @main( +// CHECK-SAME: %arg0: tensor<8x16xf32>) +func.func @main( + %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { + %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = mhlo.add %0, %0 : tensor<8x16xf32> + return %1 : tensor<8x16xf32> +} + +// ----- + // Basic test with func arg sharding // Make sure this temp attr doesn't exist anymore. @@ -202,6 +217,17 @@ func.func @main( return %2#0 : tensor<32x96xf32> } +// ----- + +// Test that sharding group op is preserved under import and export passes. + +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>) { + // CHECK: sdy.sharding_group %arg0 group_id=13 : tensor<8x16xf32> + sdy.sharding_group %arg0 group_id=13 : tensor<8x16xf32> + return %arg0 : tensor<8x16xf32> +} + // TODO(b/335481977): Add more tests for MHLO ops. So far tested all SDY // compiler APIs other than shard as/like (doesn't exist yet). See // round_trip_pipeline_manual_computation.mlir for ManualComputationOp tests. diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir index 37e94fbb510bb6..90754f8e9bf0a2 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline_manual_computation.mlir @@ -18,17 +18,13 @@ func.func @main(%arg0: tensor<16x32xf32>) -> tensor<128x32xf32> { // CHECK-NEXT: } : (tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SHARD_MAP]]#0, %[[SHARD_MAP]]#1 : tensor<128x32xf32> // CHECK-NEXT: return %[[ADD]] : tensor<128x32xf32> - %0 = mhlo.custom_call @Sharding(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={"a", "b"}>]>} : (tensor<16x32xf32>) -> tensor<16x32xf32> - %1 = mhlo.custom_call @SPMDFullToShardShape(%0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %2:2 = call @shmap_body_4(%1) : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) - %3 = mhlo.custom_call @Sharding(%2#0) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %4 = mhlo.custom_call @SPMDShardToFullShape(%3) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"a", "b"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %5 = mhlo.custom_call @Sharding(%2#1) : (tensor<16x32xf32>) -> tensor<16x32xf32> - %6 = mhlo.custom_call @SPMDShardToFullShape(%5) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_1, [{"b", "a"}, {}]>]>} : (tensor<16x32xf32>) -> tensor<128x32xf32> - %7 = mhlo.add %4, %6 : tensor<128x32xf32> - return %7 : tensor<128x32xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<16x32xf32>) -> tensor<16x32xf32> + %1:2 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {}], replicated={\\\22a\\\22, \\\22b\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22, \\\22b\\\22}, {}]>, <@mesh_1, [{\\\22b\\\22, \\\22a\\\22}, {}]>]>"}} : (tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) + %2:2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1#0, %1#1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>) + %3 = mhlo.add %2#0, %2#1 : tensor<128x32xf32> + return %3 : tensor<128x32xf32> } -// CHECK-NOT: func.func private @shmap_body_4 -func.func private @shmap_body_4(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { +// CHECK-NOT: func.func private @local_xla.sdy.manual_computation_body +func.func private @local_xla.sdy.manual_computation_body(%arg0: tensor<16x32xf32>) -> (tensor<16x32xf32>, tensor<16x32xf32>) { return %arg0, %arg0 : tensor<16x32xf32>, tensor<16x32xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir new file mode 100644 index 00000000000000..d0ed401a2a4299 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_inline_round_trip.mlir @@ -0,0 +1,31 @@ +// RUN: sdy_opt %s -xla-sdy-round-trip-export-pipeline -inline -xla-sdy-round-trip-testing-pipeline -split-input-file 2>&1 | FileCheck %s + +// Test with a nested func op that gets inlined after first export. + +// Make sure this temp attr doesn't exist anymore. +// CHECK-NOT: xla.sdy.sharding + +// CHECK: sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> +sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> + +// CHECK-LABEL: func @main( +// CHECK-SAME: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) +// CHECK-SAME: -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) +func.func @main(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) + -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}, {}]>}) { + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + // CHECK-NEXT: %[[MUL:.*]] = mhlo.multiply %[[ADD_0]], %[[ADD_0]] : tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %[[MUL]], %[[MUL]] : tensor<8x16xf32> + // CHECK-NEXT: return %[[ADD_1]] : tensor<8x16xf32> + %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> + %1 = func.call @nested_func(%0) : (tensor<8x16xf32>) -> (tensor<8x16xf32>) + %2 = mhlo.add %1, %1 : tensor<8x16xf32> + return %2 : tensor<8x16xf32> +} + +// CHECK-NOT: func @nested_func +func.func @nested_func(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) + -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}) { + %0 = mhlo.multiply %arg0, %arg0 : tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 959cbb9a4d4f4c..977de9208630fb 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: sdy_opt %s -xla-sdy-round-trip-export-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-export-pipeline 2>&1 | FileCheck %s sdy.mesh @mesh_0 = <["axis_0"=2, "axis_1"=4, "axis_2"=4]> sdy.mesh @mesh_1 = <["axis_0"=16]> @@ -65,20 +65,26 @@ func.func @func_result_sharding_returning_func_arg( return %arg0 : tensor<8x16xf32> } -// CHECK-LABEL: func @func_result_sharding_returning_op_value( -func.func @func_result_sharding_returning_op_value( - // CHECK: %arg0: tensor<8x16xf32>) - // CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { - %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { +// CHECK-LABEL: func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) +func.func @func_result_sharding_returning_op_value(%arg0: tensor<8x16xf32>) + // CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}, + // CHECK-SAME: tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, + // CHECK-SAME: tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}, + // CHECK-SAME: tensor<8x16xf32> {mhlo.sharding = "{replicated}"}) { + -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, + tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, + tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}, + tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{}, {}]>}) { // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, \\\22y\\\22}, {}]>, <@mesh_2, [{\\\22y\\\22, \\\22x\\\22}, {}]>]>"}, mhlo.sharding = - // CHECK-NEXT: %[[ADD_RESULT_SHARDING:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_0:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22, ?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\\\22y\\\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {\\\22y\\\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> - // CHECK-NEXT: return %[[ADD_RESULT_SHARDING]], %[[TEST_ONLY_RES_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_1]] : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> + // CHECK-NEXT: %[[ADD_RESULT_SHARDING_1:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {}]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK-NEXT: return %[[ADD_RESULT_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_0]], %[[TEST_ONLY_RES_SHARDING_1]], %[[ADD_RESULT_SHARDING_1]] %0 = mhlo.add %arg0, %arg0 : tensor<8x16xf32> %1:2 = mhlo.custom_call @sdy_testonly(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x","y"}, {}]>, <@mesh_2, [{"y","x"}, {}]>]>} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) - return %0, %1#0, %1#1 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> + return %0, %1#0, %1#1, %0 : tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32> } // CHECK-LABEL: func @sharding_constraint @@ -89,6 +95,14 @@ func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { return %0 : tensor<8x8xf32> } +// CHECK-LABEL: func @export_sharding_group +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @export_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK: mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "12 : i64"}} + sdy.sharding_group %arg0 group_id = 12: tensor<8x8xf32> + return %arg0 : tensor<8x8xf32> +} + // CHECK-LABEL: func @constant func.func @constant() -> tensor { // CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0> @@ -96,3 +110,60 @@ func.func @constant() -> tensor { %0 = sdy.constant dense<0> : tensor return %0 : tensor } + +// CHECK-LABEL: func @inlined_mesh( +// CHECK-SAME: %arg0: tensor<32xi32> +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\\\22a\\\22}]>"}, +// CHECK-SAME: mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}) +// CHECK-SAME: -> (tensor<32xi32> {mhlo.sharding = "{maximal device=5}"}) { +func.func @inlined_mesh( + %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>} +) -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, [{}]>}) { + // CHECK-NEXT: %[[SHARDING:.*]] = mhlo.custom_call @Sharding(%arg0) + // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}, mhlo.sharding = "{devices=[4]<=[4]}"} + // CHECK-NEXT: %[[RESULT_SHARDING:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[SHARDING]]) + // CHECK-SAME: mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"} + // CHECK-NEXT: return %[[RESULT_SHARDING]] + %0 = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> + return %0 : tensor<32xi32> +} + +// CHECK-LABEL: func @op_sharding_rule +func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK: stablehlo.custom_call @foo(%arg0, %arg1) {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> +} + +// CHECK-LABEL: func @sharding_and_op_sharding_rule +func.func @sharding_and_op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK: stablehlo.custom_call @foo(%arg0, %arg1) {mhlo.frontend_attributes = + // CHECK-SAME: {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\\\22x\\\22}, {}]>]>" + // CHECK-SAME: xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>, + sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x"}, {}]>]>} + : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> +} + +// ----- + +// CHECK-NOT: xla.sdy.meshes + +// CHECK-LABEL: func @non_sdy_module( +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, +// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) +// CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { +func.func @non_sdy_module(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"}, + %arg1: tensor<8x8xf32> {mhlo.sharding = "{devices=[1,2,16]<=[32] last_tile_dim_replicate}"}, + %arg2: tensor<8x16xf32> {mhlo.sharding = "{devices=[4,4,2]<=[2,16]T(1,0) last_tile_dim_replicate}"}) + -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { + // CHECK-NEXT: mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} + // CHECK-NOT: xla.sdy.sharding + // CHECK-NOT: xla.sdy.sharding_rule + %0 = mhlo.add %arg0, %arg1 {mhlo.sharding = "{devices=[4,8]<=[8,4]T(1,0)}"} : tensor<8x8xf32> + %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %1 : tensor<8x16xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 82edbe3f82c7a7..9c8e27a4871429 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -1,32 +1,55 @@ -// RUN: sdy_opt %s -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s +// RUN: sdy_opt %s --split-input-file -xla-sdy-round-trip-import-pipeline 2>&1 | FileCheck %s // CHECK-LABEL: module @multiple_func_result_shardings -module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>}"}} { +module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes = + "{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} { // CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]> + // CHECK: sdy.mesh @mesh2 = <["b"=4]> + // CHECK-LABEL: func @func_results_with_sharding - // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, - // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, - // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>} - // CHECK-SAME: ) -> ( - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}, - // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { - // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg2 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, + // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>} + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) { + // CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2 // CHECK-NEXT: } func.func @func_results_with_sharding( %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p2]>"}}, %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}}, %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}} - ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { + ) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) { %0 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %2 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> %4 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> - return %0, %1, %2, %3, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> + return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32> + } + + // This might happen due to inlined funcs that originally had result shardings + // CHECK-LABEL: func @func_result_shardings_used_by_other_ops( + // CHECK-SAME: %arg0: tensor<32xi32>, %arg1: tensor<32xi32> + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>}, + // CHECK-SAME: tensor<32xi32>) { + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg1 + // CHECK-NEXT: return %arg0, %[[ADD]] + // CHECK-NEXT: } + func.func @func_result_shardings_used_by_other_ops( + %arg0: tensor<32xi32>, %arg1: tensor<32xi32> + ) -> (tensor<32xi32>, tensor<32xi32>) { + %0 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %2 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = mhlo.add %1, %2 : tensor<32xi32> + return %1, %3 : tensor<32xi32>, tensor<32xi32> } // CHECK-LABEL: func @while_with_free_variables @@ -109,4 +132,122 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> return %3 : tensor<32xi32> } + + // CHECK-LABEL: func @inlined_mesh( + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding, [{"a"}]>}) + // CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding, [{}]>}) { + func.func @inlined_mesh( + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding, [{\\\22a\\\22}]>"}} + ) -> tensor<32xi32> { + // CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %arg0 , [{"c"}]> : tensor<32xi32> + // CHECK-NEXT: return %[[SHARDING]] + %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{\\\22c\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[, [{}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + return %1 : tensor<32xi32> + } + + // CHECK-LABEL: func @shardings_with_size_one_axes + // CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b"}p1]>}, + // CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{}], replicated={"b"}>}, + // CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b", ?}p0]>} + // CHECK-SAME: ) -> ( + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{}]>}, + // CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh2, [{"b"}]>}) { + func.func @shardings_with_size_one_axes( + %arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22b\\\22}p1], replicated={\\\22c\\\22}>"}}, + %arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22a\\\22}p2], replicated={\\\22b\\\22}>"}}, + %arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh2, [{\\\22c\\\22, \\\22b\\\22, ?}p0]>"}} + ) -> (tensor<32xi32>, tensor<32xi32>) { + // CHECK-NEXT: %[[SC1:.*]] = sdy.sharding_constraint %arg0 <@mesh2, [{"b", ?}]> + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[SC1]], %[[SC1]] + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: %[[SC2:.*]] = sdy.sharding_constraint %arg1 <@mesh2, [{}]> + // CHECK-NEXT: return %[[ADD]], %[[SC2]] + // CHECK-NEXT: } + %0 = mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22, \\\22b\\\22, ?}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %1 = mhlo.add %0, %0 : tensor<32xi32> + %2 = mhlo.custom_call @Sharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22c\\\22, \\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %3 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + %4 = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh2, [{\\\22b\\\22}]>]>"}} : (tensor<32xi32>) -> tensor<32xi32> + return %3, %4 : tensor<32xi32>, tensor<32xi32> + } + + // CHECK-LABEL: func @manual_computation_with_size_one_axes + func.func @manual_computation_with_size_one_axes(%arg0: tensor<16x32xf32>, %arg1: tensor<16x32xf32>) -> (tensor<16x32xf32>) { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body + // CHECK: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh2, [{}, {"b"}]>, <@mesh2, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh2, [{}, {"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME: (%arg2: tensor<16x8xf32>, %arg3: tensor<16x8xf32>) { + // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg2, %arg3 + // CHECK-NEXT: sdy.return %[[ADD]] + // CHECK-NEXT: } : (tensor<16x32xf32>, tensor<16x32xf32>) -> tensor<16x32xf32> + // CHECK-NEXT: return %[[MAN_COMP]] + %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<16x32xf32>, tensor<16x32xf32>) -> (tensor<16x8xf32>, tensor<16x8xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh2, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh2, [{}, {\\\22b\\\22}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh2, [{}, {\\\22b\\\22, \\\22a\\\22}]>]>"}} : (tensor<16x8xf32>, tensor<16x8xf32>) -> tensor<16x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<16x8xf32>) -> tensor<16x32xf32> + return %2 : tensor<16x32xf32> + } + + // CHECK-NOT: func @local_xla.sdy.manual_computation_body( + func.func @local_xla.sdy.manual_computation_body(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> tensor<16x8xf32> { + %0 = mhlo.add %arg0, %arg1 : tensor<16x8xf32> + return %0 : tensor<16x8xf32> + } +} + +// ----- + +// CHECK-NOT: sdy.mesh @mesh + +module @no_meshes_module attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{}"}} { + // CHECK-LABEL: func @no_sharding_rule + func.func @no_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + %0 = stablehlo.custom_call @foo(%arg0, %arg1) : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } + + // CHECK-LABEL: func @op_sharding_rule + func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } +} + +// ----- + +// CHECK-NOT: sdy.mesh @mesh + +module @no_meshes_attr_module { + // CHECK-LABEL: func @op_sharding_rule + func.func @op_sharding_rule(%arg0: tensor<8x2xf32>, %arg1: tensor<8x2xf32>) -> tensor<8x2xf64> { + // CHECK-NEXT: stablehlo.custom_call @foo(%arg0, %arg1) {sdy.sharding_rule = #sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>} + %0 = stablehlo.custom_call @foo(%arg0, %arg1) + {mhlo.frontend_attributes = {xla.sdy.sharding_rule = "#sdy.op_sharding_rule<([i, j], [i, j])->([i, j]) {i=8, j=2}>"}} : (tensor<8x2xf32>, tensor<8x2xf32>) -> tensor<8x2xf64> + return %0 : tensor<8x2xf64> + } +} + +// ----- + +// CHECK-LABEL: func @import_sharding_group +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_sharding_group(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> + mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> () + return %arg0 : tensor<8x8xf32> +} + +// ----- + +// CHECK-LABEL: func @import_sharding_group_with_unused_result +// CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { +func.func @import_sharding_group_with_unused_result(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { + // CHECK sdy.sharding_group %arg0 group_id = 21: tensor<8x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + return %arg0 : tensor<8x8xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir new file mode 100644 index 00000000000000..296dbda1e1f397 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_remove_size_one_axes.mlir @@ -0,0 +1,111 @@ +// RUN: sdy_opt %s -xla-sdy-round-trip-remove-size-one-axes 2>&1 | FileCheck %s + +sdy.mesh @mesh1 = <["a"=1, "b"=2, "c"=1, "d"=4, "e"=1], device_ids=[0, 2, 1, 3, 4, 6, 5, 7]> +sdy.mesh @mesh2 = <["a"=4, "b"=2]> +sdy.mesh @mesh3 = <["x"=1, "y"=1]> +sdy.mesh @mesh4 = <["a"=1, "b"=2, "c"=1]> + +// CHECK: sdy.mesh @mesh1 = <["b"=2, "d"=4], device_ids=[0, 2, 1, 3, 4, 6, 5, 7]> +// CHECK: sdy.mesh @mesh2 = <["a"=4, "b"=2]> +// CHECK: sdy.mesh @mesh3 = <[]> +// CHECK: sdy.mesh @mesh4 = <["b"=2]> + +// CHECK-LABEL: func @func_and_op_shardings +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}, {?}]>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"d", ?}, {}], replicated={"b"}>}, +// CHECK-SAME: %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a"}, {"b"}]>} +// CHECK-SAME: ) -> ( +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{}, {?}]>}, +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}, {}]>}) { +func.func @func_and_op_shardings( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b"}, {"c", ?}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"d", "e", ?}, {}], replicated={"b", "c"}>}, + %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a"}, {"b"}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"e"}, {"c", ?}]>}, + tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b", "c"}, {}], replicated={"e"}>}) { + // CHECK-NEXT: %[[ADD1:.*]] = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}, {?}]>]>} + // CHECK-NEXT: %[[ADD2:.*]] = mhlo.add %arg2, %arg2 + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: %[[ADD3:.*]] = mhlo.add %[[ADD2]], %[[ADD2]] {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{}, {}], replicated={"d"}>]>} + // CHECK-NEXT: return %[[ADD1]], %[[ADD3]] + // CHECK-NEXT: } + %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}, {"e", ?}]>]>} : tensor<8x8xf32> + %1 = mhlo.add %arg2, %arg2 : tensor<8x8xf32> + %2 = mhlo.add %1, %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"c"}, {}], replicated={"d"}>]>} : tensor<8x8xf32> + return %0, %2 : tensor<8x8xf32>, tensor<8x8xf32> +} + +// CHECK-LABEL: func @inlined_mesh +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding, [{"b"}, {?}]>} +// CHECK-SAME: ) -> ( +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding, [{"a", "b"}, {}]>}) { +func.func @inlined_mesh( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding, [{"a", "b"}, {"c", ?}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding, [{"a", "b"}, {}]>}) { + // CHECK-NEXT: mhlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[, [{?}, {?}]>]>} + %0 = mhlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[, [{"a", ?}, {"b", ?}]>]>} : tensor<8x8xf32> + return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @shardings_with_priorities +// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"b"}p0, {?}p3], replicated={"d"}>}, +// CHECK-SAME: %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a", ?}p2, {}]>} +// CHECK-SAME: ) -> ( +// CHECK-SAME: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh3, [{}, {?}p2]>}) { +func.func @shardings_with_priorities( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a", "b"}p0, {"c", ?}p3], replicated={"d", "e"}>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh2, [{"a", ?}p2, {}]>} +) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh3, [{"x"}p1, {"y", ?}p2]>}) { + // CHECK-NEXT: %[[ADD1:.*]] = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}p1, {?}]>]>} + // CHECK-NEXT: %[[ADD2:.*]] = mhlo.add %[[ADD1]], %[[ADD1]] + // CHECK-NOT: sdy.sharding + // CHECK-NEXT: return %[[ADD2]] + // CHECK-NEXT: } + %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{"d", ?}p1, {"e", ?}]>]>} : tensor<8x8xf32> + %1 = mhlo.add %0, %0 : tensor<8x8xf32> + return %1 : tensor<8x8xf32> +} + +// CHECK-LABEL: func @manual_computation +func.func @manual_computation(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> tensor<8x32xf32> { + // CHECK-NEXT: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh1, [{"d"}, {"b"}]>, <@mesh1, [{"b"}, {}], replicated={"d"}>] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh1, [{"d"}, {}], replicated={"b"}>] + // CHECK-SAME{LITERAL}: manual_axes={"b", "d"} + // CHECK-SAME: (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { + // CHECK-NEXT: stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{?}, {?}]>]>} + // CHECK: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[MAN_COMP]] + %0 = sdy.manual_computation(%arg0, %arg1) + in_shardings=[<@mesh1, [{"d", "a"}, {"b"}]>, <@mesh1, [{"b"}, {"c", "a"}], replicated={"d"}>] + out_shardings=[<@mesh1, [{"d"}, {}], replicated={"b", "c"}>] + manual_axes={"a", "b", "c", "d"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { + %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, [{?}, {"e", ?}]>]>} : tensor<2x8xf32> + %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %3 = "stablehlo.all_reduce"(%2) ({ + ^bb0(%arg4: tensor, %arg5: tensor): + %4 = stablehlo.add %arg4, %arg5 : tensor + stablehlo.return %4 : tensor + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<2x32xf32>) -> tensor<2x32xf32> + sdy.return %3 : tensor<2x32xf32> + } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> + return %0 : tensor<8x32xf32> +} + +// CHECK-LABEL: func @manual_computation_inlined_mesh +func.func @manual_computation_inlined_mesh(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { + // CHECK-NEXT: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh4, [{"b"}, {}]>, , [{"b"}, {}]>] + // CHECK-SAME{LITERAL}: out_shardings=[, [{"b"}, {}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + %0 = sdy.manual_computation(%arg0, %arg1) + in_shardings=[<@mesh4, [{"b", "a"}, {}]>, , [{"b"}, {}], replicated={"a"}>] + out_shardings=[, [{"a", "b"}, {}]>] + manual_axes={"a", "b"} (%arg2: tensor<4x16xf32>, %arg3: tensor<4x16xf32>) { + %1 = stablehlo.add %arg2, %arg2 : tensor<4x16xf32> + sdy.return %1 : tensor<4x16xf32> + } : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir index 2376056f8f0735..ffb749919eca2d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_export.mlir @@ -5,14 +5,15 @@ sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> // CHECK-LABEL: func @single_manual_comp func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a", ?}, {"b", ?}]>}, %arg1: tensor<16x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"b", ?}, {?}]>}) -> (tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0, %arg1) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]]:2 = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body(%[[FULL_TO_SHARD]]#0, %[[FULL_TO_SHARD]]#1) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} - // CHECK-SAME: : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> - // CHECK-NEXT: return %[[SHMAP]] : tensor<8x32xf32> + // CHECK-SAME: : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x32xf32>) -> tensor<8x32xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<8x32xf32> %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] manual_axes={"a", "b"} (%arg2: tensor<2x8xf32>, %arg3: tensor<8x32xf32>) { %1 = stablehlo.add %arg2, %arg2 : tensor<2x8xf32> %2 = stablehlo.dot %1, %arg3 : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> @@ -31,20 +32,23 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32> {sdy.sharding = #sdy.shard // CHECK-LABEL: func @manual_comp_using_another func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"a"}, {}]>}) -> (tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"b"}]>}) { - // CHECK: %[[SHMAP_0:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_0], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD_0:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP_0:.*]] = call @local_xla.sdy.manual_computation_body_0(%[[FULL_TO_SHARD_0]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} - // CHECK-SAME: : (tensor<8x8xf32>) -> tensor<8x8xf32> - // CHECK-NEXT: %[[SHMAP_1:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%[[SHMAP_0]]) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_1], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL_0:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_0]]) : (tensor<2x8xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: %[[FULL_TO_SHARD_1:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%[[SHARD_TO_FULL_0]]) : (tensor<8x8xf32>) -> tensor<8x4xf32> + // CHECK-NEXT: %[[SHMAP_1:.*]] = call @local_xla.sdy.manual_computation_body_1(%[[FULL_TO_SHARD_1]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} - // CHECK-SAME: : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-SAME: : (tensor<8x4xf32>) -> tensor<8x4xf32 + // CHECK-NEXT: %[[SHARD_TO_FULL_1:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP_1]]) : (tensor<8x4xf32>) -> tensor<8x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL_1]] : tensor<8x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"a"}, {}]>] out_shardings=[<@mesh_0, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { sdy.return %arg1 : tensor<2x8xf32> } : (tensor<8x8xf32>) -> tensor<8x8xf32> @@ -57,14 +61,15 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy // CHECK-LABEL: func @nested_shmaps func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_3], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_3(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} - // CHECK-SAME: : (tensor<4x8xf32>) -> tensor<4x8xf32 - // CHECK-NEXT: return %[[SHMAP]] : tensor<4x8xf32> + // CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> @@ -77,13 +82,15 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@m // CHECK-LABEL: func @nested_shmaps_extra_op func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a"}, {"b"}]>}) -> (tensor<4x8xf32> {sdy.sharding = #sdy.sharding<@mesh_1, [{"a", ?}, {?}]>}) { - // CHECK: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) - // CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_5], - // CHECK-SAME: mhlo.frontend_attributes = { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_5(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", - // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - // CHECK-NEXT: return %[[SHMAP]] : tensor<4x8xf32> + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} + // CHECK-SAME: (tensor<2x8xf32>) -> tensor<2x8xf32> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x8xf32>) -> tensor<4x8xf32> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4x8xf32> %0 = sdy.manual_computation(%arg0) in_shardings=[<@mesh_1, [{"a"}, {}]>] out_shardings=[<@mesh_1, [{"a"}, {}]>] manual_axes={"a"} (%arg1: tensor<2x8xf32>) { %1 = sdy.manual_computation(%arg1) in_shardings=[<@mesh_1, [{}, {"b"}]>] out_shardings=[<@mesh_1, [{}, {"b"}]>] manual_axes={"b"} (%arg2: tensor<2x4xf32>) { %2 = stablehlo.multiply %arg2, %arg2 : tensor<2x4xf32> @@ -95,6 +102,40 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh return %0 : tensor<4x8xf32> } +// CHECK-LABEL: func @manual_computation_no_inputs +func.func @manual_computation_no_inputs() -> tensor<4xi64> { + // CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_6() + // CHECK-SAME: {mhlo.frontend_attributes = { + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} + // CHECK-SAME: () -> tensor<2xi64> + // CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2xi64>) -> tensor<4xi64> + // CHECK-NEXT: return %[[SHARD_TO_FULL]] : tensor<4xi64> + %0 = sdy.manual_computation() in_shardings=[] out_shardings=[<@mesh_0, [{"b"}]>] manual_axes={"b"} () { + %1 = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + sdy.return %1 : tensor<2xi64> + } : () -> tensor<4xi64> + func.return %0 : tensor<4xi64> +} + +// CHECK-LABEL: func @manual_computation_no_outputs +func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { + // CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + // CHECK-NEXT: call @local_xla.sdy.manual_computation_body_7(%[[FULL_TO_SHARD]]) + // CHECK-SAME: {mhlo.frontend_attributes = { + // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", + // CHECK-SAME: xla.sdy.manual_axes = "#sdy", + // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} + // CHECK-SAME: : (tensor<2xi64>) -> () + // CHECK-NEXT: return + sdy.manual_computation(%arg0) in_shardings=[<@mesh_0, [{"b"}]>] out_shardings=[] manual_axes={"b"} (%arg1: tensor<2xi64>) { + stablehlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + sdy.return + } : (tensor<4xi64>) -> () + func.return +} + // CHECK-LABEL: func @local_xla.sdy.manual_computation_body(%arg0: tensor<2x8xf32>, %arg1: tensor<8x32xf32>) -> tensor<2x32xf32> // CHECK-NEXT: stablehlo.add // CHECK-NEXT: stablehlo.dot @@ -110,24 +151,34 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32> {sdy.sharding = #sdy.sh // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32 -// CHECK-NEXT: stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) -// CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_2], -// CHECK-SAME: mhlo.frontend_attributes = { +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_2(%[[FULL_TO_SHARD]]) +// CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} -// CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> +// CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: stablehlo.multiply %arg0, %arg0 : tensor<2x4xf32> // CHECK-LABEL: func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> -// CHECK-NEXT: %[[SHMAP:.*]] = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) -// CHECK-SAME: {called_computations = [@local_xla.sdy.manual_computation_body_4], -// CHECK-SAME: mhlo.frontend_attributes = { +// CHECK-NEXT: %[[FULL_TO_SHARD:.*]] = stablehlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32 +// CHECK-NEXT: %[[SHMAP:.*]] = call @local_xla.sdy.manual_computation_body_4(%[[FULL_TO_SHARD]]) +// CHECK-SAME: {mhlo.frontend_attributes = { // CHECK-SAME: xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", // CHECK-SAME: xla.sdy.manual_axes = "#sdy", // CHECK-SAME: xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} -// CHECK-SAME: : (tensor<2x8xf32>) -> tensor<2x8xf32> -// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHMAP]], %[[SHMAP]] : tensor<2x8xf32> +// CHECK-SAME: : (tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %[[SHARD_TO_FULL:.*]] = stablehlo.custom_call @local_xla.sdy.LocalToGlobalShape(%[[SHMAP]]) : (tensor<2x4xf32>) -> tensor<2x8xf32> +// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SHARD_TO_FULL]], %[[SHARD_TO_FULL]] : tensor<2x8xf32> // CHECK-NEXT: return %[[ADD]] : tensor<2x8xf32> + +// CHECK-LABEL: func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { +// CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<[2, 3]> : tensor<2xi64> +// CHECK-NEXT: return %[[C]] : tensor<2xi64> + +// CHECK-LABEL: func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { +// CHECK-NEXT: stablehlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () +// CHECK-NEXT: return diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir index 1f49e3858cd899..0f55988e0f123b 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import.mlir @@ -4,7 +4,7 @@ sdy.mesh @mesh_0 = <["a"=4, "b"=2]> sdy.mesh @mesh_1 = <["a"=2, "b"=2, "c"=2, "d"=2]> func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) -> (tensor<8x32xf32>) { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body // CHECK: %[[MAN_COMP:.*]] = sdy.manual_computation(%arg0, %arg1) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {"b"}]>, <@mesh_0, [{"b"}, {}], replicated={"a"}>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}], replicated={"b"}>] @@ -20,12 +20,14 @@ func.func @single_manual_comp(%arg0: tensor<8x16xf32>, %arg1: tensor<16x32xf32>) // CHECK-NEXT: sdy.return %[[REDUCE]] : tensor<2x32xf32> // CHECK-NEXT: } : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> // CHECK-NEXT: return %[[MAN_COMP]] : tensor<8x32xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0, %arg1) {called_computations = [@local_xla.sdy.manual_computation_body], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<8x16xf32>, tensor<16x32xf32>) -> tensor<8x32xf32> - return %0 : tensor<8x32xf32> + %0:2 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0, %arg1) : (tensor<8x16xf32>, tensor<16x32xf32>) -> (tensor<2x8xf32>, tensor<8x32xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0#0, %0#1) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>, <@mesh_0, [{\\\22b\\\22}, {}], replicated={\\\22a\\\22}>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>, tensor<8x32xf32>) -> tensor<2x32xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x32xf32>) -> tensor<8x32xf32> + return %2 : tensor<8x32xf32> } func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_0 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"a"}, {}]>] @@ -33,6 +35,7 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-SAME: (%arg1: tensor<2x8xf32>) { // CHECK-NEXT: sdy.return %arg1 : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_1 // CHECK-NEXT: %[[MAN_COMP_1:.*]] = sdy.manual_computation(%[[MAN_COMP_0]]) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{}, {"b"}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{}, {"b"}]>] @@ -41,15 +44,21 @@ func.func @manual_comp_using_another(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: sdy.return %arg1 : tensor<8x4xf32> // CHECK-NEXT: } : (tensor<8x8xf32>) -> tensor<8x8xf32> // CHECK-NEXT: return %[[MAN_COMP_1]] : tensor<8x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - %1 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%0) {called_computations = [@local_xla.sdy.manual_computation_body_1], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - return %1 : tensor<8x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_0(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<8x8xf32> + %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> tensor<8x4xf32> + %4 = call @local_xla.sdy.manual_computation_body_1(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{}, {\\\22b\\\22}]>]>"}} : (tensor<8x4xf32>) -> tensor<8x4xf32> + %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<8x4xf32>) -> tensor<8x8xf32> + return %5 : tensor<8x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_3( func.func @local_xla.sdy.manual_computation_body_3(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_2], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - return %0 : tensor<2x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %1 = call @local_xla.sdy.manual_computation_body_2(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + return %2 : tensor<2x8xf32> } // CHECK-NOT: func @local_xla.sdy.manual_computation_body_2( @@ -59,7 +68,7 @@ func.func @local_xla.sdy.manual_computation_body_2(%arg0: tensor<2x4xf32>) -> te } func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_3 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] @@ -76,12 +85,14 @@ func.func @nested_shmaps(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[MAN_COMP_1]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_3], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_3(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> } func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { - // CHECK-NOT: xla.sdy.ManualComputation + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_5 // CHECK: %[[MAN_COMP_0:.*]] = sdy.manual_computation(%arg0) // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_1, [{"a"}, {}]>] // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_1, [{"a"}, {}]>] @@ -99,8 +110,42 @@ func.func @nested_shmaps_extra_op(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { // CHECK-NEXT: sdy.return %[[ADD]] : tensor<2x8xf32> // CHECK-NEXT: } : (tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-NEXT: return %[[MAN_COMP_0]] : tensor<4x8xf32> - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_5], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<4x8xf32>) -> tensor<4x8xf32> - return %0 : tensor<4x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4x8xf32>) -> tensor<2x8xf32> + %1 = call @local_xla.sdy.manual_computation_body_5(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{\\\22a\\\22}, {}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} + +func.func @manual_computation_no_inputs() -> tensor<4xi64> { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_6 + // CHECK: %[[SHMAP:.*]] = sdy.manual_computation() + // CHECK-SAME{LITERAL}: in_shardings=[] + // CHECK-SAME{LITERAL}: out_shardings=[<@mesh_0, [{"b"}]>] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME{LITERAL}: () { + // CHECK-NEXT: %[[C:.*]] = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + // CHECK-NEXT: sdy.return %[[C]] : tensor<2xi64> + // CHECK-NEXT: } : () -> tensor<4xi64> + // CHECK-NEXT: return %[[SHMAP]] : tensor<4xi64> + %0 = call @local_xla.sdy.manual_computation_body_6() {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>"}} : () -> tensor<2xi64> + %1 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%0) : (tensor<2xi64>) -> tensor<4xi64> + return %1 : tensor<4xi64> +} + +func.func @manual_computation_no_outputs(%arg0: tensor<4xi64>) { + // CHECK-NOT: call @local_xla.sdy.manual_computation_body_7 + // CHECK: sdy.manual_computation(%arg0) + // CHECK-SAME{LITERAL}: in_shardings=[<@mesh_0, [{"b"}]>] + // CHECK-SAME{LITERAL}: out_shardings=[] + // CHECK-SAME{LITERAL}: manual_axes={"b"} + // CHECK-SAME{LITERAL}: (%arg1: tensor<2xi64>) { + // CHECK-NEXT: mhlo.custom_call @sdy_testonly(%arg1) : (tensor<2xi64>) -> () + // CHECK-NEXT: sdy.return + // CHECK-NEXT: } : (tensor<4xi64>) -> () + // CHECK-NEXT: return + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<4xi64>) -> tensor<2xi64> + call @local_xla.sdy.manual_computation_body_7(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[]>"}} : (tensor<2xi64>) -> () + return } // CHECK-NOT: func @local_xla.sdy.manual_computation_body( @@ -133,7 +178,21 @@ func.func @local_xla.sdy.manual_computation_body_4(%arg0: tensor<2x4xf32>) -> te // CHECK-NOT: func @local_xla.sdy.manual_computation_body_5( func.func @local_xla.sdy.manual_computation_body_5(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_4], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x8xf32>) -> tensor<2x8xf32> - %1 = stablehlo.add %0, %0 : tensor<2x8xf32> - return %1 : tensor<2x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<2x8xf32>) -> tensor<2x4xf32> + %1 = call @local_xla.sdy.manual_computation_body_4(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_1, [{}, {\\\22b\\\22}]>]>"}} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x4xf32>) -> tensor<2x8xf32> + %3 = stablehlo.add %2, %2 : tensor<2x8xf32> + return %3 : tensor<2x8xf32> +} + +// CHECK-NOT: func @local_xla.sdy.manual_computation_body_6( +func.func @local_xla.sdy.manual_computation_body_6() -> tensor<2xi64> { + %c = stablehlo.constant dense<[2, 3]> : tensor<2xi64> + return %c : tensor<2xi64> +} + +// CHECK-NOT: func @local_xla.sdy.manual_computation_body_7( +func.func @local_xla.sdy.manual_computation_body_7(%arg0: tensor<2xi64>) { + mhlo.custom_call @sdy_testonly(%arg0) : (tensor<2xi64>) -> () + return } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir index d1516899e450d6..9f2a3a5740924d 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_shard_map_import_failure.mlir @@ -3,13 +3,17 @@ sdy.mesh @mesh = <["a"=2]> func.func @using_same_body_func(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - %0 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%arg0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - // expected-error @+2 {{'stablehlo.custom_call' op expected a unique FuncOp per @local_xla.sdy.ManualComputation custom call}} - // expected-error @+1 {{failed to legalize operation 'stablehlo.custom_call'}} - %1 = stablehlo.custom_call @local_xla.sdy.ManualComputation(%0) {called_computations = [@local_xla.sdy.manual_computation_body_0], mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}, {}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> - return %1 : tensor<8x8xf32> + %0 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%arg0) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + %1 = call @local_xla.sdy.manual_computation_body(%0) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) + %2 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%1) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + %3 = mhlo.custom_call @local_xla.sdy.GlobalToLocalShape(%2) : (tensor<8x8xf32>) -> (tensor<2x8xf32>) + // expected-error @+2 {{'func.call' op expected a unique FuncOp per @local_xla.sdy.manual_computation_body call}} + // expected-error @+1 {{failed to legalize operation 'func.call'}} + %4 = call @local_xla.sdy.manual_computation_body(%3) {mhlo.frontend_attributes = {xla.sdy.in_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {\\\22b\\\22}]>]>", xla.sdy.manual_axes = "#sdy", xla.sdy.out_shardings = "#sdy.sharding_per_value<[<@mesh_0, [{\\\22a\\\22}, {}], replicated={\\\22b\\\22}>]>"}} : (tensor<2x8xf32>) -> (tensor<2x8xf32>) + %5 = mhlo.custom_call @local_xla.sdy.LocalToGlobalShape(%4) : (tensor<2x8xf32>) -> (tensor<8x8xf32>) + return %5 : tensor<8x8xf32> } -func.func @local_xla.sdy.manual_computation_body_0(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { +func.func @local_xla.sdy.manual_computation_body(%arg0: tensor<2x8xf32>) -> tensor<2x8xf32> { return %arg0 : tensor<2x8xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir new file mode 100644 index 00000000000000..f30c0150ce0264 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_sharding_group_import_failure.mlir @@ -0,0 +1,18 @@ +// RUN: sdy_opt %s -xla-sdy-import-sdy-custom-calls -split-input-file -verify-diagnostics + +func.func @sharding_group_import_failure_if_no_group_id(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { + // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+1 {{expected CustomCallOp with a sharding group id.}} + mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {}} : (tensor<16x16xf32>) -> () + return %arg0 : tensor<16x16xf32> +} + +// ----- + +func.func @sharding_group_import_with_used_result(%arg0: tensor<8x8xf32>) -> tuple> { + // expected-error @+2 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} + // expected-error @+1 {{xla.sdy.ShardingGroup CustomCallOp should have no uses.}} + %0 = mhlo.custom_call @local_xla.sdy.ShardingGroup(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding_group_id = "21 : i64"}} : (tensor<8x8xf32>) -> tuple<> + %1 = "mhlo.tuple"(%0) : (tuple<>) -> tuple> + return %1 : tuple> +} diff --git a/third_party/xla/xla/service/spmd/shardy/utils.cc b/third_party/xla/xla/service/spmd/shardy/utils.cc index 2245b582733004..604ed05b306ec3 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.cc +++ b/third_party/xla/xla/service/spmd/shardy/utils.cc @@ -152,6 +152,14 @@ void removeFrontendAttribute(FuncOp funcOp, StringRef attributeName, [&]() { funcOp.removeArgAttr(argNum, kFrontendAttributesAttr); }); } +bool hasFrontendAttr(mlir::Operation* op, mlir::StringRef key) { + return hasKey(getFrontendAttrs(op), key); +} + +bool hasKey(mlir::DictionaryAttr dictAttr, mlir::StringRef key) { + return dictAttr && dictAttr.contains(key); +} + void loadAllRequiredDialects(mlir::MLIRContext* context) { mlir::DialectRegistry registry; mlir::func::registerAllExtensions(registry); diff --git a/third_party/xla/xla/service/spmd/shardy/utils.h b/third_party/xla/xla/service/spmd/shardy/utils.h index 80194b3ca04c40..552de063ce2e4a 100644 --- a/third_party/xla/xla/service/spmd/shardy/utils.h +++ b/third_party/xla/xla/service/spmd/shardy/utils.h @@ -17,6 +17,8 @@ limitations under the License. #define XLA_SERVICE_SPMD_SHARDY_UTILS_H_ #include +#include +#include #include "absl/log/check.h" #include "absl/strings/escaping.h" @@ -59,27 +61,41 @@ void removeFrontendAttribute(mlir::Operation* op, void removeFrontendAttribute(mlir::func::FuncOp funcOp, mlir::StringRef attributeName, int64_t argNum); -void loadAllRequiredDialects(mlir::MLIRContext* context); +// Checks if "frontend_attributes" `DictionaryAttr` from `op` contains `key`. +bool hasFrontendAttr(mlir::Operation* op, mlir::StringRef key); -// Parses `stringAttr` to an attribute of type `AttrTy`. -// -// NOTE: assumes `stringAttr` is of type `StringAttr`. -template -AttrTy parseStringAttr(mlir::Attribute stringAttr) { - std::string value; - std::string error; - CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), - &value, &error)) - << error; - return mlir::cast( - mlir::parseAttribute(value, stringAttr.getContext())); -} +// Checks if `dictAttr` exists and contains `key`. +bool hasKey(mlir::DictionaryAttr dictAttr, mlir::StringRef key); + +void loadAllRequiredDialects(mlir::MLIRContext* context); // Parses `attrName` from `dictAttr` to an attribute of type `AttrTy`. template AttrTy parseStringAttr(mlir::DictionaryAttr dictAttr, llvm::StringRef attrName) { - return parseStringAttr(dictAttr.get(attrName)); + if (mlir::Attribute stringAttr = dictAttr.get(attrName)) { + std::string value; + std::string error; + CHECK(absl::CUnescape(mlir::cast(stringAttr).getValue(), + &value, &error)) + << error; + return mlir::cast( + mlir::parseAttribute(value, stringAttr.getContext())); + } + return nullptr; +} + +// Checks if `op`'s "frontend_attributes" `DictionaryAttr` contains `attrName` +// and parses it to an attribute of type `AttrTy`. If it doesn't exist, then +// returns std::nullopt. +template +std::optional tryGetFrontendAttr(mlir::Operation* op, + mlir::StringRef attrName) { + mlir::DictionaryAttr dictAttr = getFrontendAttrs(op); + if (hasKey(dictAttr, attrName)) { + return parseStringAttr(dictAttr, attrName); + } + return std::nullopt; } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index ceb9553f9ee095..37a9ad5328633f 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -49,6 +49,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout_util.h" @@ -58,14 +61,11 @@ limitations under the License. #include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_layout.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_cse.h" -#include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" #include "xla/service/spmd/custom_call_handler.h" #include "xla/service/spmd/spmd_partitioner_util.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" @@ -246,135 +246,179 @@ HloInstruction* SpmdBuilder::AddInstruction( HloInstruction* hlo = HloComputation::Builder::AddInstruction(std::move(instruction)); if (visiting_hlo_) { - hlo->set_metadata(visiting_hlo_->metadata()); + std::shared_ptr prev_sharding = hlo->sharding_ptr(); + visiting_hlo_->SetupDerivedInstruction(hlo); + if (prev_sharding != nullptr) { + hlo->set_sharding(*prev_sharding); + } else { + hlo->clear_sharding(); + } instructions_[visiting_hlo_].push_back(hlo); } - if (hlo->opcode() == HloOpcode::kBroadcast) { - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (!absl::c_linear_search(hlo->dimensions(), i)) { - broadcast_dims_[hlo].insert(i); + SetBroadcastDimsForAddedHlo(*hlo); + return hlo; +} + +void SpmdBuilder::SetBroadcastDimsForAddedHlo(const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kBroadcast) { + for (int64_t i = 0; i < hlo.shape().rank(); ++i) { + if (!absl::c_linear_search(hlo.dimensions(), i)) { + broadcast_dims_[&hlo].insert(i); } } } - if (hlo->IsElementwise() && hlo->operand_count() > 0 && + if (hlo.IsElementwise() && hlo.operand_count() > 0 && // Copy can have a tuple result. - hlo->shape().IsArray()) { - absl::flat_hash_set broadcast_dims; - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - broadcast_dims.insert(i); + hlo.shape().IsArray()) { + SetBroadcastDimsForElementwise(hlo); + } + if (hlo.opcode() == HloOpcode::kTranspose) { + SetBroadcastDimsForTranspose(hlo); + } + if (hlo.opcode() == HloOpcode::kReshape && + Product(hlo.shape().dimensions()) > 0) { + SetBroadcastDimsForReshape(hlo); + } + if (hlo.opcode() == HloOpcode::kSlice || + hlo.opcode() == HloOpcode::kDynamicSlice) { + SetBroadcastDimsForSlice(hlo); + } + if (hlo.opcode() == HloOpcode::kPad) { + SetBroadcastDimsForPad(hlo); + } +} + +void SpmdBuilder::SetBroadcastDimsForReshape(const HloInstruction& hlo) { + CHECK(hlo.opcode() == HloOpcode::kReshape); + + auto it = broadcast_dims_.find(hlo.operand(0)); + if (it == broadcast_dims_.end()) { + return; + } + std::vector iota_dims(hlo.shape().rank()); + absl::c_iota(iota_dims, 0); + absl::flat_hash_set reshape_broadcast_dims(iota_dims.begin(), + iota_dims.end()); + + absl::Span operand_dims = hlo.operand(0)->shape().dimensions(); + absl::Span hlo_dims = hlo.shape().dimensions(); + std::vector before_dim_size_stack(operand_dims.rbegin(), + operand_dims.rend()); + std::vector after_dim_size_stack(hlo_dims.rbegin(), hlo_dims.rend()); + + auto erase_reshape_broadcast_dims = [&reshape_broadcast_dims](int64_t from, + int64_t to) { + for (int64_t i = from; i < to; ++i) { + reshape_broadcast_dims.erase(i); } - for (int64_t i = 0; i < hlo->operand_count(); ++i) { - auto it = broadcast_dims_.find(hlo->operand(i)); - if (it == broadcast_dims_.end()) { - broadcast_dims.clear(); - break; - } - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - if (!it->second.contains(i)) { - broadcast_dims.erase(i); - } - } + }; + + while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) { + int64_t before_size = before_dim_size_stack.back(); + int64_t after_size = after_dim_size_stack.back(); + int64_t current_before_dim = + hlo.operand(0)->shape().rank() - before_dim_size_stack.size(); + int64_t current_after_dim = + hlo.shape().rank() - after_dim_size_stack.size(); + before_dim_size_stack.pop_back(); + after_dim_size_stack.pop_back(); + if (!it->second.contains(current_before_dim)) { + reshape_broadcast_dims.erase(current_after_dim); + } + if (before_size == after_size) { + continue; } - if (!broadcast_dims.empty()) { - broadcast_dims_[hlo] = std::move(broadcast_dims); + if (before_size % after_size == 0) { + // Split dim. + before_dim_size_stack.push_back(before_size / after_size); + } else if (after_size % before_size == 0) { + // Merge dim. + after_dim_size_stack.push_back(after_size / before_size); + } else { + // Other cases, mark all remaining dims as non-broadcast. + erase_reshape_broadcast_dims(current_after_dim, hlo.shape().rank()); + break; } } - if (hlo->opcode() == HloOpcode::kTranspose) { - auto it = broadcast_dims_.find(hlo->operand(0)); - if (it != broadcast_dims_.end()) { - absl::flat_hash_set xpose_broadcast_dims; - std::vector reverse_map(hlo->shape().rank()); - for (int64_t i = 0; i < reverse_map.size(); ++i) { - reverse_map[hlo->dimensions(i)] = i; - } - for (int64_t dim : it->second) { - xpose_broadcast_dims.insert(reverse_map[dim]); - } - broadcast_dims_[hlo] = std::move(xpose_broadcast_dims); - } + + bool has_broadcast_dims = !reshape_broadcast_dims.empty() && + before_dim_size_stack.empty() && + after_dim_size_stack.empty(); + if (has_broadcast_dims) { + broadcast_dims_[&hlo] = std::move(reshape_broadcast_dims); } - if (hlo->opcode() == HloOpcode::kReshape && - Product(hlo->shape().dimensions()) > 0) { - auto it = broadcast_dims_.find(hlo->operand(0)); - if (it != broadcast_dims_.end()) { - absl::flat_hash_set reshape_broadcast_dims; - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - reshape_broadcast_dims.insert(i); - } - std::vector before_dim_size_stack; - std::vector after_dim_size_stack; - const int64_t operand0_rank = hlo->operand(0)->shape().rank(); - const int64_t hlo_shape_rank = hlo->shape().rank(); - before_dim_size_stack.reserve(operand0_rank); - after_dim_size_stack.reserve(hlo_shape_rank); - for (int64_t i = operand0_rank - 1; i >= 0; --i) { - before_dim_size_stack.push_back(hlo->operand(0)->shape().dimensions(i)); - } - for (int64_t i = hlo_shape_rank - 1; i >= 0; --i) { - after_dim_size_stack.push_back(hlo->shape().dimensions(i)); - } - while (!before_dim_size_stack.empty() && !after_dim_size_stack.empty()) { - int64_t before_size = before_dim_size_stack.back(); - int64_t after_size = after_dim_size_stack.back(); - int64_t current_before_dim = - hlo->operand(0)->shape().rank() - before_dim_size_stack.size(); - int64_t current_after_dim = - hlo->shape().rank() - after_dim_size_stack.size(); - before_dim_size_stack.pop_back(); - after_dim_size_stack.pop_back(); - if (!it->second.contains(current_before_dim)) { - reshape_broadcast_dims.erase(current_after_dim); - } - if (before_size == after_size) { - continue; - } - if (before_size % after_size == 0) { - // Split dim. - before_dim_size_stack.push_back(before_size / after_size); - } else if (after_size % before_size == 0) { - // Merge dim. - after_dim_size_stack.push_back(after_size / before_size); - } else { - // Other cases, mark all remaining dims as non-broadcast. - for (int64_t i = current_after_dim; i < hlo->shape().rank(); ++i) { - reshape_broadcast_dims.erase(i); - } - break; - } - } - if (!before_dim_size_stack.empty() || !after_dim_size_stack.empty()) { - reshape_broadcast_dims.clear(); - } - if (!reshape_broadcast_dims.empty()) { - broadcast_dims_[hlo] = std::move(reshape_broadcast_dims); - } - } +} + +void SpmdBuilder::SetBroadcastDimsForTranspose(const HloInstruction& hlo) { + CHECK(hlo.opcode() == HloOpcode::kTranspose); + auto it = broadcast_dims_.find(hlo.operand(0)); + if (it == broadcast_dims_.end()) { + return; + } + absl::flat_hash_set xpose_broadcast_dims; + std::vector reverse_map(hlo.shape().rank()); + for (int64_t i = 0; i < reverse_map.size(); ++i) { + reverse_map[hlo.dimensions(i)] = i; + } + for (int64_t dim : it->second) { + xpose_broadcast_dims.insert(reverse_map[dim]); + } + broadcast_dims_[&hlo] = std::move(xpose_broadcast_dims); +} + +void SpmdBuilder::SetBroadcastDimsForPad(const HloInstruction& hlo) { + CHECK(hlo.opcode() == HloOpcode::kPad); + auto it = broadcast_dims_.find(hlo.operand(0)); + if (it == broadcast_dims_.end()) { + return; } - if (hlo->opcode() == HloOpcode::kSlice || - hlo->opcode() == HloOpcode::kDynamicSlice) { - auto it = broadcast_dims_.find(hlo->operand(0)); - if (it != broadcast_dims_.end()) { - auto dims = it->second; - broadcast_dims_[hlo] = std::move(dims); + absl::flat_hash_set pad_broadcast_dims; + for (int64_t i = 0; i < hlo.shape().rank(); ++i) { + const auto& dim = hlo.padding_config().dimensions(i); + if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 && + dim.interior_padding() == 0 && it->second.contains(i)) { + pad_broadcast_dims.insert(i); } } - if (hlo->opcode() == HloOpcode::kPad) { - auto it = broadcast_dims_.find(hlo->operand(0)); - if (it != broadcast_dims_.end()) { - absl::flat_hash_set pad_broadcast_dims; - for (int64_t i = 0; i < hlo->shape().rank(); ++i) { - const auto& dim = hlo->padding_config().dimensions(i); - if (dim.edge_padding_low() == 0 && dim.edge_padding_high() == 0 && - dim.interior_padding() == 0 && it->second.contains(i)) { - pad_broadcast_dims.insert(i); - } - } - if (!pad_broadcast_dims.empty()) { - broadcast_dims_[hlo] = std::move(pad_broadcast_dims); + if (!pad_broadcast_dims.empty()) { + broadcast_dims_[&hlo] = std::move(pad_broadcast_dims); + } +} + +void SpmdBuilder::SetBroadcastDimsForSlice(const HloInstruction& hlo) { + CHECK(hlo.opcode() == HloOpcode::kSlice || + hlo.opcode() == HloOpcode::kDynamicSlice); + auto it = broadcast_dims_.find(hlo.operand(0)); + if (it != broadcast_dims_.end()) { + auto dims = it->second; + broadcast_dims_[&hlo] = std::move(dims); + } +} + +void SpmdBuilder::SetBroadcastDimsForElementwise(const HloInstruction& hlo) { + CHECK(hlo.IsElementwise()); + if (hlo.operand_count() == 0 || hlo.shape().IsTuple()) { + return; + } + absl::flat_hash_set broadcast_dims; + for (int64_t i = 0; i < hlo.shape().rank(); ++i) { + broadcast_dims.insert(i); + } + for (int64_t i = 0; i < hlo.operand_count(); ++i) { + auto it = broadcast_dims_.find(hlo.operand(i)); + if (it == broadcast_dims_.end()) { + broadcast_dims.clear(); + break; + } + for (int64_t i = 0; i < hlo.shape().rank(); ++i) { + if (!it->second.contains(i)) { + broadcast_dims.erase(i); } } } - return hlo; + if (!broadcast_dims.empty()) { + broadcast_dims_[&hlo] = std::move(broadcast_dims); + } } PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target, @@ -1143,8 +1187,6 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, std::vector(halo_exchange_base_shape.rank(), 1))); } - std::vector left_halo_size_functions(base_shape_.rank()); - std::vector right_halo_size_functions(base_shape_.rank()); // TODO(yuanzx): We are concatenating on each sharded dimension one at time, // and in the second dimension (and beyond) we create halos by slicing the // concat in the previous dimension, which is not optimal. We should generate @@ -1162,18 +1204,18 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, // partition. MultiplyAddDivideOffsetCalculation shard_limit_of_previous_on_padded( input_shard_size, explicit_left_padding[dim], 1); - left_halo_size_functions[dim] = + OffsetCalculation left_halo_size_functions = shard_limit_of_previous_on_padded - start_on_padded_calculations[dim]; // Right halo. MultiplyAddDivideOffsetCalculation shard_start_of_next_on_padded( input_shard_size, input_shard_size + explicit_left_padding[dim], 1); - right_halo_size_functions[dim] = + OffsetCalculation right_halo_size_functions = limit_on_padded_calculations[dim] - shard_start_of_next_on_padded; auto resharded = ExchangeHaloAndGetValidData( - visiting_hlo, halo_exchange_base_shape, left_halo_size_functions[dim], - right_halo_size_functions[dim], explicit_left_padding[dim], + visiting_hlo, halo_exchange_base_shape, left_halo_size_functions, + right_halo_size_functions, explicit_left_padding[dim], padded_shape.dimensions(dim), shard_shape.dimensions(dim), dim, *halo_exchange_target, offsets_on_padded_shape[dim], pad_value, partition_ordinals[dim], state_.collective_ops_creator, @@ -5349,8 +5391,9 @@ absl::Status SpmdPartitioner::PreprocessHlos( auto* merged_dim = merged_padding->mutable_dimensions(i); merged_dim->set_edge_padding_low(dim.edge_padding_low() - hlo->slice_starts(i)); - merged_dim->set_edge_padding_high(hlo->slice_limits(i) - - operand->shape().dimensions(i)); + merged_dim->set_edge_padding_high( + hlo->slice_limits(i) - + (operand->shape().dimensions(i) - dim.edge_padding_high())); } if (merged_padding.has_value() && may_have_multi_halo_exchanges) { // Rewrite to a single Pad. diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.h b/third_party/xla/xla/service/spmd/spmd_partitioner.h index 26ae71f44d21f4..2466d99a03a054 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.h @@ -54,6 +54,7 @@ namespace spmd { // Enum representing the partitioning methods for gather and scatter. enum class PartitioningMethod { + kExplicitBatch, kIndexParallel, kOperandPassthrough, kTrivialSlicedOperand, @@ -116,6 +117,11 @@ struct SpmdPartitionerOptions { // Partitioning method to prioritize for scatter operations. PartitioningMethod scatter_partition_method = PartitioningMethod::kIndexParallel; + + // The minimum size to enable windowed einsum in total bytes. + // This combines sizes in bytes of both operands. + // When it's set, it will override threshold_for_windowed_einsum_mib. + std::optional total_bytes_windowed_einsum_threshold = std::nullopt; }; // Class to wrap the computation builder to capture information during SPMD @@ -153,6 +159,18 @@ class SpmdBuilder : public HloComputation::Builder { } private: + // Sets the broadcast dims for the newly added/created hlo. + void SetBroadcastDimsForAddedHlo(const HloInstruction& hlo); + + void SetBroadcastDimsForReshape(const HloInstruction& hlo); + + void SetBroadcastDimsForTranspose(const HloInstruction& hlo); + + void SetBroadcastDimsForPad(const HloInstruction& hlo); + + void SetBroadcastDimsForSlice(const HloInstruction& hlo); + + void SetBroadcastDimsForElementwise(const HloInstruction& hlo); // Currently visiting instruction. HloInstruction* visiting_hlo_; diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 3ed5b7bb5f0116..e94c752a06bd94 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -42,12 +41,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/sharding_format_picker.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/layout_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/sharding_format_picker.h" #include "xla/service/spmd/spmd_prepare.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" @@ -77,7 +76,9 @@ class SpmdPartitioningTest bool bidirectional_windowed_einsum = false, int64_t threshold_for_windowed_einsum_mib = -1, PartitioningMethod gather_method = PartitioningMethod::kIndexParallel, - PartitioningMethod scatter_method = PartitioningMethod::kIndexParallel) { + PartitioningMethod scatter_method = PartitioningMethod::kIndexParallel, + std::optional total_bytes_windowed_einsum_threshold = + std::nullopt) { // Some tests (BackpropFilter convs) set this flag false to test two // different paths of the implementation. SpmdPartitionerOptions options; @@ -87,6 +88,8 @@ class SpmdPartitioningTest choose_faster_windowed_einsum; options.unroll_windowed_einsum = unroll_windowed_einsum; options.bidirectional_windowed_einsum = bidirectional_windowed_einsum; + options.total_bytes_windowed_einsum_threshold = + total_bytes_windowed_einsum_threshold; if (threshold_for_windowed_einsum_mib >= 0) { options.threshold_for_windowed_einsum_mib = threshold_for_windowed_einsum_mib; @@ -4906,6 +4909,34 @@ ENTRY entry { } } +TEST_P(SpmdPartitioningTest, WindowedEinsumNoRewriteWithTotalBytesThreshold) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %p0 = f32[2048,2,3264]{2,1,0} parameter(0), sharding={devices=[1,1,2]0,1} + %p1 = f32[2,3264,2176]{2,1,0} parameter(1), sharding={devices=[2,1,1]0,1} + ROOT %dot.224 = f32[2048,2176]{1,0} dot(f32[2048,2,3264]{2,1,0} %p0, f32[2,3264,2176]{2,1,0} %p1), lhs_contracting_dims={1,2}, rhs_contracting_dims={0,1}, sharding={devices=[1,2]0,1} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_devices=*/2, + /*conv_halo_exchange_always_on_lhs=*/true, + /*choose_faster_windowed_einsum=*/false, + /*unroll_windowed_einsum=*/false, + /*bidirectional_windowed_einsum=*/false, + /*threshold_for_windowed_einsum_mib=*/5, + PartitioningMethod::kIndexParallel, + PartitioningMethod::kIndexParallel, + /*total_bytes_windowed_einsum_threshold=*/1 << 30)); + VLOG(1) << module->ToString(); + // Total bytes threshold overrides threshold_for_windowed_einsum_mib, + // there shouldn't be any while loop after partitioner. + HloInstruction* while_inst = FindInstruction(module.get(), HloOpcode::kWhile); + EXPECT_EQ(while_inst, nullptr); +} + TEST_P(SpmdPartitioningTest, DotPartialDeviceOrder) { absl::string_view hlo_string = R"( HloModule module @@ -7730,6 +7761,72 @@ ENTRY entry { EXPECT_THAT(root, op::CollectivePermute(gather)); } +TEST_P(SpmdPartitioningTest, GatherExplicitBatchDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,1,2,1]<=[2,2]T(1,0)} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,1,1]<=[4]} + ROOT %gather = f32[14,10,6,2] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,2}, sharding={devices=[2,2,1,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,3,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,5,6,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[7,5,6,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + +TEST_P(SpmdPartitioningTest, GatherExplicitBatchAndOperandPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[2,1,1,2]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[1,2,1,1,2]<=[4] last_tile_dim_replicate} + ROOT %gather = f32[14,10,6,4] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,4}, sharding={devices=[1,2,1,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,3,14,2]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[14,5,6,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[14,5,6,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + +TEST_P(SpmdPartitioningTest, GatherExplicitBatchAndIndexPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %input = f32[10,3,14,4] parameter(0), sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,2,1]<=[4]} + ROOT %gather = f32[14,10,6,2] gather(%input, %indices), offset_dims={3}, + collapsed_slice_dims={1}, operand_batching_dims={0,2}, + start_indices_batching_dims={1,0}, start_index_map={1,3}, + index_vector_dim=3, slice_sizes={1,1,1,2}, sharding={devices=[2,1,2,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[10,3,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,10,3,2]"), op::Parameter(1)); + auto gather = AllOf(op::Shape("f32[7,10,3,2]"), op::Gather(input, indices)); + EXPECT_THAT(module->entry_computation()->root_instruction(), gather); +} + TEST_P(SpmdPartitioningTest, GatherPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module @@ -8191,6 +8288,100 @@ ENTRY entry { op::Shape("f32[2,9,8]"))); } +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[2,1,2,1]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,2,1,1]<=[2,2]T(1,0)} + %updates = f32[14,10,6,2] parameter(2), sharding={devices=[2,2,1,1]<=[2,2]T(1,0)} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[2,1,2,1]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[5,6,7,4]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,5,6,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,5,6,2]"), op::Parameter(2)); + auto scatter = + AllOf(op::Shape("f32[5,6,7,4]"), op::Scatter(input, indices, updates)); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndOperandPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[1,1,2,2]<=[4]} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,1,1,2]<=[4] last_tile_dim_replicate} + %updates = f32[14,10,6,4] parameter(2), sharding={devices=[2,1,1,2]<=[4]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[1,1,2,2]<=[4]} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = AllOf(op::Shape("f32[10,6,7,2]"), op::Parameter(0)); + auto indices = AllOf(op::Shape("s32[7,10,6,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,10,6,2]"), op::Parameter(2)); + auto scatter = + AllOf(op::Shape("f32[10,6,7,2]"), op::Scatter(input, indices, updates)); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + +TEST_P(SpmdPartitioningTest, ScatterExplicitBatchAndIndexPassthroughDims) { + absl::string_view hlo_string = R"( +HloModule module + +min (lhs: f32[], rhs: f32[]) -> f32[] { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT min = f32[] minimum(lhs, rhs) +} + +ENTRY entry { + %input = f32[10,6,14,4] parameter(0), sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} + %indices = s32[14,10,6,2] parameter(1), sharding={devices=[2,1,2,1]<=[4]} + %updates = f32[14,10,6,2] parameter(2), sharding={devices=[2,1,2,1]<=[4]} + ROOT %scatter = f32[10,6,14,4] scatter(%input, %indices, %updates), + to_apply=min, update_window_dims={3}, inserted_window_dims={1}, + scatter_dims_to_operand_dims={1,3}, input_batching_dims={0,2}, + scatter_indices_batching_dims={1,0}, index_vector_dim=3, sharding={devices=[1,1,2,1,2]<=[4] last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto input = + AllOf(op::Shape("f32[10,6,7,4]"), op::Select(_, _, op::Parameter(0))); + auto indices = AllOf(op::Shape("s32[7,10,3,2]"), op::Parameter(1)); + auto updates = AllOf(op::Shape("f32[7,10,3,2]"), op::Parameter(2)); + auto scatter = AllOf(op::Shape("f32[10,6,7,4]"), + op::AllReduce(op::Scatter(input, indices, updates))); + EXPECT_THAT(module->entry_computation()->root_instruction(), scatter); +} + TEST_P(SpmdPartitioningTest, ScatterPartitionedOnTrivialSliceDims) { absl::string_view hlo_string = R"( HloModule module @@ -14923,6 +15114,27 @@ ENTRY offloading (param0: f32[1,256,128]) -> f32[1,256,128] { EXPECT_THAT(move_to_device, op::Shape("f32[1,256,32]")); } +TEST_P(SpmdPartitioningTest, MergedPadThenSliceWithPaddingHigh) { + absl::string_view hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[4] parameter(0), sharding={devices=[4]<=[4]} + %init = f32[] constant(2.0) + %pad = f32[8] pad(%param0, %init), padding=2_2, sharding={devices=[4]<=[4]} + ROOT %slice = f32[4] slice(%pad), slice={[4:8]}, sharding={devices=[4]<=[4]} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + const auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf(op::Parameter(0), op::Shape("f32[1]")); + EXPECT_THAT(root, AllOf(op::Select(_, op::CollectivePermute(param0), _), + op::Shape("f32[1]"))); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc index 83c41ec91f840f..ac1d272bac56f0 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.cc @@ -956,8 +956,7 @@ HloInstruction* ExchangeHaloCompact( (i + 1) * input_shard_size + right_halo_size_function.Calculate(i); max_window_size = std::max(max_window_size, limit - start); while (next_start < limit) { - halos[i].emplace_back(); - Halo& halo = halos[i].back(); + Halo& halo = halos[i].emplace_back(); halo.my_index = i; halo.halo_offset = next_start - start; halo.start = next_start % input_shard_size; @@ -1038,11 +1037,12 @@ HloInstruction* ExchangeHaloCompact( // Sort halos that are from the same src according to halo_offset, so that // they are more likely to have similar characteristics. for (int64_t i = 0; i < src_to_dst.size(); ++i) { - absl::c_sort(src_to_dst[i], [&](const std::pair& a, - const std::pair& b) { - return halos[a.first][a.second].halo_offset < - halos[b.first][b.second].halo_offset; - }); + absl::c_stable_sort(src_to_dst[i], + [&](const std::pair& a, + const std::pair& b) { + return halos[a.first][a.second].halo_offset < + halos[b.first][b.second].halo_offset; + }); } // Build collective permutes with distinct src/dst values. @@ -2142,8 +2142,8 @@ std::optional GatherScatterOperandsShardedAcrossParallelDims( const HloInstruction& operand, const HloInstruction& indices, const hlo_sharding_util::GatherScatterParallelDims& parallel_dims) { - auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; - auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; + const auto& indices_parallel_dims = parallel_dims.indices_parallel_dims; + const auto& operand_parallel_dims = parallel_dims.operand_parallel_dims; if (indices_parallel_dims.size() != operand_parallel_dims.size()) { return std::nullopt; } @@ -2154,32 +2154,26 @@ GatherScatterOperandsShardedAcrossParallelDims( if (idx_parallel_tiles_num == 1 && op_parallel_tiles_num == 1) { return std::nullopt; } - absl::InlinedVector indices_parallel_dims_ordered_as_operand; - for (int idx : parallel_dims.index_parallel_in_dim) { - if (idx != -1) { - indices_parallel_dims_ordered_as_operand.push_back(idx); - } - } + if (new_index_shard.IsReplicated()) { return GatherScatterParallelDimSharding{ CreateMatchingShardingOnDims(indices.shape(), new_operand_shard, - indices_parallel_dims_ordered_as_operand, + indices_parallel_dims, operand_parallel_dims), new_operand_shard}; } if (new_operand_shard.IsReplicated()) { return GatherScatterParallelDimSharding{ - new_index_shard, - CreateMatchingShardingOnDims(operand.shape(), new_index_shard, - operand_parallel_dims, - indices_parallel_dims_ordered_as_operand)}; + new_index_shard, CreateMatchingShardingOnDims( + operand.shape(), new_index_shard, + operand_parallel_dims, indices_parallel_dims)}; } // Parallel dimension distribution needs to be the same, so try to steal // sharding from partial replication to compensate. if (idx_parallel_tiles_num != op_parallel_tiles_num) { auto to_adjust_dims = operand_parallel_dims; - auto target_dims = indices_parallel_dims_ordered_as_operand; + auto target_dims = indices_parallel_dims; HloSharding* target = &new_index_shard; HloSharding* to_adjust = &new_operand_shard; if (idx_parallel_tiles_num < op_parallel_tiles_num) { @@ -2231,19 +2225,17 @@ GatherScatterOperandsShardedAcrossParallelDims( std::vector operand_shard_tile_dims( new_operand_shard.tile_assignment().dimensions().begin(), new_operand_shard.tile_assignment().dimensions().end()); - for (int i = 0; i < indices_parallel_dims_ordered_as_operand.size(); ++i) { + for (int i = 0; i < indices_parallel_dims.size(); ++i) { operand_shard_tile_dims[operand_parallel_dims[i]] = - new_index_shard.tile_assignment().dim( - indices_parallel_dims_ordered_as_operand[i]); + new_index_shard.tile_assignment().dim(indices_parallel_dims[i]); } auto operand_shard_tiles = new_operand_shard.tile_assignment().Reshape(operand_shard_tile_dims); - new_operand_shard = - AlignShardingOnDims(new_operand_shard.ReplicateOnLastTileDim() - ? HloSharding::PartialTile(operand_shard_tiles) - : HloSharding::Tile(operand_shard_tiles), - operand_parallel_dims, new_index_shard, - indices_parallel_dims_ordered_as_operand); + new_operand_shard = AlignShardingOnDims( + new_operand_shard.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(operand_shard_tiles) + : HloSharding::Tile(operand_shard_tiles), + operand_parallel_dims, new_index_shard, indices_parallel_dims); return GatherScatterParallelDimSharding{new_index_shard, new_operand_shard}; } diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h index a982c3edf1e8db..058cfd49779ac9 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h @@ -46,12 +46,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/hlo/utils/hlo_sharding_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/hlo_dce.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/spmd/spmd_prepare.cc b/third_party/xla/xla/service/spmd/spmd_prepare.cc index 51655b90861d48..83bf8495cb18ca 100644 --- a/third_party/xla/xla/service/spmd/spmd_prepare.cc +++ b/third_party/xla/xla/service/spmd/spmd_prepare.cc @@ -108,9 +108,7 @@ absl::StatusOr ProcessScatter(HloInstruction* hlo, if (lhs_parallel_dims->operand_parallel_dims != rhs_parallel_dims->operand_parallel_dims || lhs_parallel_dims->indices_parallel_dims != - rhs_parallel_dims->indices_parallel_dims || - lhs_parallel_dims->index_parallel_in_dim != - rhs_parallel_dims->index_parallel_in_dim) { + rhs_parallel_dims->indices_parallel_dims) { return false; } if (lhs_parallel_dims->operand_parallel_dims.size() != diff --git a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h index a8962df6ace617..6141c6b38770ea 100644 --- a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h +++ b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner.h @@ -54,13 +54,16 @@ class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { int64_t threshold_for_windowed_einsum_mib = 100000, bool windowed_einsum_use_multiple_streams = false, bool skip_checking_windowed_einsum_users = false, - bool disable_ag_rewrite_for_multiple_consumers = false) - : spmd::SpmdPartitioner(num_partitions, num_replicas, - GetSpmdPartitionerOptions( - threshold_for_windowed_einsum_mib, - windowed_einsum_use_multiple_streams, - skip_checking_windowed_einsum_users, - disable_ag_rewrite_for_multiple_consumers)) {} + bool disable_ag_rewrite_for_multiple_consumers = false, + std::optional total_bytes_windowed_einsum_threshold = + std::nullopt) + : spmd::SpmdPartitioner( + num_partitions, num_replicas, + GetSpmdPartitionerOptions(threshold_for_windowed_einsum_mib, + windowed_einsum_use_multiple_streams, + skip_checking_windowed_einsum_users, + disable_ag_rewrite_for_multiple_consumers, + total_bytes_windowed_einsum_threshold)) {} protected: std::unique_ptr CreateVisitor( @@ -87,7 +90,9 @@ class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { int64_t threshold_for_windowed_einsum_mib, bool windowed_einsum_use_multiple_streams = false, bool skip_checking_windowed_einsum_users = false, - bool disable_ag_rewrite_for_multiple_consumers = false) { + bool disable_ag_rewrite_for_multiple_consumers = false, + std::optional total_bytes_windowed_einsum_threshold = + std::nullopt) { spmd::SpmdPartitionerOptions options; options.allow_module_signature_change = true; options.threshold_for_windowed_einsum_mib = @@ -97,6 +102,8 @@ class StatefulRngSpmdPartitioner : public spmd::SpmdPartitioner { skip_checking_windowed_einsum_users; options.disable_ag_rewrite_for_multiple_consumers = disable_ag_rewrite_for_multiple_consumers; + options.total_bytes_windowed_einsum_threshold = + total_bytes_windowed_einsum_threshold; return options; } }; diff --git a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc index 6549214a74b792..abb22b4747500b 100644 --- a/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/stateful_rng_spmd_partitioner_test.cc @@ -23,10 +23,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/rng_expander.h" #include "xla/service/sharding_propagation.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" @@ -77,7 +77,8 @@ class StatefulRngSpmdPartitionerTest : public HloTestBase { debug_options.xla_gpu_threshold_for_windowed_einsum_mib(), debug_options.xla_gpu_multi_streamed_windowed_einsum(), skip_checking_windowed_einsum_users, - disable_ag_rewrite_for_multiple_consumers); + disable_ag_rewrite_for_multiple_consumers, + debug_options.xla_gpu_operand_bytes_threshold_for_windowed_einsum()); pass.AddPass(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); TF_RETURN_IF_ERROR(pass.Run(module.get()).status()); @@ -262,6 +263,75 @@ ENTRY %test { rotate = op::Concatenate(op::CollectivePermute(op::Slice()), op::Slice()); EXPECT_THAT(root, AllOf(rotate, op::Shape("f32[3]"))); } + +TEST_F(StatefulRngSpmdPartitionerTest, + TotalFlopsThresholdOverrideOperandThreshold) { + absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={(bf16[2,128,256]{2,1,0}, bf16[256,512]{1,0})->bf16[2,128,512]{2,1,0}}, num_partitions=4 + +ENTRY main { + Arg_0.1 = bf16[2,128,256]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = bf16[256,512]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + ROOT dot.5 = bf16[2,128,512]{2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}, sharding={devices=[1,1,4]<=[4]} +} + +)"; + DebugOptions debug_options = GetDefaultDebugOptions(); + debug_options.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true); + int64_t oper_bytes_threshold = 1 << 20; + debug_options.set_xla_gpu_operand_bytes_threshold_for_windowed_einsum( + oper_bytes_threshold); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_partitions=*/4, debug_options, + /*add_passes=*/nullptr, + /*skip_checking_windowed_einsum_users=*/true, + /*disable_ag_rewrite_for_multiple_consumers=*/true)); + XLA_VLOG_LINES(1, module->ToString()); + // The operand threshold is set to 0 but flops threshold is set to be + // larger than the total flops of the gemm. So we don't expect any + // windowed einsum loop but rather an all-gather. + EXPECT_EQ(CountInstructions(*module->entry_computation(), HloOpcode::kWhile), + 0); + EXPECT_EQ( + CountInstructions(*module->entry_computation(), HloOpcode::kAllGather), + 1); +} + +TEST_F(StatefulRngSpmdPartitionerTest, + TotalFlopsThresholdShouldEnableWindowedEinsum) { + absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={(bf16[2,128,256]{2,1,0}, bf16[256,512]{1,0})->bf16[2,128,512]{2,1,0}}, num_partitions=4 + +ENTRY main { + Arg_0.1 = bf16[2,128,256]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + Arg_1.2 = bf16[256,512]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + ROOT dot.5 = bf16[2,128,512]{2,1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={2}, rhs_contracting_dims={0}, sharding={devices=[1,1,4]<=[4]} +} + +)"; + DebugOptions debug_options = GetDefaultDebugOptions(); + debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true); + int64_t oper_bytes_threshold = 1 << 8; + debug_options.set_xla_gpu_operand_bytes_threshold_for_windowed_einsum( + oper_bytes_threshold); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + PartitionComputation(hlo_string, /*num_partitions=*/4, debug_options, + /*add_passes=*/nullptr, + /*skip_checking_windowed_einsum_users=*/true, + /*disable_ag_rewrite_for_multiple_consumers=*/true)); + XLA_VLOG_LINES(1, module->ToString()); + // The operand threshold is not set which defaults to 1000000 MB. + // But the flops threshold is set, the windowed einsum should still kick in. + EXPECT_EQ(CountInstructions(*module->entry_computation(), HloOpcode::kWhile), + 1); + EXPECT_EQ( + CountInstructions(*module->entry_computation(), HloOpcode::kAllGather), + 0); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/service/stable_sort_expander.h b/third_party/xla/xla/service/stable_sort_expander.h index d35f36ed21515f..78d58b24ba822e 100644 --- a/third_party/xla/xla/service/stable_sort_expander.h +++ b/third_party/xla/xla/service/stable_sort_expander.h @@ -16,36 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_STABLE_SORT_EXPANDER_H_ #define XLA_SERVICE_STABLE_SORT_EXPANDER_H_ -#include - -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/op_expander_pass.h" - -namespace xla { - -// HLO pass which expands Sort ops that have the is_stable field set to true -// into equivalent Sort ops which guarantee stable sorting without relying on -// the is_stable field. -class StableSortExpander : public OpExpanderPass { - public: - absl::string_view name() const override { return "stable-sort-expander"; } - - // Returns the index of the sort operand that is an iota op with an iota - // dimension which is the same as the dimension to sort. Also it should have - // an integral type that is large enough for the number of elements in the - // sort dimension. For now, we only allow S32, because we expect to find a S32 - // iota operand for all Sort ops which are created by TopK. - // - // If no operand of the input sort matches the conditions above, returns -1. - static int64_t IotaOperandIndexForStableSort(const HloSortInstruction& sort); - - private: - bool InstructionMatchesPattern(HloInstruction* instruction) override; - absl::StatusOr ExpandInstruction( - HloInstruction* instruction) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/stable_sort_expander.h" #endif // XLA_SERVICE_STABLE_SORT_EXPANDER_H_ diff --git a/third_party/xla/xla/service/stochastic_convert_decomposer.h b/third_party/xla/xla/service/stochastic_convert_decomposer.h index ad53237257c817..79aefac76e302a 100644 --- a/third_party/xla/xla/service/stochastic_convert_decomposer.h +++ b/third_party/xla/xla/service/stochastic_convert_decomposer.h @@ -16,24 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ #define XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// StochasticConvertDecomposer is a pass which replaces unsupported -// stochastic-convert with multiple hlos. -class StochasticConvertDecomposer : public HloModulePass { - public: - absl::string_view name() const override { - return "stochastic_convert_decomposer"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/expanders/stochastic_convert_decomposer.h" #endif // XLA_SERVICE_STOCHASTIC_CONVERT_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/stream_pool.cc b/third_party/xla/xla/service/stream_pool.cc index e94b5867cd8b51..1e80875681cec6 100644 --- a/third_party/xla/xla/service/stream_pool.cc +++ b/third_party/xla/xla/service/stream_pool.cc @@ -51,8 +51,8 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamPriority priority) { if (!stream) { // Create a new stream. stream = executor_->CreateStream(priority).value(); - stream->set_name(absl::StrFormat("%s pool stream", - se::StreamPriorityToString(priority))); + stream->SetName(absl::StrFormat("%s pool stream", + se::StreamPriorityToString(priority))); VLOG(1) << absl::StrFormat("Created new stream (%p) with priority = %s", stream.get(), se::StreamPriorityToString(priority)); diff --git a/third_party/xla/xla/service/sub_byte_normalization.h b/third_party/xla/xla/service/sub_byte_normalization.h index 69d503944f713d..3f9f700509c4b2 100644 --- a/third_party/xla/xla/service/sub_byte_normalization.h +++ b/third_party/xla/xla/service/sub_byte_normalization.h @@ -16,51 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ #define XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass that can modify the sub-byte element_size_in_bits annotation on -// layouts. Depending on the constructor argument, it either removes the -// element_size_in_bits annotation for platforms that don't support packed -// types, or it sets element_size_in_bits to N for N-bit values. -class SubByteNormalization : public HloModulePass { - public: - enum Mode { - // Remove element_size_in_bits on all layouts. Useful for platforms which - // do not support packed types. - REMOVE_ELEMENT_SIZE, - // Set element_size_in_bits to bitwidth(type) for layouts of types < 8 bits - // (S4, U4, etc.), and to 0 for all other layouts. Useful for platforms - // which support packed types. - SET_ELEMENT_SIZE, - }; - - explicit SubByteNormalization(Mode mode) : mode_(mode) {} - - ~SubByteNormalization() override = default; - - absl::string_view name() const override { - switch (mode_) { - case REMOVE_ELEMENT_SIZE: - return "sub-byte-size-removal"; - case SET_ELEMENT_SIZE: - return "sub-byte-size-setter"; - } - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - Mode mode_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/sub_byte_normalization.h" #endif // XLA_SERVICE_SUB_BYTE_NORMALIZATION_H_ diff --git a/third_party/xla/xla/service/topk_rewriter.cc b/third_party/xla/xla/service/topk_rewriter.cc index bb65d436acedbd..94afccffe6174f 100644 --- a/third_party/xla/xla/service/topk_rewriter.cc +++ b/third_party/xla/xla/service/topk_rewriter.cc @@ -23,14 +23,15 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" -#include "xla/client/lib/comparators.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/comparators.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -41,20 +42,6 @@ namespace xla { namespace m = match; -// TODO(cheshire): Avoid duplication w/ cudnn_vectorize_convolutions. -static absl::StatusOr BuilderToHloComputation( - XlaComputation& comp, HloComputation* sibling_computation) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, - HloModule::CreateFromProto(comp.proto(), config)); - - HloModule* dest_module = sibling_computation->parent(); - HloCloneContext context(dest_module); - return dest_module->DeepCloneComputation(new_module->entry_computation(), - &context); -} - static bool IsNanSafeGt(HloComputation* comp) { namespace m = match; auto match_bitcast_f32 = [](int64_t parameter_number) { @@ -500,9 +487,9 @@ class TopkDecomposerVisitor : public DfsHloRewriteVisitor { XlaComputation comparison = topk->largest() ? CreateScalarGtComputation(ptypes, &b) : CreateScalarLtComputation(ptypes, &b); - - TF_ASSIGN_OR_RETURN(HloComputation * comparator, - BuilderToHloComputation(comparison, topk->parent())); + TF_ASSIGN_OR_RETURN( + HloComputation * comparator, + XlaComputationToHloComputation(comparison, topk->parent()->parent())); return comparator; } diff --git a/third_party/xla/xla/service/topk_rewriter_test.cc b/third_party/xla/xla/service/topk_rewriter_test.cc index c678bef94e373f..5eda22467dade6 100644 --- a/third_party/xla/xla/service/topk_rewriter_test.cc +++ b/third_party/xla/xla/service/topk_rewriter_test.cc @@ -25,11 +25,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_dce.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" -#include "xla/service/tuple_simplifier.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/service/transpose_folding_test.cc b/third_party/xla/xla/service/transpose_folding_test.cc index 1dd4f0c361badc..31a17d78429176 100644 --- a/third_party/xla/xla/service/transpose_folding_test.cc +++ b/third_party/xla/xla/service/transpose_folding_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/service/tree_reduction_rewriter.h b/third_party/xla/xla/service/tree_reduction_rewriter.h index 9563c769d8ee54..e505b69e92d0d9 100644 --- a/third_party/xla/xla/service/tree_reduction_rewriter.h +++ b/third_party/xla/xla/service/tree_reduction_rewriter.h @@ -16,46 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ #define XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Increase precision for the reduction operation by applying the reduce-window -// first. -// -// E.g. suppose we want to reduce f32[1024] to a scalar. This pass first applies -// a reduce-window (with kSame padding) of size `reduce_window_size`, and then -// reduces the resulting array f32[32]. The rewrite is not applied if any of the -// reduced dimensions is smaller than the `reduce_window_size`. -// -// Applying this pass until a fixed point performs a variant of pairwise -// summation (https://en.wikipedia.org/wiki/Pairwise_summation), which is -// guaranteed to have an asymptotically smaller error bound provided that -// intermediate roundoff errors are random and have random sign. -// -// If this pass lowers the performance too much, the window size can always be -// increased to a larger value. -class TreeReductionRewriter : public HloModulePass { - public: - explicit TreeReductionRewriter(int64_t reduce_window_size = 32) - : reduce_window_size_(reduce_window_size) {} - ~TreeReductionRewriter() override = default; - absl::string_view name() const override { return "tree_reduction_rewriter"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - int64_t reduce_window_size_; -}; - -} // end namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/tree_reduction_rewriter.h" #endif // XLA_SERVICE_TREE_REDUCTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/triangular_solve_expander.cc b/third_party/xla/xla/service/triangular_solve_expander.cc index c61dc148c0ec33..0a328e86b3905a 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.cc +++ b/third_party/xla/xla/service/triangular_solve_expander.cc @@ -25,15 +25,16 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/slicing.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/constants.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/lib/slicing.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -599,15 +600,8 @@ absl::StatusOr TriangularSolveExpander::ExpandInstruction( /*block_size=*/block_size_, /*precision=*/PrecisionConfig::HIGHEST); TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); - - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - xla_computation.GetProgramShape()); - HloModuleConfig config(program_shape); - TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( - xla_computation.proto(), config)); - HloCloneContext context(module); - computation = - module->DeepCloneComputation(new_module->entry_computation(), &context); + TF_ASSIGN_OR_RETURN( + computation, XlaComputationToHloComputation(xla_computation, module)); } return instruction->parent()->AddInstruction(HloInstruction::CreateCall( diff --git a/third_party/xla/xla/service/triangular_solve_expander.h b/third_party/xla/xla/service/triangular_solve_expander.h index 0ccbcf1cf7ceaa..87aaf5612ce48a 100644 --- a/third_party/xla/xla/service/triangular_solve_expander.h +++ b/third_party/xla/xla/service/triangular_solve_expander.h @@ -17,8 +17,8 @@ limitations under the License. #define XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_ #include "absl/container/flat_hash_map.h" -#include "xla/client/xla_builder.h" -#include "xla/service/op_expander_pass.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" namespace xla { diff --git a/third_party/xla/xla/service/tuple_points_to_analysis.h b/third_party/xla/xla/service/tuple_points_to_analysis.h index 0b9710d3075810..1b231e4b76ad29 100644 --- a/third_party/xla/xla/service/tuple_points_to_analysis.h +++ b/third_party/xla/xla/service/tuple_points_to_analysis.h @@ -16,355 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ #define XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ -#include - -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/logical_buffer.h" -#include "xla/service/logical_buffer_analysis.h" -#include "xla/shape_tree.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" -#include "tsl/lib/gtl/compactptrset.h" -#include "tsl/platform/status.h" - -namespace xla { - -// A class describing the source(s) of the Buffer(s) contained in the output of -// a particular HLO instruction. The structure of PointsToSet mirrors the -// structure of the instruction's shape, which may be an arbitrary tree (eg, a -// nested tuple). Each node in this tree corresponds to a single buffer in the -// instruction's output and contains the set of Buffers which might define -// the corresponding buffer. -class PointsToSet { - public: - // Construct our ShapeTree with a pointer rather than a reference to a Shape - // because this is very hot code, and copying (and then destroying) all these - // Shapes is slow. - explicit PointsToSet(const Shape* shape) : tree_(shape) {} - - // Returns true if any points-to sets for any subshape element is not a - // singleton. - bool IsAmbiguous() const; - - // Returns true if no LogicalBuffer appears in more than one points-to set of - // the shape nodes. - bool IsDistinct() const; - - // Returns the total number of different LogicalBuffers contained in this - // object. This is equal to CreateFlattenedSet().size(). - size_t size() const; - - // Creates a set containing the union of all LogicalBuffers contained in the - // PointsToSet. - using BufferSet = tsl::gtl::CompactPointerSet; - BufferSet CreateFlattenedSet() const; - - // Returns true if the given buffer is in the points-to set at the given - // index. - bool ContainsBufferAtIndex(const LogicalBuffer& buffer, - const ShapeIndex& index) const; - - // Returns true if the given buffer is in the points-to set at any index. - bool ContainsBuffer(const LogicalBuffer& buffer) const; - - // Adds the given buffer to the points-to set at the given index. This is a - // nop if the buffer already is in the set at that index. - void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index); - - // For the subshape at the given index (where index is defined as in - // ShapeUtil::GetSubshape) this method returns the set of HLO instructions - // which may produce the tuple subshape at that index. For example, given: - // - // %tuple1 = tuple(...) - // %tuple2 = tuple(...) - // %select = select(%tuple1, %tuple2) - // %nested_tuple = tuple(%select, %tuple1) - // - // These are the values for tuple_sources() for the PointsToSet of - // %nested_tuple: - // - // tuple_sources({}) = {%nested_tuple} - // tuple_sources({0}) = {%tuple1, %tuple2} - // tuple_sources({1}) = {%tuple1} - // - // tuple_sources() at the index of an array shape (not a tuple) returns the - // empty set. The instructions in the set returned by tuple_sources - // necessarily are either Tuple instructions, constants, or parameters. - using SourceSet = tsl::gtl::CompactPointerSet; - const SourceSet& tuple_sources(const ShapeIndex& index) const; - - // Add a tuple source instruction for the given index. - void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple); - - using BufferList = absl::InlinedVector; - - // Return the list of logical buffers for the subshape at index. - const BufferList& element(const ShapeIndex& index) const { - return tree_.element(index).buffers; - } - BufferList* mutable_element(const ShapeIndex& index) { - return &tree_.mutable_element(index)->buffers; - } - - // Call fn(index, buflist) for every subshape index. - template - void ForEachElement(const Fn& fn) const { - tree_.ForEachElement([&fn](const ShapeIndex& index, const Elem& elem) { - fn(index, elem.buffers); - }); - } - template - void ForEachMutableElement(const Fn& fn) { - tree_.ForEachMutableElement([&fn](const ShapeIndex& index, Elem* elem) { - fn(index, &elem->buffers); - }); - } - template - absl::Status ForEachElementWithStatus(const Fn& fn) const { - return tree_.ForEachElementWithStatus( - [&fn](const ShapeIndex& index, const Elem& elem) { - return fn(index, elem.buffers); - }); - } - - private: - struct Elem { - BufferList buffers; - SourceSet tuple_sources; - }; - ShapeTree tree_; - - // PointsToSet contains references (const LogicalBuffer*) to elements within - // TuplePointsToAnalysis, so disable copying. - PointsToSet(const PointsToSet&) = delete; - PointsToSet& operator=(const PointsToSet&) = delete; -}; - -// This class describes a particular subshape in a computation (instruction and -// shape index) and the logical buffer which may be a source of the subshape -// value. -class BufferAlias { - public: - BufferAlias(HloInstruction* instruction, const ShapeIndex& index) - : instruction_(instruction), index_(index) {} - - // Return the instruction/index of the subshape. - HloInstruction* instruction() const { return instruction_; } - const ShapeIndex& index() const { return index_; } - - bool operator==(const BufferAlias& other) const { - return instruction_ == other.instruction_ && index_ == other.index_; - } - bool operator!=(const BufferAlias& other) const { return !(*this == other); } - - std::string ToString() const; - - private: - HloInstruction* instruction_; - ShapeIndex index_; -}; - -std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias); - -// DFS visitor that performs tuple points-to analysis. This analysis determines -// the potential sources of each buffer in each instruction's output. -class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { - public: - // Runs points-to analysis on 'module'. - static absl::StatusOr> Run( - const HloModule* module); - - // Return the points-to set of an instruction. This describes the potential - // sources of each buffer in the instruction's output. - const PointsToSet& GetPointsToSet( - const HloInstruction* hlo_instruction) const; - - // Returns the logical buffer with the given ID. - const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const; - - // Returns the buffer defined at the given instruction and index. An error is - // returned if no buffer is defined at that point. - absl::StatusOr GetBufferDefinedAt( - const HloInstruction* instruction, const ShapeIndex& index) const; - - // Return a (possibly empty) vector containing all BufferAliases of the given - // logical buffer The buffer alias set is the inverse of the points-to set. - // That is, LogicalBuffer B is in the points-to set of instruction I at index - // N iff instruction I, index N is a BufferAlias of B. - using BufferAliasVector = absl::InlinedVector; - const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; - - // Returns the number of logical buffers in the module - LogicalBuffer::Id num_logical_buffers() const { - return logical_buffer_analysis_->num_logical_buffers(); - } - - // Return a the logical buffer with id "id" in the module. Iteration - // over all logical buffers is usually done with something like: - // - // for (LogicalBuffer:Id id = 0; id < points_to.num_logical_buffers(); id++){ - // const auto& buffer = points_to.logical_buffer(id); - // ... do something with buffer ... - // } - LogicalBuffer& logical_buffer(LogicalBuffer::Id id) const { - return logical_buffer_analysis_->GetBuffer(id); - } - - // Returns a vector of buffers that the instruction produces. Most - // instructions produce a single buffer (the top-level buffer), some produce - // no buffers (eg bitcast), and some produce more than one buffer (eg, - // tuple-shaped parameters). - using BufferDefinitionVector = absl::InlinedVector; - const BufferDefinitionVector& GetBuffersDefinedByInstruction( - const HloInstruction* instruction) const; - - // Returns true if the given instruction defines a buffer at the given index. - bool InstructionDefinesBufferAtIndex(const HloInstruction* instruction, - const ShapeIndex& index) const; - - // Returns an OK status if the given buffer is defined by instruction - // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer - // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns - // an FailedPrecondition error status otherwise. An example of a LogicalBuffer - // which is not defined is a tuple element in a Tuple instruction. In this - // case, the Tuple instruction does not define the LogicalBuffer, rather that - // index aliases one of its operands. - absl::Status VerifyBuffer(const LogicalBuffer& buffer) const; - - absl::Status DefaultAction(HloInstruction* hlo_instruction) override; - absl::Status HandleTuple(HloInstruction* tuple) override; - absl::Status HandleGetTupleElement( - HloInstruction* get_tuple_element) override; - absl::Status HandleAsyncStart(HloInstruction* async_start) override; - absl::Status HandleAsyncUpdate(HloInstruction* async_update) override; - absl::Status HandleAsyncDone(HloInstruction* async_done) override; - absl::Status HandleBitcast(HloInstruction* bitcast) override; - absl::Status HandleDomain(HloInstruction* domain) override; - absl::Status HandleCopy(HloInstruction* copy) override; - absl::Status HandleCopyStart(HloInstruction* copy_start) override; - absl::Status HandleCopyDone(HloInstruction* copy_done) override; - absl::Status HandleRecvDone(HloInstruction* recv_done) override; - absl::Status HandleSend(HloInstruction* send) override; - absl::Status HandleAddDependency(HloInstruction* add_dependency) override; - absl::Status HandleCustomCall(HloInstruction* custom_call) override; - absl::Status HandleFusion(HloInstruction* fusion) override; - absl::Status HandleOptimizationBarrier(HloInstruction* barrier) override; - - std::string ToString() const; - - // Returns true if 'user' cannot possibly use the buffer at 'index' in - // 'operand'. Returns false otherwise. - // - // REQUIRES: 'operand' is an operand of 'user'. - bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user) const; - - private: - explicit TuplePointsToAnalysis( - const HloModule* module, - std::unique_ptr logical_buffer_analysis) - : module_(module), - logical_buffer_analysis_(std::move(logical_buffer_analysis)) {} - - // Perform the analysis. Should be called immediately after constructing the - // object and before calling GetPointsToSet. - absl::Status Analyze(); - - // Populates instruction-defined buffers and aliases for each instruction - // in 'instructions'. - absl::Status PopulateDefinedBuffersAndAliases( - const decltype(std::declval() - .instructions())& instructions); - - // Creates an empty PointsToSet in the points_to_ map for the given - // instruction. - PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction); - - // Creates a PointsToSet in the points_to_ map for 'instruction' which is a - // copy of the existing PointsToSet for 'src'. - PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction, - const HloInstruction* src); - - // Adds the buffers defined by the given instruction to the given vector. - absl::Status GatherBuffersDefinedByInstruction( - const HloInstruction* instruction, BufferDefinitionVector* buffers); - - // Print points-to set for 'instruction' to 'output'. - void InstructionToString(const HloInstruction* instruction, - std::string* output) const; - - // Information kept per instruction - struct PerInstruction { - std::unique_ptr points_to_set; - // Empirically, ~92% of instructions have 1 - // instruction_defined_buffer, and 99% have 0 or 1 - BufferDefinitionVector instruction_defined_buffers; - }; - - const PerInstruction* PerInst(const HloInstruction* inst) const { - int id = inst->unique_id(); - DCHECK_GE(id, 0); - auto iter = per_instruction_.find(id); - if (iter == per_instruction_.end()) { - LOG(FATAL) << "Expected per-instruction information to already exist"; - } else { - return iter->second.get(); - } - } - PerInstruction* PerInst(const HloInstruction* inst) { - int id = inst->unique_id(); - DCHECK_GE(id, 0); - auto iter = per_instruction_.find(id); - if (iter == per_instruction_.end()) { - return per_instruction_.emplace(id, std::make_unique()) - .first->second.get(); - } else { - return iter->second.get(); - } - } - - std::vector> - GetAllUsesOfInstructionAtIndex(HloInstruction* instruction, - const ShapeIndex& index) const; - bool HasUniqueFusedUseOfOperandAt(HloInstruction* operand, - const ShapeIndex& operand_index, - HloInstruction* fusion, - const int64_t use_operand_index) const; - - // The module this analysis is performed on. - const HloModule* module_; - - // The logical buffers for this module. - const std::unique_ptr logical_buffer_analysis_; - - // A map from instruction->unique_id() to - absl::flat_hash_map> per_instruction_; - - // A map from LogicalBuffer->id() to alias information about that logical - // buffer - std::vector logical_buffer_aliases_; - - TuplePointsToAnalysis(const TuplePointsToAnalysis&) = delete; - TuplePointsToAnalysis& operator=(const TuplePointsToAnalysis&) = delete; - // Whether to alias buffers connected by dataflow relations. This aliasing - // relation should not be recognized if copies can be inserted to break up - // the dataflow relation. - const bool alias_buffer_across_dataflow_ = false; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/tuple_points_to_analysis.h" #endif // XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/tuple_simplifier.h b/third_party/xla/xla/service/tuple_simplifier.h index 81c315d25ec47f..19d81248537be4 100644 --- a/third_party/xla/xla/service/tuple_simplifier.h +++ b/third_party/xla/xla/service/tuple_simplifier.h @@ -16,54 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_TUPLE_SIMPLIFIER_H_ #define XLA_SERVICE_TUPLE_SIMPLIFIER_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// A pass which simplifies patterns of Tuple and GetTupleElement instructions in -// the module. -class TupleSimplifier : public HloModulePass { - public: - TupleSimplifier() : TupleSimplifier(/*exclude_entry_computation=*/false) {} - explicit TupleSimplifier(bool exclude_entry_computation); - ~TupleSimplifier() override {} - absl::string_view name() const override { return "tuple-simplifier"; } - - // Runs tuple simplification on the given module. Returns whether the module - // was changed. - using HloPassInterface::Run; - using HloPassInterface::RunOnModuleGroup; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - private: - // When set, this pipeline stage will perform optimization of all computations - // apart from the module's entry computation. This is used by Graphcore's - // backend. - bool exclude_entry_computation_; - - // Collapse the following structure into just 'Tuple-shaped Op', iff the - // sequence of GTE ops is order-preserving: - // - // Tuple-shaped Op - // | - // +-----+-----+ - // | | | - // GTE GTE GTE - // | | | - // +-----+-----+ - // | - // Tuple - // - absl::StatusOr RemoveWholeTuple(HloInstruction* tuple); -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #endif // XLA_SERVICE_TUPLE_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/tuple_util.cc b/third_party/xla/xla/service/tuple_util.cc index e55b0d629c662a..abbc4975a5b6a1 100644 --- a/third_party/xla/xla/service/tuple_util.cc +++ b/third_party/xla/xla/service/tuple_util.cc @@ -246,4 +246,28 @@ HloInstruction* TupleUtil::AssembleTupleInstruction( return elements.element({}); } +HloInstruction* TupleUtil::GetTupleInstructionAtIndex( + HloInstruction& tuple, const ShapeIndex& target_index) { + HloInstruction* target_index_instr = &tuple; + + for (int32_t tuple_index : target_index) { + bool found = false; + for (HloInstruction* user : target_index_instr->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->tuple_index() == tuple_index) { + target_index_instr = user; + found = true; + break; + } + } + + if (!found) { + // No GTE found at the target index. + return nullptr; + } + } + + return target_index_instr; +} + } // namespace xla diff --git a/third_party/xla/xla/service/tuple_util.h b/third_party/xla/xla/service/tuple_util.h index 334b2cbc3dd8f5..86719aba74e120 100644 --- a/third_party/xla/xla/service/tuple_util.h +++ b/third_party/xla/xla/service/tuple_util.h @@ -85,6 +85,12 @@ class TupleUtil { static HloInstruction* AssembleTupleInstruction( HloComputation* computation, ShapeTree elements, absl::string_view name = ""); + + // Returns the tuple instruction at the given ShapeIndex `target_index`. + // Returns nullptr if there does not exist a tuple instruction at the given + // index, or if the index is invalid. + static HloInstruction* GetTupleInstructionAtIndex( + HloInstruction& tuple, const ShapeIndex& target_index); }; } // namespace xla diff --git a/third_party/xla/xla/service/tuple_util_test.cc b/third_party/xla/xla/service/tuple_util_test.cc index 2956e6edfbac86..e2a7176bc12b44 100644 --- a/third_party/xla/xla/service/tuple_util_test.cc +++ b/third_party/xla/xla/service/tuple_util_test.cc @@ -18,13 +18,16 @@ limitations under the License. #include #include +#include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -170,5 +173,56 @@ ENTRY entry { EXPECT_THAT(new_gte, op::GetTupleElement(existing_gte, 0)); } +TEST_F(TupleUtilTest, GetTupleInstructionAtIndexTest) { + const std::string hlo_string = R"( +HloModule GetTupleInstructionAtIndexTest + +ENTRY entry { + p0 = (f32[32,32]{1,0}, (f32[32,32]{1,0}, f32[32,32]{1,0}, (f32[32,32]{1,0})), f32[32,32]) parameter(0) + gte = (f32[32,32]{1,0}, f32[32,32]{1,0}, (f32[32,32]{1,0})) get-tuple-element(p0), index=1 + gte.1 = f32[32,32]{1,0} get-tuple-element(p0), index=0 + gte.2 = (f32[32,32]{1,0}) get-tuple-element(gte), index=2 + gte.3 = f32[32,32]{1,0} get-tuple-element(gte.2), index=0 + ROOT root = f32[32,32]{1,0} get-tuple-element(gte), index=1 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + HloInstruction* p0 = FindInstruction(module.get(), "p0"); + ASSERT_NE(p0, nullptr); + HloInstruction* gte = FindInstruction(module.get(), "gte"); + ASSERT_NE(gte, nullptr); + HloInstruction* gte1 = FindInstruction(module.get(), "gte.1"); + ASSERT_NE(gte1, nullptr); + HloInstruction* gte2 = FindInstruction(module.get(), "gte.2"); + ASSERT_NE(gte2, nullptr); + + // Valid queries. + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 1}), + op::GetTupleElement(gte, 1)); + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {1}), + op::GetTupleElement(p0, 1)); + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 2, 0}), + op::GetTupleElement(gte2, 0)); + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 2}), + op::GetTupleElement(gte, 2)); + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 2}), + op::GetTupleElement(gte, 2)); + EXPECT_THAT(TupleUtil::GetTupleInstructionAtIndex(*p0, {0}), + op::GetTupleElement(p0, 0)); + + // Invalid queries. + // Out of bounds + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {3}), nullptr); + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {-1}), nullptr); + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 3}), nullptr); + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {0, -1}), nullptr); + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 2, 3}), nullptr); + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {1, 2, -1}), nullptr); + // Valid index but no gte present. + EXPECT_EQ(TupleUtil::GetTupleInstructionAtIndex(*p0, {2}), nullptr); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/value_range_test.cc b/third_party/xla/xla/service/value_range_test.cc index d95f044c6c5ead..1f98c489edc373 100644 --- a/third_party/xla/xla/service/value_range_test.cc +++ b/third_party/xla/xla/service/value_range_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" namespace xla { diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc index c9b34c702efc32..aac2a22abfe0c5 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/hlo/analysis/hlo_replication_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -36,7 +37,6 @@ limitations under the License. #include "xla/map_util.h" #include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/hlo_replication_analysis.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -936,7 +936,7 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool is_changed = false; - bool run_next_pass = true; + // In case of MPMD, all-reduces might be cross-module and should preserve // their channel ID. Do not move all-reduces in this case since the channel // ID might be changed. @@ -965,96 +965,95 @@ absl::StatusOr WhileLoopAllReduceCodeMotion::Run( // loop. We recursively sink the all-reduce through nested while loops if // applicable by repeating this process. uint32_t count_all_reduce = 0, count_reduce_scatter = 0; - while (run_next_pass) { - run_next_pass = false; - std::unique_ptr call_graph = CallGraph::Build(module); + std::unique_ptr call_graph = CallGraph::Build(module); + // We process all callees of a computation before processing the computation, + // so that when we process a computation, the all-reduce instructions that + // need to be hoisted to the computation from its callees have been hoisted. + for (HloComputation* computation : + module->MakeComputationPostOrder(execution_threads)) { // A computation could be the while body of multiple while instructions, // so we start from the computation and find all of its callers that is a // kWhile if there is any. - for (HloComputation* computation : - module->computations(execution_threads)) { - std::vector computation_callers = - call_graph->GetComputationCallers(computation); - std::vector while_caller_instructions; - for (HloInstruction* caller_instruction : computation_callers) { - // For simplicity, we only support while instructions whose shape is - // tuple. - if (caller_instruction->opcode() == HloOpcode::kWhile && - caller_instruction->shape().IsTuple() && - caller_instruction->while_body() == computation) { - while_caller_instructions.push_back(caller_instruction); - } - } - // Skip to next computation if this computation is not the while body of - // any while instruction. - if (while_caller_instructions.empty()) { - continue; + std::vector computation_callers = + call_graph->GetComputationCallers(computation); + std::vector while_caller_instructions; + for (HloInstruction* caller_instruction : computation_callers) { + // For simplicity, we only support while instructions whose shape is + // tuple. + if (caller_instruction->opcode() == HloOpcode::kWhile && + caller_instruction->shape().IsTuple() && + caller_instruction->while_body() == computation) { + while_caller_instructions.push_back(caller_instruction); } - std::vector while_body_all_reduces; - for (HloInstruction* while_body_instruction : - computation->MakeInstructionPostOrder()) { - HloOpcode op = while_body_instruction->opcode(); - const bool is_candidate = - (op == HloOpcode::kAllReduce) || - (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); - if (!is_candidate) { - continue; - } - auto* all_reduce_instruction = - Cast(while_body_instruction); - if (all_reduce_instruction->constrain_layout()) { - return false; - } else { - while_body_all_reduces.push_back(all_reduce_instruction); - } - } - HloInstructionMap> - all_reduce_to_accumulations; - for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { - auto movable_all_reduce_context = IsAllReduceMovable( - all_reduce, computation, cross_replica_replication_analysis, - cross_partition_replication_analysis); - if (movable_all_reduce_context.is_movable) { - all_reduce_to_accumulations[all_reduce] = - std::move(movable_all_reduce_context.accumulation_contexts); - } - VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " - << all_reduce->ToString() - << " is_movable: " << movable_all_reduce_context.is_movable - << " while loop: " << while_caller_instructions.front()->name() - << " num_accumulations: " - << (movable_all_reduce_context.is_movable - ? all_reduce_to_accumulations[all_reduce].size() - : 0); - } - if (all_reduce_to_accumulations.empty()) { + } + // Skip to next computation if this computation is not the while body of + // any while instruction. + if (while_caller_instructions.empty()) { + continue; + } + std::vector while_body_all_reduces; + for (HloInstruction* while_body_instruction : + computation->MakeInstructionPostOrder()) { + HloOpcode op = while_body_instruction->opcode(); + const bool is_candidate = + (op == HloOpcode::kAllReduce) || + (enable_reduce_scatter_ && op == HloOpcode::kReduceScatter); + if (!is_candidate) { continue; } - // For each while instruction calling this computation, create the - // corresponding all-reduces after the while loop. - for (HloInstruction* while_instruction : while_caller_instructions) { - TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( - while_instruction, all_reduce_to_accumulations)); - is_changed = true; - run_next_pass = true; + auto* all_reduce_instruction = + Cast(while_body_instruction); + if (all_reduce_instruction->constrain_layout()) { + return false; + } else { + while_body_all_reduces.push_back(all_reduce_instruction); } - // At last, remove the old all-reduce instructions in the while body. - for (const auto& all_reduce_accumulations_pair : - all_reduce_to_accumulations) { - HloInstruction* all_reduce = all_reduce_accumulations_pair.first; - if (all_reduce->opcode() == HloOpcode::kAllReduce) { - count_all_reduce++; - } else { - count_reduce_scatter++; - } - TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( - all_reduce, all_reduce->mutable_operand(0))); + } + HloInstructionMap> + all_reduce_to_accumulations; + for (HloAllReduceInstructionBase* all_reduce : while_body_all_reduces) { + auto movable_all_reduce_context = IsAllReduceMovable( + all_reduce, computation, cross_replica_replication_analysis, + cross_partition_replication_analysis); + if (movable_all_reduce_context.is_movable) { + all_reduce_to_accumulations[all_reduce] = + std::move(movable_all_reduce_context.accumulation_contexts); } - // Needs to rebuild the call graph or we could access removed - // instructions. - if (run_next_pass) { - break; + VLOG(3) << "WhileLoopAllReduceCodeMotion, all-reduce: " + << all_reduce->ToString() + << " is_movable: " << movable_all_reduce_context.is_movable + << " while loop: " << while_caller_instructions.front()->name() + << " num_accumulations: " + << (movable_all_reduce_context.is_movable + ? all_reduce_to_accumulations[all_reduce].size() + : 0); + } + if (all_reduce_to_accumulations.empty()) { + continue; + } + // For each while instruction calling this computation, create the + // corresponding all-reduces after the while loop. + for (HloInstruction* while_instruction : while_caller_instructions) { + TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( + while_instruction, all_reduce_to_accumulations)); + is_changed = true; + } + // At last, remove the old all-reduce instructions in the while body. + for (const auto& all_reduce_accumulations_pair : + all_reduce_to_accumulations) { + HloInstruction* all_reduce = all_reduce_accumulations_pair.first; + if (all_reduce->opcode() == HloOpcode::kAllReduce) { + count_all_reduce++; + } else { + count_reduce_scatter++; } + TF_RETURN_IF_ERROR(computation->ReplaceInstructionWithDifferentShape( + all_reduce, all_reduce->mutable_operand(0))); + } + // Needs to rebuild the call graph after we remove instructions to avoid + // accessing removed instructions. + if (!all_reduce_to_accumulations.empty()) { + call_graph = CallGraph::Build(module); } } VLOG(2) << "Hoisted " << count_all_reduce << " all-reduce and " diff --git a/third_party/xla/xla/service/while_loop_analysis.h b/third_party/xla/xla/service/while_loop_analysis.h index 5fe4038ab6d0bc..c6d95ac80db238 100644 --- a/third_party/xla/xla/service/while_loop_analysis.h +++ b/third_party/xla/xla/service/while_loop_analysis.h @@ -16,47 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ #define XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ -#include - -#include "xla/hlo/ir/hlo_instruction.h" - -namespace xla { - -// Returns the precise trip count of the loop if it's statically known, -// nullopt otherwise. -// -// max_brute_force_iters limits the number of steps that are evaluated while -// trying to brute force a loop trip count. trip counts larger than -// max_brute_force_iters may be returned if we can pattern-match the loop -// condition. -std::optional ComputeWhileLoopTripCount( - const HloInstruction *while_op, int64_t max_brute_force_iters = 128); - -// Returns an upper bound on the trip count of the loop if it's statically -// known, nullopt otherwise. -std::optional ComputeWhileLoopTripCountUpperBound( - const HloInstruction *while_op); - -// The below function identifies a subset of all possible auxiliary -// induction variables (AIV). Specifically, candidates are gtes, e.g., -// gte(param0, N) -std::vector GetAuxiliaryLoopInductionVars( - const HloInstruction *while_op); -// Returns the tuple index of the loop induction variable if there is such an -// induction variable detected. Otherwise returns nullopt. -std::optional GetLoopInductionVarTupleIdx( - const HloInstruction *while_op); - -// Checks the following conditions: -// - `i`, the induction varaiable, is initialized to a scalar constant K -// (namely, `indvar_init`), -// - the while condition does `i < N` or `i <= N` (where N is a know constant) -// - the while body does `i++`. -// If so, it's trivial to compute the loop bound as `N - k` or `N - k + 1`, -// respectively. -std::optional MatchTrivialLoopTripCount(const HloInstruction *while_op, - int64_t indvar_tuple_idx, - const Literal &indvar_init); -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/analysis/while_loop_analysis.h" #endif // XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_ diff --git a/third_party/xla/xla/service/while_loop_concat_code_motion.cc b/third_party/xla/xla/service/while_loop_concat_code_motion.cc index f63c7afd9e96bf..e1aa072d30ecc6 100644 --- a/third_party/xla/xla/service/while_loop_concat_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_concat_code_motion.cc @@ -32,8 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/tuple_simplifier.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/service/while_loop_simplifier.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/while_loop_constant_sinking.cc b/third_party/xla/xla/service/while_loop_constant_sinking.cc index 83bd7f056ae6ae..49dfe2a5f7e4dc 100644 --- a/third_party/xla/xla/service/while_loop_constant_sinking.cc +++ b/third_party/xla/xla/service/while_loop_constant_sinking.cc @@ -15,11 +15,26 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" +#include +#include +#include +#include + #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/while_util.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -65,7 +80,7 @@ HloInstruction* CloneHelper(const HloInstruction* instruction, } // namespace absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( - HloInstruction* while_instr) { + HloModule* module, HloInstruction* while_instr) { HloComputation* while_cond = while_instr->while_condition(); HloComputation* while_body = while_instr->while_body(); @@ -74,14 +89,16 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( return false; } - bool changed = false; - absl::flat_hash_map> conditional_gte_index_to_insts = WhileUtil::GetGTEsMapForWhileConditional(*while_cond); std::vector invariant_body_gtes = WhileUtil::GetInvariantGTEsForWhileBody(*while_body); + HloCloneContext body_clone_context(module); + HloCloneContext cond_clone_context(module); + HloComputation* body_clone = nullptr; + HloComputation* cond_clone = nullptr; for (HloInstruction* invariant_body_gte : invariant_body_gtes) { int64_t index = invariant_body_gte->tuple_index(); const HloInstruction& invariant_value = *init_value.operand(index); @@ -103,12 +120,18 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( // Sink into the while_body. // Should have at least one user that's not while_body_root. if (invariant_body_gte->user_count() > 1) { + if (!body_clone) { + body_clone = module->AddEmbeddedComputation( + while_body->Clone("sunk", &body_clone_context)); + while_instr->set_while_body(body_clone); + } HloInstruction* constant_instr = - CloneHelper(&invariant_value, while_body); + CloneHelper(&invariant_value, body_clone); TF_RETURN_IF_ERROR(ReplaceUsesWhileKeepingLoopInvariance( - invariant_body_gte, constant_instr, while_body->root_instruction(), + body_clone_context.FindInstruction(invariant_body_gte), + constant_instr, + body_clone_context.FindInstruction(while_body->root_instruction()), index)); - changed = true; } // Check if there is a corresponding GTE in while_conditional. @@ -120,16 +143,22 @@ absl::StatusOr WhileLoopConstantSinking::TrySinkingConstantsIntoWhileLoop( for (HloInstruction* invariant_cond_gte : it->second) { // Should have at least one user. if (invariant_cond_gte->user_count() > 0) { + if (!cond_clone) { + cond_clone = module->AddEmbeddedComputation( + while_cond->Clone("sunk", &cond_clone_context)); + while_instr->set_while_condition(cond_clone); + } HloInstruction* constant_instr = - CloneHelper(&invariant_value, while_cond); - TF_RETURN_IF_ERROR( - invariant_cond_gte->ReplaceAllUsesWith(constant_instr)); - changed = true; + CloneHelper(&invariant_value, cond_clone); + HloInstruction* cond_gte = + cond_clone_context.FindInstruction(invariant_cond_gte); + TF_RETURN_IF_ERROR(cond_gte->ReplaceAllUsesWith(constant_instr)); + TF_RETURN_IF_ERROR(cond_clone->RemoveInstruction(cond_gte)); } } } - return changed; + return body_clone || cond_clone; } absl::StatusOr WhileLoopConstantSinking::Run( @@ -140,37 +169,51 @@ absl::StatusOr WhileLoopConstantSinking::Run( bool changed = false; std::vector while_instrs; - for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { - // Right now we don't particularly care about optimizing while-of-while - // patterns. If/When we do, we'll want to visit the outer while (while_0) - // before we visit the inner while (while_1): - // - // while_1_body(state) { - // val = gte(state, 0) // Loop invariant - // use(val) - // } - // - // while_0_body(state) { - // val = gte(state, 0) // Loop invariant - // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) - // ... - // } - // - // main { - // while_0 = while(init=(constant, ...), body=while_0_body, ...) - // } - // - // This will let us sink the constant into the outer while first and then - // into the inner while in a single run of this pass. - absl::c_copy_if(comp->instructions(), std::back_inserter(while_instrs), - HloPredicateIsOp); - } - for (HloInstruction* while_instr : while_instrs) { - TF_ASSIGN_OR_RETURN(bool result, - TrySinkingConstantsIntoWhileLoop(while_instr)); - changed |= result; + // Visit computations in order, from outermost to innermost. + // We want to visit the outer while (while_0) before we visit the inner + // while (while_1): + // + // while_1_body(state) { + // val = gte(state, 0) // Loop invariant + // use(val) + // } + // + // while_0_body(state) { + // val = gte(state, 0) // Loop invariant + // while_1 = while(init=tuple(val, ...), body=while_1_body, ...) + // ... + // } + // + // main { + // while_0 = while(init=(constant, ...), body=while_0_body, ...) + // } + // + // This will let us sink the constant into the outer while first and then + // into the inner while in a single run of this pass. + std::stack agenda; + agenda.push(module->entry_computation()); + absl::flat_hash_set visited; + while (!agenda.empty()) { + HloComputation* comp = agenda.top(); + agenda.pop(); + if (!visited.insert(comp).second) { + continue; + } + for (auto* instr : comp->instructions()) { + // Sinking constants may change the called computations, so do that first + // if this is a while instruction. + if (instr->opcode() == HloOpcode::kWhile) { + TF_ASSIGN_OR_RETURN(bool result, + TrySinkingConstantsIntoWhileLoop(module, instr)); + changed |= result; + } + for (HloComputation* child : instr->called_computations()) { + agenda.push(child); + } + } } + TF_RETURN_IF_ERROR(module->RemoveUnusedComputations()); if (changed) { VLOG(2) << "HLO module after WhileLoopConstantSinking:"; diff --git a/third_party/xla/xla/service/while_loop_constant_sinking.h b/third_party/xla/xla/service/while_loop_constant_sinking.h index 8d1402ff72d29b..1ea8e4db0f1b18 100644 --- a/third_party/xla/xla/service/while_loop_constant_sinking.h +++ b/third_party/xla/xla/service/while_loop_constant_sinking.h @@ -66,7 +66,7 @@ class WhileLoopConstantSinking : public HloModulePass { private: absl::StatusOr TrySinkingConstantsIntoWhileLoop( - HloInstruction* while_instr); + HloModule* module, HloInstruction* while_instr); const bool sink_broadcast_of_constants_; const bool sink_only_scalar_constants_; diff --git a/third_party/xla/xla/service/while_loop_constant_sinking_test.cc b/third_party/xla/xla/service/while_loop_constant_sinking_test.cc index 3597686e9b9cce..2cfd69a9254e8b 100644 --- a/third_party/xla/xla/service/while_loop_constant_sinking_test.cc +++ b/third_party/xla/xla/service/while_loop_constant_sinking_test.cc @@ -15,9 +15,13 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/literal_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -68,7 +72,7 @@ ENTRY entry { .Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(_, op::Constant()), _)); } @@ -115,7 +119,7 @@ ENTRY entry { .Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(_, op::Broadcast(op::Constant())), _)); } @@ -155,7 +159,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(op::Constant(), op::Constant()), op::GetTupleElement(op::Parameter(0)), @@ -196,7 +200,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::GetTupleElement(op::Constant(), 0), op::GetTupleElement(op::Parameter(0)))); @@ -244,7 +248,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::Add(op::Constant(), ::testing::Not(op::Constant())), op::GetTupleElement(op::Parameter(0)), @@ -286,7 +290,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_body = module->GetComputationWithName("body"); + auto* while_body = module->GetComputationWithName("body.sunk"); EXPECT_THAT(while_body->root_instruction(), op::Tuple(op::GetTupleElement(), op::GetTupleElement(), op::GetTupleElement())); @@ -332,7 +336,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); } @@ -372,7 +376,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::GetTupleElement(op::Constant()))); } @@ -415,7 +419,7 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::Lt(_, op::Constant())); for (const HloInstruction* inst : while_condition->instructions()) { if (inst->opcode() == HloOpcode::kConstant) { @@ -465,9 +469,61 @@ ENTRY entry { WhileLoopConstantSinking{}.Run(module.get())); ASSERT_TRUE(changed); - auto* while_condition = module->GetComputationWithName("condition"); + auto* while_condition = module->GetComputationWithName("condition.sunk"); EXPECT_THAT(while_condition->root_instruction(), op::And(op::Lt(_, op::Constant()), op::Lt(_, op::Constant()))); } + +TEST_F(WhileLoopConstantSinkingTest, SinkWithSharedBody) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +body { + p_body = (f32[2],f32[2]) parameter(0) + p_body.0 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=0 + p_body.1 = f32[2] get-tuple-element((f32[2],f32[2]) p_body), index=1 + + add.0 = f32[2] add(p_body.0, p_body.1) + ROOT root = (f32[2],f32[2]) tuple(add.0, p_body.1) +} + +condition { + p_cond = (f32[2],f32[2]) parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + const_0 = f32[2] constant({1, 2}) + const_1 = f32[2] constant({2, 1}) + while_init = (f32[2],f32[2]) tuple(const_0, const_1) + while = (f32[2],f32[2]) while(while_init), condition=condition, body=body + while_init2 = (f32[2],f32[2]) tuple(const_1, const_0) + while2 = (f32[2],f32[2]) while(while_init2), condition=condition, body=body + ROOT tuple = ((f32[2],f32[2]),(f32[2],f32[2])) tuple(while, while2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + WhileLoopConstantSinking(/*sink_broadcast_of_constants=*/false, + /*sink_only_scalar_constants=*/false) + .Run(module.get())); + ASSERT_TRUE(changed); + + auto* while_body = module->GetComputationWithName("body.sunk"); + EXPECT_THAT( + while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1({2, 1}))), + _)); + while_body = module->GetComputationWithName("body.sunk.1"); + EXPECT_THAT( + while_body->root_instruction(), + op::Tuple(op::Add(_, op::Constant(LiteralUtil::CreateR1({1, 2}))), + _)); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc index 399fd7c88c3333..284a5a2f192466 100644 --- a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" -#include "xla/service/while_loop_analysis.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/service/while_util.h" #include "xla/shape_util.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion_test.cc b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion_test.cc index 9aa7a15876f969..fdf3164a4db5a7 100644 --- a/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_expensive_invariant_code_motion_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index 07b49dbafe45d1..1fc04ea791c169 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/service/while_loop_fusible_sinking.h" #include +#include #include #include "absl/algorithm/container.h" @@ -24,11 +25,20 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/hlo/analysis/while_loop_analysis.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/service/pattern_matcher.h" #include "xla/service/while_util.h" +#include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -50,6 +60,232 @@ bool IsFusionCandidate(const HloInstruction* instr) { (instr->IsElementwise() || instr->opcode() == HloOpcode::kReshape || instr->opcode() == HloOpcode::kTranspose); } + +// For element-wise op 'instr' we have: +// forall index i in output shape: instr[i] = f(operand1[i], ...), where +// f is the elementwise operation. We can see that all the indices of the output +// shape is written to. +bool IsShapeCoveringWriteOnlyInstruction(HloInstruction* instr) { + // Clamp is tricky to handle, we bail. + if (instr->opcode() == HloOpcode::kClamp) { + return false; + } + return instr->IsElementwise(); +} + +// Updates the uses of the while loop with the equivalent tuple that retrieves +// the first original_operand_count elements of the while output. +absl::Status UpdateWhileUsesWithTuple(HloInstruction* while_instr, + int64_t original_operand_count) { + const std::vector users = while_instr->users(); + std::vector gtes(original_operand_count); + for (int64_t i = 0; i < gtes.size(); ++i) { + gtes[i] = while_instr->AddInstruction( + HloInstruction::CreateGetTupleElement(while_instr, i)); + } + HloInstruction* tuple = + while_instr->AddInstruction(HloInstruction::CreateTuple(gtes)); + if (while_instr->IsRoot()) { + while_instr->parent()->set_root_instruction(tuple); + } + if (!users.empty()) { + TF_RETURN_IF_ERROR(while_instr->ReplaceUsesWith(users, tuple)); + } + return absl::OkStatus(); +} + +// Appends the given new operand to while input and update loops computations +// and shape accordingly and returns the gte instruction within the body that +// represents the new operand. +absl::StatusOr AppendToWhileState( + HloInstruction* while_instr, HloInstruction* new_operand) { + // Update the while initial value + HloInstruction* while_input = while_instr->while_init(); + ShapeUtil::AppendShapeToTuple(new_operand->shape(), + while_input->mutable_shape()); + while_input->AppendOperand(new_operand); + // Update the body computation. + HloComputation* body = while_instr->while_body(); + *body->parameter_instruction(0)->mutable_shape() = while_input->shape(); + HloInstruction* new_gte = + body->AddInstruction(HloInstruction::CreateGetTupleElement( + body->parameter_instruction(0), while_input->operand_count() - 1)); + ShapeUtil::AppendShapeToTuple(new_gte->shape(), + body->root_instruction()->mutable_shape()); + body->root_instruction()->AppendOperand(new_gte); + // Update the condition computation. + HloComputation* condition = while_instr->while_condition(); + *condition->parameter_instruction(0)->mutable_shape() = while_input->shape(); + // Finalize the update by changing the uses of the while loop and updating its + // shape. + TF_RETURN_IF_ERROR( + UpdateWhileUsesWithTuple(while_instr, while_input->operand_count() - 1)); + *while_instr->mutable_shape() = while_input->shape(); + return new_gte; +} + +// Return the list of indices of the given while loop that are written to +// entirely in the loop body. +std::vector GetLoopShapeCoveringWriteIndices( + HloInstruction* while_instr) { + HloInstruction* tuple; + if (!Match(while_instr->while_init(), + match::Op(&tuple).WithOpcode(HloOpcode::kTuple).WithOneUse())) { + return {}; + } + + std::vector loop_indices; + for (int64_t tuple_index = 0; tuple_index < tuple->operand_count(); + ++tuple_index) { + HloInstruction* arg_operand = tuple->mutable_operand(tuple_index); + // We're looking for an argument that is a broadcast(constant) feeds a while + // loop. + if (!Match(arg_operand, match::Broadcast(match::ConstantScalar()))) { + continue; + } + HloInstruction* broadcast_gte = hlo_query::GetUniqueGteInstruction( + while_instr->while_body()->parameter_instruction(0), tuple_index); + if (broadcast_gte == nullptr) { + continue; + } + + // If the buffer is not written to entirely, we won't sink it. We might be + // able to support this case in the future, but for now we'll just skip it. + HloInstruction* root_buffer_value = + while_instr->while_body()->root_instruction()->mutable_operand( + tuple_index); + if (!IsShapeCoveringWriteOnlyInstruction(root_buffer_value)) { + continue; + } + loop_indices.push_back(tuple_index); + } + + return loop_indices; +} + +// Returns true if the given instruction is monotonic, i.e. it is either +// monotonically increasing or decreasing. This is not an exhaustive list of +// monotonic operations. +bool IsMonotonic(HloInstruction* instr) { + return instr->opcode() == HloOpcode::kAdd || + instr->opcode() == HloOpcode::kSubtract; +} + +// The idea is that certain constant-initialized buffers can be left as +// uninitialized if all the elements of the buffer are written to in the loop +// body. This way, we eliminate the need to initialize the buffer (with +// broadcast) in the critical path of the program. To summarize, the conditions +// to apply this optimization are: +// 1. The buffer is a constant-initialized buffer. +// 2. All the elements of the buffer are written to in the loop body. +// 3. The iteration variable of the loop is monotonically increasing or +// decreasing. +// The optimization is applied by creating a select between the initial value +// and the value in the body. The select is guarded by a predicate that checks +// if the loop iteration variable is equal to the first iteration value. +absl::StatusOr TryRewritingBroadcastAsAllocateBuffer( + HloInstruction* while_instr) { + std::optional induction_var_tuple_index = + GetLoopInductionVarTupleIdx(while_instr); + if (!induction_var_tuple_index.has_value()) { + return false; + } + HloComputation* while_body = while_instr->while_body(); + bool changed = false; + std::vector old_buffers; + std::vector loop_indices = + GetLoopShapeCoveringWriteIndices(while_instr); + if (loop_indices.empty()) { + return false; + } + HloInstruction* loop_iteration_variable_initial_value = + while_instr->while_init()->mutable_operand( + induction_var_tuple_index.value()); + // We only support integer loop iteration variables since these are the only + // ones that can be compared to get the first iteration value. + if (!ShapeUtil::ElementIsIntegral( + loop_iteration_variable_initial_value->shape())) { + return false; + } + + // Also we have to make sure that the induction variable is either + // monotonically increasing or decreasing since we rely on this fact to get + // the first iteration value. + HloInstruction* induction_var_update_fun = + while_instr->while_body()->root_instruction()->mutable_operand( + induction_var_tuple_index.value()); + if (!IsMonotonic(induction_var_update_fun)) { + return false; + } + + VLOG(3) << "Sinking fusible broadcast into " << while_instr->ToString(); + + // If we find any sinkable indices, we prepare the loop state by adding the + // initial value of the loop iteration variable to the loop state and use it + // inside the body to create a predicate that checks if the loop iteration + // variable is equal to the first iteration value. This is done only once + // regardless of the number of sinkable indices. + TF_ASSIGN_OR_RETURN( + HloInstruction * loop_iteration_variable_initial_value_gte, + AppendToWhileState(while_instr, loop_iteration_variable_initial_value)); + HloInstruction* iteration_var_gte = hlo_query::GetUniqueGteInstruction( + while_body->parameter_instruction(0), induction_var_tuple_index.value()); + if (iteration_var_gte == nullptr) { + return false; + } + HloInstruction* is_first_iteration_pred = + while_body->AddInstruction(HloInstruction::CreateCompare( + ShapeUtil::MakeShape(PRED, {}), iteration_var_gte, + loop_iteration_variable_initial_value_gte, + Comparison::Direction::kEq)); + for (int64_t loop_index : loop_indices) { + HloInstruction* buffer = + while_instr->while_init()->mutable_operand(loop_index); + VLOG(3) << "Sinking " << buffer->ToString() << " at index " << loop_index; + if (absl::c_find(old_buffers, buffer) == old_buffers.end()) { + old_buffers.push_back(buffer); + } + // It is possible that the same broadcast has multiple users, first clone + // the buffer and then replace this specific use with the clone. + HloInstruction* buffer_clone = buffer->AddInstruction(buffer->Clone()); + TF_RETURN_IF_ERROR(while_instr->while_init()->ReplaceOperandWith( + loop_index, buffer_clone)); + + // Replace the clone with a free AllocateBuffer. + HloInstruction* new_buffer = + while_instr->parent()->AddInstruction(HloInstruction::CreateCustomCall( + buffer_clone->shape(), {}, "AllocateBuffer")); + TF_RETURN_IF_ERROR(buffer_clone->ReplaceAllUsesWith(new_buffer)); + TF_RETURN_IF_ERROR(buffer_clone->parent()->RemoveInstruction(buffer_clone)); + // Broadcast the predicate to the shape of the buffer. + HloInstruction* is_first_iteration_pred_broadcast = + while_body->AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShapeWithDescendingLayout( + PRED, new_buffer->shape().dimensions()), + is_first_iteration_pred, {})); + HloInstruction* sunk_constant_broadcast = + while_body->AddInstruction(HloInstruction::CreateBroadcast( + new_buffer->shape(), + while_body->AddInstruction(buffer->mutable_operand(0)->Clone()), + {})); + // Create a select between the initial broadcasted value (in the first + // iteration of the loop) and the value in the body in the subsequent + // iterations and replace the use of the buffer in the body with the select. + HloInstruction* buffer_body_gte = hlo_query::GetUniqueGteInstruction( + while_body->parameter_instruction(0), loop_index); + HloInstruction* new_buffer_value = + while_body->AddInstruction(HloInstruction::CreateTernary( + new_buffer->shape(), HloOpcode::kSelect, + is_first_iteration_pred_broadcast, sunk_constant_broadcast, + buffer_body_gte)); + TF_RETURN_IF_ERROR(buffer_body_gte->ReplaceAllUsesWith(new_buffer_value)); + if (buffer->user_count() == 0) { + TF_RETURN_IF_ERROR(buffer->parent()->RemoveInstruction(buffer)); + } + changed = true; + } + return changed; +} } // namespace bool WhileLoopFusibleSinking::IsSinkableFusion(HloInstruction* while_operand) { @@ -243,13 +479,14 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( while_body->ReplaceInstruction(invariant_body_gte, cloned_fusion)); TF_RETURN_IF_ERROR(cloned_fusion->Defuse()); } - return changed; } absl::StatusOr WhileLoopFusibleSinking::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { + VLOG(5) << "Before WhileLoopFusibleSinking " << module->unique_id(); + XLA_VLOG_LINES(5, module->ToString()); call_counts_.clear(); bool changed = false; std::vector while_instrs; @@ -289,6 +526,17 @@ absl::StatusOr WhileLoopFusibleSinking::Run( TrySinkingFusiblesIntoWhileLoop(while_instr)); changed |= result; } + + for (auto* comp : module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instr : comp->instructions()) { + // TODO: b/358837872 - Handle loops with sharding. + if (Match(instr, match::While()) && !instr->has_sharding()) { + TF_ASSIGN_OR_RETURN(bool result, + TryRewritingBroadcastAsAllocateBuffer(instr)); + changed |= result; + } + } + } return changed; } } // namespace xla diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.h b/third_party/xla/xla/service/while_loop_fusible_sinking.h index e1b38bd7c41531..6ba5b6da5d690e 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.h +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.h @@ -27,9 +27,12 @@ limitations under the License. namespace xla { -// Sinks while loop invariant values that happen to be fusibles into the while -// loop body and conditional. This is probably not a win in isolation but may -// unlock further optimizations like fusible folding. +// Sinks values into the while loop body and conditional that fusibles. This is +// probably not a win in isolation but may unlock further optimizations like +// fusible folding. There are two categories: + +// 1. Sinks while loop invariant values into the while +// loop body and conditional. // // state = (..., fusible_graph, ...) // while (pred(state)) { @@ -51,6 +54,37 @@ namespace xla { // tuple trivially loop invariant. WhileLoopSimplifier will later get rid of // `v`. // +// 2. Sinks constant-initialized value, i.e., broadcast(constant) into the while +// body. The high level idea is that we don't want to leave any element of the +// buffer after loop execution as undefined. Therefore, all the elements of the +// buffer must be written to in the body. For element-wise operation 'instr' we +// have: +// forall index i in output shape: instr[i] = f(operand1[i], ...), where +// f is the elementwise operation. +// We can see that all the indices of the output shape is written to. These +// values can sink into the loop and fused later. +// +// state = (..., broadcast(constant), ...) +// while (pred(state)) { +// (..., v, ...) = state +// value = f(v) // f writes to the entire shape of v. +// state = (..., value, ...) +// } +// +// => +// +// state = (..., allocate-buffer(), ...) +// while (pred(state)) { +// i = iteration_var +// (..., v, ...) = state +// new_v = select(i == 0, broadcast(constant), v) +// value = f(new_v) +// state = (..., value, ...) +// } +// +// This transformation replaces the broadcast with a free AllocateBuffer +// outside the while loop with the hope that the broadcast inside the loop +// will be fused. class WhileLoopFusibleSinking : public HloModulePass { public: WhileLoopFusibleSinking() = default; diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc index fc457f290ff895..54dca7ac0d8a19 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking_test.cc @@ -15,8 +15,15 @@ limitations under the License. #include "xla/service/while_loop_fusible_sinking.h" +#include +#include + +#include +#include +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -154,5 +161,197 @@ ENTRY entry { EXPECT_FALSE(changed); } +TEST_F(WhileLoopFusibleSinkingTest, TestPlumbSingleBroadcast) { + const std::string hlo_string_before = R"( + HloModule test + + loop.body { + loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=2 + bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3) + add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855) + constant.1 = s32[]{:T(128)} constant(1) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, get-tuple-element.3) + } + + loop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant(4) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + + ENTRY %main { + param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0 + zero = s32[]{:T(128)} constant(0) + zeros32 = s32[]{:T(128)} constant(0) + broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) + input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, param.1) + ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, + ParseAndReturnVerifiedModule(hlo_string_before)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module_before.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module_before.get(), "while"), + op::While(op::Tuple(_, op::CustomCall(), _, _))); +} + +TEST_F(WhileLoopFusibleSinkingTest, + TestPlumbSingleBroadcastNotFlattenCallGraph) { + const std::string hlo_string_before = R"( + HloModule test + + loop.body { + loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=2 + bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3) + add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855) + constant.1 = s32[]{:T(128)} constant(1) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, get-tuple-element.3) + } + + loop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant(4) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + + ENTRY %main { + param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0 + zero = s32[]{:T(128)} constant(0) + zeros32 = s32[]{:T(128)} constant(0) + broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) + input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, param.1) + while1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body + input2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, param.1) + ROOT while2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input2), condition=loop.condition, body=loop.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, + ParseAndReturnVerifiedModule(hlo_string_before)); + CHECK_OK(FlattenCallGraph{}.Run(module_before.get()).status()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module_before.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module_before.get(), "while1"), + op::While(op::Tuple(_, op::CustomCall(), _, _))); + EXPECT_THAT(FindInstruction(module_before.get(), "while2"), + op::While(op::Tuple(_, op::CustomCall(), _, _))); +} + +TEST_F(WhileLoopFusibleSinkingTest, + TestPlumbSingleBroadcastNoneZeroLoopIterationVar) { + const std::string hlo_string_before = R"( + HloModule cluster_6512412223095190558_f15n_0__.258 + + %wide._functionalize_body_1_const_0__.164.clone.clone.clone.clone (wide.arg_tuple.1: (s32[], f32[2])) -> (s32[], f32[2]) { + %wide.arg_tuple.1 = (s32[], f32[2]{0}) parameter(0) + %get-tuple-element.383 = s32[] get-tuple-element((s32[], f32[2]{0}) %wide.arg_tuple.1), index=0 + %constant.50..sunk.4 = s32[] constant(-1) + %add.48 = s32[] add(s32[] %get-tuple-element.383, s32[] %constant.50..sunk.4) + %get-tuple-element.384 = f32[2]{0} get-tuple-element((s32[], f32[2]{0}) %wide.arg_tuple.1), index=1 + %constant.11..sunk.4 = f32[] constant(1) + %broadcast.19 = f32[2]{0} broadcast(f32[] %constant.11..sunk.4), dimensions={} + %add.49 = f32[2]{0} add(f32[2]{0} %get-tuple-element.384, f32[2]{0} %broadcast.19) + ROOT %tuple.55 = (s32[], f32[2]{0}) tuple(s32[] %add.48, f32[2]{0} %add.49) + } + + %wide.cond_wrapper.236.clone.clone.clone.clone (wide.inputs.1: (s32[], f32[2])) -> pred[] { + %wide.inputs.1 = (s32[], f32[2]{0}) parameter(0) + %get-tuple-element.382 = s32[] get-tuple-element((s32[], f32[2]{0}) %wide.inputs.1), index=0 + %constant.66 = s32[] constant(1) + ROOT %compare.10 = pred[] compare(s32[] %get-tuple-element.382, s32[] %constant.66), direction=GE + } + + %_functionalize_body_0_const_0__.40.clone.clone.clone.clone.clone.clone.clone (arg_tuple.9: (s32[])) -> (s32[]) { + %arg_tuple.9 = (s32[]) parameter(0) + %get-tuple-element.409 = s32[] get-tuple-element((s32[]) %arg_tuple.9), index=0 + %constant.71 = s32[] constant(1) + %add.57 = s32[] add(s32[] %get-tuple-element.409, s32[] %constant.71) + ROOT %tuple.61 = (s32[]) tuple(s32[] %add.57) + } + + %cond_wrapper.120.clone.clone.clone.clone.clone.clone (inputs.7: (s32[])) -> pred[] { + %inputs.7 = (s32[]) parameter(0) + %get-tuple-element.408 = s32[] get-tuple-element((s32[]) %inputs.7), index=0 + %constant.70 = s32[] constant(10) + ROOT %compare.12 = pred[] compare(s32[] %get-tuple-element.408, s32[] %constant.70), direction=LT + } + + ENTRY %cluster_6512412223095190558_f15n_0__.258{ + %arg_tuple.1 = () parameter(0) + %constant.24 = s32[] constant(0) + %tuple.60 = (s32[]) tuple(s32[] %constant.24) + %while.10 = (s32[]) while((s32[]) %tuple.60), condition=%cond_wrapper.120.clone.clone.clone.clone.clone.clone, body=%_functionalize_body_0_const_0__.40.clone.clone.clone.clone.clone.clone.clone + %get-tuple-element.380 = s32[] get-tuple-element((s32[]) %while.10), index=0 + %constant.9 = f32[] constant(0) + %broadcast.10 = f32[2]{0} broadcast(f32[] %constant.9), dimensions={} + %tuple.54 = (s32[], f32[2]{0}) tuple(s32[] %get-tuple-element.380, f32[2]{0} %broadcast.10) + ROOT %while.8 = (s32[], f32[2]{0}) while((s32[], f32[2]{0}) %tuple.54), condition=%wide.cond_wrapper.236.clone.clone.clone.clone, body=%wide._functionalize_body_1_const_0__.164.clone.clone.clone.clone + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, + ParseAndReturnVerifiedModule(hlo_string_before)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module_before.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module_before.get(), "while.8"), + op::While(op::Tuple(_, op::CustomCall(), _))); +} + +TEST_F(WhileLoopFusibleSinkingTest, TestPlumbMultipleBroadcast) { + const std::string hlo_string_before = R"( + HloModule test + + loop.body { + loop_var.1 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.1 = s32[]{:T(128)} get-tuple-element(loop_var.1), index=0 + get-tuple-element.2 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=1 + get-tuple-element.4 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} get-tuple-element(loop_var.1), index=2 + get-tuple-element.3 = s32[4,3,5]{2,1,0} get-tuple-element(loop_var.1), index=3 + bitcast.12855 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} bitcast(get-tuple-element.3) + add.40974 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.2, bitcast.12855) + add.1 = s32[1,1,1,4,3,5]{5,4,3,2,1,0} add(get-tuple-element.4, add.40974) + constant.1 = s32[]{:T(128)} constant(1) + idx = s32[]{:T(128)} add(get-tuple-element.1, constant.1) + ROOT tuple = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(idx, add.40974, add.1, get-tuple-element.3) + } + + loop.condition { + loop_var.2 = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) parameter(0) + get-tuple-element.3 = s32[]{:T(128)} get-tuple-element(loop_var.2), index=0 + constant.2 = s32[]{:T(128)} constant(4) + ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT + } + + ENTRY %main { + param.1 = s32[4,3,5]{2,1,0} iota(), iota_dimension=0 + zero = s32[]{:T(128)} constant(0) + zeros32 = s32[]{:T(128)} constant(0) + broadcast = s32[1,1,1,4,3,5]{5,4,3,2,1,0} broadcast(zeros32) + input = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) tuple(zero, broadcast, broadcast, param.1) + ROOT while = (s32[]{:T(128)}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[1,1,1,4,3,5]{5,4,3,2,1,0}, s32[4,3,5]{2,1,0}) while(input), condition=loop.condition, body=loop.body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module_before, + ParseAndReturnVerifiedModule(hlo_string_before)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + WhileLoopFusibleSinking{}.Run(module_before.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module_before.get(), "while"), + op::While(op::Tuple(_, op::CustomCall(), op::CustomCall(), _, _))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc index ed44547af3fca4..2a9033f0ff9b03 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion.cc @@ -27,13 +27,13 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/map_util.h" #include "xla/service/compile_time_cap.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/while_loop_analysis.h" #include "xla/service/while_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -232,6 +232,7 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody( } if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kAfterAll || instruction->opcode() == HloOpcode::kParameter || !instruction->control_predecessors().empty() || !instruction->control_successors().empty()) { diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc index 7d311df3546e65..eadb19462118f4 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc new file mode 100644 index 00000000000000..19f74c72834d7c --- /dev/null +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller.cc @@ -0,0 +1,212 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_pipeline_unroller.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/service/call_inliner.h" +#include "xla/service/while_util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +/*static*/ +int64_t WhileLoopPipelineUnroller::ComputeWhileLoopPipelineDepth( + const HloInstruction& while_instruction) { + CHECK_EQ(while_instruction.opcode(), HloOpcode::kWhile); + const HloComputation* while_body = while_instruction.while_body(); + + // Look for pattern param -> gte -> root, where indices in the param and + // root tuples are mismatching. + absl::flat_hash_map loop_permutations; + HloInstruction* while_param = while_body->parameter_instruction(0); + HloInstruction* while_root = while_body->root_instruction(); + CHECK_EQ(while_root->opcode(), HloOpcode::kTuple) + << "While Instruction has not been canonicalized to have a tuple shape"; + for (int64_t output_index = 0; output_index < while_root->operand_count(); + ++output_index) { + const HloInstruction* operand = while_root->operand(output_index); + if (operand->opcode() == HloOpcode::kGetTupleElement && + operand->operand(0) == while_param) { + int64_t input_index = operand->tuple_index(); + if (input_index != output_index) { + // Don't try to analyze loops with complicated permutation patterns. + if (loop_permutations.contains(input_index)) { + return 1; + } + loop_permutations.emplace(input_index, output_index); + } + } + } + + // Find all indices at which the pipelined chains start from. + std::vector start_indices; + absl::flat_hash_set output_indices; + for (auto&& [_, output_index] : loop_permutations) { + output_indices.insert(output_index); + } + for (auto&& [input_index, _] : loop_permutations) { + if (!output_indices.contains(input_index)) { + start_indices.push_back(input_index); + } + } + + // Find all pipelining chains. + std::vector> pipelined_chains; + for (int64_t start_index : start_indices) { + std::stack>> stack; + stack.push({start_index, {start_index}}); + while (!stack.empty()) { + auto [current_index, current_chain] = stack.top(); + stack.pop(); + if (!loop_permutations.contains(current_index)) { + pipelined_chains.push_back(std::move(current_chain)); + } else { + int64_t next_index = loop_permutations[current_index]; + current_chain.push_back(next_index); + stack.emplace(next_index, std::move(current_chain)); + } + } + } + + // Compute the pipeline depth of the loop body. + // https://en.wikipedia.org/wiki/Permutation#Order_of_a_permutation + int64_t pipeline_depth = 1; + for (auto&& pipelined_chain : pipelined_chains) { + pipeline_depth = + std::lcm(pipelined_chain.size() + 1, pipeline_depth); + } + + return pipeline_depth; +} + +absl::StatusOr WhileLoopPipelineUnroller::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + std::vector> while_instructions; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + int64_t pipeline_depth = ComputeWhileLoopPipelineDepth(*instruction); + if (pipeline_depth > 1) { + // The pipeline depth is our unroll factor. + while_instructions.emplace_back(instruction, pipeline_depth); + } + } + } + } + + std::vector original_roots; + for (auto&& [while_instruction, unroll_factor] : while_instructions) { + HloComputation* body = while_instruction->while_body(); + HloComputation* condition = while_instruction->while_condition(); + + // Generate the unrolled loop body. This will call the original body + // unroll_factor times. + HloComputation::Builder b( + absl::StrFormat("%s.unrolled_%dx", body->name(), unroll_factor)); + HloInstruction* input_tuple = + b.AddInstruction(HloInstruction::CreateParameter( + 0, while_instruction->shape(), "input_tuple")); + HloComputation* unrolled_body = module->AddEmbeddedComputation(b.Build()); + for (int64_t step = 0; step < unroll_factor; ++step) { + HloComputation* loop_step = module->AddEmbeddedComputation(body->Clone( + absl::StrFormat("unrolled_%dx_step_%d", unroll_factor, step))); + input_tuple = unrolled_body->AddInstruction(HloInstruction::CreateCall( + while_instruction->shape(), {input_tuple}, loop_step)); + TF_ASSIGN_OR_RETURN(auto inline_map, CallInliner::Inline(input_tuple)); + // Find the original bodies root after inlining. This is the inputs for + // the next (unrolled) loop iteration. + input_tuple = inline_map[loop_step->root_instruction()]; + original_roots.push_back(input_tuple); + } + // The final original root is now the root of the unrolled loop. + HloInstruction* unrolled_root = original_roots.back(); + original_roots.pop_back(); + unrolled_body->set_root_instruction(unrolled_root); + + // We need the unrolled loop and the remainder (original) loop to execute + // a combined number of steps equal to the unroll factor. Since the unrolled + // loop on each iteration executes unroll_factor steps, we split the + // work by having the unrolled loop execute num_steps // unroll_factor + // times, and then the remainder loop will execute num_steps % unroll_factor + // times. This can be guaranteed by using the original condition for the + // unrolled loop, but reducing its trip count by (unroll_factor - 1), + // accounting for the original body execution. + HloComputation* unrolled_condition = module->AddEmbeddedComputation( + condition->Clone(absl::StrFormat("unrolled_%dx", unroll_factor))); + // We don't set the unrolled body right away, as it is non-trivial for + // IncrementWhileLoopTripCount to find the trip count variable inside the + // unrolled version. + HloInstruction* unrolled_while_instruction = + while_instruction->parent()->AddInstruction(HloInstruction::CreateWhile( + while_instruction->shape(), unrolled_condition, body, + while_instruction->mutable_operand(0))); + TF_RETURN_IF_ERROR(WhileUtil::IncrementWhileLoopTripCount( + *unrolled_while_instruction, -(unroll_factor - 1))); + unrolled_while_instruction->set_while_body(unrolled_body); + + TF_RETURN_IF_ERROR( + while_instruction->ReplaceOperandWith(0, unrolled_while_instruction)); + } + + const bool changed = !while_instructions.empty(); + if (changed) { + // We're unrolling the loop to remove aliasing copies, not to find better + // scheduling opportunities. + // Create a global barrier at the boundary of steps, so sideffecting ops + // don't get moved to neighbouring steps during scheduling. + // This creates a soft guarantee that the unrolled loop steps will have + // an identical schedule to their original counterpart. + for (HloInstruction* original_root : original_roots) { + HloInstruction* sideeffect_barrier = original_root->AddInstruction( + HloInstruction::CreateAfterAll(original_root->operands())); + for (HloInstruction* user : original_root->users()) { + if (user->shape().IsToken()) { + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(sideeffect_barrier)); + } + } + } + + // When we cloned the loop body for each unrolled step, we didn't + // recursively clone all the nested computations. FCG will take care of this + // for us. + FlattenCallGraph fcg; + TF_RETURN_IF_ERROR(fcg.Run(module).status()); + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + + return changed; +} +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller.h b/third_party/xla/xla/service/while_loop_pipeline_unroller.h new file mode 100644 index 00000000000000..4e5318f8f90385 --- /dev/null +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller.h @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_WHILE_LOOP_PIPELINE_UNROLLER_H_ +#define XLA_SERVICE_WHILE_LOOP_PIPELINE_UNROLLER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +// Pipelined loops have inherent aliasing interference in them, due to loop +// inputs shifting positions across iterations. This results in copy insertion +// adding copies for each pipelined input. In some cases extra copies on top of +// this are needed to properly sequence all the mandatory aliasing copies. +// +// It is not necessary to insert copies to resolve interference in this case. +// The loop inputs, despite directly carried out as loop outputs, still have +// finite lifetimes across a certain amount of loop iterations. If the loop was +// unrolled just enough times to have the lifetimes of its inputs end before the +// outputs would be materialized, this would implicitly remove any sort of +// interference. The drawback of this approach is that it can in some cases +// drastically increase compile times due to linearly increasing graph size. +class WhileLoopPipelineUnroller : public HloModulePass { + public: + std::string_view name() const override { + return "while_loop_pipeline_unroller"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + // The pipeline depth of a while loop is the number of loop iterations that + // pipelined loop inputs live throughout. This is used to determine how many + // times to unroll the loop in order to remove aliasing interference. + static int64_t ComputeWhileLoopPipelineDepth( + const HloInstruction& while_instruction); +}; +} // namespace xla + +#endif // XLA_SERVICE_WHILE_LOOP_PIPELINE_UNROLLER_H_ diff --git a/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc new file mode 100644 index 00000000000000..f8618a304514c6 --- /dev/null +++ b/third_party/xla/xla/service/while_loop_pipeline_unroller_test.cc @@ -0,0 +1,188 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/while_loop_pipeline_unroller.h" + +#include +#include + +#include +#include "absl/container/inlined_vector.h" +#include "xla/hlo/analysis/hlo_ordering.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/copy_insertion.h" +#include "xla/test_helpers.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +// Copied from xla/service/copy_insertion_test.cc +int64_t CountCopies(const HloComputation& computation) { + int64_t count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +class WhileLoopPipelineUnrollerTest : public HloTestBase { + protected: + WhileLoopPipelineUnrollerTest() = default; +}; + +TEST_F(WhileLoopPipelineUnrollerTest, PipelinedLoop) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + input_tuple.0 = (s32[], s32[], s32[], s32[]) parameter(0) + arg.0 = get-tuple-element(input_tuple.0), index=0 + arg.1 = get-tuple-element(input_tuple.0), index=1 + arg.2 = get-tuple-element(input_tuple.0), index=2 + arg.3 = get-tuple-element(input_tuple.0), index=3 + + one.0 = s32[] constant(1) + out.0 = add(arg.0, one.0) + + add.0 = add(arg.3, one.0) + ROOT output_tuple.0 = tuple(arg.1, arg.2, out.0, add.0) +} + +condition { + input_tuple.0 = (s32[], s32[], s32[], s32[]) parameter(0) + arg.3 = get-tuple-element(input_tuple.0), index=3 + three.0 = s32[] constant(3) + ROOT pred.0 = compare(arg.3, three.0), direction=LT +} + +ENTRY main { + while_tuple.0 = (s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.0 = (s32[], s32[], s32[], s32[]) while(while_tuple.0), body=body, condition=condition +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + WhileLoopPipelineUnroller wlpu; + ASSERT_IS_OK(wlpu.Run(module.get()).status()); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + + const HloInstruction* original_loop = + FindInstruction(module.get(), "while.0"); + // The original loop should have 3 copies. + // arg.1 moves to index 0. + // arg.2 moves to index 1. + // out.0 moves to index 2. + EXPECT_EQ(CountCopies(*original_loop->while_body()), 3); + + const HloInstruction* unrolled_loop = original_loop->operand(0); + EXPECT_EQ(unrolled_loop->opcode(), HloOpcode::kWhile); + // There should be no copies inserted into the unrolled loop. + EXPECT_EQ(CountCopies(*unrolled_loop->while_body()), 0); +} + +TEST_F(WhileLoopPipelineUnrollerTest, PipelinedLoopWithInfeed) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + input_tuple.0 = (s32[], s32[], s32[], token[], s32[]) parameter(0) + arg.0 = get-tuple-element(input_tuple.0), index=0 + arg.1 = get-tuple-element(input_tuple.0), index=1 + arg.2 = get-tuple-element(input_tuple.0), index=2 + arg.3 = get-tuple-element(input_tuple.0), index=3 + arg.4 = get-tuple-element(input_tuple.0), index=4 + + infeed.0 = (s32[], token[]) infeed(arg.3) + infeed_value.0 = get-tuple-element(infeed.0), index=0 + infeed_output_token.0 = get-tuple-element(infeed.0), index=1 + + out.0 = add(arg.0, arg.1) + + one.0 = s32[] constant(1) + add.0 = add(arg.4, one.0) + ROOT output_tuple.0 = tuple(out.0, arg.2, infeed_value.0, infeed_output_token.0, add.0) +} + +condition { + input_tuple.0 = (s32[], s32[], s32[], token[], s32[]) parameter(0) + arg.4 = get-tuple-element(input_tuple.0), index=4 + three.0 = s32[] constant(3) + ROOT pred.0 = compare(arg.4, three.0), direction=LT +} + +ENTRY main { + infeed_input_token.0 = after-all() + infeed.0 = (s32[], token[]) infeed(infeed_input_token.0) + infeed_value.0 = s32[] get-tuple-element(infeed.0), index=0 + infeed_output_token.0 = token[] get-tuple-element(infeed.0), index=1 + + infeed.1 = (s32[], token[]) infeed(infeed_output_token.0) + infeed_value.1 = s32[] get-tuple-element(infeed.1), index=0 + infeed_output_token.1 = token[] get-tuple-element(infeed.1), index=1 + + zero.0 = s32[] constant(0) + while_tuple.0 = tuple(zero.0, infeed_value.0, infeed_value.1, infeed_output_token.1, zero.0) + while.0 = (s32[], s32[], s32[], token[], s32[]) while(while_tuple.0), body=body, condition=condition + + ROOT root.0 = get-tuple-element(while.0), index=0 +} + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + WhileLoopPipelineUnroller wlpu; + ASSERT_IS_OK(wlpu.Run(module.get()).status()); + CopyInsertion copy_insertion(nullptr, + /*use_region_based_live_range_analysis=*/-1); + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + + const HloInstruction* original_loop = + FindInstruction(module.get(), "while.0"); + // The original loop should have 1 copy. + // arg.2 moves to index 1. + EXPECT_EQ(CountCopies(*original_loop->while_body()), 1); + + const HloInstruction* unrolled_loop = original_loop->operand(0); + EXPECT_EQ(unrolled_loop->opcode(), HloOpcode::kWhile); + // There should be no copies inserted into the unrolled loop. + EXPECT_EQ(CountCopies(*unrolled_loop->while_body()), 0); + + // All infeeds in the unrolled body need to be ordered with respect to each + // other. + absl::InlinedVector unrolled_infeeds; + for (HloInstruction* instruction : + unrolled_loop->while_body()->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed) { + unrolled_infeeds.push_back(instruction); + } + } + DependencyHloOrdering dlo(module.get()); + for (HloInstruction* lhs : unrolled_infeeds) { + for (HloInstruction* rhs : unrolled_infeeds) { + if (lhs != rhs) { + EXPECT_TRUE(dlo.ExecutesBefore(lhs, rhs) || + dlo.ExecutesBefore(rhs, lhs)); + } + } + } +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/while_loop_simplifier.cc b/third_party/xla/xla/service/while_loop_simplifier.cc index d6baca8c4e3823..a11d826040e7f8 100644 --- a/third_party/xla/xla/service/while_loop_simplifier.cc +++ b/third_party/xla/xla/service/while_loop_simplifier.cc @@ -28,19 +28,19 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/call_inliner.h" #include "xla/service/hlo_creation_utils.h" -#include "xla/service/hlo_dce.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/while_loop_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/while_loop_simplifier_test.cc b/third_party/xla/xla/service/while_loop_simplifier_test.cc index 494271c2023ceb..273f832ec5ce76 100644 --- a/third_party/xla/xla/service/while_loop_simplifier_test.cc +++ b/third_party/xla/xla/service/while_loop_simplifier_test.cc @@ -24,11 +24,11 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal_util.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" -#include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/service/while_loop_trip_count_annotator.h b/third_party/xla/xla/service/while_loop_trip_count_annotator.h index d4185e558f958e..ee7377423b8b02 100644 --- a/third_party/xla/xla/service/while_loop_trip_count_annotator.h +++ b/third_party/xla/xla/service/while_loop_trip_count_annotator.h @@ -16,38 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ #define XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" - -namespace xla { - -// Pass that annotates `while` loops with known trip counts. -// -// The annotation is stored as a backend-config on the while loop node. -// -// This pass should run after all passes that might semantically modify a while -// loop, e.g. by unrolling it. Otherwise, a loop could end up with a -// backend-config that doesn't match its true trip-count. -// -// This pass does some pattern-matching on loop bodies and conditions, so it -// should run after most HLO simplifications and before fusion and layout -// assignment, which make pattern matching much more difficult by e.g. -// introducing `copy` nodes. -class WhileLoopTripCountAnnotator : public HloModulePass { - public: - ~WhileLoopTripCountAnnotator() override {} - absl::string_view name() const override { - return "while-loop-trip-count-annotator"; - } - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/while_loop_trip_count_annotator.h" #endif // XLA_SERVICE_WHILE_LOOP_TRIP_COUNT_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index 053c20aa8aa231..d2731d22f61575 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -32,24 +32,24 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "xla/comparison_util.h" +#include "xla/hlo/analysis/while_loop_analysis.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/pass/hlo_pass_fix.h" +#include "xla/hlo/transforms/simplifiers/flatten_call_graph.h" +#include "xla/hlo/transforms/simplifiers/tuple_simplifier.h" #include "xla/hlo/utils/hlo_query.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/overflow_util.h" #include "xla/service/call_inliner.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/flatten_call_graph.h" #include "xla/service/hlo_creation_utils.h" #include "xla/service/hlo_cse.h" #include "xla/service/pattern_matcher.h" -#include "xla/service/tuple_simplifier.h" -#include "xla/service/while_loop_analysis.h" #include "xla/service/while_loop_constant_sinking.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -514,13 +514,31 @@ std::optional MatchShapeCoveringDynamicIndexInstruction( return std::nullopt; } - // The shape's broadcast_dim must be exactly equal to the loop trip count. if (operand->shape().dimensions(dynamic_index) != config.trip_count) { - VLOG(3) << "The shape's broadcast_dim must be exactly equal to the loop " - "trip count."; + VLOG(3) << "The dynamic_index dimension size of the operand must be equal " + "to the loop trip count."; return std::nullopt; } + if (opcode == HloOpcode::kDynamicSlice) { + const Shape& result_shape = instr->shape(); + if (result_shape.dimensions(dynamic_index) != 1) { + VLOG(3) << "The slice size on the dynamic_index dimension must be 1."; + return std::nullopt; + } + + const Shape& operand_shape = operand->shape(); + CHECK_EQ(result_shape.dimensions_size(), operand_shape.dimensions_size()); + for (int64_t i = 0; i < result_shape.dimensions_size(); ++i) { + if (i != dynamic_index && + result_shape.dimensions(i) != operand_shape.dimensions(i)) { + VLOG(3) << "The slice sizes must match the operand-shape on " + "non-dynamic-index dimensions."; + return std::nullopt; + } + } + } + return dynamic_index; } diff --git a/third_party/xla/xla/service/while_loop_unroller.h b/third_party/xla/xla/service/while_loop_unroller.h index face63336372a1..619c11697435bc 100644 --- a/third_party/xla/xla/service/while_loop_unroller.h +++ b/third_party/xla/xla/service/while_loop_unroller.h @@ -65,7 +65,10 @@ struct UnrollResult { // 1. All start indices must be constant zero except only a single dimension. // 2. The start index of that dimension should be equal to the enclosing loop // induction variable. -// 3. And, the size of that dimension must match the loop trip count. +// 3. The size of that dimension must match the loop trip count. +// 4. For dynamic-slice, the slice size for the induction variable dimension is +// 1, and the size of all other dimensions is the same as the shape of the +// input. // If so, it returns the dynamic index. std::optional MatchShapeCoveringDynamicIndexInstruction( const HloInstruction* instr, const HloInstruction* input, HloOpcode opcode, @@ -86,7 +89,7 @@ std::optional MatchEffectivelyStaticDynamicSliceInsideLoop( // // The trip count for loops is calculated based on // `MatchTrivialLoopTripCount` function in -// tensorflow/compiler/xla/service/while_loop_analysis.h` +// tensorflow/compiler/xla/hlo/analysis/while_loop_analysis.h` // // TODO(b/301472793): Add utility functions to unroll specific loops. class WhileLoopUnroller : public HloModulePass { diff --git a/third_party/xla/xla/service/while_loop_unroller_test.cc b/third_party/xla/xla/service/while_loop_unroller_test.cc index fa44905344f66d..952b6f5240a95f 100644 --- a/third_party/xla/xla/service/while_loop_unroller_test.cc +++ b/third_party/xla/xla/service/while_loop_unroller_test.cc @@ -177,7 +177,7 @@ WhileLoopUnrollerTest::MakeModuleWithNestedLoopBodyIndirectInc(int num_iters) { constant.3 = s32[] constant(0) tuple.1 = (s32[], s32[], s32[3]{0}) tuple(constant.3, constant.1, get-tuple-element.22) inner-while = (s32[], s32[], s32[3]{0}) while(tuple.1), condition= - SimpleLoop.condition, body=SimpleLoop.body + SimpleLoop.condition, body=SimpleLoop.body get-tuple-element.6 = s32[3]{0} get-tuple-element(inner-while), index=2 inc = s32[] add(get-tuple-element.1, get-tuple-element.2) ROOT tuple = (s32[], s32[], s32[3]{0}, s32[10]{0}) tuple(inc, get-tuple-element.2, get-tuple-element.6, output) @@ -269,22 +269,22 @@ std::unique_ptr WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) { std::string hlo_string_template = R"( HloModule SimpleLoop - + %reduction { %x = f32[] parameter(0) %y = f32[] parameter(1) ROOT %add = f32[] add(f32[] %x, f32[] %y) } - + SimpleLoop.body { loop_var.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 get-tuple-element.2 = f32[1024, 1024] get-tuple-element(loop_var.1), index=1 get-tuple-element.3 = f32[1024, 1024] get-tuple-element(loop_var.1), index=2 - + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] get-tuple-element.2), channel_id=1, replica_groups={{0}}, to_apply=%reduction %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] get-tuple-element.3) - + constant.1 = s32[] constant(1) add = s32[] add(get-tuple-element.1, constant.1) ROOT tuple = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(add, get-tuple-element.2, %accumulation) @@ -298,10 +298,10 @@ WhileLoopUnrollerTest::MakeModuleWithSimpleLoopAllReduce(int num_iters) { ENTRY SimpleLoop { %param.1 = f32[1024, 1024] parameter(0) constant.3 = s32[] constant(0) - + %accumulation_buffer_init = f32[] constant(0) %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} - + tuple.1 = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(constant.3, %param.1, %accumulation_buffer) ROOT while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(tuple.1), condition=SimpleLoop.condition, body=SimpleLoop.body } @@ -987,6 +987,49 @@ TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDS) { .has_value()); } +TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSShapeMismatch) { + const std::string hlo_string = R"( + HloModule SimpleLoop + body { + param = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) parameter(0) + idx = s32[]{:T(128)} get-tuple-element(param), index=0 + constant1 = s32[]{:T(128)} constant(1) + new-idx = s32[]{:T(128)} add(idx, constant1) + update = s32[3,10]{1,0} get-tuple-element(param), index=1 + input = s32[3,11]{1,0} get-tuple-element(param), index=2 + zero = s32[] constant(0) + slice = s32[1,10] dynamic-slice(input, idx, zero), dynamic_slice_sizes={1,10} + new-update = s32[3,10]{1,0} dynamic-update-slice(update, slice, idx, zero) + ROOT tuple = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) tuple(new-idx, new-update, input) + } + condition { + param = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) parameter(0) + idx = s32[] get-tuple-element(param), index=0 + constant3 = s32[]{:T(128)} constant(3) + ROOT less-than = pred[] compare(idx, constant3), direction=LT + } + ENTRY main { + constant0 = s32[]{:T(128)} constant(0) + init-update = s32[3,10]{1,0} constant({...}) + init-input = s32[3,11]{1,0} constant({...}) + init-while = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) tuple(constant0, init-update, init-input) + ROOT while = (s32[]{:T(128)}, s32[3,10]{1,0}, s32[3,11]{1,0}) while(init-while), condition= + condition, body=body + } + )"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + HloInstruction* loop = module->entry_computation()->root_instruction(); + auto config = WhileLoopUnroller::IsLoopUnrollable(loop); + EXPECT_TRUE(config.has_value()); + HloComputation* body = module->GetComputationWithName("body"); + HloInstruction* input = body->GetInstructionWithName("input"); + HloInstruction* instr = body->GetInstructionWithName("slice"); + EXPECT_FALSE(MatchShapeCoveringDynamicIndexInstruction( + instr, input, HloOpcode::kDynamicSlice, config.value()) + .has_value()); +} + TEST_F(WhileLoopUnrollerTest, MatchShapeCoveringDSNested) { std::string hlo_string_template = R"( HloModule SimpleLoop @@ -1127,7 +1170,7 @@ TEST_F(WhileLoopUnrollerTest, IsEffectivelyStaticDynamicSlice) { %dynamic-slice.static = s8[1,128,128] dynamic-slice(s8[6,128,128] %param_0.51117, static.p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.static) } - + %fused_computation.slice.2 (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[6,128,128] parameter(0) dynamic.p1 = s32[] parameter(1) @@ -1270,10 +1313,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopWithCustomCallNonTupleForRoot) { )"; auto m = ParseAndReturnVerifiedModule(hlo_string).value(); UnrollConfig config; - EXPECT_TRUE(WhileLoopUnroller(/*unroll_factor=*/-1, - /*wrap_in_trivial_loop=*/false, config) - .Run(m.get()) - .value()); + EXPECT_FALSE(WhileLoopUnroller(/*unroll_factor=*/-1, + /*wrap_in_trivial_loop=*/false, config) + .Run(m.get()) + .value()); } TEST_F(WhileLoopUnrollerTest, SimpleLoopWithCustomCall) { @@ -1306,10 +1349,10 @@ TEST_F(WhileLoopUnrollerTest, SimpleLoopWithCustomCall) { )"; auto m = ParseAndReturnVerifiedModule(hlo_string).value(); UnrollConfig config; - EXPECT_TRUE(WhileLoopUnroller(/*unroll_factor=*/-1, - /*wrap_in_trivial_loop=*/false, config) - .Run(m.get()) - .value()); + EXPECT_FALSE(WhileLoopUnroller(/*unroll_factor=*/-1, + /*wrap_in_trivial_loop=*/false, config) + .Run(m.get()) + .value()); } } // namespace diff --git a/third_party/xla/xla/service/while_util.cc b/third_party/xla/xla/service/while_util.cc index d7f85d5e4cb3cc..a2080f6f594756 100644 --- a/third_party/xla/xla/service/while_util.cc +++ b/third_party/xla/xla/service/while_util.cc @@ -27,6 +27,8 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "xla/comparison_util.h" @@ -37,6 +39,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/call_inliner.h" #include "xla/service/hlo_creation_utils.h" +#include "xla/service/pattern_matcher.h" #include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -365,4 +368,130 @@ WhileUtil::GetGTEsMapForWhileConditional( return result; } +/*static*/ +absl::Status WhileUtil::IncrementWhileLoopTripCount( + const HloInstruction& while_instruction, int32_t increment) { + CHECK_EQ(while_instruction.opcode(), HloOpcode::kWhile); + const HloComputation* while_body = while_instruction.while_body(); + HloComputation* while_conditional = while_instruction.while_condition(); + + HloInstruction* compare = while_conditional->root_instruction(); + if (compare->opcode() != HloOpcode::kCompare) { + return absl::InvalidArgumentError("While condition root is not a compare"); + } + HloInstruction* induction_var; + const HloInstruction* trip_count; + if ((compare->comparison_direction() == ComparisonDirection::kGt) || + (compare->comparison_direction() == ComparisonDirection::kGe)) { + induction_var = compare->mutable_operand(1); + trip_count = compare->mutable_operand(0); + } else if ((compare->comparison_direction() == ComparisonDirection::kLt) || + (compare->comparison_direction() == ComparisonDirection::kLe)) { + induction_var = compare->mutable_operand(0); + trip_count = compare->mutable_operand(1); + } else { + return absl::InvalidArgumentError("Unhandled comparison direction"); + } + + // Verify that the induction variable flows through directly inside the loop + // condition. + if (induction_var->user_count() > 1) { + return absl::InvalidArgumentError( + "Loop induction variable has multiple users"); + } + if (induction_var->opcode() != HloOpcode::kGetTupleElement && + induction_var->operand(0) != + while_conditional->parameter_instruction(0)) { + return absl::InvalidArgumentError( + "Loop induction variable does not pass through unmodified through the " + "condition body"); + } + + // Verify that the induction variable is being incremented exactly by 1 inside + // loop body. + bool found_induction_var = false; + for (const HloInstruction* gte : + while_body->parameter_instruction(0)->users()) { + if (gte->tuple_index() == induction_var->tuple_index()) { + if (gte->user_count() != 1) { + return absl::InvalidArgumentError( + "Loop induction variable has multiple users"); + } + const HloInstruction* add = gte->users()[0]; + if (!Match(add, + match::AddAnyOrder(match::GetTupleElement().WithTupleIndex( + induction_var->tuple_index()), + match::ConstantScalar(1)))) { + return absl::InvalidArgumentError( + "Loop induction variable is not being incremented exactly by one " + "(1)"); + } + found_induction_var = true; + break; + } + } + if (!found_induction_var) { + return absl::InvalidArgumentError( + "Could not match induction variable between loop body and condition"); + } + + // Verify that the trip count is: + // a) A compile time constant. + // b) A run-time constant. + // c) A pure operation with operands that are either a) or b) + if (trip_count->opcode() != HloOpcode::kConstant) { + auto is_trip_count = [while_conditional](const HloInstruction* trip_count) { + return trip_count->opcode() == HloOpcode::kGetTupleElement && + trip_count->operand(0) == + while_conditional->parameter_instruction(0); + }; + const HloInstruction* runtime_trip_count = nullptr; + if (is_trip_count(trip_count)) { + runtime_trip_count = trip_count; + } else { + if (trip_count->HasSideEffect()) { + return absl::InvalidArgumentError( + "Trip count passes through sideeffecting op"); + } + for (HloInstruction* operand : trip_count->operands()) { + if (operand->opcode() == HloOpcode::kConstant) { + continue; + } else if (is_trip_count(operand)) { + // Check if we already found something that looks like the runtime + // trip count. + if (runtime_trip_count != nullptr) { + return absl::InvalidArgumentError( + "Could not identify trip count variable"); + } + runtime_trip_count = operand; + } else { + return absl::InvalidArgumentError( + "Trip count consists of non-constant variable"); + } + } + } + // Verify that the runtime trip count stays constant through the while + // body. + auto invariant_gtes = GetInvariantGTEsForWhileBody(*while_body); + if (!absl::c_any_of(invariant_gtes, + [runtime_trip_count](HloInstruction* invariant_gte) { + return invariant_gte->tuple_index() == + runtime_trip_count->tuple_index(); + })) { + return absl::InvalidArgumentError( + "Trip count is not constant throughout the while loop"); + } + } + + // Decrementing the induction var is equivalent to incrementing the + // trip count. + HloInstruction* trip_count_increment = + while_conditional->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR0(-increment))); + HloInstruction* decremented_induction_var = while_conditional->AddInstruction( + HloInstruction::CreateBinary(induction_var->shape(), HloOpcode::kAdd, + induction_var, trip_count_increment)); + return induction_var->ReplaceAllUsesWith(decremented_induction_var); +} + } // namespace xla diff --git a/third_party/xla/xla/service/while_util.h b/third_party/xla/xla/service/while_util.h index 05c44c09110ca7..69c44e0a6301dc 100644 --- a/third_party/xla/xla/service/while_util.h +++ b/third_party/xla/xla/service/while_util.h @@ -23,8 +23,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/call_inliner.h" #include "xla/xla_data.pb.h" @@ -123,6 +125,11 @@ class WhileUtil { // question. static absl::flat_hash_map> GetGTEsMapForWhileConditional(const HloComputation& while_conditional); + + // Modifies the trip count of the loop by the given increment. + // Requires loop body to be incrementing the induction variable by exactly 1. + static absl::Status IncrementWhileLoopTripCount( + const HloInstruction& while_instruction, int32_t increment); }; } // namespace xla diff --git a/third_party/xla/xla/service/while_util_test.cc b/third_party/xla/xla/service/while_util_test.cc index a760bbf9f10318..f8e597ecc43932 100644 --- a/third_party/xla/xla/service/while_util_test.cc +++ b/third_party/xla/xla/service/while_util_test.cc @@ -16,15 +16,19 @@ limitations under the License. #include "xla/service/while_util.h" #include +#include +#include +#include #include "absl/algorithm/container.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/testlib/verified_hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "tsl/platform/statusor.h" @@ -218,5 +222,191 @@ ENTRY main { }; EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1); } + +TEST_F(WhileUtilTest, TryIncrementNonCounterTripCount) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + one.0 = s32[] constant(2) + add.0 = s32[] add(gte.0, one.0) + ROOT tuple.0 = (s32[], s32[]) tuple(add.0, gte.1) +} + +cond { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + minus-one.0 = s32[] constant(-1) + add.0 = add(gte.1, minus-one.0) + ROOT compare.0 = compare(gte.0, add.0), direction=LT +} + +ENTRY main { + param.0 = (s32[], s32[]) parameter(0) + ROOT while = while(param.0), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + const HloComputation* main = module->GetComputationWithName("main"); + const HloInstruction* while_instr = main->root_instruction(); + // Loop body increments induction variable by 2, in this case we should fail. + EXPECT_FALSE( + WhileUtil::IncrementWhileLoopTripCount(*while_instr, /*increment=*/1) + .ok()); +} + +TEST_F(WhileUtilTest, TryIncrementNonConstantTripCount) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + one.0 = s32[] constant(1) + add.0 = s32[] add(gte.0, one.0) + add.1 = s32[] add(gte.1, one.0) + ROOT tuple.0 = (s32[], s32[]) tuple(add.0, add.1) +} + +cond { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + minus-one.0 = s32[] constant(-1) + add.0 = add(gte.1, minus-one.0) + ROOT compare.0 = compare(gte.0, add.0), direction=LT +} + +ENTRY main { + param.0 = (s32[], s32[]) parameter(0) + ROOT while = while(param.0), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + const HloComputation* main = module->GetComputationWithName("main"); + const HloInstruction* while_instr = main->root_instruction(); + // Loop body increments trip count, in this case we should fail. + EXPECT_FALSE( + WhileUtil::IncrementWhileLoopTripCount(*while_instr, /*increment=*/1) + .ok()); +} + +TEST_F(WhileUtilTest, TryIncrementSideEffecting) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + one.0 = s32[] constant(1) + add.0 = s32[] add(gte.0, one.0) + ROOT tuple.0 = (s32[], s32[]) tuple(add.0, gte.1) +} + +cond { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + minus-one.0 = s32[] constant(-1) + add.0 = s32[] custom-call(gte.1, minus-one.0), custom_call_target="add", custom_call_has_side_effect=true + ROOT compare.0 = compare(gte.0, add.0), direction=LT +} + +ENTRY main { + param.0 = (s32[], s32[]) parameter(0) + ROOT while = while(param.0), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + const HloComputation* main = module->GetComputationWithName("main"); + const HloInstruction* while_instr = main->root_instruction(); + // The trip count is modified with a side effecting op, in this case we + // should fail. + EXPECT_FALSE( + WhileUtil::IncrementWhileLoopTripCount(*while_instr, /*increment=*/1) + .ok()); +} + +TEST_F(WhileUtilTest, IncrementTripCountLt) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + one.0 = s32[] constant(1) + add.0 = s32[] add(gte.0, one.0) + ROOT tuple.0 = (s32[], s32[]) tuple(add.0, gte.1) +} + +cond { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + minus-one.0 = s32[] constant(-1) + add.0 = add(gte.1, minus-one.0) + ROOT compare.0 = compare(gte.0, add.0), direction=LT +} + +ENTRY main { + param.0 = (s32[], s32[]) parameter(0) + ROOT while = while(param.0), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + const HloComputation* main = module->GetComputationWithName("main"); + const HloInstruction* while_instr = main->root_instruction(); + TF_EXPECT_OK( + WhileUtil::IncrementWhileLoopTripCount(*while_instr, /*increment=*/1)); + + const HloComputation* cond = module->GetComputationWithName("cond"); + EXPECT_THAT(cond->root_instruction()->operand(0), + op::Add(op::GetTupleElement(), op::Constant())); +} + +TEST_F(WhileUtilTest, IncrementTripCountGt) { + constexpr std::string_view hlo = R"( +HloModule main + +body { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + one.0 = s32[] constant(1) + add.0 = s32[] add(gte.1, one.0) + ROOT tuple.0 = (s32[], s32[]) tuple(gte.0, add.0) +} + +cond { + param.0 = (s32[], s32[]) parameter(0) + gte.0 = get-tuple-element(param.0), index=0 + gte.1 = get-tuple-element(param.0), index=1 + minus-one.0 = s32[] constant(-1) + add.0 = add(gte.0, minus-one.0) + ROOT compare.0 = compare(add.0, gte.1), direction=GT +} + +ENTRY main { + param.0 = (s32[], s32[]) parameter(0) + ROOT while = while(param.0), condition=cond, body=body +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + const HloComputation* main = module->GetComputationWithName("main"); + const HloInstruction* while_instr = main->root_instruction(); + TF_EXPECT_OK( + WhileUtil::IncrementWhileLoopTripCount(*while_instr, /*increment=*/1)); + + const HloComputation* cond = module->GetComputationWithName("cond"); + EXPECT_THAT(cond->root_instruction()->operand(1), + op::Add(op::GetTupleElement(), op::Constant())); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/xla_compile_result.proto b/third_party/xla/xla/service/xla_compile_result.proto index 5846b8b11bacc2..7634661596a5ad 100644 --- a/third_party/xla/xla/service/xla_compile_result.proto +++ b/third_party/xla/xla/service/xla_compile_result.proto @@ -19,7 +19,7 @@ package xla; import "google/protobuf/duration.proto"; import "xla/service/hlo.proto"; -import "tsl/protobuf/status.proto"; +import "xla/tsl/protobuf/status.proto"; // Statistics on how long various parts of compilation took. // Not all durations may be relevant for all producers of this message, in diff --git a/third_party/xla/xla/service/zero_sized_hlo_elimination.h b/third_party/xla/xla/service/zero_sized_hlo_elimination.h index de7488b959ae80..3da82bd21355bb 100644 --- a/third_party/xla/xla/service/zero_sized_hlo_elimination.h +++ b/third_party/xla/xla/service/zero_sized_hlo_elimination.h @@ -16,23 +16,7 @@ limitations under the License. #ifndef XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ #define XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ -#include "absl/container/flat_hash_set.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/pass/hlo_pass_interface.h" +// The current header will be deprecated in favour of the following. +#include "xla/hlo/transforms/simplifiers/zero_sized_hlo_elimination.h" -// HLO pass that replaces zero sized Hlos with a zero sized constant literal. -namespace xla { -class ZeroSizedHloElimination : public HloModulePass { - public: - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - absl::string_view name() const override { - return "zero_sized_hlo_elimination"; - } -}; -} // namespace xla #endif // XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ diff --git a/third_party/xla/xla/shape_tree.h b/third_party/xla/xla/shape_tree.h index ba4e13560fd2c3..fd4448e0265089 100644 --- a/third_party/xla/xla/shape_tree.h +++ b/third_party/xla/xla/shape_tree.h @@ -31,7 +31,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 01f7cacfc9b441..9def58503a0854 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -1067,6 +1067,9 @@ Shape ShapeUtil::PrependMajorDimension(int64_t bound, Shape shape) { } else { Shape new_shape = original; new_shape.set_element_type(type); + if (new_shape.has_layout() && type == PRED) { + new_shape.mutable_layout()->set_element_size_in_bits(0); + } return new_shape; } } diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index e239a96ce6aa02..2ed50604569880 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -1224,6 +1224,12 @@ TEST(ShapeUtilTest, Int4ShapeSize) { layout->set_element_size_in_bits(4); EXPECT_EQ(ShapeUtil::ArrayDataSize(int4_shape2), 9216 * 6144 / 2); EXPECT_EQ(ShapeUtil::ArraySize(int4_shape2), 9216 * 6144 / 2); + + // Changing the type to PRED should clear element_size_in_bits. + Shape pred_shape = ShapeUtil::ChangeElementType(int4_shape, PRED); + EXPECT_EQ(pred_shape.layout().element_size_in_bits(), 0); + Shape u4_shape = ShapeUtil::ChangeElementType(int4_shape, U4); + EXPECT_EQ(u4_shape.layout().element_size_in_bits(), 4); } TEST(XlaShapeUtilTest, ZeroSize) { diff --git a/third_party/xla/xla/sharding_op_util.cc b/third_party/xla/xla/sharding_op_util.cc index 40154c61f45c63..16a26e4a5b8e09 100644 --- a/third_party/xla/xla/sharding_op_util.cc +++ b/third_party/xla/xla/sharding_op_util.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/service/hlo_lexer.h" +#include "xla/hlo/parser/hlo_lexer.h" #include "xla/status_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/side_effect_util.cc b/third_party/xla/xla/side_effect_util.cc index f7a7f198f840e6..18e0144d863b53 100644 --- a/third_party/xla/xla/side_effect_util.cc +++ b/third_party/xla/xla/side_effect_util.cc @@ -59,4 +59,14 @@ const char kXlaBufferPlacementAttr[] = "_xla_buffer_placement"; const char kXlaBufferPlacementParam[] = "arg"; +const char kXlaCollectiveMatmulAttr[] = "_xla_collective_matmul"; + +const char kXlaCollectiveMatmulLhsAg[] = "lhs_ag"; + +const char kXlaCollectiveMatmulRhsAg[] = "rhs_ag"; + +const char kXlaCollectiveMatmulRs[] = "rs"; + +const char kXlaMultiRecvCountAttr[] = "_xla_multi_recv_count"; + } // namespace xla diff --git a/third_party/xla/xla/side_effect_util.h b/third_party/xla/xla/side_effect_util.h index 756ecf82f6b93d..f16949fff635ba 100644 --- a/third_party/xla/xla/side_effect_util.h +++ b/third_party/xla/xla/side_effect_util.h @@ -66,6 +66,18 @@ extern const char kXlaTableId[]; // XLA frontend attribute for buffer placement. extern const char kXlaBufferPlacementAttr[]; extern const char kXlaBufferPlacementParam[]; + +// XLA frontend attribute for collective matmul control. +extern const char kXlaCollectiveMatmulAttr[]; + +// XLA frontend attribute values for kXlaCollectiveMatmulAttr +extern const char kXlaCollectiveMatmulLhsAg[]; +extern const char kXlaCollectiveMatmulRhsAg[]; +extern const char kXlaCollectiveMatmulRs[]; + +// XLA frontend attribute for specifying the number of sends this recv should +// match. +extern const char kXlaMultiRecvCountAttr[]; } // namespace xla #endif // XLA_SIDE_EFFECT_UTIL_H_ diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 17f517730afe29..94c963ee36064a 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -31,6 +31,7 @@ bzl_library( srcs = ["build_defs.bzl"], deps = [ "@local_config_cuda//cuda:build_defs_bzl", + "@local_tsl//third_party/py/rules_pywrap:pywrap_bzl", "@local_tsl//tsl/platform:rules_cc_bzl", "@local_tsl//tsl/platform/default:cuda_build_defs_bzl", ] + stream_executor_build_defs_bzl_deps(), @@ -84,6 +85,7 @@ cc_library( ":stream_executor_api_headers", ], deps = [ + ":activate_context", ":allocator_stats", ":blas", ":command_buffer", @@ -104,9 +106,10 @@ cc_library( ":stream_common", ":stream_executor_common", ":stream_executor_h", - "//xla/stream_executor/platform", "//xla/tsl/framework:device_id", "//xla/tsl/framework:device_type", + "//xla/tsl/lib/gtl:int_type", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", @@ -121,7 +124,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -130,7 +132,6 @@ cc_library( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_static([ ":stream_executor_impl", ]) + if_google([ @@ -149,6 +150,11 @@ tf_proto_library( protodeps = ["//xla:autotune_results_proto"], ) +cc_library( + name = "activate_context", + hdrs = ["activate_context.h"], +) + cc_library( name = "device_description", srcs = ["device_description.cc"], @@ -157,9 +163,9 @@ cc_library( ":device_description_proto_cc", ":launch_dim", ":semantic_version", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", ], ) @@ -187,6 +193,7 @@ cc_library( name = "module_spec", hdrs = ["module_spec.h"], deps = [ + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", ], @@ -255,6 +262,7 @@ cc_library( testonly = True, hdrs = ["mock_stream_executor.h"], deps = [ + ":activate_context", ":allocator_stats", ":blas", ":command_buffer", @@ -282,8 +290,8 @@ cc_library( name = "data_type", hdrs = ["data_type.h"], deps = [ + "//xla/tsl/protobuf:dnn_proto_cc", "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -411,13 +419,15 @@ cc_library( ":data_type", ":device_memory", ":numeric_options", - "//xla/stream_executor/platform", + ":scratch_allocator", + ":stream", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -430,8 +440,11 @@ cc_library( ":device_description_proto_cc", ":device_memory", ":numeric_options", - "//xla/stream_executor/platform", + ":scratch_allocator", + ":stream", + "//xla:util", "//xla/tsl/lib/strings:proto_serialization", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -443,11 +456,11 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # buildcleaner: keep + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]), ) @@ -455,7 +468,6 @@ cc_library( name = "fft", hdrs = ["fft.h"], deps = [ - "//xla/stream_executor/platform", ], ) @@ -466,11 +478,11 @@ cc_library( ":dnn", ":stream", ":stream_executor_h", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], ) @@ -491,6 +503,7 @@ cc_library( "stream_executor.h", ], deps = [ + ":activate_context", ":allocator_stats", ":blas", ":command_buffer", @@ -592,10 +605,10 @@ cc_library( ":kernel", ":launch_dim", ":platform", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", ], alwayslink = True, @@ -730,6 +743,7 @@ cc_library( srcs = ["stream_executor_common.cc"], hdrs = ["stream_executor_common.h"], deps = [ + ":activate_context", ":device_description", ":platform", ":stream_executor_h", @@ -794,7 +808,7 @@ cc_library( ":stream_common", ":stream_executor_common", ":stream_executor_h", - ] + if_oss(["@local_tsl//tsl/protobuf:dnn_proto_cc_impl"]), + ] + if_oss(["//xla/tsl/protobuf:dnn_proto_cc_impl"]), ) #===--------------------------------------------------------------------------------------------===# @@ -806,8 +820,11 @@ xla_cc_test( srcs = ["kernel_test.cc"], deps = [ ":device_memory", + ":kernel", ":kernel_spec", - ":stream_executor", + ":platform", + ":platform_manager", + ":stream_executor_h", ":typed_kernel_factory", "//xla/stream_executor/host:host_platform", "@local_tsl//tsl/platform:test", @@ -820,7 +837,9 @@ xla_cc_test( name = "stream_executor_test", srcs = ["stream_executor_test.cc"], deps = [ - ":stream_executor", + ":platform", + ":platform_manager", + ":stream_executor_h", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -834,7 +853,10 @@ xla_cc_test( size = "small", srcs = ["stream_test.cc"], deps = [ - ":stream_executor", + ":platform", + ":platform_manager", + ":stream", + ":stream_executor_h", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:statusor", @@ -928,14 +950,28 @@ xla_cc_test( alias( name = "cuda_platform", actual = "//xla/stream_executor/cuda:all_runtime", + tags = [ + "cuda-only", + "gpu", + ], ) alias( name = "rocm_platform", actual = "//xla/stream_executor/rocm:all_runtime", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), ) alias( name = "sycl_platform", actual = "//xla/stream_executor/sycl:all_runtime", + tags = [ + "gpu", + ], ) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_activation.h b/third_party/xla/xla/stream_executor/activate_context.h similarity index 53% rename from third_party/xla/xla/stream_executor/gpu/gpu_activation.h rename to third_party/xla/xla/stream_executor/activate_context.h index 2de957934d1571..d0b85fc3d5ea15 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_activation.h +++ b/third_party/xla/xla/stream_executor/activate_context.h @@ -1,4 +1,4 @@ -/* Copyright 2019 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file contains APIs that assume a StreamExecutor is backed by a GPU -// implementation. It ensures the underlying GPU context is active. - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ - -#include "xla/stream_executor/gpu/scoped_activate_context.h" +#ifndef XLA_STREAM_EXECUTOR_ACTIVATE_CONTEXT_H_ +#define XLA_STREAM_EXECUTOR_ACTIVATE_CONTEXT_H_ namespace stream_executor { -class StreamExecutor; - -namespace gpu { - -using ScopedActivateExecutorContext = ScopedActivateContext; +// An RAII handle for ensuring a context is activated for the duration of the +// ActivateContext's scope. The creation of an ActivateContext ensures that any +// necessary state changes are done to make the requested context active. When +// the ActivateContext is destroyed, it will enable any previous context that +// was active. +class ActivateContext { + public: + virtual ~ActivateContext() = default; +}; -} // namespace gpu } // namespace stream_executor -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_ACTIVATION_H_ +#endif // XLA_STREAM_EXECUTOR_ACTIVATE_CONTEXT_H_ diff --git a/third_party/xla/xla/stream_executor/blas.h b/third_party/xla/xla/stream_executor/blas.h index cf235c83a9729c..f6ad27aebee97d 100644 --- a/third_party/xla/xla/stream_executor/blas.h +++ b/third_party/xla/xla/stream_executor/blas.h @@ -30,13 +30,15 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct half; @@ -50,15 +52,6 @@ struct MatrixDescriptor; struct OutputMatrixDescriptor; } // namespace gpu -class Stream; -class ScratchAllocator; - -template -class DeviceMemory; - -template -class HostOrDeviceScalar; - template using DeviceMemorySlice = absl::Span *const>; @@ -222,6 +215,10 @@ class BlasSupport { virtual gpu::BlasLt *GetBlasLt() = 0; + // For tests only: sets *is_main_stream to true if the underlying Blas library + // has stream 0 set as its current stream. + virtual absl::StatusOr IsMainStreamSet() const = 0; + // Computes the product of a vector by a scalar: x <- a*x. virtual bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, DeviceMemory *x, int incx) = 0; @@ -287,7 +284,7 @@ class BlasSupport { // case the expected alpha/beta type is `float`. virtual absl::Status DoBlasGemm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, DataType dtype, const void *alpha, + uint64_t m, uint64_t n, uint64_t k, DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, const NumericOptions &numeric_options, blas::CallContext context) = 0; @@ -312,7 +309,7 @@ class BlasSupport { // creating a new Stream for each attempt. virtual absl::Status DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, const void *alpha, + uint64_t m, uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, DataType type_a, int lda, const DeviceMemoryBase &b, DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, @@ -321,7 +318,7 @@ class BlasSupport { ProfileResult *output_profile_result, blas::CallContext context) = 0; virtual absl::Status DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, const void *alpha, + uint64_t m, uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, DataType type_c, int ldc, @@ -335,7 +332,7 @@ class BlasSupport { // and c, which contain batch_count DeviceMemory objects. virtual bool DoBlasGemmBatched(Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64_t n, - uint64 k, float alpha, + uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, @@ -345,7 +342,7 @@ class BlasSupport { blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, float alpha, + uint64_t m, uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, @@ -353,21 +350,21 @@ class BlasSupport { ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, float alpha, DeviceMemorySlice a, - int lda, DeviceMemorySlice b, int ldb, float beta, - DeviceMemorySlice c, int ldc, int batch_count, + uint64_t m, uint64_t n, uint64_t k, float alpha, + DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, + float beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, double alpha, + uint64_t m, uint64_t n, uint64_t k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, std::complex alpha, + uint64_t m, uint64_t n, uint64_t k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, @@ -375,7 +372,7 @@ class BlasSupport { ScratchAllocator *scratch_allocator, blas::CallContext context) = 0; virtual bool DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, std::complex alpha, + uint64_t m, uint64_t n, uint64_t k, std::complex alpha, DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, @@ -384,7 +381,7 @@ class BlasSupport { // Batched gemm with strides instead of pointer arrays. virtual absl::Status DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64_t n, uint64 k, DataType dtype, const void *alpha, + uint64_t m, uint64_t n, uint64_t k, DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, @@ -393,7 +390,7 @@ class BlasSupport { template absl::Status BlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + uint64_t m, uint64_t n, uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, @@ -425,9 +422,10 @@ class BlasSupport { template absl::Status BlasGemm(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, - ConstantType alpha, const DeviceMemory &a, - int lda, const DeviceMemory &b, int ldb, + blas::Transpose transb, uint64_t m, uint64_t n, + uint64_t k, ConstantType alpha, + const DeviceMemory &a, int lda, + const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, const NumericOptions &numeric_options, blas::CallContext context) { @@ -459,8 +457,8 @@ class BlasSupport { template absl::Status BlasGemm(Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, - const DeviceMemory &a, int lda, + blas::Transpose transb, uint64_t m, uint64_t n, + uint64_t k, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, const NumericOptions &numeric_options, @@ -474,7 +472,7 @@ class BlasSupport { template absl::Status BlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + uint64_t m, uint64_t n, uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, ConstantType beta, DeviceMemory *c, int ldc, @@ -508,7 +506,7 @@ class BlasSupport { template absl::Status BlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, const DeviceMemory &a, + uint64_t m, uint64_t n, uint64_t k, const DeviceMemory &a, int lda, const DeviceMemory &b, int ldb, DeviceMemory *c, int ldc, blas::ComputationType computation_type, blas::AlgorithmType algorithm, @@ -525,7 +523,7 @@ class BlasSupport { template absl::Status BlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, - uint64_t m, uint64 n, uint64_t k, ConstantType alpha, + uint64_t m, uint64_t n, uint64_t k, ConstantType alpha, const DeviceMemory &a, int lda, int64_t stride_a, const DeviceMemory &b, int ldb, int64_t stride_b, ConstantType beta, DeviceMemory *c, int ldc, int64_t stride_c, @@ -565,23 +563,23 @@ class BlasSupport { // or op(a) = conj(a'). virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, float alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, double alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; virtual bool DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) = 0; @@ -590,19 +588,19 @@ class BlasSupport { // `as` and `bs` must have the same length. virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, float alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, int ldb, int batch_count) = 0; virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, double alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, int ldb, int batch_count) = 0; virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory *> &as, int lda, @@ -610,7 +608,7 @@ class BlasSupport { int ldb, int batch_count) = 0; virtual bool DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory *> &as, int lda, @@ -727,6 +725,7 @@ class BlasSupport { // Macro used to quickly declare overrides for abstract virtuals in the // BlasSupport base class. #define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ + absl::StatusOr IsMainStreamSet() const override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, float alpha, \ DeviceMemory *x, int incx) override; \ bool DoBlasScal(Stream *stream, uint64_t elem_count, double alpha, \ @@ -741,33 +740,33 @@ class BlasSupport { bool DoBlasScal(Stream *stream, uint64_t elem_count, \ std::complex alpha, \ DeviceMemory> *x, int incx) override; \ - bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ - float alpha, const DeviceMemory &a, int lda, \ - const DeviceMemory &x, int incx, float beta, \ + bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, float alpha, const DeviceMemory &a, \ + int lda, const DeviceMemory &x, int incx, float beta, \ DeviceMemory *y, int incy) override; \ - bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ - double alpha, const DeviceMemory &a, int lda, \ - const DeviceMemory &x, int incx, double beta, \ - DeviceMemory *y, int incy) override; \ - bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ - std::complex alpha, \ + bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, double alpha, const DeviceMemory &a, \ + int lda, const DeviceMemory &x, int incx, \ + double beta, DeviceMemory *y, int incy) override; \ + bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ - bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, uint64 n, \ - std::complex alpha, \ + bool DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, \ + uint64_t n, std::complex alpha, \ const DeviceMemory> &a, int lda, \ const DeviceMemory> &x, int incx, \ std::complex beta, \ DeviceMemory> *y, int incy) override; \ absl::Status DoBlasGemm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ - const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, \ - const void *beta, DeviceMemoryBase *c, int ldc, \ - const NumericOptions &numeric_options, blas::CallContext context) \ - override; \ + uint64_t m, uint64_t n, uint64_t k, blas::DataType dtype, \ + const void *alpha, const DeviceMemoryBase &a, int lda, \ + const DeviceMemoryBase &b, int ldb, const void *beta, \ + DeviceMemoryBase *c, int ldc, const NumericOptions &numeric_options, \ + blas::CallContext context) override; \ bool GetBlasGemmAlgorithms( \ Stream *stream, const gpu::MatrixDescriptor &a, \ const gpu::MatrixDescriptor &b, gpu::OutputMatrixDescriptor *c, \ @@ -775,7 +774,7 @@ class BlasSupport { std::vector *out_algorithms) override; \ absl::Status DoBlasGemmWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, const void *alpha, \ + uint64_t m, uint64_t n, uint64_t k, const void *alpha, \ const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ const DeviceMemoryBase &b, blas::DataType type_b, int ldb, \ const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, \ @@ -785,7 +784,7 @@ class BlasSupport { override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, float alpha, \ + uint64_t m, uint64_t n, uint64_t k, float alpha, \ DeviceMemorySlice a, int lda, \ DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ @@ -794,7 +793,7 @@ class BlasSupport { override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, float alpha, \ + uint64_t m, uint64_t n, uint64_t k, float alpha, \ DeviceMemorySlice a, int lda, \ DeviceMemorySlice b, int ldb, float beta, \ DeviceMemorySlice c, int ldc, int batch_count, \ @@ -803,15 +802,15 @@ class BlasSupport { override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, float alpha, DeviceMemorySlice a, \ - int lda, DeviceMemorySlice b, int ldb, float beta, \ - DeviceMemorySlice c, int ldc, int batch_count, \ - const NumericOptions &numeric_options, \ + uint64_t m, uint64_t n, uint64_t k, float alpha, \ + DeviceMemorySlice a, int lda, DeviceMemorySlice b, \ + int ldb, float beta, DeviceMemorySlice c, int ldc, \ + int batch_count, const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator, blas::CallContext context) \ override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, double alpha, \ + uint64_t m, uint64_t n, uint64_t k, double alpha, \ DeviceMemorySlice a, int lda, DeviceMemorySlice b, \ int ldb, double beta, DeviceMemorySlice c, int ldc, \ int batch_count, const NumericOptions &numeric_options, \ @@ -819,7 +818,7 @@ class BlasSupport { override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, std::complex alpha, \ + uint64_t m, uint64_t n, uint64_t k, std::complex alpha, \ DeviceMemorySlice> a, int lda, \ DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ @@ -828,7 +827,7 @@ class BlasSupport { override; \ bool DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, std::complex alpha, \ + uint64_t m, uint64_t n, uint64_t k, std::complex alpha, \ DeviceMemorySlice> a, int lda, \ DeviceMemorySlice> b, int ldb, \ std::complex beta, DeviceMemorySlice> c, \ @@ -837,15 +836,15 @@ class BlasSupport { override; \ absl::Status DoBlasGemmStridedBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, blas::DataType dtype, const void *alpha, \ - const DeviceMemoryBase &a, int lda, int64_t stride_a, \ + uint64_t m, uint64_t n, uint64_t k, blas::DataType dtype, \ + const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, \ const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, \ DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, \ const NumericOptions &numeric_options, blas::CallContext context) \ override; \ absl::Status DoBlasGemmStridedBatchedWithAlgorithm( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64 n, uint64 k, const void *alpha, \ + uint64_t m, uint64_t n, uint64_t k, const void *alpha, \ const DeviceMemoryBase &a, blas::DataType type_a, int lda, \ int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b, \ int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, \ @@ -874,24 +873,24 @@ class BlasSupport { DeviceMemory> *b, int ldb) override; \ bool DoBlasTrsmBatched( \ Stream *stream, blas::Side side, blas::UpperLower uplo, \ - blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64_t n, \ float alpha, const DeviceMemory &as, int lda, \ DeviceMemory *bs, int ldb, int batch_count) override; \ bool DoBlasTrsmBatched( \ Stream *stream, blas::Side side, blas::UpperLower uplo, \ - blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64_t n, \ double alpha, const DeviceMemory &as, int lda, \ DeviceMemory *bs, int ldb, int batch_count) override; \ bool DoBlasTrsmBatched(Stream *stream, blas::Side side, \ blas::UpperLower uplo, blas::Transpose transa, \ - blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Diagonal diag, uint64_t m, uint64_t n, \ std::complex alpha, \ const DeviceMemory *> &as, \ int lda, DeviceMemory *> *bs, \ int ldb, int batch_count) override; \ bool DoBlasTrsmBatched(Stream *stream, blas::Side side, \ blas::UpperLower uplo, blas::Transpose transa, \ - blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Diagonal diag, uint64_t m, uint64_t n, \ std::complex alpha, \ const DeviceMemory *> &as, \ int lda, DeviceMemory *> *bs, \ diff --git a/third_party/xla/xla/stream_executor/build_defs.bzl b/third_party/xla/xla/stream_executor/build_defs.bzl index 959de0fa172871..3204b886c651ff 100644 --- a/third_party/xla/xla/stream_executor/build_defs.bzl +++ b/third_party/xla/xla/stream_executor/build_defs.bzl @@ -1,6 +1,5 @@ """Configurations for StreamExecutor builds""" -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load( "@local_config_rocm//rocm:build_defs.bzl", _if_cuda_or_rocm = "if_cuda_or_rocm", @@ -64,34 +63,5 @@ def gpu_only_cc_library(name, tags = [], **kwargs): target_compatible_with = kwargs.get("target_compatible_with"), ) -def cuda_only_cc_library(name, tags = [], **kwargs): - """A library that only gets compiled when CUDA is configured, otherwise it's an empty target. - - Args: - name: Name of the target - tags: Tags being applied to the implementation target - **kwargs: Accepts all arguments that a `cc_library` would also accept - """ - if not native.package_name().startswith("xla/stream_executor"): - fail("cuda_only_cc_library may only be used in `xla/stream_executor/...`.") - - cc_library( - name = "%s_non_cuda" % name, - tags = ["manual"], - ) - cc_library( - name = "%s_cuda_only" % name, - tags = tags + ["manual", "no_rocm"], - **kwargs - ) - native.alias( - name = name, - actual = if_cuda_is_configured(":%s_cuda_only" % name, ":%s_non_cuda" % name), - visibility = kwargs.get("visibility"), - compatible_with = kwargs.get("compatible_with"), - restricted_to = kwargs.get("restricted_to"), - target_compatible_with = kwargs.get("target_compatible_with"), - ) - def stream_executor_build_defs_bzl_deps(): return [] diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 2b92b504f2059a..a5e7ac61ccd0e4 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -28,7 +28,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" #include "tsl/platform/errors.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 6ff53527b930f4..8c67fcdf158850 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,20 +10,14 @@ load( ) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", "if_cuda_newer_than", ) load( "//xla:xla.bzl", "xla_cc_test", ) -load( - "//xla/service/gpu:build_defs.bzl", - "gpu_kernel_library", -) load( "//xla/stream_executor:build_defs.bzl", - "cuda_only_cc_library", "stream_executor_friends", "tf_additional_cuda_platform_deps", "tf_additional_cudnn_plugin_copts", @@ -87,25 +81,32 @@ cc_library( deps = ["//xla/stream_executor:platform"], ) -cuda_only_cc_library( +cc_library( name = "cuda_platform", srcs = ["cuda_platform.cc"], hdrs = ["cuda_platform.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_collectives", + ":cuda_diagnostics", # buildcleaner: keep ":cuda_driver", ":cuda_executor", ":cuda_platform_id", ":cuda_runtime", - "//xla/stream_executor", + ":cuda_status", + "//xla/stream_executor:device_description", "//xla/stream_executor:executor_cache", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -116,6 +117,7 @@ cuda_only_cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -123,10 +125,14 @@ cuda_only_cc_library( alwayslink = True, # Registers itself with the PlatformManager. ) -cuda_only_cc_library( +cc_library( name = "cuda_diagnostics", srcs = ["cuda_diagnostics.cc"], hdrs = ["cuda_diagnostics.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", @@ -157,14 +163,42 @@ cc_library( ), ) -cuda_only_cc_library( +cc_library( + name = "cuda_context", + srcs = ["cuda_context.cc"], + hdrs = ["cuda_context.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_status", + "//xla/stream_executor/gpu:context", + "//xla/stream_executor/gpu:context_map", + "//xla/stream_executor/gpu:scoped_activate_context", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( name = "cuda_driver", srcs = ["cuda_driver.cc"], - hdrs = ["cuda_driver.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ + ":cuda_context", ":cuda_diagnostics", # buildcleaner: keep ":cuda_status", - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:context_map", "//xla/stream_executor/gpu:gpu_diagnostics_header", @@ -175,33 +209,31 @@ cuda_only_cc_library( "//xla/tsl/cuda:cudart", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/debugging:leak_check", - "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:stacktrace", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) -cuda_only_cc_library( +cc_library( name = "cuda_status", srcs = ["cuda_status.cc"], hdrs = ["cuda_status.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -210,12 +242,15 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_runtime", srcs = ["cuda_runtime.cc"], + hdrs = ["cuda_runtime.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - "//xla/stream_executor/gpu:gpu_runtime_header", - "//xla/stream_executor/gpu:gpu_types_header", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -223,41 +258,97 @@ cuda_only_cc_library( "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) -cuda_only_cc_library( +cc_library( name = "cuda_collectives", - srcs = ["cuda_collectives.cc"], - defines = if_nccl(["STREAM_EXECUTOR_GPU_ENABLE_XCCL"]), - deps = [ - ":cuda_driver", + hdrs = ["cuda_collectives.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = if_nccl( + [":cuda_collectives_impl"], + [":cuda_collectives_stub"], + ) + [ + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", - "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "cuda_collectives_impl", + srcs = [ + "cuda_collectives.cc", + "cuda_collectives.h", + ], + tags = [ + "cuda-only", + "gpu", + "manual", + ], + deps = [ + "//xla/stream_executor:activate_context", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_config_nccl//:nccl", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", + ], +) + +cc_library( + name = "cuda_collectives_stub", + srcs = [ + "cuda_collectives.h", + "cuda_collectives_stub.cc", + ], + deps = [ + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:context", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +xla_test( + name = "cuda_collectives_test", + srcs = ["cuda_collectives_test.cc"], + backends = ["gpu_any"], + tags = ["cuda-only"], + deps = [ + ":cuda_collectives", + "//xla/service/gpu/runtime:nccl_api", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - ] + if_nccl(["@local_config_nccl//:nccl"]), + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], ) xla_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], backends = ["gpu"], - tags = ["no_rocm"], + tags = [ + "cuda-only", + ], deps = [ + ":cuda_context", ":cuda_diagnostics", - ":cuda_driver", ":cuda_status", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/gpu:scoped_activate_context", + "@com_google_absl//absl/cleanup", "@com_google_absl//absl/log", "@com_google_googletest//:gtest_main", "@local_config_cuda//cuda:cuda_headers", @@ -268,18 +359,23 @@ xla_test( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_lt_header", hdrs = [ "cuda_blas_lt.h", "cuda_blas_utils.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ "//xla:types", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", @@ -290,7 +386,7 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cublas_plugin", srcs = [ "cuda_blas.cc", @@ -300,6 +396,10 @@ cuda_only_cc_library( "cuda_blas.h", "cuda_blas_lt.h", ], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_blas_utils", @@ -311,24 +411,25 @@ cuda_only_cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:blas", + "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor:numeric_options", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:scratch_allocator", - "//xla/stream_executor/gpu:gpu_activation_header", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "//xla/tsl/cuda:cublas", "//xla/tsl/cuda:cublas_lt", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -345,19 +446,21 @@ cuda_only_cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_static([ "@local_tsl//tsl/platform:tensor_float_32_utils", ]), alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_blas_utils", srcs = ["cuda_blas_utils.cc"], hdrs = ["cuda_blas_utils.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/tsl/cuda:cublas", "@com_google_absl//absl/log", @@ -368,24 +471,28 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cufft_plugin", srcs = ["cuda_fft.cc"], hdrs = ["cuda_fft.h"], + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_helpers", ":cuda_platform_id", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:device_memory", "//xla/stream_executor:fft", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "//xla/tsl/cuda:cufft", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -397,61 +504,64 @@ cuda_only_cc_library( alwayslink = True, ) -gpu_kernel_library( +cuda_library( name = "delay_kernel_cuda", srcs = [ - "delay_kernel.h", "delay_kernel_cuda.cu.cc", ], - tags = ["manual"], + hdrs = ["delay_kernel.h"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//xla/stream_executor:__subpackages__", ]), deps = [ - "//xla/stream_executor", + "//xla/stream_executor:stream", "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_semaphore", "@com_google_absl//absl/status:statusor", ], ) -cuda_only_cc_library( +cc_library( name = "cudnn_plugin", srcs = ["cuda_dnn.cc"], hdrs = ["cuda_dnn.h"], copts = tf_additional_cudnn_plugin_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cuda_diagnostics", - ":cuda_driver", - ":cuda_executor", ":cuda_platform_id", ":cudnn_frontend_helpers", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:data_type", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:numeric_options", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_activation_header", "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "//xla/tsl/cuda:cudnn", + "//xla/tsl/protobuf:dnn_proto_cc", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -462,27 +572,59 @@ cuda_only_cc_library( "@cudnn_frontend_archive//:cudnn_frontend", "@eigen_archive//:eigen3", "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", + "@local_config_cuda//cuda:cudnn_header", # build_cleaner: keep "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", "@local_tsl//tsl/platform:tensor_float_32_utils", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ], alwayslink = True, ) -cuda_only_cc_library( +cc_library( name = "cuda_kernel", srcs = ["cuda_kernel.cc"], + hdrs = ["cuda_kernel.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_driver_header", + ":cuda_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_kernel_header", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_test( + name = "cuda_kernel_test", + srcs = ["cuda_kernel_test.cc"], + backends = ["gpu_any"], + tags = ["cuda-only"], + deps = [ + ":cuda_kernel", + ":cuda_runtime", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels_cuda", + "@com_google_googletest//:gtest_main", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) @@ -492,7 +634,9 @@ cuda_library( "command_buffer_kernels.cc", "command_buffer_kernels.cu.cc", ], - tags = ["no_rocm"], + hdrs = ["command_buffer_kernels.h"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = ["cuda-only"], deps = [ "//xla/stream_executor:kernel_spec", "//xla/stream_executor/gpu:gpu_types_header", @@ -500,31 +644,66 @@ cuda_library( ], ) -# TODO(leary) we likely need to canonicalize/eliminate this. cc_library( name = "cuda_helpers", - textual_hdrs = if_cuda_is_configured(["cuda_helpers.h"]), - deps = if_cuda_is_configured([ + hdrs = ["cuda_helpers.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ "//xla/stream_executor/gpu:gpu_helpers_header", - "@local_config_cuda//cuda:cuda_headers", - ]) + [ "@com_google_absl//absl/log:check", + "@local_config_cuda//cuda:cuda_headers", ], ) -cuda_only_cc_library( +cc_library( name = "cuda_event", srcs = ["cuda_event.cc"], hdrs = ["cuda_event.h"], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - ":cuda_driver", + ":cuda_status", + "//xla/stream_executor:activate_context", "//xla/stream_executor:event", - "//xla/stream_executor/gpu:gpu_event", - "//xla/stream_executor/gpu:scoped_activate_context", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) +xla_test( + name = "cuda_event_test", + srcs = ["cuda_event_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_event", + ":cuda_executor", + ":cuda_platform_id", + "//xla/stream_executor:event", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", + "@com_google_googletest//:gtest_main", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +# This target serves to expose a single variable to all other kinds of +# targets and should stay minimal (no dependencies). cc_library( name = "ptx_compiler_support", srcs = ["ptx_compiler_support.cc"], @@ -537,9 +716,14 @@ cc_library( "LIBNVPTXCOMPILER_SUPPORT=false", ], }), +) + +cc_library( + name = "ptx_compiler_helpers", + srcs = ["ptx_compiler_helpers.cc"], + hdrs = ["ptx_compiler_helpers.h"], deps = [ "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", ], ) @@ -565,7 +749,7 @@ cc_library( ], tags = ["manual"], deps = [ - ":ptx_compiler_support", + ":ptx_compiler_helpers", "//xla/stream_executor:semantic_version", "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/algorithm:container", @@ -578,13 +762,14 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:nvptxcompiler", + "@local_tsl//tsl/platform:logging", ], ) cc_library( name = "ptx_compiler", hdrs = ["ptx_compiler.h"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = select({ ":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"], "//conditions:default": [":ptx_compiler_stub"], @@ -599,7 +784,7 @@ xla_test( name = "cuda_platform_test", srcs = ["cuda_platform_test.cc"], backends = ["gpu"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":cuda_platform", "//xla/stream_executor:platform", @@ -616,7 +801,7 @@ xla_cc_test( name = "ptx_compiler_test", srcs = ["ptx_compiler_test.cc"], tags = [ - "no_rocm", + "cuda-only", # TODO(b/343996893): Figure out whether msan reports a false positive or not. "nomsan", ], @@ -675,7 +860,7 @@ cc_library( ], tags = ["manual"], deps = [ - ":ptx_compiler_support", + ":ptx_compiler_helpers", "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -690,6 +875,7 @@ cc_library( "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:nvjitlink", # buildcleaner: keep "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", ], ) @@ -751,7 +937,7 @@ xla_cc_test( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_asm_compiler", srcs = ["cuda_asm_compiler.cc"], hdrs = ["cuda_asm_compiler.h"], @@ -770,6 +956,10 @@ cuda_only_cc_library( "@cuda_nvcc//:ptxas", ]), # copybara:comment_end + tags = [ + "cuda-only", + "gpu", + ], visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", @@ -781,15 +971,15 @@ cuda_only_cc_library( ":cuda_driver", # buildcleaner: keep ":cuda_status", ":ptx_compiler", + ":ptx_compiler_helpers", ":ptx_compiler_support", "//xla:status_macros", "//xla:util", + "//xla/stream_executor:activate_context", "//xla/stream_executor:device_description", "//xla/stream_executor:semantic_version", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:gpu_asm_opts", - "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", @@ -804,7 +994,7 @@ cuda_only_cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:cuda_libdevice_path", + "@local_tsl//tsl/platform:cuda_root_path", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -815,58 +1005,64 @@ cuda_only_cc_library( ], ) -cuda_only_cc_library( +cc_library( name = "cuda_executor", srcs = [ "cuda_executor.cc", - "delay_kernel.h", ], hdrs = [ "cuda_executor.h", ], + tags = [ + "cuda-only", + "gpu", + ], deps = [ - ":cuda_collectives", # buildcleaner: keep - ":cuda_diagnostics", - ":cuda_driver", + ":cuda_collectives", + ":cuda_command_buffer", + ":cuda_context", + ":cuda_driver", # buildcleaner: keep ":cuda_event", # buildcleaner: keep ":cuda_kernel", # buildcleaner: keep ":cuda_platform_id", - ":cuda_runtime", # buildcleaner: keep + ":cuda_runtime", ":cuda_status", + ":cuda_stream", + ":cuda_timer", ":cuda_version_parser", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:fft", "//xla/stream_executor:host_memory_allocation", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:module_spec", + "//xla/stream_executor:platform", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:gpu_command_buffer", - "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_event_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", - "//xla/stream_executor/gpu:gpu_runtime_header", - "//xla/stream_executor/gpu:gpu_semaphore", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:read_numa_node", - "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/stream_executor/gpu:scoped_activate_context", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", @@ -876,19 +1072,43 @@ cuda_only_cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([":delay_kernel_cuda"]), + ], alwayslink = True, ) +xla_test( + name = "cuda_executor_test", + srcs = ["cuda_executor_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_executor", + ":cuda_platform", + "//xla/stream_executor:device_description", + "//xla/stream_executor:semantic_version", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "all_runtime", copts = tsl_copts(), + tags = [ + "cuda-only", + "gpu", + ], visibility = ["//visibility:public"], deps = [ ":cublas_plugin", @@ -911,14 +1131,13 @@ cc_library( cc_library( name = "stream_executor_cuda", + tags = ["cuda-only"], deps = [ - "//xla/stream_executor", + ":cuda_platform_id", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", - "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", - "//xla/stream_executor/platform:dso_loader", "//xla/stream_executor/rocm:rocm_platform_id", ] + if_google( select({ @@ -926,7 +1145,7 @@ cc_library( # "//tools/cc_target_os:gce": [], # copybara:uncomment_end_and_comment_begin "//conditions:default": [ - "@local_config_cuda//cuda:cudart_static", # buildcleaner: keep + "@local_config_cuda//cuda:cuda_runtime", # buildcleaner: keep ":cuda_platform", ], }), @@ -980,3 +1199,149 @@ cc_test( "@local_tsl//tsl/platform:test", ], ) + +cc_library( + name = "cuda_stream", + srcs = ["cuda_stream.cc"], + hdrs = ["cuda_stream.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_context", + ":cuda_event", + ":cuda_kernel", + ":cuda_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_common", + "@com_google_absl//absl/base", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:nvtx_utils", + ], +) + +xla_test( + name = "cuda_stream_test", + srcs = ["cuda_stream_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_executor", + ":cuda_platform_id", + ":cuda_stream", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_test_kernels_cuda", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "cuda_timer", + srcs = [ + "cuda_timer.cc", + ], + hdrs = ["cuda_timer.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":cuda_event", + ":cuda_status", + ":delay_kernel_cuda", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_semaphore", + "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "cuda_timer_test", + srcs = ["cuda_timer_test.cc"], + backends = ["gpu"], + tags = ["cuda-only"], + deps = [ + ":cuda_executor", + ":cuda_platform_id", + ":cuda_timer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_stream", + "//xla/stream_executor/gpu:gpu_test_kernels_cuda", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "cuda_command_buffer", + srcs = ["cuda_command_buffer.cc"], + hdrs = ["cuda_command_buffer.h"], + tags = [ + "cuda-only", + "gpu", + ], + deps = [ + ":command_buffer_kernels", + ":cuda_status", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_command_buffer", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc index 4ec6543d7a9b9d..06726c83a562cb 100644 --- a/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc +++ b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/command_buffer_kernels.h" + #include #include "absl/status/statusor.h" @@ -808,10 +810,6 @@ void* GetSetWhileConditionKernel(); void* GetNoOpKernel(); #endif -} // namespace cuda - -namespace gpu { - // TODO(b/362786589): Remove PTX usage when we only support cuda >= 12.4.1 // See comment at top of this file for why PTX is used for cuda < 12.4.1. absl::StatusOr GetSetIfConditionKernelLoaderSpec() { @@ -880,5 +878,5 @@ absl::StatusOr GetNoOpKernelLoaderSpec() { return spec; } -} // namespace gpu +} // namespace cuda } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/command_buffer_kernels.cc b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h similarity index 65% rename from third_party/xla/xla/stream_executor/rocm/command_buffer_kernels.cc rename to third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h index 14893e9a005ade..a610b3ebf4f4be 100644 --- a/third_party/xla/xla/stream_executor/rocm/command_buffer_kernels.cc +++ b/third_party/xla/xla/stream_executor/cuda/command_buffer_kernels.h @@ -13,33 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "absl/status/status.h" +#ifndef XLA_STREAM_EXECUTOR_CUDA_COMMAND_BUFFER_KERNELS_H_ +#define XLA_STREAM_EXECUTOR_CUDA_COMMAND_BUFFER_KERNELS_H_ + #include "absl/status/statusor.h" #include "xla/stream_executor/kernel_spec.h" -namespace stream_executor::gpu { - -absl::StatusOr GetSetIfConditionKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} - -absl::StatusOr GetSetIfElseConditionKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} - -absl::StatusOr GetSetCaseConditionKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} +namespace stream_executor::cuda { -absl::StatusOr GetSetForConditionKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} +// These are various kernels that update Gpu conditionals based on the device +// memory values, and allow implementing on-device control flow via conditional +// command buffers. +absl::StatusOr GetSetIfConditionKernelLoaderSpec(); +absl::StatusOr GetSetIfElseConditionKernelLoaderSpec(); +absl::StatusOr GetSetCaseConditionKernelLoaderSpec(); +absl::StatusOr GetSetForConditionKernelLoaderSpec(); +absl::StatusOr GetSetWhileConditionKernelLoaderSpec(); +absl::StatusOr GetNoOpKernelLoaderSpec(); -absl::StatusOr GetSetWhileConditionKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} +} // namespace stream_executor::cuda -absl::StatusOr GetNoOpKernelLoaderSpec() { - return absl::UnimplementedError("Unimplemented"); -} -} // namespace stream_executor::gpu +#endif // XLA_STREAM_EXECUTOR_CUDA_COMMAND_BUFFER_KERNELS_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index d788a8dd077fe6..cc822b46f65a7b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -47,17 +47,17 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/status_macros.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/cuda/ptx_compiler.h" +#include "xla/stream_executor/cuda/ptx_compiler_helpers.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" #include "xla/util.h" -#include "tsl/platform/cuda_libdevice_path.h" +#include "tsl/platform/cuda_root_path.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -592,9 +592,10 @@ absl::StatusOr> LinkUsingNvlink( } absl::StatusOr> LinkGpuAsm( - stream_executor::CudaComputeCapability cc, gpu::Context* context, + stream_executor::CudaComputeCapability cc, + stream_executor::StreamExecutor* executor, std::vector images) { - gpu::ScopedActivateContext activation(context); + std::unique_ptr activation = executor->Activate(); CUlinkState link_state; CUjit_option options[] = {CU_JIT_TARGET}; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h index 061301d0243b8a..a646882f08a503 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.h @@ -24,7 +24,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream_executor.h" @@ -75,7 +74,8 @@ absl::StatusOr> BundleGpuAsm( // Links multiple relocatable GPU images (e.g. results of ptxas -c) into a // single image. absl::StatusOr> LinkGpuAsm( - stream_executor::CudaComputeCapability cc, gpu::Context* context, + stream_executor::CudaComputeCapability cc, + stream_executor::StreamExecutor* executor, std::vector images); absl::StatusOr> LinkUsingNvlink( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc index f6628ab0edfad4..f4fb52a31544c2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.cc @@ -37,6 +37,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/driver_types.h" #include "third_party/gpus/cuda/include/library_types.h" #include "third_party/gpus/cuda/include/vector_types.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" #include "xla/stream_executor/cuda/cuda_helpers.h" @@ -44,21 +45,19 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace cuda { @@ -66,6 +65,7 @@ namespace cuda { using gpu::AsGpuStreamValue; using gpu::GpuMemory; using gpu::GpuMemoryMutable; +using gpu::GpuStreamHandle; // cuBLAS has interfaces that permit pointers to be passed from either the host // memory space or the device memory space; however, you must instruct it as to @@ -190,7 +190,7 @@ static const char *const kCublasNotInitializedExplanation = bool CUDABlas::Init() { absl::MutexLock lock(&mu_); - gpu::ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); cublasStatus_t ret = cublasCreate(&blas_); if (ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to create cublas handle: " << ToString(ret); @@ -211,7 +211,7 @@ bool CUDABlas::Init() { return true; } -CUDABlas::CUDABlas(gpu::GpuExecutor *parent) +CUDABlas::CUDABlas(StreamExecutor *parent) : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) #if CUDA_VERSION >= 11000 @@ -223,31 +223,32 @@ CUDABlas::CUDABlas(gpu::GpuExecutor *parent) CUDABlas::~CUDABlas() { if (blas_ != nullptr) { - gpu::ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); cublasDestroy(blas_); } } bool CUDABlas::SetStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); CHECK(blas_ != nullptr); - gpu::ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); - cublasStatus_t ret = cublasSetStream(blas_, AsGpuStreamValue(stream)); - if (ret != CUBLAS_STATUS_SUCCESS) { + auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : nullptr; + if (auto ret = cublasSetStream(blas_, handle); ret != CUBLAS_STATUS_SUCCESS) { LOG(ERROR) << "failed to set stream for cuBLAS calls: " << ToString(ret); return false; } - return true; } -cudaStream_t CUDABlas::CUDAStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); - gpu::ScopedActivateContext sac{parent_}; - return AsGpuStreamValue(stream); +absl::StatusOr CUDABlas::IsMainStreamSet() const { + absl::MutexLock lock{&mu_}; + CHECK(blas_ != nullptr); + GpuStreamHandle handle{}; + if (auto ret = cublasGetStream(blas_, &handle); + ret != CUBLAS_STATUS_SUCCESS) { + return absl::InternalError("failed to get the current stream value"); + } + return (handle == nullptr); } namespace { @@ -395,7 +396,7 @@ absl::Status CUDABlas::DoBlasInternalImpl(FuncT cublas_func, Stream *stream, } } - gpu::ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); ScopedCublasPointerMode pointer_mode{blas_}; if (!pointer_mode.Init(pointer_mode_host ? CUBLAS_POINTER_MODE_HOST : CUBLAS_POINTER_MODE_DEVICE)) { @@ -513,7 +514,7 @@ bool CUDABlas::DoBlasGemv(Stream *stream, blas::Transpose trans, uint64_t m, absl::Status CUDABlas::DoBlasGemm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64 n, uint64_t k, blas::DataType dtype, const void *alpha, + uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, const DeviceMemoryBase &b, int ldb, const void *beta, DeviceMemoryBase *c, int ldc, const NumericOptions &numeric_options, blas::CallContext context) { @@ -706,7 +707,7 @@ static absl::Status PopulateProfileFromTimer( absl::Status CUDABlas::DoBlasGemmWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, blas::ComputationType computation_type, @@ -741,7 +742,7 @@ absl::Status CUDABlas::DoBlasGemmWithAlgorithm( absl::Status CUDABlas::DoBlasGemmStridedBatchedWithAlgorithm( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, const void *alpha, const DeviceMemoryBase &a, + uint64_t n, uint64_t k, const void *alpha, const DeviceMemoryBase &a, blas::DataType type_a, int lda, int64_t stride_a, const DeviceMemoryBase &b, blas::DataType type_b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, blas::DataType type_c, int ldc, int64_t stride_c, @@ -910,7 +911,7 @@ T inline CUDAComplexValue(T v) { template absl::Status CUDABlas::DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, Scalar alpha, const DeviceMemorySlice &a_ptrs_to_wrappers, int lda, const DeviceMemorySlice &b_ptrs_to_wrappers, int ldb, Scalar beta, const DeviceMemorySlice &c_ptrs_to_wrappers, int ldc, int batch_count, @@ -1022,7 +1023,7 @@ absl::Status CUDABlas::DoBlasGemmBatchedInternal( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a_array, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, @@ -1041,7 +1042,7 @@ bool CUDABlas::DoBlasGemmBatched( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, @@ -1061,7 +1062,7 @@ bool CUDABlas::DoBlasGemmBatched( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, float alpha, DeviceMemorySlice a_array, + uint64_t n, uint64_t k, float alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, float beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, @@ -1078,7 +1079,7 @@ bool CUDABlas::DoBlasGemmBatched( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, double alpha, DeviceMemorySlice a_array, + uint64_t n, uint64_t k, double alpha, DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, int ldb, double beta, DeviceMemorySlice c_array, int ldc, int batch_count, const NumericOptions &numeric_options, ScratchAllocator *scratch_allocator, @@ -1096,7 +1097,7 @@ bool CUDABlas::DoBlasGemmBatched( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, + uint64_t n, uint64_t k, std::complex alpha, DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, @@ -1115,7 +1116,7 @@ bool CUDABlas::DoBlasGemmBatched( bool CUDABlas::DoBlasGemmBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, std::complex alpha, + uint64_t n, uint64_t k, std::complex alpha, DeviceMemorySlice> a_array, int lda, DeviceMemorySlice> b_array, int ldb, std::complex beta, DeviceMemorySlice> c_array, @@ -1133,7 +1134,7 @@ bool CUDABlas::DoBlasGemmBatched( absl::Status CUDABlas::DoBlasGemmStridedBatched( Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64_t m, - uint64_t n, uint64 k, blas::DataType dtype, const void *alpha, + uint64_t n, uint64_t k, blas::DataType dtype, const void *alpha, const DeviceMemoryBase &a, int lda, int64_t stride_a, const DeviceMemoryBase &b, int ldb, int64_t stride_b, const void *beta, DeviceMemoryBase *c, int ldc, int64_t stride_c, int batch_count, @@ -1271,7 +1272,7 @@ absl::Status CUDABlas::DoBlasGemmStridedBatched( bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, float alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { return DoBlasInternal(cublasStrsm, stream, true /* = pointer_mode_host */, @@ -1282,7 +1283,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, double alpha, const DeviceMemory &a, int lda, DeviceMemory *b, int ldb) { return DoBlasInternal(cublasDtrsm, stream, true /* = pointer_mode_host */, @@ -1293,7 +1294,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { @@ -1307,7 +1308,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory> &a, int lda, DeviceMemory> *b, int ldb) { @@ -1321,7 +1322,7 @@ bool CUDABlas::DoBlasTrsm(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, float alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, int ldb, int batch_count) { @@ -1334,7 +1335,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, double alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, int ldb, int batch_count) { @@ -1347,7 +1348,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory *> &as, int lda, @@ -1364,7 +1365,7 @@ bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, bool CUDABlas::DoBlasTrsmBatched(Stream *stream, blas::Side side, blas::UpperLower uplo, blas::Transpose transa, - blas::Diagonal diag, uint64_t m, uint64 n, + blas::Diagonal diag, uint64_t m, uint64_t n, std::complex alpha, const DeviceMemory *> &as, int lda, @@ -1396,16 +1397,7 @@ void initialize_cublas() { PluginRegistry::Instance()->RegisterFactory( kCudaPlatformId, "cuBLAS", [](::stream_executor::StreamExecutor *parent) -> blas::BlasSupport * { - gpu::GpuExecutor *cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the cuBLAS " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } - - CUDABlas *blas = new CUDABlas(cuda_executor); + CUDABlas *blas = new CUDABlas(parent); if (!blas->Init()) { // Note: Init() will log a more specific error. delete blas; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h index 5f69e8b04765a8..fc87558a8d4b59 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas.h @@ -30,16 +30,11 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas_lt.h" #include "xla/stream_executor/numeric_options.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor { - -class Stream; - -namespace gpu { -class GpuExecutor; -} // namespace gpu - namespace cuda { // BLAS plugin for CUDA platform via cuBLAS library. @@ -54,7 +49,7 @@ namespace cuda { // Thread-safe post-initialization. class CUDABlas : public blas::BlasSupport { public: - explicit CUDABlas(gpu::GpuExecutor *parent); + explicit CUDABlas(StreamExecutor *parent); // Allocates a cuBLAS handle. bool Init(); @@ -74,9 +69,6 @@ class CUDABlas : public blas::BlasSupport { // invoked before calling into cuBLAS. bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Returns the underlying CUDA stream. - cudaStream_t CUDAStream(Stream *stream); - // A helper function that calls the real cuBLAS function together with error // handling. // @@ -106,7 +98,7 @@ class CUDABlas : public blas::BlasSupport { template absl::Status DoBlasGemmBatchedInternal( FuncT cublas_func, Stream *stream, blas::Transpose transa, - blas::Transpose transb, uint64_t m, uint64 n, uint64 k, Scalar alpha, + blas::Transpose transb, uint64_t m, uint64_t n, uint64_t k, Scalar alpha, const DeviceMemorySlice &a_array, int lda, const DeviceMemorySlice &b_array, int ldb, Scalar beta, const DeviceMemorySlice &c_array, int ldc, int batch_count, @@ -114,11 +106,11 @@ class CUDABlas : public blas::BlasSupport { ScratchAllocator *scratch_allocator); // Guards the cuBLAS handle for this device. - absl::Mutex mu_; + mutable absl::Mutex mu_; - // GpuExecutor which instantiated this CUDABlas. + // StreamExecutor which instantiated this CUDABlas. // Immutable post-initialization. - gpu::GpuExecutor *parent_; + StreamExecutor *parent_; // cuBLAS library handle on the device. cublasHandle_t blas_ ABSL_GUARDED_BY(mu_); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index fbbcfad52fb3ae..3b381bd8e7a68e 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -38,6 +38,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/library_types.h" #include "xla/primitive_util.h" #include "xla/status_macros.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" @@ -46,7 +47,6 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" #include "xla/util.h" @@ -256,7 +256,8 @@ auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, cu_preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, max_workspace_size)); - gpu::ScopedActivateContext sac{blas_lt_ref_.parent_}; + std::unique_ptr activation = + blas_lt_ref_.parent_->Activate(); int found_algorithm_count = 0; SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmulAlgoGetHeuristic( @@ -449,15 +450,12 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } - auto isF8Input = [](const auto& desc) { - return desc.type() == CUDA_R_8F_E4M3 || desc.type() == CUDA_R_8F_E5M2; - }; - if (c_scale != nullptr && isF8Input(c_desc_)) { + if (c_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, c_scale.opaque())); } - if (d_scale != nullptr && isF8Input(d_desc_)) { + if (d_scale != nullptr) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, d_scale.opaque())); @@ -505,7 +503,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( #endif } - gpu::ScopedActivateContext sac{blas_lt_ref_.parent_}; + std::unique_ptr activation = + blas_lt_ref_.parent_->Activate(); if (palgo != nullptr) { SE_CUBLAS_RETURN_IF_ERROR(cublasLtMatmul( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 3d61c816024af9..47bcc6d34817be 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -34,13 +34,10 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream_executor.h" #include "xla/types.h" namespace stream_executor { -namespace gpu { -class GpuExecutor; -} // namespace gpu - namespace cuda { class BlasLt : public gpu::BlasLt { @@ -152,7 +149,7 @@ class BlasLt : public gpu::BlasLt { bool must_swap_operands_; }; // class MatmulPlan - explicit BlasLt(gpu::GpuExecutor* parent) + explicit BlasLt(StreamExecutor* parent) : parent_(parent), blas_lt_(nullptr, cublasLtDestroy) {} absl::Status Init() override; @@ -163,7 +160,7 @@ class BlasLt : public gpu::BlasLt { ~BlasLt() override = default; private: - gpu::GpuExecutor* parent_; + StreamExecutor* parent_; mutable absl::Mutex mu_; Owned blas_lt_ ABSL_GUARDED_BY(mu_); }; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc index 281707de8a11ec..382fd7dc3fba10 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.cc @@ -13,30 +13,28 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_collectives.h" + #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "third_party/nccl/nccl.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL -#include "third_party/nccl/nccl.h" -#endif // STREAM_EXECUTOR_GPU_ENABLE_XCCL - namespace stream_executor::gpu { -/* static */ absl::StatusOr GpuCollectives::CollectiveMemoryAllocate( - Context* context, uint64_t bytes) { +/* static */ absl::StatusOr CudaCollectives::CollectiveMemoryAllocate( + StreamExecutor* executor, uint64_t bytes) { if (bytes == 0) return nullptr; - ScopedActivateContext activated(context); + std::unique_ptr activation = executor->Activate(); -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL void* ptr = nullptr; ncclResult_t res = ncclMemAlloc(&ptr, bytes); if (res != ncclSuccess) { @@ -46,19 +44,15 @@ namespace stream_executor::gpu { tsl::strings::HumanReadableNumBytes(bytes), bytes, ncclGetErrorString(res), ncclGetLastError(nullptr))); } - VLOG(2) << "Allocated collective memory " << ptr << " for context " << context - << " of " << bytes << " bytes"; + VLOG(2) << "Allocated collective memory " << ptr << " for executor " + << executor << " of " << bytes << " bytes"; return ptr; -#else - return absl::FailedPreconditionError("XLA was compiled without NCCL support"); -#endif } -/* static */ absl::Status GpuCollectives::CollectiveMemoryDeallocate( - Context* context, void* location) { - ScopedActivateContext activation(context); +/* static */ absl::Status CudaCollectives::CollectiveMemoryDeallocate( + StreamExecutor* executor, void* location) { + std::unique_ptr activation = executor->Activate(); -#ifdef STREAM_EXECUTOR_GPU_ENABLE_XCCL ncclResult_t res = ncclMemFree(location); if (res != ncclSuccess) { return absl::InternalError(absl::StrFormat( @@ -67,12 +61,9 @@ namespace stream_executor::gpu { location, ncclGetErrorString(res), ncclGetLastError(nullptr))); } - VLOG(2) << "Deallocated collective memory " << location << " for context " - << context; + VLOG(2) << "Deallocated collective memory " << location << " for executor " + << executor; return absl::OkStatus(); -#else - return absl::FailedPreconditionError("XLA was compiled without NCCL support"); -#endif } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_collectives.h b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.h similarity index 70% rename from third_party/xla/xla/stream_executor/gpu/gpu_collectives.h rename to third_party/xla/xla/stream_executor/cuda/cuda_collectives.h index 188931312fbe38..bbcf021201d2a8 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_collectives.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives.h @@ -13,33 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ #include #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/context.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor::gpu { -struct GpuCollectives { +struct CudaCollectives { // Allocates a collective device memory space of size bytes associated with // the given context. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclmemalloc - static absl::StatusOr CollectiveMemoryAllocate(Context* context, - uint64_t bytes); + static absl::StatusOr CollectiveMemoryAllocate( + StreamExecutor *executor, uint64_t bytes); // Deallocates a collective device memory space of size bytes associated with // the given context. // // https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclmemfree - static absl::Status CollectiveMemoryDeallocate(Context* context, - void* location); + static absl::Status CollectiveMemoryDeallocate(StreamExecutor *executor, + void *location); }; } // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_COLLECTIVES_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_COLLECTIVES_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc similarity index 56% rename from third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc rename to third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc index 2b993bb25295e0..a486cfa4fccc56 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_collectives.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_stub.cc @@ -17,21 +17,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" -#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor::gpu { -absl::StatusOr GpuCollectives::CollectiveMemoryAllocate(Context* context, - uint64_t bytes) { - return absl::UnimplementedError( - "Feature not supported on ROCm platform (CollectiveMemoryAllocate)"); +/* static */ absl::StatusOr CudaCollectives::CollectiveMemoryAllocate( + StreamExecutor *executor, uint64_t bytes) { + if (bytes == 0) return nullptr; + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); } -absl::Status GpuCollectives::CollectiveMemoryDeallocate(Context* context, - void* location) { - return absl::UnimplementedError( - "Feature not supported on ROCm platform (CollectiveMemoryDeallocate)"); +/* static */ absl::Status CudaCollectives::CollectiveMemoryDeallocate( + StreamExecutor *executor, void *location) { + return absl::FailedPreconditionError("XLA was compiled without NCCL support"); } } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc new file mode 100644 index 00000000000000..97a4c81dc5e14f --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_collectives_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_collectives.h" + +#include + +#include +#include +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::tsl::testing::IsOk; +using ::tsl::testing::IsOkAndHolds; + +TEST(CudaCollectivesTest, CollectiveMemoryAllocation) { + if (!xla::gpu::NcclApi::HasNcclSupport()) { + GTEST_SKIP() << "Compiled without NCCL support"; + } + + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + + constexpr size_t kAllocateSize = 1024; + TF_ASSERT_OK_AND_ASSIGN( + void* memory, + CudaCollectives::CollectiveMemoryAllocate(executor, kAllocateSize)); + + EXPECT_THAT(executor->GetPointerMemorySpace(memory), + IsOkAndHolds(MemoryType::kDevice)); + + EXPECT_THAT(CudaCollectives::CollectiveMemoryDeallocate(executor, memory), + IsOk()); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc new file mode 100644 index 00000000000000..f93104d5068229 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.cc @@ -0,0 +1,125 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_command_buffer.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/command_buffer_kernels.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/typed_kernel_factory.h" // IWYU pragma: keep +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor::gpu { +namespace { +absl::StatusOr CreateGraph() { + VLOG(2) << "Create new CUDA graph"; + CUgraph graph = nullptr; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(&graph, /*flags=*/0), + "Failed to create CUDA graph")); + VLOG(2) << "Created CUDA graph " << graph; + return graph; +} +} // namespace + +absl::StatusOr> CudaCommandBuffer::Create( + Mode mode, GpuExecutor* parent) { + TF_ASSIGN_OR_RETURN(CUgraph graph, CreateGraph()); + return std::unique_ptr( + new CudaCommandBuffer(mode, parent, graph, + /*is_owned_graph=*/true)); +} + +absl::StatusOr +CudaCommandBuffer::GetSetIfConditionKernel() { + if (!set_if_condition_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, cuda::GetSetIfConditionKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN( + set_if_condition_kernel_, + SetIfConditionKernel::FactoryType::Create(parent_, spec)); + } + return &set_if_condition_kernel_; +} + +absl::StatusOr +CudaCommandBuffer::GetSetIfElseConditionKernel() { + if (!set_if_else_condition_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, + cuda::GetSetIfElseConditionKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN( + set_if_else_condition_kernel_, + SetIfElseConditionKernel::FactoryType::Create(parent_, spec)); + } + return &set_if_else_condition_kernel_; +} + +absl::StatusOr +CudaCommandBuffer::GetSetCaseConditionKernel() { + if (!set_case_condition_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, cuda::GetSetCaseConditionKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN( + set_case_condition_kernel_, + SetCaseConditionKernel::FactoryType::Create(parent_, spec)); + } + return &set_case_condition_kernel_; +} + +absl::StatusOr +CudaCommandBuffer::GetSetForConditionKernel() { + if (!set_for_condition_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, cuda::GetSetForConditionKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN( + set_for_condition_kernel_, + SetForConditionKernel::FactoryType::Create(parent_, spec)); + } + return &set_for_condition_kernel_; +} + +absl::StatusOr +CudaCommandBuffer::GetSetWhileConditionKernel() { + if (!set_while_condition_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, + cuda::GetSetWhileConditionKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN( + set_while_condition_kernel_, + SetWhileConditionKernel::FactoryType::Create(parent_, spec)); + } + return &set_while_condition_kernel_; +} + +absl::StatusOr +CudaCommandBuffer::GetNoOpKernel() { + if (!noop_kernel_) { + TF_ASSIGN_OR_RETURN(auto spec, cuda::GetNoOpKernelLoaderSpec()); + TF_ASSIGN_OR_RETURN(noop_kernel_, + NoOpKernel::FactoryType::Create(parent_, spec)); + } + return &noop_kernel_; +} + +std::unique_ptr CudaCommandBuffer::CreateNestedCommandBuffer( + CUgraph graph) { + return std::unique_ptr( + new CudaCommandBuffer(Mode::kNested, parent_, graph, + /*is_owned_graph=*/false)); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h new file mode 100644 index 00000000000000..9250acd364c1df --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_command_buffer.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_COMMAND_BUFFER_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_COMMAND_BUFFER_H_ + +#include + +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include "xla/stream_executor/gpu/gpu_executor.h" + +namespace stream_executor::gpu { + +// This class implements GpuCommandBuffer for Nvidia GPUs. +class CudaCommandBuffer : public GpuCommandBuffer { + public: + // Creates a new CUDA command buffer and the underlying CUDA graph. + static absl::StatusOr> Create( + Mode mode, GpuExecutor* parent); + + private: + CudaCommandBuffer(Mode mode, GpuExecutor* parent, CUgraph graph, + bool is_owned_graph) + : GpuCommandBuffer(mode, parent, graph, is_owned_graph), + parent_(parent) {} + + absl::StatusOr GetSetIfConditionKernel() override; + absl::StatusOr GetSetIfElseConditionKernel() + override; + absl::StatusOr GetSetCaseConditionKernel() override; + absl::StatusOr GetSetForConditionKernel() override; + absl::StatusOr GetSetWhileConditionKernel() + override; + absl::StatusOr GetNoOpKernel() override; + + std::unique_ptr CreateNestedCommandBuffer( + CUgraph graph) override; + + // Lazy loaded auxiliary kernels required for building CUDA graphs (no-op + // barriers, updating conditional handles, etc.). + SetIfConditionKernel set_if_condition_kernel_; + SetIfElseConditionKernel set_if_else_condition_kernel_; + SetCaseConditionKernel set_case_condition_kernel_; + SetForConditionKernel set_for_condition_kernel_; + SetWhileConditionKernel set_while_condition_kernel_; + NoOpKernel noop_kernel_; + + GpuExecutor* parent_; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_COMMAND_BUFFER_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_context.cc b/third_party/xla/xla/stream_executor/cuda/cuda_context.cc new file mode 100644 index 00000000000000..cce7f2c2d20761 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_context.cc @@ -0,0 +1,197 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_context.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/gpu/context_map.h" +#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" + +namespace stream_executor::gpu { + +namespace { + +// Synchronize with spinlocks. +const char kScheduleSpinString[] = "spin"; +// Synchronize with spinlocks that also call CPU yield instructions. +const char kScheduleYieldString[] = "yield"; +// Synchronize with a "synchronization primitive" (e.g. mutex). +const char kScheduleBlockingSyncString[] = "blocking_sync"; + +int GetFlagsFromEnv() { + const char* gpu_schedule_string = + std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); + + if (gpu_schedule_string == nullptr) { + return 0; + } + + unsigned device_flags = 0; + if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_SPIN; + } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_YIELD; + } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { + device_flags = CU_CTX_SCHED_BLOCKING_SYNC; + } else { + LOG(QFATAL) << "Unknown option for environment variable " + "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " + << gpu_schedule_string << " should be one of {" + << kScheduleBlockingSyncString << ", " << kScheduleSpinString + << ", " << kScheduleYieldString << "}"; + } + + return device_flags; +} + +// Returns the current context or dies if it fails. +CUcontext CurrentContextOrDie() { + CUcontext current = nullptr; + TF_CHECK_OK(cuda::ToStatus(cuCtxGetCurrent(¤t), + "Failed to query current context")); + return current; +} + +// Returns the current context and checks that it is in the set of CUDA contexts +// created by StreamExecutor (to ensure that the CUDA runtime didn't create a +// context behind our backs). +CUcontext CurrentContext() { + CUcontext current = CurrentContextOrDie(); + if (current != nullptr && !CudaContext::GetContextMap()->Has(current)) { + LOG(FATAL) << "current context was not created by the StreamExecutor " + "cuda_driver API: " + << current + << "; a CUDA runtime call " + "was likely performed without using a StreamExecutor context"; + } + return current; +} + +} // namespace + +// Returns the singleton ContextMap. +ContextMap* CudaContext::GetContextMap() { + static ContextMap* context_map = + new ContextMap([](void* ptr) { + int device_ordinal; + absl::Status status = cuda::ToStatus( + cuPointerGetAttribute(static_cast(&device_ordinal), + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(ptr))); + if (!status.ok()) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr + << ". Error: " << status; + } + return device_ordinal; + }); + return context_map; +} + +CudaContext::~CudaContext() { + auto status = cuda::ToStatus(cuCtxPushCurrent(context())); + if (!status.ok()) { + LOG(ERROR) << "failed to Push CUDA context; leaking: " << status; + } + CUdevice device; + cuCtxGetDevice(&device); + cuCtxPopCurrent(nullptr); + + status = cuda::ToStatus(cuDevicePrimaryCtxRelease(device)); + + if (!status.ok()) { + LOG(ERROR) << "failed to release CUDA context; leaking: " << status; + } + + GetContextMap()->Remove(context()); +} + +absl::StatusOr CudaContext::Create(int device_ordinal, + CUdevice device) { + CudaContext* context = nullptr; + + int flags = GetFlagsFromEnv(); + + unsigned int former_primary_context_flags; + int former_primary_context_is_active; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, + &former_primary_context_is_active))); + if (former_primary_context_flags != flags) { + if (former_primary_context_is_active) { + LOG(ERROR) + << "The primary context is active and has a different flag set (" + << former_primary_context_flags << ") than the desired flag set (" + << flags << ")."; + } else { + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDevicePrimaryCtxSetFlags(device, flags))); + } + } + + CUcontext former_context = CurrentContextOrDie(); + CUcontext new_context; + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDevicePrimaryCtxRetain(&new_context, device))); + if (former_context != nullptr) { + CUdevice former_device; + if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) { + if (former_device == device) { + if (former_context == new_context) { + VLOG(2) << "The primary context " << former_context << " for device " + << device + << " exists before initializing the StreamExecutor."; + } else { + LOG(WARNING) << "A non-primary context " << former_context + << " for device " << device + << " exists before initializing the StreamExecutor. The " + << "primary context is now " << new_context << ". We " + << "haven't verified StreamExecutor works with that."; + } + } + } else { + LOG(ERROR) << "Failed to get the device of the current context " + << former_context; + } + } + TF_RETURN_IF_ERROR(cuda::ToStatus(cuCtxSetCurrent(former_context))); + + context = GetContextMap()->Add(new_context, device_ordinal); + CHECK(context != nullptr) + << "success in this call must entail non-null result"; + VLOG(2) << "created or reused context " << new_context << " for this thread"; + return context; +} + +void CudaContext::SetActive() { + TF_CHECK_OK( + cuda::ToStatus(cuCtxSetCurrent(context_), "Failed setting context")); +} + +bool CudaContext::IsActive() const { return CurrentContext() == context_; } + +absl::Status CudaContext::Synchronize() { + ScopedActivateContext activation(this); + return cuda::ToStatus(cuCtxSynchronize()); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_context.h similarity index 53% rename from third_party/xla/xla/stream_executor/cuda/cuda_driver.h rename to third_party/xla/xla/stream_executor/cuda/cuda_context.h index c467de8bc3990e..bce62d4e6cba85 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_context.h @@ -15,58 +15,50 @@ limitations under the License. // CUDA userspace driver library wrapper functionality. -#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ -#define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_CONTEXT_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_CONTEXT_H_ -#include -#include -#include -#include - -#include "absl/container/node_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/synchronization/mutex.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/context_map.h" -namespace stream_executor { -namespace gpu { +namespace stream_executor::gpu { -// CUDAContext wraps a cuda CUcontext handle, and includes a unique id. The -// unique id is positive, and ids are not repeated within the process. -class GpuContext : public Context { +// CudaContext implements the Context class for CUDA GPUs. +class CudaContext : public Context { public: - GpuContext(CUcontext context, int device_ordinal) + CudaContext(CUcontext context, int device_ordinal) : context_(context), device_ordinal_(device_ordinal) {} + ~CudaContext() override; void SetActive() override; bool IsActive() const override; CUcontext context() const { return context_; } int device_ordinal() const override { return device_ordinal_; } + absl::Status Synchronize() override; // Disallow copying and moving. - GpuContext(GpuContext&&) = delete; - GpuContext(const GpuContext&) = delete; - GpuContext& operator=(GpuContext&&) = delete; - GpuContext& operator=(const GpuContext&) = delete; + CudaContext(CudaContext&&) = delete; + CudaContext(const CudaContext&) = delete; + CudaContext& operator=(CudaContext&&) = delete; + CudaContext& operator=(const CudaContext&) = delete; + + // Returns a new context for the given device. + static absl::StatusOr Create(int device_ordinal, + CUdevice device); + + // Returns the context map for all XLA-known CUDA contexts. + static ContextMap* GetContextMap(); private: CUcontext const context_; const int device_ordinal_; }; -} // namespace gpu - -namespace cuda { - -using CUDADriver = gpu::GpuDriver; - -using CudaContext = gpu::GpuContext; - -} // namespace cuda -} // namespace stream_executor +} // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_CONTEXT_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index 6d09f0f627ad91..1ed3f90e53bc10 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -49,6 +49,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" @@ -57,23 +58,20 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/numeric_options.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/tsl/util/env_var.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tensor_float_32_utils.h" -#include "tsl/protobuf/dnn.pb.h" // clang-format off #include "third_party/gpus/cuda/include/library_types.h" @@ -213,16 +211,18 @@ std::string CudnnStatusToString(cudnnStatus_t status) { class CudnnHandle { public: // Takes ownership of the lock to access cuDNN using handle. - CudnnHandle(GpuExecutor* executor, std::unique_ptr lock, + CudnnHandle(StreamExecutor* executor, std::unique_ptr lock, cudnnHandle_t handle) - : context_(executor), lock_(std::move(lock)), handle_(handle) {} + : context_(executor->Activate()), + lock_(std::move(lock)), + handle_(handle) {} // Returns cuDNN handle. To be passed directly to cuDNN APIs, don't keep // a copy. cudnnHandle_t handle() const { return handle_; } private: - gpu::ScopedActivateContext context_; + std::unique_ptr context_; std::unique_ptr lock_; cudnnHandle_t handle_; // Not owned. }; @@ -287,7 +287,7 @@ class CudnnAccess { // The legacy default stream synchronizes with all other streams and it is // therefore a bad idea (performance wise) to call any cuDNN APIs that // enqueue work in the stream. - CudnnHandle GetHandle(GpuExecutor* executor, Stream* stream) { + CudnnHandle GetHandle(StreamExecutor* executor, Stream* stream) { auto lock = std::make_unique(&mutex_); mutex_.AssertHeld(); CUstream cu_stream = stream ? AsGpuStreamValue(stream) : cudaStreamLegacy; @@ -485,10 +485,10 @@ void PreloadCudnnSubLibsHelper(dnn::ConvolutionKind kind) { } // namespace -CudnnSupport::CudnnSupport(GpuExecutor* parent) : parent_(parent) {} +CudnnSupport::CudnnSupport(StreamExecutor* parent) : parent_(parent) {} absl::Status CudnnSupport::Init() { - ScopedActivateContext context(parent_); + std::unique_ptr context = parent_->Activate(); // Peek at the last error to give more information in cases of errors. cudaError_t cuda_error = cudaPeekAtLastError(); @@ -534,7 +534,7 @@ absl::Status CudnnSupport::Init() { LOG(ERROR) << "Could not create cudnn handle: " << CudnnStatusToString(status); int64_t free, total; - GpuDriver::GetDeviceMemoryInfo(parent_->gpu_context(), &free, &total); + parent_->DeviceMemoryUsage(&free, &total); LOG(ERROR) << "Memory usage: " << free << " bytes free, " << total << " bytes total."; @@ -1909,7 +1909,7 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( class CudnnRnnSequenceTensorDescriptor : public dnn::RnnSequenceTensorDescriptor { - CudnnRnnSequenceTensorDescriptor(GpuExecutor* parent, int max_seq_length, + CudnnRnnSequenceTensorDescriptor(StreamExecutor* parent, int max_seq_length, int batch_size, int data_size, RNNDataDescriptor data_handle, TensorDescriptor handle) @@ -1925,7 +1925,7 @@ class CudnnRnnSequenceTensorDescriptor default; static absl::StatusOr Create( - GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, + StreamExecutor* parent, int max_seq_length, int batch_size, int data_size, cudnnDataType_t data_type) { if (max_seq_length <= 0) { return absl::InvalidArgumentError("max_seq_length <= 0"); @@ -1943,7 +1943,7 @@ class CudnnRnnSequenceTensorDescriptor } static absl::StatusOr Create( - GpuExecutor* parent, int max_seq_length, int batch_size, int data_size, + StreamExecutor* parent, int max_seq_length, int batch_size, int data_size, absl::Span seq_lengths, bool time_major, cudnnDataType_t data_type) { if (max_seq_length <= 0) { @@ -2001,7 +2001,7 @@ class CudnnRnnSequenceTensorDescriptor class CudnnRnnStateTensorDescriptor : public dnn::RnnStateTensorDescriptor { public: - CudnnRnnStateTensorDescriptor(GpuExecutor* parent, int num_layers, + CudnnRnnStateTensorDescriptor(StreamExecutor* parent, int num_layers, int batch_size, int data_size, cudnnDataType_t data_type) : handle_(CreateTensorDescriptor()), @@ -5240,10 +5240,12 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( .set_uid(next_uid()); amax_s->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid()); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::FLOAT) .set_uid(next_uid()); @@ -5278,6 +5280,202 @@ absl::StatusOr GetCudnnFlashAttentionF8OperationGraph( #endif } +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type) { +#if CUDNN_VERSION >= 90100 + if (VLOG_IS_ON(4)) { + VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() + << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString() + << "\n bmm2_grad_gemm1_lhs(p): " << p_desc.ToString() + << "\n bmm2_grad_gemm2_rhs(v^t): " << v_desc.ToString() + << "\n d_output(do): " << do_desc.ToString() + << "\n d_bmm1_lhs(dq): " << dq_desc.ToString() + << "\n d_bmm1_rhs(dk): " << dk_desc.ToString() + << "\n d_bmm2_rhs(dv): " << dv_desc.ToString() + << "\n scale: " << scale; + } + using cudnn_frontend::graph::Tensor_attributes; + cudnn_frontend::graph::Graph graph; + if (!(q_desc.type() == k_desc.type() && v_desc.type() == do_desc.type() && + do_desc.type() == dq_desc.type() && dq_desc.type() == dk_desc.type() && + dk_desc.type() == dv_desc.type())) { + return absl::InternalError("Input datatypes do not match."); + } + + auto ioDataType = ToCudnnFrontendDataType(q_desc.type()); + graph.set_compute_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_io_data_type(ioDataType); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + + std::shared_ptr q = + graph.tensor(Tensor_attributes() + .set_name("Q") + .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(q_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr k = + graph.tensor(Tensor_attributes() + .set_name("K") + .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(k_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr v = + graph.tensor(Tensor_attributes() + .set_name("V") + .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) + .set_stride(v_desc.GetCudnnCompatibleStrides(true)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + std::shared_ptr dO = + graph.tensor(Tensor_attributes() + .set_name("dO") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + std::shared_ptr Stats = + graph.tensor(Tensor_attributes() + .set_name("Stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + + auto descale_q = + graph.tensor(Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); + auto descale_k = graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = graph.tensor_like(descale_q, "Scale_dV"); + + descale_k->set_uid(next_uid()); + descale_v->set_uid(next_uid()); + descale_s->set_uid(next_uid()); + descale_o->set_uid(next_uid()); + descale_dO->set_uid(next_uid()); + descale_dP->set_uid(next_uid()); + + scale_s->set_uid(next_uid()); + scale_dP->set_uid(next_uid()); + scale_dQ->set_uid(next_uid()); + scale_dK->set_uid(next_uid()); + scale_dV->set_uid(next_uid()); + + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL; + auto sdpa_fp8_backwards_options = + cudnn_frontend::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale); + + auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = + graph.sdpa_fp8_backward(q, k, v, o, dO, Stats, descale_q, descale_k, + descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, + scale_dP, sdpa_fp8_backwards_options); + + dQ->set_output(true) + .set_dim(dq_desc.dimensions()) + .set_stride(dq_desc.GetLogicalStrides()) + .set_name("dQ") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dK->set_output(true) + .set_dim(dk_desc.dimensions()) + .set_stride(dk_desc.GetLogicalStrides()) + .set_name("dK") + .set_uid(next_uid()) + .set_data_type(ioDataType); + dV->set_output(true) + .set_dim(dv_desc.dimensions()) + .set_stride(dv_desc.GetLogicalStrides()) + .set_name("dV") + .set_uid(next_uid()) + .set_data_type(ioDataType); + Amax_dQ->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dK->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dV->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + Amax_dP->set_output(true) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(cudnn_frontend::DataType_t::FLOAT) + .set_uid(next_uid()); + + CudnnGraph cudnnGraph(std::move(graph)); + TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + dnn_support, NumericOptions{/*require_determinism=*/false, + /*allow_tf32=*/true})); + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); + + if (VLOG_IS_ON(4)) { + VLOG(4) << "\b workspace size:" << cudnnGraph.Graph().get_workspace_size(); + VLOG(4) << "\b flash attention f8 operation backward graph: " << graph; + } + + return cudnnGraph; +#else + return absl::UnimplementedError( + "Cudnn flash attention only supported with Cudnn >= 9.1.0"); +#endif +} + absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, const dnn::MatmulTensorDescriptor& k_desc, @@ -5573,7 +5771,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { public: // Queries the workspace size and constructs a 'CudnnLegacyConvRunner'. static absl::StatusOr Create( - GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn, + StreamExecutor* parent, Stream* stream, CudnnAccess* cudnn, const dnn::AlgorithmDesc& algo, dnn::DataType input_type, dnn::DataType output_type, dnn::ConvolutionKind kind, CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd, @@ -5755,7 +5953,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { private: // Private to prevent passing in the wrong workspace_size. - CudnnLegacyConvRunner(GpuExecutor* parent, CudnnAccess* cudnn, + CudnnLegacyConvRunner(StreamExecutor* parent, CudnnAccess* cudnn, int64_t algo_id, bool tensor_ops_enabled, size_t workspace_size, dnn::DataType input_type, dnn::DataType output_type, dnn::ConvolutionKind kind, @@ -5781,7 +5979,7 @@ class CudnnLegacyConvRunner : public dnn::ConvRunner { return {algo_id_, tensor_ops_enabled_, workspace_size_}; } - GpuExecutor* parent_; + StreamExecutor* parent_; CudnnAccess* cudnn_; int64_t algo_id_; bool tensor_ops_enabled_; @@ -6135,7 +6333,7 @@ class CudnnExecutionPlanRunner } static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, + StreamExecutor* parent, CudnnAccess* cudnn, cudnn_frontend::ExecutionPlan plan, absl::Span uids, bool need_side_input) { auto workspace_size = static_cast(plan.getWorkspaceSize()); @@ -6151,7 +6349,7 @@ class CudnnExecutionPlanRunner } static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, + StreamExecutor* parent, CudnnAccess* cudnn, cudnn_frontend::ExecutionPlan plan, absl::Span uids, bool need_side_input, std::vector scalar_input_uids, std::vector scalar_input_values) { @@ -6162,7 +6360,7 @@ class CudnnExecutionPlanRunner } private: - CudnnExecutionPlanRunner(GpuExecutor* parent, CudnnAccess* cudnn, + CudnnExecutionPlanRunner(StreamExecutor* parent, CudnnAccess* cudnn, cudnn_frontend::ExecutionPlan plan, size_t workspace_size, absl::Span uids, bool need_side_input, @@ -6176,7 +6374,7 @@ class CudnnExecutionPlanRunner need_side_input_(need_side_input), scalar_input_uids_(scalar_input_uids), scalar_input_values_(scalar_input_values) {} - GpuExecutor* parent_; + StreamExecutor* parent_; CudnnAccess* cudnn_; cudnn_frontend::ExecutionPlan plan_; size_t workspace_size_; @@ -6190,7 +6388,7 @@ namespace { template absl::Status CreateOpRunners( - Stream* stream, CudnnHandle& cudnn, GpuExecutor* gpu_executor, + Stream* stream, CudnnHandle& cudnn, StreamExecutor* gpu_executor, CudnnAccess* cudnn_access, std::unique_ptr op_graph, dnn::ConvolutionKind kind, dnn::DataType input_type, @@ -6542,7 +6740,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { public: // Queries the workspace size and constructs a 'CudnnLegacyFusedConvRunner'. static absl::StatusOr Create( - GpuExecutor* parent, Stream* stream, CudnnAccess* cudnn, + StreamExecutor* parent, Stream* stream, CudnnAccess* cudnn, const dnn::AlgorithmDesc& algo, dnn::DataType input_type, double conv_scale, double side_input_scale, CudnnTensorDescriptor input_nd, CudnnTensorDescriptor output_nd, @@ -6668,7 +6866,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { private: // Private to prevent passing in the wrong workspace_size. - CudnnLegacyFusedConvRunner(GpuExecutor* parent, CudnnAccess* cudnn, + CudnnLegacyFusedConvRunner(StreamExecutor* parent, CudnnAccess* cudnn, int64_t algo_id, bool tensor_ops_enabled, size_t workspace_size, dnn::DataType input_type, double conv_scale, double side_input_scale, @@ -6698,7 +6896,7 @@ class CudnnLegacyFusedConvRunner : public dnn::FusedConvRunner { return {algo_id_, tensor_ops_enabled_, workspace_size_}; } - GpuExecutor* parent_; + StreamExecutor* parent_; CudnnAccess* cudnn_; int64_t algo_id_; bool tensor_ops_enabled_; @@ -8274,9 +8472,7 @@ absl::Status CudnnGraph::Execute(Stream& stream, const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( - dnn_support.cudnn_ - ->GetHandle(ExtractGpuExecutor(stream.parent()), &stream) - .handle(), + dnn_support.cudnn_->GetHandle(stream.parent(), &stream).handle(), tensor_to_ptr_map, workspace.opaque())); return absl::OkStatus(); } @@ -8289,15 +8485,7 @@ void initialize_cudnn() { PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuDNN", [](StreamExecutor* parent) -> dnn::DnnSupport* { - gpu::GpuExecutor* cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) << "Attempting to initialize an instance of the cuDNN " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } - - gpu::CudnnSupport* dnn = new gpu::CudnnSupport(cuda_executor); + gpu::CudnnSupport* dnn = new gpu::CudnnSupport(parent); if (!dnn->Init().ok()) { // Note: Init() will log a more specific error. delete dnn; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 3a223731347766..16a08231263500 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -34,7 +34,8 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/numeric_options.h" -#include "tsl/protobuf/dnn.pb.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #if CUDNN_VERSION >= 8100 #include "third_party/cudnn_frontend/include/cudnn_frontend.h" @@ -43,7 +44,6 @@ limitations under the License. namespace stream_executor { namespace gpu { -class GpuExecutor; class CudnnRnnDescriptor; class CudnnRnnSequenceTensorDescriptor; class CudnnRnnStateTensorDescriptor; @@ -90,7 +90,7 @@ class CudnnGraph : public dnn::DnnGraph { // functions, see dnn.h. class CudnnSupport : public dnn::DnnSupport { public: - explicit CudnnSupport(GpuExecutor* parent); + explicit CudnnSupport(StreamExecutor* parent); absl::Status Init() override; absl::StatusOr GetVersion() override; @@ -564,7 +564,7 @@ class CudnnSupport : public dnn::DnnSupport { // Uses cuDNN handle for execution. friend class CudnnGraph; - GpuExecutor* parent_; // Parent executor object. Not owned. + StreamExecutor* parent_; // Parent executor object. Not owned. // Provides access to the cuDNN handle. std::unique_ptr cudnn_; @@ -732,6 +732,16 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( const dnn::FMHAMaskKind mask_type, bool force_deterministic, const int sliding_window_length); +absl::StatusOr GetCudnnFlashAttentionBackwardF8OperationGraph( + dnn::DnnSupport& dnn_support, const dnn::MatmulTensorDescriptor& q_desc, + const dnn::MatmulTensorDescriptor& k_desc, + const dnn::MatmulTensorDescriptor& p_desc, + const dnn::MatmulTensorDescriptor& v_desc, + const dnn::MatmulTensorDescriptor& do_desc, + const dnn::TensorDescriptor& dq_desc, const dnn::TensorDescriptor& dk_desc, + const dnn::TensorDescriptor& dv_desc, double scale, + dnn::FMHAMaskKind mask_type); + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 1308272ff09f89..c9feb76420e225 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -13,15 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_driver.h" - #include #include #include #include #include -#include #include #include #include @@ -29,19 +26,17 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" -#include "absl/debugging/leak_check.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_context.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/context_map.h" @@ -55,246 +50,13 @@ limitations under the License. #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/platform/macros.h" #include "tsl/platform/numbers.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" -#include "tsl/platform/threadpool.h" namespace stream_executor { namespace gpu { -namespace { - -// Returns the device associated with the given context. -absl::StatusOr DeviceFromContext(Context* context) { - ScopedActivateContext activated{context}; - CUdevice device = -1; - auto status = cuda::ToStatus(cuCtxGetDevice(&device)); - if (status.ok()) { - return device; - } - - return status; -} - -CUcontext CurrentContextOrDie() { - CUcontext current = nullptr; - TF_CHECK_OK(cuda::ToStatus(cuCtxGetCurrent(¤t), - "Failed to query current context")); - return current; -} - -// Returns the singleton ContextMap. -ContextMap* GetContextMap() { - static ContextMap* context_map = - new ContextMap([](void* ptr) { - int device_ordinal; - absl::Status status = cuda::ToStatus( - cuPointerGetAttribute(static_cast(&device_ordinal), - CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(ptr))); - if (!status.ok()) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr - << ". Error: " << status; - } - return device_ordinal; - }); - return context_map; -} - -// Returns the current context and checks that it is in the set of CUDA contexts -// created by StreamExecutor (to ensure that the CUDA runtime didn't create a -// context behind our backs). -CUcontext CurrentContext() { - CUcontext current = CurrentContextOrDie(); - if (current != nullptr && !GetContextMap()->Has(current)) { - LOG(FATAL) << "current context was not created by the StreamExecutor " - "cuda_driver API: " - << current - << "; a CUDA runtime call " - "was likely performed without using a StreamExecutor context"; - } - return current; -} - -// CUDA driver routines may require a large amount of stack (particularly -// cuModuleLoadDataEx, in our experience). To avoid stack overflow when using -// stack-limited threads (such as those spawned by a default-argument -// thread::ThreadPool on some platforms), we run certain routines in this pool -// and wait for completion. -tsl::thread::ThreadPool* GetDriverExecutor() { - static tsl::thread::ThreadPool* thread_pool = new tsl::thread::ThreadPool( - tsl::Env::Default(), tsl::ThreadOptions(), "cuda_driver", 1); - return thread_pool; -} - -} // namespace - -void GpuContext::SetActive() { - TF_CHECK_OK( - cuda::ToStatus(cuCtxSetCurrent(context_), "Failed setting context")); -} - -bool GpuContext::IsActive() const { return CurrentContext() == context_; } - -namespace { - -// Actually performs the work of CUDA initialization. Wrapped up in one-time -// execution guard. -static absl::Status InternalInit() { - absl::Status status = - cuda::ToStatus(cuInit(0 /* = flags */), "Failed call to cuInit"); - if (status.ok()) { - return status; - } - - LOG(ERROR) << "failed call to cuInit: " << status; - - Diagnostician::LogDiagnosticInformation(); - return status; -} - -// Synchronize with spinlocks. -const char kScheduleSpinString[] = "spin"; -// Synchronize with spinlocks that also call CPU yield instructions. -const char kScheduleYieldString[] = "yield"; -// Synchronize with a "synchronization primitive" (e.g. mutex). -const char kScheduleBlockingSyncString[] = "blocking_sync"; - -int GetFlagsFromEnv() { - const char* gpu_schedule_string = - std::getenv("TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE"); - - if (gpu_schedule_string == nullptr) { - return 0; - } - - unsigned device_flags = 0; - if (strcmp(kScheduleSpinString, gpu_schedule_string) == 0) { - device_flags = CU_CTX_SCHED_SPIN; - } else if (strcmp(kScheduleYieldString, gpu_schedule_string) == 0) { - device_flags = CU_CTX_SCHED_YIELD; - } else if (strcmp(kScheduleBlockingSyncString, gpu_schedule_string) == 0) { - device_flags = CU_CTX_SCHED_BLOCKING_SYNC; - } else { - LOG(QFATAL) << "Unknown option for environment variable " - "TF_CUDA_PLATFORM_GPU_DEVICE_SCHEDULE " - << gpu_schedule_string << " should be one of {" - << kScheduleBlockingSyncString << ", " << kScheduleSpinString - << ", " << kScheduleYieldString << "}"; - } - - return device_flags; -} - -} // namespace - -absl::Status GpuDriver::Init() { - // Cached return value from calling InternalInit(), as cuInit need only be - // called once, but GpuDriver::Init may be called many times. - static absl::Status* init_retval = [] { - return new absl::Status(InternalInit()); - }(); - return *init_retval; -} - -absl::Status GpuDriver::GetDevice(int device_ordinal, CUdevice* device) { - return cuda::ToStatus(cuDeviceGet(device, device_ordinal), - "Failed call to cuDeviceGet"); -} - -absl::Status GpuDriver::GetDeviceName(CUdevice device, - std::string* device_name) { - static const size_t kCharLimit = 64; - absl::InlinedVector chars(kCharLimit); - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuDeviceGetName(chars.begin(), kCharLimit - 1, device), - "Failed to get device name")); - chars[kCharLimit - 1] = '\0'; - *device_name = chars.begin(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::CreateContext(int device_ordinal, CUdevice device, - Context** context) { - *context = nullptr; - - int flags = GetFlagsFromEnv(); - - unsigned int former_primary_context_flags; - int former_primary_context_is_active; - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuDevicePrimaryCtxGetState(device, &former_primary_context_flags, - &former_primary_context_is_active))); - if (former_primary_context_flags != flags) { - if (former_primary_context_is_active) { - LOG(ERROR) - << "The primary context is active and has a different flag set (" - << former_primary_context_flags << ") than the desired flag set (" - << flags << ")."; - } else { - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuDevicePrimaryCtxSetFlags(device, flags))); - } - } - - CUcontext former_context = CurrentContextOrDie(); - CUcontext new_context; - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuDevicePrimaryCtxRetain(&new_context, device))); - if (former_context != nullptr) { - CUdevice former_device; - if (cuCtxGetDevice(&former_device) == CUDA_SUCCESS) { - if (former_device == device) { - if (former_context == new_context) { - VLOG(2) << "The primary context " << former_context << " for device " - << device - << " exists before initializing the StreamExecutor."; - } else { - LOG(WARNING) << "A non-primary context " << former_context - << " for device " << device - << " exists before initializing the StreamExecutor. The " - << "primary context is now " << new_context << ". We " - << "haven't verified StreamExecutor works with that."; - } - } - } else { - LOG(ERROR) << "Failed to get the device of the current context " - << former_context; - } - } - TF_RETURN_IF_ERROR(cuda::ToStatus(cuCtxSetCurrent(former_context))); - - *context = GetContextMap()->Add(new_context, device_ordinal); - CHECK(*context != nullptr) - << "success in this call must entail non-null result"; - VLOG(2) << "created or reused context " << new_context << " for this thread"; - return absl::OkStatus(); -} - -void GpuDriver::DestroyContext(Context* context) { - if (context == nullptr) { - return; - } - GpuContext* cuda_context = tensorflow::down_cast(context); - auto status = cuda::ToStatus(cuCtxPushCurrent(cuda_context->context())); - if (!status.ok()) { - LOG(ERROR) << "failed to Push CUDA context; leaking: " << status; - } - CUdevice device; - cuCtxGetDevice(&device); - cuCtxPopCurrent(nullptr); - - status = cuda::ToStatus(cuDevicePrimaryCtxRelease(device)); - - if (!status.ok()) { - LOG(ERROR) << "failed to release CUDA context; leaking: " << status; - } - - GetContextMap()->Remove(cuda_context->context()); -} - absl::Status GpuDriver::CreateGraph(CUgraph* graph) { VLOG(2) << "Create new CUDA graph"; TF_RETURN_IF_ERROR(cuda::ToStatus(cuGraphCreate(graph, /*flags=*/0), @@ -526,22 +288,6 @@ absl::StatusOr GpuDriver::GraphDebugDotPrint( return std::string(path); } -absl::Status GpuDriver::DeviceGraphMemTrim(CUdevice device) { - VLOG(2) << "Trim CUDA device graph memory " << device; - return cuda::ToStatus(cuDeviceGraphMemTrim(device), - "Failed to trim device graph memory"); -} - -absl::StatusOr GpuDriver::StreamIsCapturing(CUstream stream) { - VLOG(2) << "Checking if stream " << stream << " is capturing"; - - CUstreamCaptureStatus status; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuStreamIsCapturing(stream, &status), - "Failed to check stream capturing status")); - - return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; -} - absl::Status GpuDriver::GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, CUgraph graph, Context* context, unsigned int default_launch_value, unsigned int flags) { @@ -553,7 +299,8 @@ absl::Status GpuDriver::GraphConditionalHandleCreate( #if CUDA_VERSION >= 12030 return cuda::ToStatus( cuGraphConditionalHandleCreate( - handle, graph, tensorflow::down_cast(context)->context(), + handle, graph, + tensorflow::down_cast(context)->context(), default_launch_value, flags), "Failed to create conditional handle for a CUDA graph"); #else @@ -584,8 +331,8 @@ absl::StatusOr GpuDriver::GraphAddNode( CUgraphNodeParams cu_params; memset(&cu_params, 0, sizeof(cu_params)); - GpuContext* gpu_context = - tensorflow::down_cast(conditional->context); + CudaContext* gpu_context = + tensorflow::down_cast(conditional->context); cu_params.type = CU_GRAPH_NODE_TYPE_CONDITIONAL; cu_params.conditional.handle = conditional->handle; @@ -713,7 +460,7 @@ absl::Status GpuDriver::GraphAddMemcpyD2DNode( Context* context, CUgraphNode* node, CUgraph graph, absl::Span deps, CUdeviceptr gpu_dst, CUdeviceptr gpu_src, uint64_t size) { - GpuContext* gpu_context = tensorflow::down_cast(context); + CudaContext* gpu_context = tensorflow::down_cast(context); VLOG(2) << "Add memcpy d2d node to a graph " << graph << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size @@ -739,8 +486,8 @@ absl::Status GpuDriver::GraphAddMemcpyD2DNode( absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( Context* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, - GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { - GpuContext* gpu_context = tensorflow::down_cast(context); + CUdeviceptr gpu_dst, CUdeviceptr gpu_src, uint64_t size) { + CudaContext* gpu_context = tensorflow::down_cast(context); VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " << exec << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size @@ -799,7 +546,7 @@ absl::Status GpuDriver::GraphAddMemsetNode( absl::Span deps, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { - GpuContext* gpu_context = tensorflow::down_cast(context); + CudaContext* gpu_context = tensorflow::down_cast(context); VLOG(2) << "Add memset node to a graph " << graph << "; dst: " << reinterpret_cast(dst) << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) @@ -829,7 +576,7 @@ absl::Status GpuDriver::GraphExecMemsetNodeSetParams( Context* context, CUgraphExec exec, CUgraphNode node, CUdeviceptr dst, std::variant bit_pattern, uint64_t num_elements) { - GpuContext* gpu_context = tensorflow::down_cast(context); + CudaContext* gpu_context = tensorflow::down_cast(context); VLOG(2) << "Set memset node params " << node << " in graph executable " << exec << "; dst: " << reinterpret_cast(dst) << "; bit_pattern: " << std::visit(BitPatternToString(), bit_pattern) @@ -874,581 +621,6 @@ absl::Status GpuDriver::GraphAddChildNode(CUgraphNode* node, CUgraph graph, "Failed to set CUDA graph child node params"); } -absl::Status GpuDriver::LaunchKernel( - Context* context, absl::string_view kernel_name, CUfunction function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, CUstream stream, - void** kernel_params, void** extra) { - ScopedActivateContext activation(context); - VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x - << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z - << " bdx: " << block_dim_x << " bdy: " << block_dim_y - << " bdz: " << block_dim_z - << "; shared_mem_bytes: " << shared_mem_bytes; - - // TODO(ezhulenev): Why do we do it on every call to launch kernel? This - // should be moved one level up to se::Kernel level, and done just once (or - // updated once we get a new larger shared memory request). - if (shared_mem_bytes != 0) { - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuFuncSetAttribute(function, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_mem_bytes), - "Failed to set shared memory size")); - } - - return cuda::ToStatus( - cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, - block_dim_y, block_dim_z, shared_mem_bytes, stream, - kernel_params, extra), - absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, - "; block dims: ", block_dim_x, "x", block_dim_y, "x", - block_dim_z, "; grid dims: ", grid_dim_x, "x", grid_dim_y, - "x", grid_dim_z, - "; shared memory size: ", shared_mem_bytes)); -} - -absl::Status GpuDriver::LaunchKernel( - Context* context, absl::string_view kernel_name, GpuFunctionHandle function, - unsigned int cluster_dim_x, unsigned int cluster_dim_y, - unsigned int cluster_dim_z, unsigned int grid_dim_x, - unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, - unsigned int block_dim_y, unsigned int block_dim_z, - unsigned int shared_mem_bytes, GpuStreamHandle stream, void** kernel_params, - void** extra) { - ScopedActivateContext activation(context); - VLOG(2) << "launching kernel: " << kernel_name << "; cdx: " << cluster_dim_x - << " cdy: " << cluster_dim_y << " cdz: " << cluster_dim_z - << " gdx: " << grid_dim_x << " gdy: " << grid_dim_y - << " gdz: " << grid_dim_z << " bdx: " << block_dim_x - << " bdy: " << block_dim_y << " bdz: " << block_dim_z - << "; shared_mem_bytes: " << shared_mem_bytes; - - // TODO(ezhulenev): Why do we do it on every call to launch kernel? This - // should be moved one level up to se::Kernel level, and done just once (or - // updated once we get a new larger shared memory request). - if (shared_mem_bytes != 0) { - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuFuncSetAttribute(function, - CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, - shared_mem_bytes), - "Failed to set shared memory size")); - } - - CUlaunchConfig launch_config; - memset(&launch_config, 0, sizeof(launch_config)); - launch_config.blockDimX = block_dim_x; - launch_config.blockDimY = block_dim_y; - launch_config.blockDimZ = block_dim_z; - launch_config.gridDimX = grid_dim_x; - launch_config.gridDimY = grid_dim_y; - launch_config.gridDimZ = grid_dim_z; - launch_config.hStream = stream; - launch_config.sharedMemBytes = shared_mem_bytes; - - CUlaunchAttribute cluster_dims; - memset(&cluster_dims, 0, sizeof(cluster_dims)); - cluster_dims.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; - cluster_dims.value.clusterDim.x = cluster_dim_x; - cluster_dims.value.clusterDim.y = cluster_dim_y; - cluster_dims.value.clusterDim.z = cluster_dim_z; - - launch_config.attrs = &cluster_dims; - launch_config.numAttrs = 1; - - return cuda::ToStatus( - cuLaunchKernelEx(&launch_config, function, kernel_params, extra), - absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, - "; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x", - cluster_dim_z, "; block dims: ", block_dim_x, "x", - block_dim_y, "x", block_dim_z, "; grid dims: ", grid_dim_x, - "x", grid_dim_y, "x", grid_dim_z, - "; shared memory size: ", shared_mem_bytes)); -} - -absl::Status GpuDriver::LoadCubin(Context* context, const char* cubin_bytes, - CUmodule* module) { - ScopedActivateContext activation(context); - return cuda::ToStatus( - cuModuleLoadFatBinary(module, cubin_bytes), - "Failed to load in-memory CUBIN (compiled for a different GPU?)."); -} - -absl::Status GpuDriver::LoadPtx(Context* context, const char* ptx_contents, - CUmodule* module) { - absl::Notification notification; - absl::Status ret = absl::OkStatus(); - GetDriverExecutor()->Schedule( - [context, ptx_contents, module, &ret, ¬ification]() { - ScopedActivateContext activation(context); - void* ptx_data = const_cast(ptx_contents); - static const unsigned int kLogBufferBytesLimit = 1024; - unsigned int error_log_buffer_bytes = kLogBufferBytesLimit; - unsigned int info_log_buffer_bytes = kLogBufferBytesLimit; - absl::InlinedVector error_log_buffer(error_log_buffer_bytes); - absl::InlinedVector info_log_buffer(info_log_buffer_bytes); - bool log_verbose = true; - CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, - CU_JIT_ERROR_LOG_BUFFER, - CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, - CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE}; - // Note that the driver API wants the contents of this values to be - // stored in an array of void*s, so we coerce them accordingly. - void* option_values[] = { - absl::bit_cast(uintptr_t(error_log_buffer_bytes)), - absl::bit_cast(error_log_buffer.data()), - absl::bit_cast(uintptr_t(info_log_buffer_bytes)), - absl::bit_cast(info_log_buffer.data()), - absl::bit_cast(uintptr_t(log_verbose))}; - CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values)); - - absl::Status status; - { - // TODO(leary) Need to see if NVIDIA can expunge the leakiness in - // their module loading: see http://b/13248943 - absl::LeakCheckDisabler disabler; - status = cuda::ToStatus(cuModuleLoadDataEx( - module, ptx_data, TF_ARRAYSIZE(options), options, option_values)); - } - - // The PTX JIT mutates the values in the option values array to reflect - // the size of the logs it output; now that we've made the call, read - // the values back out. - error_log_buffer_bytes = reinterpret_cast(option_values[0]); - info_log_buffer_bytes = reinterpret_cast(option_values[2]); - CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit); - CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit); - - if (!status.ok()) { - LOG(ERROR) << "failed to load PTX text as a module: " << status; - // As a precaution for null termination of the API-provided value, - // ensure that at least the last byte is null. - error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1 - : 0] = '\0'; - LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes - << " bytes): " << error_log_buffer.data(); - if (absl::StrContains(error_log_buffer.data(), - "Register allocation failed")) { - ret = absl::ResourceExhaustedError( - absl::StrFormat("Failed to load PTX text as a module (register " - "allocation failed): %s", - status.ToString())); - } else { - ret = status; - } - notification.Notify(); - return; - } - - VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes - << " bytes): " << info_log_buffer.data(); - VLOG(3) << "PTX compilation error log (" << error_log_buffer_bytes - << " bytes): " << error_log_buffer.data(); - CHECK(module != nullptr); - notification.Notify(); - }); - notification.WaitForNotification(); - - return ret; -} - -absl::Status GpuDriver::LoadHsaco(Context* context, const char* hsaco_contents, - CUmodule* module) { - return absl::InternalError( - "Feature not supported on CUDA platform (LoadHsaco)"); -} - -absl::Status GpuDriver::SynchronousMemsetUint8(Context* context, - CUdeviceptr location, - uint8_t value, size_t size) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD8(location, value, size), - "Failed to memset memory"); -} - -absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, - CUdeviceptr location, - uint32_t value, - size_t uint32_count) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD32(location, value, uint32_count), - "Failed to memset memory"); -} - -absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, - CUdeviceptr location, - uint8_t value, - size_t uint32_count, - CUstream stream) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD8Async(location, value, uint32_count, stream), - "Failed to enqueue async memset operation"); -} - -absl::Status GpuDriver::AsynchronousMemsetUint32(Context* context, - CUdeviceptr location, - uint32_t value, - size_t uint32_count, - CUstream stream) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuMemsetD32Async(location, value, uint32_count, stream), - "Failed to enqueue async memset operation"); -} - -absl::Status GpuDriver::AddStreamCallback(Context* context, CUstream stream, - StreamCallback callback, void* data) { - // Note: flags param is required to be zero according to CUDA 6.0. - return cuda::ToStatus(cuLaunchHostFunc(stream, callback, data)); -} - -absl::Status GpuDriver::GetModuleFunction(Context* context, CUmodule module, - const char* kernel_name, - CUfunction* function) { - ScopedActivateContext activated{context}; - CHECK(module != nullptr && kernel_name != nullptr); - cudaError_t cuda_error = cudaPeekAtLastError(); - if (cuda_error != cudaSuccess) { - return absl::InternalError( - absl::StrCat("There was an error before calling cuModuleGetFunction (", - cuda_error, "): ", cudaGetErrorName(cuda_error), " : ", - cudaGetErrorString(cuda_error))); - } - return cuda::ToStatus(cuModuleGetFunction(function, module, kernel_name), - "Failed to get module function"); -} - -absl::Status GpuDriver::GetModuleSymbol(Context* context, CUmodule module, - const char* symbol_name, - CUdeviceptr* dptr, size_t* bytes) { - ScopedActivateContext activated{context}; - CHECK(module != nullptr && symbol_name != nullptr && - (dptr != nullptr || bytes != nullptr)); - return cuda::ToStatus( - cuModuleGetGlobal(dptr, bytes, module, symbol_name), - absl::StrCat("Failed to get symbol '", symbol_name, "'")); -} - -void GpuDriver::UnloadModule(Context* context, CUmodule module) { - ScopedActivateContext activated{context}; - auto status = cuda::ToStatus(cuModuleUnload(module)); - if (!status.ok()) { - LOG(ERROR) << "failed to unload module " << module - << "; leaking: " << status; - } -} - -absl::StatusOr GpuDriver::CreateStream(Context* context, - int priority) { - ScopedActivateContext activated(context); - GpuStreamHandle stream; - // If the priority is 0, then use the previous api to create the stream with - // the default priority for backward compatibility. Probably there is no - // difference in using the new api call but leaving it as is for now. - if (priority == 0) { - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING))); - } else { - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuStreamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, priority))); - } - - VLOG(2) << "successfully created stream " << stream << " for context " - << context << " on thread"; - return stream; -} - -void GpuDriver::DestroyStream(Context* context, GpuStreamHandle stream) { - if (stream == nullptr) { - return; - } - - ScopedActivateContext activated{context}; - CUresult res = cuStreamQuery(stream); - if (res != CUDA_SUCCESS) { - LOG(ERROR) << "stream not idle on destroy: " << cuda::ToStatus(res); - } - - auto status = cuda::ToStatus(cuStreamDestroy(stream)); - if (!status.ok()) { - LOG(ERROR) << "failed to destroy CUDA stream for context " << context - << ": " << status; - } else { - VLOG(2) << "successfully destroyed stream " << stream << " for context " - << context; - } -} - -void* GpuDriver::DeviceAllocate(Context* context, uint64_t bytes) { - if (bytes == 0) { - return nullptr; - } - - ScopedActivateContext activated{context}; - CUdeviceptr result = 0; - auto status = cuda::ToStatus(cuMemAlloc(&result, bytes)); - if (!status.ok()) { - // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator - // implements a retry if the first allocation fails). - LOG(INFO) << "failed to allocate " - << tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes - << " bytes) from device: " << status; - return nullptr; - } - void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << context << " of " - << bytes << " bytes"; - return ptr; -} - -void GpuDriver::DeviceDeallocate(Context* context, void* location) { - ScopedActivateContext activation(context); - CUdeviceptr pointer = absl::bit_cast(location); - auto status = cuda::ToStatus(cuMemFree(pointer)); - if (!status.ok()) { - LOG(ERROR) << "failed to free device memory at " << location - << "; result: " << status; - } else { - VLOG(2) << "deallocated " << location << " for context " << context; - } -} - -void* GpuDriver::UnifiedMemoryAllocate(Context* context, uint64_t bytes) { - ScopedActivateContext activation(context); - CUdeviceptr result = 0; - // "Portable" memory is visible to all CUDA contexts. Safe for our use model. - auto status = - cuda::ToStatus(cuMemAllocManaged(&result, bytes, CU_MEM_ATTACH_GLOBAL)); - if (!status.ok()) { - LOG(ERROR) << "failed to alloc " << bytes - << " bytes unified memory; result: " << status; - return nullptr; - } - void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << context << " of " - << bytes << " bytes in unified memory"; - return ptr; -} - -void GpuDriver::UnifiedMemoryDeallocate(Context* context, void* location) { - ScopedActivateContext activation(context); - CUdeviceptr pointer = absl::bit_cast(location); - auto status = cuda::ToStatus(cuMemFree(pointer)); - if (!status.ok()) { - LOG(ERROR) << "failed to free unified memory at " << location - << "; result: " << status; - } else { - VLOG(2) << "deallocated unified memory at " << location << " for context " - << context; - } -} - -void* GpuDriver::HostAllocate(Context* context, uint64_t bytes) { - ScopedActivateContext activation(context); - void* host_mem = nullptr; - // "Portable" memory is visible to all CUDA contexts. Safe for our use model. - auto status = cuda::ToStatus( - cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE)); - if (!status.ok()) { - LOG(ERROR) << "failed to alloc " << bytes << " bytes on host: " << status; - } - return host_mem; -} - -void GpuDriver::HostDeallocate(Context* context, void* location) { - ScopedActivateContext activation(context); - auto status = cuda::ToStatus(cuMemFreeHost(location)); - if (!status.ok()) { - LOG(ERROR) << "error deallocating host memory at " << location << ": " - << status; - } -} - -int GpuDriver::GetGpuStreamPriority( - Context* context, stream_executor::StreamPriority stream_priority) { - ScopedActivateContext activation(context); - if (stream_priority == stream_executor::StreamPriority::Default) { - return 0; - } - int lowest, highest; - auto status = cuda::ToStatus(cuCtxGetStreamPriorityRange(&lowest, &highest)); - if (!status.ok()) { - LOG(ERROR) - << "Could not query stream priority range. Returning default priority."; - return 0; - } - return stream_priority == stream_executor::StreamPriority::Highest ? highest - : lowest; -} - -absl::Status GpuDriver::DestroyEvent(Context* context, CUevent* event) { - if (*event == nullptr) { - return absl::InvalidArgumentError("input event cannot be null"); - } - - ScopedActivateContext activated{context}; - return cuda::ToStatus(cuEventDestroy(*event), "Error destroying CUDA event"); -} - -absl::Status GpuDriver::RecordEvent(Context* context, CUevent event, - CUstream stream) { - ScopedActivateContext activated{context}; - return cuda::ToStatus(cuEventRecord(event, stream), - "Error recording CUDA event"); -} - -absl::StatusOr GpuDriver::GetEventElapsedTime(Context* context, - CUevent start, - CUevent stop) { - ScopedActivateContext activated{context}; - // The stop event must have completed in order for cuEventElapsedTime to - // work. - auto status = cuda::ToStatus(cuEventSynchronize(stop)); - if (!status.ok()) { - LOG(ERROR) << "failed to synchronize the stop event: " << status; - return false; - } - - float elapsed_milliseconds; - - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuEventElapsedTime(&elapsed_milliseconds, start, stop))); - - return elapsed_milliseconds; -} - -absl::Status GpuDriver::WaitStreamOnEvent(Context* context, CUstream stream, - CUevent event) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */)); -} - -absl::Status GpuDriver::SynchronizeContext(Context* context) { - ScopedActivateContext activation(context); - return cuda::ToStatus(cuCtxSynchronize()); -} - -absl::Status GpuDriver::SynchronizeStream(Context* context, CUstream stream) { - ScopedActivateContext activated{context}; - CHECK(stream != nullptr); - return cuda::ToStatus(cuStreamSynchronize(stream), - "Could not synchronize CUDA stream"); -} - -absl::Status GpuDriver::SynchronousMemcpyD2H(Context* context, void* host_dst, - CUdeviceptr gpu_src, - uint64_t size) { - ScopedActivateContext activation(context); - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuMemcpyDtoH(host_dst, gpu_src, size), - absl::StrFormat("failed to synchronous memcpy from device to host " - "host dst: %p; GPU src: %p; size: %u=0x%x", - host_dst, absl::bit_cast(gpu_src), size, size))); - VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " - << host_dst; - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronousMemcpyH2D(Context* context, - CUdeviceptr gpu_dst, - const void* host_src, - uint64_t size) { - ScopedActivateContext activation(context); - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuMemcpyHtoD(gpu_dst, host_src, size), - absl::StrFormat( - "failed to synchronous memcpy from host to device: GPU dst: %p;" - " host src: %p; size: %u=0x%x", - absl::bit_cast(gpu_dst), host_src, size, size))); - VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes"; - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyD2H(Context* context, void* host_dst, - CUdeviceptr gpu_src, - uint64_t size, CUstream stream) { - ScopedActivateContext activation(context); - - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream))); - - VLOG(2) << "successfully enqueued async memcpy d2h of " << size - << " bytes from " << absl::bit_cast(gpu_src) << " to " - << host_dst << " on stream " << stream; - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyH2D(Context* context, - CUdeviceptr gpu_dst, - const void* host_src, - uint64_t size, CUstream stream) { - ScopedActivateContext activation(context); - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream))); - - VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes" - << " from " << host_src << " to " << absl::bit_cast(gpu_dst) - << " on stream " << stream; - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context, - CUdeviceptr gpu_dst, - CUdeviceptr gpu_src, - uint64_t size, CUstream stream) { - ScopedActivateContext activation(context); - - // In graph capture mode we never have operations that access peer memory, so - // we can always make a call to cuMemcpyDtoDAsync. - TF_ASSIGN_OR_RETURN(bool is_capturing, StreamIsCapturing(stream)); - - if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { - // GetContextMap()->GetAnyContext() doesn't works when ptr == 0. - // This happens when the size is 0. - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); - } else { - // Any context work here. - CUcontext dst_context = - GetContextMap()->GetAnyContext(absl::bit_cast(gpu_dst)); - CUcontext src_context = - GetContextMap()->GetAnyContext(absl::bit_cast(gpu_src)); - - if (dst_context == src_context) { - // Since the CUDA context is the same, the src and dst are within the same - // GPU. So we can use cuMemcpyDtoD. - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); - } else { - TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemcpyPeerAsync( - gpu_dst, dst_context, gpu_src, src_context, size, stream))); - } - } - - VLOG(2) << "successfully enqueued async memcpy d2d of " << size << " bytes" - << " from " << absl::bit_cast(gpu_src) << " to " - << absl::bit_cast(gpu_dst) << " on stream " << stream; - return absl::OkStatus(); -} - -absl::Status GpuDriver::InitEvent(Context* context, CUevent* result, - EventFlags flags) { - int cuflags; - switch (flags) { - case EventFlags::kDefault: - cuflags = CU_EVENT_DEFAULT; - break; - case EventFlags::kDisableTiming: - cuflags = CU_EVENT_DISABLE_TIMING; - break; - default: - LOG(FATAL) << "impossible event flags: " << int(flags); - } - - ScopedActivateContext activated{context}; - return cuda::ToStatus(cuEventCreate(result, cuflags)); -} - int GpuDriver::GetDeviceCount() { int device_count = 0; auto status = cuda::ToStatus(cuDeviceGetCount(&device_count)); @@ -1460,120 +632,6 @@ int GpuDriver::GetDeviceCount() { return device_count; } -absl::StatusOr GpuDriver::GetPointerMemorySpace( - CUdeviceptr pointer) { - unsigned int value; - TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( - &value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer))); - switch (value) { - case CU_MEMORYTYPE_DEVICE: - return MemoryType::kDevice; - case CU_MEMORYTYPE_HOST: - return MemoryType::kHost; - default: - return absl::InternalError( - absl::StrCat("unknown memory space provided by CUDA API: ", value)); - } -} - -absl::Status GpuDriver::GetPointerAddressRange(CUdeviceptr dptr, - CUdeviceptr* base, - size_t* size) { - return cuda::ToStatus(cuMemGetAddressRange(base, size, dptr)); -} - -absl::Status GpuDriver::GetComputeCapability(int* cc_major, int* cc_minor, - CUdevice device) { - *cc_major = 0; - *cc_minor = 0; - - TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( - cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); - - return cuda::ToStatus(cuDeviceGetAttribute( - cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); -} - -absl::Status GpuDriver::GetGpuISAVersion(int* version, CUdevice device) { - return absl::Status{ - absl::StatusCode::kInternal, - "Feature not supported on CUDA platform (GetGpuISAVersion)"}; -} - -absl::Status GpuDriver::GetGpuGCNArchName(CUdevice, std::string*) { - return absl::Status{ - absl::StatusCode::kInternal, - "Feature not supported on CUDA platform (GetGpuGCNArchName)"}; -} - -// Helper function that turns the integer output of cuDeviceGetAttribute to type -// T and wraps it in a absl::StatusOr. -template -static absl::StatusOr GetSimpleAttribute(CUdevice device, - CUdevice_attribute attribute) { - int value = -1; - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuDeviceGetAttribute(&value, attribute, device), - absl::StrCat("Could not retrieve CUDA device attribute (", attribute))); - T converted = value; - return converted; -} - -absl::StatusOr GpuDriver::GetMultiprocessorCount(CUdevice device) { - return GetSimpleAttribute(device, - CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); -} - -absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore(CUdevice device) { - return GetSimpleAttribute( - device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR); -} - -absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock(CUdevice device) { - return GetSimpleAttribute( - device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK); -} - -absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlockOptin( - CUdevice device) { - return GetSimpleAttribute( - device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); -} - -absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( - CUdevice device) { - return GetSimpleAttribute( - device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR); -} - -absl::StatusOr GpuDriver::GetMaxRegistersPerBlock(CUdevice device) { - return GetSimpleAttribute( - device, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK); -} - -absl::StatusOr GpuDriver::GetThreadsPerWarp(CUdevice device) { - return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE); -} - -absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, CUdevice device) { - int value; - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device), - "Could not get device attribute")); - *x = value; - - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device), - "Could not get device attribute")); - *y = value; - - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device), - "Could not get device attribute")); - *z = value; - return absl::OkStatus(); -} - absl::StatusOr GpuDriver::GetDriverVersion() { int32_t version; TF_RETURN_IF_ERROR(cuda::ToStatus(cuDriverGetVersion(&version), @@ -1581,141 +639,6 @@ absl::StatusOr GpuDriver::GetDriverVersion() { return version; } -bool GpuDriver::GetDeviceProperties(CUdevprop* device_properties, - int device_ordinal) { - auto status = - cuda::ToStatus(cuDeviceGetProperties(device_properties, device_ordinal)); - return status.ok(); -} - -absl::StatusOr GpuDriver::GetDeviceAttribute(CUdevice_attribute attribute, - CUdevice device) { - int val; - TF_RETURN_IF_ERROR( - cuda::ToStatus(cuDeviceGetAttribute(&val, attribute, device))); - return val; -} - -bool GpuDriver::IsEccEnabled(CUdevice device, bool* result) { - int value = -1; - auto status = cuda::ToStatus( - cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device)); - if (!status.ok()) { - LOG(ERROR) << "failed to query ECC status: " << status; - return false; - } - - *result = value; - return true; -} - -bool GpuDriver::GetDeviceMemoryInfo(Context* context, int64_t* free_out, - int64_t* total_out) { - ScopedActivateContext activation(context); - size_t free = 0; - size_t total = 0; - auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); - if (!status.ok()) { - LOG(ERROR) << "failed to query device memory info: " << status; - return false; - } - - *free_out = free; - *total_out = total; - return true; -} - -bool GpuDriver::GetDeviceTotalMemory(CUdevice device, uint64_t* result) { - size_t value{}; - auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device)); - if (!status.ok()) { - LOG(ERROR) << "failed to query total available memory: " << status; - return false; - } - - *result = value; - return true; -} - -std::string GpuDriver::GetPCIBusID(CUdevice device) { - std::string pci_bus_id; - static const int kBufferSize = 64; - absl::InlinedVector chars(kBufferSize); - chars[kBufferSize - 1] = '\0'; - auto status = cuda::ToStatus( - cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device)); - if (!status.ok()) { - LOG(ERROR) << "failed to query PCI bus id for device: " << status; - return pci_bus_id; - } - pci_bus_id = chars.begin(); - return pci_bus_id; -} - -bool GpuDriver::CanEnablePeerAccess(Context* from, Context* to) { - if (from == to) { - return true; // A context can always access its own memory. - } - - auto from_device = DeviceFromContext(from); - if (!from_device.ok()) { - LOG(ERROR) << "failed to resolve 'from' peer access context to a device: " - << from_device.status(); - return false; - } - auto to_device = DeviceFromContext(to); - if (!to_device.ok()) { - LOG(ERROR) << "failed to resolve 'to' peer access context to a device: " - << to_device.status(); - return false; - } - return CanEnablePeerAccess(from_device.value(), to_device.value()); -} - -bool GpuDriver::CanEnablePeerAccess(GpuDeviceHandle from, GpuDeviceHandle to) { - int can_access_peer = -1; - auto status = - cuda::ToStatus(cuDeviceCanAccessPeer(&can_access_peer, from, to)); - if (!status.ok()) { - LOG(ERROR) << "failed to detect peer access capability: " << status; - return false; - } - return can_access_peer; -} - -absl::Status GpuDriver::EnablePeerAccess(Context* from, Context* to) { - if (from == to) { - return absl::OkStatus(); // A context can always access its own - // memory. - } - - ScopedActivateContext activated{from}; - CUresult result = cuCtxEnablePeerAccess( - tensorflow::down_cast(to)->context(), 0 /* = flags */); - if (result != CUDA_SUCCESS && - result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { - return absl::InternalError( - absl::StrFormat("failed to enable peer access from %p to %p: %s", from, - to, cuda::ToStatus(result).ToString())); - } - - return absl::OkStatus(); -} - -absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( - Context* context, CUfunction kernel, int threads_per_block, - size_t dynamic_shared_memory_bytes) { - ScopedActivateContext activation(context); - - int max_blocks; - TF_RETURN_IF_ERROR(cuda::ToStatus( - cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( - &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes, - CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE), - absl::StrFormat("Failed to calculate occupancy of kernel %p", kernel))); - return max_blocks; -} - absl::StatusOr GpuDriver::GraphGetNodeCount(GpuGraphHandle graph) { size_t num_nodes; TF_RETURN_IF_ERROR( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc index cd7fc58bfe5ca0..ed47fa3f0d6618 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc @@ -13,19 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/cuda/cuda_driver.h" - #include #include +#include "absl/cleanup/cleanup.h" #include "absl/log/log.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_context.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" @@ -55,30 +54,6 @@ class CudaDriverTest : public ::testing::Test { static void SetUpTestSuite() { CHECK_CUDA(cuInit(0)); } }; -TEST_F(CudaDriverTest, ScopedActivateContextTest) { - CUdevice device; - CHECK_CUDA(cuDeviceGet(&device, 0)); - CUcontext context0, context1; - CHECK_CUDA(cuCtxCreate(&context0, 0, device)); - CHECK_CUDA(cuCtxCreate(&context1, 0, device)); - gpu::GpuContext se_context1(context1, /*device_ordinal=*/101); - { - gpu::ScopedActivateContext scope(&se_context1); - CUcontext c; - CHECK_CUDA(cuCtxGetCurrent(&c)); - EXPECT_EQ(c, context1); - } - CHECK_CUDA(cuCtxSetCurrent(context0)); - // ScopedActivateContext must correctly set the CUDA context even if some - // other code changes the context between the two scopes. - { - gpu::ScopedActivateContext scope(&se_context1); - CUcontext c; - CHECK_CUDA(cuCtxGetCurrent(&c)); - EXPECT_EQ(c, context1); - } -} - TEST_F(CudaDriverTest, DriverVersionParsingTest) { // Tests that the driver version can be right after 'Kernel Module', // or later as well. @@ -102,6 +77,8 @@ TEST_F(CudaDriverTest, GraphGetNodeCountTest) { CHECK_CUDA(cuCtxCreate(&context, 0, device)); gpu::GpuGraphHandle graph; TF_CHECK_OK(gpu::GpuDriver::CreateGraph(&graph)); + absl::Cleanup cleanup( + [graph] { TF_CHECK_OK(gpu::GpuDriver::DestroyGraph(graph)); }); EXPECT_THAT(gpu::GpuDriver::GraphGetNodeCount(graph), IsOkAndHolds(0)); gpu::GpuGraphNodeHandle node; TF_CHECK_OK(gpu::GpuDriver::GraphAddEmptyNode(&node, graph, {})); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc index 3dc81c2d9e258d..5656939a68c091 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_event.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_event.cc @@ -15,17 +15,66 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_event.h" +#include +#include + +#include "absl/base/casts.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "third_party/gpus/cuda/include/cuda.h" -#include "xla/stream_executor/cuda/cuda_driver.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { +namespace { +absl::Status WaitStreamOnEvent(StreamExecutor *executor, CUstream stream, + CUevent event) { + std::unique_ptr activation = executor->Activate(); + return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */)); +} + +void DestroyEvent(StreamExecutor *executor, CUevent event) { + if (event == nullptr) { + return; + } + + std::unique_ptr activation = executor->Activate(); + auto result = + cuda::ToStatus(cuEventDestroy(event), "Error destroying CUDA event"); + if (!result.ok()) { + LOG(ERROR) << result.message(); + } +} + +enum class EventFlags { kDefault, kDisableTiming }; +absl::StatusOr InitEvent(StreamExecutor *executor, EventFlags flags) { + int cuflags; + switch (flags) { + case EventFlags::kDefault: + cuflags = CU_EVENT_DEFAULT; + break; + case EventFlags::kDisableTiming: + cuflags = CU_EVENT_DISABLE_TIMING; + break; + default: + LOG(FATAL) << "impossible event flags: " << int(flags); + } + + std::unique_ptr activation = executor->Activate(); + CUevent event_handle; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventCreate(&event_handle, cuflags))); + return event_handle; +} + +} // namespace Event::Status CudaEvent::PollForStatus() { - ScopedActivateContext activated(context()); - CUresult res = cuEventQuery(gpu_event()); + std::unique_ptr activation = executor_->Activate(); + CUresult res = cuEventQuery(handle_); if (res == CUDA_SUCCESS) { return Event::Status::kComplete; } else if (res == CUDA_ERROR_NOT_READY) { @@ -34,5 +83,43 @@ Event::Status CudaEvent::PollForStatus() { return Event::Status::kError; } +absl::Status CudaEvent::WaitForEventOnExternalStream(std::intptr_t stream) { + return WaitStreamOnEvent(executor_, absl::bit_cast(stream), + handle_); +} + +absl::StatusOr CudaEvent::Create(StreamExecutor *executor, + bool allow_timing) { + TF_ASSIGN_OR_RETURN( + CUevent event_handle, + InitEvent(executor, allow_timing ? EventFlags::kDefault + : EventFlags::kDisableTiming)); + + return CudaEvent(executor, event_handle); +} + +CudaEvent::~CudaEvent() { DestroyEvent(executor_, handle_); } + +CudaEvent& CudaEvent::operator=(CudaEvent&& other) { + if (this == &other) { + return *this; + } + + DestroyEvent(executor_, handle_); + + executor_ = other.executor_; + handle_ = other.handle_; + other.executor_ = nullptr; + other.handle_ = nullptr; + + return *this; +} + +CudaEvent::CudaEvent(CudaEvent &&other) + : executor_(other.executor_), handle_(other.handle_) { + other.executor_ = nullptr; + other.handle_ = nullptr; +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_event.h b/third_party/xla/xla/stream_executor/cuda/cuda_event.h index 46b69721b1eac4..0d6f871d0fbcc7 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_event.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_event.h @@ -16,19 +16,46 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_ #define XLA_STREAM_EXECUTOR_CUDA_CUDA_EVENT_H_ +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/gpu_event.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor::gpu { class GpuContext; -// This class implements Event::PollForStatus for CUDA devices. -class CudaEvent : public GpuEvent { +// This class implements Event for CUDA devices. +class CudaEvent : public Event { public: - explicit CudaEvent(Context *context) : GpuEvent(context) {} - Event::Status PollForStatus() override; + absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + + // Creates a new CudaEvent. If allow_timing is false, the event will not + // support timing, which is cheaper to create. + static absl::StatusOr Create(StreamExecutor* executor, + bool allow_timing); + + CUevent GetHandle() const { return handle_; } + + ~CudaEvent() override; + CudaEvent(const CudaEvent&) = delete; + CudaEvent& operator=(const CudaEvent&) = delete; + CudaEvent(CudaEvent&& other); + CudaEvent& operator=(CudaEvent&& other); + + private: + explicit CudaEvent(StreamExecutor* executor, CUevent handle) + : executor_(executor), handle_(handle) {} + + // The StreamExecutor to which this object and CUevent are bound. + StreamExecutor* executor_; + + // The underlying CUDA event handle. + CUevent handle_; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_event_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_event_test.cc new file mode 100644 index 00000000000000..acc7af142e99c2 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_event_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_event.h" + +#include + +#include +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_executor.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::tsl::testing::IsOk; + +TEST(CudaEventTest, CreateEvent) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::cuda::kCudaPlatformId)); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + CudaExecutor* cuda_executor = reinterpret_cast(executor); + + TF_ASSERT_OK_AND_ASSIGN(CudaEvent event, + CudaEvent::Create(cuda_executor, false)); + + EXPECT_NE(event.GetHandle(), nullptr); + EXPECT_EQ(event.PollForStatus(), Event::Status::kComplete); + + CUevent handle = event.GetHandle(); + CudaEvent event2 = std::move(event); + EXPECT_EQ(event2.GetHandle(), handle); +} + +} // namespace + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 28e15f1546b6c2..8f92021a61de98 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -26,24 +26,36 @@ limitations under the License. #include #include +#include "absl/base/casts.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" +#include "xla/stream_executor/cuda/cuda_command_buffer.h" +#include "xla/stream_executor/cuda/cuda_context.h" #include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/cuda/cuda_kernel.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/cuda_runtime.h" #include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/cuda/cuda_stream.h" +#include "xla/stream_executor/cuda/cuda_timer.h" #include "xla/stream_executor/cuda/cuda_version_parser.h" -#include "xla/stream_executor/cuda/delay_kernel.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" @@ -51,30 +63,34 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" +#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/gpu/read_numa_node.h" +#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/casts.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/threadpool.h" namespace stream_executor { namespace gpu { @@ -90,11 +106,420 @@ bool ShouldLaunchDelayKernel() { return value; } -absl::Status FuncGetAttribute(CUfunction_attribute attribute, CUfunction func, - int* attribute_value) { +// CUDA driver routines may require a large amount of stack (particularly +// cuModuleLoadDataEx, in our experience). To avoid stack overflow when using +// stack-limited threads (such as those spawned by a default-argument +// thread::ThreadPool on some platforms), we run certain routines in this pool +// and wait for completion. +tsl::thread::ThreadPool* GetDriverExecutor() { + static tsl::thread::ThreadPool* thread_pool = new tsl::thread::ThreadPool( + tsl::Env::Default(), tsl::ThreadOptions(), "cuda_driver", 1); + return thread_pool; +} + +// Loads ptx_contents with the CUDA driver's PTX JIT and return the resulting +// handle. Any error logs that are produced are logged internally. +absl::StatusOr LoadPtx(Context* context, const char* ptx_contents) { + absl::Notification notification; + absl::Status returned_status = absl::OkStatus(); + CUmodule module; + GetDriverExecutor()->Schedule( + [context, ptx_contents, &module, &returned_status, ¬ification]() { + ScopedActivateContext activation(context); + void* ptx_data = const_cast(ptx_contents); + static const unsigned int kLogBufferBytesLimit = 1024; + unsigned int error_log_buffer_bytes = kLogBufferBytesLimit; + unsigned int info_log_buffer_bytes = kLogBufferBytesLimit; + absl::InlinedVector error_log_buffer(error_log_buffer_bytes); + absl::InlinedVector info_log_buffer(info_log_buffer_bytes); + bool log_verbose = true; + CUjit_option options[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + CU_JIT_ERROR_LOG_BUFFER, + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE}; + // Note that the driver API wants the contents of this values to be + // stored in an array of void*s, so we coerce them accordingly. + void* option_values[] = { + absl::bit_cast(uintptr_t(error_log_buffer_bytes)), + absl::bit_cast(error_log_buffer.data()), + absl::bit_cast(uintptr_t(info_log_buffer_bytes)), + absl::bit_cast(info_log_buffer.data()), + absl::bit_cast(uintptr_t(log_verbose))}; + CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values)); + + absl::Status status; + status = cuda::ToStatus(cuModuleLoadDataEx( + &module, ptx_data, TF_ARRAYSIZE(options), options, option_values)); + + // The PTX JIT mutates the values in the option values array to reflect + // the size of the logs it output; now that we've made the call, read + // the values back out. + error_log_buffer_bytes = reinterpret_cast(option_values[0]); + info_log_buffer_bytes = reinterpret_cast(option_values[2]); + CHECK_LE(error_log_buffer_bytes, kLogBufferBytesLimit); + CHECK_LE(info_log_buffer_bytes, kLogBufferBytesLimit); + + if (!status.ok()) { + LOG(ERROR) << "failed to load PTX text as a module: " << status; + // As a precaution for null termination of the API-provided value, + // ensure that at least the last byte is null. + error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1 + : 0] = '\0'; + LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes + << " bytes): " << error_log_buffer.data(); + if (absl::StrContains(error_log_buffer.data(), + "Register allocation failed")) { + returned_status = absl::ResourceExhaustedError( + absl::StrFormat("Failed to load PTX text as a module (register " + "allocation failed): %s", + status.ToString())); + } else { + returned_status = status; + } + notification.Notify(); + return; + } + + VLOG(3) << "PTX compilation info log (" << info_log_buffer_bytes + << " bytes): " << info_log_buffer.data(); + VLOG(3) << "PTX compilation error log (" << error_log_buffer_bytes + << " bytes): " << error_log_buffer.data(); + CHECK(module != nullptr); + notification.Notify(); + }); + notification.WaitForNotification(); + + TF_RETURN_IF_ERROR(returned_status); + return module; +} + +// Loads cubin_bytes with the CUDA driver's blob loading interface and stores +// the resulting handle in "module". +absl::StatusOr LoadCubin(Context* context, const char* cubin_bytes) { + ScopedActivateContext activation(context); + CUmodule module; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuModuleLoadFatBinary(&module, cubin_bytes), + "Failed to load in-memory CUBIN (compiled for a different GPU?).")); + return module; +} + +// Retrieves a named kernel from a loaded module, and return the CUfunction +// handle on success. Neither kernel_name nor function may be null. No ownership +// is taken of kernel_name. +absl::StatusOr GetModuleFunction(Context* context, CUmodule module, + const char* kernel_name) { + ScopedActivateContext activated{context}; + CHECK(module != nullptr && kernel_name != nullptr); + cudaError_t cuda_error = cudaPeekAtLastError(); + if (cuda_error != cudaSuccess) { + return absl::InternalError( + absl::StrCat("There was an error before calling cuModuleGetFunction (", + cuda_error, "): ", cudaGetErrorName(cuda_error), " : ", + cudaGetErrorString(cuda_error))); + } + CUfunction function; + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuModuleGetFunction(&function, module, kernel_name), + "Failed to get module function")); + return function; +} + +// Retrieves a named global/constant symbol from a loaded module, and returns +// a device pointer and size of the symbol on success. symbol_name may not be +// null. At least one of dptr or bytes should not be null. No ownership is +// taken of symbol_name. +absl::Status GetModuleSymbol(Context* context, CUmodule module, + const char* symbol_name, CUdeviceptr* dptr, + size_t* bytes) { + ScopedActivateContext activated{context}; + CHECK(module != nullptr && symbol_name != nullptr && + (dptr != nullptr || bytes != nullptr)); return cuda::ToStatus( - cuFuncGetAttribute(attribute_value, attribute, func), - absl::StrCat("Failed to query kernel attribute: ", attribute)); + cuModuleGetGlobal(dptr, bytes, module, symbol_name), + absl::StrCat("Failed to get symbol '", symbol_name, "'")); +} + +// Unloads module from the current context via cuModuleUnload. +void UnloadCudaModule(Context* context, CUmodule module) { + ScopedActivateContext activated{context}; + auto status = cuda::ToStatus(cuModuleUnload(module)); + if (!status.ok()) { + LOG(ERROR) << "failed to unload module " << module + << "; leaking: " << status; + } +} + +// Returns the integer output of cuDeviceGetAttribute. +absl::StatusOr GetDeviceAttribute(CUdevice_attribute attribute, + CUdevice device) { + int val; + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDeviceGetAttribute(&val, attribute, device))); + return val; +} + +// Returns the name of the device. +absl::StatusOr GetDeviceName(CUdevice device) { + static const size_t kCharLimit = 64; + absl::InlinedVector chars(kCharLimit); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuDeviceGetName(chars.begin(), kCharLimit - 1, device), + "Failed to get device name")); + chars[kCharLimit - 1] = '\0'; + return chars.begin(); +} + +// Returns the compute capability for the device; i.e (3, 5). +absl::Status GetComputeCapability(int* cc_major, int* cc_minor, + CUdevice device) { + *cc_major = 0; + *cc_minor = 0; + + TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGetAttribute( + cc_major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device))); + + return cuda::ToStatus(cuDeviceGetAttribute( + cc_minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)); +} + +// Helper function that turns the integer output of cuDeviceGetAttribute to type +// T and wraps it in a absl::StatusOr. +template +static absl::StatusOr GetSimpleAttribute(CUdevice device, + CUdevice_attribute attribute) { + int value = -1; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDeviceGetAttribute(&value, attribute, device), + absl::StrCat("Could not retrieve CUDA device attribute (", attribute))); + T converted = value; + return converted; +} + +// Returns the number of multiprocessors on the device (note that the device +// may be multi-GPU-per-board). +absl::StatusOr GetMultiprocessorCount(CUdevice device) { + return GetSimpleAttribute(device, + CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT); +} + +absl::StatusOr GetMaxSharedMemoryPerCore(CUdevice device) { + return GetSimpleAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR); +} + +absl::StatusOr GetMaxSharedMemoryPerBlock(CUdevice device) { + return GetSimpleAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK); +} + +absl::StatusOr GetMaxSharedMemoryPerBlockOptin(CUdevice device) { + return GetSimpleAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); +} + +absl::StatusOr GetMaxThreadsPerMultiprocessor(CUdevice device) { + return GetSimpleAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR); +} + +absl::StatusOr GetMaxRegistersPerBlock(CUdevice device) { + return GetSimpleAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK); +} + +absl::StatusOr GetThreadsPerWarp(CUdevice device) { + return GetSimpleAttribute(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE); +} + +absl::Status GetGridLimits(int* x, int* y, int* z, CUdevice device) { + int value; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, device), + "Could not get device attribute")); + *x = value; + + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y, device), + "Could not get device attribute")); + *y = value; + + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z, device), + "Could not get device attribute")); + *z = value; + return absl::OkStatus(); +} + +// Returns the device associated with the given device_ordinal. +absl::StatusOr GetDevice(int device_ordinal) { + CUdevice device; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuDeviceGet(&device, device_ordinal), + "Failed call to cuDeviceGet")); + return device; +} + +// Returns the device associated with the given context. +absl::StatusOr DeviceFromContext(Context* context) { + ScopedActivateContext activated{context}; + CUdevice device = -1; + auto status = cuda::ToStatus(cuCtxGetDevice(&device)); + if (status.ok()) { + return device; + } + + return status; +} + +bool CanEnablePeerAccess(CUdevice from, CUdevice to) { + int can_access_peer = -1; + auto status = + cuda::ToStatus(cuDeviceCanAccessPeer(&can_access_peer, from, to)); + if (!status.ok()) { + LOG(ERROR) << "failed to detect peer access capability: " << status; + return false; + } + return can_access_peer; +} + +bool CanEnablePeerAccess(Context* from, Context* to) { + if (from == to) { + return true; // A context can always access its own memory. + } + + auto from_device = DeviceFromContext(from); + if (!from_device.ok()) { + LOG(ERROR) << "failed to resolve 'from' peer access context to a device: " + << from_device.status(); + return false; + } + auto to_device = DeviceFromContext(to); + if (!to_device.ok()) { + LOG(ERROR) << "failed to resolve 'to' peer access context to a device: " + << to_device.status(); + return false; + } + return CanEnablePeerAccess(from_device.value(), to_device.value()); +} + +absl::Status EnablePeerAccess(Context* from, Context* to) { + if (from == to) { + return absl::OkStatus(); // A context can always access its own + // memory. + } + + ScopedActivateContext activated{from}; + CUresult result = cuCtxEnablePeerAccess( + tensorflow::down_cast(to)->context(), 0 /* = flags */); + if (result != CUDA_SUCCESS && + result != CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED) { + return absl::InternalError( + absl::StrFormat("failed to enable peer access from %p to %p: %s", from, + to, cuda::ToStatus(result).ToString())); + } + + return absl::OkStatus(); +} + +// Returns the total amount of memory available on the device. +bool GetDeviceTotalMemory(CUdevice device, uint64_t* result) { + size_t value{}; + auto status = cuda::ToStatus(cuDeviceTotalMem(&value, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query total available memory: " << status; + return false; + } + + *result = value; + return true; +} + +bool IsEccEnabled(CUdevice device, bool* result) { + int value = -1; + auto status = cuda::ToStatus( + cuDeviceGetAttribute(&value, CU_DEVICE_ATTRIBUTE_ECC_ENABLED, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query ECC status: " << status; + return false; + } + + *result = value; + return true; +} + +std::string GetPCIBusID(CUdevice device) { + std::string pci_bus_id; + static const int kBufferSize = 64; + absl::InlinedVector chars(kBufferSize); + chars[kBufferSize - 1] = '\0'; + auto status = cuda::ToStatus( + cuDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device)); + if (!status.ok()) { + LOG(ERROR) << "failed to query PCI bus id for device: " << status; + return pci_bus_id; + } + pci_bus_id = std::string(chars.begin(), kBufferSize - 1); + return pci_bus_id; +} + +// Allocates memory on the GPU device. +void* DeviceAllocate(Context* context, uint64_t bytes) { + if (bytes == 0) { + return nullptr; + } + + ScopedActivateContext activated{context}; + CUdeviceptr result = 0; + auto status = cuda::ToStatus(cuMemAlloc(&result, bytes)); + if (!status.ok()) { + // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator + // implements a retry if the first allocation fails). + LOG(INFO) << "failed to allocate " + << tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes + << " bytes) from device: " << status; + return nullptr; + } + void* ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for context " << context << " of " + << bytes << " bytes"; + return ptr; +} + +// Deallocates memory on the GPU device that was previously allocated via +// DeviceAllocate. +void DeviceDeallocate(Context* context, void* location) { + ScopedActivateContext activation(context); + CUdeviceptr pointer = absl::bit_cast(location); + auto status = cuda::ToStatus(cuMemFree(pointer)); + if (!status.ok()) { + LOG(ERROR) << "failed to free device memory at " << location + << "; result: " << status; + } else { + VLOG(2) << "deallocated " << location << " for context " << context; + } +} + +// Allocates memory on the host. +void* HostAllocate(Context* context, uint64_t bytes) { + ScopedActivateContext activation(context); + void* host_mem = nullptr; + // "Portable" memory is visible to all CUDA contexts. Safe for our use model. + auto status = cuda::ToStatus( + cuMemHostAlloc(&host_mem, bytes, CU_MEMHOSTALLOC_PORTABLE)); + if (!status.ok()) { + LOG(ERROR) << "failed to alloc " << bytes << " bytes on host: " << status; + } + return host_mem; +} + +// Deallocates memory allocated via HostAllocate. +void HostDeallocate(Context* context, void* location) { + ScopedActivateContext activation(context); + auto status = cuda::ToStatus(cuMemFreeHost(location)); + if (!status.ok()) { + LOG(ERROR) << "error deallocating host memory at " << location << ": " + << status; + } } } // namespace @@ -114,23 +539,62 @@ static CUdeviceptr AsCudaDevicePtr(DeviceMemoryBase* gpu_mem) { return AsCudaDevicePtr(*gpu_mem); } +absl::StatusOr CudaExecutor::GetMemoryRange( + const DeviceMemoryBase& location) { + CUdeviceptr device_pointer; + size_t size; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemGetAddressRange(&device_pointer, &size, AsCudaDevicePtr(location)))); + return DeviceMemoryBase(reinterpret_cast(device_pointer), size); +} + +std::unique_ptr CudaExecutor::Activate() { + return std::make_unique(cuda_context_); +} + CudaExecutor::~CudaExecutor() { - CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; - CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; - if (gpu_context() != nullptr) { - GpuDriver::DestroyContext(gpu_context()); + CHECK(kernel_to_gpu_binary_.empty()) << "CudaExecutor has live kernels."; + CHECK(gpu_binary_to_module_.empty()) << "CudaExecutor has loaded modules."; + set_context(nullptr); +} + +void CudaExecutor::UnifiedMemoryDeallocate(void* location) { + std::unique_ptr activation = Activate(); + CUdeviceptr pointer = absl::bit_cast(location); + auto status = cuda::ToStatus(cuMemFree(pointer)); + if (!status.ok()) { + LOG(ERROR) << "failed to free unified memory at " << location + << "; result: " << status; + } else { + VLOG(2) << "deallocated unified memory at " << location << " for context " + << cuda_context_; + } +} + +void* CudaExecutor::UnifiedMemoryAllocate(uint64_t size) { + std::unique_ptr activation = Activate(); + CUdeviceptr result = 0; + // "Portable" memory is visible to all CUDA contexts. Safe for our use model. + auto status = + cuda::ToStatus(cuMemAllocManaged(&result, size, CU_MEM_ATTACH_GLOBAL)); + if (!status.ok()) { + LOG(ERROR) << "failed to alloc " << size + << " bytes unified memory; result: " << status; + return nullptr; } + void* ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for context " << cuda_context_ << " of " + << size << " bytes in unified memory"; + return ptr; } absl::Status CudaExecutor::Init() { - TF_RETURN_IF_ERROR(GpuDriver::Init()); - TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal(), &device_)); - Context* context; - TF_RETURN_IF_ERROR( - GpuDriver::CreateContext(device_ordinal(), device_, &context)); + TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); + TF_ASSIGN_OR_RETURN(CudaContext * context, + CudaContext::Create(device_ordinal(), device_)); set_context(context); - TF_RETURN_IF_ERROR( - GpuDriver::GetComputeCapability(&cc_major_, &cc_minor_, device_)); + cuda_context_ = context; + TF_RETURN_IF_ERROR(GetComputeCapability(&cc_major_, &cc_minor_, device_)); TF_ASSIGN_OR_RETURN(delay_kernels_supported_, DelayKernelIsSupported()); return absl::OkStatus(); } @@ -138,61 +602,57 @@ absl::Status CudaExecutor::Init() { absl::StatusOr CudaExecutor::DelayKernelIsSupported() { // Check the assumption that this device supports unified addressing, // otherwise skip the delay kernel - TF_ASSIGN_OR_RETURN(int status, - GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device_)); + TF_ASSIGN_OR_RETURN( + int status, + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device_)); return static_cast(status); } -absl::Status CudaExecutor::LoadModuleFromCuBin(const char* cubin, - CUmodule* module) { +absl::StatusOr CudaExecutor::LoadModuleFromCuBin( + const char* cubin) { + ModuleHandle module_handle{cubin}; uint64_t module_refcount; - std::tie(*module, module_refcount) = gpu_binary_to_module_[cubin]; + CUmodule module; + std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; - if (*module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadCubin(gpu_context(), cubin, module)); + if (module == nullptr) { + TF_ASSIGN_OR_RETURN(module, LoadCubin(cuda_context_, cubin)); module_refcount = 1; VLOG(3) << "Loaded CUBIN " << static_cast(cubin) - << " as module " << *module; + << " as module " << module; } else { ++module_refcount; VLOG(3) << "CUBIN " << static_cast(cubin) - << " is already loaded as module " << *module; + << " is already loaded as module " << module; } - gpu_binary_to_module_[cubin] = {*module, module_refcount}; - return absl::OkStatus(); + gpu_binary_to_module_[module_handle] = {module, module_refcount}; + return module_handle; } -absl::Status CudaExecutor::LoadModuleFromPtx(const char* ptx, - CUmodule* module) { +absl::StatusOr CudaExecutor::LoadModuleFromPtx(const char* ptx) { + ModuleHandle module_handle{ptx}; uint64_t module_refcount; - std::tie(*module, module_refcount) = gpu_binary_to_module_[ptx]; + CUmodule module; + std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; - if (*module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadPtx(gpu_context(), ptx, module)); + if (module == nullptr) { + TF_ASSIGN_OR_RETURN(module, LoadPtx(cuda_context_, ptx)); VLOG(3) << "Loaded PTX " << static_cast(ptx) << " as module " - << *module; + << module; module_refcount = 1; } else { ++module_refcount; VLOG(3) << "PTX " << static_cast(ptx) << " is already loaded as module " << module; } - gpu_binary_to_module_[ptx] = {*module, module_refcount}; - return absl::OkStatus(); -} - -absl::Status CudaExecutor::LoadModuleFromHsaco(const char* hsaco, - CUmodule* module) { - return absl::InternalError( - "Feature not supported on CUDA platform (LoadModuleFromHsaco)"); + gpu_binary_to_module_[module_handle] = {module, module_refcount}; + return module_handle; } absl::StatusOr> CudaExecutor::LoadKernel( const MultiKernelLoaderSpec& spec) { - auto cuda_kernel = std::make_unique(this); - CUmodule module; + auto cuda_kernel = std::make_unique(this); const std::string* kernel_name; if (spec.has_cuda_cubin_in_memory()) { @@ -200,8 +660,15 @@ absl::StatusOr> CudaExecutor::LoadKernel( kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); const char* cubin = reinterpret_cast( spec.cuda_cubin_in_memory().cubin_bytes().data()); - TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module)); - kernel_to_gpu_binary_[cuda_kernel.get()] = cubin; + TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromCuBin(cubin)); + kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle; + + CUmodule module = gpu_binary_to_module_.at(module_handle).first; + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_ASSIGN_OR_RETURN( + CUfunction function, + GetModuleFunction(cuda_context_, module, kernel_name->c_str())); + cuda_kernel->set_gpu_function(function); } else if (spec.has_cuda_ptx_in_memory()) { kernel_name = &spec.cuda_ptx_in_memory().kernel_name(); @@ -219,8 +686,15 @@ absl::StatusOr> CudaExecutor::LoadKernel( } absl::MutexLock lock{&in_memory_modules_mu_}; - TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module)); - kernel_to_gpu_binary_[cuda_kernel.get()] = ptx; + TF_ASSIGN_OR_RETURN(ModuleHandle module_handle, LoadModuleFromPtx(ptx)); + kernel_to_gpu_binary_[cuda_kernel.get()] = module_handle; + + CUmodule module = gpu_binary_to_module_.at(module_handle).first; + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_ASSIGN_OR_RETURN( + CUfunction function, + GetModuleFunction(cuda_context_, module, kernel_name->c_str())); + cuda_kernel->set_gpu_function(function); } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); @@ -229,23 +703,14 @@ absl::StatusOr> CudaExecutor::LoadKernel( VLOG(2) << "Resolve CUDA kernel " << *kernel_name << " from symbol pointer: " << symbol; TF_ASSIGN_OR_RETURN( - GpuFunctionHandle function, - GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); + CUfunction function, + CudaRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); cuda_kernel->set_gpu_function(function); } else { return absl::InternalError("No method of loading CUDA kernel provided"); } VLOG(3) << "LoadKernel on kernel : " << *kernel_name; - // If we resolved kernel from a symbol pointer, there is no need to load it - // from a module, as CUDA runtime did that automatically for us. - if (!spec.has_in_process_symbol()) { - VLOG(2) << "getting function " << *kernel_name << " from module " << module; - GpuFunctionHandle function; - TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( - gpu_context(), module, kernel_name->c_str(), &function)); - cuda_kernel->set_gpu_function(function); - } // Update CUDA kernel properties after it was loaded in the CUDA context. cuda_kernel->set_name(*kernel_name); @@ -254,8 +719,8 @@ absl::StatusOr> CudaExecutor::LoadKernel( // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel.get(), &kernel_metadata)); + TF_ASSIGN_OR_RETURN(KernelMetadata kernel_metadata, + cuda_kernel->GetKernelMetadata()); cuda_kernel->set_metadata(kernel_metadata); cuda_kernel->set_name(*kernel_name); cuda_kernel->set_args_packing(spec.kernel_args_packing()); @@ -263,22 +728,19 @@ absl::StatusOr> CudaExecutor::LoadKernel( } absl::StatusOr> -CudaExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { - GpuSemaphore semaphore{}; - - if (use_delay_kernel && ShouldLaunchDelayKernel() && - delay_kernels_supported_) { - TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); - } - TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); - TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); - TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); - return std::make_unique(gpu_context(), std::move(start_event), - std::move(stop_event), stream, - std::move(semaphore)); +CudaExecutor::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { + const CudaTimer::TimerType timer_type = + (use_delay_kernel && ShouldLaunchDelayKernel() && + delay_kernels_supported_) + ? CudaTimer::TimerType::kDelayKernel + : CudaTimer::TimerType::kEventBased; + + TF_ASSIGN_OR_RETURN(CudaTimer timer, + CudaTimer::Create(this, stream, timer_type)); + return std::make_unique(std::move(timer)); } -bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) { +bool CudaExecutor::UnloadGpuBinary(ModuleHandle gpu_binary) { auto module_it = gpu_binary_to_module_.find(gpu_binary); if (gpu_binary_to_module_.end() == module_it) { VLOG(3) << "No loaded CUDA module for " << gpu_binary; @@ -289,7 +751,7 @@ bool CudaExecutor::UnloadGpuBinary(const void* gpu_binary) { VLOG(3) << "Found CUDA module " << module << " with refcount " << refcount; if (--refcount == 0) { VLOG(3) << "Unloading CUDA module " << module; - GpuDriver::UnloadModule(gpu_context(), module); + UnloadCudaModule(cuda_context_, module); gpu_binary_to_module_.erase(module_it); } return true; @@ -313,19 +775,14 @@ void CudaExecutor::UnloadKernel(const Kernel* kernel) { kernel_to_gpu_binary_.erase(gpu_binary_it); } -absl::Status CudaExecutor::LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { - // In GpuExecutor we store the pointer to the GPU binary (PTX or CUBIN) as +absl::StatusOr CudaExecutor::LoadModule( + const MultiModuleLoaderSpec& spec) { + // We store the pointer to the GPU binary (PTX or CUBIN) as // ModuleHandle::id(). - CUmodule cu_module; if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; - TF_RETURN_IF_ERROR(LoadModuleFromCuBin( - reinterpret_cast(spec.cuda_cubin_in_memory().data()), - &cu_module)); - *module_handle = ModuleHandle(const_cast( - static_cast(spec.cuda_cubin_in_memory().data()))); - return absl::OkStatus(); + return LoadModuleFromCuBin( + reinterpret_cast(spec.cuda_cubin_in_memory().data())); } else if (spec.has_cuda_ptx_in_memory()) { if (cc_major_ == 0 && cc_minor_ == 0) { return absl::InternalError("Compute capability not set"); @@ -336,19 +793,14 @@ absl::Status CudaExecutor::LoadModule(const MultiModuleLoaderSpec& spec, } absl::MutexLock lock{&in_memory_modules_mu_}; - TF_RETURN_IF_ERROR( - LoadModuleFromPtx(spec.cuda_ptx_in_memory(), &cu_module)); - *module_handle = ModuleHandle( - const_cast(static_cast(spec.cuda_ptx_in_memory()))); - return absl::OkStatus(); + return LoadModuleFromPtx(spec.cuda_ptx_in_memory()); } return absl::InternalError("No method of loading CUDA module provided"); } bool CudaExecutor::UnloadModule(ModuleHandle module_handle) { - const char* gpu_binary = reinterpret_cast(module_handle.id()); absl::MutexLock lock{&in_memory_modules_mu_}; - return UnloadGpuBinary(gpu_binary); + return UnloadGpuBinary(module_handle); } namespace { @@ -426,32 +878,29 @@ CudaExecutor::CreateOrShareConstant(Stream* stream, return shared_constant; } -absl::Status CudaExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, - KernelMetadata* kernel_metadata) { - int value; - TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS, - cuda_kernel->gpu_function(), &value)); - kernel_metadata->set_registers_per_thread(value); - - TF_RETURN_IF_ERROR(FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - cuda_kernel->gpu_function(), &value)); - kernel_metadata->set_shared_memory_bytes(value); - return absl::OkStatus(); -} - DeviceMemoryBase CudaExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { - auto result = GpuCollectives::CollectiveMemoryAllocate(gpu_context(), size); + auto result = CudaCollectives::CollectiveMemoryAllocate(this, size); if (!result.ok()) { LOG(ERROR) << result.status(); } - return DeviceMemoryBase(*result, size); + return DeviceMemoryBase(nullptr, 0); } else if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { - return DeviceMemoryBase(GpuDriver::HostAllocate(gpu_context(), size), size); + return DeviceMemoryBase(HostAllocate(cuda_context_, size), size); } CHECK_EQ(memory_space, 0); - return DeviceMemoryBase(GpuDriver::DeviceAllocate(gpu_context(), size), size); + return DeviceMemoryBase(DeviceAllocate(cuda_context_, size), size); +} + +absl::StatusOr> +CudaExecutor::HostMemoryAllocate(uint64_t size) { + auto* buffer = HostAllocate(cuda_context_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); } void CudaExecutor::Deallocate(DeviceMemoryBase* mem) { @@ -462,39 +911,89 @@ void CudaExecutor::Deallocate(DeviceMemoryBase* mem) { } auto memory_space = status_or_memory_space.value(); if (memory_space == MemoryType::kHost) { - GpuDriver::HostDeallocate(gpu_context(), mem->opaque()); + HostDeallocate(cuda_context_, mem->opaque()); } else { - GpuDriver::DeviceDeallocate(gpu_context(), mem->opaque()); + DeviceDeallocate(cuda_context_, mem->opaque()); } } +void CudaExecutor::HostMemoryDeallocate(void* location) { + return HostDeallocate(cuda_context_, location); +} + bool CudaExecutor::SynchronizeAllActivity() { - return GpuDriver::SynchronizeContext(gpu_context()).ok(); + return cuda_context_->Synchronize().ok(); +} + +bool CudaExecutor::HostMemoryRegister(void* location, uint64_t size) { + VLOG(1) << "Called StreamExecutor::HostMemoryRegister(data=" << location + << ")"; + + std::unique_ptr activation = Activate(); + // "Portable" memory is visible to all CUDA contexts. Safe for our use model. + auto status = cuda::ToStatus( + cuMemHostRegister(location, size, CU_MEMHOSTREGISTER_PORTABLE)); + if (!status.ok()) { + LOG(ERROR) << "error registering host memory at " << location << ": " + << status; + return false; + } + return true; +} + +bool CudaExecutor::HostMemoryUnregister(void* location) { + VLOG(1) << "Called StreamExecutor::HostUnregister(data=" << location << ")"; + + std::unique_ptr activation = Activate(); + auto status = cuda::ToStatus(cuMemHostUnregister(location)); + if (!status.ok()) { + LOG(ERROR) << "error unregistering host memory at " << location << ": " + << status; + return false; + } + return true; } absl::Status CudaExecutor::SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return GpuDriver::SynchronousMemsetUint32( - gpu_context(), AsCudaDevicePtr(location), 0x0, size / 4); + std::unique_ptr activation = Activate(); + CUdeviceptr cuda_location = AsCudaDevicePtr(location); + if (reinterpret_cast(location->opaque()) % sizeof(uint32_t) == 0 && + size % sizeof(uint32_t) == 0) { + return cuda::ToStatus( + cuMemsetD32(cuda_location, 0x0, size / sizeof(uint32_t)), + "Failed to memset memory"); } - return GpuDriver::SynchronousMemsetUint8( - gpu_context(), AsCudaDevicePtr(location), 0x0, size); + return cuda::ToStatus(cuMemsetD8(cuda_location, 0x0, size), + "Failed to memset memory"); } absl::Status CudaExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { - return GpuDriver::SynchronousMemcpyH2D( - gpu_context(), AsCudaDevicePtr(gpu_dst), host_src, size); + std::unique_ptr activation = Activate(); + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemcpyHtoD(AsCudaDevicePtr(gpu_dst), host_src, size), + absl::StrFormat( + "failed to synchronous memcpy from host to device: GPU dst: %llx;" + " host src: %p; size: %u=0x%x", + AsCudaDevicePtr(gpu_dst), host_src, size, size))); + VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes"; + return absl::OkStatus(); } absl::Status CudaExecutor::SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::SynchronousMemcpyD2H(gpu_context(), host_dst, - AsCudaDevicePtr(gpu_src), size); + std::unique_ptr activation = Activate(); + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuMemcpyDtoH(host_dst, AsCudaDevicePtr(gpu_src), size), + absl::StrFormat("failed to synchronous memcpy from device to host " + "host dst: %p; GPU src: %llx; size: %u=0x%x", + host_dst, AsCudaDevicePtr(gpu_src), size, size))); + VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " + << host_dst; + return absl::OkStatus(); } void CudaExecutor::DeallocateStream(Stream* stream) { @@ -504,13 +1003,8 @@ void CudaExecutor::DeallocateStream(Stream* stream) { dnn_->NotifyStreamDestroyed(stream); } } - GpuStream* gpu_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(gpu_stream->gpu_stream()); -} - -absl::Status CudaExecutor::BlockHostUntilDone(Stream* stream) { - return GpuDriver::SynchronizeStream(gpu_context(), AsGpuStreamValue(stream)); + alive_gpu_streams_.erase(stream->platform_specific_handle().stream); } blas::BlasSupport* CudaExecutor::AsBlas() { @@ -575,18 +1069,29 @@ fft::FftSupport* CudaExecutor::AsFft() { } bool CudaExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* cuda_other = static_cast(other); - return GpuDriver::CanEnablePeerAccess(gpu_context(), - cuda_other->gpu_context()); + CudaExecutor* cuda_other = static_cast(other); + return CanEnablePeerAccess(cuda_context_, cuda_other->cuda_context_); } absl::Status CudaExecutor::EnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* cuda_other = static_cast(other); - return GpuDriver::EnablePeerAccess(gpu_context(), cuda_other->gpu_context()); + CudaExecutor* cuda_other = static_cast(other); + return EnablePeerAccess(cuda_context_, cuda_other->cuda_context_); } -bool CudaExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return GpuDriver::GetDeviceMemoryInfo(gpu_context(), free, total); +bool CudaExecutor::DeviceMemoryUsage(int64_t* free_out, + int64_t* total_out) const { + ScopedActivateContext activation(cuda_context_); + size_t free = 0; + size_t total = 0; + auto status = cuda::ToStatus(cuMemGetInfo(&free, &total)); + if (!status.ok()) { + LOG(ERROR) << "failed to query device memory info: " << status; + return false; + } + + *free_out = free; + *total_out = total; + return true; } absl::StatusOr CudaExecutor::GetSymbol( @@ -597,14 +1102,14 @@ absl::StatusOr CudaExecutor::GetSymbol( { // give limited scope to mutex_lock absl::MutexLock lock{&in_memory_modules_mu_}; - auto it = gpu_binary_to_module_.find(module_handle.id()); + auto it = gpu_binary_to_module_.find(module_handle); CHECK(it != gpu_binary_to_module_.end()); - GpuModuleHandle gpu_module_handle = it->second.first; + CUmodule gpu_module_handle = it->second.first; CHECK(gpu_module_handle != nullptr); - TF_RETURN_IF_ERROR(GpuDriver::GetModuleSymbol( - gpu_context(), gpu_module_handle, symbol_name.c_str(), - reinterpret_cast(&mem), &bytes)); + TF_RETURN_IF_ERROR( + GetModuleSymbol(cuda_context_, gpu_module_handle, symbol_name.c_str(), + reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } @@ -614,63 +1119,50 @@ absl::StatusOr CudaExecutor::GetSymbol( reinterpret_cast(module_handle.id()), ")")); } -absl::Status FillBlockDimLimit(GpuDeviceHandle device, - BlockDim* block_dim_limit) { +absl::Status FillBlockDimLimit(CUdevice device, BlockDim* block_dim_limit) { // The BlockDim name is a mismatch against these GRID_DIM_* queries because // we use BlockDims to express the dimensions of blocks within a grid // (as opposed to ThreadDim which expresses the dimensions of threads // within a block). int x, y, z; - TF_RETURN_IF_ERROR(GpuDriver::GetGridLimits(&x, &y, &z, device)); + TF_RETURN_IF_ERROR(GetGridLimits(&x, &y, &z, device)); block_dim_limit->x = x; block_dim_limit->y = y; block_dim_limit->z = z; return absl::OkStatus(); } -absl::StatusOr> CudaExecutor::CreateGpuEvent( - bool allow_timing) { - auto gpu_event = std::make_unique(gpu_context()); - TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); - return std::move(gpu_event); -} - absl::StatusOr> CudaExecutor::CreateEvent() { - return CreateGpuEvent(/*allow_timing=*/false); + TF_ASSIGN_OR_RETURN(auto event, CudaEvent::Create(this, false)); + return std::make_unique(std::move(event)); } absl::StatusOr> CudaExecutor::CreateStream( std::optional> priority) { - TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); - auto stream = std::make_unique(this, std::move(event), priority); + TF_ASSIGN_OR_RETURN(auto stream, CudaStream::Create(this, priority)); absl::MutexLock l(&alive_gpu_streams_mu_); - TF_RETURN_IF_ERROR(stream->Init()); - auto gpu_stream = stream->gpu_stream(); - alive_gpu_streams_[gpu_stream] = stream.get(); + alive_gpu_streams_[stream->stream_handle()] = stream.get(); return std::move(stream); } absl::StatusOr> CudaExecutor::CreateCommandBuffer(CommandBuffer::Mode mode) { VLOG(2) << "Create CUDA command buffer (CUDA graph)"; - GpuGraphHandle graph = nullptr; - TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); - return std::make_unique(mode, /*parent=*/this, graph); + return CudaCommandBuffer::Create(mode, this); } absl::Status CudaExecutor::TrimGraphMemory() { - return GpuDriver::DeviceGraphMemTrim(device_); + return cuda::ToStatus(cuDeviceGraphMemTrim(device_), + "Failed to trim device graph memory"); } absl::StatusOr> -GpuExecutor::CreateDeviceDescription(int device_ordinal) { - GpuDeviceHandle device; - TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal, &device)); +CudaExecutor::CreateDeviceDescription(int device_ordinal) { + TF_ASSIGN_OR_RETURN(CUdevice device, GetDevice(device_ordinal)); int cc_major; int cc_minor; - TF_RETURN_IF_ERROR( - GpuDriver::GetComputeCapability(&cc_major, &cc_minor, device)); + TF_RETURN_IF_ERROR(GetComputeCapability(&cc_major, &cc_minor, device)); DeviceDescription desc; @@ -678,13 +1170,13 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { ParseCudaVersion(GpuDriver::GetDriverVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_runtime_version( - ParseCudaVersion(GpuRuntime::GetRuntimeVersion().value_or(0)) + ParseCudaVersion(CudaRuntime::GetRuntimeVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_compile_time_toolkit_version( ParseCudaVersion(CUDA_VERSION).value_or(SemanticVersion{0, 0, 0})); { - std::string pci_bus_id = GpuDriver::GetPCIBusID(device); + std::string pci_bus_id = GetPCIBusID(device); // Lower the hex characters to match sysfs. pci_bus_id = absl::AsciiStrToLower(pci_bus_id); @@ -697,46 +1189,40 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { { desc.set_threads_per_block_limit( - GpuDriver::GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, - device) + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, device) .value()); ThreadDim thread_dim_limit; - thread_dim_limit.x = GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device) - .value(); - thread_dim_limit.y = GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device) - .value(); - thread_dim_limit.z = GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device) - .value(); + thread_dim_limit.x = + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, device).value(); + thread_dim_limit.y = + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y, device).value(); + thread_dim_limit.z = + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z, device).value(); desc.set_thread_dim_limit(thread_dim_limit); } int sm_clock_khz = - GpuDriver::GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device) - .value(); + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device).value(); desc.set_clock_rate_ghz(static_cast(sm_clock_khz) / 1e6); { bool ecc_enabled = false; - (void)GpuDriver::IsEccEnabled(device, &ecc_enabled); + IsEccEnabled(device, &ecc_enabled); desc.set_ecc_enabled(ecc_enabled); } uint64_t device_memory_size = static_cast(-1); - (void)GpuDriver::GetDeviceTotalMemory(device, &device_memory_size); + GetDeviceTotalMemory(device, &device_memory_size); desc.set_device_memory_size(device_memory_size); int64_t l2_cache_bytes = - GpuDriver::GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, device) - .value(); + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE, device).value(); desc.set_l2_cache_size(l2_cache_bytes); - absl::StatusOr mem_clock_khz = GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device_ordinal); - absl::StatusOr mem_bus_width_bits = GpuDriver::GetDeviceAttribute( + absl::StatusOr mem_clock_khz = + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device_ordinal); + absl::StatusOr mem_bus_width_bits = GetDeviceAttribute( CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device_ordinal); if (mem_clock_khz.ok() && mem_bus_width_bits.ok()) { // Times 2 because HBM is DDR memory; it gets two data bits per each data @@ -752,8 +1238,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } { - std::string device_name; - TF_RETURN_IF_ERROR(GpuDriver::GetDeviceName(device, &device_name)); + TF_ASSIGN_OR_RETURN(std::string device_name, GetDeviceName(device)); desc.set_name(device_name); } @@ -766,23 +1251,20 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_device_vendor("NVIDIA Corporation"); desc.set_cuda_compute_capability(cc_major, cc_minor); - desc.set_shared_memory_per_core( - GpuDriver::GetMaxSharedMemoryPerCore(device).value()); - desc.set_shared_memory_per_block( - GpuDriver::GetMaxSharedMemoryPerBlock(device).value()); + desc.set_shared_memory_per_core(GetMaxSharedMemoryPerCore(device).value()); + desc.set_shared_memory_per_block(GetMaxSharedMemoryPerBlock(device).value()); desc.set_shared_memory_per_block_optin( - GpuDriver::GetMaxSharedMemoryPerBlockOptin(device).value()); - int core_count = GpuDriver::GetMultiprocessorCount(device).value(); + GetMaxSharedMemoryPerBlockOptin(device).value()); + int core_count = GetMultiprocessorCount(device).value(); desc.set_core_count(core_count); desc.set_fpus_per_core(fpus_per_core(cc_major, cc_minor)); desc.set_threads_per_core_limit( - GpuDriver::GetMaxThreadsPerMultiprocessor(device).value()); - desc.set_registers_per_block_limit( - GpuDriver::GetMaxRegistersPerBlock(device).value()); - desc.set_threads_per_warp(GpuDriver::GetThreadsPerWarp(device).value()); + GetMaxThreadsPerMultiprocessor(device).value()); + desc.set_registers_per_block_limit(GetMaxRegistersPerBlock(device).value()); + desc.set_threads_per_warp(GetThreadsPerWarp(device).value()); desc.set_registers_per_core_limit( - GpuDriver::GetDeviceAttribute( - CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, device) + GetDeviceAttribute(CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, + device) .value()); auto value_or = [](const auto& status_or, auto default_val) { @@ -794,7 +1276,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { // identifier for the GPU model. But getting this requires using NVML or // other hacks, which we don't have access to in OSS TensorFlow. // - // Alternatively you might be tempted to use GpuDriver::GetDeviceName as a + // Alternatively you might be tempted to use GetDeviceName as a // unique identifier, but this is not stable across GPU VBIOS versions. // // For now, this identifier is good enough. @@ -806,5 +1288,22 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { return std::make_unique(std::move(desc)); } +absl::StatusOr CudaExecutor::GetPointerMemorySpace( + const void* ptr) { + CUdeviceptr pointer = reinterpret_cast(const_cast(ptr)); + unsigned int value; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuPointerGetAttribute( + &value, CU_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer))); + switch (value) { + case CU_MEMORYTYPE_DEVICE: + return MemoryType::kDevice; + case CU_MEMORYTYPE_HOST: + return MemoryType::kHost; + default: + return absl::InternalError( + absl::StrCat("unknown memory space provided by CUDA API: ", value)); + } +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h index 60ecd23d05d11d..50ba17fe4ec18b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.h @@ -1,3 +1,5 @@ +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_context.h" /* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -21,7 +23,7 @@ limitations under the License. #include #include #include -#include +#include #include #include "absl/base/thread_annotations.h" @@ -29,24 +31,19 @@ limitations under the License. #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/cuda/cuda_collectives.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/memory_allocation.h" @@ -62,19 +59,22 @@ class CudaExecutor : public GpuExecutor { CudaExecutor(Platform* platform, int device_ordinal) : GpuExecutor(platform, device_ordinal) {} ~CudaExecutor() override; + std::unique_ptr Activate() override; absl::Status Init() override; bool SynchronizeAllActivity() override; + absl::StatusOr GetMemoryRange( + const DeviceMemoryBase& location) override; absl::StatusOr CollectiveMemoryAllocate(uint64_t size) override { - return GpuCollectives::CollectiveMemoryAllocate(gpu_context(), size); + return CudaCollectives::CollectiveMemoryAllocate(this, size); } absl::Status CollectiveMemoryDeallocate(void* location) override { - return GpuCollectives::CollectiveMemoryDeallocate(gpu_context(), location); + return CudaCollectives::CollectiveMemoryDeallocate(this, location); } absl::StatusOr> CreateEventBasedTimer( - GpuStream* stream, bool use_delay_kernel) override; + Stream* stream, bool use_delay_kernel) override; absl::StatusOr GetSymbol( const std::string& symbol_name, ModuleHandle module_handle) override; absl::Status SynchronousMemZero(DeviceMemoryBase* location, @@ -85,15 +85,14 @@ class CudaExecutor : public GpuExecutor { const DeviceMemoryBase& gpu_src, uint64_t size) override; void DeallocateStream(Stream* stream) override; - absl::Status BlockHostUntilDone(Stream* stream) override; absl::Status EnablePeerAccessTo(StreamExecutor* other) override; bool CanEnablePeerAccessTo(StreamExecutor* other) override; - bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; + bool DeviceMemoryUsage(int64_t* free_out, int64_t* total_out) const override; absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; void UnloadKernel(const Kernel* kernel) override; - absl::Status LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) override; + absl::StatusOr LoadModule( + const MultiModuleLoaderSpec& spec) override; bool UnloadModule(ModuleHandle module_handle) override; absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content) override; @@ -113,33 +112,18 @@ class CudaExecutor : public GpuExecutor { absl::StatusOr> CreateDeviceDescription() const override { - return GpuExecutor::CreateDeviceDescription(device_ordinal()); - } - void* UnifiedMemoryAllocate(uint64_t size) override { - return GpuDriver::UnifiedMemoryAllocate(gpu_context(), size); - } - - void UnifiedMemoryDeallocate(void* location) override { - return GpuDriver::UnifiedMemoryDeallocate(gpu_context(), location); + return CudaExecutor::CreateDeviceDescription(device_ordinal()); } + void* UnifiedMemoryAllocate(uint64_t size) override; + void UnifiedMemoryDeallocate(void* location) override; absl::StatusOr> HostMemoryAllocate( - uint64_t size) override { - auto* buffer = GpuDriver::HostAllocate(gpu_context(), size); - if (buffer == nullptr && size > 0) { - return absl::InternalError( - absl::StrFormat("Failed to allocate HostMemory of size %d", size)); - } - return std::make_unique(buffer, size, this); - } + uint64_t size) override; - void HostMemoryDeallocate(void* location) override { - return GpuDriver::HostDeallocate(gpu_context(), location); - } + void HostMemoryDeallocate(void* location) override; + bool HostMemoryRegister(void* location, uint64_t size) override; + bool HostMemoryUnregister(void* location) override; - absl::StatusOr GetPointerMemorySpace(const void* ptr) override { - return GpuDriver::GetPointerMemorySpace( - reinterpret_cast(const_cast(ptr))); - } + absl::StatusOr GetPointerMemorySpace(const void* ptr) override; Stream* FindAllocatedStream(void* gpu_stream) override { absl::MutexLock lock(&alive_gpu_streams_mu_); @@ -150,49 +134,27 @@ class CudaExecutor : public GpuExecutor { return it->second; } - private: - // Collects metadata for the specified kernel. - absl::Status GetKernelMetadata(GpuKernel* cuda_kernel, - KernelMetadata* kernel_metadata); - - // (supported on CUDA only) - absl::Status LoadModuleFromCuBin(const char* cubin, GpuModuleHandle* module) - ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); + static absl::StatusOr> + CreateDeviceDescription(int device_ordinal); - // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated. - // (supported on CUDA only) - absl::Status LoadModuleFromPtx(const char* ptx, GpuModuleHandle* module) + private: + // Loads a module in cubin format. + absl::StatusOr LoadModuleFromCuBin(const char* cubin) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - // (supported on ROCm only) - absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) + // Loads the PTX text `ptx` as a CUDA module. `ptx` must be null terminated. + absl::StatusOr LoadModuleFromPtx(const char* ptx) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - bool UnloadGpuBinary(const void* gpu_binary) + bool UnloadGpuBinary(ModuleHandle gpu_binary) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - // Creates a GpuEvent for the given stream. - absl::StatusOr> CreateGpuEvent(bool allow_timing); - // Returns true if a delay kernel is supported. absl::StatusOr DelayKernelIsSupported(); - // Guards the on-disk-module mapping. - absl::Mutex disk_modules_mu_; - - // Mapping from filename to GPUModuleHandle, if it was already retrieved. - // Multiple GPUFunctionHandle are usually obtained from a single - // GPUModuleHandle so we attempt to hit in this mapping first, before - // retrieving it. - std::map disk_modules_ - ABSL_GUARDED_BY(disk_modules_mu_); - // Guards the in-memory-module mapping. absl::Mutex in_memory_modules_mu_; - std::map in_memory_modules_ - ABSL_GUARDED_BY(in_memory_modules_mu_); - absl::Mutex shared_constants_mu_; // On-device constants that can be shared between multiple executables. A // pointer for a given constant will expire when no executables require use @@ -200,16 +162,17 @@ class CudaExecutor : public GpuExecutor { std::map> shared_constants_ ABSL_GUARDED_BY(shared_constants_mu_); - // Kernel -> loaded GPU binary. Many kernels may load the same binary. - std::unordered_map kernel_to_gpu_binary_ + // Kernel -> loaded GPU module. Many kernels may load the same binary. + absl::flat_hash_map kernel_to_gpu_binary_ ABSL_GUARDED_BY(in_memory_modules_mu_); - // GPU binary (PTX or CUBIN or HSACO) -> {CUDA module, reference count}. - std::unordered_map> + + // Loaded GPU module handle -> {CUDA module, reference count}. + absl::flat_hash_map> gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_); // Handle for the CUDA device being operated on. Immutable // post-initialization. - GpuDeviceHandle device_; + CUdevice device_; // True if delay kernels are supported. bool delay_kernels_supported_ = false; @@ -240,6 +203,9 @@ class CudaExecutor : public GpuExecutor { // Lookup map for alive streams, from raw stream pointers. absl::flat_hash_map alive_gpu_streams_ ABSL_GUARDED_BY(alive_gpu_streams_mu_); + + // CudaContext for this device. + CudaContext* cuda_context_; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc new file mode 100644 index 00000000000000..bf898f18f31b11 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_executor.h" + +#include + +#include +#include +#include "absl/log/check.h" +#include "xla/stream_executor/cuda/cuda_platform.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/semantic_version.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using testing::IsEmpty; +using testing::Not; +using testing::VariantWith; + +TEST(CudaExecutorTest, CreateDeviceDescription) { + CudaPlatform platform; + ASSERT_GT(platform.VisibleDeviceCount(), 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + CudaExecutor::CreateDeviceDescription(0)); + + constexpr SemanticVersion kNullVersion{0, 0, 0}; + EXPECT_NE(result->runtime_version(), kNullVersion); + EXPECT_NE(result->driver_version(), kNullVersion); + EXPECT_NE(result->compile_time_toolkit_version(), kNullVersion); + + EXPECT_THAT(result->platform_version(), Not(IsEmpty())); + EXPECT_THAT(result->name(), Not(IsEmpty())); + EXPECT_THAT(result->model_str(), Not(IsEmpty())); + EXPECT_THAT(result->device_vendor(), "NVIDIA Corporation"); + + EXPECT_THAT( + result->gpu_compute_capability(), + VariantWith(Ge(CudaComputeCapability{1, 0}))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc index 831aa1ccc12606..441befcb0cd26b 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.cc @@ -27,16 +27,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cufft.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/cuda/cuda_helpers.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" @@ -74,8 +72,8 @@ cufftType CUDAFftType(fft::Type type) { } // Associates the given stream with the given cuFFT plan. -bool SetStream(GpuExecutor *parent, cufftHandle plan, Stream *stream) { - ScopedActivateContext sac(parent); +bool SetStream(StreamExecutor *parent, cufftHandle plan, Stream *stream) { + std::unique_ptr activation = parent->Activate(); auto ret = cufftSetStream(plan, AsGpuStreamValue(stream)); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to run cuFFT routine cufftSetStream: " << ret; @@ -102,16 +100,16 @@ absl::StatusOr> Downsize64bArray( } // namespace absl::Status CUDAFftPlan::Initialize( - GpuExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, - uint64_t *input_embed, uint64 input_stride, uint64 input_distance, - uint64_t *output_embed, uint64 output_stride, uint64 output_distance, + StreamExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, + uint64_t *input_embed, uint64_t input_stride, uint64_t input_distance, + uint64_t *output_embed, uint64_t output_stride, uint64_t output_distance, fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) { if (IsInitialized()) { return absl::InternalError("cuFFT is already initialized."); } is_initialized_ = true; scratch_allocator_ = scratch_allocator; - ScopedActivateContext sac(parent); + std::unique_ptr activation = parent->Activate(); // NOLINTBEGIN std::array elem_count_ = {0}; std::array input_embed_ = {0}; @@ -273,7 +271,7 @@ absl::Status CUDAFftPlan::UpdateScratchAllocator( } } // Connect work area with allocated space. - ScopedActivateContext sac(parent_); + std::unique_ptr activation = parent_->Activate(); cufftResult_t ret = cufftSetWorkArea(plan_, scratch_.opaque()); if (ret != CUFFT_SUCCESS) { LOG(ERROR) << "Failed to set work area for cuFFT plan: " << ret; @@ -283,7 +281,7 @@ absl::Status CUDAFftPlan::UpdateScratchAllocator( } CUDAFftPlan::~CUDAFftPlan() { - ScopedActivateContext sac(parent_); + std::unique_ptr activation = parent_->Activate(); cufftDestroy(plan_); } @@ -309,9 +307,9 @@ int CUDAFftPlan::GetFftDirection() const { } std::unique_ptr CUDAFft::CreateBatchedPlanWithScratchAllocator( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, fft::Type type, + Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, + uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed, + uint64_t output_stride, uint64_t output_distance, fft::Type type, bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) { std::unique_ptr fft_plan_ptr{new CUDAFftPlan()}; absl::Status status = fft_plan_ptr->Initialize( @@ -389,7 +387,7 @@ bool CUDAFft::DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufftExec, } #endif - ScopedActivateContext sac(parent_); + std::unique_ptr activation = parent_->Activate(); auto ret = cufftExec(cuda_fft_plan->GetPlan(), CUDAComplex(const_cast(GpuMemory(input_maybe_copy))), @@ -418,7 +416,7 @@ bool CUDAFft::DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, return false; } - ScopedActivateContext sac(parent_); + std::unique_ptr activation = parent_->Activate(); auto ret = cufftExec(cuda_fft_plan->GetPlan(), CUDAComplex(const_cast(GpuMemory(input))), CUDAComplex(GpuMemoryMutable(output)), @@ -463,15 +461,7 @@ void initialize_cufft() { PluginRegistry::Instance()->RegisterFactory( cuda::kCudaPlatformId, "cuFFT", [](StreamExecutor *parent) -> fft::FftSupport * { - gpu::GpuExecutor *cuda_executor = - dynamic_cast(parent); - if (cuda_executor == nullptr) { - LOG(ERROR) << "Attempting to initialize an instance of the cuFFT " - << "support library with a non-CUDA StreamExecutor"; - return nullptr; - } - - return new gpu::CUDAFft(cuda_executor); + return new gpu::CUDAFft(parent); }); if (!status.ok()) { LOG(ERROR) << "Unable to register cuFFT factory: " << status.message(); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h index 56fc78fb360219..342f9b9c72e4dd 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_fft.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_fft.h @@ -27,17 +27,13 @@ limitations under the License. #include "absl/status/status.h" #include "third_party/gpus/cuda/include/cufft.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor { - -class Stream; - namespace gpu { -class GpuExecutor; - // CUDAFftPlan uses deferred initialization. Only a single call of // Initialize() is allowed to properly create cufft plan and set member // variable is_initialized_ to true. Newly added interface that uses member @@ -66,9 +62,9 @@ class CUDAFftPlan : public fft::Plan { } // Initialize function for batched plan - absl::Status Initialize(GpuExecutor* parent, Stream* stream, int rank, + absl::Status Initialize(StreamExecutor* parent, Stream* stream, int rank, uint64_t* elem_count, uint64_t* input_embed, - uint64_t input_stride, uint64 input_distance, + uint64_t input_stride, uint64_t input_distance, uint64_t* output_embed, uint64_t output_stride, uint64_t output_distance, fft::Type type, int batch_count, ScratchAllocator* scratch_allocator); @@ -82,7 +78,7 @@ class CUDAFftPlan : public fft::Plan { bool IsInitialized() const { return is_initialized_; } private: - GpuExecutor* parent_; + StreamExecutor* parent_; cufftHandle plan_; fft::Type fft_type_; DeviceMemory scratch_; @@ -96,7 +92,7 @@ class CUDAFftPlan : public fft::Plan { // This satisfies the platform-agnostic FftSupport interface. // // Note that the cuFFT handle that this encapsulates is implicitly tied to the -// context (and, as a result, the device) that the parent GpuExecutor is tied +// context (and, as a result, the device) that the parent StreamExecutor is tied // to. This simply happens as an artifact of creating the cuFFT handle when a // CUDA context is active. // @@ -104,29 +100,29 @@ class CUDAFftPlan : public fft::Plan { // context of parent_, so all context is explicit. class CUDAFft : public fft::FftSupport { public: - explicit CUDAFft(GpuExecutor* parent) : parent_(parent) {} + explicit CUDAFft(StreamExecutor* parent) : parent_(parent) {} ~CUDAFft() override {} TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES private: - GpuExecutor* parent_; + StreamExecutor* parent_; // Two helper functions that execute dynload::cufftExec?2?. // This is for complex to complex FFT, when the direction is required. template - bool DoFftWithDirectionInternal(Stream *stream, fft::Plan *plan, + bool DoFftWithDirectionInternal(Stream* stream, fft::Plan* plan, FuncT cufft_exec, - const DeviceMemory &input, - DeviceMemory *output); + const DeviceMemory& input, + DeviceMemory* output); // This is for complex to real or real to complex FFT, when the direction // is implied. template - bool DoFftInternal(Stream *stream, fft::Plan *plan, FuncT cufft_exec, - const DeviceMemory &input, - DeviceMemory *output); + bool DoFftInternal(Stream* stream, fft::Plan* plan, FuncT cufft_exec, + const DeviceMemory& input, + DeviceMemory* output); CUDAFft(const CUDAFft&) = delete; void operator=(const CUDAFft&) = delete; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc index c62a8f99ae3298..66d01bda9713a2 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.cc @@ -13,28 +13,66 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_kernel.h" + #include #include +#include #include "absl/log/log.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/launch_dim.h" +#include "tsl/platform/errors.h" namespace stream_executor { namespace gpu { -absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( +namespace { + +absl::Status GetCudaAttribute(CUfunction_attribute attribute, CUfunction func, + int* attribute_value) { + return cuda::ToStatus( + cuFuncGetAttribute(attribute_value, attribute, func), + absl::StrCat("Failed to query kernel attribute: ", attribute)); +} + +} // namespace + +absl::StatusOr CudaKernel::GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const { int32_t threads_per_block = threads.x * threads.y * threads.z; VLOG(3) << "Get kernel block occupancy: " << name() << "; threads_per_block: " << threads_per_block << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; + std::unique_ptr activation = executor_->Activate(); + + int max_blocks; + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags( + &max_blocks, gpu_function_, threads_per_block, + dynamic_shared_memory_bytes, CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE), + absl::StrFormat("Failed to calculate occupancy of kernel %p", + gpu_function_))); + return max_blocks; +} + +absl::StatusOr CudaKernel::GetKernelMetadata() { + KernelMetadata kernel_metadata; + int value; + TF_RETURN_IF_ERROR( + GetCudaAttribute(CU_FUNC_ATTRIBUTE_NUM_REGS, gpu_function_, &value)); + kernel_metadata.set_registers_per_thread(value); - return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, - threads_per_block, - dynamic_shared_memory_bytes); + TF_RETURN_IF_ERROR(GetCudaAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + gpu_function_, &value)); + kernel_metadata.set_shared_memory_bytes(value); + return kernel_metadata; } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h new file mode 100644 index 00000000000000..b317e61c7a7cd7 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel.h @@ -0,0 +1,71 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The CUDA implementation of the StreamExecutor functionality. +// CUDA inclusions are ideally confined to this implementation file. +// +// The notions from the StreamExecutor basically correspond to the CUDA streams +// programming model provided by the libcuda.so driver APIs, so we don't have +// to do much more than wrap the calls to the libraries appropriately. +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/logging.h" + +namespace stream_executor::gpu { + +class CudaKernel : public GpuKernel { + public: + explicit CudaKernel(StreamExecutor* executor) : executor_(executor) {} + + // Note that the function is unloaded when the module is unloaded, and the + // module that the function is contained in is owned by the StreamExecutor. + ~CudaKernel() override { executor_->UnloadKernel(this); } + + // As arity cannot be reflected upon using the CUDA API, the arity is + // explicitly set during the StreamExecutor::GetKernel initialization process. + void set_arity(unsigned arity) { arity_ = arity; } + unsigned Arity() const override { return arity_; } + + absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + + // Simple accessor methods. + CUfunction gpu_function() const override { return gpu_function_; } + void set_gpu_function(CUfunction gpu_function) { + gpu_function_ = gpu_function; + } + + // Collects metadata for the specified kernel. + absl::StatusOr GetKernelMetadata(); + + private: + StreamExecutor* executor_ = nullptr; + + CUfunction gpu_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc new file mode 100644 index 00000000000000..a1ccd78a094a42 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_kernel_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_kernel.h" + +#include +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_runtime.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using tsl::testing::IsOkAndHolds; + +TEST(CudaKernelTest, GetMaxOccupiedBlocksPerCore) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + + CudaKernel cuda_kernel(executor); + cuda_kernel.set_arity(3); + + TF_ASSERT_OK_AND_ASSIGN( + CUfunction function, + CudaRuntime::GetFuncBySymbol(internal::GetAddI32Kernel())); + + cuda_kernel.set_gpu_function(function); + + EXPECT_EQ(cuda_kernel.Arity(), 3); + EXPECT_EQ(cuda_kernel.gpu_function(), function); + + EXPECT_THAT(cuda_kernel.GetMaxOccupiedBlocksPerCore( + ThreadDim(1, 1, 1), /*dynamic_shared_memory_bytes=*/0), + IsOkAndHolds(Ge(1))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index 42c2808dc1f23e..4e0ceca86c390d 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -23,9 +23,12 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_executor.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" @@ -35,17 +38,42 @@ limitations under the License. namespace stream_executor { namespace gpu { +namespace { -CudaPlatform::CudaPlatform() : name_("CUDA") {} +// Actually performs the work of CUDA initialization. Wrapped up in one-time +// execution guard. +static absl::Status InternalInit() { + absl::Status status = + cuda::ToStatus(cuInit(0 /* = flags */), "Failed call to cuInit"); + if (status.ok()) { + return status; + } + + LOG(ERROR) << "failed call to cuInit: " << status; -CudaPlatform::~CudaPlatform() {} + Diagnostician::LogDiagnosticInformation(); + return status; +} + +static absl::Status PlatformInitialize() { + // Cached return value from calling InternalInit(), as cuInit need only be + // called once, but PlatformInitialize may be called many times. + static absl::Status* initialization_status = [] { + return new absl::Status(InternalInit()); + }(); + return *initialization_status; +} + +} // namespace + +CudaPlatform::CudaPlatform() : name_("CUDA") {} Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; } int CudaPlatform::VisibleDeviceCount() const { // Initialized in a thread-safe manner the first time this is run. static const int num_devices = [] { - if (!GpuDriver::Init().ok()) return -1; + if (!PlatformInitialize().ok()) return -1; return GpuDriver::GetDeviceCount(); }(); return num_devices; @@ -55,10 +83,12 @@ const std::string& CudaPlatform::Name() const { return name_; } absl::StatusOr> CudaPlatform::DescriptionForDevice(int ordinal) const { - return GpuExecutor::CreateDeviceDescription(ordinal); + TF_RETURN_IF_ERROR(PlatformInitialize()); + return CudaExecutor::CreateDeviceDescription(ordinal); } absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { + TF_RETURN_IF_ERROR(PlatformInitialize()); return executor_cache_.GetOrCreate( ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h index e4ba806343f091..b03e90f08d8f27 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h @@ -38,7 +38,6 @@ namespace gpu { class CudaPlatform : public Platform { public: CudaPlatform(); - ~CudaPlatform() override; // Platform interface implementation: // Returns the same value as kCudaPlatform above. @@ -55,13 +54,13 @@ class CudaPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; absl::StatusOr FindExisting(int ordinal) override; + private: // Returns a device constructed with the ordinal without // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( int ordinal); - private: // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc index bf355cf9b7b1da..c9ced05c4b91e0 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/cuda/cuda_runtime.h" + #include #include "absl/base/optimization.h" @@ -23,8 +25,6 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_types.h" #include "tsl/platform/logging.h" namespace stream_executor::gpu { @@ -42,7 +42,7 @@ static const char* ToString(cudaError_t error) { } \ } while (0) -absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { +absl::StatusOr CudaRuntime::GetFuncBySymbol(void* symbol) { VLOG(2) << "Get CUDA function from a symbol: " << symbol; cudaFunction_t func; RETURN_IF_CUDA_RES_ERROR(cudaGetFuncBySymbol(&func, symbol), @@ -50,7 +50,7 @@ absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { return reinterpret_cast(func); } -absl::StatusOr GpuRuntime::GetRuntimeVersion() { +absl::StatusOr CudaRuntime::GetRuntimeVersion() { VLOG(2) << "Get CUDA runtime version"; int32_t version; RETURN_IF_CUDA_RES_ERROR(cudaRuntimeGetVersion(&version), diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_runtime.h b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.h similarity index 86% rename from third_party/xla/xla/stream_executor/gpu/gpu_runtime.h rename to third_party/xla/xla/stream_executor/cuda/cuda_runtime.h index 6f36c7ceab1ea1..32ebd5cf8611a5 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_runtime.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_runtime.h @@ -15,13 +15,13 @@ limitations under the License. // CUDA/ROCm runtime library wrapper functionality. -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ #include #include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/gpu_types.h" +#include "third_party/gpus/cuda/include/cuda.h" namespace stream_executor::gpu { @@ -39,10 +39,10 @@ namespace stream_executor::gpu { // //===----------------------------------------------------------------------===// -// Gpu runtime returns types defined in the stream_executor::gpu namespace, and +// Cuda runtime returns types defined in the stream_executor::gpu namespace, and // they usually correspond to the driver types, as driver API is the primary // integration API of Gpus into StreamExecutor. -class GpuRuntime { +class CudaRuntime { public: // Get pointer to device entry function that matches entry function `symbol`. // @@ -52,7 +52,7 @@ class GpuRuntime { // current device (and create it if it doesn't exist yet). // // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER.html#group__CUDART__DRIVER_1gaba6f8d01e745f0c8d8776ceb18be617 - static absl::StatusOr GetFuncBySymbol(void* symbol); + static absl::StatusOr GetFuncBySymbol(void* symbol); // Returns the Gpu Runtime version. // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION_1g0e3952c7802fd730432180f1f4a6cdc6 @@ -61,4 +61,4 @@ class GpuRuntime { } // namespace stream_executor::gpu -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_RUNTIME_H_ +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_RUNTIME_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc new file mode 100644 index 00000000000000..469c19a8b60b58 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.cc @@ -0,0 +1,501 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_stream.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_context.h" +#include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/cuda/cuda_kernel.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/nvtx_utils.h" + +namespace stream_executor { +namespace gpu { + +namespace { +absl::Status WaitStreamOnEvent(StreamExecutor* executor, CUstream stream, + CUevent event) { + std::unique_ptr activation = executor->Activate(); + return cuda::ToStatus(cuStreamWaitEvent(stream, event, 0 /* = flags */)); +} + +absl::Status RecordGpuEvent(StreamExecutor* executor, CUevent event, + CUstream stream) { + std::unique_ptr activation = executor->Activate(); + return cuda::ToStatus(cuEventRecord(event, stream), + "Error recording CUDA event"); +} + +int GetGpuStreamPriority(stream_executor::StreamPriority stream_priority) { + if (stream_priority == stream_executor::StreamPriority::Default) { + return 0; + } + int lowest, highest; + auto status = cuda::ToStatus(cuCtxGetStreamPriorityRange(&lowest, &highest)); + if (!status.ok()) { + LOG(ERROR) + << "Could not query stream priority range. Returning default priority."; + return 0; + } + return stream_priority == stream_executor::StreamPriority::Highest ? highest + : lowest; +} + +absl::StatusOr CreateStream(StreamExecutor* executor, int priority) { + std::unique_ptr activation = executor->Activate(); + CUstream stream; + // If the priority is 0, then use the previous api to create the stream with + // the default priority for backward compatibility. Probably there is no + // difference in using the new api call but leaving it as is for now. + if (priority == 0) { + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING))); + } else { + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuStreamCreateWithPriority(&stream, CU_STREAM_NON_BLOCKING, priority))); + } + + VLOG(2) << "successfully created stream " << stream << " for executor " + << executor << " on thread"; + return stream; +} + +absl::StatusOr StreamIsCapturing(CUstream stream) { + VLOG(2) << "Checking if stream " << stream << " is capturing"; + + CUstreamCaptureStatus status; + TF_RETURN_IF_ERROR(cuda::ToStatus(cuStreamIsCapturing(stream, &status), + "Failed to check stream capturing status")); + + return status == CU_STREAM_CAPTURE_STATUS_ACTIVE; +} + +absl::Status AsynchronousMemcpyD2H(StreamExecutor* executor, void* host_dst, + CUdeviceptr gpu_src, uint64_t size, + CUstream stream) { + std::unique_ptr activation = executor->Activate(); + + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream))); + + VLOG(2) << "successfully enqueued async memcpy d2h of " << size + << " bytes from " << absl::bit_cast(gpu_src) << " to " + << host_dst << " on stream " << stream; + return absl::OkStatus(); +} + +absl::Status AsynchronousMemcpyH2D(StreamExecutor* executor, + CUdeviceptr gpu_dst, const void* host_src, + uint64_t size, CUstream stream) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream))); + + VLOG(2) << "successfully enqueued async memcpy h2d of " << size << " bytes" + << " from " << host_src << " to " << absl::bit_cast(gpu_dst) + << " on stream " << stream; + return absl::OkStatus(); +} + +absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, + CUdeviceptr gpu_dst, CUdeviceptr gpu_src, + uint64_t size, CUstream stream) { + std::unique_ptr activation = executor->Activate(); + + // In graph capture mode we never have operations that access peer memory, so + // we can always make a call to cuMemcpyDtoDAsync. + TF_ASSIGN_OR_RETURN(bool is_capturing, StreamIsCapturing(stream)); + + if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { + // GetContextMap()->GetAnyContext() doesn't work when ptr == 0. + // This happens when the size is 0. + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); + } else { + // Any context work here. + CUcontext dst_context = CudaContext::GetContextMap()->GetAnyContext( + absl::bit_cast(gpu_dst)); + CUcontext src_context = CudaContext::GetContextMap()->GetAnyContext( + absl::bit_cast(gpu_src)); + + if (dst_context == src_context) { + // Since the CUDA context is the same, the src and dst are within the same + // GPU. So we can use cuMemcpyDtoD. + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream))); + } else { + TF_RETURN_IF_ERROR(cuda::ToStatus(cuMemcpyPeerAsync( + gpu_dst, dst_context, gpu_src, src_context, size, stream))); + } + } + + VLOG(2) << "successfully enqueued async memcpy d2d of " << size << " bytes" + << " from " << absl::bit_cast(gpu_src) << " to " + << absl::bit_cast(gpu_dst) << " on stream " << stream; + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> CudaStream::Create( + StreamExecutor* executor, + std::optional> priority) { + int stream_priority = [&]() { + if (priority.has_value() && std::holds_alternative(priority.value())) { + return std::get(priority.value()); + } + std::unique_ptr activation = executor->Activate(); + return GetGpuStreamPriority( + std::get(priority.value_or(StreamPriority::Default))); + }(); + TF_ASSIGN_OR_RETURN(auto stream_handle, + CreateStream(executor, stream_priority)); + + TF_ASSIGN_OR_RETURN(auto completed_event, + CudaEvent::Create(executor, + /*allow_timing=*/false)); + + return std::unique_ptr(new CudaStream( + executor, std::move(completed_event), priority, stream_handle)); +} + +absl::Status CudaStream::WaitFor(Stream* other) { + CudaStream* other_stream = static_cast(other); + + TF_RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); + return WaitStreamOnEvent(executor_, stream_handle_, + other_stream->completed_event_.GetHandle()); +} + +absl::Status CudaStream::RecordEvent(Event* event) { + return RecordGpuEvent(executor_, static_cast(event)->GetHandle(), + stream_handle_); +} + +absl::Status CudaStream::WaitFor(Event* event) { + return WaitStreamOnEvent(executor_, stream_handle_, + static_cast(event)->GetHandle()); +} + +absl::Status CudaStream::RecordCompletedEvent() { + return RecordEvent(&completed_event_); +} + +namespace { +void DestroyStream(StreamExecutor* executor, CUstream stream) { + if (stream == nullptr) { + return; + } + + std::unique_ptr activation = executor->Activate(); + CUresult res = cuStreamQuery(stream); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "stream not idle on destroy: " << cuda::ToStatus(res); + } + + auto status = cuda::ToStatus(cuStreamDestroy(stream)); + if (!status.ok()) { + LOG(ERROR) << "failed to destroy CUDA stream for executor " << executor + << ": " << status; + } else { + VLOG(2) << "successfully destroyed stream " << stream << " for executor " + << executor; + } +} + +absl::Status SynchronizeStream(StreamExecutor* executor, CUstream stream) { + std::unique_ptr activation = executor->Activate(); + CHECK(stream != nullptr); + return cuda::ToStatus(cuStreamSynchronize(stream), + "Could not synchronize CUDA stream"); +} +} // namespace + +CudaStream::~CudaStream() { + BlockHostUntilDone().IgnoreError(); + executor_->DeallocateStream(this); + + DestroyStream(executor_, stream_handle_); +} + +absl::Status CudaStream::BlockHostUntilDone() { + return SynchronizeStream(executor_, stream_handle_); +} + +absl::Status CudaStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) { + if (absl::bit_cast(location->opaque()) % alignof(uint32_t) != 0) { + return absl::InvalidArgumentError("location must be 4 byte aligned."); + } + if (size % sizeof(uint32_t) != 0) { + return absl::InvalidArgumentError("size must be a multiple of 4 bytes."); + } + std::unique_ptr activation = executor_->Activate(); + return cuda::ToStatus( + cuMemsetD32Async(absl::bit_cast(location->opaque()), pattern, + size / 4, stream_handle_), + "Failed to enqueue async memset operation"); +} + +absl::Status CudaStream::MemZero(DeviceMemoryBase* location, uint64_t size) { + if (reinterpret_cast(location->opaque()) % alignof(uint32_t) == + 0 && + size % sizeof(uint32_t) == 0) { + return Memset32(location, 0x0, size); + } else { + std::unique_ptr activation = executor_->Activate(); + return cuda::ToStatus( + cuMemsetD8Async(absl::bit_cast(location->opaque()), 0x0, + size, stream_handle_), + "Failed to enqueue async memset operation"); + } +} + +absl::Status CudaStream::Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + return AsynchronousMemcpyD2D( + executor_, absl::bit_cast(gpu_dst->opaque()), + absl::bit_cast(gpu_src.opaque()), size, stream_handle_); +} + +absl::Status CudaStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) { + return AsynchronousMemcpyH2D(executor_, + absl::bit_cast(gpu_dst->opaque()), + host_src, size, stream_handle_); +} + +absl::Status CudaStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) { + return AsynchronousMemcpyD2H(executor_, host_dst, + absl::bit_cast(gpu_src.opaque()), + size, stream_handle_); +} + +namespace { +void InternalHostCallback(void* data) { + auto* callback = reinterpret_cast*>(data); + std::move (*callback)(); + delete callback; +} +} // namespace + +absl::Status CudaStream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + auto callback_ptr = + new absl::AnyInvocable([cb = std::move(callback)]() mutable { + absl::Status s = (std::move(cb))(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); + return cuda::ToStatus( + cuLaunchHostFunc(stream_handle_, InternalHostCallback, callback_ptr)); +} + +namespace { +absl::Status LaunchKernel(StreamExecutor* executor, + absl::string_view kernel_name, CUfunction function, + unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, + unsigned int block_dim_y, unsigned int block_dim_z, + unsigned int shared_mem_bytes, CUstream stream, + void** kernel_params, void** extra) { + std::unique_ptr activation = executor->Activate(); + VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x + << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z + << " bdx: " << block_dim_x << " bdy: " << block_dim_y + << " bdz: " << block_dim_z + << "; shared_mem_bytes: " << shared_mem_bytes; + + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). + if (shared_mem_bytes != 0) { + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuFuncSetAttribute(function, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_mem_bytes), + "Failed to set shared memory size")); + } + + return cuda::ToStatus( + cuLaunchKernel(function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, + block_dim_y, block_dim_z, shared_mem_bytes, stream, + kernel_params, extra), + absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, + "; block dims: ", block_dim_x, "x", block_dim_y, "x", + block_dim_z, "; grid dims: ", grid_dim_x, "x", grid_dim_y, + "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes)); +} + +absl::Status LaunchKernel(StreamExecutor* executor, + absl::string_view kernel_name, CUfunction function, + unsigned int cluster_dim_x, + unsigned int cluster_dim_y, + unsigned int cluster_dim_z, unsigned int grid_dim_x, + unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, + unsigned int shared_mem_bytes, CUstream stream, + void** kernel_params, void** extra) { + std::unique_ptr activation = executor->Activate(); + VLOG(2) << "launching kernel: " << kernel_name << "; cdx: " << cluster_dim_x + << " cdy: " << cluster_dim_y << " cdz: " << cluster_dim_z + << " gdx: " << grid_dim_x << " gdy: " << grid_dim_y + << " gdz: " << grid_dim_z << " bdx: " << block_dim_x + << " bdy: " << block_dim_y << " bdz: " << block_dim_z + << "; shared_mem_bytes: " << shared_mem_bytes; + + // TODO(ezhulenev): Why do we do it on every call to launch kernel? This + // should be moved one level up to se::Kernel level, and done just once (or + // updated once we get a new larger shared memory request). + if (shared_mem_bytes != 0) { + TF_RETURN_IF_ERROR(cuda::ToStatus( + cuFuncSetAttribute(function, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_mem_bytes), + "Failed to set shared memory size")); + } + + CUlaunchConfig launch_config; + memset(&launch_config, 0, sizeof(launch_config)); + launch_config.blockDimX = block_dim_x; + launch_config.blockDimY = block_dim_y; + launch_config.blockDimZ = block_dim_z; + launch_config.gridDimX = grid_dim_x; + launch_config.gridDimY = grid_dim_y; + launch_config.gridDimZ = grid_dim_z; + launch_config.hStream = stream; + launch_config.sharedMemBytes = shared_mem_bytes; + + CUlaunchAttribute cluster_dims; + memset(&cluster_dims, 0, sizeof(cluster_dims)); + cluster_dims.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + cluster_dims.value.clusterDim.x = cluster_dim_x; + cluster_dims.value.clusterDim.y = cluster_dim_y; + cluster_dims.value.clusterDim.z = cluster_dim_z; + + launch_config.attrs = &cluster_dims; + launch_config.numAttrs = 1; + + return cuda::ToStatus( + cuLaunchKernelEx(&launch_config, function, kernel_params, extra), + absl::StrCat("Failed to launch CUDA kernel: ", kernel_name, + "; cluster dims: ", cluster_dim_x, "x", cluster_dim_y, "x", + cluster_dim_z, "; block dims: ", block_dim_x, "x", + block_dim_y, "x", block_dim_z, "; grid dims: ", grid_dim_x, + "x", grid_dim_y, "x", grid_dim_z, + "; shared memory size: ", shared_mem_bytes)); +} + +} // namespace + +absl::Status CudaStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + const CudaKernel* gpu_kernel = static_cast(&kernel); + CUfunction function = gpu_kernel->gpu_function(); + + // Launch kernels with packed arguments. + auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, + &function](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return LaunchKernel( + executor_, kernel.name(), function, cluster_dims->x, cluster_dims->y, + cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), stream_handle_, params, + /*extra=*/nullptr); + } else { + return LaunchKernel( + executor_, kernel.name(), function, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), stream_handle_, params, + /*extra=*/nullptr); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + +void CudaStream::SetName(std::string name) { + tsl::profiler::NameStream( + absl::bit_cast(stream_handle_), name); + StreamCommon::SetName(std::move(name)); +} + +} // namespace gpu +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream.h b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h new file mode 100644 index 00000000000000..7d8be77df9366c --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream.h @@ -0,0 +1,104 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" + +namespace stream_executor { +namespace gpu { + +class CudaStream : public StreamCommon { + public: + absl::Status WaitFor(Stream* other) override; + absl::Status RecordEvent(Event* event) override; + absl::Status WaitFor(Event* event) override; + + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override; + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override; + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override; + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override; + absl::Status BlockHostUntilDone() override; + + void SetName(std::string name) override; + + Stream::PlatformSpecificHandle platform_specific_handle() const override { + return {stream_handle_}; + } + + absl::StatusOr> CreateEventBasedTimer( + bool use_delay_kernel) override { + return executor_->CreateEventBasedTimer(this, use_delay_kernel); + } + + static absl::StatusOr> Create( + StreamExecutor* executor, + std::optional> priority); + + ~CudaStream() override; + + CUstream stream_handle() const { return stream_handle_; } + + private: + CudaStream(StreamExecutor* executor, CudaEvent completed_event, + std::optional> priority, + CUstream stream_handle) + : StreamCommon(executor, priority), + executor_(executor), + completed_event_(std::move(completed_event)), + stream_handle_(stream_handle) {} + + absl::Status RecordCompletedEvent(); + + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) override; + + StreamExecutor* executor_; + CudaEvent completed_event_; + CUstream stream_handle_; +}; +} // namespace gpu + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc new file mode 100644 index 00000000000000..2e905e990e648a --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_stream_test.cc @@ -0,0 +1,241 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_stream.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/stream_executor/cuda/cuda_executor.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace gpu { +namespace { + +using ::testing::Each; +using ::testing::ElementsAreArray; +using ::tsl::testing::IsOk; + +class CudaStreamTest : public ::testing::Test { + public: + CudaExecutor* executor_; + + private: + void SetUp() override { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::cuda::kCudaPlatformId)); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + executor_ = reinterpret_cast(executor); + } +}; + +TEST_F(CudaStreamTest, Memset32) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + // Should fail due to the invalid size parameter. + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t) + 1), + ::tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument)); + + // Should fail due to the non-4-byte-aligned pointer. + DeviceMemoryBase unaligned_pointer = + buffer.GetByteSlice(/*offset_bytes=*/1, /*size_bytes=*/0); + EXPECT_THAT(stream->Memset32(&unaligned_pointer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t) + 1), + ::tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument)); + + // Correct call. Should succeed. + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, Each(0xDEADBEEF)); +} + +TEST_F(CudaStreamTest, MemZero) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + // We overwrite half the buffer with zeros. + EXPECT_THAT( + stream->MemZero(&buffer, kBufferNumElements / 2 * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + // We expect the first half of the buffer to be zeros. + EXPECT_THAT( + absl::MakeConstSpan(host_buffer).subspan(0, kBufferNumElements / 2), + Each(0x0)); + + // And it shouldn't have touched the second half. + EXPECT_THAT(absl::MakeConstSpan(host_buffer).subspan(kBufferNumElements / 2), + Each(0xDEADBEEF)); +} + +TEST_F(CudaStreamTest, MemcpyHostToDeviceAndBack) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + std::array src_buffer; + std::generate(src_buffer.begin(), src_buffer.end(), + [i = 0]() mutable { return i++; }); + + EXPECT_THAT(stream->MemcpyH2D(absl::MakeConstSpan(src_buffer), &buffer), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, ElementsAreArray(src_buffer)); +} + +TEST_F(CudaStreamTest, MemcpyDeviceToDevice) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer1 = + executor_->AllocateArray(kBufferNumElements, 0); + DeviceMemory buffer2 = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + EXPECT_THAT(stream->Memset32(&buffer1, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + EXPECT_THAT(stream->MemcpyD2D(&buffer2, buffer1, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer2, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, Each(0xDEADBEEF)); +} + +TEST_F(CudaStreamTest, DoHostCallback) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + int callback_call_counter = 0; + EXPECT_THAT(stream->DoHostCallback( + [&callback_call_counter]() { callback_call_counter++; }), + IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_EQ(callback_call_counter, 1); +} + +TEST_F(CudaStreamTest, LaunchKernel) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor_, spec)); + + constexpr int64_t kLength = 4; + constexpr int64_t kByteLength = sizeof(int32_t) * kLength; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor_->AllocateArray(kLength, 0); + DeviceMemory b = executor_->AllocateArray(kLength, 0); + DeviceMemory c = executor_->AllocateArray(kLength, 0); + + EXPECT_THAT(stream->Memset32(&a, 1, kByteLength), IsOk()); + EXPECT_THAT(stream->Memset32(&b, 2, kByteLength), IsOk()); + EXPECT_THAT(stream->MemZero(&c, kByteLength), IsOk()); + EXPECT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(kLength), add, a, b, c), + IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(c, absl::MakeSpan(host_buffer)), IsOk()); + EXPECT_THAT(host_buffer, Each(3)); +} + +TEST_F(CudaStreamTest, SetName) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + CudaStream::Create(executor_, + /*priority=*/std::nullopt)); + + constexpr absl::string_view kStreamName = "Test stream"; + stream->SetName(std::string(kStreamName)); + EXPECT_EQ(stream->GetName(), kStreamName); +} + +} // namespace +} // namespace gpu +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc b/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc new file mode 100644 index 00000000000000..b33fda0dc59317 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_timer.cc @@ -0,0 +1,121 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_timer.h" + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/cuda/delay_kernel.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/gpu/gpu_stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor::gpu { + +namespace { +absl::StatusOr GetEventElapsedTime(StreamExecutor *executor, + CUevent start, CUevent stop) { + std::unique_ptr activation = executor->Activate(); + // The stop event must have completed in order for cuEventElapsedTime to + // work. + TF_RETURN_IF_ERROR(cuda::ToStatus(cuEventSynchronize(stop))); + + float elapsed_milliseconds; + + TF_RETURN_IF_ERROR( + cuda::ToStatus(cuEventElapsedTime(&elapsed_milliseconds, start, stop))); + + return elapsed_milliseconds; +} + +} // namespace + +CudaTimer::CudaTimer(StreamExecutor *executor, CudaEvent start_event, + CudaEvent stop_event, Stream *stream, + GpuSemaphore semaphore) + : semaphore_(std::move(semaphore)), + executor_(executor), + stream_(stream), + start_event_(std::move(start_event)), + stop_event_(std::move(stop_event)) {} + +CudaTimer::~CudaTimer() { + if (semaphore_ && !is_stopped_) { + // Signal the delay kernel that it can exit + *semaphore_ = GpuSemaphoreState::kRelease; + // Wait for the delay kernel to exit before destroying the value that it is + // watching. + absl::Status result = stream_->BlockHostUntilDone(); + if (!result.ok()) { + LOG(ERROR) << result.message(); + } + } +} + +absl::StatusOr CudaTimer::GetElapsedDuration() { + if (is_stopped_) { + return absl::FailedPreconditionError("Measuring inactive timer"); + } + TF_RETURN_IF_ERROR(stream_->RecordEvent(&stop_event_)); + // If we launched the delay kernel then check if it already timed out. + if (semaphore_) { + if (*semaphore_ == GpuSemaphoreState::kTimedOut) { + // The delay kernel did not achieve the intended result. + LOG(ERROR) << "Delay kernel timed out: measured time has sub-optimal " + "accuracy. There may be a missing warmup execution, please " + "investigate in Nsight Systems."; + } else { + // Signal that the kernel can exit + *semaphore_ = GpuSemaphoreState::kRelease; + } + } + TF_ASSIGN_OR_RETURN(float elapsed_milliseconds, + GetEventElapsedTime(executor_, start_event_.GetHandle(), + stop_event_.GetHandle())); + is_stopped_ = true; + return absl::Milliseconds(elapsed_milliseconds); +} + +absl::StatusOr CudaTimer::Create(StreamExecutor *executor, + Stream *stream, + TimerType timer_type) { + GpuSemaphore semaphore{}; + + if (timer_type == TimerType::kDelayKernel) { + TF_ASSIGN_OR_RETURN(semaphore, LaunchDelayKernel(stream)); + } + + TF_ASSIGN_OR_RETURN(CudaEvent start_event, + CudaEvent::Create(executor, /*allow_timing=*/true)); + TF_ASSIGN_OR_RETURN(CudaEvent stop_event, + CudaEvent::Create(executor, /*allow_timing=*/true)); + + TF_RETURN_IF_ERROR(stream->RecordEvent(&start_event)); + + return CudaTimer(executor, std::move(start_event), std::move(stop_event), + stream, std::move(semaphore)); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_timer.h b/third_party/xla/xla/stream_executor/cuda/cuda_timer.h new file mode 100644 index 00000000000000..2690c4b63cb434 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_timer.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_ +#define XLA_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_ + +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "xla/stream_executor/cuda/cuda_event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/gpu/gpu_semaphore.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" + +namespace stream_executor::gpu { + +// This class implements EventBasedTimer for CUDA devices. +class CudaTimer : public EventBasedTimer { + public: + ~CudaTimer() override; + CudaTimer(CudaTimer&&) = default; + CudaTimer& operator=(CudaTimer&&) = default; + + absl::StatusOr GetElapsedDuration() override; + + enum class TimerType { + kDelayKernel, + kEventBased, + }; + static absl::StatusOr Create(StreamExecutor* executor, + Stream* stream, TimerType timer_type); + + private: + CudaTimer(StreamExecutor* executor, CudaEvent start_event, + CudaEvent stop_event, Stream* stream, GpuSemaphore semaphore); + + GpuSemaphore semaphore_; + bool is_stopped_ = false; + StreamExecutor* executor_; + Stream* stream_; + CudaEvent start_event_; + CudaEvent stop_event_; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_CUDA_CUDA_TIMER_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc new file mode 100644 index 00000000000000..021ce4f7d2cdd7 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_timer_test.cc @@ -0,0 +1,106 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_timer.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "xla/stream_executor/cuda/cuda_executor.h" +#include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::testing::Gt; +using ::tsl::testing::IsOk; + +class CudaTimerTest : public ::testing::TestWithParam { + public: + void LaunchSomeKernel(StreamExecutor* executor, Stream* stream) { + using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + ASSERT_THAT(stream->Memset32(&a, 1, byte_length), IsOk()); + ASSERT_THAT(stream->Memset32(&b, 2, byte_length), IsOk()); + ASSERT_THAT(stream->MemZero(&c, byte_length), IsOk()); + + ASSERT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c), + IsOk()); + } + + StreamExecutor* executor_; + std::unique_ptr stream_; + + private: + void SetUp() override { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::cuda::kCudaPlatformId)); + TF_ASSERT_OK_AND_ASSIGN(executor_, platform->ExecutorForDevice(0)); + TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt)); + } +}; + +TEST_P(CudaTimerTest, Create) { + TF_ASSERT_OK_AND_ASSIGN( + CudaTimer timer, CudaTimer::Create(executor_, stream_.get(), GetParam())); + + // We don't really care what kernel we launch here as long as it takes a + // non-zero amount of time. + LaunchSomeKernel(executor_, stream_.get()); + + TF_ASSERT_OK_AND_ASSIGN(absl::Duration timer_result, + timer.GetElapsedDuration()); + EXPECT_THAT(timer_result, Gt(absl::ZeroDuration())); + EXPECT_THAT(timer.GetElapsedDuration(), + tsl::testing::StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +INSTANTIATE_TEST_SUITE_P(CudaTimerTest, CudaTimerTest, + ::testing::Values(CudaTimer::TimerType::kEventBased, + CudaTimer::TimerType::kDelayKernel)); + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc index 29035d049f2c31..e0c5138278e676 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel_cuda.cu.cc @@ -16,8 +16,6 @@ limitations under the License. #include #include "xla/stream_executor/cuda/delay_kernel.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" #include "xla/stream_executor/typed_kernel_factory.h" diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc index 7d15865dd89110..75bcaa0965946f 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_impl.cc @@ -35,9 +35,10 @@ limitations under the License. #include "absl/types/span.h" #include "third_party/gpus/cuda/include/nvJitLink.h" #include "xla/stream_executor/cuda/nvjitlink.h" -#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/cuda/ptx_compiler_helpers.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc new file mode 100644 index 00000000000000..2f04878d5a96df --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.cc @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/ptx_compiler_helpers.h" + +#include + +#include "absl/strings/match.h" + +namespace stream_executor { + +bool IsPtxRegisterAllocationError(std::string_view str) { + return absl::StrContains(str, "ptxas fatal") && + (absl::StrContains(str, "Register allocation failed") || + absl::StrContains(str, "Insufficient registers")); +} + +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h new file mode 100644 index 00000000000000..ceb586e283ce67 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_helpers.h @@ -0,0 +1,24 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ +#define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ +#include + +namespace stream_executor { +// Checks whether ptxas log contains errors related to register allocation. +bool IsPtxRegisterAllocationError(std::string_view); +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_HELPERS_H_ diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc index a936a45e11b8ce..ba1959c2b0553e 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_impl.cc @@ -35,9 +35,10 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/nvPTXCompiler.h" #include "xla/stream_executor/cuda/ptx_compiler.h" -#include "xla/stream_executor/cuda/ptx_compiler_support.h" +#include "xla/stream_executor/cuda/ptx_compiler_helpers.h" #include "xla/stream_executor/gpu/gpu_asm_opts.h" #include "xla/stream_executor/semantic_version.h" +#include "tsl/platform/logging.h" namespace stream_executor { @@ -147,9 +148,12 @@ absl::StatusOr> CompileGpuAsmUsingLibNvPtxCompiler( RETURN_IF_NVPTXCOMPILER_ERROR( nvPTXCompilerGetInfoLogSize(compiler_handle, &info_log_size)); - std::string info_log(info_log_size, '\0'); + std::vector info_log_buffer(info_log_size + 1); RETURN_IF_NVPTXCOMPILER_ERROR( - nvPTXCompilerGetInfoLog(compiler_handle, info_log.data())); + nvPTXCompilerGetInfoLog(compiler_handle, info_log_buffer.data())); + // The buffer may have several trailing null characters, so create a string + // from the pointer to the buffer rather than pair of iterators. + std::string info_log(info_log_buffer.data()); // Print the verbose output of ptxas. if (!info_log.empty()) { diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.cc index e104d4f52f0dac..994687bf95a141 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.cc @@ -15,14 +15,6 @@ limitations under the License. #include "xla/stream_executor/cuda/ptx_compiler_support.h" -#include "absl/strings/match.h" - namespace stream_executor { bool IsLibNvPtxCompilerSupported() { return LIBNVPTXCOMPILER_SUPPORT; } - -bool IsPtxRegisterAllocationError(absl::string_view str) { - return absl::StrContains(str, "ptxas fatal") && - (absl::StrContains(str, "Register allocation failed") || - absl::StrContains(str, "Insufficient registers")); -} } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.h b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.h index 084c3351e8c691..37f28f8a45c9a8 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.h +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_support.h @@ -16,15 +16,10 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ #define XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ -#include "absl/strings/string_view.h" - namespace stream_executor { // Returns true if XLA was built with libnvptxcompiler support. Otherwise false // is returned. bool IsLibNvPtxCompilerSupported(); - -// Checks whether ptxas log contains errors related to register allocation. -bool IsPtxRegisterAllocationError(absl::string_view); } // namespace stream_executor #endif // XLA_STREAM_EXECUTOR_CUDA_PTX_COMPILER_SUPPORT_H_ diff --git a/third_party/xla/xla/stream_executor/data_type.h b/third_party/xla/xla/stream_executor/data_type.h index ebac59ba7c4eae..f5246389e485c3 100644 --- a/third_party/xla/xla/stream_executor/data_type.h +++ b/third_party/xla/xla/stream_executor/data_type.h @@ -19,8 +19,8 @@ limitations under the License. #include #include +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct bfloat16; @@ -37,6 +37,14 @@ struct ToDataType; // Note: If you add a new specialization below, make sure to add the // corresponding definition in stream_executor/dnn.cc. template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E3M4; +}; +template <> +struct ToDataType { + static constexpr DataType value = DataType::kF8E4M3; +}; +template <> struct ToDataType { static constexpr DataType value = DataType::kF8E4M3FN; }; diff --git a/third_party/xla/xla/stream_executor/device_description.cc b/third_party/xla/xla/stream_executor/device_description.cc index ca19e68ffdc382..7486bda1002d72 100644 --- a/third_party/xla/xla/stream_executor/device_description.cc +++ b/third_party/xla/xla/stream_executor/device_description.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/launch_dim.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/device_description.h b/third_party/xla/xla/stream_executor/device_description.h index 99d7f1ce5d83c8..195dd058aa64da 100644 --- a/third_party/xla/xla/stream_executor/device_description.h +++ b/third_party/xla/xla/stream_executor/device_description.h @@ -123,6 +123,18 @@ struct CudaComputeCapability { return !(*this == other); } + bool operator>(const CudaComputeCapability &other) const { + return ToPair() > other.ToPair(); + } + + bool operator>=(const CudaComputeCapability &other) const { + return ToPair() >= other.ToPair(); + } + + bool operator<=(const CudaComputeCapability &other) const { + return ToPair() <= other.ToPair(); + } + std::string ToString() const { return absl::StrCat(major, ".", minor); } std::pair ToPair() const { return std::make_pair(major, minor); } diff --git a/third_party/xla/xla/stream_executor/device_description_test.cc b/third_party/xla/xla/stream_executor/device_description_test.cc index 4600a7a04e97d5..ba65c78f8460c9 100644 --- a/third_party/xla/xla/stream_executor/device_description_test.cc +++ b/third_party/xla/xla/stream_executor/device_description_test.cc @@ -47,5 +47,35 @@ TEST(CudaComputeCapability, GenerationLiteralTest) { EXPECT_TRUE(CudaComputeCapability::Blackwell().IsAtLeast(10)); } +TEST(CudaComputeCapability, ComparisonTest) { + CudaComputeCapability lower{1, 0}; + CudaComputeCapability slightly_higher{1, 1}; + CudaComputeCapability higher{2, 0}; + + EXPECT_TRUE(lower == lower); + EXPECT_FALSE(lower == slightly_higher); + EXPECT_FALSE(lower == higher); + + EXPECT_TRUE(lower <= lower); + EXPECT_TRUE(lower < slightly_higher); + EXPECT_TRUE(lower <= slightly_higher); + + EXPECT_FALSE(lower < lower); + EXPECT_FALSE(slightly_higher <= lower); + EXPECT_FALSE(slightly_higher < lower); + + EXPECT_TRUE(slightly_higher >= slightly_higher); + EXPECT_TRUE(slightly_higher > lower); + EXPECT_TRUE(slightly_higher >= lower); + + EXPECT_FALSE(slightly_higher > slightly_higher); + EXPECT_FALSE(lower > slightly_higher); + EXPECT_FALSE(lower >= slightly_higher); + + EXPECT_TRUE(higher > slightly_higher); + EXPECT_TRUE(higher >= slightly_higher); + EXPECT_TRUE(higher >= higher); +} + } // namespace } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/device_memory.h b/third_party/xla/xla/stream_executor/device_memory.h index 5334e79f4565c6..43b645b4c345df 100644 --- a/third_party/xla/xla/stream_executor/device_memory.h +++ b/third_party/xla/xla/stream_executor/device_memory.h @@ -35,8 +35,6 @@ limitations under the License. namespace stream_executor { -class DeviceMemoryAllocator; - // void*-analogous device memory allocation. For the typed variation, see // DeviceMemory. // diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index 951b2f6e147cd8..6b7a87d80b3aec 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -41,9 +41,12 @@ limitations under the License. #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" #include "xla/tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/protobuf/dnn.pb.h" +#include "xla/util.h" #include "tsl/platform/ml_dtypes.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace dnn { @@ -66,6 +69,8 @@ bool ProtoMapsEqual(const google::protobuf::Map& x, } // namespace +constexpr DataType ToDataType::value; +constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; constexpr DataType ToDataType::value; @@ -296,19 +301,6 @@ absl::Status DnnSupport::DoPoolBackward( workspace_allocator); } -std::string QuantizedActivationModeString(QuantizedActivationMode mode) { - switch (mode) { - case dnn::QuantizedActivationMode::k8Bit: - return "uint8"; - case dnn::QuantizedActivationMode::k16Bit: - return "uint16"; - case dnn::QuantizedActivationMode::k32Bit: - return "int32"; - default: - return absl::StrCat("unknown: ", static_cast(mode)); - } -} - std::string ActivationModeString(ActivationMode mode) { switch (mode) { case ActivationMode::kNone: @@ -334,17 +326,6 @@ std::string ActivationModeString(ActivationMode mode) { } } -std::string ElementwiseOperationString(ElementwiseOperation op) { - switch (op) { - case ElementwiseOperation::kAdd: - return "add"; - case ElementwiseOperation::kMultiply: - return "multiply"; - default: - return absl::StrCat("unknown: ", static_cast(op)); - } -} - std::string DataLayoutString(DataLayout layout) { switch (layout) { case DataLayout::kYXDepthBatch: @@ -402,17 +383,6 @@ std::ostream& operator<<(std::ostream& str, dnn::PadAlignment alignment) { return str << PadAlignmentString(alignment); } -std::string ShortPoolingModeString(PoolingMode mode) { - switch (mode) { - case PoolingMode::kMaximum: - return "Max"; - case PoolingMode::kAverage: - return "Avg"; - default: - return absl::StrCat("unknown: ", static_cast(mode)); - } -} - struct ConvDimIndices { union { struct { @@ -627,16 +597,9 @@ std::string TensorDescriptor::ToString() const { absl::StatusOr> MatmulTensorDescriptor::GetNonContractingDims() const { - std::vector non_contracting_dims; - for (int64_t dim = 0; dim < tensor_.dimensions().size(); ++dim) { - bool is_batch = absl::c_count(batch_dimension_numbers_, dim) != 0; - bool is_contracting = absl::c_count(contracting_dim_, dim) != 0; - if (is_batch && is_contracting) - return absl::InternalError( - "A dimension cannot be both a batch dimension and a contracting " - "dimension."); - if (!(is_batch || is_contracting)) non_contracting_dims.push_back(dim); - } + auto nc = xla::GetNonContractingDims( + tensor_.dimensions().size(), contracting_dim_, batch_dimension_numbers_); + std::vector non_contracting_dims(nc.begin(), nc.end()); if (batch_dimension_numbers_.size() + contracting_dim_.size() + non_contracting_dims.size() != @@ -777,13 +740,6 @@ std::vector BatchDescriptor::vectorized_strides( return ReorderDims(phys_strides, this->layout(), layout); } -void BatchDescriptor::CloneFrom(const BatchDescriptor& other) { - tensor_ = other.tensor_; - value_max_ = other.value_max_; - value_min_ = other.value_min_; - quantized_activation_mode_ = other.quantized_activation_mode_; -} - std::string BatchDescriptor::ToString() const { std::string spatial; for (int i = 0; i < ndims(); i++) { @@ -846,34 +802,6 @@ int64_t BatchDescriptor::NodesAcrossFeatureMaps() const { return NodesPerFeatureMap() * feature_map_count(); } -int64_t BatchDescriptor::ElementCount() const { - return count() * feature_map_count() * NodesPerFeatureMap(); -} - -int64_t BatchDescriptor::FullyConnectedWeightCount( - const BatchDescriptor& input, const BatchDescriptor& output) { - return input.NodesAcrossFeatureMaps() * output.NodesAcrossFeatureMaps(); -} - -int64_t BatchDescriptor::FullyConnectedBiasCount( - const BatchDescriptor& output) { - return output.NodesAcrossFeatureMaps(); -} - -BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor( - absl::Span inputs) { - if (inputs.empty()) { - return BatchDescriptor(); - } - int feature_map_count = 0; - for (const auto& dimensions : inputs) { - feature_map_count += dimensions.feature_map_count(); - } - BatchDescriptor output = inputs[0]; - output.set_feature_map_count(feature_map_count); - return output; -} - TensorDescriptorProto BatchDescriptor::ToProto(DataType data_type) const { CHECK_EQ(0.0, value_max_); CHECK_EQ(0.0, value_min_); @@ -895,10 +823,6 @@ FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {} FilterDescriptor::~FilterDescriptor() {} -void FilterDescriptor::CloneFrom(const FilterDescriptor& other) { - tensor_ = other.tensor_; -} - std::string FilterDescriptor::ToString() const { std::string desc = absl::StrFormat( "{output_feature_map_count: %d input_feature_map_count: %d " @@ -913,45 +837,6 @@ std::string FilterDescriptor::ToString() const { return desc; } -std::string FilterDescriptor::ToShortString() const { - // All the constituent strings are less than 15 characters, so the - // small string optimization ensures that there will be at most one - // heap memory allocation. - std::string od = absl::StrCat("od", output_feature_map_count()); - std::string id = absl::StrCat("id", input_feature_map_count()); - - std::string spatial = "s"; - for (int i = 0; i < ndims(); i++) { - absl::StrAppendFormat(&spatial, "%d ", input_filter_dims()[i]); - } - - switch (layout()) { - case FilterLayout::kOutputInputYX: - return absl::StrCat(od, id, spatial); - case FilterLayout::kOutputYXInput: - return absl::StrCat(od, spatial, id); - case FilterLayout::kOutputInputYX4: - case FilterLayout::kOutputInputYX32: - case FilterLayout::kOutputInputYX32_CudnnReordered: - return absl::StrCat(od, id, spatial, "(VECT_C)"); - case FilterLayout::kInputYXOutput: - return absl::StrCat(id, spatial, od); - case FilterLayout::kYXInputOutput: - return absl::StrCat(spatial, id, od); - default: - LOG(FATAL) << "Unknown layout " << static_cast(layout()); - return ""; // Avoid return warning (unreachable) - } -} - -int64_t FilterDescriptor::ComputeWeightCount() const { - int64_t ret = output_feature_map_count() * input_feature_map_count(); - for (int i = 0; i < ndims(); i++) { - ret *= input_filter_dims()[i]; - } - return ret; -} - std::vector FilterDescriptor::full_dims( const FilterLayout& layout) const { std::vector oiyx_dims(ndims() + 2); @@ -1031,21 +916,6 @@ std::string ConvolutionDescriptor::ToString() const { padding, PadAlignmentString(pad_alignment()), strides, dilations); } -std::string ConvolutionDescriptor::ToShortString() const { - std::string desc; - for (int i = 0; i < ndims(); i++) { - if (i > 0) absl::StrAppend(&desc, "_"); - absl::StrAppendFormat(&desc, "p%d:%d", i, padding()[i]); - } - for (int i = 0; i < ndims(); i++) { - absl::StrAppendFormat(&desc, "_s%d:%d", i, strides()[i]); - } - for (int i = 0; i < ndims(); i++) { - absl::StrAppendFormat(&desc, "_d%d:%d", i, dilations()[i]); - } - return desc; -} - // -- PoolingDescriptor PoolingDescriptor::PoolingDescriptor(int ndims) @@ -1058,45 +928,6 @@ PoolingDescriptor::PoolingDescriptor(int ndims) PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {} -void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) { - mode_ = other.mode_; - ndims_ = other.ndims_; - window_ = other.window_; - padding_ = other.padding_; - strides_ = other.strides_; - propagate_nans_ = other.propagate_nans_; -} - -std::string PoolingDescriptor::ToString() const { - const char* mode_string = - mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage"; - - std::string window, strides, padding; - for (int i = 0; i < ndims_; i++) { - absl::StrAppendFormat(&window, "%d ", window_[i]); - absl::StrAppendFormat(&strides, "%d ", strides_[i]); - absl::StrAppendFormat(&padding, "%d", padding_[i]); - } - - const char* propagate_string = propagate_nans_ ? "Yes" : "No"; - - return absl::StrFormat( - "{mode: %s window: %s strides: %s padding: %s propagate NaNs: %s}", - mode_string, window, strides, padding, propagate_string); -} - -std::string PoolingDescriptor::ToShortString() const { - std::string window, strides, padding; - for (int i = 0; i < ndims_; i++) { - absl::StrAppendFormat(&window, "_w%d:%d", i, window_[i]); - absl::StrAppendFormat(&strides, "_s%d:%d", i, strides_[i]); - absl::StrAppendFormat(&padding, "_p%d:%d", i, padding_[i]); - } - return absl::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg", - window, strides, padding, - propagate_nans_ ? "propagate_nans" : "ignore_nans"); -} - // -- NormalizeDescriptor NormalizeDescriptor::NormalizeDescriptor() @@ -1107,28 +938,6 @@ NormalizeDescriptor::NormalizeDescriptor() wrap_around_(false), segment_size_(0) {} -void NormalizeDescriptor::CloneFrom(const NormalizeDescriptor& other) { - bias_ = other.bias_; - range_ = other.range_; - alpha_ = other.alpha_; - beta_ = other.beta_; - wrap_around_ = other.wrap_around_; - segment_size_ = other.segment_size_; -} - -std::string NormalizeDescriptor::ToString() const { - return absl::StrFormat( - "{bias: %f range: %d alpha: %f beta: %f wrap_around: %d " - "segment_size: %d}", - bias_, range_, alpha_, beta_, wrap_around_, segment_size_); -} - -std::string NormalizeDescriptor::ToShortString() const { - return absl::StrCat("bias:", bias_, "_range:", range_, "_alpha:", alpha_, - "_beta:", beta_, "_wrap:", wrap_around_, - "_size:", segment_size_); -} - bool DnnSupport::IsStatusOk(const absl::Status& status, bool report_error) { if (status.ok()) { return true; diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index b1b89ff1c59d59..f99f6a0380c395 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -44,8 +44,11 @@ limitations under the License. #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" +#include "xla/stream_executor/scratch_allocator.h" +#include "xla/stream_executor/stream.h" +#include "xla/tsl/protobuf/dnn.pb.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" -#include "tsl/protobuf/dnn.pb.h" namespace Eigen { struct half; @@ -53,10 +56,6 @@ struct half; namespace stream_executor { -class HostBuffer; -class Stream; -class ScratchAllocator; - namespace dnn { // Specifies an index to use when accessing specific spatial dimensions. @@ -194,10 +193,6 @@ class MatmulTensorDescriptor { absl::Span minor_to_major, absl::Span batch_dims, absl::Span contracting_dims); - std::vector dimensions() const { return tensor_.dimensions(); } - std::vector minor_to_major() const { - return tensor_.minor_to_major(); - } DataType type() const { return tensor_.type(); } std::string ToString() const; @@ -258,9 +253,6 @@ class RnnStateTensorDescriptor { virtual ~RnnStateTensorDescriptor() = default; }; -// Returns a string representation of the given quantization mode. -std::string QuantizedActivationModeString(QuantizedActivationMode mode); - // Describes the dimensions that a layer consumes/produces. // // This is a matrix (height, width), its "depth" (feature_map_count), @@ -307,9 +299,6 @@ class BatchDescriptor { BatchDescriptor(); explicit BatchDescriptor(int ndims); - // Clones values from 'other' for initialization. - void CloneFrom(const BatchDescriptor& other); - std::string ToString() const; std::string ToShortString() const; @@ -324,9 +313,6 @@ class BatchDescriptor { int64_t feature_map_count() const { return tensor_.dimensions(1); } int64_t height() const { return GetDim(spatial_size(), DimIndex::Y); } int64_t width() const { return GetDim(spatial_size(), DimIndex::X); } - int64_t spatial_dim(DimIndex dim) const { - return GetDim(spatial_size(), dim); - } int ndims() const { return spatial_size().size(); } float value_max() const { return value_max_; } float value_min() const { return value_min_; } @@ -373,23 +359,10 @@ class BatchDescriptor { SetDim(spatial_size(), dim, value); return *this; } - BatchDescriptor& set_value_max(float value) { - value_max_ = value; - return *this; - } - BatchDescriptor& set_value_min(float value) { - value_min_ = value; - return *this; - } BatchDescriptor& set_layout(DataLayout layout) { tensor_.set_data_layout(layout); return *this; } - BatchDescriptor& set_quantized_activation_mode( - QuantizedActivationMode quantized_activation_mode) { - quantized_activation_mode_ = quantized_activation_mode; - return *this; - } // Return the number of nodes in a single feature map. int64_t NodesPerFeatureMap() const; @@ -398,28 +371,6 @@ class BatchDescriptor { // affected by the batch count. int64_t NodesAcrossFeatureMaps() const; - // Returns the number of elements (e.g. RGB pixel values) required to hold a - // given batch descriptor, given a no-padding assumption. Note that this is - // affected by the batch count. - int64_t ElementCount() const; - - // Return the number of weights required to fully connect a layer with - // dimensions given by the 'input' descriptor with a layer with dimensions - // given by the 'output' descriptor. - static int64_t FullyConnectedWeightCount(const BatchDescriptor& input, - const BatchDescriptor& output); - - // Return the number of biases required to fully connect to an output layer - // with dimensions given the 'output' descriptor. - static int64_t FullyConnectedBiasCount(const BatchDescriptor& output); - - // Return a BatchDescriptor for the output of a depth concatenation - // with the given input descriptors. The inputs should have the same - // dimensions, except possibly for feature_map_count(), though this - // function does not verify that. - static BatchDescriptor DepthConcatenateOutputDescriptor( - absl::Span inputs); - private: absl::Span spatial_size() const { return AsInt64Slice(tensor_.dimensions()).subspan(2); @@ -499,32 +450,11 @@ class FilterDescriptor { } int ndims() const { return input_filter_dims().size(); } - void CloneFrom(const FilterDescriptor& other); - std::string ToString() const; - std::string ToShortString() const; TensorDescriptorProto ToProto(DataType data_type) const; - // Returns the number of weights required as parameters for a convolution - // using this filter descriptor. - int64_t ComputeWeightCount() const; - - // Returns the number of biases required as parameters for a convolution - // using this filter descriptor. - int64_t bias_count() const { return output_feature_map_count(); } - int64_t output_feature_map_count() const { return tensor_.dimensions(0); } int64_t input_feature_map_count() const { return tensor_.dimensions(1); } - int64_t input_filter_height() const { - return GetDim(input_filter_dims(), DimIndex::Y); - } - int64_t input_filter_width() const { - return GetDim(input_filter_dims(), DimIndex::X); - } - int64_t input_filter_dim(DimIndex dim) const { - return GetDim(input_filter_dims(), dim); - } - FilterLayout layout() const { return tensor_.filter_layout(); } absl::Span input_filter_dims() const { @@ -610,7 +540,6 @@ class ConvolutionDescriptor { ~ConvolutionDescriptor(); std::string ToString() const; - std::string ToShortString() const; ConvolutionDescriptorProto ToProto() const { return proto_; } ConvolutionDescriptor& set_zero_padding_height(int64_t value) { @@ -658,28 +587,7 @@ class ConvolutionDescriptor { : ConvolutionMode::CROSS_CORRELATION); return *this; } - ConvolutionDescriptor& set_name(const std::string& name) { - proto_.set_name(name); - return *this; - } - int64_t zero_padding_height() const { return GetDim(padding(), DimIndex::Y); } - int64_t zero_padding_width() const { return GetDim(padding(), DimIndex::X); } - int64_t vertical_filter_stride() const { - return GetDim(strides(), DimIndex::Y); - } - int64_t horizontal_filter_stride() const { - return GetDim(strides(), DimIndex::X); - } - int64_t vertical_dilation_rate() const { - return GetDim(dilations(), DimIndex::Y); - } - int64_t horizontal_dilation_rate() const { - return GetDim(dilations(), DimIndex::X); - } - int zero_padding(DimIndex dim) const { return GetDim(padding(), dim); } - int filter_stride(DimIndex dim) const { return GetDim(strides(), dim); } - int dilation_rate(DimIndex dim) const { return GetDim(dilations(), dim); } // TODO(timshen): remove this function. No users of this class is setting a // non-default pad alignment. PadAlignment pad_alignment() const { return PadAlignment::kDefault; } @@ -701,8 +609,6 @@ class ConvolutionDescriptor { return AsInt64Slice(proto_.paddings()); } - std::string name() const { return proto_.name(); } - private: absl::Span strides() { return AsInt64Slice(proto_.mutable_strides()); @@ -739,9 +645,6 @@ enum class SpaceConcatenateMode : int64_t { YDirection, }; -// Returns a short name for the pooling mode, e.g. "Avg". -std::string ShortPoolingModeString(PoolingMode mode); - // Describes a pooling operation to be enqueued onto a stream via a platform's // DnnSupport. // @@ -804,38 +707,19 @@ class PoolingDescriptor { propagate_nans_ = value; return *this; } - PoolingDescriptor& set_name(const std::string& name) { - name_ = name; - return *this; - } int ndims() const { return ndims_; } - void CloneFrom(const PoolingDescriptor& other); - - std::string ToString() const; - std::string ToShortString() const; PoolingMode mode() const { return mode_; } - int64_t window_height() const { return GetDim(window_, DimIndex::Y); } - int64_t window_width() const { return GetDim(window_, DimIndex::X); } - int64_t window(DimIndex dim) const { return GetDim(window_, dim); } - int64_t vertical_padding() const { return GetDim(padding_, DimIndex::Y); } - int64_t horizontal_padding() const { return GetDim(padding_, DimIndex::X); } - int64_t padding(DimIndex dim) const { return GetDim(padding_, dim); } - int64_t vertical_stride() const { return GetDim(strides_, DimIndex::Y); } - int64_t horizontal_stride() const { return GetDim(strides_, DimIndex::X); } - int64_t stride(DimIndex dim) const { return GetDim(strides_, dim); } absl::Span window() const { return window_; } absl::Span padding() const { return padding_; } absl::Span strides() const { return strides_; } bool propagate_nans() const { return propagate_nans_; } - std::string name() const { return name_; } private: PoolingMode mode_; int ndims_; bool propagate_nans_; - std::string name_; // Name as in Tensorflow NodeDef, for debugging purposes. // Stored as: ..., y, x. std::vector window_; @@ -919,7 +803,6 @@ class ProfileResult { float elapsed_time_in_ms() const { return elapsed_time_in_ms_; } void set_elapsed_time_in_ms(float val) { elapsed_time_in_ms_ = val; } - size_t scratch_size() const { return scratch_size_; } void set_scratch_size(size_t val) { scratch_size_ = val; } private: @@ -1049,7 +932,6 @@ class AlgorithmConfig { algorithm_no_scratch_ = val; } std::optional scratch_size() const { return scratch_size_; } - void set_scratch_size(size_t val) { scratch_size_ = val; } bool operator==(const AlgorithmConfig& other) const { return this->algorithm_ == other.algorithm_ && this->algorithm_no_scratch_ == other.algorithm_no_scratch_ && @@ -1131,21 +1013,6 @@ class NormalizeDescriptor { return *this; } - NormalizeDescriptor& set_wrap_around(bool wrap_around) { - wrap_around_ = wrap_around; - return *this; - } - - NormalizeDescriptor& set_segment_size(int32_t segment_size) { - segment_size_ = segment_size; - return *this; - } - - void CloneFrom(const NormalizeDescriptor& other); - - std::string ToString() const; - std::string ToShortString() const; - float bias() const { return bias_; } int32_t range() const { return range_; } float alpha() const { return alpha_; } @@ -1169,8 +1036,6 @@ std::string ActivationModeString(ActivationMode mode); // inputs. enum class ElementwiseOperation { kAdd, kMultiply }; -std::string ElementwiseOperationString(ElementwiseOperation op); - // A simple class representing the version of the backing library, to // workaround the "too perfect forwarding" issue in gcc6+ compilers. // See PR#16309 and issue #18402 for links discussing the issue. diff --git a/third_party/xla/xla/stream_executor/event_based_timer.h b/third_party/xla/xla/stream_executor/event_based_timer.h index 2283f34619cff5..96900806d736f0 100644 --- a/third_party/xla/xla/stream_executor/event_based_timer.h +++ b/third_party/xla/xla/stream_executor/event_based_timer.h @@ -27,6 +27,9 @@ namespace stream_executor { class EventBasedTimer { public: virtual ~EventBasedTimer() = default; + EventBasedTimer() = default; + EventBasedTimer(EventBasedTimer&&) = default; + EventBasedTimer& operator=(EventBasedTimer&&) = default; // Stops the timer on the first call and returns the elapsed duration. // Subsequent calls error out. diff --git a/third_party/xla/xla/stream_executor/fft.h b/third_party/xla/xla/stream_executor/fft.h index d88834c72e074c..937ae639eed9f0 100644 --- a/third_party/xla/xla/stream_executor/fft.h +++ b/third_party/xla/xla/stream_executor/fft.h @@ -47,8 +47,6 @@ limitations under the License. #include #include -#include "xla/stream_executor/platform/port.h" - namespace stream_executor { class Stream; @@ -109,9 +107,9 @@ class FftSupport { // output_distance: Indicates the distance between the first element of two // consecutive signals in a batch of the output data. virtual std::unique_ptr CreateBatchedPlanWithScratchAllocator( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, Type type, + Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, + uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed, + uint64_t output_stride, uint64_t output_distance, Type type, bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) = 0; @@ -162,9 +160,9 @@ class FftSupport { // ::stream_executor namespace. #define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ std::unique_ptr CreateBatchedPlanWithScratchAllocator( \ - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, \ - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, \ - uint64_t output_stride, uint64 output_distance, fft::Type type, \ + Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, \ + uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed, \ + uint64_t output_stride, uint64_t output_distance, fft::Type type, \ bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) \ override; \ void UpdatePlanWithScratchAllocator(Stream *stream, fft::Plan *plan, \ diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index 04d1783a18d2ee..d9e0b5ec9880bb 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -65,15 +65,12 @@ package( licenses = ["notice"], ) -cc_library( - name = "gpu_activation_header", - hdrs = ["gpu_activation.h"], - deps = [":scoped_activate_context"], -) - cc_library( name = "context", hdrs = ["context.h"], + deps = [ + "@com_google_absl//absl/status", + ], ) cc_library( @@ -126,8 +123,7 @@ cc_library( hdrs = ["scoped_activate_context.h"], deps = [ ":context", - ":gpu_executor_header", - "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:activate_context", "@com_google_absl//absl/log:check", "@local_tsl//tsl/platform:logging", ], @@ -146,30 +142,12 @@ xla_cc_test( ], ) -gpu_only_cc_library( - name = "gpu_activation", - hdrs = ["gpu_activation.h"], - deps = [ - ":scoped_activate_context", - ], -) - gpu_only_cc_library( name = "gpu_diagnostics_header", hdrs = ["gpu_diagnostics.h"], deps = ["@com_google_absl//absl/status:statusor"], ) -gpu_only_cc_library( - name = "gpu_collectives_header", - hdrs = ["gpu_collectives.h"], - deps = [ - ":context", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - ], -) - gpu_only_cc_library( name = "gpu_driver_header", hdrs = ["gpu_driver.h"], @@ -182,7 +160,8 @@ gpu_only_cc_library( deps = [ ":context", ":gpu_types_header", - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -195,19 +174,6 @@ gpu_only_cc_library( ), ) -gpu_only_cc_library( - name = "gpu_runtime_header", - hdrs = ["gpu_runtime.h"], - visibility = internal_visibility([ - "//xla/service/gpu:__subpackages__", - "//xla/stream_executor:__subpackages__", - ]), - deps = [ - ":gpu_types_header", - "@com_google_absl//absl/status:statusor", - ], -) - gpu_only_cc_library( name = "gpu_command_buffer", srcs = ["gpu_command_buffer.cc"], @@ -222,9 +188,12 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla:util", - "//xla/stream_executor", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "@com_google_absl//absl/container:flat_hash_map", @@ -245,37 +214,9 @@ gpu_only_cc_library( "@local_tsl//tsl/platform:statusor", ] + if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", - "//xla/stream_executor/cuda:command_buffer_kernels", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:command_buffer_kernels", ]), ) -gpu_only_cc_library( - name = "gpu_event_header", - hdrs = ["gpu_event.h"], - deps = [ - ":context", - ":gpu_types_header", - "//xla/stream_executor:event", - "@com_google_absl//absl/status", - ], -) - -gpu_only_cc_library( - name = "gpu_event", - srcs = ["gpu_event.cc"], - hdrs = ["gpu_event.h"], - deps = [ - ":context", - ":gpu_driver_header", - ":gpu_types_header", - "//xla/stream_executor:event", - "@com_google_absl//absl/base", - "@com_google_absl//absl/status", - ], -) - cc_library( name = "gpu_executor_header", hdrs = ["gpu_executor.h"], @@ -303,6 +244,36 @@ cc_library( ], ) +cc_library( + name = "mock_gpu_executor", + testonly = True, + hdrs = ["mock_gpu_executor.h"], + tags = ["gpu"], + deps = [ + ":gpu_executor_header", + "//xla:test", + "//xla/stream_executor:allocator_stats", + "//xla/stream_executor:blas", + "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:dnn", + "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:fft", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:module_spec", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + gpu_only_cc_library( name = "gpu_helpers_header", hdrs = ["gpu_helpers.h"], @@ -357,12 +328,8 @@ gpu_only_cc_library( name = "gpu_kernel_header", hdrs = ["gpu_kernel.h"], deps = [ - ":context", - ":gpu_executor_header", ":gpu_types_header", - "//xla/stream_executor", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", + "//xla/stream_executor:kernel", ], ) @@ -370,20 +337,14 @@ gpu_only_cc_library( name = "gpu_stream_header", hdrs = ["gpu_stream.h"], deps = [ - ":gpu_event_header", - ":gpu_executor_header", ":gpu_types_header", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:event", - "//xla/stream_executor:event_based_timer", "//xla/stream_executor:kernel", "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", - "@com_google_absl//absl/functional:any_invocable", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", ], ) @@ -392,28 +353,10 @@ gpu_only_cc_library( srcs = ["gpu_stream.cc"], hdrs = ["gpu_stream.h"], deps = [ - ":gpu_driver_header", - ":gpu_event_header", - ":gpu_executor_header", - ":gpu_kernel_header", ":gpu_types_header", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:event", - "//xla/stream_executor:event_based_timer", - "//xla/stream_executor:kernel", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:platform", "//xla/stream_executor:stream", - "//xla/stream_executor:stream_common", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log", + "@com_google_absl//absl/base", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:nvtx_utils", ], ) @@ -430,33 +373,6 @@ gpu_only_cc_library( ], ) -gpu_only_cc_library( - name = "gpu_timer", - srcs = [ - "gpu_timer.cc", - ], - hdrs = [ - "gpu_timer.h", - ], - deps = [ - ":context", - ":gpu_driver_header", - ":gpu_event", - ":gpu_semaphore", - ":gpu_stream", - "//xla/stream_executor:event_based_timer", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - gpu_only_cc_library( name = "gpu_types_header", hdrs = ["gpu_types.h"], @@ -465,9 +381,7 @@ gpu_only_cc_library( ]) + if_sycl_is_configured([ "TENSORFLOW_USE_SYCL=1", ]), - deps = [ - "//xla/stream_executor/platform", - ] + if_cuda_is_configured([ + deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", @@ -521,11 +435,16 @@ cc_library( "redzone_allocator_kernel.h", "redzone_allocator_kernel_cuda.cc", ], - tags = ["manual"], + tags = [ + "cuda-only", + "gpu", + "manual", + ], deps = [ ":gpu_asm_opts", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/cuda:cuda_asm_compiler", "@com_google_absl//absl/base", @@ -551,8 +470,9 @@ gpu_kernel_library( tags = ["manual"], deps = [ ":gpu_asm_opts", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "@com_google_absl//absl/status:statusor", "@local_config_rocm//rocm:rocm_headers", @@ -574,11 +494,16 @@ gpu_only_cc_library( ]), deps = [ ":gpu_asm_opts", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/framework:allocator", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/container:fixed_array", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -586,7 +511,6 @@ gpu_only_cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ] + if_rocm_is_configured([ @@ -604,8 +528,10 @@ xla_test( ":gpu_asm_opts", ":gpu_init", ":redzone_allocator", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", @@ -632,6 +558,7 @@ tsl_gpu_library( tsl_gpu_library( name = "gpu_cudamallocasync_allocator", +<<<<<<< HEAD srcs = [ "gpu_cudamallocasync_allocator.cc", ], @@ -641,11 +568,24 @@ tsl_gpu_library( ]), cuda_deps = [ "//xla/stream_executor/cuda:cuda_executor", +======= + srcs = ["gpu_cudamallocasync_allocator.cc"], + hdrs = ["gpu_cudamallocasync_allocator.h"], + tags = [ + "cuda-only", + "gpu", +>>>>>>> master ], deps = [ ":gpu_init_impl", + "//xla/stream_executor:activate_context", "//xla/stream_executor:stream_executor_h", +<<<<<<< HEAD "//xla/stream_executor/gpu:scoped_activate_context", +======= + "//xla/stream_executor/cuda:cuda_executor", + "//xla/stream_executor/cuda:cuda_status", +>>>>>>> master "//xla/tsl/framework:allocator", "//xla/tsl/framework:device_id", "//xla/tsl/util:env_var", @@ -661,14 +601,14 @@ xla_test( name = "gpu_cudamallocasync_allocator_test", srcs = ["gpu_cudamallocasync_allocator_test.cc"], backends = ["gpu_any"], - tags = ["no_rocm"], + tags = ["cuda-only"], deps = [ ":gpu_cudamallocasync_allocator", ":gpu_stream", "//xla/service:platform_util", - "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/cuda:cuda_platform", "//xla/tsl/framework:device_id", "@com_google_absl//absl/log:check", @@ -693,9 +633,12 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:algorithm_util", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:host_or_device_scalar", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/tsl/protobuf:dnn_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -703,7 +646,6 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", ] + if_cuda_is_configured([ "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", ]) + if_static([ @@ -781,10 +723,13 @@ xla_test( ":gpu_test_kernels", ":gpu_test_kernels_fatbin", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/rocm:rocm_platform_id", "//xla/tsl/lib/core:status_test_util", @@ -809,12 +754,16 @@ xla_test( ":gpu_test_kernels", ":gpu_types_header", "//xla/service:platform_util", - "//xla/stream_executor", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor:typed_kernel_factory", "//xla/stream_executor/cuda:cuda_platform_id", @@ -844,9 +793,11 @@ xla_test( "TENSORFLOW_USE_ROCM=1", ]), deps = [ - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -867,7 +818,10 @@ xla_test( "TENSORFLOW_USE_ROCM=1", ]), deps = [ - "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_finder", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/stream_executor/gpu/context.h b/third_party/xla/xla/stream_executor/gpu/context.h index 77f26ddd74ca13..1baa0e589fb6f2 100644 --- a/third_party/xla/xla/stream_executor/gpu/context.h +++ b/third_party/xla/xla/stream_executor/gpu/context.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_CONTEXT_H_ #define XLA_STREAM_EXECUTOR_GPU_CONTEXT_H_ +#include "absl/status/status.h" + namespace stream_executor::gpu { // This defines a base class for interacting with any context-specific state @@ -32,6 +34,9 @@ class Context { // Returns the device ordinal associated with this context. virtual int device_ordinal() const = 0; + + // Synchronizes all activity on the GPU. + virtual absl::Status Synchronize() = 0; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc index 6f931aeb6324fd..6aee86bf2cbc19 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -27,9 +27,9 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/protobuf/dnn.pb.h" #if GOOGLE_CUDA #include "tsl/platform/tensor_float_32_utils.h" #endif @@ -46,12 +46,16 @@ absl::StatusOr AsBlasDataType(PrimitiveType dtype) { switch (dtype) { case PrimitiveType::F8E5M2: return DataType::kF8E5M2; + case PrimitiveType::F8E4M3: + return DataType::kF8E4M3; case PrimitiveType::F8E4M3FN: return DataType::kF8E4M3FN; case PrimitiveType::F8E5M2FNUZ: return DataType::kF8E5M2FNUZ; case PrimitiveType::F8E4M3FNUZ: return DataType::kF8E4M3FNUZ; + case PrimitiveType::F8E3M4: + return DataType::kF8E3M4; case PrimitiveType::S8: return DataType::kInt8; case PrimitiveType::F16: @@ -79,12 +83,16 @@ absl::StatusOr AsXlaPrimitiveType(DataType dtype) { switch (dtype) { case DataType::kF8E5M2: return PrimitiveType::F8E5M2; + case DataType::kF8E4M3: + return PrimitiveType::F8E4M3; case DataType::kF8E4M3FN: return PrimitiveType::F8E4M3FN; case DataType::kF8E5M2FNUZ: return PrimitiveType::F8E5M2FNUZ; case DataType::kF8E4M3FNUZ: return PrimitiveType::F8E4M3FNUZ; + case DataType::kF8E3M4: + return PrimitiveType::F8E3M4; case DataType::kInt8: return PrimitiveType::S8; case DataType::kHalf: @@ -141,9 +149,11 @@ absl::StatusOr GetBlasComputationType( if (algorithm == xla::PrecisionConfig::ALG_UNSET) { switch (output_dtype) { case PrimitiveType::F8E5M2: // fall-through + case PrimitiveType::F8E4M3: // fall-through case PrimitiveType::F8E4M3FN: // fall-through case PrimitiveType::F8E5M2FNUZ: // fall-through case PrimitiveType::F8E4M3FNUZ: // fall-through + case PrimitiveType::F8E3M4: // fall-through case PrimitiveType::F16: // fall-through case PrimitiveType::BF16: // Accumulate in f32 precision. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index 4590e45ecb5ac2..1363ce564438bc 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -25,6 +25,8 @@ limitations under the License. #include #include +#include "xla/stream_executor/stream.h" + #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #endif @@ -47,8 +49,6 @@ limitations under the License. #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/stream_executor/typed_kernel_factory.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -61,17 +61,6 @@ namespace stream_executor::gpu { // Implementation details device kernels required by GpuCommandBuffer. //===----------------------------------------------------------------------===// -// See device specific implementations. These are -// various kernels that update Gpu conditionals based on the device memory -// values, and allow implementing on-device control flow via conditional command -// buffers. -absl::StatusOr GetSetIfConditionKernelLoaderSpec(); -absl::StatusOr GetSetIfElseConditionKernelLoaderSpec(); -absl::StatusOr GetSetCaseConditionKernelLoaderSpec(); -absl::StatusOr GetSetForConditionKernelLoaderSpec(); -absl::StatusOr GetSetWhileConditionKernelLoaderSpec(); -absl::StatusOr GetNoOpKernelLoaderSpec(); - using Mode = CommandBuffer::Mode; using State = CommandBuffer::State; @@ -226,71 +215,6 @@ GpuCommandBuffer::Dependencies GpuCommandBuffer::GetBarrier( : Dependencies{execution_scope.barriers.back().handle}; } -absl::StatusOr -GpuCommandBuffer::GetSetIfConditionKernel() { - if (!set_if_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetSetIfConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_if_condition_kernel_, - SetIfConditionKernel::FactoryType::Create(parent_, spec)); - } - return &set_if_condition_kernel_; -} - -absl::StatusOr -GpuCommandBuffer::GetSetIfElseConditionKernel() { - if (!set_if_else_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetSetIfElseConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_if_else_condition_kernel_, - SetIfElseConditionKernel::FactoryType::Create(parent_, spec)); - } - return &set_if_else_condition_kernel_; -} - -absl::StatusOr -GpuCommandBuffer::GetSetCaseConditionKernel() { - if (!set_case_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetSetCaseConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_case_condition_kernel_, - SetCaseConditionKernel::FactoryType::Create(parent_, spec)); - } - return &set_case_condition_kernel_; -} - -absl::StatusOr -GpuCommandBuffer::GetSetForConditionKernel() { - if (!set_for_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetSetForConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_for_condition_kernel_, - SetForConditionKernel::FactoryType::Create(parent_, spec)); - } - return &set_for_condition_kernel_; -} - -absl::StatusOr -GpuCommandBuffer::GetSetWhileConditionKernel() { - if (!set_while_condition_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetSetWhileConditionKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN( - set_while_condition_kernel_, - SetWhileConditionKernel::FactoryType::Create(parent_, spec)); - } - return &set_while_condition_kernel_; -} - -absl::StatusOr -GpuCommandBuffer::GetNoOpKernel() { - if (!noop_kernel_) { - TF_ASSIGN_OR_RETURN(auto spec, GetNoOpKernelLoaderSpec()); - TF_ASSIGN_OR_RETURN(noop_kernel_, - NoOpKernel::FactoryType::Create(parent_, spec)); - } - return &noop_kernel_; -} - absl::Status GpuCommandBuffer::DisableBarriersExecution( GpuGraphExecHandle exec) { #if !defined(TENSORFLOW_USE_ROCM) @@ -691,14 +615,8 @@ GpuCommandBuffer::CreateConditionalCommandBuffers( absl::Span builders) { std::vector> cmd_buffers; - // Conditional command buffers always created in nested mode and with - // underlying graphs owned by a conditional node. - CommandBuffer::Mode nested = CommandBuffer::Mode::kNested; - bool is_owned_graph = false; - for (size_t i = 0; i < handles.size(); ++i) { - auto command_buffer = std::make_unique( - nested, parent_, graphs[i], is_owned_graph); + auto command_buffer = CreateNestedCommandBuffer(graphs[i]); TF_RETURN_IF_ERROR(builders[i](command_buffer.get(), handles[i])); TF_RETURN_IF_ERROR(command_buffer->Finalize()); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 4627e9a2bb262c..7f9a870aa80ce2 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -163,6 +162,7 @@ class GpuCommandBuffer : public CommandBuffer { private: using Dependencies = absl::InlinedVector; + protected: using NoOpKernel = TypedKernel<>; // A signature of a device kernels updating conditional handle(s). @@ -186,6 +186,7 @@ class GpuCommandBuffer : public CommandBuffer { using SetWhileConditionKernel = TypedKernel>; + private: // A callback to launch a kernel that updates conditional handles state. using SetConditionFn = std::function)>; @@ -250,12 +251,15 @@ class GpuCommandBuffer : public CommandBuffer { // Returns loaded auxiliary kernels, or loads them on a given stream executor. // Loaded kernels owned by a current command buffer. - absl::StatusOr GetSetIfConditionKernel(); - absl::StatusOr GetSetIfElseConditionKernel(); - absl::StatusOr GetSetCaseConditionKernel(); - absl::StatusOr GetSetForConditionKernel(); - absl::StatusOr GetSetWhileConditionKernel(); - absl::StatusOr GetNoOpKernel(); + virtual absl::StatusOr GetSetIfConditionKernel() = 0; + virtual absl::StatusOr + GetSetIfElseConditionKernel() = 0; + virtual absl::StatusOr + GetSetCaseConditionKernel() = 0; + virtual absl::StatusOr GetSetForConditionKernel() = 0; + virtual absl::StatusOr + GetSetWhileConditionKernel() = 0; + virtual absl::StatusOr GetNoOpKernel() = 0; // Recursively disable all nodes corresponding to barriers (including nested // conditional command buffers). This is work around the fact that we can't @@ -342,14 +346,10 @@ class GpuCommandBuffer : public CommandBuffer { // Track the number of command buffer updates for debugging. int64_t num_updates_ = 0; - // Lazy loaded auxiliary kernels required for building CUDA graphs (no-op - // barriers, updating conditional handles, etc.). - SetIfConditionKernel set_if_condition_kernel_; - SetIfElseConditionKernel set_if_else_condition_kernel_; - SetCaseConditionKernel set_case_condition_kernel_; - SetForConditionKernel set_for_condition_kernel_; - SetWhileConditionKernel set_while_condition_kernel_; - NoOpKernel noop_kernel_; + // Creates a nested command buffer, associated with the same executor. + // The given graph will not be owned by the created command buffer. + virtual std::unique_ptr CreateNestedCommandBuffer( + GpuGraphHandle graph) = 0; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc index 908d02570538e2..45df35b6d9d536 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc @@ -32,9 +32,17 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +<<<<<<< HEAD #include "xla/stream_executor/gpu/gpu_init.h" // IWYU pragma: keep #include "xla/stream_executor/stream_executor.h" // IWYU pragma: keep #include "xla/stream_executor/gpu/scoped_activate_context.h" +======= +#include "absl/synchronization/mutex.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/cuda/cuda_status.h" +#include "xla/stream_executor/gpu/gpu_init.h" +>>>>>>> master #include "xla/tsl/framework/allocator.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/util/env_var.h" // IWYU pragma: keep @@ -150,7 +158,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator( << "Failed to retain context: " << GetCudaErrorMessage(result); } - gpu::ScopedActivateContext scoped_activation{stream_exec_}; + std::unique_ptr scoped_activation = stream_exec_->Activate(); // Check the CUDA runtime is recent enough. if (auto status2 = cuDriverGetVersion(&driverVersion)) { @@ -336,7 +344,7 @@ void* GpuCudaMallocAsyncAllocator::AllocateRaw(size_t alignment, if (stats_) { lock.lock(); } - gpu::ScopedActivateContext scoped_activation{stream_exec_}; + std::unique_ptr scoped_activation = stream_exec_->Activate(); void* ptr = nullptr; auto result = cuMemAllocFromPoolAsync(reinterpret_cast(&ptr), num_bytes, pool_, cuda_stream_); @@ -406,7 +414,8 @@ void GpuCudaMallocAsyncAllocator::DeallocateRaw(void* ptr) { VLOG(1) << "Ignoring CUDA error: " << GetCudaErrorMessage(result); } else { size_t free, total; - gpu::ScopedActivateContext scoped_activation{stream_exec_}; + std::unique_ptr scoped_activation = + stream_exec_->Activate(); cuMemGetInfo(&free, &total); LOG(ERROR) << "cudaFreeAsync failed to free " << ptr << ": " << GetCudaErrorMessage(result) diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 447201e63796e1..3f8753119f6e7b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -57,138 +57,6 @@ namespace gpu { // Thread safety: these functions should not be used from signal handlers. class GpuDriver { public: - // Wraps a call to cuInit/hipInit with logging to help indicate what has gone - // wrong in the case of failure. Safe to call multiple times; will be fast on - // all calls after the first. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#initialization - static absl::Status Init(); - - // Creates a new CUDA/HIP stream associated with the given context via - // cuStreamCreate/hipStreamCreateWithFlags. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1ga581f0c5833e21ded8b5a56594e243f4 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management - static absl::StatusOr CreateStream(Context* context, - int priority = 0); - - // Destroys a CUDA/HIP stream associated with the given context. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g244c8833de4596bcd31a06cdf21ee758 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management - static void DestroyStream(Context* context, GpuStreamHandle stream); - - // CUDA/HIP events can explicitly disable event TSC retrieval for some - // presumed performance improvement if timing is unnecessary. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types - enum class EventFlags { kDefault, kDisableTiming }; - - // Creates a new event associated with the given context. - // result is an outparam owned by the caller and must not be null. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types - static absl::Status InitEvent(Context* context, GpuEventHandle* result, - EventFlags flags); - - // Destroys *event and turns it into a nullptr. event may not be null, but - // *event may be, via cuEventDestroy/hipEventDestroy - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#event-management - static absl::Status DestroyEvent(Context* context, GpuEventHandle* event); - - // Allocates a GPU memory space of size bytes associated with the given - // context via cuMemAlloc/hipMalloc. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management - static void* DeviceAllocate(Context* context, uint64_t bytes); - - // Deallocates a GPU memory space of size bytes associated with the given - // context via cuMemFree/hipFree. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management - static void DeviceDeallocate(Context* context, void* location); - - // Allocates a unified memory space of size bytes associated with the given - // context via cuMemAllocManaged. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb347ded34dc326af404aa02af5388a32 - // (supported on CUDA only) - static void* UnifiedMemoryAllocate(Context* context, uint64_t bytes); - - // Deallocates a unified memory space of size bytes associated with the given - // context via cuMemFree. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a - // (supported on CUDA only) - static void UnifiedMemoryDeallocate(Context* context, void* location); - - // Allocates page-locked and CUDA-registered memory on the host via - // cuMemAllocHost/hipHostMalloc. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management - static void* HostAllocate(Context* context, uint64_t bytes); - - // Deallocates a location created by HostAllocate, via - // cuMemFreeHost/hipHostFree. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g62e0fdbe181dab6b1c90fa1a51c7b92c - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management - static void HostDeallocate(Context* context, void* location); - - // Queries the priority range and returns the corresponding integer value via - // cuCtxGetStreamPriorityRange/hipDeviceGetStreamPriorityRange - // - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g137920ab61a71be6ce67605b9f294091 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#context-management - static int GetGpuStreamPriority( - Context* context, stream_executor::StreamPriority stream_priority); - - // Given a device ordinal, returns a device handle into the device outparam, - // which must not be null. - // - // N.B. these device handles do not have a corresponding destroy function in - // the CUDA/HIP driver API. - static absl::Status GetDevice(int device_ordinal, GpuDeviceHandle* device); - - // Given a device handle, returns the name reported by the driver for the - // device. - static absl::Status GetDeviceName(GpuDeviceHandle device, - std::string* device_name); - - // Given a device to create a context for, returns a context handle into the - // context outparam, which must not be null. - // - // N.B. CUDA contexts are weird. They are implicitly associated with the - // calling thread. Current documentation on contexts and their influence on - // userspace processes is given here: - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g65dc0012348bc84810e2103a40d8e2cf - static absl::Status CreateContext(int device_ordinal, GpuDeviceHandle device, - Context** context); - - // Destroys the provided context via cuCtxDestroy. - // Don't do this while clients could still be using the context, per the docs - // bad things will happen. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g27a365aebb0eb548166309f58a1e8b8e - static void DestroyContext(Context* context); - - // Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15 - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control - static absl::Status LaunchKernel( - Context* context, absl::string_view kernel_name, - GpuFunctionHandle function, unsigned int grid_dim_x, - unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, - GpuStreamHandle stream, void** kernel_params, void** extra); - - // Launches a CUDA/ROCm kernel via cuLaunchKernelEx/hipModuleLaunchKernelEx. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb9c891eb6bb8f4089758e64c9c976db9 - static absl::Status LaunchKernel( - Context* context, absl::string_view kernel_name, - GpuFunctionHandle function, unsigned int cluster_dim_x, - unsigned int cluster_dim_y, unsigned int cluster_dim_z, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, - GpuStreamHandle stream, void** kernel_params, void** extra); - // Creates a new GPU graph. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd885f719186010727b75c3315f865fdf // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management @@ -304,16 +172,6 @@ class GpuDriver { GpuGraphHandle graph, const char* path, bool return_printed_graph = false); - // Returns a stream's capture status. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g37823c49206e3704ae23c7ad78560bca - // https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management - static absl::StatusOr StreamIsCapturing(GpuStreamHandle stream); - - // Free unused memory that was cached on the specified device for use with - // graphs back to the OS. - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g57c87f4ba6af41825627cdd4e5a8c52b - static absl::Status DeviceGraphMemTrim(GpuDeviceHandle device); - // Creates a conditional handle. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gece6f3b9e85d0edb8484d625fe567376 static absl::Status GraphConditionalHandleCreate( @@ -424,101 +282,6 @@ class GpuDriver { GpuGraphNodeHandle node, GpuGraphHandle child); - // Loads ptx_contents with the CUDA driver's PTX JIT and stores the resulting - // handle in "module". Any error logs that are produced are logged internally. - // (supported on CUDA only) - static absl::Status LoadPtx(Context* context, const char* ptx_contents, - GpuModuleHandle* module); - - // Loads cubin_bytes with the CUDA driver's blob loading interface and stores - // the resulting handle in "module". - // (supported on CUDA only) - static absl::Status LoadCubin(Context* context, const char* cubin_bytes, - GpuModuleHandle* module); - - // Loads HSACO with the ROCM runtime and stores the resulting handle in - // "module". Any error logs that are produced are logged internally. - // (supported on ROCm only) - static absl::Status LoadHsaco(Context* context, const char* hsaco_contents, - GpuModuleHandle* module); - - // Retrieves a named kernel from a loaded module, and places the resulting - // handle into function (outparam) on success. Neither kernel_name nor - // function may be null. No ownership is taken of kernel_name. - static absl::Status GetModuleFunction(Context* context, - GpuModuleHandle module, - const char* kernel_name, - GpuFunctionHandle* function); - - // Retrieves a named global/constant symbol from a loaded module, and returns - // a device pointer and size of the symbol on success. symbol_name may not be - // null. At least one of dptr or bytes should not be null. No ownership is - // taken of symbol_name. - static absl::Status GetModuleSymbol(Context* context, GpuModuleHandle module, - const char* symbol_name, - GpuDevicePtr* dptr, size_t* bytes); - - // Unloads module from the current context via cuModuleUnload. - // TODO(leary) the documentation doesn't say what kind of disasters happen - // if you try to unload a module while its GpuFunctionHandles are in use. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MODULE.html#group__CUDA__MODULE_1g8ea3d716524369de3763104ced4ea57b - static void UnloadModule(Context* context, GpuModuleHandle module); - - // Performs a synchronous memset of the device memory segment via cuMemsetD8. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6e582bf866e9e2fb014297bfaf354d7b - static absl::Status SynchronousMemsetUint8(Context* context, - GpuDevicePtr location, - uint8_t value, size_t size); - - // Performs a synchronous memset of the device memory segment via cuMemsetD32. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g983e8d8759acd1b64326317481fbf132 - static absl::Status SynchronousMemsetUint32(Context* context, - GpuDevicePtr location, - uint32_t value, - size_t uint32_count); - - // Performs an asynchronous memset of the device memory segment via - // cuMemsetD8Async. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627 - static absl::Status AsynchronousMemsetUint8(Context* context, - GpuDevicePtr location, - uint8_t value, - size_t uint32_count, - GpuStreamHandle stream); - - // Performs an asynchronous memset of the device memory segment via - // cuMemsetD32Async. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5 - static absl::Status AsynchronousMemsetUint32(Context* context, - GpuDevicePtr location, - uint32_t value, - size_t uint32_count, - GpuStreamHandle stream); - - // -- Synchronous memcopies. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g4d32266788c440b0220b1a9ba5795169 - - static absl::Status SynchronousMemcpyD2H(Context* context, void* host_dst, - GpuDevicePtr gpu_src, uint64_t size); - static absl::Status SynchronousMemcpyH2D(Context* context, - GpuDevicePtr gpu_dst, - const void* host_src, uint64_t size); - - // -- Asynchronous memcopies. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g56f30236c7c5247f8e061b59d3268362 - - static absl::Status AsynchronousMemcpyD2H(Context* context, void* host_dst, - GpuDevicePtr gpu_src, uint64_t size, - GpuStreamHandle stream); - static absl::Status AsynchronousMemcpyH2D(Context* context, - GpuDevicePtr gpu_dst, - const void* host_src, uint64_t size, - GpuStreamHandle stream); - static absl::Status AsynchronousMemcpyD2D(Context* context, - GpuDevicePtr gpu_dst, - GpuDevicePtr gpu_src, uint64_t size, - GpuStreamHandle stream); - // The CUDA stream callback type signature. // The data passed to AddStreamCallback is subsequently passed to this // callback when it fires. @@ -530,20 +293,6 @@ class GpuDriver { // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gab95a78143bae7f21eebb978f91e7f3f typedef void (*StreamCallback)(void* data); - // Enqueues a callback operation into stream. - // See StreamCallback above and the NVIDIA documentation for additional - // details. - static absl::Status AddStreamCallback(Context* context, - GpuStreamHandle stream, - StreamCallback callback, void* data); - - // Causes stream to wait for event to trigger before proceeding via - // cuStreamWaitEvent. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#axzz334nAXAhM - static absl::Status WaitStreamOnEvent(Context* context, - GpuStreamHandle stream, - GpuEventHandle event); - // Blocks the calling thread until the operations enqueued onto stream have // been completed, via cuStreamSynchronize. // @@ -555,141 +304,6 @@ class GpuDriver { static absl::Status SynchronizeStream(Context* context, GpuStreamHandle stream); - // Blocks the calling thread until the operations associated with the context - // have been completed, via cuCtxSynchronize. - // - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g7a54725f28d34b8c6299f0c6ca579616 - static absl::Status SynchronizeContext(Context* context); - - // Returns whether code in the from context can access memory in the to - // context via cuDeviceCanAccessPeer. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e - static bool CanEnablePeerAccess(Context* from, Context* to); - - // Returns whether the from device can access memory in the to - // device via cuDeviceCanAccessPeer. Because of differences between ROCM and - // CUDA, this API is not supported in ROCM builds and will result in a link - // error if used. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g496bdaae1f632ebfb695b99d2c40f19e - static bool CanEnablePeerAccess(GpuDeviceHandle from, GpuDeviceHandle to); - - // Enables peer access per CanEnablePeerAccess, via cuCtxEnablePeerAccess. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__PEER__ACCESS.html#group__CUDA__PEER__ACCESS_1g0889ec6728e61c05ed359551d67b3f5a - static absl::Status EnablePeerAccess(Context* from, Context* to); - - // Returns the elapsed milliseconds between start and stop via - // cuEventElapsedTime. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1gdfb1178807353bbcaa9e245da497cf97 - static absl::StatusOr GetEventElapsedTime(Context* context, - GpuEventHandle start, - GpuEventHandle stop); - - // Records that an event occurred when execution reaches the current point in - // thestream via cuEventRecord. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g95424d3be52c4eb95d83861b70fb89d1 - static absl::Status RecordEvent(Context* context, GpuEventHandle event, - GpuStreamHandle stream); - - // -- Pointer-specific calls. - - // Returns the memory space addressed by pointer. - static absl::StatusOr GetPointerMemorySpace(GpuDevicePtr pointer); - - // Returns the base address and size of the device pointer dptr. - static absl::Status GetPointerAddressRange(GpuDevicePtr dptr, - GpuDevicePtr* base, size_t* size); - - // -- Device-specific calls. - - // Returns the compute capability for the device; i.e (3, 5). - // This is currently done via the deprecated device API. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1ge2091bbac7e1fb18c2821612115607ea - // (supported on CUDA only) - static absl::Status GetComputeCapability(int* cc_major, int* cc_minor, - GpuDeviceHandle device); - - // Returns Gpu ISA version for the device; i.e 803, 900. - // (supported on ROCm only) - static absl::Status GetGpuISAVersion(int* version, GpuDeviceHandle device); - - // Return the full GCN Architecture Name for the device - // for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack- - // (supported on ROCm only) - static absl::Status GetGpuGCNArchName(GpuDeviceHandle device, - std::string* gcnArchName); - - // Returns the number of multiprocessors on the device (note that the device - // may be multi-GPU-per-board). - static absl::StatusOr GetMultiprocessorCount(GpuDeviceHandle device); - - // Returns the limit on number of threads that can be resident in a single - // multiprocessor. - static absl::StatusOr GetMaxThreadsPerMultiprocessor( - GpuDeviceHandle device); - - // Returns the amount of shared memory available on a single GPU core (i.e. - // SM on NVIDIA devices). - static absl::StatusOr GetMaxSharedMemoryPerCore( - GpuDeviceHandle device); - - // Returns the amount of static shared memory available for a single block - // (cooperative thread array). - static absl::StatusOr GetMaxSharedMemoryPerBlock( - GpuDeviceHandle device); - - // Returns the total amount of shared memory available for a single block - // (cooperative thread array). - static absl::StatusOr GetMaxSharedMemoryPerBlockOptin( - GpuDeviceHandle device); - - // Returns the maximum supported number of registers per block. - static absl::StatusOr GetMaxRegistersPerBlock( - GpuDeviceHandle device); - - // Returns the number of threads per warp. - static absl::StatusOr GetThreadsPerWarp(GpuDeviceHandle device); - - // Queries the grid limits for device with cuDeviceGetAttribute calls. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 - static absl::Status GetGridLimits(int* x, int* y, int* z, - GpuDeviceHandle device); - - // Returns a grab-bag of device properties in a caller-owned device_properties - // structure for device_ordinal via cuDeviceGetProperties. - // - // This call is deprecated in the NVIDIA driver API; its replacement is - // GetDeviceAttribute - // - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html#group__CUDA__DEVICE__DEPRECATED_1g65a5b4e25186bd257df80b98c98cffe6 - static bool GetDeviceProperties(GpuDeviceProperty* device_properties, - int device_ordinal); - - // Gets a specific integer-valued property about the given device. - // - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 - static absl::StatusOr GetDeviceAttribute(GpuDeviceAttribute attribute, - GpuDeviceHandle device); - - // Returns whether ECC is enabled for the given GpuDeviceHandle via - // cuDeviceGetattribute with CU_DEVICE_ATTRIBUTE_ECC_ENABLED. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE.html#group__CUDA__DEVICE_1g9c3e1414f0ad901d3278a4d6645fc266 - static bool IsEccEnabled(GpuDeviceHandle device, bool* result); - - // Returns the total amount of memory available for allocation by the CUDA - // context, in bytes, via cuDeviceTotalMem. - static bool GetDeviceTotalMemory(GpuDeviceHandle device, uint64_t* result); - - // Returns the free amount of memory and total amount of memory, as reported - // by cuMemGetInfo. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g808f555540d0143a331cc42aa98835c0 - static bool GetDeviceMemoryInfo(Context* context, int64_t* free, - int64_t* total); - - // Returns a PCI bus id string for the device. - // [domain]:[bus]:[device].[function] - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g85295e7d9745ab8f0aa80dd1e172acfc - static std::string GetPCIBusID(GpuDeviceHandle device); - // -- Context- and device-independent calls. // Returns the number of visible CUDA device via cuDeviceGetCount. @@ -705,16 +319,6 @@ class GpuDriver { // // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__VERSION.html#group__CUDA__VERSION_1g8b7a10395392e049006e61bcdc8ebe71 static absl::StatusOr GetDriverVersion(); - - // -- Other calls - - // Returns the maximum number of blocks (per multiprocessor) occupied by the - // specified kernel/GpuFunctionHandle when launched with the specified - // parameters. - // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__OCCUPANCY.html#group__CUDA__OCCUPANCY_1gcc6e1094d05cba2cee17fe33ddd04a98 - static absl::StatusOr GetMaxOccupiedBlocksPerCore( - Context* context, GpuFunctionHandle kernel, int threads_per_block, - size_t dynamic_shared_memory_bytes); }; } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc b/third_party/xla/xla/stream_executor/gpu/gpu_event.cc deleted file mode 100644 index 12a8a99b842f44..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_event.h" - -#include - -#include "absl/base/casts.h" -#include "absl/status/status.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_types.h" - -namespace stream_executor { -namespace gpu { - -GpuEvent::GpuEvent(Context* context) : context_(context), gpu_event_(nullptr) {} - -GpuEvent::~GpuEvent() { Destroy().IgnoreError(); } - -absl::Status GpuEvent::Init(bool allow_timing) { - return GpuDriver::InitEvent(context_, &gpu_event_, - allow_timing - ? GpuDriver::EventFlags::kDefault - : GpuDriver::EventFlags::kDisableTiming); -} - -absl::Status GpuEvent::Destroy() { - return GpuDriver::DestroyEvent(context_, &gpu_event_); -} - -absl::Status GpuEvent::Record(GpuStreamHandle stream_handle) { - return GpuDriver::RecordEvent(context_, gpu_event_, stream_handle); -} - -GpuEventHandle GpuEvent::gpu_event() { return gpu_event_; } - -absl::Status GpuEvent::WaitForEventOnExternalStream(std::intptr_t stream) { - return GpuDriver::WaitStreamOnEvent( - context_, absl::bit_cast(stream), gpu_event_); -} - -} // namespace gpu -} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_event.h b/third_party/xla/xla/stream_executor/gpu/gpu_event.h deleted file mode 100644 index f299cd90d3c562..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_event.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ - -#include - -#include "absl/status/status.h" -#include "xla/stream_executor/event.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_types.h" - -namespace stream_executor { -namespace gpu { - -class GpuContext; - -// GpuEvent wraps a GpuEventHandle in the platform-independent Event interface. -class GpuEvent : public Event { - public: - explicit GpuEvent(Context* context); - - ~GpuEvent() override; - - // Populates the CUDA-platform-specific elements of this object. - absl::Status Init(bool allow_timing); - - // Deallocates any platform-specific elements of this object. This is broken - // out (not part of the destructor) to allow for error reporting. - absl::Status Destroy(); - - // Inserts the event at the current position into the specified stream. - absl::Status Record(GpuStreamHandle stream_handle); - - // The underlying CUDA event element. - GpuEventHandle gpu_event(); - - absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; - - protected: - Context* context() const { return context_; } - - private: - // The Executor used to which this object and GpuEventHandle are bound. - Context* context_; - - // The underlying CUDA event element. - GpuEventHandle gpu_event_; -}; - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_EVENT_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index d515e1099488fb..0d90dadf84c9d6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_GPU_GPU_EXECUTOR_H_ #include -#include #include #include #include @@ -26,20 +25,14 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/host_memory_allocation.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_common.h" namespace stream_executor { -class StreamExecutor; - namespace gpu { class GpuStream; @@ -55,14 +48,6 @@ class GpuExecutor : public StreamExecutorCommon { int device_ordinal() const override { return device_ordinal_; }; - // Releases any state associated with the previously loaded kernel. - virtual void UnloadKernel(const Kernel* kernel) = 0; - // Creates an EventBasedTimer for the given stream. - virtual absl::StatusOr> - CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) = 0; - static absl::StatusOr> - CreateDeviceDescription(int device_ordinal); - // Frees unused memory cached on the device for use with graphs back to the // OS. virtual absl::Status TrimGraphMemory() = 0; @@ -113,10 +98,6 @@ class GpuExecutor : public StreamExecutorCommon { void operator=(const GpuExecutor&) = delete; }; -inline GpuExecutor* ExtractGpuExecutor(StreamExecutor* stream_exec) { - return static_cast(stream_exec); -} - } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index fb3439c507d1cb..c6714cce9def47 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -22,51 +22,16 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_KERNEL_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "tsl/platform/logging.h" namespace stream_executor::gpu { +// A GpuKernel is a `Kernel` that can be launched on a GPU. It allows +// access to the underlying GPU function through `gpu_function()`. class GpuKernel : public Kernel { public: - explicit GpuKernel(GpuExecutor* gpu_executor) - : gpu_executor_(gpu_executor), - gpu_context_(gpu_executor->gpu_context()) {} - - // Note that the function is unloaded when the module is unloaded, and the - // module that the function is contained in is owned by the GpuExecutor. - ~GpuKernel() override { gpu_executor_->UnloadKernel(this); } - - // As arity cannot be reflected upon using the CUDA API, the arity is - // explicitly set during the GpuExecutor::GetKernel initialization process. - void set_arity(unsigned arity) { arity_ = arity; } - unsigned Arity() const override { return arity_; } - - absl::StatusOr GetMaxOccupiedBlocksPerCore( - ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; - - // Simple accessor methods. - GpuFunctionHandle gpu_function() const { return gpu_function_; } - void set_gpu_function(GpuFunctionHandle gpu_function) { - gpu_function_ = gpu_function; - } - - private: - GpuExecutor* gpu_executor_ = nullptr; - Context* gpu_context_ = nullptr; // context where kernel is loaded - - GpuFunctionHandle gpu_function_ = nullptr; // wrapped CUDA kernel handle - unsigned arity_ = 0; // number of formal parameters the kernel takes + virtual GpuFunctionHandle gpu_function() const = 0; }; inline const GpuKernel* AsGpuKernel(const Kernel* kernel) { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index fc3984bc3d78a2..ee9b15487bab65 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -15,242 +15,18 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" +#include "absl/base/casts.h" #include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/event.h" -#include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" -#include "tsl/profiler/lib/nvtx_utils.h" namespace stream_executor { namespace gpu { -namespace { -void InternalHostCallback(void* data) { - auto* callback = reinterpret_cast*>(data); - std::move (*callback)(); - delete callback; -} -} // namespace - -absl::Status GpuStream::Init() { - int priority = [&]() { - if (std::holds_alternative(stream_priority_)) { - return std::get(stream_priority_); - } - return GpuDriver::GetGpuStreamPriority( - parent_->gpu_context(), std::get(stream_priority_)); - }(); - TF_ASSIGN_OR_RETURN( - gpu_stream_, GpuDriver::CreateStream(parent_->gpu_context(), priority)); - - return absl::OkStatus(); -} - -Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { - PlatformSpecificHandle handle; - handle.stream = gpu_stream_; - return handle; -} - -absl::Status GpuStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, - uint64_t size) { - CHECK(reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0); - return GpuDriver::AsynchronousMemsetUint32( - parent_->gpu_context(), - reinterpret_cast(location->opaque()), pattern, size / 4, - gpu_stream()); -} - -absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return Memset32(location, 0x0, size); - } else { - return GpuDriver::AsynchronousMemsetUint8( - parent_->gpu_context(), - reinterpret_cast(location->opaque()), 0x0, size, - gpu_stream()); - } -} - -absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2D( - parent_->gpu_context(), - reinterpret_cast(const_cast(gpu_dst->opaque())), - reinterpret_cast(const_cast(gpu_src.opaque())), size, - gpu_stream()); -} - -absl::Status GpuStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) { - return GpuDriver::AsynchronousMemcpyH2D( - parent_->gpu_context(), reinterpret_cast(gpu_dst->opaque()), - host_src, size, gpu_stream()); -} - -absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) { - return GpuDriver::AsynchronousMemcpyD2H( - parent_->gpu_context(), host_dst, - reinterpret_cast(const_cast(gpu_src.opaque())), size, - gpu_stream()); -} - -absl::Status GpuStream::WaitFor(Stream* other) { - GpuStream* other_gpu = AsGpuStream(other); - - GpuEvent* other_completed_event = other_gpu->completed_event(); - TF_RETURN_IF_ERROR(other_completed_event->Record(other_gpu->gpu_stream())); - - return GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), gpu_stream(), - other_completed_event->gpu_event()); -} - -absl::Status GpuStream::RecordEvent(Event* event) { - return static_cast(event)->Record(gpu_stream_); -} - -absl::Status GpuStream::WaitFor(Event* event) { - return GpuDriver::WaitStreamOnEvent( - parent_->gpu_context(), gpu_stream(), - static_cast(event)->gpu_event()); -} - -absl::Status GpuStream::DoHostCallbackWithStatus( - absl::AnyInvocable callback) { - auto callback_ptr = - new absl::AnyInvocable([cb = std::move(callback)]() mutable { - absl::Status s = std::move(cb)(); - if (!s.ok()) { - LOG(WARNING) << "Host callback failed: " << s; - } - }); - return GpuDriver::AddStreamCallback(parent_->gpu_context(), gpu_stream(), - InternalHostCallback, callback_ptr); -} - -GpuStream::~GpuStream() { - BlockHostUntilDone().IgnoreError(); - parent()->DeallocateStream(this); - - completed_event_.reset(); - GpuDriver::DestroyStream(parent_->gpu_context(), gpu_stream_); -} - -void GpuStream::set_name(absl::string_view name) { - name_ = name; - tsl::profiler::NameStream( - reinterpret_cast(gpu_stream()), name_); -} - -absl::StatusOr> -GpuStream::CreateEventBasedTimer(bool use_delay_kernel) { - return parent_->CreateEventBasedTimer(this, use_delay_kernel); -} - -absl::Status GpuStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) { - return Launch(thread_dims, block_dims, std::nullopt, kernel, args); -} - -absl::Status GpuStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), - kernel, args); -} - -absl::Status GpuStream::Launch(const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); - GpuFunctionHandle function = gpu_kernel->gpu_function(); - - // Launch kernels with packed arguments. - auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, - &function](const KernelArgsPackedArrayBase& packed) { - int32_t expected_number_of_arguments = - kernel.Arity() + (packed.number_of_shared_bytes() > 0); - - CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) - << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() - << " arguments, but expected " << expected_number_of_arguments - << "; arity=" << kernel.Arity() - << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); - - void** params = const_cast(packed.argument_addresses().data()); - - if (cluster_dims.has_value()) { - return GpuDriver::LaunchKernel( - parent_->gpu_context(), kernel.name(), function, cluster_dims->x, - cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), gpu_stream(), params, - /*extra=*/nullptr); - } else { - return GpuDriver::LaunchKernel( - parent_->gpu_context(), kernel.name(), function, block_dims.x, - block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, - thread_dims.z, packed.number_of_shared_bytes(), gpu_stream(), params, - /*extra=*/nullptr); - } - }; - - // If arguments are already packed we can just launch the kernel. - if (auto* packed = DynCast(&args)) { - return launch(*packed); - } - - // For device memory array we rely on a custom kernel arguments packing. - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return launch(*packed); - } - - return absl::InternalError("Unsupported kernel arguments type"); -} - -GpuStream* AsGpuStream(Stream* stream) { - DCHECK(stream != nullptr); - return static_cast(stream); -} - GpuStreamHandle AsGpuStreamValue(Stream* stream) { DCHECK(stream != nullptr); - return AsGpuStream(stream)->gpu_stream(); + return absl::bit_cast( + stream->platform_specific_handle().stream); } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 249fbf78877a4e..ec95ec50e25226 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -19,114 +19,12 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ #define XLA_STREAM_EXECUTOR_GPU_GPU_STREAM_H_ -#include -#include -#include -#include -#include - -#include "absl/functional/any_invocable.h" -#include "absl/log/check.h" -#include "absl/strings/string_view.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/event.h" -#include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_common.h" namespace stream_executor { namespace gpu { -class GpuExecutor; - -// Wraps a GpuStreamHandle in order to satisfy the platform-independent -// StreamInterface. -// -// Thread-safe post-initialization. -class GpuStream : public StreamCommon { - public: - GpuStream(GpuExecutor* parent, std::unique_ptr completed_event, - std::optional> priority) - : StreamCommon(parent), - parent_(parent), - gpu_stream_(nullptr), - completed_event_(std::move(completed_event)) { - if (priority.has_value()) { - stream_priority_ = priority.value(); - } - } - - // Note: teardown is handled by a parent's call to DeallocateStream. - ~GpuStream() override; - - // Explicitly initialize the CUDA resources associated with this stream. - absl::Status Init(); - - std::variant priority() const override { - return stream_priority_; - } - PlatformSpecificHandle platform_specific_handle() const override; - - // Retrieves an event which indicates that all work enqueued into the stream - // has completed. Ownership of the event is not transferred to the caller, the - // event is owned by this stream. - GpuEvent* completed_event() { return completed_event_.get(); } - - // Returns the GpuStreamHandle value for passing to the CUDA API. - // - // Precond: this GpuStream has been allocated (otherwise passing a nullptr - // into the NVIDIA library causes difficult-to-understand faults). - GpuStreamHandle gpu_stream() const { - DCHECK(gpu_stream_ != nullptr); - return gpu_stream_; - } - - absl::Status WaitFor(Stream* other) override; - absl::Status WaitFor(Event* event) override; - absl::Status RecordEvent(Event* event) override; - absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; - absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, - uint64_t size) override; - absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, - uint64_t size) override; - absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, - uint64_t size) override; - absl::Status Memcpy(DeviceMemoryBase* gpu_dst, - const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status DoHostCallbackWithStatus( - absl::AnyInvocable callback) override; - - void set_name(absl::string_view name) override; - absl::StatusOr> CreateEventBasedTimer( - bool use_delay_kernel) override; - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const Kernel& k, const KernelArgs& args) override; - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& k, - const KernelArgs& args) override; - - private: - // Helper method to launch a kernel with optional cluster dimensions. - absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args); - - GpuExecutor* parent_; // Executor that spawned this stream. - GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. - std::variant stream_priority_; - std::unique_ptr completed_event_; -}; - -// Helper functions to simplify extremely common flows. -// Converts a Stream to the underlying GpuStream implementation. -GpuStream* AsGpuStream(Stream* stream); - // Extracts a GpuStreamHandle from a GpuStream-backed Stream object. GpuStreamHandle AsGpuStreamValue(Stream* stream); } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc deleted file mode 100644 index d6b1b73f42d37f..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/stream_executor/gpu/gpu_timer.h" - -#include -#include - -#include "absl/base/const_init.h" -#include "absl/base/thread_annotations.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/time/time.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace stream_executor { -namespace gpu { - -namespace { - -bool return_random_durations = false; - -absl::Duration RandomDuration() { - static absl::Mutex mu(absl::kConstInit); - static std::mt19937 rng ABSL_GUARDED_BY(mu); - std::uniform_real_distribution distribution(10, 1000); - absl::MutexLock l(&mu); - return absl::Microseconds(distribution(rng)); -} - -} // namespace - -void GpuTimer::ReturnRandomDurationsForTesting() { - return_random_durations = true; -} - -GpuTimer::~GpuTimer() { - if (semaphore_ && !is_stopped_) { - // Signal the delay kernel that it can exit - *semaphore_ = GpuSemaphoreState::kRelease; - // Wait for the delay kernel to exit before destroying the value that it is - // watching. - absl::Status status = - GpuDriver::SynchronizeStream(context_, stream_->gpu_stream()); - if (!status.ok()) { - LOG(ERROR) << status; - } - } - start_event_.reset(); - stop_event_.reset(); -} - -absl::StatusOr GpuTimer::GetElapsedDuration() { - if (is_stopped_) { - return absl::InternalError("Measuring inactive timer"); - } - TF_RETURN_IF_ERROR(stop_event_->Record(stream_->gpu_stream())); - // If we launched the delay kernel then check if it already timed out. - if (semaphore_) { - if (*semaphore_ == GpuSemaphoreState::kTimedOut) { - // The delay kernel did not achieve the intended result. - LOG(ERROR) << "Delay kernel timed out: measured time has sub-optimal " - "accuracy. There may be a missing warmup execution, please " - "investigate in Nsight Systems."; - } else { - // Signal that the kernel can exit - *semaphore_ = GpuSemaphoreState::kRelease; - } - } - TF_ASSIGN_OR_RETURN( - float elapsed_milliseconds, - GpuDriver::GetEventElapsedTime(context_, start_event_->gpu_event(), - stop_event_->gpu_event())); - is_stopped_ = true; - if (return_random_durations) { - return RandomDuration(); - } - return absl::Milliseconds(elapsed_milliseconds); -} - -} // namespace gpu -} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h deleted file mode 100644 index 14002f61bc7478..00000000000000 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2019 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ -#define XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ - -#include -#include - -#include "absl/status/statusor.h" -#include "absl/time/time.h" -#include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_semaphore.h" - -namespace xla { -namespace gpu { -class DeterminismTest; -} -} // namespace xla - -namespace stream_executor { -namespace gpu { - -class GpuStream; - -// When a timer is created it launches a delay kernel into the given stream and -// queues a start event immediately afterwards. This delay kernel blocks -// execution on the stream until GetElapsedDuration() is called, at which point -// an end event is queued and the delay kernel exits. This allows the device -// execution time of the tasks queued to the stream while the timer is active -// to be measured more accurately. -class GpuTimer : public EventBasedTimer { - public: - GpuTimer(Context* context, std::unique_ptr start_event, - std::unique_ptr stop_event, GpuStream* stream, - GpuSemaphore semaphore = {}) - : context_(context), - start_event_(std::move(start_event)), - stop_event_(std::move(stop_event)), - stream_(stream), - semaphore_(std::move(semaphore)) {} - - GpuTimer(GpuTimer&& other) - : context_(other.context_), - start_event_(std::exchange(other.start_event_, nullptr)), - stop_event_(std::exchange(other.stop_event_, nullptr)), - stream_(other.stream_), - semaphore_(std::move(other.semaphore_)) {} - - GpuTimer& operator=(GpuTimer&& other) { - if (this != &other) { - context_ = other.context_; - start_event_ = std::exchange(other.start_event_, nullptr); - stop_event_ = std::exchange(other.stop_event_, nullptr); - stream_ = other.stream_; - semaphore_ = std::move(other.semaphore_); - } - return *this; - } - - ~GpuTimer() override; - - absl::StatusOr GetElapsedDuration() override; - - private: - Context* context_; - std::unique_ptr start_event_; - std::unique_ptr stop_event_; - GpuStream* stream_; - GpuSemaphore semaphore_; - bool is_stopped_ = false; - - GpuTimer(const GpuTimer&) = delete; - void operator=(const GpuTimer&) = delete; - - // If called, all timers will return random durations instead of the actual - // duration the timer took. Used for testing only. - static void ReturnRandomDurationsForTesting(); - friend class ::xla::gpu::DeterminismTest; -}; - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TIMER_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_types.h b/third_party/xla/xla/stream_executor/gpu/gpu_types.h index a47b5d81d66d28..52a07b6ddf3114 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_types.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_types.h @@ -24,8 +24,6 @@ limitations under the License. #elif TENSORFLOW_USE_ROCM -#define __HIP_DISABLE_CPP_FUNCTIONS__ - #include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hiprand/hiprand.h" @@ -45,14 +43,8 @@ struct UnsupportedGpuFeature {}; #if TENSORFLOW_USE_SYCL using GpuStreamHandle = ::sycl::queue*; -using GpuEventHandle = ::sycl::event*; using GpuFunctionHandle = ::sycl::kernel*; -using GpuDeviceHandle = ::sycl::device*; using GpuDevicePtr = void*; -using GpuDeviceAttribute = UnsupportedGpuFeature; -using GpuDeviceProperty = UnsupportedGpuFeature; -using GpuModuleHandle = ze_module_handle_t; -using GpuFuncCachePreference = UnsupportedGpuFeature; using GpuGraphHandle = UnsupportedGpuFeature; using GpuGraphExecHandle = UnsupportedGpuFeature; using GpuGraphNodeHandle = UnsupportedGpuFeature; @@ -61,14 +53,8 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature; #elif TENSORFLOW_USE_ROCM using GpuStreamHandle = hipStream_t; -using GpuEventHandle = hipEvent_t; using GpuFunctionHandle = hipFunction_t; -using GpuDeviceHandle = hipDevice_t; using GpuDevicePtr = hipDeviceptr_t; -using GpuDeviceAttribute = hipDeviceAttribute_t; -using GpuDeviceProperty = hipDeviceProp_t; -using GpuModuleHandle = hipModule_t; -using GpuFuncCachePreference = hipFuncCache_t; using GpuGraphHandle = hipGraph_t; using GpuGraphExecHandle = hipGraphExec_t; using GpuGraphNodeHandle = hipGraphNode_t; @@ -76,14 +62,8 @@ using GpuGraphConditionalHandle = UnsupportedGpuFeature; #else // CUDA using GpuStreamHandle = CUstream; -using GpuEventHandle = CUevent; using GpuFunctionHandle = CUfunction; -using GpuDeviceHandle = CUdevice; using GpuDevicePtr = CUdeviceptr; -using GpuDeviceAttribute = CUdevice_attribute; -using GpuDeviceProperty = CUdevprop; -using GpuModuleHandle = CUmodule; -using GpuFuncCachePreference = CUfunc_cache; using GpuGraphHandle = CUgraph; using GpuGraphExecHandle = CUgraphExec; using GpuGraphNodeHandle = CUgraphNode; diff --git a/third_party/xla/xla/stream_executor/gpu/mock_context.h b/third_party/xla/xla/stream_executor/gpu/mock_context.h index 9ca49331f77f44..11b3228a5cdb00 100644 --- a/third_party/xla/xla/stream_executor/gpu/mock_context.h +++ b/third_party/xla/xla/stream_executor/gpu/mock_context.h @@ -28,6 +28,7 @@ class MockContext : public Context { MOCK_METHOD(void, SetActive, (), (override)); MOCK_METHOD(bool, IsActive, (), (const, override)); MOCK_METHOD(int, device_ordinal, (), (const, override)); + MOCK_METHOD(absl::Status, Synchronize, (), (override)); }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/mock_gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/mock_gpu_executor.h new file mode 100644 index 00000000000000..3cd5c2dc65a81a --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/mock_gpu_executor.h @@ -0,0 +1,124 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_MOCK_GPU_EXECUTOR_H_ +#define XLA_STREAM_EXECUTOR_GPU_MOCK_GPU_EXECUTOR_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xla/stream_executor/allocator_stats.h" +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/fft.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/memory_allocation.h" +#include "xla/stream_executor/module_spec.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/test.h" + +namespace stream_executor::gpu { + +// Implements StreamExecutor for testing. +class MockGpuExecutor : public GpuExecutor { + public: + using GpuExecutor::GpuExecutor; + MOCK_METHOD(absl::Status, Init, (), (override)); + MOCK_METHOD(int, device_ordinal, (), (const, override)); + MOCK_METHOD(absl::StatusOr>, LoadKernel, + (const MultiKernelLoaderSpec& spec), (override)); + MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override)); + MOCK_METHOD(absl::StatusOr, LoadModule, + (const MultiModuleLoaderSpec& spec), (override)); + MOCK_METHOD(absl::StatusOr>, + CreateOrShareConstant, + (Stream * stream, absl::Span content), (override)); + MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space), + (override)); + MOCK_METHOD(void, Deallocate, (DeviceMemoryBase * mem), (override)); + MOCK_METHOD(void*, UnifiedMemoryAllocate, (uint64_t size), (override)); + MOCK_METHOD(void, UnifiedMemoryDeallocate, (void* mem), (override)); + MOCK_METHOD(absl::StatusOr, CollectiveMemoryAllocate, (uint64_t size), + (override)); + MOCK_METHOD(absl::Status, CollectiveMemoryDeallocate, (void* mem), + (override)); + MOCK_METHOD(absl::StatusOr>, + HostMemoryAllocate, (uint64_t size), (override)); + MOCK_METHOD(void, HostMemoryDeallocate, (void* mem), (override)); + MOCK_METHOD(bool, SynchronizeAllActivity, (), (override)); + MOCK_METHOD(absl::Status, SynchronousMemZero, + (DeviceMemoryBase * location, uint64_t size), (override)); + MOCK_METHOD(absl::Status, SynchronousMemcpy, + (DeviceMemoryBase * device_dst, const void* host_src, + uint64_t size), + (override)); + MOCK_METHOD(absl::Status, SynchronousMemcpy, + (void* host_dst, const DeviceMemoryBase& device_src, + uint64_t size), + (override)); + MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); + MOCK_METHOD(absl::Status, EnablePeerAccessTo, (StreamExecutor * other), + (override)); + MOCK_METHOD(bool, CanEnablePeerAccessTo, (StreamExecutor * other), + (override)); + MOCK_METHOD(bool, DeviceMemoryUsage, (int64_t* free, int64_t* total), + (const, override)); + MOCK_METHOD(absl::StatusOr, GetSymbol, + (const std::string& symbol_name, ModuleHandle module_handle), + (override)); + MOCK_METHOD(absl::StatusOr>, + CreateDeviceDescription, (), (const, override)); + MOCK_METHOD(blas::BlasSupport*, AsBlas, (), (override)); + MOCK_METHOD(fft::FftSupport*, AsFft, (), (override)); + MOCK_METHOD(dnn::DnnSupport*, AsDnn, (), (override)); + MOCK_METHOD(absl::StatusOr>, + CreateCommandBuffer, (CommandBuffer::Mode mode), (override)); + MOCK_METHOD(std::optional, GetAllocatorStats, (), (override)); + MOCK_METHOD(bool, ClearAllocatorStats, (), (override)); + MOCK_METHOD(absl::Status, FlushCompilationCache, (), (override)); + MOCK_METHOD(Stream*, FindAllocatedStream, (void* device_stream), (override)); + MOCK_METHOD(const Platform*, GetPlatform, (), (const, override)); + MOCK_METHOD(absl::StatusOr>, CreateStream, + ((std::optional>)), (override)); + MOCK_METHOD(int64_t, GetMemoryLimitBytes, (), (const.override)); + MOCK_METHOD(const DeviceDescription&, GetDeviceDescription, (), + (const, override)); + MOCK_METHOD(absl::StatusOr>, CreateEvent, (), + (override)); + + MOCK_METHOD(void, UnloadKernel, (const Kernel* kernel)); + MOCK_METHOD(absl::StatusOr>, + CreateEventBasedTimer, (Stream * stream, bool use_delay_kernel)); + MOCK_METHOD(absl::Status, TrimGraphMemory, ()); +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_MOCK_GPU_EXECUTOR_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc index f15b543167d7d1..57aea9ac3fa644 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h index 842289fbed6cab..ca744c84cd6ca8 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator.h @@ -60,7 +60,8 @@ class RedzoneAllocator : public ScratchAllocator { return allocated_bytes_excluding_redzones_; } - absl::StatusOr> AllocateBytes(int64_t byte_size) override; + absl::StatusOr> AllocateBytes( + int64_t byte_size) override; // Non-empty redzone check status implies that there was a write into a // redzone, with a string communicating the location of the write. diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc index a832a7fa960784..59616362a448c8 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_kernel_rocm.cu.cc @@ -39,7 +39,7 @@ namespace stream_executor { absl::StatusOr GetComparisonKernel( StreamExecutor* executor, GpuAsmOpts /*gpu_asm_opts*/) { static auto kernel = TypedKernelFactory< - DeviceMemory, uint8, uint64_t, + DeviceMemory, uint8_t, uint64_t, DeviceMemory>::Create(executor, "redzone_checker", reinterpret_cast( redzone_checker_kernel)); diff --git a/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.cc b/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.cc index 65018a9f9a43dc..8607e6a6fafa13 100644 --- a/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.cc +++ b/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.cc @@ -17,8 +17,6 @@ limitations under the License. #include "absl/log/check.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" namespace stream_executor::gpu { @@ -33,12 +31,6 @@ thread_local struct ThreadLocalData { } // namespace -ScopedActivateContext::ScopedActivateContext(GpuExecutor* gpu_executor) - : ScopedActivateContext(gpu_executor->gpu_context()) {} - -ScopedActivateContext::ScopedActivateContext(StreamExecutor* executor) - : ScopedActivateContext(ExtractGpuExecutor(executor)) {} - ScopedActivateContext::ScopedActivateContext(gpu::Context* gpu_context) { auto* tls = &tls_data; diff --git a/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.h b/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.h index c6d2a34e48d601..1dad6ae4781efc 100644 --- a/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.h +++ b/third_party/xla/xla/stream_executor/gpu/scoped_activate_context.h @@ -16,23 +16,20 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_GPU_SCOPED_ACTIVATE_CONTEXT_H_ #define XLA_STREAM_EXECUTOR_GPU_SCOPED_ACTIVATE_CONTEXT_H_ +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/stream_executor.h" namespace stream_executor::gpu { // Ensures a context is activated within a scope. -class ScopedActivateContext { +class ScopedActivateContext : public ActivateContext { public: // Activates the context via Context::SetActive. explicit ScopedActivateContext(Context* gpu_context); - explicit ScopedActivateContext(GpuExecutor* gpu_executor); - explicit ScopedActivateContext(StreamExecutor* executor); // Checks that the context has remained activated for the duration of the // scope. - ~ScopedActivateContext(); + ~ScopedActivateContext() override; private: Context* to_restore_ = nullptr; diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index a03a21ceb5592b..fc470c638d311b 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -50,10 +50,12 @@ cc_library( deps = [ ":host_executor", ":host_platform_id", - "//xla/stream_executor", + "//xla/stream_executor:device_description", "//xla/stream_executor:executor_cache", + "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", - "//xla/stream_executor/platform", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -112,11 +114,11 @@ cc_library( hdrs = ["host_kernel.h"], deps = [ ":host_kernel_c_api", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/tsl/concurrency:async_value", - "//xla/tsl/concurrency:ref_count", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", @@ -125,7 +127,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", ], @@ -140,9 +141,13 @@ xla_cc_test( ":host_platform", ":jit_host_kernel_function", ":ptr_host_kernel_function", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/concurrency:async_value", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", @@ -197,9 +202,10 @@ xla_cc_test( srcs = ["host_stream_test.cc"], deps = [ ":host_platform", - "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", @@ -219,7 +225,7 @@ cc_library( ":host_kernel", ":host_kernel_c_api", "//xla/stream_executor:kernel_spec", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], @@ -235,7 +241,7 @@ cc_library( ":host_kernel", ":host_kernel_c_api", "//xla/stream_executor:kernel_spec", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index 68207fe0b438c2..86a7225ecb63dc 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -143,10 +143,6 @@ static HostEvent* AsHostEvent(Event* event) { return static_cast(event); } -absl::Status HostExecutor::BlockHostUntilDone(Stream* stream) { - return AsHostStream(stream)->BlockUntilDone(); -} - absl::StatusOr> HostExecutor::CreateDeviceDescription(int device_ordinal) { DeviceDescription desc; diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 55eacc5fff4851..831cf27727b3ae 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -88,8 +88,6 @@ class HostExecutor : public StreamExecutorCommon { void DeallocateStream(Stream* stream) override; - absl::Status BlockHostUntilDone(Stream* stream) override; - bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; absl::StatusOr> CreateDeviceDescription() diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h index b8ce8f4340d6c4..3d6f09d3fb49b5 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.h +++ b/third_party/xla/xla/stream_executor/host/host_platform.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index 76b66711e03d62..1cbf01298ce213 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -22,6 +22,7 @@ limitations under the License. #include // NOLINT #include #include +#include #include #include @@ -197,7 +198,13 @@ absl::Status HostStream::BlockUntilDone() { absl::Status HostStream::Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, const Kernel& kernel, const KernelArgs& args) { + if (cluster_dims.has_value()) { + if (cluster_dims->x != 1 || cluster_dims->y != 1 || cluster_dims->z != 1) { + return absl::UnimplementedError("Not implemented for Host"); + } + } const HostKernel* host_kernel = AsHostKernel(&kernel); const KernelArgsDeviceMemoryArray* device_mem = diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index a43ba610e25417..dc6760f8f629ca 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -56,6 +56,8 @@ class HostStream : public StreamCommon { // (if any) and clears the error status. absl::Status BlockUntilDone(); + absl::Status BlockHostUntilDone() override { return BlockUntilDone(); } + absl::Status WaitFor(Stream* other) override; absl::Status WaitFor(Event* event) override; absl::Status RecordEvent(Event* event) override; @@ -71,6 +73,7 @@ class HostStream : public StreamCommon { absl::Status DoHostCallbackWithStatus( absl::AnyInvocable callback) override; absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, const Kernel& kernel, const KernelArgs& args) override; private: diff --git a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc index 76b9d1a42cb9da..1034abca258ee6 100644 --- a/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc +++ b/third_party/xla/xla/stream_executor/host/jit_host_kernel_function.cc @@ -38,6 +38,7 @@ limitations under the License. #include "llvm/ExecutionEngine/ObjectCache.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" #include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" #include "llvm/ExecutionEngine/Orc/LLJIT.h" diff --git a/third_party/xla/xla/stream_executor/integrations/BUILD b/third_party/xla/xla/stream_executor/integrations/BUILD index ab75bbcd5d66c4..2a9cc0b3d24e97 100644 --- a/third_party/xla/xla/stream_executor/integrations/BUILD +++ b/third_party/xla/xla/stream_executor/integrations/BUILD @@ -44,10 +44,11 @@ cc_library( srcs = ["tf_allocator_adapter.cc"], hdrs = ["tf_allocator_adapter.h"], deps = [ - "//xla/stream_executor", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/framework:allocator", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -69,8 +70,8 @@ cc_library( "device_mem_allocator.h", ], deps = [ - "//xla/stream_executor", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/framework:allocator", "//xla/tsl/framework:device_id", "@com_google_absl//absl/base:core_headers", @@ -88,7 +89,10 @@ xla_cc_test( ":tf_allocator_adapter", "//xla/service:cpu_plugin", "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/log:check", diff --git a/third_party/xla/xla/stream_executor/kernel.h b/third_party/xla/xla/stream_executor/kernel.h index f03b373b5a4d76..6076717d430598 100644 --- a/third_party/xla/xla/stream_executor/kernel.h +++ b/third_party/xla/xla/stream_executor/kernel.h @@ -94,8 +94,6 @@ limitations under the License. namespace stream_executor { -class Kernel; - //===----------------------------------------------------------------------===// // Kernel metadata //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/launch_dim.h b/third_party/xla/xla/stream_executor/launch_dim.h index f8a408bfc7211e..59b935c1ac7574 100644 --- a/third_party/xla/xla/stream_executor/launch_dim.h +++ b/third_party/xla/xla/stream_executor/launch_dim.h @@ -58,6 +58,10 @@ struct ThreadDim : internal::Dim3D { struct BlockDim : internal::Dim3D { explicit BlockDim(uint64_t x = 1, uint64_t y = 1, uint64_t z = 1) : internal::Dim3D({x, y, z}) {} + + std::string ToString() const { + return absl::StrCat("BlockDim{", x, ", ", y, ", ", z, "}"); + } }; // Cluster dimensionality for use in a kernel launch. diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index bf964e05bbaae6..f3c8d004397639 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -29,8 +29,8 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/protobuf/dnn.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/dnn.pb.h" namespace stream_executor { namespace dnn { diff --git a/third_party/xla/xla/stream_executor/mock_stream.h b/third_party/xla/xla/stream_executor/mock_stream.h index 5e9750e124caaa..41d06aa4f6e607 100644 --- a/third_party/xla/xla/stream_executor/mock_stream.h +++ b/third_party/xla/xla/stream_executor/mock_stream.h @@ -18,12 +18,13 @@ limitations under the License. #include #include +#include +#include #include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" @@ -76,15 +77,11 @@ class MockStream : public Stream { (const, override)); MOCK_METHOD(absl::Status, Launch, (const ThreadDim &thread_dims, const BlockDim &block_dims, - const Kernel &k, const KernelArgs &args), - (override)); - MOCK_METHOD(absl::Status, Launch, - (const ThreadDim &thread_dims, const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, + const std::optional &cluster_dims, const Kernel &k, const KernelArgs &args), (override)); - MOCK_METHOD(absl::string_view, name, (), (const, override)); - MOCK_METHOD(void, set_name, (absl::string_view name), (override)); + MOCK_METHOD(const std::string &, GetName, (), (const, override)); + MOCK_METHOD(void, SetName, (std::string name), (override)); MOCK_METHOD(absl::StatusOr>, CreateEventBasedTimer, (bool use_delay_kernel), (override)); }; diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 0379c2c068dc18..68d76c5e958259 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -1,3 +1,4 @@ +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/fft.h" @@ -52,10 +53,10 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(int, device_ordinal, (), (const, override)); MOCK_METHOD(absl::StatusOr>, LoadKernel, (const MultiKernelLoaderSpec& spec), (override)); + MOCK_METHOD(std::unique_ptr, Activate, (), (override)); MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override)); - MOCK_METHOD(absl::Status, LoadModule, - (const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle), - (override)); + MOCK_METHOD(absl::StatusOr, LoadModule, + (const MultiModuleLoaderSpec& spec), (override)); MOCK_METHOD(absl::StatusOr>, CreateOrShareConstant, (Stream * stream, absl::Span content), (override)); @@ -83,7 +84,6 @@ class MockStreamExecutor : public StreamExecutor { uint64_t size), (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); - MOCK_METHOD(absl::Status, BlockHostUntilDone, (Stream * stream), (override)); MOCK_METHOD(absl::Status, EnablePeerAccessTo, (StreamExecutor * other), (override)); MOCK_METHOD(bool, CanEnablePeerAccessTo, (StreamExecutor * other), @@ -112,6 +112,10 @@ class MockStreamExecutor : public StreamExecutor { (const, override)); MOCK_METHOD(absl::StatusOr>, CreateEvent, (), (override)); + MOCK_METHOD(void, UnloadKernel, (const Kernel* kernel), (override)); + MOCK_METHOD(absl::StatusOr>, + CreateEventBasedTimer, (Stream * stream, bool use_delay_kernel), + (override)); }; } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/module_spec.h b/third_party/xla/xla/stream_executor/module_spec.h index eb5c54a939befc..db733e48cfb647 100644 --- a/third_party/xla/xla/stream_executor/module_spec.h +++ b/third_party/xla/xla/stream_executor/module_spec.h @@ -17,7 +17,9 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_MODULE_SPEC_H_ #include +#include +#include "absl/strings/str_format.h" #include "absl/types/span.h" #include "tsl/platform/logging.h" @@ -32,16 +34,35 @@ namespace stream_executor { // An instance of this is returned from StreamExecutor::GetModule. class ModuleHandle { public: - explicit ModuleHandle(void* id = nullptr) : id_(id) {} + explicit ModuleHandle(const void* id = nullptr) : id_(id) {} // A ModuleHandle with id() == nullptr is an invalid module handle, akin to a // null pointer. - void* id() const { return id_; } + const void* id() const { return id_; } explicit operator bool() const { return id() != nullptr; } + template + friend H AbslHashValue(H h, const ModuleHandle& handle) { + return H::combine(std::move(h), handle.id_); + } + friend bool operator==(const ModuleHandle& lhs, const ModuleHandle& rhs) { + return lhs.id_ == rhs.id_; + } + friend bool operator!=(const ModuleHandle& lhs, const ModuleHandle& rhs) { + return lhs.id_ != rhs.id_; + } + template + friend void AbslStringify(Sink& sink, const ModuleHandle& handle) { + sink.Append(absl::StrFormat("ModuleHandle(id=%p)", handle.id_)); + } + friend std::ostream& operator<<(std::ostream& os, + const ModuleHandle& handle) { + return os << absl::StrFormat("ModuleHandle(id=%p)", handle.id_); + } + private: - void* id_; + const void* id_; }; //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/stream_executor/platform/BUILD b/third_party/xla/xla/stream_executor/platform/BUILD index 7334d2cc63b1cf..35c1e32f6c761c 100644 --- a/third_party/xla/xla/stream_executor/platform/BUILD +++ b/third_party/xla/xla/stream_executor/platform/BUILD @@ -15,23 +15,14 @@ package_group( ) cc_library( - name = "platform", - textual_hdrs = [ + name = "initialize", + hdrs = [ "initialize.h", - "platform.h", - "port.h", ], deps = [ + "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:types", ] + tf_stream_executor_deps("platform", "//xla/stream_executor/platform/"), ) - -cc_library( - name = "dso_loader", - hdrs = ["dso_loader.h"], - deps = [ - ":platform", - ] + tf_stream_executor_deps("dso_loader", "//xla/stream_executor/platform/"), -) diff --git a/third_party/xla/xla/stream_executor/platform/default/BUILD b/third_party/xla/xla/stream_executor/platform/default/BUILD index 4e5e517d28dbef..17fbcf2be7e9a2 100644 --- a/third_party/xla/xla/stream_executor/platform/default/BUILD +++ b/third_party/xla/xla/stream_executor/platform/default/BUILD @@ -1,5 +1,4 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla/tsl:tsl.bzl", "tsl_copts") licenses(["notice"]) @@ -14,20 +13,3 @@ cc_library( name = "platform", textual_hdrs = ["initialize.h"], ) - -cc_library( - name = "dso_loader", - hdrs = ["dso_loader.h"], - compatible_with = [], - copts = tsl_copts(), - tags = [ - "manual", - "nobuilder", - ], - deps = [ - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:dso_loader", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) diff --git a/third_party/xla/xla/stream_executor/platform/default/initialize.h b/third_party/xla/xla/stream_executor/platform/default/initialize.h index 78b24977ac7ae3..a3cc0f24423f26 100644 --- a/third_party/xla/xla/stream_executor/platform/default/initialize.h +++ b/third_party/xla/xla/stream_executor/platform/default/initialize.h @@ -23,39 +23,14 @@ namespace port { class Initializer { public: - typedef void (*InitializerFunc)(); - explicit Initializer(InitializerFunc func) { func(); } - - struct Dependency { - Dependency(const char *n, Initializer *i) : name(n), initializer(i) {} - const char *const name; - Initializer *const initializer; - }; - - struct DependencyRegisterer { - DependencyRegisterer(const char *type, const char *name, - Initializer *initializer, - const Dependency &dependency); - }; + explicit Initializer(void (*func)()) { func(); } }; } // namespace port } // namespace stream_executor -#define STREAM_EXECUTOR_REGISTER_INITIALIZER(type, name, body) \ - static void google_init_##type##_##name() { body; } \ - ::stream_executor::port::Initializer google_initializer_##type##_##name( \ - google_init_##type##_##name) - -#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(name, body) \ - STREAM_EXECUTOR_REGISTER_INITIALIZER(module, name, body) - -#define STREAM_EXECUTOR_DECLARE_INITIALIZER(type, name) \ - extern ::stream_executor::port::Initializer google_initializer_##type##_##name - -#define STREAM_EXECUTOR_DECLARE_MODULE_INITIALIZER(name) \ - STREAM_EXECUTOR_DECLARE_INITIALIZER(module, name) - -#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER_SEQUENCE(name1, name2) +#define STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(name, body) \ + ::stream_executor::port::Initializer google_initializer_module##_##name( \ + []() { body; }) #endif // XLA_STREAM_EXECUTOR_PLATFORM_DEFAULT_INITIALIZE_H_ diff --git a/third_party/xla/xla/stream_executor/platform/dso_loader.h b/third_party/xla/xla/stream_executor/platform/dso_loader.h deleted file mode 100644 index bfd1e061f9d824..00000000000000 --- a/third_party/xla/xla/stream_executor/platform/dso_loader.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_PLATFORM_DSO_LOADER_H_ -#define XLA_STREAM_EXECUTOR_PLATFORM_DSO_LOADER_H_ - -#include "xla/stream_executor/platform/platform.h" - -// Include appropriate platform-dependent implementations -#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_CHROMIUMOS) -#include "xla/stream_executor/platform/google/dso_loader.h" -#elif defined(PLATFORM_POSIX) || defined(PLATFORM_POSIX_ANDROID) || \ - defined(PLATFORM_GOOGLE_ANDROID) || defined(PLATFORM_WINDOWS) -#include "xla/stream_executor/platform/default/dso_loader.h" -#else -#error Define the appropriate PLATFORM_ macro for this platform -#endif - -#endif // XLA_STREAM_EXECUTOR_PLATFORM_DSO_LOADER_H_ diff --git a/third_party/xla/xla/stream_executor/platform/initialize.h b/third_party/xla/xla/stream_executor/platform/initialize.h index 910b0116343181..1e3069f782c9ef 100644 --- a/third_party/xla/xla/stream_executor/platform/initialize.h +++ b/third_party/xla/xla/stream_executor/platform/initialize.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_PLATFORM_INITIALIZE_H_ #define XLA_STREAM_EXECUTOR_PLATFORM_INITIALIZE_H_ -#include "xla/stream_executor/platform/platform.h" +#include "tsl/platform/platform.h" #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_CHROMIUMOS) #include "xla/stream_executor/platform/google/initialize.h" // IWYU pragma: export diff --git a/third_party/xla/xla/stream_executor/platform/platform.h b/third_party/xla/xla/stream_executor/platform/platform.h deleted file mode 100644 index 3b00ab8cc64768..00000000000000 --- a/third_party/xla/xla/stream_executor/platform/platform.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2018 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_STREAM_EXECUTOR_PLATFORM_PLATFORM_H_ -#define XLA_STREAM_EXECUTOR_PLATFORM_PLATFORM_H_ - -#if !defined(PLATFORM_POSIX) && !defined(PLATFORM_GOOGLE) && \ - !defined(PLATFORM_POSIX_ANDROID) && !defined(PLATFORM_GOOGLE_ANDROID) && \ - !defined(PLATFORM_WINDOWS) && !defined(PLATFORM_CHROMIUMOS) - -// Choose which platform we are on. -#if defined(ANDROID) || defined(__ANDROID__) -#define PLATFORM_POSIX_ANDROID - -#elif defined(__APPLE__) -#define PLATFORM_POSIX - -#elif defined(_WIN32) -#define PLATFORM_WINDOWS - -#elif defined(__TF_CHROMIUMOS__) -#define PLATFORM_CHROMIUMOS - -#else -// If no platform specified, use: -#define PLATFORM_POSIX - -#endif -#endif - -#endif // XLA_STREAM_EXECUTOR_PLATFORM_PLATFORM_H_ diff --git a/third_party/xla/xla/stream_executor/platform/port.h b/third_party/xla/xla/stream_executor/platform/port.h deleted file mode 100644 index 9561d7bf20cb15..00000000000000 --- a/third_party/xla/xla/stream_executor/platform/port.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2015 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// IWYU pragma: private, include "xla/stream_executor/stream_executor.h" - -#ifndef XLA_STREAM_EXECUTOR_PLATFORM_PORT_H_ -#define XLA_STREAM_EXECUTOR_PLATFORM_PORT_H_ - -#include "tsl/platform/macros.h" -#include "tsl/platform/types.h" - -namespace stream_executor { - -using tsl::int16; -using tsl::int32; -using tsl::int8; - -using tsl::uint16; -using tsl::uint32; -using tsl::uint64; -using tsl::uint8; - -#if !defined(PLATFORM_GOOGLE) -using std::string; -#endif - -#define SE_FALLTHROUGH_INTENDED TF_FALLTHROUGH_INTENDED - -} // namespace stream_executor - -// DEPRECATED: directly use the macro implementation instead. -#define SE_DISALLOW_COPY_AND_ASSIGN TF_DISALLOW_COPY_AND_ASSIGN - -#define SE_MUST_USE_RESULT TF_MUST_USE_RESULT -#define SE_PREDICT_TRUE TF_PREDICT_TRUE -#define SE_PREDICT_FALSE TF_PREDICT_FALSE - -#endif // XLA_STREAM_EXECUTOR_PLATFORM_PORT_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index bd1610200b16f5..f8a839d07f57d2 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -1,22 +1,24 @@ # Description: # ROCm-platform specific StreamExecutor support code. -# buildifier: disable=out-of-order-load - -# buildifier: disable=out-of-order-load -load( - "//xla/stream_executor:build_defs.bzl", - "stream_executor_friends", -) load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", - "if_rocm_is_configured", "rocm_library", ) -load("//xla/tsl:tsl.bzl", "internal_visibility", "tsl_copts") load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load( + "//xla/stream_executor:build_defs.bzl", + "stream_executor_friends", +) +load("//xla/tests:build_defs.bzl", "xla_test") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", + "tsl_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -31,45 +33,102 @@ package_group( cc_library( name = "rocm_diagnostics", - srcs = if_rocm_is_configured(["rocm_diagnostics.cc"]), - hdrs = if_rocm_is_configured(["rocm_diagnostics.h"]), - deps = if_rocm_is_configured([ + srcs = ["rocm_diagnostics.cc"], + hdrs = ["rocm_diagnostics.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + "//xla/stream_executor/gpu:gpu_diagnostics_header", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "//xla/stream_executor/gpu:gpu_diagnostics_header", - "//xla/stream_executor/platform", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + ], +) + +cc_library( + name = "rocm_context", + srcs = ["rocm_context.cc"], + hdrs = ["rocm_context.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_status", + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:context", + "//xla/stream_executor/gpu:context_map", + "//xla/stream_executor/gpu:scoped_activate_context", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +cc_library( + name = "rocm_driver_wrapper", + hdrs = ["rocm_driver_wrapper.h"], + defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"}, + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(klucke): Remove this tag once the target can be built without --config=rocm. + "manual", ]), + deps = [ + "@local_config_rocm//rocm:hip", # buildcleaner: keep + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", + "@local_tsl//tsl/platform:env", + ], ) cc_library( name = "rocm_driver", - srcs = if_rocm_is_configured(["rocm_driver.cc"]), - hdrs = if_rocm_is_configured([ - "rocm_driver_wrapper.h", - "rocm_driver.h", - ]), - deps = if_rocm_is_configured([ - # keep sorted - ":rocm_diagnostics", - "//xla/stream_executor", + srcs = ["rocm_driver.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_context", + ":rocm_driver_wrapper", + ":rocm_status", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", "//xla/stream_executor/gpu:context_map", "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/base", "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:hip", + "@com_google_absl//absl/types:span", + "@local_config_rocm//rocm:hip", # buildcleaner: keep "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", @@ -77,85 +136,113 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:stacktrace", - ]), + "@local_tsl//tsl/platform:status", + ], ) cc_library( name = "rocm_runtime", - srcs = if_rocm_is_configured(["rocm_runtime.cc"]), - hdrs = if_rocm_is_configured([ - "rocm_driver_wrapper.h", - "rocm_driver.h", - ]), - deps = if_rocm_is_configured([ - # keep sorted - "//xla/stream_executor", - "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_runtime_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:node_hash_map", + srcs = ["rocm_runtime.cc"], + hdrs = ["rocm_runtime.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_status", "@com_google_absl//absl/log", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@local_config_rocm//rocm:rocm_headers", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ]), + "@local_tsl//tsl/platform:errors", + ], ) cc_library( - name = "rocm_collectives", - srcs = if_rocm_is_configured(["rocm_collectives.cc"]), - deps = if_rocm_is_configured([ - # keep sorted - "//xla/stream_executor/gpu:gpu_collectives_header", - "//xla/stream_executor/gpu:gpu_driver_header", + name = "rocm_event", + srcs = ["rocm_event.cc"], + hdrs = ["rocm_event.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:event", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - ]), + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], ) -cc_library( - name = "rocm_event", - srcs = if_rocm_is_configured(["rocm_event.cc"]), - hdrs = if_rocm_is_configured(["rocm_event.h"]), - deps = if_rocm_is_configured([ - # keep sorted - ":rocm_driver", - "//xla/stream_executor", - "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_event_header", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:scoped_activate_context", +xla_test( + name = "rocm_event_test", + srcs = ["rocm_event_test.cc"], + backends = ["gpu"], + tags = ["rocm-only"] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), + deps = [ + ":rocm_event", + ":rocm_executor", + ":rocm_platform_id", + "//xla/stream_executor:event", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_googletest//:gtest_main", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], ) cc_library( name = "rocm_executor", - srcs = if_rocm_is_configured(["rocm_executor.cc"]), - hdrs = if_rocm_is_configured(["rocm_executor.h"]), - deps = if_rocm_is_configured([ - # keep sorted - ":rocm_collectives", - ":rocm_diagnostics", - ":rocm_driver", + srcs = ["rocm_executor.cc"], + hdrs = ["rocm_executor.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_command_buffer", + ":rocm_context", + ":rocm_diagnostics", # buildcleaner: keep + ":rocm_driver", # buildcleaner: keep + ":rocm_driver_wrapper", ":rocm_event", ":rocm_kernel", ":rocm_platform_id", ":rocm_runtime", + ":rocm_status", + ":rocm_stream", + ":rocm_timer", ":rocm_version_parser", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:blas", "//xla/stream_executor:command_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:dnn", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", @@ -163,30 +250,28 @@ cc_library( "//xla/stream_executor:host_memory_allocation", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:module_spec", + "//xla/stream_executor:platform", "//xla/stream_executor:plugin_registry", "//xla/stream_executor:semantic_version", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:context", - "//xla/stream_executor/gpu:gpu_collectives_header", "//xla/stream_executor/gpu:gpu_command_buffer", - "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_event", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_kernel_header", - "//xla/stream_executor/gpu:gpu_runtime_header", "//xla/stream_executor/gpu:gpu_stream", - "//xla/stream_executor/gpu:gpu_timer", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:read_numa_node", - "//xla/stream_executor/integrations:device_mem_allocator", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/gpu:scoped_activate_context", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -195,59 +280,123 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", - ]), + ], alwayslink = True, ) +xla_test( + name = "rocm_executor_test", + srcs = ["rocm_executor_test.cc"], + backends = ["gpu"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_executor", + "//xla/stream_executor:device_description", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "rocm_kernel", - srcs = if_rocm_is_configured(["rocm_kernel.cc"]), + srcs = ["rocm_kernel.cc"], + hdrs = ["rocm_kernel.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ + deps = [ + ":rocm_driver_wrapper", + ":rocm_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_kernel_header", - "//xla/stream_executor/gpu:gpu_driver_header", - ]), - alwayslink = True, + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], ) -cc_library( - name = "command_buffer_kernels", - srcs = ["command_buffer_kernels.cc"], +xla_test( + name = "rocm_kernel_test", + srcs = ["rocm_kernel_test.cc"], + backends = ["gpu"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ - "//xla/stream_executor:kernel_spec", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + ":rocm_kernel", + ":rocm_runtime", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/gpu:gpu_test_kernels", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) cc_library( name = "rocm_platform", - srcs = if_rocm_is_configured(["rocm_platform.cc"]), - hdrs = if_rocm_is_configured(["rocm_platform.h"]), + srcs = ["rocm_platform.cc"], + hdrs = ["rocm_platform.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted - ":rocm_collectives", - ":rocm_driver", + deps = [ + ":rocm_diagnostics", # buildcleaner: keep + ":rocm_driver_wrapper", ":rocm_executor", ":rocm_platform_id", - ":rocm_runtime", - "//xla/stream_executor", # buildcleaner: keep + ":rocm_status", + "//xla/stream_executor:device_description", "//xla/stream_executor:executor_cache", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_diagnostics_header", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/platform", - "@com_google_absl//absl/base", - "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings:str_format", + "//xla/stream_executor/platform:initialize", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", - ]), + "@local_tsl//tsl/platform:status", + ], alwayslink = True, # Registers itself with the PlatformManager. ) @@ -260,43 +409,54 @@ cc_library( cc_library( name = "rocblas_if_static", - deps = if_static([ - ":rocblas_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "rocblas_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:rocblas", ]), ) cc_library( name = "rocblas_wrapper", - hdrs = if_rocm_is_configured(["rocblas_wrapper.h"]), - deps = if_rocm_is_configured([ - # keep sorted + hdrs = ["rocblas_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocblas_if_static", ":rocm_executor", ":rocm_platform_id", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "//xla/tsl/util:determinism_for_kernels", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "rocblas_plugin", - srcs = if_rocm_is_configured(["rocm_blas.cc"]), - hdrs = if_rocm_is_configured(["rocm_blas.h"]), + srcs = ["rocm_blas.cc"], + hdrs = ["rocm_blas.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hipblas_lt_header", ":rocblas_if_static", ":rocblas_wrapper", @@ -304,21 +464,22 @@ cc_library( ":rocm_executor", ":rocm_helpers", ":rocm_platform_id", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:scratch_allocator", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/stream_executor/platform:initialize", "//xla/tsl/util:determinism_hdr_lib", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", @@ -326,95 +487,108 @@ cc_library( "@eigen_archive//:eigen3", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "hipfft_if_static", - deps = if_static([ - ":hipfft_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipfft_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipfft", ]), ) cc_library( name = "hipfft_plugin", - srcs = if_rocm_is_configured(["rocm_fft.cc"]), - hdrs = if_rocm_is_configured(["rocm_fft.h"]), + srcs = ["rocm_fft.cc"], + hdrs = ["rocm_fft.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hipfft_if_static", ":rocm_complex_converters", ":rocm_platform_id", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:device_memory", "//xla/stream_executor:fft", "//xla/stream_executor:plugin_registry", - "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_kernel_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/platform:initialize", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "miopen_if_static", - deps = if_static([ - ":miopen_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "miopen_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:miopen", ]), ) cc_library( name = "miopen_plugin", - srcs = if_rocm_is_configured(["rocm_dnn.cc"]), - hdrs = if_rocm_is_configured(["rocm_dnn.h"]), + srcs = ["rocm_dnn.cc"], + hdrs = ["rocm_dnn.h"], copts = [ # STREAM_EXECUTOR_CUDNN_WRAP would fail on Clang with the default # setting of template depth 256 "-ftemplate-depth-512", ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted - ":miopen_if_static", - ":rocm_diagnostics", - ":rocm_driver", + deps = [ + ":miopen_if_static", # build_cleaner: keep + ":rocm_diagnostics", # build_cleaner: keep + ":rocm_driver", # build_cleaner: keep ":rocm_executor", ":rocm_helpers", ":rocm_platform_id", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:dnn", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:plugin_registry", + "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "//xla/stream_executor/platform:initialize", "//xla/tsl/util:determinism_for_kernels", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", @@ -425,120 +599,146 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", - ]), + ], alwayslink = True, ) cc_library( name = "hiprand_if_static", - deps = if_static([ - ":hiprand_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hiprand_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hiprand", ]), ) cc_library( name = "hipsparse_if_static", - deps = if_static([ - ":hipsparse_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipsparse_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipsparse", ]), ) cc_library( name = "hipsparse_wrapper", - srcs = if_rocm_is_configured(["hipsparse_wrapper.h"]), - hdrs = if_rocm_is_configured(["hipsparse_wrapper.h"]), - deps = if_rocm_is_configured([ + srcs = ["hipsparse_wrapper.h"], + hdrs = ["hipsparse_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":hipsparse_if_static", ":rocm_executor", ":rocm_platform_id", "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "rocsolver_if_static", - deps = if_static([ - ":rocsolver_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "rocsolver_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:rocsolver", ]), ) cc_library( name = "rocsolver_wrapper", - srcs = if_rocm_is_configured(["rocsolver_wrapper.h"]), - hdrs = if_rocm_is_configured(["rocsolver_wrapper.h"]), - deps = if_rocm_is_configured([ + srcs = ["rocsolver_wrapper.h"], + hdrs = ["rocsolver_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_executor", ":rocm_platform_id", ":rocsolver_if_static", "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "hipsolver_if_static", - deps = if_static([ - ":hipsolver_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "hipsolver_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:hipsolver", ]), ) cc_library( name = "hipsolver_wrapper", - hdrs = if_rocm_is_configured(["hipsolver_wrapper.h"]), - deps = if_rocm_is_configured([ + hdrs = ["hipsolver_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":hipsolver_if_static", ":rocm_executor", ":rocm_platform_id", - ":hipsolver_if_static", "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) cc_library( name = "hipblaslt_if_static", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = if_rocm_hipblaslt([ "@local_config_rocm//rocm:hipblaslt", ]), @@ -546,14 +746,21 @@ cc_library( cc_library( name = "amdhipblaslt_plugin", - srcs = if_rocm_is_configured(["hip_blas_lt.cc"]), - hdrs = if_rocm_is_configured([ + srcs = ["hip_blas_lt.cc"], + hdrs = [ "hip_blas_lt.h", - "hipblaslt_wrapper.h", "hip_blas_utils.h", + "hipblaslt_wrapper.h", + ], + defines = {"__HIP_DISABLE_CPP_FUNCTIONS__": "1"}, + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), - deps = if_rocm_is_configured([ - # keep sorted + deps = [ ":hip_blas_utils", ":hipblas_lt_header", ":rocblas_plugin", @@ -563,22 +770,23 @@ cc_library( "//xla:status_macros", "//xla:types", "//xla:util", - "//xla/stream_executor", + "//xla/stream_executor:activate_context", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event_based_timer", "//xla/stream_executor:host_or_device_scalar", + "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor:stream", "//xla/stream_executor/gpu:gpu_blas_lt", "//xla/stream_executor/gpu:gpu_helpers_header", "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:scoped_activate_context", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]) + if_static([ + ] + if_static([ ":hipblaslt_if_static", ]), alwayslink = True, @@ -586,135 +794,169 @@ cc_library( cc_library( name = "hipblas_lt_header", - hdrs = if_rocm_is_configured([ + hdrs = [ "hip_blas_lt.h", - "hipblaslt_wrapper.h", "hip_blas_utils.h", + "hipblaslt_wrapper.h", + ], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - # keep sorted + deps = [ "//xla:types", - "//xla/stream_executor", "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", "//xla/stream_executor:host_or_device_scalar", "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "@com_google_absl//absl/status", "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", - ]), + ], ) cc_library( name = "hip_blas_utils", - srcs = if_rocm_is_configured(["hip_blas_utils.cc"]), - hdrs = if_rocm_is_configured(["hip_blas_utils.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["hip_blas_utils.cc"], + hdrs = ["hip_blas_utils.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":hipblas_lt_header", ":rocblas_plugin", - "//xla/stream_executor", "//xla/stream_executor:blas", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ]), + ], ) cc_library( name = "roctracer_if_static", - deps = if_static([ - ":roctracer_if_rocm_configured", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", ]), -) - -cc_library( - name = "roctracer_if_rocm_configured", - deps = if_rocm_is_configured([ + deps = if_static([ "@local_config_rocm//rocm:roctracer", ]), ) cc_library( name = "roctracer_wrapper", - srcs = if_rocm_is_configured(["roctracer_wrapper.h"]), - hdrs = if_rocm_is_configured(["roctracer_wrapper.h"]), - deps = if_rocm_is_configured([ - # keep sorted + srcs = ["roctracer_wrapper.h"], + hdrs = ["roctracer_wrapper.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ ":rocm_executor", ":rocm_platform_id", ":roctracer_if_static", - "//xla/stream_executor/platform", - "//xla/stream_executor/platform:dso_loader", "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:dso_loader", "@local_tsl//tsl/platform:env", - ]), + ], alwayslink = True, ) rocm_library( name = "rocm_helpers", - srcs = if_rocm_is_configured(["rocm_helpers.cu.cc"]), - deps = if_rocm_is_configured([ + srcs = ["rocm_helpers.cu.cc"], + # copybara:uncomment compatible_with = ["//buildenv/target:non_prod"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ "@local_config_rocm//rocm:rocm_headers", - ]), + ], alwayslink = True, ) cc_library( name = "rocm_complex_converters", - hdrs = if_rocm_is_configured(["rocm_complex_converters.h"]), - deps = ["@com_google_absl//absl/log:check"] + if_rocm_is_configured([ + hdrs = ["rocm_complex_converters.h"], + tags = [ + "gpu", + "rocm-only", + ], + deps = [ + "@com_google_absl//absl/log:check", "@local_config_rocm//rocm:rocm_headers", - ]), + ], ) cc_library( name = "all_runtime", copts = tsl_copts(), + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), visibility = ["//visibility:public"], - deps = if_rocm_is_configured([ - ":miopen_plugin", + deps = [ + ":amdhipblaslt_plugin", ":hipfft_plugin", + ":miopen_plugin", ":rocblas_plugin", ":rocm_driver", - ":rocm_platform", ":rocm_helpers", - ":amdhipblaslt_plugin", - ]), + ":rocm_platform", + ], alwayslink = 1, ) cc_library( name = "rocm_rpath", - data = [], linkopts = select({ "//conditions:default": [ "-Wl,-rpath,../local_config_rocm/rocm/rocm/lib", ], }), - deps = [], ) cc_library( name = "stream_executor_rocm", + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ + ":rocm_platform_id", ":rocm_rpath", - "//xla/stream_executor", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", "//xla/stream_executor/cuda:cuda_platform_id", "//xla/stream_executor/host:host_platform_id", - "//xla/stream_executor/platform:dso_loader", - "//xla/stream_executor/rocm:rocm_platform_id", ] + if_static( [":all_runtime"], ), @@ -733,7 +975,14 @@ cc_library( cc_test( name = "rocm_version_parser_test", - srcs = if_rocm_is_configured(["rocm_version_parser_test.cc"]), + srcs = ["rocm_version_parser_test.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), deps = [ ":rocm_version_parser", "//xla/stream_executor:semantic_version", @@ -744,3 +993,199 @@ cc_test( "@local_tsl//tsl/platform:test", ], ) + +cc_library( + name = "rocm_stream", + srcs = ["rocm_stream.cc"], + hdrs = ["rocm_stream.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_event", + ":rocm_kernel", + ":rocm_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:event", + "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_common", + "@com_google_absl//absl/base", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "rocm_stream_test", + srcs = ["rocm_stream_test.cc"], + backends = ["gpu"], + tags = ["rocm-only"] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_executor", + ":rocm_platform_id", + ":rocm_stream", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_test_kernels", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "rocm_timer", + srcs = ["rocm_timer.cc"], + hdrs = ["rocm_timer.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_event", + ":rocm_status", + "//xla/stream_executor:activate_context", + "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "rocm_timer_test", + srcs = ["rocm_timer_test.cc"], + backends = ["gpu"], + tags = ["rocm-only"] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_executor", + ":rocm_platform_id", + ":rocm_timer", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:kernel", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_test_kernels_rocm", + "@com_google_absl//absl/status", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "rocm_status", + srcs = ["rocm_status.cc"], + hdrs = ["rocm_status.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_config_rocm//rocm:rocm_headers", + ], +) + +cc_test( + name = "rocm_status_test", + srcs = ["rocm_status_test.cc"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_status", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "rocm_command_buffer", + srcs = ["rocm_command_buffer.cc"], + hdrs = ["rocm_command_buffer.h"], + tags = [ + "gpu", + "rocm-only", + ] + if_google([ + # TODO(b/360374983): Remove this tag once the target can be built without --config=rocm. + "manual", + ]), + deps = [ + ":rocm_driver_wrapper", + ":rocm_status", + "//xla/stream_executor:command_buffer", + "//xla/stream_executor/gpu:gpu_command_buffer", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_config_rocm//rocm:rocm_headers", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 5e04daf134b4d0..9b87f9e4a4e4c0 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -21,7 +21,7 @@ limitations under the License. #include "rocm/rocm_config.h" #include "xla/primitive_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/activate_context.h" #include "xla/util.h" #if TF_HIPBLASLT @@ -217,7 +217,8 @@ auto BlasLt::MatmulPlan::GetAlgorithms(size_t max_algorithm_count, hip_preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, max_workspace_size)); - gpu::ScopedActivateContext sac{blas_lt_ref_.parent_}; + std::unique_ptr activation = + blas_lt_ref_.parent_->Activate(); // hipBlasLt requires setting the bias pointer (even a dummy one), otherwise // no algorithms can be found for "bias epilogues". This is to be removed @@ -459,7 +460,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( "hipblaslt does not support auxiliary inputs / outputs"); } - gpu::ScopedActivateContext sac{blas_lt_ref_.parent_}; + std::unique_ptr activation = + blas_lt_ref_.parent_->Activate(); if (palgo != nullptr) { SE_HIPBLAS_RETURN_IF_ERROR(wrap::hipblasLtMatmul( diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index c839f51866ac0b..b781e2848cd479 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -28,10 +28,6 @@ limitations under the License. namespace stream_executor { -namespace gpu { -class GpuExecutor; -} // namespace gpu - namespace rocm { class BlasLt : public gpu::BlasLt { @@ -149,7 +145,7 @@ class BlasLt : public gpu::BlasLt { bool must_swap_operands_; }; // class MatmulPlan - explicit BlasLt(gpu::GpuExecutor* parent) + explicit BlasLt(StreamExecutor* parent) : parent_(parent), blas_lt_(nullptr, wrap::hipblasLtDestroy) {} absl::Status Init() override; @@ -160,7 +156,7 @@ class BlasLt : public gpu::BlasLt { ~BlasLt() override = default; private: - gpu::GpuExecutor* parent_; + StreamExecutor* parent_; mutable absl::Mutex mu_; Owned blas_lt_ ABSL_GUARDED_BY(mu_); }; diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc index 5aeadfeebda23a..06d62d4c512643 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/stream_executor/rocm/hip_blas_utils.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "xla/stream_executor/blas.h" #include "rocm/rocm_config.h" @@ -36,8 +37,11 @@ absl::Status ToStatus(hipblasStatus_t status, const char* prefix) { hipDataType AsHipblasDataType(blas::DataType type) { switch (type) { case blas::DataType::kF8E5M2: + case blas::DataType::kF8E4M3: case blas::DataType::kF8E4M3FN: - LOG(FATAL) << "hipblaslt does not support F8E5M2 and F8E4M3FN"; + case blas::DataType::kF8E3M4: + LOG(FATAL) + << "hipblaslt does not support F8E5M2, F8E4M3, F8E4M3FN and F8E3M4"; #if TF_ROCM_VERSION >= 60000 case blas::DataType::kF8E5M2FNUZ: return HIP_R_8F_E5M2_FNUZ; diff --git a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h index c53cff6a933913..09cac948f0d185 100644 --- a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -17,8 +17,6 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ -#define __HIP_DISABLE_CPP_FUNCTIONS__ - #include "rocm/rocm_config.h" #if TF_HIPBLASLT @@ -27,8 +25,7 @@ limitations under the License. #else #include "rocm/include/hipblaslt.h" #endif -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -47,22 +44,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define HIPBLASLT_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipblasltDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipblaslt lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define HIPBLASLT_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipblasltDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in hipblaslt lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/third_party/xla/xla/stream_executor/rocm/hipsolver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/hipsolver_wrapper.h index 8434ae03b96685..10bd1cb767a319 100644 --- a/third_party/xla/xla/stream_executor/rocm/hipsolver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/hipsolver_wrapper.h @@ -29,8 +29,7 @@ limitations under the License. #include "rocm/include/hipsolver.h" #endif -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -49,22 +48,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define HIPSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipsolverDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in hipsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define HIPSOLVER_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipsolverDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in hipsolver lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/third_party/xla/xla/stream_executor/rocm/hipsparse_wrapper.h b/third_party/xla/xla/stream_executor/rocm/hipsparse_wrapper.h index d66b2426d50d84..213ef7e2e7c89d 100644 --- a/third_party/xla/xla/stream_executor/rocm/hipsparse_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/hipsparse_wrapper.h @@ -27,10 +27,9 @@ limitations under the License. #else #include "rocm/include/hipsparse.h" #endif -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/platform.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" +#include "tsl/platform/platform.h" namespace stream_executor { namespace wrap { @@ -48,31 +47,30 @@ namespace wrap { #else -#define HIPSPARSE_API_WRAPPER(__name) \ - static struct DynLoadShim__##__name { \ - constexpr static const char* kName = #__name; \ - using FuncPtrT = std::add_pointer::type; \ - static void* GetDsoHandle() { \ - auto s = \ - stream_executor::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ - return s.value(); \ - } \ - static FuncPtrT LoadOrDie() { \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in miopen DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - } \ - static FuncPtrT DynLoad() { \ - static FuncPtrT f = LoadOrDie(); \ - return f; \ - } \ - template \ - hipsparseStatus_t operator()(Args... args) { \ - return DynLoad()(args...); \ - } \ +#define HIPSPARSE_API_WRAPPER(__name) \ + static struct DynLoadShim__##__name { \ + constexpr static const char* kName = #__name; \ + using FuncPtrT = std::add_pointer::type; \ + static void* GetDsoHandle() { \ + auto s = tsl::internal::CachedDsoLoader::GetHipsparseDsoHandle(); \ + return s.value(); \ + } \ + static FuncPtrT LoadOrDie() { \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in miopen DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + } \ + static FuncPtrT DynLoad() { \ + static FuncPtrT f = LoadOrDie(); \ + return f; \ + } \ + template \ + hipsparseStatus_t operator()(Args... args) { \ + return DynLoad()(args...); \ + } \ } __name; #endif diff --git a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h index 21d4301b543680..26e35cff9cf67a 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocblas_wrapper.h @@ -25,8 +25,7 @@ limitations under the License. #include "rocm/include/rocblas/rocblas.h" #include "rocm/rocm_config.h" -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/platform.h" @@ -44,7 +43,7 @@ namespace wrap { } __name; #else -using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; +using tsl::internal::CachedDsoLoader::GetRocblasDsoHandle; #define ROCBLAS_API_WRAPPER(__name) \ static struct DynLoadShim__##__name { \ @@ -263,6 +262,7 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_gemm_batched_ex_get_solutions) \ __macro(rocblas_gemm_batched_ex_get_solutions_by_type) \ __macro(rocblas_gemm_strided_batched_ex_get_solutions) \ +<<<<<<< HEAD __macro(rocblas_strsm_batched) \ __macro(rocblas_dtrsm_batched) \ __macro(rocblas_ctrsm_batched) \ @@ -275,23 +275,27 @@ using stream_executor::internal::CachedDsoLoader::GetRocblasDsoHandle; __macro(rocblas_is_managing_device_memory) \ __macro(rocblas_is_user_managing_device_memory) \ __macro(rocblas_set_workspace) +======= + __macro(rocblas_is_managing_device_memory) \ + __macro(rocblas_is_user_managing_device_memory) \ + __macro(rocblas_set_workspace) \ + __macro(rocblas_strsm_batched) \ + __macro(rocblas_dtrsm_batched) \ + __macro(rocblas_ctrsm_batched) \ + __macro(rocblas_ztrsm_batched) \ + __macro(rocblas_create_handle) \ + __macro(rocblas_destroy_handle) \ + __macro(rocblas_get_stream) \ + __macro(rocblas_set_stream) \ + __macro(rocblas_set_atomics_mode) \ + __macro(rocblas_get_version_string_size) \ + __macro(rocblas_get_version_string) +>>>>>>> master // clang-format on FOREACH_ROCBLAS_API(ROCBLAS_API_WRAPPER) -#if TF_ROCM_VERSION >= 60200 - -// clang-format off -#define FOREACH_ROCBLAS_API_62(__macro) \ - __macro(rocblas_get_version_string_size) \ - __macro(rocblas_get_version_string) -// clang-format on - -FOREACH_ROCBLAS_API_62(ROCBLAS_API_WRAPPER) - -#endif // TF_ROCM_VERSION >= 60200 - } // namespace wrap } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc index b769397193d978..f720ff459d875b 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.cc @@ -15,29 +15,28 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_blas.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/rocm/rocblas_wrapper.h" - #define EIGEN_USE_GPU #define EIGEN_USE_HIP #include #include +#include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/types/span.h" +#include "absl/synchronization/mutex.h" #include "unsupported/Eigen/CXX11/Tensor" #include "rocm/rocm_config.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/platform/dso_loader.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" +#include "xla/stream_executor/rocm/rocblas_wrapper.h" #include "xla/stream_executor/rocm/rocm_complex_converters.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/scratch_allocator.h" @@ -109,7 +108,7 @@ static std::string ToString(rocblas_status status) { } bool ROCMBlas::Init() { - ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); rocblas_status ret = wrap::rocblas_create_handle(&blas_); if (ret != rocblas_status_success) { LOG(ERROR) << "failed to create rocBLAS handle: " << ToString(ret); @@ -136,7 +135,7 @@ bool ROCMBlas::Init() { return true; } -ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent) +ROCMBlas::ROCMBlas(StreamExecutor *parent) : parent_(CHECK_NOTNULL(parent)), blas_(nullptr) #if TF_HIPBLASLT @@ -148,32 +147,31 @@ ROCMBlas::ROCMBlas(gpu::GpuExecutor *parent) ROCMBlas::~ROCMBlas() { if (blas_ != nullptr) { - ScopedActivateContext sac{parent_}; + std::unique_ptr activation = parent_->Activate(); wrap::rocblas_destroy_handle(blas_); } } bool ROCMBlas::SetStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); CHECK(blas_ != nullptr); - ScopedActivateContext sac{parent_}; - - rocblas_status ret = - wrap::rocblas_set_stream(blas_, AsGpuStreamValue(stream)); - if (ret != rocblas_status_success) { + auto handle = (stream != nullptr) ? AsGpuStreamValue(stream) : nullptr; + if (auto ret = wrap::rocblas_set_stream(blas_, handle); + ret != rocblas_status_success) { LOG(ERROR) << "failed to set stream for rocBLAS calls: " << ToString(ret); return false; } - return true; } -hipStream_t ROCMBlas::ROCMStream(Stream *stream) { - CHECK(stream != nullptr); - CHECK(AsGpuStreamValue(stream) != nullptr); - ScopedActivateContext sac{parent_}; - return AsGpuStreamValue(stream); +absl::StatusOr ROCMBlas::IsMainStreamSet() const { + absl::MutexLock lock{&mu_}; + CHECK(blas_ != nullptr); + GpuStreamHandle handle{}; + if (auto ret = wrap::rocblas_get_stream(blas_, &handle); + ret != rocblas_status_success) { + return absl::InternalError("failed to get the current stream value"); + } + return (handle == nullptr); } namespace { @@ -351,11 +349,11 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, absl::MutexLock lock{&mu_}; CHECK(blas_ != nullptr); + std::unique_ptr activation = parent_->Activate(); if (!SetStream(stream)) { return absl::InternalError("Setting stream failed"); } - ScopedActivateContext sac{parent_}; rocblas_status ret; // set the atomics mode, leaving default to library bool allow_atomics = !OpDeterminismRequired(); @@ -383,6 +381,8 @@ absl::Status ROCMBlas::DoBlasInternalImpl(FuncT rocblas_func, Stream *stream, #endif ret = rocblas_func(blas_, std::forward(args)...); + SetStream(nullptr); // Resetting stream after the function call + if (ret != rocblas_status_success) { auto err_str = absl::StrFormat("%s failed with: %s", FuncT::kName, ToString(ret)); @@ -897,7 +897,7 @@ absl::StatusOr> AllocateStridedBuffer( if (scratch_allocator == nullptr) { return absl::InternalError("scratch_allocator is null"); } - TF_ASSIGN_OR_RETURN(DeviceMemory batch_matrix_bytes, + TF_ASSIGN_OR_RETURN(DeviceMemory batch_matrix_bytes, scratch_allocator->AllocateBytes(matrix_batch_byte_size)); res.device_mem = DeviceMemory(batch_matrix_bytes); res.reallocated = true; @@ -1111,9 +1111,9 @@ bool ROCMBlas::DoBlasGemmBatched( #define IMPL_DoBlasGemmBatched(T, Fun) \ bool ROCMBlas::DoBlasGemmBatched( \ Stream *stream, blas::Transpose transa, blas::Transpose transb, \ - uint64_t m, uint64_t n, uint64 k, T alpha, DeviceMemorySlice a_array, \ - int lda, DeviceMemorySlice b_array, int ldb, T beta, \ - DeviceMemorySlice c_array, int ldc, int batch_count, \ + uint64_t m, uint64_t n, uint64_t k, T alpha, \ + DeviceMemorySlice a_array, int lda, DeviceMemorySlice b_array, \ + int ldb, T beta, DeviceMemorySlice c_array, int ldc, int batch_count, \ const NumericOptions &numeric_options, \ ScratchAllocator *scratch_allocator, blas::CallContext context) { \ MaybeLogGemmOp(GemmCallTrace::GemmType::kBatched, context, a_array.size(), \ @@ -1136,7 +1136,7 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) #define IMPL_DoBlasTrsm(T, Fun, Fun2) \ bool ROCMBlas::DoBlasTrsm(Stream *stream, blas::Side side, \ blas::UpperLower uplo, blas::Transpose transa, \ - blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Diagonal diag, uint64_t m, uint64_t n, \ T alpha, const DeviceMemory &a, int lda, \ DeviceMemory *b, int ldb) { \ return DoBlasInternal(Fun, stream, /* pointer_mode_host = */ true, \ @@ -1148,7 +1148,7 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) \ bool ROCMBlas::DoBlasTrsmBatched( \ Stream *stream, blas::Side side, blas::UpperLower uplo, \ - blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64 n, \ + blas::Transpose transa, blas::Diagonal diag, uint64_t m, uint64_t n, \ T alpha, const DeviceMemory &as, int lda, DeviceMemory *bs, \ int ldb, int batch_count) { \ return DoBlasInternal(Fun2, stream, true /* = pointer_mode_host */, \ @@ -1241,8 +1241,11 @@ IMPL_DoBlasGemmBatched(float, wrap::rocblas_sgemm_strided_batched) } absl::Status ROCMBlas::GetVersion(std::string *version) { +<<<<<<< HEAD #if TF_ROCM_VERSION >= 60200 // Not available in ROCM-6.1 +======= +>>>>>>> master absl::MutexLock lock{&mu_}; size_t len = 0; if (auto res = wrap::rocblas_get_version_string_size(&len); @@ -1258,9 +1261,6 @@ absl::Status ROCMBlas::GetVersion(std::string *version) { } *version = std::string(buf.begin(), buf.end()); return absl::OkStatus(); -#else - return absl::UnimplementedError(""); -#endif } } // namespace gpu @@ -1275,17 +1275,7 @@ void initialize_rocblas() { ->RegisterFactory( rocm::kROCmPlatformId, "rocBLAS", [](StreamExecutor *parent) -> blas::BlasSupport * { - gpu::GpuExecutor *rocm_executor = - dynamic_cast(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the " - "rocBLAS " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - gpu::ROCMBlas *blas = new gpu::ROCMBlas(rocm_executor); + gpu::ROCMBlas *blas = new gpu::ROCMBlas(parent); if (!blas->Init()) { // Note: Init() will log a more specific error. delete blas; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h index 6199d0e551a815..48a3576293b59e 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_blas.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_blas.h @@ -33,7 +33,6 @@ limitations under the License. #endif #include "xla/stream_executor/blas.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #if TF_HIPBLASLT #include "xla/stream_executor/rocm/hip_blas_lt.h" @@ -75,21 +74,19 @@ using RocBlasType_t = rocblas_float_complex, std::complex, rocblas_double_complex>::type; -class GpuExecutor; - // BLAS plugin for ROCM platform via rocBLAS library. // // This satisfies the platform-agnostic BlasSupport interface. // // Note that the rocBLAS handle that this encapsulates is implicitly tied to the -// context (and, as a result, the device) that the parent GpuExecutor is tied +// context (and, as a result, the device) that the parent StreamExecutor is tied // to. This simply happens as an artifact of creating the rocBLAS handle when a // ROCM context is active. // // Thread-safe post-initialization. class ROCMBlas : public blas::BlasSupport { public: - explicit ROCMBlas(GpuExecutor *parent); + explicit ROCMBlas(StreamExecutor *parent); // Allocates a rocBLAS handle. bool Init(); @@ -115,9 +112,6 @@ class ROCMBlas : public blas::BlasSupport { // invoked before calling into rocBLAS. bool SetStream(Stream *stream) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); - // Returns the underlying ROCm stream - hipStream_t ROCMStream(Stream *stream); - // A helper function that calls the real rocBLAS function together with error // handling. // @@ -188,11 +182,11 @@ class ROCMBlas : public blas::BlasSupport { ScratchAllocator *scratch_allocator); // mutex that guards the rocBLAS handle for this device. - absl::Mutex mu_; + mutable absl::Mutex mu_; - // GpuExecutor which instantiated this ROCMBlas. + // StreamExecutor which instantiated this ROCMBlas. // Immutable post-initialization. - GpuExecutor *parent_; + StreamExecutor *parent_; // rocBLAS library handle on the device. rocblas_handle blas_ ABSL_GUARDED_BY(mu_); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc new file mode 100644 index 00000000000000..59339fa1a6cf00 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_command_buffer.h" + +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor::gpu { +namespace { +absl::StatusOr CreateGraph() { + VLOG(2) << "Create new HIP graph"; + hipGraph_t graph; + TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(&graph, /*flags=*/0), + "Failed to create HIP graph")); + VLOG(2) << "Created HIP graph " << graph; + return graph; +} +} // namespace + +absl::StatusOr> RocmCommandBuffer::Create( + Mode mode, GpuExecutor* parent) { + TF_ASSIGN_OR_RETURN(hipGraph_t graph, CreateGraph()); + return std::unique_ptr( + new RocmCommandBuffer(mode, parent, graph, + /*is_owned_graph=*/true)); +} + +std::unique_ptr RocmCommandBuffer::CreateNestedCommandBuffer( + hipGraph_t graph) { + return std::unique_ptr( + new RocmCommandBuffer(Mode::kNested, parent_, graph, + /*is_owned_graph=*/false)); +} + +absl::StatusOr +RocmCommandBuffer::GetSetIfConditionKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} + +absl::StatusOr +RocmCommandBuffer::GetSetIfElseConditionKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} + +absl::StatusOr +RocmCommandBuffer::GetSetCaseConditionKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} + +absl::StatusOr +RocmCommandBuffer::GetSetForConditionKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} + +absl::StatusOr +RocmCommandBuffer::GetSetWhileConditionKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} + +absl::StatusOr +RocmCommandBuffer::GetNoOpKernel() { + return absl::UnimplementedError("Conditionals are not supported on ROCM."); +} +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h new file mode 100644 index 00000000000000..17c095a9d69772 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_command_buffer.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_COMMAND_BUFFER_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_COMMAND_BUFFER_H_ + +#include + +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/gpu/gpu_command_buffer.h" +#include "xla/stream_executor/gpu/gpu_executor.h" + +namespace stream_executor::gpu { + +// Implements GpuCommandBuffer for AMD GPUs. +class RocmCommandBuffer : public GpuCommandBuffer { + public: + // Creates a new ROCm command buffer and the underlying HIP graph. + static absl::StatusOr> Create( + Mode mode, GpuExecutor* parent); + + private: + RocmCommandBuffer(Mode mode, GpuExecutor* parent, hipGraph_t graph, + bool is_owned_graph) + : GpuCommandBuffer(mode, parent, graph, is_owned_graph), + parent_(parent) {} + + absl::StatusOr GetSetIfConditionKernel() override; + absl::StatusOr GetSetIfElseConditionKernel() + override; + absl::StatusOr GetSetCaseConditionKernel() override; + absl::StatusOr GetSetForConditionKernel() override; + absl::StatusOr GetSetWhileConditionKernel() + override; + absl::StatusOr GetNoOpKernel() override; + + std::unique_ptr CreateNestedCommandBuffer( + hipGraph_t graph) override; + + GpuExecutor* parent_; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_COMMAND_BUFFER_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_context.cc b/third_party/xla/xla/stream_executor/rocm/rocm_context.cc new file mode 100644 index 00000000000000..9addbb1a29af4b --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_context.cc @@ -0,0 +1,275 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_context.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "rocm/include/hip/hip_runtime_api.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/context_map.h" +#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" + +namespace stream_executor::gpu { + +namespace { + +// Returns the current context or dies if it fails. +hipCtx_t CurrentContextOrDie() { + hipCtx_t current = nullptr; + TF_CHECK_OK( + ToStatus(hipCtxGetCurrent(¤t), "Failed to query current context")); + return current; +} + +// Returns the current context and checks that it is in the set of HIP contexts +// created by StreamExecutor (to ensure that the HIP runtime didn't create a +// context behind our backs). +hipCtx_t CurrentContext() { + hipCtx_t current = CurrentContextOrDie(); + if (current != nullptr && !RocmContext::GetContextMap()->Has(current)) { + LOG(FATAL) << "current context was not created by the StreamExecutor " + "rocm_driver API: " + << current + << "; a HIP runtime call " + "was likely performed without using a StreamExecutor context"; + } + return current; +} + +// Returns the amount of memory reserved by ROCm libraries. +bool GetReservedMemory(uint64_t* reserve) { + hipDeviceProp_t props; + hipDevice_t dev; + hipError_t res = wrap::hipGetDevice(&dev); + + if (res != hipSuccess) { + LOG(FATAL) << "failed to query current device: " << ToString(res); + return false; + } + res = wrap::hipGetDeviceProperties(&props, dev); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query device properties: " << ToString(res); + return false; + } + + std::string gcnArchName = props.gcnArchName; + auto compute_capability = RocmComputeCapability(gcnArchName); + // On gfx90a, we hide 1 GB of GPU memory (512MB for gfx908) from TF, + // to allow for late allocations by internal ROCm libraries + // (e.g. rocBLAS alone needs~200 MB to put its kernels as of ROCm 4.1) + const uint64_t RESERVED_GFX908 = 1048576 * 512; + const uint64_t RESERVED_GFX9_X = 1048576 * 1024; + const uint64_t RESERVED_GFX10_X = 1048576 * 512; + const uint64_t RESERVED_GFX11_X = 1048576 * 512; + if (compute_capability.gfx9_mi100()) { + *reserve = RESERVED_GFX908; + } else if (compute_capability.gfx9_mi200_or_later()) { + *reserve = RESERVED_GFX9_X; + } else if (compute_capability.gfx10_rx68xx() || + compute_capability.gfx10_rx69xx()) { + *reserve = RESERVED_GFX10_X; + } else if (compute_capability.gfx11_rx7900()) { + *reserve = RESERVED_GFX11_X; + } + + return true; +} + +} // namespace + +// Returns the singleton ContextMap. +ContextMap* RocmContext::GetContextMap() { + static ContextMap* context_map = + new ContextMap([](void* ptr) { + int device_ordinal; + hipError_t result = + hipPointerGetAttribute(static_cast(&device_ordinal), + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(ptr)); + if (result != hipSuccess) { + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr + << ". Error: " << ToString(result); + } + return device_ordinal; + }); + return context_map; +} + +bool RocmContext::GetDeviceTotalMemory(hipDevice_t device, uint64_t* result) { + size_t value = -1; + hipError_t res = wrap::hipDeviceTotalMem(&value, device); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query total available memory: " << ToString(res); + return false; + } + uint64_t reserve = 0; + if (!GetReservedMemory(&reserve)) { + LOG(ERROR) << "failed to reserved device memory for ROCm libraries"; + return false; + } + *result = value - reserve; + return true; +} + +bool RocmContext::GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out) { + ScopedActivateContext activation(this); + size_t free = 0; + size_t total = 0; + hipError_t res = wrap::hipMemGetInfo(&free, &total); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query device memory info: " << ToString(res); + return false; + } + + uint64_t reserve = 0; + if (!GetReservedMemory(&reserve)) { + LOG(ERROR) << "failed to reserved device memory for ROCm libraries"; + return false; + } + + VLOG(1) << "Device memory: " << total / 1048576 << " MB total, " + << free / 1048576 << " MB free, reserving " << reserve / 1048576 + << " MB"; + + // overflow check + if (free > std::numeric_limits::max()) { + LOG(ERROR) << "free memory (" << free << ") is overflow int64_t"; + return false; + } + + *free_out = free >= reserve ? free - reserve : 0; + *total_out = total - reserve; + return true; +} + +RocmContext::~RocmContext() { + hipCtx_t former_context = CurrentContext(); + // Explicitly call RocmContext::SetActive() to silence clang-tidy warnings + // about calling a virtual method in the destructor. + RocmContext::SetActive(); + hipDevice_t device; + CHECK_EQ(hipSuccess, wrap::hipCtxGetDevice(&device)); + CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); + + auto res = wrap::hipDevicePrimaryCtxRelease(device); + + if (res != hipSuccess) { + LOG(ERROR) << "failed to release HIP context; leaking: " << ToString(res); + } + + GetContextMap()->Remove(context()); +} + +void RocmContext::SetActive() { + TF_CHECK_OK( + ToStatus(wrap::hipCtxSetCurrent(context_), "Failed setting context")); +} + +bool RocmContext::IsActive() const { return CurrentContext() == context_; } + +absl::Status RocmContext::Synchronize() { + ScopedActivateContext activation(this); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipDeviceSynchronize(), + "could not synchronize on ROCM device")); + return absl::OkStatus(); +} + +absl::StatusOr RocmContext::Create(int device_ordinal, + hipDevice_t device) { + RocmContext* context = nullptr; + + int flags = 0; + + hipError_t res; + hipCtx_t former_context; + hipCtx_t new_context; + + unsigned int former_primary_context_flags; + int former_primary_context_is_active; + CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxGetState( + device, &former_primary_context_flags, + &former_primary_context_is_active)); + if (former_primary_context_flags != flags) { + if (former_primary_context_is_active) { + LOG(ERROR) + << "The primary context is active and has a different flag set (" + << former_primary_context_flags << ") than the desired flag set (" + << flags << ")."; + } else { + CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxSetFlags(device, flags)); + } + } + + former_context = CurrentContextOrDie(); + res = wrap::hipDevicePrimaryCtxRetain(&new_context, device); + if (former_context != nullptr) { + hipDevice_t former_device; + if (wrap::hipCtxGetDevice(&former_device) == hipSuccess) { + if (former_device == device) { + if (former_context == new_context) { + VLOG(2) << "The primary context " << former_context << " for device " + << device + << " exists before initializing the StreamExecutor."; + } else { + LOG(WARNING) << "A non-primary context " << former_context + << " for device " << device + << " exists before initializing the StreamExecutor. The " + << "primary context is now " << new_context << ". We " + << "haven't verified StreamExecutor works with that."; + } + } + } else { + LOG(ERROR) << "Failed to get the device of the current context " + << former_context; + } + } + CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); + + if (res == hipSuccess) { + context = GetContextMap()->Add(new_context, device_ordinal); + CHECK(context != nullptr) + << "success in this call must entail non-null result"; + VLOG(2) << "created or reused context " << new_context + << " for this thread"; + return context; + } + + std::string message = + "failed call to hipDevicePrimaryCtxRetain: " + ToString(res); + if (res == hipErrorOutOfMemory) { + uint64_t total_memory; + if (GetDeviceTotalMemory(device, &total_memory)) { + absl::StrAppend(&message, "; total memory reported: ", total_memory); + } else { + absl::StrAppend(&message, "; could not query total memory"); + } + } + + return absl::InternalError(message); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_context.h b/third_party/xla/xla/stream_executor/rocm/rocm_context.h new file mode 100644 index 00000000000000..60480f49054128 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_context.h @@ -0,0 +1,72 @@ +#include "absl/status/statusor.h" +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The ROCM-specific Driver library support, implementing the general Driver +// interface. + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_ + +#include + +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/gpu/context.h" +#include "xla/stream_executor/gpu/context_map.h" + +namespace stream_executor::gpu { + +// RocmContext implements the Context class for ROCm GPUs. +class RocmContext : public Context { + public: + RocmContext(hipCtx_t context, const int ordinal) + : context_(context), device_ordinal_(ordinal) {} + ~RocmContext() override; + + hipCtx_t context() const { return context_; } + void SetActive() override; + bool IsActive() const override; + int device_ordinal() const override { return device_ordinal_; } + absl::Status Synchronize() override; + + // Disallow copying and moving. + RocmContext(RocmContext&&) = delete; + RocmContext(const RocmContext&) = delete; + RocmContext& operator=(RocmContext&&) = delete; + RocmContext& operator=(const RocmContext&) = delete; + + // Returns the free amount of memory and total amount of memory, as reported + // by hipDeviceTotalMem. + bool GetDeviceMemoryUsage(int64_t* free_out, int64_t* total_out); + + // Returns the total amount of memory available on the device. + static bool GetDeviceTotalMemory(hipDevice_t device, uint64_t* result); + + // Returns the context map for all XLA-known ROCm contexts. + static ContextMap* GetContextMap(); + + // Creates a new context for the given device. + static absl::StatusOr Create(int device_ordinal, + hipDevice_t device); + + private: + hipCtx_t const context_; + const int device_ordinal_; +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_CONTEXT_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc index 039cee85f9d728..cc86eaef11e5a7 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/base/optimization.h" #include "absl/base/thread_annotations.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -29,13 +30,11 @@ limitations under the License. #include "Eigen/Core" #include "rocm/include/miopen/miopen.h" #include "rocm/rocm_config.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/rocm/rocm_diagnostics.h" @@ -45,6 +44,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/util/determinism.h" #include "xla/tsl/util/env_var.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/hash.h" @@ -219,16 +219,18 @@ class MIOpenHandle { public: // Takes ownership of the executor context and the lock to access MIOpen // using handle. - MIOpenHandle(GpuExecutor* executor, std::unique_ptr lock, + MIOpenHandle(StreamExecutor* executor, std::unique_ptr lock, miopenHandle_t handle) - : context_(executor), lock_(std::move(lock)), handle_(handle) {} + : context_(executor->Activate()), + lock_(std::move(lock)), + handle_(handle) {} // Returns MIOpen handle. To be passed directly to MIOpen APIs, don't keep // a copy. miopenHandle_t handle() const { return handle_; } private: - ScopedActivateContext context_; + std::unique_ptr context_; std::unique_ptr lock_; miopenHandle_t handle_; // Not owned. }; @@ -252,7 +254,7 @@ namespace wrap { static const char* kName; \ using FuncPtrT = std::add_pointer::type; \ static void* GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetMiopenDsoHandle(); \ + auto s = tsl::internal::CachedDsoLoader::GetMiopenDsoHandle(); \ return s.value(); \ } \ static FuncPtrT LoadOrDie() { \ @@ -738,7 +740,7 @@ class MIOpenAccess { // The null stream synchronizes with all other streams and it is // therefore a bad idea (performance wise) to call any MIOpen APIs that // enqueue work in the stream. - MIOpenHandle GetHandle(GpuExecutor* executor, Stream* stream) { + MIOpenHandle GetHandle(StreamExecutor* executor, Stream* stream) { auto lock = std::make_unique(&mutex_); mutex_.AssertHeld(); hipStream_t hip_stream = stream ? AsGpuStreamValue(stream) : nullptr; @@ -755,7 +757,7 @@ class MIOpenAccess { miopenHandle_t handle_ ABSL_GUARDED_BY(mutex_); // Owned. }; -MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) { +MIOpenSupport::MIOpenSupport(StreamExecutor* parent) : parent_(parent) { // by default, the Get*Algorithm API will return the list of all applicable // algorithms return_best_algo_only_ = false; @@ -778,7 +780,7 @@ MIOpenSupport::MIOpenSupport(GpuExecutor* parent) : parent_(parent) { } absl::Status MIOpenSupport::Init() { - ScopedActivateContext context(parent_); + std::unique_ptr context = parent_->Activate(); miopenHandle_t miopen_handle = nullptr; auto status = wrap::miopenCreateWithStream( reinterpret_cast(&miopen_handle), (hipStream_t)(0)); @@ -1970,7 +1972,7 @@ class MixinBase {}; } // namespace #define RETURN_IF_MIOPEN_ERROR(STATUS, ...) \ - if (!SE_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) { \ + if (!ABSL_PREDICT_TRUE((STATUS) == miopenStatusSuccess)) { \ std::string error_msg = absl::StrCat(ToString(STATUS), " ", __VA_ARGS__); \ SetFailure(::absl::UnknownError(error_msg)); \ LOG(ERROR) << error_msg; \ @@ -2392,7 +2394,7 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, const MIOpenRnnDescriptor& rnn_desc, const MIOpenRnnSequenceTensorDescriptor& input_desc, ScratchAllocator* workspace_allocator, - DeviceMemory* workspace) { + DeviceMemory* workspace) { // Query the workspace size. size_t workspace_size_in_bytes = 0; auto status = wrap::miopenGetRNNWorkspaceSize( @@ -2416,7 +2418,7 @@ bool CreateRnnWorkspace(Stream* stream, miopenHandle_t miopen_handle, return false; } } else { - *workspace = DeviceMemory(); + *workspace = DeviceMemory(); } return true; } @@ -2463,7 +2465,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( } // create the workspace - DeviceMemory workspace; + DeviceMemory workspace; if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc, workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; @@ -2472,7 +2474,7 @@ absl::Status MIOpenSupport::DoRnnForwardImpl( // query the reserve space size // allocate the reserve space - DeviceMemory reserve_space; + DeviceMemory reserve_space; if (is_training) { size_t reserve_space_size_in_bytes = 0; auto status = wrap::miopenGetRNNTrainingReserveSize( @@ -2577,7 +2579,7 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { // extract model parameters @@ -2601,7 +2603,7 @@ absl::Status MIOpenSupport::DoRnnBackwardImpl( } // create the workspace - DeviceMemory workspace; + DeviceMemory workspace; if (!CreateRnnWorkspace(stream, miopen.handle(), rnn_desc, input_desc, workspace_allocator, &workspace)) { LOG(ERROR) << "Unable to create rnn workspace"; @@ -2778,7 +2780,7 @@ absl::Status MIOpenSupport::DoPrepareForCtcLoss( absl::Span labels_lengths_data, absl::Span input_lengths_data, const NumericOptions& numeric_options, ScratchAllocator* scratch_allocator, - DeviceMemory* scratch_memory, int* ctc_loss_algo_id) { + DeviceMemory* scratch_memory, int* ctc_loss_algo_id) { auto miopen = miopen_->GetHandle(parent_, stream); MIOpenCTCLossDescriptor miopen_ctc_loss_desc(ToMIOpenDataType(element_type)); @@ -2805,7 +2807,7 @@ absl::Status MIOpenSupport::DoPrepareForCtcLoss( "Failed to determine scratch memory size for MIOpen CTC Loss"); } - *scratch_memory = DeviceMemory(); + *scratch_memory = DeviceMemory(); // Allocate the workspace. if (workspace_size_in_bytes != 0) { @@ -2840,7 +2842,7 @@ absl::Status MIOpenSupport::DoCtcLossImpl( absl::Span input_lengths_data, DeviceMemoryBase costs_data, const MIOpenRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory, int ctc_loss_algo_id) { + DeviceMemory scratch_memory, int ctc_loss_algo_id) { auto miopen = miopen_->GetHandle(parent_, stream); int kNumTimestamps = probs_desc.num_layers(); @@ -2870,7 +2872,7 @@ absl::Status MIOpenSupport::DoCtcLoss( absl::Span labels_lengths_data, absl::Span input_lengths_data, DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, - DeviceMemoryBase grads_data, DeviceMemory scratch_memory, + DeviceMemoryBase grads_data, DeviceMemory scratch_memory, int ctc_loss_algo_id) { // Current MIOPen CTC Loss only supports the float datatype if (element_type != dnn::DataType::kFloat) { @@ -3089,7 +3091,7 @@ bool MIOpenSupport::DoRnnBackward( DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { const MIOpenRnnDescriptor& miopen_rnn_desc = @@ -3142,7 +3144,7 @@ bool MIOpenSupport::DoRnnBackward( DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { const MIOpenRnnDescriptor& miopen_rnn_desc = @@ -3196,7 +3198,7 @@ bool MIOpenSupport::DoRnnBackward( DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) { LOG(ERROR) << "miopen does not support half type RNN bwd yet"; @@ -3216,7 +3218,7 @@ void* MIOpenAllocatorCallback(void* ctx, size_t size_in_bytes) { auto* mac = static_cast(ctx); auto allocated = mac->scratch_allocator_->AllocateBytes(size_in_bytes); - DeviceMemory scratch; + DeviceMemory scratch; if (allocated.ok()) { scratch = allocated.value(); return scratch.opaque(); @@ -3239,7 +3241,7 @@ absl::Status MIOpenSupport::DoPrepareForConvolution( const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::AlgorithmConfig& algorithm_config, ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, - DeviceMemory* scratch_memory) { + DeviceMemory* scratch_memory) { std::optional input_algo_desc = algorithm_config.algorithm(); @@ -3279,7 +3281,7 @@ absl::Status MIOpenSupport::DoPrepareForConvolution( class RocmConvRunner : public dnn::ConvRunner { public: - RocmConvRunner(GpuExecutor* parent, MIOpenAccess* miopen, int64_t algo_id, + RocmConvRunner(StreamExecutor* parent, MIOpenAccess* miopen, int64_t algo_id, size_t workspace_size, dnn::ConvolutionKind kind, dnn::DataType input_type, dnn::DataType output_type, bool use_immediate_mode, @@ -3424,7 +3426,7 @@ class RocmConvRunner : public dnn::ConvRunner { } private: - GpuExecutor* parent_; + StreamExecutor* parent_; MIOpenAccess* miopen_; int64_t algo_id_; size_t workspace_size_; @@ -3445,7 +3447,7 @@ absl::Status MIOpenSupport::DoConvolve( DeviceMemoryBase filter_data, const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, + dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, dnn::ProfileResult* output_profile_result) { TF_ASSIGN_OR_RETURN( auto runner, @@ -3832,7 +3834,7 @@ absl::Status MIOpenSupport::GetMIOpenConvolveAlgorithmsFindMode( } // allocate scratch memory - DeviceMemory scratch_memory; + DeviceMemory scratch_memory; if (scratch_memory_size != 0) { if (scratch_allocator == nullptr) { return absl::InternalError( @@ -4086,7 +4088,7 @@ bool MIOpenSupport::DoBatchNormalizationBackward( dnn::ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) { return DoBatchNormalizationBackwardImpl( stream, miopenHalf, miopenFloat, y_backprop, x, scale, mean, @@ -4105,7 +4107,7 @@ bool MIOpenSupport::DoBatchNormalizationBackward( dnn::ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) { return DoBatchNormalizationBackwardImpl( stream, miopenFloat, miopenFloat, y_backprop, x, scale, mean, @@ -4397,7 +4399,7 @@ absl::Status MIOpenSupport::DoPoolForward( TF_ASSIGN_OR_RETURN(auto pooling_desc, scope(pooling_dimensions)); bool do_backward = false; - uint8* workspace = nullptr; + uint8_t* workspace = nullptr; size_t workspace_size = 0; if (m_pooling_cache_enabled && element_type == dnn::DataType::kFloat) { do_backward = true; @@ -4417,11 +4419,12 @@ absl::Status MIOpenSupport::DoPoolForward( miopenFloat, pdesc); if (cache_hit) { // reusing the same buffer - workspace = reinterpret_cast(pdesc->workspace.ptr()->opaque()); + workspace = + reinterpret_cast(pdesc->workspace.ptr()->opaque()); } else { TF_ASSIGN_OR_RETURN(auto allocated, workspace_allocator->AllocateBytes(workspace_size)); - workspace = reinterpret_cast(allocated.opaque()); + workspace = reinterpret_cast(allocated.opaque()); } } } @@ -4474,7 +4477,7 @@ void PoolingWorkspaceCache::insert( const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - ScopedDeviceMemory& workspace, size_t wsp_size, + ScopedDeviceMemory& workspace, size_t wsp_size, hipStream_t hip_stream) { PoolingWorkspaceDescriptor* desc = 0; auto it = cache.find(p); @@ -4550,8 +4553,8 @@ absl::Status MIOpenSupport::DoPoolBackward( TF_ASSIGN_OR_RETURN(auto dest_desc, scope(output_dimensions, miopen_dtype)); TF_ASSIGN_OR_RETURN(auto pooling_desc, scope(pooling_dimensions)); - uint8* workspace_ptr = 0; - DeviceMemory workspace; + uint8_t* workspace_ptr = 0; + DeviceMemory workspace; PoolingWorkspaceDescriptor* pdesc = 0; size_t workspace_size_in_bytes = 0; @@ -4572,7 +4575,7 @@ absl::Status MIOpenSupport::DoPoolBackward( if (cache_hit) { assert(pdesc != 0); workspace_ptr = - reinterpret_cast(pdesc->workspace.ptr()->opaque()); + reinterpret_cast(pdesc->workspace.ptr()->opaque()); VLOG(1) << "Pooling cache hit"; } else { VLOG(1) << "Pooling cache miss"; @@ -4583,7 +4586,7 @@ absl::Status MIOpenSupport::DoPoolBackward( return absl::InternalError( "Failed to allocate backward pooling workspace"); } - DeviceMemory dest2; // duplicated dest from forward: + DeviceMemory dest2; // duplicated dest from forward: int64_t dest2_size = 0; // miopen requires the strides and dims to be ordered as BDYX. @@ -4619,7 +4622,7 @@ absl::Status MIOpenSupport::DoPoolBackward( "Failed to enqueue forward pooling (before backward) on stream: ", ToString(status))); } - workspace_ptr = reinterpret_cast(workspace.opaque()); + workspace_ptr = reinterpret_cast(workspace.opaque()); } } @@ -4708,7 +4711,7 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( float alpha = 1.0f; float beta = 0.0f; - DeviceMemory workspace; + DeviceMemory workspace; size_t workspace_size_in_bytes = 0; auto status = wrap::miopenLRNGetWorkSpaceSize(dims.handle(), &workspace_size_in_bytes); @@ -4729,7 +4732,7 @@ bool MIOpenSupport::DoNormalizeBackwardWithDimensions( } } - DeviceMemory dest2; // duplicated dest from forward: + DeviceMemory dest2; // duplicated dest from forward: int dest2_size = 0; // miopen requires the strides and dims to be ordered as BDYX. @@ -4892,7 +4895,7 @@ class RocmFusedConvRunner : public dnn::FusedConvRunner { public: // Queries the workspace size and constructs a 'RocmFusedConvRunner'. static absl::StatusOr> Create( - GpuExecutor* parent, Stream* stream, MIOpenAccess* miopen, + StreamExecutor* parent, Stream* stream, MIOpenAccess* miopen, const dnn::AlgorithmDesc& algo, dnn::DataType input_type, dnn::DataType bias_type, double conv_scale, double side_input_scale, double leakyrelu_alpha, BatchDescriptor input_nd, @@ -4964,7 +4967,7 @@ class RocmFusedConvRunner : public dnn::FusedConvRunner { private: // Private to prevent passing in the wrong workspace_size. RocmFusedConvRunner( - GpuExecutor* parent, Stream* stream, MIOpenAccess* miopen, + StreamExecutor* parent, Stream* stream, MIOpenAccess* miopen, int64_t algo_id, size_t workspace_size, dnn::DataType input_type, dnn::DataType bias_type, double conv_scale, double side_input_scale, double leakyrelu_alpha, BatchDescriptor dnn_input_nd, @@ -5080,7 +5083,7 @@ class RocmFusedConvRunner : public dnn::FusedConvRunner { std::string desc_; - GpuExecutor* parent_; + StreamExecutor* parent_; MIOpenAccess* miopen_; int64_t algo_id_; size_t workspace_size_; @@ -5195,16 +5198,7 @@ void initialize_miopen() { PluginRegistry::Instance()->RegisterFactory( rocm::kROCmPlatformId, "MIOpen", [](StreamExecutor* parent) -> dnn::DnnSupport* { - gpu::GpuExecutor* rocm_executor = - dynamic_cast(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the MIOpen " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(rocm_executor); + gpu::MIOpenSupport* dnn = new gpu::MIOpenSupport(parent); if (!dnn->Init().ok()) { // Note: Init() will log a more specific error. delete dnn; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h index 957c288f12c278..27f99258d60db7 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_dnn.h @@ -29,7 +29,6 @@ limitations under the License. namespace stream_executor { namespace gpu { -class GpuExecutor; class MIOpenRnnDescriptor; class MIOpenRnnSequenceTensorDescriptor; class MIOpenRnnStateTensorDescriptor; @@ -41,7 +40,7 @@ struct PoolingWorkspaceDescriptor { dnn::PoolingDescriptor op; int dtype; uint64_t timestamp; - ScopedDeviceMemory workspace; + ScopedDeviceMemory workspace; size_t workspace_size; bool IsSame(const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, @@ -61,7 +60,7 @@ struct PoolingWorkspaceCache { void insert(const void* p, const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& output_dimensions, const dnn::PoolingDescriptor& pooling_dimensions, int _type, - ScopedDeviceMemory& workspace, size_t wsp_size, + ScopedDeviceMemory& workspace, size_t wsp_size, hipStream_t hip_stream); private: @@ -72,7 +71,7 @@ struct PoolingWorkspaceCache { // functions, see dnn.h. class MIOpenSupport : public dnn::DnnSupport { public: - explicit MIOpenSupport(GpuExecutor* parent); + explicit MIOpenSupport(StreamExecutor* parent); absl::Status Init() override; absl::StatusOr GetVersion() override; @@ -173,7 +172,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) override; @@ -199,7 +198,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) override; @@ -225,7 +224,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result) override; @@ -357,7 +356,7 @@ class MIOpenSupport : public dnn::DnnSupport { dnn::ActivationMode activation_mode, DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) override; bool DoBatchNormalizationBackward( @@ -371,7 +370,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* x_backprop, DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, DeviceMemory* side_input_backprop, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) override; bool DoBatchNormalizationBackward( @@ -398,7 +397,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_descriptor, DeviceMemoryBase output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, - dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, + dnn::AlgorithmDesc algorithm_desc, DeviceMemory scratch_memory, dnn::ProfileResult* output_profile_result) override; absl::Status DoFusedConvolve( @@ -490,7 +489,7 @@ class MIOpenSupport : public dnn::DnnSupport { dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) override; - GpuExecutor* GetParentExecutor() { return parent_; } + StreamExecutor* GetParentExecutor() { return parent_; } absl::Status DoCtcLoss(Stream* stream, dnn::DataType element_type, const dnn::RnnStateTensorDescriptor& probs_desc, @@ -501,11 +500,11 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemoryBase costs_data, const dnn::RnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, - DeviceMemory scratch_memory, + DeviceMemory scratch_memory, int ctc_loss_algo_id) override; private: - GpuExecutor* parent_; // Parent executor object. Not owned. + StreamExecutor* parent_; // Parent executor object. Not owned. // Flag to indicate whether Get*Algorithm routines should only return // the best algorithm (as opposed to a list of all applicable ones) @@ -586,7 +585,7 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* input_h_backprop_data, DeviceMemory* input_c_backprop_data, DeviceMemory* params_backprop_data, - DeviceMemory* reserve_space_data, + DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator, dnn::ProfileResult* output_profile_result); @@ -600,7 +599,7 @@ class MIOpenSupport : public dnn::DnnSupport { const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::AlgorithmConfig& algorithm_config, ScratchAllocator* scratch_allocator, dnn::AlgorithmDesc* algorithm_desc, - DeviceMemory* scratch_memory) override; + DeviceMemory* scratch_memory) override; absl::Status DoCtcLossImpl( Stream* stream, const MIOpenRnnStateTensorDescriptor& probs_desc, @@ -609,7 +608,7 @@ class MIOpenSupport : public dnn::DnnSupport { absl::Span input_lengths_data, DeviceMemoryBase costs_data, const MIOpenRnnStateTensorDescriptor& grads_desc, DeviceMemoryBase grads_data, const MIOpenCTCLossDescriptor& ctc_loss_desc, - DeviceMemory scratch_memory, int ctc_loss_algo_id); + DeviceMemory scratch_memory, int ctc_loss_algo_id); absl::Status DoPrepareForCtcLoss( Stream* stream, dnn::DataType element_type, @@ -619,8 +618,8 @@ class MIOpenSupport : public dnn::DnnSupport { absl::Span labels_lengths_data, absl::Span input_lengths_data, const NumericOptions& numeric_options, - ScratchAllocator* scratch_allocator, DeviceMemory* scratch_memory, - int* ctc_loss_algo_id) override; + ScratchAllocator* scratch_allocator, + DeviceMemory* scratch_memory, int* ctc_loss_algo_id) override; MIOpenSupport(const MIOpenSupport&) = delete; void operator=(const MIOpenSupport&) = delete; diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 81f7f3d76fbd79..6c9bf65ffcf36c 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -13,28 +13,36 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/stream_executor/rocm/rocm_driver.h" - #include #include -#include -#include +#include +#include +#include +#include #include +#include +#include #include "absl/base/casts.h" #include "absl/container/inlined_vector.h" -#include "absl/container/node_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "absl/synchronization/notification.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "rocm/include/hip/hip_runtime.h" +#include "rocm/rocm_config.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/context.h" #include "xla/stream_executor/gpu/context_map.h" #include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/platform/port.h" +#include "xla/stream_executor/rocm/rocm_context.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/casts.h" #include "tsl/platform/env.h" @@ -42,343 +50,21 @@ limitations under the License. #include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" -#include "tsl/platform/threadpool.h" - -#define RETURN_IF_ROCM_ERROR(expr, ...) \ - do { \ - hipError_t _res = (expr); \ - if (TF_PREDICT_FALSE(_res != hipSuccess)) { \ - if (_res == hipErrorOutOfMemory) \ - return absl::ResourceExhaustedError(absl::StrCat( \ - __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \ - else \ - return absl::InternalError(absl::StrCat( \ - __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \ - } \ - } while (0) - -#define FAIL_IF_ROCM_ERROR(expr, ...) \ - do { \ - hipError_t _res = (expr); \ - if (ABSL_PREDICT_FALSE(_res != hipSuccess)) { \ - LOG(FATAL) << absl::StrCat(__VA_ARGS__) << ": " \ - << ::stream_executor::gpu::ToString(_res); \ - } \ - } while (0) +#include "tsl/platform/status.h" namespace stream_executor::gpu { -namespace { - -hipCtx_t CurrentContextOrDie() { - hipCtx_t current = nullptr; - FAIL_IF_ROCM_ERROR(hipCtxGetCurrent(¤t), - "Failed to query current context"); - return current; -} - -// Returns the singleton ContextMap. -ContextMap* GetContextMap() { - static ContextMap* context_map = - new ContextMap([](void* ptr) { - int device_ordinal; - hipError_t result = - hipPointerGetAttribute(static_cast(&device_ordinal), - HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(ptr)); - if (result != hipSuccess) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr - << ". Error: " << ToString(result); - } - return device_ordinal; - }); - return context_map; -} - -} // namespace - -// Formats hipError_t to output prettified values into a log stream. -// Error summaries taken from: -std::string ToString(hipError_t result) { -#define OSTREAM_ROCM_ERROR(__name) \ - case hipError##__name: \ - return "HIP_ERROR_" #__name; - - switch (result) { - OSTREAM_ROCM_ERROR(InvalidValue) - OSTREAM_ROCM_ERROR(OutOfMemory) - OSTREAM_ROCM_ERROR(NotInitialized) - OSTREAM_ROCM_ERROR(Deinitialized) - OSTREAM_ROCM_ERROR(NoDevice) - OSTREAM_ROCM_ERROR(InvalidDevice) - OSTREAM_ROCM_ERROR(InvalidImage) - OSTREAM_ROCM_ERROR(InvalidContext) - OSTREAM_ROCM_ERROR(InvalidHandle) - OSTREAM_ROCM_ERROR(NotFound) - OSTREAM_ROCM_ERROR(NotReady) - OSTREAM_ROCM_ERROR(NoBinaryForGpu) - - // Encountered an uncorrectable ECC error during execution. - OSTREAM_ROCM_ERROR(ECCNotCorrectable) - - // Load/store on an invalid address. Must reboot all context. - case 700: - return "ROCM_ERROR_ILLEGAL_ADDRESS"; - // Passed too many / wrong arguments, too many threads for register count. - case 701: - return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; - - OSTREAM_ROCM_ERROR(ContextAlreadyInUse) - OSTREAM_ROCM_ERROR(PeerAccessUnsupported) - OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. -#if TF_ROCM_VERSION >= 60200 - OSTREAM_ROCM_ERROR(LaunchTimeOut) - OSTREAM_ROCM_ERROR(PeerAccessAlreadyEnabled) - OSTREAM_ROCM_ERROR(PeerAccessNotEnabled) - OSTREAM_ROCM_ERROR(SetOnActiveProcess) - OSTREAM_ROCM_ERROR(ContextIsDestroyed) - OSTREAM_ROCM_ERROR(Assert) - OSTREAM_ROCM_ERROR(HostMemoryAlreadyRegistered) - OSTREAM_ROCM_ERROR(HostMemoryNotRegistered) - OSTREAM_ROCM_ERROR(LaunchFailure) - OSTREAM_ROCM_ERROR(CooperativeLaunchTooLarge) - OSTREAM_ROCM_ERROR(NotSupported) - OSTREAM_ROCM_ERROR(StreamCaptureUnsupported) - OSTREAM_ROCM_ERROR(StreamCaptureInvalidated) - OSTREAM_ROCM_ERROR(StreamCaptureMerge) - OSTREAM_ROCM_ERROR(StreamCaptureUnmatched) - OSTREAM_ROCM_ERROR(StreamCaptureUnjoined) - OSTREAM_ROCM_ERROR(StreamCaptureIsolation) - OSTREAM_ROCM_ERROR(StreamCaptureImplicit) - OSTREAM_ROCM_ERROR(CapturedEvent) - OSTREAM_ROCM_ERROR(StreamCaptureWrongThread) - OSTREAM_ROCM_ERROR(GraphExecUpdateFailure) - OSTREAM_ROCM_ERROR(RuntimeMemory) - OSTREAM_ROCM_ERROR(RuntimeOther) -#endif // TF_ROCM_VERSION >= 60200 - default: - return absl::StrCat("hipError_t(", static_cast(result), ")"); - } -} - -namespace { - -// Returns the current context and checks that it is in the set of HIP contexts -// created by StreamExecutor (to ensure that the HIP runtime didn't create a -// context behind our backs). -hipCtx_t CurrentContext() { - hipCtx_t current = CurrentContextOrDie(); - if (current != nullptr && !GetContextMap()->Has(current)) { - LOG(FATAL) << "current context was not created by the StreamExecutor " - "rocm_driver API: " - << current - << "; a HIP runtime call " - "was likely performed without using a StreamExecutor context"; - } - return current; -} - -// Returns the device associated with the given context. -absl::StatusOr DeviceFromContext(Context* context) { - ScopedActivateContext activated{context}; - hipDevice_t device = -1; - hipError_t result = wrap::hipCtxGetDevice(&device); - if (result == hipSuccess) return device; - - return absl::InternalError( - absl::StrCat("failed to get device for context: ", ToString(result))); -} - -// ROCM driver routines may require a large amount of stack (particularly -// hipModuleLoadDataEx, in our experience). To avoid stack overflow when using -// stack-limited threads (such as those spawned by a default-argument -// thread::ThreadPool on some platforms), we run certain routines in this pool -// and wait for completion. -tsl::thread::ThreadPool* GetDriverExecutor() { - static tsl::thread::ThreadPool* thread_pool = new tsl::thread::ThreadPool( - tsl::Env::Default(), tsl::ThreadOptions(), "rocm_driver", 1); - return thread_pool; -} - -} // namespace - -namespace { - -// Call hipDeviceSynchronize and crash if it doesn't succeed. -void SynchronizeOrDie() { - auto res = wrap::hipDeviceSynchronize(); - if (res != hipSuccess) { - LOG(FATAL) << "Synchronize found " << ToString(res) - << " :: " << tsl::CurrentStackTrace(); - } -} - -} // namespace - -void GpuContext::SetActive() { - FAIL_IF_ROCM_ERROR(wrap::hipCtxSetCurrent(context_), - "Failed setting context"); -} - -bool GpuContext::IsActive() const { return CurrentContext() == context_; } - -namespace { - -// Actually performs the work of ROCM initialization. Wrapped up in one-time -// execution guard. -static absl::Status InternalInit() { - hipError_t res = wrap::hipInit(0 /* = flags */); - - if (res == hipSuccess) { - return absl::OkStatus(); - } - - LOG(ERROR) << "failed call to hipInit: " << ToString(res); - Diagnostician::LogDiagnosticInformation(); - return absl::AbortedError( - absl::StrCat("failed call to hipInit: ", ToString(res))); -} - -} // namespace - -absl::Status GpuDriver::Init() { - // Cached return value from calling InternalInit(), as hipInit need only be - // called once, but GpuDriver::Init may be called many times. - static absl::Status* init_retval = [] { - return new absl::Status(InternalInit()); - }(); - return *init_retval; -} - -absl::Status GpuDriver::GetDevice(int device_ordinal, hipDevice_t* device) { - hipError_t res = wrap::hipDeviceGet(device, device_ordinal); - if (res == hipSuccess) { - return absl::OkStatus(); - } - - return absl::InternalError( - absl::StrCat("failed call to hipDeviceGet: ", ToString(res))); -} - -absl::Status GpuDriver::GetDeviceName(hipDevice_t device, - std::string* device_name) { - static const size_t kCharLimit = 64; - absl::InlinedVector chars(kCharLimit); - RETURN_IF_ROCM_ERROR( - wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device), - "Failed to get device name"); - chars[kCharLimit - 1] = '\0'; - *device_name = chars.begin(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::CreateContext(int device_ordinal, hipDevice_t device, - Context** context) { - *context = nullptr; - - int flags = 0; - - hipError_t res; - hipCtx_t former_context; - hipCtx_t new_context; - - unsigned int former_primary_context_flags; - int former_primary_context_is_active; - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxGetState( - device, &former_primary_context_flags, - &former_primary_context_is_active)); - if (former_primary_context_flags != flags) { - if (former_primary_context_is_active) { - LOG(ERROR) - << "The primary context is active and has a different flag set (" - << former_primary_context_flags << ") than the desired flag set (" - << flags << ")."; - } else { - CHECK_EQ(hipSuccess, wrap::hipDevicePrimaryCtxSetFlags(device, flags)); - } - } - - former_context = CurrentContextOrDie(); - res = wrap::hipDevicePrimaryCtxRetain(&new_context, device); - if (former_context != nullptr) { - hipDevice_t former_device; - if (wrap::hipCtxGetDevice(&former_device) == hipSuccess) { - if (former_device == device) { - if (former_context == new_context) { - VLOG(2) << "The primary context " << former_context << " for device " - << device - << " exists before initializing the StreamExecutor."; - } else { - LOG(WARNING) << "A non-primary context " << former_context - << " for device " << device - << " exists before initializing the StreamExecutor. The " - << "primary context is now " << new_context << ". We " - << "haven't verified StreamExecutor works with that."; - } - } - } else { - LOG(ERROR) << "Failed to get the device of the current context " - << former_context; - } - } - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); - - if (res == hipSuccess) { - *context = GetContextMap()->Add(new_context, device_ordinal); - CHECK(*context != nullptr) - << "success in this call must entail non-null result"; - VLOG(2) << "created or reused context " << new_context - << " for this thread"; - return absl::OkStatus(); - } - - std::string message = - "failed call to hipDevicePrimaryCtxRetain: " + ToString(res); - if (res == hipErrorOutOfMemory) { - uint64_t total_memory; - if (GetDeviceTotalMemory(device, &total_memory)) { - absl::StrAppend(&message, "; total memory reported: ", total_memory); - } else { - absl::StrAppend(&message, "; could not query total memory"); - } - } - - return absl::InternalError(message); -} - -void GpuDriver::DestroyContext(Context* context) { - if (context == nullptr) { - return; - } - GpuContext* gpu_context = tensorflow::down_cast(context); - hipCtx_t former_context = CurrentContext(); - context->SetActive(); - hipDevice_t device; - CHECK_EQ(hipSuccess, wrap::hipCtxGetDevice(&device)); - CHECK_EQ(hipSuccess, wrap::hipCtxSetCurrent(former_context)); - - auto res = wrap::hipDevicePrimaryCtxRelease(device); - - if (res != hipSuccess) { - LOG(ERROR) << "failed to release HIP context; leaking: " << ToString(res); - } - - GetContextMap()->Remove(gpu_context->context()); -} - absl::Status GpuDriver::CreateGraph(hipGraph_t* graph) { VLOG(2) << "Create new HIP graph"; - RETURN_IF_ROCM_ERROR(wrap::hipGraphCreate(graph, /*flags=*/0), - "Failed to create HIP graph"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphCreate(graph, /*flags=*/0), + "Failed to create HIP graph")); VLOG(2) << "Created HIP graph " << *graph; return absl::OkStatus(); } absl::Status GpuDriver::DestroyGraph(hipGraph_t graph) { VLOG(2) << "Destroy HIP graph " << graph; - RETURN_IF_ROCM_ERROR(wrap::hipGraphDestroy(graph), - "Failed to destroy HIP graph"); - return absl::OkStatus(); + return ToStatus(wrap::hipGraphDestroy(graph), "Failed to destroy HIP graph"); } static std::string_view StreamCaptureModeToString( @@ -410,9 +96,8 @@ absl::Status GpuDriver::StreamBeginCapture(GpuStreamHandle stream, VLOG(2) << "Beging stream " << stream << " capture in " << StreamCaptureModeToString(mode) << " mode"; - RETURN_IF_ROCM_ERROR(wrap::hipStreamBeginCapture(stream, hip_mode), - "Failed to begin stream capture"); - return absl::OkStatus(); + return ToStatus(wrap::hipStreamBeginCapture(stream, hip_mode), + "Failed to begin stream capture"); } absl::Status GpuDriver::StreamBeginCaptureToGraph(GpuStreamHandle stream, @@ -426,10 +111,8 @@ absl::Status GpuDriver::StreamEndCapture(GpuStreamHandle stream, hipGraph_t* graph) { VLOG(2) << "End stream " << stream << " capture"; - RETURN_IF_ROCM_ERROR(wrap::hipStreamEndCapture(stream, graph), - "Failed to end stream capture"); - - return absl::OkStatus(); + return ToStatus(wrap::hipStreamEndCapture(stream, graph), + "Failed to end stream capture"); } absl::Status GpuDriver::GraphInstantiate(hipGraphExec_t* exec, hipGraph_t graph, @@ -439,19 +122,16 @@ absl::Status GpuDriver::GraphInstantiate(hipGraphExec_t* exec, hipGraph_t graph, << "device_launch=" << flags.device_launch << ", " << "use_node_priority=" << flags.use_node_prirotiy << ", " << "upload=" << flags.upload << ")"; - RETURN_IF_ROCM_ERROR( - wrap::hipGraphInstantiate(exec, graph, nullptr, nullptr, 0), - "Failed to instantiate HIP graph"); - return absl::OkStatus(); + return ToStatus(wrap::hipGraphInstantiate(exec, graph, nullptr, nullptr, 0), + "Failed to instantiate HIP graph"); } absl::Status GpuDriver::GraphLaunch(hipGraphExec_t exec, GpuStreamHandle stream) { VLOG(2) << "Launching HIP executable graph " << exec << " on a stream " << stream; - RETURN_IF_ROCM_ERROR(wrap::hipGraphLaunch(exec, stream), - "Failed to launch HIP graph"); - return absl::OkStatus(); + return ToStatus(wrap::hipGraphLaunch(exec, stream), + "Failed to launch HIP graph"); } absl::Status GpuDriver::GraphNodeSetEnabled(hipGraphExec_t exec, @@ -460,9 +140,8 @@ absl::Status GpuDriver::GraphNodeSetEnabled(hipGraphExec_t exec, unsigned value = enabled ? 1 : 0; VLOG(2) << "Set HIP executable graph " << exec << " node " << node << " enabled flag to " << value; - RETURN_IF_ROCM_ERROR(wrap::hipGraphNodeSetEnabled(exec, node, value), - "Failed to set HIP graph node enabled flag"); - return absl::OkStatus(); + return ToStatus(wrap::hipGraphNodeSetEnabled(exec, node, value), + "Failed to set HIP graph node enabled flag"); } absl::Status GpuDriver::GraphExecUpdate(hipGraphExec_t exec, hipGraph_t graph, @@ -506,8 +185,7 @@ absl::Status GpuDriver::GraphExecUpdate(hipGraphExec_t exec, hipGraph_t graph, // TODO: HIP hasn't GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED yet } - RETURN_IF_ROCM_ERROR(hip_error, "Failed to update HIP graph"); - return absl::OkStatus(); + return ToStatus(hip_error, "Failed to update HIP graph"); } absl::StatusOr> @@ -517,23 +195,22 @@ GpuDriver::GraphNodeGetDependencies(GpuGraphNodeHandle node) { std::vector dependencies; size_t num_dependencies = 0; - RETURN_IF_ROCM_ERROR( - hipGraphNodeGetDependencies(node, nullptr, &num_dependencies), - "Failed to get HIP graph node depedencies size"); + TF_RETURN_IF_ERROR( + ToStatus(hipGraphNodeGetDependencies(node, nullptr, &num_dependencies), + "Failed to get HIP graph node depedencies size")); dependencies.resize(num_dependencies, nullptr); - RETURN_IF_ROCM_ERROR( + TF_RETURN_IF_ERROR(ToStatus( hipGraphNodeGetDependencies(node, dependencies.data(), &num_dependencies), - "Failed to get HIP graph node depedencies"); + "Failed to get HIP graph node depedencies")); return dependencies; } absl::Status GpuDriver::DestroyGraphExec(hipGraphExec_t exec) { VLOG(2) << "Destroying HIP executable graph" << exec; - RETURN_IF_ROCM_ERROR(wrap::hipGraphExecDestroy(exec), - "Failed to destroy HIP graph"); - return absl::OkStatus(); + return ToStatus(wrap::hipGraphExecDestroy(exec), + "Failed to destroy HIP graph"); } absl::StatusOr GpuDriver::GraphDebugDotPrint( @@ -541,8 +218,8 @@ absl::StatusOr GpuDriver::GraphDebugDotPrint( VLOG(2) << "Print HIP graph " << graph << " debug dot file to " << path; int flags = hipGraphDebugDotFlagsVerbose; - RETURN_IF_ROCM_ERROR(wrap::hipGraphDebugDotPrint(graph, path, flags), - "Failed to print gpu graph debug file"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphDebugDotPrint(graph, path, flags), + "Failed to print gpu graph debug file")); if (return_printed_graph) { std::string data; @@ -556,23 +233,6 @@ absl::StatusOr GpuDriver::GraphDebugDotPrint( return std::string(path); } -absl::Status GpuDriver::DeviceGraphMemTrim(GpuDeviceHandle device) { - VLOG(2) << "Trim ROCM device graph memory " << device; - RETURN_IF_ROCM_ERROR(wrap::hipDeviceGraphMemTrim(device), - "Failed to trim device graph memory"); - return absl::OkStatus(); -} - -absl::StatusOr GpuDriver::StreamIsCapturing(GpuStreamHandle stream) { - VLOG(2) << "Checking if stream " << stream << " is capturing"; - - hipStreamCaptureStatus status; - RETURN_IF_ROCM_ERROR(wrap::hipStreamIsCapturing(stream, &status), - "Failed to check stream capturing status"); - - return status == hipStreamCaptureStatusActive; -} - absl::Status GpuDriver::GraphConditionalHandleCreate( GpuGraphConditionalHandle* handle, hipGraph_t graph, Context* context, unsigned int default_launch_value, unsigned int flags) { @@ -596,11 +256,9 @@ absl::Status GpuDriver::GraphAddEmptyNode( absl::Span deps) { VLOG(2) << "Add empty node to a graph " << graph << "; deps: " << deps.size(); - RETURN_IF_ROCM_ERROR( + return ToStatus( wrap::hipGraphAddEmptyNode(node, graph, deps.data(), deps.size()), "Failed to add empty node to a HIP graph"); - - return absl::OkStatus(); } absl::Status GpuDriver::GraphAddKernelNode( @@ -631,32 +289,30 @@ absl::Status GpuDriver::GraphAddKernelNode( params.extra = extra; if (shared_mem_bytes != 0) { - RETURN_IF_ROCM_ERROR( + TF_RETURN_IF_ERROR(ToStatus( wrap::hipFuncSetAttribute(function, hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } - RETURN_IF_ROCM_ERROR(wrap::hipGraphAddKernelNode(node, graph, deps.data(), - deps.size(), ¶ms), - "Failed to add kernel node to a HIP graph"); - - return absl::OkStatus(); + return ToStatus(wrap::hipGraphAddKernelNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add kernel node to a HIP graph"); } absl::StatusOr GpuDriver::GraphGetNodeCount(hipGraph_t graph) { VLOG(2) << "Get node count in graph " << graph; size_t numNodes; - RETURN_IF_ROCM_ERROR(wrap::hipGraphGetNodes(graph, nullptr, &numNodes), - "Failed to get HIP graph node count"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipGraphGetNodes(graph, nullptr, &numNodes), + "Failed to get HIP graph node count")); return numNodes; } /*static*/ absl::Status GpuDriver::GraphExecKernelNodeSetParams( GpuGraphExecHandle exec, GpuGraphNodeHandle node, - absl::string_view kernel_name, GpuFunctionHandle function, + absl::string_view kernel_name, hipFunction_t function, unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, unsigned int block_dim_y, unsigned int block_dim_z, unsigned int shared_mem_bytes, @@ -682,18 +338,15 @@ absl::StatusOr GpuDriver::GraphGetNodeCount(hipGraph_t graph) { params.extra = extra; if (shared_mem_bytes != 0) { - RETURN_IF_ROCM_ERROR( + TF_RETURN_IF_ERROR(ToStatus( wrap::hipFuncSetAttribute(function, hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_bytes), - "Failed to set shared memory size"); + "Failed to set shared memory size")); } - RETURN_IF_ROCM_ERROR( - wrap::hipGraphExecKernelNodeSetParams(exec, node, ¶ms), - "Failed to set HIP graph kernel node params"); - - return absl::OkStatus(); + return ToStatus(wrap::hipGraphExecKernelNodeSetParams(exec, node, ¶ms), + "Failed to set HIP graph kernel node params"); } absl::Status GpuDriver::GraphAddChildNode(hipGraphNode_t* node, @@ -703,11 +356,10 @@ absl::Status GpuDriver::GraphAddChildNode(hipGraphNode_t* node, VLOG(2) << "Create a new node by cloning the child graph " << child << " and add it to " << graph << "; deps: " << deps.size(); - RETURN_IF_ROCM_ERROR( + return ToStatus( wrap::hipGraphAddChildGraphNode(node, graph, deps.data(), deps.size(), child), "Failed to create a child graph node and add it to a HIP graph"); - return absl::OkStatus(); } /*static*/ absl::Status GpuDriver::GraphExecChildNodeSetParams( @@ -715,44 +367,37 @@ absl::Status GpuDriver::GraphAddChildNode(hipGraphNode_t* node, VLOG(2) << "Set child node params " << node << " in graph executable " << exec << "to params contained in " << child; - RETURN_IF_ROCM_ERROR( - wrap::hipGraphExecChildGraphNodeSetParams(exec, node, child), - "Failed to set HIP graph child node params"); - - return absl::OkStatus(); + return ToStatus(wrap::hipGraphExecChildGraphNodeSetParams(exec, node, child), + "Failed to set HIP graph child node params"); } absl::Status GpuDriver::GraphAddMemcpyD2DNode( Context* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr gpu_dst, - GpuDevicePtr gpu_src, uint64_t size) { + absl::Span deps, hipDeviceptr_t gpu_dst, + hipDeviceptr_t gpu_src, uint64_t size) { VLOG(2) << "Add memcpy d2d node to a graph " << graph << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size << "; context: " << context << "; deps: " << deps.size(); - RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemcpyNode1D( - node, graph, deps.data(), deps.size(), gpu_dst, - gpu_src, size, hipMemcpyDeviceToDevice), - "Failed to add memcpy d2d node to a HIP graph"); - - return absl::OkStatus(); + return ToStatus(wrap::hipGraphAddMemcpyNode1D(node, graph, deps.data(), + deps.size(), gpu_dst, gpu_src, + size, hipMemcpyDeviceToDevice), + "Failed to add memcpy d2d node to a HIP graph"); } absl::Status GpuDriver::GraphExecMemcpyD2DNodeSetParams( Context* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, - GpuDevicePtr gpu_dst, GpuDevicePtr gpu_src, uint64_t size) { + hipDeviceptr_t gpu_dst, hipDeviceptr_t gpu_src, uint64_t size) { VLOG(2) << "Set memcpy d2d node params " << node << " in graph executable " << exec << "; dst: " << reinterpret_cast(gpu_dst) << "; src: " << reinterpret_cast(gpu_src) << "; size: " << size << "; context: " << context; - RETURN_IF_ROCM_ERROR( + return ToStatus( wrap::hipGraphExecMemcpyNodeSetParams1D(exec, node, gpu_dst, gpu_src, size, hipMemcpyDeviceToDevice), "Failed to set memcpy d2d node params"); - - return absl::OkStatus(); } namespace { @@ -789,7 +434,7 @@ struct BitPatternToValue { absl::Status GpuDriver::GraphAddMemsetNode( Context* context, GpuGraphNodeHandle* node, GpuGraphHandle graph, - absl::Span deps, GpuDevicePtr dst, + absl::Span deps, hipDeviceptr_t dst, std::variant bit_pattern, uint64_t num_elements) { VLOG(2) << "Add memset node to a graph " << graph @@ -809,16 +454,14 @@ absl::Status GpuDriver::GraphAddMemsetNode( .width = num_elements, }; - RETURN_IF_ROCM_ERROR(wrap::hipGraphAddMemsetNode(node, graph, deps.data(), - deps.size(), ¶ms), - "Failed to add memset node to a HIP graph"); - - return absl::OkStatus(); + return ToStatus(wrap::hipGraphAddMemsetNode(node, graph, deps.data(), + deps.size(), ¶ms), + "Failed to add memset node to a HIP graph"); } absl::Status GpuDriver::GraphExecMemsetNodeSetParams( Context* context, GpuGraphExecHandle exec, GpuGraphNodeHandle node, - GpuDevicePtr dst, std::variant bit_pattern, + hipDeviceptr_t dst, std::variant bit_pattern, uint64_t num_elements) { VLOG(2) << "Set memset node params " << node << " in graph executable " << exec << "; dst: " << reinterpret_cast(dst) @@ -836,535 +479,8 @@ absl::Status GpuDriver::GraphExecMemsetNodeSetParams( .width = num_elements, }; - RETURN_IF_ROCM_ERROR( - wrap::hipGraphExecMemsetNodeSetParams(exec, node, ¶ms), - "Failed to set memset node params"); - - return absl::OkStatus(); -} - -absl::Status GpuDriver::LaunchKernel( - Context* context, absl::string_view kernel_name, hipFunction_t function, - unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, - unsigned int block_dim_x, unsigned int block_dim_y, - unsigned int block_dim_z, unsigned int shared_mem_bytes, - GpuStreamHandle stream, void** kernel_params, void** extra) { - ScopedActivateContext activation{context}; - VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x - << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z - << " bdx: " << block_dim_x << " bdy: " << block_dim_y - << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes - << " func: " << (const void*)function; - - auto res = hipSuccess; -#if TF_ROCM_VERSION < 60200 - // for in-process kernel this function returns mangled kernel function name, - // and null otherwise - auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream); - if (name != nullptr) { - res = wrap::hipLaunchKernel((const void*)function, - dim3(grid_dim_x, grid_dim_y, grid_dim_z), - dim3(block_dim_x, block_dim_y, block_dim_z), - kernel_params, shared_mem_bytes, stream); - } else // NOLINT(readability/braces) -#endif // TF_ROCM_VERSION < 60200 - { - res = wrap::hipModuleLaunchKernel( - function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, - block_dim_z, shared_mem_bytes, stream, kernel_params, extra); - } - RETURN_IF_ROCM_ERROR(res, "Failed to launch ROCm kernel: ", kernel_name, - " with block dimensions: ", block_dim_x, "x", - block_dim_y, "x", block_dim_z); - - VLOG(2) << "successfully launched kernel"; - return absl::OkStatus(); -} - -absl::Status GpuDriver::LaunchKernel( - Context* context, absl::string_view kernel_name, GpuFunctionHandle function, - unsigned int cluster_dim_x, unsigned int cluster_dim_y, - unsigned int cluster_dim_z, unsigned int grid_dim_x, - unsigned int grid_dim_y, unsigned int grid_dim_z, unsigned int block_dim_x, - unsigned int block_dim_y, unsigned int block_dim_z, - unsigned int shared_mem_bytes, GpuStreamHandle stream, void** kernel_params, - void** extra) { - if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1) - return absl::UnimplementedError("Not implemented for ROCm"); - return LaunchKernel(context, kernel_name, function, grid_dim_x, grid_dim_y, - grid_dim_z, block_dim_x, block_dim_y, block_dim_z, - shared_mem_bytes, stream, kernel_params, extra); -} - -absl::Status GpuDriver::LoadPtx(Context* context, const char* ptx_contents, - hipModule_t* module) { - return absl::InternalError( - "Feature not supported on ROCm platform (LoadPtx)"); -} - -absl::Status GpuDriver::LoadCubin(Context* context, const char* cubin_bytes, - hipModule_t* module) { - return absl::InternalError( - "Feature not supported on ROCm platform (LoadCubin)"); -} - -absl::Status GpuDriver::LoadHsaco(Context* context, const char* hsaco_contents, - hipModule_t* module) { - absl::Notification notification; - absl::Status ret = absl::OkStatus(); - GetDriverExecutor()->Schedule( - [context, hsaco_contents, module, &ret, ¬ification]() { - ScopedActivateContext activation{context}; - void* hsaco_data = const_cast(hsaco_contents); - - hipError_t res = wrap::hipModuleLoadData(module, hsaco_data); - - if (res != hipSuccess) { - ret = absl::InternalError( - absl::StrCat("Failed to load HSACO: ", ToString(res))); - notification.Notify(); - } - - CHECK(module != nullptr); - notification.Notify(); - }); - notification.WaitForNotification(); - - return ret; -} - -absl::Status GpuDriver::SynchronousMemsetUint8(Context* context, - hipDeviceptr_t location, - uint8 value, size_t size) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR(wrap::hipMemsetD8(location, value, size), - "Failed to memset memory"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronousMemsetUint32(Context* context, - hipDeviceptr_t location, - uint32 value, - size_t uint32_count) { - ScopedActivateContext activation{context}; - void* pointer = absl::bit_cast(location); - RETURN_IF_ROCM_ERROR(wrap::hipMemsetD32(pointer, value, uint32_count), - "Failed to memset memory"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemsetUint8(Context* context, - hipDeviceptr_t location, - uint8 value, - size_t uint32_count, - GpuStreamHandle stream) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemsetAsync(location, value, uint32_count, stream), - "Failed to enqueue async memset operation"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemsetUint32(Context* context, - hipDeviceptr_t location, - uint32 value, - size_t uint32_count, - GpuStreamHandle stream) { - ScopedActivateContext activation{context}; - void* pointer = absl::bit_cast(location); - RETURN_IF_ROCM_ERROR( - wrap::hipMemsetD32Async(pointer, value, uint32_count, stream), - "Failed to enqueue async memset operation"); - VLOG(2) << "successfully enqueued async memset operation"; - return absl::OkStatus(); -} - -absl::Status GpuDriver::AddStreamCallback(Context* context, - GpuStreamHandle stream, - StreamCallback callback, void* data) { - RETURN_IF_ROCM_ERROR( - wrap::hipLaunchHostFunc(stream, (hipHostFn_t)callback, data), - "unable to add host callback"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::GetModuleFunction(Context* context, hipModule_t module, - const char* kernel_name, - hipFunction_t* function) { - ScopedActivateContext activated{context}; - CHECK(module != nullptr && kernel_name != nullptr); - RETURN_IF_ROCM_ERROR( - wrap::hipModuleGetFunction(function, module, kernel_name), - "Failed to get kernel"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::GetModuleSymbol(Context* context, hipModule_t module, - const char* symbol_name, - hipDeviceptr_t* dptr, size_t* bytes) { - ScopedActivateContext activated{context}; - CHECK(module != nullptr && symbol_name != nullptr && - (dptr != nullptr || bytes != nullptr)); - RETURN_IF_ROCM_ERROR( - wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name), - absl::StrCat("Failed to get symbol '", symbol_name, "'")); - return absl::OkStatus(); -} - -void GpuDriver::UnloadModule(Context* context, hipModule_t module) { - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipModuleUnload(module); - if (res != hipSuccess) { - LOG(ERROR) << "failed to unload module " << module - << "; leaking: " << ToString(res); - } -} - -absl::StatusOr GpuDriver::CreateStream(Context* context, - int priority) { - ScopedActivateContext activated(context); - GpuStreamHandle stream; - if (priority == 0) { - RETURN_IF_ROCM_ERROR( - wrap::hipStreamCreateWithFlags(&stream, hipStreamDefault), - "Failed to create stream"); // switch to hipStreamNonBlocking? - } else { - RETURN_IF_ROCM_ERROR( - wrap::hipStreamCreateWithPriority(&stream, hipStreamDefault, priority), - "Failed to create stream"); // switch to hipStreamNonBlocking? - } - - VLOG(2) << "successfully created stream " << stream << " for device " - << context->device_ordinal() << " on thread"; - return stream; -} - -void GpuDriver::DestroyStream(Context* context, GpuStreamHandle stream) { - if (stream == nullptr) { - return; - } - hipError_t res = wrap::hipStreamQuery(stream); - if (res != hipSuccess) { - LOG(ERROR) << "stream not idle on destroy: " << ToString(res); - } - - ScopedActivateContext activated(context); - res = wrap::hipStreamDestroy(stream); - if (res != hipSuccess) { - LOG(ERROR) << "failed to destroy ROCM stream for device " - << context->device_ordinal() << ": " << ToString(res); - } else { - VLOG(2) << "successfully destroyed stream " << stream << " for device " - << context->device_ordinal(); - } -} - -void* GpuDriver::DeviceAllocate(Context* context, uint64_t bytes) { - if (bytes == 0) { - return nullptr; - } - - ScopedActivateContext activated{context}; - hipDeviceptr_t result = 0; - hipError_t res = wrap::hipMalloc(&result, bytes); - if (res != hipSuccess) { - // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator - // implements a retry if the first allocation fails). - LOG(INFO) << "failed to allocate " - << tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes - << " bytes) from device: " << ToString(res); - return nullptr; - } - void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for device " << context->device_ordinal() - << " of " << bytes << " bytes"; - return ptr; -} - -void GpuDriver::DeviceDeallocate(Context* context, void* location) { - ScopedActivateContext activation{context}; - hipDeviceptr_t pointer = absl::bit_cast(location); - hipError_t res = wrap::hipFree(pointer); - if (res != hipSuccess) { - LOG(ERROR) << "failed to free device memory at " << location - << "; result: " << ToString(res); - } else { - VLOG(2) << "deallocated " << location << " for device " - << context->device_ordinal(); - } -} - -void* GpuDriver::UnifiedMemoryAllocate(Context* context, uint64_t bytes) { - ScopedActivateContext activated{context}; - hipDeviceptr_t result = 0; - // "managed" memory is visible to both CPU and GPU. - hipError_t res = wrap::hipMallocManaged(&result, bytes, hipMemAttachGlobal); - if (res != hipSuccess) { - LOG(ERROR) << "failed to alloc " << bytes - << " bytes unified memory; result: " << ToString(res); - return nullptr; - } - void* ptr = reinterpret_cast(result); - VLOG(2) << "allocated " << ptr << " for context " << context << " of " - << bytes << " bytes in unified memory"; - return ptr; -} - -void GpuDriver::UnifiedMemoryDeallocate(Context* context, void* location) { - ScopedActivateContext activation(context); - hipDeviceptr_t pointer = absl::bit_cast(location); - hipError_t res = wrap::hipFree(pointer); - if (res != hipSuccess) { - LOG(ERROR) << "failed to free unified memory at " << location - << "; result: " << ToString(res); - } else { - VLOG(2) << "deallocated unified memory at " << location << " for context " - << context; - } -} - -void* GpuDriver::HostAllocate(Context* context, uint64_t bytes) { - ScopedActivateContext activation{context}; - void* host_mem = nullptr; - // "Portable" memory is visible to all ROCM contexts. Safe for our use model. - hipError_t res = wrap::hipHostMalloc(&host_mem, bytes, hipHostMallocPortable); - if (res != hipSuccess) { - LOG(ERROR) << "failed to alloc " << bytes - << " bytes on host: " << ToString(res); - } - return host_mem; -} - -void GpuDriver::HostDeallocate(Context* context, void* location) { - ScopedActivateContext activation{context}; - hipError_t res = wrap::hipHostFree(location); - if (res != hipSuccess) { - LOG(ERROR) << "error deallocating host memory at " << location << ": " - << ToString(res); - } -} - -int GpuDriver::GetGpuStreamPriority( - Context* context, stream_executor::StreamPriority stream_priority) { - ScopedActivateContext activation(context); - if (stream_priority == stream_executor::StreamPriority::Default) { - return 0; - } - int lowest, highest; - hipError_t res = wrap::hipDeviceGetStreamPriorityRange(&lowest, &highest); - if (res != hipSuccess) { - LOG(ERROR) - << "Could not query stream priority range. Returning default priority."; - return 0; - } - return stream_priority == stream_executor::StreamPriority::Highest ? highest - : lowest; -} - -absl::Status GpuDriver::DestroyEvent(Context* context, GpuEventHandle* event) { - if (*event == nullptr) { - return absl::InvalidArgumentError("input event cannot be null"); - } - - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipEventDestroy(*event); - *event = nullptr; - - switch (res) { - case hipSuccess: - return absl::OkStatus(); - case hipErrorDeinitialized: - case hipErrorNotInitialized: - return absl::FailedPreconditionError( - absl::StrFormat("error destroying ROCM event in device %d: %s", - context->device_ordinal(), ToString(res).c_str())); - default: - return absl::InternalError( - absl::StrFormat("error destroying ROCM event in device %d: %s", - context->device_ordinal(), ToString(res).c_str())); - } -} - -absl::Status GpuDriver::RecordEvent(Context* context, GpuEventHandle event, - GpuStreamHandle stream) { - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipEventRecord(event, stream); - switch (res) { - case hipSuccess: - return absl::OkStatus(); - case hipErrorDeinitialized: - case hipErrorNotInitialized: - return absl::FailedPreconditionError( - absl::StrFormat("error recording ROCM event on stream %p: %s", stream, - ToString(res).c_str())); - default: - return absl::InvalidArgumentError( - absl::StrFormat("error recording ROCM event on stream %p: %s", stream, - ToString(res).c_str())); - } -} - -absl::StatusOr GpuDriver::GetEventElapsedTime(Context* context, - GpuEventHandle start, - GpuEventHandle stop) { - ScopedActivateContext activated{context}; - // The stop event must have completed in order for hipEventElapsedTime to - // work. - hipError_t res = wrap::hipEventSynchronize(stop); - if (res != hipSuccess) { - LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res); - return false; - } - float elapsed_milliseconds; - RETURN_IF_ROCM_ERROR( - wrap::hipEventElapsedTime(&elapsed_milliseconds, start, stop), - "failed to get elapsed time between events"); - - return elapsed_milliseconds; -} - -absl::Status GpuDriver::WaitStreamOnEvent(Context* context, - GpuStreamHandle stream, - GpuEventHandle event) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR(wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */), - "could not wait stream on event"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronizeContext(Context* context) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR(wrap::hipDeviceSynchronize(), - "could not synchronize on ROCM device"); - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronizeStream(Context* context, - GpuStreamHandle stream) { - ScopedActivateContext activated{context}; - CHECK(stream != nullptr); - RETURN_IF_ROCM_ERROR(wrap::hipStreamSynchronize(stream), - "Could not synchronize on ROCM stream"); - VLOG(2) << "successfully synchronized stream " << stream << " on device " - << context->device_ordinal(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronousMemcpyD2H(Context* context, void* host_dst, - hipDeviceptr_t gpu_src, - uint64_t size) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemcpyDtoH(host_dst, gpu_src, size), - absl::StrFormat("failed to synchronous memcpy from device to host: " - "host dst: %p; Gpu src: %p; size: %llu=0x%llx", - host_dst, absl::bit_cast(gpu_src), size, size)); - VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " - << host_dst; - return absl::OkStatus(); -} - -absl::Status GpuDriver::SynchronousMemcpyH2D(Context* context, - hipDeviceptr_t gpu_dst, - const void* host_src, - uint64_t size) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemcpyHtoD(gpu_dst, const_cast(host_src), size), - absl::StrFormat( - "failed to synchronous memcpy from host to device: Gpu dst: %p;" - " host src: %p; size: %llu=0x%llx", - absl::bit_cast(gpu_dst), host_src, size, size)); - VLOG(2) << "successfully sync memcpy'd h2d of " << size << " bytes"; - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyD2H(Context* context, void* host_dst, - hipDeviceptr_t gpu_src, - uint64_t size, - GpuStreamHandle stream) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream), - absl::StrFormat( - "failed to enqueue async memcpy from device to host: host dst: %p; " - "Gpu src: %p; size: %llu=0x%llx", - host_dst, absl::bit_cast(gpu_src), size, size)); - - VLOG(2) << "successfully enqueued async memcpy d2h of " << size - << " bytes from " << absl::bit_cast(gpu_src) << " to " - << host_dst << " on stream " << stream - << " device: " << context->device_ordinal(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyH2D(Context* context, - hipDeviceptr_t gpu_dst, - const void* host_src, - uint64_t size, - GpuStreamHandle stream) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast(host_src), size, - stream), - absl::StrFormat( - "failed to enqueue async memcpy from host to device: Gpu dst: %p; " - "host src: %p; size: %llu=0x%llx", - absl::bit_cast(gpu_dst), host_src, size, size)); - - VLOG(2) << "successfully enqueued async memcpy h2d of " << size - << " bytes from " << host_src << " to " - << absl::bit_cast(gpu_dst) << " on stream " << stream - << " device: " << context->device_ordinal(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::AsynchronousMemcpyD2D(Context* context, - hipDeviceptr_t gpu_dst, - hipDeviceptr_t gpu_src, - uint64_t size, - GpuStreamHandle stream) { - ScopedActivateContext activation{context}; - RETURN_IF_ROCM_ERROR( - wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream), - absl::StrFormat("failed to enqueue async memcpy from device to device: " - "Gpu dst: %p ; Gpu src: %p ; size: %llu=0x%llx", - absl::bit_cast(gpu_dst), - absl::bit_cast(gpu_src), size, size)); - - VLOG(2) << "successfully enqueued async memcpy d2d of " << size - << " bytes from " << absl::bit_cast(gpu_src) << " to " - << absl::bit_cast(gpu_dst) << " on stream " << stream - << " device: " << context->device_ordinal(); - return absl::OkStatus(); -} - -absl::Status GpuDriver::InitEvent(Context* context, GpuEventHandle* event, - EventFlags flags) { - int hipflags; - switch (flags) { - case EventFlags::kDefault: - hipflags = hipEventDefault; - break; - case EventFlags::kDisableTiming: - hipflags = hipEventDisableTiming | hipEventReleaseToSystem; - break; - default: - LOG(FATAL) << "impossible event flags: " << int(hipflags); - } - - ScopedActivateContext activated{context}; - hipError_t res = wrap::hipEventCreateWithFlags(event, hipflags); - - if (res == hipSuccess) { - return absl::OkStatus(); - } else if (res == hipErrorMemoryAllocation) { - return absl::ResourceExhaustedError( - "could not create ROCM event: out of device memory"); - } else { - return absl::FailedPreconditionError( - absl::StrCat("could not create ROCM event: ", ToString(res))); - } + return ToStatus(wrap::hipGraphExecMemsetNodeSetParams(exec, node, ¶ms), + "Failed to set memset node params"); } int GpuDriver::GetDeviceCount() { @@ -1378,351 +494,11 @@ int GpuDriver::GetDeviceCount() { return device_count; } -absl::Status GpuDriver::GetComputeCapability(int* cc_major, int* cc_minor, - hipDevice_t device) { - return absl::InternalError( - absl::StrFormat("failed to get compute capability for device: %d " - "(unsupported API on AMD Gpus)", - device)); -} - -absl::Status GpuDriver::GetPointerAddressRange(hipDeviceptr_t dptr, - hipDeviceptr_t* base, - size_t* size) { - hipError_t result = wrap::hipMemGetAddressRange(base, size, dptr); - if (result == hipSuccess) { - return absl::OkStatus(); - } else if (result == hipErrorNotFound) { - // We differentiate between "this pointer is unknown" (return here) and - // "there was an internal error while performing this operation" (return - // below). - return absl::NotFoundError(absl::StrFormat("not a device pointer %p; %s", - reinterpret_cast(dptr), - ToString(result).c_str())); - } - - return absl::InternalError( - absl::StrFormat("failed to get pointer into for device pointer %p; %s", - reinterpret_cast(dptr), ToString(result).c_str())); -} - -absl::StatusOr GpuDriver::GetPointerMemorySpace( - hipDeviceptr_t pointer) { - unsigned int value; - hipError_t result = wrap::hipPointerGetAttribute( - &value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer); - if (result == hipSuccess) { - switch (value) { - case hipMemoryTypeDevice: - return MemoryType::kDevice; - case hipMemoryTypeHost: - return MemoryType::kHost; - default: - return absl::InternalError( - absl::StrCat("unknown memory space provided by ROCM API: ", value)); - } - } - - return absl::InternalError(absl::StrCat( - "failed to query device pointer for memory space: ", ToString(result))); -} - -absl::Status GpuDriver::GetGpuISAVersion(int* version, hipDevice_t device) { - hipDeviceProp_t props; - hipError_t result = wrap::hipGetDeviceProperties(&props, device); - if (result == hipSuccess) { - std::string gcnName = props.gcnArchName; - std::vector tokens = absl::StrSplit(gcnName, ':'); - std::string amdgpu_version = gcnName; - if (!tokens.empty() && tokens[0].size() >= 3) { - amdgpu_version = tokens[0].substr(3); - } - *version = stoi(amdgpu_version); - return absl::OkStatus(); - } - *version = 0; - return absl::InternalError(absl::StrFormat( - "failed to determine AMDGpu ISA version for device %d", device)); -} - -absl::Status GpuDriver::GetGpuGCNArchName(hipDevice_t device, - std::string* gcnArchName) { - hipDeviceProp_t props; - hipError_t result = wrap::hipGetDeviceProperties(&props, device); - if (result == hipSuccess) { - *gcnArchName = props.gcnArchName; - return absl::OkStatus(); - } - *gcnArchName = ""; - return absl::InternalError(absl::StrFormat( - "failed to determine AMDGpu GCN Arch Name for device %d", device)); -} - -// Helper function that turns the integer output of hipDeviceGetAttribute to -// type T and wraps it in a absl::StatusOr. -template -static absl::StatusOr GetSimpleAttribute(hipDevice_t device, - hipDeviceAttribute_t attribute) { - int value = -1; - hipError_t result = wrap::hipDeviceGetAttribute(&value, attribute, device); - if (result != hipSuccess) { - return absl::NotFoundError( - absl::StrCat("could not retrieve ROCM device attribute (", attribute, - "): ", ToString(result))); - } - T converted = value; - return converted; -} - -absl::StatusOr GpuDriver::GetMultiprocessorCount(hipDevice_t device) { - return GetSimpleAttribute(device, hipDeviceAttributeMultiprocessorCount); -} - -absl::StatusOr GpuDriver::GetMaxSharedMemoryPerCore( - hipDevice_t device) { - return GetSimpleAttribute( - device, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor); -} - -absl::StatusOr GpuDriver::GetMaxSharedMemoryPerBlock( - hipDevice_t device) { - return GetSimpleAttribute(device, - hipDeviceAttributeMaxSharedMemoryPerBlock); -} - -absl::StatusOr GpuDriver::GetMaxThreadsPerMultiprocessor( - hipDevice_t device) { - return GetSimpleAttribute( - device, hipDeviceAttributeMaxThreadsPerMultiProcessor); -} - -absl::StatusOr GpuDriver::GetMaxRegistersPerBlock(hipDevice_t device) { - return GetSimpleAttribute(device, - hipDeviceAttributeMaxRegistersPerBlock); -} - -absl::StatusOr GpuDriver::GetThreadsPerWarp(hipDevice_t device) { - return GetSimpleAttribute(device, hipDeviceAttributeWarpSize); -} - -absl::Status GpuDriver::GetGridLimits(int* x, int* y, int* z, - hipDevice_t device) { - int value; - RETURN_IF_ROCM_ERROR(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimX, device), - "failed to query max grid dim x"); - *x = value; - - RETURN_IF_ROCM_ERROR(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimY, device), - "failed to query max grid dim y"); - *y = value; - - RETURN_IF_ROCM_ERROR(wrap::hipDeviceGetAttribute( - &value, hipDeviceAttributeMaxGridDimZ, device), - "failed to query max grid dim z"); - *z = value; - return absl::OkStatus(); -} - absl::StatusOr GpuDriver::GetDriverVersion() { int32_t version; - RETURN_IF_ROCM_ERROR(wrap::hipDriverGetVersion(&version), - "Could not get driver version"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipDriverGetVersion(&version), + "Could not get driver version")); return version; } -bool GpuDriver::GetDeviceProperties(hipDeviceProp_t* device_properties, - int device_ordinal) { - hipError_t res = - wrap::hipGetDeviceProperties(device_properties, device_ordinal); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query device properties: " << ToString(res); - return false; - } - - return true; -} - -absl::StatusOr GpuDriver::GetDeviceAttribute( - hipDeviceAttribute_t attribute, hipDevice_t device) { - return GetSimpleAttribute(device, attribute); -} - -bool GpuDriver::IsEccEnabled(hipDevice_t device, bool* result) { - int value = -1; - hipError_t res = hipSuccess; - // TODO(ROCm) implement this feature in HIP - if (res != hipSuccess) { - LOG(ERROR) << "failed to query ECC status: " << ToString(res); - return false; - } - - *result = value; - return true; -} - -bool GetReservedMemory(uint64_t* reserve) { - hipDeviceProp_t props; - hipDevice_t dev; - hipError_t res = wrap::hipGetDevice(&dev); - - if (res != hipSuccess) { - LOG(FATAL) << "failed to query current device: " << ToString(res); - return false; - } - res = wrap::hipGetDeviceProperties(&props, dev); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query device properties: " << ToString(res); - return false; - } - - std::string gcnArchName = props.gcnArchName; - auto compute_capability = RocmComputeCapability(gcnArchName); - // On gfx90a, we hide 1 GB of GPU memory (512MB for gfx908) from TF, - // to allow for late allocations by internal ROCm libraries - // (e.g. rocBLAS alone needs~200 MB to put its kernels as of ROCm 4.1) - const uint64_t RESERVED_GFX908 = 1048576 * 512; - const uint64_t RESERVED_GFX9_X = 1048576 * 1024; - const uint64_t RESERVED_GFX10_X = 1048576 * 512; - const uint64_t RESERVED_GFX11_X = 1048576 * 512; - if (compute_capability.gfx9_mi100()) { - *reserve = RESERVED_GFX908; - } else if (compute_capability.gfx9_mi200_or_later()) { - *reserve = RESERVED_GFX9_X; - } else if (compute_capability.gfx10_rx68xx() || - compute_capability.gfx10_rx69xx()) { - *reserve = RESERVED_GFX10_X; - } else if (compute_capability.gfx11_rx7900()) { - *reserve = RESERVED_GFX11_X; - } - - return true; -} - -bool GpuDriver::GetDeviceMemoryInfo(Context* context, int64_t* free_out, - int64_t* total_out) { - ScopedActivateContext activation{context}; - size_t free = 0; - size_t total = 0; - hipError_t res = wrap::hipMemGetInfo(&free, &total); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query device memory info: " << ToString(res); - return false; - } - - uint64_t reserve = 0; - if (!GetReservedMemory(&reserve)) { - LOG(ERROR) << "failed to reserved device memory for ROCm libraries"; - return false; - } - - VLOG(1) << "Device memory: " << total / 1048576 << " MB total, " - << free / 1048576 << " MB free, reserving " << reserve / 1048576 - << " MB"; - - // overflow check - if (free > std::numeric_limits::max()) { - LOG(ERROR) << "free memory (" << free << ") is overflow int64_t"; - return false; - } - - *free_out = free >= reserve ? free - reserve : 0; - *total_out = total - reserve; - return true; -} - -bool GpuDriver::GetDeviceTotalMemory(hipDevice_t device, uint64_t* result) { - size_t value = -1; - hipError_t res = wrap::hipDeviceTotalMem(&value, device); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query total available memory: " << ToString(res); - return false; - } - uint64_t reserve = 0; - if (!GetReservedMemory(&reserve)) { - LOG(ERROR) << "failed to reserved device memory for ROCm libraries"; - return false; - } - *result = value - reserve; - return true; -} - -std::string GpuDriver::GetPCIBusID(hipDevice_t device) { - std::string pci_bus_id; - static const int kBufferSize = 64; - absl::InlinedVector chars(kBufferSize); - chars[kBufferSize - 1] = '\0'; - hipError_t res = - wrap::hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); - if (res != hipSuccess) { - LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res); - return pci_bus_id; - } - pci_bus_id = chars.begin(); - return pci_bus_id; -} - -bool GpuDriver::CanEnablePeerAccess(Context* from, Context* to) { - // A context can always access its own memory. - if (from == to) return true; - - auto from_device = DeviceFromContext(from); - if (!from_device.ok()) { - LOG(ERROR) << "failed to resolve 'from' peer access context to a device: " - << from_device.status(); - return false; - } - - auto to_device = DeviceFromContext(to); - if (!to_device.ok()) { - LOG(ERROR) << "failed to resolve 'to' peer access context to a device: " - << to_device.status(); - return false; - } - return CanEnablePeerAccess(from_device.value(), to_device.value()); -} - -bool GpuDriver::CanEnablePeerAccess(GpuDeviceHandle from, GpuDeviceHandle to) { - int can_access_peer = -1; - hipError_t result = wrap::hipDeviceCanAccessPeer(&can_access_peer, from, to); - if (result != hipSuccess) { - LOG(ERROR) << "failed to detect peer access capability: " - << ToString(result); - return false; - } - return can_access_peer; -} - -absl::Status GpuDriver::EnablePeerAccess(Context* from, Context* to) { - if (from == to) { - return absl::OkStatus(); // A device can always access its own memory. - } - - ScopedActivateContext activated{from}; - hipError_t result = wrap::hipCtxEnablePeerAccess( - tensorflow::down_cast(to)->context(), 0 /* = flags */); - if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) { - return absl::InternalError( - absl::StrFormat("failed to enable peer access from %d to %d: %s", - from->device_ordinal(), to->device_ordinal(), - ToString(result).c_str())); - } - - return absl::OkStatus(); -} - -absl::StatusOr GpuDriver::GetMaxOccupiedBlocksPerCore( - Context* context, hipFunction_t kernel, int threads_per_block, - size_t dynamic_shared_memory_bytes) { - ScopedActivateContext activation{context}; - - int max_blocks = 0; - RETURN_IF_ROCM_ERROR( - wrap::hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( - &max_blocks, kernel, threads_per_block, dynamic_shared_memory_bytes), - "Failed to calculate maximal active blocks per SM"); - return max_blocks; -} - } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver.h deleted file mode 100644 index aa1538f3886fbd..00000000000000 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The ROCM-specific Driver library support, implementing the general Driver -// interface. - -#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_H_ -#define XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_H_ - -#include "absl/container/node_hash_map.h" -#include "absl/memory/memory.h" -#include "absl/strings/str_cat.h" -#include "absl/synchronization/mutex.h" -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "tsl/platform/logging.h" - -namespace stream_executor { -namespace gpu { -// Formats hipError_t to output prettified values into a log stream. -// Error summaries taken from: -std::string ToString(hipError_t result); - -// GpuContext implements the Context class for ROCm GPUs. -class GpuContext : public Context { - public: - GpuContext(hipCtx_t context, const int ordinal) - : context_(context), device_ordinal_(ordinal) {} - - hipCtx_t context() const { return context_; } - void SetActive() override; - bool IsActive() const override; - int device_ordinal() const override { return device_ordinal_; } - - // Disallow copying and moving. - GpuContext(GpuContext&&) = delete; - GpuContext(const GpuContext&) = delete; - GpuContext& operator=(GpuContext&&) = delete; - GpuContext& operator=(const GpuContext&) = delete; - - private: - hipCtx_t const context_; - const int device_ordinal_; -}; - -} // namespace gpu -} // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h index 6d541adef338e8..391890e012f2ef 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver_wrapper.h @@ -20,12 +20,9 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCM_DRIVER_WRAPPER_H_ -#define __HIP_DISABLE_CPP_FUNCTIONS__ - #include "rocm/include/hip/hip_runtime.h" #include "rocm/rocm_config.h" -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -47,22 +44,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ - template \ - auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char *kName = TO_STR(hipSymbolName); \ - void *f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetHipDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in HIP DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define STREAM_EXECUTOR_HIP_WRAP(hipSymbolName) \ + template \ + auto hipSymbolName(Args... args) -> decltype(::hipSymbolName(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char *kName = TO_STR(hipSymbolName); \ + void *f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetHipDsoHandle().value(), kName, \ + &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in HIP DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_event.cc b/third_party/xla/xla/stream_executor/rocm/rocm_event.cc index d770af56a9f5a4..1eb63f12df2a19 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_event.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_event.cc @@ -15,18 +15,85 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_event.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/rocm/rocm_driver.h" +#include +#include + +#include "absl/base/casts.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/event.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace stream_executor { namespace gpu { +namespace { +absl::Status WaitStreamOnEvent(StreamExecutor *executor, hipStream_t stream, + hipEvent_t event) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */), + "could not wait stream on event")); + return absl::OkStatus(); +} + +enum class EventFlags { kDefault, kDisableTiming }; +absl::StatusOr InitEvent(StreamExecutor *executor, + EventFlags flags) { + int hipflags; + switch (flags) { + case EventFlags::kDefault: + hipflags = hipEventDefault; + break; + case EventFlags::kDisableTiming: + hipflags = hipEventDisableTiming | hipEventReleaseToSystem; + break; + default: + LOG(FATAL) << "impossible event flags: " << int(hipflags); + } + + std::unique_ptr activation = executor->Activate(); + hipEvent_t event; + hipError_t res = wrap::hipEventCreateWithFlags(&event, hipflags); + + if (res == hipSuccess) { + return event; + } + if (res == hipErrorMemoryAllocation) { + return absl::ResourceExhaustedError( + "could not create ROCM event: out of device memory"); + } + return absl::FailedPreconditionError( + absl::StrCat("could not create ROCM event: ", ToString(res))); +} + +void DestroyEvent(StreamExecutor *executor, hipEvent_t event) { + if (event == nullptr) { + return; + } + + std::unique_ptr activation = executor->Activate(); + hipError_t res = wrap::hipEventDestroy(event); + + if (res != hipSuccess) { + LOG(ERROR) << absl::StrFormat( + "error destroying ROCM event in device %d: %s", + executor->device_ordinal(), ToString(res)); + } +} + +} // namespace Event::Status RocmEvent::PollForStatus() { - ScopedActivateContext activated(context()); - hipError_t res = wrap::hipEventQuery(gpu_event()); + std::unique_ptr activated = executor_->Activate(); + hipError_t res = wrap::hipEventQuery(handle_); if (res == hipSuccess) { return Event::Status::kComplete; @@ -37,5 +104,41 @@ Event::Status RocmEvent::PollForStatus() { return Event::Status::kError; } +absl::Status RocmEvent::WaitForEventOnExternalStream(std::intptr_t stream) { + return WaitStreamOnEvent(executor_, absl::bit_cast(stream), + handle_); +} + +absl::StatusOr RocmEvent::Create(StreamExecutor *executor, + bool allow_timing) { + TF_ASSIGN_OR_RETURN( + hipEvent_t event_handle, + InitEvent(executor, allow_timing ? EventFlags::kDefault + : EventFlags::kDisableTiming)); + + return RocmEvent(executor, event_handle); +} + +RocmEvent::~RocmEvent() { DestroyEvent(executor_, handle_); } + +RocmEvent::RocmEvent(RocmEvent &&other) + : executor_(other.executor_), handle_(other.handle_) { + other.executor_ = nullptr; + other.handle_ = nullptr; +} + +RocmEvent& RocmEvent::operator=(RocmEvent&& other) { + if (this == &other) { + return *this; + } + + DestroyEvent(executor_, handle_); + + executor_ = other.executor_; + handle_ = other.handle_; + other.executor_ = nullptr; + other.handle_ = nullptr; + return *this; +} } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_event.h b/third_party/xla/xla/stream_executor/rocm/rocm_event.h index d81a207f3c2eed..81e0cbbaa03863 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_event.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_event.h @@ -16,17 +16,44 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCM_EVENT_H_ -#include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_event.h" +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor::gpu { -// This class implements Event::PollForStatus for ROCm devices. -class RocmEvent : public GpuEvent { +// This class implements Event for ROCm devices. +class RocmEvent : public Event { public: - explicit RocmEvent(Context *context) : GpuEvent(context) {} - Event::Status PollForStatus() override; + absl::Status WaitForEventOnExternalStream(std::intptr_t stream) override; + + // Creates a new RocmEvent. If allow_timing is false, the event will not + // support timing, which is cheaper to create. + static absl::StatusOr Create(StreamExecutor* executor, + bool allow_timing); + + hipEvent_t GetHandle() const { return handle_; } + + ~RocmEvent() override; + RocmEvent(const RocmEvent&) = delete; + RocmEvent& operator=(const RocmEvent&) = delete; + RocmEvent(RocmEvent&& other); + RocmEvent& operator=(RocmEvent&& other); + + private: + explicit RocmEvent(StreamExecutor* executor, hipEvent_t handle) + : executor_(executor), handle_(handle) {} + + // The Executor used to which this object and hipEvent_t are bound. + StreamExecutor* executor_; + + // The underlying CUDA event handle. + hipEvent_t handle_; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_event_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_event_test.cc new file mode 100644 index 00000000000000..5f1db89a8b52e8 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_event_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_event.h" + +#include + +#include +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { + +TEST(RocmEventTest, CreateEvent) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::rocm::kROCmPlatformId)); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + + TF_ASSERT_OK_AND_ASSIGN(RocmEvent event, RocmEvent::Create(executor, false)); + + EXPECT_NE(event.GetHandle(), nullptr); + EXPECT_EQ(event.PollForStatus(), Event::Status::kComplete); + + hipEvent_t handle = event.GetHandle(); + RocmEvent event2 = std::move(event); + EXPECT_EQ(event2.GetHandle(), handle); +} + +} // namespace + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index ac368c3e2b63c2..67e09ad4d5bb85 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + #include "xla/stream_executor/rocm/rocm_executor.h" #include @@ -21,23 +22,28 @@ limitations under the License. #include #include #include +#include #include #include +#include #include "absl/base/casts.h" -#include "absl/functional/any_invocable.h" +#include "absl/container/inlined_vector.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/strings/ascii.h" -#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" #include "absl/types/span.h" +#include "rocm/include/hip/driver_types.h" +#include "rocm/include/hip/hip_runtime.h" #include "rocm/include/hip/hip_version.h" #include "rocm/rocm_config.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_description.h" @@ -47,94 +53,63 @@ limitations under the License. #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/context.h" -#include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_command_buffer.h" -#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" -#include "xla/stream_executor/gpu/gpu_executor.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/gpu/read_numa_node.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" +#include "xla/stream_executor/gpu/scoped_activate_context.h" +#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" -#include "xla/stream_executor/rocm/rocm_diagnostics.h" -#include "xla/stream_executor/rocm/rocm_driver.h" +#include "xla/stream_executor/rocm/rocm_command_buffer.h" +#include "xla/stream_executor/rocm/rocm_context.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/rocm/rocm_kernel.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/rocm/rocm_runtime.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "xla/stream_executor/rocm/rocm_stream.h" +#include "xla/stream_executor/rocm/rocm_timer.h" #include "xla/stream_executor/rocm/rocm_version_parser.h" #include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" +#include "tsl/platform/numbers.h" #include "tsl/platform/statusor.h" - -#define RETURN_IF_ROCM_ERROR(expr, ...) \ - do { \ - hipError_t _res = (expr); \ - if (TF_PREDICT_FALSE(_res != hipSuccess)) { \ - if (_res == hipErrorOutOfMemory) \ - return absl::ResourceExhaustedError(absl::StrCat( \ - __VA_ARGS__, ":", ::stream_executor::gpu::ToString(_res))); \ - else \ - return absl::InternalError(absl::StrCat( \ - __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(_res))); \ - } \ - } while (0) +#include "tsl/platform/threadpool.h" namespace stream_executor { namespace gpu { +namespace { // Given const GPU memory, returns a librocm device pointer datatype, suitable // for passing directly to librocm APIs. // // N.B. we must lose constness in order to pass a suitable type to the existing // librocm APIs, so the caller should take care to only pass the result of const // GPU memory conversions to librocm functions which will honor constness. -static hipDeviceptr_t AsROCmDevicePtr(const DeviceMemoryBase& gpu_mem) { +hipDeviceptr_t AsROCmDevicePtr(const DeviceMemoryBase& gpu_mem) { return const_cast(gpu_mem.opaque()); } // See description on const version above. -static hipDeviceptr_t AsROCmDevicePtr(DeviceMemoryBase* gpu_mem) { +hipDeviceptr_t AsROCmDevicePtr(DeviceMemoryBase* gpu_mem) { return AsROCmDevicePtr(*gpu_mem); } -RocmExecutor::~RocmExecutor() { - for (auto& it : disk_modules_) { - GpuDriver::UnloadModule(gpu_context(), it.second); - } - for (auto& it : in_memory_modules_) { - GpuDriver::UnloadModule(gpu_context(), it.second); - } - if (gpu_context() != nullptr) { - GpuDriver::DestroyContext(gpu_context()); - } - CHECK(kernel_to_gpu_binary_.empty()) << "GpuExecutor has live kernels."; - CHECK(gpu_binary_to_module_.empty()) << "GpuExecutor has loaded modules."; -} -bool RocmExecutor::UnloadModule(ModuleHandle module_handle) { - const char* gpu_binary = reinterpret_cast(module_handle.id()); - absl::MutexLock lock{&in_memory_modules_mu_}; - return UnloadGpuBinary(gpu_binary); -} - -namespace { absl::uint128 Fingerprint128(const absl::string_view s) { auto fp = tsl::Fingerprint128(s); return absl::MakeUint128(fp.high64, fp.low64); @@ -150,15 +125,384 @@ int fpus_per_core(std::string gcn_arch_name) { return n; } -absl::Status FuncGetAttribute(hipFunction_attribute attribute, - hipFunction_t func, int* attribute_value) { - RETURN_IF_ROCM_ERROR( - wrap::hipFuncGetAttribute(attribute_value, attribute, func), - "Failed to query kernel attribute: ", attribute); +// ROCM driver routines may require a large amount of stack (particularly +// hipModuleLoadDataEx, in our experience). To avoid stack overflow when using +// stack-limited threads (such as those spawned by a default-argument +// thread::ThreadPool on some platforms), we run certain routines in this pool +// and wait for completion. +tsl::thread::ThreadPool* GetDriverExecutor() { + static tsl::thread::ThreadPool* thread_pool = new tsl::thread::ThreadPool( + tsl::Env::Default(), tsl::ThreadOptions(), "rocm_driver", 1); + return thread_pool; +} + +// Loads HSACO with the ROCM runtime and stores the resulting handle in +// "module". Any error logs that are produced are logged internally. +absl::StatusOr LoadHsaco(Context* context, + const char* hsaco_contents) { + absl::Notification notification; + absl::Status returned_status = absl::OkStatus(); + hipModule_t module; + GetDriverExecutor()->Schedule( + [context, hsaco_contents, &module, &returned_status, ¬ification]() { + ScopedActivateContext activation(context); + hipError_t res = wrap::hipModuleLoadData(&module, hsaco_contents); + + if (res != hipSuccess) { + returned_status = absl::InternalError( + absl::StrCat("Failed to load HSACO: ", ToString(res))); + notification.Notify(); + } + + CHECK(module != nullptr); + notification.Notify(); + }); + notification.WaitForNotification(); + + TF_RETURN_IF_ERROR(returned_status); + return module; +} + +// Retrieves a named kernel from a loaded module, and places the resulting +// handle into function (outparam) on success. Neither kernel_name nor +// function may be null. No ownership is taken of kernel_name. +absl::StatusOr GetModuleFunction(Context* context, + hipModule_t module, + const char* kernel_name) { + ScopedActivateContext activated(context); + CHECK(module != nullptr && kernel_name != nullptr); + hipFunction_t function; + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipModuleGetFunction(&function, module, kernel_name), + "Failed to get kernel")); + return function; +} + +// Retrieves a named global/constant symbol from a loaded module, and returns +// a device pointer and size of the symbol on success. symbol_name may not be +// null. At least one of dptr or bytes should not be null. No ownership is +// taken of symbol_name. +absl::Status GetModuleSymbol(Context* context, hipModule_t module, + const char* symbol_name, hipDeviceptr_t* dptr, + size_t* bytes) { + ScopedActivateContext activated(context); + CHECK(module != nullptr && symbol_name != nullptr && + (dptr != nullptr || bytes != nullptr)); + return ToStatus(wrap::hipModuleGetGlobal(dptr, bytes, module, symbol_name), + absl::StrCat("Failed to get symbol '", symbol_name, "'")); +} + +// Unloads module from the current context via cuModuleUnload. +void UnloadRocmModule(Context* context, hipModule_t module) { + ScopedActivateContext activated(context); + hipError_t res = wrap::hipModuleUnload(module); + if (res != hipSuccess) { + LOG(ERROR) << "failed to unload module " << module + << "; leaking: " << ToString(res); + } +} + +// Returns the name of the device. +absl::StatusOr GetDeviceName(hipDevice_t device) { + static const size_t kCharLimit = 64; + absl::InlinedVector chars(kCharLimit); + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipDeviceGetName(chars.begin(), kCharLimit - 1, device), + "Failed to get device name")); + chars[kCharLimit - 1] = '\0'; + return chars.begin(); +} + +absl::StatusOr GetGpuISAVersion(hipDevice_t device) { + hipDeviceProp_t props; + hipError_t result = wrap::hipGetDeviceProperties(&props, device); + if (result == hipSuccess) { + std::string gcnName = props.gcnArchName; + std::vector tokens = absl::StrSplit(gcnName, ':'); + std::string amdgpu_version = gcnName; + if (!tokens.empty() && tokens[0].size() >= 3) { + amdgpu_version = tokens[0].substr(3); + } + int version = std::stoi(amdgpu_version); + return version; + } + return absl::InternalError(absl::StrFormat( + "failed to determine AMDGpu ISA version for device %d", device)); +} + +// Return the full GCN Architecture Name for the device +// for eg: amdgcn-amd-amdhsa--gfx908:sramecc+:xnack- +absl::StatusOr GetGpuGCNArchName(hipDevice_t device) { + hipDeviceProp_t props; + hipError_t result = wrap::hipGetDeviceProperties(&props, device); + if (result == hipSuccess) { + return props.gcnArchName; + } + return absl::InternalError(absl::StrFormat( + "failed to determine AMDGpu GCN Arch Name for device %d", device)); +} + +// Helper function that turns the integer output of hipDeviceGetAttribute to +// type T and wraps it in a absl::StatusOr. +template +static absl::StatusOr GetSimpleAttribute(hipDevice_t device, + hipDeviceAttribute_t attribute) { + int value = -1; + hipError_t result = wrap::hipDeviceGetAttribute(&value, attribute, device); + if (result != hipSuccess) { + return absl::NotFoundError( + absl::StrCat("could not retrieve ROCM device attribute (", attribute, + "): ", ToString(result))); + } + T converted = value; + return converted; +} + +// Returns the number of multiprocessors on the device (note that the device +// may be multi-GPU-per-board). + +absl::StatusOr GetMultiprocessorCount(hipDevice_t device) { + return GetSimpleAttribute(device, hipDeviceAttributeMultiprocessorCount); +} + +absl::StatusOr GetMaxSharedMemoryPerCore(hipDevice_t device) { + return GetSimpleAttribute( + device, hipDeviceAttributeMaxSharedMemoryPerMultiprocessor); +} + +absl::StatusOr GetMaxSharedMemoryPerBlock(hipDevice_t device) { + return GetSimpleAttribute(device, + hipDeviceAttributeMaxSharedMemoryPerBlock); +} + +absl::StatusOr GetMaxThreadsPerMultiprocessor(hipDevice_t device) { + return GetSimpleAttribute( + device, hipDeviceAttributeMaxThreadsPerMultiProcessor); +} + +absl::StatusOr GetMaxRegistersPerBlock(hipDevice_t device) { + return GetSimpleAttribute(device, + hipDeviceAttributeMaxRegistersPerBlock); +} + +absl::StatusOr GetThreadsPerWarp(hipDevice_t device) { + return GetSimpleAttribute(device, hipDeviceAttributeWarpSize); +} + +absl::Status GetGridLimits(int* x, int* y, int* z, hipDevice_t device) { + int value; + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxGridDimX, device), + "failed to query max grid dim x")); + *x = value; + + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxGridDimY, device), + "failed to query max grid dim y")); + *y = value; + + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxGridDimZ, device), + "failed to query max grid dim z")); + *z = value; return absl::OkStatus(); } + +// Returns the device associated with the given device_ordinal. +absl::StatusOr GetDevice(int device_ordinal) { + hipDevice_t device; + hipError_t res = wrap::hipDeviceGet(&device, device_ordinal); + if (res == hipSuccess) { + return device; + } + + return absl::InternalError( + absl::StrCat("failed call to hipDeviceGet: ", ToString(res))); +} + +// Returns the device associated with the given context. +absl::StatusOr DeviceFromContext(Context* context) { + ScopedActivateContext activated(context); + hipDevice_t device = -1; + hipError_t result = wrap::hipCtxGetDevice(&device); + if (result == hipSuccess) return device; + + return absl::InternalError( + absl::StrCat("failed to get device for context: ", ToString(result))); +} + +bool CanEnablePeerAccess(hipDevice_t from, hipDevice_t to) { + int can_access_peer = -1; + hipError_t result = wrap::hipDeviceCanAccessPeer(&can_access_peer, from, to); + if (result != hipSuccess) { + LOG(ERROR) << "failed to detect peer access capability: " + << ToString(result); + return false; + } + return can_access_peer; +} + +bool CanEnablePeerAccess(Context* from, Context* to) { + // A context can always access its own memory. + if (from == to) return true; + + auto from_device = DeviceFromContext(from); + if (!from_device.ok()) { + LOG(ERROR) << "failed to resolve 'from' peer access context to a device: " + << from_device.status(); + return false; + } + + auto to_device = DeviceFromContext(to); + if (!to_device.ok()) { + LOG(ERROR) << "failed to resolve 'to' peer access context to a device: " + << to_device.status(); + return false; + } + return CanEnablePeerAccess(from_device.value(), to_device.value()); +} + +absl::Status EnablePeerAccess(Context* from, Context* to) { + if (from == to) { + return absl::OkStatus(); // A device can always access its own memory. + } + + ScopedActivateContext activated(from); + hipError_t result = wrap::hipCtxEnablePeerAccess( + tensorflow::down_cast(to)->context(), 0 /* = flags */); + if (result != hipSuccess && result != hipErrorPeerAccessAlreadyEnabled) { + return absl::InternalError( + absl::StrFormat("failed to enable peer access from %d to %d: %s", + from->device_ordinal(), to->device_ordinal(), + ToString(result).c_str())); + } + + return absl::OkStatus(); +} + +std::string GetPCIBusID(hipDevice_t device) { + std::string pci_bus_id; + static const int kBufferSize = 64; + absl::InlinedVector chars(kBufferSize); + chars[kBufferSize - 1] = '\0'; + hipError_t res = + wrap::hipDeviceGetPCIBusId(chars.begin(), kBufferSize - 1, device); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query PCI bus id for device: " << ToString(res); + return pci_bus_id; + } + pci_bus_id = chars.begin(); + return pci_bus_id; +} + +bool GetDeviceProperties(hipDeviceProp_t* device_properties, + int device_ordinal) { + hipError_t res = + wrap::hipGetDeviceProperties(device_properties, device_ordinal); + if (res != hipSuccess) { + LOG(ERROR) << "failed to query device properties: " << ToString(res); + return false; + } + + return true; +} + +// Allocates memory on the GPU device. +void* DeviceAllocate(Context* context, uint64_t bytes) { + if (bytes == 0) { + return nullptr; + } + + ScopedActivateContext activated(context); + hipDeviceptr_t result = 0; + hipError_t res = wrap::hipMalloc(&result, bytes); + if (res != hipSuccess) { + // LOG(INFO) because this isn't always important to users (e.g. BFCAllocator + // implements a retry if the first allocation fails). + LOG(INFO) << "failed to allocate " + << tsl::strings::HumanReadableNumBytes(bytes) << " (" << bytes + << " bytes) from device: " << ToString(res); + return nullptr; + } + void* ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for device " << context->device_ordinal() + << " of " << bytes << " bytes"; + return ptr; +} + +// Deallocates memory on the GPU device that was previously allocated via +// DeviceAllocate. +void DeviceDeallocate(Context* context, void* location) { + ScopedActivateContext activation(context); + hipDeviceptr_t pointer = absl::bit_cast(location); + hipError_t res = wrap::hipFree(pointer); + if (res != hipSuccess) { + LOG(ERROR) << "failed to free device memory at " << location + << "; result: " << ToString(res); + } else { + VLOG(2) << "deallocated " << location << " for device " + << context->device_ordinal(); + } +} + +// Allocates memory on the host. +void* HostAllocate(Context* context, uint64_t bytes) { + ScopedActivateContext activation(context); + void* host_mem = nullptr; + // "Portable" memory is visible to all ROCM contexts. Safe for our use model. + hipError_t res = wrap::hipHostMalloc(&host_mem, bytes, hipHostMallocPortable); + if (res != hipSuccess) { + LOG(ERROR) << "failed to alloc " << bytes + << " bytes on host: " << ToString(res); + } + return host_mem; +} + } // namespace +RocmExecutor::~RocmExecutor() { + for (auto& it : in_memory_modules_) { + UnloadRocmModule(rocm_context_, it.second); + } + set_context(nullptr); + CHECK(kernel_to_gpu_binary_.empty()) << "RocmExecutor has live kernels."; + CHECK(gpu_binary_to_module_.empty()) << "RocmExecutor has loaded modules."; +} + +std::unique_ptr RocmExecutor::Activate() { + return std::make_unique(rocm_context_); +} + +bool RocmExecutor::UnloadModule(ModuleHandle module_handle) { + absl::MutexLock lock{&in_memory_modules_mu_}; + return UnloadGpuBinary(module_handle); +} + +absl::StatusOr RocmExecutor::GetMemoryRange( + const DeviceMemoryBase& location) { + hipDeviceptr_t device_pointer; + size_t size; + hipError_t result = wrap::hipMemGetAddressRange( + &device_pointer, &size, const_cast(location.opaque())); + if (result == hipSuccess) { + return DeviceMemoryBase(device_pointer, size); + } else if (result == hipErrorNotFound) { + // We differentiate between "this pointer is unknown" (return here) and + // "there was an internal error while performing this operation" (return + // below). + return absl::NotFoundError(absl::StrFormat("not a device pointer %p; %s", + location.opaque(), + ToString(result).c_str())); + } + + return absl::InternalError( + absl::StrFormat("failed to get pointer into for device pointer %p; %s", + location.opaque(), ToString(result).c_str())); +} + absl::StatusOr> RocmExecutor::CreateOrShareConstant(Stream* stream, absl::Span content) { @@ -214,18 +558,15 @@ RocmExecutor::CreateOrShareConstant(Stream* stream, } absl::StatusOr> -RocmExecutor::CreateEventBasedTimer(GpuStream* stream, bool use_delay_kernel) { - TF_ASSIGN_OR_RETURN(auto start_event, CreateGpuEvent(/*allow_timing=*/true)); - TF_ASSIGN_OR_RETURN(auto stop_event, CreateGpuEvent(/*allow_timing=*/true)); - TF_RETURN_IF_ERROR(start_event->Record(stream->gpu_stream())); - return std::make_unique(gpu_context(), std::move(start_event), - std::move(stop_event), stream); +RocmExecutor::CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { + TF_ASSIGN_OR_RETURN(auto timer, RocmTimer::Create(this, stream)); + return std::make_unique(std::move(timer)); } -bool RocmExecutor::UnloadGpuBinary(const void* gpu_binary) { - auto module_it = gpu_binary_to_module_.find(gpu_binary); +bool RocmExecutor::UnloadGpuBinary(ModuleHandle module_handle) { + auto module_it = gpu_binary_to_module_.find(module_handle); if (gpu_binary_to_module_.end() == module_it) { - VLOG(3) << "No loaded HSACO module for " << gpu_binary; + VLOG(3) << "No loaded HSACO module for " << module_handle; return false; } auto& module = module_it->second.first; @@ -233,13 +574,13 @@ bool RocmExecutor::UnloadGpuBinary(const void* gpu_binary) { VLOG(3) << "Found HSACO module " << module << " with refcount " << refcount; if (--refcount == 0) { VLOG(3) << "Unloading HSACO module " << module; - GpuDriver::UnloadModule(gpu_context(), module); + UnloadRocmModule(rocm_context_, module); gpu_binary_to_module_.erase(module_it); - const char* mem_it = nullptr; + ModuleHandle mem_it{}; for (auto x : in_memory_modules_) { if (x.second == module) mem_it = x.first; } - if (mem_it != nullptr) in_memory_modules_.erase(mem_it); + if (mem_it != ModuleHandle{}) in_memory_modules_.erase(mem_it); } return true; } @@ -261,21 +602,18 @@ void RocmExecutor::UnloadKernel(const Kernel* kernel) { } absl::Status RocmExecutor::Init() { - TF_RETURN_IF_ERROR(GpuDriver::Init()); - - TF_RETURN_IF_ERROR(GpuDriver::GetDevice(device_ordinal(), &device_)); + TF_ASSIGN_OR_RETURN(device_, GetDevice(device_ordinal())); - Context* context; - TF_RETURN_IF_ERROR( - GpuDriver::CreateContext(device_ordinal(), device_, &context)); - set_context(context); - return GpuDriver::GetGpuISAVersion(&version_, device_); + TF_ASSIGN_OR_RETURN(rocm_context_, + RocmContext::Create(device_ordinal(), device_)); + set_context(rocm_context_); + TF_ASSIGN_OR_RETURN(version_, GetGpuISAVersion(device_)); + return absl::OkStatus(); } absl::StatusOr> RocmExecutor::LoadKernel( const MultiKernelLoaderSpec& spec) { - auto rocm_kernel = std::make_unique(this); - hipModule_t module = nullptr; + auto rocm_kernel = std::make_unique(this); const std::string* kernel_name; if (spec.has_cuda_cubin_in_memory()) { @@ -284,12 +622,19 @@ absl::StatusOr> RocmExecutor::LoadKernel( const char* hsaco = reinterpret_cast( spec.cuda_cubin_in_memory().cubin_bytes().data()); absl::MutexLock lock{&in_memory_modules_mu_}; - module = in_memory_modules_[hsaco]; + ModuleHandle module_handle{hsaco}; + hipModule_t& module = in_memory_modules_[module_handle]; if (module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(gpu_context(), hsaco, &module)); + TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); } - kernel_to_gpu_binary_[rocm_kernel.get()] = hsaco; + kernel_to_gpu_binary_[rocm_kernel.get()] = module_handle; + + VLOG(2) << "getting function " << *kernel_name << " from module " << module; + TF_ASSIGN_OR_RETURN( + hipFunction_t function, + GetModuleFunction(rocm_context_, module, kernel_name->c_str())); + rocm_kernel->set_gpu_function(function); } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); void* symbol = spec.in_process_symbol().symbol(); @@ -299,8 +644,8 @@ absl::StatusOr> RocmExecutor::LoadKernel( #if TF_ROCM_VERSION >= 60200 TF_ASSIGN_OR_RETURN( - GpuFunctionHandle function, - GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); + hipFunction_t function, + RocmRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); rocm_kernel->set_gpu_function(function); #else rocm_kernel->set_gpu_function( @@ -311,24 +656,14 @@ absl::StatusOr> RocmExecutor::LoadKernel( return absl::InternalError("No method of loading ROCM kernel provided"); } - // If we resolved kernel from a symbol pointer, there is no need to load it - // from a module, as ROCm runtime did that automatically for us. - if (!spec.has_in_process_symbol()) { - VLOG(2) << "getting function " << *kernel_name << " from module " << module; - GpuFunctionHandle function; - TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( - gpu_context(), module, kernel_name->c_str(), &function)); - rocm_kernel->set_gpu_function(function); - } - // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the ROCM API. rocm_kernel->set_arity(spec.arity()); // unable to get kernel metadata for in-process kernel if (!spec.has_in_process_symbol()) { - KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel.get(), &kernel_metadata)); + TF_ASSIGN_OR_RETURN(KernelMetadata kernel_metadata, + rocm_kernel->GetKernelMetadata()); rocm_kernel->set_metadata(kernel_metadata); } rocm_kernel->set_name(*kernel_name); @@ -336,98 +671,147 @@ absl::StatusOr> RocmExecutor::LoadKernel( return std::move(rocm_kernel); } -absl::Status RocmExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, - KernelMetadata* kernel_metadata) { - int value = 0; - TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_NUM_REGS, - rocm_kernel->gpu_function(), &value)); - kernel_metadata->set_registers_per_thread(value); - - TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - rocm_kernel->gpu_function(), &value)); - kernel_metadata->set_shared_memory_bytes(value); - return absl::OkStatus(); -} +absl::StatusOr RocmExecutor::LoadModule( + const MultiModuleLoaderSpec& spec) { + // We store the pointer to the HSACO binary as ModuleHandle::id(). -absl::Status RocmExecutor::LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { - // In GpuExecutor we store the pointer to the HSACO binary as - // ModuleHandle::id(). - hipModule_t hip_module = nullptr; // TODO(ROCm): Need generic term instead of cubin/cuda/ptx if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; - TF_RETURN_IF_ERROR(LoadModuleFromHsaco( - reinterpret_cast(spec.cuda_cubin_in_memory().data()), - &hip_module)); - *module_handle = ModuleHandle(const_cast( - static_cast(spec.cuda_cubin_in_memory().data()))); - return absl::OkStatus(); + return LoadModuleFromHsaco( + reinterpret_cast(spec.cuda_cubin_in_memory().data())); } else { return absl::InternalError("No HASCO binary found"); } } -absl::Status RocmExecutor::LoadModuleFromHsaco(const char* hsaco, - hipModule_t* module) { +absl::StatusOr RocmExecutor::LoadModuleFromHsaco( + const char* hsaco) { + ModuleHandle module_handle{hsaco}; uint64_t module_refcount; - std::tie(*module, module_refcount) = gpu_binary_to_module_[hsaco]; + hipModule_t module; + std::tie(module, module_refcount) = gpu_binary_to_module_[module_handle]; - if (*module == nullptr) { - TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(gpu_context(), hsaco, module)); + if (module == nullptr) { + TF_ASSIGN_OR_RETURN(module, LoadHsaco(rocm_context_, hsaco)); module_refcount = 1; - in_memory_modules_[hsaco] = *module; + in_memory_modules_[module_handle] = module; VLOG(3) << "Loaded HSACO " << static_cast(hsaco) - << " as module " << *module; + << " as module " << module; } else { ++module_refcount; VLOG(3) << "HSACO " << static_cast(hsaco) - << " is already loaded as module " << *module; + << " is already loaded as module " << module; } - gpu_binary_to_module_[hsaco] = {*module, module_refcount}; - return absl::OkStatus(); + gpu_binary_to_module_[module_handle] = {module, module_refcount}; + return module_handle; } DeviceMemoryBase RocmExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == static_cast(stream_executor::MemoryType::kHost)) { - return DeviceMemoryBase(GpuDriver::HostAllocate(gpu_context(), size), size); + return DeviceMemoryBase(HostAllocate(rocm_context_, size), size); } CHECK_EQ(memory_space, 0); - return DeviceMemoryBase(GpuDriver::DeviceAllocate(gpu_context(), size), size); + return DeviceMemoryBase(DeviceAllocate(rocm_context_, size), size); +} +absl::StatusOr> +RocmExecutor::HostMemoryAllocate(uint64_t size) { + auto* buffer = HostAllocate(rocm_context_, size); + if (buffer == nullptr && size > 0) { + return absl::InternalError( + absl::StrFormat("Failed to allocate HostMemory of size %d", size)); + } + return std::make_unique(buffer, size, this); +} + +void RocmExecutor::HostMemoryDeallocate(void* location) { + std::unique_ptr activation = Activate(); + hipError_t res = wrap::hipHostFree(location); + if (res != hipSuccess) { + LOG(ERROR) << "error deallocating host memory at " << location << ": " + << ToString(res); + } } void RocmExecutor::Deallocate(DeviceMemoryBase* mem) { - GpuDriver::DeviceDeallocate(gpu_context(), mem->opaque()); + DeviceDeallocate(rocm_context_, mem->opaque()); +} + +void* RocmExecutor::UnifiedMemoryAllocate(uint64_t size) { + std::unique_ptr activation = Activate(); + hipDeviceptr_t result = 0; + // "managed" memory is visible to both CPU and GPU. + hipError_t res = wrap::hipMallocManaged(&result, size, hipMemAttachGlobal); + if (res != hipSuccess) { + LOG(ERROR) << "failed to alloc " << size + << " bytes unified memory; result: " << ToString(res); + return nullptr; + } + void* ptr = reinterpret_cast(result); + VLOG(2) << "allocated " << ptr << " for context " << rocm_context_ << " of " + << size << " bytes in unified memory"; + return ptr; +} + +void RocmExecutor::UnifiedMemoryDeallocate(void* location) { + std::unique_ptr activation = Activate(); + hipDeviceptr_t pointer = absl::bit_cast(location); + hipError_t res = wrap::hipFree(pointer); + if (res != hipSuccess) { + LOG(ERROR) << "failed to free unified memory at " << location + << "; result: " << ToString(res); + } else { + VLOG(2) << "deallocated unified memory at " << location << " for context " + << rocm_context_; + } } bool RocmExecutor::SynchronizeAllActivity() { - return GpuDriver::SynchronizeContext(gpu_context()).ok(); + return rocm_context_->Synchronize().ok(); } absl::Status RocmExecutor::SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) { - if (reinterpret_cast(location->opaque()) % 4 == 0 && - size % 4 == 0) { - return GpuDriver::SynchronousMemsetUint32( - gpu_context(), AsROCmDevicePtr(location), 0x0, size / 4); + std::unique_ptr activation = Activate(); + hipDeviceptr_t rocm_location = AsROCmDevicePtr(location); + if (reinterpret_cast(location->opaque()) % sizeof(uint32_t) == 0 && + size % sizeof(uint32_t) == 0) { + return ToStatus( + wrap::hipMemsetD32(rocm_location, 0x0, size / sizeof(uint32_t)), + "Failed to memset memory"); } - return GpuDriver::SynchronousMemsetUint8( - gpu_context(), AsROCmDevicePtr(location), 0x0, size); + return ToStatus(wrap::hipMemsetD8(rocm_location, 0x0, size), + "Failed to memset memory"); } absl::Status RocmExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { - return GpuDriver::SynchronousMemcpyH2D( - gpu_context(), AsROCmDevicePtr(gpu_dst), host_src, size); + std::unique_ptr activation = Activate(); + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipMemcpyHtoD(AsROCmDevicePtr(gpu_dst), const_cast(host_src), + size), + absl::StrFormat( + "failed to synchronous memcpy from host to device: Gpu dst: %p;" + " host src: %p; size: %llu=0x%llx", + AsROCmDevicePtr(gpu_dst), host_src, size, size))); + VLOG(2) << "successfully sync memcpy'd h2d of " << size << " bytes"; + return absl::OkStatus(); } absl::Status RocmExecutor::SynchronousMemcpy(void* host_dst, const DeviceMemoryBase& gpu_src, uint64_t size) { - return GpuDriver::SynchronousMemcpyD2H(gpu_context(), host_dst, - AsROCmDevicePtr(gpu_src), size); + std::unique_ptr activation = Activate(); + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipMemcpyDtoH(host_dst, AsROCmDevicePtr(gpu_src), size), + absl::StrFormat("failed to synchronous memcpy from device to host: " + "host dst: %p; Gpu src: %p; size: %llu=0x%llx", + host_dst, AsROCmDevicePtr(gpu_src), size, size))); + VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to " + << host_dst; + return absl::OkStatus(); } void RocmExecutor::DeallocateStream(Stream* stream) { @@ -437,13 +821,9 @@ void RocmExecutor::DeallocateStream(Stream* stream) { dnn_->NotifyStreamDestroyed(stream); } } - GpuStream* rocm_stream = AsGpuStream(stream); + RocmStream* rocm_stream = static_cast(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(rocm_stream->gpu_stream()); -} - -absl::Status RocmExecutor::BlockHostUntilDone(Stream* stream) { - return GpuDriver::SynchronizeStream(gpu_context(), AsGpuStreamValue(stream)); + alive_gpu_streams_.erase(rocm_stream->stream_handle()); } blas::BlasSupport* RocmExecutor::AsBlas() { @@ -508,18 +888,17 @@ fft::FftSupport* RocmExecutor::AsFft() { } bool RocmExecutor::CanEnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* rocm_other = static_cast(other); - return GpuDriver::CanEnablePeerAccess(gpu_context(), - rocm_other->gpu_context()); + RocmExecutor* rocm_other = static_cast(other); + return CanEnablePeerAccess(rocm_context_, rocm_other->rocm_context_); } absl::Status RocmExecutor::EnablePeerAccessTo(StreamExecutor* other) { - GpuExecutor* rocm_other = static_cast(other); - return GpuDriver::EnablePeerAccess(gpu_context(), rocm_other->gpu_context()); + RocmExecutor* rocm_other = static_cast(other); + return EnablePeerAccess(rocm_context_, rocm_other->rocm_context_); } bool RocmExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { - return GpuDriver::GetDeviceMemoryInfo(gpu_context(), free, total); + return rocm_context_->GetDeviceMemoryUsage(free, total); } absl::StatusOr RocmExecutor::GetSymbol( @@ -529,18 +908,18 @@ absl::StatusOr RocmExecutor::GetSymbol( absl::MutexLock lock{&in_memory_modules_mu_}; if (static_cast(module_handle)) { - auto it = gpu_binary_to_module_.find(module_handle.id()); + auto it = gpu_binary_to_module_.find(module_handle); CHECK(it != gpu_binary_to_module_.end()); - TF_RETURN_IF_ERROR(GpuDriver::GetModuleSymbol( - gpu_context(), it->second.first, symbol_name.c_str(), - reinterpret_cast(&mem), &bytes)); + TF_RETURN_IF_ERROR( + GetModuleSymbol(rocm_context_, it->second.first, symbol_name.c_str(), + reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } for (auto& it : gpu_binary_to_module_) { - TF_RETURN_IF_ERROR(GpuDriver::GetModuleSymbol( - gpu_context(), it.second.first, symbol_name.c_str(), - reinterpret_cast(&mem), &bytes)); + TF_RETURN_IF_ERROR( + GetModuleSymbol(rocm_context_, it.second.first, symbol_name.c_str(), + reinterpret_cast(&mem), &bytes)); return DeviceMemoryBase(mem, bytes); } @@ -551,14 +930,13 @@ absl::StatusOr RocmExecutor::GetSymbol( reinterpret_cast(module_handle.id()), ")")); } -absl::Status FillBlockDimLimit(GpuDeviceHandle device, - BlockDim* block_dim_limit) { +absl::Status FillBlockDimLimit(hipDevice_t device, BlockDim* block_dim_limit) { // The BlockDim name is a mismatch against these GRID_DIM_* queries because // we use BlockDims to express the dimensions of blocks within a grid // (as opposed to ThreadDim which expresses the dimensions of threads // within a block). int x, y, z; - TF_RETURN_IF_ERROR(GpuDriver::GetGridLimits(&x, &y, &z, device)); + TF_RETURN_IF_ERROR(GetGridLimits(&x, &y, &z, device)); block_dim_limit->x = x; block_dim_limit->y = y; @@ -566,64 +944,41 @@ absl::Status FillBlockDimLimit(GpuDeviceHandle device, return absl::OkStatus(); } -absl::StatusOr> RocmExecutor::CreateGpuEvent( - bool allow_timing) { - auto gpu_event = std::make_unique(gpu_context()); - TF_RETURN_IF_ERROR(gpu_event->Init(allow_timing)); - return std::move(gpu_event); -} - absl::StatusOr> RocmExecutor::CreateEvent() { - return CreateGpuEvent(/*allow_timing=*/false); + TF_ASSIGN_OR_RETURN(auto event, + RocmEvent::Create(this, /*allow_timing=*/false)); + return std::make_unique(std::move(event)); } absl::StatusOr> RocmExecutor::CreateStream( std::optional> priority) { - TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); - auto stream = std::make_unique(this, std::move(event), priority); + TF_ASSIGN_OR_RETURN(auto stream, RocmStream::Create(this, priority)); absl::MutexLock l(&alive_gpu_streams_mu_); - TF_RETURN_IF_ERROR(stream->Init()); - auto gpu_stream = stream->gpu_stream(); - alive_gpu_streams_[gpu_stream] = stream.get(); + alive_gpu_streams_[stream->stream_handle()] = stream.get(); return std::move(stream); } absl::StatusOr> RocmExecutor::CreateCommandBuffer(CommandBuffer::Mode mode) { VLOG(2) << "Create ROCm command buffer (ROCm graph)"; - GpuGraphHandle graph = nullptr; - TF_RETURN_IF_ERROR(GpuDriver::CreateGraph(&graph)); - return std::make_unique(mode, /*parent=*/this, graph); + return RocmCommandBuffer::Create(mode, this); } absl::Status RocmExecutor::TrimGraphMemory() { - return GpuDriver::DeviceGraphMemTrim(device_); + return ToStatus(wrap::hipDeviceGraphMemTrim(device_), + "Failed to trim device graph memory"); } absl::StatusOr> -GpuExecutor::CreateDeviceDescription(int device_ordinal) { - GpuDeviceHandle device; - auto status = GpuDriver::GetDevice(device_ordinal, &device); - if (!status.ok()) { - return status; - } +RocmExecutor::CreateDeviceDescription(int device_ordinal) { + TF_ASSIGN_OR_RETURN(hipDevice_t device, GetDevice(device_ordinal)); - int version; - status = GpuDriver::GetGpuISAVersion(&version, device); - if (!status.ok()) { - return status; - } - - std::string gcn_arch_name; - status = GpuDriver::GetGpuGCNArchName(device, &gcn_arch_name); - if (!status.ok()) { - return status; - } + TF_ASSIGN_OR_RETURN(std::string gcn_arch_name, GetGpuGCNArchName(device)); DeviceDescription desc; { - std::string pci_bus_id = GpuDriver::GetPCIBusID(device); + std::string pci_bus_id = GetPCIBusID(device); // Lower the hex characters to match sysfs. pci_bus_id = absl::AsciiStrToLower(pci_bus_id); @@ -635,7 +990,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } hipDeviceProp_t prop; - if (GpuDriver::GetDeviceProperties(&prop, device_ordinal)) { + if (GetDeviceProperties(&prop, device_ordinal)) { desc.set_threads_per_block_limit(prop.maxThreadsPerBlock); ThreadDim thread_dim_limit; @@ -656,14 +1011,11 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_l2_cache_size(prop.l2CacheSize); } - { - bool ecc_enabled = false; - (void)GpuDriver::IsEccEnabled(device, &ecc_enabled); - desc.set_ecc_enabled(ecc_enabled); - } + // No way to query ECC status from the API. + desc.set_ecc_enabled(false); uint64_t device_memory_size = -1; - (void)GpuDriver::GetDeviceTotalMemory(device, &device_memory_size); + (void)RocmContext::GetDeviceTotalMemory(device, &device_memory_size); desc.set_device_memory_size(device_memory_size); { @@ -673,8 +1025,7 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { } { - std::string device_name; - TF_RETURN_IF_ERROR(GpuDriver::GetDeviceName(device, &device_name)); + TF_ASSIGN_OR_RETURN(std::string device_name, GetDeviceName(device)); desc.set_name(device_name); } @@ -688,48 +1039,69 @@ GpuExecutor::CreateDeviceDescription(int device_ordinal) { desc.set_device_vendor("Advanced Micro Devices, Inc"); desc.set_rocm_compute_capability(gcn_arch_name); - desc.set_shared_memory_per_core( - GpuDriver::GetMaxSharedMemoryPerCore(device).value()); - desc.set_shared_memory_per_block( - GpuDriver::GetMaxSharedMemoryPerBlock(device).value()); - int core_count = GpuDriver::GetMultiprocessorCount(device).value(); + desc.set_shared_memory_per_core(GetMaxSharedMemoryPerCore(device).value()); + desc.set_shared_memory_per_block(GetMaxSharedMemoryPerBlock(device).value()); + int core_count = GetMultiprocessorCount(device).value(); desc.set_core_count(core_count); desc.set_fpus_per_core(fpus_per_core(gcn_arch_name)); desc.set_threads_per_core_limit( - GpuDriver::GetMaxThreadsPerMultiprocessor(device).value()); - desc.set_registers_per_block_limit( - GpuDriver::GetMaxRegistersPerBlock(device).value()); - desc.set_threads_per_warp(GpuDriver::GetThreadsPerWarp(device).value()); + GetMaxThreadsPerMultiprocessor(device).value()); + desc.set_registers_per_block_limit(GetMaxRegistersPerBlock(device).value()); + desc.set_threads_per_warp(GetThreadsPerWarp(device).value()); desc.set_registers_per_core_limit(64 * 1024); desc.set_compile_time_toolkit_version( SemanticVersion{HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH}); desc.set_runtime_version( - ParseRocmVersion(GpuRuntime::GetRuntimeVersion().value_or(0)) + ParseRocmVersion(RocmRuntime::GetRuntimeVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); desc.set_driver_version( ParseRocmVersion(GpuDriver::GetDriverVersion().value_or(0)) .value_or(SemanticVersion{0, 0, 0})); - int cc_major = 0; - int cc_minor = 0; - GpuDriver::GetComputeCapability(&cc_major, &cc_minor, device).IgnoreError(); - // It would be better to use the PCI device ID or some other truly unique // identifier for the GPU model. But getting this requires using NVML or // other hacks, which we don't have access to in OSS TensorFlow. // - // Alternatively you might be tempted to use GpuDriver::GetDeviceName as a + // Alternatively you might be tempted to use GetDeviceName as a // unique identifier, but this is not stable across GPU VBIOS versions. // // TODO(jlebar): This really should be more unique. In CUDA land, we mix in // the clock speed and L2 cache size. +<<<<<<< HEAD desc.set_model_str(absl::StrFormat("cc_%d.%d.%d with %dB RAM, %d cores", cc_major, cc_minor, device_ordinal, device_memory_size, core_count)); +======= + desc.set_model_str( + absl::StrFormat("%dB RAM, %d cores", device_memory_size, core_count)); +>>>>>>> master return std::make_unique(std::move(desc)); } +absl::StatusOr RocmExecutor::GetPointerMemorySpace( + const void* ptr) { + hipDeviceptr_t pointer = + reinterpret_cast(const_cast(ptr)); + unsigned int value; + hipError_t result = wrap::hipPointerGetAttribute( + &value, HIP_POINTER_ATTRIBUTE_MEMORY_TYPE, pointer); + if (result == hipSuccess) { + switch (value) { + case hipMemoryTypeDevice: + return MemoryType::kDevice; + case hipMemoryTypeHost: + return MemoryType::kHost; + default: + return absl::InternalError( + absl::StrCat("unknown memory space provided by ROCM API: ", value)); + } + } + + return absl::InternalError(absl::StrCat( + "failed to query device pointer for memory space: ", ToString(result))); +} + } // namespace gpu } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.h b/third_party/xla/xla/stream_executor/rocm/rocm_executor.h index f8d88d46d6db28..7b5583b404fdd4 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.h @@ -21,7 +21,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -30,9 +29,10 @@ limitations under the License. #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_description.h" @@ -41,17 +41,15 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_kernel.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/rocm/rocm_context.h" +#include "xla/stream_executor/rocm/rocm_event.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -63,6 +61,7 @@ class RocmExecutor : public GpuExecutor { RocmExecutor(Platform* platform, int device_ordinal) : GpuExecutor(platform, device_ordinal) {} ~RocmExecutor() override; + std::unique_ptr Activate() override; absl::Status Init() override; blas::BlasSupport* AsBlas() override; @@ -76,16 +75,23 @@ class RocmExecutor : public GpuExecutor { absl::StatusOr> LoadKernel( const MultiKernelLoaderSpec& spec) override; void UnloadKernel(const Kernel* kernel) override; +<<<<<<< HEAD absl::Status LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) override; +======= + absl::StatusOr LoadModule( + const MultiModuleLoaderSpec& spec) override; +>>>>>>> master bool UnloadModule(ModuleHandle module_handle) override; absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content) override; DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; + absl::StatusOr GetMemoryRange( + const DeviceMemoryBase& location) override; void Deallocate(DeviceMemoryBase* mem) override; bool SynchronizeAllActivity() override; absl::StatusOr> CreateEventBasedTimer( - GpuStream* stream, bool use_delay_kernel) override; + Stream* stream, bool use_delay_kernel) override; absl::StatusOr GetSymbol( const std::string& symbol_name, ModuleHandle module_handle) override; absl::Status SynchronousMemZero(DeviceMemoryBase* location, @@ -97,40 +103,22 @@ class RocmExecutor : public GpuExecutor { uint64_t size) override; absl::Status TrimGraphMemory() override; void DeallocateStream(Stream* stream) override; - absl::Status BlockHostUntilDone(Stream* stream) override; absl::Status EnablePeerAccessTo(StreamExecutor* other) override; bool CanEnablePeerAccessTo(StreamExecutor* other) override; bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override; absl::StatusOr> CreateDeviceDescription() const override { - return GpuExecutor::CreateDeviceDescription(device_ordinal()); - } - void* UnifiedMemoryAllocate(uint64_t size) override { - return GpuDriver::UnifiedMemoryAllocate(gpu_context(), size); + return RocmExecutor::CreateDeviceDescription(device_ordinal()); } + void* UnifiedMemoryAllocate(uint64_t size) override; - void UnifiedMemoryDeallocate(void* location) override { - return GpuDriver::UnifiedMemoryDeallocate(gpu_context(), location); - } + void UnifiedMemoryDeallocate(void* location) override; absl::StatusOr> HostMemoryAllocate( - uint64_t size) override { - auto* buffer = GpuDriver::HostAllocate(gpu_context(), size); - if (buffer == nullptr && size > 0) { - return absl::InternalError( - absl::StrFormat("Failed to allocate HostMemory of size %d", size)); - } - return std::make_unique(buffer, size, this); - } + uint64_t size) override; + void HostMemoryDeallocate(void* location) override; - void HostMemoryDeallocate(void* location) override { - return GpuDriver::HostDeallocate(gpu_context(), location); - } - - absl::StatusOr GetPointerMemorySpace(const void* ptr) override { - return GpuDriver::GetPointerMemorySpace( - reinterpret_cast(const_cast(ptr))); - } + absl::StatusOr GetPointerMemorySpace(const void* ptr) override; Stream* FindAllocatedStream(void* gpu_stream) override { absl::MutexLock lock(&alive_gpu_streams_mu_); @@ -141,35 +129,24 @@ class RocmExecutor : public GpuExecutor { return it->second; } - private: - // Collects metadata for the specified kernel. - absl::Status GetKernelMetadata(GpuKernel* rocm_kernel, - KernelMetadata* kernel_metadata); + static absl::StatusOr> + CreateDeviceDescription(int device_ordinal); + private: // Loads a module in HSACO format. - absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) + absl::StatusOr LoadModuleFromHsaco(const char* hsaco) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - bool UnloadGpuBinary(const void* gpu_binary) + bool UnloadGpuBinary(ModuleHandle module_handle) ABSL_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); // Creates a GpuEvent for the given stream. - absl::StatusOr> CreateGpuEvent(bool allow_timing); - - // Guards the on-disk-module mapping. - absl::Mutex disk_modules_mu_; - - // Mapping from filename to GPUModuleHandle, if it was already retrieved. - // Multiple GPUFunctionHandle are usually obtained from a single - // GPUModuleHandle so we attempt to hit in this mapping first, before - // retrieving it. - std::map disk_modules_ - ABSL_GUARDED_BY(disk_modules_mu_); + absl::StatusOr> CreateGpuEvent(bool allow_timing); // Guards the in-memory-module mapping. absl::Mutex in_memory_modules_mu_; - std::map in_memory_modules_ + absl::flat_hash_map in_memory_modules_ ABSL_GUARDED_BY(in_memory_modules_mu_); absl::Mutex shared_constants_mu_; @@ -180,15 +157,16 @@ class RocmExecutor : public GpuExecutor { shared_constants_ ABSL_GUARDED_BY(shared_constants_mu_); // Kernel -> loaded GPU binary. Many kernels may load the same binary. - std::unordered_map kernel_to_gpu_binary_ + absl::flat_hash_map kernel_to_gpu_binary_ ABSL_GUARDED_BY(in_memory_modules_mu_); - // GPU binary (PTX or CUBIN or HSACO) -> {module, reference count}. - std::unordered_map> + + // Loaded GPU binary handle -> {module, reference count}. + absl::flat_hash_map> gpu_binary_to_module_ ABSL_GUARDED_BY(in_memory_modules_mu_); // Handle for the ROCm device being operated on. Immutable // post-initialization. - GpuDeviceHandle device_; + hipDevice_t device_; // Reader/writer lock for mutable data structures on this object. absl::Mutex mu_; @@ -213,6 +191,9 @@ class RocmExecutor : public GpuExecutor { // GPU ISA version for device_. int version_; + + // RocmContext for this device. + RocmContext* rocm_context_; }; } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc new file mode 100644 index 00000000000000..1ed4fed1b0e462 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_executor.h" + +#include +#include +#include "xla/stream_executor/device_description.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using testing::IsEmpty; +using testing::Not; + +TEST(RocmExecutorTest, CreateDeviceDescription) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + RocmExecutor::CreateDeviceDescription(0)); + + constexpr SemanticVersion kNullVersion{0, 0, 0}; + EXPECT_NE(result->runtime_version(), kNullVersion); + EXPECT_NE(result->driver_version(), kNullVersion); + EXPECT_NE(result->compile_time_toolkit_version(), kNullVersion); + + EXPECT_THAT(result->platform_version(), Not(IsEmpty())); + EXPECT_THAT(result->name(), Not(IsEmpty())); + EXPECT_THAT(result->model_str(), Not(IsEmpty())); + EXPECT_THAT(result->device_vendor(), "Advanced Micro Devices, Inc"); + + EXPECT_THAT( + std::get_if(&result->gpu_compute_capability()) + ->gcn_arch_name(), + Not(IsEmpty())); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc b/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc index fe85422350838e..02ab236ef408dd 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_fft.cc @@ -16,19 +16,18 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_fft.h" #include +#include +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_helpers.h" #include "xla/stream_executor/gpu/gpu_stream.h" -#include "xla/stream_executor/gpu/scoped_activate_context.h" -#include "xla/stream_executor/platform/dso_loader.h" #include "xla/stream_executor/platform/initialize.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/rocm/rocm_complex_converters.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" @@ -45,13 +44,13 @@ namespace wrap { // manner on first use. This dynamic loading technique is used to avoid DSO // dependencies on vendor libraries which may or may not be available in the // deployed binary environment. -#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ - struct WrapperShim__##__name { \ - template \ - hipfftResult operator()(GpuExecutor *parent, Args... args) { \ - ScopedActivateContext sac{parent}; \ - return ::__name(args...); \ - } \ +#define STREAM_EXECUTOR_ROCFFT_WRAP(__name) \ + struct WrapperShim__##__name { \ + template \ + hipfftResult operator()(StreamExecutor *parent, Args... args) { \ + std::unique_ptr activation = parent->Activate(); \ + return ::__name(args...); \ + } \ } __name; #else @@ -61,7 +60,7 @@ namespace wrap { static const char *kName; \ using FuncPtrT = std::add_pointer::type; \ static void *GetDsoHandle() { \ - auto s = internal::CachedDsoLoader::GetHipfftDsoHandle(); \ + auto s = tsl::internal::CachedDsoLoader::GetHipfftDsoHandle(); \ return s.value(); \ } \ static FuncPtrT LoadOrDie() { \ @@ -77,8 +76,8 @@ namespace wrap { return f; \ } \ template \ - hipfftResult operator()(GpuExecutor *parent, Args... args) { \ - ScopedActivateContext sac{parent}; \ + hipfftResult operator()(StreamExecutor *parent, Args... args) { \ + std::unique_ptr activation = parent->Activate(); \ return DynLoad()(args...); \ } \ } __name; \ @@ -143,7 +142,7 @@ hipfftType ROCMFftType(fft::Type type) { } // Associates the given stream with the given rocFFT plan. -bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) { +bool SetStream(StreamExecutor *parent, hipfftHandle plan, Stream *stream) { auto ret = wrap::hipfftSetStream(parent, plan, AsGpuStreamValue(stream)); if (ret != HIPFFT_SUCCESS) { LOG(ERROR) << "failed to run rocFFT routine hipfftSetStream: " << ret; @@ -155,9 +154,9 @@ bool SetStream(GpuExecutor *parent, hipfftHandle plan, Stream *stream) { } // namespace absl::Status ROCMFftPlan::Initialize( - GpuExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, - uint64_t *input_embed, uint64 input_stride, uint64 input_distance, - uint64_t *output_embed, uint64 output_stride, uint64 output_distance, + StreamExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, + uint64_t *input_embed, uint64_t input_stride, uint64_t input_distance, + uint64_t *output_embed, uint64_t output_stride, uint64_t output_distance, fft::Type type, int batch_count, ScratchAllocator *scratch_allocator) { if (IsInitialized()) { LOG(FATAL) << "Try to repeatedly initialize."; @@ -316,7 +315,7 @@ absl::Status ROCMFftPlan::Initialize( return absl::OkStatus(); } -absl::Status ROCMFftPlan::Initialize(GpuExecutor *parent, Stream *stream, +absl::Status ROCMFftPlan::Initialize(StreamExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, fft::Type type, ScratchAllocator *scratch_allocator) { @@ -370,9 +369,9 @@ int ROCMFftPlan::GetFftDirection() const { } std::unique_ptr ROCMFft::CreateBatchedPlanWithScratchAllocator( - Stream *stream, int rank, uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, uint64 *output_embed, - uint64_t output_stride, uint64 output_distance, fft::Type type, + Stream *stream, int rank, uint64_t *elem_count, uint64_t *input_embed, + uint64_t input_stride, uint64_t input_distance, uint64_t *output_embed, + uint64_t output_stride, uint64_t output_distance, fft::Type type, bool in_place_fft, int batch_count, ScratchAllocator *scratch_allocator) { std::unique_ptr fft_plan_ptr{new ROCMFftPlan()}; absl::Status status = fft_plan_ptr->Initialize( @@ -513,16 +512,7 @@ void initialize_rocfft() { PluginRegistry::Instance()->RegisterFactory( rocm::kROCmPlatformId, "rocFFT", [](StreamExecutor *parent) -> fft::FftSupport * { - gpu::GpuExecutor *rocm_executor = - dynamic_cast(parent); - if (rocm_executor == nullptr) { - LOG(ERROR) - << "Attempting to initialize an instance of the rocFFT " - << "support library with a non-ROCM StreamExecutor"; - return nullptr; - } - - return new gpu::ROCMFft(rocm_executor); + return new gpu::ROCMFft(parent); }); if (!status.ok()) { LOG(ERROR) << "Unable to register rocFFT factory: " << status.message(); diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_fft.h b/third_party/xla/xla/stream_executor/rocm/rocm_fft.h index dad6f3e0864f19..5de76ec71a5a1b 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_fft.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_fft.h @@ -20,6 +20,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_ #define XLA_STREAM_EXECUTOR_ROCM_ROCM_FFT_H_ +#include #if TENSORFLOW_USE_ROCM #include "rocm/rocm_config.h" @@ -33,19 +34,14 @@ limitations under the License. #endif #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/platform/port.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" namespace stream_executor { -class Stream; - namespace gpu { -class GpuExecutor; - // ROCMFftPlan uses deferred initialization. Only a single call of // Initialize() is allowed to properly create hipfft plan and set member // variable is_initialized_ to true. Newly added interface that uses member @@ -73,15 +69,15 @@ class ROCMFftPlan : public fft::Plan { } // Initialize function for batched plan - absl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, - uint64_t *elem_count, uint64 *input_embed, - uint64_t input_stride, uint64 input_distance, - uint64_t *output_embed, uint64 output_stride, + absl::Status Initialize(StreamExecutor *parent, Stream *stream, int rank, + uint64_t *elem_count, uint64_t *input_embed, + uint64_t input_stride, uint64_t input_distance, + uint64_t *output_embed, uint64_t output_stride, uint64_t output_distance, fft::Type type, int batch_count, ScratchAllocator *scratch_allocator); // Initialize function for 1d,2d, and 3d plan - absl::Status Initialize(GpuExecutor *parent, Stream *stream, int rank, + absl::Status Initialize(StreamExecutor *parent, Stream *stream, int rank, uint64_t *elem_count, fft::Type type, ScratchAllocator *scratch_allocator); @@ -95,10 +91,10 @@ class ROCMFftPlan : public fft::Plan { ScratchAllocator *scratch_allocator_; private: - GpuExecutor *parent_; + StreamExecutor *parent_; hipfftHandle plan_; fft::Type fft_type_; - DeviceMemory scratch_; + DeviceMemory scratch_; size_t scratch_size_bytes_; bool is_initialized_; }; @@ -108,7 +104,7 @@ class ROCMFftPlan : public fft::Plan { // This satisfies the platform-agnostic FftSupport interface. // // Note that the hipFFT handle that this encapsulates is implicitly tied to the -// context (and, as a result, the device) that the parent GpuExecutor is tied +// context (and, as a result, the device) that the parent StreamExecutor is tied // to. This simply happens as an artifact of creating the hipFFT handle when a // ROCM context is active. // @@ -116,13 +112,13 @@ class ROCMFftPlan : public fft::Plan { // context of parent_, so all context is explicit. class ROCMFft : public fft::FftSupport { public: - explicit ROCMFft(GpuExecutor *parent) : parent_(parent) {} + explicit ROCMFft(StreamExecutor *parent) : parent_(parent) {} ~ROCMFft() override {} TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES private: - GpuExecutor *parent_; + StreamExecutor *parent_; // Two helper functions that execute dynload::hipfftExec?2?. diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc index e8eb84a89092e1..a75b62927ba1c2 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.cc @@ -13,24 +13,64 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/rocm/rocm_kernel.h" + +#include #include +#include -#include "xla/stream_executor/gpu/gpu_driver.h" -#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "tsl/platform/errors.h" namespace stream_executor { namespace gpu { -absl::StatusOr GpuKernel::GetMaxOccupiedBlocksPerCore( +namespace { + +absl::Status FuncGetAttribute(hipFunction_attribute attribute, + hipFunction_t func, int* attribute_value) { + return ToStatus( + wrap::hipFuncGetAttribute(attribute_value, attribute, func), + absl::StrCat("Failed to query kernel attribute: ", attribute)); +} + +} // namespace +absl::StatusOr RocmKernel::GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const { int32_t threads_per_block = threads.x * threads.y * threads.z; VLOG(0) << "Get kernel block occupancy: " << name() << "; threads_per_block: " << threads_per_block << "; dynamic_shared_memory_bytes: " << dynamic_shared_memory_bytes; - return GpuDriver::GetMaxOccupiedBlocksPerCore(gpu_context_, gpu_function_, - threads_per_block, - dynamic_shared_memory_bytes); + std::unique_ptr activation = executor_->Activate(); + + int max_blocks = 0; + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipModuleOccupancyMaxActiveBlocksPerMultiprocessor( + &max_blocks, rocm_function_, threads_per_block, + dynamic_shared_memory_bytes), + "Failed to calculate maximal active blocks per SM")); + return max_blocks; +} + +absl::StatusOr RocmKernel::GetKernelMetadata() { + KernelMetadata kernel_metadata; + int value = 0; + TF_RETURN_IF_ERROR( + FuncGetAttribute(HIP_FUNC_ATTRIBUTE_NUM_REGS, rocm_function_, &value)); + kernel_metadata.set_registers_per_thread(value); + + TF_RETURN_IF_ERROR(FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, + rocm_function_, &value)); + kernel_metadata.set_shared_memory_bytes(value); + return kernel_metadata; } } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h new file mode 100644 index 00000000000000..6713ceeb8c74b7 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel.h @@ -0,0 +1,72 @@ +/* Copyright 2019 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The CUDA implementation of the StreamExecutor functionality. +// CUDA inclusions are ideally confined to this implementation file. +// +// The notions from the StreamExecutor basically correspond to the CUDA streams +// programming model provided by the libcuda.so driver APIs, so we don't have +// to do much more than wrap the calls to the libraries appropriately. +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/logging.h" + +namespace stream_executor::gpu { + +class RocmKernel : public GpuKernel { + public: + explicit RocmKernel(StreamExecutor* executor) : executor_(executor) {} + + // Note that the function is unloaded when the module is unloaded, and the + // module that the function is contained in is owned by the StreamExecutor. + ~RocmKernel() override { executor_->UnloadKernel(this); } + + // As arity cannot be reflected upon using the HIP API, the arity is + // explicitly set during the RocmExecutor::GetKernel initialization process. + void set_arity(unsigned arity) { arity_ = arity; } + unsigned Arity() const override { return arity_; } + + absl::StatusOr GetMaxOccupiedBlocksPerCore( + ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + + // Simple accessor methods. + hipFunction_t gpu_function() const override { return rocm_function_; } + void set_gpu_function(hipFunction_t rocm_function) { + rocm_function_ = rocm_function; + } + + // Collects metadata for the specified kernel. + absl::StatusOr GetKernelMetadata(); + + private: + StreamExecutor* executor_ = nullptr; + + hipFunction_t rocm_function_ = nullptr; // wrapped CUDA kernel handle + unsigned arity_ = 0; // number of formal parameters the kernel takes +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_KERNEL_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc new file mode 100644 index 00000000000000..f46c19d5e42f25 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_kernel_test.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_kernel.h" + +#include +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_runtime.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using testing::Ge; +using tsl::testing::IsOkAndHolds; + +TEST(RocmKernelTest, GetMaxOccupiedBlocksPerCore) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("ROCM")); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + + RocmKernel rocm_kernel(executor); + rocm_kernel.set_arity(3); + + TF_ASSERT_OK_AND_ASSIGN( + hipFunction_t function, + RocmRuntime::GetFuncBySymbol(internal::GetAddI32Kernel())); + + rocm_kernel.set_gpu_function(function); + + EXPECT_EQ(rocm_kernel.Arity(), 3); + EXPECT_EQ(rocm_kernel.gpu_function(), function); + + EXPECT_THAT(rocm_kernel.GetMaxOccupiedBlocksPerCore( + ThreadDim(1, 1, 1), /*dynamic_shared_memory_bytes=*/0), + IsOkAndHolds(Ge(1))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc index a8414a142115cb..2373509c2c56fd 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc @@ -16,21 +16,55 @@ limitations under the License. #include "xla/stream_executor/rocm/rocm_platform.h" #include - -#include "absl/base/call_once.h" -#include "absl/strings/str_format.h" +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/gpu/gpu_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" #include "xla/stream_executor/rocm/rocm_executor.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/rocm/rocm_status.h" #include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace stream_executor { namespace gpu { +namespace { -ROCmPlatform::ROCmPlatform() : name_("ROCM") {} +// Actually performs the work of ROCM initialization. Wrapped up in one-time +// execution guard. +static absl::Status InternalInitialize() { + hipError_t res = wrap::hipInit(0 /* = flags */); -ROCmPlatform::~ROCmPlatform() {} + if (res == hipSuccess) { + return absl::OkStatus(); + } + + LOG(ERROR) << "failed call to hipInit: " << ToString(res); + Diagnostician::LogDiagnosticInformation(); + return absl::AbortedError( + absl::StrCat("failed call to hipInit: ", ToString(res))); +} + +static absl::Status PlatformInitialize() { + // Cached return value from calling InternalInitialize(), as hipInit need only + // be called once, but PlatformInitialize may be called many times. + static absl::Status* init_retval = [] { + return new absl::Status(InternalInitialize()); + }(); + return *init_retval; +} +} // namespace + +ROCmPlatform::ROCmPlatform() : name_("ROCM") {} Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; } @@ -38,7 +72,7 @@ int ROCmPlatform::VisibleDeviceCount() const { // Throw away the result - it logs internally, and this [containing] function // isn't in the path of user control. It's safe to call this > 1x. - if (!gpu::GpuDriver::Init().ok()) { + if (!PlatformInitialize().ok()) { return -1; } @@ -49,10 +83,12 @@ const std::string& ROCmPlatform::Name() const { return name_; } absl::StatusOr> ROCmPlatform::DescriptionForDevice(int ordinal) const { - return GpuExecutor::CreateDeviceDescription(ordinal); + TF_RETURN_IF_ERROR(PlatformInitialize()); + return RocmExecutor::CreateDeviceDescription(ordinal); } absl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { + TF_RETURN_IF_ERROR(PlatformInitialize()); return executor_cache_.GetOrCreate( ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h index e37345c5275127..1cab970e37d385 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h @@ -17,13 +17,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_ROCM_ROCM_PLATFORM_H_ #include -#include +#include +#include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform/port.h" -#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { @@ -41,7 +41,6 @@ namespace gpu { class ROCmPlatform : public Platform { public: ROCmPlatform(); - ~ROCmPlatform() override; // Platform interface implementation: // Returns the same value as kROCmPlatform above. diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc index fe8dd31c47a7ee..5998d42f1ba513 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.cc @@ -13,44 +13,37 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/stream_executor/rocm/rocm_runtime.h" + #include -#include "absl/base/optimization.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "xla/stream_executor/gpu/gpu_runtime.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/rocm/rocm_driver.h" #include "xla/stream_executor/rocm/rocm_driver_wrapper.h" - -#define RETURN_IF_ROCM_ERROR(expr, ...) \ - if (auto res = (expr); TF_PREDICT_FALSE(res != hipSuccess)) { \ - return absl::InternalError(absl::StrCat( \ - __VA_ARGS__, ": ", ::stream_executor::gpu::ToString(res))); \ - } +#include "xla/stream_executor/rocm/rocm_status.h" +#include "tsl/platform/errors.h" namespace stream_executor { namespace gpu { -absl::StatusOr GpuRuntime::GetFuncBySymbol(void* symbol) { +absl::StatusOr RocmRuntime::GetFuncBySymbol(void* symbol) { VLOG(2) << "Get ROCM function from a symbol: " << symbol; #if TF_ROCM_VERSION >= 60200 hipFunction_t func; - RETURN_IF_ROCM_ERROR(wrap::hipGetFuncBySymbol(&func, symbol), - "Failed call to hipGetFuncBySymbol"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipGetFuncBySymbol(&func, symbol), + "Failed call to hipGetFuncBySymbol")); return func; #else return absl::UnimplementedError("GetFuncBySymbol is not implemented"); #endif // TF_ROCM_VERSION >= 60200 } -absl::StatusOr GpuRuntime::GetRuntimeVersion() { +absl::StatusOr RocmRuntime::GetRuntimeVersion() { VLOG(2) << "Get ROCM runtime version"; int32_t version; - RETURN_IF_ROCM_ERROR(wrap::hipRuntimeGetVersion(&version), - "Failed call to hipRuntimeGetVersion"); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipRuntimeGetVersion(&version), + "Failed call to hipRuntimeGetVersion")); return version; } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h new file mode 100644 index 00000000000000..b1a197fe0643bd --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_runtime.h @@ -0,0 +1,47 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// CUDA/ROCm runtime library wrapper functionality. + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ + +#include + +#include "absl/status/statusor.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" + +namespace stream_executor::gpu { + +// Rocm runtime returns types defined in the stream_executor::gpu namespace, and +// they usually correspond to the driver types, as driver API is the primary +// integration API of Gpus into StreamExecutor. +class RocmRuntime { + public: + // Get pointer to device entry function that matches entry function `symbol`. + // + // WARNING: This will load all fatbins statically registered with the + // underlying runtime into runtime modules for the current context. If no + // context is current, the runtime will use the primary context for the + // current device (and create it if it doesn't exist yet). + static absl::StatusOr GetFuncBySymbol(void* symbol); + + // Returns the Gpu Runtime version. + static absl::StatusOr GetRuntimeVersion(); +}; + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_RUNTIME_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_status.cc b/third_party/xla/xla/stream_executor/rocm/rocm_status.cc new file mode 100644 index 00000000000000..bddadd870cf3ad --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_status.cc @@ -0,0 +1,102 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_status.h" + +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "rocm/include/hip/hip_runtime.h" + +namespace stream_executor::gpu { + +// Formats hipError_t to output prettified values into a log stream. +// Error summaries taken from: +std::string ToString(hipError_t result) { +#define OSTREAM_ROCM_ERROR(__name) \ + case hipError##__name: \ + return "HIP_ERROR_" #__name; + + switch (result) { + OSTREAM_ROCM_ERROR(InvalidValue) + OSTREAM_ROCM_ERROR(OutOfMemory) + OSTREAM_ROCM_ERROR(NotInitialized) + OSTREAM_ROCM_ERROR(Deinitialized) + OSTREAM_ROCM_ERROR(NoDevice) + OSTREAM_ROCM_ERROR(InvalidDevice) + OSTREAM_ROCM_ERROR(InvalidImage) + OSTREAM_ROCM_ERROR(InvalidContext) + OSTREAM_ROCM_ERROR(InvalidHandle) + OSTREAM_ROCM_ERROR(NotFound) + OSTREAM_ROCM_ERROR(NotReady) + OSTREAM_ROCM_ERROR(NoBinaryForGpu) + + // Encountered an uncorrectable ECC error during execution. + OSTREAM_ROCM_ERROR(ECCNotCorrectable) + + // Load/store on an invalid address. Must reboot all context. + case 700: + return "ROCM_ERROR_ILLEGAL_ADDRESS"; + // Passed too many / wrong arguments, too many threads for register count. + case 701: + return "ROCM_ERROR_LAUNCH_OUT_OF_RESOURCES"; + + OSTREAM_ROCM_ERROR(ContextAlreadyInUse) + OSTREAM_ROCM_ERROR(PeerAccessUnsupported) + OSTREAM_ROCM_ERROR(Unknown) // Unknown internal error to ROCM. +#if TF_ROCM_VERSION >= 60200 + OSTREAM_ROCM_ERROR(LaunchTimeOut) + OSTREAM_ROCM_ERROR(PeerAccessAlreadyEnabled) + OSTREAM_ROCM_ERROR(PeerAccessNotEnabled) + OSTREAM_ROCM_ERROR(SetOnActiveProcess) + OSTREAM_ROCM_ERROR(ContextIsDestroyed) + OSTREAM_ROCM_ERROR(Assert) + OSTREAM_ROCM_ERROR(HostMemoryAlreadyRegistered) + OSTREAM_ROCM_ERROR(HostMemoryNotRegistered) + OSTREAM_ROCM_ERROR(LaunchFailure) + OSTREAM_ROCM_ERROR(CooperativeLaunchTooLarge) + OSTREAM_ROCM_ERROR(NotSupported) + OSTREAM_ROCM_ERROR(StreamCaptureUnsupported) + OSTREAM_ROCM_ERROR(StreamCaptureInvalidated) + OSTREAM_ROCM_ERROR(StreamCaptureMerge) + OSTREAM_ROCM_ERROR(StreamCaptureUnmatched) + OSTREAM_ROCM_ERROR(StreamCaptureUnjoined) + OSTREAM_ROCM_ERROR(StreamCaptureIsolation) + OSTREAM_ROCM_ERROR(StreamCaptureImplicit) + OSTREAM_ROCM_ERROR(CapturedEvent) + OSTREAM_ROCM_ERROR(StreamCaptureWrongThread) + OSTREAM_ROCM_ERROR(GraphExecUpdateFailure) + OSTREAM_ROCM_ERROR(RuntimeMemory) + OSTREAM_ROCM_ERROR(RuntimeOther) +#endif // TF_ROCM_VERSION >= 60200 + default: + return absl::StrCat("hipError_t(", static_cast(result), ")"); + } +#undef OSTREAM_ROCM_ERROR +} + +namespace internal { +absl::Status ToStatusSlow(hipError_t result, absl::string_view detail) { + std::string error_message = absl::StrCat(detail, ": ", ToString(result)); + if (result == hipErrorOutOfMemory) { + return absl::ResourceExhaustedError(error_message); + } + return absl::InternalError(error_message); +} +} // namespace internal + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_status.h b/third_party/xla/xla/stream_executor/rocm/rocm_status.h new file mode 100644 index 00000000000000..e86e970c8e623a --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_status.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_STATUS_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_STATUS_H_ + +#include + +#include "absl/base/optimization.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "rocm/include/hip/hip_runtime.h" + +namespace stream_executor::gpu { + +namespace internal { +// Helper method to handle the slow path of ToStatus. Assumes a non-successful +// result code. +absl::Status ToStatusSlow(hipError_t result, absl::string_view detail); +} // namespace internal + +// Returns an absl::Status corresponding to the hipError_t. +inline absl::Status ToStatus(hipError_t result, absl::string_view detail = "") { + if (ABSL_PREDICT_TRUE(result == hipSuccess)) { + return absl::OkStatus(); + } + return internal::ToStatusSlow(result, detail); +} + +// Returns a textual description of the given hipError_t. +std::string ToString(hipError_t result); + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_STATUS_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc new file mode 100644 index 00000000000000..0f5e46f33a557e --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_status_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_status.h" + +#include + +#include +#include +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { + +using ::testing::HasSubstr; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +TEST(RocmStatusTest, ToStatusReturnsExpectedStatusCodes) { + // We only promise hipSuccess to map to Ok, hipErrorOutOfMemory to + // ResourceExhausted, and everything else to Internal. + EXPECT_THAT(ToStatus(hipSuccess), IsOk()); + EXPECT_THAT(ToStatus(hipErrorOutOfMemory), + StatusIs(absl::StatusCode::kResourceExhausted)); + EXPECT_THAT(ToStatus(hipErrorNotInitialized), + StatusIs(absl::StatusCode::kInternal)); +} + +TEST(RocmStatusTest, ToStatusIncludesDetailMessage) { + constexpr std::string_view kMyMessage = "Some arbitrary message"; + EXPECT_THAT(ToStatus(hipErrorNotInitialized, kMyMessage), + StatusIs(absl::StatusCode::kInternal, HasSubstr(kMyMessage))); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc b/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc new file mode 100644 index 00000000000000..dff3a877227fc5 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream.cc @@ -0,0 +1,450 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_stream.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/base/casts.h" +#include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "rocm/include/hip/driver_types.h" +#include "rocm/include/hip/hip_runtime.h" +#include "rocm/rocm_config.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/rocm/rocm_kernel.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor::gpu { +namespace { +int GetGpuStreamPriority(StreamExecutor* executor, + stream_executor::StreamPriority stream_priority) { + std::unique_ptr activation = executor->Activate(); + if (stream_priority == stream_executor::StreamPriority::Default) { + return 0; + } + int lowest, highest; + hipError_t res = wrap::hipDeviceGetStreamPriorityRange(&lowest, &highest); + if (res != hipSuccess) { + LOG(ERROR) + << "Could not query stream priority range. Returning default priority."; + return 0; + } + return stream_priority == stream_executor::StreamPriority::Highest ? highest + : lowest; +} + +absl::StatusOr CreateStream(StreamExecutor* executor, + int priority) { + std::unique_ptr activation = executor->Activate(); + hipStream_t stream; + if (priority == 0) { + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipStreamCreateWithFlags(&stream, hipStreamDefault), + "Failed to create stream")); // switch to hipStreamNonBlocking? + } else { + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipStreamCreateWithPriority(&stream, hipStreamDefault, priority), + "Failed to create stream")); // switch to hipStreamNonBlocking? + } + + VLOG(2) << "successfully created stream " << stream << " for device " + << executor->device_ordinal() << " on thread"; + return stream; +} + +absl::Status RecordEvent(StreamExecutor* executor, hipEvent_t event, + hipStream_t stream) { + std::unique_ptr activation = executor->Activate(); + hipError_t res = wrap::hipEventRecord(event, stream); + switch (res) { + case hipSuccess: + return absl::OkStatus(); + case hipErrorDeinitialized: + case hipErrorNotInitialized: + return absl::FailedPreconditionError( + absl::StrFormat("error recording ROCM event on stream %p: %s", stream, + ToString(res).c_str())); + default: + return absl::InvalidArgumentError( + absl::StrFormat("error recording ROCM event on stream %p: %s", stream, + ToString(res).c_str())); + } +} + +absl::Status WaitStreamOnEvent(StreamExecutor* executor, hipStream_t stream, + hipEvent_t event) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipStreamWaitEvent(stream, event, 0 /* = flags */), + "could not wait stream on event")); + return absl::OkStatus(); +} + +absl::Status AsynchronousMemcpyD2H(StreamExecutor* executor, void* host_dst, + hipDeviceptr_t gpu_src, uint64_t size, + hipStream_t stream) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipMemcpyDtoHAsync(host_dst, gpu_src, size, stream), + absl::StrFormat( + "failed to enqueue async memcpy from device to host: host dst: %p; " + "Gpu src: %p; size: %llu=0x%llx", + host_dst, absl::bit_cast(gpu_src), size, size))); + + VLOG(2) << "successfully enqueued async memcpy d2h of " << size + << " bytes from " << absl::bit_cast(gpu_src) << " to " + << host_dst << " on stream " << stream + << " device: " << executor->device_ordinal(); + return absl::OkStatus(); +} + +absl::Status AsynchronousMemcpyH2D(StreamExecutor* executor, + hipDeviceptr_t gpu_dst, const void* host_src, + uint64_t size, hipStream_t stream) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipMemcpyHtoDAsync(gpu_dst, const_cast(host_src), size, + stream), + absl::StrFormat( + "failed to enqueue async memcpy from host to device: Gpu dst: %p; " + "host src: %p; size: %llu=0x%llx", + absl::bit_cast(gpu_dst), host_src, size, size))); + + VLOG(2) << "successfully enqueued async memcpy h2d of " << size + << " bytes from " << host_src << " to " + << absl::bit_cast(gpu_dst) << " on stream " << stream + << " device: " << executor->device_ordinal(); + return absl::OkStatus(); +} + +absl::Status AsynchronousMemcpyD2D(StreamExecutor* executor, + hipDeviceptr_t gpu_dst, + hipDeviceptr_t gpu_src, uint64_t size, + hipStream_t stream) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR(ToStatus( + wrap::hipMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream), + absl::StrFormat("failed to enqueue async memcpy from device to device: " + "Gpu dst: %p ; Gpu src: %p ; size: %llu=0x%llx", + absl::bit_cast(gpu_dst), + absl::bit_cast(gpu_src), size, size))); + + VLOG(2) << "successfully enqueued async memcpy d2d of " << size + << " bytes from " << absl::bit_cast(gpu_src) << " to " + << absl::bit_cast(gpu_dst) << " on stream " << stream + << " device: " << executor->device_ordinal(); + return absl::OkStatus(); +} + +absl::Status SynchronizeStream(StreamExecutor* executor, hipStream_t stream) { + std::unique_ptr activation = executor->Activate(); + TF_RETURN_IF_ERROR(ToStatus(wrap::hipStreamSynchronize(stream), + "Could not synchronize on ROCM stream")); + VLOG(2) << "successfully synchronized stream " << stream << " on device " + << executor->device_ordinal(); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> RocmStream::Create( + StreamExecutor* executor, + std::optional> priority) { + int stream_priority = [&]() { + if (priority.has_value() && std::holds_alternative(priority.value())) { + return std::get(priority.value()); + } + return GetGpuStreamPriority( + executor, + std::get(priority.value_or(StreamPriority::Default))); + }(); + TF_ASSIGN_OR_RETURN(auto stream_handle, + CreateStream(executor, stream_priority)); + + TF_ASSIGN_OR_RETURN(auto completed_event, + RocmEvent::Create(executor, + /*allow_timing=*/false)); + + return std::unique_ptr(new RocmStream( + executor, std::move(completed_event), priority, stream_handle)); +} + +absl::Status RocmStream::WaitFor(Stream* other) { + RocmStream* other_stream = static_cast(other); + + TF_RETURN_IF_ERROR(other_stream->RecordCompletedEvent()); + + return WaitStreamOnEvent(executor_, stream_handle_, + other_stream->completed_event_.GetHandle()); +} + +absl::Status RocmStream::RecordEvent(Event* event) { + return stream_executor::gpu::RecordEvent( + executor_, static_cast(event)->GetHandle(), stream_handle_); +} + +absl::Status RocmStream::WaitFor(Event* event) { + return WaitStreamOnEvent(executor_, stream_handle_, + static_cast(event)->GetHandle()); +} + +absl::Status RocmStream::RecordCompletedEvent() { + return RecordEvent(&completed_event_); +} + +namespace { +void DestroyStream(StreamExecutor* executor, hipStream_t stream) { + if (stream == nullptr) { + return; + } + hipError_t res = wrap::hipStreamQuery(stream); + if (res != hipSuccess) { + LOG(ERROR) << "stream not idle on destroy: " << ToString(res); + } + + std::unique_ptr activation = executor->Activate(); + res = wrap::hipStreamDestroy(stream); + if (res != hipSuccess) { + LOG(ERROR) << "failed to destroy ROCM stream for device " + << executor->device_ordinal() << ": " << ToString(res); + } else { + VLOG(2) << "successfully destroyed stream " << stream << " for device " + << executor->device_ordinal(); + } +} +} // namespace + +RocmStream::~RocmStream() { + BlockHostUntilDone().IgnoreError(); + executor_->DeallocateStream(this); + + DestroyStream(executor_, stream_handle_); +} + +absl::Status RocmStream::Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) { + if (absl::bit_cast(location->opaque()) % alignof(uint32_t) != 0) { + return absl::InvalidArgumentError("location must be 4 byte aligned."); + } + if (size % sizeof(uint32_t) != 0) { + return absl::InvalidArgumentError("size must be a multiple of 4 bytes."); + } + return ToStatus(wrap::hipMemsetD32Async(location->opaque(), pattern, size / 4, + stream_handle_), + "Failed to memset memory"); +} + +absl::Status RocmStream::MemZero(DeviceMemoryBase* location, uint64_t size) { + if (absl::bit_cast(location->opaque()) % alignof(uint32_t) == 0 && + size % sizeof(uint32_t) == 0) { + return Memset32(location, 0x0, size); + } else { + std::unique_ptr activation = executor_->Activate(); + return ToStatus( + wrap::hipMemsetAsync(location->opaque(), 0x0, size, stream_handle_), + "Failed to enqueue async memset operation"); + } +} + +absl::Status RocmStream::Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64_t size) { + return AsynchronousMemcpyD2D( + executor_, absl::bit_cast(gpu_dst->opaque()), + absl::bit_cast(gpu_src.opaque()), size, stream_handle_); +} + +absl::Status RocmStream::Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) { + return AsynchronousMemcpyH2D( + executor_, absl::bit_cast(gpu_dst->opaque()), host_src, + size, stream_handle_); +} + +absl::Status RocmStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) { + return AsynchronousMemcpyD2H(executor_, host_dst, + absl::bit_cast(gpu_src.opaque()), + size, stream_handle_); +} + +namespace { +void InternalHostCallback(void* data) { + auto* callback = reinterpret_cast*>(data); + std::move (*callback)(); + delete callback; +} +} // namespace + +absl::Status RocmStream::DoHostCallbackWithStatus( + absl::AnyInvocable callback) { + auto callback_ptr = + new absl::AnyInvocable([cb = std::move(callback)]() mutable { + absl::Status s = std::move(cb)(); + if (!s.ok()) { + LOG(WARNING) << "Host callback failed: " << s; + } + }); + return ToStatus( + wrap::hipLaunchHostFunc(stream_handle_, (hipHostFn_t)InternalHostCallback, + callback_ptr), + "unable to add host callback"); +} + +namespace { +absl::Status LaunchKernel(StreamExecutor* executor, + absl::string_view kernel_name, hipFunction_t function, + unsigned int grid_dim_x, unsigned int grid_dim_y, + unsigned int grid_dim_z, unsigned int block_dim_x, + unsigned int block_dim_y, unsigned int block_dim_z, + unsigned int shared_mem_bytes, hipStream_t stream, + void** kernel_params, void** extra) { + std::unique_ptr activation = executor->Activate(); + VLOG(2) << "launching kernel: " << kernel_name << "; gdx: " << grid_dim_x + << " gdy: " << grid_dim_y << " gdz: " << grid_dim_z + << " bdx: " << block_dim_x << " bdy: " << block_dim_y + << " bdz: " << block_dim_z << " smem: " << shared_mem_bytes + << " func: " << (const void*)function; + + auto res = hipSuccess; +#if TF_ROCM_VERSION < 60200 + // for in-process kernel this function returns mangled kernel function name, + // and null otherwise + auto name = wrap::hipKernelNameRefByPtr((const void*)function, stream); + if (name != nullptr) { + res = wrap::hipLaunchKernel((const void*)function, + dim3(grid_dim_x, grid_dim_y, grid_dim_z), + dim3(block_dim_x, block_dim_y, block_dim_z), + kernel_params, shared_mem_bytes, stream); + } else // NOLINT(readability/braces) +#endif // TF_ROCM_VERSION < 60200 + { + res = wrap::hipModuleLaunchKernel( + function, grid_dim_x, grid_dim_y, grid_dim_z, block_dim_x, block_dim_y, + block_dim_z, shared_mem_bytes, stream, kernel_params, extra); + } + TF_RETURN_IF_ERROR( + ToStatus(res, absl::StrCat("Failed to launch ROCm kernel: ", kernel_name, + " with block dimensions: ", block_dim_x, "x", + block_dim_y, "x", block_dim_z))); + + VLOG(2) << "successfully launched kernel"; + return absl::OkStatus(); +} + +absl::Status LaunchKernel(StreamExecutor* executor, + absl::string_view kernel_name, hipFunction_t function, + unsigned int cluster_dim_x, + unsigned int cluster_dim_y, + unsigned int cluster_dim_z, unsigned int grid_dim_x, + unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, + unsigned int shared_mem_bytes, hipStream_t stream, + void** kernel_params, void** extra) { + if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1) + return absl::UnimplementedError("Not implemented for ROCm"); + return LaunchKernel(executor, kernel_name, function, grid_dim_x, grid_dim_y, + grid_dim_z, block_dim_x, block_dim_y, block_dim_z, + shared_mem_bytes, stream, kernel_params, extra); +} + +} // namespace + +absl::Status RocmStream::BlockHostUntilDone() { + return SynchronizeStream(executor_, stream_handle_); +} + +absl::Status RocmStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + const RocmKernel* gpu_kernel = static_cast(&kernel); + hipFunction_t function = gpu_kernel->gpu_function(); + + // Launch kernels with packed arguments. + auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, + &function](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return LaunchKernel( + executor_, kernel.name(), function, cluster_dims->x, cluster_dims->y, + cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, + thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), stream_handle_, params, + /*extra=*/nullptr); + } else { + return LaunchKernel( + executor_, kernel.name(), function, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), stream_handle_, params, + /*extra=*/nullptr); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream.h b/third_party/xla/xla/stream_executor/rocm/rocm_stream.h new file mode 100644 index 00000000000000..693335daa187bf --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream.h @@ -0,0 +1,99 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_STREAM_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_STREAM_H_ + +#include +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_common.h" + +namespace stream_executor { +namespace gpu { + +class RocmStream : public StreamCommon { + public: + absl::Status WaitFor(Stream* other) override; + absl::Status RecordEvent(Event* event) override; + absl::Status WaitFor(Event* event) override; + + absl::Status Memset32(DeviceMemoryBase* location, uint32_t pattern, + uint64_t size) override; + absl::Status MemZero(DeviceMemoryBase* location, uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, const void* host_src, + uint64_t size) override; + absl::Status Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, + uint64_t size) override; + absl::Status Memcpy(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, uint64_t size) override; + absl::Status DoHostCallbackWithStatus( + absl::AnyInvocable callback) override; + absl::Status BlockHostUntilDone() override; + + Stream::PlatformSpecificHandle platform_specific_handle() const override { + return {stream_handle_}; + } + + absl::StatusOr> CreateEventBasedTimer( + bool use_delay_kernel) override { + return executor_->CreateEventBasedTimer(this, use_delay_kernel); + } + + static absl::StatusOr> Create( + StreamExecutor* executor, + std::optional> priority); + + ~RocmStream() override; + + hipStream_t stream_handle() const { return stream_handle_; } + + private: + RocmStream(StreamExecutor* executor, RocmEvent completed_event, + std::optional> priority, + hipStream_t stream_handle) + : StreamCommon(executor, priority), + executor_(executor), + completed_event_(std::move(completed_event)), + stream_handle_(stream_handle) {} + + absl::Status RecordCompletedEvent(); + + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) override; + + StreamExecutor* executor_; + RocmEvent completed_event_; + hipStream_t stream_handle_; +}; + +} // namespace gpu +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc new file mode 100644 index 00000000000000..70acb6fdb7306e --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_stream_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_stream.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_executor.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace gpu { +namespace { + +using ::testing::Each; +using ::testing::ElementsAreArray; +using ::tsl::testing::IsOk; + +class RocmStreamTest : public ::testing::Test { + public: + std::optional executor_; + + private: + void SetUp() override { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::rocm::kROCmPlatformId)); + executor_.emplace(platform, 0); + ASSERT_THAT(executor_->Init(), IsOk()); + } +}; + +TEST_F(RocmStreamTest, Memset32) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + // Should fail due to the invalid size parameter. + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t) + 1), + ::tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument)); + + // Should fail due to the non-4-byte-aligned pointer. + DeviceMemoryBase unaligned_pointer = + buffer.GetByteSlice(/*offset_bytes=*/1, /*size_bytes=*/0); + EXPECT_THAT(stream->Memset32(&unaligned_pointer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t) + 1), + ::tsl::testing::StatusIs(absl::StatusCode::kInvalidArgument)); + + // Correct call. Should succeed. + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, Each(0xDEADBEEF)); +} + +TEST_F(RocmStreamTest, MemZero) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + EXPECT_THAT(stream->Memset32(&buffer, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + // We overwrite half the buffer with zeros. + EXPECT_THAT( + stream->MemZero(&buffer, kBufferNumElements / 2 * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + // We expect the first half of the buffer to be zeros. + EXPECT_THAT( + absl::MakeConstSpan(host_buffer).subspan(0, kBufferNumElements / 2), + Each(0x0)); + + // And it shouldn't have touched the second half. + EXPECT_THAT(absl::MakeConstSpan(host_buffer).subspan(kBufferNumElements / 2), + Each(0xDEADBEEF)); +} + +TEST_F(RocmStreamTest, MemcpyHostToDeviceAndBack) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + std::array src_buffer; + std::generate(src_buffer.begin(), src_buffer.end(), + [i = 0]() mutable { return i++; }); + + EXPECT_THAT(stream->MemcpyH2D(absl::MakeConstSpan(src_buffer), &buffer), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, ElementsAreArray(src_buffer)); +} + +TEST_F(RocmStreamTest, MemcpyDeviceToDevice) { + constexpr int kBufferNumElements = 42; + DeviceMemory buffer1 = + executor_->AllocateArray(kBufferNumElements, 0); + DeviceMemory buffer2 = + executor_->AllocateArray(kBufferNumElements, 0); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + EXPECT_THAT(stream->Memset32(&buffer1, 0xDEADBEEF, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + EXPECT_THAT(stream->MemcpyD2D(&buffer2, buffer1, + kBufferNumElements * sizeof(uint32_t)), + IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(buffer2, absl::MakeSpan(host_buffer)), IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_THAT(host_buffer, Each(0xDEADBEEF)); +} + +TEST_F(RocmStreamTest, DoHostCallback) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + bool callback_called = false; + EXPECT_THAT( + stream->DoHostCallback([&callback_called]() { callback_called = true; }), + IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + EXPECT_TRUE(callback_called); +} + +TEST_F(RocmStreamTest, LaunchKernel) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + TF_ASSERT_OK_AND_ASSIGN(auto add, + AddI32Kernel::Create(&executor_.value(), spec)); + + constexpr int64_t kLength = 4; + constexpr int64_t kByteLength = sizeof(int32_t) * kLength; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor_->AllocateArray(kLength, 0); + DeviceMemory b = executor_->AllocateArray(kLength, 0); + DeviceMemory c = executor_->AllocateArray(kLength, 0); + + EXPECT_THAT(stream->Memset32(&a, 1, kByteLength), IsOk()); + EXPECT_THAT(stream->Memset32(&b, 2, kByteLength), IsOk()); + EXPECT_THAT(stream->MemZero(&c, kByteLength), IsOk()); + EXPECT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(kLength), add, a, b, c), + IsOk()); + + EXPECT_THAT(stream->BlockHostUntilDone(), IsOk()); + + std::array host_buffer; + EXPECT_THAT(stream->MemcpyD2H(c, absl::MakeSpan(host_buffer)), IsOk()); + EXPECT_THAT(host_buffer, Each(3)); +} + +TEST_F(RocmStreamTest, SetName) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr stream, + RocmStream::Create(&executor_.value(), + /*priority=*/std::nullopt)); + + constexpr absl::string_view kStreamName = "Test stream"; + stream->SetName(std::string(kStreamName)); + EXPECT_EQ(stream->GetName(), kStreamName); +} + +} // namespace +} // namespace gpu +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_timer.cc b/third_party/xla/xla/stream_executor/rocm/rocm_timer.cc new file mode 100644 index 00000000000000..784f6f67ba14bd --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_timer.cc @@ -0,0 +1,85 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_timer.h" + +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "rocm/include/hip/hip_runtime.h" +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/rocm/rocm_driver_wrapper.h" +#include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/rocm/rocm_status.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace stream_executor::gpu { +namespace { +absl::StatusOr GetEventElapsedTime(StreamExecutor* executor, + hipEvent_t start, hipEvent_t stop) { + std::unique_ptr activation = executor->Activate(); + // The stop event must have completed in order for hipEventElapsedTime to + // work. + hipError_t res = wrap::hipEventSynchronize(stop); + if (res != hipSuccess) { + LOG(ERROR) << "failed to synchronize the stop event: " << ToString(res); + return false; + } + float elapsed_milliseconds; + TF_RETURN_IF_ERROR( + ToStatus(wrap::hipEventElapsedTime(&elapsed_milliseconds, start, stop), + "failed to get elapsed time between events")); + + return elapsed_milliseconds; +} +} // namespace + +RocmTimer::RocmTimer(StreamExecutor* executor, RocmEvent start_event, + RocmEvent stop_event, Stream* stream) + : executor_(executor), + stream_(stream), + start_event_(std::move(start_event)), + stop_event_(std::move(stop_event)) {} + +absl::StatusOr RocmTimer::GetElapsedDuration() { + if (is_stopped_) { + return absl::FailedPreconditionError("Measuring inactive timer"); + } + TF_RETURN_IF_ERROR(stream_->RecordEvent(&stop_event_)); + TF_ASSIGN_OR_RETURN(float elapsed_milliseconds, + GetEventElapsedTime(executor_, start_event_.GetHandle(), + stop_event_.GetHandle())); + is_stopped_ = true; + return absl::Milliseconds(elapsed_milliseconds); +} + +absl::StatusOr RocmTimer::Create(StreamExecutor* executor, + Stream* stream) { + TF_ASSIGN_OR_RETURN(RocmEvent start_event, + RocmEvent::Create(executor, /*allow_timing=*/true)); + TF_ASSIGN_OR_RETURN(RocmEvent stop_event, + RocmEvent::Create(executor, /*allow_timing=*/true)); + TF_RETURN_IF_ERROR(stream->RecordEvent(&start_event)); + return RocmTimer(executor, std::move(start_event), std::move(stop_event), + stream); +} +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_timer.h b/third_party/xla/xla/stream_executor/rocm/rocm_timer.h new file mode 100644 index 00000000000000..f6764fbcbc4300 --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_timer.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_ROCM_ROCM_TIMER_H_ +#define XLA_STREAM_EXECUTOR_ROCM_ROCM_TIMER_H_ + +#include "absl/status/statusor.h" +#include "absl/time/time.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/rocm/rocm_event.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" + +namespace stream_executor::gpu { + +class RocmTimer : public EventBasedTimer { + public: + RocmTimer(RocmTimer&&) = default; + RocmTimer& operator=(RocmTimer&&) = default; + + absl::StatusOr GetElapsedDuration() override; + + static absl::StatusOr Create(StreamExecutor* executor, + Stream* stream); + + private: + RocmTimer(StreamExecutor* executor, RocmEvent start_event, + RocmEvent stop_event, Stream* stream); + + bool is_stopped_ = false; + StreamExecutor* executor_; + Stream* stream_; + RocmEvent start_event_; + RocmEvent stop_event_; +}; +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_ROCM_ROCM_TIMER_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc new file mode 100644 index 00000000000000..958c5dfa53316f --- /dev/null +++ b/third_party/xla/xla/stream_executor/rocm/rocm_timer_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/rocm/rocm_timer.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/time/time.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_executor.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/typed_kernel_factory.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { +using ::testing::Gt; +using ::tsl::testing::IsOk; + +class RocmTimerTest : public ::testing::Test { + public: + void LaunchSomeKernel(StreamExecutor* executor, Stream* stream) { + using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor->AllocateArray(length, 0); + DeviceMemory b = executor->AllocateArray(length, 0); + DeviceMemory c = executor->AllocateArray(length, 0); + + ASSERT_THAT(stream->Memset32(&a, 1, byte_length), IsOk()); + ASSERT_THAT(stream->Memset32(&b, 2, byte_length), IsOk()); + ASSERT_THAT(stream->MemZero(&c, byte_length), IsOk()); + + ASSERT_THAT(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c), + IsOk()); + } + + RocmExecutor* executor_; + std::unique_ptr stream_; + + private: + void SetUp() override { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + stream_executor::PlatformManager::PlatformWithId( + stream_executor::rocm::kROCmPlatformId)); + TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, + platform->ExecutorForDevice(0)); + executor_ = reinterpret_cast(executor); + TF_ASSERT_OK_AND_ASSIGN(stream_, executor_->CreateStream(std::nullopt)); + } +}; + +TEST_F(RocmTimerTest, Create) { + TF_ASSERT_OK_AND_ASSIGN(RocmTimer timer, + RocmTimer::Create(executor_, stream_.get())); + + // We don't really care what kernel we launch here as long as it takes a + // non-zero amount of time. + LaunchSomeKernel(executor_, stream_.get()); + + TF_ASSERT_OK_AND_ASSIGN(absl::Duration timer_result, + timer.GetElapsedDuration()); + EXPECT_THAT(timer_result, Gt(absl::ZeroDuration())); + EXPECT_THAT(timer.GetElapsedDuration(), + tsl::testing::StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc b/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc index 2306ae8717a110..1859aed034fbdb 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_version_parser_test.cc @@ -30,7 +30,7 @@ using tsl::testing::IsOkAndHolds; using tsl::testing::StatusIs; TEST(ParseRocmVersionTest, Simple) { - EXPECT_THAT(stream_executor::ParseRocmVersion(60102), + EXPECT_THAT(stream_executor::ParseRocmVersion(60'100'002), IsOkAndHolds(SemanticVersion(6, 1, 2))); } diff --git a/third_party/xla/xla/stream_executor/rocm/rocsolver_wrapper.h b/third_party/xla/xla/stream_executor/rocm/rocsolver_wrapper.h index 23b6c4b99dabbc..173f4e5faeaa9a 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocsolver_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/rocsolver_wrapper.h @@ -27,8 +27,7 @@ limitations under the License. #include "rocm/include/rocsolver.h" #endif -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" namespace stream_executor { @@ -47,22 +46,21 @@ namespace wrap { #define TO_STR_(x) #x #define TO_STR(x) TO_STR_(x) -#define ROCSOLVER_API_WRAPPER(api_name) \ - template \ - auto api_name(Args... args) -> decltype(::api_name(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = TO_STR(api_name); \ - void* f; \ - auto s = tsl::Env::Default() -> GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetRocsolverDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in rocsolver lib; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define ROCSOLVER_API_WRAPPER(api_name) \ + template \ + auto api_name(Args... args) -> decltype(::api_name(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = TO_STR(api_name); \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetRocsolverDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in rocsolver lib; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif diff --git a/third_party/xla/xla/stream_executor/rocm/roctracer_wrapper.h b/third_party/xla/xla/stream_executor/rocm/roctracer_wrapper.h index b42751bb53e0cf..871df2cb9e2f69 100644 --- a/third_party/xla/xla/stream_executor/rocm/roctracer_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/roctracer_wrapper.h @@ -28,8 +28,7 @@ limitations under the License. #else #include "rocm/include/roctracer/roctracer_hcc.h" #endif -#include "xla/stream_executor/platform/dso_loader.h" -#include "xla/stream_executor/platform/port.h" +#include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/platform.h" @@ -46,22 +45,21 @@ namespace wrap { #else -#define ROCTRACER_API_WRAPPER(API_NAME) \ - template \ - auto API_NAME(Args... args) -> decltype(::API_NAME(args...)) { \ - using FuncPtrT = std::add_pointer::type; \ - static FuncPtrT loaded = []() -> FuncPtrT { \ - static const char* kName = #API_NAME; \ - void* f; \ - auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ - stream_executor::internal::CachedDsoLoader::GetRoctracerDsoHandle() \ - .value(), \ - kName, &f); \ - CHECK(s.ok()) << "could not find " << kName \ - << " in roctracer DSO; dlerror: " << s.message(); \ - return reinterpret_cast(f); \ - }(); \ - return loaded(args...); \ +#define ROCTRACER_API_WRAPPER(API_NAME) \ + template \ + auto API_NAME(Args... args) -> decltype(::API_NAME(args...)) { \ + using FuncPtrT = std::add_pointer::type; \ + static FuncPtrT loaded = []() -> FuncPtrT { \ + static const char* kName = #API_NAME; \ + void* f; \ + auto s = tsl::Env::Default()->GetSymbolFromLibrary( \ + tsl::internal::CachedDsoLoader::GetRoctracerDsoHandle().value(), \ + kName, &f); \ + CHECK(s.ok()) << "could not find " << kName \ + << " in roctracer DSO; dlerror: " << s.message(); \ + return reinterpret_cast(f); \ + }(); \ + return loaded(args...); \ } #endif // PLATFORM_GOOGLE diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 0fccf94270a85c..220cbf761c24fd 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -23,6 +23,8 @@ limitations under the License. #include #include +#include +#include #include #include @@ -270,25 +272,24 @@ class Stream { // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying // platform driver. - virtual absl::Status Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) { - return absl::UnimplementedError("Not implemented"); + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const Kernel &kernel, const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::nullopt, kernel, args); } // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying // platform driver. - virtual absl::Status Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) { - return absl::UnimplementedError("Not implemented"); + absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &kernel, + const KernelArgs &args) { + return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), + kernel, args); } // Get/set a name for a stream, which can be shown in profiling tools - virtual absl::string_view name() const = 0; - virtual void set_name(absl::string_view name) = 0; + virtual const std::string &GetName() const = 0; + virtual void SetName(std::string name) = 0; // Create an EventBasedTimer that can be used to time operations on this // stream using Events. @@ -304,6 +305,15 @@ class Stream { return absl::UnimplementedError( "This stream does not support EventBasedTimers."); } + + private: + // Helper method to launch a kernel with optional cluster dimensions. + virtual absl::Status Launch(const ThreadDim &thread_dims, + const BlockDim &block_dims, + const std::optional &cluster_dims, + const Kernel &kernel, const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } }; template diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index e7833bfd25dab2..2c4dd7828c539d 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -18,14 +18,16 @@ limitations under the License. #include #include #include +#include #include +#include #include #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/blas.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" @@ -33,10 +35,21 @@ limitations under the License. namespace stream_executor { StreamCommon::StreamCommon(StreamExecutor *parent) - : parent_(parent), status_(absl::OkStatus()) { + : parent_(parent), + status_(absl::OkStatus()), + stream_priority_(StreamPriority::Default) { CHECK_NE(parent, nullptr); } +StreamCommon::StreamCommon( + StreamExecutor *parent, + std::optional> priority) + : StreamCommon(parent) { + if (priority.has_value()) { + stream_priority_ = priority.value(); + } +} + StreamCommon::PlatformSpecificHandle StreamCommon::platform_specific_handle() const { PlatformSpecificHandle handle; @@ -83,7 +96,7 @@ absl::StatusOr StreamCommon::GetOrCreateSubStream() { // No streams are reusable; create a new stream. TF_ASSIGN_OR_RETURN(auto stream, parent_->CreateStream()); Stream *sub_stream = stream.get(); - sub_stream->set_name(absl::StrFormat("Sub-stream of %s", name())); + sub_stream->SetName(absl::StrFormat("Sub-stream of %s", GetName())); sub_streams_.emplace_back(std::move(stream), false); VLOG(1) << "stream=" << this << " created new sub_stream=" << sub_stream; @@ -135,22 +148,6 @@ void StreamCommon::CheckError(bool operation_retcode) { status_ = absl::InternalError("Unknown error"); } -absl::Status StreamCommon::BlockHostUntilDone() { - if (!ok()) { - absl::MutexLock lock(&mu_); - LOG(INFO) << status_.ToString(); - absl::Status status = absl::InternalError( - "stream did not block host until done; was already in an error state"); - LOG(INFO) << "stream = " << this << " " << status; - return status; - } - - absl::Status error = parent_->BlockHostUntilDone(this); - CheckError(error.ok()); - - return error; -} - void StreamCommon::CheckStatus(absl::Status status) { if (status.ok()) { return; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index f7029c72fbadbf..5832a8a1950146 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -22,6 +22,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_STREAM_COMMON_H_ #include +#include #include #include #include @@ -31,7 +32,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/fft.h" @@ -60,16 +60,21 @@ class StreamCommon : public Stream { // StreamExecutor's platform. explicit StreamCommon(StreamExecutor *parent); + StreamCommon(StreamExecutor *parent, + std::optional> priority); + PlatformSpecificHandle platform_specific_handle() const override; bool ok() const override { return !InErrorState(); } absl::StatusOr GetOrCreateSubStream() override TF_LOCKS_EXCLUDED(mu_); void ReturnSubStream(Stream *sub_stream) override TF_LOCKS_EXCLUDED(mu_); - absl::Status BlockHostUntilDone() override TF_LOCKS_EXCLUDED(mu_); StreamExecutor *parent() const override { CHECK(parent_ != nullptr); return parent_; } + std::variant priority() const override { + return stream_priority_; + } CudaComputeCapability GetCudaComputeCapability() const override { return parent()->GetDeviceDescription().cuda_compute_capability(); @@ -78,13 +83,10 @@ class StreamCommon : public Stream { RocmComputeCapability GetRocmComputeCapability() const override { return parent()->GetDeviceDescription().rocm_compute_capability(); } - std::variant priority() const override { - return StreamPriority::Default; - } // Doesn't do anything interesting by default; GpuStream connects this to NVTX - absl::string_view name() const override { return name_; } - void set_name(absl::string_view name) override { name_ = name; } + const std::string &GetName() const override { return name_; } + void SetName(std::string name) override { name_ = std::move(name); } protected: bool InErrorState() const TF_LOCKS_EXCLUDED(mu_) { @@ -118,8 +120,7 @@ class StreamCommon : public Stream { std::vector, bool>> sub_streams_ ABSL_GUARDED_BY(mu_); - StreamCommon(const StreamCommon &) = delete; - void operator=(const StreamCommon &) = delete; + std::variant stream_priority_; }; } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index a0c3c48e521e30..2ebd361fa16756 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/stream_executor/activate_context.h" #include "xla/stream_executor/allocator_stats.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/command_buffer.h" @@ -71,6 +72,10 @@ class StreamExecutor { public: virtual ~StreamExecutor() = default; + // Returns an ActivateContext that ensures the StreamExecutor's context is + // activated for the duration of the returned ActivateContext's scope. + virtual std::unique_ptr Activate() = 0; + // Returns a reference to the platform that created this executor. virtual const Platform* GetPlatform() const = 0; @@ -86,6 +91,11 @@ class StreamExecutor { absl::StatusOr> CreateStream() { return CreateStream(std::nullopt); } + // Creates an EventBasedTimer for the given stream. + virtual absl::StatusOr> + CreateEventBasedTimer(Stream* stream, bool use_delay_kernel) { + return absl::UnimplementedError("Not Implemented"); + } // Creates and initializes an Event. virtual absl::StatusOr> CreateEvent() = 0; @@ -117,16 +127,18 @@ class StreamExecutor { return absl::UnimplementedError("Not Implemented"); } + // Releases any state associated with the previously loaded kernel. + virtual void UnloadKernel(const Kernel* kernel) {} + // Unloads the module with handle `module_handle`. virtual bool UnloadModule(ModuleHandle module_handle) { return false; } // Loads a module for the platform this StreamExecutor is acting upon. // - // `spec` describes the module to be loaded. On success writes the handle for - // the loaded module to `module_handle` and returns absl::OkStatus(). - // Otherwise, returns the error which has occurred. - virtual absl::Status LoadModule(const MultiModuleLoaderSpec& spec, - ModuleHandle* module_handle) { + // `spec` describes the module to be loaded. On success returns the handle + // for the loaded module. Otherwise, returns the error which has occurred. + virtual absl::StatusOr LoadModule( + const MultiModuleLoaderSpec& spec) { return absl::UnimplementedError("Not Implemented"); } @@ -193,6 +205,19 @@ class StreamExecutor { virtual absl::Status SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) = 0; + // Returns a DeviceMemoryBase representing the range [base, base + size) + // for the given DeviceMemoryBase, such that location is contained within the + // returned range. + virtual absl::StatusOr GetMemoryRange( + const DeviceMemoryBase& location) { + return absl::UnimplementedError("Not implemented for this executor."); + } + + virtual bool HostMemoryUnregister(void* location) { return false; }; + virtual bool HostMemoryRegister(void* location, uint64_t size) { + return false; + }; + // Blocks the caller while "size" bytes are copied to the given location in // device memory. virtual absl::Status SynchronousMemcpy(DeviceMemoryBase* device_dst, @@ -216,11 +241,6 @@ class StreamExecutor { // Deallocates stream resources on the underlying platform. virtual void DeallocateStream(Stream* stream) = 0; - // Causes the host code to synchronously wait for operations enqueued - // onto stream to complete. Effectively a join on the asynchronous device - // operations enqueued on the stream before this program point. - virtual absl::Status BlockHostUntilDone(Stream* stream) = 0; - // Enables peer access from this StreamExecutor to memory // allocated by other, such that launched device code, memcpies, etc may // access it directly. diff --git a/third_party/xla/xla/stream_executor/stream_executor_common.h b/third_party/xla/xla/stream_executor/stream_executor_common.h index 52bb6f6394c831..1e755d637b05cd 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_common.h +++ b/third_party/xla/xla/stream_executor/stream_executor_common.h @@ -1,3 +1,5 @@ +#include "xla/stream_executor/activate_context.h" +#include "xla/stream_executor/device_description.h" /* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,6 +44,11 @@ class StreamExecutorCommon : public StreamExecutor { public: explicit StreamExecutorCommon(const Platform* platform); + std::unique_ptr Activate() override { + // Non-GPU stream executors don't have a context to activate. + return std::make_unique(); + } + const Platform* GetPlatform() const override { return platform_; } const DeviceDescription& GetDeviceDescription() const override; int64_t GetMemoryLimitBytes() const override { return memory_limit_bytes_; } diff --git a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc index c17656c6d34147..e03e2783a6f4a0 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_memory_allocator.cc @@ -105,7 +105,7 @@ absl::StatusOr StreamExecutorMemoryAllocator::GetStream( if (!streams_.count(device_ordinal)) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); auto stream_ptr = stream.get(); - stream_ptr->set_name("StreamExecutorMemoryAllocator"); + stream_ptr->SetName("StreamExecutorMemoryAllocator"); streams_.emplace(device_ordinal, std::move(stream)); return stream_ptr; } diff --git a/third_party/xla/xla/stream_executor/sycl/BUILD b/third_party/xla/xla/stream_executor/sycl/BUILD index 86c00a09d67028..906df8332046a2 100644 --- a/third_party/xla/xla/stream_executor/sycl/BUILD +++ b/third_party/xla/xla/stream_executor/sycl/BUILD @@ -41,13 +41,11 @@ cc_library( ":sycl_platform_id", "@com_google_absl//absl/base", "@com_google_absl//absl/memory", - "//xla/stream_executor", # buildcleaner: keep "//xla/stream_executor:executor_cache", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "//xla/stream_executor/gpu:gpu_types_header", "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/stream_executor/gpu:gpu_collectives_header", "@local_tsl//tsl/platform:errors", ]), alwayslink = True, # Registers itself with the PlatformManager. @@ -83,10 +81,8 @@ cc_library( deps = [ ":sycl_platform_id", ":sycl_rpath", - "//xla/stream_executor", "//xla/stream_executor:dnn", "//xla/stream_executor:platform_manager", "//xla/stream_executor:scratch_allocator", - "//xla/stream_executor/platform:dso_loader", ] + if_static([":all_runtime"]), ) diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc index a78e104670bf21..cb3759d88fdeb0 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc @@ -49,10 +49,7 @@ Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; } int SyclPlatform::VisibleDeviceCount() const { // Initialized in a thread-safe manner the first time this is run. - static const int num_devices = [] { - if (!GpuDriver::Init().ok()) return -1; - return GpuDriver::GetDeviceCount(); - }(); + static const int num_devices = [] { return GpuDriver::GetDeviceCount(); }(); return num_devices; } diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h index 61f0eb3d5372b9..7c70e5d17e0f6e 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h @@ -60,7 +60,7 @@ class SyclPlatform : public Platform { // looking in or storing to the Platform's executor cache. // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - int ordinal) override; + int ordinal); // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/tpu/BUILD b/third_party/xla/xla/stream_executor/tpu/BUILD index 15abc2b4dcee23..e8e9f8385d2ec5 100644 --- a/third_party/xla/xla/stream_executor/tpu/BUILD +++ b/third_party/xla/xla/stream_executor/tpu/BUILD @@ -32,8 +32,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":libtftpu_header", - "//xla/stream_executor", "//xla/stream_executor:event", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", ], ) @@ -93,8 +94,8 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -199,13 +200,16 @@ cc_library( ":tpu_platform_interface", ":tpu_stream_interface", ":tpu_topology_external", - "//xla/stream_executor", "//xla/stream_executor:allocator_stats", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:executor_cache", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", @@ -234,9 +238,10 @@ cc_library( ":tpu_executor_c_api_hdrs", ":tpu_platform_interface", ":tpu_topology_external", - "//xla/stream_executor", "//xla/stream_executor:event", "//xla/stream_executor:executor_cache", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -264,11 +269,14 @@ cc_library( ":tpu_executor_api", ":tpu_executor_c_api_hdrs", ":tpu_topology_external", - "//xla/stream_executor", "//xla/stream_executor:allocator_stats", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:executor_cache", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", "//xla/stream_executor:stream_executor_common", "//xla/stream_executor:stream_executor_h", @@ -408,7 +416,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/log", "@com_google_absl//absl/status", ], @@ -445,7 +453,10 @@ cc_library( "//xla:shape_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -472,7 +483,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:executable", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -491,14 +502,15 @@ cc_library( deps = [ ":c_api_decl", ":tpu_topology_external", - "//xla/stream_executor", "//xla/stream_executor:event", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) @@ -509,7 +521,8 @@ cc_library( deps = [ ":tpu_platform_interface", ":tpu_topology_external", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_common", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -554,7 +567,7 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_proto_cc", - "//xla/stream_executor", + "//xla/stream_executor:platform", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -567,7 +580,7 @@ cc_library( hdrs = ["tpu_stream_interface.h"], visibility = ["//visibility:public"], deps = [ - "//xla/stream_executor", + "//xla/stream_executor:device_memory", "//xla/stream_executor:stream_common", "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", @@ -590,7 +603,10 @@ cc_library( "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -622,8 +638,8 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_execution_profile", "//xla/service:shaped_buffer", - "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc index 196a6b004c96c6..015536525c64fc 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc +++ b/third_party/xla/xla/stream_executor/tpu/c_api_conversions_test.cc @@ -27,10 +27,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/tpu/c_api_decl.h" diff --git a/third_party/xla/xla/stream_executor/tpu/noncopyable_buffer.h b/third_party/xla/xla/stream_executor/tpu/noncopyable_buffer.h index 8e0abf45c88af0..f9c86d59a5b753 100644 --- a/third_party/xla/xla/stream_executor/tpu/noncopyable_buffer.h +++ b/third_party/xla/xla/stream_executor/tpu/noncopyable_buffer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_ #define XLA_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_ +#include #include #include #include @@ -67,10 +68,8 @@ class NoncopyableBuffer { } #endif uint32_t* data_u32 = reinterpret_cast(data_.get()); - uint32_t v = value.value_or(0); - for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) { - *p = v; - } + uint32_t v = value.value_or(uint32_t{0}); + std::fill_n(data_u32, size_in_u32s, v); } // Directly use buf pointer without copying it to owning data_. This delays diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc index fdf7fdf67fdc29..75295ea47af979 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.cc @@ -56,13 +56,6 @@ bool TpuExecutor::SynchronizeAllActivity() { return ExecutorApiFn()->TpuExecutor_SynchronizeAllActivityFn(executor_); } -absl::Status TpuExecutor::BlockHostUntilDone(Stream* stream) { - StatusHelper status; - ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn( - executor_, get_stream(stream), status.c_status); - return status.status(); -} - tensorflow::tpu::TpuCoreLocationExternal TpuExecutor::GetCoreLocationExternal() const { return tensorflow::tpu::TpuCoreLocationExternal( diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index 85646afbb68762..65ca29a0b7317f 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -66,8 +66,6 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; - absl::Status BlockHostUntilDone(Stream* stream) override; - absl::StatusOr> CreateDeviceDescription() const override; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc index 63c83e5696cfc5..d1248356d914a4 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc @@ -18,9 +18,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" // IWYU pragma: keep -#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace tpu { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h index 298773caef6e2a..dfa4b4a97ec0b0 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_stream.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_stream.h @@ -70,6 +70,13 @@ class TpuStream : public tensorflow::tpu::TpuStreamInterface { return status.status(); } + absl::Status BlockHostUntilDone() override { + StatusHelper status; + stream_executor::tpu::ExecutorApiFn()->TpuExecutor_BlockHostUntilDoneFn( + se_executor_, stream_, status.c_status); + return status.status(); + } + absl::Status EnqueueTransferDeviceToHost( stream_executor::DeviceMemoryBase device_src, void* host_dst, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc index 0f18b68a08e343..0f2643d77a3c37 100644 --- a/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc +++ b/third_party/xla/xla/stream_executor/trace_command_buffer_factory.cc @@ -34,7 +34,7 @@ TraceCommandBufferFactory::Create( absl::AnyInvocable function, CommandBuffer::Mode mode) { TF_ASSIGN_OR_RETURN(auto stream, executor->CreateStream()); - stream->set_name("Command buffer tracer"); + stream->SetName("Command buffer tracer"); return TraceCommandBufferFactory::Create(executor, stream.get(), std::move(function), mode); } diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 2d934d6564f792..989d09ff1d658e 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -14,7 +14,7 @@ load( load("//xla:package_groups.bzl", "xla_tests_package_groups") load("//xla:xla.bzl", "tests_build_defs_bzl_deps", "xla_cc_binary", "xla_cc_test") load("//xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -72,8 +72,8 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_verifier", "//xla/service:transfer_manager", "@com_google_absl//absl/status", @@ -113,23 +113,10 @@ cc_library( cc_library( name = "verified_hlo_module", testonly = True, - srcs = ["verified_hlo_module.cc"], hdrs = ["verified_hlo_module.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/testlib:verified_hlo_module instead.", deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service:hlo_verifier", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", + "//xla/hlo/testlib:verified_hlo_module", ], ) @@ -175,37 +162,106 @@ cc_library( deps = [ ":filecheck", ":literal_test_util", + ":new_hlo_test_base", ":pjrt_client_registry", ":test_utils", ":verified_hlo_module", "//xla:debug_options_flags", - "//xla:shape_layout", + "//xla:error_spec", + "//xla:literal", "//xla:shape_util", - "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", "//xla/hlo/utils:hlo_query", + "//xla/pjrt:pjrt_client", "//xla/service:backend", - "//xla/service:computation_layout", + "//xla/service:computation_placer_hdr", + "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:hlo_module_util", - "//xla/service:hlo_parser", "//xla/service:hlo_runner", "//xla/service:hlo_runner_interface", "//xla/service:hlo_runner_pjrt", "//xla/service:hlo_verifier", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "new_hlo_test_base", + testonly = True, + srcs = ["new_hlo_test_base.cc"], + hdrs = ["new_hlo_test_base.h"], + deps = [ + ":filecheck", + ":literal_test_util", + ":test_utils", + ":verified_hlo_module", + "//xla:debug_options_flags", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:test_helpers", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/pass:hlo_pass", + "//xla/hlo/testlib:hlo_hardware_independent_test_base", + "//xla/hlo/testlib:verified_hlo_module", + "//xla/hlo/utils:hlo_query", + "//xla/service:backend", + "//xla/service:computation_layout", + "//xla/service:computation_placer_hdr", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_module_util", + "//xla/service:hlo_runner", + "//xla/service:hlo_runner_interface", + "//xla/service:hlo_verifier", + "//xla/service:interpreter_plugin", # reference backend + "//xla/service:platform_util", + "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -217,8 +273,8 @@ xla_cc_binary( "//xla:types", "//xla:util", "//xla/client:client_library", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service:cpu_plugin", "//xla/service/cpu:cpu_compiler", "//xla/service/llvm_ir:llvm_util", @@ -259,11 +315,11 @@ cc_library( "//xla/client:client_library", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service:interpreter_plugin", # reference backend "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/stream_executor:stream_executor_h", "//xla/tsl/lib/core:bitmap", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -307,23 +363,10 @@ cc_library( cc_library( name = "filecheck", testonly = True, - srcs = ["filecheck.cc"], hdrs = ["filecheck.h"], - data = [ - "@llvm-project//llvm:FileCheck", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/testlib:filecheck instead.", deps = [ - "//xla:types", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:resource_loader", - "@local_tsl//tsl/platform:subprocess", + "//xla/hlo/testlib:filecheck", ], ) @@ -342,16 +385,16 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/service:computation_placer", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:local_service", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -374,8 +417,8 @@ xla_test( "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", ], @@ -384,6 +427,19 @@ xla_test( xla_test( name = "buffer_donation_test", srcs = ["buffer_donation_test.cc"], + backend_args = if_google( + { + "cpu": [ + # TODO(b/372312816): Fix the leak in the test. + "--heap_check=", + ], + "interpreter": [ + # TODO(b/372312816): Fix the leak in the test. + "--heap_check=", + ], + }, + {}, + ), tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", @@ -419,9 +475,9 @@ xla_test( "//xla:execution_options_util", "//xla:status_macros", "//xla:test", - "//xla/client:xla_computation", - "//xla/service:despecializer", - "//xla/service:float_normalization", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/transforms:despecializer", + "//xla/hlo/transforms:float_normalization", ], ) @@ -439,9 +495,9 @@ xla_test( "//xla:execution_options_util", "//xla:status_macros", "//xla:test", - "//xla/client:xla_computation", - "//xla/service:despecializer", - "//xla/service:float_normalization", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/transforms:despecializer", + "//xla/hlo/transforms:float_normalization", ], ) @@ -464,9 +520,9 @@ xla_test( "//xla:execution_options_util", "//xla:status_macros", "//xla:test", - "//xla/client:xla_computation", - "//xla/service:despecializer", - "//xla/service:float_normalization", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/transforms:despecializer", + "//xla/hlo/transforms:float_normalization", "@com_google_absl//absl/algorithm:container", ], ) @@ -486,7 +542,7 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", ], ) @@ -502,7 +558,7 @@ xla_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -523,9 +579,9 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tsl/lib/core:status_test_util", @@ -553,8 +609,8 @@ xla_test( "//xla:shape_util", "//xla:util", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service:platform_util", "//xla/service:stream_pool", "//xla/tsl/lib/core:status_test_util", @@ -577,7 +633,7 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -601,10 +657,10 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", - "//xla/stream_executor", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", ], ) @@ -629,8 +685,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:protobuf", @@ -647,8 +703,8 @@ xla_test( ":xla_internal_test_main", "//xla:array2d", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], @@ -665,7 +721,7 @@ xla_test( ":xla_internal_test_main", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -680,8 +736,8 @@ xla_test( ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", ], ) @@ -697,7 +753,7 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -722,8 +778,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:math", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:math", "@local_tsl//tsl/platform:test", ], ) @@ -745,8 +801,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -766,8 +822,8 @@ xla_test( "//xla:test_helpers", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", ], @@ -788,8 +844,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", @@ -820,7 +876,7 @@ xla_test( "//xla:types", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -845,9 +901,9 @@ cc_library( "//xla:execution_options_util", "//xla:status_macros", "//xla:test", - "//xla/client:xla_computation", - "//xla/service:despecializer", - "//xla/service:float_normalization", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/transforms:despecializer", + "//xla/hlo/transforms:float_normalization", ], ) @@ -867,7 +923,7 @@ xla_test( "//xla:types", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -924,19 +980,38 @@ xla_test( ":client_library_test_base", ":hlo_test_base", ":test_macros_header", - ":test_utils", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:array4d", + "//xla:error_spec", + "//xla:executable_run_options", + "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", + "//xla:test_helpers", + "//xla:types", + "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:matrix", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/parser:hlo_parser", + "//xla/service", + "//xla/service:platform_util", + "//xla/service:shaped_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/stream_executor:stream_executor_memory_allocator", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -963,19 +1038,21 @@ xla_test( ":client_library_test_base", ":hlo_test_base", ":test_macros_header", - ":test_utils", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:matrix", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -1007,19 +1084,21 @@ xla_test( ":client_library_test_base", ":hlo_test_base", ":test_macros_header", - ":test_utils", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:matrix", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -1046,7 +1125,7 @@ xla_test( "//xla:literal_util", "//xla:status_macros", "//xla:test", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service", ], ) @@ -1090,19 +1169,21 @@ xla_test( ":client_library_test_base", ":hlo_test_base", ":test_macros_header", - ":test_utils", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", + "//xla:error_spec", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:matrix", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:matrix", + "//xla/hlo/parser:hlo_parser", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", ] + if_rocm_is_configured([ @@ -1125,7 +1206,7 @@ xla_test( "//xla:literal_util", "//xla:reference_util", "//xla:util", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -1146,10 +1227,12 @@ xla_test( "//xla:array3d", "//xla:array4d", "//xla:literal_util", + "//xla:types", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:constants", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:constants", "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", ], @@ -1167,8 +1250,8 @@ CONVOLUTION_TEST_DEPS = [ "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:padding", - "//xla/client:xla_builder", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", "//xla/tests:client_library_test_base", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", @@ -1209,7 +1292,7 @@ xla_test( ], shard_count = 50, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1228,7 +1311,7 @@ xla_test( ], shard_count = 50, tags = [ - "no_rocm", + "cuda-only", "optonly", "test_xla_cpu_thunks", ], @@ -1249,7 +1332,7 @@ xla_test( backends = ["gpu"], shard_count = 40, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1266,7 +1349,7 @@ xla_test( backends = ["gpu"], shard_count = 40, tags = [ - "no_rocm", + "cuda-only", "optonly", ], deps = CONVOLUTION_TEST_DEPS + [ @@ -1295,7 +1378,7 @@ xla_test( backend_args = {"gpu": ["--xla_backend_extra_options=xla_gpu_experimental_conv_disable_layout_heuristic"]}, backends = ["gpu"], shard_count = 25, - tags = ["no_rocm"], + tags = ["cuda-only"], deps = CONVOLUTION_TEST_DEPS + [ "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -1340,8 +1423,8 @@ xla_test( "//xla:reference_util", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:padding", - "//xla/client:xla_builder", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -1361,8 +1444,8 @@ xla_test( "//xla:reference_util", "//xla:test", "//xla/client:local_client", - "//xla/client:padding", - "//xla/client:xla_builder", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", ], ) @@ -1415,14 +1498,14 @@ xla_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", - "//xla/client/lib:math", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:math", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/math:math_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], @@ -1449,8 +1532,8 @@ xla_test( "//xla:test_helpers", "//xla:util", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", "//xla/hlo/ir:hlo", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", @@ -1465,7 +1548,7 @@ xla_test( ":client_library_test_base", ":xla_internal_test_main", "//xla:test", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:ml_dtypes", ], @@ -1487,7 +1570,7 @@ xla_test( "//xla:literal", "//xla:test", "//xla:test_helpers", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", ], ) @@ -1506,7 +1589,7 @@ xla_test( ":hlo_test_base", ":xla_internal_test_main", "//xla:test", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:errors", ], ) @@ -1525,7 +1608,7 @@ xla_test( "//xla:array2d", "//xla:reference_util", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -1546,7 +1629,7 @@ xla_test( "//xla:array2d", "//xla:array3d", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -1568,14 +1651,14 @@ xla_test( "//xla:test_helpers", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:computation_placer", "//xla/service:local_service", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -1598,9 +1681,9 @@ xla_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/parser:hlo_parser", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", @@ -1620,8 +1703,8 @@ xla_test( "//xla:array3d", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", "@local_tsl//tsl/platform:test", ], ) @@ -1650,9 +1733,9 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", @@ -1691,24 +1774,28 @@ xla_test_library( deps = [ ":client_library_test_base", ":hlo_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", "//xla:array3d", "//xla:array4d", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", - "//xla/client:padding", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:status", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -1747,10 +1834,10 @@ xla_test( "//xla:error_spec", "//xla:reference_util", "//xla:xla_data_proto_cc", - "//xla/client:padding", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:padding", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "@local_tsl//tsl/platform:test", ], ) @@ -1774,8 +1861,10 @@ xla_test( "//xla:literal_util", "//xla:shape_util", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/stream_executor:platform", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", @@ -1791,7 +1880,15 @@ xla_test( ":test_macros_header", ":test_utils", ":xla_internal_test_main", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -1805,6 +1902,8 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:error_spec", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", ], ) @@ -1872,8 +1971,8 @@ xla_test( "//xla:shape_util", "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@local_tsl//tsl/platform:test", ], ) @@ -1899,10 +1998,10 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:xla_builder", "//xla/ffi", "//xla/ffi:execution_context", "//xla/ffi:ffi_api", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/service", "//xla/service:custom_call_status", @@ -1938,7 +2037,7 @@ xla_test( "//xla:array4d", "//xla:reference_util", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -1958,7 +2057,7 @@ xla_test( "//xla:literal_util", "//xla:test", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", ], @@ -1970,16 +2069,20 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", + "//xla:array3d", "//xla:array4d", + "//xla:error_spec", + "//xla:literal_util", "//xla:reference_util", + "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", - "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "@local_tsl//tsl/platform:test", ], ) @@ -1993,7 +2096,7 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:error_spec", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service", ], ) @@ -2008,7 +2111,7 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:test", ], ) @@ -2031,8 +2134,8 @@ xla_test( "//xla:test_helpers", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -2065,7 +2168,7 @@ xla_test( "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:protobuf", @@ -2083,8 +2186,8 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", - "//xla/service:rng_bit_generator_expander", - "//xla/service:rng_expander", + "//xla/hlo/transforms:rng_bit_generator_expander", + "//xla/hlo/transforms:rng_expander", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -2105,20 +2208,25 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:array2d", + "//xla:array3d", "//xla:array4d", "//xla:error_spec", + "//xla:literal", "//xla:literal_util", "//xla:reference_util", "//xla:shape_util", - "//xla:status_macros", "//xla:test", + "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", ], ) @@ -2146,15 +2254,18 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:array2d", "//xla:array4d", - "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", ], ) @@ -2166,9 +2277,9 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", - ":test_utils", "//xla:error_spec", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", @@ -2193,10 +2304,9 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", - "//xla/stream_executor", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2219,8 +2329,8 @@ xla_test( "//xla:test", "//xla:test_helpers", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2238,7 +2348,7 @@ xla_test( "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@local_tsl//tsl/platform:ml_dtypes", @@ -2317,15 +2427,10 @@ xla_test( "multi_gpu", "no_oss", ], - "cpu": [ - "notsan", - ], }, backends = [ "gpu", - "cpu", ], - tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", ":literal_test_util", @@ -2339,6 +2444,7 @@ xla_test( "//xla:shape_util", "//xla:statusor", "//xla/hlo/ir:hlo", + "//xla/service:computation_placer", "//xla/service:executable", "//xla/service:hlo_module_config", "@com_google_absl//absl/log", @@ -2373,7 +2479,9 @@ xla_test( "//xla/hlo/utils:hlo_matchers", "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_h", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -2390,10 +2498,10 @@ xla_test( "//xla:error_spec", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:hlo_dce", "//xla/service:collective_pipeliner", - "//xla/service:hlo_dce", - "//xla/service:hlo_parser", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -2436,8 +2544,7 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/stream_executor", + "//xla/hlo/builder:xla_builder", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", ], @@ -2469,7 +2576,7 @@ xla_test( ":test_macros_header", ":xla_internal_test_main", "//xla:error_spec", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", @@ -2521,11 +2628,11 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", - "//xla/client:value_inference", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:arithmetic", - "//xla/client/lib:prng", + "//xla/hlo/builder:value_inference", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:prng", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2551,8 +2658,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:global_data", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -2575,8 +2682,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", ], @@ -2597,8 +2704,8 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:test", @@ -2641,7 +2748,8 @@ xla_test( "//xla/hlo/ir:hlo_module_group", "//xla/service:backend", "//xla/service:llvm_compiler", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status", "@llvm-project//llvm:Core", "@local_tsl//tsl/platform:casts", @@ -2694,7 +2802,7 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:client_library", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/hlo/ir:hlo", "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", @@ -2763,7 +2871,7 @@ xla_test( ":xla_internal_test_main", "//xla:literal", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:local_service", "//xla/service:shaped_buffer", "@com_google_absl//absl/status:statusor", @@ -2793,18 +2901,19 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:client_library", "//xla/client:local_client", - "//xla/client:sharding_builder", - "//xla/client:xla_builder", + "//xla/hlo/builder:sharding_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:platform_util", "//xla/service:shaped_buffer", "//xla/service:transfer_manager", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:platform_manager", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/host:host_platform", "//xla/stream_executor/host:host_platform_id", "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_benchmark", ], @@ -2836,7 +2945,7 @@ xla_cc_test( ":local_client_test_base", "//xla:test_helpers", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:cpu_plugin", "//xla/service:local_service", "@local_tsl//tsl/platform:test_main", @@ -2879,7 +2988,7 @@ xla_test( "//xla:test_helpers", "//xla/client:global_data", "//xla/client:local_client", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", @@ -2893,7 +3002,7 @@ xla_test( deps = [ ":client_library_test_base", ":xla_internal_test_main", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", ], ) @@ -2928,12 +3037,12 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", + "//xla/hlo/parser:hlo_parser", "//xla/service:generic_transfer_manager", - "//xla/service:hlo_parser", "//xla/service:shaped_buffer", "//xla/service:stream_pool", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test_benchmark", @@ -2989,8 +3098,8 @@ xla_test( ":test_utils", ":xla_internal_test_main", "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/service:hlo_parser", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/parser:hlo_parser", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_set", @@ -2999,11 +3108,13 @@ xla_test( xla_test( name = "iota_test", + timeout = "long", srcs = ["iota_test.cc"], + backend_tags = { + "cpu": ["optonly"], + }, shard_count = 50, tags = [ - # Require optimized builds, iota_test_cpu is very slow in fastbuild. - "optonly", "test_xla_cpu_thunks", ], deps = [ @@ -3015,7 +3126,7 @@ xla_test( "//xla:shape_util", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:ml_dtypes", @@ -3031,7 +3142,7 @@ xla_cc_test( ":xla_internal_test_main", # fixdeps: keep "//xla:shape_util", "//xla/client:client_library", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/service:cpu_plugin", "//xla/stream_executor:platform_manager", "//xla/tsl/lib/core:status_test_util", @@ -3114,9 +3225,9 @@ xla_test( "//xla:test", "//xla:types", "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client/lib:math", - "//xla/client/lib:matrix", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:math", + "//xla/hlo/builder/lib:matrix", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -3140,9 +3251,9 @@ xla_test( "//xla:literal", "//xla:test", "//xla:types", - "//xla/client:xla_builder", - "//xla/client/lib:arithmetic", - "//xla/client/lib:matrix", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:arithmetic", + "//xla/hlo/builder/lib:matrix", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", ], diff --git a/third_party/xla/xla/tests/array_elementwise_ops_test.cc b/third_party/xla/xla/tests/array_elementwise_ops_test.cc index 6000ed029bccc5..c12ce79a06e8fa 100644 --- a/third_party/xla/xla/tests/array_elementwise_ops_test.cc +++ b/third_party/xla/xla/tests/array_elementwise_ops_test.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" #include "xla/fp_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" @@ -1423,7 +1423,8 @@ class TotalOrderTest : public ClientLibraryTestBase { } }; -using Types = ::testing::Types #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc index 5e655d0e1400d7..cb077b05dda71d 100644 --- a/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc +++ b/third_party/xla/xla/tests/bad_rng_shape_validation_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/types.h" diff --git a/third_party/xla/xla/tests/batch_normalization_test.cc b/third_party/xla/xla/tests/batch_normalization_test.cc index c388c0fd72b89c..3b6aebc95cb05d 100644 --- a/third_party/xla/xla/tests/batch_normalization_test.cc +++ b/third_party/xla/xla/tests/batch_normalization_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/math.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -39,10 +39,10 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/bfloat16_test.cc b/third_party/xla/xla/tests/bfloat16_test.cc index 6f5f132d1eba4c..22085485fde573 100644 --- a/third_party/xla/xla/tests/bfloat16_test.cc +++ b/third_party/xla/xla/tests/bfloat16_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/tests/binop_scaling_test.cc b/third_party/xla/xla/tests/binop_scaling_test.cc index 8205318b40f154..6aab7717b2a1f7 100644 --- a/third_party/xla/xla/tests/binop_scaling_test.cc +++ b/third_party/xla/xla/tests/binop_scaling_test.cc @@ -16,7 +16,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/bitcast_convert_test.cc b/third_party/xla/xla/tests/bitcast_convert_test.cc index 78c6b435fe5c97..0deca38d59c163 100644 --- a/third_party/xla/xla/tests/bitcast_convert_test.cc +++ b/third_party/xla/xla/tests/bitcast_convert_test.cc @@ -19,9 +19,8 @@ limitations under the License. #include #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/broadcast_simple_test.cc b/third_party/xla/xla/tests/broadcast_simple_test.cc index caf0a57b8f6254..2876714ab94e02 100644 --- a/third_party/xla/xla/tests/broadcast_simple_test.cc +++ b/third_party/xla/xla/tests/broadcast_simple_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 990daa1423aeaa..711301808c345a 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -47,7 +47,7 @@ def prepare_nvidia_gpu_backend_data(backends, disabled_backends, backend_tags, b new_disabled_backends.extend(NVIDIA_GPU_BACKENDS) new_backend_tags = {key: value for key, value in backend_tags.items() if key != "gpu"} - gpu_backend_tags = backend_tags.get("gpu", []) + gpu_backend_tags = backend_tags.get("gpu", tf_gpu_tests_tags()) for key in NVIDIA_GPU_BACKENDS: new_backend_tags.setdefault(key, gpu_backend_tags[:]) @@ -97,7 +97,7 @@ def prepare_nvidia_gpu_backend_data(backends, disabled_backends, backend_tags, b sm_tag += ":%d" % num_gpus new_backend_tags[gpu_backend] = [t for t in all_tags if t not in requires_gpu] new_backend_tags[gpu_backend].append(sm_tag) - new_backend_tags[gpu_backend].append("no_rocm") + new_backend_tags[gpu_backend].append("cuda-only") return new_backends, new_disabled_backends, new_backend_tags, new_backend_args @@ -130,9 +130,10 @@ def prepare_amd_gpu_backend_data(backends, disabled_backends, backend_tags, back new_backend_tags.setdefault(key, gpu_backend_tags[:]) for backend in AMD_GPU_DEFAULT_BACKENDS: - if "no_rocm" not in gpu_backend_tags: + if "cuda-only" not in gpu_backend_tags: new_backend_tags[backend].append("requires-gpu-amd") new_backend_tags[backend].append("notap") + new_backend_tags[backend].append("rocm-only") return new_backends, new_disabled_backends, new_backend_tags, backend_args diff --git a/third_party/xla/xla/tests/call_test.cc b/third_party/xla/xla/tests/call_test.cc index 4bc38c213b384d..36aae1aed51de1 100644 --- a/third_party/xla/xla/tests/call_test.cc +++ b/third_party/xla/xla/tests/call_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/check_execution_arity_test.cc b/third_party/xla/xla/tests/check_execution_arity_test.cc index 20f2083ee0e141..fd0f5bd9bf75e0 100644 --- a/third_party/xla/xla/tests/check_execution_arity_test.cc +++ b/third_party/xla/xla/tests/check_execution_arity_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/tests/cholesky_test.cc b/third_party/xla/xla/tests/cholesky_test.cc index 9215319bbf8e40..c52ea4b9ea849c 100644 --- a/third_party/xla/xla/tests/cholesky_test.cc +++ b/third_party/xla/xla/tests/cholesky_test.cc @@ -20,9 +20,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 743db05f93f73b..01944740eab9ec 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -24,13 +24,14 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test_helpers.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" namespace xla { @@ -291,7 +292,7 @@ absl::StatusOr ClientLibraryTestBase::ComputeAndTransfer( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -315,7 +316,7 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } @@ -326,20 +327,20 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for non float outputs. In this + // case, we need to convert the expected literal to test_type_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16_) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -377,27 +378,27 @@ absl::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( for (const auto& argument : arguments_) { TF_ASSIGN_OR_RETURN( std::unique_ptr owned_argument, - client_->TransferToServer(MaybeConvertLiteralToBfloat16(argument))); + client_->TransferToServer(MaybeConvertLiteralToTestType(argument))); owning_arguments.push_back(std::move(owned_argument)); arguments.push_back(owning_arguments.back().get()); } } TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - // We allow using a float expected literal for a bfloat16 output. In this - // case, we need to convert the expected literal to bfloat16. + // We allow using a float expected literal for a non float outputs. In this + // case, we need to convert the expected literal to type_test_. const Literal* expected_ptr = &expected; Literal converted_expected; Shape layout_shape; - if (use_bfloat16_) { - converted_expected = LiteralUtil::ConvertF32ToBF16(expected); + if (test_type_ != F32) { + converted_expected = MaybeConvertLiteralToTestType(expected); expected_ptr = &converted_expected; if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; ShapeUtil::ForEachMutableSubshape( &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); + subshape->set_element_type(test_type_); } }); shape_with_layout = &layout_shape; @@ -535,13 +536,11 @@ ClientLibraryTestBase::ComputeValueAndReference( return std::make_pair(std::move(reference), std::move(result)); } -XlaComputation ClientLibraryTestBase::CreateScalarRelu() { +XlaComputation ClientLibraryTestBase::CreateScalarReluF32() { XlaBuilder builder("relu"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(F32, {}); auto z_value = Parameter(&builder, 0, shape, "z_value"); - auto zero = use_bfloat16_ - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); + auto zero = ConstantR0(&builder, 0.0f); Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -550,7 +549,7 @@ XlaComputation ClientLibraryTestBase::CreateScalarRelu() { XlaComputation ClientLibraryTestBase::CreateScalarMax() { XlaBuilder builder("max"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto shape = ShapeUtil::MakeShape(test_type_, {}); auto x = Parameter(&builder, 0, shape, "x"); auto y = Parameter(&builder, 1, shape, "y"); Max(x, y); @@ -559,22 +558,6 @@ XlaComputation ClientLibraryTestBase::CreateScalarMax() { return std::move(computation_status).value(); } -XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() { - XlaBuilder builder("relu_sensitivity"); - auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); - auto activation = Parameter(&builder, 0, shape, "activation"); - auto backprop = Parameter(&builder, 1, shape, "backprop"); - auto zero = use_bfloat16_ - ? ConstantR0(&builder, static_cast(0.0f)) - : ConstantR0(&builder, 0.0f); - auto activation_gtz = Gt(activation, zero); - Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); - - auto computation_status = builder.Build(); - TF_CHECK_OK(computation_status.status()); - return std::move(computation_status).value(); -} - std::unique_ptr> ClientLibraryTestBase::CreatePatternedMatrix( int rows, int cols, float offset) { auto array = std::make_unique>(rows, cols); @@ -605,14 +588,12 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument, XlaBuilder* builder) { arguments_.push_back(argument.Clone()); return Parameter(builder, /*parameter_number=*/arguments_.size() - 1, - MaybeConvertShapeToBfloat16(argument.shape()), ""); + MaybeConvertShapeToTestType(argument.shape()), ""); } XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder) { - return ConstantLiteral(builder, use_bfloat16_ - ? LiteralUtil::ConvertF32ToBF16(literal) - : LiteralSlice(literal)); + return ConstantLiteral(builder, MaybeConvertLiteralToTestType(literal)); } absl::StatusOr> @@ -623,26 +604,34 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( nullptr, builder, data_handle); } -Shape ClientLibraryTestBase::MaybeConvertShapeToBfloat16(const Shape& shape) { - if (!use_bfloat16_) { +Shape ClientLibraryTestBase::MaybeConvertShapeToTestType(const Shape& shape) { + if (test_type_ == F32) { return shape; } Shape new_shape = shape; - ShapeUtil::ForEachMutableSubshape(&new_shape, - [](Shape* subshape, const ShapeIndex&) { - if (subshape->element_type() == F32) { - subshape->set_element_type(BF16); - } - }); + ShapeUtil::ForEachMutableSubshape( + &new_shape, [test_type = test_type_](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == F32) { + subshape->set_element_type(test_type); + } + }); return new_shape; } -Literal ClientLibraryTestBase::MaybeConvertLiteralToBfloat16( +Literal ClientLibraryTestBase::MaybeConvertLiteralToTestType( const Literal& literal) { - if (use_bfloat16_) { - return LiteralUtil::ConvertF32ToBF16(literal); + switch (test_type_) { + case BF16: + return LiteralUtil::ConvertF32ToBF16(literal); + case F32: + return literal.Clone(); + case F8E5M2: + return LiteralUtil::ConvertF32ToF8E5M2(literal); + case F8E4M3FN: + return LiteralUtil::ConvertF32ToF8E4M3FN(literal); + default: + LOG(FATAL) << "Unsupported test type: " << test_type_; } - return literal.Clone(); } absl::StatusOr> @@ -650,7 +639,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, const DeviceHandle* device_handle, XlaBuilder* builder, XlaOp* data_handle) { - Literal param_literal = MaybeConvertLiteralToBfloat16(literal); + Literal param_literal = MaybeConvertLiteralToTestType(literal); TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(param_literal, device_handle)); *data_handle = diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 9185a31d7f6211..016b73b6e7682d 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -29,8 +29,8 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/client_library.h" #include "xla/client/global_data.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/stream_executor/stream_executor.h" @@ -44,18 +44,15 @@ limitations under the License. namespace xla { -// Sets the use_bfloat16 on a container of test cases according to the values in -// use_bfloat16_params. Generates one set of test cases for each values in -// use_bfloat16_params with that value. Returns the result. template -std::vector ExpandUseBfloat16( - absl::Span use_bfloat16_params, +std::vector ExpandTestType( + absl::Span test_type_params, absl::Span specs) { std::vector expanded; - for (bool use_bfloat16 : use_bfloat16_params) { + for (const PrimitiveType test_type : test_type_params) { for (const auto& spec : specs) { expanded.push_back(spec); - expanded.back().use_bfloat16 = use_bfloat16; + expanded.back().test_type = test_type; } } return expanded; @@ -236,9 +233,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::Span arguments, ErrorSpec error); // Create scalar operations for use in reductions. - XlaComputation CreateScalarRelu(); + XlaComputation CreateScalarReluF32(); XlaComputation CreateScalarMax(); - XlaComputation CreateScalarReluSensitivity(); // Special case convenience functions for creating filled arrays. @@ -276,8 +272,8 @@ class ClientLibraryTestBase : public ::testing::Test { // Creates a parameter instruction, transfers the literal for the parameter to // server, then stores into "data_handle" the global handle for that - // parameter. When the use_bfloat16 flag is set but the literal has F32 - // elements, the literal will be converted to BF16 before being transferred. + // parameter. When the test_type is bfloat16 but the literal has F32 elements, + // the literal will be converted to test_type_ before being transferred. absl::StatusOr> CreateParameterAndTransferLiteral( int64_t parameter_number, const Literal& literal, const std::string& name, XlaBuilder* builder, XlaOp* data_handle); @@ -302,15 +298,13 @@ class ClientLibraryTestBase : public ::testing::Test { return AddParam(LiteralUtil::CreateFromArray(argument), builder); } - // Creates a constant instruction with the given literal. When the - // use_bfloat16 flag is set but the literal has F32 elements, the elements - // will be converted to BF16s. + // Creates a constant instruction with the given literal. When the test_type + // is bfloat16 but the literal has F32 elements, the literal will be converted + // to test_type_ before being transferred. XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder); - // Creates a constant instruction with the given array. When the use_bfloat16 - // flag is set but the array has float elements, the elements will be - // converted to bfloat16s. - + // Creates a constant instruction with the given array. When the test_type is + // bfloat16, the elements will be converted to bfloat16s. template XlaOp CreateConstantFromArray(const Array& array, XlaBuilder* builder) { @@ -331,7 +325,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR0Parameter(NativeT value, @@ -346,7 +340,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR1Parameter( @@ -360,7 +354,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR2Parameter( @@ -374,7 +368,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR3Parameter( @@ -388,7 +382,7 @@ class ClientLibraryTestBase : public ::testing::Test { // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. // - // When the use_bfloat16 flag is set but NativeT is float, the data will be + // When the test_type is bfloat16 but NativeT is float, the data will be // converted to bfloat16. template std::unique_ptr CreateR4Parameter( @@ -402,13 +396,9 @@ class ClientLibraryTestBase : public ::testing::Test { XlaBuilder* builder, XlaOp* data_handle); - // Getter and setter for the use_bfloat16 flag, which indicates whether to run - // tests with all float-type input/output converted to bfloat16. - bool use_bfloat16() const { return use_bfloat16_; } - void set_use_bfloat16(bool value) { use_bfloat16_ = value; } - - // The float type used in this test, BF16 or F32 according to use_bfloat16. - PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + // The float type used in this test. + PrimitiveType FloatType() const { return test_type_; } + void set_float_type(PrimitiveType type) { test_type_ = type; } // Executes the computation and calculates the expected reference value using // the reference client. Returns two literals in the order of (expected, @@ -416,8 +406,8 @@ class ClientLibraryTestBase : public ::testing::Test { absl::StatusOr> ComputeValueAndReference( XlaBuilder* builder, absl::Span arguments); - // Converts an f32 literal to bf16 if use_bfloat16_ is true. - Literal MaybeConvertLiteralToBfloat16(const Literal& literal); + // Converts a literal to the test_type if the literal's type is F32. + Literal MaybeConvertLiteralToTestType(const Literal& literal); LocalClient* client_; LocalClient* ref_client_; // To compute reference result. @@ -438,12 +428,12 @@ class ClientLibraryTestBase : public ::testing::Test { verify_output, const Shape* output_with_layout = nullptr); - // Converts an f32 shape to bf16 if use_bfloat16_ is true. - Shape MaybeConvertShapeToBfloat16(const Shape& shape); + // Converts an f32 shape to test_type_. + Shape MaybeConvertShapeToTestType(const Shape& shape); - // Whether to run tests with all float-type input/output converted to - // bfloat16. - bool use_bfloat16_ = false; + // Type to use when running tests. By default, we use F32 for historical + // reasons and we rely on the underlying tests to change it. + PrimitiveType test_type_ = F32; // Arguments to be passed to the computation when it runs. std::vector arguments_; @@ -584,9 +574,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR0(value); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -597,9 +585,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( absl::Span values, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR1(values); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -610,9 +596,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR2FromArray2D(array_2d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -623,9 +607,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR3FromArray3D(array_3d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -636,9 +618,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR4Parameter( const Array4D& array_4d, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateR4FromArray4D(array_4d); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; @@ -649,9 +629,7 @@ std::unique_ptr ClientLibraryTestBase::CreateParameter( const Array& array, int64_t parameter_number, const std::string& name, XlaBuilder* builder, XlaOp* data_handle) { Literal literal = LiteralUtil::CreateFromArray(array); - if (use_bfloat16_ && literal.shape().element_type() == F32) { - literal = LiteralUtil::ConvertF32ToBF16(literal); - } + literal = MaybeConvertLiteralToTestType(literal); std::unique_ptr data = client_->TransferToServer(literal).value(); *data_handle = Parameter(builder, parameter_number, literal.shape(), name); return data; diff --git a/third_party/xla/xla/tests/client_test.cc b/third_party/xla/xla/tests/client_test.cc index 1adb92207748aa..59eafe57b12141 100644 --- a/third_party/xla/xla/tests/client_test.cc +++ b/third_party/xla/xla/tests/client_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/test_helpers.h" diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index 01c60e5d6ac683..a5610a618f3161 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -54,6 +55,13 @@ DeviceAssignment MakeDeviceAssn(int64_t num_replicas) { class CollectiveOpsTestE2E : public HloTestBase { public: + CollectiveOpsTestE2E() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + bool IsCuda() { return std::holds_alternative(Capability()); } @@ -93,10 +101,11 @@ class CollectiveOpsTestE2E : public HloTestBase { CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + std::vector gemm_ops = + FindInstructions(&executable->module(), HloOpcode::kCustomCall); + for (HloInstruction* gemm_op : gemm_ops) { + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } } absl::StatusOr> ExecuteReplicated(Executable* executable, @@ -108,6 +117,13 @@ class CollectiveOpsTestE2E : public HloTestBase { /*argument_provider*/ [](int64_t, int64_t) { return nullptr; }, num_replicas, /*run_hlo_passes=*/false, &device_assignment); } + + protected: + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; }; // E2E tests for collective ops. These will generally verify some HLO transform @@ -126,7 +142,7 @@ class AsyncCollectiveOps : public CollectiveOpsTestE2E, } protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); // Enable or disable all async collectives based on test parameter. @@ -434,6 +450,46 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { LiteralTestUtil::ExpectR1Equal({15, 16}, results[1]); } +TEST_F(CollectiveOpsTestE2E, AsyncAllToAllMemCpy) { + const absl::string_view kModuleStr = R"( + HloModule test + ENTRY test_computation { + id = u32[] replica-id() + id2 = u32[2, 2] broadcast(id), dimensions={} + a0 = u32[2, 2] constant({{10, 15}, {20, 25}}) + a1 = u32[2, 2] add(id2, a0) + all2all = u32[2, 2] all-to-all(a1), dimensions={0} + ROOT out = u32[4] reshape(all2all) + } + )"; + const int64_t kNumReplicas = 2; + + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_use_memcpy_local_p2p(true); + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_debug_options(debug_options); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + auto executable, + CreateExecutable(std::move(module), /*run_hlo_passes=*/true)); + ASSERT_TRUE(executable->has_module()); + HloModule* executable_module = &executable->module(); + + // Verify that the all-to-all is not decomposed into a tuple all-to-all. + const HloInstruction* all_to_all = + FindInstruction(executable_module, HloOpcode::kAllToAll); + EXPECT_THAT(all_to_all, op::Shape("u32[2, 2]")); + + TF_ASSERT_OK_AND_ASSIGN(std::vector results, + ExecuteReplicated(executable.get(), kNumReplicas)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({10, 15, 11, 16}, results[0]); + LiteralTestUtil::ExpectR1Equal({20, 25, 21, 26}, results[1]); +} + XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { const absl::string_view kModuleStr = R"( HloModule test @@ -537,11 +593,10 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) { true /*run_hlo_passes*/, true /*use-threads*/)); ASSERT_EQ(results.size(), kNumReplicas); - auto& ref_runner = HloTestBase::reference_runner_; TF_ASSERT_OK_AND_ASSIGN( auto ref_module, ParseAndReturnVerifiedModule(kModuleSingleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - auto ref_exec, ref_runner.CreateExecutable(std::move(ref_module), true)); + TF_ASSERT_OK_AND_ASSIGN(auto ref_exec, reference_runner().CreateExecutable( + std::move(ref_module), true)); ErrorSpec error_spec{1e-5, 1e-5}; fake_ptrs.push_back(nullptr); @@ -549,8 +604,8 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) { auto replica_id = LiteralUtil::CreateFullWithDescendingLayout({}, i); fake_ptrs.back() = &replica_id; - TF_ASSERT_OK_AND_ASSIGN( - auto res, ref_runner.ExecuteWithExecutable(ref_exec.get(), fake_ptrs)); + TF_ASSERT_OK_AND_ASSIGN(auto res, reference_runner().ExecuteWithExecutable( + ref_exec.get(), fake_ptrs)); EXPECT_TRUE(LiteralTestUtil::Near(res, results[i], error_spec)); } } @@ -808,46 +863,24 @@ ENTRY main.12 { CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); } -TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, - WindowedEinsumE2EAllGatherAndReduceScatterF8) { +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherF8) { absl::string_view kModuleReplicatedStr = R"( -HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[], bf16[], bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -ENTRY main.12 { - Arg_0.1 = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - Arg_1.2 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} - Arg_2.3 = bf16[] parameter(3) - Arg_3.4 = bf16[] parameter(4) - broadcast = bf16[2,16,48]{2,1,0} broadcast(Arg_2.3), dimensions={} - broadcast.1 = bf16[48,192]{1,0} broadcast(Arg_3.4), dimensions={} - convert = bf16[2,16,48]{2,1,0} convert(Arg_0.1) - convert.1 = bf16[48,192]{1,0} convert(Arg_1.2) - multiply = bf16[2,16,48]{2,1,0} multiply(broadcast, convert) - multiply.1 = bf16[48,192]{1,0} multiply(broadcast.1, convert.1) - dot.5 = bf16[2,16,192]{2,1,0} dot(multiply, multiply.1), lhs_contracting_dims={2}, rhs_contracting_dims={0} - custom-call.7 = bf16[2,16,192]{2,1,0} custom-call(dot.5), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} - Arg_4.5 = bf16[] parameter(5) - broadcast.2 = bf16[2,16,192]{2,1,0} broadcast(Arg_4.5), dimensions={} - divide = bf16[2,16,192]{2,1,0} divide(custom-call.7, broadcast.2) - constant = bf16[] constant(-448.) - broadcast.3 = bf16[2,16,192]{2,1,0} broadcast(constant), dimensions={} - constant.1 = bf16[] constant(448.) - broadcast.4 = bf16[2,16,192]{2,1,0} broadcast(constant.1), dimensions={} - clamp = bf16[2,16,192]{2,1,0} clamp(broadcast.3, divide, broadcast.4) - convert.2 = f8e4m3fn[2,16,192]{2,1,0} convert(clamp) - Arg_5.6 = bf16[] parameter(6) - broadcast.5 = bf16[2,16,192]{2,1,0} broadcast(Arg_5.6), dimensions={} - convert.3 = bf16[2,16,192]{2,1,0} convert(convert.2) - multiply.2 = bf16[2,16,192]{2,1,0} multiply(convert.3, broadcast.5) - Arg_6.7 = f8e4m3fn[192,48]{1,0} parameter(2), sharding={devices=[4,1]<=[4]} - Arg_7.8 = bf16[] parameter(7) - broadcast.6 = bf16[192,48]{1,0} broadcast(Arg_7.8), dimensions={} - convert.4 = bf16[192,48]{1,0} convert(Arg_6.7) - multiply.3 = bf16[192,48]{1,0} multiply(convert.4, broadcast.6) - dot.6 = bf16[2,16,48]{2,1,0} dot(multiply.2, multiply.3), lhs_contracting_dims={2}, rhs_contracting_dims={0} - tuple.10 = (bf16[2,16,48]{2,1,0}) tuple(dot.6) - ROOT get-tuple-element.11 = bf16[2,16,48]{2,1,0} get-tuple-element(tuple.10), index=0, sharding={devices=[1,4,1]<=[4]} -} // main.12 +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[48,192]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[48,192]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[48,192]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main )"; // Disable the dot merger pass which can prevent the creation of FP8 GEMM @@ -866,30 +899,107 @@ ENTRY main.12 { CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EAllGatherReshapeF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[2,24,192]{2,1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs = f8e4m3fn[2,24,192]{2,1,0} parameter(1), sharding={devices=[1,1,4]<=[4]} + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} + scale_rhs_bcast = bf16[2,24,192]{2,1,0} broadcast(scale_lhs), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs_bf16 = bf16[2,24,192]{2,1,0} convert(rhs) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[2,24,192]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) + rhs_reshaped = bf16[48,192]{1,0} reshape(rhs_scaled) + dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_reshaped), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]} +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); +} + TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherMultiConsumerF8) { absl::string_view kModuleReplicatedStr = R"( HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 ENTRY main { - rhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} - lhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + lhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_lhs = bf16[] parameter(3) + scale_rhs0 = bf16[] parameter(4) + scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs0_bcast = bf16[48,192]{1,0} broadcast(scale_rhs0), dimensions={} + lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs) + rhs0_bf16 = bf16[48,192]{1,0} convert(rhs0) + lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs0_scaled = bf16[48,192]{1,0} multiply(scale_rhs0_bcast, rhs0_bf16) + dot0 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + rhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_rhs1 = bf16[] parameter(5) + scale_rhs1_bcast = bf16[48,192]{1,0} broadcast(scale_rhs1), dimensions={} + rhs1_bf16 = bf16[48,192]{1,0} convert(rhs1) + rhs1_scaled = bf16[48,192]{1,0} multiply(scale_rhs1_bcast, rhs1_bf16) + dot1 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add = bf16[2,16,192]{2,1,0} add(dot0, dot1) +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EReduceScatterF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,192]{2,1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + lhs = f8e4m3fn[2,16,192]{2,1,0} parameter(0), sharding={devices=[1,1,4]<=[4]} + rhs = f8e4m3fn[192,48]{1,0} parameter(1), sharding={devices=[4,1]<=[4]} + scale_lhs = bf16[] parameter(2) scale_rhs = bf16[] parameter(3) - scale_lhs0 = bf16[] parameter(4) - scale_rhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} - scale_lhs0_bcast = bf16[48,192]{1,0} broadcast(scale_lhs0), dimensions={} - rhs_bf16 = bf16[2,16,48]{2,1,0} convert(rhs) - lhs0_bf16 = bf16[48,192]{1,0} convert(lhs0) - rhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) - lhs0_scaled = bf16[48,192]{1,0} multiply(scale_lhs0_bcast, lhs0_bf16) - dot0 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - lhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} - scale_lhs1 = bf16[] parameter(5) - scale_lhs1_bcast = bf16[48,192]{1,0} broadcast(scale_lhs1), dimensions={} - lhs1_bf16 = bf16[48,192]{1,0} convert(lhs1) - lhs1_scaled = bf16[48,192]{1,0} multiply(scale_lhs1_bcast, lhs1_bf16) - dot1 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} - ROOT add.8 = bf16[2,16,192]{2,1,0} add(dot0, dot1) + scale_lhs_bcast = bf16[2,16,192]{2,1,0} broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[192,48]{1,0} broadcast(scale_rhs), dimensions={} + lhs_bf16 = bf16[2,16,192]{2,1,0} convert(lhs) + rhs_bf16 = bf16[192,48]{1,0} convert(rhs) + lhs_scaled = bf16[2,16,192]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16) + rhs_scaled = bf16[192,48]{1,0} multiply(scale_rhs_bcast, rhs_bf16) + dot = bf16[2,16,48]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT custom-call = bf16[2,16,48]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,4,1]<=[4]} } // main )"; @@ -982,154 +1092,57 @@ ENTRY main.9_spmd { CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr); } -TEST_F(CollectiveOpsTestE2E, PostLayoutCollectivePipeliner) { - // We need fp8 support to test the post-layout collective pipeliner. This will - // preserve the desired fp8 patterns and so the gemm rewriter can correctly - // recognize them and rewrite to custom fp8 gemm calls. +TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) { + // Verify that FP8 patterns are preserved when collectives are pipelined so + // the GEMM rewriter can create FP8 matmuls. if (!HasFp8Support()) { - GTEST_SKIP() << "Test requires a post-Ada GPU."; + GTEST_SKIP() << "Test requires Hopper or newer architecture."; } absl::string_view kModuleReplicatedStr = R"( -HloModule module, entry_computation_layout={(bf16[384,128], bf16[96,128], bf16[], bf16[])->bf16[384,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 -add { - lhs = bf16[] parameter(0) - rhs = bf16[] parameter(1) - ROOT add = bf16[] add(lhs, rhs) -} +HloModule module, entry_computation_layout={(bf16[128,128], bf16[32,128], bf16[], bf16[])->bf16[512,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 while_cond { - param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) - gte = s32[] get-tuple-element(param), index=0 - constant.1 = s32[] constant(3) - ROOT cmp = pred[] compare(gte, constant.1), direction=LT + input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0) + loop_counter = s32[] get-tuple-element(input), index=0 + c4 = s32[] constant(4) + ROOT compare = pred[] compare(loop_counter, c4), direction=LT } while_body { - param = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) parameter(0) - get-tuple-element.394 = s32[] get-tuple-element(param), index=0 - get-tuple-element.395 = bf16[384,128] get-tuple-element(param), index=1 - get-tuple-element.k = bf16[96,128] get-tuple-element(param), index=2 - constant.2561 = s32[] constant(0) - constant.2557 = s32[] constant(1) - add.230 = s32[] add(get-tuple-element.394, constant.2557) - constant.2559 = s32[] constant(3) - subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394) - constant.2560 = s32[] constant(-1) - add.231 = s32[] add(subtract.139, constant.2560) - compare.747 = pred[] compare(add.231, constant.2561), direction=LT - constant.2562 = s32[] constant(2) - add.232 = s32[] add(subtract.139, constant.2562) - select.1348 = s32[] select(compare.747, add.232, add.231) - dynamic-slice.k = bf16[32,128] dynamic-slice(get-tuple-element.k, select.1348, constant.2561), dynamic_slice_sizes={32,128} - r = bf16[32,128] bitcast(dynamic-slice.k) - a = bf16[32,128] add(r, r), control-predecessors={constant.2559} - // A fp8 pattern of quant-dequant before the collective AG. - qa = f8e4m3fn[32,128] convert(a) - dqa = bf16[32,128] convert(qa) - a_scale = bf16[] get-tuple-element(param), index=3 - a_scales = bf16[32,128] broadcast(a_scale), dimensions={} - dqa_unscaled = bf16[32,128] multiply(dqa, a_scales) - mb = bf16[128,128] all-gather(dqa_unscaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} - ma = bf16[128,128] dynamic-slice(get-tuple-element.395, select.1348, constant.2561), dynamic_slice_sizes={128,128} - - qma = f8e4m3fn[128,128] convert(ma) - dqma = bf16[128,128] convert(qma) - ma_scale = bf16[] get-tuple-element(param), index=4 - ma_scales = bf16[128,128] broadcast(ma_scale), dimensions={} - dqma_unscaled = bf16[128,128] multiply(dqma, ma_scales) - mc = bf16[128,128] dot(dqma_unscaled, mb), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dynamic-update-slice.35 = bf16[384,128] dynamic-update-slice(get-tuple-element.395, mc, select.1348, constant.2561) - ROOT tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.k, a_scale, ma_scale), control-predecessors={a} -} -ENTRY entry { + input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0) + loop_counter = s32[] get-tuple-element(input), index=0 + lhs = bf16[128,128] get-tuple-element(input), index=1 + rhs = bf16[32,128] get-tuple-element(input), index=2 + partial_dot_output = bf16[512,128] get-tuple-element(input), index=5 + lhs_f8 = f8e4m3fn[128,128] convert(lhs) + rhs_f8 = f8e4m3fn[32,128] convert(rhs) + lhs_bf16 = bf16[128,128] convert(lhs_f8) + rhs_bf16 = bf16[32,128] convert(rhs_f8) + scale_lhs = bf16[] get-tuple-element(input), index=3 + scale_rhs = bf16[] get-tuple-element(input), index=4 + scale_lhs_bcast = bf16[128,128] broadcast(scale_lhs), dimensions={} + scale_rhs_bcast = bf16[32,128] broadcast(scale_rhs), dimensions={} + lhs_scaled = bf16[128,128] multiply(lhs_bf16, scale_lhs_bcast) + rhs_scaled = bf16[32,128] multiply(rhs_bf16, scale_rhs_bcast) + rhs_scaled_all_gathered = bf16[128,128] all-gather(rhs_scaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}} + dot = bf16[128,128] dot(lhs_scaled, rhs_scaled_all_gathered), lhs_contracting_dims={1}, rhs_contracting_dims={1} c0 = s32[] constant(0) - p0 = bf16[384,128] parameter(0) - p1 = bf16[96,128] parameter(1) - s0 = bf16[] parameter(2) - s1 = bf16[] parameter(3) - tuple = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) tuple(c0, p0, p1, s0, s1) - while = (s32[], bf16[384,128], bf16[96,128], bf16[], bf16[]) while(tuple), condition=while_cond, body=while_body - ROOT gte1 = bf16[384,128] get-tuple-element(while), index=1 -} -)"; - - const int64_t kNumReplicas = 1; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); - - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); - opts.set_xla_gpu_enable_pipelined_collectives(true); - opts.set_xla_gpu_enable_triton_gemm(false); - CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); -} - -TEST_F(CollectiveOpsTestE2E, - PostLayoutCollectivePipelinerShouldFlattenCallGraph) { - // The allgather in the loop has a nested while loop as its operand, - // when the pipelining happens, the nested while loop will be peeled outside. - // However, when a while is cloned, its call sites are still preserved which - // will error out in alias analysis. When the graph is flattened, the error - // should not happen. - absl::string_view kModuleReplicatedStr = R"( -HloModule module - -while_cond { - param = (s32[], f32[2,128], f32[8,128], f32[8,128]) parameter(0) - gte = s32[] get-tuple-element(param), index=0 - constant.1 = s32[] constant(3) - ROOT cmp = pred[] compare(gte, constant.1), direction=LT -} - -while_nested_cond { - param.nested = (s32[], f32[2,128]) parameter(0) - gte.nested = s32[] get-tuple-element(param.nested), index=0 - constant.nested = s32[] constant(3) - ROOT cmp.nested = pred[] compare(gte.nested, constant.nested), direction=LT -} -while_nested_body { - param.body_nested = (s32[], f32[2,128]) parameter(0) - gte.body_nested = s32[] get-tuple-element(param.body_nested), index=0 - gte.2.body_nested = f32[2,128] get-tuple-element(param.body_nested), index=1 - - constant.body_nested = s32[] constant(1) - add.body_nested = s32[] add(gte.body_nested, constant.body_nested) - rsqrt.body_nested = f32[2,128] rsqrt(gte.2.body_nested) - ROOT tuple.body_nested = (s32[], f32[2,128]) tuple(add.body_nested, rsqrt.body_nested) + size = s32[] constant(128) + iteration_offset = s32[] multiply(loop_counter, size) + updated_dot_output = bf16[512,128] dynamic-update-slice(partial_dot_output, dot, iteration_offset, c0) + c1 = s32[] constant(1) + loop_counter_plus_one = s32[] add(loop_counter, c1) + ROOT tuple = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(loop_counter_plus_one, lhs, rhs, scale_lhs, scale_rhs, updated_dot_output) } - -while_body { - param = (s32[], f32[2,128], f32[8,128], f32[8,128]) parameter(0) - get-tuple-element.394 = s32[] get-tuple-element(param), index=0 - get-tuple-element.395 = f32[2,128] get-tuple-element(param), index=1 - get-tuple-element.35 = f32[8,128] get-tuple-element(param), index=2 - get-tuple-element.36 = f32[8,128] get-tuple-element(param), index=3 - - constant.2557 = s32[] constant(1) - add.230 = s32[] add(get-tuple-element.394, constant.2557) - mul = f32[2,128] multiply(get-tuple-element.395, get-tuple-element.395) - constant.while = s32[] constant(0) - tuple.1 = (s32[], f32[2,128]) tuple(constant.while, mul) - while.1 = (s32[], f32[2,128]) while(tuple.1), condition=while_nested_cond, body=while_nested_body - gte.while = f32[2,128] get-tuple-element(while.1), index=1 - add.while = f32[2,128] add(gte.while, get-tuple-element.395) - - ag.1 = f32[8,128] all-gather(add.while), replica_groups={}, dimensions={0} - add.ag = f32[8,128] add(ag.1, get-tuple-element.36) - - ROOT tuple = (s32[], f32[2,128], f32[8,128], f32[8,128]) tuple(add.230, get-tuple-element.395, get-tuple-element.35, ag.1) -} - ENTRY entry { c0 = s32[] constant(0) - p0 = f32[2,128] parameter(0) - p1 = f32[8,128] parameter(1) - - tuple = (s32[], f32[2,128], f32[8,128], f32[8,128]) tuple(c0, p0, p1, p1) - while = (s32[], f32[2,128], f32[8,128], f32[8,128]) while(tuple), condition=while_cond, body=while_body - gte1 = f32[2,128] get-tuple-element(while), index=1 - gte2 = f32[8,128] get-tuple-element(while), index=3 - ROOT tuple.result = (f32[2,128], f32[8,128]) tuple(gte1, gte2) + lhs = bf16[128,128] parameter(0) + rhs = bf16[32,128] parameter(1) + scale_lhs = bf16[] parameter(2) + scale_rhs = bf16[] parameter(3) + result_buffer = bf16[512,128] constant(0.) + while_input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(c0, lhs, rhs, scale_lhs, scale_rhs, result_buffer) + while = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) while(while_input), condition=while_cond, body=while_body + ROOT dot_output = bf16[512,128] get-tuple-element(while), index=5 } )"; @@ -1139,22 +1152,10 @@ ENTRY entry { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); auto opts = GetDebugOptionsForTest(); - opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); - opts.set_xla_gpu_enable_pipelined_all_reduce(true); - opts.set_xla_gpu_enable_pipelined_all_gather(true); - opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); - + opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(opts); - config.set_use_spmd_partitioning(false); - - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CreateExecutable(std::move(module), - /*run_hlo_passes=*/true)); - EXPECT_TRUE(executable->has_module()); + CollectiveOpsVerifyF8Matmul( + absl::StrReplaceAll(kModuleReplicatedStr, replacements_), opts); } TEST_F(CollectiveOpsTestE2E, AllToAllQuantizeCollectiveQuantizer) { @@ -1261,5 +1262,37 @@ ENTRY entry { LiteralTestUtil::ExpectR1Equal({8., 8.}, results[1]); } +TEST_F(CollectiveOpsTestE2E, NoErrorOnDuplicateChannelId) { + absl::string_view kModuleReplicatedStr = R"( +HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f32[4,32,128]{2,1,0})->(f32[4,32,128]{2,1,0}, f32[4,32,128]{2,1,0})}, num_partitions=4 +ENTRY entry { + param = f32[4,32,128]{2,1,0} parameter(0) + all-to-all = f32[4,32,128]{2,1,0} all-to-all(param), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1} + all-to-all.1 = f32[4,32,128]{2,1,0} all-to-all(param), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0} + ROOT tuple = (f32[4,32,128]{2,1,0}, f32[4,32,128]{2,1,0}) tuple(all-to-all, all-to-all.1) +} +)"; + + const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + + auto opts = GetDebugOptionsForTest(); + opts.set_xla_experimental_ignore_channel_id(true); + config.set_debug_options(opts); + + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 9cd874c9e03c13..ae19748244ac49 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -53,7 +53,7 @@ class CollectiveOpsTest : public HloTestBase { } protected: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); // Disable async->sync collective conversion pass to enable unit testing // of async collectives. @@ -402,15 +402,14 @@ XLA_TEST_F(CollectiveOpsTest, HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/2); auto executable = - test_runner_ - .CreateExecutable(MakeCrsModule(input_literal.shape(), - /*replica_groups=*/{}, config), - /*run_hlo_passes=*/true) + CreateExecutable(MakeCrsModule(input_literal.shape(), + /*replica_groups=*/{}, config), + /*run_hlo_passes=*/true) .value(); std::vector devices = {0, 1}; auto device_assn = MakeDeviceAssn(devices); - HloRunner::ReplicatedExecuteOptions opts; + HloRunnerInterface::ReplicatedExecuteOptions opts; opts.num_replicas = devices.size(); opts.use_threads = true; opts.arguments.push_back(&input_literal); @@ -420,7 +419,7 @@ XLA_TEST_F(CollectiveOpsTest, for (int64_t i = 0; i < kNumThreads * kRunsPerThread; ++i) { pool.Schedule([&] { TF_ASSERT_OK( - test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn) + ExecuteReplicatedWithHloRunner(executable.get(), opts, &device_assn) .status()); done.DecrementCount(); }); @@ -652,6 +651,44 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { } } +XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_TwoGPUs)) { + const char* const kModuleStr = R"( + HloModule test + + collective_broadcast { + p0 = u32[2] parameter(0) + ROOT result = u32[2] collective-broadcast(p0), replica_groups={{1, 0}} + } + + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + cb = ((u32[2]), u32[2]) async-start(u32[2] %p), calls=collective_broadcast + ROOT res = u32[2] async-done(cb), calls=collective_broadcast + } + )"; + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[1])); +} + XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { const char* const kModuleStr = R"( HloModule test @@ -694,6 +731,38 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { results[3])); } +XLA_TEST_F(CollectiveOpsTest, CollectivePermute_TwoGPUs) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + replica = u32[] replica-id() + ten = u32[] constant(10) + sum = u32[] add(replica, ten) + p = u32[2] broadcast(sum), dimensions={} + permute = u32[2] collective-permute(p), source_target_pairs={{1,0}, {0,1}} + ROOT copy = u32[2] copy(permute) + } + )"; + const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({11, 11}), + results[0])); + EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR1({10, 10}), + results[1])); +} + XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { const char* const kModuleStr = R"( HloModule test @@ -1753,80 +1822,6 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceBFloat16Min) { } } -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[1,2] constant({{1,2}}) - allgather = f8e4m3fn[2, 2] all-gather(a0), dimensions={0} - p = f8e4m3fn[4] reshape(allgather) - ROOT out = f32[4] convert(p) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e4m3fn[2] constant({1,2}) - a2a = f8e4m3fn[2] all-to-all(a0), dimensions={0} - ROOT out = f32[2] convert(a2a) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); - LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); -} - -XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { - const char* const kModuleStr = R"( - HloModule test - ENTRY test_computation { - a0 = f8e5m2[2] constant({1,2}) - a1 = f8e5m2[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} - ROOT out = f32[2] convert(a1) - } - )"; - const int64_t kNumReplicas = 2; - HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); - LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); -} - XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllGather)) { const char* const kModuleStr = R"( HloModule test @@ -2225,6 +2220,10 @@ body { } recv-data.1 = u32[2] get-tuple-element(recv-done.1), index=0 + send-done.1 = token[] send-done(send.1), channel_id=0, + frontend_attributes={ + _xla_send_recv_pipeline="0" + } replica = u32[] replica-id() constant0 = u32[] constant(0) compare0 = pred[] compare(replica, constant0), direction=EQ @@ -2237,10 +2236,6 @@ body { r = u32[2] broadcast(c1), dimensions={} s = u32[2] add(r, recv-data) - send-done.1 = token[] send-done(send.1), channel_id=0, - frontend_attributes={ - _xla_send_recv_pipeline="0" - } ROOT result = (u32[], u32[2]) tuple(new_count, s) } @@ -2273,5 +2268,110 @@ body { results[1])); } +class Fp8CollectiveOpsTest : public CollectiveOpsTest { + public: + Fp8CollectiveOpsTest() { + replacements_[kF8E4M3DatatypePlaceholder] = + IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; + replacements_[kF8E5M2DatatypePlaceholder] = + IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + } + + protected: + bool IsCuda() { + return std::holds_alternative(Capability()); + } + + const se::GpuComputeCapability& Capability() { + return backend() + .default_stream_executor() + ->GetDeviceDescription() + .gpu_compute_capability(); + } + + absl::flat_hash_map replacements_; + + private: + static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; +}; + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) + ROOT out = f32[4] convert(p) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a2a = <>[2] all-to-all(a0), dimensions={0} + ROOT out = f32[2] convert(a2a) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 1}, results[0]); + LiteralTestUtil::ExpectR1Equal({2, 2}, results[1]); +} + +XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(CollectivePermute_8BitFloat)) { + const char* const kModuleStr = R"( + HloModule test + ENTRY test_computation { + a0 = <>[2] constant({1,2}) + a1 = <>[2] collective-permute(a0), source_target_pairs={{0,1}, {1,0}} + ROOT out = f32[2] convert(a1) + } + )"; + const int64_t kNumReplicas = 2; + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(kModuleStr, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, + /*use_threads=*/true, /*run_hlo_passes=*/true)); + ASSERT_EQ(results.size(), kNumReplicas); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[0]); + LiteralTestUtil::ExpectR1Equal({1, 2}, results[1]); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc index ee844727b9c7f8..7bf2054e384f9c 100644 --- a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc +++ b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/computation_placer.h" #include "xla/service/hlo_module_config.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" @@ -832,5 +833,366 @@ XLA_TEST_F(CollectivePipelineParallelismTest, ErrorSpec{1e-5, 1e-5})); } +XLA_TEST_F(CollectivePipelineParallelismTest, SendRecvLoop) { + const absl::string_view kModuleStr = R"( + HloModule test, num_partitions=4 + + while_condidtion { + param = (u32[], f32[2,2]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + c3 = u32[] constant(3) + ROOT cmp = pred[] compare(i, c3), direction=LT + } + + while_body { + param = (u32[], f32[2,2]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data = f32[2,2] get-tuple-element(param), index=1 + + // Send data from GPU i to i+1. Break cycle to avoid deadlock. + after_all = token[] after-all() + send_ctx = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=1 + recv_ctx = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + data_ = f32[2,2] get-tuple-element(recv_done), index=0 + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + ROOT result = (u32[], f32[2,2]) tuple(i_, data_) + } + + ENTRY test_computation { + data = f32[2,2] parameter(0) + i = u32[] constant(0) + init = (u32[], f32[2,2]) tuple(i, data) + while = (u32[], f32[2,2]) while(init), condition=while_condidtion, + body=while_body + ROOT data_ = f32[2,2] get-tuple-element(while), index=1 + } + )"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + + // Parse HLO module. + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); + std::unique_ptr module; + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // Create input data. + std::vector literals; + for (int64_t i = 0; i < kNumPartitions; ++i) { + float val = i + 1; + literals.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); + } + std::vector> inputs; + for (int64_t i = 0; i < kNumPartitions; ++i) { + inputs.push_back({&literals[i]}); + } + + // Create device assignment running across partitions. + DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + device_assignment(0, i) = i; + } + + // Execute and check results. + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), inputs, + /*num_replicas=*/kNumPartitions, + /*run_hlo_passes=*/false, &device_assignment)); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[1]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[2]); + LiteralTestUtil::ExpectR2Equal({{1, 1}, {1, 1}}, results[3]); +} + +XLA_TEST_F(CollectivePipelineParallelismTest, SendRecvLoop2Devices) { + const absl::string_view kModuleStr = R"( + HloModule test, num_partitions=2 + + // 1 iteration so that we can test on 2 GPUs. + while_condidtion { + param = (u32[], f32[2,2]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + c1 = u32[] constant(1) + ROOT cmp = pred[] compare(i, c1), direction=LT + } + + while_body { + param = (u32[], f32[2,2]) parameter(0) + i = u32[] get-tuple-element(param), index=0 + data = f32[2,2] get-tuple-element(param), index=1 + + // Just send from GPU 0 to GPU 1 to avoid deadlock. + after_all = token[] after-all() + send_ctx = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=1 + recv_ctx = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + data_ = f32[2,2] get-tuple-element(recv_done), index=0 + + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + ROOT result = (u32[], f32[2,2]) tuple(i_, data_) + } + + ENTRY test_computation { + data = f32[2,2] parameter(0) + i = u32[] constant(0) + init = (u32[], f32[2,2]) tuple(i, data) + while = (u32[], f32[2,2]) while(init), condition=while_condidtion, + body=while_body + ROOT data_ = f32[2,2] get-tuple-element(while), index=1 + } + )"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + + // Parse HLO module. + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); + std::unique_ptr module; + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // Create input data. + std::vector literals; + for (int64_t i = 0; i < kNumPartitions; ++i) { + float val = i + 1; + literals.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); + } + std::vector> inputs; + for (int64_t i = 0; i < kNumPartitions; ++i) { + inputs.push_back({&literals[i]}); + } + + // Create device assignment running across partitions. + DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + device_assignment(0, i) = i; + } + + // Execute and check results. + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), inputs, + /*num_replicas=*/kNumPartitions, + /*run_hlo_passes=*/false, &device_assignment)); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{1, 1}, {1, 1}}, results[1]); +} + +XLA_TEST_F(CollectivePipelineParallelismTest, + PartiallyPipelinedAsyncSendRecvLoop) { + const absl::string_view kModuleStr = R"( + HloModule test, num_partitions=4 + + while_condidtion { + param = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + parameter(0) + i = u32[] get-tuple-element(param), index=0 + c2 = u32[] constant(2) + ROOT cmp = pred[] compare(i, c2), direction=LT + } + + while_body { + param = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + parameter(0) + i = u32[] get-tuple-element(param), index=0 + send_ctx = get-tuple-element(param), index=1 + recv_ctx = get-tuple-element(param), index=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + data = get-tuple-element(recv_done), index=0 + after_all = token[] after-all() + send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=1 + recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=2 + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + ROOT result = (u32[], (f32[2,2], u32[], token[]), + (f32[2,2], u32[], token[])) tuple(i_, send_ctx_, recv_ctx_) + } + + ENTRY test_computation { + data = f32[2,2] parameter(0) + i = u32[] constant(0) + after_all = token[] after-all() + send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=1 + recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={ + _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}}, channel_id=2 + init = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + tuple(i, send_ctx_, recv_ctx_) + while = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + while(init), condition=while_condidtion, body=while_body + send_ctx = get-tuple-element(while), index=1 + recv_ctx = get-tuple-element(while), index=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + ROOT data_ = get-tuple-element(recv_done), index=0 + } + )"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + + // Parse HLO module. + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); + std::unique_ptr module; + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // Create input data. + std::vector literals; + for (int64_t i = 0; i < kNumPartitions; ++i) { + float val = i + 1; + literals.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); + } + std::vector> inputs; + for (int64_t i = 0; i < kNumPartitions; ++i) { + inputs.push_back({&literals[i]}); + } + + // Create device assignment running across partitions. + DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + device_assignment(0, i) = i; + } + + // Execute and check results. + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), inputs, + /*num_replicas=*/kNumPartitions, + /*run_hlo_passes=*/false, &device_assignment)); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[1]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[2]); + LiteralTestUtil::ExpectR2Equal({{1, 1}, {1, 1}}, results[3]); +} + +XLA_TEST_F(CollectivePipelineParallelismTest, + PartiallyPipelinedAsyncSendRecvLoop2Devices) { + const absl::string_view kModuleStr = R"( + HloModule test, num_partitions=2 + + while_condidtion { + param = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + parameter(0) + i = u32[] get-tuple-element(param), index=0 + c2 = u32[] constant(2) + ROOT cmp = pred[] compare(i, c2), direction=LT + } + + while_body { + param = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + parameter(0) + i = u32[] get-tuple-element(param), index=0 + send_ctx = get-tuple-element(param), index=1 + recv_ctx = get-tuple-element(param), index=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + data = get-tuple-element(recv_done), index=0 + after_all = token[] after-all() + send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=1 + recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=2 + c1 = u32[] constant(1) + i_ = u32[] add(i, c1) + ROOT result = (u32[], (f32[2,2], u32[], token[]), + (f32[2,2], u32[], token[])) tuple(i_, send_ctx_, recv_ctx_) + } + + ENTRY test_computation { + data = f32[2,2] parameter(0) + i = u32[] constant(0) + after_all = token[] after-all() + send_ctx_ = (f32[2,2], u32[], token[]) send(data, after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=1 + recv_ctx_ = (f32[2,2], u32[], token[]) recv(after_all), + frontend_attributes={_xla_send_recv_source_target_pairs={{0,1}}}, + channel_id=2 + init = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + tuple(i, send_ctx_, recv_ctx_) + while = (u32[], (f32[2,2], u32[], token[]), (f32[2,2], u32[], token[])) + while(init), condition=while_condidtion, body=while_body + send_ctx = get-tuple-element(while), index=1 + recv_ctx = get-tuple-element(while), index=2 + send_done = token[] send-done(send_ctx), channel_id=1 + recv_done = (f32[2,2], token[]) recv-done(recv_ctx), channel_id=2 + ROOT data_ = get-tuple-element(recv_done), index=0 + } + )"; + + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); + + // Parse HLO module. + HloModuleConfig config = GetModuleConfigForTest( + /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); + std::unique_ptr module; + TF_ASSERT_OK_AND_ASSIGN(module, + ParseAndReturnVerifiedModule(kModuleStr, config)); + + // Create input data. + std::vector literals; + for (int64_t i = 0; i < kNumPartitions; ++i) { + float val = i + 1; + literals.push_back(LiteralUtil::CreateR2({{val, val}, {val, val}})); + } + std::vector> inputs; + for (int64_t i = 0; i < kNumPartitions; ++i) { + inputs.push_back({&literals[i]}); + } + + // Create device assignment running across partitions. + DeviceAssignment device_assignment(/*replica_count=*/kNumReplicas, + /*computation_count=*/kNumPartitions); + for (int64_t i = 0; i < kNumPartitions; ++i) { + device_assignment(0, i) = i; + } + + // Execute and check results. + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), inputs, + /*num_replicas=*/kNumPartitions, + /*run_hlo_passes=*/false, &device_assignment)); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[0]); + LiteralTestUtil::ExpectR2Equal({{0, 0}, {0, 0}}, results[1]); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc index b9a0588e9ec649..c76df8f8a30694 100644 --- a/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc +++ b/third_party/xla/xla/tests/collective_pipeliner_execution_test.cc @@ -26,10 +26,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/service/collective_pipeliner.h" -#include "xla/service/hlo_dce.h" -#include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" #include "xla/util.h" diff --git a/third_party/xla/xla/tests/complex_unary_op_samples.h b/third_party/xla/xla/tests/complex_unary_op_samples.h index 3ccce969d00ad9..a4725384a001c2 100644 --- a/third_party/xla/xla/tests/complex_unary_op_samples.h +++ b/third_party/xla/xla/tests/complex_unary_op_samples.h @@ -804,9 +804,15 @@ struct Tan { /* 123 */ { { -2.e+00f, -min }, { 2.1850398e+00f, -6.7877737e-38f }, 2.5e-01f }, /* 124 */ { { -3.6093321e-13f, -min }, { -3.6093321e-13f, -min }, 2.1990233e+12f }, /* 125 */ { { -6.5136393e-26f, -min }, { -6.5136393e-26f, -min }, 9.6714066e+24f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 126 */ { { -min, -min }, { -min, -min }, 4.2535296e+37f }, +#endif /* 127 */ { { zero, -min }, { zero, -min }, 4.2535296e+37f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 128 */ { { min, -min }, { min, -min }, 4.2535296e+37f }, +#endif /* 129 */ { { 6.5136393e-26f, -min }, { 6.5136393e-26f, -min }, 9.6714066e+24f }, /* 130 */ { { 3.6093321e-13f, -min }, { 3.6093321e-13f, -min }, 2.1990233e+12f }, /* 131 */ { { 2.e+00f, -min }, { -2.1850398e+00f, -6.7877737e-38f }, 2.5e-01f }, @@ -821,9 +827,15 @@ struct Tan { /* 140 */ { { -2.e+00f, zero }, { 2.1850398e+00f, zero }, 2.5e-01f }, /* 141 */ { { -3.6093321e-13f, zero }, { -3.6093321e-13f, zero }, 2.1990233e+12f }, /* 142 */ { { -6.5136393e-26f, zero }, { -6.5136393e-26f, zero }, 9.6714066e+24f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 143 */ { { -min, zero }, { -min, zero }, 4.2535296e+37f }, +#endif /* 144 */ { { zero, zero }, { zero, zero }, 1.e+00f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 145 */ { { min, zero }, { min, zero }, 4.2535296e+37f }, +#endif /* 146 */ { { 6.5136393e-26f, zero }, { 6.5136393e-26f, zero }, 9.6714066e+24f }, /* 147 */ { { 3.6093321e-13f, zero }, { 3.6093321e-13f, zero }, 2.1990233e+12f }, /* 148 */ { { 2.e+00f, zero }, { -2.1850398e+00f, zero }, 2.5e-01f }, @@ -838,9 +850,15 @@ struct Tan { /* 157 */ { { -2.e+00f, min }, { 2.1850398e+00f, 6.7877737e-38f }, 2.5e-01f }, /* 158 */ { { -3.6093321e-13f, min }, { -3.6093321e-13f, min }, 2.1990233e+12f }, /* 159 */ { { -6.5136393e-26f, min }, { -6.5136393e-26f, min }, 9.6714066e+24f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 160 */ { { -min, min }, { -min, min }, 4.2535296e+37f }, +#endif /* 161 */ { { zero, min }, { zero, min }, 4.2535296e+37f }, +#ifndef __aarch64__ +// TODO(b/342448599); Fix and re-enable on Arm. /* 162 */ { { min, min }, { min, min }, 4.2535296e+37f }, +#endif /* 163 */ { { 6.5136393e-26f, min }, { 6.5136393e-26f, min }, 9.6714066e+24f }, /* 164 */ { { 3.6093321e-13f, min }, { 3.6093321e-13f, min }, 2.1990233e+12f }, /* 165 */ { { 2.e+00f, min }, { -2.1850398e+00f, 6.7877737e-38f }, 2.5e-01f }, @@ -967,6 +985,7 @@ struct Tan { /* 286 */ { { 6.1409603e+25f, inf }, { zero, 1.e+00f }, 5.e-01f }, /* 287 */ { { max, inf }, { zero, 1.e+00f }, 5.e-01f }, /* 288 */ { { inf, inf }, { zero, 1.e+00f }, 5.e-01f } + // clang-format on }; return table; diff --git a/third_party/xla/xla/tests/complex_unary_op_test.cc b/third_party/xla/xla/tests/complex_unary_op_test.cc index ee2f6ae26e25a2..119e50967eb1d2 100644 --- a/third_party/xla/xla/tests/complex_unary_op_test.cc +++ b/third_party/xla/xla/tests/complex_unary_op_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "xla/client/global_data.h" -#include "xla/client/lib/math.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/complex_unary_op_samples.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/compute_constant_test.cc b/third_party/xla/xla/tests/compute_constant_test.cc index 8742656f17ff7a..6524e47ffa5486 100644 --- a/third_party/xla/xla/tests/compute_constant_test.cc +++ b/third_party/xla/xla/tests/compute_constant_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/client/client_library.h" #include "xla/client/global_data.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/concat_test.cc b/third_party/xla/xla/tests/concat_test.cc index c3af237497d0d6..6f831d8f29c998 100644 --- a/third_party/xla/xla/tests/concat_test.cc +++ b/third_party/xla/xla/tests/concat_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/tests/conditional_test.cc b/third_party/xla/xla/tests/conditional_test.cc index 7d13d06851077a..c9e8f2b40a5b16 100644 --- a/third_party/xla/xla/tests/conditional_test.cc +++ b/third_party/xla/xla/tests/conditional_test.cc @@ -16,8 +16,8 @@ limitations under the License. #include #include -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/constants_test.cc b/third_party/xla/xla/tests/constants_test.cc index 1a462f7fbad51e..9650077ed57b28 100644 --- a/third_party/xla/xla/tests/constants_test.cc +++ b/third_party/xla/xla/tests/constants_test.cc @@ -15,16 +15,17 @@ limitations under the License. // Tests that constants in program memory round trip as expected. -#include "xla/client/lib/constants.h" +#include "xla/hlo/builder/lib/constants.h" #include #include +#include #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" @@ -32,6 +33,7 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/types.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" @@ -46,10 +48,11 @@ class ConstantsTest : public ClientLibraryTestBase { template class ConstantsFloatTest : public ConstantsTest {}; -typedef ::testing::Types - FloatTypes; +using FloatTypes = + ::testing::Types; TYPED_TEST_SUITE(ConstantsFloatTest, FloatTypes); diff --git a/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc b/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc index 55a540c7e9df13..e8225b7838a4d5 100644 --- a/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc +++ b/third_party/xla/xla/tests/conv_depthwise_backprop_filter_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" -#include "xla/service/despecializer.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/transforms/despecializer.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_common.cc b/third_party/xla/xla/tests/conv_depthwise_common.cc index 07a7bbfa53b283..5c4bb5d1fcef45 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.cc +++ b/third_party/xla/xla/tests/conv_depthwise_common.cc @@ -17,10 +17,10 @@ limitations under the License. #include -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" -#include "xla/service/despecializer.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/transforms/despecializer.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_common.h b/third_party/xla/xla/tests/conv_depthwise_common.h index 0deb41064ee0df..350858498111f4 100644 --- a/third_party/xla/xla/tests/conv_depthwise_common.h +++ b/third_party/xla/xla/tests/conv_depthwise_common.h @@ -18,10 +18,10 @@ limitations under the License. #include -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" -#include "xla/service/despecializer.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/transforms/despecializer.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/conv_depthwise_test.cc b/third_party/xla/xla/tests/conv_depthwise_test.cc index e1826bd902505d..05d2e6c446ee4a 100644 --- a/third_party/xla/xla/tests/conv_depthwise_test.cc +++ b/third_party/xla/xla/tests/conv_depthwise_test.cc @@ -15,10 +15,10 @@ limitations under the License. #include -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" -#include "xla/service/despecializer.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/transforms/despecializer.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index ef252a594e930b..4f06ea0cc290c7 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" @@ -54,11 +54,19 @@ class ConvertTestT : public ConvertTest { using ConvertTest::ConvertTest; }; using FloatingPointTypeList = - ::testing::Types; + ::testing::Types; TYPED_TEST_SUITE(ConvertTestT, FloatingPointTypeList); +template +class ConvertTestF16 : public ConvertTest { + public: + using ConvertTest::ConvertTest; +}; +using F16TypeList = ::testing::Types; +TYPED_TEST_SUITE(ConvertTestF16, F16TypeList); + TEST_F(ConvertTest, ConvertR1S32ToR1S32) { XlaBuilder builder(TestName()); auto a = ConstantR1(&builder, {42, 64}); @@ -729,8 +737,21 @@ XLA_TEST_F(ConvertTest, ConvertF32BF16) { } } +XLA_TYPED_TEST(ConvertTestT, ConvertFPToPred) { + XlaBuilder builder(this->TestName()); + using FP = TypeParam; + + auto a = ConstantR1(&builder, {FP{0.0}, FP{0.25}, FP{2.0}, FP{-0.0}}); + ConvertElementType(a, PRED); + + std::array expected = {false, true, true, false}; + this->template ComputeAndCompareR1(&builder, expected, {}); +} + +// ----- F8E5M2 + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -741,6 +762,7 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, @@ -752,8 +774,18 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { {0x1.DFCp15, 0x1.Cp15}, // Largest number that doesn't overflow {0x1.Ep15, inf}, // Smallest number that overflows {0x1p16, inf}, // Overflow - {0x1p-14, 0x1p-14}, // Smallest normal - {0x1.8p-15, 0x1.8p-15}, // Denormal + {0x1p-14, 0x1p-14}, // Smallest F8 normal + {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.8p-14, 0x0.8p-14}, // Denormal without rounding + {0x0.Ap-14, 0x0.8p-14}, // Round-to-even down + {0x0.Ep-14, 0x1.0p-14}, // Round-to-even up + {0x0.98p-14, 0x0.8p-14}, // Round-to-nearest down + {0x0.A8p-14, 0x0.Cp-14}, // Round-to-nearest up + {0x0.2p-14, 0}, // Largest number that underflows + {0x0.204p-14, 0x0.4p-14}, // Smallest number that doesn't underflow + {0x0.DFCp-14, 0x0.Cp-14}, // Largest number that rounds to denormal }; std::vector inputs; @@ -762,126 +794,307 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { inputs.push_back(Eigen::half{test_case.input}); expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); } + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e5m2Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {inf, inf}, + // clang-format on + {0x1.2p0, 0x1p0}, // Round-to-even down + {0x1.6p0, 0x1.8p0}, // Round-to-even up + {0x1.Cp15, 0x1.Cp15}, // Max value + {0x1.DFFFFEp15, 0x1.Cp15}, // Largest number that doesn't overflow + {0x1.Ep15, inf}, // Smallest number that overflows + {0x1p16, inf}, // Overflow + {0x1p-14, 0x1p-14}, // Smallest F8 normal + {0x1.Cp-15, 0x1p-14}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-15, 0x0.8p-14}, // Denormal without rounding + {0x1.4p-15, 0x0.8p-14}, // Round-to-even down + {0x1.Cp-15, 0x1.0p-14}, // Round-to-even up + {0x1.3p-15, 0x0.8p-14}, // Round-to-nearest down + {0x1.5p-15, 0x0.Cp-14}, // Round-to-nearest up + {0x1p-17, 0}, // Largest number that underflows + {0x1.000002p-17, 0x0.4p-14}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-15, 0x0.Cp-14}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e5m2; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E5M2); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. - // Round-tripping a NaN will turn it into a quiet NaN and doesn't necessarily - // preserve the payload. - ComputeAndCompareR1(&builder, all_f8, {}, ErrorSpec(0.)); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + } } -XLA_TEST_F(ConvertTest, ConvertF8e5m2F32Exhaustive) { - // Convert from f8e5m2 to f32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); - std::vector all_f8; - std::vector all_f32; + using From = tsl::float8_e5m2; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e5m2::ConvertTo(all_f8.back())); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E5M2); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzF32Exhaustive) { - // Convert from f8e5m2fnuz to f32. +// ----- F8E4M3 + +XLA_TEST_F(ConvertTest, ConvertF16F8e4m3Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); - std::vector all_f8; - std::vector all_f32; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e5m2fnuz::ConvertTo(all_f8.back())); - } + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFCp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.004p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFCp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); + } - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3F32Exhaustive) { - // Convert from f8e4m3 to f32. +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep7, 0x1.Ep7}, // Max value + {0x1.EFFFFEp7, 0x1.Ep7}, // Largest number that doesn't overflow + {0x1.Fp7, inf}, // Smallest number that overflows + {0x1p8, inf}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal - std::vector all_f8; - std::vector all_f32; + // Denormal tests + {0x0.2p-6, 0x0.2p-6}, // Smallest denormal + {0x0.Ep-6, 0x0.Ep-6}, // Largest denormal + {0x0.8p-6, 0x0.8p-6}, // Denormal without rounding + {0x0.9p-6, 0x0.8p-6}, // Round-to-even down + {0x0.Fp-6, 0x0.8p-5}, // Round-to-even up + {0x0.8Fp-6, 0x0.8p-6}, // Round-to-nearest down + {0x0.91p-6, 0x0.Ap-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x0.EFFFFEp-6, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e4m3; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - all_f32.push_back(tsl::float8_e4m3fn::ConvertTo(all_f8.back())); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } - ComputeAndCompareR1(&builder, all_f32, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2F16RoundtripExhaustive2) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e4m3; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2BF16RoundtripExhaustive3) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector inputs; for (int i = 0; i < 65536; i++) { inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E5M2); + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3FN + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -907,14 +1120,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-8, 0x1.0p-8}, // Denormal without rounding - {0x1.4p-8, 0x1.0p-8}, // Round-to-even down - {0x1.Cp-8, 0x1.0p-7}, // Round-to-even up - {0x1.5p-7, 0x1.4p-7}, // Round-to-nearest down - {0x1.3p-7, 0x1.4p-7}, // Round-to-nearest up - {0x1p-10, 0}, // Largest number that underflows - {0x1.004p-10, 0x1p-9}, // Smallest number that doesn't underflow - {0x1.DFCp-7, 0x1.Cp-7}, // Largest number that rounds to denormal + {0x1.0p-8, 0x0.4p-6}, // Denormal without rounding + {0x1.4p-8, 0x0.4p-6}, // Round-to-even down + {0x1.Cp-8, 0x0.8p-6}, // Round-to-even up + {0x1.3p-8, 0x0.4p-6}, // Round-to-nearest down + {0x1.5p-8, 0x0.6p-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.004p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x1.DFCp-7, 0x0.Ep-6}, // Largest number that rounds to denormal }; std::vector inputs; @@ -927,95 +1140,124 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FN); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - // Pass in ErrorSpec, as this causes all NaNs to be treated as equal. ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3fnRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Cp8, 0x1.Cp8}, // Max value + {0x1.Dp8, 0x1.Cp8}, // Largest number that doesn't overflow + {0x1.D00002p8, nan}, // Smallest number that overflows + {0x1p9, nan}, // Overflow + {0x1p-6, 0x1p-6}, // Smallest F8 normal + {0x1.Ep-7, 0x1p-6}, // Smallest number rounding up to normal + + // Denormal tests + {0x1.0p-8, 0x0.4p-6}, // Denormal without rounding + {0x1.4p-8, 0x0.4p-6}, // Round-to-even down + {0x1.Cp-8, 0x0.8p-6}, // Round-to-even up + {0x1.3p-8, 0x0.4p-6}, // Round-to-nearest down + {0x1.5p-8, 0x0.6p-6}, // Round-to-nearest up + {0x1p-10, 0}, // Largest number that underflows + {0x1.000002p-10, 0x0.2p-6}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-7, 0x0.Ep-6}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E4M3FN); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FN); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive2) { - // Convert from FP32 to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3fn; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3FN); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive3) { - // Convert from FP8 to FP32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3FN); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e4m3fn; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E4M3FN); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector inputs; for (int i = 0; i < 65536; i++) { inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E4M3FN); + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E4M3FN); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3B11FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1041,14 +1283,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { {0x1.Ep-11, 0x1p-10}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-12, 0x1.0p-12}, // Denormal without rounding - {0x1.4p-12, 0x1.0p-12}, // Round-to-even down - {0x1.Cp-12, 0x1.0p-11}, // Round-to-even up - {0x1.5p-11, 0x1.4p-11}, // Round-to-nearest down - {0x1.3p-11, 0x1.4p-11}, // Round-to-nearest up + {0x1.0p-12, 0x0.4p-10}, // Denormal without rounding + {0x1.4p-12, 0x0.4p-10}, // Round-to-even down + {0x1.Cp-12, 0x0.8p-10}, // Round-to-even up + {0x1.3p-12, 0x0.4p-10}, // Round-to-nearest down + {0x1.5p-12, 0x0.6p-10}, // Round-to-nearest up {0x1p-14, 0}, // Largest number that underflows - {0x1.004p-14, 0x1p-13}, // Smallest number that doesn't underflow - {0x1.DFCp-11, 0x1.Cp-11}, // Largest number that rounds to denormal + {0x1.004p-14, 0x0.2p-10}, // Smallest number that doesn't underflow + {0x1.DFCp-11, 0x0.Ep-10}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1061,67 +1303,125 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3b11fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3B11FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive) { - // Convert from FP8 to FP16, then back to FP8 +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e4m3b11fnuzRoundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, 0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {inf, nan}, + // clang-format on + {0x1.1p0, 0x1p0}, // Round-to-even down + {0x1.3p0, 0x1.4p0}, // Round-to-even up + {0x1.Ep4, 0x1.Ep4}, // Max value + {0x1.EFFFFEp4, 0x1.Ep4}, // Largest number that doesn't overflow + {0x1.Fp4, nan}, // Smallest number that overflows + {0x1p5, nan}, // Overflow + {0x1p-10, 0x1p-10}, // Smallest F8 normal + {0x1.Ep-11, 0x1p-10}, // Smallest number rounding up to normal - std::vector all_f8; + // Denormal tests + {0x1.0p-12, 0x0.4p-10}, // Denormal without rounding + {0x1.4p-12, 0x0.4p-10}, // Round-to-even down + {0x1.Cp-12, 0x0.8p-10}, // Round-to-even up + {0x1.3p-12, 0x0.4p-10}, // Round-to-nearest down + {0x1.5p-12, 0x0.6p-10}, // Round-to-nearest up + {0x1p-14, 0}, // Largest number that underflows + {0x1.000002p-14, 0x0.2p-10}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-11, 0x0.Ep-10}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E4M3B11FNUZ); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); + + using From = tsl::float8_e4m3b11fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); - ConvertElementType(all_f8_as_f16, F8E4M3B11FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive2) { - // Convert from FP32 to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { all_f8.push_back( - static_cast(Eigen::numext::bit_cast( + static_cast(Eigen::numext::bit_cast( static_cast(i)))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3B11FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive3) { - // Convert from FP8 to FP32. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3b11fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3b11fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3b11fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector all_f16; + for (int i = 0; i < 65536; i++) { + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, F32); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f16), F8E4M3B11FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E5M2FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1132,11 +1432,11 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.2p0, 0x1p0}, // Round-to-even down {0x1.6p0, 0x1.8p0}, // Round-to-even up @@ -1148,14 +1448,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding - {0x1.4p-16, 0x1.0p-16}, // Round-to-even down - {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up - {0x1.3p-16, 0x1.0p-16}, // Round-to-nearest down - {0x1.5p-16, 0x1.8p-16}, // Round-to-nearest up - {0x1p-18, 0}, // Largest number that underflows - {0x1.04p-18, 0x1p-17}, // Smallest number that doesn't underflow - {0x1.BFp-16, 0x1.8p-16}, // Largest number that rounds to denormal + {0x0.4p-14, 0x0.8p-15}, // Denormal without rounding + {0x0.5p-14, 0x0.8p-15}, // Round-to-even down + {0x0.7p-14, 0x1.0p-15}, // Round-to-even up + {0x0.4Cp-14, 0x0.8p-15}, // Round-to-nearest down + {0x0.54p-14, 0x0.Cp-15}, // Round-to-nearest up + {0x0.1p-14, 0}, // Largest number that underflows + {0x0.104p-14, 0x0.4p-15}, // Smallest number that doesn't underflow + {0x0.6FCp-14, 0x0.Cp-15}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1168,18 +1468,12 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e5m2fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { - // Convert from FP32 to FP8, then back to FP32 + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1190,11 +1484,11 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E5M2FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, {nan, nan}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.2p0, 0x1p0}, // Round-to-even down {0x1.6p0, 0x1.8p0}, // Round-to-even up @@ -1206,15 +1500,14 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { {0x1.Cp-16, 0x1p-15}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-16, 0x1.0p-16}, // Denormal without rounding - {0x1.4p-16, 0x1.0p-16}, // Round-to-even down + {0x1.0p-16, 0x0.8p-15}, // Denormal without rounding + {0x1.4p-16, 0x0.8p-15}, // Round-to-even down {0x1.Cp-16, 0x1.0p-15}, // Round-to-even up - {0x1.3FFFFEp-16, 0x1.0p-16}, // Round-to-nearest down - {0x1.5FFFFEp-16, 0x1.8p-16}, // Round-to-nearest up + {0x1.3p-16, 0x0.8p-15}, // Round-to-nearest down + {0x1.5p-16, 0x0.Cp-15}, // Round-to-nearest up {0x1p-18, 0}, // Largest number that underflows - {0x1.000002p-18, 0x1p-17}, // Smallest number that doesn't underflow - {0x1.BFFFFEp-16, 0x1.8p-16}, // Largest number that rounds to denormal - {0x1.FFFFFEp-50, 0}, // A very small input that should underflow + {0x1.000002p-18, 0x0.4p-15}, // Smallest number that doesn't underflow + {0x1.BFFFFEp-16, 0x0.Cp-15}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1226,110 +1519,80 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e5m2fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E5M2FNUZ); ConvertElementType(f8, F32); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive) { - // Convert from FP8 to each supported floating type, then back to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e5m2fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(Eigen::numext::bit_cast( - static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - - for (auto type : {F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); - ConvertElementType(all_f8_as_type, F8E5M2FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); - } - - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive2) { // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - static_cast(Eigen::numext::bit_cast( - static_cast(i)))); - } - - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E5M2FNUZ); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); -} - -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzRoundtripExhaustive3) { - // Convert from FP8 to supported floating point types. - XlaBuilder builder(TestName()); - - std::vector all_f8; - for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } + if constexpr (std::is_same_v) { + // TODO(b/370786669): Enable this test. + GTEST_SKIP() << "Skipping test for E3M4 as it requires an ml_dtypes " + "release with https://github.com/jax-ml/ml_dtypes/pull/205"; + } else { + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); + } - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, type); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + ConvertElementType(ConstantR1(&builder, all_f8), F8E5M2FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestT, ConvertF8e5m2fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. XlaBuilder builder(this->TestName()); - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + using From = tsl::float8_e5m2fnuz; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E5M2FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e5m2fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector all_f16; for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E5M2FNUZ); + ConvertElementType(ConstantR1(&builder, all_f16), F8E5M2FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } +// ----- F8E4M3FNUZ + XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP16 to FP8, then back to FP16. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1340,10 +1603,10 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.1p0, 0x1p0}, // Round-to-even down {0x1.3p0, 0x1.4p0}, // Round-to-even up @@ -1355,14 +1618,14 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding - {0x1.4p-9, 0x1.0p-9}, // Round-to-even down - {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down - {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up - {0x1p-11, 0}, // Largest number that underflows - {0x1.004p-11, 0x1p-10}, // Smallest number that doesn't underflow - {0x1.DFCp-8, 0x1.Cp-8}, // Largest number that rounds to denormal + {0x1.0p-9, 0x0.4p-7}, // Denormal without rounding + {0x1.4p-9, 0x0.4p-7}, // Round-to-even down + {0x1.Cp-9, 0x0.8p-7}, // Round-to-even up + {0x1.3p-9, 0x0.4p-7}, // Round-to-nearest down + {0x1.5p-9, 0x0.6p-7}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.004p-11, 0x0.2p-7}, // Smallest number that doesn't underflow + {0x1.DFCp-8, 0x0.Ep-7}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1375,18 +1638,12 @@ XLA_TEST_F(ConvertTest, ConvertF16F8e4m3fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); ConvertElementType(f8, F16); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { - // Convert from FP16 to FP8, then back to FP16 + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); float nan = std::numeric_limits::quiet_NaN(); float inf = std::numeric_limits::infinity(); @@ -1397,10 +1654,10 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { } test_cases[] = { // clang-format off {0.0, 0.0}, - {-0.0, 0.0}, // No signed zero in F8E4M3FNUZ + {-0.0, 0.0}, {1.0, 1.0}, {-1.0, -1.0}, - {inf, nan}, // No Inf in F8E4M3FNUZ + {inf, nan}, // clang-format on {0x1.1p0, 0x1p0}, // Round-to-even down {0x1.3p0, 0x1.4p0}, // Round-to-even up @@ -1412,15 +1669,14 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { {0x1.Ep-8, 0x1p-7}, // Smallest number rounding up to normal // Denormal tests - {0x1.0p-9, 0x1.0p-9}, // Denormal without rounding - {0x1.4p-9, 0x1.0p-9}, // Round-to-even down - {0x1.Cp-9, 0x1.0p-8}, // Round-to-even up - {0x1.5p-8, 0x1.4p-8}, // Round-to-nearest down - {0x1.3p-8, 0x1.4p-8}, // Round-to-nearest up - {0x1p-11, 0}, // Largest number that underflows - {0x1.000002p-11, 0x1p-10}, // Smallest number that doesn't underflow - {0x1.DFFFFEp-8, 0x1.Cp-8}, // Largest number that rounds to denormal - {0x1.FFFFFEp-50, 0}, // A very small input that should underflow + {0x1.0p-9, 0x0.4p-7}, // Denormal without rounding + {0x1.4p-9, 0x0.4p-7}, // Round-to-even down + {0x1.Cp-9, 0x0.8p-7}, // Round-to-even up + {0x1.3p-9, 0x0.4p-7}, // Round-to-nearest down + {0x1.5p-9, 0x0.6p-7}, // Round-to-nearest up + {0x1p-11, 0}, // Largest number that underflows + {0x1.000002p-11, 0x0.2p-7}, // Smallest number that doesn't underflow + {0x1.DFFFFEp-8, 0x0.Ep-7}, // Largest number that rounds to denormal }; std::vector inputs; @@ -1432,45 +1688,28 @@ XLA_TEST_F(ConvertTest, ConvertF32F8e4m3fnuzRoundtrip) { auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E4M3FNUZ); ConvertElementType(f8, F32); - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive) { - // Convert from FP8 to each supported floating type, then back to FP8. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); - } - - const bool saved = - execution_options_.debug_options().xla_allow_excess_precision(); - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - false); - - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - xla::XlaOp all_f8_as_type = ConvertElementType(all_f8_as_f8, type); - ConvertElementType(all_f8_as_type, F8E4M3FNUZ); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - execution_options_.mutable_debug_options()->set_xla_allow_excess_precision( - saved); + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E4M3FNUZ); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive2) { - // Convert from support floating types to FP8. + // Convert from supported floating point type to FP8. XlaBuilder builder(this->TestName()); std::vector all_f8; @@ -1480,98 +1719,210 @@ XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive2) { static_cast(i)))); } - xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f32, F8E4M3FNUZ); + ConvertElementType(ConstantR1(&builder, all_f8), F8E4M3FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzRoundtripExhaustive3) { - // Convert from FP8 to supported floating point types. - XlaBuilder builder(TestName()); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e4m3fnuzRoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); - std::vector all_f8; + using From = tsl::float8_e4m3fnuz; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); } - for (auto type : {F8E4M3FN, F8E4M3B11FNUZ, F8E4M3FNUZ, F8E5M2, F8E5M2FNUZ, - F16, BF16, F32, F64}) { - xla::XlaOp all_f8_as_f8 = - ConstantR1(&builder, all_f8); - ConvertElementType(all_f8_as_f8, type); - ComputeAndCompare(&builder, {}, ErrorSpec(0.)); - } + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { - // Convert from F16 to FP8. +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e4m3fnuzF16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. XlaBuilder builder(this->TestName()); - std::vector inputs; + std::vector all_f16; for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f16.push_back( + Eigen::numext::bit_cast(static_cast(i))); } - xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_f16_to_f8, F8E4M3FNUZ); + ConvertElementType(ConstantR1(&builder, all_f16), F8E4M3FNUZ); this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzBF16RoundtripExhaustive5) { - // Convert from BF16 to FP8. - XlaBuilder builder(this->TestName()); +// ----- F8E3M4 - std::vector inputs; - for (int i = 0; i < 65536; i++) { - inputs.push_back( - Eigen::numext::bit_cast(static_cast(i))); +XLA_TEST_F(ConvertTest, ConvertF16F8e3m4Roundtrip) { + // Convert from FP16 to FP8, then back to FP16 + XlaBuilder builder(TestName()); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7Cp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.004p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7Cp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(Eigen::half{test_case.input}); + expected_roundtrip.push_back(Eigen::half{test_case.expected_roundtrip}); } - xla::XlaOp all_bf16_to_f8 = ConstantR1(&builder, inputs); - ConvertElementType(all_bf16_to_f8, F8E4M3FNUZ); - this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); + auto f8 = + ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F16); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, + ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2ToPred) { +XLA_TEST_F(ConvertTest, DISABLED_ON_CPU(ConvertF32F8e3m4Roundtrip)) { + // Convert from FP32 to FP8, then back to FP32. XlaBuilder builder(TestName()); - using F8 = tsl::float8_e5m2; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); + float nan = std::numeric_limits::quiet_NaN(); + float inf = std::numeric_limits::infinity(); - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); + struct TestCase { + float input; + float expected_roundtrip; + } test_cases[] = { + // clang-format off + {0.0, 0.0}, + {-0.0, -0.0}, + {1.0, 1.0}, + {-1.0, -1.0}, + {nan, nan}, + {-nan, -nan}, + {inf, inf}, + {-inf, -inf}, + // clang-format on + {0x1.08p0, 0x1p0}, // Round-to-even down + {0x1.18p0, 0x1.2p0}, // Round-to-even up + {0x1.Fp3, 0x1.Fp3}, // Max value + {0x1.F7FFFEp3, 0x1.Fp3}, // Largest number that doesn't overflow + {0x1.F8p3, inf}, // Smallest number that overflows + {0x1p4, inf}, // Overflow + {0x1p-2, 0x1p-2}, // Smallest F8 normal + {0x1.Fp-3, 0x1p-2}, // Smallest number rounding up to normal + + // Denormal tests + {0x0.1p-2, 0x0.1p-2}, // Smallest denormal + {0x0.Fp-2, 0x0.Fp-2}, // Largest denormal + {0x0.8p-2, 0x0.8p-2}, // Denormal without rounding + {0x0.88p-2, 0x0.8p-2}, // Round-to-even down + {0x0.F8p-2, 0x0.8p-1}, // Round-to-even up + {0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down + {0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up + {0x1p-7, 0}, // Largest number that underflows + {0x1.000002p-7, 0x0.1p-2}, // Smallest number that doesn't underflow + {0x0.F7FFFEp-2, 0x0.Fp-2}, // Largest number that rounds to denormal + }; + + std::vector inputs; + std::vector expected_roundtrip; + for (auto test_case : test_cases) { + inputs.push_back(test_case.input); + expected_roundtrip.push_back(test_case.expected_roundtrip); + } + + auto f8 = ConvertElementType(ConstantR1(&builder, inputs), F8E3M4); + ConvertElementType(f8, F32); + ComputeAndCompareR1(&builder, expected_roundtrip, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e4m3fn; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive) { + // Convert from FP8 to supported floating point type, then back to FP8. + XlaBuilder builder(this->TestName()); - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f8_as_fp = + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + ConvertElementType(all_f8_as_fp, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e5m2fnuzToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e5m2fnuz; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive2) { + // Convert from supported floating point type to FP8. + XlaBuilder builder(this->TestName()); - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(static_cast( + Eigen::numext::bit_cast(static_cast(i)))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } -XLA_TEST_F(ConvertTest, ConvertF8e4m3fnuzToPred) { - XlaBuilder builder(TestName()); - using F8 = tsl::float8_e4m3fnuz; - auto a = ConstantR1(&builder, {F8{0.0}, F8{0.25}, F8{2.0}}); - ConvertElementType(a, PRED); +XLA_TYPED_TEST(ConvertTestT, ConvertF8e3m4RoundtripExhaustive3) { + // Convert from FP8 to supported floating point type. + XlaBuilder builder(this->TestName()); - std::array expected = {false, true, true}; - ComputeAndCompareR1(&builder, expected, {}); + using From = tsl::float8_e3m4; + std::vector all_f8; + for (int i = 0; i < 256; i++) { + all_f8.push_back(Eigen::numext::bit_cast(static_cast(i))); + } + + ConvertElementType(ConstantR1(&builder, all_f8), + primitive_util::NativeToPrimitiveType()); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); +} + +XLA_TYPED_TEST(ConvertTestF16, ConvertF8e3m4F16RoundtripExhaustive4) { + // Convert from (B)F16 to FP8. + XlaBuilder builder(this->TestName()); + + std::vector inputs; + for (int i = 0; i < 65536; i++) { + inputs.push_back( + Eigen::numext::bit_cast(static_cast(i))); + } + + xla::XlaOp all_f16_to_f8 = ConstantR1(&builder, inputs); + ConvertElementType(all_f16_to_f8, F8E3M4); + this->ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } } // namespace diff --git a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc index da3909a199ac1c..557f4046ca4e82 100644 --- a/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc +++ b/third_party/xla/xla/tests/convolution_dimension_numbers_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/convolution_test.cc b/third_party/xla/xla/tests/convolution_test.cc index 56703df13727da..cb8841d1d5109d 100644 --- a/third_party/xla/xla/tests/convolution_test.cc +++ b/third_party/xla/xla/tests/convolution_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/convolution_test_1d.cc b/third_party/xla/xla/tests/convolution_test_1d.cc index a0f93d9cc70372..502a47ffcdcce8 100644 --- a/third_party/xla/xla/tests/convolution_test_1d.cc +++ b/third_party/xla/xla/tests/convolution_test_1d.cc @@ -24,8 +24,8 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/reference_util.h" diff --git a/third_party/xla/xla/tests/convolution_variants_test.cc b/third_party/xla/xla/tests/convolution_variants_test.cc index 6fb47d1e15640c..719e9b3d80e8bd 100644 --- a/third_party/xla/xla/tests/convolution_variants_test.cc +++ b/third_party/xla/xla/tests/convolution_variants_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc index 91b7fa2a1473c8..36b7e0815a844f 100644 --- a/third_party/xla/xla/tests/copy_test.cc +++ b/third_party/xla/xla/tests/copy_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -31,8 +31,10 @@ limitations under the License. #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/platform.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" @@ -45,6 +47,8 @@ namespace { class CopyOpTest : public HloTestBase { protected: + CopyOpTest() : platform_(*PlatformUtil::GetDefaultPlatform()) {} + void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); auto constant = @@ -81,6 +85,11 @@ class CopyOpTest : public HloTestBase { void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, absl::Span permutation); + + se::Platform* platform() const { return platform_; } + + private: + se::Platform* platform_; }; XLA_TEST_F(CopyOpTest, CopyR0Bool) { @@ -97,7 +106,7 @@ XLA_TEST_F(CopyOpTest, CopyR1S3U32) { XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) { // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. - if (backend().platform()->Name() == "Host") { + if (platform()->Name() == "Host") { GTEST_SKIP(); } Shape bounded_shape = @@ -110,7 +119,7 @@ XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) { XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) { // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. - if (backend().platform()->Name() == "Host") { + if (platform()->Name() == "Host") { GTEST_SKIP(); } Shape bounded_shape = @@ -124,7 +133,7 @@ XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) { XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) { // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. - if (backend().platform()->Name() == "Host") { + if (platform()->Name() == "Host") { GTEST_SKIP(); } Shape bounded_shape = @@ -138,7 +147,7 @@ XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) { XLA_TEST_F(CopyOpTest, CopyDynamicR1S512U32Dynamic64) { // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. - if (backend().platform()->Name() == "Host") { + if (platform()->Name() == "Host") { GTEST_SKIP(); } Shape bounded_shape = ShapeUtil::MakeShape(PrimitiveType::F32, {512}, {true}); diff --git a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc index d53699e477a885..af903c53dabe53 100644 --- a/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc +++ b/third_party/xla/xla/tests/cpu_gpu_fusion_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/array2d.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -138,7 +138,7 @@ class CpuGpuFusionTest : public HloTestBase { absl::Span xs); bool ComputeElementwiseAnswerCompare(ComparisonDirection direction, absl::Span xs); - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.add_xla_disable_hlo_passes("layout-assignment"); return debug_options; diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index 7e03bd9f2971a7..bca61939a79f3e 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -38,11 +38,11 @@ limitations under the License. #include "xla/array3d.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/tests/deallocation_test.cc b/third_party/xla/xla/tests/deallocation_test.cc index 42a5401c936c0c..213e3f05ed9931 100644 --- a/third_party/xla/xla/tests/deallocation_test.cc +++ b/third_party/xla/xla/tests/deallocation_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/deconstruct_tuple_test.cc b/third_party/xla/xla/tests/deconstruct_tuple_test.cc index 8eb5d15a1cd5d1..e5579e7abc4e20 100644 --- a/third_party/xla/xla/tests/deconstruct_tuple_test.cc +++ b/third_party/xla/xla/tests/deconstruct_tuple_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" diff --git a/third_party/xla/xla/tests/deep_graph_test.cc b/third_party/xla/xla/tests/deep_graph_test.cc index eed8303b6e0a48..2024fe697e8848 100644 --- a/third_party/xla/xla/tests/deep_graph_test.cc +++ b/third_party/xla/xla/tests/deep_graph_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" namespace xla { diff --git a/third_party/xla/xla/tests/dot_operation_test.cc b/third_party/xla/xla/tests/dot_operation_test.cc index 4c3c728f6e1fda..1dcaa6318ee4ae 100644 --- a/third_party/xla/xla/tests/dot_operation_test.cc +++ b/third_party/xla/xla/tests/dot_operation_test.cc @@ -20,19 +20,21 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/matrix.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/reference_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" +#include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -312,6 +314,7 @@ class ParametricDotTest : public DotOperationTest, std::string_view name( ::testing::UnitTest::GetInstance()->current_test_info()->name()); if (name.find("TestF16/270x270x520_MajorToMinor") != std::string::npos) { + GTEST_SKIP() << "Not supported on ROCm until Triton is re-enabled."; execution_options_.mutable_debug_options()->set_xla_gpu_autotune_level( 0); DotTestParam param = GetParam(); @@ -364,6 +367,27 @@ void ParametricDotTest::ComputeAndCompareR2WithError( ComputeAndCompareR2(builder, expected, arguments); } +template <> +void ParametricDotTest::ComputeAndCompareR2WithError( + XlaBuilder* builder, const Array2D& expected, + absl::Span arguments) { + ErrorSpec error_spec(0.3, 3e-3); + error_spec.low_precision_fp_error_spec.type = + primitive_util::NativeToPrimitiveType(); + error_spec.low_precision_fp_error_spec.within_n_values = 1; + ComputeAndCompareR2(builder, expected, arguments, error_spec); +} + +template <> +void ParametricDotTest::ComputeAndCompareR2WithError( + XlaBuilder* builder, const Array2D& expected, + absl::Span arguments) { + ErrorSpec error_spec(0.3, 3e-3); + error_spec.low_precision_fp_error_spec.type = + primitive_util::NativeToPrimitiveType(); + error_spec.low_precision_fp_error_spec.within_n_values = 1; + ComputeAndCompareR2(builder, expected, arguments, error_spec); +} template void ParametricDotTest::TestImpl() { DotTestParam param = GetParam(); @@ -486,6 +510,8 @@ XLA_TEST_P(ParametricDotTest, TestC64) { TestImpl>(); } XLA_TEST_P(ParametricDotTest, TestC128) { TestImpl>(); } #endif XLA_TEST_P(ParametricDotTest, TestS32) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF8E5M2) { TestImpl(); } +XLA_TEST_P(ParametricDotTest, TestF8E4M3FN) { TestImpl(); } XLA_TEST_P(ParametricDotTest, TestU8) { TestImpl(); } diff --git a/third_party/xla/xla/tests/dynamic_ops_test.cc b/third_party/xla/xla/tests/dynamic_ops_test.cc index ac530970229611..ab27dbe99072fe 100644 --- a/third_party/xla/xla/tests/dynamic_ops_test.cc +++ b/third_party/xla/xla/tests/dynamic_ops_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" #include "xla/service/local_service.h" #include "xla/service/platform_util.h" diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 902c2c696db98f..10258f7c1261d6 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -4,6 +4,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tests:build_defs.bzl", "xla_test") load("//xla/tests/exhaustive:build_defs.bzl", "exhaustive_xla_test") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -21,18 +22,41 @@ package_group( ], ) -cc_library( - name = "exhaustive_op_test_utils", +filegroup( + name = "exhaustive_op_test_utils_shared_hdrs", + testonly = True, + srcs = [ + "error_spec.h", + "exhaustive_op_test_base.h", + "exhaustive_op_test_utils.h", + ], + compatible_with = get_compatible_with_portable(), +) + +filegroup( + name = "exhaustive_op_test_utils_shared_srcs", testonly = True, srcs = [ "exhaustive_op_test_base.cc", "exhaustive_op_test_utils.cc", ], + compatible_with = get_compatible_with_portable(), +) + +cc_library( + name = "exhaustive_op_test_utils", + testonly = True, + srcs = [ + "platform.cc", + ":exhaustive_op_test_utils_shared_srcs", + ], hdrs = [ - "error_spec.h", - "exhaustive_op_test_base.h", - "exhaustive_op_test_utils.h", + "exhaustive_op_test.h", + "platform.h", + "test_op.h", + ":exhaustive_op_test_utils_shared_hdrs", ], + visibility = ["//visibility:private"], deps = [ "//xla:bit_cast", "//xla:executable_run_options", @@ -43,9 +67,11 @@ cc_library( "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_builder", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder:xla_computation", "//xla/service:shaped_buffer", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", "//xla/tests:client_library_test_base", "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:command_line_flags", @@ -64,33 +90,23 @@ cc_library( ], ) -filegroup( - name = "exhaustive_unary_test_srcs", - testonly = True, - srcs = [ - "exhaustive_unary_test_definitions.h", - "exhaustive_unary_test_functions.cc", +cc_library( + name = "exhaustive_unary_test_textual_hdrs", + textual_hdrs = [ + "exhaustive_unary_test_definitions.inc", + "exhaustive_unary_test_f32_and_smaller_instantiation.inc", + "exhaustive_unary_test_f64_instantiation.inc", + "exhaustive_unary_test_ops.inc", ], ) -filegroup( - name = "exhaustive_unary_test_f32_and_smaller_srcs", - testonly = True, - srcs = ["exhaustive_unary_test_f32_and_smaller_instantiation.cc"], -) - -filegroup( - name = "exhaustive_unary_test_f64_srcs", - testonly = True, - srcs = ["exhaustive_unary_test_f64_instantiation.cc"], -) - exhaustive_xla_test( name = "exhaustive_unary_test", timeout = "long", srcs = [ "exhaustive_test_main.cc", - ":exhaustive_unary_test_srcs", + "exhaustive_unary_test_definitions.h", + "exhaustive_unary_test_functions.cc", ], # Nvidia close-sourced libraries are not TSAN friendly, but are doing their own synchronization. # This can lead to TSAN false positives that are hard to track down. @@ -107,8 +123,12 @@ exhaustive_xla_test( # exhaustive_xla_test needs to have all partition names added to allow other build tools to # function. partitions = { - "f32_and_smaller": [":exhaustive_unary_test_f32_and_smaller_srcs"], - "f64": [":exhaustive_unary_test_f64_srcs"], + "f32_and_smaller": [ + "exhaustive_unary_test_f32_and_smaller_instantiation.cc", + ], + "f64": [ + ":exhaustive_unary_test_f64_instantiation.cc", + ], }, real_hardware_only = True, # Very slow on the interpreter. shard_count = 50, @@ -120,15 +140,17 @@ exhaustive_xla_test( ], deps = [ ":exhaustive_op_test_utils", + ":exhaustive_unary_test_textual_hdrs", "//xla:literal", "//xla:types", - "//xla/client:xla_builder", - "//xla/client/lib:constants", - "//xla/client/lib:math", + "//xla/hlo/builder:xla_builder", + "//xla/hlo/builder/lib:constants", + "//xla/hlo/builder/lib:math", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], @@ -161,50 +183,36 @@ xla_test( ":exhaustive_op_test_utils", "//xla:literal", "//xla:types", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) -filegroup( - name = "exhaustive_binary_test_srcs", - testonly = True, - srcs = [ - "exhaustive_binary_test_definitions.h", - "exhaustive_binary_test_functions.cc", +cc_library( + name = "exhaustive_binary_test_textual_hdrs", + textual_hdrs = [ + "exhaustive_binary_test_definitions.inc", + "exhaustive_binary_test_f16_and_smaller_instantiation.inc", + "exhaustive_binary_test_f32_instantiation.inc", + "exhaustive_binary_test_f64_instantiation.inc", + "exhaustive_binary_test_ops.inc", ], ) -filegroup( - name = "exhaustive_binary_test_f16_and_smaller_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f16_and_smaller_instantiation.cc"], -) - -filegroup( - name = "exhaustive_binary_test_f32_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f32_instantiation.cc"], -) - -filegroup( - name = "exhaustive_binary_test_f64_srcs", - testonly = True, - srcs = ["exhaustive_binary_test_f64_instantiation.cc"], -) - exhaustive_xla_test( name = "exhaustive_binary_test", timeout = "long", srcs = [ + "exhaustive_binary_test_definitions.h", + "exhaustive_binary_test_functions.cc", "exhaustive_test_main.cc", - ":exhaustive_binary_test_srcs", ], # Nvidia close-sourced libraries are not TSAN friendly, but are doing their own synchronization. # This can lead to TSAN false positives that are hard to track down. @@ -221,9 +229,15 @@ exhaustive_xla_test( # exhasutive_xla_test needs to have all partition names added to allow other build tools to # function. partitions = { - "f16_and_smaller": [":exhaustive_binary_test_f16_and_smaller_srcs"], - "f32": [":exhaustive_binary_test_f32_srcs"], - "f64": [":exhaustive_binary_test_f64_srcs"], + "f16_and_smaller": [ + "exhaustive_binary_test_f16_and_smaller_instantiation.cc", + ], + "f32": [ + "exhaustive_binary_test_f32_instantiation.cc", + ], + "f64": [ + "exhaustive_binary_test_f64_instantiation.cc", + ], }, shard_count = 50, tags = [ @@ -233,15 +247,17 @@ exhaustive_xla_test( "no_oss", ], deps = [ + ":exhaustive_binary_test_textual_hdrs", ":exhaustive_op_test_utils", "//xla:literal", "//xla:types", - "//xla/client:xla_builder", + "//xla/hlo/builder:xla_builder", "//xla/tests:xla_internal_test_main", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], diff --git a/third_party/xla/xla/tests/exhaustive/build_defs.bzl b/third_party/xla/xla/tests/exhaustive/build_defs.bzl index 446fe7de188460..f92de799dcae77 100644 --- a/third_party/xla/xla/tests/exhaustive/build_defs.bzl +++ b/third_party/xla/xla/tests/exhaustive/build_defs.bzl @@ -53,5 +53,8 @@ def exhaustive_xla_test(name, srcs, partitions, tags, **kwargs): register_extension_info( extension = exhaustive_xla_test, # Needs to be kept up-to-date on all partition names defined in the invocations. - label_regex_for_dep = "{extension_name}_(f16_and_smaller|f32_and_smaller|f32|f64)_.*", + # + # For some reason, manually specifying the expansion targets like (cpu|cpu_.*|...) is required + # for build tools. + label_regex_for_dep = "{extension_name}_(f16_and_smaller|f32_and_smaller|f32|f64)_(cpu|cpu_.*|gpu|gpu_.*)", ) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h index 6ee448ed1c9cf8..e4aee45b302cac 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.h @@ -16,205 +16,26 @@ limitations under the License. #ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_BINARY_TEST_DEFINITIONS_H_ #define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_BINARY_TEST_DEFINITIONS_H_ -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "xla/literal.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/test_macros.h" -#include "xla/types.h" -#include "tsl/platform/test.h" +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_binary_test_definitions.inc + +#include "absl/log/check.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "absl/log/log.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "absl/types/span.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/literal.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/tests/test_macros.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "xla/types.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_definitions.inc namespace xla { namespace exhaustive_op_test { -// Exhaustive test for binary operations for 16 bit floating point types, -// including float16 and bfloat. -// -// Test parameter is a pair of (begin, end) for range under test. -template -class Exhaustive16BitBinaryTest - : public ExhaustiveBinaryTest, - public ::testing::WithParamInterface> { - public: - int64_t GetInputSize() override { - int64_t begin, end; - std::tie(begin, end) = GetParam(); - return end - begin; - } - - // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 - // for the values of src0 and src1 (see below for ordering) for the 16 bit - // binary operation being tested, and generates the cartesian product of the - // two sets as the two inputs for the test. - // - // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes - // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 - // and 15..0 becomes src0. - void FillInput(std::array* input_literals) override { - int64_t input_size = GetInputSize(); - CHECK_EQ(input_size, (*input_literals)[0].element_count()); - CHECK_EQ(input_size, (*input_literals)[1].element_count()); - - int64_t begin, end; - std::tie(begin, end) = GetParam(); - if (VLOG_IS_ON(2)) { - uint16_t left_begin, left_end, right_begin, right_end; - if constexpr (kLeftToRightPacking) { - left_begin = std::bit_cast(static_cast(begin >> 16)); - left_end = std::bit_cast(static_cast(end >> 16)); - right_begin = std::bit_cast(static_cast(begin)); - right_end = std::bit_cast(static_cast(end)); - } else { - left_begin = std::bit_cast(static_cast(begin)); - left_end = std::bit_cast(static_cast(end)); - right_begin = - std::bit_cast(static_cast(begin >> 16)); - right_end = std::bit_cast(static_cast(end >> 16)); - } - - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; - LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" - << std::hex << left_begin << ", " << right_begin << "); float=(" - << *reinterpret_cast(&left_begin) << ", " - << *reinterpret_cast(&right_begin) - << ") (inclusive)"; - LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" - << std::hex << left_end << ", " << right_end << "); float=(" - << *reinterpret_cast(&left_end) << ", " - << *reinterpret_cast(&right_end) - << ") (exclusive)"; - LOG(INFO) << "\ttotal values to test=" << (end - begin); - } - - absl::Span input_arr_0 = (*input_literals)[0].data(); - absl::Span input_arr_1 = (*input_literals)[1].data(); - for (int64_t i = 0; i < input_size; i++) { - uint32_t input_val = i + begin; - // Convert the packed bits to a pair of NativeT and replace known - // incorrect input values with 0. - // - // In either case, we only use 32 bits out of the 64 bits possible. - if constexpr (kLeftToRightPacking) { - // Left is stored at higher 16 bits. - input_arr_0[i] = this->ConvertValue(input_val >> 16); - input_arr_1[i] = this->ConvertValue(input_val); - } else { - // Left is stored at lower 16 bits. - input_arr_0[i] = this->ConvertValue(input_val); - input_arr_1[i] = this->ConvertValue(input_val >> 16); - } - } - } - - protected: - using typename ExhaustiveBinaryTest::NativeT; -}; - -// Exhaustive test for binary operations for float and double. -// -// Test parameter is a tuple of (FpValues, FpValues) describing the possible -// values for each operand. The inputs for the test are the Cartesian product -// of the possible values for the two operands. -template -class Exhaustive32BitOrMoreBinaryTest - : public ExhaustiveBinaryTest, - public ::testing::WithParamInterface> { - protected: - using typename ExhaustiveBinaryTest::NativeT; - - private: - int64_t GetInputSize() override { - FpValues values_0; - FpValues values_1; - std::tie(values_0, values_1) = GetParam(); - return values_0.GetTotalNumValues() * values_1.GetTotalNumValues(); - } - - void FillInput(std::array* input_literals) override { - int64_t input_size = GetInputSize(); - FpValues values_0; - FpValues values_1; - std::tie(values_0, values_1) = GetParam(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; - LOG(INFO) << "\tleft values=" << values_0.ToString(); - LOG(INFO) << "\tright values=" << values_1.ToString(); - LOG(INFO) << "\ttotal values to test=" << input_size; - } - CHECK(input_size == (*input_literals)[0].element_count() && - input_size == (*input_literals)[1].element_count()); - - absl::Span input_arr_0 = (*input_literals)[0].data(); - absl::Span input_arr_1 = (*input_literals)[1].data(); - - uint64_t i = 0; - for (auto src0 : values_0) { - for (auto src1 : values_1) { - input_arr_0[i] = this->ConvertValue(src0); - input_arr_1[i] = this->ConvertValue(src1); - ++i; - } - } - CHECK_EQ(i, input_size); - } -}; - -using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; -using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; -using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; -using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -#define BINARY_TEST_F16(test_name, ...) \ - XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_F16(test_name, ...) -#endif - -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) -#define BINARY_TEST_BF16(test_name, ...) \ - XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_BF16(test_name, ...) -#endif - -#define BINARY_TEST_F32(test_name, ...) \ - XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ - __VA_ARGS__ - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; -#define BINARY_TEST_F64(test_name, ...) \ - XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \ - __VA_ARGS__ -#else -#define BINARY_TEST_F64(test_name, ...) -#endif - -#define BINARY_TEST(test_name, ...) \ - BINARY_TEST_F16(test_name, __VA_ARGS__) \ - BINARY_TEST_BF16(test_name, __VA_ARGS__) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ - BINARY_TEST_F64(test_name, __VA_ARGS__) - -#define BINARY_TEST_COMPLEX(test_name, ...) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ - BINARY_TEST_F64(test_name, __VA_ARGS__) +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.inc" } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc new file mode 100644 index 00000000000000..1c1967256218a0 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc @@ -0,0 +1,293 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Exhaustive test for binary operations for 8-bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template +class Exhaustive8BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + public: + int64_t GetInputSize() override { + int64_t begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64_t representation, uses bits 7..0 and bits 15..8 + // for the values of src0 and src1 (see below for ordering) for the 8-bit + // binary operation being tested, and generate the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, then bits 15..8 are interpreted as src0 + // and bits 7..0 are interpreted as src1. If `kLeftToRightPacking == false`, + // then bits 15..8 are interpreted as src1 and 7..0 are interpreted as src0. + void FillInput(std::array* input_literals) override { + int64_t input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64_t begin, end; + std::tie(begin, end) = GetParam(); + + if (VLOG_IS_ON(2)) { + uint8_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 8)); + left_end = std::bit_cast(static_cast(end >> 8)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = std::bit_cast(static_cast(begin >> 8)); + right_end = std::bit_cast(static_cast(end >> 8)); + } + + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + // N.B.: Cast to u32 to avoid printing values as char. + LOG(INFO) << "\tfrom=(" << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); hex=(" << std::hex + << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); float=(" + << std::bit_cast(left_begin) << ", " + << std::bit_cast(right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << static_cast(left_end) << ", " + << static_cast(right_end) << "); hex=(" << std::hex + << static_cast(left_end) << ", " + << static_cast(right_end) << "); float=(" + << std::bit_cast(left_end) << ", " + << std::bit_cast(right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64_t i = 0; i < input_size; i++) { + uint32_t input_val = i + begin; + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 16 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + input_arr_0[i] = this->ConvertValue(input_val >> 8); + input_arr_1[i] = this->ConvertValue(input_val); + } else { + input_arr_0[i] = this->ConvertValue(input_val); + input_arr_1[i] = this->ConvertValue(input_val >> 8); + } + } + } + + protected: + using typename ExhaustiveBinaryTest::NativeT; +}; + +// Exhaustive test for binary operations for 16 bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template +class Exhaustive16BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + int64_t begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 + // for the values of src0 and src1 (see below for ordering) for the 16 bit + // binary operation being tested, and generates the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes + // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 + // and 15..0 becomes src0. + void FillInput(std::array* input_literals) override { + using NativeT = typename ExhaustiveBinaryTest::NativeT; + + int64_t input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64_t begin, end; + std::tie(begin, end) = GetParam(); + if (VLOG_IS_ON(2)) { + uint16_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 16)); + left_end = std::bit_cast(static_cast(end >> 16)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = + std::bit_cast(static_cast(begin >> 16)); + right_end = std::bit_cast(static_cast(end >> 16)); + } + + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" + << std::hex << left_begin << ", " << right_begin << "); float=(" + << *reinterpret_cast(&left_begin) << ", " + << *reinterpret_cast(&right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" + << std::hex << left_end << ", " << right_end << "); float=(" + << *reinterpret_cast(&left_end) << ", " + << *reinterpret_cast(&right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64_t i = 0; i < input_size; i++) { + uint32_t input_val = i + begin; + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 32 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + // Left is stored at higher 16 bits. + input_arr_0[i] = this->ConvertValue(input_val >> 16); + input_arr_1[i] = this->ConvertValue(input_val); + } else { + // Left is stored at lower 16 bits. + input_arr_0[i] = this->ConvertValue(input_val); + input_arr_1[i] = this->ConvertValue(input_val >> 16); + } + } + } +}; + +// Exhaustive test for binary operations for float and double. +// +// Test parameter is a tuple of (FpValues, FpValues) describing the possible +// values for each operand. The inputs for the test are the Cartesian product +// of the possible values for the two operands. +template +class Exhaustive32BitOrMoreBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + return values_0.GetTotalNumValues() * values_1.GetTotalNumValues(); + } + + void FillInput(std::array* input_literals) override { + using NativeT = typename ExhaustiveBinaryTest::NativeT; + + int64_t input_size = GetInputSize(); + FpValues values_0; + FpValues values_1; + std::tie(values_0, values_1) = GetParam(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\tleft values=" << values_0.ToString(); + LOG(INFO) << "\tright values=" << values_1.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } + CHECK(input_size == (*input_literals)[0].element_count() && + input_size == (*input_literals)[1].element_count()); + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + + uint64_t i = 0; + for (auto src0 : values_0) { + for (auto src1 : values_1) { + input_arr_0[i] = this->ConvertValue(src0); + input_arr_1[i] = this->ConvertValue(src1); + ++i; + } + } + CHECK_EQ(i, input_size); + } +}; + +using ExhaustiveF8E4M3FNBinaryTest = Exhaustive8BitBinaryTest; +using ExhaustiveF8E5M2BinaryTest = Exhaustive8BitBinaryTest; +using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; +using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; +using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; +using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define BINARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNBinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define BINARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E5M2(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define BINARY_TEST_F16(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_F16(test_name, ...) +#endif + +#define BINARY_TEST_BF16(test_name, ...) \ + XLA_TEST_P(ExhaustiveBF16BinaryTest, test_name) \ + __VA_ARGS__ + +#define BINARY_TEST_F32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32BinaryTest, test_name) \ + __VA_ARGS__ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define BINARY_TEST_F64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_F64(test_name, ...) +#endif + +#define BINARY_TEST(test_name, ...) \ + BINARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + BINARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + BINARY_TEST_F16(test_name, __VA_ARGS__) \ + BINARY_TEST_BF16(test_name, __VA_ARGS__) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ + BINARY_TEST_F64(test_name, __VA_ARGS__) + +#define BINARY_TEST_COMPLEX(test_name, ...) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ + BINARY_TEST_F64(test_name, __VA_ARGS__) \ No newline at end of file diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc index d8451068624898..db863ed53af0fe 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.cc @@ -13,31 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f16_and_smaller_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) -INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); -#endif - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#include "xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc new file mode 100644 index 00000000000000..339406cd0da05d --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNBinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); +#endif + +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); +#endif + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc index ed28e923a035bf..c625282e987d5a 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.cc @@ -13,54 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f32_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); - -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - SpecialAndNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::Values(GetNormals(2000)))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndSpecialValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::Values(GetNormals(2000)), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine(::testing::Values(GetNormals(2000)), - ::testing::Values(GetNormals(2000)))); - -// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. -// Comparing with the unary tests, the binary tests use a smaller set of inputs -// for each sub-test to avoid timeout because the implementation of ExpectNear -// more than 2x slower for binary test. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF32BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, - 2000)), - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#include "xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc new file mode 100644 index 00000000000000..ba62061d7437fc --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); + +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(2000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(2000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(2000)), + ::testing::Values(GetNormals(2000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. +// Comparing with the unary tests, the binary tests use a smaller set of inputs +// for each sub-test to avoid timeout because the implementation of ExpectNear +// more than 2x slower for binary test. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF32BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals(40000, + 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc index c948c83703171e..fd2b73a706cca1 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.cc @@ -13,57 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_binary_test_f64_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - SpecialAndNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), - ::testing::Values(GetNormals(1000)))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndSpecialValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::Values(GetNormals(1000)), - ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); - -INSTANTIATE_TEST_SUITE_P( - NormalAndNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine(::testing::Values(GetNormals(1000)), - ::testing::Values(GetNormals(1000)))); - -// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. -// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each -// for each sub-test comparing with the unary test to avoid timeout. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF64BinaryTest, - ::testing::Combine( - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), - ::testing::ValuesIn( - GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); -#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#include "xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc new file mode 100644 index 00000000000000..a91f93ee155d45 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc @@ -0,0 +1,62 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32BinaryTest); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + SpecialAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()), + ::testing::Values(GetNormals(1000)))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndSpecialValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::Values(GetNormals(1000)), + ::testing::ValuesIn(CreateFpValuesForBoundaryTest()))); + +INSTANTIATE_TEST_SUITE_P( + NormalAndNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine(::testing::Values(GetNormals(1000)), + ::testing::Values(GetNormals(1000)))); + +// Tests a total of 40000 ^ 2 inputs, with 2000 ^ 2 inputs in each sub-test. +// Similar to ExhaustiveF64BinaryTest, we use a smaller set of inputs for each +// for each sub-test comparing with the unary test to avoid timeout. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF64BinaryTest, + ::testing::Combine( + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)), + ::testing::ValuesIn( + GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64BinaryTest); +#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc index a74b86ac89d019..2e87f43787f4b4 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_functions.cc @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include -#include +#include // IWYU pragma: keep, exhaustive_binary_test_ops.inc +#include // IWYU pragma: keep, exhaustive_binary_test_ops.inc #include #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" // IWYU pragma: keep, exhaustive_binary_test_ops.inc +#include "xla/tests/exhaustive/error_spec.h" #include "xla/tests/exhaustive/exhaustive_binary_test_definitions.h" #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_binary_test_ops.inc #include "xla/types.h" #ifdef __FAST_MATH__ @@ -33,346 +33,163 @@ namespace xla { namespace exhaustive_op_test { namespace { +#include "xla/tests/exhaustive/exhaustive_binary_test_ops.inc" + // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double AddCpuTpuAbsErr(NativeT left, NativeT right) { +double AddCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) + static_cast(right); - // Hardware flushes subnormal outputs to 0. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } BINARY_TEST(Add, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if ((IsCpu(platform_) || IsTpu(platform_))) { - if (std::is_same_v || - std::is_same_v || std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(AddCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Add), - [](NativeRefT x, NativeRefT y) { return x + y; }, error_spec_gen); + AddOp(this) + .Error(+[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(AddCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double SubCpuTpuAbsErr(NativeT left, NativeT right) { +double SubCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) - static_cast(right); - // Hardware flushes subnormal outputs to 0. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } BINARY_TEST(Sub, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) || IsTpu(platform_)) { - if (std::is_same_v || - std::is_same_v || std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(SubCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Sub), - [](NativeRefT x, NativeRefT y) { return x - y; }, error_spec_gen); + SubOp(this) + .Error(+[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(SubCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double MulCpuTpuAbsErr(NativeT left, NativeT right) { +double MulCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) * static_cast(right); - - // CPU BF16 and TPU (all types) flush subnormals to 0. + // CPU BF16 flush subnormals to 0. auto output_is_subnormal = IsSubnormal(output); if (output_is_subnormal) { return std::numeric_limits::min(); } - return 0.0; } -bool MulCpuTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { - // For CPU and TPU BF16, multiplying a subnormal by infinity will lead to +bool MulCpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + // For CPU BF16, multiplying a subnormal by infinity will lead to // calculating 0 multiplied by infinity due to subnormal flushing, which is // defined to be NaN. However, the calculation in higher precision does not // flush the subnormal value to 0, leading to a result of infinity. - if ((IsSubnormal(left) && std::isinf(right)) || - (std::isinf(left) && IsSubnormal(right))) { - return true; - } - return false; + return (IsSubnormal(left) && std::isinf(right)) || + (std::isinf(left) && IsSubnormal(right)); } BINARY_TEST(Mul, { - ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) || IsTpu(platform_)) { - if (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MulCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison( - MulCpuTpuBf16Skip(static_cast(left), - static_cast(right))) - .build(); - }; - } - if (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MulCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Mul), - [](NativeRefT x, NativeRefT y) { return x * y; }, error_spec_gen); + MulOp(this) + .Error(+[](NativeT left, NativeT right) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MulCpuAbsErr(left, right)) + .strict_signed_zeros() + .skip_comparison( + MulCpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + } + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MulCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of // `<= |std::numeric_limits::::min()|`. template -double DivCpuTpuAbsErr(NativeT left, NativeT right) { +double DivCpuAbsErr(NativeT left, NativeT right) { NativeRefT output = static_cast(left) / static_cast(right); - // Subnormals are flushed to 0 so we add a absolute error margin that is // larger than any subnormal. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - - return 0.0; -} - -template -double DivTpuAbsErr(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / static_cast(right); - NativeT output = left / right; - NativeRefT output_as_native_ref_t = - static_cast(left) / static_cast(right); - - // If we calculate NaN, we don't need to adjust tolerances. - if (std::isnan(output_as_native_ref_t)) { - return 0.0; - } - - // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are - // flushed to `0` if they are subnormal. Also applies to if reciprocal is min - // normal. - if (IsSubnormal(left) || IsSubnormal(reciprocal)) { - // Subnormals can have a larger value in BF16 than float due to rounding to - // the nearest BF16 value during conversion while having less representation - // bits. For normals, the float value is usually always bigger due to - // greater precision. - return std::max(std::abs(output), std::abs(output_as_native_ref_t)); - } - - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - if (IsSubnormal(output)) { - return std::numeric_limits::min(); - } - - return 0.0; -} - -template -double DivTpuBf16F32AbsErr(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / static_cast(right); - NativeT output = left / right; - NativeRefT output_as_native_ref_t = - static_cast(left) / static_cast(right); - - // If we calculate NaN, we don't need to adjust tolerances. - if (std::isnan(output_as_native_ref_t)) { - return 0.0; - } - - // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are - // flushed to `0` if they are subnormal. Also applies to if reciprocal is min - // normal. - if (IsSubnormal(left) || IsSubnormalOrMinNormal(reciprocal)) { - // Subnormals can have a larger value in BF16 than float due to rounding to - // the nearest BF16 value during conversion while having less representation - // bits. For normals, the float value is usually always bigger due to - // greater precision. - return std::max(std::abs(output), std::abs(output_as_native_ref_t)); - } - - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - if (IsSubnormalOrMinNormal(output)) { - return std::numeric_limits::min(); - } - return 0.0; } -template -bool DivTpuBf16F32Skip(NativeT left, NativeT right) { - NativeRefT reciprocal = 1.0f / right; - - // TPU calculates `left * (1 / right)` and flushed `(1 / right)` to `0` when - // it is subnormal or min normal. It also follows the IEEE multiplication spec - // that inf * 0 is NaN. However, IEEE division of infinity by a subnormal is - // infinity, so we must skip comparison. - if (std::isinf(left) && IsSubnormalOrMinNormal(reciprocal)) { - return true; - } - - return false; -} - BINARY_TEST(Div, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivCpuTpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_)) { - if (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(2) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivTpuBf16F32AbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - // This is basically distance_err(1), but is tighter because it - // guarantees this only happens when the abs_err is less than min - // normal. - .abs_err(std::numeric_limits::min()) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .distance_err(2) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .rel_err(34 * eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuBf16F32AbsErr(left, right)) - .rel_err(eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(DivTpuAbsErr(left, right)) - .rel_err(136 * eps) - .strict_signed_zeros() - .skip_comparison( - DivTpuBf16F32Skip(left, right)) - .build(); - }; - } - } - - Run( - AddEmptyBroadcastDimension(Div), - [](NativeRefT x, NativeRefT y) { return x / y; }, error_spec_gen); + DivOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(DivCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -388,61 +205,42 @@ double MaxMinCpuAbsErr(NativeT left, NativeT right) { } BINARY_TEST(Max, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MaxMinCpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_) || IsTpu(platform_)) { - error_spec_gen = +[](NativeT, NativeT) { - // A100 and H100 return -0 for max(-0,0). - // - // TPUs return -0 for max(0,-0) and 0 for max(-0,0). - return ErrorSpec::Builder().strict_signed_zeros(false).build(); - }; - } - - Run(AddEmptyBroadcastDimension(Max), ReferenceMax, - error_spec_gen); + MaxOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if ((std::is_same_v || + std::is_same_v || + std::is_same_v)) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + // A100 and H100 return -0 for max(-0,0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }) + .Run(); }) BINARY_TEST(Min, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_) && - (std::is_same_v || - std::is_same_v || std::is_same_v)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(MaxMinCpuAbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - - if (IsGpu(platform_) || IsTpu(platform_)) { - error_spec_gen = +[](NativeT, NativeT) { - // A100 and H100 return 0 for min(0,-0). - // - // TPUs return 0 for min(-0,0) and -0 for min(0,-0). - return ErrorSpec::Builder().strict_signed_zeros(false).build(); - }; - } - - Run(AddEmptyBroadcastDimension(Min), ReferenceMin, - error_spec_gen); + MinOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if (std::is_same_v || + std::is_same_v || std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuAbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + // A100 and H100 return 0 for min(0,-0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }) + .Run(); }) template @@ -476,17 +274,6 @@ double PowCpuBf16F32AbsErr(NativeT left, NativeT right) { return 0.0; } -double PowTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { - float output = std::pow(static_cast(left), static_cast(right)); - - // Output is flushed to 0 if subnormal. - if (IsSubnormal(output)) { - return std::numeric_limits::min(); - } - - return 0.0; -} - bool PowCpuF64Skip(double left, double right) { // Hardware returns 0 if right is positive and inf otherwise. if ((IsSubnormal(left) || std::isinf(left) || left == 0) && @@ -509,119 +296,43 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) { return false; } -template -bool PowTpuSkip(NativeT left, NativeT right) { - // Hardware always returns 1 if right is 0 (or subnormal due to - // flushing subnormals to zero before the operation), no matter if left is - // NaN. - if (std::isnan(left) && (right == 0.0f || IsSubnormal(right))) { - return true; - } - // Hardware always returns 1 if left is 1, no matter if right is NaN. - if (left == 1.0f && std::isnan(right)) { - return true; - } - - return false; -} - BINARY_TEST(Pow, { - ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .strict_signed_zeros() - .skip_comparison(PowCpuGpuF16Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(PowCpuBf16F32AbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .strict_signed_zeros() - .skip_comparison(PowCpuF64Skip(static_cast(left), - static_cast(right))) - .build(); - }; - } - } - - if (IsGpu(platform_)) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .skip_comparison(PowCpuGpuF16Skip(left, right)) - .build(); - }; - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(PowTpuBf16AbsErr(static_cast(left), - static_cast(right))) - .distance_err(1) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { + PowOp(this) + .CpuError(+[](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(PowCpuBf16F32AbsErr(left, right)) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuF64Skip(static_cast(left), + static_cast(right))) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT left, NativeT right) { return ErrorSpec::Builder() .distance_err(1) .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(8) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .rel_err(41 * eps) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .rel_err(44 * eps) - .strict_signed_zeros() - .skip_comparison(PowTpuSkip(left, right)) + .skip_comparison(PowCpuGpuF16Skip(left, right)) .build(); - }; - } - } - - Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -630,13 +341,11 @@ template double Atan2CpuBf16F32F64AbsErr(NativeT left, NativeT right) { NativeRefT output = std::atan2(static_cast(left), static_cast(right)); - // If the output would be a subnormal float, we allow some error to account // for BF16 implementation flushing subnormals to zero. if (IsSubnormal(output)) { return std::numeric_limits::min(); } - return 0.0; } @@ -648,151 +357,72 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) { if (IsSubnormal(left) && IsSubnormal(right)) { return true; } - return false; } -template -double Atan2TpuBf16F32AbsErr(NativeT left, NativeT right) { - NativeT output = static_cast(std::atan2(left, right)); - NativeRefT output_as_float = - std::atan2(static_cast(left), static_cast(right)); - - // If the output would be a subnormal float, we allow some error to account - // for BF16 implementation flushing subnormals to zero. TPUs also seem to - // flush the minimum value to 0 along with subnormals. - if (IsSubnormalOrMinNormal(output_as_float)) { - return std::numeric_limits::min(); - } - - // Implementation of Atan2 on TPUs is that they take the reciprocal of the - // larger of left or right. If this is subnormal or the minimum value, the TPU - // flushes it to 0 before using it in multiplication. When this happens, the - // error is the output calculation, either in BF16 or float, or PI/2, - // depending on which of the three is bigger. - NativeRefT reciprocal_as_float = - 1.0f / std::max(std::abs(static_cast(left)), - std::abs(static_cast(right))); - if (!std::isnan(output_as_float) && IsSubnormal(reciprocal_as_float)) { - return std::max({std::abs(output_as_float), std::abs(output), - static_cast(M_PI_2)}); - } - - return 0.0; -} - BINARY_TEST(Atan2, { - auto error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - .strict_signed_zeros() - .skip_comparison(Atan2CpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - // Only used when right is subnormal. - .distance_err(2) - .strict_signed_zeros() - .skip_comparison(Atan2CpuBf16F32Skip(left, right)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2CpuBf16F32F64AbsErr(left, right)) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsGpu(platform_)) { - if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(3) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT, NativeT) { - return ErrorSpec::Builder() - .distance_err(2) - .strict_signed_zeros() - .build(); - }; - } - } - - if (IsTpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .distance_err(1) - .strict_signed_zeros() - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .distance_err(3) - .strict_signed_zeros() - .build(); - }; - } - } - if (IsPreV6Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .rel_err(28 * eps) - .strict_signed_zeros() - .build(); - }; - } - } - if (IsPreV5Tpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeT left, NativeT right) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(Atan2TpuBf16F32AbsErr(left, right)) - .rel_err(133 * eps) - .strict_signed_zeros() - .build(); - }; - } - } - - Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen); + Atan2Op(this) + .CpuError([](NativeT left, NativeT right) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + .strict_signed_zeros() + .skip_comparison(Atan2CpuBf16F32Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + // Only used when right is subnormal. + .distance_err(2) + .strict_signed_zeros() + .skip_comparison(Atan2CpuBf16F32Skip(left, right)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .abs_err( + Atan2CpuBf16F32F64AbsErr(left, right)) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(3) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .strict_signed_zeros() + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) // Can be thought of as an absolute error of @@ -810,26 +440,7 @@ double AbsComplexCpuAbsErr(NativeRefT real, NativeRefT imag) { template bool AbsComplexSkip(NativeRefT real, NativeRefT imag) { // TODO(timshen): see b/162664705. - if (std::isnan(real) || std::isnan(imag)) { - return true; - } - return false; -} - -template -double AbsComplexTpuRelErr(NativeRefT real, NativeRefT imag) { - NativeRefT abs_max = std::max(std::abs(real), std::abs(imag)); - NativeRefT kOne(1); - NativeRefT reciprocal = kOne / abs_max; - if (IsSubnormal(reciprocal)) { - // In this case, the reciprocal erroneously returns zero, and - // we get max(|real|, |imag|) instead of sqrt(real^2 + imag^2), - // so the relative error can be as large as (sqrt(2)-1)/sqrt(2) ~= 0.293, - // when using the typical hypot implementation hypot(max, min) = max * - // sqrt(1 + min / max). - return 0.293; - } - return 0.0; + return std::isnan(real) || std::isnan(imag); } // It is more convenient to implement Abs(complex) as a binary op than a unary @@ -838,65 +449,33 @@ double AbsComplexTpuRelErr(NativeRefT real, NativeRefT imag) { // TODO(bixia): May want to move this test to unary test if we will be able to // implement Abs(complex) as unary conveniently. BINARY_TEST_COMPLEX(AbsComplex, { - ErrorSpecGen error_spec_gen = +[](NativeRefT, NativeRefT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - - if (IsCpu(platform_)) { - if constexpr (std::is_same_v || - std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .abs_err(AbsComplexCpuAbsErr(real, imag)) - .distance_err(2) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - } - - if (IsGpu(platform_)) { - if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .distance_err(3) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } else if constexpr (std::is_same_v) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .distance_err(2) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - } - - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .rel_err(AbsComplexTpuRelErr(real, imag)) - .distance_err(3) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeRefT real, NativeRefT imag) { - return ErrorSpec::Builder() - .rel_err(AbsComplexTpuRelErr(real, imag)) - .distance_err(125) - .skip_comparison(AbsComplexSkip(real, imag)) - .build(); - }; - } - - Run([](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }, - [](NativeRefT x, NativeRefT y) { - return std::abs(std::complex(x, y)); - }, - error_spec_gen); + AbsComplexOp(this) + .CpuError(+[](NativeRefT real, NativeRefT imag) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .abs_err(AbsComplexCpuAbsErr(real, imag)) + .distance_err(2) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .GpuError(+[](NativeRefT real, NativeRefT imag) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(3) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(2) + .skip_comparison(AbsComplexSkip(real, imag)) + .build(); + } + return ErrorSpec::Builder().strict_signed_zeros().build(); + }) + .Run(); }) } // namespace diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc new file mode 100644 index 00000000000000..611eb2c231aa46 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_test_ops.inc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define DEFINE_BINARY_TEST_OP(NAME, ENQUEUE, EVALUATE) \ + template \ + class NAME final : public BinaryTestOp { \ + public: \ + using Traits = BinaryTestOp::Traits; \ + using Test = BinaryTestOp::Test; \ + \ + explicit NAME(Test* test) : BinaryTestOp(test) {} \ + ~NAME() override {} \ + \ + Traits::EnqueueOp EnqueueOp() const override ENQUEUE; \ + \ + Traits::EvaluateOp EvaluateOp() const override EVALUATE; \ + }; \ + static_assert(true, "") + +DEFINE_BINARY_TEST_OP( + AddOp, { return AddEmptyBroadcastDimension(Add); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x + y; + }; + }); +DEFINE_BINARY_TEST_OP( + SubOp, { return AddEmptyBroadcastDimension(Sub); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x - y; + }; + }); +DEFINE_BINARY_TEST_OP( + MulOp, { return AddEmptyBroadcastDimension(Mul); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x * y; + }; + }); +DEFINE_BINARY_TEST_OP( + DivOp, { return AddEmptyBroadcastDimension(Div); }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return x / y; + }; + }); +DEFINE_BINARY_TEST_OP( + MaxOp, { return AddEmptyBroadcastDimension(Max); }, + { return ReferenceMax; }); +DEFINE_BINARY_TEST_OP( + MinOp, { return AddEmptyBroadcastDimension(Min); }, + { return ReferenceMin; }); +DEFINE_BINARY_TEST_OP( + PowOp, { return AddEmptyBroadcastDimension(Pow); }, { return std::pow; }); +DEFINE_BINARY_TEST_OP( + Atan2Op, { return AddEmptyBroadcastDimension(Atan2); }, + { return std::atan2; }); +DEFINE_BINARY_TEST_OP( + AbsComplexOp, + { return +[](XlaOp x, XlaOp y) { return Abs(Complex(x, y)); }; }, + { + return +[](typename Traits::NativeRefT x, typename Traits::NativeRefT y) { + return std::abs(std::complex(x, y)); + }; + }); + +#undef DEFINE_BINARY_TEST_OP \ No newline at end of file diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h new file mode 100644 index 00000000000000..524ab7f53fb289 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test.h @@ -0,0 +1,74 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ +#define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ + +#include + +#include "xla/tests/exhaustive/exhaustive_op_test_base.h" +#include "xla/tests/exhaustive/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +// openXLA-specific ExhaustiveOpTestBase subclass. +// +// Holds utility functions related to determining the execution platform. +// +// Type Parameters: +// - T: The primitive type being tested. +// - N: The number of operands that the function being tested takes. +// +// Pure Virtual Functions: +// - GetInputSize +// - FillInput +template +class ExhaustiveOpTest : public ExhaustiveOpTestBase { + public: + using Traits = ExhaustiveOpTestBase::Traits; + + ExhaustiveOpTest() : platform_(*this->client_->platform()) {} + + bool RelaxedDenormalSigns() const override { + return !platform_.IsNvidiaGpu(); + } + + const Platform& Platform() { return platform_; } + + // DEPRECATED: Only kept until exhaustive_unary_complex_test is merged into + // exhaustive_unary_test. Use the new TestOp framework for + // exhaustive_unary_test. + bool IsGpu() const { return platform_.IsGpu(); } + bool IsCpu() const { return platform_.IsCpu(); } + + static typename Traits::ErrorSpecGen GetDefaultSpecGenerator() { + return exhaustive_op_test::GetDefaultSpecGenerator(); + } + + protected: + const class Platform platform_; +}; + +template +using ExhaustiveUnaryTest = ExhaustiveOpTest; + +template +using ExhaustiveBinaryTest = ExhaustiveOpTest; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_OP_TEST_H_ diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc index b18bbec9c5ad04..accf828a599913 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.cc @@ -40,9 +40,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/bit_cast.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/fp_util.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/shaped_buffer.h" @@ -51,6 +52,7 @@ limitations under the License. #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/path.h" @@ -60,10 +62,6 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { -int eup_version = 0; - -int GetEupVersion() { return eup_version; } - bool dump_values = false; bool ShouldDumpValues() { return dump_values; } @@ -198,7 +196,6 @@ int GetCacheLocation(const std::array& input) { } // The inverse function of GetCacheLocation. - template ::value>::type* = nullptr> RetT FromCacheLocationComponent(int cache_loc) { @@ -568,11 +565,11 @@ ExhaustiveOpTestBase::GetTestValuesWithSubnormalSubstitutions( ComponentNativeRefT value) { std::vector test_values; if (std::fpclassify(value) == FP_SUBNORMAL) { - test_values.reserve(relaxed_denormal_signs_ ? 3 : 2); + test_values.reserve(RelaxedDenormalSigns() ? 3 : 2); test_values.push_back(std::copysign(0, value)); test_values.push_back( std::copysign(std::numeric_limits::min(), value)); - if (relaxed_denormal_signs_) { + if (RelaxedDenormalSigns()) { test_values.push_back(std::copysign(0, -value)); } } else { @@ -869,11 +866,15 @@ template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h index 9f90e64f9bc392..1f4f515046aa89 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_base.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include @@ -28,8 +27,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/bit_cast.h" -#include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/exhaustive/error_spec.h" @@ -40,16 +39,11 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { -// Access this through GetEupVersion. -extern int eup_version; - -// Get the TPU EUP version (if it was provided). -int GetEupVersion(); - // Return if the user specified dumping all tested values with their expected // and actual results. bool ShouldDumpValues(); +// Add all extra CLI flags that are used by ExhaustiveOpTestBase. void AddExhaustiveFlags(std::vector& flag_list); // Base class from which all exhaustive tests should inherit. @@ -60,10 +54,16 @@ void AddExhaustiveFlags(std::vector& flag_list); // Type Parameters: // - T: The primitive type being tested. // - N: The number of operands that the function being tested takes. +// +// Pure Virtual Functions: +// - GetInputSize +// - FillInput +// - RelaxedDenormalSigns template class ExhaustiveOpTestBase : public ClientLibraryTestBase { public: using Traits = ExhaustiveOpTestTraits; + static constexpr PrimitiveType kT = Traits::kT; using NativeT = typename Traits::NativeT; using NativeRefT = typename Traits::NativeRefT; @@ -85,10 +85,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { using ErrorSpecGen = typename Traits::ErrorSpecGen; ExhaustiveOpTestBase() - : ty_(T), - platform_(client_->platform()->Name()), - eup_version_(xla::exhaustive_op_test::GetEupVersion()), - should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { + : should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { SetFastMathDisabled(true); // Run all HLO passes. In particular, constant folding is disabled by @@ -105,8 +102,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // uint64_t. This function is used to convert such a bit pattern stored as // uint64_t to the input value for T. static ComponentNativeT ConvertValue(uint64_t bits) { - using I = ComponentIntegralNativeT; - I used_bits = static_cast(bits); + auto used_bits = static_cast(bits); return BitCast(used_bits); } @@ -116,13 +112,22 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // Fills the literals with values to test for. virtual void FillInput(LiteralInputs* literals) = 0; + // If true, allows denormals to be flushed to non-sign-preserving 0. + // + // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of + // a negative number) or -inf (flush the denormal to sign-preserving zero, + // then sqrt(-0)). When true, we'll also accept 0 (sqrt(0)). + // + // XLA:GPU preserves denormal signs, but other backends don't. + virtual bool RelaxedDenormalSigns() const = 0; + // Enable debug logging for the invocation of the lambda. // - // This is intended to be used to wrap a call to `Run`, which will then log - // extra debug information for a failure such as the calculated absolute, - // relative, and distance errors. In addition, in an effort to reduce output - // log size, this will trigger an ASSERT failure to early return from a test - // at the first failure. + // This is intended to be used to wrap a call to `Run`, which will then + // log extra debug information for a failure such as the calculated + // absolute, relative, and distance errors. In addition, in an effort to + // reduce output log size, this will trigger an ASSERT failure to early + // return from a test at the first failure. template , int> = 0> void EnableDebugLoggingForScope(Callable&& work) { @@ -218,41 +223,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { ErrorSpecGen error_spec_gen, OutputRangeCheck check_valid_range = nullptr); - const std::string& Platform() { return platform_; } - - bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } - bool IsCpu(const std::string& platform) const { return platform == "Host"; } - bool IsTpu(const std::string& platform) const { - return !IsGpu(platform) && !IsCpu(platform); - } - - int EupVersion() const { return eup_version_; } - bool IsPreV5Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 2; - } - bool IsPreV6Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 3; - } - protected: - // The primitive type being tested. - const PrimitiveType ty_; - - // The platform under test. - const std::string platform_; - - // Version of the EUP for a TPU target. Only relevant for TPU platforms. - const int eup_version_; - - // If true, allows denormals to be flushed to non-sign-preserving 0. - // - // For example, normally we'd expect sqrt(-denormal) to be either nan (sqrt of - // a negative number) or -inf (flush the denormal to sign-preserving zero, - // then sqrt(-0)). But with this as true, we'll also accept 0 (sqrt(0)). - // - // XLA:GPU preserves denormal signs, but other backends don't. - bool relaxed_denormal_signs_ = platform_ != "CUDA"; - // Indicates if files of the expected and actual values should be dumped. bool should_dump_values_ = false; @@ -261,24 +232,6 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { bool should_emit_debug_logging_ = false; }; -template -class ExhaustiveUnaryTest : public ExhaustiveOpTestBase { - public: - static typename ExhaustiveOpTestTraits::ErrorSpecGen - GetDefaultSpecGenerator() { - return exhaustive_op_test::GetDefaultSpecGenerator(); - } -}; - -template -class ExhaustiveBinaryTest : public ExhaustiveOpTestBase { - public: - static typename ExhaustiveOpTestTraits::ErrorSpecGen - GetDefaultSpecGenerator() { - return exhaustive_op_test::GetDefaultSpecGenerator(); - } -}; - } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index aa1a501a73b42c..11408952e920af 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -15,11 +15,46 @@ limitations under the License. #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include + +#include "xla/tests/exhaustive/error_spec.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" namespace xla { namespace exhaustive_op_test { +template +/* static */ typename ExhaustiveOpTestTraits::ErrorSpecGen +ExhaustiveOpTestTraits::FallbackErrorSpecGen() { + if constexpr (N == 1) { + return +[](NativeT) { return ErrorSpec{}; }; + } else if constexpr (N == 2) { + return +[](NativeT, NativeT) { return ErrorSpec{}; }; + } else { + static_assert( + N == 1 || N == 2, + "ExhaustiveOpTestTraits::FallbackErrorSpecGen() is only " + "implemented for N == 1 and N == 2."); + } +} + +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; + +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; + bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } bool IsSubnormalReal(xla::complex128 value) { diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index 79200f6107d8df..6448c960de2bce 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,8 +38,7 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" -#include "xla/fp_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/tests/exhaustive/error_spec.h" @@ -50,7 +50,7 @@ namespace exhaustive_op_test { // The primitive type used to compute the reference output. constexpr PrimitiveType Ref(PrimitiveType T) { - return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32; + return (!primitive_util::IsFloatingPointType(T) || T == F64) ? T : F32; } // The primitive type of the component of T. If T is not complex, then @@ -117,6 +117,12 @@ class ExhaustiveOpTestTraits { N == 1, ErrorSpec (*)(NativeT), std::conditional_t>>; + + // Returns an ErrorSpecGen that sets no error tolerances. + // + // The intention of this default is to force test writers to tighten bounds at + // least somewhat and not rely on overly large default tolerances. + static ErrorSpecGen FallbackErrorSpecGen(); }; template @@ -188,6 +194,16 @@ inline ErrorSpec DefaultSpecGenerator(xla::bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template <> inline ErrorSpec DefaultSpecGenerator(double, double) { double atol = kDefaultAbsoluteToleranceSlackFactor * @@ -224,6 +240,18 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16, bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn, + tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2, + tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { // Select overload by casting to fn ptr type. @@ -231,6 +259,21 @@ typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { DefaultSpecGenerator); } +template +typename Traits::ErrorSpecGen PickFirstErrorSpecGenPresent( + std::initializer_list error_specs) { + typename Traits::ErrorSpecGen ret = Traits::FallbackErrorSpecGen(); + for (auto it = error_specs.begin(); it != error_specs.end(); it++) { + // Check if the ErrorSpecGen is nullptr to indicate it is not set. Replace + // ret with the first non-nullptr ErrorSpecGen. + if (*it != nullptr) { + ret = *it; + break; + } + } + return ret; +} + // Determines if the real component of the complex number is subnormal (either // sign). // @@ -288,7 +331,7 @@ bool IsMinNormal(NativeT value) { std::is_same_v) { return IsMinNormalReal(value) || IsMinNormalImaginary(value); } else { - return std::abs(value) == std::numeric_limits::min(); + return std::abs(value) == std::numeric_limits::min(); // NOLINT } } @@ -760,7 +803,13 @@ CreateSubnormalExhaustiveRanges() { return ret; } -inline std::vector> CreateExhaustiveF32Ranges() { +inline std::vector> CreateExhaustiveU16Ranges() { + // The entire U16 range is small enough that we don't need to do any + // partitioning. + return {{0, std::numeric_limits::max()}}; +} + +inline std::vector> CreateExhaustiveU32Ranges() { // We break up the 2^32-element space into small-ish chunks to keep peak // memory usage low. std::vector> result; @@ -800,7 +849,7 @@ T ReferenceMin(T x, T y) { inline std::function AddEmptyBroadcastDimension( std::function)> build_method) { - return [&](XlaOp src0, XlaOp src1) -> XlaOp { + return [build_method](XlaOp src0, XlaOp src1) -> XlaOp { return build_method(src0, src1, {}); }; } diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc index 61a0f106a627d0..74fbed5c292b85 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include #include "xla/tests/exhaustive/exhaustive_op_test_base.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index bdc0ca990268ef..ce6dd4ed4179cf 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -21,9 +21,10 @@ limitations under the License. #include "absl/log/log.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/tests/exhaustive/error_spec.h" +#include "xla/tests/exhaustive/exhaustive_op_test.h" #include "xla/tests/exhaustive/exhaustive_op_test_base.h" #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/test_macros.h" @@ -137,16 +138,19 @@ UNARY_TEST_COMPLEX_64(Sqrt, { Run(Sqrt, [](complex64 x) { return std::sqrt(x); }, error_spec_gen); }) -double RsqrtCpuGpuAbsErr(complex64 x) { - return std::sqrt(std::numeric_limits::min()); +template +double RsqrtCpuGpuAbsErr(NativeT x) { + return std::sqrt(std::numeric_limits::min()); } -double RsqrtCpuGpuRelErr(complex64 x) { +template +double RsqrtCpuGpuRelErr(NativeT x) { // As noted above for Sqrt, the accuracy of sqrt degrades severely for // inputs with inputs with subnormals entries. - constexpr double eps = std::numeric_limits::epsilon(); - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); + constexpr double eps = std::numeric_limits::epsilon(); + constexpr double norm_min = std::numeric_limits::min(); + constexpr double denorm_min = + std::numeric_limits::denorm_min(); if (std::abs(x) < norm_min) { // Gradually loosen the relative tolerance as abs(x) becomes smaller // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. @@ -160,22 +164,22 @@ UNARY_TEST_COMPLEX_64(Rsqrt, { return ErrorSpec::Builder().strict_signed_zeros().build(); }; - if (IsCpu(platform_)) { + if (IsCpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) .skip_comparison(x.real() == 0.0f) .strict_signed_zeros(false) .build(); }; } - if (IsGpu(platform_)) { + if (IsGpu()) { error_spec_gen = +[](complex64 x) { return ErrorSpec::Builder() - .abs_err(RsqrtCpuGpuAbsErr(x)) - .rel_err(RsqrtCpuGpuRelErr(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) .strict_signed_zeros(false) .build(); }; @@ -251,7 +255,7 @@ UNARY_TEST_COMPLEX_128(Log, { return ErrorSpec::Builder().strict_signed_zeros().build(); }; - if (IsCpu(platform_) || IsGpu(platform_)) { + if (IsCpu() || IsGpu()) { error_spec_gen = +[](complex128 x) { // TODO(rmlarsen): see b/162664705 and b/138578594 bool should_skip = std::isnan(x.real()) || std::isnan(x.imag()); @@ -285,24 +289,38 @@ UNARY_TEST_COMPLEX_128(Sqrt, { }) UNARY_TEST_COMPLEX_128(Rsqrt, { - ErrorSpecGen error_spec_gen = +[](complex128 x) { - // As noted above for Sqrt, the accuracy of sqrt degrades severely for - // inputs with inputs with subnormals entries. - constexpr double norm_min = std::numeric_limits::min(); - constexpr double denorm_min = std::numeric_limits::denorm_min(); - if (std::abs(x) < norm_min) { - // Gradually loosen the relative tolerance as abs(x) becomes smaller - // than norm_min, letting it reach 100% when abs(x) = 10 * denorm_min. + ErrorSpecGen error_spec_gen = +[](complex128) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu()) { + error_spec_gen = +[](complex128 x) { return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(10 * denorm_min / std::abs(x)) + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) +#ifdef __aarch64__ + // TODO(b/365620546): ARM and x86 handle complex(inf, nan) + // differently. + .skip_comparison(x.real() == 0.0f || + (std::isinf(x.real()) && std::isnan(x.imag()))) +#else + .skip_comparison(x.real() == 0.0f) +#endif + .strict_signed_zeros(false) .build(); - } - return ErrorSpec::Builder() - .abs_err(std::sqrt(std::numeric_limits::min())) - .rel_err(50 * std::numeric_limits::epsilon()) - .build(); - }; + }; + } + + if (IsGpu()) { + error_spec_gen = +[](complex128 x) { + return ErrorSpec::Builder() + .abs_err(RsqrtCpuGpuAbsErr(x)) + .rel_err(RsqrtCpuGpuRelErr(x)) + .strict_signed_zeros(false) + .build(); + }; + } + Run( Rsqrt, [](complex128 x) { return complex128(1, 0) / std::sqrt(x); }, error_spec_gen); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h index 0c5e839c71aa69..2b3fa8f3c34a3a 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.h @@ -16,152 +16,24 @@ limitations under the License. #ifndef XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_UNARY_TEST_DEFINITIONS_H_ #define XLA_TESTS_EXHAUSTIVE_EXHAUSTIVE_UNARY_TEST_DEFINITIONS_H_ -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/types/span.h" -#include "xla/literal.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/test_macros.h" -#include "tsl/platform/test.h" +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include // IWYU pragma: keep, exhaustive_unary_test_definitions.inc + +#include "absl/log/check.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "absl/log/log.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "absl/types/span.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/literal.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "xla/tests/test_macros.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_definitions.inc namespace xla { namespace exhaustive_op_test { -// Exhaustive test for unary operations for <= 32bit floating point types. -// -// Test parameter is a tuple containing -// - primitive type under test, -// - (begin, end) range under test, as zero-extended int64_ts bitcast to the -// primitive type under test. -template -class Exhaustive32BitOrLessUnaryTest - : public ExhaustiveUnaryTest, - public ::testing::WithParamInterface> { - public: - // Sets error parameters appropriately for testing tan. - void SetParamsForTan(); - - protected: - using typename ExhaustiveUnaryTest::NativeT; - - private: - int64_t GetInputSize() override { - auto [begin, end] = GetParam(); - return end - begin; - } - - // Generates all the input values for the test. The range of the bit - // representation of the input values is described by the test parameter as - // a pair of int64_t representing the starting bit pattern and the ending - // pattern. Each bit representation is first truncated to the integral type of - // the same bit as the type being tested, if needed, and then bitcasted to the - // type being tested. - void FillInput(std::array* input_literal) override { - using IntegralT = - typename ExhaustiveOpTestBase::ComponentIntegralNativeT; - - auto [begin, end] = GetParam(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; - LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin - << "; float=" << *reinterpret_cast(&begin) - << " (inclusive)"; - LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end - << "; float=" << *reinterpret_cast(&end) - << " (exclusive)"; - LOG(INFO) << "\ttotal values to test=" << (end - begin); - } - - int64_t input_size = (*input_literal)[0].element_count(); - CHECK_EQ(input_size, end - begin); - - absl::Span input_arr = (*input_literal)[0].data(); - for (int64_t i = 0; i < input_size; i++) { - IntegralT input_val = i + begin; - input_arr[i] = this->ConvertValue(input_val); - } - } -}; - -using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; - -// Exhaustive test for unary operations for double. -// -// Test parameter is a tuple containing -// - primitive type under test, -// - FpValues representing a set of double values. -class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, - public ::testing::WithParamInterface { - private: - int64_t GetInputSize() override { - FpValues values = GetParam(); - return values.GetTotalNumValues(); - } - - void FillInput(std::array* input_literal) override { - FpValues fp_values = GetParam(); - int64_t input_size = (*input_literal)[0].element_count(); - if (VLOG_IS_ON(2)) { - // N.B.: Use INFO directly instead of doing another thread-safe VLOG - // check. - LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; - LOG(INFO) << "\t" << fp_values.ToString(); - LOG(INFO) << "\ttotal values to test=" << input_size; - } - - uint64_t i = 0; - absl::Span input_arr = (*input_literal)[0].data(); - for (auto bits : fp_values) { - input_arr[i] = this->ConvertValue(bits); - ++i; - } - CHECK_EQ(i, input_size); - } -}; - -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -#define UNARY_TEST_BF16(test_name, ...) \ - XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_BF16(test_name, ...) -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -#define UNARY_TEST_F16(test_name, ...) \ - XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_F16(test_name, ...) -#endif - -#define UNARY_TEST_F32(test_name, ...) \ - XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ - __VA_ARGS__ - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -#define UNARY_TEST_F64(test_name, ...) \ - XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \ - __VA_ARGS__ -#else -#define UNARY_TEST_F64(test_name, ...) -#endif - -#define UNARY_TEST(test_name, ...) \ - UNARY_TEST_BF16(test_name, __VA_ARGS__) \ - UNARY_TEST_F16(test_name, __VA_ARGS__) \ - UNARY_TEST_F32(test_name, __VA_ARGS__) \ - UNARY_TEST_F64(test_name, __VA_ARGS__) +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.inc" } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc new file mode 100644 index 00000000000000..5c023eec713643 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc @@ -0,0 +1,156 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Exhaustive test for unary operations for <= 32bit floating point types. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - (begin, end) range under test, as zero-extended int64_ts bitcast to the +// primitive type under test. +template +class Exhaustive32BitOrLessUnaryTest + : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface> { + protected: + int64_t GetInputSize() override { + auto [begin, end] = GetParam(); + return end - begin; + } + + // Generates all the input values for the test. The range of the bit + // representation of the input values is described by the test parameter as + // a pair of int64_t representing the starting bit pattern and the ending + // pattern. Each bit representation is first truncated to the integral type of + // the same bit as the type being tested, if needed, and then bitcasted to the + // type being tested. + void FillInput(std::array* input_literal) override { + using NativeT = typename ExhaustiveUnaryTest::NativeT; + using ComponentIntegralNativeT = + typename ExhaustiveUnaryTest::ComponentIntegralNativeT; + + auto [begin, end] = GetParam(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin + << "; float=" << *reinterpret_cast(&begin) + << " (inclusive)"; + LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end + << "; float=" << *reinterpret_cast(&end) + << " (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + int64_t input_size = (*input_literal)[0].element_count(); + CHECK_EQ(input_size, end - begin); + + absl::Span input_arr = (*input_literal)[0].data(); + for (int64_t i = 0; i < input_size; i++) { + ComponentIntegralNativeT input_val = + // We guarantee i + begin will be within range. + static_cast(i + begin); + input_arr[i] = this->ConvertValue(input_val); + } + } +}; + +using ExhaustiveF8E4M3FNUnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF8E5M2UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; + +// Exhaustive test for unary operations for double. +// +// Test parameter is a tuple containing +// - primitive type under test, +// - FpValues representing a set of double values. +class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, + public ::testing::WithParamInterface { + protected: + int64_t GetInputSize() override { + FpValues values = GetParam(); + return values.GetTotalNumValues(); + } + + void FillInput(std::array* input_literal) override { + FpValues fp_values = GetParam(); + int64_t input_size = (*input_literal)[0].element_count(); + if (VLOG_IS_ON(2)) { + // N.B.: Use INFO directly instead of doing another thread-safe VLOG + // check. + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\t" << fp_values.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } + + uint64_t i = 0; + absl::Span input_arr = (*input_literal)[0].data(); + for (auto bits : fp_values) { + input_arr[i] = ExhaustiveOpTestBase::ConvertValue(bits); + ++i; + } + CHECK_EQ(i, input_size); + } +}; + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define UNARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNUnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define UNARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E5M2(test_name, ...) +#endif + +#define UNARY_TEST_BF16(test_name, ...) \ + XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ + __VA_ARGS__ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +#define UNARY_TEST_F16(test_name, ...) \ + XLA_TEST_P(ExhaustiveF16UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F16(test_name, ...) +#endif + +#define UNARY_TEST_F32(test_name, ...) \ + XLA_TEST_P(ExhaustiveF32UnaryTest, test_name) \ + __VA_ARGS__ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#define UNARY_TEST_F64(test_name, ...) \ + XLA_TEST_P(ExhaustiveF64UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_F64(test_name, ...) +#endif + +#define UNARY_TEST(test_name, ...) \ + UNARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + UNARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + UNARY_TEST_BF16(test_name, __VA_ARGS__) \ + UNARY_TEST_F16(test_name, __VA_ARGS__) \ + UNARY_TEST_F32(test_name, __VA_ARGS__) \ + UNARY_TEST_F64(test_name, __VA_ARGS__) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc index 69ab47f64412ed..f252590a0edfd0 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.cc @@ -13,34 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_f32_and_smaller_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, - ::testing::Values(std::make_pair(0, 1 << 16))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); -#endif - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) -INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, - ::testing::Values(std::make_pair(0, 1 << 16))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); -#endif - -INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); +#include "xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc new file mode 100644 index 00000000000000..efb173686c9849 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc @@ -0,0 +1,43 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNUnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); +#endif + +INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) +INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 16))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); +#endif + +INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc index 9b94a7ced85959..6271809d97df5f 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.cc @@ -13,37 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" -#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" -#include "tsl/platform/test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc +#include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc +#include "tsl/platform/test.h" // IWYU pragma: keep, exhaustive_unary_test_f64_instantiation.inc namespace xla { namespace exhaustive_op_test { namespace { -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); - -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32UnaryTest); - -#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) -INSTANTIATE_TEST_SUITE_P( - SpecialValues, ExhaustiveF64UnaryTest, - ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); - -INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest, - ::testing::Values(GetNormals(1000))); - -// Tests a total of 4,000,000,000 inputs, with 16,000,000 inputs in each -// sub-test, to keep the peak memory usage low. -INSTANTIATE_TEST_SUITE_P( - LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest, - ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( - 4000000000ull, 16000000))); -#else -GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); -#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +#include "xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc" } // namespace } // namespace exhaustive_op_test diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc new file mode 100644 index 00000000000000..a2e67ff4f8fb0c --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF32UnaryTest); + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) +INSTANTIATE_TEST_SUITE_P( + SpecialValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(CreateFpValuesForBoundaryTest())); + +INSTANTIATE_TEST_SUITE_P(NormalValues, ExhaustiveF64UnaryTest, + ::testing::Values(GetNormals(1000))); + +// Tests a total of 4,000,000,000 inputs, with 16,000,000 inputs in each +// sub-test, to keep the peak memory usage low. +INSTANTIATE_TEST_SUITE_P( + LargeAndSmallMagnitudeNormalValues, ExhaustiveF64UnaryTest, + ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( + 4000000000ull, 16000000))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); +#endif // !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT64) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index e07742874f6e98..b57ad05476be25 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -13,22 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include // NOLINT #include -#include -#include +#include // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include #include -#include "xla/client/lib/constants.h" -#include "xla/client/lib/math.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/constants.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/hlo/builder/lib/math.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/hlo/builder/xla_builder.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include "xla/tests/exhaustive/error_spec.h" -#include "xla/tests/exhaustive/exhaustive_op_test_base.h" +#include "xla/tests/exhaustive/exhaustive_op_test.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" +#include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc #include "xla/types.h" #ifdef __FAST_MATH__ @@ -39,624 +39,422 @@ namespace xla { namespace exhaustive_op_test { namespace { -using Eigen::half; +#include "xla/tests/exhaustive/exhaustive_unary_test_ops.inc" -template -T EvaluatePolynomial(T x, const std::array& coeffs) { - // Evaluate the polynomial as accurately as we can using double precision and - // FMA. - double result = 0; - double x_d = static_cast(x); - for (T c : coeffs) { - result = std::fma(result, x_d, static_cast(c)); - } - return static_cast(result); -} - -// There's no std::erfinv, so we have to implement it ourselves. This follows -// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a -// different implementation from that in math.cc. -template -NativeRefT HostErfInv(NativeRefT x) { - std::array kPolyA = { - 8.8709406962545514830200e2, 1.1819493347062294404278e4, - 2.3782041382114385731252e4, 1.6235862515167575384252e4, - 4.8548868893843886794648e3, 6.9706266534389598238465e2, - 4.7072688112383978012285e1, 1.1975323115670912564578e0, - }; - std::array kPolyB = { - 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, - 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, - 4.2313330701600911252e1, 1.0000000000000000000e0, - }; - std::array kPolyC = { - 7.74545014278341407640e-4, 2.27238449892691845833e-2, - 2.41780725177450611770e-1, 1.27045825245236838258e0, - 3.64784832476320460504e0, 5.76949722146069140550e0, - 4.63033784615654529590e0, 1.42343711074968357734e0, - }; - std::array kPolyD = { - 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, - 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, - 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, - 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, - }; - std::array kPolyE = { - 2.01033439929228813265e-7, 2.71155556874348757815e-5, - 1.24266094738807843860e-3, 2.65321895265761230930e-2, - 2.96560571828504891230e-1, 1.78482653991729133580e0, - 5.46378491116411436990e0, 6.65790464350110377720e0, - }; - std::array kPolyF = { - 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, - 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, - 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, - 8.482908416595164588112026e-1, 1.414213562373095048801689e0, - }; - - if (std::abs(x) > 1 || std::isnan(x)) { - return std::numeric_limits::quiet_NaN(); - } - if (std::abs(x) == 1) { - return std::copysign(std::numeric_limits::infinity(), x); - } - - double unsigned_result = [&] { - double y = std::abs(x); - if (y <= 0.85) { - double r = 0.180625 - 0.25 * y * y; - return (y * EvaluatePolynomial(r, kPolyA)) / - EvaluatePolynomial(r, kPolyB); - } else { - double r = std::sqrt(std::log(2.0) - std::log1p(-y)); - if (r <= 5.0) { - r -= 1.6; - return EvaluatePolynomial(r, kPolyC) / EvaluatePolynomial(r, kPolyD); - } else { - r -= 5; - return EvaluatePolynomial(r, kPolyE) / EvaluatePolynomial(r, kPolyF); - } - } - }(); - return static_cast(std::copysign(unsigned_result, x)); -} - -// Digamma implementation using a polynomial from Cephes. Notably this is a -// different implementation from the one in math.cc. -template -NativeRefT HostDigamma(NativeRefT x) { - // Euler-Mascheroni constant - double kGamma = 0.57721566490153286061; - double kPi = M_PI; - - std::array kPoly = { - -4.16666666666666666667E-3, - 3.96825396825396825397E-3, - -8.33333333333333333333E-3, - 8.33333333333333333333E-2, - }; - - double reflection = 0; - if (x <= 0) { - double floor = std::floor(x); - if (x == floor) { - return std::numeric_limits::quiet_NaN(); - } - // Compute reflection term, pi * cot(pi * x). - reflection = x - floor; - if (reflection == 0.5) { - reflection = 0; - } else { - if (reflection > 0.5) { - reflection = x - (floor + 1.0f); - } - reflection = kPi / std::tan(kPi * reflection); - } - x = 1 - x; - } - - double result = 0; - if (x <= 10 && x == std::floor(x)) { - // Special case for integers <= 10. - for (int i = 1; i < x; ++i) { - result += 1.0 / i; - } - result -= kGamma; - } else { - double w = 0; - for (; x < 10; ++x) { - w += 1.0 / x; - } - if (x < 1e8) { - double z = 1.0 / (x * x); - result = z * EvaluatePolynomial(z, kPoly); - } - result = std::log(x) - 0.5 / x - result - w; - } - - // Compute the final, reflected value. - return static_cast(result - reflection); -} - -UNARY_TEST(Log, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Log, std::log, error_spec_gen); -}) +UNARY_TEST(Log, { LogOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Log1p, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Log1p, std::log1p, error_spec_gen); -}) -UNARY_TEST(Exp, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(min).rel_err(75 * eps).build(); - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(min).rel_err(33 * eps).build(); - }; - } - Run(Exp, std::exp, error_spec_gen); + Log1pOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) -UNARY_TEST(Expm1, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(100 * eps).build(); - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(33 * eps).build(); - }; - } - - Run(Expm1, std::expm1, error_spec_gen); +UNARY_TEST(Exp, { + ExpOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) +UNARY_TEST(Expm1, { Expm1Op(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Exp2, { Exp2Op(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Logistic, { - // FIXME(rmlarsen): Break into region around zero and everything else. - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - NativeT atol = std::min(static_cast(0.004), - static_cast(200 * eps)); - return ErrorSpec::Builder().abs_err(atol).rel_err(0).build(); - }; - } - EvaluateOp fn = +[](NativeRefT x) { return 1.0f / (1.0f + std::exp(-x)); }; - auto range_checker = +[](NativeInputs in, NativeT out) { - if (Eigen::numext::isnan(in[0])) { - return Eigen::numext::isnan(out); - } - return Eigen::numext::abs(out) <= 1.0f; - }; - Run(Logistic, fn, error_spec_gen, range_checker); + LogisticOp(this) + .OutputRangeCheck(+[](NativeInputs in, NativeT out) { + if (std::isnan(in[0])) { + return std::isnan(out); + } + return std::abs(out) <= 1.0f; + }) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) // It feels a little overkill to exhaustively test sqrt and pow(x, 0.5), but // this *did* find a bug, namely that some backends were assuming sqrt(x) == // pow(x, 0.5), but this is not true for x == -inf. -UNARY_TEST(PowOneHalf, { - EvaluateOp fn = +[](NativeRefT x) { return std::pow(x, 0.5f); }; - Run([](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }, fn); -}) - +UNARY_TEST(PowOneHalf, + { PowOneHalfOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Rsqrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(2 * eps) - .strict_signed_zeros() - .build(); - }; - Run(Rsqrt, +[](NativeRefT x) { return 1 / std::sqrt(x); }, error_spec_gen); + RsqrtOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(0) + .rel_err(2 * eps) + .strict_signed_zeros() + .build(); + }) + .Run(); }) - UNARY_TEST(Sqrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(2 * eps) - .strict_signed_zeros() - .build(); - }; - Run(Sqrt, std::sqrt, error_spec_gen); -}) - -UNARY_TEST(Cbrt, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(0) - .rel_err(16 * eps) - .strict_signed_zeros() - .build(); - }; - if (IsCpu(platform_)) { - error_spec_gen = +[](NativeT x) { - // While GPUs and TPUs flush subnormal inputs to zero, the CPU returns a - // relatively inaccurate approximation for such inputs. Therefore we - // allow a small absolute error (e.g. ~9e-16 for F32). This corresponds - // to a 0.5% relative error for the smallest normalized floating point - // values, increasing gradually to 100% for the smallest subnormal - // value. - NativeT denorm_min = std::numeric_limits::denorm_min(); - double abs_err = std::cbrt(denorm_min); - - if constexpr (std::is_same_v) { + SqrtOp(this) + .Error(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() - .abs_err(abs_err) - .rel_err(70 * eps) + .abs_err(0) + .rel_err(2 * eps) .strict_signed_zeros() .build(); - } else { + }) + .Run(); +}) +UNARY_TEST(Cbrt, { + CbrtOp(this) + .Error(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() - .abs_err(abs_err) - .rel_err(10 * eps) + .abs_err(0) + .rel_err(16 * eps) .strict_signed_zeros() .build(); - } - }; - } - Run(Cbrt, std::cbrt, error_spec_gen); + }) + .CpuError(+[](NativeT x) { + // While GPUs flush subnormal inputs to zero, CPU returns a relatively + // inaccurate approximation for such inputs. Therefore we allow a small + // absolute error (e.g. ~9e-16 for F32). This corresponds to a 0.5% + // relative error for the smallest normalized floating point values, + // increasing gradually to 100% for the smallest subnormal value. + NativeT denorm_min = std::numeric_limits::denorm_min(); + double abs_err = std::cbrt(denorm_min); + + if constexpr (std::is_same_v) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(abs_err) + .rel_err(70 * eps) + .strict_signed_zeros() + .build(); + } else { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(abs_err) + .rel_err(10 * eps) + .strict_signed_zeros() + .build(); + } + }) + .Run(); }) // Tests for inverse hyperbolic functions. UNARY_TEST(Acosh, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-7).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Acosh, std::acosh, error_spec_gen); + AcoshOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-7).rel_err(50 * eps).build(); + }) + .Run(); }) UNARY_TEST(Asinh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(eps).build(); - }; - } - Run(Asinh, std::asinh, error_spec_gen); + AsinhOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) UNARY_TEST(Atanh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-4).rel_err(eps).build(); - }; - } - Run(Atanh, std::atanh, error_spec_gen); + AtanhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) // Tests for inverse trigonometric functions. UNARY_TEST(Acos, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (platform_ != "Host") { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); - }; - } - Run(Acos, std::acos, error_spec_gen); + AcosOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); + }) + .Run(); }) UNARY_TEST(Asin, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2.0f * min).rel_err(10 * eps).build(); - }; - Run(Asin, std::asin, error_spec_gen); + AsinOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(2.0f * min) + .rel_err(10 * eps) + .build(); + }) + .Run(); }) UNARY_TEST(Atan, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2.0f * min).rel_err(20 * eps).build(); - }; - Run(Atan, std::atan, error_spec_gen); + AtanOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder() + .abs_err(2.0f * min) + .rel_err(20 * eps) + .build(); + }) + .Run(); }) UNARY_TEST(Cosh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - // Cosh is always greater than or equal to 1, so an absolute - // tolerance does not make sense. - return ErrorSpec::Builder().abs_err(0).rel_err(100 * eps).build(); - }; - } - auto range_checker = - +[](NativeInputs in, NativeT actual) { return !(actual < 1); }; - Run(Cosh, std::cosh, error_spec_gen, range_checker); + CoshOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .OutputRangeCheck( + +[](NativeInputs in, NativeT actual) { return !(actual < 1); }) + .Run(); }) - UNARY_TEST(Sinh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(100 * eps).build(); - }; - } - Run(Sinh, std::sinh, error_spec_gen); + SinhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); }) - UNARY_TEST(Tanh, { - auto error_spec_gen = GetDefaultSpecGenerator(); - if (IsPreV6Tpu(platform_)) { - error_spec_gen = +[](NativeT x) { - // The range of tanh is [-1:1], so no point in giving a relative - // tolerance when we have an absolute one. - return ErrorSpec::Builder().abs_err(5e-4).rel_err(0).build(); - }; - } - Run(Tanh, std::tanh, error_spec_gen, - [](NativeInputs in, NativeT out) -> bool { - if (Eigen::numext::isnan(in[0])) { - return Eigen::numext::isnan(out); + TanhOp(this) + .Error(GetDefaultSpecGenerator()) + .OutputRangeCheck([](NativeInputs in, NativeT out) -> bool { + if (std::isnan(in[0])) { + return std::isnan(out); } - return Eigen::numext::abs(out) <= 1.0f; - }); + return std::abs(out) <= 1.0f; + }) + .Run(); }) UNARY_TEST(Cos, { - auto range_checker = - +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }; - Run( - Cos, std::cos, - +[](NativeT) { + CosOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 2 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); - }, - range_checker); + }) + .OutputRangeCheck( + +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) + .Run(); }) - UNARY_TEST(Sin, { - auto range_checker = - +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }; - Run( - Sin, std::sin, - +[](NativeT) { + SinOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 2 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); - }, - range_checker); -}) + }) + .CpuArmError(+[](NativeT val) { + // Flushes subnormals and minimum positive output to 0. + NativeT output = static_cast(std::sin(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + // This error spec corresponds to a maximum relative error of 2 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(2 * eps).build(); + }) + .OutputRangeCheck( + +[](NativeInputs in, NativeT out) { return !(out < -1 || out > 1); }) + .Run(); +}) UNARY_TEST(Tan, { - Run( - Tan, std::tan, +[](NativeT) { + TanOp(this) + .Error(+[](NativeT) { // This error spec corresponds to a maximum relative error of 4 ULP. NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); - }); + }) + .CpuArmError(+[](NativeT val) { + // Flushes positive subnormals and minimum positive output to 0. + NativeT output = static_cast(std::tan(val)); + // TODO(b/365622116): Understand why ARM flushes these but x86 doesn't. + if (IsSubnormalOrMinNormal(output)) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .build(); + } + + // This error spec corresponds to a maximum relative error of 4 ULP. + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(0).rel_err(4 * eps).build(); + }) + .Run(); }) -UNARY_TEST(Erf, { Run(Erf, std::erf); }) +UNARY_TEST(Erf, { ErfOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Erfc, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(100 * eps).build(); - }; - } - Run(Erfc, std::erfc, error_spec_gen); + ErfcOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(min).rel_err(35 * eps).build(); + }) + .Run(); }) UNARY_TEST(ErfInv, { - auto error_spec_gen = +[](NativeT x) { - NativeT min = std::numeric_limits::min(); - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(5e-5).rel_err(2 * eps).build(); - }; - } - Run(ErfInv, HostErfInv, error_spec_gen); + ErfInvOp(this) + .Error(+[](NativeT x) { + NativeT min = std::numeric_limits::min(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2 * min).rel_err(50 * eps).build(); + }) + .Run(); }) UNARY_TEST(Digamma, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); - }; - if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(2e-4).rel_err(10 * eps).build(); - }; - } - Run(Digamma, HostDigamma, error_spec_gen); + DigammaOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); + }) + .GpuError(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); + }) + .Run(); }) UNARY_TEST(Lgamma, { - auto error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(150 * eps).build(); - }; - if (IsGpu(platform_)) { - error_spec_gen = +[](NativeT x) { - if constexpr (std::is_same_v) { - // Very large error on the smallest subnormal input. - if (static_cast(std::abs(x)) == 4.9406564584124654e-324) { - return ErrorSpec::Builder().abs_err(0.05).build(); + LgammaOp(this) + .Error(+[](NativeT x) { + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-5).rel_err(150 * eps).build(); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + // Very large error on the smallest subnormal input. + if (static_cast(std::abs(x)) == 4.9406564584124654e-324) { + return ErrorSpec::Builder().abs_err(0.05).build(); + } else { + return ErrorSpec::Builder().distance_err(2).build(); + } } else { - return ErrorSpec::Builder().distance_err(2).build(); + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(1e-5).rel_err(5000 * eps).build(); } - } else { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(1e-5).rel_err(5000 * eps).build(); - } - }; - } else if (IsTpu(platform_)) { - error_spec_gen = +[](NativeT x) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder().abs_err(5e-4).rel_err(5000 * eps).build(); - }; - } - Run(Lgamma, std::lgamma, error_spec_gen); + }) + .Run(); }) -UNARY_TEST(Round, { Run(Round, std::round); }) - +UNARY_TEST(Round, { RoundOp(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(RoundNearestEven, { - auto error_spec_gen = +[](NativeT) { - return ErrorSpec::Builder().abs_err(0.0).rel_err(0.0).build(); - }; int curr_direction = fegetround(); fesetround(FE_TONEAREST); - Run(RoundNearestEven, std::nearbyint, error_spec_gen); + RoundNearestEvenOp(this).Run(); fesetround(curr_direction); }) UNARY_TEST(Reciprocal, { // Can be thought of as an absolute error of `<= // |std::numeric_limits::min()|`. - auto abs_err = +[](NativeT val) -> double { + auto* abs_err = +[](NativeT val) -> double { NativeT output = static_cast(1.0) / val; if (IsSubnormal(output)) { return std::numeric_limits::min(); } return 0.0; }; - auto abs_err_bf16 = +[](NativeT val) -> double { - NativeT output = static_cast(1.0) / val; - if (IsSubnormalOrMinNormal(output)) { - return std::numeric_limits::min(); - } - return 0.0; - }; - ErrorSpecGen error_spec_gen = [](NativeT) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - }; - if (IsCpu(platform_)) { - error_spec_gen = [&](NativeT val) { - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .strict_signed_zeros() - .build(); - }; - } - if (IsGpu(platform_)) { - error_spec_gen = [&](NativeT val) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .rel_err(eps) - .strict_signed_zeros() - .build(); - }; - } - if (IsTpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) - .strict_signed_zeros() - .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { - NativeT eps = std::numeric_limits::epsilon(); + ReciprocalOp(this) + .CpuError([&](NativeT val) { return ErrorSpec::Builder() .abs_err(abs_err(val)) - .rel_err(eps) - .strict_signed_zeros() - .build(); - } else { - return ErrorSpec{}; - } - }; - } - if (IsPreV6Tpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) .strict_signed_zeros() .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { + }) + .GpuError([&](NativeT val) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder() .abs_err(abs_err(val)) - .rel_err(34 * eps) - .strict_signed_zeros() - .build(); - } else { - return ErrorSpec{}; - } - }; - } - if (IsPreV5Tpu(platform_)) { - error_spec_gen = [&](NativeT val) { - if constexpr (std::is_same_v) { - return ErrorSpec::Builder() - .abs_err(abs_err_bf16(val)) - .strict_signed_zeros() - .build(); - } else if constexpr (std::is_same_v) { - return ErrorSpec::Builder().strict_signed_zeros().build(); - } else if constexpr (std::is_same_v) { - NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec::Builder() - .abs_err(abs_err(val)) - .rel_err(136 * eps) + .rel_err(eps) .strict_signed_zeros() .build(); - } else { - return ErrorSpec{}; - } - }; - } - Run(Reciprocal, +[](NativeRefT x) { return 1 / x; }, error_spec_gen); + }) + .Run(); }) } // namespace diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc new file mode 100644 index 00000000000000..8efa13538f3632 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_test_ops.inc @@ -0,0 +1,221 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +template +T EvaluatePolynomial(T x, const std::array& coeffs) { + // Evaluate the polynomial as accurately as we can using double precision and + // FMA. + double result = 0; + double x_d = static_cast(x); + for (T c : coeffs) { + result = std::fma(result, x_d, static_cast(c)); + } + return static_cast(result); +} + +// There's no std::erfinv, so we have to implement it ourselves. This follows +// Wichura 1998, https://www.jstor.org/stable/2347330 which, notably, is a +// different implementation from that in math.cc. +template +NativeRefT HostErfInv(NativeRefT x) { + const std::array poly_a = { + 8.8709406962545514830200e2, 1.1819493347062294404278e4, + 2.3782041382114385731252e4, 1.6235862515167575384252e4, + 4.8548868893843886794648e3, 6.9706266534389598238465e2, + 4.7072688112383978012285e1, 1.1975323115670912564578e0, + }; + const std::array poly_b = { + 5.2264952788528545610e3, 2.8729085735721942674e4, 3.9307895800092710610e4, + 2.1213794301586595867e4, 5.3941960214247511077e3, 6.8718700749205790830e2, + 4.2313330701600911252e1, 1.0000000000000000000e0, + }; + const std::array poly_c = { + 7.74545014278341407640e-4, 2.27238449892691845833e-2, + 2.41780725177450611770e-1, 1.27045825245236838258e0, + 3.64784832476320460504e0, 5.76949722146069140550e0, + 4.63033784615654529590e0, 1.42343711074968357734e0, + }; + const std::array poly_d = { + 1.4859850019840355905497876e-9, 7.7441459065157709165577218e-4, + 2.1494160384252876777097297e-2, 2.0945065210512749128288442e-1, + 9.7547832001787427186894837e-1, 2.3707661626024532365971225e0, + 2.9036514445419946173133295e0, 1.4142135623730950488016887e0, + }; + const std::array poly_e = { + 2.01033439929228813265e-7, 2.71155556874348757815e-5, + 1.24266094738807843860e-3, 2.65321895265761230930e-2, + 2.96560571828504891230e-1, 1.78482653991729133580e0, + 5.46378491116411436990e0, 6.65790464350110377720e0, + }; + const std::array poly_f = { + 2.891024605872965461538222e-15, 2.010321207683943062279931e-7, + 2.611088405080593625138020e-5, 1.112800997078859844711555e-3, + 2.103693768272068968719679e-2, 1.936480946950659106176712e-1, + 8.482908416595164588112026e-1, 1.414213562373095048801689e0, + }; + + if (std::abs(x) > 1 || std::isnan(x)) { + return std::numeric_limits::quiet_NaN(); + } + if (std::abs(x) == 1) { + return static_cast( + std::copysign(std::numeric_limits::infinity(), x)); + } + + double unsigned_result = [&] { + double y = std::abs(x); + if (y <= 0.85) { + double r = 0.180625 - 0.25 * y * y; + return (y * EvaluatePolynomial(r, poly_a)) / + EvaluatePolynomial(r, poly_b); + } + + double r = std::sqrt(std::log(2.0) - std::log1p(-y)); + if (r <= 5.0) { + r -= 1.6; + return EvaluatePolynomial(r, poly_c) / EvaluatePolynomial(r, poly_d); + } + + r -= 5; + return EvaluatePolynomial(r, poly_e) / EvaluatePolynomial(r, poly_f); + }(); + return static_cast(std::copysign(unsigned_result, x)); +} + +// Digamma implementation using a polynomial from Cephes. Notably this is a +// different implementation from the one in math.cc. +template +NativeRefT HostDigamma(NativeRefT x) { + // Euler-Mascheroni constant + const double gamma_constant = 0.57721566490153286061; + + const std::array poly = { + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + double reflection = 0; + if (x <= 0) { + double floor = std::floor(x); + if (x == floor) { + return std::numeric_limits::quiet_NaN(); + } + // Compute reflection term, pi * cot(pi * x). + reflection = x - floor; + if (reflection == 0.5) { + reflection = 0; + } else { + if (reflection > 0.5) { + reflection = x - (floor + 1.0f); + } + reflection = M_PI / std::tan(M_PI * reflection); + } + x = 1 - x; + } + + double result = 0; + if (x <= 10 && x == std::floor(x)) { + // Special case for integers <= 10. + for (size_t i = 1; i < static_cast(std::floor(x)); ++i) { + result += 1.0 / static_cast(i); + } + result -= gamma_constant; + } else { + double w = 0; + while (x < 10) { + w += 1.0 / x; + ++x; + } + if (x < 1e8) { + double z = 1.0 / (x * x); + result = z * EvaluatePolynomial(z, poly); + } + result = std::log(x) - 0.5 / x - result - w; + } + + // Compute the final, reflected value. + return static_cast(result - reflection); +} + +#define DEFINE_UNARY_TEST_OP(NAME, ENQUEUE, EVALUATE) \ + template \ + class NAME final : public UnaryTestOp { \ + public: \ + using Traits = UnaryTestOp::Traits; \ + using Test = UnaryTestOp::Test; \ + \ + explicit NAME(Test* test) : UnaryTestOp(test) {} \ + ~NAME() override {} \ + \ + Traits::EnqueueOp EnqueueOp() const override ENQUEUE; \ + \ + Traits::EvaluateOp EvaluateOp() const override EVALUATE; \ + }; \ + static_assert(true, "") + +DEFINE_UNARY_TEST_OP(LogOp, { return Log; }, { return std::log; }); +DEFINE_UNARY_TEST_OP(Log1pOp, { return Log1p; }, { return std::log1p; }); +DEFINE_UNARY_TEST_OP(ExpOp, { return Exp; }, { return std::exp; }); +DEFINE_UNARY_TEST_OP(Expm1Op, { return Expm1; }, { return std::expm1; }); +DEFINE_UNARY_TEST_OP( + LogisticOp, { return Logistic; }, + { + return +[](Traits::NativeRefT x) { return 1.0f / (1.0f + std::exp(-x)); }; + }); +DEFINE_UNARY_TEST_OP( + PowOneHalfOp, + { return [](XlaOp x) { return Pow(x, ScalarLike(x, 0.5)); }; }, + { return +[](Traits::NativeRefT x) { return std::pow(x, 0.5f); }; }); +DEFINE_UNARY_TEST_OP( + Exp2Op, + { return [](XlaOp x) { return Pow(ScalarLike(x, 2.0f), x); }; }, + { return +[](Traits::NativeRefT x) { return std::exp2(x); }; }); +DEFINE_UNARY_TEST_OP( + RsqrtOp, { return Rsqrt; }, + { return +[](Traits::NativeRefT x) { return 1 / std::sqrt(x); }; }); +DEFINE_UNARY_TEST_OP(SqrtOp, { return Sqrt; }, { return std::sqrt; }); +DEFINE_UNARY_TEST_OP(CbrtOp, { return Cbrt; }, { return std::cbrt; }); +DEFINE_UNARY_TEST_OP(AcoshOp, { return Acosh; }, { return std::acosh; }); +DEFINE_UNARY_TEST_OP(AsinhOp, { return Asinh; }, { return std::asinh; }); +DEFINE_UNARY_TEST_OP(AtanhOp, { return Atanh; }, { return std::atanh; }); +DEFINE_UNARY_TEST_OP(AcosOp, { return Acos; }, { return std::acos; }); +DEFINE_UNARY_TEST_OP(AsinOp, { return Asin; }, { return std::asin; }); +DEFINE_UNARY_TEST_OP(AtanOp, { return Atan; }, { return std::atan; }); +DEFINE_UNARY_TEST_OP(CoshOp, { return Cosh; }, { return std::cosh; }); +DEFINE_UNARY_TEST_OP(SinhOp, { return Sinh; }, { return std::sinh; }); +DEFINE_UNARY_TEST_OP(TanhOp, { return Tanh; }, { return std::tanh; }); +DEFINE_UNARY_TEST_OP(CosOp, { return Cos; }, { return std::cos; }); +DEFINE_UNARY_TEST_OP(SinOp, { return Sin; }, { return std::sin; }); +DEFINE_UNARY_TEST_OP(TanOp, { return Tan; }, { return std::tan; }); +DEFINE_UNARY_TEST_OP(ErfOp, { return Erf; }, { return std::erf; }); +DEFINE_UNARY_TEST_OP(ErfcOp, { return Erfc; }, { return std::erfc; }); +DEFINE_UNARY_TEST_OP( + ErfInvOp, { return ErfInv; }, + { return HostErfInv; }); +DEFINE_UNARY_TEST_OP( + DigammaOp, { return Digamma; }, + { return HostDigamma; }); +DEFINE_UNARY_TEST_OP(LgammaOp, { return Lgamma; }, { return std::lgamma; }); +DEFINE_UNARY_TEST_OP(RoundOp, { return Round; }, { return std::round; }); +DEFINE_UNARY_TEST_OP( + RoundNearestEvenOp, { return RoundNearestEven; }, + { return std::nearbyint; }); +DEFINE_UNARY_TEST_OP( + ReciprocalOp, { return Reciprocal; }, + { return +[](Traits::NativeRefT x) { return 1 / x; }; }); + +#undef DEFINE_UNARY_TEST_OP diff --git a/third_party/xla/xla/tests/exhaustive/platform.cc b/third_party/xla/xla/tests/exhaustive/platform.cc new file mode 100644 index 00000000000000..704b1a5b8df9be --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/platform.cc @@ -0,0 +1,103 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/exhaustive/platform.h" + +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +Platform::Value GetPlatformValue(const stream_executor::Platform& platform) { + if (platform.Name() == "Host") { +// We process these copts in a library instead of the final exhaustive_xla_test +// target because we assume the final target will use the same target CPU arch +// as this target. +#ifdef __x86_64__ + return Platform::CpuValue::X86_64; +#endif +#ifdef __aarch64__ + return Platform::CpuValue::AARCH64; +#endif + } else if (platform.Name() == "CUDA") { + auto device_descriptor_status = platform.DescriptionForDevice(0); + CHECK_OK(device_descriptor_status); + std::unique_ptr device_descriptor = + std::move(*device_descriptor_status); + + auto cuda_compute_compatibility = + device_descriptor->cuda_compute_capability(); + // If not available, CudaComputeCompatibility will have major version 0. + if (cuda_compute_compatibility.IsAtLeast(1, 0)) { + return cuda_compute_compatibility; + } + } else if (platform.Name() == "ROCM") { + auto device_descriptor_status = platform.DescriptionForDevice(0); + CHECK_OK(device_descriptor_status); + std::unique_ptr device_descriptor = + std::move(*device_descriptor_status); + + auto rocm_compute_compatibility = + device_descriptor->rocm_compute_capability(); + // If not available, RocmComputeCompatibility will be an invalid platform + // value. + if (rocm_compute_compatibility.gfx_version() == "gfx000") { + return rocm_compute_compatibility; + } + } + LOG(FATAL) << "Unhandled stream_executor::Platform: " << platform.Name() + << ". Please add support to " __FILE__ "."; +} + +bool Platform::IsNvidiaP100() const { + return std::holds_alternative( + value_) && + !std::get(value_).IsAtLeast( + stream_executor::CudaComputeCapability::Volta()); +} + +bool Platform::IsNvidiaV100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Volta(); +} + +bool Platform::IsNvidiaA100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Ampere(); +} + +bool Platform::IsNvidiaH100() const { + return std::holds_alternative( + value_) && + std::get(value_) == + stream_executor::CudaComputeCapability::Hopper(); +} + +Platform::Platform(const stream_executor::Platform& platform) + : value_(GetPlatformValue(platform)) {} + +} // namespace exhaustive_op_test +} // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/platform.h b/third_party/xla/xla/tests/exhaustive/platform.h new file mode 100644 index 00000000000000..7728033ec5ea93 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/platform.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ +#define XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ + +#include + +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" + +namespace xla { +namespace exhaustive_op_test { + +// Represents an enum class of all possible openXLA execution platforms along +// with helper functions to categorically handle them. +class Platform { + public: + enum class CpuValue { + AARCH64, + X86_64, + }; + + using Value = std::variant; + + explicit Platform(const stream_executor::Platform& platform); + + bool IsCpu() const { return std::holds_alternative(value_); } + + bool IsGpu() const { + return std::holds_alternative( + value_) || + std::holds_alternative( + value_); + } + + bool IsNvidiaGpu() const { + return std::holds_alternative( + value_); + } + + bool IsNvidiaP100() const; + + bool IsNvidiaV100() const; + + bool IsNvidiaA100() const; + + bool IsNvidiaH100() const; + + bool IsAmdGpu() const { + return std::holds_alternative( + value_); + } + + const Value& value() const { return value_; } + + private: + const Value value_; +}; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_PLATFORM_H_ diff --git a/third_party/xla/xla/tests/exhaustive/test_op.h b/third_party/xla/xla/tests/exhaustive/test_op.h new file mode 100644 index 00000000000000..35ad4b51f69ad6 --- /dev/null +++ b/third_party/xla/xla/tests/exhaustive/test_op.h @@ -0,0 +1,247 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ +#define XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ + +#include +#include + +#include "xla/tests/exhaustive/exhaustive_op_test.h" +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tests/exhaustive/platform.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace exhaustive_op_test { + +// Declares a single exhaustive test operation. +// +// This class is intended to be subclassed by an actual operation implementation +// that configures EnqueueOp() and EvaluateOp() as necessary. +// +// The exhaustive test can be run using the Run() function defined here. +// +// Pure virtual functions: +// - EnqueueOp +// - EvaluateOp +template +class TestOp { + public: + using Traits = ExhaustiveOpTestTraits; + using Test = std::conditional_t< + N == 1, ExhaustiveUnaryTest, + std::conditional_t, + std::enable_if_t>>; + + explicit TestOp(Test* test) : test_(test) {} + + virtual ~TestOp() = default; + + virtual Traits::EnqueueOp EnqueueOp() const = 0; + virtual Traits::EvaluateOp EvaluateOp() const = 0; + + // Establish a verification check that each EnqueueOp() value is within range. + TestOp& OutputRangeCheck(Traits::OutputRangeCheck output_range_check) & { + output_range_check_ = output_range_check; + return *this; + } + TestOp&& OutputRangeCheck(Traits::OutputRangeCheck output_range_check) && { + output_range_check_ = output_range_check; + return std::move(*this); + } + + // The following methods set ErrorSpecGen for associated platforms. There is a + // precedence hierarchy to allow for easily setting fallbacks and overriding + // for certain platforms. + // + // CPU Precedence: + // CPU Make (x86, ARM, etc) Error -> CPU Error -> Error + // + // GPU Precedence: + // GPU Model (P100, V100, etc) Error -> GPU Make (Nvidia) Error -> GPU Error + // -> Error + + TestOp& Error(Traits::ErrorSpecGen error_spec_gen) & { + error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& Error(Traits::ErrorSpecGen error_spec_gen) && { + error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuError(Traits::ErrorSpecGen error_spec_gen) & { + cpu_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuError(Traits::ErrorSpecGen error_spec_gen) && { + cpu_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuX86Error(Traits::ErrorSpecGen error_spec_gen) & { + cpu_x86_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuX86Error(Traits::ErrorSpecGen error_spec_gen) && { + cpu_x86_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& CpuArmError(Traits::ErrorSpecGen error_spec_gen) & { + cpu_arm_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& CpuArmError(Traits::ErrorSpecGen error_spec_gen) && { + cpu_arm_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuError(Traits::ErrorSpecGen error_spec_gen) & { + gpu_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuError(Traits::ErrorSpecGen error_spec_gen) && { + gpu_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuNvidiaError(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuNvidiaError(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuP100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_p100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuP100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_p100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuV100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_v100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuV100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_v100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuA100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_a100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuA100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_a100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + TestOp& GpuH100Error(Traits::ErrorSpecGen error_spec_gen) & { + gpu_nv_h100_error_spec_gen_ = error_spec_gen; + return *this; + } + TestOp&& GpuH100Error(Traits::ErrorSpecGen error_spec_gen) && { + gpu_nv_h100_error_spec_gen_ = std::move(error_spec_gen); + return std::move(*this); + } + + // Execute the TestCase as configured. + // + // Requires invoking on a TestCase&& to ensure the TestCase is not used + // afterwards. + void Run() && { + typename Traits::ErrorSpecGen error_spec_gen; + if (test_->Platform().IsCpu()) { + switch (std::get(test_->Platform().value())) { + case Platform::CpuValue::X86_64: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_x86_error_spec_gen_, cpu_error_spec_gen_, error_spec_gen_}); + break; + } + case Platform::CpuValue::AARCH64: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_arm_error_spec_gen_, cpu_error_spec_gen_, error_spec_gen_}); + break; + } + default: { + error_spec_gen = PickFirstErrorSpecGenPresent( + {cpu_error_spec_gen_, error_spec_gen_}); + break; + } + } + } else if (test_->Platform().IsGpu()) { + if (test_->Platform().IsNvidiaGpu()) { + if (test_->Platform().IsNvidiaP100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_p100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaV100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_v100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaA100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_a100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else if (test_->Platform().IsNvidiaH100()) { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_h100_error_spec_gen_, gpu_nv_error_spec_gen_, + gpu_error_spec_gen_, error_spec_gen_}); + } else { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_nv_error_spec_gen_, gpu_error_spec_gen_, error_spec_gen_}); + } + } else { + error_spec_gen = PickFirstErrorSpecGenPresent( + {gpu_error_spec_gen_, error_spec_gen_}); + } + } else { + error_spec_gen = PickFirstErrorSpecGenPresent({error_spec_gen_}); + } + test_->Run(EnqueueOp(), EvaluateOp(), error_spec_gen, output_range_check_); + } + + private: + Test* test_ = nullptr; + Traits::OutputRangeCheck output_range_check_ = nullptr; + Traits::ErrorSpecGen error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_x86_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen cpu_arm_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_p100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_v100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_a100_error_spec_gen_ = nullptr; + Traits::ErrorSpecGen gpu_nv_h100_error_spec_gen_ = nullptr; +}; + +template +using UnaryTestOp = TestOp; + +template +using BinaryTestOp = TestOp; + +} // namespace exhaustive_op_test +} // namespace xla + +#endif // XLA_TESTS_EXHAUSTIVE_TEST_OP_H_ diff --git a/third_party/xla/xla/tests/filecheck.h b/third_party/xla/xla/tests/filecheck.h index 5ea4134d691a72..e96152510c455f 100644 --- a/third_party/xla/xla/tests/filecheck.h +++ b/third_party/xla/xla/tests/filecheck.h @@ -16,26 +16,7 @@ limitations under the License. #ifndef XLA_TESTS_FILECHECK_H_ #define XLA_TESTS_FILECHECK_H_ -#include - -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/types.h" - -namespace xla { - -// Runs FileCheck with the given pattern over given input string. Provided that -// FileCheck can execute, returns true if and only if FileCheck succeeded in -// matching the input. -absl::StatusOr RunFileCheck(const std::string& input, - absl::string_view pattern); - -// Runs FileCheck with the given pattern file over given input string. Provided -// that FileCheck can execute, returns true if and only if FileCheck succeeded -// in matching the input. -absl::StatusOr RunFileCheckWithPatternFile( - const std::string& input, const std::string& pattern_file); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/testlib/filecheck.h" #endif // XLA_TESTS_FILECHECK_H_ diff --git a/third_party/xla/xla/tests/float8_test.cc b/third_party/xla/xla/tests/float8_test.cc index ab5debea32355b..648c718d7cd958 100644 --- a/third_party/xla/xla/tests/float8_test.cc +++ b/third_party/xla/xla/tests/float8_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" @@ -27,11 +27,12 @@ limitations under the License. namespace xla { namespace { -// Test FP8 floating-point types (F8E5M2, F8E4M3FN) +// Test FP8 floating-point types template class Float8Test : public ClientLibraryTestBase {}; -using DataTypes = ::testing::Types; +using DataTypes = ::testing::Types; TYPED_TEST_SUITE(Float8Test, DataTypes); XLA_TYPED_TEST(Float8Test, ScalarOperation) { diff --git a/third_party/xla/xla/tests/floor_ceil_test.cc b/third_party/xla/xla/tests/floor_ceil_test.cc index 3bcab08518ed72..c164645e954e7a 100644 --- a/third_party/xla/xla/tests/floor_ceil_test.cc +++ b/third_party/xla/xla/tests/floor_ceil_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tests/fmax_fmin_test.cc b/third_party/xla/xla/tests/fmax_fmin_test.cc index d7ea4d28d45471..b386de39ad20b3 100644 --- a/third_party/xla/xla/tests/fmax_fmin_test.cc +++ b/third_party/xla/xla/tests/fmax_fmin_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/service/service.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/fuzz/BUILD b/third_party/xla/xla/tests/fuzz/BUILD index 1d07a5a088d2e4..ed6396a1c9208c 100644 --- a/third_party/xla/xla/tests/fuzz/BUILD +++ b/third_party/xla/xla/tests/fuzz/BUILD @@ -16,7 +16,7 @@ cc_library( [hlo_test( name = hlo + "_test", hlo = hlo, - tags = (["no_rocm"] if hlo == "rand_000079.hlo" else []), # No int8 + tags = (["cuda-only"] if hlo == "rand_000079.hlo" else []), # No int8 ) for hlo in glob( include = ["rand_*.hlo"], exclude = [ diff --git a/third_party/xla/xla/tests/gather_operation_test.cc b/third_party/xla/xla/tests/gather_operation_test.cc index 94d02abbea3b04..7bf57a8f05138f 100644 --- a/third_party/xla/xla/tests/gather_operation_test.cc +++ b/third_party/xla/xla/tests/gather_operation_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include "xla/array.h" -#include "xla/client/xla_builder.h" #include "xla/execution_options_util.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/service/service.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/tests/grouped_convolution_test.cc b/third_party/xla/xla/tests/grouped_convolution_test.cc index 24f41c3c96cfeb..7a86547f171aae 100644 --- a/third_party/xla/xla/tests/grouped_convolution_test.cc +++ b/third_party/xla/xla/tests/grouped_convolution_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" -#include "xla/client/xla_computation.h" #include "xla/execution_options_util.h" -#include "xla/service/despecializer.h" -#include "xla/service/float_normalization.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/transforms/despecializer.h" +#include "xla/hlo/transforms/simplifiers/float_normalization.h" #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/half_test.cc b/third_party/xla/xla/tests/half_test.cc index a86e47ce53f782..04f23a6c1faff2 100644 --- a/third_party/xla/xla/tests/half_test.cc +++ b/third_party/xla/xla/tests/half_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/test.h" #include "xla/test_helpers.h" diff --git a/third_party/xla/xla/tests/hlo_metadata_test.cc b/third_party/xla/xla/tests/hlo_metadata_test.cc index ed5260426044eb..30cb1fa0e3b262 100644 --- a/third_party/xla/xla/tests/hlo_metadata_test.cc +++ b/third_party/xla/xla/tests/hlo_metadata_test.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/service/local_service.h" #include "xla/test_helpers.h" #include "xla/tests/local_client_test_base.h" diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index dbbcd5f866e924..18c53c4fa0c727 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -15,12 +15,14 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" +#include #include #include #include #include -#include +#include #include +#include #include #include @@ -28,63 +30,79 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" #include "xla/hlo/utils/hlo_query.h" -#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/service/backend.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_module_util.h" -#include "xla/service/hlo_parser.h" +#include "xla/service/hlo_runner.h" #include "xla/service/hlo_runner_interface.h" #include "xla/service/hlo_runner_pjrt.h" +#include "xla/service/hlo_verifier.h" #include "xla/service/platform_util.h" #include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" #include "xla/tests/literal_test_util.h" +#include "xla/tests/new_hlo_test_base.h" #include "xla/tests/pjrt_client_registry.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" +#include "xla/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { - namespace { -using absl::string_view; -using std::optional; +constexpr absl::string_view kInterpreter = "interpreter"; -constexpr char kInterpreter[] = "interpreter"; +// Returns either an HloRunner or HloRunnerPjRt implementation depending on +// whether there exists a registered PjRtClientFactory. +absl::StatusOr> GetHloRunnerForTest( + se::Platform* test_platform) { + if (ShouldUsePjRt()) { + PjRtClientTestFactoryRegistry& pjrt_registry = + GetGlobalPjRtClientTestFactory(); + TF_ASSIGN_OR_RETURN(std::unique_ptr client, + pjrt_registry.Get()()); + PjRtClientTestFactoryRegistry::DeviceShapeRepresentationFn + device_shape_representation_fn = + pjrt_registry.GetDeviceShapeRepresentationFn(client.get()); + PjRtClientTestFactoryRegistry::DeviceShapeSizeFn device_shape_size_fn = + pjrt_registry.GetDeviceShapeSizeFn(client.get()); -bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { - if (lhs.parameters_size() != rhs.parameters_size()) { - return false; + return std::make_unique(std::move(client), + device_shape_representation_fn, + device_shape_size_fn); } - for (int i = 0; i < lhs.parameters_size(); i++) { - if (!Shape::Equal().IgnoreElementSizeInLayout()(lhs.parameters(i), - rhs.parameters(i))) { - return false; - } - } - return Shape::Equal().IgnoreElementSizeInLayout()(lhs.result(), rhs.result()); + + return std::make_unique(test_platform); } -ProgramShape GetProgramShapeWithLayout(const HloModule& module) { - ProgramShape program_shape; - const auto* entry = module.entry_computation(); - for (const auto* param : entry->parameter_instructions()) { - *program_shape.add_parameters() = param->shape(); - *program_shape.add_parameter_names() = param->name(); - } - *program_shape.mutable_result() = entry->root_instruction()->shape(); - return program_shape; +absl::StatusOr> GetHloRunnerForReference( + se::Platform* reference_platform) { + return std::make_unique(reference_platform); } } // namespace @@ -102,19 +120,12 @@ HloTestBase::HloTestBase(se::Platform* test_platform, bool verifier_layout_sensitive, bool allow_mixed_precision_in_hlo_verifier, HloPredicate instruction_can_change_layout_func) - : test_runner_(test_platform), - reference_runner_(reference_platform), - verifier_layout_sensitive_(verifier_layout_sensitive), - allow_mixed_precision_in_hlo_verifier_( - allow_mixed_precision_in_hlo_verifier), - instruction_can_change_layout_func_(instruction_can_change_layout_func), - test_platform_(test_platform) { - hlo_verifier_ = std::make_unique( - /*layout_sensitive=*/verifier_layout_sensitive, - /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier, - instruction_can_change_layout_func); - runner_ = GetHloRunner().value(); -} + : NewHloTestBase( + /*test_runner=*/GetHloRunnerForTest(test_platform).value(), + /*reference_runner=*/ + GetHloRunnerForReference(reference_platform).value(), + verifier_layout_sensitive, allow_mixed_precision_in_hlo_verifier), + test_platform_(test_platform) {} /*static*/ se::Platform* HloTestBase::GetReferencePlatform() { auto result = PlatformUtil::GetPlatform(kInterpreter); @@ -128,895 +139,13 @@ HloTestBase::HloTestBase(se::Platform* test_platform, return result.value(); } -std::unique_ptr HloTestBase::CreateNewUnverifiedModule( - const std::string& name) { - return std::make_unique(name, GetModuleConfigForTest()); -} - -std::unique_ptr HloTestBase::CreateNewVerifiedModule( - const std::string& name, int64_t replica_count) { - return std::make_unique( - name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_, - backend().compiler()->ShapeSizeBytesFunction(), - instruction_can_change_layout_func_); -} - -absl::StatusOr> -HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count, - int64_t num_partitions) { - return ParseAndReturnVerifiedModule( - hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); -} - -absl::Status HloTestBase::UpdateEntryComputationLayoutToMatchProgramLayout( - HloModule* module) { - for (auto* const computation : module->computations({})) { - if (computation->IsEntryComputation()) { - for (int64_t i = 0; i < computation->num_parameters(); ++i) { - const Shape& param_shape = - computation->parameter_instruction(i)->shape(); - TF_RETURN_IF_ERROR(computation->parent() - ->mutable_entry_computation_layout() - ->mutable_parameter_layout(i) - ->CopyLayoutFromShape(param_shape)); - } - - TF_RETURN_IF_ERROR( - computation->parent() - ->mutable_entry_computation_layout() - ->mutable_result_layout() - ->CopyLayoutFromShape(computation->root_instruction()->shape())); - } - } - return absl::OkStatus(); -} - -absl::StatusOr> -HloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config) { - auto module = std::make_unique( - TestName(), config, verifier_layout_sensitive_, - allow_mixed_precision_in_hlo_verifier_, - backend().compiler()->ShapeSizeBytesFunction(), - instruction_can_change_layout_func_); - TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); - UpdateEntryComputationLayout(module.get()); - return std::move(module); -} - -HloComputation* HloTestBase::AddEntryComputationAndUpdateEntryComputationLayout( - HloModule* module, std::unique_ptr computation) { - auto comp = module->AddEntryComputation(std::move(computation)); - UpdateEntryComputationLayout(module); - return comp; -} - -void HloTestBase::UpdateEntryComputationLayout(HloModule* module) { - xla::UpdateEntryComputationLayout( - module, test_runner_.device_shape_representation_fn()); -} - -/* static */ -absl::StatusOr HloTestBase::RunHloPass(HloPassInterface* hlo_pass, - HloModule* module) { - const std::string module_str_before_run = - module->ToProto().ShortDebugString(); - const auto status_or = hlo_pass->Run(module); - if (status_or.status().ok()) { - const std::string module_str_after_run = - module->ToProto().ShortDebugString(); - const bool passChangedHlo = status_or.value(); - if (passChangedHlo) { - // Check that the proto actually changed. - EXPECT_NE(module_str_after_run, module_str_before_run); - } else { - // Check that the proto remains same. - EXPECT_EQ(module_str_after_run, module_str_before_run); - } - } - return status_or; -} - -/* static */ -absl::StatusOr HloTestBase::RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group) { - const std::string module_group_str_before_run = - module_group->ToProto().ShortDebugString(); - const auto status_or = hlo_pass.RunOnModuleGroup(module_group); - if (status_or.status().ok()) { - const std::string module_group_str_after_run = - module_group->ToProto().ShortDebugString(); - const bool passChangedHlo = status_or.value(); - if (passChangedHlo) { - // Check that the proto actually changed. - EXPECT_NE(module_group_str_after_run, module_group_str_before_run); - } else { - // Check that the proto remains same. - EXPECT_EQ(module_group_str_after_run, module_group_str_before_run); - } - } - return status_or; -} - -/* static */ -PrecisionConfig HloTestBase::DefaultPrecisionConfig(int operands) { - PrecisionConfig precision_config; - precision_config.mutable_operand_precision()->Resize( - operands, PrecisionConfig::DEFAULT); - return precision_config; -} - -void HloTestBase::SetAotFastMathDebugOptions(DebugOptions* options) { - options->set_xla_cpu_enable_fast_math(true); - options->set_xla_gpu_enable_fast_min_max(true); - options->set_xla_cpu_enable_fast_min_max(true); - options->set_xla_cpu_fast_math_honor_nans(false); - options->set_xla_cpu_fast_math_honor_infs(false); - options->set_xla_cpu_fast_math_honor_functions(false); - options->set_xla_cpu_fast_math_honor_division(false); -} - -DebugOptions HloTestBase::GetDebugOptionsForTest() { - auto debug_options = GetDebugOptionsFromFlags(); - // TODO(b/38354253): Change tests to use Parameters instead of Constants. - debug_options.add_xla_disable_hlo_passes("constant_folding"); - debug_options.set_xla_hlo_evaluator_use_fast_path(true); - return debug_options; -} - -void HloTestBase::RunAndFilecheckHloRewrite( - absl::string_view hlo, HloPassInterface&& hlo_pass, - std::optional expected, - std::function after_pass_checks, - const HloModuleConfig* config) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - config ? ParseAndReturnVerifiedModule(hlo, *config) - : ParseAndReturnVerifiedModule(hlo)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&hlo_pass, module.get())); - EXPECT_EQ(changed, expected.has_value()) << module->ToString(); - if (changed) { - TF_ASSERT_OK_AND_ASSIGN( - bool filecheck_matches, - RunFileCheck( - module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), - *expected)); - EXPECT_TRUE(filecheck_matches); - if (after_pass_checks) { - after_pass_checks(module.get()); - } - } -} - -void HloTestBase::RunAndFilecheckHloModuleGroupRewrite( - absl::Span hlo_module_strs, - HloPassInterface&& hlo_pass, - std::optional> expected) { - std::vector> modules; - for (absl::string_view hlo : hlo_module_strs) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - modules.push_back(std::move(module)); - } - HloModuleGroup module_group("test_input_module_group", std::move(modules)); - - TF_ASSERT_OK_AND_ASSIGN(bool changed, - RunHloPass(std::move(hlo_pass), &module_group)); - EXPECT_EQ(changed, expected.has_value()) << module_group.ToString(); - - if (!changed) { - return; - } - - EXPECT_THAT(module_group.modules(), - ::testing::SizeIs(expected.value().size())); - int index = 0; - for (auto expected_str : expected.value()) { - TF_ASSERT_OK_AND_ASSIGN( - bool filecheck_matches, - RunFileCheck(module_group.module(index).ToString( - HloPrintOptions{}.set_print_operand_shape(false)), - expected_str)); - EXPECT_TRUE(filecheck_matches); - index++; - } -} - -absl::StatusOr HloTestBase::Execute( - std::unique_ptr module, absl::Span arguments, - bool run_hlo_passes) { - return runner_->Execute(std::move(module), arguments, run_hlo_passes); -} - -Literal HloTestBase::ExecuteNoHloPasses(std::unique_ptr module, - absl::Span arguments) { - return Execute(std::move(module), arguments, - /*run_hlo_passes=*/false) - .value(); -} - absl::StatusOr> HloTestBase::GetHloRunner() { - if (runner_ != nullptr) { - return std::move(runner_); - } absl::StatusOr> status_or_runner = GetHloRunnerForTest(test_platform_); - // Test for successful creation of PjRt based Hlo Runner. - EXPECT_TRUE(status_or_runner.ok()); - - return std::move(status_or_runner.value()); -} - -Literal HloTestBase::ExecuteAndTransfer(std::unique_ptr module, - absl::Span arguments) { - return runner_->Execute(std::move(module), arguments, true, nullptr).value(); -} - -std::vector HloTestBase::CompareInputs(const HloModule& module_0, - const HloModule& module_1) { - const auto params_0 = module_0.entry_computation()->parameter_instructions(); - const auto params_1 = module_1.entry_computation()->parameter_instructions(); - std::vector mismatches; - int64_t min = std::min(params_0.size(), params_1.size()); - int64_t max = std::max(params_0.size(), params_1.size()); - for (int64_t i = 0; i < min; ++i) { - const HloModuleConfig& module_config_0 = module_0.config(); - const Shape& param_shape_0 = - (module_config_0.has_entry_computation_layout() && - module_config_0.entry_computation_layout() - .parameter_layout(i) - .shape() - .is_static()) - ? module_config_0.entry_computation_layout() - .parameter_layout(i) - .shape() - : params_0[i]->shape(); - - const HloModuleConfig& module_config_1 = module_1.config(); - const Shape& param_shape_1 = - (module_config_1.has_entry_computation_layout() && - module_config_1.entry_computation_layout() - .parameter_layout(i) - .shape() - .is_static()) - ? module_config_1.entry_computation_layout() - .parameter_layout(i) - .shape() - : params_1[i]->shape(); - - if (!Shape::Equal().IgnoreTilesInLayout()(param_shape_0, param_shape_1)) { - mismatches.push_back(i); - } - } - for (int64_t i = min; i < max; i++) { - mismatches.push_back(i); - } - return mismatches; -} - -absl::StatusOr> HloTestBase::ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, bool use_threads, bool run_hlo_passes) { - HloRunner::ReplicatedExecuteOptions options; - options.num_replicas = num_replicas; - options.run_hlo_passes = run_hlo_passes; - options.use_threads = use_threads; - for (auto argument : arguments) { - options.arguments.push_back(argument); - } - - return runner_->ExecuteReplicated(std::move(module), options); -} - -absl::StatusOr> HloTestBase::ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, DeviceAssignment* device_assignment, - bool run_hlo_passes, bool use_threads) { - HloRunner::ReplicatedExecuteOptions options; - options.num_replicas = num_replicas; - options.run_hlo_passes = run_hlo_passes; - options.use_threads = use_threads; - for (auto argument : arguments) { - options.arguments.push_back(argument); - } - return runner_->ExecuteReplicated(std::move(module), options, - device_assignment); -} - -absl::StatusOr> HloTestBase::ExecuteReplicated( - std::function executable_provider, - std::function argument_count_provider, - std::function argument_provider, - int64_t num_replicas, bool run_hlo_passes, - DeviceAssignment* device_assignment) { - HloRunner::ReplicatedExecuteOptions options; - options.num_replicas = num_replicas; - options.run_hlo_passes = run_hlo_passes; - options.use_threads = true; - return runner_->ExecuteReplicated(executable_provider, - argument_count_provider, argument_provider, - options, device_assignment); -} - -absl::StatusOr> HloTestBase::ExecuteReplicated( - std::unique_ptr module, - std::vector> arguments, int64_t num_replicas, - bool run_hlo_passes) { - CHECK(num_replicas > 0 && "expect at least one replica"); - CHECK(num_replicas == arguments.size() && - "expect arguments for each replica"); - int64_t argument_count = arguments.front().size(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - runner_->CreateExecutable(std::unique_ptr(std::move(module)), - run_hlo_passes)); - return ExecuteReplicated( - /*executable_provider=*/[&](int64_t) { return executable.get(); }, - /*argument_count_provider=*/[&](int64_t) { return argument_count; }, - /*argument_provider=*/ - [&](int64_t replica_idx, int64_t argument_idx) -> const Literal* { - return arguments[replica_idx][argument_idx]; - }, - num_replicas, /*run_hlo_passes=*/run_hlo_passes, - /*device_assignment=*/nullptr); -} - -absl::StatusOr> HloTestBase::MakeReferenceModule( - const HloModule& test_module, - const std::function& reference_preprocessor) { - std::unique_ptr reference_module = test_module.Clone(); - const auto& program_shape = GetProgramShapeWithLayout(test_module); - - if (reference_preprocessor != nullptr) { - reference_preprocessor(reference_module.get()); - if (!ProgramShapesEqual(program_shape, - GetProgramShapeWithLayout(*reference_module))) { - return InvalidArgument( - "reference preprocessor must not modify the program shape"); - } - } - TF_RETURN_IF_ERROR(hlo_verifier_->Run(reference_module.get()).status()); - return std::move(reference_module); -} - -absl::StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( - std::unique_ptr module, - const absl::Span arguments, - const optional& error, bool run_hlo_passes, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - TF_RETURN_IF_ERROR(hlo_verifier_->Run(module.get()).status()); - TF_ASSIGN_OR_RETURN(auto reference_module, - MakeReferenceModule(*module, reference_preprocessor)); - if (test_preprocessor) { - test_preprocessor(module.get()); - } - // Execute on two backends. - TF_ASSIGN_OR_RETURN(auto test, runner_->Execute(std::move(module), arguments, - run_hlo_passes)); - TF_ASSIGN_OR_RETURN(auto reference, - reference_runner_.Execute(std::move(reference_module), - arguments, run_hlo_passes)); - if (reference.IsAll(0)) { - LOG(WARNING) << "Reference value is only zeros."; - } - - return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, - error); -} - -::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, - const absl::Span arguments, - const optional& error, - const std::function& reference_preprocessor) { - auto result = - RunAndCompareInternal(std::move(module), arguments, error, - /*run_hlo_passes=*/true, reference_preprocessor); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return result.value(); -} - -::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, - const absl::Span arguments, - const optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - auto result = RunAndCompareInternal( - std::move(module), arguments, error, - /*run_hlo_passes=*/false, reference_preprocessor, test_preprocessor); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return result.value(); -} - -::testing::AssertionResult HloTestBase::RunAndCompare( - std::unique_ptr module, const optional& error, - const std::function& reference_preprocessor, - std::optional args_max_bits_of_precision) { - auto fake_arguments = - MakeFakeArguments(module.get(), /*pseudo_random=*/true, - /*use_large_range=*/false, - /*treat_gte_as_data_formatting=*/false, - args_max_bits_of_precision) - .value(); - - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - return RunAndCompare(std::move(module), fake_argument_ptrs, error, - reference_preprocessor); -} - -::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - std::unique_ptr module, const optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - const auto fake_arguments = MakeFakeArguments(module.get()).value(); - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - auto assertion_result = - RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, - reference_preprocessor, test_preprocessor); - if (!assertion_result) { - for (const auto& literal : fake_arguments) { - uint64_t total_elements = 1; - absl::c_for_each(literal.shape().dimensions(), - [&](int64_t dim) { total_elements *= dim; }); - if (total_elements > 1000) { - assertion_result << "argument literal is too large to print: " - << literal.shape().ToString(); - continue; - } - assertion_result << "argument literal: " << literal.ToString(); - } - } - return assertion_result; -} - -::testing::AssertionResult HloTestBase::Run(std::unique_ptr module, - bool run_hlo_passes) { - const auto fake_arguments = MakeFakeArguments(module.get()).value(); - const auto change = hlo_verifier_->Run(module.get()); - if (!change.ok()) { - return ::testing::AssertionFailure() << change.status(); - } - - const auto output = - runner_->Execute(std::move(module), fake_arguments, run_hlo_passes); - return output.ok() - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << output.status().message(); -} - -::testing::AssertionResult HloTestBase::RunAndCompare( - string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor, - std::optional args_max_bits_of_precision) { - auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); - if (!module_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_or_status.status().ToString(); - } - return RunAndCompare(std::move(module_or_status).value(), error, - reference_preprocessor, args_max_bits_of_precision); -} - -absl::StatusOr<::testing::AssertionResult> -HloTestBase::RunAndCompareTwoModulesInternalReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - HloRunner::ReplicatedExecuteOptions options, - const std::optional& error) { - TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_0.get()).status()); - TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_1.get()).status()); - - // Execute the two modules. - TF_ASSIGN_OR_RETURN(auto test_0, - runner_->ExecuteReplicated(std::move(module_0), options)); - TF_ASSIGN_OR_RETURN(auto test_1, - runner_->ExecuteReplicated(std::move(module_1), options)); - - for (auto [expected, actual] : llvm::zip_equal(test_0, test_1)) { - auto compare_result = LiteralTestUtil::NearOrEqual(expected, actual, error); - if (!compare_result) { - return compare_result; - } - } - return ::testing::AssertionSuccess(); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModulesReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - HloRunner::ReplicatedExecuteOptions options, - const optional& error) { - int replica_count = module_0->config().replica_count(); - if (replica_count != module_1->config().replica_count()) { - return ::testing::AssertionFailure() - << "Number of replicas is not the same: " << replica_count << " Vs " - << module_1->config().replica_count(); - } - if (options.num_replicas != replica_count) { - return ::testing::AssertionFailure() - << "Number of execution replicas is different from number of " - "replicas in the module: requested number of replicas = " - << options.num_replicas - << ", number of replicas in hlo = " << replica_count; - } - - std::vector mismatches = CompareInputs(*module_0, *module_1); - if (!mismatches.empty()) { - return ::testing::AssertionFailure() - << "Error: parameter mismatch at indices: " - << absl::StrJoin(mismatches, ","); - } - auto num_args = module_0->entry_computation()->num_parameters(); - if (num_args != options.arguments.size()) { - return ::testing::AssertionFailure() - << "Mismatch in number of arguments passed while running replicated " - "hlo module. Expected: " - << num_args << ", actual: " << options.arguments.size(); - } - auto result = RunAndCompareTwoModulesInternalReplicated( - std::move(module_0), std::move(module_1), options, error); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return result.value(); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModulesReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - bool run_hlo_passes, bool use_threads, - const std::optional& error) { - absl::StatusOr> fake_arguments = MakeFakeArguments( - /*module=*/module_0.get(), /*pseudo_random=*/true, - /*use_large_range=*/false, - /*treat_gte_as_data_formatting=*/false, - /*max_bits_of_precision=*/std::nullopt); - CHECK_OK(fake_arguments); - std::vector fake_argument_ptrs; - absl::c_transform( - /*input=*/*fake_arguments, - /*output=*/std::back_inserter(fake_argument_ptrs), - /*unary_op=*/[](const Literal& literal) -> Literal* { - return const_cast(&literal); - }); - HloRunner::ReplicatedExecuteOptions options{ - /*num_replicas=*/module_0->config().replica_count(), - /*arguments=*/fake_argument_ptrs, - /*infeed_values=*/{}, - /*infeed_steps=*/-1, - /*outfeed_shape=*/{}, - /*outfeed_values=*/nullptr, - /*run_hlo_passes=*/run_hlo_passes, - /*use_threads=*/use_threads}; - return RunAndCompareTwoModulesReplicated(std::move(module_0), - std::move(module_1), options, error); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModulesReplicated( - absl::string_view module_0, absl::string_view module_1, bool run_hlo_passes, - bool use_threads, const std::optional& error) { - auto module_0_or_status = ParseAndReturnVerifiedModule(module_0); - if (!module_0_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_0_or_status.status().ToString(); - } - - auto module_1_or_status = ParseAndReturnVerifiedModule(module_1); - if (!module_1_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_1_or_status.status().ToString(); - } - return RunAndCompareTwoModulesReplicated( - std::move(module_0_or_status).value(), - std::move(module_1_or_status).value(), run_hlo_passes, use_threads, - error); -} - -absl::StatusOr<::testing::AssertionResult> -HloTestBase::RunAndCompareTwoModulesInternal( - std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, - const std::optional& error, bool run_hlo_passes) { - TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_0.get()).status()); - TF_RETURN_IF_ERROR(hlo_verifier_->Run(module_1.get()).status()); - - // Execute the two modules. - TF_ASSIGN_OR_RETURN(auto test_0, runner_->Execute(std::move(module_0), - arguments, run_hlo_passes)); - TF_ASSIGN_OR_RETURN(auto test_1, runner_->Execute(std::move(module_1), - arguments, run_hlo_passes)); - - return LiteralTestUtil::NearOrEqual(/*expected=*/test_0, /*actual=*/test_1, - error); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( - std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, - const optional& error, bool run_hlo_passes) { - auto result = - RunAndCompareTwoModulesInternal(std::move(module_0), std::move(module_1), - arguments, error, run_hlo_passes); - if (!result.ok()) { - return ::testing::AssertionFailure() << result.status(); - } - return result.value(); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( - std::unique_ptr module_0, std::unique_ptr module_1, - const optional& error, bool run_hlo_passes, - std::optional args_max_bits_of_precision) { - std::vector mismatches = CompareInputs(*module_0, *module_1); - if (!mismatches.empty()) { - return ::testing::AssertionFailure() - << "Error : mismatching parameter shapes for parameters " - << absl::StrJoin(mismatches, ", "); - } - - absl::StatusOr> fake_arguments = MakeFakeArguments( - module_0.get(), /*pseudo_random=*/true, /*use_large_range=*/false, - /*treat_gte_as_data_formatting=*/false, args_max_bits_of_precision); - CHECK_OK(fake_arguments); - - std::vector fake_argument_ptrs; - absl::c_transform( - *fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - return RunAndCompareTwoModules(std::move(module_0), std::move(module_1), - fake_argument_ptrs, error, run_hlo_passes); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( - string_view hlo_string_module_0, string_view hlo_string_module_1, - const std::optional& error, bool run_hlo_passes, - std::optional args_max_bits_of_precision) { - auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); - if (!module_0_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_0_or_status.status().ToString(); - } - - auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1); - if (!module_1_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_1_or_status.status().ToString(); - } - return RunAndCompareTwoModules(std::move(module_0_or_status).value(), - std::move(module_1_or_status).value(), error, - run_hlo_passes, args_max_bits_of_precision); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( - string_view hlo_string_module_0, string_view hlo_string_module_1, - const HloModuleConfig& config_0, const HloModuleConfig& config_1, - const std::optional& error, bool run_hlo_passes, - std::optional args_max_bits_of_precision) { - auto module_0_or_status = - ParseAndReturnVerifiedModule(hlo_string_module_0, config_0); - if (!module_0_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_0_or_status.status().ToString(); - } - - auto module_1_or_status = - ParseAndReturnVerifiedModule(hlo_string_module_1, config_1); - if (!module_1_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_1_or_status.status().ToString(); - } - return RunAndCompareTwoModules(std::move(module_0_or_status).value(), - std::move(module_1_or_status).value(), error, - run_hlo_passes, args_max_bits_of_precision); -} - -::testing::AssertionResult HloTestBase::RunAndCompareTwoModules( - absl::string_view hlo_string_module_0, - absl::string_view hlo_string_module_1, - const absl::Span arguments, - const std::optional& error, bool run_hlo_passes) { - auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); - if (!module_0_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_0_or_status.status().ToString(); - } - - auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1); - if (!module_1_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_1_or_status.status().ToString(); - } - return RunAndCompareTwoModules(std::move(module_0_or_status).value(), - std::move(module_1_or_status).value(), - arguments, error, run_hlo_passes); -} - -::testing::AssertionResult HloTestBase::Run( - string_view hlo_string, bool run_hlo_passes, ExecutionProfile* profile, - const tsl::protobuf::Message* backend_config, bool use_random_data) { - auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); - if (!module_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_or_status.status().ToString(); - } - std::unique_ptr module = std::move(module_or_status.value()); - const auto fake_arguments = - MakeFakeArguments(module.get(), use_random_data).value(); - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - if (profile != nullptr) { - // We have to enable HLO profiling since otherwise currently the - // ExecutionProfile is not correct. - // - // TODO(b/119432044): Fix collection of the ExecutionProfile - // so that this is not necessary. - HloModuleConfig config = module->config(); - DebugOptions debug_options = config.debug_options(); - debug_options.set_xla_hlo_profile(true); - config.set_debug_options(debug_options); - module->set_config(config); - } - - if (backend_config) { - // Set backend configuration if it is given. - HloInstruction* instruction = - module->entry_computation()->root_instruction(); - absl::Status s = instruction->set_backend_config(*backend_config); - return s.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << s.message(); - } - - auto output = runner_->Execute(std::move(module), fake_argument_ptrs, - /*run_hlo_passes=*/run_hlo_passes, - /*profile=*/profile); - - return output.ok() - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << output.status().message(); -} - -::testing::AssertionResult HloTestBase::RunReplicated( - string_view hlo_string, bool run_hlo_passes, int64_t num_replicas, - const tsl::protobuf::Message* backend_config) { - auto module_or_status = - ParseAndReturnVerifiedModule(hlo_string, num_replicas); - if (!module_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_or_status.status().ToString(); - } - - std::unique_ptr module = std::move(module_or_status.value()); - const auto fake_arguments = MakeFakeArguments(module.get()).value(); - std::vector fake_argument_ptrs; - absl::c_transform( - fake_arguments, std::back_inserter(fake_argument_ptrs), - [](const Literal& literal) { return const_cast(&literal); }); - - if (backend_config) { - // Set backend configuration if it is given. - HloInstruction* instruction = - module->entry_computation()->root_instruction(); - absl::Status s = instruction->set_backend_config(*backend_config); - return s.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << s.message(); - } - - HloRunner::ReplicatedExecuteOptions options; - options.num_replicas = num_replicas; - options.run_hlo_passes = run_hlo_passes; - options.use_threads = true; - for (auto argument : fake_argument_ptrs) { - options.arguments.push_back(argument); - } - auto output = runner_->ExecuteReplicated(std::move(module), options); - - return output.ok() - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << output.status().message(); -} - -::testing::AssertionResult HloTestBase::RunMultipleTimes( - string_view hlo_string, bool run_hlo_passes, - std::vector* profiles, - const tsl::protobuf::Message* backend_config, bool assert_determinism) { - int n = profiles->size(); - std::vector> fake_argument_ptrs(n); - std::vector> fake_arguments(n); - std::vector> executables(n); - - for (int i = 0; i < n; ++i) { - auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); - if (!module_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_or_status.status().ToString(); - } - std::unique_ptr module = std::move(module_or_status.value()); - - fake_arguments[i] = MakeFakeArguments(module.get()).value(); - - if (profiles != nullptr) { - // We have to enable HLO profiling since otherwise currently the - // ExecutionProfile is not correct. - // - // TODO(b/119432044): Fix collection of the ExecutionProfile - // so that this is not necessary. - HloModuleConfig config = module->config(); - DebugOptions debug_options = config.debug_options(); - debug_options.set_xla_hlo_profile(true); - config.set_debug_options(debug_options); - module->set_config(config); - } - - if (backend_config) { - // Set backend configuration if it is given. - HloInstruction* instruction = - module->entry_computation()->root_instruction(); - absl::Status s = instruction->set_backend_config(*backend_config); - return s.ok() ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << s.message(); - } - - auto executable = - runner_->CreateExecutable(std::move(module), run_hlo_passes); - if (!executable.ok()) { - return ::testing::AssertionFailure() << executable.status().message(); - } - executables[i] = std::move(executable.value()); - } - - std::optional canonical_output; - for (int i = 0; i < n; ++i) { - absl::StatusOr output = - runner_->ExecuteWithExecutable(executables[i].get(), fake_arguments[i], - /*profile=*/&((*profiles)[i])); - if (!output.ok()) { - return ::testing::AssertionFailure() << output.status().message(); - } - - if (assert_determinism) { - if (!canonical_output.has_value()) { - canonical_output = std::move(output).value(); - } else { - if (*canonical_output != output.value()) { - return ::testing::AssertionFailure() - << "Successive runs have returned different results: " - << *canonical_output << " vs. " << output.value(); - } - } - } - } - - return ::testing::AssertionSuccess(); + TF_CHECK_OK(status_or_runner.status()); + return *std::move(status_or_runner); } ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( @@ -1032,20 +161,6 @@ ::testing::AssertionResult HloTestBase::RunAndCompareFromFile( reference_preprocessor); } -::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( - string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor, - const std::function& test_preprocessor) { - auto module_or_status = ParseAndReturnVerifiedModule(hlo_string); - if (!module_or_status.ok()) { - return ::testing::AssertionFailure() - << "Error while parsing HLO text format: " - << module_or_status.status().ToString(); - } - return RunAndCompareNoHloPasses(std::move(module_or_status).value(), error, - reference_preprocessor, test_preprocessor); -} - ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( const std::string& filename, const std::optional& error, const std::function& reference_preprocessor) { @@ -1059,43 +174,6 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( reference_preprocessor); } -HloComputation* HloTestBase::FindComputation(HloModule* module, - absl::string_view name) { - return hlo_query::FindComputation(module, name); -} - -HloInstruction* HloTestBase::FindInstruction(HloModule* module, - absl::string_view name) { - for (const HloComputation* computation : module->computations()) { - if (auto instruction = hlo_query::FindFirstInstruction(computation, name); - instruction.first != nullptr) { - return instruction.first; - } - } - return nullptr; -} - -HloInstruction* HloTestBase::FindInstruction(HloModule* module, - HloOpcode opcode) { - for (const HloComputation* computation : module->computations()) { - if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode); - instruction.first != nullptr) { - return instruction.first; - } - } - return nullptr; -} - -std::vector HloTestBase::FindInstructions(HloModule* module, - HloOpcode opcode) { - std::vector instructions; - for (const HloComputation* c : module->computations()) { - absl::c_copy_if(c->instructions(), std::back_inserter(instructions), - [&](HloInstruction* i) { return i->opcode() == opcode; }); - } - return instructions; -} - se::DeviceMemoryAllocator* HloTestBase::GetAllocator() { if (allocator_ == nullptr) { allocator_ = std::make_unique( @@ -1104,14 +182,6 @@ se::DeviceMemoryAllocator* HloTestBase::GetAllocator() { return allocator_.get(); } -Backend& HloTestBase::backend() { return test_runner_.backend(); } -const Backend& HloTestBase::backend() const { return test_runner_.backend(); } - -/* static */ -std::string HloTestBase::TestName() { - return ::testing::UnitTest::GetInstance()->current_test_info()->name(); -} - void HloTestBase::MatchOptimizedHlo(absl::string_view hlo, absl::string_view pattern, bool print_operand_shape) { @@ -1141,21 +211,4 @@ absl::StatusOr> HloTestBase::GetOptimizedModule( GetAllocator()); } -absl::StatusOr> -HloTestBase::GetHloRunnerForTest(se::Platform* test_platform) { - if (ShouldUsePjRt()) { - PjRtClientTestFactoryRegistry& pjrt_registry = - GetGlobalPjRtClientTestFactory(); - TF_ASSIGN_OR_RETURN(auto client, pjrt_registry.Get()()); - - auto device_shape_representation_fn = - pjrt_registry.GetDeviceShapeRepresentationFn(client.get()); - - return std::unique_ptr( - new HloRunnerPjRt(std::move(client), device_shape_representation_fn)); - } else { - return std::unique_ptr(new HloRunner(test_platform)); - } -} - } // namespace xla diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index e075d7fd7123a0..732e9de6eee226 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -16,30 +16,34 @@ limitations under the License. #ifndef XLA_TESTS_HLO_TEST_BASE_H_ #define XLA_TESTS_HLO_TEST_BASE_H_ +#include #include +#include #include #include #include #include #include +#include "absl/base/attributes.h" +#include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "xla/hlo/ir/hlo_instruction.h" +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/hlo/ir/hlo_module_group.h" -#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/literal.h" #include "xla/service/backend.h" -#include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" #include "xla/service/hlo_runner.h" -#include "xla/service/hlo_verifier.h" -#include "xla/service/platform_util.h" -#include "xla/shape_layout.h" +#include "xla/service/hlo_runner_interface.h" #include "xla/stream_executor/device_memory_allocator.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/tests/literal_test_util.h" -#include "xla/tests/verified_hlo_module.h" -#include "xla/types.h" +#include "xla/stream_executor/platform.h" +#include "xla/tests/new_hlo_test_base.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" @@ -72,72 +76,20 @@ namespace xla { // ) // // For a more detailed example, see "../tests/sample_text_test.cc". -class HloTestBase : public ::testing::Test { +// +// This class is deprecated in favor of NewHloTestBase. We are in the process of +// incrementally migrating tests to use this new base class. HloTestBase remains +// as a shim on tests during this migration process. Please avoid introducing +// new tests that use this class. +class [[deprecated("Use NewHloTestBase instead.")]] HloTestBase + : public NewHloTestBase { public: - // Creates a new HLO module for a test. The module created will have - // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. If you want a fresh HloModule object and - // then add HloComputations to it, it's recommended to use this method in your - // tests. - // - // This returns a vanilla HloModule that doesn't run the HLO verifier on - // destruction. - ABSL_DEPRECATED("Use CreateNewVerifiedModule instead.") - std::unique_ptr CreateNewUnverifiedModule( - const std::string& name = TestName()); - - // Like CreateNewUnverifiedModule, except the HloModule returned here runs the - // HLO verifier on destruction. - std::unique_ptr CreateNewVerifiedModule( - const std::string& name = TestName(), int64_t replica_count = 1); - - // Parses the given string and returns module as a VerifiedHloModule. - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - int64_t replica_count = 1, - int64_t num_partitions = 1); - absl::StatusOr> - ParseAndReturnVerifiedModule(absl::string_view hlo_text, - const HloModuleConfig& config); - - // Runs the hlo_pass with the provided module and returns the result. This - // function also verifies that the module remains unchanged when hlo_pass - // returns false as the absl::StatusOr value. - // - // These three overloads all do the same thing. The && overload lets you do - // `RunHloPass(MyPass(), module)` all in one line. The reason for the - // overload that takes a pointer is that, at one point in the past, non-const - // lvalue references were banned in Google code. - static absl::StatusOr RunHloPass(HloPassInterface* hlo_pass, - HloModule* module); - static absl::StatusOr RunHloPass(HloPassInterface& hlo_pass, - HloModule* module) { - return RunHloPass(&hlo_pass, module); - } - static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModule* module) { - return RunHloPass(&hlo_pass, module); - } - - // Runs the hlo_pass with the provided module group and returns the result. - // This method runs the input HLO module group pass for a `HloModuleGroup` and - // it also verifies the module group remains unchanged when hlo_pass returns - // false as the absl::StatusOr value. - static absl::StatusOr RunHloPass(HloPassInterface&& hlo_pass, - HloModuleGroup* module_group); - - static PrecisionConfig DefaultPrecisionConfig(int operands); - - // Sets most fath math options to be enabled to model the fast math flags - // generally used for CPU:AOT compilation. - static void SetAotFastMathDebugOptions(DebugOptions* options); - // Compiles the given `hlo` with optimizations, and verifies that optimized // HLO matches the given FileCheck pattern. void MatchOptimizedHlo(absl::string_view hlo, absl::string_view pattern, bool print_operand_shape = false); - // LikeMatchOptimizedHlo, but checks operand shapes as well. + // Like MatchOptimizedHlo, but checks operand shapes as well. void MatchOptimizedHloWithShapes(absl::string_view hlo, absl::string_view pattern) { MatchOptimizedHlo(hlo, pattern, /*print_operand_shape=*/true); @@ -150,6 +102,8 @@ class HloTestBase : public ::testing::Test { absl::StatusOr> GetOptimizedModule( std::unique_ptr hlo_module); + using NewHloTestBase::ParseAndReturnVerifiedModule; + protected: // This uses the interpreter backend as the reference backend and // automatically finds another supported backend as the test backend. If the @@ -167,370 +121,88 @@ class HloTestBase : public ::testing::Test { bool allow_mixed_precision_in_hlo_verifier = true, HloPredicate instruction_can_change_layout_func = {}); - ~HloTestBase() override {} - - // Runs pass `hlo_pass` on input HLO module `hlo` with optional config, and - // FileChecks the result against `expected`. - // - // If the rewrite has changed the module, also runs `additional_checks` on the - // result. - void RunAndFilecheckHloRewrite( - absl::string_view hlo, HloPassInterface&& hlo_pass, - std::optional expected, - std::function after_pass_checks = nullptr, - const HloModuleConfig* config = nullptr); - - // Runs pass `hlo_pass` on a group of input HLO modules `hlo_module_strs`, - // and FileChecks the result against `expected`. - void RunAndFilecheckHloModuleGroupRewrite( - absl::Span hlo_module_strs, - HloPassInterface&& hlo_pass, - std::optional> expected); - - // Populates debug options from command-line flags and adjusts the options for - // testing. It is recommended to use this when you need to pass in - // DebugOptions, e.g. when creating a module from a string or a file. - // - // This function is virtual so tests can specify an alternative set of debug - // options (e.g. disabling additional passes). - virtual DebugOptions GetDebugOptionsForTest(); - - // Gets an HloModuleConfig with options appropriate for tests. - HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1, - int64_t num_partitions = 1) { - HloModuleConfig config; - config.set_debug_options(GetDebugOptionsForTest()); - config.set_replica_count(replica_count); - config.set_num_partitions(num_partitions); - return config; - } - - // Executes the given module and return the result as a Literal. - absl::StatusOr Execute(std::unique_ptr module, - absl::Span arguments, - bool run_hlo_passes = true); - - // Same as above, except the module will be executed without running any HLO - // passes on it. - Literal ExecuteNoHloPasses(std::unique_ptr module, - absl::Span arguments); - - Literal ExecuteAndTransfer(std::unique_ptr module, - absl::Span arguments); - - // Compile the given module to an executable. - absl::StatusOr> CreateExecutable( - std::unique_ptr module, bool run_hlo_passes) { - return runner_->CreateExecutable(std::move(module), run_hlo_passes); + // DO NOT USE: This is a temporary method to help migrate away from HloRunner. + // Some test fixures rely on functionality that is not supported by other + // HloRunnerInterface implementations, thus we expose it here. + [[nodiscard]] [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + absl::StatusOr> ExecuteReplicatedWithHloRunner( + Executable* executable, + const HloRunnerInterface::ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, + ExecutionProfile* profile = nullptr) { + return test_runner_as_hlo_runner().ExecuteReplicated( + executable, options, device_assignment, profile); } - // Executes the given module on multiple replicas. - // - // use_threads indicates whether this replicated computation will be executed - // with a thread-per-replica, vs using an implicitly async call such as - // Executable::ExecuteOnStreams. - absl::StatusOr> ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); - - // Same as above, but uses specified device assignment. - absl::StatusOr> ExecuteReplicated( - std::unique_ptr module, absl::Span arguments, - int64_t num_replicas, DeviceAssignment* device_assignment, - bool run_hlo_passes, bool use_threads); - - // Same as above, but allows passing different programs for replicas. - absl::StatusOr> ExecuteReplicated( - std::function executable_provider, - std::function argument_count_provider, - std::function argument_provider, - int64_t num_replicas, bool run_hlo_passes, - DeviceAssignment* device_assignment = nullptr); - - // Convenience function for above. Allows passing different inputs to - // different replicas of the same program. - absl::StatusOr> ExecuteReplicated( - std::unique_ptr module, - std::vector> arguments, int64_t num_replicas, - bool run_hlo_passes); - - // Executes the given hlo module on two backends and compares results. - // - // 'arguments': the input of the hlo module. - // - // 'error': if has value, expects the results to be near (within the error - // bound). Otherwise, expects the results to be equal. - // - // 'reference_preprocessor': the module should be ready to run on the test - // backend, but it might need to be tailored so that it is able to run on the - // reference backend. Note that the program shape of the module must not be - // modified. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, - const absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor = nullptr); - - // Same as above, except that the module will be executed without Hlo - // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, - const absl::Span arguments, - const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); - - // Executes an hlo module with fake inputs and compares the results. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - std::optional args_max_bits_of_precision = std::nullopt); - - // Same as above, except that the module will be executed without Hlo - // optimization. - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( - std::unique_ptr module, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); - - // Executes an hlo module with fake inputs and checks that the execution is - // successful. - [[nodiscard]] ::testing::AssertionResult Run( - std::unique_ptr module, bool run_hlo_passes); - - // Convenient wrappers for executing and comparing an hlo module with fake - // input. Module can be passed in directly, or parsed from an hlo_string, - // or loaded from a file. - [[nodiscard]] ::testing::AssertionResult RunAndCompare( - const absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - std::optional args_max_bits_of_precision = std::nullopt); - [[nodiscard]] ::testing::AssertionResult Run( - const absl::string_view hlo_string, bool run_hlo_passes = true, - ExecutionProfile* profile = nullptr, - const tsl::protobuf::Message* backend_config = nullptr, - bool use_random_data = true); - - // Same as below, except that it requires all the options to be passed. - ::testing::AssertionResult RunAndCompareTwoModulesReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - HloRunner::ReplicatedExecuteOptions options, - const std::optional& error); - - // Same as below, except that it requires the parsed modules to be passed. - ::testing::AssertionResult RunAndCompareTwoModulesReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - bool run_hlo_passes, bool use_threads, - const std::optional& error); - - // Parses the modules, and executes them based on `run_hlo_passes` and - // `use_threads` flags. The replica count should be mentioned in the module - // itself. - ::testing::AssertionResult RunAndCompareTwoModulesReplicated( - absl::string_view module_0, absl::string_view module_1, - bool run_hlo_passes, bool use_threads, - const std::optional& error); - - // Same as below, except requires passing fake arguments. - ::testing::AssertionResult RunAndCompareTwoModules( - std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, - const std::optional& error, bool run_hlo_passes = true); - - // Same as below, except requires passing the modules. - ::testing::AssertionResult RunAndCompareTwoModules( - std::unique_ptr module_0, std::unique_ptr module_1, - const std::optional& error, bool run_hlo_passes = true, - std::optional args_max_bits_of_precision = std::nullopt); - - // Convenient wrapper for executing and comparing results of two hlo modules - // with fake input. By default compares unoptimized modules. If the modules - // are already optimized, set |run_hlo_passes| to false. - ::testing::AssertionResult RunAndCompareTwoModules( - absl::string_view hlo_string_module_0, - absl::string_view hlo_string_module_1, - const std::optional& error, bool run_hlo_passes = true, - std::optional args_max_bits_of_precision = std::nullopt); - - // Same as above but allows running with different configs. - ::testing::AssertionResult RunAndCompareTwoModules( - absl::string_view hlo_string_module_0, - absl::string_view hlo_string_module_1, const HloModuleConfig& config_0, - const HloModuleConfig& config_1, const std::optional& error, - bool run_hlo_passes = true, - std::optional args_max_bits_of_precision = std::nullopt); - - // Same as above but requires explicit arguments. - ::testing::AssertionResult RunAndCompareTwoModules( - absl::string_view hlo_string_module_0, - absl::string_view hlo_string_module_1, - absl::Span arguments, - const std::optional& error, bool run_hlo_passes = true); - - // Executes an hlo module with fake inputs on multiple replicas. - [[nodiscard]] ::testing::AssertionResult RunReplicated( - const absl::string_view hlo_string, bool run_hlo_passes = true, - int64_t num_replicas = 1, - const tsl::protobuf::Message* backend_config = nullptr); - - // If assert_determinism is true, the assertion will fail unless all runs - // produce exactly the same output. - [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( - const absl::string_view hlo_string, bool run_hlo_passes, - std::vector* profiles, - const tsl::protobuf::Message* backend_config = nullptr, - bool assert_determinism = false); [[nodiscard]] ::testing::AssertionResult RunAndCompareFromFile( const std::string& filename, const std::optional& error, const std::function& reference_preprocessor = nullptr); - [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( - const absl::string_view hlo_string, const std::optional& error, - const std::function& reference_preprocessor = nullptr, - const std::function& test_preprocessor = nullptr); [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( const std::string& filename, const std::optional& error, const std::function& reference_preprocessor = nullptr); - // Convenience method to force the layout of a given parameter in a module. - // The layout of parameter number 'param_no' in the 'module' is set to - // 'layout'. - void ForceParameterLayout(HloModule* module, int64_t param_no, - const Layout& layout) { - ASSERT_LT(param_no, - module->mutable_entry_computation_layout()->parameter_count()); - module->mutable_entry_computation_layout() - ->mutable_parameter_layout(param_no) - ->ResetLayout(layout); + // DO NOT USE: This is a temporary method to help migrate away from HloRunner. + // Some test fixures rely on functionality that is not supported by other + // HloRunnerInterface implementations, thus we expose it here. + [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + const Backend& backend() const { + return test_runner_as_hlo_runner().backend(); } - - // Convenience method to force the layout of the computation result in a - // module. The result layout of 'module' is set to 'layout'. - void ForceResultLayout(HloModule* module, const Layout& layout) { - module->mutable_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout); + // Returns the backend owned by the test runner. + // DO NOT USE: This is a temporary method to help migrate away from HloRunner. + // Some test fixures rely on functionality that is not supported by other + // HloRunnerInterface implementations, thus we expose it here. + [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + Backend& backend() { + return test_runner_as_hlo_runner().backend(); } - void ForceResultLayout(HloModule* module, const Layout& layout, - ShapeIndexView shape_index) { - module->mutable_entry_computation_layout() - ->mutable_result_layout() - ->ResetLayout(layout, shape_index); + // DO NOT USE: This is a temporary method to help migrate away from HloRunner. + // Some test fixures rely on functionality that is not supported by other + // HloRunnerInterface implementations, thus we expose it here. + [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + const HloRunner& test_runner_as_hlo_runner() const { + return *static_cast(&test_runner()); } - - // Convenience method to clear the layout of the computation result in - // 'module'. - void ForceClearResultLayout(HloModule* module) { - module->mutable_entry_computation_layout() - ->mutable_result_layout() - ->Clear(); + // DO NOT USE: This is a temporary method to help migrate away from HloRunner. + // Some test fixures rely on functionality that is not supported by other + // HloRunnerInterface implementations, thus we expose it here. + [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + HloRunner& test_runner_as_hlo_runner() { + return *static_cast(&test_runner()); } - // Gets the computation/instruction from the given module with the given name. - // Note that it is encouraged to use these functions directly via the - // hlo_query.h header instead since they are independent from any test-time - // variables or contexts. - - // This is useful for tests which create HLOs from a string and then want to - // inspect a particular computation or instruction. - HloComputation* FindComputation(HloModule* module, absl::string_view name); - HloInstruction* FindInstruction(HloModule* module, absl::string_view name); - // Gets the instruction from the given module with the given opcode. - HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); - // Gets all the instructions from the given module with the given opcode. - std::vector FindInstructions(HloModule* module, - HloOpcode opcode); - - // Return an HLO verifier constructed for the test backend. - HloVerifier& verifier() const { return *hlo_verifier_; } - - static std::string TestName(); - - // Returns the backend owned by the test runner. - Backend& backend(); - const Backend& backend() const; - - int64_t num_devices() { return backend().device_count(); } - - HloRunner test_runner_; - HloRunner reference_runner_; - - bool verifier_layout_sensitive_; - bool allow_mixed_precision_in_hlo_verifier_; - HloPredicate instruction_can_change_layout_func_; - std::unique_ptr hlo_verifier_; - - ErrorSpec error_spec_{0.0001}; - - HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( - HloModule*, std::unique_ptr computation); - void UpdateEntryComputationLayout(HloModule* module); - - // Updates the entry computation layout to match the program shape. Useful - // when tiling assignment has been run to update the latter and we want those - // changes propagated into the former. - absl::Status UpdateEntryComputationLayoutToMatchProgramLayout( - HloModule* module); + [[deprecated( + "This is a temporary method to help migrate existing tests away from " + "directly depending on HloRunner. Please do not introduce new uses.")]] + int64_t num_devices() { + return backend().device_count(); + } absl::StatusOr> GetHloRunner(); - protected: // Helper functions to get test and reference platforms. static se::Platform* GetReferencePlatform(); static se::Platform* GetTestPlatform(); - // Compares the inputs shapes of two modules and returns the list of parameter - // indices that mismatch. The mismatch could be either in shape or datatype. - // If there is no mismatch, an empty vector is returned. - [[nodiscard]] std::vector CompareInputs(const HloModule& module_0, - const HloModule& module_1); - - private: // Creates or retrieves the allocator. se::DeviceMemoryAllocator* GetAllocator(); - // Either an HloRunner or HloRunnerPjRt depending on if ShouldUsePjRt() - std::unique_ptr runner_; - se::Platform* test_platform_; - std::unique_ptr allocator_; - - // Given the test module, makes a reference module that is ready to run on the - // reference platform. This assumes that the given module is ready to run on - // the test platform. - absl::StatusOr> MakeReferenceModule( - const HloModule& test_module, - const std::function& reference_preprocessor); - - // Runs the module on two platforms with or without running hlo passes and - // compares the results. Returns whether the results are near or equal. If any - // error happens before the results are computed, returns the error status. - absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( - std::unique_ptr module, - const absl::Span arguments, - const std::optional& error, bool run_hlo_passes, - const std::function& reference_preprocessor, - const std::function& test_preprocessor = nullptr); - - // Runs the two module with or without running hlo passes and compares - // the results. Returns whether the results are near or equal. If any - // error happens before the results are computed, returns the error status. - absl::StatusOr<::testing::AssertionResult> - RunAndCompareTwoModulesInternalReplicated( - std::unique_ptr module_0, std::unique_ptr module_1, - HloRunner::ReplicatedExecuteOptions options, - const std::optional& error); - // Runs the two module on with or without running hlo passes and - // compares the results. Returns whether the results are near or equal. If any - // error happens before the results are computed, returns the error status. - absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( - std::unique_ptr module_0, std::unique_ptr module_1, - const absl::Span arguments, - const std::optional& error, bool run_hlo_passes); + ErrorSpec error_spec_{0.0001}; - // Returns either an HloRunner or HloRunnerPjRt implementation depending if - // there exists a registered PjRtClientFactory. - absl::StatusOr> GetHloRunnerForTest( - se::Platform* test_platform); + private: + se::Platform* test_platform_; + std::unique_ptr allocator_; }; #define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ diff --git a/third_party/xla/xla/tests/int4_test.cc b/third_party/xla/xla/tests/int4_test.cc index 6d83d9489aee4d..2e2e863ce5a697 100644 --- a/third_party/xla/xla/tests/int4_test.cc +++ b/third_party/xla/xla/tests/int4_test.cc @@ -157,5 +157,34 @@ XLA_TEST_F(HloTestBase, HorizontalLoopFusion) { EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); } +class HloTestBaseWithAlgsimpDisabled : public HloTestBase { + DebugOptions GetDebugOptionsForTest() const override { + DebugOptions options = HloTestBase::GetDebugOptionsForTest(); + options.add_xla_disable_hlo_passes("algsimp"); + return options; + } +}; + +XLA_TEST_F(HloTestBaseWithAlgsimpDisabled, TwoDots) { + // This tests a regression that occured when a non-parameter non-ROOT + // instruction was s4 as the input or output of a fusion. Fusion passes tend + // to make any int4 instructions only internal to a fusion, but this HLO, at + // the time it is written, has an int4 tensor existing between fusions when + // algebraic simplifier is disabled. + const std::string hlo_text = R"( + HloModule TwoDots + + ENTRY main { + x = s8[25,20,10,5] parameter(0) + y = s8[25,20,10,5] parameter(1) + z = s8[5,20] parameter(2) + dot0 = s8[25,20,10,5] dot(x, y), lhs_batch_dims={0,1,2,3}, lhs_contracting_dims={}, rhs_batch_dims={0,1,2,3}, rhs_contracting_dims={} + dot0_4 = s4[25,20,10,5] convert(dot0) + dot0_8 = s8[25,20,10,5] convert(dot0_4) + dot1 = s8[5,25,10] dot(z, dot0_8), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={3}, rhs_contracting_dims={1} + } +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, std::nullopt)); +} } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/iota_test.cc b/third_party/xla/xla/tests/iota_test.cc index bb5bf5b93151dd..a9dddb816b4705 100644 --- a/third_party/xla/xla/tests/iota_test.cc +++ b/third_party/xla/xla/tests/iota_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "absl/log/check.h" -#include "xla/client/xla_builder.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/local_client_allocation_test.cc b/third_party/xla/xla/tests/local_client_allocation_test.cc index a391a8e2586c2c..81b317e3438d42 100644 --- a/third_party/xla/xla/tests/local_client_allocation_test.cc +++ b/third_party/xla/xla/tests/local_client_allocation_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/service/local_service.h" #include "xla/service/shaped_buffer.h" diff --git a/third_party/xla/xla/tests/local_client_aot_test_helper.cc b/third_party/xla/xla/tests/local_client_aot_test_helper.cc index bb6fdc0cb0bff5..2554ce9df424d5 100644 --- a/third_party/xla/xla/tests/local_client_aot_test_helper.cc +++ b/third_party/xla/xla/tests/local_client_aot_test_helper.cc @@ -22,8 +22,8 @@ limitations under the License. #include "llvm/TargetParser/Host.h" #include "llvm/TargetParser/Triple.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/cpu/cpu_compiler.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/tests/local_client_execute_test.cc b/third_party/xla/xla/tests/local_client_execute_test.cc index b4e4a167a4d07d..829bfc31fb3449 100644 --- a/third_party/xla/xla/tests/local_client_execute_test.cc +++ b/third_party/xla/xla/tests/local_client_execute_test.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include +#include #include "absl/status/statusor.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/sharding_builder.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/sharding_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/platform_util.h" @@ -1018,6 +1019,53 @@ XLA_TEST_F(LocalClientExecuteTest, ValidateUseShardyPartitioner) { EXPECT_EQ(proto.config().use_shardy_partitioner(), true); } +XLA_TEST_F(LocalClientExecuteTest, ValidateExecTimeOptimizationEffort) { + XlaBuilder builder(TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); + Shape argument_layout = + local_client_->backend().compiler()->DefaultDeviceShapeRepresentation( + ShapeUtil::MakeShapeWithDenseLayout(F32, /*dimensions=*/{3}, {0})); + + ExecutableBuildOptions build_options; + build_options.set_exec_time_optimization_effort(-1.5f); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + local_client_->Compile(builder.Build().value(), {&argument_layout}, + build_options)); + EXPECT_EQ(1, executables.size()); + const HloModule& compiled_module = + executables.front()->executable()->module(); + EXPECT_FLOAT_EQ(compiled_module.config().exec_time_optimization_effort(), + -1.5f); + auto proto = compiled_module.ToProtoWithConfig(); + EXPECT_FLOAT_EQ(proto.config().exec_time_optimization_effort(), -1.5f); +} + +XLA_TEST_F(LocalClientExecuteTest, ValidateMemoryFittingEffort) { + XlaBuilder builder(TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {3}), "x"); + auto y = ConstantR1(&builder, {2.0f, 3.0f, 4.0f}); + Add(x, y); + Shape argument_layout = + local_client_->backend().compiler()->DefaultDeviceShapeRepresentation( + ShapeUtil::MakeShapeWithDenseLayout(F32, /*dimensions=*/{3}, {0})); + + ExecutableBuildOptions build_options; + build_options.set_memory_fitting_effort(2.0f); + TF_ASSERT_OK_AND_ASSIGN( + auto executables, + local_client_->Compile(builder.Build().value(), {&argument_layout}, + build_options)); + EXPECT_EQ(1, executables.size()); + const HloModule& compiled_module = + executables.front()->executable()->module(); + EXPECT_FLOAT_EQ(compiled_module.config().memory_fitting_effort(), 2.0f); + auto proto = compiled_module.ToProtoWithConfig(); + EXPECT_FLOAT_EQ(proto.config().memory_fitting_effort(), 2.0f); +} + BENCHMARK(BM_LocalClientOverhead); } // namespace diff --git a/third_party/xla/xla/tests/local_client_test_base.cc b/third_party/xla/xla/tests/local_client_test_base.cc index aeac00409e7009..0f4750132889ba 100644 --- a/third_party/xla/xla/tests/local_client_test_base.cc +++ b/third_party/xla/xla/tests/local_client_test_base.cc @@ -23,10 +23,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "unsupported/Eigen/CXX11/Tensor" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/map_util.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" diff --git a/third_party/xla/xla/tests/local_client_test_base.h b/third_party/xla/xla/tests/local_client_test_base.h index 87cfb3756e0179..dfe45beb735b89 100644 --- a/third_party/xla/xla/tests/local_client_test_base.h +++ b/third_party/xla/xla/tests/local_client_test_base.h @@ -26,7 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/client_library.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/service/hlo_module_config.h" #include "xla/service/local_service.h" #include "xla/service/platform_util.h" diff --git a/third_party/xla/xla/tests/log_test.cc b/third_party/xla/xla/tests/log_test.cc index f5bc530f0c1c82..114a00ee387682 100644 --- a/third_party/xla/xla/tests/log_test.cc +++ b/third_party/xla/xla/tests/log_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/map_test.cc b/third_party/xla/xla/tests/map_test.cc index cce561c75dfeb2..6d654a74a06656 100644 --- a/third_party/xla/xla/tests/map_test.cc +++ b/third_party/xla/xla/tests/map_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/client/global_data.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/stream_executor/stream_executor.h" diff --git a/third_party/xla/xla/tests/matmul_test.cc b/third_party/xla/xla/tests/matmul_test.cc index 668fa32425391c..1ed47869346ad0 100644 --- a/third_party/xla/xla/tests/matmul_test.cc +++ b/third_party/xla/xla/tests/matmul_test.cc @@ -28,7 +28,7 @@ namespace { class MatmulTestWithCublas : public HloTestBase, public ::testing::WithParamInterface { public: - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto debug_options = HloTestBase::GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_cublaslt(use_cublas_lt_); return debug_options; diff --git a/third_party/xla/xla/tests/matrix_ops_simple_test.cc b/third_party/xla/xla/tests/matrix_ops_simple_test.cc index 471595c4aba98f..65bad8ae68fe38 100644 --- a/third_party/xla/xla/tests/matrix_ops_simple_test.cc +++ b/third_party/xla/xla/tests/matrix_ops_simple_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/reference_util.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/multidimensional_slice_test.cc b/third_party/xla/xla/tests/multidimensional_slice_test.cc index 1c15cd2ac94fc3..0b89cbee3341f6 100644 --- a/third_party/xla/xla/tests/multidimensional_slice_test.cc +++ b/third_party/xla/xla/tests/multidimensional_slice_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/multioutput_fusion_test.cc b/third_party/xla/xla/tests/multioutput_fusion_test.cc index 97ee8b70575426..829124c8da981b 100644 --- a/third_party/xla/xla/tests/multioutput_fusion_test.cc +++ b/third_party/xla/xla/tests/multioutput_fusion_test.cc @@ -52,7 +52,7 @@ class MultiOutputFusionTest : public HloTestBase { // Layout assignment assumes that there are no fusions in the input graph. // Since the purpose of this test is to send pre-fused graphs to XLA, we have // to do layout assignment ourselves. - DebugOptions GetDebugOptionsForTest() override { + DebugOptions GetDebugOptionsForTest() const override { auto opts = HloTestBase::GetDebugOptionsForTest(); opts.add_xla_disable_hlo_passes("layout-assignment"); return opts; diff --git a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc index 8aa1502a3a951d..01c87236e9a768 100644 --- a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc +++ b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/client/client_library.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tsl/lib/core/status_test_util.h" diff --git a/third_party/xla/xla/tests/multithreaded_compilation_test.cc b/third_party/xla/xla/tests/multithreaded_compilation_test.cc index b9cb8d253cb511..1e5f138389a289 100644 --- a/third_party/xla/xla/tests/multithreaded_compilation_test.cc +++ b/third_party/xla/xla/tests/multithreaded_compilation_test.cc @@ -70,9 +70,8 @@ XLA_TEST_F(MultithreadedCompilation, EightModuleCompilation) { absl::Mutex mu; std::vector> executables; auto do_compilation = [&](int iteration) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - test_runner_.CreateExecutable(std::move(modules[iteration]), true)); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + CreateExecutable(std::move(modules[iteration]), true)); absl::MutexLock lock(&mu); executables.push_back(std::move(executable)); VLOG(2) << "Adding executable obtained from thread: " << iteration; diff --git a/third_party/xla/xla/tests/new_hlo_test_base.cc b/third_party/xla/xla/tests/new_hlo_test_base.cc new file mode 100644 index 00000000000000..8a1eba325adbdc --- /dev/null +++ b/third_party/xla/xla/tests/new_hlo_test_base.cc @@ -0,0 +1,816 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tests/new_hlo_test_base.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/debug_options_flags.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_module_util.h" +#include "xla/service/hlo_runner_interface.h" +#include "xla/service/hlo_verifier.h" +#include "xla/shape.h" +#include "xla/tests/filecheck.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/test_utils.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { + +namespace { + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!Shape::Equal().IgnoreElementSizeInLayout()(lhs.parameters(i), + rhs.parameters(i))) { + return false; + } + } + return Shape::Equal().IgnoreElementSizeInLayout()(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +NewHloTestBase::NewHloTestBase( + absl::Nonnull> test_runner, + absl::Nonnull> reference_runner, + const bool verifier_layout_sensitive, + const bool allow_mixed_precision_in_hlo_verifier, + const HloPredicate instruction_can_change_layout_func) + : HloHardwareIndependentTestBase(verifier_layout_sensitive, + allow_mixed_precision_in_hlo_verifier, + instruction_can_change_layout_func), + test_runner_(std::move(test_runner)), + reference_runner_(std::move(reference_runner)) {} + +std::unique_ptr NewHloTestBase::CreateNewVerifiedModule( + const std::string& name, const int64_t replica_count) { + return std::make_unique( + name, GetModuleConfigForTest(replica_count), verifier_layout_sensitive(), + allow_mixed_precision_in_hlo_verifier(), + test_runner_->device_shape_size_fn(), + instruction_can_change_layout_func()); +} + +absl::StatusOr> +NewHloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count, + int64_t num_partitions) { + return ParseAndReturnVerifiedModule( + hlo_text, GetModuleConfigForTest(replica_count, num_partitions)); +} + +absl::StatusOr> +NewHloTestBase::ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config) { + auto module = std::make_unique( + TestName(), config, verifier_layout_sensitive(), + allow_mixed_precision_in_hlo_verifier(), + test_runner_->device_shape_size_fn(), + instruction_can_change_layout_func()); + TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text)); + UpdateEntryComputationLayout(module.get()); + return std::move(module); +} + +HloComputation* +NewHloTestBase::AddEntryComputationAndUpdateEntryComputationLayout( + HloModule* const module, std::unique_ptr computation) { + HloComputation* const comp = + module->AddEntryComputation(std::move(computation)); + UpdateEntryComputationLayout(module); + return comp; +} + +void NewHloTestBase::UpdateEntryComputationLayout( + HloModule* const module) const { + xla::UpdateEntryComputationLayout( + module, test_runner_->device_shape_representation_fn()); +} + +absl::StatusOr NewHloTestBase::Execute( + std::unique_ptr module, absl::Span arguments, + bool run_hlo_passes) { + return test_runner_->Execute(std::move(module), arguments, run_hlo_passes); +} + +Literal NewHloTestBase::ExecuteNoHloPasses( + std::unique_ptr module, absl::Span arguments) { + absl::StatusOr result = Execute(std::move(module), arguments, + /*run_hlo_passes=*/false); + CHECK_OK(result.status()); + return *std::move(result); +} + +Literal NewHloTestBase::ExecuteAndTransfer( + std::unique_ptr module, absl::Span arguments) { + absl::StatusOr result = + test_runner_->Execute(std::move(module), arguments, true, nullptr); + CHECK_OK(result.status()); + return *std::move(result); +} + +absl::StatusOr> NewHloTestBase::ExecuteReplicated( + std::unique_ptr module, + const absl::Span arguments, const int64_t num_replicas, + const bool use_threads, const bool run_hlo_passes) { + HloRunnerInterface::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.arguments = {arguments.begin(), arguments.end()}; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = use_threads; + return test_runner_->ExecuteReplicated(std::move(module), std::move(options)); +} + +absl::StatusOr> NewHloTestBase::ExecuteReplicated( + std::unique_ptr module, + const absl::Span arguments, const int64_t num_replicas, + DeviceAssignment* const device_assignment, const bool run_hlo_passes, + const bool use_threads) { + HloRunnerInterface::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.arguments = {arguments.begin(), arguments.end()}; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = use_threads; + return test_runner_->ExecuteReplicated(std::move(module), std::move(options), + device_assignment); +} + +absl::StatusOr> NewHloTestBase::ExecuteReplicated( + const std::function executable_provider, + const std::function argument_count_provider, + const std::function argument_provider, + const int64_t num_replicas, const bool run_hlo_passes, + DeviceAssignment* const device_assignment) { + HloRunnerInterface::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = true; + return test_runner_->ExecuteReplicated( + executable_provider, argument_count_provider, argument_provider, + std::move(options), device_assignment); +} + +absl::StatusOr> NewHloTestBase::ExecuteReplicated( + std::unique_ptr module, + const std::vector> arguments, + const int64_t num_replicas, const bool run_hlo_passes, + DeviceAssignment* const device_assignment) { + CHECK(num_replicas > 0 && "expect at least one replica"); + CHECK(num_replicas == arguments.size() && + "expect arguments for each replica"); + int64_t argument_count = arguments.front().size(); + TF_ASSIGN_OR_RETURN( + const std::unique_ptr executable, + test_runner_->CreateExecutable(std::move(module), run_hlo_passes)); + return ExecuteReplicated( + /*executable_provider=*/[&](int64_t) { return executable.get(); }, + /*argument_count_provider=*/[&](int64_t) { return argument_count; }, + /*argument_provider=*/ + [&](int64_t replica_idx, int64_t argument_idx) -> const Literal* { + return arguments[replica_idx][argument_idx]; + }, + num_replicas, /*run_hlo_passes=*/run_hlo_passes, + /*device_assignment=*/device_assignment); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompare( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor) { + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor, + test_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, + const absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor) { + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor, + test_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; +} + +::testing::AssertionResult NewHloTestBase::RunAndCompare( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor, + const std::optional args_max_bits_of_precision) { + const std::vector fake_arguments = + MakeFakeArguments(module.get(), /*pseudo_random=*/true, + /*use_large_range=*/false, + /*treat_gte_as_data_formatting=*/false, + args_max_bits_of_precision) + .value(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + return RunAndCompare(std::move(module), fake_argument_ptrs, error, + reference_preprocessor, test_preprocessor); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor) { + const std::vector fake_arguments = + MakeFakeArguments(module.get()).value(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, + reference_preprocessor, test_preprocessor); +} + +::testing::AssertionResult NewHloTestBase::Run( + std::unique_ptr module, const bool run_hlo_passes, + const std::function& test_preprocessor) { + const std::vector fake_arguments = + MakeFakeArguments(module.get()).value(); + if (const absl::StatusOr change = verifier().Run(module.get()); + !change.ok()) { + return ::testing::AssertionFailure() << change.status(); + } + if (test_preprocessor != nullptr) { + test_preprocessor(module.get()); + } + + const absl::StatusOr output = + test_runner_->Execute(std::move(module), fake_arguments, run_hlo_passes); + return output.ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << output.status().message(); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompare( + const absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor, + const std::optional args_max_bits_of_precision) { + absl::StatusOr> module = + ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + return RunAndCompare(*std::move(module), error, reference_preprocessor, + test_preprocessor, args_max_bits_of_precision); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModulesReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + const HloRunnerInterface::ReplicatedExecuteOptions options, + const std::optional& error) { + const int replica_count = module_0->config().replica_count(); + if (replica_count != module_1->config().replica_count()) { + return ::testing::AssertionFailure() + << "Number of replicas is not the same: " << replica_count << " Vs " + << module_1->config().replica_count(); + } + if (options.num_replicas != replica_count) { + return ::testing::AssertionFailure() + << "Number of execution replicas is different from number of " + "replicas in the module: requested number of replicas = " + << options.num_replicas + << ", number of replicas in hlo = " << replica_count; + } + + if (const std::vector mismatches = CompareInputs(*module_0, *module_1); + !mismatches.empty()) { + return ::testing::AssertionFailure() + << "Error: parameter mismatch at indices: " + << absl::StrJoin(mismatches, ","); + } + if (const int64_t num_args = module_0->entry_computation()->num_parameters(); + num_args != options.arguments.size()) { + return ::testing::AssertionFailure() + << "Mismatch in number of arguments passed while running replicated " + "hlo module. Expected: " + << num_args << ", actual: " << options.arguments.size(); + } + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareTwoModulesInternalReplicated( + std::move(module_0), std::move(module_1), options, error); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModulesReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + const bool run_hlo_passes, const bool use_threads, + const std::optional& error) { + const absl::StatusOr> fake_arguments = MakeFakeArguments( + /*module=*/module_0.get(), /*pseudo_random=*/true, + /*use_large_range=*/false, + /*treat_gte_as_data_formatting=*/false, + /*max_bits_of_precision=*/std::nullopt); + CHECK_OK(fake_arguments); + std::vector fake_argument_ptrs; + absl::c_transform( + /*input=*/*fake_arguments, + /*output=*/std::back_inserter(fake_argument_ptrs), + /*unary_op=*/[](const Literal& literal) -> Literal* { + return const_cast(&literal); + }); + const HloRunnerInterface::ReplicatedExecuteOptions options{ + /*num_replicas=*/module_0->config().replica_count(), + /*arguments=*/fake_argument_ptrs, + /*infeed_values=*/{}, + /*infeed_steps=*/-1, + /*outfeed_shape=*/{}, + /*outfeed_values=*/nullptr, + /*run_hlo_passes=*/run_hlo_passes, + /*use_threads=*/use_threads}; + return RunAndCompareTwoModulesReplicated(std::move(module_0), + std::move(module_1), options, error); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModulesReplicated( + const absl::string_view module_0_str, const absl::string_view module_1_str, + const bool run_hlo_passes, const bool use_threads, + const std::optional& error) { + absl::StatusOr> module_0 = + ParseAndReturnVerifiedModule(module_0_str); + if (!module_0.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0.status().ToString(); + } + + absl::StatusOr> module_1 = + ParseAndReturnVerifiedModule(module_1_str); + if (!module_1.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1.status().ToString(); + } + return RunAndCompareTwoModulesReplicated(*std::move(module_0), + *std::move(module_1), run_hlo_passes, + use_threads, error); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModules( + std::unique_ptr module_0, std::unique_ptr module_1, + const absl::Span arguments, + const std::optional& error, bool run_hlo_passes) { + const absl::StatusOr<::testing::AssertionResult> result = + RunAndCompareTwoModulesInternal(std::move(module_0), std::move(module_1), + arguments, error, run_hlo_passes); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return *result; +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModules( + std::unique_ptr module_0, std::unique_ptr module_1, + const std::optional& error, const bool run_hlo_passes, + const std::optional args_max_bits_of_precision) { + if (const std::vector mismatches = CompareInputs(*module_0, *module_1); + !mismatches.empty()) { + return ::testing::AssertionFailure() + << "Error : mismatching parameter shapes for parameters " + << absl::StrJoin(mismatches, ", "); + } + + const absl::StatusOr> fake_arguments = MakeFakeArguments( + module_0.get(), /*pseudo_random=*/true, /*use_large_range=*/false, + /*treat_gte_as_data_formatting=*/false, args_max_bits_of_precision); + CHECK_OK(fake_arguments); + + std::vector fake_argument_ptrs; + absl::c_transform( + *fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + return RunAndCompareTwoModules(std::move(module_0), std::move(module_1), + fake_argument_ptrs, error, run_hlo_passes); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModules( + const absl::string_view hlo_string_module_0, + const absl::string_view hlo_string_module_1, + const std::optional& error, const bool run_hlo_passes, + const std::optional args_max_bits_of_precision) { + absl::StatusOr> module_0 = + ParseAndReturnVerifiedModule(hlo_string_module_0); + if (!module_0.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0.status().ToString(); + } + + absl::StatusOr> module_1 = + ParseAndReturnVerifiedModule(hlo_string_module_1); + if (!module_1.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1.status().ToString(); + } + return RunAndCompareTwoModules(*std::move(module_0), *std::move(module_1), + error, run_hlo_passes, + args_max_bits_of_precision); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModules( + const absl::string_view hlo_string_module_0, + const absl::string_view hlo_string_module_1, + const HloModuleConfig& config_0, const HloModuleConfig& config_1, + const std::optional& error, const bool run_hlo_passes, + const std::optional args_max_bits_of_precision) { + absl::StatusOr> module_0 = + ParseAndReturnVerifiedModule(hlo_string_module_0, config_0); + if (!module_0.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0.status().ToString(); + } + + absl::StatusOr> module_1 = + ParseAndReturnVerifiedModule(hlo_string_module_1, config_1); + if (!module_1.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1.status().ToString(); + } + return RunAndCompareTwoModules(*std::move(module_0), *std::move(module_1), + error, run_hlo_passes, + args_max_bits_of_precision); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + const absl::Span arguments, + const std::optional& error, const bool run_hlo_passes) { + auto module_0_or_status = ParseAndReturnVerifiedModule(hlo_string_module_0); + if (!module_0_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_0_or_status.status().ToString(); + } + + auto module_1_or_status = ParseAndReturnVerifiedModule(hlo_string_module_1); + if (!module_1_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_1_or_status.status().ToString(); + } + return RunAndCompareTwoModules(std::move(module_0_or_status).value(), + std::move(module_1_or_status).value(), + arguments, error, run_hlo_passes); +} + +::testing::AssertionResult NewHloTestBase::Run( + const absl::string_view hlo_string, const bool run_hlo_passes, + ExecutionProfile* const profile, + const tsl::protobuf::Message* backend_config, const bool use_random_data) { + absl::StatusOr> module = + ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + const std::vector fake_arguments = + MakeFakeArguments(module->get(), use_random_data).value(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + if (profile != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = (*module)->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + (*module)->set_config(config); + } + + if (backend_config) { + // Set backend configuration if it is given. + HloInstruction* instruction = + (*module)->entry_computation()->root_instruction(); + absl::Status s = instruction->set_backend_config(*backend_config); + return s.ok() ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << s.message(); + } + + auto output = test_runner_->Execute(*std::move(module), fake_argument_ptrs, + /*run_hlo_passes=*/run_hlo_passes, + /*profile=*/profile); + + return output.ok() + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << output.status().message(); +} + +::testing::AssertionResult NewHloTestBase::RunReplicated( + const absl::string_view hlo_string, const bool run_hlo_passes, + const int64_t num_replicas, const tsl::protobuf::Message* backend_config) { + absl::StatusOr> module = + ParseAndReturnVerifiedModule(hlo_string, num_replicas); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + + const std::vector fake_arguments = + MakeFakeArguments(module->get()).value(); + std::vector fake_argument_ptrs; + absl::c_transform( + fake_arguments, std::back_inserter(fake_argument_ptrs), + [](const Literal& literal) { return const_cast(&literal); }); + + if (backend_config) { + // Set backend configuration if it is given. + HloInstruction* instruction = + (*module)->entry_computation()->root_instruction(); + if (const absl::Status s = instruction->set_backend_config(*backend_config); + !s.ok()) { + return ::testing::AssertionFailure() << s.message(); + } + return ::testing::AssertionSuccess(); + } + + HloRunnerInterface::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.arguments = {fake_argument_ptrs.begin(), fake_argument_ptrs.end()}; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = true; + const absl::StatusOr> output = + test_runner_->ExecuteReplicated(*std::move(module), std::move(options)); + if (output.ok()) { + return ::testing::AssertionSuccess(); + } + return ::testing::AssertionFailure() << output.status().message(); +} + +::testing::AssertionResult NewHloTestBase::RunMultipleTimes( + const absl::string_view hlo_string, const bool run_hlo_passes, + std::vector* const profiles, + const tsl::protobuf::Message* const backend_config, + const bool assert_determinism) { + const int n = profiles->size(); + std::vector> fake_argument_ptrs(n); + std::vector> fake_arguments(n); + std::vector> executables(n); + + for (int i = 0; i < n; ++i) { + absl::StatusOr> module = + ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + + fake_arguments[i] = MakeFakeArguments(module->get()).value(); + + if (profiles != nullptr) { + // We have to enable HLO profiling since otherwise currently the + // ExecutionProfile is not correct. + // + // TODO(b/119432044): Fix collection of the ExecutionProfile + // so that this is not necessary. + HloModuleConfig config = (*module)->config(); + DebugOptions debug_options = config.debug_options(); + debug_options.set_xla_hlo_profile(true); + config.set_debug_options(debug_options); + (*module)->set_config(config); + } + + if (backend_config) { + // Set backend configuration if it is given. + HloInstruction* instruction = + (*module)->entry_computation()->root_instruction(); + absl::Status s = instruction->set_backend_config(*backend_config); + return s.ok() ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << s.message(); + } + + absl::StatusOr> executable = + test_runner_->CreateExecutable(*std::move(module), run_hlo_passes); + if (!executable.ok()) { + return ::testing::AssertionFailure() << executable.status().message(); + } + executables[i] = *std::move(executable); + } + + std::optional canonical_output; + for (int i = 0; i < n; ++i) { + absl::StatusOr output = test_runner_->ExecuteWithExecutable( + executables[i].get(), fake_arguments[i], + /*profile=*/&((*profiles)[i])); + if (!output.ok()) { + return ::testing::AssertionFailure() << output.status().message(); + } + + if (assert_determinism) { + if (!canonical_output.has_value()) { + canonical_output = *std::move(output); + } else { + if (*canonical_output != *output) { + return ::testing::AssertionFailure() + << "Successive runs have returned different results: " + << *canonical_output << " vs. " << *output; + } + } + } + } + + return ::testing::AssertionSuccess(); +} + +::testing::AssertionResult NewHloTestBase::RunAndCompareNoHloPasses( + const absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor, + const std::function& test_preprocessor) { + absl::StatusOr> module = + ParseAndReturnVerifiedModule(hlo_string); + if (!module.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module.status().ToString(); + } + return RunAndCompareNoHloPasses(*std::move(module), error, + reference_preprocessor, test_preprocessor); +} + +absl::StatusOr> NewHloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const ProgramShape program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(verifier().Run(reference_module.get()).status()); + return std::move(reference_module); +} + +absl::StatusOr<::testing::AssertionResult> +NewHloTestBase::RunAndCompareInternal( + std::unique_ptr module, + const absl::Span arguments, + const std::optional& error, const bool run_hlo_passes, + const std::function& reference_preprocessor, + const std::function& test_preprocessor) { + TF_RETURN_IF_ERROR(verifier().Run(module.get()).status()); + TF_ASSIGN_OR_RETURN(std::unique_ptr reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + if (test_preprocessor != nullptr) { + test_preprocessor(module.get()); + } + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + const Literal test, + test_runner_->Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(const Literal reference, + reference_runner_->Execute(std::move(reference_module), + arguments, run_hlo_passes)); + if (reference.IsAll(0)) { + LOG(WARNING) << "Reference value is only zeros."; + } + + return LiteralTestUtil::NearOrEqual(/*expected=*/reference, /*actual=*/test, + error); +} + +absl::StatusOr<::testing::AssertionResult> +NewHloTestBase::RunAndCompareTwoModulesInternalReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + const HloRunnerInterface::ReplicatedExecuteOptions options, + const std::optional& error) { + TF_RETURN_IF_ERROR(verifier().Run(module_0.get()).status()); + TF_RETURN_IF_ERROR(verifier().Run(module_1.get()).status()); + + // Execute the two modules. + TF_ASSIGN_OR_RETURN(auto test_0, test_runner_->ExecuteReplicated( + std::move(module_0), options)); + TF_ASSIGN_OR_RETURN(auto test_1, test_runner_->ExecuteReplicated( + std::move(module_1), options)); + + for (const auto& [expected, actual] : llvm::zip_equal(test_0, test_1)) { + if (::testing::AssertionResult result = + LiteralTestUtil::NearOrEqual(expected, actual, error); + !result) { + return result; + } + } + return ::testing::AssertionSuccess(); +} + +absl::StatusOr<::testing::AssertionResult> +NewHloTestBase::RunAndCompareTwoModulesInternal( + std::unique_ptr module_0, std::unique_ptr module_1, + const absl::Span arguments, + const std::optional& error, bool run_hlo_passes) { + TF_RETURN_IF_ERROR(verifier().Run(module_0.get()).status()); + TF_RETURN_IF_ERROR(verifier().Run(module_1.get()).status()); + + // Execute the two modules. + TF_ASSIGN_OR_RETURN( + const Literal test_0, + test_runner_->Execute(std::move(module_0), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN( + const Literal test_1, + test_runner_->Execute(std::move(module_1), arguments, run_hlo_passes)); + + return LiteralTestUtil::NearOrEqual(/*expected=*/test_0, /*actual=*/test_1, + error); +} + +} // namespace xla diff --git a/third_party/xla/xla/tests/new_hlo_test_base.h b/third_party/xla/xla/tests/new_hlo_test_base.h new file mode 100644 index 00000000000000..0c8d003f9d3d1a --- /dev/null +++ b/third_party/xla/xla/tests/new_hlo_test_base.h @@ -0,0 +1,355 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TESTS_NEW_HLO_TEST_BASE_H_ +#define XLA_TESTS_NEW_HLO_TEST_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" +#include "xla/hlo/testlib/verified_hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/backend.h" +#include "xla/service/computation_layout.h" +#include "xla/service/computation_placer.h" +#include "xla/service/executable.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/hlo_runner_interface.h" +#include "xla/service/hlo_verifier.h" +#include "xla/service/platform_util.h" +#include "xla/shape_layout.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/test_helpers.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/test.h" + +namespace xla { + +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A convenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//xla/tests:new_hlo_test_base", +// ... +// ], +// ) +// +// Unlike HloTestBase, which relies on StreamExecutor via HloRunner, this class +// relies on HloRunnerInterface. HloRunnerInterface supports HloRunner among +// other implementations. We plan to incrementally migrate tests this class and +// away from HloTestBase. +class NewHloTestBase : public HloHardwareIndependentTestBase { + protected: + explicit NewHloTestBase( + absl::Nonnull> test_runner, + absl::Nonnull> reference_runner, + bool verifier_layout_sensitive = false, + bool allow_mixed_precision_in_hlo_verifier = true, + HloPredicate instruction_can_change_layout_func = {}); + + // Creates a new HLO module for a test. The module created will have + // TestName() for its name; it will also automatically populate its debug + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. + // + // This returns a VerifiedHloModule that runs the HLO verifier on + // destruction. + std::unique_ptr CreateNewVerifiedModule( + const std::string& name = TestName(), int64_t replica_count = 1); + + // Parses the given string and returns module as a VerifiedHloModule. + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + int64_t replica_count = 1, + int64_t num_partitions = 1); + absl::StatusOr> + ParseAndReturnVerifiedModule(absl::string_view hlo_text, + const HloModuleConfig& config); + + HloComputation* AddEntryComputationAndUpdateEntryComputationLayout( + HloModule*, std::unique_ptr computation); + void UpdateEntryComputationLayout(HloModule* module) const; + + // Executes the given module and return the result as a Literal. + absl::StatusOr Execute(std::unique_ptr module, + absl::Span arguments, + bool run_hlo_passes = true); + + // Same as above, except the module will be executed without running any HLO + // passes on it. + Literal ExecuteNoHloPasses(std::unique_ptr module, + absl::Span arguments); + + Literal ExecuteAndTransfer(std::unique_ptr module, + absl::Span arguments); + + // Compile the given module to an executable. + absl::StatusOr> CreateExecutable( + std::unique_ptr module, bool run_hlo_passes) { + return test_runner_->CreateExecutable(std::move(module), run_hlo_passes); + } + + // Executes the given module on multiple replicas. + // + // use_threads indicates whether this replicated computation will be executed + // with a thread-per-replica, vs using an implicitly async call such as + // Executable::ExecuteOnStreams. + absl::StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64_t num_replicas, bool use_threads, bool run_hlo_passes = false); + + // Same as above, but uses specified device assignment. + absl::StatusOr> ExecuteReplicated( + std::unique_ptr module, absl::Span arguments, + int64_t num_replicas, DeviceAssignment* device_assignment, + bool run_hlo_passes, bool use_threads); + + // Same as above, but allows passing different programs for replicas. + absl::StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + int64_t num_replicas, bool run_hlo_passes, + DeviceAssignment* device_assignment = nullptr); + + // Convenience function for above. Allows passing different inputs to + // different replicas of the same program. + absl::StatusOr> ExecuteReplicated( + std::unique_ptr module, + std::vector> arguments, int64_t num_replicas, + bool run_hlo_passes, DeviceAssignment* device_assignment = nullptr); + + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + [[nodiscard]] ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr); + + // Same as above, except that the module will be executed without Hlo + // optimization. + [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr); + + // Executes an hlo module with fake inputs and compares the results. + [[nodiscard]] ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr, + std::optional args_max_bits_of_precision = std::nullopt); + + // Same as above, except that the module will be executed without Hlo + // optimization. + [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr); + + // Executes an hlo module with fake inputs and checks that the execution is + // successful. + [[nodiscard]] ::testing::AssertionResult Run( + std::unique_ptr module, bool run_hlo_passes, + const std::function& test_preprocessor = nullptr); + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + [[nodiscard]] ::testing::AssertionResult RunAndCompare( + absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr, + std::optional args_max_bits_of_precision = std::nullopt); + [[nodiscard]] ::testing::AssertionResult Run( + absl::string_view hlo_string, bool run_hlo_passes = true, + ExecutionProfile* profile = nullptr, + const tsl::protobuf::Message* backend_config = nullptr, + bool use_random_data = true); + + // Same as below, except that it requires all the options to be passed. + ::testing::AssertionResult RunAndCompareTwoModulesReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + HloRunnerInterface::ReplicatedExecuteOptions options, + const std::optional& error); + + // Same as below, except that it requires the parsed modules to be passed. + ::testing::AssertionResult RunAndCompareTwoModulesReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + bool run_hlo_passes, bool use_threads, + const std::optional& error); + + // Parses the modules, and executes them based on `run_hlo_passes` and + // `use_threads` flags. The replica count should be mentioned in the module + // itself. + ::testing::AssertionResult RunAndCompareTwoModulesReplicated( + absl::string_view module_0_str, absl::string_view module_1_str, + bool run_hlo_passes, bool use_threads, + const std::optional& error); + + // Same as below, except requires passing fake arguments. + ::testing::AssertionResult RunAndCompareTwoModules( + std::unique_ptr module_0, std::unique_ptr module_1, + absl::Span arguments, + const std::optional& error, bool run_hlo_passes = true); + + // Same as below, except requires passing the modules. + ::testing::AssertionResult RunAndCompareTwoModules( + std::unique_ptr module_0, std::unique_ptr module_1, + const std::optional& error, bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + + // Convenient wrapper for executing and comparing results of two hlo modules + // with fake input. By default compares unoptimized modules. If the modules + // are already optimized, set |run_hlo_passes| to false. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + const std::optional& error, bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + + // Same as above but allows running with different configs. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, const HloModuleConfig& config_0, + const HloModuleConfig& config_1, const std::optional& error, + bool run_hlo_passes = true, + std::optional args_max_bits_of_precision = std::nullopt); + + // Same as above but requires explicit arguments. + ::testing::AssertionResult RunAndCompareTwoModules( + absl::string_view hlo_string_module_0, + absl::string_view hlo_string_module_1, + absl::Span arguments, + const std::optional& error, bool run_hlo_passes = true); + + // Executes an hlo module with fake inputs on multiple replicas. + [[nodiscard]] ::testing::AssertionResult RunReplicated( + absl::string_view hlo_string, bool run_hlo_passes = true, + int64_t num_replicas = 1, + const tsl::protobuf::Message* backend_config = nullptr); + + // If assert_determinism is true, the assertion will fail unless all runs + // produce exactly the same output. + [[nodiscard]] ::testing::AssertionResult RunMultipleTimes( + absl::string_view hlo_string, bool run_hlo_passes, + std::vector* profiles, + const tsl::protobuf::Message* backend_config = nullptr, + bool assert_determinism = false); + [[nodiscard]] ::testing::AssertionResult RunAndCompareNoHloPasses( + absl::string_view hlo_string, const std::optional& error, + const std::function& reference_preprocessor = nullptr, + const std::function& test_preprocessor = nullptr); + + HloRunnerInterface& test_runner() const { return *test_runner_; } + HloRunnerInterface& reference_runner() const { return *reference_runner_; } + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + absl::StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + absl::StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, absl::Span arguments, + const std::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor, + const std::function& test_preprocessor = nullptr); + + // Runs the two module with or without running hlo passes and compares + // the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + absl::StatusOr<::testing::AssertionResult> + RunAndCompareTwoModulesInternalReplicated( + std::unique_ptr module_0, std::unique_ptr module_1, + HloRunnerInterface::ReplicatedExecuteOptions options, + const std::optional& error); + + // Runs the two module on with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + absl::StatusOr<::testing::AssertionResult> RunAndCompareTwoModulesInternal( + std::unique_ptr module_0, std::unique_ptr module_1, + absl::Span arguments, + const std::optional& error, bool run_hlo_passes); + + std::unique_ptr test_runner_; + std::unique_ptr reference_runner_; +}; + +} // namespace xla + +#endif // XLA_TESTS_NEW_HLO_TEST_BASE_H_ diff --git a/third_party/xla/xla/tests/pad_test.cc b/third_party/xla/xla/tests/pad_test.cc index cd039d3daaf931..ca753c122a1921 100644 --- a/third_party/xla/xla/tests/pad_test.cc +++ b/third_party/xla/xla/tests/pad_test.cc @@ -13,31 +13,33 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. -static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif +static std::array test_type_params{F32, BF16, F8E5M2, + F8E4M3FN}; class PadTest : public ClientLibraryTestBase { protected: @@ -68,17 +70,11 @@ class PadTest : public ClientLibraryTestBase { }; class PadTestFloat : public PadTest, - public ::testing::WithParamInterface { + public ::testing::WithParamInterface { protected: - PadTestFloat() { set_use_bfloat16(GetParam()); } - - ErrorSpec DefaultErrorSpec() const { - if (use_bfloat16()) { - return ErrorSpec(1e-3, 1e-3); - } else { - return ErrorSpec(1e-5, 1e-5); - } - } + PadTestFloat() { set_float_type(GetParam()); } + + ErrorSpec DefaultErrorSpec() const { return ErrorSpec(1e-5, 1e-5); } }; // Tests a Pad() with a zero-element input and output. @@ -464,7 +460,7 @@ XLA_TEST_P(PadTestFloat, ReducePad) { } INSTANTIATE_TEST_CASE_P(PadTestFloatInstantiation, PadTestFloat, - ::testing::ValuesIn(use_bfloat16_params)); + ::testing::ValuesIn(test_type_params)); } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/params_test.cc b/third_party/xla/xla/tests/params_test.cc index cae752ee158d6b..46079c711d9e8a 100644 --- a/third_party/xla/xla/tests/params_test.cc +++ b/third_party/xla/xla/tests/params_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/array2d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/pjrt_client_registry.cc b/third_party/xla/xla/tests/pjrt_client_registry.cc index c4f66923c229c2..c4412e67f7abc5 100644 --- a/third_party/xla/xla/tests/pjrt_client_registry.cc +++ b/third_party/xla/xla/tests/pjrt_client_registry.cc @@ -29,9 +29,12 @@ PjRtClientTestFactoryRegistry& GetGlobalPjRtClientTestFactory() { void RegisterPjRtClientTestFactory( PjRtClientTestFactoryRegistry::PjRtClientFactory factory, PjRtClientTestFactoryRegistry::DeviceShapeRepresentationFnFactory - registered_device_shape_representation_fn) { + registered_device_shape_representation_fn, + PjRtClientTestFactoryRegistry::DeviceShapeSizeFnFactory + registered_device_shape_size_fn) { GetGlobalPjRtClientTestFactory().Register( - std::move(factory), registered_device_shape_representation_fn); + std::move(factory), registered_device_shape_representation_fn, + registered_device_shape_size_fn); } bool ShouldUsePjRt() { diff --git a/third_party/xla/xla/tests/pjrt_client_registry.h b/third_party/xla/xla/tests/pjrt_client_registry.h index 7d82fddc058c30..a8add48497624f 100644 --- a/third_party/xla/xla/tests/pjrt_client_registry.h +++ b/third_party/xla/xla/tests/pjrt_client_registry.h @@ -29,20 +29,34 @@ namespace xla { class PjRtClientTestFactoryRegistry { public: - typedef std::function DeviceShapeRepresentationFn; - typedef std::function - DeviceShapeRepresentationFnFactory; - typedef std::function>()> - PjRtClientFactory; + using DeviceShapeRepresentationFn = std::function; + using DeviceShapeRepresentationFnFactory = + std::function; + using DeviceShapeSizeFn = std::function; + using DeviceShapeSizeFnFactory = + std::function; + using PjRtClientFactory = + std::function>()>; static DeviceShapeRepresentationFn DefaultShapeRepresentationRegisteredFn( - absl::StatusOr client) { + PjRtClient* client) { return [](const Shape& host_shape) { return host_shape; }; } + static DeviceShapeSizeFn DefaultDeviceShapeSizeRegisteredFn( + PjRtClient* client) { + return [](const Shape& shape) -> int64_t { + if (shape.IsOpaque()) { + return sizeof(void*); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + } void Register(PjRtClientFactory factory, DeviceShapeRepresentationFnFactory - registered_device_shape_representation_fn) { + registered_device_shape_representation_fn, + DeviceShapeSizeFnFactory registered_device_shape_size_fn) + ABSL_LOCKS_EXCLUDED(mu_) { if (HasRegisteredFactory()) { LOG(FATAL) << "A PjRtClient has already been registered."; return; @@ -52,16 +66,27 @@ class PjRtClientTestFactoryRegistry { factory_ = std::move(factory); registered_device_shape_representation_fn_ = std::move(registered_device_shape_representation_fn); + registered_device_shape_size_fn_ = + std::move(registered_device_shape_size_fn); } // Return the device shape representation of 'host_shape'. DeviceShapeRepresentationFn GetDeviceShapeRepresentationFn( - PjRtClient* pjrt_client) { + PjRtClient* pjrt_client) ABSL_LOCKS_EXCLUDED(mu_) { absl::MutexLock lock(&mu_); return registered_device_shape_representation_fn_(pjrt_client); } - bool HasRegisteredFactory() { + // Return the device shape size of 'host_shape'. + // This function is used e.g. to create a VerifiedHloModule. It returns an + // integer representing the size of the shape in bytes as opposed to a Shape. + DeviceShapeSizeFn GetDeviceShapeSizeFn(PjRtClient* pjrt_client) + ABSL_LOCKS_EXCLUDED(mu_) { + absl::MutexLock lock(&mu_); + return registered_device_shape_size_fn_(pjrt_client); + } + + bool HasRegisteredFactory() ABSL_LOCKS_EXCLUDED(mu_) { absl::MutexLock lock(&mu_); return factory_ != nullptr; } @@ -75,7 +100,10 @@ class PjRtClientTestFactoryRegistry { mutable absl::Mutex mu_; std::function>()> factory_ ABSL_GUARDED_BY(mu_); - DeviceShapeRepresentationFnFactory registered_device_shape_representation_fn_; + DeviceShapeRepresentationFnFactory registered_device_shape_representation_fn_ + ABSL_GUARDED_BY(mu_); + DeviceShapeSizeFnFactory registered_device_shape_size_fn_ + ABSL_GUARDED_BY(mu_); }; PjRtClientTestFactoryRegistry& GetGlobalPjRtClientTestFactory(); @@ -85,7 +113,10 @@ void RegisterPjRtClientTestFactory( PjRtClientTestFactoryRegistry::DeviceShapeRepresentationFnFactory registered_device_shape_representation_fn = PjRtClientTestFactoryRegistry:: - DefaultShapeRepresentationRegisteredFn); + DefaultShapeRepresentationRegisteredFn, + PjRtClientTestFactoryRegistry::DeviceShapeSizeFnFactory + registered_device_shape_size_fn_ = + PjRtClientTestFactoryRegistry::DefaultDeviceShapeSizeRegisteredFn); bool ShouldUsePjRt(); diff --git a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc index 540cf3d59ff6b8..a9205b640b7e3c 100644 --- a/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc +++ b/third_party/xla/xla/tests/pjrt_cpu_client_registry.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "xla/pjrt/cpu/cpu_client.h" #include "xla/tests/pjrt_client_registry.h" @@ -23,7 +25,7 @@ namespace { const bool kUnused = (RegisterPjRtClientTestFactory([]() { CpuClientOptions options; options.cpu_device_count = 4; - return GetTfrtCpuClient(options); + return GetTfrtCpuClient(std::move(options)); }), true); diff --git a/third_party/xla/xla/tests/pred_test.cc b/third_party/xla/xla/tests/pred_test.cc index 9f8af6a013d677..060a433753aa8e 100644 --- a/third_party/xla/xla/tests/pred_test.cc +++ b/third_party/xla/xla/tests/pred_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include "xla/array2d.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/prng_test.cc b/third_party/xla/xla/tests/prng_test.cc index 3ee3b851760b58..b68f56c4157635 100644 --- a/third_party/xla/xla/tests/prng_test.cc +++ b/third_party/xla/xla/tests/prng_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "absl/types/span.h" #include "unsupported/Eigen/SpecialFunctions" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/query_inferred_shape_test.cc b/third_party/xla/xla/tests/query_inferred_shape_test.cc index 8312b15c0c975c..871e6266220cc2 100644 --- a/third_party/xla/xla/tests/query_inferred_shape_test.cc +++ b/third_party/xla/xla/tests/query_inferred_shape_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/reduce_hlo_test.cc b/third_party/xla/xla/tests/reduce_hlo_test.cc index c4dcfedb669bc9..87b0d35a602b72 100644 --- a/third_party/xla/xla/tests/reduce_hlo_test.cc +++ b/third_party/xla/xla/tests/reduce_hlo_test.cc @@ -14,12 +14,29 @@ limitations under the License. ==============================================================================*/ #include - +#include +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "xla/error_spec.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" // Tests the Reduce HLO in ways that can't be done using the ComputationBuilder @@ -64,8 +81,9 @@ Sum { ENTRY reduce.1 { parameter = f32[2,2,2,3]{3,2,1,0} parameter(0) init_value = f32[] constant(0) - reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, to_apply=Sum - ROOT copy = f32[2,2,3]{2,1,0} copy(reduce) + reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, + to_apply=Sum transpose = f32[2,2,3]{2,1,0} transpose(reduce), + dimensions={0,1,2} ROOT bitcast = f32[2,2,3]{2,1,0} bitcast(transpose) } )"; @@ -79,8 +97,10 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_TPU(Reduce)) { } TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetParsedModule()); - HloInstruction* reduce_instruction = - module->entry_computation()->root_instruction()->mutable_operand(0); + HloInstruction* reduce_instruction = module->entry_computation() + ->root_instruction() + ->mutable_operand(0) + ->mutable_operand(0); ASSERT_EQ(reduce_instruction->opcode(), HloOpcode::kReduce); const ReduceLayout& reduce_layout = GetParam(); @@ -110,7 +130,7 @@ XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_TPU(Reduce)) { {-0.241772294, -0.245131493, -0.160247207}, {-0.179881215, -0.23383224, -0.121976733}}}}); - auto reduce_input_relaid = + Literal reduce_input_relaid = reduce_input.Relayout(reduce_input_shape->layout()); EXPECT_TRUE(RunAndCompareNoHloPasses( std::move(module), {&reduce_input_relaid}, ErrorSpec(1e-5))); diff --git a/third_party/xla/xla/tests/reduce_precision_test.cc b/third_party/xla/xla/tests/reduce_precision_test.cc index 38f3ed1b8f0e1f..b3614174902dc5 100644 --- a/third_party/xla/xla/tests/reduce_precision_test.cc +++ b/third_party/xla/xla/tests/reduce_precision_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/array2d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/test.h" diff --git a/third_party/xla/xla/tests/reduce_test.cc b/third_party/xla/xla/tests/reduce_test.cc index f5db7397cad818..ebe799f704d1f5 100644 --- a/third_party/xla/xla/tests/reduce_test.cc +++ b/third_party/xla/xla/tests/reduce_test.cc @@ -44,10 +44,10 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array4d.h" #include "xla/client/global_data.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal_util.h" #include "xla/reference_util.h" diff --git a/third_party/xla/xla/tests/reduce_window_test.cc b/third_party/xla/xla/tests/reduce_window_test.cc index c65cd9c9af1969..e6e374e95e8568 100644 --- a/third_party/xla/xla/tests/reduce_window_test.cc +++ b/third_party/xla/xla/tests/reduce_window_test.cc @@ -15,46 +15,55 @@ limitations under the License. // Tests the reduce-window XLA operation. -#include +#include +#include +#include +#include #include - +#include +#include +#include +#include +#include + +#include +#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/local_client.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/layout_util.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" #include "xla/reference_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/status.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. -static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif +static std::array test_type_params = {F32, BF16}; class ReduceWindowTestBase : public ClientLibraryTestBase { public: ErrorSpec DefaultErrorSpec() const { - if (use_bfloat16()) { + if (FloatType() == BF16) { return ErrorSpec(2e-1, 6e-2); } else { return ErrorSpec(1e-3, 1e-3); @@ -62,10 +71,10 @@ class ReduceWindowTestBase : public ClientLibraryTestBase { } }; -class ReduceWindowTest : public ::testing::WithParamInterface, +class ReduceWindowTest : public ::testing::WithParamInterface, public ReduceWindowTestBase { public: - ReduceWindowTest() : builder_(TestName()) { set_use_bfloat16(GetParam()); } + ReduceWindowTest() : builder_(TestName()) { set_float_type(GetParam()); } void ReduceWindowAdd(const XlaOp input, absl::Span window_dimensions, @@ -569,7 +578,7 @@ XLA_TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { } INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, - ::testing::ValuesIn(use_bfloat16_params)); + ::testing::ValuesIn(test_type_params)); enum Reducer { kAdd, kMax }; @@ -586,7 +595,7 @@ struct R4ReduceWindowTestData { std::string R4ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // @@ -600,17 +609,18 @@ std::string R4ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R4ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R4ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R4ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); @@ -884,7 +894,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { INSTANTIATE_TEST_CASE_P( R4ReduceWindowTestInstantiation, R4ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; @@ -973,7 +983,7 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { INSTANTIATE_TEST_CASE_P( R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R4ReduceWindowTestDataToString); struct R3ReduceWindowTestData { @@ -1023,7 +1033,7 @@ R3ReduceWindowTestData kR3TestCases[] = { std::string R3ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), "__window_bounds_", @@ -1032,17 +1042,18 @@ std::string R3ReduceWindowTestDataToString( param.padding == Padding::kSame ? "same" : "valid", "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], "__reducer_", param.reducer == kAdd ? "add" : "max"); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R3ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R3ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R3ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); @@ -1058,7 +1069,7 @@ class R3ReduceWindowTest : public ReduceWindowTestBase, Literal input_literal = LiteralUtil::CreateR3FromArray3DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); auto reducer = param.reducer; - if (use_bfloat16()) { + if (FloatType() == BF16) { input_literal = LiteralUtil::ConvertF32ToBF16(input_literal); // To avoid numerical issues, force the reducer to be kMax for bf16 @@ -1089,7 +1100,7 @@ XLA_TEST_P(R3ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R3ReduceWindowTestInstantiation, R3ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR3TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R3ReduceWindowTestDataToString); class R3ReduceWindowLargeTest : public R3ReduceWindowTest {}; @@ -1112,7 +1123,7 @@ const R3ReduceWindowTestData kR3ReduceWindowLargeTestValues[] = { INSTANTIATE_TEST_CASE_P( R3ReduceWindowLargeTestInstantiation, R3ReduceWindowLargeTest, ::testing::Combine(::testing::ValuesIn(kR3ReduceWindowLargeTestValues), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R3ReduceWindowTestDataToString); struct R2ReduceWindowTestData { @@ -1274,7 +1285,7 @@ struct R2ReduceWindowTestData { std::string R2ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat( "base_bounds_", absl::StrJoin(param.base_bounds, "x"), // @@ -1289,24 +1300,25 @@ std::string R2ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R2ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R2ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R2ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } void DoIt() { XlaBuilder b(TestName()); const auto& param = ::testing::get<0>(GetParam()); Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - if (!::testing::get<1>(GetParam())) { + if (FloatType() == F32) { // We only do this in F32 mode, to avoid precision issues with BF16. input = *MakeLinspaceArray2D(0, 100, param.base_bounds[0], param.base_bounds[1]); @@ -1339,7 +1351,7 @@ class R2ReduceWindowTest : public ReduceWindowTestBase, /*window_dilations=*/param.window_dilation, /*padding=*/padding); - ComputeAndCompare(&b, {MaybeConvertLiteralToBfloat16(input_literal)}, + ComputeAndCompare(&b, {MaybeConvertLiteralToTestType(input_literal)}, DefaultErrorSpec()); } }; @@ -1349,7 +1361,7 @@ XLA_TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } INSTANTIATE_TEST_CASE_P( R2ReduceWindowTestInstantiation, R2ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR2TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { @@ -1505,7 +1517,7 @@ struct R1ReduceWindowTestData { std::string R1ReduceWindowTestDataToString( const ::testing::TestParamInfo< - ::testing::tuple>& data) { + ::testing::tuple>& data) { const auto& param = ::testing::get<0>(data.param); std::string str = absl::StrCat("base_bounds_", absl::StrJoin(param.base_bounds, "x"), @@ -1517,17 +1529,18 @@ std::string R1ReduceWindowTestDataToString( // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); - if (::testing::get<1>(data.param)) { - absl::StrAppend(&str, "_bfloat16"); - } + absl::StrAppend(&str, "_", + primitive_util::LowercasePrimitiveTypeName( + ::testing::get<1>(data.param))); return str; } -class R1ReduceWindowTest : public ReduceWindowTestBase, - public ::testing::WithParamInterface< - ::testing::tuple> { +class R1ReduceWindowTest + : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: - R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + R1ReduceWindowTest() { set_float_type(::testing::get<1>(GetParam())); } }; XLA_TEST_P(R1ReduceWindowTest, DoIt) { @@ -1581,7 +1594,7 @@ XLA_TEST_P(R1ReduceWindowTest, DoIt) { INSTANTIATE_TEST_CASE_P( R1ReduceWindowTestInstantiation, R1ReduceWindowTest, ::testing::Combine(::testing::ValuesIn(kR1TestCases), - ::testing::ValuesIn(use_bfloat16_params)), + ::testing::ValuesIn(test_type_params)), R1ReduceWindowTestDataToString); // Test class for text-based test cases. Note that this compares with the diff --git a/third_party/xla/xla/tests/replay_test.cc b/third_party/xla/xla/tests/replay_test.cc index 8280251d3bae2f..3107c7b05d6180 100644 --- a/third_party/xla/xla/tests/replay_test.cc +++ b/third_party/xla/xla/tests/replay_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/tests/replicated_io_feed_test.cc b/third_party/xla/xla/tests/replicated_io_feed_test.cc index 0164f8b6b30e69..415faa01ff89e7 100644 --- a/third_party/xla/xla/tests/replicated_io_feed_test.cc +++ b/third_party/xla/xla/tests/replicated_io_feed_test.cc @@ -59,8 +59,7 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { std::unique_ptr module = ParseAndReturnVerifiedModule(hlo_text, config).value(); auto executable = - test_runner_.CreateExecutable(std::move(module), /*run_hlo_passes=*/true) - .value(); + CreateExecutable(std::move(module), /*run_hlo_passes=*/true).value(); auto device_assn = MakeDeviceAssn(kNumReplicas); @@ -81,7 +80,7 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { opts.use_threads = true; TF_ASSERT_OK( - test_runner_.ExecuteReplicated(executable.get(), opts, &device_assn) + ExecuteReplicatedWithHloRunner(executable.get(), opts, &device_assn) .status()); // Verify that each infeed and outfeed is routed correctly. Each replica diff --git a/third_party/xla/xla/tests/reshape_motion_test.cc b/third_party/xla/xla/tests/reshape_motion_test.cc index df65a847ca6e46..2300df5990c635 100644 --- a/third_party/xla/xla/tests/reshape_motion_test.cc +++ b/third_party/xla/xla/tests/reshape_motion_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/array4d.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/reference_util.h" diff --git a/third_party/xla/xla/tests/reshape_test.cc b/third_party/xla/xla/tests/reshape_test.cc index 9e3c09dd12ffc0..84d51c5f53de49 100644 --- a/third_party/xla/xla/tests/reshape_test.cc +++ b/third_party/xla/xla/tests/reshape_test.cc @@ -13,41 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include #include #include +#include +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/array4d.h" -#include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { -// Use a bool parameter to indicate whether to use bfloat16. -class ReshapeTest : public ::testing::WithParamInterface, +class ReshapeTest : public ::testing::WithParamInterface, public ClientLibraryTestBase { public: - ReshapeTest() { set_use_bfloat16(GetParam()); } + ReshapeTest() { set_float_type(GetParam()); } ErrorSpec zero_error_spec_{0.0}; }; @@ -652,16 +657,15 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { XlaComputation computation = builder.Build().value(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithDenseLayout(use_bfloat16() ? BF16 : F32, {2, 8}, - {1, 0}) + ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {2, 8}, {1, 0}) .ToProto(); Literal actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .value(); Literal expected = LiteralUtil::CreateR2FromArray2D(expected_array); - if (use_bfloat16()) { - expected = LiteralUtil::ConvertF32ToBF16(expected); + if (FloatType() != F32) { + expected = MaybeConvertLiteralToTestType(expected); } EXPECT_TRUE(LiteralTestUtil::Equal(expected, actual)); } @@ -808,8 +812,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithDenseLayout(use_bfloat16() ? BF16 : F32, - {7, 2, 3, 5}, {2, 3, 0, 1}) + ShapeUtil::MakeShapeWithDenseLayout(FloatType(), {7, 2, 3, 5}, + {2, 3, 0, 1}) .ToProto(); Literal output_literal = client_ @@ -819,11 +823,29 @@ XLA_TEST_P(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - if (use_bfloat16()) { - auto expected = LiteralUtil::ConvertF32ToBF16(input_literal); - EXPECT_EQ(expected.data(), output_literal.data()); - } else { - EXPECT_EQ(input_literal.data(), output_literal.data()); + switch (FloatType()) { + case F32: + EXPECT_EQ(input_literal.data(), output_literal.data()); + break; + case BF16: { + auto expected = MaybeConvertLiteralToTestType(input_literal); + EXPECT_EQ(expected.data(), output_literal.data()); + break; + } + case F8E4M3FN: { + auto expected = MaybeConvertLiteralToTestType(input_literal); + EXPECT_EQ(expected.data(), + output_literal.data()); + break; + } + case F8E5M2: { + auto expected = MaybeConvertLiteralToTestType(input_literal); + EXPECT_EQ(expected.data(), + output_literal.data()); + break; + } + default: + LOG(FATAL) << "Unsupported float type: " << FloatType(); } } @@ -1017,12 +1039,8 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected.shape()); } -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); -#else INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, - ::testing::ValuesIn(std::vector{false})); -#endif + ::testing::ValuesIn({F32, BF16, F8E5M2, F8E4M3FN})); using ReshapeHloTest = HloTestBase; diff --git a/third_party/xla/xla/tests/reverse_test.cc b/third_party/xla/xla/tests/reverse_test.cc index 299ea416e3c9e7..a7991d930c7f85 100644 --- a/third_party/xla/xla/tests/reverse_test.cc +++ b/third_party/xla/xla/tests/reverse_test.cc @@ -13,47 +13,51 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include +#include +#include +#include #include #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" -#include "xla/array2d.h" +#include "absl/types/span.h" #include "xla/array4d.h" -#include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/error_spec.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/primitive_util.h" +#include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla { namespace { -#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 -// Tests both F32 and BF16. -static std::array use_bfloat16_params{false, true}; -#else -// Only tests F32. -static std::array use_bfloat16_params{false}; -#endif +static std::array primitive_type_params{F32, BF16, F8E5M2, + F8E4M3FN}; struct ReverseSpec { std::vector input_dims; std::vector reversal; - bool use_bfloat16; + PrimitiveType test_type; std::string ToTestCaseName() const { return absl::StrFormat( "reverse_%s_in_dims_%s_%s", absl::StrJoin(input_dims, "x"), - absl::StrJoin(reversal, "x"), use_bfloat16 ? "bf16" : "f32"); + absl::StrJoin(reversal, "x"), + primitive_util::LowercasePrimitiveTypeName(test_type)); } }; static std::vector GetTestCases() { // clang-format off - return ExpandUseBfloat16( - use_bfloat16_params, + return ExpandTestType( + primitive_type_params, {{{}, {}}, {{0, 0}, {0, 1}}, {{0, 1}, {0, 1}}, @@ -74,7 +78,7 @@ void PrintTo(const ReverseSpec& spec, std::ostream* os) { class FloatReverseTest : public ClientLibraryTestBase, public ::testing::WithParamInterface { public: - FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); } + FloatReverseTest() { set_float_type(GetParam().test_type); } }; TEST_P(FloatReverseTest, Reverses) { diff --git a/third_party/xla/xla/tests/rng_test.cc b/third_party/xla/xla/tests/rng_test.cc index 307b464716d2c3..a9e2896f97b34d 100644 --- a/third_party/xla/xla/tests/rng_test.cc +++ b/third_party/xla/xla/tests/rng_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/rng_bit_generator_expander.h" -#include "xla/service/rng_expander.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/scalar_computations_test.cc b/third_party/xla/xla/tests/scalar_computations_test.cc index eeed89f239f737..bc9ab8b7326d3e 100644 --- a/third_party/xla/xla/tests/scalar_computations_test.cc +++ b/third_party/xla/xla/tests/scalar_computations_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/tests/select_and_scatter_test.cc b/third_party/xla/xla/tests/select_and_scatter_test.cc index 7d97927bc3bb62..632ec7a786c54e 100644 --- a/third_party/xla/xla/tests/select_and_scatter_test.cc +++ b/third_party/xla/xla/tests/select_and_scatter_test.cc @@ -27,11 +27,11 @@ limitations under the License. #include "xla/array.h" #include "xla/array2d.h" #include "xla/array4d.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/padding.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" #include "xla/error_spec.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/select_test.cc b/third_party/xla/xla/tests/select_test.cc index 30eba6bdd783ad..660f223ee47565 100644 --- a/third_party/xla/xla/tests/select_test.cc +++ b/third_party/xla/xla/tests/select_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/slice_test.cc b/third_party/xla/xla/tests/slice_test.cc index 3800431fed8dcf..91f5989438b2df 100644 --- a/third_party/xla/xla/tests/slice_test.cc +++ b/third_party/xla/xla/tests/slice_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/sort_test.cc b/third_party/xla/xla/tests/sort_test.cc index b832dbdd0df0d5..3e25d8ba0039c4 100644 --- a/third_party/xla/xla/tests/sort_test.cc +++ b/third_party/xla/xla/tests/sort_test.cc @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" #include "xla/error_spec.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" @@ -85,5 +91,60 @@ XLA_TEST_F(SortTest, SortTwiceWithSameComparator) { EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); } +class SortManyInputsTest : public SortTest, + public ::testing::WithParamInterface { + public: + static std::string Name(const ::testing::TestParamInfo& info) { + auto num_inputs = info.param; + return absl::StrFormat("Sort%dInputs", num_inputs); + } +}; + +XLA_TEST_P(SortManyInputsTest, SortManyInputs) { + int num_inputs = GetParam(); + std::string_view hlo_text_module_template = R"( + HloModule sort + + compare { + ${COMPARE_DECLARATIONS} + ROOT lt = pred[] compare(p0, p1), direction=LT + } + + ENTRY e { + ${SORT_DECLARATIONS} + ROOT sort = (${SORT_SHAPE}) sort(${SORT_PARAMS}), dimensions={0}, + to_apply=compare + } + )"; + + // Prepare values for template substitutions. + std::string sort_decls = ""; + std::vector param_names; + param_names.reserve(num_inputs * 2); + for (int i = 0; i < num_inputs; ++i) { + sort_decls += absl::StrFormat("p%d = f32[32,64] parameter(%d)\n", i, i); + param_names.emplace_back(absl::StrCat("p", i)); + } + std::string sort_params = absl::StrJoin(param_names, ", "); + std::string sort_shape = + absl::StrJoin(std::vector(num_inputs, "f32[32,64]"), ","); + std::string compare_decls = ""; + for (int i = 0; i < num_inputs * 2; ++i) { + compare_decls += absl::StrFormat("p%d = f32[] parameter(%d)\n", i, i); + } + std::string compare_params = absl::StrJoin(param_names, ", "); + + // Finalize HLO text. + std::string hlo_text_module = absl::StrReplaceAll( + hlo_text_module_template, {{"${SORT_DECLARATIONS}", sort_decls}, + {"${SORT_SHAPE}", sort_shape}, + {"${SORT_PARAMS}", sort_params}, + {"${COMPARE_DECLARATIONS}", compare_decls}}); + EXPECT_TRUE(RunAndCompare(hlo_text_module, ErrorSpec{0.0, 0.0})); +} + +INSTANTIATE_TEST_SUITE_P(ManyInputs, SortManyInputsTest, + ::testing::Values(17, 20), SortManyInputsTest::Name); + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/stochastic_convert_test.cc b/third_party/xla/xla/tests/stochastic_convert_test.cc index 9aa1f023850347..a2c351114c19e0 100644 --- a/third_party/xla/xla/tests/stochastic_convert_test.cc +++ b/third_party/xla/xla/tests/stochastic_convert_test.cc @@ -20,10 +20,10 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "xla/tests/test_utils.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/test_utils.cc b/third_party/xla/xla/tests/test_utils.cc index 424fcdd0af3110..6cc71ce1ce7460 100644 --- a/third_party/xla/xla/tests/test_utils.cc +++ b/third_party/xla/xla/tests/test_utils.cc @@ -25,11 +25,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" -#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_verifier.h" #include "xla/service/transfer_manager.h" #include "xla/xla_data.pb.h" @@ -38,289 +38,6 @@ namespace xla { namespace { -template -void PopulateWithRandomFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { - std::uniform_real_distribution generator(-0.1f, 0.2f); - for (FloatT& value : literal->data()) { - value = static_cast(generator(*engine)); - } -} - -// Populates a floating point literal with random floating points sampled from a -// uniform-log distribution spanning approximately the entire range of the -// representable floating point. -template -void PopulateWithRandomFullRangeFloatingPointData(Literal* literal, - std::minstd_rand0* engine) { - constexpr float kSpecialValueProbability = 1e-6; - constexpr float kSpecialValues[] = {+0.F, - -0.F, - 1.F, - -1.F, - std::numeric_limits::infinity(), - -std::numeric_limits::infinity()}; - constexpr int kNumSpecialValues = sizeof(kSpecialValues) / sizeof(float); - std::uniform_real_distribution special_value_gen(0, 1); - - // Generates floating points with a log-uniform distribution. This causes the - // exponent of the floating point to have a uniform distribution. - const int min_exp = std::numeric_limits::min_exponent; - const int max_exp = std::numeric_limits::max_exponent; - std::uniform_real_distribution generator(min_exp - 1, max_exp - 1); - - for (FloatT& value : literal->data()) { - // Each special value has a kSpecialValueProbability chance to be generated - // instead of sampling using the normal distributions. - if (special_value_gen(*engine) < - kSpecialValueProbability * kNumSpecialValues) { - value = - static_cast(kSpecialValues[(*engine)() % kNumSpecialValues]); - } else { - float sign = ((*engine)() % 2 == 0) ? 1 : -1; - value = static_cast(pow(2, generator(*engine)) * sign); - } - } -} - -template -void PopulateWithIntNext(Literal* literal) { - using BitRepT = UnsignedIntegerTypeForSizeType; - // Duplicates may be generated if we don't have enough bits. - // Skip bfloat16 and float32 subnormals. - const FloatT kFirstValue = - std::is_same_v || sizeof(FloatT) >= sizeof(float) - ? std::numeric_limits::min() - : std::numeric_limits::denorm_min(); - // `current` keeps track of the next value we need to populate. - auto current = literal->data().begin(); - auto end = literal->data().end(); - // `sign` keeps track of the sign of the next value. - bool sign = false; - while (current != end) { - // We start populating values at zero and increase magnitude from there. - *current = sign ? static_cast(-0.0f) : static_cast(0.0f); - current++; - // The next value is either the smallest denormal or normal. - auto value = sign ? -kFirstValue : kFirstValue; - // Fill the array with values of increasing magnitude until we hit a - // non-finite value. - while (current != end && Eigen::numext::isfinite(value)) { - // Populate the value. - *current = value; - // Generate the next value by lexicographically increasing the bit - // representation. - const BitRepT next_value = Eigen::numext::bit_cast(value) + 1; - value = Eigen::numext::bit_cast(next_value); - current++; - } - // We ran out of finite values, flip the sign and begin again. - sign = !sign; - } -} - -template -void PopulateWithNoDuplicateData(Literal* literal, std::minstd_rand0* engine) { - PopulateWithIntNext(literal); - std::shuffle(literal->data().begin(), literal->data().end(), - *engine); -} - -template -void PopulateWithFloatingPointData( - Literal* literal, std::minstd_rand0* engine, bool no_duplicates, - bool use_large_range, std::optional max_bits_of_precision) { - using ComputeT = - std::conditional_t; - CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), - primitive_util::NativeToPrimitiveType()); - if (max_bits_of_precision.has_value()) { - CHECK(!use_large_range) << "Cannot set both use_large_range and " - "max_bits_of_precision for floating points."; - CHECK(!no_duplicates) << "Cannot set both no_duplicates and " - "max_bits_of_precision for floating points."; - std::uniform_int_distribution generator( - -(1 << *max_bits_of_precision), 1 << *max_bits_of_precision); - for (FloatT& value : literal->data()) { - int64_t temp = generator(*engine); - // We want to generate floating point numbers to a fixed precision, while - // keeping them between -1 and 1. This preserves their bits of precision - // while keeping the numbers small. - value = static_cast(temp * pow(2, -ceil(log2(abs(temp))))); - } - } else if (no_duplicates) { - PopulateWithNoDuplicateData(literal, engine); - } else if (use_large_range) { - PopulateWithRandomFullRangeFloatingPointData(literal, engine); - } else { - PopulateWithRandomFloatingPointData(literal, engine); - } -} - -template -void PopulateWithComplexData(Literal* result, std::minstd_rand0* engine, - bool no_duplicates, bool use_large_range) { - using InnerFloatT = typename ComplexT::value_type; - CHECK(engine != nullptr); - CHECK_EQ(result->shape().element_type(), - primitive_util::NativeToPrimitiveType()); - Shape floating_point_shape = ShapeUtil::ChangeElementType( - result->shape(), primitive_util::NativeToPrimitiveType()); - Literal real_lit(floating_point_shape); - Literal imaginary_lit(floating_point_shape); - - PopulateWithFloatingPointData( - &real_lit, engine, no_duplicates, use_large_range, - /*max_bits_of_precision=*/std::nullopt); - PopulateWithFloatingPointData( - &imaginary_lit, engine, no_duplicates, use_large_range, - /*max_bits_of_precision=*/std::nullopt); - - absl::Span real_data = real_lit.data(); - absl::Span imaginary_data = - imaginary_lit.data(); - absl::Span result_data = result->data(); - for (int i = 0; i < real_lit.data().size(); i++) { - result_data[i] = ComplexT(real_data[i], imaginary_data[i]); - } -} - -// uniform_int_distribution is not defined for 8-bit integers. -// Use 'short' for those types. -template -using RngT = std::conditional_t< - sizeof(IntT) < sizeof(uint16_t), - std::conditional_t::is_signed, int16_t, uint16_t>, - IntT>; - -template -void PopulateWithRandomIntegralDataWithBounds(Literal* literal, - std::minstd_rand0* engine, - bool no_duplicates, IntT min, - IntT max) { - CHECK(engine != nullptr); - CHECK_EQ(literal->shape().element_type(), - primitive_util::NativeToPrimitiveType()); - if (no_duplicates && - ShapeUtil::ElementsIn(literal->shape()) < static_cast(max)) { - std::iota(literal->data().begin(), literal->data().end(), - static_cast(0)); - std::shuffle(literal->data().begin(), literal->data().end(), - *engine); - } else { - std::uniform_int_distribution> generator( - static_cast>(min), static_cast>(max)); - for (IntT& value : literal->data()) { - value = static_cast(generator(*engine)); - } - } -} - -// Similar to MakeFakeLiteral but takes a random number generator engine to -// enable reusing the engine across randomly generated literals. -// 'limit' is a optional pair that contains the min and the max values to be -// sample for integers (integer format only). -// 'is_sorted' sorts the sample data for integers (integer format only). -// 'no_duplicates' indicates that there should be no duplicate values in each -// generated array. This is uniqueness is best-effort only. Some types -// (half and bfloat16) are not supported and uniqueness cannot be guaranteed if -// the number of elements exceeds the number of different values supported by -// the type. (floating point format only) -// 'use_large_range' indicates the sampled data is from the full range of the -// floating point format. (floating point format only) -// 'max_bits_of_precision' sets the data to have the given number of bits or -// less (integer or floating point formats only). -absl::StatusOr MakeFakeLiteralInternal( - const Shape& shape, std::minstd_rand0* engine, - std::optional> limit, bool is_sorted, - bool no_duplicates, bool use_large_range, - std::optional max_bits_of_precision) { - if (shape.IsTuple()) { - std::vector elements; - const auto& shape_tuple_shapes = shape.tuple_shapes(); - elements.reserve(shape_tuple_shapes.size()); - for (const Shape& element_shape : shape_tuple_shapes) { - TF_ASSIGN_OR_RETURN( - Literal element, - MakeFakeLiteralInternal(element_shape, engine, limit, is_sorted, - no_duplicates, use_large_range, - max_bits_of_precision)); - elements.push_back(std::move(element)); - } - return LiteralUtil::MakeTupleOwned(std::move(elements)); - } - if (engine == nullptr) { - return Literal::CreateFromShape(shape); - } - // Clear tiles/element size in shape's layout before using it for creating - // literal. - Shape new_shape = shape; - new_shape.mutable_layout()->clear_tiles(); - new_shape.mutable_layout()->set_tail_padding_alignment_in_elements(1); - new_shape.mutable_layout()->set_element_size_in_bits(0); - Literal literal(new_shape); - - TF_RETURN_IF_ERROR(primitive_util::PrimitiveTypeSwitch( - [&](auto primitive_type_constant) -> absl::Status { - if constexpr (primitive_util::IsArrayType(primitive_type_constant)) { - using NativeT = primitive_util::NativeTypeOf; - if constexpr (primitive_util::IsFloatingPointType( - primitive_type_constant)) { - PopulateWithFloatingPointData( - &literal, engine, no_duplicates, use_large_range, - max_bits_of_precision); - return absl::OkStatus(); - } - if constexpr (primitive_type_constant == PRED) { - std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal.Populate( - [&](absl::Span /*indices*/) { - return generator(*engine); - })); - return absl::OkStatus(); - } - if constexpr (primitive_util::IsIntegralType( - primitive_type_constant)) { - NativeT max = std::numeric_limits::max(); - NativeT min = std::numeric_limits::lowest(); - if (limit.has_value()) { - max = static_cast(limit->second); - min = static_cast(limit->first); - } - if (max_bits_of_precision.has_value()) { - max = std::min(max, - static_cast(1 << *max_bits_of_precision)); - if (primitive_util::IsSignedIntegralType( - primitive_type_constant)) { - min = std::max( - min, static_cast(-(1 << *max_bits_of_precision))); - } - } - PopulateWithRandomIntegralDataWithBounds( - &literal, engine, /*no_duplicate*/ no_duplicates, min, max); - if (is_sorted) { - std::sort(literal.data().begin(), - literal.data().end()); - } - return absl::OkStatus(); - } - if constexpr (primitive_util::IsComplexType( - primitive_type_constant)) { - PopulateWithComplexData(&literal, engine, no_duplicates, - use_large_range); - return absl::OkStatus(); - } - } - return Unimplemented( - "Unsupported type for fake random literal generation with bounds: " - "%s", - ShapeUtil::HumanString(shape)); - }, - shape.element_type())); - return std::move(literal); -} - enum class ConstantType { kUnknown, kZero, kOne }; // Return the constant type required by this computation, if known. @@ -551,10 +268,10 @@ absl::StatusOr CreateLiteralForConstrainedUses( return Unimplemented("Conflicting operand generation constraints."); } if (index_bound != INT64_MAX) { - return MakeFakeLiteralInternal(param_shape, engine, - std::pair(0, index_bound), - needs_sorted_indices, no_duplicates, - use_large_range, max_bits_of_precision); + return MakeFakeLiteral(param_shape, engine, + std::pair(0, index_bound), + needs_sorted_indices, no_duplicates, use_large_range, + max_bits_of_precision); } else if (needs_constant) { switch (constant_type) { case ConstantType::kZero: @@ -565,16 +282,15 @@ absl::StatusOr CreateLiteralForConstrainedUses( // We want the identity element for the computation, but we don't // really know what it is - so any value we generate will be just as // wrong. - return MakeFakeLiteralInternal( - param_shape, engine, /*limit=*/std::nullopt, - /*is_sorted=*/needs_sorted_indices, - /*no_duplicates=*/false, use_large_range, max_bits_of_precision); + return MakeFakeLiteral(param_shape, engine, /*limit=*/std::nullopt, + /*is_sorted=*/needs_sorted_indices, + /*no_duplicates=*/false, use_large_range, + max_bits_of_precision); } } else { - return MakeFakeLiteralInternal(param_shape, engine, /*limit=*/std::nullopt, - /*is_sorted=*/needs_sorted_indices, - no_duplicates, use_large_range, - max_bits_of_precision); + return MakeFakeLiteral(param_shape, engine, /*limit=*/std::nullopt, + /*is_sorted=*/needs_sorted_indices, no_duplicates, + use_large_range, max_bits_of_precision); } } @@ -594,15 +310,6 @@ absl::StatusOr MakeConstrainedArgument( } // namespace -absl::StatusOr MakeFakeLiteral(const Shape& shape, bool pseudo_random, - bool use_large_range) { - auto engine = pseudo_random ? std::make_unique() : nullptr; - return MakeFakeLiteralInternal(shape, engine.get(), /*limit=*/std::nullopt, - /*is_sorted=*/false, - /*no_duplicates=*/false, use_large_range, - /*max_bits_of_precision=*/std::nullopt); -} - absl::StatusOr> MakeFakeArguments( const HloModule* module, bool pseudo_random, bool use_large_range, bool treat_gte_as_data_formatting, diff --git a/third_party/xla/xla/tests/test_utils.h b/third_party/xla/xla/tests/test_utils.h index 1df3ee2e0b5da0..ec851b976941a6 100644 --- a/third_party/xla/xla/tests/test_utils.h +++ b/third_party/xla/xla/tests/test_utils.h @@ -57,13 +57,6 @@ class PseudorandomGenerator { std::mt19937 generator_; }; -// Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data -// generation. See below for documentation of pseudo_random and use_large_range. -absl::StatusOr MakeFakeLiteral(const Shape& shape, - bool pseudo_random = true, - bool use_large_range = false); - // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. // diff --git a/third_party/xla/xla/tests/test_utils_test.cc b/third_party/xla/xla/tests/test_utils_test.cc index 22212a02998239..8368cfe01582dc 100644 --- a/third_party/xla/xla/tests/test_utils_test.cc +++ b/third_party/xla/xla/tests/test_utils_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/container/flat_hash_set.h" -#include "xla/client/xla_builder.h" -#include "xla/service/hlo_parser.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/shape_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/transfer_manager_test.cc b/third_party/xla/xla/tests/transfer_manager_test.cc index b290dfea6ef92e..2478ec08171208 100644 --- a/third_party/xla/xla/tests/transfer_manager_test.cc +++ b/third_party/xla/xla/tests/transfer_manager_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/service/generic_transfer_manager.h" -#include "xla/service/hlo_parser.h" #include "xla/service/shaped_buffer.h" #include "xla/service/stream_pool.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/transpose_test.cc b/third_party/xla/xla/tests/transpose_test.cc index e52fccfcc69a6d..ffd1a5c6156faf 100644 --- a/third_party/xla/xla/tests/transpose_test.cc +++ b/third_party/xla/xla/tests/transpose_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/array2d.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal_util.h" #include "xla/reference_util.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/triangular_solve_test.cc b/third_party/xla/xla/tests/triangular_solve_test.cc index b04ac99d4110e4..3bbe5ca227c074 100644 --- a/third_party/xla/xla/tests/triangular_solve_test.cc +++ b/third_party/xla/xla/tests/triangular_solve_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/ascii.h" #include "xla/array.h" #include "xla/array2d.h" -#include "xla/client/lib/math.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/math.h" +#include "xla/hlo/builder/lib/matrix.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/literal.h" #include "xla/test.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/tuple_test.cc b/third_party/xla/xla/tests/tuple_test.cc index 8d6c1c641579e9..5cc7f7b1bb9d18 100644 --- a/third_party/xla/xla/tests/tuple_test.cc +++ b/third_party/xla/xla/tests/tuple_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array2d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" diff --git a/third_party/xla/xla/tests/unary_op_test.cc b/third_party/xla/xla/tests/unary_op_test.cc index e8ea9ff1ae84ce..adf4912cfdc96c 100644 --- a/third_party/xla/xla/tests/unary_op_test.cc +++ b/third_party/xla/xla/tests/unary_op_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/client/global_data.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/value_inference_test.cc b/third_party/xla/xla/tests/value_inference_test.cc index 50da08967a01eb..5ac2f038f67180 100644 --- a/third_party/xla/xla/tests/value_inference_test.cc +++ b/third_party/xla/xla/tests/value_inference_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/client/value_inference.h" +#include "xla/hlo/builder/value_inference.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/client/client_library.h" #include "xla/client/global_data.h" -#include "xla/client/lib/arithmetic.h" -#include "xla/client/lib/prng.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/lib/prng.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/layout_util.h" #include "xla/literal.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/vector_ops_reduce_test.cc b/third_party/xla/xla/tests/vector_ops_reduce_test.cc index f0524ec0a6787d..f35beb32f78fde 100644 --- a/third_party/xla/xla/tests/vector_ops_reduce_test.cc +++ b/third_party/xla/xla/tests/vector_ops_reduce_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include "xla/array2d.h" #include "xla/array3d.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" diff --git a/third_party/xla/xla/tests/vector_ops_simple_test.cc b/third_party/xla/xla/tests/vector_ops_simple_test.cc index 5501124b2f3878..eb67f886d6254f 100644 --- a/third_party/xla/xla/tests/vector_ops_simple_test.cc +++ b/third_party/xla/xla/tests/vector_ops_simple_test.cc @@ -22,12 +22,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/array4d.h" #include "xla/client/global_data.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/shape_util.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" diff --git a/third_party/xla/xla/tests/verified_hlo_module.h b/third_party/xla/xla/tests/verified_hlo_module.h index 837166a36d3e1d..3b27b3bd0cefa5 100644 --- a/third_party/xla/xla/tests/verified_hlo_module.h +++ b/third_party/xla/xla/tests/verified_hlo_module.h @@ -15,52 +15,7 @@ limitations under the License. #ifndef XLA_TESTS_VERIFIED_HLO_MODULE_H_ #define XLA_TESTS_VERIFIED_HLO_MODULE_H_ -#include - -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_verifier.h" -#include "xla/shape.h" -#include "xla/types.h" -#include "tsl/platform/status.h" - -namespace xla { - -// An HLO module derived class which verifies itself on destruction. This class -// is intended to be used in unit tests. Any verification errors are raised via -// ADD_FAILURE. -class VerifiedHloModule : public HloModule { - public: - VerifiedHloModule(const std::string& name, const HloModuleConfig& config, - bool verifier_layout_sensitive, - bool allow_mixed_precision_in_hlo_verifier, - std::function shape_size_function, - HloPredicate instruction_can_change_layout_func = {}) - : HloModule(name, config), - verifier_(verifier_layout_sensitive, - allow_mixed_precision_in_hlo_verifier, - instruction_can_change_layout_func, shape_size_function) {} - - ~VerifiedHloModule() override { VerifyOrAddFailure("in destructor"); } - - // Given a string in the HloModule::ToString() format, parses the string and - // builds the VerifiedHloModule in place. Before calling this method, the - // module must be empty (no computations). Finally verifies the module using - // HloVerifier and returns the status. - absl::Status ParseHloStringAndVerifyModule(absl::string_view str); - - // Verifies the module and flags any error with ADD_FAILURE. 'message' is - // included in the failure message. - void VerifyOrAddFailure(absl::string_view message); - - // Verifies the module using HloVerifier and returns the status. - absl::Status Verify(); - - private: - HloVerifier verifier_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/testlib/verified_hlo_module.h" #endif // XLA_TESTS_VERIFIED_HLO_MODULE_H_ diff --git a/third_party/xla/xla/tests/while_test.cc b/third_party/xla/xla/tests/while_test.cc index 473875960fdd16..aa5d0392e941a6 100644 --- a/third_party/xla/xla/tests/while_test.cc +++ b/third_party/xla/xla/tests/while_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/client/client_library.h" -#include "xla/client/lib/arithmetic.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/lib/arithmetic.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/literal.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/tests/xla_hlo_profile_test.cc b/third_party/xla/xla/tests/xla_hlo_profile_test.cc index 72e4387da2beb8..8013f2d8be904f 100644 --- a/third_party/xla/xla/tests/xla_hlo_profile_test.cc +++ b/third_party/xla/xla/tests/xla_hlo_profile_test.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/strings/str_split.h" #include "xla/array2d.h" #include "xla/client/local_client.h" -#include "xla/client/xla_builder.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_builder.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/map_util.h" #include "xla/service/platform_util.h" #include "xla/service/stream_pool.h" diff --git a/third_party/xla/xla/text_literal_reader.cc b/third_party/xla/xla/text_literal_reader.cc index 209790bc33eafa..7b627e5b9d1b1f 100644 --- a/third_party/xla/xla/text_literal_reader.cc +++ b/third_party/xla/xla/text_literal_reader.cc @@ -21,21 +21,26 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal.h" -#include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/random_inputstream.h" -#include "tsl/platform/protobuf.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_reader.h b/third_party/xla/xla/text_literal_reader.h index 397229e74d81cf..20684755cae91d 100644 --- a/third_party/xla/xla/text_literal_reader.h +++ b/third_party/xla/xla/text_literal_reader.h @@ -24,6 +24,7 @@ limitations under the License. #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_reader_test.cc b/third_party/xla/xla/text_literal_reader_test.cc index afeed461c61be2..11d76f224f4c9a 100644 --- a/third_party/xla/xla/text_literal_reader_test.cc +++ b/third_party/xla/xla/text_literal_reader_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/shape_util.h" #include "xla/test.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/xla/text_literal_writer.cc b/third_party/xla/xla/text_literal_writer.cc index 050eb5fe835adc..83833dacbf5924 100644 --- a/third_party/xla/xla/text_literal_writer.cc +++ b/third_party/xla/xla/text_literal_writer.cc @@ -18,14 +18,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/literal.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/types.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" namespace xla { diff --git a/third_party/xla/xla/text_literal_writer.h b/third_party/xla/xla/text_literal_writer.h index 2ce5b368773d34..a11205c905f626 100644 --- a/third_party/xla/xla/text_literal_writer.h +++ b/third_party/xla/xla/text_literal_writer.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_TEXT_LITERAL_WRITER_H_ #define XLA_TEXT_LITERAL_WRITER_H_ +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "xla/literal.h" #include "xla/types.h" diff --git a/third_party/xla/xla/text_literal_writer_test.cc b/third_party/xla/xla/text_literal_writer_test.cc index e517279a4c447d..657937f749fa32 100644 --- a/third_party/xla/xla/text_literal_writer_test.cc +++ b/third_party/xla/xla/text_literal_writer_test.cc @@ -18,12 +18,10 @@ limitations under the License. #include #include -#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "xla/types.h" #include "tsl/platform/env.h" namespace xla { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index 3935af05d34a4c..ef54458f7d115e 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -12,6 +12,7 @@ load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) +load("//xla:lit.bzl", "enforce_glob", "lit_test_suite") load( "//xla:xla.bzl", "xla_cc_binary", @@ -45,29 +46,6 @@ filegroup( visibility = ["//xla:internal"], ) -build_test( - name = "hex_floats_to_packed_literal_build_test", - targets = [ - ":hex_floats_to_packed_literal", - ], -) - -xla_cc_binary( - name = "hex_floats_to_packed_literal", - srcs = ["hex_floats_to_packed_literal.cc"], - deps = [ - "//xla/tsl/util:command_line_flags", - "@com_google_absl//absl/base", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/io:buffered_inputstream", - "@local_tsl//tsl/lib/io:random_inputstream", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - ], -) - build_test( name = "show_signature_build_test", targets = [ @@ -94,92 +72,6 @@ xla_cc_binary( ], ) -build_test( - name = "show_literal_build_test", - targets = [ - ":show_literal", - ], -) - -xla_cc_binary( - name = "show_literal", - srcs = ["show_literal.cc"], - deps = [ - "//xla:literal", - "//xla:types", - "//xla:xla_data_proto_cc", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - ], -) - -build_test( - name = "convert_computation_build_test", - targets = [ - ":convert_computation", - ], -) - -xla_cc_binary( - name = "convert_computation", - srcs = ["convert_computation.cc"], - deps = [ - "//xla/service:hlo_proto_cc", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - ], -) - -build_test( - name = "hlo_module_metadata_processor_build_test", - targets = [ - ":hlo_module_metadata_processor", - ], -) - -xla_cc_binary( - name = "hlo_module_metadata_processor", - srcs = ["hlo_module_metadata_processor.cc"], - deps = [ - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - ], -) - -build_test( - name = "show_text_literal_build_test", - targets = [ - ":show_text_literal", - ], -) - -xla_cc_binary( - name = "show_text_literal", - srcs = ["show_text_literal.cc"], - deps = [ - "//xla:literal", - "//xla:text_literal_reader", - "//xla:types", - "//xla:xla_data_proto_cc", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:protobuf", - ], -) - build_test( name = "dumped_computation_to_text_build_test", targets = [ @@ -196,7 +88,7 @@ xla_cc_binary( "//xla/client:client_library", "//xla/client:executable_build_options", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service", "//xla/service:hlo_proto_cc", @@ -227,7 +119,7 @@ xla_cc_binary( "//xla/client:client_library", "//xla/client:executable_build_options", "//xla/client:local_client", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/service", "//xla/service:hlo_proto_cc", @@ -245,28 +137,6 @@ xla_cc_binary( ], ) -build_test( - name = "hlo_proto_to_json_build_test", - targets = [ - ":hlo_proto_to_json", - ], -) - -xla_cc_binary( - name = "hlo_proto_to_json", - srcs = ["hlo_proto_to_json.cc"], - deps = [ - "//xla:util", - "//xla/service:hlo_proto_cc", - "//xla/tsl/util:command_line_flags", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status", - ], -) - xla_cc_test( name = "hlo_extractor_test", srcs = ["hlo_extractor_test.cc"], @@ -293,7 +163,6 @@ cc_library( "//xla/service:compilation_environments", "//xla/service:hlo_module_config", "//xla/service:hlo_verifier", - "//xla/tests:test_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -346,12 +215,12 @@ cc_library( deps = [ "//xla:xla_data_proto_cc", "//xla/hlo/pass:hlo_pass_pipeline", + "//xla/hlo/transforms:cholesky_expander", + "//xla/hlo/transforms:rng_bit_generator_expander", + "//xla/hlo/transforms:rng_expander", "//xla/service:batchnorm_expander", - "//xla/service:cholesky_expander", "//xla/service:hlo_proto_cc", "//xla/service:hlo_verifier", - "//xla/service:rng_bit_generator_expander", - "//xla/service:rng_expander", "//xla/service:sharding_propagation", "//xla/service:triangular_solve_expander", "//xla/service/spmd:stateful_rng_spmd_partitioner", @@ -428,6 +297,7 @@ xla_cc_binary( "//xla/service:hlo_runner", "//xla/service:local_service", "//xla/service:platform_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", @@ -436,7 +306,6 @@ xla_cc_binary( "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:subprocess", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ] + if_cuda_or_rocm([ "//xla/service:gpu_plugin", ]) + if_cuda([ @@ -472,15 +341,20 @@ cc_library( deps = [ ":run_hlo_module_proto_cc", "//xla:debug_options_flags", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) @@ -504,7 +378,7 @@ cc_library( "//xla:debug_options_flags", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:despecializer", + "//xla/hlo/transforms:despecializer", "//xla/service:hlo_module_config", "//xla/service:hlo_runner_interface", "//xla/stream_executor:platform", @@ -538,7 +412,6 @@ tf_proto_library( xla_py_proto_library( name = "run_hlo_module_pb2", - api_version = 2, visibility = ["//visibility:public"], deps = [":run_hlo_module_proto"], ) @@ -622,13 +495,13 @@ xla_cc_binary( deps = [ ":run_hlo_module_lib", "//xla:debug_options_flags", + "//xla/hlo/translate/mhlo_to_hlo:translate", + "//xla/hlo/translate/stablehlo_to_hlo:translate", "//xla/service:cpu_plugin", "//xla/service:hlo_module_config", "//xla/service:hlo_runner", "//xla/service:interpreter_plugin", "//xla/service:platform_util", - "//xla/translate/mhlo_to_hlo:translate", - "//xla/translate/stablehlo_to_hlo:translate", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", @@ -660,7 +533,7 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", + "//xla/hlo/parser:hlo_parser", "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -683,11 +556,12 @@ cc_library( "//xla:literal_util", "//xla:shape_util", "//xla:util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", + "//xla/hlo/transforms:hlo_dce", "//xla/service:call_graph", "//xla/service:collective_ops_utils", - "//xla/service:hlo_dce", "//xla/service:tuple_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -705,15 +579,18 @@ xla_cc_test( srcs = ["hlo_control_flow_flattening_test.cc"], deps = [ ":hlo_control_flow_flattening", + "//xla/hlo/transforms:despecializer", "//xla/hlo/utils:hlo_matchers", "//xla/service:collective_ops_utils", - "//xla/service:despecializer", "//xla/service:hlo_verifier", "//xla/service/spmd:spmd_partitioner", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], ) @@ -739,9 +616,11 @@ xla_cc_binary( deps = [ ":hlo_module_loader", "//xla:debug_options_flags", - "//xla:shape_util", + "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:platform_port", @@ -749,6 +628,23 @@ xla_cc_binary( ], ) +lit_test_suite( + name = "compute_cost_test", + srcs = enforce_glob( + [ + "compute_cost_test.hlo", + ], + include = [ + "*.hlo", + ], + ), + cfg = "//xla:lit.cfg.py", + tools = [ + ":compute_cost", + "@llvm-project//llvm:FileCheck", + ], +) + xla_cc_binary( name = "extract_collective_operations", srcs = ["extract_collective_operations.cc"], @@ -782,7 +678,7 @@ tsl_gpu_library( "//xla:debug_options_flags", "//xla:shape_util", "//xla:util", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/mlir_hlo", @@ -796,8 +692,8 @@ tsl_gpu_library( "//xla/service:xla_compile_result_proto_cc_impl", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:cpu_executable", - "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/status", @@ -851,6 +747,8 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest", @@ -861,8 +759,6 @@ xla_test( "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) @@ -877,10 +773,6 @@ xla_test( "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto", ], - tags = [ - "config-cuda-only", - "no_rocm", - ], deps = [ ":xla_compile_lib", "//xla:util", @@ -894,14 +786,14 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", + "//xla/tsl/protobuf:status_proto_cc", "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", - "@local_tsl//tsl/protobuf:status_proto_cc", ], ) diff --git a/third_party/xla/xla/tools/compute_cost.cc b/third_party/xla/xla/tools/compute_cost.cc index 9615ae01b59940..cd283720e2c383 100644 --- a/third_party/xla/xla/tools/compute_cost.cc +++ b/third_party/xla/xla/tools/compute_cost.cc @@ -17,16 +17,21 @@ limitations under the License. #include #include +#include #include #include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/hlo_cost_analysis.h" -#include "xla/shape.h" -#include "xla/shape_util.h" #include "xla/tools/hlo_module_loader.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" @@ -41,15 +46,50 @@ The input file can be obtained from XProf graph viewer by clicking Usage: - bazel run compute_cost -- -input=path/to/hlo_module -format=[hlo|pb|pbtxt] + bazel run compute_cost -- --input=path/to/hlo_module --format=[hlo|pb|pbtxt] [--gpu] [--all] )"; } // namespace +namespace xla { +void print_costs_of_all_instructions(const HloModule& module, + const HloCostAnalysis& analysis) { + absl::flat_hash_map fingerprint_to_name; + std::cout << "HLO name, deduplicated name, bytes accessed, flops\n"; + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kBitcast) { + // These instructions always have zero costs. + continue; + } + absl::string_view deduplicated_name = hlo->metadata().deduplicated_name(); + if (deduplicated_name.empty()) { + deduplicated_name = hlo->name(); + } + std::cout << hlo->name() << ", " << deduplicated_name << ", " + << analysis.bytes_accessed(*hlo) << ", " + << analysis.flop_count(*hlo) << "\n"; + } + } +} +} // namespace xla + int main(int argc, char** argv) { std::string input, format; + bool gpu = false; + bool all = false; std::vector flag_list = { tsl::Flag("input", &input, "input file"), - tsl::Flag("format", &format, "hlo|pb|pbtxt")}; + tsl::Flag("format", &format, "hlo|pb|pbtxt"), + tsl::Flag("gpu", &gpu, + "Use GPU flavor of cost analysis instead of the generic one"), + tsl::Flag( + "all", &all, + "Also print costs and deduplicated name of each instruction, not " + "just the total costs for the module")}; xla::AppendDebugOptionsFlags(&flag_list); const std::string kUsageString = absl::StrCat(kUsage, "\n\n", tsl::Flags::Usage(argv[0], flag_list)); @@ -59,18 +99,27 @@ int main(int argc, char** argv) { LOG(QFATAL) << kUsageString; } - xla::HloCostAnalysis analysis([](const xla::Shape& shape) { - return xla::ShapeUtil::ByteSizeOf(shape, 8); - }); + std::unique_ptr analysis; + if (gpu) { + analysis = std::make_unique( + xla::HloCostAnalysis::Options{}); + } else { + analysis = std::make_unique(); + } + + std::unique_ptr module = + *xla::LoadModuleFromFile(input, format, {}); + + TF_CHECK_OK( + module->entry_computation()->root_instruction()->Accept(&*analysis)); - TF_CHECK_OK(xla::LoadModuleFromFile(input, format, {}) - .value() - ->entry_computation() - ->root_instruction() - ->Accept(&analysis)); + if (all) { + print_costs_of_all_instructions(*module, *analysis); + } std::cout << std::setw(5) << std::setprecision(4) - << analysis.flop_count() / (1e9) << " GFLOPS. " - << analysis.bytes_accessed() / (1e6) << " MiB." << std::endl; + << "Total: " << analysis->flop_count() / (1e9) << " GFLOPS. " + << analysis->bytes_accessed() / (1e6) << " MB." << std::endl; + return 0; } diff --git a/third_party/xla/xla/tools/compute_cost_test.hlo b/third_party/xla/xla/tools/compute_cost_test.hlo new file mode 100644 index 00000000000000..0d4a12a1657920 --- /dev/null +++ b/third_party/xla/xla/tools/compute_cost_test.hlo @@ -0,0 +1,14 @@ +// RUN: compute_cost --all --input=%s | FileCheck %s + + +// CHECK: HLO name, deduplicated name, bytes accessed, flops +// CHECK-NEXT: a, a, 30, 10 +// CHECK-NEXT: a2, a, 30, 10 +// CHECK-NEXT: r, r, 210, 0 + +e { + c = s8[10] constant(1) + a = s8[10] add(c, c) + a2 = s8[10] add(a, a), metadata={deduplicated_name="a"} + r = s8[10,20] broadcast(a2), dimensions={0} +} diff --git a/third_party/xla/xla/tools/driver.cc b/third_party/xla/xla/tools/driver.cc index 780968098cf32b..4f4895b57123ae 100644 --- a/third_party/xla/xla/tools/driver.cc +++ b/third_party/xla/xla/tools/driver.cc @@ -120,22 +120,28 @@ enum PrimitiveType { C64, C128, F8E5M2, + F8E4M3, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ, + F8E3M4, }; const std::vector& primitive_strings() { - static auto vec = - new std::vector({"s2", "s4", "s8", - "s16", "s32", "s64", - "u2", "u4", "u8", - "u16", "u32", "u64", - "f16", "bf16", "f32", - "f64", "c64", "c128", - "f8e5m2", "f8e4m3fn", "f8e4m3b11fnuz", - "f8e5m2fnuz", "f8e4m3fnuz"}); + static auto vec = new std::vector({"s2", "s4", + "s8", "s16", + "s32", "s64", + "u2", "u4", + "u8", "u16", + "u32", "u64", + "f16", "bf16", + "f32", "f64", + "c64", "c128", + "f8e5m2", "f8e4m3", + "f8e4m3fn", "f8e4m3b11fnuz", + "f8e5m2fnuz", "f8e4m3fnuz", + "f8e3m4"}); return *vec; } @@ -413,10 +419,12 @@ void Fill(void* buffer, const ArrayShape& shape) { return FillFloatT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: @@ -469,10 +477,12 @@ void Display(const void* buffer, const ArrayShape& shape) { return DisplayT(buffer, num_elements); case F8E5M2: + case F8E4M3: case F8E4M3FN: case F8E4M3B11FNUZ: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E3M4: case F16: case BF16: case C64: diff --git a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc index 645021031107ca..b6be3188ffa02e 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc @@ -15,10 +15,12 @@ limitations under the License. // Dumps out the operations that are present in a serialized computation. +#include #include #include #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -29,7 +31,7 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" diff --git a/third_party/xla/xla/tools/dumped_computation_to_text.cc b/third_party/xla/xla/tools/dumped_computation_to_text.cc index 695d4c928a6866..72e1710e194507 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_text.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_text.cc @@ -18,13 +18,14 @@ limitations under the License. #include #include #include +#include #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index 7f9747bad6f252..4ee9192cce9345 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -43,10 +43,10 @@ cc_library( hdrs = ["hlo_bisect_state.h"], deps = [ "//xla:literal", + "//xla:literal_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/tests:test_utils", + "//xla/hlo/transforms:hlo_dce", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -84,8 +84,8 @@ cc_library( "//xla:protobuf_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/service:dump", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", "//xla/service:hlo_runner", diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc index aeabb3193bd439..6db6e970d22f7a 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_state.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/hlo_dce.h" -#include "xla/tests/test_utils.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" +#include "xla/literal_util.h" #include "xla/util.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc index d4e6d0d60e70fa..f25be607e81610 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc @@ -24,10 +24,10 @@ limitations under the License. #include "absl/cleanup/cleanup.h" #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/protobuf_util.h" #include "xla/service/dump.h" #include "xla/service/hlo.pb.h" -#include "xla/service/hlo_parser.h" #include "xla/service/hlo_proto_util.h" #include "xla/service/hlo_runner.h" #include "xla/service/hlo_verifier.h" diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc index e81f3ef6ee8604..61ef1175457e1a 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc @@ -34,15 +34,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/transforms/simplifiers/hlo_dce.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/hlo_dce.h" #include "xla/service/tuple_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc index 40d16ca88574cf..df23c3c9527e8d 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc @@ -15,13 +15,20 @@ limitations under the License. #include "xla/tools/hlo_control_flow_flattening.h" +#include #include #include +#include +#include +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/transforms/despecializer.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/despecializer.h" #include "xla/service/hlo_verifier.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/tools/hlo_decomposer.cc b/third_party/xla/xla/tools/hlo_decomposer.cc index 30741733af2132..577456c24787a2 100644 --- a/third_party/xla/xla/tools/hlo_decomposer.cc +++ b/third_party/xla/xla/tools/hlo_decomposer.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" -#include "absl/status/status.h" #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/tools/hlo_expand.cc b/third_party/xla/xla/tools/hlo_expand.cc index 95c97e3c8e9911..8e7ac95da76df1 100644 --- a/third_party/xla/xla/tools/hlo_expand.cc +++ b/third_party/xla/xla/tools/hlo_expand.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include "xla/hlo/pass/hlo_pass_pipeline.h" +#include "xla/hlo/transforms/expanders/cholesky_expander.h" +#include "xla/hlo/transforms/expanders/rng_bit_generator_expander.h" +#include "xla/hlo/transforms/expanders/rng_expander.h" #include "xla/service/batchnorm_expander.h" -#include "xla/service/cholesky_expander.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_verifier.h" -#include "xla/service/rng_bit_generator_expander.h" -#include "xla/service/rng_expander.h" #include "xla/service/sharding_propagation.h" #include "xla/service/spmd/stateful_rng_spmd_partitioner.h" #include "xla/service/triangular_solve_expander.h" diff --git a/third_party/xla/xla/tools/hlo_extractor.cc b/third_party/xla/xla/tools/hlo_extractor.cc index 3a34570c8eb50b..e376fd0e4d0f86 100644 --- a/third_party/xla/xla/tools/hlo_extractor.cc +++ b/third_party/xla/xla/tools/hlo_extractor.cc @@ -42,7 +42,6 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/tests/test_utils.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_module_loader.cc b/third_party/xla/xla/tools/hlo_module_loader.cc index f6a685435825ca..3ab573dfa2ac42 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.cc +++ b/third_party/xla/xla/tools/hlo_module_loader.cc @@ -22,19 +22,26 @@ limitations under the License. #include #include #include +#include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "re2/re2.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/parser/hlo_parser.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" +#include "xla/tools/run_hlo_module.pb.h" +#include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -70,8 +77,7 @@ absl::StatusOr> LoadModuleFromData( const std::string& data, std::string_view format, const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, - BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) { DebugOptions debug_options = GetDebugOptionsFromFlags(); std::unique_ptr module; if (format == "hlo" || format == "txt") { @@ -82,9 +88,10 @@ absl::StatusOr> LoadModuleFromData( if (config_modifier_hook) { config_modifier_hook(&config); } - TF_ASSIGN_OR_RETURN(module, ParseAndReturnUnverifiedModule( - hlo_string, config, - set_to_default_entry_computation_layout)); + HloParserOptions options; + options.set_fill_missing_layouts(fill_missing_layouts); + TF_ASSIGN_OR_RETURN( + module, ParseAndReturnUnverifiedModule(hlo_string, config, options)); } else { HloSnapshot proto; if (format == "pb") { @@ -132,16 +139,14 @@ absl::StatusOr> LoadModuleFromFile( const std::string& path, std::string format, const hlo_module_loader_details::Config& ovr_config, const std::function& config_modifier_hook, - BufferAssignmentProto* buffer_assignment_proto, - bool set_to_default_entry_computation_layout) { + BufferAssignmentProto* buffer_assignment_proto, bool fill_missing_layouts) { std::string data; if (format.empty()) { format = std::string(tsl::io::Extension(path)); } TF_RETURN_IF_ERROR(tsl::ReadFileToString(tsl::Env::Default(), path, &data)); return LoadModuleFromData(data, format, ovr_config, config_modifier_hook, - buffer_assignment_proto, - set_to_default_entry_computation_layout); + buffer_assignment_proto, fill_missing_layouts); } absl::StatusOr> diff --git a/third_party/xla/xla/tools/hlo_module_loader.h b/third_party/xla/xla/tools/hlo_module_loader.h index a841bceed512d0..4dc0653cd9729b 100644 --- a/third_party/xla/xla/tools/hlo_module_loader.h +++ b/third_party/xla/xla/tools/hlo_module_loader.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef XLA_TOOLS_HLO_MODULE_LOADER_H_ #define XLA_TOOLS_HLO_MODULE_LOADER_H_ +#include #include #include #include @@ -24,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo.pb.h" #include "xla/tools/run_hlo_module.pb.h" namespace xla { @@ -61,7 +63,7 @@ absl::StatusOr> LoadModuleFromData( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_layouts = true); // Loads an HLO module from file. // The file can be one of the followings: @@ -84,7 +86,7 @@ absl::StatusOr> LoadModuleFromFile( hlo_module_loader_details::Config(), const std::function& config_modifier_hook = {}, BufferAssignmentProto* buffer_assignment_proto = nullptr, - bool set_to_default_entry_computation_layout = true); + bool fill_missing_layouts = true); // Loads an HLO snapshot from a string, only for its inputs // The data format must be one of the following: diff --git a/third_party/xla/xla/tools/hlo_module_loader_test.cc b/third_party/xla/xla/tools/hlo_module_loader_test.cc index 16fbe45e4ae451..b3b4a8fb0b9c38 100644 --- a/third_party/xla/xla/tools/hlo_module_loader_test.cc +++ b/third_party/xla/xla/tools/hlo_module_loader_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/tools/hlo_module_loader.h" +#include #include #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index d9cefa0eddad8a..74dc6159a8dc36 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -39,8 +39,8 @@ cc_library( "//xla/service:executable", "//xla/service:hlo_graph_dumper", "//xla/service:platform_util", - "//xla/stream_executor", "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -74,8 +74,9 @@ cc_library( "//xla/service/gpu:gpu_compiler", "//xla/service/gpu:gpu_hlo_schedule", "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor", - "//xla/stream_executor/platform", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", @@ -105,7 +106,7 @@ cc_library( "//xla/service:hlo_graph_dumper", "//xla/service/cpu:cpu_executable", "//xla/stream_executor/host:host_platform", - "//xla/stream_executor/platform", + "//xla/stream_executor/platform:initialize", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", @@ -177,7 +178,7 @@ lit_test_suite( default_tags = tf_cuda_tests_tags(), hermetic_cuda_data_dir = "%S/../../../../cuda_nvcc", tags_override = { - "gpu_hlo_ptx.hlo": ["no_rocm"], + "gpu_hlo_ptx.hlo": ["cuda-only"], }, tools = [ "//xla/tools:hlo-opt", diff --git a/third_party/xla/xla/tools/interactive_graphviz.cc b/third_party/xla/xla/tools/interactive_graphviz.cc index 10d68162d54627..47ba7b2300c45a 100644 --- a/third_party/xla/xla/tools/interactive_graphviz.cc +++ b/third_party/xla/xla/tools/interactive_graphviz.cc @@ -46,12 +46,12 @@ limitations under the License. #include "xla/service/local_service.h" #include "xla/service/platform_util.h" #include "xla/tools/hlo_extractor.h" +#include "xla/tsl/protobuf/error_codes.pb.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/subprocess.h" -#include "tsl/protobuf/error_codes.pb.h" #if defined(PLATFORM_GOOGLE) #include "util/readline/readline.h" #endif diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 4578fcef45e177..fb42753f89810c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -83,9 +83,9 @@ xla_cc_binary( name = "hlo_runner_main_gpu", testonly = True, tags = [ + "cuda-only", "gpu", "no_mac", - "no_rocm", ] + tf_gpu_tests_tags(), deps = [ ":hlo_runner_main_lib", @@ -135,8 +135,9 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/client:executable_build_options", - "//xla/client:xla_computation", + "//xla/hlo/builder:xla_computation", "//xla/hlo/ir:hlo", + "//xla/hlo/parser:hlo_parser", "//xla/hlo/pass:hlo_pass_pipeline", "//xla/pjrt:host_memory_spaces", "//xla/pjrt:pjrt_client", @@ -147,7 +148,6 @@ cc_library( "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/tests:test_utils", "//xla/tools:hlo_control_flow_flattening", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h index 230a39e67a0939..6c8bb2a4ca685c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.h @@ -31,8 +31,9 @@ limitations under the License. namespace xla { struct PjRtEnvironment { - std::unique_ptr client; + // Sequence matters here, client should be destroyed before service. std::unique_ptr service; + std::unique_ptr client; std::shared_ptr kv_store; std::shared_ptr distributed_client; }; diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc index c20ba23e4f0919..9df993644898bc 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.cc @@ -38,10 +38,11 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/client/executable_build_options.h" -#include "xla/client/xla_computation.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/layout.h" #include "xla/literal.h" @@ -57,7 +58,6 @@ limitations under the License. #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" -#include "xla/service/hlo_parser.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tests/test_utils.h" @@ -822,14 +822,13 @@ FunctionalHloRunner::Run(PjRtClient& client, PjRtLoadedExecutable* executable, flattened_arguments.insert({device_id, std::move(flattened_argument)}); } return CopyArgumentsToDevice(client, executable, flattened_arguments, - running_options.log_input_output(), + running_options, /*flattened_arguments=*/true); } // If the per-device argument is not a single tuple, we ignore the // flatten_tupled_arguments parameter and assume the provided arguments have // already been flattened. - return CopyArgumentsToDevice(client, executable, arguments, - running_options.log_input_output(), + return CopyArgumentsToDevice(client, executable, arguments, running_options, /*flattened_arguments=*/false); }; return RunInternal(client, executable, create_argument_buffers_on_device, @@ -1164,14 +1163,13 @@ FunctionalHloRunner::CreateArgumentsOnDevice( } if (kUseSharedInputs) { - return CopyArgumentsToDevice( - client, executable, per_device_argument_literals, - running_options.log_input_output(), flatten_arguments, - /*clone_device0_arguments=*/true); + return CopyArgumentsToDevice(client, executable, + per_device_argument_literals, running_options, + flatten_arguments, + /*clone_device0_arguments=*/true); } return CopyArgumentsToDevice(client, executable, per_device_argument_literals, - running_options.log_input_output(), - flatten_arguments); + running_options, flatten_arguments); } absl::StatusOr>>> @@ -1261,8 +1259,10 @@ FunctionalHloRunner::CreateUninitializedArgumentsOnDevice( absl::StatusOr>>> FunctionalHloRunner::CopyArgumentsToDevice( PjRtClient& client, const PjRtLoadedExecutable* executable, - const PerDeviceLiteralVecType& arguments, bool log_input, - bool flattened_arguments, bool clone_device0_arguments) { + const PerDeviceLiteralVecType& arguments, + const RunningOptions& running_options, bool flattened_arguments, + bool clone_device0_arguments) { + const bool log_input = running_options.log_input_output(); absl::Span addressable_devices = executable->addressable_devices(); size_t num_addressable_devices = addressable_devices.size(); @@ -1301,20 +1301,22 @@ FunctionalHloRunner::CopyArgumentsToDevice( TF_RET_CHECK(!shape.IsTuple()) << "Param tuple without flattened_arguments"; return non_tuple_memory_space(shape); }; - auto buffer_from_host_literal = [&client, &argument_memory_space]( - const HloModule* module, - PjRtDevice* device, int arg_i, - const Literal& literal) + auto buffer_from_host_literal = + [&client, &argument_memory_space, &running_options]( + const HloModule* module, PjRtDevice* device, int arg_i, + const Literal& literal) -> absl::StatusOr> { + const Layout* layout = nullptr; + if (running_options.use_argument_host_layout && + literal.shape().has_layout()) { + layout = &literal.shape().layout(); + } if (client.memory_spaces().empty()) { - return client.BufferFromHostLiteral( - literal, device, - literal.shape().has_layout() ? &literal.shape().layout() : nullptr); + return client.BufferFromHostLiteral(literal, device, layout); } TF_ASSIGN_OR_RETURN(PjRtMemorySpace * memory_space, argument_memory_space(module, device, arg_i)); - return client.BufferFromHostLiteral(literal, memory_space, - /* device_layout */ nullptr); + return client.BufferFromHostLiteral(literal, memory_space, layout); }; absl::Span diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h index 79d986148b9a37..30864615343962 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner.h @@ -203,6 +203,9 @@ class FunctionalHloRunner { // Whether to untuple the result of running HLO module into a vector of // arrays. If unprovided, use the default in ExecuteOptions. std::optional untuple_result = std::nullopt; + // Whether to use the layout on host when allocating buffers for arguments. + // Some platforms (e.g. CPU) do not support this yet. + bool use_argument_host_layout = false; // Should we log the inputs and outputs to stderr? bool log_input_output() const { @@ -377,7 +380,7 @@ class FunctionalHloRunner { CopyArgumentsToDevice(PjRtClient& client, const PjRtLoadedExecutable* executable, const PerDeviceLiteralVecType& arguments, - bool log_input, bool flattened_arguments, + const RunningOptions& options, bool flattened_arguments, bool clone_device0_arguments = false); static absl::StatusOr RunInternal( diff --git a/third_party/xla/xla/tools/prepare_reference_module.cc b/third_party/xla/xla/tools/prepare_reference_module.cc index 82fd57a8f183f5..bc8a11aa9c8b00 100644 --- a/third_party/xla/xla/tools/prepare_reference_module.cc +++ b/third_party/xla/xla/tools/prepare_reference_module.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/despecializer.h" +#include "xla/hlo/transforms/despecializer.h" #include "xla/service/hlo_module_config.h" #include "xla/stream_executor/platform.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc index 9b6138f45e7246..3421516ee41213 100644 --- a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc +++ b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc @@ -21,9 +21,9 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/parser/hlo_parser.h" #include "xla/literal.h" #include "xla/literal_util.h" -#include "xla/service/hlo_parser.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc index 92e2e23efad6e0..17795415652bbb 100644 --- a/third_party/xla/xla/tools/run_hlo_module_main.cc +++ b/third_party/xla/xla/tools/run_hlo_module_main.cc @@ -30,12 +30,12 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_runner.h" #include "xla/service/platform_util.h" #include "xla/tools/run_hlo_module.h" -#include "xla/translate/mhlo_to_hlo/translate.h" -#include "xla/translate/stablehlo_to_hlo/translate.h" #include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index 16d4d0d65e29d3..20d0a33593368a 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -37,8 +37,8 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" #include "stablehlo/dialect/Register.h" -#include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/builder/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" diff --git a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc index 62c06734ddb990..6cef1d4e58ad44 100644 --- a/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc @@ -29,6 +29,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tools/xla_compile_lib.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" @@ -37,8 +39,6 @@ limitations under the License. #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc index bc34c8790fb14e..7bf90e5bf88970 100644 --- a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc @@ -30,14 +30,14 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tools/xla_compile_lib.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/error_codes.pb.h" +#include "xla/tsl/protobuf/status.pb.h" #include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/error_codes.pb.h" -#include "tsl/protobuf/status.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/translate/BUILD b/third_party/xla/xla/translate/BUILD index 4833529d409a24..4fab823a4a5215 100644 --- a/third_party/xla/xla/translate/BUILD +++ b/third_party/xla/xla/translate/BUILD @@ -1,5 +1,3 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("//xla:xla.bzl", "xla_cc_binary") load("//xla/tsl:tsl.bzl", "internal_visibility") package( @@ -11,52 +9,14 @@ package( licenses = ["notice"], ) -build_test( - name = "xla-translate_build_test", - targets = [ - ":xla-translate", - ], -) - -xla_cc_binary( +alias( name = "xla-translate", - testonly = True, - srcs = ["xla_translate_main.cc"], - deps = [ - "//xla/service/cpu:cpu_compiler", - "//xla/service/cpu:cpu_transfer_manager", - "//xla/stream_executor/host:host_platform", - "//xla/translate/hlo_to_mhlo:translate_registration", - "//xla/translate/mhlo_to_hlo:translate_registration", - "//xla/translate/stablehlo_to_hlo:translate_registration", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:platform_port", - ], -) - -build_test( - name = "xla-translate-opt_build_test", - targets = [ - ":xla-translate-opt", - ], + actual = "//xla/hlo/translate:xla-translate", + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate instead.", ) -xla_cc_binary( +alias( name = "xla-translate-opt", - testonly = True, - srcs = ["xla_translate_opt_main.cc"], - deps = [ - "//xla/mlir/framework/ir:xla_framework", - "//xla/mlir/framework/transforms:passes", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:cpu_plugin", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:MlirOptLib", - "@local_tsl//tsl/platform:platform_port", - "@stablehlo//:register", - ], + actual = "//xla/hlo/translate:xla-translate-opt", + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate:xla-translate-opt instead.", ) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD index fb4f7b9fc662fd..ae3aed9e9a9cea 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD @@ -1,5 +1,4 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") load("//xla/tsl:tsl.bzl", "internal_visibility") package( @@ -13,250 +12,67 @@ package( cc_library( name = "attribute_importer", - srcs = ["attribute_importer.cc"], hdrs = ["attribute_importer.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:attribute_importer instead.", deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "async_importer", - srcs = ["async_importer.cc"], - hdrs = ["async_importer.h"], - deps = [ - ":attribute_importer", - ":hlo_utils", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "custom_call_importer", - srcs = ["custom_call_importer.cc"], - hdrs = ["custom_call_importer.h"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:Support", - ], -) - -cc_library( - name = "stack_location_utils", - srcs = ["stack_location_utils.cc"], - hdrs = ["stack_location_utils.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_proto_cc", - "@llvm-project//mlir:IR", + "//xla/hlo/translate/hlo_to_mhlo:attribute_importer", ], ) cc_library( name = "hlo_function_importer", - srcs = ["hlo_function_importer.cc"], hdrs = ["hlo_function_importer.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_function_importer instead.", deps = [ - ":async_importer", - ":attribute_importer", - ":custom_call_importer", - ":hlo_utils", - ":location_importer", - "//xla:comparison_util", - "//xla:literal", - "//xla:protobuf_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:optional", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:AsmParser", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SideEffectInterfaces", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", ], ) cc_library( name = "hlo_module_importer", - srcs = [ - "hlo_module_importer.cc", - ], hdrs = [ "hlo_module_importer.h", ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_module_importer instead.", deps = [ - ":hlo_function_importer", - ":module_attributes_importer", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/hlo_to_mhlo:hlo_module_importer", ], ) cc_library( name = "hlo_to_mlir_hlo", - srcs = ["hlo_to_mlir_hlo.cc"], hdrs = ["hlo_to_mlir_hlo.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo instead.", deps = [ - ":hlo_module_importer", - "//xla:status_macros", - "//xla/mlir/utils:error_util", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", + "//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", ], ) cc_library( name = "hlo_utils", - srcs = ["hlo_utils.cc"], hdrs = ["hlo_utils.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:hlo_utils instead.", includes = ["include"], deps = [ - "//xla:literal", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:SparseTensorEnums", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "hlo_utils_test", - srcs = ["hlo_utils_test.cc"], - deps = [ - ":hlo_utils", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:types", - "//xla/tsl/lib/core:status_test_util", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "location_importer", - srcs = ["location_importer.cc"], - hdrs = ["location_importer.h"], - deps = [ - "stack_location_utils", - "//xla/hlo/ir:hlo", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "module_attributes_importer", - srcs = ["module_attributes_importer.cc"], - hdrs = ["module_attributes_importer.h"], - deps = [ - ":hlo_function_importer", - ":hlo_utils", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:computation_layout", - "//xla/service:hlo_module_config", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", + "//xla/hlo/translate/hlo_to_mhlo:hlo_utils", ], ) cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate instead.", deps = [ - ":hlo_to_mlir_hlo", - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@local_tsl//tsl/platform:protobuf", + "//xla/hlo/translate/hlo_to_mhlo:translate", ], ) cc_library( name = "translate_registration", testonly = True, - srcs = ["translate_registration.cc"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/hlo_to_mhlo:translate_registration instead.", deps = [ - ":translate", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", + "//xla/hlo/translate/hlo_to_mhlo:translate_registration", ], alwayslink = 1, ) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h index ead3a3955fc79c..2b5f81982fd6d8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h @@ -16,89 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/types/span.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Converts an XLA PrecisionConfig to the corresponding MLIR attribute. -mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config, - mlir::Builder* builder); - -// Converts the gather dimensions to attributes. -mlir::mhlo::GatherDimensionNumbersAttr ConvertGatherDimensionNumbers( - const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the scatter dimensions to attributes. -mlir::mhlo::ScatterDimensionNumbersAttr ConvertScatterDimensionNumbers( - const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the dot algorithm to attributes. -mlir::mhlo::DotAlgorithmAttr ConvertDotAlgorithm( - PrecisionConfig::Algorithm algorithm, mlir::Builder* builder); - -// Converts the dot dimensions to attributes. -mlir::mhlo::DotDimensionNumbersAttr ConvertDotDimensionNumbers( - const DotDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the conv dimensions to attributes. -mlir::mhlo::ConvDimensionNumbersAttr ConvertConvDimensionNumbers( - const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder); - -// Converts the output operand aliasing to attributes. -mlir::ArrayAttr ConvertOutputOperandAliasing( - const std::vector>>& aliaInfo, - mlir::Builder* builder); - -// Converts the sparsity descriptor to attributes. -absl::StatusOr ConvertSparsityDescriptor( - xla::SparsityDescriptor sparsity_descriptor, mlir::Builder* builder); - -absl::StatusOr ConvertFftType(FftType type); -absl::StatusOr ConvertTranspose( - TriangularSolveOptions_Transpose transpose); - -absl::StatusOr ConvertCustomCallApiVersion( - xla::CustomCallApiVersion api_version); - -mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, - mlir::Builder* builder); -mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, - mlir::Builder* builder); - -mlir::NamedAttribute ConvertReplicaGroups( - absl::Span replica_groups, mlir::Builder* builder); - -mlir::NamedAttribute ConvertSourceTargetPairs( - const std::vector>& source_target_pairs, - mlir::Builder* builder); - -mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder); - -// Extracts layouts from shapes and converts it into layout attributes (array of -// rank-1 index tensors). Returns an error if any of the shapes is a tuple. -absl::StatusOr ExtractLayoutsFromShapes( - const absl::Span shapes_with_layouts, mlir::Builder* builder); - -// Extracts the layouts of each element from a tuple shape and returns them as -// an array of rank-1 index tensors. Returns an error in presence of nested -// tuple shapes. -absl::StatusOr ExtractLayoutsFromTuple(const xla::Shape shape, - mlir::Builder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h index fa22a6d11f1086..0ebd37fa6af125 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -16,243 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/types/optional.h" -#include "absl/types/span.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "xla/comparison_util.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -class HloModule; -class HloComputation; -class HloInstruction; -class Shape; - -// HLO bounded dynamic shapes can be converted to either MLIR dynamic shapes -// (which lose the bound information) or casted to static shape using the -// bounds. -enum class DynamicShapeHandlingMode { kDynamic, kConvertToStatic }; - -// Helper class for importing HloComputations. -class HloFunctionImporter { - public: - // Imports the given computation as a function in the given symbol table and - // returns the FuncOp. This also imports any computations referred by - // instructions in this computation. - static absl::StatusOr ImportAsFunc( - const HloComputation& computation, mlir::SymbolTable& symbol_table, - std::unordered_map* - function_map, - mlir::Builder* builder, bool is_main, - bool flatten_computation_args_result = false); - - // Imports the given hlo computation to the specified region. - // - // Flattens the tuple-typed region argument(s) and return value(s). - static absl::Status ImportAsRegion( - const HloComputation& computation, mlir::SymbolTable& symbol_table, - mlir::Region* region, mlir::Builder* builder, - bool flatten_computation_args_result = false); - - // Imports the given computation to the given place specified by `builder`. - // `arguments` contains values for all parameters. - static absl::StatusOr ImportInstructions( - const HloComputation& computation, - const llvm::SmallVectorImpl& arguments, - mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, - bool flatten_computation_args_result = false); - - static absl::StatusOr ImportInstruction( - const HloInstruction* instr, - const llvm::SmallVectorImpl& operands, - mlir::SymbolTable& symbol_table, mlir::OpBuilder* builder, - bool flatten_computation_args_result = false, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, - llvm::StringRef attr_name); - - // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block - // arguments with 'implicit_operands'. Here | implicit_operands | == sum of - // the number of arguments in all the regions in IfOp or CaseOp. - void ReplaceBlockArgumentsWithImplicitOperands( - mlir::Operation* op, llvm::ArrayRef implicit_operands); - - // FlattenTupleType flattens the types in (nested) tuple-type 'type' and - // stores them in 'flattened_types'. - static void FlattenTupleType( - mlir::Type type, llvm::SmallVectorImpl& flattened_types); - - // FlattenTupleValue flattens the values in (nested) tuple-typed 'value' and - // stores them in 'flattened_values'. - static void FlattenTupleValue( - mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Value value, - llvm::SmallVectorImpl& flattened_values); - - // FlattenTupleValues flattens the values in (nested) tuple-typed 'values' and - // returns the flattened values. - static llvm::SmallVector FlattenTupleValues( - mlir::OpBuilder* func_builder, mlir::Location loc, - mlir::ValueRange values, std::optional reserve_size = std::nullopt); - - private: - HloFunctionImporter(mlir::SymbolTable& symbol_table, - std::unordered_map* function_map, - mlir::Builder* builder, - bool flatten_computation_args_result) - : context_(symbol_table.getOp()->getContext()), - symbol_table_(symbol_table), - builder_(builder), - function_map_(function_map), - flatten_computation_args_result_(flatten_computation_args_result) { - context_->loadDialect(); - context_->loadDialect(); - context_->loadDialect(); - context_->loadDialect(); - } - - // Imports the given computation as a new function, if it hasn't been already - // imported. - absl::StatusOr ImportAsFunc( - const HloComputation& computation, bool is_main); - - // Imports the given computation in the specified region. - absl::Status ImportAsRegion(const HloComputation& computation, - mlir::Region* region); - - // Imports instructions from the given computation in the specified block. - // Assumes that the block already has correct arguments populated. - absl::Status ImportInstructions(const HloComputation& computation, - mlir::Block* block); - absl::StatusOr ImportInstructionsImpl( - const HloComputation& computation, - const llvm::SmallVectorImpl& arguments, - mlir::OpBuilder* builder); - - // Imports an instruction. - absl::StatusOr ImportInstructionWithLayout( - const HloInstruction* instruction, - const llvm::SmallVectorImpl& operands, - mlir::OpBuilder* func_builder, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - absl::StatusOr ImportInstructionImpl( - const HloInstruction* instruction, - const llvm::SmallVectorImpl& operands, - mlir::OpBuilder* func_builder, - DynamicShapeHandlingMode mode = DynamicShapeHandlingMode::kDynamic); - - // Gets the MLIR operand values from an HLO Instruction. - absl::StatusOr> GetOperands( - const HloInstruction* instruction); - - // Converts xla Tensor type to the corresponding MLIR type. - absl::StatusOr ConvertTensorType(const Shape& shape); - - // Converts an XLA shape/layout to the corresponding MLIR layout, in - // flattened_attr, while flattening the tuple layout. - absl::Status ConvertShapeToMlirLayout( - const Shape& shape, - llvm::SmallVectorImpl& flattened_attr); - - // Returns the output type of an HloInstruction. - absl::StatusOr GetReturnType(const HloInstruction* instruction); - - // Takes a list of HloInstructions and generates the list of types used for - // input, bypassing tuples to subsets. - absl::Status GetMlirTypes( - absl::Span instructions, - llvm::SmallVectorImpl* types); - - // Returns the Mlir Value for the corresponding HloInstruction. - absl::StatusOr GetMlirValue(const HloInstruction* instruction); - - // TODO(b/179166199): Move attribute converters to attribute_importer. - // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertComparisonDirection( - ComparisonDirection direction); - - // Converts an XLA Comparison::Type to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertComparisonType(Comparison::Type type); - - // Converts an XLA CustomCallSchedule to the corresponding MLIR attribute. - mlir::NamedAttribute ConvertCustomCallSchedule(CustomCallSchedule schedule); - - // Converts the dimensions of an HLO instruction into an MLIR attribute. - mlir::DenseIntElementsAttr ConvertDimensions( - absl::Span op_dimensions); - - // Converts Array ref to an DenseIntElementsAttr. - mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); - - // Converts Array ref of bools to a DenseIntElementsAttr of I1 type. - mlir::DenseIntElementsAttr Convert(llvm::ArrayRef elements); - - // Converts Array ref to padding attribute. Input is a flattened list of - // padding low and padding high for each of the spatial dimensions. - mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); - - mlir::MLIRContext* context_; - - // SymbolTable to which new functions should be inserted. - mlir::SymbolTable& symbol_table_; - - mlir::Builder* builder_; - - // Mapping from HloComputation to the created MLIR function. - std::unordered_map* function_map_; - - // Mapping from HloInstructions to the associative MLIR values. - std::unordered_map instruction_value_map_; - - bool flatten_computation_args_result_; -}; - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO C++ input_output_alias_config. -// Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, - mlir::Builder* builder); - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO C++ sharding. -// Always succeeds and returns a non-empty attribute. -mlir::Attribute ConvertSharding(const HloSharding& sharding, - mlir::Builder* builder); - -// Returns a StringAttr that carries a prettyprinted representation of the -// given HLO proto sharding. -// Will fail and return an empty attribute if the proto sharding cannot be -// converted to the C++ sharding. -mlir::Attribute ConvertSharding(const OpSharding& sharding, - mlir::Builder* builder); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h index 0cc1a39d8eb003..8577e86dc93839 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.h @@ -16,49 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ -#include - -#include "absl/status/status.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/SymbolTable.h" -#include "xla/xla_data.pb.h" - -namespace xla { -class HloModule; -class HloModuleProto; -class HloComputation; -class HloInstruction; -class Shape; - -// Importer that takes an HloModule and imports it as an MLIR module in the XLA -// dialect. HloModuleImporter does not take ownership. -class HloModuleImporter { - public: - explicit HloModuleImporter(mlir::ModuleOp module, - bool import_all_computation = false, - bool flatten_computation_args_result = false); - - // Import the HloModule into the MLIR Module. - absl::Status Import(const xla::HloModule& module); - - // Import the HloModuleProto into the MLIR Module. - absl::Status Import(const xla::HloModuleProto& module); - - private: - bool import_all_computation_; - bool flatten_computation_args_result_; - mlir::SymbolTable symbol_table_; - mlir::Builder builder_; - - // Map for tracking which MLIR function map to which HLO Computation. This - // tracks functions as they are imported and provides a quick lookup for - // functions invoked by control flow related operations (e.g. while, call). - std::unordered_map - function_map_; -}; - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_module_importer.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_MODULE_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h index 775d6367dc8fc9..4943ef790d35f1 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -16,56 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/OwningOpRef.h" - -namespace mlir { -class ModuleOp; -} // namespace mlir - -namespace xla { -class HloModule; -class HloModuleProto; - -// Converts an HLO module proto to a MLIR module in HLO dialect. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -absl::StatusOr> ConvertHloToMlirHlo( - mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModuleProto const* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts an HLO module to a MLIR module in HLO dialect. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -absl::StatusOr> ConvertHloToMlirHlo( - mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - const xla::HloModule* hlo_module, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_TO_MLIR_HLO_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index bb9785d8242664..50e31028617463 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -18,232 +18,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/SparseTensor/IR/Enums.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "xla/layout.h" -#include "xla/layout_util.h" -#include "xla/literal.h" -#include "xla/mlir/utils/type_util.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { - -absl::StatusOr CreateDenseElementsAttrFromLiteral( - const LiteralBase& literal, mlir::Builder builder); - -// Creates an DenseIntElementsAttr using the elements of the vector and the -// optional shape. -mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( - const llvm::ArrayRef vector, mlir::Builder builder, - llvm::ArrayRef shape = {}); - -// Converts the given XLA shape for tensors to the template MLIR type. -template -static absl::StatusOr ConvertTensorShapeToType(const Shape& xla_ty, - mlir::Builder builder) { - auto element_type_or = - ConvertPrimitiveTypeToMlirType(xla_ty.element_type(), builder); - if (!element_type_or.ok()) return element_type_or.status(); - - bool is_bounded_dynamic = false; - int64_t rank = xla_ty.rank(); - llvm::SmallVector shape(rank, mlir::ShapedType::kDynamic); - llvm::SmallVector bounds(rank, mlir::ShapedType::kDynamic); - for (int64_t dim = 0; dim < rank; ++dim) { - int64_t dim_size = xla_ty.dimensions(dim); - if (xla_ty.is_dynamic_dimension(dim)) { - if (!xla_ty.is_unbounded_dynamic_dimension(dim)) { - bounds[dim] = dim_size; - is_bounded_dynamic = true; - } - } else { - shape[dim] = dim_size; - } - } - using mlir::mhlo::TypeExtensionsAttr; - mlir::Attribute encoding; - if (is_bounded_dynamic) { - encoding = TypeExtensionsAttr::get(builder.getContext(), bounds); - } - - using mlir::sparse_tensor::SparseTensorEncodingAttr; - // TODO(b/238903065): We don't yet support bounded dynamism shapes and - // sparsity at the same time, as we can currently only have one `encoding` on - // a RankedTensorType, and we don't currently have a meet of - // SparseTensorEncodingAttr and TypeExtensionsAttr (which holds bounds). - // - // For example, we wouldn't be able to represent the xla type - // `f32[4,<=4]{1,0:D(D,C)}`. - if (xla_ty.has_layout()) { - auto layout = xla_ty.layout(); - if (LayoutUtil::IsSparse(layout)) { - if (is_bounded_dynamic) - return Unimplemented( - "MHLO doesn't support bounded dynamic shapes for sparse tensors"); - llvm::SmallVector lts; - for (size_t i = 0, e = layout.dim_level_types_size(); i < e; ++i) { - auto dlt = layout.dim_level_type(i); - bool ordered = - i < layout.dim_ordered_size() ? layout.dim_ordered(i) : true; - bool unique = - i < layout.dim_unique_size() ? layout.dim_unique(i) : true; - switch (dlt) { - case DimLevelType::DIM_DENSE: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Dense, ordered, unique)); - break; - case DimLevelType::DIM_COMPRESSED: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Compressed, ordered, unique)); - break; - case DimLevelType::DIM_SINGLETON: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::Singleton, ordered, unique)); - break; - case DimLevelType::DIM_LOOSE_COMPRESSED: - lts.push_back(*mlir::sparse_tensor::buildLevelType( - mlir::sparse_tensor::LevelFormat::LooseCompressed, ordered, - unique)); - break; - default: - return InvalidArgument("Unknown DimLevelType from HLO"); - } - } - auto ordering = layout.minor_to_major(); - llvm::SmallVector major_to_minor = {ordering.rbegin(), - ordering.rend()}; - auto id_map = mlir::AffineMap::getPermutationMap(major_to_minor, - builder.getContext()); - // TODO(atondwal): support sizes other than 32 when XLA does - encoding = SparseTensorEncodingAttr::get( - builder.getContext(), lts, id_map, mlir::AffineMap(), 32, 32); - } - } - return TypeT::get(shape, element_type_or.value(), encoding); -} - -absl::StatusOr ConvertTensorShapeToMemRefType( - const Shape& shape, mlir::Builder builder); - -template <> -inline absl::StatusOr ConvertTensorShapeToType( - const Shape& shape, mlir::Builder builder) { - if (shape.is_dynamic()) { - return FailedPrecondition( // NOLINT - "MemRefType don't support dynamic shapes"); - } - return ConvertTensorShapeToMemRefType(shape, builder); -} - -// Converts the given XLA shape to the template MLIR type. -template -static absl::StatusOr ConvertShapeToType(const Shape& shape, - mlir::Builder builder) { - if (shape.IsTuple()) { - llvm::SmallVector contents; - contents.reserve(shape.tuple_shapes_size()); - for (const auto& subtype : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(auto mlir_subtype, - ConvertShapeToType(subtype, builder)); - contents.push_back(mlir_subtype); - } - return builder.getTupleType(contents); - } - if (shape.IsToken()) { - return mlir::mhlo::TokenType::get(builder.getContext()); - } - return ConvertTensorShapeToType(shape, builder); -} - -// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using -// the non-tuple-typed values in 'flatten_values'. -// -// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, -// The function returns %t2 such that: -// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple -// %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> -// -// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to -// resp. flatten and create tuples in the exact same order. -// 2. `flatten_values`, initially storing the flattened values, will be -// mutated to a 0-length array by the end of function invocation. -mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, - mlir::ValueRange& flatten_values, mlir::Type type); - -// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. -// Otherwise, return 'op'. -mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, - mlir::Location loc, - mlir::Operation* op, mlir::Type type); - -mlir::TypeRange Untuple(const mlir::Type& type); - -static std::pair GetLayoutAttribute( - mlir::Builder& b, const Shape& shape, - std::optional maybe_layout = std::nullopt) { - if (shape.IsTuple()) { - llvm::SmallVector element_attrs; - llvm::SmallVector tile_attrs; - for (const auto& tuple_shape : shape.tuple_shapes()) { - // TODO here we do not dissect the layout of a tuple into sublayouts. - // Presently ShapeLayout cannot represent an explicit layout for a tuple - // type so this should never occur. However, if this function were to - // be used in another context where this assumption were to be lifted. - // users should be aware of this limitation which will use the default - // layout for tuple subshapes. - std::pair inner = - tuple_shape.has_layout() - ? GetLayoutAttribute(b, tuple_shape, tuple_shape.layout()) - : GetLayoutAttribute(b, tuple_shape); - element_attrs.push_back(inner.first); - tile_attrs.push_back(inner.second); - } - return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), - b.getArrayAttr(tile_attrs)); - } - - Layout layout = maybe_layout.value_or( - shape.has_layout() ? shape.layout() - : LayoutUtil::GetDefaultLayoutForShape(shape)); - - llvm::SmallVector vec_of_tiles; - for (const Tile& tile : layout.tiles()) { - llvm::SmallVector tile_vec = {tile.dimensions().begin(), - tile.dimensions().end()}; - vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); - } - llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), - layout.minor_to_major().end()}; - return std::make_pair(b.getIndexTensorAttr(layout_vec), - b.getArrayAttr(vec_of_tiles)); -} - -static bool HasCustomLayout(const Shape& shape) { - if (shape.IsTuple()) { - return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); - } - return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); -} - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h b/third_party/xla/xla/translate/hlo_to_mhlo/translate.h index 2594aa17fc2d21..4ed0dc5c1ba216 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/translate.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/translate.h @@ -16,73 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ -namespace llvm { -class StringRef; -} // namespace llvm - -namespace mlir { -class MLIRContext; -class ModuleOp; -template -class OwningOpRef; -} // namespace mlir - -namespace xla { - -// Converts a HloModuleProto stored in the file with the given `input_filename` -// into a MHLO module. Creates MLIR entities into the given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModule stored in text form for a file with the given -// `input_filename` into a MHLO module. Creates MLIR entities into the given -// MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloTextToMlirHloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModuleProto stored in the file with the given `input_filename` -// into a StableHLO module. Creates MLIR entities into the given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -// Converts a HloModule stored in text form for a file with the given -// `input_filename` into a StableHLO module. Creates MLIR entities into the -// given MLIR `context`. -// -// If `import_all_computation` is set to true, imports all computations -// irrespective if transitively called from entry computation. -// -// If `flatten_computation_args_result` is set to true, flattens all tuple -// arguments and result of every computation when importing them as func ops. -mlir::OwningOpRef HloTextToStablehloTranslateFunction( - llvm::StringRef input, mlir::MLIRContext* context, - bool import_all_computations = false, - bool flatten_computation_args_result = false); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/hlo_to_mhlo/translate.h" #endif // XLA_TRANSLATE_HLO_TO_MHLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index b486d47233cfb3..e63ee872db8869 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -1,9 +1,5 @@ -load("@bazel_skylib//rules:build_test.bzl", "build_test") -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_binary", "cc_library") -load("//xla:xla.bzl", "xla_cc_test") +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl:tsl.bzl", "internal_visibility") -load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,280 +12,76 @@ package( cc_library( name = "attribute_exporter", - srcs = ["attribute_exporter.cc"], hdrs = ["attribute_exporter.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:attribute_exporter instead.", deps = [ - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@stablehlo//:base", + "//xla/hlo/translate/mhlo_to_hlo:attribute_exporter", ], ) cc_library( name = "layout_util", - srcs = ["layout_util.cc"], hdrs = ["layout_util.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:layout_util instead.", deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/hlo/ir:hlo", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/mhlo_to_hlo:layout_util", ], ) cc_library( name = "location_exporter", - srcs = ["location_exporter.cc"], hdrs = ["location_exporter.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:location_exporter instead.", deps = [ - ":stack_frame_index_builder", - "//xla:xla_data_proto_cc", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", + "//xla/hlo/translate/mhlo_to_hlo:location_exporter", ], ) -cc_library( +alias( name = "module_attributes_exporter", - srcs = ["module_attributes_exporter.cc"], - hdrs = ["module_attributes_exporter.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], + actual = "//xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter", + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:module_attributes_exporter instead.", ) -cc_library( +alias( name = "stack_frame_index_builder", - srcs = ["stack_frame_index_builder.cc"], - hdrs = ["stack_frame_index_builder.h"], - deps = [ - "//xla/service:hlo_proto_cc", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - ], + actual = "//xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder", + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:stack_frame_index_builder instead.", ) cc_library( name = "mlir_hlo_to_hlo", - srcs = [ - "mlir_hlo_to_hlo.cc", - "operator_writers.inc", - ], hdrs = ["mlir_hlo_to_hlo.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo instead.", deps = [ - ":attribute_exporter", - ":layout_util", - ":location_exporter", - ":module_attributes_exporter", - ":operator_writer_inc", - ":stack_frame_index_builder", - ":type_to_shape", - "//xla:array", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:xla_data_proto_cc", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/client/lib:approx_topk", - "//xla/client/lib:approx_topk_shape", - "//xla/client/lib:matrix", - "//xla/client/lib:quantize", - "//xla/client/lib:slicing", - "//xla/hlo/ir:hlo", - "//xla/mlir/utils:error_util", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "//xla/mlir_hlo:mhlo_passes", - "//xla/service:computation_layout", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service:hlo_proto_cc", - "//xla/service/gpu:backend_configs_cc", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:types", - "@stablehlo//:base", - "@stablehlo//:stablehlo_ops", - ], -) - -build_test( - name = "operator_writer_gen_build_test", - targets = [ - ":operator_writer_gen", - ], -) - -cc_binary( - name = "operator_writer_gen", - srcs = ["operator_writer_gen.cc"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//llvm:TableGen", - "@llvm-project//mlir:TableGen", - ], -) - -gentbl_cc_library( - name = "operator_writer_inc", - compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "operator_writers.inc")], - tblgen = ":operator_writer_gen", - td_file = "//xla/mlir_hlo:mhlo/IR/hlo_ops.td", - deps = [ - "//xla/mlir_hlo:hlo_ops_td_files", - "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", - "@llvm-project//mlir:OpBaseTdFiles", - "@llvm-project//mlir:SideEffectInterfacesTdFiles", - ], -) - -xla_cc_test( - name = "mlir_hlo_to_hlo_test", - srcs = ["mlir_hlo_to_hlo_test.cc"], - deps = [ - ":mlir_hlo_to_hlo", - "//xla/mlir/utils:error_util", - "//xla/tsl/lib/core:status_test_util", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:ShapeDialect", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@stablehlo//:register", + "//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", ], ) cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate instead.", deps = [ - ":mlir_hlo_to_hlo", - ":type_to_shape", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/mlir_hlo:hlo_dialect_registration", - "//xla/service:hlo_module_config", - "//xla/service:hlo_proto_cc", - "//xla/service:hlo_proto_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/hlo/translate/mhlo_to_hlo:translate", ], ) cc_library( name = "translate_registration", testonly = True, - srcs = [ - "translate_registration.cc", - "translate_registration.h", - ], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:translate_registration instead.", deps = [ - ":translate", - "//xla/mlir_hlo:hlo_dialect_registration", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TranslateLib", + "//xla/hlo/translate/mhlo_to_hlo:translate_registration", ], alwayslink = 1, ) cc_library( name = "type_to_shape", - srcs = ["type_to_shape.cc"], hdrs = ["type_to_shape.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/mhlo_to_hlo:type_to_shape instead.", deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/mlir/utils:type_util", - "//xla/mlir_hlo", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:SparseTensorDialect", - "@llvm-project//mlir:SparseTensorEnums", - "@llvm-project//mlir:Support", - "@stablehlo//:stablehlo_ops", - ], -) - -xla_cc_test( - name = "type_to_shape_test", - srcs = ["type_to_shape_test.cc"], - deps = [ - ":type_to_shape", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:test_main", + "//xla/hlo/translate/mhlo_to_hlo:type_to_shape", ], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h index 97c20f112af1ba..2caf77bf3a3d2a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -16,60 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ -#include +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/attribute_exporter.h" -#include "absl/status/statusor.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/StringRef.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Support/LLVM.h" -#include "xla/hlo/ir/hlo_input_output_alias_config.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/hlo.pb.h" -#include "xla/shape_util.h" -#include "xla/stream_executor/dnn.h" -#include "xla/types.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Converts the conv dimensions attribute to XLA HLO. -ConvolutionDimensionNumbers ConvertConvDimensionNumbers( - mlir::mhlo::ConvDimensionNumbersAttr input); - -// Converts the dot algorithm attribute to XLA HLO. -absl::StatusOr ConvertDotAlgorithm( - mlir::mhlo::DotAlgorithmAttr attr); - -absl::StatusOr> ConvertReplicaGroups( - mlir::DenseIntElementsAttr input); - -// Convert a (N, 2) dense attribute to a list of tuples. This is the way padding -// and source-target pairs are defined in HLO. -absl::StatusOr>> ConvertNx2Attribute( - std::optional optional_attr); - -absl::StatusOr ConvertTranspose( - llvm::StringRef transpose_string); - -absl::StatusOr ConvertCustomCallSchedule( - mlir::mhlo::CustomCallSchedule schedule); - -absl::StatusOr ConvertCustomCallApiVersion( - mlir::mhlo::CustomCallApiVersion api_version); - -absl::StatusOr< - std::vector>>> -ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); - -// Returns an OpSharding that represents the result of parsing the given string: -// first, as serialized protobuf, and then as prettyprinted representation. -// Will fail if both attempts at parsing failed. -std::optional ConvertSharding(mlir::StringRef sharding); - -std::optional ConvertInputOutputAlias( - llvm::ArrayRef aliasing); - -} // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h index 2ecd4e3ef3ba3d..6005d23d69e910 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h @@ -18,68 +18,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/ir/hlo_sharding.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace mlir { - -// XLA Layout preferences. Currently, when it comes to TPU, there are two -// primary layout choices for any XLA arguments (parameter or resource): (1) -// CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU -// layout while Linear is native host (CPU) layout. -// This enum allows the caller of XLA to propagate layout preference to the XLA -// compiler. -// kNoPreference: the generic layout where the XLA compiler has the freedom -// to assign any layout. -// kTpuPreferCompactChunkPaddedLayout: use native TPU layout on TPU. -// kTpuPreferLinearLayout: use native CPU layout on TPU. The compiler may -// insert transformation TPU kernels. -// As the layout of any argument will change from a native host layout to a -// native TPU layout either on host or on device, XLA compiler and TPU runtime -// must be in coordination to transform the parameters in a consistent way. -enum class XlaLayoutPreference { - kNoPreference = 0, - kTpuPreferCompactChunkPaddedLayout = 1, - kTpuPreferLinearLayout = 2 -}; - -// The following defines the layout preference of an xla tensor. -// The return value of LayoutPreferenceFn can be used in -// ShapeRepresentationFn. -typedef std::function( - const xla::Shape& shape)> - LayoutPreferenceFn; - -typedef std::function( - const xla::Shape& shape, bool fast_mem, - XlaLayoutPreference layout_preference)> - ShapeRepresentationFn; - -// Return a LayoutPreferenceFn that always uses kNoPreference layout. -LayoutPreferenceFn UseNoPreferenceLayoutFn(); - -// Rewrites the layout of xla_shape if there is tiled sharding. -absl::Status RewriteLayoutWithShardedShape( - const std::optional& sharding, bool use_fast_memory, - const LayoutPreferenceFn& layout_preference_fn, - const ShapeRepresentationFn& shape_representation_fn, - xla::Shape* xla_shape); - -// Adds reshapes to fix the layout of an output, if a shape_representation_fn or -// sharding is present. -absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - const LayoutPreferenceFn& layout_preference_fn, - const ShapeRepresentationFn& shape_representation_fn, - std::optional sharding, bool fast_mem); - -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h index d7ec94c1622918..b5c43ce49c481a 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/location_exporter.h @@ -16,29 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ -#include - -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" -#include "xla/xla_data.pb.h" - -namespace mlir { -namespace mhlo { - -// Returns a OpMetadata proto based on the location of the op. If the location -// is unknown, an empty proto is returned. `op_name` are populated with the op -// location (converted). FileLineColLoc locations are populated by taking the -// file name and line number, and populating `source_file` and `source_line` -// respectively. -xla::OpMetadata CreateOpMetadataFromLocation( - Operation* op, StackFrameIndexBuilder* frame_index_builder); - -// Returns a name that can be used for debugging purposes, e.g., naming -// variable names in generated IR or producing logging output. -std::string GetDebugNameFromLocation(Location location); - -} // namespace mhlo -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/location_exporter.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_LOCATION_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h index d55bfc15cf653d..1544b99e069571 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h @@ -16,81 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "llvm/ADT/ArrayRef.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "xla/client/xla_builder.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo.pb.h" -#include "xla/service/hlo_module_config.h" -#include "xla/translate/mhlo_to_hlo/layout_util.h" - -namespace mlir { - -struct MlirToHloConversionOptions { - // Best-effort propagation of the layouts. These layouts serve as performance - // hints to the backend. - // - // Note that non-array shapes are not carrying layouts, and users have to - // figure out the proper layouts of them through context. This is one of the - // reasons why the attribute-based solution is temporary. - // - // TODO(timshen): Investigate the necessity of having layouts in MHLO. - bool propagate_layouts = false; - - // Propagate the source and result layouts from mhlo bitcast op into the - // backend config for the bitcast. This is required for XLA:GPU backend to - // use elemental IR emitters for fused bitcasts without propagating layouts. - bool propagate_bitcast_layouts_to_backend_config = false; - - LayoutPreferenceFn layout_preference_fn; - ShapeRepresentationFn shape_representation_fn; - - // If use_tuple_args is set, then the entry computations's arguments are - // converted to a tuple and passed as a single parameter. - bool use_tuple_args = false; - - // If return tuple is true, then the entry function's return values - // are converted to a tuple even when there is only a single return value. - // Multiple return values are always converted to a tuple and returned as a - // single value. - bool return_tuple = true; -}; - -// Prefer `ConvertMlirHloToHloModule` over this method when possible, as it -// preserves more information and abstracts away the proto. This method is -// preserved for legacy reasons. -// TODO (b/345806521): Migrate callsites to ConvertMlirHloToHloModule, -// and delete this method. -// -// Converts a MLIR module in HLO dialect into a HloModuleProto. -// -absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, - ::xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple, - MlirToHloConversionOptions options = {}); - -// Converts a MLIR module in HLO dialect into a HloModule with HloModuleConfig. -// This method preserves config data stored in MHLO module attributes. -// -// See `MlirToHloConversionOptions` for details on conversion flags. -absl::StatusOr> ConvertMlirHloToHloModule( - mlir::ModuleOp module, MlirToHloConversionOptions options = {}); - -// Transforms a Block into HLO, where the HLO is represented as calls into an -// XlaBuilder. Callee functions are allowed in the Block's ancestor ModuleOp. -// xla_params are inputs to block. returns are the returned XlaOps. -absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, - llvm::ArrayRef xla_params, - std::vector& returns, - MlirToHloConversionOptions options = {}); - -} // namespace mlir +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_MLIR_HLO_TO_HLO_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir deleted file mode 100644 index 60c5548587a677..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir +++ /dev/null @@ -1,190 +0,0 @@ -// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s - -module @composite { - // CHECK: HloModule composite, entry_computation_layout={()->f32[]} - // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { - // CHECK: %Arg_0.3 = f32[] parameter(0) - // CHECK: %constant.4 = f32[] constant(2) - // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) - // CHECK: } - // CHECK: ENTRY %main.7 () -> f32[] { - // CHECK: %constant.1 = f32[] constant(42) - // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} - // CHECK: } - func.func @main() -> tensor { - %0 = mhlo.constant dense<4.200000e+01> : tensor - %1 = mhlo.composite "foo.bar" %0 { - composite_attributes = { - n = 1 : i32, - tensor = dense<1> : tensor - }, - decomposition = @add, - version = 1 : i32 - } : (tensor) -> tensor - return %1 : tensor - } - func.func @add(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = mhlo.add %arg0, %0 : tensor - return %1 : tensor - } -} - -// ----- - -// zero-output composite -module @composite { - //CHECK: HloModule composite, entry_computation_layout={()->()} - //CHECK: %return.2 (Arg_0.3: f32[]) -> () { - //CHECK: %Arg_0.3 = f32[] parameter(0) - //CHECK: ROOT %tuple.4 = () tuple() - //CHECK: } - //CHECK: ENTRY %main.7 () -> () { - //CHECK: %constant.1 = f32[] constant(42) - //CHECK: %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} - //CHECK: ROOT %tuple.6 = () tuple() - //CHECK: } - func.func @main() -> () { - %0 = mhlo.constant dense<4.200000e+01> : tensor - "mhlo.composite"(%0) { - name = "foo.bar", - composite_attributes = { - n = 1 : i32, - tensor = dense<1> : tensor - }, - decomposition = @return, - version = 1 : i32 - } : (tensor) -> () - return - } - func.func @return(%arg0: tensor) -> () { - return - } -} - -// ----- - -// multi-output composite -module @composite { - //CHECK: HloModule composite, entry_computation_layout={()->(f32[], f32[])} - //CHECK: %add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) { - //CHECK: %Arg_0.3 = f32[] parameter(0) - //CHECK: %constant.4 = f32[] constant(2) - //CHECK: %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) - //CHECK: ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5) - //CHECK: } - //CHECK: ENTRY %main.11 () -> (f32[], f32[]) { - //CHECK: %constant.1 = f32[] constant(42) - //CHECK: %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} - //CHECK: %get-tuple-element.8 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=0 - //CHECK: %get-tuple-element.9 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=1 - //CHECK: ROOT %tuple.10 = (f32[], f32[]) tuple(f32[] %get-tuple-element.8, f32[] %get-tuple-element.9) - //CHECK: } - func.func @main() -> (tensor, tensor) { - %0 = mhlo.constant dense<4.200000e+01> : tensor - %result:2 = "mhlo.composite"(%0) { - name = "foo.bar", - composite_attributes = { - n = 1 : i32, - tensor = dense<1> : tensor - }, - decomposition = @add, - version = 1 : i32 - } : (tensor) -> (tensor, tensor) - return %result#0, %result#1 : tensor, tensor - } - func.func @add(%arg0: tensor) -> (tensor, tensor) { - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = mhlo.add %arg0, %0 : tensor - return %1, %1 : tensor, tensor - } -} - -// ----- - -// optional composite attributes -module @composite { - // CHECK: HloModule composite, entry_computation_layout={()->f32[]} - // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { - // CHECK: %Arg_0.3 = f32[] parameter(0) - // CHECK: %constant.4 = f32[] constant(2) - // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) - // CHECK: } - // CHECK: ENTRY %main.7 () -> f32[] { - // CHECK: %constant.1 = f32[] constant(42) - // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} - // CHECK: } - func.func @main() -> tensor { - %0 = mhlo.constant dense<4.200000e+01> : tensor - %1 = mhlo.composite "foo.bar" %0 { - decomposition = @add, - version = 1 : i32 - } : (tensor) -> tensor - return %1 : tensor - } - func.func @add(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = mhlo.add %arg0, %0 : tensor - return %1 : tensor - } -} - -// ----- - -// optional composite version -module @composite { - // CHECK: HloModule composite, entry_computation_layout={()->f32[]} - // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { - // CHECK: %Arg_0.3 = f32[] parameter(0) - // CHECK: %constant.4 = f32[] constant(2) - // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) - // CHECK: } - // CHECK: ENTRY %main.7 () -> f32[] { - // CHECK: %constant.1 = f32[] constant(42) - // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="0"} - // CHECK: } - func.func @main() -> tensor { - %0 = mhlo.constant dense<4.200000e+01> : tensor - %1 = mhlo.composite "foo.bar" %0 { - composite_attributes = { - n = 1 : i32, - tensor = dense<1> : tensor - }, - decomposition = @add - } : (tensor) -> tensor - return %1 : tensor - } - func.func @add(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = mhlo.add %arg0, %0 : tensor - return %1 : tensor - } -} - -// ----- - -// optional composite attributes and version -module @composite { - // CHECK: HloModule composite, entry_computation_layout={()->f32[]} - // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { - // CHECK: %Arg_0.3 = f32[] parameter(0) - // CHECK: %constant.4 = f32[] constant(2) - // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) - // CHECK: } - // CHECK: ENTRY %main.7 () -> f32[] { - // CHECK: %constant.1 = f32[] constant(42) - // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} - // CHECK: } - func.func @main() -> tensor { - %0 = mhlo.constant dense<4.200000e+01> : tensor - %1 = mhlo.composite "foo.bar" %0 { - decomposition = @add - } : (tensor) -> tensor - return %1 : tensor - } - func.func @add(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<2.000000e+00> : tensor - %1 = mhlo.add %arg0, %0 : tensor - return %1 : tensor - } -} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir deleted file mode 100644 index 295c95075c9fc7..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ /dev/null @@ -1,362 +0,0 @@ -// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s - -// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: f32[], Arg_1.2: f32[4]) -> f32[4,4] -func.func public @main(%arg0: tensor {mhlo.sharding = ""}, %arg1: tensor<4xf32> {mhlo.sharding = "\08\03\1A\01\02\22\02\00\01"}) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) { - // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2]0,1} - // CHECK-NEXT: %Arg_0.1 = f32[] parameter(0), sharding={replicated} - %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4xf32> - %1 = mhlo.multiply %arg1, %0 : tensor<4xf32> - %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<4xf32>) -> tensor<4x4xf32> - // CHECK: ROOT {{.*}}, sharding={devices=[2,1]0,1} - func.return %2 : tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[5,8,128]) -> f32[5,8,128] -func.func @main(%arg0: tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) -> (tensor<5x8x128xf32> {mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01"}) { - // CHECK-NEXT: %Arg_0.1 = f32[5,8,128] parameter(0), sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: %custom-call.2 = f32[5,8,128] custom-call(f32[5,8,128] %Arg_0.1), custom_call_target="Sharding", sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: %tuple.3 = (f32[5,8,128]) tuple(f32[5,8,128] %custom-call.2) - // CHECK-NEXT: ROOT %get-tuple-element.4 = f32[5,8,128] get-tuple-element((f32[5,8,128]) %tuple.3), index=0 - // CHECK-SAME: sharding={devices=[1,2,1]0,1} - %0 = "mhlo.custom_call"(%arg0) {call_target_name = "Sharding", - mhlo.sharding = "\08\03\1A\03\01\02\01\22\02\00\01" - } : (tensor<5x8x128xf32>) -> tensor<5x8x128xf32> - func.return %0 : tensor<5x8x128xf32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} ({{[^,]*}}: f32[4,4]) -> (f32[4,4], f32[4,4]) -func.func @main(%arg0: tensor<4x4xf32>) -> (tensor<4x4xf32> {mhlo.sharding = "\08\03\1A\03\02\01\02\22\04\00\01\02\03B\01\00"}, tensor<4x4xf32>) { - // CHECK-NEXT: %Arg_0.1 = f32[4,4] parameter(0) - // CHECK-NEXT: [[RESHAPE_0:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1), sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} - // CHECK-NEXT: [[RESHAPE_1:%.*]] = f32[4,4] reshape(f32[4,4] %Arg_0.1) - // CHECK-NOT: sharding - // CHECK-NEXT: ROOT {{%.*}} = (f32[4,4], f32[4,4]) tuple(f32[4,4] [[RESHAPE_0]], f32[4,4] [[RESHAPE_1]]) - // CHECK-SAME: sharding={{\{}}{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}, {replicated}} - return %arg0, %arg0 : tensor<4x4xf32>, tensor<4x4xf32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[4] -func.func @main() -> (tensor<4xf32>) { - // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) - // CHECK-NEXT: %broadcast.2 = f32[4] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[2]0,1} - // CHECK-NEXT: ROOT %add.3 = f32[4] add(f32[4] %broadcast.2, f32[4] %broadcast.2) - %0 = mhlo.constant {mhlo.sharding = "{devices=[2]0,1}"} dense<3.1415926> : tensor<4xf32> - %1 = mhlo.add %0, %0 : tensor<4xf32> - return %1 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} () -> f32[12,24,36] -func.func @main() -> (tensor<12x24x36xf32>) { - // CHECK-NEXT: %constant.1 = f32[] constant(3.1415925) - // CHECK-NEXT: %broadcast.2 = f32[12,24,36] broadcast(f32[] %constant.1), dimensions={}, sharding={devices=[1,2,1]0,1} - // CHECK-NEXT: ROOT %add.3 = f32[12,24,36] add(f32[12,24,36] %broadcast.2, f32[12,24,36] %broadcast.2) - %0 = mhlo.constant {mhlo.sharding = "{devices=[1,2,1]0,1}"} dense<3.1415926> : tensor<12x24x36xf32> - %1 = mhlo.add %0, %0 : tensor<12x24x36xf32> - return %1 : tensor<12x24x36xf32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) -func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64> {mhlo.sharding = "{devices=[2,16]<=[32] last_tile_dim_replicate}"}, tensor<512x4xui32> {mhlo.sharding = "{devices=[4,8]<=[32]}"}) { - // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) - // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {devices=[8,4]<=[32]}} - // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} - // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) - // CHECK-NEXT: %reshape.6 = u64[2] reshape(u64[2] %add.5) - // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={devices=[8,4]<=[32]} - // CHECK-NEXT: %reshape.7 = u32[512,4] reshape(u32[512,4] %get-tuple-element.4) - // CHECK-NEXT: ROOT %tuple.8 = (u64[2], u32[512,4]) tuple(u64[2] %reshape.6, u32[512,4] %reshape.7), sharding={{\{}}{devices=[2,16]<=[32] last_tile_dim_replicate}, {devices=[4,8]<=[32]}} - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{{replicated}, {devices=[8,4]<=[32]}}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) - %0 = mhlo.add %output_state, %output_state : tensor<2xui64> - return %0, %output : tensor<2xui64>, tensor<512x4xui32> -} - -// ----- - -// CHECK-LABEL: ENTRY %main.{{.*}} (Arg_0.1: u64[2]) -> (u64[2], u32[512,4]) -func.func @main(%arg0: tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) { - // CHECK-NEXT: %Arg_0.1 = u64[2] parameter(0) - // CHECK-NEXT: %rng-bit-generator.2 = (u64[2], u32[512,4]) rng-bit-generator(u64[2] %Arg_0.1), algorithm=rng_default, sharding={{\{}}{replicated}, {replicated}} - // CHECK-NEXT: %get-tuple-element.3 = u64[2] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=0, sharding={replicated} - // CHECK-NEXT: %add.5 = u64[2] add(u64[2] %get-tuple-element.3, u64[2] %get-tuple-element.3) - // CHECK-NEXT: %get-tuple-element.4 = u32[512,4] get-tuple-element((u64[2], u32[512,4]) %rng-bit-generator.2), index=1, sharding={replicated} - // CHECK-NEXT: ROOT %tuple.6 = (u64[2], u32[512,4]) tuple(u64[2] %add.5, u32[512,4] %get-tuple-element.4) - %output_state, %output = "mhlo.rng_bit_generator"(%arg0) <{rng_algorithm = #mhlo.rng_algorithm}> {mhlo.sharding = "{replicated}"} : (tensor<2xui64>) -> (tensor<2xui64>, tensor<512x4xui32>) - %0 = mhlo.add %output_state, %output_state : tensor<2xui64> - return %0, %output : tensor<2xui64>, tensor<512x4xui32> -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.2 (Arg_.3: s32[]) -> s32[] { -// CHECK-NEXT: %Arg_.3 = s32[] parameter(0), sharding={replicated} -// CHECK-NEXT: %add.4 = s32[] add(s32[] %Arg_.3, s32[] %Arg_.3) -// CHECK-NEXT: %tuple.5 = (s32[]) tuple(s32[] %add.4) -// CHECK-NEXT: ROOT %get-tuple-element.6 = s32[] get-tuple-element((s32[]) %tuple.5), index=0, sharding={replicated} - -// CHECK: %region_1.7 (Arg_.8: s32[]) -> pred[] { -// CHECK-NEXT: %Arg_.8 = s32[] parameter(0), sharding={replicated} -// CHECK-NEXT: ROOT %compare.9 = pred[] compare(s32[] %Arg_.8, s32[] %Arg_.8), direction=LT - -// CHECK: ENTRY %main.11 (Arg_0.1: s32[]) -> s32[] { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: ROOT %while.10 = s32[] while(s32[] %Arg_0.1), condition=%region_1.7, body=%region_0.2, sharding={replicated} - -func.func @main(%arg0: tensor) -> tensor { - %0 = mhlo.while(%iterArg = %arg0) : tensor attributes {mhlo.sharding = "{replicated}"} - cond { - %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor - mhlo.return %1 : tensor - } do { - %1 = mhlo.add %iterArg, %iterArg : tensor - mhlo.return %1 : tensor - } - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.5 (arg_tuple.6: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.6 = (s32[], f32[4], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.7 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.8 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %get-tuple-element.9 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: %add.10 = f32[4] add(f32[4] %get-tuple-element.8, f32[4] %get-tuple-element.9) -// CHECK-NEXT: ROOT %tuple.11 = (s32[], f32[4], f32[4]) tuple(s32[] %get-tuple-element.7, f32[4] %add.10, f32[4] %get-tuple-element.9) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} - -// CHECK: %region_1.12 (arg_tuple.13: (s32[], f32[4], f32[4])) -> pred[] { -// CHECK-NEXT: %arg_tuple.13 = (s32[], f32[4], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: %get-tuple-element.14 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=0, sharding={replicated} -// CHECK-NEXT: ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.14), direction=LT - -// CHECK: ENTRY %main.23 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> (f32[4], f32[4]) { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: %tuple.4 = (s32[], f32[4], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %while.18 = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %tuple.4), condition=%region_1.12, body=%region_0.5 -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.19 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %get-tuple-element.21 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=2, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.22 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.20, f32[4] %get-tuple-element.21) - -func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> - attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate},{devices=[4]<=[4]}}"} - cond { - %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor - mhlo.return %1 : tensor - } do { - %1 = mhlo.add %iterArg_0, %iterArg_1 : tensor<4xf32> - mhlo.return %iterArg, %1, %iterArg_1 : tensor, tensor<4xf32>, tensor<4xf32> - } - func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.5 (arg_tuple.6: (s32[], f32[4], f32[4])) -> (s32[], f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.6 = (s32[], f32[4], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %get-tuple-element.7 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=0, sharding={manual} -// CHECK-NEXT: %get-tuple-element.8 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=1, sharding={manual} -// CHECK-NEXT: %get-tuple-element.9 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.6), index=2, sharding={manual} -// CHECK-NEXT: %add.10 = f32[4] add(f32[4] %get-tuple-element.8, f32[4] %get-tuple-element.9) -// CHECK-NEXT: ROOT %tuple.11 = (s32[], f32[4], f32[4]) tuple(s32[] %get-tuple-element.7, f32[4] %add.10, f32[4] %get-tuple-element.9) -// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} - -// CHECK: %region_1.12 (arg_tuple.13: (s32[], f32[4], f32[4])) -> pred[] { -// CHECK-NEXT: %arg_tuple.13 = (s32[], f32[4], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=1, sharding={manual} -// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=2, sharding={manual} -// CHECK-NEXT: %get-tuple-element.14 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %arg_tuple.13), index=0, sharding={manual} -// CHECK-NEXT: ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.14), direction=LT - -// CHECK: ENTRY %main.23 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> (f32[4], f32[4]) { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: %tuple.4 = (s32[], f32[4], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3) -// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %while.18 = (s32[], f32[4], f32[4]) while((s32[], f32[4], f32[4]) %tuple.4), condition=%region_1.12, body=%region_0.5 -// CHECK-SAME: sharding={{\{}}{manual}, {manual}, {manual}} -// CHECK-NEXT: %get-tuple-element.19 = s32[] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=0, sharding={manual} -// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=1, sharding={manual} -// CHECK-NEXT: %get-tuple-element.21 = f32[4] get-tuple-element((s32[], f32[4], f32[4]) %while.18), index=2, sharding={manual} -// CHECK-NEXT: ROOT %tuple.22 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.20, f32[4] %get-tuple-element.21) - -func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) { - %0:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1, %iterArg_1 = %arg2) : tensor, tensor<4xf32>, tensor<4xf32> - attributes {mhlo.sharding = "{manual}"} - cond { - %1 = mhlo.compare LT, %iterArg, %iterArg : (tensor, tensor) -> tensor - mhlo.return %1 : tensor - } do { - %1 = mhlo.add %iterArg_0, %iterArg_1 : tensor<4xf32> - mhlo.return %iterArg, %1, %iterArg_1 : tensor, tensor<4xf32>, tensor<4xf32> - } - func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 -// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} - -// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} - -// CHECK: ENTRY %main.22 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} -// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} -// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(s32[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), branch_computations={%region_0.8, %region_1.13}, -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) - -func.func @main(%arg0: tensor, - %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, - %arg2: tensor<4xf32>, - %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, - %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { - %0:2 = "mhlo.case"(%arg0) ( { - mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> - }, { - mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> - }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) - func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> -} - - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { -// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) - -// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { -// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) - -// CHECK: ENTRY %main.9 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), branch_computations={%region_0.4, %region_1.6} -func.func @main(%arg0: tensor, - %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, - %arg2: tensor<4xf32>) -> tensor<4xf32> { - %0 = "mhlo.case"(%arg0) ( { - mhlo.return %arg1 : tensor<4xf32> - }, { - mhlo.return %arg2 : tensor<4xf32> - }) : (tensor) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 -// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} - -// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { -// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} - -// CHECK: ENTRY %main.22 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { -// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} -// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} -// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} -// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), true_computation=%region_0.8, false_computation=%region_1.13, -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} -// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} -// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) - -func.func @main(%arg0: tensor, - %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, - %arg2: tensor<4xf32>, - %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, - %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { - %0:2 = "mhlo.if"(%arg0) ( { - mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> - }, { - mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> - }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) - func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> -} - -// ----- - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { -// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) - -// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { -// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) - -// CHECK: ENTRY %main.9 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { -// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(pred[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), true_computation=%region_0.4, false_computation=%region_1.6 - -func.func @main(%arg0: tensor, - %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, - %arg2: tensor<4xf32>) -> tensor<4xf32> { - %0 = "mhlo.if"(%arg0) ( { - mhlo.return %arg1 : tensor<4xf32> - }, { - mhlo.return %arg2 : tensor<4xf32> - }) : (tensor) -> tensor<4xf32> - func.return %0 : tensor<4xf32> -} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir deleted file mode 100644 index 3663f927ae6876..00000000000000 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir +++ /dev/null @@ -1,89 +0,0 @@ -// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s -o - | FileCheck %s - -// This test verifies that the correct shardings are added when a while loop -// has free variables. - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.7 (arg_tuple.8: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) { -// CHECK-NEXT: %arg_tuple.8 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-DAG: %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=3 -// CHECK-DAG: %get-tuple-element.13 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=4, sharding={devices=[4]<=[4]} -// CHECK-DAG: %add.14 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.12) -// CHECK-DAG: %add.15 = f32[4] add(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.13) -// CHECK: ROOT %tuple.16 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %add.14, f32[4] %add.15, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[4] %get-tuple-element.13) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} - -// CHECK: %region_1.17 (arg_tuple.18: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] { -// CHECK-NEXT: %arg_tuple.18 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK: %get-tuple-element.21 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.18), index=2 -// CHECK-NEXT: ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.19, s32[] %get-tuple-element.21), direction=LT - -// CHECK: ENTRY %main.28 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) -// CHECK-NEXT: %constant.4 = s32[] constant(0) -// CHECK-NEXT: %constant.5 = s32[] constant(1) -// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) -// CHECK-NEXT: %tuple.6 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %constant.4, s32[] %constant.5, f32[4] %Arg_2.3) -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %while.25 = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %tuple.6), condition=%region_1.17, body=%region_0.7 -// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} -// CHECK-NEXT: %get-tuple-element.26 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=0, sharding={replicated} -// CHECK-NEXT: ROOT %get-tuple-element.27 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} - -func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense<1> : tensor - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> - attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate}}"} - cond { - %3 = mhlo.compare LT, %iterArg, %0 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor - } do { - %3 = mhlo.add %iterArg, %1 : tensor - %4 = mhlo.add %iterArg_0, %arg2 : tensor<4xf32> - mhlo.return %3, %4: tensor, tensor<4xf32> - } - func.return %2#1 : tensor<4xf32> -} - -// ----- - -// This test verifies that a value captured multiple times is only lifted once -// and all its uses are replaced. Also verifies that no sharding is added to -// region parameters or root when the while doesn't have a sharding. - -// CHECK-LABEL: HloModule main - -// CHECK: %region_0.5 (arg_tuple.6: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) { -// CHECK-NEXT: %arg_tuple.6 = (s32[], f32[4], s32[]) parameter(0) -// CHECK: %get-tuple-element.9 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.6), index=2 -// CHECK: %add.10 = s32[] add(s32[] %get-tuple-element.7, s32[] %get-tuple-element.9) -// CHECK: ROOT %tuple.11 = (s32[], f32[4], s32[]) tuple(s32[] %add.10, f32[4] %get-tuple-element.8, s32[] %get-tuple-element.9) - -// CHECK: %region_1.12 (arg_tuple.13: (s32[], f32[4], s32[])) -> pred[] { -// CHECK-NEXT: %arg_tuple.13 = (s32[], f32[4], s32[]) parameter(0) -// CHECK: %get-tuple-element.16 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.13), index=2 -// CHECK: ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.16), direction=LT - -// CHECK: ENTRY %main.21 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: s32[]) -> f32[4] { -// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) -// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) -// CHECK-NEXT: %Arg_2.3 = s32[] parameter(2) -// CHECK-NEXT: %tuple.4 = (s32[], f32[4], s32[]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %Arg_2.3) -// CHECK-NEXT: %while.18 = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %tuple.4), condition=%region_1.12, body=%region_0.5 - -func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<4xf32> { - %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> - cond { - %3 = mhlo.compare LT, %iterArg, %arg2 : (tensor, tensor) -> tensor - mhlo.return %3 : tensor - } do { - %3 = mhlo.add %iterArg, %arg2 : tensor - mhlo.return %3, %iterArg_0: tensor, tensor<4xf32> - } - func.return %2#1 : tensor<4xf32> -} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h b/third_party/xla/xla/translate/mhlo_to_hlo/translate.h index b65e7d9c78cdaa..373eaca3fca4f3 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ -#include -#include - -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Support/LogicalResult.h" - -namespace xla { - -mlir::LogicalResult MlirHloToHloTranslateFunction(mlir::ModuleOp module, - llvm::raw_ostream& output, - bool emit_return_tuple, - bool emit_use_tuple_arg); - -mlir::LogicalResult MlirHloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -// Translate the MHLO program in in-memory file 'buffer' to a HLO program -// written in a file represented with handle 'output_stream'; -mlir::LogicalResult MlirHloToHloTextMain( - std::unique_ptr buffer, - llvm::raw_ostream& output_stream, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/translate.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h index 6e75fc2b75df40..2e99276efe7c5b 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/type_to_shape.h @@ -16,16 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ #define XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Types.h" -#include "xla/shape.h" -#include "xla/xla_data.pb.h" - -namespace xla { - -// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. -Shape TypeToShape(mlir::Type type); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" #endif // XLA_TRANSLATE_MHLO_TO_HLO_TYPE_TO_SHAPE_H_ diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD b/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD index fb6c1a7d961b9c..452f72aae373b1 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/stablehlo_to_hlo/BUILD @@ -12,35 +12,9 @@ package( cc_library( name = "translate", - srcs = ["translate.cc"], hdrs = ["translate.h"], + deprecation = "This library is deprecated. Use //third_party/tensorflow/compiler/xla/hlo/translate/stablehlo_to_hlo:translate instead.", deps = [ - "//xla/mlir_hlo:mhlo_passes", - "//xla/translate/mhlo_to_hlo:translate", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@stablehlo//:register", + "//xla/hlo/translate/stablehlo_to_hlo:translate", ], ) - -cc_library( - name = "translate_registration", - testonly = True, - srcs = ["translate_registration.cc"], - deps = [ - ":translate", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TranslateLib", - "@stablehlo//:register", - ], - alwayslink = 1, -) diff --git a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h b/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h index d21dcbdacfcf68..badaeeaa9acb30 100644 --- a/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h +++ b/third_party/xla/xla/translate/stablehlo_to_hlo/translate.h @@ -16,35 +16,7 @@ limitations under the License. #ifndef XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ #define XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ -#include -#include - -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Support/LogicalResult.h" - -namespace xla { - -mlir::LogicalResult StablehloToHloTranslateFunction(mlir::ModuleOp module, - llvm::raw_ostream& output, - bool emit_return_tuple, - bool emit_use_tuple_arg); - -mlir::LogicalResult StablehloToHloTextTranslateFunction( - mlir::ModuleOp module, llvm::raw_ostream& output, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -// Translate the StableHLO program in in-memory file 'buffer' to a HLO program -// written in a file represented with handle 'output_stream'; -mlir::LogicalResult StablehloToHloTextMain( - std::unique_ptr buffer, - llvm::raw_ostream& output_stream, bool emit_return_tuple, - bool emit_use_tuple_arg, bool print_layouts, bool print_large_constants, - bool print_sugar, bool via_builder, bool with_layouts); - -} // namespace xla +// The current header will be deprecated in favour of the following. +#include "xla/hlo/translate/stablehlo_to_hlo/translate.h" #endif // XLA_TRANSLATE_STABLEHLO_TO_HLO_TRANSLATE_H_ diff --git a/third_party/xla/xla/tsl/BUILD b/third_party/xla/xla/tsl/BUILD index 48719c0966cd16..8a1e7cd54613ee 100644 --- a/third_party/xla/xla/tsl/BUILD +++ b/third_party/xla/xla/tsl/BUILD @@ -3,7 +3,7 @@ load("@bazel_skylib//lib:selects.bzl", "selects") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla/tsl:package_groups.bzl", "tsl_package_groups") -load("//xla/tsl:tsl.bzl", "if_google", "if_oss") +load("//xla/tsl:tsl.bzl", "if_google", "if_oss", "internal_visibility") load( "//xla/tsl:tsl.default.bzl", "tsl_extra_config_settings", @@ -39,7 +39,7 @@ alias( name = "is_cuda_enabled", actual = if_oss( "@local_config_cuda//:is_cuda_enabled", - "@local_config_cuda//cuda:using_clang", + "@local_config_cuda//cuda:using_config_cuda", ), visibility = ["//visibility:public"], ) @@ -53,6 +53,24 @@ selects.config_setting_group( visibility = ["//visibility:public"], ) +# Config setting that is satisfied when CUDA device code should be compiled +# with nvcc. It does not imply that CUDA support has been enabled. +alias( + name = "is_cuda_compiler_nvcc", + actual = if_oss( + "@local_config_cuda//:is_cuda_compiler_nvcc", + "@local_config_cuda//cuda:FALSE", + ), +) + +selects.config_setting_group( + name = "is_cuda_nvcc", + match_all = [ + ":is_cuda_enabled", + ":is_cuda_compiler_nvcc", + ], +) + # Crosses between framework_shared_object and a bunch of other configurations # due to limitations in nested select() statements. config_setting( @@ -266,30 +284,50 @@ config_setting( config_setting( name = "linux_aarch64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "aarch64"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_armhf", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "armhf"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_x86_64", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_ppc64le", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "ppc"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_s390x", + constraint_values = if_google( + ["//third_party/bazel_platforms/os:linux"], + [], + ), values = {"cpu": "s390x"}, visibility = ["//visibility:public"], ) @@ -500,11 +538,20 @@ config_setting( ) config_setting( - name = "no_nccl_support", + name = "using_no_nccl_support_define", define_values = dict( - if_google({"GOOGLE_CUDA_COMPILER": "clang"}), no_nccl_support = "true", ), + visibility = internal_visibility(["//visibility:private"]), +) + +selects.config_setting_group( + name = "no_nccl_support", + match_all = [ + ":using_no_nccl_support_define", + ] + if_google([ + "@local_config_cuda//cuda:using_config_cuda", + ]), visibility = ["//visibility:public"], ) @@ -519,6 +566,7 @@ bzl_library( "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", + "@local_tsl//third_party/py/rules_pywrap:pywrap_bzl", "@local_tsl//tsl/platform:rules_cc_bzl", ], ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 91e77bda9d79e6..1caa5d166f5ce9 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -16,11 +16,11 @@ cc_library( srcs = ["coordination_service_error_util.cc"], hdrs = ["coordination_service_error_util.h"], deps = [ + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:regexp", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -29,12 +29,12 @@ tsl_cc_test( srcs = ["coordination_service_error_util_test.cc"], deps = [ ":coordination_service_error_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -43,8 +43,8 @@ cc_library( hdrs = ["coordination_client.h"], deps = [ "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -53,14 +53,14 @@ cc_library( hdrs = ["coordination_service.h"], deps = [ ":coordination_client", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -75,6 +75,8 @@ tsl_gpu_library( ":coordination_service", ":coordination_service_error_util", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -91,8 +93,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], alwayslink = 1, ) @@ -118,6 +118,8 @@ tsl_cc_test( ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -131,8 +133,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -146,6 +146,8 @@ tsl_gpu_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/framework:cancellation", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/protobuf:coordination_config_proto_cc", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", @@ -159,8 +161,6 @@ tsl_gpu_library( "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -172,6 +172,8 @@ tsl_cc_test( ":coordination_service_agent", "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -182,8 +184,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", ], ) @@ -197,6 +197,7 @@ cc_library( ":coordination_service", ":coordination_service_agent", ":coordination_service_error_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -206,7 +207,6 @@ cc_library( "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -222,6 +222,9 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", @@ -233,9 +236,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h index cea5ba4890d37b..71bc536af63135 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_client.h @@ -20,8 +20,8 @@ limitations under the License. #include #include "xla/tsl/distributed_runtime/call_options.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { using tensorflow::BarrierRequest; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index a6e835c8e6a2d4..9349f456eea810 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -46,12 +47,12 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { @@ -63,6 +64,8 @@ using tensorflow::CoordinationServiceError; using tensorflow::DeviceInfo; using tensorflow::KeyValueEntry; +constexpr char kClusterRegisterBarrierId[] = + "[Init]Wait_for_all_tasks_to_register"; constexpr absl::Duration kDevicePropagationTimeout = absl::Hours(1); constexpr int kDefaultHeartbeatTimeoutMs = 10 * 1000; // 10 seconds constexpr int kServiceToClientTimeoutMs = 10 * 1000; // 10 seconds @@ -107,7 +110,10 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { CoordinationServiceStandaloneImpl( Env* env, const CoordinationServiceConfig& config, std::unique_ptr client_cache); - ~CoordinationServiceStandaloneImpl() override { Stop(); } + ~CoordinationServiceStandaloneImpl() override { + absl::MutexLock lock(&state_mu_); + Stop(); + } void SetDeviceAggregationFunction( std::function @@ -117,6 +123,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::Status RegisterTask(const CoordinatedTask& task, uint64_t incarnation) override; + void RegisterTaskAsync(const CoordinatedTask& task, uint64_t incarnation, + StatusCallback done) override; void WaitForAllTasks(const CoordinatedTask& task, const DeviceInfo& devices, StatusCallback done) override; void ShutdownTaskAsync(const CoordinatedTask& task, @@ -125,7 +133,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::Status RecordHeartbeat(const CoordinatedTask& task, uint64_t incarnation) override; absl::Status ReportTaskError(const CoordinatedTask& task, - absl::Status error) override; + const absl::Status& error) override; std::vector GetTaskState( const std::vector& task) override; absl::Status InsertKeyValue(std::string_view key, @@ -138,11 +146,11 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { std::vector GetKeyValueDir( std::string_view directory_key) override; absl::Status DeleteKeyValue(std::string_view key) override; - void BarrierAsync(std::string_view barrier_id, absl::Duration timeout, + void BarrierAsync(std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) override; - absl::Status CancelBarrier(std::string_view barrier_id, + absl::Status CancelBarrier(std::string barrier_id, const CoordinatedTask& task) override; void PollForErrorAsync(const CoordinatedTask& task, StatusCallback done) override; @@ -151,6 +159,14 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { const DeviceInfo& ListClusterDevices() override ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); uint64_t GetServiceIncarnation() override; + void BarrierAsyncLocked( + std::string barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + StatusCallback ConnectAfterBarrierPasses(absl::string_view task_name, + uint64_t incarnation, + StatusCallback done); // Checks if any task has stopped sending heartbeats. void CheckHeartbeatTimeout(); // Checks if any barrier has timed out. @@ -160,20 +176,17 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void CheckStaleness(); // Starts a thread to check staleness. void StartCheckStaleness(); - void Stop(bool shut_staleness_thread = true); + void Stop() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); bool ServiceHasStopped() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Report service error to a specified task. - void ReportServiceErrorToTaskAsync(const CoordinatedTask& destination_task, - absl::Status error); // Report error from a task to all other connected tasks if the task is not // recoverable. // Note: SetTaskError() must be called before propagating its error. - void PropagateError(const CoordinatedTask& source_task, + void PropagateError(const absl::Status& error, + std::optional source_task = std::nullopt, bool is_reported_by_task = false) - ABSL_LOCKS_EXCLUDED(state_mu_); - void SetTaskError(std::string_view task_name, absl::Status error) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void SetTaskError(std::string_view task_name, const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); absl::Status DisconnectTask(const CoordinatedTask& task) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); @@ -192,9 +205,32 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // barrier). CoordinatedTask initiating_task; }; - void PassBarrier(std::string_view barrier_id, absl::Status result, + // Validates that the barrier is invoked with the right args. Returns false if + // the barrier should fail immediately. + bool ValidateBarrierArgs( + std::string_view barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Initializes a new barrier. Returns false if the barrier should fail + // immediately. + bool InitializeBarrier( + BarrierState* barrier, std::string_view barrier_id, + absl::Duration timeout, const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + void PassBarrier(std::string_view barrier_id, const absl::Status& result, BarrierState* barrier) ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to connect all tasks. + void ConnectAllTasks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-barrier hook to aggregate device info. + void AggregateClusterDevices() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Post-shutdown barrier hook to disconnect tasks that acked and propagate + // errors to those that have not. + void CompleteShutdownAfterBarrier(const absl::Status& result, + BarrierState* barrier) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Check if participating tasks are specified correctly across barrier calls. bool ValidateTaskArgs( const std::vector& tasks_args, @@ -203,11 +239,13 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { int64_t cluster_size); bool isRecoverableJob(std::string_view task_name) const; // Sends responses to error polling requests when an error is encountered. - void SendErrorPollingResponse(const absl::Status& error); + void SendErrorPollingResponse(const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Responds to error polling or stops the service when an error is // encountered. Should only be called when there is no service to client // connection. Returns true if the service stops, otherwise returns false. - bool SendErrorPollingResponseOrStopService(const absl::Status& error); + bool SendErrorPollingResponseOrStopService(const absl::Status& error) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Returns whether the clients are polling for error from the service. If the // clients are not polling for error from the service, the service should stop // when there is an error. Otherwise, the service should not stop. @@ -248,14 +286,25 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { // When task state becomes ERROR, propagate this status to other CONNECTED // tasks in the cluster. + explicit TaskState(absl::string_view task) { task_name_ = task; } + CoordinatedTaskState GetState() { return state_; } absl::Status GetStatus() { return status_; } uint64_t GetTaskIncarnation() { return task_incarnation_; } + void SetTaskIncarnation(uint64_t task_incarnation) { + task_incarnation_ = task_incarnation; + } + void Connect() { + SetConnected(task_incarnation_); + LOG(INFO) << task_name_ + << " has connected to coordination service. Incarnation: " + << task_incarnation_; + } void SetConnected(uint64_t task_incarnation); void Disconnect(uint64_t grace_period_duration_us); absl::Status RecordHeartbeat(uint64_t task_incarnation); int64_t TimeSinceLastHeartbeatMs(); - void SetError(absl::Status status); + void SetError(const absl::Status& status); DeviceInfo GetDeviceInfo() { return devices_; } void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; } // Checks if task has called WaitForAllTasks() previously, which gathers the @@ -272,6 +321,7 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { bool IsDisconnectedBeyondGracePeriod(); private: + std::string task_name_; // Incarnation ID for CPU:0 on remote task. uint64_t task_incarnation_ = 0; @@ -294,6 +344,8 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { Env& env_; const uint64_t service_incarnation_ = random::New64(); const uint64_t heartbeat_timeout_ms_; + bool cluster_register_with_barrier_ = false; + const absl::Duration cluster_register_timeout_; const absl::Duration shutdown_barrier_timeout_; // If a task restarts with a new incarnation, we may allow it to reconnect // silently if configured. This is useful when we know that a task can @@ -322,10 +374,6 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::flat_hash_map> get_cb_ ABSL_GUARDED_BY(kv_mu_); - absl::CondVar check_staleness_thread_cv_; - bool shutting_down_ ABSL_GUARDED_BY(state_mu_) = false; - std::unique_ptr check_staleness_thread_; - absl::flat_hash_map barriers_ ABSL_GUARDED_BY(state_mu_); // For now, we assume there won't be many simultaneous barriers so we simply @@ -336,6 +384,13 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { ErrorPollingState error_polling_state_ ABSL_GUARDED_BY(state_mu_); + absl::CondVar check_staleness_thread_cv_; + bool shutting_down_ ABSL_GUARDED_BY(state_mu_) = false; + // Note: sequence matters here, we must destroy the staleness thread before + // the other state related to barriers and heartbeats to prevent illegal + // memory access. + std::unique_ptr check_staleness_thread_; + CoordinationServiceStandaloneImpl(const CoordinationServiceStandaloneImpl&) = delete; void operator=(const CoordinationServiceStandaloneImpl&) = delete; @@ -378,7 +433,7 @@ void CoordinationServiceStandaloneImpl::TaskState::Disconnect( } void CoordinationServiceStandaloneImpl::TaskState::SetError( - const absl::Status status) { + const absl::Status& status) { if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return; state_ = CoordinatedTaskState::TASKSTATE_ERROR; status_ = status; @@ -389,8 +444,11 @@ absl::Status CoordinationServiceStandaloneImpl::TaskState::RecordHeartbeat( if (!status_.ok()) return status_; if (task_incarnation != task_incarnation_) { return MakeCoordinationError(absl::AbortedError(absl::StrCat( - "Incarnation ID mismatch: expecting ", task_incarnation_, " but got ", - task_incarnation, ". This means the remote task has restarted."))); + task_name_, "Heartbeat: Incarnation ID mismatch: expecting ", + task_incarnation_, " but got ", task_incarnation, + ". The task has restarted and likely crashed earlier - check for any " + "earlier errors or any scheduler events (e.g. preemption, eviction) to " + "debug further."))); } absl::MutexLock l(&last_heartbeat_mu_); last_heartbeat_us_ = Env::Default()->NowMicros(); @@ -440,6 +498,9 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( ? config.heartbeat_timeout_in_ms() : kDefaultHeartbeatTimeoutMs; }()), + cluster_register_with_barrier_(config.cluster_register_with_barrier()), + cluster_register_timeout_( + absl::Milliseconds(config.cluster_register_timeout_in_ms())), shutdown_barrier_timeout_( absl::Milliseconds(config.shutdown_barrier_timeout_in_ms())), allow_new_incarnation_to_reconnect_( @@ -450,7 +511,7 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( for (const auto& job : config.coordinated_job_list()) { for (int i = 0; i < job.num_tasks(); ++i) { const std::string task_name = GetTaskName(job.name(), i); - cluster_state_.emplace(task_name, std::make_unique()); + cluster_state_.emplace(task_name, std::make_unique(task_name)); } } StartCheckStaleness(); @@ -459,108 +520,87 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { absl::Status status = absl::OkStatus(); std::vector stale_task_names; - const bool has_service_to_client_connection = client_cache_ != nullptr; - { - absl::MutexLock l(&state_mu_); - for (const auto& [task_name, task_state] : cluster_state_) { - // Skip tasks that are not registered or in error state - if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { - continue; - } - const bool is_stale = - task_state->TimeSinceLastHeartbeatMs() > heartbeat_timeout_ms_; - VLOG(10) << "Checking staleness for " << task_name - << " stale?=" << is_stale; - if (is_stale) { - stale_task_names.push_back(task_name); - status = MakeCoordinationError(absl::UnavailableError( - absl::StrCat("Task ", task_name, - " heartbeat timeout. This indicates that the " - "remote task has failed, got preempted, or " - "crashed unexpectedly. Check the task logs " - "for an earlier error to debug further."))); - SetTaskError(task_name, status); + absl::MutexLock l(&state_mu_); + for (const auto& [task_name, task_state] : cluster_state_) { + // Skip tasks that are not registered or in error state. + if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + continue; + } + const bool is_stale = + task_state->TimeSinceLastHeartbeatMs() > heartbeat_timeout_ms_; + VLOG(10) << "Checking staleness for " << task_name + << " stale?=" << is_stale; + if (is_stale) { + stale_task_names.push_back(task_name); + status = MakeCoordinationError(absl::UnavailableError( + absl::StrCat("Task ", task_name, + " heartbeat timeout. This indicates that the " + "remote task has failed, got preempted, or " + "crashed unexpectedly. Check the task logs " + "for an earlier error or scheduler events (e.g. " + "preemption, eviction) to debug further."))); + + SetTaskError(task_name, status); + if (ServiceHasStopped()) { + // Setting the task to error may cause service to stop (e.g. task is + // waiting for shutdown barrier). In this case, all the state is invalid + // and we should exit immediately. + return; } } } // Propagate heartbeat timeout errors to other connected tasks. if (!stale_task_names.empty()) { - if (!has_service_to_client_connection) { - absl::Status heartbeat_timeout_error = - MakeCoordinationError(absl::UnavailableError(absl::StrCat( - "The following tasks are unhealthy (stopped sending " - "heartbeats):\n", - absl::StrJoin(stale_task_names, "\n"), - "\nCheck the task logs for an earlier error to debug " - "further."))); - if (SendErrorPollingResponseOrStopService(heartbeat_timeout_error)) { - return; - } - } else { - for (const auto& stale_task_name : stale_task_names) { - PropagateError(GetTaskFromName(stale_task_name)); - } - } + absl::Status heartbeat_timeout_error = + MakeCoordinationError(absl::UnavailableError( + absl::StrCat("The following tasks are unhealthy (stopped sending " + "heartbeats):\n", + absl::StrJoin(stale_task_names, "\n"), + "\nThe tasks have crashed. Check the task logs for an " + "earlier error, or scheduler events (e.g. preemption, " + "eviction) to debug further."))); + PropagateError(heartbeat_timeout_error); } } void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { - const bool has_service_to_client_connection = client_cache_ != nullptr; absl::flat_hash_map expired_barriers; uint64_t current_time_micros = Env::Default()->NowMicros(); - { - absl::MutexLock l(&state_mu_); - // Gather barriers which have timed out. - for (std::string_view barrier_id : ongoing_barriers_) { - auto* barrier = &barriers_[barrier_id]; - if (current_time_micros > barrier->deadline_in_micros) { - expired_barriers[barrier_id] = barrier; - } + absl::MutexLock l(&state_mu_); + // Gather barriers which have timed out. + for (std::string_view barrier_id : ongoing_barriers_) { + auto* barrier = &barriers_[barrier_id]; + if (current_time_micros > barrier->deadline_in_micros) { + expired_barriers[barrier_id] = barrier; } - // Pass these barriers with the time out error. - for (const auto& [barrier_id, barrier] : expired_barriers) { - std::string pending_tasks; - int pending_task_count = 0; - for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { - if (at_barrier) { - continue; - } - ++pending_task_count; - if (pending_task_count > kPendingTaskLogLimit) { - break; - } - absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); + } + // Pass these barriers with the time out error. + for (const auto& [barrier_id, barrier] : expired_barriers) { + std::string pending_tasks; + int pending_task_count = 0; + // Count and track pending tasks that have not reached the barrier. + for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { + if (at_barrier) { + continue; } - std::string error_message = absl::StrFormat( - "Barrier timed out. This usually happens because a task " - "triggered the barrier unexpectedly early, or some tasks are " - "too slow. Please look at the other task logs to debug " - "further. Barrier_id: %s. The first task at the barrier: " - "%s. ", - barrier_id, GetTaskName(barrier->initiating_task)); - if (pending_task_count > kPendingTaskLogLimit) { - absl::StrAppend( - &error_message, "Too many tasks have timed out. The first ", - kPendingTaskLogLimit, " timed out task names:\n", pending_tasks); - } else { - absl::StrAppend(&error_message, - "Total Number of tasks already at the barrier: ", - barrier->tasks_at_barrier.size() - pending_task_count, - "/", barrier->tasks_at_barrier.size(), - ". Timed out task names:\n", pending_tasks); + ++pending_task_count; + if (pending_task_count < kPendingTaskLogLimit) { + absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); } - const absl::Status error = - MakeCoordinationError(absl::DeadlineExceededError(error_message)); - PassBarrier(barrier_id, error, barrier); } - } - if (!has_service_to_client_connection && - expired_barriers.contains(shutdown_barrier_id_)) { - // Error cannot be propagated through service-to-client connection. - SendErrorPollingResponseOrStopService( - MakeCoordinationError(absl::DeadlineExceededError( - "Shutdown barrier timed out. Check the task logs for an " - "earlier error."))); + const int64_t tasks_at_barrier = + barrier->tasks_at_barrier.size() - pending_task_count; + std::string error_message = absl::StrFormat( + "Barrier timed out. Id: %s. This usually happens because a task " + "triggered the barrier too early or too slowly. Please look at the " + "task logs (both timed out and first task) to debug further.\n" + "# of tasks that reached the barrier: %d/%d.\nThe first " + "task at the barrier: %s. Some timed out task names:\n%s", + barrier_id, tasks_at_barrier, barrier->tasks_at_barrier.size(), + GetTaskName(barrier->initiating_task), pending_tasks); + const absl::Status error = + MakeCoordinationError(absl::DeadlineExceededError(error_message)); + PassBarrier(barrier_id, error, barrier); } } @@ -586,7 +626,11 @@ void CoordinationServiceStandaloneImpl::StartCheckStaleness() { this))); } -void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { +void CoordinationServiceStandaloneImpl::Stop() { + // Prevent recursion. + if (shutting_down_) { + return; + } { absl::MutexLock l(&kv_mu_); for (const auto& [key, get_kv_callbacks] : get_cb_) { @@ -599,38 +643,31 @@ void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { } get_cb_.clear(); } - { - absl::MutexLock l(&state_mu_); - // Indicate that the service is shutting down and stop accepting new RPCs. - shutting_down_ = true; - // Stop the heartbeat thread. - check_staleness_thread_cv_.SignalAll(); - // Fail all ongoing barriers. - for (auto& [barrier_id, barrier] : barriers_) { - if (!barrier.passed) { - absl::Status error = - MakeCoordinationError(absl::AbortedError(absl::StrCat( - "Barrier failed because service is shutting down. Barrier_id: ", - barrier_id))); - PassBarrier(barrier_id, error, &barrier); - } + // Indicate that the service is shutting down and stop accepting new RPCs. + shutting_down_ = true; + // Stop the heartbeat thread. + check_staleness_thread_cv_.SignalAll(); + // Fail all ongoing barriers. + for (auto& [barrier_id, barrier] : barriers_) { + if (!barrier.passed) { + absl::Status error = + MakeCoordinationError(absl::AbortedError(absl::StrCat( + "Barrier failed because service is shutting down. Barrier_id: ", + barrier_id))); + PassBarrier(barrier_id, error, &barrier); } - barriers_.clear(); - // Erase cluster state. - // Note: sequence matters here, this must happen after barrier clean-up as - // the state is used in `PassBarrier`. - cluster_state_.clear(); } + barriers_.clear(); + // Erase cluster state. + // Note: sequence matters here, this must happen after barrier clean-up as + // the state is used in `PassBarrier`. + cluster_state_.clear(); // Cancel all pending PollForErrorAsync() calls. if (IsClientPollingForError()) { SendErrorPollingResponse( absl::CancelledError("Coordination service is shutting down. " "Cancelling PollForErrorAsync()")); } - // Destroy thread outside of the mutex. - if (shut_staleness_thread) { - check_staleness_thread_.reset(); - } } bool CoordinationServiceStandaloneImpl::ServiceHasStopped() const { @@ -659,84 +696,128 @@ void CoordinationServiceStandaloneImpl::LogConnectStatusLocked() const { absl::Status CoordinationServiceStandaloneImpl::RegisterTask( const CoordinatedTask& task, uint64_t incarnation) { + absl::Notification done; + absl::Status status; + RegisterTaskAsync(task, incarnation, [&](absl::Status s) { + status = s; + done.Notify(); + }); + done.WaitForNotification(); + return status; +} + +StatusCallback CoordinationServiceStandaloneImpl::ConnectAfterBarrierPasses( + absl::string_view task_name, uint64_t incarnation, StatusCallback done) { + return [this, task = std::string(task_name), incarnation, + done = std::move(done)](absl::Status s) mutable { + state_mu_.AssertHeld(); + if (!s.ok()) { + done(s); + } else if (incarnation == cluster_state_[task]->GetTaskIncarnation()) { + // Connect task to service. + cluster_state_[task]->Connect(); + done(absl::OkStatus()); + } else { + // Avoid using `AbortedError` which typically has retry semantics. + done(MakeCoordinationError( + absl::AlreadyExistsError("Aborted connect attempt as there is a " + "request from a newer incarnation."))); + } + }; +} + +void CoordinationServiceStandaloneImpl::RegisterTaskAsync( + const CoordinatedTask& task, uint64_t incarnation, StatusCallback done) { const std::string task_name = GetTaskName(task); - absl::Status error; std::string error_message; - { - absl::MutexLock l(&state_mu_); - if (ServiceHasStopped()) { - return MakeCoordinationError(absl::InternalError(absl::StrCat( - "Coordination service has stopped. RegisterTask() from task: ", - task_name, - " failed. This usually implies an earlier error that caused " - "coordination service to shut down before the workers disconnect " - "gracefully. Check the task leader's logs for an earlier error to " - "debug the root cause."))); - } - if (!cluster_state_.contains(task_name)) { - // Note: return early here as unexpected task register errors should not - // be propagated to other tasks. - return MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat( - "Unexpected task registered with task_name=", task_name))); - } + absl::MutexLock l(&state_mu_); + if (ServiceHasStopped()) { + done(MakeCoordinationError(absl::InternalError(absl::StrCat( + "Coordination service has stopped. RegisterTask() from task: ", + task_name, + " failed. This usually implies an earlier error that caused " + "coordination service to shut down before the workers disconnect " + "gracefully. Check the task leader's logs for an earlier error or " + "scheduler events (e.g. preemption, eviction) to debug the root " + "cause.")))); + return; + } + if (!cluster_state_.contains(task_name)) { + // Note: return early here as unexpected task register errors should not + // be propagated to other tasks. + done(MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat( + "Unexpected task registered with task_name=", task_name)))); + return; + } - auto* task_cluster_state = cluster_state_[task_name].get(); - const auto task_state = task_cluster_state->GetState(); - const auto task_status = task_cluster_state->GetStatus(); - - if (task_state == CoordinatedTaskState::TASKSTATE_DISCONNECTED || - (allow_new_incarnation_to_reconnect_ && - (absl::IsUnavailable(task_status) && - task_status.GetPayload(CoordinationErrorPayloadKey())))) { - // The task is allowed to register itself if: - // - this task is currently disconnected (registering for the first time - // or has called ResetTask() previously). - // - this task has lost connection previously which caused it to have - // an unavailable error state, but has now restarted (possibly with - // a new incarnation). This is only allowed if configured with - // `allow_new_incarnation_to_reconnect`. - task_cluster_state->SetConnected(incarnation); - LOG(INFO) << task_name - << " has connected to coordination service. Incarnation: " - << incarnation; + auto* task_cluster_state = cluster_state_[task_name].get(); + const auto task_state = task_cluster_state->GetState(); + const auto task_status = task_cluster_state->GetStatus(); + + if (task_state == CoordinatedTaskState::TASKSTATE_DISCONNECTED || + (allow_new_incarnation_to_reconnect_ && + (absl::IsUnavailable(task_status) && + task_status.GetPayload(CoordinationErrorPayloadKey())))) { + // The task is allowed to register itself if: + // - this task is currently disconnected (registering for the first time + // or has called ResetTask() previously). + // - this task has lost connection previously which caused it to have + // an unavailable error state, but has now restarted (possibly with + // a new incarnation). This is only allowed if configured with + // `allow_new_incarnation_to_reconnect`. + if (cluster_register_with_barrier_) { + // Impose barrier so that all tasks can register together. + // Note: it is possible that the same task restarts multiple times and + // registers itself with new incarnations. + // That is okay; in this code branch, the tasks are not connected yet, + // and the barrier has not succeeded yet. + // There is no state that needs to be cleaned up. + task_cluster_state->SetTaskIncarnation(incarnation); + BarrierAsyncLocked( + kClusterRegisterBarrierId, cluster_register_timeout_, task, {}, + ConnectAfterBarrierPasses(task_name, incarnation, std::move(done))); + return; + } + task_cluster_state->SetTaskIncarnation(incarnation); + task_cluster_state->Connect(); + // TODO(b/369222279): Think about the barrier case - may need periodic + // reporting of stragglers. + LogConnectStatusLocked(); + done(absl::OkStatus()); + return; + } else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) { + // This may happen if the service processes the initial RegisterTask(), + // but the agent did not receive the response so the agent retries again. + if (task_cluster_state->GetTaskIncarnation() == incarnation) { + // This should be a no-op, but we update the last heartbeat timestamp + // to give a longer grace period for the agent to start sending + // heartbeats. + task_cluster_state->Connect(); LogConnectStatusLocked(); - return absl::OkStatus(); - } else if (task_state == CoordinatedTaskState::TASKSTATE_CONNECTED) { - // This may happen if the service processes the initial RegisterTask(), - // but the agent did not receive the response so the agent retries again. - if (task_cluster_state->GetTaskIncarnation() == incarnation) { - // This should be a no-op, but we update the last heartbeat timestamp - // to give a longer grace period for the agent to start sending - // heartbeats. - task_cluster_state->SetConnected(incarnation); - LOG(INFO) << task_name - << " has connected to coordination service with the same " - << "incarnation again: " << incarnation; - LogConnectStatusLocked(); - return absl::OkStatus(); - } else { - error_message = - absl::StrCat(task_name, - " unexpectedly tried to connect with a different " - "incarnation. It has likely restarted."); - } + done(absl::OkStatus()); + return; } else { - // This task is connected or already in error, which implies it has - // registered previously. error_message = absl::StrCat(task_name, - " unexpectedly tried to connect while it is already in " - "error. ResetTask() should be called before a " - "subsequent connect attempt."); + " unexpectedly tried to connect with a different " + "incarnation. It has likely restarted."); } - LOG(ERROR) << error_message; - error = MakeCoordinationError(absl::AbortedError(error_message), task); - SetTaskError(task_name, error); - } - assert(!error.ok()); - PropagateError(task); - return error; + } else { + // This task is connected or already in error, which implies it has + // registered previously. + error_message = + absl::StrCat(task_name, + " unexpectedly tried to connect while it is already in " + "error. ResetTask() should be called before a " + "subsequent connect attempt."); + } + LOG(ERROR) << error_message; + absl::Status error = + MakeCoordinationError(absl::AbortedError(error_message), task); + SetTaskError(task_name, error); + PropagateError(error, task); + done(error); } void CoordinationServiceStandaloneImpl::WaitForAllTasks( @@ -815,8 +896,8 @@ absl::Status CoordinationServiceStandaloneImpl::DisconnectTask( for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { absl::Status error = MakeCoordinationError(absl::InternalError(absl::StrCat( - "Barrier failed from a disconnected task. Barrier Id: ", barrier_id, - ", Task: ", task_name))); + "Barrier failed because a task has disconnected. Barrier Id: ", + barrier_id, ", Task: ", task_name))); PassBarrier(barrier_id, error, &barriers_[barrier_id]); } @@ -833,25 +914,22 @@ uint64_t CoordinationServiceStandaloneImpl::GetServiceIncarnation() { } absl::Status CoordinationServiceStandaloneImpl::ReportTaskError( - const CoordinatedTask& task, absl::Status error) { + const CoordinatedTask& task, const absl::Status& error) { const std::string task_name = GetTaskName(task); - { - absl::MutexLock l(&state_mu_); - if (ServiceHasStopped()) { - return MakeCoordinationError(absl::InternalError( - "Coordination service has stopped. ReportTaskError() failed.")); - } else if (!cluster_state_.contains(task_name)) { - return MakeCoordinationError(absl::InvalidArgumentError( - absl::StrCat("Unexpected request from task ", task_name))); - } else if (cluster_state_[task_name]->GetState() != - CoordinatedTaskState::TASKSTATE_CONNECTED) { - return MakeCoordinationError(absl::FailedPreconditionError( - "The task is not connected or already has an error.")); - } else { - SetTaskError(task_name, error); - } + absl::MutexLock l(&state_mu_); + if (ServiceHasStopped()) { + return MakeCoordinationError(absl::InternalError( + "Coordination service has stopped. ReportTaskError() failed.")); + } else if (!cluster_state_.contains(task_name)) { + return MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Unexpected request from task ", task_name))); + } else if (cluster_state_[task_name]->GetState() != + CoordinatedTaskState::TASKSTATE_CONNECTED) { + return MakeCoordinationError(absl::FailedPreconditionError( + "The task is not connected or already has an error.")); } - PropagateError(task, /*is_reported_by_task=*/true); + SetTaskError(task_name, error); + PropagateError(error, task, /*is_reported_by_task=*/true); return absl::OkStatus(); } @@ -883,132 +961,87 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( const CoordinatedTask& task, uint64_t incarnation) { const std::string task_name = GetTaskName(task); absl::Status s = absl::OkStatus(); - { - absl::MutexLock l(&state_mu_); - if (ServiceHasStopped()) { - return MakeCoordinationError(absl::InternalError(absl::StrCat( - "Coordination service has stopped. RecordHeartbeat() from task: ", - task_name, - " failed. This usually implies an earlier error that caused " - "coordination service to shut down before the workers disconnect " - "gracefully. Check the task leader's logs for an earlier error to " - "debug the root cause."))); - } else if (!cluster_state_.contains(task_name)) { - return MakeCoordinationError(absl::InvalidArgumentError( - absl::StrCat("Unexpected heartbeat request from task: ", task_name, - ". This usually implies a configuration error."))); - } - if (!cluster_state_[task_name]->GetStatus().ok()) { - return cluster_state_[task_name]->GetStatus(); - } else if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) { - // We accept heartbeats for a short grace period to account for the lag - // time between the service recording the state change and the agent - // stopping heartbeats. - return MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat( - "Task with task_name=", task_name, - " must be registered before sending heartbeat messages"))); - } - VLOG(10) << "Record heartbeat from task: " << task_name - << "at incarnation: " << incarnation << "at " << absl::Now(); - s = cluster_state_[task_name]->RecordHeartbeat(incarnation); - } + absl::MutexLock l(&state_mu_); + if (ServiceHasStopped()) { + return MakeCoordinationError(absl::InternalError(absl::StrCat( + "Coordination service has stopped. RecordHeartbeat() from task: ", + task_name, + " failed. This usually implies an earlier error that caused " + "coordination service to shut down before the workers disconnect " + "gracefully. Check the task leader's logs for an earlier error or " + "scheduler events (e.g. preemption, eviction) to debug the root " + "cause."))); + } else if (!cluster_state_.contains(task_name)) { + return MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Unexpected heartbeat request from task: ", task_name, + ". This usually implies a configuration error."))); + } + if (!cluster_state_[task_name]->GetStatus().ok()) { + return cluster_state_[task_name]->GetStatus(); + } else if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) { + // We accept heartbeats for a short grace period to account for the lag + // time between the service recording the state change and the agent + // stopping heartbeats. + return MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Task with task_name=", task_name, + " must be registered before sending heartbeat messages. " + "The service might have restarted, please restart / reset " + "and register again."))); + } + VLOG(10) << "Record heartbeat from task: " << task_name + << "at incarnation: " << incarnation << "at " << absl::Now(); + s = cluster_state_[task_name]->RecordHeartbeat(incarnation); // Set and propagate any heartbeat errors. if (!s.ok()) { - { - absl::MutexLock l(&state_mu_); - SetTaskError(task_name, s); - } - PropagateError(task); + SetTaskError(task_name, s); + PropagateError(s, task); } return s; } -void CoordinationServiceStandaloneImpl::ReportServiceErrorToTaskAsync( - const CoordinatedTask& destination_task, absl::Status error) { +void CoordinationServiceStandaloneImpl::PropagateError( + const absl::Status& error, std::optional source_task, + bool is_reported_by_task) { + VLOG(3) << "PropagateError(): " << error; assert(!error.ok()); - - // Don't report error if there is no service-to-client connection. + // If there is no service-to-client connection, use error polling or stop + // the service. if (client_cache_ == nullptr) { - LOG(ERROR) << error; + SendErrorPollingResponseOrStopService(error); return; } - auto request = std::make_shared(); - auto response = std::make_shared(); - request->set_error_code(error.raw_code()); - request->set_error_message(std::string(error.message())); - CoordinatedTask* error_source = - request->mutable_error_payload()->mutable_source_task(); - error_source->set_job_name("coordination_service"); - auto call_opts = std::make_shared(); - call_opts->SetTimeout(kServiceToClientTimeoutMs); - - const std::string task_name = GetTaskName(destination_task); - CoordinationClient* client = client_cache_->GetClient(task_name); - client->ReportErrorToTaskAsync( - call_opts.get(), request.get(), response.get(), - [request, response, task_name, call_opts](absl::Status s) { - if (!s.ok()) { - LOG(ERROR) << "Encountered another error while reporting to " - << task_name << ": " << s; - } - }); -} - -void CoordinationServiceStandaloneImpl::PropagateError( - const CoordinatedTask& source_task, bool is_reported_by_task) { - VLOG(3) << "PropagateError() from " << GetTaskName(source_task); - // If the error task is recoverable, do not propagate the error to other - // connected tasks. - if (isRecoverableJob(source_task.job_name())) return; - absl::Status error; - { - absl::MutexLock l(&state_mu_); - error = cluster_state_[GetTaskName(source_task)]->GetStatus(); - } - assert(!error.ok()); ReportErrorToTaskRequest request; request.set_error_code(error.raw_code()); request.set_error_message(std::string(error.message())); CoordinationServiceError* payload = request.mutable_error_payload(); - *payload->mutable_source_task() = source_task; payload->set_is_reported_error(is_reported_by_task); CallOptions call_opts; call_opts.SetTimeout(kServiceToClientTimeoutMs); - std::vector> notifications; - - std::vector task_names; - { - absl::ReaderMutexLock l(&state_mu_); - task_names.reserve(cluster_state_.size()); - for (const auto& pair : cluster_state_) { - task_names.emplace_back(pair.first); - } + if (source_task.has_value()) { + // If the error task is recoverable, do not propagate the error to other + // connected tasks. + if (isRecoverableJob(source_task->job_name())) return; + *payload->mutable_source_task() = *source_task; } - for (std::string_view task : task_names) { - { - absl::MutexLock l(&state_mu_); - // Propagate error only to tasks that are connected - if (cluster_state_[task]->GetState() != - CoordinatedTaskState::TASKSTATE_CONNECTED) - continue; - } - // If there is no service-to-client connection, use error polling or stop - // the service. - if (client_cache_ == nullptr) { - SendErrorPollingResponseOrStopService(error); - return; + std::vector> notifications; + + for (const auto& pair : cluster_state_) { + // Propagate error only to tasks that are connected + if (pair.second->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + continue; } + std::string task = pair.first; - CoordinationClient* client = client_cache_->GetClient(std::string(task)); + CoordinationClient* client = client_cache_->GetClient(task); auto response = std::make_shared(); auto n = std::make_shared(); client->ReportErrorToTaskAsync( &call_opts, &request, response.get(), - [response, n, task](absl::Status s) { + [response, n, task](const absl::Status& s) { if (!s.ok()) { LOG(ERROR) << "Encountered another error while reporting to " << task << ": " << s; @@ -1156,19 +1189,19 @@ absl::Status CoordinationServiceStandaloneImpl::DeleteKeyValue( return absl::OkStatus(); } -void CoordinationServiceStandaloneImpl::SetTaskError(std::string_view task_name, - absl::Status error) { +void CoordinationServiceStandaloneImpl::SetTaskError( + std::string_view task_name, const absl::Status& error) { cluster_state_[task_name]->SetError(error); + LOG(ERROR) << task_name + << " has been set to ERROR in coordination service: " << error; for (const auto& barrier_id : cluster_state_[task_name]->GetOngoingBarriers()) { - absl::Status error = MakeCoordinationError(absl::InternalError(absl::StrCat( - "Barrier failed from a task error. Barrier Id: ", barrier_id, - ", Task: ", task_name))); - PassBarrier(barrier_id, error, &barriers_[barrier_id]); + absl::Status barrier_error = + MakeCoordinationError(absl::InternalError(absl::StrCat( + "Barrier failed beacuse a task is in error. Barrier Id: ", + barrier_id, ", Task: ", task_name, " Error: ", error.ToString()))); + PassBarrier(barrier_id, barrier_error, &barriers_[barrier_id]); } - - LOG(ERROR) << task_name - << " has been set to ERROR in coordination service: " << error; } void CoordinationServiceStandaloneImpl::PollForErrorAsync( @@ -1230,14 +1263,13 @@ void CoordinationServiceStandaloneImpl::PollForErrorAsync( error_polling_state_.AddTask(task, std::move(done)); } -void CoordinationServiceStandaloneImpl::BarrierAsync( +// Validates that the barrier is invoked with the right args. Returns false if +// the barrier should fail immediately. +bool CoordinationServiceStandaloneImpl::ValidateBarrierArgs( std::string_view barrier_id, absl::Duration timeout, const CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) { - VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync(" - << barrier_id << ")."; - // Check if caller task is participating in the barrier. If not, update // `barriers_` to cause subsequent calls from the same task and other tasks // that have already called this instance of the barrier to fail. @@ -1254,26 +1286,119 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( absl::StrCat("A non-participating task (", GetTaskName(task), ") called the barrier: ", barrier_id))); - { - absl::MutexLock l(&state_mu_); - // Check if coordination service has stopped. If so, return an error - // immediately. - if (ServiceHasStopped()) { - done(MakeCoordinationError(absl::InternalError( - "Barrier requested after coordination service has shut down."))); - return; + // Check if coordination service has stopped. If so, return an error + // immediately. + if (ServiceHasStopped()) { + done(MakeCoordinationError(absl::InternalError( + "Barrier requested after coordination service has shut down."))); + return false; + } + auto pair = barriers_.try_emplace(barrier_id); + auto it = pair.first; + auto* barrier = &it->second; + // Make sure subsequent calls fail and existing waiting tasks receive the + // error. + PassBarrier(barrier_id, error, barrier); + done(error); + return false; + } + return true; +}; + +// Initializes a new barrier. Returns false if the barrier should fail +// immediately. +bool CoordinationServiceStandaloneImpl::InitializeBarrier( + BarrierState* barrier, std::string_view barrier_id, absl::Duration timeout, + const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { + // Initialize barrier state. + barrier->passed = false; + barrier->initiating_task = task; + // Assume barrier is for entire cluster if no tasks are specified. + if (participating_tasks.empty()) { + for (const auto& task_state : cluster_state_) { + std::string_view task_name = task_state.first; + barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false; + } + } else { + for (const auto& task : participating_tasks) { + // Fail the barrier immediately if unexpected task is included in the + // barrier. + const std::string task_name = GetTaskName(task); + if (!cluster_state_.contains(task_name)) { + absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( + absl::StrCat("Unexpected task (", task_name, + ") that is not in the cluster called the barrier. " + "Barrier Id: ", + barrier_id))); + PassBarrier(barrier_id, error, barrier); + done(error); + return false; } - auto pair = barriers_.try_emplace(barrier_id); - auto it = pair.first; - auto* barrier = &it->second; - // Make sure subsequent calls fail and existing waiting tasks receive the - // error. + barrier->tasks_at_barrier[task] = false; + } + } + barrier->num_pending_tasks = barrier->tasks_at_barrier.size(); + + // Fail the barrier immediately if any tasks are already in error. + for (const auto& pending_task : barrier->tasks_at_barrier) { + const std::string task_name = GetTaskName(pending_task.first); + if (cluster_state_[task_name]->GetState() == + CoordinatedTaskState::TASKSTATE_ERROR) { + absl::Status error = MakeCoordinationError(absl::InternalError( + absl::StrCat("Task (", task_name, + ") is already in error before the barrier " + "was called. Barrier Id: ", + barrier_id, " Task error: ", + cluster_state_[task_name]->GetStatus().ToString()))); PassBarrier(barrier_id, error, barrier); + done(error); + return false; } - done(error); - return; } + barrier->deadline_in_micros = + Env::Default()->NowMicros() + (timeout / absl::Microseconds(1)); + + // Add ongoing barrier to cluster state. + ongoing_barriers_.emplace(barrier_id); + const size_t num_ongoing_barriers = ongoing_barriers_.size(); + if (num_ongoing_barriers > kOngoingBarriersSoftLimit) { + LOG(WARNING) << "There is a high number of ongoing barriers in " + "coordination service: " + << num_ongoing_barriers; + } + for (const auto& pending_task : barrier->tasks_at_barrier) { + const CoordinatedTask& task = pending_task.first; + cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id); + } + return true; +} + +void CoordinationServiceStandaloneImpl::BarrierAsync( + // Note: `barrier_id` uses a `std::string` instead of `string_view` as the + // RPC may end (i.e. done callback is invoked) before this handler + // completes, which would invalidate the `string_view`. + std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { absl::MutexLock l(&state_mu_); + return BarrierAsyncLocked(barrier_id, timeout, task, participating_tasks, + std::move(done)); +}; + +void CoordinationServiceStandaloneImpl::BarrierAsyncLocked( + std::string barrier_id, absl::Duration timeout, const CoordinatedTask& task, + const std::vector& participating_tasks, + StatusCallback done) { + VLOG(3) << "Task " << GetTaskName(task) << " invoked BarrierAsync(" + << barrier_id << ")."; + + if (!ValidateBarrierArgs(barrier_id, timeout, task, participating_tasks, + done)) { + return; // Exit early if args are wrong. + } + // Check if coordination service has stopped. If so, return an error // immediately. if (ServiceHasStopped()) { @@ -1281,70 +1406,17 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( "Barrier requested after coordination service has shut down."))); return; } + auto pair = barriers_.try_emplace(barrier_id); auto it = pair.first; bool inserted = pair.second; auto* barrier = &it->second; + // Create barrier for the first time. if (inserted) { - // Initialize barrier state. - barrier->passed = false; - barrier->initiating_task = task; - // Assume barrier is for entire cluster if no tasks are specified. - if (participating_tasks.empty()) { - for (const auto& task_state : cluster_state_) { - std::string_view task_name = task_state.first; - barrier->tasks_at_barrier[GetTaskFromName(task_name)] = false; - } - } else { - for (const auto& task : participating_tasks) { - // Fail the barrier immediately if unexpected task is included in the - // barrier. - const std::string task_name = GetTaskName(task); - if (!cluster_state_.contains(task_name)) { - absl::Status error = MakeCoordinationError(absl::InvalidArgumentError( - absl::StrCat("Unexpected task (", task_name, - ") that is not in the cluster called the barrier. " - "Barrier Id: ", - barrier_id))); - PassBarrier(barrier_id, error, barrier); - done(error); - return; - } - barrier->tasks_at_barrier[task] = false; - } - } - barrier->num_pending_tasks = barrier->tasks_at_barrier.size(); - - // Fail the barrier immediately if any tasks are already in error. - for (const auto& pending_task : barrier->tasks_at_barrier) { - const std::string task_name = GetTaskName(pending_task.first); - if (cluster_state_[task_name]->GetState() == - CoordinatedTaskState::TASKSTATE_ERROR) { - absl::Status error = MakeCoordinationError(absl::InternalError( - absl::StrCat("Task (", task_name, - ") is already in error before the barrier " - "was called. Barrier Id: ", - barrier_id))); - PassBarrier(barrier_id, error, barrier); - done(error); - return; - } - } - barrier->deadline_in_micros = - Env::Default()->NowMicros() + (timeout / absl::Microseconds(1)); - - // Add ongoing barrier to cluster state. - ongoing_barriers_.emplace(barrier_id); - const size_t num_ongoing_barriers = ongoing_barriers_.size(); - if (num_ongoing_barriers > kOngoingBarriersSoftLimit) { - LOG(WARNING) << "There is a high number of ongoing barriers in " - "coordination service: " - << num_ongoing_barriers; - } - for (const auto& pending_task : barrier->tasks_at_barrier) { - const CoordinatedTask& task = pending_task.first; - cluster_state_[GetTaskName(task)]->JoinBarrier(barrier_id); + if (!InitializeBarrier(barrier, barrier_id, timeout, task, + participating_tasks, done)) { + return; // Exit early if barrier init failed. } } @@ -1392,7 +1464,10 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( } absl::Status CoordinationServiceStandaloneImpl::CancelBarrier( - std::string_view barrier_id, const CoordinatedTask& task) { + // Note: `barrier_id` uses a `std::string` instead of `string_view` as the + // RPC may end (i.e. done callback is invoked) before this handler + // completes, which would invalidate the `string_view`. + std::string barrier_id, const CoordinatedTask& task) { absl::MutexLock l(&state_mu_); if (ServiceHasStopped()) { return MakeCoordinationError(absl::InternalError( @@ -1424,7 +1499,7 @@ absl::Status CoordinationServiceStandaloneImpl::CancelBarrier( // Mark barrier as passed. void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id, - absl::Status result, + const absl::Status& result, BarrierState* barrier) { barrier->passed = true; barrier->result = result; @@ -1438,47 +1513,32 @@ void CoordinationServiceStandaloneImpl::PassBarrier(std::string_view barrier_id, const CoordinatedTask& task = task_at_barrier.first; cluster_state_[GetTaskName(task)]->ExitBarrier(barrier_id); } - - // Special hook for shutdown barrier to disconnect tasks at the barrier. - if (barrier_id == shutdown_barrier_id_) { - if (result.ok()) { - LOG(INFO) << "Shutdown barrier in coordination service has passed."; - } else { - LOG(ERROR) << "Shutdown barrier in coordination service has failed:\n" - << result - << "\nThis suggests that the workers are out of sync. Either " - "at least one worker is too fast in its execution / " - "crashed early or too slow / hanging. Check the logs for " - "an earlier error to identify the root cause."; - } - absl::Status shutdown_error = MakeCoordinationError(absl::InternalError( - absl::StrCat("Shutdown barrier has been passed with status: '", - barrier->result.ToString(), - "', but this task is not at the barrier yet."))); - for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { - if (at_barrier) { - // Disconnect tasks that reached the barrier. - absl::Status disconnect_status = DisconnectTask(task); - if (!disconnect_status.ok()) { - LOG(ERROR) << disconnect_status; - } - } else { - // Propagate errors to straggling tasks that have not reached the - // barrier. The barrier must have failed if any task did not reach the - // barrier. - ReportServiceErrorToTaskAsync(task, shutdown_error); - } - } - } barrier->tasks_at_barrier.clear(); ongoing_barriers_.erase(barrier_id); - // Note: barrier_id shouldn't be referenced after this line as its lifetime - // may be tied to one of the callbacks. // Propagate results to participating tasks. for (const auto& callback : barrier->done_callbacks) { callback(result); } barrier->done_callbacks.clear(); + if (barrier_id == kClusterRegisterBarrierId && !result.ok()) { + // Stop service if register failed. + LOG(ERROR) + << "Stopping coordination service as cluster registration failed. This " + "may be due to 1) some tasks crashed earlier before connecting, 2) " + "some tasks were never scheduled, or 3) scheduling delays. Consider " + "setting a longer initialization timeout if such delays are " + "expected, the timeout is currently set to: " + << cluster_register_timeout_ << ".\n\nOriginal error: " << result; + Stop(); + return; + } + // Special hook for shutdown barrier to disconnect tasks at the barrier and + // propagate errors to those that have not. + if (barrier_id == shutdown_barrier_id_) { + CompleteShutdownAfterBarrier(result, barrier); + // Note: this may stop the service. Be careful about referencing barrier + // state after this point. + } } void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( @@ -1486,11 +1546,8 @@ void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( CHECK(IsClientPollingForError()) << "`SendErrorPollingResponse` should only be called after agents poll " "errors from the service."; - { - absl::MutexLock l(&state_mu_); - if (error_polling_state_.Responded()) { - return; - } + if (error_polling_state_.Responded()) { + return; } if (!absl::IsCancelled(error)) { VLOG(2) << "An error is encountered. Sending the error as a response to " @@ -1498,16 +1555,13 @@ void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( << error; } std::vector missing_tasks; - { - absl::MutexLock l(&state_mu_); - missing_tasks.reserve(cluster_state_.size()); - for (const auto& [task_name, task_state] : cluster_state_) { - if (!error_polling_state_.IsTaskPolling(task_name)) { - missing_tasks.push_back(task_name); - } + missing_tasks.reserve(cluster_state_.size()); + for (const auto& [task_name, task_state] : cluster_state_) { + if (!error_polling_state_.IsTaskPolling(task_name)) { + missing_tasks.push_back(task_name); } - error_polling_state_.SetError(error); } + error_polling_state_.SetError(error); if (!missing_tasks.empty()) { LOG(ERROR) << absl::StrFormat( "The following %d tasks in the cluster has not sent request to poll " @@ -1561,6 +1615,43 @@ void CoordinationServiceStandaloneImpl::AggregateClusterDevices() { cluster_devices_ = post_aggregate_device_fn_(cluster_devices_); } } + +void CoordinationServiceStandaloneImpl::CompleteShutdownAfterBarrier( + const absl::Status& result, BarrierState* barrier) { + if (result.ok()) { + LOG(INFO) << "Shutdown barrier in coordination service has passed."; + } else { + LOG(ERROR) << "Shutdown barrier in coordination service has failed:\n" + << result + << "\nThis suggests that the workers are out of sync. Either at " + "least one worker (a) crashed early due to program error or " + "scheduler events (e.g. preemption, eviction), (b) was " + "too fast in its execution, or (c) too slow / hanging. Check " + "the logs (both the program and scheduler events) for an " + "earlier error to identify the root cause."; + absl::Status shutdown_error = MakeCoordinationError(absl::InternalError( + absl::StrCat("Shutdown barrier has failed, but this task is not at the " + "barrier yet.\nBarrier result: '", + barrier->result.ToString()))); + // Propagate error to all tasks before disconnecting them. + PropagateError(shutdown_error); + } + // It is possible that PropagateError() stops the service. In this case, the + // task state is forcibly erased and disconnecting the task is not + // necessary. + if (ServiceHasStopped()) { + return; + } + for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { + if (at_barrier) { + // Disconnect tasks that reached the barrier. + absl::Status disconnect_status = DisconnectTask(task); + if (!disconnect_status.ok()) { + LOG(ERROR) << disconnect_status; + } + } + } +} } // namespace std::unique_ptr EnableCoordinationService( @@ -1592,7 +1683,7 @@ bool CoordinationServiceStandaloneImpl::SendErrorPollingResponseOrStopService( LOG(ERROR) << "Stopping coordination service as there is no " "service-to-client connection, but we encountered an error: " << error; - Stop(/*shut_staleness_thread=*/false); + Stop(); return true; } diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h index 9ef96f1f6b425a..af9eae6a34ec62 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h @@ -30,10 +30,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { class Env; @@ -122,8 +122,11 @@ class CoordinationServiceInterface { // - InvalidArgument: Unexpected task request. // - Aborted: (1) task is in error state, or (2) task is in connected state // with a different incarnation, indicating that it restarted. + // - DeadlineExceeded: waited too long for straggler tasks to register. virtual absl::Status RegisterTask(const tensorflow::CoordinatedTask& task, uint64_t incarnation) = 0; + virtual void RegisterTaskAsync(const tensorflow::CoordinatedTask& task, + uint64_t incarnation, StatusCallback done) = 0; // Wait for all tasks to be up and running, and register local device // info. The callback is invoked when all tasks are up and registered, or some @@ -159,7 +162,7 @@ class CoordinationServiceInterface { // Set a task in error state permanently. virtual absl::Status ReportTaskError(const tensorflow::CoordinatedTask& task, - absl::Status error) = 0; + const absl::Status& error) = 0; // Get the state and the error status of the tasks. virtual std::vector GetTaskState( @@ -225,7 +228,7 @@ class CoordinationServiceInterface { // list of participating tasks. // - FailedPrecondition: Agent is in UNINITIALIZED or ERROR state. virtual void BarrierAsync( - std::string_view barrier_id, absl::Duration timeout, + std::string barrier_id, absl::Duration timeout, const tensorflow::CoordinatedTask& task, const std::vector& participating_tasks, StatusCallback done) = 0; @@ -236,7 +239,7 @@ class CoordinationServiceInterface { // Possible service errors: // - FailedPrecondition: Barrier has already been passed. virtual absl::Status CancelBarrier( - std::string_view barrier_id, const tensorflow::CoordinatedTask& task) = 0; + std::string barrier_id, const tensorflow::CoordinatedTask& task) = 0; // Gets error from the coordination service. Block until the service // returns an error or the task/service is shutdown. This should never be used diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 4290fba754f880..f89ea78b0b7100 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -45,12 +45,12 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/framework/cancellation.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { using tensorflow::CoordinatedTask; @@ -69,7 +69,6 @@ constexpr absl::Duration kDefaultClusterRegisterTimeout = absl::Hours(1); constexpr absl::Duration kDefaultHeartbeatTimeout = absl::Seconds(10); constexpr absl::Duration kDefaultShutdownTimeout = absl::Seconds(10); constexpr char kHeartbeatThread[] = "CoordinationServiceHeartbeatLoop"; -constexpr char kErrorPollingThread[] = "CoordinationServiceErrorPolling"; class CoordinationServiceAgentImpl : public CoordinationServiceAgent { public: @@ -146,10 +145,8 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { absl::Status ShutdownInternal(); // Starts sending heartbeats to the coordination service. void StartSendingHeartbeats(); - // Use long polling to get error from the coordination service. This function - // will block until an error is received or the agent is shutdown or reset. - absl::Status PollForError(); - std::shared_ptr PollForErrorAsync(StatusCallback done); + // Use long polling to get error from the coordination service. + void PollForErrorAsync(StatusCallback done); // Starts polling for error from the coordination service. void StartPollingForError(); @@ -180,7 +177,6 @@ class CoordinationServiceAgentImpl : public CoordinationServiceAgent { absl::CondVar heartbeat_thread_cv_; bool shutting_down_ TF_GUARDED_BY(heartbeat_thread_shutdown_mu_) = false; std::unique_ptr heartbeat_thread_; - std::unique_ptr error_polling_thread_; // Must outlive coordination client which may need to access it within // GetKeyValueAsync() callbacks. CancellationManager cancellation_manager_; @@ -259,7 +255,6 @@ void CoordinationServiceAgentImpl::StopHeartbeat() { void CoordinationServiceAgentImpl::StopErrorPolling() { // Cancel pending error polling RPC call. error_polling_cancellation_manager_->StartCancel(); - error_polling_thread_ = nullptr; } void CoordinationServiceAgentImpl::ResetCancellationManager() { @@ -298,7 +293,7 @@ absl::Status CoordinationServiceAgentImpl::Connect() { call_opts.SetTimeout(absl::ToInt64Milliseconds(deadline - absl::Now())); absl::Notification n; leader_client_->RegisterTaskAsync( - &call_opts, &request, &response, [&](absl::Status s) { + &call_opts, &request, &response, [&](const absl::Status& s) { if (s.ok()) { leader_incarnation_ = response.leader_incarnation(); { @@ -343,11 +338,7 @@ absl::Status CoordinationServiceAgentImpl::Connect() { absl::bind_front(&CoordinationServiceAgentImpl::StartSendingHeartbeats, this))); if (configs_.poll_for_error_from_service_at_startup()) { - // Start a thread to poll for error from the coordination service. - error_polling_thread_.reset(env_->StartThread( - ThreadOptions(), kErrorPollingThread, - absl::bind_front(&CoordinationServiceAgentImpl::StartPollingForError, - this))); + StartPollingForError(); } return absl::OkStatus(); } @@ -371,7 +362,7 @@ void CoordinationServiceAgentImpl::StartSendingHeartbeats() { // transient network failures. VLOG(10) << "HeartbeatRequest: " << request.DebugString(); leader_client_->HeartbeatAsync(&call_opts, &request, &response, - [&](absl::Status s) { + [&](const absl::Status& s) { status = s; n.Notify(); }); @@ -394,9 +385,11 @@ void CoordinationServiceAgentImpl::StartSendingHeartbeats() { } SetError(status); } else if (response.leader_incarnation() != leader_incarnation_) { - SetError(MakeCoordinationError( - absl::AbortedError("Leader incarnation ID mismatch: the " - "coordination leader has restarted."))); + SetError(MakeCoordinationError(absl::AbortedError( + "Leader incarnation ID mismatch: the coordination leader " + "(usually slice 0 task 0) has restarted. Check for earlier " + "errors or any scheduler events (e.g. preemption, eviction) to " + "debug further."))); } // Send next heartbeat after an interval. { @@ -412,44 +405,33 @@ void CoordinationServiceAgentImpl::StartSendingHeartbeats() { } void CoordinationServiceAgentImpl::StartPollingForError() { - LOG(INFO) << "Polling for error from coordination service. This thread will " - "run until an error is encountered or the agent is shutdown."; - absl::Status status = PollForError(); - CHECK(!status.ok()) << "PollForError returned OK status. Should " - "always return an error."; - if (absl::IsCancelled(status)) { - LOG(INFO) << "Cancelling error polling because the service or the agent is " - "shutting down."; - // Return early and there is no need to set error. - return; - } - LOG(ERROR) << "An error is returned from coordination service (this can be " - "an error from this or another task)."; - SetError(status); -} - -absl::Status CoordinationServiceAgentImpl::PollForError() { - absl::Status status = absl::OkStatus(); - absl::Notification n; - PollForErrorAsync([&](absl::Status s) { - status = s; - n.Notify(); + LOG(INFO) << "Polling for error from coordination service. This is a " + "long-running RPC that will return only if an error is " + "encountered or cancelled (e.g. due to shutdown)."; + PollForErrorAsync([&](const absl::Status& status) { + CHECK(!status.ok()) << "PollForError returned OK status. Should " + "always return an error."; + if (absl::IsCancelled(status)) { + LOG(INFO) + << "Cancelling error polling because the service or the agent is " + "shutting down."; + // Return early and there is no need to set error. + return; + } + LOG(ERROR) << "Polled an error from coordination service (this can be " + "an error from this or another task)."; + SetError(status); }); - n.WaitForNotification(); - CHECK(!status.ok()) - << "PollForError returned OK status. Should always return an error."; - return status; } -std::shared_ptr CoordinationServiceAgentImpl::PollForErrorAsync( - StatusCallback done) { +void CoordinationServiceAgentImpl::PollForErrorAsync(StatusCallback done) { auto call_opts = std::make_shared(); absl::Status agent_running_status = ValidateRunningAgent(/*allow_disconnected=*/true); if (!agent_running_status.ok()) { done(agent_running_status); - return call_opts; + return; } auto request = std::make_shared(); auto response = std::make_shared(); @@ -463,7 +445,7 @@ std::shared_ptr CoordinationServiceAgentImpl::PollForErrorAsync( token, [call_opts]() { call_opts->StartCancel(); }); if (already_cancelled) { done(absl::CancelledError("PollForErrorAsync() was cancelled.")); - return call_opts; + return; } leader_client_->PollForErrorAsync( @@ -476,7 +458,6 @@ std::shared_ptr CoordinationServiceAgentImpl::PollForErrorAsync( cm->TryDeregisterCallback(token); done(s); }); - return call_opts; } absl::Status CoordinationServiceAgentImpl::WaitForAllTasks( @@ -493,7 +474,7 @@ absl::Status CoordinationServiceAgentImpl::WaitForAllTasks( absl::Status status; absl::Notification n; leader_client_->WaitForAllTasksAsync(&request, &response, - [&](absl::Status s) { + [&](const absl::Status& s) { status = s; n.Notify(); }); @@ -569,7 +550,7 @@ absl::Status CoordinationServiceAgentImpl::ReportError( absl::Notification n; leader_client_->ReportErrorToServiceAsync( - &request, &response, [&](absl::Status s) { + &request, &response, [&](const absl::Status& s) { VLOG(5) << "ReportErrorToServiceResponse: " << s; if (!s.ok()) { LOG(ERROR) @@ -577,8 +558,10 @@ absl::Status CoordinationServiceAgentImpl::ReportError( "coordination service: " << s << "\nThis is usually caused by an earlier error during " - "execution. Check the logs (this task or the leader) for " - "an earlier error to debug further."; + "execution. Check the logs of (a) this task, (b) the " + "leader (usually slice 0 task 0) and (c) the scheduler " + "(e.g. preemption, eviction) for an earlier error to debug " + "further."; } n.Notify(); }); @@ -612,7 +595,7 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() { absl::Notification n; leader_client_->ShutdownTaskAsync(&call_opts, &request, &response, - [&status, &n](absl::Status s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); @@ -624,8 +607,10 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() { << "Failed to disconnect from coordination service with status: " << TrimCoordinationErrorMessage(status) << "\nProceeding with agent shutdown anyway. This is usually caused " - "by an earlier error during execution. Check the logs (this task " - "or the leader) for an earlier error to debug further."; + "by an earlier error during execution. Check the logs of (a) this " + "task, (b) the leader (usually slice 0 task 0) and (c) the " + "scheduler (e.g. preemption, eviction) for an earlier error to " + "debug further."; } } @@ -641,8 +626,9 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() { "still shutdown anyway. Agent status: ", status_.ToString(), "\nThis is usually caused by an earlier error during execution. " - "Check the logs (this task or the leader) for an earlier error to " - "debug further."); + "Check the logs of (a) this task, (b) the leader (usually slice 0 " + "task 0) and (c) the scheduler (e.g. preemption, eviction) for an " + "earlier error to debug further."); status = MakeCoordinationError(absl::FailedPreconditionError(status_message)); LOG(ERROR) << status_message; @@ -672,7 +658,7 @@ absl::Status CoordinationServiceAgentImpl::Reset() { absl::Status status; absl::Notification n; leader_client_->ResetTaskAsync(&request, &response, - [&status, &n](absl::Status s) { + [&status, &n](const absl::Status& s) { status = s; n.Notify(); }); @@ -837,10 +823,11 @@ absl::Status CoordinationServiceAgentImpl::InsertKeyValue( absl::Status status; absl::Notification n; - leader_client_->InsertKeyValueAsync(&request, &response, [&](absl::Status s) { - status = s; - n.Notify(); - }); + leader_client_->InsertKeyValueAsync(&request, &response, + [&](const absl::Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); VLOG(3) << "InsertKeyValueResponse: " << status; return status; @@ -856,10 +843,11 @@ absl::Status CoordinationServiceAgentImpl::DeleteKeyValue( absl::Status status; absl::Notification n; - leader_client_->DeleteKeyValueAsync(&request, &response, [&](absl::Status s) { - status = s; - n.Notify(); - }); + leader_client_->DeleteKeyValueAsync(&request, &response, + [&](const absl::Status& s) { + status = s; + n.Notify(); + }); n.WaitForNotification(); VLOG(3) << "DeleteKeyValueResponse " << status; return absl::OkStatus(); @@ -889,7 +877,6 @@ void CoordinationServiceAgentImpl::SetError(const absl::Status& error) { if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return; absl::Status trimmed_error = TrimCoordinationErrorMessage(error); - LOG(ERROR) << "Coordination agent is set to ERROR: " << trimmed_error; state_ = CoordinatedTaskState::TASKSTATE_ERROR; status_ = trimmed_error; error_fn_(trimmed_error); @@ -906,7 +893,7 @@ absl::Status CoordinationServiceAgentImpl::WaitAtBarrier( const std::vector& tasks) { absl::Status status; absl::Notification n; - WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](absl::Status s) { + WaitAtBarrierAsync(barrier_id, timeout, tasks, [&](const absl::Status& s) { status = s; n.Notify(); }); diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h index 3ec188ac251801..6c58501d51e112 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.h @@ -28,8 +28,8 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tensorflow { class CoordinationServiceConfig; diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 1281ea8f78988f..8e6783f0bccc97 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -30,11 +30,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc index 8fc7631b458197..b3babb837fb9a0 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc @@ -41,9 +41,10 @@ absl::Status TrimCoordinationErrorMessage(const absl::Status& s) { // This error is not provided by us, so it's probably an RPC layer error. auto prefix_message = "Failed to send RPC to coordination service. Either the leader task " - "died/restarted unexpectedly or this task is experiencing network " - "issues. Check earlier logs from this task and the " - "leader (usually slice 0 process/task/worker 0) to debug further.\n"; + "was preempted/died/restarted unexpectedly or this task is " + "experiencing network issues. Check earlier logs from 1) this task, 2) " + "the leader (usually slice 0 task 0), and 3) cluster scheduler to debug" + " further.\n"; status_message = absl::StrCat( prefix_message, // Replace the duplicated error message at the start with the prefix. diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h index e1a3cdc06eefe9..07df399979a6fa 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h @@ -18,7 +18,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" -#include "tsl/protobuf/coordination_service.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc index 535f471f0a3fc1..f4917f717a8584 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { using ::tensorflow::CoordinatedTask; @@ -143,6 +143,7 @@ TEST(CoordinationServiceErrorUtil, TrimCoordinationErrorMessage_NetworkError) { auto message = trimmed_error.message(); EXPECT_EQ(trimmed_error.code(), error.code()); EXPECT_TRUE(absl::StrContains(message, "Check earlier logs")); + EXPECT_TRUE(absl::StrContains(message, "preempted")); // Message is not duplicated. EXPECT_EQ(message.find("failed to connect"), message.rfind("failed to connect")) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index 3ec3290c9507e1..737091b1ca7fc3 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -32,11 +32,11 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc index 200db9df7ee232..b043795731b5b1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { @@ -66,7 +66,7 @@ void CoordinationServiceRpcHandler::RegisterTaskAsync( const uint64_t incarnation = request->incarnation(); const uint64_t leader_incarnation = service_->GetServiceIncarnation(); response->set_leader_incarnation(leader_incarnation); - done(service_->RegisterTask(task, incarnation)); + service_->RegisterTaskAsync(task, incarnation, done); } void CoordinationServiceRpcHandler::HeartbeatAsync( diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h index 537a5d5be3a652..2b9ca2ef9f3d2e 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_rpc_handler.h @@ -20,9 +20,9 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { class CoordinationServiceRpcHandler { diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 2fa500109acc17..c233347f7a924f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -37,13 +37,13 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/coordination_config.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { @@ -236,6 +236,7 @@ class CoordinateTwoTasksTest : public ::testing::Test { void EnableCoordinationService( bool has_service_to_client_connection = true, bool enable_shutdown_barrier = false, + bool enable_register_barrier = false, bool set_worker_job_recoverable = false, bool allow_new_incarnation_to_reconnect = false) { CoordinationServiceConfig config = @@ -256,6 +257,11 @@ class CoordinateTwoTasksTest : public ::testing::Test { config.set_shutdown_barrier_timeout_in_ms(kShutdownBarrierTimeout / absl::Milliseconds(1)); } + if (enable_register_barrier) { + config.set_cluster_register_with_barrier(true); + config.set_cluster_register_timeout_in_ms(absl::Seconds(1) / + absl::Milliseconds(1)); + } if (allow_new_incarnation_to_reconnect) { config.set_allow_new_incarnation_to_reconnect(true); } @@ -306,14 +312,16 @@ TEST_F(CoordinateTwoTasksTest, TestStandaloneService) { ASSERT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RecordHeartbeat(task_1_, incarnation_1_)); - EXPECT_TRUE( - absl::IsInvalidArgument(coord_service_->RecordHeartbeat(task_2, 0))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_2, 0), + StatusIs(absl::StatusCode::kInvalidArgument)); // Sending heartbeat with incarnation mismatch leads to Aborted error. - EXPECT_TRUE(absl::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0))); - EXPECT_TRUE(absl::IsAborted(coord_service_->RecordHeartbeat(task_1_, 0))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, 0), + StatusIs(absl::StatusCode::kAborted)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, 0), + StatusIs(absl::StatusCode::kAborted)); // Error is propagated to other tasks. - EXPECT_TRUE(absl::IsAborted(client_0_.GetStatus())); + EXPECT_THAT(client_0_.GetStatus(), StatusIs(absl::StatusCode::kAborted)); } TEST(CoordinationServiceTest, TestCoordinatedJobs) { @@ -379,8 +387,8 @@ TEST(CoordinationServiceTest, TestCoordinatedJobs) { // Registering the evaluator task is unexpected absl::Status status = coord_service->RegisterTask(evaluator, /*incarnation=*/0); - EXPECT_TRUE(absl::IsInvalidArgument(status)) << status; - EXPECT_TRUE(!status.message().empty()); + + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument)); } // RegisterTask() may succeed in the service, but the agent response times out. @@ -426,12 +434,9 @@ TEST(CoordinationServiceTest, const absl::Status status = coord_service->RegisterTask(task_0, /*incarnation=*/1); - EXPECT_TRUE(absl::IsAborted(status)) << status; - EXPECT_TRUE(!status.message().empty()); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kAborted)); } -// TODO(b/195990880): Remove this test once server-client connection is removed. -// This test passes only when there is a single task. TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { CoordinationServiceConfig config = GetCoordinationServiceConfig(/*num_tasks=*/1); @@ -448,12 +453,14 @@ TEST(CoordinationServiceTest, RegisterTask_AlreadyInError_Fails) { ASSERT_OK(coord_service->ReportTaskError(task_0, absl::InternalError("test_error"))); - // Registration should fail since task already registered previously. + // Registration should fail. const absl::Status status = coord_service->RegisterTask(task_0, /*incarnation=*/0); - EXPECT_TRUE(absl::IsAborted(status)) << status; - EXPECT_TRUE(!status.message().empty()); + // Impl note: the error triggers the service to stop, which fails new + // requests. It's okay to change the error code during development as long as + // it fails. + EXPECT_THAT(status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) { @@ -464,10 +471,10 @@ TEST_F(CoordinateTwoTasksTest, TestTaskHeartbeatTimeout) { // No heartbeat for a while, leader considers the task as stale. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsUnavailable( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); - EXPECT_TRUE(absl::IsUnavailable( - coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, incarnation_1_), + StatusIs(absl::StatusCode::kUnavailable)); } TEST_F(CoordinateTwoTasksTest, @@ -505,10 +512,10 @@ TEST_F(CoordinateTwoTasksTest, absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); // Unexpected heartbeat from unregistered tasks since service state has been // reset. - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_1_, incarnation_1_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_1_, incarnation_1_), + StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -518,15 +525,14 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. absl::Notification n0, n1; + absl::Status s0, s1; - // The heartbeat error below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s1 = status; n1.Notify(); }); @@ -534,10 +540,13 @@ TEST_F(CoordinateTwoTasksTest, // the error to the tasks. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); - // Make sure the StatusCallbacks are called. n0.WaitForNotification(); n1.WaitForNotification(); + + // Heartbeat errors are propagated to everyone. + EXPECT_THAT(s0, StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(s1, StatusIs(absl::StatusCode::kUnavailable)); } TEST_F(CoordinateTwoTasksTest, @@ -546,16 +555,15 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. + absl::Status s0, s1; absl::Notification n0, n1; - // The heartbeat error from `task_1_` below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code, HasSubstr("task:1"))); + s1 = status; n1.Notify(); }); @@ -569,10 +577,15 @@ TEST_F(CoordinateTwoTasksTest, Env::Default()->SleepForMicroseconds(sleeping_time); TF_EXPECT_OK(coord_service_->RecordHeartbeat(task_0_, incarnation_0_)); Env::Default()->SleepForMicroseconds(sleeping_time); - // Make sure the StatusCallbacks are called. n0.WaitForNotification(); n1.WaitForNotification(); + + // The heartbeat error from `task_1_` below should be propagated to all tasks. + EXPECT_THAT(s0, + StatusIs(absl::StatusCode::kUnavailable, HasSubstr("task:1"))); + EXPECT_THAT(s1, + StatusIs(absl::StatusCode::kUnavailable, HasSubstr("task:1"))); } TEST_F(CoordinateTwoTasksTest, ReportedErrorCanPropagateThroughErrorPolling) { @@ -601,16 +614,21 @@ TEST_F(CoordinateTwoTasksTest, TestTaskRestart) { // Simulate task restart scenario: trying to register to cluster again. absl::Status s = coord_service_->RegisterTask(task_1_, /*incarnation=*/random::New64()); - EXPECT_TRUE(absl::IsAborted(s)) << s; + + EXPECT_THAT(s, StatusIs(absl::StatusCode::kAborted)); // Aborted error is also propagated to other tasks in cluster. - EXPECT_TRUE(absl::IsAborted(client_0_.GetStatus())) << client_0_.GetStatus(); + EXPECT_THAT(client_0_.GetStatus(), StatusIs(absl::StatusCode::kAborted)); } TEST_F(CoordinateTwoTasksTest, InsertKeyValue_Duplicate_Fail) { EnableCoordinationService(); ASSERT_OK(coord_service_->InsertKeyValue("key0", "original_value")); - EXPECT_TRUE(absl::IsAlreadyExists( - coord_service_->InsertKeyValue("key0", "never_added"))); + + // Inserting the same key again should fail. + EXPECT_THAT(coord_service_->InsertKeyValue("key0", "never_added"), + StatusIs(absl::StatusCode::kAlreadyExists)); + + // The original value should still be set. auto result = coord_service_->TryGetKeyValue("key0"); TF_EXPECT_OK(result.status()); EXPECT_EQ(result.value(), "original_value"); @@ -701,7 +719,7 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { // Try to get nonexistent key. absl::StatusOr result = coord_service->TryGetKeyValue("test_key"); - EXPECT_TRUE(absl::IsNotFound(result.status())); + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kNotFound)); // Insert key value. ASSERT_OK(coord_service->InsertKeyValue("test_key", "test_value")); @@ -711,7 +729,7 @@ TEST(CoordinationServiceTest, TryGetKeyValue) { // Delete Key, and try to get the key again. ASSERT_OK(coord_service->DeleteKeyValue("test_key")); result = coord_service->TryGetKeyValue("test_key"); - EXPECT_TRUE(absl::IsNotFound(result.status())); + EXPECT_THAT(result.status(), StatusIs(absl::StatusCode::kNotFound)); } TEST_F(CoordinateTwoTasksTest, GetKeyValueDir_SingleValueInDirectory) { @@ -1074,8 +1092,8 @@ TEST_F(CoordinationBarrierTest, BarrierWithMismatchedTasks) { /*participating_tasks=*/{GetTask(1), GetTask(2)}, [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { @@ -1097,8 +1115,8 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTask) { [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); // Barrier should fail for all tasks with the unexpected call. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) { @@ -1139,7 +1157,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonParticipatingTaskThreeTasks) { [&barrier_status_2](absl::Status s) { barrier_status_2 = s; }); // Barrier should fail for task 2 which is not participating in the barrier. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_2)); + EXPECT_THAT(barrier_status_2, StatusIs(absl::StatusCode::kInvalidArgument)); // Other clients would need to check the barrier key to detect the error. } @@ -1163,7 +1181,7 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { n_0.WaitForNotification(); // Barrier should fail with the unexpected participating task argument. - EXPECT_TRUE(absl::IsInvalidArgument(barrier_status_0)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinationBarrierTest, BarrierTimeout) { @@ -1191,7 +1209,7 @@ TEST_F(CoordinationBarrierTest, BarrierTimeout) { // All barrier calls should fail with the same error. EXPECT_EQ(barrier_status_0, barrier_status_1); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status_0)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kDeadlineExceeded)); EXPECT_FALSE( absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(0)))); EXPECT_TRUE( @@ -1227,8 +1245,8 @@ TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { /*participating_tasks=*/{}, [&barrier_status_1](absl::Status s) { barrier_status_1 = s; }); - EXPECT_TRUE(absl::IsInternal(barrier_status_0)); - EXPECT_TRUE(absl::IsInternal(barrier_status_1)); + EXPECT_THAT(barrier_status_0, StatusIs(absl::StatusCode::kInternal)); + EXPECT_THAT(barrier_status_1, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, BarrierCancelled) { @@ -1243,7 +1261,7 @@ TEST_F(CoordinationBarrierTest, BarrierCancelled) { absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); - EXPECT_TRUE(absl::IsCancelled(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kCancelled)); TF_EXPECT_OK(cancelled_status); } @@ -1260,7 +1278,7 @@ TEST_F(CoordinationBarrierTest, CancelNonExistentBarrier_FutureBarrierFails) { /*participating_tasks=*/{}, [&barrier_status](absl::Status s) { barrier_status = s; }); - EXPECT_TRUE(absl::IsCancelled(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kCancelled)); } TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { @@ -1286,7 +1304,8 @@ TEST_F(CoordinationBarrierTest, CancelAfterBarrierHasPassed) { absl::Status cancelled_status = GetCoordinationService()->CancelBarrier(barrier_id, GetTask(0)); - EXPECT_TRUE(absl::IsFailedPrecondition(cancelled_status)); + EXPECT_THAT(cancelled_status, + StatusIs(absl::StatusCode::kFailedPrecondition)); TF_EXPECT_OK(barrier_status_0); TF_EXPECT_OK(barrier_status_1); TF_EXPECT_OK(barrier_status_2); @@ -1354,7 +1373,7 @@ TEST_F(CoordinationBarrierTest, BarrierFailsIfTaskIsAlreadyInError) { /*participating_tasks=*/{}, [&barrier_status](absl::Status s) { barrier_status = s; }); - EXPECT_TRUE(absl::IsInternal(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) { @@ -1373,7 +1392,7 @@ TEST_F(CoordinationBarrierTest, BarrierFailsUponTaskError) { GetTask(0), absl::InternalError("test_error"))); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinationBarrierTest, @@ -1440,8 +1459,8 @@ TEST_F(CoordinateTwoTasksTest, Reset_HeartbeatsAreAcceptedForAGracePeriod) { // period. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(3 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsInvalidArgument( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { @@ -1462,7 +1481,7 @@ TEST_F(CoordinateTwoTasksTest, Reset_FailsOngoingBarrier) { // Ongoing barrier should fail with error after shutdown. EXPECT_TRUE(barrier_n.HasBeenNotified()); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { @@ -1484,8 +1503,8 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_HeartbeatsAreAcceptedForAGracePeriod) { // period. Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(3 * kHeartbeatTimeout)); - EXPECT_TRUE(absl::IsInvalidArgument( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInvalidArgument)); } TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { @@ -1511,7 +1530,7 @@ TEST_F(CoordinateTwoTasksTest, Shutdown_FailsOngoingBarrier) { // Ongoing barrier should fail with error after shutdown. EXPECT_TRUE(barrier_n.HasBeenNotified()); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, ShutdownWithBarrier_BarrierSucceeds) { @@ -1554,7 +1573,7 @@ TEST_F(CoordinateTwoTasksTest, // Block until barrier times out. n.WaitForNotification(); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kDeadlineExceeded)); // Confirm that task_0_ has disconnected. // Note: this should not happen in prod where RegisterTask() is called after // Shutdown(), which is prevented by agent-side logic. @@ -1562,7 +1581,7 @@ TEST_F(CoordinateTwoTasksTest, // Other task is alerted that shutdown has been initiated without it. absl::Status other_task_status = client_1_.GetStatus(); - EXPECT_TRUE(absl::IsInternal(other_task_status)) << other_task_status; + EXPECT_THAT(other_task_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -1585,7 +1604,7 @@ TEST_F(CoordinateTwoTasksTest, Env::Default()->SleepForMicroseconds( absl::ToInt64Microseconds(absl::Seconds(1))); - EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kDeadlineExceeded)); // Service stops because no service-to-client connection is available for // error propagation. @@ -1593,7 +1612,7 @@ TEST_F(CoordinateTwoTasksTest, // service has stopped yet, which should fail. absl::Status s = coord_service_->RecordHeartbeat(task_1_, incarnation_1_); - EXPECT_TRUE(absl::IsInternal(s)) << s; + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierFailsIfServiceHasStopped) { @@ -1615,7 +1634,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsIfServiceHasStopped) { }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { @@ -1624,15 +1643,14 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); // Use notifications to guarantee the ordering of operations across threads. absl::Notification n0, n1; + absl::Status s0, s1; - // The heartbeat error below should be propagated to all tasks. - absl::StatusCode expected_error_code = absl::StatusCode::kUnavailable; coord_service_->PollForErrorAsync(task_0_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s0 = status; n0.Notify(); }); coord_service_->PollForErrorAsync(task_1_, [&](const absl::Status& status) { - EXPECT_THAT(status, StatusIs(expected_error_code)); + s1 = status; n1.Notify(); }); @@ -1644,6 +1662,9 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { // Make sure the StatusCallbacks are called before the barrier is called. n0.WaitForNotification(); n1.WaitForNotification(); + // The heartbeat error should be propagated to all tasks. + EXPECT_THAT(s0, StatusIs(absl::StatusCode::kUnavailable)); + EXPECT_THAT(s1, StatusIs(absl::StatusCode::kUnavailable)); absl::Notification n_barrier; absl::Status barrier_status; @@ -1655,7 +1676,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierFailsAfterErrorPollingResponse) { }); n_barrier.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, BarrierWithSubsetFailsIfServiceHasStopped) { @@ -1680,7 +1701,7 @@ TEST_F(CoordinateTwoTasksTest, BarrierWithSubsetFailsIfServiceHasStopped) { }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, @@ -1708,12 +1729,13 @@ TEST_F(CoordinateTwoTasksTest, }); n0.WaitForNotification(); - EXPECT_TRUE(absl::IsInternal(barrier_status)) << barrier_status; + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/false); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1722,15 +1744,16 @@ TEST_F(CoordinateTwoTasksTest, UnrecoverableTaskPropagatesError) { ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // For unrecoverable task, error propagates to all connected tasks. - EXPECT_TRUE(absl::IsInternal(client_1_.GetStatus())); + EXPECT_THAT(client_1_.GetStatus(), StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1739,8 +1762,8 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskWillNotPropagateError) { ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // Since no error propagation for recoverable tasks, other tasks should work // as normal. TF_EXPECT_OK(client_1_.GetStatus()); @@ -1750,6 +1773,7 @@ TEST_F(CoordinateTwoTasksTest, RecoverableTaskReportErrorResetAndRegisterAgain) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/true); TF_EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); @@ -1758,8 +1782,8 @@ TEST_F(CoordinateTwoTasksTest, ASSERT_OK(coord_service_->ReportTaskError(task_0_, absl::InternalError("test_error"))); - EXPECT_TRUE(absl::IsInternal( - coord_service_->RecordHeartbeat(task_0_, incarnation_0_))); + EXPECT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kInternal)); // Since no error propagation for recoverable tasks, other tasks should work // as normal. TF_EXPECT_OK(client_1_.GetStatus()); @@ -1774,6 +1798,7 @@ TEST_F(CoordinateTwoTasksTest, TEST_F(CoordinateTwoTasksTest, UnavailableTaskCanReconnect) { EnableCoordinationService(/*has_service_to_client_connection=*/true, /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/false, /*set_worker_job_recoverable=*/false, /*allow_new_incarnation_to_reconnect=*/true); @@ -1890,8 +1915,10 @@ TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorWhenInErrorState) { coord_service_->PollForErrorAsync( task_0_, [&](const absl::Status& status) { s = status; }); - EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition, - HasSubstr("test_error"))); + // Impl note: the error triggers the service to stop, which fails new + // requests. It is okay to change the error code during development as long as + // it fails. + EXPECT_THAT(s, StatusIs(absl::StatusCode::kInternal)); } TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfServiceHasStopped) { @@ -1954,4 +1981,96 @@ TEST_F(CoordinateTwoTasksTest, LatePollingTaskCanGetError) { HasSubstr("test_error_from_task_0")))); } +TEST_F(CoordinateTwoTasksTest, + RegisterWithBarrier_OldHeartbeat_ServiceNotStopped) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + // Service restarted. + // Old task 0 sends an unexpected heartbeat, which should fail. + // Crucially, this should not stop the service, so future API calls should not + // trigger an internal error (which occurs if service has shut down). + ASSERT_THAT(coord_service_->RecordHeartbeat(task_0_, incarnation_0_ - 1), + StatusIs(absl::StatusCode::kInvalidArgument)); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync(task_0_, incarnation_0_, + [](const absl::Status& s) {}); + // Task 0 restarts with a new incarnation, and registers again. + // This should be allowed since all tasks have not joined the cluster yet. + coord_service_->RegisterTaskAsync( + task_0_, /*incarnation=*/incarnation_0_ + 1, + [&](const absl::Status& s) { task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + EXPECT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + EXPECT_OK(task0_status); +} + +TEST_F(CoordinateTwoTasksTest, + RegisterWithBarrier_RestartBeforeBarrier_Succeeds) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + absl::Status restarted_task0_status = + absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync( + task_0_, incarnation_0_, + [&](const absl::Status& s) { task0_status = s; }); + // Task 0 restarts with a new incarnation, and registers again. + // This should be allowed since all tasks have not joined the cluster yet. + coord_service_->RegisterTaskAsync( + task_0_, /*incarnation=*/incarnation_0_ + 1, + [&](const absl::Status& s) { restarted_task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + ASSERT_THAT(task0_status, StatusIs(absl::StatusCode::kAlreadyExists)); + ASSERT_OK(restarted_task0_status); + // Task 0 joins again with the same incarnation. + // This is okay, it didn't restart, probably sent RPC twice due to network + // retries. + EXPECT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_ + 1)); +} + +TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_RestartAfterBarrier_Fails) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + absl::Status task0_status = absl::InternalError("uninitialized_status"); + // Task 0 registers first. + coord_service_->RegisterTaskAsync( + task_0_, incarnation_0_, + [&](const absl::Status& s) { task0_status = s; }); + // Now all tasks will register in a synchronized fashion due to the barrier. + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + ASSERT_OK(task0_status); + + // Task 0 restarts again with a new incarnation. + // This should fail since this happens after the initial register barrier + // (i.e. all tasks already acked once). + ASSERT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_ + 2), + StatusIs(absl::StatusCode::kAborted)); + // Service should have stopped due to the previous registration failure. + // Check for internal error code. + absl::Notification n; + absl::Status barrier_status; + coord_service_->BarrierAsync("barrier_id", absl::Seconds(10), task_0_, {}, + [&](const absl::Status& s) { + n.Notify(); + barrier_status = s; + }); + n.WaitForNotification(); + EXPECT_THAT(barrier_status, StatusIs(absl::StatusCode::kInternal)); +} + +TEST_F(CoordinateTwoTasksTest, RegisterWithBarrier_Timeout) { + EnableCoordinationService(/*has_service_to_client_connection=*/false, + /*enable_shutdown_barrier=*/false, + /*enable_register_barrier=*/true); + // Task 0 joins without task 1. Times out eventually as this function is + // blocking. + EXPECT_THAT(coord_service_->RegisterTask(task_0_, incarnation_0_), + StatusIs(absl::StatusCode::kDeadlineExceeded)); +} } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD index b533f7b4bd88ef..1789c00cdd4316 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/BUILD @@ -55,6 +55,7 @@ cc_library( "//xla/tsl/distributed_runtime:call_options", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", "//xla/tsl/lib/monitoring:gauge", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", @@ -65,7 +66,6 @@ cc_library( "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -83,6 +83,9 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/protobuf:coordination_config_proto_cc_impl", + "//xla/tsl/protobuf:coordination_service_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -92,8 +95,5 @@ tsl_cc_test( "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:coordination_config_proto_cc_impl", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc_impl", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index ee85f70e04b771..c6e41a9f030f62 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" #include "xla/tsl/distributed_runtime/preemption/preemption_notifier.h" #include "xla/tsl/lib/monitoring/gauge.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc index e02c7b03b7f917..616b8ccd5fcf99 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager_test.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" +#include "xla/tsl/protobuf/coordination_config.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_config.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index 6fe7b4064235f8..985a709fe09aeb 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -36,13 +36,13 @@ cc_library( srcs = ["grpc_util.cc"], hdrs = ["grpc_util.h"], deps = [ + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:cord", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:stringprintf", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -56,12 +56,12 @@ tsl_cc_test( deps = [ ":grpc_util", ":test_request_proto_cc_impl", + "//xla/tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:distributed_runtime_payloads_proto_cc_impl", ] + tsl_grpc_cc_dependencies(), ) @@ -83,9 +83,10 @@ cc_library( deps = [ ":grpc_channel_common", ":grpc_util", + "//xla/tsl/lib/gtl:map_util", + "//xla/tsl/protobuf:rpc_options_proto_cc", "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", @@ -96,7 +97,6 @@ cc_library( "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -109,12 +109,12 @@ tsl_cc_test( deps = [ ":grpc_channel", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:rpc_options_proto_cc_impl", "//xla/tsl/util:device_name_utils", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:rpc_options_proto_cc_impl", ], ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD index 838c2cbdf5ab5c..c14352c8375163 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/BUILD @@ -21,13 +21,13 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:grpc_client_cq_tag", "//xla/tsl/distributed_runtime/rpc:grpc_state", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) @@ -42,12 +42,12 @@ cc_library( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc:grpc_call", "//xla/tsl/distributed_runtime/rpc:grpc_util", + "//xla/tsl/protobuf:coordination_service_cc_grpc_proto", + "//xla/tsl/protobuf:coordination_service_proto_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/protobuf:coordination_service_cc_grpc_proto", - "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc index 6bd7885d2cafb7..639e2f2e10ec25 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc @@ -34,10 +34,10 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/grpc_client_cq_tag.h" #include "xla/tsl/distributed_runtime/rpc/grpc_state.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { namespace { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h index 4b6c74e0b870af..0fdaafc9f579bb 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h @@ -30,9 +30,9 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/grpc_call.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" +#include "xla/tsl/protobuf/coordination_service.grpc.pb.h" +#include "xla/tsl/protobuf/coordination_service.pb.h" #include "tsl/platform/threadpool.h" -#include "tsl/protobuf/coordination_service.grpc.pb.h" -#include "tsl/protobuf/coordination_service.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc index 3e80a68c3bab02..2ebb8cc7e9499b 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.cc @@ -26,8 +26,9 @@ limitations under the License. #include "absl/strings/str_split.h" #include "grpcpp/create_channel.h" #include "xla/tsl/distributed_runtime/rpc/grpc_channel_common.h" +#include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/platform/strcat.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h index d1fcba72793483..de9aadff1db4af 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel.h @@ -24,7 +24,7 @@ limitations under the License. #include "grpcpp/grpcpp.h" #include "xla/tsl/distributed_runtime/rpc/grpc_util.h" -#include "tsl/protobuf/rpc_options.pb.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" namespace tsl { using tensorflow::RPCOptions; diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 80c976640fa6f1..2790b0cd65dc44 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/protobuf/rpc_options.pb.h" #include "xla/tsl/util/device_name_utils.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/rpc_options.pb.h" namespace tsl { #define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h index 8e3e328cefd8eb..d39eb8e0f1be56 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_util.h @@ -23,11 +23,11 @@ limitations under the License. #include "absl/strings/cord.h" #include "grpcpp/grpcpp.h" #include "grpcpp/support/byte_buffer.h" +#include "xla/tsl/protobuf/distributed_runtime_payloads.pb.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/stringprintf.h" -#include "tsl/protobuf/distributed_runtime_payloads.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 27bfe40de0ea69..79ce7471ce480a 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -118,7 +118,7 @@ cc_library( ] + if_static( extra_deps = [ ":allocator_registry_impl", - "@local_tsl//tsl/lib/gtl:inlined_vector", + "//xla/tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:stringprintf", "@local_tsl//tsl/platform:env", @@ -131,7 +131,7 @@ cc_library( "@local_tsl//tsl/platform:types", ], otherwise = [ - "@local_tsl//tsl/lib/gtl:inlined_vector", + "//xla/tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:strcat", @@ -162,10 +162,10 @@ cc_library( deps = [ ":numeric_types", ":type_traits", + "//xla/tsl/lib/gtl:inlined_vector", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", - "@local_tsl//tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", @@ -233,7 +233,7 @@ cc_library( "device_id_manager.h", ], deps = [ - "@local_tsl//tsl/lib/gtl:int_type", + "//xla/tsl/lib/gtl:int_type", ] + if_static([ ":device_id_impl", ]), @@ -248,9 +248,9 @@ cc_library( ], deps = [ ":device_type", + "//xla/tsl/lib/gtl:int_type", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:int_type", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", @@ -273,6 +273,7 @@ cc_library( "//xla/tsl/util:device_name_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", @@ -365,8 +366,8 @@ cc_library( ), visibility = ["//visibility:public"], deps = [ + "//xla/tsl/lib/gtl:flatmap", "@com_google_absl//absl/memory", - "@local_tsl//tsl/lib/gtl:flatmap", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", @@ -467,10 +468,10 @@ tsl_cc_test( ":device_id_impl", ":device_id_utils", "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "//xla/tsl/util:device_name_utils", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/third_party/xla/xla/tsl/framework/cancellation.cc b/third_party/xla/xla/tsl/framework/cancellation.cc index 7802eb926de59d..83d60bcddb96d6 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.cc +++ b/third_party/xla/xla/tsl/framework/cancellation.cc @@ -78,7 +78,7 @@ void CancellationManager::StartCancelWithStatus(const absl::Status& status) { LOG(WARNING) << "Cancellation callback \"" << config.name << "\" is triggered due to a " << (StatusGroup::IsDerived(status) ? "derived" : "root") - << " error: " << status.ToString(); + << " error: " << status; } config.callback(); } diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h index 38f7ebf60a63b2..6dd04e269ff5d3 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.h +++ b/third_party/xla/xla/tsl/framework/cancellation.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatmap.h" +#include "xla/tsl/lib/gtl/flatmap.h" #include "tsl/platform/hash.h" #include "tsl/platform/mutex.h" #include "tsl/platform/notification.h" diff --git a/third_party/xla/xla/tsl/framework/device_id.h b/third_party/xla/xla/tsl/framework/device_id.h index b56c9ecbc64ec1..e80d84298195fe 100644 --- a/third_party/xla/xla/tsl/framework/device_id.h +++ b/third_party/xla/xla/tsl/framework/device_id.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef XLA_TSL_FRAMEWORK_DEVICE_ID_H_ #define XLA_TSL_FRAMEWORK_DEVICE_ID_H_ -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc index 812b119c0201da..343674c8d399d8 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/tsl/framework/device_id_utils.h" +#include #include #include #include @@ -22,13 +23,39 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" #include "xla/tsl/framework/device_id.h" #include "xla/tsl/framework/device_id_manager.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/str_util.h" namespace tsl { +namespace { + +absl::StatusOr ParsePlatformDeviceIdString( + absl::string_view platform_device_id_str, absl::string_view device_type) { + int32_t platform_device_id; + if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { + // Pluggable device would have both device type and id in the string. + const std::vector device_type_and_id = + tsl::str_util::Split(platform_device_id_str, ':'); // non-absl ok + if (device_type_and_id.size() != 2 || + !absl::SimpleAtoi(device_type_and_id[1], &platform_device_id)) { + return tsl::errors::InvalidArgument( + "Could not parse entry in 'visible_device_list': '", + platform_device_id_str, "'."); + } + if (!device_type.empty() && device_type_and_id[0] != device_type) { + return -1; // Return -1 to indicate that the device type doesn't match. + } + } + return platform_device_id; +} + +} // namespace void CheckValidTfDeviceId(const DeviceType& type, const int visible_device_count, @@ -45,7 +72,8 @@ void CheckValidTfDeviceId(const DeviceType& type, absl::Status ParseVisibleDeviceList( const std::string& visible_device_list, const int visible_device_count, - std::vector* visible_device_order) { + std::vector* visible_device_order, + absl::string_view device_type) { visible_device_order->clear(); // If the user wants to remap the visible to virtual Device mapping, @@ -59,11 +87,11 @@ absl::Status ParseVisibleDeviceList( tsl::str_util::Split(visible_device_list, ','); // non-absl ok for (const std::string& platform_device_id_str : order_str) { int32_t platform_device_id; - if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { - return tsl::errors::InvalidArgument( - "Could not parse entry in 'visible_device_list': '", - platform_device_id_str, - "'. visible_device_list = ", visible_device_list); + TF_ASSIGN_OR_RETURN( + platform_device_id, + ParsePlatformDeviceIdString(platform_device_id_str, device_type)); + if (platform_device_id == -1) { + continue; // Skip the device if the device type doesn't match. } if (platform_device_id < 0 || platform_device_id >= visible_device_count) { @@ -102,9 +130,9 @@ absl::StatusOr GetNumberTfDevicesAndConfigurePlatformDeviceId( return 0; } std::vector visible_device_order; - TF_RETURN_IF_ERROR(ParseVisibleDeviceList(std::string(visible_device_list), - visible_device_count, - &visible_device_order)); + TF_RETURN_IF_ERROR(ParseVisibleDeviceList( + std::string(visible_device_list), visible_device_count, + &visible_device_order, device_type)); if (num_tf_devices > visible_device_order.size()) { num_tf_devices = visible_device_order.size(); } diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.h b/third_party/xla/xla/tsl/framework/device_id_utils.h index 0da5969a189531..b4552431cc97d5 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.h +++ b/third_party/xla/xla/tsl/framework/device_id_utils.h @@ -37,9 +37,19 @@ void CheckValidTfDeviceId(const DeviceType& type, int visible_device_count, TfDeviceId tf_device_id); // Parse `visible_device_list` into a list of platform Device ids. +// When parsing non-PluggableDevices, the `device_type` parameter is +// optional (can be empty) and ignored. When using this function to +// parse the `visible_device_list` for PluggableDevices, the pluggable +// device type will be included in the `visible_device_list`, e.g. +// "PluggableDeviceA:0,PluggableDeviceA:1,PluggableDeviceB:0". +// In this case, the `device_type` parameter should be set to the +// corresponding pluggable device type to be parsed, e.g. +// "PluggableDeviceA". And the other types of PluggableDevices +// in the `visible_device_list` will be ignored. absl::Status ParseVisibleDeviceList( const std::string& visible_device_list, int visible_device_count, - std::vector* visible_device_order); + std::vector* visible_device_order, + absl::string_view device_type = ""); // Returns how many TF devices should be created, and generates the mapping // between TfDeviceId and PlatformDeviceId. The number of TF devices is the diff --git a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc index e230d85a61cf51..9d2417e59765b2 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc @@ -101,6 +101,48 @@ TEST(DeviceIdUtilsTest, ParseDuplicateVisibleDeviceList) { HasSubstr("visible_device_list contained a duplicate entry: 1,1"))); } +TEST(DeviceIdUtilsTest, ParseMultiplePluggableVisibleDeviceList) { + { + std::vector visible_device_order; + TF_EXPECT_OK( + ParseVisibleDeviceList("A:0,A:1,B:0", 3, &visible_device_order, "A")); + PlatformDeviceId platform_device_id0(0), platform_device_id1(1); + std::vector expected = {platform_device_id0, + platform_device_id1}; + EXPECT_EQ(visible_device_order, expected); + } + + { + std::vector visible_device_order; + TF_EXPECT_OK( + ParseVisibleDeviceList("A:0,A:1,B:0", 3, &visible_device_order, "B")); + PlatformDeviceId platform_device_id0(0); + std::vector expected = {platform_device_id0}; + EXPECT_EQ(visible_device_order, expected); + } +} + +TEST(DeviceIdUtilsTest, ParseMultiplePluggableOutOfOrderVisibleDeviceList) { + { + std::vector visible_device_order; + TF_EXPECT_OK( + ParseVisibleDeviceList("A:1,B:0,A:0", 3, &visible_device_order, "A")); + PlatformDeviceId platform_device_id0(0), platform_device_id1(1); + std::vector expected = {platform_device_id1, + platform_device_id0}; + EXPECT_EQ(visible_device_order, expected); + } + + { + std::vector visible_device_order; + TF_EXPECT_OK( + ParseVisibleDeviceList("A:1,B:0,A:0", 3, &visible_device_order, "B")); + PlatformDeviceId platform_device_id0(0); + std::vector expected = {platform_device_id0}; + EXPECT_EQ(visible_device_order, expected); + } +} + TEST(DeviceIdUtilsTest, GetNumberTfDevicesDefault) { TF_ASSERT_OK_AND_ASSIGN(size_t num_tf_device, GetNumberTfDevicesAndConfigurePlatformDeviceId( diff --git a/third_party/xla/xla/tsl/framework/tracking_allocator.h b/third_party/xla/xla/tsl/framework/tracking_allocator.h index 32d0026db63464..b0e4288fc99617 100644 --- a/third_party/xla/xla/tsl/framework/tracking_allocator.h +++ b/third_party/xla/xla/tsl/framework/tracking_allocator.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/mutex.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" diff --git a/third_party/xla/xla/tsl/framework/type_traits.h b/third_party/xla/xla/tsl/framework/type_traits.h index 46fa640ee62298..f7a9bd7a54bc91 100644 --- a/third_party/xla/xla/tsl/framework/type_traits.h +++ b/third_party/xla/xla/tsl/framework/type_traits.h @@ -70,9 +70,13 @@ struct is_simple_type { std::is_trivial::value || std::is_same::value || std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || - std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc index b748249d7b79f1..bab7f7e4bc9bf5 100644 --- a/third_party/xla/xla/tsl/lib/core/bitmap_test.cc +++ b/third_party/xla/xla/tsl/lib/core/bitmap_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "xla/tsl/lib/core/bitmap.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD b/third_party/xla/xla/tsl/lib/gtl/BUILD similarity index 77% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD rename to third_party/xla/xla/tsl/lib/gtl/BUILD index f601fb129e1521..9c0c8faf532110 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/BUILD @@ -1,13 +1,13 @@ load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,21 +16,21 @@ package( "//tensorflow/core:__pkg__", # tensorflow/core/lib/strings:proto_serialization uses on gtl:inlined_vector "//tensorflow/core/lib/strings:__pkg__", - "//tsl/lib/strings:__pkg__", + "//xla/tsl/lib/strings:__pkg__", # tensorflow/core/framework uses map_util, and flatmap "//tensorflow/core/framework:__pkg__", - "@local_xla//xla/tsl/framework:__pkg__", - "//tsl/platform/cloud:__pkg__", + "//xla/tsl/framework:__pkg__", + "@local_tsl//tsl/platform/cloud:__pkg__", # tensorflow/core/util uses inlined_vector "//tensorflow/core/util:__pkg__", # tensorflow/core/tfrt/utils uses inlined_vector "//tensorflow/core/tfrt/utils:__pkg__", # tensorflow/examples/custom_ops_doc/simple_hash_table uses map_util "//tensorflow/examples/custom_ops_doc/simple_hash_table:__pkg__", - "@local_xla//xla:__subpackages__", + "//xla:__subpackages__", "//tensorflow/core/lib/gtl:__subpackages__", - "@local_xla//xla/tsl/distributed_runtime/rpc:__pkg__", - "//tsl/profiler/utils:__pkg__", + "//xla/tsl/distributed_runtime/rpc:__pkg__", + "//xla/tsl/profiler/utils:__pkg__", ]), licenses = ["notice"], ) @@ -48,9 +48,9 @@ cc_library( hdrs = ["flatmap.h"], deps = [ ":flatrep", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -58,8 +58,8 @@ cc_library( name = "flatrep", hdrs = ["flatrep.h"], deps = [ - "//tsl/platform:types", "@com_google_absl//absl/base:prefetch", + "@local_tsl//tsl/platform:types", ], ) @@ -68,9 +68,9 @@ cc_library( hdrs = ["flatset.h"], deps = [ ":flatrep", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -78,10 +78,10 @@ cc_library( name = "inlined_vector", hdrs = ["inlined_vector.h"], deps = [ - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], ) @@ -89,8 +89,8 @@ cc_library( name = "int_type", hdrs = ["int_type.h"], deps = [ - "//tsl/platform:macros", - "//tsl/platform:types", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], ) @@ -103,7 +103,7 @@ cc_library( name = "map_util", srcs = [ "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], hdrs = ["map_util.h"], ) @@ -166,7 +166,7 @@ filegroup( visibility = internal_visibility([ "//tensorflow/core:__pkg__", "//tensorflow/core/lib/gtl:__pkg__", - "//tsl:__subpackages__", + "@local_tsl//tsl:__subpackages__", ]), ) @@ -177,7 +177,7 @@ filegroup( "int_type.h", "iterator_range.h", "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = internal_visibility([ "//tensorflow/core:__pkg__", @@ -196,7 +196,7 @@ filegroup( "int_type.h", "iterator_range.h", "map_util.h", - "//tsl/lib/gtl/subtle:map_traits", + "//xla/tsl/lib/gtl/subtle:map_traits", ], visibility = internal_visibility([ "//tensorflow/core:__pkg__", @@ -221,10 +221,10 @@ tsl_cc_test( ":int_type", ":iterator_range", ":map_util", - "//tsl/platform:hash", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h rename to third_party/xla/xla/tsl/lib/gtl/compactptrset.h index 8fbb7a8560dd1b..3848430e76fb92 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset.h +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ -#define TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ +#ifndef XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ +#define XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ #include -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" namespace tsl { namespace gtl { @@ -206,4 +206,4 @@ class CompactPointerSet { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_COMPACTPTRSET_H_ +#endif // XLA_TSL_LIB_GTL_COMPACTPTRSET_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc rename to third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc index 9dc146c2e52b79..6f5e52dc085047 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/compactptrset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/compactptrset_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/compactptrset.h" +#include "xla/tsl/lib/gtl/compactptrset.h" #include "tsl/platform/hash.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h b/third_party/xla/xla/tsl/lib/gtl/flatmap.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h rename to third_party/xla/xla/tsl/lib/gtl/flatmap.h index 8d5cf7912e9d78..e74dbd46531d9a 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ +#ifndef XLA_TSL_LIB_GTL_FLATMAP_H_ +#define XLA_TSL_LIB_GTL_FLATMAP_H_ #include @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" @@ -393,4 +393,4 @@ class FlatMap { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATMAP_H_ +#endif // XLA_TSL_LIB_GTL_FLATMAP_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc rename to third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc index a2b4fd11df3dbf..231970ccbe45ac 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatmap_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/flatmap.h" +#include "xla/tsl/lib/gtl/flatmap.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/xla/xla/tsl/lib/gtl/flatrep.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h rename to third_party/xla/xla/tsl/lib/gtl/flatrep.h index d6c77e7de363ea..74ae18fc37c0f8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatrep.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ +#ifndef XLA_TSL_LIB_GTL_FLATREP_H_ +#define XLA_TSL_LIB_GTL_FLATREP_H_ #include @@ -350,4 +350,4 @@ class FlatRep { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATREP_H_ +#endif // XLA_TSL_LIB_GTL_FLATREP_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h b/third_party/xla/xla/tsl/lib/gtl/flatset.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h rename to third_party/xla/xla/tsl/lib/gtl/flatset.h index b3178225647fe1..f272ad1fa7bd1d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h +++ b/third_party/xla/xla/tsl/lib/gtl/flatset.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ -#define TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ +#ifndef XLA_TSL_LIB_GTL_FLATSET_H_ +#define XLA_TSL_LIB_GTL_FLATSET_H_ #include @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/flatrep.h" +#include "xla/tsl/lib/gtl/flatrep.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" @@ -293,4 +293,4 @@ class FlatSet { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_FLATSET_H_ +#endif // XLA_TSL_LIB_GTL_FLATSET_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc rename to third_party/xla/xla/tsl/lib/gtl/flatset_test.cc index abf7892f2d8798..8adb9133a76ecb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/flatset_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/flatset.h" +#include "xla/tsl/lib/gtl/flatset.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h rename to third_party/xla/xla/tsl/lib/gtl/inlined_vector.h index fc8533b02937ab..6072f87ff6931d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/inlined_vector.h +++ b/third_party/xla/xla/tsl/lib/gtl/inlined_vector.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ -#define TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ +#ifndef XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ +#define XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ #include @@ -39,4 +39,4 @@ using InlinedVector ABSL_DEPRECATE_AND_INLINE() = absl::InlinedVector; } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_INLINED_VECTOR_H_ +#endif // XLA_TSL_LIB_GTL_INLINED_VECTOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h b/third_party/xla/xla/tsl/lib/gtl/int_type.h similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h rename to third_party/xla/xla/tsl/lib/gtl/int_type.h index 7a5d7935782884..2a54fc58fada8f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type.h +++ b/third_party/xla/xla/tsl/lib/gtl/int_type.h @@ -149,8 +149,8 @@ limitations under the License. // void GetGlobalDoc(int64 global) { ... // GetGlobalDoc(local.value()); <-- Compiles fine. -#ifndef TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ -#define TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ +#ifndef XLA_TSL_LIB_GTL_INT_TYPE_H_ +#define XLA_TSL_LIB_GTL_INT_TYPE_H_ #include @@ -361,4 +361,4 @@ INT_TYPE_COMPARISON_OP(>=); // NOLINT } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_INT_TYPE_H_ +#endif // XLA_TSL_LIB_GTL_INT_TYPE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc rename to third_party/xla/xla/tsl/lib/gtl/int_type_test.cc index 2716eb139fa0a9..6ab47039fe1653 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/int_type_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/int_type_test.cc @@ -15,7 +15,7 @@ limitations under the License. // Unit test cases for IntType. -#include "tsl/lib/gtl/int_type.h" +#include "xla/tsl/lib/gtl/int_type.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h b/third_party/xla/xla/tsl/lib/gtl/iterator_range.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h rename to third_party/xla/xla/tsl/lib/gtl/iterator_range.h index 6e420c940142cc..2914dce38c7f9e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range.h +++ b/third_party/xla/xla/tsl/lib/gtl/iterator_range.h @@ -22,8 +22,8 @@ limitations under the License. // // Converted from chandlerc@'s code to Google style by joshl@. -#ifndef TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ -#define TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#ifndef XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#define XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ #include @@ -65,4 +65,4 @@ iterator_range make_range(T x, T y) { } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_ITERATOR_RANGE_H_ +#endif // XLA_TSL_LIB_GTL_ITERATOR_RANGE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc rename to third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc index 35d1fe5854d8b8..08028094552ff1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/iterator_range_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/iterator_range_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/iterator_range.h" +#include "xla/tsl/lib/gtl/iterator_range.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h b/third_party/xla/xla/tsl/lib/gtl/map_util.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h rename to third_party/xla/xla/tsl/lib/gtl/map_util.h index 63a966228481bc..d04ba3644a09a0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util.h +++ b/third_party/xla/xla/tsl/lib/gtl/map_util.h @@ -17,8 +17,8 @@ limitations under the License. // structures, such as std::map and hash_map. Some functions will also work with // sets, such as ContainsKey(). -#ifndef TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ -#define TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ +#ifndef XLA_TSL_LIB_GTL_MAP_UTIL_H_ +#define XLA_TSL_LIB_GTL_MAP_UTIL_H_ #include @@ -27,7 +27,7 @@ limitations under the License. #include #include -#include "tsl/lib/gtl/subtle/map_traits.h" +#include "xla/tsl/lib/gtl/subtle/map_traits.h" namespace tsl { namespace gtl { @@ -212,4 +212,4 @@ typename Collection::value_type::second_type EraseKeyReturnValuePtr( } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_MAP_UTIL_H_ +#endif // XLA_TSL_LIB_GTL_MAP_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc rename to third_party/xla/xla/tsl/lib/gtl/map_util_test.cc index 7ecf4a4b394251..ce2a13c9e394e9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/map_util_test.cc +++ b/third_party/xla/xla/tsl/lib/gtl/map_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD b/third_party/xla/xla/tsl/lib/gtl/subtle/BUILD similarity index 69% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD rename to third_party/xla/xla/tsl/lib/gtl/subtle/BUILD index e2f9763dfd2b58..dfedc36f004cb1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/BUILD +++ b/third_party/xla/xla/tsl/lib/gtl/subtle/BUILD @@ -1,8 +1,8 @@ # Description: # gtl subtle packages. -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -16,6 +16,6 @@ filegroup( ], visibility = internal_visibility([ "//tensorflow/core/lib/gtl/subtle:__pkg__", - "//tsl/lib/gtl:__pkg__", + "//xla/tsl/lib/gtl:__pkg__", ]), ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h b/third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h rename to third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h index 535db74402ba91..961dc550747bd2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/subtle/map_traits.h +++ b/third_party/xla/xla/tsl/lib/gtl/subtle/map_traits.h @@ -17,8 +17,8 @@ limitations under the License. // 1. If T has a `first` or `second` field, use them. // 2. Otherwise if it has `key()` or `value()` methods, use them. // 3. Otherwise the program is ill-formed. -#ifndef TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ -#define TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#ifndef XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#define XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ #include namespace tsl { namespace gtl { @@ -62,4 +62,4 @@ auto GetMapped(V&& v) } // namespace subtle } // namespace gtl } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ +#endif // XLA_TSL_LIB_GTL_SUBTLE_MAP_TRAITS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD b/third_party/xla/xla/tsl/lib/hash/BUILD similarity index 71% rename from third_party/xla/third_party/tsl/tsl/lib/hash/BUILD rename to third_party/xla/xla/tsl/lib/hash/BUILD index c497abfe17ac47..a25dc7d9cda14b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/BUILD +++ b/third_party/xla/xla/tsl/lib/hash/BUILD @@ -1,24 +1,24 @@ +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", +) load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) load( - "@local_xla//xla/tsl:tsl.bzl", + "//xla/tsl:tsl.bzl", "if_linux_x86_64", "internal_visibility", "tsl_copts", ) -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") -load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", -) +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ - # tensorflow/tsl/lib/io/table_builder.cc uses crc functionality - "//tsl/lib/io:__pkg__", + # tensorflow/compiler/xla/tsl/lib/io/table_builder.cc uses crc functionality + "//xla/tsl/lib/io:__pkg__", # tensorflow/core/lib/hash aliases hash for now "//tensorflow/core/lib/hash:__pkg__", ]), @@ -34,12 +34,12 @@ cc_library( # -msse4.2 enables the use of crc32c compiler builtins. copts = tsl_copts() + if_linux_x86_64(["-msse4.2"]), deps = [ - "//tsl/platform", - "//tsl/platform:cord", - "//tsl/platform:types", "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:types", ], ) @@ -67,11 +67,11 @@ tsl_cc_test( srcs = ["crc32c_test.cc"], deps = [ ":crc32c", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc b/third_party/xla/xla/tsl/lib/hash/crc32c.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc rename to third_party/xla/xla/tsl/lib/hash/crc32c.cc index 1bd005b6b05297..8ad835fb1d80f8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.cc @@ -16,7 +16,7 @@ limitations under the License. // A portable implementation of crc32c, optimized to handle // four bytes at a time. -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h b/third_party/xla/xla/tsl/lib/hash/crc32c.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h rename to third_party/xla/xla/tsl/lib/hash/crc32c.h index 10c4ea13e864d5..29c71eed3f0a99 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c.h +++ b/third_party/xla/xla/tsl/lib/hash/crc32c.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ -#define TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ +#ifndef XLA_TSL_LIB_HASH_CRC32C_H_ +#define XLA_TSL_LIB_HASH_CRC32C_H_ #include @@ -67,4 +67,4 @@ inline uint32 Unmask(uint32 masked_crc) { } // namespace crc32c } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_HASH_CRC32C_H_ +#endif // XLA_TSL_LIB_HASH_CRC32C_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc rename to third_party/xla/xla/tsl/lib/hash/crc32c_test.cc index 9ba2e6e8108cf7..291121d5043f6f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/hash/crc32c_test.cc +++ b/third_party/xla/xla/tsl/lib/hash/crc32c_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include diff --git a/third_party/xla/xla/tsl/lib/histogram/BUILD b/third_party/xla/xla/tsl/lib/histogram/BUILD index cbd206f6bd8083..fc486455c63c14 100644 --- a/third_party/xla/xla/tsl/lib/histogram/BUILD +++ b/third_party/xla/xla/tsl/lib/histogram/BUILD @@ -20,12 +20,12 @@ cc_library( hdrs = ["histogram.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], alwayslink = True, ) @@ -55,9 +55,9 @@ tsl_cc_test( ], deps = [ ":histogram", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram.cc b/third_party/xla/xla/tsl/lib/histogram/histogram.cc index e8203549272547..35ff514e1fe1dd 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.cc @@ -20,10 +20,10 @@ limitations under the License. #include +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/mutex.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index 4051d98f49ab97..1b2f1827521a17 100644 --- a/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace histogram { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/xla/tsl/lib/io/BUILD similarity index 55% rename from third_party/xla/third_party/tsl/tsl/lib/io/BUILD rename to third_party/xla/xla/tsl/lib/io/BUILD index cd527743282c01..43152cb1ea1444 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD +++ b/third_party/xla/xla/tsl/lib/io/BUILD @@ -1,24 +1,24 @@ +load("@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test") load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ "//tensorflow/c/experimental/filesystem:__pkg__", "//tensorflow/c/experimental/filesystem/plugins/posix:__pkg__", - "//tsl/lib/io/snappy:__pkg__", - "@local_xla//xla:__subpackages__", + "//xla/tsl/lib/io/snappy:__pkg__", + "//xla:__subpackages__", # tensorflow/core:lib effectively exposes all targets under tensorflow/core/lib/** "//tensorflow/core/util:__subpackages__", "//tensorflow/core:__pkg__", "//tensorflow/core/lib/io:__subpackages__", - "@local_xla//xla/tsl/profiler:__subpackages__", - "//tsl/profiler:__subpackages__", + "//xla/tsl/profiler:__subpackages__", + "@local_tsl//tsl/profiler:__subpackages__", "//tensorflow/core/profiler:__subpackages__", ]), licenses = ["notice"], @@ -41,16 +41,16 @@ cc_library( deps = [ ":iterator", ":table_options", - "//tsl/lib/hash:crc32c", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:platform_port", - "//tsl/platform:raw_coding", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -62,8 +62,8 @@ cc_library( deps = [ ":inputstream_interface", ":random_inputstream", - "//tsl/platform:env", "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -80,13 +80,13 @@ cc_library( srcs = ["inputbuffer.cc"], hdrs = ["inputbuffer.h"], deps = [ - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -96,10 +96,10 @@ cc_library( srcs = ["inputstream_interface.cc"], hdrs = ["inputstream_interface.h"], deps = [ - "//tsl/platform:cord", - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -109,8 +109,8 @@ cc_library( srcs = ["iterator.cc"], hdrs = ["iterator.h"], deps = [ - "//tsl/platform:status", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", ], alwayslink = True, ) @@ -119,10 +119,10 @@ cc_library( name = "proto_encode_helper", hdrs = ["proto_encode_helper.h"], deps = [ - "//tsl/platform:coding", - "//tsl/platform:logging", - "//tsl/platform:protobuf", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -132,8 +132,8 @@ cc_library( hdrs = ["random_inputstream.h"], deps = [ ":inputstream_interface", - "//tsl/platform:cord", - "//tsl/platform:env", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", ], alwayslink = True, ) @@ -151,13 +151,13 @@ cc_library( ":snappy_inputstream", ":zlib_compression_options", ":zlib_inputstream", - "//tsl/lib/hash:crc32c", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:raw_coding", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -172,36 +172,36 @@ cc_library( ":snappy_outputbuffer", ":zlib_compression_options", ":zlib_outputbuffer", - "//tsl/lib/hash:crc32c", - "//tsl/platform:coding", - "//tsl/platform:cord", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) alias( name = "snappy_inputbuffer", - actual = "//tsl/lib/io/snappy:snappy_inputbuffer", + actual = "//xla/tsl/lib/io/snappy:snappy_inputbuffer", ) alias( name = "snappy_inputstream", - actual = "//tsl/lib/io/snappy:snappy_inputstream", + actual = "//xla/tsl/lib/io/snappy:snappy_inputstream", ) alias( name = "snappy_outputbuffer", - actual = "//tsl/lib/io/snappy:snappy_outputbuffer", + actual = "//xla/tsl/lib/io/snappy:snappy_outputbuffer", ) alias( name = "snappy_compression_options", - actual = "//tsl/lib/io/snappy:snappy_compression_options", + actual = "//xla/tsl/lib/io/snappy:snappy_compression_options", ) cc_library( @@ -213,9 +213,9 @@ cc_library( "cache.h", ], deps = [ - "//tsl/platform:mutex", - "//tsl/platform:raw_coding", - "//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:stringpiece", ], ) @@ -234,9 +234,9 @@ cc_library( ":cache", ":iterator", ":table_options", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:errors", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", ], alwayslink = True, ) @@ -251,10 +251,10 @@ cc_library( hdrs = ["buffered_file.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/lib/hash:crc32c", - "//tsl/platform:cord", - "//tsl/platform:env", - "//tsl/platform:status", + "//xla/tsl/lib/hash:crc32c", + "@local_tsl//tsl/platform:cord", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:status", ], ) @@ -264,12 +264,12 @@ tsl_cc_test( srcs = ["buffered_file_test.cc"], deps = [ ":buffered_file", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) @@ -278,7 +278,7 @@ cc_library( srcs = ["zlib_compression_options.cc"], hdrs = ["zlib_compression_options.h"], deps = [ - "//tsl/platform:types", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -291,12 +291,12 @@ cc_library( deps = [ ":inputstream_interface", ":zlib_compression_options", - "//tsl/platform:env", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:strcat", - "//tsl/platform:types", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -308,12 +308,12 @@ cc_library( hdrs = ["zlib_outputbuffer.h"], deps = [ ":zlib_compression_options", - "//tsl/platform:env", - "//tsl/platform:errors", - "//tsl/platform:macros", - "//tsl/platform:status", - "//tsl/platform:stringpiece", - "//tsl/platform:types", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:stringpiece", + "@local_tsl//tsl/platform:types", "@zlib", ], alwayslink = True, @@ -357,9 +357,9 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.cc", "zlib_inputstream.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputstream.cc", - "//tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.cc", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", ], ) @@ -385,10 +385,10 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputbuffer.h", - "//tsl/lib/io/snappy:snappy_inputstream.h", - "//tsl/lib/io/snappy:snappy_outputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_outputbuffer.h", ], visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) @@ -419,10 +419,10 @@ filegroup( "zlib_compression_options.h", "zlib_inputstream.h", "zlib_outputbuffer.h", - "//tsl/lib/io/snappy:snappy_compression_options.h", - "//tsl/lib/io/snappy:snappy_inputbuffer.h", - "//tsl/lib/io/snappy:snappy_inputstream.h", - "//tsl/lib/io/snappy:snappy_outputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_compression_options.h", + "//xla/tsl/lib/io/snappy:snappy_inputbuffer.h", + "//xla/tsl/lib/io/snappy:snappy_inputstream.h", + "//xla/tsl/lib/io/snappy:snappy_outputbuffer.h", ], visibility = internal_visibility(["//tensorflow/core:__pkg__"]), ) @@ -444,12 +444,12 @@ tsl_cc_test( deps = [ ":buffered_inputstream", ":random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", ], ) @@ -459,10 +459,10 @@ tsl_cc_test( srcs = ["cache_test.cc"], deps = [ ":cache", - "//tsl/platform:coding", - "//tsl/platform:raw_coding", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:raw_coding", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -472,17 +472,17 @@ tsl_cc_test( srcs = ["inputbuffer_test.cc"], deps = [ ":inputbuffer", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:status", - "//tsl/platform:str_util", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -492,10 +492,10 @@ tsl_cc_test( srcs = ["inputstream_interface_test.cc"], deps = [ ":inputstream_interface", - "//tsl/platform:errors", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -505,11 +505,11 @@ tsl_cc_test( srcs = ["random_inputstream_test.cc"], deps = [ ":random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -520,15 +520,15 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:logging", - "//tsl/platform:status", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", "@zlib", ], ) @@ -540,16 +540,16 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/lib/hash:crc32c", - "//tsl/lib/random:philox", - "//tsl/platform:coding", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:str_util", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/hash:crc32c", + "//xla/tsl/lib/random:philox", + "@local_tsl//tsl/platform:coding", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:str_util", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -561,14 +561,14 @@ tsl_cc_test( ":block", ":iterator", ":table", - "//tsl/lib/random:philox", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:platform_port", - "//tsl/platform:test", - "//tsl/platform:test_main", + "//xla/tsl/lib/random:philox", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -581,12 +581,12 @@ tsl_cc_test( ":zlib_compression_options", ":zlib_inputstream", ":zlib_outputbuffer", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:errors", - "//tsl/platform:strcat", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:strcat", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc b/third_party/xla/xla/tsl/lib/io/block.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/block.cc rename to third_party/xla/xla/tsl/lib/io/block.cc index afae26cf20caec..eed80e59cf9243 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.cc +++ b/third_party/xla/xla/tsl/lib/io/block.cc @@ -15,11 +15,11 @@ limitations under the License. // Decodes the blocks generated by block_builder.cc. -#include "tsl/lib/io/block.h" +#include "xla/tsl/lib/io/block.h" #include -#include "tsl/lib/io/format.h" +#include "xla/tsl/lib/io/format.h" #include "tsl/platform/coding.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block.h b/third_party/xla/xla/tsl/lib/io/block.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/io/block.h rename to third_party/xla/xla/tsl/lib/io/block.h index b31808627157c2..c97a0f9830d48f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block.h +++ b/third_party/xla/xla/tsl/lib/io/block.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BLOCK_H_ -#define TENSORFLOW_TSL_LIB_IO_BLOCK_H_ +#ifndef XLA_TSL_LIB_IO_BLOCK_H_ +#define XLA_TSL_LIB_IO_BLOCK_H_ #include #include -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { @@ -54,4 +54,4 @@ class Block { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BLOCK_H_ +#endif // XLA_TSL_LIB_IO_BLOCK_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc b/third_party/xla/xla/tsl/lib/io/block_builder.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc rename to third_party/xla/xla/tsl/lib/io/block_builder.cc index d28d718a24f6d7..e471852a7bfda4 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.cc +++ b/third_party/xla/xla/tsl/lib/io/block_builder.cc @@ -37,13 +37,13 @@ limitations under the License. // num_restarts: uint32 // restarts[i] contains the offset within the block of the ith restart point. -#include "tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/block_builder.h" #include #include -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include "tsl/platform/coding.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h b/third_party/xla/xla/tsl/lib/io/block_builder.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h rename to third_party/xla/xla/tsl/lib/io/block_builder.h index 578d8bab57e854..0defea6d866e0f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/block_builder.h +++ b/third_party/xla/xla/tsl/lib/io/block_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ -#define TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ +#ifndef XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ +#define XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ #include @@ -67,4 +67,4 @@ class BlockBuilder { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BLOCK_BUILDER_H_ +#endif // XLA_TSL_LIB_IO_BLOCK_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h b/third_party/xla/xla/tsl/lib/io/buffered_file.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h rename to third_party/xla/xla/tsl/lib/io/buffered_file.h index e6abe32c465fff..6d173c83d12530 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_file.h @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ -#define TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ +#ifndef XLA_TSL_LIB_IO_BUFFERED_FILE_H_ +#define XLA_TSL_LIB_IO_BUFFERED_FILE_H_ #include #include #include #include -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include "tsl/platform/cord.h" #include "tsl/platform/file_system.h" #include "tsl/platform/status.h" @@ -113,4 +113,4 @@ class BufferedWritableFile : public WritableFile { }; } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BUFFERED_FILE_H_ +#endif // XLA_TSL_LIB_IO_BUFFERED_FILE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc rename to third_party/xla/xla/tsl/lib/io/buffered_file_test.cc index f9fa67dd1572f5..2c3fc0fe5070ca 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_file_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_file.h" +#include "xla/tsl/lib/io/buffered_file.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc index 89ed20757cf093..244c15882ab502 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "absl/status/status.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream.h index 6681f1bbfbed32..1a187012766ab1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/file_system.h" namespace tsl { @@ -124,4 +124,4 @@ extern template Status BufferedInputStream::ReadAll(tstring* result); } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_BUFFERED_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc rename to third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc index 83e5776d6602d2..1ad2476eb5af6e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/buffered_inputstream_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache.cc b/third_party/xla/xla/tsl/lib/io/cache.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache.cc rename to third_party/xla/xla/tsl/lib/io/cache.cc index dee0871e6b8539..6515783f5c99e2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache.cc +++ b/third_party/xla/xla/tsl/lib/io/cache.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/cache.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache.h b/third_party/xla/xla/tsl/lib/io/cache.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache.h rename to third_party/xla/xla/tsl/lib/io/cache.h index 831288b56abd75..9cd5502cb2e715 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache.h +++ b/third_party/xla/xla/tsl/lib/io/cache.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_CACHE_H_ -#define TENSORFLOW_TSL_LIB_IO_CACHE_H_ +#ifndef XLA_TSL_LIB_IO_CACHE_H_ +#define XLA_TSL_LIB_IO_CACHE_H_ #include @@ -124,4 +124,4 @@ class Cache { } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_CACHE_H_ +#endif // XLA_TSL_LIB_IO_CACHE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc b/third_party/xla/xla/tsl/lib/io/cache_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc rename to third_party/xla/xla/tsl/lib/io/cache_test.cc index 62a53601fd3c73..3c54c82a11ac25 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/cache_test.cc +++ b/third_party/xla/xla/tsl/lib/io/cache_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/cache.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/compression.cc b/third_party/xla/xla/tsl/lib/io/compression.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/compression.cc rename to third_party/xla/xla/tsl/lib/io/compression.cc index 18f821bc805efe..450962fde73d98 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/compression.cc +++ b/third_party/xla/xla/tsl/lib/io/compression.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/compression.h" namespace tsl { namespace io { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/compression.h b/third_party/xla/xla/tsl/lib/io/compression.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/lib/io/compression.h rename to third_party/xla/xla/tsl/lib/io/compression.h index bed94981eca9c4..ce3b7fb4ca3e4c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/compression.h +++ b/third_party/xla/xla/tsl/lib/io/compression.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ -#define TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ +#ifndef XLA_TSL_LIB_IO_COMPRESSION_H_ +#define XLA_TSL_LIB_IO_COMPRESSION_H_ namespace tsl { namespace io { @@ -29,4 +29,4 @@ extern const char kZlib[]; } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_COMPRESSION_H_ +#endif // XLA_TSL_LIB_IO_COMPRESSION_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc b/third_party/xla/xla/tsl/lib/io/format.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/format.cc rename to third_party/xla/xla/tsl/lib/io/format.cc index aa26afd84b6677..e02451c08d7e0e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.cc +++ b/third_party/xla/xla/tsl/lib/io/format.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/format.h" +#include "xla/tsl/lib/io/format.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/block.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/block.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/format.h b/third_party/xla/xla/tsl/lib/io/format.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/format.h rename to third_party/xla/xla/tsl/lib/io/format.h index 2f704c9ca9d200..3cf5d6312a5f02 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/format.h +++ b/third_party/xla/xla/tsl/lib/io/format.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_FORMAT_H_ -#define TENSORFLOW_TSL_LIB_IO_FORMAT_H_ +#ifndef XLA_TSL_LIB_IO_FORMAT_H_ +#define XLA_TSL_LIB_IO_FORMAT_H_ #include #include -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -110,4 +110,4 @@ inline BlockHandle::BlockHandle() } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_FORMAT_H_ +#endif // XLA_TSL_LIB_IO_FORMAT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/inputbuffer.cc index f5a46ae7e87c1d..5fdff4943331ed 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h b/third_party/xla/xla/tsl/lib/io/inputbuffer.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h rename to third_party/xla/xla/tsl/lib/io/inputbuffer.h index 57a4a983c11e75..bec656ecd00ef6 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_INPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_INPUTBUFFER_H_ #include @@ -149,4 +149,4 @@ inline absl::Status InputBuffer::ReadVarint64(uint64* result) { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_INPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_INPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc rename to third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc index ae99467be0ea2a..a4d170101ea675 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputbuffer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc rename to third_party/xla/xla/tsl/lib/io/inputstream_interface.cc index 6425ff0656b658..7bf261f6757609 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h rename to third_party/xla/xla/tsl/lib/io/inputstream_interface.h index 8eb7f2ad868965..3ecb5b5af9e8df 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface.h +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ -#define TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#ifndef XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#define XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ #include @@ -67,4 +67,4 @@ class InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#endif // XLA_TSL_LIB_IO_INPUTSTREAM_INTERFACE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc rename to third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc index c9c34dba55364e..9021440b6e1d84 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/xla/tsl/lib/io/inputstream_interface_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc b/third_party/xla/xla/tsl/lib/io/iterator.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc rename to third_party/xla/xla/tsl/lib/io/iterator.cc index b7e69f7081aa92..2db370d7478d21 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.cc +++ b/third_party/xla/xla/tsl/lib/io/iterator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h b/third_party/xla/xla/tsl/lib/io/iterator.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/iterator.h rename to third_party/xla/xla/tsl/lib/io/iterator.h index 7fe51bfd785bc1..ba0b1dbc4b76de 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/iterator.h +++ b/third_party/xla/xla/tsl/lib/io/iterator.h @@ -23,8 +23,8 @@ limitations under the License. // non-const method, all threads accessing the same Iterator must use // external synchronization. -#ifndef TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ -#define TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ +#ifndef XLA_TSL_LIB_IO_ITERATOR_H_ +#define XLA_TSL_LIB_IO_ITERATOR_H_ #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -101,4 +101,4 @@ extern Iterator* NewErrorIterator(const absl::Status& status); } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ITERATOR_H_ +#endif // XLA_TSL_LIB_IO_ITERATOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h rename to third_party/xla/xla/tsl/lib/io/proto_encode_helper.h index c5bf8262f9df91..33c2411cbc3ca3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/proto_encode_helper.h +++ b/third_party/xla/xla/tsl/lib/io/proto_encode_helper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ -#define TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#ifndef XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#define XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ #include "tsl/platform/coding.h" #include "tsl/platform/logging.h" @@ -96,4 +96,4 @@ class ProtoEncodeHelper { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ +#endif // XLA_TSL_LIB_IO_PROTO_ENCODE_HELPER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc b/third_party/xla/xla/tsl/lib/io/random_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/random_inputstream.cc index 6802707c3387fe..26e138c0e231c2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h b/third_party/xla/xla/tsl/lib/io/random_inputstream.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h rename to third_party/xla/xla/tsl/lib/io/random_inputstream.h index 4d48db62c2b03f..99685ab055ac6a 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/cord.h" #include "tsl/platform/file_system.h" @@ -59,4 +59,4 @@ class RandomAccessInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_RANDOM_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc rename to third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc index dfa4ec80e20a17..e2fc82374e47bb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc +++ b/third_party/xla/xla/tsl/lib/io/random_inputstream_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc b/third_party/xla/xla/tsl/lib/io/record_reader.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc rename to third_party/xla/xla/tsl/lib/io/record_reader.cc index 8d17c610b09f71..8332debff876c2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_reader.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/io/compression.h" -#include "tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/buffered_inputstream.h" +#include "xla/tsl/lib/io/compression.h" +#include "xla/tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/raw_coding.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h b/third_party/xla/xla/tsl/lib/io/record_reader.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h rename to third_party/xla/xla/tsl/lib/io/record_reader.h index 61540a657324c8..3c18992ec86279 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader.h +++ b/third_party/xla/xla/tsl/lib/io/record_reader.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ -#define TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ +#ifndef XLA_TSL_LIB_IO_RECORD_READER_H_ +#define XLA_TSL_LIB_IO_RECORD_READER_H_ -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/errors.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) -#include "tsl/lib/io/snappy/snappy_compression_options.h" -#include "tsl/lib/io/snappy/snappy_inputstream.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_compression_options.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" #endif // IS_SLIM_BUILD #include "tsl/platform/macros.h" #include "tsl/platform/types.h" @@ -177,4 +177,4 @@ class SequentialRecordReader { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RECORD_READER_H_ +#endif // XLA_TSL_LIB_IO_RECORD_READER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc rename to third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc index 45934c9f355576..e91f1ecaed1b99 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/xla/tsl/lib/io/record_reader_writer_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // clang-format off -#include "tsl/lib/io/record_reader.h" -#include "tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_writer.h" // clang-format on #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc b/third_party/xla/xla/tsl/lib/io/record_writer.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc rename to third_party/xla/xla/tsl/lib/io/record_writer.cc index b6e829206e5f2e..2e47e9d0686eb0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.cc +++ b/third_party/xla/xla/tsl/lib/io/record_writer.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/io/record_writer.h" -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/compression.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/compression.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h b/third_party/xla/xla/tsl/lib/io/record_writer.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h rename to third_party/xla/xla/tsl/lib/io/record_writer.h index 94c7ca576403df..5cb160790b9f1c 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_writer.h +++ b/third_party/xla/xla/tsl/lib/io/record_writer.h @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ -#define TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ +#ifndef XLA_TSL_LIB_IO_RECORD_WRITER_H_ +#define XLA_TSL_LIB_IO_RECORD_WRITER_H_ -#include "tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/hash/crc32c.h" #include "tsl/platform/coding.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" #if !defined(IS_SLIM_BUILD) -#include "tsl/lib/io/snappy/snappy_compression_options.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_compression_options.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #endif // IS_SLIM_BUILD #include "tsl/platform/cord.h" #include "tsl/platform/macros.h" @@ -153,4 +153,4 @@ void RecordWriter::PopulateFooter(char* footer, const absl::Cord& data) { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_RECORD_WRITER_H_ +#endif // XLA_TSL_LIB_IO_RECORD_WRITER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/xla/tsl/lib/io/recordio_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc rename to third_party/xla/xla/tsl/lib/io/recordio_test.cc index c07d26a37e698f..02d22ec4931218 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/xla/tsl/lib/io/recordio_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/record_reader.h" -#include "tsl/lib/io/record_writer.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/record_reader.h" +#include "xla/tsl/lib/io/record_writer.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/xla/tsl/lib/io/snappy/BUILD similarity index 56% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD rename to third_party/xla/xla/tsl/lib/io/snappy/BUILD index 0adc5e2fa467aa..6246244fa740e7 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/xla/tsl/lib/io/snappy/BUILD @@ -1,8 +1,8 @@ -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") # Snappy targets. @@ -15,7 +15,7 @@ package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ "//tensorflow/core/lib/io:__pkg__", - "//tsl/lib/io:__pkg__", + "//xla/tsl/lib/io:__pkg__", ]), licenses = ["notice"], ) @@ -34,13 +34,13 @@ cc_library( srcs = ["snappy_inputbuffer.cc"], hdrs = ["snappy_inputbuffer.h"], deps = [ - "//tsl/lib/io:inputstream_interface", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:platform_port", - "//tsl/platform:status", - "//tsl/platform:types", + "//xla/tsl/lib/io:inputstream_interface", "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -50,12 +50,12 @@ cc_library( srcs = ["snappy_outputbuffer.cc"], hdrs = ["snappy_outputbuffer.h"], deps = [ - "//tsl/platform", - "//tsl/platform:env", - "//tsl/platform:macros", - "//tsl/platform:platform_port", - "//tsl/platform:status", - "//tsl/platform:types", + "@local_tsl//tsl/platform", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -65,10 +65,10 @@ cc_library( srcs = ["snappy_inputstream.cc"], hdrs = ["snappy_inputstream.h"], deps = [ - "//tsl/lib/io:inputstream_interface", - "//tsl/platform:errors", - "//tsl/platform:platform_port", + "//xla/tsl/lib/io:inputstream_interface", "@com_google_absl//absl/memory", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", ], alwayslink = True, ) @@ -77,7 +77,7 @@ cc_library( name = "snappy_compression_options", hdrs = ["snappy_compression_options.h"], deps = [ - "//tsl/platform:types", + "@local_tsl//tsl/platform:types", ], alwayslink = True, ) @@ -90,12 +90,12 @@ tsl_cc_test( ":snappy_inputbuffer", ":snappy_inputstream", ":snappy_outputbuffer", - "//tsl/lib/io:inputbuffer", - "//tsl/lib/io:random_inputstream", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "@local_xla//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/io:inputbuffer", + "//xla/tsl/lib/io:random_inputstream", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h similarity index 84% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h index 4d1ba01e3c15d7..3772a415056cf9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_compression_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ #include "tsl/platform/types.h" @@ -33,4 +33,4 @@ struct SnappyCompressionOptions { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_COMPRESSION_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc index 7844b8993fd98d..09c5e482ef51fa 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h index 4d7fd3fe2e010d..969c1e00c2bfe3 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputbuffer.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ #include #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" #include "tsl/platform/env.h" #include "tsl/platform/macros.h" #include "tsl/platform/snappy.h" @@ -131,4 +131,4 @@ class SnappyInputBuffer : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc index 264f524fcef48f..bcbe96e21139e7 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h index 6240aa53feb7fa..44535fe65d8763 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_inputstream.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/inputstream_interface.h" namespace tsl { namespace io { @@ -89,4 +89,4 @@ class SnappyInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc index e851f58f1b9cda..7241d24c46b155 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h index a3bd44748c152f..631014c3b6e189 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_outputbuffer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ #include #include @@ -155,4 +155,4 @@ class SnappyOutputBuffer : public WritableFile { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_SNAPPY_SNAPPY_OUTPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc rename to third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc index 78eecf360d9489..f3504e9268a76e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc +++ b/third_party/xla/xla/tsl/lib/io/snappy/snappy_test.cc @@ -16,11 +16,11 @@ limitations under the License. #include #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/inputbuffer.h" -#include "tsl/lib/io/random_inputstream.h" -#include "tsl/lib/io/snappy/snappy_inputbuffer.h" -#include "tsl/lib/io/snappy/snappy_inputstream.h" -#include "tsl/lib/io/snappy/snappy_outputbuffer.h" +#include "xla/tsl/lib/io/inputbuffer.h" +#include "xla/tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_inputbuffer.h" +#include "xla/tsl/lib/io/snappy/snappy_inputstream.h" +#include "xla/tsl/lib/io/snappy/snappy_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table.cc b/third_party/xla/xla/tsl/lib/io/table.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/table.cc rename to third_party/xla/xla/tsl/lib/io/table.cc index 05f8cd1d71e1c5..5c36b4649859b8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table.cc +++ b/third_party/xla/xla/tsl/lib/io/table.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/cache.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/table_options.h" -#include "tsl/lib/io/two_level_iterator.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/cache.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/two_level_iterator.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table.h b/third_party/xla/xla/tsl/lib/io/table.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/table.h rename to third_party/xla/xla/tsl/lib/io/table.h index 4a6855c661f6b8..3afdb0c461ea10 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table.h +++ b/third_party/xla/xla/tsl/lib/io/table.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_H_ +#define XLA_TSL_LIB_IO_TABLE_H_ #include -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { @@ -84,4 +84,4 @@ class Table { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_H_ +#endif // XLA_TSL_LIB_IO_TABLE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc b/third_party/xla/xla/tsl/lib/io/table_builder.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc rename to third_party/xla/xla/tsl/lib/io/table_builder.cc index c07227b934a2b7..b5fcb0c9ed47dc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.cc +++ b/third_party/xla/xla/tsl/lib/io/table_builder.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/io/table_builder.h" #include -#include "tsl/lib/hash/crc32c.h" -#include "tsl/lib/io/block_builder.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/hash/crc32c.h" +#include "xla/tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/table_options.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h b/third_party/xla/xla/tsl/lib/io/table_builder.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h rename to third_party/xla/xla/tsl/lib/io/table_builder.h index d4e88e989d47bf..059f9ab60546c1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_builder.h +++ b/third_party/xla/xla/tsl/lib/io/table_builder.h @@ -21,12 +21,12 @@ limitations under the License. // non-const method, all threads accessing the same TableBuilder must use // external synchronization. -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_BUILDER_H_ +#define XLA_TSL_LIB_IO_TABLE_BUILDER_H_ #include -#include "tsl/lib/io/table_options.h" +#include "xla/tsl/lib/io/table_options.h" #include "tsl/platform/status.h" #include "tsl/platform/stringpiece.h" @@ -98,4 +98,4 @@ class TableBuilder { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_BUILDER_H_ +#endif // XLA_TSL_LIB_IO_TABLE_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_format.txt b/third_party/xla/xla/tsl/lib/io/table_format.txt similarity index 100% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_format.txt rename to third_party/xla/xla/tsl/lib/io/table_format.txt diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_options.h b/third_party/xla/xla/tsl/lib/io/table_options.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_options.h rename to third_party/xla/xla/tsl/lib/io/table_options.h index c3ca3e1b2fe96d..7784d225c371fe 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_options.h +++ b/third_party/xla/xla/tsl/lib/io/table_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ +#define XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ #include @@ -73,4 +73,4 @@ struct Options { } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TABLE_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_TABLE_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc b/third_party/xla/xla/tsl/lib/io/table_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc rename to third_party/xla/xla/tsl/lib/io/table_test.cc index 567881d1dce92d..6671bc816abc17 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/table_test.cc +++ b/third_party/xla/xla/tsl/lib/io/table_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/table.h" #include #include @@ -21,12 +21,12 @@ limitations under the License. #include #include "absl/strings/escaping.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/block_builder.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table_builder.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/block_builder.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table_builder.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/snappy.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc b/third_party/xla/xla/tsl/lib/io/two_level_iterator.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc rename to third_party/xla/xla/tsl/lib/io/two_level_iterator.cc index f3ee26b3cc71c8..853ea9c037cf49 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.cc +++ b/third_party/xla/xla/tsl/lib/io/two_level_iterator.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/two_level_iterator.h" +#include "xla/tsl/lib/io/two_level_iterator.h" -#include "tsl/lib/io/block.h" -#include "tsl/lib/io/format.h" -#include "tsl/lib/io/iterator.h" -#include "tsl/lib/io/table.h" +#include "xla/tsl/lib/io/block.h" +#include "xla/tsl/lib/io/format.h" +#include "xla/tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/table.h" namespace tsl { namespace table { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h b/third_party/xla/xla/tsl/lib/io/two_level_iterator.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h rename to third_party/xla/xla/tsl/lib/io/two_level_iterator.h index 1ae98da5af9695..87f2e1545aa344 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/two_level_iterator.h +++ b/third_party/xla/xla/tsl/lib/io/two_level_iterator.h @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ -#define TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#ifndef XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#define XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ -#include "tsl/lib/io/iterator.h" +#include "xla/tsl/lib/io/iterator.h" namespace tsl { namespace table { @@ -39,4 +39,4 @@ extern Iterator* NewTwoLevelIterator( } // namespace table } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ +#endif // XLA_TSL_LIB_IO_TWO_LEVEL_ITERATOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc rename to third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc index 75554fa9bca17e..c66d9229e480c9 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_buffers_test.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "xla/tsl/lib/core/status_test_util.h" -#include "tsl/lib/io/random_inputstream.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_inputstream.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/random_inputstream.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/strcat.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc similarity index 94% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc rename to third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc index 4f30c5252c9d9a..724eec1478ccbd 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h rename to third_party/xla/xla/tsl/lib/io/zlib_compression_options.h index 612f32c507148d..0cae3a2ef54128 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_compression_options.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_compression_options.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#define XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ #include "tsl/platform/types.h" @@ -138,4 +138,4 @@ inline ZlibCompressionOptions ZlibCompressionOptions::GZIP() { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_COMPRESSION_OPTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc rename to third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc index 3407805e62ddff..fda83637279579 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_inputstream.h" +#include "xla/tsl/lib/io/zlib_inputstream.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h rename to third_party/xla/xla/tsl/lib/io/zlib_inputstream.h index 7a61fda7c9de71..16df9508636019 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_inputstream.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_inputstream.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#define XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ #include -#include "tsl/lib/io/inputstream_interface.h" -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/inputstream_interface.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tsl/platform/env.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" @@ -139,4 +139,4 @@ class ZlibInputStream : public InputStreamInterface { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_INPUTSTREAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc rename to third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc index afcf5a46752074..646e4397898841 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.cc +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h rename to third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h index a255524ff78c04..96b1d1bb9704da 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_outputbuffer.h +++ b/third_party/xla/xla/tsl/lib/io/zlib_outputbuffer.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ -#define TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#ifndef XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#define XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ #include #include -#include "tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" #include "tsl/platform/env.h" #include "tsl/platform/file_system.h" #include "tsl/platform/macros.h" @@ -156,4 +156,4 @@ class ZlibOutputBuffer : public WritableFile { } // namespace io } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ +#endif // XLA_TSL_LIB_IO_ZLIB_OUTPUTBUFFER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD b/third_party/xla/xla/tsl/lib/math/BUILD similarity index 74% rename from third_party/xla/third_party/tsl/tsl/lib/math/BUILD rename to third_party/xla/xla/tsl/lib/math/BUILD index a78947f3c38ffa..8bb5fd993079fb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/BUILD +++ b/third_party/xla/xla/tsl/lib/math/BUILD @@ -1,9 +1,9 @@ -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -33,11 +33,11 @@ tsl_cc_test( ], deps = [ ":math_util", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/math_util.h b/third_party/xla/xla/tsl/lib/math/math_util.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/math/math_util.h rename to third_party/xla/xla/tsl/lib/math/math_util.h index 26dc0093982740..a2622d48976726 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/math_util.h +++ b/third_party/xla/xla/tsl/lib/math/math_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ -#define TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ +#ifndef XLA_TSL_LIB_MATH_MATH_UTIL_H_ +#define XLA_TSL_LIB_MATH_MATH_UTIL_H_ #include @@ -158,4 +158,4 @@ constexpr T MathUtil::IPow(T base, int exp) { } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_MATH_MATH_UTIL_H_ +#endif // XLA_TSL_LIB_MATH_MATH_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc b/third_party/xla/xla/tsl/lib/math/math_util_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc rename to third_party/xla/xla/tsl/lib/math/math_util_test.cc index ccceabf5cc6da7..c60f9796695ceb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/math/math_util_test.cc +++ b/third_party/xla/xla/tsl/lib/math/math_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/math/math_util.h" +#include "xla/tsl/lib/math/math_util.h" #include #include diff --git a/third_party/xla/xla/tsl/lib/monitoring/BUILD b/third_party/xla/xla/tsl/lib/monitoring/BUILD index 008246504b846d..138efecd6b2580 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/BUILD +++ b/third_party/xla/xla/tsl/lib/monitoring/BUILD @@ -72,6 +72,7 @@ cc_library( ":collection_registry", ":metric_def", "//xla/tsl/lib/histogram", + "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@local_tsl//tsl/platform", @@ -80,7 +81,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -100,9 +100,9 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -115,6 +115,7 @@ cc_library( ":collected_metrics", ":metric_def", ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -123,7 +124,6 @@ cc_library( "@local_tsl//tsl/platform:stringpiece", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) @@ -135,7 +135,7 @@ cc_library( deps = [ ":metric_def", ":types", - "@local_tsl//tsl/protobuf:histogram_proto_cc", + "//xla/tsl/protobuf:histogram_proto_cc", ], ) @@ -201,12 +201,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":types", + "//xla/tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h b/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h index 48b655c2a8a2ba..8e305493e83c6b 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h +++ b/third_party/xla/xla/tsl/lib/monitoring/collected_metrics.h @@ -27,7 +27,7 @@ limitations under the License. #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" -#include "tsl/protobuf/histogram.pb.h" +#include "xla/tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h index 46e93e5c2e46f9..6c48ea9114c8db 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h +++ b/third_party/xla/xla/tsl/lib/monitoring/collection_registry.h @@ -110,6 +110,7 @@ class CollectionRegistry { #include "xla/tsl/lib/monitoring/collected_metrics.h" #include "xla/tsl/lib/monitoring/metric_def.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" @@ -117,7 +118,6 @@ class CollectionRegistry { #include "tsl/platform/stringpiece.h" #include "tsl/platform/thread_annotations.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h index 05a1c44da5b9a9..dcee3f92db4c30 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/metric_def.h +++ b/third_party/xla/xla/tsl/lib/monitoring/metric_def.h @@ -22,9 +22,9 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/stringpiece.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/sampler.h b/third_party/xla/xla/tsl/lib/monitoring/sampler.h index 34e7d79b9cced3..3976e312876cb4 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/sampler.h +++ b/third_party/xla/xla/tsl/lib/monitoring/sampler.h @@ -29,10 +29,10 @@ limitations under the License. #include #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { @@ -125,11 +125,11 @@ class Sampler { #include "xla/tsl/lib/histogram/histogram.h" #include "xla/tsl/lib/monitoring/collection_registry.h" #include "xla/tsl/lib/monitoring/metric_def.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/macros.h" #include "tsl/platform/mutex.h" #include "tsl/platform/status.h" #include "tsl/platform/thread_annotations.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc index b895ead9d0b923..3691130880ab24 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_join.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/errors.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h index d761c9d27ce039..85101ebffc6d69 100644 --- a/third_party/xla/xla/tsl/lib/monitoring/test_utils.h +++ b/third_party/xla/xla/tsl/lib/monitoring/test_utils.h @@ -19,8 +19,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/tsl/lib/monitoring/types.h" +#include "xla/tsl/protobuf/histogram.pb.h" #include "tsl/platform/statusor.h" -#include "tsl/protobuf/histogram.pb.h" namespace tsl { namespace monitoring { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD b/third_party/xla/xla/tsl/lib/random/BUILD similarity index 73% rename from third_party/xla/third_party/tsl/tsl/lib/random/BUILD rename to third_party/xla/xla/tsl/lib/random/BUILD index c64a1332e76ff8..4fa352c07886eb 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/BUILD +++ b/third_party/xla/xla/tsl/lib/random/BUILD @@ -1,13 +1,13 @@ -load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( - "//tsl/platform:build_config.bzl", + "@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test", ) +load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") default_visibility = [ - "//tsl/lib/io:__pkg__", + "//xla/tsl/lib/io:__pkg__", # tensorflow/core/platform/random aliases this package "//tensorflow/core/lib/random:__pkg__", ] @@ -40,11 +40,11 @@ cc_library( ":exact_uniform_int", ":philox_random", ":random_distributions_utils", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) @@ -73,8 +73,8 @@ cc_library( hdrs = ["philox_random_test_utils.h"], deps = [ ":philox_random", - "//tsl/platform:logging", - "//tsl/platform:random", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", ], ) @@ -84,9 +84,9 @@ cc_library( hdrs = ["weighted_picker.h"], deps = [ ":philox", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", ], alwayslink = 1, ) @@ -159,11 +159,11 @@ tsl_cc_test( srcs = ["distribution_sampler_test.cc"], deps = [ ":philox", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) @@ -175,10 +175,10 @@ tsl_cc_test( ":philox", ":philox_random", ":philox_random_test_utils", - "//tsl/platform:logging", - "//tsl/platform:random", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -190,11 +190,11 @@ tsl_cc_test( ":philox", ":philox_random", ":philox_random_test_utils", - "//tsl/lib/math:math_util", - "//tsl/platform:logging", - "//tsl/platform:random", - "//tsl/platform:test", - "//tsl/platform:test_main", + "//xla/tsl/lib/math:math_util", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:random", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -204,10 +204,10 @@ tsl_cc_test( srcs = ["simple_philox_test.cc"], deps = [ ":philox", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) @@ -218,11 +218,11 @@ tsl_cc_test( deps = [ ":philox", ":weighted_picker", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:test", - "//tsl/platform:test_benchmark", - "//tsl/platform:test_main", - "//tsl/platform:types", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc b/third_party/xla/xla/tsl/lib/random/distribution_sampler.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc rename to third_party/xla/xla/tsl/lib/random/distribution_sampler.cc index 9e597ffea0a390..384dd50fc34e74 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.cc +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/distribution_sampler.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h rename to third_party/xla/xla/tsl/lib/random/distribution_sampler.h index 877660c8532baa..ababcc6bf23a31 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler.h +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler.h @@ -28,14 +28,14 @@ limitations under the License. // // The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2. -#ifndef TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#ifndef XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#define XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ #include #include #include "absl/types/span.h" -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/types.h" @@ -92,4 +92,4 @@ class DistributionSampler { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ +#endif // XLA_TSL_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc rename to third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc index 142c01a77023eb..16107ec61c26c0 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/distribution_sampler_test.cc +++ b/third_party/xla/xla/tsl/lib/random/distribution_sampler_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/distribution_sampler.h" #include #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h b/third_party/xla/xla/tsl/lib/random/exact_uniform_int.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h rename to third_party/xla/xla/tsl/lib/random/exact_uniform_int.h index 392d1aa2835110..25d05cb69eefcc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/exact_uniform_int.h +++ b/third_party/xla/xla/tsl/lib/random/exact_uniform_int.h @@ -15,8 +15,8 @@ limitations under the License. // Exact uniform integers using rejection sampling -#ifndef TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#ifndef XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#define XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ #include @@ -31,7 +31,7 @@ UintType ExactUniformInt(const UintType n, const RandomBits& random) { "random() should return UintType"); if (n == 0) { // Consume a value anyway - // TODO(irving): Assert n != 0, since this case makes no sense. + // TODO(geoffreyi): Assert n != 0, since this case makes no sense. return random() * n; } else if (0 == (n & (n - 1))) { // N is a power of two, so just mask off the lower bits. @@ -80,4 +80,4 @@ UintType ExactUniformInt(const UintType n, const RandomBits& random) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ +#endif // XLA_TSL_LIB_RANDOM_EXACT_UNIFORM_INT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h b/third_party/xla/xla/tsl/lib/random/philox_random.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h rename to third_party/xla/xla/tsl/lib/random/philox_random.h index 03b54aae3ec48e..f3b57794f737bc 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random.h +++ b/third_party/xla/xla/tsl/lib/random/philox_random.h @@ -17,8 +17,8 @@ limitations under the License. // Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. // http://www.thesalmons.org/john/random123/papers/random123sc11.pdf -#ifndef TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#ifndef XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#define XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ #include @@ -255,4 +255,4 @@ class PhiloxRandom { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ +#endif // XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc rename to third_party/xla/xla/tsl/lib/random/philox_random_test.cc index 714c510aa4ddd5..7af1f9485754fd 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test.cc +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #include @@ -22,8 +22,8 @@ limitations under the License. #include #include -#include "tsl/lib/random/philox_random_test_utils.h" -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/philox_random_test_utils.h" +#include "xla/tsl/lib/random/random_distributions.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h rename to third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h index 4e217d6362cce1..6bbb1c89596b80 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/philox_random_test_utils.h +++ b/third_party/xla/xla/tsl/lib/random/philox_random_test_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#ifndef XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#define XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ #include -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" @@ -48,4 +48,4 @@ void FillRandoms(PhiloxRandom gen, typename Distribution::ResultElementType* p, } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ +#endif // XLA_TSL_LIB_RANDOM_PHILOX_RANDOM_TEST_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc b/third_party/xla/xla/tsl/lib/random/random_distributions.cc similarity index 90% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc rename to third_party/xla/xla/tsl/lib/random/random_distributions.cc index 12a806f80acbe5..ab8930008f8c8b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.cc +++ b/third_party/xla/xla/tsl/lib/random/random_distributions.cc @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/distribution_sampler.h" -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/distribution_sampler.h" +#include "xla/tsl/lib/random/philox_random.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h b/third_party/xla/xla/tsl/lib/random/random_distributions.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h rename to third_party/xla/xla/tsl/lib/random/random_distributions.h index 70a78cf86082ab..ce231f9f652c27 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions.h +++ b/third_party/xla/xla/tsl/lib/random/random_distributions.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#ifndef XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#define XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ #include #include #include -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions_utils.h" +#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions_utils.h" #include "tsl/platform/types.h" namespace tsl { @@ -753,4 +753,4 @@ PHILOX_DEVICE_INLINE bfloat16 Uint16ToGfloat16(uint16 x) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ +#endif // XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc rename to third_party/xla/xla/tsl/lib/random/random_distributions_test.cc index ccb595aa0dae4f..b1dab4cd81d6d8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_test.cc +++ b/third_party/xla/xla/tsl/lib/random/random_distributions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/random_distributions.h" #include #include @@ -22,9 +22,9 @@ limitations under the License. #include #include -#include "tsl/lib/math/math_util.h" -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/philox_random_test_utils.h" +#include "xla/tsl/lib/math/math_util.h" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random_test_utils.h" #include "tsl/platform/logging.h" #include "tsl/platform/random.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h b/third_party/xla/xla/tsl/lib/random/random_distributions_utils.h similarity index 93% rename from third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h rename to third_party/xla/xla/tsl/lib/random/random_distributions_utils.h index 38f2d792f58a2d..8da345b83e5c97 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/random_distributions_utils.h +++ b/third_party/xla/xla/tsl/lib/random/random_distributions_utils.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#ifndef XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#define XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ #include #include -#include "tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/philox_random.h" #ifndef M_PI #define M_PI (3.14159265358979323846) @@ -94,4 +94,4 @@ void BoxMullerFloat(uint32_t x0, uint32_t x1, float* f0, float* f1) { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ +#endif // XLA_TSL_LIB_RANDOM_RANDOM_DISTRIBUTIONS_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc b/third_party/xla/xla/tsl/lib/random/simple_philox.cc similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc rename to third_party/xla/xla/tsl/lib/random/simple_philox.cc index 1ae957ed0a9c29..f2c2bbe5820863 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" -#include "tsl/lib/random/exact_uniform_int.h" +#include "xla/tsl/lib/random/exact_uniform_int.h" #include "tsl/platform/logging.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h b/third_party/xla/xla/tsl/lib/random/simple_philox.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h rename to third_party/xla/xla/tsl/lib/random/simple_philox.h index 631656519478d3..736bec4d84d238 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox.h +++ b/third_party/xla/xla/tsl/lib/random/simple_philox.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#ifndef XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#define XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ #include #include #include -#include "tsl/lib/random/philox_random.h" -#include "tsl/lib/random/random_distributions.h" +#include "xla/tsl/lib/random/philox_random.h" +#include "xla/tsl/lib/random/random_distributions.h" namespace tsl { namespace random { @@ -74,4 +74,4 @@ class SimplePhilox { } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ +#endif // XLA_TSL_LIB_RANDOM_SIMPLE_PHILOX_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc rename to third_party/xla/xla/tsl/lib/random/simple_philox_test.cc index 657d4cf64758ca..3eded84eb0ee33 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/simple_philox_test.cc +++ b/third_party/xla/xla/tsl/lib/random/simple_philox_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc b/third_party/xla/xla/tsl/lib/random/weighted_picker.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc rename to third_party/xla/xla/tsl/lib/random/weighted_picker.cc index 06e0df581fbd46..911f0f4d300616 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.cc +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/weighted_picker.h" +#include "xla/tsl/lib/random/weighted_picker.h" #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" namespace tsl { namespace random { diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h b/third_party/xla/xla/tsl/lib/random/weighted_picker.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h rename to third_party/xla/xla/tsl/lib/random/weighted_picker.h index 05fabea852b5f7..27903077df2a73 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker.h +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker.h @@ -24,8 +24,8 @@ limitations under the License. // Alternative: distribution-sampler.h allows O(1) time picking, but no weight // adjustment after construction. -#ifndef TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ -#define TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#ifndef XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#define XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ #include @@ -131,4 +131,4 @@ inline int WeightedPicker::num_elements() const { return N_; } } // namespace random } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ +#endif // XLA_TSL_LIB_RANDOM_WEIGHTED_PICKER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc rename to third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc index a81b2ad99d5e6b..64e40c05c432a8 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/random/weighted_picker_test.cc +++ b/third_party/xla/xla/tsl/lib/random/weighted_picker_test.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/random/weighted_picker.h" +#include "xla/tsl/lib/random/weighted_picker.h" #include #include -#include "tsl/lib/random/simple_philox.h" +#include "xla/tsl/lib/random/simple_philox.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD index 03f82a366f78c6..1e36a8d78f0b5e 100644 --- a/third_party/xla/xla/tsl/lib/strings/BUILD +++ b/third_party/xla/xla/tsl/lib/strings/BUILD @@ -13,9 +13,9 @@ cc_library( hdrs = ["proto_serialization.h"], visibility = ["//visibility:public"], deps = [ + "//xla/tsl/lib/gtl:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/gtl:inlined_vector", "@local_tsl//tsl/platform:hash", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", diff --git a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc index 06ef0747ee553d..fef78bd1835a00 100644 --- a/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc @@ -19,7 +19,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/string_view.h" -#include "tsl/lib/gtl/inlined_vector.h" +#include "xla/tsl/lib/gtl/inlined_vector.h" #include "tsl/platform/hash.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD index 28c4a7c57b2142..4121ad5fc3a92b 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/BUILD @@ -40,6 +40,8 @@ cc_library( "//xla/tsl/profiler:xla_internal", ]), deps = [ + "//xla/tsl/profiler/utils:lock_free_queue", + "//xla/tsl/profiler/utils:per_thread", "@com_google_absl//absl/container:flat_hash_map", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", @@ -47,8 +49,6 @@ cc_library( "@local_tsl//tsl/platform:mutex", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:lock_free_queue", - "@local_tsl//tsl/profiler/utils:per_thread", ], alwayslink = True, ) @@ -59,6 +59,9 @@ tsl_cc_test( deps = [ ":traceme_recorder", ":traceme_recorder_impl", + "//xla/tsl/profiler/utils:math_utils", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", @@ -68,9 +71,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ], ) @@ -120,13 +120,13 @@ cc_library( ]), deps = [ ":traceme_recorder", + "//xla/tsl/profiler/utils:parse_annotation", + "//xla/tsl/profiler/utils:tf_op_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:parse_annotation", - "@local_tsl//tsl/profiler/utils:tf_op_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -144,6 +144,8 @@ cc_library( deps = [ ":threadpool_listener_state", ":traceme_recorder", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_schema", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:logging", @@ -153,8 +155,6 @@ cc_library( "@local_tsl//tsl/profiler/lib:profiler_interface", "@local_tsl//tsl/profiler/lib:traceme_encode", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc index 4195d8df45c953..3ee8fae3f04883 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/host_tracer_utils.cc @@ -19,12 +19,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/parse_annotation.h" -#include "tsl/profiler/utils/tf_op_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc index 3772efb86adbb9..af9fc451b2d238 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/threadpool_listener.cc @@ -21,14 +21,14 @@ limitations under the License. #include "absl/status/status.h" #include "xla/tsl/profiler/backends/cpu/threadpool_listener_state.h" #include "xla/tsl/profiler/backends/cpu/traceme_recorder.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/logging.h" #include "tsl/platform/tracing.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme_encode.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc index 1e0b77e92ed7b4..1279b80dd62dbe 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder.cc @@ -25,12 +25,12 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/lock_free_queue.h" -#include "tsl/profiler/utils/per_thread.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc index bdaf0e9ae3f81d..9fa89ed3d5e400 100644 --- a/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc +++ b/third_party/xla/xla/tsl/profiler/backends/cpu/traceme_recorder_test.cc @@ -23,14 +23,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/notification.h" #include "tsl/platform/test.h" #include "tsl/platform/threadpool.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/time_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/BUILD b/third_party/xla/xla/tsl/profiler/convert/BUILD index edda2c08a65d07..7b1eb9abd00ee2 100644 --- a/third_party/xla/xla/tsl/profiler/convert/BUILD +++ b/third_party/xla/xla/tsl/profiler/convert/BUILD @@ -55,11 +55,11 @@ cc_library( copts = tf_profiler_copts(), visibility = internal_visibility(["//xla/tsl/profiler:internal"]), deps = [ + "//xla/tsl/profiler/utils:timestamp_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:timestamp_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", ], ) @@ -73,14 +73,14 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/profiler/utils:format_utils", + "//xla/tsl/profiler/utils:math_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@jsoncpp_git//:jsoncpp", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", - "@local_tsl//tsl/profiler/utils:format_utils", - "@local_tsl//tsl/profiler/utils:math_utils", ], ) @@ -120,16 +120,16 @@ cc_library( ]), deps = [ ":trace_container", + "//xla/tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_schema", + "//xla/tsl/profiler/utils:xplane_utils", + "//xla/tsl/profiler/utils:xplane_visitor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:tf_xplane_visitor", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_schema", - "@local_tsl//tsl/profiler/utils:xplane_utils", - "@local_tsl//tsl/profiler/utils:xplane_visitor", ], ) @@ -139,12 +139,12 @@ tsl_cc_test( srcs = ["xplane_to_trace_events_test.cc"], deps = [ ":xplane_to_trace_events", + "//xla/tsl/profiler/utils:trace_utils", + "//xla/tsl/profiler/utils:xplane_builder", + "//xla/tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/profiler/protobuf:trace_events_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:trace_utils", - "@local_tsl//tsl/profiler/utils:xplane_builder", - "@local_tsl//tsl/profiler/utils:xplane_schema", ], ) diff --git a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc index ec1c4e72ba970a..864da925423d99 100644 --- a/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc +++ b/third_party/xla/xla/tsl/profiler/convert/post_process_single_host_xplane.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timestamp_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc index 24d55cc81998be..d9bc3319fdbb5d 100644 --- a/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc +++ b/third_party/xla/xla/tsl/profiler/convert/trace_events_to_json.cc @@ -22,11 +22,11 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_cat.h" #include "json/json.h" +#include "xla/tsl/profiler/utils/format_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" -#include "tsl/profiler/utils/format_utils.h" -#include "tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc index fbc6ea7da1a8d3..c37951436d7168 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events.cc @@ -24,14 +24,14 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc index 8a724596767a31..6e0d3955c84cbf 100644 --- a/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/convert/xplane_to_trace_events_test.cc @@ -18,12 +18,12 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/trace_events.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/BUILD b/third_party/xla/xla/tsl/profiler/rpc/BUILD index 6f96ed053dc550..2b902f0c35d391 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/BUILD @@ -29,6 +29,10 @@ cc_library( ]), deps = [ "//xla/tsl/profiler/rpc/client:save_profile", + "//xla/tsl/profiler/utils:file_system_utils", + "//xla/tsl/profiler/utils:math_utils", + "//xla/tsl/profiler/utils:time_utils", + "//xla/tsl/profiler/utils:xplane_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -43,10 +47,6 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", - "@local_tsl//tsl/profiler/utils:math_utils", - "@local_tsl//tsl/profiler/utils:time_utils", - "@local_tsl//tsl/profiler/utils:xplane_utils", ] + tsl_grpc_cc_dependencies(), ) diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD index 8e02d961b59f47..87082573a0b624 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/xla/tsl/profiler/rpc/client/BUILD @@ -36,6 +36,7 @@ cc_library( ":save_profile", "//xla/tsl/profiler/convert:trace_events_to_json", "//xla/tsl/profiler/convert:xplane_to_trace_events", + "//xla/tsl/profiler/utils:session_manager", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -47,7 +48,6 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:session_manager", ], ) @@ -64,10 +64,11 @@ cc_library( "//learning/pathways/data_parallel:__pkg__", ]), deps = [ + "//xla/tsl/lib/io:zlib_compression_options", + "//xla/tsl/lib/io:zlib_outputbuffer", + "//xla/tsl/profiler/utils:file_system_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/io:zlib_compression_options", - "@local_tsl//tsl/lib/io:zlib_outputbuffer", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -75,7 +76,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_service_proto_cc", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", - "@local_tsl//tsl/profiler/utils:file_system_utils", ], ) @@ -121,6 +121,7 @@ cc_library( "//tensorflow/python/profiler/internal:__pkg__", ]), deps = [ + "//xla/tsl/protobuf:error_codes_proto_impl_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -130,7 +131,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/protobuf:profiler_analysis_cc_grpc_proto", "@local_tsl//tsl/profiler/protobuf:profiler_service_cc_grpc_proto", - "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ] + tsl_grpc_cc_dependencies(), alwayslink = True, ) @@ -161,6 +161,7 @@ tsl_cc_test( ":profiler_client_test_util", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", @@ -170,7 +171,6 @@ tsl_cc_test( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ] + tf_protos_profiler_service(), ) @@ -181,6 +181,7 @@ cc_library( copts = tf_profiler_copts(), deps = [ ":profiler_client_for_pybind", + "//xla/tsl/profiler/utils:time_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -192,7 +193,6 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:thread_annotations", "@local_tsl//tsl/platform:types", - "@local_tsl//tsl/profiler/utils:time_utils", ], ) @@ -205,6 +205,7 @@ tsl_cc_test( ":remote_profiler_session_manager", "//xla/tsl/profiler/rpc:profiler_server_impl", "//xla/tsl/profiler/rpc:profiler_service_impl", + "//xla/tsl/profiler/utils:time_utils_impl", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@local_tsl//tsl/platform:env_impl", @@ -216,6 +217,5 @@ tsl_cc_test( "@local_tsl//tsl/profiler/lib:profiler_factory_impl", "@local_tsl//tsl/profiler/lib:profiler_session_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", - "@local_tsl//tsl/profiler/utils:time_utils_impl", ] + tf_protos_profiler_service(), ) diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc index b78ceadcda2b93..84dd66b6e2f118 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/capture_profile.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/tsl/profiler/rpc/client/profiler_client.h" #include "xla/tsl/profiler/rpc/client/remote_profiler_session_manager.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include "tsl/platform/errors.h" #include "tsl/platform/host_info.h" #include "tsl/platform/status.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tsl/profiler/protobuf/profiler_analysis.pb.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" -#include "tsl/profiler/utils/session_manager.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc index 77d295bf1af6f8..892bef42bb0a82 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/profiler_client.cc @@ -21,11 +21,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "grpcpp/grpcpp.h" // IWYU pragma: keep +#include "xla/tsl/protobuf/error_codes.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" #include "tsl/platform/types.h" -#include "tsl/protobuf/error_codes.pb.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc index bc987157013dcf..2eb7e0d6743180 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/remote_profiler_session_manager.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/time/clock.h" #include "absl/time/time.h" #include "xla/tsl/profiler/rpc/client/profiler_client.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/time_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc index cf98c56d1944e4..bc8bf69f492bfd 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/client/save_profile.cc @@ -28,8 +28,9 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/time/clock.h" #include "absl/time/time.h" -#include "tsl/lib/io/zlib_compression_options.h" -#include "tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/lib/io/zlib_compression_options.h" +#include "xla/tsl/lib/io/zlib_outputbuffer.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tsl/platform/status.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc index 3f4e3669174f5e..d359f0bdadb1fd 100644 --- a/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/xla/xla/tsl/profiler/rpc/profiler_service_impl.cc @@ -21,6 +21,10 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "grpcpp/support/status.h" #include "xla/tsl/profiler/rpc/client/save_profile.h" +#include "xla/tsl/profiler/utils/file_system_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" @@ -32,10 +36,6 @@ limitations under the License. #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" #include "tsl/profiler/protobuf/profiler_service.pb.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/file_system_utils.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/time_utils.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/xla/tsl/profiler/utils/BUILD similarity index 66% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD rename to third_party/xla/xla/tsl/profiler/utils/BUILD index 8ea8fd71837272..281fca182068f9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/xla/tsl/profiler/utils/BUILD @@ -1,14 +1,14 @@ +load("@local_tsl//tsl/platform:build_config.bzl", "tsl_cc_test") +load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") -load("@local_xla//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") -load("//tsl/platform:build_config.bzl", "tsl_cc_test") -load("//tsl/platform:build_config_root.bzl", "if_static") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") +load("//xla/tsl/profiler/builds:build_config.bzl", "tf_profiler_copts") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = internal_visibility([ - "@local_xla//xla/tsl/profiler:internal", + "//xla/tsl/profiler:internal", ]), licenses = ["notice"], ) @@ -16,7 +16,7 @@ package( package_group( name = "friends", includes = [ - "@local_xla//xla/tsl/profiler:friends", + "//xla/tsl/profiler:friends", ], ) @@ -29,7 +29,7 @@ cc_library( name = "format_utils", hdrs = ["format_utils.h"], deps = [ - "//tsl/platform:logging", + "@local_tsl//tsl/platform:logging", ], ) @@ -53,9 +53,9 @@ cc_library( ], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla:__subpackages__", - "//tsl/platform/cloud:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla:__subpackages__", + "@local_tsl//tsl/platform/cloud:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ ":math_utils", @@ -71,9 +71,9 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":math_utils", - "//tsl/platform:logging", - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", ], ) @@ -82,8 +82,8 @@ tsl_cc_test( srcs = ["timespan_test.cc"], deps = [ ":timespan", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -93,9 +93,9 @@ cc_library( hdrs = ["tf_op_utils.h"], copts = tf_profiler_copts(), deps = [ - "//tsl/platform:macros", - "//tsl/platform:regexp", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:regexp", ], ) @@ -105,9 +105,9 @@ tsl_cc_test( srcs = ["tf_op_utils_test.cc"], deps = [ ":tf_op_utils", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -119,15 +119,15 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":tf_op_utils", - "//tsl/lib/gtl:map_util", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:types", - "//tsl/profiler/lib:context_types_hdrs", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/lib:context_types_hdrs", ], ) @@ -139,12 +139,12 @@ cc_library( visibility = internal_visibility([":friends"]), deps = [ ":timespan", - "//tsl/platform:logging", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -157,14 +157,14 @@ cc_library( deps = [ ":math_utils", ":timespan", - "//tsl/platform:macros", - "//tsl/platform:protobuf", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/meta:type_traits", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -175,10 +175,10 @@ tsl_cc_test( deps = [ ":xplane_builder", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -187,12 +187,12 @@ cc_library( hdrs = ["trace_utils.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/backends/profiler/gpu:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/backends/profiler/gpu:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform:types", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:types", ], ) @@ -210,16 +210,16 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_visitor", - "//tsl/platform:fingerprint", - "//tsl/platform:types", - "//tsl/profiler/lib:context_types", - "//tsl/profiler/protobuf:xplane_proto_cc", + "//xla/tsl/util:stats_calculator_portable", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", - "@local_xla//xla/tsl/util:stats_calculator_portable", + "@local_tsl//tsl/platform:fingerprint", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/lib:context_types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -232,14 +232,14 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", - "//tsl/profiler/utils:tf_xplane_visitor", + "//xla/tsl/profiler/utils:tf_xplane_visitor", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -251,7 +251,7 @@ cc_library( deps = [ ":xplane_schema", ":xplane_visitor", - "//tsl/profiler/protobuf:xplane_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -271,9 +271,9 @@ tsl_cc_test( srcs = ["parse_annotation_test.cc"], deps = [ ":parse_annotation", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -289,18 +289,18 @@ cc_library( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/lib/gtl:map_util", - "//tsl/platform:dso_loader", - "//tsl/platform:env", - "//tsl/platform:logging", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", + "//xla/tsl/lib/gtl:map_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:dso_loader", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -315,11 +315,11 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -333,13 +333,13 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/platform:types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -350,10 +350,10 @@ cc_library( deps = [ ":xplane_schema", ":xplane_utils", - "//tsl/platform:regexp", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/platform:regexp", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -365,10 +365,10 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -377,12 +377,12 @@ cc_library( hdrs = ["file_system_utils.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/python:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/python:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform", ], ) @@ -392,14 +392,14 @@ cc_library( hdrs = ["buffer_pool.h"], copts = tf_profiler_copts(), visibility = internal_visibility([ - "@local_xla//xla/backends/profiler/gpu:__pkg__", - "@local_xla//xla/tsl/profiler:internal", + "//xla/backends/profiler/gpu:__pkg__", + "//xla/tsl/profiler:internal", ]), deps = [ - "//tsl/platform:logging", - "//tsl/platform:mutex", - "//tsl/platform:platform_port", - "//tsl/platform:thread_annotations", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:thread_annotations", ], ) @@ -408,8 +408,8 @@ tsl_cc_test( srcs = ["buffer_pool_test.cc"], deps = [ ":buffer_pool", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -425,14 +425,14 @@ cc_library( ":xplane_builder", ":xplane_mutators", ":xplane_schema", - "//tsl/profiler/lib:context_types", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@local_tsl//tsl/profiler/lib:context_types", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -446,12 +446,12 @@ tsl_cc_test( ":xplane_schema", ":xplane_test_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/profiler/lib:connected_traceme", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/profiler/lib:connected_traceme", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -460,11 +460,11 @@ cc_library( srcs = ["session_manager.cc"], hdrs = ["session_manager.h"], deps = [ - "//tsl/platform:errors", - "//tsl/platform:status", - "//tsl/profiler/lib:profiler_session", - "//tsl/profiler/protobuf:profiler_options_proto_cc", "@com_google_absl//absl/container:flat_hash_map", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/profiler/lib:profiler_session", + "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", ], ) @@ -476,8 +476,8 @@ cc_library( ":xplane_builder", ":xplane_schema", ":xplane_utils", - "//tsl/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/log", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -489,8 +489,8 @@ tsl_cc_test( ":xplane_schema", ":xplane_utils", ":xplane_visitor", - "//tsl/platform:test", - "//tsl/platform:test_main", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -516,8 +516,8 @@ cc_library( hdrs = ["lock_free_queue.h"], deps = [ ":no_init", - "//tsl/platform:logging", - "//tsl/platform:macros", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", ], ) @@ -528,11 +528,11 @@ tsl_cc_test( srcs = ["lock_free_queue_test.cc"], deps = [ ":lock_free_queue", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) @@ -559,12 +559,36 @@ tsl_cc_test( srcs = ["per_thread_test.cc"], deps = [ ":per_thread", - "//tsl/platform:env", - "//tsl/platform:env_impl", - "//tsl/platform:test", - "//tsl/platform:test_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_impl", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "device_utils", + srcs = ["device_utils.cc"], + hdrs = ["device_utils.h"], + deps = [ + ":xplane_schema", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", + ], +) + +tsl_cc_test( + name = "device_utils_test", + srcs = ["device_utils_test.cc"], + deps = [ + ":device_utils", + ":xplane_schema", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc index e7811327c475b2..f16fe91d573a8b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool.h index dcfd5b0acb6a1b..5482b7cd8bc261 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool.h +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#ifndef XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#define XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ #include @@ -59,4 +59,4 @@ class BufferPool { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_BUFFER_POOL_H_ +#endif // XLA_TSL_PROFILER_UTILS_BUFFER_POOL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc rename to third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc index ec1696a12e08b7..4e5dbab63085de 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/buffer_pool_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/buffer_pool_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/buffer_pool.h" +#include "xla/tsl/profiler/utils/buffer_pool.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tsl/profiler/utils/device_utils.cc b/third_party/xla/xla/tsl/profiler/utils/device_utils.cc new file mode 100644 index 00000000000000..945a157ec9c456 --- /dev/null +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/profiler/utils/device_utils.h" + +#include "absl/strings/match.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane) { + if (plane.name() == kHostThreadsPlaneName) { + return DeviceType::kCpu; + } else if (absl::StartsWith(plane.name(), kTpuPlanePrefix)) { + return DeviceType::kTpu; + } else if (absl::StartsWith(plane.name(), kGpuPlanePrefix)) { + return DeviceType::kGpu; + } else { + return DeviceType::kUnknown; + } +} +} // namespace profiler +} // namespace tsl diff --git a/third_party/xla/xla/tsl/profiler/utils/device_utils.h b/third_party/xla/xla/tsl/profiler/utils/device_utils.h new file mode 100644 index 00000000000000..825a9fe975437d --- /dev/null +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ + +#include "tsl/profiler/protobuf/xplane.pb.h" + +namespace tsl { +namespace profiler { + +enum class DeviceType { + kUnknown, + kCpu, + kTpu, + kGpu, +}; + +// Get DeviceType from XPlane. +DeviceType GetDeviceType(const tensorflow::profiler::XPlane& plane); + +} // namespace profiler +} // namespace tsl + +#endif // XLA_TSL_PROFILER_UTILS_DEVICE_UTILS_H_ diff --git a/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc new file mode 100644 index 00000000000000..6f872dc5713bed --- /dev/null +++ b/third_party/xla/xla/tsl/profiler/utils/device_utils_test.cc @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/profiler/utils/device_utils.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "tsl/platform/test.h" + +namespace tsl { +namespace profiler { +namespace { + +tensorflow::profiler::XPlane CreateXPlane(absl::string_view name) { + tensorflow::profiler::XPlane plane; + plane.set_name(name.data(), name.size()); + return plane; +} + +TEST(DeviceUtilsTest, GetDeviceType) { + EXPECT_EQ(GetDeviceType(CreateXPlane(kHostThreadsPlaneName)), + DeviceType::kCpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kTpuPlanePrefix, 0))), + DeviceType::kTpu); + EXPECT_EQ(GetDeviceType(CreateXPlane(absl::StrCat(kGpuPlanePrefix, 0))), + DeviceType::kGpu); + EXPECT_EQ(GetDeviceType(CreateXPlane("unknown")), DeviceType::kUnknown); +} + +} // namespace +} // namespace profiler +} // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h b/third_party/xla/xla/tsl/profiler/utils/file_system_utils.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h rename to third_party/xla/xla/tsl/profiler/utils/file_system_utils.h index 6d7c937908f43c..522b5284afec49 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/file_system_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/file_system_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ #include #include @@ -66,4 +66,4 @@ std::string ProfilerJoinPath(const T&... args) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_FILE_SYSTEM_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h b/third_party/xla/xla/tsl/profiler/utils/format_utils.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h rename to third_party/xla/xla/tsl/profiler/utils/format_utils.h index 4a1be939de9ccd..d93d69e8592d70 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/format_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/format_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ #include @@ -60,4 +60,4 @@ inline std::string MaxPrecision(double d) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_FORMAT_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc b/third_party/xla/xla/tsl/profiler/utils/group_events.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc rename to third_party/xla/xla/tsl/profiler/utils/group_events.cc index 2811232b43bc5e..393e170b839446 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/group_events.h" #include #include @@ -31,15 +31,15 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/functional/bind_front.h" #include "absl/strings/str_cat.h" -#include "tsl/lib/gtl/map_util.h" +#include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/dso_loader.h" #include "tsl/platform/env.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h b/third_party/xla/xla/tsl/profiler/utils/group_events.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h rename to third_party/xla/xla/tsl/profiler/utils/group_events.h index c77c32a623ea08..52a73529fb734c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events.h +++ b/third_party/xla/xla/tsl/profiler/utils/group_events.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#define XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ #include #include @@ -28,11 +28,11 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -258,4 +258,4 @@ void GroupTpuEventsOSS( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ +#endif // XLA_TSL_PROFILER_UTILS_GROUP_EVENTS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc rename to third_party/xla/xla/tsl/profiler/utils/group_events_test.cc index e3607626263004..e8c3306ee4ea3d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/group_events_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/group_events_test.cc @@ -13,20 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/group_events.h" +#include "xla/tsl/profiler/utils/group_events.h" #include #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h rename to third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h index d8f197be6c314b..9f22aa8b8e5094 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue.h +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#ifndef XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#define XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ #include @@ -23,9 +23,9 @@ limitations under the License. #include #include +#include "xla/tsl/profiler/utils/no_init.h" #include "tsl/platform/logging.h" #include "tsl/platform/macros.h" -#include "tsl/profiler/utils/no_init.h" namespace tsl { namespace profiler { @@ -311,4 +311,4 @@ class LockFreeQueue final } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ +#endif // XLA_TSL_PROFILER_UTILS_LOCK_FREE_QUEUE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc rename to third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc index 78a8b07d7bdc20..fd8ccdfb659207 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/lock_free_queue_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/lock_free_queue_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/lock_free_queue.h" +#include "xla/tsl/profiler/utils/lock_free_queue.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h b/third_party/xla/xla/tsl/profiler/utils/math_utils.h similarity index 94% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h rename to third_party/xla/xla/tsl/profiler/utils/math_utils.h index 06b3495576d887..cd9e8685e8c35c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/math_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/math_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ #include @@ -70,4 +70,4 @@ inline double GibibytesPerSecond(double gigabytes, double ns) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_MATH_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_MATH_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h b/third_party/xla/xla/tsl/profiler/utils/no_init.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h rename to third_party/xla/xla/tsl/profiler/utils/no_init.h index 5beb1908380c90..6f7d6aa95d1b25 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/no_init.h +++ b/third_party/xla/xla/tsl/profiler/utils/no_init.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ +#ifndef XLA_TSL_PROFILER_UTILS_NO_INIT_H_ +#define XLA_TSL_PROFILER_UTILS_NO_INIT_H_ #include @@ -48,4 +48,4 @@ union NoInit { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_NO_INIT_H_ +#endif // XLA_TSL_PROFILER_UTILS_NO_INIT_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc index 986f9b08e65a32..67328c1ea6e9bc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation.h index 1552a2f271140b..8d755f7e64fb4c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation.h +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#define XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ #include @@ -48,4 +48,4 @@ std::vector ParseAnnotationStack( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ +#endif // XLA_TSL_PROFILER_UTILS_PARSE_ANNOTATION_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc rename to third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc index 93730c851d4311..6225916ef96cfc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/parse_annotation_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/parse_annotation_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/parse_annotation.h" +#include "xla/tsl/profiler/utils/parse_annotation.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h b/third_party/xla/xla/tsl/profiler/utils/per_thread.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h rename to third_party/xla/xla/tsl/profiler/utils/per_thread.h index 3163fd890c1c7f..f3e9d79242ceab 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread.h +++ b/third_party/xla/xla/tsl/profiler/utils/per_thread.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ +#define XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ #include #include @@ -143,4 +143,4 @@ class PerThread { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PER_THREAD_H_ +#endif // XLA_TSL_PROFILER_UTILS_PER_THREAD_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc rename to third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc index 15af47c3195a47..9007319c4d0c74 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/per_thread_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/per_thread_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/per_thread.h" +#include "xla/tsl/profiler/utils/per_thread.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc index 3d06a05609a118..9925276dacfdde 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.cc +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h index 9d027c780c9bfb..c64a6d02417e48 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane.h +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#ifndef XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#define XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ #include #include @@ -31,13 +31,13 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_mutators.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tpu_xplane_utils.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_mutators.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { @@ -533,4 +533,4 @@ void PreprocessXPlane(XPlane* plane); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ +#endif // XLA_TSL_PROFILER_UTILS_PREPROCESS_XPLANE_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc rename to third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc index 9712893645090c..d18d6452a6a85d 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/preprocess_xplane_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/preprocess_xplane_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/preprocess_xplane.h" +#include "xla/tsl/profiler/utils/preprocess_xplane.h" #include #include @@ -21,14 +21,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/lib/connected_traceme.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_test_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc rename to third_party/xla/xla/tsl/profiler/utils/session_manager.cc index 7fd31dab970104..d45b6edd83efba 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.cc +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/session_manager.h" +#include "xla/tsl/profiler/utils/session_manager.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h b/third_party/xla/xla/tsl/profiler/utils/session_manager.h similarity index 91% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h rename to third_party/xla/xla/tsl/profiler/utils/session_manager.h index 9a6a6300cef51f..fd8c60cbc63d13 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/session_manager.h +++ b/third_party/xla/xla/tsl/profiler/utils/session_manager.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#ifndef XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#define XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ #include #include @@ -52,4 +52,4 @@ absl::Status ValidateHostPortPair(absl::string_view host_port); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ +#endif // XLA_TSL_PROFILER_UTILS_SESSION_MANAGER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc index 4129e2ae8fa7c7..981de3f3141d32 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h index 85230331cc6e06..078d4d7c3b6f9c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ #include #include @@ -151,4 +151,4 @@ bool IsJaxOpNameAndType(absl::string_view op_name, absl::string_view op_type); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TF_OP_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc index d03be20b197ca6..aef2bbc686f4d8 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_op_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tf_op_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h b/third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h similarity index 78% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h rename to third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h index 59dbbcc100fbfc..f902562935743b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tf_xplane_visitor.h +++ b/third_party/xla/xla/tsl/profiler/utils/tf_xplane_visitor.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#define XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -32,4 +32,4 @@ inline XPlaneVisitor CreateTfXPlaneVisitor( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ +#endif // XLA_TSL_PROFILER_UTILS_TF_XPLANE_VISITOR_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc b/third_party/xla/xla/tsl/profiler/utils/time_utils.cc similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/time_utils.cc index 03d9973df0562f..a101ec2335070a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/time_utils.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/time_utils.h" +#include "xla/tsl/profiler/utils/time_utils.h" #include "absl/time/clock.h" #include "absl/time/time.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h b/third_party/xla/xla/tsl/profiler/utils/time_utils.h similarity index 87% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h rename to third_party/xla/xla/tsl/profiler/utils/time_utils.h index 3cd30214f49975..65c12c70005b30 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/time_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/time_utils.h @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ #include -#include "tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { @@ -40,4 +40,4 @@ inline void SpinForMicros(int64_t us) { SpinForNanos(us * 1000); } } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIME_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIME_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h b/third_party/xla/xla/tsl/profiler/utils/timespan.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h rename to third_party/xla/xla/tsl/profiler/utils/timespan.h index ee4c1d646ab3a3..d1883b8566a6ae 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan.h +++ b/third_party/xla/xla/tsl/profiler/utils/timespan.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ +#define XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ #include #include #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/math_utils.h" #include "tsl/platform/logging.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/math_utils.h" namespace tsl { namespace profiler { @@ -131,4 +131,4 @@ inline Timespan MilliSpan(double start_ms, double end_ms) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESPAN_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIMESPAN_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc rename to third_party/xla/xla/tsl/profiler/utils/timespan_test.cc index f729f088b14222..57d7876365c904 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timespan_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timespan_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc index ea208ed309c468..17b728f7f9cad1 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/timestamp_utils.h" #include #include "absl/log/log.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h similarity index 86% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h index 87013c97a6f5b0..a2b61672fbabba 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ #include @@ -30,4 +30,4 @@ void SetSessionTimestamps(uint64_t start_walltime_ns, uint64_t stop_walltime_ns, } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TIMESTAMP_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc similarity index 86% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc index 893e31ebb5ec59..dd2e434adbc0f3 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/timestamp_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/timestamp_utils_test.cc @@ -13,12 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/timestamp_utils.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc similarity index 92% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc index 9274a1da941743..d456164be18a46 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h index 2fb7c677e3a058..3f6adb498cd270 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ #include #include @@ -43,4 +43,4 @@ std::optional GetSparseCoreId(absl::string_view plane_name); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TPU_XPLANE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc similarity index 93% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc index e5bcd73c339be9..fc341c98582cc9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/tpu_xplane_utils_test.cc @@ -12,16 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/tpu_xplane_utils.h" +#include "xla/tsl/profiler/utils/tpu_xplane_utils.h" #include #include "absl/strings/str_cat.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h rename to third_party/xla/xla/tsl/profiler/utils/trace_utils.h index 98e18973a5ea79..ef53e611ab95fa 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/trace_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/trace_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ #include @@ -81,4 +81,4 @@ static inline std::optional ParseDeviceOrdinal( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_TRACE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_TRACE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc index 6def9d1c768d47..ebcc5884ffe3e7 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder.h index 4a837c86ac00aa..522b34612d48bc 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ #include @@ -27,12 +27,12 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/macros.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { @@ -454,4 +454,4 @@ absl::string_view XStatsBuilder::StrOrRefValue(const XStat& stat) { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_BUILDER_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc index 87c433773472b6..ee2c8e4df0400b 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_builder_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_builder_test.cc @@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" #include #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h b/third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h index e558936b2130e6..0873aee9a41f3a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_mutators.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_mutators.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ #include #include -#include "tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" namespace tsl { namespace profiler { @@ -60,4 +60,4 @@ class XplaneEventMutatorFactory { } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_MUTATORS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc similarity index 74% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc index da16a8704187be..a3f85cb7f2dd78 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" -#include "tsl/lib/gtl/map_util.h" -#include "tsl/profiler/utils/tf_op_utils.h" +#include "xla/tsl/lib/gtl/map_util.h" +#include "xla/tsl/profiler/utils/tf_op_utils.h" namespace tsl { namespace profiler { @@ -49,6 +49,7 @@ const absl::string_view kHostCpusPlaneName = "Host CPUs"; const absl::string_view kSyscallsPlaneName = "Syscalls"; const absl::string_view kStepLineName = "Steps"; +const absl::string_view kSparseCoreStepLineName = "Sparse Core Steps"; const absl::string_view kTensorFlowNameScopeLineName = "Framework Name Scope"; const absl::string_view kTensorFlowOpLineName = "Framework Ops"; const absl::string_view kXlaModuleLineName = "XLA Modules"; @@ -200,147 +201,151 @@ const HostEventTypeMap& GetHostEventTypeMap() { } const StatTypeMap& GetStatTypeMap() { - static auto* stat_type_map = new StatTypeMap({ - {"UnknownStatType", kUnknownStatType}, - // TraceMe arguments. - {"id", kStepId}, - {"device_ordinal", kDeviceOrdinal}, - {"chip_ordinal", kChipOrdinal}, - {"node_ordinal", kNodeOrdinal}, - {"model_id", kModelId}, - {"queue_addr", kQueueAddr}, - {"queue_id", kQueueId}, - {"request_id", kRequestId}, - {"run_id", kRunId}, - {"replica_id", kReplicaId}, - {"graph_type", kGraphType}, - {"step_num", kStepNum}, - {"iter_num", kIterNum}, - {"index_on_host", kIndexOnHost}, - {"allocator_name", kAllocatorName}, - {"bytes_reserved", kBytesReserved}, - {"bytes_allocated", kBytesAllocated}, - {"bytes_available", kBytesAvailable}, - {"fragmentation", kFragmentation}, - {"peak_bytes_in_use", kPeakBytesInUse}, - {"requested_bytes", kRequestedBytes}, - {"allocation_bytes", kAllocationBytes}, - {"addr", kAddress}, - {"region_type", kRegionType}, - {"data_type", kDataType}, - {"shape", kTensorShapes}, - {"layout", kTensorLayout}, - {"kpi_name", kKpiName}, - {"kpi_value", kKpiValue}, - {"element_id", kElementId}, - {"parent_id", kParentId}, - {"core_type", kCoreType}, - // XPlane semantics related. - {"_pt", kProducerType}, - {"_ct", kConsumerType}, - {"_p", kProducerId}, - {"_c", kConsumerId}, - {"_r", kIsRoot}, - {"_a", kIsAsync}, - // Device trace arguments. - {"device_id", kDeviceId}, - {"device_type_string", kDeviceTypeString}, - {"context_id", kContextId}, - {"correlation_id", kCorrelationId}, - {"memcpy_details", kMemcpyDetails}, - {"memalloc_details", kMemallocDetails}, - {"MemFree_details", kMemFreeDetails}, - {"Memset_details", kMemsetDetails}, - {"MemoryResidency_details", kMemoryResidencyDetails}, - {"kernel_details", kKernelDetails}, - {"nvtx_range", kNVTXRange}, - {"stream", kStream}, - // Stats added when processing traces. - {"group_id", kGroupId}, - {"flow", kFlow}, - {"step_name", kStepName}, - {"tf_op", kTfOp}, - {"hlo_op", kHloOp}, - {"deduplicated_name", kDeduplicatedName}, - {"hlo_category", kHloCategory}, - {"hlo_module", kHloModule}, - {"program_id", kProgramId}, - {"equation", kEquation}, - {"is_eager", kIsEager}, - {"is_func", kIsFunc}, - {"tf_function_call", kTfFunctionCall}, - {"tracing_count", kTfFunctionTracingCount}, - {"flops", kFlops}, - {"model_flops", kModelFlops}, - {"bytes_accessed", kBytesAccessed}, - {"memory_access_breakdown", kMemoryAccessBreakdown}, - {"source", kSourceInfo}, - {"model_name", kModelName}, - {"model_version", kModelVersion}, - {"bytes_transferred", kBytesTransferred}, - {"queue", kDmaQueue}, - {"dcn_collective_info", kDcnCollectiveInfo}, - // Performance counter related. - {"Raw Value", kRawValue}, - {"Scaled Value", kScaledValue}, - {"Thread Id", kThreadId}, - {"matrix_unit_utilization_percent", kMatrixUnitUtilizationPercent}, - // XLA metadata map related. - {"Hlo Proto", kHloProto}, - {"EdgeTPU Model information", kEdgeTpuModelInfo}, - {"EdgeTPU Model Profile information", kEdgeTpuModelProfileInfo}, - {"EdgeTPU MLIR", kEdgeTpuMlir}, - // Device capability related. - {"clock_rate", kDevCapClockRateKHz}, - {"core_count", kDevCapCoreCount}, - {"memory_bandwidth", kDevCapMemoryBandwidth}, - {"memory_size", kDevCapMemorySize}, - {"compute_cap_major", kDevCapComputeCapMajor}, - {"compute_cap_minor", kDevCapComputeCapMinor}, - {"peak_teraflops_per_second", kDevCapPeakTeraflopsPerSecond}, - {"peak_hbm_bw_gigabytes_per_second", kDevCapPeakHbmBwGigabytesPerSecond}, - {"peak_sram_rd_bw_gigabytes_per_second", - kDevCapPeakSramRdBwGigabytesPerSecond}, - {"peak_sram_wr_bw_gigabytes_per_second", - kDevCapPeakSramWrBwGigabytesPerSecond}, - {"device_vendor", kDevVendor}, - // Batching related. - {"batch_size_after_padding", kBatchSizeAfterPadding}, - {"padding_amount", kPaddingAmount}, - {"batching_input_task_size", kBatchingInputTaskSize}, - // GPU related metrics. - {"theoretical_occupancy_pct", kTheoreticalOccupancyPct}, - {"occupancy_min_grid_size", kOccupancyMinGridSize}, - {"occupancy_suggested_block_size", kOccupancySuggestedBlockSize}, - // Aggregated Stat - {"self_duration_ps", kSelfDurationPs}, - {"min_duration_ps", kMinDurationPs}, - {"total_profile_duration_ps", kTotalProfileDurationPs}, - {"max_iteration_num", kMaxIterationNum}, - {"device_type", kDeviceType}, - {"uses_megacore", kUsesMegaCore}, - {"symbol_id", kSymbolId}, - {"hlo_category", kHloCategory}, - {"tf_op_name", kTfOpName}, - {"dma_stall_duration_ps", kDmaStallDurationPs}, - {"key", kKey}, - {"payload_size_bytes", kPayloadSizeBytes}, - {"duration_us", kDuration}, - {"buffer_size", kBufferSize}, - {"transfers", kTransfers}, - // Dcn message Stats - {"dcn_label", kDcnLabel}, - {"dcn_source_slice_id", kDcnSourceSliceId}, - {"dcn_source_per_slice_device_id", kDcnSourcePerSliceDeviceId}, - {"dcn_destination_slice_id", kDcnDestinationSliceId}, - {"dcn_destination_per_slice_device_id", kDcnDestinationPerSliceDeviceId}, - {"dcn_chunk", kDcnChunk}, - {"dcn_loop_index", kDcnLoopIndex}, - {"dropped_traces", kDroppedTraces}, - {"cuda_graph_id", kCudaGraphId}, - {"cuda_graph_exec_id", kCudaGraphExecId}, - {"cuda_graph_orig_id", kCudaGraphOrigId}, - }); + static auto* stat_type_map = new StatTypeMap( + {{"UnknownStatType", kUnknownStatType}, + // TraceMe arguments. + {"id", kStepId}, + {"device_ordinal", kDeviceOrdinal}, + {"chip_ordinal", kChipOrdinal}, + {"node_ordinal", kNodeOrdinal}, + {"model_id", kModelId}, + {"queue_addr", kQueueAddr}, + {"queue_id", kQueueId}, + {"request_id", kRequestId}, + {"run_id", kRunId}, + {"replica_id", kReplicaId}, + {"graph_type", kGraphType}, + {"step_num", kStepNum}, + {"iter_num", kIterNum}, + {"index_on_host", kIndexOnHost}, + {"allocator_name", kAllocatorName}, + {"bytes_reserved", kBytesReserved}, + {"bytes_allocated", kBytesAllocated}, + {"bytes_available", kBytesAvailable}, + {"fragmentation", kFragmentation}, + {"peak_bytes_in_use", kPeakBytesInUse}, + {"requested_bytes", kRequestedBytes}, + {"allocation_bytes", kAllocationBytes}, + {"addr", kAddress}, + {"region_type", kRegionType}, + {"data_type", kDataType}, + {"shape", kTensorShapes}, + {"layout", kTensorLayout}, + {"kpi_name", kKpiName}, + {"kpi_value", kKpiValue}, + {"element_id", kElementId}, + {"parent_id", kParentId}, + {"core_type", kCoreType}, + // XPlane semantics related. + {"_pt", kProducerType}, + {"_ct", kConsumerType}, + {"_p", kProducerId}, + {"_c", kConsumerId}, + {"_r", kIsRoot}, + {"_a", kIsAsync}, + // Device trace arguments. + {"device_id", kDeviceId}, + {"device_type_string", kDeviceTypeString}, + {"context_id", kContextId}, + {"correlation_id", kCorrelationId}, + {"memcpy_details", kMemcpyDetails}, + {"memalloc_details", kMemallocDetails}, + {"MemFree_details", kMemFreeDetails}, + {"Memset_details", kMemsetDetails}, + {"MemoryResidency_details", kMemoryResidencyDetails}, + {"kernel_details", kKernelDetails}, + {"nvtx_range", kNVTXRange}, + {"stream", kStream}, + // Stats added when processing traces. + {"group_id", kGroupId}, + {"flow", kFlow}, + {"step_name", kStepName}, + {"tf_op", kTfOp}, + {"hlo_op", kHloOp}, + {"deduplicated_name", kDeduplicatedName}, + {"hlo_category", kHloCategory}, + {"hlo_module", kHloModule}, + {"program_id", kProgramId}, + {"equation", kEquation}, + {"is_eager", kIsEager}, + {"is_func", kIsFunc}, + {"tf_function_call", kTfFunctionCall}, + {"tracing_count", kTfFunctionTracingCount}, + {"flops", kFlops}, + {"model_flops", kModelFlops}, + {"bytes_accessed", kBytesAccessed}, + {"memory_access_breakdown", kMemoryAccessBreakdown}, + {"source", kSourceInfo}, + {"model_name", kModelName}, + {"model_version", kModelVersion}, + {"bytes_transferred", kBytesTransferred}, + {"queue", kDmaQueue}, + {"dcn_collective_info", kDcnCollectiveInfo}, + // Performance counter related. + {"Raw Value", kRawValue}, + {"Scaled Value", kScaledValue}, + {"Thread Id", kThreadId}, + {"matrix_unit_utilization_percent", kMatrixUnitUtilizationPercent}, + // XLA metadata map related. + {"Hlo Proto", kHloProto}, + {"EdgeTPU Model information", kEdgeTpuModelInfo}, + {"EdgeTPU Model Profile information", kEdgeTpuModelProfileInfo}, + {"EdgeTPU MLIR", kEdgeTpuMlir}, + // Device capability related. + {"clock_rate", kDevCapClockRateKHz}, + {"core_count", kDevCapCoreCount}, + {"memory_bandwidth", kDevCapMemoryBandwidth}, + {"memory_size", kDevCapMemorySize}, + {"compute_cap_major", kDevCapComputeCapMajor}, + {"compute_cap_minor", kDevCapComputeCapMinor}, + {"peak_teraflops_per_second", kDevCapPeakTeraflopsPerSecond}, + {"peak_hbm_bw_gigabytes_per_second", kDevCapPeakHbmBwGigabytesPerSecond}, + {"peak_sram_rd_bw_gigabytes_per_second", + kDevCapPeakSramRdBwGigabytesPerSecond}, + {"peak_sram_wr_bw_gigabytes_per_second", + kDevCapPeakSramWrBwGigabytesPerSecond}, + {"device_vendor", kDevVendor}, + // Batching related. + {"batch_size_after_padding", kBatchSizeAfterPadding}, + {"padding_amount", kPaddingAmount}, + {"batching_input_task_size", kBatchingInputTaskSize}, + // GPU related metrics. + {"theoretical_occupancy_pct", kTheoreticalOccupancyPct}, + {"occupancy_min_grid_size", kOccupancyMinGridSize}, + {"occupancy_suggested_block_size", kOccupancySuggestedBlockSize}, + // Aggregated Stat + {"self_duration_ps", kSelfDurationPs}, + {"min_duration_ps", kMinDurationPs}, + {"total_profile_duration_ps", kTotalProfileDurationPs}, + {"max_iteration_num", kMaxIterationNum}, + {"device_type", kDeviceType}, + {"uses_megacore", kUsesMegaCore}, + {"symbol_id", kSymbolId}, + {"hlo_category", kHloCategory}, + {"tf_op_name", kTfOpName}, + {"dma_stall_duration_ps", kDmaStallDurationPs}, + {"key", kKey}, + {"payload_size_bytes", kPayloadSizeBytes}, + {"duration_us", kDuration}, + {"buffer_size", kBufferSize}, + {"transfers", kTransfers}, + // Dcn message Stats + {"dcn_label", kDcnLabel}, + {"dcn_source_slice_id", kDcnSourceSliceId}, + {"dcn_source_per_slice_device_id", kDcnSourcePerSliceDeviceId}, + {"dcn_destination_slice_id", kDcnDestinationSliceId}, + {"dcn_destination_per_slice_device_id", kDcnDestinationPerSliceDeviceId}, + {"dcn_chunk", kDcnChunk}, + {"dcn_loop_index", kDcnLoopIndex}, + {"dropped_traces", kDroppedTraces}, + {"cuda_graph_id", kCudaGraphId}, + {"cuda_graph_exec_id", kCudaGraphExecId}, + {"cuda_graph_orig_id", kCudaGraphOrigId}, + {"step_idle_time_ps", kStepIdleTimePs}, + {"gpu_device_name", kGpuDeviceName}, + {"source_stack", kSourceStack}, + {"device_offset_ps", kDeviceOffsetPs}, + {"device_duration_ps", kDeviceDurationPs}}); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h similarity index 97% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_schema.h index 96ebf29d4fefff..a929971c99877a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_schema.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ #include #include @@ -70,6 +70,7 @@ TF_CONST_INIT extern const absl::string_view kTensorFlowNameScopeLineName; TF_CONST_INIT extern const absl::string_view kTensorFlowOpLineName; TF_CONST_INIT extern const absl::string_view kXlaModuleLineName; TF_CONST_INIT extern const absl::string_view kXlaOpLineName; +TF_CONST_INIT extern const absl::string_view kSparseCoreStepLineName; TF_CONST_INIT extern const absl::string_view kXlaAsyncOpLineName; TF_CONST_INIT extern const absl::string_view kKernelLaunchLineName; TF_CONST_INIT extern const absl::string_view kSourceLineName; @@ -277,6 +278,7 @@ enum StatType { kHloProto, // Device capability related. kDevCapClockRateKHz, + // For GPU, this is the number of SMs. kDevCapCoreCount, kDevCapMemoryBandwidth, kDevCapMemorySize, @@ -328,7 +330,12 @@ enum StatType { // on the GPU device when tracing is in graph level. kCudaGraphExecId, kCudaGraphOrigId, - kLastStatType = kCudaGraphOrigId, + kStepIdleTimePs, + kGpuDeviceName, + kSourceStack, + kDeviceOffsetPs, + kDeviceDurationPs, + kLastStatType = kDeviceDurationPs, }; enum MegaScaleStatType : uint8_t { @@ -532,4 +539,4 @@ TF_CONST_INIT extern const absl::string_view kThreadpoolListenerRegion; } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_SCHEMA_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc index f80e7e15bae4a7..548b444b912263 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_test_utils.h" +#include "xla/tsl/profiler/utils/xplane_test_utils.h" #include #include @@ -20,11 +20,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_utils.h" namespace tsl { namespace profiler { diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h similarity index 88% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h index 4568feed8d0439..b2e5e58494c67a 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_test_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_test_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/variant.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" #include "tsl/platform/types.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" namespace tsl { namespace profiler { @@ -58,4 +58,4 @@ void CreateTfFunctionCallEvent(XPlaneBuilder* plane_builder, } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_TEST_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc similarity index 95% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc index 9e1632dd66b6b7..8c1ec1882bb486 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include #include @@ -28,17 +28,17 @@ limitations under the License. #include "absl/log/log.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "xla/tsl/util/stats_calculator.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -167,6 +167,19 @@ bool IsOpLineName(absl::string_view line_name) { return line_name == kXlaOpLineName || line_name == kTensorFlowOpLineName; } +Timespan GetEventTimespan(const XEventVisitor& event) { + const std::optional device_offset_ps = + event.GetStat(StatType::kDeviceOffsetPs); + const std::optional device_duration_ps = + event.GetStat(StatType::kDeviceDurationPs); + if (device_offset_ps.has_value() && device_duration_ps.has_value()) { + return Timespan(device_offset_ps->IntOrUintValue(), + device_duration_ps->IntOrUintValue()); + } + + return event.GetTimespan(); +} + } // namespace const XPlane* FindPlaneWithName(const XSpace& space, absl::string_view name) { @@ -548,7 +561,8 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { uint64_t last_op_end_ps = 0; plane.ForEachLine([&](const XLineVisitor& line) { - if (line.Name() == kStepLineName) { + if (line.Name() == kStepLineName || + line.Name() == kSparseCoreStepLineName) { XLineBuilder aggregated_line = aggregated_plane.GetOrCreateLine(line.Id()); aggregated_line.SetName(kStepLineName); @@ -562,26 +576,27 @@ void AggregateXPlane(const XPlane& full_trace, XPlane& aggregated_trace) { aggregated_line.SetName(line.Name()); std::vector event_stack; line.ForEachEvent([&](XEventVisitor event) { + Timespan timespan = GetEventTimespan(event); first_op_start_ps = first_op_start_ps <= event.TimestampPs() ? first_op_start_ps - : event.TimestampPs(); + : timespan.begin_ps(); last_op_end_ps = last_op_end_ps >= event.EndTimestampPs() ? last_op_end_ps - : event.EndTimestampPs(); + : timespan.end_ps(); const auto& group_stat = event.GetStat(StatType::kGroupId); int64_t group_id = group_stat.has_value() ? group_stat->IntOrUintValue() : kint64max; StatByEvent& line_stats = stats[line.Id()][group_id]; - line_stats[event.Id()].stat.UpdateStat(event.DurationPs()); + line_stats[event.Id()].stat.UpdateStat(timespan.duration_ps()); DCHECK(event_stack.empty() || !(event < event_stack.back())); while (!event_stack.empty() && - !event_stack.back().GetTimespan().Includes(event.GetTimespan())) { + !GetEventTimespan(event_stack.back()).Includes(timespan)) { event_stack.pop_back(); } if (!event_stack.empty()) { line_stats[event_stack.back().Id()].children_duration += - event.DurationPs(); + timespan.duration_ps(); } event_stack.push_back(std::move(event)); }); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h similarity index 96% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils.h index 8ea1429c1d90d2..7992d795ea0318 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" +#include "xla/tsl/profiler/utils/timespan.h" +#include "xla/tsl/profiler/utils/trace_utils.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" -#include "tsl/profiler/utils/trace_utils.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -222,4 +222,4 @@ bool IsDevicePlane(const XPlane& plane); } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_UTILS_H_ diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc similarity index 89% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc index bdb5a3c4da4b03..ec44a499f56ad9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_utils_test.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_utils_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_utils.h" +#include "xla/tsl/profiler/utils/xplane_utils.h" #include #include @@ -24,14 +24,14 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/math_utils.h" +#include "xla/tsl/profiler/utils/tf_xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_builder.h" +#include "xla/tsl/profiler/utils/xplane_schema.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include "tsl/platform/test.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/math_utils.h" -#include "tsl/profiler/utils/tf_xplane_visitor.h" -#include "tsl/profiler/utils/xplane_builder.h" -#include "tsl/profiler/utils/xplane_schema.h" -#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { @@ -518,6 +518,80 @@ TEST(XplaneUtilsTest, TestAggregateXPlanes) { #endif } +TEST(XPlaneUtilsTest, TestAggregateXPlaneWithCycleStats) { + XPlane xplane; + XPlaneBuilder builder(&xplane); + const XStatMetadata& device_offset_stat = *builder.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDeviceOffsetPs)); + const XStatMetadata& device_duration_stat = *builder.GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kDeviceDurationPs)); + + const XEventMetadata& event_metadata1 = + *builder.GetOrCreateEventMetadata("EventMetadata1"); + const XEventMetadata& event_metadata2 = + *builder.GetOrCreateEventMetadata("EventMetadata2"); + const XEventMetadata& event_metadata3 = + *builder.GetOrCreateEventMetadata("EventMetadata3"); + + XLineBuilder line = builder.GetOrCreateLine(1); + line.SetName(kXlaOpLineName); + XEventBuilder event1 = line.AddEvent(event_metadata1); + event1.SetOffsetNs(0); + event1.SetDurationNs(5); + event1.AddStatValue(device_offset_stat, 50); + event1.AddStatValue(device_duration_stat, 4950); + XEventBuilder event2 = line.AddEvent(event_metadata2); + event2.SetOffsetNs(2); + event2.SetDurationNs(1); + event2.AddStatValue(device_offset_stat, 1950); + event2.AddStatValue(device_duration_stat, 890); + XEventBuilder event3 = line.AddEvent(event_metadata3); + event3.SetOffsetNs(3); + event3.SetDurationNs(2); + event3.AddStatValue(device_offset_stat, 2950); + event3.AddStatValue(device_duration_stat, 2050); + XEventBuilder event4 = line.AddEvent(event_metadata1); + event4.SetOffsetNs(5); + event4.SetDurationNs(5); + event4.AddStatValue(device_offset_stat, 5000); + event4.AddStatValue(device_duration_stat, 4950); + XEventBuilder event5 = line.AddEvent(event_metadata2); + event5.SetOffsetNs(7); + event5.SetDurationNs(1); + event5.AddStatValue(device_offset_stat, 7050); + event5.AddStatValue(device_duration_stat, 900); + XEventBuilder event6 = line.AddEvent(event_metadata3); + event6.SetOffsetNs(8); + event6.SetDurationNs(2); + event6.AddStatValue(device_offset_stat, 8050); + event6.AddStatValue(device_duration_stat, 1900); + + XPlane aggregated_xplane; + AggregateXPlane(xplane, aggregated_xplane); + + XPlaneVisitor visitor = CreateTfXPlaneVisitor(&aggregated_xplane); + visitor.ForEachLine([&](const XLineVisitor& line) { + EXPECT_EQ(line.Name(), kXlaOpLineName); + line.ForEachEvent([&](const XEventVisitor& event) { + EXPECT_EQ(event.OffsetPs(), 0); + if (event.Metadata().Name() == "EventMetadata1") { + EXPECT_EQ(event.NumOccurrences(), 2); + EXPECT_EQ(event.DurationPs(), 9900); + EXPECT_EQ((*event.GetStat(StatType::kSelfDurationPs)).IntOrUintValue(), + 4160); + } + if (event.Metadata().Name() == "EventMetadata2") { + EXPECT_EQ(event.NumOccurrences(), 2); + EXPECT_EQ(event.DurationPs(), 1790); + } + if (event.Metadata().Name() == "EventMetadata3") { + EXPECT_EQ(event.NumOccurrences(), 2); + EXPECT_EQ(event.DurationPs(), 3950); + } + }); + }); +} + TEST(XPlanuUtilsTest, TestInstantEventDoesNotFail) { XPlane xplane; XPlaneBuilder xplane_builder(&xplane); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc rename to third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc index 7d0a723221ffc5..b7bfad3f7211eb 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.cc +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/profiler/utils/xplane_visitor.h" +#include "xla/tsl/profiler/utils/xplane_visitor.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h similarity index 98% rename from third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h rename to third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h index eb0eb01a6b17c2..a9c8510355cde2 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_visitor.h +++ b/third_party/xla/xla/tsl/profiler/utils/xplane_visitor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ -#define TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#ifndef XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#define XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ #include @@ -25,9 +25,9 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/profiler/utils/timespan.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/xplane.pb.h" -#include "tsl/profiler/utils/timespan.h" namespace tsl { namespace profiler { @@ -359,4 +359,4 @@ void XEventMetadataVisitor::ForEachChild( } // namespace profiler } // namespace tsl -#endif // TENSORFLOW_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ +#endif // XLA_TSL_PROFILER_UTILS_XPLANE_VISITOR_H_ diff --git a/third_party/xla/xla/tsl/protobuf/BUILD b/third_party/xla/xla/tsl/protobuf/BUILD index d4d6f822814797..7b1b6fdbe4a753 100644 --- a/third_party/xla/xla/tsl/protobuf/BUILD +++ b/third_party/xla/xla/tsl/protobuf/BUILD @@ -1,7 +1,6 @@ -load( - "@local_tsl//tsl/platform:build_config.bzl", - "tf_proto_library", -) +load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") + +# copybara:uncomment load("@rules_python//python:proto.bzl", "py_proto_library") load( "//xla/tsl:tsl.bzl", "if_google", @@ -36,3 +35,95 @@ tf_proto_library( ]), visibility = ["//visibility:public"], ) + +tf_proto_library( + name = "dnn_proto", + srcs = ["dnn.proto"], + make_default_target_header_only = True, + protodeps = if_google(["@com_google_protobuf//:wrappers"]), + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "status_proto", + srcs = ["status.proto"], + make_default_target_header_only = True, + protodeps = ["//xla/tsl/protobuf:error_codes_proto_impl"], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "histogram_proto", + srcs = ["histogram.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "coordination_config_proto", + srcs = ["coordination_config.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "coordination_service_proto", + srcs = ["coordination_service.proto"], + has_services = 1, + create_grpc_library = True, + create_java_proto = False, + create_kotlin_proto = False, + create_service = True, + protodeps = if_google(["@com_google_protobuf//:any"]), + visibility = ["//visibility:public"], +) + +# copybara:uncomment_begin(google-only) +# py_proto_library( +# name = "coordination_service_py_pb2", +# visibility = ["//visibility:public"], +# deps = [":coordination_service_proto"], +# ) +# copybara:uncomment_end + +tf_proto_library( + name = "distributed_runtime_payloads_proto", + srcs = ["distributed_runtime_payloads.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "rpc_options_proto", + srcs = ["rpc_options.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "error_codes_proto_impl", + srcs = ["error_codes.proto"], + make_default_target_header_only = True, + protodeps = if_google(["@com_google_protobuf//:any"]), + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "protos_all", + create_go_proto = False, + create_kotlin_proto = False, + make_default_target_header_only = True, + protodeps = [ + # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes + # breakages (and they are not actually used). + ":bfc_memory_map_proto", + ":coordination_config_proto", + ":distributed_runtime_payloads_proto", + "//xla/tsl/protobuf:error_codes_proto_impl", + ":histogram_proto", + ":rpc_options_proto", + ":status_proto", + ":test_log_proto", + ] + if_google(["@com_google_protobuf//:any"]), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/coordination_config.proto b/third_party/xla/xla/tsl/protobuf/coordination_config.proto similarity index 94% rename from third_party/xla/third_party/tsl/tsl/protobuf/coordination_config.proto rename to third_party/xla/xla/tsl/protobuf/coordination_config.proto index 23aff65eb67985..645c992c64df42 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/coordination_config.proto +++ b/third_party/xla/xla/tsl/protobuf/coordination_config.proto @@ -29,6 +29,10 @@ message CoordinationServiceConfig { // Maximum wait time for all members in the cluster to be registered. int64 cluster_register_timeout_in_ms = 4; + // Denotes if we should synchronize the agents' register attempts by blocking + // on a barrier. This is useful for synchronized restarts. + bool cluster_register_with_barrier = 14; + // Heartbeat timeout, if a task does not record heartbeat in this time // window, it will be considered disconnected. // Note: This is also used as a grace period to accept any heartbeats after diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/coordination_service.proto b/third_party/xla/xla/tsl/protobuf/coordination_service.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/coordination_service.proto rename to third_party/xla/xla/tsl/protobuf/coordination_service.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/distributed_runtime_payloads.proto b/third_party/xla/xla/tsl/protobuf/distributed_runtime_payloads.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/distributed_runtime_payloads.proto rename to third_party/xla/xla/tsl/protobuf/distributed_runtime_payloads.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto b/third_party/xla/xla/tsl/protobuf/dnn.proto similarity index 99% rename from third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto rename to third_party/xla/xla/tsl/protobuf/dnn.proto index 695db935f6a0b4..2ac31005c16629 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/dnn.proto +++ b/third_party/xla/xla/tsl/protobuf/dnn.proto @@ -22,6 +22,8 @@ enum DataType { kF8E5M2FNUZ = 10; kF8E4M3FNUZ = 11; kInt64 = 12; + kF8E4M3 = 13; + kF8E3M4 = 14; } // Describes how a convolution input or output layer's data is formatted. diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/error_codes.proto b/third_party/xla/xla/tsl/protobuf/error_codes.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/error_codes.proto rename to third_party/xla/xla/tsl/protobuf/error_codes.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/histogram.proto b/third_party/xla/xla/tsl/protobuf/histogram.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/histogram.proto rename to third_party/xla/xla/tsl/protobuf/histogram.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/rpc_options.proto b/third_party/xla/xla/tsl/protobuf/rpc_options.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/rpc_options.proto rename to third_party/xla/xla/tsl/protobuf/rpc_options.proto diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/status.proto b/third_party/xla/xla/tsl/protobuf/status.proto similarity index 60% rename from third_party/xla/third_party/tsl/tsl/protobuf/status.proto rename to third_party/xla/xla/tsl/protobuf/status.proto index 09d72218941ee2..b61383cf3193ca 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/status.proto +++ b/third_party/xla/xla/tsl/protobuf/status.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package tensorflow; -import "tsl/protobuf/error_codes.proto"; +import "xla/tsl/protobuf/error_codes.proto"; option cc_enable_arenas = true; option java_multiple_files = true; @@ -10,11 +10,15 @@ option java_package = "org.tensorflow.framework"; option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; // Wire-format for Status. -// Next tag: 3 +// Next tag: 4 message StatusProto { - // Status code as defined in tensorflow/tsl/protobuf/error_codes.proto. + // Status code as defined in + // tensorflow/compiler/xla/tsl/protobuf/error_codes.proto. error.Code code = 1; // Detail error message. string message = 2; + + // Unique type URL -> value, like absl::Status payloads. + map payload = 3; } diff --git a/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc index 717ab3e462a7bf..e2c5eb295c6b12 100644 --- a/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc +++ b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.cc @@ -61,6 +61,10 @@ struct MlDtypesInitInfo { numpy_dtypes.bfloat16 = py::dtype::from_args(ml_dtypes.attr("bfloat16")).num(); + numpy_dtypes.float8_e3m4 = + py::dtype::from_args(ml_dtypes.attr("float8_e3m4")).num(); + numpy_dtypes.float8_e4m3 = + py::dtype::from_args(ml_dtypes.attr("float8_e4m3")).num(); numpy_dtypes.float8_e4m3fn = py::dtype::from_args(ml_dtypes.attr("float8_e4m3fn")).num(); numpy_dtypes.float8_e5m2 = @@ -81,6 +85,8 @@ struct MlDtypesInitInfo { // Verify all types were successfully loaded. if (numpy_dtypes.bfloat16 == NPY_NOTYPE || + numpy_dtypes.float8_e3m4 == NPY_NOTYPE || + numpy_dtypes.float8_e4m3 == NPY_NOTYPE || numpy_dtypes.float8_e4m3fn == NPY_NOTYPE || numpy_dtypes.float8_e4m3fnuz == NPY_NOTYPE || numpy_dtypes.float8_e4m3b11fnuz == NPY_NOTYPE || diff --git a/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h index bf9eab2200a76b..b3aa94e430239a 100644 --- a/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h +++ b/third_party/xla/xla/tsl/python/lib/core/ml_dtypes.h @@ -24,6 +24,8 @@ namespace ml_dtypes { struct NumpyDtypes { int bfloat16; + int float8_e3m4; + int float8_e4m3; int float8_e4m3fn; int float8_e4m3b11fnuz; int float8_e4m3fnuz; diff --git a/third_party/xla/xla/tsl/python/lib/core/numpy.h b/third_party/xla/xla/tsl/python/lib/core/numpy.h index 6a5a6a6486ccf7..ca57a0370548ed 100644 --- a/third_party/xla/xla/tsl/python/lib/core/numpy.h +++ b/third_party/xla/xla/tsl/python/lib/core/numpy.h @@ -29,6 +29,13 @@ limitations under the License. #define NO_IMPORT_ARRAY #endif +// Prevent linking error with numpy>=2.1.0 +// error: undefined hidden symbol: _xla_numpy_apiPyArray_RUNTIME_VERSION +// Without this define, Numpy's API symbols will have hidden symbol visibility, +// which may break things if Bazel chooses to build a cc_library target into +// its own .so file. Bazel typically does this for debug builds. +#define NPY_API_SYMBOL_ATTRIBUTE + // clang-format off // Place `` before to avoid build failure in macOS. #include diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 0cf769ddf4eadb..9cff2cc413ce51 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -21,7 +21,6 @@ load( load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm", - "if_rocm_is_configured", ) load( "@local_tsl//tsl/platform:rules_cc.bzl", @@ -33,6 +32,11 @@ load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", ) +load( + "@local_tsl//third_party/py/rules_pywrap:pywrap.bzl", + "use_pywrap_rules", + _pybind_extension = "pybind_extension", +) # Internally this loads a macro, but in OSS this is a function # buildifier: disable=out-of-order-load @@ -88,8 +92,7 @@ def if_cuda_or_rocm(if_true, if_false = []): """ return select({ - "@local_config_cuda//cuda:using_nvcc": if_true, - "@local_config_cuda//cuda:using_clang": if_true, + clean_dep("//xla/tsl:is_cuda_enabled"): if_true, "@local_config_rocm//rocm:using_hipcc": if_true, "//conditions:default": if_false, }) @@ -121,7 +124,10 @@ def internal_visibility(internal_targets): return if_google(internal_targets, ["//visibility:public"]) # TODO(jakeharmon): Use this to replace if_static +# TODO(b/356020232): remove completely after migration is done def if_tsl_link_protobuf(if_true, if_false = []): + if use_pywrap_rules(): + return if_true return select({ "//conditions:default": if_true, clean_dep("//xla/tsl:tsl_protobuf_header_only"): if_false, @@ -164,7 +170,7 @@ def if_not_fuchsia(a): def if_nvcc(a): return select({ - "@local_config_cuda//cuda:using_nvcc": a, + clean_dep("//xla/tsl:is_cuda_nvcc"): a, "//conditions:default": [], }) @@ -274,6 +280,8 @@ def get_win_copts(is_external = False): return WINDOWS_COPTS + ["/DTF_COMPILE_LIBRARY"] # copybara:comment_end +# TODO(b/356020232): cleanup non-use_pywrap_rules part once migration is done +# buildozer: disable=function-docstring-args def tsl_copts( android_optimization_level_override = "-O2", is_external = False, @@ -288,6 +296,16 @@ def tsl_copts( ] if android_optimization_level_override: android_copts.append(android_optimization_level_override) + + framework_deps = [] + if use_pywrap_rules(): + pass + else: + framework_deps = select({ + clean_dep("//xla/tsl:framework_shared_object"): [], + "//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"], + }) + return ( if_not_windows([ "-DEIGEN_AVOID_STL_ARRAY", @@ -316,10 +334,7 @@ def tsl_copts( if_linux_x86_64(["-msse3"]) + if_ios_x86_64(["-msse4.1"]) + if_no_default_logger(["-DNO_DEFAULT_LOGGER"]) + - select({ - clean_dep("//xla/tsl:framework_shared_object"): [], - "//conditions:default": ["-DTENSORFLOW_MONOLITHIC_BUILD"], - }) + + framework_deps + select({ clean_dep("//xla/tsl:android"): android_copts, clean_dep("//xla/tsl:emscripten"): [], @@ -367,7 +382,7 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs cuda_deps = [] kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"] - deps = deps + if_cuda_or_rocm(cuda_deps) + deps = deps + if_cuda(cuda_deps) if "default_copts" in kwargs: copts = kwargs["default_copts"] + copts kwargs.pop("default_copts", None) @@ -375,7 +390,8 @@ def tsl_gpu_library(deps = None, cuda_deps = None, copts = tsl_copts(), **kwargs deps = deps + if_cuda([ clean_dep("//xla/tsl/cuda:cudart"), "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ + ]) + if_rocm([ + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ]), copts = (copts + if_cuda(["-DGOOGLE_CUDA=1", "-DNV_CUDNN_DISABLE_EXCEPTION"]) + if_rocm(["-DTENSORFLOW_USE_ROCM=1"]) + if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) + if_mkl(["-DINTEL_MKL=1"]) + if_enable_mkl(["-DENABLE_MKL"]) + if_tensorrt(["-DGOOGLE_TENSORRT=1"])), @@ -498,6 +514,24 @@ def transitive_hdrs(name, deps = [], **kwargs): _transitive_hdrs(name = name + "_gather", deps = deps) native.filegroup(name = name, srcs = [":" + name + "_gather"], **kwargs) +def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], compatible_with = None, **kwargs): + if use_pywrap_rules(): + cc_library( + name = name, + deps = deps + extra_deps, + compatible_with = compatible_with, + **kwargs + ) + else: + custom_op_cc_header_only_library( + name, + deps, + includes, + extra_deps, + compatible_with, + **kwargs + ) + # Create a header only library that includes all the headers exported by # the libraries in deps. # @@ -511,7 +545,7 @@ def transitive_hdrs(name, deps = [], **kwargs): # * Eigen: it's a header-only library. Add it directly to your deps. # * GRPC: add a direct dep on @com_github_grpc_grpc//:grpc++_public_hdrs. # -def cc_header_only_library(name, deps = [], includes = [], extra_deps = [], compatible_with = None, **kwargs): +def custom_op_cc_header_only_library(name, deps = [], includes = [], extra_deps = [], compatible_with = None, **kwargs): _transitive_hdrs( name = name + "_gather", deps = deps, @@ -802,3 +836,6 @@ def tsl_extra_config_settings(): def tsl_extra_config_settings_targets(): return [] + +# TODO(b/356020232): remove after migration is done +tsl_pybind_extension = _pybind_extension if use_pywrap_rules() else tsl_pybind_extension_opensource diff --git a/third_party/xla/xla/tsl/tsl.default.bzl b/third_party/xla/xla/tsl/tsl.default.bzl index 2a6e4b3a0e2fb2..ffa4b8f9cb1dda 100644 --- a/third_party/xla/xla/tsl/tsl.default.bzl +++ b/third_party/xla/xla/tsl/tsl.default.bzl @@ -11,7 +11,7 @@ load( _tsl_extra_config_settings_targets = "tsl_extra_config_settings_targets", _tsl_google_bzl_deps = "tsl_google_bzl_deps", _tsl_grpc_cc_dependencies = "tsl_grpc_cc_dependencies", - _tsl_pybind_extension = "tsl_pybind_extension_opensource", + _tsl_pybind_extension = "tsl_pybind_extension", ) get_compatible_with_portable = _get_compatible_with_portable diff --git a/third_party/xla/xla/tsl/util/stats_calculator.h b/third_party/xla/xla/tsl/util/stats_calculator.h index 84045fb6ceece2..253895ca605fae 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator.h +++ b/third_party/xla/xla/tsl/util/stats_calculator.h @@ -117,6 +117,42 @@ class Stat { HighPrecisionValueType squared_sum_ = 0; }; +// A `StatWithPercentiles` inherited from `Stat`, also keeps track of the +// values added and can be used to compute the percentile values. +template +class StatWithPercentiles : public Stat { + public: + void UpdateStat(ValueType v) { + Stat::UpdateStat(v); + values_.push_back(v); + } + + // Returns the percentile value. + ValueType percentile(int percentile) const { + if (percentile < 0 || percentile > 100 || values_.empty()) { + return std::numeric_limits::quiet_NaN(); + } + std::vector values = values_; + if (percentile == 100) { + return values[values.size() - 1]; + } else { + std::nth_element(values.begin(), + values.begin() + values.size() * percentile / 100, + values.end()); + return values[values.size() * percentile / 100]; + } + } + + void OutputToStream(std::ostream* stream) const { + Stat::OutputToStream(stream); + *stream << " p5=" << percentile(5) << " median=" << percentile(50) + << " p95=" << percentile(95); + } + + private: + std::vector values_; +}; + // A StatsCalculator assists in performance analysis of Graph executions. // // It summarizes time spent executing (on GPU/CPU), memory used etc for diff --git a/third_party/xla/xla/tsl/util/stats_calculator_test.cc b/third_party/xla/xla/tsl/util/stats_calculator_test.cc index d58186630598f0..bab88a0236fe7e 100644 --- a/third_party/xla/xla/tsl/util/stats_calculator_test.cc +++ b/third_party/xla/xla/tsl/util/stats_calculator_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/tsl/util/stats_calculator.h" #include +#include +#include #include "tsl/platform/test.h" @@ -104,5 +106,39 @@ TEST(StatsCalculatorTest, UpdateStat) { EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON); } +TEST(StatsCalculatorTest, StatWithPercentiles) { + StatWithPercentiles stat; + EXPECT_TRUE(stat.empty()); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(1); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(-1.0); + EXPECT_FALSE(stat.all_same()); + stat.UpdateStat(100); + stat.UpdateStat(0); + EXPECT_EQ(4, stat.count()); + EXPECT_EQ(-1, stat.min()); + EXPECT_EQ(100, stat.max()); + EXPECT_EQ(25, stat.avg()); + EXPECT_EQ(1, stat.first()); + EXPECT_EQ(0, stat.newest()); + EXPECT_EQ(10002, stat.squared_sum()); + EXPECT_EQ(625, stat.avg() * stat.avg()); + // Sample variance + EXPECT_EQ(7502 / 3, stat.sample_variance()); + // Sample standard deviation, from WolframAlpha + EXPECT_EQ(50, std::sqrt(stat.sample_variance())); + // Population variance + EXPECT_EQ(7502 / 4, stat.variance()); + // Population standard deviation, from WolframAlpha + EXPECT_EQ(43, stat.std_deviation()); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(100, stat.percentile(90)); + stat.UpdateStat(150); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(150, stat.percentile(90)); + EXPECT_EQ(150, stat.percentile(100)); +} + } // namespace } // namespace tsl diff --git a/third_party/xla/xla/util.cc b/third_party/xla/xla/util.cc index 9b1a6db1fa22c0..8378920df5921f 100644 --- a/third_party/xla/xla/util.cc +++ b/third_party/xla/xla/util.cc @@ -29,16 +29,25 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/casts.h" +#include "absl/base/const_init.h" #include "absl/base/log_severity.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/strings/ascii.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/env.h" +#include "tsl/platform/logging.h" #include "tsl/platform/numbers.h" #include "tsl/platform/stacktrace.h" @@ -137,7 +146,7 @@ std::string Reindent(absl::string_view original, template static void RoundTripNanPayload(FloatT value, std::string* result) { static_assert(!std::is_same::value, - "RoundTripNanPayload does not support E4M3"); + "RoundTripNanPayload does not support E4M3FN"); static_assert(!std::is_same::value, "RoundTripNanPayload does not support E4M3FNUZ"); static_assert(!std::is_same::value, @@ -168,6 +177,12 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value) { return result; } +std::string RoundTripFpToString(tsl::float8_e4m3 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(tsl::float8_e4m3fnuz value) { std::string result = GenericRoundTripFpToString(value); return result; @@ -188,6 +203,12 @@ std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value) { return result; } +std::string RoundTripFpToString(tsl::float8_e3m4 value) { + std::string result = GenericRoundTripFpToString(value); + RoundTripNanPayload(value, &result); + return result; +} + std::string RoundTripFpToString(bfloat16 value) { std::string result = GenericRoundTripFpToString(value); RoundTripNanPayload(value, &result); @@ -453,6 +474,20 @@ ConvertedDimensionNumbers ConvertDimensionNumbers( absl::c_sort(dimensions.to_dimensions); return dimensions; } + +DimensionVector GetNonContractingDims( + int64_t rank, absl::Span contracting_dim_numbers, + absl::Span batch_dim_numbers) { + DimensionVector non_contracting_dim_numbers; + for (int64_t i = 0; i < rank; ++i) { + if (!absl::c_linear_search(contracting_dim_numbers, i) && + !absl::c_linear_search(batch_dim_numbers, i)) { + non_contracting_dim_numbers.push_back(i); + } + } + return non_contracting_dim_numbers; +} + std::string SanitizeFileName(std::string file_name) { for (char& c : file_name) { if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { diff --git a/third_party/xla/xla/util.h b/third_party/xla/xla/util.h index 325080f0f08201..959009073e96f9 100644 --- a/third_party/xla/xla/util.h +++ b/third_party/xla/xla/util.h @@ -47,9 +47,9 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" #include "xla/status_macros.h" +#include "xla/tsl/lib/math/math_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/math/math_util.h" #include "tsl/platform/bfloat16.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" // IWYU pragma: keep @@ -420,6 +420,9 @@ std::string VectorString(const std::initializer_list& c) { std::string RoundTripFpToString(tsl::float8_e5m2 value); // Returns a string which can losslessly round trip to a float8 E4M3. +std::string RoundTripFpToString(tsl::float8_e4m3 value); + +// Returns a string which can losslessly round trip to a float8 E4M3FN. std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. @@ -431,6 +434,9 @@ std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); // Returns a string which can losslessly round trip to a float8 E4M3FNUZ. std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); +// Returns a string which can losslessly round trip to a float8 E3M4. +std::string RoundTripFpToString(tsl::float8_e3m4 value); + // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); @@ -765,6 +771,12 @@ ConvertedDimensionNumbers ConvertDimensionNumbers( absl::Span from_dimensions, absl::Span from_sizes, absl::Span to_sizes); +// Returns non contracting dimensions for a dot operand based on rank, batch and +// contracting dimension numbers. +DimensionVector GetNonContractingDims( + int64_t rank, absl::Span contracting_dim_numbers, + absl::Span batch_dim_numbers); + // Removes illegal characters from filenames. std::string SanitizeFileName(std::string file_name); diff --git a/third_party/xla/xla/util_test.cc b/third_party/xla/xla/util_test.cc index 707696ea1c3a99..2fe6317bfbb8ea 100644 --- a/third_party/xla/xla/util_test.cc +++ b/third_party/xla/xla/util_test.cc @@ -27,6 +27,9 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" +#include "absl/types/span.h" +#include "ml_dtypes/include/float8.h" #include "xla/maybe_owning.h" #include "xla/test.h" #include "xla/types.h" @@ -130,6 +133,18 @@ TEST(UtilTest, RoundTripFpToString) { EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( true, QuietNanWithoutPayload())), "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + false, QuietNanWithoutPayload())), + "nan"); + EXPECT_EQ(RoundTripFpToString(NanWithSignAndPayload( + true, QuietNanWithoutPayload())), + "-nan"); EXPECT_EQ( RoundTripFpToString(std::numeric_limits::quiet_NaN()), "nan"); @@ -237,6 +252,18 @@ TEST(UtilTest, TotalOrder_F8E5M2) { } } +TEST(UtilTest, TotalOrder_F8E4M3) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e4m3 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e4m3 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + TEST(UtilTest, TotalOrder_F8E4M3FN) { for (int a = 0; a < 256; ++a) { tsl::float8_e4m3fn x = @@ -287,6 +314,18 @@ TEST(UtilTest, TotalOrder_F8E5M2FNUZ) { } } +TEST(UtilTest, TotalOrder_F8E3M4) { + for (int a = 0; a < 256; ++a) { + tsl::float8_e3m4 x = + Eigen::numext::bit_cast(static_cast(a)); + for (int b = 0; b < 256; ++b) { + tsl::float8_e3m4 y = + Eigen::numext::bit_cast(static_cast(b)); + TotalOrderHelper(x, y); + } + } +} + void PackInt4(absl::Span input, absl::Span output) { CHECK_EQ(output.size(), CeilOfRatio(input.size(), size_t{2})); for (size_t i = 0; i < input.size(); ++i) { diff --git a/third_party/xla/xla/window_util.cc b/third_party/xla/xla/window_util.cc index 471aec16511f32..140d15f58984d0 100644 --- a/third_party/xla/xla/window_util.cc +++ b/third_party/xla/xla/window_util.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/functional/function_ref.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/window_util_test.cc b/third_party/xla/xla/window_util_test.cc index a5d7ac7a265d3e..e1f6e13597a54e 100644 --- a/third_party/xla/xla/window_util_test.cc +++ b/third_party/xla/xla/window_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/window_util.h" #include "xla/test.h" +#include "xla/xla_data.pb.h" namespace xla { namespace { diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl index a81deeb562893c..c1193cb53fe951 100644 --- a/third_party/xla/xla/xla.bzl +++ b/third_party/xla/xla/xla.bzl @@ -53,8 +53,8 @@ _XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_static(extra_deps = [], otherwise = [ "//xla/tsl/profiler/backends/cpu:traceme_recorder_impl", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", - "@local_tsl//tsl/profiler/utils:time_utils_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", + "//xla/tsl/profiler/utils:time_utils_impl", + "//xla/tsl/protobuf:protos_all_cc_impl", ]) + if_cuda_is_configured([ Label("//xla/stream_executor/cuda:all_runtime"), Label("//xla/stream_executor/cuda:stream_executor_cuda"), diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index d5384f52d52e8d..9c75901e3ab616 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -48,9 +48,11 @@ message DebugOptions { //--------------------------------------------------------------------------// // XLA:CPU options. //--------------------------------------------------------------------------// + // go/keep-sorted start newline_separated=yes skip_lines=1 + + // Use region analysis in copy insertion pass. + bool xla_cpu_copy_insertion_use_region_analysis = 337; - // go/keep-sorted start newline_separated=yes - // // When true, XLA:CPU uses HLO module scheduler that is optimized for // extracting concurrency at the cost of extra memory: we extend the live // ranges of temporaries to allow XLA runtime to schedule independent @@ -98,6 +100,9 @@ message DebugOptions { // When true, XLA:CPU uses the thunk runtime to execute compiled program. bool xla_cpu_use_thunk_runtime = 298; + // Enabling this will enable optimizations that ignore the possibility of NaN. + bool xla_enable_fast_math = 335; + // The number of parts to split the LLVM module into before codegen. This // allows XLA to compile all parts in parallel, and resolve kernel symbols // from different dynamic libraries. @@ -107,6 +112,11 @@ message DebugOptions { // value is `256` (AVX2 on x86 platforms). int32 xla_cpu_prefer_vector_width = 308; + // When set, XLA:CPU will only generate code up to the specified ISA. + // (It will not use newer ISAs.) Using the string format allows us to extend + // the flag for more flexible control if necessary. + string xla_cpu_max_isa = 333; + // go/keep-sorted end //--------------------------------------------------------------------------// @@ -120,6 +130,28 @@ message DebugOptions { // Experimentally disables binary libraries in GPU compiler passes. bool xla_gpu_experimental_disable_binary_libraries = 329; + // Dump FDO profiles in a binary format to a separate file. + bool xla_gpu_experimental_dump_fdo_profiles = 338; + + // Enabling this flag will attempt to redirect every already-constructed + // fusion possible to the Triton emitter. + // + // For example, a fusion with kind kLoop will be transformed to a fusion with + // kind kCustom (and underlying kTritonFusionKind) if it can be tiled + // correctly, and if all the instructions it contains are supported by XLA's + // Triton emitter. Tile sizes are assigned automatically. + // + // Pre-existing block-level fusions are left unmodified. + bool xla_gpu_experimental_enable_fusion_block_level_rewriter = 334; + + // When enabled, the PriorityFusion pass will try to make Triton fusions first + // and foremost where it is possible. + // + // A kCustom fusion with underlying kTritonFusionKind will be created if it + // can be tiled correctly, and if all the instructions it contains are + // supported by XLA's Triton emitter. Tile sizes are assigned automatically. + bool xla_gpu_experimental_enable_triton_heroless_priority_fusion = 340; + // Gates the experimental feature coupling the Triton Softmax pattern matcher // with priority fusion. bool xla_gpu_experimental_enable_triton_softmax_priority_fusion = 325; @@ -567,6 +599,7 @@ message DebugOptions { CONDITIONALS = 5; CUSTOM_CALL = 6; CUBLASLT = 7; + DYNAMIC_SLICE = 8; } // Determine the types of commands that are recorded into command buffers. @@ -643,7 +676,8 @@ message DebugOptions { bool xla_gpu_enable_pipelined_all_gather = 227; bool xla_gpu_enable_pipelined_reduce_scatter = 231; bool xla_gpu_enable_pipelined_p2p = 246; - bool xla_gpu_run_post_layout_collective_pipeliner = 313; + + reserved 313; // Was xla_gpu_run_post_layout_collective_pipeliner. // The minimum data size in bytes to trigger collective-permute-decomposer // transformation. @@ -673,8 +707,6 @@ message DebugOptions { reserved 220; // Was xla_gpu_enable_triton_softmax_fusion - bool xla_gpu_enable_priority_fusion = 221; - reserved 286; // Was xla_gpu_enable_triton_softmax_priority_fusion // File to write autotune results to. It will be a binary file unless the name @@ -740,6 +772,9 @@ message DebugOptions { WHILE_LOOP_UNROLLING_DOUBLE_BUFFER = 1; // Enables full loop unrolling using the same strategy as `DOUBLE_BUFFER`. WHILE_LOOP_UNROLLING_FULL_UNROLL = 2; + // Enables loop unrolling when we have at least one collective within a + // while loop. + WHILE_LOOP_UNROLLING_AUTO_UNROLL = 3; } // Determine the while loop unrolling scheme. @@ -946,9 +981,6 @@ message DebugOptions { // If enabled, uses the libnvjitlink library for PTX compilation and linking bool xla_gpu_enable_libnvjitlink = 319; - // If enabled, generates triton gemm kernels for int4 inputs. - bool xla_gpu_enable_triton_gemm_int4 = 320; - // If true, XLA will wrap `dot` operations into async computations in an // effort to parallelize matrix operations. bool xla_gpu_async_dot = 321; @@ -968,13 +1000,55 @@ message DebugOptions { // Enables strict PGLE checking. If an FDO profile is specified and latency // hiding scheduler encounters missing instructions in the profile // compilation will halt. + // TODO: remove this field - it is replaced by xla_gpu_pgle_accuracy_checker. bool xla_gpu_enable_pgle_accuracy_checker = 326; // Timeouts for RendezvousSingle stuck warning and termination. int32 xla_gpu_executable_warn_stuck_timeout_seconds = 327; int32 xla_gpu_executable_terminate_timeout_seconds = 328; - // Next id: 330 + // Whether to ignore channel ids(including verifier channel id checks) + // for collectives in the given HLO. + bool xla_experimental_ignore_channel_id = 330; + + // DotMerger pass threshold size to be used in MB. + int32 xla_gpu_dot_merger_threshold_mb = 331; + + // If enabled, in the absence of user provided knobs might tune pass + // configurations based on the HLO. For example it decides to unroll the while + // loop by a factor of two if a collective op is present. + bool xla_gpu_enable_heuristic_pass_configuration = 332; + + // This controls how many in-flight collectives latency hiding scheduler + // can schedule. Example usage: + // With xla_gpu_experimental_parallel_collective_overlap_limit = 1: + // coll.1-start = collective(input) + // coll.1-done = collective(coll.1-start) + // coll.2-start = collective(input2) + // coll.2-done = collective(coll.2-start) + // With xla_gpu_experimental_parallel_collective_overlap_limit = 2: + // coll.1-start = collective(input) + // coll.2-start = collective(input2) + // coll.1-done = collective(coll.1-start) + // coll.2-done = collective(coll.2-start) + int32 xla_gpu_experimental_parallel_collective_overlap_limit = 336; + + // If set >= 0, this controls the total bytes(combined sizes of both + // operands in bytes) to enable windowed einsum and + // xla_gpu_threshold_for_windowed_einsum_mib will be ignored. + int64 xla_gpu_operand_bytes_threshold_for_windowed_einsum = 339; + + // Enables strict PGLE checking. If an FDO profile is specified and latency + // hiding scheduler encounters missing instructions in the profile + // compilation will halt or warn depending on the value of this option. + enum PGLEStrictnessLevel { + PGLE_STRICTNESS_LEVEL_OFF = 0; + PGLE_STRICTNESS_LEVEL_WARN = 1; + PGLE_STRICTNESS_LEVEL_ERROR = 2; + } + PGLEStrictnessLevel xla_gpu_pgle_accuracy_checker = 341; + + // Next id: 342 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -990,7 +1064,9 @@ message DebugOptions { // xla_gpu_graph_level // xla_gpu_single_wave_autotuning // xla_gpu_enable_persistent_temp_buffers - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206; + // xla_gpu_enable_triton_gemm_int4 + // xla_gpu_enable_priority_fusion + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 221, 242, 206, 320; } // Contains flags which affects the GPU compilation result. @@ -1012,7 +1088,7 @@ message ShardableValueUpdatePairProto { // will have an effect on every platform. // // When adding new fields, keep in mind that boolean fields default to false. -// Next id: 25. +// Next id: 27. message ExecutionOptions { // This optional field's layout is used as a hint when storing the output of // this computation. Subsequent transfers of this output array to the client @@ -1068,6 +1144,24 @@ message ExecutionOptions { // use_auto_spmd_partitioning=true. repeated int64 auto_spmd_partitioning_mesh_ids = 17; + // The amount of effort to spend on optimizing for minimizing program + // execution time, as a value in [-1.0, +1.0]. The baseline is 0.0, which + // strongly prioritizes execution time at the cost of longer compile times, + // suitable for production workloads. A value of -0.5 would be appropriate for + // research use cases that prefer faster compilations to iterate more quickly. + // Positive values, on the other hand, might enable costly optimizations that + // are off by default. + float exec_time_optimization_effort = 25; + + // The amount of effort to spend on making the program fit in memory (where + // "fit in memory" here has a backend-dependent meaning), as a value in + // [-1.0,+1.0]. The baseline is 0.0, which expends significant effort on + // attempting to make the program fit. A value of -1.0 would be appropriate + // for use cases that wish to spend minimal effort here and fail as quickly as + // possible instead. Positive values, on the other hand, might enable costly + // algorithms to reduce memory usage that are off by default. + float memory_fitting_effort = 26; + // If set, deduplicate hlo into function calls to reduce binary size. Only // works on TPU. bool deduplicate_hlo = 12; @@ -1129,7 +1223,7 @@ message ExecutionOptions { // Serialization of HloModuleConfig. See the C++ class definition for // descriptions of each field. // There are no guarantees of backwards or forwards compatibility. -// Next id: 36. +// Next id: 38. message HloModuleConfigProto { enum FusionConfigCollection { OFF = 0; // Do not collect configuration. @@ -1157,6 +1251,8 @@ message HloModuleConfigProto { bool use_auto_spmd_partitioning = 8; repeated int64 auto_spmd_partitioning_mesh_shape = 9; repeated int64 auto_spmd_partitioning_mesh_ids = 10; + float exec_time_optimization_effort = 36; + float memory_fitting_effort = 37; bool deduplicate_hlo = 11; int64 intra_op_parallelism_threads = 12; string device_type = 13; diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 335b59e2064d23..c67116a167eea5 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -66,6 +66,9 @@ enum PrimitiveType { // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the // existing IEEE types. // + // F8E4M3 has 4 exponent bits and 3 mantissa bits, and is similar to the + // existing IEEE types. + // // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only // Finite and NaN values are supported. Unlike IEEE types, infinities are not // supported. NaN is represented when the exponent and mantissa bits are all @@ -77,12 +80,17 @@ enum PrimitiveType { // the exponent and mantissa bits are all 0s with a sign bit of 1. All other // values are finite. // + // F8E3M4 has 3 exponent bits and 4 mantissa bits, and is similar to the + // existing IEEE types. + // // Support for these dtypes is under development. They do not yet work // properly in most cases. // TODO(b/259609697): Fully support FP8. F8E5M2 = 19; + F8E4M3 = 28; F8E4M3FN = 20; F8E4M3B11FNUZ = 23; + F8E3M4 = 29; // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 // @@ -126,10 +134,9 @@ enum PrimitiveType { // primitive type will have empty dimensions and tuple_shapes fields. TOKEN = 17; - // Next = 28 + // Next = 30 } // LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, // https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc // ) @@ -573,12 +580,14 @@ message LiteralProto { bytes u16s = 16; bytes s16s = 17; bytes f8e5m2s = 19; + bytes f8e4m3s = 28; bytes f8e4m3fns = 20; bytes f8e4m3b11fnuzs = 23; bytes f8e5m2fnuzs = 24; bytes f8e4m3fnuzs = 25; + bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 28 + // Next = 30 } message WindowDimension {